diff --git a/codex-rs/core/src/session/tests.rs b/codex-rs/core/src/session/tests.rs index 175f022289..5318b72151 100644 --- a/codex-rs/core/src/session/tests.rs +++ b/codex-rs/core/src/session/tests.rs @@ -1911,6 +1911,108 @@ async fn record_token_usage_info_notifies_extension_contributors() { assert_eq!(expected, actual); } +#[tokio::test] +async fn turn_start_lifecycle_exposes_turn_metadata_and_token_baseline() { + struct SessionTurnStartMarker; + struct ThreadTurnStartMarker; + + #[derive(Debug, PartialEq, Eq)] + struct RecordedTurnStart { + turn_id: String, + collaboration_mode: CollaborationMode, + token_usage_at_turn_start: TokenUsage, + saw_session_store: bool, + saw_thread_store: bool, + } + + struct TurnStartRecorder { + records: Arc>>, + } + + #[async_trait::async_trait] + impl codex_extension_api::TurnLifecycleContributor for TurnStartRecorder { + async fn on_turn_start(&self, input: codex_extension_api::TurnStartInput<'_>) { + self.records + .lock() + .expect("turn start records lock") + .push(RecordedTurnStart { + turn_id: input.turn_id.to_string(), + collaboration_mode: input.collaboration_mode.clone(), + token_usage_at_turn_start: input.token_usage_at_turn_start.clone(), + saw_session_store: input + .session_store + .get::() + .is_some(), + saw_thread_store: input.thread_store.get::().is_some(), + }); + } + } + + let (mut session, turn_context) = make_session_and_context().await; + let records = Arc::new(std::sync::Mutex::new(Vec::new())); + let mut builder = codex_extension_api::ExtensionRegistryBuilder::::new(); + builder.turn_lifecycle_contributor(Arc::new(TurnStartRecorder { + records: Arc::clone(&records), + })); + session.services.extensions = Arc::new(builder.build()); + session + .services + .session_extension_data + .insert(SessionTurnStartMarker); + session + .services + .thread_extension_data + .insert(ThreadTurnStartMarker); + + let token_usage_at_turn_start = TokenUsage { + input_tokens: 120, + cached_input_tokens: 15, + output_tokens: 40, + reasoning_output_tokens: 9, + total_tokens: 169, + }; + session + .state + .lock() + .await + .set_token_info(Some(TokenUsageInfo { + total_token_usage: token_usage_at_turn_start.clone(), + last_token_usage: TokenUsage::default(), + model_context_window: turn_context.model_context_window(), + })); + + let turn_context = Arc::new(turn_context); + let session = Arc::new(session); + session + .spawn_task( + Arc::clone(&turn_context), + Vec::new(), + NeverEndingTask { + kind: TaskKind::Regular, + listen_to_cancellation_token: true, + }, + ) + .await; + + session.abort_all_tasks(TurnAbortReason::Interrupted).await; + + let actual = records + .lock() + .expect("turn start records lock") + .drain(..) + .collect::>(); + assert_eq!( + vec![RecordedTurnStart { + turn_id: turn_context.sub_id.clone(), + collaboration_mode: turn_context.collaboration_mode.clone(), + token_usage_at_turn_start, + saw_session_store: true, + saw_thread_store: true, + }], + actual + ); +} + #[tokio::test] async fn config_change_contributor_observes_effective_config_changes() { struct SessionConfigMarker; diff --git a/codex-rs/core/src/tasks/lifecycle.rs b/codex-rs/core/src/tasks/lifecycle.rs index 94b51647c9..8b934175cf 100644 --- a/codex-rs/core/src/tasks/lifecycle.rs +++ b/codex-rs/core/src/tasks/lifecycle.rs @@ -1,16 +1,25 @@ use codex_extension_api::ExtensionData; +use codex_protocol::protocol::TokenUsage; use codex_protocol::protocol::TurnAbortReason; use crate::session::session::Session; +use crate::session::turn_context::TurnContext; impl Session { - pub(super) async fn emit_turn_start_lifecycle(&self, turn_store: &ExtensionData) { + pub(super) async fn emit_turn_start_lifecycle( + &self, + turn_context: &TurnContext, + token_usage_at_turn_start: &TokenUsage, + ) { for contributor in self.services.extensions.turn_lifecycle_contributors() { contributor .on_turn_start(codex_extension_api::TurnStartInput { + turn_id: turn_context.sub_id.as_str(), + collaboration_mode: &turn_context.collaboration_mode, + token_usage_at_turn_start, session_store: &self.services.session_extension_data, thread_store: &self.services.thread_extension_data, - turn_store, + turn_store: turn_context.extension_data.as_ref(), }) .await; } diff --git a/codex-rs/core/src/tasks/mod.rs b/codex-rs/core/src/tasks/mod.rs index 164e65f07c..80ddfee897 100644 --- a/codex-rs/core/src/tasks/mod.rs +++ b/codex-rs/core/src/tasks/mod.rs @@ -357,7 +357,7 @@ impl Session { debug_assert!(turn.tasks.is_empty()); Arc::clone(&turn.turn_state) }; - turn_state.lock().await.token_usage_at_turn_start = token_usage_at_turn_start; + turn_state.lock().await.token_usage_at_turn_start = token_usage_at_turn_start.clone(); let mut pending_items = queued_response_items .into_iter() .map(TurnInput::ResponseInputItem) @@ -366,7 +366,7 @@ impl Session { self.input_queue .extend_pending_input_for_turn_state(turn_state.as_ref(), pending_items) .await; - self.emit_turn_start_lifecycle(turn_context.extension_data.as_ref()) + self.emit_turn_start_lifecycle(turn_context.as_ref(), &token_usage_at_turn_start) .await; let turn_extension_data = Arc::clone(&turn_context.extension_data); diff --git a/codex-rs/ext/extension-api/src/contributors/turn_lifecycle.rs b/codex-rs/ext/extension-api/src/contributors/turn_lifecycle.rs index 0b53183e28..bbd3ae8f39 100644 --- a/codex-rs/ext/extension-api/src/contributors/turn_lifecycle.rs +++ b/codex-rs/ext/extension-api/src/contributors/turn_lifecycle.rs @@ -1,9 +1,17 @@ +use codex_protocol::config_types::CollaborationMode; +use codex_protocol::protocol::TokenUsage; use codex_protocol::protocol::TurnAbortReason; use crate::ExtensionData; /// Input supplied when the host starts a turn. pub struct TurnStartInput<'a> { + /// Stable host-owned turn identifier. + pub turn_id: &'a str, + /// Effective collaboration mode for this turn. + pub collaboration_mode: &'a CollaborationMode, + /// Total token usage snapshot captured when the turn started. + pub token_usage_at_turn_start: &'a TokenUsage, /// Store scoped to the host session runtime. pub session_store: &'a ExtensionData, /// Store scoped to this thread runtime. diff --git a/codex-rs/ext/goal/BUILD.bazel b/codex-rs/ext/goal/BUILD.bazel index 037313da37..05b80c3706 100644 --- a/codex-rs/ext/goal/BUILD.bazel +++ b/codex-rs/ext/goal/BUILD.bazel @@ -3,4 +3,5 @@ load("//:defs.bzl", "codex_rust_crate") codex_rust_crate( name = "goal", crate_name = "codex_goal_extension", + integration_compile_data_extra = ["src/accounting.rs"], ) diff --git a/codex-rs/ext/goal/src/accounting.rs b/codex-rs/ext/goal/src/accounting.rs index 712e325849..cc04684169 100644 --- a/codex-rs/ext/goal/src/accounting.rs +++ b/codex-rs/ext/goal/src/accounting.rs @@ -1,3 +1,4 @@ +use codex_protocol::config_types::ModeKind; use codex_protocol::protocol::TokenUsage; use std::collections::HashMap; use std::sync::Mutex; @@ -17,6 +18,8 @@ struct GoalAccountingInner { #[derive(Debug, Default)] struct GoalTurnAccounting { token_delta: i64, + last_accounted_token_usage: TokenUsage, + account_tokens: bool, stopped: bool, } @@ -27,24 +30,42 @@ pub(crate) struct RecordedTokenDelta { } impl GoalAccountingState { - pub(crate) fn start_turn(&self, turn_id: impl Into) { + pub(crate) fn start_turn( + &self, + turn_id: impl Into, + collaboration_mode: ModeKind, + token_usage_at_turn_start: &TokenUsage, + ) { let turn_id = turn_id.into(); - self.inner().turns.entry(turn_id).or_default().stopped = false; + self.inner().turns.insert( + turn_id, + GoalTurnAccounting { + token_delta: 0, + last_accounted_token_usage: token_usage_at_turn_start.clone(), + account_tokens: !matches!(collaboration_mode, ModeKind::Plan), + stopped: false, + }, + ); } pub(crate) fn record_token_usage( &self, turn_id: impl Into, - usage: &TokenUsage, + total_usage: &TokenUsage, ) -> Option { - let delta = goal_token_delta_for_usage(usage); - if delta <= 0 { + let turn_id = turn_id.into(); + let mut inner = self.inner(); + let turn = inner.turns.get_mut(&turn_id)?; + if turn.stopped || !turn.account_tokens { return None; } - let turn_id = turn_id.into(); - let mut inner = self.inner(); - let turn = inner.turns.entry(turn_id).or_default(); + let delta = + token_delta_since_last_accounting(&turn.last_accounted_token_usage, total_usage); + turn.last_accounted_token_usage = total_usage.clone(); + if delta <= 0 { + return None; + } turn.token_delta = turn.token_delta.saturating_add(delta); let turn_delta = turn.token_delta; inner.unflushed_token_delta = inner.unflushed_token_delta.saturating_add(delta); @@ -65,6 +86,21 @@ impl GoalAccountingState { } } +fn token_delta_since_last_accounting(last: &TokenUsage, current: &TokenUsage) -> i64 { + let delta = TokenUsage { + input_tokens: current.input_tokens.saturating_sub(last.input_tokens), + cached_input_tokens: current + .cached_input_tokens + .saturating_sub(last.cached_input_tokens), + output_tokens: current.output_tokens.saturating_sub(last.output_tokens), + reasoning_output_tokens: current + .reasoning_output_tokens + .saturating_sub(last.reasoning_output_tokens), + total_tokens: current.total_tokens.saturating_sub(last.total_tokens), + }; + goal_token_delta_for_usage(&delta) +} + pub(crate) fn goal_token_delta_for_usage(usage: &TokenUsage) -> i64 { usage .input_tokens diff --git a/codex-rs/ext/goal/src/extension.rs b/codex-rs/ext/goal/src/extension.rs index 657d285f74..6d8eb42f0e 100644 --- a/codex-rs/ext/goal/src/extension.rs +++ b/codex-rs/ext/goal/src/extension.rs @@ -108,10 +108,11 @@ where return; } - // TODO: TurnStartInput should expose collaboration mode and token usage - // at turn start. Goals need mode to suppress plan-mode accounting and - // the token baseline to account deltas exactly. - accounting_state(input.thread_store).start_turn(input.turn_store.level_id()); + accounting_state(input.thread_store).start_turn( + input.turn_id, + input.collaboration_mode.mode, + input.token_usage_at_turn_start, + ); } async fn on_turn_stop(&self, input: TurnStopInput<'_>) { @@ -158,7 +159,7 @@ where } let Some(_recorded) = accounting_state(thread_store) - .record_token_usage(turn_store.level_id(), &token_usage.last_token_usage) + .record_token_usage(turn_store.level_id(), &token_usage.total_token_usage) else { return; }; diff --git a/codex-rs/ext/goal/tests/accounting.rs b/codex-rs/ext/goal/tests/accounting.rs new file mode 100644 index 0000000000..99e9c93005 --- /dev/null +++ b/codex-rs/ext/goal/tests/accounting.rs @@ -0,0 +1,68 @@ +#![allow(dead_code)] + +#[path = "../src/accounting.rs"] +mod accounting; + +use accounting::GoalAccountingState; +use codex_protocol::config_types::ModeKind; +use codex_protocol::protocol::TokenUsage; +use pretty_assertions::assert_eq; + +#[test] +fn goal_accounting_uses_turn_start_baseline_for_exact_deltas() { + let state = GoalAccountingState::default(); + state.start_turn( + "turn-1", + ModeKind::Default, + &token_usage( + /*input_tokens*/ 100, /*cached_input_tokens*/ 10, /*output_tokens*/ 30, + /*reasoning_output_tokens*/ 5, /*total_tokens*/ 135, + ), + ); + + let recorded = state + .record_token_usage( + "turn-1", + &token_usage( + /*input_tokens*/ 120, /*cached_input_tokens*/ 14, + /*output_tokens*/ 42, /*reasoning_output_tokens*/ 8, + /*total_tokens*/ 162, + ), + ) + .expect("token delta should be recorded"); + + assert_eq!(28, recorded.turn_delta); + assert_eq!(28, recorded.thread_unflushed_delta); +} + +#[test] +fn goal_accounting_ignores_plan_mode_turns() { + let state = GoalAccountingState::default(); + state.start_turn("turn-1", ModeKind::Plan, &TokenUsage::default()); + + let recorded = state.record_token_usage( + "turn-1", + &token_usage( + /*input_tokens*/ 20, /*cached_input_tokens*/ 5, /*output_tokens*/ 8, + /*reasoning_output_tokens*/ 2, /*total_tokens*/ 30, + ), + ); + + assert_eq!(None, recorded); +} + +fn token_usage( + input_tokens: i64, + cached_input_tokens: i64, + output_tokens: i64, + reasoning_output_tokens: i64, + total_tokens: i64, +) -> TokenUsage { + TokenUsage { + input_tokens, + cached_input_tokens, + output_tokens, + reasoning_output_tokens, + total_tokens, + } +}