optimization for replay stamping

This commit is contained in:
Roy Han
2026-03-17 13:45:31 -07:00
parent 2aada64f0d
commit 7a6c1d985a
2 changed files with 94 additions and 46 deletions

View File

@@ -1133,6 +1133,52 @@ fn tool_call_metadata_or_default(item: &ResponseItem) -> Option<ResponseItemMeta
}
}
#[derive(Clone, Debug, Default)]
struct ToolApprovalMetadataSnapshot {
approval_outcomes_by_call_id: HashMap<String, ReviewDecisionMetadata>,
pending_approval_call_ids: HashSet<String>,
}
fn stamp_tool_approval_metadata_with_snapshot(
turn_context: &TurnContext,
response_item: ResponseItem,
snapshot: Option<&ToolApprovalMetadataSnapshot>,
) -> ResponseItem {
let Some(snapshot) = snapshot else {
return response_item;
};
let Some(call_id) = response_item_tool_call_id(&response_item) else {
return response_item;
};
let outcome = snapshot.approval_outcomes_by_call_id.get(call_id).cloned();
let has_pending_approval = snapshot.pending_approval_call_ids.contains(call_id);
let mut metadata = match tool_call_metadata_or_default(&response_item) {
Some(metadata) => metadata,
None => return response_item,
};
metadata.sandbox_policy = Some(sandbox_policy_to_metadata(
turn_context.sandbox_policy.get(),
));
match outcome {
Some(review_decision) => {
metadata.is_tool_call_escalated = Some(true);
metadata.review_decision = Some(review_decision);
}
None if !has_pending_approval => {
metadata.is_tool_call_escalated = Some(false);
metadata.review_decision = None;
}
None => {
return stamp_tool_metadata_on_response_item(response_item, metadata);
}
}
stamp_tool_metadata_on_response_item(response_item, metadata)
}
#[derive(Clone)]
pub(crate) struct SessionConfiguration {
/// Provider identifier ("openai", "openrouter", ...).
@@ -3419,44 +3465,50 @@ impl Session {
if !self.enabled(Feature::ItemMetadata) {
return response_item;
}
let Some(call_id) = response_item_tool_call_id(&response_item) else {
return response_item;
};
let (outcome, has_pending_approval) = {
let snapshot = {
let active = self.active_turn.lock().await;
let Some(at) = active.as_ref() else {
return response_item;
};
let ts = at.turn_state.lock().await;
let outcome = ts.approval_outcome_for_call_id(call_id);
let has_pending_approval = ts.has_pending_approval_for_call_id(call_id);
(outcome, has_pending_approval)
let (approval_outcomes_by_call_id, pending_approval_call_ids) =
ts.approval_metadata_snapshot();
ToolApprovalMetadataSnapshot {
approval_outcomes_by_call_id,
pending_approval_call_ids,
}
};
stamp_tool_approval_metadata_with_snapshot(turn_context, response_item, Some(&snapshot))
}
let mut metadata = match tool_call_metadata_or_default(&response_item) {
Some(metadata) => metadata,
None => return response_item,
};
metadata.sandbox_policy = Some(sandbox_policy_to_metadata(
turn_context.sandbox_policy.get(),
));
match outcome {
Some(review_decision) => {
metadata.is_tool_call_escalated = Some(true);
metadata.review_decision = Some(review_decision);
}
None if !has_pending_approval => {
metadata.is_tool_call_escalated = Some(false);
metadata.review_decision = None;
}
None => {
return stamp_tool_metadata_on_response_item(response_item, metadata);
}
pub(crate) async fn stamp_tool_approval_metadata_on_items(
&self,
turn_context: &TurnContext,
response_items: Vec<ResponseItem>,
) -> Vec<ResponseItem> {
if !self.enabled(Feature::ItemMetadata) {
return response_items;
}
let snapshot = {
let active = self.active_turn.lock().await;
let Some(at) = active.as_ref() else {
return response_items;
};
let ts = at.turn_state.lock().await;
let (approval_outcomes_by_call_id, pending_approval_call_ids) =
ts.approval_metadata_snapshot();
ToolApprovalMetadataSnapshot {
approval_outcomes_by_call_id,
pending_approval_call_ids,
}
};
stamp_tool_metadata_on_response_item(response_item, metadata)
response_items
.into_iter()
.map(|item| {
stamp_tool_approval_metadata_with_snapshot(turn_context, item, Some(&snapshot))
})
.collect()
}
pub async fn resolve_elicitation(
@@ -6060,13 +6112,12 @@ pub(crate) async fn run_turn(
.clone_history()
.await
.for_prompt(&turn_context.model_info.input_modalities);
let mut sampling_request_input = Vec::with_capacity(sampling_request_input_items.len());
for item in sampling_request_input_items {
sampling_request_input.push(
sess.stamp_tool_approval_metadata(turn_context.as_ref(), item)
.await,
);
}
let sampling_request_input = sess
.stamp_tool_approval_metadata_on_items(
turn_context.as_ref(),
sampling_request_input_items,
)
.await;
let sampling_request_input_messages = sampling_request_input
.iter()

View File

@@ -2,6 +2,7 @@
use indexmap::IndexMap;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::Notify;
@@ -125,17 +126,13 @@ impl TurnState {
self.approval_outcomes_by_call_id.insert(call_id, decision);
}
pub(crate) fn approval_outcome_for_call_id(
pub(crate) fn approval_metadata_snapshot(
&self,
call_id: &str,
) -> Option<ReviewDecisionMetadata> {
self.approval_outcomes_by_call_id.get(call_id).cloned()
}
pub(crate) fn has_pending_approval_for_call_id(&self, call_id: &str) -> bool {
self.pending_approval_call_ids
.values()
.any(|pending_call_id| pending_call_id == call_id)
) -> (HashMap<String, ReviewDecisionMetadata>, HashSet<String>) {
(
self.approval_outcomes_by_call_id.clone(),
self.pending_approval_call_ids.values().cloned().collect(),
)
}
pub(crate) fn remove_pending_approval(