Move streaming transport logic onto ModelClientSession

This commit is contained in:
pakrym-oai
2026-02-04 14:19:31 -08:00
parent 52c25b39b1
commit 14ac8471ee

View File

@@ -148,14 +148,6 @@ impl ModelClient {
}
}
fn disable_websockets(&self) -> bool {
self.state.disable_websockets.load(Ordering::Relaxed)
}
fn activate_http_fallback(&self, websocket_enabled: bool) -> bool {
websocket_enabled && !self.state.disable_websockets.swap(true, Ordering::Relaxed)
}
/// Compacts the current conversation history using the Compact endpoint.
///
/// This is a unary call (no streaming) that returns a new list of
@@ -258,8 +250,31 @@ impl ModelClient {
extra_headers
}
/// Builds request telemetry for unary API calls (e.g., Compact endpoint).
fn build_request_telemetry(otel_manager: &OtelManager) -> Arc<dyn RequestTelemetry> {
let telemetry = Arc::new(ApiTelemetry::new(otel_manager.clone()));
let request_telemetry: Arc<dyn RequestTelemetry> = telemetry;
request_telemetry
}
}
impl ModelClientSession {
fn disable_websockets(&self) -> bool {
self.client.state.disable_websockets.load(Ordering::Relaxed)
}
fn activate_http_fallback(&self, websocket_enabled: bool) -> bool {
websocket_enabled
&& !self
.client
.state
.disable_websockets
.swap(true, Ordering::Relaxed)
}
fn responses_websocket_enabled(&self) -> bool {
self.state.provider.supports_websockets && self.state.enable_responses_websockets
self.client.state.provider.supports_websockets
&& self.client.state.enable_responses_websockets
}
fn build_responses_request(prompt: &Prompt) -> Result<ApiPrompt> {
@@ -277,7 +292,6 @@ impl ModelClient {
summary: ReasoningSummaryConfig,
web_search_eligible: bool,
turn_metadata_header: Option<&str>,
turn_state: &Arc<OnceLock<String>>,
compression: Compression,
) -> ApiResponsesOptions {
let turn_metadata_header =
@@ -304,9 +318,12 @@ impl ModelClient {
};
let verbosity = if model_info.support_verbosity {
self.state.model_verbosity.or(model_info.default_verbosity)
self.client
.state
.model_verbosity
.or(model_info.default_verbosity)
} else {
if self.state.model_verbosity.is_some() {
if self.client.state.model_verbosity.is_some() {
warn!(
"model_verbosity is set but ignored as the model does not support verbosity: {}",
model_info.slug
@@ -316,7 +333,7 @@ impl ModelClient {
};
let text = create_text_param_for_request(verbosity, &prompt.output_schema);
let conversation_id = self.state.conversation_id.to_string();
let conversation_id = self.client.state.conversation_id.to_string();
ApiResponsesOptions {
reasoning,
@@ -325,28 +342,25 @@ impl ModelClient {
text,
store_override: None,
conversation_id: Some(conversation_id),
session_source: Some(self.state.session_source.clone()),
session_source: Some(self.client.state.session_source.clone()),
extra_headers: build_responses_headers(
self.state.beta_features_header.as_deref(),
self.client.state.beta_features_header.as_deref(),
web_search_eligible,
Some(turn_state),
Some(&self.turn_state),
turn_metadata_header.as_ref(),
),
compression,
turn_state: Some(Arc::clone(turn_state)),
turn_state: Some(Arc::clone(&self.turn_state)),
}
}
fn get_incremental_items(
websocket_last_items: &[ResponseItem],
input_items: &[ResponseItem],
) -> Option<Vec<ResponseItem>> {
fn get_incremental_items(&self, input_items: &[ResponseItem]) -> Option<Vec<ResponseItem>> {
// Checks whether the current request input is an incremental append to the previous request.
// If items in the new request contain all the items from the previous request we build
// a response.append request otherwise we start with a fresh response.create request.
let previous_len = websocket_last_items.len();
let previous_len = self.websocket_last_items.len();
let can_append = previous_len > 0
&& input_items.starts_with(websocket_last_items)
&& input_items.starts_with(&self.websocket_last_items)
&& previous_len < input_items.len();
if can_append {
Some(input_items[previous_len..].to_vec())
@@ -356,14 +370,12 @@ impl ModelClient {
}
fn prepare_websocket_request(
&self,
model_slug: &str,
websocket_last_items: &[ResponseItem],
api_prompt: &ApiPrompt,
options: &ApiResponsesOptions,
) -> ResponsesWsRequest {
if let Some(append_items) =
Self::get_incremental_items(websocket_last_items, &api_prompt.input)
{
if let Some(append_items) = self.get_incremental_items(&api_prompt.input) {
return ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest {
input: append_items,
});
@@ -398,14 +410,13 @@ impl ModelClient {
}
async fn ensure_websocket_connection(
&self,
&mut self,
otel_manager: &OtelManager,
connection: &mut Option<ApiWebSocketConnection>,
api_provider: codex_api::Provider,
api_auth: CoreAuthProvider,
options: &ApiResponsesOptions,
) -> std::result::Result<(), ApiError> {
let needs_new = match connection.as_ref() {
let needs_new = match self.connection.as_ref() {
Some(conn) => conn.is_closed().await,
None => true,
};
@@ -413,7 +424,7 @@ impl ModelClient {
if needs_new {
let mut headers = options.extra_headers.clone();
headers.extend(build_conversation_headers(options.conversation_id.clone()));
if self.state.include_timing_metrics {
if self.client.state.include_timing_metrics {
headers.insert(
X_RESPONSESAPI_INCLUDE_TIMING_METRICS_HEADER,
HeaderValue::from_static("true"),
@@ -428,16 +439,16 @@ impl ModelClient {
Some(websocket_telemetry),
)
.await?;
*connection = Some(new_conn);
self.connection = Some(new_conn);
}
Ok(())
}
fn responses_request_compression(&self, auth: Option<&crate::auth::CodexAuth>) -> Compression {
if self.state.enable_request_compression
if self.client.state.enable_request_compression
&& auth.is_some_and(CodexAuth::is_chatgpt_auth)
&& self.state.provider.is_openai()
&& self.client.state.provider.is_openai()
{
Compression::Zstd
} else {
@@ -459,17 +470,18 @@ impl ModelClient {
summary: ReasoningSummaryConfig,
web_search_eligible: bool,
turn_metadata_header: Option<&str>,
turn_state: &Arc<OnceLock<String>>,
) -> Result<ResponseStream> {
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
warn!(path, "Streaming from fixture");
let stream =
codex_api::stream_from_fixture(path, self.state.provider.stream_idle_timeout())
.map_err(map_api_error)?;
let stream = codex_api::stream_from_fixture(
path,
self.client.state.provider.stream_idle_timeout(),
)
.map_err(map_api_error)?;
return Ok(map_response_stream(stream, otel_manager.clone()));
}
let auth_manager = self.state.auth_manager.clone();
let auth_manager = self.client.state.auth_manager.clone();
let api_prompt = Self::build_responses_request(prompt)?;
let mut auth_recovery = auth_manager
@@ -481,10 +493,11 @@ impl ModelClient {
None => None,
};
let api_provider = self
.client
.state
.provider
.to_api_provider(auth.as_ref().map(CodexAuth::internal_auth_mode))?;
let api_auth = auth_provider_from_auth(auth.clone(), &self.state.provider)?;
let api_auth = auth_provider_from_auth(auth.clone(), &self.client.state.provider)?;
let transport = ReqwestTransport::new(build_reqwest_client());
let (request_telemetry, sse_telemetry) = Self::build_streaming_telemetry(otel_manager);
let compression = self.responses_request_compression(auth.as_ref());
@@ -499,7 +512,6 @@ impl ModelClient {
summary,
web_search_eligible,
turn_metadata_header,
turn_state,
compression,
);
@@ -525,7 +537,7 @@ impl ModelClient {
/// Streams a turn via the Responses API over WebSocket transport.
#[allow(clippy::too_many_arguments)]
async fn stream_responses_websocket(
&self,
&mut self,
prompt: &Prompt,
model_info: &ModelInfo,
otel_manager: &OtelManager,
@@ -533,11 +545,8 @@ impl ModelClient {
summary: ReasoningSummaryConfig,
web_search_eligible: bool,
turn_metadata_header: Option<&str>,
connection: &mut Option<ApiWebSocketConnection>,
websocket_last_items: &mut Vec<ResponseItem>,
turn_state: &Arc<OnceLock<String>>,
) -> Result<ResponseStream> {
let auth_manager = self.state.auth_manager.clone();
let auth_manager = self.client.state.auth_manager.clone();
let api_prompt = Self::build_responses_request(prompt)?;
let mut auth_recovery = auth_manager
@@ -549,10 +558,11 @@ impl ModelClient {
None => None,
};
let api_provider = self
.client
.state
.provider
.to_api_provider(auth.as_ref().map(CodexAuth::internal_auth_mode))?;
let api_auth = auth_provider_from_auth(auth.clone(), &self.state.provider)?;
let api_auth = auth_provider_from_auth(auth.clone(), &self.client.state.provider)?;
let compression = self.responses_request_compression(auth.as_ref());
let options = self.build_responses_options(
@@ -562,20 +572,13 @@ impl ModelClient {
summary,
web_search_eligible,
turn_metadata_header,
turn_state,
compression,
);
let request = Self::prepare_websocket_request(
&model_info.slug,
websocket_last_items,
&api_prompt,
&options,
);
let request = self.prepare_websocket_request(&model_info.slug, &api_prompt, &options);
match self
.ensure_websocket_connection(
otel_manager,
connection,
api_provider.clone(),
api_auth.clone(),
&options,
@@ -592,7 +595,7 @@ impl ModelClient {
Err(err) => return Err(map_api_error(err)),
}
let connection = connection.as_ref().ok_or_else(|| {
let connection = self.connection.as_ref().ok_or_else(|| {
map_api_error(ApiError::Stream(
"websocket connection is unavailable".to_string(),
))
@@ -602,7 +605,7 @@ impl ModelClient {
.stream_request(request)
.await
.map_err(map_api_error)?;
*websocket_last_items = api_prompt.input.clone();
self.websocket_last_items = api_prompt.input.clone();
return Ok(map_response_stream(stream_result, otel_manager.clone()));
}
@@ -625,15 +628,6 @@ impl ModelClient {
websocket_telemetry
}
/// Builds request telemetry for unary API calls (e.g., Compact endpoint).
fn build_request_telemetry(otel_manager: &OtelManager) -> Arc<dyn RequestTelemetry> {
let telemetry = Arc::new(ApiTelemetry::new(otel_manager.clone()));
let request_telemetry: Arc<dyn RequestTelemetry> = telemetry;
request_telemetry
}
}
impl ModelClientSession {
#[allow(clippy::too_many_arguments)]
pub async fn stream(
&mut self,
@@ -644,72 +638,43 @@ impl ModelClientSession {
summary: ReasoningSummaryConfig,
web_search_eligible: bool,
turn_metadata_header: Option<&str>,
) -> Result<ResponseStream> {
self.stream_with_state(
prompt,
model_info,
otel_manager,
effort,
summary,
web_search_eligible,
turn_metadata_header,
)
.await
}
#[allow(clippy::too_many_arguments)]
async fn stream_with_state(
&mut self,
prompt: &Prompt,
model_info: &ModelInfo,
otel_manager: &OtelManager,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
web_search_eligible: bool,
turn_metadata_header: Option<&str>,
) -> Result<ResponseStream> {
let wire_api = self.client.state.provider.wire_api;
match wire_api {
WireApi::Responses => {
let websocket_enabled =
self.client.responses_websocket_enabled() && !self.client.disable_websockets();
self.responses_websocket_enabled() && !self.disable_websockets();
if websocket_enabled {
self.client
.stream_responses_websocket(
prompt,
model_info,
otel_manager,
effort,
summary,
web_search_eligible,
turn_metadata_header,
&mut self.connection,
&mut self.websocket_last_items,
&self.turn_state,
)
.await
self.stream_responses_websocket(
prompt,
model_info,
otel_manager,
effort,
summary,
web_search_eligible,
turn_metadata_header,
)
.await
} else {
self.client
.stream_responses_api(
prompt,
model_info,
otel_manager,
effort,
summary,
web_search_eligible,
turn_metadata_header,
&self.turn_state,
)
.await
self.stream_responses_api(
prompt,
model_info,
otel_manager,
effort,
summary,
web_search_eligible,
turn_metadata_header,
)
.await
}
}
}
}
pub(crate) fn try_switch_fallback_transport(&mut self, otel_manager: &OtelManager) -> bool {
let websocket_enabled = self.client.responses_websocket_enabled();
let activated = self.client.activate_http_fallback(websocket_enabled);
let websocket_enabled = self.responses_websocket_enabled();
let activated = self.activate_http_fallback(websocket_enabled);
if activated {
warn!("falling back to HTTP");
otel_manager.counter(