This commit is contained in:
Owen Lin
2025-11-02 13:15:53 -08:00
parent afd59a99af
commit d8075b923c
6 changed files with 424 additions and 13 deletions

View File

@@ -1,8 +1,10 @@
use std::collections::HashMap;
use std::path::PathBuf;
use codex_protocol::ConversationId;
use codex_protocol::account::PlanType;
use codex_protocol::config_types::ReasoningEffort;
use codex_protocol::config_types::ReasoningSummary;
use mcp_types::ContentBlock as McpContentBlock;
use schemars::JsonSchema;
use serde::Deserialize;
@@ -48,6 +50,69 @@ v2_enum_from_core!(
}
);
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, JsonSchema, TS)]
#[serde(tag = "mode", rename_all = "camelCase")]
#[ts(tag = "mode")]
#[ts(export_to = "v2/")]
pub enum SandboxPolicy {
DangerFullAccess,
ReadOnly,
WorkspaceWrite {
#[serde(default, skip_serializing_if = "Vec::is_empty")]
writable_roots: Vec<PathBuf>,
#[serde(default)]
network_access: bool,
#[serde(default)]
exclude_tmpdir_env_var: bool,
#[serde(default)]
exclude_slash_tmp: bool,
},
}
impl SandboxPolicy {
pub fn to_core(&self) -> codex_protocol::protocol::SandboxPolicy {
match self {
SandboxPolicy::DangerFullAccess => {
codex_protocol::protocol::SandboxPolicy::DangerFullAccess
}
SandboxPolicy::ReadOnly => codex_protocol::protocol::SandboxPolicy::ReadOnly,
SandboxPolicy::WorkspaceWrite {
writable_roots,
network_access,
exclude_tmpdir_env_var,
exclude_slash_tmp,
} => codex_protocol::protocol::SandboxPolicy::WorkspaceWrite {
writable_roots: writable_roots.clone(),
network_access: *network_access,
exclude_tmpdir_env_var: *exclude_tmpdir_env_var,
exclude_slash_tmp: *exclude_slash_tmp,
},
}
}
}
impl From<codex_protocol::protocol::SandboxPolicy> for SandboxPolicy {
fn from(value: codex_protocol::protocol::SandboxPolicy) -> Self {
match value {
codex_protocol::protocol::SandboxPolicy::DangerFullAccess => {
SandboxPolicy::DangerFullAccess
}
codex_protocol::protocol::SandboxPolicy::ReadOnly => SandboxPolicy::ReadOnly,
codex_protocol::protocol::SandboxPolicy::WorkspaceWrite {
writable_roots,
network_access,
exclude_tmpdir_env_var,
exclude_slash_tmp,
} => SandboxPolicy::WorkspaceWrite {
writable_roots,
network_access,
exclude_tmpdir_env_var,
exclude_slash_tmp,
},
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export_to = "v2/")]
@@ -284,6 +349,7 @@ pub struct Thread {
#[serde(rename_all = "camelCase")]
#[ts(export_to = "v2/")]
pub struct Turn {
pub id: String,
pub items: Vec<ThreadItem>,
pub status: TurnStatus,
pub error: Option<TurnError>,
@@ -314,16 +380,25 @@ pub enum TurnStatus {
pub struct TurnStartParams {
pub thread_id: String,
pub input: Vec<UserInput>,
pub model: String,
pub effort: ReasoningEffort,
pub summary: String,
// Override the working directory for this turn and subsequent turns.
pub cwd: Option<PathBuf>,
// Override the approval policy for this turn and subsequent turns.
pub approval_policy: Option<AskForApproval>,
// Override the sandbox policy for this turn and subsequent turns.
pub sandbox_policy: Option<SandboxPolicy>,
// Override the model for this turn and subsequent turns.
pub model: Option<String>,
// Override the reasoning effort for this turn and subsequent turns.
pub effort: Option<ReasoningEffort>,
// Override the reasoning summary for this turn and subsequent turns.
pub summary: Option<ReasoningSummary>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
#[ts(export_to = "v2/")]
pub struct TurnStartResponse {
pub turn_id: String,
pub turn: Turn,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
@@ -345,6 +420,7 @@ pub struct TurnInterruptResponse {}
pub enum UserInput {
Text(String),
Image(ImageInput),
LocalImage { path: PathBuf },
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]

View File

@@ -29,6 +29,7 @@ use codex_app_server_protocol::GetConversationSummaryResponse;
use codex_app_server_protocol::GetUserAgentResponse;
use codex_app_server_protocol::GetUserSavedConfigResponse;
use codex_app_server_protocol::GitDiffToRemoteResponse;
use codex_app_server_protocol::ImageInput as V2ImageInput;
use codex_app_server_protocol::InputItem as WireInputItem;
use codex_app_server_protocol::InterruptConversationParams;
use codex_app_server_protocol::InterruptConversationResponse;
@@ -71,6 +72,7 @@ use codex_app_server_protocol::ThreadStartResponse;
use codex_app_server_protocol::ThreadStartedNotification;
use codex_app_server_protocol::Turn;
use codex_app_server_protocol::UserInfoResponse;
use codex_app_server_protocol::UserInput as V2UserInput;
use codex_app_server_protocol::UserSavedConfig;
use codex_backend_client::Client as BackendClient;
use codex_core::AuthManager;
@@ -209,12 +211,8 @@ impl CodexMessageProcessor {
self.send_unimplemented_error(request_id, "thread/compact")
.await;
}
ClientRequest::TurnStart {
request_id,
params: _,
} => {
self.send_unimplemented_error(request_id, "turn/start")
.await;
ClientRequest::TurnStart { request_id, params } => {
self.turn_start(request_id, params).await;
}
ClientRequest::TurnInterrupt {
request_id,
@@ -942,6 +940,117 @@ impl CodexMessageProcessor {
}
}
async fn turn_start(
&self,
request_id: RequestId,
params: codex_app_server_protocol::TurnStartParams,
) {
// Resolve conversation id from v2 thread id string.
let conversation_id = match ConversationId::from_string(&params.thread_id) {
Ok(id) => id,
Err(err) => {
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: format!("invalid thread id: {err}"),
data: None,
};
self.outgoing.send_error(request_id, error).await;
return;
}
};
let Ok(conversation) = self
.conversation_manager
.get_conversation(conversation_id)
.await
else {
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: format!("conversation not found: {conversation_id}"),
data: None,
};
self.outgoing.send_error(request_id, error).await;
return;
};
// Keep a copy of v2 inputs for the notification payload.
let v2_inputs_for_notif = params.input.clone();
// Map v2 input items to core input items.
let mapped_items: Vec<CoreInputItem> = params
.input
.into_iter()
.map(|item| match item {
V2UserInput::Text(text) => CoreInputItem::Text { text },
V2UserInput::Image(V2ImageInput::Image { url }) => {
CoreInputItem::Image { image_url: url }
}
V2UserInput::LocalImage { path } => CoreInputItem::LocalImage { path },
})
.collect();
let has_any_overrides = params.cwd.is_some()
|| params.approval_policy.is_some()
|| params.sandbox_policy.is_some()
|| params.model.is_some()
|| params.effort.is_some()
|| params.summary.is_some();
// If any overrides are provided, update the session turn context first.
if has_any_overrides {
let _ = conversation
.submit(Op::OverrideTurnContext {
cwd: params.cwd,
approval_policy: params
.approval_policy
.map(codex_app_server_protocol::AskForApproval::to_core),
sandbox_policy: params.sandbox_policy.map(|p| p.to_core()),
model: params.model,
effort: params.effort.map(Some),
summary: params.summary,
})
.await;
}
// Start the turn by submitting the user input. Return its submission id as turn_id.
let turn_id = conversation
.submit(Op::UserInput {
items: mapped_items,
})
.await;
match turn_id {
Ok(turn_id) => {
let turn = codex_app_server_protocol::Turn {
id: turn_id.clone(),
items: vec![codex_app_server_protocol::ThreadItem::UserMessage {
id: turn_id,
content: v2_inputs_for_notif,
}],
status: codex_app_server_protocol::TurnStatus::InProgress,
error: None,
};
let response = codex_app_server_protocol::TurnStartResponse { turn: turn.clone() };
self.outgoing.send_response(request_id, response).await;
// Emit v2 turn/started notification.
let notif = codex_app_server_protocol::TurnStartedNotification { turn };
self.outgoing
.send_server_notification(ServerNotification::TurnStarted(notif))
.await;
}
Err(err) => {
let error = JSONRPCErrorError {
code: INTERNAL_ERROR_CODE,
message: format!("failed to start turn: {err}"),
data: None,
};
self.outgoing.send_error(request_id, error).await;
}
}
}
async fn thread_archive(&self, request_id: RequestId, params: ThreadArchiveParams) {
// Resolve conversation id from v2 thread id string.
let conversation_id = match ConversationId::from_string(&params.thread.id) {

View File

@@ -35,6 +35,7 @@ use codex_app_server_protocol::ThreadArchiveParams;
use codex_app_server_protocol::ThreadListParams;
use codex_app_server_protocol::ThreadResumeParams;
use codex_app_server_protocol::ThreadStartParams;
use codex_app_server_protocol::TurnStartParams as V2TurnStartParams;
use codex_app_server_protocol::JSONRPCError;
use codex_app_server_protocol::JSONRPCMessage;
@@ -347,6 +348,15 @@ impl McpProcess {
self.send_request("thread/resume", params).await
}
/// Send a `turn/start` JSON-RPC request (v2).
pub async fn send_turn_start_request(
&mut self,
params: V2TurnStartParams,
) -> anyhow::Result<i64> {
let params = Some(serde_json::to_value(params)?);
self.send_request("turn/start", params).await
}
/// Send a `cancelLoginChatGpt` JSON-RPC request.
pub async fn send_cancel_login_chat_gpt_request(
&mut self,

View File

@@ -1,3 +1,5 @@
mod thread_archive;
mod thread_list;
mod thread_resume;
mod thread_start;
mod turn_start;

View File

@@ -52,11 +52,17 @@ async fn thread_archive_moves_rollout_into_archived_directory() -> Result<()> {
.path()
.join(SESSIONS_SUBDIR)
.join(format!("{}.jsonl", thread.id));
assert!(rollout_path.exists(), "expected {} to exist", rollout_path.display());
assert!(
rollout_path.exists(),
"expected {} to exist",
rollout_path.display()
);
// Archive the thread.
let archive_id = mcp
.send_thread_archive_request(ThreadArchiveParams { thread: thread.clone() })
.send_thread_archive_request(ThreadArchiveParams {
thread: thread.clone(),
})
.await?;
let archive_resp: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
@@ -93,4 +99,3 @@ approval_policy = "never"
sandbox_mode = "read-only"
"#
}

View File

@@ -0,0 +1,209 @@
use anyhow::Result;
use app_test_support::McpProcess;
use app_test_support::create_final_assistant_message_sse_response;
use app_test_support::create_mock_chat_completions_server;
use app_test_support::to_response;
use codex_app_server_protocol::AddConversationListenerParams;
use codex_app_server_protocol::AddConversationSubscriptionResponse;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::RequestId;
use codex_app_server_protocol::ThreadStartParams;
use codex_app_server_protocol::ThreadStartResponse;
use codex_app_server_protocol::TurnStartParams;
use codex_app_server_protocol::TurnStartResponse;
use codex_app_server_protocol::TurnStartedNotification;
use codex_app_server_protocol::UserInput as V2UserInput;
use pretty_assertions::assert_eq;
use std::path::Path;
use tempfile::TempDir;
use tokio::time::timeout;
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn turn_start_emits_notification_and_completes_without_overrides() -> Result<()> {
// Provide a mock server and config so model wiring is valid.
// Two Codex turns hit the mock model (session start + turn/start).
let responses = vec![
create_final_assistant_message_sse_response("Done")?,
create_final_assistant_message_sse_response("Done")?,
];
let server = create_mock_chat_completions_server(responses).await;
let codex_home = TempDir::new()?;
create_config_toml(codex_home.path(), &server.uri())?;
let mut mcp = McpProcess::new(codex_home.path()).await?;
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
// Start a thread (v2) and capture its id.
let thread_req = mcp
.send_thread_start_request(ThreadStartParams {
model: Some("mock-model".to_string()),
model_provider: None,
profile: None,
cwd: None,
approval_policy: None,
sandbox: None,
config: None,
base_instructions: None,
developer_instructions: None,
compact_prompt: None,
include_apply_patch_tool: None,
})
.await?;
let thread_resp: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(thread_req)),
)
.await??;
let ThreadStartResponse { thread } = to_response::<ThreadStartResponse>(thread_resp)?;
// Add legacy listener to get task_complete signal.
let conversation_id = codex_protocol::ConversationId::from_string(&thread.id)?;
let add_listener_id = mcp
.send_add_conversation_listener_request(AddConversationListenerParams {
conversation_id,
experimental_raw_events: false,
})
.await?;
let add_listener_resp: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(add_listener_id)),
)
.await??;
let AddConversationSubscriptionResponse { .. } = to_response::<_>(add_listener_resp)?;
// Start a turn with only input and thread_id set.
let turn_req = mcp
.send_turn_start_request(TurnStartParams {
thread_id: thread.id.clone(),
input: vec![V2UserInput::Text("Hello".to_string())],
cwd: None,
approval_policy: None,
sandbox_policy: None,
model: None,
effort: None,
summary: None,
})
.await?;
let turn_resp: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(turn_req)),
)
.await??;
let TurnStartResponse { turn } = to_response::<TurnStartResponse>(turn_resp)?;
assert!(!turn.id.is_empty());
// Expect a turn/started notification.
let notif: JSONRPCNotification = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("turn/started"),
)
.await??;
let started: TurnStartedNotification =
serde_json::from_value(notif.params.expect("params must be present"))?;
assert_eq!(
started.turn.status,
codex_app_server_protocol::TurnStatus::InProgress
);
// And we should see task_complete for the legacy stream as the turn finishes.
let _task_done: JSONRPCNotification = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("codex/event/task_complete"),
)
.await??;
drop(server);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn turn_start_accepts_local_image_input() -> Result<()> {
// Two Codex turns hit the mock model (session start + turn/start).
let responses = vec![
create_final_assistant_message_sse_response("Done")?,
create_final_assistant_message_sse_response("Done")?,
];
let server = create_mock_chat_completions_server(responses).await;
let codex_home = TempDir::new()?;
create_config_toml(codex_home.path(), &server.uri())?;
let mut mcp = McpProcess::new(codex_home.path()).await?;
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
let thread_req = mcp
.send_thread_start_request(ThreadStartParams {
model: Some("mock-model".to_string()),
model_provider: None,
profile: None,
cwd: None,
approval_policy: None,
sandbox: None,
config: None,
base_instructions: None,
developer_instructions: None,
compact_prompt: None,
include_apply_patch_tool: None,
})
.await?;
let thread_resp: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(thread_req)),
)
.await??;
let ThreadStartResponse { thread } = to_response::<ThreadStartResponse>(thread_resp)?;
let image_path = codex_home.path().join("image.png");
// No need to actually write the file; we just exercise the input path.
let turn_req = mcp
.send_turn_start_request(TurnStartParams {
thread_id: thread.id.clone(),
input: vec![V2UserInput::LocalImage { path: image_path }],
cwd: None,
approval_policy: None,
sandbox_policy: None,
model: None,
effort: None,
summary: None,
})
.await?;
let turn_resp: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(turn_req)),
)
.await??;
let TurnStartResponse { turn } = to_response::<TurnStartResponse>(turn_resp)?;
assert!(!turn.id.is_empty());
drop(server);
Ok(())
}
// Helper to create a config.toml pointing at the mock model server.
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
let config_toml = codex_home.join("config.toml");
std::fs::write(
config_toml,
format!(
r#"
model = "mock-model"
approval_policy = "never"
sandbox_mode = "danger-full-access"
model_provider = "mock_provider"
[model_providers.mock_provider]
name = "Mock provider for test"
base_url = "{server_uri}/v1"
wire_api = "chat"
request_max_retries = 0
stream_max_retries = 0
"#
),
)
}