mirror of
https://github.com/openai/codex.git
synced 2026-04-27 08:05:51 +00:00
Feat: add model reroute notification (#12001)
### Summary
Builiding off
5c75aa7b89 (diff-058ae8f109a8b84b4b79bbfa45f522c2233b9d9e139696044ae374d50b6196e0),
we have created a `model/rerouted` notification that captures the event
so that consumers can render as expected. Keep the `EventMsg::Warning`
path in core so that this does not affect TUI rendering.
`model/rerouted` is meant to be generic to account for future usage
including capacity planning etc.
This commit is contained in:
@@ -3,8 +3,10 @@ use app_test_support::McpProcess;
|
||||
use app_test_support::to_response;
|
||||
use codex_app_server_protocol::ItemCompletedNotification;
|
||||
use codex_app_server_protocol::ItemStartedNotification;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::JSONRPCResponse;
|
||||
use codex_app_server_protocol::ModelRerouteReason;
|
||||
use codex_app_server_protocol::ModelReroutedNotification;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::ThreadItem;
|
||||
use codex_app_server_protocol::ThreadStartParams;
|
||||
@@ -23,7 +25,7 @@ const REQUESTED_MODEL: &str = "gpt-5.1-codex-max";
|
||||
const SERVER_MODEL: &str = "gpt-5.2-codex";
|
||||
|
||||
#[tokio::test]
|
||||
async fn openai_model_header_mismatch_emits_warning_item_v2() -> Result<()> {
|
||||
async fn openai_model_header_mismatch_emits_model_rerouted_notification_v2() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = responses::start_mock_server().await;
|
||||
@@ -64,64 +66,30 @@ async fn openai_model_header_mismatch_emits_warning_item_v2() -> Result<()> {
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let _turn_resp: JSONRPCResponse = timeout(
|
||||
let turn_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(turn_req)),
|
||||
)
|
||||
.await??;
|
||||
let _turn_start: TurnStartResponse = to_response(_turn_resp)?;
|
||||
let turn_start: TurnStartResponse = to_response(turn_resp)?;
|
||||
|
||||
let warning_started = timeout(DEFAULT_READ_TIMEOUT, async {
|
||||
loop {
|
||||
let notification: JSONRPCNotification = mcp
|
||||
.read_stream_until_notification_message("item/started")
|
||||
.await?;
|
||||
let params = notification.params.expect("item/started params");
|
||||
let started: ItemStartedNotification =
|
||||
serde_json::from_value(params).expect("deserialize item/started");
|
||||
if warning_text_from_item(&started.item).is_some_and(is_cyber_model_warning_text) {
|
||||
return Ok::<ItemStartedNotification, anyhow::Error>(started);
|
||||
}
|
||||
}
|
||||
})
|
||||
.await??;
|
||||
|
||||
let warning_text =
|
||||
warning_text_from_item(&warning_started.item).expect("expected warning user message item");
|
||||
assert!(warning_text.contains("Warning:"));
|
||||
assert!(warning_text.contains("gpt-5.2 as a fallback"));
|
||||
assert!(warning_text.contains("regain access to gpt-5.3-codex"));
|
||||
|
||||
let warning_completed = timeout(DEFAULT_READ_TIMEOUT, async {
|
||||
loop {
|
||||
let notification: JSONRPCNotification = mcp
|
||||
.read_stream_until_notification_message("item/completed")
|
||||
.await?;
|
||||
let params = notification.params.expect("item/completed params");
|
||||
let completed: ItemCompletedNotification =
|
||||
serde_json::from_value(params).expect("deserialize item/completed");
|
||||
if warning_text_from_item(&completed.item).is_some_and(is_cyber_model_warning_text) {
|
||||
return Ok::<ItemCompletedNotification, anyhow::Error>(completed);
|
||||
}
|
||||
}
|
||||
})
|
||||
.await??;
|
||||
let rerouted = collect_turn_notifications_and_validate_no_warning_item(&mut mcp).await?;
|
||||
assert_eq!(
|
||||
warning_text_from_item(&warning_completed.item),
|
||||
warning_text_from_item(&warning_started.item)
|
||||
rerouted,
|
||||
ModelReroutedNotification {
|
||||
thread_id: thread.id,
|
||||
turn_id: turn_start.turn.id,
|
||||
from_model: REQUESTED_MODEL.to_string(),
|
||||
to_model: SERVER_MODEL.to_string(),
|
||||
reason: ModelRerouteReason::HighRiskCyberActivity,
|
||||
}
|
||||
);
|
||||
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("turn/completed"),
|
||||
)
|
||||
.await??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn response_model_field_mismatch_emits_warning_item_v2_when_header_matches_requested()
|
||||
async fn response_model_field_mismatch_emits_model_rerouted_notification_v2_when_header_matches_requested()
|
||||
-> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
@@ -174,54 +142,65 @@ async fn response_model_field_mismatch_emits_warning_item_v2_when_header_matches
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(turn_req)),
|
||||
)
|
||||
.await??;
|
||||
let _turn_start: TurnStartResponse = to_response(turn_resp)?;
|
||||
let turn_start: TurnStartResponse = to_response(turn_resp)?;
|
||||
|
||||
let warning_started = timeout(DEFAULT_READ_TIMEOUT, async {
|
||||
loop {
|
||||
let notification: JSONRPCNotification = mcp
|
||||
.read_stream_until_notification_message("item/started")
|
||||
.await?;
|
||||
let params = notification.params.expect("item/started params");
|
||||
let started: ItemStartedNotification =
|
||||
serde_json::from_value(params).expect("deserialize item/started");
|
||||
if warning_text_from_item(&started.item).is_some_and(is_cyber_model_warning_text) {
|
||||
return Ok::<ItemStartedNotification, anyhow::Error>(started);
|
||||
}
|
||||
}
|
||||
})
|
||||
.await??;
|
||||
let warning_text =
|
||||
warning_text_from_item(&warning_started.item).expect("expected warning user message item");
|
||||
assert!(warning_text.contains("gpt-5.2 as a fallback"));
|
||||
|
||||
let warning_completed = timeout(DEFAULT_READ_TIMEOUT, async {
|
||||
loop {
|
||||
let notification: JSONRPCNotification = mcp
|
||||
.read_stream_until_notification_message("item/completed")
|
||||
.await?;
|
||||
let params = notification.params.expect("item/completed params");
|
||||
let completed: ItemCompletedNotification =
|
||||
serde_json::from_value(params).expect("deserialize item/completed");
|
||||
if warning_text_from_item(&completed.item).is_some_and(is_cyber_model_warning_text) {
|
||||
return Ok::<ItemCompletedNotification, anyhow::Error>(completed);
|
||||
}
|
||||
}
|
||||
})
|
||||
.await??;
|
||||
let rerouted = collect_turn_notifications_and_validate_no_warning_item(&mut mcp).await?;
|
||||
assert_eq!(
|
||||
warning_text_from_item(&warning_completed.item),
|
||||
warning_text_from_item(&warning_started.item)
|
||||
rerouted,
|
||||
ModelReroutedNotification {
|
||||
thread_id: thread.id,
|
||||
turn_id: turn_start.turn.id,
|
||||
from_model: REQUESTED_MODEL.to_string(),
|
||||
to_model: SERVER_MODEL.to_string(),
|
||||
reason: ModelRerouteReason::HighRiskCyberActivity,
|
||||
}
|
||||
);
|
||||
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("turn/completed"),
|
||||
)
|
||||
.await??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn collect_turn_notifications_and_validate_no_warning_item(
|
||||
mcp: &mut McpProcess,
|
||||
) -> Result<ModelReroutedNotification> {
|
||||
let mut rerouted = None;
|
||||
|
||||
loop {
|
||||
let message = timeout(DEFAULT_READ_TIMEOUT, mcp.read_next_message()).await??;
|
||||
let JSONRPCMessage::Notification(notification) = message else {
|
||||
continue;
|
||||
};
|
||||
match notification.method.as_str() {
|
||||
"model/rerouted" => {
|
||||
let params = notification.params.ok_or_else(|| {
|
||||
anyhow::anyhow!("model/rerouted notifications must include params")
|
||||
})?;
|
||||
let payload: ModelReroutedNotification = serde_json::from_value(params)?;
|
||||
rerouted = Some(payload);
|
||||
}
|
||||
"item/started" => {
|
||||
let params = notification.params.ok_or_else(|| {
|
||||
anyhow::anyhow!("item/started notifications must include params")
|
||||
})?;
|
||||
let payload: ItemStartedNotification = serde_json::from_value(params)?;
|
||||
assert!(!is_warning_user_message_item(&payload.item));
|
||||
}
|
||||
"item/completed" => {
|
||||
let params = notification.params.ok_or_else(|| {
|
||||
anyhow::anyhow!("item/completed notifications must include params")
|
||||
})?;
|
||||
let payload: ItemCompletedNotification = serde_json::from_value(params)?;
|
||||
assert!(!is_warning_user_message_item(&payload.item));
|
||||
}
|
||||
"turn/completed" => {
|
||||
return rerouted.ok_or_else(|| {
|
||||
anyhow::anyhow!("expected model/rerouted notification before turn/completed")
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn warning_text_from_item(item: &ThreadItem) -> Option<&str> {
|
||||
let ThreadItem::UserMessage { content, .. } = item else {
|
||||
return None;
|
||||
@@ -233,9 +212,8 @@ fn warning_text_from_item(item: &ThreadItem) -> Option<&str> {
|
||||
})
|
||||
}
|
||||
|
||||
fn is_cyber_model_warning_text(text: &str) -> bool {
|
||||
text.contains("flagged for potentially high-risk cyber activity")
|
||||
&& text.contains("apply for trusted access: https://chatgpt.com/cyber")
|
||||
fn is_warning_user_message_item(item: &ThreadItem) -> bool {
|
||||
warning_text_from_item(item).is_some()
|
||||
}
|
||||
|
||||
fn create_config_toml(codex_home: &std::path::Path, server_uri: &str) -> std::io::Result<()> {
|
||||
|
||||
Reference in New Issue
Block a user