Unify realtime shutdown in core (#14902)

- route realtime startup, input, and transport failures through a single
shutdown path
- emit one realtime error/closed lifecycle while clearing session state
once

---------

Co-authored-by: Codex <noreply@openai.com>
Co-authored-by: Ahmed Ibrahim <219906144+aibrahim-oai@users.noreply.github.com>
This commit is contained in:
Ahmed Ibrahim
2026-03-17 15:58:52 -07:00
committed by GitHub
parent c6ab4ee537
commit 98be562fd3
3 changed files with 355 additions and 89 deletions

View File

@@ -190,7 +190,7 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> {
read_notification::<ThreadRealtimeClosedNotification>(&mut mcp, "thread/realtime/closed")
.await?;
assert_eq!(closed.thread_id, output_audio.thread_id);
assert_eq!(closed.reason.as_deref(), Some("transport_closed"));
assert_eq!(closed.reason.as_deref(), Some("error"));
let connections = realtime_server.connections();
assert_eq!(connections.len(), 1);

View File

@@ -56,6 +56,18 @@ const REALTIME_STARTUP_CONTEXT_TOKEN_BUDGET: usize = 5_000;
const ACTIVE_RESPONSE_CONFLICT_ERROR_PREFIX: &str =
"Conversation already has an active response in progress:";
#[derive(Debug)]
enum RealtimeConversationEnd {
Requested,
TransportClosed,
Error,
}
enum RealtimeFanoutTaskStop {
Abort,
Detach,
}
pub(crate) struct RealtimeConversationManager {
state: Mutex<Option<ConversationState>>,
}
@@ -120,7 +132,8 @@ struct ConversationState {
user_text_tx: Sender<String>,
writer: RealtimeWebsocketWriter,
handoff: RealtimeHandoffState,
task: JoinHandle<()>,
input_task: JoinHandle<()>,
fanout_task: Option<JoinHandle<()>>,
realtime_active: Arc<AtomicBool>,
}
@@ -150,9 +163,7 @@ impl RealtimeConversationManager {
guard.take()
};
if let Some(state) = previous_state {
state.realtime_active.store(false, Ordering::Relaxed);
state.task.abort();
let _ = state.task.await;
stop_conversation_state(state, RealtimeFanoutTaskStop::Abort).await;
}
let session_kind = match session_config.event_parser {
RealtimeEventParser::V1 => RealtimeSessionKind::V1,
@@ -199,12 +210,48 @@ impl RealtimeConversationManager {
user_text_tx,
writer,
handoff,
task,
input_task: task,
fanout_task: None,
realtime_active: Arc::clone(&realtime_active),
});
Ok((events_rx, realtime_active))
}
pub(crate) async fn register_fanout_task(
&self,
realtime_active: &Arc<AtomicBool>,
fanout_task: JoinHandle<()>,
) {
let mut fanout_task = Some(fanout_task);
{
let mut guard = self.state.lock().await;
if let Some(state) = guard.as_mut()
&& Arc::ptr_eq(&state.realtime_active, realtime_active)
{
state.fanout_task = fanout_task.take();
}
}
if let Some(fanout_task) = fanout_task {
fanout_task.abort();
let _ = fanout_task.await;
}
}
pub(crate) async fn finish_if_active(&self, realtime_active: &Arc<AtomicBool>) {
let state = {
let mut guard = self.state.lock().await;
match guard.as_ref() {
Some(state) if Arc::ptr_eq(&state.realtime_active, realtime_active) => guard.take(),
_ => None,
}
};
if let Some(state) = state {
stop_conversation_state(state, RealtimeFanoutTaskStop::Detach).await;
}
}
pub(crate) async fn audio_in(&self, frame: RealtimeAudioFrame) -> CodexResult<()> {
let sender = {
let guard = self.state.lock().await;
@@ -332,19 +379,78 @@ impl RealtimeConversationManager {
};
if let Some(state) = state {
state.realtime_active.store(false, Ordering::Relaxed);
state.task.abort();
let _ = state.task.await;
stop_conversation_state(state, RealtimeFanoutTaskStop::Abort).await;
}
Ok(())
}
}
async fn stop_conversation_state(
mut state: ConversationState,
fanout_task_stop: RealtimeFanoutTaskStop,
) {
state.realtime_active.store(false, Ordering::Relaxed);
state.input_task.abort();
let _ = state.input_task.await;
if let Some(fanout_task) = state.fanout_task.take() {
match fanout_task_stop {
RealtimeFanoutTaskStop::Abort => {
fanout_task.abort();
let _ = fanout_task.await;
}
RealtimeFanoutTaskStop::Detach => {}
}
}
}
pub(crate) async fn handle_start(
sess: &Arc<Session>,
sub_id: String,
params: ConversationStartParams,
) -> CodexResult<()> {
let prepared_start = match prepare_realtime_start(sess, params).await {
Ok(prepared_start) => prepared_start,
Err(err) => {
error!("failed to prepare realtime conversation: {err}");
let message = err.to_string();
sess.send_event_raw(Event {
id: sub_id,
msg: EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent {
payload: RealtimeEvent::Error(message),
}),
})
.await;
return Ok(());
}
};
if let Err(err) = handle_start_inner(sess, &sub_id, prepared_start).await {
error!("failed to start realtime conversation: {err}");
let message = err.to_string();
sess.send_event_raw(Event {
id: sub_id.clone(),
msg: EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent {
payload: RealtimeEvent::Error(message),
}),
})
.await;
}
Ok(())
}
struct PreparedRealtimeConversationStart {
api_provider: ApiProvider,
extra_headers: Option<HeaderMap>,
requested_session_id: Option<String>,
version: RealtimeWsVersion,
session_config: RealtimeSessionConfig,
}
async fn prepare_realtime_start(
sess: &Arc<Session>,
params: ConversationStartParams,
) -> CodexResult<PreparedRealtimeConversationStart> {
let provider = sess.provider().await;
let auth = sess.services.auth_manager.auth().await;
let realtime_api_key = realtime_api_key(auth.as_ref(), &provider)?;
@@ -380,9 +486,7 @@ pub(crate) async fn handle_start(
RealtimeWsMode::Conversational => RealtimeSessionMode::Conversational,
RealtimeWsMode::Transcription => RealtimeSessionMode::Transcription,
};
let requested_session_id = params
.session_id
.or_else(|| Some(sess.conversation_id.to_string()));
let requested_session_id = params.session_id.or(Some(sess.conversation_id.to_string()));
let session_config = RealtimeSessionConfig {
instructions: prompt,
model,
@@ -392,24 +496,37 @@ pub(crate) async fn handle_start(
};
let extra_headers =
realtime_request_headers(requested_session_id.as_deref(), realtime_api_key.as_str())?;
Ok(PreparedRealtimeConversationStart {
api_provider,
extra_headers,
requested_session_id,
version,
session_config,
})
}
async fn handle_start_inner(
sess: &Arc<Session>,
sub_id: &str,
prepared_start: PreparedRealtimeConversationStart,
) -> CodexResult<()> {
let PreparedRealtimeConversationStart {
api_provider,
extra_headers,
requested_session_id,
version,
session_config,
} = prepared_start;
info!("starting realtime conversation");
let (events_rx, realtime_active) = match sess
let (events_rx, realtime_active) = sess
.conversation
.start(api_provider, extra_headers, session_config)
.await
{
Ok(events_rx) => events_rx,
Err(err) => {
error!("failed to start realtime conversation: {err}");
send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::Other).await;
return Ok(());
}
};
.await?;
info!("realtime conversation started");
sess.send_event_raw(Event {
id: sub_id.clone(),
id: sub_id.to_string(),
msg: EventMsg::RealtimeConversationStarted(RealtimeConversationStartedEvent {
session_id: requested_session_id,
version,
@@ -418,12 +535,18 @@ pub(crate) async fn handle_start(
.await;
let sess_clone = Arc::clone(sess);
tokio::spawn(async move {
let sub_id = sub_id.to_string();
let fanout_realtime_active = Arc::clone(&realtime_active);
let fanout_task = tokio::spawn(async move {
let ev = |msg| Event {
id: sub_id.clone(),
msg,
};
let mut end = RealtimeConversationEnd::TransportClosed;
while let Ok(event) = events_rx.recv().await {
if !fanout_realtime_active.load(Ordering::Relaxed) {
break;
}
// if not audio out, log the event
if !matches!(event, RealtimeEvent::AudioOut(_)) {
info!(
@@ -431,6 +554,9 @@ pub(crate) async fn handle_start(
"received realtime conversation event"
);
}
if matches!(event, RealtimeEvent::Error(_)) {
end = RealtimeConversationEnd::Error;
}
let maybe_routed_text = match &event {
RealtimeEvent::HandoffRequested(handoff) => {
realtime_text_from_handoff_request(handoff)
@@ -442,6 +568,9 @@ pub(crate) async fn handle_start(
let sess_for_routed_text = Arc::clone(&sess_clone);
sess_for_routed_text.route_realtime_text_input(text).await;
}
if !fanout_realtime_active.load(Ordering::Relaxed) {
break;
}
sess_clone
.send_event_raw(ev(EventMsg::RealtimeConversationRealtime(
RealtimeConversationRealtimeEvent {
@@ -450,17 +579,20 @@ pub(crate) async fn handle_start(
)))
.await;
}
if realtime_active.swap(false, Ordering::Relaxed) {
info!("realtime conversation transport closed");
if fanout_realtime_active.swap(false, Ordering::Relaxed) {
if matches!(end, RealtimeConversationEnd::TransportClosed) {
info!("realtime conversation transport closed");
}
sess_clone
.send_event_raw(ev(EventMsg::RealtimeConversationClosed(
RealtimeConversationClosedEvent {
reason: Some("transport_closed".to_string()),
},
)))
.conversation
.finish_if_active(&fanout_realtime_active)
.await;
send_realtime_conversation_closed(&sess_clone, sub_id, end).await;
}
});
sess.conversation
.register_fanout_task(&realtime_active, fanout_task)
.await;
Ok(())
}
@@ -472,7 +604,12 @@ pub(crate) async fn handle_audio(
) {
if let Err(err) = sess.conversation.audio_in(params.frame).await {
error!("failed to append realtime audio: {err}");
send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::BadRequest).await;
if sess.conversation.running_state().await.is_some() {
warn!("realtime audio input failed while the session was already ending");
} else {
send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::BadRequest)
.await;
}
}
}
@@ -480,14 +617,12 @@ fn realtime_text_from_handoff_request(handoff: &RealtimeHandoffRequested) -> Opt
let active_transcript = handoff
.active_transcript
.iter()
.map(|entry| format!("{}: {}", entry.role, entry.text))
.map(|entry| format!("{role}: {text}", role = entry.role, text = entry.text))
.collect::<Vec<_>>()
.join("\n");
(!active_transcript.is_empty())
.then_some(active_transcript)
.or_else(|| {
(!handoff.input_transcript.is_empty()).then(|| handoff.input_transcript.clone())
})
.or((!handoff.input_transcript.is_empty()).then_some(handoff.input_transcript.clone()))
}
fn realtime_api_key(
@@ -547,25 +682,17 @@ pub(crate) async fn handle_text(
debug!(text = %params.text, "[realtime-text] appending realtime conversation text input");
if let Err(err) = sess.conversation.text_in(params.text).await {
error!("failed to append realtime text: {err}");
send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::BadRequest).await;
if sess.conversation.running_state().await.is_some() {
warn!("realtime text input failed while the session was already ending");
} else {
send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::BadRequest)
.await;
}
}
}
pub(crate) async fn handle_close(sess: &Arc<Session>, sub_id: String) {
match sess.conversation.shutdown().await {
Ok(()) => {
sess.send_event_raw(Event {
id: sub_id,
msg: EventMsg::RealtimeConversationClosed(RealtimeConversationClosedEvent {
reason: Some("requested".to_string()),
}),
})
.await;
}
Err(err) => {
send_conversation_error(sess, sub_id, err.to_string(), CodexErrorInfo::Other).await;
}
}
end_realtime_conversation(sess, sub_id, RealtimeConversationEnd::Requested).await;
}
fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
@@ -593,6 +720,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
if let Err(err) = writer.send_conversation_item_create(text).await {
let mapped_error = map_api_error(err);
warn!("failed to send input text: {mapped_error}");
let _ = events_tx
.send(RealtimeEvent::Error(mapped_error.to_string()))
.await;
break;
}
if matches!(session_kind, RealtimeSessionKind::V2) {
@@ -601,6 +731,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
} else if let Err(err) = writer.send_response_create().await {
let mapped_error = map_api_error(err);
warn!("failed to send text response.create: {mapped_error}");
let _ = events_tx
.send(RealtimeEvent::Error(mapped_error.to_string()))
.await;
break;
} else {
pending_response_create = false;
@@ -625,6 +758,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
{
let mapped_error = map_api_error(err);
warn!("failed to send handoff output: {mapped_error}");
let _ = events_tx
.send(RealtimeEvent::Error(mapped_error.to_string()))
.await;
break;
}
}
@@ -638,6 +774,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
{
let mapped_error = map_api_error(err);
warn!("failed to send handoff output: {mapped_error}");
let _ = events_tx
.send(RealtimeEvent::Error(mapped_error.to_string()))
.await;
break;
}
if matches!(session_kind, RealtimeSessionKind::V2) {
@@ -648,6 +787,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
warn!(
"failed to send handoff response.create: {mapped_error}"
);
let _ = events_tx
.send(RealtimeEvent::Error(mapped_error.to_string()))
.await;
break;
} else {
pending_response_create = false;
@@ -685,6 +827,11 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
warn!(
"failed to send deferred response.create: {mapped_error}"
);
let _ = events_tx
.send(RealtimeEvent::Error(
mapped_error.to_string(),
))
.await;
break;
}
pending_response_create = false;
@@ -732,6 +879,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
warn!(
"failed to send deferred response.create after cancellation: {mapped_error}"
);
let _ = events_tx
.send(RealtimeEvent::Error(mapped_error.to_string()))
.await;
break;
}
pending_response_create = false;
@@ -773,11 +923,6 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
}
}
Ok(None) => {
let _ = events_tx
.send(RealtimeEvent::Error(
"realtime websocket connection is closed".to_string(),
))
.await;
break;
}
Err(err) => {
@@ -800,6 +945,9 @@ fn spawn_realtime_input_task(input: RealtimeInputTask) -> JoinHandle<()> {
if let Err(err) = writer.send_audio_frame(frame).await {
let mapped_error = map_api_error(err);
error!("failed to send input audio: {mapped_error}");
let _ = events_tx
.send(RealtimeEvent::Error(mapped_error.to_string()))
.await;
break;
}
}
@@ -839,7 +987,7 @@ fn update_output_audio_state(
fn audio_duration_ms(frame: &RealtimeAudioFrame) -> u32 {
let Some(samples_per_channel) = frame
.samples_per_channel
.or_else(|| decoded_samples_per_channel(frame))
.or(decoded_samples_per_channel(frame))
else {
return 0;
};
@@ -870,6 +1018,33 @@ async fn send_conversation_error(
.await;
}
async fn end_realtime_conversation(
sess: &Arc<Session>,
sub_id: String,
end: RealtimeConversationEnd,
) {
let _ = sess.conversation.shutdown().await;
send_realtime_conversation_closed(sess, sub_id, end).await;
}
async fn send_realtime_conversation_closed(
sess: &Arc<Session>,
sub_id: String,
end: RealtimeConversationEnd,
) {
let reason = match end {
RealtimeConversationEnd::Requested => Some("requested".to_string()),
RealtimeConversationEnd::TransportClosed => Some("transport_closed".to_string()),
RealtimeConversationEnd::Error => Some("error".to_string()),
};
sess.send_event_raw(Event {
id: sub_id,
msg: EventMsg::RealtimeConversationClosed(RealtimeConversationClosedEvent { reason }),
})
.await;
}
#[cfg(test)]
#[path = "realtime_conversation_tests.rs"]
mod tests;

View File

@@ -30,15 +30,17 @@ use core_test_support::wait_for_event_match;
use pretty_assertions::assert_eq;
use serde_json::Value;
use serde_json::json;
use serial_test::serial;
use std::ffi::OsString;
use std::fs;
use std::process::Command;
use std::time::Duration;
use tokio::sync::oneshot;
use tokio::time::timeout;
const STARTUP_CONTEXT_HEADER: &str = "Startup context from Codex.";
const MEMORY_PROMPT_PHRASE: &str =
"You have access to a memory folder with guidance from prior runs.";
const REALTIME_CONVERSATION_TEST_SUBPROCESS_ENV_VAR: &str =
"CODEX_REALTIME_CONVERSATION_TEST_SUBPROCESS";
fn websocket_request_text(
request: &core_test_support::responses::WebSocketRequest,
) -> Option<String> {
@@ -82,6 +84,33 @@ where
tokio::time::sleep(Duration::from_millis(10)).await;
}
}
fn run_realtime_conversation_test_in_subprocess(
test_name: &str,
openai_api_key: Option<&str>,
) -> Result<()> {
let mut command = Command::new(std::env::current_exe()?);
command
.arg("--exact")
.arg(test_name)
.env(REALTIME_CONVERSATION_TEST_SUBPROCESS_ENV_VAR, "1");
match openai_api_key {
Some(openai_api_key) => {
command.env(OPENAI_API_KEY_ENV_VAR, openai_api_key);
}
None => {
command.env_remove(OPENAI_API_KEY_ENV_VAR);
}
}
let output = command.output()?;
assert!(
output.status.success(),
"subprocess test `{test_name}` failed\nstdout:\n{}\nstderr:\n{}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr),
);
Ok(())
}
async fn seed_recent_thread(
test: &TestCodex,
title: &str,
@@ -260,11 +289,16 @@ async fn conversation_start_audio_text_close_round_trip() -> Result<()> {
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[serial(openai_api_key_env)]
async fn conversation_start_uses_openai_env_key_fallback_with_chatgpt_auth() -> Result<()> {
if std::env::var_os(REALTIME_CONVERSATION_TEST_SUBPROCESS_ENV_VAR).is_none() {
return run_realtime_conversation_test_in_subprocess(
"suite::realtime_conversation::conversation_start_uses_openai_env_key_fallback_with_chatgpt_auth",
Some("env-realtime-key"),
);
}
skip_if_no_network!(Ok(()));
let _env_guard = EnvGuard::set(OPENAI_API_KEY_ENV_VAR, "env-realtime-key");
let server = start_websocket_server(vec![
vec![],
vec![vec![json!({
@@ -369,34 +403,6 @@ async fn conversation_transport_close_emits_closed_event() -> Result<()> {
Ok(())
}
struct EnvGuard {
key: &'static str,
original: Option<OsString>,
}
impl EnvGuard {
fn set(key: &'static str, value: &str) -> Self {
let original = std::env::var_os(key);
// SAFETY: this guard restores the original value before the test exits.
unsafe {
std::env::set_var(key, value);
}
Self { key, original }
}
}
impl Drop for EnvGuard {
fn drop(&mut self) {
// SAFETY: this guard restores the original value for the modified env var.
unsafe {
match &self.original {
Some(value) => std::env::set_var(self.key, value),
None => std::env::remove_var(self.key),
}
}
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn conversation_audio_before_start_emits_error() -> Result<()> {
skip_if_no_network!(Ok(()));
@@ -429,6 +435,91 @@ async fn conversation_audio_before_start_emits_error() -> Result<()> {
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn conversation_start_preflight_failure_emits_realtime_error_only() -> Result<()> {
if std::env::var_os(REALTIME_CONVERSATION_TEST_SUBPROCESS_ENV_VAR).is_none() {
return run_realtime_conversation_test_in_subprocess(
"suite::realtime_conversation::conversation_start_preflight_failure_emits_realtime_error_only",
None,
);
}
skip_if_no_network!(Ok(()));
let server = start_websocket_server(vec![]).await;
let mut builder = test_codex().with_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing());
let test = builder.build_with_websocket_server(&server).await?;
test.codex
.submit(Op::RealtimeConversationStart(ConversationStartParams {
prompt: "backend prompt".to_string(),
session_id: None,
}))
.await?;
let err = wait_for_event_match(&test.codex, |msg| match msg {
EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent {
payload: RealtimeEvent::Error(message),
}) => Some(message.clone()),
_ => None,
})
.await;
assert_eq!(err, "realtime conversation requires API key auth");
let closed = timeout(Duration::from_millis(200), async {
wait_for_event_match(&test.codex, |msg| match msg {
EventMsg::RealtimeConversationClosed(closed) => Some(closed.clone()),
_ => None,
})
.await
})
.await;
assert!(closed.is_err(), "preflight failure should not emit closed");
server.shutdown().await;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn conversation_start_connect_failure_emits_realtime_error_only() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_websocket_server(vec![]).await;
let mut builder = test_codex().with_config(|config| {
config.experimental_realtime_ws_base_url = Some("http://127.0.0.1:1".to_string());
});
let test = builder.build_with_websocket_server(&server).await?;
test.codex
.submit(Op::RealtimeConversationStart(ConversationStartParams {
prompt: "backend prompt".to_string(),
session_id: None,
}))
.await?;
let err = wait_for_event_match(&test.codex, |msg| match msg {
EventMsg::RealtimeConversationRealtime(RealtimeConversationRealtimeEvent {
payload: RealtimeEvent::Error(message),
}) => Some(message.clone()),
_ => None,
})
.await;
assert!(!err.is_empty());
let closed = timeout(Duration::from_millis(200), async {
wait_for_event_match(&test.codex, |msg| match msg {
EventMsg::RealtimeConversationClosed(closed) => Some(closed.clone()),
_ => None,
})
.await
})
.await;
assert!(closed.is_err(), "connect failure should not emit closed");
server.shutdown().await;
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn conversation_text_before_start_emits_error() -> Result<()> {
skip_if_no_network!(Ok(()));