mirror of
https://github.com/openai/codex.git
synced 2026-05-02 10:26:45 +00:00
feat: add graph representation of agent network (#15056)
Add a representation of the agent graph. This is now used for: * Cascade close agents (when I close a parent, it close the kids) * Cascade resume (oposite) Later, this will also be used for post-compaction stuffing of the context Direct fix for: https://github.com/openai/codex/issues/14458
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
use super::*;
|
||||
use codex_protocol::protocol::SessionSource;
|
||||
|
||||
impl StateRuntime {
|
||||
pub async fn get_thread(&self, id: ThreadId) -> anyhow::Result<Option<crate::ThreadMetadata>> {
|
||||
@@ -78,6 +79,172 @@ ORDER BY position ASC
|
||||
Ok(Some(tools))
|
||||
}
|
||||
|
||||
/// Persist or replace the directional parent-child edge for a spawned thread.
|
||||
pub async fn upsert_thread_spawn_edge(
|
||||
&self,
|
||||
parent_thread_id: ThreadId,
|
||||
child_thread_id: ThreadId,
|
||||
status: crate::DirectionalThreadSpawnEdgeStatus,
|
||||
) -> anyhow::Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO thread_spawn_edges (
|
||||
parent_thread_id,
|
||||
child_thread_id,
|
||||
status
|
||||
) VALUES (?, ?, ?)
|
||||
ON CONFLICT(child_thread_id) DO UPDATE SET
|
||||
parent_thread_id = excluded.parent_thread_id,
|
||||
status = excluded.status
|
||||
"#,
|
||||
)
|
||||
.bind(parent_thread_id.to_string())
|
||||
.bind(child_thread_id.to_string())
|
||||
.bind(status.as_ref())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the persisted lifecycle status of a spawned thread's incoming edge.
|
||||
pub async fn set_thread_spawn_edge_status(
|
||||
&self,
|
||||
child_thread_id: ThreadId,
|
||||
status: crate::DirectionalThreadSpawnEdgeStatus,
|
||||
) -> anyhow::Result<()> {
|
||||
sqlx::query("UPDATE thread_spawn_edges SET status = ? WHERE child_thread_id = ?")
|
||||
.bind(status.as_ref())
|
||||
.bind(child_thread_id.to_string())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List direct spawned children of `parent_thread_id` whose edge matches `status`.
|
||||
pub async fn list_thread_spawn_children_with_status(
|
||||
&self,
|
||||
parent_thread_id: ThreadId,
|
||||
status: crate::DirectionalThreadSpawnEdgeStatus,
|
||||
) -> anyhow::Result<Vec<ThreadId>> {
|
||||
self.list_thread_spawn_children_matching(parent_thread_id, Some(status))
|
||||
.await
|
||||
}
|
||||
|
||||
/// List spawned descendants of `root_thread_id` whose edges match `status`.
|
||||
///
|
||||
/// Descendants are returned breadth-first by depth, then by thread id for stable ordering.
|
||||
pub async fn list_thread_spawn_descendants_with_status(
|
||||
&self,
|
||||
root_thread_id: ThreadId,
|
||||
status: crate::DirectionalThreadSpawnEdgeStatus,
|
||||
) -> anyhow::Result<Vec<ThreadId>> {
|
||||
self.list_thread_spawn_descendants_matching(root_thread_id, Some(status))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn list_thread_spawn_children_matching(
|
||||
&self,
|
||||
parent_thread_id: ThreadId,
|
||||
status: Option<crate::DirectionalThreadSpawnEdgeStatus>,
|
||||
) -> anyhow::Result<Vec<ThreadId>> {
|
||||
let mut query = String::from(
|
||||
"SELECT child_thread_id FROM thread_spawn_edges WHERE parent_thread_id = ?",
|
||||
);
|
||||
if status.is_some() {
|
||||
query.push_str(" AND status = ?");
|
||||
}
|
||||
query.push_str(" ORDER BY child_thread_id");
|
||||
|
||||
let mut sql = sqlx::query(query.as_str()).bind(parent_thread_id.to_string());
|
||||
if let Some(status) = status {
|
||||
sql = sql.bind(status.to_string());
|
||||
}
|
||||
|
||||
let rows = sql.fetch_all(self.pool.as_ref()).await?;
|
||||
rows.into_iter()
|
||||
.map(|row| {
|
||||
ThreadId::try_from(row.try_get::<String, _>("child_thread_id")?).map_err(Into::into)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn list_thread_spawn_descendants_matching(
|
||||
&self,
|
||||
root_thread_id: ThreadId,
|
||||
status: Option<crate::DirectionalThreadSpawnEdgeStatus>,
|
||||
) -> anyhow::Result<Vec<ThreadId>> {
|
||||
let status_filter = if status.is_some() {
|
||||
" AND status = ?"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let query = format!(
|
||||
r#"
|
||||
WITH RECURSIVE subtree(child_thread_id, depth) AS (
|
||||
SELECT child_thread_id, 1
|
||||
FROM thread_spawn_edges
|
||||
WHERE parent_thread_id = ?{status_filter}
|
||||
UNION ALL
|
||||
SELECT edge.child_thread_id, subtree.depth + 1
|
||||
FROM thread_spawn_edges AS edge
|
||||
JOIN subtree ON edge.parent_thread_id = subtree.child_thread_id
|
||||
WHERE 1 = 1{status_filter}
|
||||
)
|
||||
SELECT child_thread_id
|
||||
FROM subtree
|
||||
ORDER BY depth ASC, child_thread_id ASC
|
||||
"#
|
||||
);
|
||||
|
||||
let mut sql = sqlx::query(query.as_str()).bind(root_thread_id.to_string());
|
||||
if let Some(status) = status {
|
||||
let status = status.to_string();
|
||||
sql = sql.bind(status.clone()).bind(status);
|
||||
}
|
||||
|
||||
let rows = sql.fetch_all(self.pool.as_ref()).await?;
|
||||
rows.into_iter()
|
||||
.map(|row| {
|
||||
ThreadId::try_from(row.try_get::<String, _>("child_thread_id")?).map_err(Into::into)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn insert_thread_spawn_edge_if_absent(
|
||||
&self,
|
||||
parent_thread_id: ThreadId,
|
||||
child_thread_id: ThreadId,
|
||||
) -> anyhow::Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO thread_spawn_edges (
|
||||
parent_thread_id,
|
||||
child_thread_id,
|
||||
status
|
||||
) VALUES (?, ?, ?)
|
||||
ON CONFLICT(child_thread_id) DO NOTHING
|
||||
"#,
|
||||
)
|
||||
.bind(parent_thread_id.to_string())
|
||||
.bind(child_thread_id.to_string())
|
||||
.bind(crate::DirectionalThreadSpawnEdgeStatus::Open.as_ref())
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn insert_thread_spawn_edge_from_source_if_absent(
|
||||
&self,
|
||||
child_thread_id: ThreadId,
|
||||
source: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let Some(parent_thread_id) = thread_spawn_parent_thread_id_from_source_str(source) else {
|
||||
return Ok(());
|
||||
};
|
||||
self.insert_thread_spawn_edge_if_absent(parent_thread_id, child_thread_id)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Find a rollout path by thread id using the underlying database.
|
||||
pub async fn find_rollout_path_by_id(
|
||||
&self,
|
||||
@@ -276,6 +443,8 @@ ON CONFLICT(id) DO NOTHING
|
||||
.bind("enabled")
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
self.insert_thread_spawn_edge_from_source_if_absent(metadata.id, metadata.source.as_str())
|
||||
.await?;
|
||||
Ok(result.rows_affected() > 0)
|
||||
}
|
||||
|
||||
@@ -420,6 +589,8 @@ ON CONFLICT(id) DO UPDATE SET
|
||||
.bind(creation_memory_mode.unwrap_or("enabled"))
|
||||
.execute(self.pool.as_ref())
|
||||
.await?;
|
||||
self.insert_thread_spawn_edge_from_source_if_absent(metadata.id, metadata.source.as_str())
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -602,6 +773,18 @@ pub(super) fn extract_memory_mode(items: &[RolloutItem]) -> Option<String> {
|
||||
})
|
||||
}
|
||||
|
||||
fn thread_spawn_parent_thread_id_from_source_str(source: &str) -> Option<ThreadId> {
|
||||
let parsed_source = serde_json::from_str(source)
|
||||
.or_else(|_| serde_json::from_value::<SessionSource>(Value::String(source.to_string())));
|
||||
match parsed_source.ok() {
|
||||
Some(SessionSource::SubAgent(codex_protocol::protocol::SubAgentSource::ThreadSpawn {
|
||||
parent_thread_id,
|
||||
..
|
||||
})) => Some(parent_thread_id),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn push_thread_filters<'a>(
|
||||
builder: &mut QueryBuilder<'a, Sqlite>,
|
||||
archived_only: bool,
|
||||
@@ -680,6 +863,7 @@ pub(super) fn push_thread_order_and_limit(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::DirectionalThreadSpawnEdgeStatus;
|
||||
use crate::runtime::test_support::test_thread_metadata;
|
||||
use crate::runtime::test_support::unique_temp_dir;
|
||||
use codex_protocol::protocol::EventMsg;
|
||||
@@ -1072,4 +1256,94 @@ mod tests {
|
||||
assert_eq!(persisted.tokens_used, 321);
|
||||
assert_eq!(persisted.updated_at, override_updated_at);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn thread_spawn_edges_track_directional_status() {
|
||||
let codex_home = unique_temp_dir();
|
||||
let runtime = StateRuntime::init(codex_home, "test-provider".to_string())
|
||||
.await
|
||||
.expect("state db should initialize");
|
||||
let parent_thread_id =
|
||||
ThreadId::from_string("00000000-0000-0000-0000-000000000900").expect("valid thread id");
|
||||
let child_thread_id =
|
||||
ThreadId::from_string("00000000-0000-0000-0000-000000000901").expect("valid thread id");
|
||||
let grandchild_thread_id =
|
||||
ThreadId::from_string("00000000-0000-0000-0000-000000000902").expect("valid thread id");
|
||||
|
||||
runtime
|
||||
.upsert_thread_spawn_edge(
|
||||
parent_thread_id,
|
||||
child_thread_id,
|
||||
DirectionalThreadSpawnEdgeStatus::Open,
|
||||
)
|
||||
.await
|
||||
.expect("child edge insert should succeed");
|
||||
runtime
|
||||
.upsert_thread_spawn_edge(
|
||||
child_thread_id,
|
||||
grandchild_thread_id,
|
||||
DirectionalThreadSpawnEdgeStatus::Open,
|
||||
)
|
||||
.await
|
||||
.expect("grandchild edge insert should succeed");
|
||||
|
||||
let children = runtime
|
||||
.list_thread_spawn_children_with_status(
|
||||
parent_thread_id,
|
||||
DirectionalThreadSpawnEdgeStatus::Open,
|
||||
)
|
||||
.await
|
||||
.expect("open child list should load");
|
||||
assert_eq!(children, vec![child_thread_id]);
|
||||
|
||||
let descendants = runtime
|
||||
.list_thread_spawn_descendants_with_status(
|
||||
parent_thread_id,
|
||||
DirectionalThreadSpawnEdgeStatus::Open,
|
||||
)
|
||||
.await
|
||||
.expect("open descendants should load");
|
||||
assert_eq!(descendants, vec![child_thread_id, grandchild_thread_id]);
|
||||
|
||||
runtime
|
||||
.set_thread_spawn_edge_status(child_thread_id, DirectionalThreadSpawnEdgeStatus::Closed)
|
||||
.await
|
||||
.expect("edge close should succeed");
|
||||
|
||||
let open_children = runtime
|
||||
.list_thread_spawn_children_with_status(
|
||||
parent_thread_id,
|
||||
DirectionalThreadSpawnEdgeStatus::Open,
|
||||
)
|
||||
.await
|
||||
.expect("open child list should load");
|
||||
assert_eq!(open_children, Vec::<ThreadId>::new());
|
||||
|
||||
let closed_children = runtime
|
||||
.list_thread_spawn_children_with_status(
|
||||
parent_thread_id,
|
||||
DirectionalThreadSpawnEdgeStatus::Closed,
|
||||
)
|
||||
.await
|
||||
.expect("closed child list should load");
|
||||
assert_eq!(closed_children, vec![child_thread_id]);
|
||||
|
||||
let closed_descendants = runtime
|
||||
.list_thread_spawn_descendants_with_status(
|
||||
parent_thread_id,
|
||||
DirectionalThreadSpawnEdgeStatus::Closed,
|
||||
)
|
||||
.await
|
||||
.expect("closed descendants should load");
|
||||
assert_eq!(closed_descendants, vec![child_thread_id]);
|
||||
|
||||
let open_descendants_from_child = runtime
|
||||
.list_thread_spawn_descendants_with_status(
|
||||
child_thread_id,
|
||||
DirectionalThreadSpawnEdgeStatus::Open,
|
||||
)
|
||||
.await
|
||||
.expect("open descendants from child should load");
|
||||
assert_eq!(open_descendants_from_child, vec![grandchild_thread_id]);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user