mirror of
https://github.com/openai/codex.git
synced 2026-04-19 20:24:50 +00:00
Compare commits
2 Commits
xl/plugins
...
ruslan/ref
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1038c7dd83 | ||
|
|
e8f7382a28 |
@@ -7,7 +7,7 @@
|
||||
//! deriving client identity from the typed [`ClientRequest`] rather than
|
||||
//! from a parsed JSON envelope.
|
||||
|
||||
use crate::message_processor::ConnectionSessionState;
|
||||
use crate::message_processor::ConnectionState;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::transport::AppServerTransport;
|
||||
use codex_app_server_protocol::ClientRequest;
|
||||
@@ -25,7 +25,7 @@ pub(crate) fn request_span(
|
||||
request: &JSONRPCRequest,
|
||||
transport: AppServerTransport,
|
||||
connection_id: ConnectionId,
|
||||
session: &ConnectionSessionState,
|
||||
session: &ConnectionState,
|
||||
) -> Span {
|
||||
let initialize_client_info = initialize_client_info(request);
|
||||
let method = request.method.as_str();
|
||||
@@ -62,7 +62,7 @@ pub(crate) fn request_span(
|
||||
pub(crate) fn typed_request_span(
|
||||
request: &ClientRequest,
|
||||
connection_id: ConnectionId,
|
||||
session: &ConnectionSessionState,
|
||||
session: &ConnectionState,
|
||||
) -> Span {
|
||||
let method = request.method();
|
||||
let span = app_server_request_span_template(&method, "in-process", request.id(), connection_id);
|
||||
@@ -142,7 +142,7 @@ fn attach_parent_context(
|
||||
|
||||
fn client_name<'a>(
|
||||
initialize_client_info: Option<&'a InitializeParams>,
|
||||
session: &'a ConnectionSessionState,
|
||||
session: &'a ConnectionState,
|
||||
) -> Option<&'a str> {
|
||||
if let Some(params) = initialize_client_info {
|
||||
return Some(params.client_info.name.as_str());
|
||||
@@ -152,7 +152,7 @@ fn client_name<'a>(
|
||||
|
||||
fn client_version<'a>(
|
||||
initialize_client_info: Option<&'a InitializeParams>,
|
||||
session: &'a ConnectionSessionState,
|
||||
session: &'a ConnectionState,
|
||||
) -> Option<&'a str> {
|
||||
if let Some(params) = initialize_client_info {
|
||||
return Some(params.client_info.version.as_str());
|
||||
|
||||
@@ -39,21 +39,17 @@
|
||||
//! helpers, surface-specific startup policy, and bounded shutdown.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::io::Error as IoError;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Result as IoResult;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::error_code::INTERNAL_ERROR_CODE;
|
||||
use crate::error_code::INVALID_REQUEST_ERROR_CODE;
|
||||
use crate::error_code::OVERLOADED_ERROR_CODE;
|
||||
use crate::message_processor::ConnectionSessionState;
|
||||
use crate::message_processor::ConnectionState;
|
||||
use crate::message_processor::MessageProcessor;
|
||||
use crate::message_processor::MessageProcessorArgs;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
@@ -360,18 +356,14 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle {
|
||||
let outgoing_message_sender = Arc::new(OutgoingMessageSender::new(outgoing_tx));
|
||||
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel::<QueuedOutgoingMessage>(channel_capacity);
|
||||
let outbound_initialized = Arc::new(AtomicBool::new(false));
|
||||
let outbound_experimental_api_enabled = Arc::new(AtomicBool::new(false));
|
||||
let outbound_opted_out_notification_methods = Arc::new(RwLock::new(HashSet::new()));
|
||||
let connection_state = Arc::new(ConnectionState::default());
|
||||
|
||||
let mut outbound_connections = HashMap::<ConnectionId, OutboundConnectionState>::new();
|
||||
outbound_connections.insert(
|
||||
IN_PROCESS_CONNECTION_ID,
|
||||
OutboundConnectionState::new(
|
||||
writer_tx,
|
||||
Arc::clone(&outbound_initialized),
|
||||
Arc::clone(&outbound_experimental_api_enabled),
|
||||
Arc::clone(&outbound_opted_out_notification_methods),
|
||||
Arc::clone(&connection_state),
|
||||
/*disconnect_sender*/ None,
|
||||
),
|
||||
);
|
||||
@@ -403,7 +395,6 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle {
|
||||
remote_control_handle: None,
|
||||
}));
|
||||
let mut thread_created_rx = processor.thread_created_receiver();
|
||||
let session = Arc::new(ConnectionSessionState::default());
|
||||
let mut listen_for_threads = true;
|
||||
|
||||
loop {
|
||||
@@ -411,32 +402,15 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle {
|
||||
command = processor_rx.recv() => {
|
||||
match command {
|
||||
Some(ProcessorCommand::Request(request)) => {
|
||||
let was_initialized = session.initialized();
|
||||
let was_initialized = connection_state.initialized();
|
||||
processor
|
||||
.process_client_request(
|
||||
IN_PROCESS_CONNECTION_ID,
|
||||
*request,
|
||||
Arc::clone(&session),
|
||||
&outbound_initialized,
|
||||
Arc::clone(&connection_state),
|
||||
)
|
||||
.await;
|
||||
let opted_out_notification_methods_snapshot =
|
||||
session.opted_out_notification_methods();
|
||||
let experimental_api_enabled =
|
||||
session.experimental_api_enabled();
|
||||
let is_initialized = session.initialized();
|
||||
if let Ok(mut opted_out_notification_methods) =
|
||||
outbound_opted_out_notification_methods.write()
|
||||
{
|
||||
*opted_out_notification_methods =
|
||||
opted_out_notification_methods_snapshot;
|
||||
} else {
|
||||
warn!("failed to update outbound opted-out notifications");
|
||||
}
|
||||
outbound_experimental_api_enabled.store(
|
||||
experimental_api_enabled,
|
||||
Ordering::Release,
|
||||
);
|
||||
let is_initialized = connection_state.initialized();
|
||||
if !was_initialized && is_initialized {
|
||||
processor.send_initialize_notifications().await;
|
||||
}
|
||||
@@ -452,7 +426,7 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle {
|
||||
created = thread_created_rx.recv(), if listen_for_threads => {
|
||||
match created {
|
||||
Ok(thread_id) => {
|
||||
let connection_ids = if session.initialized() {
|
||||
let connection_ids = if connection_state.initialized() {
|
||||
vec![IN_PROCESS_CONNECTION_ID]
|
||||
} else {
|
||||
Vec::<ConnectionId>::new()
|
||||
|
||||
@@ -11,13 +11,11 @@ use codex_features::Feature;
|
||||
use codex_login::AuthManager;
|
||||
use codex_utils_cli::CliConfigOverrides;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::Result as IoResult;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
use crate::message_processor::ConnectionState;
|
||||
use crate::message_processor::MessageProcessor;
|
||||
use crate::message_processor::MessageProcessorArgs;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
@@ -25,7 +23,6 @@ use crate::outgoing_message::OutgoingEnvelope;
|
||||
use crate::outgoing_message::OutgoingMessageSender;
|
||||
use crate::outgoing_message::QueuedOutgoingMessage;
|
||||
use crate::transport::CHANNEL_CAPACITY;
|
||||
use crate::transport::ConnectionState;
|
||||
use crate::transport::OutboundConnectionState;
|
||||
use crate::transport::TransportEvent;
|
||||
use crate::transport::auth::policy_from_settings;
|
||||
@@ -117,9 +114,7 @@ enum OutboundControlEvent {
|
||||
connection_id: ConnectionId,
|
||||
writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
initialized: Arc<AtomicBool>,
|
||||
experimental_api_enabled: Arc<AtomicBool>,
|
||||
opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
connection_state: Arc<ConnectionState>,
|
||||
},
|
||||
/// Remove state for a closed/disconnected connection.
|
||||
Closed { connection_id: ConnectionId },
|
||||
@@ -604,17 +599,13 @@ pub async fn run_main_with_transport(
|
||||
connection_id,
|
||||
writer,
|
||||
disconnect_sender,
|
||||
initialized,
|
||||
experimental_api_enabled,
|
||||
opted_out_notification_methods,
|
||||
connection_state,
|
||||
} => {
|
||||
outbound_connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
writer,
|
||||
initialized,
|
||||
experimental_api_enabled,
|
||||
opted_out_notification_methods,
|
||||
connection_state,
|
||||
disconnect_sender,
|
||||
),
|
||||
);
|
||||
@@ -670,7 +661,7 @@ pub async fn run_main_with_transport(
|
||||
}));
|
||||
let mut thread_created_rx = processor.thread_created_receiver();
|
||||
let mut running_turn_count_rx = processor.subscribe_running_assistant_turn_count();
|
||||
let mut connections = HashMap::<ConnectionId, ConnectionState>::new();
|
||||
let mut connections = HashMap::<ConnectionId, Arc<ConnectionState>>::new();
|
||||
let transport_shutdown_token = transport_shutdown_token.clone();
|
||||
async move {
|
||||
let mut listen_for_threads = true;
|
||||
@@ -714,37 +705,20 @@ pub async fn run_main_with_transport(
|
||||
writer,
|
||||
disconnect_sender,
|
||||
} => {
|
||||
let outbound_initialized = Arc::new(AtomicBool::new(false));
|
||||
let outbound_experimental_api_enabled =
|
||||
Arc::new(AtomicBool::new(false));
|
||||
let outbound_opted_out_notification_methods =
|
||||
Arc::new(RwLock::new(HashSet::new()));
|
||||
let connection_state = Arc::new(ConnectionState::default());
|
||||
if outbound_control_tx
|
||||
.send(OutboundControlEvent::Opened {
|
||||
connection_id,
|
||||
writer,
|
||||
disconnect_sender,
|
||||
initialized: Arc::clone(&outbound_initialized),
|
||||
experimental_api_enabled: Arc::clone(
|
||||
&outbound_experimental_api_enabled,
|
||||
),
|
||||
opted_out_notification_methods: Arc::clone(
|
||||
&outbound_opted_out_notification_methods,
|
||||
),
|
||||
connection_state: Arc::clone(&connection_state),
|
||||
})
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
connections.insert(
|
||||
connection_id,
|
||||
ConnectionState::new(
|
||||
outbound_initialized,
|
||||
outbound_experimental_api_enabled,
|
||||
outbound_opted_out_notification_methods,
|
||||
),
|
||||
);
|
||||
connections.insert(connection_id, connection_state);
|
||||
}
|
||||
TransportEvent::ConnectionClosed { connection_id } => {
|
||||
if connections.remove(&connection_id).is_none() {
|
||||
@@ -765,43 +739,21 @@ pub async fn run_main_with_transport(
|
||||
TransportEvent::IncomingMessage { connection_id, message } => {
|
||||
match message {
|
||||
JSONRPCMessage::Request(request) => {
|
||||
let Some(connection_state) = connections.get_mut(&connection_id) else {
|
||||
let Some(connection_state) = connections.get(&connection_id) else {
|
||||
warn!("dropping request from unknown connection: {connection_id:?}");
|
||||
continue;
|
||||
};
|
||||
let was_initialized =
|
||||
connection_state.session.initialized();
|
||||
connection_state.initialized();
|
||||
processor
|
||||
.process_request(
|
||||
connection_id,
|
||||
request,
|
||||
transport,
|
||||
Arc::clone(&connection_state.session),
|
||||
Arc::clone(connection_state),
|
||||
)
|
||||
.await;
|
||||
let opted_out_notification_methods_snapshot = connection_state
|
||||
.session
|
||||
.opted_out_notification_methods();
|
||||
let experimental_api_enabled =
|
||||
connection_state.session.experimental_api_enabled();
|
||||
let is_initialized = connection_state.session.initialized();
|
||||
if let Ok(mut opted_out_notification_methods) = connection_state
|
||||
.outbound_opted_out_notification_methods
|
||||
.write()
|
||||
{
|
||||
*opted_out_notification_methods =
|
||||
opted_out_notification_methods_snapshot;
|
||||
} else {
|
||||
warn!(
|
||||
"failed to update outbound opted-out notifications"
|
||||
);
|
||||
}
|
||||
connection_state
|
||||
.outbound_experimental_api_enabled
|
||||
.store(
|
||||
experimental_api_enabled,
|
||||
std::sync::atomic::Ordering::Release,
|
||||
);
|
||||
let is_initialized = connection_state.initialized();
|
||||
if !was_initialized && is_initialized {
|
||||
processor
|
||||
.send_initialize_notifications_to_connection(
|
||||
@@ -844,7 +796,7 @@ pub async fn run_main_with_transport(
|
||||
Ok(thread_id) => {
|
||||
let mut initialized_connection_ids = Vec::new();
|
||||
for (connection_id, connection_state) in &connections {
|
||||
if connection_state.session.initialized() {
|
||||
if connection_state.initialized() {
|
||||
initialized_connection_ids.push(*connection_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,19 +176,20 @@ pub(crate) struct MessageProcessor {
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct ConnectionSessionState {
|
||||
initialized: OnceLock<InitializedConnectionSessionState>,
|
||||
pub(crate) struct ConnectionState {
|
||||
initialized: OnceLock<InitializedConnectionState>,
|
||||
pub(crate) outbound_initialized: AtomicBool,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InitializedConnectionSessionState {
|
||||
struct InitializedConnectionState {
|
||||
experimental_api_enabled: bool,
|
||||
opted_out_notification_methods: HashSet<String>,
|
||||
app_server_client_name: String,
|
||||
client_version: String,
|
||||
}
|
||||
|
||||
impl ConnectionSessionState {
|
||||
impl ConnectionState {
|
||||
pub(crate) fn initialized(&self) -> bool {
|
||||
self.initialized.get().is_some()
|
||||
}
|
||||
@@ -218,9 +219,27 @@ impl ConnectionSessionState {
|
||||
.map(|session| session.client_version.as_str())
|
||||
}
|
||||
|
||||
fn initialize(&self, session: InitializedConnectionSessionState) -> Result<(), ()> {
|
||||
fn initialize(&self, session: InitializedConnectionState) -> Result<(), ()> {
|
||||
self.initialized.set(session).map_err(|_| ())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn new_test(
|
||||
experimental_api_enabled: bool,
|
||||
opted_out_notification_methods: HashSet<String>,
|
||||
) -> Self {
|
||||
let state = Self::default();
|
||||
state
|
||||
.initialize(InitializedConnectionState {
|
||||
experimental_api_enabled,
|
||||
opted_out_notification_methods,
|
||||
app_server_client_name: "codex-app-server-tests".to_string(),
|
||||
client_version: "0.1.0".to_string(),
|
||||
})
|
||||
.expect("test connection state should initialize once");
|
||||
state.outbound_initialized.store(true, Ordering::Release);
|
||||
state
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct MessageProcessorArgs {
|
||||
@@ -344,7 +363,7 @@ impl MessageProcessor {
|
||||
connection_id: ConnectionId,
|
||||
request: JSONRPCRequest,
|
||||
transport: AppServerTransport,
|
||||
session: Arc<ConnectionSessionState>,
|
||||
connection_state: Arc<ConnectionState>,
|
||||
) {
|
||||
let request_method = request.method.as_str();
|
||||
tracing::trace!(
|
||||
@@ -356,8 +375,12 @@ impl MessageProcessor {
|
||||
connection_id,
|
||||
request_id: request.id.clone(),
|
||||
};
|
||||
let request_span =
|
||||
crate::app_server_tracing::request_span(&request, transport, connection_id, &session);
|
||||
let request_span = crate::app_server_tracing::request_span(
|
||||
&request,
|
||||
transport,
|
||||
connection_id,
|
||||
&connection_state,
|
||||
);
|
||||
let request_trace = request.trace.as_ref().map(|trace| W3cTraceContext {
|
||||
traceparent: trace.traceparent.clone(),
|
||||
tracestate: trace.tracestate.clone(),
|
||||
@@ -392,15 +415,13 @@ impl MessageProcessor {
|
||||
return;
|
||||
}
|
||||
};
|
||||
// Websocket callers finalize outbound readiness in lib.rs after mirroring
|
||||
// session state into outbound state and sending initialize notifications to
|
||||
// this specific connection. Passing `None` avoids marking the connection
|
||||
// ready too early from inside the shared request handler.
|
||||
// Websocket callers finalize outbound readiness in lib.rs after sending
|
||||
// initialize notifications to this specific connection.
|
||||
self.handle_client_request(
|
||||
request_id.clone(),
|
||||
codex_request,
|
||||
Arc::clone(&session),
|
||||
/*outbound_initialized*/ None,
|
||||
Arc::clone(&connection_state),
|
||||
/*initialize_outbound*/ false,
|
||||
request_context.clone(),
|
||||
)
|
||||
.await;
|
||||
@@ -417,15 +438,17 @@ impl MessageProcessor {
|
||||
self: &Arc<Self>,
|
||||
connection_id: ConnectionId,
|
||||
request: ClientRequest,
|
||||
session: Arc<ConnectionSessionState>,
|
||||
outbound_initialized: &AtomicBool,
|
||||
connection_state: Arc<ConnectionState>,
|
||||
) {
|
||||
let request_id = ConnectionRequestId {
|
||||
connection_id,
|
||||
request_id: request.id().clone(),
|
||||
};
|
||||
let request_span =
|
||||
crate::app_server_tracing::typed_request_span(&request, connection_id, &session);
|
||||
let request_span = crate::app_server_tracing::typed_request_span(
|
||||
&request,
|
||||
connection_id,
|
||||
&connection_state,
|
||||
);
|
||||
let request_context =
|
||||
RequestContext::new(request_id.clone(), request_span, /*parent_trace*/ None);
|
||||
tracing::trace!(
|
||||
@@ -443,8 +466,8 @@ impl MessageProcessor {
|
||||
self.handle_client_request(
|
||||
request_id.clone(),
|
||||
request,
|
||||
Arc::clone(&session),
|
||||
Some(outbound_initialized),
|
||||
Arc::clone(&connection_state),
|
||||
/*initialize_outbound*/ true,
|
||||
request_context.clone(),
|
||||
)
|
||||
.await;
|
||||
@@ -569,11 +592,11 @@ impl MessageProcessor {
|
||||
self: &Arc<Self>,
|
||||
connection_request_id: ConnectionRequestId,
|
||||
codex_request: ClientRequest,
|
||||
session: Arc<ConnectionSessionState>,
|
||||
// `Some(...)` means the caller wants initialize to immediately mark the
|
||||
// connection outbound-ready. Websocket JSON-RPC calls pass `None` so
|
||||
connection_state: Arc<ConnectionState>,
|
||||
// `true` means the caller wants initialize to immediately mark the
|
||||
// connection outbound-ready. Websocket JSON-RPC calls pass `false` so
|
||||
// lib.rs can deliver connection-scoped initialize notifications first.
|
||||
outbound_initialized: Option<&AtomicBool>,
|
||||
initialize_outbound: bool,
|
||||
request_context: RequestContext,
|
||||
) {
|
||||
let connection_id = connection_request_id.connection_id;
|
||||
@@ -584,7 +607,7 @@ impl MessageProcessor {
|
||||
connection_id,
|
||||
request_id,
|
||||
};
|
||||
if session.initialized() {
|
||||
if connection_state.initialized() {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: "Already initialized".to_string(),
|
||||
@@ -634,8 +657,8 @@ impl MessageProcessor {
|
||||
let originator = name.clone();
|
||||
let user_agent_suffix = format!("{name}; {version}");
|
||||
let codex_home = self.config.codex_home.clone();
|
||||
if session
|
||||
.initialize(InitializedConnectionSessionState {
|
||||
if connection_state
|
||||
.initialize(InitializedConnectionState {
|
||||
experimental_api_enabled,
|
||||
opted_out_notification_methods: opt_out_notification_methods
|
||||
.into_iter()
|
||||
@@ -697,11 +720,13 @@ impl MessageProcessor {
|
||||
.send_response(connection_request_id, response)
|
||||
.await;
|
||||
|
||||
if let Some(outbound_initialized) = outbound_initialized {
|
||||
if initialize_outbound {
|
||||
// In-process clients can complete readiness immediately here. The
|
||||
// websocket path defers this until lib.rs finishes transport-layer
|
||||
// initialize handling for the specific connection.
|
||||
outbound_initialized.store(true, Ordering::Release);
|
||||
connection_state
|
||||
.outbound_initialized
|
||||
.store(true, Ordering::Release);
|
||||
self.codex_message_processor
|
||||
.connection_initialized(connection_id)
|
||||
.await;
|
||||
@@ -712,7 +737,7 @@ impl MessageProcessor {
|
||||
self.dispatch_initialized_client_request(
|
||||
connection_request_id,
|
||||
codex_request,
|
||||
session,
|
||||
connection_state,
|
||||
request_context,
|
||||
)
|
||||
.await;
|
||||
@@ -722,12 +747,10 @@ impl MessageProcessor {
|
||||
self: &Arc<Self>,
|
||||
connection_request_id: ConnectionRequestId,
|
||||
codex_request: ClientRequest,
|
||||
session: Arc<ConnectionSessionState>,
|
||||
connection_state: Arc<ConnectionState>,
|
||||
request_context: RequestContext,
|
||||
) {
|
||||
let connection_id = connection_request_id.connection_id;
|
||||
|
||||
if !session.initialized() {
|
||||
if !connection_state.initialized() {
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
message: "Not initialized".to_string(),
|
||||
@@ -738,7 +761,7 @@ impl MessageProcessor {
|
||||
}
|
||||
|
||||
if let Some(reason) = codex_request.experimental_reason()
|
||||
&& !session.experimental_api_enabled()
|
||||
&& !connection_state.experimental_api_enabled()
|
||||
{
|
||||
let error = JSONRPCErrorError {
|
||||
code: INVALID_REQUEST_ERROR_CODE,
|
||||
@@ -753,14 +776,16 @@ impl MessageProcessor {
|
||||
| ClientRequest::TurnSteer { request_id, .. } = &codex_request
|
||||
{
|
||||
self.analytics_events_client.track_request(
|
||||
connection_id.0,
|
||||
connection_request_id.connection_id.0,
|
||||
request_id.clone(),
|
||||
codex_request.clone(),
|
||||
);
|
||||
}
|
||||
|
||||
let app_server_client_name = session.app_server_client_name().map(str::to_string);
|
||||
let client_version = session.client_version().map(str::to_string);
|
||||
let app_server_client_name = connection_state
|
||||
.app_server_client_name()
|
||||
.map(str::to_string);
|
||||
let client_version = connection_state.client_version().map(str::to_string);
|
||||
Arc::clone(self)
|
||||
.handle_initialized_client_request(
|
||||
connection_request_id,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::ConnectionSessionState;
|
||||
use super::ConnectionState;
|
||||
use super::MessageProcessor;
|
||||
use super::MessageProcessorArgs;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
@@ -111,7 +111,7 @@ struct TracingHarness {
|
||||
_codex_home: TempDir,
|
||||
processor: Arc<MessageProcessor>,
|
||||
outgoing_rx: mpsc::Receiver<crate::outgoing_message::OutgoingEnvelope>,
|
||||
session: Arc<ConnectionSessionState>,
|
||||
session: Arc<ConnectionState>,
|
||||
tracing: &'static TestTracing,
|
||||
}
|
||||
|
||||
@@ -129,7 +129,7 @@ impl TracingHarness {
|
||||
_codex_home: codex_home,
|
||||
processor,
|
||||
outgoing_rx,
|
||||
session: Arc::new(ConnectionSessionState::default()),
|
||||
session: Arc::new(ConnectionState::default()),
|
||||
tracing,
|
||||
};
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
pub(crate) mod auth;
|
||||
|
||||
use crate::error_code::OVERLOADED_ERROR_CODE;
|
||||
use crate::message_processor::ConnectionSessionState;
|
||||
use crate::message_processor::ConnectionState;
|
||||
use crate::outgoing_message::ConnectionId;
|
||||
use crate::outgoing_message::OutgoingEnvelope;
|
||||
use crate::outgoing_message::OutgoingError;
|
||||
@@ -11,12 +11,9 @@ use codex_app_server_protocol::JSONRPCErrorError;
|
||||
use codex_app_server_protocol::JSONRPCMessage;
|
||||
use codex_app_server_protocol::ServerRequest;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::net::SocketAddr;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::atomic::Ordering;
|
||||
use tokio::sync::mpsc;
|
||||
@@ -117,32 +114,8 @@ pub(crate) enum TransportEvent {
|
||||
},
|
||||
}
|
||||
|
||||
pub(crate) struct ConnectionState {
|
||||
pub(crate) outbound_initialized: Arc<AtomicBool>,
|
||||
pub(crate) outbound_experimental_api_enabled: Arc<AtomicBool>,
|
||||
pub(crate) outbound_opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
pub(crate) session: Arc<ConnectionSessionState>,
|
||||
}
|
||||
|
||||
impl ConnectionState {
|
||||
pub(crate) fn new(
|
||||
outbound_initialized: Arc<AtomicBool>,
|
||||
outbound_experimental_api_enabled: Arc<AtomicBool>,
|
||||
outbound_opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
outbound_initialized,
|
||||
outbound_experimental_api_enabled,
|
||||
outbound_opted_out_notification_methods,
|
||||
session: Arc::new(ConnectionSessionState::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct OutboundConnectionState {
|
||||
pub(crate) initialized: Arc<AtomicBool>,
|
||||
pub(crate) experimental_api_enabled: Arc<AtomicBool>,
|
||||
pub(crate) opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
pub(crate) connection_state: Arc<ConnectionState>,
|
||||
pub(crate) writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
}
|
||||
@@ -150,15 +123,11 @@ pub(crate) struct OutboundConnectionState {
|
||||
impl OutboundConnectionState {
|
||||
pub(crate) fn new(
|
||||
writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
initialized: Arc<AtomicBool>,
|
||||
experimental_api_enabled: Arc<AtomicBool>,
|
||||
opted_out_notification_methods: Arc<RwLock<HashSet<String>>>,
|
||||
connection_state: Arc<ConnectionState>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
) -> Self {
|
||||
Self {
|
||||
initialized,
|
||||
experimental_api_enabled,
|
||||
opted_out_notification_methods,
|
||||
connection_state,
|
||||
writer,
|
||||
disconnect_sender,
|
||||
}
|
||||
@@ -260,11 +229,9 @@ fn should_skip_notification_for_connection(
|
||||
connection_state: &OutboundConnectionState,
|
||||
message: &OutgoingMessage,
|
||||
) -> bool {
|
||||
let Ok(opted_out_notification_methods) = connection_state.opted_out_notification_methods.read()
|
||||
else {
|
||||
warn!("failed to read outbound opted-out notifications");
|
||||
return false;
|
||||
};
|
||||
let opted_out_notification_methods = connection_state
|
||||
.connection_state
|
||||
.opted_out_notification_methods();
|
||||
match message {
|
||||
OutgoingMessage::AppServerNotification(notification) => {
|
||||
let method = notification.to_string();
|
||||
@@ -329,9 +296,7 @@ fn filter_outgoing_message_for_connection(
|
||||
connection_state: &OutboundConnectionState,
|
||||
message: OutgoingMessage,
|
||||
) -> OutgoingMessage {
|
||||
let experimental_api_enabled = connection_state
|
||||
.experimental_api_enabled
|
||||
.load(Ordering::Acquire);
|
||||
let experimental_api_enabled = connection_state.connection_state.experimental_api_enabled();
|
||||
match message {
|
||||
OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval {
|
||||
request_id,
|
||||
@@ -367,7 +332,10 @@ pub(crate) async fn route_outgoing_envelope(
|
||||
let target_connections: Vec<ConnectionId> = connections
|
||||
.iter()
|
||||
.filter_map(|(connection_id, connection_state)| {
|
||||
if connection_state.initialized.load(Ordering::Acquire)
|
||||
if connection_state
|
||||
.connection_state
|
||||
.outbound_initialized
|
||||
.load(Ordering::Acquire)
|
||||
&& !should_skip_notification_for_connection(connection_state, &message)
|
||||
{
|
||||
Some(*connection_id)
|
||||
@@ -402,6 +370,8 @@ mod tests {
|
||||
use codex_utils_absolute_path::AbsolutePathBuf;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use std::collections::HashSet;
|
||||
|
||||
use tokio::time::Duration;
|
||||
use tokio::time::timeout;
|
||||
|
||||
@@ -409,6 +379,23 @@ mod tests {
|
||||
AbsolutePathBuf::from_absolute_path(path).expect("absolute path")
|
||||
}
|
||||
|
||||
fn outbound_connection_state(
|
||||
writer: mpsc::Sender<QueuedOutgoingMessage>,
|
||||
experimental_api_enabled: bool,
|
||||
opted_out_notification_methods: impl IntoIterator<Item = &'static str>,
|
||||
disconnect_sender: Option<CancellationToken>,
|
||||
) -> OutboundConnectionState {
|
||||
let opted_out_notification_methods = opted_out_notification_methods
|
||||
.into_iter()
|
||||
.map(str::to_string)
|
||||
.collect::<HashSet<_>>();
|
||||
let connection_state = Arc::new(ConnectionState::new_test(
|
||||
experimental_api_enabled,
|
||||
opted_out_notification_methods,
|
||||
));
|
||||
OutboundConnectionState::new(writer, connection_state, disconnect_sender)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn listen_off_parses_as_off_transport() {
|
||||
assert_eq!(
|
||||
@@ -615,18 +602,14 @@ mod tests {
|
||||
async fn to_connection_notification_respects_opt_out_filters() {
|
||||
let connection_id = ConnectionId(7);
|
||||
let (writer_tx, mut writer_rx) = mpsc::channel(1);
|
||||
let initialized = Arc::new(AtomicBool::new(true));
|
||||
let opted_out_notification_methods =
|
||||
Arc::new(RwLock::new(HashSet::from(["configWarning".to_string()])));
|
||||
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
outbound_connection_state(
|
||||
writer_tx,
|
||||
initialized,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
opted_out_notification_methods,
|
||||
/*experimental_api_enabled*/ true,
|
||||
["configWarning"],
|
||||
/*disconnect_sender*/ None,
|
||||
),
|
||||
);
|
||||
@@ -662,11 +645,10 @@ mod tests {
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
outbound_connection_state(
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::from(["configWarning".to_string()]))),
|
||||
/*experimental_api_enabled*/ true,
|
||||
["configWarning"],
|
||||
/*disconnect_sender*/ None,
|
||||
),
|
||||
);
|
||||
@@ -702,11 +684,10 @@ mod tests {
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
outbound_connection_state(
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
/*experimental_api_enabled*/ true,
|
||||
[],
|
||||
/*disconnect_sender*/ None,
|
||||
),
|
||||
);
|
||||
@@ -748,11 +729,10 @@ mod tests {
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
outbound_connection_state(
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(false)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
/*experimental_api_enabled*/ false,
|
||||
[],
|
||||
/*disconnect_sender*/ None,
|
||||
),
|
||||
);
|
||||
@@ -810,11 +790,10 @@ mod tests {
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
outbound_connection_state(
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
/*experimental_api_enabled*/ true,
|
||||
[],
|
||||
/*disconnect_sender*/ None,
|
||||
),
|
||||
);
|
||||
@@ -887,21 +866,19 @@ mod tests {
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
fast_connection_id,
|
||||
OutboundConnectionState::new(
|
||||
outbound_connection_state(
|
||||
fast_writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
/*experimental_api_enabled*/ true,
|
||||
[],
|
||||
Some(fast_disconnect_token.clone()),
|
||||
),
|
||||
);
|
||||
connections.insert(
|
||||
slow_connection_id,
|
||||
OutboundConnectionState::new(
|
||||
outbound_connection_state(
|
||||
slow_writer_tx.clone(),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
/*experimental_api_enabled*/ true,
|
||||
[],
|
||||
Some(slow_disconnect_token.clone()),
|
||||
),
|
||||
);
|
||||
@@ -982,11 +959,10 @@ mod tests {
|
||||
let mut connections = HashMap::new();
|
||||
connections.insert(
|
||||
connection_id,
|
||||
OutboundConnectionState::new(
|
||||
outbound_connection_state(
|
||||
writer_tx,
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(AtomicBool::new(true)),
|
||||
Arc::new(RwLock::new(HashSet::new())),
|
||||
/*experimental_api_enabled*/ true,
|
||||
[],
|
||||
/*disconnect_sender*/ None,
|
||||
),
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user