mirror of
https://github.com/openai/codex.git
synced 2026-02-01 22:47:52 +00:00
Compare commits
2 Commits
owen/updat
...
owen/threa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1ea209077 | ||
|
|
593fb755c4 |
@@ -43,8 +43,14 @@ use codex_app_server_protocol::SendUserMessageParams;
|
||||
use codex_app_server_protocol::SendUserMessageResponse;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_app_server_protocol::ServerRequest;
|
||||
use codex_app_server_protocol::ThreadItem;
|
||||
use codex_app_server_protocol::ThreadResumeParams;
|
||||
use codex_app_server_protocol::ThreadResumeResponse;
|
||||
use codex_app_server_protocol::ThreadRollbackParams;
|
||||
use codex_app_server_protocol::ThreadRollbackResponse;
|
||||
use codex_app_server_protocol::ThreadStartParams;
|
||||
use codex_app_server_protocol::ThreadStartResponse;
|
||||
use codex_app_server_protocol::Turn;
|
||||
use codex_app_server_protocol::TurnStartParams;
|
||||
use codex_app_server_protocol::TurnStartResponse;
|
||||
use codex_app_server_protocol::TurnStatus;
|
||||
@@ -113,6 +119,9 @@ enum CliCommand {
|
||||
TestLogin,
|
||||
/// Fetch the current account rate limits from the Codex app-server.
|
||||
GetAccountRateLimits,
|
||||
/// Send multiple turns, roll back the most recent turn, and verify the thread history changed.
|
||||
#[command(name = "thread-rollback")]
|
||||
ThreadRollback,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@@ -134,6 +143,7 @@ fn main() -> Result<()> {
|
||||
} => send_follow_up_v2(codex_bin, first_message, follow_up_message),
|
||||
CliCommand::TestLogin => test_login(codex_bin),
|
||||
CliCommand::GetAccountRateLimits => get_account_rate_limits(codex_bin),
|
||||
CliCommand::ThreadRollback => thread_rollback(codex_bin),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -213,10 +223,7 @@ fn send_message_v2_with_policies(
|
||||
turn_params.approval_policy = approval_policy;
|
||||
turn_params.sandbox_policy = sandbox_policy;
|
||||
|
||||
let turn_response = client.turn_start(turn_params)?;
|
||||
println!("< turn/start response: {turn_response:?}");
|
||||
|
||||
client.stream_turn(&thread_response.thread.id, &turn_response.turn.id)?;
|
||||
let _ = client.run_turn(turn_params)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -234,27 +241,8 @@ fn send_follow_up_v2(
|
||||
let thread_response = client.thread_start(ThreadStartParams::default())?;
|
||||
println!("< thread/start response: {thread_response:?}");
|
||||
|
||||
let first_turn_params = TurnStartParams {
|
||||
thread_id: thread_response.thread.id.clone(),
|
||||
input: vec![V2UserInput::Text {
|
||||
text: first_message,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
let first_turn_response = client.turn_start(first_turn_params)?;
|
||||
println!("< turn/start response (initial): {first_turn_response:?}");
|
||||
client.stream_turn(&thread_response.thread.id, &first_turn_response.turn.id)?;
|
||||
|
||||
let follow_up_params = TurnStartParams {
|
||||
thread_id: thread_response.thread.id.clone(),
|
||||
input: vec![V2UserInput::Text {
|
||||
text: follow_up_message,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
let follow_up_response = client.turn_start(follow_up_params)?;
|
||||
println!("< turn/start response (follow-up): {follow_up_response:?}");
|
||||
client.stream_turn(&thread_response.thread.id, &follow_up_response.turn.id)?;
|
||||
let _ = client.run_turn_text(&thread_response.thread.id, first_message)?;
|
||||
let _ = client.run_turn_text(&thread_response.thread.id, follow_up_message)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -301,6 +289,143 @@ fn get_account_rate_limits(codex_bin: String) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn thread_rollback(codex_bin: String) -> Result<()> {
|
||||
let codex_bin_resume = codex_bin.clone();
|
||||
let mut client = CodexClient::spawn(codex_bin)?;
|
||||
|
||||
let initialize = client.initialize()?;
|
||||
println!("< initialize response: {initialize:?}");
|
||||
|
||||
let thread_response = client.thread_start(ThreadStartParams::default())?;
|
||||
println!("< thread/start response: {thread_response:?}");
|
||||
let thread_id = thread_response.thread.id;
|
||||
|
||||
let _ = client.run_turn_text(&thread_id, "Say pineapple")?;
|
||||
let _ = client.run_turn_text(&thread_id, "Say banana")?;
|
||||
|
||||
let rollback_response = client.thread_rollback(ThreadRollbackParams {
|
||||
thread_id: thread_id.clone(),
|
||||
num_turns: 1,
|
||||
})?;
|
||||
println!("< thread/rollback response: {rollback_response:?}");
|
||||
|
||||
let answer = client
|
||||
.run_turn_text(&thread_id, "What was the last word you said?")?
|
||||
.context("turn completed without an agent message item")?;
|
||||
|
||||
if answer.to_lowercase().contains("pineapple") {
|
||||
println!("Rollback success!");
|
||||
} else {
|
||||
println!("Rollback did not work as expected!");
|
||||
}
|
||||
|
||||
let mut resume_client = CodexClient::spawn(codex_bin_resume)?;
|
||||
let initialize = resume_client.initialize()?;
|
||||
println!("< initialize response (resume client): {initialize:?}");
|
||||
|
||||
let resume_response = resume_client.thread_resume(ThreadResumeParams {
|
||||
thread_id: thread_id.clone(),
|
||||
..Default::default()
|
||||
})?;
|
||||
println!("< thread/resume response: {resume_response:?}");
|
||||
|
||||
verify_resumed_thread_after_rollback(
|
||||
&resume_response,
|
||||
"Say pineapple",
|
||||
"Say banana",
|
||||
"What was the last word you said?",
|
||||
"pineapple",
|
||||
)?;
|
||||
println!("Resume verification success!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn verify_resumed_thread_after_rollback(
|
||||
resume: &ThreadResumeResponse,
|
||||
expected_first_prompt: &str,
|
||||
rolled_back_prompt: &str,
|
||||
expected_follow_up_prompt: &str,
|
||||
expected_word: &str,
|
||||
) -> Result<()> {
|
||||
let mut saw_expected_first_turn = false;
|
||||
let mut saw_expected_follow_up_turn = false;
|
||||
|
||||
for turn in &resume.thread.turns {
|
||||
let user_messages = turn_user_messages(turn);
|
||||
let agent_messages = turn_agent_messages(turn);
|
||||
|
||||
for user_message in &user_messages {
|
||||
if user_message.contains(rolled_back_prompt) {
|
||||
bail!(
|
||||
"thread/resume returned a rolled back prompt: {rolled_back_prompt:?} (thread {})",
|
||||
resume.thread.id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if user_messages
|
||||
.iter()
|
||||
.any(|message| message.contains(expected_first_prompt))
|
||||
&& agent_messages
|
||||
.iter()
|
||||
.any(|message| message.to_lowercase().contains(expected_word))
|
||||
{
|
||||
saw_expected_first_turn = true;
|
||||
}
|
||||
|
||||
if user_messages
|
||||
.iter()
|
||||
.any(|message| message.contains(expected_follow_up_prompt))
|
||||
&& agent_messages
|
||||
.iter()
|
||||
.any(|message| message.to_lowercase().contains(expected_word))
|
||||
{
|
||||
saw_expected_follow_up_turn = true;
|
||||
}
|
||||
}
|
||||
|
||||
if !saw_expected_first_turn {
|
||||
bail!(
|
||||
"thread/resume did not include expected prompt {expected_first_prompt:?} with answer containing {expected_word:?}"
|
||||
);
|
||||
}
|
||||
|
||||
if !saw_expected_follow_up_turn {
|
||||
bail!(
|
||||
"thread/resume did not include expected prompt {expected_follow_up_prompt:?} with answer containing {expected_word:?}"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn turn_user_messages(turn: &Turn) -> Vec<String> {
|
||||
turn.items
|
||||
.iter()
|
||||
.filter_map(|item| match item {
|
||||
ThreadItem::UserMessage { content, .. } => Some(content),
|
||||
_ => None,
|
||||
})
|
||||
.flat_map(|content| {
|
||||
content.iter().filter_map(|input| match input {
|
||||
V2UserInput::Text { text } => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn turn_agent_messages(turn: &Turn) -> Vec<String> {
|
||||
turn.items
|
||||
.iter()
|
||||
.filter_map(|item| match item {
|
||||
ThreadItem::AgentMessage { text, .. } => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
struct CodexClient {
|
||||
child: Child,
|
||||
stdin: Option<ChildStdin>,
|
||||
@@ -422,6 +547,16 @@ impl CodexClient {
|
||||
self.send_request(request, request_id, "thread/start")
|
||||
}
|
||||
|
||||
fn thread_resume(&mut self, params: ThreadResumeParams) -> Result<ThreadResumeResponse> {
|
||||
let request_id = self.request_id();
|
||||
let request = ClientRequest::ThreadResume {
|
||||
request_id: request_id.clone(),
|
||||
params,
|
||||
};
|
||||
|
||||
self.send_request(request, request_id, "thread/resume")
|
||||
}
|
||||
|
||||
fn turn_start(&mut self, params: TurnStartParams) -> Result<TurnStartResponse> {
|
||||
let request_id = self.request_id();
|
||||
let request = ClientRequest::TurnStart {
|
||||
@@ -432,6 +567,39 @@ impl CodexClient {
|
||||
self.send_request(request, request_id, "turn/start")
|
||||
}
|
||||
|
||||
fn run_turn(&mut self, params: TurnStartParams) -> Result<Option<String>> {
|
||||
let thread_id = params.thread_id.clone();
|
||||
let turn_response = self.turn_start(params)?;
|
||||
println!("< turn/start response: {turn_response:?}");
|
||||
self.stream_turn(&thread_id, &turn_response.turn.id)
|
||||
}
|
||||
|
||||
fn run_turn_text(
|
||||
&mut self,
|
||||
thread_id: &str,
|
||||
user_message: impl Into<String>,
|
||||
) -> Result<Option<String>> {
|
||||
let turn_params = TurnStartParams {
|
||||
thread_id: thread_id.to_string(),
|
||||
input: vec![V2UserInput::Text {
|
||||
text: user_message.into(),
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
self.run_turn(turn_params)
|
||||
}
|
||||
|
||||
fn thread_rollback(&mut self, params: ThreadRollbackParams) -> Result<ThreadRollbackResponse> {
|
||||
let request_id = self.request_id();
|
||||
let request = ClientRequest::ThreadRollback {
|
||||
request_id: request_id.clone(),
|
||||
params,
|
||||
};
|
||||
|
||||
self.send_request(request, request_id, "thread/rollback")
|
||||
}
|
||||
|
||||
fn login_chat_gpt(&mut self) -> Result<LoginChatGptResponse> {
|
||||
let request_id = self.request_id();
|
||||
let request = ClientRequest::LoginChatGpt {
|
||||
@@ -526,7 +694,9 @@ impl CodexClient {
|
||||
}
|
||||
}
|
||||
|
||||
fn stream_turn(&mut self, thread_id: &str, turn_id: &str) -> Result<()> {
|
||||
fn stream_turn(&mut self, thread_id: &str, turn_id: &str) -> Result<Option<String>> {
|
||||
let mut last_agent_message = None::<String>;
|
||||
|
||||
loop {
|
||||
let notification = self.next_notification()?;
|
||||
|
||||
@@ -561,7 +731,16 @@ impl CodexClient {
|
||||
println!("\n< item started: {:?}", payload.item);
|
||||
}
|
||||
ServerNotification::ItemCompleted(payload) => {
|
||||
println!("< item completed: {:?}", payload.item);
|
||||
if payload.thread_id == thread_id && payload.turn_id == turn_id {
|
||||
if let ThreadItem::AgentMessage { text, .. } = payload.item {
|
||||
last_agent_message = Some(text);
|
||||
println!("< agent message completed >");
|
||||
} else {
|
||||
println!("< item completed: {:?}", payload.item);
|
||||
}
|
||||
} else {
|
||||
println!("< item completed: {:?}", payload.item);
|
||||
}
|
||||
}
|
||||
ServerNotification::TurnCompleted(payload) => {
|
||||
if payload.turn.id == turn_id {
|
||||
@@ -583,7 +762,7 @@ impl CodexClient {
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(last_agent_message)
|
||||
}
|
||||
|
||||
fn extract_event(
|
||||
|
||||
Reference in New Issue
Block a user