fork conversation

This commit is contained in:
Ahmed Ibrahim
2025-08-21 21:10:06 -07:00
parent 750ca9e21d
commit 682ec7f0ef
6 changed files with 290 additions and 13 deletions

View File

@@ -130,6 +130,7 @@ pub struct Codex {
next_id: AtomicU64,
tx_sub: Sender<Submission>,
rx_event: Receiver<Event>,
session: Arc<Session>,
}
/// Wrapper returned by [`Codex::spawn`] containing the spawned [`Codex`],
@@ -144,7 +145,11 @@ pub(crate) const INITIAL_SUBMIT_ID: &str = "";
impl Codex {
/// Spawn a new [`Codex`] and initialize the session.
pub async fn spawn(config: Config, auth: Option<CodexAuth>) -> CodexResult<CodexSpawnOk> {
pub async fn spawn(
config: Config,
auth: Option<CodexAuth>,
initial_history: Option<Vec<ResponseItem>>,
) -> CodexResult<CodexSpawnOk> {
let (tx_sub, rx_sub) = async_channel::bounded(64);
let (tx_event, rx_event) = async_channel::unbounded();
@@ -169,21 +174,32 @@ impl Codex {
};
// Generate a unique ID for the lifetime of this Codex session.
let (session, turn_context) =
Session::new(configure_session, config.clone(), auth, tx_event.clone())
.await
.map_err(|e| {
error!("Failed to create session: {e:#}");
CodexErr::InternalAgentDied
})?;
let (session, turn_context) = Session::new(
configure_session,
config.clone(),
auth,
tx_event.clone(),
initial_history,
)
.await
.map_err(|e| {
error!("Failed to create session: {e:#}");
CodexErr::InternalAgentDied
})?;
let session_id = session.session_id;
// This task will run until Op::Shutdown is received.
tokio::spawn(submission_loop(session, turn_context, config, rx_sub));
tokio::spawn(submission_loop(
Arc::clone(&session),
turn_context,
config,
rx_sub,
));
let codex = Codex {
next_id: AtomicU64::new(0),
tx_sub,
rx_event,
session: Arc::clone(&session),
};
Ok(CodexSpawnOk { codex, session_id })
@@ -218,6 +234,11 @@ impl Codex {
.map_err(|_| CodexErr::InternalAgentDied)?;
Ok(event)
}
/// Snapshot of the conversation history (oldest → newest).
pub(crate) fn history_contents(&self) -> Vec<ResponseItem> {
self.session.state.lock_unchecked().history.contents()
}
}
/// Mutable state of the agent
@@ -325,6 +346,7 @@ impl Session {
config: Arc<Config>,
auth: Option<CodexAuth>,
tx_event: Sender<Event>,
initial_history: Option<Vec<ResponseItem>>,
) -> anyhow::Result<(Arc<Self>, TurnContext)> {
let ConfigureSession {
provider,
@@ -384,14 +406,16 @@ impl Session {
}
let rollout_result = match rollout_res {
Ok((session_id, maybe_saved, recorder)) => {
let restored_items: Option<Vec<ResponseItem>> =
maybe_saved.and_then(|saved_session| {
let restored_items: Option<Vec<ResponseItem>> = match initial_history {
Some(items) => Some(items),
None => maybe_saved.and_then(|saved_session| {
if saved_session.items.is_empty() {
None
} else {
Some(saved_session.items)
}
});
}),
};
RolloutResult {
session_id,
rollout_recorder: Some(recorder),

View File

@@ -1,5 +1,6 @@
use crate::codex::Codex;
use crate::error::Result as CodexResult;
use crate::models::ResponseItem;
use crate::protocol::Event;
use crate::protocol::Op;
use crate::protocol::Submission;
@@ -27,4 +28,9 @@ impl CodexConversation {
pub async fn next_event(&self) -> CodexResult<Event> {
self.codex.next_event().await
}
/// Return a snapshot of the current conversation history (oldest → newest).
pub(crate) fn history_contents(&self) -> Vec<ResponseItem> {
self.codex.history_contents()
}
}

View File

@@ -54,7 +54,7 @@ impl ConversationManager {
let CodexSpawnOk {
codex,
session_id: conversation_id,
} = Codex::spawn(config, auth).await?;
} = Codex::spawn(config, auth, None).await?;
// The first event must be `SessionInitialized`. Validate and forward it
// to the caller so that they can display it in the conversation
@@ -93,4 +93,149 @@ impl ConversationManager {
.cloned()
.ok_or_else(|| CodexErr::ConversationNotFound(conversation_id))
}
/// Fork an existing conversation by dropping the last `drop_last_messages`
/// user/assistant messages from its transcript and starting a new
/// conversation with identical configuration (unless overridden by the
/// caller's `config`). The new conversation will have a fresh id.
pub async fn fork_conversation(
&self,
base_conversation_id: Uuid,
drop_last_messages: usize,
config: Config,
) -> CodexResult<NewConversation> {
// Obtain base conversation currently managed in memory.
let base = self.get_conversation(base_conversation_id).await?;
let items = base.history_contents();
// Compute the prefix up to the cut point.
let fork_items = truncate_after_dropping_last_messages(items, drop_last_messages);
// Spawn a new conversation with the computed initial history.
let auth = CodexAuth::from_codex_home(&config.codex_home, config.preferred_auth_method)?;
let CodexSpawnOk {
codex,
session_id: conversation_id,
} = Codex::spawn(config, auth, Some(fork_items)).await?;
// The first event must be `SessionInitialized`. Validate and forward it
// to the caller so that they can display it in the conversation
// history.
let event = codex.next_event().await?;
let session_configured = match event {
Event {
id,
msg: EventMsg::SessionConfigured(session_configured),
} if id == INITIAL_SUBMIT_ID => session_configured,
_ => {
return Err(CodexErr::SessionConfiguredNotFirstEvent);
}
};
let conversation = Arc::new(CodexConversation::new(codex));
self.conversations
.write()
.await
.insert(conversation_id, conversation.clone());
Ok(NewConversation {
conversation_id,
conversation,
session_configured,
})
}
}
/// Return a prefix of `items` obtained by dropping the last `n` user messages
/// and all items that follow them.
fn truncate_after_dropping_last_messages(
items: Vec<crate::models::ResponseItem>,
n: usize,
) -> Vec<crate::models::ResponseItem> {
if n == 0 || items.is_empty() {
return items;
}
// Walk backwards counting only `user` Message items, find cut index.
let mut count = 0usize;
let mut cut_index = 0usize;
for (idx, item) in items.iter().enumerate().rev() {
if let crate::models::ResponseItem::Message { role, .. } = item
&& role == "user"
{
count += 1;
if count == n {
// Cut everything from this user message to the end.
cut_index = idx;
break;
}
}
}
if count < n {
// If fewer than n messages exist, drop everything.
return Vec::new();
}
items.into_iter().take(cut_index).collect()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::ContentItem;
use crate::models::ReasoningItemReasoningSummary;
use crate::models::ResponseItem;
fn user_msg(text: &str) -> ResponseItem {
ResponseItem::Message {
id: None,
role: "user".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
}
}
fn assistant_msg(text: &str) -> ResponseItem {
ResponseItem::Message {
id: None,
role: "assistant".to_string(),
content: vec![ContentItem::OutputText {
text: text.to_string(),
}],
}
}
#[test]
fn drops_from_last_user_only() {
let items = vec![
user_msg("u1"),
assistant_msg("a1"),
assistant_msg("a2"),
user_msg("u2"),
assistant_msg("a3"),
ResponseItem::Reasoning {
id: "r1".to_string(),
summary: vec![ReasoningItemReasoningSummary::SummaryText {
text: "s".to_string(),
}],
content: None,
encrypted_content: None,
},
ResponseItem::FunctionCall {
id: None,
name: "tool".to_string(),
arguments: "{}".to_string(),
call_id: "c1".to_string(),
},
assistant_msg("a4"),
];
let truncated = truncate_after_dropping_last_messages(items.clone(), 1);
assert_eq!(
truncated,
vec![items[0].clone(), items[1].clone(), items[2].clone()]
);
let truncated2 = truncate_after_dropping_last_messages(items, 2);
assert!(truncated2.is_empty());
}
}

View File

@@ -47,6 +47,7 @@ use codex_protocol::mcp_protocol::ConversationId;
use codex_protocol::mcp_protocol::EXEC_COMMAND_APPROVAL_METHOD;
use codex_protocol::mcp_protocol::ExecCommandApprovalParams;
use codex_protocol::mcp_protocol::ExecCommandApprovalResponse;
use codex_protocol::mcp_protocol::ForkConversationParams;
use codex_protocol::mcp_protocol::InputItem as WireInputItem;
use codex_protocol::mcp_protocol::InterruptConversationParams;
use codex_protocol::mcp_protocol::InterruptConversationResponse;
@@ -114,6 +115,10 @@ impl CodexMessageProcessor {
// created before processing any subsequent messages.
self.process_new_conversation(request_id, params).await;
}
ClientRequest::ForkConversation { request_id, params } => {
// Same reasoning as NewConversation: ensure ordering.
self.process_fork_conversation(request_id, params).await;
}
ClientRequest::SendUserMessage { request_id, params } => {
self.send_user_message(request_id, params).await;
}
@@ -377,6 +382,79 @@ impl CodexMessageProcessor {
}
}
async fn process_fork_conversation(
&self,
request_id: RequestId,
params: ForkConversationParams,
) {
let ForkConversationParams {
conversation_id,
drop_last_messages,
overrides,
} = params;
// Verify the base conversation exists.
if self
.conversation_manager
.get_conversation(conversation_id.0)
.await
.is_err()
{
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: format!("conversation not found: {conversation_id}"),
data: None,
};
self.outgoing.send_error(request_id, error).await;
return;
}
// Derive config from overrides (or defaults) for the new conversation.
let new_conv_params = overrides.unwrap_or_default();
let config = match derive_config_from_params(
new_conv_params,
self.codex_linux_sandbox_exe.clone(),
) {
Ok(config) => config,
Err(err) => {
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: format!("error deriving config: {err}"),
data: None,
};
self.outgoing.send_error(request_id, error).await;
return;
}
};
match self
.conversation_manager
.fork_conversation(conversation_id.0, drop_last_messages, config)
.await
{
Ok(new_conv) => {
let NewConversation {
conversation_id,
session_configured,
..
} = new_conv;
let response = NewConversationResponse {
conversation_id: ConversationId(conversation_id),
model: session_configured.model,
};
self.outgoing.send_response(request_id, response).await;
}
Err(err) => {
let error = JSONRPCErrorError {
code: INTERNAL_ERROR_CODE,
message: format!("error forking conversation: {err}"),
data: None,
};
self.outgoing.send_error(request_id, error).await;
}
}
}
async fn send_user_message(&self, request_id: RequestId, params: SendUserMessageParams) {
let SendUserMessageParams {
conversation_id,

View File

@@ -22,6 +22,7 @@ pub fn generate_ts(out_dir: &Path, prettier: Option<&Path>) -> Result<()> {
codex_protocol::mcp_protocol::ServerRequest::export_all_to(out_dir)?;
codex_protocol::mcp_protocol::NewConversationParams::export_all_to(out_dir)?;
codex_protocol::mcp_protocol::NewConversationResponse::export_all_to(out_dir)?;
codex_protocol::mcp_protocol::ForkConversationParams::export_all_to(out_dir)?;
codex_protocol::mcp_protocol::AddConversationListenerParams::export_all_to(out_dir)?;
codex_protocol::mcp_protocol::AddConversationSubscriptionResponse::export_all_to(out_dir)?;
codex_protocol::mcp_protocol::RemoveConversationListenerParams::export_all_to(out_dir)?;

View File

@@ -53,6 +53,15 @@ pub enum ClientRequest {
request_id: RequestId,
params: NewConversationParams,
},
/// Start a new conversation by forking an existing one and truncating the
/// last N user/assistant messages from its transcript. The new
/// conversation will have a fresh conversation id; all other
/// configuration can be optionally overridden via `overrides`.
ForkConversation {
#[serde(rename = "id")]
request_id: RequestId,
params: ForkConversationParams,
},
SendUserMessage {
#[serde(rename = "id")]
request_id: RequestId,
@@ -152,6 +161,20 @@ pub struct NewConversationResponse {
pub model: String,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(rename_all = "camelCase")]
pub struct ForkConversationParams {
/// Existing conversation to fork from.
pub conversation_id: ConversationId,
/// Positive number of trailing user/assistant messages to drop in the
/// fork. `1` means the last message is excluded; `2` excludes the last
/// two messages, and so on.
pub drop_last_messages: usize,
/// Optional overrides for the new conversation's initial configuration.
#[serde(skip_serializing_if = "Option::is_none")]
pub overrides: Option<NewConversationParams>,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, TS)]
#[serde(rename_all = "camelCase")]
pub struct AddConversationSubscriptionResponse {