mirror of
https://github.com/openai/codex.git
synced 2026-04-08 06:44:58 +00:00
Compare commits
1 Commits
codex/func
...
dev/jn/cod
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
163c32eedb |
@@ -174,6 +174,7 @@ impl ThreadHistoryBuilder {
|
||||
self.handle_dynamic_tool_call_response(payload)
|
||||
}
|
||||
EventMsg::McpToolCallBegin(payload) => self.handle_mcp_tool_call_begin(payload),
|
||||
EventMsg::McpToolCallProgress(_) => {}
|
||||
EventMsg::McpToolCallEnd(payload) => self.handle_mcp_tool_call_end(payload),
|
||||
EventMsg::ViewImageToolCall(payload) => self.handle_view_image_tool_call(payload),
|
||||
EventMsg::ImageGenerationBegin(payload) => self.handle_image_generation_begin(payload),
|
||||
|
||||
@@ -55,6 +55,7 @@ use codex_app_server_protocol::McpServerElicitationRequestResponse;
|
||||
use codex_app_server_protocol::McpServerStartupState;
|
||||
use codex_app_server_protocol::McpServerStatusUpdatedNotification;
|
||||
use codex_app_server_protocol::McpToolCallError;
|
||||
use codex_app_server_protocol::McpToolCallProgressNotification;
|
||||
use codex_app_server_protocol::McpToolCallResult;
|
||||
use codex_app_server_protocol::McpToolCallStatus;
|
||||
use codex_app_server_protocol::ModelReroutedNotification;
|
||||
@@ -122,6 +123,7 @@ use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::ExecApprovalRequestEvent;
|
||||
use codex_protocol::protocol::McpToolCallBeginEvent;
|
||||
use codex_protocol::protocol::McpToolCallEndEvent;
|
||||
use codex_protocol::protocol::McpToolCallProgressEvent;
|
||||
use codex_protocol::protocol::Op;
|
||||
use codex_protocol::protocol::RealtimeEvent;
|
||||
use codex_protocol::protocol::ReviewDecision;
|
||||
@@ -980,6 +982,16 @@ pub(crate) async fn apply_bespoke_event_handling(
|
||||
.send_server_notification(ServerNotification::ItemCompleted(notification))
|
||||
.await;
|
||||
}
|
||||
EventMsg::McpToolCallProgress(progress_event) => {
|
||||
let notification = construct_mcp_tool_call_progress_notification(
|
||||
progress_event,
|
||||
conversation_id.to_string(),
|
||||
event_turn_id.clone(),
|
||||
);
|
||||
outgoing
|
||||
.send_server_notification(ServerNotification::McpToolCallProgress(notification))
|
||||
.await;
|
||||
}
|
||||
EventMsg::CollabAgentSpawnBegin(begin_event) => {
|
||||
let item = ThreadItem::CollabAgentToolCall {
|
||||
id: begin_event.call_id,
|
||||
@@ -2823,6 +2835,19 @@ async fn construct_mcp_tool_call_end_notification(
|
||||
}
|
||||
}
|
||||
|
||||
fn construct_mcp_tool_call_progress_notification(
|
||||
progress_event: McpToolCallProgressEvent,
|
||||
thread_id: String,
|
||||
turn_id: String,
|
||||
) -> McpToolCallProgressNotification {
|
||||
McpToolCallProgressNotification {
|
||||
thread_id,
|
||||
turn_id,
|
||||
item_id: progress_event.call_id,
|
||||
message: progress_event.message,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -4205,6 +4230,32 @@ mod tests {
|
||||
assert_eq!(notification, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_construct_mcp_tool_call_progress_notification() {
|
||||
let progress_event = McpToolCallProgressEvent {
|
||||
call_id: "call_progress".to_string(),
|
||||
message: "indexing".to_string(),
|
||||
};
|
||||
|
||||
let thread_id = ThreadId::new().to_string();
|
||||
let turn_id = "turn_5".to_string();
|
||||
let notification = construct_mcp_tool_call_progress_notification(
|
||||
progress_event.clone(),
|
||||
thread_id.clone(),
|
||||
turn_id.clone(),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
notification,
|
||||
McpToolCallProgressNotification {
|
||||
thread_id,
|
||||
turn_id,
|
||||
item_id: progress_event.call_id,
|
||||
message: progress_event.message,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_handle_turn_diff_emits_v2_notification() -> Result<()> {
|
||||
let (tx, mut rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
|
||||
@@ -50,6 +50,7 @@ use codex_rmcp_client::ElicitationResponse;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
use codex_rmcp_client::SendElicitation;
|
||||
use codex_rmcp_client::SendProgressNotification;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::future::FutureExt;
|
||||
use futures::future::Shared;
|
||||
@@ -1012,6 +1013,7 @@ impl McpConnectionManager {
|
||||
tool: &str,
|
||||
arguments: Option<serde_json::Value>,
|
||||
meta: Option<serde_json::Value>,
|
||||
progress_notification: Option<SendProgressNotification>,
|
||||
) -> Result<CallToolResult> {
|
||||
let client = self.client_by_name(server).await?;
|
||||
if !client.tool_filter.allows(tool) {
|
||||
@@ -1022,7 +1024,13 @@ impl McpConnectionManager {
|
||||
|
||||
let result: rmcp::model::CallToolResult = client
|
||||
.client
|
||||
.call_tool(tool.to_string(), arguments, meta, client.tool_timeout)
|
||||
.call_tool(
|
||||
tool.to_string(),
|
||||
arguments,
|
||||
meta,
|
||||
client.tool_timeout,
|
||||
progress_notification,
|
||||
)
|
||||
.await
|
||||
.with_context(|| format!("tool call failed for `{server}/{tool}`"))?;
|
||||
|
||||
|
||||
@@ -129,6 +129,7 @@ use codex_protocol::request_user_input::RequestUserInputArgs;
|
||||
use codex_protocol::request_user_input::RequestUserInputResponse;
|
||||
use codex_rmcp_client::ElicitationResponse;
|
||||
use codex_rmcp_client::OAuthCredentialsStoreMode;
|
||||
use codex_rmcp_client::SendProgressNotification;
|
||||
use codex_rollout::state_db;
|
||||
use codex_shell_command::parse_command::parse_command;
|
||||
use codex_terminal_detection::user_agent;
|
||||
@@ -4280,12 +4281,13 @@ impl Session {
|
||||
tool: &str,
|
||||
arguments: Option<serde_json::Value>,
|
||||
meta: Option<serde_json::Value>,
|
||||
progress_notification: Option<SendProgressNotification>,
|
||||
) -> anyhow::Result<CallToolResult> {
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.call_tool(server, tool, arguments, meta)
|
||||
.call_tool(server, tool, arguments, meta, progress_notification)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -7050,6 +7052,7 @@ fn realtime_text_for_event(msg: &EventMsg) -> Option<String> {
|
||||
| EventMsg::McpStartupUpdate(_)
|
||||
| EventMsg::McpStartupComplete(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
| EventMsg::McpToolCallProgress(_)
|
||||
| EventMsg::McpToolCallEnd(_)
|
||||
| EventMsg::WebSearchBegin(_)
|
||||
| EventMsg::WebSearchEnd(_)
|
||||
|
||||
@@ -369,6 +369,24 @@ async fn forward_events(
|
||||
break;
|
||||
}
|
||||
}
|
||||
Event {
|
||||
id,
|
||||
msg: EventMsg::McpToolCallProgress(event),
|
||||
} => {
|
||||
if !forward_event_or_shutdown(
|
||||
&codex,
|
||||
&tx_sub,
|
||||
&cancel_token,
|
||||
Event {
|
||||
id,
|
||||
msg: EventMsg::McpToolCallProgress(event),
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
other => {
|
||||
if !forward_event_or_shutdown(&codex, &tx_sub, &cancel_token, other).await
|
||||
{
|
||||
|
||||
@@ -41,6 +41,7 @@ use codex_protocol::protocol::EventMsg;
|
||||
use codex_protocol::protocol::McpInvocation;
|
||||
use codex_protocol::protocol::McpToolCallBeginEvent;
|
||||
use codex_protocol::protocol::McpToolCallEndEvent;
|
||||
use codex_protocol::protocol::McpToolCallProgressEvent;
|
||||
use codex_protocol::protocol::ReviewDecision;
|
||||
use codex_protocol::protocol::SandboxPolicy;
|
||||
use codex_protocol::request_user_input::RequestUserInputAnswer;
|
||||
@@ -50,7 +51,9 @@ use codex_protocol::request_user_input::RequestUserInputQuestionOption;
|
||||
use codex_protocol::request_user_input::RequestUserInputResponse;
|
||||
use codex_rmcp_client::ElicitationAction;
|
||||
use codex_rmcp_client::ElicitationResponse;
|
||||
use codex_rmcp_client::SendProgressNotification;
|
||||
use codex_rollout::state_db;
|
||||
use rmcp::model::ProgressNotificationParam;
|
||||
use rmcp::model::ToolAnnotations;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
@@ -153,6 +156,11 @@ pub(crate) async fn handle_mcp_tool_call(
|
||||
.await
|
||||
.server_origin(&server)
|
||||
.map(str::to_string);
|
||||
let progress_notification = build_mcp_tool_call_progress_notification_sender(
|
||||
Arc::clone(&sess),
|
||||
Arc::clone(turn_context),
|
||||
call_id.clone(),
|
||||
);
|
||||
|
||||
let tool_call_begin_event = EventMsg::McpToolCallBegin(McpToolCallBeginEvent {
|
||||
call_id: call_id.clone(),
|
||||
@@ -183,6 +191,7 @@ pub(crate) async fn handle_mcp_tool_call(
|
||||
&tool_name,
|
||||
arguments_value.clone(),
|
||||
request_meta.clone(),
|
||||
Some(progress_notification.clone()),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| format!("tool call error: {e:?}"))
|
||||
@@ -296,9 +305,15 @@ pub(crate) async fn handle_mcp_tool_call(
|
||||
let start = Instant::now();
|
||||
// Perform the tool call.
|
||||
let result = async {
|
||||
sess.call_tool(&server, &tool_name, arguments_value.clone(), request_meta)
|
||||
.await
|
||||
.map_err(|e| format!("tool call error: {e:?}"))
|
||||
sess.call_tool(
|
||||
&server,
|
||||
&tool_name,
|
||||
arguments_value.clone(),
|
||||
request_meta,
|
||||
Some(progress_notification),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| format!("tool call error: {e:?}"))
|
||||
}
|
||||
.instrument(mcp_tool_call_span(
|
||||
sess.as_ref(),
|
||||
@@ -504,6 +519,66 @@ async fn notify_mcp_tool_call_event(sess: &Session, turn_context: &TurnContext,
|
||||
sess.send_event(turn_context, event).await;
|
||||
}
|
||||
|
||||
fn build_mcp_tool_call_progress_notification_sender(
|
||||
sess: Arc<Session>,
|
||||
turn_context: Arc<TurnContext>,
|
||||
call_id: String,
|
||||
) -> SendProgressNotification {
|
||||
Arc::new(move |notification: ProgressNotificationParam| {
|
||||
let sess = Arc::clone(&sess);
|
||||
let turn_context = Arc::clone(&turn_context);
|
||||
let call_id = call_id.clone();
|
||||
Box::pin(async move {
|
||||
if let Some(message) = format_mcp_tool_call_progress_message(¬ification) {
|
||||
notify_mcp_tool_call_event(
|
||||
sess.as_ref(),
|
||||
turn_context.as_ref(),
|
||||
EventMsg::McpToolCallProgress(McpToolCallProgressEvent { call_id, message }),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn format_mcp_tool_call_progress_message(
|
||||
notification: &ProgressNotificationParam,
|
||||
) -> Option<String> {
|
||||
if let Some(message) = notification.message.as_deref() {
|
||||
let trimmed = message.trim();
|
||||
if !trimmed.is_empty() {
|
||||
return Some(trimmed.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(total) = notification.total
|
||||
&& notification.progress.is_finite()
|
||||
&& total.is_finite()
|
||||
&& total > 0.0
|
||||
{
|
||||
return Some(format!(
|
||||
"{} / {}",
|
||||
format_mcp_progress_value(notification.progress),
|
||||
format_mcp_progress_value(total),
|
||||
));
|
||||
}
|
||||
|
||||
notification.progress.is_finite().then(|| {
|
||||
format!(
|
||||
"Progress: {}",
|
||||
format_mcp_progress_value(notification.progress)
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn format_mcp_progress_value(value: f64) -> String {
|
||||
let formatted = format!("{value:.2}");
|
||||
formatted
|
||||
.trim_end_matches('0')
|
||||
.trim_end_matches('.')
|
||||
.to_string()
|
||||
}
|
||||
|
||||
struct McpAppUsageMetadata {
|
||||
connector_id: Option<String>,
|
||||
app_name: Option<String>,
|
||||
|
||||
@@ -132,6 +132,40 @@ fn approval_question_text_prepends_safety_reason() {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mcp_tool_progress_prefers_server_message() {
|
||||
let notification = rmcp::model::ProgressNotificationParam {
|
||||
progress_token: rmcp::model::ProgressToken(rmcp::model::NumberOrString::String(
|
||||
"token".into(),
|
||||
)),
|
||||
progress: 1.0,
|
||||
total: Some(3.0),
|
||||
message: Some(" indexing files ".to_string()),
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
format_mcp_tool_call_progress_message(¬ification),
|
||||
Some("indexing files".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mcp_tool_progress_formats_numeric_fallback() {
|
||||
let notification = rmcp::model::ProgressNotificationParam {
|
||||
progress_token: rmcp::model::ProgressToken(rmcp::model::NumberOrString::String(
|
||||
"token".into(),
|
||||
)),
|
||||
progress: 2.0,
|
||||
total: Some(5.0),
|
||||
message: None,
|
||||
};
|
||||
|
||||
assert_eq!(
|
||||
format_mcp_tool_call_progress_message(¬ification),
|
||||
Some("2 / 5".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mcp_tool_call_span_records_expected_fields() {
|
||||
let buffer: &'static std::sync::Mutex<Vec<u8>> =
|
||||
|
||||
@@ -334,6 +334,7 @@ async fn run_codex_tool_session_inner(
|
||||
| EventMsg::AgentReasoning(_)
|
||||
| EventMsg::AgentReasoningSectionBreak(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
| EventMsg::McpToolCallProgress(_)
|
||||
| EventMsg::McpToolCallEnd(_)
|
||||
| EventMsg::McpListToolsResponse(_)
|
||||
| EventMsg::ListSkillsResponse(_)
|
||||
|
||||
@@ -1297,6 +1297,8 @@ pub enum EventMsg {
|
||||
|
||||
McpToolCallBegin(McpToolCallBeginEvent),
|
||||
|
||||
McpToolCallProgress(McpToolCallProgressEvent),
|
||||
|
||||
McpToolCallEnd(McpToolCallEndEvent),
|
||||
|
||||
WebSearchBegin(WebSearchBeginEvent),
|
||||
@@ -2168,6 +2170,14 @@ pub struct McpToolCallBeginEvent {
|
||||
pub invocation: McpInvocation,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS, PartialEq)]
|
||||
pub struct McpToolCallProgressEvent {
|
||||
/// Identifier for the corresponding McpToolCallBegin that is still running.
|
||||
pub call_id: String,
|
||||
/// User-visible progress text for the current MCP tool call.
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema, TS, PartialEq)]
|
||||
pub struct McpToolCallEndEvent {
|
||||
/// Identifier for the corresponding McpToolCallBegin that finished.
|
||||
|
||||
@@ -12,6 +12,7 @@ use rmcp::model::ListResourceTemplatesResult;
|
||||
use rmcp::model::ListResourcesResult;
|
||||
use rmcp::model::ListToolsResult;
|
||||
use rmcp::model::PaginatedRequestParams;
|
||||
use rmcp::model::ProgressNotificationParam;
|
||||
use rmcp::model::RawResource;
|
||||
use rmcp::model::RawResourceTemplate;
|
||||
use rmcp::model::ReadResourceRequestParams;
|
||||
@@ -26,6 +27,8 @@ use rmcp::model::ToolAnnotations;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use tokio::task;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct TestToolServer {
|
||||
@@ -47,6 +50,7 @@ impl TestToolServer {
|
||||
let tools = vec![
|
||||
Self::echo_tool(),
|
||||
Self::echo_dash_tool(),
|
||||
Self::progress_tool(),
|
||||
Self::image_tool(),
|
||||
Self::image_scenario_tool(),
|
||||
];
|
||||
@@ -113,6 +117,24 @@ impl TestToolServer {
|
||||
tool
|
||||
}
|
||||
|
||||
fn progress_tool() -> Tool {
|
||||
#[expect(clippy::expect_used)]
|
||||
let schema: JsonObject = serde_json::from_value(serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"steps": { "type": "integer", "minimum": 1 }
|
||||
},
|
||||
"additionalProperties": false
|
||||
}))
|
||||
.expect("progress tool schema should deserialize");
|
||||
|
||||
Tool::new(
|
||||
Cow::Borrowed("progress"),
|
||||
Cow::Borrowed("Emit progress notifications before completing."),
|
||||
Arc::new(schema),
|
||||
)
|
||||
}
|
||||
|
||||
/// Tool intended for manual testing of Codex TUI rendering for MCP image tool results.
|
||||
///
|
||||
/// This exists to exercise edge cases where a `CallToolResult.content` includes image blocks
|
||||
@@ -210,6 +232,16 @@ struct EchoArgs {
|
||||
env_var: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ProgressArgs {
|
||||
#[serde(default = "default_progress_steps")]
|
||||
steps: u64,
|
||||
}
|
||||
|
||||
fn default_progress_steps() -> u64 {
|
||||
3
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
/// Scenarios for `image_scenario`, intended to exercise Codex TUI handling of MCP image outputs.
|
||||
@@ -315,7 +347,7 @@ impl ServerHandler for TestToolServer {
|
||||
async fn call_tool(
|
||||
&self,
|
||||
request: CallToolRequestParams,
|
||||
_context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
context: rmcp::service::RequestContext<rmcp::service::RoleServer>,
|
||||
) -> Result<CallToolResult, McpError> {
|
||||
match request.name.as_ref() {
|
||||
"echo" | "echo-tool" => {
|
||||
@@ -345,6 +377,40 @@ impl ServerHandler for TestToolServer {
|
||||
meta: None,
|
||||
})
|
||||
}
|
||||
"progress" => {
|
||||
let args = match request.arguments {
|
||||
Some(arguments) => serde_json::from_value(serde_json::Value::Object(
|
||||
arguments.into_iter().collect(),
|
||||
))
|
||||
.map_err(|err| McpError::invalid_params(err.to_string(), None))?,
|
||||
None => ProgressArgs {
|
||||
steps: default_progress_steps(),
|
||||
},
|
||||
};
|
||||
let progress_token = context.meta.get_progress_token().ok_or_else(|| {
|
||||
McpError::invalid_params("missing progress token for progress tool", None)
|
||||
})?;
|
||||
for step in 1..=args.steps {
|
||||
context
|
||||
.peer
|
||||
.notify_progress(ProgressNotificationParam {
|
||||
progress_token: progress_token.clone(),
|
||||
progress: step as f64,
|
||||
total: Some(args.steps as f64),
|
||||
message: Some(format!("step {step}")),
|
||||
})
|
||||
.await
|
||||
.map_err(|err| McpError::internal_error(err.to_string(), None))?;
|
||||
sleep(Duration::from_millis(10)).await;
|
||||
}
|
||||
|
||||
Ok(CallToolResult {
|
||||
content: Vec::new(),
|
||||
structured_content: Some(json!({ "steps": args.steps })),
|
||||
is_error: Some(false),
|
||||
meta: None,
|
||||
})
|
||||
}
|
||||
"image" => {
|
||||
// Read a data URL (e.g. data:image/png;base64,AAA...) from env and convert to
|
||||
// an MCP image content block. Tests set MCP_TEST_IMAGE_DATA_URL.
|
||||
|
||||
@@ -28,4 +28,5 @@ pub use rmcp_client::ElicitationResponse;
|
||||
pub use rmcp_client::ListToolsWithConnectorIdResult;
|
||||
pub use rmcp_client::RmcpClient;
|
||||
pub use rmcp_client::SendElicitation;
|
||||
pub use rmcp_client::SendProgressNotification;
|
||||
pub use rmcp_client::ToolWithConnectorId;
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use rmcp::ClientHandler;
|
||||
use rmcp::RoleClient;
|
||||
use rmcp::handler::client::progress::ProgressDispatcher;
|
||||
use rmcp::model::CancelledNotificationParam;
|
||||
use rmcp::model::ClientInfo;
|
||||
use rmcp::model::CreateElicitationRequestParams;
|
||||
@@ -23,6 +24,7 @@ use crate::rmcp_client::SendElicitation;
|
||||
pub(crate) struct LoggingClientHandler {
|
||||
client_info: ClientInfo,
|
||||
send_elicitation: Arc<SendElicitation>,
|
||||
progress_handler: ProgressDispatcher,
|
||||
}
|
||||
|
||||
impl LoggingClientHandler {
|
||||
@@ -30,8 +32,13 @@ impl LoggingClientHandler {
|
||||
Self {
|
||||
client_info,
|
||||
send_elicitation: Arc::new(send_elicitation),
|
||||
progress_handler: ProgressDispatcher::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn progress_handler(&self) -> ProgressDispatcher {
|
||||
self.progress_handler.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientHandler for LoggingClientHandler {
|
||||
@@ -66,6 +73,7 @@ impl ClientHandler for LoggingClientHandler {
|
||||
"MCP server progress notification (token: {:?}, progress: {}, total: {:?}, message: {:?})",
|
||||
params.progress_token, params.progress, params.total, params.message
|
||||
);
|
||||
self.progress_handler.handle_notification(params).await;
|
||||
}
|
||||
|
||||
async fn on_resource_updated(
|
||||
|
||||
@@ -452,6 +452,8 @@ impl From<ElicitationResponse> for CreateElicitationResult {
|
||||
pub type SendElicitation = Box<
|
||||
dyn Fn(RequestId, Elicitation) -> BoxFuture<'static, Result<ElicitationResponse>> + Send + Sync,
|
||||
>;
|
||||
pub type SendProgressNotification =
|
||||
Arc<dyn Fn(rmcp::model::ProgressNotificationParam) -> BoxFuture<'static, ()> + Send + Sync>;
|
||||
|
||||
pub struct ToolWithConnectorId {
|
||||
pub tool: Tool,
|
||||
@@ -702,6 +704,7 @@ impl RmcpClient {
|
||||
arguments: Option<serde_json::Value>,
|
||||
meta: Option<serde_json::Value>,
|
||||
timeout: Option<Duration>,
|
||||
progress_notification: Option<SendProgressNotification>,
|
||||
) -> Result<CallToolResult> {
|
||||
self.refresh_oauth_if_needed().await;
|
||||
let arguments = match arguments {
|
||||
@@ -728,12 +731,21 @@ impl RmcpClient {
|
||||
arguments,
|
||||
task: None,
|
||||
};
|
||||
let progress_handler = self
|
||||
.initialize_context
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.map(|context| context.handler.progress_handler())
|
||||
.ok_or_else(|| anyhow!("client not initialized"))?;
|
||||
let result = self
|
||||
.run_service_operation("tools/call", timeout, move |service| {
|
||||
let rmcp_params = rmcp_params.clone();
|
||||
let meta = meta.clone();
|
||||
let progress_handler = progress_handler.clone();
|
||||
let progress_notification = progress_notification.clone();
|
||||
async move {
|
||||
let result = service
|
||||
let handle = service
|
||||
.peer()
|
||||
.send_request_with_option(
|
||||
ClientRequest::CallToolRequest(rmcp::model::CallToolRequest {
|
||||
@@ -746,9 +758,25 @@ impl RmcpClient {
|
||||
meta,
|
||||
},
|
||||
)
|
||||
.await?
|
||||
.await_response()
|
||||
.await?;
|
||||
let result = if let Some(progress_notification) = progress_notification {
|
||||
let mut progress_subscriber = progress_handler
|
||||
.subscribe(handle.progress_token.clone())
|
||||
.await;
|
||||
let mut response = handle.await_response().boxed();
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = &mut response => break result?,
|
||||
notification = progress_subscriber.next() => {
|
||||
if let Some(notification) = notification {
|
||||
progress_notification(notification).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
handle.await_response().await?
|
||||
};
|
||||
match result {
|
||||
ServerResult::CallToolResult(result) => Ok(result),
|
||||
_ => Err(rmcp::service::ServiceError::UnexpectedResponse),
|
||||
|
||||
115
codex-rs/rmcp-client/tests/progress.rs
Normal file
115
codex-rs/rmcp-client/tests/progress.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
use std::ffi::OsString;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_rmcp_client::ElicitationAction;
|
||||
use codex_rmcp_client::ElicitationResponse;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
use codex_rmcp_client::SendProgressNotification;
|
||||
use codex_utils_cargo_bin::CargoBinError;
|
||||
use futures::FutureExt as _;
|
||||
use pretty_assertions::assert_eq;
|
||||
use rmcp::model::ClientCapabilities;
|
||||
use rmcp::model::ElicitationCapability;
|
||||
use rmcp::model::FormElicitationCapability;
|
||||
use rmcp::model::Implementation;
|
||||
use rmcp::model::InitializeRequestParams;
|
||||
use rmcp::model::ProtocolVersion;
|
||||
use serde_json::json;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
fn stdio_server_bin() -> Result<PathBuf, CargoBinError> {
|
||||
codex_utils_cargo_bin::cargo_bin("test_stdio_server")
|
||||
}
|
||||
|
||||
fn init_params() -> InitializeRequestParams {
|
||||
InitializeRequestParams {
|
||||
meta: None,
|
||||
capabilities: ClientCapabilities {
|
||||
experimental: None,
|
||||
extensions: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
elicitation: Some(ElicitationCapability {
|
||||
form: Some(FormElicitationCapability {
|
||||
schema_validation: None,
|
||||
}),
|
||||
url: None,
|
||||
}),
|
||||
tasks: None,
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "codex-test".into(),
|
||||
version: "0.0.0-test".into(),
|
||||
title: Some("Codex rmcp progress test".into()),
|
||||
description: None,
|
||||
icons: None,
|
||||
website_url: None,
|
||||
},
|
||||
protocol_version: ProtocolVersion::V_2025_06_18,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn rmcp_client_forwards_progress_notifications() -> anyhow::Result<()> {
|
||||
let client = RmcpClient::new_stdio_client(
|
||||
stdio_server_bin()?.into(),
|
||||
Vec::<OsString>::new(),
|
||||
None,
|
||||
&[],
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
client
|
||||
.initialize(
|
||||
init_params(),
|
||||
Some(Duration::from_secs(5)),
|
||||
Box::new(|_, _| {
|
||||
async {
|
||||
Ok(ElicitationResponse {
|
||||
action: ElicitationAction::Accept,
|
||||
content: Some(json!({})),
|
||||
meta: None,
|
||||
})
|
||||
}
|
||||
.boxed()
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let received_messages = Arc::new(Mutex::new(Vec::new()));
|
||||
let progress_notification: SendProgressNotification = Arc::new({
|
||||
let received_messages = Arc::clone(&received_messages);
|
||||
move |notification| {
|
||||
let received_messages = Arc::clone(&received_messages);
|
||||
async move {
|
||||
received_messages.lock().await.push(notification.message);
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
});
|
||||
|
||||
let result = client
|
||||
.call_tool(
|
||||
"progress".to_string(),
|
||||
Some(json!({ "steps": 3 })),
|
||||
None,
|
||||
Some(Duration::from_secs(5)),
|
||||
Some(progress_notification),
|
||||
)
|
||||
.await?;
|
||||
|
||||
assert_eq!(result.structured_content, Some(json!({ "steps": 3 })));
|
||||
assert_eq!(
|
||||
*received_messages.lock().await,
|
||||
vec![
|
||||
Some("step 1".to_string()),
|
||||
Some("step 2".to_string()),
|
||||
Some("step 3".to_string()),
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -107,6 +107,7 @@ async fn call_echo_tool(client: &RmcpClient, message: &str) -> anyhow::Result<Ca
|
||||
Some(json!({ "message": message })),
|
||||
/*meta*/ None,
|
||||
Some(Duration::from_secs(5)),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -143,6 +143,7 @@ fn event_msg_persistence_mode(ev: &EventMsg) -> Option<EventPersistenceMode> {
|
||||
| EventMsg::SessionConfigured(_)
|
||||
| EventMsg::ThreadNameUpdated(_)
|
||||
| EventMsg::McpToolCallBegin(_)
|
||||
| EventMsg::McpToolCallProgress(_)
|
||||
| EventMsg::WebSearchBegin(_)
|
||||
| EventMsg::ExecCommandBegin(_)
|
||||
| EventMsg::TerminalInteraction(_)
|
||||
|
||||
@@ -6927,6 +6927,7 @@ impl ChatWidget {
|
||||
EventMsg::ImageGenerationBegin(ev) => self.on_image_generation_begin(ev),
|
||||
EventMsg::ImageGenerationEnd(ev) => self.on_image_generation_end(ev),
|
||||
EventMsg::McpToolCallBegin(ev) => self.on_mcp_tool_call_begin(ev),
|
||||
EventMsg::McpToolCallProgress(_) => {}
|
||||
EventMsg::McpToolCallEnd(ev) => self.on_mcp_tool_call_end(ev),
|
||||
EventMsg::WebSearchBegin(ev) => self.on_web_search_begin(ev),
|
||||
EventMsg::WebSearchEnd(ev) => self.on_web_search_end(ev),
|
||||
|
||||
Reference in New Issue
Block a user