Compare commits

...

5 Commits

Author SHA1 Message Date
Roy Han
3dd791ecd5 fix image rollback test lint 2026-05-20 15:04:20 -07:00
rhan-oai
02fbf02750 Merge branch 'main' into rhan/fix-image-poisoning 2026-05-20 14:33:14 -07:00
Roy Han
9f80e1643d fix invalid image rollback in debug builds 2026-05-20 14:02:24 -07:00
rhan-oai
d651a764cb Update api_bridge.rs 2026-05-20 12:25:32 -07:00
Roy Han
ef5acb20ad repair invalid image poisoning across resume 2026-05-20 11:15:23 -07:00
7 changed files with 264 additions and 178 deletions

View File

@@ -56,8 +56,8 @@ pub fn map_api_error(err: ApiError) -> CodexErr {
}
if status == http::StatusCode::BAD_REQUEST {
if let Ok(parsed) = serde_json::from_str::<Value>(&body_text)
&& let Some(error) = parsed.get("error")
let parsed = serde_json::from_str::<Value>(&body_text).ok();
if let Some(error) = parsed.as_ref().and_then(|parsed| parsed.get("error"))
&& error.get("code").and_then(Value::as_str)
== Some(CYBER_POLICY_ERROR_CODE)
{
@@ -68,8 +68,19 @@ pub fn map_api_error(err: ApiError) -> CodexErr {
.map(str::to_string)
.unwrap_or_else(|| CYBER_POLICY_FALLBACK_MESSAGE.to_string());
CodexErr::CyberPolicy { message }
} else if body_text
.contains("The image data you provided does not represent a valid image")
} else if body_text.contains(INVALID_IMAGE_ERROR_MESSAGE)
|| body_text.contains(EMPTY_BASE64_IMAGE_ERROR_MESSAGE)
|| parsed
.as_ref()
.and_then(|parsed| parsed.get("error"))
.is_some_and(|error| {
error.get("code").and_then(Value::as_str)
== Some(INVALID_VALUE_ERROR_CODE)
&& error
.get("param")
.and_then(Value::as_str)
.is_some_and(|param| param.ends_with(".image_url"))
})
{
CodexErr::InvalidImageRequest()
} else {
@@ -141,6 +152,10 @@ const X_ERROR_JSON_HEADER: &str = "x-error-json";
const CYBER_POLICY_ERROR_CODE: &str = "cyber_policy";
const CYBER_POLICY_FALLBACK_MESSAGE: &str =
"This request has been flagged for possible cybersecurity risk.";
const INVALID_VALUE_ERROR_CODE: &str = "invalid_value";
const INVALID_IMAGE_ERROR_MESSAGE: &str =
"The image data you provided does not represent a valid image";
const EMPTY_BASE64_IMAGE_ERROR_MESSAGE: &str = "Expected a base64-encoded data URL with an image MIME type, but got empty base64-encoded bytes.";
#[cfg(test)]
#[path = "api_bridge_tests.rs"]

View File

@@ -102,6 +102,50 @@ fn map_api_error_uses_cyber_policy_fallback_for_missing_message() {
);
}
#[test]
fn map_api_error_maps_invalid_image_param_from_400_body() {
let body = serde_json::json!({
"error": {
"message": "Invalid inline image.",
"type": "invalid_request_error",
"param": "input[13].content[2].image_url",
"code": "invalid_value"
}
})
.to_string();
let err = map_api_error(ApiError::Transport(TransportError::Http {
status: http::StatusCode::BAD_REQUEST,
url: Some("http://example.com/v1/responses".to_string()),
headers: None,
body: Some(body),
}));
assert!(matches!(err, CodexErr::InvalidImageRequest()));
}
#[test]
fn map_api_error_maps_empty_base64_image_message_from_400_body() {
let body = serde_json::json!({
"error": {
"message": format!(
"Invalid 'input[0].content[2].image_url': {EMPTY_BASE64_IMAGE_ERROR_MESSAGE}"
),
"type": "invalid_request_error",
"param": null,
"code": "invalid_value"
}
})
.to_string();
let err = map_api_error(ApiError::Transport(TransportError::Http {
status: http::StatusCode::BAD_REQUEST,
url: Some("http://example.com/v1/responses".to_string()),
headers: None,
body: Some(body),
}));
assert!(matches!(err, CodexErr::InvalidImageRequest()));
}
#[test]
fn map_api_error_keeps_unknown_400_errors_generic() {
let body = serde_json::json!({

View File

@@ -184,40 +184,6 @@ impl ContextManager {
self.history_version = self.history_version.saturating_add(1);
}
/// Replace image content in the last turn if it originated from a tool output.
/// Returns true when a tool image was replaced, false otherwise.
pub(crate) fn replace_last_turn_images(&mut self, placeholder: &str) -> bool {
let Some(index) = self.items.iter().rposition(|item| {
matches!(item, ResponseItem::FunctionCallOutput { .. }) || is_user_turn_boundary(item)
}) else {
return false;
};
match &mut self.items[index] {
ResponseItem::FunctionCallOutput { output, .. } => {
let Some(content_items) = output.content_items_mut() else {
return false;
};
let mut replaced = false;
let placeholder = placeholder.to_string();
for item in content_items.iter_mut() {
if matches!(item, FunctionCallOutputContentItem::InputImage { .. }) {
*item = FunctionCallOutputContentItem::InputText {
text: placeholder.clone(),
};
replaced = true;
}
}
if replaced {
self.history_version = self.history_version.saturating_add(1);
}
replaced
}
ResponseItem::Message { .. } => false,
_ => false,
}
}
/// Drop the last `num_turns` instruction turns from this history.
///
/// Instruction turns are history messages that should behave like a new prompt boundary:

View File

@@ -671,63 +671,6 @@ fn remove_last_item_removes_matching_call_for_output() {
assert_eq!(h.raw_items(), vec![user_msg("before tool call")]);
}
#[test]
fn replace_last_turn_images_replaces_tool_output_images() {
let items = vec![
user_input_text_msg("hi"),
ResponseItem::FunctionCallOutput {
call_id: "call-1".to_string(),
output: FunctionCallOutputPayload {
body: FunctionCallOutputBody::ContentItems(vec![
FunctionCallOutputContentItem::InputImage {
image_url: "data:image/png;base64,AAA".to_string(),
detail: Some(DEFAULT_IMAGE_DETAIL),
},
]),
success: Some(true),
},
},
];
let mut history = create_history_with_items(items);
assert!(history.replace_last_turn_images("Invalid image"));
assert_eq!(
history.raw_items(),
vec![
user_input_text_msg("hi"),
ResponseItem::FunctionCallOutput {
call_id: "call-1".to_string(),
output: FunctionCallOutputPayload {
body: FunctionCallOutputBody::ContentItems(vec![
FunctionCallOutputContentItem::InputText {
text: "Invalid image".to_string(),
},
]),
success: Some(true),
},
},
]
);
}
#[test]
fn replace_last_turn_images_does_not_touch_user_images() {
let items = vec![ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::InputImage {
image_url: "data:image/png;base64,AAA".to_string(),
detail: Some(DEFAULT_IMAGE_DETAIL),
}],
phase: None,
}];
let mut history = create_history_with_items(items.clone());
assert!(!history.replace_last_turn_images("Invalid image"));
assert_eq!(history.raw_items(), items);
}
#[test]
fn remove_first_item_handles_local_shell_pair() {
let items = vec![

View File

@@ -12,6 +12,7 @@ use tracing::info_span;
use crate::session::SteerInputError;
use crate::session::session::Session;
use crate::session::session::SessionSettingsUpdate;
use crate::session::turn_context::TurnContext;
use crate::config::Config;
use crate::realtime_context::REALTIME_TURN_TOKEN_BUDGET;
@@ -499,61 +500,23 @@ pub async fn thread_rollback(sess: &Arc<Session>, sub_id: String, num_turns: u32
}
let turn_context = sess.new_default_turn_with_sub_id(sub_id).await;
let live_thread = match sess.live_thread_for_persistence("rollback thread") {
Ok(live_thread) => live_thread,
Err(_) => {
sess.send_event_raw(Event {
id: turn_context.sub_id.clone(),
msg: EventMsg::Error(ErrorEvent {
message: "thread rollback requires persisted thread history".to_string(),
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
}),
})
.await;
return;
}
};
if let Err(err) = live_thread.flush().await {
sess.send_event_raw(Event {
id: turn_context.sub_id.clone(),
msg: EventMsg::Error(ErrorEvent {
message: format!("failed to flush thread persistence for rollback replay: {err}"),
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
}),
})
.await;
return;
}
let (rollback_msg, flush_error) =
match apply_thread_rollback(sess, &turn_context, num_turns).await {
Ok(outcome) => outcome,
Err(message) => {
sess.send_event_raw(Event {
id: turn_context.sub_id.clone(),
msg: EventMsg::Error(ErrorEvent {
message,
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
}),
})
.await;
return;
}
};
let stored_history = match live_thread.load_history(/*include_archived*/ false).await {
Ok(history) => history,
Err(err) => {
sess.send_event_raw(Event {
id: turn_context.sub_id.clone(),
msg: EventMsg::Error(ErrorEvent {
message: format!("failed to load thread history for rollback replay: {err}"),
codex_error_info: Some(CodexErrorInfo::ThreadRollbackFailed),
}),
})
.await;
return;
}
};
let rollback_event = ThreadRolledBackEvent { num_turns };
let rollback_msg = EventMsg::ThreadRolledBack(rollback_event.clone());
let replay_items = stored_history
.items
.into_iter()
.chain(std::iter::once(RolloutItem::EventMsg(rollback_msg.clone())))
.collect::<Vec<_>>();
sess.apply_rollout_reconstruction(turn_context.as_ref(), replay_items.as_slice())
.await;
sess.recompute_token_usage(turn_context.as_ref()).await;
sess.persist_rollout_items(&[RolloutItem::EventMsg(rollback_msg.clone())])
.await;
if let Err(err) = sess.flush_rollout().await {
if let Some(err) = flush_error {
sess.send_event(
turn_context.as_ref(),
EventMsg::Warning(WarningEvent {
@@ -572,6 +535,40 @@ pub async fn thread_rollback(sess: &Arc<Session>, sub_id: String, num_turns: u32
.await;
}
pub(super) async fn apply_thread_rollback(
sess: &Arc<Session>,
turn_context: &Arc<TurnContext>,
num_turns: u32,
) -> Result<(EventMsg, Option<String>), String> {
let live_thread = sess
.live_thread_for_persistence("rollback thread")
.map_err(|_| "thread rollback requires persisted thread history".to_string())?;
live_thread
.flush()
.await
.map_err(|err| format!("failed to flush thread persistence for rollback replay: {err}"))?;
let stored_history = live_thread
.load_history(/*include_archived*/ false)
.await
.map_err(|err| format!("failed to load thread history for rollback replay: {err}"))?;
let rollback_event = ThreadRolledBackEvent { num_turns };
let rollback_msg = EventMsg::ThreadRolledBack(rollback_event);
let replay_items = stored_history
.items
.into_iter()
.chain(std::iter::once(RolloutItem::EventMsg(rollback_msg.clone())))
.collect::<Vec<_>>();
sess.apply_rollout_reconstruction(turn_context.as_ref(), replay_items.as_slice())
.await;
sess.recompute_token_usage(turn_context.as_ref()).await;
sess.persist_rollout_items(&[RolloutItem::EventMsg(rollback_msg.clone())])
.await;
let flush_error = sess.flush_rollout().await.err().map(|err| err.to_string());
Ok((rollback_msg, flush_error))
}
pub(super) async fn persist_thread_memory_mode_update(
sess: &Arc<Session>,
mode: ThreadMemoryMode,

View File

@@ -37,6 +37,7 @@ use crate::mentions::collect_tool_mentions_from_messages;
use crate::plugins::build_plugin_injections;
use crate::session::PreviousTurnSettings;
use crate::session::TurnInput;
use crate::session::handlers::apply_thread_rollback;
use crate::session::session::Session;
use crate::session::turn_context::TurnContext;
use crate::stream_events_utils::HandleOutputCtx;
@@ -87,6 +88,7 @@ use codex_protocol::protocol::AgentMessageContentDeltaEvent;
use codex_protocol::protocol::AgentReasoningSectionBreakEvent;
use codex_protocol::protocol::CodexErrorInfo;
use codex_protocol::protocol::ErrorEvent;
use codex_protocol::protocol::Event;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::PlanDeltaEvent;
use codex_protocol::protocol::ReasoningContentDeltaEvent;
@@ -410,19 +412,42 @@ pub(crate) async fn run_turn(
break;
}
Err(CodexErr::InvalidImageRequest()) => {
error!(
"Invalid image detected; rolling back the current turn to prevent poisoning",
);
let rollback_succeeded = match apply_thread_rollback(
&sess,
&turn_context,
/*num_turns*/ 1,
)
.await
{
let mut state = sess.state.lock().await;
error_or_panic(
"Invalid image detected; sanitizing tool output to prevent poisoning",
);
if state.history.replace_last_turn_images("Invalid image") {
continue;
Ok((rollback_msg, flush_error)) => {
if let Some(err) = flush_error {
warn!(
"rolled back invalid-image turn in memory, but failed to flush rollback marker: {err}"
);
}
sess.deliver_event_raw(Event {
id: turn_context.sub_id.clone(),
msg: rollback_msg,
})
.await;
true
}
}
Err(err) => {
warn!("failed to rollback invalid-image turn: {err}");
false
}
};
let event = EventMsg::Error(ErrorEvent {
message: "Invalid image in your last message. Please remove it and try again."
.to_string(),
message: if rollback_succeeded {
"This turn contained invalid image data and was rolled back. Please retry."
.to_string()
} else {
"This turn contained invalid image data and could not be repaired automatically. Please retry."
.to_string()
},
codex_error_info: Some(CodexErrorInfo::BadRequest),
});
sess.send_event(&turn_context, event).await;

View File

@@ -3,6 +3,7 @@
use anyhow::Context;
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use codex_core::RolloutRecorder;
use codex_exec_server::CreateDirectoryOptions;
use codex_exec_server::LOCAL_ENVIRONMENT_ID;
use codex_exec_server::REMOTE_ENVIRONMENT_ID;
@@ -25,7 +26,9 @@ use codex_protocol::permissions::FileSystemSandboxPolicy;
use codex_protocol::permissions::NetworkSandboxPolicy;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::InitialHistory;
use codex_protocol::protocol::Op;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::TurnEnvironmentSelection;
use codex_protocol::user_input::UserInput;
use core_test_support::PathBufExt;
@@ -62,9 +65,7 @@ use tempfile::TempDir;
use tokio::time::Duration;
use wiremock::BodyPrintLimit;
use wiremock::MockServer;
#[cfg(not(debug_assertions))]
use wiremock::ResponseTemplate;
#[cfg(not(debug_assertions))]
use wiremock::matchers::body_string_contains;
const VIEW_IMAGE_TURN_COMPLETE_TIMEOUT: Duration = Duration::from_secs(30);
@@ -1448,9 +1449,8 @@ async fn view_image_tool_returns_unsupported_message_for_text_only_model() -> an
Ok(())
}
#[cfg(not(debug_assertions))]
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn replaces_invalid_local_image_after_bad_request() -> anyhow::Result<()> {
async fn rolls_back_invalid_local_image_after_bad_request() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
@@ -1466,14 +1466,27 @@ async fn replaces_invalid_local_image_after_bad_request() -> anyhow::Result<()>
.set_body_string(INVALID_IMAGE_ERROR),
)
.await;
let unexpected_retry_mock = responses::mount_sse_once_match(
&server,
body_string_contains("Invalid image"),
sse(vec![
ev_response_created("resp-unexpected"),
ev_assistant_message("msg-unexpected", "unexpected retry"),
ev_completed("resp-unexpected"),
]),
)
.await;
let success_response = sse(vec![
ev_response_created("resp-2"),
ev_assistant_message("msg-1", "done"),
ev_completed("resp-2"),
]);
let completion_mock = responses::mount_sse_once(&server, success_response).await;
let resumed_completion_mock = responses::mount_sse_once_match(
&server,
body_string_contains("after resume"),
sse(vec![
ev_response_created("resp-2"),
ev_assistant_message("msg-1", "still done"),
ev_completed("resp-2"),
]),
)
.await;
let mut builder = test_codex();
let test = builder.build_with_remote_env(&server).await?;
@@ -1484,9 +1497,21 @@ async fn replaces_invalid_local_image_after_bad_request() -> anyhow::Result<()>
} = &test;
let rel_path = "assets/poisoned.png";
let abs_path = write_workspace_png(&test, rel_path, 1024, 512, [10u8, 20, 30, 255]).await?;
let abs_path = write_workspace_png(
&test,
rel_path,
/*width*/ 1024,
/*height*/ 512,
[10u8, 20, 30, 255],
)
.await?;
let session_model = session_configured.model.clone();
let rollout_path = session_configured
.rollout_path
.clone()
.expect("rollout path");
let home = test.home.clone();
codex
.submit(disabled_user_turn(
@@ -1499,8 +1524,33 @@ async fn replaces_invalid_local_image_after_bad_request() -> anyhow::Result<()>
))
.await?;
let rollback_event = wait_for_event_with_timeout(
codex,
|event| matches!(event, EventMsg::ThreadRolledBack(_)),
VIEW_IMAGE_TURN_COMPLETE_TIMEOUT,
)
.await;
let EventMsg::ThreadRolledBack(rollback) = rollback_event else {
unreachable!()
};
assert_eq!(rollback.num_turns, 1);
let error_event = wait_for_event_with_timeout(
codex,
|event| matches!(event, EventMsg::Error(_)),
VIEW_IMAGE_TURN_COMPLETE_TIMEOUT,
)
.await;
let EventMsg::Error(error) = error_event else {
unreachable!()
};
assert_eq!(
error.message,
"This turn contained invalid image data and was rolled back. Please retry."
);
wait_for_event_with_timeout(
&codex,
codex,
|event| matches!(event, EventMsg::TurnComplete(_)),
VIEW_IMAGE_TURN_COMPLETE_TIMEOUT,
)
@@ -1512,14 +1562,60 @@ async fn replaces_invalid_local_image_after_bad_request() -> anyhow::Result<()>
"initial request should include the uploaded image"
);
let second_request = completion_mock.single_request();
let second_body = second_request.body_json();
assert!(
find_image_message(&second_body).is_none(),
"second request should replace the invalid image"
unexpected_retry_mock.requests().is_empty(),
"invalid-image recovery should rollback instead of retrying with sanitized text"
);
let InitialHistory::Resumed(resumed_history) =
RolloutRecorder::get_rollout_history(&rollout_path).await?
else {
panic!("expected resumed rollout history");
};
assert!(resumed_history.history.iter().any(|item| {
matches!(
item,
RolloutItem::EventMsg(EventMsg::ThreadRolledBack(rollback))
if rollback.num_turns == 1
)
}));
let mut resume_builder = test_codex();
let resumed = resume_builder.resume(&server, home, rollout_path).await?;
resumed
.codex
.submit(disabled_user_turn(
&resumed,
vec![UserInput::Text {
text: "after resume".to_string(),
text_elements: Vec::new(),
}],
resumed.session_configured.model.clone(),
))
.await?;
wait_for_event_with_timeout(
&resumed.codex,
|event| matches!(event, EventMsg::TurnComplete(_)),
VIEW_IMAGE_TURN_COMPLETE_TIMEOUT,
)
.await;
let resumed_request = resumed_completion_mock.single_request();
assert!(
find_image_message(&resumed_request.body_json()).is_none(),
"resumed request should not replay the rolled-back image turn"
);
let resumed_user_texts = resumed_request.message_input_texts("user");
assert!(
resumed_user_texts.iter().any(|text| text == "after resume"),
"resumed request should contain only the new follow-up text"
);
assert!(
resumed_user_texts
.iter()
.all(|text| text != "Invalid image"),
"rolled-back turns should not replay placeholder text"
);
let user_texts = second_request.message_input_texts("user");
assert!(user_texts.iter().any(|text| text == "Invalid image"));
Ok(())
}