From d6978506f3f4fa3f8562d0044602c30d99a231f8 Mon Sep 17 00:00:00 2001 From: canvrno-oai Date: Tue, 12 May 2026 18:47:38 -0700 Subject: [PATCH] Remove OneShotShutdownGate --- codex-rs/core/src/codex_delegate.rs | 37 +------- codex-rs/core/src/codex_delegate_tests.rs | 103 ---------------------- 2 files changed, 4 insertions(+), 136 deletions(-) diff --git a/codex-rs/core/src/codex_delegate.rs b/codex-rs/core/src/codex_delegate.rs index 23291b8189..ecfcc9c44b 100644 --- a/codex-rs/core/src/codex_delegate.rs +++ b/codex-rs/core/src/codex_delegate.rs @@ -10,7 +10,6 @@ use codex_protocol::protocol::Event; use codex_protocol::protocol::EventMsg; use codex_protocol::protocol::ExecApprovalRequestEvent; use codex_protocol::protocol::McpInvocation; -use codex_protocol::protocol::McpStartupStatus; use codex_protocol::protocol::Op; use codex_protocol::protocol::RequestUserInputEvent; use codex_protocol::protocol::ReviewDecision; @@ -207,9 +206,11 @@ fn spawn_one_shot_event_bridge(io: Codex, child_cancel: CancellationToken) -> Co let session_loop_termination = io.session_loop_termination.clone(); let io_for_bridge = io; tokio::spawn(async move { - let mut shutdown_gate = OneShotShutdownGate::default(); while let Ok(event) = io_for_bridge.next_event().await { - let should_shutdown = shutdown_gate.observe(&event.msg); + let should_shutdown = matches!( + event.msg, + EventMsg::TurnComplete(_) | EventMsg::TurnAborted(_) + ); let _ = tx_bridge.send(event).await; if should_shutdown { let _ = ops_tx @@ -414,36 +415,6 @@ async fn forward_event_or_shutdown( } } -#[derive(Default)] -struct OneShotShutdownGate { - turn_finished: bool, - mcp_startup_pending: bool, -} - -impl OneShotShutdownGate { - fn observe(&mut self, msg: &EventMsg) -> bool { - match msg { - EventMsg::McpStartupUpdate(update) - if matches!(&update.status, McpStartupStatus::Starting) => - { - self.mcp_startup_pending = true; - } - EventMsg::McpStartupComplete(_) => { - self.mcp_startup_pending = false; - } - EventMsg::TurnComplete(_) => { - self.turn_finished = true; - } - EventMsg::TurnAborted(_) => { - self.turn_finished = true; - self.mcp_startup_pending = false; - } - _ => {} - } - self.turn_finished && !self.mcp_startup_pending - } -} - /// Forward ops from a caller to a sub-agent, respecting cancellation. async fn forward_ops( codex: Arc, diff --git a/codex-rs/core/src/codex_delegate_tests.rs b/codex-rs/core/src/codex_delegate_tests.rs index 27c18eab5b..392350fa35 100644 --- a/codex-rs/core/src/codex_delegate_tests.rs +++ b/codex-rs/core/src/codex_delegate_tests.rs @@ -21,7 +21,6 @@ use codex_protocol::protocol::RawResponseItemEvent; use codex_protocol::protocol::ReviewDecision; use codex_protocol::protocol::TurnAbortReason; use codex_protocol::protocol::TurnAbortedEvent; -use codex_protocol::protocol::TurnCompleteEvent; use codex_protocol::request_permissions::RequestPermissionProfile; use codex_protocol::request_permissions::RequestPermissionsEvent; use codex_protocol::request_permissions::RequestPermissionsResponse; @@ -195,108 +194,6 @@ async fn forward_events_forwards_mcp_startup_events() { assert_eq!(vec!["starting", "failed", "complete"], received); } -#[test] -fn one_shot_shutdown_waits_for_mcp_startup_complete_after_turn_complete() { - let mut gate = OneShotShutdownGate::default(); - - assert!(!gate.observe(&mcp_startup_update(McpStartupStatus::Starting))); - assert!(!gate.observe(&turn_complete())); - assert!(gate.observe(&mcp_startup_complete())); -} - -#[test] -fn one_shot_shutdown_does_not_wait_when_mcp_startup_is_not_pending() { - let mut gate = OneShotShutdownGate::default(); - - assert!(!gate.observe(&mcp_startup_update(McpStartupStatus::Ready))); - assert!(gate.observe(&turn_complete())); -} - -#[tokio::test] -async fn one_shot_bridge_waits_to_shutdown_until_mcp_startup_complete() { - let (tx_events, rx_events) = bounded(SUBMISSION_CHANNEL_CAPACITY); - let (tx_sub, rx_sub) = bounded(SUBMISSION_CHANNEL_CAPACITY); - let (_agent_status_tx, agent_status) = watch::channel(AgentStatus::PendingInit); - let (session, _ctx, _rx_evt) = crate::session::tests::make_session_and_context_with_rx().await; - let io = Codex { - tx_sub, - rx_event: rx_events, - agent_status, - session, - session_loop_termination: completed_session_loop_termination(), - }; - let child_cancel = CancellationToken::new(); - let bridged = spawn_one_shot_event_bridge(io, child_cancel.clone()); - - tx_events - .send(event( - "starting", - mcp_startup_update(McpStartupStatus::Starting), - )) - .await - .unwrap(); - tx_events - .send(event("turn-complete", turn_complete())) - .await - .unwrap(); - - assert_eq!("starting", next_event_id(&bridged).await); - assert_eq!("turn-complete", next_event_id(&bridged).await); - tokio::task::yield_now().await; - assert!(rx_sub.try_recv().is_err()); - assert!(!child_cancel.is_cancelled()); - - tx_events - .send(event("startup-complete", mcp_startup_complete())) - .await - .unwrap(); - - assert_eq!("startup-complete", next_event_id(&bridged).await); - let shutdown = timeout(Duration::from_secs(1), rx_sub.recv()) - .await - .expect("bridge did not send shutdown") - .expect("shutdown submission missing"); - assert_eq!("shutdown", shutdown.id); - assert!(matches!(shutdown.op, Op::Shutdown)); - assert!(child_cancel.is_cancelled()); -} - -fn event(id: &str, msg: EventMsg) -> Event { - Event { - id: id.to_string(), - msg, - } -} - -async fn next_event_id(codex: &Codex) -> String { - timeout(Duration::from_secs(1), codex.next_event()) - .await - .expect("bridged event missing") - .expect("bridged event channel closed") - .id -} - -fn mcp_startup_update(status: McpStartupStatus) -> EventMsg { - EventMsg::McpStartupUpdate(McpStartupUpdateEvent { - server: "github".to_string(), - status, - }) -} - -fn mcp_startup_complete() -> EventMsg { - EventMsg::McpStartupComplete(McpStartupCompleteEvent::default()) -} - -fn turn_complete() -> EventMsg { - EventMsg::TurnComplete(TurnCompleteEvent { - turn_id: "turn-1".to_string(), - last_agent_message: None, - completed_at: None, - duration_ms: None, - time_to_first_token_ms: None, - }) -} - #[tokio::test] async fn forward_ops_preserves_submission_trace_context() { let (tx_sub, rx_sub) = bounded(SUBMISSION_CHANNEL_CAPACITY);