Compare commits

...

1 Commits

Author SHA1 Message Date
shijie-openai
53f9480cd8 Respect fast_mode for Fast service tier 2026-05-08 13:14:14 -07:00
6 changed files with 125 additions and 21 deletions

View File

@@ -799,6 +799,8 @@ fn get_service_tier(
account_plan_type: Option<AccountPlanType>,
fast_mode_enabled: bool,
) -> Option<String> {
let configured_service_tier =
normalize_configured_service_tier(configured_service_tier, fast_mode_enabled);
if configured_service_tier.is_some() || fast_default_opt_out || !fast_mode_enabled {
return configured_service_tier;
}
@@ -808,6 +810,19 @@ fn get_service_tier(
.then_some(ServiceTier::Fast.request_value().to_string())
}
fn normalize_configured_service_tier(
service_tier: Option<String>,
fast_mode_enabled: bool,
) -> Option<String> {
service_tier.and_then(
|service_tier| match ServiceTier::from_request_value(&service_tier) {
Some(ServiceTier::Fast) if !fast_mode_enabled => None,
Some(service_tier) => Some(service_tier.request_value().to_string()),
None => Some(service_tier),
},
)
}
fn is_enterprise_default_service_tier_plan(plan_type: AccountPlanType) -> bool {
plan_type == AccountPlanType::Enterprise
|| plan_type.is_business_like()
@@ -1334,7 +1349,10 @@ impl Session {
) -> ConstraintResult<()> {
let (previous_cwd, permission_profile_changed, next_cwd, codex_home, session_source) = {
let mut state = self.state.lock().await;
let updated = match state.session_configuration.apply(&updates) {
let updated = match state
.session_configuration
.apply_with_fast_mode(&updates, self.features.enabled(Feature::FastMode))
{
Ok(updated) => updated,
Err(err) => {
warn!("rejected session settings update: {err}");
@@ -1379,7 +1397,10 @@ impl Session {
updates: &SessionSettingsUpdate,
) -> ConstraintResult<()> {
let state = self.state.lock().await;
state.session_configuration.apply(updates).map(|_| ())
state
.session_configuration
.apply_with_fast_mode(updates, self.features.enabled(Feature::FastMode))
.map(|_| ())
}
pub(crate) async fn set_session_startup_prewarm(

View File

@@ -1,7 +1,6 @@
use super::*;
use crate::goals::GoalRuntimeState;
use codex_protocol::SessionId;
use codex_protocol::config_types::ServiceTier;
use codex_protocol::permissions::FileSystemPath;
use codex_protocol::permissions::FileSystemSpecialPath;
use codex_protocol::protocol::ThreadSource;
@@ -151,7 +150,16 @@ impl SessionConfiguration {
}
}
#[cfg(test)]
pub(crate) fn apply(&self, updates: &SessionSettingsUpdate) -> ConstraintResult<Self> {
self.apply_with_fast_mode(updates, /*fast_mode_enabled*/ true)
}
pub(crate) fn apply_with_fast_mode(
&self,
updates: &SessionSettingsUpdate,
fast_mode_enabled: bool,
) -> ConstraintResult<Self> {
let mut next_configuration = self.clone();
let current_sandbox_policy = self.sandbox_policy();
let current_file_system_sandbox_policy = self.file_system_sandbox_policy();
@@ -186,12 +194,8 @@ impl SessionConfiguration {
if let Some(service_tier) = updates.service_tier.clone() {
// TODO(aibrahim): Remove once v2 clients no longer send the legacy
// "fast" service tier value.
next_configuration.service_tier = service_tier.map(|service_tier| {
ServiceTier::from_request_value(&service_tier)
.map_or(service_tier, |service_tier| {
service_tier.request_value().to_string()
})
});
next_configuration.service_tier =
normalize_configured_service_tier(service_tier, fast_mode_enabled);
}
if let Some(personality) = updates.personality {
next_configuration.personality = Some(personality);

View File

@@ -3094,6 +3094,37 @@ fn get_service_tier_does_not_default_non_enterprise_or_disabled_fast_mode() {
);
}
#[test]
fn get_service_tier_filters_configured_fast_when_fast_mode_is_disabled() {
assert_eq!(
get_service_tier(
Some(ServiceTier::Fast.request_value().to_string()),
/*fast_default_opt_out*/ false,
Some(AccountPlanType::Enterprise),
/*fast_mode_enabled*/ false,
),
None
);
assert_eq!(
get_service_tier(
Some("fast".to_string()),
/*fast_default_opt_out*/ false,
Some(AccountPlanType::Enterprise),
/*fast_mode_enabled*/ false,
),
None
);
assert_eq!(
get_service_tier(
Some(ServiceTier::Flex.request_value().to_string()),
/*fast_default_opt_out*/ false,
Some(AccountPlanType::Enterprise),
/*fast_mode_enabled*/ false,
),
Some(ServiceTier::Flex.request_value().to_string())
);
}
#[tokio::test]
async fn session_settings_null_service_tier_update_clears_service_tier() {
let session_configuration = make_session_configuration_for_tests().await;
@@ -3125,6 +3156,23 @@ async fn session_settings_legacy_fast_service_tier_update_uses_priority_request_
);
}
#[tokio::test]
async fn session_settings_fast_service_tier_update_is_ignored_when_fast_mode_is_disabled() {
let session_configuration = make_session_configuration_for_tests().await;
let updated = session_configuration
.apply_with_fast_mode(
&SessionSettingsUpdate {
service_tier: Some(Some(ServiceTier::Fast.request_value().to_string())),
..Default::default()
},
/*fast_mode_enabled*/ false,
)
.expect("service tier update should apply");
assert_eq!(updated.service_tier, None);
}
pub(crate) async fn make_session_configuration_for_tests() -> SessionConfiguration {
let codex_home = tempfile::tempdir().expect("create temp dir");
let config = build_test_config(codex_home.path()).await;

View File

@@ -585,7 +585,11 @@ impl Session {
) -> CodexResult<Arc<TurnContext>> {
let update_result: CodexResult<_> = {
let mut state = self.state.lock().await;
match state.session_configuration.clone().apply(&updates) {
match state
.session_configuration
.clone()
.apply_with_fast_mode(&updates, self.features.enabled(Feature::FastMode))
{
Ok(next) => {
let mut effective_environments = updates
.environments

View File

@@ -240,7 +240,7 @@ async fn websocket_v2_test_codex_shell_chain() -> Result<()> {
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_v2_first_turn_uses_updated_fast_tier_after_startup_prewarm() -> Result<()> {
async fn websocket_v2_drops_updated_fast_tier_when_fast_mode_disabled() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_websocket_server(vec![vec![
@@ -281,8 +281,8 @@ async fn websocket_v2_first_turn_uses_updated_fast_tier_after_startup_prewarm()
.body_json();
assert_eq!(first_turn["type"].as_str(), Some("response.create"));
assert_eq!(first_turn["service_tier"].as_str(), Some("priority"));
assert_eq!(first_turn.get("previous_response_id"), None);
assert_eq!(first_turn.get("service_tier"), None);
assert_eq!(first_turn["previous_response_id"].as_str(), Some("warm-1"));
assert!(
first_turn
.get("input")
@@ -323,7 +323,7 @@ async fn websocket_v2_first_turn_drops_fast_tier_after_startup_prewarm() -> Resu
.body_json();
assert_eq!(warmup["type"].as_str(), Some("response.create"));
assert_eq!(warmup["generate"].as_bool(), Some(false));
assert_eq!(warmup["service_tier"].as_str(), Some("priority"));
assert_eq!(warmup.get("service_tier"), None);
test.submit_turn_with_service_tier("hello", /*service_tier*/ None)
.await?;
@@ -338,7 +338,7 @@ async fn websocket_v2_first_turn_drops_fast_tier_after_startup_prewarm() -> Resu
assert_eq!(first_turn["type"].as_str(), Some("response.create"));
assert_eq!(first_turn.get("service_tier"), None);
assert_eq!(first_turn.get("previous_response_id"), None);
assert_eq!(first_turn["previous_response_id"].as_str(), Some("warm-1"));
assert!(
first_turn
.get("input")
@@ -351,14 +351,14 @@ async fn websocket_v2_first_turn_drops_fast_tier_after_startup_prewarm() -> Resu
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn websocket_v2_next_turn_uses_updated_service_tier() -> Result<()> {
async fn websocket_v2_next_turn_uses_updated_flex_service_tier() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_websocket_server(vec![vec![
vec![ev_response_created("warm-1"), ev_completed("warm-1")],
vec![
ev_response_created("resp-1"),
ev_assistant_message("msg-1", "fast"),
ev_assistant_message("msg-1", "flex"),
ev_completed("resp-1"),
],
vec![
@@ -385,7 +385,7 @@ async fn websocket_v2_next_turn_uses_updated_service_tier() -> Result<()> {
assert_eq!(warmup["generate"].as_bool(), Some(false));
assert_eq!(warmup.get("service_tier"), None);
test.submit_turn_with_service_tier("first", Some(ServiceTier::Fast))
test.submit_turn_with_service_tier("first", Some(ServiceTier::Flex))
.await?;
test.submit_turn_with_service_tier("second", /*service_tier*/ None)
.await?;
@@ -404,7 +404,7 @@ async fn websocket_v2_next_turn_uses_updated_service_tier() -> Result<()> {
.body_json();
assert_eq!(first_turn["type"].as_str(), Some("response.create"));
assert_eq!(first_turn["service_tier"].as_str(), Some("priority"));
assert_eq!(first_turn["service_tier"].as_str(), Some("flex"));
assert_eq!(first_turn.get("previous_response_id"), None);
assert!(
first_turn

View File

@@ -298,7 +298,7 @@ async fn service_tier_change_is_applied_on_next_http_turn() -> Result<()> {
let test = test_codex().build(&server).await?;
test.submit_turn_with_service_tier("fast turn", Some(ServiceTier::Fast))
test.submit_turn_with_service_tier("flex turn", Some(ServiceTier::Flex))
.await?;
test.submit_turn_with_service_tier("standard turn", /*service_tier*/ None)
.await?;
@@ -309,7 +309,7 @@ async fn service_tier_change_is_applied_on_next_http_turn() -> Result<()> {
let first_body = requests[0].body_json();
let second_body = requests[1].body_json();
assert_eq!(first_body["service_tier"].as_str(), Some("priority"));
assert_eq!(first_body["service_tier"].as_str(), Some("flex"));
assert_eq!(second_body.get("service_tier"), None);
Ok(())
@@ -334,6 +334,33 @@ async fn flex_service_tier_is_applied_to_http_turn() -> Result<()> {
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn fast_service_tier_is_not_applied_when_fast_mode_is_disabled() -> Result<()> {
skip_if_no_network!(Ok(()));
let server = start_mock_server().await;
let resp_mock = mount_sse_once(&server, sse_completed("resp-1")).await;
let test = test_codex()
.with_config(|config| {
config
.features
.disable(Feature::FastMode)
.expect("test config should allow feature update");
})
.build(&server)
.await?;
test.submit_turn_with_service_tier("fast turn", Some(ServiceTier::Fast))
.await?;
let request = resp_mock.single_request();
let body = request.body_json();
assert_eq!(body.get("service_tier"), None);
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn model_change_from_image_to_text_strips_prior_image_content() -> Result<()> {
skip_if_no_network!(Ok(()));