Add cancellation token

This commit is contained in:
jif-oai
2025-10-16 15:22:56 +01:00
parent 39c72b3151
commit 32bd302d80
3 changed files with 41 additions and 17 deletions

1
codex-rs/Cargo.lock generated
View File

@@ -1202,6 +1202,7 @@ dependencies = [
"tempfile",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"wiremock",
]

View File

@@ -14,6 +14,7 @@ serde_json = { workspace = true }
tempfile = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "signal"] }
tokio-stream = { workspace = true }
tokio-util = { workspace = true }
tracing = { workspace = true, features = ["log"] }
futures = "0.3"

View File

@@ -1,5 +1,4 @@
use std::fs;
use std::io::Write;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
@@ -17,6 +16,7 @@ use codex_core::protocol::Op;
use codex_protocol::ConversationId;
use tokio::signal;
use tokio_stream::StreamExt;
use tokio_util::sync::CancellationToken;
use tracing::warn;
use crate::progress::ProgressReporter;
@@ -214,8 +214,13 @@ impl InftyOrchestrator {
self.maybe_post_objective(&solver_role, sessions, &mut state, options)
.await?;
let ctrl_c = signal::ctrl_c();
tokio::pin!(ctrl_c);
// Cancellation token that propagates Ctrl+C to nested awaits
let cancel = CancellationToken::new();
let cancel_on_ctrl_c = cancel.clone();
tokio::spawn(async move {
let _ = signal::ctrl_c().await;
cancel_on_ctrl_c.cancel();
});
'event_loop: loop {
tokio::select! {
@@ -243,6 +248,7 @@ impl InftyOrchestrator {
options,
&director_role,
&solver_role,
cancel.clone(),
)
.await?;
sessions.store.touch()?;
@@ -277,6 +283,7 @@ impl InftyOrchestrator {
summary_ref,
options,
&solver_role,
cancel.clone(),
)
.await?;
if !verified { state.pending_solver_turn_completion = true; continue; }
@@ -307,16 +314,16 @@ impl InftyOrchestrator {
println!("Stream error: {:?}", error);
}
e => {
tracing::info!("Unhandled event: {:?}", e);
print!(".");
let _ = std::io::stdout().flush();
tracing::info!("Unhandled event: {:?}", e); // todo move to trace
}
}
}
_ = &mut ctrl_c => {
if let Some(progress) = self.progress.as_ref() {
progress.run_interrupted();
}
_ = cancel.cancelled() => {
if let Some(progress) = self.progress.as_ref() { progress.run_interrupted(); }
// Proactively interrupt any in-flight role turns for fast shutdown.
let _ = sessions.solver.conversation.submit(Op::Interrupt).await;
let _ = sessions.director.conversation.submit(Op::Interrupt).await;
for v in &sessions.verifiers { let _ = v.conversation.submit(Op::Interrupt).await; }
// Cleanup is handled by the caller (drive_run) to avoid double-shutdown
bail!("run interrupted by Ctrl+C");
}
@@ -357,17 +364,25 @@ impl InftyOrchestrator {
options: &RunExecutionOptions,
director_role: &DirectorRole,
solver_role: &SolverRole,
cancel: CancellationToken,
) -> Result<()> {
let request = DirectionRequestPayload::new(prompt, options.objective.as_deref());
let directive_payload = director_role
.call(&request)
.await
.context("director response was not valid directive JSON")?;
let directive_payload = tokio::select! {
r = director_role.call(&request) => {
r.context("director response was not valid directive JSON")?
}
_ = cancel.cancelled() => {
bail!("interrupted")
}
};
if let Some(progress) = self.progress.as_ref() {
progress.director_response(&directive_payload);
}
let req = SolverRequest::from(directive_payload);
solver_role.call(&req).await?;
tokio::select! {
r = solver_role.call(&req) => { r?; }
_ = cancel.cancelled() => { bail!("interrupted"); }
}
Ok(())
}
@@ -379,6 +394,7 @@ impl InftyOrchestrator {
summary: Option<&str>,
options: &RunExecutionOptions,
solver_role: &SolverRole,
cancel: CancellationToken,
) -> Result<bool> {
let relative = deliverable_path
.strip_prefix(sessions.store.path())
@@ -392,14 +408,20 @@ impl InftyOrchestrator {
if verifier_pool.is_empty() {
return Ok(true);
}
let round = verifier_pool.collect_round(&request).await?;
let round = tokio::select! {
r = verifier_pool.collect_round(&request) => { r? }
_ = cancel.cancelled() => { bail!("interrupted"); }
};
verifier_pool
.rotate_passing(sessions, &self.conversation_manager, &round.passing_roles)
.await?;
let summary_result = round.summary;
self.emit_verification_summary(&summary_result);
let req = SolverRequest::from(&summary_result);
solver_role.call(&req).await?;
tokio::select! {
r = solver_role.call(&req) => { r?; }
_ = cancel.cancelled() => { bail!("interrupted"); }
}
Ok(summary_result.overall.is_pass())
}