Compare commits

...

3 Commits

Author SHA1 Message Date
Ahmed Ibrahim
2f2ec5a5f0 wire 2025-12-04 12:59:26 -08:00
Ahmed Ibrahim
1549e513a7 wire 2025-12-04 12:59:26 -08:00
Ahmed Ibrahim
f745ae7d7f introduce-endpoint 2025-12-04 12:58:53 -08:00
4 changed files with 201 additions and 6 deletions

View File

@@ -0,0 +1,103 @@
#![allow(clippy::expect_used)]
use async_trait::async_trait;
use codex_api::ModelsClient;
use codex_api::auth::AuthProvider;
use codex_api::provider::Provider;
use codex_api::provider::RetryConfig;
use codex_api::provider::WireApi;
use codex_client::HttpTransport;
use codex_client::Request;
use codex_client::Response;
use codex_client::StreamResponse;
use codex_client::TransportError;
use codex_protocol::openai_models::ModelsResponse;
use http::HeaderMap;
use http::StatusCode;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
#[derive(Clone, Default)]
struct DummyAuth;
impl AuthProvider for DummyAuth {
fn bearer_token(&self) -> Option<String> {
None
}
}
#[derive(Clone)]
struct CapturingTransport {
last_request: Arc<Mutex<Option<Request>>>,
body: Arc<ModelsResponse>,
}
#[async_trait]
impl HttpTransport for CapturingTransport {
async fn execute(&self, req: Request) -> Result<Response, TransportError> {
*self.last_request.lock().expect("lock poisoned") = Some(req);
let body = serde_json::to_vec(&*self.body).expect("serialization should succeed");
Ok(Response {
status: StatusCode::OK,
headers: HeaderMap::new(),
body: body.into(),
})
}
async fn stream(&self, _req: Request) -> Result<StreamResponse, TransportError> {
Err(TransportError::Build("stream should not run".to_string()))
}
}
fn provider(base_url: &str) -> Provider {
Provider {
name: "test".to_string(),
base_url: base_url.to_string(),
query_params: None,
wire: WireApi::Responses,
headers: HeaderMap::new(),
retry: RetryConfig {
max_attempts: 1,
base_delay: Duration::from_millis(1),
retry_429: false,
retry_5xx: true,
retry_transport: true,
},
stream_idle_timeout: Duration::from_secs(1),
}
}
#[tokio::test]
async fn builds_correct_url() {
let response = ModelsResponse { models: Vec::new() };
let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
};
let client = ModelsClient::new(
transport.clone(),
provider("https://example.com/api/codex"),
DummyAuth,
);
client
.list_models("0.1.2", HeaderMap::new())
.await
.expect("list_models should succeed");
let url = transport
.last_request
.lock()
.expect("lock poisoned")
.as_ref()
.expect("request recorded")
.url
.clone();
assert_eq!(
url,
"https://example.com/api/codex/models?client_version=0.1.2"
);
}

View File

@@ -47,7 +47,7 @@ impl ConversationManager {
conversations: Arc::new(RwLock::new(HashMap::new())),
auth_manager: auth_manager.clone(),
session_source,
models_manager: Arc::new(ModelsManager::new(auth_manager.get_auth_mode())),
models_manager: Arc::new(ModelsManager::new(auth_manager)),
}
}

View File

@@ -1,8 +1,21 @@
use std::sync::Arc;
use codex_api::ModelsClient;
use codex_api::ReqwestTransport;
use codex_app_server_protocol::AuthMode;
use codex_protocol::openai_models::ModelInfo;
use codex_protocol::openai_models::ModelPreset;
use http::HeaderMap;
use tokio::sync::RwLock;
use crate::AuthManager;
use crate::api_bridge::auth_provider_from_auth;
use crate::api_bridge::map_api_error;
use crate::auth::CodexAuth;
use crate::config::Config;
use crate::default_client::build_reqwest_client;
use crate::error::Result;
use crate::model_provider_info::ModelProviderInfo;
use crate::openai_models::model_family::ModelFamily;
use crate::openai_models::model_family::find_family_for_model;
use crate::openai_models::model_presets::builtin_model_presets;
@@ -11,24 +24,41 @@ use crate::openai_models::model_presets::builtin_model_presets;
pub struct ModelsManager {
pub available_models: RwLock<Vec<ModelPreset>>,
pub etag: String,
pub auth_mode: Option<AuthMode>,
pub auth_manager: Arc<AuthManager>,
}
impl ModelsManager {
pub fn new(auth_mode: Option<AuthMode>) -> Self {
pub fn new(auth_manager: Arc<AuthManager>) -> Self {
Self {
available_models: RwLock::new(builtin_model_presets(auth_mode)),
available_models: RwLock::new(builtin_model_presets(auth_manager.get_auth_mode())),
etag: String::new(),
auth_mode,
auth_manager,
}
}
pub async fn refresh_available_models(&self) {
let models = builtin_model_presets(self.auth_mode);
let models = builtin_model_presets(self.auth_manager.get_auth_mode());
*self.available_models.write().await = models;
}
pub fn construct_model_family(&self, model: &str, config: &Config) -> ModelFamily {
find_family_for_model(model).with_config_overrides(config)
}
pub async fn fetch_models_from_api(
&self,
provider: &ModelProviderInfo,
) -> Result<Vec<ModelInfo>> {
let api_provider = provider.to_api_provider(self.auth_manager.get_auth_mode())?;
let api_auth = auth_provider_from_auth(self.auth_manager.auth(), provider).await?;
let transport = ReqwestTransport::new(build_reqwest_client());
let client = ModelsClient::new(transport, api_provider, api_auth);
let response = client
.list_models(env!("CARGO_PKG_VERSION"), HeaderMap::new())
.await
.map_err(map_api_error)?;
Ok(response.models)
}
}

View File

@@ -0,0 +1,62 @@
use codex_core::WireApi;
use codex_core::create_oss_provider_with_base_url;
use codex_core::openai_models::models_manager::ModelsManager;
use codex_protocol::openai_models::ClientVersion;
use codex_protocol::openai_models::ModelInfo;
use codex_protocol::openai_models::ModelVisibility;
use codex_protocol::openai_models::ModelsResponse;
use codex_protocol::openai_models::ReasoningLevel;
use codex_protocol::openai_models::ShellType;
use core_test_support::responses::mount_models_once;
use core_test_support::responses::start_mock_server;
use pretty_assertions::assert_eq;
#[tokio::test]
async fn fetches_models_via_models_endpoint() {
let server = start_mock_server().await;
let body = ModelsResponse {
models: vec![ModelInfo {
slug: "gpt-test".to_string(),
display_name: "gpt-test".to_string(),
description: Some("desc".to_string()),
default_reasoning_level: ReasoningLevel::Medium,
supported_reasoning_levels: vec![
ReasoningLevel::Low,
ReasoningLevel::Medium,
ReasoningLevel::High,
],
shell_type: ShellType::ShellCommand,
visibility: ModelVisibility::List,
minimal_client_version: ClientVersion(0, 99, 0),
supported_in_api: true,
priority: 1,
}],
};
let models_mock = mount_models_once(&server, body.clone()).await;
let base_url = format!("{}/api/codex", server.uri());
let provider = create_oss_provider_with_base_url(&base_url, WireApi::Responses);
let manager = ModelsManager::new(None);
let models = manager
.fetch_models_from_api(&provider, None)
.await
.expect("fetch models");
assert_eq!(models, body.models);
let request = models_mock
.requests()
.into_iter()
.next()
.expect("models request captured");
assert_eq!(request.url.path(), "/api/codex/models");
let client_version = request
.url
.query_pairs()
.find(|(k, _)| k == "client_version")
.map(|(_, v)| v.to_string())
.expect("client_version query param");
assert_eq!(client_version, env!("CARGO_PKG_VERSION"));
}