Compare commits

...

2 Commits

Author SHA1 Message Date
aibrahim-oai
0df5282b72 Merge branch 'main' into codex/add-session-loading-logic-in-codex-rs 2025-07-17 10:27:48 -07:00
aibrahim-oai
84934a6a62 core: support session loading 2025-07-17 10:26:37 -07:00
3 changed files with 163 additions and 50 deletions

View File

@@ -202,6 +202,20 @@ impl Session {
.map(PathBuf::from)
.map_or_else(|| self.cwd.clone(), |p| self.cwd.join(p))
}
pub async fn load_rollout(&self, path: std::path::PathBuf) -> std::io::Result<()> {
let (rec, saved) = crate::rollout::RolloutRecorder::resume(&path).await?;
{
let mut state = self.state.lock().unwrap();
state.previous_response_id = saved.state.previous_response_id;
if let Some(transcript) = state.zdr_transcript.as_mut() {
transcript.record_items(saved.items.iter());
}
}
let mut guard = self.rollout.lock().unwrap();
*guard = Some(rec);
Ok(())
}
}
/// Mutable state of the agent
@@ -309,6 +323,7 @@ impl Session {
async fn record_conversation_items(&self, items: &[ResponseItem]) {
debug!("Recording items for conversation: {items:?}");
self.record_rollout_items(items).await;
self.record_state_snapshot().await;
if let Some(transcript) = self.state.lock().unwrap().zdr_transcript.as_mut() {
transcript.record_items(items);
@@ -332,6 +347,26 @@ impl Session {
}
}
async fn record_state_snapshot(&self) {
let snapshot = {
let state = self.state.lock().unwrap();
crate::rollout::SessionStateSnapshot {
previous_response_id: state.previous_response_id.clone(),
}
};
let recorder = {
let guard = self.rollout.lock().unwrap();
guard.as_ref().cloned()
};
if let Some(rec) = recorder {
if let Err(e) = rec.record_state(snapshot).await {
error!("failed to record rollout state: {e:#}");
}
}
}
async fn notify_exec_command_begin(&self, sub_id: &str, call_id: &str, params: &ExecParams) {
let event = Event {
id: sub_id.to_string(),
@@ -744,6 +779,22 @@ async fn submission_loop(
other => sess.notify_approval(&id, other),
}
}
Op::LoadSession { path } => {
let sess = match sess.as_ref() {
Some(sess) => sess,
None => {
send_no_session_event(sub.id).await;
continue;
}
};
if let Err(e) = sess.load_rollout(path).await {
let event = Event {
id: sub.id,
msg: EventMsg::Error(ErrorEvent { message: e.to_string() }),
};
tx_event.send(event).await.ok();
}
}
Op::AddToHistory { text } => {
let id = session_id;
let config = config.clone();

View File

@@ -97,6 +97,9 @@ pub enum Op {
decision: ReviewDecision,
},
/// Load a previously saved session from disk and resume from it.
LoadSession { path: std::path::PathBuf },
/// Append an entry to the persistent cross-session message history.
///
/// Note the entry is not guaranteed to be logged if the user has

View File

@@ -7,11 +7,11 @@ use std::fs::File;
use std::fs::{self};
use std::io::Error as IoError;
use serde::Serialize;
use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use time::format_description::FormatItem;
use time::macros::format_description;
use tokio::io::AsyncWriteExt;
use tokio::io::{AsyncSeekExt, AsyncWriteExt};
use tokio::sync::mpsc::Sender;
use tokio::sync::mpsc::{self};
use uuid::Uuid;
@@ -22,26 +22,60 @@ use crate::models::ResponseItem;
/// Folder inside `~/.codex` that holds saved rollouts.
const SESSIONS_SUBDIR: &str = "sessions";
#[derive(Serialize)]
struct SessionMeta {
id: String,
timestamp: String,
#[derive(Serialize, Deserialize, Clone, Default)]
pub struct SessionMeta {
pub id: String,
pub timestamp: String,
#[serde(skip_serializing_if = "Option::is_none")]
instructions: Option<String>,
pub instructions: Option<String>,
}
#[derive(Serialize, Deserialize, Default, Clone)]
pub struct SessionStateSnapshot {
#[serde(skip_serializing_if = "Option::is_none")]
pub previous_response_id: Option<String>,
}
#[derive(Serialize, Deserialize, Default, Clone)]
pub struct SavedSession {
pub session: SessionMeta,
#[serde(default)]
pub items: Vec<ResponseItem>,
#[serde(default)]
pub state: SessionStateSnapshot,
}
/// Records all [`ResponseItem`]s for a session and flushes them to disk after
/// every update.
///
/// Rollouts are recorded as JSONL and can be inspected with tools such as:
/// Rollouts are recorded as JSON and can be inspected with tools such as:
///
/// ```ignore
/// $ jq -C . ~/.codex/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.jsonl
/// $ fx ~/.codex/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.jsonl
/// $ jq -C . ~/.codex/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.json
/// $ fx ~/.codex/sessions/rollout-2025-05-07T17-24-21-5973b6c0-94b8-487b-a530-2aeb6098ae0e.json
/// ```
#[derive(Clone)]
pub(crate) struct RolloutRecorder {
tx: Sender<String>,
tx: Sender<RolloutCmd>,
}
#[derive(Clone)]
enum RolloutCmd {
AddItems(Vec<ResponseItem>),
UpdateState(SessionStateSnapshot),
}
async fn write_session(file: &mut tokio::fs::File, data: &SavedSession) {
if file.seek(std::io::SeekFrom::Start(0)).await.is_err() {
return;
}
if file.set_len(0).await.is_err() {
return;
}
if let Ok(json) = serde_json::to_vec_pretty(data) {
let _ = file.write_all(&json).await;
}
let _ = file.flush().await;
}
impl RolloutRecorder {
@@ -76,68 +110,92 @@ impl RolloutRecorder {
// A reasonably-sized bounded channel. If the buffer fills up the send
// future will yield, which is fine we only need to ensure we do not
// perform *blocking* I/O on the callers thread.
let (tx, mut rx) = mpsc::channel::<String>(256);
let (tx, mut rx) = mpsc::channel::<RolloutCmd>(256);
let mut data = SavedSession {
session: meta,
items: Vec::new(),
state: SessionStateSnapshot::default(),
};
// Spawn a Tokio task that owns the file handle and performs async
// writes. Using `tokio::fs::File` keeps everything on the async I/O
// driver instead of blocking the runtime.
tokio::task::spawn(async move {
let mut file = tokio::fs::File::from_std(file);
while let Some(line) = rx.recv().await {
// Write line + newline, then flush to disk.
if let Err(e) = file.write_all(line.as_bytes()).await {
tracing::warn!("rollout writer: failed to write line: {e}");
break;
}
if let Err(e) = file.write_all(b"\n").await {
tracing::warn!("rollout writer: failed to write newline: {e}");
break;
}
if let Err(e) = file.flush().await {
tracing::warn!("rollout writer: failed to flush: {e}");
break;
write_session(&mut file, &data).await;
while let Some(cmd) = rx.recv().await {
match cmd {
RolloutCmd::AddItems(items) => data.items.extend(items),
RolloutCmd::UpdateState(state) => data.state = state,
}
write_session(&mut file, &data).await;
}
});
let recorder = Self { tx };
// Ensure SessionMeta is the first item in the file.
recorder.record_item(&meta).await?;
Ok(recorder)
}
/// Append `items` to the rollout file.
pub(crate) async fn record_items(&self, items: &[ResponseItem]) -> std::io::Result<()> {
let mut filtered = Vec::new();
for item in items {
match item {
// Note that function calls may look a bit strange if they are
// "fully qualified MCP tool calls," so we could consider
// reformatting them in that case.
ResponseItem::Message { .. }
| ResponseItem::LocalShellCall { .. }
| ResponseItem::FunctionCall { .. }
| ResponseItem::FunctionCallOutput { .. } => {}
ResponseItem::Reasoning { .. } | ResponseItem::Other => {
// These should never be serialized.
continue;
}
| ResponseItem::FunctionCallOutput { .. } => filtered.push(item.clone()),
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
}
self.record_item(item).await?;
}
Ok(())
if filtered.is_empty() {
return Ok(());
}
self.tx
.send(RolloutCmd::AddItems(filtered))
.await
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
}
async fn record_item(&self, item: &impl Serialize) -> std::io::Result<()> {
// Serialize the item to JSON first so that the writer thread only has
// to perform the actual write.
let json = serde_json::to_string(item)
.map_err(|e| IoError::other(format!("failed to serialize response items: {e}")))?;
pub(crate) async fn record_state(&self, state: SessionStateSnapshot) -> std::io::Result<()> {
self.tx
.send(json)
.send(RolloutCmd::UpdateState(state))
.await
.map_err(|e| IoError::other(format!("failed to queue rollout item: {e}")))
.map_err(|e| IoError::other(format!("failed to queue rollout state: {e}")))
}
pub async fn resume(path: &std::path::Path) -> std::io::Result<(Self, SavedSession)> {
let bytes = tokio::fs::read(path).await?;
let saved: SavedSession = serde_json::from_slice(&bytes)
.map_err(|e| IoError::other(format!("failed to parse session: {e}")))?;
let file = std::fs::OpenOptions::new()
.write(true)
.read(true)
.open(path)?;
let saved_clone = saved.clone();
let (tx, mut rx) = mpsc::channel::<RolloutCmd>(256);
tokio::task::spawn(async move {
let mut data = saved_clone;
let mut file = tokio::fs::File::from_std(file);
write_session(&mut file, &data).await;
while let Some(cmd) = rx.recv().await {
match cmd {
RolloutCmd::AddItems(items) => data.items.extend(items),
RolloutCmd::UpdateState(state) => data.state = state,
}
write_session(&mut file, &data).await;
}
});
Ok((Self { tx }, saved))
}
pub async fn load(path: &std::path::Path) -> std::io::Result<SavedSession> {
let bytes = tokio::fs::read(path).await?;
let saved: SavedSession = serde_json::from_slice(&bytes)
.map_err(|e| IoError::other(format!("failed to parse session: {e}")))?;
Ok(saved)
}
}
@@ -171,12 +229,13 @@ fn create_log_file(config: &Config, session_id: Uuid) -> std::io::Result<LogFile
.format(format)
.map_err(|e| IoError::other(format!("failed to format timestamp: {e}")))?;
let filename = format!("rollout-{date_str}-{session_id}.jsonl");
let filename = format!("rollout-{date_str}-{session_id}.json");
let path = dir.join(filename);
let file = std::fs::OpenOptions::new()
.append(true)
.write(true)
.create(true)
.truncate(true)
.open(&path)?;
Ok(LogFileInfo {