mirror of
https://github.com/openai/codex.git
synced 2026-05-13 15:52:40 +00:00
Compare commits
8 Commits
xli-codex/
...
codex/mini
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0e570952f4 | ||
|
|
bbc664079a | ||
|
|
0e331a685d | ||
|
|
5e7458e515 | ||
|
|
c0e4efdfb5 | ||
|
|
dd7aaab825 | ||
|
|
051accdb6a | ||
|
|
553376ffcb |
@@ -17,6 +17,7 @@ use codex_hooks::UserPromptSubmitRequest;
|
||||
use codex_otel::HOOK_RUN_DURATION_METRIC;
|
||||
use codex_otel::HOOK_RUN_METRIC;
|
||||
use codex_protocol::items::TurnItem;
|
||||
use codex_protocol::items::UserMessageItem;
|
||||
use codex_protocol::models::DeveloperInstructions;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -30,8 +31,10 @@ use codex_protocol::protocol::HookSource;
|
||||
use codex_protocol::protocol::HookStartedEvent;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::OwnedSemaphorePermit;
|
||||
|
||||
use crate::event_mapping::parse_turn_item;
|
||||
use crate::session::PreviousTurnSettings;
|
||||
use crate::session::session::Session;
|
||||
use crate::session::turn_context::TurnContext;
|
||||
use crate::tools::sandboxing::PermissionRequestPayload;
|
||||
@@ -57,6 +60,11 @@ pub(crate) enum PendingInputRecord {
|
||||
},
|
||||
}
|
||||
|
||||
pub(crate) enum TurnStartTranscriptDrainMode {
|
||||
RegularTurn,
|
||||
InterruptRecovery,
|
||||
}
|
||||
|
||||
struct ContextInjectingHookOutcome {
|
||||
hook_events: Vec<HookCompletedEvent>,
|
||||
outcome: HookRuntimeOutcome,
|
||||
@@ -268,6 +276,107 @@ pub(crate) async fn inspect_pending_input(
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn drain_turn_start_transcript_inputs(
|
||||
sess: &Arc<Session>,
|
||||
turn_context: &Arc<TurnContext>,
|
||||
mode: TurnStartTranscriptDrainMode,
|
||||
) -> bool {
|
||||
let Ok(permit) = turn_context
|
||||
.transcript_serialization_lock
|
||||
.clone()
|
||||
.acquire_owned()
|
||||
.await
|
||||
else {
|
||||
return false;
|
||||
};
|
||||
|
||||
drain_turn_start_transcript_inputs_with_permit(sess, turn_context, mode, permit).await
|
||||
}
|
||||
|
||||
pub(crate) async fn drain_turn_start_transcript_inputs_with_permit(
|
||||
sess: &Arc<Session>,
|
||||
turn_context: &Arc<TurnContext>,
|
||||
mode: TurnStartTranscriptDrainMode,
|
||||
_permit: OwnedSemaphorePermit,
|
||||
) -> bool {
|
||||
let has_queued_start_input = !turn_context.lock_turn_start_transcript_inputs().is_empty();
|
||||
if !has_queued_start_input && matches!(mode, TurnStartTranscriptDrainMode::InterruptRecovery) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Keep the normal turn-start ordering in one serialized region: context first,
|
||||
// then the user prompt, then hook-provided context and previous-turn settings.
|
||||
sess.record_context_updates_and_set_reference_context_item(turn_context.as_ref())
|
||||
.await;
|
||||
|
||||
if run_pending_session_start_hooks(sess, turn_context).await {
|
||||
turn_context.lock_turn_start_transcript_inputs().clear();
|
||||
return false;
|
||||
}
|
||||
|
||||
let mut recorded_start_input = false;
|
||||
loop {
|
||||
let input = {
|
||||
let inputs = turn_context.lock_turn_start_transcript_inputs();
|
||||
inputs.first().map(|queued| queued.input.clone())
|
||||
};
|
||||
let Some(input) = input else {
|
||||
break;
|
||||
};
|
||||
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input.clone());
|
||||
let response_item: ResponseItem = initial_input_for_turn.into();
|
||||
let user_prompt_submit_outcome = run_user_prompt_submit_hooks(
|
||||
sess,
|
||||
turn_context,
|
||||
UserMessageItem::new(&input).message(),
|
||||
)
|
||||
.await;
|
||||
|
||||
if user_prompt_submit_outcome.should_stop {
|
||||
record_additional_contexts(
|
||||
sess,
|
||||
turn_context,
|
||||
user_prompt_submit_outcome.additional_contexts,
|
||||
)
|
||||
.await;
|
||||
let mut inputs = turn_context.lock_turn_start_transcript_inputs();
|
||||
if inputs.first().is_some_and(|queued| queued.input == input) {
|
||||
inputs.remove(0);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
sess.record_user_prompt_and_emit_turn_item(
|
||||
turn_context.as_ref(),
|
||||
input.as_slice(),
|
||||
response_item,
|
||||
)
|
||||
.await;
|
||||
record_additional_contexts(
|
||||
sess,
|
||||
turn_context,
|
||||
user_prompt_submit_outcome.additional_contexts,
|
||||
)
|
||||
.await;
|
||||
let mut inputs = turn_context.lock_turn_start_transcript_inputs();
|
||||
if inputs.first().is_some_and(|queued| queued.input == input) {
|
||||
inputs.remove(0);
|
||||
}
|
||||
recorded_start_input = true;
|
||||
}
|
||||
|
||||
if recorded_start_input {
|
||||
sess.set_previous_turn_settings(Some(PreviousTurnSettings {
|
||||
model: turn_context.model_info.slug.clone(),
|
||||
realtime_active: Some(turn_context.realtime_active),
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
pub(crate) async fn record_pending_input(
|
||||
sess: &Arc<Session>,
|
||||
turn_context: &Arc<TurnContext>,
|
||||
|
||||
@@ -128,6 +128,7 @@ use rmcp::model::RequestId;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::sync::watch;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
@@ -139,6 +139,8 @@ pub(super) async fn spawn_review_thread(
|
||||
turn_metadata_state,
|
||||
turn_skills: TurnSkillsContext::new(parent_turn_context.turn_skills.outcome.clone()),
|
||||
turn_timing_state: Arc::new(TurnTimingState::default()),
|
||||
turn_start_transcript_inputs: Arc::new(std::sync::Mutex::new(Vec::new())),
|
||||
transcript_serialization_lock: Arc::new(Semaphore::new(1)),
|
||||
};
|
||||
|
||||
// Seed the child task with the review prompt as the initial user message.
|
||||
|
||||
@@ -144,7 +144,6 @@ use sha2::Sha512;
|
||||
use std::path::Path;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::time::sleep;
|
||||
use tokio::time::timeout;
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
use wiremock::Mock;
|
||||
@@ -306,12 +305,23 @@ async fn interrupting_regular_turn_waiting_on_startup_prewarm_emits_turn_aborted
|
||||
),
|
||||
)
|
||||
.await;
|
||||
sess.spawn_task(
|
||||
Arc::clone(&tc),
|
||||
Vec::new(),
|
||||
crate::tasks::RegularTask::new(),
|
||||
)
|
||||
.await;
|
||||
let input = vec![UserInput::Text {
|
||||
text: "hello before prewarm".to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}];
|
||||
let mut expected_history = sess.build_initial_context(tc.as_ref()).await;
|
||||
expected_history.push(ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello before prewarm".to_string(),
|
||||
}],
|
||||
end_turn: None,
|
||||
phase: None,
|
||||
});
|
||||
expected_history.push(crate::tasks::interrupted_turn_history_marker());
|
||||
sess.spawn_task(Arc::clone(&tc), input, crate::tasks::RegularTask::new())
|
||||
.await;
|
||||
|
||||
let first = tokio::time::timeout(std::time::Duration::from_millis(200), rx.recv())
|
||||
.await
|
||||
@@ -324,23 +334,112 @@ async fn interrupting_regular_turn_waiting_on_startup_prewarm_emits_turn_aborted
|
||||
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
|
||||
let second = tokio::time::timeout(std::time::Duration::from_secs(2), rx.recv())
|
||||
let (turn_id, reason, completed_at, duration_ms) =
|
||||
tokio::time::timeout(std::time::Duration::from_secs(2), async {
|
||||
loop {
|
||||
let event = rx.recv().await.expect("channel open");
|
||||
if let EventMsg::TurnAborted(TurnAbortedEvent {
|
||||
turn_id,
|
||||
reason,
|
||||
completed_at,
|
||||
duration_ms,
|
||||
}) = event.msg
|
||||
{
|
||||
return (turn_id, reason, completed_at, duration_ms);
|
||||
}
|
||||
}
|
||||
})
|
||||
.await
|
||||
.expect("expected turn aborted event")
|
||||
.expect("channel open");
|
||||
let EventMsg::TurnAborted(TurnAbortedEvent {
|
||||
turn_id,
|
||||
reason,
|
||||
completed_at,
|
||||
duration_ms,
|
||||
}) = second.msg
|
||||
else {
|
||||
panic!("expected turn aborted event");
|
||||
};
|
||||
.expect("expected turn aborted event");
|
||||
assert_eq!(turn_id, Some(tc.sub_id.clone()));
|
||||
assert_eq!(reason, TurnAbortReason::Interrupted);
|
||||
assert!(completed_at.is_some());
|
||||
assert!(duration_ms.is_some());
|
||||
|
||||
let history = sess.clone_history().await;
|
||||
assert_eq!(history.raw_items(), expected_history.as_slice());
|
||||
assert!(tc.lock_turn_start_transcript_inputs().is_empty());
|
||||
assert_eq!(
|
||||
sess.previous_turn_settings().await,
|
||||
Some(PreviousTurnSettings {
|
||||
model: tc.model_info.slug.clone(),
|
||||
realtime_active: Some(tc.realtime_active),
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn session_start_stop_clears_queued_start_input() -> anyhow::Result<()> {
|
||||
let (mut sess, tc) = make_session_and_context().await;
|
||||
std::fs::create_dir_all(&tc.config.codex_home)?;
|
||||
let hook_script = tc.config.codex_home.join("block_session_start_hook.py");
|
||||
std::fs::write(
|
||||
&hook_script,
|
||||
r#"import json
|
||||
import sys
|
||||
|
||||
json.load(sys.stdin)
|
||||
print(json.dumps({"continue": False, "stopReason": "blocked for test"}))
|
||||
"#,
|
||||
)?;
|
||||
std::fs::write(
|
||||
tc.config.codex_home.join("hooks.json"),
|
||||
json!({
|
||||
"hooks": {
|
||||
"SessionStart": [{
|
||||
"hooks": [{
|
||||
"type": "command",
|
||||
"command": format!("python3 {}", hook_script.display()),
|
||||
}]
|
||||
}]
|
||||
}
|
||||
})
|
||||
.to_string(),
|
||||
)?;
|
||||
sess.services.hooks = Hooks::new(HooksConfig {
|
||||
feature_enabled: true,
|
||||
config_layer_stack: Some(tc.config.config_layer_stack.clone()),
|
||||
..HooksConfig::default()
|
||||
});
|
||||
let sess = Arc::new(sess);
|
||||
let tc = Arc::new(tc);
|
||||
let input = vec![UserInput::Text {
|
||||
text: "prompt blocked by session start".to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}];
|
||||
sess.state
|
||||
.lock()
|
||||
.await
|
||||
.set_pending_session_start_source(Some(codex_hooks::SessionStartSource::Startup));
|
||||
tc.lock_turn_start_transcript_inputs()
|
||||
.push(crate::session::turn_context::TurnStartTranscriptInput { input });
|
||||
|
||||
assert!(
|
||||
!crate::hook_runtime::drain_turn_start_transcript_inputs(
|
||||
&sess,
|
||||
&tc,
|
||||
crate::hook_runtime::TurnStartTranscriptDrainMode::RegularTurn,
|
||||
)
|
||||
.await
|
||||
);
|
||||
assert!(tc.lock_turn_start_transcript_inputs().is_empty());
|
||||
|
||||
assert!(
|
||||
crate::hook_runtime::drain_turn_start_transcript_inputs(
|
||||
&sess,
|
||||
&tc,
|
||||
crate::hook_runtime::TurnStartTranscriptDrainMode::InterruptRecovery,
|
||||
)
|
||||
.await
|
||||
);
|
||||
let history = sess.clone_history().await;
|
||||
assert!(
|
||||
user_input_texts(history.raw_items())
|
||||
.iter()
|
||||
.all(|text| !text.contains("prompt blocked by session start"))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn test_model_client_session() -> crate::client::ModelClientSession {
|
||||
@@ -5669,9 +5768,7 @@ impl SessionTask for NeverEndingTask {
|
||||
cancellation_token.cancelled().await;
|
||||
return None;
|
||||
}
|
||||
loop {
|
||||
sleep(Duration::from_secs(60)).await;
|
||||
}
|
||||
std::future::pending::<Option<String>>().await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5742,6 +5839,62 @@ async fn abort_gracefully_emits_turn_aborted_only() {
|
||||
assert!(rx.try_recv().is_err());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[test_log::test]
|
||||
async fn abort_non_regular_task_keeps_pending_session_start_source() {
|
||||
let (sess, tc, _rx) = make_session_and_context_with_rx().await;
|
||||
sess.state
|
||||
.lock()
|
||||
.await
|
||||
.set_pending_session_start_source(Some(codex_hooks::SessionStartSource::Startup));
|
||||
let input = vec![UserInput::Text {
|
||||
text: "review this".to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}];
|
||||
sess.spawn_task(
|
||||
Arc::clone(&tc),
|
||||
input,
|
||||
NeverEndingTask {
|
||||
kind: TaskKind::Review,
|
||||
listen_to_cancellation_token: true,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
|
||||
assert!(matches!(
|
||||
sess.take_pending_session_start_source().await,
|
||||
Some(codex_hooks::SessionStartSource::Startup)
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
#[test_log::test]
|
||||
async fn abort_regular_task_without_start_input_keeps_pending_session_start_source() {
|
||||
let (sess, tc, _rx) = make_session_and_context_with_rx().await;
|
||||
sess.state
|
||||
.lock()
|
||||
.await
|
||||
.set_pending_session_start_source(Some(codex_hooks::SessionStartSource::Startup));
|
||||
sess.spawn_task(
|
||||
Arc::clone(&tc),
|
||||
Vec::new(),
|
||||
NeverEndingTask {
|
||||
kind: TaskKind::Regular,
|
||||
listen_to_cancellation_token: true,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
sess.abort_all_tasks(TurnAbortReason::Interrupted).await;
|
||||
|
||||
assert!(matches!(
|
||||
sess.take_pending_session_start_source().await,
|
||||
Some(codex_hooks::SessionStartSource::Startup)
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn task_finish_emits_turn_item_lifecycle_for_leftover_pending_user_input() {
|
||||
let (sess, tc, rx) = make_session_and_context_with_rx().await;
|
||||
|
||||
@@ -18,12 +18,13 @@ use crate::compact_remote::run_inline_remote_auto_compact_task;
|
||||
use crate::connectors;
|
||||
use crate::feedback_tags;
|
||||
use crate::hook_runtime::PendingInputHookDisposition;
|
||||
use crate::hook_runtime::TurnStartTranscriptDrainMode;
|
||||
use crate::hook_runtime::drain_turn_start_transcript_inputs;
|
||||
use crate::hook_runtime::emit_hook_completed_events;
|
||||
use crate::hook_runtime::inspect_pending_input;
|
||||
use crate::hook_runtime::record_additional_contexts;
|
||||
use crate::hook_runtime::record_pending_input;
|
||||
use crate::hook_runtime::run_pending_session_start_hooks;
|
||||
use crate::hook_runtime::run_user_prompt_submit_hooks;
|
||||
use crate::injection::ToolMentionKind;
|
||||
use crate::injection::app_id_from_path;
|
||||
use crate::injection::tool_kind_for_path;
|
||||
@@ -37,7 +38,6 @@ use crate::mentions::collect_tool_mentions_from_messages;
|
||||
use crate::parse_turn_item;
|
||||
use crate::plugins::build_plugin_injections;
|
||||
use crate::resolve_skill_dependencies_for_turn;
|
||||
use crate::session::PreviousTurnSettings;
|
||||
use crate::session::session::Session;
|
||||
use crate::session::turn_context::TurnContext;
|
||||
use crate::stream_events_utils::HandleOutputCtx;
|
||||
@@ -74,7 +74,6 @@ use codex_protocol::error::CodexErr;
|
||||
use codex_protocol::error::Result as CodexResult;
|
||||
use codex_protocol::items::PlanItem;
|
||||
use codex_protocol::items::TurnItem;
|
||||
use codex_protocol::items::UserMessageItem;
|
||||
use codex_protocol::items::build_hook_prompt_message;
|
||||
use codex_protocol::models::BaseInstructions;
|
||||
use codex_protocol::models::ContentItem;
|
||||
@@ -158,9 +157,6 @@ pub(crate) async fn run_turn(
|
||||
|
||||
let skills_outcome = Some(turn_context.turn_skills.outcome.as_ref());
|
||||
|
||||
sess.record_context_updates_and_set_reference_context_item(turn_context.as_ref())
|
||||
.await;
|
||||
|
||||
let loaded_plugins = sess
|
||||
.services
|
||||
.plugins_manager
|
||||
@@ -285,33 +281,15 @@ pub(crate) async fn run_turn(
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if run_pending_session_start_hooks(&sess, &turn_context).await {
|
||||
if !drain_turn_start_transcript_inputs(
|
||||
&sess,
|
||||
&turn_context,
|
||||
TurnStartTranscriptDrainMode::RegularTurn,
|
||||
)
|
||||
.await
|
||||
{
|
||||
return None;
|
||||
}
|
||||
let additional_contexts = if input.is_empty() {
|
||||
Vec::new()
|
||||
} else {
|
||||
let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input.clone());
|
||||
let response_item: ResponseItem = initial_input_for_turn.clone().into();
|
||||
let user_prompt_submit_outcome = run_user_prompt_submit_hooks(
|
||||
&sess,
|
||||
&turn_context,
|
||||
UserMessageItem::new(&input).message(),
|
||||
)
|
||||
.await;
|
||||
if user_prompt_submit_outcome.should_stop {
|
||||
record_additional_contexts(
|
||||
&sess,
|
||||
&turn_context,
|
||||
user_prompt_submit_outcome.additional_contexts,
|
||||
)
|
||||
.await;
|
||||
return None;
|
||||
}
|
||||
sess.record_user_prompt_and_emit_turn_item(turn_context.as_ref(), &input, response_item)
|
||||
.await;
|
||||
user_prompt_submit_outcome.additional_contexts
|
||||
};
|
||||
sess.services
|
||||
.analytics_events_client
|
||||
.track_app_mentioned(tracking.clone(), mentioned_app_invocations);
|
||||
@@ -322,17 +300,6 @@ pub(crate) async fn run_turn(
|
||||
}
|
||||
sess.merge_connector_selection(explicitly_enabled_connectors.clone())
|
||||
.await;
|
||||
record_additional_contexts(&sess, &turn_context, additional_contexts).await;
|
||||
if !input.is_empty() {
|
||||
// Track the previous-turn baseline from the regular user-turn path only so
|
||||
// standalone tasks (compact/shell/review/undo) cannot suppress future
|
||||
// model/realtime injections.
|
||||
sess.set_previous_turn_settings(Some(PreviousTurnSettings {
|
||||
model: turn_context.model_info.slug.clone(),
|
||||
realtime_active: Some(turn_context.realtime_active),
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
let agent_task = match sess.ensure_agent_task_registered().await {
|
||||
Ok(agent_task) => agent_task,
|
||||
Err(error) => {
|
||||
|
||||
@@ -24,6 +24,11 @@ impl TurnSkillsContext {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub(crate) struct TurnStartTranscriptInput {
|
||||
pub(crate) input: Vec<UserInput>,
|
||||
}
|
||||
|
||||
/// The context needed for a single turn of the thread.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct TurnContext {
|
||||
@@ -71,8 +76,18 @@ pub(crate) struct TurnContext {
|
||||
pub(crate) turn_metadata_state: Arc<TurnMetadataState>,
|
||||
pub(crate) turn_skills: TurnSkillsContext,
|
||||
pub(crate) turn_timing_state: Arc<TurnTimingState>,
|
||||
pub(crate) turn_start_transcript_inputs: Arc<std::sync::Mutex<Vec<TurnStartTranscriptInput>>>,
|
||||
pub(crate) transcript_serialization_lock: Arc<Semaphore>,
|
||||
}
|
||||
impl TurnContext {
|
||||
pub(crate) fn lock_turn_start_transcript_inputs(
|
||||
&self,
|
||||
) -> std::sync::MutexGuard<'_, Vec<TurnStartTranscriptInput>> {
|
||||
self.turn_start_transcript_inputs
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner)
|
||||
}
|
||||
|
||||
pub(crate) fn model_context_window(&self) -> Option<i64> {
|
||||
let effective_context_window_percent = self.model_info.effective_context_window_percent;
|
||||
self.model_info
|
||||
@@ -197,6 +212,8 @@ impl TurnContext {
|
||||
turn_metadata_state: self.turn_metadata_state.clone(),
|
||||
turn_skills: self.turn_skills.clone(),
|
||||
turn_timing_state: Arc::clone(&self.turn_timing_state),
|
||||
turn_start_transcript_inputs: Arc::clone(&self.turn_start_transcript_inputs),
|
||||
transcript_serialization_lock: Arc::clone(&self.transcript_serialization_lock),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -443,6 +460,8 @@ impl Session {
|
||||
turn_metadata_state,
|
||||
turn_skills: TurnSkillsContext::new(skills_outcome),
|
||||
turn_timing_state: Arc::new(TurnTimingState::default()),
|
||||
turn_start_transcript_inputs: Arc::new(std::sync::Mutex::new(Vec::new())),
|
||||
transcript_serialization_lock: Arc::new(Semaphore::new(1)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -70,6 +70,7 @@ pub(crate) struct RunningTask {
|
||||
pub(crate) done: Arc<Notify>,
|
||||
pub(crate) kind: TaskKind,
|
||||
pub(crate) task: Arc<dyn AnySessionTask>,
|
||||
pub(crate) records_turn_start_transcript: bool,
|
||||
pub(crate) cancellation_token: CancellationToken,
|
||||
pub(crate) handle: Arc<AbortOnDropHandle<()>>,
|
||||
pub(crate) turn_context: Arc<TurnContext>,
|
||||
|
||||
@@ -22,11 +22,15 @@ use tracing::warn;
|
||||
use crate::contextual_user_message::TURN_ABORTED_CLOSE_TAG;
|
||||
use crate::contextual_user_message::TURN_ABORTED_OPEN_TAG;
|
||||
use crate::hook_runtime::PendingInputHookDisposition;
|
||||
use crate::hook_runtime::TurnStartTranscriptDrainMode;
|
||||
use crate::hook_runtime::drain_turn_start_transcript_inputs;
|
||||
use crate::hook_runtime::drain_turn_start_transcript_inputs_with_permit;
|
||||
use crate::hook_runtime::inspect_pending_input;
|
||||
use crate::hook_runtime::record_additional_contexts;
|
||||
use crate::hook_runtime::record_pending_input;
|
||||
use crate::session::session::Session;
|
||||
use crate::session::turn_context::TurnContext;
|
||||
use crate::session::turn_context::TurnStartTranscriptInput;
|
||||
use crate::state::ActiveTurn;
|
||||
use crate::state::RunningTask;
|
||||
use crate::state::TaskKind;
|
||||
@@ -160,6 +164,11 @@ pub(crate) trait SessionTask: Send + Sync + 'static {
|
||||
/// Returns the tracing name for a spawned task span.
|
||||
fn span_name(&self) -> &'static str;
|
||||
|
||||
/// Whether this task owns the normal user-turn transcript start sequence.
|
||||
fn records_turn_start_transcript(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Executes the task until completion or cancellation.
|
||||
///
|
||||
/// Implementations typically stream protocol events using `session` and
|
||||
@@ -197,6 +206,8 @@ pub(crate) trait AnySessionTask: Send + Sync + 'static {
|
||||
|
||||
fn span_name(&self) -> &'static str;
|
||||
|
||||
fn records_turn_start_transcript(&self) -> bool;
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
@@ -224,6 +235,10 @@ where
|
||||
SessionTask::span_name(self)
|
||||
}
|
||||
|
||||
fn records_turn_start_transcript(&self) -> bool {
|
||||
SessionTask::records_turn_start_transcript(self)
|
||||
}
|
||||
|
||||
fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
@@ -270,6 +285,7 @@ impl Session {
|
||||
let task: Arc<dyn AnySessionTask> = Arc::new(task);
|
||||
let task_kind = task.kind();
|
||||
let span_name = task.span_name();
|
||||
let records_turn_start_transcript = task.records_turn_start_transcript();
|
||||
let started_at = Instant::now();
|
||||
turn_context
|
||||
.turn_timing_state
|
||||
@@ -299,6 +315,13 @@ impl Session {
|
||||
}
|
||||
}
|
||||
|
||||
if records_turn_start_transcript && !input.is_empty() {
|
||||
turn_context
|
||||
.lock_turn_start_transcript_inputs()
|
||||
.push(TurnStartTranscriptInput {
|
||||
input: input.clone(),
|
||||
});
|
||||
}
|
||||
let mut active = self.active_turn.lock().await;
|
||||
let turn = active.get_or_insert_with(ActiveTurn::default);
|
||||
debug_assert!(turn.tasks.is_empty());
|
||||
@@ -358,6 +381,7 @@ impl Session {
|
||||
handle: Arc::new(AbortOnDropHandle::new(handle)),
|
||||
kind: task_kind,
|
||||
task,
|
||||
records_turn_start_transcript,
|
||||
cancellation_token,
|
||||
turn_context: Arc::clone(&turn_context),
|
||||
_timer: timer,
|
||||
@@ -618,15 +642,58 @@ impl Session {
|
||||
.cancel_git_enrichment_task();
|
||||
let session_task = task.task;
|
||||
|
||||
select! {
|
||||
_ = task.done.notified() => {
|
||||
},
|
||||
_ = tokio::time::sleep(Duration::from_millis(GRACEFULL_INTERRUPTION_TIMEOUT_MS)) => {
|
||||
warn!("task {sub_id} didn't complete gracefully after {}ms", GRACEFULL_INTERRUPTION_TIMEOUT_MS);
|
||||
let mut handle_aborted = false;
|
||||
// The startup prompt queue is serialized by the regular turn itself. If the serializer is
|
||||
// idle, grab it and abort the task before recovering the queued prompt; if the serializer
|
||||
// is already active, wait for it to finish before force-aborting so history cannot be left
|
||||
// half-written. This can wait on turn-start hooks, but it keeps hook policy and hook-added
|
||||
// context ahead of the model-visible interrupt marker.
|
||||
if reason == TurnAbortReason::Interrupted && task.records_turn_start_transcript {
|
||||
let has_turn_start_transcript_input = {
|
||||
let inputs = task.turn_context.lock_turn_start_transcript_inputs();
|
||||
!inputs.is_empty()
|
||||
};
|
||||
if has_turn_start_transcript_input {
|
||||
match task
|
||||
.turn_context
|
||||
.transcript_serialization_lock
|
||||
.clone()
|
||||
.try_acquire_owned()
|
||||
{
|
||||
Ok(permit) => {
|
||||
task.handle.abort();
|
||||
handle_aborted = true;
|
||||
let _ = drain_turn_start_transcript_inputs_with_permit(
|
||||
self,
|
||||
&task.turn_context,
|
||||
TurnStartTranscriptDrainMode::InterruptRecovery,
|
||||
permit,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = drain_turn_start_transcript_inputs(
|
||||
self,
|
||||
&task.turn_context,
|
||||
TurnStartTranscriptDrainMode::InterruptRecovery,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
task.handle.abort();
|
||||
if !handle_aborted {
|
||||
select! {
|
||||
_ = task.done.notified() => {
|
||||
},
|
||||
_ = tokio::time::sleep(Duration::from_millis(GRACEFULL_INTERRUPTION_TIMEOUT_MS)) => {
|
||||
warn!("task {sub_id} didn't complete gracefully after {}ms", GRACEFULL_INTERRUPTION_TIMEOUT_MS);
|
||||
}
|
||||
}
|
||||
|
||||
task.handle.abort();
|
||||
}
|
||||
|
||||
let session_ctx = Arc::new(SessionTaskContext::new(Arc::clone(self)));
|
||||
session_task
|
||||
|
||||
@@ -33,6 +33,10 @@ impl SessionTask for RegularTask {
|
||||
"session_task.turn"
|
||||
}
|
||||
|
||||
fn records_turn_start_transcript(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
session: Arc<SessionTaskContext>,
|
||||
|
||||
Reference in New Issue
Block a user