feat: add graph representation of agent network (#15056)

Add a representation of the agent graph. This is now used for:
* Cascade close agents (when I close a parent, it close the kids)
* Cascade resume (oposite)

Later, this will also be used for post-compaction stuffing of the
context

Direct fix for: https://github.com/openai/codex/issues/14458
This commit is contained in:
jif-oai
2026-03-19 10:21:25 +00:00
committed by GitHub
parent db5781a088
commit 70cdb17703
15 changed files with 1561 additions and 52 deletions

View File

@@ -1,4 +1,5 @@
use super::*;
use codex_protocol::protocol::SessionSource;
impl StateRuntime {
pub async fn get_thread(&self, id: ThreadId) -> anyhow::Result<Option<crate::ThreadMetadata>> {
@@ -78,6 +79,172 @@ ORDER BY position ASC
Ok(Some(tools))
}
/// Persist or replace the directional parent-child edge for a spawned thread.
pub async fn upsert_thread_spawn_edge(
&self,
parent_thread_id: ThreadId,
child_thread_id: ThreadId,
status: crate::DirectionalThreadSpawnEdgeStatus,
) -> anyhow::Result<()> {
sqlx::query(
r#"
INSERT INTO thread_spawn_edges (
parent_thread_id,
child_thread_id,
status
) VALUES (?, ?, ?)
ON CONFLICT(child_thread_id) DO UPDATE SET
parent_thread_id = excluded.parent_thread_id,
status = excluded.status
"#,
)
.bind(parent_thread_id.to_string())
.bind(child_thread_id.to_string())
.bind(status.as_ref())
.execute(self.pool.as_ref())
.await?;
Ok(())
}
/// Update the persisted lifecycle status of a spawned thread's incoming edge.
pub async fn set_thread_spawn_edge_status(
&self,
child_thread_id: ThreadId,
status: crate::DirectionalThreadSpawnEdgeStatus,
) -> anyhow::Result<()> {
sqlx::query("UPDATE thread_spawn_edges SET status = ? WHERE child_thread_id = ?")
.bind(status.as_ref())
.bind(child_thread_id.to_string())
.execute(self.pool.as_ref())
.await?;
Ok(())
}
/// List direct spawned children of `parent_thread_id` whose edge matches `status`.
pub async fn list_thread_spawn_children_with_status(
&self,
parent_thread_id: ThreadId,
status: crate::DirectionalThreadSpawnEdgeStatus,
) -> anyhow::Result<Vec<ThreadId>> {
self.list_thread_spawn_children_matching(parent_thread_id, Some(status))
.await
}
/// List spawned descendants of `root_thread_id` whose edges match `status`.
///
/// Descendants are returned breadth-first by depth, then by thread id for stable ordering.
pub async fn list_thread_spawn_descendants_with_status(
&self,
root_thread_id: ThreadId,
status: crate::DirectionalThreadSpawnEdgeStatus,
) -> anyhow::Result<Vec<ThreadId>> {
self.list_thread_spawn_descendants_matching(root_thread_id, Some(status))
.await
}
async fn list_thread_spawn_children_matching(
&self,
parent_thread_id: ThreadId,
status: Option<crate::DirectionalThreadSpawnEdgeStatus>,
) -> anyhow::Result<Vec<ThreadId>> {
let mut query = String::from(
"SELECT child_thread_id FROM thread_spawn_edges WHERE parent_thread_id = ?",
);
if status.is_some() {
query.push_str(" AND status = ?");
}
query.push_str(" ORDER BY child_thread_id");
let mut sql = sqlx::query(query.as_str()).bind(parent_thread_id.to_string());
if let Some(status) = status {
sql = sql.bind(status.to_string());
}
let rows = sql.fetch_all(self.pool.as_ref()).await?;
rows.into_iter()
.map(|row| {
ThreadId::try_from(row.try_get::<String, _>("child_thread_id")?).map_err(Into::into)
})
.collect()
}
async fn list_thread_spawn_descendants_matching(
&self,
root_thread_id: ThreadId,
status: Option<crate::DirectionalThreadSpawnEdgeStatus>,
) -> anyhow::Result<Vec<ThreadId>> {
let status_filter = if status.is_some() {
" AND status = ?"
} else {
""
};
let query = format!(
r#"
WITH RECURSIVE subtree(child_thread_id, depth) AS (
SELECT child_thread_id, 1
FROM thread_spawn_edges
WHERE parent_thread_id = ?{status_filter}
UNION ALL
SELECT edge.child_thread_id, subtree.depth + 1
FROM thread_spawn_edges AS edge
JOIN subtree ON edge.parent_thread_id = subtree.child_thread_id
WHERE 1 = 1{status_filter}
)
SELECT child_thread_id
FROM subtree
ORDER BY depth ASC, child_thread_id ASC
"#
);
let mut sql = sqlx::query(query.as_str()).bind(root_thread_id.to_string());
if let Some(status) = status {
let status = status.to_string();
sql = sql.bind(status.clone()).bind(status);
}
let rows = sql.fetch_all(self.pool.as_ref()).await?;
rows.into_iter()
.map(|row| {
ThreadId::try_from(row.try_get::<String, _>("child_thread_id")?).map_err(Into::into)
})
.collect()
}
async fn insert_thread_spawn_edge_if_absent(
&self,
parent_thread_id: ThreadId,
child_thread_id: ThreadId,
) -> anyhow::Result<()> {
sqlx::query(
r#"
INSERT INTO thread_spawn_edges (
parent_thread_id,
child_thread_id,
status
) VALUES (?, ?, ?)
ON CONFLICT(child_thread_id) DO NOTHING
"#,
)
.bind(parent_thread_id.to_string())
.bind(child_thread_id.to_string())
.bind(crate::DirectionalThreadSpawnEdgeStatus::Open.as_ref())
.execute(self.pool.as_ref())
.await?;
Ok(())
}
async fn insert_thread_spawn_edge_from_source_if_absent(
&self,
child_thread_id: ThreadId,
source: &str,
) -> anyhow::Result<()> {
let Some(parent_thread_id) = thread_spawn_parent_thread_id_from_source_str(source) else {
return Ok(());
};
self.insert_thread_spawn_edge_if_absent(parent_thread_id, child_thread_id)
.await
}
/// Find a rollout path by thread id using the underlying database.
pub async fn find_rollout_path_by_id(
&self,
@@ -276,6 +443,8 @@ ON CONFLICT(id) DO NOTHING
.bind("enabled")
.execute(self.pool.as_ref())
.await?;
self.insert_thread_spawn_edge_from_source_if_absent(metadata.id, metadata.source.as_str())
.await?;
Ok(result.rows_affected() > 0)
}
@@ -420,6 +589,8 @@ ON CONFLICT(id) DO UPDATE SET
.bind(creation_memory_mode.unwrap_or("enabled"))
.execute(self.pool.as_ref())
.await?;
self.insert_thread_spawn_edge_from_source_if_absent(metadata.id, metadata.source.as_str())
.await?;
Ok(())
}
@@ -602,6 +773,18 @@ pub(super) fn extract_memory_mode(items: &[RolloutItem]) -> Option<String> {
})
}
fn thread_spawn_parent_thread_id_from_source_str(source: &str) -> Option<ThreadId> {
let parsed_source = serde_json::from_str(source)
.or_else(|_| serde_json::from_value::<SessionSource>(Value::String(source.to_string())));
match parsed_source.ok() {
Some(SessionSource::SubAgent(codex_protocol::protocol::SubAgentSource::ThreadSpawn {
parent_thread_id,
..
})) => Some(parent_thread_id),
_ => None,
}
}
pub(super) fn push_thread_filters<'a>(
builder: &mut QueryBuilder<'a, Sqlite>,
archived_only: bool,
@@ -680,6 +863,7 @@ pub(super) fn push_thread_order_and_limit(
#[cfg(test)]
mod tests {
use super::*;
use crate::DirectionalThreadSpawnEdgeStatus;
use crate::runtime::test_support::test_thread_metadata;
use crate::runtime::test_support::unique_temp_dir;
use codex_protocol::protocol::EventMsg;
@@ -1072,4 +1256,94 @@ mod tests {
assert_eq!(persisted.tokens_used, 321);
assert_eq!(persisted.updated_at, override_updated_at);
}
#[tokio::test]
async fn thread_spawn_edges_track_directional_status() {
let codex_home = unique_temp_dir();
let runtime = StateRuntime::init(codex_home, "test-provider".to_string())
.await
.expect("state db should initialize");
let parent_thread_id =
ThreadId::from_string("00000000-0000-0000-0000-000000000900").expect("valid thread id");
let child_thread_id =
ThreadId::from_string("00000000-0000-0000-0000-000000000901").expect("valid thread id");
let grandchild_thread_id =
ThreadId::from_string("00000000-0000-0000-0000-000000000902").expect("valid thread id");
runtime
.upsert_thread_spawn_edge(
parent_thread_id,
child_thread_id,
DirectionalThreadSpawnEdgeStatus::Open,
)
.await
.expect("child edge insert should succeed");
runtime
.upsert_thread_spawn_edge(
child_thread_id,
grandchild_thread_id,
DirectionalThreadSpawnEdgeStatus::Open,
)
.await
.expect("grandchild edge insert should succeed");
let children = runtime
.list_thread_spawn_children_with_status(
parent_thread_id,
DirectionalThreadSpawnEdgeStatus::Open,
)
.await
.expect("open child list should load");
assert_eq!(children, vec![child_thread_id]);
let descendants = runtime
.list_thread_spawn_descendants_with_status(
parent_thread_id,
DirectionalThreadSpawnEdgeStatus::Open,
)
.await
.expect("open descendants should load");
assert_eq!(descendants, vec![child_thread_id, grandchild_thread_id]);
runtime
.set_thread_spawn_edge_status(child_thread_id, DirectionalThreadSpawnEdgeStatus::Closed)
.await
.expect("edge close should succeed");
let open_children = runtime
.list_thread_spawn_children_with_status(
parent_thread_id,
DirectionalThreadSpawnEdgeStatus::Open,
)
.await
.expect("open child list should load");
assert_eq!(open_children, Vec::<ThreadId>::new());
let closed_children = runtime
.list_thread_spawn_children_with_status(
parent_thread_id,
DirectionalThreadSpawnEdgeStatus::Closed,
)
.await
.expect("closed child list should load");
assert_eq!(closed_children, vec![child_thread_id]);
let closed_descendants = runtime
.list_thread_spawn_descendants_with_status(
parent_thread_id,
DirectionalThreadSpawnEdgeStatus::Closed,
)
.await
.expect("closed descendants should load");
assert_eq!(closed_descendants, vec![child_thread_id]);
let open_descendants_from_child = runtime
.list_thread_spawn_descendants_with_status(
child_thread_id,
DirectionalThreadSpawnEdgeStatus::Open,
)
.await
.expect("open descendants from child should load");
assert_eq!(open_descendants_from_child, vec![grandchild_thread_id]);
}
}