mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
progress
This commit is contained in:
@@ -100,3 +100,6 @@ This folder is the root of a Cargo workspace. It contains quite a bit of experim
|
||||
- [`exec/`](./exec) "headless" CLI for use in automation.
|
||||
- [`tui/`](./tui) CLI that launches a fullscreen TUI built with [Ratatui](https://ratatui.rs/).
|
||||
- [`cli/`](./cli) CLI multitool that provides the aforementioned CLIs via subcommands.
|
||||
- [`cli/`](./cli) CLI multitool that provides the aforementioned CLIs via subcommands.
|
||||
|
||||
high
|
||||
|
||||
@@ -10,6 +10,7 @@ use std::time::Duration;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::event_mapping::map_response_item_to_event_messages;
|
||||
use crate::rollout::RolloutItem;
|
||||
use async_channel::Receiver;
|
||||
use async_channel::Sender;
|
||||
use codex_apply_patch::ApplyPatchAction;
|
||||
@@ -492,13 +493,10 @@ impl Session {
|
||||
|
||||
// Dispatch the SessionConfiguredEvent first and then report any errors.
|
||||
// If resuming, include converted initial messages in the payload so UIs can render them immediately.
|
||||
let initial_messages = match &initial_history {
|
||||
InitialHistory::New => None,
|
||||
InitialHistory::Forked(items) => Some(sess.build_initial_messages(items)),
|
||||
InitialHistory::Resumed(resumed_history) => {
|
||||
Some(sess.build_initial_messages(&resumed_history.history))
|
||||
}
|
||||
};
|
||||
let initial_messages = Some(
|
||||
sess.apply_initial_history(&turn_context, initial_history.clone())
|
||||
.await,
|
||||
);
|
||||
|
||||
let events = std::iter::once(Event {
|
||||
id: INITIAL_SUBMIT_ID.to_owned(),
|
||||
@@ -512,9 +510,7 @@ impl Session {
|
||||
})
|
||||
.chain(post_session_configured_error_events.into_iter());
|
||||
for event in events {
|
||||
if let Err(e) = tx_event.send(event).await {
|
||||
error!("failed to send event: {e:?}");
|
||||
}
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
|
||||
Ok((sess, turn_context))
|
||||
@@ -537,26 +533,31 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
async fn record_initial_history(
|
||||
async fn apply_initial_history(
|
||||
&self,
|
||||
turn_context: &TurnContext,
|
||||
conversation_history: InitialHistory,
|
||||
) {
|
||||
) -> Vec<EventMsg> {
|
||||
match conversation_history {
|
||||
InitialHistory::New => {
|
||||
self.record_initial_history_new(turn_context).await;
|
||||
}
|
||||
InitialHistory::New => self.record_initial_history_new(turn_context).await,
|
||||
InitialHistory::Forked(items) => {
|
||||
self.record_initial_history_from_items(items).await;
|
||||
self.record_conversation_items_internal(&items, false).await;
|
||||
items
|
||||
.into_iter()
|
||||
.flat_map(|ri| {
|
||||
map_response_item_to_event_messages(&ri, self.show_raw_agent_reasoning)
|
||||
})
|
||||
.filter(|m| matches!(m, EventMsg::UserMessage(_)))
|
||||
.collect()
|
||||
}
|
||||
InitialHistory::Resumed(resumed_history) => {
|
||||
self.record_initial_history_from_items(resumed_history.history)
|
||||
.await;
|
||||
self.record_initial_history_resumed(resumed_history.history)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn record_initial_history_new(&self, turn_context: &TurnContext) {
|
||||
async fn record_initial_history_new(&self, turn_context: &TurnContext) -> Vec<EventMsg> {
|
||||
// record the initial user instructions and environment context,
|
||||
// regardless of whether we restored items.
|
||||
// TODO: Those items shouldn't be "user messages" IMO. Maybe developer messages.
|
||||
@@ -571,28 +572,59 @@ impl Session {
|
||||
Some(self.user_shell.clone()),
|
||||
)));
|
||||
self.record_conversation_items(&conversation_items).await;
|
||||
vec![]
|
||||
}
|
||||
|
||||
async fn record_initial_history_from_items(&self, items: Vec<ResponseItem>) {
|
||||
self.record_conversation_items_internal(&items, false).await;
|
||||
}
|
||||
|
||||
/// build the initial messages vector for SessionConfigured by converting
|
||||
/// ResponseItems into EventMsg.
|
||||
fn build_initial_messages(&self, items: &[ResponseItem]) -> Vec<EventMsg> {
|
||||
items
|
||||
.iter()
|
||||
.flat_map(|item| {
|
||||
map_response_item_to_event_messages(item, self.show_raw_agent_reasoning)
|
||||
})
|
||||
.collect()
|
||||
async fn record_initial_history_resumed(&self, items: Vec<RolloutItem>) -> Vec<EventMsg> {
|
||||
let mut responses: Vec<ResponseItem> = Vec::new();
|
||||
for item in items.clone() {
|
||||
if let RolloutItem::ResponseItem(v) = item {
|
||||
responses.extend(v);
|
||||
}
|
||||
}
|
||||
if !responses.is_empty() {
|
||||
self.record_conversation_items_internal(&responses, false)
|
||||
.await;
|
||||
}
|
||||
|
||||
let mut msgs = Vec::new();
|
||||
for item in items {
|
||||
match item {
|
||||
RolloutItem::ResponseItem(responses) => msgs.extend(
|
||||
responses
|
||||
.iter()
|
||||
.flat_map(|ri| {
|
||||
map_response_item_to_event_messages(ri, self.show_raw_agent_reasoning)
|
||||
})
|
||||
.filter(|m| matches!(m, EventMsg::UserMessage(_))),
|
||||
),
|
||||
RolloutItem::Event(events) => msgs.extend(events.iter().map(|e| e.msg.clone())),
|
||||
}
|
||||
}
|
||||
msgs
|
||||
}
|
||||
|
||||
/// Sends the given event to the client and swallows the send event, if
|
||||
/// any, logging it as an error.
|
||||
/// Sends the given event to the client and records it to the rollout (if enabled).
|
||||
/// Any send/record errors are logged and swallowed.
|
||||
pub(crate) async fn send_event(&self, event: Event) {
|
||||
let event_to_record = event.clone();
|
||||
if let Err(e) = self.tx_event.send(event).await {
|
||||
error!("failed to send tool call event: {e}");
|
||||
error!("failed to send event: {e}");
|
||||
}
|
||||
let recorder = {
|
||||
let guard = self.rollout.lock_unchecked();
|
||||
guard.as_ref().cloned()
|
||||
};
|
||||
if let Some(rec) = recorder
|
||||
&& let Err(e) = rec
|
||||
.record_items(crate::rollout::RolloutItem::Event(vec![event_to_record]))
|
||||
.await
|
||||
{
|
||||
error!("failed to record rollout event: {e:#}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -624,7 +656,7 @@ impl Session {
|
||||
reason,
|
||||
}),
|
||||
};
|
||||
let _ = self.tx_event.send(event).await;
|
||||
self.send_event(event).await;
|
||||
rx_approve
|
||||
}
|
||||
|
||||
@@ -656,7 +688,7 @@ impl Session {
|
||||
grant_root,
|
||||
}),
|
||||
};
|
||||
let _ = self.tx_event.send(event).await;
|
||||
self.send_event(event).await;
|
||||
rx_approve
|
||||
}
|
||||
|
||||
@@ -689,27 +721,23 @@ impl Session {
|
||||
async fn record_conversation_items_internal(&self, items: &[ResponseItem], persist: bool) {
|
||||
debug!("Recording items for conversation: {items:?}");
|
||||
if persist {
|
||||
self.record_state_snapshot(items).await;
|
||||
self.record_state_snapshot(RolloutItem::ResponseItem(items.to_vec()))
|
||||
.await;
|
||||
}
|
||||
|
||||
self.state.lock_unchecked().history.record_items(items);
|
||||
}
|
||||
|
||||
async fn record_state_snapshot(&self, items: &[ResponseItem]) {
|
||||
let snapshot = { crate::rollout::SessionStateSnapshot {} };
|
||||
|
||||
async fn record_state_snapshot(&self, item: RolloutItem) {
|
||||
let recorder = {
|
||||
let guard = self.rollout.lock_unchecked();
|
||||
guard.as_ref().cloned()
|
||||
};
|
||||
|
||||
if let Some(rec) = recorder {
|
||||
if let Err(e) = rec.record_state(snapshot).await {
|
||||
error!("failed to record rollout state: {e:#}");
|
||||
}
|
||||
if let Err(e) = rec.record_items(items).await {
|
||||
error!("failed to record rollout items: {e:#}");
|
||||
}
|
||||
if let Some(rec) = recorder
|
||||
&& let Err(e) = rec.record_items(item).await
|
||||
{
|
||||
error!("failed to record rollout items: {e:#}");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -752,7 +780,7 @@ impl Session {
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
let _ = self.tx_event.send(event).await;
|
||||
self.send_event(event).await;
|
||||
}
|
||||
|
||||
async fn on_exec_command_end(
|
||||
@@ -799,7 +827,7 @@ impl Session {
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
let _ = self.tx_event.send(event).await;
|
||||
self.send_event(event).await;
|
||||
|
||||
// If this is an apply_patch, after we emit the end patch, emit a second event
|
||||
// with the full turn diff if there is one.
|
||||
@@ -811,7 +839,7 @@ impl Session {
|
||||
id: sub_id.into(),
|
||||
msg,
|
||||
};
|
||||
let _ = self.tx_event.send(event).await;
|
||||
self.send_event(event).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -877,7 +905,7 @@ impl Session {
|
||||
message: message.into(),
|
||||
}),
|
||||
};
|
||||
let _ = self.tx_event.send(event).await;
|
||||
self.send_event(event).await;
|
||||
}
|
||||
|
||||
async fn notify_stream_error(&self, sub_id: &str, message: impl Into<String>) {
|
||||
@@ -887,7 +915,7 @@ impl Session {
|
||||
message: message.into(),
|
||||
}),
|
||||
};
|
||||
let _ = self.tx_event.send(event).await;
|
||||
self.send_event(event).await;
|
||||
}
|
||||
|
||||
/// Build the full turn input by concatenating the current conversation
|
||||
@@ -1050,9 +1078,9 @@ impl AgentTask {
|
||||
id: self.sub_id,
|
||||
msg: EventMsg::TurnAborted(TurnAbortedEvent { reason }),
|
||||
};
|
||||
let tx_event = self.sess.tx_event.clone();
|
||||
let sess = self.sess.clone();
|
||||
tokio::spawn(async move {
|
||||
tx_event.send(event).await.ok();
|
||||
sess.send_event(event).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1257,7 +1285,7 @@ async fn submission_loop(
|
||||
|
||||
Op::GetHistoryEntryRequest { offset, log_id } => {
|
||||
let config = config.clone();
|
||||
let tx_event = sess.tx_event.clone();
|
||||
let sess_for_spawn = sess.clone();
|
||||
let sub_id = sub.id.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
@@ -1285,13 +1313,10 @@ async fn submission_loop(
|
||||
),
|
||||
};
|
||||
|
||||
if let Err(e) = tx_event.send(event).await {
|
||||
warn!("failed to send GetHistoryEntryResponse event: {e}");
|
||||
}
|
||||
sess_for_spawn.send_event(event).await;
|
||||
});
|
||||
}
|
||||
Op::ListMcpTools => {
|
||||
let tx_event = sess.tx_event.clone();
|
||||
let sub_id = sub.id.clone();
|
||||
|
||||
// This is a cheap lookup from the connection manager's cache.
|
||||
@@ -1302,12 +1327,9 @@ async fn submission_loop(
|
||||
crate::protocol::McpListToolsResponseEvent { tools },
|
||||
),
|
||||
};
|
||||
if let Err(e) = tx_event.send(event).await {
|
||||
warn!("failed to send McpListToolsResponse event: {e}");
|
||||
}
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
Op::ListCustomPrompts => {
|
||||
let tx_event = sess.tx_event.clone();
|
||||
let sub_id = sub.id.clone();
|
||||
|
||||
let custom_prompts: Vec<CustomPrompt> =
|
||||
@@ -1323,9 +1345,7 @@ async fn submission_loop(
|
||||
custom_prompts,
|
||||
}),
|
||||
};
|
||||
if let Err(e) = tx_event.send(event).await {
|
||||
warn!("failed to send ListCustomPromptsResponse event: {e}");
|
||||
}
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
Op::Compact => {
|
||||
// Create a summarization request as user input
|
||||
@@ -1361,22 +1381,17 @@ async fn submission_loop(
|
||||
message: "Failed to shutdown rollout recorder".to_string(),
|
||||
}),
|
||||
};
|
||||
if let Err(e) = sess.tx_event.send(event).await {
|
||||
warn!("failed to send error message: {e:?}");
|
||||
}
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
|
||||
let event = Event {
|
||||
id: sub.id.clone(),
|
||||
msg: EventMsg::ShutdownComplete,
|
||||
};
|
||||
if let Err(e) = sess.tx_event.send(event).await {
|
||||
warn!("failed to send Shutdown event: {e}");
|
||||
}
|
||||
sess.send_event(event).await;
|
||||
break;
|
||||
}
|
||||
Op::GetHistory => {
|
||||
let tx_event = sess.tx_event.clone();
|
||||
let sub_id = sub.id.clone();
|
||||
|
||||
let event = Event {
|
||||
@@ -1386,9 +1401,7 @@ async fn submission_loop(
|
||||
entries: sess.state.lock_unchecked().history.contents(),
|
||||
}),
|
||||
};
|
||||
if let Err(e) = tx_event.send(event).await {
|
||||
warn!("failed to send ConversationHistory event: {e}");
|
||||
}
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
_ => {
|
||||
// Ignore unknown ops; enum is non_exhaustive to allow extensions.
|
||||
@@ -1426,12 +1439,10 @@ async fn run_task(
|
||||
model_context_window: turn_context.client.get_model_context_window(),
|
||||
}),
|
||||
};
|
||||
if sess.tx_event.send(event).await.is_err() {
|
||||
return;
|
||||
}
|
||||
sess.send_event(event).await;
|
||||
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||
sess.record_conversation_items(&[initial_input_for_turn.clone().into()])
|
||||
sess.record_conversation_items(&[ResponseItem::from(initial_input_for_turn.clone())])
|
||||
.await;
|
||||
|
||||
let mut last_agent_message: Option<String> = None;
|
||||
@@ -1600,7 +1611,7 @@ async fn run_task(
|
||||
message: e.to_string(),
|
||||
}),
|
||||
};
|
||||
sess.tx_event.send(event).await.ok();
|
||||
sess.send_event(event).await;
|
||||
// let the user continue the conversation
|
||||
break;
|
||||
}
|
||||
@@ -1611,7 +1622,7 @@ async fn run_task(
|
||||
id: sub_id,
|
||||
msg: EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }),
|
||||
};
|
||||
sess.tx_event.send(event).await.ok();
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
|
||||
async fn run_turn(
|
||||
@@ -1787,13 +1798,11 @@ async fn try_run_turn(
|
||||
output.push(ProcessedResponseItem { item, response });
|
||||
}
|
||||
ResponseEvent::WebSearchCallBegin { call_id } => {
|
||||
let _ = sess
|
||||
.tx_event
|
||||
.send(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::WebSearchBegin(WebSearchBeginEvent { call_id }),
|
||||
})
|
||||
.await;
|
||||
sess.send_event(Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::WebSearchBegin(WebSearchBeginEvent { call_id }),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
ResponseEvent::Completed {
|
||||
response_id: _,
|
||||
@@ -1824,7 +1833,7 @@ async fn try_run_turn(
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
let _ = sess.tx_event.send(event).await;
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
|
||||
return Ok(output);
|
||||
@@ -1834,21 +1843,21 @@ async fn try_run_turn(
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentMessageDelta(AgentMessageDeltaEvent { delta }),
|
||||
};
|
||||
sess.tx_event.send(event).await.ok();
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
ResponseEvent::ReasoningSummaryDelta(delta) => {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentReasoningDelta(AgentReasoningDeltaEvent { delta }),
|
||||
};
|
||||
sess.tx_event.send(event).await.ok();
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
ResponseEvent::ReasoningSummaryPartAdded => {
|
||||
let event = Event {
|
||||
id: sub_id.to_string(),
|
||||
msg: EventMsg::AgentReasoningSectionBreak(AgentReasoningSectionBreakEvent {}),
|
||||
};
|
||||
sess.tx_event.send(event).await.ok();
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
ResponseEvent::ReasoningContentDelta(delta) => {
|
||||
if sess.show_raw_agent_reasoning {
|
||||
@@ -1858,7 +1867,7 @@ async fn try_run_turn(
|
||||
AgentReasoningRawContentDeltaEvent { delta },
|
||||
),
|
||||
};
|
||||
sess.tx_event.send(event).await.ok();
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1879,9 +1888,7 @@ async fn run_compact_task(
|
||||
model_context_window,
|
||||
}),
|
||||
};
|
||||
if sess.tx_event.send(start_event).await.is_err() {
|
||||
return;
|
||||
}
|
||||
sess.send_event(start_event).await;
|
||||
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input);
|
||||
let turn_input: Vec<ResponseItem> =
|
||||
@@ -2059,7 +2066,7 @@ async fn handle_response_item(
|
||||
id: sub_id.to_string(),
|
||||
msg,
|
||||
};
|
||||
sess.tx_event.send(event).await.ok();
|
||||
sess.send_event(event).await;
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ use crate::error::Result as CodexResult;
|
||||
use crate::protocol::Event;
|
||||
use crate::protocol::EventMsg;
|
||||
use crate::protocol::SessionConfiguredEvent;
|
||||
use crate::rollout::RolloutItem;
|
||||
use crate::rollout::RolloutRecorder;
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -21,11 +22,11 @@ use tokio::sync::RwLock;
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ResumedHistory {
|
||||
pub conversation_id: ConversationId,
|
||||
pub history: Vec<ResponseItem>,
|
||||
pub history: Vec<RolloutItem>,
|
||||
pub rollout_path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum InitialHistory {
|
||||
New,
|
||||
Resumed(ResumedHistory),
|
||||
@@ -178,7 +179,12 @@ impl ConversationManager {
|
||||
/// and all items that follow them.
|
||||
fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) -> InitialHistory {
|
||||
if n == 0 {
|
||||
return InitialHistory::Forked(items);
|
||||
return InitialHistory::Resumed(
|
||||
items
|
||||
.into_iter()
|
||||
.map(|ri| RolloutItem::ResponseItem(vec![ri]))
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
|
||||
// Walk backwards counting only `user` Message items, find cut index.
|
||||
@@ -200,7 +206,13 @@ fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) ->
|
||||
// No prefix remains after dropping; start a new conversation.
|
||||
InitialHistory::New
|
||||
} else {
|
||||
InitialHistory::Forked(items.into_iter().take(cut_index).collect())
|
||||
InitialHistory::Resumed(
|
||||
items
|
||||
.into_iter()
|
||||
.take(cut_index)
|
||||
.map(|ri| RolloutItem::ResponseItem(vec![ri]))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,12 +268,34 @@ mod tests {
|
||||
];
|
||||
|
||||
let truncated = truncate_after_dropping_last_messages(items.clone(), 1);
|
||||
assert_eq!(
|
||||
truncated,
|
||||
InitialHistory::Forked(vec![items[0].clone(), items[1].clone(), items[2].clone(),])
|
||||
);
|
||||
if let InitialHistory::Resumed(resumed) = truncated {
|
||||
let get_text = |ri: &ResponseItem| -> Option<String> {
|
||||
if let ResponseItem::Message { content, .. } = ri {
|
||||
for c in content {
|
||||
if let ContentItem::OutputText { text } = c {
|
||||
return Some(text.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
};
|
||||
|
||||
let texts: Vec<String> = resumed
|
||||
.iter()
|
||||
.filter_map(|it| match it {
|
||||
RolloutItem::ResponseItem(v) if !v.is_empty() => get_text(&v[0]),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
assert_eq!(
|
||||
texts,
|
||||
vec!["u1".to_string(), "a1".to_string(), "a2".to_string()]
|
||||
);
|
||||
} else {
|
||||
panic!("expected Resumed history");
|
||||
}
|
||||
|
||||
let truncated2 = truncate_after_dropping_last_messages(items, 2);
|
||||
assert_eq!(truncated2, InitialHistory::New);
|
||||
assert!(matches!(truncated2, InitialHistory::New));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ pub mod list;
|
||||
pub(crate) mod policy;
|
||||
pub mod recorder;
|
||||
|
||||
pub use recorder::RolloutItem;
|
||||
pub use recorder::RolloutRecorder;
|
||||
pub use recorder::RolloutRecorderParams;
|
||||
pub use recorder::SessionMeta;
|
||||
|
||||
@@ -7,6 +7,7 @@ use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use codex_protocol::mcp_protocol::ConversationId;
|
||||
use codex_protocol::protocol::Event;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
@@ -47,7 +48,7 @@ struct SessionMetaWithGit {
|
||||
git: Option<GitInfo>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
#[derive(Serialize, Deserialize, Default, Clone, Debug)]
|
||||
pub struct SessionStateSnapshot {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default, Clone)]
|
||||
@@ -85,9 +86,24 @@ pub enum RolloutRecorderParams {
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum RolloutItem {
|
||||
ResponseItem(Vec<ResponseItem>),
|
||||
Event(Vec<Event>),
|
||||
}
|
||||
|
||||
impl<T> From<T> for RolloutItem
|
||||
where
|
||||
T: AsRef<[ResponseItem]>,
|
||||
{
|
||||
fn from(items: T) -> Self {
|
||||
RolloutItem::ResponseItem(items.as_ref().to_vec())
|
||||
}
|
||||
}
|
||||
|
||||
enum RolloutCmd {
|
||||
AddItems(Vec<ResponseItem>),
|
||||
UpdateState(SessionStateSnapshot),
|
||||
AddResponseItems(Vec<ResponseItem>),
|
||||
AddEvent(Vec<Event>),
|
||||
Shutdown { ack: oneshot::Sender<()> },
|
||||
}
|
||||
|
||||
@@ -172,7 +188,14 @@ impl RolloutRecorder {
|
||||
Ok(Self { tx })
|
||||
}
|
||||
|
||||
pub(crate) async fn record_items(&self, items: &[ResponseItem]) -> std::io::Result<()> {
|
||||
pub(crate) async fn record_items(&self, item: RolloutItem) -> std::io::Result<()> {
|
||||
match item {
|
||||
RolloutItem::ResponseItem(items) => self.record_response_items(&items).await,
|
||||
RolloutItem::Event(events) => self.record_event(&events).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn record_response_items(&self, items: &[ResponseItem]) -> std::io::Result<()> {
|
||||
let mut filtered = Vec::new();
|
||||
for item in items {
|
||||
// Note that function calls may look a bit strange if they are
|
||||
@@ -186,16 +209,16 @@ impl RolloutRecorder {
|
||||
return Ok(());
|
||||
}
|
||||
self.tx
|
||||
.send(RolloutCmd::AddItems(filtered))
|
||||
.send(RolloutCmd::AddResponseItems(filtered))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout items: {e}")))
|
||||
}
|
||||
|
||||
pub(crate) async fn record_state(&self, state: SessionStateSnapshot) -> std::io::Result<()> {
|
||||
async fn record_event(&self, events: &[Event]) -> std::io::Result<()> {
|
||||
self.tx
|
||||
.send(RolloutCmd::UpdateState(state))
|
||||
.send(RolloutCmd::AddEvent(events.to_vec()))
|
||||
.await
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout state: {e}")))
|
||||
.map_err(|e| IoError::other(format!("failed to queue rollout event: {e}")))
|
||||
}
|
||||
|
||||
pub async fn get_rollout_history(path: &Path) -> std::io::Result<InitialHistory> {
|
||||
@@ -221,7 +244,7 @@ impl RolloutRecorder {
|
||||
}
|
||||
};
|
||||
|
||||
let mut items = Vec::new();
|
||||
let mut items: Vec<RolloutItem> = Vec::new();
|
||||
for line in lines {
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
@@ -230,21 +253,32 @@ impl RolloutRecorder {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
if v.get("record_type")
|
||||
.and_then(|rt| rt.as_str())
|
||||
.map(|s| s == "state")
|
||||
.unwrap_or(false)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
match serde_json::from_value::<ResponseItem>(v.clone()) {
|
||||
Ok(item) => {
|
||||
if is_persisted_response_item(&item) {
|
||||
items.push(item);
|
||||
match v.get("record_type").and_then(|rt| rt.as_str()) {
|
||||
Some("state") => continue,
|
||||
Some("event") => {
|
||||
let mut ev_val = v.clone();
|
||||
if let Some(obj) = ev_val.as_object_mut() {
|
||||
obj.remove("record_type");
|
||||
}
|
||||
match serde_json::from_value::<Event>(ev_val) {
|
||||
Ok(ev) => items.push(RolloutItem::Event(vec![ev])),
|
||||
Err(e) => warn!("failed to parse event: {v:?}, error: {e}"),
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("failed to parse item: {v:?}, error: {e}");
|
||||
Some("response") | None => {
|
||||
match serde_json::from_value::<ResponseItem>(v.clone()) {
|
||||
Ok(item) => {
|
||||
if is_persisted_response_item(&item) {
|
||||
items.push(RolloutItem::ResponseItem(vec![item]));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("failed to parse response item: {v:?}, error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(other) => {
|
||||
warn!("unknown record_type in rollout: {other}");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -356,26 +390,28 @@ async fn rollout_writer(
|
||||
// Process rollout commands
|
||||
while let Some(cmd) = rx.recv().await {
|
||||
match cmd {
|
||||
RolloutCmd::AddItems(items) => {
|
||||
RolloutCmd::AddResponseItems(items) => {
|
||||
for item in items {
|
||||
if is_persisted_response_item(&item) {
|
||||
writer.write_line(&item).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
RolloutCmd::UpdateState(state) => {
|
||||
#[derive(Serialize)]
|
||||
struct StateLine<'a> {
|
||||
record_type: &'static str,
|
||||
#[serde(flatten)]
|
||||
state: &'a SessionStateSnapshot,
|
||||
RolloutCmd::AddEvent(events) => {
|
||||
for event in events {
|
||||
#[derive(Serialize)]
|
||||
struct EventLine<'a> {
|
||||
record_type: &'static str,
|
||||
#[serde(flatten)]
|
||||
event: &'a Event,
|
||||
}
|
||||
writer
|
||||
.write_line(&EventLine {
|
||||
record_type: "event",
|
||||
event: &event,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
writer
|
||||
.write_line(&StateLine {
|
||||
record_type: "state",
|
||||
state: &state,
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
RolloutCmd::Shutdown { ack } => {
|
||||
let _ = ack.send(());
|
||||
|
||||
Reference in New Issue
Block a user