mirror of
https://github.com/openai/codex.git
synced 2026-02-02 23:13:37 +00:00
Compare commits
1 Commits
remove/doc
...
pakrym/ws
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0e4d129379 |
37
codex-rs/Cargo.lock
generated
37
codex-rs/Cargo.lock
generated
@@ -982,8 +982,10 @@ dependencies = [
|
||||
"thiserror 2.0.17",
|
||||
"tokio",
|
||||
"tokio-test",
|
||||
"tokio-tungstenite",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"url",
|
||||
"wiremock",
|
||||
]
|
||||
|
||||
@@ -2344,6 +2346,12 @@ dependencies = [
|
||||
"syn 2.0.104",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
|
||||
|
||||
[[package]]
|
||||
name = "dbus"
|
||||
version = "0.9.9"
|
||||
@@ -7095,6 +7103,18 @@ dependencies = [
|
||||
"tokio-stream",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.26.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a9daff607c6d2bf6c16fd681ccb7eecc83e4e2cdc1ca067ffaadfca5de7f084"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.16"
|
||||
@@ -7489,6 +7509,23 @@ dependencies = [
|
||||
"ratatui-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.26.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http 1.3.1",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand 0.9.2",
|
||||
"sha1",
|
||||
"thiserror 2.0.17",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.18.0"
|
||||
|
||||
@@ -207,6 +207,7 @@ thiserror = "2.0.17"
|
||||
time = "0.3"
|
||||
tiny_http = "0.12"
|
||||
tokio = "1"
|
||||
tokio-tungstenite = "0.26.1"
|
||||
tokio-stream = "0.1.18"
|
||||
tokio-test = "0.4"
|
||||
tokio-util = "0.7.16"
|
||||
|
||||
@@ -15,7 +15,9 @@ serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true, features = ["macros", "rt", "sync", "time"] }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
url = { workspace = true }
|
||||
eventsource-stream = { workspace = true }
|
||||
regex-lite = { workspace = true }
|
||||
tokio-util = { workspace = true, features = ["codec"] }
|
||||
|
||||
@@ -2,4 +2,5 @@ pub mod chat;
|
||||
pub mod compact;
|
||||
pub mod models;
|
||||
pub mod responses;
|
||||
pub mod responses_ws;
|
||||
mod streaming;
|
||||
|
||||
708
codex-rs/codex-api/src/endpoint/responses_ws.rs
Normal file
708
codex-rs/codex-api/src/endpoint/responses_ws.rs
Normal file
@@ -0,0 +1,708 @@
|
||||
use crate::auth::AuthProvider;
|
||||
use crate::common::Prompt as ApiPrompt;
|
||||
use crate::common::ResponseEvent;
|
||||
use crate::common::ResponseStream;
|
||||
use crate::endpoint::responses::ResponsesOptions;
|
||||
use crate::error::ApiError;
|
||||
use crate::provider::Provider;
|
||||
use crate::requests::ResponsesRequestBuilder;
|
||||
use codex_client::TransportError;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use futures::SinkExt;
|
||||
use futures::StreamExt;
|
||||
use http::HeaderMap;
|
||||
use http::HeaderValue;
|
||||
use serde::Deserialize;
|
||||
use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::OwnedSemaphorePermit;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_tungstenite::MaybeTlsStream;
|
||||
use tokio_tungstenite::WebSocketStream;
|
||||
use tokio_tungstenite::connect_async;
|
||||
use tokio_tungstenite::tungstenite;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
use url::Url;
|
||||
|
||||
const WS_BUFFER: usize = 1600;
|
||||
|
||||
type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
|
||||
type WsSender = futures::stream::SplitSink<WsStream, Message>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ResponsesWsSession<A: AuthProvider + Clone> {
|
||||
inner: Arc<ResponsesWsInner<A>>,
|
||||
}
|
||||
|
||||
struct ResponsesWsInner<A: AuthProvider + Clone> {
|
||||
provider: Provider,
|
||||
auth: A,
|
||||
connection: Mutex<Option<Arc<ResponsesWsConnection>>>,
|
||||
state: Arc<Mutex<WsSessionState>>,
|
||||
turn_gate: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct WsSessionState {
|
||||
last_sent_len: usize,
|
||||
active: bool,
|
||||
}
|
||||
|
||||
struct ResponsesWsConnection {
|
||||
sender: Mutex<WsSender>,
|
||||
receiver: Mutex<mpsc::Receiver<Result<String, ApiError>>>,
|
||||
}
|
||||
|
||||
impl<A: AuthProvider + Clone> ResponsesWsSession<A> {
|
||||
pub fn new(provider: Provider, auth: A) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(ResponsesWsInner {
|
||||
provider,
|
||||
auth,
|
||||
connection: Mutex::new(None),
|
||||
state: Arc::new(Mutex::new(WsSessionState::default())),
|
||||
turn_gate: Arc::new(Semaphore::new(1)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn reset(&self) {
|
||||
{
|
||||
let mut guard = self.inner.connection.lock().await;
|
||||
*guard = None;
|
||||
}
|
||||
let mut state = self.inner.state.lock().await;
|
||||
state.last_sent_len = 0;
|
||||
state.active = false;
|
||||
}
|
||||
|
||||
pub async fn stream_prompt(
|
||||
&self,
|
||||
model: &str,
|
||||
prompt: &ApiPrompt,
|
||||
options: ResponsesOptions,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let ResponsesOptions {
|
||||
reasoning,
|
||||
include,
|
||||
prompt_cache_key,
|
||||
text,
|
||||
store_override,
|
||||
conversation_id,
|
||||
session_source,
|
||||
extra_headers,
|
||||
} = options;
|
||||
|
||||
let request = ResponsesRequestBuilder::new(model, &prompt.instructions, &prompt.input)
|
||||
.tools(&prompt.tools)
|
||||
.parallel_tool_calls(prompt.parallel_tool_calls)
|
||||
.reasoning(reasoning)
|
||||
.include(include)
|
||||
.prompt_cache_key(prompt_cache_key)
|
||||
.text(text)
|
||||
.conversation(conversation_id)
|
||||
.session_source(session_source)
|
||||
.store_override(store_override)
|
||||
.extra_headers(extra_headers)
|
||||
.build(&self.inner.provider)?;
|
||||
|
||||
let input_len = prompt.input.len();
|
||||
let event = {
|
||||
let mut state = self.inner.state.lock().await;
|
||||
let should_reset = !state.active || input_len < state.last_sent_len;
|
||||
if should_reset {
|
||||
state.last_sent_len = 0;
|
||||
}
|
||||
state.active = true;
|
||||
if should_reset {
|
||||
build_create_event(request.body)?
|
||||
} else {
|
||||
let delta = prompt
|
||||
.input
|
||||
.get(state.last_sent_len..)
|
||||
.unwrap_or_default()
|
||||
.to_vec();
|
||||
build_append_event(delta)
|
||||
}
|
||||
};
|
||||
|
||||
let permit = self
|
||||
.inner
|
||||
.turn_gate
|
||||
.clone()
|
||||
.acquire_owned()
|
||||
.await
|
||||
.map_err(|_| ApiError::Stream("responses websocket closed".into()))?;
|
||||
|
||||
let connection = self.ensure_connection(request.headers).await?;
|
||||
if let Err(err) = connection.send(&event).await {
|
||||
self.reset().await;
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
Ok(spawn_ws_response_stream(
|
||||
connection,
|
||||
self.inner.state.clone(),
|
||||
input_len,
|
||||
permit,
|
||||
))
|
||||
}
|
||||
|
||||
async fn ensure_connection(
|
||||
&self,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<Arc<ResponsesWsConnection>, ApiError> {
|
||||
let existing = { self.inner.connection.lock().await.clone() };
|
||||
if let Some(connection) = existing {
|
||||
return Ok(connection);
|
||||
}
|
||||
|
||||
let connection =
|
||||
ResponsesWsConnection::connect(&self.inner.provider, &self.inner.auth, extra_headers)
|
||||
.await?;
|
||||
let connection = Arc::new(connection);
|
||||
|
||||
let mut guard = self.inner.connection.lock().await;
|
||||
if guard.is_none() {
|
||||
*guard = Some(connection.clone());
|
||||
}
|
||||
Ok(connection)
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponsesWsConnection {
|
||||
async fn connect<A: AuthProvider>(
|
||||
provider: &Provider,
|
||||
auth: &A,
|
||||
extra_headers: HeaderMap,
|
||||
) -> Result<Self, ApiError> {
|
||||
let url = ws_url(provider)?;
|
||||
let headers = build_ws_headers(provider, auth, extra_headers);
|
||||
let request = build_ws_request(url, headers)?;
|
||||
let (stream, _response) = connect_async(request).await.map_err(map_ws_error)?;
|
||||
let (sender, mut receiver) = stream.split();
|
||||
let (tx, rx) = mpsc::channel(WS_BUFFER);
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
let message = receiver.next().await;
|
||||
let message = match message {
|
||||
Some(Ok(message)) => message,
|
||||
Some(Err(err)) => {
|
||||
let _ = tx
|
||||
.send(Err(ApiError::Stream(format!("websocket error: {err}"))))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
None => {
|
||||
let _ = tx
|
||||
.send(Err(ApiError::Stream(
|
||||
"websocket closed unexpectedly".into(),
|
||||
)))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match message {
|
||||
Message::Text(text) => {
|
||||
if tx.send(Ok(text.to_string())).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
Message::Binary(bytes) => {
|
||||
if let Ok(text) = String::from_utf8(bytes.to_vec())
|
||||
&& tx.send(Ok(text)).await.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
Message::Close(_) => {
|
||||
let _ = tx
|
||||
.send(Err(ApiError::Stream("websocket closed".into())))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
Message::Ping(_) | Message::Pong(_) => {}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Self {
|
||||
sender: Mutex::new(sender),
|
||||
receiver: Mutex::new(rx),
|
||||
})
|
||||
}
|
||||
|
||||
async fn send(&self, payload: &Value) -> Result<(), ApiError> {
|
||||
let text = serde_json::to_string(payload)
|
||||
.map_err(|err| ApiError::Stream(format!("failed to encode ws payload: {err}")))?;
|
||||
let mut sender = self.sender.lock().await;
|
||||
sender
|
||||
.send(Message::Text(text.into()))
|
||||
.await
|
||||
.map_err(|err| ApiError::Stream(format!("websocket send failed: {err}")))
|
||||
}
|
||||
}
|
||||
|
||||
fn build_create_event(body: Value) -> Result<Value, ApiError> {
|
||||
let Value::Object(mut payload) = body else {
|
||||
return Err(ApiError::Stream(
|
||||
"responses create body was not an object".into(),
|
||||
));
|
||||
};
|
||||
payload.remove("stream");
|
||||
payload.remove("background");
|
||||
let mut event = serde_json::Map::new();
|
||||
event.insert(
|
||||
"type".to_string(),
|
||||
Value::String("response.create".to_string()),
|
||||
);
|
||||
event.extend(payload);
|
||||
Ok(Value::Object(event))
|
||||
}
|
||||
|
||||
fn build_append_event(input: Vec<ResponseItem>) -> Value {
|
||||
serde_json::json!({
|
||||
"type": "response.append",
|
||||
"input": input,
|
||||
})
|
||||
}
|
||||
|
||||
fn ws_url(provider: &Provider) -> Result<Url, ApiError> {
|
||||
let url = provider.url_for_path("responses");
|
||||
let mut url = Url::parse(&url)
|
||||
.map_err(|err| ApiError::Stream(format!("invalid websocket url: {err}")))?;
|
||||
let scheme = match url.scheme() {
|
||||
"https" => "wss",
|
||||
"http" => "ws",
|
||||
"wss" => "wss",
|
||||
"ws" => "ws",
|
||||
other => {
|
||||
return Err(ApiError::Stream(format!(
|
||||
"unsupported websocket scheme: {other}"
|
||||
)));
|
||||
}
|
||||
};
|
||||
if url.scheme() != scheme {
|
||||
url.set_scheme(scheme)
|
||||
.map_err(|_| ApiError::Stream("failed to set websocket scheme".into()))?;
|
||||
}
|
||||
Ok(url)
|
||||
}
|
||||
|
||||
fn build_ws_headers<A: AuthProvider>(
|
||||
provider: &Provider,
|
||||
auth: &A,
|
||||
extra_headers: HeaderMap,
|
||||
) -> HeaderMap {
|
||||
let mut headers = provider.headers.clone();
|
||||
headers.extend(extra_headers);
|
||||
if let Some(token) = auth.bearer_token()
|
||||
&& let Ok(header) = format!("Bearer {token}").parse()
|
||||
{
|
||||
let _ = headers.insert(http::header::AUTHORIZATION, header);
|
||||
}
|
||||
if let Some(account_id) = auth.account_id()
|
||||
&& let Ok(header) = HeaderValue::from_str(&account_id)
|
||||
{
|
||||
let _ = headers.insert("ChatGPT-Account-ID", header);
|
||||
}
|
||||
headers
|
||||
}
|
||||
|
||||
fn build_ws_request(url: Url, headers: HeaderMap) -> Result<http::Request<()>, ApiError> {
|
||||
let mut builder = http::Request::builder()
|
||||
.method(http::Method::GET)
|
||||
.uri(url.as_str());
|
||||
for (name, value) in headers.iter() {
|
||||
builder = builder.header(name, value);
|
||||
}
|
||||
builder
|
||||
.body(())
|
||||
.map_err(|err| ApiError::Stream(format!("failed to build websocket request: {err}")))
|
||||
}
|
||||
|
||||
fn map_ws_error(err: tungstenite::Error) -> ApiError {
|
||||
let transport = match err {
|
||||
tungstenite::Error::Http(response) => TransportError::Http {
|
||||
status: response.status(),
|
||||
headers: Some(response.headers().clone()),
|
||||
body: None,
|
||||
},
|
||||
tungstenite::Error::Url(err) => TransportError::Build(err.to_string()),
|
||||
tungstenite::Error::Io(err) => TransportError::Network(err.to_string()),
|
||||
other => TransportError::Network(other.to_string()),
|
||||
};
|
||||
ApiError::Transport(transport)
|
||||
}
|
||||
|
||||
fn spawn_ws_response_stream(
|
||||
connection: Arc<ResponsesWsConnection>,
|
||||
state: Arc<Mutex<WsSessionState>>,
|
||||
input_len: usize,
|
||||
permit: OwnedSemaphorePermit,
|
||||
) -> ResponseStream {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(WS_BUFFER);
|
||||
tokio::spawn(async move {
|
||||
let _permit = permit;
|
||||
let mut output_count: usize = 0;
|
||||
let mut draining = false;
|
||||
let mut can_send = true;
|
||||
let mut receiver = connection.receiver.lock().await;
|
||||
loop {
|
||||
let message = receiver.recv().await;
|
||||
let message = match message {
|
||||
Some(message) => message,
|
||||
None => {
|
||||
if can_send && !draining {
|
||||
let _ = tx_event
|
||||
.send(Err(ApiError::Stream(
|
||||
"websocket closed while awaiting responses".into(),
|
||||
)))
|
||||
.await;
|
||||
}
|
||||
let mut state = state.lock().await;
|
||||
state.active = false;
|
||||
state.last_sent_len = 0;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match message {
|
||||
Ok(text) => {
|
||||
trace!("WS event: {text}");
|
||||
let event: WsEvent = match serde_json::from_str(&text) {
|
||||
Ok(event) => event,
|
||||
Err(err) => {
|
||||
debug!("Failed to parse WS event: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match event.kind.as_str() {
|
||||
"response.output_item.done" => {
|
||||
let Some(item_val) = event.item else {
|
||||
continue;
|
||||
};
|
||||
let Ok(item) = serde_json::from_value::<ResponseItem>(item_val) else {
|
||||
debug!("failed to parse ResponseItem from output_item.done");
|
||||
continue;
|
||||
};
|
||||
output_count = output_count.saturating_add(1);
|
||||
if can_send
|
||||
&& tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemDone(item)))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
can_send = false;
|
||||
}
|
||||
}
|
||||
"response.output_item.added" => {
|
||||
let Some(item_val) = event.item else {
|
||||
continue;
|
||||
};
|
||||
let Ok(item) = serde_json::from_value::<ResponseItem>(item_val) else {
|
||||
debug!("failed to parse ResponseItem from output_item.added");
|
||||
continue;
|
||||
};
|
||||
if can_send
|
||||
&& tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemAdded(item)))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
can_send = false;
|
||||
}
|
||||
}
|
||||
"response.output_text.delta" => {
|
||||
if let Some(delta) = event.delta
|
||||
&& can_send
|
||||
&& tx_event
|
||||
.send(Ok(ResponseEvent::OutputTextDelta(delta)))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
can_send = false;
|
||||
}
|
||||
}
|
||||
"response.reasoning_summary_text.delta" => {
|
||||
if let (Some(delta), Some(summary_index)) =
|
||||
(event.delta, event.summary_index)
|
||||
&& can_send
|
||||
&& tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningSummaryDelta {
|
||||
delta,
|
||||
summary_index,
|
||||
}))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
can_send = false;
|
||||
}
|
||||
}
|
||||
"response.reasoning_text.delta" => {
|
||||
if let (Some(delta), Some(content_index)) =
|
||||
(event.delta, event.content_index)
|
||||
&& can_send
|
||||
&& tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta {
|
||||
delta,
|
||||
content_index,
|
||||
}))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
can_send = false;
|
||||
}
|
||||
}
|
||||
"response.reasoning_summary_part.added" => {
|
||||
if let Some(summary_index) = event.summary_index
|
||||
&& can_send
|
||||
&& tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningSummaryPartAdded {
|
||||
summary_index,
|
||||
}))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
can_send = false;
|
||||
}
|
||||
}
|
||||
"response.created" => {
|
||||
if can_send
|
||||
&& tx_event.send(Ok(ResponseEvent::Created {})).await.is_err()
|
||||
{
|
||||
can_send = false;
|
||||
}
|
||||
}
|
||||
"response.failed" => {
|
||||
let error = map_failed_response(&event);
|
||||
if can_send && tx_event.send(Err(error)).await.is_err() {
|
||||
can_send = false;
|
||||
}
|
||||
let mut state = state.lock().await;
|
||||
state.active = false;
|
||||
state.last_sent_len = 0;
|
||||
draining = true;
|
||||
}
|
||||
"response.done" | "response.completed" => {
|
||||
let completed = match completed_event(&event) {
|
||||
Ok(event) => event,
|
||||
Err(err) => {
|
||||
if can_send {
|
||||
let _ = tx_event.send(Err(err)).await;
|
||||
}
|
||||
let mut state = state.lock().await;
|
||||
state.active = false;
|
||||
state.last_sent_len = 0;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if !draining {
|
||||
if can_send {
|
||||
let _ = tx_event.send(Ok(completed)).await;
|
||||
}
|
||||
let mut state = state.lock().await;
|
||||
state.last_sent_len = input_len.saturating_add(output_count);
|
||||
state.active = true;
|
||||
}
|
||||
return;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
if can_send && !draining {
|
||||
let _ = tx_event.send(Err(err)).await;
|
||||
}
|
||||
let mut state = state.lock().await;
|
||||
state.active = false;
|
||||
state.last_sent_len = 0;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
ResponseStream { rx_event }
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
struct Error {
|
||||
r#type: Option<String>,
|
||||
code: Option<String>,
|
||||
message: Option<String>,
|
||||
plan_type: Option<String>,
|
||||
resets_at: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseCompleted {
|
||||
id: String,
|
||||
#[serde(default)]
|
||||
usage: Option<ResponseUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
struct ResponseUsage {
|
||||
#[serde(default)]
|
||||
input_tokens: i64,
|
||||
#[serde(default)]
|
||||
input_tokens_details: Option<ResponseInputTokensDetails>,
|
||||
#[serde(default)]
|
||||
output_tokens: i64,
|
||||
#[serde(default)]
|
||||
output_tokens_details: Option<ResponseOutputTokensDetails>,
|
||||
#[serde(default)]
|
||||
total_tokens: i64,
|
||||
}
|
||||
|
||||
impl From<ResponseUsage> for TokenUsage {
|
||||
fn from(value: ResponseUsage) -> Self {
|
||||
TokenUsage {
|
||||
input_tokens: value.input_tokens,
|
||||
cached_input_tokens: value
|
||||
.input_tokens_details
|
||||
.map(|d| d.cached_tokens)
|
||||
.unwrap_or(0),
|
||||
output_tokens: value.output_tokens,
|
||||
reasoning_output_tokens: value
|
||||
.output_tokens_details
|
||||
.map(|d| d.reasoning_tokens)
|
||||
.unwrap_or(0),
|
||||
total_tokens: value.total_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
struct ResponseInputTokensDetails {
|
||||
cached_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
struct ResponseOutputTokensDetails {
|
||||
reasoning_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
struct WsEvent {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
response: Option<Value>,
|
||||
item: Option<Value>,
|
||||
delta: Option<String>,
|
||||
summary_index: Option<i64>,
|
||||
content_index: Option<i64>,
|
||||
#[serde(default)]
|
||||
usage: Option<ResponseUsage>,
|
||||
}
|
||||
|
||||
fn completed_event(event: &WsEvent) -> Result<ResponseEvent, ApiError> {
|
||||
if let Some(response) = &event.response {
|
||||
let completed =
|
||||
serde_json::from_value::<ResponseCompleted>(response.clone()).map_err(|err| {
|
||||
ApiError::Stream(format!("failed to parse response.completed: {err}"))
|
||||
})?;
|
||||
return Ok(ResponseEvent::Completed {
|
||||
response_id: completed.id,
|
||||
token_usage: completed.usage.map(Into::into),
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage.clone() {
|
||||
return Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: Some(usage.into()),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn map_failed_response(event: &WsEvent) -> ApiError {
|
||||
let Some(resp_val) = event.response.clone() else {
|
||||
return ApiError::Stream("response.failed event received".into());
|
||||
};
|
||||
|
||||
let Some(error) = resp_val.get("error") else {
|
||||
return ApiError::Stream("response.failed event received".into());
|
||||
};
|
||||
|
||||
let Ok(error) = serde_json::from_value::<Error>(error.clone()) else {
|
||||
return ApiError::Stream("response.failed event received".into());
|
||||
};
|
||||
|
||||
if is_context_window_error(&error) {
|
||||
ApiError::ContextWindowExceeded
|
||||
} else if is_quota_exceeded_error(&error) {
|
||||
ApiError::QuotaExceeded
|
||||
} else if is_usage_not_included(&error) {
|
||||
ApiError::UsageNotIncluded
|
||||
} else {
|
||||
let delay = try_parse_retry_after(&error);
|
||||
let message = error.message.unwrap_or_default();
|
||||
ApiError::Retryable { message, delay }
|
||||
}
|
||||
}
|
||||
|
||||
fn try_parse_retry_after(err: &Error) -> Option<std::time::Duration> {
|
||||
if err.code.as_deref() != Some("rate_limit_exceeded") {
|
||||
return None;
|
||||
}
|
||||
|
||||
let re = rate_limit_regex();
|
||||
if let Some(message) = &err.message
|
||||
&& let Some(captures) = re.captures(message)
|
||||
{
|
||||
let seconds = captures.get(1);
|
||||
let unit = captures.get(2);
|
||||
|
||||
if let (Some(value), Some(unit)) = (seconds, unit) {
|
||||
let value = value.as_str().parse::<f64>().ok()?;
|
||||
let unit = unit.as_str().to_ascii_lowercase();
|
||||
|
||||
if unit == "s" || unit.starts_with("second") {
|
||||
return Some(std::time::Duration::from_secs_f64(value));
|
||||
} else if unit == "ms" {
|
||||
return Some(std::time::Duration::from_millis(value as u64));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn is_context_window_error(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("context_length_exceeded")
|
||||
}
|
||||
|
||||
fn is_quota_exceeded_error(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("insufficient_quota")
|
||||
}
|
||||
|
||||
fn is_usage_not_included(error: &Error) -> bool {
|
||||
error.code.as_deref() == Some("usage_not_included")
|
||||
}
|
||||
|
||||
fn rate_limit_regex() -> &'static regex_lite::Regex {
|
||||
static RE: std::sync::OnceLock<regex_lite::Regex> = std::sync::OnceLock::new();
|
||||
#[expect(clippy::unwrap_used)]
|
||||
RE.get_or_init(|| {
|
||||
regex_lite::Regex::new(r"(?i)try again in\\s*(\\d+(?:\\.\\d+)?)\\s*(s|ms|seconds?)")
|
||||
.unwrap()
|
||||
})
|
||||
}
|
||||
@@ -25,6 +25,7 @@ pub use crate::endpoint::compact::CompactClient;
|
||||
pub use crate::endpoint::models::ModelsClient;
|
||||
pub use crate::endpoint::responses::ResponsesClient;
|
||||
pub use crate::endpoint::responses::ResponsesOptions;
|
||||
pub use crate::endpoint::responses_ws::ResponsesWsSession;
|
||||
pub use crate::error::ApiError;
|
||||
pub use crate::provider::Provider;
|
||||
pub use crate::provider::WireApi;
|
||||
|
||||
@@ -47,9 +47,11 @@ use crate::default_client::build_reqwest_client;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result;
|
||||
use crate::features::FEATURES;
|
||||
use crate::flags::CODEX_RS_RESPONSES_WS;
|
||||
use crate::flags::CODEX_RS_SSE_FIXTURE;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::WireApi;
|
||||
use crate::responses_ws::ResponsesWsManager;
|
||||
use crate::tools::spec::create_tools_json_for_chat_completions_api;
|
||||
use crate::tools::spec::create_tools_json_for_responses_api;
|
||||
|
||||
@@ -60,6 +62,7 @@ pub struct ModelClient {
|
||||
model_info: ModelInfo,
|
||||
otel_manager: OtelManager,
|
||||
provider: ModelProviderInfo,
|
||||
responses_ws: Option<Arc<ResponsesWsManager>>,
|
||||
conversation_id: ThreadId,
|
||||
effort: Option<ReasoningEffortConfig>,
|
||||
summary: ReasoningSummaryConfig,
|
||||
@@ -74,6 +77,7 @@ impl ModelClient {
|
||||
model_info: ModelInfo,
|
||||
otel_manager: OtelManager,
|
||||
provider: ModelProviderInfo,
|
||||
responses_ws: Option<Arc<ResponsesWsManager>>,
|
||||
effort: Option<ReasoningEffortConfig>,
|
||||
summary: ReasoningSummaryConfig,
|
||||
conversation_id: ThreadId,
|
||||
@@ -85,6 +89,7 @@ impl ModelClient {
|
||||
model_info,
|
||||
otel_manager,
|
||||
provider,
|
||||
responses_ws,
|
||||
conversation_id,
|
||||
effort,
|
||||
summary,
|
||||
@@ -115,7 +120,12 @@ impl ModelClient {
|
||||
/// based on the `show_raw_agent_reasoning` flag in the config.
|
||||
pub async fn stream(&self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
match self.provider.wire_api {
|
||||
WireApi::Responses => self.stream_responses_api(prompt).await,
|
||||
WireApi::Responses => {
|
||||
if *CODEX_RS_RESPONSES_WS && let Some(manager) = self.responses_ws.as_ref() {
|
||||
return self.stream_responses_ws(prompt, manager).await;
|
||||
}
|
||||
self.stream_responses_api(prompt).await
|
||||
}
|
||||
WireApi::Chat => {
|
||||
let api_stream = self.stream_chat_completions(prompt).await?;
|
||||
|
||||
@@ -283,6 +293,108 @@ impl ModelClient {
|
||||
}
|
||||
}
|
||||
|
||||
async fn stream_responses_ws(
|
||||
&self,
|
||||
prompt: &Prompt,
|
||||
manager: &Arc<ResponsesWsManager>,
|
||||
) -> Result<ResponseStream> {
|
||||
if let Some(path) = &*CODEX_RS_SSE_FIXTURE {
|
||||
warn!(path, "Streaming from fixture");
|
||||
let stream = codex_api::stream_from_fixture(path, self.provider.stream_idle_timeout())
|
||||
.map_err(map_api_error)?;
|
||||
return Ok(map_response_stream(stream, self.otel_manager.clone()));
|
||||
}
|
||||
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let model_info = self.get_model_info();
|
||||
let instructions = prompt.get_full_instructions(&model_info).into_owned();
|
||||
let tools_json: Vec<Value> = create_tools_json_for_responses_api(&prompt.tools)?;
|
||||
|
||||
let default_reasoning_effort = model_info.default_reasoning_level;
|
||||
let reasoning = if model_info.supports_reasoning_summaries {
|
||||
Some(Reasoning {
|
||||
effort: self.effort.or(default_reasoning_effort),
|
||||
summary: if self.summary == ReasoningSummaryConfig::None {
|
||||
None
|
||||
} else {
|
||||
Some(self.summary)
|
||||
},
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let include: Vec<String> = if reasoning.is_some() {
|
||||
vec!["reasoning.encrypted_content".to_string()]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
let verbosity = if model_info.support_verbosity {
|
||||
self.config.model_verbosity.or(model_info.default_verbosity)
|
||||
} else {
|
||||
if self.config.model_verbosity.is_some() {
|
||||
warn!(
|
||||
"model_verbosity is set but ignored as the model does not support verbosity: {}",
|
||||
model_info.slug
|
||||
);
|
||||
}
|
||||
None
|
||||
};
|
||||
|
||||
let text = create_text_param_for_request(verbosity, &prompt.output_schema);
|
||||
let api_prompt = build_api_prompt(prompt, instructions.clone(), tools_json);
|
||||
let conversation_id = self.conversation_id.to_string();
|
||||
let session_source = self.session_source.clone();
|
||||
|
||||
let mut refreshed = false;
|
||||
loop {
|
||||
let auth = auth_manager.as_ref().and_then(|m| m.auth());
|
||||
let api_provider = self
|
||||
.provider
|
||||
.to_api_provider(auth.as_ref().map(|a| a.mode))?;
|
||||
let api_auth = auth_provider_from_auth(auth.clone(), &self.provider).await?;
|
||||
|
||||
let options = ApiResponsesOptions {
|
||||
reasoning: reasoning.clone(),
|
||||
include: include.clone(),
|
||||
prompt_cache_key: Some(conversation_id.clone()),
|
||||
text: text.clone(),
|
||||
store_override: None,
|
||||
conversation_id: Some(conversation_id.clone()),
|
||||
session_source: Some(session_source.clone()),
|
||||
extra_headers: beta_feature_headers(&self.config),
|
||||
};
|
||||
|
||||
let stream_result = manager
|
||||
.stream_prompt(
|
||||
api_provider,
|
||||
api_auth,
|
||||
&self.get_model(),
|
||||
&api_prompt,
|
||||
options,
|
||||
)
|
||||
.await;
|
||||
|
||||
match stream_result {
|
||||
Ok(stream) => {
|
||||
return Ok(map_response_stream(stream, self.otel_manager.clone()));
|
||||
}
|
||||
Err(ApiError::Transport(TransportError::Http { status, .. }))
|
||||
if status == StatusCode::UNAUTHORIZED =>
|
||||
{
|
||||
manager.reset().await;
|
||||
handle_unauthorized(status, &mut refreshed, &auth_manager, &auth).await?;
|
||||
continue;
|
||||
}
|
||||
Err(err) => {
|
||||
manager.reset().await;
|
||||
return Err(map_api_error(err));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_provider(&self) -> ModelProviderInfo {
|
||||
self.provider.clone()
|
||||
}
|
||||
|
||||
@@ -19,9 +19,11 @@ use crate::compact_remote::run_inline_remote_auto_compact_task;
|
||||
use crate::exec_policy::ExecPolicyManager;
|
||||
use crate::features::Feature;
|
||||
use crate::features::Features;
|
||||
use crate::flags::CODEX_RS_RESPONSES_WS;
|
||||
use crate::models_manager::manager::ModelsManager;
|
||||
use crate::parse_command::parse_command;
|
||||
use crate::parse_turn_item;
|
||||
use crate::responses_ws::ResponsesWsManager;
|
||||
use crate::stream_events_utils::HandleOutputCtx;
|
||||
use crate::stream_events_utils::handle_non_tool_response_item;
|
||||
use crate::stream_events_utils::handle_output_item_done;
|
||||
@@ -506,6 +508,7 @@ impl Session {
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
otel_manager: &OtelManager,
|
||||
provider: ModelProviderInfo,
|
||||
responses_ws: Option<Arc<ResponsesWsManager>>,
|
||||
session_configuration: &SessionConfiguration,
|
||||
per_turn_config: Config,
|
||||
model_info: ModelInfo,
|
||||
@@ -524,6 +527,7 @@ impl Session {
|
||||
model_info.clone(),
|
||||
otel_manager,
|
||||
provider,
|
||||
responses_ws,
|
||||
session_configuration.model_reasoning_effort,
|
||||
session_configuration.model_reasoning_summary,
|
||||
conversation_id,
|
||||
@@ -676,6 +680,11 @@ impl Session {
|
||||
.map(Arc::new);
|
||||
}
|
||||
let state = SessionState::new(session_configuration.clone());
|
||||
let responses_ws = if *CODEX_RS_RESPONSES_WS {
|
||||
Some(Arc::new(ResponsesWsManager::new()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let services = SessionServices {
|
||||
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::default())),
|
||||
@@ -692,6 +701,7 @@ impl Session {
|
||||
tool_approvals: Mutex::new(ApprovalStore::default()),
|
||||
skills_manager,
|
||||
agent_control,
|
||||
responses_ws,
|
||||
};
|
||||
|
||||
let sess = Arc::new(Session {
|
||||
@@ -952,6 +962,7 @@ impl Session {
|
||||
Some(Arc::clone(&self.services.auth_manager)),
|
||||
&self.services.otel_manager,
|
||||
session_configuration.provider.clone(),
|
||||
self.services.responses_ws.clone(),
|
||||
&session_configuration,
|
||||
per_turn_config,
|
||||
model_info,
|
||||
@@ -2243,6 +2254,7 @@ async fn spawn_review_thread(
|
||||
model_info.clone(),
|
||||
otel_manager,
|
||||
provider,
|
||||
None,
|
||||
per_turn_config.model_reasoning_effort,
|
||||
per_turn_config.model_reasoning_summary,
|
||||
sess.conversation_id,
|
||||
@@ -3532,12 +3544,14 @@ mod tests {
|
||||
tool_approvals: Mutex::new(ApprovalStore::default()),
|
||||
skills_manager,
|
||||
agent_control,
|
||||
responses_ws: None,
|
||||
};
|
||||
|
||||
let turn_context = Session::make_turn_context(
|
||||
Some(Arc::clone(&auth_manager)),
|
||||
&otel_manager,
|
||||
session_configuration.provider.clone(),
|
||||
None,
|
||||
&session_configuration,
|
||||
per_turn_config,
|
||||
model_info,
|
||||
@@ -3626,12 +3640,14 @@ mod tests {
|
||||
tool_approvals: Mutex::new(ApprovalStore::default()),
|
||||
skills_manager,
|
||||
agent_control,
|
||||
responses_ws: None,
|
||||
};
|
||||
|
||||
let turn_context = Arc::new(Session::make_turn_context(
|
||||
Some(Arc::clone(&auth_manager)),
|
||||
&otel_manager,
|
||||
session_configuration.provider.clone(),
|
||||
None,
|
||||
&session_configuration,
|
||||
per_turn_config,
|
||||
model_info,
|
||||
|
||||
@@ -3,4 +3,5 @@ use env_flags::env_flags;
|
||||
env_flags! {
|
||||
/// Fixture path for offline tests (see client.rs).
|
||||
pub CODEX_RS_SSE_FIXTURE: Option<&str> = None;
|
||||
pub CODEX_RS_RESPONSES_WS: bool = false;
|
||||
}
|
||||
|
||||
@@ -78,6 +78,7 @@ pub use auth::AuthManager;
|
||||
pub use auth::CodexAuth;
|
||||
pub mod default_client;
|
||||
pub mod project_doc;
|
||||
mod responses_ws;
|
||||
mod rollout;
|
||||
pub(crate) mod safety;
|
||||
pub mod seatbelt;
|
||||
@@ -134,5 +135,6 @@ pub use codex_protocol::models::LocalShellStatus;
|
||||
pub use codex_protocol::models::ResponseItem;
|
||||
pub use compact::content_items_to_text;
|
||||
pub use event_mapping::parse_turn_item;
|
||||
pub use responses_ws::ResponsesWsManager;
|
||||
pub mod compact;
|
||||
pub mod otel_init;
|
||||
|
||||
79
codex-rs/core/src/responses_ws.rs
Normal file
79
codex-rs/core/src/responses_ws.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
use crate::api_bridge::CoreAuthProvider;
|
||||
use codex_api::Prompt as ApiPrompt;
|
||||
use codex_api::Provider;
|
||||
use codex_api::ResponseStream;
|
||||
use codex_api::ResponsesOptions;
|
||||
use codex_api::ResponsesWsSession;
|
||||
use codex_api::error::ApiError;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
pub struct ResponsesWsManager {
|
||||
session: Mutex<Option<ResponsesWsSession<CoreAuthProvider>>>,
|
||||
base_url: Mutex<Option<String>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ResponsesWsManager {
|
||||
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
formatter.debug_struct("ResponsesWsManager").finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponsesWsManager {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
session: Mutex::new(None),
|
||||
base_url: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn reset(&self) {
|
||||
{
|
||||
let mut guard = self.session.lock().await;
|
||||
*guard = None;
|
||||
}
|
||||
let mut base_url = self.base_url.lock().await;
|
||||
*base_url = None;
|
||||
}
|
||||
|
||||
pub(crate) async fn stream_prompt(
|
||||
&self,
|
||||
provider: Provider,
|
||||
auth: CoreAuthProvider,
|
||||
model: &str,
|
||||
prompt: &ApiPrompt,
|
||||
options: ResponsesOptions,
|
||||
) -> Result<ResponseStream, ApiError> {
|
||||
let should_reset = self
|
||||
.base_url
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.map(|url| url != &provider.base_url)
|
||||
.unwrap_or(false);
|
||||
if should_reset {
|
||||
self.reset().await;
|
||||
}
|
||||
|
||||
let existing = { self.session.lock().await.clone() };
|
||||
let session = if let Some(session) = existing {
|
||||
session
|
||||
} else {
|
||||
let session = ResponsesWsSession::new(provider.clone(), auth);
|
||||
{
|
||||
let mut guard = self.session.lock().await;
|
||||
if guard.is_none() {
|
||||
*guard = Some(session.clone());
|
||||
let mut base_url = self.base_url.lock().await;
|
||||
*base_url = Some(provider.base_url.clone());
|
||||
}
|
||||
}
|
||||
session
|
||||
};
|
||||
|
||||
let stream = session.stream_prompt(model, prompt, options).await;
|
||||
if stream.is_err() {
|
||||
self.reset().await;
|
||||
}
|
||||
stream
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ use crate::agent::AgentControl;
|
||||
use crate::exec_policy::ExecPolicyManager;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::models_manager::manager::ModelsManager;
|
||||
use crate::responses_ws::ResponsesWsManager;
|
||||
use crate::skills::SkillsManager;
|
||||
use crate::tools::sandboxing::ApprovalStore;
|
||||
use crate::unified_exec::UnifiedExecProcessManager;
|
||||
@@ -30,4 +31,5 @@ pub(crate) struct SessionServices {
|
||||
pub(crate) tool_approvals: Mutex<ApprovalStore>,
|
||||
pub(crate) skills_manager: Arc<SkillsManager>,
|
||||
pub(crate) agent_control: AgentControl,
|
||||
pub(crate) responses_ws: Option<Arc<ResponsesWsManager>>,
|
||||
}
|
||||
|
||||
@@ -94,6 +94,7 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
|
||||
model_info,
|
||||
otel_manager,
|
||||
provider,
|
||||
None,
|
||||
effort,
|
||||
summary,
|
||||
conversation_id,
|
||||
|
||||
@@ -95,6 +95,7 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec<ResponseEvent> {
|
||||
model_info,
|
||||
otel_manager,
|
||||
provider,
|
||||
None,
|
||||
effort,
|
||||
summary,
|
||||
conversation_id,
|
||||
|
||||
@@ -87,6 +87,7 @@ async fn responses_stream_includes_subagent_header_on_review() {
|
||||
model_info,
|
||||
otel_manager,
|
||||
provider,
|
||||
None,
|
||||
effort,
|
||||
summary,
|
||||
conversation_id,
|
||||
@@ -182,6 +183,7 @@ async fn responses_stream_includes_subagent_header_on_other() {
|
||||
model_info,
|
||||
otel_manager,
|
||||
provider,
|
||||
None,
|
||||
effort,
|
||||
summary,
|
||||
conversation_id,
|
||||
@@ -275,6 +277,7 @@ async fn responses_respects_model_info_overrides_from_config() {
|
||||
model_info,
|
||||
otel_manager,
|
||||
provider,
|
||||
None,
|
||||
effort,
|
||||
summary,
|
||||
conversation_id,
|
||||
|
||||
@@ -1167,6 +1167,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
|
||||
model_info,
|
||||
otel_manager,
|
||||
provider,
|
||||
None,
|
||||
effort,
|
||||
summary,
|
||||
conversation_id,
|
||||
|
||||
Reference in New Issue
Block a user