Files
codex/codex-rs/core/src/memories/phase2.rs
2026-02-15 06:49:15 -08:00

475 lines
16 KiB
Rust

use crate::agent::AgentStatus;
use crate::agent::status::is_final as is_final_agent_status;
use crate::codex::Session;
use crate::config::Config;
use crate::memories::memory_root;
use crate::memories::metrics;
use crate::memories::phase_two;
use crate::memories::prompts::build_consolidation_prompt;
use crate::memories::start::emit_memory_progress;
use crate::memories::storage::rebuild_raw_memories_file_from_memories;
use crate::memories::storage::sync_rollout_summaries_from_memories;
use codex_config::Constrained;
use codex_protocol::ThreadId;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::SandboxPolicy;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use codex_protocol::user_input::UserInput;
use codex_state::StateRuntime;
use codex_utils_absolute_path::AbsolutePathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::watch;
use tracing::warn;
#[derive(Debug, Clone, Default)]
struct Claim {
token: String,
watermark: i64,
}
#[derive(Debug, Clone, Default)]
struct Counters {
input: i64,
}
/// Runs memory phase 2 (aka consolidation) in strict order. The method represents the linear
/// flow of the consolidation phase.
pub(super) async fn run(
session: &Arc<Session>,
config: Arc<Config>,
progress_sub_id: &Option<String>,
) {
let Some(db) = session.services.state_db.as_deref() else {
emit_memory_progress(
session.as_ref(),
progress_sub_id,
"phase 2 skipped (state db unavailable)",
)
.await;
return;
};
let root = memory_root(&config.codex_home);
let max_raw_memories = config.memories.max_raw_memories_for_global;
// 1. Claim the job.
let claim = match job::claim(session, db).await {
Ok(claim) => claim,
Err(e) => {
session.services.otel_manager.counter(
metrics::MEMORY_PHASE_TWO_JOBS,
1,
&[("status", e)],
);
let progress = match e {
"skipped_not_dirty" => "phase 2 up to date",
"skipped_running" => "phase 2 already running",
_ => "phase 2 failed to claim global job",
};
emit_memory_progress(session.as_ref(), progress_sub_id, progress).await;
return;
}
};
// 2. Get the config for the agent
let Some(agent_config) = agent::get_config(config.clone()) else {
// If we can't get the config, we can't consolidate.
tracing::error!("failed to get agent config");
job::failed(session, db, &claim, "failed_sandbox_policy").await;
emit_memory_progress(
session.as_ref(),
progress_sub_id,
"phase 2 failed (sandbox policy rejected)",
)
.await;
return;
};
// 3. Query the memories
let raw_memories = match db.list_stage1_outputs_for_global(max_raw_memories).await {
Ok(memories) => memories,
Err(err) => {
tracing::error!("failed to list stage1 outputs from global: {}", err);
job::failed(session, db, &claim, "failed_load_stage1_outputs").await;
emit_memory_progress(
session.as_ref(),
progress_sub_id,
"phase 2 failed (could not load stage-1 outputs)",
)
.await;
return;
}
};
let new_watermark = get_watermark(claim.watermark, &raw_memories);
// 4. Update the file system by syncing the raw memories with the one extracted from DB at
// step 3
// [`rollout_summaries/`]
if let Err(err) =
sync_rollout_summaries_from_memories(&root, &raw_memories, max_raw_memories).await
{
tracing::error!("failed syncing local memory artifacts for global consolidation: {err}");
job::failed(session, db, &claim, "failed_sync_artifacts").await;
emit_memory_progress(
session.as_ref(),
progress_sub_id,
"phase 2 failed (could not sync local artifacts)",
)
.await;
return;
}
// [`raw_memories.md`]
if let Err(err) =
rebuild_raw_memories_file_from_memories(&root, &raw_memories, max_raw_memories).await
{
tracing::error!("failed syncing local memory artifacts for global consolidation: {err}");
job::failed(session, db, &claim, "failed_rebuild_raw_memories").await;
emit_memory_progress(
session.as_ref(),
progress_sub_id,
"phase 2 failed (could not rebuild raw memories)",
)
.await;
return;
}
if raw_memories.is_empty() {
// We check only after sync of the file system.
job::succeed(session, db, &claim, new_watermark, "succeeded_no_input").await;
emit_memory_progress(
session.as_ref(),
progress_sub_id,
"phase 2 complete (no stage-1 outputs)",
)
.await;
return;
}
// 5. Spawn the agent
let prompt = agent::get_prompt(config);
let source = SessionSource::SubAgent(SubAgentSource::MemoryConsolidation);
let thread_id = match session
.services
.agent_control
.spawn_agent(agent_config, prompt, Some(source))
.await
{
Ok(thread_id) => thread_id,
Err(err) => {
tracing::error!("failed to spawn global memory consolidation agent: {err}");
job::failed(session, db, &claim, "failed_spawn_agent").await;
emit_memory_progress(
session.as_ref(),
progress_sub_id,
"phase 2 failed (could not spawn consolidation agent)",
)
.await;
return;
}
};
emit_memory_progress(session.as_ref(), progress_sub_id, "phase 2 running").await;
// 6. Spawn the agent handler.
agent::handle(
session,
claim,
new_watermark,
thread_id,
progress_sub_id.clone(),
);
// 7. Metrics and logs.
let counters = Counters {
input: raw_memories.len() as i64,
};
emit_metrics(session, counters);
}
mod job {
use super::*;
pub(super) async fn claim(
session: &Arc<Session>,
db: &StateRuntime,
) -> Result<Claim, &'static str> {
let otel_manager = &session.services.otel_manager;
let claim = db
.try_claim_global_phase2_job(session.conversation_id, phase_two::JOB_LEASE_SECONDS)
.await
.map_err(|e| {
tracing::error!("failed to claim job: {}", e);
"failed_claim"
})?;
let (token, watermark) = match claim {
codex_state::Phase2JobClaimOutcome::Claimed {
ownership_token,
input_watermark,
} => {
otel_manager.counter(metrics::MEMORY_PHASE_TWO_JOBS, 1, &[("status", "claimed")]);
(ownership_token, input_watermark)
}
codex_state::Phase2JobClaimOutcome::SkippedNotDirty => return Err("skipped_not_dirty"),
codex_state::Phase2JobClaimOutcome::SkippedRunning => return Err("skipped_running"),
};
Ok(Claim { token, watermark })
}
pub(super) async fn failed(
session: &Arc<Session>,
db: &StateRuntime,
claim: &Claim,
reason: &'static str,
) {
session.services.otel_manager.counter(
metrics::MEMORY_PHASE_TWO_JOBS,
1,
&[("status", reason)],
);
if matches!(
db.mark_global_phase2_job_failed(
&claim.token,
reason,
phase_two::JOB_RETRY_DELAY_SECONDS,
)
.await,
Ok(false)
) {
let _ = db
.mark_global_phase2_job_failed_if_unowned(
&claim.token,
reason,
phase_two::JOB_RETRY_DELAY_SECONDS,
)
.await;
}
}
pub(super) async fn succeed(
session: &Arc<Session>,
db: &StateRuntime,
claim: &Claim,
completion_watermark: i64,
reason: &'static str,
) {
session.services.otel_manager.counter(
metrics::MEMORY_PHASE_TWO_JOBS,
1,
&[("status", reason)],
);
let _ = db
.mark_global_phase2_job_succeeded(&claim.token, completion_watermark)
.await;
}
}
mod agent {
use super::*;
pub(super) fn get_config(config: Arc<Config>) -> Option<Config> {
let root = memory_root(&config.codex_home);
let mut agent_config = config.as_ref().clone();
agent_config.cwd = root;
// Approval policy
agent_config.permissions.approval_policy = Constrained::allow_only(AskForApproval::Never);
// Sandbox policy
let mut writable_roots = Vec::new();
match AbsolutePathBuf::from_absolute_path(agent_config.codex_home.clone()) {
Ok(codex_home) => writable_roots.push(codex_home),
Err(err) => warn!(
"memory phase-2 consolidation could not add codex_home writable root {}: {err}",
agent_config.codex_home.display()
),
}
// The consolidation agent only needs local codex_home write access and no network.
let consolidation_sandbox_policy = SandboxPolicy::WorkspaceWrite {
writable_roots,
read_only_access: Default::default(),
network_access: false,
exclude_tmpdir_env_var: false,
exclude_slash_tmp: false,
};
agent_config
.permissions
.sandbox_policy
.set(consolidation_sandbox_policy)
.ok()?;
agent_config.model = Some(
config
.memories
.phase_2_model
.clone()
.unwrap_or(phase_two::MODEL.to_string()),
);
Some(agent_config)
}
pub(super) fn get_prompt(config: Arc<Config>) -> Vec<UserInput> {
let root = memory_root(&config.codex_home);
let prompt = build_consolidation_prompt(&root);
vec![UserInput::Text {
text: prompt,
text_elements: vec![],
}]
}
/// Handle the agent while it is running.
pub(super) fn handle(
session: &Arc<Session>,
claim: Claim,
new_watermark: i64,
thread_id: ThreadId,
progress_sub_id: Option<String>,
) {
let Some(db) = session.services.state_db.clone() else {
let session = Arc::clone(session);
tokio::spawn(async move {
emit_memory_progress(
session.as_ref(),
&progress_sub_id,
"phase 2 failed (state db unavailable)",
)
.await;
});
return;
};
let session = session.clone();
tokio::spawn(async move {
let agent_control = session.services.agent_control.clone();
// TODO(jif) we might have a very small race here.
let rx = match agent_control.subscribe_status(thread_id).await {
Ok(rx) => rx,
Err(err) => {
tracing::error!("agent_control.subscribe_status failed: {err:?}");
job::failed(&session, &db, &claim, "failed_subscribe_status").await;
emit_memory_progress(
session.as_ref(),
&progress_sub_id,
"phase 2 failed (status subscription unavailable)",
)
.await;
return;
}
};
// Loop the agent until we have the final status.
let final_status = loop_agent(
db.clone(),
claim.token.clone(),
new_watermark,
thread_id,
rx,
)
.await;
if matches!(final_status, AgentStatus::Completed(_)) {
job::succeed(&session, &db, &claim, new_watermark, "succeeded").await;
} else {
job::failed(&session, &db, &claim, "failed_agent").await;
}
let progress = match &final_status {
AgentStatus::Completed(_) => "phase 2 complete",
AgentStatus::Errored(_) | AgentStatus::NotFound => "phase 2 failed",
AgentStatus::Shutdown => "phase 2 cancelled",
AgentStatus::PendingInit | AgentStatus::Running => "phase 2 failed",
};
emit_memory_progress(session.as_ref(), &progress_sub_id, progress).await;
// Fire and forget close of the agent.
if !matches!(final_status, AgentStatus::Shutdown | AgentStatus::NotFound) {
tokio::spawn(async move {
if let Err(err) = agent_control.shutdown_agent(thread_id).await {
warn!(
"failed to auto-close global memory consolidation agent {thread_id}: {err}"
);
}
});
} else {
tracing::warn!("The agent was already gone");
}
});
}
async fn loop_agent(
db: Arc<StateRuntime>,
token: String,
_new_watermark: i64,
thread_id: ThreadId,
mut rx: watch::Receiver<AgentStatus>,
) -> AgentStatus {
let mut heartbeat_interval =
tokio::time::interval(Duration::from_secs(phase_two::JOB_HEARTBEAT_SECONDS));
heartbeat_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
let status = rx.borrow().clone();
if is_final_agent_status(&status) {
break status;
}
tokio::select! {
update = rx.changed() => {
if update.is_err() {
tracing::warn!(
"lost status updates for global memory consolidation agent {thread_id}"
);
break status;
}
}
_ = heartbeat_interval.tick() => {
match db
.heartbeat_global_phase2_job(
&token,
phase_two::JOB_LEASE_SECONDS,
)
.await
{
Ok(true) => {}
Ok(false) => {
break AgentStatus::Errored(
"lost global phase-2 ownership during heartbeat".to_string(),
);
}
Err(err) => {
break AgentStatus::Errored(format!(
"phase-2 heartbeat update failed: {err}"
));
}
}
}
}
}
}
}
pub(super) fn get_watermark(
claimed_watermark: i64,
latest_memories: &[codex_state::Stage1Output],
) -> i64 {
latest_memories
.iter()
.map(|memory| memory.source_updated_at.timestamp())
.max()
.unwrap_or(claimed_watermark)
.max(claimed_watermark) // todo double check the claimed here.
}
fn emit_metrics(session: &Arc<Session>, counters: Counters) {
let otel = session.services.otel_manager.clone();
if counters.input > 0 {
otel.counter(metrics::MEMORY_PHASE_TWO_INPUT, counters.input, &[]);
}
otel.counter(
metrics::MEMORY_PHASE_TWO_JOBS,
1,
&[("status", "agent_spawned")],
);
}