From c69cde3547c87c3423434ff37273dcadbcce8817 Mon Sep 17 00:00:00 2001 From: jif-oai Date: Mon, 18 May 2026 21:55:57 +0200 Subject: [PATCH] Add tool lifecycle extension contributor (#23309) ## Why Extensions that need to track runtime progress currently have no typed host signal for tool execution. The goal extension in particular needs to observe tool attempts without inspecting tool payloads, owning tool implementations, or staying coupled to core-only runtime plumbing. This adds a narrow lifecycle contributor API for host-owned tool execution: extensions can observe when an accepted tool call starts and how it finishes, while policy hooks and tool handlers continue to own payload rewriting, blocking, and execution. Relevant code: - [`ToolLifecycleContributor`](https://github.com/openai/codex/blob/3ad2850ffc7d8a1da19c65a92425637a59098f1b/codex-rs/ext/extension-api/src/contributors.rs#L119) defines the extension-facing observer contract. - [`tool_lifecycle.rs`](https://github.com/openai/codex/blob/3ad2850ffc7d8a1da19c65a92425637a59098f1b/codex-rs/ext/extension-api/src/contributors/tool_lifecycle.rs) defines the typed start/finish inputs, source, and outcome enums. - [`notify_tool_start` / `notify_tool_finish`](https://github.com/openai/codex/blob/3ad2850ffc7d8a1da19c65a92425637a59098f1b/codex-rs/core/src/tools/lifecycle.rs) bridges core tool dispatch into the extension registry. ## What Changed - Added `ToolLifecycleContributor` to `codex-extension-api`, including: - `ToolStartInput` - `ToolFinishInput` - `ToolCallSource` - `ToolCallOutcome` - Added registration and lookup support on `ExtensionRegistryBuilder` / `ExtensionRegistry`. - Wired core tool dispatch to notify lifecycle contributors for: - accepted tool starts - completed tool calls, including the tool output success marker - pre-tool-use blocks - failures before or after the handler runs - cancellation/abort in the parallel tool path - Registered the goal extension as a lifecycle contributor and added the outcome filter it will use for goal progress accounting. ## Test Coverage - Added `dispatch_notifies_tool_lifecycle_contributors` to cover lifecycle notification ordering and outcomes for successful and handler-failed tool calls. --- codex-rs/core/src/tools/lifecycle.rs | 98 ++++++++ codex-rs/core/src/tools/mod.rs | 1 + codex-rs/core/src/tools/parallel.rs | 234 +++++++++++++++--- codex-rs/core/src/tools/registry.rs | 66 ++++- codex-rs/core/src/tools/registry_tests.rs | 194 +++++++++++++++ codex-rs/core/src/tools/router.rs | 53 +++- .../ext/extension-api/src/contributors.rs | 23 ++ .../src/contributors/tool_lifecycle.rs | 82 ++++++ codex-rs/ext/extension-api/src/lib.rs | 6 + codex-rs/ext/extension-api/src/registry.rs | 15 ++ codex-rs/ext/goal/src/extension.rs | 39 +++ 11 files changed, 777 insertions(+), 34 deletions(-) create mode 100644 codex-rs/core/src/tools/lifecycle.rs create mode 100644 codex-rs/ext/extension-api/src/contributors/tool_lifecycle.rs diff --git a/codex-rs/core/src/tools/lifecycle.rs b/codex-rs/core/src/tools/lifecycle.rs new file mode 100644 index 0000000000..ad8b492cce --- /dev/null +++ b/codex-rs/core/src/tools/lifecycle.rs @@ -0,0 +1,98 @@ +use codex_extension_api::ToolCallOutcome; +use codex_extension_api::ToolCallSource as ExtensionToolCallSource; +use codex_extension_api::ToolFinishInput; +use codex_extension_api::ToolStartInput; +use codex_tools::ToolName; + +use crate::session::session::Session; +use crate::session::turn_context::TurnContext; +use crate::tools::context::ToolCallSource; +use crate::tools::context::ToolInvocation; + +pub(crate) async fn notify_tool_start(invocation: &ToolInvocation) { + for contributor in invocation + .session + .services + .extensions + .tool_lifecycle_contributors() + { + contributor + .on_tool_start(ToolStartInput { + session_store: &invocation.session.services.session_extension_data, + thread_store: &invocation.session.services.thread_extension_data, + turn_store: invocation.turn.extension_data.as_ref(), + turn_id: invocation.turn.sub_id.as_str(), + call_id: invocation.call_id.as_str(), + tool_name: &invocation.tool_name, + source: extension_tool_call_source(invocation.source.clone()), + }) + .await; + } +} + +pub(crate) async fn notify_tool_finish(invocation: &ToolInvocation, outcome: ToolCallOutcome) { + notify_tool_finish_parts( + invocation.session.as_ref(), + invocation.turn.as_ref(), + invocation.call_id.as_str(), + &invocation.tool_name, + invocation.source.clone(), + outcome, + ) + .await; +} + +pub(crate) async fn notify_tool_aborted( + session: &Session, + turn: &TurnContext, + call_id: &str, + tool_name: &ToolName, + source: ToolCallSource, +) { + notify_tool_finish_parts( + session, + turn, + call_id, + tool_name, + source, + ToolCallOutcome::Aborted, + ) + .await; +} + +async fn notify_tool_finish_parts( + session: &Session, + turn: &TurnContext, + call_id: &str, + tool_name: &ToolName, + source: ToolCallSource, + outcome: ToolCallOutcome, +) { + for contributor in session.services.extensions.tool_lifecycle_contributors() { + contributor + .on_tool_finish(ToolFinishInput { + session_store: &session.services.session_extension_data, + thread_store: &session.services.thread_extension_data, + turn_store: turn.extension_data.as_ref(), + turn_id: turn.sub_id.as_str(), + call_id, + tool_name, + source: extension_tool_call_source(source.clone()), + outcome, + }) + .await; + } +} + +fn extension_tool_call_source(source: ToolCallSource) -> ExtensionToolCallSource { + match source { + ToolCallSource::Direct => ExtensionToolCallSource::Direct, + ToolCallSource::CodeMode { + cell_id, + runtime_tool_call_id, + } => ExtensionToolCallSource::CodeMode { + cell_id, + runtime_tool_call_id, + }, + } +} diff --git a/codex-rs/core/src/tools/mod.rs b/codex-rs/core/src/tools/mod.rs index dd63116353..5b7a17f428 100644 --- a/codex-rs/core/src/tools/mod.rs +++ b/codex-rs/core/src/tools/mod.rs @@ -4,6 +4,7 @@ pub(crate) mod events; pub(crate) mod handlers; pub(crate) mod hook_names; pub(crate) mod hosted_spec; +pub(crate) mod lifecycle; pub(crate) mod network_approval; pub(crate) mod orchestrator; pub(crate) mod parallel; diff --git a/codex-rs/core/src/tools/parallel.rs b/codex-rs/core/src/tools/parallel.rs index 4c79e4b168..15954869e2 100644 --- a/codex-rs/core/src/tools/parallel.rs +++ b/codex-rs/core/src/tools/parallel.rs @@ -1,7 +1,10 @@ use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; use std::time::Instant; use tokio::sync::RwLock; +use tokio::task::JoinError; use tokio_util::either::Either; use tokio_util::sync::CancellationToken; use tokio_util::task::AbortOnDropHandle; @@ -15,6 +18,7 @@ use crate::session::turn_context::TurnContext; use crate::tools::context::AbortedToolOutput; use crate::tools::context::SharedTurnDiffTracker; use crate::tools::context::ToolPayload; +use crate::tools::lifecycle::notify_tool_aborted; use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolArgumentDiffConsumer; use crate::tools::router::ToolCall; @@ -89,6 +93,12 @@ impl ToolCallRuntime { let lock = Arc::clone(&self.parallel_execution); let invocation_cancellation_token = cancellation_token.clone(); let started = Instant::now(); + let abort_session = Arc::clone(&session); + let abort_source = source.clone(); + let abort_turn = Arc::clone(&turn); + let terminal_outcome_reached = Arc::new(AtomicBool::new(false)); + let dispatch_terminal_outcome_reached = Arc::clone(&terminal_outcome_reached); + let dispatch_call = call.clone(); let dispatch_span = trace_span!( "dispatch_tool_call_with_code_mode_result", @@ -97,47 +107,69 @@ impl ToolCallRuntime { call_id = call.call_id.as_str(), aborted = false, ); + let abort_dispatch_span = dispatch_span.clone(); - let handle: AbortOnDropHandle> = + let mut handle: AbortOnDropHandle> = AbortOnDropHandle::new(tokio::spawn(async move { - tokio::select! { - _ = cancellation_token.cancelled() => { - let secs = started.elapsed().as_secs_f32().max(0.1); - dispatch_span.record("aborted", true); - Ok(Self::aborted_response(&call, secs)) - }, - res = async { - let _guard = if supports_parallel { - Either::Left(lock.read().await) - } else { - Either::Right(lock.write().await) - }; + let _guard = if supports_parallel { + Either::Left(lock.read().await) + } else { + Either::Right(lock.write().await) + }; - router - .dispatch_tool_call_with_code_mode_result( - session, - turn, - invocation_cancellation_token, - tracker, - call.clone(), - source, - ) - .instrument(dispatch_span.clone()) - .await - } => res, - } + router + .dispatch_tool_call_with_terminal_outcome( + session, + turn, + invocation_cancellation_token, + tracker, + dispatch_call, + source, + dispatch_terminal_outcome_reached, + ) + .instrument(dispatch_span.clone()) + .await })); async move { - handle.await.map_err(|err| { - FunctionCallError::Fatal(format!("tool task failed to receive: {err:?}")) - })? + tokio::select! { + res = &mut handle => res.map_err(Self::tool_task_join_error)?, + _ = cancellation_token.cancelled() => { + if terminal_outcome_reached.load(Ordering::Acquire) || handle.is_finished() { + handle.await.map_err(Self::tool_task_join_error)? + } else { + let secs = started.elapsed().as_secs_f32().max(0.1); + abort_dispatch_span.record("aborted", true); + handle.abort(); + match handle.await { + Ok(result) => result, + Err(err) if err.is_cancelled() => { + let response = Self::aborted_response(&call, secs); + notify_tool_aborted( + abort_session.as_ref(), + abort_turn.as_ref(), + call.call_id.as_str(), + &call.tool_name, + abort_source, + ) + .await; + Ok(response) + } + Err(err) => Err(Self::tool_task_join_error(err)), + } + } + }, + } } .in_current_span() } } impl ToolCallRuntime { + fn tool_task_join_error(err: JoinError) -> FunctionCallError { + FunctionCallError::Fatal(format!("tool task failed to receive: {err:?}")) + } + fn failure_response(call: ToolCall, err: FunctionCallError) -> ResponseInputItem { let message = err.to_string(); match call.payload { @@ -189,3 +221,147 @@ impl ToolCallRuntime { } } } + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + use crate::tools::context::FunctionToolOutput; + use crate::tools::context::ToolInvocation; + use crate::tools::registry::CoreToolRuntime; + use crate::tools::registry::ToolExecutor; + use crate::tools::registry::ToolRegistry; + use crate::turn_diff_tracker::TurnDiffTracker; + use codex_extension_api::ToolCallOutcome; + use codex_protocol::models::FunctionCallOutputBody; + use codex_protocol::models::FunctionCallOutputPayload; + use pretty_assertions::assert_eq; + use tokio::sync::Notify; + use tokio::sync::oneshot; + + struct ImmediateHandler { + tool_name: codex_tools::ToolName, + } + + #[async_trait::async_trait] + impl ToolExecutor for ImmediateHandler { + fn tool_name(&self) -> codex_tools::ToolName { + self.tool_name.clone() + } + + async fn handle( + &self, + _invocation: ToolInvocation, + ) -> Result, FunctionCallError> { + Ok(Box::new(FunctionToolOutput::from_text( + "ok".to_string(), + Some(true), + ))) + } + } + + impl CoreToolRuntime for ImmediateHandler {} + + struct BlockingFinishContributor { + records: Arc>>, + finish_started: std::sync::Mutex>>, + allow_finish: Arc, + } + + impl codex_extension_api::ToolLifecycleContributor for BlockingFinishContributor { + fn on_tool_finish<'a>( + &'a self, + input: codex_extension_api::ToolFinishInput<'a>, + ) -> codex_extension_api::ToolLifecycleFuture<'a> { + let records = Arc::clone(&self.records); + let allow_finish = Arc::clone(&self.allow_finish); + let finish_started = self + .finish_started + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .take(); + let outcome = input.outcome; + Box::pin(async move { + if let Some(finish_started) = finish_started { + let _ = finish_started.send(()); + } + allow_finish.notified().await; + records + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .push(outcome); + }) + } + } + + #[tokio::test] + async fn cancellation_after_handler_finishes_preserves_completed_lifecycle() + -> anyhow::Result<()> { + let (mut session, turn_context) = crate::session::tests::make_session_and_context().await; + let records = Arc::new(std::sync::Mutex::new(Vec::new())); + let (finish_started_tx, finish_started_rx) = oneshot::channel(); + let allow_finish = Arc::new(Notify::new()); + let mut builder = + codex_extension_api::ExtensionRegistryBuilder::::new(); + builder.tool_lifecycle_contributor(Arc::new(BlockingFinishContributor { + records: Arc::clone(&records), + finish_started: std::sync::Mutex::new(Some(finish_started_tx)), + allow_finish: Arc::clone(&allow_finish), + })); + session.services.extensions = Arc::new(builder.build()); + + let session = Arc::new(session); + let turn_context = Arc::new(turn_context); + let tool_name = codex_tools::ToolName::plain("test_tool"); + let handler = Arc::new(ImmediateHandler { + tool_name: tool_name.clone(), + }) as Arc; + let router = Arc::new(ToolRouter::from_parts( + ToolRegistry::from_tools([handler]), + Vec::new(), + )); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); + let runtime = ToolCallRuntime::new(router, session, turn_context, tracker); + let cancellation_token = CancellationToken::new(); + let call = ToolCall { + tool_name, + call_id: "call-1".to_string(), + payload: ToolPayload::Function { + arguments: "{}".to_string(), + }, + }; + + let response_task = + tokio::spawn(runtime.handle_tool_call(call, cancellation_token.clone())); + tokio::time::timeout(Duration::from_secs(1), finish_started_rx) + .await + .expect("timed out waiting for lifecycle notification to start") + .expect("lifecycle notification should start"); + cancellation_token.cancel(); + tokio::time::sleep(Duration::from_millis(10)).await; + allow_finish.notify_waiters(); + + let response = tokio::time::timeout(Duration::from_secs(1), response_task) + .await + .expect("timed out waiting for tool response") + .expect("tool response task should join")?; + let expected_response = ResponseInputItem::FunctionCallOutput { + call_id: "call-1".to_string(), + output: FunctionCallOutputPayload { + body: FunctionCallOutputBody::Text("ok".to_string()), + success: Some(true), + }, + }; + assert_eq!(expected_response, response); + + let actual = records + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .drain(..) + .collect::>(); + assert_eq!(vec![ToolCallOutcome::Completed { success: true }], actual); + + Ok(()) + } +} diff --git a/codex-rs/core/src/tools/registry.rs b/codex-rs/core/src/tools/registry.rs index 96b686a83a..363c3f2a01 100644 --- a/codex-rs/core/src/tools/registry.rs +++ b/codex-rs/core/src/tools/registry.rs @@ -1,5 +1,7 @@ use std::collections::HashMap; use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; use std::time::Duration; use crate::function_tool::FunctionCallError; @@ -18,9 +20,12 @@ use crate::tools::context::ToolOutput; use crate::tools::context::ToolPayload; use crate::tools::flat_tool_name; use crate::tools::hook_names::HookToolName; +use crate::tools::lifecycle::notify_tool_finish; +use crate::tools::lifecycle::notify_tool_start; use crate::tools::tool_dispatch_trace::ToolDispatchTrace; use crate::tools::tool_search_entry::ToolSearchInfo; use crate::util::error_or_panic; +use codex_extension_api::ToolCallOutcome; use codex_protocol::models::ResponseInputItem; use codex_protocol::protocol::EventMsg; use codex_tools::ToolName; @@ -298,13 +303,23 @@ impl ToolRegistry { Some(tool.supports_parallel_tool_calls()) } + #[allow(dead_code)] + pub(crate) async fn dispatch_any( + &self, + invocation: ToolInvocation, + ) -> Result { + self.dispatch_any_with_terminal_outcome(invocation, /*terminal_outcome_reached*/ None) + .await + } + #[expect( clippy::await_holding_invalid_type, reason = "tool dispatch must keep active-turn accounting atomic" )] - pub(crate) async fn dispatch_any( + pub(crate) async fn dispatch_any_with_terminal_outcome( &self, mut invocation: ToolInvocation, + terminal_outcome_reached: Option>, ) -> Result { let tool_name = invocation.tool_name.clone(); let tool_name_flat = flat_tool_name(&tool_name); @@ -389,6 +404,8 @@ impl ToolRegistry { return Err(err); } + notify_tool_start(&invocation).await; + if let Some(pre_tool_use_payload) = tool.pre_tool_use_payload(&invocation) { match run_pre_tool_use_hooks( &invocation.session, @@ -402,13 +419,33 @@ impl ToolRegistry { PreToolUseHookResult::Blocked(message) => { let err = FunctionCallError::RespondToModel(message); dispatch_trace.record_failed(&err); + if let Some(terminal_outcome_reached) = &terminal_outcome_reached { + terminal_outcome_reached.store(true, Ordering::Release); + } + notify_tool_finish(&invocation, ToolCallOutcome::Blocked).await; return Err(err); } PreToolUseHookResult::Continue { updated_input: Some(updated_input), - } => { - invocation = tool.with_updated_hook_input(invocation, updated_input)?; - } + } => match tool.with_updated_hook_input(invocation.clone(), updated_input) { + Ok(updated_invocation) => { + invocation = updated_invocation; + } + Err(err) => { + dispatch_trace.record_failed(&err); + if let Some(terminal_outcome_reached) = &terminal_outcome_reached { + terminal_outcome_reached.store(true, Ordering::Release); + } + notify_tool_finish( + &invocation, + ToolCallOutcome::Failed { + handler_executed: false, + }, + ) + .await; + return Err(err); + } + }, PreToolUseHookResult::Continue { updated_input: None, } => {} @@ -503,6 +540,27 @@ impl ToolRegistry { } } + let lifecycle_outcome = match &result { + Ok(_) => { + let guard = response_cell.lock().await; + match guard.as_ref() { + Some(result) => ToolCallOutcome::Completed { + success: result.result.success_for_logging(), + }, + None => ToolCallOutcome::Failed { + handler_executed: true, + }, + } + } + Err(_) => ToolCallOutcome::Failed { + handler_executed: true, + }, + }; + if let Some(terminal_outcome_reached) = &terminal_outcome_reached { + terminal_outcome_reached.store(true, Ordering::Release); + } + notify_tool_finish(&invocation, lifecycle_outcome).await; + if let Err(err) = invocation .session .goal_runtime_apply(GoalRuntimeEvent::ToolCompleted { diff --git a/codex-rs/core/src/tools/registry_tests.rs b/codex-rs/core/src/tools/registry_tests.rs index defacf33c0..e3ecfc8f98 100644 --- a/codex-rs/core/src/tools/registry_tests.rs +++ b/codex-rs/core/src/tools/registry_tests.rs @@ -23,6 +23,97 @@ impl ToolExecutor for TestHandler { impl CoreToolRuntime for TestHandler {} +#[derive(Clone)] +enum LifecycleTestResult { + Ok { success: bool }, + Err, +} + +struct LifecycleTestHandler { + tool_name: codex_tools::ToolName, + result: LifecycleTestResult, +} + +#[async_trait::async_trait] +impl ToolExecutor for LifecycleTestHandler { + fn tool_name(&self) -> codex_tools::ToolName { + self.tool_name.clone() + } + + async fn handle( + &self, + _invocation: ToolInvocation, + ) -> Result, FunctionCallError> { + match self.result.clone() { + LifecycleTestResult::Ok { success } => Ok(Box::new( + crate::tools::context::FunctionToolOutput::from_text( + "ok".to_string(), + Some(success), + ), + )), + LifecycleTestResult::Err => Err(FunctionCallError::RespondToModel( + "handler failed".to_string(), + )), + } + } +} + +impl CoreToolRuntime for LifecycleTestHandler {} + +#[derive(Debug, PartialEq, Eq)] +enum RecordedToolLifecycle { + Start { + call_id: String, + tool_name: codex_tools::ToolName, + }, + Finish { + call_id: String, + tool_name: codex_tools::ToolName, + outcome: codex_extension_api::ToolCallOutcome, + }, +} + +struct ToolLifecycleRecorder { + records: Arc>>, +} + +impl codex_extension_api::ToolLifecycleContributor for ToolLifecycleRecorder { + fn on_tool_start<'a>( + &'a self, + input: codex_extension_api::ToolStartInput<'a>, + ) -> codex_extension_api::ToolLifecycleFuture<'a> { + let records = Arc::clone(&self.records); + let record = RecordedToolLifecycle::Start { + call_id: input.call_id.to_string(), + tool_name: input.tool_name.clone(), + }; + Box::pin(async move { + records + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .push(record); + }) + } + + fn on_tool_finish<'a>( + &'a self, + input: codex_extension_api::ToolFinishInput<'a>, + ) -> codex_extension_api::ToolLifecycleFuture<'a> { + let records = Arc::clone(&self.records); + let record = RecordedToolLifecycle::Finish { + call_id: input.call_id.to_string(), + tool_name: input.tool_name.clone(), + outcome: input.outcome, + }; + Box::pin(async move { + records + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .push(record); + }) + } +} + #[test] fn handler_looks_up_namespaced_aliases_explicitly() { let namespace = "mcp__codex_apps__gmail"; @@ -61,3 +152,106 @@ fn handler_looks_up_namespaced_aliases_explicitly() { .is_some_and(|handler| Arc::ptr_eq(handler, &namespaced_handler)) ); } + +#[tokio::test] +async fn dispatch_notifies_tool_lifecycle_contributors() -> anyhow::Result<()> { + let (mut session, turn) = crate::session::tests::make_session_and_context().await; + let records = Arc::new(std::sync::Mutex::new(Vec::new())); + let mut builder = codex_extension_api::ExtensionRegistryBuilder::::new(); + builder.tool_lifecycle_contributor(Arc::new(ToolLifecycleRecorder { + records: Arc::clone(&records), + })); + session.services.extensions = Arc::new(builder.build()); + + let ok_tool = codex_tools::ToolName::plain("ok_tool"); + let failing_tool = codex_tools::ToolName::plain("failing_tool"); + let ok_handler = Arc::new(LifecycleTestHandler { + tool_name: ok_tool.clone(), + result: LifecycleTestResult::Ok { success: false }, + }) as Arc; + let failing_handler = Arc::new(LifecycleTestHandler { + tool_name: failing_tool.clone(), + result: LifecycleTestResult::Err, + }) as Arc; + let registry = ToolRegistry::new(HashMap::from([ + (ok_tool.clone(), ok_handler), + (failing_tool.clone(), failing_handler), + ])); + let session = Arc::new(session); + let turn = Arc::new(turn); + + registry + .dispatch_any(test_invocation( + Arc::clone(&session), + Arc::clone(&turn), + "ok-call", + ok_tool.clone(), + )) + .await?; + let err = match registry + .dispatch_any(test_invocation( + Arc::clone(&session), + Arc::clone(&turn), + "failing-call", + failing_tool.clone(), + )) + .await + { + Ok(_) => panic!("failing handler should return an error"), + Err(err) => err, + }; + assert_eq!(err.to_string(), "handler failed"); + + let expected = vec![ + RecordedToolLifecycle::Start { + call_id: "ok-call".to_string(), + tool_name: ok_tool.clone(), + }, + RecordedToolLifecycle::Finish { + call_id: "ok-call".to_string(), + tool_name: ok_tool, + outcome: codex_extension_api::ToolCallOutcome::Completed { success: false }, + }, + RecordedToolLifecycle::Start { + call_id: "failing-call".to_string(), + tool_name: failing_tool.clone(), + }, + RecordedToolLifecycle::Finish { + call_id: "failing-call".to_string(), + tool_name: failing_tool, + outcome: codex_extension_api::ToolCallOutcome::Failed { + handler_executed: true, + }, + }, + ]; + let actual = records + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .drain(..) + .collect::>(); + assert_eq!(expected, actual); + + Ok(()) +} + +fn test_invocation( + session: Arc, + turn: Arc, + call_id: &str, + tool_name: codex_tools::ToolName, +) -> ToolInvocation { + ToolInvocation { + session, + turn, + cancellation_token: tokio_util::sync::CancellationToken::new(), + tracker: Arc::new(tokio::sync::Mutex::new( + crate::turn_diff_tracker::TurnDiffTracker::new(), + )), + call_id: call_id.to_string(), + tool_name, + source: crate::tools::context::ToolCallSource::Direct, + payload: ToolPayload::Function { + arguments: "{}".to_string(), + }, + } +} diff --git a/codex-rs/core/src/tools/router.rs b/codex-rs/core/src/tools/router.rs index 2477ba347c..a279ec88d9 100644 --- a/codex-rs/core/src/tools/router.rs +++ b/codex-rs/core/src/tools/router.rs @@ -19,6 +19,7 @@ use codex_tools::ToolName; use codex_tools::ToolSpec; use codex_tools::ToolsConfig; use std::sync::Arc; +use std::sync::atomic::AtomicBool; use tokio_util::sync::CancellationToken; use tracing::instrument; @@ -123,6 +124,7 @@ impl ToolRouter { } } + #[allow(dead_code)] #[instrument(level = "trace", skip_all, err)] pub async fn dispatch_tool_call_with_code_mode_result( &self, @@ -132,6 +134,53 @@ impl ToolRouter { tracker: SharedTurnDiffTracker, call: ToolCall, source: ToolCallSource, + ) -> Result { + self.dispatch_tool_call_with_code_mode_result_inner( + session, + turn, + cancellation_token, + tracker, + call, + source, + /*terminal_outcome_reached*/ None, + ) + .await + } + + #[instrument(level = "trace", skip_all, err)] + #[allow(clippy::too_many_arguments)] + pub(crate) async fn dispatch_tool_call_with_terminal_outcome( + &self, + session: Arc, + turn: Arc, + cancellation_token: CancellationToken, + tracker: SharedTurnDiffTracker, + call: ToolCall, + source: ToolCallSource, + terminal_outcome_reached: Arc, + ) -> Result { + self.dispatch_tool_call_with_code_mode_result_inner( + session, + turn, + cancellation_token, + tracker, + call, + source, + Some(terminal_outcome_reached), + ) + .await + } + + #[allow(clippy::too_many_arguments)] + async fn dispatch_tool_call_with_code_mode_result_inner( + &self, + session: Arc, + turn: Arc, + cancellation_token: CancellationToken, + tracker: SharedTurnDiffTracker, + call: ToolCall, + source: ToolCallSource, + terminal_outcome_reached: Option>, ) -> Result { let ToolCall { tool_name, @@ -150,7 +199,9 @@ impl ToolRouter { payload, }; - self.registry.dispatch_any(invocation).await + self.registry + .dispatch_any_with_terminal_outcome(invocation, terminal_outcome_reached) + .await } } diff --git a/codex-rs/ext/extension-api/src/contributors.rs b/codex-rs/ext/extension-api/src/contributors.rs index 4aa28d1049..8968e08573 100644 --- a/codex-rs/ext/extension-api/src/contributors.rs +++ b/codex-rs/ext/extension-api/src/contributors.rs @@ -11,6 +11,7 @@ use crate::ExtensionData; mod prompt; mod thread_lifecycle; +mod tool_lifecycle; mod turn_lifecycle; pub use prompt::PromptFragment; @@ -18,6 +19,11 @@ pub use prompt::PromptSlot; pub use thread_lifecycle::ThreadResumeInput; pub use thread_lifecycle::ThreadStartInput; pub use thread_lifecycle::ThreadStopInput; +pub use tool_lifecycle::ToolCallOutcome; +pub use tool_lifecycle::ToolCallSource; +pub use tool_lifecycle::ToolFinishInput; +pub use tool_lifecycle::ToolLifecycleFuture; +pub use tool_lifecycle::ToolStartInput; pub use turn_lifecycle::TurnAbortInput; pub use turn_lifecycle::TurnStartInput; pub use turn_lifecycle::TurnStopInput; @@ -111,6 +117,23 @@ pub trait ToolContributor: Send + Sync { ) -> Vec>>; } +/// Contributor for host-owned tool lifecycle gates. +/// +/// Implementations should use these callbacks to observe tool execution without +/// inspecting or rewriting tool input/output. Use `ToolContributor` for owning a +/// tool implementation and hooks for policy that needs tool payloads. +pub trait ToolLifecycleContributor: Send + Sync { + /// Called once the host has accepted a tool call for execution. + fn on_tool_start<'a>(&'a self, _input: ToolStartInput<'a>) -> ToolLifecycleFuture<'a> { + Box::pin(std::future::ready(())) + } + + /// Called after the tool call returns, is blocked, fails, or is cancelled. + fn on_tool_finish<'a>(&'a self, _input: ToolFinishInput<'a>) -> ToolLifecycleFuture<'a> { + Box::pin(std::future::ready(())) + } +} + /// Future returned by one claimed approval-review contribution. pub type ApprovalReviewFuture<'a> = std::pin::Pin + Send + 'a>>; diff --git a/codex-rs/ext/extension-api/src/contributors/tool_lifecycle.rs b/codex-rs/ext/extension-api/src/contributors/tool_lifecycle.rs new file mode 100644 index 0000000000..486bca2643 --- /dev/null +++ b/codex-rs/ext/extension-api/src/contributors/tool_lifecycle.rs @@ -0,0 +1,82 @@ +use std::future::Future; +use std::pin::Pin; + +use codex_tools::ToolName; + +use crate::ExtensionData; + +/// Future returned by one tool-lifecycle callback. +pub type ToolLifecycleFuture<'a> = Pin + Send + 'a>>; + +/// Host-visible source for a model tool call. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum ToolCallSource { + /// The model invoked the tool directly. + Direct, + /// Code mode invoked the tool while executing a runtime cell. + CodeMode { + /// Runtime cell that issued the nested tool request. + cell_id: String, + /// Code-mode's per-cell tool invocation id. + runtime_tool_call_id: String, + }, +} + +/// Extension-facing outcome for a finished tool call. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum ToolCallOutcome { + /// The tool returned a normal output. + Completed { + /// The tool output's own success marker for telemetry/logging. + success: bool, + }, + /// The tool was blocked by host policy before the handler ran. + Blocked, + /// The tool did not produce a normal output. + Failed { + /// Whether the host reached the tool handler before the failure. + handler_executed: bool, + }, + /// The host cancelled the tool before normal completion. Cancellation can + /// win before the dispatch path accepts the call, so contributors should not + /// assume a matching start callback exists. + Aborted, +} + +/// Input supplied when the host starts executing one tool call. +pub struct ToolStartInput<'a> { + /// Store scoped to the host session runtime. + pub session_store: &'a ExtensionData, + /// Store scoped to this thread runtime. + pub thread_store: &'a ExtensionData, + /// Store scoped to this turn runtime. + pub turn_store: &'a ExtensionData, + /// Current turn submission id. + pub turn_id: &'a str, + /// Model-visible tool call id. + pub call_id: &'a str, + /// Tool name as routed by the host. + pub tool_name: &'a ToolName, + /// Source that issued the tool call. + pub source: ToolCallSource, +} + +/// Input supplied when the host finishes executing one tool call. +pub struct ToolFinishInput<'a> { + /// Store scoped to the host session runtime. + pub session_store: &'a ExtensionData, + /// Store scoped to this thread runtime. + pub thread_store: &'a ExtensionData, + /// Store scoped to this turn runtime. + pub turn_store: &'a ExtensionData, + /// Current turn submission id. + pub turn_id: &'a str, + /// Model-visible tool call id. + pub call_id: &'a str, + /// Tool name as routed by the host. + pub tool_name: &'a ToolName, + /// Source that issued the tool call. + pub source: ToolCallSource, + /// Host-observed result of the tool call. + pub outcome: ToolCallOutcome, +} diff --git a/codex-rs/ext/extension-api/src/lib.rs b/codex-rs/ext/extension-api/src/lib.rs index fe33d42128..373f3735a4 100644 --- a/codex-rs/ext/extension-api/src/lib.rs +++ b/codex-rs/ext/extension-api/src/lib.rs @@ -28,7 +28,13 @@ pub use contributors::ThreadResumeInput; pub use contributors::ThreadStartInput; pub use contributors::ThreadStopInput; pub use contributors::TokenUsageContributor; +pub use contributors::ToolCallOutcome; +pub use contributors::ToolCallSource; pub use contributors::ToolContributor; +pub use contributors::ToolFinishInput; +pub use contributors::ToolLifecycleContributor; +pub use contributors::ToolLifecycleFuture; +pub use contributors::ToolStartInput; pub use contributors::TurnAbortInput; pub use contributors::TurnItemContributionFuture; pub use contributors::TurnItemContributor; diff --git a/codex-rs/ext/extension-api/src/registry.rs b/codex-rs/ext/extension-api/src/registry.rs index 41d0967126..4577ddc048 100644 --- a/codex-rs/ext/extension-api/src/registry.rs +++ b/codex-rs/ext/extension-api/src/registry.rs @@ -10,6 +10,7 @@ use crate::NoopExtensionEventSink; use crate::ThreadLifecycleContributor; use crate::TokenUsageContributor; use crate::ToolContributor; +use crate::ToolLifecycleContributor; use crate::TurnItemContributor; use crate::TurnLifecycleContributor; @@ -22,6 +23,7 @@ pub struct ExtensionRegistryBuilder { token_usage_contributors: Vec>, context_contributors: Vec>, tool_contributors: Vec>, + tool_lifecycle_contributors: Vec>, turn_item_contributors: Vec>, approval_review_contributors: Vec>, } @@ -37,6 +39,7 @@ impl Default for ExtensionRegistryBuilder { approval_review_contributors: Vec::new(), context_contributors: Vec::new(), tool_contributors: Vec::new(), + tool_lifecycle_contributors: Vec::new(), turn_item_contributors: Vec::new(), } } @@ -99,6 +102,11 @@ impl ExtensionRegistryBuilder { self.tool_contributors.push(contributor); } + /// Registers one tool-lifecycle contributor. + pub fn tool_lifecycle_contributor(&mut self, contributor: Arc) { + self.tool_lifecycle_contributors.push(contributor); + } + /// Registers one ordered turn-item contributor. pub fn turn_item_contributor(&mut self, contributor: Arc) { self.turn_item_contributors.push(contributor); @@ -115,6 +123,7 @@ impl ExtensionRegistryBuilder { approval_review_contributors: self.approval_review_contributors, context_contributors: self.context_contributors, tool_contributors: self.tool_contributors, + tool_lifecycle_contributors: self.tool_lifecycle_contributors, turn_item_contributors: self.turn_item_contributors, } } @@ -129,6 +138,7 @@ pub struct ExtensionRegistry { token_usage_contributors: Vec>, context_contributors: Vec>, tool_contributors: Vec>, + tool_lifecycle_contributors: Vec>, turn_item_contributors: Vec>, approval_review_contributors: Vec>, } @@ -182,6 +192,11 @@ impl ExtensionRegistry { &self.tool_contributors } + /// Returns the registered tool-lifecycle contributors. + pub fn tool_lifecycle_contributors(&self) -> &[Arc] { + &self.tool_lifecycle_contributors + } + /// Returns the registered ordered turn-item contributors. pub fn turn_item_contributors(&self) -> &[Arc] { &self.turn_item_contributors diff --git a/codex-rs/ext/goal/src/extension.rs b/codex-rs/ext/goal/src/extension.rs index a0ce7117f4..a8d4f5c289 100644 --- a/codex-rs/ext/goal/src/extension.rs +++ b/codex-rs/ext/goal/src/extension.rs @@ -9,7 +9,11 @@ use codex_extension_api::NoopExtensionEventSink; use codex_extension_api::ThreadLifecycleContributor; use codex_extension_api::ThreadStartInput; use codex_extension_api::TokenUsageContributor; +use codex_extension_api::ToolCallOutcome; use codex_extension_api::ToolContributor; +use codex_extension_api::ToolFinishInput; +use codex_extension_api::ToolLifecycleContributor; +use codex_extension_api::ToolLifecycleFuture; use codex_extension_api::TurnAbortInput; use codex_extension_api::TurnLifecycleContributor; use codex_extension_api::TurnStartInput; @@ -21,6 +25,7 @@ use codex_protocol::protocol::TurnAbortReason; use crate::accounting::GoalAccountingState; use crate::events::GoalEventEmitter; +use crate::spec::UPDATE_GOAL_TOOL_NAME; use crate::tool::CreateGoalRequest; use crate::tool::GoalToolExecutor; @@ -221,6 +226,25 @@ where } } +impl ToolLifecycleContributor for GoalExtension +where + C: Send + Sync + 'static, +{ + fn on_tool_finish<'a>(&'a self, input: ToolFinishInput<'a>) -> ToolLifecycleFuture<'a> { + Box::pin(async move { + let _should_count_for_goal_progress = goal_enabled(input.thread_store) + && tool_attempt_counts_for_goal_progress(input.outcome) + && !(input.tool_name.namespace.is_none() + && input.tool_name.name == UPDATE_GOAL_TOOL_NAME); + + // TODO: commit active goal progress through host goal storage and emit + // ThreadGoalUpdated when the persisted goal changes. This replaces + // GoalRuntimeEvent::ToolCompleted once the goal extension owns runtime + // accounting. + }) + } +} + // TODO: app-server initiated goal set/clear operations need a contributor or // backend callback here. They currently happen outside thread/turn/token // lifecycle, but the goal extension must observe them to account before @@ -288,6 +312,7 @@ pub fn install_with_backend( registry.config_contributor(extension.clone()); registry.turn_lifecycle_contributor(extension.clone()); registry.token_usage_contributor(extension.clone()); + registry.tool_lifecycle_contributor(extension.clone()); registry.tool_contributor(extension); } @@ -300,3 +325,17 @@ fn goal_enabled(thread_store: &ExtensionData) -> bool { fn accounting_state(thread_store: &ExtensionData) -> Arc { thread_store.get_or_init::(GoalAccountingState::default) } + +fn tool_attempt_counts_for_goal_progress(outcome: ToolCallOutcome) -> bool { + match outcome { + ToolCallOutcome::Completed { .. } => true, + ToolCallOutcome::Failed { + handler_executed: true, + } => true, + ToolCallOutcome::Blocked + | ToolCallOutcome::Failed { + handler_executed: false, + } + | ToolCallOutcome::Aborted => false, + } +}