Compare commits

...

1 Commits

Author SHA1 Message Date
Noah MacCallum
b861e9ea81 Enable image generation continuation 2026-05-11 16:02:41 -07:00
9 changed files with 223 additions and 16 deletions

View File

@@ -460,6 +460,9 @@
"image_generation": {
"type": "boolean"
},
"image_generation_continuation": {
"type": "boolean"
},
"in_app_browser": {
"type": "boolean"
},
@@ -4102,6 +4105,9 @@
"image_generation": {
"type": "boolean"
},
"image_generation_continuation": {
"type": "boolean"
},
"in_app_browser": {
"type": "boolean"
},

View File

@@ -5,13 +5,25 @@ use std::fmt::Display;
pub(crate) struct ImageGenerationInstructions {
image_output_dir: String,
image_output_path: String,
call_id: Option<String>,
saved_path: Option<String>,
revised_prompt: Option<String>,
}
impl ImageGenerationInstructions {
pub(crate) fn new(image_output_dir: impl Display, image_output_path: impl Display) -> Self {
pub(crate) fn for_generated_image(
image_output_dir: impl Display,
image_output_path: impl Display,
call_id: impl Display,
saved_path: impl Display,
revised_prompt: Option<&str>,
) -> Self {
Self {
image_output_dir: image_output_dir.to_string(),
image_output_path: image_output_path.to_string(),
call_id: Some(call_id.to_string()),
saved_path: Some(saved_path.to_string()),
revised_prompt: revised_prompt.map(str::to_string),
}
}
}
@@ -22,9 +34,25 @@ impl ContextualUserFragment for ImageGenerationInstructions {
const END_MARKER: &'static str = "";
fn body(&self) -> String {
format!(
let mut body = format!(
"Generated images are saved to {} as {} by default.\nIf you need to use a generated image at another path, copy it and leave the original in place unless the user explicitly asks you to delete it.",
self.image_output_dir, self.image_output_path
)
);
if let (Some(call_id), Some(saved_path)) = (&self.call_id, &self.saved_path) {
body.push_str(&format!(
"\n\nThe most recent image_generation_call completed and was saved locally.\nArtifact metadata:\n- call_id: {call_id}\n- saved_path: {saved_path}"
));
if let Some(revised_prompt) = self.revised_prompt.as_deref()
&& !revised_prompt.is_empty()
{
body.push_str(&format!("\n- revised_prompt: {revised_prompt}"));
}
body.push_str(
"\n\nContinue the workflow now. If caller, user, or skill instructions ask for final message text alongside the image, provide it. If the generated image alone fully satisfies the request and no follow-up text is needed, keep any final response minimal or empty.",
);
}
body
}
}

View File

@@ -6594,9 +6594,11 @@ async fn handle_output_item_done_records_image_save_history_message() {
tool_runtime: test_tool_runtime(Arc::clone(&session), Arc::clone(&turn_context)),
cancellation_token: CancellationToken::new(),
};
handle_output_item_done(&mut ctx, item.clone(), /*previously_active_item*/ None)
.await
.expect("image generation item should succeed");
let output =
handle_output_item_done(&mut ctx, item.clone(), /*previously_active_item*/ None)
.await
.expect("image generation item should succeed");
assert!(output.image_generation_follow_up_requested);
let history = session.clone_history().await;
let image_output_path = crate::stream_events_utils::image_generation_artifact_path(
@@ -6608,9 +6610,12 @@ async fn handle_output_item_done_records_image_save_history_message() {
.parent()
.expect("generated image path should have a parent");
let image_message: ResponseItem = crate::context::ContextualUserFragment::into(
crate::context::ImageGenerationInstructions::new(
crate::context::ImageGenerationInstructions::for_generated_image(
image_output_dir.display(),
image_output_path.display(),
call_id,
expected_saved_path.display(),
Some("a tiny blue square"),
),
);
assert_eq!(history.raw_items(), &[image_message, item]);
@@ -6646,9 +6651,11 @@ async fn handle_output_item_done_skips_image_save_message_when_save_fails() {
tool_runtime: test_tool_runtime(Arc::clone(&session), Arc::clone(&turn_context)),
cancellation_token: CancellationToken::new(),
};
handle_output_item_done(&mut ctx, item.clone(), /*previously_active_item*/ None)
.await
.expect("image generation item should still complete");
let output =
handle_output_item_done(&mut ctx, item.clone(), /*previously_active_item*/ None)
.await
.expect("image generation item should still complete");
assert!(!output.image_generation_follow_up_requested);
let history = session.clone_history().await;
assert_eq!(history.raw_items(), &[item]);

View File

@@ -119,6 +119,8 @@ use tracing::trace;
use tracing::trace_span;
use tracing::warn;
const MAX_IMAGE_GENERATION_FOLLOW_UPS_PER_TURN: usize = 1;
/// Takes a user message as input and runs a loop where, at each sampling request, the model
/// replies with either:
///
@@ -380,6 +382,7 @@ pub(crate) async fn run_turn(
// 1. At the start of a turn, so the fresh user prompt in `input` gets sampled first.
// 2. After auto-compact, when model/tool continuation needs to resume before any steer.
let mut can_drain_pending_input = input.is_empty();
let mut image_generation_follow_ups = 0usize;
loop {
if run_pending_session_start_hooks(&sess, &turn_context).await {
@@ -466,10 +469,20 @@ pub(crate) async fn run_turn(
{
Ok(sampling_request_output) => {
let SamplingRequestResult {
needs_follow_up: model_needs_follow_up,
needs_follow_up: sampling_needs_follow_up,
image_generation_follow_up_requested,
last_agent_message: sampling_request_last_agent_message,
} = sampling_request_output;
can_drain_pending_input = true;
let image_generation_follow_up = image_generation_follow_up_requested
&& turn_context
.features
.enabled(Feature::ImageGenerationContinuation)
&& image_generation_follow_ups < MAX_IMAGE_GENERATION_FOLLOW_UPS_PER_TURN;
if image_generation_follow_up {
image_generation_follow_ups += 1;
}
let model_needs_follow_up = sampling_needs_follow_up || image_generation_follow_up;
let has_pending_input = sess.has_pending_input().await;
let needs_follow_up = model_needs_follow_up || has_pending_input;
let total_usage_tokens = sess.get_total_token_usage().await;
@@ -485,6 +498,9 @@ pub(crate) async fn run_turn(
auto_compact_limit,
token_limit_reached,
model_needs_follow_up,
sampling_needs_follow_up,
image_generation_follow_up_requested,
image_generation_follow_up,
has_pending_input,
needs_follow_up,
"post sampling token usage"
@@ -1278,6 +1294,7 @@ pub(crate) async fn built_tools(
#[derive(Debug)]
struct SamplingRequestResult {
needs_follow_up: bool,
image_generation_follow_up_requested: bool,
last_agent_message: Option<String>,
}
@@ -1867,6 +1884,7 @@ async fn try_run_sampling_request(
let mut in_flight: FuturesOrdered<BoxFuture<'static, CodexResult<ResponseInputItem>>> =
FuturesOrdered::new();
let mut needs_follow_up = false;
let mut image_generation_follow_up_requested = false;
let mut last_agent_message: Option<String> = None;
let mut active_item: Option<TurnItem> = None;
let mut active_tool_argument_diff_consumer: Option<(
@@ -1998,10 +2016,13 @@ async fn try_run_sampling_request(
last_agent_message = Some(agent_message);
}
needs_follow_up |= output_result.needs_follow_up;
image_generation_follow_up_requested |=
output_result.image_generation_follow_up_requested;
// todo: remove before stabilizing multi-agent v2
if preempt_for_mailbox_mail && sess.mailbox_rx.lock().await.has_pending() {
break Ok(SamplingRequestResult {
needs_follow_up: true,
image_generation_follow_up_requested: false,
last_agent_message,
});
}
@@ -2127,6 +2148,7 @@ async fn try_run_sampling_request(
completed_response_id = Some(response_id);
break Ok(SamplingRequestResult {
needs_follow_up,
image_generation_follow_up_requested,
last_agent_message,
});
}

View File

@@ -209,6 +209,7 @@ pub(crate) type InFlightFuture<'f> =
pub(crate) struct OutputItemResult {
pub last_agent_message: Option<String>,
pub needs_follow_up: bool,
pub image_generation_follow_up_requested: bool,
pub tool_future: Option<InFlightFuture<'static>>,
}
@@ -266,6 +267,10 @@ pub(crate) async fn handle_output_item_done(
)
.await;
if let Some(turn_item) = turn_item {
let image_generation_follow_up_requested = matches!(
&turn_item,
TurnItem::ImageGeneration(item) if item.saved_path.is_some()
);
if previously_active_item.is_none() {
let mut started_item = turn_item.clone();
if let TurnItem::ImageGeneration(item) = &mut started_item {
@@ -282,6 +287,7 @@ pub(crate) async fn handle_output_item_done(
ctx.sess
.emit_turn_item_completed(&ctx.turn_context, turn_item)
.await;
output.image_generation_follow_up_requested |= image_generation_follow_up_requested;
}
record_completed_response_item(ctx.sess.as_ref(), ctx.turn_context.as_ref(), &item)
.await;
@@ -387,7 +393,7 @@ pub(crate) async fn handle_non_tool_response_item(
.await
{
Ok(path) => {
image_item.saved_path = Some(path);
image_item.saved_path = Some(path.clone());
let image_output_path = image_generation_artifact_path(
&turn_context.config.codex_home,
&session_id,
@@ -396,11 +402,15 @@ pub(crate) async fn handle_non_tool_response_item(
let image_output_dir = image_output_path
.parent()
.unwrap_or_else(|| turn_context.config.codex_home.clone());
let message: ResponseItem =
ContextualUserFragment::into(ImageGenerationInstructions::new(
let message: ResponseItem = ContextualUserFragment::into(
ImageGenerationInstructions::for_generated_image(
image_output_dir.display(),
image_output_path.display(),
));
&image_item.id,
path.display(),
image_item.revised_prompt.as_deref(),
),
);
sess.record_conversation_items(turn_context, &[message])
.await;
}

View File

@@ -1,6 +1,7 @@
#![cfg(not(target_os = "windows"))]
use anyhow::Ok;
use codex_features::Feature;
use codex_protocol::config_types::CollaborationMode;
use codex_protocol::config_types::ModeKind;
use codex_protocol::config_types::Settings;
@@ -30,6 +31,7 @@ use core_test_support::responses::ev_response_created;
use core_test_support::responses::ev_web_search_call_added_partial;
use core_test_support::responses::ev_web_search_call_done;
use core_test_support::responses::mount_sse_once;
use core_test_support::responses::mount_sse_sequence;
use core_test_support::responses::sse;
use core_test_support::responses::start_mock_server;
use core_test_support::skip_if_no_network;
@@ -368,7 +370,21 @@ async fn image_generation_call_event_is_emitted() -> anyhow::Result<()> {
ev_image_generation_call(call_id, "completed", "A tiny blue square", "Zm9v"),
ev_completed("resp-1"),
]);
mount_sse_once(&server, first_response).await;
let responses = mount_sse_sequence(
&server,
vec![
first_response,
sse(vec![
ev_response_created("resp-2"),
ev_assistant_message(
"msg-1",
"The image is ready and the saved path is available.",
),
ev_completed("resp-2"),
]),
],
)
.await;
codex
.submit(Op::UserInput {
@@ -424,6 +440,97 @@ async fn image_generation_call_event_is_emitted() -> anyhow::Result<()> {
end.saved_path.as_ref().map(AbsolutePathBuf::as_path),
Some(expected_saved_path.as_path())
);
let final_message = wait_for_event_match(&codex, |ev| match ev {
EventMsg::ItemCompleted(ItemCompletedEvent {
item: TurnItem::AgentMessage(item),
..
}) => Some(item.clone()),
_ => None,
})
.await;
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await;
let final_text = final_message
.content
.iter()
.map(|entry| match entry {
AgentMessageContent::Text { text } => text.as_str(),
})
.collect::<String>();
assert_eq!(
final_text,
"The image is ready and the saved path is available."
);
let requests = responses.requests();
assert_eq!(requests.len(), 2);
let second_request_developer_texts = requests[1].message_input_texts("developer");
assert!(
second_request_developer_texts
.iter()
.any(|text| text.contains("The most recent image_generation_call completed"))
);
assert!(
second_request_developer_texts
.iter()
.any(|text| text.contains(&expected_saved_path.display().to_string()))
);
assert_eq!(std::fs::read(&expected_saved_path)?, b"foo");
let _ = std::fs::remove_file(&expected_saved_path);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn image_generation_follow_up_can_be_disabled() -> anyhow::Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let mut builder = test_codex().with_config(|config| {
config
.features
.disable(Feature::ImageGenerationContinuation)
.expect("test config should allow disabling image generation continuation");
});
let TestCodex {
codex,
config,
session_configured,
..
} = builder.build(&server).await?;
let call_id = "ig_no_follow_up";
let expected_saved_path = image_generation_artifact_path(
config.codex_home.as_path(),
&session_configured.thread_id.to_string(),
call_id,
);
let _ = std::fs::remove_file(&expected_saved_path);
let responses = mount_sse_once(
&server,
sse(vec![
ev_response_created("resp-1"),
ev_image_generation_call(call_id, "completed", "A tiny blue square", "Zm9v"),
ev_completed("resp-1"),
]),
)
.await;
codex
.submit(Op::UserInput {
environments: None,
items: vec![UserInput::Text {
text: "generate a tiny blue square".into(),
text_elements: Vec::new(),
}],
final_output_json_schema: None,
responsesapi_client_metadata: None,
})
.await?;
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TurnComplete(_))).await;
assert_eq!(responses.requests().len(), 1);
assert_eq!(std::fs::read(&expected_saved_path)?, b"foo");
let _ = std::fs::remove_file(&expected_saved_path);

View File

@@ -424,6 +424,10 @@ async fn model_change_from_image_to_text_strips_prior_image_content() -> Result<
.with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing())
.with_config(move |config| {
config.model = Some(image_model_slug.to_string());
config
.features
.disable(Feature::ImageGenerationContinuation)
.expect("test config should allow disabling image generation continuation");
});
let test = builder.build(&server).await?;
let models_manager = test.thread_manager.get_models_manager();
@@ -536,6 +540,10 @@ async fn generated_image_is_replayed_for_image_capable_models() -> Result<()> {
.with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing())
.with_config(move |config| {
config.model = Some(image_model_slug.to_string());
config
.features
.disable(Feature::ImageGenerationContinuation)
.expect("test config should allow disabling image generation continuation");
});
let test = builder.build(&server).await?;
let saved_path = image_generation_artifact_path(
@@ -650,6 +658,10 @@ async fn model_change_from_generated_image_to_text_preserves_prior_generated_ima
.with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing())
.with_config(move |config| {
config.model = Some(image_model_slug.to_string());
config
.features
.disable(Feature::ImageGenerationContinuation)
.expect("test config should allow disabling image generation continuation");
});
let test = builder.build(&server).await?;
let saved_path = image_generation_artifact_path(

View File

@@ -192,6 +192,9 @@ pub enum Feature {
ExternalMigration,
/// Allow the model to invoke the built-in image generation tool.
ImageGeneration,
/// Continue sampling after image generation so workflows can return text
/// and artifact metadata alongside the generated image.
ImageGenerationContinuation,
/// Allow prompting and installing missing MCP dependencies.
SkillMcpDependencyInstall,
/// Prompt for missing skill env var dependencies.
@@ -1022,6 +1025,12 @@ pub const FEATURES: &[FeatureSpec] = &[
stage: Stage::Stable,
default_enabled: true,
},
FeatureSpec {
id: Feature::ImageGenerationContinuation,
key: "image_generation_continuation",
stage: Stage::Stable,
default_enabled: true,
},
FeatureSpec {
id: Feature::SkillMcpDependencyInstall,
key: "skill_mcp_dependency_install",

View File

@@ -226,6 +226,12 @@ fn image_generation_is_stable_and_enabled_by_default() {
assert_eq!(Feature::ImageGeneration.default_enabled(), true);
}
#[test]
fn image_generation_continuation_is_stable_and_enabled_by_default() {
assert_eq!(Feature::ImageGenerationContinuation.stage(), Stage::Stable);
assert_eq!(Feature::ImageGenerationContinuation.default_enabled(), true);
}
#[test]
fn use_legacy_landlock_config_records_deprecation_notice() {
let mut entries = BTreeMap::new();