diff --git a/codex-rs/core/src/session/tests.rs b/codex-rs/core/src/session/tests.rs index 5318b72151..a08243701c 100644 --- a/codex-rs/core/src/session/tests.rs +++ b/codex-rs/core/src/session/tests.rs @@ -1918,6 +1918,9 @@ async fn turn_start_lifecycle_exposes_turn_metadata_and_token_baseline() { #[derive(Debug, PartialEq, Eq)] struct RecordedTurnStart { + session_level_id: String, + thread_level_id: String, + turn_level_id: String, turn_id: String, collaboration_mode: CollaborationMode, token_usage_at_turn_start: TokenUsage, @@ -1936,6 +1939,9 @@ async fn turn_start_lifecycle_exposes_turn_metadata_and_token_baseline() { .lock() .expect("turn start records lock") .push(RecordedTurnStart { + session_level_id: input.session_store.level_id().to_string(), + thread_level_id: input.thread_store.level_id().to_string(), + turn_level_id: input.turn_store.level_id().to_string(), turn_id: input.turn_id.to_string(), collaboration_mode: input.collaboration_mode.clone(), token_usage_at_turn_start: input.token_usage_at_turn_start.clone(), @@ -1965,52 +1971,43 @@ async fn turn_start_lifecycle_exposes_turn_metadata_and_token_baseline() { .insert(ThreadTurnStartMarker); let token_usage_at_turn_start = TokenUsage { - input_tokens: 120, - cached_input_tokens: 15, - output_tokens: 40, - reasoning_output_tokens: 9, - total_tokens: 169, + input_tokens: 100, + cached_input_tokens: 40, + output_tokens: 25, + reasoning_output_tokens: 5, + total_tokens: 130, }; - session - .state - .lock() - .await - .set_token_info(Some(TokenUsageInfo { - total_token_usage: token_usage_at_turn_start.clone(), - last_token_usage: TokenUsage::default(), - model_context_window: turn_context.model_context_window(), - })); + set_total_token_usage(&session, token_usage_at_turn_start.clone()).await; - let turn_context = Arc::new(turn_context); - let session = Arc::new(session); - session - .spawn_task( - Arc::clone(&turn_context), - Vec::new(), - NeverEndingTask { - kind: TaskKind::Regular, - listen_to_cancellation_token: true, - }, - ) - .await; + let expected = RecordedTurnStart { + session_level_id: session.session_id().to_string(), + thread_level_id: session.conversation_id.to_string(), + turn_level_id: turn_context.sub_id.clone(), + turn_id: turn_context.sub_id.clone(), + collaboration_mode: turn_context.collaboration_mode.clone(), + token_usage_at_turn_start, + saw_session_store: true, + saw_thread_store: true, + }; - session.abort_all_tasks(TurnAbortReason::Interrupted).await; + let sess = Arc::new(session); + sess.spawn_task( + Arc::new(turn_context), + Vec::new(), + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: true, + }, + ) + .await; + sess.abort_all_tasks(TurnAbortReason::Interrupted).await; let actual = records .lock() .expect("turn start records lock") .drain(..) .collect::>(); - assert_eq!( - vec![RecordedTurnStart { - turn_id: turn_context.sub_id.clone(), - collaboration_mode: turn_context.collaboration_mode.clone(), - token_usage_at_turn_start, - saw_session_store: true, - saw_thread_store: true, - }], - actual - ); + assert_eq!(vec![expected], actual); } #[tokio::test] diff --git a/codex-rs/ext/goal/src/accounting.rs b/codex-rs/ext/goal/src/accounting.rs index cc04684169..b4fa14669d 100644 --- a/codex-rs/ext/goal/src/accounting.rs +++ b/codex-rs/ext/goal/src/accounting.rs @@ -1,26 +1,50 @@ use codex_protocol::config_types::ModeKind; use codex_protocol::protocol::TokenUsage; +use codex_state::ThreadGoalStatus; use std::collections::HashMap; use std::sync::Mutex; use std::sync::PoisonError; +use std::time::Duration; +use std::time::Instant; #[derive(Debug, Default)] pub(crate) struct GoalAccountingState { inner: Mutex, } -#[derive(Debug, Default)] +#[derive(Debug)] struct GoalAccountingInner { + current_turn_id: Option, turns: HashMap, - unflushed_token_delta: i64, + wall_clock: GoalWallClockAccounting, } -#[derive(Debug, Default)] +#[derive(Debug)] struct GoalTurnAccounting { - token_delta: i64, + current_token_usage: TokenUsage, last_accounted_token_usage: TokenUsage, + active_goal_id: Option, account_tokens: bool, - stopped: bool, +} + +#[derive(Debug)] +struct GoalWallClockAccounting { + last_accounted_at: Instant, + active_goal_id: Option, +} + +#[derive(Debug, Clone)] +pub(crate) struct GoalProgressSnapshot { + pub(crate) current_token_usage: TokenUsage, + pub(crate) expected_goal_id: String, + pub(crate) time_delta_seconds: i64, + pub(crate) token_delta: i64, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum BudgetLimitedGoalDisposition { + KeepActive, + ClearActive, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -37,17 +61,21 @@ impl GoalAccountingState { token_usage_at_turn_start: &TokenUsage, ) { let turn_id = turn_id.into(); - self.inner().turns.insert( + let mut inner = self.inner(); + inner.current_turn_id = Some(turn_id.clone()); + inner.turns.insert( turn_id, - GoalTurnAccounting { - token_delta: 0, - last_accounted_token_usage: token_usage_at_turn_start.clone(), - account_tokens: !matches!(collaboration_mode, ModeKind::Plan), - stopped: false, - }, + GoalTurnAccounting::new( + token_usage_at_turn_start.clone(), + !matches!(collaboration_mode, ModeKind::Plan), + ), ); } + pub(crate) fn current_turn_id(&self) -> Option { + self.inner().current_turn_id.clone() + } + pub(crate) fn record_token_usage( &self, turn_id: impl Into, @@ -56,28 +84,107 @@ impl GoalAccountingState { let turn_id = turn_id.into(); let mut inner = self.inner(); let turn = inner.turns.get_mut(&turn_id)?; - if turn.stopped || !turn.account_tokens { + turn.current_token_usage = total_usage.clone(); + if !turn.account_tokens { return None; } - let delta = - token_delta_since_last_accounting(&turn.last_accounted_token_usage, total_usage); - turn.last_accounted_token_usage = total_usage.clone(); + let delta = turn.token_delta_since_last_accounting(); if delta <= 0 { return None; } - turn.token_delta = turn.token_delta.saturating_add(delta); - let turn_delta = turn.token_delta; - inner.unflushed_token_delta = inner.unflushed_token_delta.saturating_add(delta); Some(RecordedTokenDelta { - turn_delta, - thread_unflushed_delta: inner.unflushed_token_delta, + turn_delta: delta, + thread_unflushed_delta: inner.thread_unflushed_token_delta(), }) } - pub(crate) fn stop_turn(&self, turn_id: &str) { - if let Some(turn) = self.inner().turns.get_mut(turn_id) { - turn.stopped = true; + pub(crate) fn mark_turn_goal_active(&self, turn_id: &str, goal_id: impl Into) { + let mut inner = self.inner(); + let goal_id = goal_id.into(); + if let Some(turn) = inner.turns.get_mut(turn_id) { + turn.active_goal_id = Some(goal_id.clone()); + if inner.current_turn_id.as_deref() == Some(turn_id) { + inner.wall_clock.mark_active_goal(goal_id); + } + } + } + + pub(crate) fn mark_current_turn_goal_active( + &self, + goal_id: impl Into, + ) -> Option { + let mut inner = self.inner(); + let turn_id = inner.current_turn_id.clone()?; + let goal_id = goal_id.into(); + let turn = inner.turns.get_mut(turn_id.as_str())?; + turn.active_goal_id = Some(goal_id.clone()); + turn.reset_baseline_to_current(); + inner.wall_clock.mark_active_goal(goal_id); + Some(turn_id) + } + + pub(crate) fn clear_current_turn_goal(&self) -> Option { + let mut inner = self.inner(); + let turn_id = inner.current_turn_id.clone()?; + if let Some(turn) = inner.turns.get_mut(turn_id.as_str()) { + turn.active_goal_id = None; + } + inner.wall_clock.clear_active_goal(); + Some(turn_id) + } + + pub(crate) fn progress_snapshot(&self, turn_id: &str) -> Option { + let inner = self.inner(); + let turn = inner.turns.get(turn_id)?; + if !turn.account_tokens { + return None; + } + let expected_goal_id = turn.active_goal_id()?; + let token_delta = turn.token_delta_since_last_accounting(); + let time_delta_seconds = + if inner.wall_clock.active_goal_id.as_deref() == Some(expected_goal_id.as_str()) { + inner.wall_clock.time_delta_since_last_accounting() + } else { + 0 + }; + if time_delta_seconds == 0 && token_delta <= 0 { + return None; + } + Some(GoalProgressSnapshot { + current_token_usage: turn.current_token_usage.clone(), + expected_goal_id, + time_delta_seconds, + token_delta, + }) + } + + pub(crate) fn mark_progress_accounted_for_status( + &self, + turn_id: &str, + snapshot: &GoalProgressSnapshot, + status: ThreadGoalStatus, + budget_limited_goal_disposition: BudgetLimitedGoalDisposition, + ) { + let clear_active_goal = should_clear_active_goal(status, budget_limited_goal_disposition); + let mut inner = self.inner(); + if let Some(turn) = inner.turns.get_mut(turn_id) { + turn.last_accounted_token_usage = snapshot.current_token_usage.clone(); + if clear_active_goal { + turn.active_goal_id = None; + } + } + inner.wall_clock.mark_accounted(snapshot.time_delta_seconds); + if clear_active_goal { + inner.wall_clock.clear_active_goal(); + } + } + + pub(crate) fn finish_turn(&self, turn_id: &str) { + let mut inner = self.inner(); + inner.turns.remove(turn_id); + if inner.current_turn_id.as_deref() == Some(turn_id) { + inner.current_turn_id = None; } } @@ -107,3 +214,108 @@ pub(crate) fn goal_token_delta_for_usage(usage: &TokenUsage) -> i64 { .saturating_sub(usage.cached_input_tokens) .saturating_add(usage.output_tokens.max(0)) } + +impl Default for GoalAccountingInner { + fn default() -> Self { + Self { + current_turn_id: None, + turns: HashMap::new(), + wall_clock: GoalWallClockAccounting::new(), + } + } +} + +impl GoalAccountingInner { + fn thread_unflushed_token_delta(&self) -> i64 { + self.turns + .values() + .filter(|turn| turn.account_tokens) + .fold(0_i64, |total, turn| { + total.saturating_add(turn.token_delta_since_last_accounting().max(0)) + }) + } +} + +impl GoalTurnAccounting { + fn new(current_token_usage: TokenUsage, account_tokens: bool) -> Self { + Self { + last_accounted_token_usage: current_token_usage.clone(), + current_token_usage, + active_goal_id: None, + account_tokens, + } + } + + fn active_goal_id(&self) -> Option { + self.active_goal_id.clone() + } + + fn reset_baseline_to_current(&mut self) { + self.last_accounted_token_usage = self.current_token_usage.clone(); + } + + fn token_delta_since_last_accounting(&self) -> i64 { + token_delta_since_last_accounting( + &self.last_accounted_token_usage, + &self.current_token_usage, + ) + } +} + +impl GoalWallClockAccounting { + fn new() -> Self { + Self { + last_accounted_at: Instant::now(), + active_goal_id: None, + } + } + + fn time_delta_since_last_accounting(&self) -> i64 { + i64::try_from(self.last_accounted_at.elapsed().as_secs()).unwrap_or(i64::MAX) + } + + fn mark_accounted(&mut self, accounted_seconds: i64) { + if accounted_seconds <= 0 { + return; + } + let advance = Duration::from_secs(u64::try_from(accounted_seconds).unwrap_or(u64::MAX)); + self.last_accounted_at = self + .last_accounted_at + .checked_add(advance) + .unwrap_or_else(Instant::now); + } + + fn reset_baseline(&mut self) { + self.last_accounted_at = Instant::now(); + } + + fn mark_active_goal(&mut self, goal_id: impl Into) { + let goal_id = goal_id.into(); + if self.active_goal_id.as_deref() != Some(goal_id.as_str()) { + self.reset_baseline(); + self.active_goal_id = Some(goal_id); + } + } + + fn clear_active_goal(&mut self) { + self.active_goal_id = None; + self.reset_baseline(); + } +} + +fn should_clear_active_goal( + status: ThreadGoalStatus, + budget_limited_goal_disposition: BudgetLimitedGoalDisposition, +) -> bool { + match status { + ThreadGoalStatus::Active => false, + ThreadGoalStatus::BudgetLimited => matches!( + budget_limited_goal_disposition, + BudgetLimitedGoalDisposition::ClearActive + ), + ThreadGoalStatus::Paused + | ThreadGoalStatus::Blocked + | ThreadGoalStatus::UsageLimited + | ThreadGoalStatus::Complete => true, + } +} diff --git a/codex-rs/ext/goal/src/extension.rs b/codex-rs/ext/goal/src/extension.rs index 6d8eb42f0e..61114bf57e 100644 --- a/codex-rs/ext/goal/src/extension.rs +++ b/codex-rs/ext/goal/src/extension.rs @@ -18,13 +18,15 @@ use codex_extension_api::TurnLifecycleContributor; use codex_extension_api::TurnStartInput; use codex_extension_api::TurnStopInput; use codex_protocol::ThreadId; +use codex_protocol::protocol::ThreadGoal; use codex_protocol::protocol::TokenUsageInfo; -use codex_protocol::protocol::TurnAbortReason; +use crate::accounting::BudgetLimitedGoalDisposition; use crate::accounting::GoalAccountingState; use crate::events::GoalEventEmitter; use crate::spec::UPDATE_GOAL_TOOL_NAME; use crate::tool::GoalToolExecutor; +use crate::tool::protocol_goal_from_state; #[derive(Clone, Debug)] pub struct GoalExtensionConfig { @@ -108,11 +110,39 @@ where return; } - accounting_state(input.thread_store).start_turn( + let accounting = accounting_state(input.thread_store); + accounting.start_turn( input.turn_id, input.collaboration_mode.mode, input.token_usage_at_turn_start, ); + if matches!( + input.collaboration_mode.mode, + codex_protocol::config_types::ModeKind::Plan + ) { + accounting.clear_current_turn_goal(); + return; + } + let Ok(thread_id) = ThreadId::from_string(input.thread_store.level_id()) else { + return; + }; + let Ok(goal) = self + .state_dbs + .thread_goals() + .get_thread_goal(thread_id) + .await + else { + return; + }; + if let Some(goal) = goal + && matches!( + goal.status, + codex_state::ThreadGoalStatus::Active + | codex_state::ThreadGoalStatus::BudgetLimited + ) + { + accounting.mark_turn_goal_active(input.turn_id, goal.goal_id); + } } async fn on_turn_stop(&self, input: TurnStopInput<'_>) { @@ -120,13 +150,16 @@ where return; } - // TODO: this should flush wall-clock and any unflushed token usage to - // persisted goal storage, emit ThreadGoalUpdated, and optionally inject - // budget-limit steering through a host event/input capability. - // TODO: the host also needs an idle/next-turn wake capability so an - // active goal can enqueue continuation context after the turn is fully - // cleared, only when there is no pending user or mailbox work. - accounting_state(input.thread_store).stop_turn(input.turn_store.level_id()); + let turn_id = input.turn_store.level_id(); + self.account_active_goal_progress( + input.thread_store, + turn_id, + &format!("{turn_id}:turn-stop"), + codex_state::GoalAccountingMode::ActiveOnly, + BudgetLimitedGoalDisposition::ClearActive, + ) + .await; + accounting_state(input.thread_store).finish_turn(turn_id); } async fn on_turn_abort(&self, input: TurnAbortInput<'_>) { @@ -134,11 +167,16 @@ where return; } - accounting_state(input.thread_store).stop_turn(input.turn_store.level_id()); - if input.reason == TurnAbortReason::Interrupted { - // TODO: interrupted turns should pause the active goal via persisted - // goal storage and emit ThreadGoalUpdated with turn_id None. - } + let turn_id = input.turn_store.level_id(); + self.account_active_goal_progress( + input.thread_store, + turn_id, + &format!("{turn_id}:turn-abort"), + codex_state::GoalAccountingMode::ActiveOnly, + BudgetLimitedGoalDisposition::ClearActive, + ) + .await; + accounting_state(input.thread_store).finish_turn(turn_id); } } @@ -163,11 +201,6 @@ where else { return; }; - - // TODO: TokenUsageContributor needs a host goal storage capability so - // this recorded delta can be committed to the active persisted goal. - // It also needs an event/input capability to emit ThreadGoalUpdated and - // inject budget-limit steering when accounting changes goal status. } } @@ -177,15 +210,21 @@ where { fn on_tool_finish<'a>(&'a self, input: ToolFinishInput<'a>) -> ToolLifecycleFuture<'a> { Box::pin(async move { - let _should_count_for_goal_progress = goal_enabled(input.thread_store) + let should_count_for_goal_progress = goal_enabled(input.thread_store) && tool_attempt_counts_for_goal_progress(input.outcome) && !(input.tool_name.namespace.is_none() && input.tool_name.name == UPDATE_GOAL_TOOL_NAME); - - // TODO: commit active goal progress through host goal storage and emit - // ThreadGoalUpdated when the persisted goal changes. This replaces - // GoalRuntimeEvent::ToolCompleted once the goal extension owns runtime - // accounting. + if !should_count_for_goal_progress { + return; + } + self.account_active_goal_progress( + input.thread_store, + input.turn_id, + input.call_id, + codex_state::GoalAccountingMode::ActiveOnly, + BudgetLimitedGoalDisposition::KeepActive, + ) + .await; }) } } @@ -216,16 +255,19 @@ where Arc::new(GoalToolExecutor::get( thread_id, Arc::clone(&self.state_dbs), + accounting_state(thread_store), self.event_emitter.clone(), )), Arc::new(GoalToolExecutor::create( thread_id, Arc::clone(&self.state_dbs), + accounting_state(thread_store), self.event_emitter.clone(), )), Arc::new(GoalToolExecutor::update( thread_id, Arc::clone(&self.state_dbs), + accounting_state(thread_store), self.event_emitter.clone(), )), ] @@ -275,3 +317,50 @@ fn tool_attempt_counts_for_goal_progress(outcome: ToolCallOutcome) -> bool { | ToolCallOutcome::Aborted => false, } } + +impl GoalExtension { + async fn account_active_goal_progress( + &self, + thread_store: &ExtensionData, + turn_id: &str, + event_id: &str, + mode: codex_state::GoalAccountingMode, + budget_limited_goal_disposition: BudgetLimitedGoalDisposition, + ) -> Option { + let Ok(thread_id) = ThreadId::from_string(thread_store.level_id()) else { + return None; + }; + let accounting = accounting_state(thread_store); + let snapshot = accounting.progress_snapshot(turn_id)?; + let outcome = self + .state_dbs + .thread_goals() + .account_thread_goal_usage( + thread_id, + snapshot.time_delta_seconds, + snapshot.token_delta, + mode, + Some(snapshot.expected_goal_id.as_str()), + ) + .await + .ok()?; + match outcome { + codex_state::GoalAccountingOutcome::Updated(goal) => { + accounting.mark_progress_accounted_for_status( + turn_id, + &snapshot, + goal.status, + budget_limited_goal_disposition, + ); + let goal = protocol_goal_from_state(goal); + self.event_emitter.thread_goal_updated( + event_id.to_string(), + Some(turn_id.to_string()), + goal.clone(), + ); + Some(goal) + } + codex_state::GoalAccountingOutcome::Unchanged(_) => None, + } + } +} diff --git a/codex-rs/ext/goal/src/spec.rs b/codex-rs/ext/goal/src/spec.rs index e127846ab5..70e89e4fbb 100644 --- a/codex-rs/ext/goal/src/spec.rs +++ b/codex-rs/ext/goal/src/spec.rs @@ -60,9 +60,9 @@ pub fn create_update_goal_tool() -> ToolSpec { let properties = BTreeMap::from([( "status".to_string(), JsonSchema::string_enum( - vec![json!("complete")], + vec![json!("complete"), json!("blocked")], Some( - "Required. Set to complete only when the objective is achieved and no required work remains." + "Required. Set to complete only when the objective is achieved and no required work remains. Set to blocked only when the goal cannot currently proceed without a user decision, missing dependency, or external unblock." .to_string(), ), ), @@ -71,8 +71,9 @@ pub fn create_update_goal_tool() -> ToolSpec { ToolSpec::Function(ResponsesApiTool { name: UPDATE_GOAL_TOOL_NAME.to_string(), description: r#"Update the existing goal. -Use this tool only to mark the goal achieved. +Use this tool only to mark the goal achieved or blocked. Set status to `complete` only when the objective has actually been achieved and no required work remains. +Set status to `blocked` only when the goal cannot currently proceed until something external changes. Do not mark a goal complete merely because its budget is nearly exhausted or because you are stopping work. You cannot use this tool to pause, resume, or budget-limit a goal; those status changes are controlled by the user or system. When marking a budgeted goal achieved with status `complete`, report the final token usage from the tool result to the user."# diff --git a/codex-rs/ext/goal/src/tool.rs b/codex-rs/ext/goal/src/tool.rs index 96b160f938..fa3b7d5089 100644 --- a/codex-rs/ext/goal/src/tool.rs +++ b/codex-rs/ext/goal/src/tool.rs @@ -15,6 +15,8 @@ use codex_protocol::protocol::validate_thread_goal_objective; use serde::Deserialize; use serde::Serialize; +use crate::accounting::BudgetLimitedGoalDisposition; +use crate::accounting::GoalAccountingState; use crate::events::GoalEventEmitter; use crate::spec::CREATE_GOAL_TOOL_NAME; use crate::spec::GET_GOAL_TOOL_NAME; @@ -28,6 +30,7 @@ pub(crate) struct GoalToolExecutor { kind: GoalToolKind, thread_id: ThreadId, state_db: Arc, + accounting_state: Arc, event_emitter: GoalEventEmitter, } @@ -69,12 +72,14 @@ impl GoalToolExecutor { pub(crate) fn get( thread_id: ThreadId, state_db: Arc, + accounting_state: Arc, event_emitter: GoalEventEmitter, ) -> Self { Self { kind: GoalToolKind::Get, thread_id, state_db, + accounting_state, event_emitter, } } @@ -82,12 +87,14 @@ impl GoalToolExecutor { pub(crate) fn create( thread_id: ThreadId, state_db: Arc, + accounting_state: Arc, event_emitter: GoalEventEmitter, ) -> Self { Self { kind: GoalToolKind::Create, thread_id, state_db, + accounting_state, event_emitter, } } @@ -95,12 +102,14 @@ impl GoalToolExecutor { pub(crate) fn update( thread_id: ThreadId, state_db: Arc, + accounting_state: Arc, event_emitter: GoalEventEmitter, ) -> Self { Self { kind: GoalToolKind::Update, thread_id, state_db, + accounting_state, event_emitter, } } @@ -179,8 +188,11 @@ impl GoalToolExecutor { ) })?; fill_empty_thread_preview_if_possible(self.state_db.as_ref(), self.thread_id, &goal).await; + let turn_id = self + .accounting_state + .mark_current_turn_goal_active(goal.goal_id.clone()); let goal = protocol_goal_from_state(goal); - self.emit_goal_updated_from_tool_call(&invocation, goal.clone()); + self.emit_goal_updated_from_tool_call(&invocation, turn_id, goal.clone()); goal_response(Some(goal), CompletionBudgetReport::Omit) } @@ -189,15 +201,29 @@ impl GoalToolExecutor { invocation: ToolCall, ) -> Result, FunctionCallError> { let args: UpdateGoalArgs = parse_arguments(invocation.function_arguments()?)?; - if args.status != ThreadGoalStatus::Complete { + if !matches!( + args.status, + ThreadGoalStatus::Complete | ThreadGoalStatus::Blocked + ) { return Err(FunctionCallError::RespondToModel( - "update_goal can only mark the existing goal complete; pause, resume, and budget-limited status changes are controlled by the user or system" + "update_goal can only mark the existing goal complete or blocked; pause, resume, budget-limited, and usage-limited status changes are controlled by the user or system" .to_string(), )); } - // TODO: update_goal needs a host callback before completion to flush - // final active-turn accounting with budget steering suppressed. + self.account_active_goal_progress( + match args.status { + ThreadGoalStatus::Complete => codex_state::GoalAccountingMode::ActiveOrComplete, + ThreadGoalStatus::Blocked => codex_state::GoalAccountingMode::ActiveOrStopped, + ThreadGoalStatus::Active + | ThreadGoalStatus::Paused + | ThreadGoalStatus::UsageLimited + | ThreadGoalStatus::BudgetLimited => unreachable!("status validated above"), + }, + invocation.call_id.as_str(), + BudgetLimitedGoalDisposition::ClearActive, + ) + .await?; let goal = self .state_db .thread_goals() @@ -205,14 +231,14 @@ impl GoalToolExecutor { self.thread_id, codex_state::GoalUpdate { objective: None, - status: Some(codex_state::ThreadGoalStatus::Complete), + status: Some(state_status_from_protocol(args.status)), token_budget: None, expected_goal_id: None, }, ) .await .map_err(|err| { - FunctionCallError::RespondToModel(format!("failed to complete goal: {err}")) + FunctionCallError::RespondToModel(format!("failed to update goal: {err}")) })? .map(protocol_goal_from_state) .ok_or_else(|| { @@ -220,19 +246,72 @@ impl GoalToolExecutor { "cannot update goal because this thread has no goal".to_string(), ) })?; - self.emit_goal_updated_from_tool_call(&invocation, goal.clone()); - goal_response(Some(goal), CompletionBudgetReport::Include) + let turn_id = self.accounting_state.clear_current_turn_goal(); + self.emit_goal_updated_from_tool_call(&invocation, turn_id, goal.clone()); + goal_response( + Some(goal), + if args.status == ThreadGoalStatus::Complete { + CompletionBudgetReport::Include + } else { + CompletionBudgetReport::Omit + }, + ) } - fn emit_goal_updated_from_tool_call(&self, invocation: &ToolCall, goal: ThreadGoal) { - // TODO: ToolCall should expose the current turn submission id so goal - // tool events can set ThreadGoalUpdatedEvent.turn_id exactly as core - // does today. Until then, correlate the event with the tool call id. - self.event_emitter.thread_goal_updated( - invocation.call_id.clone(), - /*turn_id*/ None, - goal, - ); + fn emit_goal_updated_from_tool_call( + &self, + invocation: &ToolCall, + turn_id: Option, + goal: ThreadGoal, + ) { + self.event_emitter + .thread_goal_updated(invocation.call_id.clone(), turn_id, goal); + } + + async fn account_active_goal_progress( + &self, + mode: codex_state::GoalAccountingMode, + event_id: &str, + budget_limited_goal_disposition: BudgetLimitedGoalDisposition, + ) -> Result, FunctionCallError> { + let Some(turn_id) = self.accounting_state.current_turn_id() else { + return Ok(None); + }; + let Some(snapshot) = self.accounting_state.progress_snapshot(turn_id.as_str()) else { + return Ok(None); + }; + let outcome = self + .state_db + .thread_goals() + .account_thread_goal_usage( + self.thread_id, + snapshot.time_delta_seconds, + snapshot.token_delta, + mode, + Some(snapshot.expected_goal_id.as_str()), + ) + .await + .map_err(|err| { + FunctionCallError::RespondToModel(format!("failed to account goal progress: {err}")) + })?; + Ok(match outcome { + codex_state::GoalAccountingOutcome::Updated(goal) => { + self.accounting_state.mark_progress_accounted_for_status( + turn_id.as_str(), + &snapshot, + goal.status, + budget_limited_goal_disposition, + ); + let goal = protocol_goal_from_state(goal); + self.event_emitter.thread_goal_updated( + event_id.to_string(), + Some(turn_id), + goal.clone(), + ); + Some(goal) + } + codex_state::GoalAccountingOutcome::Unchanged(_) => None, + }) } } @@ -298,7 +377,7 @@ async fn fill_empty_thread_preview_if_possible( } } -fn protocol_goal_from_state(goal: codex_state::ThreadGoal) -> ThreadGoal { +pub(crate) fn protocol_goal_from_state(goal: codex_state::ThreadGoal) -> ThreadGoal { ThreadGoal { thread_id: goal.thread_id, objective: goal.objective, @@ -322,6 +401,17 @@ fn protocol_status_from_state(status: codex_state::ThreadGoalStatus) -> ThreadGo } } +fn state_status_from_protocol(status: ThreadGoalStatus) -> codex_state::ThreadGoalStatus { + match status { + ThreadGoalStatus::Active => codex_state::ThreadGoalStatus::Active, + ThreadGoalStatus::Paused => codex_state::ThreadGoalStatus::Paused, + ThreadGoalStatus::Blocked => codex_state::ThreadGoalStatus::Blocked, + ThreadGoalStatus::UsageLimited => codex_state::ThreadGoalStatus::UsageLimited, + ThreadGoalStatus::BudgetLimited => codex_state::ThreadGoalStatus::BudgetLimited, + ThreadGoalStatus::Complete => codex_state::ThreadGoalStatus::Complete, + } +} + fn completion_budget_report(goal: &ThreadGoal) -> Option { if goal.token_budget.is_none() && goal.time_used_seconds <= 0 { None diff --git a/codex-rs/ext/goal/tests/goal_extension_backend.rs b/codex-rs/ext/goal/tests/goal_extension_backend.rs index a79c69c633..d04dbe6021 100644 --- a/codex-rs/ext/goal/tests/goal_extension_backend.rs +++ b/codex-rs/ext/goal/tests/goal_extension_backend.rs @@ -1,16 +1,32 @@ use std::sync::Arc; +use std::sync::Mutex; +use std::sync::PoisonError; use codex_extension_api::ExtensionData; +use codex_extension_api::ExtensionEventSink; use codex_extension_api::ExtensionRegistryBuilder; use codex_extension_api::FunctionCallError; use codex_extension_api::ThreadStartInput; use codex_extension_api::ToolCall; +use codex_extension_api::ToolCallOutcome; +use codex_extension_api::ToolCallSource; use codex_extension_api::ToolExecutor; +use codex_extension_api::ToolFinishInput; +use codex_extension_api::ToolName; use codex_extension_api::ToolPayload; +use codex_extension_api::TurnStartInput; +use codex_extension_api::TurnStopInput; use codex_goal_extension::install_with_backend; use codex_protocol::ThreadId; -use codex_protocol::ToolName; +use codex_protocol::config_types::CollaborationMode; +use codex_protocol::config_types::ModeKind; +use codex_protocol::config_types::Settings; +use codex_protocol::protocol::Event; +use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::SessionSource; +use codex_protocol::protocol::ThreadGoalStatus; +use codex_protocol::protocol::TokenUsage; +use codex_protocol::protocol::TokenUsageInfo; use pretty_assertions::assert_eq; use serde_json::json; use tempfile::TempDir; @@ -70,7 +86,8 @@ async fn installed_goal_tools_reject_duplicate_goal_creation() -> anyhow::Result let runtime = test_runtime().await?; let thread_id = test_thread_id()?; seed_thread_metadata(runtime.as_ref(), thread_id).await?; - let tools = installed_tools(runtime, thread_id).await; + let harness = GoalExtensionHarness::new(runtime, thread_id).await?; + let tools = harness.tools(); let create_tool = tool_by_name(&tools, "create_goal"); let first = tool_call( @@ -100,6 +117,273 @@ async fn installed_goal_tools_reject_duplicate_goal_creation() -> anyhow::Result Ok(()) } +#[tokio::test] +async fn create_goal_resets_baseline_before_turn_stop_accounting() -> anyhow::Result<()> { + let runtime = test_runtime().await?; + let thread_id = test_thread_id()?; + seed_thread_metadata(runtime.as_ref(), thread_id).await?; + let harness = GoalExtensionHarness::new(runtime.clone(), thread_id).await?; + harness + .start_turn( + "turn-1", + &token_usage( + /*input_tokens*/ 100, /*cached_input_tokens*/ 10, + /*output_tokens*/ 30, /*reasoning_output_tokens*/ 5, + /*total_tokens*/ 135, + ), + ) + .await; + harness + .record_token_usage( + "turn-1", + &token_usage( + /*input_tokens*/ 120, /*cached_input_tokens*/ 14, + /*output_tokens*/ 42, /*reasoning_output_tokens*/ 8, + /*total_tokens*/ 162, + ), + ) + .await; + + let tools = harness.tools(); + let create_tool = tool_by_name(&tools, "create_goal"); + create_tool + .handle(tool_call( + "create_goal", + "call-create-goal", + json!({ "objective": "ship goal extension backend" }), + )) + .await?; + + harness + .record_token_usage( + "turn-1", + &token_usage( + /*input_tokens*/ 127, /*cached_input_tokens*/ 16, + /*output_tokens*/ 52, /*reasoning_output_tokens*/ 10, + /*total_tokens*/ 189, + ), + ) + .await; + harness.stop_turn("turn-1").await; + + let goal = runtime + .thread_goals() + .get_thread_goal(thread_id) + .await? + .ok_or_else(|| anyhow::anyhow!("goal should exist"))?; + assert_eq!(15, goal.tokens_used); + assert_eq!(ThreadGoalStatus::Active, protocol_status(goal.status)); + Ok(()) +} + +#[tokio::test] +async fn tool_finish_accounts_active_goal_progress_and_emits_event() -> anyhow::Result<()> { + let runtime = test_runtime().await?; + let thread_id = test_thread_id()?; + seed_thread_metadata(runtime.as_ref(), thread_id).await?; + let harness = GoalExtensionHarness::new(runtime.clone(), thread_id).await?; + harness.start_turn("turn-1", &TokenUsage::default()).await; + + let tools = harness.tools(); + let create_tool = tool_by_name(&tools, "create_goal"); + create_tool + .handle(tool_call( + "create_goal", + "call-create-goal", + json!({ "objective": "ship goal extension backend" }), + )) + .await?; + harness.sink.clear(); + + harness + .record_token_usage( + "turn-1", + &token_usage( + /*input_tokens*/ 20, /*cached_input_tokens*/ 5, /*output_tokens*/ 8, + /*reasoning_output_tokens*/ 2, /*total_tokens*/ 30, + ), + ) + .await; + harness + .notify_tool_finish("turn-1", "call-shell", "shell") + .await; + + let goal = runtime + .thread_goals() + .get_thread_goal(thread_id) + .await? + .ok_or_else(|| anyhow::anyhow!("goal should exist"))?; + assert_eq!(23, goal.tokens_used); + + assert_eq!( + vec![CapturedGoalEvent { + event_id: "call-shell".to_string(), + turn_id: Some("turn-1".to_string()), + status: ThreadGoalStatus::Active, + tokens_used: 23, + }], + harness.sink.goal_events() + ); + Ok(()) +} + +#[tokio::test] +async fn budget_limited_goal_keeps_accruing_until_turn_stop() -> anyhow::Result<()> { + let runtime = test_runtime().await?; + let thread_id = test_thread_id()?; + seed_thread_metadata(runtime.as_ref(), thread_id).await?; + let harness = GoalExtensionHarness::new(runtime.clone(), thread_id).await?; + harness.start_turn("turn-1", &TokenUsage::default()).await; + + let tools = harness.tools(); + let create_tool = tool_by_name(&tools, "create_goal"); + create_tool + .handle(tool_call( + "create_goal", + "call-create-goal", + json!({ + "objective": "ship goal extension backend", + "token_budget": 25, + }), + )) + .await?; + harness.sink.clear(); + + harness + .record_token_usage( + "turn-1", + &token_usage( + /*input_tokens*/ 20, /*cached_input_tokens*/ 5, + /*output_tokens*/ 10, /*reasoning_output_tokens*/ 0, + /*total_tokens*/ 30, + ), + ) + .await; + harness + .notify_tool_finish("turn-1", "call-shell", "shell") + .await; + harness + .record_token_usage( + "turn-1", + &token_usage( + /*input_tokens*/ 24, /*cached_input_tokens*/ 5, + /*output_tokens*/ 16, /*reasoning_output_tokens*/ 0, + /*total_tokens*/ 40, + ), + ) + .await; + harness.stop_turn("turn-1").await; + + let goal = runtime + .thread_goals() + .get_thread_goal(thread_id) + .await? + .ok_or_else(|| anyhow::anyhow!("goal should exist"))?; + assert_eq!(35, goal.tokens_used); + assert_eq!(codex_state::ThreadGoalStatus::BudgetLimited, goal.status); + + assert_eq!( + vec![ + CapturedGoalEvent { + event_id: "call-shell".to_string(), + turn_id: Some("turn-1".to_string()), + status: ThreadGoalStatus::BudgetLimited, + tokens_used: 25, + }, + CapturedGoalEvent { + event_id: "turn-1:turn-stop".to_string(), + turn_id: Some("turn-1".to_string()), + status: ThreadGoalStatus::BudgetLimited, + tokens_used: 35, + }, + ], + harness.sink.goal_events() + ); + Ok(()) +} + +#[tokio::test] +async fn update_goal_can_block_and_accounts_final_progress() -> anyhow::Result<()> { + let runtime = test_runtime().await?; + let thread_id = test_thread_id()?; + seed_thread_metadata(runtime.as_ref(), thread_id).await?; + let harness = GoalExtensionHarness::new(runtime.clone(), thread_id).await?; + harness.start_turn("turn-1", &TokenUsage::default()).await; + + let tools = harness.tools(); + let create_tool = tool_by_name(&tools, "create_goal"); + create_tool + .handle(tool_call( + "create_goal", + "call-create-goal", + json!({ "objective": "ship goal extension backend" }), + )) + .await?; + harness.sink.clear(); + + harness + .record_token_usage( + "turn-1", + &token_usage( + /*input_tokens*/ 20, /*cached_input_tokens*/ 5, /*output_tokens*/ 8, + /*reasoning_output_tokens*/ 2, /*total_tokens*/ 30, + ), + ) + .await; + let update_tool = tool_by_name(&tools, "update_goal"); + let invocation = tool_call( + "update_goal", + "call-update-goal", + json!({ "status": "blocked" }), + ); + let output = update_tool.handle(invocation.clone()).await?; + let result = output.code_mode_result(&invocation.payload); + + assert_eq!( + result, + json!({ + "goal": { + "threadId": thread_id, + "objective": "ship goal extension backend", + "status": "blocked", + "tokensUsed": 23, + "timeUsedSeconds": 0, + "createdAt": result["goal"]["createdAt"], + "updatedAt": result["goal"]["updatedAt"], + }, + "remainingTokens": serde_json::Value::Null, + "completionBudgetReport": serde_json::Value::Null, + }) + ); + + let goal = runtime + .thread_goals() + .get_thread_goal(thread_id) + .await? + .ok_or_else(|| anyhow::anyhow!("goal should exist"))?; + assert_eq!(23, goal.tokens_used); + assert_eq!(codex_state::ThreadGoalStatus::Blocked, goal.status); + + assert_eq!( + vec![ + CapturedGoalEvent { + event_id: "call-update-goal".to_string(), + turn_id: Some("turn-1".to_string()), + status: ThreadGoalStatus::Active, + tokens_used: 23, + }, + CapturedGoalEvent { + event_id: "call-update-goal".to_string(), + turn_id: Some("turn-1".to_string()), + status: ThreadGoalStatus::Blocked, + tokens_used: 23, + }, + ], + harness.sink.goal_events() + ); + Ok(()) +} + async fn installed_tools( runtime: Arc, thread_id: ThreadId, @@ -126,6 +410,118 @@ async fn installed_tools( .collect() } +struct GoalExtensionHarness { + registry: codex_extension_api::ExtensionRegistry<()>, + session_store: ExtensionData, + thread_store: ExtensionData, + sink: Arc, +} + +impl GoalExtensionHarness { + async fn new( + runtime: Arc, + thread_id: ThreadId, + ) -> anyhow::Result { + let sink = Arc::new(RecordingEventSink::default()); + let mut builder = ExtensionRegistryBuilder::<()>::with_event_sink(sink.clone()); + install_with_backend(&mut builder, runtime, |_| true); + let registry = builder.build(); + let session_store = ExtensionData::new("session-1"); + let thread_store = ExtensionData::new(thread_id.to_string()); + for contributor in registry.thread_lifecycle_contributors() { + contributor + .on_thread_start(ThreadStartInput { + config: &(), + session_store: &session_store, + thread_store: &thread_store, + }) + .await; + } + Ok(Self { + registry, + session_store, + thread_store, + sink, + }) + } + + fn tools(&self) -> Vec>> { + self.registry + .tool_contributors() + .iter() + .flat_map(|contributor| contributor.tools(&self.session_store, &self.thread_store)) + .collect() + } + + async fn start_turn(&self, turn_id: &str, usage: &TokenUsage) { + let turn_store = ExtensionData::new(turn_id); + let collaboration_mode = default_collaboration_mode(); + for contributor in self.registry.turn_lifecycle_contributors() { + contributor + .on_turn_start(TurnStartInput { + turn_id, + collaboration_mode: &collaboration_mode, + token_usage_at_turn_start: usage, + session_store: &self.session_store, + thread_store: &self.thread_store, + turn_store: &turn_store, + }) + .await; + } + } + + async fn stop_turn(&self, turn_id: &str) { + let turn_store = ExtensionData::new(turn_id); + for contributor in self.registry.turn_lifecycle_contributors() { + contributor + .on_turn_stop(TurnStopInput { + session_store: &self.session_store, + thread_store: &self.thread_store, + turn_store: &turn_store, + }) + .await; + } + } + + async fn record_token_usage(&self, turn_id: &str, usage: &TokenUsage) { + let turn_store = ExtensionData::new(turn_id); + let token_usage = TokenUsageInfo { + total_token_usage: usage.clone(), + last_token_usage: TokenUsage::default(), + model_context_window: None, + }; + for contributor in self.registry.token_usage_contributors() { + contributor + .on_token_usage( + &self.session_store, + &self.thread_store, + &turn_store, + &token_usage, + ) + .await; + } + } + + async fn notify_tool_finish(&self, turn_id: &str, call_id: &str, tool_name: &str) { + let turn_store = ExtensionData::new(turn_id); + let tool_name = codex_extension_api::ToolName::plain(tool_name); + for contributor in self.registry.tool_lifecycle_contributors() { + contributor + .on_tool_finish(ToolFinishInput { + session_store: &self.session_store, + thread_store: &self.thread_store, + turn_store: &turn_store, + turn_id, + call_id, + tool_name: &tool_name, + source: ToolCallSource::Direct, + outcome: ToolCallOutcome::Completed { success: true }, + }) + .await; + } + } +} + fn tool_by_name<'a>( tools: &'a [Arc>], name: &str, @@ -139,7 +535,7 @@ fn tool_by_name<'a>( fn tool_call(tool_name: &str, call_id: &str, arguments: serde_json::Value) -> ToolCall { ToolCall { call_id: call_id.to_string(), - tool_name: ToolName::plain(tool_name), + tool_name: codex_extension_api::ToolName::plain(tool_name), payload: ToolPayload::Function { arguments: arguments.to_string(), }, @@ -169,3 +565,85 @@ async fn seed_thread_metadata( ); runtime.upsert_thread(&builder.build("test-provider")).await } + +#[derive(Debug, Default)] +struct RecordingEventSink { + events: Mutex>, +} + +impl RecordingEventSink { + fn goal_events(&self) -> Vec { + self.events() + .iter() + .filter_map(|event| match &event.msg { + EventMsg::ThreadGoalUpdated(updated) => Some(CapturedGoalEvent { + event_id: event.id.clone(), + turn_id: updated.turn_id.clone(), + status: updated.goal.status, + tokens_used: updated.goal.tokens_used, + }), + _ => None, + }) + .collect() + } + + fn clear(&self) { + self.events().clear(); + } + + fn events(&self) -> std::sync::MutexGuard<'_, Vec> { + self.events.lock().unwrap_or_else(PoisonError::into_inner) + } +} + +impl ExtensionEventSink for RecordingEventSink { + fn emit(&self, event: Event) { + self.events().push(event); + } +} + +#[derive(Debug, PartialEq, Eq)] +struct CapturedGoalEvent { + event_id: String, + turn_id: Option, + status: ThreadGoalStatus, + tokens_used: i64, +} + +fn default_collaboration_mode() -> CollaborationMode { + CollaborationMode { + mode: ModeKind::Default, + settings: Settings { + model: "gpt-5".to_string(), + reasoning_effort: None, + developer_instructions: None, + }, + } +} + +fn token_usage( + input_tokens: i64, + cached_input_tokens: i64, + output_tokens: i64, + reasoning_output_tokens: i64, + total_tokens: i64, +) -> TokenUsage { + TokenUsage { + input_tokens, + cached_input_tokens, + output_tokens, + reasoning_output_tokens, + total_tokens, + } +} + +fn protocol_status(status: codex_state::ThreadGoalStatus) -> ThreadGoalStatus { + match status { + codex_state::ThreadGoalStatus::Active => ThreadGoalStatus::Active, + codex_state::ThreadGoalStatus::Paused => ThreadGoalStatus::Paused, + codex_state::ThreadGoalStatus::Blocked => ThreadGoalStatus::Blocked, + codex_state::ThreadGoalStatus::UsageLimited => ThreadGoalStatus::UsageLimited, + codex_state::ThreadGoalStatus::BudgetLimited => ThreadGoalStatus::BudgetLimited, + codex_state::ThreadGoalStatus::Complete => ThreadGoalStatus::Complete, + } +}