Compare commits

...

1 Commits

Author SHA1 Message Date
Owen Lin
b4626d518b feat(guardian): stop guardian subagents when turn is interrupted 2026-04-10 17:03:28 -07:00
4 changed files with 410 additions and 21 deletions

View File

@@ -5,6 +5,7 @@ use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::time::Duration;
use std::time::SystemTime;
use std::time::UNIX_EPOCH;
@@ -4436,9 +4437,35 @@ impl Session {
pub async fn interrupt_task(self: &Arc<Self>) {
info!("interrupt received: abort current task, if any");
let has_active_turn = { self.active_turn.lock().await.is_some() };
if has_active_turn {
let active_turn_ids = {
self.active_turn
.lock()
.await
.as_ref()
.map(|turn| turn.tasks.keys().cloned().collect::<Vec<_>>())
.unwrap_or_default()
};
if !active_turn_ids.is_empty() {
let mut cancelled_guardian_review_count = 0;
for turn_id in &active_turn_ids {
cancelled_guardian_review_count += self
.guardian_review_session
.cancel_active_reviews_for_turn(turn_id)
.await;
}
if cancelled_guardian_review_count > 0 {
for turn_id in &active_turn_ids {
self.guardian_review_session
.wait_for_no_active_reviews_for_turn(turn_id, Duration::from_secs(5))
.await;
}
}
self.abort_all_tasks(TurnAbortReason::Interrupted).await;
for turn_id in active_turn_ids {
self.guardian_review_session
.shutdown_active_reviews_for_turn(&turn_id)
.await;
}
} else {
self.cancel_mcp_startup().await;
}

View File

@@ -155,6 +155,7 @@ async fn run_guardian_review(
let outcome = run_guardian_review_session(
session.clone(),
turn.clone(),
review_id.clone(),
request,
retry_reason,
schema,
@@ -317,6 +318,7 @@ pub(crate) async fn review_approval_request_with_cancel(
pub(super) async fn run_guardian_review_session(
session: Arc<Session>,
turn: Arc<TurnContext>,
review_id: String,
request: GuardianApprovalRequest,
retry_reason: Option<String>,
schema: serde_json::Value,
@@ -383,6 +385,7 @@ pub(super) async fn run_guardian_review_session(
.run_review(GuardianReviewSessionParams {
parent_session: Arc::clone(&session),
parent_turn: turn.clone(),
review_id: review_id.clone(),
spawn_config: guardian_config,
request,
retry_reason,

View File

@@ -64,6 +64,7 @@ pub(crate) enum GuardianReviewSessionOutcome {
pub(crate) struct GuardianReviewSessionParams {
pub(crate) parent_session: Arc<Session>,
pub(crate) parent_turn: Arc<TurnContext>,
pub(crate) review_id: String,
pub(crate) spawn_config: Config,
pub(crate) request: GuardianApprovalRequest,
pub(crate) retry_reason: Option<String>,
@@ -75,7 +76,7 @@ pub(crate) struct GuardianReviewSessionParams {
pub(crate) external_cancel: Option<CancellationToken>,
}
#[derive(Default)]
#[derive(Clone, Default)]
pub(crate) struct GuardianReviewSessionManager {
state: Arc<Mutex<GuardianReviewSessionState>>,
}
@@ -84,6 +85,19 @@ pub(crate) struct GuardianReviewSessionManager {
struct GuardianReviewSessionState {
trunk: Option<Arc<GuardianReviewSession>>,
ephemeral_reviews: Vec<Arc<GuardianReviewSession>>,
active_reviews: HashMap<String, ActiveGuardianReview>,
}
struct ActiveGuardianReview {
parent_turn_id: String,
cancel_token: CancellationToken,
review_session: Option<Arc<GuardianReviewSession>>,
}
struct ActiveGuardianReviewRegistration {
manager: GuardianReviewSessionManager,
review_id: String,
registered: bool,
}
struct GuardianReviewSession {
@@ -246,15 +260,42 @@ impl Drop for EphemeralReviewCleanup {
}
}
impl ActiveGuardianReviewRegistration {
async fn unregister(mut self) {
self.registered = false;
self.manager.unregister_active_review(&self.review_id).await;
}
}
impl Drop for ActiveGuardianReviewRegistration {
fn drop(&mut self) {
if !self.registered {
return;
}
let manager = self.manager.clone();
let review_id = self.review_id.clone();
if let Ok(handle) = tokio::runtime::Handle::try_current() {
drop(handle.spawn(async move {
manager.unregister_active_review(&review_id).await;
}));
}
}
}
impl GuardianReviewSessionManager {
pub(crate) async fn shutdown(&self) {
let (review_session, ephemeral_reviews) = {
let (review_session, ephemeral_reviews, active_reviews) = {
let mut state = self.state.lock().await;
(
state.trunk.take(),
std::mem::take(&mut state.ephemeral_reviews),
std::mem::take(&mut state.active_reviews),
)
};
for active_review in active_reviews.into_values() {
active_review.cancel_token.cancel();
}
if let Some(review_session) = review_session {
review_session.shutdown().await;
}
@@ -264,6 +305,17 @@ impl GuardianReviewSessionManager {
}
pub(crate) async fn run_review(
&self,
mut params: GuardianReviewSessionParams,
) -> GuardianReviewSessionOutcome {
let (registration, cancel_token) = self.register_active_review(&params).await;
params.external_cancel = Some(cancel_token);
let outcome = Box::pin(self.run_review_inner(params)).await;
registration.unregister().await;
outcome
}
async fn run_review_inner(
&self,
params: GuardianReviewSessionParams,
) -> GuardianReviewSessionOutcome {
@@ -326,30 +378,30 @@ impl GuardianReviewSessionManager {
};
if trunk.reuse_key != next_reuse_key {
return self
.run_ephemeral_review(
params,
next_reuse_key,
deadline,
/*fork_snapshot*/ None,
)
.await;
return Box::pin(self.run_ephemeral_review(
params,
next_reuse_key,
deadline,
/*fork_snapshot*/ None,
))
.await;
}
let trunk_guard = match trunk.review_lock.try_lock() {
Ok(trunk_guard) => trunk_guard,
Err(_) => {
return self
.run_ephemeral_review(
params,
next_reuse_key,
deadline,
trunk.fork_snapshot().await,
)
.await;
return Box::pin(self.run_ephemeral_review(
params,
next_reuse_key,
deadline,
trunk.fork_snapshot().await,
))
.await;
}
};
self.record_active_review_session(&params.review_id, Arc::clone(&trunk))
.await;
let (outcome, keep_review_session) =
run_review_on_session(trunk.as_ref(), &params, deadline).await;
if keep_review_session && matches!(outcome, GuardianReviewSessionOutcome::Completed(_)) {
@@ -367,6 +419,143 @@ impl GuardianReviewSessionManager {
}
}
async fn register_active_review(
&self,
params: &GuardianReviewSessionParams,
) -> (ActiveGuardianReviewRegistration, CancellationToken) {
let cancel_token = CancellationToken::new();
if let Some(external_cancel) = params.external_cancel.clone() {
if external_cancel.is_cancelled() {
cancel_token.cancel();
} else {
let cancel_token_for_task = cancel_token.clone();
drop(tokio::spawn(async move {
tokio::select! {
_ = external_cancel.cancelled() => cancel_token_for_task.cancel(),
_ = cancel_token_for_task.cancelled() => {}
}
}));
}
}
let previous = self.state.lock().await.active_reviews.insert(
params.review_id.clone(),
ActiveGuardianReview {
parent_turn_id: params.parent_turn.sub_id.clone(),
cancel_token: cancel_token.clone(),
review_session: None,
},
);
if previous.is_some() {
warn!(
review_id = %params.review_id,
"overwriting active guardian approval review"
);
}
(
ActiveGuardianReviewRegistration {
manager: self.clone(),
review_id: params.review_id.clone(),
registered: true,
},
cancel_token,
)
}
async fn unregister_active_review(&self, review_id: &str) {
self.state.lock().await.active_reviews.remove(review_id);
}
async fn record_active_review_session(
&self,
review_id: &str,
review_session: Arc<GuardianReviewSession>,
) {
let mut state = self.state.lock().await;
if let Some(active_review) = state.active_reviews.get_mut(review_id) {
active_review.review_session = Some(review_session);
}
}
pub(crate) async fn cancel_active_reviews_for_turn(&self, parent_turn_id: &str) -> usize {
let active_reviews = {
let state = self.state.lock().await;
state
.active_reviews
.values()
.filter(|review| review.parent_turn_id == parent_turn_id)
.map(|review| review.cancel_token.clone())
.collect::<Vec<_>>()
};
let count = active_reviews.len();
for cancel_token in active_reviews {
cancel_token.cancel();
}
count
}
pub(crate) async fn wait_for_no_active_reviews_for_turn(
&self,
parent_turn_id: &str,
timeout: Duration,
) {
let result = tokio::time::timeout(timeout, async {
while self.active_review_count_for_turn(parent_turn_id).await > 0 {
tokio::time::sleep(Duration::from_millis(10)).await;
}
})
.await;
if result.is_err() {
warn!(
parent_turn_id,
"timed out waiting for guardian approval reviews to abort"
);
}
}
pub(crate) async fn shutdown_active_reviews_for_turn(&self, parent_turn_id: &str) {
let active_reviews = {
let mut state = self.state.lock().await;
let mut active_reviews = Vec::new();
let mut retained_reviews = HashMap::new();
for (review_id, active_review) in std::mem::take(&mut state.active_reviews) {
if active_review.parent_turn_id == parent_turn_id {
active_reviews.push(active_review);
} else {
retained_reviews.insert(review_id, active_review);
}
}
state.active_reviews = retained_reviews;
active_reviews
};
for active_review in active_reviews {
active_review.cancel_token.cancel();
if let Some(review_session) = active_review.review_session {
review_session.shutdown_in_background();
}
}
}
async fn active_review_count_for_turn(&self, parent_turn_id: &str) -> usize {
self.state
.lock()
.await
.active_reviews
.values()
.filter(|review| review.parent_turn_id == parent_turn_id)
.count()
}
#[cfg(test)]
pub(crate) async fn active_review_count_for_turn_for_test(
&self,
parent_turn_id: &str,
) -> usize {
self.active_review_count_for_turn(parent_turn_id).await
}
#[cfg(test)]
pub(crate) async fn cache_for_test(&self, codex: Codex) {
let reuse_key = GuardianReviewSessionReuseKey::from_spawn_config(
@@ -414,7 +603,7 @@ impl GuardianReviewSessionManager {
let snapshot = state.last_committed_fork_snapshot.as_ref()?;
match &snapshot.initial_history {
InitialHistory::Forked(items) => Some(items.clone()),
InitialHistory::New | InitialHistory::Resumed(_) => None,
InitialHistory::New | InitialHistory::Cleared | InitialHistory::Resumed(_) => None,
}
}
@@ -487,6 +676,8 @@ impl GuardianReviewSessionManager {
let mut cleanup =
EphemeralReviewCleanup::new(Arc::clone(&self.state), Arc::clone(&review_session));
self.record_active_review_session(&params.review_id, Arc::clone(&review_session))
.await;
let (outcome, _) = run_review_on_session(review_session.as_ref(), &params, deadline).await;
if let Some(review_session) = self.take_active_ephemeral(&review_session).await {
cleanup.disarm();

View File

@@ -14,6 +14,9 @@ use crate::config_loader::NetworkDomainPermissionToml;
use crate::config_loader::NetworkDomainPermissionsToml;
use crate::config_loader::RequirementSource;
use crate::config_loader::Sourced;
use crate::state::TaskKind;
use crate::tasks::SessionTask;
use crate::tasks::SessionTaskContext;
use crate::test_support;
use codex_config::config_toml::ConfigToml;
use codex_network_proxy::NetworkProxyConfig;
@@ -30,6 +33,8 @@ use codex_protocol::protocol::GuardianUserAuthorization;
use codex_protocol::protocol::ReviewDecision;
use codex_protocol::protocol::RolloutItem;
use codex_protocol::protocol::SandboxPolicy;
use codex_protocol::protocol::TurnAbortReason;
use codex_protocol::user_input::UserInput;
use core_test_support::PathBufExt;
use core_test_support::TempDirExt;
use core_test_support::context_snapshot;
@@ -60,6 +65,30 @@ fn fixed_guardian_parent_session_id() -> ThreadId {
.expect("fixed parent session id should be a valid UUID")
}
#[derive(Clone, Copy)]
struct GuardianNeverEndingTask;
impl SessionTask for GuardianNeverEndingTask {
fn kind(&self) -> TaskKind {
TaskKind::Regular
}
fn span_name(&self) -> &'static str {
"session_task.guardian_never_ending"
}
async fn run(
self: Arc<Self>,
_session: Arc<SessionTaskContext>,
_ctx: Arc<TurnContext>,
_input: Vec<UserInput>,
cancellation_token: CancellationToken,
) -> Option<String> {
cancellation_token.cancelled().await;
None
}
}
async fn guardian_test_session_and_turn(
server: &wiremock::MockServer,
) -> (Arc<Session>, Arc<TurnContext>) {
@@ -718,6 +747,141 @@ async fn cancelled_guardian_review_emits_terminal_abort_without_warning() {
assert!(warnings.is_empty());
}
#[tokio::test]
async fn interrupting_parent_turn_cancels_active_guardian_review() -> anyhow::Result<()> {
let (_gate_tx, gate_rx) = tokio::sync::oneshot::channel();
let (server, _) = start_streaming_sse_server(vec![vec![
StreamingSseChunk {
gate: None,
body: sse(vec![ev_response_created("resp-guardian-cancel")]),
},
StreamingSseChunk {
gate: Some(gate_rx),
body: sse(vec![
ev_assistant_message(
"msg-guardian-cancel",
&serde_json::json!({
"risk_level": "low",
"user_authorization": "high",
"outcome": "allow",
"rationale": "would have completed if not cancelled",
})
.to_string(),
),
ev_completed("resp-guardian-cancel"),
]),
},
]])
.await;
let (mut session, mut turn, rx) = crate::codex::make_session_and_context_with_rx().await;
let session_mut = Arc::get_mut(&mut session).expect("single session ref");
let turn_mut = Arc::get_mut(&mut turn).expect("single turn ref");
let mut config = (*turn_mut.config).clone();
config.model_provider.base_url = Some(format!("{}/v1", server.uri()));
config.user_instructions = None;
let config = Arc::new(config);
session_mut.services.models_manager = Arc::new(test_support::models_manager_with_provider(
config.codex_home.clone(),
Arc::clone(&session_mut.services.auth_manager),
config.model_provider.clone(),
));
turn_mut.config = Arc::clone(&config);
turn_mut.provider = config.model_provider.clone();
turn_mut.user_instructions = None;
seed_guardian_parent_history(&session, &turn).await;
session
.spawn_task(Arc::clone(&turn), Vec::new(), GuardianNeverEndingTask)
.await;
let review_session = Arc::clone(&session);
let review_turn = Arc::clone(&turn);
let mut review = tokio::spawn(async move {
review_approval_request(
&review_session,
&review_turn,
"review-shell-guardian-cancel".to_string(),
GuardianApprovalRequest::Shell {
id: "shell-guardian-cancel".to_string(),
command: vec!["git".to_string(), "push".to_string()],
cwd: PathBuf::from("/repo/codex-rs/core"),
sandbox_permissions: crate::sandboxing::SandboxPermissions::UseDefault,
additional_permissions: None,
justification: Some("Push the docs fix.".to_string()),
},
/*retry_reason*/ None,
)
.await
});
tokio::select! {
wait_result = tokio::time::timeout(Duration::from_secs(10), async {
loop {
let review_count = session
.guardian_review_session
.active_review_count_for_turn_for_test(&turn.sub_id)
.await;
if review_count == 1 {
break;
}
tokio::task::yield_now().await;
}
}) => {
wait_result.expect("guardian review should register as active");
}
decision = &mut review => {
panic!("guardian review completed before cancellation: {decision:?}");
}
}
tokio::time::timeout(Duration::from_secs(10), async {
loop {
if !server.requests().await.is_empty() {
break;
}
tokio::task::yield_now().await;
}
})
.await
.expect("guardian review request should start");
session.interrupt_task().await;
let decision = tokio::time::timeout(Duration::from_secs(10), review)
.await
.expect("guardian review should abort")?;
assert_eq!(decision, ReviewDecision::Abort);
assert_eq!(
session
.guardian_review_session
.active_review_count_for_turn_for_test(&turn.sub_id)
.await,
0
);
let mut guardian_statuses = Vec::new();
let mut turn_aborted = false;
while let Ok(event) = rx.try_recv() {
match event.msg {
EventMsg::GuardianAssessment(event) => guardian_statuses.push(event.status),
EventMsg::TurnAborted(event) => {
turn_aborted = event.reason == TurnAbortReason::Interrupted;
}
_ => {}
}
}
assert_eq!(
guardian_statuses,
vec![
GuardianAssessmentStatus::InProgress,
GuardianAssessmentStatus::Aborted,
]
);
assert!(turn_aborted, "expected parent turn to be interrupted");
Ok(())
}
#[tokio::test]
async fn routes_approval_to_guardian_requires_auto_only_review_policy() {
let (_session, mut turn) = crate::codex::make_session_and_context().await;
@@ -893,6 +1057,7 @@ async fn guardian_review_request_layout_matches_model_visible_request_snapshot()
let outcome = run_guardian_review_session_for_test(
Arc::clone(&session),
Arc::clone(&turn),
"review-shell-1".to_string(),
request,
Some("Sandbox denied outbound git push to github.com.".to_string()),
guardian_output_schema(),
@@ -1013,6 +1178,7 @@ async fn guardian_reuses_prompt_cache_key_and_appends_prior_reviews() -> anyhow:
let first_outcome = run_guardian_review_session_for_test(
Arc::clone(&session),
Arc::clone(&turn),
"review-shell-1".to_string(),
first_request,
Some("First retry reason".to_string()),
guardian_output_schema(),
@@ -1059,6 +1225,7 @@ async fn guardian_reuses_prompt_cache_key_and_appends_prior_reviews() -> anyhow:
let second_outcome = run_guardian_review_session_for_test(
Arc::clone(&session),
Arc::clone(&turn),
"review-shell-2".to_string(),
second_request,
Some("Second retry reason".to_string()),
guardian_output_schema(),
@@ -1101,6 +1268,7 @@ async fn guardian_reuses_prompt_cache_key_and_appends_prior_reviews() -> anyhow:
let third_outcome = run_guardian_review_session_for_test(
Arc::clone(&session),
Arc::clone(&turn),
"review-shell-3".to_string(),
third_request,
Some("Third retry reason".to_string()),
guardian_output_schema(),