feat: add phase 1 mem client

This commit is contained in:
jif-oai
2026-02-04 14:55:28 +00:00
parent 583e5d4f41
commit 9ca9c5f88c
8 changed files with 498 additions and 4 deletions

View File

@@ -29,4 +29,13 @@ The public interface of this crate is intentionally small and uniform:
- Output: `Vec<ResponseItem>`.
- `CompactClient::compact_input(&CompactionInput, extra_headers)` wraps the JSON encoding and retry/telemetry wiring.
- **Memory trace summarize endpoint**
- Input: `MemoryTraceSummarizeInput` (re-exported as `codex_api::MemoryTraceSummarizeInput`):
- `model: String`.
- `traces: Vec<MemoryTrace>`.
- `MemoryTrace` includes `id`, `metadata.source_path`, and normalized `items`.
- `reasoning: Option<Reasoning>`.
- Output: `Vec<MemoryTraceSummaryOutput>`.
- `MemoriesClient::trace_summarize_input(&MemoryTraceSummarizeInput, extra_headers)` wraps JSON encoding and retry/telemetry wiring.
All HTTP details (URLs, headers, retry/backoff policies, SSE framing) are encapsulated in `codex-api` and `codex-client`. Callers construct prompts/inputs using protocol types and work with typed streams of `ResponseEvent` or compacted `ResponseItem` values.

View File

@@ -6,6 +6,7 @@ use codex_protocol::openai_models::ReasoningEffort as ReasoningEffortConfig;
use codex_protocol::protocol::RateLimitSnapshot;
use codex_protocol::protocol::TokenUsage;
use futures::Stream;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
use std::pin::Pin;
@@ -37,6 +38,33 @@ pub struct CompactionInput<'a> {
pub instructions: &'a str,
}
/// Canonical input payload for the memory trace summarize endpoint.
#[derive(Debug, Clone, Serialize)]
pub struct MemoryTraceSummarizeInput {
pub model: String,
pub traces: Vec<MemoryTrace>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning: Option<Reasoning>,
}
#[derive(Debug, Clone, Serialize)]
pub struct MemoryTrace {
pub id: String,
pub metadata: MemoryTraceMetadata,
pub items: Vec<Value>,
}
#[derive(Debug, Clone, Serialize)]
pub struct MemoryTraceMetadata {
pub source_path: String,
}
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
pub struct MemoryTraceSummaryOutput {
pub trace_summary: String,
pub memory_summary: String,
}
#[derive(Debug)]
pub enum ResponseEvent {
Created,

View File

@@ -0,0 +1,108 @@
use crate::auth::AuthProvider;
use crate::common::MemoryTraceSummarizeInput;
use crate::common::MemoryTraceSummaryOutput;
use crate::endpoint::session::EndpointSession;
use crate::error::ApiError;
use crate::provider::Provider;
use codex_client::HttpTransport;
use codex_client::RequestTelemetry;
use http::HeaderMap;
use http::Method;
use serde::Deserialize;
use serde_json::to_value;
use std::sync::Arc;
pub struct MemoriesClient<T: HttpTransport, A: AuthProvider> {
session: EndpointSession<T, A>,
}
impl<T: HttpTransport, A: AuthProvider> MemoriesClient<T, A> {
pub fn new(transport: T, provider: Provider, auth: A) -> Self {
Self {
session: EndpointSession::new(transport, provider, auth),
}
}
pub fn with_telemetry(self, request: Option<Arc<dyn RequestTelemetry>>) -> Self {
Self {
session: self.session.with_request_telemetry(request),
}
}
fn path() -> &'static str {
"memories/trace_summarize"
}
pub async fn trace_summarize(
&self,
body: serde_json::Value,
extra_headers: HeaderMap,
) -> Result<Vec<MemoryTraceSummaryOutput>, ApiError> {
let resp = self
.session
.execute(Method::POST, Self::path(), extra_headers, Some(body))
.await?;
let parsed: TraceSummarizeResponse =
serde_json::from_slice(&resp.body).map_err(|e| ApiError::Stream(e.to_string()))?;
Ok(parsed.output)
}
pub async fn trace_summarize_input(
&self,
input: &MemoryTraceSummarizeInput,
extra_headers: HeaderMap,
) -> Result<Vec<MemoryTraceSummaryOutput>, ApiError> {
let body = to_value(input).map_err(|e| {
ApiError::Stream(format!(
"failed to encode memory trace summarize input: {e}"
))
})?;
self.trace_summarize(body, extra_headers).await
}
}
#[derive(Debug, Deserialize)]
struct TraceSummarizeResponse {
output: Vec<MemoryTraceSummaryOutput>,
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use codex_client::Request;
use codex_client::Response;
use codex_client::StreamResponse;
use codex_client::TransportError;
#[derive(Clone, Default)]
struct DummyTransport;
#[async_trait]
impl HttpTransport for DummyTransport {
async fn execute(&self, _req: Request) -> Result<Response, TransportError> {
Err(TransportError::Build("execute should not run".to_string()))
}
async fn stream(&self, _req: Request) -> Result<StreamResponse, TransportError> {
Err(TransportError::Build("stream should not run".to_string()))
}
}
#[derive(Clone, Default)]
struct DummyAuth;
impl AuthProvider for DummyAuth {
fn bearer_token(&self) -> Option<String> {
None
}
}
#[test]
fn path_is_memories_trace_summarize() {
assert_eq!(
MemoriesClient::<DummyTransport, DummyAuth>::path(),
"memories/trace_summarize"
);
}
}

View File

@@ -1,5 +1,6 @@
pub mod aggregate;
pub mod compact;
pub mod memories;
pub mod models;
pub mod responses;
pub mod responses_websocket;

View File

@@ -15,6 +15,10 @@ pub use codex_client::TransportError;
pub use crate::auth::AuthProvider;
pub use crate::common::CompactionInput;
pub use crate::common::MemoryTrace;
pub use crate::common::MemoryTraceMetadata;
pub use crate::common::MemoryTraceSummarizeInput;
pub use crate::common::MemoryTraceSummaryOutput;
pub use crate::common::Prompt;
pub use crate::common::ResponseAppendWsRequest;
pub use crate::common::ResponseCreateWsRequest;
@@ -24,6 +28,7 @@ pub use crate::common::ResponsesApiRequest;
pub use crate::common::create_text_param_for_request;
pub use crate::endpoint::aggregate::AggregateStreamExt;
pub use crate::endpoint::compact::CompactClient;
pub use crate::endpoint::memories::MemoriesClient;
pub use crate::endpoint::models::ModelsClient;
pub use crate::endpoint::responses::ResponsesClient;
pub use crate::endpoint::responses::ResponsesOptions;

View File

@@ -7,6 +7,10 @@ use crate::api_bridge::map_api_error;
use crate::auth::UnauthorizedRecovery;
use codex_api::CompactClient as ApiCompactClient;
use codex_api::CompactionInput as ApiCompactionInput;
use codex_api::MemoriesClient as ApiMemoriesClient;
use codex_api::MemoryTrace as ApiMemoryTrace;
use codex_api::MemoryTraceSummarizeInput as ApiMemoryTraceSummarizeInput;
use codex_api::MemoryTraceSummaryOutput as ApiMemoryTraceSummaryOutput;
use codex_api::Prompt as ApiPrompt;
use codex_api::RequestTelemetry;
use codex_api::ReqwestTransport;
@@ -238,6 +242,55 @@ impl ModelClient {
instructions: &instructions,
};
let extra_headers = self.build_subagent_headers();
client
.compact_input(&payload, extra_headers)
.await
.map_err(map_api_error)
}
/// Builds memory summaries for each provided normalized trace.
///
/// This is a unary call (no streaming) to `/v1/memories/trace_summarize`.
pub async fn summarize_memory_traces(
&self,
traces: Vec<ApiMemoryTrace>,
) -> Result<Vec<ApiMemoryTraceSummaryOutput>> {
if traces.is_empty() {
return Ok(Vec::new());
}
let auth_manager = self.state.auth_manager.clone();
let auth = match auth_manager.as_ref() {
Some(manager) => manager.auth().await,
None => None,
};
let api_provider = self
.state
.provider
.to_api_provider(auth.as_ref().map(CodexAuth::internal_auth_mode))?;
let api_auth = auth_provider_from_auth(auth, &self.state.provider)?;
let transport = ReqwestTransport::new(build_reqwest_client());
let request_telemetry = self.build_request_telemetry();
let client = ApiMemoriesClient::new(transport, api_provider, api_auth)
.with_telemetry(Some(request_telemetry));
let payload = ApiMemoryTraceSummarizeInput {
model: self.state.model_info.slug.clone(),
traces,
reasoning: self.state.effort.map(|effort| Reasoning {
effort: Some(effort),
summary: None,
}),
};
client
.trace_summarize_input(&payload, self.build_subagent_headers())
.await
.map_err(map_api_error)
}
fn build_subagent_headers(&self) -> ApiHeaderMap {
let mut extra_headers = ApiHeaderMap::new();
if let SessionSource::SubAgent(sub) = &self.state.session_source {
let subagent = match sub {
@@ -250,10 +303,7 @@ impl ModelClient {
extra_headers.insert("x-openai-subagent", val);
}
}
client
.compact_input(&payload, extra_headers)
.await
.map_err(map_api_error)
extra_headers
}
}

View File

@@ -161,4 +161,5 @@ pub use codex_protocol::models::ResponseItem;
pub use compact::content_items_to_text;
pub use event_mapping::parse_turn_item;
pub mod compact;
pub mod memory_trace;
pub mod otel_init;

View File

@@ -0,0 +1,292 @@
use std::path::Path;
use std::path::PathBuf;
use crate::ModelClient;
use crate::error::CodexErr;
use crate::error::Result;
use codex_api::MemoryTrace as ApiMemoryTrace;
use codex_api::MemoryTraceMetadata as ApiMemoryTraceMetadata;
use serde_json::Map;
use serde_json::Value;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BuiltTraceMemory {
pub trace_id: String,
pub source_path: PathBuf,
pub trace_summary: String,
pub memory_summary: String,
}
struct PreparedTrace {
trace_id: String,
source_path: PathBuf,
payload: ApiMemoryTrace,
}
/// Loads raw trace files, normalizes trace items, and builds memory summaries.
///
/// The request/response wiring mirrors the memory trace summarize E2E flow:
/// `/v1/memories/trace_summarize` with one output object per input trace.
pub async fn build_memories_from_trace_files(
client: &ModelClient,
trace_paths: &[PathBuf],
) -> Result<Vec<BuiltTraceMemory>> {
if trace_paths.is_empty() {
return Ok(Vec::new());
}
let mut prepared = Vec::with_capacity(trace_paths.len());
for (index, path) in trace_paths.iter().enumerate() {
prepared.push(prepare_trace(index + 1, path).await?);
}
let traces = prepared.iter().map(|trace| trace.payload.clone()).collect();
let output = client.summarize_memory_traces(traces).await?;
if output.len() != prepared.len() {
return Err(CodexErr::InvalidRequest(format!(
"unexpected memory summarize output length: expected {}, got {}",
prepared.len(),
output.len()
)));
}
Ok(prepared
.into_iter()
.zip(output)
.map(|(trace, summary)| BuiltTraceMemory {
trace_id: trace.trace_id,
source_path: trace.source_path,
trace_summary: summary.trace_summary,
memory_summary: summary.memory_summary,
})
.collect())
}
async fn prepare_trace(index: usize, path: &Path) -> Result<PreparedTrace> {
let text = load_trace_text(path).await?;
let items = load_trace_items(path, &text)?;
let trace_id = build_trace_id(index, path);
let source_path = path.to_path_buf();
Ok(PreparedTrace {
trace_id: trace_id.clone(),
source_path: source_path.clone(),
payload: ApiMemoryTrace {
id: trace_id,
metadata: ApiMemoryTraceMetadata {
source_path: source_path.display().to_string(),
},
items,
},
})
}
async fn load_trace_text(path: &Path) -> Result<String> {
let raw = tokio::fs::read(path).await?;
Ok(decode_trace_bytes(&raw))
}
fn decode_trace_bytes(raw: &[u8]) -> String {
if let Some(without_bom) = raw.strip_prefix(&[0xEF, 0xBB, 0xBF])
&& let Ok(text) = String::from_utf8(without_bom.to_vec())
{
return text;
}
if let Ok(text) = String::from_utf8(raw.to_vec()) {
return text;
}
raw.iter().map(|b| char::from(*b)).collect()
}
fn load_trace_items(path: &Path, text: &str) -> Result<Vec<Value>> {
if let Ok(Value::Array(items)) = serde_json::from_str::<Value>(text) {
let dict_items = items
.into_iter()
.filter(serde_json::Value::is_object)
.collect::<Vec<_>>();
if dict_items.is_empty() {
return Err(CodexErr::InvalidRequest(format!(
"no object items found in trace file: {}",
path.display()
)));
}
return normalize_trace_items(dict_items, path);
}
let mut parsed_items = Vec::new();
for line in text.lines() {
let line = line.trim();
if line.is_empty() || (!line.starts_with('{') && !line.starts_with('[')) {
continue;
}
let Ok(obj) = serde_json::from_str::<Value>(line) else {
continue;
};
match obj {
Value::Object(_) => parsed_items.push(obj),
Value::Array(inner) => {
parsed_items.extend(inner.into_iter().filter(serde_json::Value::is_object))
}
_ => {}
}
}
if parsed_items.is_empty() {
return Err(CodexErr::InvalidRequest(format!(
"no JSON items parsed from trace file: {}",
path.display()
)));
}
normalize_trace_items(parsed_items, path)
}
fn normalize_trace_items(items: Vec<Value>, path: &Path) -> Result<Vec<Value>> {
let mut normalized = Vec::new();
for item in items {
let Value::Object(obj) = item else {
continue;
};
if let Some(payload) = obj.get("payload") {
if obj.get("type").and_then(Value::as_str) != Some("response_item") {
continue;
}
match payload {
Value::Object(payload_item) => {
if is_allowed_trace_item(payload_item) {
normalized.push(Value::Object(payload_item.clone()));
}
}
Value::Array(payload_items) => {
for payload_item in payload_items {
if let Value::Object(payload_item) = payload_item
&& is_allowed_trace_item(payload_item)
{
normalized.push(Value::Object(payload_item.clone()));
}
}
}
_ => {}
}
continue;
}
if is_allowed_trace_item(&obj) {
normalized.push(Value::Object(obj));
}
}
if normalized.is_empty() {
return Err(CodexErr::InvalidRequest(format!(
"no valid trace items after normalization: {}",
path.display()
)));
}
Ok(normalized)
}
fn is_allowed_trace_item(item: &Map<String, Value>) -> bool {
let Some(item_type) = item.get("type").and_then(Value::as_str) else {
return false;
};
if item_type == "message" {
return matches!(
item.get("role").and_then(Value::as_str),
Some("assistant" | "system" | "developer" | "user")
);
}
true
}
fn build_trace_id(index: usize, path: &Path) -> String {
let stem = path
.file_stem()
.map(|stem| stem.to_string_lossy().into_owned())
.filter(|stem| !stem.is_empty())
.unwrap_or_else(|| "trace".to_string());
format!("trace_{index}_{stem}")
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use tempfile::tempdir;
#[test]
fn normalize_trace_items_handles_payload_wrapper_and_message_role_filtering() {
let items = vec![
serde_json::json!({
"type": "response_item",
"payload": {"type": "message", "role": "assistant", "content": []}
}),
serde_json::json!({
"type": "response_item",
"payload": [
{"type": "message", "role": "user", "content": []},
{"type": "message", "role": "tool", "content": []},
{"type": "function_call", "name": "shell", "arguments": "{}", "call_id": "c1"}
]
}),
serde_json::json!({
"type": "not_response_item",
"payload": {"type": "message", "role": "assistant", "content": []}
}),
serde_json::json!({
"type": "message",
"role": "developer",
"content": []
}),
];
let normalized = normalize_trace_items(items, Path::new("trace.json")).expect("normalize");
let expected = vec![
serde_json::json!({"type": "message", "role": "assistant", "content": []}),
serde_json::json!({"type": "message", "role": "user", "content": []}),
serde_json::json!({"type": "function_call", "name": "shell", "arguments": "{}", "call_id": "c1"}),
serde_json::json!({"type": "message", "role": "developer", "content": []}),
];
assert_eq!(normalized, expected);
}
#[test]
fn load_trace_items_supports_jsonl_arrays_and_objects() {
let text = r#"
{"type":"response_item","payload":{"type":"message","role":"assistant","content":[]}}
[{"type":"message","role":"user","content":[]},{"type":"message","role":"tool","content":[]}]
"#;
let loaded = load_trace_items(Path::new("trace.jsonl"), text).expect("load");
let expected = vec![
serde_json::json!({"type":"message","role":"assistant","content":[]}),
serde_json::json!({"type":"message","role":"user","content":[]}),
];
assert_eq!(loaded, expected);
}
#[tokio::test]
async fn load_trace_text_decodes_utf8_sig() {
let dir = tempdir().expect("tempdir");
let path = dir.path().join("trace.json");
tokio::fs::write(
&path,
[
0xEF, 0xBB, 0xBF, b'[', b'{', b'"', b't', b'y', b'p', b'e', b'"', b':', b'"', b'm',
b'e', b's', b's', b'a', b'g', b'e', b'"', b',', b'"', b'r', b'o', b'l', b'e', b'"',
b':', b'"', b'u', b's', b'e', b'r', b'"', b',', b'"', b'c', b'o', b'n', b't', b'e',
b'n', b't', b'"', b':', b'[', b']', b'}', b']',
],
)
.await
.expect("write");
let text = load_trace_text(&path).await.expect("decode");
assert!(text.starts_with('['));
}
}