mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
test
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use crate::api_bridge::auth_provider_from_auth;
|
||||
use crate::api_bridge::map_api_error;
|
||||
@@ -58,8 +57,8 @@ use crate::tools::spec::create_tools_json_for_responses_api;
|
||||
pub struct ModelClient {
|
||||
config: Arc<Config>,
|
||||
auth_manager: Option<Arc<AuthManager>>,
|
||||
model_family: RwLock<ModelFamily>,
|
||||
models_etag: RwLock<Option<String>>,
|
||||
model_family: ModelFamily,
|
||||
models_etag: Option<String>,
|
||||
otel_manager: OtelManager,
|
||||
provider: ModelProviderInfo,
|
||||
conversation_id: ConversationId,
|
||||
@@ -85,8 +84,8 @@ impl ModelClient {
|
||||
Self {
|
||||
config,
|
||||
auth_manager,
|
||||
model_family: RwLock::new(model_family),
|
||||
models_etag: RwLock::new(models_etag),
|
||||
model_family,
|
||||
models_etag,
|
||||
otel_manager,
|
||||
provider,
|
||||
conversation_id,
|
||||
@@ -112,22 +111,6 @@ impl ModelClient {
|
||||
&self.provider
|
||||
}
|
||||
|
||||
pub fn update_models_etag(&self, models_etag: Option<String>) {
|
||||
let mut guard = self
|
||||
.models_etag
|
||||
.write()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
*guard = models_etag;
|
||||
}
|
||||
|
||||
pub fn update_model_family(&self, model_family: ModelFamily) {
|
||||
let mut guard = self
|
||||
.model_family
|
||||
.write()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
*guard = model_family;
|
||||
}
|
||||
|
||||
/// Streams a single model turn using either the Responses or Chat
|
||||
/// Completions wire API, depending on the configured provider.
|
||||
///
|
||||
@@ -167,7 +150,7 @@ impl ModelClient {
|
||||
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let model_family = self.get_model_family();
|
||||
let instructions = prompt.get_full_instructions(&model_family).into_owned();
|
||||
let instructions = prompt.get_full_instructions(model_family).into_owned();
|
||||
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
|
||||
let api_prompt = build_api_prompt(prompt, instructions, tools_json);
|
||||
let conversation_id = self.conversation_id.to_string();
|
||||
@@ -221,7 +204,7 @@ impl ModelClient {
|
||||
|
||||
let auth_manager = self.auth_manager.clone();
|
||||
let model_family = self.get_model_family();
|
||||
let instructions = prompt.get_full_instructions(&model_family).into_owned();
|
||||
let instructions = prompt.get_full_instructions(model_family).into_owned();
|
||||
let tools_json: Vec<Value> = create_tools_json_for_responses_api(&prompt.tools)?;
|
||||
|
||||
let reasoning = if model_family.supports_reasoning_summaries {
|
||||
@@ -282,7 +265,7 @@ impl ModelClient {
|
||||
store_override: None,
|
||||
conversation_id: Some(conversation_id.clone()),
|
||||
session_source: Some(session_source.clone()),
|
||||
extra_headers: beta_feature_headers(&self.config, self.get_models_etag()),
|
||||
extra_headers: beta_feature_headers(&self.config, self.get_models_etag().clone()),
|
||||
};
|
||||
|
||||
let stream_result = client
|
||||
@@ -322,18 +305,12 @@ impl ModelClient {
|
||||
}
|
||||
|
||||
/// Returns the currently configured model family.
|
||||
pub fn get_model_family(&self) -> ModelFamily {
|
||||
self.model_family
|
||||
.read()
|
||||
.map(|model_family| model_family.clone())
|
||||
.unwrap_or_else(|err| err.into_inner().clone())
|
||||
pub fn get_model_family(&self) -> &ModelFamily {
|
||||
&self.model_family
|
||||
}
|
||||
|
||||
fn get_models_etag(&self) -> Option<String> {
|
||||
self.models_etag
|
||||
.read()
|
||||
.map(|models_etag| models_etag.clone())
|
||||
.unwrap_or_else(|err| err.into_inner().clone())
|
||||
fn get_models_etag(&self) -> &Option<String> {
|
||||
&self.models_etag
|
||||
}
|
||||
|
||||
/// Returns the current reasoning effort setting.
|
||||
@@ -370,7 +347,7 @@ impl ModelClient {
|
||||
.with_telemetry(Some(request_telemetry));
|
||||
|
||||
let instructions = prompt
|
||||
.get_full_instructions(&self.get_model_family())
|
||||
.get_full_instructions(self.get_model_family())
|
||||
.into_owned();
|
||||
let payload = ApiCompactionInput {
|
||||
model: &self.get_model(),
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
use crate::client_common::tools::ToolSpec;
|
||||
use crate::codex::Session;
|
||||
use crate::codex::TurnContext;
|
||||
use crate::error::Result;
|
||||
use crate::features::Feature;
|
||||
use crate::openai_models::model_family::ModelFamily;
|
||||
use crate::tools::ToolRouter;
|
||||
pub use codex_api::common::ResponseEvent;
|
||||
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
|
||||
use codex_protocol::models::ResponseItem;
|
||||
@@ -44,6 +48,27 @@ pub struct Prompt {
|
||||
}
|
||||
|
||||
impl Prompt {
|
||||
pub(crate) fn new(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
router: &ToolRouter,
|
||||
input: &[ResponseItem],
|
||||
) -> Prompt {
|
||||
let model_supports_parallel = turn_context
|
||||
.client
|
||||
.get_model_family()
|
||||
.supports_parallel_tool_calls;
|
||||
|
||||
Prompt {
|
||||
input: input.to_vec(),
|
||||
tools: router.specs(),
|
||||
parallel_tool_calls: model_supports_parallel
|
||||
&& sess.enabled(Feature::ParallelToolCalls),
|
||||
base_instructions_override: turn_context.base_instructions.clone(),
|
||||
output_schema: turn_context.final_output_json_schema.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_full_instructions<'a>(&'a self, model: &'a ModelFamily) -> Cow<'a, str> {
|
||||
let base = self
|
||||
.base_instructions_override
|
||||
|
||||
@@ -2226,7 +2226,7 @@ fn errors_to_info(errors: &[SkillError]) -> Vec<SkillErrorInfo> {
|
||||
///
|
||||
pub(crate) async fn run_task(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
mut turn_context: Arc<TurnContext>,
|
||||
input: Vec<UserInput>,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Option<String> {
|
||||
@@ -2309,7 +2309,7 @@ pub(crate) async fn run_task(
|
||||
.collect::<Vec<String>>();
|
||||
match run_turn(
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
&mut turn_context,
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
turn_input,
|
||||
cancellation_token.child_token(),
|
||||
@@ -2386,7 +2386,7 @@ async fn run_auto_compact(sess: &Arc<Session>, turn_context: &Arc<TurnContext>)
|
||||
)]
|
||||
async fn run_turn(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
turn_context: &mut Arc<TurnContext>,
|
||||
turn_diff_tracker: SharedTurnDiffTracker,
|
||||
input: Vec<ResponseItem>,
|
||||
cancellation_token: CancellationToken,
|
||||
@@ -2399,19 +2399,20 @@ async fn run_turn(
|
||||
.list_all_tools()
|
||||
.or_cancel(&cancellation_token)
|
||||
.await?;
|
||||
let router = Arc::new(ToolRouter::from_config(
|
||||
&turn_context.tools_config,
|
||||
Some(
|
||||
mcp_tools
|
||||
.into_iter()
|
||||
.map(|(name, tool)| (name, tool.tool))
|
||||
.collect(),
|
||||
),
|
||||
));
|
||||
|
||||
let mut retries = 0;
|
||||
loop {
|
||||
let prompt = build_prompt(
|
||||
let router = Arc::new(ToolRouter::from_config(
|
||||
&turn_context.tools_config,
|
||||
Some(
|
||||
mcp_tools
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|(name, tool)| (name, tool.tool))
|
||||
.collect(),
|
||||
),
|
||||
));
|
||||
let prompt = Prompt::new(
|
||||
sess.as_ref(),
|
||||
turn_context.as_ref(),
|
||||
router.as_ref(),
|
||||
@@ -2421,7 +2422,7 @@ async fn run_turn(
|
||||
match try_run_turn(
|
||||
Arc::clone(&router),
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(&turn_context),
|
||||
Arc::clone(turn_context),
|
||||
Arc::clone(&turn_diff_tracker),
|
||||
&prompt,
|
||||
cancellation_token.child_token(),
|
||||
@@ -2437,13 +2438,13 @@ async fn run_turn(
|
||||
Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)),
|
||||
Err(e @ CodexErr::Fatal(_)) => return Err(e),
|
||||
Err(e @ CodexErr::ContextWindowExceeded) => {
|
||||
sess.set_total_tokens_full(&turn_context).await;
|
||||
sess.set_total_tokens_full(turn_context).await;
|
||||
return Err(e);
|
||||
}
|
||||
Err(CodexErr::UsageLimitReached(e)) => {
|
||||
let rate_limits = e.rate_limits.clone();
|
||||
if let Some(rate_limits) = rate_limits {
|
||||
sess.update_rate_limits(&turn_context, rate_limits).await;
|
||||
sess.update_rate_limits(turn_context, rate_limits).await;
|
||||
}
|
||||
return Err(CodexErr::UsageLimitReached(e));
|
||||
}
|
||||
@@ -2459,7 +2460,23 @@ async fn run_turn(
|
||||
retries += 1;
|
||||
// Refresh models if we got an outdated models error
|
||||
if matches!(e, CodexErr::OutdatedModels) {
|
||||
refresh_models_after_outdated_error(sess.as_ref(), turn_context.as_ref())
|
||||
let config = {
|
||||
let state = sess.state.lock().await;
|
||||
state
|
||||
.session_configuration
|
||||
.original_config_do_not_use
|
||||
.clone()
|
||||
};
|
||||
if let Err(err) = sess
|
||||
.services
|
||||
.models_manager
|
||||
.refresh_available_models(&config)
|
||||
.await
|
||||
{
|
||||
error!("failed to refresh models after outdated models error: {err}");
|
||||
}
|
||||
*turn_context = sess
|
||||
.new_default_turn_with_sub_id(turn_context.sub_id.clone())
|
||||
.await;
|
||||
}
|
||||
let delay = match e {
|
||||
@@ -2474,7 +2491,7 @@ async fn run_turn(
|
||||
// user understands what is happening instead of staring
|
||||
// at a seemingly frozen screen.
|
||||
sess.notify_stream_error(
|
||||
&turn_context,
|
||||
turn_context,
|
||||
format!("Reconnecting... {retries}/{max_retries}"),
|
||||
e,
|
||||
)
|
||||
@@ -2489,53 +2506,6 @@ async fn run_turn(
|
||||
}
|
||||
}
|
||||
|
||||
fn build_prompt(
|
||||
sess: &Session,
|
||||
turn_context: &TurnContext,
|
||||
router: &ToolRouter,
|
||||
input: &[ResponseItem],
|
||||
) -> Prompt {
|
||||
let model_supports_parallel = turn_context
|
||||
.client
|
||||
.get_model_family()
|
||||
.supports_parallel_tool_calls;
|
||||
|
||||
Prompt {
|
||||
input: input.to_vec(),
|
||||
tools: router.specs(),
|
||||
parallel_tool_calls: model_supports_parallel && sess.enabled(Feature::ParallelToolCalls),
|
||||
base_instructions_override: turn_context.base_instructions.clone(),
|
||||
output_schema: turn_context.final_output_json_schema.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn refresh_models_after_outdated_error(sess: &Session, turn_context: &TurnContext) {
|
||||
let config = {
|
||||
let state = sess.state.lock().await;
|
||||
state
|
||||
.session_configuration
|
||||
.original_config_do_not_use
|
||||
.clone()
|
||||
};
|
||||
if let Err(err) = sess
|
||||
.services
|
||||
.models_manager
|
||||
.refresh_available_models(&config)
|
||||
.await
|
||||
{
|
||||
error!("failed to refresh models after outdated models error: {err}");
|
||||
}
|
||||
let models_etag = sess.services.models_manager.get_models_etag().await;
|
||||
let model = turn_context.client.get_model();
|
||||
let model_family = sess
|
||||
.services
|
||||
.models_manager
|
||||
.construct_model_family(&model, &config)
|
||||
.await;
|
||||
turn_context.client.update_model_family(model_family);
|
||||
turn_context.client.update_models_etag(models_etag);
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct TurnRunResult {
|
||||
needs_follow_up: bool,
|
||||
|
||||
@@ -530,6 +530,17 @@ async fn remote_models_refresh_etag_after_outdated_models() -> Result<()> {
|
||||
})
|
||||
.await?;
|
||||
|
||||
let stream_error =
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::StreamError(_))).await;
|
||||
let EventMsg::StreamError(stream_error) = stream_error else {
|
||||
unreachable!();
|
||||
};
|
||||
assert!(
|
||||
stream_error.message.starts_with("Reconnecting..."),
|
||||
"unexpected stream error message: {}",
|
||||
stream_error.message
|
||||
);
|
||||
|
||||
wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await;
|
||||
|
||||
// Phase 3c: assert the refresh happened and the ETag was updated.
|
||||
|
||||
Reference in New Issue
Block a user