mirror of
https://github.com/openai/codex.git
synced 2026-04-24 14:45:27 +00:00
use arc instead of overusing mutex
This commit is contained in:
@@ -30,6 +30,10 @@ pub(crate) struct Conversation {
|
||||
session_id: Uuid,
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
state: Mutex<ConversationState>,
|
||||
}
|
||||
|
||||
struct ConversationState {
|
||||
running: bool,
|
||||
streaming_enabled: bool,
|
||||
buffered_events: Vec<CodexEventNotificationParams>,
|
||||
@@ -42,40 +46,49 @@ impl Conversation {
|
||||
outgoing: Arc<OutgoingMessageSender>,
|
||||
request_id: RequestId,
|
||||
session_id: Uuid,
|
||||
) -> Arc<Mutex<Self>> {
|
||||
let conv = Arc::new(Mutex::new(Self {
|
||||
) -> Arc<Self> {
|
||||
let conv = Arc::new(Self {
|
||||
codex,
|
||||
session_id,
|
||||
outgoing,
|
||||
request_id,
|
||||
running: false,
|
||||
streaming_enabled: false,
|
||||
buffered_events: Vec::new(),
|
||||
pending_elicitations: Vec::new(),
|
||||
}));
|
||||
state: Mutex::new(ConversationState {
|
||||
running: false,
|
||||
streaming_enabled: false,
|
||||
buffered_events: Vec::new(),
|
||||
pending_elicitations: Vec::new(),
|
||||
}),
|
||||
});
|
||||
// Detach a background loop tied to this Conversation
|
||||
Conversation::spawn_loop(conv.clone());
|
||||
conv
|
||||
}
|
||||
|
||||
pub(crate) async fn set_streaming(&mut self, enabled: bool) {
|
||||
pub(crate) async fn set_streaming(&self, enabled: bool) {
|
||||
if enabled {
|
||||
self.streaming_enabled = true;
|
||||
self.emit_initial_state().await;
|
||||
self.drain_pending_elicitations().await;
|
||||
let (events_snapshot, pending_snapshot) = {
|
||||
let mut st = self.state.lock().await;
|
||||
st.streaming_enabled = true;
|
||||
(
|
||||
st.buffered_events.clone(),
|
||||
std::mem::take(&mut st.pending_elicitations),
|
||||
)
|
||||
};
|
||||
self.emit_initial_state_with(events_snapshot).await;
|
||||
self.drain_pending_elicitations_from(pending_snapshot).await;
|
||||
} else {
|
||||
self.streaming_enabled = false;
|
||||
let mut st = self.state.lock().await;
|
||||
st.streaming_enabled = false;
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_loop(this: Arc<Mutex<Self>>) {
|
||||
fn spawn_loop(this: Arc<Self>) {
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
// We clone codex to avoid holding the lock while awaiting next_event
|
||||
let codex = { this.lock().await.codex.clone() };
|
||||
// Codex can be awaited without locking Conversation
|
||||
let codex = this.codex.clone();
|
||||
let res = codex.next_event().await;
|
||||
let mut guard = this.lock().await;
|
||||
guard.handle_next_event(res).await;
|
||||
this.handle_next_event(res).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -85,15 +98,17 @@ impl Conversation {
|
||||
}
|
||||
|
||||
pub(crate) async fn try_submit_user_input(
|
||||
&mut self,
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
items: Vec<codex_core::protocol::InputItem>,
|
||||
) -> Result<(), String> {
|
||||
if self.running {
|
||||
return Err("Session is already running".to_string());
|
||||
{
|
||||
let mut st = self.state.lock().await;
|
||||
if st.running {
|
||||
return Err("Session is already running".to_string());
|
||||
}
|
||||
st.running = true;
|
||||
}
|
||||
// Optimistically mark running to avoid races between quick successive submits
|
||||
self.running = true;
|
||||
let request_id_string = match &request_id {
|
||||
RequestId::String(s) => s.clone(),
|
||||
RequestId::Integer(i) => i.to_string(),
|
||||
@@ -106,23 +121,26 @@ impl Conversation {
|
||||
})
|
||||
.await;
|
||||
if let Err(e) = submit_res {
|
||||
// Revert running on error
|
||||
self.running = false;
|
||||
let mut st = self.state.lock().await;
|
||||
st.running = false;
|
||||
return Err(format!("Failed to submit user input: {e}"));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_next_event<E>(&mut self, res: Result<Event, E>)
|
||||
async fn handle_next_event<E>(&self, res: Result<Event, E>)
|
||||
where
|
||||
E: std::fmt::Display,
|
||||
{
|
||||
match res {
|
||||
Ok(event) => {
|
||||
self.buffered_events.push(CodexEventNotificationParams {
|
||||
meta: None,
|
||||
msg: event.msg.clone(),
|
||||
});
|
||||
{
|
||||
let mut st = self.state.lock().await;
|
||||
st.buffered_events.push(CodexEventNotificationParams {
|
||||
meta: None,
|
||||
msg: event.msg.clone(),
|
||||
});
|
||||
}
|
||||
self.stream_event_if_enabled(&event.msg).await;
|
||||
|
||||
match event.msg {
|
||||
@@ -191,15 +209,13 @@ impl Conversation {
|
||||
}
|
||||
}
|
||||
|
||||
async fn emit_initial_state(&self) {
|
||||
async fn emit_initial_state_with(&self, events: Vec<CodexEventNotificationParams>) {
|
||||
let params = InitialStateNotificationParams {
|
||||
meta: Some(NotificationMeta {
|
||||
conversation_id: Some(ConversationId(self.session_id)),
|
||||
request_id: None,
|
||||
}),
|
||||
initial_state: InitialStatePayload {
|
||||
events: self.buffered_events.clone(),
|
||||
},
|
||||
initial_state: InitialStatePayload { events },
|
||||
};
|
||||
if let Ok(params_val) = serde_json::to_value(¶ms) {
|
||||
self.outgoing
|
||||
@@ -210,8 +226,8 @@ impl Conversation {
|
||||
}
|
||||
}
|
||||
|
||||
async fn drain_pending_elicitations(&mut self) {
|
||||
for item in self.pending_elicitations.drain(..) {
|
||||
async fn drain_pending_elicitations_from(&self, items: Vec<PendingElicitation>) {
|
||||
for item in items {
|
||||
match item {
|
||||
PendingElicitation::ExecRequest(ExecRequest {
|
||||
command,
|
||||
@@ -262,13 +278,14 @@ impl Conversation {
|
||||
}
|
||||
|
||||
async fn process_exec_request(
|
||||
&mut self,
|
||||
&self,
|
||||
command: Vec<String>,
|
||||
cwd: PathBuf,
|
||||
call_id: String,
|
||||
event_id: String,
|
||||
) {
|
||||
if self.streaming_enabled {
|
||||
let should_stream = { self.state.lock().await.streaming_enabled };
|
||||
if should_stream {
|
||||
handle_exec_approval_request(
|
||||
command,
|
||||
cwd,
|
||||
@@ -284,7 +301,8 @@ impl Conversation {
|
||||
)
|
||||
.await;
|
||||
} else {
|
||||
self.pending_elicitations
|
||||
let mut st = self.state.lock().await;
|
||||
st.pending_elicitations
|
||||
.push(PendingElicitation::ExecRequest(ExecRequest {
|
||||
command,
|
||||
cwd,
|
||||
@@ -294,7 +312,7 @@ impl Conversation {
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_patch_request(&mut self, req: PatchRequest) {
|
||||
async fn process_patch_request(&self, req: PatchRequest) {
|
||||
let PatchRequest {
|
||||
call_id,
|
||||
reason,
|
||||
@@ -302,7 +320,8 @@ impl Conversation {
|
||||
changes,
|
||||
event_id,
|
||||
} = req;
|
||||
if self.streaming_enabled {
|
||||
let should_stream = { self.state.lock().await.streaming_enabled };
|
||||
if should_stream {
|
||||
handle_patch_approval_request(
|
||||
call_id,
|
||||
reason,
|
||||
@@ -319,7 +338,8 @@ impl Conversation {
|
||||
)
|
||||
.await;
|
||||
} else {
|
||||
self.pending_elicitations
|
||||
let mut st = self.state.lock().await;
|
||||
st.pending_elicitations
|
||||
.push(PendingElicitation::PatchRequest(PatchRequest {
|
||||
call_id,
|
||||
reason,
|
||||
@@ -331,7 +351,7 @@ impl Conversation {
|
||||
}
|
||||
|
||||
async fn stream_event_if_enabled(&self, msg: &EventMsg) {
|
||||
if !self.streaming_enabled {
|
||||
if !{ self.state.lock().await.streaming_enabled } {
|
||||
return;
|
||||
}
|
||||
let method = msg.to_string();
|
||||
@@ -348,12 +368,14 @@ impl Conversation {
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_task_started(&mut self) {
|
||||
self.running = true;
|
||||
async fn handle_task_started(&self) {
|
||||
let mut st = self.state.lock().await;
|
||||
st.running = true;
|
||||
}
|
||||
|
||||
async fn handle_task_clear(&mut self) {
|
||||
self.running = false;
|
||||
async fn handle_task_clear(&self) {
|
||||
let mut st = self.state.lock().await;
|
||||
st.running = false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ pub(crate) struct MessageProcessor {
|
||||
initialized: bool,
|
||||
codex_linux_sandbox_exe: Option<PathBuf>,
|
||||
session_map: Arc<Mutex<HashMap<Uuid, Arc<Codex>>>>,
|
||||
conversation_map: Arc<Mutex<HashMap<Uuid, Arc<Mutex<crate::conversation_loop::Conversation>>>>>,
|
||||
conversation_map: Arc<Mutex<HashMap<Uuid, Arc<crate::conversation_loop::Conversation>>>>,
|
||||
running_requests_id_to_codex_uuid: Arc<Mutex<HashMap<RequestId, Uuid>>>,
|
||||
/// Track request IDs to the original ToolCallRequestParams for cancellation handling
|
||||
tool_request_map: Arc<Mutex<HashMap<RequestId, ToolCallRequestParams>>>,
|
||||
@@ -70,7 +70,7 @@ impl MessageProcessor {
|
||||
|
||||
pub(crate) fn conversation_map(
|
||||
&self,
|
||||
) -> Arc<Mutex<HashMap<Uuid, Arc<Mutex<crate::conversation_loop::Conversation>>>>> {
|
||||
) -> Arc<Mutex<HashMap<Uuid, Arc<crate::conversation_loop::Conversation>>>> {
|
||||
self.conversation_map.clone()
|
||||
}
|
||||
|
||||
@@ -632,7 +632,7 @@ impl MessageProcessor {
|
||||
let codex_arc = {
|
||||
let sessions_guard = self.conversation_map.lock().await;
|
||||
match sessions_guard.get(&session_id) {
|
||||
Some(conv) => conv.lock().await.codex().clone(),
|
||||
Some(conv) => conv.codex().clone(),
|
||||
None => {
|
||||
tracing::warn!(
|
||||
"Cancel send_message: session not found for session_id: {session_id}"
|
||||
|
||||
@@ -55,10 +55,7 @@ pub(crate) async fn handle_send_message(
|
||||
return;
|
||||
};
|
||||
|
||||
let res = {
|
||||
let mut guard = conversation.lock().await;
|
||||
guard.try_submit_user_input(id.clone(), items).await
|
||||
};
|
||||
let res = conversation.try_submit_user_input(id.clone(), items).await;
|
||||
|
||||
if let Err(e) = res {
|
||||
message_processor
|
||||
@@ -86,8 +83,8 @@ pub(crate) async fn handle_send_message(
|
||||
|
||||
pub(crate) async fn get_session(
|
||||
session_id: Uuid,
|
||||
conversation_map: Arc<Mutex<HashMap<Uuid, Arc<Mutex<Conversation>>>>>,
|
||||
) -> Option<Arc<Mutex<Conversation>>> {
|
||||
conversation_map: Arc<Mutex<HashMap<Uuid, Arc<Conversation>>>>,
|
||||
) -> Option<Arc<Conversation>> {
|
||||
let guard = conversation_map.lock().await;
|
||||
guard.get(&session_id).cloned()
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ pub(crate) async fn handle_stream_conversation(
|
||||
|
||||
if let Some(conv) = conv {
|
||||
tokio::spawn(async move {
|
||||
conv.lock().await.set_streaming(true).await;
|
||||
conv.set_streaming(true).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -52,6 +52,6 @@ pub(crate) async fn handle_cancel(
|
||||
) {
|
||||
let session_id = args.conversation_id.0;
|
||||
if let Some(conv) = get_session(session_id, message_processor.conversation_map()).await {
|
||||
conv.lock().await.set_streaming(false).await;
|
||||
conv.set_streaming(false).await;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user