Feat: add model reroute notification (#12001)

### Summary
Builiding off
5c75aa7b89 (diff-058ae8f109a8b84b4b79bbfa45f522c2233b9d9e139696044ae374d50b6196e0),
we have created a `model/rerouted` notification that captures the event
so that consumers can render as expected. Keep the `EventMsg::Warning`
path in core so that this does not affect TUI rendering.

`model/rerouted` is meant to be generic to account for future usage
including capacity planning etc.
This commit is contained in:
Shijie Rao
2026-02-17 11:02:23 -08:00
committed by GitHub
parent a1b8e34938
commit 48018e9eac
28 changed files with 605 additions and 146 deletions

View File

@@ -3,8 +3,10 @@ use app_test_support::McpProcess;
use app_test_support::to_response;
use codex_app_server_protocol::ItemCompletedNotification;
use codex_app_server_protocol::ItemStartedNotification;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::ModelRerouteReason;
use codex_app_server_protocol::ModelReroutedNotification;
use codex_app_server_protocol::RequestId;
use codex_app_server_protocol::ThreadItem;
use codex_app_server_protocol::ThreadStartParams;
@@ -23,7 +25,7 @@ const REQUESTED_MODEL: &str = "gpt-5.1-codex-max";
const SERVER_MODEL: &str = "gpt-5.2-codex";
#[tokio::test]
async fn openai_model_header_mismatch_emits_warning_item_v2() -> Result<()> {
async fn openai_model_header_mismatch_emits_model_rerouted_notification_v2() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = responses::start_mock_server().await;
@@ -64,64 +66,30 @@ async fn openai_model_header_mismatch_emits_warning_item_v2() -> Result<()> {
..Default::default()
})
.await?;
let _turn_resp: JSONRPCResponse = timeout(
let turn_resp: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(turn_req)),
)
.await??;
let _turn_start: TurnStartResponse = to_response(_turn_resp)?;
let turn_start: TurnStartResponse = to_response(turn_resp)?;
let warning_started = timeout(DEFAULT_READ_TIMEOUT, async {
loop {
let notification: JSONRPCNotification = mcp
.read_stream_until_notification_message("item/started")
.await?;
let params = notification.params.expect("item/started params");
let started: ItemStartedNotification =
serde_json::from_value(params).expect("deserialize item/started");
if warning_text_from_item(&started.item).is_some_and(is_cyber_model_warning_text) {
return Ok::<ItemStartedNotification, anyhow::Error>(started);
}
}
})
.await??;
let warning_text =
warning_text_from_item(&warning_started.item).expect("expected warning user message item");
assert!(warning_text.contains("Warning:"));
assert!(warning_text.contains("gpt-5.2 as a fallback"));
assert!(warning_text.contains("regain access to gpt-5.3-codex"));
let warning_completed = timeout(DEFAULT_READ_TIMEOUT, async {
loop {
let notification: JSONRPCNotification = mcp
.read_stream_until_notification_message("item/completed")
.await?;
let params = notification.params.expect("item/completed params");
let completed: ItemCompletedNotification =
serde_json::from_value(params).expect("deserialize item/completed");
if warning_text_from_item(&completed.item).is_some_and(is_cyber_model_warning_text) {
return Ok::<ItemCompletedNotification, anyhow::Error>(completed);
}
}
})
.await??;
let rerouted = collect_turn_notifications_and_validate_no_warning_item(&mut mcp).await?;
assert_eq!(
warning_text_from_item(&warning_completed.item),
warning_text_from_item(&warning_started.item)
rerouted,
ModelReroutedNotification {
thread_id: thread.id,
turn_id: turn_start.turn.id,
from_model: REQUESTED_MODEL.to_string(),
to_model: SERVER_MODEL.to_string(),
reason: ModelRerouteReason::HighRiskCyberActivity,
}
);
timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("turn/completed"),
)
.await??;
Ok(())
}
#[tokio::test]
async fn response_model_field_mismatch_emits_warning_item_v2_when_header_matches_requested()
async fn response_model_field_mismatch_emits_model_rerouted_notification_v2_when_header_matches_requested()
-> Result<()> {
skip_if_no_network!(Ok(()));
@@ -174,54 +142,65 @@ async fn response_model_field_mismatch_emits_warning_item_v2_when_header_matches
mcp.read_stream_until_response_message(RequestId::Integer(turn_req)),
)
.await??;
let _turn_start: TurnStartResponse = to_response(turn_resp)?;
let turn_start: TurnStartResponse = to_response(turn_resp)?;
let warning_started = timeout(DEFAULT_READ_TIMEOUT, async {
loop {
let notification: JSONRPCNotification = mcp
.read_stream_until_notification_message("item/started")
.await?;
let params = notification.params.expect("item/started params");
let started: ItemStartedNotification =
serde_json::from_value(params).expect("deserialize item/started");
if warning_text_from_item(&started.item).is_some_and(is_cyber_model_warning_text) {
return Ok::<ItemStartedNotification, anyhow::Error>(started);
}
}
})
.await??;
let warning_text =
warning_text_from_item(&warning_started.item).expect("expected warning user message item");
assert!(warning_text.contains("gpt-5.2 as a fallback"));
let warning_completed = timeout(DEFAULT_READ_TIMEOUT, async {
loop {
let notification: JSONRPCNotification = mcp
.read_stream_until_notification_message("item/completed")
.await?;
let params = notification.params.expect("item/completed params");
let completed: ItemCompletedNotification =
serde_json::from_value(params).expect("deserialize item/completed");
if warning_text_from_item(&completed.item).is_some_and(is_cyber_model_warning_text) {
return Ok::<ItemCompletedNotification, anyhow::Error>(completed);
}
}
})
.await??;
let rerouted = collect_turn_notifications_and_validate_no_warning_item(&mut mcp).await?;
assert_eq!(
warning_text_from_item(&warning_completed.item),
warning_text_from_item(&warning_started.item)
rerouted,
ModelReroutedNotification {
thread_id: thread.id,
turn_id: turn_start.turn.id,
from_model: REQUESTED_MODEL.to_string(),
to_model: SERVER_MODEL.to_string(),
reason: ModelRerouteReason::HighRiskCyberActivity,
}
);
timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("turn/completed"),
)
.await??;
Ok(())
}
async fn collect_turn_notifications_and_validate_no_warning_item(
mcp: &mut McpProcess,
) -> Result<ModelReroutedNotification> {
let mut rerouted = None;
loop {
let message = timeout(DEFAULT_READ_TIMEOUT, mcp.read_next_message()).await??;
let JSONRPCMessage::Notification(notification) = message else {
continue;
};
match notification.method.as_str() {
"model/rerouted" => {
let params = notification.params.ok_or_else(|| {
anyhow::anyhow!("model/rerouted notifications must include params")
})?;
let payload: ModelReroutedNotification = serde_json::from_value(params)?;
rerouted = Some(payload);
}
"item/started" => {
let params = notification.params.ok_or_else(|| {
anyhow::anyhow!("item/started notifications must include params")
})?;
let payload: ItemStartedNotification = serde_json::from_value(params)?;
assert!(!is_warning_user_message_item(&payload.item));
}
"item/completed" => {
let params = notification.params.ok_or_else(|| {
anyhow::anyhow!("item/completed notifications must include params")
})?;
let payload: ItemCompletedNotification = serde_json::from_value(params)?;
assert!(!is_warning_user_message_item(&payload.item));
}
"turn/completed" => {
return rerouted.ok_or_else(|| {
anyhow::anyhow!("expected model/rerouted notification before turn/completed")
});
}
_ => {}
}
}
}
fn warning_text_from_item(item: &ThreadItem) -> Option<&str> {
let ThreadItem::UserMessage { content, .. } = item else {
return None;
@@ -233,9 +212,8 @@ fn warning_text_from_item(item: &ThreadItem) -> Option<&str> {
})
}
fn is_cyber_model_warning_text(text: &str) -> bool {
text.contains("flagged for potentially high-risk cyber activity")
&& text.contains("apply for trusted access: https://chatgpt.com/cyber")
fn is_warning_user_message_item(item: &ThreadItem) -> bool {
warning_text_from_item(item).is_some()
}
fn create_config_toml(codex_home: &std::path::Path, server_uri: &str) -> std::io::Result<()> {