mirror of
https://github.com/openai/codex.git
synced 2026-05-27 06:25:48 +00:00
## Why Extensions that need to track runtime progress currently have no typed host signal for tool execution. The goal extension in particular needs to observe tool attempts without inspecting tool payloads, owning tool implementations, or staying coupled to core-only runtime plumbing. This adds a narrow lifecycle contributor API for host-owned tool execution: extensions can observe when an accepted tool call starts and how it finishes, while policy hooks and tool handlers continue to own payload rewriting, blocking, and execution. Relevant code: - [`ToolLifecycleContributor`](3ad2850ffc/codex-rs/ext/extension-api/src/contributors.rs (L119)) defines the extension-facing observer contract. - [`tool_lifecycle.rs`](3ad2850ffc/codex-rs/ext/extension-api/src/contributors/tool_lifecycle.rs) defines the typed start/finish inputs, source, and outcome enums. - [`notify_tool_start` / `notify_tool_finish`](3ad2850ffc/codex-rs/core/src/tools/lifecycle.rs) bridges core tool dispatch into the extension registry. ## What Changed - Added `ToolLifecycleContributor` to `codex-extension-api`, including: - `ToolStartInput` - `ToolFinishInput` - `ToolCallSource` - `ToolCallOutcome` - Added registration and lookup support on `ExtensionRegistryBuilder` / `ExtensionRegistry`. - Wired core tool dispatch to notify lifecycle contributors for: - accepted tool starts - completed tool calls, including the tool output success marker - pre-tool-use blocks - failures before or after the handler runs - cancellation/abort in the parallel tool path - Registered the goal extension as a lifecycle contributor and added the outcome filter it will use for goal progress accounting. ## Test Coverage - Added `dispatch_notifies_tool_lifecycle_contributors` to cover lifecycle notification ordering and outcomes for successful and handler-failed tool calls.
368 lines
14 KiB
Rust
368 lines
14 KiB
Rust
use std::sync::Arc;
|
|
use std::sync::atomic::AtomicBool;
|
|
use std::sync::atomic::Ordering;
|
|
use std::time::Instant;
|
|
|
|
use tokio::sync::RwLock;
|
|
use tokio::task::JoinError;
|
|
use tokio_util::either::Either;
|
|
use tokio_util::sync::CancellationToken;
|
|
use tokio_util::task::AbortOnDropHandle;
|
|
use tracing::Instrument;
|
|
use tracing::instrument;
|
|
use tracing::trace_span;
|
|
|
|
use crate::function_tool::FunctionCallError;
|
|
use crate::session::session::Session;
|
|
use crate::session::turn_context::TurnContext;
|
|
use crate::tools::context::AbortedToolOutput;
|
|
use crate::tools::context::SharedTurnDiffTracker;
|
|
use crate::tools::context::ToolPayload;
|
|
use crate::tools::lifecycle::notify_tool_aborted;
|
|
use crate::tools::registry::AnyToolResult;
|
|
use crate::tools::registry::ToolArgumentDiffConsumer;
|
|
use crate::tools::router::ToolCall;
|
|
use crate::tools::router::ToolCallSource;
|
|
use crate::tools::router::ToolRouter;
|
|
use codex_protocol::error::CodexErr;
|
|
use codex_protocol::models::ResponseInputItem;
|
|
|
|
#[derive(Clone)]
|
|
pub(crate) struct ToolCallRuntime {
|
|
router: Arc<ToolRouter>,
|
|
session: Arc<Session>,
|
|
turn_context: Arc<TurnContext>,
|
|
tracker: SharedTurnDiffTracker,
|
|
parallel_execution: Arc<RwLock<()>>,
|
|
}
|
|
|
|
impl ToolCallRuntime {
|
|
pub(crate) fn new(
|
|
router: Arc<ToolRouter>,
|
|
session: Arc<Session>,
|
|
turn_context: Arc<TurnContext>,
|
|
tracker: SharedTurnDiffTracker,
|
|
) -> Self {
|
|
Self {
|
|
router,
|
|
session,
|
|
turn_context,
|
|
tracker,
|
|
parallel_execution: Arc::new(RwLock::new(())),
|
|
}
|
|
}
|
|
|
|
pub(crate) fn create_diff_consumer(
|
|
&self,
|
|
tool_name: &codex_tools::ToolName,
|
|
) -> Option<Box<dyn ToolArgumentDiffConsumer>> {
|
|
self.router.create_diff_consumer(tool_name)
|
|
}
|
|
|
|
#[instrument(level = "trace", skip_all)]
|
|
pub(crate) fn handle_tool_call(
|
|
self,
|
|
call: ToolCall,
|
|
cancellation_token: CancellationToken,
|
|
) -> impl std::future::Future<Output = Result<ResponseInputItem, CodexErr>> {
|
|
let error_call = call.clone();
|
|
let future =
|
|
self.handle_tool_call_with_source(call, ToolCallSource::Direct, cancellation_token);
|
|
async move {
|
|
match future.await {
|
|
Ok(response) => Ok(response.into_response()),
|
|
Err(FunctionCallError::Fatal(message)) => Err(CodexErr::Fatal(message)),
|
|
Err(other) => Ok(Self::failure_response(error_call, other)),
|
|
}
|
|
}
|
|
.in_current_span()
|
|
}
|
|
|
|
#[instrument(level = "trace", skip_all)]
|
|
pub(crate) fn handle_tool_call_with_source(
|
|
self,
|
|
call: ToolCall,
|
|
source: ToolCallSource,
|
|
cancellation_token: CancellationToken,
|
|
) -> impl std::future::Future<Output = Result<AnyToolResult, FunctionCallError>> {
|
|
let supports_parallel = self.router.tool_supports_parallel(&call);
|
|
let router = Arc::clone(&self.router);
|
|
let session = Arc::clone(&self.session);
|
|
let turn = Arc::clone(&self.turn_context);
|
|
let tracker = Arc::clone(&self.tracker);
|
|
let lock = Arc::clone(&self.parallel_execution);
|
|
let invocation_cancellation_token = cancellation_token.clone();
|
|
let started = Instant::now();
|
|
let abort_session = Arc::clone(&session);
|
|
let abort_source = source.clone();
|
|
let abort_turn = Arc::clone(&turn);
|
|
let terminal_outcome_reached = Arc::new(AtomicBool::new(false));
|
|
let dispatch_terminal_outcome_reached = Arc::clone(&terminal_outcome_reached);
|
|
let dispatch_call = call.clone();
|
|
|
|
let dispatch_span = trace_span!(
|
|
"dispatch_tool_call_with_code_mode_result",
|
|
otel.name = %call.tool_name,
|
|
tool_name = %call.tool_name,
|
|
call_id = call.call_id.as_str(),
|
|
aborted = false,
|
|
);
|
|
let abort_dispatch_span = dispatch_span.clone();
|
|
|
|
let mut handle: AbortOnDropHandle<Result<AnyToolResult, FunctionCallError>> =
|
|
AbortOnDropHandle::new(tokio::spawn(async move {
|
|
let _guard = if supports_parallel {
|
|
Either::Left(lock.read().await)
|
|
} else {
|
|
Either::Right(lock.write().await)
|
|
};
|
|
|
|
router
|
|
.dispatch_tool_call_with_terminal_outcome(
|
|
session,
|
|
turn,
|
|
invocation_cancellation_token,
|
|
tracker,
|
|
dispatch_call,
|
|
source,
|
|
dispatch_terminal_outcome_reached,
|
|
)
|
|
.instrument(dispatch_span.clone())
|
|
.await
|
|
}));
|
|
|
|
async move {
|
|
tokio::select! {
|
|
res = &mut handle => res.map_err(Self::tool_task_join_error)?,
|
|
_ = cancellation_token.cancelled() => {
|
|
if terminal_outcome_reached.load(Ordering::Acquire) || handle.is_finished() {
|
|
handle.await.map_err(Self::tool_task_join_error)?
|
|
} else {
|
|
let secs = started.elapsed().as_secs_f32().max(0.1);
|
|
abort_dispatch_span.record("aborted", true);
|
|
handle.abort();
|
|
match handle.await {
|
|
Ok(result) => result,
|
|
Err(err) if err.is_cancelled() => {
|
|
let response = Self::aborted_response(&call, secs);
|
|
notify_tool_aborted(
|
|
abort_session.as_ref(),
|
|
abort_turn.as_ref(),
|
|
call.call_id.as_str(),
|
|
&call.tool_name,
|
|
abort_source,
|
|
)
|
|
.await;
|
|
Ok(response)
|
|
}
|
|
Err(err) => Err(Self::tool_task_join_error(err)),
|
|
}
|
|
}
|
|
},
|
|
}
|
|
}
|
|
.in_current_span()
|
|
}
|
|
}
|
|
|
|
impl ToolCallRuntime {
|
|
fn tool_task_join_error(err: JoinError) -> FunctionCallError {
|
|
FunctionCallError::Fatal(format!("tool task failed to receive: {err:?}"))
|
|
}
|
|
|
|
fn failure_response(call: ToolCall, err: FunctionCallError) -> ResponseInputItem {
|
|
let message = err.to_string();
|
|
match call.payload {
|
|
ToolPayload::ToolSearch { .. } => ResponseInputItem::ToolSearchOutput {
|
|
call_id: call.call_id,
|
|
status: "completed".to_string(),
|
|
execution: "client".to_string(),
|
|
tools: Vec::new(),
|
|
},
|
|
ToolPayload::Custom { .. } => ResponseInputItem::CustomToolCallOutput {
|
|
call_id: call.call_id,
|
|
name: None,
|
|
output: codex_protocol::models::FunctionCallOutputPayload {
|
|
body: codex_protocol::models::FunctionCallOutputBody::Text(message),
|
|
success: Some(false),
|
|
},
|
|
},
|
|
_ => ResponseInputItem::FunctionCallOutput {
|
|
call_id: call.call_id,
|
|
output: codex_protocol::models::FunctionCallOutputPayload {
|
|
body: codex_protocol::models::FunctionCallOutputBody::Text(message),
|
|
success: Some(false),
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
fn aborted_response(call: &ToolCall, secs: f32) -> AnyToolResult {
|
|
AnyToolResult {
|
|
call_id: call.call_id.clone(),
|
|
payload: call.payload.clone(),
|
|
result: Box::new(AbortedToolOutput {
|
|
message: Self::abort_message(call, secs),
|
|
}),
|
|
post_tool_use_payload: None,
|
|
}
|
|
}
|
|
|
|
fn abort_message(call: &ToolCall, secs: f32) -> String {
|
|
if call.tool_name.namespace.is_none()
|
|
&& matches!(
|
|
call.tool_name.name.as_str(),
|
|
"shell_command" | "unified_exec"
|
|
)
|
|
{
|
|
format!("Wall time: {secs:.1} seconds\naborted by user")
|
|
} else {
|
|
format!("aborted by user after {secs:.1}s")
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use std::time::Duration;
|
|
|
|
use crate::tools::context::FunctionToolOutput;
|
|
use crate::tools::context::ToolInvocation;
|
|
use crate::tools::registry::CoreToolRuntime;
|
|
use crate::tools::registry::ToolExecutor;
|
|
use crate::tools::registry::ToolRegistry;
|
|
use crate::turn_diff_tracker::TurnDiffTracker;
|
|
use codex_extension_api::ToolCallOutcome;
|
|
use codex_protocol::models::FunctionCallOutputBody;
|
|
use codex_protocol::models::FunctionCallOutputPayload;
|
|
use pretty_assertions::assert_eq;
|
|
use tokio::sync::Notify;
|
|
use tokio::sync::oneshot;
|
|
|
|
struct ImmediateHandler {
|
|
tool_name: codex_tools::ToolName,
|
|
}
|
|
|
|
#[async_trait::async_trait]
|
|
impl ToolExecutor<ToolInvocation> for ImmediateHandler {
|
|
fn tool_name(&self) -> codex_tools::ToolName {
|
|
self.tool_name.clone()
|
|
}
|
|
|
|
async fn handle(
|
|
&self,
|
|
_invocation: ToolInvocation,
|
|
) -> Result<Box<dyn crate::tools::context::ToolOutput>, FunctionCallError> {
|
|
Ok(Box::new(FunctionToolOutput::from_text(
|
|
"ok".to_string(),
|
|
Some(true),
|
|
)))
|
|
}
|
|
}
|
|
|
|
impl CoreToolRuntime for ImmediateHandler {}
|
|
|
|
struct BlockingFinishContributor {
|
|
records: Arc<std::sync::Mutex<Vec<ToolCallOutcome>>>,
|
|
finish_started: std::sync::Mutex<Option<oneshot::Sender<()>>>,
|
|
allow_finish: Arc<Notify>,
|
|
}
|
|
|
|
impl codex_extension_api::ToolLifecycleContributor for BlockingFinishContributor {
|
|
fn on_tool_finish<'a>(
|
|
&'a self,
|
|
input: codex_extension_api::ToolFinishInput<'a>,
|
|
) -> codex_extension_api::ToolLifecycleFuture<'a> {
|
|
let records = Arc::clone(&self.records);
|
|
let allow_finish = Arc::clone(&self.allow_finish);
|
|
let finish_started = self
|
|
.finish_started
|
|
.lock()
|
|
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
|
.take();
|
|
let outcome = input.outcome;
|
|
Box::pin(async move {
|
|
if let Some(finish_started) = finish_started {
|
|
let _ = finish_started.send(());
|
|
}
|
|
allow_finish.notified().await;
|
|
records
|
|
.lock()
|
|
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
|
.push(outcome);
|
|
})
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn cancellation_after_handler_finishes_preserves_completed_lifecycle()
|
|
-> anyhow::Result<()> {
|
|
let (mut session, turn_context) = crate::session::tests::make_session_and_context().await;
|
|
let records = Arc::new(std::sync::Mutex::new(Vec::new()));
|
|
let (finish_started_tx, finish_started_rx) = oneshot::channel();
|
|
let allow_finish = Arc::new(Notify::new());
|
|
let mut builder =
|
|
codex_extension_api::ExtensionRegistryBuilder::<crate::config::Config>::new();
|
|
builder.tool_lifecycle_contributor(Arc::new(BlockingFinishContributor {
|
|
records: Arc::clone(&records),
|
|
finish_started: std::sync::Mutex::new(Some(finish_started_tx)),
|
|
allow_finish: Arc::clone(&allow_finish),
|
|
}));
|
|
session.services.extensions = Arc::new(builder.build());
|
|
|
|
let session = Arc::new(session);
|
|
let turn_context = Arc::new(turn_context);
|
|
let tool_name = codex_tools::ToolName::plain("test_tool");
|
|
let handler = Arc::new(ImmediateHandler {
|
|
tool_name: tool_name.clone(),
|
|
}) as Arc<dyn CoreToolRuntime>;
|
|
let router = Arc::new(ToolRouter::from_parts(
|
|
ToolRegistry::from_tools([handler]),
|
|
Vec::new(),
|
|
));
|
|
let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
|
let runtime = ToolCallRuntime::new(router, session, turn_context, tracker);
|
|
let cancellation_token = CancellationToken::new();
|
|
let call = ToolCall {
|
|
tool_name,
|
|
call_id: "call-1".to_string(),
|
|
payload: ToolPayload::Function {
|
|
arguments: "{}".to_string(),
|
|
},
|
|
};
|
|
|
|
let response_task =
|
|
tokio::spawn(runtime.handle_tool_call(call, cancellation_token.clone()));
|
|
tokio::time::timeout(Duration::from_secs(1), finish_started_rx)
|
|
.await
|
|
.expect("timed out waiting for lifecycle notification to start")
|
|
.expect("lifecycle notification should start");
|
|
cancellation_token.cancel();
|
|
tokio::time::sleep(Duration::from_millis(10)).await;
|
|
allow_finish.notify_waiters();
|
|
|
|
let response = tokio::time::timeout(Duration::from_secs(1), response_task)
|
|
.await
|
|
.expect("timed out waiting for tool response")
|
|
.expect("tool response task should join")?;
|
|
let expected_response = ResponseInputItem::FunctionCallOutput {
|
|
call_id: "call-1".to_string(),
|
|
output: FunctionCallOutputPayload {
|
|
body: FunctionCallOutputBody::Text("ok".to_string()),
|
|
success: Some(true),
|
|
},
|
|
};
|
|
assert_eq!(expected_response, response);
|
|
|
|
let actual = records
|
|
.lock()
|
|
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
|
.drain(..)
|
|
.collect::<Vec<_>>();
|
|
assert_eq!(vec![ToolCallOutcome::Completed { success: true }], actual);
|
|
|
|
Ok(())
|
|
}
|
|
}
|