mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
fork conversation
This commit is contained in:
@@ -130,6 +130,7 @@ pub struct Codex {
|
||||
next_id: AtomicU64,
|
||||
tx_sub: Sender<Submission>,
|
||||
rx_event: Receiver<Event>,
|
||||
session: Arc<Session>,
|
||||
}
|
||||
|
||||
/// Wrapper returned by [`Codex::spawn`] containing the spawned [`Codex`],
|
||||
@@ -144,7 +145,11 @@ pub(crate) const INITIAL_SUBMIT_ID: &str = "";
|
||||
|
||||
impl Codex {
|
||||
/// Spawn a new [`Codex`] and initialize the session.
|
||||
pub async fn spawn(config: Config, auth: Option<CodexAuth>) -> CodexResult<CodexSpawnOk> {
|
||||
pub async fn spawn(
|
||||
config: Config,
|
||||
auth: Option<CodexAuth>,
|
||||
initial_history: Option<Vec<ResponseItem>>,
|
||||
) -> CodexResult<CodexSpawnOk> {
|
||||
let (tx_sub, rx_sub) = async_channel::bounded(64);
|
||||
let (tx_event, rx_event) = async_channel::unbounded();
|
||||
|
||||
@@ -169,21 +174,32 @@ impl Codex {
|
||||
};
|
||||
|
||||
// Generate a unique ID for the lifetime of this Codex session.
|
||||
let (session, turn_context) =
|
||||
Session::new(configure_session, config.clone(), auth, tx_event.clone())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Failed to create session: {e:#}");
|
||||
CodexErr::InternalAgentDied
|
||||
})?;
|
||||
let (session, turn_context) = Session::new(
|
||||
configure_session,
|
||||
config.clone(),
|
||||
auth,
|
||||
tx_event.clone(),
|
||||
initial_history,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Failed to create session: {e:#}");
|
||||
CodexErr::InternalAgentDied
|
||||
})?;
|
||||
let session_id = session.session_id;
|
||||
|
||||
// This task will run until Op::Shutdown is received.
|
||||
tokio::spawn(submission_loop(session, turn_context, config, rx_sub));
|
||||
tokio::spawn(submission_loop(
|
||||
Arc::clone(&session),
|
||||
turn_context,
|
||||
config,
|
||||
rx_sub,
|
||||
));
|
||||
let codex = Codex {
|
||||
next_id: AtomicU64::new(0),
|
||||
tx_sub,
|
||||
rx_event,
|
||||
session: Arc::clone(&session),
|
||||
};
|
||||
|
||||
Ok(CodexSpawnOk { codex, session_id })
|
||||
@@ -218,6 +234,11 @@ impl Codex {
|
||||
.map_err(|_| CodexErr::InternalAgentDied)?;
|
||||
Ok(event)
|
||||
}
|
||||
|
||||
/// Snapshot of the conversation history (oldest → newest).
|
||||
pub(crate) fn history_contents(&self) -> Vec<ResponseItem> {
|
||||
self.session.state.lock_unchecked().history.contents()
|
||||
}
|
||||
}
|
||||
|
||||
/// Mutable state of the agent
|
||||
@@ -325,6 +346,7 @@ impl Session {
|
||||
config: Arc<Config>,
|
||||
auth: Option<CodexAuth>,
|
||||
tx_event: Sender<Event>,
|
||||
initial_history: Option<Vec<ResponseItem>>,
|
||||
) -> anyhow::Result<(Arc<Self>, TurnContext)> {
|
||||
let ConfigureSession {
|
||||
provider,
|
||||
@@ -384,14 +406,16 @@ impl Session {
|
||||
}
|
||||
let rollout_result = match rollout_res {
|
||||
Ok((session_id, maybe_saved, recorder)) => {
|
||||
let restored_items: Option<Vec<ResponseItem>> =
|
||||
maybe_saved.and_then(|saved_session| {
|
||||
let restored_items: Option<Vec<ResponseItem>> = match initial_history {
|
||||
Some(items) => Some(items),
|
||||
None => maybe_saved.and_then(|saved_session| {
|
||||
if saved_session.items.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(saved_session.items)
|
||||
}
|
||||
});
|
||||
}),
|
||||
};
|
||||
RolloutResult {
|
||||
session_id,
|
||||
rollout_recorder: Some(recorder),
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::codex::Codex;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::models::ResponseItem;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::Op;
|
||||
use crate::protocol::Submission;
|
||||
@@ -27,4 +28,9 @@ impl CodexConversation {
|
||||
pub async fn next_event(&self) -> CodexResult<Event> {
|
||||
self.codex.next_event().await
|
||||
}
|
||||
|
||||
/// Return a snapshot of the current conversation history (oldest → newest).
|
||||
pub(crate) fn history_contents(&self) -> Vec<ResponseItem> {
|
||||
self.codex.history_contents()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,7 +54,7 @@ impl ConversationManager {
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
session_id: conversation_id,
|
||||
} = Codex::spawn(config, auth).await?;
|
||||
} = Codex::spawn(config, auth, None).await?;
|
||||
|
||||
// The first event must be `SessionInitialized`. Validate and forward it
|
||||
// to the caller so that they can display it in the conversation
|
||||
@@ -93,4 +93,149 @@ impl ConversationManager {
|
||||
.cloned()
|
||||
.ok_or_else(|| CodexErr::ConversationNotFound(conversation_id))
|
||||
}
|
||||
|
||||
/// Fork an existing conversation by dropping the last `drop_last_messages`
|
||||
/// user/assistant messages from its transcript and starting a new
|
||||
/// conversation with identical configuration (unless overridden by the
|
||||
/// caller's `config`). The new conversation will have a fresh id.
|
||||
pub async fn fork_conversation(
|
||||
&self,
|
||||
base_conversation_id: Uuid,
|
||||
drop_last_messages: usize,
|
||||
config: Config,
|
||||
) -> CodexResult<NewConversation> {
|
||||
// Obtain base conversation currently managed in memory.
|
||||
let base = self.get_conversation(base_conversation_id).await?;
|
||||
let items = base.history_contents();
|
||||
|
||||
// Compute the prefix up to the cut point.
|
||||
let fork_items = truncate_after_dropping_last_messages(items, drop_last_messages);
|
||||
|
||||
// Spawn a new conversation with the computed initial history.
|
||||
let auth = CodexAuth::from_codex_home(&config.codex_home, config.preferred_auth_method)?;
|
||||
let CodexSpawnOk {
|
||||
codex,
|
||||
session_id: conversation_id,
|
||||
} = Codex::spawn(config, auth, Some(fork_items)).await?;
|
||||
|
||||
// The first event must be `SessionInitialized`. Validate and forward it
|
||||
// to the caller so that they can display it in the conversation
|
||||
// history.
|
||||
let event = codex.next_event().await?;
|
||||
let session_configured = match event {
|
||||
Event {
|
||||
id,
|
||||
msg: EventMsg::SessionConfigured(session_configured),
|
||||
} if id == INITIAL_SUBMIT_ID => session_configured,
|
||||
_ => {
|
||||
return Err(CodexErr::SessionConfiguredNotFirstEvent);
|
||||
}
|
||||
};
|
||||
|
||||
let conversation = Arc::new(CodexConversation::new(codex));
|
||||
self.conversations
|
||||
.write()
|
||||
.await
|
||||
.insert(conversation_id, conversation.clone());
|
||||
|
||||
Ok(NewConversation {
|
||||
conversation_id,
|
||||
conversation,
|
||||
session_configured,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a prefix of `items` obtained by dropping the last `n` user messages
|
||||
/// and all items that follow them.
|
||||
fn truncate_after_dropping_last_messages(
|
||||
items: Vec<crate::models::ResponseItem>,
|
||||
n: usize,
|
||||
) -> Vec<crate::models::ResponseItem> {
|
||||
if n == 0 || items.is_empty() {
|
||||
return items;
|
||||
}
|
||||
|
||||
// Walk backwards counting only `user` Message items, find cut index.
|
||||
let mut count = 0usize;
|
||||
let mut cut_index = 0usize;
|
||||
for (idx, item) in items.iter().enumerate().rev() {
|
||||
if let crate::models::ResponseItem::Message { role, .. } = item
|
||||
&& role == "user"
|
||||
{
|
||||
count += 1;
|
||||
if count == n {
|
||||
// Cut everything from this user message to the end.
|
||||
cut_index = idx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if count < n {
|
||||
// If fewer than n messages exist, drop everything.
|
||||
return Vec::new();
|
||||
}
|
||||
items.into_iter().take(cut_index).collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::models::ContentItem;
|
||||
use crate::models::ReasoningItemReasoningSummary;
|
||||
use crate::models::ResponseItem;
|
||||
|
||||
fn user_msg(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
fn assistant_msg(text: &str) -> ResponseItem {
|
||||
ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: text.to_string(),
|
||||
}],
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn drops_from_last_user_only() {
|
||||
let items = vec![
|
||||
user_msg("u1"),
|
||||
assistant_msg("a1"),
|
||||
assistant_msg("a2"),
|
||||
user_msg("u2"),
|
||||
assistant_msg("a3"),
|
||||
ResponseItem::Reasoning {
|
||||
id: "r1".to_string(),
|
||||
summary: vec![ReasoningItemReasoningSummary::SummaryText {
|
||||
text: "s".to_string(),
|
||||
}],
|
||||
content: None,
|
||||
encrypted_content: None,
|
||||
},
|
||||
ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: "tool".to_string(),
|
||||
arguments: "{}".to_string(),
|
||||
call_id: "c1".to_string(),
|
||||
},
|
||||
assistant_msg("a4"),
|
||||
];
|
||||
|
||||
let truncated = truncate_after_dropping_last_messages(items.clone(), 1);
|
||||
assert_eq!(
|
||||
truncated,
|
||||
vec![items[0].clone(), items[1].clone(), items[2].clone()]
|
||||
);
|
||||
|
||||
let truncated2 = truncate_after_dropping_last_messages(items, 2);
|
||||
assert!(truncated2.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_protocol::mcp_protocol::EXEC_COMMAND_APPROVAL_METHOD;
|
||||
use codex_protocol::mcp_protocol::ExecCommandApprovalParams;
|
||||
use codex_protocol::mcp_protocol::ExecCommandApprovalResponse;
|
||||
use codex_protocol::mcp_protocol::ForkConversationParams;
|
||||
use codex_protocol::mcp_protocol::InputItem as WireInputItem;
|
||||
use codex_protocol::mcp_protocol::InterruptConversationParams;
|
||||
use codex_protocol::mcp_protocol::InterruptConversationResponse;
|
||||
@@ -114,6 +115,10 @@ impl CodexMessageProcessor {
|
||||
// created before processing any subsequent messages.
|
||||
self.process_new_conversation(request_id, params).await;
|
||||
}
|
||||
ClientRequest::ForkConversation { request_id, params } => {
|
||||
// Same reasoning as NewConversation: ensure ordering.
|
||||
self.process_fork_conversation(request_id, params).await;
|
||||
}
|
||||
ClientRequest::SendUserMessage { request_id, params } => {
|
||||
self.send_user_message(request_id, params).await;
|
||||
}
|
||||
@@ -377,6 +382,79 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_fork_conversation(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
params: ForkConversationParams,
|
||||
) {
|
||||
let ForkConversationParams {
|
||||
conversation_id,
|
||||
drop_last_messages,
|
||||
overrides,
|
||||
} = params;
|
||||
|
||||
// Verify the base conversation exists.
|
||||
if self
|
||||
.conversation_manager
|
||||
.get_conversation(conversation_id.0)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("conversation not found: {conversation_id}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
|
||||
// Derive config from overrides (or defaults) for the new conversation.
|
||||
let new_conv_params = overrides.unwrap_or_default();
|
||||
let config = match derive_config_from_params(
|
||||
new_conv_params,
|
||||
self.codex_linux_sandbox_exe.clone(),
|
||||
) {
|
||||
Ok(config) => config,
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: format!("error deriving config: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match self
|
||||
.conversation_manager
|
||||
.fork_conversation(conversation_id.0, drop_last_messages, config)
|
||||
.await
|
||||
{
|
||||
Ok(new_conv) => {
|
||||
let NewConversation {
|
||||
conversation_id,
|
||||
session_configured,
|
||||
..
|
||||
} = new_conv;
|
||||
let response = NewConversationResponse {
|
||||
conversation_id: ConversationId(conversation_id),
|
||||
model: session_configured.model,
|
||||
};
|
||||
self.outgoing.send_response(request_id, response).await;
|
||||
}
|
||||
Err(err) => {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INTERNAL_ERROR_CODE,
|
||||
message: format!("error forking conversation: {err}"),
|
||||
data: None,
|
||||
};
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_user_message(&self, request_id: RequestId, params: SendUserMessageParams) {
|
||||
let SendUserMessageParams {
|
||||
conversation_id,
|
||||
|
||||
@@ -22,6 +22,7 @@ pub fn generate_ts(out_dir: &Path, prettier: Option<&Path>) -> Result<()> {
|
||||
codex_protocol::mcp_protocol::ServerRequest::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::NewConversationParams::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::NewConversationResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::ForkConversationParams::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::AddConversationListenerParams::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::AddConversationSubscriptionResponse::export_all_to(out_dir)?;
|
||||
codex_protocol::mcp_protocol::RemoveConversationListenerParams::export_all_to(out_dir)?;
|
||||
|
||||
@@ -53,6 +53,15 @@ pub enum ClientRequest {
|
||||
request_id: RequestId,
|
||||
params: NewConversationParams,
|
||||
},
|
||||
/// Start a new conversation by forking an existing one and truncating the
|
||||
/// last N user/assistant messages from its transcript. The new
|
||||
/// conversation will have a fresh conversation id; all other
|
||||
/// configuration can be optionally overridden via `overrides`.
|
||||
ForkConversation {
|
||||
#[serde(rename = "id")]
|
||||
request_id: RequestId,
|
||||
params: ForkConversationParams,
|
||||
},
|
||||
SendUserMessage {
|
||||
#[serde(rename = "id")]
|
||||
request_id: RequestId,
|
||||
@@ -152,6 +161,20 @@ pub struct NewConversationResponse {
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ForkConversationParams {
|
||||
/// Existing conversation to fork from.
|
||||
pub conversation_id: ConversationId,
|
||||
/// Positive number of trailing user/assistant messages to drop in the
|
||||
/// fork. `1` means the last message is excluded; `2` excludes the last
|
||||
/// two messages, and so on.
|
||||
pub drop_last_messages: usize,
|
||||
/// Optional overrides for the new conversation's initial configuration.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub overrides: Option<NewConversationParams>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AddConversationSubscriptionResponse {
|
||||
|
||||
Reference in New Issue
Block a user