mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
fixes for opt-out
This commit is contained in:
@@ -395,7 +395,6 @@ pub async fn run_main_with_transport(
|
||||
);
|
||||
continue;
|
||||
};
|
||||
let was_initialized = session.initialized;
|
||||
let pre_synced_initialize =
|
||||
!session.initialized && request.method == "initialize";
|
||||
if pre_synced_initialize
|
||||
@@ -411,17 +410,10 @@ pub async fn run_main_with_transport(
|
||||
.process_request(connection_id, request, session)
|
||||
.await;
|
||||
|
||||
if pre_synced_initialize
|
||||
&& !session.initialized
|
||||
&& let Some(connection_state) =
|
||||
connections.lock().await.get_mut(&connection_id)
|
||||
if let Some(connection_state) =
|
||||
connections.lock().await.get_mut(&connection_id)
|
||||
{
|
||||
connection_state.session.initialized = false;
|
||||
} else if session.initialized != was_initialized
|
||||
&& let Some(connection_state) =
|
||||
connections.lock().await.get_mut(&connection_id)
|
||||
{
|
||||
connection_state.session.initialized = session.initialized;
|
||||
connection_state.session = session.clone();
|
||||
}
|
||||
}
|
||||
JSONRPCMessage::Response(response) => {
|
||||
@@ -457,7 +449,7 @@ pub async fn run_main_with_transport(
|
||||
break;
|
||||
};
|
||||
let mut connections = connections.lock().await;
|
||||
route_outgoing_envelope(&mut connections, envelope).await;
|
||||
route_outgoing_envelope(&mut connections, envelope);
|
||||
}
|
||||
created = thread_created_rx.recv(), if listen_for_threads => {
|
||||
match created {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::collections::HashSet;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
@@ -114,10 +115,11 @@ pub(crate) struct MessageProcessor {
|
||||
config_warnings: Arc<Vec<ConfigWarningNotification>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub(crate) struct ConnectionSessionState {
|
||||
pub(crate) initialized: bool,
|
||||
experimental_api_enabled: bool,
|
||||
pub(crate) opted_out_notification_methods: HashSet<String>,
|
||||
}
|
||||
|
||||
pub(crate) struct MessageProcessorArgs {
|
||||
@@ -256,9 +258,8 @@ impl MessageProcessor {
|
||||
None => (false, Vec::new()),
|
||||
};
|
||||
session.experimental_api_enabled = experimental_api_enabled;
|
||||
self.outgoing
|
||||
.set_opted_out_notification_methods(opt_out_notification_methods)
|
||||
.await;
|
||||
session.opted_out_notification_methods =
|
||||
opt_out_notification_methods.into_iter().collect();
|
||||
let ClientInfo {
|
||||
name,
|
||||
title: _title,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
@@ -47,7 +46,6 @@ pub(crate) struct OutgoingMessageSender {
|
||||
next_server_request_id: AtomicI64,
|
||||
sender: mpsc::Sender<OutgoingEnvelope>,
|
||||
request_id_to_callback: Mutex<HashMap<RequestId, oneshot::Sender<Result>>>,
|
||||
opted_out_notification_methods: Mutex<HashSet<String>>,
|
||||
}
|
||||
|
||||
impl OutgoingMessageSender {
|
||||
@@ -56,21 +54,9 @@ impl OutgoingMessageSender {
|
||||
next_server_request_id: AtomicI64::new(0),
|
||||
sender,
|
||||
request_id_to_callback: Mutex::new(HashMap::new()),
|
||||
opted_out_notification_methods: Mutex::new(HashSet::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn set_opted_out_notification_methods(&self, methods: Vec<String>) {
|
||||
let mut opted_out = self.opted_out_notification_methods.lock().await;
|
||||
opted_out.clear();
|
||||
opted_out.extend(methods);
|
||||
}
|
||||
|
||||
async fn should_skip_notification(&self, method: &str) -> bool {
|
||||
let opted_out = self.opted_out_notification_methods.lock().await;
|
||||
opted_out.contains(method)
|
||||
}
|
||||
|
||||
pub(crate) async fn send_request(
|
||||
&self,
|
||||
request: ServerRequestPayload,
|
||||
@@ -186,10 +172,6 @@ impl OutgoingMessageSender {
|
||||
}
|
||||
|
||||
pub(crate) async fn send_server_notification(&self, notification: ServerNotification) {
|
||||
let method = notification.to_string();
|
||||
if self.should_skip_notification(&method).await {
|
||||
return;
|
||||
}
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
.send(OutgoingEnvelope::Broadcast {
|
||||
@@ -204,12 +186,6 @@ impl OutgoingMessageSender {
|
||||
/// All notifications should be migrated to [`ServerNotification`] and
|
||||
/// [`OutgoingMessage::Notification`] should be removed.
|
||||
pub(crate) async fn send_notification(&self, notification: OutgoingNotification) {
|
||||
if self
|
||||
.should_skip_notification(notification.method.as_str())
|
||||
.await
|
||||
{
|
||||
return;
|
||||
}
|
||||
let outgoing_message = OutgoingMessage::Notification(notification);
|
||||
if let Err(err) = self
|
||||
.sender
|
||||
|
||||
@@ -413,7 +413,27 @@ fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option<Strin
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn route_outgoing_envelope(
|
||||
fn should_skip_notification_for_connection(
|
||||
connection_state: &ConnectionState,
|
||||
message: &OutgoingMessage,
|
||||
) -> bool {
|
||||
match message {
|
||||
OutgoingMessage::AppServerNotification(notification) => {
|
||||
let method = notification.to_string();
|
||||
connection_state
|
||||
.session
|
||||
.opted_out_notification_methods
|
||||
.contains(method.as_str())
|
||||
}
|
||||
OutgoingMessage::Notification(notification) => connection_state
|
||||
.session
|
||||
.opted_out_notification_methods
|
||||
.contains(notification.method.as_str()),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn route_outgoing_envelope(
|
||||
connections: &mut HashMap<ConnectionId, ConnectionState>,
|
||||
envelope: OutgoingEnvelope,
|
||||
) {
|
||||
@@ -429,15 +449,27 @@ pub(crate) async fn route_outgoing_envelope(
|
||||
);
|
||||
return;
|
||||
};
|
||||
if connection_state.writer.send(message).await.is_err() {
|
||||
connections.remove(&connection_id);
|
||||
match connection_state.writer.try_send(message) {
|
||||
Ok(()) => {}
|
||||
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
|
||||
connections.remove(&connection_id);
|
||||
}
|
||||
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
|
||||
warn!(
|
||||
"dropping slow connection with full outgoing queue: {:?}",
|
||||
connection_id
|
||||
);
|
||||
connections.remove(&connection_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
OutgoingEnvelope::Broadcast { message } => {
|
||||
let target_connections: Vec<ConnectionId> = connections
|
||||
.iter()
|
||||
.filter_map(|(connection_id, connection_state)| {
|
||||
if connection_state.session.initialized {
|
||||
if connection_state.session.initialized
|
||||
&& !should_skip_notification_for_connection(connection_state, &message)
|
||||
{
|
||||
Some(*connection_id)
|
||||
} else {
|
||||
None
|
||||
@@ -449,8 +481,18 @@ pub(crate) async fn route_outgoing_envelope(
|
||||
let Some(connection_state) = connections.get(&connection_id) else {
|
||||
continue;
|
||||
};
|
||||
if connection_state.writer.send(message.clone()).await.is_err() {
|
||||
connections.remove(&connection_id);
|
||||
match connection_state.writer.try_send(message.clone()) {
|
||||
Ok(()) => {}
|
||||
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
|
||||
connections.remove(&connection_id);
|
||||
}
|
||||
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
|
||||
warn!(
|
||||
"dropping slow connection with full outgoing queue: {:?}",
|
||||
connection_id
|
||||
);
|
||||
connections.remove(&connection_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user