mirror of
https://github.com/openai/codex.git
synced 2026-02-02 06:57:03 +00:00
Compare commits
3 Commits
dev/cc/tmp
...
wire-to-co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2f2ec5a5f0 | ||
|
|
1549e513a7 | ||
|
|
f745ae7d7f |
103
codex-rs/codex-api/tests/models_endpoint.rs
Normal file
103
codex-rs/codex-api/tests/models_endpoint.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
#![allow(clippy::expect_used)]
|
||||
|
||||
use async_trait::async_trait;
|
||||
use codex_api::ModelsClient;
|
||||
use codex_api::auth::AuthProvider;
|
||||
use codex_api::provider::Provider;
|
||||
use codex_api::provider::RetryConfig;
|
||||
use codex_api::provider::WireApi;
|
||||
use codex_client::HttpTransport;
|
||||
use codex_client::Request;
|
||||
use codex_client::Response;
|
||||
use codex_client::StreamResponse;
|
||||
use codex_client::TransportError;
|
||||
use codex_protocol::openai_models::ModelsResponse;
|
||||
use http::HeaderMap;
|
||||
use http::StatusCode;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Clone, Default)]
|
||||
struct DummyAuth;
|
||||
|
||||
impl AuthProvider for DummyAuth {
|
||||
fn bearer_token(&self) -> Option<String> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CapturingTransport {
|
||||
last_request: Arc<Mutex<Option<Request>>>,
|
||||
body: Arc<ModelsResponse>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HttpTransport for CapturingTransport {
|
||||
async fn execute(&self, req: Request) -> Result<Response, TransportError> {
|
||||
*self.last_request.lock().expect("lock poisoned") = Some(req);
|
||||
let body = serde_json::to_vec(&*self.body).expect("serialization should succeed");
|
||||
Ok(Response {
|
||||
status: StatusCode::OK,
|
||||
headers: HeaderMap::new(),
|
||||
body: body.into(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream(&self, _req: Request) -> Result<StreamResponse, TransportError> {
|
||||
Err(TransportError::Build("stream should not run".to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
fn provider(base_url: &str) -> Provider {
|
||||
Provider {
|
||||
name: "test".to_string(),
|
||||
base_url: base_url.to_string(),
|
||||
query_params: None,
|
||||
wire: WireApi::Responses,
|
||||
headers: HeaderMap::new(),
|
||||
retry: RetryConfig {
|
||||
max_attempts: 1,
|
||||
base_delay: Duration::from_millis(1),
|
||||
retry_429: false,
|
||||
retry_5xx: true,
|
||||
retry_transport: true,
|
||||
},
|
||||
stream_idle_timeout: Duration::from_secs(1),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn builds_correct_url() {
|
||||
let response = ModelsResponse { models: Vec::new() };
|
||||
let transport = CapturingTransport {
|
||||
last_request: Arc::new(Mutex::new(None)),
|
||||
body: Arc::new(response),
|
||||
};
|
||||
|
||||
let client = ModelsClient::new(
|
||||
transport.clone(),
|
||||
provider("https://example.com/api/codex"),
|
||||
DummyAuth,
|
||||
);
|
||||
|
||||
client
|
||||
.list_models("0.1.2", HeaderMap::new())
|
||||
.await
|
||||
.expect("list_models should succeed");
|
||||
|
||||
let url = transport
|
||||
.last_request
|
||||
.lock()
|
||||
.expect("lock poisoned")
|
||||
.as_ref()
|
||||
.expect("request recorded")
|
||||
.url
|
||||
.clone();
|
||||
|
||||
assert_eq!(
|
||||
url,
|
||||
"https://example.com/api/codex/models?client_version=0.1.2"
|
||||
);
|
||||
}
|
||||
@@ -47,7 +47,7 @@ impl ConversationManager {
|
||||
conversations: Arc::new(RwLock::new(HashMap::new())),
|
||||
auth_manager: auth_manager.clone(),
|
||||
session_source,
|
||||
models_manager: Arc::new(ModelsManager::new(auth_manager.get_auth_mode())),
|
||||
models_manager: Arc::new(ModelsManager::new(auth_manager)),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,21 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_api::ModelsClient;
|
||||
use codex_api::ReqwestTransport;
|
||||
use codex_app_server_protocol::AuthMode;
|
||||
use codex_protocol::openai_models::ModelInfo;
|
||||
use codex_protocol::openai_models::ModelPreset;
|
||||
use http::HeaderMap;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::AuthManager;
|
||||
use crate::api_bridge::auth_provider_from_auth;
|
||||
use crate::api_bridge::map_api_error;
|
||||
use crate::auth::CodexAuth;
|
||||
use crate::config::Config;
|
||||
use crate::default_client::build_reqwest_client;
|
||||
use crate::error::Result;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::openai_models::model_family::ModelFamily;
|
||||
use crate::openai_models::model_family::find_family_for_model;
|
||||
use crate::openai_models::model_presets::builtin_model_presets;
|
||||
@@ -11,24 +24,41 @@ use crate::openai_models::model_presets::builtin_model_presets;
|
||||
pub struct ModelsManager {
|
||||
pub available_models: RwLock<Vec<ModelPreset>>,
|
||||
pub etag: String,
|
||||
pub auth_mode: Option<AuthMode>,
|
||||
pub auth_manager: Arc<AuthManager>,
|
||||
}
|
||||
|
||||
impl ModelsManager {
|
||||
pub fn new(auth_mode: Option<AuthMode>) -> Self {
|
||||
pub fn new(auth_manager: Arc<AuthManager>) -> Self {
|
||||
Self {
|
||||
available_models: RwLock::new(builtin_model_presets(auth_mode)),
|
||||
available_models: RwLock::new(builtin_model_presets(auth_manager.get_auth_mode())),
|
||||
etag: String::new(),
|
||||
auth_mode,
|
||||
auth_manager,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn refresh_available_models(&self) {
|
||||
let models = builtin_model_presets(self.auth_mode);
|
||||
let models = builtin_model_presets(self.auth_manager.get_auth_mode());
|
||||
*self.available_models.write().await = models;
|
||||
}
|
||||
|
||||
pub fn construct_model_family(&self, model: &str, config: &Config) -> ModelFamily {
|
||||
find_family_for_model(model).with_config_overrides(config)
|
||||
}
|
||||
|
||||
pub async fn fetch_models_from_api(
|
||||
&self,
|
||||
provider: &ModelProviderInfo,
|
||||
) -> Result<Vec<ModelInfo>> {
|
||||
let api_provider = provider.to_api_provider(self.auth_manager.get_auth_mode())?;
|
||||
let api_auth = auth_provider_from_auth(self.auth_manager.auth(), provider).await?;
|
||||
let transport = ReqwestTransport::new(build_reqwest_client());
|
||||
let client = ModelsClient::new(transport, api_provider, api_auth);
|
||||
|
||||
let response = client
|
||||
.list_models(env!("CARGO_PKG_VERSION"), HeaderMap::new())
|
||||
.await
|
||||
.map_err(map_api_error)?;
|
||||
|
||||
Ok(response.models)
|
||||
}
|
||||
}
|
||||
|
||||
62
codex-rs/core/tests/models_manager.rs
Normal file
62
codex-rs/core/tests/models_manager.rs
Normal file
@@ -0,0 +1,62 @@
|
||||
use codex_core::WireApi;
|
||||
use codex_core::create_oss_provider_with_base_url;
|
||||
use codex_core::openai_models::models_manager::ModelsManager;
|
||||
use codex_protocol::openai_models::ClientVersion;
|
||||
use codex_protocol::openai_models::ModelInfo;
|
||||
use codex_protocol::openai_models::ModelVisibility;
|
||||
use codex_protocol::openai_models::ModelsResponse;
|
||||
use codex_protocol::openai_models::ReasoningLevel;
|
||||
use codex_protocol::openai_models::ShellType;
|
||||
use core_test_support::responses::mount_models_once;
|
||||
use core_test_support::responses::start_mock_server;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
#[tokio::test]
|
||||
async fn fetches_models_via_models_endpoint() {
|
||||
let server = start_mock_server().await;
|
||||
let body = ModelsResponse {
|
||||
models: vec![ModelInfo {
|
||||
slug: "gpt-test".to_string(),
|
||||
display_name: "gpt-test".to_string(),
|
||||
description: Some("desc".to_string()),
|
||||
default_reasoning_level: ReasoningLevel::Medium,
|
||||
supported_reasoning_levels: vec![
|
||||
ReasoningLevel::Low,
|
||||
ReasoningLevel::Medium,
|
||||
ReasoningLevel::High,
|
||||
],
|
||||
shell_type: ShellType::ShellCommand,
|
||||
visibility: ModelVisibility::List,
|
||||
minimal_client_version: ClientVersion(0, 99, 0),
|
||||
supported_in_api: true,
|
||||
priority: 1,
|
||||
}],
|
||||
};
|
||||
let models_mock = mount_models_once(&server, body.clone()).await;
|
||||
|
||||
let base_url = format!("{}/api/codex", server.uri());
|
||||
let provider = create_oss_provider_with_base_url(&base_url, WireApi::Responses);
|
||||
let manager = ModelsManager::new(None);
|
||||
|
||||
let models = manager
|
||||
.fetch_models_from_api(&provider, None)
|
||||
.await
|
||||
.expect("fetch models");
|
||||
|
||||
assert_eq!(models, body.models);
|
||||
|
||||
let request = models_mock
|
||||
.requests()
|
||||
.into_iter()
|
||||
.next()
|
||||
.expect("models request captured");
|
||||
assert_eq!(request.url.path(), "/api/codex/models");
|
||||
|
||||
let client_version = request
|
||||
.url
|
||||
.query_pairs()
|
||||
.find(|(k, _)| k == "client_version")
|
||||
.map(|(_, v)| v.to_string())
|
||||
.expect("client_version query param");
|
||||
assert_eq!(client_version, env!("CARGO_PKG_VERSION"));
|
||||
}
|
||||
Reference in New Issue
Block a user