mirror of
https://github.com/openai/codex.git
synced 2026-05-15 08:42:34 +00:00
366 lines
12 KiB
Python
366 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import argparse
|
|
import asyncio
|
|
import datetime as dt
|
|
import json
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import websockets
|
|
|
|
|
|
HOST = "127.0.0.1"
|
|
DEFAULT_PORT = 8765
|
|
PATH = "/v1/responses"
|
|
|
|
CALL_ID = "shell-command-call"
|
|
FUNCTION_NAME = "shell_command"
|
|
FUNCTION_ARGS_JSON = json.dumps({"command": "echo websocket"}, separators=(",", ":"))
|
|
|
|
ASSISTANT_TEXT = "done"
|
|
SCENARIO_NORMAL = "normal"
|
|
SCENARIO_USAGE_LIMIT = "usage-limit"
|
|
|
|
|
|
def _utc_iso() -> str:
|
|
return dt.datetime.now(tz=dt.timezone.utc).isoformat(timespec="milliseconds")
|
|
|
|
|
|
def _default_usage() -> dict[str, Any]:
|
|
return {
|
|
"input_tokens": 0,
|
|
"input_tokens_details": None,
|
|
"output_tokens": 0,
|
|
"output_tokens_details": None,
|
|
"total_tokens": 0,
|
|
}
|
|
|
|
|
|
def _event_response_created(response_id: str) -> dict[str, Any]:
|
|
return {"type": "response.created", "response": {"id": response_id}}
|
|
|
|
|
|
def _event_response_done() -> dict[str, Any]:
|
|
return {"type": "response.done", "response": {"usage": _default_usage()}}
|
|
|
|
|
|
def _event_response_completed(response_id: str) -> dict[str, Any]:
|
|
return {"type": "response.completed", "response": {"id": response_id, "usage": _default_usage()}}
|
|
|
|
|
|
def _event_function_call(call_id: str, name: str, arguments_json: str) -> dict[str, Any]:
|
|
return {
|
|
"type": "response.output_item.done",
|
|
"item": {"type": "function_call", "call_id": call_id, "name": name, "arguments": arguments_json},
|
|
}
|
|
|
|
|
|
def _event_assistant_message(message_id: str, text: str) -> dict[str, Any]:
|
|
return {
|
|
"type": "response.output_item.done",
|
|
"item": {
|
|
"type": "message",
|
|
"role": "assistant",
|
|
"id": message_id,
|
|
"content": [{"type": "output_text", "text": text}],
|
|
},
|
|
}
|
|
|
|
|
|
def _event_usage_limit_error() -> dict[str, Any]:
|
|
return {
|
|
"type": "error",
|
|
"status": 429,
|
|
"error": {
|
|
"type": "usage_limit_reached",
|
|
"message": "The usage limit has been reached",
|
|
"plan_type": "pro",
|
|
"resets_at": 1704067242,
|
|
"resets_in_seconds": 1234,
|
|
},
|
|
"headers": {
|
|
"x-codex-primary-used-percent": "100.0",
|
|
"x-codex-secondary-used-percent": "87.5",
|
|
"x-codex-primary-over-secondary-limit-percent": "95.0",
|
|
"x-codex-primary-window-minutes": "15",
|
|
"x-codex-secondary-window-minutes": "60",
|
|
},
|
|
}
|
|
|
|
|
|
def _event_approaching_rate_limits() -> dict[str, Any]:
|
|
return {
|
|
"type": "codex.rate_limits",
|
|
"plan_type": "plus",
|
|
"rate_limits": {
|
|
"allowed": True,
|
|
"limit_reached": False,
|
|
"primary": {
|
|
"used_percent": 95,
|
|
"window_minutes": 15,
|
|
"reset_at": 1704067242,
|
|
},
|
|
"secondary": {
|
|
"used_percent": 87.5,
|
|
"window_minutes": 60,
|
|
"reset_at": 1704067242,
|
|
},
|
|
},
|
|
"code_review_rate_limits": None,
|
|
"credits": None,
|
|
"promo": None,
|
|
}
|
|
|
|
|
|
def _dump_json(payload: Any) -> str:
|
|
return json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
|
|
|
|
|
def _print_request(prefix: str, payload: Any) -> None:
|
|
pretty = json.dumps(payload, ensure_ascii=False, indent=2, sort_keys=True)
|
|
sys.stdout.write(f"{prefix} {_utc_iso()}\n{pretty}\n")
|
|
sys.stdout.flush()
|
|
|
|
|
|
@dataclass
|
|
class ScenarioState:
|
|
scenario: str
|
|
usage_limit_delay_seconds: float
|
|
usage_limit_approach_first: bool
|
|
approaching_rate_limits_emitted: bool = False
|
|
usage_limit_emitted: bool = False
|
|
success_count: int = 0
|
|
|
|
|
|
async def _handle_connection(
|
|
websocket: Any,
|
|
*,
|
|
scenario_state: ScenarioState,
|
|
expected_path: str = PATH,
|
|
) -> None:
|
|
# websockets v15 exposes the request path here.
|
|
path = getattr(getattr(websocket, "request", None), "path", None)
|
|
if path is None:
|
|
# Older handler signatures could pass `path` separately; accept if unavailable.
|
|
path = "(unknown)"
|
|
|
|
sys.stdout.write(f"[conn] {_utc_iso()} connected path={path}\n")
|
|
sys.stdout.flush()
|
|
|
|
path_no_qs = path.split("?", 1)[0] if path != "(unknown)" else path
|
|
if path_no_qs != "(unknown)" and path_no_qs != expected_path:
|
|
sys.stdout.write(f"[conn] {_utc_iso()} rejecting unexpected path (expected {expected_path})\n")
|
|
sys.stdout.flush()
|
|
await websocket.close(code=1008, reason="unexpected websocket path")
|
|
return
|
|
|
|
async def recv_json(label: str) -> Any:
|
|
msg = await websocket.recv()
|
|
if isinstance(msg, bytes):
|
|
payload = json.loads(msg.decode("utf-8"))
|
|
else:
|
|
payload = json.loads(msg)
|
|
_print_request(f"[{label}] recv", payload)
|
|
return payload
|
|
|
|
async def send_event(ev: dict[str, Any]) -> None:
|
|
sys.stdout.write(f"[conn] {_utc_iso()} send {_dump_json(ev)}\n")
|
|
await websocket.send(_dump_json(ev))
|
|
|
|
if scenario_state.scenario == SCENARIO_NORMAL:
|
|
# Request 1: provoke a function call (mirrors `codex-rs/core/tests/suite/agent_websocket.rs`).
|
|
await recv_json("req1")
|
|
await send_event(_event_response_created("resp-1"))
|
|
await send_event(_event_function_call(CALL_ID, FUNCTION_NAME, FUNCTION_ARGS_JSON))
|
|
await send_event(_event_response_done())
|
|
|
|
# Request 2: expect appended tool output; send final assistant message.
|
|
await recv_json("req2")
|
|
await send_event(_event_response_created("resp-2"))
|
|
await send_event(_event_assistant_message("msg-1", ASSISTANT_TEXT))
|
|
await send_event(_event_response_completed("resp-2"))
|
|
else:
|
|
request_count = 0
|
|
while True:
|
|
request_count += 1
|
|
payload = await recv_json(f"req{request_count}")
|
|
if payload.get("generate") is False:
|
|
await send_event(_event_response_created(f"resp-warmup-{request_count}"))
|
|
await send_event(_event_response_completed(f"resp-warmup-{request_count}"))
|
|
continue
|
|
|
|
if (
|
|
scenario_state.usage_limit_approach_first
|
|
and not scenario_state.approaching_rate_limits_emitted
|
|
):
|
|
scenario_state.approaching_rate_limits_emitted = True
|
|
await send_event(_event_approaching_rate_limits())
|
|
await send_event(_event_response_created("resp-approaching-limits"))
|
|
await send_event(_event_assistant_message("msg-approaching-limits", ASSISTANT_TEXT))
|
|
await send_event(_event_response_completed("resp-approaching-limits"))
|
|
continue
|
|
|
|
if not scenario_state.usage_limit_emitted:
|
|
scenario_state.usage_limit_emitted = True
|
|
sys.stdout.write(
|
|
f"[conn] {_utc_iso()} waiting {scenario_state.usage_limit_delay_seconds:.1f}s before usage limit\n"
|
|
)
|
|
sys.stdout.flush()
|
|
await asyncio.sleep(scenario_state.usage_limit_delay_seconds)
|
|
await send_event(_event_usage_limit_error())
|
|
continue
|
|
|
|
scenario_state.success_count += 1
|
|
response_id = f"resp-after-limit-{scenario_state.success_count}"
|
|
message_id = f"msg-after-limit-{scenario_state.success_count}"
|
|
await send_event(_event_response_created(response_id))
|
|
await send_event(_event_assistant_message(message_id, ASSISTANT_TEXT))
|
|
await send_event(_event_response_completed(response_id))
|
|
|
|
sys.stdout.write(f"[conn] {_utc_iso()} closing\n")
|
|
sys.stdout.flush()
|
|
await websocket.close()
|
|
|
|
|
|
async def _serve(
|
|
port: int,
|
|
scenario: str,
|
|
usage_limit_delay_seconds: float,
|
|
usage_limit_approach_first: bool,
|
|
) -> int:
|
|
scenario_state = ScenarioState(
|
|
scenario=scenario,
|
|
usage_limit_delay_seconds=usage_limit_delay_seconds,
|
|
usage_limit_approach_first=usage_limit_approach_first,
|
|
)
|
|
|
|
async def handler(ws: Any) -> None:
|
|
try:
|
|
await _handle_connection(ws, scenario_state=scenario_state, expected_path=PATH)
|
|
except websockets.exceptions.ConnectionClosed:
|
|
return
|
|
|
|
try:
|
|
server = await websockets.serve(handler, HOST, port)
|
|
except OSError as err:
|
|
sys.stderr.write(f"[server] failed to bind ws://{HOST}:{port}: {err}\n")
|
|
return 2
|
|
bound_port = server.sockets[0].getsockname()[1]
|
|
ws_uri = f"ws://{HOST}:{bound_port}"
|
|
|
|
sys.stdout.write("[server] mock Responses WebSocket server running\n")
|
|
if scenario == SCENARIO_USAGE_LIMIT:
|
|
if usage_limit_approach_first:
|
|
sys.stdout.write(
|
|
f"[server] scenario={SCENARIO_USAGE_LIMIT} first real turn approaches limits; next real turn errors after {usage_limit_delay_seconds:.1f}s\n"
|
|
)
|
|
else:
|
|
sys.stdout.write(
|
|
f"[server] scenario={SCENARIO_USAGE_LIMIT} first real turn errors after {usage_limit_delay_seconds:.1f}s\n"
|
|
)
|
|
sys.stdout.write(f"""Add this to your config.toml:
|
|
|
|
|
|
[model_providers.localapi_ws]
|
|
base_url = "{ws_uri}/v1"
|
|
name = "localapi_ws"
|
|
wire_api = "responses"
|
|
env_key = "OPENAI_API_KEY_STAGING"
|
|
supports_websockets = true
|
|
|
|
[profiles.localapi_ws]
|
|
model = "gpt-5.2"
|
|
model_provider = "localapi_ws"
|
|
model_reasoning_effort = "medium"
|
|
|
|
|
|
start codex with `codex --profile localapi_ws`
|
|
""")
|
|
if scenario == SCENARIO_USAGE_LIMIT:
|
|
if usage_limit_approach_first:
|
|
sys.stdout.write(
|
|
"""To exercise the approaching-limit then paused-queue flow:
|
|
1. Submit one prompt and dismiss the approaching-rate-limits prompt.
|
|
2. Submit a second prompt.
|
|
3. Before the delayed usage-limit error arrives, press Enter for a steer and Tab for a queued follow-up.
|
|
4. After the error, confirm both queued sections stay paused until you resume them.
|
|
"""
|
|
)
|
|
else:
|
|
sys.stdout.write(
|
|
"""To exercise the paused queue state:
|
|
1. Submit one prompt.
|
|
2. Before the delayed usage-limit error arrives, type follow-ups and press Tab to queue them.
|
|
3. After the error, confirm queued sends stay paused until you press Enter and choose `Resume queued sends`.
|
|
"""
|
|
)
|
|
sys.stdout.flush()
|
|
|
|
try:
|
|
await asyncio.Future()
|
|
finally:
|
|
server.close()
|
|
await server.wait_closed()
|
|
return 0
|
|
|
|
|
|
def main() -> int:
|
|
parser = argparse.ArgumentParser(
|
|
description=(
|
|
"Mock a minimal Responses API WebSocket endpoint for the `test_codex` flow.\n"
|
|
f"Binds to {HOST}:{DEFAULT_PORT} by default and logs incoming JSON requests to stdout."
|
|
),
|
|
formatter_class=argparse.RawTextHelpFormatter,
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=DEFAULT_PORT,
|
|
help=f"Bind port (default: {DEFAULT_PORT}; use 0 for random free port).",
|
|
)
|
|
parser.add_argument(
|
|
"--scenario",
|
|
choices=[SCENARIO_NORMAL, SCENARIO_USAGE_LIMIT],
|
|
default=SCENARIO_NORMAL,
|
|
help=(
|
|
"Behavior to emulate (default: normal).\n"
|
|
"Use `usage-limit` to make the first real turn hit a usage limit."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--usage-limit-delay-seconds",
|
|
type=float,
|
|
default=15.0,
|
|
help=(
|
|
"Delay before sending the usage-limit error in the `usage-limit` scenario "
|
|
"(default: 15.0)."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--usage-limit-approach-first",
|
|
action="store_true",
|
|
help=(
|
|
"In the `usage-limit` scenario, complete one successful near-limit turn before "
|
|
"the delayed hard-limit turn."
|
|
),
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
return asyncio.run(
|
|
_serve(
|
|
args.port,
|
|
args.scenario,
|
|
args.usage_limit_delay_seconds,
|
|
args.usage_limit_approach_first,
|
|
)
|
|
)
|
|
except KeyboardInterrupt:
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|