mirror of
https://github.com/openai/codex.git
synced 2026-06-02 19:31:59 +00:00
Merge branch 'main' into jif/feature-new-tool-
This commit is contained in:
@@ -1,27 +1,67 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::Weak;
|
||||
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_app_server_protocol::ThreadGoalUpdatedNotification;
|
||||
use codex_core::NewThread;
|
||||
use codex_core::StartThreadOptions;
|
||||
use codex_core::ThreadManager;
|
||||
use codex_core::config::Config;
|
||||
use codex_extension_api::AgentSpawnFuture;
|
||||
use codex_extension_api::AgentSpawner;
|
||||
use codex_extension_api::ExtensionEventSink;
|
||||
use codex_extension_api::ExtensionRegistry;
|
||||
use codex_extension_api::ExtensionRegistryBuilder;
|
||||
use codex_protocol::ThreadId;
|
||||
use codex_protocol::error::CodexErr;
|
||||
use codex_protocol::protocol::Event;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
|
||||
pub(crate) fn thread_extensions<S>(guardian_agent_spawner: S) -> Arc<ExtensionRegistry<Config>>
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
|
||||
pub(crate) fn thread_extensions<S>(
|
||||
guardian_agent_spawner: S,
|
||||
event_sink: Arc<dyn ExtensionEventSink>,
|
||||
) -> Arc<ExtensionRegistry<Config>>
|
||||
where
|
||||
S: AgentSpawner<StartThreadOptions, Spawned = NewThread, Error = CodexErr> + 'static,
|
||||
{
|
||||
let mut builder = ExtensionRegistryBuilder::<Config>::new();
|
||||
let mut builder = ExtensionRegistryBuilder::<Config>::with_event_sink(event_sink);
|
||||
codex_guardian::install(&mut builder, guardian_agent_spawner);
|
||||
codex_memories_extension::install(&mut builder, codex_otel::global());
|
||||
Arc::new(builder.build())
|
||||
}
|
||||
|
||||
pub(crate) fn app_server_extension_event_sink(
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
) -> Arc<dyn ExtensionEventSink> {
|
||||
Arc::new(AppServerExtensionEventSink { outgoing })
|
||||
}
|
||||
|
||||
struct AppServerExtensionEventSink {
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
}
|
||||
|
||||
impl ExtensionEventSink for AppServerExtensionEventSink {
|
||||
fn emit(&self, event: Event) {
|
||||
match event.msg {
|
||||
EventMsg::ThreadGoalUpdated(thread_goal_event) => {
|
||||
self.outgoing
|
||||
.try_send_server_notification(ServerNotification::ThreadGoalUpdated(
|
||||
ThreadGoalUpdatedNotification {
|
||||
thread_id: thread_goal_event.thread_id.to_string(),
|
||||
turn_id: thread_goal_event.turn_id,
|
||||
goal: thread_goal_event.goal.into(),
|
||||
},
|
||||
));
|
||||
}
|
||||
msg => {
|
||||
tracing::debug!(event_id = %event.id, ?msg, "dropping unsupported extension event");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn guardian_agent_spawner(
|
||||
thread_manager: Weak<ThreadManager>,
|
||||
) -> impl AgentSpawner<StartThreadOptions, Spawned = NewThread, Error = CodexErr> {
|
||||
@@ -39,3 +79,84 @@ pub(crate) fn guardian_agent_spawner(
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_analytics::AnalyticsEventsClient;
|
||||
use codex_app_server_protocol::ServerNotification;
|
||||
use codex_app_server_protocol::ThreadGoal as AppServerThreadGoal;
|
||||
use codex_app_server_protocol::ThreadGoalStatus as AppServerThreadGoalStatus;
|
||||
use codex_protocol::protocol::ThreadGoal;
|
||||
use codex_protocol::protocol::ThreadGoalStatus;
|
||||
use codex_protocol::protocol::ThreadGoalUpdatedEvent;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::time::timeout;
|
||||
|
||||
use super::*;
|
||||
use crate::outgoing_message::OutgoingEnvelope;
|
||||
use crate::outgoing_message::OutgoingMessage;
|
||||
|
||||
#[tokio::test]
|
||||
async fn app_server_event_sink_forwards_thread_goal_updates() {
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(4);
|
||||
let outgoing = Arc::new(OutgoingMessageSender::new(
|
||||
outgoing_tx,
|
||||
AnalyticsEventsClient::disabled(),
|
||||
));
|
||||
let sink = app_server_extension_event_sink(outgoing);
|
||||
let thread_id = ThreadId::default();
|
||||
|
||||
sink.emit(Event {
|
||||
id: "call-1".to_string(),
|
||||
msg: EventMsg::ThreadGoalUpdated(ThreadGoalUpdatedEvent {
|
||||
thread_id,
|
||||
turn_id: Some("turn-1".to_string()),
|
||||
goal: ThreadGoal {
|
||||
thread_id,
|
||||
objective: "wire extension events".to_string(),
|
||||
status: ThreadGoalStatus::Active,
|
||||
token_budget: Some(123),
|
||||
tokens_used: 45,
|
||||
time_used_seconds: 6,
|
||||
created_at: 7,
|
||||
updated_at: 8,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
let envelope = timeout(Duration::from_secs(1), outgoing_rx.recv())
|
||||
.await
|
||||
.expect("timed out waiting for forwarded extension event")
|
||||
.expect("outgoing channel closed unexpectedly");
|
||||
let OutgoingEnvelope::Broadcast { message } = envelope else {
|
||||
panic!("expected broadcast notification");
|
||||
};
|
||||
let OutgoingMessage::AppServerNotification(ServerNotification::ThreadGoalUpdated(
|
||||
notification,
|
||||
)) = message
|
||||
else {
|
||||
panic!("expected thread goal updated notification");
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
ThreadGoalUpdatedNotification {
|
||||
thread_id: thread_id.to_string(),
|
||||
turn_id: Some("turn-1".to_string()),
|
||||
goal: AppServerThreadGoal {
|
||||
thread_id: thread_id.to_string(),
|
||||
objective: "wire extension events".to_string(),
|
||||
status: AppServerThreadGoalStatus::Active,
|
||||
token_budget: Some(123),
|
||||
tokens_used: 45,
|
||||
time_used_seconds: 6,
|
||||
created_at: 7,
|
||||
updated_at: 8,
|
||||
},
|
||||
},
|
||||
notification
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -114,6 +114,7 @@ mod tests {
|
||||
use codex_core::init_state_db;
|
||||
use codex_core::thread_store_from_config;
|
||||
use codex_exec_server::EnvironmentManager;
|
||||
use codex_extension_api::NoopExtensionEventSink;
|
||||
use codex_login::AuthManager;
|
||||
use codex_login::CodexAuth;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
@@ -186,7 +187,10 @@ mod tests {
|
||||
auth_manager,
|
||||
SessionSource::Exec,
|
||||
Arc::new(EnvironmentManager::default_for_tests()),
|
||||
thread_extensions(guardian_agent_spawner(thread_manager.clone())),
|
||||
thread_extensions(
|
||||
guardian_agent_spawner(thread_manager.clone()),
|
||||
Arc::new(NoopExtensionEventSink),
|
||||
),
|
||||
/*analytics_events_client*/ None,
|
||||
thread_store,
|
||||
Some(state_db.clone()),
|
||||
|
||||
@@ -8,6 +8,7 @@ use crate::attestation::app_server_attestation_provider;
|
||||
use crate::config_manager::ConfigManager;
|
||||
use crate::connection_rpc_gate::ConnectionRpcGate;
|
||||
use crate::error_code::invalid_request;
|
||||
use crate::extensions::app_server_extension_event_sink;
|
||||
use crate::extensions::guardian_agent_spawner;
|
||||
use crate::extensions::thread_extensions;
|
||||
use crate::fs_watch::FsWatchManager;
|
||||
@@ -310,7 +311,10 @@ impl MessageProcessor {
|
||||
auth_manager.clone(),
|
||||
session_source,
|
||||
environment_manager,
|
||||
thread_extensions(guardian_agent_spawner(thread_manager.clone())),
|
||||
thread_extensions(
|
||||
guardian_agent_spawner(thread_manager.clone()),
|
||||
app_server_extension_event_sink(outgoing.clone()),
|
||||
),
|
||||
Some(analytics_events_client.clone()),
|
||||
Arc::clone(&thread_store),
|
||||
state_db.clone(),
|
||||
|
||||
@@ -555,6 +555,16 @@ impl OutgoingMessageSender {
|
||||
.await;
|
||||
}
|
||||
|
||||
pub(crate) fn try_send_server_notification(&self, notification: ServerNotification) {
|
||||
tracing::trace!("app-server event: {notification}");
|
||||
let outgoing_message = OutgoingMessage::AppServerNotification(notification);
|
||||
if let Err(err) = self.sender.try_send(OutgoingEnvelope::Broadcast {
|
||||
message: outgoing_message,
|
||||
}) {
|
||||
warn!("failed to send server notification to client without waiting: {err:?}");
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn send_server_notification_to_connections(
|
||||
&self,
|
||||
connection_ids: &[ConnectionId],
|
||||
|
||||
@@ -20,14 +20,12 @@ use crate::tools;
|
||||
/// Contributes Codex memory read-path prompt context and memory read tools.
|
||||
#[derive(Clone, Default)]
|
||||
pub(crate) struct MemoriesExtension {
|
||||
_metrics_client: Option<MetricsClient>,
|
||||
metrics_client: Option<MetricsClient>,
|
||||
}
|
||||
|
||||
impl MemoriesExtension {
|
||||
fn new(metrics_client: Option<MetricsClient>) -> Self {
|
||||
Self {
|
||||
_metrics_client: metrics_client,
|
||||
}
|
||||
Self { metrics_client }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,7 +104,10 @@ impl ToolContributor for MemoriesExtension {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
tools::memory_tools(LocalMemoriesBackend::from_codex_home(&config.codex_home))
|
||||
tools::memory_tools(
|
||||
LocalMemoriesBackend::from_codex_home(&config.codex_home),
|
||||
self.metrics_client.clone(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
mod backend;
|
||||
mod extension;
|
||||
mod local;
|
||||
mod metrics;
|
||||
mod prompts;
|
||||
mod schema;
|
||||
mod tools;
|
||||
|
||||
69
codex-rs/ext/memories/src/metrics.rs
Normal file
69
codex-rs/ext/memories/src/metrics.rs
Normal file
@@ -0,0 +1,69 @@
|
||||
use codex_otel::MetricsClient;
|
||||
|
||||
use crate::MEMORY_TOOLS_NAMESPACE;
|
||||
|
||||
pub(crate) const MEMORIES_TOOL_CALL_METRIC: &str = "codex.memories.tool.call";
|
||||
|
||||
pub(crate) fn record_tool_call(
|
||||
metrics_client: Option<&MetricsClient>,
|
||||
operation: &str,
|
||||
scope: &str,
|
||||
success: bool,
|
||||
truncated: &str,
|
||||
) {
|
||||
let Some(metrics_client) = metrics_client else {
|
||||
return;
|
||||
};
|
||||
|
||||
let tool = format!("{MEMORY_TOOLS_NAMESPACE}{operation}");
|
||||
let _ = metrics_client.counter(
|
||||
MEMORIES_TOOL_CALL_METRIC,
|
||||
/*inc*/ 1,
|
||||
&[
|
||||
("tool", tool.as_str()),
|
||||
("operation", operation),
|
||||
("scope", scope),
|
||||
("status", status_tag(success)),
|
||||
("truncated", truncated),
|
||||
],
|
||||
);
|
||||
}
|
||||
|
||||
pub(crate) fn scope_from_path(path: &str) -> &'static str {
|
||||
let path = path.trim_matches('/');
|
||||
let path = path.strip_prefix("./").unwrap_or(path);
|
||||
|
||||
if path.is_empty() {
|
||||
"root"
|
||||
} else if path == "MEMORY.md" {
|
||||
"memory_md"
|
||||
} else if path == "memory_summary.md" {
|
||||
"memory_summary"
|
||||
} else if path == "raw_memories.md" {
|
||||
"raw_memories"
|
||||
} else if path == "rollout_summaries" || path.starts_with("rollout_summaries/") {
|
||||
"rollout_summaries"
|
||||
} else if path == "skills" || path.starts_with("skills/") {
|
||||
"skills"
|
||||
} else if path == "extensions/ad_hoc/notes" || path.starts_with("extensions/ad_hoc/notes/") {
|
||||
"ad_hoc_notes"
|
||||
} else {
|
||||
"other"
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn scope_from_optional_path(path: Option<&str>, default: &'static str) -> &'static str {
|
||||
path.map_or(default, scope_from_path)
|
||||
}
|
||||
|
||||
pub(crate) fn truncated_tag(truncated: Option<bool>) -> &'static str {
|
||||
match truncated {
|
||||
Some(true) => "true",
|
||||
Some(false) => "false",
|
||||
None => "unknown",
|
||||
}
|
||||
}
|
||||
|
||||
fn status_tag(success: bool) -> &'static str {
|
||||
if success { "succeeded" } else { "failed" }
|
||||
}
|
||||
@@ -445,10 +445,13 @@ async fn search_tool_rejects_legacy_single_query() {
|
||||
|
||||
fn memory_tool(memory_root: &Path, tool_name: &str) -> Arc<dyn ToolExecutor<ToolCall>> {
|
||||
let expected_tool_name = memory_tool_name(tool_name);
|
||||
crate::tools::memory_tools(LocalMemoriesBackend::from_memory_root(memory_root))
|
||||
.into_iter()
|
||||
.find(|tool| tool.tool_name() == expected_tool_name)
|
||||
.unwrap_or_else(|| panic!("{tool_name} tool should be registered"))
|
||||
crate::tools::memory_tools(
|
||||
LocalMemoriesBackend::from_memory_root(memory_root),
|
||||
/*metrics_client*/ None,
|
||||
)
|
||||
.into_iter()
|
||||
.find(|tool| tool.tool_name() == expected_tool_name)
|
||||
.unwrap_or_else(|| panic!("{tool_name} tool should be registered"))
|
||||
}
|
||||
|
||||
fn memory_tool_name(tool_name: &str) -> ToolName {
|
||||
|
||||
@@ -3,6 +3,7 @@ use codex_extension_api::ToolCall;
|
||||
use codex_extension_api::ToolExecutor;
|
||||
use codex_extension_api::ToolName;
|
||||
use codex_extension_api::ToolSpec;
|
||||
use codex_otel::MetricsClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
@@ -11,6 +12,7 @@ use crate::ADD_AD_HOC_NOTE_TOOL_NAME;
|
||||
use crate::backend::AddAdHocMemoryNoteRequest;
|
||||
use crate::backend::AddAdHocMemoryNoteResponse;
|
||||
use crate::backend::MemoriesBackend;
|
||||
use crate::metrics::record_tool_call;
|
||||
|
||||
use super::backend_error_to_function_call;
|
||||
use super::memory_function_tool;
|
||||
@@ -36,6 +38,7 @@ struct AddAdHocNoteArgs {
|
||||
#[derive(Clone)]
|
||||
pub(super) struct AddAdHocNoteTool<B> {
|
||||
pub(super) backend: B,
|
||||
pub(super) metrics_client: Option<MetricsClient>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -66,8 +69,15 @@ where
|
||||
filename: args.filename,
|
||||
note: args.note,
|
||||
})
|
||||
.await
|
||||
.map_err(backend_error_to_function_call)?;
|
||||
.await;
|
||||
record_tool_call(
|
||||
self.metrics_client.as_ref(),
|
||||
ADD_AD_HOC_NOTE_TOOL_NAME,
|
||||
"ad_hoc_notes",
|
||||
response.is_ok(),
|
||||
"not_applicable",
|
||||
);
|
||||
let response = response.map_err(backend_error_to_function_call)?;
|
||||
Ok(Box::new(JsonToolOutput::new(json!(response))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ use codex_extension_api::ToolCall;
|
||||
use codex_extension_api::ToolExecutor;
|
||||
use codex_extension_api::ToolName;
|
||||
use codex_extension_api::ToolSpec;
|
||||
use codex_otel::MetricsClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
@@ -13,6 +14,9 @@ use crate::MAX_LIST_RESULTS;
|
||||
use crate::backend::ListMemoriesRequest;
|
||||
use crate::backend::ListMemoriesResponse;
|
||||
use crate::backend::MemoriesBackend;
|
||||
use crate::metrics::record_tool_call;
|
||||
use crate::metrics::scope_from_optional_path;
|
||||
use crate::metrics::truncated_tag;
|
||||
|
||||
use super::backend_error_to_function_call;
|
||||
use super::clamp_max_results;
|
||||
@@ -32,6 +36,7 @@ struct ListArgs {
|
||||
#[derive(Clone)]
|
||||
pub(super) struct ListTool<B> {
|
||||
pub(super) backend: B,
|
||||
pub(super) metrics_client: Option<MetricsClient>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -57,6 +62,7 @@ where
|
||||
{
|
||||
let backend = self.backend.clone();
|
||||
let args: ListArgs = parse_args(&call)?;
|
||||
let scope = scope_from_optional_path(args.path.as_deref(), "root");
|
||||
let response = backend
|
||||
.list(ListMemoriesRequest {
|
||||
path: args.path,
|
||||
@@ -67,8 +73,15 @@ where
|
||||
MAX_LIST_RESULTS,
|
||||
),
|
||||
})
|
||||
.await
|
||||
.map_err(backend_error_to_function_call)?;
|
||||
.await;
|
||||
record_tool_call(
|
||||
self.metrics_client.as_ref(),
|
||||
LIST_TOOL_NAME,
|
||||
scope,
|
||||
response.is_ok(),
|
||||
truncated_tag(response.as_ref().ok().map(|response| response.truncated)),
|
||||
);
|
||||
let response = response.map_err(backend_error_to_function_call)?;
|
||||
Ok(Box::new(JsonToolOutput::new(json!(response))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ use codex_extension_api::ToolExecutor;
|
||||
use codex_extension_api::ToolName;
|
||||
use codex_extension_api::ToolSpec;
|
||||
use codex_extension_api::parse_tool_input_schema;
|
||||
use codex_otel::MetricsClient;
|
||||
use codex_tools::ResponsesApiNamespace;
|
||||
use codex_tools::ResponsesApiNamespaceTool;
|
||||
use codex_tools::default_namespace_description;
|
||||
@@ -24,21 +25,30 @@ mod list;
|
||||
mod read;
|
||||
mod search;
|
||||
|
||||
pub(crate) fn memory_tools<B>(backend: B) -> Vec<Arc<dyn ToolExecutor<ToolCall>>>
|
||||
pub(crate) fn memory_tools<B>(
|
||||
backend: B,
|
||||
metrics_client: Option<MetricsClient>,
|
||||
) -> Vec<Arc<dyn ToolExecutor<ToolCall>>>
|
||||
where
|
||||
B: MemoriesBackend,
|
||||
{
|
||||
vec![
|
||||
Arc::new(ad_hoc_note::AddAdHocNoteTool {
|
||||
backend: backend.clone(),
|
||||
metrics_client: metrics_client.clone(),
|
||||
}),
|
||||
Arc::new(list::ListTool {
|
||||
backend: backend.clone(),
|
||||
metrics_client: metrics_client.clone(),
|
||||
}),
|
||||
Arc::new(read::ReadTool {
|
||||
backend: backend.clone(),
|
||||
metrics_client: metrics_client.clone(),
|
||||
}),
|
||||
Arc::new(search::SearchTool {
|
||||
backend,
|
||||
metrics_client,
|
||||
}),
|
||||
Arc::new(search::SearchTool { backend }),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ use codex_extension_api::ToolCall;
|
||||
use codex_extension_api::ToolExecutor;
|
||||
use codex_extension_api::ToolName;
|
||||
use codex_extension_api::ToolSpec;
|
||||
use codex_otel::MetricsClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
@@ -12,6 +13,9 @@ use crate::READ_TOOL_NAME;
|
||||
use crate::backend::MemoriesBackend;
|
||||
use crate::backend::ReadMemoryRequest;
|
||||
use crate::backend::ReadMemoryResponse;
|
||||
use crate::metrics::record_tool_call;
|
||||
use crate::metrics::scope_from_path;
|
||||
use crate::metrics::truncated_tag;
|
||||
|
||||
use super::backend_error_to_function_call;
|
||||
use super::memory_function_tool;
|
||||
@@ -31,6 +35,7 @@ struct ReadArgs {
|
||||
#[derive(Clone)]
|
||||
pub(super) struct ReadTool<B> {
|
||||
pub(super) backend: B,
|
||||
pub(super) metrics_client: Option<MetricsClient>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -56,15 +61,24 @@ where
|
||||
{
|
||||
let backend = self.backend.clone();
|
||||
let args: ReadArgs = parse_args(&call)?;
|
||||
let path = args.path;
|
||||
let scope = scope_from_path(path.as_str());
|
||||
let response = backend
|
||||
.read(ReadMemoryRequest {
|
||||
path: args.path,
|
||||
path: path.clone(),
|
||||
line_offset: args.line_offset.unwrap_or(1),
|
||||
max_lines: args.max_lines,
|
||||
max_tokens: DEFAULT_READ_MAX_TOKENS,
|
||||
})
|
||||
.await
|
||||
.map_err(backend_error_to_function_call)?;
|
||||
.await;
|
||||
record_tool_call(
|
||||
self.metrics_client.as_ref(),
|
||||
READ_TOOL_NAME,
|
||||
scope,
|
||||
response.is_ok(),
|
||||
truncated_tag(response.as_ref().ok().map(|response| response.truncated)),
|
||||
);
|
||||
let response = response.map_err(backend_error_to_function_call)?;
|
||||
Ok(Box::new(JsonToolOutput::new(json!(response))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ use codex_extension_api::ToolCall;
|
||||
use codex_extension_api::ToolExecutor;
|
||||
use codex_extension_api::ToolName;
|
||||
use codex_extension_api::ToolSpec;
|
||||
use codex_otel::MetricsClient;
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
@@ -14,6 +15,9 @@ use crate::backend::MemoriesBackend;
|
||||
use crate::backend::SearchMatchMode;
|
||||
use crate::backend::SearchMemoriesRequest;
|
||||
use crate::backend::SearchMemoriesResponse;
|
||||
use crate::metrics::record_tool_call;
|
||||
use crate::metrics::scope_from_optional_path;
|
||||
use crate::metrics::truncated_tag;
|
||||
|
||||
use super::backend_error_to_function_call;
|
||||
use super::clamp_max_results;
|
||||
@@ -40,6 +44,7 @@ struct SearchArgs {
|
||||
#[derive(Clone)]
|
||||
pub(super) struct SearchTool<B> {
|
||||
pub(super) backend: B,
|
||||
pub(super) metrics_client: Option<MetricsClient>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -65,10 +70,16 @@ where
|
||||
{
|
||||
let backend = self.backend.clone();
|
||||
let args: SearchArgs = parse_args(&call)?;
|
||||
let response = backend
|
||||
.search(args.into_request())
|
||||
.await
|
||||
.map_err(backend_error_to_function_call)?;
|
||||
let scope = scope_from_optional_path(args.path.as_deref(), "all");
|
||||
let response = backend.search(args.into_request()).await;
|
||||
record_tool_call(
|
||||
self.metrics_client.as_ref(),
|
||||
SEARCH_TOOL_NAME,
|
||||
scope,
|
||||
response.is_ok(),
|
||||
truncated_tag(response.as_ref().ok().map(|response| response.truncated)),
|
||||
);
|
||||
let response = response.map_err(backend_error_to_function_call)?;
|
||||
Ok(Box::new(JsonToolOutput::new(json!(response))))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user