This commit is contained in:
jif-oai
2025-11-10 12:13:40 +00:00
parent dabf219a45
commit 10c880d886
3 changed files with 120 additions and 219 deletions

View File

@@ -7,6 +7,7 @@ pub mod error;
pub mod model_provider;
pub mod prompt;
pub mod responses;
pub mod routed_client;
pub mod stream;
pub use crate::aggregate::AggregateStreamExt;
@@ -28,6 +29,8 @@ pub use crate::prompt::Prompt;
pub use crate::responses::ResponsesApiClient;
pub use crate::responses::ResponsesApiClientConfig;
pub use crate::responses::stream_from_fixture;
pub use crate::routed_client::RoutedApiClient;
pub use crate::routed_client::RoutedApiClientConfig;
pub use crate::stream::EventStream;
pub use crate::stream::Reasoning;
pub use crate::stream::ResponseEvent;

View File

@@ -0,0 +1,87 @@
use std::path::PathBuf;
use std::sync::Arc;
use codex_otel::otel_event_manager::OtelEventManager;
use codex_protocol::ConversationId;
use codex_protocol::protocol::SessionSource;
use crate::ApiClient;
use crate::ChatAggregationMode;
use crate::ChatCompletionsApiClient;
use crate::ChatCompletionsApiClientConfig;
use crate::Prompt;
use crate::ResponseStream;
use crate::ResponsesApiClient;
use crate::ResponsesApiClientConfig;
use crate::Result;
use crate::WireApi;
use crate::auth::AuthProvider;
use crate::model_provider::ModelProviderInfo;
use crate::responses::stream_from_fixture;
/// Dispatches to the appropriate API client implementation based on the provider wire API.
#[derive(Clone)]
pub struct RoutedApiClientConfig {
pub http_client: reqwest::Client,
pub provider: ModelProviderInfo,
pub model: String,
pub conversation_id: ConversationId,
pub auth_provider: Option<Arc<dyn AuthProvider>>,
pub otel_event_manager: OtelEventManager,
pub session_source: SessionSource,
pub chat_aggregation_mode: ChatAggregationMode,
pub responses_fixture_path: Option<PathBuf>,
}
#[derive(Clone)]
pub struct RoutedApiClient {
config: RoutedApiClientConfig,
}
impl RoutedApiClient {
pub fn new(config: RoutedApiClientConfig) -> Self {
Self { config }
}
pub async fn stream(&self, prompt: Prompt) -> Result<ResponseStream> {
match self.config.provider.wire_api {
WireApi::Responses => self.stream_responses(prompt).await,
WireApi::Chat => self.stream_chat(prompt).await,
}
}
async fn stream_responses(&self, prompt: Prompt) -> Result<ResponseStream> {
if let Some(path) = &self.config.responses_fixture_path {
return stream_from_fixture(
path,
self.config.provider.clone(),
self.config.otel_event_manager.clone(),
)
.await;
}
let cfg = ResponsesApiClientConfig {
http_client: self.config.http_client.clone(),
provider: self.config.provider.clone(),
model: self.config.model.clone(),
conversation_id: self.config.conversation_id,
auth_provider: self.config.auth_provider.clone(),
otel_event_manager: self.config.otel_event_manager.clone(),
};
let client = ResponsesApiClient::new(cfg).await?;
client.stream(prompt).await
}
async fn stream_chat(&self, prompt: Prompt) -> Result<ResponseStream> {
let cfg = ChatCompletionsApiClientConfig {
http_client: self.config.http_client.clone(),
provider: self.config.provider.clone(),
model: self.config.model.clone(),
otel_event_manager: self.config.otel_event_manager.clone(),
session_source: self.config.session_source.clone(),
aggregation_mode: self.config.chat_aggregation_mode,
};
let client = ChatCompletionsApiClient::new(cfg).await?;
client.stream(prompt).await
}
}

View File

@@ -2,27 +2,22 @@ use std::fmt;
use std::sync::Arc;
use async_trait::async_trait;
use codex_api_client::AggregateStreamExt;
use codex_api_client::ApiClient;
use codex_api_client::AuthContext;
use codex_api_client::AuthProvider;
use codex_api_client::ChatAggregationMode;
use codex_api_client::ChatCompletionsApiClient;
use codex_api_client::ChatCompletionsApiClientConfig;
use codex_api_client::ModelProviderInfo;
use codex_api_client::ResponsesApiClient;
use codex_api_client::ResponsesApiClientConfig;
use codex_api_client::Result as ApiClientResult;
use codex_api_client::RoutedApiClient;
use codex_api_client::RoutedApiClientConfig;
use codex_api_client::WireApi;
use codex_api_client::stream_from_fixture;
use codex_otel::otel_event_manager::OtelEventManager;
use codex_protocol::ConversationId;
use codex_protocol::config_types::ReasoningEffort as ReasoningEffortConfig;
use codex_protocol::config_types::ReasoningSummary as ReasoningSummaryConfig;
use codex_protocol::protocol::SessionSource;
use futures::StreamExt;
use futures::stream::BoxStream;
use reqwest::StatusCode;
use std::path::PathBuf;
use tokio::sync::OnceCell;
use tokio::sync::mpsc;
use tracing::warn;
@@ -54,7 +49,7 @@ pub struct ModelClient {
auth_manager: Option<Arc<AuthManager>>,
otel_event_manager: OtelEventManager,
provider: ModelProviderInfo,
backend: Arc<OnceCell<ModelBackend>>,
api_client: Arc<OnceCell<RoutedApiClient>>,
conversation_id: ConversationId,
effort: Option<ReasoningEffortConfig>,
summary: ReasoningSummaryConfig,
@@ -67,57 +62,11 @@ impl fmt::Debug for ModelClient {
.field("provider", &self.provider.name)
.field("model", &self.config.model)
.field("conversation_id", &self.conversation_id)
.field("backend_initialized", &self.backend.get().is_some())
.field("client_initialized", &self.api_client.get().is_some())
.finish()
}
}
type ApiClientStream = BoxStream<'static, ApiClientResult<ResponseEvent>>;
enum ModelBackend {
Responses(ResponsesBackend),
Chat(ChatBackend),
}
impl ModelBackend {
async fn stream(&self, prompt: codex_api_client::Prompt) -> ApiClientResult<ApiClientStream> {
match self {
ModelBackend::Responses(backend) => backend.stream(prompt).await,
ModelBackend::Chat(backend) => backend.stream(prompt).await,
}
}
}
struct ResponsesBackend {
client: ResponsesApiClient,
}
impl ResponsesBackend {
async fn stream(&self, prompt: codex_api_client::Prompt) -> ApiClientResult<ApiClientStream> {
self.client
.stream(prompt)
.await
.map(futures::StreamExt::boxed)
}
}
struct ChatBackend {
client: ChatCompletionsApiClient,
show_reasoning: bool,
}
impl ChatBackend {
async fn stream(&self, prompt: codex_api_client::Prompt) -> ApiClientResult<ApiClientStream> {
let stream = self.client.stream(prompt).await?;
let stream = if self.show_reasoning {
stream.streaming_mode().boxed()
} else {
stream.aggregate().boxed()
};
Ok(stream)
}
}
#[allow(clippy::too_many_arguments)]
impl ModelClient {
pub fn new(
@@ -130,14 +79,14 @@ impl ModelClient {
conversation_id: ConversationId,
session_source: SessionSource,
) -> Self {
let backend = Arc::new(OnceCell::new());
let api_client = Arc::new(OnceCell::new());
Self {
config,
auth_manager,
otel_event_manager,
provider,
backend,
api_client,
conversation_id,
effort,
summary,
@@ -169,26 +118,14 @@ impl ModelClient {
pub async fn stream(&self, prompt: &Prompt) -> Result<ResponseStream> {
let api_prompt = self.build_api_prompt(prompt)?;
if self.provider.wire_api == WireApi::Responses
&& let Some(path) = &*CODEX_RS_SSE_FIXTURE
{
warn!(path, "Streaming from fixture");
let stream =
stream_from_fixture(path, self.provider.clone(), self.otel_event_manager.clone())
.await
.map_err(map_api_error)?
.boxed();
return Ok(wrap_stream(stream));
}
let backend = self
.backend
.get_or_try_init(|| async { self.build_backend().await })
let client = self
.api_client
.get_or_try_init(|| async { self.build_api_client().await })
.await
.map_err(map_api_error)?;
let api_stream = backend.stream(api_prompt).await.map_err(map_api_error)?;
let api_stream = client.stream(api_prompt).await.map_err(map_api_error)?;
Ok(wrap_stream(api_stream))
}
@@ -209,15 +146,16 @@ impl ModelClient {
self.summary,
);
if !self.config.model_family.support_verbosity && self.config.model_verbosity.is_some() {
warn!(
"model_verbosity is set but ignored as the model does not support verbosity: {}",
self.config.model_family.family
);
}
let verbosity = if self.config.model_family.support_verbosity {
self.config.model_verbosity
} else {
if self.config.model_verbosity.is_some() {
warn!(
"model_verbosity is set but ignored as the model does not support verbosity: {}",
self.config.model_family.family
);
}
None
};
@@ -236,53 +174,32 @@ impl ModelClient {
})
}
async fn build_backend(&self) -> ApiClientResult<ModelBackend> {
match self.provider.wire_api {
WireApi::Responses => self.build_responses_backend().await,
WireApi::Chat => self.build_chat_backend().await,
}
}
async fn build_responses_backend(&self) -> ApiClientResult<ModelBackend> {
async fn build_api_client(&self) -> ApiClientResult<RoutedApiClient> {
let show_reasoning = self.config.show_raw_agent_reasoning;
let auth_provider = self.auth_manager.as_ref().map(|manager| {
Arc::new(AuthManagerProvider::new(Arc::clone(manager))) as Arc<dyn AuthProvider>
});
let responses_fixture_path: Option<PathBuf> =
CODEX_RS_SSE_FIXTURE.as_ref().map(PathBuf::from);
let http_client = create_client().clone_inner();
let config = ResponsesApiClientConfig {
let config = RoutedApiClientConfig {
http_client,
provider: self.provider.clone(),
model: self.config.model.clone(),
conversation_id: self.conversation_id,
auth_provider,
otel_event_manager: self.otel_event_manager.clone(),
};
let client = ResponsesApiClient::new(config).await?;
Ok(ModelBackend::Responses(ResponsesBackend { client }))
}
async fn build_chat_backend(&self) -> ApiClientResult<ModelBackend> {
let show_reasoning = self.config.show_raw_agent_reasoning;
let http_client = create_client().clone_inner();
let config = ChatCompletionsApiClientConfig {
http_client,
provider: self.provider.clone(),
model: self.config.model.clone(),
otel_event_manager: self.otel_event_manager.clone(),
session_source: self.session_source.clone(),
aggregation_mode: if show_reasoning {
chat_aggregation_mode: if show_reasoning {
ChatAggregationMode::Streaming
} else {
ChatAggregationMode::AggregatedOnly
},
responses_fixture_path,
};
let client = ChatCompletionsApiClient::new(config).await?;
Ok(ModelBackend::Chat(ChatBackend {
client,
show_reasoning,
}))
Ok(RoutedApiClient::new(config))
}
pub fn get_provider(&self) -> ModelProviderInfo {
@@ -358,17 +275,13 @@ impl AuthProvider for AuthManagerProvider {
}
}
fn wrap_stream(stream: ApiClientStream) -> ResponseStream {
fn wrap_stream(stream: codex_api_client::ResponseStream) -> ResponseStream {
let (tx, rx) = mpsc::channel::<Result<ResponseEvent>>(1600);
tokio::spawn(async move {
let mut stream = stream;
while let Some(item) = stream.next().await {
let mapped = match item {
Ok(event) => Ok(event),
Err(err) => Err(map_api_error(err)),
};
let mapped = item.map_err(map_api_error);
if tx.send(mapped).await.is_err() {
break;
}
@@ -415,107 +328,5 @@ pub async fn stream_for_turn(
ctx: &crate::codex::TurnContext,
prompt: &Prompt,
) -> Result<ResponseStream> {
let instructions = prompt
.get_full_instructions(&ctx.client.get_model_family())
.into_owned();
let input = prompt.get_formatted_input();
let tools = match ctx.client.get_provider().wire_api {
WireApi::Responses => create_tools_json_for_responses_api(&prompt.tools)?,
WireApi::Chat => create_tools_json_for_chat_completions_api(&prompt.tools)?,
};
let reasoning = create_reasoning_param_for_request(
&ctx.client.get_model_family(),
ctx.client.get_reasoning_effort(),
ctx.client.get_reasoning_summary(),
);
let verbosity = if ctx.client.get_model_family().support_verbosity {
ctx.client.config().model_verbosity
} else {
if ctx.client.config().model_verbosity.is_some() {
warn!(
"model_verbosity is set but ignored as the model does not support verbosity: {}",
ctx.client.get_model_family().family
);
}
None
};
let text_controls = create_text_param_for_request(verbosity, &prompt.output_schema);
let api_prompt = codex_api_client::Prompt {
instructions,
input,
tools,
parallel_tool_calls: prompt.parallel_tool_calls,
output_schema: prompt.output_schema.clone(),
reasoning,
text_controls,
prompt_cache_key: Some(ctx.client.conversation_id.to_string()),
session_source: Some(ctx.client.get_session_source()),
};
if ctx.client.get_provider().wire_api == WireApi::Responses
&& let Some(path) = &*CODEX_RS_SSE_FIXTURE
{
warn!(path, "Streaming from fixture");
let stream = stream_from_fixture(
path,
ctx.client.get_provider(),
ctx.client.get_otel_event_manager(),
)
.await
.map_err(map_api_error)?
.boxed();
return Ok(wrap_stream(stream));
}
let http_client = create_client().clone_inner();
let api_stream = match ctx.client.get_provider().wire_api {
WireApi::Responses => {
let auth_provider = ctx.client.get_auth_manager().as_ref().map(|m| {
Arc::new(AuthManagerProvider::new(Arc::clone(m))) as Arc<dyn AuthProvider>
});
let cfg = ResponsesApiClientConfig {
http_client,
provider: ctx.client.get_provider(),
model: ctx.client.get_model(),
conversation_id: ctx.client.conversation_id,
auth_provider,
otel_event_manager: ctx.client.get_otel_event_manager(),
};
let client = ResponsesApiClient::new(cfg).await.map_err(map_api_error)?;
client
.stream(api_prompt)
.await
.map_err(map_api_error)?
.boxed()
}
WireApi::Chat => {
let cfg = ChatCompletionsApiClientConfig {
http_client,
provider: ctx.client.get_provider(),
model: ctx.client.get_model(),
otel_event_manager: ctx.client.get_otel_event_manager(),
session_source: ctx.client.get_session_source(),
aggregation_mode: if ctx.client.config().show_raw_agent_reasoning {
ChatAggregationMode::Streaming
} else {
ChatAggregationMode::AggregatedOnly
},
};
let client = ChatCompletionsApiClient::new(cfg)
.await
.map_err(map_api_error)?;
client
.stream(api_prompt)
.await
.map_err(map_api_error)?
.boxed()
}
};
Ok(wrap_stream(api_stream))
ctx.client.stream(prompt).await
}