fixes for opt-out

This commit is contained in:
Max Johnson
2026-02-10 15:35:16 -08:00
parent 7560a70d4e
commit 9b57222059
4 changed files with 57 additions and 46 deletions

View File

@@ -395,7 +395,6 @@ pub async fn run_main_with_transport(
);
continue;
};
let was_initialized = session.initialized;
let pre_synced_initialize =
!session.initialized && request.method == "initialize";
if pre_synced_initialize
@@ -411,17 +410,10 @@ pub async fn run_main_with_transport(
.process_request(connection_id, request, session)
.await;
if pre_synced_initialize
&& !session.initialized
&& let Some(connection_state) =
connections.lock().await.get_mut(&connection_id)
if let Some(connection_state) =
connections.lock().await.get_mut(&connection_id)
{
connection_state.session.initialized = false;
} else if session.initialized != was_initialized
&& let Some(connection_state) =
connections.lock().await.get_mut(&connection_id)
{
connection_state.session.initialized = session.initialized;
connection_state.session = session.clone();
}
}
JSONRPCMessage::Response(response) => {
@@ -457,7 +449,7 @@ pub async fn run_main_with_transport(
break;
};
let mut connections = connections.lock().await;
route_outgoing_envelope(&mut connections, envelope).await;
route_outgoing_envelope(&mut connections, envelope);
}
created = thread_created_rx.recv(), if listen_for_threads => {
match created {

View File

@@ -1,3 +1,4 @@
use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::RwLock;
@@ -114,10 +115,11 @@ pub(crate) struct MessageProcessor {
config_warnings: Arc<Vec<ConfigWarningNotification>>,
}
#[derive(Debug, Default)]
#[derive(Clone, Debug, Default)]
pub(crate) struct ConnectionSessionState {
pub(crate) initialized: bool,
experimental_api_enabled: bool,
pub(crate) opted_out_notification_methods: HashSet<String>,
}
pub(crate) struct MessageProcessorArgs {
@@ -256,9 +258,8 @@ impl MessageProcessor {
None => (false, Vec::new()),
};
session.experimental_api_enabled = experimental_api_enabled;
self.outgoing
.set_opted_out_notification_methods(opt_out_notification_methods)
.await;
session.opted_out_notification_methods =
opt_out_notification_methods.into_iter().collect();
let ClientInfo {
name,
title: _title,

View File

@@ -1,5 +1,4 @@
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::atomic::AtomicI64;
use std::sync::atomic::Ordering;
@@ -47,7 +46,6 @@ pub(crate) struct OutgoingMessageSender {
next_server_request_id: AtomicI64,
sender: mpsc::Sender<OutgoingEnvelope>,
request_id_to_callback: Mutex<HashMap<RequestId, oneshot::Sender<Result>>>,
opted_out_notification_methods: Mutex<HashSet<String>>,
}
impl OutgoingMessageSender {
@@ -56,21 +54,9 @@ impl OutgoingMessageSender {
next_server_request_id: AtomicI64::new(0),
sender,
request_id_to_callback: Mutex::new(HashMap::new()),
opted_out_notification_methods: Mutex::new(HashSet::new()),
}
}
pub(crate) async fn set_opted_out_notification_methods(&self, methods: Vec<String>) {
let mut opted_out = self.opted_out_notification_methods.lock().await;
opted_out.clear();
opted_out.extend(methods);
}
async fn should_skip_notification(&self, method: &str) -> bool {
let opted_out = self.opted_out_notification_methods.lock().await;
opted_out.contains(method)
}
pub(crate) async fn send_request(
&self,
request: ServerRequestPayload,
@@ -186,10 +172,6 @@ impl OutgoingMessageSender {
}
pub(crate) async fn send_server_notification(&self, notification: ServerNotification) {
let method = notification.to_string();
if self.should_skip_notification(&method).await {
return;
}
if let Err(err) = self
.sender
.send(OutgoingEnvelope::Broadcast {
@@ -204,12 +186,6 @@ impl OutgoingMessageSender {
/// All notifications should be migrated to [`ServerNotification`] and
/// [`OutgoingMessage::Notification`] should be removed.
pub(crate) async fn send_notification(&self, notification: OutgoingNotification) {
if self
.should_skip_notification(notification.method.as_str())
.await
{
return;
}
let outgoing_message = OutgoingMessage::Notification(notification);
if let Err(err) = self
.sender

View File

@@ -413,7 +413,27 @@ fn serialize_outgoing_message(outgoing_message: OutgoingMessage) -> Option<Strin
}
}
pub(crate) async fn route_outgoing_envelope(
fn should_skip_notification_for_connection(
connection_state: &ConnectionState,
message: &OutgoingMessage,
) -> bool {
match message {
OutgoingMessage::AppServerNotification(notification) => {
let method = notification.to_string();
connection_state
.session
.opted_out_notification_methods
.contains(method.as_str())
}
OutgoingMessage::Notification(notification) => connection_state
.session
.opted_out_notification_methods
.contains(notification.method.as_str()),
_ => false,
}
}
pub(crate) fn route_outgoing_envelope(
connections: &mut HashMap<ConnectionId, ConnectionState>,
envelope: OutgoingEnvelope,
) {
@@ -429,15 +449,27 @@ pub(crate) async fn route_outgoing_envelope(
);
return;
};
if connection_state.writer.send(message).await.is_err() {
connections.remove(&connection_id);
match connection_state.writer.try_send(message) {
Ok(()) => {}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
connections.remove(&connection_id);
}
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
warn!(
"dropping slow connection with full outgoing queue: {:?}",
connection_id
);
connections.remove(&connection_id);
}
}
}
OutgoingEnvelope::Broadcast { message } => {
let target_connections: Vec<ConnectionId> = connections
.iter()
.filter_map(|(connection_id, connection_state)| {
if connection_state.session.initialized {
if connection_state.session.initialized
&& !should_skip_notification_for_connection(connection_state, &message)
{
Some(*connection_id)
} else {
None
@@ -449,8 +481,18 @@ pub(crate) async fn route_outgoing_envelope(
let Some(connection_state) = connections.get(&connection_id) else {
continue;
};
if connection_state.writer.send(message.clone()).await.is_err() {
connections.remove(&connection_id);
match connection_state.writer.try_send(message.clone()) {
Ok(()) => {}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
connections.remove(&connection_id);
}
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
warn!(
"dropping slow connection with full outgoing queue: {:?}",
connection_id
);
connections.remove(&connection_id);
}
}
}
}