Compare commits

...

2 Commits

Author SHA1 Message Date
Ahmed Ibrahim
2d00570ca5 Merge branch 'main' into sticky-r 2026-01-15 13:39:20 -08:00
Ahmed Ibrahim
840ec2d650 progress 2026-01-15 12:02:21 -08:00
11 changed files with 163 additions and 12 deletions

View File

@@ -13,6 +13,8 @@ use std::task::Context;
use std::task::Poll;
use tokio::sync::mpsc;
pub const TURN_STATE_HEADER: &str = "x-codex-turn-state";
/// Canonical prompt input for Chat and Responses endpoints.
#[derive(Debug, Clone)]
pub struct Prompt {
@@ -60,6 +62,7 @@ pub enum ResponseEvent {
},
RateLimits(RateLimitSnapshot),
ModelsEtag(String),
TurnState(String),
}
#[derive(Debug, Serialize, Clone)]

View File

@@ -56,6 +56,7 @@ impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
prompt: &ApiPrompt,
conversation_id: Option<String>,
session_source: Option<SessionSource>,
extra_headers: HeaderMap,
) -> Result<ResponseStream, ApiError> {
use crate::requests::ChatRequestBuilder;
@@ -63,6 +64,7 @@ impl<T: HttpTransport, A: AuthProvider> ChatClient<T, A> {
ChatRequestBuilder::new(model, &prompt.instructions, &prompt.input, &prompt.tools)
.conversation_id(conversation_id)
.session_source(session_source)
.extra_headers(extra_headers)
.build(self.streaming.provider())?;
self.stream_request(request).await
@@ -159,6 +161,9 @@ impl Stream for AggregatedStream {
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
}
Poll::Ready(Some(Ok(ResponseEvent::TurnState(state)))) => {
return Poll::Ready(Some(Ok(ResponseEvent::TurnState(state))));
}
Poll::Ready(Some(Ok(ResponseEvent::ModelsEtag(etag)))) => {
return Poll::Ready(Some(Ok(ResponseEvent::ModelsEtag(etag))));
}

View File

@@ -2,6 +2,7 @@ use crate::auth::AuthProvider;
use crate::common::ResponseEvent;
use crate::common::ResponseStream;
use crate::common::ResponsesWsRequest;
use crate::common::TURN_STATE_HEADER;
use crate::error::ApiError;
use crate::provider::Provider;
use crate::sse::responses::ResponsesStreamEvent;
@@ -32,13 +33,15 @@ pub struct ResponsesWebsocketConnection {
stream: Arc<Mutex<Option<WsStream>>>,
// TODO (pakrym): is this the right place for timeout?
idle_timeout: Duration,
turn_state_header: Option<String>,
}
impl ResponsesWebsocketConnection {
fn new(stream: WsStream, idle_timeout: Duration) -> Self {
fn new(stream: WsStream, idle_timeout: Duration, turn_state_header: Option<String>) -> Self {
Self {
stream: Arc::new(Mutex::new(Some(stream))),
idle_timeout,
turn_state_header,
}
}
@@ -54,11 +57,17 @@ impl ResponsesWebsocketConnection {
mpsc::channel::<std::result::Result<ResponseEvent, ApiError>>(1600);
let stream = Arc::clone(&self.stream);
let idle_timeout = self.idle_timeout;
let turn_state = self.turn_state_header.clone();
let request_body = serde_json::to_value(&request).map_err(|err| {
ApiError::Stream(format!("failed to encode websocket request: {err}"))
})?;
tokio::spawn(async move {
if let Some(turn_state) = turn_state {
let _ = tx_event
.send(Ok(ResponseEvent::TurnState(turn_state)))
.await;
}
let mut guard = stream.lock().await;
let Some(ws_stream) = guard.as_mut() else {
let _ = tx_event
@@ -108,10 +117,15 @@ impl<A: AuthProvider> ResponsesWebsocketClient<A> {
headers.extend(extra_headers);
apply_auth_headers(&mut headers, &self.auth);
let stream = connect_websocket(ws_url, headers).await?;
let (stream, response_headers) = connect_websocket(ws_url, headers).await?;
let turn_state_header = response_headers
.get(TURN_STATE_HEADER)
.and_then(|value| value.to_str().ok())
.map(ToString::to_string);
Ok(ResponsesWebsocketConnection::new(
stream,
self.provider.stream_idle_timeout,
turn_state_header,
))
}
}
@@ -130,17 +144,20 @@ fn apply_auth_headers(headers: &mut HeaderMap, auth: &impl AuthProvider) {
}
}
async fn connect_websocket(url: Url, headers: HeaderMap) -> Result<WsStream, ApiError> {
async fn connect_websocket(
url: Url,
headers: HeaderMap,
) -> Result<(WsStream, HeaderMap), ApiError> {
let mut request = url
.clone()
.into_client_request()
.map_err(|err| ApiError::Stream(format!("failed to build websocket request: {err}")))?;
request.headers_mut().extend(headers);
let (stream, _) = tokio_tungstenite::connect_async(request)
let (stream, response) = tokio_tungstenite::connect_async(request)
.await
.map_err(|err| map_ws_error(err, &url))?;
Ok(stream)
Ok((stream, response.headers().clone()))
}
fn map_ws_error(err: WsError, url: &Url) -> ApiError {

View File

@@ -26,6 +26,7 @@ pub struct ChatRequestBuilder<'a> {
tools: &'a [Value],
conversation_id: Option<String>,
session_source: Option<SessionSource>,
extra_headers: HeaderMap,
}
impl<'a> ChatRequestBuilder<'a> {
@@ -42,6 +43,7 @@ impl<'a> ChatRequestBuilder<'a> {
tools,
conversation_id: None,
session_source: None,
extra_headers: HeaderMap::new(),
}
}
@@ -55,6 +57,11 @@ impl<'a> ChatRequestBuilder<'a> {
self
}
pub fn extra_headers(mut self, headers: HeaderMap) -> Self {
self.extra_headers = headers;
self
}
pub fn build(self, _provider: &Provider) -> Result<ChatRequest, ApiError> {
let mut messages = Vec::<Value>::new();
messages.push(json!({"role": "system", "content": self.instructions}));
@@ -298,7 +305,8 @@ impl<'a> ChatRequestBuilder<'a> {
"tools": self.tools,
});
let mut headers = build_conversation_headers(self.conversation_id);
let mut headers = self.extra_headers;
headers.extend(build_conversation_headers(self.conversation_id));
if let Some(subagent) = subagent_header(&self.session_source) {
insert_header(&mut headers, "x-openai-subagent", &subagent);
}

View File

@@ -1,5 +1,6 @@
use crate::common::ResponseEvent;
use crate::common::ResponseStream;
use crate::common::TURN_STATE_HEADER;
use crate::error::ApiError;
use crate::telemetry::SseTelemetry;
use codex_client::StreamResponse;
@@ -23,9 +24,20 @@ pub(crate) fn spawn_chat_stream(
idle_timeout: Duration,
telemetry: Option<std::sync::Arc<dyn SseTelemetry>>,
) -> ResponseStream {
let turn_state = stream_response
.headers
.get(TURN_STATE_HEADER)
.and_then(|value| value.to_str().ok())
.map(ToString::to_string);
let bytes = stream_response.bytes;
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
tokio::spawn(async move {
process_chat_sse(stream_response.bytes, tx_event, idle_timeout, telemetry).await;
if let Some(turn_state) = turn_state {
let _ = tx_event
.send(Ok(ResponseEvent::TurnState(turn_state)))
.await;
}
process_chat_sse(bytes, tx_event, idle_timeout, telemetry).await;
});
ResponseStream { rx_event }
}

View File

@@ -1,5 +1,6 @@
use crate::common::ResponseEvent;
use crate::common::ResponseStream;
use crate::common::TURN_STATE_HEADER;
use crate::error::ApiError;
use crate::rate_limits::parse_rate_limit;
use crate::telemetry::SseTelemetry;
@@ -51,6 +52,11 @@ pub fn spawn_response_stream(
telemetry: Option<Arc<dyn SseTelemetry>>,
) -> ResponseStream {
let rate_limits = parse_rate_limit(&stream_response.headers);
let turn_state = stream_response
.headers
.get(TURN_STATE_HEADER)
.and_then(|v| v.to_str().ok())
.map(ToString::to_string);
let models_etag = stream_response
.headers
.get("X-Models-Etag")
@@ -58,6 +64,11 @@ pub fn spawn_response_stream(
.map(ToString::to_string);
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent, ApiError>>(1600);
tokio::spawn(async move {
if let Some(turn_state) = turn_state {
let _ = tx_event
.send(Ok(ResponseEvent::TurnState(turn_state)))
.await;
}
if let Some(snapshot) = rate_limits {
let _ = tx_event.send(Ok(ResponseEvent::RateLimits(snapshot))).await;
}

View File

@@ -1,4 +1,5 @@
use std::sync::Arc;
use std::sync::Mutex;
use crate::api_bridge::CoreAuthProvider;
use crate::api_bridge::auth_provider_from_auth;
@@ -23,6 +24,7 @@ use codex_api::TransportError;
use codex_api::build_conversation_headers;
use codex_api::common::Reasoning;
use codex_api::common::ResponsesWsRequest;
use codex_api::common::TURN_STATE_HEADER;
use codex_api::create_text_param_for_request;
use codex_api::error::ApiError;
use codex_api::requests::responses::Compression;
@@ -89,6 +91,8 @@ pub struct ModelClientSession {
state: Arc<ModelClientState>,
connection: Option<ApiWebSocketConnection>,
websocket_last_items: Vec<ResponseItem>,
turn_state_header: Option<Arc<Mutex<Option<String>>>>,
websocket_turn_state_header: Option<String>,
}
#[allow(clippy::too_many_arguments)]
@@ -124,6 +128,8 @@ impl ModelClient {
state: Arc::clone(&self.state),
connection: None,
websocket_last_items: Vec::new(),
turn_state_header: None,
websocket_turn_state_header: None,
}
}
}
@@ -185,6 +191,16 @@ impl ModelClient {
/// This is a unary call (no streaming) that returns a new list of
/// `ResponseItem`s representing the compacted transcript.
pub async fn compact_conversation_history(&self, prompt: &Prompt) -> Result<Vec<ResponseItem>> {
self.compact_conversation_history_with_headers(prompt, ApiHeaderMap::new())
.await
}
/// Compacts the current conversation history using the Compact endpoint with extra headers.
pub async fn compact_conversation_history_with_headers(
&self,
prompt: &Prompt,
mut extra_headers: ApiHeaderMap,
) -> Result<Vec<ResponseItem>> {
if prompt.input.is_empty() {
return Ok(Vec::new());
}
@@ -212,7 +228,6 @@ impl ModelClient {
instructions: &instructions,
};
let mut extra_headers = ApiHeaderMap::new();
if let SessionSource::SubAgent(sub) = &self.state.session_source {
let subagent = if let crate::protocol::SubAgentSource::Other(label) = sub {
label.clone()
@@ -235,6 +250,14 @@ impl ModelClient {
}
impl ModelClientSession {
pub(crate) fn with_turn_state_header(
mut self,
turn_state_header: Arc<Mutex<Option<String>>>,
) -> Self {
self.turn_state_header = Some(turn_state_header);
self
}
/// Streams a single model turn using either the Responses or Chat
/// Completions wire API, depending on the configured provider.
///
@@ -314,6 +337,9 @@ impl ModelClientSession {
let text = create_text_param_for_request(verbosity, &prompt.output_schema);
let conversation_id = self.state.conversation_id.to_string();
let mut extra_headers = beta_feature_headers(&self.state.config);
self.insert_turn_state_header(&mut extra_headers);
ApiResponsesOptions {
reasoning,
include,
@@ -387,12 +413,16 @@ impl ModelClientSession {
api_auth: CoreAuthProvider,
options: &ApiResponsesOptions,
) -> std::result::Result<&ApiWebSocketConnection, ApiError> {
let turn_state_header = self.turn_state_header_value();
let needs_new = match self.connection.as_ref() {
Some(conn) => conn.is_closed().await,
None => true,
};
let needs_refresh = !needs_new
&& turn_state_header.is_some()
&& self.websocket_turn_state_header != turn_state_header;
if needs_new {
if needs_new || needs_refresh {
let mut headers = options.extra_headers.clone();
headers.extend(build_conversation_headers(options.conversation_id.clone()));
let new_conn: ApiWebSocketConnection =
@@ -400,6 +430,8 @@ impl ModelClientSession {
.connect(headers)
.await?;
self.connection = Some(new_conn);
self.websocket_turn_state_header = turn_state_header;
self.websocket_last_items.clear();
}
self.connection.as_ref().ok_or(ApiError::Stream(
@@ -459,12 +491,16 @@ impl ModelClientSession {
let client = ApiChatClient::new(transport, api_provider, api_auth)
.with_telemetry(Some(request_telemetry), Some(sse_telemetry));
let mut extra_headers = ApiHeaderMap::new();
self.insert_turn_state_header(&mut extra_headers);
let stream_result = client
.stream_prompt(
&self.state.model_info.slug,
&api_prompt,
Some(conversation_id.clone()),
Some(session_source.clone()),
extra_headers,
)
.await;
@@ -588,6 +624,19 @@ impl ModelClientSession {
}
}
fn turn_state_header_value(&self) -> Option<String> {
self.turn_state_header.as_ref().and_then(|turn_state| {
turn_state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.clone()
})
}
fn insert_turn_state_header(&self, headers: &mut ApiHeaderMap) {
insert_turn_state_header(headers, self.turn_state_header_value().as_deref());
}
/// Builds request and SSE telemetry for streaming API calls (Chat/Responses).
fn build_streaming_telemetry(&self) -> (Arc<dyn RequestTelemetry>, Arc<dyn SseTelemetry>) {
let telemetry = Arc::new(ApiTelemetry::new(self.state.otel_manager.clone()));
@@ -617,6 +666,14 @@ fn build_api_prompt(prompt: &Prompt, instructions: String, tools_json: Vec<Value
}
}
pub(crate) fn insert_turn_state_header(headers: &mut ApiHeaderMap, value: Option<&str>) {
if let Some(value) = value
&& let Ok(header_value) = HeaderValue::from_str(value)
{
headers.insert(TURN_STATE_HEADER, header_value);
}
}
fn beta_feature_headers(config: &Config) -> ApiHeaderMap {
let enabled = FEATURES
.iter()

View File

@@ -383,6 +383,7 @@ pub(crate) struct Session {
pub(crate) struct TurnContext {
pub(crate) sub_id: String,
pub(crate) client: ModelClient,
pub(crate) turn_state_header: Arc<std::sync::Mutex<Option<String>>>,
/// The session's current working directory. All relative paths provided by
/// the model as well as sandbox policies are resolved against this path
/// instead of `std::env::current_dir()`.
@@ -403,6 +404,21 @@ pub(crate) struct TurnContext {
}
impl TurnContext {
pub(crate) fn capture_turn_state_header(&self, value: String) {
let mut header = self
.turn_state_header
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
*header = Some(value);
}
pub(crate) fn turn_state_header(&self) -> Option<String> {
self.turn_state_header
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.clone()
}
pub(crate) fn resolve_path(&self, path: Option<String>) -> PathBuf {
path.as_ref()
.map(PathBuf::from)
@@ -545,6 +561,7 @@ impl Session {
TurnContext {
sub_id,
client,
turn_state_header: Arc::new(std::sync::Mutex::new(None)),
cwd: session_configuration.cwd.clone(),
developer_instructions: session_configuration.developer_instructions.clone(),
base_instructions: session_configuration.base_instructions.clone(),
@@ -2443,6 +2460,7 @@ async fn spawn_review_thread(
client,
tools_config,
ghost_snapshot: parent_turn_context.ghost_snapshot.clone(),
turn_state_header: Arc::new(std::sync::Mutex::new(None)),
developer_instructions: None,
user_instructions: None,
base_instructions: Some(base_instructions.clone()),
@@ -2577,7 +2595,10 @@ pub(crate) async fn run_turn(
// many turns, from the perspective of the user, it is a single turn.
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
let mut client_session = turn_context.client.new_session();
let mut client_session = turn_context
.client
.new_session()
.with_turn_state_header(Arc::clone(&turn_context.turn_state_header));
loop {
// Note that pending_input would be something like a message the user
@@ -2953,6 +2974,9 @@ async fn try_run_turn(
.refresh_if_new_etag(etag, &config)
.await;
}
ResponseEvent::TurnState(value) => {
turn_context.capture_turn_state_header(value);
}
ResponseEvent::Completed {
response_id: _,
token_usage,

View File

@@ -301,7 +301,10 @@ async fn drain_to_completed(
turn_context: &TurnContext,
prompt: &Prompt,
) -> CodexResult<()> {
let mut client_session = turn_context.client.new_session();
let mut client_session = turn_context
.client
.new_session()
.with_turn_state_header(Arc::clone(&turn_context.turn_state_header));
let mut stream = client_session.stream(prompt).await?;
loop {
let maybe_event = stream.next().await;
@@ -319,6 +322,9 @@ async fn drain_to_completed(
Ok(ResponseEvent::RateLimits(snapshot)) => {
sess.update_rate_limits(turn_context, snapshot).await;
}
Ok(ResponseEvent::TurnState(value)) => {
turn_context.capture_turn_state_header(value);
}
Ok(ResponseEvent::Completed { token_usage, .. }) => {
sess.update_token_usage_info(turn_context, token_usage.as_ref())
.await;

View File

@@ -1,6 +1,7 @@
use std::sync::Arc;
use crate::Prompt;
use crate::client::insert_turn_state_header;
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::error::Result as CodexResult;
@@ -10,6 +11,7 @@ use crate::protocol::EventMsg;
use crate::protocol::RolloutItem;
use crate::protocol::TurnStartedEvent;
use codex_protocol::models::ResponseItem;
use http::HeaderMap;
pub(crate) async fn run_inline_remote_auto_compact_task(
sess: Arc<Session>,
@@ -58,9 +60,14 @@ async fn run_remote_compact_task_inner_impl(
output_schema: None,
};
let mut extra_headers = HeaderMap::new();
insert_turn_state_header(
&mut extra_headers,
turn_context.turn_state_header().as_deref(),
);
let mut new_history = turn_context
.client
.compact_conversation_history(&prompt)
.compact_conversation_history_with_headers(&prompt, extra_headers)
.await?;
if !ghost_snapshots.is_empty() {

View File

@@ -481,6 +481,7 @@ impl OtelManager {
}
ResponseEvent::RateLimits(_) => "rate_limits".into(),
ResponseEvent::ModelsEtag(_) => "models_etag".into(),
ResponseEvent::TurnState(_) => "turn_state".into(),
}
}