2026-03-09 Simplify experimental python SDK wire models

This commit is contained in:
Shaqayeq
2026-03-09 11:53:03 -07:00
parent 17818d6243
commit eb98b16533
148 changed files with 8847 additions and 17217 deletions

View File

@@ -29,8 +29,15 @@ def sdk_root() -> Path:
return repo_root() / "sdk" / "python"
def schema_dir() -> Path:
return repo_root() / "codex-rs" / "app-server-protocol" / "schema" / "json" / "v2"
def schema_bundle_path() -> Path:
return (
repo_root()
/ "codex-rs"
/ "app-server-protocol"
/ "schema"
/ "json"
/ "codex_app_server_protocol.v2.schemas.json"
)
def schema_root_dir() -> Path:
@@ -202,36 +209,44 @@ def bundle_all_platform_binaries(channel: str) -> None:
def generate_v2_all() -> None:
out_dir = sdk_root() / "src" / "codex_app_server" / "generated" / "v2_all"
if out_dir.exists():
shutil.rmtree(out_dir)
out_path = sdk_root() / "src" / "codex_app_server" / "generated" / "v2_all.py"
out_dir = out_path.parent
old_package_dir = out_dir / "v2_all"
if old_package_dir.exists():
shutil.rmtree(old_package_dir)
out_dir.mkdir(parents=True, exist_ok=True)
run_python_module(
"datamodel_code_generator",
[
"--input",
str(schema_dir()),
str(schema_bundle_path()),
"--input-file-type",
"jsonschema",
"--output",
str(out_dir),
str(out_path),
"--output-model-type",
"pydantic_v2.BaseModel",
"--target-python-version",
"3.10",
"--snake-case-field",
"--allow-population-by-field-name",
"--use-union-operator",
"--reuse-model",
"--disable-timestamp",
"--use-double-quotes",
],
cwd=sdk_root(),
)
_normalize_generated_timestamps(out_dir)
(out_dir / "__init__.py").touch()
_normalize_generated_timestamps(out_path)
def _notification_specs() -> list[tuple[str, str]]:
server_notifications = json.loads((schema_root_dir() / "ServerNotification.json").read_text())
one_of = server_notifications.get("oneOf", [])
generated_source = (
sdk_root() / "src" / "codex_app_server" / "generated" / "v2_all.py"
).read_text()
specs: list[tuple[str, str]] = []
v2_dir = sdk_root() / "src" / "codex_app_server" / "generated" / "v2_all"
for variant in one_of:
props = variant.get("properties", {})
@@ -249,7 +264,7 @@ def _notification_specs() -> list[tuple[str, str]]:
if not isinstance(ref, str) or not ref.startswith("#/definitions/"):
continue
class_name = ref.split("/")[-1]
if not (v2_dir / f"{class_name}.py").exists():
if f"class {class_name}(" not in generated_source and f"{class_name} =" not in generated_source:
# Skip schema variants that are not emitted into the generated v2 surface.
continue
specs.append((method, class_name))
@@ -274,7 +289,7 @@ def generate_notification_registry() -> None:
]
for class_name in class_names:
lines.append(f"from .v2_all.{class_name} import {class_name}")
lines.append(f"from .v2_all import {class_name}")
lines.extend(
[
"",
@@ -344,260 +359,17 @@ def generate_codex_event_types() -> None:
def _normalize_generated_timestamps(root: Path) -> None:
timestamp_re = re.compile(r"^#\s+timestamp:\s+.+$", flags=re.MULTILINE)
for py_file in root.rglob("*.py"):
py_files = [root] if root.is_file() else sorted(root.rglob("*.py"))
for py_file in py_files:
content = py_file.read_text()
normalized = timestamp_re.sub("# timestamp: <normalized>", content)
if normalized != content:
py_file.write_text(normalized)
# ---- protocol_types.py generation ----
def load_schema(name: str) -> dict[str, Any]:
return json.loads((schema_dir() / f"{name}.json").read_text())
def object_props(schema: dict[str, Any], node: dict[str, Any]) -> tuple[dict[str, Any], set[str]]:
if "$ref" in node:
ref = node["$ref"]
if ref.startswith("#/definitions/"):
key = ref.split("/")[-1]
return object_props(schema, schema["definitions"][key])
raise ValueError(f"unsupported ref: {ref}")
return node.get("properties", {}), set(node.get("required", []))
def field_type(v: dict[str, Any]) -> str:
if "$ref" in v:
ref = v["$ref"]
if ref.endswith("Thread"):
return "ThreadObject"
if ref.endswith("Turn"):
return "TurnObject"
if ref.endswith("ThreadTokenUsage"):
return "ThreadTokenUsage"
return "dict[str, Any]"
if "anyOf" in v:
non_null = [x for x in v["anyOf"] if x.get("type") != "null"]
if len(non_null) == 1:
return f"{field_type(non_null[0])} | None"
t = v.get("type")
if t == "string":
return "str"
if t == "integer":
return "int"
if t == "boolean":
return "bool"
if t == "array":
if (v.get("items") or {}).get("$ref", "").endswith("Thread"):
return "list[ThreadObject]"
if (v.get("items") or {}).get("$ref", "").endswith("Turn"):
return "list[TurnObject]"
return "list[dict[str, Any]]"
return "dict[str, Any]"
def render_typed_dict(name: str, props: dict[str, Any], req: set[str]) -> str:
lines = [f"class {name}(TypedDict):"]
if not props:
lines.append(" pass")
return "\n".join(lines)
for k, v in props.items():
t = field_type(v)
if k in req:
lines.append(f" {k}: {t}")
else:
lines.append(f" {k}: NotRequired[{t}]")
return "\n".join(lines)
def generate_protocol_types() -> None:
out = sdk_root() / "src" / "codex_app_server" / "generated" / "protocol_types.py"
tsr = load_schema("ThreadStartResponse")
turs = load_schema("TurnStartResponse")
ttu = load_schema("ThreadTokenUsageUpdatedNotification")
thread_props, thread_req = object_props(tsr, tsr["definitions"].get("Thread", {}))
turn_props, turn_req = object_props(turs, turs["definitions"].get("Turn", {}))
usage_props, usage_req = object_props(ttu, ttu["definitions"].get("ThreadTokenUsage", {}))
roots = {
"ThreadStartResponse": object_props(tsr, tsr),
"TurnStartResponse": object_props(turs, turs),
"ThreadTokenUsageUpdatedNotificationParams": object_props(ttu, ttu),
}
parts = [
"from __future__ import annotations",
"",
"from typing import Any, NotRequired, TypedDict",
"",
"# Generated by scripts/update_sdk_artifacts.py",
"",
render_typed_dict("ThreadObject", thread_props, thread_req),
"",
render_typed_dict("TurnObject", turn_props, turn_req),
"",
render_typed_dict("ThreadTokenUsage", usage_props, usage_req),
"",
]
for name, (props, req) in roots.items():
parts.append(render_typed_dict(name, props, req))
parts.append("")
out.write_text("\n".join(parts))
# ---- schema_types.py generation ----
TARGET_SCHEMAS = {
"ThreadStartResponse": "ThreadStartResponse.json",
"ThreadResumeResponse": "ThreadResumeResponse.json",
"ThreadReadResponse": "ThreadReadResponse.json",
"ThreadListResponse": "ThreadListResponse.json",
"ThreadForkResponse": "ThreadForkResponse.json",
"ThreadArchiveResponse": "ThreadArchiveResponse.json",
"ThreadUnarchiveResponse": "ThreadUnarchiveResponse.json",
"ThreadSetNameResponse": "ThreadSetNameResponse.json",
"ThreadCompactStartResponse": "ThreadCompactStartResponse.json",
"TurnStartResponse": "TurnStartResponse.json",
"TurnSteerResponse": "TurnSteerResponse.json",
"ModelListResponse": "ModelListResponse.json",
}
@dataclass(slots=True)
class FieldSpec:
name: str
annotation: str
required: bool
source_expr: str
@dataclass(slots=True)
class ClassSpec:
name: str
fields: list[FieldSpec]
def py_type_for_schema(schema: dict[str, Any], defs: dict[str, Any], nested: set[str]) -> tuple[str, str]:
if "$ref" in schema:
ref = schema["$ref"].split("/")[-1]
if ref in nested:
return ref, "object"
rd = defs.get(ref, {})
if rd.get("type") == "string":
return "str", "scalar"
if rd.get("type") == "integer":
return "int", "scalar"
if rd.get("type") == "boolean":
return "bool", "scalar"
return "Any", "scalar"
t = schema.get("type")
if t == "string":
return "str", "scalar"
if t == "integer":
return "int", "scalar"
if t == "boolean":
return "bool", "scalar"
if t == "array":
item_t, _ = py_type_for_schema(schema.get("items", {}), defs, nested)
return f"list[{item_t}]", "array"
if t == "object":
return "dict[str, Any]", "object"
return "Any", "scalar"
def field_source(field_name: str, py_type: str, kind: str) -> str:
g = f'payload.get("{field_name}")'
if py_type == "str":
return f"str({g} or '')"
if py_type == "int":
return f"int({g} or 0)"
if py_type == "bool":
return f"bool({g})"
if kind == "array":
return f"list({g} or [])"
return g
def class_from_schema(name: str, schema: dict[str, Any], defs: dict[str, Any], nested: set[str]) -> ClassSpec:
props = schema.get("properties", {})
req = set(schema.get("required", []))
fields: list[FieldSpec] = []
for n, s in props.items():
t, k = py_type_for_schema(s, defs, nested)
fields.append(FieldSpec(name=n, annotation=t, required=n in req, source_expr=field_source(n, t, k)))
return ClassSpec(name=name, fields=fields)
def generate_schema_types() -> None:
out = sdk_root() / "src" / "codex_app_server" / "generated" / "schema_types.py"
raw: dict[str, dict[str, Any]] = {}
defs: dict[str, Any] = {}
for cname, fname in TARGET_SCHEMAS.items():
data = json.loads((schema_dir() / fname).read_text())
raw[cname] = data
defs.update(data.get("definitions", {}))
nested = {"Thread", "Turn"}
specs: list[ClassSpec] = []
for n in sorted(nested):
if defs.get(n):
specs.append(class_from_schema(n, defs[n], defs, nested))
for name, root in raw.items():
specs.append(class_from_schema(name, root, defs, nested))
parts: list[str] = [
"# Auto-generated by scripts/update_sdk_artifacts.py",
"# DO NOT EDIT MANUALLY.",
"",
"from __future__ import annotations",
"",
"from dataclasses import dataclass",
"from typing import Any, TypedDict",
"",
]
for spec in specs:
parts.append(f"class {spec.name}Dict(TypedDict, total=False):")
if spec.fields:
for f in spec.fields:
parts.append(f" {f.name}: {f.annotation}")
else:
parts.append(" pass")
parts.append("")
parts.append("@dataclass(slots=True, kw_only=True)")
parts.append(f"class {spec.name}:")
if spec.fields:
for f in spec.fields:
default = "" if f.required else " = None"
parts.append(f" {f.name}: {f.annotation}{default}")
else:
parts.append(" pass")
parts.append("")
out.write_text("\n".join(parts) + "\n")
TYPE_ALIAS_MAP: dict[tuple[str, str], str] = {
("codex_app_server.generated.v2_all.ThreadStartParams", "AskForApproval"): "AskForApproval",
("codex_app_server.generated.v2_all.ThreadStartParams", "Personality"): "Personality",
("codex_app_server.generated.v2_all.ThreadStartParams", "SandboxMode"): "SandboxMode",
("codex_app_server.generated.v2_all.ThreadListParams", "ThreadSortKey"): "ThreadSortKey",
("codex_app_server.generated.v2_all.ThreadListParams", "ThreadSourceKind"): "ThreadSourceKind",
("codex_app_server.generated.v2_all.ThreadResumeParams", "AskForApproval"): "ResumeAskForApproval",
("codex_app_server.generated.v2_all.ThreadResumeParams", "Personality"): "ResumePersonality",
("codex_app_server.generated.v2_all.ThreadResumeParams", "SandboxMode"): "ResumeSandboxMode",
("codex_app_server.generated.v2_all.ThreadForkParams", "AskForApproval"): "ForkAskForApproval",
("codex_app_server.generated.v2_all.ThreadForkParams", "SandboxMode"): "ForkSandboxMode",
("codex_app_server.generated.v2_all.TurnStartParams", "AskForApproval"): "TurnAskForApproval",
("codex_app_server.generated.v2_all.TurnStartParams", "Personality"): "TurnPersonality",
("codex_app_server.generated.v2_all.TurnStartParams", "ReasoningEffort"): "TurnReasoningEffort",
("codex_app_server.generated.v2_all.TurnStartParams", "SandboxPolicy"): "TurnSandboxPolicy",
("codex_app_server.generated.v2_all.TurnStartParams", "ReasoningSummary"): "TurnReasoningSummary",
}
FIELD_ANNOTATION_OVERRIDES: dict[str, str] = {
# Keep public API typed without falling back to `Any`.
"config": "JsonObject",
"outputSchema": "JsonObject",
"output_schema": "JsonObject",
}
@@ -636,10 +408,8 @@ def _annotation_to_source(annotation: Any) -> str:
if isinstance(annotation, type):
if annotation.__module__ == "builtins":
return annotation.__name__
alias = TYPE_ALIAS_MAP.get((annotation.__module__, annotation.__name__))
if alias is not None:
return alias
return "Any"
return annotation.__name__
return repr(annotation)
def _camel_to_snake(name: str) -> str:
@@ -663,7 +433,7 @@ def _load_public_fields(module_name: str, class_name: str, *, exclude: set[str]
fields.append(
PublicFieldSpec(
wire_name=name,
py_name=_camel_to_snake(name),
py_name=name,
annotation=annotation,
required=required,
)
@@ -731,7 +501,7 @@ def _render_codex_block(
*_kw_signature_lines(resume_fields),
" ) -> Thread:",
" params = ThreadResumeParams(",
" threadId=thread_id,",
" thread_id=thread_id,",
*_model_arg_lines(resume_fields),
" )",
" resumed = self._client.thread_resume(thread_id, params)",
@@ -744,7 +514,7 @@ def _render_codex_block(
*_kw_signature_lines(fork_fields),
" ) -> Thread:",
" params = ThreadForkParams(",
" threadId=thread_id,",
" thread_id=thread_id,",
*_model_arg_lines(fork_fields),
" )",
" forked = self._client.thread_fork(thread_id, params)",
@@ -798,7 +568,7 @@ def _render_async_codex_block(
" ) -> AsyncThread:",
" await self._ensure_initialized()",
" params = ThreadResumeParams(",
" threadId=thread_id,",
" thread_id=thread_id,",
*_model_arg_lines(resume_fields),
" )",
" resumed = await self._client.thread_resume(thread_id, params)",
@@ -812,7 +582,7 @@ def _render_async_codex_block(
" ) -> AsyncThread:",
" await self._ensure_initialized()",
" params = ThreadForkParams(",
" threadId=thread_id,",
" thread_id=thread_id,",
*_model_arg_lines(fork_fields),
" )",
" forked = await self._client.thread_fork(thread_id, params)",
@@ -842,7 +612,7 @@ def _render_thread_block(
" ) -> Turn:",
" wire_input = _to_wire_input(input)",
" params = TurnStartParams(",
" threadId=self.id,",
" thread_id=self.id,",
" input=wire_input,",
*_model_arg_lines(turn_fields),
" )",
@@ -865,7 +635,7 @@ def _render_async_thread_block(
" await self._codex._ensure_initialized()",
" wire_input = _to_wire_input(input)",
" params = TurnStartParams(",
" threadId=self.id,",
" thread_id=self.id,",
" input=wire_input,",
*_model_arg_lines(turn_fields),
" )",
@@ -890,27 +660,27 @@ def generate_public_api_flat_methods() -> None:
sys.path.insert(0, src_dir_str)
thread_start_fields = _load_public_fields(
"codex_app_server.generated.v2_all.ThreadStartParams",
"codex_app_server.generated.v2_all",
"ThreadStartParams",
)
thread_list_fields = _load_public_fields(
"codex_app_server.generated.v2_all.ThreadListParams",
"codex_app_server.generated.v2_all",
"ThreadListParams",
)
thread_resume_fields = _load_public_fields(
"codex_app_server.generated.v2_all.ThreadResumeParams",
"codex_app_server.generated.v2_all",
"ThreadResumeParams",
exclude={"threadId"},
exclude={"thread_id"},
)
thread_fork_fields = _load_public_fields(
"codex_app_server.generated.v2_all.ThreadForkParams",
"codex_app_server.generated.v2_all",
"ThreadForkParams",
exclude={"threadId"},
exclude={"thread_id"},
)
turn_start_fields = _load_public_fields(
"codex_app_server.generated.v2_all.TurnStartParams",
"codex_app_server.generated.v2_all",
"TurnStartParams",
exclude={"threadId", "input"},
exclude={"thread_id", "input"},
)
source = public_api_path.read_text()
@@ -950,8 +720,6 @@ def generate_public_api_flat_methods() -> None:
def generate_types() -> None:
# v2_all is the authoritative generated surface.
generate_v2_all()
generate_protocol_types()
generate_schema_types()
generate_notification_registry()
generate_codex_event_types()
generate_public_api_flat_methods()