mirror of
https://github.com/openai/codex.git
synced 2026-06-02 19:31:59 +00:00
Compare commits
3 Commits
jif/fix-ra
...
cconger/st
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2d7973334f | ||
|
|
b93007878c | ||
|
|
54d4da7f29 |
@@ -2588,6 +2588,17 @@
|
||||
],
|
||||
"type": "object"
|
||||
},
|
||||
"ThreadDecrementElicitationParams": {
|
||||
"properties": {
|
||||
"threadId": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"threadId"
|
||||
],
|
||||
"type": "object"
|
||||
},
|
||||
"ThreadForkParams": {
|
||||
"description": "There are two ways to fork a thread: 1. By thread_id: load the thread from disk by thread_id and fork it into a new thread. 2. By path: load the thread from disk by path and fork it into a new thread.\n\nIf using path, the thread_id param will be ignored.\n\nPrefer using thread_id whenever possible.",
|
||||
"properties": {
|
||||
@@ -2661,6 +2672,17 @@
|
||||
"ThreadId": {
|
||||
"type": "string"
|
||||
},
|
||||
"ThreadIncrementElicitationParams": {
|
||||
"properties": {
|
||||
"threadId": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"threadId"
|
||||
],
|
||||
"type": "object"
|
||||
},
|
||||
"ThreadListParams": {
|
||||
"properties": {
|
||||
"archived": {
|
||||
@@ -3523,6 +3545,54 @@
|
||||
"title": "Thread/archiveRequest",
|
||||
"type": "object"
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"id": {
|
||||
"$ref": "#/definitions/RequestId"
|
||||
},
|
||||
"method": {
|
||||
"enum": [
|
||||
"thread/increment_elicitation"
|
||||
],
|
||||
"title": "Thread/incrementElicitationRequestMethod",
|
||||
"type": "string"
|
||||
},
|
||||
"params": {
|
||||
"$ref": "#/definitions/ThreadIncrementElicitationParams"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"id",
|
||||
"method",
|
||||
"params"
|
||||
],
|
||||
"title": "Thread/incrementElicitationRequest",
|
||||
"type": "object"
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"id": {
|
||||
"$ref": "#/definitions/RequestId"
|
||||
},
|
||||
"method": {
|
||||
"enum": [
|
||||
"thread/decrement_elicitation"
|
||||
],
|
||||
"title": "Thread/decrementElicitationRequestMethod",
|
||||
"type": "string"
|
||||
},
|
||||
"params": {
|
||||
"$ref": "#/definitions/ThreadDecrementElicitationParams"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"id",
|
||||
"method",
|
||||
"params"
|
||||
],
|
||||
"title": "Thread/decrementElicitationRequest",
|
||||
"type": "object"
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"id": {
|
||||
|
||||
@@ -570,6 +570,54 @@
|
||||
"title": "Thread/archiveRequest",
|
||||
"type": "object"
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"id": {
|
||||
"$ref": "#/definitions/RequestId"
|
||||
},
|
||||
"method": {
|
||||
"enum": [
|
||||
"thread/increment_elicitation"
|
||||
],
|
||||
"title": "Thread/incrementElicitationRequestMethod",
|
||||
"type": "string"
|
||||
},
|
||||
"params": {
|
||||
"$ref": "#/definitions/v2/ThreadIncrementElicitationParams"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"id",
|
||||
"method",
|
||||
"params"
|
||||
],
|
||||
"title": "Thread/incrementElicitationRequest",
|
||||
"type": "object"
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"id": {
|
||||
"$ref": "#/definitions/RequestId"
|
||||
},
|
||||
"method": {
|
||||
"enum": [
|
||||
"thread/decrement_elicitation"
|
||||
],
|
||||
"title": "Thread/decrementElicitationRequestMethod",
|
||||
"type": "string"
|
||||
},
|
||||
"params": {
|
||||
"$ref": "#/definitions/v2/ThreadDecrementElicitationParams"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"id",
|
||||
"method",
|
||||
"params"
|
||||
],
|
||||
"title": "Thread/decrementElicitationRequest",
|
||||
"type": "object"
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"id": {
|
||||
@@ -15359,6 +15407,38 @@
|
||||
"title": "ThreadCompactStartResponse",
|
||||
"type": "object"
|
||||
},
|
||||
"ThreadDecrementElicitationParams": {
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"properties": {
|
||||
"threadId": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"threadId"
|
||||
],
|
||||
"title": "ThreadDecrementElicitationParams",
|
||||
"type": "object"
|
||||
},
|
||||
"ThreadDecrementElicitationResponse": {
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"properties": {
|
||||
"count": {
|
||||
"format": "uint64",
|
||||
"minimum": 0.0,
|
||||
"type": "integer"
|
||||
},
|
||||
"paused": {
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"count",
|
||||
"paused"
|
||||
],
|
||||
"title": "ThreadDecrementElicitationResponse",
|
||||
"type": "object"
|
||||
},
|
||||
"ThreadForkParams": {
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"description": "There are two ways to fork a thread: 1. By thread_id: load the thread from disk by thread_id and fork it into a new thread. 2. By path: load the thread from disk by path and fork it into a new thread.\n\nIf using path, the thread_id param will be ignored.\n\nPrefer using thread_id whenever possible.",
|
||||
@@ -15477,6 +15557,38 @@
|
||||
"ThreadId": {
|
||||
"type": "string"
|
||||
},
|
||||
"ThreadIncrementElicitationParams": {
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"properties": {
|
||||
"threadId": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"threadId"
|
||||
],
|
||||
"title": "ThreadIncrementElicitationParams",
|
||||
"type": "object"
|
||||
},
|
||||
"ThreadIncrementElicitationResponse": {
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"properties": {
|
||||
"count": {
|
||||
"format": "uint64",
|
||||
"minimum": 0.0,
|
||||
"type": "integer"
|
||||
},
|
||||
"paused": {
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"count",
|
||||
"paused"
|
||||
],
|
||||
"title": "ThreadIncrementElicitationResponse",
|
||||
"type": "object"
|
||||
},
|
||||
"ThreadItem": {
|
||||
"oneOf": [
|
||||
{
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"properties": {
|
||||
"threadId": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"threadId"
|
||||
],
|
||||
"title": "ThreadDecrementElicitationParams",
|
||||
"type": "object"
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"properties": {
|
||||
"count": {
|
||||
"format": "uint64",
|
||||
"minimum": 0.0,
|
||||
"type": "integer"
|
||||
},
|
||||
"paused": {
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"count",
|
||||
"paused"
|
||||
],
|
||||
"title": "ThreadDecrementElicitationResponse",
|
||||
"type": "object"
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"properties": {
|
||||
"threadId": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"threadId"
|
||||
],
|
||||
"title": "ThreadIncrementElicitationParams",
|
||||
"type": "object"
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
{
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"properties": {
|
||||
"count": {
|
||||
"format": "uint64",
|
||||
"minimum": 0.0,
|
||||
"type": "integer"
|
||||
},
|
||||
"paused": {
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"count",
|
||||
"paused"
|
||||
],
|
||||
"title": "ThreadIncrementElicitationResponse",
|
||||
"type": "object"
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,5 @@
|
||||
// GENERATED CODE! DO NOT MODIFY BY HAND!
|
||||
|
||||
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
|
||||
|
||||
export type ThreadDecrementElicitationParams = { threadId: string, };
|
||||
@@ -0,0 +1,5 @@
|
||||
// GENERATED CODE! DO NOT MODIFY BY HAND!
|
||||
|
||||
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
|
||||
|
||||
export type ThreadDecrementElicitationResponse = { count: bigint, paused: boolean, };
|
||||
@@ -0,0 +1,5 @@
|
||||
// GENERATED CODE! DO NOT MODIFY BY HAND!
|
||||
|
||||
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
|
||||
|
||||
export type ThreadIncrementElicitationParams = { threadId: string, };
|
||||
@@ -0,0 +1,5 @@
|
||||
// GENERATED CODE! DO NOT MODIFY BY HAND!
|
||||
|
||||
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
|
||||
|
||||
export type ThreadIncrementElicitationResponse = { count: bigint, paused: boolean, };
|
||||
@@ -152,8 +152,12 @@ export type { ThreadArchiveResponse } from "./ThreadArchiveResponse";
|
||||
export type { ThreadArchivedNotification } from "./ThreadArchivedNotification";
|
||||
export type { ThreadCompactStartParams } from "./ThreadCompactStartParams";
|
||||
export type { ThreadCompactStartResponse } from "./ThreadCompactStartResponse";
|
||||
export type { ThreadDecrementElicitationParams } from "./ThreadDecrementElicitationParams";
|
||||
export type { ThreadDecrementElicitationResponse } from "./ThreadDecrementElicitationResponse";
|
||||
export type { ThreadForkParams } from "./ThreadForkParams";
|
||||
export type { ThreadForkResponse } from "./ThreadForkResponse";
|
||||
export type { ThreadIncrementElicitationParams } from "./ThreadIncrementElicitationParams";
|
||||
export type { ThreadIncrementElicitationResponse } from "./ThreadIncrementElicitationResponse";
|
||||
export type { ThreadItem } from "./ThreadItem";
|
||||
export type { ThreadListParams } from "./ThreadListParams";
|
||||
export type { ThreadListResponse } from "./ThreadListResponse";
|
||||
|
||||
@@ -202,6 +202,14 @@ client_request_definitions! {
|
||||
params: v2::ThreadArchiveParams,
|
||||
response: v2::ThreadArchiveResponse,
|
||||
},
|
||||
ThreadIncrementElicitation => "thread/increment_elicitation" {
|
||||
params: v2::ThreadIncrementElicitationParams,
|
||||
response: v2::ThreadIncrementElicitationResponse,
|
||||
},
|
||||
ThreadDecrementElicitation => "thread/decrement_elicitation" {
|
||||
params: v2::ThreadDecrementElicitationParams,
|
||||
response: v2::ThreadDecrementElicitationResponse,
|
||||
},
|
||||
ThreadSetName => "thread/name/set" {
|
||||
params: v2::ThreadSetNameParams,
|
||||
response: v2::ThreadSetNameResponse,
|
||||
|
||||
@@ -1692,6 +1692,36 @@ pub struct ThreadArchiveParams {
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ThreadArchiveResponse {}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ThreadIncrementElicitationParams {
|
||||
pub thread_id: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ThreadIncrementElicitationResponse {
|
||||
pub count: u64,
|
||||
pub paused: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ThreadDecrementElicitationParams {
|
||||
pub thread_id: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
pub struct ThreadDecrementElicitationResponse {
|
||||
pub count: u64,
|
||||
pub paused: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[ts(export_to = "v2/")]
|
||||
|
||||
@@ -122,6 +122,8 @@ Example with notification opt-out:
|
||||
- `thread/read` — read a stored thread by id without resuming it; optionally include turns via `includeTurns`. The returned `thread` includes `status` (`ThreadStatus`), defaulting to `notLoaded` when the thread is not currently loaded.
|
||||
- `thread/status/changed` — notification emitted when a loaded thread’s status changes (`threadId` + new `status`).
|
||||
- `thread/archive` — move a thread’s rollout file into the archived directory; returns `{}` on success and emits `thread/archived`.
|
||||
- `thread/increment_elicitation` — atomically increment a thread-scoped out-of-band elicitation counter and pause shell-tool timeout stopwatches while the counter is above zero; returns `{ count, paused }`.
|
||||
- `thread/decrement_elicitation` — atomically decrement that counter and resume shell-tool timeout stopwatches when it reaches zero; returns `{ count, paused }`.
|
||||
- `thread/name/set` — set or update a thread’s user-facing name; returns `{}` on success. Thread names are not required to be unique; name lookups resolve to the most recently updated thread.
|
||||
- `thread/unarchive` — move an archived rollout file back into the sessions directory; returns the restored `thread` on success and emits `thread/unarchived`.
|
||||
- `thread/compact/start` — trigger conversation history compaction for a thread; returns `{}` immediately while progress streams through standard turn/item notifications.
|
||||
@@ -309,6 +311,18 @@ Use `thread/unarchive` to move an archived rollout back into the sessions direct
|
||||
{ "method": "thread/unarchived", "params": { "threadId": "thr_b" } }
|
||||
```
|
||||
|
||||
### Example: Pause command timeout stopwatches for out-of-band approval
|
||||
|
||||
Use `thread/increment_elicitation` and `thread/decrement_elicitation` to bracket external approval flows (for example, a downstream CLI waiting on user approval). While `count > 0`, active shell-tool timeout stopwatches are paused.
|
||||
|
||||
```json
|
||||
{ "method": "thread/increment_elicitation", "id": 24, "params": { "threadId": "thr_b" } }
|
||||
{ "id": 24, "result": { "count": 1, "paused": true } }
|
||||
|
||||
{ "method": "thread/decrement_elicitation", "id": 25, "params": { "threadId": "thr_b" } }
|
||||
{ "id": 25, "result": { "count": 0, "paused": false } }
|
||||
```
|
||||
|
||||
### Example: Trigger thread compaction
|
||||
|
||||
Use `thread/compact/start` to trigger manual history compaction for a thread. The request returns immediately with `{}`.
|
||||
|
||||
@@ -128,8 +128,12 @@ use codex_app_server_protocol::ThreadBackgroundTerminalsCleanParams;
|
||||
use codex_app_server_protocol::ThreadBackgroundTerminalsCleanResponse;
|
||||
use codex_app_server_protocol::ThreadCompactStartParams;
|
||||
use codex_app_server_protocol::ThreadCompactStartResponse;
|
||||
use codex_app_server_protocol::ThreadDecrementElicitationParams;
|
||||
use codex_app_server_protocol::ThreadDecrementElicitationResponse;
|
||||
use codex_app_server_protocol::ThreadForkParams;
|
||||
use codex_app_server_protocol::ThreadForkResponse;
|
||||
use codex_app_server_protocol::ThreadIncrementElicitationParams;
|
||||
use codex_app_server_protocol::ThreadIncrementElicitationResponse;
|
||||
use codex_app_server_protocol::ThreadItem;
|
||||
use codex_app_server_protocol::ThreadListParams;
|
||||
use codex_app_server_protocol::ThreadListResponse;
|
||||
@@ -544,6 +548,14 @@ impl CodexMessageProcessor {
|
||||
self.thread_archive(to_connection_request_id(request_id), params)
|
||||
.await;
|
||||
}
|
||||
ClientRequest::ThreadIncrementElicitation { request_id, params } => {
|
||||
self.thread_increment_elicitation(to_connection_request_id(request_id), params)
|
||||
.await;
|
||||
}
|
||||
ClientRequest::ThreadDecrementElicitation { request_id, params } => {
|
||||
self.thread_decrement_elicitation(to_connection_request_id(request_id), params)
|
||||
.await;
|
||||
}
|
||||
ClientRequest::ThreadSetName { request_id, params } => {
|
||||
self.thread_set_name(to_connection_request_id(request_id), params)
|
||||
.await;
|
||||
@@ -2169,6 +2181,79 @@ impl CodexMessageProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
async fn thread_increment_elicitation(
|
||||
&self,
|
||||
request_id: ConnectionRequestId,
|
||||
params: ThreadIncrementElicitationParams,
|
||||
) {
|
||||
let (_, thread) = match self.load_thread(¶ms.thread_id).await {
|
||||
Ok(value) => value,
|
||||
Err(error) => {
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match thread.increment_out_of_band_elicitation_count().await {
|
||||
Ok(count) => {
|
||||
self.outgoing
|
||||
.send_response(
|
||||
request_id,
|
||||
ThreadIncrementElicitationResponse {
|
||||
count,
|
||||
paused: count > 0,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Err(err) => {
|
||||
self.send_internal_error(
|
||||
request_id,
|
||||
format!("failed to increment out-of-band elicitation counter: {err}"),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn thread_decrement_elicitation(
|
||||
&self,
|
||||
request_id: ConnectionRequestId,
|
||||
params: ThreadDecrementElicitationParams,
|
||||
) {
|
||||
let (_, thread) = match self.load_thread(¶ms.thread_id).await {
|
||||
Ok(value) => value,
|
||||
Err(error) => {
|
||||
self.outgoing.send_error(request_id, error).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match thread.decrement_out_of_band_elicitation_count().await {
|
||||
Ok(count) => {
|
||||
self.outgoing
|
||||
.send_response(
|
||||
request_id,
|
||||
ThreadDecrementElicitationResponse {
|
||||
count,
|
||||
paused: count > 0,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Err(CodexErr::InvalidRequest(message)) => {
|
||||
self.send_invalid_request_error(request_id, message).await;
|
||||
}
|
||||
Err(err) => {
|
||||
self.send_internal_error(
|
||||
request_id,
|
||||
format!("failed to decrement out-of-band elicitation counter: {err}"),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn thread_set_name(&self, request_id: ConnectionRequestId, params: ThreadSetNameParams) {
|
||||
let ThreadSetNameParams { thread_id, name } = params;
|
||||
let Some(name) = codex_core::util::normalize_thread_name(&name) else {
|
||||
|
||||
@@ -28,6 +28,7 @@ pub use models_cache::write_models_cache;
|
||||
pub use models_cache::write_models_cache_with_models;
|
||||
pub use responses::create_apply_patch_sse_response;
|
||||
pub use responses::create_exec_command_sse_response;
|
||||
pub use responses::create_exec_command_sse_response_for_command;
|
||||
pub use responses::create_final_assistant_message_sse_response;
|
||||
pub use responses::create_request_user_input_sse_response;
|
||||
pub use responses::create_shell_command_sse_response;
|
||||
|
||||
@@ -53,7 +53,9 @@ use codex_app_server_protocol::SetDefaultModelParams;
|
||||
use codex_app_server_protocol::SkillsListParams;
|
||||
use codex_app_server_protocol::ThreadArchiveParams;
|
||||
use codex_app_server_protocol::ThreadCompactStartParams;
|
||||
use codex_app_server_protocol::ThreadDecrementElicitationParams;
|
||||
use codex_app_server_protocol::ThreadForkParams;
|
||||
use codex_app_server_protocol::ThreadIncrementElicitationParams;
|
||||
use codex_app_server_protocol::ThreadListParams;
|
||||
use codex_app_server_protocol::ThreadLoadedListParams;
|
||||
use codex_app_server_protocol::ThreadReadParams;
|
||||
@@ -472,6 +474,26 @@ impl McpProcess {
|
||||
self.send_request("thread/read", params).await
|
||||
}
|
||||
|
||||
/// Send a `thread/increment_elicitation` JSON-RPC request.
|
||||
pub async fn send_thread_increment_elicitation_request(
|
||||
&mut self,
|
||||
params: ThreadIncrementElicitationParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("thread/increment_elicitation", params)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Send a `thread/decrement_elicitation` JSON-RPC request.
|
||||
pub async fn send_thread_decrement_elicitation_request(
|
||||
&mut self,
|
||||
params: ThreadDecrementElicitationParams,
|
||||
) -> anyhow::Result<i64> {
|
||||
let params = Some(serde_json::to_value(params)?);
|
||||
self.send_request("thread/decrement_elicitation", params)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Send a `model/list` JSON-RPC request.
|
||||
pub async fn send_list_models_request(
|
||||
&mut self,
|
||||
|
||||
@@ -50,9 +50,18 @@ pub fn create_exec_command_sse_response(call_id: &str) -> anyhow::Result<String>
|
||||
let command = std::iter::once(cmd.to_string())
|
||||
.chain(args.into_iter().map(str::to_string))
|
||||
.collect::<Vec<_>>();
|
||||
create_exec_command_sse_response_for_command(command, 500, call_id)
|
||||
}
|
||||
|
||||
pub fn create_exec_command_sse_response_for_command(
|
||||
command: Vec<String>,
|
||||
yield_time_ms: u64,
|
||||
call_id: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let command_str = shlex::try_join(command.iter().map(String::as_str))?;
|
||||
let tool_call_arguments = serde_json::to_string(&json!({
|
||||
"cmd": command.join(" "),
|
||||
"yield_time_ms": 500
|
||||
"cmd": command_str,
|
||||
"yield_time_ms": yield_time_ms
|
||||
}))?;
|
||||
Ok(responses::sse(vec![
|
||||
responses::ev_response_created("resp-1"),
|
||||
|
||||
77
codex-rs/app-server/tests/fixtures/elicitation_stopwatch/orchestrator.py
vendored
Normal file
77
codex-rs/app-server/tests/fixtures/elicitation_stopwatch/orchestrator.py
vendored
Normal file
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
REQUESTED_FILENAME = "elicitation_requested"
|
||||
RELEASE_FILENAME = "elicitation_release"
|
||||
|
||||
|
||||
def requested_path(state_dir: Path) -> Path:
|
||||
return state_dir / REQUESTED_FILENAME
|
||||
|
||||
|
||||
def release_path(state_dir: Path) -> Path:
|
||||
return state_dir / RELEASE_FILENAME
|
||||
|
||||
|
||||
def cmd_wait_for_request(state_dir: Path, timeout_seconds: float) -> int:
|
||||
deadline = time.monotonic() + timeout_seconds
|
||||
while time.monotonic() < deadline:
|
||||
if requested_path(state_dir).exists():
|
||||
return 0
|
||||
time.sleep(0.05)
|
||||
|
||||
print(
|
||||
f"timed out waiting for {requested_path(state_dir)}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 2
|
||||
|
||||
|
||||
def cmd_release(state_dir: Path) -> int:
|
||||
release_path(state_dir).write_text("approved\n", encoding="utf-8")
|
||||
return 0
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--state-dir",
|
||||
required=True,
|
||||
type=Path,
|
||||
help="Directory shared with the elicitation trigger script.",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
wait_parser = subparsers.add_parser("wait-for-request")
|
||||
wait_parser.add_argument(
|
||||
"--timeout-seconds",
|
||||
type=float,
|
||||
default=5.0,
|
||||
)
|
||||
|
||||
subparsers.add_parser("release")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
state_dir: Path = args.state_dir
|
||||
state_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if args.command == "wait-for-request":
|
||||
return cmd_wait_for_request(state_dir, args.timeout_seconds)
|
||||
|
||||
if args.command == "release":
|
||||
return cmd_release(state_dir)
|
||||
|
||||
print(f"unsupported command: {args.command}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
43
codex-rs/app-server/tests/fixtures/elicitation_stopwatch/trigger_elicitation.py
vendored
Normal file
43
codex-rs/app-server/tests/fixtures/elicitation_stopwatch/trigger_elicitation.py
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
REQUESTED_FILENAME = "elicitation_requested"
|
||||
RELEASE_FILENAME = "elicitation_release"
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--state-dir",
|
||||
required=True,
|
||||
type=Path,
|
||||
help="Directory shared with the test orchestrator.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
state_dir: Path = args.state_dir
|
||||
state_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
requested = state_dir / REQUESTED_FILENAME
|
||||
release = state_dir / RELEASE_FILENAME
|
||||
|
||||
requested.write_text(f"pid={os.getpid()}\n", encoding="utf-8")
|
||||
print("waited for a user approval", file=sys.stderr, flush=True)
|
||||
|
||||
while not release.exists():
|
||||
time.sleep(0.05)
|
||||
|
||||
print("approval received", flush=True)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
@@ -2,6 +2,7 @@ use anyhow::Result;
|
||||
use app_test_support::McpProcess;
|
||||
use app_test_support::create_apply_patch_sse_response;
|
||||
use app_test_support::create_exec_command_sse_response;
|
||||
use app_test_support::create_exec_command_sse_response_for_command;
|
||||
use app_test_support::create_fake_rollout;
|
||||
use app_test_support::create_final_assistant_message_sse_response;
|
||||
use app_test_support::create_mock_responses_server_sequence;
|
||||
@@ -26,6 +27,10 @@ use codex_app_server_protocol::PatchChangeKind;
|
||||
use codex_app_server_protocol::RequestId;
|
||||
use codex_app_server_protocol::ServerRequest;
|
||||
use codex_app_server_protocol::TextElement;
|
||||
use codex_app_server_protocol::ThreadDecrementElicitationParams;
|
||||
use codex_app_server_protocol::ThreadDecrementElicitationResponse;
|
||||
use codex_app_server_protocol::ThreadIncrementElicitationParams;
|
||||
use codex_app_server_protocol::ThreadIncrementElicitationResponse;
|
||||
use codex_app_server_protocol::ThreadItem;
|
||||
use codex_app_server_protocol::ThreadStartParams;
|
||||
use codex_app_server_protocol::ThreadStartResponse;
|
||||
@@ -45,12 +50,14 @@ use codex_protocol::config_types::ModeKind;
|
||||
use codex_protocol::config_types::Personality;
|
||||
use codex_protocol::config_types::Settings;
|
||||
use codex_protocol::openai_models::ReasoningEffort;
|
||||
use codex_utils_cargo_bin::find_resource;
|
||||
use core_test_support::responses;
|
||||
use core_test_support::skip_if_no_network;
|
||||
use pretty_assertions::assert_eq;
|
||||
use std::collections::BTreeMap;
|
||||
use std::path::Path;
|
||||
use tempfile::TempDir;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::timeout;
|
||||
|
||||
#[cfg(windows)]
|
||||
@@ -1716,6 +1723,207 @@ async fn turn_start_file_change_approval_decline_v2() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[cfg_attr(
|
||||
windows,
|
||||
ignore = "relies on local Python fixture scripts and POSIX unified exec timing"
|
||||
)]
|
||||
async fn thread_elicitation_pauses_unified_exec_stopwatch() -> Result<()> {
|
||||
let tmp = TempDir::new()?;
|
||||
let codex_home = tmp.path().join("codex_home");
|
||||
std::fs::create_dir(&codex_home)?;
|
||||
let workspace = tmp.path().join("workspace");
|
||||
std::fs::create_dir(&workspace)?;
|
||||
let state_dir = tmp.path().join("elicitation_state");
|
||||
std::fs::create_dir(&state_dir)?;
|
||||
|
||||
let orchestrator_script =
|
||||
find_resource!("tests/fixtures/elicitation_stopwatch/orchestrator.py")?;
|
||||
let trigger_script =
|
||||
find_resource!("tests/fixtures/elicitation_stopwatch/trigger_elicitation.py")?;
|
||||
|
||||
let responses = vec![
|
||||
create_exec_command_sse_response_for_command(
|
||||
vec![
|
||||
"python3".to_string(),
|
||||
trigger_script.to_string_lossy().to_string(),
|
||||
"--state-dir".to_string(),
|
||||
state_dir.to_string_lossy().to_string(),
|
||||
],
|
||||
30_000,
|
||||
"uexec-elicitation",
|
||||
)?,
|
||||
create_final_assistant_message_sse_response("done")?,
|
||||
];
|
||||
let server = create_mock_responses_server_sequence(responses).await;
|
||||
create_config_toml_with_sandbox(
|
||||
&codex_home,
|
||||
&server.uri(),
|
||||
"never",
|
||||
&BTreeMap::from([(Feature::UnifiedExec, true)]),
|
||||
"danger-full-access",
|
||||
)?;
|
||||
|
||||
let mut mcp = McpProcess::new(&codex_home).await?;
|
||||
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
|
||||
|
||||
let start_id = mcp
|
||||
.send_thread_start_request(ThreadStartParams {
|
||||
model: Some("mock-model".to_string()),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
let start_resp: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(start_id)),
|
||||
)
|
||||
.await??;
|
||||
let ThreadStartResponse { thread, .. } = to_response::<ThreadStartResponse>(start_resp)?;
|
||||
|
||||
let turn_id = mcp
|
||||
.send_turn_start_request(TurnStartParams {
|
||||
thread_id: thread.id.clone(),
|
||||
input: vec![V2UserInput::Text {
|
||||
text: "run the local approval fixture".to_string(),
|
||||
text_elements: Vec::new(),
|
||||
}],
|
||||
cwd: Some(workspace.clone()),
|
||||
sandbox_policy: Some(codex_app_server_protocol::SandboxPolicy::DangerFullAccess),
|
||||
..Default::default()
|
||||
})
|
||||
.await?;
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(turn_id)),
|
||||
)
|
||||
.await??;
|
||||
|
||||
let started_command = timeout(DEFAULT_READ_TIMEOUT, async {
|
||||
loop {
|
||||
let notif = mcp
|
||||
.read_stream_until_notification_message("item/started")
|
||||
.await?;
|
||||
let started: ItemStartedNotification = serde_json::from_value(
|
||||
notif
|
||||
.params
|
||||
.clone()
|
||||
.expect("item/started should include params"),
|
||||
)?;
|
||||
if let ThreadItem::CommandExecution { .. } = started.item {
|
||||
return Ok::<ThreadItem, anyhow::Error>(started.item);
|
||||
}
|
||||
}
|
||||
})
|
||||
.await??;
|
||||
let ThreadItem::CommandExecution { id, status, .. } = started_command else {
|
||||
unreachable!("loop ensures we break on command execution items");
|
||||
};
|
||||
assert_eq!(id, "uexec-elicitation");
|
||||
assert_eq!(status, CommandExecutionStatus::InProgress);
|
||||
|
||||
run_elicitation_orchestrator(
|
||||
&orchestrator_script,
|
||||
&state_dir,
|
||||
"wait-for-request",
|
||||
Some(5),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let increment_request_id = mcp
|
||||
.send_thread_increment_elicitation_request(ThreadIncrementElicitationParams {
|
||||
thread_id: thread.id.clone(),
|
||||
})
|
||||
.await?;
|
||||
let increment_response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(increment_request_id)),
|
||||
)
|
||||
.await??;
|
||||
let ThreadIncrementElicitationResponse { count, paused } =
|
||||
to_response::<ThreadIncrementElicitationResponse>(increment_response)?;
|
||||
assert_eq!(count, 1);
|
||||
assert!(paused);
|
||||
|
||||
// Hold longer than the default 10s unified-exec timeout. If the stopwatch is not paused,
|
||||
// the command exits/times out and the model will issue the second /responses request.
|
||||
assert!(
|
||||
timeout(
|
||||
std::time::Duration::from_secs(11),
|
||||
read_completed_command_execution_item(&mut mcp),
|
||||
)
|
||||
.await
|
||||
.is_err(),
|
||||
"command execution should remain in progress while elicitation is active"
|
||||
);
|
||||
let requests_during_pause = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("failed to fetch received requests while paused");
|
||||
assert_eq!(
|
||||
requests_during_pause.len(),
|
||||
1,
|
||||
"unexpected extra inference request while elicitation pause was active"
|
||||
);
|
||||
|
||||
let decrement_request_id = mcp
|
||||
.send_thread_decrement_elicitation_request(ThreadDecrementElicitationParams {
|
||||
thread_id: thread.id.clone(),
|
||||
})
|
||||
.await?;
|
||||
let decrement_response: JSONRPCResponse = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_response_message(RequestId::Integer(decrement_request_id)),
|
||||
)
|
||||
.await??;
|
||||
let ThreadDecrementElicitationResponse { count, paused } =
|
||||
to_response::<ThreadDecrementElicitationResponse>(decrement_response)?;
|
||||
assert_eq!(count, 0);
|
||||
assert!(!paused);
|
||||
|
||||
run_elicitation_orchestrator(&orchestrator_script, &state_dir, "release", None).await?;
|
||||
|
||||
let completed_command = timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
read_completed_command_execution_item(&mut mcp),
|
||||
)
|
||||
.await??;
|
||||
let ThreadItem::CommandExecution {
|
||||
id,
|
||||
status,
|
||||
exit_code,
|
||||
aggregated_output,
|
||||
..
|
||||
} = completed_command
|
||||
else {
|
||||
unreachable!("helper only returns command execution items");
|
||||
};
|
||||
assert_eq!(id, "uexec-elicitation");
|
||||
assert_eq!(status, CommandExecutionStatus::Completed);
|
||||
assert_eq!(exit_code, Some(0));
|
||||
let aggregated_output = aggregated_output.expect("expected command output");
|
||||
assert!(aggregated_output.contains("waited for a user approval"));
|
||||
assert!(aggregated_output.contains("approval received"));
|
||||
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("codex/event/task_complete"),
|
||||
)
|
||||
.await??;
|
||||
timeout(
|
||||
DEFAULT_READ_TIMEOUT,
|
||||
mcp.read_stream_until_notification_message("turn/completed"),
|
||||
)
|
||||
.await??;
|
||||
|
||||
let requests_after_resume = server
|
||||
.received_requests()
|
||||
.await
|
||||
.expect("failed to fetch received requests after resume");
|
||||
assert_eq!(requests_after_resume.len(), 2);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[cfg_attr(windows, ignore = "process id reporting differs on Windows")]
|
||||
async fn command_execution_notifications_include_process_id() -> Result<()> {
|
||||
@@ -1853,6 +2061,55 @@ async fn command_execution_notifications_include_process_id() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_elicitation_orchestrator(
|
||||
orchestrator_script: &Path,
|
||||
state_dir: &Path,
|
||||
command: &str,
|
||||
timeout_seconds: Option<u64>,
|
||||
) -> Result<()> {
|
||||
let mut cmd = Command::new("python3");
|
||||
cmd.arg(orchestrator_script)
|
||||
.arg("--state-dir")
|
||||
.arg(state_dir)
|
||||
.arg(command);
|
||||
|
||||
if let Some(timeout_seconds) = timeout_seconds {
|
||||
cmd.arg("--timeout-seconds")
|
||||
.arg(timeout_seconds.to_string());
|
||||
}
|
||||
|
||||
let output = cmd.output().await?;
|
||||
if output.status.success() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
anyhow::bail!(
|
||||
"orchestrator command `{command}` failed with status {:?}\nstdout:\n{}\nstderr:\n{}",
|
||||
output.status.code(),
|
||||
stdout,
|
||||
stderr,
|
||||
);
|
||||
}
|
||||
|
||||
async fn read_completed_command_execution_item(mcp: &mut McpProcess) -> Result<ThreadItem> {
|
||||
loop {
|
||||
let notif = mcp
|
||||
.read_stream_until_notification_message("item/completed")
|
||||
.await?;
|
||||
let completed: ItemCompletedNotification = serde_json::from_value(
|
||||
notif
|
||||
.params
|
||||
.clone()
|
||||
.expect("item/completed should include params"),
|
||||
)?;
|
||||
if let ThreadItem::CommandExecution { .. } = completed.item {
|
||||
return Ok(completed.item);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to create a config.toml pointing at the mock model server.
|
||||
fn create_config_toml(
|
||||
codex_home: &Path,
|
||||
|
||||
@@ -2399,6 +2399,20 @@ impl Session {
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn notify_out_of_band_elicitation_pause_state(
|
||||
&self,
|
||||
paused: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
self.services
|
||||
.mcp_connection_manager
|
||||
.read()
|
||||
.await
|
||||
.notify_out_of_band_elicitation_state_change(&crate::OutOfBandElicitationState {
|
||||
paused,
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Records input items: always append to conversation history and
|
||||
/// persist these response items to rollout.
|
||||
pub(crate) async fn record_conversation_items(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::agent::AgentStatus;
|
||||
use crate::codex::Codex;
|
||||
use crate::codex::SteerInputError;
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CodexResult;
|
||||
use crate::features::Feature;
|
||||
use crate::file_watcher::WatchRegistration;
|
||||
@@ -18,6 +19,7 @@ use codex_protocol::protocol::SessionSource;
|
||||
use codex_protocol::protocol::TokenUsage;
|
||||
use codex_protocol::user_input::UserInput;
|
||||
use std::path::PathBuf;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::watch;
|
||||
|
||||
use crate::state_db::StateDbHandle;
|
||||
@@ -37,6 +39,7 @@ pub struct ThreadConfigSnapshot {
|
||||
pub struct CodexThread {
|
||||
pub(crate) codex: Codex,
|
||||
rollout_path: Option<PathBuf>,
|
||||
out_of_band_elicitation_count: Mutex<u64>,
|
||||
_watch_registration: WatchRegistration,
|
||||
}
|
||||
|
||||
@@ -51,6 +54,7 @@ impl CodexThread {
|
||||
Self {
|
||||
codex,
|
||||
rollout_path,
|
||||
out_of_band_elicitation_count: Mutex::new(0),
|
||||
_watch_registration: watch_registration,
|
||||
}
|
||||
}
|
||||
@@ -130,4 +134,53 @@ impl CodexThread {
|
||||
pub fn enabled(&self, feature: Feature) -> bool {
|
||||
self.codex.enabled(feature)
|
||||
}
|
||||
|
||||
pub async fn increment_out_of_band_elicitation_count(&self) -> CodexResult<u64> {
|
||||
let mut guard = self.out_of_band_elicitation_count.lock().await;
|
||||
let was_zero = *guard == 0;
|
||||
*guard = guard.checked_add(1).ok_or_else(|| {
|
||||
CodexErr::Fatal("out-of-band elicitation count overflowed".to_string())
|
||||
})?;
|
||||
|
||||
if was_zero
|
||||
&& let Err(err) = self
|
||||
.codex
|
||||
.session
|
||||
.notify_out_of_band_elicitation_pause_state(true)
|
||||
.await
|
||||
{
|
||||
*guard -= 1;
|
||||
return Err(CodexErr::Fatal(format!(
|
||||
"failed to pause out-of-band elicitation state: {err:#}"
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(*guard)
|
||||
}
|
||||
|
||||
pub async fn decrement_out_of_band_elicitation_count(&self) -> CodexResult<u64> {
|
||||
let mut guard = self.out_of_band_elicitation_count.lock().await;
|
||||
if *guard == 0 {
|
||||
return Err(CodexErr::InvalidRequest(
|
||||
"out-of-band elicitation count is already zero".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
*guard -= 1;
|
||||
let now_zero = *guard == 0;
|
||||
if now_zero
|
||||
&& let Err(err) = self
|
||||
.codex
|
||||
.session
|
||||
.notify_out_of_band_elicitation_pause_state(false)
|
||||
.await
|
||||
{
|
||||
*guard += 1;
|
||||
return Err(CodexErr::Fatal(format!(
|
||||
"failed to resume out-of-band elicitation state: {err:#}"
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(*guard)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,8 +44,11 @@ mod mcp_connection_manager;
|
||||
pub mod models_manager;
|
||||
mod network_policy_decision;
|
||||
pub mod network_proxy_loader;
|
||||
pub use mcp_connection_manager::MCP_OUT_OF_BAND_ELICITATION_STATE_CAPABILITY;
|
||||
pub use mcp_connection_manager::MCP_OUT_OF_BAND_ELICITATION_STATE_METHOD;
|
||||
pub use mcp_connection_manager::MCP_SANDBOX_STATE_CAPABILITY;
|
||||
pub use mcp_connection_manager::MCP_SANDBOX_STATE_METHOD;
|
||||
pub use mcp_connection_manager::OutOfBandElicitationState;
|
||||
pub use mcp_connection_manager::SandboxState;
|
||||
mod mcp_tool_call;
|
||||
mod memories;
|
||||
|
||||
@@ -289,6 +289,7 @@ struct ManagedClient {
|
||||
tool_filter: ToolFilter,
|
||||
tool_timeout: Option<Duration>,
|
||||
server_supports_sandbox_state_capability: bool,
|
||||
server_supports_out_of_band_elicitation_state_capability: bool,
|
||||
}
|
||||
|
||||
impl ManagedClient {
|
||||
@@ -307,6 +308,24 @@ impl ManagedClient {
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn notify_out_of_band_elicitation_state_change(
|
||||
&self,
|
||||
state: &OutOfBandElicitationState,
|
||||
) -> Result<()> {
|
||||
if !self.server_supports_out_of_band_elicitation_state_capability {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let _response = self
|
||||
.client
|
||||
.send_custom_request(
|
||||
MCP_OUT_OF_BAND_ELICITATION_STATE_METHOD,
|
||||
Some(serde_json::to_value(state)?),
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -360,13 +379,27 @@ impl AsyncManagedClient {
|
||||
let managed = self.client().await?;
|
||||
managed.notify_sandbox_state_change(sandbox_state).await
|
||||
}
|
||||
|
||||
async fn notify_out_of_band_elicitation_state_change(
|
||||
&self,
|
||||
state: &OutOfBandElicitationState,
|
||||
) -> Result<()> {
|
||||
let managed = self.client().await?;
|
||||
managed
|
||||
.notify_out_of_band_elicitation_state_change(state)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
pub const MCP_SANDBOX_STATE_CAPABILITY: &str = "codex/sandbox-state";
|
||||
pub const MCP_OUT_OF_BAND_ELICITATION_STATE_CAPABILITY: &str =
|
||||
"codex/out-of-band-elicitation-state";
|
||||
|
||||
/// Custom MCP request to push sandbox state updates.
|
||||
/// When used, the `params` field of the notification is [`SandboxState`].
|
||||
pub const MCP_SANDBOX_STATE_METHOD: &str = "codex/sandbox-state/update";
|
||||
pub const MCP_OUT_OF_BAND_ELICITATION_STATE_METHOD: &str =
|
||||
"codex/out-of-band-elicitation-state/update";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -378,6 +411,12 @@ pub struct SandboxState {
|
||||
pub use_linux_sandbox_bwrap: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct OutOfBandElicitationState {
|
||||
pub paused: bool,
|
||||
}
|
||||
|
||||
/// A thin wrapper around a set of running [`RmcpClient`] instances.
|
||||
pub(crate) struct McpConnectionManager {
|
||||
clients: HashMap<String, AsyncManagedClient>,
|
||||
@@ -885,6 +924,41 @@ impl McpConnectionManager {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn notify_out_of_band_elicitation_state_change(
|
||||
&self,
|
||||
state: &OutOfBandElicitationState,
|
||||
) -> Result<()> {
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
for async_managed_client in self.clients.values() {
|
||||
let state = state.clone();
|
||||
let async_managed_client = async_managed_client.clone();
|
||||
join_set.spawn(async move {
|
||||
async_managed_client
|
||||
.notify_out_of_band_elicitation_state_change(&state)
|
||||
.await
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(join_res) = join_set.join_next().await {
|
||||
match join_res {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(err)) => {
|
||||
warn!(
|
||||
"Failed to notify out-of-band elicitation state change to MCP server: {err:#}",
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"Task panic when notifying out-of-band elicitation state change to MCP server: {err:#}",
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn emit_update(
|
||||
@@ -1123,6 +1197,12 @@ async fn start_server_task(
|
||||
.as_ref()
|
||||
.and_then(|exp| exp.get(MCP_SANDBOX_STATE_CAPABILITY))
|
||||
.is_some();
|
||||
let server_supports_out_of_band_elicitation_state_capability = initialize_result
|
||||
.capabilities
|
||||
.experimental
|
||||
.as_ref()
|
||||
.and_then(|exp| exp.get(MCP_OUT_OF_BAND_ELICITATION_STATE_CAPABILITY))
|
||||
.is_some();
|
||||
|
||||
let managed = ManagedClient {
|
||||
client: Arc::clone(&client),
|
||||
@@ -1130,6 +1210,7 @@ async fn start_server_task(
|
||||
tool_timeout: Some(tool_timeout),
|
||||
tool_filter,
|
||||
server_supports_sandbox_state_capability,
|
||||
server_supports_out_of_band_elicitation_state_capability,
|
||||
};
|
||||
|
||||
Ok(managed)
|
||||
|
||||
@@ -82,6 +82,7 @@ mod mcp;
|
||||
mod mcp_escalation_policy;
|
||||
mod socket;
|
||||
mod stopwatch;
|
||||
mod stopwatch_controller;
|
||||
|
||||
pub use mcp::ExecResult;
|
||||
|
||||
|
||||
@@ -4,8 +4,11 @@ use std::time::Duration;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use anyhow::Result;
|
||||
use codex_core::MCP_OUT_OF_BAND_ELICITATION_STATE_CAPABILITY;
|
||||
use codex_core::MCP_OUT_OF_BAND_ELICITATION_STATE_METHOD;
|
||||
use codex_core::MCP_SANDBOX_STATE_CAPABILITY;
|
||||
use codex_core::MCP_SANDBOX_STATE_METHOD;
|
||||
use codex_core::OutOfBandElicitationState;
|
||||
use codex_core::SandboxState;
|
||||
use codex_core::protocol::SandboxPolicy;
|
||||
use codex_execpolicy::Policy;
|
||||
@@ -32,11 +35,13 @@ use crate::posix::escalate_server::EscalateServer;
|
||||
use crate::posix::escalate_server::{self};
|
||||
use crate::posix::mcp_escalation_policy::McpEscalationPolicy;
|
||||
use crate::posix::stopwatch::Stopwatch;
|
||||
use crate::posix::stopwatch_controller::StopwatchController;
|
||||
|
||||
/// Path to our patched bash.
|
||||
const CODEX_BASH_PATH_ENV_VAR: &str = "CODEX_BASH_PATH";
|
||||
|
||||
const SANDBOX_STATE_CAPABILITY_VERSION: &str = "1.0.0";
|
||||
const OUT_OF_BAND_ELICITATION_STATE_CAPABILITY_VERSION: &str = "1.0.0";
|
||||
|
||||
pub(crate) fn get_bash_path() -> Result<PathBuf> {
|
||||
std::env::var(CODEX_BASH_PATH_ENV_VAR)
|
||||
@@ -83,6 +88,7 @@ pub struct ExecTool {
|
||||
policy: Arc<RwLock<Policy>>,
|
||||
preserve_program_paths: bool,
|
||||
sandbox_state: Arc<RwLock<Option<SandboxState>>>,
|
||||
stopwatch_controller: StopwatchController,
|
||||
}
|
||||
|
||||
#[tool_router]
|
||||
@@ -100,6 +106,7 @@ impl ExecTool {
|
||||
policy,
|
||||
preserve_program_paths,
|
||||
sandbox_state: Arc::new(RwLock::new(None)),
|
||||
stopwatch_controller: StopwatchController::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,6 +123,7 @@ impl ExecTool {
|
||||
.unwrap_or(codex_core::exec::DEFAULT_EXEC_COMMAND_TIMEOUT_MS),
|
||||
);
|
||||
let stopwatch = Stopwatch::new(effective_timeout);
|
||||
let stopwatch_id = self.stopwatch_controller.register(stopwatch.clone()).await;
|
||||
let cancel_token = stopwatch.cancellation_token();
|
||||
let sandbox_state =
|
||||
self.sandbox_state
|
||||
@@ -141,8 +149,9 @@ impl ExecTool {
|
||||
|
||||
let result = escalate_server
|
||||
.exec(params, cancel_token, &sandbox_state)
|
||||
.await
|
||||
.map_err(|e| McpError::internal_error(e.to_string(), None))?;
|
||||
.await;
|
||||
self.stopwatch_controller.unregister(stopwatch_id).await;
|
||||
let result = result.map_err(|e| McpError::internal_error(e.to_string(), None))?;
|
||||
Ok(CallToolResult::success(vec![Content::json(
|
||||
ExecResult::from(result),
|
||||
)?]))
|
||||
@@ -169,6 +178,15 @@ impl ServerHandler for ExecTool {
|
||||
MCP_SANDBOX_STATE_CAPABILITY.to_string(),
|
||||
sandbox_state_capability,
|
||||
);
|
||||
let mut out_of_band_elicitation_state_capability = JsonObject::new();
|
||||
out_of_band_elicitation_state_capability.insert(
|
||||
"version".to_string(),
|
||||
serde_json::Value::String(OUT_OF_BAND_ELICITATION_STATE_CAPABILITY_VERSION.to_string()),
|
||||
);
|
||||
experimental_capabilities.insert(
|
||||
MCP_OUT_OF_BAND_ELICITATION_STATE_CAPABILITY.to_string(),
|
||||
out_of_band_elicitation_state_capability,
|
||||
);
|
||||
ServerInfo {
|
||||
protocol_version: ProtocolVersion::V_2025_06_18,
|
||||
capabilities: ServerCapabilities::builder()
|
||||
@@ -197,27 +215,47 @@ impl ServerHandler for ExecTool {
|
||||
_context: rmcp::service::RequestContext<rmcp::RoleServer>,
|
||||
) -> Result<CustomResult, McpError> {
|
||||
let CustomRequest { method, params, .. } = request;
|
||||
if method != MCP_SANDBOX_STATE_METHOD {
|
||||
return Err(McpError::method_not_found::<CodexSandboxStateUpdateMethod>());
|
||||
match method.as_ref() {
|
||||
MCP_SANDBOX_STATE_METHOD => {
|
||||
let Some(params) = params else {
|
||||
return Err(McpError::invalid_params(
|
||||
"missing params for sandbox state request".to_string(),
|
||||
None,
|
||||
));
|
||||
};
|
||||
|
||||
let Ok(sandbox_state) = serde_json::from_value::<SandboxState>(params.clone())
|
||||
else {
|
||||
return Err(McpError::invalid_params(
|
||||
"failed to deserialize sandbox state".to_string(),
|
||||
Some(params),
|
||||
));
|
||||
};
|
||||
|
||||
*self.sandbox_state.write().await = Some(sandbox_state);
|
||||
Ok(CustomResult::new(json!({})))
|
||||
}
|
||||
MCP_OUT_OF_BAND_ELICITATION_STATE_METHOD => {
|
||||
let Some(params) = params else {
|
||||
return Err(McpError::invalid_params(
|
||||
"missing params for out-of-band elicitation state request".to_string(),
|
||||
None,
|
||||
));
|
||||
};
|
||||
|
||||
let Ok(state) = serde_json::from_value::<OutOfBandElicitationState>(params.clone())
|
||||
else {
|
||||
return Err(McpError::invalid_params(
|
||||
"failed to deserialize out-of-band elicitation state".to_string(),
|
||||
Some(params),
|
||||
));
|
||||
};
|
||||
|
||||
self.stopwatch_controller.set_paused(state.paused).await;
|
||||
Ok(CustomResult::new(json!({})))
|
||||
}
|
||||
_ => Err(McpError::method_not_found::<CodexSandboxStateUpdateMethod>()),
|
||||
}
|
||||
|
||||
let Some(params) = params else {
|
||||
return Err(McpError::invalid_params(
|
||||
"missing params for sandbox state request".to_string(),
|
||||
None,
|
||||
));
|
||||
};
|
||||
|
||||
let Ok(sandbox_state) = serde_json::from_value::<SandboxState>(params.clone()) else {
|
||||
return Err(McpError::invalid_params(
|
||||
"failed to deserialize sandbox state".to_string(),
|
||||
Some(params),
|
||||
));
|
||||
};
|
||||
|
||||
*self.sandbox_state.write().await = Some(sandbox_state);
|
||||
|
||||
Ok(CustomResult::new(json!({})))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ impl Stopwatch {
|
||||
result
|
||||
}
|
||||
|
||||
async fn pause(&self) {
|
||||
pub(crate) async fn pause(&self) {
|
||||
let mut guard = self.inner.lock().await;
|
||||
guard.active_pauses += 1;
|
||||
if guard.active_pauses == 1
|
||||
@@ -101,7 +101,7 @@ impl Stopwatch {
|
||||
}
|
||||
}
|
||||
|
||||
async fn resume(&self) {
|
||||
pub(crate) async fn resume(&self) {
|
||||
let mut guard = self.inner.lock().await;
|
||||
if guard.active_pauses == 0 {
|
||||
return;
|
||||
|
||||
143
codex-rs/exec-server/src/posix/stopwatch_controller.rs
Normal file
143
codex-rs/exec-server/src/posix/stopwatch_controller.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::posix::stopwatch::Stopwatch;
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub(crate) struct StopwatchController {
|
||||
state: Arc<Mutex<StopwatchControllerState>>,
|
||||
operation_lock: Arc<Mutex<()>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct StopwatchControllerState {
|
||||
paused: bool,
|
||||
next_stopwatch_id: u64,
|
||||
stopwatches: HashMap<u64, Stopwatch>,
|
||||
}
|
||||
|
||||
impl StopwatchController {
|
||||
pub(crate) async fn register(&self, stopwatch: Stopwatch) -> u64 {
|
||||
let _operation_guard = self.operation_lock.lock().await;
|
||||
let (stopwatch_id, paused) = {
|
||||
let mut guard = self.state.lock().await;
|
||||
let stopwatch_id = guard.next_stopwatch_id;
|
||||
guard.next_stopwatch_id += 1;
|
||||
guard.stopwatches.insert(stopwatch_id, stopwatch.clone());
|
||||
(stopwatch_id, guard.paused)
|
||||
};
|
||||
|
||||
if paused {
|
||||
stopwatch.pause().await;
|
||||
}
|
||||
|
||||
stopwatch_id
|
||||
}
|
||||
|
||||
pub(crate) async fn unregister(&self, stopwatch_id: u64) {
|
||||
let _operation_guard = self.operation_lock.lock().await;
|
||||
let (stopwatch, paused) = {
|
||||
let mut guard = self.state.lock().await;
|
||||
(guard.stopwatches.remove(&stopwatch_id), guard.paused)
|
||||
};
|
||||
|
||||
if paused && let Some(stopwatch) = stopwatch {
|
||||
stopwatch.resume().await;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn set_paused(&self, paused: bool) {
|
||||
let _operation_guard = self.operation_lock.lock().await;
|
||||
let stopwatches = {
|
||||
let mut guard = self.state.lock().await;
|
||||
if guard.paused == paused {
|
||||
return;
|
||||
}
|
||||
guard.paused = paused;
|
||||
guard.stopwatches.values().cloned().collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
for stopwatch in stopwatches {
|
||||
if paused {
|
||||
stopwatch.pause().await;
|
||||
} else {
|
||||
stopwatch.resume().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::StopwatchController;
|
||||
use crate::posix::stopwatch::Stopwatch;
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
#[tokio::test]
|
||||
async fn pausing_controller_pauses_registered_stopwatch() {
|
||||
let controller = StopwatchController::default();
|
||||
let stopwatch = Stopwatch::new(Duration::from_millis(50));
|
||||
let token = stopwatch.cancellation_token();
|
||||
|
||||
let stopwatch_id = controller.register(stopwatch).await;
|
||||
controller.set_paused(true).await;
|
||||
|
||||
assert!(
|
||||
timeout(Duration::from_millis(30), token.cancelled())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
controller.set_paused(false).await;
|
||||
controller.unregister(stopwatch_id).await;
|
||||
token.cancelled().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn registering_while_paused_starts_paused() {
|
||||
let controller = StopwatchController::default();
|
||||
controller.set_paused(true).await;
|
||||
|
||||
let stopwatch = Stopwatch::new(Duration::from_millis(50));
|
||||
let token = stopwatch.cancellation_token();
|
||||
|
||||
let stopwatch_id = controller.register(stopwatch).await;
|
||||
|
||||
assert!(
|
||||
timeout(Duration::from_millis(30), token.cancelled())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
controller.set_paused(false).await;
|
||||
controller.unregister(stopwatch_id).await;
|
||||
token.cancelled().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unregistering_while_paused_resumes_controller_pause() {
|
||||
let controller = StopwatchController::default();
|
||||
let stopwatch = Stopwatch::new(Duration::from_millis(50));
|
||||
let token = stopwatch.cancellation_token();
|
||||
|
||||
let stopwatch_id = controller.register(stopwatch).await;
|
||||
controller.set_paused(true).await;
|
||||
|
||||
assert!(
|
||||
timeout(Duration::from_millis(30), token.cancelled())
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
controller.unregister(stopwatch_id).await;
|
||||
|
||||
assert!(
|
||||
timeout(Duration::from_millis(120), token.cancelled())
|
||||
.await
|
||||
.is_ok()
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user