This commit is contained in:
jimmyfraiture
2025-09-29 10:49:19 +01:00
parent 491ba05f71
commit 2efe961ac1
16 changed files with 265 additions and 363 deletions

View File

@@ -1,19 +1,29 @@
pub mod config_types;
pub mod exec_command;
pub mod function_tool;
pub mod model_family;
pub mod model_provider;
pub mod notifications;
pub mod runtime;
pub mod runtime_config;
pub mod services;
pub mod shell;
pub mod token_data;
pub mod tooling;
pub mod truncate;
pub mod turn_diff_tracker;
pub mod unified_exec;
pub use config_types::*;
pub use function_tool::*;
pub use model_family::*;
pub use model_provider::*;
pub use notifications::*;
pub use runtime::*;
pub use runtime_config::*;
pub use services::*;
pub use shell::*;
pub use token_data::*;
pub use tooling::*;
pub use turn_diff_tracker::*;
pub use unified_exec::*;

View File

@@ -0,0 +1,15 @@
use crate::config_types::ReasoningSummaryFormat;
use crate::tooling::ApplyPatchToolType;
/// Metadata describing consistent behaviour across a family of models.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ModelFamily {
pub slug: String,
pub family: String,
pub needs_special_apply_patch_instructions: bool,
pub supports_reasoning_summaries: bool,
pub reasoning_summary_format: ReasoningSummaryFormat,
pub uses_local_shell_tool: bool,
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
pub base_instructions: String,
}

View File

@@ -0,0 +1,54 @@
use std::collections::HashMap;
use codex_protocol::mcp_protocol::AuthMode;
use serde::Deserialize;
use serde::Serialize;
/// Wire protocol variants supported by model providers.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum WireApi {
Responses,
#[default]
Chat,
}
/// Serializable representation of a provider definition shared across hosts.
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ModelProviderInfo {
pub name: String,
pub base_url: Option<String>,
pub env_key: Option<String>,
pub env_key_instructions: Option<String>,
#[serde(default)]
pub wire_api: WireApi,
pub query_params: Option<HashMap<String, String>>,
pub http_headers: Option<HashMap<String, String>>,
pub env_http_headers: Option<HashMap<String, String>>,
pub request_max_retries: Option<u64>,
pub stream_max_retries: Option<u64>,
pub stream_idle_timeout_ms: Option<u64>,
#[serde(default)]
pub requires_openai_auth: bool,
}
impl ModelProviderInfo {
pub fn wire_api(&self) -> WireApi {
self.wire_api
}
pub fn requires_auth(&self) -> bool {
self.requires_openai_auth
}
pub fn base_url(&self, auth_mode: AuthMode) -> String {
let fallback = if auth_mode == AuthMode::ChatGPT {
"https://chatgpt.com/backend-api/codex"
} else {
"https://api.openai.com/v1"
};
self.base_url
.clone()
.unwrap_or_else(|| fallback.to_string())
}
}

View File

@@ -0,0 +1,16 @@
use async_trait::async_trait;
use codex_protocol::protocol::Event;
use codex_protocol::protocol::Op;
use codex_protocol::protocol::Submission;
/// Minimal async interface for interacting with an agent runtime.
#[async_trait]
pub trait AgentRuntime: Send + Sync {
type Error: std::error::Error + Send + Sync + 'static;
async fn submit(&self, op: Op) -> Result<String, Self::Error>;
async fn submit_with_id(&self, submission: Submission) -> Result<(), Self::Error>;
async fn next_event(&self) -> Result<Event, Self::Error>;
}

View File

@@ -0,0 +1,46 @@
use std::collections::HashMap;
use std::path::PathBuf;
use crate::config_types::History;
use crate::config_types::McpServerConfig;
use crate::config_types::ShellEnvironmentPolicy;
use crate::model_family::ModelFamily;
use crate::model_provider::ModelProviderInfo;
use codex_protocol::config_types::ReasoningEffort;
use codex_protocol::config_types::ReasoningSummary;
use codex_protocol::config_types::Verbosity;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::SandboxPolicy;
/// Configuration surface consumed by the agent runtime regardless of host.
#[derive(Debug, Clone, PartialEq)]
pub struct AgentConfig {
pub model: String,
pub review_model: String,
pub model_family: ModelFamily,
pub model_context_window: Option<u64>,
pub model_auto_compact_token_limit: Option<i64>,
pub model_reasoning_effort: Option<ReasoningEffort>,
pub model_reasoning_summary: ReasoningSummary,
pub model_verbosity: Option<Verbosity>,
pub model_provider: ModelProviderInfo,
pub approval_policy: AskForApproval,
pub sandbox_policy: SandboxPolicy,
pub shell_environment_policy: ShellEnvironmentPolicy,
pub user_instructions: Option<String>,
pub base_instructions: Option<String>,
pub notify: Option<Vec<String>>,
pub cwd: PathBuf,
pub codex_home: PathBuf,
pub history: History,
pub mcp_servers: HashMap<String, McpServerConfig>,
pub include_plan_tool: bool,
pub include_apply_patch_tool: bool,
pub include_view_image_tool: bool,
pub tools_web_search_request: bool,
pub use_experimental_streamable_shell_tool: bool,
pub use_experimental_unified_exec_tool: bool,
pub show_raw_agent_reasoning: bool,
pub codex_linux_sandbox_exe: Option<PathBuf>,
pub project_doc_max_bytes: usize,
}

View File

@@ -0,0 +1,10 @@
use serde::Deserialize;
use serde::Serialize;
/// Represents which apply_patch tool variant a model expects.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum ApplyPatchToolType {
Freeform,
Function,
}

View File

@@ -1,49 +1,6 @@
use std::collections::HashMap;
use std::path::PathBuf;
pub use codex_agent::AgentConfig;
use crate::config::Config;
use crate::config_types::History;
use crate::config_types::McpServerConfig;
use crate::config_types::ShellEnvironmentPolicy;
use crate::model_family::ModelFamily;
use crate::model_provider_info::ModelProviderInfo;
use codex_protocol::config_types::ReasoningEffort;
use codex_protocol::config_types::ReasoningSummary;
use codex_protocol::config_types::Verbosity;
use codex_protocol::protocol::AskForApproval;
use codex_protocol::protocol::SandboxPolicy;
#[derive(Debug, Clone, PartialEq)]
pub struct AgentConfig {
pub model: String,
pub review_model: String,
pub model_family: ModelFamily,
pub model_context_window: Option<u64>,
pub model_auto_compact_token_limit: Option<i64>,
pub model_reasoning_effort: Option<ReasoningEffort>,
pub model_reasoning_summary: ReasoningSummary,
pub model_verbosity: Option<Verbosity>,
pub model_provider: ModelProviderInfo,
pub approval_policy: AskForApproval,
pub sandbox_policy: SandboxPolicy,
pub shell_environment_policy: ShellEnvironmentPolicy,
pub user_instructions: Option<String>,
pub base_instructions: Option<String>,
pub notify: Option<Vec<String>>,
pub cwd: PathBuf,
pub codex_home: PathBuf,
pub history: History,
pub mcp_servers: HashMap<String, McpServerConfig>,
pub include_plan_tool: bool,
pub include_apply_patch_tool: bool,
pub include_view_image_tool: bool,
pub tools_web_search_request: bool,
pub use_experimental_streamable_shell_tool: bool,
pub use_experimental_unified_exec_tool: bool,
pub show_raw_agent_reasoning: bool,
pub codex_linux_sandbox_exe: Option<PathBuf>,
pub project_doc_max_bytes: usize,
}
impl From<&Config> for AgentConfig {
fn from(config: &Config) -> Self {

View File

@@ -22,6 +22,7 @@ use crate::client_common::ResponseStream;
use crate::error::CodexErr;
use crate::error::Result;
use crate::model_family::ModelFamily;
use crate::model_provider_info::ModelProviderExt;
use crate::openai_tools::create_tools_json_for_chat_completions_api;
use crate::util::backoff;
use codex_protocol::models::ContentItem;

View File

@@ -23,6 +23,8 @@ use tracing::debug;
use tracing::trace;
use tracing::warn;
use crate::ModelProviderInfo;
use crate::WireApi;
use crate::agent_config::AgentConfig;
use crate::chat_completions::AggregateStreamExt;
use crate::chat_completions::stream_chat_completions;
@@ -38,8 +40,7 @@ use crate::error::Result;
use crate::error::UsageLimitReachedError;
use crate::flags::CODEX_RS_SSE_FIXTURE;
use crate::model_family::ModelFamily;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
use crate::model_provider_info::ModelProviderExt;
use crate::openai_model_info::get_model_info;
use crate::openai_tools::create_tools_json_for_responses_api;
use crate::protocol::RateLimitSnapshot;

View File

@@ -76,6 +76,7 @@ use crate::exec_env::create_env;
use crate::mcp_connection_manager::McpConnectionManager;
use crate::mcp_tool_call::handle_mcp_tool_call;
use crate::model_family::find_family_for_model;
use crate::model_provider_info::ModelProviderExt;
use crate::openai_model_info::get_model_info;
use crate::openai_tools::ApplyPatchToolArgs;
use crate::openai_tools::ToolsConfig;
@@ -262,6 +263,23 @@ impl Codex {
}
}
#[async_trait::async_trait]
impl codex_agent::AgentRuntime for Codex {
type Error = CodexErr;
async fn submit(&self, op: Op) -> CodexResult<String> {
Codex::submit(self, op).await
}
async fn submit_with_id(&self, submission: Submission) -> CodexResult<()> {
Codex::submit_with_id(self, submission).await
}
async fn next_event(&self) -> CodexResult<Event> {
Codex::next_event(self).await
}
}
use crate::state::SessionState;
/// Context for an initialized model agent

View File

@@ -7,6 +7,7 @@ use crate::Prompt;
use crate::client_common::ResponseEvent;
use crate::error::CodexErr;
use crate::error::Result as CodexResult;
use crate::model_provider_info::ModelProviderExt;
use crate::protocol::AgentMessageEvent;
use crate::protocol::CompactedItem;
use crate::protocol::ErrorEvent;

View File

@@ -1,3 +1,4 @@
use crate::ModelProviderInfo;
use crate::config_profile::ConfigProfile;
use crate::config_types::History;
use crate::config_types::McpServerConfig;
@@ -12,7 +13,6 @@ use crate::git_info::resolve_root_git_project_for_trust;
use crate::model_family::ModelFamily;
use crate::model_family::derive_default_model_family;
use crate::model_family::find_family_for_model;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::built_in_model_providers;
use crate::openai_model_info::get_model_info;
use crate::protocol::AskForApproval;

View File

@@ -42,8 +42,6 @@ mod truncate;
mod unified_exec;
mod user_instructions;
pub use model_provider_info::BUILT_IN_OSS_MODEL_PROVIDER_ID;
pub use model_provider_info::ModelProviderInfo;
pub use model_provider_info::WireApi;
pub use model_provider_info::built_in_model_providers;
pub use model_provider_info::create_oss_provider_with_base_url;
mod conversation_manager;
@@ -103,6 +101,8 @@ pub use client_common::ResponseEvent;
pub use client_common::ResponseStream;
pub use codex::compact::content_items_to_text;
pub use codex::compact::is_session_prefix_message;
pub use codex_agent::ModelProviderInfo;
pub use codex_agent::WireApi;
pub use codex_agent::services::CredentialsProvider;
pub use codex_agent::services::McpInterface;
pub use codex_agent::services::Notifier;

View File

@@ -1,48 +1,12 @@
use crate::config_types::ReasoningSummaryFormat;
use crate::tool_apply_patch::ApplyPatchToolType;
use codex_agent::ApplyPatchToolType;
pub use codex_agent::ModelFamily;
/// The `instructions` field in the payload sent to a model should always start
/// with this content.
const BASE_INSTRUCTIONS: &str = include_str!("../prompt.md");
const GPT_5_CODEX_INSTRUCTIONS: &str = include_str!("../gpt_5_codex_prompt.md");
/// A model family is a group of models that share certain characteristics.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ModelFamily {
/// The full model slug used to derive this model family, e.g.
/// "gpt-4.1-2025-04-14".
pub slug: String,
/// The model family name, e.g. "gpt-4.1". Note this should able to be used
/// with [`crate::openai_model_info::get_model_info`].
pub family: String,
/// True if the model needs additional instructions on how to use the
/// "virtual" `apply_patch` CLI.
pub needs_special_apply_patch_instructions: bool,
// Whether the `reasoning` field can be set when making a request to this
// model family. Note it has `effort` and `summary` subfields (though
// `summary` is optional).
pub supports_reasoning_summaries: bool,
// Define if we need a special handling of reasoning summary
pub reasoning_summary_format: ReasoningSummaryFormat,
// This should be set to true when the model expects a tool named
// "local_shell" to be provided. Its contract must be understood natively by
// the model such that its description can be omitted.
// See https://platform.openai.com/docs/guides/tools-local-shell
pub uses_local_shell_tool: bool,
/// Present if the model performs better when `apply_patch` is provided as
/// a tool call instead of just a bash command
pub apply_patch_tool_type: Option<ApplyPatchToolType>,
// Instructions to use for querying the model
pub base_instructions: String,
}
macro_rules! model_family {
(
$slug:expr, $family:expr $(, $key:ident : $value:expr )* $(,)?

View File

@@ -5,17 +5,19 @@
//! 2. User-defined entries inside `~/.codex/config.toml` under the `model_providers`
//! key. These override or extend the defaults at runtime.
use crate::CodexAuth;
use crate::ProviderAuth;
use async_trait::async_trait;
pub use codex_agent::ModelProviderInfo;
pub use codex_agent::WireApi;
use codex_protocol::mcp_protocol::AuthMode;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use std::env::VarError;
use std::sync::Arc;
use std::time::Duration;
use crate::CodexAuth;
use crate::ProviderAuth;
use crate::error::EnvVarError;
const DEFAULT_STREAM_IDLE_TIMEOUT_MS: u64 = 300_000;
const DEFAULT_STREAM_MAX_RETRIES: u64 = 5;
const DEFAULT_REQUEST_MAX_RETRIES: u64 = 4;
@@ -24,100 +26,42 @@ const MAX_STREAM_MAX_RETRIES: u64 = 100;
/// Hard cap for user-configured `request_max_retries`.
const MAX_REQUEST_MAX_RETRIES: u64 = 100;
/// Wire protocol that the provider speaks. Most third-party services only
/// implement the classic OpenAI Chat Completions JSON schema, whereas OpenAI
/// itself (and a handful of others) additionally expose the more modern
/// *Responses* API. The two protocols use different request/response shapes
/// and *cannot* be auto-detected at runtime, therefore each provider entry
/// must declare which one it expects.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum WireApi {
/// The Responses API exposed by OpenAI at `/v1/responses`.
Responses,
#[async_trait]
pub trait ModelProviderExt {
async fn create_request_builder(
&self,
client: &reqwest::Client,
auth: &Option<Arc<dyn ProviderAuth>>,
) -> crate::error::Result<reqwest::RequestBuilder>;
/// Regular Chat Completions compatible with `/v1/chat/completions`.
#[default]
Chat,
fn get_full_url(&self, auth: &Option<Arc<dyn ProviderAuth>>) -> String;
fn is_azure_responses_endpoint(&self) -> bool;
fn apply_http_headers(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder;
fn api_key(&self) -> crate::error::Result<Option<String>>;
fn request_max_retries(&self) -> u64;
fn stream_max_retries(&self) -> u64;
fn stream_idle_timeout(&self) -> Duration;
}
/// Serializable representation of a provider definition.
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub struct ModelProviderInfo {
/// Friendly display name.
pub name: String,
/// Base URL for the provider's OpenAI-compatible API.
pub base_url: Option<String>,
/// Environment variable that stores the user's API key for this provider.
pub env_key: Option<String>,
/// Optional instructions to help the user get a valid value for the
/// variable and set it.
pub env_key_instructions: Option<String>,
/// Which wire protocol this provider expects.
#[serde(default)]
pub wire_api: WireApi,
/// Optional query parameters to append to the base URL.
pub query_params: Option<HashMap<String, String>>,
/// Additional HTTP headers to include in requests to this provider where
/// the (key, value) pairs are the header name and value.
pub http_headers: Option<HashMap<String, String>>,
/// Optional HTTP headers to include in requests to this provider where the
/// (key, value) pairs are the header name and _environment variable_ whose
/// value should be used. If the environment variable is not set, or the
/// value is empty, the header will not be included in the request.
pub env_http_headers: Option<HashMap<String, String>>,
/// Maximum number of times to retry a failed HTTP request to this provider.
pub request_max_retries: Option<u64>,
/// Number of times to retry reconnecting a dropped streaming response before failing.
pub stream_max_retries: Option<u64>,
/// Idle timeout (in milliseconds) to wait for activity on a streaming response before treating
/// the connection as lost.
pub stream_idle_timeout_ms: Option<u64>,
/// Does this provider require an OpenAI API Key or ChatGPT login token? If true,
/// user is presented with login screen on first run, and login preference and token/key
/// are stored in auth.json. If false (which is the default), login screen is skipped,
/// and API key (if needed) comes from the "env_key" environment variable.
#[serde(default)]
pub requires_openai_auth: bool,
}
impl ModelProviderInfo {
/// Construct a `POST` RequestBuilder for the given URL using the provided
/// reqwest Client applying:
/// • provider-specific headers (static + env based)
/// • Bearer auth header when an API key is available.
/// • Auth token for OAuth.
///
/// If the provider declares an `env_key` but the variable is missing/empty, returns an [`Err`] identical to the
/// one produced by [`ModelProviderInfo::api_key`].
pub async fn create_request_builder(
#[async_trait]
impl ModelProviderExt for ModelProviderInfo {
async fn create_request_builder(
&self,
client: &reqwest::Client,
auth: &Option<Arc<dyn ProviderAuth>>,
) -> crate::error::Result<reqwest::RequestBuilder> {
let effective_auth: Option<Arc<dyn ProviderAuth>> = match self.api_key() {
Ok(Some(key)) => Some(Arc::new(CodexAuth::from_api_key(&key))),
Ok(None) => auth.clone(),
Err(err) => {
if auth.is_some() {
auth.clone()
} else {
return Err(err);
}
}
let effective_auth: Option<Arc<dyn ProviderAuth>> = match self.api_key()? {
Some(key) => Some(Arc::new(CodexAuth::from_api_key(&key))),
None => auth.clone(),
};
let url = self.get_full_url(&effective_auth);
let mut builder = client.post(url);
if let Some(auth) = effective_auth.as_ref() {
@@ -127,26 +71,13 @@ impl ModelProviderInfo {
Ok(self.apply_http_headers(builder))
}
fn get_query_string(&self) -> String {
self.query_params
.as_ref()
.map_or_else(String::new, |params| {
let full_params = params
.iter()
.map(|(k, v)| format!("{k}={v}"))
.collect::<Vec<_>>()
.join("&");
format!("?{full_params}")
})
}
pub(crate) fn get_full_url(&self, auth: &Option<Arc<dyn ProviderAuth>>) -> String {
fn get_full_url(&self, auth: &Option<Arc<dyn ProviderAuth>>) -> String {
let default_base_url = if auth.as_ref().map(|a| a.mode()) == Some(AuthMode::ChatGPT) {
"https://chatgpt.com/backend-api/codex"
} else {
"https://api.openai.com/v1"
};
let query_string = self.get_query_string();
let query_string = get_query_string(self);
let base_url = self
.base_url
.clone()
@@ -158,7 +89,7 @@ impl ModelProviderInfo {
}
}
pub(crate) fn is_azure_responses_endpoint(&self) -> bool {
fn is_azure_responses_endpoint(&self) -> bool {
if self.wire_api != WireApi::Responses {
return false;
}
@@ -173,9 +104,6 @@ impl ModelProviderInfo {
.unwrap_or(false)
}
/// Apply provider-specific HTTP headers (both static and environment-based)
/// onto an existing `reqwest::RequestBuilder` and return the updated
/// builder.
fn apply_http_headers(&self, mut builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
if let Some(extra) = &self.http_headers {
for (k, v) in extra {
@@ -195,10 +123,7 @@ impl ModelProviderInfo {
builder
}
/// If `env_key` is Some, returns the API key for this provider if present
/// (and non-empty) in the environment. If `env_key` is required but
/// cannot be found, returns an error.
pub fn api_key(&self) -> crate::error::Result<Option<String>> {
fn api_key(&self) -> crate::error::Result<Option<String>> {
match &self.env_key {
Some(env_key) => {
let env_value = std::env::var(env_key);
@@ -221,28 +146,39 @@ impl ModelProviderInfo {
}
}
/// Effective maximum number of request retries for this provider.
pub fn request_max_retries(&self) -> u64 {
fn request_max_retries(&self) -> u64 {
self.request_max_retries
.unwrap_or(DEFAULT_REQUEST_MAX_RETRIES)
.min(MAX_REQUEST_MAX_RETRIES)
}
/// Effective maximum number of stream reconnection attempts for this provider.
pub fn stream_max_retries(&self) -> u64 {
fn stream_max_retries(&self) -> u64 {
self.stream_max_retries
.unwrap_or(DEFAULT_STREAM_MAX_RETRIES)
.min(MAX_STREAM_MAX_RETRIES)
}
/// Effective idle timeout for streaming responses.
pub fn stream_idle_timeout(&self) -> Duration {
fn stream_idle_timeout(&self) -> Duration {
self.stream_idle_timeout_ms
.map(Duration::from_millis)
.unwrap_or(Duration::from_millis(DEFAULT_STREAM_IDLE_TIMEOUT_MS))
}
}
fn get_query_string(provider: &ModelProviderInfo) -> String {
provider
.query_params
.as_ref()
.map_or_else(String::new, |params| {
let full_params = params
.iter()
.map(|(k, v)| format!("{k}={v}"))
.collect::<Vec<_>>()
.join("&");
format!("?{full_params}")
})
}
const DEFAULT_OLLAMA_PORT: u32 = 11434;
pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss";
@@ -251,20 +187,11 @@ pub const BUILT_IN_OSS_MODEL_PROVIDER_ID: &str = "oss";
pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
use ModelProviderInfo as P;
// We do not want to be in the business of adjucating which third-party
// providers are bundled with Codex CLI, so we only include the OpenAI and
// open source ("oss") providers by default. Users are encouraged to add to
// `model_providers` in config.toml to add their own providers.
[
(
"openai",
P {
name: "OpenAI".into(),
// Allow users to override the default OpenAI endpoint by
// exporting `OPENAI_BASE_URL`. This is useful when pointing
// Codex at a proxy, mock server, or Azure-style deployment
// without requiring a full TOML override for the built-in
// OpenAI provider.
base_url: std::env::var("OPENAI_BASE_URL")
.ok()
.filter(|v| !v.trim().is_empty()),
@@ -288,7 +215,6 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
.into_iter()
.collect(),
),
// Use global defaults for retry/timeout unless overridden in config.toml.
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
@@ -303,8 +229,6 @@ pub fn built_in_model_providers() -> HashMap<String, ModelProviderInfo> {
}
pub fn create_oss_provider() -> ModelProviderInfo {
// These CODEX_OSS_ environment variables are experimental: we may
// switch to reading values from config.toml instead.
let codex_oss_base_url = match std::env::var("CODEX_OSS_BASE_URL")
.ok()
.filter(|v| !v.trim().is_empty())
@@ -357,130 +281,11 @@ mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[test]
fn test_deserialize_ollama_model_provider_toml() {
let azure_provider_toml = r#"
name = "Ollama"
base_url = "http://localhost:11434/v1"
"#;
let expected_provider = ModelProviderInfo {
name: "Ollama".into(),
base_url: Some("http://localhost:11434/v1".into()),
env_key: None,
env_key_instructions: None,
wire_api: WireApi::Chat,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
assert_eq!(expected_provider, provider);
}
#[test]
fn test_deserialize_azure_model_provider_toml() {
let azure_provider_toml = r#"
name = "Azure"
base_url = "https://xxxxx.openai.azure.com/openai"
env_key = "AZURE_OPENAI_API_KEY"
query_params = { api-version = "2025-04-01-preview" }
"#;
let expected_provider = ModelProviderInfo {
name: "Azure".into(),
base_url: Some("https://xxxxx.openai.azure.com/openai".into()),
env_key: Some("AZURE_OPENAI_API_KEY".into()),
env_key_instructions: None,
wire_api: WireApi::Chat,
query_params: Some(maplit::hashmap! {
"api-version".to_string() => "2025-04-01-preview".to_string(),
}),
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
assert_eq!(expected_provider, provider);
}
#[test]
fn test_deserialize_example_model_provider_toml() {
let azure_provider_toml = r#"
name = "Example"
base_url = "https://example.com"
env_key = "API_KEY"
http_headers = { "X-Example-Header" = "example-value" }
env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
"#;
let expected_provider = ModelProviderInfo {
name: "Example".into(),
base_url: Some("https://example.com".into()),
env_key: Some("API_KEY".into()),
env_key_instructions: None,
wire_api: WireApi::Chat,
query_params: None,
http_headers: Some(maplit::hashmap! {
"X-Example-Header".to_string() => "example-value".to_string(),
}),
env_http_headers: Some(maplit::hashmap! {
"X-Example-Env-Header".to_string() => "EXAMPLE_ENV_VAR".to_string(),
}),
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
};
let provider: ModelProviderInfo = toml::from_str(azure_provider_toml).unwrap();
assert_eq!(expected_provider, provider);
}
#[test]
fn detects_azure_responses_base_urls() {
fn provider_for(base_url: &str) -> ModelProviderInfo {
ModelProviderInfo {
name: "test".into(),
base_url: Some(base_url.into()),
env_key: None,
env_key_instructions: None,
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
env_http_headers: None,
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
}
}
let positive_cases = [
"https://foo.openai.azure.com/openai",
"https://foo.openai.azure.us/openai/deployments/bar",
"https://foo.cognitiveservices.azure.cn/openai",
"https://foo.aoai.azure.com/openai",
"https://foo.openai.azure-api.net/openai",
"https://foo.z01.azurefd.net/",
];
for base_url in positive_cases {
let provider = provider_for(base_url);
assert!(
provider.is_azure_responses_endpoint(),
"expected {base_url} to be detected as Azure"
);
}
let named_provider = ModelProviderInfo {
name: "Azure".into(),
base_url: Some("https://example.com".into()),
#[tokio::test]
async fn creates_request_builder_with_auth() {
let provider = ModelProviderInfo {
name: "openai".to_string(),
base_url: None,
env_key: None,
env_key_instructions: None,
wire_api: WireApi::Responses,
@@ -490,21 +295,33 @@ env_http_headers = { "X-Example-Env-Header" = "EXAMPLE_ENV_VAR" }
request_max_retries: None,
stream_max_retries: None,
stream_idle_timeout_ms: None,
requires_openai_auth: false,
requires_openai_auth: true,
};
assert!(named_provider.is_azure_responses_endpoint());
let client = reqwest::Client::new();
let auth =
Some(Arc::new(CodexAuth::create_dummy_chatgpt_auth_for_testing())
as Arc<dyn ProviderAuth>);
let negative_cases = [
"https://api.openai.com/v1",
"https://example.com/openai",
"https://myproxy.azurewebsites.net/openai",
];
for base_url in negative_cases {
let provider = provider_for(base_url);
assert!(
!provider.is_azure_responses_endpoint(),
"expected {base_url} not to be detected as Azure"
);
}
let builder = provider
.create_request_builder(&client, &auth)
.await
.expect("builder");
let request = builder.build().expect("request");
assert_eq!(request.method(), reqwest::Method::POST);
assert_eq!(
request.url().as_str(),
"https://chatgpt.com/backend-api/codex/responses"
);
}
#[test]
fn azure_detection() {
let mut provider = create_oss_provider();
assert!(!provider.is_azure_responses_endpoint());
provider.name = "azure".to_string();
provider.wire_api = WireApi::Responses;
assert!(provider.is_azure_responses_endpoint());
}
}

View File

@@ -1,5 +1,3 @@
use serde::Deserialize;
use serde::Serialize;
use std::collections::BTreeMap;
use crate::openai_tools::FreeformTool;
@@ -7,16 +5,10 @@ use crate::openai_tools::FreeformToolFormat;
use crate::openai_tools::JsonSchema;
use crate::openai_tools::OpenAiTool;
use crate::openai_tools::ResponsesApiTool;
pub use codex_agent::ApplyPatchToolType;
const APPLY_PATCH_LARK_GRAMMAR: &str = include_str!("tool_apply_patch.lark");
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum ApplyPatchToolType {
Freeform,
Function,
}
/// Returns a custom tool that can be used to edit files. Well-suited for GPT-5 models
/// https://platform.openai.com/docs/guides/function-calling#custom-tools
pub(crate) fn create_apply_patch_freeform_tool() -> OpenAiTool {