Compare commits

...

1 Commits

Author SHA1 Message Date
pakrym-oai
0e4d129379 ws 2026-01-08 08:20:41 -08:00
16 changed files with 969 additions and 1 deletions

37
codex-rs/Cargo.lock generated
View File

@@ -982,8 +982,10 @@ dependencies = [
"thiserror 2.0.17",
"tokio",
"tokio-test",
"tokio-tungstenite",
"tokio-util",
"tracing",
"url",
"wiremock",
]
@@ -2344,6 +2346,12 @@ dependencies = [
"syn 2.0.104",
]
[[package]]
name = "data-encoding"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
[[package]]
name = "dbus"
version = "0.9.9"
@@ -7095,6 +7103,18 @@ dependencies = [
"tokio-stream",
]
[[package]]
name = "tokio-tungstenite"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]]
name = "tokio-util"
version = "0.7.16"
@@ -7489,6 +7509,23 @@ dependencies = [
"ratatui-core",
]
[[package]]
name = "tungstenite"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13"
dependencies = [
"bytes",
"data-encoding",
"http 1.3.1",
"httparse",
"log",
"rand 0.9.2",
"sha1",
"thiserror 2.0.17",
"utf-8",
]
[[package]]
name = "typenum"
version = "1.18.0"

View File

@@ -207,6 +207,7 @@ thiserror = "2.0.17"
time = "0.3"
tiny_http = "0.12"
tokio = "1"
tokio-tungstenite = "0.26.1"
tokio-stream = "0.1.18"
tokio-test = "0.4"
tokio-util = "0.7.16"

View File

@@ -15,7 +15,9 @@ serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt", "sync", "time"] }
tokio-tungstenite = { workspace = true }
tracing = { workspace = true }
url = { workspace = true }
eventsource-stream = { workspace = true }
regex-lite = { workspace = true }
tokio-util = { workspace = true, features = ["codec"] }

View File

@@ -2,4 +2,5 @@ pub mod chat;
pub mod compact;
pub mod models;
pub mod responses;
pub mod responses_ws;
mod streaming;

View File

@@ -0,0 +1,708 @@
use crate::auth::AuthProvider;
use crate::common::Prompt as ApiPrompt;
use crate::common::ResponseEvent;
use crate::common::ResponseStream;
use crate::endpoint::responses::ResponsesOptions;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::requests::ResponsesRequestBuilder;
use codex_client::TransportError;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::TokenUsage;
use futures::SinkExt;
use futures::StreamExt;
use http::HeaderMap;
use http::HeaderValue;
use serde::Deserialize;
use serde_json::Value;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio::sync::OwnedSemaphorePermit;
use tokio::sync::Semaphore;
use tokio::sync::mpsc;
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite;
use tokio_tungstenite::tungstenite::Message;
use tracing::debug;
use tracing::trace;
use url::Url;
const WS_BUFFER: usize = 1600;
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
type WsSender = futures::stream::SplitSink<WsStream, Message>;
#[derive(Clone)]
pub struct ResponsesWsSession<A: AuthProvider + Clone> {
inner: Arc<ResponsesWsInner<A>>,
}
struct ResponsesWsInner<A: AuthProvider + Clone> {
provider: Provider,
auth: A,
connection: Mutex<Option<Arc<ResponsesWsConnection>>>,
state: Arc<Mutex<WsSessionState>>,
turn_gate: Arc<Semaphore>,
}
#[derive(Default)]
struct WsSessionState {
last_sent_len: usize,
active: bool,
}
struct ResponsesWsConnection {
sender: Mutex<WsSender>,
receiver: Mutex<mpsc::Receiver<Result<String, ApiError>>>,
}
impl<A: AuthProvider + Clone> ResponsesWsSession<A> {
pub fn new(provider: Provider, auth: A) -> Self {
Self {
inner: Arc::new(ResponsesWsInner {
provider,
auth,
connection: Mutex::new(None),
state: Arc::new(Mutex::new(WsSessionState::default())),
turn_gate: Arc::new(Semaphore::new(1)),
}),
}
}
pub async fn reset(&self) {
{
let mut guard = self.inner.connection.lock().await;
*guard = None;
}
let mut state = self.inner.state.lock().await;
state.last_sent_len = 0;
state.active = false;
}
pub async fn stream_prompt(
&self,
model: &str,
prompt: &ApiPrompt,
options: ResponsesOptions,
) -> Result<ResponseStream, ApiError> {
let ResponsesOptions {
reasoning,
include,
prompt_cache_key,
text,
store_override,
conversation_id,
session_source,
extra_headers,
} = options;
let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input)
.tools(&prompt.tools)
.parallel_tool_calls(prompt.parallel_tool_calls)
.reasoning(reasoning)
.include(include)
.prompt_cache_key(prompt_cache_key)
.text(text)
.conversation(conversation_id)
.session_source(session_source)
.store_override(store_override)
.extra_headers(extra_headers)
.build(&self.inner.provider)?;
let input_len = prompt.input.len();
let event = {
let mut state = self.inner.state.lock().await;
let should_reset = !state.active || input_len < state.last_sent_len;
if should_reset {
state.last_sent_len = 0;
}
state.active = true;
if should_reset {
build_create_event(request.body)?
} else {
let delta = prompt
.input
.get(state.last_sent_len..)
.unwrap_or_default()
.to_vec();
build_append_event(delta)
}
};
let permit = self
.inner
.turn_gate
.clone()
.acquire_owned()
.await
.map_err(|_| ApiError::Stream("responses websocket closed".into()))?;
let connection = self.ensure_connection(request.headers).await?;
if let Err(err) = connection.send(&event).await {
self.reset().await;
return Err(err);
}
Ok(spawn_ws_response_stream(
connection,
self.inner.state.clone(),
input_len,
permit,
))
}
async fn ensure_connection(
&self,
extra_headers: HeaderMap,
) -> Result<Arc<ResponsesWsConnection>, ApiError> {
let existing = { self.inner.connection.lock().await.clone() };
if let Some(connection) = existing {
return Ok(connection);
}
let connection =
ResponsesWsConnection::connect(&self.inner.provider, &self.inner.auth, extra_headers)
.await?;
let connection = Arc::new(connection);
let mut guard = self.inner.connection.lock().await;
if guard.is_none() {
*guard = Some(connection.clone());
}
Ok(connection)
}
}
impl ResponsesWsConnection {
async fn connect<A: AuthProvider>(
provider: &Provider,
auth: &A,
extra_headers: HeaderMap,
) -> Result<Self, ApiError> {
let url = ws_url(provider)?;
let headers = build_ws_headers(provider, auth, extra_headers);
let request = build_ws_request(url, headers)?;
let (stream, _response) = connect_async(request).await.map_err(map_ws_error)?;
let (sender, mut receiver) = stream.split();
let (tx, rx) = mpsc::channel(WS_BUFFER);
tokio::spawn(async move {
loop {
let message = receiver.next().await;
let message = match message {
Some(Ok(message)) => message,
Some(Err(err)) => {
let _ = tx
.send(Err(ApiError::Stream(format!("websocket error: {err}"))))
.await;
return;
}
None => {
let _ = tx
.send(Err(ApiError::Stream(
"websocket closed unexpectedly".into(),
)))
.await;
return;
}
};
match message {
Message::Text(text) => {
if tx.send(Ok(text.to_string())).await.is_err() {
return;
}
}
Message::Binary(bytes) => {
if let Ok(text) = String::from_utf8(bytes.to_vec())
&& tx.send(Ok(text)).await.is_err()
{
return;
}
}
Message::Close(_) => {
let _ = tx
.send(Err(ApiError::Stream("websocket closed".into())))
.await;
return;
}
Message::Ping(_) | Message::Pong(_) => {}
_ => {}
}
}
});
Ok(Self {
sender: Mutex::new(sender),
receiver: Mutex::new(rx),
})
}
async fn send(&self, payload: &Value) -> Result<(), ApiError> {
let text = serde_json::to_string(payload)
.map_err(|err| ApiError::Stream(format!("failed to encode ws payload: {err}")))?;
let mut sender = self.sender.lock().await;
sender
.send(Message::Text(text.into()))
.await
.map_err(|err| ApiError::Stream(format!("websocket send failed: {err}")))
}
}
fn build_create_event(body: Value) -> Result<Value, ApiError> {
let Value::Object(mut payload) = body else {
return Err(ApiError::Stream(
"responses create body was not an object".into(),
));
};
payload.remove("stream");
payload.remove("background");
let mut event = serde_json::Map::new();
event.insert(
"type".to_string(),
Value::String("response.create".to_string()),
);
event.extend(payload);
Ok(Value::Object(event))
}
fn build_append_event(input: Vec<ResponseItem>) -> Value {
serde_json::json!({
"type": "response.append",
"input": input,
})
}
fn ws_url(provider: &Provider) -> Result<Url, ApiError> {
let url = provider.url_for_path("responses");
let mut url = Url::parse(&url)
.map_err(|err| ApiError::Stream(format!("invalid websocket url: {err}")))?;
let scheme = match url.scheme() {
"https" => "wss",
"http" => "ws",
"wss" => "wss",
"ws" => "ws",
other => {
return Err(ApiError::Stream(format!(
"unsupported websocket scheme: {other}"
)));
}
};
if url.scheme() != scheme {
url.set_scheme(scheme)
.map_err(|_| ApiError::Stream("failed to set websocket scheme".into()))?;
}
Ok(url)
}
fn build_ws_headers<A: AuthProvider>(
provider: &Provider,
auth: &A,
extra_headers: HeaderMap,
) -> HeaderMap {
let mut headers = provider.headers.clone();
headers.extend(extra_headers);
if let Some(token) = auth.bearer_token()
&& let Ok(header) = format!("Bearer {token}").parse()
{
let _ = headers.insert(http::header::AUTHORIZATION, header);
}
if let Some(account_id) = auth.account_id()
&& let Ok(header) = HeaderValue::from_str(&account_id)
{
let _ = headers.insert("ChatGPT-Account-ID", header);
}
headers
}
fn build_ws_request(url: Url, headers: HeaderMap) -> Result<http::Request<()>, ApiError> {
let mut builder = http::Request::builder()
.method(http::Method::GET)
.uri(url.as_str());
for (name, value) in headers.iter() {
builder = builder.header(name, value);
}
builder
.body(())
.map_err(|err| ApiError::Stream(format!("failed to build websocket request: {err}")))
}
fn map_ws_error(err: tungstenite::Error) -> ApiError {
let transport = match err {
tungstenite::Error::Http(response) => TransportError::Http {
status: response.status(),
headers: Some(response.headers().clone()),
body: None,
},
tungstenite::Error::Url(err) => TransportError::Build(err.to_string()),
tungstenite::Error::Io(err) => TransportError::Network(err.to_string()),
other => TransportError::Network(other.to_string()),
};
ApiError::Transport(transport)
}
fn spawn_ws_response_stream(
connection: Arc<ResponsesWsConnection>,
state: Arc<Mutex<WsSessionState>>,
input_len: usize,
permit: OwnedSemaphorePermit,
) -> ResponseStream {
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(WS_BUFFER);
tokio::spawn(async move {
let _permit = permit;
let mut output_count: usize = 0;
let mut draining = false;
let mut can_send = true;
let mut receiver = connection.receiver.lock().await;
loop {
let message = receiver.recv().await;
let message = match message {
Some(message) => message,
None => {
if can_send && !draining {
let _ = tx_event
.send(Err(ApiError::Stream(
"websocket closed while awaiting responses".into(),
)))
.await;
}
let mut state = state.lock().await;
state.active = false;
state.last_sent_len = 0;
return;
}
};
match message {
Ok(text) => {
trace!("WS event: {text}");
let event: WsEvent = match serde_json::from_str(&text) {
Ok(event) => event,
Err(err) => {
debug!("Failed to parse WS event: {err}");
continue;
}
};
match event.kind.as_str() {
"response.output_item.done" => {
let Some(item_val) = event.item else {
continue;
};
let Ok(item) = serde_json::from_value::<ResponseItem>(item_val) else {
debug!("failed to parse ResponseItem from output_item.done");
continue;
};
output_count = output_count.saturating_add(1);
if can_send
&& tx_event
.send(Ok(ResponseEvent::OutputItemDone(item)))
.await
.is_err()
{
can_send = false;
}
}
"response.output_item.added" => {
let Some(item_val) = event.item else {
continue;
};
let Ok(item) = serde_json::from_value::<ResponseItem>(item_val) else {
debug!("failed to parse ResponseItem from output_item.added");
continue;
};
if can_send
&& tx_event
.send(Ok(ResponseEvent::OutputItemAdded(item)))
.await
.is_err()
{
can_send = false;
}
}
"response.output_text.delta" => {
if let Some(delta) = event.delta
&& can_send
&& tx_event
.send(Ok(ResponseEvent::OutputTextDelta(delta)))
.await
.is_err()
{
can_send = false;
}
}
"response.reasoning_summary_text.delta" => {
if let (Some(delta), Some(summary_index)) =
(event.delta, event.summary_index)
&& can_send
&& tx_event
.send(Ok(ResponseEvent::ReasoningSummaryDelta {
delta,
summary_index,
}))
.await
.is_err()
{
can_send = false;
}
}
"response.reasoning_text.delta" => {
if let (Some(delta), Some(content_index)) =
(event.delta, event.content_index)
&& can_send
&& tx_event
.send(Ok(ResponseEvent::ReasoningContentDelta {
delta,
content_index,
}))
.await
.is_err()
{
can_send = false;
}
}
"response.reasoning_summary_part.added" => {
if let Some(summary_index) = event.summary_index
&& can_send
&& tx_event
.send(Ok(ResponseEvent::ReasoningSummaryPartAdded {
summary_index,
}))
.await
.is_err()
{
can_send = false;
}
}
"response.created" => {
if can_send
&& tx_event.send(Ok(ResponseEvent::Created {})).await.is_err()
{
can_send = false;
}
}
"response.failed" => {
let error = map_failed_response(&event);
if can_send && tx_event.send(Err(error)).await.is_err() {
can_send = false;
}
let mut state = state.lock().await;
state.active = false;
state.last_sent_len = 0;
draining = true;
}
"response.done" | "response.completed" => {
let completed = match completed_event(&event) {
Ok(event) => event,
Err(err) => {
if can_send {
let _ = tx_event.send(Err(err)).await;
}
let mut state = state.lock().await;
state.active = false;
state.last_sent_len = 0;
return;
}
};
if !draining {
if can_send {
let _ = tx_event.send(Ok(completed)).await;
}
let mut state = state.lock().await;
state.last_sent_len = input_len.saturating_add(output_count);
state.active = true;
}
return;
}
_ => {}
}
}
Err(err) => {
if can_send && !draining {
let _ = tx_event.send(Err(err)).await;
}
let mut state = state.lock().await;
state.active = false;
state.last_sent_len = 0;
return;
}
}
}
});
ResponseStream { rx_event }
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct Error {
r#type: Option<String>,
code: Option<String>,
message: Option<String>,
plan_type: Option<String>,
resets_at: Option<i64>,
}
#[derive(Debug, Deserialize)]
struct ResponseCompleted {
id: String,
#[serde(default)]
usage: Option<ResponseUsage>,
}
#[derive(Debug, Deserialize, Clone)]
struct ResponseUsage {
#[serde(default)]
input_tokens: i64,
#[serde(default)]
input_tokens_details: Option<ResponseInputTokensDetails>,
#[serde(default)]
output_tokens: i64,
#[serde(default)]
output_tokens_details: Option<ResponseOutputTokensDetails>,
#[serde(default)]
total_tokens: i64,
}
impl From<ResponseUsage> for TokenUsage {
fn from(value: ResponseUsage) -> Self {
TokenUsage {
input_tokens: value.input_tokens,
cached_input_tokens: value
.input_tokens_details
.map(|d| d.cached_tokens)
.unwrap_or(0),
output_tokens: value.output_tokens,
reasoning_output_tokens: value
.output_tokens_details
.map(|d| d.reasoning_tokens)
.unwrap_or(0),
total_tokens: value.total_tokens,
}
}
}
#[derive(Debug, Deserialize, Clone)]
struct ResponseInputTokensDetails {
cached_tokens: i64,
}
#[derive(Debug, Deserialize, Clone)]
struct ResponseOutputTokensDetails {
reasoning_tokens: i64,
}
#[derive(Deserialize, Debug)]
struct WsEvent {
#[serde(rename = "type")]
kind: String,
response: Option<Value>,
item: Option<Value>,
delta: Option<String>,
summary_index: Option<i64>,
content_index: Option<i64>,
#[serde(default)]
usage: Option<ResponseUsage>,
}
fn completed_event(event: &WsEvent) -> Result<ResponseEvent, ApiError> {
if let Some(response) = &event.response {
let completed =
serde_json::from_value::<ResponseCompleted>(response.clone()).map_err(|err| {
ApiError::Stream(format!("failed to parse response.completed: {err}"))
})?;
return Ok(ResponseEvent::Completed {
response_id: completed.id,
token_usage: completed.usage.map(Into::into),
});
}
if let Some(usage) = event.usage.clone() {
return Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: Some(usage.into()),
});
}
Ok(ResponseEvent::Completed {
response_id: String::new(),
token_usage: None,
})
}
fn map_failed_response(event: &WsEvent) -> ApiError {
let Some(resp_val) = event.response.clone() else {
return ApiError::Stream("response.failed event received".into());
};
let Some(error) = resp_val.get("error") else {
return ApiError::Stream("response.failed event received".into());
};
let Ok(error) = serde_json::from_value::<Error>(error.clone()) else {
return ApiError::Stream("response.failed event received".into());
};
if is_context_window_error(&error) {
ApiError::ContextWindowExceeded
} else if is_quota_exceeded_error(&error) {
ApiError::QuotaExceeded
} else if is_usage_not_included(&error) {
ApiError::UsageNotIncluded
} else {
let delay = try_parse_retry_after(&error);
let message = error.message.unwrap_or_default();
ApiError::Retryable { message, delay }
}
}
fn try_parse_retry_after(err: &Error) -> Option<std::time::Duration> {
if err.code.as_deref() != Some("rate_limit_exceeded") {
return None;
}
let re = rate_limit_regex();
if let Some(message) = &err.message
&& let Some(captures) = re.captures(message)
{
let seconds = captures.get(1);
let unit = captures.get(2);
if let (Some(value), Some(unit)) = (seconds, unit) {
let value = value.as_str().parse::<f64>().ok()?;
let unit = unit.as_str().to_ascii_lowercase();
if unit == "s" || unit.starts_with("second") {
return Some(std::time::Duration::from_secs_f64(value));
} else if unit == "ms" {
return Some(std::time::Duration::from_millis(value as u64));
}
}
}
None
}
fn is_context_window_error(error: &Error) -> bool {
error.code.as_deref() == Some("context_length_exceeded")
}
fn is_quota_exceeded_error(error: &Error) -> bool {
error.code.as_deref() == Some("insufficient_quota")
}
fn is_usage_not_included(error: &Error) -> bool {
error.code.as_deref() == Some("usage_not_included")
}
fn rate_limit_regex() -> &'static regex_lite::Regex {
static RE: std::sync::OnceLock<regex_lite::Regex> = std::sync::OnceLock::new();
#[expect(clippy::unwrap_used)]
RE.get_or_init(|| {
regex_lite::Regex::new(r"(?i)try again in\\s*(\\d+(?:\\.\\d+)?)\\s*(s|ms|seconds?)")
.unwrap()
})
}

View File

@@ -25,6 +25,7 @@ pub use crate::endpoint::compact::CompactClient;
pub use crate::endpoint::models::ModelsClient;
pub use crate::endpoint::responses::ResponsesClient;
pub use crate::endpoint::responses::ResponsesOptions;
pub use crate::endpoint::responses_ws::ResponsesWsSession;
pub use crate::error::ApiError;
pub use crate::provider::Provider;
pub use crate::provider::WireApi;

View File

@@ -47,9 +47,11 @@ use crate::default_client::build_reqwest_client;
use crate::error::CodexErr;
use crate::error::Result;
use crate::features::FEATURES;
use crate::flags::CODEX_RS_RESPONSES_WS;
use crate::flags::CODEX_RS_SSE_FIXTURE;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
use crate::responses_ws::ResponsesWsManager;
use crate::tools::spec::create_tools_json_for_chat_completions_api;
use crate::tools::spec::create_tools_json_for_responses_api;
@@ -60,6 +62,7 @@ pub struct ModelClient {
model_info: ModelInfo,
otel_manager: OtelManager,
provider: ModelProviderInfo,
responses_ws: Option<Arc<ResponsesWsManager>>,
conversation_id: ThreadId,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
@@ -74,6 +77,7 @@ impl ModelClient {
model_info: ModelInfo,
otel_manager: OtelManager,
provider: ModelProviderInfo,
responses_ws: Option<Arc<ResponsesWsManager>>,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
conversation_id: ThreadId,
@@ -85,6 +89,7 @@ impl ModelClient {
model_info,
otel_manager,
provider,
responses_ws,
conversation_id,
effort,
summary,
@@ -115,7 +120,12 @@ impl ModelClient {
/// based on the `show_raw_agent_reasoning` flag in the config.
pub async fn stream(&self, prompt: &Prompt) -> Result<ResponseStream> {
match self.provider.wire_api {
WireApi::Responses => self.stream_responses_api(prompt).await,
WireApi::Responses => {
if *CODEX_RS_RESPONSES_WS && let Some(manager) = self.responses_ws.as_ref() {
return self.stream_responses_ws(prompt, manager).await;
}
self.stream_responses_api(prompt).await
}
WireApi::Chat => {
let api_stream = self.stream_chat_completions(prompt).await?;
@@ -283,6 +293,108 @@ impl ModelClient {
}
}
async fn stream_responses_ws(
&self,
prompt: &Prompt,
manager: &Arc<ResponsesWsManager>,
) -> Result<ResponseStream> {
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
warn!(path, "Streaming from fixture");
let stream = codex_api::stream_from_fixture(path, self.provider.stream_idle_timeout())
.map_err(map_api_error)?;
return Ok(map_response_stream(stream, self.otel_manager.clone()));
}
let auth_manager = self.auth_manager.clone();
let model_info = self.get_model_info();
let instructions = prompt.get_full_instructions(&model_info).into_owned();
let tools_json: Vec<Value> = create_tools_json_for_responses_api(&prompt.tools)?;
let default_reasoning_effort = model_info.default_reasoning_level;
let reasoning = if model_info.supports_reasoning_summaries {
Some(Reasoning {
effort: self.effort.or(default_reasoning_effort),
summary: if self.summary == ReasoningSummaryConfig::None {
None
} else {
Some(self.summary)
},
})
} else {
None
};
let include: Vec<String> = if reasoning.is_some() {
vec!["reasoning.encrypted_content".to_string()]
} else {
vec![]
};
let verbosity = if model_info.support_verbosity {
self.config.model_verbosity.or(model_info.default_verbosity)
} else {
if self.config.model_verbosity.is_some() {
warn!(
"model_verbosity is set but ignored as the model does not support verbosity: {}",
model_info.slug
);
}
None
};
let text = create_text_param_for_request(verbosity, &prompt.output_schema);
let api_prompt = build_api_prompt(prompt, instructions.clone(), tools_json);
let conversation_id = self.conversation_id.to_string();
let session_source = self.session_source.clone();
let mut refreshed = false;
loop {
let auth = auth_manager.as_ref().and_then(|m| m.auth());
let api_provider = self
.provider
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider).await?;
let options = ApiResponsesOptions {
reasoning: reasoning.clone(),
include: include.clone(),
prompt_cache_key: Some(conversation_id.clone()),
text: text.clone(),
store_override: None,
conversation_id: Some(conversation_id.clone()),
session_source: Some(session_source.clone()),
extra_headers: beta_feature_headers(&self.config),
};
let stream_result = manager
.stream_prompt(
api_provider,
api_auth,
&self.get_model(),
&api_prompt,
options,
)
.await;
match stream_result {
Ok(stream) => {
return Ok(map_response_stream(stream, self.otel_manager.clone()));
}
Err(ApiError::Transport(TransportError::Http { status, .. }))
if status == StatusCode::UNAUTHORIZED =>
{
manager.reset().await;
handle_unauthorized(status, &mut refreshed, &auth_manager, &auth).await?;
continue;
}
Err(err) => {
manager.reset().await;
return Err(map_api_error(err));
}
}
}
}
pub fn get_provider(&self) -> ModelProviderInfo {
self.provider.clone()
}

View File

@@ -19,9 +19,11 @@ use crate::compact_remote::run_inline_remote_auto_compact_task;
use crate::exec_policy::ExecPolicyManager;
use crate::features::Feature;
use crate::features::Features;
use crate::flags::CODEX_RS_RESPONSES_WS;
use crate::models_manager::manager::ModelsManager;
use crate::parse_command::parse_command;
use crate::parse_turn_item;
use crate::responses_ws::ResponsesWsManager;
use crate::stream_events_utils::HandleOutputCtx;
use crate::stream_events_utils::handle_non_tool_response_item;
use crate::stream_events_utils::handle_output_item_done;
@@ -506,6 +508,7 @@ impl Session {
auth_manager: Option<Arc<AuthManager>>,
otel_manager: &OtelManager,
provider: ModelProviderInfo,
responses_ws: Option<Arc<ResponsesWsManager>>,
session_configuration: &SessionConfiguration,
per_turn_config: Config,
model_info: ModelInfo,
@@ -524,6 +527,7 @@ impl Session {
model_info.clone(),
otel_manager,
provider,
responses_ws,
session_configuration.model_reasoning_effort,
session_configuration.model_reasoning_summary,
conversation_id,
@@ -676,6 +680,11 @@ impl Session {
.map(Arc::new);
}
let state = SessionState::new(session_configuration.clone());
let responses_ws = if *CODEX_RS_RESPONSES_WS {
Some(Arc::new(ResponsesWsManager::new()))
} else {
None
};
let services = SessionServices {
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())),
@@ -692,6 +701,7 @@ impl Session {
tool_approvals: Mutex::new(ApprovalStore::default()),
skills_manager,
agent_control,
responses_ws,
};
let sess = Arc::new(Session {
@@ -952,6 +962,7 @@ impl Session {
Some(Arc::clone(&self.services.auth_manager)),
&self.services.otel_manager,
session_configuration.provider.clone(),
self.services.responses_ws.clone(),
&session_configuration,
per_turn_config,
model_info,
@@ -2243,6 +2254,7 @@ async fn spawn_review_thread(
model_info.clone(),
otel_manager,
provider,
None,
per_turn_config.model_reasoning_effort,
per_turn_config.model_reasoning_summary,
sess.conversation_id,
@@ -3532,12 +3544,14 @@ mod tests {
tool_approvals: Mutex::new(ApprovalStore::default()),
skills_manager,
agent_control,
responses_ws: None,
};
let turn_context = Session::make_turn_context(
Some(Arc::clone(&auth_manager)),
&otel_manager,
session_configuration.provider.clone(),
None,
&session_configuration,
per_turn_config,
model_info,
@@ -3626,12 +3640,14 @@ mod tests {
tool_approvals: Mutex::new(ApprovalStore::default()),
skills_manager,
agent_control,
responses_ws: None,
};
let turn_context = Arc::new(Session::make_turn_context(
Some(Arc::clone(&auth_manager)),
&otel_manager,
session_configuration.provider.clone(),
None,
&session_configuration,
per_turn_config,
model_info,

View File

@@ -3,4 +3,5 @@ use env_flags::env_flags;
env_flags! {
/// Fixture path for offline tests (see client.rs).
pub CODEX_RS_SSE_FIXTURE: Option<&str> = None;
pub CODEX_RS_RESPONSES_WS: bool = false;
}

View File

@@ -78,6 +78,7 @@ pub use auth::AuthManager;
pub use auth::CodexAuth;
pub mod default_client;
pub mod project_doc;
mod responses_ws;
mod rollout;
pub(crate) mod safety;
pub mod seatbelt;
@@ -134,5 +135,6 @@ pub use codex_protocol::models::LocalShellStatus;
pub use codex_protocol::models::ResponseItem;
pub use compact::content_items_to_text;
pub use event_mapping::parse_turn_item;
pub use responses_ws::ResponsesWsManager;
pub mod compact;
pub mod otel_init;

View File

@@ -0,0 +1,79 @@
use crate::api_bridge::CoreAuthProvider;
use codex_api::Prompt as ApiPrompt;
use codex_api::Provider;
use codex_api::ResponseStream;
use codex_api::ResponsesOptions;
use codex_api::ResponsesWsSession;
use codex_api::error::ApiError;
use tokio::sync::Mutex;
pub struct ResponsesWsManager {
session: Mutex<Option<ResponsesWsSession<CoreAuthProvider>>>,
base_url: Mutex<Option<String>>,
}
impl std::fmt::Debug for ResponsesWsManager {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.debug_struct("ResponsesWsManager").finish()
}
}
impl ResponsesWsManager {
pub(crate) fn new() -> Self {
Self {
session: Mutex::new(None),
base_url: Mutex::new(None),
}
}
pub(crate) async fn reset(&self) {
{
let mut guard = self.session.lock().await;
*guard = None;
}
let mut base_url = self.base_url.lock().await;
*base_url = None;
}
pub(crate) async fn stream_prompt(
&self,
provider: Provider,
auth: CoreAuthProvider,
model: &str,
prompt: &ApiPrompt,
options: ResponsesOptions,
) -> Result<ResponseStream, ApiError> {
let should_reset = self
.base_url
.lock()
.await
.as_ref()
.map(|url| url != &provider.base_url)
.unwrap_or(false);
if should_reset {
self.reset().await;
}
let existing = { self.session.lock().await.clone() };
let session = if let Some(session) = existing {
session
} else {
let session = ResponsesWsSession::new(provider.clone(), auth);
{
let mut guard = self.session.lock().await;
if guard.is_none() {
*guard = Some(session.clone());
let mut base_url = self.base_url.lock().await;
*base_url = Some(provider.base_url.clone());
}
}
session
};
let stream = session.stream_prompt(model, prompt, options).await;
if stream.is_err() {
self.reset().await;
}
stream
}
}

View File

@@ -6,6 +6,7 @@ use crate::agent::AgentControl;
use crate::exec_policy::ExecPolicyManager;
use crate::mcp_connection_manager::McpConnectionManager;
use crate::models_manager::manager::ModelsManager;
use crate::responses_ws::ResponsesWsManager;
use crate::skills::SkillsManager;
use crate::tools::sandboxing::ApprovalStore;
use crate::unified_exec::UnifiedExecProcessManager;
@@ -30,4 +31,5 @@ pub(crate) struct SessionServices {
pub(crate) tool_approvals: Mutex<ApprovalStore>,
pub(crate) skills_manager: Arc<SkillsManager>,
pub(crate) agent_control: AgentControl,
pub(crate) responses_ws: Option<Arc<ResponsesWsManager>>,
}

View File

@@ -94,6 +94,7 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
model_info,
otel_manager,
provider,
None,
effort,
summary,
conversation_id,

View File

@@ -95,6 +95,7 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec<ResponseEvent> {
model_info,
otel_manager,
provider,
None,
effort,
summary,
conversation_id,

View File

@@ -87,6 +87,7 @@ async fn responses_stream_includes_subagent_header_on_review() {
model_info,
otel_manager,
provider,
None,
effort,
summary,
conversation_id,
@@ -182,6 +183,7 @@ async fn responses_stream_includes_subagent_header_on_other() {
model_info,
otel_manager,
provider,
None,
effort,
summary,
conversation_id,
@@ -275,6 +277,7 @@ async fn responses_respects_model_info_overrides_from_config() {
model_info,
otel_manager,
provider,
None,
effort,
summary,
conversation_id,

View File

@@ -1167,6 +1167,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
model_info,
otel_manager,
provider,
None,
effort,
summary,
conversation_id,