mirror of
https://github.com/openai/codex.git
synced 2026-02-02 23:13:37 +00:00
Compare commits
16 Commits
shell-proc
...
jif/client
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6e936f92a4 | ||
|
|
423c40d9b8 | ||
|
|
bed7f051a4 | ||
|
|
e770917d9b | ||
|
|
cfe7fb825c | ||
|
|
69341a04f4 | ||
|
|
3315336366 | ||
|
|
7a9c344bd5 | ||
|
|
482e6d7fad | ||
|
|
a1eee44844 | ||
|
|
16b5f62b31 | ||
|
|
bae5341585 | ||
|
|
43f7733ad7 | ||
|
|
0bba25eff3 | ||
|
|
bb8f963b08 | ||
|
|
7c48a0b717 |
30
codex-rs/Cargo.lock
generated
30
codex-rs/Cargo.lock
generated
@@ -829,6 +829,29 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-api-client"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"codex-app-server-protocol",
|
||||
"codex-otel",
|
||||
"codex-protocol",
|
||||
"eventsource-stream",
|
||||
"futures",
|
||||
"maplit",
|
||||
"regex-lite",
|
||||
"reqwest",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.16",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
"toml",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-app-server"
|
||||
version = "0.0.0"
|
||||
@@ -1039,6 +1062,7 @@ name = "codex-common"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"clap",
|
||||
"codex-api-client",
|
||||
"codex-app-server-protocol",
|
||||
"codex-core",
|
||||
"codex-protocol",
|
||||
@@ -1059,6 +1083,7 @@ dependencies = [
|
||||
"base64",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"codex-api-client",
|
||||
"codex-app-server-protocol",
|
||||
"codex-apply-patch",
|
||||
"codex-async-utils",
|
||||
@@ -1131,6 +1156,7 @@ dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"clap",
|
||||
"codex-api-client",
|
||||
"codex-arg0",
|
||||
"codex-common",
|
||||
"codex-core",
|
||||
@@ -1206,7 +1232,6 @@ name = "codex-git"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"assert_matches",
|
||||
"once_cell",
|
||||
"pretty_assertions",
|
||||
"regex",
|
||||
"schemars 0.8.22",
|
||||
@@ -1296,6 +1321,7 @@ dependencies = [
|
||||
"assert_matches",
|
||||
"async-stream",
|
||||
"bytes",
|
||||
"codex-api-client",
|
||||
"codex-core",
|
||||
"futures",
|
||||
"reqwest",
|
||||
@@ -1437,6 +1463,7 @@ dependencies = [
|
||||
"chrono",
|
||||
"clap",
|
||||
"codex-ansi-escape",
|
||||
"codex-api-client",
|
||||
"codex-app-server-protocol",
|
||||
"codex-arg0",
|
||||
"codex-common",
|
||||
@@ -1670,6 +1697,7 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"assert_cmd",
|
||||
"codex-api-client",
|
||||
"codex-core",
|
||||
"codex-protocol",
|
||||
"notify",
|
||||
|
||||
@@ -38,7 +38,7 @@ members = [
|
||||
"utils/pty",
|
||||
"utils/readiness",
|
||||
"utils/string",
|
||||
"utils/tokenizer",
|
||||
"utils/tokenizer", "api-client",
|
||||
]
|
||||
resolver = "2"
|
||||
|
||||
@@ -87,6 +87,7 @@ codex-utils-pty = { path = "utils/pty" }
|
||||
codex-utils-readiness = { path = "utils/readiness" }
|
||||
codex-utils-string = { path = "utils/string" }
|
||||
codex-utils-tokenizer = { path = "utils/tokenizer" }
|
||||
codex-api-client = { path = "api-client" }
|
||||
core_test_support = { path = "core/tests/common" }
|
||||
mcp-types = { path = "mcp-types" }
|
||||
mcp_test_support = { path = "mcp-server/tests/common" }
|
||||
|
||||
28
codex-rs/api-client/Cargo.toml
Normal file
28
codex-rs/api-client/Cargo.toml
Normal file
@@ -0,0 +1,28 @@
|
||||
[package]
|
||||
name = "codex-api-client"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
async-trait = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-otel = { workspace = true }
|
||||
codex-protocol = { path = "../protocol" }
|
||||
eventsource-stream = { workspace = true }
|
||||
futures = { workspace = true, default-features = false, features = ["std"] }
|
||||
regex-lite = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["json", "stream"] }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true, features = ["sync", "time", "rt", "rt-multi-thread", "macros", "io-util"] }
|
||||
tokio-util = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
maplit = "1.0.2"
|
||||
toml = { workspace = true }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
16
codex-rs/api-client/src/api.rs
Normal file
16
codex-rs/api-client/src/api.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::prompt::Prompt;
|
||||
use crate::stream::ResponseStream;
|
||||
|
||||
#[async_trait]
|
||||
pub trait ApiClient: Send + Sync {
|
||||
type Config: Send + Sync;
|
||||
|
||||
async fn new(config: Self::Config) -> Result<Self, Error>
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
async fn stream(&self, prompt: Prompt) -> Result<ResponseStream, Error>;
|
||||
}
|
||||
15
codex-rs/api-client/src/auth.rs
Normal file
15
codex-rs/api-client/src/auth.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
use async_trait::async_trait;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AuthContext {
|
||||
pub mode: AuthMode,
|
||||
pub bearer_token: Option<String>,
|
||||
pub account_id: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait AuthProvider: Send + Sync {
|
||||
async fn auth_context(&self) -> Option<AuthContext>;
|
||||
async fn refresh_token(&self) -> Result<Option<String>, String>;
|
||||
}
|
||||
866
codex-rs/api-client/src/chat.rs
Normal file
866
codex-rs/api-client/src/chat.rs
Normal file
@@ -0,0 +1,866 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::Stream;
|
||||
use futures::StreamExt;
|
||||
use futures::TryStreamExt;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
|
||||
use crate::api::ApiClient;
|
||||
use crate::error::Error;
|
||||
use crate::model_provider::ModelProviderInfo;
|
||||
use crate::prompt::Prompt;
|
||||
use crate::stream::ResponseEvent;
|
||||
use crate::stream::ResponseStream;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum ChatAggregationMode {
|
||||
AggregatedOnly,
|
||||
Streaming,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ChatCompletionsApiClientConfig {
|
||||
pub http_client: reqwest::Client,
|
||||
pub provider: ModelProviderInfo,
|
||||
pub model: String,
|
||||
pub otel_event_manager: OtelEventManager,
|
||||
pub session_source: SessionSource,
|
||||
pub aggregation_mode: ChatAggregationMode,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ChatCompletionsApiClient {
|
||||
config: ChatCompletionsApiClientConfig,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ApiClient for ChatCompletionsApiClient {
|
||||
type Config = ChatCompletionsApiClientConfig;
|
||||
|
||||
async fn new(config: Self::Config) -> Result<Self> {
|
||||
Ok(Self { config })
|
||||
}
|
||||
|
||||
async fn stream(&self, prompt: Prompt) -> Result<ResponseStream> {
|
||||
Self::validate_prompt(&prompt)?;
|
||||
|
||||
let payload = self.build_payload(&prompt)?;
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
|
||||
let mut attempt = 0u64;
|
||||
let max_retries = self.config.provider.request_max_retries();
|
||||
|
||||
loop {
|
||||
attempt += 1;
|
||||
|
||||
let mut req_builder = self
|
||||
.config
|
||||
.provider
|
||||
.create_request_builder(&self.config.http_client, &None)
|
||||
.await?;
|
||||
|
||||
if let SessionSource::SubAgent(sub) = &self.config.session_source {
|
||||
let subagent = if let SubAgentSource::Other(label) = sub {
|
||||
label.clone()
|
||||
} else {
|
||||
serde_json::to_value(sub)
|
||||
.ok()
|
||||
.and_then(|v| v.as_str().map(std::string::ToString::to_string))
|
||||
.unwrap_or_else(|| "other".to_string())
|
||||
};
|
||||
req_builder = req_builder.header("x-openai-subagent", subagent);
|
||||
}
|
||||
|
||||
let res = self
|
||||
.config
|
||||
.otel_event_manager
|
||||
.log_request(attempt, || {
|
||||
req_builder
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(&payload)
|
||||
.send()
|
||||
})
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let stream = resp
|
||||
.bytes_stream()
|
||||
.map_err(|err| Error::ResponseStreamFailed {
|
||||
source: err,
|
||||
request_id: None,
|
||||
});
|
||||
let idle_timeout = self.config.provider.stream_idle_timeout();
|
||||
let otel = self.config.otel_event_manager.clone();
|
||||
let mode = self.config.aggregation_mode;
|
||||
|
||||
tokio::spawn(process_chat_sse(
|
||||
stream,
|
||||
tx_event.clone(),
|
||||
idle_timeout,
|
||||
otel,
|
||||
mode,
|
||||
));
|
||||
|
||||
return Ok(ResponseStream { rx_event });
|
||||
}
|
||||
Ok(resp) => {
|
||||
if attempt >= max_retries {
|
||||
let status = resp.status();
|
||||
let body = resp
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "<failed to read response>".to_string());
|
||||
return Err(Error::UnexpectedStatus { status, body });
|
||||
}
|
||||
|
||||
let retry_after = resp
|
||||
.headers()
|
||||
.get(reqwest::header::RETRY_AFTER)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
.map(Duration::from_secs);
|
||||
tokio::time::sleep(retry_after.unwrap_or_else(|| backoff(attempt))).await;
|
||||
}
|
||||
Err(error) => {
|
||||
if attempt >= max_retries {
|
||||
return Err(Error::Http(error));
|
||||
}
|
||||
tokio::time::sleep(backoff(attempt)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ChatCompletionsApiClient {
|
||||
fn validate_prompt(prompt: &Prompt) -> Result<()> {
|
||||
if prompt.output_schema.is_some() {
|
||||
return Err(Error::UnsupportedOperation(
|
||||
"output_schema is not supported for Chat Completions API".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_payload(&self, prompt: &Prompt) -> Result<serde_json::Value> {
|
||||
let mut messages = Vec::<serde_json::Value>::new();
|
||||
messages.push(json!({ "role": "system", "content": prompt.instructions }));
|
||||
|
||||
let mut reasoning_by_anchor_index: std::collections::HashMap<usize, String> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
let mut last_emitted_role: Option<&str> = None;
|
||||
for item in &prompt.input {
|
||||
match item {
|
||||
ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()),
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
last_emitted_role = Some("assistant");
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"),
|
||||
ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::Other
|
||||
| ResponseItem::CustomToolCall { .. }
|
||||
| ResponseItem::CustomToolCallOutput { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::GhostSnapshot { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
let mut last_user_index: Option<usize> = None;
|
||||
for (idx, item) in prompt.input.iter().enumerate() {
|
||||
if let ResponseItem::Message { role, .. } = item
|
||||
&& role == "user"
|
||||
{
|
||||
last_user_index = Some(idx);
|
||||
}
|
||||
}
|
||||
|
||||
if !matches!(last_emitted_role, Some("user")) {
|
||||
for (idx, item) in prompt.input.iter().enumerate() {
|
||||
if let Some(u_idx) = last_user_index
|
||||
&& idx <= u_idx
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let ResponseItem::Reasoning {
|
||||
content: Some(items),
|
||||
..
|
||||
} = item
|
||||
{
|
||||
let mut text = String::new();
|
||||
for entry in items {
|
||||
match entry {
|
||||
ReasoningItemContent::ReasoningText { text: segment }
|
||||
| ReasoningItemContent::Text { text: segment } => {
|
||||
text.push_str(segment);
|
||||
}
|
||||
}
|
||||
}
|
||||
if text.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut attached = false;
|
||||
if idx > 0
|
||||
&& let ResponseItem::Message { role, .. } = &prompt.input[idx - 1]
|
||||
&& role == "assistant"
|
||||
{
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx - 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
attached = true;
|
||||
}
|
||||
|
||||
if !attached && idx + 1 < prompt.input.len() {
|
||||
match &prompt.input[idx + 1] {
|
||||
ResponseItem::FunctionCall { .. }
|
||||
| ResponseItem::LocalShellCall { .. } => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
ResponseItem::Message { role, .. } if role == "assistant" => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut last_assistant_text: Option<String> = None;
|
||||
|
||||
for (idx, item) in prompt.input.iter().enumerate() {
|
||||
match item {
|
||||
ResponseItem::Message { role, content, .. } => {
|
||||
let mut text = String::new();
|
||||
let mut items: Vec<serde_json::Value> = Vec::new();
|
||||
let mut saw_image = false;
|
||||
|
||||
for c in content {
|
||||
match c {
|
||||
ContentItem::InputText { text: t }
|
||||
| ContentItem::OutputText { text: t } => {
|
||||
text.push_str(t);
|
||||
items.push(json!({"type":"text","text": t}));
|
||||
}
|
||||
ContentItem::InputImage { image_url } => {
|
||||
saw_image = true;
|
||||
items.push(
|
||||
json!({"type":"image_url","image_url": {"url": image_url}}),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if role == "assistant" {
|
||||
if let Some(prev) = &last_assistant_text
|
||||
&& prev == &text
|
||||
{
|
||||
continue;
|
||||
}
|
||||
last_assistant_text = Some(text.clone());
|
||||
}
|
||||
|
||||
let content_value = if role == "assistant" {
|
||||
json!(text)
|
||||
} else if saw_image {
|
||||
json!(items)
|
||||
} else {
|
||||
json!(text)
|
||||
};
|
||||
|
||||
let mut message = json!({
|
||||
"role": role,
|
||||
"content": content_value,
|
||||
});
|
||||
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = message.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!({"text": reasoning}));
|
||||
}
|
||||
|
||||
messages.push(message);
|
||||
}
|
||||
ResponseItem::FunctionCall {
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
}],
|
||||
}));
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
let content_value = if let Some(items) = &output.content_items {
|
||||
let mapped: Vec<serde_json::Value> = items
|
||||
.iter()
|
||||
.map(|item| match item {
|
||||
FunctionCallOutputContentItem::InputText { text } => {
|
||||
json!({"type":"text","text": text})
|
||||
}
|
||||
FunctionCallOutputContentItem::InputImage { image_url } => {
|
||||
json!({"type":"image_url","image_url": {"url": image_url}})
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
json!(mapped)
|
||||
} else {
|
||||
json!(output.content)
|
||||
};
|
||||
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": content_value,
|
||||
}));
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
call_id,
|
||||
action,
|
||||
..
|
||||
} => {
|
||||
let tool_id = call_id
|
||||
.clone()
|
||||
.filter(|value| !value.is_empty())
|
||||
.or_else(|| id.clone())
|
||||
.unwrap_or_default();
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": tool_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "shell",
|
||||
"arguments": serde_json::to_string(action).unwrap_or_default(),
|
||||
},
|
||||
}],
|
||||
}));
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
call_id,
|
||||
name,
|
||||
input,
|
||||
..
|
||||
} => {
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
"tool_calls": [{
|
||||
"id": call_id.clone(),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": input,
|
||||
},
|
||||
}],
|
||||
}));
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, output } => {
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": output,
|
||||
}));
|
||||
}
|
||||
ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::Other
|
||||
| ResponseItem::GhostSnapshot { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
|
||||
let payload = json!({
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"stream": true,
|
||||
"tools": tools_json,
|
||||
});
|
||||
|
||||
trace!("chat completions payload: {}", payload);
|
||||
Ok(payload)
|
||||
}
|
||||
}
|
||||
|
||||
async fn append_assistant_text(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
|
||||
assistant_item: &mut Option<ResponseItem>,
|
||||
text: String,
|
||||
) {
|
||||
if assistant_item.is_none() {
|
||||
let item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![],
|
||||
};
|
||||
*assistant_item = Some(item.clone());
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemAdded(item)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(ResponseItem::Message { content, .. }) = assistant_item {
|
||||
content.push(ContentItem::OutputText { text: text.clone() });
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputTextDelta(text.clone())))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn append_reasoning_text(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
|
||||
reasoning_item: &mut Option<ResponseItem>,
|
||||
text: String,
|
||||
) {
|
||||
if reasoning_item.is_none() {
|
||||
let item = ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
*reasoning_item = Some(item.clone());
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemAdded(item)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(ResponseItem::Reasoning {
|
||||
content: Some(content),
|
||||
..
|
||||
}) = reasoning_item
|
||||
{
|
||||
content.push(ReasoningItemContent::ReasoningText { text: text.clone() });
|
||||
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta(text.clone())))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_chat_sse<S>(
|
||||
stream: S,
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent>>,
|
||||
idle_timeout: Duration,
|
||||
otel_event_manager: OtelEventManager,
|
||||
_aggregation_mode: ChatAggregationMode,
|
||||
) where
|
||||
S: Stream<Item = Result<Bytes>> + Unpin,
|
||||
{
|
||||
let mut stream = stream.eventsource();
|
||||
|
||||
#[derive(Default)]
|
||||
struct FunctionCallState {
|
||||
name: Option<String>,
|
||||
arguments: String,
|
||||
call_id: Option<String>,
|
||||
}
|
||||
|
||||
let mut function_call_state = FunctionCallState::default();
|
||||
let mut assistant_item: Option<ResponseItem> = None;
|
||||
let mut reasoning_item: Option<ResponseItem> = None;
|
||||
|
||||
loop {
|
||||
let response = timeout(idle_timeout, stream.next()).await;
|
||||
otel_event_manager.log_sse_event(&response, idle_timeout);
|
||||
|
||||
let sse = match response {
|
||||
Ok(Some(Ok(sse))) => sse,
|
||||
Ok(Some(Err(e))) => {
|
||||
debug!("SSE Error: {e:#}");
|
||||
let event = Error::Stream(e.to_string(), None);
|
||||
let _ = tx_event.send(Err(event)).await;
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
if let Some(item) = assistant_item.take() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
if let Some(item) = reasoning_item.take() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = tx_event
|
||||
.send(Err(Error::Stream(
|
||||
"idle timeout waiting for SSE".into(),
|
||||
None,
|
||||
)))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
trace!("chat_completions received SSE chunk: {}", sse.data);
|
||||
|
||||
if sse.data.trim() == "[DONE]" {
|
||||
if let Some(item) = assistant_item.take() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
if let Some(item) = reasoning_item.take() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
let chunk: serde_json::Value = match serde_json::from_str(&sse.data) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let choice_opt = chunk.get("choices").and_then(|c| c.get(0));
|
||||
|
||||
if let Some(choice) = choice_opt {
|
||||
if let Some(content) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
&& !content.is_empty()
|
||||
{
|
||||
append_assistant_text(&tx_event, &mut assistant_item, content.to_string()).await;
|
||||
}
|
||||
|
||||
if let Some(reasoning_val) = choice.get("delta").and_then(|d| d.get("reasoning")) {
|
||||
let mut maybe_text = reasoning_val
|
||||
.as_str()
|
||||
.map(str::to_string)
|
||||
.filter(|s| !s.is_empty());
|
||||
|
||||
if maybe_text.is_none() && reasoning_val.is_object() {
|
||||
if let Some(s) = reasoning_val
|
||||
.get("text")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
{
|
||||
maybe_text = Some(s.to_string());
|
||||
} else if let Some(s) = reasoning_val
|
||||
.get("content")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
{
|
||||
maybe_text = Some(s.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(reasoning) = maybe_text {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, reasoning).await;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(message_reasoning) = choice.get("message").and_then(|m| m.get("reasoning"))
|
||||
{
|
||||
if let Some(s) = message_reasoning.as_str() {
|
||||
if !s.is_empty() {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await;
|
||||
}
|
||||
} else if let Some(obj) = message_reasoning.as_object()
|
||||
&& let Some(s) = obj
|
||||
.get("text")
|
||||
.and_then(|v| v.as_str())
|
||||
.or_else(|| obj.get("content").and_then(|v| v.as_str()))
|
||||
&& !s.is_empty()
|
||||
{
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(tool_calls) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("tool_calls"))
|
||||
.and_then(|v| v.as_array())
|
||||
{
|
||||
for call in tool_calls {
|
||||
if let Some(index) = call.get("index").and_then(serde_json::Value::as_u64)
|
||||
&& index == 0
|
||||
&& let Some(function) = call.get("function")
|
||||
{
|
||||
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
|
||||
function_call_state.name = Some(name.to_string());
|
||||
}
|
||||
if let Some(arguments) = function.get("arguments").and_then(|a| a.as_str())
|
||||
{
|
||||
function_call_state.arguments.push_str(arguments);
|
||||
}
|
||||
if let Some(id) = call.get("id").and_then(|i| i.as_str()) {
|
||||
function_call_state.call_id = Some(id.to_string());
|
||||
}
|
||||
|
||||
if let Some(finish) = choice.get("finish_reason").and_then(|f| f.as_str())
|
||||
&& finish == "tool_calls"
|
||||
&& let Some(name) = function_call_state.name.take()
|
||||
{
|
||||
let call_id = function_call_state.call_id.take().unwrap_or_default();
|
||||
let arguments = std::mem::take(&mut function_call_state.arguments);
|
||||
let item = ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
};
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Sized {
|
||||
fn aggregate(self) -> AggregatedChatStream<Self>
|
||||
where
|
||||
Self: Unpin,
|
||||
{
|
||||
AggregatedChatStream::new(self, AggregateMode::AggregatedOnly)
|
||||
}
|
||||
|
||||
fn streaming_mode(self) -> AggregatedChatStream<Self>
|
||||
where
|
||||
Self: Unpin,
|
||||
{
|
||||
AggregatedChatStream::new(self, AggregateMode::Streaming)
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> AggregateStreamExt for S where S: Stream<Item = Result<ResponseEvent>> + Sized + Unpin {}
|
||||
|
||||
enum AggregateMode {
|
||||
AggregatedOnly,
|
||||
Streaming,
|
||||
}
|
||||
|
||||
pub struct AggregatedChatStream<S> {
|
||||
inner: S,
|
||||
cumulative: String,
|
||||
cumulative_reasoning: String,
|
||||
pending: VecDeque<ResponseEvent>,
|
||||
mode: AggregateMode,
|
||||
}
|
||||
|
||||
impl<S> AggregatedChatStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<ResponseEvent>> + Unpin,
|
||||
{
|
||||
fn new(inner: S, mode: AggregateMode) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
cumulative: String::new(),
|
||||
cumulative_reasoning: String::new(),
|
||||
pending: VecDeque::new(),
|
||||
mode,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> Stream for AggregatedChatStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<ResponseEvent>> + Unpin,
|
||||
{
|
||||
type Item = Result<ResponseEvent>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
if let Some(ev) = self.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
|
||||
loop {
|
||||
match Pin::new(&mut self.inner).poll_next(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
|
||||
let is_assistant_message = matches!(
|
||||
&item,
|
||||
ResponseItem::Message { role, .. } if role == "assistant"
|
||||
);
|
||||
|
||||
if is_assistant_message {
|
||||
match self.mode {
|
||||
AggregateMode::AggregatedOnly => {
|
||||
if self.cumulative.is_empty()
|
||||
&& let ResponseItem::Message { content, .. } = &item
|
||||
&& let Some(text) = content.iter().find_map(|c| match c {
|
||||
ContentItem::OutputText { text } => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
{
|
||||
self.cumulative.push_str(text);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
AggregateMode::Streaming => {
|
||||
if self.cumulative.is_empty() {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
|
||||
item,
|
||||
))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}))) => {
|
||||
let mut emitted_any = false;
|
||||
|
||||
if !self.cumulative_reasoning.is_empty()
|
||||
&& matches!(self.mode, AggregateMode::AggregatedOnly)
|
||||
{
|
||||
let aggregated_reasoning = ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![ReasoningItemContent::ReasoningText {
|
||||
text: std::mem::take(&mut self.cumulative_reasoning),
|
||||
}]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
self.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_reasoning));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
if !self.cumulative.is_empty() {
|
||||
let aggregated_message = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![ContentItem::OutputText {
|
||||
text: std::mem::take(&mut self.cumulative),
|
||||
}],
|
||||
};
|
||||
self.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_message));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
if emitted_any {
|
||||
self.pending.push_back(ResponseEvent::Completed {
|
||||
response_id: response_id.clone(),
|
||||
token_usage: token_usage.clone(),
|
||||
});
|
||||
if let Some(ev) = self.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
}
|
||||
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
})));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Created))) => continue,
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => {
|
||||
self.cumulative.push_str(&delta);
|
||||
if matches!(self.mode, AggregateMode::Streaming) {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta))));
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => {
|
||||
self.cumulative_reasoning.push_str(&delta);
|
||||
if matches!(self.mode, AggregateMode::Streaming) {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta))));
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => continue,
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded))) => continue,
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_tools_json_for_chat_completions_api(
|
||||
tools: &[serde_json::Value],
|
||||
) -> Result<Vec<serde_json::Value>> {
|
||||
let tools_json = tools
|
||||
.iter()
|
||||
.filter_map(|tool| {
|
||||
if tool.get("type") != Some(&serde_json::Value::String("function".to_string())) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let function_value = if let Some(function) = tool.get("function") {
|
||||
function.clone()
|
||||
} else if let Some(map) = tool.as_object() {
|
||||
let mut function = map.clone();
|
||||
function.remove("type");
|
||||
Value::Object(function)
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
|
||||
Some(json!({
|
||||
"type": "function",
|
||||
"function": function_value,
|
||||
}))
|
||||
})
|
||||
.collect::<Vec<serde_json::Value>>();
|
||||
Ok(tools_json)
|
||||
}
|
||||
|
||||
fn backoff(attempt: u64) -> Duration {
|
||||
let capped = attempt.min(6);
|
||||
Duration::from_millis(100 * 2u64.pow(capped as u32))
|
||||
}
|
||||
42
codex-rs/api-client/src/error.rs
Normal file
42
codex-rs/api-client/src/error.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum Error {
|
||||
#[error("{0}")]
|
||||
UnsupportedOperation(String),
|
||||
#[error(transparent)]
|
||||
Http(#[from] reqwest::Error),
|
||||
#[error("{source}")]
|
||||
ResponseStreamFailed {
|
||||
#[source]
|
||||
source: reqwest::Error,
|
||||
request_id: Option<String>,
|
||||
},
|
||||
#[error("{0}")]
|
||||
Stream(String, Option<Duration>),
|
||||
#[error("unexpected status {status}: {body}")]
|
||||
UnexpectedStatus {
|
||||
status: reqwest::StatusCode,
|
||||
body: String,
|
||||
},
|
||||
#[error("retry limit reached (status {status}, request id: {request_id:?})")]
|
||||
RetryLimit {
|
||||
status: reqwest::StatusCode,
|
||||
request_id: Option<String>,
|
||||
},
|
||||
#[error("missing environment variable {var}")]
|
||||
MissingEnvVar {
|
||||
var: String,
|
||||
instructions: Option<String>,
|
||||
},
|
||||
#[error("{0}")]
|
||||
Auth(String),
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
#[error("{0}")]
|
||||
Other(String),
|
||||
}
|
||||
35
codex-rs/api-client/src/lib.rs
Normal file
35
codex-rs/api-client/src/lib.rs
Normal file
@@ -0,0 +1,35 @@
|
||||
pub mod api;
|
||||
pub mod auth;
|
||||
pub mod chat;
|
||||
pub mod error;
|
||||
pub mod model_provider;
|
||||
pub mod prompt;
|
||||
pub mod responses;
|
||||
pub mod stream;
|
||||
|
||||
pub use crate::api::ApiClient;
|
||||
pub use crate::auth::AuthContext;
|
||||
pub use crate::auth::AuthProvider;
|
||||
pub use crate::chat::AggregateStreamExt;
|
||||
pub use crate::chat::ChatAggregationMode;
|
||||
pub use crate::chat::ChatCompletionsApiClient;
|
||||
pub use crate::chat::ChatCompletionsApiClientConfig;
|
||||
pub use crate::error::Error;
|
||||
pub use crate::error::Result;
|
||||
pub use crate::model_provider::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
pub use crate::model_provider::ModelProviderInfo;
|
||||
pub use crate::model_provider::WireApi;
|
||||
pub use crate::model_provider::built_in_model_providers;
|
||||
pub use crate::model_provider::create_oss_provider;
|
||||
pub use crate::model_provider::create_oss_provider_with_base_url;
|
||||
pub use crate::prompt::Prompt;
|
||||
pub use crate::responses::ResponsesApiClient;
|
||||
pub use crate::responses::ResponsesApiClientConfig;
|
||||
pub use crate::responses::stream_from_fixture;
|
||||
pub use crate::stream::EventStream;
|
||||
pub use crate::stream::Reasoning;
|
||||
pub use crate::stream::ResponseEvent;
|
||||
pub use crate::stream::ResponseStream;
|
||||
pub use crate::stream::TextControls;
|
||||
pub use crate::stream::TextFormat;
|
||||
pub use crate::stream::TextFormatType;
|
||||
@@ -5,17 +5,18 @@
|
||||
//! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers`
|
||||
//! key. These override or extend the defaults at runtime.
|
||||
|
||||
use crate::CodexAuth;
|
||||
use crate::default_client::CodexHttpClient;
|
||||
use crate::default_client::CodexRequestBuilder;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use std::collections::HashMap;
|
||||
use std::env::VarError;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::error::EnvVarError;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::auth::AuthContext;
|
||||
use crate::error::Error;
|
||||
use crate::error::Result;
|
||||
|
||||
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
|
||||
const DEFAULT_STREAM_MAX_RETRIES: u64 = 5;
|
||||
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
|
||||
@@ -23,19 +24,19 @@ const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
|
||||
const MAX_STREAM_MAX_RETRIES: u64 = 100;
|
||||
/// Hard cap for user-configured `request_max_retries`.
|
||||
const MAX_REQUEST_MAX_RETRIES: u64 = 100;
|
||||
const DEFAULT_OLLAMA_PORT: u32 = 11434;
|
||||
|
||||
/// Wire protocol that the provider speaks. Most third-party services only
|
||||
/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI
|
||||
/// itself (and a handful of others) additionally expose the more modern
|
||||
/// *Responses* API. The two protocols use different request/response shapes
|
||||
/// and *cannot* be auto-detected at runtime, therefore each provider entry
|
||||
/// Responses API. The two protocols use different request/response shapes
|
||||
/// and cannot be auto-detected at runtime, therefore each provider entry
|
||||
/// must declare which one it expects.
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum WireApi {
|
||||
/// The Responses API exposed by OpenAI at `/v1/responses`.
|
||||
Responses,
|
||||
|
||||
/// Regular Chat Completions compatible with `/v1/chat/completions`.
|
||||
#[default]
|
||||
Chat,
|
||||
@@ -50,87 +51,79 @@ pub struct ModelProviderInfo {
|
||||
pub base_url: Option<String>,
|
||||
/// Environment variable that stores the user's API key for this provider.
|
||||
pub env_key: Option<String>,
|
||||
|
||||
/// Optional instructions to help the user get a valid value for the
|
||||
/// variable and set it.
|
||||
pub env_key_instructions: Option<String>,
|
||||
|
||||
/// Value to use with `Authorization: Bearer <token>` header. Use of this
|
||||
/// config is discouraged in favor of `env_key` for security reasons, but
|
||||
/// this may be necessary when using this programmatically.
|
||||
pub experimental_bearer_token: Option<String>,
|
||||
|
||||
/// Which wire protocol this provider expects.
|
||||
#[serde(default)]
|
||||
pub wire_api: WireApi,
|
||||
|
||||
/// Optional query parameters to append to the base URL.
|
||||
pub query_params: Option<HashMap<String, String>>,
|
||||
|
||||
/// Additional HTTP headers to include in requests to this provider where
|
||||
/// the (key, value) pairs are the header name and value.
|
||||
pub http_headers: Option<HashMap<String, String>>,
|
||||
|
||||
/// Optional HTTP headers to include in requests to this provider where the
|
||||
/// (key, value) pairs are the header name and _environment variable_ whose
|
||||
/// (key, value) pairs are the header name and environment variable whose
|
||||
/// value should be used. If the environment variable is not set, or the
|
||||
/// value is empty, the header will not be included in the request.
|
||||
pub env_http_headers: Option<HashMap<String, String>>,
|
||||
|
||||
/// Maximum number of times to retry a failed HTTP request to this provider.
|
||||
pub request_max_retries: Option<u64>,
|
||||
|
||||
/// Number of times to retry reconnecting a dropped streaming response before failing.
|
||||
pub stream_max_retries: Option<u64>,
|
||||
|
||||
/// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating
|
||||
/// the connection as lost.
|
||||
pub stream_idle_timeout_ms: Option<u64>,
|
||||
|
||||
/// Does this provider require an OpenAI API Key or ChatGPT login token? If true,
|
||||
/// user is presented with login screen on first run, and login preference and token/key
|
||||
/// are stored in auth.json. If false (which is the default), login screen is skipped,
|
||||
/// and API key (if needed) comes from the "env_key" environment variable.
|
||||
/// the user is presented with a login screen on first run, and login preference and token/key
|
||||
/// are stored in auth.json. If false (which is the default), the login screen is skipped,
|
||||
/// and the API key (if needed) comes from the `env_key` environment variable.
|
||||
#[serde(default)]
|
||||
pub requires_openai_auth: bool,
|
||||
}
|
||||
|
||||
impl ModelProviderInfo {
|
||||
/// Construct a `POST` RequestBuilder for the given URL using the provided
|
||||
/// [`CodexHttpClient`] applying:
|
||||
/// • provider-specific headers (static + env based)
|
||||
/// • Bearer auth header when an API key is available.
|
||||
/// • Auth token for OAuth.
|
||||
/// Construct a `POST` request builder for the given URL using the provided
|
||||
/// [`reqwest::Client`] applying:
|
||||
/// - provider-specific headers (static and environment based)
|
||||
/// - Bearer auth header when an API key is available
|
||||
/// - Auth token for OAuth
|
||||
///
|
||||
/// If the provider declares an `env_key` but the variable is missing/empty, returns an [`Err`] identical to the
|
||||
/// one produced by [`ModelProviderInfo::api_key`].
|
||||
pub async fn create_request_builder<'a>(
|
||||
&'a self,
|
||||
client: &'a CodexHttpClient,
|
||||
auth: &Option<CodexAuth>,
|
||||
) -> crate::error::Result<CodexRequestBuilder> {
|
||||
/// If the provider declares an `env_key` but the variable is missing or empty, this returns an
|
||||
/// error identical to the one produced by [`ModelProviderInfo::api_key`].
|
||||
pub async fn create_request_builder(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
auth: &Option<AuthContext>,
|
||||
) -> Result<reqwest::RequestBuilder> {
|
||||
let effective_auth = if let Some(secret_key) = &self.experimental_bearer_token {
|
||||
Some(CodexAuth::from_api_key(secret_key))
|
||||
Some(AuthContext {
|
||||
mode: AuthMode::ApiKey,
|
||||
bearer_token: Some(secret_key.clone()),
|
||||
account_id: None,
|
||||
})
|
||||
} else {
|
||||
match self.api_key() {
|
||||
Ok(Some(key)) => Some(CodexAuth::from_api_key(&key)),
|
||||
Ok(None) => auth.clone(),
|
||||
Err(err) => {
|
||||
if auth.is_some() {
|
||||
auth.clone()
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
match self.api_key()? {
|
||||
Some(key) => Some(AuthContext {
|
||||
mode: AuthMode::ApiKey,
|
||||
bearer_token: Some(key),
|
||||
account_id: None,
|
||||
}),
|
||||
None => auth.clone(),
|
||||
}
|
||||
};
|
||||
|
||||
let url = self.get_full_url(&effective_auth);
|
||||
|
||||
let url = self.get_full_url(effective_auth.as_ref());
|
||||
let mut builder = client.post(url);
|
||||
|
||||
if let Some(auth) = effective_auth.as_ref() {
|
||||
builder = builder.bearer_auth(auth.get_token().await?);
|
||||
if let Some(context) = effective_auth.as_ref()
|
||||
&& let Some(token) = context.bearer_token.as_ref()
|
||||
{
|
||||
builder = builder.bearer_auth(token);
|
||||
}
|
||||
|
||||
Ok(self.apply_http_headers(builder))
|
||||
@@ -149,10 +142,10 @@ impl ModelProviderInfo {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn get_full_url(&self, auth: &Option<CodexAuth>) -> String {
|
||||
pub fn get_full_url(&self, auth: Option<&AuthContext>) -> String {
|
||||
let default_base_url = if matches!(
|
||||
auth,
|
||||
Some(CodexAuth {
|
||||
Some(AuthContext {
|
||||
mode: AuthMode::ChatGPT,
|
||||
..
|
||||
})
|
||||
@@ -165,7 +158,7 @@ impl ModelProviderInfo {
|
||||
let base_url = self
|
||||
.base_url
|
||||
.clone()
|
||||
.unwrap_or(default_base_url.to_string());
|
||||
.unwrap_or_else(|| default_base_url.to_string());
|
||||
|
||||
match self.wire_api {
|
||||
WireApi::Responses => format!("{base_url}/responses{query_string}"),
|
||||
@@ -173,7 +166,7 @@ impl ModelProviderInfo {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_azure_responses_endpoint(&self) -> bool {
|
||||
pub fn is_azure_responses_endpoint(&self) -> bool {
|
||||
if self.wire_api != WireApi::Responses {
|
||||
return false;
|
||||
}
|
||||
@@ -188,10 +181,9 @@ impl ModelProviderInfo {
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Apply provider-specific HTTP headers (both static and environment-based)
|
||||
/// onto an existing [`CodexRequestBuilder`] and return the updated
|
||||
/// builder.
|
||||
fn apply_http_headers(&self, mut builder: CodexRequestBuilder) -> CodexRequestBuilder {
|
||||
/// Apply provider-specific HTTP headers (both static and environment-based) onto an existing
|
||||
/// [`reqwest::RequestBuilder`] and return the updated builder.
|
||||
fn apply_http_headers(&self, mut builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||
if let Some(extra) = &self.http_headers {
|
||||
for (k, v) in extra {
|
||||
builder = builder.header(k, v);
|
||||
@@ -210,10 +202,9 @@ impl ModelProviderInfo {
|
||||
builder
|
||||
}
|
||||
|
||||
/// If `env_key` is Some, returns the API key for this provider if present
|
||||
/// (and non-empty) in the environment. If `env_key` is required but
|
||||
/// cannot be found, returns an error.
|
||||
pub fn api_key(&self) -> crate::error::Result<Option<String>> {
|
||||
/// If `env_key` is `Some`, returns the API key for this provider if present (and non-empty) in
|
||||
/// the environment. If `env_key` is required but cannot be found, returns an error.
|
||||
pub fn api_key(&self) -> Result<Option<String>> {
|
||||
match &self.env_key {
|
||||
Some(env_key) => {
|
||||
let env_value = std::env::var(env_key);
|
||||
@@ -225,11 +216,9 @@ impl ModelProviderInfo {
|
||||
Ok(Some(v))
|
||||
}
|
||||
})
|
||||
.map_err(|_| {
|
||||
crate::error::CodexErr::EnvVar(EnvVarError {
|
||||
var: env_key.clone(),
|
||||
instructions: self.env_key_instructions.clone(),
|
||||
})
|
||||
.map_err(|_| Error::MissingEnvVar {
|
||||
var: env_key.clone(),
|
||||
instructions: self.env_key_instructions.clone(),
|
||||
})
|
||||
}
|
||||
None => Ok(None),
|
||||
@@ -258,28 +247,23 @@ impl ModelProviderInfo {
|
||||
}
|
||||
}
|
||||
|
||||
const DEFAULT_OLLAMA_PORT: u32 = 11434;
|
||||
|
||||
pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss";
|
||||
|
||||
/// Built-in default provider list.
|
||||
pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
use ModelProviderInfo as P;
|
||||
|
||||
// We do not want to be in the business of adjucating which third-party
|
||||
// providers are bundled with Codex CLI, so we only include the OpenAI and
|
||||
// open source ("oss") providers by default. Users are encouraged to add to
|
||||
// `model_providers` in config.toml to add their own providers.
|
||||
// We do not want to be in the business of adjudicating which third-party providers are bundled
|
||||
// with Codex CLI, so we only include the OpenAI and open source ("oss") providers by default.
|
||||
// Users are encouraged to add to `model_providers` in config.toml to add their own providers.
|
||||
[
|
||||
(
|
||||
"openai",
|
||||
P {
|
||||
name: "OpenAI".into(),
|
||||
// Allow users to override the default OpenAI endpoint by
|
||||
// exporting `OPENAI_BASE_URL`. This is useful when pointing
|
||||
// Codex at a proxy, mock server, or Azure-style deployment
|
||||
// without requiring a full TOML override for the built-in
|
||||
// OpenAI provider.
|
||||
// Allow users to override the default OpenAI endpoint by exporting `OPENAI_BASE_URL`.
|
||||
// This is useful when pointing Codex at a proxy, mock server, or Azure-style
|
||||
// deployment without requiring a full TOML override for the built-in OpenAI provider.
|
||||
base_url: std::env::var("OPENAI_BASE_URL")
|
||||
.ok()
|
||||
.filter(|v| !v.trim().is_empty()),
|
||||
@@ -318,9 +302,10 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Convenience helper for the built-in OSS provider.
|
||||
pub fn create_oss_provider() -> ModelProviderInfo {
|
||||
// These CODEX_OSS_ environment variables are experimental: we may
|
||||
// switch to reading values from config.toml instead.
|
||||
// These CODEX_OSS_ environment variables are experimental: we may switch to reading values from
|
||||
// config.toml instead.
|
||||
let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL")
|
||||
.ok()
|
||||
.filter(|v| !v.trim().is_empty())
|
||||
@@ -366,23 +351,23 @@ fn matches_azure_responses_base_url(base_url: &str) -> bool {
|
||||
"azure-api.",
|
||||
"azurefd.",
|
||||
];
|
||||
AZURE_MARKERS.iter().any(|marker| base.contains(marker))
|
||||
AZURE_MARKERS.iter().any(|needle| base.contains(needle))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use maplit::hashmap;
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_ollama_model_provider_toml() {
|
||||
fn deserializes_defaults_without_optional_fields() {
|
||||
let azure_provider_toml = r#"
|
||||
name = "Ollama"
|
||||
base_url = "http://localhost:11434/v1"
|
||||
name = "Azure"
|
||||
base_url = "https://xxxxx.openai.azure.com/openai"
|
||||
"#;
|
||||
let expected_provider = ModelProviderInfo {
|
||||
name: "Ollama".into(),
|
||||
base_url: Some("http://localhost:11434/v1".into()),
|
||||
name: "Azure".into(),
|
||||
base_url: Some("https://xxxxx.openai.azure.com/openai".into()),
|
||||
env_key: None,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
@@ -415,7 +400,7 @@ query_params = { api-version = "2025-04-01-preview" }
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: Some(maplit::hashmap! {
|
||||
query_params: Some(hashmap! {
|
||||
"api-version".to_string() => "2025-04-01-preview".to_string(),
|
||||
}),
|
||||
http_headers: None,
|
||||
@@ -447,10 +432,10 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
|
||||
experimental_bearer_token: None,
|
||||
wire_api: WireApi::Chat,
|
||||
query_params: None,
|
||||
http_headers: Some(maplit::hashmap! {
|
||||
http_headers: Some(hashmap! {
|
||||
"X-Example-Header".to_string() => "example-value".to_string(),
|
||||
}),
|
||||
env_http_headers: Some(maplit::hashmap! {
|
||||
env_http_headers: Some(hashmap! {
|
||||
"X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(),
|
||||
}),
|
||||
request_max_retries: None,
|
||||
@@ -516,16 +501,12 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
|
||||
};
|
||||
assert!(named_provider.is_azure_responses_endpoint());
|
||||
|
||||
let negative_cases = [
|
||||
"https://api.openai.com/v1",
|
||||
"https://example.com/openai",
|
||||
"https://myproxy.azurewebsites.net/openai",
|
||||
];
|
||||
let negative_cases = ["https://api.openai.com/v1", "https://example.com"];
|
||||
for base_url in negative_cases {
|
||||
let provider = provider_for(base_url);
|
||||
assert!(
|
||||
!provider.is_azure_responses_endpoint(),
|
||||
"expected {base_url} not to be detected as Azure"
|
||||
"expected {base_url} to be non-Azure"
|
||||
);
|
||||
}
|
||||
}
|
||||
49
codex-rs/api-client/src/prompt.rs
Normal file
49
codex-rs/api-client/src/prompt.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::Reasoning;
|
||||
use crate::TextControls;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Prompt {
|
||||
pub instructions: String,
|
||||
pub input: Vec<ResponseItem>,
|
||||
pub tools: Vec<Value>,
|
||||
pub parallel_tool_calls: bool,
|
||||
pub output_schema: Option<Value>,
|
||||
pub reasoning: Option<Reasoning>,
|
||||
pub text_controls: Option<TextControls>,
|
||||
pub prompt_cache_key: Option<String>,
|
||||
pub previous_response_id: Option<String>,
|
||||
pub session_source: Option<SessionSource>,
|
||||
}
|
||||
|
||||
impl Prompt {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
instructions: String,
|
||||
input: Vec<ResponseItem>,
|
||||
tools: Vec<Value>,
|
||||
parallel_tool_calls: bool,
|
||||
output_schema: Option<Value>,
|
||||
reasoning: Option<Reasoning>,
|
||||
text_controls: Option<TextControls>,
|
||||
prompt_cache_key: Option<String>,
|
||||
previous_response_id: Option<String>,
|
||||
session_source: Option<SessionSource>,
|
||||
) -> Self {
|
||||
Self {
|
||||
instructions,
|
||||
input,
|
||||
tools,
|
||||
parallel_tool_calls,
|
||||
output_schema,
|
||||
reasoning,
|
||||
text_controls,
|
||||
prompt_cache_key,
|
||||
previous_response_id,
|
||||
session_source,
|
||||
}
|
||||
}
|
||||
}
|
||||
742
codex-rs/api-client/src/responses.rs
Normal file
742
codex-rs/api-client/src/responses.rs
Normal file
@@ -0,0 +1,742 @@
|
||||
use std::io::BufRead;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::ConversationId;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::RateLimitSnapshot;
|
||||
use codex_protocol::protocol::RateLimitWindow;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::Stream;
|
||||
use futures::StreamExt;
|
||||
use futures::TryStreamExt;
|
||||
use regex_lite::Regex;
|
||||
use reqwest::StatusCode;
|
||||
use reqwest::header::HeaderMap;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
use tokio_util::io::ReaderStream;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
|
||||
use crate::api::ApiClient;
|
||||
use crate::auth::AuthProvider;
|
||||
use crate::error::Error;
|
||||
use crate::model_provider::ModelProviderInfo;
|
||||
use crate::prompt::Prompt;
|
||||
use crate::stream::ResponseEvent;
|
||||
use crate::stream::ResponseStream;
|
||||
|
||||
type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ResponsesApiClientConfig {
|
||||
pub http_client: reqwest::Client,
|
||||
pub provider: ModelProviderInfo,
|
||||
pub model: String,
|
||||
pub conversation_id: ConversationId,
|
||||
pub auth_provider: Option<Arc<dyn AuthProvider>>,
|
||||
pub otel_event_manager: OtelEventManager,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ResponsesApiClient {
|
||||
config: ResponsesApiClientConfig,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ApiClient for ResponsesApiClient {
|
||||
type Config = ResponsesApiClientConfig;
|
||||
|
||||
async fn new(config: Self::Config) -> Result<Self> {
|
||||
Ok(Self { config })
|
||||
}
|
||||
|
||||
async fn stream(&self, prompt: Prompt) -> Result<ResponseStream> {
|
||||
if self.config.provider.wire_api != crate::model_provider::WireApi::Responses {
|
||||
return Err(Error::UnsupportedOperation(
|
||||
"ResponsesApiClient requires a Responses provider".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut payload_json = self.build_payload(&prompt)?;
|
||||
|
||||
if self.config.provider.is_azure_responses_endpoint()
|
||||
&& let Some(input_value) = payload_json.get_mut("input")
|
||||
&& let Some(array) = input_value.as_array_mut()
|
||||
{
|
||||
attach_item_ids_array(array, &prompt.input);
|
||||
}
|
||||
|
||||
let max_attempts = self.config.provider.request_max_retries();
|
||||
for attempt in 0..=max_attempts {
|
||||
match self
|
||||
.attempt_stream_responses(attempt, &prompt, &payload_json)
|
||||
.await
|
||||
{
|
||||
Ok(stream) => return Ok(stream),
|
||||
Err(StreamAttemptError::Fatal(err)) => return Err(err),
|
||||
Err(retryable) => {
|
||||
if attempt == max_attempts {
|
||||
return Err(retryable.into_error());
|
||||
}
|
||||
|
||||
tokio::time::sleep(retryable.delay(attempt)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unreachable!("attempt_stream_responses should always return");
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponsesApiClient {
|
||||
fn build_payload(&self, prompt: &Prompt) -> Result<Value> {
|
||||
let azure_workaround = self.config.provider.is_azure_responses_endpoint();
|
||||
|
||||
let mut payload = json!({
|
||||
"model": self.config.model,
|
||||
"instructions": prompt.instructions,
|
||||
"input": prompt.input,
|
||||
"tools": prompt.tools,
|
||||
"tool_choice": "auto",
|
||||
"parallel_tool_calls": prompt.parallel_tool_calls,
|
||||
"store": azure_workaround,
|
||||
"stream": true,
|
||||
"prompt_cache_key": prompt
|
||||
.prompt_cache_key
|
||||
.clone()
|
||||
.unwrap_or_else(|| self.config.conversation_id.to_string()),
|
||||
});
|
||||
|
||||
if let Some(reasoning) = prompt.reasoning.as_ref()
|
||||
&& let Some(obj) = payload.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), serde_json::to_value(reasoning)?);
|
||||
}
|
||||
|
||||
if let Some(text) = prompt.text_controls.as_ref()
|
||||
&& let Some(obj) = payload.as_object_mut()
|
||||
{
|
||||
obj.insert("text".to_string(), serde_json::to_value(text)?);
|
||||
}
|
||||
|
||||
if let Some(prev) = prompt.previous_response_id.as_ref()
|
||||
&& let Some(obj) = payload.as_object_mut()
|
||||
{
|
||||
obj.insert(
|
||||
"previous_response_id".to_string(),
|
||||
Value::String(prev.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
let include = if prompt.reasoning.is_some() {
|
||||
vec!["reasoning.encrypted_content".to_string()]
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
if let Some(obj) = payload.as_object_mut() {
|
||||
obj.insert(
|
||||
"include".to_string(),
|
||||
Value::Array(include.into_iter().map(Value::String).collect()),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(payload)
|
||||
}
|
||||
|
||||
async fn attempt_stream_responses(
|
||||
&self,
|
||||
attempt: u64,
|
||||
prompt: &Prompt,
|
||||
payload_json: &Value,
|
||||
) -> std::result::Result<ResponseStream, StreamAttemptError> {
|
||||
let auth = match &self.config.auth_provider {
|
||||
Some(provider) => provider.auth_context().await,
|
||||
None => None,
|
||||
};
|
||||
|
||||
trace!(
|
||||
"POST to {}: {:?}",
|
||||
self.config.provider.get_full_url(auth.as_ref()),
|
||||
serde_json::to_string(payload_json)
|
||||
.unwrap_or_else(|_| "<unable to serialize payload>".to_string())
|
||||
);
|
||||
|
||||
let mut req_builder = self
|
||||
.config
|
||||
.provider
|
||||
.create_request_builder(&self.config.http_client, &auth)
|
||||
.await
|
||||
.map_err(StreamAttemptError::Fatal)?;
|
||||
|
||||
if let Some(SessionSource::SubAgent(sub)) = prompt.session_source.as_ref() {
|
||||
let subagent = match sub {
|
||||
SubAgentSource::Other(label) => label.clone(),
|
||||
other => serde_json::to_value(other)
|
||||
.ok()
|
||||
.and_then(|v| v.as_str().map(ToString::to_string))
|
||||
.unwrap_or_else(|| "other".to_string()),
|
||||
};
|
||||
req_builder = req_builder.header("x-openai-subagent", subagent);
|
||||
}
|
||||
|
||||
req_builder = req_builder
|
||||
.header("conversation_id", self.config.conversation_id.to_string())
|
||||
.header("session_id", self.config.conversation_id.to_string())
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(payload_json);
|
||||
|
||||
if let Some(ctx) = auth.as_ref()
|
||||
&& ctx.mode == AuthMode::ChatGPT
|
||||
&& let Some(account_id) = ctx.account_id.as_ref()
|
||||
{
|
||||
req_builder = req_builder.header("chatgpt-account-id", account_id);
|
||||
}
|
||||
|
||||
let res = self
|
||||
.config
|
||||
.otel_event_manager
|
||||
.log_request(attempt, || req_builder.send())
|
||||
.await;
|
||||
|
||||
let mut request_id = None;
|
||||
if let Ok(resp) = &res {
|
||||
request_id = resp
|
||||
.headers()
|
||||
.get("cf-ray")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(std::string::ToString::to_string);
|
||||
}
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
|
||||
if let Some(snapshot) = parse_rate_limit_snapshot(resp.headers())
|
||||
&& tx_event
|
||||
.send(Ok(ResponseEvent::RateLimits(snapshot)))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
debug!("receiver dropped rate limit snapshot event");
|
||||
}
|
||||
|
||||
let request_id_for_stream = request_id.clone();
|
||||
let stream = resp
|
||||
.bytes_stream()
|
||||
.map_err(move |err| Error::ResponseStreamFailed {
|
||||
source: err,
|
||||
request_id: request_id_for_stream.clone(),
|
||||
});
|
||||
tokio::spawn(process_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
self.config.provider.stream_idle_timeout(),
|
||||
self.config.otel_event_manager.clone(),
|
||||
));
|
||||
|
||||
Ok(ResponseStream { rx_event })
|
||||
}
|
||||
Ok(res) => {
|
||||
let status = res.status();
|
||||
|
||||
let retry_after_secs = res
|
||||
.headers()
|
||||
.get(reqwest::header::RETRY_AFTER)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok());
|
||||
let retry_after = retry_after_secs.map(|s| Duration::from_millis(s * 1_000));
|
||||
|
||||
if status == StatusCode::UNAUTHORIZED
|
||||
&& let Some(provider) = self.config.auth_provider.as_ref()
|
||||
&& let Some(ctx) = auth.as_ref()
|
||||
&& ctx.mode == AuthMode::ChatGPT
|
||||
{
|
||||
provider
|
||||
.refresh_token()
|
||||
.await
|
||||
.map_err(|err| StreamAttemptError::Fatal(Error::Auth(err)))?;
|
||||
}
|
||||
|
||||
if !(status == StatusCode::TOO_MANY_REQUESTS
|
||||
|| status == StatusCode::UNAUTHORIZED
|
||||
|| status.is_server_error())
|
||||
{
|
||||
// Surface error body.
|
||||
let body = res
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "<failed to read response>".to_string());
|
||||
return Err(StreamAttemptError::Fatal(Error::UnexpectedStatus {
|
||||
status,
|
||||
body,
|
||||
}));
|
||||
}
|
||||
|
||||
Err(StreamAttemptError::RetryableHttpError {
|
||||
status,
|
||||
retry_after,
|
||||
request_id,
|
||||
})
|
||||
}
|
||||
Err(err) => Err(StreamAttemptError::RetryableTransportError(Error::Http(
|
||||
err,
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum StreamAttemptError {
|
||||
RetryableHttpError {
|
||||
status: StatusCode,
|
||||
retry_after: Option<Duration>,
|
||||
request_id: Option<String>,
|
||||
},
|
||||
RetryableTransportError(Error),
|
||||
Fatal(Error),
|
||||
}
|
||||
|
||||
impl StreamAttemptError {
|
||||
fn delay(&self, attempt: u64) -> Duration {
|
||||
let backoff_attempt = attempt + 1;
|
||||
match self {
|
||||
StreamAttemptError::RetryableHttpError { retry_after, .. } => {
|
||||
retry_after.unwrap_or_else(|| backoff(backoff_attempt))
|
||||
}
|
||||
StreamAttemptError::RetryableTransportError { .. } => backoff(backoff_attempt),
|
||||
StreamAttemptError::Fatal(_) => Duration::from_secs(0),
|
||||
}
|
||||
}
|
||||
|
||||
fn into_error(self) -> Error {
|
||||
match self {
|
||||
StreamAttemptError::RetryableHttpError {
|
||||
status, request_id, ..
|
||||
} => Error::RetryLimit { status, request_id },
|
||||
StreamAttemptError::RetryableTransportError(error) => error,
|
||||
StreamAttemptError::Fatal(error) => error,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct SseEvent {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
response: Option<Value>,
|
||||
item: Option<Value>,
|
||||
delta: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseCompleted {
|
||||
id: String,
|
||||
usage: Option<ResponseCompletedUsage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseCompletedUsage {
|
||||
input_tokens: i64,
|
||||
input_tokens_details: Option<ResponseCompletedInputTokensDetails>,
|
||||
output_tokens: i64,
|
||||
output_tokens_details: Option<ResponseCompletedOutputTokensDetails>,
|
||||
total_tokens: i64,
|
||||
}
|
||||
|
||||
impl From<ResponseCompletedUsage> for TokenUsage {
|
||||
fn from(val: ResponseCompletedUsage) -> Self {
|
||||
TokenUsage {
|
||||
input_tokens: val.input_tokens,
|
||||
cached_input_tokens: val
|
||||
.input_tokens_details
|
||||
.map(|d| d.cached_tokens)
|
||||
.unwrap_or(0),
|
||||
output_tokens: val.output_tokens,
|
||||
reasoning_output_tokens: val
|
||||
.output_tokens_details
|
||||
.map(|d| d.reasoning_tokens)
|
||||
.unwrap_or(0),
|
||||
total_tokens: val.total_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseCompletedInputTokensDetails {
|
||||
cached_tokens: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseCompletedOutputTokensDetails {
|
||||
reasoning_tokens: i64,
|
||||
}
|
||||
|
||||
fn attach_item_ids_array(items: &mut [Value], original_items: &[ResponseItem]) {
|
||||
for (value, item) in items.iter_mut().zip(original_items.iter()) {
|
||||
if let ResponseItem::Reasoning { id, .. }
|
||||
| ResponseItem::Message { id: Some(id), .. }
|
||||
| ResponseItem::WebSearchCall { id: Some(id), .. }
|
||||
| ResponseItem::FunctionCall { id: Some(id), .. }
|
||||
| ResponseItem::LocalShellCall { id: Some(id), .. }
|
||||
| ResponseItem::CustomToolCall { id: Some(id), .. }
|
||||
| ResponseItem::CustomToolCallOutput { call_id: id, .. }
|
||||
| ResponseItem::FunctionCallOutput { call_id: id, .. } = item
|
||||
{
|
||||
if id.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(obj) = value.as_object_mut() {
|
||||
obj.insert("id".to_string(), Value::String(id.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_rate_limit_snapshot(headers: &HeaderMap) -> Option<RateLimitSnapshot> {
|
||||
let primary = parse_rate_limit_window(
|
||||
headers,
|
||||
"x-codex-primary-used-percent",
|
||||
"x-codex-primary-window-minutes",
|
||||
"x-codex-primary-reset-at",
|
||||
);
|
||||
|
||||
let secondary = parse_rate_limit_window(
|
||||
headers,
|
||||
"x-codex-secondary-used-percent",
|
||||
"x-codex-secondary-window-minutes",
|
||||
"x-codex-secondary-reset-at",
|
||||
);
|
||||
|
||||
Some(RateLimitSnapshot { primary, secondary })
|
||||
}
|
||||
|
||||
fn parse_rate_limit_window(
|
||||
headers: &HeaderMap,
|
||||
used_percent_header: &str,
|
||||
window_minutes_header: &str,
|
||||
resets_at_header: &str,
|
||||
) -> Option<RateLimitWindow> {
|
||||
let used_percent: Option<f64> = parse_header_f64(headers, used_percent_header);
|
||||
|
||||
used_percent.and_then(|used_percent| {
|
||||
let window_minutes = parse_header_i64(headers, window_minutes_header);
|
||||
let resets_at = parse_header_i64(headers, resets_at_header);
|
||||
|
||||
let has_data = used_percent != 0.0
|
||||
|| window_minutes.is_some_and(|minutes| minutes != 0)
|
||||
|| resets_at.is_some();
|
||||
|
||||
has_data.then_some(RateLimitWindow {
|
||||
used_percent,
|
||||
window_minutes,
|
||||
resets_at,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_header_f64(headers: &HeaderMap, name: &str) -> Option<f64> {
|
||||
parse_header_str(headers, name)?
|
||||
.parse::<f64>()
|
||||
.ok()
|
||||
.filter(|v| v.is_finite())
|
||||
}
|
||||
|
||||
fn parse_header_i64(headers: &HeaderMap, name: &str) -> Option<i64> {
|
||||
parse_header_str(headers, name)?.parse::<i64>().ok()
|
||||
}
|
||||
|
||||
fn parse_header_str<'a>(headers: &'a HeaderMap, name: &str) -> Option<&'a str> {
|
||||
headers.get(name)?.to_str().ok()
|
||||
}
|
||||
|
||||
async fn process_sse<S>(
|
||||
stream: S,
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent>>,
|
||||
idle_timeout: Duration,
|
||||
otel_event_manager: OtelEventManager,
|
||||
) where
|
||||
S: Stream<Item = Result<Bytes>> + Unpin + Send + 'static,
|
||||
{
|
||||
let mut stream = stream.eventsource();
|
||||
|
||||
let mut response_completed: Option<ResponseCompleted> = None;
|
||||
let mut response_error: Option<Error> = None;
|
||||
|
||||
loop {
|
||||
let start = std::time::Instant::now();
|
||||
let response = timeout(idle_timeout, stream.next()).await;
|
||||
let duration = start.elapsed();
|
||||
otel_event_manager.log_sse_event(&response, duration);
|
||||
|
||||
let sse = match response {
|
||||
Ok(Some(Ok(sse))) => sse,
|
||||
Ok(Some(Err(e))) => {
|
||||
debug!("SSE Error: {e:#}");
|
||||
let event = Error::Stream(e.to_string(), None);
|
||||
let _ = tx_event.send(Err(event)).await;
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
match response_completed {
|
||||
Some(ResponseCompleted {
|
||||
id: response_id,
|
||||
usage,
|
||||
}) => {
|
||||
if let Some(token_usage) = &usage {
|
||||
otel_event_manager.sse_event_completed(
|
||||
token_usage.input_tokens,
|
||||
token_usage.output_tokens,
|
||||
token_usage
|
||||
.input_tokens_details
|
||||
.as_ref()
|
||||
.map(|d| d.cached_tokens),
|
||||
token_usage
|
||||
.output_tokens_details
|
||||
.as_ref()
|
||||
.map(|d| d.reasoning_tokens),
|
||||
token_usage.total_tokens,
|
||||
);
|
||||
}
|
||||
let event = ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage: usage.map(Into::into),
|
||||
};
|
||||
let _ = tx_event.send(Ok(event)).await;
|
||||
}
|
||||
None => {
|
||||
let error = response_error.unwrap_or(Error::Stream(
|
||||
"stream closed before response.completed".into(),
|
||||
None,
|
||||
));
|
||||
otel_event_manager.see_event_completed_failed(&error);
|
||||
|
||||
let _ = tx_event.send(Err(error)).await;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = tx_event
|
||||
.send(Err(Error::Stream(
|
||||
"idle timeout waiting for SSE".into(),
|
||||
None,
|
||||
)))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let raw = sse.data.clone();
|
||||
trace!("SSE event: {}", raw);
|
||||
|
||||
let event: SseEvent = match serde_json::from_str(&sse.data) {
|
||||
Ok(event) => event,
|
||||
Err(e) => {
|
||||
debug!("Failed to parse SSE event: {e}, data: {}", &sse.data);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match event.kind.as_str() {
|
||||
"response.output_item.done" => {
|
||||
let Some(item_val) = event.item else { continue };
|
||||
let Ok(item) = serde_json::from_value::<ResponseItem>(item_val) else {
|
||||
debug!("failed to parse ResponseItem from output_item.done");
|
||||
continue;
|
||||
};
|
||||
|
||||
let event = ResponseEvent::OutputItemDone(item);
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
"response.output_text.delta" => {
|
||||
if let Some(delta) = event.delta {
|
||||
let event = ResponseEvent::OutputTextDelta(delta);
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.reasoning_summary_text.delta" => {
|
||||
if let Some(delta) = event.delta {
|
||||
let event = ResponseEvent::ReasoningSummaryDelta(delta);
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.reasoning_text.delta" => {
|
||||
if let Some(delta) = event.delta {
|
||||
let event = ResponseEvent::ReasoningContentDelta(delta);
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.created" => {
|
||||
if event.response.is_some() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::Created)).await;
|
||||
}
|
||||
}
|
||||
"response.failed" => {
|
||||
if let Some(resp_val) = event.response {
|
||||
response_error = Some(Error::Stream(
|
||||
"response.failed event received".to_string(),
|
||||
None,
|
||||
));
|
||||
|
||||
if let Some(error) = resp_val.get("error") {
|
||||
match serde_json::from_value::<ErrorResponse>(error.clone()) {
|
||||
Ok(error) => {
|
||||
if is_context_window_error(&error) {
|
||||
response_error = Some(Error::UnsupportedOperation(
|
||||
"context window exceeded".to_string(),
|
||||
));
|
||||
} else {
|
||||
let delay = try_parse_retry_after(&error);
|
||||
let message = error.message.clone().unwrap_or_default();
|
||||
response_error = Some(Error::Stream(message, delay));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let error = format!("failed to parse ErrorResponse: {e}");
|
||||
debug!(error);
|
||||
response_error = Some(Error::Stream(error, None))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"response.completed" => {
|
||||
if let Some(resp_val) = event.response {
|
||||
match serde_json::from_value::<ResponseCompleted>(resp_val) {
|
||||
Ok(r) => {
|
||||
response_completed = Some(r);
|
||||
}
|
||||
Err(e) => {
|
||||
let error = format!("failed to parse ResponseCompleted: {e}");
|
||||
debug!(error);
|
||||
response_error = Some(Error::Stream(error, None));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
};
|
||||
}
|
||||
"response.output_item.added" => {
|
||||
let Some(item_val) = event.item else { continue };
|
||||
let Ok(item) = serde_json::from_value::<ResponseItem>(item_val) else {
|
||||
debug!("failed to parse ResponseItem from output_item.done");
|
||||
continue;
|
||||
};
|
||||
|
||||
let event = ResponseEvent::OutputItemAdded(item);
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
"response.reasoning_summary_part.added" => {
|
||||
let event = ResponseEvent::ReasoningSummaryPartAdded;
|
||||
if tx_event.send(Ok(event)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ErrorResponse {
|
||||
code: Option<String>,
|
||||
message: Option<String>,
|
||||
}
|
||||
|
||||
fn backoff(attempt: u64) -> Duration {
|
||||
let exponent = attempt.min(6) as u32;
|
||||
let base = 2u64.pow(exponent);
|
||||
Duration::from_millis(base * 100)
|
||||
}
|
||||
|
||||
fn rate_limit_regex() -> Option<&'static Regex> {
|
||||
static RE: OnceLock<Option<Regex>> = OnceLock::new();
|
||||
|
||||
RE.get_or_init(|| Regex::new(r"Please try again in (\d+(?:\.\d+)?)(s|ms)").ok())
|
||||
.as_ref()
|
||||
}
|
||||
|
||||
fn try_parse_retry_after(err: &ErrorResponse) -> Option<Duration> {
|
||||
if err.code.as_deref() != Some("rate_limit_exceeded") {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(re) = rate_limit_regex()
|
||||
&& let Some(message) = &err.message
|
||||
&& let Some(captures) = re.captures(message)
|
||||
{
|
||||
let seconds = captures.get(1);
|
||||
let unit = captures.get(2);
|
||||
|
||||
if let (Some(value), Some(unit)) = (seconds, unit) {
|
||||
let value = value.as_str().parse::<f64>().ok()?;
|
||||
let unit = unit.as_str();
|
||||
|
||||
if unit == "s" {
|
||||
return Some(Duration::from_secs_f64(value));
|
||||
} else if unit == "ms" {
|
||||
return Some(Duration::from_millis(value as u64));
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn is_context_window_error(error: &ErrorResponse) -> bool {
|
||||
error.code.as_deref() == Some("context_length_exceeded")
|
||||
}
|
||||
|
||||
/// used in tests to stream from a text SSE file
|
||||
pub async fn stream_from_fixture(
|
||||
path: impl AsRef<Path>,
|
||||
provider: ModelProviderInfo,
|
||||
otel_event_manager: OtelEventManager,
|
||||
) -> Result<ResponseStream> {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
let display_path = path.as_ref().display().to_string();
|
||||
let file = std::fs::File::open(path.as_ref())
|
||||
.map_err(|e| Error::Other(format!("failed to open fixture {display_path}: {e}")))?;
|
||||
let lines = std::io::BufReader::new(file).lines();
|
||||
|
||||
let mut content = String::new();
|
||||
for line in lines {
|
||||
let line =
|
||||
line.map_err(|e| Error::Other(format!("failed to read fixture {display_path}: {e}")))?;
|
||||
content.push_str(&line);
|
||||
content.push_str("\n\n");
|
||||
}
|
||||
|
||||
let rdr = std::io::Cursor::new(content);
|
||||
let stream = ReaderStream::new(rdr).map_err(|e| Error::Other(e.to_string()));
|
||||
tokio::spawn(process_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
provider.stream_idle_timeout(),
|
||||
otel_event_manager,
|
||||
));
|
||||
Ok(ResponseStream { rx_event })
|
||||
}
|
||||
83
codex-rs/api-client/src/stream.rs
Normal file
83
codex-rs/api-client/src/stream.rs
Normal file
@@ -0,0 +1,83 @@
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::RateLimitSnapshot;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use futures::Stream;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
#[derive(Debug, Serialize, Clone)]
|
||||
pub struct Reasoning {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub effort: Option<ReasoningEffortConfig>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub summary: Option<ReasoningSummaryConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TextFormatType {
|
||||
#[default]
|
||||
JsonSchema,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
pub struct TextFormat {
|
||||
pub r#type: TextFormatType,
|
||||
pub strict: bool,
|
||||
pub schema: Value,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
pub struct TextControls {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub verbosity: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub format: Option<TextFormat>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ResponseEvent {
|
||||
Created,
|
||||
OutputItemDone(ResponseItem),
|
||||
OutputItemAdded(ResponseItem),
|
||||
Completed {
|
||||
response_id: String,
|
||||
token_usage: Option<TokenUsage>,
|
||||
},
|
||||
OutputTextDelta(String),
|
||||
ReasoningSummaryDelta(String),
|
||||
ReasoningContentDelta(String),
|
||||
ReasoningSummaryPartAdded,
|
||||
RateLimits(RateLimitSnapshot),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EventStream<T> {
|
||||
pub(crate) rx_event: mpsc::Receiver<T>,
|
||||
}
|
||||
|
||||
impl<T> EventStream<T> {
|
||||
pub fn from_receiver(rx_event: mpsc::Receiver<T>) -> Self {
|
||||
Self { rx_event }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Stream for EventStream<T> {
|
||||
type Item = T;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.rx_event.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
|
||||
pub type ResponseStream = EventStream<Result<ResponseEvent>>;
|
||||
@@ -8,6 +8,7 @@ workspace = true
|
||||
|
||||
[dependencies]
|
||||
clap = { workspace = true, features = ["derive", "wrap_help"], optional = true }
|
||||
codex-api-client = { workspace = true }
|
||||
codex-core = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use codex_core::WireApi;
|
||||
use codex_api_client::WireApi;
|
||||
use codex_core::config::Config;
|
||||
|
||||
use crate::sandbox_summary::summarize_sandbox_policy;
|
||||
|
||||
@@ -22,6 +22,7 @@ chrono = { workspace = true, features = ["serde"] }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-apply-patch = { workspace = true }
|
||||
codex-async-utils = { workspace = true }
|
||||
codex-api-client = { workspace = true }
|
||||
codex-file-search = { workspace = true }
|
||||
codex-git = { workspace = true }
|
||||
codex-keyring-store = { workspace = true }
|
||||
|
||||
@@ -22,7 +22,6 @@ use crate::auth::storage::AuthStorageBackend;
|
||||
use crate::auth::storage::create_auth_storage;
|
||||
use crate::config::Config;
|
||||
use crate::default_client::CodexHttpClient;
|
||||
use crate::token_data::PlanType;
|
||||
use crate::token_data::TokenData;
|
||||
use crate::token_data::parse_id_token;
|
||||
use crate::util::try_parse_error_message;
|
||||
@@ -153,11 +152,6 @@ impl CodexAuth {
|
||||
self.get_current_token_data().and_then(|t| t.id_token.email)
|
||||
}
|
||||
|
||||
pub(crate) fn get_plan_type(&self) -> Option<PlanType> {
|
||||
self.get_current_token_data()
|
||||
.and_then(|t| t.id_token.chatgpt_plan_type)
|
||||
}
|
||||
|
||||
fn get_current_auth_json(&self) -> Option<AuthDotJson> {
|
||||
#[expect(clippy::unwrap_used)]
|
||||
self.auth_dot_json.lock().unwrap().clone()
|
||||
|
||||
@@ -1,967 +0,0 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::ModelProviderInfo;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::client_common::ResponseStream;
|
||||
use crate::default_client::CodexHttpClient;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::ConnectionFailedError;
|
||||
use crate::error::ResponseStreamFailed;
|
||||
use crate::error::Result;
|
||||
use crate::error::RetryLimitReachedError;
|
||||
use crate::error::UnexpectedResponseError;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::tools::spec::create_tools_json_for_chat_completions_api;
|
||||
use crate::util::backoff;
|
||||
use bytes::Bytes;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::models::ContentItem;
|
||||
use codex_protocol::models::FunctionCallOutputContentItem;
|
||||
use codex_protocol::models::ReasoningItemContent;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::SubAgentSource;
|
||||
use eventsource_stream::Eventsource;
|
||||
use futures::Stream;
|
||||
use futures::StreamExt;
|
||||
use futures::TryStreamExt;
|
||||
use reqwest::StatusCode;
|
||||
use serde_json::json;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
use tracing::debug;
|
||||
use tracing::trace;
|
||||
|
||||
/// Implementation for the classic Chat Completions API.
|
||||
pub(crate) async fn stream_chat_completions(
|
||||
prompt: &Prompt,
|
||||
model_family: &ModelFamily,
|
||||
client: &CodexHttpClient,
|
||||
provider: &ModelProviderInfo,
|
||||
otel_event_manager: &OtelEventManager,
|
||||
session_source: &SessionSource,
|
||||
) -> Result<ResponseStream> {
|
||||
if prompt.output_schema.is_some() {
|
||||
return Err(CodexErr::UnsupportedOperation(
|
||||
"output_schema is not supported for Chat Completions API".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Build messages array
|
||||
let mut messages = Vec::<serde_json::Value>::new();
|
||||
|
||||
let full_instructions = prompt.get_full_instructions(model_family);
|
||||
messages.push(json!({"role": "system", "content": full_instructions}));
|
||||
|
||||
let input = prompt.get_formatted_input();
|
||||
|
||||
// Pre-scan: map Reasoning blocks to the adjacent assistant anchor after the last user.
|
||||
// - If the last emitted message is a user message, drop all reasoning.
|
||||
// - Otherwise, for each Reasoning item after the last user message, attach it
|
||||
// to the immediate previous assistant message (stop turns) or the immediate
|
||||
// next assistant anchor (tool-call turns: function/local shell call, or assistant message).
|
||||
let mut reasoning_by_anchor_index: std::collections::HashMap<usize, String> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
// Determine the last role that would be emitted to Chat Completions.
|
||||
let mut last_emitted_role: Option<&str> = None;
|
||||
for item in &input {
|
||||
match item {
|
||||
ResponseItem::Message { role, .. } => last_emitted_role = Some(role.as_str()),
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
last_emitted_role = Some("assistant")
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { .. } => last_emitted_role = Some("tool"),
|
||||
ResponseItem::Reasoning { .. } | ResponseItem::Other => {}
|
||||
ResponseItem::CustomToolCall { .. } => {}
|
||||
ResponseItem::CustomToolCallOutput { .. } => {}
|
||||
ResponseItem::WebSearchCall { .. } => {}
|
||||
ResponseItem::GhostSnapshot { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Find the last user message index in the input.
|
||||
let mut last_user_index: Option<usize> = None;
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
if let ResponseItem::Message { role, .. } = item
|
||||
&& role == "user"
|
||||
{
|
||||
last_user_index = Some(idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Attach reasoning only if the conversation does not end with a user message.
|
||||
if !matches!(last_emitted_role, Some("user")) {
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
// Only consider reasoning that appears after the last user message.
|
||||
if let Some(u_idx) = last_user_index
|
||||
&& idx <= u_idx
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let ResponseItem::Reasoning {
|
||||
content: Some(items),
|
||||
..
|
||||
} = item
|
||||
{
|
||||
let mut text = String::new();
|
||||
for entry in items {
|
||||
match entry {
|
||||
ReasoningItemContent::ReasoningText { text: segment }
|
||||
| ReasoningItemContent::Text { text: segment } => text.push_str(segment),
|
||||
}
|
||||
}
|
||||
if text.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Prefer immediate previous assistant message (stop turns)
|
||||
let mut attached = false;
|
||||
if idx > 0
|
||||
&& let ResponseItem::Message { role, .. } = &input[idx - 1]
|
||||
&& role == "assistant"
|
||||
{
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx - 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
attached = true;
|
||||
}
|
||||
|
||||
// Otherwise, attach to immediate next assistant anchor (tool-calls or assistant message)
|
||||
if !attached && idx + 1 < input.len() {
|
||||
match &input[idx + 1] {
|
||||
ResponseItem::FunctionCall { .. } | ResponseItem::LocalShellCall { .. } => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
ResponseItem::Message { role, .. } if role == "assistant" => {
|
||||
reasoning_by_anchor_index
|
||||
.entry(idx + 1)
|
||||
.and_modify(|v| v.push_str(&text))
|
||||
.or_insert(text.clone());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Track last assistant text we emitted to avoid duplicate assistant messages
|
||||
// in the outbound Chat Completions payload (can happen if a final
|
||||
// aggregated assistant message was recorded alongside an earlier partial).
|
||||
let mut last_assistant_text: Option<String> = None;
|
||||
|
||||
for (idx, item) in input.iter().enumerate() {
|
||||
match item {
|
||||
ResponseItem::Message { role, content, .. } => {
|
||||
// Build content either as a plain string (typical for assistant text)
|
||||
// or as an array of content items when images are present (user/tool multimodal).
|
||||
let mut text = String::new();
|
||||
let mut items: Vec<serde_json::Value> = Vec::new();
|
||||
let mut saw_image = false;
|
||||
|
||||
for c in content {
|
||||
match c {
|
||||
ContentItem::InputText { text: t }
|
||||
| ContentItem::OutputText { text: t } => {
|
||||
text.push_str(t);
|
||||
items.push(json!({"type":"text","text": t}));
|
||||
}
|
||||
ContentItem::InputImage { image_url } => {
|
||||
saw_image = true;
|
||||
items.push(json!({"type":"image_url","image_url": {"url": image_url}}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Skip exact-duplicate assistant messages.
|
||||
if role == "assistant" {
|
||||
if let Some(prev) = &last_assistant_text
|
||||
&& prev == &text
|
||||
{
|
||||
continue;
|
||||
}
|
||||
last_assistant_text = Some(text.clone());
|
||||
}
|
||||
|
||||
// For assistant messages, always send a plain string for compatibility.
|
||||
// For user messages, if an image is present, send an array of content items.
|
||||
let content_value = if role == "assistant" {
|
||||
json!(text)
|
||||
} else if saw_image {
|
||||
json!(items)
|
||||
} else {
|
||||
json!(text)
|
||||
};
|
||||
|
||||
let mut msg = json!({"role": role, "content": content_value});
|
||||
if role == "assistant"
|
||||
&& let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::FunctionCall {
|
||||
name,
|
||||
arguments,
|
||||
call_id,
|
||||
..
|
||||
} => {
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"arguments": arguments,
|
||||
}
|
||||
}]
|
||||
});
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::LocalShellCall {
|
||||
id,
|
||||
call_id: _,
|
||||
status,
|
||||
action,
|
||||
} => {
|
||||
// Confirm with API team.
|
||||
let mut msg = json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": id.clone().unwrap_or_else(|| "".to_string()),
|
||||
"type": "local_shell_call",
|
||||
"status": status,
|
||||
"action": action,
|
||||
}]
|
||||
});
|
||||
if let Some(reasoning) = reasoning_by_anchor_index.get(&idx)
|
||||
&& let Some(obj) = msg.as_object_mut()
|
||||
{
|
||||
obj.insert("reasoning".to_string(), json!(reasoning));
|
||||
}
|
||||
messages.push(msg);
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
// Prefer structured content items when available (e.g., images)
|
||||
// otherwise fall back to the legacy plain-string content.
|
||||
let content_value = if let Some(items) = &output.content_items {
|
||||
let mapped: Vec<serde_json::Value> = items
|
||||
.iter()
|
||||
.map(|it| match it {
|
||||
FunctionCallOutputContentItem::InputText { text } => {
|
||||
json!({"type":"text","text": text})
|
||||
}
|
||||
FunctionCallOutputContentItem::InputImage { image_url } => {
|
||||
json!({"type":"image_url","image_url": {"url": image_url}})
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
json!(mapped)
|
||||
} else {
|
||||
json!(output.content)
|
||||
};
|
||||
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": content_value,
|
||||
}));
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
id,
|
||||
call_id: _,
|
||||
name,
|
||||
input,
|
||||
status: _,
|
||||
} => {
|
||||
messages.push(json!({
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [{
|
||||
"id": id,
|
||||
"type": "custom",
|
||||
"custom": {
|
||||
"name": name,
|
||||
"input": input,
|
||||
}
|
||||
}]
|
||||
}));
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, output } => {
|
||||
messages.push(json!({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": output,
|
||||
}));
|
||||
}
|
||||
ResponseItem::GhostSnapshot { .. } => {
|
||||
// Ghost snapshots annotate history but are not sent to the model.
|
||||
continue;
|
||||
}
|
||||
ResponseItem::Reasoning { .. }
|
||||
| ResponseItem::WebSearchCall { .. }
|
||||
| ResponseItem::Other => {
|
||||
// Omit these items from the conversation history.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
|
||||
let payload = json!({
|
||||
"model": model_family.slug,
|
||||
"messages": messages,
|
||||
"stream": true,
|
||||
"tools": tools_json,
|
||||
});
|
||||
|
||||
debug!(
|
||||
"POST to {}: {}",
|
||||
provider.get_full_url(&None),
|
||||
serde_json::to_string_pretty(&payload).unwrap_or_default()
|
||||
);
|
||||
|
||||
let mut attempt = 0;
|
||||
let max_retries = provider.request_max_retries();
|
||||
loop {
|
||||
attempt += 1;
|
||||
|
||||
let mut req_builder = provider.create_request_builder(client, &None).await?;
|
||||
|
||||
// Include subagent header only for subagent sessions.
|
||||
if let SessionSource::SubAgent(sub) = session_source.clone() {
|
||||
let subagent = if let SubAgentSource::Other(label) = sub {
|
||||
label
|
||||
} else {
|
||||
serde_json::to_value(&sub)
|
||||
.ok()
|
||||
.and_then(|v| v.as_str().map(std::string::ToString::to_string))
|
||||
.unwrap_or_else(|| "other".to_string())
|
||||
};
|
||||
req_builder = req_builder.header("x-openai-subagent", subagent);
|
||||
}
|
||||
|
||||
let res = otel_event_manager
|
||||
.log_request(attempt, || {
|
||||
req_builder
|
||||
.header(reqwest::header::ACCEPT, "text/event-stream")
|
||||
.json(&payload)
|
||||
.send()
|
||||
})
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let (tx_event, rx_event) = mpsc::channel::<Result<ResponseEvent>>(1600);
|
||||
let stream = resp.bytes_stream().map_err(|e| {
|
||||
CodexErr::ResponseStreamFailed(ResponseStreamFailed {
|
||||
source: e,
|
||||
request_id: None,
|
||||
})
|
||||
});
|
||||
tokio::spawn(process_chat_sse(
|
||||
stream,
|
||||
tx_event,
|
||||
provider.stream_idle_timeout(),
|
||||
otel_event_manager.clone(),
|
||||
));
|
||||
return Ok(ResponseStream { rx_event });
|
||||
}
|
||||
Ok(res) => {
|
||||
let status = res.status();
|
||||
if !(status == StatusCode::TOO_MANY_REQUESTS || status.is_server_error()) {
|
||||
let body = (res.text().await).unwrap_or_default();
|
||||
return Err(CodexErr::UnexpectedStatus(UnexpectedResponseError {
|
||||
status,
|
||||
body,
|
||||
request_id: None,
|
||||
}));
|
||||
}
|
||||
|
||||
if attempt > max_retries {
|
||||
return Err(CodexErr::RetryLimit(RetryLimitReachedError {
|
||||
status,
|
||||
request_id: None,
|
||||
}));
|
||||
}
|
||||
|
||||
let retry_after_secs = res
|
||||
.headers()
|
||||
.get(reqwest::header::RETRY_AFTER)
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u64>().ok());
|
||||
|
||||
let delay = retry_after_secs
|
||||
.map(|s| Duration::from_millis(s * 1_000))
|
||||
.unwrap_or_else(|| backoff(attempt));
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
Err(e) => {
|
||||
if attempt > max_retries {
|
||||
return Err(CodexErr::ConnectionFailed(ConnectionFailedError {
|
||||
source: e,
|
||||
}));
|
||||
}
|
||||
let delay = backoff(attempt);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn append_assistant_text(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
|
||||
assistant_item: &mut Option<ResponseItem>,
|
||||
text: String,
|
||||
) {
|
||||
if assistant_item.is_none() {
|
||||
let item = ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![],
|
||||
};
|
||||
*assistant_item = Some(item.clone());
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemAdded(item)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(ResponseItem::Message { content, .. }) = assistant_item {
|
||||
content.push(ContentItem::OutputText { text: text.clone() });
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputTextDelta(text.clone())))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn append_reasoning_text(
|
||||
tx_event: &mpsc::Sender<Result<ResponseEvent>>,
|
||||
reasoning_item: &mut Option<ResponseItem>,
|
||||
text: String,
|
||||
) {
|
||||
if reasoning_item.is_none() {
|
||||
let item = ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
*reasoning_item = Some(item.clone());
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::OutputItemAdded(item)))
|
||||
.await;
|
||||
}
|
||||
|
||||
if let Some(ResponseItem::Reasoning {
|
||||
content: Some(content),
|
||||
..
|
||||
}) = reasoning_item
|
||||
{
|
||||
content.push(ReasoningItemContent::ReasoningText { text: text.clone() });
|
||||
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::ReasoningContentDelta(text.clone())))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
/// Lightweight SSE processor for the Chat Completions streaming format. The
|
||||
/// output is mapped onto Codex's internal [`ResponseEvent`] so that the rest
|
||||
/// of the pipeline can stay agnostic of the underlying wire format.
|
||||
async fn process_chat_sse<S>(
|
||||
stream: S,
|
||||
tx_event: mpsc::Sender<Result<ResponseEvent>>,
|
||||
idle_timeout: Duration,
|
||||
otel_event_manager: OtelEventManager,
|
||||
) where
|
||||
S: Stream<Item = Result<Bytes>> + Unpin,
|
||||
{
|
||||
let mut stream = stream.eventsource();
|
||||
|
||||
// State to accumulate a function call across streaming chunks.
|
||||
// OpenAI may split the `arguments` string over multiple `delta` events
|
||||
// until the chunk whose `finish_reason` is `tool_calls` is emitted. We
|
||||
// keep collecting the pieces here and forward a single
|
||||
// `ResponseItem::FunctionCall` once the call is complete.
|
||||
#[derive(Default)]
|
||||
struct FunctionCallState {
|
||||
name: Option<String>,
|
||||
arguments: String,
|
||||
call_id: Option<String>,
|
||||
active: bool,
|
||||
}
|
||||
|
||||
let mut fn_call_state = FunctionCallState::default();
|
||||
let mut assistant_item: Option<ResponseItem> = None;
|
||||
let mut reasoning_item: Option<ResponseItem> = None;
|
||||
|
||||
loop {
|
||||
let start = std::time::Instant::now();
|
||||
let response = timeout(idle_timeout, stream.next()).await;
|
||||
let duration = start.elapsed();
|
||||
otel_event_manager.log_sse_event(&response, duration);
|
||||
|
||||
let sse = match response {
|
||||
Ok(Some(Ok(ev))) => ev,
|
||||
Ok(Some(Err(e))) => {
|
||||
let _ = tx_event
|
||||
.send(Err(CodexErr::Stream(e.to_string(), None)))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
// Stream closed gracefully – emit Completed with dummy id.
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
let _ = tx_event
|
||||
.send(Err(CodexErr::Stream(
|
||||
"idle timeout waiting for SSE".into(),
|
||||
None,
|
||||
)))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// OpenAI Chat streaming sends a literal string "[DONE]" when finished.
|
||||
if sse.data.trim() == "[DONE]" {
|
||||
// Emit any finalized items before closing so downstream consumers receive
|
||||
// terminal events for both assistant content and raw reasoning.
|
||||
if let Some(item) = assistant_item {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
|
||||
if let Some(item) = reasoning_item {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
|
||||
// Parse JSON chunk
|
||||
let chunk: serde_json::Value = match serde_json::from_str(&sse.data) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
trace!("chat_completions received SSE chunk: {chunk:?}");
|
||||
|
||||
let choice_opt = chunk.get("choices").and_then(|c| c.get(0));
|
||||
|
||||
if let Some(choice) = choice_opt {
|
||||
// Handle assistant content tokens as streaming deltas.
|
||||
if let Some(content) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
&& !content.is_empty()
|
||||
{
|
||||
append_assistant_text(&tx_event, &mut assistant_item, content.to_string()).await;
|
||||
}
|
||||
|
||||
// Forward any reasoning/thinking deltas if present.
|
||||
// Some providers stream `reasoning` as a plain string while others
|
||||
// nest the text under an object (e.g. `{ "reasoning": { "text": "…" } }`).
|
||||
if let Some(reasoning_val) = choice.get("delta").and_then(|d| d.get("reasoning")) {
|
||||
let mut maybe_text = reasoning_val
|
||||
.as_str()
|
||||
.map(str::to_string)
|
||||
.filter(|s| !s.is_empty());
|
||||
|
||||
if maybe_text.is_none() && reasoning_val.is_object() {
|
||||
if let Some(s) = reasoning_val
|
||||
.get("text")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
{
|
||||
maybe_text = Some(s.to_string());
|
||||
} else if let Some(s) = reasoning_val
|
||||
.get("content")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
{
|
||||
maybe_text = Some(s.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(reasoning) = maybe_text {
|
||||
// Accumulate so we can emit a terminal Reasoning item at the end.
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, reasoning).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Some providers only include reasoning on the final message object.
|
||||
if let Some(message_reasoning) = choice.get("message").and_then(|m| m.get("reasoning"))
|
||||
{
|
||||
// Accept either a plain string or an object with { text | content }
|
||||
if let Some(s) = message_reasoning.as_str() {
|
||||
if !s.is_empty() {
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await;
|
||||
}
|
||||
} else if let Some(obj) = message_reasoning.as_object()
|
||||
&& let Some(s) = obj
|
||||
.get("text")
|
||||
.and_then(|v| v.as_str())
|
||||
.or_else(|| obj.get("content").and_then(|v| v.as_str()))
|
||||
&& !s.is_empty()
|
||||
{
|
||||
append_reasoning_text(&tx_event, &mut reasoning_item, s.to_string()).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle streaming function / tool calls.
|
||||
if let Some(tool_calls) = choice
|
||||
.get("delta")
|
||||
.and_then(|d| d.get("tool_calls"))
|
||||
.and_then(|tc| tc.as_array())
|
||||
&& let Some(tool_call) = tool_calls.first()
|
||||
{
|
||||
// Mark that we have an active function call in progress.
|
||||
fn_call_state.active = true;
|
||||
|
||||
// Extract call_id if present.
|
||||
if let Some(id) = tool_call.get("id").and_then(|v| v.as_str()) {
|
||||
fn_call_state.call_id.get_or_insert_with(|| id.to_string());
|
||||
}
|
||||
|
||||
// Extract function details if present.
|
||||
if let Some(function) = tool_call.get("function") {
|
||||
if let Some(name) = function.get("name").and_then(|n| n.as_str()) {
|
||||
fn_call_state.name.get_or_insert_with(|| name.to_string());
|
||||
}
|
||||
|
||||
if let Some(args_fragment) = function.get("arguments").and_then(|a| a.as_str())
|
||||
{
|
||||
fn_call_state.arguments.push_str(args_fragment);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Emit end-of-turn when finish_reason signals completion.
|
||||
if let Some(finish_reason) = choice.get("finish_reason").and_then(|v| v.as_str()) {
|
||||
match finish_reason {
|
||||
"tool_calls" if fn_call_state.active => {
|
||||
// First, flush the terminal raw reasoning so UIs can finalize
|
||||
// the reasoning stream before any exec/tool events begin.
|
||||
if let Some(item) = reasoning_item.take() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
|
||||
// Then emit the FunctionCall response item.
|
||||
let item = ResponseItem::FunctionCall {
|
||||
id: None,
|
||||
name: fn_call_state.name.clone().unwrap_or_else(|| "".to_string()),
|
||||
arguments: fn_call_state.arguments.clone(),
|
||||
call_id: fn_call_state.call_id.clone().unwrap_or_else(String::new),
|
||||
};
|
||||
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
"stop" => {
|
||||
// Regular turn without tool-call. Emit the final assistant message
|
||||
// as a single OutputItemDone so non-delta consumers see the result.
|
||||
if let Some(item) = assistant_item.take() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
// Also emit a terminal Reasoning item so UIs can finalize raw reasoning.
|
||||
if let Some(item) = reasoning_item.take() {
|
||||
let _ = tx_event.send(Ok(ResponseEvent::OutputItemDone(item))).await;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Emit Completed regardless of reason so the agent can advance.
|
||||
let _ = tx_event
|
||||
.send(Ok(ResponseEvent::Completed {
|
||||
response_id: String::new(),
|
||||
token_usage: None,
|
||||
}))
|
||||
.await;
|
||||
|
||||
// Prepare for potential next turn (should not happen in same stream).
|
||||
// fn_call_state = FunctionCallState::default();
|
||||
|
||||
return; // End processing for this SSE stream.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Optional client-side aggregation helper
|
||||
///
|
||||
/// Stream adapter that merges the incremental `OutputItemDone` chunks coming from
|
||||
/// [`process_chat_sse`] into a *running* assistant message, **suppressing the
|
||||
/// per-token deltas**. The stream stays silent while the model is thinking
|
||||
/// and only emits two events per turn:
|
||||
///
|
||||
/// 1. `ResponseEvent::OutputItemDone` with the *complete* assistant message
|
||||
/// (fully concatenated).
|
||||
/// 2. The original `ResponseEvent::Completed` right after it.
|
||||
///
|
||||
/// This mirrors the behaviour the TypeScript CLI exposes to its higher layers.
|
||||
///
|
||||
/// The adapter is intentionally *lossless*: callers who do **not** opt in via
|
||||
/// [`AggregateStreamExt::aggregate()`] keep receiving the original unmodified
|
||||
/// events.
|
||||
#[derive(Copy, Clone, Eq, PartialEq)]
|
||||
enum AggregateMode {
|
||||
AggregatedOnly,
|
||||
Streaming,
|
||||
}
|
||||
pub(crate) struct AggregatedChatStream<S> {
|
||||
inner: S,
|
||||
cumulative: String,
|
||||
cumulative_reasoning: String,
|
||||
pending: std::collections::VecDeque<ResponseEvent>,
|
||||
mode: AggregateMode,
|
||||
}
|
||||
|
||||
impl<S> Stream for AggregatedChatStream<S>
|
||||
where
|
||||
S: Stream<Item = Result<ResponseEvent>> + Unpin,
|
||||
{
|
||||
type Item = Result<ResponseEvent>;
|
||||
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
// First, flush any buffered events from the previous call.
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
|
||||
loop {
|
||||
match Pin::new(&mut this.inner).poll_next(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item)))) => {
|
||||
// If this is an incremental assistant message chunk, accumulate but
|
||||
// do NOT emit yet. Forward any other item (e.g. FunctionCall) right
|
||||
// away so downstream consumers see it.
|
||||
|
||||
let is_assistant_message = matches!(
|
||||
&item,
|
||||
codex_protocol::models::ResponseItem::Message { role, .. } if role == "assistant"
|
||||
);
|
||||
|
||||
if is_assistant_message {
|
||||
match this.mode {
|
||||
AggregateMode::AggregatedOnly => {
|
||||
// Only use the final assistant message if we have not
|
||||
// seen any deltas; otherwise, deltas already built the
|
||||
// cumulative text and this would duplicate it.
|
||||
if this.cumulative.is_empty()
|
||||
&& let codex_protocol::models::ResponseItem::Message {
|
||||
content,
|
||||
..
|
||||
} = &item
|
||||
&& let Some(text) = content.iter().find_map(|c| match c {
|
||||
codex_protocol::models::ContentItem::OutputText {
|
||||
text,
|
||||
} => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
{
|
||||
this.cumulative.push_str(text);
|
||||
}
|
||||
// Swallow assistant message here; emit on Completed.
|
||||
continue;
|
||||
}
|
||||
AggregateMode::Streaming => {
|
||||
// In streaming mode, if we have not seen any deltas, forward
|
||||
// the final assistant message directly. If deltas were seen,
|
||||
// suppress the final message to avoid duplication.
|
||||
if this.cumulative.is_empty() {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(
|
||||
item,
|
||||
))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Not an assistant message – forward immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemDone(item))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::RateLimits(snapshot))));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
}))) => {
|
||||
// Build any aggregated items in the correct order: Reasoning first, then Message.
|
||||
let mut emitted_any = false;
|
||||
|
||||
if !this.cumulative_reasoning.is_empty()
|
||||
&& matches!(this.mode, AggregateMode::AggregatedOnly)
|
||||
{
|
||||
let aggregated_reasoning =
|
||||
codex_protocol::models::ResponseItem::Reasoning {
|
||||
id: String::new(),
|
||||
summary: Vec::new(),
|
||||
content: Some(vec![
|
||||
codex_protocol::models::ReasoningItemContent::ReasoningText {
|
||||
text: std::mem::take(&mut this.cumulative_reasoning),
|
||||
},
|
||||
]),
|
||||
encrypted_content: None,
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_reasoning));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
// Always emit the final aggregated assistant message when any
|
||||
// content deltas have been observed. In AggregatedOnly mode this
|
||||
// is the sole assistant output; in Streaming mode this finalizes
|
||||
// the streamed deltas into a terminal OutputItemDone so callers
|
||||
// can persist/render the message once per turn.
|
||||
if !this.cumulative.is_empty() {
|
||||
let aggregated_message = codex_protocol::models::ResponseItem::Message {
|
||||
id: None,
|
||||
role: "assistant".to_string(),
|
||||
content: vec![codex_protocol::models::ContentItem::OutputText {
|
||||
text: std::mem::take(&mut this.cumulative),
|
||||
}],
|
||||
};
|
||||
this.pending
|
||||
.push_back(ResponseEvent::OutputItemDone(aggregated_message));
|
||||
emitted_any = true;
|
||||
}
|
||||
|
||||
// Always emit Completed last when anything was aggregated.
|
||||
if emitted_any {
|
||||
this.pending.push_back(ResponseEvent::Completed {
|
||||
response_id: response_id.clone(),
|
||||
token_usage: token_usage.clone(),
|
||||
});
|
||||
// Return the first pending event now.
|
||||
if let Some(ev) = this.pending.pop_front() {
|
||||
return Poll::Ready(Some(Ok(ev)));
|
||||
}
|
||||
}
|
||||
|
||||
// Nothing aggregated – forward Completed directly.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::Completed {
|
||||
response_id,
|
||||
token_usage,
|
||||
})));
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::Created))) => {
|
||||
// These events are exclusive to the Responses API and
|
||||
// will never appear in a Chat Completions stream.
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta)))) => {
|
||||
// Always accumulate deltas so we can emit a final OutputItemDone at Completed.
|
||||
this.cumulative.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
// In streaming mode, also forward the delta immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputTextDelta(delta))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta)))) => {
|
||||
// Always accumulate reasoning deltas so we can emit a final Reasoning item at Completed.
|
||||
this.cumulative_reasoning.push_str(&delta);
|
||||
if matches!(this.mode, AggregateMode::Streaming) {
|
||||
// In streaming mode, also forward the delta immediately.
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::ReasoningContentDelta(delta))));
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryDelta(_)))) => {
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::ReasoningSummaryPartAdded))) => {
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item)))) => {
|
||||
return Poll::Ready(Some(Ok(ResponseEvent::OutputItemAdded(item))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extension trait that activates aggregation on any stream of [`ResponseEvent`].
|
||||
pub(crate) trait AggregateStreamExt: Stream<Item = Result<ResponseEvent>> + Sized {
|
||||
/// Returns a new stream that emits **only** the final assistant message
|
||||
/// per turn instead of every incremental delta. The produced
|
||||
/// `ResponseEvent` sequence for a typical text turn looks like:
|
||||
///
|
||||
/// ```ignore
|
||||
/// OutputItemDone(<full message>)
|
||||
/// Completed
|
||||
/// ```
|
||||
///
|
||||
/// No other `OutputItemDone` events will be seen by the caller.
|
||||
///
|
||||
/// Usage:
|
||||
///
|
||||
/// ```ignore
|
||||
/// let agg_stream = client.stream(&prompt).await?.aggregate();
|
||||
/// while let Some(event) = agg_stream.next().await {
|
||||
/// // event now contains cumulative text
|
||||
/// }
|
||||
/// ```
|
||||
fn aggregate(self) -> AggregatedChatStream<Self> {
|
||||
AggregatedChatStream::new(self, AggregateMode::AggregatedOnly)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> AggregateStreamExt for T where T: Stream<Item = Result<ResponseEvent>> + Sized {}
|
||||
|
||||
impl<S> AggregatedChatStream<S> {
|
||||
fn new(inner: S, mode: AggregateMode) -> Self {
|
||||
AggregatedChatStream {
|
||||
inner,
|
||||
cumulative: String::new(),
|
||||
cumulative_reasoning: String::new(),
|
||||
pending: std::collections::VecDeque::new(),
|
||||
mode,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn streaming_mode(inner: S) -> Self {
|
||||
Self::new(inner, AggregateMode::Streaming)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,348 +1,45 @@
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use std::borrow::Cow;
|
||||
use std::ops::Deref;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::protocol::RateLimitSnapshot;
|
||||
use crate::protocol::TokenUsage;
|
||||
use codex_api_client::EventStream;
|
||||
pub use codex_api_client::Prompt;
|
||||
pub use codex_api_client::Reasoning;
|
||||
pub use codex_api_client::TextControls;
|
||||
pub use codex_api_client::TextFormat;
|
||||
pub use codex_api_client::TextFormatType;
|
||||
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
|
||||
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
|
||||
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
|
||||
use codex_protocol::config_types::Verbosity as VerbosityConfig;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
use futures::Stream;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashSet;
|
||||
use std::ops::Deref;
|
||||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::model_family::ModelFamily;
|
||||
|
||||
/// Review thread system prompt. Edit `core/src/review_prompt.md` to customize.
|
||||
pub const REVIEW_PROMPT: &str = include_str!("../review_prompt.md");
|
||||
|
||||
// Centralized templates for review-related user messages
|
||||
pub const REVIEW_EXIT_SUCCESS_TMPL: &str = include_str!("../templates/review/exit_success.xml");
|
||||
pub const REVIEW_EXIT_INTERRUPTED_TMPL: &str =
|
||||
include_str!("../templates/review/exit_interrupted.xml");
|
||||
|
||||
/// API request payload for a single model turn
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Prompt {
|
||||
/// Conversation context input items.
|
||||
pub input: Vec<ResponseItem>,
|
||||
|
||||
/// Tools available to the model, including additional tools sourced from
|
||||
/// external MCP servers.
|
||||
pub(crate) tools: Vec<ToolSpec>,
|
||||
|
||||
/// Whether parallel tool calls are permitted for this prompt.
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
|
||||
/// Optional override for the built-in BASE_INSTRUCTIONS.
|
||||
pub base_instructions_override: Option<String>,
|
||||
|
||||
/// Optional the output schema for the model's response.
|
||||
pub output_schema: Option<Value>,
|
||||
}
|
||||
|
||||
impl Prompt {
|
||||
pub(crate) fn get_full_instructions<'a>(&'a self, model: &'a ModelFamily) -> Cow<'a, str> {
|
||||
let base = self
|
||||
.base_instructions_override
|
||||
.as_deref()
|
||||
.unwrap_or(model.base_instructions.deref());
|
||||
// When there are no custom instructions, add apply_patch_tool_instructions if:
|
||||
// - the model needs special instructions (4.1)
|
||||
// AND
|
||||
// - there is no apply_patch tool present
|
||||
let is_apply_patch_tool_present = self.tools.iter().any(|tool| match tool {
|
||||
ToolSpec::Function(f) => f.name == "apply_patch",
|
||||
ToolSpec::Freeform(f) => f.name == "apply_patch",
|
||||
_ => false,
|
||||
});
|
||||
if self.base_instructions_override.is_none()
|
||||
&& model.needs_special_apply_patch_instructions
|
||||
&& !is_apply_patch_tool_present
|
||||
{
|
||||
Cow::Owned(format!("{base}\n{APPLY_PATCH_TOOL_INSTRUCTIONS}"))
|
||||
} else {
|
||||
Cow::Borrowed(base)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_formatted_input(&self) -> Vec<ResponseItem> {
|
||||
let mut input = self.input.clone();
|
||||
|
||||
// when using the *Freeform* apply_patch tool specifically, tool outputs
|
||||
// should be structured text, not json. Do NOT reserialize when using
|
||||
// the Function tool - note that this differs from the check above for
|
||||
// instructions. We declare the result as a named variable for clarity.
|
||||
let is_freeform_apply_patch_tool_present = self.tools.iter().any(|tool| match tool {
|
||||
ToolSpec::Freeform(f) => f.name == "apply_patch",
|
||||
_ => false,
|
||||
});
|
||||
if is_freeform_apply_patch_tool_present {
|
||||
reserialize_shell_outputs(&mut input);
|
||||
}
|
||||
|
||||
input
|
||||
pub fn compute_full_instructions<'a>(
|
||||
base_override: Option<&'a str>,
|
||||
model: &'a ModelFamily,
|
||||
is_apply_patch_present: bool,
|
||||
) -> Cow<'a, str> {
|
||||
let base = base_override.unwrap_or(model.base_instructions.deref());
|
||||
if base_override.is_none()
|
||||
&& model.needs_special_apply_patch_instructions
|
||||
&& !is_apply_patch_present
|
||||
{
|
||||
Cow::Owned(format!("{base}\n{APPLY_PATCH_TOOL_INSTRUCTIONS}"))
|
||||
} else {
|
||||
Cow::Borrowed(base)
|
||||
}
|
||||
}
|
||||
|
||||
fn reserialize_shell_outputs(items: &mut [ResponseItem]) {
|
||||
let mut shell_call_ids: HashSet<String> = HashSet::new();
|
||||
|
||||
items.iter_mut().for_each(|item| match item {
|
||||
ResponseItem::LocalShellCall { call_id, id, .. } => {
|
||||
if let Some(identifier) = call_id.clone().or_else(|| id.clone()) {
|
||||
shell_call_ids.insert(identifier);
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCall {
|
||||
id: _,
|
||||
status: _,
|
||||
call_id,
|
||||
name,
|
||||
input: _,
|
||||
} => {
|
||||
if name == "apply_patch" {
|
||||
shell_call_ids.insert(call_id.clone());
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, output } => {
|
||||
if shell_call_ids.remove(call_id)
|
||||
&& let Some(structured) = parse_structured_shell_output(output)
|
||||
{
|
||||
*output = structured
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCall { name, call_id, .. }
|
||||
if is_shell_tool_name(name) || name == "apply_patch" =>
|
||||
{
|
||||
shell_call_ids.insert(call_id.clone());
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
if shell_call_ids.remove(call_id)
|
||||
&& let Some(structured) = parse_structured_shell_output(&output.content)
|
||||
{
|
||||
output.content = structured
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
})
|
||||
}
|
||||
|
||||
fn is_shell_tool_name(name: &str) -> bool {
|
||||
matches!(name, "shell" | "container.exec")
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExecOutputJson {
|
||||
output: String,
|
||||
metadata: ExecOutputMetadataJson,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ExecOutputMetadataJson {
|
||||
exit_code: i32,
|
||||
duration_seconds: f32,
|
||||
}
|
||||
|
||||
fn parse_structured_shell_output(raw: &str) -> Option<String> {
|
||||
let parsed: ExecOutputJson = serde_json::from_str(raw).ok()?;
|
||||
Some(build_structured_output(&parsed))
|
||||
}
|
||||
|
||||
fn build_structured_output(parsed: &ExecOutputJson) -> String {
|
||||
let mut sections = Vec::new();
|
||||
sections.push(format!("Exit code: {}", parsed.metadata.exit_code));
|
||||
sections.push(format!(
|
||||
"Wall time: {} seconds",
|
||||
parsed.metadata.duration_seconds
|
||||
));
|
||||
|
||||
let mut output = parsed.output.clone();
|
||||
if let Some(total_lines) = extract_total_output_lines(&parsed.output) {
|
||||
sections.push(format!("Total output lines: {total_lines}"));
|
||||
if let Some(stripped) = strip_total_output_header(&output) {
|
||||
output = stripped.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
sections.push("Output:".to_string());
|
||||
sections.push(output);
|
||||
|
||||
sections.join("\n")
|
||||
}
|
||||
|
||||
fn extract_total_output_lines(output: &str) -> Option<u32> {
|
||||
let marker_start = output.find("[... omitted ")?;
|
||||
let marker = &output[marker_start..];
|
||||
let (_, after_of) = marker.split_once(" of ")?;
|
||||
let (total_segment, _) = after_of.split_once(' ')?;
|
||||
total_segment.parse::<u32>().ok()
|
||||
}
|
||||
|
||||
fn strip_total_output_header(output: &str) -> Option<&str> {
|
||||
let after_prefix = output.strip_prefix("Total output lines: ")?;
|
||||
let (_, remainder) = after_prefix.split_once('\n')?;
|
||||
let remainder = remainder.strip_prefix('\n').unwrap_or(remainder);
|
||||
Some(remainder)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum ResponseEvent {
|
||||
Created,
|
||||
OutputItemDone(ResponseItem),
|
||||
OutputItemAdded(ResponseItem),
|
||||
Completed {
|
||||
response_id: String,
|
||||
token_usage: Option<TokenUsage>,
|
||||
},
|
||||
OutputTextDelta(String),
|
||||
ReasoningSummaryDelta(String),
|
||||
ReasoningContentDelta(String),
|
||||
ReasoningSummaryPartAdded,
|
||||
RateLimits(RateLimitSnapshot),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(crate) struct Reasoning {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) effort: Option<ReasoningEffortConfig>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) summary: Option<ReasoningSummaryConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub(crate) enum TextFormatType {
|
||||
#[default]
|
||||
JsonSchema,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
pub(crate) struct TextFormat {
|
||||
pub(crate) r#type: TextFormatType,
|
||||
pub(crate) strict: bool,
|
||||
pub(crate) schema: Value,
|
||||
pub(crate) name: String,
|
||||
}
|
||||
|
||||
/// Controls under the `text` field in the Responses API for GPT-5.
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
pub(crate) struct TextControls {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) verbosity: Option<OpenAiVerbosity>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) format: Option<TextFormat>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Default, Clone)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub(crate) enum OpenAiVerbosity {
|
||||
Low,
|
||||
#[default]
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
impl From<VerbosityConfig> for OpenAiVerbosity {
|
||||
fn from(v: VerbosityConfig) -> Self {
|
||||
match v {
|
||||
VerbosityConfig::Low => OpenAiVerbosity::Low,
|
||||
VerbosityConfig::Medium => OpenAiVerbosity::Medium,
|
||||
VerbosityConfig::High => OpenAiVerbosity::High,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Request object that is serialized as JSON and POST'ed when using the
|
||||
/// Responses API.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub(crate) struct ResponsesApiRequest<'a> {
|
||||
pub(crate) model: &'a str,
|
||||
pub(crate) instructions: &'a str,
|
||||
// TODO(mbolin): ResponseItem::Other should not be serialized. Currently,
|
||||
// we code defensively to avoid this case, but perhaps we should use a
|
||||
// separate enum for serialization.
|
||||
pub(crate) input: &'a Vec<ResponseItem>,
|
||||
pub(crate) tools: &'a [serde_json::Value],
|
||||
pub(crate) tool_choice: &'static str,
|
||||
pub(crate) parallel_tool_calls: bool,
|
||||
pub(crate) reasoning: Option<Reasoning>,
|
||||
pub(crate) store: bool,
|
||||
pub(crate) stream: bool,
|
||||
pub(crate) include: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) prompt_cache_key: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(crate) text: Option<TextControls>,
|
||||
}
|
||||
|
||||
pub(crate) mod tools {
|
||||
use crate::tools::spec::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
/// When serialized as JSON, this produces a valid "Tool" in the OpenAI
|
||||
/// Responses API.
|
||||
#[derive(Debug, Clone, Serialize, PartialEq)]
|
||||
#[serde(tag = "type")]
|
||||
pub(crate) enum ToolSpec {
|
||||
#[serde(rename = "function")]
|
||||
Function(ResponsesApiTool),
|
||||
#[serde(rename = "local_shell")]
|
||||
LocalShell {},
|
||||
// TODO: Understand why we get an error on web_search although the API docs say it's supported.
|
||||
// https://platform.openai.com/docs/guides/tools-web-search?api-mode=responses#:~:text=%7B%20type%3A%20%22web_search%22%20%7D%2C
|
||||
#[serde(rename = "web_search")]
|
||||
WebSearch {},
|
||||
#[serde(rename = "custom")]
|
||||
Freeform(FreeformTool),
|
||||
}
|
||||
|
||||
impl ToolSpec {
|
||||
pub(crate) fn name(&self) -> &str {
|
||||
match self {
|
||||
ToolSpec::Function(tool) => tool.name.as_str(),
|
||||
ToolSpec::LocalShell {} => "local_shell",
|
||||
ToolSpec::WebSearch {} => "web_search",
|
||||
ToolSpec::Freeform(tool) => tool.name.as_str(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FreeformTool {
|
||||
pub(crate) name: String,
|
||||
pub(crate) description: String,
|
||||
pub(crate) format: FreeformToolFormat,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FreeformToolFormat {
|
||||
pub(crate) r#type: String,
|
||||
pub(crate) syntax: String,
|
||||
pub(crate) definition: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, PartialEq)]
|
||||
pub struct ResponsesApiTool {
|
||||
pub(crate) name: String,
|
||||
pub(crate) description: String,
|
||||
/// TODO: Validation. When strict is set to true, the JSON schema,
|
||||
/// `required` and `additional_properties` must be present. All fields in
|
||||
/// `properties` must be present in `required`.
|
||||
pub(crate) strict: bool,
|
||||
pub(crate) parameters: JsonSchema,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn create_reasoning_param_for_request(
|
||||
pub fn create_reasoning_param_for_request(
|
||||
model_family: &ModelFamily,
|
||||
effort: Option<ReasoningEffortConfig>,
|
||||
summary: ReasoningSummaryConfig,
|
||||
@@ -357,7 +54,7 @@ pub(crate) fn create_reasoning_param_for_request(
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn create_text_param_for_request(
|
||||
pub fn create_text_param_for_request(
|
||||
verbosity: Option<VerbosityConfig>,
|
||||
output_schema: &Option<Value>,
|
||||
) -> Option<TextControls> {
|
||||
@@ -366,7 +63,11 @@ pub(crate) fn create_text_param_for_request(
|
||||
}
|
||||
|
||||
Some(TextControls {
|
||||
verbosity: verbosity.map(std::convert::Into::into),
|
||||
verbosity: verbosity.map(|v| match v {
|
||||
VerbosityConfig::Low => "low".to_string(),
|
||||
VerbosityConfig::Medium => "medium".to_string(),
|
||||
VerbosityConfig::High => "high".to_string(),
|
||||
}),
|
||||
format: output_schema.as_ref().map(|schema| TextFormat {
|
||||
r#type: TextFormatType::JsonSchema,
|
||||
strict: true,
|
||||
@@ -376,178 +77,54 @@ pub(crate) fn create_text_param_for_request(
|
||||
})
|
||||
}
|
||||
|
||||
pub struct ResponseStream {
|
||||
pub(crate) rx_event: mpsc::Receiver<Result<ResponseEvent>>,
|
||||
}
|
||||
pub use codex_api_client::ResponseEvent;
|
||||
|
||||
impl Stream for ResponseStream {
|
||||
type Item = Result<ResponseEvent>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.rx_event.poll_recv(cx)
|
||||
}
|
||||
}
|
||||
pub type ResponseStream = EventStream<Result<ResponseEvent>>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::model_family::find_family_for_model;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use super::*;
|
||||
|
||||
struct InstructionsTestCase {
|
||||
pub slug: &'static str,
|
||||
pub expects_apply_patch_instructions: bool,
|
||||
}
|
||||
#[test]
|
||||
fn get_full_instructions_no_user_content() {
|
||||
let prompt = Prompt {
|
||||
..Default::default()
|
||||
};
|
||||
let test_cases = vec![
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-3.5",
|
||||
expects_apply_patch_instructions: true,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-4.1",
|
||||
expects_apply_patch_instructions: true,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-4o",
|
||||
expects_apply_patch_instructions: true,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-5",
|
||||
expects_apply_patch_instructions: true,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "codex-mini-latest",
|
||||
expects_apply_patch_instructions: true,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-oss:120b",
|
||||
expects_apply_patch_instructions: false,
|
||||
},
|
||||
InstructionsTestCase {
|
||||
slug: "gpt-5-codex",
|
||||
expects_apply_patch_instructions: false,
|
||||
},
|
||||
];
|
||||
for test_case in test_cases {
|
||||
let model_family = find_family_for_model(test_case.slug).expect("known model slug");
|
||||
let expected = if test_case.expects_apply_patch_instructions {
|
||||
format!(
|
||||
"{}\n{}",
|
||||
model_family.clone().base_instructions,
|
||||
APPLY_PATCH_TOOL_INSTRUCTIONS
|
||||
)
|
||||
} else {
|
||||
model_family.clone().base_instructions
|
||||
};
|
||||
|
||||
let full = prompt.get_full_instructions(&model_family);
|
||||
assert_eq!(full, expected);
|
||||
}
|
||||
}
|
||||
use crate::model_family::find_family_for_model;
|
||||
|
||||
#[test]
|
||||
fn serializes_text_verbosity_when_set() {
|
||||
let input: Vec<ResponseItem> = vec![];
|
||||
let tools: Vec<serde_json::Value> = vec![];
|
||||
let req = ResponsesApiRequest {
|
||||
model: "gpt-5",
|
||||
instructions: "i",
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
include: vec![],
|
||||
prompt_cache_key: None,
|
||||
text: Some(TextControls {
|
||||
verbosity: Some(OpenAiVerbosity::Low),
|
||||
format: None,
|
||||
}),
|
||||
};
|
||||
fn compute_full_instructions_respects_apply_patch_flag() {
|
||||
let model = find_family_for_model("gpt-4.1").expect("model");
|
||||
let with_tool = compute_full_instructions(None, &model, true);
|
||||
assert_eq!(with_tool.as_ref(), model.base_instructions.deref());
|
||||
|
||||
let v = serde_json::to_value(&req).expect("json");
|
||||
assert_eq!(
|
||||
v.get("text")
|
||||
.and_then(|t| t.get("verbosity"))
|
||||
.and_then(|s| s.as_str()),
|
||||
Some("low")
|
||||
let without_tool = compute_full_instructions(None, &model, false);
|
||||
assert!(
|
||||
without_tool
|
||||
.as_ref()
|
||||
.ends_with(APPLY_PATCH_TOOL_INSTRUCTIONS)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serializes_text_schema_with_strict_format() {
|
||||
let input: Vec<ResponseItem> = vec![];
|
||||
let tools: Vec<serde_json::Value> = vec![];
|
||||
fn create_text_controls_includes_verbosity() {
|
||||
let controls = create_text_param_for_request(Some(VerbosityConfig::Low), &None)
|
||||
.expect("text controls");
|
||||
assert_eq!(controls.verbosity.as_deref(), Some("low"));
|
||||
assert!(controls.format.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn create_text_controls_includes_schema() {
|
||||
let schema = serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"answer": {"type": "string"}
|
||||
},
|
||||
"properties": {"answer": {"type": "string"}},
|
||||
"required": ["answer"],
|
||||
});
|
||||
let text_controls =
|
||||
let controls =
|
||||
create_text_param_for_request(None, &Some(schema.clone())).expect("text controls");
|
||||
|
||||
let req = ResponsesApiRequest {
|
||||
model: "gpt-5",
|
||||
instructions: "i",
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
include: vec![],
|
||||
prompt_cache_key: None,
|
||||
text: Some(text_controls),
|
||||
};
|
||||
|
||||
let v = serde_json::to_value(&req).expect("json");
|
||||
let text = v.get("text").expect("text field");
|
||||
assert!(text.get("verbosity").is_none());
|
||||
let format = text.get("format").expect("format field");
|
||||
|
||||
assert_eq!(
|
||||
format.get("name"),
|
||||
Some(&serde_json::Value::String("codex_output_schema".into()))
|
||||
);
|
||||
assert_eq!(
|
||||
format.get("type"),
|
||||
Some(&serde_json::Value::String("json_schema".into()))
|
||||
);
|
||||
assert_eq!(format.get("strict"), Some(&serde_json::Value::Bool(true)));
|
||||
assert_eq!(format.get("schema"), Some(&schema));
|
||||
let format = controls.format.expect("format");
|
||||
assert_eq!(format.name, "codex_output_schema");
|
||||
assert!(format.strict);
|
||||
assert_eq!(format.schema, schema);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn omits_text_when_not_set() {
|
||||
let input: Vec<ResponseItem> = vec![];
|
||||
let tools: Vec<serde_json::Value> = vec![];
|
||||
let req = ResponsesApiRequest {
|
||||
model: "gpt-5",
|
||||
instructions: "i",
|
||||
input: &input,
|
||||
tools: &tools,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: true,
|
||||
reasoning: None,
|
||||
store: false,
|
||||
stream: true,
|
||||
include: vec![],
|
||||
prompt_cache_key: None,
|
||||
text: None,
|
||||
};
|
||||
|
||||
let v = serde_json::to_value(&req).expect("json");
|
||||
assert!(v.get("text").is_none());
|
||||
fn create_text_controls_none_when_no_options() {
|
||||
assert!(create_text_param_for_request(None, &None).is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -51,7 +51,6 @@ use tracing::error;
|
||||
use tracing::info;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::ModelProviderInfo;
|
||||
use crate::client::ModelClient;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
@@ -64,8 +63,10 @@ use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
#[cfg(test)]
|
||||
use crate::exec::StreamOutput;
|
||||
use codex_api_client::ModelProviderInfo;
|
||||
// Removed: legacy executor wiring replaced by ToolOrchestrator flows.
|
||||
// legacy normalize_exec_result no longer used after orchestrator migration
|
||||
use crate::conversation_history::ResponsesApiChainState;
|
||||
use crate::mcp::auth::compute_auth_statuses;
|
||||
use crate::mcp_connection_manager::McpConnectionManager;
|
||||
use crate::model_family::find_family_for_model;
|
||||
@@ -301,7 +302,7 @@ pub(crate) struct SessionConfiguration {
|
||||
provider: ModelProviderInfo,
|
||||
|
||||
/// If not specified, server will use its default model.
|
||||
model: String,
|
||||
pub(crate) model: String,
|
||||
|
||||
model_reasoning_effort: Option<ReasoningEffortConfig>,
|
||||
model_reasoning_summary: ReasoningSummaryConfig,
|
||||
@@ -313,7 +314,7 @@ pub(crate) struct SessionConfiguration {
|
||||
user_instructions: Option<String>,
|
||||
|
||||
/// Base instructions override.
|
||||
base_instructions: Option<String>,
|
||||
pub(crate) base_instructions: Option<String>,
|
||||
|
||||
/// Compact prompt override.
|
||||
compact_prompt: Option<String>,
|
||||
@@ -333,7 +334,7 @@ pub(crate) struct SessionConfiguration {
|
||||
cwd: PathBuf,
|
||||
|
||||
/// Set of feature flags for this session
|
||||
features: Features,
|
||||
pub(crate) features: Features,
|
||||
|
||||
// TODO(pakrym): Remove config from here
|
||||
original_config_do_not_use: Arc<Config>,
|
||||
@@ -586,8 +587,9 @@ impl Session {
|
||||
config.active_profile.clone(),
|
||||
);
|
||||
|
||||
// Create the mutable state for the Session.
|
||||
let state = SessionState::new(session_configuration.clone());
|
||||
let model_family = find_family_for_model(&session_configuration.model)
|
||||
.unwrap_or_else(|| config.model_family.clone());
|
||||
let state = SessionState::new(session_configuration.clone(), model_family);
|
||||
|
||||
let services = SessionServices {
|
||||
mcp_connection_manager,
|
||||
@@ -694,7 +696,6 @@ impl Session {
|
||||
|
||||
pub(crate) async fn update_settings(&self, updates: SessionSettingsUpdate) {
|
||||
let mut state = self.state.lock().await;
|
||||
|
||||
state.session_configuration = state.session_configuration.apply(&updates);
|
||||
}
|
||||
|
||||
@@ -978,6 +979,31 @@ impl Session {
|
||||
state.replace_history(items);
|
||||
}
|
||||
|
||||
async fn update_responses_api_chain_state(
|
||||
&self,
|
||||
response_id: Option<String>,
|
||||
) {
|
||||
let mut state = self.state.lock().await;
|
||||
|
||||
let Some(response_id) = response_id.filter(|id| !id.is_empty()) else {
|
||||
state.reset_responses_api_chain();
|
||||
return;
|
||||
};
|
||||
|
||||
let mut history = state.clone_history();
|
||||
let prompt_items = history.get_history_for_prompt();
|
||||
let last_message_id = prompt_items
|
||||
.iter()
|
||||
.rev()
|
||||
.find_map(crate::state::response_item_id)
|
||||
.map(ToString::to_string);
|
||||
|
||||
state.set_responses_api_chain(ResponsesApiChainState {
|
||||
last_response_id: Some(response_id),
|
||||
last_message_id,
|
||||
});
|
||||
}
|
||||
|
||||
async fn persist_rollout_response_items(&self, items: &[ResponseItem]) {
|
||||
let rollout_items: Vec<RolloutItem> = items
|
||||
.iter()
|
||||
@@ -1761,30 +1787,32 @@ pub(crate) async fn run_task(
|
||||
.collect::<Vec<ResponseItem>>();
|
||||
|
||||
// Construct the input that we will send to the model.
|
||||
let turn_input: Vec<ResponseItem> = {
|
||||
sess.record_conversation_items(&turn_context, &pending_input)
|
||||
.await;
|
||||
sess.clone_history().await.get_history_for_prompt()
|
||||
};
|
||||
sess.record_conversation_items(&turn_context, &pending_input)
|
||||
.await;
|
||||
let mut state = sess.state.lock().await;
|
||||
let prompt = state.prompt_for_turn();
|
||||
|
||||
let turn_input_messages: Vec<String> = turn_input
|
||||
.iter()
|
||||
.filter_map(|item| match item {
|
||||
ResponseItem::Message { content, .. } => Some(content),
|
||||
_ => None,
|
||||
})
|
||||
.flat_map(|content| {
|
||||
content.iter().filter_map(|item| match item {
|
||||
ContentItem::OutputText { text } => Some(text.clone()),
|
||||
let turn_input_messages: Vec<String> = {
|
||||
prompt
|
||||
.input
|
||||
.iter()
|
||||
.filter_map(|item| match item {
|
||||
ResponseItem::Message { content, .. } => Some(content),
|
||||
_ => None,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
.flat_map(|content| {
|
||||
content.iter().filter_map(|item| match item {
|
||||
ContentItem::OutputText { text } => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
match run_turn(
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
turn_input,
|
||||
prompt,
|
||||
cancellation_token.child_token(),
|
||||
)
|
||||
.await
|
||||
@@ -1870,7 +1898,7 @@ async fn run_turn(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
input: Vec<ResponseItem>,
|
||||
mut prompt: Prompt,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
|
||||
@@ -1879,27 +1907,39 @@ async fn run_turn(
|
||||
Some(mcp_tools),
|
||||
));
|
||||
|
||||
let tool_specs = router.specs();
|
||||
let (tools_json, has_freeform_apply_patch) =
|
||||
crate::tools::spec::tools_metadata_for_prompt(&tool_specs)?;
|
||||
crate::conversation_history::format_prompt_items(&mut prompt.input, has_freeform_apply_patch);
|
||||
|
||||
let apply_patch_present = tool_specs.iter().any(|spec| spec.name() == "apply_patch");
|
||||
|
||||
let instructions = crate::client_common::compute_full_instructions(
|
||||
turn_context.base_instructions.as_deref(),
|
||||
&turn_context.client.get_model_family(),
|
||||
apply_patch_present,
|
||||
)
|
||||
.into_owned();
|
||||
|
||||
let model_supports_parallel = turn_context
|
||||
.client
|
||||
.get_model_family()
|
||||
.supports_parallel_tool_calls;
|
||||
let parallel_tool_calls = model_supports_parallel;
|
||||
let prompt = Prompt {
|
||||
input,
|
||||
tools: router.specs(),
|
||||
parallel_tool_calls,
|
||||
base_instructions_override: turn_context.base_instructions.clone(),
|
||||
output_schema: turn_context.final_output_json_schema.clone(),
|
||||
};
|
||||
prompt.instructions = instructions.clone();
|
||||
prompt.tools = tools_json;
|
||||
prompt.parallel_tool_calls = parallel_tool_calls;
|
||||
prompt.output_schema = turn_context.final_output_json_schema.clone();
|
||||
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
let attempt_prompt = prompt.clone();
|
||||
match try_run_turn(
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
&prompt,
|
||||
attempt_prompt,
|
||||
cancellation_token.child_token(),
|
||||
)
|
||||
.await
|
||||
@@ -1980,7 +2020,7 @@ async fn try_run_turn(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
prompt: &Prompt,
|
||||
prompt: Prompt,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> CodexResult<TurnRunResult> {
|
||||
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
|
||||
@@ -1996,7 +2036,7 @@ async fn try_run_turn(
|
||||
let mut stream = turn_context
|
||||
.client
|
||||
.clone()
|
||||
.stream(prompt)
|
||||
.stream(&prompt)
|
||||
.or_cancel(&cancellation_token)
|
||||
.await??;
|
||||
|
||||
@@ -2129,7 +2169,7 @@ async fn try_run_turn(
|
||||
sess.update_rate_limits(&turn_context, snapshot).await;
|
||||
}
|
||||
ResponseEvent::Completed {
|
||||
response_id: _,
|
||||
response_id,
|
||||
token_usage,
|
||||
} => {
|
||||
sess.update_token_usage_info(&turn_context, token_usage.as_ref())
|
||||
@@ -2139,6 +2179,10 @@ async fn try_run_turn(
|
||||
let mut tracker = turn_diff_tracker.lock().await;
|
||||
tracker.get_unified_diff()
|
||||
};
|
||||
sess.update_responses_api_chain_state(
|
||||
Some(response_id.clone()),
|
||||
)
|
||||
.await;
|
||||
if let Ok(Some(unified_diff)) = unified_diff {
|
||||
let msg = EventMsg::TurnDiff(TurnDiffEvent { unified_diff });
|
||||
sess.send_event(&turn_context, msg).await;
|
||||
@@ -2534,7 +2578,9 @@ mod tests {
|
||||
session_source: SessionSource::Exec,
|
||||
};
|
||||
|
||||
let state = SessionState::new(session_configuration.clone());
|
||||
let model_family = find_family_for_model(&session_configuration.model)
|
||||
.unwrap_or_else(|| config.model_family.clone());
|
||||
let state = SessionState::new(session_configuration.clone(), model_family);
|
||||
|
||||
let services = SessionServices {
|
||||
mcp_connection_manager: McpConnectionManager::default(),
|
||||
@@ -2610,7 +2656,9 @@ mod tests {
|
||||
session_source: SessionSource::Exec,
|
||||
};
|
||||
|
||||
let state = SessionState::new(session_configuration.clone());
|
||||
let model_family = find_family_for_model(&session_configuration.model)
|
||||
.unwrap_or_else(|| config.model_family.clone());
|
||||
let state = SessionState::new(session_configuration.clone(), model_family);
|
||||
|
||||
let services = SessionServices {
|
||||
mcp_connection_manager: McpConnectionManager::default(),
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use super::Session;
|
||||
use super::TurnContext;
|
||||
use super::get_last_assistant_message_from_turn;
|
||||
use crate::Prompt;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
@@ -84,11 +84,9 @@ async fn run_compact_task_inner(
|
||||
|
||||
loop {
|
||||
let turn_input = history.get_history_for_prompt();
|
||||
let prompt = Prompt {
|
||||
input: turn_input.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
let attempt_result = drain_to_completed(&sess, turn_context.as_ref(), &prompt).await;
|
||||
let turn_input_len = turn_input.len();
|
||||
let (prompt, _) = crate::state::build_prompt_from_items(turn_input, None);
|
||||
let attempt_result = drain_to_completed(&sess, turn_context.as_ref(), prompt).await;
|
||||
|
||||
match attempt_result {
|
||||
Ok(()) => {
|
||||
@@ -107,7 +105,7 @@ async fn run_compact_task_inner(
|
||||
return;
|
||||
}
|
||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||
if turn_input.len() > 1 {
|
||||
if turn_input_len > 1 {
|
||||
// Trim from the beginning to preserve cache (prefix-based) and keep recent messages intact.
|
||||
error!(
|
||||
"Context window exceeded while compacting; removing oldest history item. Error: {e}"
|
||||
@@ -251,9 +249,9 @@ fn build_compacted_history_with_limit(
|
||||
async fn drain_to_completed(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
prompt: &Prompt,
|
||||
prompt: Prompt,
|
||||
) -> CodexResult<()> {
|
||||
let mut stream = turn_context.client.clone().stream(prompt).await?;
|
||||
let mut stream = turn_context.client.clone().stream(&prompt).await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
let Some(event) = maybe_event else {
|
||||
|
||||
@@ -25,13 +25,13 @@ use crate::git_info::resolve_root_git_project_for_trust;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::model_family::derive_default_model_family;
|
||||
use crate::model_family::find_family_for_model;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::built_in_model_providers;
|
||||
use crate::openai_model_info::get_model_info;
|
||||
use crate::project_doc::DEFAULT_PROJECT_DOC_FILENAME;
|
||||
use crate::project_doc::LOCAL_PROJECT_DOC_FILENAME;
|
||||
use crate::protocol::AskForApproval;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use codex_api_client::ModelProviderInfo;
|
||||
use codex_api_client::built_in_model_providers;
|
||||
use codex_app_server_protocol::Tools;
|
||||
use codex_app_server_protocol::UserSavedConfig;
|
||||
use codex_protocol::config_types::ForcedLoginMethod;
|
||||
@@ -2802,7 +2802,7 @@ model_verbosity = "high"
|
||||
name: "OpenAI using Chat Completions".to_string(),
|
||||
base_url: Some("https://api.openai.com/v1".to_string()),
|
||||
env_key: Some("OPENAI_API_KEY".to_string()),
|
||||
wire_api: crate::WireApi::Chat,
|
||||
wire_api: codex_api_client::WireApi::Chat,
|
||||
env_key_instructions: None,
|
||||
experimental_bearer_token: None,
|
||||
query_params: None,
|
||||
|
||||
@@ -7,6 +7,7 @@ use codex_protocol::protocol::TokenUsage;
|
||||
use codex_protocol::protocol::TokenUsageInfo;
|
||||
use codex_utils_string::take_bytes_at_char_boundary;
|
||||
use codex_utils_string::take_last_bytes_at_char_boundary;
|
||||
use std::collections::HashSet;
|
||||
use std::ops::Deref;
|
||||
|
||||
// Model-formatting limits: clients get full streams; only content sent to the model is truncated.
|
||||
@@ -22,6 +23,13 @@ pub(crate) struct ConversationHistory {
|
||||
/// The oldest items are at the beginning of the vector.
|
||||
items: Vec<ResponseItem>,
|
||||
token_info: Option<TokenUsageInfo>,
|
||||
responses_api_chain: Option<ResponsesApiChainState>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub(crate) struct ResponsesApiChainState {
|
||||
pub last_response_id: Option<String>,
|
||||
pub last_message_id: Option<String>,
|
||||
}
|
||||
|
||||
impl ConversationHistory {
|
||||
@@ -29,6 +37,7 @@ impl ConversationHistory {
|
||||
Self {
|
||||
items: Vec::new(),
|
||||
token_info: TokenUsageInfo::new_or_append(&None, &None, None),
|
||||
responses_api_chain: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,6 +80,10 @@ impl ConversationHistory {
|
||||
// Returns the history prepared for sending to the model.
|
||||
// With extra response items filtered out and GhostCommits removed.
|
||||
pub(crate) fn get_history_for_prompt(&mut self) -> Vec<ResponseItem> {
|
||||
self.build_prompt_history()
|
||||
}
|
||||
|
||||
fn build_prompt_history(&mut self) -> Vec<ResponseItem> {
|
||||
let mut history = self.get_history();
|
||||
Self::remove_ghost_snapshots(&mut history);
|
||||
Self::remove_reasoning_before_last_turn(&mut history);
|
||||
@@ -91,6 +104,7 @@ impl ConversationHistory {
|
||||
|
||||
pub(crate) fn replace(&mut self, items: Vec<ResponseItem>) {
|
||||
self.items = items;
|
||||
self.reset_responses_api_chain();
|
||||
}
|
||||
|
||||
pub(crate) fn update_token_info(
|
||||
@@ -429,6 +443,18 @@ impl ConversationHistory {
|
||||
| ResponseItem::Other => item.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn responses_api_chain(&self) -> Option<ResponsesApiChainState> {
|
||||
self.responses_api_chain.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn reset_responses_api_chain(&mut self) {
|
||||
self.responses_api_chain = None;
|
||||
}
|
||||
|
||||
pub(crate) fn set_responses_api_chain(&mut self, chain: ResponsesApiChainState) {
|
||||
self.responses_api_chain = Some(chain);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn format_output_for_model_body(content: &str) -> String {
|
||||
@@ -519,6 +545,102 @@ fn is_api_message(message: &ResponseItem) -> bool {
|
||||
}
|
||||
}
|
||||
|
||||
fn reserialize_shell_outputs(items: &mut [ResponseItem]) {
|
||||
let mut shell_call_ids: HashSet<String> = HashSet::new();
|
||||
items.iter_mut().for_each(|item| match item {
|
||||
ResponseItem::LocalShellCall { call_id, id, .. } => {
|
||||
if let Some(identifier) = call_id.clone().or_else(|| id.clone()) {
|
||||
shell_call_ids.insert(identifier);
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCall { call_id, name, .. } => {
|
||||
if name == "apply_patch" {
|
||||
shell_call_ids.insert(call_id.clone());
|
||||
}
|
||||
}
|
||||
ResponseItem::CustomToolCallOutput { call_id, output } => {
|
||||
if shell_call_ids.remove(call_id)
|
||||
&& let Some(structured) = parse_structured_shell_output(output)
|
||||
{
|
||||
*output = structured;
|
||||
}
|
||||
}
|
||||
ResponseItem::FunctionCall { name, call_id, .. }
|
||||
if name == "shell" || name == "container.exec" || name == "apply_patch" =>
|
||||
{
|
||||
shell_call_ids.insert(call_id.clone());
|
||||
}
|
||||
ResponseItem::FunctionCallOutput { call_id, output } => {
|
||||
if shell_call_ids.remove(call_id)
|
||||
&& let Some(structured) = parse_structured_shell_output(&output.content)
|
||||
{
|
||||
output.content = structured;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
});
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ExecOutputJson {
|
||||
output: String,
|
||||
metadata: ExecOutputMetadataJson,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ExecOutputMetadataJson {
|
||||
exit_code: i32,
|
||||
duration_seconds: f32,
|
||||
}
|
||||
|
||||
fn parse_structured_shell_output(raw: &str) -> Option<String> {
|
||||
let parsed: ExecOutputJson = serde_json::from_str(raw).ok()?;
|
||||
Some(build_structured_output(&parsed))
|
||||
}
|
||||
|
||||
fn build_structured_output(parsed: &ExecOutputJson) -> String {
|
||||
let mut sections = Vec::new();
|
||||
sections.push(format!("Exit code: {}", parsed.metadata.exit_code));
|
||||
sections.push(format!(
|
||||
"Wall time: {} seconds",
|
||||
parsed.metadata.duration_seconds
|
||||
));
|
||||
|
||||
let mut output = parsed.output.clone();
|
||||
if let Some(total_lines) = extract_total_output_lines(&parsed.output) {
|
||||
sections.push(format!("Total output lines: {total_lines}"));
|
||||
if let Some(stripped) = strip_total_output_header(&output) {
|
||||
output = stripped.to_string();
|
||||
}
|
||||
}
|
||||
|
||||
sections.push("Output:".to_string());
|
||||
sections.push(output);
|
||||
|
||||
sections.join("\n")
|
||||
}
|
||||
|
||||
fn extract_total_output_lines(output: &str) -> Option<u32> {
|
||||
let marker_start = output.find("[... omitted ")?;
|
||||
let marker = &output[marker_start..];
|
||||
let (_, after_of) = marker.split_once(" of ")?;
|
||||
let (total_segment, _) = after_of.split_once(' ')?;
|
||||
total_segment.parse::<u32>().ok()
|
||||
}
|
||||
|
||||
fn strip_total_output_header(output: &str) -> Option<&str> {
|
||||
let after_prefix = output.strip_prefix("Total output lines: ")?;
|
||||
let (_, remainder) = after_prefix.split_once('\n')?;
|
||||
let remainder = remainder.strip_prefix('\n').unwrap_or(remainder);
|
||||
Some(remainder)
|
||||
}
|
||||
|
||||
pub(crate) fn format_prompt_items(items: &mut [ResponseItem], has_freeform_apply_patch: bool) {
|
||||
if has_freeform_apply_patch {
|
||||
reserialize_shell_outputs(items);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -41,6 +41,14 @@ impl CodexHttpClient {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> &reqwest::Client {
|
||||
&self.inner
|
||||
}
|
||||
|
||||
pub fn clone_inner(&self) -> reqwest::Client {
|
||||
self.inner.clone()
|
||||
}
|
||||
|
||||
pub fn get<U>(&self, url: U) -> CodexRequestBuilder
|
||||
where
|
||||
U: IntoUrl,
|
||||
|
||||
@@ -43,6 +43,8 @@ pub enum Feature {
|
||||
SandboxCommandAssessment,
|
||||
/// Create a ghost commit at each turn.
|
||||
GhostCommit,
|
||||
/// Enable chaining Responses API calls via previous response IDs.
|
||||
ResponsesApiChaining,
|
||||
}
|
||||
|
||||
impl Feature {
|
||||
@@ -295,4 +297,10 @@ pub const FEATURES: &[FeatureSpec] = &[
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
FeatureSpec {
|
||||
id: Feature::ResponsesApiChaining,
|
||||
key: "responses_api_chaining",
|
||||
stage: Stage::Experimental,
|
||||
default_enabled: false,
|
||||
},
|
||||
];
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
mod apply_patch;
|
||||
pub mod auth;
|
||||
pub mod bash;
|
||||
mod chat_completions;
|
||||
mod client;
|
||||
mod client_common;
|
||||
pub mod codex;
|
||||
@@ -19,9 +18,11 @@ mod command_safety;
|
||||
pub mod config;
|
||||
pub mod config_loader;
|
||||
mod conversation_history;
|
||||
mod conversation_manager;
|
||||
pub mod custom_prompts;
|
||||
mod environment_context;
|
||||
pub mod error;
|
||||
mod event_mapping;
|
||||
pub mod exec;
|
||||
pub mod exec_env;
|
||||
pub mod features;
|
||||
@@ -32,22 +33,14 @@ pub mod mcp;
|
||||
mod mcp_connection_manager;
|
||||
mod mcp_tool_call;
|
||||
mod message_history;
|
||||
mod model_provider_info;
|
||||
pub mod parse_command;
|
||||
mod response_processing;
|
||||
pub mod review_format;
|
||||
pub mod sandboxing;
|
||||
pub mod token_data;
|
||||
mod truncate;
|
||||
mod unified_exec;
|
||||
mod user_instructions;
|
||||
pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
pub use model_provider_info::ModelProviderInfo;
|
||||
pub use model_provider_info::WireApi;
|
||||
pub use model_provider_info::built_in_model_providers;
|
||||
pub use model_provider_info::create_oss_provider_with_base_url;
|
||||
mod conversation_manager;
|
||||
mod event_mapping;
|
||||
pub mod review_format;
|
||||
pub use codex_protocol::protocol::InitialHistory;
|
||||
pub use conversation_manager::ConversationManager;
|
||||
pub use conversation_manager::NewConversation;
|
||||
|
||||
@@ -5,13 +5,13 @@ use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::ModelProviderInfo;
|
||||
use crate::client::ModelClient;
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::ResponseEvent;
|
||||
use crate::config::Config;
|
||||
use crate::protocol::SandboxPolicy;
|
||||
use askama::Template;
|
||||
use codex_api_client::ModelProviderInfo;
|
||||
use codex_api_client::Prompt;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::ConversationId;
|
||||
use codex_protocol::models::ContentItem;
|
||||
@@ -126,12 +126,10 @@ pub(crate) async fn assess_command(
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText { text: user_prompt }],
|
||||
}],
|
||||
tools: Vec::new(),
|
||||
parallel_tool_calls: false,
|
||||
base_instructions_override: Some(system_prompt),
|
||||
output_schema: Some(sandbox_assessment_schema()),
|
||||
instructions: system_prompt,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let child_otel =
|
||||
parent_otel.with_model(config.model.as_str(), config.model_family.slug.as_str());
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ mod turn;
|
||||
|
||||
pub(crate) use service::SessionServices;
|
||||
pub(crate) use session::SessionState;
|
||||
pub(crate) use session::build_prompt_from_items;
|
||||
pub(crate) use session::response_item_id;
|
||||
pub(crate) use turn::ActiveTurn;
|
||||
pub(crate) use turn::RunningTask;
|
||||
pub(crate) use turn::TaskKind;
|
||||
|
||||
@@ -2,26 +2,41 @@
|
||||
|
||||
use codex_protocol::models::ResponseItem;
|
||||
|
||||
use crate::client_common::Prompt;
|
||||
use crate::client_common::compute_full_instructions;
|
||||
use crate::codex::SessionConfiguration;
|
||||
use crate::conversation_history::ConversationHistory;
|
||||
use crate::conversation_history::ResponsesApiChainState;
|
||||
use crate::conversation_history::format_prompt_items;
|
||||
use crate::features::Feature;
|
||||
use crate::model_family::ModelFamily;
|
||||
use crate::protocol::RateLimitSnapshot;
|
||||
use crate::protocol::TokenUsage;
|
||||
use crate::protocol::TokenUsageInfo;
|
||||
use crate::tools::spec::ToolsConfig;
|
||||
use crate::tools::spec::ToolsConfigParams;
|
||||
use crate::tools::spec::build_specs;
|
||||
use crate::tools::spec::tools_metadata_for_prompt;
|
||||
|
||||
/// Persistent, session-scoped state previously stored directly on `Session`.
|
||||
pub(crate) struct SessionState {
|
||||
pub(crate) session_configuration: SessionConfiguration,
|
||||
pub(crate) history: ConversationHistory,
|
||||
pub(crate) latest_rate_limits: Option<RateLimitSnapshot>,
|
||||
pub(crate) model_family: ModelFamily,
|
||||
}
|
||||
|
||||
impl SessionState {
|
||||
/// Create a new session state mirroring previous `State::default()` semantics.
|
||||
pub(crate) fn new(session_configuration: SessionConfiguration) -> Self {
|
||||
pub(crate) fn new(
|
||||
session_configuration: SessionConfiguration,
|
||||
model_family: ModelFamily,
|
||||
) -> Self {
|
||||
Self {
|
||||
session_configuration,
|
||||
history: ConversationHistory::new(),
|
||||
latest_rate_limits: None,
|
||||
model_family,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,6 +57,16 @@ impl SessionState {
|
||||
self.history.replace(items);
|
||||
}
|
||||
|
||||
pub(crate) fn reset_responses_api_chain(&mut self) {
|
||||
self.history.reset_responses_api_chain();
|
||||
}
|
||||
|
||||
pub(crate) fn set_responses_api_chain(&mut self, chain: ResponsesApiChainState) {
|
||||
if self.session_configuration.features.enabled(Feature::ResponsesApiChaining) {
|
||||
self.history.set_responses_api_chain(chain);
|
||||
}
|
||||
}
|
||||
|
||||
// Token/rate limit helpers
|
||||
pub(crate) fn update_token_info_from_usage(
|
||||
&mut self,
|
||||
@@ -68,4 +93,84 @@ impl SessionState {
|
||||
pub(crate) fn set_token_usage_full(&mut self, context_window: i64) {
|
||||
self.history.set_token_usage_full(context_window);
|
||||
}
|
||||
|
||||
pub(crate) fn prompt_for_turn(&mut self) -> Prompt {
|
||||
let tools_config = ToolsConfig::new(&ToolsConfigParams {
|
||||
model_family: &self.model_family,
|
||||
features: &self.session_configuration.features,
|
||||
});
|
||||
let (tool_specs, _registry) = build_specs(&tools_config, None).build();
|
||||
let tool_specs = tool_specs.into_iter().map(|c| c.spec).collect::<Vec<_>>();
|
||||
|
||||
let prompt_items = self.history.get_history_for_prompt();
|
||||
let chain_state = self.history.responses_api_chain();
|
||||
let (mut prompt, reset_chain) = build_prompt_from_items(prompt_items, chain_state.as_ref());
|
||||
if reset_chain {
|
||||
self.reset_responses_api_chain();
|
||||
}
|
||||
|
||||
// Populate prompt fields that depend only on session state.
|
||||
let (tools_json, has_freeform_apply_patch) =
|
||||
tools_metadata_for_prompt(&tool_specs).unwrap_or_default();
|
||||
format_prompt_items(&mut prompt.input, has_freeform_apply_patch);
|
||||
|
||||
let apply_patch_present = tool_specs.iter().any(|spec| spec.name() == "apply_patch");
|
||||
let base_override = self.session_configuration.base_instructions.as_deref();
|
||||
let instructions =
|
||||
compute_full_instructions(base_override, &self.model_family, apply_patch_present)
|
||||
.into_owned();
|
||||
|
||||
prompt.instructions = instructions;
|
||||
prompt.tools = tools_json;
|
||||
prompt.parallel_tool_calls = self.model_family.supports_parallel_tool_calls;
|
||||
|
||||
prompt
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn response_item_id(item: &ResponseItem) -> Option<&str> {
|
||||
match item {
|
||||
ResponseItem::Message { id: Some(id), .. }
|
||||
| ResponseItem::Reasoning { id, .. }
|
||||
| ResponseItem::LocalShellCall { id: Some(id), .. }
|
||||
| ResponseItem::FunctionCall { id: Some(id), .. }
|
||||
| ResponseItem::CustomToolCall { id: Some(id), .. }
|
||||
| ResponseItem::WebSearchCall { id: Some(id), .. } => Some(id.as_str()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn build_prompt_from_items(
|
||||
prompt_items: Vec<ResponseItem>,
|
||||
chain_state: Option<&ResponsesApiChainState>,
|
||||
) -> (Prompt, bool) {
|
||||
let mut prompt = Prompt {
|
||||
..Prompt::default()
|
||||
};
|
||||
|
||||
if let Some(state) = chain_state {
|
||||
if let Some(last_message_id) = state.last_message_id.as_ref() {
|
||||
if let Some(position) = prompt_items
|
||||
.iter()
|
||||
.position(|item| response_item_id(item) == Some(last_message_id.as_str()))
|
||||
{
|
||||
if let Some(previous_response_id) = state.last_response_id.clone() {
|
||||
prompt.previous_response_id = Some(previous_response_id);
|
||||
}
|
||||
prompt.input = prompt_items.into_iter().skip(position + 1).collect();
|
||||
return (prompt, false);
|
||||
}
|
||||
prompt.input = prompt_items;
|
||||
return (prompt, true);
|
||||
}
|
||||
|
||||
if let Some(previous_response_id) = state.last_response_id.clone() {
|
||||
prompt.previous_response_id = Some(previous_response_id);
|
||||
}
|
||||
prompt.input = prompt_items;
|
||||
return (prompt, false);
|
||||
}
|
||||
|
||||
prompt.input = prompt_items;
|
||||
(prompt, false)
|
||||
}
|
||||
|
||||
@@ -3,10 +3,6 @@ use std::collections::BTreeMap;
|
||||
use crate::apply_patch;
|
||||
use crate::apply_patch::InternalApplyPatchInvocation;
|
||||
use crate::apply_patch::convert_apply_patch_to_protocol;
|
||||
use crate::client_common::tools::FreeformTool;
|
||||
use crate::client_common::tools::FreeformToolFormat;
|
||||
use crate::client_common::tools::ResponsesApiTool;
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
@@ -20,7 +16,11 @@ use crate::tools::runtimes::apply_patch::ApplyPatchRequest;
|
||||
use crate::tools::runtimes::apply_patch::ApplyPatchRuntime;
|
||||
use crate::tools::sandboxing::ToolCtx;
|
||||
use crate::tools::spec::ApplyPatchToolArgs;
|
||||
use crate::tools::spec::FreeformTool;
|
||||
use crate::tools::spec::FreeformToolFormat;
|
||||
use crate::tools::spec::JsonSchema;
|
||||
use crate::tools::spec::ResponsesApiTool;
|
||||
use crate::tools::spec::ToolSpec;
|
||||
use async_trait::async_trait;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
use crate::client_common::tools::ResponsesApiTool;
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
@@ -9,6 +7,8 @@ use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ToolHandler;
|
||||
use crate::tools::registry::ToolKind;
|
||||
use crate::tools::spec::JsonSchema;
|
||||
use crate::tools::spec::ResponsesApiTool;
|
||||
use crate::tools::spec::ToolSpec;
|
||||
use async_trait::async_trait;
|
||||
use codex_protocol::plan_tool::UpdatePlanArgs;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
|
||||
@@ -6,11 +6,11 @@ use async_trait::async_trait;
|
||||
use codex_protocol::models::ResponseInputItem;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolOutput;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::spec::ToolSpec;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum ToolKind {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::function_tool::FunctionCallError;
|
||||
@@ -10,6 +9,7 @@ use crate::tools::context::ToolInvocation;
|
||||
use crate::tools::context::ToolPayload;
|
||||
use crate::tools::registry::ConfiguredToolSpec;
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::tools::spec::ToolSpec;
|
||||
use crate::tools::spec::ToolsConfig;
|
||||
use crate::tools::spec::build_specs;
|
||||
use codex_protocol::models::LocalShellAction;
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
use crate::client_common::tools::ResponsesApiTool;
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::features::Feature;
|
||||
use crate::features::Features;
|
||||
use crate::model_family::ModelFamily;
|
||||
@@ -22,6 +20,52 @@ pub enum ConfigShellToolType {
|
||||
Streamable,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, PartialEq)]
|
||||
#[serde(tag = "type")]
|
||||
pub(crate) enum ToolSpec {
|
||||
#[serde(rename = "function")]
|
||||
Function(ResponsesApiTool),
|
||||
#[serde(rename = "local_shell")]
|
||||
LocalShell {},
|
||||
#[serde(rename = "web_search")]
|
||||
WebSearch {},
|
||||
#[serde(rename = "custom")]
|
||||
Freeform(FreeformTool),
|
||||
}
|
||||
|
||||
impl ToolSpec {
|
||||
pub(crate) fn name(&self) -> &str {
|
||||
match self {
|
||||
ToolSpec::Function(tool) => tool.name.as_str(),
|
||||
ToolSpec::LocalShell {} => "local_shell",
|
||||
ToolSpec::WebSearch {} => "web_search",
|
||||
ToolSpec::Freeform(tool) => tool.name.as_str(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FreeformTool {
|
||||
pub(crate) name: String,
|
||||
pub(crate) description: String,
|
||||
pub(crate) format: FreeformToolFormat,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FreeformToolFormat {
|
||||
pub(crate) r#type: String,
|
||||
pub(crate) syntax: String,
|
||||
pub(crate) definition: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, PartialEq)]
|
||||
pub struct ResponsesApiTool {
|
||||
pub(crate) name: String,
|
||||
pub(crate) description: String,
|
||||
pub(crate) strict: bool,
|
||||
pub(crate) parameters: JsonSchema,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct ToolsConfig {
|
||||
pub shell_type: ConfigShellToolType,
|
||||
@@ -666,9 +710,6 @@ pub(crate) struct ApplyPatchToolArgs {
|
||||
pub(crate) input: String,
|
||||
}
|
||||
|
||||
/// Returns JSON values that are compatible with Function Calling in the
|
||||
/// Responses API:
|
||||
/// https://platform.openai.com/docs/guides/function-calling?api-mode=responses
|
||||
pub fn create_tools_json_for_responses_api(
|
||||
tools: &[ToolSpec],
|
||||
) -> crate::error::Result<Vec<serde_json::Value>> {
|
||||
@@ -681,35 +722,16 @@ pub fn create_tools_json_for_responses_api(
|
||||
|
||||
Ok(tools_json)
|
||||
}
|
||||
/// Returns JSON values that are compatible with Function Calling in the
|
||||
/// Chat Completions API:
|
||||
/// https://platform.openai.com/docs/guides/function-calling?api-mode=chat
|
||||
pub(crate) fn create_tools_json_for_chat_completions_api(
|
||||
tools: &[ToolSpec],
|
||||
) -> crate::error::Result<Vec<serde_json::Value>> {
|
||||
// We start with the JSON for the Responses API and than rewrite it to match
|
||||
// the chat completions tool call format.
|
||||
let responses_api_tools_json = create_tools_json_for_responses_api(tools)?;
|
||||
let tools_json = responses_api_tools_json
|
||||
.into_iter()
|
||||
.filter_map(|mut tool| {
|
||||
if tool.get("type") != Some(&serde_json::Value::String("function".to_string())) {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(map) = tool.as_object_mut() {
|
||||
// Remove "type" field as it is not needed in chat completions.
|
||||
map.remove("type");
|
||||
Some(json!({
|
||||
"type": "function",
|
||||
"function": map,
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<serde_json::Value>>();
|
||||
Ok(tools_json)
|
||||
pub fn tools_metadata_for_prompt(
|
||||
tools: &[ToolSpec],
|
||||
) -> crate::error::Result<(Vec<serde_json::Value>, bool)> {
|
||||
let tools_json = create_tools_json_for_responses_api(tools)?;
|
||||
let has_freeform_apply_patch = tools.iter().any(|tool| match tool {
|
||||
ToolSpec::Freeform(freeform) => freeform.name == "apply_patch",
|
||||
_ => false,
|
||||
});
|
||||
Ok((tools_json, has_freeform_apply_patch))
|
||||
}
|
||||
|
||||
pub(crate) fn mcp_tool_to_openai_tool(
|
||||
@@ -1002,7 +1024,6 @@ pub(crate) fn build_specs(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::client_common::tools::FreeformTool;
|
||||
use crate::model_family::find_family_for_model;
|
||||
use crate::tools::registry::ConfiguredToolSpec;
|
||||
use mcp_types::ToolInputSchema;
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_api_client::ModelProviderInfo;
|
||||
use codex_api_client::WireApi;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_core::ContentItem;
|
||||
use codex_core::LocalShellAction;
|
||||
use codex_core::LocalShellExecAction;
|
||||
use codex_core::LocalShellStatus;
|
||||
use codex_core::ModelClient;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::Prompt;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::ConversationId;
|
||||
@@ -97,10 +97,12 @@ async fn run_request(input: Vec<ResponseItem>) -> Value {
|
||||
codex_protocol::protocol::SessionSource::Exec,
|
||||
);
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = input;
|
||||
let prompt = Prompt {
|
||||
input,
|
||||
..Prompt::default()
|
||||
};
|
||||
|
||||
let mut stream = match client.stream(&prompt).await {
|
||||
let mut stream = match client.stream_for_test(prompt).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => panic!("stream chat failed: {e}"),
|
||||
};
|
||||
|
||||
@@ -2,14 +2,14 @@ use assert_matches::assert_matches;
|
||||
use std::sync::Arc;
|
||||
use tracing_test::traced_test;
|
||||
|
||||
use codex_api_client::ModelProviderInfo;
|
||||
use codex_api_client::WireApi;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_core::ContentItem;
|
||||
use codex_core::ModelClient;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::Prompt;
|
||||
use codex_core::ResponseEvent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::spawn::CODEX_SANDBOX_NETWORK_DISABLED_ENV_VAR;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::ConversationId;
|
||||
@@ -97,16 +97,18 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec<ResponseEvent> {
|
||||
codex_protocol::protocol::SessionSource::Exec,
|
||||
);
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".to_string(),
|
||||
let prompt = Prompt {
|
||||
input: vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".to_string(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".to_string(),
|
||||
}],
|
||||
}],
|
||||
}];
|
||||
..Prompt::default()
|
||||
};
|
||||
|
||||
let mut stream = match client.stream(&prompt).await {
|
||||
let mut stream = match client.stream_for_test(prompt).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => panic!("stream chat failed: {e}"),
|
||||
};
|
||||
|
||||
@@ -9,6 +9,7 @@ path = "lib.rs"
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
assert_cmd = { workspace = true }
|
||||
codex-api-client = { workspace = true }
|
||||
codex-core = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
notify = { workspace = true }
|
||||
|
||||
@@ -4,11 +4,11 @@ use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use codex_api_client::ModelProviderInfo;
|
||||
use codex_api_client::built_in_model_providers;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::CodexConversation;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::protocol::AskForApproval;
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_api_client::ModelProviderInfo;
|
||||
use codex_api_client::WireApi;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_core::ContentItem;
|
||||
use codex_core::ModelClient;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::Prompt;
|
||||
use codex_core::ResponseEvent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::WireApi;
|
||||
use codex_otel::otel_event_manager::OtelEventManager;
|
||||
use codex_protocol::ConversationId;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
@@ -82,16 +82,18 @@ async fn responses_stream_includes_subagent_header_on_review() {
|
||||
SessionSource::SubAgent(codex_protocol::protocol::SubAgentSource::Review),
|
||||
);
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
let prompt = Prompt {
|
||||
input: vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
}],
|
||||
}];
|
||||
..Prompt::default()
|
||||
};
|
||||
|
||||
let mut stream = client.stream(&prompt).await.expect("stream failed");
|
||||
let mut stream = client.stream_for_test(prompt).await.expect("stream failed");
|
||||
while let Some(event) = stream.next().await {
|
||||
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
|
||||
break;
|
||||
@@ -172,16 +174,18 @@ async fn responses_stream_includes_subagent_header_on_other() {
|
||||
)),
|
||||
);
|
||||
|
||||
let mut prompt = Prompt::default();
|
||||
prompt.input = vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
let prompt = Prompt {
|
||||
input: vec![ResponseItem::Message {
|
||||
id: None,
|
||||
role: "user".into(),
|
||||
content: vec![ContentItem::InputText {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
}],
|
||||
}];
|
||||
..Prompt::default()
|
||||
};
|
||||
|
||||
let mut stream = client.stream(&prompt).await.expect("stream failed");
|
||||
let mut stream = client.stream_for_test(prompt).await.expect("stream failed");
|
||||
while let Some(event) = stream.next().await {
|
||||
if matches!(event, Ok(ResponseEvent::Completed { .. })) {
|
||||
break;
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
use codex_api_client::ModelProviderInfo;
|
||||
use codex_api_client::WireApi;
|
||||
use codex_api_client::built_in_model_providers;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ContentItem;
|
||||
@@ -6,15 +9,13 @@ use codex_core::LocalShellAction;
|
||||
use codex_core::LocalShellExecAction;
|
||||
use codex_core::LocalShellStatus;
|
||||
use codex_core::ModelClient;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::Prompt;
|
||||
use codex_core::ResponseEvent;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::WireApi;
|
||||
use codex_core::auth::AuthCredentialsStoreMode;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::error::CodexErr;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::Op;
|
||||
@@ -678,6 +679,98 @@ async fn includes_developer_instructions_message_in_request() {
|
||||
assert_message_ends_with(&request_body["input"][2], "</environment_context>");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn responses_api_chaining_sets_store_and_previous_id() {
|
||||
skip_if_no_network!();
|
||||
|
||||
let server = MockServer::start().await;
|
||||
let first_response = responses::sse(vec![
|
||||
responses::ev_response_created("resp-first"),
|
||||
responses::ev_assistant_message("m1", "hi there"),
|
||||
responses::ev_completed("resp-first"),
|
||||
]);
|
||||
let second_response = responses::sse(vec![
|
||||
responses::ev_response_created("resp-second"),
|
||||
responses::ev_assistant_message("m2", "second reply"),
|
||||
responses::ev_completed("resp-second"),
|
||||
]);
|
||||
let response_mock =
|
||||
responses::mount_sse_sequence(&server, vec![first_response, second_response]).await;
|
||||
|
||||
let model_provider = ModelProviderInfo {
|
||||
base_url: Some(format!("{}/v1", server.uri())),
|
||||
..built_in_model_providers()["openai"].clone()
|
||||
};
|
||||
|
||||
let codex_home = TempDir::new().unwrap();
|
||||
let mut config = load_default_config_for_test(&codex_home);
|
||||
config.model_provider = model_provider;
|
||||
config.features.enable(Feature::ResponsesApiChaining);
|
||||
|
||||
let conversation_manager =
|
||||
ConversationManager::with_auth(CodexAuth::from_api_key("Test API Key"));
|
||||
let codex = conversation_manager
|
||||
.new_conversation(config)
|
||||
.await
|
||||
.expect("create new conversation")
|
||||
.conversation;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: "first turn".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
codex
|
||||
.submit(Op::UserInput {
|
||||
items: vec![UserInput::Text {
|
||||
text: "second turn".into(),
|
||||
}],
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
wait_for_event(&codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
let requests = response_mock.requests();
|
||||
assert_eq!(
|
||||
requests.len(),
|
||||
2,
|
||||
"expected two responses API calls for two turns"
|
||||
);
|
||||
|
||||
let first_body = requests[0].body_json();
|
||||
assert_eq!(first_body["store"], serde_json::Value::Bool(true));
|
||||
assert!(
|
||||
first_body.get("previous_response_id").is_none(),
|
||||
"first request should not set previous_response_id"
|
||||
);
|
||||
|
||||
let second_body = requests[1].body_json();
|
||||
assert_eq!(second_body["store"], serde_json::Value::Bool(true));
|
||||
assert_eq!(
|
||||
second_body["previous_response_id"].as_str(),
|
||||
Some("resp-first")
|
||||
);
|
||||
|
||||
let second_input = requests[1].input();
|
||||
assert_eq!(
|
||||
second_input.len(),
|
||||
1,
|
||||
"second request should only send new user input items"
|
||||
);
|
||||
let user_item = &second_input[0];
|
||||
assert_eq!(user_item["type"].as_str(), Some("message"));
|
||||
assert_eq!(user_item["role"].as_str(), Some("user"));
|
||||
let content = user_item["content"][0]["text"]
|
||||
.as_str()
|
||||
.expect("missing user message text");
|
||||
assert_eq!(content, "second turn");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn azure_responses_request_includes_store_and_reasoning_ids() {
|
||||
skip_if_no_network!();
|
||||
@@ -800,7 +893,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() {
|
||||
});
|
||||
|
||||
let mut stream = client
|
||||
.stream(&prompt)
|
||||
.stream_for_test(prompt)
|
||||
.await
|
||||
.expect("responses stream to start");
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use codex_api_client::ModelProviderInfo;
|
||||
use codex_api_client::built_in_model_providers;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::protocol::ErrorEvent;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::Op;
|
||||
|
||||
@@ -9,12 +9,12 @@
|
||||
|
||||
use super::compact::FIRST_REPLY;
|
||||
use super::compact::SUMMARY_TEXT;
|
||||
use codex_api_client::ModelProviderInfo;
|
||||
use codex_api_client::built_in_model_providers;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::CodexConversation;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::codex::compact::SUMMARIZATION_PROMPT;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::config::OPENAI_DEFAULT_MODEL;
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use codex_api_client::ModelProviderInfo;
|
||||
use codex_api_client::built_in_model_providers;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::parse_turn_item;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::Op;
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
use codex_api_client::ModelProviderInfo;
|
||||
use codex_api_client::built_in_model_providers;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
use codex_core::protocol::EventMsg;
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#![allow(clippy::unwrap_used)]
|
||||
|
||||
use codex_api_client::ModelProviderInfo;
|
||||
use codex_api_client::built_in_model_providers;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::config::OPENAI_DEFAULT_MODEL;
|
||||
use codex_core::features::Feature;
|
||||
use codex_core::model_family::find_family_for_model;
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use codex_api_client::ModelProviderInfo;
|
||||
use codex_api_client::built_in_model_providers;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::CodexConversation;
|
||||
use codex_core::ContentItem;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::REVIEW_PROMPT;
|
||||
use codex_core::ResponseItem;
|
||||
use codex_core::built_in_model_providers;
|
||||
use codex_core::config::Config;
|
||||
use codex_core::protocol::ENVIRONMENT_CONTEXT_OPEN_TAG;
|
||||
use codex_core::protocol::EventMsg;
|
||||
|
||||
@@ -422,7 +422,7 @@ async fn stdio_image_completions_round_trip() -> anyhow::Result<()> {
|
||||
|
||||
let fixture = test_codex()
|
||||
.with_config(move |config| {
|
||||
config.model_provider.wire_api = codex_core::WireApi::Chat;
|
||||
config.model_provider.wire_api = codex_api_client::WireApi::Chat;
|
||||
config.features.enable(Feature::RmcpClient);
|
||||
config.mcp_servers.insert(
|
||||
server_name.to_string(),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::WireApi;
|
||||
use codex_api_client::ModelProviderInfo;
|
||||
use codex_api_client::WireApi;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
|
||||
@@ -3,8 +3,8 @@
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::WireApi;
|
||||
use codex_api_client::ModelProviderInfo;
|
||||
use codex_api_client::WireApi;
|
||||
use codex_core::protocol::EventMsg;
|
||||
use codex_core::protocol::Op;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
|
||||
@@ -24,6 +24,7 @@ codex-common = { workspace = true, features = [
|
||||
"sandbox_summary",
|
||||
] }
|
||||
codex-core = { workspace = true }
|
||||
codex-api-client = { workspace = true }
|
||||
codex-ollama = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
mcp-types = { workspace = true }
|
||||
|
||||
@@ -11,8 +11,8 @@ pub mod event_processor_with_jsonl_output;
|
||||
pub mod exec_events;
|
||||
|
||||
pub use cli::Cli;
|
||||
use codex_api_client::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
use codex_core::ConversationManager;
|
||||
use codex_core::NewConversation;
|
||||
use codex_core::auth::enforce_login_restrictions;
|
||||
|
||||
@@ -13,6 +13,7 @@ workspace = true
|
||||
[dependencies]
|
||||
async-stream = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
codex-api-client = { workspace = true }
|
||||
codex-core = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["json", "stream"] }
|
||||
|
||||
@@ -10,9 +10,9 @@ use crate::pull::PullEvent;
|
||||
use crate::pull::PullProgressReporter;
|
||||
use crate::url::base_url_to_host_root;
|
||||
use crate::url::is_openai_compatible_base_url;
|
||||
use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
use codex_core::ModelProviderInfo;
|
||||
use codex_core::WireApi;
|
||||
use codex_api_client::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
use codex_api_client::ModelProviderInfo;
|
||||
use codex_api_client::WireApi;
|
||||
use codex_core::config::Config;
|
||||
|
||||
const OLLAMA_CONNECTION_ERROR: &str = "No running Ollama server detected. Start it with: `ollama serve` (after installing). Install instructions: https://github.com/ollama/ollama?tab=readme-ov-file#ollama";
|
||||
@@ -47,7 +47,7 @@ impl OllamaClient {
|
||||
|
||||
#[cfg(test)]
|
||||
async fn try_from_provider_with_base_url(base_url: &str) -> io::Result<Self> {
|
||||
let provider = codex_core::create_oss_provider_with_base_url(base_url);
|
||||
let provider = codex_api_client::create_oss_provider_with_base_url(base_url);
|
||||
Self::try_from_provider(&provider).await
|
||||
}
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ codex-common = { workspace = true, features = [
|
||||
"sandbox_summary",
|
||||
] }
|
||||
codex-core = { workspace = true }
|
||||
codex-api-client = { workspace = true }
|
||||
codex-file-search = { workspace = true }
|
||||
codex-login = { workspace = true }
|
||||
codex-ollama = { workspace = true }
|
||||
|
||||
@@ -6,9 +6,9 @@
|
||||
use additional_dirs::add_dir_warning_message;
|
||||
use app::App;
|
||||
pub use app::AppExitInfo;
|
||||
use codex_api_client::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_core::AuthManager;
|
||||
use codex_core::BUILT_IN_OSS_MODEL_PROVIDER_ID;
|
||||
use codex_core::CodexAuth;
|
||||
use codex_core::INTERACTIVE_SESSION_SOURCES;
|
||||
use codex_core::RolloutRecorder;
|
||||
|
||||
@@ -8,7 +8,6 @@ readme = "README.md"
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
once_cell = "1"
|
||||
regex = "1"
|
||||
schemars = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
|
||||
@@ -6,12 +6,12 @@
|
||||
//! mode via [`ApplyGitRequest::preflight`] and inspect the resulting paths to
|
||||
//! learn what would change before applying for real.
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
use regex::Regex;
|
||||
use std::ffi::OsStr;
|
||||
use std::io;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
/// Parameters for invoking [`apply_git_patch`].
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -192,7 +192,7 @@ fn render_command_for_log(cwd: &Path, git_cfg: &[String], args: &[String]) -> St
|
||||
|
||||
/// Collect every path referenced by the diff headers inside `diff --git` sections.
|
||||
pub fn extract_paths_from_patch(diff_text: &str) -> Vec<String> {
|
||||
static RE: Lazy<Regex> = Lazy::new(|| {
|
||||
static RE: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(r"(?m)^diff --git a/(.*?) b/(.*)$")
|
||||
.unwrap_or_else(|e| panic!("invalid regex: {e}"))
|
||||
});
|
||||
@@ -275,62 +275,64 @@ pub fn parse_git_apply_output(
|
||||
}
|
||||
}
|
||||
|
||||
static APPLIED_CLEAN: Lazy<Regex> =
|
||||
Lazy::new(|| regex_ci("^Applied patch(?: to)?\\s+(?P<path>.+?)\\s+cleanly\\.?$"));
|
||||
static APPLIED_CONFLICTS: Lazy<Regex> =
|
||||
Lazy::new(|| regex_ci("^Applied patch(?: to)?\\s+(?P<path>.+?)\\s+with conflicts\\.?$"));
|
||||
static APPLYING_WITH_REJECTS: Lazy<Regex> = Lazy::new(|| {
|
||||
static APPLIED_CLEAN: LazyLock<Regex> =
|
||||
LazyLock::new(|| regex_ci("^Applied patch(?: to)?\\s+(?P<path>.+?)\\s+cleanly\\.?$"));
|
||||
static APPLIED_CONFLICTS: LazyLock<Regex> = LazyLock::new(|| {
|
||||
regex_ci("^Applied patch(?: to)?\\s+(?P<path>.+?)\\s+with conflicts\\.?$")
|
||||
});
|
||||
static APPLYING_WITH_REJECTS: LazyLock<Regex> = LazyLock::new(|| {
|
||||
regex_ci("^Applying patch\\s+(?P<path>.+?)\\s+with\\s+\\d+\\s+rejects?\\.{0,3}$")
|
||||
});
|
||||
static CHECKING_PATCH: Lazy<Regex> =
|
||||
Lazy::new(|| regex_ci("^Checking patch\\s+(?P<path>.+?)\\.\\.\\.$"));
|
||||
static UNMERGED_LINE: Lazy<Regex> = Lazy::new(|| regex_ci("^U\\s+(?P<path>.+)$"));
|
||||
static PATCH_FAILED: Lazy<Regex> =
|
||||
Lazy::new(|| regex_ci("^error:\\s+patch failed:\\s+(?P<path>.+?)(?::\\d+)?(?:\\s|$)"));
|
||||
static DOES_NOT_APPLY: Lazy<Regex> =
|
||||
Lazy::new(|| regex_ci("^error:\\s+(?P<path>.+?):\\s+patch does not apply$"));
|
||||
static THREE_WAY_START: Lazy<Regex> = Lazy::new(|| {
|
||||
static CHECKING_PATCH: LazyLock<Regex> =
|
||||
LazyLock::new(|| regex_ci("^Checking patch\\s+(?P<path>.+?)\\.\\.\\.$"));
|
||||
static UNMERGED_LINE: LazyLock<Regex> = LazyLock::new(|| regex_ci("^U\\s+(?P<path>.+)$"));
|
||||
static PATCH_FAILED: LazyLock<Regex> =
|
||||
LazyLock::new(|| regex_ci("^error:\\s+patch failed:\\s+(?P<path>.+?)(?::\\d+)?(?:\\s|$)"));
|
||||
static DOES_NOT_APPLY: LazyLock<Regex> =
|
||||
LazyLock::new(|| regex_ci("^error:\\s+(?P<path>.+?):\\s+patch does not apply$"));
|
||||
static THREE_WAY_START: LazyLock<Regex> = LazyLock::new(|| {
|
||||
regex_ci("^(?:Performing three-way merge|Falling back to three-way merge)\\.\\.\\.$")
|
||||
});
|
||||
static THREE_WAY_FAILED: Lazy<Regex> =
|
||||
Lazy::new(|| regex_ci("^Failed to perform three-way merge\\.\\.\\.$"));
|
||||
static FALLBACK_DIRECT: Lazy<Regex> =
|
||||
Lazy::new(|| regex_ci("^Falling back to direct application\\.\\.\\.$"));
|
||||
static LACKS_BLOB: Lazy<Regex> = Lazy::new(|| {
|
||||
static THREE_WAY_FAILED: LazyLock<Regex> =
|
||||
LazyLock::new(|| regex_ci("^Failed to perform three-way merge\\.\\.\\.$"));
|
||||
static FALLBACK_DIRECT: LazyLock<Regex> =
|
||||
LazyLock::new(|| regex_ci("^Falling back to direct application\\.\\.\\.$"));
|
||||
static LACKS_BLOB: LazyLock<Regex> = LazyLock::new(|| {
|
||||
regex_ci(
|
||||
"^(?:error: )?repository lacks the necessary blob to (?:perform|fall back on) 3-?way merge\\.?$",
|
||||
)
|
||||
});
|
||||
static INDEX_MISMATCH: Lazy<Regex> =
|
||||
Lazy::new(|| regex_ci("^error:\\s+(?P<path>.+?):\\s+does not match index\\b"));
|
||||
static NOT_IN_INDEX: Lazy<Regex> =
|
||||
Lazy::new(|| regex_ci("^error:\\s+(?P<path>.+?):\\s+does not exist in index\\b"));
|
||||
static ALREADY_EXISTS_WT: Lazy<Regex> = Lazy::new(|| {
|
||||
static INDEX_MISMATCH: LazyLock<Regex> =
|
||||
LazyLock::new(|| regex_ci("^error:\\s+(?P<path>.+?):\\s+does not match index\\b"));
|
||||
static NOT_IN_INDEX: LazyLock<Regex> =
|
||||
LazyLock::new(|| regex_ci("^error:\\s+(?P<path>.+?):\\s+does not exist in index\\b"));
|
||||
static ALREADY_EXISTS_WT: LazyLock<Regex> = LazyLock::new(|| {
|
||||
regex_ci("^error:\\s+(?P<path>.+?)\\s+already exists in (?:the )?working directory\\b")
|
||||
});
|
||||
static FILE_EXISTS: Lazy<Regex> =
|
||||
Lazy::new(|| regex_ci("^error:\\s+patch failed:\\s+(?P<path>.+?)\\s+File exists"));
|
||||
static RENAMED_DELETED: Lazy<Regex> =
|
||||
Lazy::new(|| regex_ci("^error:\\s+path\\s+(?P<path>.+?)\\s+has been renamed\\/deleted"));
|
||||
static CANNOT_APPLY_BINARY: Lazy<Regex> = Lazy::new(|| {
|
||||
static FILE_EXISTS: LazyLock<Regex> =
|
||||
LazyLock::new(|| regex_ci("^error:\\s+patch failed:\\s+(?P<path>.+?)\\s+File exists"));
|
||||
static RENAMED_DELETED: LazyLock<Regex> = LazyLock::new(|| {
|
||||
regex_ci("^error:\\s+path\\s+(?P<path>.+?)\\s+has been renamed\\/deleted")
|
||||
});
|
||||
static CANNOT_APPLY_BINARY: LazyLock<Regex> = LazyLock::new(|| {
|
||||
regex_ci(
|
||||
"^error:\\s+cannot apply binary patch to\\s+['\\\"]?(?P<path>.+?)['\\\"]?\\s+without full index line$",
|
||||
)
|
||||
});
|
||||
static BINARY_DOES_NOT_APPLY: Lazy<Regex> = Lazy::new(|| {
|
||||
static BINARY_DOES_NOT_APPLY: LazyLock<Regex> = LazyLock::new(|| {
|
||||
regex_ci("^error:\\s+binary patch does not apply to\\s+['\\\"]?(?P<path>.+?)['\\\"]?$")
|
||||
});
|
||||
static BINARY_INCORRECT_RESULT: Lazy<Regex> = Lazy::new(|| {
|
||||
static BINARY_INCORRECT_RESULT: LazyLock<Regex> = LazyLock::new(|| {
|
||||
regex_ci(
|
||||
"^error:\\s+binary patch to\\s+['\\\"]?(?P<path>.+?)['\\\"]?\\s+creates incorrect result\\b",
|
||||
)
|
||||
});
|
||||
static CANNOT_READ_CURRENT: Lazy<Regex> = Lazy::new(|| {
|
||||
static CANNOT_READ_CURRENT: LazyLock<Regex> = LazyLock::new(|| {
|
||||
regex_ci("^error:\\s+cannot read the current contents of\\s+['\\\"]?(?P<path>.+?)['\\\"]?$")
|
||||
});
|
||||
static SKIPPED_PATCH: Lazy<Regex> =
|
||||
Lazy::new(|| regex_ci("^Skipped patch\\s+['\\\"]?(?P<path>.+?)['\\\"]\\.$"));
|
||||
static CANNOT_MERGE_BINARY_WARN: Lazy<Regex> = Lazy::new(|| {
|
||||
static SKIPPED_PATCH: LazyLock<Regex> =
|
||||
LazyLock::new(|| regex_ci("^Skipped patch\\s+['\\\"]?(?P<path>.+?)['\\\"]\\.$"));
|
||||
static CANNOT_MERGE_BINARY_WARN: LazyLock<Regex> = LazyLock::new(|| {
|
||||
regex_ci(
|
||||
"^warning:\\s*Cannot merge binary files:\\s+(?P<path>.+?)\\s+\\(ours\\s+vs\\.\\s+theirs\\)",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user