Compare commits

...

1 Commits

Author SHA1 Message Date
pakrym-oai
5ffd9f6c56 core: add API client factory 2026-05-06 16:07:39 -07:00
15 changed files with 252 additions and 87 deletions

1
codex-rs/Cargo.lock generated
View File

@@ -2439,6 +2439,7 @@ dependencies = [
"codex-app-server-protocol",
"codex-apply-patch",
"codex-async-utils",
"codex-client",
"codex-code-mode",
"codex-config",
"codex-connectors",

View File

@@ -32,6 +32,7 @@ codex-app-server-protocol = { workspace = true }
codex-apply-patch = { workspace = true }
codex-async-utils = { workspace = true }
codex-code-mode = { workspace = true }
codex-client = { workspace = true }
codex-connectors = { workspace = true }
codex-config = { workspace = true }
codex-core-plugins = { workspace = true }

View File

@@ -62,6 +62,7 @@ use codex_api::build_session_headers;
use codex_api::create_text_param_for_request;
use codex_api::response_create_client_metadata;
use codex_app_server_protocol::AuthMode;
use codex_client::HttpTransport;
use codex_login::AuthManager;
use codex_login::CodexAuth;
use codex_login::RefreshTokenError;
@@ -184,6 +185,99 @@ struct CurrentClientSetup {
api_auth: SharedAuthProvider,
}
/// Session-scoped factory for constructing codex-api REST clients.
///
/// The factory keeps the runtime model provider handle and resolves API provider/auth state when a
/// client is requested, so clients created after auth refresh use current credentials.
#[derive(Clone, Debug)]
pub struct ApiClientFactory {
provider: SharedModelProvider,
}
impl ApiClientFactory {
pub(crate) fn new(provider: SharedModelProvider) -> Self {
Self { provider }
}
pub async fn create<C: ApiClient<ReqwestTransport>>(&self) -> Result<C> {
Ok(self.current_setup().await?.create())
}
async fn current_setup(&self) -> Result<CurrentClientSetup> {
let auth = self.provider.auth().await;
let api_provider = self.provider.api_provider().await?;
let api_auth = self.provider.api_auth().await?;
Ok(CurrentClientSetup {
auth,
api_provider,
api_auth,
})
}
}
impl CurrentClientSetup {
fn create<C: ApiClient<ReqwestTransport>>(&self) -> C {
C::from_api_parts(
ReqwestTransport::new(build_reqwest_client()),
self.api_provider.clone(),
self.api_auth.clone(),
)
}
}
/// Constructs a codex-api REST client from a transport and resolved API provider/auth parts.
///
/// Implementations should be thin adapters over the concrete client's existing constructor. This
/// lets [`ApiClientFactory`] instantiate client types generically without hardcoding them in the
/// factory itself.
pub trait ApiClient<T: HttpTransport>: Sized {
fn from_api_parts(
transport: T,
api_provider: ApiProvider,
api_auth: SharedAuthProvider,
) -> Self;
}
impl<T: HttpTransport> ApiClient<T> for ApiCompactClient<T> {
fn from_api_parts(
transport: T,
api_provider: ApiProvider,
api_auth: SharedAuthProvider,
) -> Self {
Self::new(transport, api_provider, api_auth)
}
}
impl<T: HttpTransport> ApiClient<T> for ApiMemoriesClient<T> {
fn from_api_parts(
transport: T,
api_provider: ApiProvider,
api_auth: SharedAuthProvider,
) -> Self {
Self::new(transport, api_provider, api_auth)
}
}
impl<T: HttpTransport> ApiClient<T> for ApiRealtimeCallClient<T> {
fn from_api_parts(
transport: T,
api_provider: ApiProvider,
api_auth: SharedAuthProvider,
) -> Self {
Self::new(transport, api_provider, api_auth)
}
}
impl<T: HttpTransport> ApiClient<T> for ApiResponsesClient<T> {
fn from_api_parts(
transport: T,
api_provider: ApiProvider,
api_auth: SharedAuthProvider,
) -> Self {
Self::new(transport, api_provider, api_auth)
}
}
#[derive(Clone, Copy)]
struct RequestRouteTelemetry {
endpoint: &'static str,
@@ -226,6 +320,7 @@ pub struct ModelClient {
/// contract and can cause routing bugs.
pub struct ModelClientSession {
client: ModelClient,
api_client_factory: ApiClientFactory,
websocket_session: WebsocketSession,
/// Turn state for sticky routing.
///
@@ -316,6 +411,31 @@ impl ModelClient {
beta_features_header: Option<String>,
) -> Self {
let model_provider = create_model_provider(provider_info, auth_manager);
Self::from_model_provider(
model_provider,
session_id,
thread_id,
installation_id,
session_source,
model_verbosity,
enable_request_compression,
include_timing_metrics,
beta_features_header,
)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn from_model_provider(
model_provider: SharedModelProvider,
session_id: SessionId,
thread_id: ThreadId,
installation_id: String,
session_source: SessionSource,
model_verbosity: Option<VerbosityConfig>,
enable_request_compression: bool,
include_timing_metrics: bool,
beta_features_header: Option<String>,
) -> Self {
let codex_api_key_env_enabled = model_provider
.auth_manager()
.as_ref()
@@ -346,8 +466,17 @@ impl ModelClient {
/// This constructor does not perform network I/O itself; the session opens a websocket lazily
/// when the first stream request is issued.
pub fn new_session(&self) -> ModelClientSession {
let api_client_factory = ApiClientFactory::new(Arc::clone(&self.state.provider));
self.new_session_with_client_factory(api_client_factory)
}
pub(crate) fn new_session_with_client_factory(
&self,
api_client_factory: ApiClientFactory,
) -> ModelClientSession {
ModelClientSession {
client: self.clone(),
api_client_factory,
websocket_session: self.take_cached_websocket_session(),
turn_state: Arc::new(OnceLock::new()),
}
@@ -422,6 +551,7 @@ impl ModelClient {
/// session-scoped.
pub(crate) async fn compact_conversation_history(
&self,
api_client_factory: &ApiClientFactory,
prompt: &Prompt,
model_info: &ModelInfo,
settings: CompactConversationRequestSettings,
@@ -431,8 +561,7 @@ impl ModelClient {
if prompt.input.is_empty() {
return Ok(Vec::new());
}
let client_setup = self.current_client_setup().await?;
let transport = ReqwestTransport::new(build_reqwest_client());
let client_setup = api_client_factory.current_setup().await?;
let request_telemetry = Self::build_request_telemetry(
session_telemetry,
AuthRequestTelemetryContext::new(
@@ -463,9 +592,9 @@ impl ModelClient {
text,
..
} = request;
let client =
ApiCompactClient::new(transport, client_setup.api_provider, client_setup.api_auth)
.with_telemetry(Some(request_telemetry));
let client = client_setup
.create::<ApiCompactClient<ReqwestTransport>>()
.with_telemetry(Some(request_telemetry));
let payload = ApiCompactionInput {
model: &model,
input: &input,
@@ -503,23 +632,23 @@ impl ModelClient {
pub(crate) async fn create_realtime_call_with_headers(
&self,
api_client_factory: &ApiClientFactory,
sdp: String,
session_config: ApiRealtimeSessionConfig,
extra_headers: ApiHeaderMap,
) -> Result<RealtimeWebrtcCallStart> {
// Create the media call over HTTP first, then retain matching auth so realtime can attach
// the server-side control WebSocket to the call id from that HTTP response.
let client_setup = self.current_client_setup().await?;
let client_setup = api_client_factory.current_setup().await?;
let mut sideband_headers = extra_headers.clone();
sideband_headers.extend(sideband_websocket_auth_headers(
client_setup.api_auth.as_ref(),
));
let transport = ReqwestTransport::new(build_reqwest_client());
let response =
ApiRealtimeCallClient::new(transport, client_setup.api_provider, client_setup.api_auth)
.create_with_session_and_headers(sdp, session_config, extra_headers)
.await
.map_err(map_api_error)?;
let response = client_setup
.create::<ApiRealtimeCallClient<ReqwestTransport>>()
.create_with_session_and_headers(sdp, session_config, extra_headers)
.await
.map_err(map_api_error)?;
Ok(RealtimeWebrtcCallStart {
sdp: response.sdp,
call_id: response.call_id,
@@ -535,6 +664,7 @@ impl ModelClient {
/// `ModelClient` session-scoped.
pub async fn summarize_memories(
&self,
api_client_factory: &ApiClientFactory,
raw_memories: Vec<ApiRawMemory>,
model_info: &ModelInfo,
effort: Option<ReasoningEffortConfig>,
@@ -544,8 +674,7 @@ impl ModelClient {
return Ok(Vec::new());
}
let client_setup = self.current_client_setup().await?;
let transport = ReqwestTransport::new(build_reqwest_client());
let client_setup = api_client_factory.current_setup().await?;
let request_telemetry = Self::build_request_telemetry(
session_telemetry,
AuthRequestTelemetryContext::new(
@@ -556,9 +685,9 @@ impl ModelClient {
RequestRouteTelemetry::for_endpoint(MEMORIES_SUMMARIZE_ENDPOINT),
self.state.auth_env_telemetry.clone(),
);
let client =
ApiMemoriesClient::new(transport, client_setup.api_provider, client_setup.api_auth)
.with_telemetry(Some(request_telemetry));
let client = client_setup
.create::<ApiMemoriesClient<ReqwestTransport>>()
.with_telemetry(Some(request_telemetry));
let payload = ApiMemorySummarizeInput {
model: model_info.slug.clone(),
@@ -747,21 +876,6 @@ impl ModelClient {
true
}
/// Returns auth + provider configuration resolved from the current session auth state.
///
/// This centralizes setup used by both prewarm and normal request paths so they stay in
/// lockstep when auth/provider resolution changes.
async fn current_client_setup(&self) -> Result<CurrentClientSetup> {
let auth = self.state.provider.auth().await;
let api_provider = self.state.provider.api_provider().await?;
let api_auth = self.state.provider.api_auth().await?;
Ok(CurrentClientSetup {
auth,
api_provider,
api_auth,
})
}
/// Opens a websocket connection using the same header and telemetry wiring as normal turns.
///
/// Both startup prewarm and in-turn `needs_new` reconnects call this path so handshake
@@ -1038,11 +1152,15 @@ impl ModelClientSession {
return Ok(());
}
let client_setup = self.client.current_client_setup().await.map_err(|err| {
ApiError::Stream(format!(
"failed to build websocket prewarm client setup: {err}"
))
})?;
let client_setup = self
.api_client_factory
.current_setup()
.await
.map_err(|err| {
ApiError::Stream(format!(
"failed to build websocket prewarm client setup: {err}"
))
})?;
let auth_context = AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),
@@ -1201,8 +1319,7 @@ impl ModelClientSession {
.map(AuthManager::unauthorized_recovery);
let mut pending_retry = PendingUnauthorizedRetry::default();
loop {
let client_setup = self.client.current_client_setup().await?;
let transport = ReqwestTransport::new(build_reqwest_client());
let client_setup = self.api_client_factory.current_setup().await?;
let request_auth_context = AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),
@@ -1227,12 +1344,9 @@ impl ModelClientSession {
)?;
let inference_trace_attempt = inference_trace.start_attempt();
inference_trace_attempt.record_started(&request);
let client = ApiResponsesClient::new(
transport,
client_setup.api_provider,
client_setup.api_auth,
)
.with_telemetry(Some(request_telemetry), Some(sse_telemetry));
let client = client_setup
.create::<ApiResponsesClient<ReqwestTransport>>()
.with_telemetry(Some(request_telemetry), Some(sse_telemetry));
let stream_result = client.stream_request(request, options).await;
match stream_result {
@@ -1314,7 +1428,7 @@ impl ModelClientSession {
.map(AuthManager::unauthorized_recovery);
let mut pending_retry = PendingUnauthorizedRetry::default();
loop {
let client_setup = self.client.current_client_setup().await?;
let client_setup = self.api_client_factory.current_setup().await?;
let request_auth_context = AuthRequestTelemetryContext::new(
client_setup.auth.as_ref().map(CodexAuth::auth_mode),
client_setup.api_auth.as_ref(),

View File

@@ -1,3 +1,4 @@
use super::ApiClientFactory;
use super::AuthRequestTelemetryContext;
use super::ModelClient;
use super::PendingUnauthorizedRetry;
@@ -11,6 +12,7 @@ use codex_api::ApiError;
use codex_api::ResponseEvent;
use codex_app_server_protocol::AuthMode;
use codex_model_provider::BearerAuthProvider;
use codex_model_provider::create_model_provider;
use codex_model_provider_info::WireApi;
use codex_model_provider_info::create_oss_provider_with_base_url;
use codex_otel::SessionTelemetry;
@@ -308,6 +310,10 @@ async fn summarize_memories_returns_empty_for_empty_input() {
let output = client
.summarize_memories(
&ApiClientFactory::new(create_model_provider(
create_oss_provider_with_base_url("https://example.com/v1", WireApi::Responses),
/*auth_manager*/ None,
)),
Vec::new(),
&model_info,
/*effort*/ None,

View File

@@ -167,7 +167,10 @@ async fn run_compact_task_inner_impl(
let max_retries = turn_context.provider.info().stream_max_retries();
let mut retries = 0;
let mut client_session = sess.services.model_client.new_session();
let mut client_session = sess
.services
.model_client
.new_session_with_client_factory(sess.services.api_client_factory.clone());
// Reuse one client session so turn-scoped state (sticky routing, websocket incremental
// request tracking)
// survives retries within this compact turn.

View File

@@ -169,6 +169,7 @@ async fn run_remote_compact_task_inner_impl(
.services
.model_client
.compact_conversation_history(
&sess.services.api_client_factory,
&prompt,
&turn_context.model_info,
CompactConversationRequestSettings {

View File

@@ -195,7 +195,10 @@ async fn run_remote_compact_task_inner_impl(
)
.await
} else {
let mut owned_client_session = sess.services.model_client.new_session();
let mut owned_client_session = sess
.services
.model_client
.new_session_with_client_factory(sess.services.api_client_factory.clone());
run_remote_compaction_request_v2(
sess,
turn_context,

View File

@@ -178,6 +178,8 @@ mod tasks;
mod user_shell_command;
pub mod util;
pub use client::ApiClient;
pub use client::ApiClientFactory;
pub use client::ModelClient;
pub use client::ModelClientSession;
pub use client::X_CODEX_INSTALLATION_ID_HEADER;

View File

@@ -1,3 +1,4 @@
use crate::client::ApiClientFactory;
use crate::client::ModelClient;
use crate::realtime_context::build_realtime_startup_context;
use crate::realtime_prompt::prepare_realtime_backend_prompt;
@@ -229,6 +230,7 @@ struct RealtimeStart {
extra_headers: Option<HeaderMap>,
session_config: RealtimeSessionConfig,
model_client: ModelClient,
api_client_factory: ApiClientFactory,
sdp: Option<String>,
}
@@ -281,6 +283,7 @@ impl RealtimeConversationManager {
extra_headers,
session_config,
model_client,
api_client_factory,
sdp,
} = start;
let event_parser = session_config.event_parser;
@@ -310,6 +313,7 @@ impl RealtimeConversationManager {
let (task, sdp) = if let Some(sdp) = sdp {
let call = model_client
.create_realtime_call_with_headers(
&api_client_factory,
sdp,
session_config.clone(),
extra_headers.unwrap_or_default(),
@@ -789,6 +793,7 @@ async fn handle_start_inner(
extra_headers,
session_config,
model_client: sess.services.model_client.clone(),
api_client_factory: sess.services.api_client_factory.clone(),
sdp,
};
let start_output = sess.conversation.start(start).await?;

View File

@@ -164,6 +164,7 @@ use tracing::instrument;
use tracing::warn;
use uuid::Uuid;
use crate::client::ApiClientFactory;
use crate::client::ModelClient;
use crate::codex_thread::ThreadConfigSnapshot;
use crate::compact::collect_user_messages;
@@ -307,6 +308,7 @@ use codex_git_utils::get_git_repo_root;
use codex_mcp::compute_auth_statuses;
use codex_mcp::host_owned_codex_apps_enabled;
use codex_mcp::with_codex_apps_mcp;
use codex_model_provider::create_model_provider;
use codex_otel::SessionTelemetry;
use codex_otel::THREAD_STARTED_METRIC;
use codex_otel::TelemetryAuthMode;

View File

@@ -795,6 +795,22 @@ impl Session {
SessionId::from(thread_id)
};
let agent_control = agent_control.with_session_id(session_id);
let model_provider = create_model_provider(
session_configuration.provider.clone(),
Some(Arc::clone(&auth_manager)),
);
let api_client_factory = ApiClientFactory::new(Arc::clone(&model_provider));
let model_client = ModelClient::from_model_provider(
model_provider,
session_id,
thread_id,
installation_id.clone(),
session_configuration.session_source.clone(),
config.model_verbosity,
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Self::build_model_client_beta_features_header(config.as_ref()),
);
let services = SessionServices {
// Initialize the MCP connection manager with an uninitialized
// instance. It will be replaced with one created via
@@ -836,18 +852,8 @@ impl Session {
state_db: state_db_ctx.clone(),
live_thread: live_thread_init.as_ref().cloned(),
thread_store: Arc::clone(&thread_store),
model_client: ModelClient::new(
Some(Arc::clone(&auth_manager)),
session_id,
thread_id,
installation_id.clone(),
session_configuration.provider.clone(),
session_configuration.session_source.clone(),
config.model_verbosity,
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Self::build_model_client_beta_features_header(config.as_ref()),
),
model_client,
api_client_factory,
code_mode_service: crate::tools::code_mode::CodeModeService::new(),
environment_manager,
};

View File

@@ -3710,6 +3710,22 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
.expect("create environment"),
);
let model_provider = create_model_provider(
session_configuration.provider.clone(),
Some(auth_manager.clone()),
);
let api_client_factory = ApiClientFactory::new(Arc::clone(&model_provider));
let model_client = ModelClient::from_model_provider(
model_provider,
thread_id.into(),
thread_id,
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
session_configuration.session_source.clone(),
config.model_verbosity,
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Session::build_model_client_beta_features_header(config.as_ref()),
);
let services = SessionServices {
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::new_uninitialized(
&config.permissions.approval_policy,
@@ -3759,18 +3775,8 @@ pub(crate) async fn make_session_and_context() -> (Session, TurnContext) {
.await
.expect("state db should initialize"),
)),
model_client: ModelClient::new(
Some(auth_manager.clone()),
thread_id.into(),
thread_id,
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
session_configuration.provider.clone(),
session_configuration.session_source.clone(),
config.model_verbosity,
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Session::build_model_client_beta_features_header(config.as_ref()),
),
model_client,
api_client_factory,
code_mode_service: crate::tools::code_mode::CodeModeService::new(),
environment_manager: Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()),
};
@@ -5397,6 +5403,22 @@ where
)
.await
.expect("state db should initialize");
let model_provider = create_model_provider(
session_configuration.provider.clone(),
Some(Arc::clone(&auth_manager)),
);
let api_client_factory = ApiClientFactory::new(Arc::clone(&model_provider));
let model_client = ModelClient::from_model_provider(
model_provider,
thread_id.into(),
thread_id,
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
session_configuration.session_source.clone(),
config.model_verbosity,
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Session::build_model_client_beta_features_header(config.as_ref()),
);
let services = SessionServices {
mcp_connection_manager: Arc::new(RwLock::new(McpConnectionManager::new_uninitialized(
&config.permissions.approval_policy,
@@ -5441,18 +5463,8 @@ where
codex_thread_store::LocalThreadStoreConfig::from_config(config.as_ref()),
state_db,
)),
model_client: ModelClient::new(
Some(Arc::clone(&auth_manager)),
thread_id.into(),
thread_id,
/*installation_id*/ "11111111-1111-4111-8111-111111111111".to_string(),
session_configuration.provider.clone(),
session_configuration.session_source.clone(),
config.model_verbosity,
config.features.enabled(Feature::EnableRequestCompression),
config.features.enabled(Feature::RuntimeMetrics),
Session::build_model_client_beta_features_header(config.as_ref()),
),
model_client,
api_client_factory,
code_mode_service: crate::tools::code_mode::CodeModeService::new(),
environment_manager: Arc::new(codex_exec_server::EnvironmentManager::default_for_tests()),
};

View File

@@ -148,8 +148,11 @@ pub(crate) async fn run_turn(
let model_info = turn_context.model_info.clone();
let auto_compact_limit = model_info.auto_compact_token_limit().unwrap_or(i64::MAX);
let mut client_session =
prewarmed_client_session.unwrap_or_else(|| sess.services.model_client.new_session());
let mut client_session = prewarmed_client_session.unwrap_or_else(|| {
sess.services
.model_client
.new_session_with_client_factory(sess.services.api_client_factory.clone())
});
// TODO(ccunningham): Pre-turn compaction runs before context updates and the
// new user message are recorded. Estimate pending incoming items (context
// diffs/full reinjection + user input) and trigger compaction preemptively

View File

@@ -224,7 +224,10 @@ async fn schedule_startup_prewarm_inner(
let startup_turn_metadata_header = startup_turn_context
.turn_metadata_state
.current_header_value();
let mut client_session = session.services.model_client.new_session();
let mut client_session = session
.services
.model_client
.new_session_with_client_factory(session.services.api_client_factory.clone());
client_session
.prewarm_websocket(
&startup_prompt,

View File

@@ -3,6 +3,7 @@ use std::sync::Arc;
use crate::SkillsManager;
use crate::agent::AgentControl;
use crate::client::ApiClientFactory;
use crate::client::ModelClient;
use crate::config::StartedNetworkProxy;
use crate::exec_policy::ExecPolicyManager;
@@ -66,6 +67,8 @@ pub(crate) struct SessionServices {
pub(crate) thread_store: Arc<dyn ThreadStore>,
/// Session-scoped model client shared across turns.
pub(crate) model_client: ModelClient,
/// Factory used to construct codex-api clients from the session's current auth/provider state.
pub(crate) api_client_factory: ApiClientFactory,
pub(crate) code_mode_service: CodeModeService,
/// Shared process-level environment registry. Sessions carry an `Arc` handle so they can pass
/// the same manager through child-thread spawn paths without reconstructing it.