This commit is contained in:
jif-oai
2025-10-30 20:44:42 +00:00
parent 482e6d7fad
commit 7a9c344bd5
3 changed files with 63 additions and 92 deletions

View File

@@ -335,7 +335,7 @@ pub(crate) struct SessionConfiguration {
cwd: PathBuf,
/// Set of feature flags for this session
features: Features,
pub features: Features,
// TODO(pakrym): Remove config from here
original_config_do_not_use: Arc<Config>,
@@ -379,12 +379,6 @@ pub(crate) struct SessionSettingsUpdate {
pub(crate) final_output_json_schema: Option<Option<Value>>,
}
#[derive(Clone)]
struct PreparedPrompt {
prompt: Prompt,
full_prompt_items: Vec<ResponseItem>,
}
impl Session {
fn make_turn_context(
auth_manager: Option<Arc<AuthManager>>,
@@ -948,51 +942,10 @@ impl Session {
self.send_raw_response_items(turn_context, items).await;
}
async fn prepare_prompt(&self, turn_context: &TurnContext) -> PreparedPrompt {
let use_chain = turn_context.client.supports_responses_api_chaining();
let mut history = self.clone_history().await;
let full_prompt_items = history.get_history_for_prompt();
let mut prompt = Prompt::default();
prompt.store_response = use_chain;
if !use_chain {
{
let mut state = self.state.lock().await;
state.reset_responses_api_chain();
}
prompt.input = full_prompt_items.clone();
return PreparedPrompt {
prompt,
full_prompt_items,
};
}
let mut previous_response_id = None;
let mut request_items = full_prompt_items.clone();
{
let mut state = self.state.lock().await;
if let Some(chain) = state.responses_api_chain()
&& let Some(prev_id) = chain.last_response_id
{
let prefix = common_prefix_len(&chain.last_prompt_items, &full_prompt_items);
if prefix == 0 && !chain.last_prompt_items.is_empty() {
state.reset_responses_api_chain();
} else {
previous_response_id = Some(prev_id);
request_items = full_prompt_items[prefix..].to_vec();
}
}
}
prompt.previous_response_id = previous_response_id;
prompt.input = request_items;
PreparedPrompt {
prompt,
full_prompt_items,
}
async fn prompt_for_turn(&self, turn_context: &TurnContext) -> Prompt {
let supports_chain = turn_context.client.supports_responses_api_chaining();
let mut state = self.state.lock().await;
state.prompt_for_turn(supports_chain)
}
fn reconstruct_history_from_rollout(
@@ -1035,12 +988,11 @@ impl Session {
async fn update_responses_api_chain_state(
&self,
chaining_intent: bool,
supports_responses_api_chaining: bool,
response_id: Option<String>,
prompt_items: Vec<ResponseItem>,
) {
let mut state = self.state.lock().await;
if !chaining_intent {
if !supports_responses_api_chaining {
state.reset_responses_api_chain();
return;
}
@@ -1050,6 +1002,9 @@ impl Session {
return;
};
let mut history = state.clone_history();
let prompt_items = history.get_history_for_prompt();
state.set_responses_api_chain(ResponsesApiChainState {
last_response_id: Some(response_id),
last_prompt_items: prompt_items,
@@ -1841,11 +1796,11 @@ pub(crate) async fn run_task(
// Construct the input that we will send to the model.
sess.record_conversation_items(&turn_context, &pending_input)
.await;
let prepared_prompt = sess.prepare_prompt(&turn_context).await;
let prompt = sess.prompt_for_turn(&turn_context).await;
let turn_input_messages: Vec<String> = {
prepared_prompt
.full_prompt_items
prompt
.input
.iter()
.filter_map(|item| match item {
ResponseItem::Message { content, .. } => Some(content),
@@ -1863,7 +1818,7 @@ pub(crate) async fn run_task(
Arc::clone(&sess),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
prepared_prompt,
prompt,
cancellation_token.child_token(),
)
.await
@@ -1944,18 +1899,12 @@ pub(crate) async fn run_task(
last_agent_message
}
fn common_prefix_len(lhs: &[ResponseItem], rhs: &[ResponseItem]) -> usize {
lhs.iter()
.zip(rhs.iter())
.take_while(|(l, r)| l == r)
.count()
}
async fn run_turn(
sess: Arc<Session>,
turn_context: Arc<TurnContext>,
turn_diff_tracker: SharedTurnDiffTracker,
mut prepared_prompt: PreparedPrompt,
mut prompt: Prompt,
cancellation_token: CancellationToken,
) -> CodexResult<TurnRunResult> {
let mcp_tools = sess.services.mcp_connection_manager.list_all_tools();
@@ -1967,17 +1916,7 @@ async fn run_turn(
let tool_specs = router.specs();
let (tools_json, has_freeform_apply_patch) =
crate::tools::spec::tools_metadata_for_prompt(&tool_specs)?;
crate::conversation_history::format_prompt_items(
&mut prepared_prompt.prompt.input,
has_freeform_apply_patch,
);
crate::conversation_history::format_prompt_items(
&mut prepared_prompt.full_prompt_items,
has_freeform_apply_patch,
);
let mut prompt = prepared_prompt.prompt;
let full_prompt_items = prepared_prompt.full_prompt_items;
crate::conversation_history::format_prompt_items(&mut prompt.input, has_freeform_apply_patch);
let apply_patch_present = tool_specs.iter().any(|spec| spec.name() == "apply_patch");
@@ -2003,14 +1942,12 @@ async fn run_turn(
let mut retries = 0;
loop {
let attempt_payload = payload.clone();
let attempt_full_items = full_prompt_items.clone();
match try_run_turn(
Arc::clone(&router),
Arc::clone(&sess),
Arc::clone(&turn_context),
Arc::clone(&turn_diff_tracker),
attempt_payload,
attempt_full_items,
cancellation_token.child_token(),
)
.await
@@ -2092,10 +2029,9 @@ async fn try_run_turn(
turn_context: Arc<TurnContext>,
turn_diff_tracker: SharedTurnDiffTracker,
payload: StreamPayload,
full_prompt_items: Vec<ResponseItem>,
cancellation_token: CancellationToken,
) -> CodexResult<TurnRunResult> {
let chaining_intent = payload.prompt.store_response;
let supports_responses_api_chaining = payload.prompt.store_response;
let rollout_item = RolloutItem::TurnContext(TurnContextItem {
cwd: turn_context.cwd.clone(),
@@ -2253,15 +2189,9 @@ async fn try_run_turn(
let mut tracker = turn_diff_tracker.lock().await;
tracker.get_unified_diff()
};
let prompt_items_for_chain = if chaining_intent {
full_prompt_items
} else {
Vec::new()
};
sess.update_responses_api_chain_state(
chaining_intent,
supports_responses_api_chaining,
Some(response_id.clone()),
prompt_items_for_chain,
)
.await;
if let Ok(Some(unified_diff)) = unified_diff {

View File

@@ -80,6 +80,10 @@ impl ConversationHistory {
// Returns the history prepared for sending to the model.
// With extra response items filtered out and GhostCommits removed.
pub(crate) fn get_history_for_prompt(&mut self) -> Vec<ResponseItem> {
self.build_prompt_history()
}
fn build_prompt_history(&mut self) -> Vec<ResponseItem> {
let mut history = self.get_history();
Self::remove_ghost_snapshots(&mut history);
Self::remove_reasoning_before_last_turn(&mut history);

View File

@@ -2,6 +2,7 @@
use codex_protocol::models::ResponseItem;
use crate::client_common::Prompt;
use crate::codex::SessionConfiguration;
use crate::conversation_history::ConversationHistory;
use crate::conversation_history::ResponsesApiChainState;
@@ -51,10 +52,6 @@ impl SessionState {
self.history.set_responses_api_chain(chain);
}
pub(crate) fn responses_api_chain(&self) -> Option<ResponsesApiChainState> {
self.history.responses_api_chain()
}
// Token/rate limit helpers
pub(crate) fn update_token_info_from_usage(
&mut self,
@@ -81,4 +78,44 @@ impl SessionState {
pub(crate) fn set_token_usage_full(&mut self, context_window: i64) {
self.history.set_token_usage_full(context_window);
}
pub(crate) fn prompt_for_turn(&mut self, supports_responses_api_chaining: bool) -> Prompt {
let mut prompt = Prompt::default();
prompt.store_response = supports_responses_api_chaining;
let mut prompt_items = self.history.get_history_for_prompt();
if !supports_responses_api_chaining {
self.reset_responses_api_chain();
prompt.input = prompt_items;
return prompt;
}
let mut previous_response_id = None;
if let Some(chain_state) = self.history.responses_api_chain() {
if let Some(prev_id) = chain_state.last_response_id {
let prefix = common_prefix_len(&chain_state.last_prompt_items, &prompt_items);
let matches_previous_prompt = prefix == chain_state.last_prompt_items.len();
if matches_previous_prompt {
previous_response_id = Some(prev_id);
if prefix > 0 {
prompt_items.drain(..prefix);
}
} else if !chain_state.last_prompt_items.is_empty() {
self.reset_responses_api_chain();
}
}
}
prompt.previous_response_id = previous_response_id;
prompt.input = prompt_items;
prompt
}
}
fn common_prefix_len(lhs: &[ResponseItem], rhs: &[ResponseItem]) -> usize {
lhs.iter()
.zip(rhs.iter())
.take_while(|(left, right)| left == right)
.count()
}