Websocket append support (#9128)

Support an incremental append request in websocket transport.
This commit is contained in:
pakrym-oai
2026-01-12 22:07:13 -08:00
committed by GitHub
parent ddae70bd62
commit e726a82c8a
6 changed files with 164 additions and 130 deletions

View File

@@ -11,16 +11,18 @@ use codex_api::CompactionInput as ApiCompactionInput;
use codex_api::Prompt as ApiPrompt;
use codex_api::RequestTelemetry;
use codex_api::ReqwestTransport;
use codex_api::ResponseAppendWsRequest;
use codex_api::ResponseCreateWsRequest;
use codex_api::ResponseStream as ApiResponseStream;
use codex_api::ResponsesClient as ApiResponsesClient;
use codex_api::ResponsesOptions as ApiResponsesOptions;
use codex_api::ResponsesRequest;
use codex_api::ResponsesRequestBuilder;
use codex_api::ResponsesWebsocketClient as ApiWebSocketResponsesClient;
use codex_api::ResponsesWebsocketConnection as ApiWebSocketConnection;
use codex_api::SseTelemetry;
use codex_api::TransportError;
use codex_api::build_conversation_headers;
use codex_api::common::Reasoning;
use codex_api::common::ResponsesWsRequest;
use codex_api::create_text_param_for_request;
use codex_api::error::ApiError;
use codex_api::requests::responses::Compression;
@@ -83,6 +85,7 @@ pub struct ModelClient {
pub struct ModelClientSession {
state: Arc<ModelClientState>,
connection: Option<ApiWebSocketConnection>,
websocket_last_items: Vec<ResponseItem>,
}
#[allow(clippy::too_many_arguments)]
@@ -117,6 +120,7 @@ impl ModelClient {
ModelClientSession {
state: Arc::clone(&self.state),
connection: None,
websocket_last_items: Vec::new(),
}
}
}
@@ -320,49 +324,65 @@ impl ModelClientSession {
}
}
fn build_responses_websocket_request(
fn get_incremental_items(&self, input_items: &[ResponseItem]) -> Option<Vec<ResponseItem>> {
// Checks whether the current request input is an incremental append to the previous request.
// If items in the new request contain all the items from the previous request we build
// a response.append request otherwise we start with a fresh response.create request.
let previous_len = self.websocket_last_items.len();
let can_append = previous_len > 0
&& input_items.starts_with(&self.websocket_last_items)
&& previous_len < input_items.len();
if can_append {
Some(input_items[previous_len..].to_vec())
} else {
None
}
}
fn prepare_websocket_request(
&self,
api_provider: &codex_api::Provider,
api_prompt: &ApiPrompt,
options: ApiResponsesOptions,
) -> Result<ResponsesRequest> {
options: &ApiResponsesOptions,
) -> ResponsesWsRequest {
if let Some(append_items) = self.get_incremental_items(&api_prompt.input) {
return ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest {
input: append_items,
});
}
let ApiResponsesOptions {
reasoning,
include,
prompt_cache_key,
text,
store_override,
conversation_id,
session_source,
extra_headers,
compression,
..
} = options;
ResponsesRequestBuilder::new(
&self.state.model_info.slug,
&api_prompt.instructions,
&api_prompt.input,
)
.tools(&api_prompt.tools)
.parallel_tool_calls(api_prompt.parallel_tool_calls)
.reasoning(reasoning)
.include(include)
.prompt_cache_key(prompt_cache_key)
.text(text)
.conversation(conversation_id)
.session_source(session_source)
.store_override(store_override)
.extra_headers(extra_headers)
.compression(compression)
.build(api_provider)
.map_err(map_api_error)
let store = store_override.unwrap_or(false);
let payload = ResponseCreateWsRequest {
model: self.state.model_info.slug.clone(),
instructions: api_prompt.instructions.clone(),
input: api_prompt.input.clone(),
tools: api_prompt.tools.clone(),
tool_choice: "auto".to_string(),
parallel_tool_calls: api_prompt.parallel_tool_calls,
reasoning: reasoning.clone(),
store,
stream: true,
include: include.clone(),
prompt_cache_key: prompt_cache_key.clone(),
text: text.clone(),
};
ResponsesWsRequest::ResponseCreate(payload)
}
async fn websocket_connection(
&mut self,
api_provider: codex_api::Provider,
api_auth: CoreAuthProvider,
headers: ApiHeaderMap,
options: &ApiResponsesOptions,
) -> std::result::Result<&ApiWebSocketConnection, ApiError> {
let needs_new = match self.connection.as_ref() {
Some(conn) => conn.is_closed().await,
@@ -370,9 +390,12 @@ impl ModelClientSession {
};
if needs_new {
let new_conn = ApiWebSocketResponsesClient::new(api_provider, api_auth)
.connect(headers)
.await?;
let mut headers = options.extra_headers.clone();
headers.extend(build_conversation_headers(options.conversation_id.clone()));
let new_conn: ApiWebSocketConnection =
ApiWebSocketResponsesClient::new(api_provider, api_auth)
.connect(headers)
.await?;
self.connection = Some(new_conn);
}
@@ -533,15 +556,10 @@ impl ModelClientSession {
let compression = self.responses_request_compression(auth.as_ref());
let options = self.build_responses_options(prompt, compression);
let request =
self.build_responses_websocket_request(&api_provider, &api_prompt, options)?;
let request = self.prepare_websocket_request(&api_prompt, &options);
let connection = match self
.websocket_connection(
api_provider.clone(),
api_auth.clone(),
request.headers.clone(),
)
.websocket_connection(api_provider.clone(), api_auth.clone(), &options)
.await
{
Ok(connection) => connection,
@@ -558,6 +576,7 @@ impl ModelClientSession {
.stream_request(request)
.await
.map_err(map_api_error)?;
self.websocket_last_items = api_prompt.input.clone();
return Ok(map_response_stream(
stream_result,