Wire with_remote_overrides to construct model families (#7621)

- This PR wires `with_remote_overrides` and make the
`construct_model_families` an async function
- Moves getting model family a level above to keep the function `sync`
- Updates the tests to local, offline, and `sync` helper for model
families
This commit is contained in:
Ahmed Ibrahim
2025-12-05 10:40:15 -08:00
committed by GitHub
parent 5f80ad6da8
commit d08efb1743
16 changed files with 147 additions and 108 deletions

View File

@@ -14,6 +14,7 @@ use crate::compact_remote::run_inline_remote_auto_compact_task;
use crate::exec_policy::load_exec_policy_for_features;
use crate::features::Feature;
use crate::features::Features;
use crate::openai_models::model_family::ModelFamily;
use crate::openai_models::models_manager::ModelsManager;
use crate::parse_command::parse_command;
use crate::parse_turn_item;
@@ -398,35 +399,39 @@ pub(crate) struct SessionSettingsUpdate {
}
impl Session {
fn make_turn_context(
auth_manager: Option<Arc<AuthManager>>,
models_manager: Arc<ModelsManager>,
otel_event_manager: &OtelEventManager,
provider: ModelProviderInfo,
session_configuration: &SessionConfiguration,
conversation_id: ConversationId,
sub_id: String,
) -> TurnContext {
fn build_per_turn_config(session_configuration: &SessionConfiguration) -> Config {
let config = session_configuration.original_config_do_not_use.clone();
let features = &config.features;
let mut per_turn_config = (*config).clone();
per_turn_config.model = session_configuration.model.clone();
per_turn_config.model_reasoning_effort = session_configuration.model_reasoning_effort;
per_turn_config.model_reasoning_summary = session_configuration.model_reasoning_summary;
per_turn_config.features = features.clone();
let model_family =
models_manager.construct_model_family(&per_turn_config.model, &per_turn_config);
per_turn_config.features = config.features.clone();
per_turn_config
}
#[allow(clippy::too_many_arguments)]
fn make_turn_context(
auth_manager: Option<Arc<AuthManager>>,
otel_event_manager: &OtelEventManager,
provider: ModelProviderInfo,
session_configuration: &SessionConfiguration,
mut per_turn_config: Config,
model_family: ModelFamily,
conversation_id: ConversationId,
sub_id: String,
) -> TurnContext {
if let Some(model_info) = get_model_info(&model_family) {
per_turn_config.model_context_window = Some(model_info.context_window);
}
let otel_event_manager = otel_event_manager.clone().with_model(
session_configuration.model.as_str(),
session_configuration.model.as_str(),
model_family.slug.as_str(),
);
let per_turn_config = Arc::new(per_turn_config);
let client = ModelClient::new(
Arc::new(per_turn_config.clone()),
per_turn_config.clone(),
auth_manager,
model_family.clone(),
otel_event_manager,
@@ -439,7 +444,7 @@ impl Session {
let tools_config = ToolsConfig::new(&ToolsConfigParams {
model_family: &model_family,
features,
features: &per_turn_config.features,
});
TurnContext {
@@ -452,14 +457,14 @@ impl Session {
user_instructions: session_configuration.user_instructions.clone(),
approval_policy: session_configuration.approval_policy,
sandbox_policy: session_configuration.sandbox_policy.clone(),
shell_environment_policy: config.shell_environment_policy.clone(),
shell_environment_policy: per_turn_config.shell_environment_policy.clone(),
tools_config,
final_output_json_schema: None,
codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(),
codex_linux_sandbox_exe: per_turn_config.codex_linux_sandbox_exe.clone(),
tool_call_gate: Arc::new(ReadinessFlag::new()),
exec_policy: session_configuration.exec_policy.clone(),
truncation_policy: TruncationPolicy::new(
&per_turn_config,
per_turn_config.as_ref(),
model_family.truncation_policy,
),
}
@@ -545,7 +550,9 @@ impl Session {
});
}
let model_family = models_manager.construct_model_family(&config.model, &config);
let model_family = models_manager
.construct_model_family(&config.model, &config)
.await;
// todo(aibrahim): why are we passing model here while it can change?
let otel_event_manager = OtelEventManager::new(
conversation_id,
@@ -768,12 +775,19 @@ impl Session {
session_configuration
};
let per_turn_config = Self::build_per_turn_config(&session_configuration);
let model_family = self
.services
.models_manager
.construct_model_family(&per_turn_config.model, &per_turn_config)
.await;
let mut turn_context: TurnContext = Self::make_turn_context(
Some(Arc::clone(&self.services.auth_manager)),
Arc::clone(&self.services.models_manager),
&self.services.otel_event_manager,
session_configuration.provider.clone(),
&session_configuration,
per_turn_config,
model_family,
self.conversation_id,
sub_id,
);
@@ -1907,7 +1921,8 @@ async fn spawn_review_thread(
let review_model_family = sess
.services
.models_manager
.construct_model_family(&model, &config);
.construct_model_family(&model, &config)
.await;
// For reviews, disable web_search and view_image regardless of global settings.
let mut review_features = sess.features.clone();
review_features
@@ -2812,15 +2827,12 @@ mod tests {
fn otel_event_manager(
conversation_id: ConversationId,
config: &Config,
models_manager: &ModelsManager,
model_family: &ModelFamily,
) -> OtelEventManager {
OtelEventManager::new(
conversation_id,
config.model.as_str(),
models_manager
.construct_model_family(&config.model, config)
.slug
.as_str(),
model_family.slug.as_str(),
None,
Some("test@test.com".to_string()),
Some(AuthMode::ChatGPT),
@@ -2843,9 +2855,6 @@ mod tests {
let auth_manager =
AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key"));
let models_manager = Arc::new(ModelsManager::new(auth_manager.clone()));
let otel_event_manager =
otel_event_manager(conversation_id, config.as_ref(), &models_manager);
let session_configuration = SessionConfiguration {
provider: config.model_provider.clone(),
model: config.model.clone(),
@@ -2862,6 +2871,11 @@ mod tests {
exec_policy: Arc::new(RwLock::new(ExecPolicy::empty())),
session_source: SessionSource::Exec,
};
let per_turn_config = Session::build_per_turn_config(&session_configuration);
let model_family =
ModelsManager::construct_model_family_offline(&per_turn_config.model, &per_turn_config);
let otel_event_manager =
otel_event_manager(conversation_id, config.as_ref(), &model_family);
let state = SessionState::new(session_configuration.clone());
@@ -2875,16 +2889,17 @@ mod tests {
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
auth_manager: auth_manager.clone(),
otel_event_manager: otel_event_manager.clone(),
models_manager: models_manager.clone(),
models_manager,
tool_approvals: Mutex::new(ApprovalStore::default()),
};
let turn_context = Session::make_turn_context(
Some(Arc::clone(&auth_manager)),
models_manager,
&otel_event_manager,
session_configuration.provider.clone(),
&session_configuration,
per_turn_config,
model_family,
conversation_id,
"turn_id".to_string(),
);
@@ -2922,9 +2937,6 @@ mod tests {
let auth_manager =
AuthManager::from_auth_for_testing(CodexAuth::from_api_key("Test API Key"));
let models_manager = Arc::new(ModelsManager::new(auth_manager.clone()));
let otel_event_manager =
otel_event_manager(conversation_id, config.as_ref(), &models_manager);
let session_configuration = SessionConfiguration {
provider: config.model_provider.clone(),
model: config.model.clone(),
@@ -2941,6 +2953,11 @@ mod tests {
exec_policy: Arc::new(RwLock::new(ExecPolicy::empty())),
session_source: SessionSource::Exec,
};
let per_turn_config = Session::build_per_turn_config(&session_configuration);
let model_family =
ModelsManager::construct_model_family_offline(&per_turn_config.model, &per_turn_config);
let otel_event_manager =
otel_event_manager(conversation_id, config.as_ref(), &model_family);
let state = SessionState::new(session_configuration.clone());
@@ -2954,16 +2971,17 @@ mod tests {
show_raw_agent_reasoning: config.show_raw_agent_reasoning,
auth_manager: Arc::clone(&auth_manager),
otel_event_manager: otel_event_manager.clone(),
models_manager: models_manager.clone(),
models_manager,
tool_approvals: Mutex::new(ApprovalStore::default()),
};
let turn_context = Arc::new(Session::make_turn_context(
Some(Arc::clone(&auth_manager)),
models_manager,
&otel_event_manager,
session_configuration.provider.clone(),
&session_configuration,
per_turn_config,
model_family,
conversation_id,
"turn_id".to_string(),
));