mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
context in headers
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
@@ -5,6 +6,7 @@ use crate::api_bridge::CoreAuthProvider;
|
||||
use crate::api_bridge::auth_provider_from_auth;
|
||||
use crate::api_bridge::map_api_error;
|
||||
use crate::auth::UnauthorizedRecovery;
|
||||
use crate::turn_metadata::build_turn_metadata_header;
|
||||
use codex_api::AggregateStreamExt;
|
||||
use codex_api::ChatClient as ApiChatClient;
|
||||
use codex_api::CompactClient as ApiCompactClient;
|
||||
@@ -72,6 +74,7 @@ use crate::transport_manager::TransportManager;
|
||||
|
||||
pub const WEB_SEARCH_ELIGIBLE_HEADER: &str = "x-oai-web-search-eligible";
|
||||
pub const X_CODEX_TURN_STATE_HEADER: &str = "x-codex-turn-state";
|
||||
pub const X_CODEX_TURN_METADATA_HEADER: &str = "x-codex-turn-metadata";
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ModelClientState {
|
||||
@@ -108,6 +111,10 @@ pub struct ModelClientSession {
|
||||
/// keep sending it unchanged between turn requests (e.g., for retries, incremental
|
||||
/// appends, or continuation requests), and must not send it between different turns.
|
||||
turn_state: Arc<OnceLock<String>>,
|
||||
/// Turn-scoped metadata attached to every request in the turn.
|
||||
turn_metadata_header: Option<HeaderValue>,
|
||||
/// Working directory used to lazily compute turn metadata at send time.
|
||||
turn_metadata_cwd: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -141,12 +148,31 @@ impl ModelClient {
|
||||
}
|
||||
|
||||
pub fn new_session(&self) -> ModelClientSession {
|
||||
self.new_session_with_turn_metadata_and_cwd(None, None)
|
||||
}
|
||||
|
||||
pub fn new_session_with_turn_metadata(
|
||||
&self,
|
||||
turn_metadata_header: Option<String>,
|
||||
) -> ModelClientSession {
|
||||
self.new_session_with_turn_metadata_and_cwd(turn_metadata_header, None)
|
||||
}
|
||||
|
||||
pub fn new_session_with_turn_metadata_and_cwd(
|
||||
&self,
|
||||
turn_metadata_header: Option<String>,
|
||||
turn_metadata_cwd: Option<PathBuf>,
|
||||
) -> ModelClientSession {
|
||||
let turn_metadata_header =
|
||||
turn_metadata_header.and_then(|value| HeaderValue::from_str(&value).ok());
|
||||
ModelClientSession {
|
||||
state: Arc::clone(&self.state),
|
||||
connection: None,
|
||||
websocket_last_items: Vec::new(),
|
||||
transport_manager: self.state.transport_manager.clone(),
|
||||
turn_state: Arc::new(OnceLock::new()),
|
||||
turn_metadata_header,
|
||||
turn_metadata_cwd,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -257,6 +283,21 @@ impl ModelClient {
|
||||
}
|
||||
|
||||
impl ModelClientSession {
|
||||
async fn ensure_turn_metadata_header(&mut self) {
|
||||
if self.turn_metadata_header.is_some() {
|
||||
return;
|
||||
}
|
||||
let Some(cwd) = self.turn_metadata_cwd.as_deref() else {
|
||||
return;
|
||||
};
|
||||
let Some(value) = build_turn_metadata_header(cwd).await else {
|
||||
return;
|
||||
};
|
||||
if let Ok(header_value) = HeaderValue::from_str(value.as_str()) {
|
||||
self.turn_metadata_header = Some(header_value);
|
||||
}
|
||||
}
|
||||
|
||||
/// Streams a single model turn using either the Responses or Chat
|
||||
/// Completions wire API, depending on the configured provider.
|
||||
///
|
||||
@@ -264,6 +305,9 @@ impl ModelClientSession {
|
||||
/// based on the `show_raw_agent_reasoning` flag in the config.
|
||||
pub async fn stream(&mut self, prompt: &Prompt) -> Result<ResponseStream> {
|
||||
let wire_api = self.state.provider.wire_api;
|
||||
if matches!(wire_api, WireApi::Responses) {
|
||||
self.ensure_turn_metadata_header().await;
|
||||
}
|
||||
match wire_api {
|
||||
WireApi::Responses => {
|
||||
let websocket_enabled = self.responses_websocket_enabled()
|
||||
@@ -380,7 +424,11 @@ impl ModelClientSession {
|
||||
store_override: None,
|
||||
conversation_id: Some(conversation_id),
|
||||
session_source: Some(self.state.session_source.clone()),
|
||||
extra_headers: build_responses_headers(&self.state.config, Some(&self.turn_state)),
|
||||
extra_headers: build_responses_headers(
|
||||
&self.state.config,
|
||||
Some(&self.turn_state),
|
||||
self.turn_metadata_header.as_ref(),
|
||||
),
|
||||
compression,
|
||||
turn_state: Some(Arc::clone(&self.turn_state)),
|
||||
}
|
||||
@@ -713,6 +761,7 @@ fn experimental_feature_headers(config: &Config) -> ApiHeaderMap {
|
||||
fn build_responses_headers(
|
||||
config: &Config,
|
||||
turn_state: Option<&Arc<OnceLock<String>>>,
|
||||
turn_metadata_header: Option<&HeaderValue>,
|
||||
) -> ApiHeaderMap {
|
||||
let mut headers = experimental_feature_headers(config);
|
||||
headers.insert(
|
||||
@@ -731,6 +780,9 @@ fn build_responses_headers(
|
||||
{
|
||||
headers.insert(X_CODEX_TURN_STATE_HEADER, header_value);
|
||||
}
|
||||
if let Some(header_value) = turn_metadata_header {
|
||||
headers.insert(X_CODEX_TURN_METADATA_HEADER, header_value.clone());
|
||||
}
|
||||
headers
|
||||
}
|
||||
|
||||
|
||||
@@ -119,6 +119,7 @@ use crate::error::Result as CodexResult;
|
||||
use crate::exec::StreamOutput;
|
||||
use crate::exec_policy::ExecPolicyUpdateError;
|
||||
use crate::feedback_tags;
|
||||
use crate::git_info::get_git_repo_root;
|
||||
use crate::instructions::UserInstructions;
|
||||
use crate::mcp::CODEX_APPS_MCP_SERVER_NAME;
|
||||
use crate::mcp::auth::compute_auth_statuses;
|
||||
@@ -503,8 +504,8 @@ pub(crate) struct TurnContext {
|
||||
pub(crate) tool_call_gate: Arc<ReadinessFlag>,
|
||||
pub(crate) truncation_policy: TruncationPolicy,
|
||||
pub(crate) dynamic_tools: Vec<DynamicToolSpec>,
|
||||
pub(crate) turn_metadata_header: Option<String>,
|
||||
}
|
||||
|
||||
impl TurnContext {
|
||||
pub(crate) fn resolve_path(&self, path: Option<String>) -> PathBuf {
|
||||
path.as_ref()
|
||||
@@ -705,6 +706,7 @@ impl Session {
|
||||
tool_call_gate: Arc::new(ReadinessFlag::new()),
|
||||
truncation_policy: model_info.truncation_policy.into(),
|
||||
dynamic_tools: session_configuration.dynamic_tools.clone(),
|
||||
turn_metadata_header: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3176,6 +3178,7 @@ async fn spawn_review_thread(
|
||||
tool_call_gate: Arc::new(ReadinessFlag::new()),
|
||||
dynamic_tools: parent_turn_context.dynamic_tools.clone(),
|
||||
truncation_policy: model_info.truncation_policy.into(),
|
||||
turn_metadata_header: None,
|
||||
};
|
||||
|
||||
// Seed the child task with the review prompt as the initial user message.
|
||||
@@ -3379,7 +3382,10 @@ pub(crate) async fn run_turn(
|
||||
// many turns, from the perspective of the user, it is a single turn.
|
||||
let turn_diff_tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
|
||||
|
||||
let mut client_session = turn_context.client.new_session();
|
||||
let mut client_session = turn_context.client.new_session_with_turn_metadata_and_cwd(
|
||||
turn_context.turn_metadata_header.clone(),
|
||||
Some(turn_context.cwd.clone()),
|
||||
);
|
||||
|
||||
loop {
|
||||
// Note that pending_input would be something like a message the user
|
||||
@@ -4442,8 +4448,6 @@ pub(super) fn get_last_assistant_message_from_turn(responses: &[ResponseItem]) -
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) use tests::make_session_and_context;
|
||||
|
||||
use crate::git_info::get_git_repo_root;
|
||||
#[cfg(test)]
|
||||
pub(crate) use tests::make_session_and_context_with_rx;
|
||||
|
||||
|
||||
@@ -335,7 +335,10 @@ async fn drain_to_completed(
|
||||
turn_context: &TurnContext,
|
||||
prompt: &Prompt,
|
||||
) -> CodexResult<()> {
|
||||
let mut client_session = turn_context.client.new_session();
|
||||
let mut client_session = turn_context.client.new_session_with_turn_metadata_and_cwd(
|
||||
turn_context.turn_metadata_header.clone(),
|
||||
Some(turn_context.cwd.clone()),
|
||||
);
|
||||
let mut stream = client_session.stream(prompt).await?;
|
||||
loop {
|
||||
let maybe_event = stream.next().await;
|
||||
|
||||
@@ -28,6 +28,7 @@ impl EnvironmentContext {
|
||||
cwd,
|
||||
// should compare all fields except shell
|
||||
shell: _,
|
||||
..
|
||||
} = other;
|
||||
|
||||
self.cwd == *cwd
|
||||
@@ -66,6 +67,7 @@ impl EnvironmentContext {
|
||||
|
||||
let shell_name = self.shell.name();
|
||||
lines.push(format!(" <shell>{shell_name}</shell>"));
|
||||
|
||||
lines.push(ENVIRONMENT_CONTEXT_CLOSE_TAG.to_string());
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashSet;
|
||||
use std::path::Path;
|
||||
use std::path::PathBuf;
|
||||
@@ -109,6 +110,92 @@ pub async fn collect_git_info(cwd: &Path) -> Option<GitInfo> {
|
||||
Some(git_info)
|
||||
}
|
||||
|
||||
/// Collect fetch remotes in a multi-root-friendly format: {"origin": "https://..."}.
|
||||
pub async fn get_git_remote_urls(cwd: &Path) -> Option<BTreeMap<String, String>> {
|
||||
let is_git_repo = run_git_command_with_timeout(&["rev-parse", "--git-dir"], cwd)
|
||||
.await?
|
||||
.status
|
||||
.success();
|
||||
if !is_git_repo {
|
||||
return None;
|
||||
}
|
||||
|
||||
get_git_remote_urls_assume_git_repo(cwd).await
|
||||
}
|
||||
|
||||
/// Collect fetch remotes without checking whether `cwd` is in a git repo.
|
||||
pub async fn get_git_remote_urls_assume_git_repo(cwd: &Path) -> Option<BTreeMap<String, String>> {
|
||||
get_git_remote_urls_assume_git_repo_with_timeout(cwd, GIT_COMMAND_TIMEOUT).await
|
||||
}
|
||||
|
||||
/// Collect fetch remotes without checking whether `cwd` is in a git repo,
|
||||
/// using the provided timeout.
|
||||
pub async fn get_git_remote_urls_assume_git_repo_with_timeout(
|
||||
cwd: &Path,
|
||||
timeout_dur: TokioDuration,
|
||||
) -> Option<BTreeMap<String, String>> {
|
||||
let output = run_git_command_with_timeout_override(&["remote", "-v"], cwd, timeout_dur).await?;
|
||||
if !output.status.success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).ok()?;
|
||||
parse_git_remote_urls(stdout.as_str())
|
||||
}
|
||||
|
||||
/// Return the current HEAD commit hash without checking whether `cwd` is in a git repo.
|
||||
pub async fn get_head_commit_hash(cwd: &Path) -> Option<String> {
|
||||
get_head_commit_hash_with_timeout(cwd, GIT_COMMAND_TIMEOUT).await
|
||||
}
|
||||
|
||||
/// Return the current HEAD commit hash without checking whether `cwd` is in a
|
||||
/// git repo, using the provided timeout.
|
||||
pub async fn get_head_commit_hash_with_timeout(
|
||||
cwd: &Path,
|
||||
timeout_dur: TokioDuration,
|
||||
) -> Option<String> {
|
||||
let output =
|
||||
run_git_command_with_timeout_override(&["rev-parse", "HEAD"], cwd, timeout_dur).await?;
|
||||
if !output.status.success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8(output.stdout).ok()?;
|
||||
let hash = stdout.trim();
|
||||
if hash.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(hash.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_git_remote_urls(stdout: &str) -> Option<BTreeMap<String, String>> {
|
||||
let mut remotes = BTreeMap::new();
|
||||
for line in stdout.lines() {
|
||||
let Some(fetch_line) = line.strip_suffix(" (fetch)") else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let Some((name, url_part)) = fetch_line
|
||||
.split_once('\t')
|
||||
.or_else(|| fetch_line.split_once(' '))
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let url = url_part.trim_start();
|
||||
if !url.is_empty() {
|
||||
remotes.insert(name.to_string(), url.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if remotes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(remotes)
|
||||
}
|
||||
}
|
||||
|
||||
/// A minimal commit summary entry used for pickers (subject + timestamp + sha).
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct CommitLogEntry {
|
||||
@@ -185,11 +272,19 @@ pub async fn git_diff_to_remote(cwd: &Path) -> Option<GitDiffToRemote> {
|
||||
|
||||
/// Run a git command with a timeout to prevent blocking on large repositories
|
||||
async fn run_git_command_with_timeout(args: &[&str], cwd: &Path) -> Option<std::process::Output> {
|
||||
let result = timeout(
|
||||
GIT_COMMAND_TIMEOUT,
|
||||
Command::new("git").args(args).current_dir(cwd).output(),
|
||||
)
|
||||
.await;
|
||||
run_git_command_with_timeout_override(args, cwd, GIT_COMMAND_TIMEOUT).await
|
||||
}
|
||||
|
||||
/// Run a git command with a caller-provided timeout.
|
||||
async fn run_git_command_with_timeout_override(
|
||||
args: &[&str],
|
||||
cwd: &Path,
|
||||
timeout_dur: TokioDuration,
|
||||
) -> Option<std::process::Output> {
|
||||
let mut command = Command::new("git");
|
||||
command.args(args).current_dir(cwd);
|
||||
command.kill_on_drop(true);
|
||||
let result = timeout(timeout_dur, command.output()).await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(output)) => Some(output),
|
||||
|
||||
@@ -101,6 +101,7 @@ pub mod state_db;
|
||||
pub mod terminal;
|
||||
mod tools;
|
||||
pub mod turn_diff_tracker;
|
||||
mod turn_metadata;
|
||||
pub use rollout::ARCHIVED_SESSIONS_SUBDIR;
|
||||
pub use rollout::INTERACTIVE_SESSION_SOURCES;
|
||||
pub use rollout::RolloutRecorder;
|
||||
@@ -131,6 +132,7 @@ pub mod util;
|
||||
|
||||
pub use apply_patch::CODEX_APPLY_PATCH_ARG1;
|
||||
pub use client::WEB_SEARCH_ELIGIBLE_HEADER;
|
||||
pub use client::X_CODEX_TURN_METADATA_HEADER;
|
||||
pub use command_safety::is_dangerous_command;
|
||||
pub use command_safety::is_safe_command;
|
||||
pub use exec_policy::ExecPolicyError;
|
||||
|
||||
52
codex-rs/core/src/turn_metadata.rs
Normal file
52
codex-rs/core/src/turn_metadata.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::path::Path;
|
||||
|
||||
use serde::Serialize;
|
||||
use tokio::time::Duration as TokioDuration;
|
||||
|
||||
use crate::git_info::get_git_remote_urls_assume_git_repo_with_timeout;
|
||||
use crate::git_info::get_git_repo_root;
|
||||
use crate::git_info::get_head_commit_hash_with_timeout;
|
||||
|
||||
const TURN_METADATA_TIMEOUT: TokioDuration = TokioDuration::from_millis(150);
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct TurnMetadataWorkspace {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
associated_remote_urls: Option<BTreeMap<String, String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
latest_git_commit_hash: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct TurnMetadata {
|
||||
workspaces: BTreeMap<String, TurnMetadataWorkspace>,
|
||||
}
|
||||
|
||||
/// Build a per-turn metadata header value for the given working directory.
|
||||
///
|
||||
/// This is intentionally evaluated lazily at request send time so turns that
|
||||
/// never reach the model do not pay the git subprocess cost.
|
||||
pub(crate) async fn build_turn_metadata_header(cwd: &Path) -> Option<String> {
|
||||
let repo_root = get_git_repo_root(cwd)?;
|
||||
|
||||
// Keep git subprocess work bounded per command, without wrapping the
|
||||
// entire metadata build in a separate timeout.
|
||||
let (latest_git_commit_hash, associated_remote_urls) = tokio::join!(
|
||||
get_head_commit_hash_with_timeout(cwd, TURN_METADATA_TIMEOUT),
|
||||
get_git_remote_urls_assume_git_repo_with_timeout(cwd, TURN_METADATA_TIMEOUT)
|
||||
);
|
||||
if latest_git_commit_hash.is_none() && associated_remote_urls.is_none() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut workspaces = BTreeMap::new();
|
||||
workspaces.insert(
|
||||
repo_root.to_string_lossy().into_owned(),
|
||||
TurnMetadataWorkspace {
|
||||
associated_remote_urls,
|
||||
latest_git_commit_hash,
|
||||
},
|
||||
);
|
||||
serde_json::to_string(&TurnMetadata { workspaces }).ok()
|
||||
}
|
||||
@@ -23,6 +23,7 @@ use core_test_support::load_default_config_for_test;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::test_codex::test_codex;
|
||||
use futures::StreamExt;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tempfile::TempDir;
|
||||
use wiremock::matchers::header;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user