mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
codex: tighten realtime shutdown ordering (#14902)
Co-own the realtime fanout task so shutdown stops the full session before closed is emitted. Also keep start/connect failures as error-only until started is sent.\n\nCo-authored-by: Codex <noreply@openai.com>
This commit is contained in:
@@ -63,6 +63,11 @@ enum RealtimeConversationEnd {
|
||||
Error,
|
||||
}
|
||||
|
||||
enum RealtimeFanoutTaskStop {
|
||||
Abort,
|
||||
Detach,
|
||||
}
|
||||
|
||||
pub(crate) struct RealtimeConversationManager {
|
||||
state: Mutex<Option<ConversationState>>,
|
||||
}
|
||||
@@ -127,7 +132,8 @@ struct ConversationState {
|
||||
user_text_tx: Sender<String>,
|
||||
writer: RealtimeWebsocketWriter,
|
||||
handoff: RealtimeHandoffState,
|
||||
task: JoinHandle<()>,
|
||||
input_task: JoinHandle<()>,
|
||||
fanout_task: Option<JoinHandle<()>>,
|
||||
realtime_active: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
@@ -157,9 +163,7 @@ impl RealtimeConversationManager {
|
||||
guard.take()
|
||||
};
|
||||
if let Some(state) = previous_state {
|
||||
state.realtime_active.store(false, Ordering::Relaxed);
|
||||
state.task.abort();
|
||||
let _ = state.task.await;
|
||||
stop_conversation_state(state, RealtimeFanoutTaskStop::Abort).await;
|
||||
}
|
||||
let session_kind = match session_config.event_parser {
|
||||
RealtimeEventParser::V1 => RealtimeSessionKind::V1,
|
||||
@@ -206,12 +210,48 @@ impl RealtimeConversationManager {
|
||||
user_text_tx,
|
||||
writer,
|
||||
handoff,
|
||||
task,
|
||||
input_task: task,
|
||||
fanout_task: None,
|
||||
realtime_active: Arc::clone(&realtime_active),
|
||||
});
|
||||
Ok((events_rx, realtime_active))
|
||||
}
|
||||
|
||||
pub(crate) async fn register_fanout_task(
|
||||
&self,
|
||||
realtime_active: &Arc<AtomicBool>,
|
||||
fanout_task: JoinHandle<()>,
|
||||
) {
|
||||
let mut fanout_task = Some(fanout_task);
|
||||
{
|
||||
let mut guard = self.state.lock().await;
|
||||
if let Some(state) = guard.as_mut()
|
||||
&& Arc::ptr_eq(&state.realtime_active, realtime_active)
|
||||
{
|
||||
state.fanout_task = fanout_task.take();
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(fanout_task) = fanout_task {
|
||||
fanout_task.abort();
|
||||
let _ = fanout_task.await;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn finish_if_active(&self, realtime_active: &Arc<AtomicBool>) {
|
||||
let state = {
|
||||
let mut guard = self.state.lock().await;
|
||||
match guard.as_ref() {
|
||||
Some(state) if Arc::ptr_eq(&state.realtime_active, realtime_active) => guard.take(),
|
||||
_ => None,
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(state) = state {
|
||||
stop_conversation_state(state, RealtimeFanoutTaskStop::Detach).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn audio_in(&self, frame: RealtimeAudioFrame) -> CodexResult<()> {
|
||||
let sender = {
|
||||
let guard = self.state.lock().await;
|
||||
@@ -339,14 +379,32 @@ impl RealtimeConversationManager {
|
||||
};
|
||||
|
||||
if let Some(state) = state {
|
||||
state.realtime_active.store(false, Ordering::Relaxed);
|
||||
state.task.abort();
|
||||
let _ = state.task.await;
|
||||
stop_conversation_state(state, RealtimeFanoutTaskStop::Abort).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn stop_conversation_state(
|
||||
mut state: ConversationState,
|
||||
fanout_task_stop: RealtimeFanoutTaskStop,
|
||||
) {
|
||||
state.realtime_active.store(false, Ordering::Relaxed);
|
||||
state.input_task.abort();
|
||||
let _ = state.input_task.await;
|
||||
|
||||
match state.fanout_task.take() {
|
||||
Some(fanout_task) => match fanout_task_stop {
|
||||
RealtimeFanoutTaskStop::Abort => {
|
||||
fanout_task.abort();
|
||||
let _ = fanout_task.await;
|
||||
}
|
||||
RealtimeFanoutTaskStop::Detach => {}
|
||||
},
|
||||
None => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn handle_start(
|
||||
sess: &Arc<Session>,
|
||||
sub_id: String,
|
||||
@@ -378,7 +436,6 @@ pub(crate) async fn handle_start(
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
end_realtime_conversation(sess, sub_id, RealtimeConversationEnd::Error).await;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -477,13 +534,17 @@ async fn handle_start_inner(
|
||||
|
||||
let sess_clone = Arc::clone(sess);
|
||||
let sub_id = sub_id.to_string();
|
||||
tokio::spawn(async move {
|
||||
let fanout_realtime_active = Arc::clone(&realtime_active);
|
||||
let fanout_task = tokio::spawn(async move {
|
||||
let ev = |msg| Event {
|
||||
id: sub_id.clone(),
|
||||
msg,
|
||||
};
|
||||
let mut end = RealtimeConversationEnd::TransportClosed;
|
||||
while let Ok(event) = events_rx.recv().await {
|
||||
if !fanout_realtime_active.load(Ordering::Relaxed) {
|
||||
break;
|
||||
}
|
||||
// if not audio out, log the event
|
||||
if !matches!(event, RealtimeEvent::AudioOut(_)) {
|
||||
info!(
|
||||
@@ -505,6 +566,9 @@ async fn handle_start_inner(
|
||||
let sess_for_routed_text = Arc::clone(&sess_clone);
|
||||
sess_for_routed_text.route_realtime_text_input(text).await;
|
||||
}
|
||||
if !fanout_realtime_active.load(Ordering::Relaxed) {
|
||||
break;
|
||||
}
|
||||
sess_clone
|
||||
.send_event_raw(ev(EventMsg::RealtimeConversationRealtime(
|
||||
RealtimeConversationRealtimeEvent {
|
||||
@@ -513,13 +577,20 @@ async fn handle_start_inner(
|
||||
)))
|
||||
.await;
|
||||
}
|
||||
if realtime_active.swap(false, Ordering::Relaxed) {
|
||||
if fanout_realtime_active.swap(false, Ordering::Relaxed) {
|
||||
if matches!(end, RealtimeConversationEnd::TransportClosed) {
|
||||
info!("realtime conversation transport closed");
|
||||
}
|
||||
end_realtime_conversation(&sess_clone, sub_id, end).await;
|
||||
sess_clone
|
||||
.conversation
|
||||
.finish_if_active(&fanout_realtime_active)
|
||||
.await;
|
||||
send_realtime_conversation_closed(&sess_clone, sub_id, end).await;
|
||||
}
|
||||
});
|
||||
sess.conversation
|
||||
.register_fanout_task(&realtime_active, fanout_task)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -532,15 +603,7 @@ pub(crate) async fn handle_audio(
|
||||
if let Err(err) = sess.conversation.audio_in(params.frame).await {
|
||||
error!("failed to append realtime audio: {err}");
|
||||
if sess.conversation.running_state().await.is_some() {
|
||||
let message = err.to_string();
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent {
|
||||
payload: RealtimeEvent::Error(message),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
end_realtime_conversation(sess, sub_id, RealtimeConversationEnd::Error).await;
|
||||
warn!("realtime audio input failed while the session was already ending");
|
||||
} else {
|
||||
send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::BadRequest)
|
||||
.await;
|
||||
@@ -620,15 +683,7 @@ pub(crate) async fn handle_text(
|
||||
if let Err(err) = sess.conversation.text_in(params.text).await {
|
||||
error!("failed to append realtime text: {err}");
|
||||
if sess.conversation.running_state().await.is_some() {
|
||||
let message = err.to_string();
|
||||
sess.send_event_raw(Event {
|
||||
id: sub_id.clone(),
|
||||
msg: EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent {
|
||||
payload: RealtimeEvent::Error(message),
|
||||
}),
|
||||
})
|
||||
.await;
|
||||
end_realtime_conversation(sess, sub_id, RealtimeConversationEnd::Error).await;
|
||||
warn!("realtime text input failed while the session was already ending");
|
||||
} else {
|
||||
send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::BadRequest)
|
||||
.await;
|
||||
@@ -665,6 +720,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
|
||||
if let Err(err) = writer.send_conversation_item_create(text).await {
|
||||
let mapped_error = map_api_error(err);
|
||||
warn!("failed to send input text: {mapped_error}");
|
||||
let _ = events_tx
|
||||
.send(RealtimeEvent::Error(mapped_error.to_string()))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
if matches!(session_kind, RealtimeSessionKind::V2) {
|
||||
@@ -673,6 +731,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
|
||||
} else if let Err(err) = writer.send_response_create().await {
|
||||
let mapped_error = map_api_error(err);
|
||||
warn!("failed to send text response.create: {mapped_error}");
|
||||
let _ = events_tx
|
||||
.send(RealtimeEvent::Error(mapped_error.to_string()))
|
||||
.await;
|
||||
break;
|
||||
} else {
|
||||
pending_response_create = false;
|
||||
@@ -697,6 +758,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
|
||||
{
|
||||
let mapped_error = map_api_error(err);
|
||||
warn!("failed to send handoff output: {mapped_error}");
|
||||
let _ = events_tx
|
||||
.send(RealtimeEvent::Error(mapped_error.to_string()))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -710,6 +774,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
|
||||
{
|
||||
let mapped_error = map_api_error(err);
|
||||
warn!("failed to send handoff output: {mapped_error}");
|
||||
let _ = events_tx
|
||||
.send(RealtimeEvent::Error(mapped_error.to_string()))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
if matches!(session_kind, RealtimeSessionKind::V2) {
|
||||
@@ -720,6 +787,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
|
||||
warn!(
|
||||
"failed to send handoff response.create: {mapped_error}"
|
||||
);
|
||||
let _ = events_tx
|
||||
.send(RealtimeEvent::Error(mapped_error.to_string()))
|
||||
.await;
|
||||
break;
|
||||
} else {
|
||||
pending_response_create = false;
|
||||
@@ -757,6 +827,11 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
|
||||
warn!(
|
||||
"failed to send deferred response.create: {mapped_error}"
|
||||
);
|
||||
let _ = events_tx
|
||||
.send(RealtimeEvent::Error(
|
||||
mapped_error.to_string(),
|
||||
))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
pending_response_create = false;
|
||||
@@ -804,6 +879,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
|
||||
warn!(
|
||||
"failed to send deferred response.create after cancellation: {mapped_error}"
|
||||
);
|
||||
let _ = events_tx
|
||||
.send(RealtimeEvent::Error(mapped_error.to_string()))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
pending_response_create = false;
|
||||
@@ -867,6 +945,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
|
||||
if let Err(err) = writer.send_audio_frame(frame).await {
|
||||
let mapped_error = map_api_error(err);
|
||||
error!("failed to send input audio: {mapped_error}");
|
||||
let _ = events_tx
|
||||
.send(RealtimeEvent::Error(mapped_error.to_string()))
|
||||
.await;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -943,7 +1024,14 @@ async fn end_realtime_conversation(
|
||||
end: RealtimeConversationEnd,
|
||||
) {
|
||||
let _ = sess.conversation.shutdown().await;
|
||||
send_realtime_conversation_closed(sess, sub_id, end).await;
|
||||
}
|
||||
|
||||
async fn send_realtime_conversation_closed(
|
||||
sess: &Arc<Session>,
|
||||
sub_id: String,
|
||||
end: RealtimeConversationEnd,
|
||||
) {
|
||||
let reason = match end {
|
||||
RealtimeConversationEnd::Requested => Some("requested".to_string()),
|
||||
RealtimeConversationEnd::TransportClosed => Some("transport_closed".to_string()),
|
||||
|
||||
@@ -12,7 +12,6 @@ use codex_protocol::protocol::ErrorEvent;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::Op;
|
||||
use codex_protocol::protocol::RealtimeAudioFrame;
|
||||
use codex_protocol::protocol::RealtimeConversationClosedEvent;
|
||||
use codex_protocol::protocol::RealtimeConversationRealtimeEvent;
|
||||
use codex_protocol::protocol::RealtimeEvent;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
@@ -480,7 +479,7 @@ async fn conversation_start_preflight_failure_emits_realtime_error_only() -> Res
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn conversation_start_connect_failure_emits_realtime_error_and_closed() -> Result<()> {
|
||||
async fn conversation_start_connect_failure_emits_realtime_error_only() -> Result<()> {
|
||||
skip_if_no_network!(Ok(()));
|
||||
|
||||
let server = start_websocket_server(vec![]).await;
|
||||
@@ -505,17 +504,15 @@ async fn conversation_start_connect_failure_emits_realtime_error_and_closed() ->
|
||||
.await;
|
||||
assert!(!err.is_empty());
|
||||
|
||||
let closed = wait_for_event_match(&test.codex, |msg| match msg {
|
||||
EventMsg::RealtimeConversationClosed(closed) => Some(closed.clone()),
|
||||
_ => None,
|
||||
let closed = timeout(Duration::from_millis(200), async {
|
||||
wait_for_event_match(&test.codex, |msg| match msg {
|
||||
EventMsg::RealtimeConversationClosed(closed) => Some(closed.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.await
|
||||
})
|
||||
.await;
|
||||
assert_eq!(
|
||||
closed,
|
||||
RealtimeConversationClosedEvent {
|
||||
reason: Some("error".to_string()),
|
||||
}
|
||||
);
|
||||
assert!(closed.is_err(), "connect failure should not emit closed");
|
||||
|
||||
server.shutdown().await;
|
||||
Ok(())
|
||||
|
||||
Reference in New Issue
Block a user