mirror of
https://github.com/openai/codex.git
synced 2026-06-01 19:02:59 +00:00
Add tool lifecycle extension contributor (#23309)
## Why Extensions that need to track runtime progress currently have no typed host signal for tool execution. The goal extension in particular needs to observe tool attempts without inspecting tool payloads, owning tool implementations, or staying coupled to core-only runtime plumbing. This adds a narrow lifecycle contributor API for host-owned tool execution: extensions can observe when an accepted tool call starts and how it finishes, while policy hooks and tool handlers continue to own payload rewriting, blocking, and execution. Relevant code: - [`ToolLifecycleContributor`](3ad2850ffc/codex-rs/ext/extension-api/src/contributors.rs (L119)) defines the extension-facing observer contract. - [`tool_lifecycle.rs`](3ad2850ffc/codex-rs/ext/extension-api/src/contributors/tool_lifecycle.rs) defines the typed start/finish inputs, source, and outcome enums. - [`notify_tool_start` / `notify_tool_finish`](3ad2850ffc/codex-rs/core/src/tools/lifecycle.rs) bridges core tool dispatch into the extension registry. ## What Changed - Added `ToolLifecycleContributor` to `codex-extension-api`, including: - `ToolStartInput` - `ToolFinishInput` - `ToolCallSource` - `ToolCallOutcome` - Added registration and lookup support on `ExtensionRegistryBuilder` / `ExtensionRegistry`. - Wired core tool dispatch to notify lifecycle contributors for: - accepted tool starts - completed tool calls, including the tool output success marker - pre-tool-use blocks - failures before or after the handler runs - cancellation/abort in the parallel tool path - Registered the goal extension as a lifecycle contributor and added the outcome filter it will use for goal progress accounting. ## Test Coverage - Added `dispatch_notifies_tool_lifecycle_contributors` to cover lifecycle notification ordering and outcomes for successful and handler-failed tool calls.
This commit is contained in:
98
codex-rs/core/src/tools/lifecycle.rs
Normal file
98
codex-rs/core/src/tools/lifecycle.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
use codex_extension_api::ToolCallOutcome;
|
||||
use codex_extension_api::ToolCallSource as ExtensionToolCallSource;
|
||||
use codex_extension_api::ToolFinishInput;
|
||||
use codex_extension_api::ToolStartInput;
|
||||
use codex_tools::ToolName;
|
||||
|
||||
use crate::session::session::Session;
|
||||
use crate::session::turn_context::TurnContext;
|
||||
use crate::tools::context::ToolCallSource;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
|
||||
pub(crate) async fn notify_tool_start(invocation: &ToolInvocation) {
|
||||
for contributor in invocation
|
||||
.session
|
||||
.services
|
||||
.extensions
|
||||
.tool_lifecycle_contributors()
|
||||
{
|
||||
contributor
|
||||
.on_tool_start(ToolStartInput {
|
||||
session_store: &invocation.session.services.session_extension_data,
|
||||
thread_store: &invocation.session.services.thread_extension_data,
|
||||
turn_store: invocation.turn.extension_data.as_ref(),
|
||||
turn_id: invocation.turn.sub_id.as_str(),
|
||||
call_id: invocation.call_id.as_str(),
|
||||
tool_name: &invocation.tool_name,
|
||||
source: extension_tool_call_source(invocation.source.clone()),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn notify_tool_finish(invocation: &ToolInvocation, outcome: ToolCallOutcome) {
|
||||
notify_tool_finish_parts(
|
||||
invocation.session.as_ref(),
|
||||
invocation.turn.as_ref(),
|
||||
invocation.call_id.as_str(),
|
||||
&invocation.tool_name,
|
||||
invocation.source.clone(),
|
||||
outcome,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(crate) async fn notify_tool_aborted(
|
||||
session: &Session,
|
||||
turn: &TurnContext,
|
||||
call_id: &str,
|
||||
tool_name: &ToolName,
|
||||
source: ToolCallSource,
|
||||
) {
|
||||
notify_tool_finish_parts(
|
||||
session,
|
||||
turn,
|
||||
call_id,
|
||||
tool_name,
|
||||
source,
|
||||
ToolCallOutcome::Aborted,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
async fn notify_tool_finish_parts(
|
||||
session: &Session,
|
||||
turn: &TurnContext,
|
||||
call_id: &str,
|
||||
tool_name: &ToolName,
|
||||
source: ToolCallSource,
|
||||
outcome: ToolCallOutcome,
|
||||
) {
|
||||
for contributor in session.services.extensions.tool_lifecycle_contributors() {
|
||||
contributor
|
||||
.on_tool_finish(ToolFinishInput {
|
||||
session_store: &session.services.session_extension_data,
|
||||
thread_store: &session.services.thread_extension_data,
|
||||
turn_store: turn.extension_data.as_ref(),
|
||||
turn_id: turn.sub_id.as_str(),
|
||||
call_id,
|
||||
tool_name,
|
||||
source: extension_tool_call_source(source.clone()),
|
||||
outcome,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
fn extension_tool_call_source(source: ToolCallSource) -> ExtensionToolCallSource {
|
||||
match source {
|
||||
ToolCallSource::Direct => ExtensionToolCallSource::Direct,
|
||||
ToolCallSource::CodeMode {
|
||||
cell_id,
|
||||
runtime_tool_call_id,
|
||||
} => ExtensionToolCallSource::CodeMode {
|
||||
cell_id,
|
||||
runtime_tool_call_id,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ pub(crate) mod events;
|
||||
pub(crate) mod handlers;
|
||||
pub(crate) mod hook_names;
|
||||
pub(crate) mod hosted_spec;
|
||||
pub(crate) mod lifecycle;
|
||||
pub(crate) mod network_approval;
|
||||
pub(crate) mod orchestrator;
|
||||
pub(crate) mod parallel;
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Instant;
|
||||
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::task::JoinError;
|
||||
use tokio_util::either::Either;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::AbortOnDropHandle;
|
||||
@@ -15,6 +18,7 @@ use crate::session::turn_context::TurnContext;
|
||||
use crate::tools::context::AbortedToolOutput;
|
||||
use crate::tools::context::SharedTurnDiffTracker;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::lifecycle::notify_tool_aborted;
|
||||
use crate::tools::registry::AnyToolResult;
|
||||
use crate::tools::registry::ToolArgumentDiffConsumer;
|
||||
use crate::tools::router::ToolCall;
|
||||
@@ -89,6 +93,12 @@ impl ToolCallRuntime {
|
||||
let lock = Arc::clone(&self.parallel_execution);
|
||||
let invocation_cancellation_token = cancellation_token.clone();
|
||||
let started = Instant::now();
|
||||
let abort_session = Arc::clone(&session);
|
||||
let abort_source = source.clone();
|
||||
let abort_turn = Arc::clone(&turn);
|
||||
let terminal_outcome_reached = Arc::new(AtomicBool::new(false));
|
||||
let dispatch_terminal_outcome_reached = Arc::clone(&terminal_outcome_reached);
|
||||
let dispatch_call = call.clone();
|
||||
|
||||
let dispatch_span = trace_span!(
|
||||
"dispatch_tool_call_with_code_mode_result",
|
||||
@@ -97,47 +107,69 @@ impl ToolCallRuntime {
|
||||
call_id = call.call_id.as_str(),
|
||||
aborted = false,
|
||||
);
|
||||
let abort_dispatch_span = dispatch_span.clone();
|
||||
|
||||
let handle: AbortOnDropHandle<Result<AnyToolResult, FunctionCallError>> =
|
||||
let mut handle: AbortOnDropHandle<Result<AnyToolResult, FunctionCallError>> =
|
||||
AbortOnDropHandle::new(tokio::spawn(async move {
|
||||
tokio::select! {
|
||||
_ = cancellation_token.cancelled() => {
|
||||
let secs = started.elapsed().as_secs_f32().max(0.1);
|
||||
dispatch_span.record("aborted", true);
|
||||
Ok(Self::aborted_response(&call, secs))
|
||||
},
|
||||
res = async {
|
||||
let _guard = if supports_parallel {
|
||||
Either::Left(lock.read().await)
|
||||
} else {
|
||||
Either::Right(lock.write().await)
|
||||
};
|
||||
let _guard = if supports_parallel {
|
||||
Either::Left(lock.read().await)
|
||||
} else {
|
||||
Either::Right(lock.write().await)
|
||||
};
|
||||
|
||||
router
|
||||
.dispatch_tool_call_with_code_mode_result(
|
||||
session,
|
||||
turn,
|
||||
invocation_cancellation_token,
|
||||
tracker,
|
||||
call.clone(),
|
||||
source,
|
||||
)
|
||||
.instrument(dispatch_span.clone())
|
||||
.await
|
||||
} => res,
|
||||
}
|
||||
router
|
||||
.dispatch_tool_call_with_terminal_outcome(
|
||||
session,
|
||||
turn,
|
||||
invocation_cancellation_token,
|
||||
tracker,
|
||||
dispatch_call,
|
||||
source,
|
||||
dispatch_terminal_outcome_reached,
|
||||
)
|
||||
.instrument(dispatch_span.clone())
|
||||
.await
|
||||
}));
|
||||
|
||||
async move {
|
||||
handle.await.map_err(|err| {
|
||||
FunctionCallError::Fatal(format!("tool task failed to receive: {err:?}"))
|
||||
})?
|
||||
tokio::select! {
|
||||
res = &mut handle => res.map_err(Self::tool_task_join_error)?,
|
||||
_ = cancellation_token.cancelled() => {
|
||||
if terminal_outcome_reached.load(Ordering::Acquire) || handle.is_finished() {
|
||||
handle.await.map_err(Self::tool_task_join_error)?
|
||||
} else {
|
||||
let secs = started.elapsed().as_secs_f32().max(0.1);
|
||||
abort_dispatch_span.record("aborted", true);
|
||||
handle.abort();
|
||||
match handle.await {
|
||||
Ok(result) => result,
|
||||
Err(err) if err.is_cancelled() => {
|
||||
let response = Self::aborted_response(&call, secs);
|
||||
notify_tool_aborted(
|
||||
abort_session.as_ref(),
|
||||
abort_turn.as_ref(),
|
||||
call.call_id.as_str(),
|
||||
&call.tool_name,
|
||||
abort_source,
|
||||
)
|
||||
.await;
|
||||
Ok(response)
|
||||
}
|
||||
Err(err) => Err(Self::tool_task_join_error(err)),
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
.in_current_span()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToolCallRuntime {
|
||||
fn tool_task_join_error(err: JoinError) -> FunctionCallError {
|
||||
FunctionCallError::Fatal(format!("tool task failed to receive: {err:?}"))
|
||||
}
|
||||
|
||||
fn failure_response(call: ToolCall, err: FunctionCallError) -> ResponseInputItem {
|
||||
let message = err.to_string();
|
||||
match call.payload {
|
||||
@@ -189,3 +221,147 @@ impl ToolCallRuntime {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::tools::context::FunctionToolOutput;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::registry::CoreToolRuntime;
|
||||
use crate::tools::registry::ToolExecutor;
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::turn_diff_tracker::TurnDiffTracker;
|
||||
use codex_extension_api::ToolCallOutcome;
|
||||
use codex_protocol::models::FunctionCallOutputBody;
|
||||
use codex_protocol::models::FunctionCallOutputPayload;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
struct ImmediateHandler {
|
||||
tool_name: codex_tools::ToolName,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ToolExecutor<ToolInvocation> for ImmediateHandler {
|
||||
fn tool_name(&self) -> codex_tools::ToolName {
|
||||
self.tool_name.clone()
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
_invocation: ToolInvocation,
|
||||
) -> Result<Box<dyn crate::tools::context::ToolOutput>, FunctionCallError> {
|
||||
Ok(Box::new(FunctionToolOutput::from_text(
|
||||
"ok".to_string(),
|
||||
Some(true),
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
impl CoreToolRuntime for ImmediateHandler {}
|
||||
|
||||
struct BlockingFinishContributor {
|
||||
records: Arc<std::sync::Mutex<Vec<ToolCallOutcome>>>,
|
||||
finish_started: std::sync::Mutex<Option<oneshot::Sender<()>>>,
|
||||
allow_finish: Arc<Notify>,
|
||||
}
|
||||
|
||||
impl codex_extension_api::ToolLifecycleContributor for BlockingFinishContributor {
|
||||
fn on_tool_finish<'a>(
|
||||
&'a self,
|
||||
input: codex_extension_api::ToolFinishInput<'a>,
|
||||
) -> codex_extension_api::ToolLifecycleFuture<'a> {
|
||||
let records = Arc::clone(&self.records);
|
||||
let allow_finish = Arc::clone(&self.allow_finish);
|
||||
let finish_started = self
|
||||
.finish_started
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.take();
|
||||
let outcome = input.outcome;
|
||||
Box::pin(async move {
|
||||
if let Some(finish_started) = finish_started {
|
||||
let _ = finish_started.send(());
|
||||
}
|
||||
allow_finish.notified().await;
|
||||
records
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.push(outcome);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cancellation_after_handler_finishes_preserves_completed_lifecycle()
|
||||
-> anyhow::Result<()> {
|
||||
let (mut session, turn_context) = crate::session::tests::make_session_and_context().await;
|
||||
let records = Arc::new(std::sync::Mutex::new(Vec::new()));
|
||||
let (finish_started_tx, finish_started_rx) = oneshot::channel();
|
||||
let allow_finish = Arc::new(Notify::new());
|
||||
let mut builder =
|
||||
codex_extension_api::ExtensionRegistryBuilder::<crate::config::Config>::new();
|
||||
builder.tool_lifecycle_contributor(Arc::new(BlockingFinishContributor {
|
||||
records: Arc::clone(&records),
|
||||
finish_started: std::sync::Mutex::new(Some(finish_started_tx)),
|
||||
allow_finish: Arc::clone(&allow_finish),
|
||||
}));
|
||||
session.services.extensions = Arc::new(builder.build());
|
||||
|
||||
let session = Arc::new(session);
|
||||
let turn_context = Arc::new(turn_context);
|
||||
let tool_name = codex_tools::ToolName::plain("test_tool");
|
||||
let handler = Arc::new(ImmediateHandler {
|
||||
tool_name: tool_name.clone(),
|
||||
}) as Arc<dyn CoreToolRuntime>;
|
||||
let router = Arc::new(ToolRouter::from_parts(
|
||||
ToolRegistry::from_tools([handler]),
|
||||
Vec::new(),
|
||||
));
|
||||
let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
let runtime = ToolCallRuntime::new(router, session, turn_context, tracker);
|
||||
let cancellation_token = CancellationToken::new();
|
||||
let call = ToolCall {
|
||||
tool_name,
|
||||
call_id: "call-1".to_string(),
|
||||
payload: ToolPayload::Function {
|
||||
arguments: "{}".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let response_task =
|
||||
tokio::spawn(runtime.handle_tool_call(call, cancellation_token.clone()));
|
||||
tokio::time::timeout(Duration::from_secs(1), finish_started_rx)
|
||||
.await
|
||||
.expect("timed out waiting for lifecycle notification to start")
|
||||
.expect("lifecycle notification should start");
|
||||
cancellation_token.cancel();
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
allow_finish.notify_waiters();
|
||||
|
||||
let response = tokio::time::timeout(Duration::from_secs(1), response_task)
|
||||
.await
|
||||
.expect("timed out waiting for tool response")
|
||||
.expect("tool response task should join")?;
|
||||
let expected_response = ResponseInputItem::FunctionCallOutput {
|
||||
call_id: "call-1".to_string(),
|
||||
output: FunctionCallOutputPayload {
|
||||
body: FunctionCallOutputBody::Text("ok".to_string()),
|
||||
success: Some(true),
|
||||
},
|
||||
};
|
||||
assert_eq!(expected_response, response);
|
||||
|
||||
let actual = records
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.drain(..)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(vec![ToolCallOutcome::Completed { success: true }], actual);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::function_tool::FunctionCallError;
|
||||
@@ -18,9 +20,12 @@ use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::flat_tool_name;
|
||||
use crate::tools::hook_names::HookToolName;
|
||||
use crate::tools::lifecycle::notify_tool_finish;
|
||||
use crate::tools::lifecycle::notify_tool_start;
|
||||
use crate::tools::tool_dispatch_trace::ToolDispatchTrace;
|
||||
use crate::tools::tool_search_entry::ToolSearchInfo;
|
||||
use crate::util::error_or_panic;
|
||||
use codex_extension_api::ToolCallOutcome;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_tools::ToolName;
|
||||
@@ -298,13 +303,23 @@ impl ToolRegistry {
|
||||
Some(tool.supports_parallel_tool_calls())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) async fn dispatch_any(
|
||||
&self,
|
||||
invocation: ToolInvocation,
|
||||
) -> Result<AnyToolResult, FunctionCallError> {
|
||||
self.dispatch_any_with_terminal_outcome(invocation, /*terminal_outcome_reached*/ None)
|
||||
.await
|
||||
}
|
||||
|
||||
#[expect(
|
||||
clippy::await_holding_invalid_type,
|
||||
reason = "tool dispatch must keep active-turn accounting atomic"
|
||||
)]
|
||||
pub(crate) async fn dispatch_any(
|
||||
pub(crate) async fn dispatch_any_with_terminal_outcome(
|
||||
&self,
|
||||
mut invocation: ToolInvocation,
|
||||
terminal_outcome_reached: Option<Arc<AtomicBool>>,
|
||||
) -> Result<AnyToolResult, FunctionCallError> {
|
||||
let tool_name = invocation.tool_name.clone();
|
||||
let tool_name_flat = flat_tool_name(&tool_name);
|
||||
@@ -389,6 +404,8 @@ impl ToolRegistry {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
notify_tool_start(&invocation).await;
|
||||
|
||||
if let Some(pre_tool_use_payload) = tool.pre_tool_use_payload(&invocation) {
|
||||
match run_pre_tool_use_hooks(
|
||||
&invocation.session,
|
||||
@@ -402,13 +419,33 @@ impl ToolRegistry {
|
||||
PreToolUseHookResult::Blocked(message) => {
|
||||
let err = FunctionCallError::RespondToModel(message);
|
||||
dispatch_trace.record_failed(&err);
|
||||
if let Some(terminal_outcome_reached) = &terminal_outcome_reached {
|
||||
terminal_outcome_reached.store(true, Ordering::Release);
|
||||
}
|
||||
notify_tool_finish(&invocation, ToolCallOutcome::Blocked).await;
|
||||
return Err(err);
|
||||
}
|
||||
PreToolUseHookResult::Continue {
|
||||
updated_input: Some(updated_input),
|
||||
} => {
|
||||
invocation = tool.with_updated_hook_input(invocation, updated_input)?;
|
||||
}
|
||||
} => match tool.with_updated_hook_input(invocation.clone(), updated_input) {
|
||||
Ok(updated_invocation) => {
|
||||
invocation = updated_invocation;
|
||||
}
|
||||
Err(err) => {
|
||||
dispatch_trace.record_failed(&err);
|
||||
if let Some(terminal_outcome_reached) = &terminal_outcome_reached {
|
||||
terminal_outcome_reached.store(true, Ordering::Release);
|
||||
}
|
||||
notify_tool_finish(
|
||||
&invocation,
|
||||
ToolCallOutcome::Failed {
|
||||
handler_executed: false,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
return Err(err);
|
||||
}
|
||||
},
|
||||
PreToolUseHookResult::Continue {
|
||||
updated_input: None,
|
||||
} => {}
|
||||
@@ -503,6 +540,27 @@ impl ToolRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
let lifecycle_outcome = match &result {
|
||||
Ok(_) => {
|
||||
let guard = response_cell.lock().await;
|
||||
match guard.as_ref() {
|
||||
Some(result) => ToolCallOutcome::Completed {
|
||||
success: result.result.success_for_logging(),
|
||||
},
|
||||
None => ToolCallOutcome::Failed {
|
||||
handler_executed: true,
|
||||
},
|
||||
}
|
||||
}
|
||||
Err(_) => ToolCallOutcome::Failed {
|
||||
handler_executed: true,
|
||||
},
|
||||
};
|
||||
if let Some(terminal_outcome_reached) = &terminal_outcome_reached {
|
||||
terminal_outcome_reached.store(true, Ordering::Release);
|
||||
}
|
||||
notify_tool_finish(&invocation, lifecycle_outcome).await;
|
||||
|
||||
if let Err(err) = invocation
|
||||
.session
|
||||
.goal_runtime_apply(GoalRuntimeEvent::ToolCompleted {
|
||||
|
||||
@@ -23,6 +23,97 @@ impl ToolExecutor<ToolInvocation> for TestHandler {
|
||||
|
||||
impl CoreToolRuntime for TestHandler {}
|
||||
|
||||
#[derive(Clone)]
|
||||
enum LifecycleTestResult {
|
||||
Ok { success: bool },
|
||||
Err,
|
||||
}
|
||||
|
||||
struct LifecycleTestHandler {
|
||||
tool_name: codex_tools::ToolName,
|
||||
result: LifecycleTestResult,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ToolExecutor<ToolInvocation> for LifecycleTestHandler {
|
||||
fn tool_name(&self) -> codex_tools::ToolName {
|
||||
self.tool_name.clone()
|
||||
}
|
||||
|
||||
async fn handle(
|
||||
&self,
|
||||
_invocation: ToolInvocation,
|
||||
) -> Result<Box<dyn crate::tools::context::ToolOutput>, FunctionCallError> {
|
||||
match self.result.clone() {
|
||||
LifecycleTestResult::Ok { success } => Ok(Box::new(
|
||||
crate::tools::context::FunctionToolOutput::from_text(
|
||||
"ok".to_string(),
|
||||
Some(success),
|
||||
),
|
||||
)),
|
||||
LifecycleTestResult::Err => Err(FunctionCallError::RespondToModel(
|
||||
"handler failed".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CoreToolRuntime for LifecycleTestHandler {}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
enum RecordedToolLifecycle {
|
||||
Start {
|
||||
call_id: String,
|
||||
tool_name: codex_tools::ToolName,
|
||||
},
|
||||
Finish {
|
||||
call_id: String,
|
||||
tool_name: codex_tools::ToolName,
|
||||
outcome: codex_extension_api::ToolCallOutcome,
|
||||
},
|
||||
}
|
||||
|
||||
struct ToolLifecycleRecorder {
|
||||
records: Arc<std::sync::Mutex<Vec<RecordedToolLifecycle>>>,
|
||||
}
|
||||
|
||||
impl codex_extension_api::ToolLifecycleContributor for ToolLifecycleRecorder {
|
||||
fn on_tool_start<'a>(
|
||||
&'a self,
|
||||
input: codex_extension_api::ToolStartInput<'a>,
|
||||
) -> codex_extension_api::ToolLifecycleFuture<'a> {
|
||||
let records = Arc::clone(&self.records);
|
||||
let record = RecordedToolLifecycle::Start {
|
||||
call_id: input.call_id.to_string(),
|
||||
tool_name: input.tool_name.clone(),
|
||||
};
|
||||
Box::pin(async move {
|
||||
records
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.push(record);
|
||||
})
|
||||
}
|
||||
|
||||
fn on_tool_finish<'a>(
|
||||
&'a self,
|
||||
input: codex_extension_api::ToolFinishInput<'a>,
|
||||
) -> codex_extension_api::ToolLifecycleFuture<'a> {
|
||||
let records = Arc::clone(&self.records);
|
||||
let record = RecordedToolLifecycle::Finish {
|
||||
call_id: input.call_id.to_string(),
|
||||
tool_name: input.tool_name.clone(),
|
||||
outcome: input.outcome,
|
||||
};
|
||||
Box::pin(async move {
|
||||
records
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.push(record);
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handler_looks_up_namespaced_aliases_explicitly() {
|
||||
let namespace = "mcp__codex_apps__gmail";
|
||||
@@ -61,3 +152,106 @@ fn handler_looks_up_namespaced_aliases_explicitly() {
|
||||
.is_some_and(|handler| Arc::ptr_eq(handler, &namespaced_handler))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_notifies_tool_lifecycle_contributors() -> anyhow::Result<()> {
|
||||
let (mut session, turn) = crate::session::tests::make_session_and_context().await;
|
||||
let records = Arc::new(std::sync::Mutex::new(Vec::new()));
|
||||
let mut builder = codex_extension_api::ExtensionRegistryBuilder::<crate::config::Config>::new();
|
||||
builder.tool_lifecycle_contributor(Arc::new(ToolLifecycleRecorder {
|
||||
records: Arc::clone(&records),
|
||||
}));
|
||||
session.services.extensions = Arc::new(builder.build());
|
||||
|
||||
let ok_tool = codex_tools::ToolName::plain("ok_tool");
|
||||
let failing_tool = codex_tools::ToolName::plain("failing_tool");
|
||||
let ok_handler = Arc::new(LifecycleTestHandler {
|
||||
tool_name: ok_tool.clone(),
|
||||
result: LifecycleTestResult::Ok { success: false },
|
||||
}) as Arc<dyn CoreToolRuntime>;
|
||||
let failing_handler = Arc::new(LifecycleTestHandler {
|
||||
tool_name: failing_tool.clone(),
|
||||
result: LifecycleTestResult::Err,
|
||||
}) as Arc<dyn CoreToolRuntime>;
|
||||
let registry = ToolRegistry::new(HashMap::from([
|
||||
(ok_tool.clone(), ok_handler),
|
||||
(failing_tool.clone(), failing_handler),
|
||||
]));
|
||||
let session = Arc::new(session);
|
||||
let turn = Arc::new(turn);
|
||||
|
||||
registry
|
||||
.dispatch_any(test_invocation(
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
"ok-call",
|
||||
ok_tool.clone(),
|
||||
))
|
||||
.await?;
|
||||
let err = match registry
|
||||
.dispatch_any(test_invocation(
|
||||
Arc::clone(&session),
|
||||
Arc::clone(&turn),
|
||||
"failing-call",
|
||||
failing_tool.clone(),
|
||||
))
|
||||
.await
|
||||
{
|
||||
Ok(_) => panic!("failing handler should return an error"),
|
||||
Err(err) => err,
|
||||
};
|
||||
assert_eq!(err.to_string(), "handler failed");
|
||||
|
||||
let expected = vec![
|
||||
RecordedToolLifecycle::Start {
|
||||
call_id: "ok-call".to_string(),
|
||||
tool_name: ok_tool.clone(),
|
||||
},
|
||||
RecordedToolLifecycle::Finish {
|
||||
call_id: "ok-call".to_string(),
|
||||
tool_name: ok_tool,
|
||||
outcome: codex_extension_api::ToolCallOutcome::Completed { success: false },
|
||||
},
|
||||
RecordedToolLifecycle::Start {
|
||||
call_id: "failing-call".to_string(),
|
||||
tool_name: failing_tool.clone(),
|
||||
},
|
||||
RecordedToolLifecycle::Finish {
|
||||
call_id: "failing-call".to_string(),
|
||||
tool_name: failing_tool,
|
||||
outcome: codex_extension_api::ToolCallOutcome::Failed {
|
||||
handler_executed: true,
|
||||
},
|
||||
},
|
||||
];
|
||||
let actual = records
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
.drain(..)
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(expected, actual);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn test_invocation(
|
||||
session: Arc<crate::session::session::Session>,
|
||||
turn: Arc<crate::session::turn_context::TurnContext>,
|
||||
call_id: &str,
|
||||
tool_name: codex_tools::ToolName,
|
||||
) -> ToolInvocation {
|
||||
ToolInvocation {
|
||||
session,
|
||||
turn,
|
||||
cancellation_token: tokio_util::sync::CancellationToken::new(),
|
||||
tracker: Arc::new(tokio::sync::Mutex::new(
|
||||
crate::turn_diff_tracker::TurnDiffTracker::new(),
|
||||
)),
|
||||
call_id: call_id.to_string(),
|
||||
tool_name,
|
||||
source: crate::tools::context::ToolCallSource::Direct,
|
||||
payload: ToolPayload::Function {
|
||||
arguments: "{}".to_string(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ use codex_tools::ToolName;
|
||||
use codex_tools::ToolSpec;
|
||||
use codex_tools::ToolsConfig;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::instrument;
|
||||
|
||||
@@ -123,6 +124,7 @@ impl ToolRouter {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[instrument(level = "trace", skip_all, err)]
|
||||
pub async fn dispatch_tool_call_with_code_mode_result(
|
||||
&self,
|
||||
@@ -132,6 +134,53 @@ impl ToolRouter {
|
||||
tracker: SharedTurnDiffTracker,
|
||||
call: ToolCall,
|
||||
source: ToolCallSource,
|
||||
) -> Result<AnyToolResult, FunctionCallError> {
|
||||
self.dispatch_tool_call_with_code_mode_result_inner(
|
||||
session,
|
||||
turn,
|
||||
cancellation_token,
|
||||
tracker,
|
||||
call,
|
||||
source,
|
||||
/*terminal_outcome_reached*/ None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[instrument(level = "trace", skip_all, err)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn dispatch_tool_call_with_terminal_outcome(
|
||||
&self,
|
||||
session: Arc<Session>,
|
||||
turn: Arc<TurnContext>,
|
||||
cancellation_token: CancellationToken,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
call: ToolCall,
|
||||
source: ToolCallSource,
|
||||
terminal_outcome_reached: Arc<AtomicBool>,
|
||||
) -> Result<AnyToolResult, FunctionCallError> {
|
||||
self.dispatch_tool_call_with_code_mode_result_inner(
|
||||
session,
|
||||
turn,
|
||||
cancellation_token,
|
||||
tracker,
|
||||
call,
|
||||
source,
|
||||
Some(terminal_outcome_reached),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn dispatch_tool_call_with_code_mode_result_inner(
|
||||
&self,
|
||||
session: Arc<Session>,
|
||||
turn: Arc<TurnContext>,
|
||||
cancellation_token: CancellationToken,
|
||||
tracker: SharedTurnDiffTracker,
|
||||
call: ToolCall,
|
||||
source: ToolCallSource,
|
||||
terminal_outcome_reached: Option<Arc<AtomicBool>>,
|
||||
) -> Result<AnyToolResult, FunctionCallError> {
|
||||
let ToolCall {
|
||||
tool_name,
|
||||
@@ -150,7 +199,9 @@ impl ToolRouter {
|
||||
payload,
|
||||
};
|
||||
|
||||
self.registry.dispatch_any(invocation).await
|
||||
self.registry
|
||||
.dispatch_any_with_terminal_outcome(invocation, terminal_outcome_reached)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user