mirror of
https://github.com/openai/codex.git
synced 2026-04-30 01:16:54 +00:00
Websocket append support (#9128)
Support an incremental append request in websocket transport.
This commit is contained in:
@@ -11,16 +11,18 @@ use codex_api::CompactionInput as ApiCompactionInput;
|
||||
use codex_api::Prompt as ApiPrompt;
|
||||
use codex_api::RequestTelemetry;
|
||||
use codex_api::ReqwestTransport;
|
||||
use codex_api::ResponseAppendWsRequest;
|
||||
use codex_api::ResponseCreateWsRequest;
|
||||
use codex_api::ResponseStream as ApiResponseStream;
|
||||
use codex_api::ResponsesClient as ApiResponsesClient;
|
||||
use codex_api::ResponsesOptions as ApiResponsesOptions;
|
||||
use codex_api::ResponsesRequest;
|
||||
use codex_api::ResponsesRequestBuilder;
|
||||
use codex_api::ResponsesWebsocketClient as ApiWebSocketResponsesClient;
|
||||
use codex_api::ResponsesWebsocketConnection as ApiWebSocketConnection;
|
||||
use codex_api::SseTelemetry;
|
||||
use codex_api::TransportError;
|
||||
use codex_api::build_conversation_headers;
|
||||
use codex_api::common::Reasoning;
|
||||
use codex_api::common::ResponsesWsRequest;
|
||||
use codex_api::create_text_param_for_request;
|
||||
use codex_api::error::ApiError;
|
||||
use codex_api::requests::responses::Compression;
|
||||
@@ -83,6 +85,7 @@ pub struct ModelClient {
|
||||
pub struct ModelClientSession {
|
||||
state: Arc<ModelClientState>,
|
||||
connection: Option<ApiWebSocketConnection>,
|
||||
websocket_last_items: Vec<ResponseItem>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -117,6 +120,7 @@ impl ModelClient {
|
||||
ModelClientSession {
|
||||
state: Arc::clone(&self.state),
|
||||
connection: None,
|
||||
websocket_last_items: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -320,49 +324,65 @@ impl ModelClientSession {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_responses_websocket_request(
|
||||
fn get_incremental_items(&self, input_items: &[ResponseItem]) -> Option<Vec<ResponseItem>> {
|
||||
// Checks whether the current request input is an incremental append to the previous request.
|
||||
// If items in the new request contain all the items from the previous request we build
|
||||
// a response.append request otherwise we start with a fresh response.create request.
|
||||
let previous_len = self.websocket_last_items.len();
|
||||
let can_append = previous_len > 0
|
||||
&& input_items.starts_with(&self.websocket_last_items)
|
||||
&& previous_len < input_items.len();
|
||||
if can_append {
|
||||
Some(input_items[previous_len..].to_vec())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_websocket_request(
|
||||
&self,
|
||||
api_provider: &codex_api::Provider,
|
||||
api_prompt: &ApiPrompt,
|
||||
options: ApiResponsesOptions,
|
||||
) -> Result<ResponsesRequest> {
|
||||
options: &ApiResponsesOptions,
|
||||
) -> ResponsesWsRequest {
|
||||
if let Some(append_items) = self.get_incremental_items(&api_prompt.input) {
|
||||
return ResponsesWsRequest::ResponseAppend(ResponseAppendWsRequest {
|
||||
input: append_items,
|
||||
});
|
||||
}
|
||||
|
||||
let ApiResponsesOptions {
|
||||
reasoning,
|
||||
include,
|
||||
prompt_cache_key,
|
||||
text,
|
||||
store_override,
|
||||
conversation_id,
|
||||
session_source,
|
||||
extra_headers,
|
||||
compression,
|
||||
..
|
||||
} = options;
|
||||
|
||||
ResponsesRequestBuilder::new(
|
||||
&self.state.model_info.slug,
|
||||
&api_prompt.instructions,
|
||||
&api_prompt.input,
|
||||
)
|
||||
.tools(&api_prompt.tools)
|
||||
.parallel_tool_calls(api_prompt.parallel_tool_calls)
|
||||
.reasoning(reasoning)
|
||||
.include(include)
|
||||
.prompt_cache_key(prompt_cache_key)
|
||||
.text(text)
|
||||
.conversation(conversation_id)
|
||||
.session_source(session_source)
|
||||
.store_override(store_override)
|
||||
.extra_headers(extra_headers)
|
||||
.compression(compression)
|
||||
.build(api_provider)
|
||||
.map_err(map_api_error)
|
||||
let store = store_override.unwrap_or(false);
|
||||
let payload = ResponseCreateWsRequest {
|
||||
model: self.state.model_info.slug.clone(),
|
||||
instructions: api_prompt.instructions.clone(),
|
||||
input: api_prompt.input.clone(),
|
||||
tools: api_prompt.tools.clone(),
|
||||
tool_choice: "auto".to_string(),
|
||||
parallel_tool_calls: api_prompt.parallel_tool_calls,
|
||||
reasoning: reasoning.clone(),
|
||||
store,
|
||||
stream: true,
|
||||
include: include.clone(),
|
||||
prompt_cache_key: prompt_cache_key.clone(),
|
||||
text: text.clone(),
|
||||
};
|
||||
|
||||
ResponsesWsRequest::ResponseCreate(payload)
|
||||
}
|
||||
|
||||
async fn websocket_connection(
|
||||
&mut self,
|
||||
api_provider: codex_api::Provider,
|
||||
api_auth: CoreAuthProvider,
|
||||
headers: ApiHeaderMap,
|
||||
options: &ApiResponsesOptions,
|
||||
) -> std::result::Result<&ApiWebSocketConnection, ApiError> {
|
||||
let needs_new = match self.connection.as_ref() {
|
||||
Some(conn) => conn.is_closed().await,
|
||||
@@ -370,9 +390,12 @@ impl ModelClientSession {
|
||||
};
|
||||
|
||||
if needs_new {
|
||||
let new_conn = ApiWebSocketResponsesClient::new(api_provider, api_auth)
|
||||
.connect(headers)
|
||||
.await?;
|
||||
let mut headers = options.extra_headers.clone();
|
||||
headers.extend(build_conversation_headers(options.conversation_id.clone()));
|
||||
let new_conn: ApiWebSocketConnection =
|
||||
ApiWebSocketResponsesClient::new(api_provider, api_auth)
|
||||
.connect(headers)
|
||||
.await?;
|
||||
self.connection = Some(new_conn);
|
||||
}
|
||||
|
||||
@@ -533,15 +556,10 @@ impl ModelClientSession {
|
||||
let compression = self.responses_request_compression(auth.as_ref());
|
||||
|
||||
let options = self.build_responses_options(prompt, compression);
|
||||
let request =
|
||||
self.build_responses_websocket_request(&api_provider, &api_prompt, options)?;
|
||||
let request = self.prepare_websocket_request(&api_prompt, &options);
|
||||
|
||||
let connection = match self
|
||||
.websocket_connection(
|
||||
api_provider.clone(),
|
||||
api_auth.clone(),
|
||||
request.headers.clone(),
|
||||
)
|
||||
.websocket_connection(api_provider.clone(), api_auth.clone(), &options)
|
||||
.await
|
||||
{
|
||||
Ok(connection) => connection,
|
||||
@@ -558,6 +576,7 @@ impl ModelClientSession {
|
||||
.stream_request(request)
|
||||
.await
|
||||
.map_err(map_api_error)?;
|
||||
self.websocket_last_items = api_prompt.input.clone();
|
||||
|
||||
return Ok(map_response_stream(
|
||||
stream_result,
|
||||
|
||||
Reference in New Issue
Block a user