mirror of
https://github.com/openai/codex.git
synced 2026-04-25 15:15:15 +00:00
595 lines
19 KiB
Rust
595 lines
19 KiB
Rust
use anyhow::Result;
|
|
use app_test_support::ChatGptAuthFixture;
|
|
use app_test_support::McpProcess;
|
|
use app_test_support::create_fake_rollout;
|
|
use app_test_support::create_mock_responses_server_repeating_assistant;
|
|
use app_test_support::to_response;
|
|
use app_test_support::write_chatgpt_auth;
|
|
use codex_app_server_protocol::JSONRPCError;
|
|
use codex_app_server_protocol::JSONRPCMessage;
|
|
use codex_app_server_protocol::JSONRPCResponse;
|
|
use codex_app_server_protocol::RequestId;
|
|
use codex_app_server_protocol::SessionSource;
|
|
use codex_app_server_protocol::ThreadForkParams;
|
|
use codex_app_server_protocol::ThreadForkResponse;
|
|
use codex_app_server_protocol::ThreadItem;
|
|
use codex_app_server_protocol::ThreadListParams;
|
|
use codex_app_server_protocol::ThreadListResponse;
|
|
use codex_app_server_protocol::ThreadStartParams;
|
|
use codex_app_server_protocol::ThreadStartResponse;
|
|
use codex_app_server_protocol::ThreadStartedNotification;
|
|
use codex_app_server_protocol::ThreadStatus;
|
|
use codex_app_server_protocol::ThreadStatusChangedNotification;
|
|
use codex_app_server_protocol::TurnStartParams;
|
|
use codex_app_server_protocol::TurnStartResponse;
|
|
use codex_app_server_protocol::TurnStatus;
|
|
use codex_app_server_protocol::UserInput;
|
|
use codex_config::types::AuthCredentialsStoreMode;
|
|
use codex_login::REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR;
|
|
use pretty_assertions::assert_eq;
|
|
use serde_json::Value;
|
|
use serde_json::json;
|
|
use std::path::Path;
|
|
use tempfile::TempDir;
|
|
use tokio::time::timeout;
|
|
use wiremock::Mock;
|
|
use wiremock::MockServer;
|
|
use wiremock::ResponseTemplate;
|
|
use wiremock::matchers::method;
|
|
use wiremock::matchers::path;
|
|
|
|
use super::analytics::assert_basic_thread_initialized_event;
|
|
use super::analytics::enable_analytics_capture;
|
|
use super::analytics::thread_initialized_event;
|
|
use super::analytics::wait_for_analytics_payload;
|
|
|
|
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
|
|
|
|
#[tokio::test]
|
|
async fn thread_fork_creates_new_thread_and_emits_started() -> Result<()> {
|
|
let server = create_mock_responses_server_repeating_assistant("Done").await;
|
|
let codex_home = TempDir::new()?;
|
|
create_config_toml(codex_home.path(), &server.uri())?;
|
|
|
|
let preview = "Saved user message";
|
|
let conversation_id = create_fake_rollout(
|
|
codex_home.path(),
|
|
"2025-01-05T12-00-00",
|
|
"2025-01-05T12:00:00Z",
|
|
preview,
|
|
Some("mock_provider"),
|
|
/*git_info*/ None,
|
|
)?;
|
|
|
|
let original_path = codex_home
|
|
.path()
|
|
.join("sessions")
|
|
.join("2025")
|
|
.join("01")
|
|
.join("05")
|
|
.join(format!(
|
|
"rollout-2025-01-05T12-00-00-{conversation_id}.jsonl"
|
|
));
|
|
assert!(
|
|
original_path.exists(),
|
|
"expected original rollout to exist at {}",
|
|
original_path.display()
|
|
);
|
|
let original_contents = std::fs::read_to_string(&original_path)?;
|
|
|
|
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
|
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
|
|
|
let fork_id = mcp
|
|
.send_thread_fork_request(ThreadForkParams {
|
|
thread_id: conversation_id.clone(),
|
|
..Default::default()
|
|
})
|
|
.await?;
|
|
let fork_resp: JSONRPCResponse = timeout(
|
|
DEFAULT_READ_TIMEOUT,
|
|
mcp.read_stream_until_response_message(RequestId::Integer(fork_id)),
|
|
)
|
|
.await??;
|
|
let fork_result = fork_resp.result.clone();
|
|
let ThreadForkResponse { thread, .. } = to_response::<ThreadForkResponse>(fork_resp)?;
|
|
|
|
// Wire contract: thread title field is `name`, serialized as null when unset.
|
|
let thread_json = fork_result
|
|
.get("thread")
|
|
.and_then(Value::as_object)
|
|
.expect("thread/fork result.thread must be an object");
|
|
assert_eq!(
|
|
thread_json.get("name"),
|
|
Some(&Value::Null),
|
|
"forked threads do not inherit a name; expected `name: null`"
|
|
);
|
|
|
|
let after_contents = std::fs::read_to_string(&original_path)?;
|
|
assert_eq!(
|
|
after_contents, original_contents,
|
|
"fork should not mutate the original rollout file"
|
|
);
|
|
|
|
assert_ne!(thread.id, conversation_id);
|
|
assert_eq!(thread.forked_from_id, Some(conversation_id.clone()));
|
|
assert_eq!(thread.preview, preview);
|
|
assert_eq!(thread.model_provider, "mock_provider");
|
|
assert_eq!(thread.status, ThreadStatus::Idle);
|
|
let thread_path = thread.path.clone().expect("thread path");
|
|
assert!(thread_path.is_absolute());
|
|
assert_ne!(thread_path, original_path);
|
|
assert!(thread.cwd.is_absolute());
|
|
assert_eq!(thread.source, SessionSource::VsCode);
|
|
assert_eq!(thread.name, None);
|
|
|
|
assert_eq!(
|
|
thread.turns.len(),
|
|
1,
|
|
"expected forked thread to include one turn"
|
|
);
|
|
let turn = &thread.turns[0];
|
|
assert_eq!(turn.status, TurnStatus::Interrupted);
|
|
assert_eq!(turn.items.len(), 1, "expected user message item");
|
|
match &turn.items[0] {
|
|
ThreadItem::UserMessage { content, .. } => {
|
|
assert_eq!(
|
|
content,
|
|
&vec![UserInput::Text {
|
|
text: preview.to_string(),
|
|
text_elements: Vec::new(),
|
|
}]
|
|
);
|
|
}
|
|
other => panic!("expected user message item, got {other:?}"),
|
|
}
|
|
|
|
// A corresponding thread/started notification should arrive.
|
|
let deadline = tokio::time::Instant::now() + DEFAULT_READ_TIMEOUT;
|
|
let notif = loop {
|
|
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
|
|
let message = timeout(remaining, mcp.read_next_message()).await??;
|
|
let JSONRPCMessage::Notification(notif) = message else {
|
|
continue;
|
|
};
|
|
if notif.method == "thread/status/changed" {
|
|
let status_changed: ThreadStatusChangedNotification =
|
|
serde_json::from_value(notif.params.expect("params must be present"))?;
|
|
if status_changed.thread_id == thread.id {
|
|
anyhow::bail!(
|
|
"thread/fork should introduce the thread without a preceding thread/status/changed"
|
|
);
|
|
}
|
|
continue;
|
|
}
|
|
if notif.method == "thread/started" {
|
|
break notif;
|
|
}
|
|
};
|
|
let started_params = notif.params.clone().expect("params must be present");
|
|
let started_thread_json = started_params
|
|
.get("thread")
|
|
.and_then(Value::as_object)
|
|
.expect("thread/started params.thread must be an object");
|
|
assert_eq!(
|
|
started_thread_json.get("name"),
|
|
Some(&Value::Null),
|
|
"thread/started must serialize `name: null` when unset"
|
|
);
|
|
let started: ThreadStartedNotification =
|
|
serde_json::from_value(notif.params.expect("params must be present"))?;
|
|
assert_eq!(started.thread, thread);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn thread_fork_tracks_thread_initialized_analytics() -> Result<()> {
|
|
let server = create_mock_responses_server_repeating_assistant("Done").await;
|
|
|
|
let codex_home = TempDir::new()?;
|
|
create_config_toml_with_chatgpt_base_url(
|
|
codex_home.path(),
|
|
&server.uri(),
|
|
&server.uri(),
|
|
/*general_analytics_enabled*/ true,
|
|
)?;
|
|
enable_analytics_capture(&server, codex_home.path()).await?;
|
|
|
|
let conversation_id = create_fake_rollout(
|
|
codex_home.path(),
|
|
"2025-01-05T12-00-00",
|
|
"2025-01-05T12:00:00Z",
|
|
"Saved user message",
|
|
Some("mock_provider"),
|
|
/*git_info*/ None,
|
|
)?;
|
|
|
|
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
|
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
|
|
|
let fork_id = mcp
|
|
.send_thread_fork_request(ThreadForkParams {
|
|
thread_id: conversation_id,
|
|
..Default::default()
|
|
})
|
|
.await?;
|
|
let fork_resp: JSONRPCResponse = timeout(
|
|
DEFAULT_READ_TIMEOUT,
|
|
mcp.read_stream_until_response_message(RequestId::Integer(fork_id)),
|
|
)
|
|
.await??;
|
|
let ThreadForkResponse { thread, .. } = to_response::<ThreadForkResponse>(fork_resp)?;
|
|
|
|
let payload = wait_for_analytics_payload(&server, DEFAULT_READ_TIMEOUT).await?;
|
|
let event = thread_initialized_event(&payload)?;
|
|
assert_basic_thread_initialized_event(event, &thread.id, "mock-model", "forked");
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn thread_fork_rejects_unmaterialized_thread() -> Result<()> {
|
|
let server = create_mock_responses_server_repeating_assistant("Done").await;
|
|
let codex_home = TempDir::new()?;
|
|
create_config_toml(codex_home.path(), &server.uri())?;
|
|
|
|
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
|
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
|
|
|
let start_id = mcp
|
|
.send_thread_start_request(ThreadStartParams {
|
|
model: Some("mock-model".to_string()),
|
|
..Default::default()
|
|
})
|
|
.await?;
|
|
let start_resp: JSONRPCResponse = timeout(
|
|
DEFAULT_READ_TIMEOUT,
|
|
mcp.read_stream_until_response_message(RequestId::Integer(start_id)),
|
|
)
|
|
.await??;
|
|
let ThreadStartResponse { thread, .. } = to_response::<ThreadStartResponse>(start_resp)?;
|
|
|
|
let fork_id = mcp
|
|
.send_thread_fork_request(ThreadForkParams {
|
|
thread_id: thread.id,
|
|
..Default::default()
|
|
})
|
|
.await?;
|
|
let fork_err: JSONRPCError = timeout(
|
|
DEFAULT_READ_TIMEOUT,
|
|
mcp.read_stream_until_error_message(RequestId::Integer(fork_id)),
|
|
)
|
|
.await??;
|
|
assert!(
|
|
fork_err
|
|
.error
|
|
.message
|
|
.contains("no rollout found for thread id"),
|
|
"unexpected fork error: {}",
|
|
fork_err.error.message
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn thread_fork_surfaces_cloud_requirements_load_errors() -> Result<()> {
|
|
let server = MockServer::start().await;
|
|
Mock::given(method("GET"))
|
|
.and(path("/backend-api/wham/config/requirements"))
|
|
.respond_with(
|
|
ResponseTemplate::new(401)
|
|
.insert_header("content-type", "text/html")
|
|
.set_body_string("<html>nope</html>"),
|
|
)
|
|
.mount(&server)
|
|
.await;
|
|
Mock::given(method("POST"))
|
|
.and(path("/oauth/token"))
|
|
.respond_with(ResponseTemplate::new(401).set_body_json(json!({
|
|
"error": { "code": "refresh_token_invalidated" }
|
|
})))
|
|
.mount(&server)
|
|
.await;
|
|
|
|
let codex_home = TempDir::new()?;
|
|
let model_server = create_mock_responses_server_repeating_assistant("Done").await;
|
|
let chatgpt_base_url = format!("{}/backend-api", server.uri());
|
|
create_config_toml_with_chatgpt_base_url(
|
|
codex_home.path(),
|
|
&model_server.uri(),
|
|
&chatgpt_base_url,
|
|
/*general_analytics_enabled*/ false,
|
|
)?;
|
|
write_chatgpt_auth(
|
|
codex_home.path(),
|
|
ChatGptAuthFixture::new("chatgpt-token")
|
|
.refresh_token("stale-refresh-token")
|
|
.plan_type("business")
|
|
.chatgpt_user_id("user-123")
|
|
.chatgpt_account_id("account-123")
|
|
.account_id("account-123"),
|
|
AuthCredentialsStoreMode::File,
|
|
)?;
|
|
|
|
let conversation_id = create_fake_rollout(
|
|
codex_home.path(),
|
|
"2025-01-05T12-00-00",
|
|
"2025-01-05T12:00:00Z",
|
|
"Saved user message",
|
|
Some("mock_provider"),
|
|
/*git_info*/ None,
|
|
)?;
|
|
|
|
let refresh_token_url = format!("{}/oauth/token", server.uri());
|
|
let mut mcp = McpProcess::new_with_env(
|
|
codex_home.path(),
|
|
&[
|
|
("OPENAI_API_KEY", None),
|
|
(
|
|
REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR,
|
|
Some(refresh_token_url.as_str()),
|
|
),
|
|
],
|
|
)
|
|
.await?;
|
|
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
|
|
|
let fork_id = mcp
|
|
.send_thread_fork_request(ThreadForkParams {
|
|
thread_id: conversation_id,
|
|
..Default::default()
|
|
})
|
|
.await?;
|
|
let fork_err: JSONRPCError = timeout(
|
|
DEFAULT_READ_TIMEOUT,
|
|
mcp.read_stream_until_error_message(RequestId::Integer(fork_id)),
|
|
)
|
|
.await??;
|
|
|
|
assert!(
|
|
fork_err
|
|
.error
|
|
.message
|
|
.contains("failed to load configuration"),
|
|
"unexpected fork error: {}",
|
|
fork_err.error.message
|
|
);
|
|
assert_eq!(
|
|
fork_err.error.data,
|
|
Some(json!({
|
|
"reason": "cloudRequirements",
|
|
"errorCode": "Auth",
|
|
"action": "relogin",
|
|
"statusCode": 401,
|
|
"detail": "Your access token could not be refreshed because your refresh token was revoked. Please log out and sign in again.",
|
|
}))
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn thread_fork_ephemeral_remains_pathless_and_omits_listing() -> Result<()> {
|
|
let server = create_mock_responses_server_repeating_assistant("Done").await;
|
|
let codex_home = TempDir::new()?;
|
|
create_config_toml(codex_home.path(), &server.uri())?;
|
|
|
|
let preview = "Saved user message";
|
|
let conversation_id = create_fake_rollout(
|
|
codex_home.path(),
|
|
"2025-01-05T12-00-00",
|
|
"2025-01-05T12:00:00Z",
|
|
preview,
|
|
Some("mock_provider"),
|
|
/*git_info*/ None,
|
|
)?;
|
|
|
|
let mut mcp = McpProcess::new(codex_home.path()).await?;
|
|
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
|
|
|
let fork_id = mcp
|
|
.send_thread_fork_request(ThreadForkParams {
|
|
thread_id: conversation_id.clone(),
|
|
ephemeral: true,
|
|
..Default::default()
|
|
})
|
|
.await?;
|
|
let fork_resp: JSONRPCResponse = timeout(
|
|
DEFAULT_READ_TIMEOUT,
|
|
mcp.read_stream_until_response_message(RequestId::Integer(fork_id)),
|
|
)
|
|
.await??;
|
|
let fork_result = fork_resp.result.clone();
|
|
let ThreadForkResponse { thread, .. } = to_response::<ThreadForkResponse>(fork_resp)?;
|
|
let fork_thread_id = thread.id.clone();
|
|
|
|
assert!(
|
|
thread.ephemeral,
|
|
"ephemeral forks should be marked explicitly"
|
|
);
|
|
assert_eq!(
|
|
thread.path, None,
|
|
"ephemeral forks should not expose a path"
|
|
);
|
|
assert_eq!(thread.preview, preview);
|
|
assert_eq!(thread.status, ThreadStatus::Idle);
|
|
assert_eq!(thread.name, None);
|
|
assert_eq!(thread.turns.len(), 1, "expected copied fork history");
|
|
|
|
let turn = &thread.turns[0];
|
|
assert_eq!(turn.status, TurnStatus::Completed);
|
|
assert_eq!(turn.items.len(), 1, "expected user message item");
|
|
match &turn.items[0] {
|
|
ThreadItem::UserMessage { content, .. } => {
|
|
assert_eq!(
|
|
content,
|
|
&vec![UserInput::Text {
|
|
text: preview.to_string(),
|
|
text_elements: Vec::new(),
|
|
}]
|
|
);
|
|
}
|
|
other => panic!("expected user message item, got {other:?}"),
|
|
}
|
|
|
|
let thread_json = fork_result
|
|
.get("thread")
|
|
.and_then(Value::as_object)
|
|
.expect("thread/fork result.thread must be an object");
|
|
assert_eq!(
|
|
thread_json.get("ephemeral").and_then(Value::as_bool),
|
|
Some(true),
|
|
"ephemeral forks should serialize `ephemeral: true`"
|
|
);
|
|
|
|
let deadline = tokio::time::Instant::now() + DEFAULT_READ_TIMEOUT;
|
|
let notif = loop {
|
|
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
|
|
let message = timeout(remaining, mcp.read_next_message()).await??;
|
|
let JSONRPCMessage::Notification(notif) = message else {
|
|
continue;
|
|
};
|
|
if notif.method == "thread/status/changed" {
|
|
let status_changed: ThreadStatusChangedNotification =
|
|
serde_json::from_value(notif.params.expect("params must be present"))?;
|
|
if status_changed.thread_id == fork_thread_id {
|
|
anyhow::bail!(
|
|
"thread/fork should introduce the thread without a preceding thread/status/changed"
|
|
);
|
|
}
|
|
continue;
|
|
}
|
|
if notif.method == "thread/started" {
|
|
break notif;
|
|
}
|
|
};
|
|
let started_params = notif.params.clone().expect("params must be present");
|
|
let started_thread_json = started_params
|
|
.get("thread")
|
|
.and_then(Value::as_object)
|
|
.expect("thread/started params.thread must be an object");
|
|
assert_eq!(
|
|
started_thread_json
|
|
.get("ephemeral")
|
|
.and_then(Value::as_bool),
|
|
Some(true),
|
|
"thread/started should serialize `ephemeral: true` for ephemeral forks"
|
|
);
|
|
let started: ThreadStartedNotification =
|
|
serde_json::from_value(notif.params.expect("params must be present"))?;
|
|
assert_eq!(started.thread, thread);
|
|
|
|
let list_id = mcp
|
|
.send_thread_list_request(ThreadListParams {
|
|
cursor: None,
|
|
limit: Some(10),
|
|
sort_key: None,
|
|
model_providers: None,
|
|
source_kinds: None,
|
|
archived: None,
|
|
cwd: None,
|
|
search_term: None,
|
|
})
|
|
.await?;
|
|
let list_resp: JSONRPCResponse = timeout(
|
|
DEFAULT_READ_TIMEOUT,
|
|
mcp.read_stream_until_response_message(RequestId::Integer(list_id)),
|
|
)
|
|
.await??;
|
|
let ThreadListResponse { data, .. } = to_response::<ThreadListResponse>(list_resp)?;
|
|
assert!(
|
|
data.iter().all(|candidate| candidate.id != fork_thread_id),
|
|
"ephemeral forks should not appear in thread/list"
|
|
);
|
|
assert!(
|
|
data.iter().any(|candidate| candidate.id == conversation_id),
|
|
"persistent source thread should remain listed"
|
|
);
|
|
|
|
let turn_id = mcp
|
|
.send_turn_start_request(TurnStartParams {
|
|
thread_id: fork_thread_id,
|
|
input: vec![UserInput::Text {
|
|
text: "continue".to_string(),
|
|
text_elements: Vec::new(),
|
|
}],
|
|
..Default::default()
|
|
})
|
|
.await?;
|
|
let turn_resp: JSONRPCResponse = timeout(
|
|
DEFAULT_READ_TIMEOUT,
|
|
mcp.read_stream_until_response_message(RequestId::Integer(turn_id)),
|
|
)
|
|
.await??;
|
|
let _: TurnStartResponse = to_response::<TurnStartResponse>(turn_resp)?;
|
|
timeout(
|
|
DEFAULT_READ_TIMEOUT,
|
|
mcp.read_stream_until_notification_message("turn/completed"),
|
|
)
|
|
.await??;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
// Helper to create a config.toml pointing at the mock model server.
|
|
fn create_config_toml(codex_home: &Path, server_uri: &str) -> std::io::Result<()> {
|
|
let config_toml = codex_home.join("config.toml");
|
|
std::fs::write(
|
|
config_toml,
|
|
format!(
|
|
r#"
|
|
model = "mock-model"
|
|
approval_policy = "never"
|
|
sandbox_mode = "read-only"
|
|
|
|
model_provider = "mock_provider"
|
|
|
|
[model_providers.mock_provider]
|
|
name = "Mock provider for test"
|
|
base_url = "{server_uri}/v1"
|
|
wire_api = "responses"
|
|
request_max_retries = 0
|
|
stream_max_retries = 0
|
|
"#
|
|
),
|
|
)
|
|
}
|
|
|
|
fn create_config_toml_with_chatgpt_base_url(
|
|
codex_home: &Path,
|
|
server_uri: &str,
|
|
chatgpt_base_url: &str,
|
|
general_analytics_enabled: bool,
|
|
) -> std::io::Result<()> {
|
|
let general_analytics_toml = if general_analytics_enabled {
|
|
"\ngeneral_analytics = true".to_string()
|
|
} else {
|
|
String::new()
|
|
};
|
|
let config_toml = codex_home.join("config.toml");
|
|
std::fs::write(
|
|
config_toml,
|
|
format!(
|
|
r#"
|
|
model = "mock-model"
|
|
approval_policy = "never"
|
|
sandbox_mode = "read-only"
|
|
chatgpt_base_url = "{chatgpt_base_url}"
|
|
|
|
model_provider = "mock_provider"
|
|
|
|
[features]
|
|
{general_analytics_toml}
|
|
|
|
[model_providers.mock_provider]
|
|
name = "Mock provider for test"
|
|
base_url = "{server_uri}/v1"
|
|
wire_api = "responses"
|
|
request_max_retries = 0
|
|
stream_max_retries = 0
|
|
"#
|
|
),
|
|
)
|
|
}
|