Compare commits

...

35 Commits

Author SHA1 Message Date
jif-oai
d664ea4c11 feat: coloring for user commands 2026-03-03 20:58:34 +00:00
jif-oai
9047e4dc6f feat: rendering engine 2026-03-03 15:44:59 +00:00
jif-oai
8159f05dfd feat: wire spreadsheet artifact (#13362) 2026-03-03 15:27:37 +00:00
jif-oai
24ba01b9da feat: artifact presentation part 7 (#13360) 2026-03-03 15:03:25 +00:00
jif-oai
1df040e62b feat: add multi-actions to presentation tool (#13357) 2026-03-03 14:37:26 +00:00
jif-oai
ad393fa753 feat: pres artifact part 5 (#13355)
Mostly written by Codex
2026-03-03 14:08:01 +00:00
jif-oai
821024f9c9 feat: spreadsheet part 3 (#13350)
=
2026-03-03 13:09:37 +00:00
jif-oai
a7d90b867d feat: presentation part 4 (#13348) 2026-03-03 12:51:31 +00:00
jif-oai
875eaac0d1 feat: spreadsheet v2 (#13347) 2026-03-03 12:38:27 +00:00
jif-oai
8c5e50ef39 feat: spreadsheet artifact (#13345) 2026-03-03 12:25:40 +00:00
jif-oai
564a883c2a feat: pres artifact 3 (#13346) 2026-03-03 12:18:25 +00:00
jif-oai
72dc444b2c feat: pres artifact 2 (#13344) 2026-03-03 12:00:34 +00:00
jif-oai
4874b9291a feat: presentation artifact p1 (#13341)
Part 1 of presentation tool artifact
2026-03-03 11:38:03 +00:00
pash-openai
07e532dcb9 app-server service tier plumbing (plus some cleanup) (#13334)
followup to https://github.com/openai/codex/pull/13212 to expose fast
tier controls to app server
(majority of this PR is generated schema jsons - actual code is +69 /
-35 and +24 tests )

- add service tier fields to the app-server protocol surfaces used by
thread lifecycle, turn start, config, and session configured events
- thread service tier through the app-server message processor and core
thread config snapshots
- allow runtime config overrides to carry service tier for app-server
callers

cleanup:
- Removing useless "legacy" code supporting "standard" - we moved to
None | "fast", so "standard" is not needed.
2026-03-03 02:35:09 -08:00
jif-oai
938c6dd388 fix: db windows path (#13336) 2026-03-03 09:50:52 +00:00
jif-oai
cacefb5228 fix: agent when profile (#13235)
Co-authored-by: Josh McKinney <joshka@openai.com>
Co-authored-by: Codex <noreply@openai.com>
2026-03-03 09:20:25 +00:00
jif-oai
3166a5ba82 fix: agent race (#13248)
https://github.com/openai/codex/issues/13244
2026-03-03 09:19:37 +00:00
bwanner-oai
6deb72c04b Renaming Team to Business plan during TUI onboarding (#13313)
Team is referred to as "Business"
2026-03-02 23:13:29 -08:00
Felipe Coury
745c48b088 fix(core): scope file search gitignore to repository context (#13250)
Closes #3493

## Problem

When a user's home directory (or any ancestor) contains a broad
`.gitignore` (e.g. `*` + `!.gitignore`), the `@` file mention picker in
Codex silently hides valid repository files like `package.json`. The
picker returns `no matches` for searches that should succeed. This is
surprising because manually typed paths still work, making the failure
hard to diagnose.

## Mental model

Git itself never walks above the repository root to assemble its ignore
list. Its `.gitignore` resolution is strictly scoped: it reads
`.gitignore` files from the repo root downward, the per-repo
`.git/info/exclude`, and the user's global excludes file (via
`core.excludesFile`). A `.gitignore` sitting in a parent directory above
the repo root has no effect on `git status`, `git ls-files`, or any
other git operation. Our file search should replicate this contract
exactly.

The `ignore` crate's `WalkBuilder` has a `require_git` flag that
controls whether it follows this contract:

- `require_git(false)` (the previous setting): the walker reads
`.gitignore` files from _all_ ancestor directories, even those above or
outside the repository root. This is a deliberate divergence from git's
behavior in the `ignore` crate, intended for non-git use cases. It means
a `~/.gitignore` with `*` will suppress every file in the walk—something
git itself would never do.

- `require_git(true)` (this fix): the walker only applies `.gitignore`
semantics when it detects a `.git` directory, scoping ignore resolution
to the repository boundary. This matches git's own behavior: parent
`.gitignore` files above the repo root have no effect.

The fix is a one-line change: `require_git(false)` becomes
`require_git(true)`.

## How `require_git(false)` got here

The setting was introduced in af338cc (#2981, "Improve @ file search:
include specific hidden dirs such as .github, .gitlab"). That PR's goal
was to make hidden directories like `.github` and `.vscode` discoverable
by setting `.hidden(false)` on the walker. The `require_git(false)` was
added alongside it with the comment _"Don't require git to be present to
apply git-related ignore rules"_—the author likely intended gitignore
rules to still filter results even when no `.git` directory exists (e.g.
searching an extracted tarball that has a `.gitignore` but no `.git`).

The unintended consequence: with `require_git(false)`, the `ignore`
crate walks _above_ the search root to find `.gitignore` files in
ancestor directories. This is a side effect the original author almost
certainly didn't anticipate. The PR message says "Preserve `.gitignore`
semantics," but `require_git(false)` actually _breaks_ git's semantics
by applying ancestor ignore files that git itself would never read.

In short: the intent was "apply gitignore even without `.git`" but the
effect was "apply gitignore from every ancestor directory." This fix
restores git-correct scoping.

## Non-goals

- This PR does not change behavior when `respect_gitignore` is `false`
(that path already disables all git-related ignore rules).
- The first test
(`parent_gitignore_outside_repo_does_not_hide_repo_files`) intentionally
omits `git init`. The `ignore` crate's `require_git(true)` causes it to
skip gitignore processing entirely when no `.git` exists, which is the
desired behavior for that scenario. A second test
(`git_repo_still_respects_local_gitignore_when_enabled`) covers the
complementary case with a real git repo.

## Tradeoffs

**Behavioral shift**: With `require_git(true)`, directories that contain
`.gitignore` files but are _not_ inside a git repository will no longer
have those ignore rules applied during `@` search. This is a correctness
improvement for the primary use case (searching inside repos), but
changes behavior for the edge case of searching non-repo directories
that happen to have `.gitignore` files. In practice, Codex is
overwhelmingly used inside git repositories, so this tradeoff strongly
favors the fix.

**Two test strategies**: The first test omits `git init` to verify
parent ignore leakage is blocked; the second runs `git init` to verify
the repo's own `.gitignore` is still honored. Together they cover both
sides of the `require_git(true)` contract.

## Architecture

The change is in `walker_worker()` within
`codex-rs/file-search/src/lib.rs`, which configures the
`ignore::WalkBuilder` used by the file search walker thread. The walker
feeds discovered file paths into `nucleo` for fuzzy matching. The
`require_git` flag controls whether the walker consults `.gitignore`
files at all—it sits upstream of all ignore processing.

```
walker_worker
  └─ WalkBuilder::new(root)
       ├─ .hidden(false)         — include dotfiles
       ├─ .follow_links(true)    — follow symlinks
       ├─ .require_git(true)     — ← THE FIX: only apply gitignore in git repos
       └─ (conditional) git_ignore(false), git_global(false), etc.
            └─ applied when respect_gitignore == false
```

## Tests

- `parent_gitignore_outside_repo_does_not_hide_repo_files`: creates a
temp directory tree with a parent `.gitignore` containing `*`, a child
"repo" directory with `package.json` and `.vscode/settings.json`, and
asserts that both files are discoverable via `run()` with
`respect_gitignore: true`.
- `git_repo_still_respects_local_gitignore_when_enabled`: the
complementary test—runs `git init` inside the child directory and
verifies that the repo's own `.gitignore` exclusions still work (e.g.
`.vscode/extensions.json` is excluded while `.vscode/settings.json` is
whitelisted). Confirms that `require_git(true)` does not disable
gitignore processing inside actual git repositories.
2026-03-02 21:52:20 -07:00
pash-openai
2f5b01abd6 add fast mode toggle (#13212)
- add a local Fast mode setting in codex-core (similar to how model id
is currently stored on disk locally)
- send `service_tier=priority` on requests when Fast is enabled
- add `/fast` in the TUI and persist it locally
- feature flag
2026-03-02 20:29:33 -08:00
rakan-oai
56cc2c71f4 tui: preserve kill buffer across submit and slash-command clears (#12006)
## Problem

Before this change, composer paths that cleared the textarea after
submit or slash-command dispatch
also cleared the textarea kill buffer. That meant a user could `Ctrl+K`
part of a draft, trigger a
composer action that cleared the visible draft, and then lose the
ability to `Ctrl+Y` the killed
text back.

This was especially awkward for workflows where the user wants to
temporarily remove text, run a
composer action such as changing reasoning level or dispatching a slash
command, and then restore
the killed text into the now-empty draft.

## Mental model

This change separates visible draft state from editing-history state.

The visible draft includes the current textarea contents and text
elements that should be cleared
when the composer submits or dispatches a command. The kill buffer is
different: it represents the
most recent killed text and should survive those composer-driven clears
so the user can still yank
it back afterward.

After this change, submit and slash-command dispatch still clear the
visible textarea contents, but
they no longer erase the most recent kill.

## Non-goals

This does not implement a multi-entry kill ring or change the semantics
of `Ctrl+K` and `Ctrl+Y`
beyond preserving the existing yank target across these clears.

It also does not change how submit, slash-command parsing, prompt
expansion, or attachment handling
work, except that those flows no longer discard the textarea kill buffer
as a side effect of
clearing the draft.

## Tradeoffs

The main tradeoff is that clearing the visible textarea is no longer
equivalent to fully resetting
all editing state. That is intentional here, because submit and
slash-command dispatch are composer
actions, not requests to forget the user's most recent kill.

The benefit is better editing continuity. The cost is that callers must
understand that full-buffer
replacement resets visible draft state but not the kill buffer.

## Architecture

The behavioral change is in `TextArea`: full-buffer replacement now
rebuilds text and elements
without clearing `kill_buffer`.

`ChatComposer` already clears the textarea after successful submit and
slash-command dispatch by
calling into those textarea replacement paths. With this change, those
existing composer flows
inherit the new behavior automatically: the visible draft is cleared,
but the last killed text
remains available for `Ctrl+Y`.

The tests cover both layers:

- `TextArea` verifies that the kill buffer survives full-buffer
replacement.
- `ChatComposer` verifies that it survives submit.
- `ChatComposer` also verifies that it survives slash-command dispatch.

## Observability

There is no dedicated logging for kill-buffer preservation. The most
direct way to reason about the
behavior is to inspect textarea-wide replacement paths and confirm
whether they treat the kill
buffer as visible-buffer state or as editing-history state.

If this regresses in the future, the likely failure mode is simple and
user-visible: `Ctrl+Y` stops
restoring text after submit or slash-command clears even though ordinary
kill/yank still works
within a single uninterrupted draft.

## Tests

Added focused regression coverage for the new contract:

- `kill_buffer_persists_across_set_text`
- `kill_buffer_persists_after_submit`
- `kill_buffer_persists_after_slash_command_dispatch`

Local verification:
- `just fmt`
- `cargo test -p codex-tui`

---------

Co-authored-by: Josh McKinney <joshka@openai.com>
2026-03-03 02:06:08 +00:00
Celia Chen
0bb152b01d chore: remove SkillMetadata.permissions and derive skill sandboxing from permission_profile (#13061)
## Summary

This change removes the compiled permissions field from skill metadata
and keeps permission_profile as the single source of truth.

Skill loading no longer compiles skill permissions eagerly. Instead, the
zsh-fork skill escalation path compiles `skill.permission_profile` when
it needs to determine the sandbox to apply for a skill script.

  ## Behavior change

  For skills that declare:
```
  permissions: {}
```
we now treat that the same as having no skill permissions override,
instead of creating and using a default readonly sandbox. This change
makes the behavior more intuitive:

  - only non-empty skill permission profiles affect sandboxing
- omitting permissions and writing permissions: {} now mean the same
thing
- skill metadata keeps a single permissions representation instead of
storing derived state too

Overall, this makes skill sandbox behavior easier to understand and more
predictable.
2026-03-03 01:29:53 +00:00
Owen Lin
9965bf31fa feat(app-server-test-client): support tracing (#13286) 2026-03-02 17:24:48 -08:00
Brian Fioca
50084339a6 Adjusting plan prompt for clarity and verbosity (#13284)
`plan.md` prompt changes to tighten plan clarity and verbosity.
2026-03-03 01:14:39 +00:00
Ruslan Nigmatullin
9022cdc563 app-server: Silence thread status changes caused by thread being created (#13079)
Currently we emit `thread/status/changed` with `Idle` status right
before sending `thread/started` event (which also has `Idle` status in
it).
It feels that there is no point in that as client has no way to know
prior state of the thread as it didn't exist yet, so silence these kinds
of notifications.
2026-03-03 00:52:28 +00:00
Owen Lin
146b798129 fix(app-server): emit turn/started only when turn actually starts (#13261)
This is a follow-up for https://github.com/openai/codex/pull/13047

## Why
We had a race where `turn/started` could be observed before the thread
had actually transitioned to `Active`. This was because we eagerly
emitted `turn/started` in the request handler for `turn/start` (and
`review/start`).

That was showing up as flaky `thread/resume` tests, but the real issue
was broader: a client could see `turn/started` and still get back an
idle thread immediately afterward.

The first idea was to eagerly call
`thread_watch_manager.note_turn_started(...)` from the `turn/start`
request path. That turns out to be unsafe, because
`submit(Op::UserInput)` only queues work. If a turn starts and completes
quickly, request-path bookkeeping can race with the real lifecycle
events and leave stale running state behind.

**The real fix** is to move `turn/started` to emit only after the turn
_actually_ starts, so we do that by waiting for the
`EventMsg::TurnStarted` notification emitted by codex core. We do this
for both `turn/start` and `review/start`.

I also verified this change is safe for our first-party codex apps -
they don't have any assumptions that `turn/started` is emitted before
the RPC response to `turn/start` (which is correct anyway).

I also removed `single_client_mode` since it isn't really necessary now.

## Testing
- `cargo test -p codex-app-server thread_resume -- --nocapture`
- `cargo test -p codex-app-server
'suite::v2::turn_start::turn_start_emits_notifications_and_accepts_model_override'
-- --exact --nocapture`
- `cargo test -p codex-app-server`
2026-03-02 16:43:31 -08:00
Ahmed Ibrahim
b20b6aa46f Update realtime websocket API (#13265)
- migrate the realtime websocket transport to the new session and
handoff flow
- make the realtime model configurable in config.toml and use API-key
auth for the websocket

---------

Co-authored-by: Codex <noreply@openai.com>
2026-03-02 16:05:40 -08:00
Owen Lin
d473e8d56d feat(app-server): add tracing to all app-server APIs (#13285)
### Overview
This PR adds the first piece of tracing for app-server JSON-RPC
requests.

There are two main changes:
- JSON-RPC requests can now take an optional W3C trace context at the
top level via a `trace` field (`traceparent` / `tracestate`).
- app-server now creates a dedicated request span for every inbound
JSON-RPC request in `MessageProcessor`, and uses the request-level trace
context as the parent when present.

For compatibility with existing flows, app-server still falls back to
the TRACEPARENT env var when there is no request-level traceparent.

This PR is intentionally scoped to the app-server boundary. In a
followup, we'll actually propagate trace context through the async
handoff into core execution spans like run_turn, which will make
app-server traces much more useful.

### Spans
A few details on the app-server span shape:
- each inbound request gets its own server span
- span/resource names are based on the JSON-RPC method (`initialize`,
`thread/start`, `turn/start`, etc.)
- spans record transport (stdio vs websocket), request id, connection
id, and client name/version when available
- `initialize` stores client metadata in session state so later requests
on the same connection can reuse it
2026-03-02 16:01:41 -08:00
Ruslan Nigmatullin
14fcb6645c app-server: Update thread/name/set to support not-loaded threads (#13282)
Currently `thread/name/set` does only work for loaded threads.
Expand the scope to also support persisted but not-yet-loaded ones for a
more predictable API surface.
This will make it possible to rename threads discovered via
`thread/list` and similar operations.
2026-03-02 15:13:18 -08:00
Josh McKinney
75e7c804ea test(app-server): increase flow test timeout to reduce flake (#11814)
## Summary
- increase `DEFAULT_READ_TIMEOUT` in `codex_message_processor_flow` from
20s to 45s
- keep test behavior the same while avoiding platform timing flakes

## Why
Windows ARM64 CI showed these tests taking about 24s before
`task_complete`, which could fail early and produce wiremock
request-count mismatches.

## Testing
- just fmt
- cargo test -p codex-app-server codex_message_processor_flow --
--nocapture
2026-03-02 12:29:28 -08:00
Dylan Hurd
e10df4ba10 fix(core) shell_snapshot multiline exports (#12642)
## Summary
Codex discovered this one - shell_snapshot tests were breaking on my
machine because I had a multiline env var. We should handle these!

## Testing
- [x] existing tests pass
- [x] Updated unit tests
2026-03-02 12:08:17 -07:00
jif-oai
f8838fd6f3 feat: enable ma through /agent (#13246)
<img width="639" height="139" alt="Screenshot 2026-03-02 at 16 06 41"
src="https://github.com/user-attachments/assets/c006fcec-c1e7-41ce-bb84-c121d5ffb501"
/>

Then
<img width="372" height="37" alt="Screenshot 2026-03-02 at 16 06 49"
src="https://github.com/user-attachments/assets/aa4ad703-e7e7-4620-9032-f5cd4f48ff79"
/>
2026-03-02 18:37:29 +00:00
Charley Cunningham
7979ce453a tui: restore draft footer hints (#13202)
## Summary
- restore `Tab to queue` when a draft is present and the agent is
running
- keep draft-idle footers passive by showing the normal footer or status
line instead of `? for shortcuts`
- align footer snapshot coverage with the updated draft footer behavior

## Codex author
`codex resume 019c7f1c-43aa-73e0-97c7-40f457396bb0`

---------

Co-authored-by: Codex <noreply@openai.com>
2026-03-02 10:26:13 -08:00
Eric Traut
7709bf32a3 Fix project trust config parsing so CLI overrides work (#13090)
Fixes #13076

This PR fixes a bug that causes command-line config overrides for MCP
subtables to not be merged correctly.

Summary
- make project trust loading go through the dedicated struct so CLI
overrides can update trusted project-local MCP transports

---------

Co-authored-by: jif-oai <jif@openai.com>
2026-03-02 11:10:38 -07:00
Michael Bolin
3241c1c6cc fix: use https://git.savannah.gnu.org/git/bash instead of https://github.com/bolinfest/bash (#13057)
Historically, we cloned the Bash repo from
https://github.com/bminor/bash, but for whatever reason, it was removed
at some point.

I had a local clone of it, so I pushed it to
https://github.com/bolinfest/bash so that we could continue running our
CI job. I did this in https://github.com/openai/codex/pull/9563, and as
you can see, I did not tamper with the commit hash we used as the basis
of this build.

Using a personal fork is not great, so this PR changes the CI job to use
what appears to be considered the source of truth for Bash, which is
https://git.savannah.gnu.org/git/bash.git.

Though in testing this out, it appears this Git server does not support
the combination of `git clone --depth 1
https://git.savannah.gnu.org/git/bash` and `git fetch --depth 1 origin
a8a1c2fac029404d3f42cd39f5a20f24b6e4fe4b`, as it fails with the
following error:

```
error: Server does not allow request for unadvertised object a8a1c2fac029404d3f42cd39f5a20f24b6e4fe4b
```

so unfortunately this means that we have to do a full clone instead of a
shallow clone in our CI jobs, which will be a bit slower.

Also updated `codex-rs/shell-escalation/README.md` to reflect this
change.
2026-03-02 09:09:54 -08:00
215 changed files with 33824 additions and 1497 deletions

View File

@@ -146,9 +146,8 @@ jobs:
shell: bash
run: |
set -euo pipefail
git clone --depth 1 https://github.com/bolinfest/bash /tmp/bash
git clone https://git.savannah.gnu.org/git/bash /tmp/bash
cd /tmp/bash
git fetch --depth 1 origin a8a1c2fac029404d3f42cd39f5a20f24b6e4fe4b
git checkout a8a1c2fac029404d3f42cd39f5a20f24b6e4fe4b
git apply "${GITHUB_WORKSPACE}/shell-tool-mcp/patches/bash-exec-wrapper.patch"
./configure --without-bash-malloc
@@ -188,9 +187,8 @@ jobs:
shell: bash
run: |
set -euo pipefail
git clone --depth 1 https://github.com/bolinfest/bash /tmp/bash
git clone https://git.savannah.gnu.org/git/bash /tmp/bash
cd /tmp/bash
git fetch --depth 1 origin a8a1c2fac029404d3f42cd39f5a20f24b6e4fe4b
git checkout a8a1c2fac029404d3f42cd39f5a20f24b6e4fe4b
git apply "${GITHUB_WORKSPACE}/shell-tool-mcp/patches/bash-exec-wrapper.patch"
./configure --without-bash-malloc

47
MODULE.bazel.lock generated

File diff suppressed because one or more lines are too long

711
codex-rs/Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -33,6 +33,8 @@ members = [
"mcp-server",
"network-proxy",
"ollama",
"artifact-presentation",
"artifact-spreadsheet",
"process-hardening",
"protocol",
"rmcp-client",
@@ -109,6 +111,8 @@ codex-mcp-server = { path = "mcp-server" }
codex-network-proxy = { path = "network-proxy" }
codex-ollama = { path = "ollama" }
codex-otel = { path = "otel" }
codex-artifact-presentation = { path = "artifact-presentation" }
codex-artifact-spreadsheet = { path = "artifact-spreadsheet" }
codex-process-hardening = { path = "process-hardening" }
codex-protocol = { path = "protocol" }
codex-responses-api-proxy = { path = "responses-api-proxy" }
@@ -214,7 +218,10 @@ os_info = "3.12.0"
owo-colors = "4.3.0"
path-absolutize = "3.1.1"
pathdiff = "0.2"
font8x8 = "0.3.1"
tiny-skia = "0.11.4"
portable-pty = "0.9.0"
ppt-rs = "0.2.6"
predicates = "3"
pretty_assertions = "1.4.1"
pulldown-cmark = "0.10"
@@ -349,6 +356,7 @@ ignored = [
"openssl-sys",
"codex-utils-readiness",
"codex-secrets",
"codex-artifact-spreadsheet"
]
[profile.release]

View File

@@ -20,6 +20,7 @@ codex-utils-absolute-path = { workspace = true }
schemars = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
serde_with = { workspace = true }
shlex = { workspace = true }
strum_macros = { workspace = true }
thiserror = { workspace = true }

View File

@@ -1703,6 +1703,12 @@
}
]
},
"ServiceTier": {
"enum": [
"fast"
],
"type": "string"
},
"Settings": {
"description": "Settings for a collaboration mode.",
"properties": {
@@ -1933,6 +1939,23 @@
}
]
},
"serviceTier": {
"anyOf": [
{
"anyOf": [
{
"$ref": "#/definitions/ServiceTier"
},
{
"type": "null"
}
]
},
{
"type": "null"
}
]
},
"threadId": {
"type": "string"
}
@@ -2155,6 +2178,23 @@
}
]
},
"serviceTier": {
"anyOf": [
{
"anyOf": [
{
"$ref": "#/definitions/ServiceTier"
},
{
"type": "null"
}
]
},
{
"type": "null"
}
]
},
"threadId": {
"type": "string"
}
@@ -2299,6 +2339,23 @@
"string",
"null"
]
},
"serviceTier": {
"anyOf": [
{
"anyOf": [
{
"$ref": "#/definitions/ServiceTier"
},
{
"type": "null"
}
]
},
{
"type": "null"
}
]
}
},
"type": "object"
@@ -2409,6 +2466,24 @@
],
"description": "Override the sandbox policy for this turn and subsequent turns."
},
"serviceTier": {
"anyOf": [
{
"anyOf": [
{
"$ref": "#/definitions/ServiceTier"
},
{
"type": "null"
}
]
},
{
"type": "null"
}
],
"description": "Override the service tier for this turn and subsequent turns."
},
"summary": {
"anyOf": [
{

View File

@@ -1138,6 +1138,16 @@
],
"description": "How to sandbox commands executed in the system"
},
"service_tier": {
"anyOf": [
{
"$ref": "#/definitions/ServiceTier"
},
{
"type": "null"
}
]
},
"session_id": {
"$ref": "#/definitions/ThreadId"
},
@@ -4234,8 +4244,14 @@
{
"additionalProperties": false,
"properties": {
"SessionCreated": {
"SessionUpdated": {
"properties": {
"instructions": {
"type": [
"string",
"null"
]
},
"session_id": {
"type": "string"
}
@@ -4246,27 +4262,6 @@
"type": "object"
}
},
"required": [
"SessionCreated"
],
"title": "SessionCreatedRealtimeEvent",
"type": "object"
},
{
"additionalProperties": false,
"properties": {
"SessionUpdated": {
"properties": {
"backend_prompt": {
"type": [
"string",
"null"
]
}
},
"type": "object"
}
},
"required": [
"SessionUpdated"
],
@@ -4297,6 +4292,40 @@
"title": "ConversationItemAddedRealtimeEvent",
"type": "object"
},
{
"additionalProperties": false,
"properties": {
"ConversationItemDone": {
"properties": {
"item_id": {
"type": "string"
}
},
"required": [
"item_id"
],
"type": "object"
}
},
"required": [
"ConversationItemDone"
],
"title": "ConversationItemDoneRealtimeEvent",
"type": "object"
},
{
"additionalProperties": false,
"properties": {
"HandoffRequested": {
"$ref": "#/definitions/RealtimeHandoffRequested"
}
},
"required": [
"HandoffRequested"
],
"title": "HandoffRequestedRealtimeEvent",
"type": "object"
},
{
"additionalProperties": false,
"properties": {
@@ -4312,6 +4341,47 @@
}
]
},
"RealtimeHandoffMessage": {
"properties": {
"role": {
"type": "string"
},
"text": {
"type": "string"
}
},
"required": [
"role",
"text"
],
"type": "object"
},
"RealtimeHandoffRequested": {
"properties": {
"handoff_id": {
"type": "string"
},
"input_transcript": {
"type": "string"
},
"item_id": {
"type": "string"
},
"messages": {
"items": {
"$ref": "#/definitions/RealtimeHandoffMessage"
},
"type": "array"
}
},
"required": [
"handoff_id",
"input_transcript",
"item_id",
"messages"
],
"type": "object"
},
"ReasoningEffort": {
"description": "See https://platform.openai.com/docs/guides/reasoning?api-mode=responses#get-started-with-reasoning",
"enum": [
@@ -5350,6 +5420,12 @@
}
]
},
"ServiceTier": {
"enum": [
"fast"
],
"type": "string"
},
"SessionNetworkProxyRuntime": {
"properties": {
"admin_addr": {
@@ -6694,6 +6770,16 @@
],
"description": "How to sandbox commands executed in the system"
},
"service_tier": {
"anyOf": [
{
"$ref": "#/definitions/ServiceTier"
},
{
"type": "null"
}
]
},
"session_id": {
"$ref": "#/definitions/ThreadId"
},

View File

@@ -70,7 +70,18 @@
"method": {
"type": "string"
},
"params": true
"params": true,
"trace": {
"anyOf": [
{
"$ref": "#/definitions/W3cTraceContext"
},
{
"type": "null"
}
],
"description": "Optional W3C Trace Context for distributed tracing."
}
},
"required": [
"id",
@@ -102,6 +113,23 @@
"type": "integer"
}
]
},
"W3cTraceContext": {
"properties": {
"traceparent": {
"type": [
"string",
"null"
]
},
"tracestate": {
"type": [
"string",
"null"
]
}
},
"type": "object"
}
},
"description": "Refers to any valid JSON-RPC object that can be decoded off the wire, or encoded to be sent.",

View File

@@ -11,6 +11,23 @@
"type": "integer"
}
]
},
"W3cTraceContext": {
"properties": {
"traceparent": {
"type": [
"string",
"null"
]
},
"tracestate": {
"type": [
"string",
"null"
]
}
},
"type": "object"
}
},
"description": "A request that expects a response.",
@@ -21,7 +38,18 @@
"method": {
"type": "string"
},
"params": true
"params": true,
"trace": {
"anyOf": [
{
"$ref": "#/definitions/W3cTraceContext"
},
{
"type": "null"
}
],
"description": "Optional W3C Trace Context for distributed tracing."
}
},
"required": [
"id",

View File

@@ -2306,6 +2306,16 @@
],
"description": "How to sandbox commands executed in the system"
},
"service_tier": {
"anyOf": [
{
"$ref": "#/definitions/v2/ServiceTier"
},
{
"type": "null"
}
]
},
"session_id": {
"$ref": "#/definitions/v2/ThreadId"
},
@@ -5016,7 +5026,18 @@
"method": {
"type": "string"
},
"params": true
"params": true,
"trace": {
"anyOf": [
{
"$ref": "#/definitions/W3cTraceContext"
},
{
"type": "null"
}
],
"description": "Optional W3C Trace Context for distributed tracing."
}
},
"required": [
"id",
@@ -5443,8 +5464,14 @@
{
"additionalProperties": false,
"properties": {
"SessionCreated": {
"SessionUpdated": {
"properties": {
"instructions": {
"type": [
"string",
"null"
]
},
"session_id": {
"type": "string"
}
@@ -5455,27 +5482,6 @@
"type": "object"
}
},
"required": [
"SessionCreated"
],
"title": "SessionCreatedRealtimeEvent",
"type": "object"
},
{
"additionalProperties": false,
"properties": {
"SessionUpdated": {
"properties": {
"backend_prompt": {
"type": [
"string",
"null"
]
}
},
"type": "object"
}
},
"required": [
"SessionUpdated"
],
@@ -5506,6 +5512,40 @@
"title": "ConversationItemAddedRealtimeEvent",
"type": "object"
},
{
"additionalProperties": false,
"properties": {
"ConversationItemDone": {
"properties": {
"item_id": {
"type": "string"
}
},
"required": [
"item_id"
],
"type": "object"
}
},
"required": [
"ConversationItemDone"
],
"title": "ConversationItemDoneRealtimeEvent",
"type": "object"
},
{
"additionalProperties": false,
"properties": {
"HandoffRequested": {
"$ref": "#/definitions/RealtimeHandoffRequested"
}
},
"required": [
"HandoffRequested"
],
"title": "HandoffRequestedRealtimeEvent",
"type": "object"
},
{
"additionalProperties": false,
"properties": {
@@ -5521,6 +5561,47 @@
}
]
},
"RealtimeHandoffMessage": {
"properties": {
"role": {
"type": "string"
},
"text": {
"type": "string"
}
},
"required": [
"role",
"text"
],
"type": "object"
},
"RealtimeHandoffRequested": {
"properties": {
"handoff_id": {
"type": "string"
},
"input_transcript": {
"type": "string"
},
"item_id": {
"type": "string"
},
"messages": {
"items": {
"$ref": "#/definitions/RealtimeHandoffMessage"
},
"type": "array"
}
},
"required": [
"handoff_id",
"input_transcript",
"item_id",
"messages"
],
"type": "object"
},
"RejectConfig": {
"properties": {
"mcp_elicitations": {
@@ -7220,6 +7301,23 @@
}
]
},
"W3cTraceContext": {
"properties": {
"traceparent": {
"type": [
"string",
"null"
]
},
"tracestate": {
"type": [
"string",
"null"
]
}
},
"type": "object"
},
"v2": {
"AbsolutePathBuf": {
"description": "A path that is guaranteed to be absolute and normalized (though it is not guaranteed to be canonicalized or exist on the filesystem).\n\nIMPORTANT: When deserializing an `AbsolutePathBuf`, a base path must be set using [AbsolutePathBufGuard::new]. If no base path is set, the deserialization will fail unless the path being deserialized is already absolute.",
@@ -8507,6 +8605,16 @@
}
]
},
"service_tier": {
"anyOf": [
{
"$ref": "#/definitions/v2/ServiceTier"
},
{
"type": "null"
}
]
},
"tools": {
"anyOf": [
{
@@ -10757,6 +10865,16 @@
}
]
},
"service_tier": {
"anyOf": [
{
"$ref": "#/definitions/v2/ServiceTier"
},
{
"type": "null"
}
]
},
"web_search": {
"anyOf": [
{
@@ -11908,6 +12026,12 @@
"title": "ServerRequestResolvedNotification",
"type": "object"
},
"ServiceTier": {
"enum": [
"fast"
],
"type": "string"
},
"SessionSource": {
"oneOf": [
{
@@ -12723,6 +12847,23 @@
}
]
},
"serviceTier": {
"anyOf": [
{
"anyOf": [
{
"$ref": "#/definitions/v2/ServiceTier"
},
{
"type": "null"
}
]
},
{
"type": "null"
}
]
},
"threadId": {
"type": "string"
}
@@ -12761,6 +12902,16 @@
"sandbox": {
"$ref": "#/definitions/v2/SandboxPolicy"
},
"serviceTier": {
"anyOf": [
{
"$ref": "#/definitions/v2/ServiceTier"
},
{
"type": "null"
}
]
},
"thread": {
"$ref": "#/definitions/v2/Thread"
}
@@ -13706,6 +13857,23 @@
}
]
},
"serviceTier": {
"anyOf": [
{
"anyOf": [
{
"$ref": "#/definitions/v2/ServiceTier"
},
{
"type": "null"
}
]
},
{
"type": "null"
}
]
},
"threadId": {
"type": "string"
}
@@ -13744,6 +13912,16 @@
"sandbox": {
"$ref": "#/definitions/v2/SandboxPolicy"
},
"serviceTier": {
"anyOf": [
{
"$ref": "#/definitions/v2/ServiceTier"
},
{
"type": "null"
}
]
},
"thread": {
"$ref": "#/definitions/v2/Thread"
}
@@ -13922,6 +14100,23 @@
"string",
"null"
]
},
"serviceTier": {
"anyOf": [
{
"anyOf": [
{
"$ref": "#/definitions/v2/ServiceTier"
},
{
"type": "null"
}
]
},
{
"type": "null"
}
]
}
},
"title": "ThreadStartParams",
@@ -13955,6 +14150,16 @@
"sandbox": {
"$ref": "#/definitions/v2/SandboxPolicy"
},
"serviceTier": {
"anyOf": [
{
"$ref": "#/definitions/v2/ServiceTier"
},
{
"type": "null"
}
]
},
"thread": {
"$ref": "#/definitions/v2/Thread"
}
@@ -14522,6 +14727,24 @@
],
"description": "Override the sandbox policy for this turn and subsequent turns."
},
"serviceTier": {
"anyOf": [
{
"anyOf": [
{
"$ref": "#/definitions/v2/ServiceTier"
},
{
"type": "null"
}
]
},
{
"type": "null"
}
],
"description": "Override the service tier for this turn and subsequent turns."
},
"summary": {
"anyOf": [
{

View File

@@ -323,6 +323,16 @@
}
]
},
"service_tier": {
"anyOf": [
{
"$ref": "#/definitions/ServiceTier"
},
{
"type": "null"
}
]
},
"tools": {
"anyOf": [
{
@@ -608,6 +618,16 @@
}
]
},
"service_tier": {
"anyOf": [
{
"$ref": "#/definitions/ServiceTier"
},
{
"type": "null"
}
]
},
"web_search": {
"anyOf": [
{
@@ -685,6 +705,12 @@
},
"type": "object"
},
"ServiceTier": {
"enum": [
"fast"
],
"type": "string"
},
"ToolsV2": {
"properties": {
"view_image": {

View File

@@ -50,6 +50,12 @@
"danger-full-access"
],
"type": "string"
},
"ServiceTier": {
"enum": [
"fast"
],
"type": "string"
}
},
"description": "There are two ways to fork a thread: 1. By thread_id: load the thread from disk by thread_id and fork it into a new thread. 2. By path: load the thread from disk by path and fork it into a new thread.\n\nIf using path, the thread_id param will be ignored.\n\nPrefer using thread_id whenever possible.",
@@ -112,6 +118,23 @@
}
]
},
"serviceTier": {
"anyOf": [
{
"anyOf": [
{
"$ref": "#/definitions/ServiceTier"
},
{
"type": "null"
}
]
},
{
"type": "null"
}
]
},
"threadId": {
"type": "string"
}

View File

@@ -738,6 +738,12 @@
}
]
},
"ServiceTier": {
"enum": [
"fast"
],
"type": "string"
},
"SessionSource": {
"oneOf": [
{
@@ -1906,6 +1912,16 @@
"sandbox": {
"$ref": "#/definitions/SandboxPolicy"
},
"serviceTier": {
"anyOf": [
{
"$ref": "#/definitions/ServiceTier"
},
{
"type": "null"
}
]
},
"thread": {
"$ref": "#/definitions/Thread"
}

View File

@@ -738,6 +738,12 @@
],
"type": "string"
},
"ServiceTier": {
"enum": [
"fast"
],
"type": "string"
},
"WebSearchAction": {
"oneOf": [
{
@@ -910,6 +916,23 @@
}
]
},
"serviceTier": {
"anyOf": [
{
"anyOf": [
{
"$ref": "#/definitions/ServiceTier"
},
{
"type": "null"
}
]
},
{
"type": "null"
}
]
},
"threadId": {
"type": "string"
}

View File

@@ -738,6 +738,12 @@
}
]
},
"ServiceTier": {
"enum": [
"fast"
],
"type": "string"
},
"SessionSource": {
"oneOf": [
{
@@ -1906,6 +1912,16 @@
"sandbox": {
"$ref": "#/definitions/SandboxPolicy"
},
"serviceTier": {
"anyOf": [
{
"$ref": "#/definitions/ServiceTier"
},
{
"type": "null"
}
]
},
"thread": {
"$ref": "#/definitions/Thread"
}

View File

@@ -75,6 +75,12 @@
"danger-full-access"
],
"type": "string"
},
"ServiceTier": {
"enum": [
"fast"
],
"type": "string"
}
},
"properties": {
@@ -156,6 +162,23 @@
"string",
"null"
]
},
"serviceTier": {
"anyOf": [
{
"anyOf": [
{
"$ref": "#/definitions/ServiceTier"
},
{
"type": "null"
}
]
},
{
"type": "null"
}
]
}
},
"title": "ThreadStartParams",

View File

@@ -738,6 +738,12 @@
}
]
},
"ServiceTier": {
"enum": [
"fast"
],
"type": "string"
},
"SessionSource": {
"oneOf": [
{
@@ -1906,6 +1912,16 @@
"sandbox": {
"$ref": "#/definitions/SandboxPolicy"
},
"serviceTier": {
"anyOf": [
{
"$ref": "#/definitions/ServiceTier"
},
{
"type": "null"
}
]
},
"thread": {
"$ref": "#/definitions/Thread"
}

View File

@@ -299,6 +299,12 @@
}
]
},
"ServiceTier": {
"enum": [
"fast"
],
"type": "string"
},
"Settings": {
"description": "Settings for a collaboration mode.",
"properties": {
@@ -539,6 +545,24 @@
],
"description": "Override the sandbox policy for this turn and subsequent turns."
},
"serviceTier": {
"anyOf": [
{
"anyOf": [
{
"$ref": "#/definitions/ServiceTier"
},
{
"type": "null"
}
]
},
{
"type": "null"
}
],
"description": "Override the service tier for this turn and subsequent turns."
},
"summary": {
"anyOf": [
{

View File

@@ -2,6 +2,7 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
import type { RealtimeAudioFrame } from "./RealtimeAudioFrame";
import type { RealtimeHandoffRequested } from "./RealtimeHandoffRequested";
import type { JsonValue } from "./serde_json/JsonValue";
export type RealtimeEvent = { "SessionCreated": { session_id: string, } } | { "SessionUpdated": { backend_prompt: string | null, } } | { "AudioOut": RealtimeAudioFrame } | { "ConversationItemAdded": JsonValue } | { "Error": string };
export type RealtimeEvent = { "SessionUpdated": { session_id: string, instructions: string | null, } } | { "AudioOut": RealtimeAudioFrame } | { "ConversationItemAdded": JsonValue } | { "ConversationItemDone": { item_id: string, } } | { "HandoffRequested": RealtimeHandoffRequested } | { "Error": string };

View File

@@ -0,0 +1,5 @@
// GENERATED CODE! DO NOT MODIFY BY HAND!
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
export type RealtimeHandoffMessage = { role: string, text: string, };

View File

@@ -0,0 +1,6 @@
// GENERATED CODE! DO NOT MODIFY BY HAND!
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
import type { RealtimeHandoffMessage } from "./RealtimeHandoffMessage";
export type RealtimeHandoffRequested = { handoff_id: string, item_id: string, input_transcript: string, messages: Array<RealtimeHandoffMessage>, };

View File

@@ -6,10 +6,11 @@ import type { InputItem } from "./InputItem";
import type { ReasoningEffort } from "./ReasoningEffort";
import type { ReasoningSummary } from "./ReasoningSummary";
import type { SandboxPolicy } from "./SandboxPolicy";
import type { ServiceTier } from "./ServiceTier";
import type { ThreadId } from "./ThreadId";
import type { JsonValue } from "./serde_json/JsonValue";
export type SendUserTurnParams = { conversationId: ThreadId, items: Array<InputItem>, cwd: string, approvalPolicy: AskForApproval, sandboxPolicy: SandboxPolicy, model: string, effort: ReasoningEffort | null, summary: ReasoningSummary,
export type SendUserTurnParams = { conversationId: ThreadId, items: Array<InputItem>, cwd: string, approvalPolicy: AskForApproval, sandboxPolicy: SandboxPolicy, model: string, serviceTier?: ServiceTier | null | null, effort: ReasoningEffort | null, summary: ReasoningSummary,
/**
* Optional JSON Schema used to constrain the final assistant message for this turn.
*/

View File

@@ -0,0 +1,5 @@
// GENERATED CODE! DO NOT MODIFY BY HAND!
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
export type ServiceTier = "fast";

View File

@@ -5,6 +5,7 @@ import type { AskForApproval } from "./AskForApproval";
import type { EventMsg } from "./EventMsg";
import type { ReasoningEffort } from "./ReasoningEffort";
import type { SandboxPolicy } from "./SandboxPolicy";
import type { ServiceTier } from "./ServiceTier";
import type { SessionNetworkProxyRuntime } from "./SessionNetworkProxyRuntime";
import type { ThreadId } from "./ThreadId";
@@ -16,7 +17,7 @@ thread_name?: string,
/**
* Tell the client what model is being queried.
*/
model: string, model_provider_id: string,
model: string, model_provider_id: string, service_tier: ServiceTier | null,
/**
* When to escalate for approval for execution
*/

View File

@@ -3,6 +3,7 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
import type { EventMsg } from "./EventMsg";
import type { ReasoningEffort } from "./ReasoningEffort";
import type { ServiceTier } from "./ServiceTier";
import type { ThreadId } from "./ThreadId";
export type SessionConfiguredNotification = { sessionId: ThreadId, model: string, reasoningEffort: ReasoningEffort | null, historyLogId: bigint, historyEntryCount: number, initialMessages: Array<EventMsg> | null, rolloutPath: string, };
export type SessionConfiguredNotification = { sessionId: ThreadId, model: string, serviceTier: ServiceTier | null, reasoningEffort: ReasoningEffort | null, historyLogId: bigint, historyEntryCount: number, initialMessages: Array<EventMsg> | null, rolloutPath: string, };

View File

@@ -161,6 +161,8 @@ export type { RealtimeConversationClosedEvent } from "./RealtimeConversationClos
export type { RealtimeConversationRealtimeEvent } from "./RealtimeConversationRealtimeEvent";
export type { RealtimeConversationStartedEvent } from "./RealtimeConversationStartedEvent";
export type { RealtimeEvent } from "./RealtimeEvent";
export type { RealtimeHandoffMessage } from "./RealtimeHandoffMessage";
export type { RealtimeHandoffRequested } from "./RealtimeHandoffRequested";
export type { ReasoningContentDeltaEvent } from "./ReasoningContentDeltaEvent";
export type { ReasoningEffort } from "./ReasoningEffort";
export type { ReasoningItem } from "./ReasoningItem";
@@ -198,6 +200,7 @@ export type { SendUserTurnParams } from "./SendUserTurnParams";
export type { SendUserTurnResponse } from "./SendUserTurnResponse";
export type { ServerNotification } from "./ServerNotification";
export type { ServerRequest } from "./ServerRequest";
export type { ServiceTier } from "./ServiceTier";
export type { SessionConfiguredEvent } from "./SessionConfiguredEvent";
export type { SessionConfiguredNotification } from "./SessionConfiguredNotification";
export type { SessionNetworkProxyRuntime } from "./SessionNetworkProxyRuntime";

View File

@@ -4,6 +4,7 @@
import type { ForcedLoginMethod } from "../ForcedLoginMethod";
import type { ReasoningEffort } from "../ReasoningEffort";
import type { ReasoningSummary } from "../ReasoningSummary";
import type { ServiceTier } from "../ServiceTier";
import type { Verbosity } from "../Verbosity";
import type { WebSearchMode } from "../WebSearchMode";
import type { JsonValue } from "../serde_json/JsonValue";
@@ -14,4 +15,4 @@ import type { SandboxMode } from "./SandboxMode";
import type { SandboxWorkspaceWrite } from "./SandboxWorkspaceWrite";
import type { ToolsV2 } from "./ToolsV2";
export type Config = {model: string | null, review_model: string | null, model_context_window: bigint | null, model_auto_compact_token_limit: bigint | null, model_provider: string | null, approval_policy: AskForApproval | null, sandbox_mode: SandboxMode | null, sandbox_workspace_write: SandboxWorkspaceWrite | null, forced_chatgpt_workspace_id: string | null, forced_login_method: ForcedLoginMethod | null, web_search: WebSearchMode | null, tools: ToolsV2 | null, profile: string | null, profiles: { [key in string]?: ProfileV2 }, instructions: string | null, developer_instructions: string | null, compact_prompt: string | null, model_reasoning_effort: ReasoningEffort | null, model_reasoning_summary: ReasoningSummary | null, model_verbosity: Verbosity | null, analytics: AnalyticsConfig | null} & ({ [key in string]?: number | string | boolean | Array<JsonValue> | { [key in string]?: JsonValue } | null });
export type Config = {model: string | null, review_model: string | null, model_context_window: bigint | null, model_auto_compact_token_limit: bigint | null, model_provider: string | null, approval_policy: AskForApproval | null, sandbox_mode: SandboxMode | null, sandbox_workspace_write: SandboxWorkspaceWrite | null, forced_chatgpt_workspace_id: string | null, forced_login_method: ForcedLoginMethod | null, web_search: WebSearchMode | null, tools: ToolsV2 | null, profile: string | null, profiles: { [key in string]?: ProfileV2 }, instructions: string | null, developer_instructions: string | null, compact_prompt: string | null, model_reasoning_effort: ReasoningEffort | null, model_reasoning_summary: ReasoningSummary | null, model_verbosity: Verbosity | null, service_tier: ServiceTier | null, analytics: AnalyticsConfig | null} & ({ [key in string]?: number | string | boolean | Array<JsonValue> | { [key in string]?: JsonValue } | null });

View File

@@ -3,9 +3,10 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
import type { ReasoningEffort } from "../ReasoningEffort";
import type { ReasoningSummary } from "../ReasoningSummary";
import type { ServiceTier } from "../ServiceTier";
import type { Verbosity } from "../Verbosity";
import type { WebSearchMode } from "../WebSearchMode";
import type { JsonValue } from "../serde_json/JsonValue";
import type { AskForApproval } from "./AskForApproval";
export type ProfileV2 = { model: string | null, model_provider: string | null, approval_policy: AskForApproval | null, model_reasoning_effort: ReasoningEffort | null, model_reasoning_summary: ReasoningSummary | null, model_verbosity: Verbosity | null, web_search: WebSearchMode | null, chatgpt_base_url: string | null, } & ({ [key in string]?: number | string | boolean | Array<JsonValue> | { [key in string]?: JsonValue } | null });
export type ProfileV2 = { model: string | null, model_provider: string | null, approval_policy: AskForApproval | null, service_tier: ServiceTier | null, model_reasoning_effort: ReasoningEffort | null, model_reasoning_summary: ReasoningSummary | null, model_verbosity: Verbosity | null, web_search: WebSearchMode | null, chatgpt_base_url: string | null, } & ({ [key in string]?: number | string | boolean | Array<JsonValue> | { [key in string]?: JsonValue } | null });

View File

@@ -1,6 +1,7 @@
// GENERATED CODE! DO NOT MODIFY BY HAND!
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
import type { ServiceTier } from "../ServiceTier";
import type { JsonValue } from "../serde_json/JsonValue";
import type { AskForApproval } from "./AskForApproval";
import type { SandboxMode } from "./SandboxMode";
@@ -21,7 +22,7 @@ export type ThreadForkParams = {threadId: string, /**
path?: string | null, /**
* Configuration overrides for the forked thread, if any.
*/
model?: string | null, modelProvider?: string | null, cwd?: string | null, approvalPolicy?: AskForApproval | null, sandbox?: SandboxMode | null, config?: { [key in string]?: JsonValue } | null, baseInstructions?: string | null, developerInstructions?: string | null, /**
model?: string | null, modelProvider?: string | null, serviceTier?: ServiceTier | null | null, cwd?: string | null, approvalPolicy?: AskForApproval | null, sandbox?: SandboxMode | null, config?: { [key in string]?: JsonValue } | null, baseInstructions?: string | null, developerInstructions?: string | null, /**
* If true, persist additional rollout EventMsg variants required to
* reconstruct a richer thread history on subsequent resume/fork/read.
*/

View File

@@ -2,8 +2,9 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
import type { ReasoningEffort } from "../ReasoningEffort";
import type { ServiceTier } from "../ServiceTier";
import type { AskForApproval } from "./AskForApproval";
import type { SandboxPolicy } from "./SandboxPolicy";
import type { Thread } from "./Thread";
export type ThreadForkResponse = { thread: Thread, model: string, modelProvider: string, cwd: string, approvalPolicy: AskForApproval, sandbox: SandboxPolicy, reasoningEffort: ReasoningEffort | null, };
export type ThreadForkResponse = { thread: Thread, model: string, modelProvider: string, serviceTier: ServiceTier | null, cwd: string, approvalPolicy: AskForApproval, sandbox: SandboxPolicy, reasoningEffort: ReasoningEffort | null, };

View File

@@ -3,6 +3,7 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
import type { Personality } from "../Personality";
import type { ResponseItem } from "../ResponseItem";
import type { ServiceTier } from "../ServiceTier";
import type { JsonValue } from "../serde_json/JsonValue";
import type { AskForApproval } from "./AskForApproval";
import type { SandboxMode } from "./SandboxMode";
@@ -30,7 +31,7 @@ history?: Array<ResponseItem> | null, /**
path?: string | null, /**
* Configuration overrides for the resumed thread, if any.
*/
model?: string | null, modelProvider?: string | null, cwd?: string | null, approvalPolicy?: AskForApproval | null, sandbox?: SandboxMode | null, config?: { [key in string]?: JsonValue } | null, baseInstructions?: string | null, developerInstructions?: string | null, personality?: Personality | null, /**
model?: string | null, modelProvider?: string | null, serviceTier?: ServiceTier | null | null, cwd?: string | null, approvalPolicy?: AskForApproval | null, sandbox?: SandboxMode | null, config?: { [key in string]?: JsonValue } | null, baseInstructions?: string | null, developerInstructions?: string | null, personality?: Personality | null, /**
* If true, persist additional rollout EventMsg variants required to
* reconstruct a richer thread history on subsequent resume/fork/read.
*/

View File

@@ -2,8 +2,9 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
import type { ReasoningEffort } from "../ReasoningEffort";
import type { ServiceTier } from "../ServiceTier";
import type { AskForApproval } from "./AskForApproval";
import type { SandboxPolicy } from "./SandboxPolicy";
import type { Thread } from "./Thread";
export type ThreadResumeResponse = { thread: Thread, model: string, modelProvider: string, cwd: string, approvalPolicy: AskForApproval, sandbox: SandboxPolicy, reasoningEffort: ReasoningEffort | null, };
export type ThreadResumeResponse = { thread: Thread, model: string, modelProvider: string, serviceTier: ServiceTier | null, cwd: string, approvalPolicy: AskForApproval, sandbox: SandboxPolicy, reasoningEffort: ReasoningEffort | null, };

View File

@@ -2,11 +2,12 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
import type { Personality } from "../Personality";
import type { ServiceTier } from "../ServiceTier";
import type { JsonValue } from "../serde_json/JsonValue";
import type { AskForApproval } from "./AskForApproval";
import type { SandboxMode } from "./SandboxMode";
export type ThreadStartParams = {model?: string | null, modelProvider?: string | null, cwd?: string | null, approvalPolicy?: AskForApproval | null, sandbox?: SandboxMode | null, config?: { [key in string]?: JsonValue } | null, serviceName?: string | null, baseInstructions?: string | null, developerInstructions?: string | null, personality?: Personality | null, ephemeral?: boolean | null, /**
export type ThreadStartParams = {model?: string | null, modelProvider?: string | null, serviceTier?: ServiceTier | null | null, cwd?: string | null, approvalPolicy?: AskForApproval | null, sandbox?: SandboxMode | null, config?: { [key in string]?: JsonValue } | null, serviceName?: string | null, baseInstructions?: string | null, developerInstructions?: string | null, personality?: Personality | null, ephemeral?: boolean | null, /**
* If true, opt into emitting raw Responses API items on the event stream.
* This is for internal use only (e.g. Codex Cloud).
*/

View File

@@ -2,8 +2,9 @@
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
import type { ReasoningEffort } from "../ReasoningEffort";
import type { ServiceTier } from "../ServiceTier";
import type { AskForApproval } from "./AskForApproval";
import type { SandboxPolicy } from "./SandboxPolicy";
import type { Thread } from "./Thread";
export type ThreadStartResponse = { thread: Thread, model: string, modelProvider: string, cwd: string, approvalPolicy: AskForApproval, sandbox: SandboxPolicy, reasoningEffort: ReasoningEffort | null, };
export type ThreadStartResponse = { thread: Thread, model: string, modelProvider: string, serviceTier: ServiceTier | null, cwd: string, approvalPolicy: AskForApproval, sandbox: SandboxPolicy, reasoningEffort: ReasoningEffort | null, };

View File

@@ -5,6 +5,7 @@ import type { CollaborationMode } from "../CollaborationMode";
import type { Personality } from "../Personality";
import type { ReasoningEffort } from "../ReasoningEffort";
import type { ReasoningSummary } from "../ReasoningSummary";
import type { ServiceTier } from "../ServiceTier";
import type { JsonValue } from "../serde_json/JsonValue";
import type { AskForApproval } from "./AskForApproval";
import type { SandboxPolicy } from "./SandboxPolicy";
@@ -23,6 +24,9 @@ sandboxPolicy?: SandboxPolicy | null, /**
* Override the model for this turn and subsequent turns.
*/
model?: string | null, /**
* Override the service tier for this turn and subsequent turns.
*/
serviceTier?: ServiceTier | null | null, /**
* Override the reasoning effort for this turn and subsequent turns.
*/
effort?: ReasoningEffort | null, /**

View File

@@ -1,6 +1,7 @@
//! We do not do true JSON-RPC 2.0, as we neither send nor expect the
//! "jsonrpc": "2.0" field.
use codex_protocol::protocol::W3cTraceContext;
use schemars::JsonSchema;
use serde::Deserialize;
use serde::Serialize;
@@ -38,6 +39,10 @@ pub struct JSONRPCRequest {
#[serde(default, skip_serializing_if = "Option::is_none")]
#[ts(optional)]
pub params: Option<serde_json::Value>,
/// Optional W3C Trace Context for distributed tracing.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[ts(optional)]
pub trace: Option<W3cTraceContext>,
}
/// A notification which does not expect a response.

View File

@@ -3,6 +3,7 @@
pub mod common;
mod mappers;
mod serde_helpers;
pub mod thread_history;
pub mod v1;
pub mod v2;

View File

@@ -0,0 +1,23 @@
use serde::Deserialize;
use serde::Deserializer;
use serde::Serialize;
use serde::Serializer;
pub fn deserialize_double_option<'de, T, D>(deserializer: D) -> Result<Option<Option<T>>, D::Error>
where
T: Deserialize<'de>,
D: Deserializer<'de>,
{
serde_with::rust::double_option::deserialize(deserializer)
}
pub fn serialize_double_option<T, S>(
value: &Option<Option<T>>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
T: Serialize,
S: Serializer,
{
serde_with::rust::double_option::serialize(value, serializer)
}

View File

@@ -5,6 +5,7 @@ use codex_protocol::ThreadId;
use codex_protocol::config_types::ForcedLoginMethod;
use codex_protocol::config_types::ReasoningSummary;
use codex_protocol::config_types::SandboxMode;
use codex_protocol::config_types::ServiceTier;
use codex_protocol::config_types::Verbosity;
use codex_protocol::models::ResponseItem;
use codex_protocol::openai_models::ReasoningEffort;
@@ -419,6 +420,13 @@ pub struct SendUserTurnParams {
pub approval_policy: AskForApproval,
pub sandbox_policy: SandboxPolicy,
pub model: String,
#[serde(
default,
deserialize_with = "super::serde_helpers::deserialize_double_option",
serialize_with = "super::serde_helpers::serialize_double_option",
skip_serializing_if = "Option::is_none"
)]
pub service_tier: Option<Option<ServiceTier>>,
pub effort: Option<ReasoningEffort>,
pub summary: ReasoningSummary,
/// Optional JSON Schema used to constrain the final assistant message for this turn.
@@ -429,6 +437,55 @@ pub struct SendUserTurnParams {
#[serde(rename_all = "camelCase")]
pub struct SendUserTurnResponse {}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use std::path::PathBuf;
#[test]
fn send_user_turn_params_preserve_explicit_null_service_tier() {
let params = SendUserTurnParams {
conversation_id: ThreadId::new(),
items: vec![],
cwd: PathBuf::from("/tmp"),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::DangerFullAccess,
model: "gpt-4.1".to_string(),
service_tier: Some(None),
effort: None,
summary: ReasoningSummary::Auto,
output_schema: None,
};
let serialized = serde_json::to_value(&params).expect("params should serialize");
assert_eq!(
serialized.get("serviceTier"),
Some(&serde_json::Value::Null)
);
let roundtrip: SendUserTurnParams =
serde_json::from_value(serialized).expect("params should deserialize");
assert_eq!(roundtrip.service_tier, Some(None));
let without_override = SendUserTurnParams {
conversation_id: ThreadId::new(),
items: vec![],
cwd: PathBuf::from("/tmp"),
approval_policy: AskForApproval::Never,
sandbox_policy: SandboxPolicy::DangerFullAccess,
model: "gpt-4.1".to_string(),
service_tier: None,
effort: None,
summary: ReasoningSummary::Auto,
output_schema: None,
};
let serialized_without_override =
serde_json::to_value(&without_override).expect("params should serialize");
assert_eq!(serialized_without_override.get("serviceTier"), None);
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, JsonSchema, TS)]
#[serde(rename_all = "camelCase")]
pub struct InterruptConversationParams {
@@ -555,6 +612,7 @@ pub struct LoginChatGptCompleteNotification {
pub struct SessionConfiguredNotification {
pub session_id: ThreadId,
pub model: String,
pub service_tier: Option<ServiceTier>,
pub reasoning_effort: Option<ReasoningEffort>,
pub history_log_id: u64,
#[ts(type = "number")]

View File

@@ -17,6 +17,7 @@ use codex_protocol::config_types::ModeKind;
use codex_protocol::config_types::Personality;
use codex_protocol::config_types::ReasoningSummary;
use codex_protocol::config_types::SandboxMode as CoreSandboxMode;
use codex_protocol::config_types::ServiceTier;
use codex_protocol::config_types::Verbosity;
use codex_protocol::config_types::WebSearchMode;
use codex_protocol::items::AgentMessageContent as CoreAgentMessageContent;
@@ -392,6 +393,7 @@ pub struct ProfileV2 {
pub model: Option<String>,
pub model_provider: Option<String>,
pub approval_policy: Option<AskForApproval>,
pub service_tier: Option<ServiceTier>,
pub model_reasoning_effort: Option<ReasoningEffort>,
pub model_reasoning_summary: Option<ReasoningSummary>,
pub model_verbosity: Option<Verbosity>,
@@ -503,6 +505,7 @@ pub struct Config {
pub model_reasoning_effort: Option<ReasoningEffort>,
pub model_reasoning_summary: Option<ReasoningSummary>,
pub model_verbosity: Option<Verbosity>,
pub service_tier: Option<ServiceTier>,
pub analytics: Option<AnalyticsConfig>,
#[experimental("config/read.apps")]
#[serde(default)]
@@ -1788,6 +1791,14 @@ pub struct ThreadStartParams {
pub model: Option<String>,
#[ts(optional = nullable)]
pub model_provider: Option<String>,
#[serde(
default,
deserialize_with = "super::serde_helpers::deserialize_double_option",
serialize_with = "super::serde_helpers::serialize_double_option",
skip_serializing_if = "Option::is_none"
)]
#[ts(optional = nullable)]
pub service_tier: Option<Option<ServiceTier>>,
#[ts(optional = nullable)]
pub cwd: Option<String>,
#[ts(optional = nullable)]
@@ -1850,6 +1861,7 @@ pub struct ThreadStartResponse {
pub thread: Thread,
pub model: String,
pub model_provider: String,
pub service_tier: Option<ServiceTier>,
pub cwd: PathBuf,
pub approval_policy: AskForApproval,
pub sandbox: SandboxPolicy,
@@ -1891,6 +1903,14 @@ pub struct ThreadResumeParams {
pub model: Option<String>,
#[ts(optional = nullable)]
pub model_provider: Option<String>,
#[serde(
default,
deserialize_with = "super::serde_helpers::deserialize_double_option",
serialize_with = "super::serde_helpers::serialize_double_option",
skip_serializing_if = "Option::is_none"
)]
#[ts(optional = nullable)]
pub service_tier: Option<Option<ServiceTier>>,
#[ts(optional = nullable)]
pub cwd: Option<String>,
#[ts(optional = nullable)]
@@ -1919,6 +1939,7 @@ pub struct ThreadResumeResponse {
pub thread: Thread,
pub model: String,
pub model_provider: String,
pub service_tier: Option<ServiceTier>,
pub cwd: PathBuf,
pub approval_policy: AskForApproval,
pub sandbox: SandboxPolicy,
@@ -1951,6 +1972,14 @@ pub struct ThreadForkParams {
pub model: Option<String>,
#[ts(optional = nullable)]
pub model_provider: Option<String>,
#[serde(
default,
deserialize_with = "super::serde_helpers::deserialize_double_option",
serialize_with = "super::serde_helpers::serialize_double_option",
skip_serializing_if = "Option::is_none"
)]
#[ts(optional = nullable)]
pub service_tier: Option<Option<ServiceTier>>,
#[ts(optional = nullable)]
pub cwd: Option<String>,
#[ts(optional = nullable)]
@@ -1977,6 +2006,7 @@ pub struct ThreadForkResponse {
pub thread: Thread,
pub model: String,
pub model_provider: String,
pub service_tier: Option<ServiceTier>,
pub cwd: PathBuf,
pub approval_policy: AskForApproval,
pub sandbox: SandboxPolicy,
@@ -2837,6 +2867,15 @@ pub struct TurnStartParams {
/// Override the model for this turn and subsequent turns.
#[ts(optional = nullable)]
pub model: Option<String>,
/// Override the service tier for this turn and subsequent turns.
#[serde(
default,
deserialize_with = "super::serde_helpers::deserialize_double_option",
serialize_with = "super::serde_helpers::serialize_double_option",
skip_serializing_if = "Option::is_none"
)]
#[ts(optional = nullable)]
pub service_tier: Option<Option<ServiceTier>>,
/// Override the reasoning effort for this turn and subsequent turns.
#[ts(optional = nullable)]
pub effort: Option<ReasoningEffort>,
@@ -4566,4 +4605,56 @@ mod tests {
})
);
}
#[test]
fn thread_start_params_preserve_explicit_null_service_tier() {
let params: ThreadStartParams = serde_json::from_value(json!({ "serviceTier": null }))
.expect("params should deserialize");
assert_eq!(params.service_tier, Some(None));
let serialized = serde_json::to_value(&params).expect("params should serialize");
assert_eq!(
serialized.get("serviceTier"),
Some(&serde_json::Value::Null)
);
let serialized_without_override =
serde_json::to_value(ThreadStartParams::default()).expect("params should serialize");
assert_eq!(serialized_without_override.get("serviceTier"), None);
}
#[test]
fn turn_start_params_preserve_explicit_null_service_tier() {
let params: TurnStartParams = serde_json::from_value(json!({
"threadId": "thread_123",
"input": [],
"serviceTier": null
}))
.expect("params should deserialize");
assert_eq!(params.service_tier, Some(None));
let serialized = serde_json::to_value(&params).expect("params should serialize");
assert_eq!(
serialized.get("serviceTier"),
Some(&serde_json::Value::Null)
);
let without_override = TurnStartParams {
thread_id: "thread_123".to_string(),
input: vec![],
cwd: None,
approval_policy: None,
sandbox_policy: None,
model: None,
service_tier: None,
effort: None,
summary: None,
output_schema: None,
collaboration_mode: None,
personality: None,
};
let serialized_without_override =
serde_json::to_value(&without_override).expect("params should serialize");
assert_eq!(serialized_without_override.get("serviceTier"), None);
}
}

View File

@@ -15,6 +15,7 @@ use std::process::Command;
use std::process::Stdio;
use std::thread;
use std::time::Duration;
use std::time::SystemTime;
use anyhow::Context;
use anyhow::Result;
@@ -71,6 +72,7 @@ use codex_app_server_protocol::UserInput as V2UserInput;
use codex_protocol::ThreadId;
use codex_protocol::protocol::Event;
use codex_protocol::protocol::EventMsg;
use codex_protocol::protocol::W3cTraceContext;
use serde::Serialize;
use serde::de::DeserializeOwned;
use serde_json::Value;
@@ -104,6 +106,8 @@ const NOTIFICATIONS_TO_OPT_OUT: &[&str] = &[
"item/reasoning/summaryTextDelta",
"item/reasoning/textDelta",
];
const APP_SERVER_GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
const APP_SERVER_GRACEFUL_SHUTDOWN_POLL_INTERVAL: Duration = Duration::from_millis(100);
/// Minimal launcher that initializes the Codex app-server and logs the handshake.
#[derive(Parser)]
@@ -498,25 +502,26 @@ fn send_message(
config_overrides: &[String],
user_message: String,
) -> Result<()> {
let mut client = CodexClient::connect(endpoint, config_overrides)?;
with_client(endpoint, config_overrides, |client| {
let initialize = client.initialize()?;
println!("< initialize response: {initialize:?}");
let initialize = client.initialize()?;
println!("< initialize response: {initialize:?}");
let conversation = client.start_thread()?;
println!("< newConversation response: {conversation:?}");
let conversation = client.start_thread()?;
println!("< newConversation response: {conversation:?}");
let subscription = client.add_conversation_listener(&conversation.conversation_id)?;
println!("< addConversationListener response: {subscription:?}");
let subscription = client.add_conversation_listener(&conversation.conversation_id)?;
println!("< addConversationListener response: {subscription:?}");
let send_response =
client.send_user_message(&conversation.conversation_id, &user_message)?;
println!("< sendUserMessage response: {send_response:?}");
let send_response = client.send_user_message(&conversation.conversation_id, &user_message)?;
println!("< sendUserMessage response: {send_response:?}");
client.stream_conversation(&conversation.conversation_id)?;
client.stream_conversation(&conversation.conversation_id)?;
client.remove_thread_listener(subscription.subscription_id)?;
client.remove_thread_listener(subscription.subscription_id)?;
Ok(())
Ok(())
})
}
pub fn send_message_v2(
@@ -574,82 +579,85 @@ fn trigger_zsh_fork_multi_cmd_approval(
let default_prompt = "Run this exact command using shell command execution without rewriting or splitting it: /usr/bin/true && /usr/bin/true";
let message = user_message.unwrap_or_else(|| default_prompt.to_string());
let mut client = CodexClient::connect(endpoint, config_overrides)?;
let initialize = client.initialize()?;
println!("< initialize response: {initialize:?}");
with_client(endpoint, config_overrides, |client| {
let initialize = client.initialize()?;
println!("< initialize response: {initialize:?}");
let thread_response = client.thread_start(ThreadStartParams {
dynamic_tools: dynamic_tools.clone(),
..Default::default()
})?;
println!("< thread/start response: {thread_response:?}");
let thread_response = client.thread_start(ThreadStartParams {
dynamic_tools: dynamic_tools.clone(),
..Default::default()
})?;
println!("< thread/start response: {thread_response:?}");
client.command_approval_behavior = match abort_on {
Some(index) => CommandApprovalBehavior::AbortOn(index),
None => CommandApprovalBehavior::AlwaysAccept,
};
client.command_approval_count = 0;
client.command_approval_item_ids.clear();
client.command_execution_statuses.clear();
client.last_turn_status = None;
client.command_approval_behavior = match abort_on {
Some(index) => CommandApprovalBehavior::AbortOn(index),
None => CommandApprovalBehavior::AlwaysAccept,
};
client.command_approval_count = 0;
client.command_approval_item_ids.clear();
client.command_execution_statuses.clear();
client.last_turn_status = None;
let mut turn_params = TurnStartParams {
thread_id: thread_response.thread.id.clone(),
input: vec![V2UserInput::Text {
text: message,
text_elements: Vec::new(),
}],
..Default::default()
};
turn_params.approval_policy = Some(AskForApproval::OnRequest);
turn_params.sandbox_policy = Some(SandboxPolicy::ReadOnly {
access: ReadOnlyAccess::FullAccess,
});
let mut turn_params = TurnStartParams {
thread_id: thread_response.thread.id.clone(),
input: vec![V2UserInput::Text {
text: message,
text_elements: Vec::new(),
}],
..Default::default()
};
turn_params.approval_policy = Some(AskForApproval::OnRequest);
turn_params.sandbox_policy = Some(SandboxPolicy::ReadOnly {
access: ReadOnlyAccess::FullAccess,
});
let turn_response = client.turn_start(turn_params)?;
println!("< turn/start response: {turn_response:?}");
client.stream_turn(&thread_response.thread.id, &turn_response.turn.id)?;
let turn_response = client.turn_start(turn_params)?;
println!("< turn/start response: {turn_response:?}");
client.stream_turn(&thread_response.thread.id, &turn_response.turn.id)?;
if client.command_approval_count < min_approvals {
bail!(
"expected at least {min_approvals} command approvals, got {}",
client.command_approval_count
);
}
let mut approvals_per_item = std::collections::BTreeMap::new();
for item_id in &client.command_approval_item_ids {
*approvals_per_item.entry(item_id.clone()).or_insert(0usize) += 1;
}
let max_approvals_for_one_item = approvals_per_item.values().copied().max().unwrap_or(0);
if max_approvals_for_one_item < min_approvals {
bail!(
"expected at least {min_approvals} approvals for one command item, got max {max_approvals_for_one_item} with map {approvals_per_item:?}"
);
}
let last_command_status = client.command_execution_statuses.last();
if abort_on.is_none() {
if last_command_status != Some(&CommandExecutionStatus::Completed) {
bail!("expected completed command execution, got {last_command_status:?}");
}
if client.last_turn_status != Some(TurnStatus::Completed) {
if client.command_approval_count < min_approvals {
bail!(
"expected completed turn in all-accept flow, got {:?}",
client.last_turn_status
"expected at least {min_approvals} command approvals, got {}",
client.command_approval_count
);
}
} else if last_command_status == Some(&CommandExecutionStatus::Completed) {
bail!(
"expected non-completed command execution in mixed approval/decline flow, got {last_command_status:?}"
let mut approvals_per_item = std::collections::BTreeMap::new();
for item_id in &client.command_approval_item_ids {
*approvals_per_item.entry(item_id.clone()).or_insert(0usize) += 1;
}
let max_approvals_for_one_item = approvals_per_item.values().copied().max().unwrap_or(0);
if max_approvals_for_one_item < min_approvals {
bail!(
"expected at least {min_approvals} approvals for one command item, got max {max_approvals_for_one_item} with map {approvals_per_item:?}"
);
}
let last_command_status = client.command_execution_statuses.last();
if abort_on.is_none() {
if last_command_status != Some(&CommandExecutionStatus::Completed) {
bail!("expected completed command execution, got {last_command_status:?}");
}
if client.last_turn_status != Some(TurnStatus::Completed) {
bail!(
"expected completed turn in all-accept flow, got {:?}",
client.last_turn_status
);
}
} else if last_command_status == Some(&CommandExecutionStatus::Completed) {
bail!(
"expected non-completed command execution in mixed approval/decline flow, got {last_command_status:?}"
);
}
println!(
"[zsh-fork multi-approval summary] approvals={}, approvals_per_item={approvals_per_item:?}, command_statuses={:?}, turn_status={:?}",
client.command_approval_count,
client.command_execution_statuses,
client.last_turn_status
);
}
println!(
"[zsh-fork multi-approval summary] approvals={}, approvals_per_item={approvals_per_item:?}, command_statuses={:?}, turn_status={:?}",
client.command_approval_count, client.command_execution_statuses, client.last_turn_status
);
Ok(())
Ok(())
})
}
fn resume_message_v2(
@@ -661,30 +669,30 @@ fn resume_message_v2(
) -> Result<()> {
ensure_dynamic_tools_unused(dynamic_tools, "resume-message-v2")?;
let mut client = CodexClient::connect(endpoint, config_overrides)?;
with_client(endpoint, config_overrides, |client| {
let initialize = client.initialize()?;
println!("< initialize response: {initialize:?}");
let initialize = client.initialize()?;
println!("< initialize response: {initialize:?}");
let resume_response = client.thread_resume(ThreadResumeParams {
thread_id,
..Default::default()
})?;
println!("< thread/resume response: {resume_response:?}");
let resume_response = client.thread_resume(ThreadResumeParams {
thread_id,
..Default::default()
})?;
println!("< thread/resume response: {resume_response:?}");
let turn_response = client.turn_start(TurnStartParams {
thread_id: resume_response.thread.id.clone(),
input: vec![V2UserInput::Text {
text: user_message,
text_elements: Vec::new(),
}],
..Default::default()
})?;
println!("< turn/start response: {turn_response:?}");
let turn_response = client.turn_start(TurnStartParams {
thread_id: resume_response.thread.id.clone(),
input: vec![V2UserInput::Text {
text: user_message,
text_elements: Vec::new(),
}],
..Default::default()
})?;
println!("< turn/start response: {turn_response:?}");
client.stream_turn(&resume_response.thread.id, &turn_response.turn.id)?;
client.stream_turn(&resume_response.thread.id, &turn_response.turn.id)?;
Ok(())
Ok(())
})
}
fn thread_resume_follow(
@@ -787,34 +795,34 @@ fn send_message_v2_with_policies(
sandbox_policy: Option<SandboxPolicy>,
dynamic_tools: &Option<Vec<DynamicToolSpec>>,
) -> Result<()> {
let mut client = CodexClient::connect(endpoint, config_overrides)?;
with_client(endpoint, config_overrides, |client| {
let initialize = client.initialize_with_experimental_api(experimental_api)?;
println!("< initialize response: {initialize:?}");
let initialize = client.initialize_with_experimental_api(experimental_api)?;
println!("< initialize response: {initialize:?}");
let thread_response = client.thread_start(ThreadStartParams {
dynamic_tools: dynamic_tools.clone(),
..Default::default()
})?;
println!("< thread/start response: {thread_response:?}");
let mut turn_params = TurnStartParams {
thread_id: thread_response.thread.id.clone(),
input: vec![V2UserInput::Text {
text: user_message,
// Test client sends plain text without UI element ranges.
text_elements: Vec::new(),
}],
..Default::default()
};
turn_params.approval_policy = approval_policy;
turn_params.sandbox_policy = sandbox_policy;
let thread_response = client.thread_start(ThreadStartParams {
dynamic_tools: dynamic_tools.clone(),
..Default::default()
})?;
println!("< thread/start response: {thread_response:?}");
let mut turn_params = TurnStartParams {
thread_id: thread_response.thread.id.clone(),
input: vec![V2UserInput::Text {
text: user_message,
// Test client sends plain text without UI element ranges.
text_elements: Vec::new(),
}],
..Default::default()
};
turn_params.approval_policy = approval_policy;
turn_params.sandbox_policy = sandbox_policy;
let turn_response = client.turn_start(turn_params)?;
println!("< turn/start response: {turn_response:?}");
let turn_response = client.turn_start(turn_params)?;
println!("< turn/start response: {turn_response:?}");
client.stream_turn(&thread_response.thread.id, &turn_response.turn.id)?;
client.stream_turn(&thread_response.thread.id, &turn_response.turn.id)?;
Ok(())
Ok(())
})
}
fn send_follow_up_v2(
@@ -824,119 +832,130 @@ fn send_follow_up_v2(
follow_up_message: String,
dynamic_tools: &Option<Vec<DynamicToolSpec>>,
) -> Result<()> {
let mut client = CodexClient::connect(endpoint, config_overrides)?;
with_client(endpoint, config_overrides, |client| {
let initialize = client.initialize()?;
println!("< initialize response: {initialize:?}");
let initialize = client.initialize()?;
println!("< initialize response: {initialize:?}");
let thread_response = client.thread_start(ThreadStartParams {
dynamic_tools: dynamic_tools.clone(),
..Default::default()
})?;
println!("< thread/start response: {thread_response:?}");
let thread_response = client.thread_start(ThreadStartParams {
dynamic_tools: dynamic_tools.clone(),
..Default::default()
})?;
println!("< thread/start response: {thread_response:?}");
let first_turn_params = TurnStartParams {
thread_id: thread_response.thread.id.clone(),
input: vec![V2UserInput::Text {
text: first_message,
// Test client sends plain text without UI element ranges.
text_elements: Vec::new(),
}],
..Default::default()
};
let first_turn_response = client.turn_start(first_turn_params)?;
println!("< turn/start response (initial): {first_turn_response:?}");
client.stream_turn(&thread_response.thread.id, &first_turn_response.turn.id)?;
let first_turn_params = TurnStartParams {
thread_id: thread_response.thread.id.clone(),
input: vec![V2UserInput::Text {
text: first_message,
// Test client sends plain text without UI element ranges.
text_elements: Vec::new(),
}],
..Default::default()
};
let first_turn_response = client.turn_start(first_turn_params)?;
println!("< turn/start response (initial): {first_turn_response:?}");
client.stream_turn(&thread_response.thread.id, &first_turn_response.turn.id)?;
let follow_up_params = TurnStartParams {
thread_id: thread_response.thread.id.clone(),
input: vec![V2UserInput::Text {
text: follow_up_message,
// Test client sends plain text without UI element ranges.
text_elements: Vec::new(),
}],
..Default::default()
};
let follow_up_response = client.turn_start(follow_up_params)?;
println!("< turn/start response (follow-up): {follow_up_response:?}");
client.stream_turn(&thread_response.thread.id, &follow_up_response.turn.id)?;
let follow_up_params = TurnStartParams {
thread_id: thread_response.thread.id.clone(),
input: vec![V2UserInput::Text {
text: follow_up_message,
// Test client sends plain text without UI element ranges.
text_elements: Vec::new(),
}],
..Default::default()
};
let follow_up_response = client.turn_start(follow_up_params)?;
println!("< turn/start response (follow-up): {follow_up_response:?}");
client.stream_turn(&thread_response.thread.id, &follow_up_response.turn.id)?;
Ok(())
Ok(())
})
}
fn test_login(endpoint: &Endpoint, config_overrides: &[String]) -> Result<()> {
let mut client = CodexClient::connect(endpoint, config_overrides)?;
with_client(endpoint, config_overrides, |client| {
let initialize = client.initialize()?;
println!("< initialize response: {initialize:?}");
let initialize = client.initialize()?;
println!("< initialize response: {initialize:?}");
let login_response = client.login_chat_gpt()?;
println!("< loginChatGpt response: {login_response:?}");
println!(
"Open the following URL in your browser to continue:\n{}",
login_response.auth_url
);
let completion = client.wait_for_login_completion(&login_response.login_id)?;
println!("< loginChatGptComplete notification: {completion:?}");
if completion.success {
println!("Login succeeded.");
Ok(())
} else {
bail!(
"login failed: {}",
completion
.error
.as_deref()
.unwrap_or("unknown error from loginChatGptComplete")
let login_response = client.login_chat_gpt()?;
println!("< loginChatGpt response: {login_response:?}");
println!(
"Open the following URL in your browser to continue:\n{}",
login_response.auth_url
);
}
let completion = client.wait_for_login_completion(&login_response.login_id)?;
println!("< loginChatGptComplete notification: {completion:?}");
if completion.success {
println!("Login succeeded.");
Ok(())
} else {
bail!(
"login failed: {}",
completion
.error
.as_deref()
.unwrap_or("unknown error from loginChatGptComplete")
);
}
})
}
fn get_account_rate_limits(endpoint: &Endpoint, config_overrides: &[String]) -> Result<()> {
let mut client = CodexClient::connect(endpoint, config_overrides)?;
with_client(endpoint, config_overrides, |client| {
let initialize = client.initialize()?;
println!("< initialize response: {initialize:?}");
let initialize = client.initialize()?;
println!("< initialize response: {initialize:?}");
let response = client.get_account_rate_limits()?;
println!("< account/rateLimits/read response: {response:?}");
let response = client.get_account_rate_limits()?;
println!("< account/rateLimits/read response: {response:?}");
Ok(())
Ok(())
})
}
fn model_list(endpoint: &Endpoint, config_overrides: &[String]) -> Result<()> {
let mut client = CodexClient::connect(endpoint, config_overrides)?;
with_client(endpoint, config_overrides, |client| {
let initialize = client.initialize()?;
println!("< initialize response: {initialize:?}");
let initialize = client.initialize()?;
println!("< initialize response: {initialize:?}");
let response = client.model_list(ModelListParams::default())?;
println!("< model/list response: {response:?}");
let response = client.model_list(ModelListParams::default())?;
println!("< model/list response: {response:?}");
Ok(())
Ok(())
})
}
fn thread_list(endpoint: &Endpoint, config_overrides: &[String], limit: u32) -> Result<()> {
with_client(endpoint, config_overrides, |client| {
let initialize = client.initialize()?;
println!("< initialize response: {initialize:?}");
let response = client.thread_list(ThreadListParams {
cursor: None,
limit: Some(limit),
sort_key: None,
model_providers: None,
source_kinds: None,
archived: None,
cwd: None,
search_term: None,
})?;
println!("< thread/list response: {response:?}");
Ok(())
})
}
fn with_client<T>(
endpoint: &Endpoint,
config_overrides: &[String],
f: impl FnOnce(&mut CodexClient) -> Result<T>,
) -> Result<T> {
let mut client = CodexClient::connect(endpoint, config_overrides)?;
let initialize = client.initialize()?;
println!("< initialize response: {initialize:?}");
let response = client.thread_list(ThreadListParams {
cursor: None,
limit: Some(limit),
sort_key: None,
model_providers: None,
source_kinds: None,
archived: None,
cwd: None,
search_term: None,
})?;
println!("< thread/list response: {response:?}");
Ok(())
let result = f(&mut client);
client.print_trace_summary();
result
}
fn ensure_dynamic_tools_unused(
@@ -993,6 +1012,8 @@ struct CodexClient {
command_approval_item_ids: Vec<String>,
command_execution_statuses: Vec<CommandExecutionStatus>,
last_turn_status: Option<TurnStatus>,
trace_id: String,
trace_root_span_id: String,
}
#[derive(Debug, Clone, Copy)]
@@ -1052,6 +1073,8 @@ impl CodexClient {
command_approval_item_ids: Vec::new(),
command_execution_statuses: Vec::new(),
last_turn_status: None,
trace_id: generate_trace_id(),
trace_root_span_id: generate_parent_span_id(),
})
}
@@ -1073,6 +1096,8 @@ impl CodexClient {
command_approval_item_ids: Vec::new(),
command_execution_statuses: Vec::new(),
last_turn_status: None,
trace_id: generate_trace_id(),
trace_root_span_id: generate_parent_span_id(),
})
}
@@ -1438,12 +1463,32 @@ impl CodexClient {
}
fn write_request(&mut self, request: &ClientRequest) -> Result<()> {
let request_json = serde_json::to_string(request)?;
let request_pretty = serde_json::to_string_pretty(request)?;
let request = self.jsonrpc_request_with_trace(request)?;
let request_json = serde_json::to_string(&request)?;
let request_pretty = serde_json::to_string_pretty(&request)?;
print_multiline_with_prefix("> ", &request_pretty);
self.write_payload(&request_json)
}
fn jsonrpc_request_with_trace(&self, request: &ClientRequest) -> Result<JSONRPCRequest> {
let request_value = serde_json::to_value(request)?;
let mut request: JSONRPCRequest = serde_json::from_value(request_value)
.context("client request was not a valid JSON-RPC request")?;
request.trace = Some(W3cTraceContext {
traceparent: Some(format!(
"00-{}-{}-01",
self.trace_id, self.trace_root_span_id
)),
tracestate: None,
});
Ok(request)
}
fn print_trace_summary(&self) {
println!("\n[Datadog trace]");
println!("go/trace/{}\n", self.trace_id);
}
fn wait_for_response<T>(&mut self, request_id: RequestId, method: &str) -> Result<T>
where
T: DeserializeOwned,
@@ -1709,6 +1754,15 @@ impl CodexClient {
}
}
fn generate_trace_id() -> String {
Uuid::new_v4().simple().to_string()
}
fn generate_parent_span_id() -> String {
let uuid = Uuid::new_v4().simple().to_string();
uuid[..16].to_string()
}
fn print_multiline_with_prefix(prefix: &str, payload: &str) {
for line in payload.lines() {
println!("{prefix}{line}");
@@ -1728,11 +1782,18 @@ impl Drop for CodexClient {
return;
}
thread::sleep(Duration::from_millis(100));
let deadline = SystemTime::now() + APP_SERVER_GRACEFUL_SHUTDOWN_TIMEOUT;
loop {
if let Ok(Some(status)) = child.try_wait() {
println!("[codex app-server exited: {status}]");
return;
}
if let Ok(Some(status)) = child.try_wait() {
println!("[codex app-server exited: {status}]");
return;
if SystemTime::now() >= deadline {
break;
}
thread::sleep(APP_SERVER_GRACEFUL_SHUTDOWN_POLL_INTERVAL);
}
let _ = child.kill();

View File

@@ -21,6 +21,7 @@ async-trait = { workspace = true }
codex-arg0 = { workspace = true }
codex-cloud-requirements = { workspace = true }
codex-core = { workspace = true }
codex-otel = { workspace = true }
codex-shell-command = { workspace = true }
codex-utils-cli = { workspace = true }
codex-backend-client = { workspace = true }

View File

@@ -63,7 +63,7 @@ Use the thread APIs to create, list, or archive conversations. Drive a conversat
- Initialize once per connection: Immediately after opening a transport connection, send an `initialize` request with your client metadata, then emit an `initialized` notification. Any other request on that connection before this handshake gets rejected.
- Start (or resume) a thread: Call `thread/start` to open a fresh conversation. The response returns the thread object and youll also get a `thread/started` notification. If youre continuing an existing conversation, call `thread/resume` with its ID instead. If you want to branch from an existing conversation, call `thread/fork` to create a new thread id with copied history.
The returned `thread.ephemeral` flag tells you whether the session is intentionally in-memory only; when it is `true`, `thread.path` is `null`.
- Begin a turn: To send user input, call `turn/start` with the target `threadId` and the user's input. Optional fields let you override model, cwd, sandbox policy, etc. This immediately returns the new turn object and triggers a `turn/started` notification.
- Begin a turn: To send user input, call `turn/start` with the target `threadId` and the user's input. Optional fields let you override model, cwd, sandbox policy, etc. This immediately returns the new turn object. The app-server emits `turn/started` when that turn actually begins running.
- Stream events: After `turn/start`, keep reading JSON-RPC notifications on stdout. Youll see `item/started`, `item/completed`, deltas like `item/agentMessage/delta`, tool progress, etc. These represent streaming model output plus any side effects (commands, tool calls, reasoning notes).
- Finish the turn: When the model is done (or the turn is interrupted via making the `turn/interrupt` call), the server sends `turn/completed` with the final turn state and token usage.
@@ -120,16 +120,16 @@ Example with notification opt-out:
## API Overview
- `thread/start` — create a new thread; emits `thread/started` and auto-subscribes you to turn/item events for that thread.
- `thread/start` — create a new thread; emits `thread/started` (including the current `thread.status`) and auto-subscribes you to turn/item events for that thread.
- `thread/resume` — reopen an existing thread by id so subsequent `turn/start` calls append to it.
- `thread/fork` — fork an existing thread into a new thread id by copying the stored history; emits `thread/started` and auto-subscribes you to turn/item events for the new thread.
- `thread/fork` — fork an existing thread into a new thread id by copying the stored history; emits `thread/started` (including the current `thread.status`) and auto-subscribes you to turn/item events for the new thread.
- `thread/list` — page through stored rollouts; supports cursor-based pagination and optional `modelProviders`, `sourceKinds`, `archived`, `cwd`, and `searchTerm` filters. Each returned `thread` includes `status` (`ThreadStatus`), defaulting to `notLoaded` when the thread is not currently loaded.
- `thread/loaded/list` — list the thread ids currently loaded in memory.
- `thread/read` — read a stored thread by id without resuming it; optionally include turns via `includeTurns`. The returned `thread` includes `status` (`ThreadStatus`), defaulting to `notLoaded` when the thread is not currently loaded.
- `thread/status/changed` — notification emitted when a loaded threads status changes (`threadId` + new `status`).
- `thread/archive` — move a threads rollout file into the archived directory; returns `{}` on success and emits `thread/archived`.
- `thread/unsubscribe` — unsubscribe this connection from thread turn/item events. If this was the last subscriber, the server shuts down and unloads the thread, then emits `thread/closed`.
- `thread/name/set` — set or update a threads user-facing name; returns `{}` on success. Thread names are not required to be unique; name lookups resolve to the most recently updated thread.
- `thread/name/set` — set or update a threads user-facing name for either a loaded thread or a persisted rollout; returns `{}` on success. Thread names are not required to be unique; name lookups resolve to the most recently updated thread.
- `thread/unarchive` — move an archived rollout file back into the sessions directory; returns the restored `thread` on success and emits `thread/unarchived`.
- `thread/compact/start` — trigger conversation history compaction for a thread; returns `{}` immediately while progress streams through standard turn/item notifications.
- `thread/backgroundTerminals/clean` — terminate all running background terminals for a thread (experimental; requires `capabilities.experimentalApi`); returns `{}` when the cleanup request is accepted.
@@ -273,10 +273,11 @@ When `nextCursor` is `null`, youve reached the final page.
### Example: Track thread status changes
`thread/status/changed` is emitted whenever a loaded thread's status changes:
`thread/status/changed` is emitted whenever a loaded thread's status changes after it has already been introduced to the client:
- Includes `threadId` and the new `status`.
- Status can be `notLoaded`, `idle`, `systemError`, or `active` (with `activeFlags`; `active` implies running).
- `thread/start`, `thread/fork`, and detached review threads do not emit a separate initial `thread/status/changed`; their `thread/started` notification already carries the current `thread.status`.
```json
{ "method": "thread/status/changed", "params": {
@@ -620,7 +621,7 @@ Because audio is intentionally separate from `ThreadItem`, clients can opt out o
### Turn events
The app-server streams JSON-RPC notifications while a turn is running. Each turn starts with `turn/started` (initial `turn`) and ends with `turn/completed` (final `turn` status). Token usage events stream separately via `thread/tokenUsage/updated`. Clients subscribe to the events they care about, rendering each item incrementally as updates arrive. The per-item lifecycle is always: `item/started` → zero or more item-specific deltas → `item/completed`.
The app-server streams JSON-RPC notifications while a turn is running. Each turn emits `turn/started` when it begins running and ends with `turn/completed` (final `turn` status). Token usage events stream separately via `thread/tokenUsage/updated`. Clients subscribe to the events they care about, rendering each item incrementally as updates arrive. The per-item lifecycle is always: `item/started` → zero or more item-specific deltas → `item/completed`.
- `turn/started``{ turn }` with the turn id, empty `items`, and `status: "inProgress"`.
- `turn/completed``{ turn }` where `turn.status` is `completed`, `interrupted`, or `failed`; failures carry `{ error: { message, codexErrorInfo?, additionalDetails? } }`.

View File

@@ -0,0 +1,101 @@
use crate::message_processor::ConnectionSessionState;
use crate::outgoing_message::ConnectionId;
use crate::transport::AppServerTransport;
use codex_app_server_protocol::InitializeParams;
use codex_app_server_protocol::JSONRPCRequest;
use codex_otel::set_parent_from_context;
use codex_otel::set_parent_from_w3c_trace_context;
use codex_otel::traceparent_context_from_env;
use codex_protocol::protocol::W3cTraceContext;
use tracing::Span;
use tracing::field;
use tracing::info_span;
pub(crate) fn request_span(
request: &JSONRPCRequest,
transport: AppServerTransport,
connection_id: ConnectionId,
session: &ConnectionSessionState,
) -> Span {
let span = info_span!(
"app_server.request",
otel.kind = "server",
otel.name = request.method.as_str(),
rpc.system = "jsonrpc",
rpc.method = request.method.as_str(),
rpc.transport = transport_name(transport),
rpc.request_id = ?request.id,
app_server.connection_id = ?connection_id,
app_server.api_version = "v2",
app_server.client_name = field::Empty,
app_server.client_version = field::Empty,
);
let initialize_client_info = initialize_client_info(request);
if let Some(client_name) = client_name(initialize_client_info.as_ref(), session) {
span.record("app_server.client_name", client_name);
}
if let Some(client_version) = client_version(initialize_client_info.as_ref(), session) {
span.record("app_server.client_version", client_version);
}
if let Some(traceparent) = request
.trace
.as_ref()
.and_then(|trace| trace.traceparent.as_deref())
{
let trace = W3cTraceContext {
traceparent: Some(traceparent.to_string()),
tracestate: request
.trace
.as_ref()
.and_then(|value| value.tracestate.clone()),
};
if !set_parent_from_w3c_trace_context(&span, &trace) {
tracing::warn!(
rpc_method = request.method.as_str(),
rpc_request_id = ?request.id,
"ignoring invalid inbound request trace carrier"
);
}
} else if let Some(context) = traceparent_context_from_env() {
set_parent_from_context(&span, context);
}
span
}
fn transport_name(transport: AppServerTransport) -> &'static str {
match transport {
AppServerTransport::Stdio => "stdio",
AppServerTransport::WebSocket { .. } => "websocket",
}
}
fn client_name<'a>(
initialize_client_info: Option<&'a InitializeParams>,
session: &'a ConnectionSessionState,
) -> Option<&'a str> {
if let Some(params) = initialize_client_info {
return Some(params.client_info.name.as_str());
}
session.app_server_client_name.as_deref()
}
fn client_version<'a>(
initialize_client_info: Option<&'a InitializeParams>,
session: &'a ConnectionSessionState,
) -> Option<&'a str> {
if let Some(params) = initialize_client_info {
return Some(params.client_info.version.as_str());
}
session.client_version.as_deref()
}
fn initialize_client_info(request: &JSONRPCRequest) -> Option<InitializeParams> {
if request.method != "initialize" {
return None;
}
let params = request.params.clone()?;
serde_json::from_value(params).ok()
}

View File

@@ -83,6 +83,7 @@ use codex_app_server_protocol::TurnError;
use codex_app_server_protocol::TurnInterruptResponse;
use codex_app_server_protocol::TurnPlanStep;
use codex_app_server_protocol::TurnPlanUpdatedNotification;
use codex_app_server_protocol::TurnStartedNotification;
use codex_app_server_protocol::TurnStatus;
use codex_app_server_protocol::build_turns_from_rollout_items;
use codex_app_server_protocol::convert_patch_changes;
@@ -185,12 +186,30 @@ pub(crate) async fn apply_bespoke_event_handling(
msg,
} = event;
match msg {
EventMsg::TurnStarted(_) => {
EventMsg::TurnStarted(payload) => {
// While not technically necessary as it was already done on TurnComplete, be extra cautios and abort any pending server requests.
outgoing.abort_pending_server_requests().await;
thread_watch_manager
.note_turn_started(&conversation_id.to_string())
.await;
if let ApiVersion::V2 = api_version {
let turn = {
let state = thread_state.lock().await;
state.active_turn_snapshot().unwrap_or_else(|| Turn {
id: payload.turn_id.clone(),
items: Vec::new(),
error: None,
status: TurnStatus::InProgress,
})
};
let notification = TurnStartedNotification {
thread_id: conversation_id.to_string(),
turn,
};
outgoing
.send_server_notification(ServerNotification::TurnStarted(notification))
.await;
}
}
EventMsg::TurnComplete(_ev) => {
// All per-thread requests are bound to a turn, so abort them.
@@ -232,7 +251,6 @@ pub(crate) async fn apply_bespoke_event_handling(
EventMsg::RealtimeConversationRealtime(event) => {
if let ApiVersion::V2 = api_version {
match event.payload {
RealtimeEvent::SessionCreated { .. } => {}
RealtimeEvent::SessionUpdated { .. } => {}
RealtimeEvent::AudioOut(audio) => {
let notification = ThreadRealtimeOutputAudioDeltaNotification {
@@ -256,6 +274,24 @@ pub(crate) async fn apply_bespoke_event_handling(
))
.await;
}
RealtimeEvent::ConversationItemDone { .. } => {}
RealtimeEvent::HandoffRequested(handoff) => {
let notification = ThreadRealtimeItemAddedNotification {
thread_id: conversation_id.to_string(),
item: serde_json::json!({
"type": "handoff_request",
"handoff_id": handoff.handoff_id,
"item_id": handoff.item_id,
"input_transcript": handoff.input_transcript,
"messages": handoff.messages,
}),
};
outgoing
.send_server_notification(ServerNotification::ThreadRealtimeItemAdded(
notification,
))
.await;
}
RealtimeEvent::Error(message) => {
let notification = ThreadRealtimeErrorNotification {
thread_id: conversation_id.to_string(),

View File

@@ -141,6 +141,7 @@ use codex_app_server_protocol::ThreadListParams;
use codex_app_server_protocol::ThreadListResponse;
use codex_app_server_protocol::ThreadLoadedListParams;
use codex_app_server_protocol::ThreadLoadedListResponse;
use codex_app_server_protocol::ThreadNameUpdatedNotification;
use codex_app_server_protocol::ThreadReadParams;
use codex_app_server_protocol::ThreadReadResponse;
use codex_app_server_protocol::ThreadRealtimeAppendAudioParams;
@@ -172,7 +173,6 @@ use codex_app_server_protocol::Turn;
use codex_app_server_protocol::TurnInterruptParams;
use codex_app_server_protocol::TurnStartParams;
use codex_app_server_protocol::TurnStartResponse;
use codex_app_server_protocol::TurnStartedNotification;
use codex_app_server_protocol::TurnStatus;
use codex_app_server_protocol::TurnSteerParams;
use codex_app_server_protocol::TurnSteerResponse;
@@ -374,7 +374,6 @@ pub(crate) struct CodexMessageProcessor {
outgoing: Arc<OutgoingMessageSender>,
arg0_paths: Arg0DispatchPaths,
config: Arc<Config>,
single_client_mode: bool,
cli_overrides: Vec<(String, TomlValue)>,
cloud_requirements: Arc<RwLock<CloudRequirementsLoader>>,
active_login: Arc<Mutex<Option<ActiveLogin>>>,
@@ -401,7 +400,6 @@ struct ListenerTaskContext {
thread_watch_manager: ThreadWatchManager,
fallback_model_provider: String,
codex_home: PathBuf,
single_client_mode: bool,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
@@ -418,7 +416,6 @@ pub(crate) struct CodexMessageProcessorArgs {
pub(crate) config: Arc<Config>,
pub(crate) cli_overrides: Vec<(String, TomlValue)>,
pub(crate) cloud_requirements: Arc<RwLock<CloudRequirementsLoader>>,
pub(crate) single_client_mode: bool,
pub(crate) feedback: CodexFeedback,
}
@@ -463,7 +460,6 @@ impl CodexMessageProcessor {
config,
cli_overrides,
cloud_requirements,
single_client_mode,
feedback,
} = args;
Self {
@@ -472,7 +468,6 @@ impl CodexMessageProcessor {
outgoing: outgoing.clone(),
arg0_paths,
config,
single_client_mode,
cli_overrides,
cloud_requirements,
active_login: Arc::new(Mutex::new(None)),
@@ -2047,6 +2042,7 @@ impl CodexMessageProcessor {
let ThreadStartParams {
model,
model_provider,
service_tier,
cwd,
approval_policy,
sandbox,
@@ -2064,6 +2060,7 @@ impl CodexMessageProcessor {
let mut typesafe_overrides = self.build_thread_config_overrides(
model,
model_provider,
service_tier,
cwd,
approval_policy,
sandbox,
@@ -2081,7 +2078,6 @@ impl CodexMessageProcessor {
thread_watch_manager: self.thread_watch_manager.clone(),
fallback_model_provider: self.config.model_provider_id.clone(),
codex_home: self.config.codex_home.clone(),
single_client_mode: self.single_client_mode,
};
tokio::spawn(async move {
@@ -2204,7 +2200,7 @@ impl CodexMessageProcessor {
listener_task_context
.thread_watch_manager
.upsert_thread(thread.clone())
.upsert_thread_silently(thread.clone())
.await;
thread.status = resolve_thread_status(
@@ -2219,6 +2215,7 @@ impl CodexMessageProcessor {
thread: thread.clone(),
model: config_snapshot.model,
model_provider: config_snapshot.model_provider_id,
service_tier: config_snapshot.service_tier,
cwd: config_snapshot.cwd,
approval_policy: config_snapshot.approval_policy.into(),
sandbox: config_snapshot.sandbox_policy.into(),
@@ -2255,6 +2252,7 @@ impl CodexMessageProcessor {
&self,
model: Option<String>,
model_provider: Option<String>,
service_tier: Option<Option<codex_protocol::config_types::ServiceTier>>,
cwd: Option<String>,
approval_policy: Option<codex_app_server_protocol::AskForApproval>,
sandbox: Option<SandboxMode>,
@@ -2265,6 +2263,7 @@ impl CodexMessageProcessor {
ConfigOverrides {
model,
model_provider,
service_tier,
cwd: cwd.map(PathBuf::from),
approval_policy: approval_policy
.map(codex_app_server_protocol::AskForApproval::to_core),
@@ -2341,6 +2340,14 @@ impl CodexMessageProcessor {
async fn thread_set_name(&self, request_id: ConnectionRequestId, params: ThreadSetNameParams) {
let ThreadSetNameParams { thread_id, name } = params;
let thread_id = match ThreadId::from_string(&thread_id) {
Ok(id) => id,
Err(err) => {
self.send_invalid_request_error(request_id, format!("invalid thread id: {err}"))
.await;
return;
}
};
let Some(name) = codex_core::util::normalize_thread_name(&name) else {
self.send_invalid_request_error(
request_id,
@@ -2350,15 +2357,43 @@ impl CodexMessageProcessor {
return;
};
let (_, thread) = match self.load_thread(&thread_id).await {
Ok(v) => v,
Err(error) => {
self.outgoing.send_error(request_id, error).await;
if let Ok(thread) = self.thread_manager.get_thread(thread_id).await {
if let Err(err) = thread.submit(Op::SetThreadName { name }).await {
self.send_internal_error(request_id, format!("failed to set thread name: {err}"))
.await;
return;
}
};
if let Err(err) = thread.submit(Op::SetThreadName { name }).await {
self.outgoing
.send_response(request_id, ThreadSetNameResponse {})
.await;
return;
}
let thread_exists =
match find_thread_path_by_id_str(&self.config.codex_home, &thread_id.to_string()).await
{
Ok(Some(_)) => true,
Ok(None) => false,
Err(err) => {
self.send_invalid_request_error(
request_id,
format!("failed to locate thread id {thread_id}: {err}"),
)
.await;
return;
}
};
if !thread_exists {
self.send_invalid_request_error(request_id, format!("thread not found: {thread_id}"))
.await;
return;
}
if let Err(err) =
codex_core::append_thread_name(&self.config.codex_home, thread_id, &name).await
{
self.send_internal_error(request_id, format!("failed to set thread name: {err}"))
.await;
return;
@@ -2367,6 +2402,13 @@ impl CodexMessageProcessor {
self.outgoing
.send_response(request_id, ThreadSetNameResponse {})
.await;
let notification = ThreadNameUpdatedNotification {
thread_id: thread_id.to_string(),
thread_name: Some(name),
};
self.outgoing
.send_server_notification(ServerNotification::ThreadNameUpdated(notification))
.await;
}
async fn thread_unarchive(
@@ -3023,6 +3065,7 @@ impl CodexMessageProcessor {
path,
model,
model_provider,
service_tier,
cwd,
approval_policy,
sandbox,
@@ -3055,6 +3098,7 @@ impl CodexMessageProcessor {
let typesafe_overrides = self.build_thread_config_overrides(
model,
model_provider,
service_tier,
cwd,
approval_policy,
sandbox,
@@ -3153,6 +3197,7 @@ impl CodexMessageProcessor {
thread,
model: session_configured.model,
model_provider: session_configured.model_provider_id,
service_tier: session_configured.service_tier,
cwd: session_configured.cwd,
approval_policy: session_configured.approval_policy.into(),
sandbox: session_configured.sandbox_policy.into(),
@@ -3460,6 +3505,7 @@ impl CodexMessageProcessor {
path,
model,
model_provider,
service_tier,
cwd,
approval_policy,
sandbox,
@@ -3540,6 +3586,7 @@ impl CodexMessageProcessor {
let typesafe_overrides = self.build_thread_config_overrides(
model,
model_provider,
service_tier,
cwd,
approval_policy,
sandbox,
@@ -3667,7 +3714,7 @@ impl CodexMessageProcessor {
}
self.thread_watch_manager
.upsert_thread(thread.clone())
.upsert_thread_silently(thread.clone())
.await;
thread.status = resolve_thread_status(
@@ -3681,6 +3728,7 @@ impl CodexMessageProcessor {
thread: thread.clone(),
model: session_configured.model,
model_provider: session_configured.model_provider_id,
service_tier: session_configured.service_tier,
cwd: session_configured.cwd,
approval_policy: session_configured.approval_policy.into(),
sandbox: session_configured.sandbox_policy.into(),
@@ -4589,6 +4637,7 @@ impl CodexMessageProcessor {
SessionConfiguredNotification {
session_id: session_configured.session_id,
model: session_configured.model.clone(),
service_tier: session_configured.service_tier,
reasoning_effort: session_configured.reasoning_effort,
history_log_id: session_configured.history_log_id,
history_entry_count: session_configured.history_entry_count,
@@ -4810,6 +4859,7 @@ impl CodexMessageProcessor {
SessionConfiguredNotification {
session_id: session_configured.session_id,
model: session_configured.model.clone(),
service_tier: session_configured.service_tier,
reasoning_effort: session_configured.reasoning_effort,
history_log_id: session_configured.history_log_id,
history_entry_count: session_configured.history_entry_count,
@@ -5237,6 +5287,7 @@ impl CodexMessageProcessor {
approval_policy,
sandbox_policy,
model,
service_tier,
effort,
summary,
output_schema,
@@ -5286,6 +5337,7 @@ impl CodexMessageProcessor {
model,
effort,
summary: Some(summary),
service_tier,
final_output_json_schema: output_schema,
collaboration_mode: None,
personality: None,
@@ -5827,6 +5879,7 @@ impl CodexMessageProcessor {
|| params.approval_policy.is_some()
|| params.sandbox_policy.is_some()
|| params.model.is_some()
|| params.service_tier.is_some()
|| params.effort.is_some()
|| params.summary.is_some()
|| collaboration_mode.is_some()
@@ -5843,6 +5896,7 @@ impl CodexMessageProcessor {
model: params.model,
effort: params.effort.map(Some),
summary: params.summary,
service_tier: params.service_tier,
collaboration_mode,
personality: params.personality,
})
@@ -5866,17 +5920,8 @@ impl CodexMessageProcessor {
status: TurnStatus::InProgress,
};
let response = TurnStartResponse { turn: turn.clone() };
let response = TurnStartResponse { turn };
self.outgoing.send_response(request_id, response).await;
// Emit v2 turn/started notification.
let notif = TurnStartedNotification {
thread_id: params.thread_id,
turn,
};
self.outgoing
.send_server_notification(ServerNotification::TurnStarted(notif))
.await;
}
Err(err) => {
let error = JSONRPCErrorError {
@@ -6167,24 +6212,15 @@ impl CodexMessageProcessor {
&self,
request_id: &ConnectionRequestId,
turn: Turn,
parent_thread_id: String,
review_thread_id: String,
) {
let response = ReviewStartResponse {
turn: turn.clone(),
turn,
review_thread_id,
};
self.outgoing
.send_response(request_id.clone(), response)
.await;
let notif = TurnStartedNotification {
thread_id: parent_thread_id,
turn,
};
self.outgoing
.send_server_notification(ServerNotification::TurnStarted(notif))
.await;
}
async fn start_inline_review(
@@ -6200,13 +6236,8 @@ impl CodexMessageProcessor {
match turn_id {
Ok(turn_id) => {
let turn = Self::build_review_turn(turn_id, display_text);
self.emit_review_started(
request_id,
turn,
parent_thread_id.clone(),
parent_thread_id,
)
.await;
self.emit_review_started(request_id, turn, parent_thread_id)
.await;
Ok(())
}
Err(err) => Err(JSONRPCErrorError {
@@ -6281,7 +6312,7 @@ impl CodexMessageProcessor {
Ok(summary) => {
let mut thread = summary_to_thread(summary);
self.thread_watch_manager
.upsert_thread(thread.clone())
.upsert_thread_silently(thread.clone())
.await;
thread.status = resolve_thread_status(
self.thread_watch_manager
@@ -6320,7 +6351,7 @@ impl CodexMessageProcessor {
let turn = Self::build_review_turn(turn_id, display_text);
let review_thread_id = thread_id.to_string();
self.emit_review_started(request_id, turn, review_thread_id.clone(), review_thread_id)
self.emit_review_started(request_id, turn, review_thread_id)
.await;
Ok(())
@@ -6496,7 +6527,6 @@ impl CodexMessageProcessor {
thread_watch_manager: self.thread_watch_manager.clone(),
fallback_model_provider: self.config.model_provider_id.clone(),
codex_home: self.config.codex_home.clone(),
single_client_mode: self.single_client_mode,
},
conversation_id,
connection_id,
@@ -6584,7 +6614,6 @@ impl CodexMessageProcessor {
thread_watch_manager: self.thread_watch_manager.clone(),
fallback_model_provider: self.config.model_provider_id.clone(),
codex_home: self.config.codex_home.clone(),
single_client_mode: self.single_client_mode,
},
conversation_id,
conversation,
@@ -6616,7 +6645,6 @@ impl CodexMessageProcessor {
thread_watch_manager,
fallback_model_provider,
codex_home,
single_client_mode,
} = listener_task_context;
let outgoing_for_task = Arc::clone(&outgoing);
tokio::spawn(async move {
@@ -6666,9 +6694,7 @@ impl CodexMessageProcessor {
);
let raw_events_enabled = {
let mut thread_state = thread_state.lock().await;
if !single_client_mode {
thread_state.track_current_turn_event(&event.msg);
}
thread_state.track_current_turn_event(&event.msg);
thread_state.experimental_raw_events
};
let subscribed_connection_ids = thread_state_manager
@@ -7142,6 +7168,7 @@ async fn handle_pending_thread_resume_request(
let ThreadConfigSnapshot {
model,
model_provider_id,
service_tier,
approval_policy,
sandbox_policy,
cwd,
@@ -7152,6 +7179,7 @@ async fn handle_pending_thread_resume_request(
thread,
model,
model_provider: model_provider_id,
service_tier,
cwd,
approval_policy: approval_policy.into(),
sandbox: sandbox_policy.into(),

View File

@@ -52,6 +52,7 @@ use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::registry::Registry;
use tracing_subscriber::util::SubscriberInitExt;
mod app_server_tracing;
mod bespoke_event_handling;
mod codex_message_processor;
mod config_api;
@@ -447,7 +448,7 @@ pub async fn run_main_with_transport(
let otel = codex_core::otel_init::build_provider(
&config,
env!("CARGO_PKG_VERSION"),
Some("codex_app_server"),
Some("codex-app-server"),
default_analytics_enabled,
)
.map_err(|e| {
@@ -557,7 +558,6 @@ pub async fn run_main_with_transport(
outgoing: outgoing_message_sender,
arg0_paths,
config: Arc::new(config),
single_client_mode,
cli_overrides,
loader_overrides,
cloud_requirements: cloud_requirements.clone(),
@@ -675,6 +675,7 @@ pub async fn run_main_with_transport(
.process_request(
connection_id,
request,
transport,
&mut connection_state.session,
&connection_state.outbound_initialized,
)

View File

@@ -12,6 +12,7 @@ use crate::external_agent_config_api::ExternalAgentConfigApi;
use crate::outgoing_message::ConnectionId;
use crate::outgoing_message::ConnectionRequestId;
use crate::outgoing_message::OutgoingMessageSender;
use crate::transport::AppServerTransport;
use async_trait::async_trait;
use codex_app_server_protocol::ChatgptAuthTokensRefreshParams;
use codex_app_server_protocol::ChatgptAuthTokensRefreshReason;
@@ -59,6 +60,7 @@ use tokio::sync::watch;
use tokio::time::Duration;
use tokio::time::timeout;
use toml::Value as TomlValue;
use tracing::Instrument;
const EXTERNAL_AUTH_REFRESH_TIMEOUT: Duration = Duration::from_secs(10);
@@ -141,13 +143,13 @@ pub(crate) struct ConnectionSessionState {
pub(crate) experimental_api_enabled: bool,
pub(crate) opted_out_notification_methods: HashSet<String>,
pub(crate) app_server_client_name: Option<String>,
pub(crate) client_version: Option<String>,
}
pub(crate) struct MessageProcessorArgs {
pub(crate) outgoing: Arc<OutgoingMessageSender>,
pub(crate) arg0_paths: Arg0DispatchPaths,
pub(crate) config: Arc<Config>,
pub(crate) single_client_mode: bool,
pub(crate) cli_overrides: Vec<(String, TomlValue)>,
pub(crate) loader_overrides: LoaderOverrides,
pub(crate) cloud_requirements: CloudRequirementsLoader,
@@ -163,7 +165,6 @@ impl MessageProcessor {
outgoing,
arg0_paths,
config,
single_client_mode,
cli_overrides,
loader_overrides,
cloud_requirements,
@@ -199,7 +200,6 @@ impl MessageProcessor {
config: Arc::clone(&config),
cli_overrides: cli_overrides.clone(),
cloud_requirements: cloud_requirements.clone(),
single_client_mode,
feedback,
});
let config_api = ConfigApi::new(
@@ -224,46 +224,50 @@ impl MessageProcessor {
&mut self,
connection_id: ConnectionId,
request: JSONRPCRequest,
transport: AppServerTransport,
session: &mut ConnectionSessionState,
outbound_initialized: &AtomicBool,
) {
let request_method = request.method.as_str();
tracing::trace!(
?connection_id,
request_id = ?request.id,
"app-server request: {request_method}"
);
let request_id = ConnectionRequestId {
connection_id,
request_id: request.id.clone(),
};
let request_json = match serde_json::to_value(&request) {
Ok(request_json) => request_json,
Err(err) => {
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: format!("Invalid request: {err}"),
data: None,
};
self.outgoing.send_error(request_id, error).await;
return;
}
};
let request_span =
crate::app_server_tracing::request_span(&request, transport, connection_id, session);
async {
let request_method = request.method.as_str();
tracing::trace!(
?connection_id,
request_id = ?request.id,
"app-server request: {request_method}"
);
let request_id = ConnectionRequestId {
connection_id,
request_id: request.id.clone(),
};
let request_json = match serde_json::to_value(&request) {
Ok(request_json) => request_json,
Err(err) => {
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: format!("Invalid request: {err}"),
data: None,
};
self.outgoing.send_error(request_id, error).await;
return;
}
};
let codex_request = match serde_json::from_value::<ClientRequest>(request_json) {
Ok(codex_request) => codex_request,
Err(err) => {
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: format!("Invalid request: {err}"),
data: None,
};
self.outgoing.send_error(request_id, error).await;
return;
}
};
let codex_request = match serde_json::from_value::<ClientRequest>(request_json) {
Ok(codex_request) => codex_request,
Err(err) => {
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: format!("Invalid request: {err}"),
data: None,
};
self.outgoing.send_error(request_id, error).await;
return;
}
};
match codex_request {
match codex_request {
// Handle Initialize internally so CodexMessageProcessor does not have to concern
// itself with the `initialized` bool.
ClientRequest::Initialize { request_id, params } => {
@@ -304,6 +308,8 @@ impl MessageProcessor {
title: _title,
version,
} = params.client_info;
session.app_server_client_name = Some(name.clone());
session.client_version = Some(version.clone());
if let Err(error) = set_default_originator(name.clone()) {
match error {
SetOriginatorError::InvalidHeaderValue => {
@@ -330,7 +336,6 @@ impl MessageProcessor {
if let Ok(mut suffix) = USER_AGENT_SUFFIX.lock() {
*suffix = Some(user_agent_suffix);
}
session.app_server_client_name = Some(name.clone());
let user_agent = get_codex_user_agent();
let response = InitializeResponse { user_agent };
@@ -355,91 +360,97 @@ impl MessageProcessor {
return;
}
}
}
}
if let Some(reason) = codex_request.experimental_reason()
&& !session.experimental_api_enabled
{
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: experimental_required_message(reason),
data: None,
};
self.outgoing.send_error(request_id, error).await;
return;
}
if let Some(reason) = codex_request.experimental_reason()
&& !session.experimental_api_enabled
{
let error = JSONRPCErrorError {
code: INVALID_REQUEST_ERROR_CODE,
message: experimental_required_message(reason),
data: None,
};
self.outgoing.send_error(request_id, error).await;
return;
}
match codex_request {
ClientRequest::ConfigRead { request_id, params } => {
self.handle_config_read(
ConnectionRequestId {
connection_id,
request_id,
},
params,
)
.await;
}
ClientRequest::ExternalAgentConfigDetect { request_id, params } => {
self.handle_external_agent_config_detect(
ConnectionRequestId {
connection_id,
request_id,
},
params,
)
.await;
}
ClientRequest::ExternalAgentConfigImport { request_id, params } => {
self.handle_external_agent_config_import(
ConnectionRequestId {
connection_id,
request_id,
},
params,
)
.await;
}
ClientRequest::ConfigValueWrite { request_id, params } => {
self.handle_config_value_write(
ConnectionRequestId {
connection_id,
request_id,
},
params,
)
.await;
}
ClientRequest::ConfigBatchWrite { request_id, params } => {
self.handle_config_batch_write(
ConnectionRequestId {
connection_id,
request_id,
},
params,
)
.await;
}
ClientRequest::ConfigRequirementsRead {
request_id,
params: _,
} => {
self.handle_config_requirements_read(ConnectionRequestId {
connection_id,
request_id,
})
.await;
}
other => {
// Box the delegated future so this wrapper's async state machine does not
// inline the full `CodexMessageProcessor::process_request` future, which
// can otherwise push worker-thread stack usage over the edge.
self.codex_message_processor
.process_request(connection_id, other, session.app_server_client_name.clone())
.boxed()
match codex_request {
ClientRequest::ConfigRead { request_id, params } => {
self.handle_config_read(
ConnectionRequestId {
connection_id,
request_id,
},
params,
)
.await;
}
ClientRequest::ExternalAgentConfigDetect { request_id, params } => {
self.handle_external_agent_config_detect(
ConnectionRequestId {
connection_id,
request_id,
},
params,
)
.await;
}
ClientRequest::ExternalAgentConfigImport { request_id, params } => {
self.handle_external_agent_config_import(
ConnectionRequestId {
connection_id,
request_id,
},
params,
)
.await;
}
ClientRequest::ConfigValueWrite { request_id, params } => {
self.handle_config_value_write(
ConnectionRequestId {
connection_id,
request_id,
},
params,
)
.await;
}
ClientRequest::ConfigBatchWrite { request_id, params } => {
self.handle_config_batch_write(
ConnectionRequestId {
connection_id,
request_id,
},
params,
)
.await;
}
ClientRequest::ConfigRequirementsRead {
request_id,
params: _,
} => {
self.handle_config_requirements_read(ConnectionRequestId {
connection_id,
request_id,
})
.await;
}
other => {
// Box the delegated future so this wrapper's async state machine does not
// inline the full `CodexMessageProcessor::process_request` future, which
// can otherwise push worker-thread stack usage over the edge.
self.codex_message_processor
.process_request(
connection_id,
other,
session.app_server_client_name.clone(),
)
.boxed()
.await;
}
}
}
.instrument(request_span)
.await;
}
pub(crate) async fn process_notification(&self, notification: JSONRPCNotification) {

View File

@@ -91,7 +91,12 @@ impl ThreadWatchManager {
}
pub(crate) async fn upsert_thread(&self, thread: Thread) {
self.mutate_and_publish(move |state| state.upsert_thread(thread.id))
self.mutate_and_publish(move |state| state.upsert_thread(thread.id, true))
.await;
}
pub(crate) async fn upsert_thread_silently(&self, thread: Thread) {
self.mutate_and_publish(move |state| state.upsert_thread(thread.id, false))
.await;
}
@@ -289,14 +294,22 @@ struct ThreadWatchState {
}
impl ThreadWatchState {
fn upsert_thread(&mut self, thread_id: String) -> Option<ThreadStatusChangedNotification> {
fn upsert_thread(
&mut self,
thread_id: String,
emit_notification: bool,
) -> Option<ThreadStatusChangedNotification> {
let previous_status = self.status_for(&thread_id);
let runtime = self
.runtime_by_thread_id
.entry(thread_id.clone())
.or_default();
runtime.is_loaded = true;
self.status_changed_notification(thread_id, previous_status)
if emit_notification {
self.status_changed_notification(thread_id, previous_status)
} else {
None
}
}
fn remove_thread(&mut self, thread_id: &str) -> Option<ThreadStatusChangedNotification> {
@@ -692,6 +705,45 @@ mod tests {
);
}
#[tokio::test]
async fn silent_upsert_skips_initial_notification() {
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(8);
let manager = ThreadWatchManager::new_with_outgoing(Arc::new(OutgoingMessageSender::new(
outgoing_tx,
)));
manager
.upsert_thread_silently(test_thread(
INTERACTIVE_THREAD_ID,
codex_app_server_protocol::SessionSource::Cli,
))
.await;
assert_eq!(
manager
.loaded_status_for_thread(INTERACTIVE_THREAD_ID)
.await,
ThreadStatus::Idle,
);
assert!(
timeout(Duration::from_millis(100), outgoing_rx.recv())
.await
.is_err(),
"silent upsert should not emit thread/status/changed"
);
manager.note_turn_started(INTERACTIVE_THREAD_ID).await;
assert_eq!(
recv_status_changed_notification(&mut outgoing_rx).await,
ThreadStatusChangedNotification {
thread_id: INTERACTIVE_THREAD_ID.to_string(),
status: ThreadStatus::Active {
active_flags: vec![],
},
},
);
}
async fn wait_for_status(
manager: &ThreadWatchManager,
thread_id: &str,

View File

@@ -744,6 +744,7 @@ mod tests {
id: codex_app_server_protocol::RequestId::Integer(7),
method: "config/read".to_string(),
params: Some(json!({ "includeLayers": false })),
trace: None,
});
assert!(
enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request).await
@@ -885,6 +886,7 @@ mod tests {
id: codex_app_server_protocol::RequestId::Integer(7),
method: "config/read".to_string(),
params: Some(json!({ "includeLayers": false })),
trace: None,
});
let enqueue_result = tokio::time::timeout(

View File

@@ -891,6 +891,7 @@ impl McpProcess {
id: RequestId::Integer(request_id),
method: method.to_string(),
params,
trace: None,
});
self.send_jsonrpc_message(message).await?;
Ok(request_id)

View File

@@ -36,7 +36,7 @@ use std::path::Path;
use tempfile::TempDir;
use tokio::time::timeout;
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
const DEFAULT_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(45);
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_codex_jsonrpc_conversation_flow() -> Result<()> {
@@ -337,6 +337,7 @@ async fn test_send_user_turn_changes_approval_policy_behavior() -> Result<()> {
model: "mock-model".to_string(),
effort: Some(ReasoningEffort::Medium),
summary: ReasoningSummary::Auto,
service_tier: None,
output_schema: None,
})
.await?;
@@ -453,6 +454,7 @@ async fn test_send_user_turn_updates_sandbox_and_cwd_between_turns() -> Result<(
model: model.clone(),
effort: Some(ReasoningEffort::Medium),
summary: ReasoningSummary::Auto,
service_tier: None,
output_schema: None,
})
.await?;
@@ -481,6 +483,7 @@ async fn test_send_user_turn_updates_sandbox_and_cwd_between_turns() -> Result<(
model: model.clone(),
effort: Some(ReasoningEffort::Medium),
summary: ReasoningSummary::Auto,
service_tier: None,
output_schema: None,
})
.await?;

View File

@@ -92,6 +92,7 @@ async fn send_user_turn_accepts_output_schema_v1() -> Result<()> {
model: "mock-model".to_string(),
effort: Some(ReasoningEffort::Medium),
summary: ReasoningSummary::Auto,
service_tier: None,
output_schema: Some(output_schema.clone()),
})
.await?;
@@ -184,6 +185,7 @@ async fn send_user_turn_rejects_oversized_input_v1() -> Result<()> {
model: "mock-model".to_string(),
effort: Some(ReasoningEffort::Low),
summary: ReasoningSummary::Auto,
service_tier: None,
output_schema: None,
})
.await?;
@@ -273,6 +275,7 @@ async fn send_user_turn_output_schema_is_per_turn_v1() -> Result<()> {
model: "mock-model".to_string(),
effort: Some(ReasoningEffort::Medium),
summary: ReasoningSummary::Auto,
service_tier: None,
output_schema: Some(output_schema.clone()),
})
.await?;
@@ -321,6 +324,7 @@ async fn send_user_turn_output_schema_is_per_turn_v1() -> Result<()> {
model: "mock-model".to_string(),
effort: Some(ReasoningEffort::Medium),
summary: ReasoningSummary::Auto,
service_tier: None,
output_schema: None,
})
.await?;

View File

@@ -30,7 +30,7 @@ async fn app_server_default_analytics_disabled_without_flag() -> Result<()> {
let provider = codex_core::otel_init::build_provider(
&config,
SERVICE_VERSION,
Some("codex_app_server"),
Some("codex-app-server"),
false,
)
.map_err(|err| anyhow::anyhow!(err.to_string()))?;
@@ -55,7 +55,7 @@ async fn app_server_default_analytics_enabled_with_flag() -> Result<()> {
let provider = codex_core::otel_init::build_provider(
&config,
SERVICE_VERSION,
Some("codex_app_server"),
Some("codex-app-server"),
true,
)
.map_err(|err| anyhow::anyhow!(err.to_string()))?;

View File

@@ -174,6 +174,7 @@ pub(super) async fn send_request(
id: RequestId::Integer(id),
method: method.to_string(),
params,
trace: None,
});
send_jsonrpc(stream, message).await
}

View File

@@ -5,6 +5,7 @@ use app_test_support::create_mock_responses_server_sequence_unchecked;
use app_test_support::to_response;
use codex_app_server_protocol::JSONRPCError;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::LoginApiKeyParams;
use codex_app_server_protocol::RequestId;
use codex_app_server_protocol::ThreadRealtimeAppendAudioParams;
use codex_app_server_protocol::ThreadRealtimeAppendAudioResponse;
@@ -42,20 +43,17 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> {
let responses_server = create_mock_responses_server_sequence_unchecked(Vec::new()).await;
let realtime_server = start_websocket_server(vec![vec![
vec![json!({
"type": "session.created",
"session": { "id": "sess_backend" }
})],
vec![json!({
"type": "session.updated",
"session": { "backend_prompt": "backend prompt" }
"session": { "id": "sess_backend", "instructions": "backend prompt" }
})],
vec![],
vec![
json!({
"type": "response.output_audio.delta",
"type": "conversation.output_audio.delta",
"delta": "AQID",
"sample_rate": 24_000,
"num_channels": 1,
"channels": 1,
"samples_per_channel": 512
}),
json!({
@@ -84,6 +82,7 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> {
let mut mcp = McpProcess::new(codex_home.path()).await?;
mcp.initialize().await?;
login_with_api_key(&mut mcp, "sk-test-key").await?;
let thread_start_request_id = mcp
.send_thread_start_request(ThreadStartParams::default())
@@ -182,7 +181,7 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> {
assert_eq!(connection.len(), 3);
assert_eq!(
connection[0].body_json()["type"].as_str(),
Some("session.create")
Some("session.update")
);
let mut request_types = [
connection[1].body_json()["type"]
@@ -199,7 +198,7 @@ async fn realtime_conversation_streams_v2_notifications() -> Result<()> {
request_types,
[
"conversation.item.create".to_string(),
"response.input_audio.delta".to_string(),
"input_audio_buffer.append".to_string(),
]
);
@@ -214,8 +213,8 @@ async fn realtime_conversation_stop_emits_closed_notification() -> Result<()> {
let responses_server = create_mock_responses_server_sequence_unchecked(Vec::new()).await;
let realtime_server = start_websocket_server(vec![vec![
vec![json!({
"type": "session.created",
"session": { "id": "sess_backend" }
"type": "session.updated",
"session": { "id": "sess_backend", "instructions": "backend prompt" }
})],
vec![],
]])
@@ -231,6 +230,7 @@ async fn realtime_conversation_stop_emits_closed_notification() -> Result<()> {
let mut mcp = McpProcess::new(codex_home.path()).await?;
mcp.initialize().await?;
login_with_api_key(&mut mcp, "sk-test-key").await?;
let thread_start_request_id = mcp
.send_thread_start_request(ThreadStartParams::default())
@@ -349,6 +349,22 @@ async fn read_notification<T: DeserializeOwned>(mcp: &mut McpProcess, method: &s
Ok(serde_json::from_value(params)?)
}
async fn login_with_api_key(mcp: &mut McpProcess, api_key: &str) -> Result<()> {
let request_id = mcp
.send_login_api_key_request(LoginApiKeyParams {
api_key: api_key.to_string(),
})
.await?;
timeout(
DEFAULT_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(request_id)),
)
.await??;
Ok(())
}
fn create_config_toml(
codex_home: &Path,
responses_server_uri: &str,

View File

@@ -8,6 +8,7 @@ use app_test_support::to_response;
use codex_app_server_protocol::ItemCompletedNotification;
use codex_app_server_protocol::ItemStartedNotification;
use codex_app_server_protocol::JSONRPCError;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::RequestId;
@@ -19,9 +20,12 @@ use codex_app_server_protocol::ServerRequest;
use codex_app_server_protocol::ThreadItem;
use codex_app_server_protocol::ThreadStartParams;
use codex_app_server_protocol::ThreadStartResponse;
use codex_app_server_protocol::ThreadStartedNotification;
use codex_app_server_protocol::ThreadStatusChangedNotification;
use codex_app_server_protocol::TurnStartParams;
use codex_app_server_protocol::TurnStatus;
use codex_app_server_protocol::UserInput as V2UserInput;
use pretty_assertions::assert_eq;
use serde_json::json;
use tempfile::TempDir;
use tokio::time::timeout;
@@ -301,6 +305,31 @@ async fn review_start_with_detached_delivery_returns_new_thread_id() -> Result<(
"detached review should run on a different thread"
);
let deadline = tokio::time::Instant::now() + DEFAULT_READ_TIMEOUT;
let notification = loop {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
let message = timeout(remaining, mcp.read_next_message()).await??;
let JSONRPCMessage::Notification(notification) = message else {
continue;
};
if notification.method == "thread/status/changed" {
let status_changed: ThreadStatusChangedNotification =
serde_json::from_value(notification.params.expect("params must be present"))?;
if status_changed.thread_id == review_thread_id {
anyhow::bail!(
"detached review threads should be introduced without a preceding thread/status/changed"
);
}
continue;
}
if notification.method == "thread/started" {
break notification;
}
};
let started: ThreadStartedNotification =
serde_json::from_value(notification.params.expect("params must be present"))?;
assert_eq!(started.thread.id, review_thread_id);
Ok(())
}
@@ -389,6 +418,11 @@ async fn start_default_thread(mcp: &mut McpProcess) -> Result<String> {
)
.await??;
let ThreadStartResponse { thread, .. } = to_response::<ThreadStartResponse>(thread_resp)?;
timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("thread/started"),
)
.await??;
Ok(thread.id)
}

View File

@@ -4,7 +4,7 @@ use app_test_support::create_fake_rollout;
use app_test_support::create_mock_responses_server_repeating_assistant;
use app_test_support::to_response;
use codex_app_server_protocol::JSONRPCError;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::RequestId;
use codex_app_server_protocol::SessionSource;
@@ -15,6 +15,7 @@ use codex_app_server_protocol::ThreadStartParams;
use codex_app_server_protocol::ThreadStartResponse;
use codex_app_server_protocol::ThreadStartedNotification;
use codex_app_server_protocol::ThreadStatus;
use codex_app_server_protocol::ThreadStatusChangedNotification;
use codex_app_server_protocol::TurnStatus;
use codex_app_server_protocol::UserInput;
use pretty_assertions::assert_eq;
@@ -124,11 +125,27 @@ async fn thread_fork_creates_new_thread_and_emits_started() -> Result<()> {
}
// A corresponding thread/started notification should arrive.
let notif: JSONRPCNotification = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("thread/started"),
)
.await??;
let deadline = tokio::time::Instant::now() + DEFAULT_READ_TIMEOUT;
let notif = loop {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
let message = timeout(remaining, mcp.read_next_message()).await??;
let JSONRPCMessage::Notification(notif) = message else {
continue;
};
if notif.method == "thread/status/changed" {
let status_changed: ThreadStatusChangedNotification =
serde_json::from_value(notif.params.expect("params must be present"))?;
if status_changed.thread_id == thread.id {
anyhow::bail!(
"thread/fork should introduce the thread without a preceding thread/status/changed"
);
}
continue;
}
if notif.method == "thread/started" {
break notif;
}
};
let started_params = notif.params.clone().expect("params must be present");
let started_thread_json = started_params
.get("thread")

View File

@@ -10,6 +10,7 @@ use codex_app_server_protocol::SessionSource;
use codex_app_server_protocol::ThreadItem;
use codex_app_server_protocol::ThreadListParams;
use codex_app_server_protocol::ThreadListResponse;
use codex_app_server_protocol::ThreadNameUpdatedNotification;
use codex_app_server_protocol::ThreadReadParams;
use codex_app_server_protocol::ThreadReadResponse;
use codex_app_server_protocol::ThreadResumeParams;
@@ -220,25 +221,6 @@ async fn thread_name_set_is_reflected_in_read_list_and_resume() -> Result<()> {
let mut mcp = McpProcess::new(codex_home.path()).await?;
timeout(DEFAULT_READ_TIMEOUT, mcp.initialize()).await??;
// `thread/name/set` operates on loaded threads (via ThreadManager). A rollout existing on disk
// is not enough; we must `thread/resume` first to load it into the running server.
let pre_resume_id = mcp
.send_thread_resume_request(ThreadResumeParams {
thread_id: conversation_id.clone(),
..Default::default()
})
.await?;
let pre_resume_resp: JSONRPCResponse = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_response_message(RequestId::Integer(pre_resume_id)),
)
.await??;
let ThreadResumeResponse {
thread: pre_resumed,
..
} = to_response::<ThreadResumeResponse>(pre_resume_resp)?;
assert_eq!(pre_resumed.id, conversation_id);
// Set a user-facing thread title.
let new_name = "My renamed thread";
let set_id = mcp
@@ -253,6 +235,15 @@ async fn thread_name_set_is_reflected_in_read_list_and_resume() -> Result<()> {
)
.await??;
let _: ThreadSetNameResponse = to_response::<ThreadSetNameResponse>(set_resp)?;
let notification = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("thread/name/updated"),
)
.await??;
let notification: ThreadNameUpdatedNotification =
serde_json::from_value(notification.params.expect("thread/name/updated params"))?;
assert_eq!(notification.thread_id, conversation_id);
assert_eq!(notification.thread_name.as_deref(), Some(new_name));
// Read should now surface `thread.name`, and the wire payload must include `name`.
let read_id = mcp

View File

@@ -1,4 +1,3 @@
use anyhow::Context;
use anyhow::Result;
use app_test_support::McpProcess;
use app_test_support::create_apply_patch_sse_response;
@@ -17,7 +16,6 @@ use codex_app_server_protocol::FileChangeApprovalDecision;
use codex_app_server_protocol::FileChangeRequestApprovalResponse;
use codex_app_server_protocol::ItemStartedNotification;
use codex_app_server_protocol::JSONRPCError;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::PatchApplyStatus;
use codex_app_server_protocol::PatchChangeKind;
@@ -30,7 +28,6 @@ use codex_app_server_protocol::ThreadResumeResponse;
use codex_app_server_protocol::ThreadStartParams;
use codex_app_server_protocol::ThreadStartResponse;
use codex_app_server_protocol::ThreadStatus;
use codex_app_server_protocol::ThreadStatusChangedNotification;
use codex_app_server_protocol::TurnStartParams;
use codex_app_server_protocol::TurnStartResponse;
use codex_app_server_protocol::TurnStatus;
@@ -293,7 +290,7 @@ async fn thread_resume_keeps_in_flight_turn_streaming() -> Result<()> {
.await??;
timeout(
DEFAULT_READ_TIMEOUT,
wait_for_thread_status_active(&mut primary, &thread.id),
primary.read_stream_until_notification_message("turn/started"),
)
.await??;
@@ -400,7 +397,7 @@ async fn thread_resume_rejects_history_when_thread_is_running() -> Result<()> {
to_response::<TurnStartResponse>(running_turn_resp)?;
timeout(
DEFAULT_READ_TIMEOUT,
wait_for_thread_status_active(&mut primary, &thread_id),
primary.read_stream_until_notification_message("turn/started"),
)
.await??;
@@ -516,7 +513,7 @@ async fn thread_resume_rejects_mismatched_path_when_thread_is_running() -> Resul
to_response::<TurnStartResponse>(running_turn_resp)?;
timeout(
DEFAULT_READ_TIMEOUT,
wait_for_thread_status_active(&mut primary, &thread_id),
primary.read_stream_until_notification_message("turn/started"),
)
.await??;
@@ -619,7 +616,7 @@ async fn thread_resume_rejoins_running_thread_even_with_override_mismatch() -> R
.await??;
timeout(
DEFAULT_READ_TIMEOUT,
wait_for_thread_status_active(&mut primary, &thread.id),
primary.read_stream_until_notification_message("turn/started"),
)
.await??;
@@ -1419,30 +1416,6 @@ required = true
)
}
async fn wait_for_thread_status_active(
mcp: &mut McpProcess,
thread_id: &str,
) -> Result<ThreadStatusChangedNotification> {
loop {
let status_changed_notif: JSONRPCNotification = mcp
.read_stream_until_notification_message("thread/status/changed")
.await?;
let status_changed_params = status_changed_notif
.params
.context("thread/status/changed params must be present")?;
let status_changed: ThreadStatusChangedNotification =
serde_json::from_value(status_changed_params)?;
if status_changed.thread_id == thread_id
&& status_changed.status
== (ThreadStatus::Active {
active_flags: Vec::new(),
})
{
return Ok(status_changed);
}
}
}
#[allow(dead_code)]
fn set_rollout_mtime(path: &Path, updated_at_rfc3339: &str) -> Result<()> {
let parsed = chrono::DateTime::parse_from_rfc3339(updated_at_rfc3339)?.with_timezone(&Utc);

View File

@@ -3,16 +3,18 @@ use app_test_support::McpProcess;
use app_test_support::create_mock_responses_server_repeating_assistant;
use app_test_support::to_response;
use codex_app_server_protocol::JSONRPCError;
use codex_app_server_protocol::JSONRPCNotification;
use codex_app_server_protocol::JSONRPCMessage;
use codex_app_server_protocol::JSONRPCResponse;
use codex_app_server_protocol::RequestId;
use codex_app_server_protocol::ThreadStartParams;
use codex_app_server_protocol::ThreadStartResponse;
use codex_app_server_protocol::ThreadStartedNotification;
use codex_app_server_protocol::ThreadStatus;
use codex_app_server_protocol::ThreadStatusChangedNotification;
use codex_core::config::set_project_trust_level;
use codex_protocol::config_types::TrustLevel;
use codex_protocol::openai_models::ReasoningEffort;
use pretty_assertions::assert_eq;
use serde_json::Value;
use std::path::Path;
use tempfile::TempDir;
@@ -92,11 +94,27 @@ async fn thread_start_creates_thread_and_emits_started() -> Result<()> {
assert_eq!(thread.name, None);
// A corresponding thread/started notification should arrive.
let notif: JSONRPCNotification = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("thread/started"),
)
.await??;
let deadline = tokio::time::Instant::now() + DEFAULT_READ_TIMEOUT;
let notif = loop {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
let message = timeout(remaining, mcp.read_next_message()).await??;
let JSONRPCMessage::Notification(notif) = message else {
continue;
};
if notif.method == "thread/status/changed" {
let status_changed: ThreadStatusChangedNotification =
serde_json::from_value(notif.params.expect("params must be present"))?;
if status_changed.thread_id == thread.id {
anyhow::bail!(
"thread/start should introduce the thread without a preceding thread/status/changed"
);
}
continue;
}
if notif.method == "thread/started" {
break notif;
}
};
let started_params = notif.params.clone().expect("params must be present");
let started_thread_json = started_params
.get("thread")

View File

@@ -434,6 +434,21 @@ async fn turn_start_emits_notifications_and_accepts_model_override() -> Result<(
started.turn.status,
codex_app_server_protocol::TurnStatus::InProgress
);
assert_eq!(started.turn.id, turn.id);
let completed_notif: JSONRPCNotification = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("turn/completed"),
)
.await??;
let completed: TurnCompletedNotification = serde_json::from_value(
completed_notif
.params
.expect("turn/completed params must be present"),
)?;
assert_eq!(completed.thread_id, thread.id);
assert_eq!(completed.turn.id, turn.id);
assert_eq!(completed.turn.status, TurnStatus::Completed);
// Send a second turn that exercises the overrides path: change the model.
let turn_req2 = mcp
@@ -457,25 +472,30 @@ async fn turn_start_emits_notifications_and_accepts_model_override() -> Result<(
// Ensure the second turn has a different id than the first.
assert_ne!(turn.id, turn2.id);
// Expect a second turn/started notification as well.
let _notif2: JSONRPCNotification = timeout(
let notif2: JSONRPCNotification = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("turn/started"),
)
.await??;
let started2: TurnStartedNotification =
serde_json::from_value(notif2.params.expect("params must be present"))?;
assert_eq!(started2.thread_id, thread.id);
assert_eq!(started2.turn.id, turn2.id);
assert_eq!(started2.turn.status, TurnStatus::InProgress);
let completed_notif: JSONRPCNotification = timeout(
let completed_notif2: JSONRPCNotification = timeout(
DEFAULT_READ_TIMEOUT,
mcp.read_stream_until_notification_message("turn/completed"),
)
.await??;
let completed: TurnCompletedNotification = serde_json::from_value(
completed_notif
let completed2: TurnCompletedNotification = serde_json::from_value(
completed_notif2
.params
.expect("turn/completed params must be present"),
)?;
assert_eq!(completed.thread_id, thread.id);
assert_eq!(completed.turn.status, TurnStatus::Completed);
assert_eq!(completed2.thread_id, thread.id);
assert_eq!(completed2.turn.id, turn2.id);
assert_eq!(completed2.turn.status, TurnStatus::Completed);
Ok(())
}
@@ -1365,6 +1385,7 @@ async fn turn_start_updates_sandbox_and_cwd_between_turns_v2() -> Result<()> {
model: Some("mock-model".to_string()),
effort: Some(ReasoningEffort::Medium),
summary: Some(ReasoningSummary::Auto),
service_tier: None,
personality: None,
output_schema: None,
collaboration_mode: None,
@@ -1396,6 +1417,7 @@ async fn turn_start_updates_sandbox_and_cwd_between_turns_v2() -> Result<()> {
model: Some("mock-model".to_string()),
effort: Some(ReasoningEffort::Medium),
summary: Some(ReasoningSummary::Auto),
service_tier: None,
personality: None,
output_schema: None,
collaboration_mode: None,

View File

@@ -0,0 +1,6 @@
load("//:defs.bzl", "codex_rust_crate")
codex_rust_crate(
name = "artifact-presentation",
crate_name = "codex_artifact_presentation",
)

View File

@@ -0,0 +1,30 @@
[package]
name = "codex-artifact-presentation"
version.workspace = true
edition.workspace = true
license.workspace = true
[lib]
name = "codex_artifact_presentation"
path = "src/lib.rs"
[lints]
workspace = true
[dependencies]
base64 = { workspace = true }
font8x8 = { workspace = true }
image = { workspace = true, features = ["jpeg", "png"] }
ppt-rs = { workspace = true }
reqwest = { workspace = true, features = ["blocking"] }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tiny-skia = { workspace = true }
thiserror = { workspace = true }
uuid = { workspace = true, features = ["v4"] }
zip = { workspace = true }
[dev-dependencies]
pretty_assertions = { workspace = true }
tempfile = { workspace = true }
tiny_http = { workspace = true }

View File

@@ -0,0 +1,6 @@
mod presentation_artifact;
#[cfg(test)]
mod tests;
pub use presentation_artifact::*;

View File

@@ -0,0 +1,250 @@
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use image::GenericImageView;
use image::ImageFormat;
use image::codecs::jpeg::JpegEncoder;
use image::imageops::FilterType;
use ppt_rs::Chart;
use ppt_rs::ChartSeries;
use ppt_rs::ChartType;
use ppt_rs::Hyperlink as PptHyperlink;
use ppt_rs::HyperlinkAction as PptHyperlinkAction;
use ppt_rs::Image;
use ppt_rs::Presentation;
use ppt_rs::Shape;
use ppt_rs::ShapeFill;
use ppt_rs::ShapeLine;
use ppt_rs::ShapeType;
use ppt_rs::SlideContent;
use ppt_rs::SlideLayout;
use ppt_rs::TableBuilder;
use ppt_rs::TableCell;
use ppt_rs::TableRow;
use ppt_rs::generator::ArrowSize;
use ppt_rs::generator::ArrowType;
use ppt_rs::generator::CellAlign;
use ppt_rs::generator::Connector;
use ppt_rs::generator::ConnectorLine;
use ppt_rs::generator::ConnectorType;
use ppt_rs::generator::LineDash;
use ppt_rs::generator::generate_image_content_type;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
use std::collections::BTreeSet;
use std::collections::HashMap;
use std::collections::HashSet;
use std::io::Cursor;
use std::io::Read;
use std::io::Seek;
use std::io::Write;
use std::path::Path;
use std::path::PathBuf;
use thiserror::Error;
use uuid::Uuid;
use zip::ZipArchive;
use zip::ZipWriter;
use zip::write::SimpleFileOptions;
const POINT_TO_EMU: u32 = 12_700;
const DEFAULT_SLIDE_WIDTH_POINTS: u32 = 720;
const DEFAULT_SLIDE_HEIGHT_POINTS: u32 = 540;
const DEFAULT_IMPORTED_TITLE_LEFT: u32 = 36;
const DEFAULT_IMPORTED_TITLE_TOP: u32 = 24;
const DEFAULT_IMPORTED_TITLE_WIDTH: u32 = 648;
const DEFAULT_IMPORTED_TITLE_HEIGHT: u32 = 48;
const DEFAULT_IMPORTED_CONTENT_LEFT: u32 = 48;
const DEFAULT_IMPORTED_CONTENT_TOP: u32 = 96;
const DEFAULT_IMPORTED_CONTENT_WIDTH: u32 = 624;
const DEFAULT_IMPORTED_CONTENT_HEIGHT: u32 = 324;
#[derive(Debug, Error)]
pub enum PresentationArtifactError {
#[error("missing `artifact_id` for action `{action}`")]
MissingArtifactId { action: String },
#[error("unknown artifact id `{artifact_id}` for action `{action}`")]
UnknownArtifactId { action: String, artifact_id: String },
#[error("unknown action `{0}`")]
UnknownAction(String),
#[error("invalid args for action `{action}`: {message}")]
InvalidArgs { action: String, message: String },
#[error("unsupported feature for action `{action}`: {message}")]
UnsupportedFeature { action: String, message: String },
#[error("failed to import PPTX `{path}`: {message}")]
ImportFailed { path: PathBuf, message: String },
#[error("failed to export PPTX `{path}`: {message}")]
ExportFailed { path: PathBuf, message: String },
#[error("failed to render preview for action `{action}`: {message}")]
RenderFailed { action: String, message: String },
}
#[derive(Debug, Clone, Deserialize)]
pub struct PresentationArtifactRequest {
pub artifact_id: Option<String>,
pub action: String,
#[serde(default)]
pub args: Value,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct PresentationArtifactToolRequest {
pub artifact_id: Option<String>,
pub actions: Vec<PresentationArtifactToolAction>,
}
#[derive(Debug, Clone)]
pub struct PresentationArtifactExecutionRequest {
pub artifact_id: Option<String>,
pub requests: Vec<PresentationArtifactRequest>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct PresentationArtifactToolAction {
pub action: String,
#[serde(default)]
pub args: Value,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PathAccessKind {
Read,
Write,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PathAccessRequirement {
pub action: String,
pub kind: PathAccessKind,
pub path: PathBuf,
}
impl PresentationArtifactRequest {
pub fn is_mutating(&self) -> bool {
!is_read_only_action(&self.action)
}
pub fn required_path_accesses(
&self,
cwd: &Path,
) -> Result<Vec<PathAccessRequirement>, PresentationArtifactError> {
let access = match self.action.as_str() {
"import_pptx" => {
let args: ImportPptxArgs = parse_args(&self.action, &self.args)?;
vec![PathAccessRequirement {
action: self.action.clone(),
kind: PathAccessKind::Read,
path: resolve_path(cwd, &args.path),
}]
}
"export_pptx" => {
let args: ExportPptxArgs = parse_args(&self.action, &self.args)?;
vec![PathAccessRequirement {
action: self.action.clone(),
kind: PathAccessKind::Write,
path: resolve_path(cwd, &args.path),
}]
}
"export_preview" => {
let args: ExportPreviewArgs = parse_args(&self.action, &self.args)?;
vec![PathAccessRequirement {
action: self.action.clone(),
kind: PathAccessKind::Write,
path: resolve_path(cwd, &args.path),
}]
}
"add_image" => {
let args: AddImageArgs = parse_args(&self.action, &self.args)?;
match args.image_source()? {
ImageInputSource::Path(path) => vec![PathAccessRequirement {
action: self.action.clone(),
kind: PathAccessKind::Read,
path: resolve_path(cwd, &path),
}],
ImageInputSource::DataUrl(_)
| ImageInputSource::Blob(_)
| ImageInputSource::Uri(_)
| ImageInputSource::Placeholder => Vec::new(),
}
}
"replace_image" => {
let args: ReplaceImageArgs = parse_args(&self.action, &self.args)?;
match (
&args.path,
&args.data_url,
&args.blob,
&args.uri,
&args.prompt,
) {
(Some(path), None, None, None, None) => vec![PathAccessRequirement {
action: self.action.clone(),
kind: PathAccessKind::Read,
path: resolve_path(cwd, path),
}],
(None, Some(_), None, None, None)
| (None, None, Some(_), None, None)
| (None, None, None, Some(_), None)
| (None, None, None, None, Some(_)) => Vec::new(),
_ => {
return Err(PresentationArtifactError::InvalidArgs {
action: self.action.clone(),
message:
"provide exactly one of `path`, `data_url`, `blob`, or `uri`, or provide `prompt` for a placeholder image"
.to_string(),
});
}
}
}
_ => Vec::new(),
};
Ok(access)
}
}
impl PresentationArtifactToolRequest {
pub fn is_mutating(&self) -> Result<bool, PresentationArtifactError> {
Ok(self.actions.iter().any(|request| !is_read_only_action(&request.action)))
}
pub fn into_execution_request(
self,
) -> Result<PresentationArtifactExecutionRequest, PresentationArtifactError> {
if self.actions.is_empty() {
return Err(PresentationArtifactError::InvalidArgs {
action: "presentation_artifact".to_string(),
message: "`actions` must contain at least one item".to_string(),
});
}
Ok(PresentationArtifactExecutionRequest {
artifact_id: self.artifact_id,
requests: self
.actions
.into_iter()
.map(|request| PresentationArtifactRequest {
artifact_id: None,
action: request.action,
args: request.args,
})
.collect(),
})
}
pub fn required_path_accesses(
&self,
cwd: &Path,
) -> Result<Vec<PathAccessRequirement>, PresentationArtifactError> {
let mut accesses = Vec::new();
for request in &self.actions {
accesses.extend(
PresentationArtifactRequest {
artifact_id: None,
action: request.action.clone(),
args: request.args.clone(),
}
.required_path_accesses(cwd)?,
);
}
Ok(accesses)
}
}

View File

@@ -0,0 +1,736 @@
#[derive(Debug, Deserialize)]
struct CreateArgs {
name: Option<String>,
slide_size: Option<Value>,
theme: Option<ThemeArgs>,
}
#[derive(Debug, Deserialize)]
struct ImportPptxArgs {
path: PathBuf,
}
#[derive(Debug, Deserialize)]
struct ExportPptxArgs {
path: PathBuf,
}
#[derive(Debug, Deserialize)]
struct ExportPreviewArgs {
path: PathBuf,
slide_index: Option<u32>,
format: Option<String>,
scale: Option<f32>,
quality: Option<u8>,
}
#[derive(Debug, Default, Deserialize)]
struct RenderPreviewArgs {
slide_index: Option<u32>,
scale: Option<f32>,
include_background: Option<bool>,
}
#[derive(Debug, Default, Deserialize)]
struct AddSlideArgs {
layout: Option<String>,
notes: Option<String>,
background_fill: Option<String>,
}
#[derive(Debug, Deserialize)]
struct CreateLayoutArgs {
name: String,
kind: Option<String>,
parent_layout_id: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum PreviewOutputFormat {
Png,
Jpeg,
Svg,
}
impl PreviewOutputFormat {
fn extension(self) -> &'static str {
match self {
Self::Png => "png",
Self::Jpeg => "jpg",
Self::Svg => "svg",
}
}
}
#[derive(Debug, Deserialize)]
struct AddLayoutPlaceholderArgs {
layout_id: String,
name: String,
placeholder_type: String,
index: Option<u32>,
text: Option<String>,
geometry: Option<String>,
position: Option<PositionArgs>,
}
#[derive(Debug, Deserialize)]
struct LayoutIdArgs {
layout_id: String,
}
#[derive(Debug, Deserialize)]
struct SetSlideLayoutArgs {
slide_index: u32,
layout_id: String,
}
#[derive(Debug, Deserialize)]
struct UpdatePlaceholderTextArgs {
slide_index: u32,
name: String,
text: String,
}
#[derive(Debug, Deserialize)]
struct NotesArgs {
slide_index: u32,
text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct NotesVisibilityArgs {
slide_index: u32,
visible: bool,
}
#[derive(Debug, Deserialize)]
struct ThemeArgs {
color_scheme: HashMap<String, String>,
major_font: Option<String>,
minor_font: Option<String>,
}
#[derive(Debug, Deserialize)]
struct StyleNameArgs {
name: String,
}
#[derive(Debug, Deserialize)]
struct AddStyleArgs {
name: String,
#[serde(flatten)]
styling: TextStylingArgs,
}
#[derive(Debug, Deserialize)]
struct InspectArgs {
kind: Option<String>,
include: Option<String>,
exclude: Option<String>,
search: Option<String>,
target_id: Option<String>,
target: Option<InspectTargetArgs>,
max_chars: Option<usize>,
}
#[derive(Debug, Clone, Deserialize)]
struct InspectTargetArgs {
id: String,
before_lines: Option<usize>,
after_lines: Option<usize>,
}
#[derive(Debug, Deserialize)]
struct ResolveArgs {
id: String,
}
#[derive(Debug, Clone, Deserialize)]
struct PatchOperationInput {
artifact_id: Option<String>,
action: String,
#[serde(default)]
args: Value,
}
#[derive(Debug, Deserialize)]
struct RecordPatchArgs {
operations: Vec<PatchOperationInput>,
}
#[derive(Debug, Deserialize)]
struct ApplyPatchArgs {
operations: Option<Vec<PatchOperationInput>>,
patch: Option<PresentationPatch>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct PresentationPatch {
version: u32,
artifact_id: String,
operations: Vec<PatchOperation>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PatchOperation {
action: String,
#[serde(default)]
args: Value,
}
#[derive(Debug, Default, Deserialize)]
struct InsertSlideArgs {
index: Option<u32>,
after_slide_index: Option<u32>,
layout: Option<String>,
notes: Option<String>,
background_fill: Option<String>,
}
#[derive(Debug, Deserialize)]
struct SlideIndexArgs {
slide_index: u32,
}
#[derive(Debug, Deserialize)]
struct MoveSlideArgs {
from_index: u32,
to_index: u32,
}
#[derive(Debug, Deserialize)]
struct SetActiveSlideArgs {
slide_index: u32,
}
#[derive(Debug, Deserialize)]
struct SetSlideBackgroundArgs {
slide_index: u32,
fill: String,
}
#[derive(Debug, Clone, Deserialize)]
struct PositionArgs {
left: u32,
top: u32,
width: u32,
height: u32,
rotation: Option<i32>,
flip_horizontal: Option<bool>,
flip_vertical: Option<bool>,
}
#[derive(Debug, Clone, Default, Deserialize)]
struct PartialPositionArgs {
left: Option<u32>,
top: Option<u32>,
width: Option<u32>,
height: Option<u32>,
rotation: Option<i32>,
flip_horizontal: Option<bool>,
flip_vertical: Option<bool>,
}
#[derive(Debug, Clone, Default, Deserialize)]
struct TextStylingArgs {
style: Option<String>,
font_size: Option<u32>,
font_family: Option<String>,
color: Option<String>,
fill: Option<String>,
alignment: Option<String>,
bold: Option<bool>,
italic: Option<bool>,
underline: Option<bool>,
}
#[derive(Debug, Clone, Default, Deserialize)]
struct TextLayoutArgs {
insets: Option<TextInsetsArgs>,
wrap: Option<String>,
auto_fit: Option<String>,
vertical_alignment: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct TextInsetsArgs {
left: u32,
right: u32,
top: u32,
bottom: u32,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
enum RichTextInput {
Plain(String),
Paragraphs(Vec<RichParagraphInput>),
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
enum RichParagraphInput {
Plain(String),
Runs(Vec<RichRunInput>),
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
enum RichRunInput {
Plain(String),
Styled(RichRunObjectInput),
}
#[derive(Debug, Clone, Deserialize)]
struct RichRunObjectInput {
run: String,
#[serde(default)]
text_style: TextStylingArgs,
link: Option<RichTextLinkInput>,
}
#[derive(Debug, Clone, Deserialize)]
struct RichTextLinkInput {
uri: Option<String>,
is_external: Option<bool>,
}
#[derive(Debug, Deserialize)]
struct AddTextShapeArgs {
slide_index: u32,
text: String,
position: PositionArgs,
#[serde(flatten)]
styling: TextStylingArgs,
#[serde(default)]
text_layout: TextLayoutArgs,
}
#[derive(Debug, Clone, Default, Deserialize)]
struct StrokeArgs {
color: String,
width: u32,
style: Option<String>,
}
#[derive(Debug, Deserialize)]
struct AddShapeArgs {
slide_index: u32,
geometry: String,
position: PositionArgs,
fill: Option<String>,
stroke: Option<StrokeArgs>,
text: Option<String>,
rotation: Option<i32>,
flip_horizontal: Option<bool>,
flip_vertical: Option<bool>,
#[serde(default)]
text_style: TextStylingArgs,
#[serde(default)]
text_layout: TextLayoutArgs,
}
#[derive(Debug, Clone, Default, Deserialize)]
struct ConnectorLineArgs {
color: Option<String>,
width: Option<u32>,
style: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PointArgs {
left: u32,
top: u32,
}
#[derive(Debug, Deserialize)]
struct AddConnectorArgs {
slide_index: u32,
connector_type: String,
start: PointArgs,
end: PointArgs,
line: Option<ConnectorLineArgs>,
start_arrow: Option<String>,
end_arrow: Option<String>,
arrow_size: Option<String>,
label: Option<String>,
}
#[derive(Debug, Deserialize)]
struct AddImageArgs {
slide_index: u32,
path: Option<PathBuf>,
data_url: Option<String>,
blob: Option<String>,
uri: Option<String>,
position: PositionArgs,
fit: Option<ImageFitMode>,
crop: Option<ImageCropArgs>,
rotation: Option<i32>,
flip_horizontal: Option<bool>,
flip_vertical: Option<bool>,
lock_aspect_ratio: Option<bool>,
alt: Option<String>,
prompt: Option<String>,
}
impl AddImageArgs {
fn image_source(&self) -> Result<ImageInputSource, PresentationArtifactError> {
match (&self.path, &self.data_url, &self.blob, &self.uri) {
(Some(path), None, None, None) => Ok(ImageInputSource::Path(path.clone())),
(None, Some(data_url), None, None) => Ok(ImageInputSource::DataUrl(data_url.clone())),
(None, None, Some(blob), None) => Ok(ImageInputSource::Blob(blob.clone())),
(None, None, None, Some(uri)) => Ok(ImageInputSource::Uri(uri.clone())),
(None, None, None, None) if self.prompt.is_some() => Ok(ImageInputSource::Placeholder),
_ => Err(PresentationArtifactError::InvalidArgs {
action: "add_image".to_string(),
message:
"provide exactly one of `path`, `data_url`, `blob`, or `uri`, or provide `prompt` for a placeholder image"
.to_string(),
}),
}
}
}
enum ImageInputSource {
Path(PathBuf),
DataUrl(String),
Blob(String),
Uri(String),
Placeholder,
}
#[derive(Debug, Clone, Deserialize)]
struct ImageCropArgs {
left: f64,
top: f64,
right: f64,
bottom: f64,
}
#[derive(Debug, Deserialize)]
struct AddTableArgs {
slide_index: u32,
position: PositionArgs,
rows: Vec<Vec<Value>>,
column_widths: Option<Vec<u32>>,
row_heights: Option<Vec<u32>>,
style: Option<String>,
style_options: Option<TableStyleOptionsArgs>,
borders: Option<TableBordersArgs>,
right_to_left: Option<bool>,
}
#[derive(Debug, Deserialize)]
struct AddChartArgs {
slide_index: u32,
position: PositionArgs,
chart_type: String,
categories: Vec<String>,
series: Vec<ChartSeriesArgs>,
title: Option<String>,
style_index: Option<u32>,
has_legend: Option<bool>,
legend_position: Option<String>,
#[serde(default)]
legend_text_style: TextStylingArgs,
x_axis_title: Option<String>,
y_axis_title: Option<String>,
data_labels: Option<ChartDataLabelsArgs>,
chart_fill: Option<String>,
plot_area_fill: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ChartSeriesArgs {
name: String,
values: Vec<f64>,
categories: Option<Vec<String>>,
x_values: Option<Vec<f64>>,
fill: Option<String>,
stroke: Option<StrokeArgs>,
marker: Option<ChartMarkerArgs>,
data_label_overrides: Option<Vec<ChartDataLabelOverrideArgs>>,
}
#[derive(Debug, Clone, Deserialize)]
struct ChartMarkerArgs {
symbol: Option<String>,
size: Option<u32>,
}
#[derive(Debug, Clone, Deserialize)]
struct ChartDataLabelsArgs {
show_value: Option<bool>,
show_category_name: Option<bool>,
show_leader_lines: Option<bool>,
position: Option<String>,
#[serde(default)]
text_style: TextStylingArgs,
}
#[derive(Debug, Clone, Deserialize)]
struct ChartDataLabelOverrideArgs {
idx: u32,
text: Option<String>,
position: Option<String>,
#[serde(default)]
text_style: TextStylingArgs,
fill: Option<String>,
stroke: Option<StrokeArgs>,
}
#[derive(Debug, Deserialize)]
struct UpdateTextArgs {
element_id: String,
text: String,
#[serde(default)]
styling: TextStylingArgs,
#[serde(default)]
text_layout: TextLayoutArgs,
}
#[derive(Debug, Deserialize)]
struct SetRichTextArgs {
element_id: Option<String>,
slide_index: Option<u32>,
row: Option<u32>,
column: Option<u32>,
notes: Option<bool>,
text: RichTextInput,
#[serde(default)]
styling: TextStylingArgs,
#[serde(default)]
text_layout: TextLayoutArgs,
}
#[derive(Debug, Deserialize)]
struct FormatTextRangeArgs {
element_id: Option<String>,
slide_index: Option<u32>,
row: Option<u32>,
column: Option<u32>,
notes: Option<bool>,
query: Option<String>,
occurrence: Option<usize>,
start_cp: Option<usize>,
length: Option<usize>,
#[serde(default)]
styling: TextStylingArgs,
#[serde(default)]
text_layout: TextLayoutArgs,
link: Option<RichTextLinkInput>,
spacing_before: Option<u32>,
spacing_after: Option<u32>,
line_spacing: Option<f32>,
}
#[derive(Debug, Deserialize)]
struct ReplaceTextArgs {
element_id: String,
search: String,
replace: String,
}
#[derive(Debug, Deserialize)]
struct InsertTextAfterArgs {
element_id: String,
after: String,
insert: String,
}
#[derive(Debug, Deserialize)]
struct SetHyperlinkArgs {
element_id: String,
link_type: Option<String>,
url: Option<String>,
slide_index: Option<u32>,
address: Option<String>,
subject: Option<String>,
path: Option<String>,
tooltip: Option<String>,
highlight_click: Option<bool>,
clear: Option<bool>,
}
#[derive(Debug, Deserialize)]
struct UpdateShapeStyleArgs {
element_id: String,
position: Option<PartialPositionArgs>,
fill: Option<String>,
stroke: Option<StrokeArgs>,
rotation: Option<i32>,
flip_horizontal: Option<bool>,
flip_vertical: Option<bool>,
fit: Option<ImageFitMode>,
crop: Option<ImageCropArgs>,
lock_aspect_ratio: Option<bool>,
z_order: Option<u32>,
#[serde(default)]
text_layout: TextLayoutArgs,
}
#[derive(Debug, Deserialize)]
struct ElementIdArgs {
element_id: String,
}
#[derive(Debug, Deserialize)]
struct ReplaceImageArgs {
element_id: String,
path: Option<PathBuf>,
data_url: Option<String>,
blob: Option<String>,
uri: Option<String>,
fit: Option<ImageFitMode>,
crop: Option<ImageCropArgs>,
rotation: Option<i32>,
flip_horizontal: Option<bool>,
flip_vertical: Option<bool>,
lock_aspect_ratio: Option<bool>,
alt: Option<String>,
prompt: Option<String>,
}
#[derive(Debug, Deserialize)]
struct UpdateTableCellArgs {
element_id: String,
row: u32,
column: u32,
value: Value,
#[serde(default)]
styling: TextStylingArgs,
background_fill: Option<String>,
alignment: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct TableStyleOptionsArgs {
header_row: Option<bool>,
banded_rows: Option<bool>,
banded_columns: Option<bool>,
first_column: Option<bool>,
last_column: Option<bool>,
total_row: Option<bool>,
}
#[derive(Debug, Clone, Deserialize)]
struct TableBorderArgs {
color: String,
width: u32,
}
#[derive(Debug, Clone, Deserialize)]
struct TableBordersArgs {
outside: Option<TableBorderArgs>,
inside: Option<TableBorderArgs>,
top: Option<TableBorderArgs>,
bottom: Option<TableBorderArgs>,
left: Option<TableBorderArgs>,
right: Option<TableBorderArgs>,
}
#[derive(Debug, Deserialize)]
struct UpdateTableStyleArgs {
element_id: String,
style: Option<String>,
style_options: Option<TableStyleOptionsArgs>,
borders: Option<TableBordersArgs>,
right_to_left: Option<bool>,
}
#[derive(Debug, Deserialize)]
struct StyleTableBlockArgs {
element_id: String,
row: u32,
column: u32,
row_count: u32,
column_count: u32,
#[serde(default)]
styling: TextStylingArgs,
background_fill: Option<String>,
alignment: Option<String>,
borders: Option<TableBordersArgs>,
}
#[derive(Debug, Deserialize)]
struct MergeTableCellsArgs {
element_id: String,
start_row: u32,
end_row: u32,
start_column: u32,
end_column: u32,
}
#[derive(Debug, Deserialize)]
struct UpdateChartArgs {
element_id: String,
title: Option<String>,
categories: Option<Vec<String>>,
style_index: Option<u32>,
has_legend: Option<bool>,
legend_position: Option<String>,
#[serde(default)]
legend_text_style: TextStylingArgs,
x_axis_title: Option<String>,
y_axis_title: Option<String>,
data_labels: Option<ChartDataLabelsArgs>,
chart_fill: Option<String>,
plot_area_fill: Option<String>,
}
#[derive(Debug, Deserialize)]
struct AddChartSeriesArgs {
element_id: String,
name: String,
values: Vec<f64>,
categories: Option<Vec<String>>,
x_values: Option<Vec<f64>>,
fill: Option<String>,
stroke: Option<StrokeArgs>,
marker: Option<ChartMarkerArgs>,
}
#[derive(Debug, Deserialize)]
struct SetCommentAuthorArgs {
display_name: String,
initials: String,
email: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
struct CommentPositionArgs {
x: u32,
y: u32,
}
#[derive(Debug, Deserialize)]
struct AddCommentThreadArgs {
slide_index: Option<u32>,
element_id: Option<String>,
query: Option<String>,
occurrence: Option<usize>,
start_cp: Option<usize>,
length: Option<usize>,
text: String,
position: Option<CommentPositionArgs>,
}
#[derive(Debug, Deserialize)]
struct AddCommentReplyArgs {
thread_id: String,
text: String,
}
#[derive(Debug, Deserialize)]
struct ToggleCommentReactionArgs {
thread_id: String,
message_id: Option<String>,
emoji: String,
}
#[derive(Debug, Deserialize)]
struct CommentThreadIdArgs {
thread_id: String,
}

View File

@@ -0,0 +1,871 @@
fn inspect_document(document: &PresentationDocument, args: &InspectArgs) -> String {
let include_kinds = args
.include
.as_deref()
.or(args.kind.as_deref())
.unwrap_or(
"deck,slide,textbox,shape,connector,table,chart,image,notes,layoutList,textRange,comment",
);
let included_kinds = include_kinds
.split(',')
.map(str::trim)
.filter(|entry| !entry.is_empty())
.collect::<HashSet<_>>();
let excluded_kinds = args
.exclude
.as_deref()
.unwrap_or_default()
.split(',')
.map(str::trim)
.filter(|entry| !entry.is_empty())
.collect::<HashSet<_>>();
let include = |name: &str| included_kinds.contains(name) && !excluded_kinds.contains(name);
let mut records: Vec<(Value, Option<String>)> = Vec::new();
if include("deck") {
records.push((
serde_json::json!({
"kind": "deck",
"id": format!("pr/{}", document.artifact_id),
"name": document.name,
"slides": document.slides.len(),
"styleIds": document
.named_text_styles()
.iter()
.map(|style| format!("st/{}", style.name))
.collect::<Vec<_>>(),
"activeSlideIndex": document.active_slide_index,
"activeSlideId": document.active_slide_index.and_then(|index| document.slides.get(index)).map(|slide| format!("sl/{}", slide.slide_id)),
"commentThreadIds": document
.comment_threads
.iter()
.map(|thread| format!("th/{}", thread.thread_id))
.collect::<Vec<_>>(),
}),
None,
));
}
if include("styleList") {
for style in document.named_text_styles() {
records.push((named_text_style_to_json(&style, "st"), None));
}
}
if include("layoutList") {
for layout in &document.layouts {
let placeholders = resolved_layout_placeholders(document, &layout.layout_id, "inspect")
.unwrap_or_default()
.into_iter()
.map(|placeholder| {
serde_json::json!({
"name": placeholder.definition.name,
"type": placeholder.definition.placeholder_type,
"sourceLayoutId": placeholder.source_layout_id,
"textPreview": placeholder.definition.text,
})
})
.collect::<Vec<_>>();
records.push((
serde_json::json!({
"kind": "layout",
"id": format!("ly/{}", layout.layout_id),
"layoutId": layout.layout_id,
"name": layout.name,
"type": match layout.kind { LayoutKind::Layout => "layout", LayoutKind::Master => "master" },
"parentLayoutId": layout.parent_layout_id,
"placeholders": placeholders,
}),
None,
));
}
}
for (index, slide) in document.slides.iter().enumerate() {
let slide_id = format!("sl/{}", slide.slide_id);
if include("slide") {
records.push((
serde_json::json!({
"kind": "slide",
"id": slide_id,
"slide": index + 1,
"slideIndex": index,
"isActive": document.active_slide_index == Some(index),
"layoutId": slide.layout_id,
"elements": slide.elements.len(),
}),
Some(slide_id.clone()),
));
}
if include("notes") && !slide.notes.text.is_empty() {
records.push((
serde_json::json!({
"kind": "notes",
"id": format!("nt/{}", slide.slide_id),
"slide": index + 1,
"visible": slide.notes.visible,
"text": slide.notes.text,
"textPreview": slide.notes.text.replace('\n', " | "),
"textChars": slide.notes.text.chars().count(),
"textLines": slide.notes.text.lines().count(),
"richText": rich_text_to_proto(&slide.notes.text, &slide.notes.rich_text),
}),
Some(slide_id.clone()),
));
}
if include("textRange") {
records.extend(
slide
.notes
.rich_text
.ranges
.iter()
.map(|range| {
let mut record = text_range_to_proto(&slide.notes.text, range);
record["kind"] = Value::String("textRange".to_string());
record["slide"] = Value::from(index + 1);
record["slideIndex"] = Value::from(index);
record["hostAnchor"] = Value::String(format!("nt/{}", slide.slide_id));
record["hostKind"] = Value::String("notes".to_string());
(record, Some(slide_id.clone()))
}),
);
}
for element in &slide.elements {
let mut record = match element {
PresentationElement::Text(text) => {
if !include("textbox") {
continue;
}
serde_json::json!({
"kind": "textbox",
"id": format!("sh/{}", text.element_id),
"slide": index + 1,
"text": text.text,
"textStyle": text_style_to_proto(&text.style),
"textPreview": text.text.replace('\n', " | "),
"textChars": text.text.chars().count(),
"textLines": text.text.lines().count(),
"richText": rich_text_to_proto(&text.text, &text.rich_text),
"bbox": [text.frame.left, text.frame.top, text.frame.width, text.frame.height],
"bboxUnit": "points",
})
}
PresentationElement::Shape(shape) => {
if !(include("shape") || include("textbox") && shape.text.is_some()) {
continue;
}
let kind = if shape.text.is_some() && include("textbox") {
"textbox"
} else {
"shape"
};
let mut record = serde_json::json!({
"kind": kind,
"id": format!("sh/{}", shape.element_id),
"slide": index + 1,
"geometry": format!("{:?}", shape.geometry),
"text": shape.text,
"textStyle": text_style_to_proto(&shape.text_style),
"richText": shape
.text
.as_ref()
.zip(shape.rich_text.as_ref())
.map(|(text, rich_text)| rich_text_to_proto(text, rich_text))
.unwrap_or(Value::Null),
"rotation": shape.rotation_degrees,
"flipHorizontal": shape.flip_horizontal,
"flipVertical": shape.flip_vertical,
"bbox": [shape.frame.left, shape.frame.top, shape.frame.width, shape.frame.height],
"bboxUnit": "points",
});
if let Some(text) = &shape.text {
record["textPreview"] = Value::String(text.replace('\n', " | "));
record["textChars"] = Value::from(text.chars().count());
record["textLines"] = Value::from(text.lines().count());
}
record
}
PresentationElement::Connector(connector) => {
if !include("shape") && !include("connector") {
continue;
}
serde_json::json!({
"kind": "connector",
"id": format!("cn/{}", connector.element_id),
"slide": index + 1,
"connectorType": format!("{:?}", connector.connector_type),
"start": [connector.start.left, connector.start.top],
"end": [connector.end.left, connector.end.top],
"lineStyle": format!("{:?}", connector.line_style),
"label": connector.label,
})
}
PresentationElement::Table(table) => {
if !include("table") {
continue;
}
serde_json::json!({
"kind": "table",
"id": format!("tb/{}", table.element_id),
"slide": index + 1,
"rows": table.rows.len(),
"cols": table.rows.iter().map(std::vec::Vec::len).max().unwrap_or(0),
"columnWidths": table.column_widths,
"rowHeights": table.row_heights,
"preview": table.rows.first().map(|row| row.iter().map(|cell| cell.text.clone()).collect::<Vec<_>>().join(" | ")),
"style": table.style,
"styleOptions": table_style_options_to_proto(&table.style_options),
"borders": table.borders.as_ref().map(table_borders_to_proto),
"rightToLeft": table.right_to_left,
"cellTextStyles": table
.rows
.iter()
.map(|row| row.iter().map(|cell| text_style_to_proto(&cell.text_style)).collect::<Vec<_>>())
.collect::<Vec<_>>(),
"rowsData": table
.rows
.iter()
.map(|row| row.iter().map(table_cell_to_proto).collect::<Vec<_>>())
.collect::<Vec<_>>(),
"bbox": [table.frame.left, table.frame.top, table.frame.width, table.frame.height],
"bboxUnit": "points",
})
}
PresentationElement::Chart(chart) => {
if !include("chart") {
continue;
}
serde_json::json!({
"kind": "chart",
"id": format!("ch/{}", chart.element_id),
"slide": index + 1,
"chartType": format!("{:?}", chart.chart_type),
"title": chart.title,
"styleIndex": chart.style_index,
"hasLegend": chart.has_legend,
"legend": chart.legend.as_ref().map(chart_legend_to_proto),
"xAxis": chart.x_axis.as_ref().map(chart_axis_to_proto),
"yAxis": chart.y_axis.as_ref().map(chart_axis_to_proto),
"dataLabels": chart.data_labels.as_ref().map(chart_data_labels_to_proto),
"chartFill": chart.chart_fill,
"plotAreaFill": chart.plot_area_fill,
"series": chart
.series
.iter()
.map(|series| serde_json::json!({
"name": series.name,
"values": series.values,
"categories": series.categories,
"xValues": series.x_values,
"fill": series.fill,
"stroke": series.stroke.as_ref().map(stroke_to_proto),
"marker": series.marker.as_ref().map(chart_marker_to_proto),
"dataLabelOverrides": series
.data_label_overrides
.iter()
.map(chart_data_label_override_to_proto)
.collect::<Vec<_>>(),
}))
.collect::<Vec<_>>(),
"bbox": [chart.frame.left, chart.frame.top, chart.frame.width, chart.frame.height],
"bboxUnit": "points",
})
}
PresentationElement::Image(image) => {
if !include("image") {
continue;
}
serde_json::json!({
"kind": "image",
"id": format!("im/{}", image.element_id),
"slide": index + 1,
"alt": image.alt_text,
"prompt": image.prompt,
"fit": format!("{:?}", image.fit_mode),
"rotation": image.rotation_degrees,
"flipHorizontal": image.flip_horizontal,
"flipVertical": image.flip_vertical,
"crop": image.crop.map(|(left, top, right, bottom)| serde_json::json!({
"left": left,
"top": top,
"right": right,
"bottom": bottom,
})),
"isPlaceholder": image.is_placeholder,
"lockAspectRatio": image.lock_aspect_ratio,
"bbox": [image.frame.left, image.frame.top, image.frame.width, image.frame.height],
"bboxUnit": "points",
})
}
};
if let Some(placeholder) = match element {
PresentationElement::Text(text) => text.placeholder.as_ref(),
PresentationElement::Shape(shape) => shape.placeholder.as_ref(),
PresentationElement::Connector(_)
| PresentationElement::Table(_)
| PresentationElement::Chart(_) => None,
PresentationElement::Image(image) => image.placeholder.as_ref(),
} {
record["placeholder"] = Value::String(placeholder.placeholder_type.clone());
record["placeholderName"] = Value::String(placeholder.name.clone());
record["placeholderIndex"] =
placeholder.index.map(Value::from).unwrap_or(Value::Null);
}
if let PresentationElement::Shape(shape) = element
&& let Some(stroke) = &shape.stroke
{
record["stroke"] = serde_json::json!({
"color": stroke.color,
"width": stroke.width,
"style": stroke.style.as_api_str(),
});
}
if let Some(hyperlink) = match element {
PresentationElement::Text(text) => text.hyperlink.as_ref(),
PresentationElement::Shape(shape) => shape.hyperlink.as_ref(),
PresentationElement::Connector(_)
| PresentationElement::Image(_)
| PresentationElement::Table(_)
| PresentationElement::Chart(_) => None,
} {
record["hyperlink"] = hyperlink.to_json();
}
records.push((record, Some(slide_id.clone())));
if include("textRange") {
match element {
PresentationElement::Text(text) => {
records.extend(text.rich_text.ranges.iter().map(|range| {
let mut record = text_range_to_proto(&text.text, range);
record["kind"] = Value::String("textRange".to_string());
record["slide"] = Value::from(index + 1);
record["slideIndex"] = Value::from(index);
record["hostAnchor"] = Value::String(format!("sh/{}", text.element_id));
record["hostKind"] = Value::String("textbox".to_string());
(record, Some(slide_id.clone()))
}));
}
PresentationElement::Shape(shape) => {
if let Some((text, rich_text)) = shape.text.as_ref().zip(shape.rich_text.as_ref()) {
records.extend(rich_text.ranges.iter().map(|range| {
let mut record = text_range_to_proto(text, range);
record["kind"] = Value::String("textRange".to_string());
record["slide"] = Value::from(index + 1);
record["slideIndex"] = Value::from(index);
record["hostAnchor"] = Value::String(format!("sh/{}", shape.element_id));
record["hostKind"] = Value::String("textbox".to_string());
(record, Some(slide_id.clone()))
}));
}
}
PresentationElement::Table(table) => {
for (row_index, row) in table.rows.iter().enumerate() {
for (column_index, cell) in row.iter().enumerate() {
records.extend(cell.rich_text.ranges.iter().map(|range| {
let mut record = text_range_to_proto(&cell.text, range);
record["kind"] = Value::String("textRange".to_string());
record["slide"] = Value::from(index + 1);
record["slideIndex"] = Value::from(index);
record["hostAnchor"] = Value::String(format!(
"tb/{}#cell/{row_index}/{column_index}",
table.element_id
));
record["hostKind"] = Value::String("tableCell".to_string());
(record, Some(slide_id.clone()))
}));
}
}
}
PresentationElement::Connector(_)
| PresentationElement::Image(_)
| PresentationElement::Chart(_) => {}
}
}
}
}
if include("comment") {
records.extend(document.comment_threads.iter().map(|thread| {
let mut record = comment_thread_to_proto(thread);
record["id"] = Value::String(format!("th/{}", thread.thread_id));
(record, None)
}));
}
if let Some(target_id) = args.target_id.as_deref() {
records.retain(|(record, slide_id)| {
legacy_target_matches(target_id, record, slide_id.as_deref())
});
if records.is_empty() {
records.push((
serde_json::json!({
"kind": "notice",
"noticeType": "targetNotFound",
"target": { "id": target_id },
"message": format!("No inspect records matched target `{target_id}`."),
}),
None,
));
}
}
if let Some(search) = args.search.as_deref() {
let search_lowercase = search.to_ascii_lowercase();
records.retain(|(record, _)| {
record
.to_string()
.to_ascii_lowercase()
.contains(&search_lowercase)
});
if records.is_empty() {
records.push((
serde_json::json!({
"kind": "notice",
"noticeType": "noMatches",
"search": search,
"message": format!("No inspect records matched search `{search}`."),
}),
None,
));
}
}
if let Some(target) = args.target.as_ref() {
if let Some(target_index) = records.iter().position(|(record, _)| {
record.get("id").and_then(Value::as_str) == Some(target.id.as_str())
}) {
let start = target_index.saturating_sub(target.before_lines.unwrap_or(0));
let end = (target_index + target.after_lines.unwrap_or(0) + 1).min(records.len());
records = records.into_iter().skip(start).take(end - start).collect();
} else {
records = vec![(
serde_json::json!({
"kind": "notice",
"noticeType": "targetNotFound",
"target": {
"id": target.id,
"beforeLines": target.before_lines,
"afterLines": target.after_lines,
},
"message": format!("No inspect records matched target `{}`.", target.id),
}),
None,
)];
}
}
let mut lines = Vec::new();
let mut omitted_lines = 0usize;
let mut omitted_chars = 0usize;
for line in records.into_iter().map(|(record, _)| record.to_string()) {
let separator_len = usize::from(!lines.is_empty());
if let Some(max_chars) = args.max_chars
&& lines.iter().map(String::len).sum::<usize>() + separator_len + line.len() > max_chars
{
omitted_lines += 1;
omitted_chars += line.len();
continue;
}
lines.push(line);
}
if omitted_lines > 0 {
lines.push(
serde_json::json!({
"kind": "notice",
"noticeType": "truncation",
"maxChars": args.max_chars,
"omittedLines": omitted_lines,
"omittedChars": omitted_chars,
"message": format!(
"Truncated inspect output by omitting {omitted_lines} lines. Increase maxChars or narrow the filter."
),
})
.to_string(),
);
}
lines.join("\n")
}
fn legacy_target_matches(target_id: &str, record: &Value, slide_id: Option<&str>) -> bool {
record.get("id").and_then(Value::as_str) == Some(target_id) || slide_id == Some(target_id)
}
fn add_text_metadata(record: &mut Value, text: &str) {
record["textPreview"] = Value::String(text.replace('\n', " | "));
record["textChars"] = Value::from(text.chars().count());
record["textLines"] = Value::from(text.lines().count());
}
fn normalize_element_lookup_id(element_id: &str) -> &str {
element_id
.split_once('/')
.map(|(_, normalized)| normalized)
.unwrap_or(element_id)
}
fn resolve_anchor(
document: &PresentationDocument,
id: &str,
action: &str,
) -> Result<Value, PresentationArtifactError> {
if id == format!("pr/{}", document.artifact_id) {
return Ok(serde_json::json!({
"kind": "deck",
"id": id,
"artifactId": document.artifact_id,
"name": document.name,
"slideCount": document.slides.len(),
"styleIds": document
.named_text_styles()
.iter()
.map(|style| format!("st/{}", style.name))
.collect::<Vec<_>>(),
"activeSlideIndex": document.active_slide_index,
"activeSlideId": document.active_slide_index.and_then(|index| document.slides.get(index)).map(|slide| format!("sl/{}", slide.slide_id)),
}));
}
if let Some(style_name) = id.strip_prefix("st/") {
let named_style = document
.named_text_styles()
.into_iter()
.find(|style| style.name == style_name)
.ok_or_else(|| PresentationArtifactError::UnsupportedFeature {
action: action.to_string(),
message: format!("unknown style id `{id}`"),
})?;
return Ok(named_text_style_to_json(&named_style, "st"));
}
for (slide_index, slide) in document.slides.iter().enumerate() {
let slide_id = format!("sl/{}", slide.slide_id);
if id == slide_id {
return Ok(serde_json::json!({
"kind": "slide",
"id": slide_id,
"slide": slide_index + 1,
"slideIndex": slide_index,
"isActive": document.active_slide_index == Some(slide_index),
"layoutId": slide.layout_id,
"notesId": (!slide.notes.text.is_empty()).then(|| format!("nt/{}", slide.slide_id)),
"elementIds": slide.elements.iter().map(|element| {
let prefix = match element {
PresentationElement::Text(_) | PresentationElement::Shape(_) => "sh",
PresentationElement::Connector(_) => "cn",
PresentationElement::Image(_) => "im",
PresentationElement::Table(_) => "tb",
PresentationElement::Chart(_) => "ch",
};
format!("{prefix}/{}", element.element_id())
}).collect::<Vec<_>>(),
}));
}
let notes_id = format!("nt/{}", slide.slide_id);
if id == notes_id {
let mut record = serde_json::json!({
"kind": "notes",
"id": notes_id,
"slide": slide_index + 1,
"slideIndex": slide_index,
"visible": slide.notes.visible,
"text": slide.notes.text,
});
add_text_metadata(&mut record, &slide.notes.text);
record["richText"] = rich_text_to_proto(&slide.notes.text, &slide.notes.rich_text);
return Ok(record);
}
if let Some(range_id) = id.strip_prefix("tr/")
&& let Some(record) = slide
.notes
.rich_text
.ranges
.iter()
.find(|range| range.range_id == range_id)
.map(|range| {
let mut record = text_range_to_proto(&slide.notes.text, range);
record["kind"] = Value::String("textRange".to_string());
record["id"] = Value::String(id.to_string());
record["slide"] = Value::from(slide_index + 1);
record["slideIndex"] = Value::from(slide_index);
record["hostAnchor"] = Value::String(notes_id.clone());
record["hostKind"] = Value::String("notes".to_string());
record
})
{
return Ok(record);
}
for element in &slide.elements {
let mut record = match element {
PresentationElement::Text(text) => {
let mut record = serde_json::json!({
"kind": "textbox",
"id": format!("sh/{}", text.element_id),
"elementId": text.element_id,
"slide": slide_index + 1,
"slideIndex": slide_index,
"text": text.text,
"textStyle": text_style_to_proto(&text.style),
"richText": rich_text_to_proto(&text.text, &text.rich_text),
"bbox": [text.frame.left, text.frame.top, text.frame.width, text.frame.height],
"bboxUnit": "points",
});
add_text_metadata(&mut record, &text.text);
record
}
PresentationElement::Shape(shape) => {
let mut record = serde_json::json!({
"kind": if shape.text.is_some() { "textbox" } else { "shape" },
"id": format!("sh/{}", shape.element_id),
"elementId": shape.element_id,
"slide": slide_index + 1,
"slideIndex": slide_index,
"geometry": format!("{:?}", shape.geometry),
"text": shape.text,
"textStyle": text_style_to_proto(&shape.text_style),
"richText": shape
.text
.as_ref()
.zip(shape.rich_text.as_ref())
.map(|(text, rich_text)| rich_text_to_proto(text, rich_text))
.unwrap_or(Value::Null),
"rotation": shape.rotation_degrees,
"flipHorizontal": shape.flip_horizontal,
"flipVertical": shape.flip_vertical,
"bbox": [shape.frame.left, shape.frame.top, shape.frame.width, shape.frame.height],
"bboxUnit": "points",
});
if let Some(text) = &shape.text {
add_text_metadata(&mut record, text);
}
record
}
PresentationElement::Connector(connector) => serde_json::json!({
"kind": "connector",
"id": format!("cn/{}", connector.element_id),
"elementId": connector.element_id,
"slide": slide_index + 1,
"slideIndex": slide_index,
"connectorType": format!("{:?}", connector.connector_type),
"start": [connector.start.left, connector.start.top],
"end": [connector.end.left, connector.end.top],
"lineStyle": format!("{:?}", connector.line_style),
"label": connector.label,
}),
PresentationElement::Image(image) => serde_json::json!({
"kind": "image",
"id": format!("im/{}", image.element_id),
"elementId": image.element_id,
"slide": slide_index + 1,
"slideIndex": slide_index,
"alt": image.alt_text,
"prompt": image.prompt,
"fit": format!("{:?}", image.fit_mode),
"rotation": image.rotation_degrees,
"flipHorizontal": image.flip_horizontal,
"flipVertical": image.flip_vertical,
"crop": image.crop.map(|(left, top, right, bottom)| serde_json::json!({
"left": left,
"top": top,
"right": right,
"bottom": bottom,
})),
"isPlaceholder": image.is_placeholder,
"lockAspectRatio": image.lock_aspect_ratio,
"bbox": [image.frame.left, image.frame.top, image.frame.width, image.frame.height],
"bboxUnit": "points",
}),
PresentationElement::Table(table) => serde_json::json!({
"kind": "table",
"id": format!("tb/{}", table.element_id),
"elementId": table.element_id,
"slide": slide_index + 1,
"slideIndex": slide_index,
"rows": table.rows.len(),
"cols": table.rows.iter().map(std::vec::Vec::len).max().unwrap_or(0),
"columnWidths": table.column_widths,
"rowHeights": table.row_heights,
"style": table.style,
"styleOptions": table_style_options_to_proto(&table.style_options),
"borders": table.borders.as_ref().map(table_borders_to_proto),
"rightToLeft": table.right_to_left,
"cellTextStyles": table
.rows
.iter()
.map(|row| row.iter().map(|cell| text_style_to_proto(&cell.text_style)).collect::<Vec<_>>())
.collect::<Vec<_>>(),
"rowsData": table
.rows
.iter()
.map(|row| row.iter().map(table_cell_to_proto).collect::<Vec<_>>())
.collect::<Vec<_>>(),
"bbox": [table.frame.left, table.frame.top, table.frame.width, table.frame.height],
"bboxUnit": "points",
}),
PresentationElement::Chart(chart) => serde_json::json!({
"kind": "chart",
"id": format!("ch/{}", chart.element_id),
"elementId": chart.element_id,
"slide": slide_index + 1,
"slideIndex": slide_index,
"chartType": format!("{:?}", chart.chart_type),
"title": chart.title,
"styleIndex": chart.style_index,
"hasLegend": chart.has_legend,
"legend": chart.legend.as_ref().map(chart_legend_to_proto),
"xAxis": chart.x_axis.as_ref().map(chart_axis_to_proto),
"yAxis": chart.y_axis.as_ref().map(chart_axis_to_proto),
"dataLabels": chart.data_labels.as_ref().map(chart_data_labels_to_proto),
"chartFill": chart.chart_fill,
"plotAreaFill": chart.plot_area_fill,
"series": chart
.series
.iter()
.map(|series| serde_json::json!({
"name": series.name,
"values": series.values,
"categories": series.categories,
"xValues": series.x_values,
"fill": series.fill,
"stroke": series.stroke.as_ref().map(stroke_to_proto),
"marker": series.marker.as_ref().map(chart_marker_to_proto),
"dataLabelOverrides": series
.data_label_overrides
.iter()
.map(chart_data_label_override_to_proto)
.collect::<Vec<_>>(),
}))
.collect::<Vec<_>>(),
"bbox": [chart.frame.left, chart.frame.top, chart.frame.width, chart.frame.height],
"bboxUnit": "points",
}),
};
if let Some(hyperlink) = match element {
PresentationElement::Text(text) => text.hyperlink.as_ref(),
PresentationElement::Shape(shape) => shape.hyperlink.as_ref(),
PresentationElement::Connector(_)
| PresentationElement::Image(_)
| PresentationElement::Table(_)
| PresentationElement::Chart(_) => None,
} {
record["hyperlink"] = hyperlink.to_json();
}
if let PresentationElement::Shape(shape) = element
&& let Some(stroke) = &shape.stroke
{
record["stroke"] = serde_json::json!({
"color": stroke.color,
"width": stroke.width,
"style": stroke.style.as_api_str(),
});
}
if let Some(placeholder) = match element {
PresentationElement::Text(text) => text.placeholder.as_ref(),
PresentationElement::Shape(shape) => shape.placeholder.as_ref(),
PresentationElement::Image(image) => image.placeholder.as_ref(),
PresentationElement::Connector(_)
| PresentationElement::Table(_)
| PresentationElement::Chart(_) => None,
} {
record["placeholder"] = Value::String(placeholder.placeholder_type.clone());
record["placeholderName"] = Value::String(placeholder.name.clone());
record["placeholderIndex"] =
placeholder.index.map(Value::from).unwrap_or(Value::Null);
}
if record.get("id").and_then(Value::as_str) == Some(id) {
return Ok(record);
}
if let Some(range_id) = id.strip_prefix("tr/") {
match element {
PresentationElement::Text(text) => {
if let Some(range) =
text.rich_text.ranges.iter().find(|range| range.range_id == range_id)
{
let mut range_record = text_range_to_proto(&text.text, range);
range_record["kind"] = Value::String("textRange".to_string());
range_record["id"] = Value::String(id.to_string());
range_record["slide"] = Value::from(slide_index + 1);
range_record["slideIndex"] = Value::from(slide_index);
range_record["hostAnchor"] =
Value::String(format!("sh/{}", text.element_id));
range_record["hostKind"] = Value::String("textbox".to_string());
return Ok(range_record);
}
}
PresentationElement::Shape(shape) => {
if let Some((text, rich_text)) =
shape.text.as_ref().zip(shape.rich_text.as_ref())
&& let Some(range) =
rich_text.ranges.iter().find(|range| range.range_id == range_id)
{
let mut range_record = text_range_to_proto(text, range);
range_record["kind"] = Value::String("textRange".to_string());
range_record["id"] = Value::String(id.to_string());
range_record["slide"] = Value::from(slide_index + 1);
range_record["slideIndex"] = Value::from(slide_index);
range_record["hostAnchor"] =
Value::String(format!("sh/{}", shape.element_id));
range_record["hostKind"] = Value::String("textbox".to_string());
return Ok(range_record);
}
}
PresentationElement::Table(table) => {
for (row_index, row) in table.rows.iter().enumerate() {
for (column_index, cell) in row.iter().enumerate() {
if let Some(range) = cell
.rich_text
.ranges
.iter()
.find(|range| range.range_id == range_id)
{
let mut range_record = text_range_to_proto(&cell.text, range);
range_record["kind"] = Value::String("textRange".to_string());
range_record["id"] = Value::String(id.to_string());
range_record["slide"] = Value::from(slide_index + 1);
range_record["slideIndex"] = Value::from(slide_index);
range_record["hostAnchor"] = Value::String(format!(
"tb/{}#cell/{row_index}/{column_index}",
table.element_id
));
range_record["hostKind"] =
Value::String("tableCell".to_string());
return Ok(range_record);
}
}
}
}
PresentationElement::Connector(_)
| PresentationElement::Image(_)
| PresentationElement::Chart(_) => {}
}
}
}
}
if let Some(thread_id) = id.strip_prefix("th/")
&& let Some(thread) = document
.comment_threads
.iter()
.find(|thread| thread.thread_id == thread_id)
{
let mut record = comment_thread_to_proto(thread);
record["id"] = Value::String(id.to_string());
return Ok(record);
}
for layout in &document.layouts {
let layout_id = format!("ly/{}", layout.layout_id);
if id == layout_id {
return Ok(serde_json::json!({
"kind": "layout",
"id": layout_id,
"layoutId": layout.layout_id,
"name": layout.name,
"type": match layout.kind {
LayoutKind::Layout => "layout",
LayoutKind::Master => "master",
},
"parentLayoutId": layout.parent_layout_id,
"placeholders": layout_placeholder_list(document, &layout.layout_id, action)?,
}));
}
}
Err(PresentationArtifactError::UnsupportedFeature {
action: action.to_string(),
message: format!("unknown resolve id `{id}`"),
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,11 @@
include!("api.rs");
include!("manager.rs");
include!("response.rs");
include!("model.rs");
include!("args.rs");
include!("parsing.rs");
include!("proto.rs");
include!("inspect.rs");
include!("pptx.rs");
include!("snapshot.rs");
include!("render.rs");

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,851 @@
const CODEX_METADATA_ENTRY: &str = "ppt/codex-document.json";
fn import_codex_metadata_document(path: &Path) -> Result<Option<PresentationDocument>, String> {
let file = std::fs::File::open(path).map_err(|error| error.to_string())?;
let mut archive = ZipArchive::new(file).map_err(|error| error.to_string())?;
let mut entry = match archive.by_name(CODEX_METADATA_ENTRY) {
Ok(entry) => entry,
Err(zip::result::ZipError::FileNotFound) => return Ok(None),
Err(error) => return Err(error.to_string()),
};
let mut bytes = Vec::new();
entry.read_to_end(&mut bytes)
.map_err(|error| error.to_string())?;
serde_json::from_slice(&bytes)
.map(Some)
.map_err(|error| error.to_string())
}
fn build_pptx_bytes(document: &PresentationDocument, action: &str) -> Result<Vec<u8>, String> {
let bytes = document
.to_ppt_rs()
.build()
.map_err(|error| format!("{action}: {error}"))?;
patch_pptx_package(bytes, document).map_err(|error| format!("{action}: {error}"))
}
struct SlideImageAsset {
xml: String,
relationship_xml: String,
media_path: String,
media_bytes: Vec<u8>,
extension: String,
}
fn normalized_image_extension(format: &str) -> String {
match format.to_ascii_lowercase().as_str() {
"jpeg" => "jpg".to_string(),
other => other.to_string(),
}
}
fn image_relationship_xml(relationship_id: &str, target: &str) -> String {
format!(
r#"<Relationship Id="{relationship_id}" Type="http://schemas.openxmlformats.org/officeDocument/2006/relationships/image" Target="{}"/>"#,
ppt_rs::escape_xml(target)
)
}
fn image_picture_xml(
image: &ImageElement,
shape_id: usize,
relationship_id: &str,
frame: Rect,
crop: Option<ImageCrop>,
) -> String {
let blip_fill = if let Some((crop_left, crop_top, crop_right, crop_bottom)) = crop {
format!(
r#"<p:blipFill>
<a:blip r:embed="{relationship_id}"/>
<a:srcRect l="{}" t="{}" r="{}" b="{}"/>
<a:stretch>
<a:fillRect/>
</a:stretch>
</p:blipFill>"#,
(crop_left * 100_000.0).round() as u32,
(crop_top * 100_000.0).round() as u32,
(crop_right * 100_000.0).round() as u32,
(crop_bottom * 100_000.0).round() as u32,
)
} else {
format!(
r#"<p:blipFill>
<a:blip r:embed="{relationship_id}"/>
<a:stretch>
<a:fillRect/>
</a:stretch>
</p:blipFill>"#
)
};
let descr = image
.alt_text
.as_deref()
.map(|alt| format!(r#" descr="{}""#, ppt_rs::escape_xml(alt)))
.unwrap_or_default();
let no_change_aspect = if image.lock_aspect_ratio { 1 } else { 0 };
let rotation = image
.rotation_degrees
.map(|rotation| format!(r#" rot="{}""#, i64::from(rotation) * 60_000))
.unwrap_or_default();
let flip_horizontal = if image.flip_horizontal {
r#" flipH="1""#
} else {
""
};
let flip_vertical = if image.flip_vertical {
r#" flipV="1""#
} else {
""
};
format!(
r#"<p:pic>
<p:nvPicPr>
<p:cNvPr id="{shape_id}" name="Picture {shape_id}"{descr}/>
<p:cNvPicPr>
<a:picLocks noChangeAspect="{no_change_aspect}"/>
</p:cNvPicPr>
<p:nvPr/>
</p:nvPicPr>
{blip_fill}
<p:spPr>
<a:xfrm{rotation}{flip_horizontal}{flip_vertical}>
<a:off x="{}" y="{}"/>
<a:ext cx="{}" cy="{}"/>
</a:xfrm>
<a:prstGeom prst="rect">
<a:avLst/>
</a:prstGeom>
</p:spPr>
</p:pic>"#,
points_to_emu(frame.left),
points_to_emu(frame.top),
points_to_emu(frame.width),
points_to_emu(frame.height),
)
}
fn slide_image_assets(
slide: &PresentationSlide,
next_media_index: &mut usize,
) -> Vec<SlideImageAsset> {
let mut ordered = slide.elements.iter().collect::<Vec<_>>();
ordered.sort_by_key(|element| element.z_order());
let shape_count = ordered
.iter()
.filter(|element| {
matches!(
element,
PresentationElement::Text(_)
| PresentationElement::Shape(_)
| PresentationElement::Image(ImageElement { payload: None, .. })
)
})
.count()
+ usize::from(slide.background_fill.is_some());
let mut image_index = 0_usize;
let mut assets = Vec::new();
for element in ordered {
let PresentationElement::Image(image) = element else {
continue;
};
let Some(payload) = &image.payload else {
continue;
};
let (left, top, width, height, fitted_crop) = if image.fit_mode != ImageFitMode::Stretch {
fit_image(image)
} else {
(
image.frame.left,
image.frame.top,
image.frame.width,
image.frame.height,
None,
)
};
image_index += 1;
let relationship_id = format!("rIdImage{image_index}");
let extension = normalized_image_extension(&payload.format);
let media_name = format!("image{next_media_index}.{extension}");
*next_media_index += 1;
assets.push(SlideImageAsset {
xml: image_picture_xml(
image,
20 + shape_count + image_index - 1,
&relationship_id,
Rect {
left,
top,
width,
height,
},
image.crop.or(fitted_crop),
),
relationship_xml: image_relationship_xml(
&relationship_id,
&format!("../media/{media_name}"),
),
media_path: format!("ppt/media/{media_name}"),
media_bytes: payload.bytes.clone(),
extension,
});
}
assets
}
fn patch_pptx_package(
source_bytes: Vec<u8>,
document: &PresentationDocument,
) -> Result<Vec<u8>, String> {
let mut archive =
ZipArchive::new(Cursor::new(source_bytes)).map_err(|error| error.to_string())?;
let mut writer = ZipWriter::new(Cursor::new(Vec::new()));
let mut next_media_index = 1_usize;
let mut pending_slide_relationships = HashMap::new();
let mut pending_slide_images = HashMap::new();
let mut pending_media = Vec::new();
let mut image_extensions = BTreeSet::new();
for (slide_index, slide) in document.slides.iter().enumerate() {
let slide_number = slide_index + 1;
let images = slide_image_assets(slide, &mut next_media_index);
let mut relationships = slide_hyperlink_relationships(slide);
relationships.extend(images.iter().map(|image| image.relationship_xml.clone()));
if !relationships.is_empty() {
pending_slide_relationships.insert(slide_number, relationships);
}
if !images.is_empty() {
image_extensions.extend(images.iter().map(|image| image.extension.clone()));
pending_media.extend(
images
.iter()
.map(|image| (image.media_path.clone(), image.media_bytes.clone())),
);
pending_slide_images.insert(slide_number, images);
}
}
for index in 0..archive.len() {
let mut file = archive.by_index(index).map_err(|error| error.to_string())?;
if file.is_dir() {
continue;
}
let name = file.name().to_string();
if name == CODEX_METADATA_ENTRY {
continue;
}
let options = file.options();
let mut bytes = Vec::new();
file.read_to_end(&mut bytes)
.map_err(|error| error.to_string())?;
writer
.start_file(&name, options)
.map_err(|error| error.to_string())?;
if name == "[Content_Types].xml" {
writer
.write_all(update_content_types_xml(bytes, &image_extensions)?.as_bytes())
.map_err(|error| error.to_string())?;
continue;
}
if name == "ppt/presentation.xml" {
writer
.write_all(
update_presentation_xml_dimensions(bytes, document.slide_size)?.as_bytes(),
)
.map_err(|error| error.to_string())?;
continue;
}
if let Some(slide_number) = parse_slide_xml_path(&name) {
writer
.write_all(
update_slide_xml(
bytes,
&document.slides[slide_number - 1],
pending_slide_images
.get(&slide_number)
.map(std::vec::Vec::as_slice)
.unwrap_or(&[]),
)?
.as_bytes(),
)
.map_err(|error| error.to_string())?;
continue;
}
if let Some(slide_number) = parse_slide_relationships_path(&name)
&& let Some(relationships) = pending_slide_relationships.remove(&slide_number)
{
writer
.write_all(update_slide_relationships_xml(bytes, &relationships)?.as_bytes())
.map_err(|error| error.to_string())?;
continue;
}
writer
.write_all(&bytes)
.map_err(|error| error.to_string())?;
}
for (slide_number, relationships) in pending_slide_relationships {
writer
.start_file(
format!("ppt/slides/_rels/slide{slide_number}.xml.rels"),
SimpleFileOptions::default(),
)
.map_err(|error| error.to_string())?;
writer
.write_all(slide_relationships_xml(&relationships).as_bytes())
.map_err(|error| error.to_string())?;
}
for (path, bytes) in pending_media {
writer
.start_file(path, SimpleFileOptions::default())
.map_err(|error| error.to_string())?;
writer
.write_all(&bytes)
.map_err(|error| error.to_string())?;
}
writer
.start_file(CODEX_METADATA_ENTRY, SimpleFileOptions::default())
.map_err(|error| error.to_string())?;
writer
.write_all(
&serde_json::to_vec(document).map_err(|error| error.to_string())?,
)
.map_err(|error| error.to_string())?;
writer
.finish()
.map_err(|error| error.to_string())
.map(Cursor::into_inner)
}
fn update_presentation_xml_dimensions(
existing_bytes: Vec<u8>,
slide_size: Rect,
) -> Result<String, String> {
let existing = String::from_utf8(existing_bytes).map_err(|error| error.to_string())?;
let updated = replace_self_closing_xml_tag(
&existing,
"p:sldSz",
&format!(
r#"<p:sldSz cx="{}" cy="{}" type="screen4x3"/>"#,
points_to_emu(slide_size.width),
points_to_emu(slide_size.height)
),
)?;
replace_self_closing_xml_tag(
&updated,
"p:notesSz",
&format!(
r#"<p:notesSz cx="{}" cy="{}"/>"#,
points_to_emu(slide_size.height),
points_to_emu(slide_size.width)
),
)
}
fn replace_self_closing_xml_tag(xml: &str, tag: &str, replacement: &str) -> Result<String, String> {
let start = xml
.find(&format!("<{tag} "))
.ok_or_else(|| format!("presentation xml is missing `<{tag} .../>`"))?;
let end = xml[start..]
.find("/>")
.map(|offset| start + offset + 2)
.ok_or_else(|| format!("presentation xml tag `{tag}` is not self-closing"))?;
Ok(format!("{}{replacement}{}", &xml[..start], &xml[end..]))
}
fn slide_hyperlink_relationships(slide: &PresentationSlide) -> Vec<String> {
let mut ordered = slide.elements.iter().collect::<Vec<_>>();
ordered.sort_by_key(|element| element.z_order());
let mut hyperlink_index = 1_u32;
let mut relationships = Vec::new();
for element in ordered {
let Some(hyperlink) = (match element {
PresentationElement::Text(text) => text.hyperlink.as_ref(),
PresentationElement::Shape(shape) => shape.hyperlink.as_ref(),
PresentationElement::Connector(_)
| PresentationElement::Image(_)
| PresentationElement::Table(_)
| PresentationElement::Chart(_) => None,
}) else {
continue;
};
let relationship_id = format!("rIdHyperlink{hyperlink_index}");
hyperlink_index += 1;
relationships.push(hyperlink.relationship_xml(&relationship_id));
}
relationships
}
fn parse_slide_relationships_path(path: &str) -> Option<usize> {
path.strip_prefix("ppt/slides/_rels/slide")?
.strip_suffix(".xml.rels")?
.parse::<usize>()
.ok()
}
fn parse_slide_xml_path(path: &str) -> Option<usize> {
path.strip_prefix("ppt/slides/slide")?
.strip_suffix(".xml")?
.parse::<usize>()
.ok()
}
fn update_slide_relationships_xml(
existing_bytes: Vec<u8>,
relationships: &[String],
) -> Result<String, String> {
let existing = String::from_utf8(existing_bytes).map_err(|error| error.to_string())?;
let injected = relationships.join("\n");
existing
.contains("</Relationships>")
.then(|| existing.replace("</Relationships>", &format!("{injected}\n</Relationships>")))
.ok_or_else(|| {
"slide relationships xml is missing a closing `</Relationships>`".to_string()
})
}
fn slide_relationships_xml(relationships: &[String]) -> String {
let body = relationships.join("\n");
format!(
r#"<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<Relationships xmlns="http://schemas.openxmlformats.org/package/2006/relationships">
{body}
</Relationships>"#
)
}
fn update_content_types_xml(
existing_bytes: Vec<u8>,
image_extensions: &BTreeSet<String>,
) -> Result<String, String> {
let existing = String::from_utf8(existing_bytes).map_err(|error| error.to_string())?;
if image_extensions.is_empty() {
return Ok(existing);
}
let existing_lower = existing.to_ascii_lowercase();
let additions = image_extensions
.iter()
.filter(|extension| {
!existing_lower.contains(&format!(
r#"extension="{}""#,
extension.to_ascii_lowercase()
))
})
.map(|extension| generate_image_content_type(extension))
.collect::<Vec<_>>();
if additions.is_empty() {
return Ok(existing);
}
existing
.contains("</Types>")
.then(|| existing.replace("</Types>", &format!("{}\n</Types>", additions.join("\n"))))
.ok_or_else(|| "content types xml is missing a closing `</Types>`".to_string())
}
fn update_slide_xml(
existing_bytes: Vec<u8>,
slide: &PresentationSlide,
slide_images: &[SlideImageAsset],
) -> Result<String, String> {
let existing = String::from_utf8(existing_bytes).map_err(|error| error.to_string())?;
let existing = replace_image_placeholders(existing, slide_images)?;
let existing = apply_shape_block_patches(existing, slide)?;
let table_xml = slide_table_xml(slide);
if table_xml.is_empty() {
return Ok(existing);
}
existing
.contains("</p:spTree>")
.then(|| existing.replace("</p:spTree>", &format!("{table_xml}\n</p:spTree>")))
.ok_or_else(|| "slide xml is missing a closing `</p:spTree>`".to_string())
}
fn replace_image_placeholders(
existing: String,
slide_images: &[SlideImageAsset],
) -> Result<String, String> {
if slide_images.is_empty() {
return Ok(existing);
}
let mut updated = String::with_capacity(existing.len());
let mut remaining = existing.as_str();
for image in slide_images {
let marker = remaining
.find("name=\"Image Placeholder: ")
.ok_or_else(|| {
"slide xml is missing an image placeholder block for exported images".to_string()
})?;
let start = remaining[..marker].rfind("<p:sp>").ok_or_else(|| {
"slide xml is missing an opening `<p:sp>` for image placeholder".to_string()
})?;
let end = remaining[marker..]
.find("</p:sp>")
.map(|offset| marker + offset + "</p:sp>".len())
.ok_or_else(|| {
"slide xml is missing a closing `</p:sp>` for image placeholder".to_string()
})?;
updated.push_str(&remaining[..start]);
updated.push_str(&image.xml);
remaining = &remaining[end..];
}
updated.push_str(remaining);
Ok(updated)
}
#[derive(Clone, Copy)]
struct ShapeXmlPatch {
line_style: Option<LineStyle>,
flip_horizontal: bool,
flip_vertical: bool,
}
fn apply_shape_block_patches(
existing: String,
slide: &PresentationSlide,
) -> Result<String, String> {
let mut patches = Vec::new();
if slide.background_fill.is_some() {
patches.push(None);
}
let mut ordered = slide.elements.iter().collect::<Vec<_>>();
ordered.sort_by_key(|element| element.z_order());
for element in ordered {
match element {
PresentationElement::Text(_) => patches.push(None),
PresentationElement::Shape(shape) => patches.push(Some(ShapeXmlPatch {
line_style: shape
.stroke
.as_ref()
.map(|stroke| stroke.style)
.filter(|style| *style != LineStyle::Solid),
flip_horizontal: shape.flip_horizontal,
flip_vertical: shape.flip_vertical,
})),
PresentationElement::Image(ImageElement { payload: None, .. }) => patches.push(None),
PresentationElement::Connector(_)
| PresentationElement::Image(_)
| PresentationElement::Table(_)
| PresentationElement::Chart(_) => {}
}
}
if patches.iter().all(|patch| {
patch.is_none_or(|patch| {
patch.line_style.is_none() && !patch.flip_horizontal && !patch.flip_vertical
})
}) {
return Ok(existing);
}
let mut updated = String::with_capacity(existing.len());
let mut remaining = existing.as_str();
for patch in patches {
let Some(start) = remaining.find("<p:sp>") else {
return Err("slide xml is missing an expected `<p:sp>` block".to_string());
};
let end = remaining[start..]
.find("</p:sp>")
.map(|offset| start + offset + "</p:sp>".len())
.ok_or_else(|| "slide xml is missing a closing `</p:sp>` block".to_string())?;
updated.push_str(&remaining[..start]);
let block = &remaining[start..end];
if let Some(patch) = patch {
updated.push_str(&patch_shape_block(block, patch)?);
} else {
updated.push_str(block);
}
remaining = &remaining[end..];
}
updated.push_str(remaining);
Ok(updated)
}
fn patch_shape_block(block: &str, patch: ShapeXmlPatch) -> Result<String, String> {
let block = if let Some(line_style) = patch.line_style {
patch_shape_block_dash(block, line_style)?
} else {
block.to_string()
};
if patch.flip_horizontal || patch.flip_vertical {
patch_shape_block_flip(&block, patch.flip_horizontal, patch.flip_vertical)
} else {
Ok(block)
}
}
fn patch_shape_block_dash(block: &str, line_style: LineStyle) -> Result<String, String> {
let Some(line_start) = block.find("<a:ln") else {
return Err("shape block is missing an `<a:ln>` entry for stroke styling".to_string());
};
if let Some(dash_start) = block[line_start..].find("<a:prstDash") {
let dash_start = line_start + dash_start;
let dash_end = block[dash_start..]
.find("/>")
.map(|offset| dash_start + offset + 2)
.ok_or_else(|| "shape line dash entry is missing a closing `/>`".to_string())?;
let mut patched = String::with_capacity(block.len() + 32);
patched.push_str(&block[..dash_start]);
patched.push_str(&format!(
r#"<a:prstDash val="{}"/>"#,
line_style.to_ppt_xml()
));
patched.push_str(&block[dash_end..]);
return Ok(patched);
}
if let Some(line_end) = block[line_start..].find("</a:ln>") {
let line_end = line_start + line_end;
let mut patched = String::with_capacity(block.len() + 32);
patched.push_str(&block[..line_end]);
patched.push_str(&format!(
r#"<a:prstDash val="{}"/>"#,
line_style.to_ppt_xml()
));
patched.push_str(&block[line_end..]);
return Ok(patched);
}
let line_end = block[line_start..]
.find("/>")
.map(|offset| line_start + offset + 2)
.ok_or_else(|| "shape line entry is missing a closing marker".to_string())?;
let line_tag = &block[line_start..line_end - 2];
let mut patched = String::with_capacity(block.len() + 48);
patched.push_str(&block[..line_start]);
patched.push_str(line_tag);
patched.push('>');
patched.push_str(&format!(
r#"<a:prstDash val="{}"/>"#,
line_style.to_ppt_xml()
));
patched.push_str("</a:ln>");
patched.push_str(&block[line_end..]);
Ok(patched)
}
fn patch_shape_block_flip(
block: &str,
flip_horizontal: bool,
flip_vertical: bool,
) -> Result<String, String> {
let Some(xfrm_start) = block.find("<a:xfrm") else {
return Err("shape block is missing an `<a:xfrm>` entry for flip styling".to_string());
};
let tag_end = block[xfrm_start..]
.find('>')
.map(|offset| xfrm_start + offset)
.ok_or_else(|| "shape transform entry is missing a closing `>`".to_string())?;
let tag = &block[xfrm_start..=tag_end];
let mut patched_tag = tag.to_string();
patched_tag = upsert_xml_attribute(
&patched_tag,
"flipH",
if flip_horizontal { "1" } else { "0" },
);
patched_tag =
upsert_xml_attribute(&patched_tag, "flipV", if flip_vertical { "1" } else { "0" });
Ok(format!(
"{}{}{}",
&block[..xfrm_start],
patched_tag,
&block[tag_end + 1..]
))
}
fn upsert_xml_attribute(tag: &str, attribute: &str, value: &str) -> String {
let needle = format!(r#"{attribute}=""#);
if let Some(start) = tag.find(&needle) {
let value_start = start + needle.len();
if let Some(end_offset) = tag[value_start..].find('"') {
let end = value_start + end_offset;
return format!("{}{}{}", &tag[..value_start], value, &tag[end..]);
}
}
let insert_at = tag.len() - 1;
format!(r#"{} {attribute}="{value}""#, &tag[..insert_at]) + &tag[insert_at..]
}
fn slide_table_xml(slide: &PresentationSlide) -> String {
let mut ordered = slide.elements.iter().collect::<Vec<_>>();
ordered.sort_by_key(|element| element.z_order());
let mut table_index = 0_usize;
ordered
.into_iter()
.filter_map(|element| {
let PresentationElement::Table(table) = element else {
return None;
};
table_index += 1;
let rows = table
.rows
.clone()
.into_iter()
.enumerate()
.map(|(row_index, row)| {
let cells = row
.into_iter()
.enumerate()
.map(|(column_index, cell)| {
build_table_cell(cell, &table.merges, row_index, column_index)
})
.collect::<Vec<_>>();
let mut table_row = TableRow::new(cells);
if let Some(height) = table.row_heights.get(row_index) {
table_row = table_row.with_height(points_to_emu(*height));
}
Some(table_row)
})
.collect::<Option<Vec<_>>>()?;
Some(ppt_rs::generator::table::generate_table_xml(
&ppt_rs::generator::table::Table::new(
rows,
table
.column_widths
.iter()
.copied()
.map(points_to_emu)
.collect(),
points_to_emu(table.frame.left),
points_to_emu(table.frame.top),
),
300 + table_index,
))
})
.collect::<Vec<_>>()
.join("\n")
}
pub(crate) fn write_preview_image_bytes(
png_bytes: &[u8],
target_path: &Path,
format: PreviewOutputFormat,
scale: f32,
quality: u8,
action: &str,
) -> Result<(), PresentationArtifactError> {
if matches!(format, PreviewOutputFormat::Png) && scale == 1.0 {
std::fs::write(target_path, png_bytes).map_err(|error| {
PresentationArtifactError::ExportFailed {
path: target_path.to_path_buf(),
message: error.to_string(),
}
})?;
return Ok(());
}
let mut preview = image::load_from_memory(png_bytes).map_err(|error| {
PresentationArtifactError::ExportFailed {
path: target_path.to_path_buf(),
message: format!("{action}: {error}"),
}
})?;
if scale != 1.0 {
let width = (preview.width() as f32 * scale).round().max(1.0) as u32;
let height = (preview.height() as f32 * scale).round().max(1.0) as u32;
preview = preview.resize_exact(width, height, FilterType::Lanczos3);
}
let file = std::fs::File::create(target_path).map_err(|error| {
PresentationArtifactError::ExportFailed {
path: target_path.to_path_buf(),
message: error.to_string(),
}
})?;
let mut writer = std::io::BufWriter::new(file);
match format {
PreviewOutputFormat::Png => {
preview
.write_to(&mut writer, ImageFormat::Png)
.map_err(|error| PresentationArtifactError::ExportFailed {
path: target_path.to_path_buf(),
message: format!("{action}: {error}"),
})?
}
PreviewOutputFormat::Jpeg => {
let rgb = preview.to_rgb8();
let mut encoder = JpegEncoder::new_with_quality(&mut writer, quality);
encoder.encode_image(&rgb).map_err(|error| {
PresentationArtifactError::ExportFailed {
path: target_path.to_path_buf(),
message: format!("{action}: {error}"),
}
})?;
}
PreviewOutputFormat::Svg => {
let mut png_bytes = Cursor::new(Vec::new());
preview
.write_to(&mut png_bytes, ImageFormat::Png)
.map_err(|error| PresentationArtifactError::ExportFailed {
path: target_path.to_path_buf(),
message: format!("{action}: {error}"),
})?;
let embedded_png = BASE64_STANDARD.encode(png_bytes.into_inner());
let svg = format!(
r#"<svg xmlns="http://www.w3.org/2000/svg" width="{}" height="{}" viewBox="0 0 {} {}"><image href="data:image/png;base64,{embedded_png}" width="{}" height="{}"/></svg>"#,
preview.width(),
preview.height(),
preview.width(),
preview.height(),
preview.width(),
preview.height(),
);
writer.write_all(svg.as_bytes()).map_err(|error| {
PresentationArtifactError::ExportFailed {
path: target_path.to_path_buf(),
message: format!("{action}: {error}"),
}
})?;
}
}
Ok(())
}
fn parse_preview_output_format(
format: Option<&str>,
path: &Path,
action: &str,
) -> Result<PreviewOutputFormat, PresentationArtifactError> {
let value = format
.map(str::to_owned)
.or_else(|| {
path.extension()
.and_then(|extension| extension.to_str())
.map(str::to_owned)
})
.unwrap_or_else(|| "png".to_string());
match value.to_ascii_lowercase().as_str() {
"png" => Ok(PreviewOutputFormat::Png),
"jpg" | "jpeg" => Ok(PreviewOutputFormat::Jpeg),
"svg" => Ok(PreviewOutputFormat::Svg),
other => Err(PresentationArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("preview format `{other}` is not supported"),
}),
}
}
fn normalize_preview_scale(
scale: Option<f32>,
action: &str,
) -> Result<f32, PresentationArtifactError> {
let scale = scale.unwrap_or(1.0);
if !scale.is_finite() || scale <= 0.0 {
return Err(PresentationArtifactError::InvalidArgs {
action: action.to_string(),
message: "`scale` must be a positive number".to_string(),
});
}
Ok(scale)
}
fn normalize_preview_quality(
quality: Option<u8>,
action: &str,
) -> Result<u8, PresentationArtifactError> {
let quality = quality.unwrap_or(90);
if quality == 0 || quality > 100 {
return Err(PresentationArtifactError::InvalidArgs {
action: action.to_string(),
message: "`quality` must be between 1 and 100".to_string(),
});
}
Ok(quality)
}

View File

@@ -0,0 +1,614 @@
fn document_to_proto(
document: &PresentationDocument,
action: &str,
) -> Result<Value, PresentationArtifactError> {
let layouts = document
.layouts
.iter()
.map(|layout| layout_to_proto(document, layout, action))
.collect::<Result<Vec<_>, _>>()?;
let slides = document
.slides
.iter()
.enumerate()
.map(|(slide_index, slide)| slide_to_proto(slide, slide_index))
.collect::<Vec<_>>();
Ok(serde_json::json!({
"kind": "presentation",
"artifactId": document.artifact_id,
"anchor": format!("pr/{}", document.artifact_id),
"name": document.name,
"slideSize": rect_to_proto(document.slide_size),
"activeSlideIndex": document.active_slide_index,
"activeSlideId": document.active_slide_index.and_then(|index| document.slides.get(index)).map(|slide| slide.slide_id.clone()),
"theme": serde_json::json!({
"colorScheme": document.theme.color_scheme,
"hexColorMap": document.theme.color_scheme,
"majorFont": document.theme.major_font,
"minorFont": document.theme.minor_font,
}),
"styles": document
.named_text_styles()
.iter()
.map(|style| named_text_style_to_json(style, "st"))
.collect::<Vec<_>>(),
"masters": document.layouts.iter().filter(|layout| layout.kind == LayoutKind::Master).map(|layout| layout.layout_id.clone()).collect::<Vec<_>>(),
"layouts": layouts,
"slides": slides,
"commentAuthor": document.comment_self.as_ref().map(comment_author_to_proto),
"commentThreads": document
.comment_threads
.iter()
.map(comment_thread_to_proto)
.collect::<Vec<_>>(),
}))
}
fn layout_to_proto(
document: &PresentationDocument,
layout: &LayoutDocument,
action: &str,
) -> Result<Value, PresentationArtifactError> {
let placeholders = layout
.placeholders
.iter()
.map(placeholder_definition_to_proto)
.collect::<Vec<_>>();
let resolved_placeholders = resolved_layout_placeholders(document, &layout.layout_id, action)?
.into_iter()
.map(|placeholder| {
let mut value = placeholder_definition_to_proto(&placeholder.definition);
value["sourceLayoutId"] = Value::String(placeholder.source_layout_id);
value
})
.collect::<Vec<_>>();
Ok(serde_json::json!({
"layoutId": layout.layout_id,
"anchor": format!("ly/{}", layout.layout_id),
"name": layout.name,
"kind": match layout.kind {
LayoutKind::Layout => "layout",
LayoutKind::Master => "master",
},
"parentLayoutId": layout.parent_layout_id,
"placeholders": placeholders,
"resolvedPlaceholders": resolved_placeholders,
}))
}
fn placeholder_definition_to_proto(placeholder: &PlaceholderDefinition) -> Value {
serde_json::json!({
"name": placeholder.name,
"placeholderType": placeholder.placeholder_type,
"index": placeholder.index,
"text": placeholder.text,
"geometry": format!("{:?}", placeholder.geometry),
"frame": rect_to_proto(placeholder.frame),
})
}
fn slide_to_proto(slide: &PresentationSlide, slide_index: usize) -> Value {
serde_json::json!({
"slideId": slide.slide_id,
"anchor": format!("sl/{}", slide.slide_id),
"index": slide_index,
"layoutId": slide.layout_id,
"backgroundFill": slide.background_fill,
"notes": serde_json::json!({
"anchor": format!("nt/{}", slide.slide_id),
"text": slide.notes.text,
"visible": slide.notes.visible,
"textPreview": slide.notes.text.replace('\n', " | "),
"textChars": slide.notes.text.chars().count(),
"textLines": slide.notes.text.lines().count(),
"richText": rich_text_to_proto(&slide.notes.text, &slide.notes.rich_text),
}),
"elements": slide.elements.iter().map(element_to_proto).collect::<Vec<_>>(),
})
}
fn element_to_proto(element: &PresentationElement) -> Value {
match element {
PresentationElement::Text(text) => {
let mut record = serde_json::json!({
"kind": "text",
"elementId": text.element_id,
"anchor": format!("sh/{}", text.element_id),
"frame": rect_to_proto(text.frame),
"text": text.text,
"textPreview": text.text.replace('\n', " | "),
"textChars": text.text.chars().count(),
"textLines": text.text.lines().count(),
"fill": text.fill,
"style": text_style_to_proto(&text.style),
"richText": rich_text_to_proto(&text.text, &text.rich_text),
"zOrder": text.z_order,
});
if let Some(placeholder) = &text.placeholder {
record["placeholder"] = placeholder_ref_to_proto(placeholder);
}
if let Some(hyperlink) = &text.hyperlink {
record["hyperlink"] = hyperlink.to_json();
}
record
}
PresentationElement::Shape(shape) => {
let mut record = serde_json::json!({
"kind": "shape",
"elementId": shape.element_id,
"anchor": format!("sh/{}", shape.element_id),
"geometry": format!("{:?}", shape.geometry),
"frame": rect_to_proto(shape.frame),
"fill": shape.fill,
"stroke": shape.stroke.as_ref().map(stroke_to_proto),
"text": shape.text,
"textStyle": text_style_to_proto(&shape.text_style),
"richText": shape
.text
.as_ref()
.zip(shape.rich_text.as_ref())
.map(|(text, rich_text)| rich_text_to_proto(text, rich_text))
.unwrap_or(Value::Null),
"rotation": shape.rotation_degrees,
"flipHorizontal": shape.flip_horizontal,
"flipVertical": shape.flip_vertical,
"zOrder": shape.z_order,
});
if let Some(text) = &shape.text {
record["textPreview"] = Value::String(text.replace('\n', " | "));
record["textChars"] = Value::from(text.chars().count());
record["textLines"] = Value::from(text.lines().count());
}
if let Some(placeholder) = &shape.placeholder {
record["placeholder"] = placeholder_ref_to_proto(placeholder);
}
if let Some(hyperlink) = &shape.hyperlink {
record["hyperlink"] = hyperlink.to_json();
}
record
}
PresentationElement::Connector(connector) => serde_json::json!({
"kind": "connector",
"elementId": connector.element_id,
"anchor": format!("cn/{}", connector.element_id),
"connectorType": format!("{:?}", connector.connector_type),
"start": serde_json::json!({
"left": connector.start.left,
"top": connector.start.top,
"unit": "points",
}),
"end": serde_json::json!({
"left": connector.end.left,
"top": connector.end.top,
"unit": "points",
}),
"line": stroke_to_proto(&connector.line),
"lineStyle": connector.line_style.as_api_str(),
"startArrow": format!("{:?}", connector.start_arrow),
"endArrow": format!("{:?}", connector.end_arrow),
"arrowSize": format!("{:?}", connector.arrow_size),
"label": connector.label,
"zOrder": connector.z_order,
}),
PresentationElement::Image(image) => {
let mut record = serde_json::json!({
"kind": "image",
"elementId": image.element_id,
"anchor": format!("im/{}", image.element_id),
"frame": rect_to_proto(image.frame),
"fit": format!("{:?}", image.fit_mode),
"crop": image.crop.map(|(left, top, right, bottom)| serde_json::json!({
"left": left,
"top": top,
"right": right,
"bottom": bottom,
})),
"rotation": image.rotation_degrees,
"flipHorizontal": image.flip_horizontal,
"flipVertical": image.flip_vertical,
"lockAspectRatio": image.lock_aspect_ratio,
"alt": image.alt_text,
"prompt": image.prompt,
"isPlaceholder": image.is_placeholder,
"payload": image.payload.as_ref().map(image_payload_to_proto),
"zOrder": image.z_order,
});
if let Some(placeholder) = &image.placeholder {
record["placeholder"] = placeholder_ref_to_proto(placeholder);
}
record
}
PresentationElement::Table(table) => serde_json::json!({
"kind": "table",
"elementId": table.element_id,
"anchor": format!("tb/{}", table.element_id),
"frame": rect_to_proto(table.frame),
"rows": table.rows.iter().map(|row| {
row.iter().map(table_cell_to_proto).collect::<Vec<_>>()
}).collect::<Vec<_>>(),
"columnWidths": table.column_widths,
"rowHeights": table.row_heights,
"style": table.style,
"styleOptions": table_style_options_to_proto(&table.style_options),
"borders": table.borders.as_ref().map(table_borders_to_proto),
"rightToLeft": table.right_to_left,
"merges": table.merges.iter().map(|merge| serde_json::json!({
"startRow": merge.start_row,
"endRow": merge.end_row,
"startColumn": merge.start_column,
"endColumn": merge.end_column,
})).collect::<Vec<_>>(),
"zOrder": table.z_order,
}),
PresentationElement::Chart(chart) => serde_json::json!({
"kind": "chart",
"elementId": chart.element_id,
"anchor": format!("ch/{}", chart.element_id),
"frame": rect_to_proto(chart.frame),
"chartType": format!("{:?}", chart.chart_type),
"title": chart.title,
"categories": chart.categories,
"styleIndex": chart.style_index,
"hasLegend": chart.has_legend,
"legend": chart.legend.as_ref().map(chart_legend_to_proto),
"xAxis": chart.x_axis.as_ref().map(chart_axis_to_proto),
"yAxis": chart.y_axis.as_ref().map(chart_axis_to_proto),
"dataLabels": chart.data_labels.as_ref().map(chart_data_labels_to_proto),
"chartFill": chart.chart_fill,
"plotAreaFill": chart.plot_area_fill,
"series": chart.series.iter().map(|series| serde_json::json!({
"name": series.name,
"values": series.values,
"categories": series.categories,
"xValues": series.x_values,
"fill": series.fill,
"stroke": series.stroke.as_ref().map(stroke_to_proto),
"marker": series.marker.as_ref().map(chart_marker_to_proto),
"dataLabelOverrides": series
.data_label_overrides
.iter()
.map(chart_data_label_override_to_proto)
.collect::<Vec<_>>(),
})).collect::<Vec<_>>(),
"zOrder": chart.z_order,
}),
}
}
fn rect_to_proto(rect: Rect) -> Value {
serde_json::json!({
"left": rect.left,
"top": rect.top,
"width": rect.width,
"height": rect.height,
"unit": "points",
})
}
fn stroke_to_proto(stroke: &StrokeStyle) -> Value {
serde_json::json!({
"color": stroke.color,
"width": stroke.width,
"style": stroke.style.as_api_str(),
"unit": "points",
})
}
fn text_style_to_proto(style: &TextStyle) -> Value {
serde_json::json!({
"styleName": style.style_name,
"fontSize": style.font_size,
"fontFamily": style.font_family,
"color": style.color,
"alignment": style.alignment,
"bold": style.bold,
"italic": style.italic,
"underline": style.underline,
})
}
fn rich_text_to_proto(text: &str, rich_text: &RichTextState) -> Value {
serde_json::json!({
"layout": text_layout_to_proto(&rich_text.layout),
"ranges": rich_text
.ranges
.iter()
.map(|range| text_range_to_proto(text, range))
.collect::<Vec<_>>(),
})
}
fn text_range_to_proto(text: &str, range: &TextRangeAnnotation) -> Value {
serde_json::json!({
"rangeId": range.range_id,
"anchor": format!("tr/{}", range.range_id),
"startCp": range.start_cp,
"length": range.length,
"text": text_slice_by_codepoint_range(text, range.start_cp, range.length),
"style": text_style_to_proto(&range.style),
"hyperlink": range.hyperlink.as_ref().map(HyperlinkState::to_json),
"spacingBefore": range.spacing_before,
"spacingAfter": range.spacing_after,
"lineSpacing": range.line_spacing,
})
}
fn text_layout_to_proto(layout: &TextLayoutState) -> Value {
serde_json::json!({
"insets": layout.insets.map(|insets| serde_json::json!({
"left": insets.left,
"right": insets.right,
"top": insets.top,
"bottom": insets.bottom,
"unit": "points",
})),
"wrap": layout.wrap.map(text_wrap_mode_to_proto),
"autoFit": layout.auto_fit.map(text_auto_fit_mode_to_proto),
"verticalAlignment": layout
.vertical_alignment
.map(text_vertical_alignment_to_proto),
})
}
fn text_wrap_mode_to_proto(mode: TextWrapMode) -> &'static str {
match mode {
TextWrapMode::Square => "square",
TextWrapMode::None => "none",
}
}
fn text_auto_fit_mode_to_proto(mode: TextAutoFitMode) -> &'static str {
match mode {
TextAutoFitMode::None => "none",
TextAutoFitMode::ShrinkText => "shrinkText",
TextAutoFitMode::ResizeShapeToFitText => "resizeShapeToFitText",
}
}
fn text_vertical_alignment_to_proto(alignment: TextVerticalAlignment) -> &'static str {
match alignment {
TextVerticalAlignment::Top => "top",
TextVerticalAlignment::Middle => "middle",
TextVerticalAlignment::Bottom => "bottom",
}
}
fn placeholder_ref_to_proto(placeholder: &PlaceholderRef) -> Value {
serde_json::json!({
"name": placeholder.name,
"placeholderType": placeholder.placeholder_type,
"index": placeholder.index,
})
}
fn image_payload_to_proto(payload: &ImagePayload) -> Value {
serde_json::json!({
"format": payload.format,
"widthPx": payload.width_px,
"heightPx": payload.height_px,
"bytesBase64": BASE64_STANDARD.encode(&payload.bytes),
})
}
fn table_cell_to_proto(cell: &TableCellSpec) -> Value {
serde_json::json!({
"text": cell.text,
"textStyle": text_style_to_proto(&cell.text_style),
"richText": rich_text_to_proto(&cell.text, &cell.rich_text),
"backgroundFill": cell.background_fill,
"alignment": cell.alignment,
"borders": cell.borders.as_ref().map(table_borders_to_proto),
})
}
fn table_style_options_to_proto(style_options: &TableStyleOptions) -> Value {
serde_json::json!({
"headerRow": style_options.header_row,
"bandedRows": style_options.banded_rows,
"bandedColumns": style_options.banded_columns,
"firstColumn": style_options.first_column,
"lastColumn": style_options.last_column,
"totalRow": style_options.total_row,
})
}
fn table_borders_to_proto(borders: &TableBorders) -> Value {
serde_json::json!({
"outside": borders.outside.as_ref().map(table_border_to_proto),
"inside": borders.inside.as_ref().map(table_border_to_proto),
"top": borders.top.as_ref().map(table_border_to_proto),
"bottom": borders.bottom.as_ref().map(table_border_to_proto),
"left": borders.left.as_ref().map(table_border_to_proto),
"right": borders.right.as_ref().map(table_border_to_proto),
})
}
fn table_border_to_proto(border: &TableBorder) -> Value {
serde_json::json!({
"color": border.color,
"width": border.width,
"unit": "points",
})
}
fn chart_marker_to_proto(marker: &ChartMarkerStyle) -> Value {
serde_json::json!({
"symbol": marker.symbol,
"size": marker.size,
})
}
fn chart_data_labels_to_proto(data_labels: &ChartDataLabels) -> Value {
serde_json::json!({
"showValue": data_labels.show_value,
"showCategoryName": data_labels.show_category_name,
"showLeaderLines": data_labels.show_leader_lines,
"position": data_labels.position,
"textStyle": text_style_to_proto(&data_labels.text_style),
})
}
fn chart_legend_to_proto(legend: &ChartLegend) -> Value {
serde_json::json!({
"position": legend.position,
"textStyle": text_style_to_proto(&legend.text_style),
})
}
fn chart_axis_to_proto(axis: &ChartAxisSpec) -> Value {
serde_json::json!({
"title": axis.title,
})
}
fn chart_data_label_override_to_proto(override_spec: &ChartDataLabelOverride) -> Value {
serde_json::json!({
"idx": override_spec.idx,
"text": override_spec.text,
"position": override_spec.position,
"textStyle": text_style_to_proto(&override_spec.text_style),
"fill": override_spec.fill,
"stroke": override_spec.stroke.as_ref().map(stroke_to_proto),
})
}
fn comment_author_to_proto(author: &CommentAuthorProfile) -> Value {
serde_json::json!({
"displayName": author.display_name,
"initials": author.initials,
"email": author.email,
})
}
fn comment_thread_to_proto(thread: &CommentThread) -> Value {
serde_json::json!({
"kind": "comment",
"threadId": thread.thread_id,
"anchor": format!("th/{}", thread.thread_id),
"target": comment_target_to_proto(&thread.target),
"position": thread.position.as_ref().map(comment_position_to_proto),
"status": comment_status_to_proto(thread.status),
"messages": thread.messages.iter().map(comment_message_to_proto).collect::<Vec<_>>(),
})
}
fn comment_target_to_proto(target: &CommentTarget) -> Value {
match target {
CommentTarget::Slide { slide_id } => serde_json::json!({
"type": "slide",
"slideId": slide_id,
"slideAnchor": format!("sl/{slide_id}"),
}),
CommentTarget::Element {
slide_id,
element_id,
} => serde_json::json!({
"type": "element",
"slideId": slide_id,
"slideAnchor": format!("sl/{slide_id}"),
"elementId": element_id,
"elementAnchor": format!("sh/{element_id}"),
}),
CommentTarget::TextRange {
slide_id,
element_id,
start_cp,
length,
context,
} => serde_json::json!({
"type": "textRange",
"slideId": slide_id,
"slideAnchor": format!("sl/{slide_id}"),
"elementId": element_id,
"elementAnchor": format!("sh/{element_id}"),
"startCp": start_cp,
"length": length,
"context": context,
}),
}
}
fn comment_position_to_proto(position: &CommentPosition) -> Value {
serde_json::json!({
"x": position.x,
"y": position.y,
"unit": "points",
})
}
fn comment_message_to_proto(message: &CommentMessage) -> Value {
serde_json::json!({
"messageId": message.message_id,
"author": comment_author_to_proto(&message.author),
"text": message.text,
"createdAt": message.created_at,
"reactions": message.reactions,
})
}
fn comment_status_to_proto(status: CommentThreadStatus) -> &'static str {
match status {
CommentThreadStatus::Active => "active",
CommentThreadStatus::Resolved => "resolved",
}
}
fn text_slice_by_codepoint_range(text: &str, start_cp: usize, length: usize) -> String {
text.chars().skip(start_cp).take(length).collect()
}
fn build_table_cell(
cell: TableCellSpec,
merges: &[TableMergeRegion],
row_index: usize,
column_index: usize,
) -> TableCell {
let mut table_cell = TableCell::new(&cell.text);
if cell.text_style.bold {
table_cell = table_cell.bold();
}
if cell.text_style.italic {
table_cell = table_cell.italic();
}
if cell.text_style.underline {
table_cell = table_cell.underline();
}
if let Some(color) = cell.text_style.color {
table_cell = table_cell.text_color(&color);
}
if let Some(fill) = cell.background_fill {
table_cell = table_cell.background_color(&fill);
}
if let Some(size) = cell.text_style.font_size {
table_cell = table_cell.font_size(size);
}
if let Some(font_family) = cell.text_style.font_family {
table_cell = table_cell.font_family(&font_family);
}
if let Some(alignment) = cell.alignment.or(cell.text_style.alignment) {
table_cell = match alignment {
TextAlignment::Left => table_cell.align_left(),
TextAlignment::Center => table_cell.align_center(),
TextAlignment::Right => table_cell.align_right(),
TextAlignment::Justify => table_cell.align(CellAlign::Justify),
};
}
for merge in merges {
if row_index == merge.start_row && column_index == merge.start_column {
table_cell = table_cell
.grid_span((merge.end_column - merge.start_column + 1) as u32)
.row_span((merge.end_row - merge.start_row + 1) as u32);
} else if row_index >= merge.start_row
&& row_index <= merge.end_row
&& column_index >= merge.start_column
&& column_index <= merge.end_column
{
if row_index == merge.start_row {
table_cell = table_cell.h_merge();
} else {
table_cell = table_cell.v_merge();
}
}
}
table_cell
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,148 @@
#[derive(Debug, Clone, Serialize)]
pub struct PresentationArtifactResponse {
pub artifact_id: String,
pub action: String,
pub summary: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub executed_actions: Option<Vec<String>>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub exported_paths: Vec<PathBuf>,
#[serde(skip_serializing_if = "Option::is_none")]
pub artifact_snapshot: Option<ArtifactSnapshot>,
#[serde(skip_serializing_if = "Option::is_none")]
pub slide_list: Option<Vec<SlideListEntry>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub layout_list: Option<Vec<LayoutListEntry>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub placeholder_list: Option<Vec<PlaceholderListEntry>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub theme: Option<ThemeSnapshot>,
#[serde(skip_serializing_if = "Option::is_none")]
pub inspect_ndjson: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub resolved_record: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub proto_json: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub patch: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub active_slide_index: Option<usize>,
#[serde(skip)]
pub rendered_preview: Option<RenderedPreview>,
}
#[derive(Debug, Clone)]
pub struct RenderedPreview {
pub slide_index: usize,
pub png_bytes: Vec<u8>,
}
impl PresentationArtifactResponse {
fn new(
artifact_id: String,
action: String,
summary: String,
artifact_snapshot: ArtifactSnapshot,
) -> Self {
Self {
artifact_id,
action,
summary,
executed_actions: None,
exported_paths: Vec::new(),
artifact_snapshot: Some(artifact_snapshot),
slide_list: None,
layout_list: None,
placeholder_list: None,
theme: None,
inspect_ndjson: None,
resolved_record: None,
proto_json: None,
patch: None,
active_slide_index: None,
rendered_preview: None,
}
}
}
fn response_for_document_state(
artifact_id: String,
action: String,
summary: String,
document: Option<&PresentationDocument>,
) -> PresentationArtifactResponse {
PresentationArtifactResponse {
artifact_id,
action,
summary,
executed_actions: None,
exported_paths: Vec::new(),
artifact_snapshot: document.map(snapshot_for_document),
slide_list: None,
layout_list: None,
placeholder_list: None,
theme: document.map(PresentationDocument::theme_snapshot),
inspect_ndjson: None,
resolved_record: None,
proto_json: None,
patch: None,
active_slide_index: document.and_then(|current| current.active_slide_index),
rendered_preview: None,
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ArtifactSnapshot {
pub slide_count: usize,
pub slides: Vec<SlideSnapshot>,
}
#[derive(Debug, Clone, Serialize)]
pub struct SlideSnapshot {
pub slide_id: String,
pub index: usize,
pub element_ids: Vec<String>,
pub element_types: Vec<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct SlideListEntry {
pub slide_id: String,
pub index: usize,
pub is_active: bool,
pub notes: Option<String>,
pub notes_visible: bool,
pub background_fill: Option<String>,
pub layout_id: Option<String>,
pub element_count: usize,
}
#[derive(Debug, Clone, Serialize)]
pub struct LayoutListEntry {
pub layout_id: String,
pub name: String,
pub kind: String,
pub parent_layout_id: Option<String>,
pub placeholder_count: usize,
}
#[derive(Debug, Clone, Serialize)]
pub struct PlaceholderListEntry {
pub scope: String,
pub source_layout_id: Option<String>,
pub slide_index: Option<usize>,
pub element_id: Option<String>,
pub name: String,
pub placeholder_type: String,
pub index: Option<u32>,
pub geometry: Option<String>,
pub text_preview: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct ThemeSnapshot {
pub color_scheme: HashMap<String, String>,
pub hex_color_map: HashMap<String, String>,
pub major_font: Option<String>,
pub minor_font: Option<String>,
}

View File

@@ -0,0 +1,339 @@
fn cell_value_to_string(value: Value) -> String {
match value {
Value::Null => String::new(),
Value::String(text) => text,
Value::Bool(boolean) => boolean.to_string(),
Value::Number(number) => number.to_string(),
other => other.to_string(),
}
}
fn snapshot_for_document(document: &PresentationDocument) -> ArtifactSnapshot {
ArtifactSnapshot {
slide_count: document.slides.len(),
slides: document
.slides
.iter()
.enumerate()
.map(|(index, slide)| SlideSnapshot {
slide_id: slide.slide_id.clone(),
index,
element_ids: slide
.elements
.iter()
.map(|element| element.element_id().to_string())
.collect(),
element_types: slide
.elements
.iter()
.map(|element| element.kind().to_string())
.collect(),
})
.collect(),
}
}
fn slide_list(document: &PresentationDocument) -> Vec<SlideListEntry> {
document
.slides
.iter()
.enumerate()
.map(|(index, slide)| SlideListEntry {
slide_id: slide.slide_id.clone(),
index,
is_active: document.active_slide_index == Some(index),
notes: (slide.notes.visible && !slide.notes.text.is_empty())
.then(|| slide.notes.text.clone()),
notes_visible: slide.notes.visible,
background_fill: slide.background_fill.clone(),
layout_id: slide.layout_id.clone(),
element_count: slide.elements.len(),
})
.collect()
}
fn layout_list(document: &PresentationDocument) -> Vec<LayoutListEntry> {
document
.layouts
.iter()
.map(|layout| LayoutListEntry {
layout_id: layout.layout_id.clone(),
name: layout.name.clone(),
kind: match layout.kind {
LayoutKind::Layout => "layout".to_string(),
LayoutKind::Master => "master".to_string(),
},
parent_layout_id: layout.parent_layout_id.clone(),
placeholder_count: layout.placeholders.len(),
})
.collect()
}
fn points_to_emu(points: u32) -> u32 {
points.saturating_mul(POINT_TO_EMU)
}
fn emu_to_points(emu: u32) -> u32 {
emu / POINT_TO_EMU
}
type ImageCrop = (f64, f64, f64, f64);
type FittedImage = (u32, u32, u32, u32, Option<ImageCrop>);
pub(crate) fn fit_image(image: &ImageElement) -> FittedImage {
let Some(payload) = image.payload.as_ref() else {
return (
image.frame.left,
image.frame.top,
image.frame.width,
image.frame.height,
None,
);
};
let frame = image.frame;
let source_width = payload.width_px as f64;
let source_height = payload.height_px as f64;
let target_width = frame.width as f64;
let target_height = frame.height as f64;
let source_ratio = source_width / source_height;
let target_ratio = target_width / target_height;
match image.fit_mode {
ImageFitMode::Stretch => (frame.left, frame.top, frame.width, frame.height, None),
ImageFitMode::Contain => {
let scale = if source_ratio > target_ratio {
target_width / source_width
} else {
target_height / source_height
};
let width = (source_width * scale).round() as u32;
let height = (source_height * scale).round() as u32;
let left = frame.left + frame.width.saturating_sub(width) / 2;
let top = frame.top + frame.height.saturating_sub(height) / 2;
(left, top, width, height, None)
}
ImageFitMode::Cover => {
let scale = if source_ratio > target_ratio {
target_height / source_height
} else {
target_width / source_width
};
let width = source_width * scale;
let height = source_height * scale;
let crop_x = ((width - target_width).max(0.0) / width) / 2.0;
let crop_y = ((height - target_height).max(0.0) / height) / 2.0;
(
frame.left,
frame.top,
frame.width,
frame.height,
Some((crop_x, crop_y, crop_x, crop_y)),
)
}
}
}
fn normalize_image_crop(
crop: ImageCropArgs,
action: &str,
) -> Result<ImageCrop, PresentationArtifactError> {
for (name, value) in [
("left", crop.left),
("top", crop.top),
("right", crop.right),
("bottom", crop.bottom),
] {
if !(0.0..=1.0).contains(&value) {
return Err(PresentationArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("image crop `{name}` must be between 0.0 and 1.0"),
});
}
}
Ok((crop.left, crop.top, crop.right, crop.bottom))
}
fn load_image_payload_from_path(
path: &Path,
action: &str,
) -> Result<ImagePayload, PresentationArtifactError> {
let bytes = std::fs::read(path).map_err(|error| PresentationArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("failed to read image `{}`: {error}", path.display()),
})?;
build_image_payload(
bytes,
path.file_name()
.and_then(|name| name.to_str())
.unwrap_or("image")
.to_string(),
action,
)
}
fn load_image_payload_from_data_url(
data_url: &str,
action: &str,
) -> Result<ImagePayload, PresentationArtifactError> {
let (header, payload) =
data_url
.split_once(',')
.ok_or_else(|| PresentationArtifactError::InvalidArgs {
action: action.to_string(),
message: "data_url must include a MIME header and base64 payload".to_string(),
})?;
let mime = header
.strip_prefix("data:")
.and_then(|prefix| prefix.strip_suffix(";base64"))
.ok_or_else(|| PresentationArtifactError::InvalidArgs {
action: action.to_string(),
message: "data_url must be base64-encoded".to_string(),
})?;
let bytes = BASE64_STANDARD.decode(payload).map_err(|error| {
PresentationArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("failed to decode image data_url: {error}"),
}
})?;
build_image_payload(
bytes,
format!("image.{}", image_extension_from_mime(mime)),
action,
)
}
fn load_image_payload_from_blob(
blob: &str,
action: &str,
) -> Result<ImagePayload, PresentationArtifactError> {
let bytes = BASE64_STANDARD.decode(blob.trim()).map_err(|error| {
PresentationArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("failed to decode image blob: {error}"),
}
})?;
let extension = image::guess_format(&bytes)
.ok()
.map(image_extension_from_format)
.unwrap_or("png");
build_image_payload(bytes, format!("image.{extension}"), action)
}
fn load_image_payload_from_uri(
uri: &str,
action: &str,
) -> Result<ImagePayload, PresentationArtifactError> {
let response =
reqwest::blocking::get(uri).map_err(|error| PresentationArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("failed to fetch image `{uri}`: {error}"),
})?;
let status = response.status();
if !status.is_success() {
return Err(PresentationArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("failed to fetch image `{uri}`: HTTP {status}"),
});
}
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map(|value| value.split(';').next().unwrap_or(value).trim().to_string());
let bytes = response
.bytes()
.map_err(|error| PresentationArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("failed to read image `{uri}`: {error}"),
})?;
build_image_payload(
bytes.to_vec(),
infer_remote_image_filename(uri, content_type.as_deref()),
action,
)
}
fn infer_remote_image_filename(uri: &str, content_type: Option<&str>) -> String {
let path_name = reqwest::Url::parse(uri)
.ok()
.and_then(|url| {
url.path_segments()
.and_then(Iterator::last)
.map(str::to_owned)
})
.filter(|segment| !segment.is_empty());
match (path_name, content_type) {
(Some(path_name), _) if Path::new(&path_name).extension().is_some() => path_name,
(Some(path_name), Some(content_type)) => {
format!("{path_name}.{}", image_extension_from_mime(content_type))
}
(Some(path_name), None) => path_name,
(None, Some(content_type)) => format!("image.{}", image_extension_from_mime(content_type)),
(None, None) => "image.png".to_string(),
}
}
fn build_image_payload(
bytes: Vec<u8>,
filename: String,
action: &str,
) -> Result<ImagePayload, PresentationArtifactError> {
let image = image::load_from_memory(&bytes).map_err(|error| {
PresentationArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("failed to decode image bytes: {error}"),
}
})?;
let (width_px, height_px) = image.dimensions();
let format = Path::new(&filename)
.extension()
.and_then(|extension| extension.to_str())
.unwrap_or("png")
.to_uppercase();
Ok(ImagePayload {
bytes,
format,
width_px,
height_px,
})
}
fn image_extension_from_mime(mime: &str) -> &'static str {
match mime {
"image/jpeg" => "jpg",
"image/gif" => "gif",
"image/webp" => "webp",
_ => "png",
}
}
fn image_extension_from_format(format: image::ImageFormat) -> &'static str {
match format {
image::ImageFormat::Jpeg => "jpg",
image::ImageFormat::Gif => "gif",
image::ImageFormat::WebP => "webp",
image::ImageFormat::Bmp => "bmp",
image::ImageFormat::Tiff => "tiff",
_ => "png",
}
}
fn index_out_of_range(action: &str, index: usize, len: usize) -> PresentationArtifactError {
PresentationArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("slide index {index} is out of range for {len} slides"),
}
}
fn to_index(value: u32) -> Result<usize, PresentationArtifactError> {
usize::try_from(value).map_err(|_| PresentationArtifactError::InvalidArgs {
action: "insert_slide".to_string(),
message: "index does not fit in usize".to_string(),
})
}
fn resequence_z_order(slide: &mut PresentationSlide) {
for (index, element) in slide.elements.iter_mut().enumerate() {
element.set_z_order(index);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,6 @@
load("//:defs.bzl", "codex_rust_crate")
codex_rust_crate(
name = "artifact-spreadsheet",
crate_name = "codex_artifact_spreadsheet",
)

View File

@@ -0,0 +1,25 @@
[package]
name = "codex-artifact-spreadsheet"
version.workspace = true
edition.workspace = true
license.workspace = true
[lib]
name = "codex_artifact_spreadsheet"
path = "src/lib.rs"
[lints]
workspace = true
[dependencies]
base64 = { workspace = true }
regex = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
thiserror = { workspace = true }
uuid = { workspace = true, features = ["v4"] }
zip = { workspace = true }
[dev-dependencies]
pretty_assertions = { workspace = true }
tempfile = { workspace = true }

View File

@@ -0,0 +1,245 @@
use serde::Deserialize;
use serde::Serialize;
use crate::SpreadsheetArtifactError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct CellAddress {
pub column: u32,
pub row: u32,
}
impl CellAddress {
pub fn parse(address: &str) -> Result<Self, SpreadsheetArtifactError> {
let trimmed = address.trim();
if trimmed.is_empty() {
return Err(SpreadsheetArtifactError::InvalidAddress {
address: address.to_string(),
message: "address is empty".to_string(),
});
}
let mut split = 0usize;
for (index, ch) in trimmed.char_indices() {
if ch.is_ascii_alphabetic() {
split = index + ch.len_utf8();
} else {
break;
}
}
let (letters, digits) = trimmed.split_at(split);
if letters.is_empty() || digits.is_empty() {
return Err(SpreadsheetArtifactError::InvalidAddress {
address: address.to_string(),
message: "expected A1-style address".to_string(),
});
}
if !letters.chars().all(|ch| ch.is_ascii_alphabetic())
|| !digits.chars().all(|ch| ch.is_ascii_digit())
{
return Err(SpreadsheetArtifactError::InvalidAddress {
address: address.to_string(),
message: "expected letters followed by digits".to_string(),
});
}
let column = column_letters_to_index(letters)?;
let row = digits
.parse::<u32>()
.map_err(|_| SpreadsheetArtifactError::InvalidAddress {
address: address.to_string(),
message: "row must be a positive integer".to_string(),
})?;
if row == 0 {
return Err(SpreadsheetArtifactError::InvalidAddress {
address: address.to_string(),
message: "row must be positive".to_string(),
});
}
Ok(Self { column, row })
}
pub fn to_a1(self) -> String {
format!("{}{}", column_index_to_letters(self.column), self.row)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CellRange {
pub start: CellAddress,
pub end: CellAddress,
}
impl CellRange {
pub fn parse(address: &str) -> Result<Self, SpreadsheetArtifactError> {
let trimmed = address.trim();
if trimmed.is_empty() {
return Err(SpreadsheetArtifactError::InvalidAddress {
address: address.to_string(),
message: "range is empty".to_string(),
});
}
let (start, end) = if let Some((left, right)) = trimmed.split_once(':') {
(CellAddress::parse(left)?, CellAddress::parse(right)?)
} else {
let cell = CellAddress::parse(trimmed)?;
(cell, cell)
};
let normalized = Self {
start: CellAddress {
column: start.column.min(end.column),
row: start.row.min(end.row),
},
end: CellAddress {
column: start.column.max(end.column),
row: start.row.max(end.row),
},
};
Ok(normalized)
}
pub fn from_start_end(start: CellAddress, end: CellAddress) -> Self {
Self {
start: CellAddress {
column: start.column.min(end.column),
row: start.row.min(end.row),
},
end: CellAddress {
column: start.column.max(end.column),
row: start.row.max(end.row),
},
}
}
pub fn to_a1(&self) -> String {
if self.is_single_cell() {
self.start.to_a1()
} else {
format!("{}:{}", self.start.to_a1(), self.end.to_a1())
}
}
pub fn is_single_cell(&self) -> bool {
self.start == self.end
}
pub fn is_single_row(&self) -> bool {
self.start.row == self.end.row
}
pub fn is_single_column(&self) -> bool {
self.start.column == self.end.column
}
pub fn width(&self) -> usize {
(self.end.column - self.start.column + 1) as usize
}
pub fn height(&self) -> usize {
(self.end.row - self.start.row + 1) as usize
}
pub fn contains(&self, address: CellAddress) -> bool {
self.start.column <= address.column
&& address.column <= self.end.column
&& self.start.row <= address.row
&& address.row <= self.end.row
}
pub fn contains_range(&self, other: &CellRange) -> bool {
self.contains(other.start) && self.contains(other.end)
}
pub fn intersects(&self, other: &CellRange) -> bool {
!(self.end.column < other.start.column
|| other.end.column < self.start.column
|| self.end.row < other.start.row
|| other.end.row < self.start.row)
}
pub fn addresses(&self) -> impl Iterator<Item = CellAddress> {
let range = self.clone();
(range.start.row..=range.end.row).flat_map(move |row| {
let range = range.clone();
(range.start.column..=range.end.column).map(move |column| CellAddress { column, row })
})
}
}
pub fn column_letters_to_index(column: &str) -> Result<u32, SpreadsheetArtifactError> {
let trimmed = column.trim();
if trimmed.is_empty() {
return Err(SpreadsheetArtifactError::InvalidAddress {
address: column.to_string(),
message: "column is empty".to_string(),
});
}
let mut result = 0u32;
for ch in trimmed.chars() {
if !ch.is_ascii_alphabetic() {
return Err(SpreadsheetArtifactError::InvalidAddress {
address: column.to_string(),
message: "column must contain only letters".to_string(),
});
}
result = result
.checked_mul(26)
.and_then(|value| value.checked_add((ch.to_ascii_uppercase() as u8 - b'A' + 1) as u32))
.ok_or_else(|| SpreadsheetArtifactError::InvalidAddress {
address: column.to_string(),
message: "column is too large".to_string(),
})?;
}
Ok(result)
}
pub fn column_index_to_letters(mut index: u32) -> String {
if index == 0 {
return String::new();
}
let mut letters = Vec::new();
while index > 0 {
let remainder = (index - 1) % 26;
letters.push((b'A' + remainder as u8) as char);
index = (index - 1) / 26;
}
letters.iter().rev().collect()
}
pub fn parse_column_reference(reference: &str) -> Result<(u32, u32), SpreadsheetArtifactError> {
let trimmed = reference.trim();
if let Some((left, right)) = trimmed.split_once(':') {
let start = column_letters_to_index(left)?;
let end = column_letters_to_index(right)?;
Ok((start.min(end), start.max(end)))
} else {
let column = column_letters_to_index(trimmed)?;
Ok((column, column))
}
}
pub fn is_valid_cell_reference(address: &str) -> bool {
CellAddress::parse(address).is_ok()
}
pub fn is_valid_range_reference(address: &str) -> bool {
CellRange::parse(address).is_ok()
}
pub fn is_valid_row_reference(address: &str) -> bool {
CellRange::parse(address)
.map(|range| range.is_single_row())
.unwrap_or(false)
}
pub fn is_valid_column_reference(address: &str) -> bool {
parse_column_reference(address).is_ok()
}

View File

@@ -0,0 +1,357 @@
use serde::Deserialize;
use serde::Serialize;
use crate::CellAddress;
use crate::CellRange;
use crate::SpreadsheetArtifactError;
use crate::SpreadsheetSheet;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SpreadsheetChartType {
Area,
Bar,
Doughnut,
Line,
Pie,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum SpreadsheetChartLegendPosition {
Bottom,
Top,
Left,
Right,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpreadsheetChartLegend {
pub visible: bool,
pub position: SpreadsheetChartLegendPosition,
pub overlay: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpreadsheetChartAxis {
pub linked_number_format: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpreadsheetChartSeries {
pub id: u32,
pub name: Option<String>,
pub category_sheet_name: Option<String>,
pub category_range: String,
pub value_sheet_name: Option<String>,
pub value_range: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpreadsheetChart {
pub id: u32,
pub chart_type: SpreadsheetChartType,
pub source_sheet_name: Option<String>,
pub source_range: Option<String>,
pub title: Option<String>,
pub style_index: u32,
pub display_blanks_as: String,
pub legend: SpreadsheetChartLegend,
pub category_axis: SpreadsheetChartAxis,
pub value_axis: SpreadsheetChartAxis,
#[serde(default)]
pub series: Vec<SpreadsheetChartSeries>,
}
#[derive(Debug, Clone, Default)]
pub struct SpreadsheetChartLookup {
pub id: Option<u32>,
pub index: Option<usize>,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpreadsheetChartCreateOptions {
pub id: Option<u32>,
pub title: Option<String>,
pub legend_visible: Option<bool>,
pub legend_position: Option<SpreadsheetChartLegendPosition>,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpreadsheetChartProperties {
pub title: Option<String>,
pub legend_visible: Option<bool>,
pub legend_position: Option<SpreadsheetChartLegendPosition>,
}
impl SpreadsheetSheet {
pub fn list_charts(
&self,
range: Option<&CellRange>,
) -> Result<Vec<SpreadsheetChart>, SpreadsheetArtifactError> {
Ok(self
.charts
.iter()
.filter(|chart| {
range.is_none_or(|target| {
chart
.source_range
.as_deref()
.map(CellRange::parse)
.transpose()
.ok()
.flatten()
.is_some_and(|chart_range| chart_range.intersects(target))
})
})
.cloned()
.collect())
}
pub fn get_chart(
&self,
action: &str,
lookup: SpreadsheetChartLookup,
) -> Result<&SpreadsheetChart, SpreadsheetArtifactError> {
if let Some(id) = lookup.id {
return self
.charts
.iter()
.find(|chart| chart.id == id)
.ok_or_else(|| SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("chart id `{id}` was not found"),
});
}
if let Some(index) = lookup.index {
return self.charts.get(index).ok_or_else(|| {
SpreadsheetArtifactError::IndexOutOfRange {
action: action.to_string(),
index,
len: self.charts.len(),
}
});
}
Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "chart id or index is required".to_string(),
})
}
pub fn create_chart(
&mut self,
action: &str,
chart_type: SpreadsheetChartType,
source_sheet_name: Option<String>,
source_range: &CellRange,
options: SpreadsheetChartCreateOptions,
) -> Result<u32, SpreadsheetArtifactError> {
if source_range.width() < 2 {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "chart source range must include at least two columns".to_string(),
});
}
let id = if let Some(id) = options.id {
if self.charts.iter().any(|chart| chart.id == id) {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("chart id `{id}` already exists"),
});
}
id
} else {
self.charts.iter().map(|chart| chart.id).max().unwrap_or(0) + 1
};
let series = (source_range.start.column + 1..=source_range.end.column)
.enumerate()
.map(|(index, value_column)| SpreadsheetChartSeries {
id: index as u32 + 1,
name: None,
category_sheet_name: source_sheet_name.clone(),
category_range: CellRange::from_start_end(
source_range.start,
CellAddress {
column: source_range.start.column,
row: source_range.end.row,
},
)
.to_a1(),
value_sheet_name: source_sheet_name.clone(),
value_range: CellRange::from_start_end(
CellAddress {
column: value_column,
row: source_range.start.row,
},
CellAddress {
column: value_column,
row: source_range.end.row,
},
)
.to_a1(),
})
.collect::<Vec<_>>();
self.charts.push(SpreadsheetChart {
id,
chart_type,
source_sheet_name,
source_range: Some(source_range.to_a1()),
title: options.title,
style_index: 102,
display_blanks_as: "gap".to_string(),
legend: SpreadsheetChartLegend {
visible: options.legend_visible.unwrap_or(true),
position: options
.legend_position
.unwrap_or(SpreadsheetChartLegendPosition::Bottom),
overlay: false,
},
category_axis: SpreadsheetChartAxis {
linked_number_format: true,
},
value_axis: SpreadsheetChartAxis {
linked_number_format: true,
},
series,
});
Ok(id)
}
pub fn add_chart_series(
&mut self,
action: &str,
lookup: SpreadsheetChartLookup,
mut series: SpreadsheetChartSeries,
) -> Result<u32, SpreadsheetArtifactError> {
validate_chart_series(action, &series)?;
let chart = self.get_chart_mut(action, lookup)?;
let next_id = chart.series.iter().map(|entry| entry.id).max().unwrap_or(0) + 1;
series.id = next_id;
chart.series.push(series);
Ok(next_id)
}
pub fn delete_chart(
&mut self,
action: &str,
lookup: SpreadsheetChartLookup,
) -> Result<(), SpreadsheetArtifactError> {
let index = if let Some(id) = lookup.id {
self.charts
.iter()
.position(|chart| chart.id == id)
.ok_or_else(|| SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("chart id `{id}` was not found"),
})?
} else if let Some(index) = lookup.index {
if index >= self.charts.len() {
return Err(SpreadsheetArtifactError::IndexOutOfRange {
action: action.to_string(),
index,
len: self.charts.len(),
});
}
index
} else {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "chart id or index is required".to_string(),
});
};
self.charts.remove(index);
Ok(())
}
pub fn set_chart_properties(
&mut self,
action: &str,
lookup: SpreadsheetChartLookup,
properties: SpreadsheetChartProperties,
) -> Result<(), SpreadsheetArtifactError> {
let chart = self.get_chart_mut(action, lookup)?;
if let Some(title) = properties.title {
chart.title = Some(title);
}
if let Some(visible) = properties.legend_visible {
chart.legend.visible = visible;
}
if let Some(position) = properties.legend_position {
chart.legend.position = position;
}
Ok(())
}
pub fn validate_charts(&self, action: &str) -> Result<(), SpreadsheetArtifactError> {
for chart in &self.charts {
if let Some(source_range) = &chart.source_range {
let range = CellRange::parse(source_range)?;
if range.width() < 2 {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: format!(
"chart `{}` source range `{source_range}` is too narrow",
chart.id
),
});
}
}
for series in &chart.series {
validate_chart_series(action, series)?;
}
}
Ok(())
}
fn get_chart_mut(
&mut self,
action: &str,
lookup: SpreadsheetChartLookup,
) -> Result<&mut SpreadsheetChart, SpreadsheetArtifactError> {
if let Some(id) = lookup.id {
return self
.charts
.iter_mut()
.find(|chart| chart.id == id)
.ok_or_else(|| SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("chart id `{id}` was not found"),
});
}
if let Some(index) = lookup.index {
let len = self.charts.len();
return self.charts.get_mut(index).ok_or_else(|| {
SpreadsheetArtifactError::IndexOutOfRange {
action: action.to_string(),
index,
len,
}
});
}
Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "chart id or index is required".to_string(),
})
}
}
fn validate_chart_series(
action: &str,
series: &SpreadsheetChartSeries,
) -> Result<(), SpreadsheetArtifactError> {
let category_range = CellRange::parse(&series.category_range)?;
let value_range = CellRange::parse(&series.value_range)?;
if !category_range.is_single_column() || !value_range.is_single_column() {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "chart category and value ranges must be single-column ranges".to_string(),
});
}
if category_range.height() != value_range.height() {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "chart category and value series lengths must match".to_string(),
});
}
Ok(())
}

View File

@@ -0,0 +1,308 @@
use serde::Deserialize;
use serde::Serialize;
use crate::CellRange;
use crate::SpreadsheetArtifact;
use crate::SpreadsheetArtifactError;
use crate::SpreadsheetSheet;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum SpreadsheetConditionalFormatType {
Expression,
CellIs,
ColorScale,
DataBar,
IconSet,
Top10,
UniqueValues,
DuplicateValues,
ContainsText,
NotContainsText,
BeginsWith,
EndsWith,
ContainsBlanks,
NotContainsBlanks,
ContainsErrors,
NotContainsErrors,
TimePeriod,
AboveAverage,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SpreadsheetColorScale {
pub min_type: Option<String>,
pub mid_type: Option<String>,
pub max_type: Option<String>,
pub min_value: Option<String>,
pub mid_value: Option<String>,
pub max_value: Option<String>,
pub min_color: String,
pub mid_color: Option<String>,
pub max_color: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SpreadsheetDataBar {
pub color: String,
pub min_length: Option<u8>,
pub max_length: Option<u8>,
pub show_value: Option<bool>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SpreadsheetIconSet {
pub style: String,
pub show_value: Option<bool>,
pub reverse_order: Option<bool>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SpreadsheetConditionalFormat {
pub id: u32,
pub range: String,
pub rule_type: SpreadsheetConditionalFormatType,
pub operator: Option<String>,
#[serde(default)]
pub formulas: Vec<String>,
pub text: Option<String>,
pub dxf_id: Option<u32>,
pub stop_if_true: bool,
pub priority: u32,
pub rank: Option<u32>,
pub percent: Option<bool>,
pub time_period: Option<String>,
pub above_average: Option<bool>,
pub equal_average: Option<bool>,
pub color_scale: Option<SpreadsheetColorScale>,
pub data_bar: Option<SpreadsheetDataBar>,
pub icon_set: Option<SpreadsheetIconSet>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SpreadsheetConditionalFormatCollection {
pub sheet_name: String,
pub range: String,
}
impl SpreadsheetConditionalFormatCollection {
pub fn new(sheet_name: String, range: &CellRange) -> Self {
Self {
sheet_name,
range: range.to_a1(),
}
}
pub fn range(&self) -> Result<CellRange, SpreadsheetArtifactError> {
CellRange::parse(&self.range)
}
pub fn list(
&self,
artifact: &SpreadsheetArtifact,
) -> Result<Vec<SpreadsheetConditionalFormat>, SpreadsheetArtifactError> {
let sheet = artifact.sheet_lookup(
"conditional_format_collection",
Some(&self.sheet_name),
None,
)?;
Ok(sheet.list_conditional_formats(Some(&self.range()?)))
}
pub fn add(
&self,
artifact: &mut SpreadsheetArtifact,
mut format: SpreadsheetConditionalFormat,
) -> Result<u32, SpreadsheetArtifactError> {
format.range = self.range.clone();
artifact.add_conditional_format("conditional_format_collection", &self.sheet_name, format)
}
pub fn delete(
&self,
artifact: &mut SpreadsheetArtifact,
id: u32,
) -> Result<(), SpreadsheetArtifactError> {
artifact.delete_conditional_format("conditional_format_collection", &self.sheet_name, id)
}
}
impl SpreadsheetArtifact {
pub fn validate_conditional_formats(
&self,
action: &str,
sheet_name: &str,
) -> Result<(), SpreadsheetArtifactError> {
let sheet = self.sheet_lookup(action, Some(sheet_name), None)?;
for format in &sheet.conditional_formats {
validate_conditional_format(self, format, action)?;
}
Ok(())
}
pub fn add_conditional_format(
&mut self,
action: &str,
sheet_name: &str,
mut format: SpreadsheetConditionalFormat,
) -> Result<u32, SpreadsheetArtifactError> {
validate_conditional_format(self, &format, action)?;
let sheet = self.sheet_lookup_mut(action, Some(sheet_name), None)?;
let next_id = sheet
.conditional_formats
.iter()
.map(|entry| entry.id)
.max()
.unwrap_or(0)
+ 1;
format.id = next_id;
format.priority = if format.priority == 0 {
next_id
} else {
format.priority
};
sheet.conditional_formats.push(format);
Ok(next_id)
}
pub fn delete_conditional_format(
&mut self,
action: &str,
sheet_name: &str,
id: u32,
) -> Result<(), SpreadsheetArtifactError> {
let sheet = self.sheet_lookup_mut(action, Some(sheet_name), None)?;
let previous_len = sheet.conditional_formats.len();
sheet.conditional_formats.retain(|entry| entry.id != id);
if sheet.conditional_formats.len() == previous_len {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("conditional format `{id}` was not found"),
});
}
Ok(())
}
}
impl SpreadsheetSheet {
pub fn conditional_format_collection(
&self,
range: &CellRange,
) -> SpreadsheetConditionalFormatCollection {
SpreadsheetConditionalFormatCollection::new(self.name.clone(), range)
}
pub fn list_conditional_formats(
&self,
range: Option<&CellRange>,
) -> Vec<SpreadsheetConditionalFormat> {
self.conditional_formats
.iter()
.filter(|entry| {
range.is_none_or(|target| {
CellRange::parse(&entry.range)
.map(|entry_range| entry_range.intersects(target))
.unwrap_or(false)
})
})
.cloned()
.collect()
}
}
fn validate_conditional_format(
artifact: &SpreadsheetArtifact,
format: &SpreadsheetConditionalFormat,
action: &str,
) -> Result<(), SpreadsheetArtifactError> {
CellRange::parse(&format.range)?;
if let Some(dxf_id) = format.dxf_id
&& artifact.get_differential_format(dxf_id).is_none()
{
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("differential format `{dxf_id}` was not found"),
});
}
let has_style = format.dxf_id.is_some();
let has_intrinsic_visual =
format.color_scale.is_some() || format.data_bar.is_some() || format.icon_set.is_some();
match format.rule_type {
SpreadsheetConditionalFormatType::Expression | SpreadsheetConditionalFormatType::CellIs => {
if format.formulas.is_empty() {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "conditional format formulas are required".to_string(),
});
}
}
SpreadsheetConditionalFormatType::ContainsText
| SpreadsheetConditionalFormatType::NotContainsText
| SpreadsheetConditionalFormatType::BeginsWith
| SpreadsheetConditionalFormatType::EndsWith => {
if format.text.as_deref().unwrap_or_default().is_empty() {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "conditional format text is required".to_string(),
});
}
}
SpreadsheetConditionalFormatType::ColorScale => {
if format.color_scale.is_none() {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "color scale settings are required".to_string(),
});
}
}
SpreadsheetConditionalFormatType::DataBar => {
if format.data_bar.is_none() {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "data bar settings are required".to_string(),
});
}
}
SpreadsheetConditionalFormatType::IconSet => {
if format.icon_set.is_none() {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "icon set settings are required".to_string(),
});
}
}
SpreadsheetConditionalFormatType::Top10 => {
if format.rank.is_none() {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "top10 rank is required".to_string(),
});
}
}
SpreadsheetConditionalFormatType::TimePeriod => {
if format.time_period.as_deref().unwrap_or_default().is_empty() {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "time period is required".to_string(),
});
}
}
SpreadsheetConditionalFormatType::AboveAverage => {}
SpreadsheetConditionalFormatType::UniqueValues
| SpreadsheetConditionalFormatType::DuplicateValues
| SpreadsheetConditionalFormatType::ContainsBlanks
| SpreadsheetConditionalFormatType::NotContainsBlanks
| SpreadsheetConditionalFormatType::ContainsErrors
| SpreadsheetConditionalFormatType::NotContainsErrors => {}
}
if !has_style && !has_intrinsic_visual {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "conditional formatting requires at least one style component".to_string(),
});
}
Ok(())
}

View File

@@ -0,0 +1,39 @@
use std::path::PathBuf;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SpreadsheetArtifactError {
#[error("missing `artifact_id` for action `{action}`")]
MissingArtifactId { action: String },
#[error("unknown artifact id `{artifact_id}` for action `{action}`")]
UnknownArtifactId { action: String, artifact_id: String },
#[error("unknown action `{0}`")]
UnknownAction(String),
#[error("invalid args for action `{action}`: {message}")]
InvalidArgs { action: String, message: String },
#[error("invalid address `{address}`: {message}")]
InvalidAddress { address: String, message: String },
#[error("sheet lookup failed for action `{action}`: {message}")]
SheetLookup { action: String, message: String },
#[error("index `{index}` is out of range for action `{action}`; len={len}")]
IndexOutOfRange {
action: String,
index: usize,
len: usize,
},
#[error("merge conflict for action `{action}` on range `{range}` with `{conflict}`")]
MergeConflict {
action: String,
range: String,
conflict: String,
},
#[error("formula error at `{location}`: {message}")]
Formula { location: String, message: String },
#[error("serialization failed: {message}")]
Serialization { message: String },
#[error("failed to import XLSX `{path}`: {message}")]
ImportFailed { path: PathBuf, message: String },
#[error("failed to export XLSX `{path}`: {message}")]
ExportFailed { path: PathBuf, message: String },
}

View File

@@ -0,0 +1,535 @@
use std::collections::BTreeSet;
use crate::CellAddress;
use crate::CellRange;
use crate::SpreadsheetArtifact;
use crate::SpreadsheetArtifactError;
use crate::SpreadsheetCellValue;
#[derive(Debug, Clone)]
enum Token {
Number(f64),
Cell(String),
Ident(String),
Plus,
Minus,
Star,
Slash,
LParen,
RParen,
Colon,
Comma,
}
#[derive(Debug, Clone)]
enum Expr {
Number(f64),
Cell(CellAddress),
Range(CellRange),
UnaryMinus(Box<Expr>),
Binary {
op: BinaryOp,
left: Box<Expr>,
right: Box<Expr>,
},
Function {
name: String,
args: Vec<Expr>,
},
}
#[derive(Debug, Clone, Copy)]
enum BinaryOp {
Add,
Subtract,
Multiply,
Divide,
}
#[derive(Debug, Clone)]
enum EvalValue {
Scalar(Option<SpreadsheetCellValue>),
Range(Vec<Option<SpreadsheetCellValue>>),
}
pub(crate) fn recalculate_workbook(artifact: &mut SpreadsheetArtifact) {
let updates = artifact
.sheets
.iter()
.enumerate()
.flat_map(|(sheet_index, sheet)| {
sheet.cells.iter().filter_map(move |(address, cell)| {
cell.formula
.as_ref()
.map(|formula| (sheet_index, *address, formula.clone()))
})
})
.map(|(sheet_index, address, formula)| {
let mut stack = BTreeSet::new();
let value = evaluate_formula(artifact, sheet_index, &formula, &mut stack)
.unwrap_or_else(|error| {
Some(SpreadsheetCellValue::Error(map_error_to_code(&error)))
});
(sheet_index, address, value)
})
.collect::<Vec<_>>();
for (sheet_index, address, value) in updates {
if let Some(sheet) = artifact.sheets.get_mut(sheet_index)
&& let Some(cell) = sheet.cells.get_mut(&address)
{
cell.value = value;
}
}
}
fn evaluate_formula(
artifact: &SpreadsheetArtifact,
sheet_index: usize,
formula: &str,
stack: &mut BTreeSet<(usize, CellAddress)>,
) -> Result<Option<SpreadsheetCellValue>, SpreadsheetArtifactError> {
let source = formula.trim().trim_start_matches('=');
let tokens = tokenize(source)?;
let mut parser = Parser::new(tokens);
let expr = parser.parse_expression()?;
if parser.has_remaining() {
return Err(SpreadsheetArtifactError::Formula {
location: formula.to_string(),
message: "unexpected trailing tokens".to_string(),
});
}
match evaluate_expr(artifact, sheet_index, &expr, stack)? {
EvalValue::Scalar(value) => Ok(value),
EvalValue::Range(_) => Err(SpreadsheetArtifactError::Formula {
location: formula.to_string(),
message: "range expressions are only allowed inside functions".to_string(),
}),
}
}
fn evaluate_expr(
artifact: &SpreadsheetArtifact,
sheet_index: usize,
expr: &Expr,
stack: &mut BTreeSet<(usize, CellAddress)>,
) -> Result<EvalValue, SpreadsheetArtifactError> {
match expr {
Expr::Number(value) => Ok(EvalValue::Scalar(Some(number_to_value(*value)))),
Expr::Cell(address) => evaluate_cell_reference(artifact, sheet_index, *address, stack),
Expr::Range(range) => {
let sheet = artifact.sheets.get(sheet_index).ok_or_else(|| {
SpreadsheetArtifactError::Formula {
location: range.to_a1(),
message: "sheet index was not found".to_string(),
}
})?;
let values = range
.addresses()
.map(|address| sheet.get_cell(address).and_then(|cell| cell.value.clone()))
.collect::<Vec<_>>();
Ok(EvalValue::Range(values))
}
Expr::UnaryMinus(inner) => {
let value = evaluate_scalar(artifact, sheet_index, inner, stack)?;
Ok(EvalValue::Scalar(match value {
None => Some(SpreadsheetCellValue::Integer(0)),
Some(SpreadsheetCellValue::Integer(value)) => {
Some(SpreadsheetCellValue::Integer(-value))
}
Some(SpreadsheetCellValue::Float(value)) => {
Some(SpreadsheetCellValue::Float(-value))
}
Some(SpreadsheetCellValue::Error(value)) => {
Some(SpreadsheetCellValue::Error(value))
}
Some(_) => Some(SpreadsheetCellValue::Error("#VALUE!".to_string())),
}))
}
Expr::Binary { op, left, right } => {
let left = evaluate_scalar(artifact, sheet_index, left, stack)?;
let right = evaluate_scalar(artifact, sheet_index, right, stack)?;
Ok(EvalValue::Scalar(Some(apply_binary_op(*op, left, right)?)))
}
Expr::Function { name, args } => {
let mut numeric = Vec::new();
for arg in args {
match evaluate_expr(artifact, sheet_index, arg, stack)? {
EvalValue::Scalar(value) => {
if let Some(number) = scalar_to_number(value.clone())? {
numeric.push(number);
}
}
EvalValue::Range(values) => {
for value in values {
if let Some(number) = scalar_to_number(value.clone())? {
numeric.push(number);
}
}
}
}
}
let upper = name.to_ascii_uppercase();
let result = match upper.as_str() {
"SUM" => numeric.iter().sum::<f64>(),
"AVERAGE" => {
if numeric.is_empty() {
return Ok(EvalValue::Scalar(None));
}
numeric.iter().sum::<f64>() / numeric.len() as f64
}
"MIN" => numeric.iter().copied().reduce(f64::min).unwrap_or(0.0),
"MAX" => numeric.iter().copied().reduce(f64::max).unwrap_or(0.0),
_ => {
return Ok(EvalValue::Scalar(Some(SpreadsheetCellValue::Error(
"#NAME?".to_string(),
))));
}
};
Ok(EvalValue::Scalar(Some(number_to_value(result))))
}
}
}
fn evaluate_scalar(
artifact: &SpreadsheetArtifact,
sheet_index: usize,
expr: &Expr,
stack: &mut BTreeSet<(usize, CellAddress)>,
) -> Result<Option<SpreadsheetCellValue>, SpreadsheetArtifactError> {
match evaluate_expr(artifact, sheet_index, expr, stack)? {
EvalValue::Scalar(value) => Ok(value),
EvalValue::Range(_) => Err(SpreadsheetArtifactError::Formula {
location: format!("{expr:?}"),
message: "expected a scalar expression".to_string(),
}),
}
}
fn evaluate_cell_reference(
artifact: &SpreadsheetArtifact,
sheet_index: usize,
address: CellAddress,
stack: &mut BTreeSet<(usize, CellAddress)>,
) -> Result<EvalValue, SpreadsheetArtifactError> {
let Some(sheet) = artifact.sheets.get(sheet_index) else {
return Err(SpreadsheetArtifactError::Formula {
location: address.to_a1(),
message: "sheet index was not found".to_string(),
});
};
let key = (sheet_index, address);
if !stack.insert(key) {
return Ok(EvalValue::Scalar(Some(SpreadsheetCellValue::Error(
"#CYCLE!".to_string(),
))));
}
let value = if let Some(cell) = sheet.get_cell(address) {
if let Some(formula) = &cell.formula {
evaluate_formula(artifact, sheet_index, formula, stack)?
} else {
cell.value.clone()
}
} else {
None
};
stack.remove(&key);
Ok(EvalValue::Scalar(value))
}
fn apply_binary_op(
op: BinaryOp,
left: Option<SpreadsheetCellValue>,
right: Option<SpreadsheetCellValue>,
) -> Result<SpreadsheetCellValue, SpreadsheetArtifactError> {
if let Some(SpreadsheetCellValue::Error(value)) = &left {
return Ok(SpreadsheetCellValue::Error(value.clone()));
}
if let Some(SpreadsheetCellValue::Error(value)) = &right {
return Ok(SpreadsheetCellValue::Error(value.clone()));
}
let left = scalar_to_number(left)?;
let right = scalar_to_number(right)?;
let left = left.unwrap_or(0.0);
let right = right.unwrap_or(0.0);
let result = match op {
BinaryOp::Add => left + right,
BinaryOp::Subtract => left - right,
BinaryOp::Multiply => left * right,
BinaryOp::Divide => {
if right == 0.0 {
return Ok(SpreadsheetCellValue::Error("#DIV/0!".to_string()));
}
left / right
}
};
Ok(number_to_value(result))
}
fn scalar_to_number(
value: Option<SpreadsheetCellValue>,
) -> Result<Option<f64>, SpreadsheetArtifactError> {
match value {
None => Ok(None),
Some(SpreadsheetCellValue::Integer(value)) => Ok(Some(value as f64)),
Some(SpreadsheetCellValue::Float(value)) => Ok(Some(value)),
Some(SpreadsheetCellValue::Bool(value)) => Ok(Some(if value { 1.0 } else { 0.0 })),
Some(SpreadsheetCellValue::Error(value)) => Err(SpreadsheetArtifactError::Formula {
location: value,
message: "encountered error value".to_string(),
}),
Some(other) => Err(SpreadsheetArtifactError::Formula {
location: format!("{other:?}"),
message: "value is not numeric".to_string(),
}),
}
}
fn number_to_value(number: f64) -> SpreadsheetCellValue {
if number.fract() == 0.0 {
SpreadsheetCellValue::Integer(number as i64)
} else {
SpreadsheetCellValue::Float(number)
}
}
fn map_error_to_code(error: &SpreadsheetArtifactError) -> String {
match error {
SpreadsheetArtifactError::Formula { message, .. } => {
if message.contains("cycle") {
"#CYCLE!".to_string()
} else if message.contains("not numeric") || message.contains("scalar") {
"#VALUE!".to_string()
} else {
"#ERROR!".to_string()
}
}
SpreadsheetArtifactError::InvalidAddress { .. } => "#REF!".to_string(),
_ => "#ERROR!".to_string(),
}
}
fn tokenize(source: &str) -> Result<Vec<Token>, SpreadsheetArtifactError> {
let chars = source.chars().collect::<Vec<_>>();
let mut index = 0usize;
let mut tokens = Vec::new();
while index < chars.len() {
let ch = chars[index];
if ch.is_ascii_whitespace() {
index += 1;
continue;
}
match ch {
'+' => {
tokens.push(Token::Plus);
index += 1;
}
'-' => {
tokens.push(Token::Minus);
index += 1;
}
'*' => {
tokens.push(Token::Star);
index += 1;
}
'/' => {
tokens.push(Token::Slash);
index += 1;
}
'(' => {
tokens.push(Token::LParen);
index += 1;
}
')' => {
tokens.push(Token::RParen);
index += 1;
}
':' => {
tokens.push(Token::Colon);
index += 1;
}
',' => {
tokens.push(Token::Comma);
index += 1;
}
'0'..='9' | '.' => {
let start = index;
index += 1;
while index < chars.len() && (chars[index].is_ascii_digit() || chars[index] == '.')
{
index += 1;
}
let number = source[start..index].parse::<f64>().map_err(|_| {
SpreadsheetArtifactError::Formula {
location: source.to_string(),
message: "invalid numeric literal".to_string(),
}
})?;
tokens.push(Token::Number(number));
}
'A'..='Z' | 'a'..='z' | '_' => {
let start = index;
index += 1;
while index < chars.len()
&& (chars[index].is_ascii_alphanumeric() || chars[index] == '_')
{
index += 1;
}
let text = source[start..index].to_string();
if text.chars().any(|part| part.is_ascii_digit())
&& text.chars().any(|part| part.is_ascii_alphabetic())
{
tokens.push(Token::Cell(text));
} else {
tokens.push(Token::Ident(text));
}
}
other => {
return Err(SpreadsheetArtifactError::Formula {
location: source.to_string(),
message: format!("unsupported token `{other}`"),
});
}
}
}
Ok(tokens)
}
struct Parser {
tokens: Vec<Token>,
index: usize,
}
impl Parser {
fn new(tokens: Vec<Token>) -> Self {
Self { tokens, index: 0 }
}
fn has_remaining(&self) -> bool {
self.index < self.tokens.len()
}
fn parse_expression(&mut self) -> Result<Expr, SpreadsheetArtifactError> {
let mut expr = self.parse_term()?;
while let Some(token) = self.peek() {
let op = match token {
Token::Plus => BinaryOp::Add,
Token::Minus => BinaryOp::Subtract,
_ => break,
};
self.index += 1;
let right = self.parse_term()?;
expr = Expr::Binary {
op,
left: Box::new(expr),
right: Box::new(right),
};
}
Ok(expr)
}
fn parse_term(&mut self) -> Result<Expr, SpreadsheetArtifactError> {
let mut expr = self.parse_factor()?;
while let Some(token) = self.peek() {
let op = match token {
Token::Star => BinaryOp::Multiply,
Token::Slash => BinaryOp::Divide,
_ => break,
};
self.index += 1;
let right = self.parse_factor()?;
expr = Expr::Binary {
op,
left: Box::new(expr),
right: Box::new(right),
};
}
Ok(expr)
}
fn parse_factor(&mut self) -> Result<Expr, SpreadsheetArtifactError> {
match self.peek() {
Some(Token::Minus) => {
self.index += 1;
Ok(Expr::UnaryMinus(Box::new(self.parse_factor()?)))
}
_ => self.parse_primary(),
}
}
fn parse_primary(&mut self) -> Result<Expr, SpreadsheetArtifactError> {
match self.next().cloned() {
Some(Token::Number(value)) => Ok(Expr::Number(value)),
Some(Token::Cell(address)) => {
let start = CellAddress::parse(&address)?;
if matches!(self.peek(), Some(Token::Colon)) {
self.index += 1;
let Some(Token::Cell(end)) = self.next().cloned() else {
return Err(SpreadsheetArtifactError::Formula {
location: address,
message: "expected cell after `:`".to_string(),
});
};
Ok(Expr::Range(CellRange::from_start_end(
start,
CellAddress::parse(&end)?,
)))
} else {
Ok(Expr::Cell(start))
}
}
Some(Token::Ident(name)) => {
if !matches!(self.next(), Some(Token::LParen)) {
return Err(SpreadsheetArtifactError::Formula {
location: name,
message: "expected `(` after function name".to_string(),
});
}
let mut args = Vec::new();
if !matches!(self.peek(), Some(Token::RParen)) {
loop {
args.push(self.parse_expression()?);
if matches!(self.peek(), Some(Token::Comma)) {
self.index += 1;
continue;
}
break;
}
}
if !matches!(self.next(), Some(Token::RParen)) {
return Err(SpreadsheetArtifactError::Formula {
location: name,
message: "expected `)`".to_string(),
});
}
Ok(Expr::Function { name, args })
}
Some(Token::LParen) => {
let expr = self.parse_expression()?;
if !matches!(self.next(), Some(Token::RParen)) {
return Err(SpreadsheetArtifactError::Formula {
location: format!("{expr:?}"),
message: "expected `)`".to_string(),
});
}
Ok(expr)
}
other => Err(SpreadsheetArtifactError::Formula {
location: format!("{other:?}"),
message: "unexpected token".to_string(),
}),
}
}
fn peek(&self) -> Option<&Token> {
self.tokens.get(self.index)
}
fn next(&mut self) -> Option<&Token> {
let token = self.tokens.get(self.index);
self.index += usize::from(token.is_some());
token
}
}

View File

@@ -0,0 +1,26 @@
mod address;
mod chart;
mod conditional;
mod error;
mod formula;
mod manager;
mod model;
mod pivot;
mod render;
mod style;
mod table;
mod xlsx;
#[cfg(test)]
mod tests;
pub use address::*;
pub use chart::*;
pub use conditional::*;
pub use error::*;
pub use manager::*;
pub use model::*;
pub use pivot::*;
pub use render::*;
pub use style::*;
pub use table::*;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,177 @@
use std::collections::BTreeMap;
use serde::Deserialize;
use serde::Serialize;
use crate::CellRange;
use crate::SpreadsheetArtifactError;
use crate::SpreadsheetCellRangeRef;
use crate::SpreadsheetSheet;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpreadsheetPivotFieldItem {
pub item_type: Option<String>,
pub index: Option<u32>,
pub hidden: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpreadsheetPivotField {
pub index: u32,
pub name: Option<String>,
pub axis: Option<String>,
#[serde(default)]
pub items: Vec<SpreadsheetPivotFieldItem>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpreadsheetPivotFieldReference {
pub field_index: u32,
pub field_name: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpreadsheetPivotPageField {
pub field_index: u32,
pub field_name: Option<String>,
pub selected_item: Option<u32>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpreadsheetPivotDataField {
pub field_index: u32,
pub field_name: Option<String>,
pub name: Option<String>,
pub subtotal: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpreadsheetPivotFilter {
pub field_index: Option<u32>,
pub field_name: Option<String>,
pub filter_type: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpreadsheetPivotTable {
pub name: String,
pub cache_id: u32,
pub address: Option<String>,
#[serde(default)]
pub row_fields: Vec<SpreadsheetPivotFieldReference>,
#[serde(default)]
pub column_fields: Vec<SpreadsheetPivotFieldReference>,
#[serde(default)]
pub page_fields: Vec<SpreadsheetPivotPageField>,
#[serde(default)]
pub data_fields: Vec<SpreadsheetPivotDataField>,
#[serde(default)]
pub filters: Vec<SpreadsheetPivotFilter>,
#[serde(default)]
pub pivot_fields: Vec<SpreadsheetPivotField>,
pub style_name: Option<String>,
pub part_path: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct SpreadsheetPivotTableLookup<'a> {
pub name: Option<&'a str>,
pub index: Option<usize>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpreadsheetPivotCacheDefinition {
pub definition_path: String,
#[serde(default)]
pub field_names: Vec<Option<String>>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct SpreadsheetPivotPreservation {
#[serde(default)]
pub caches: BTreeMap<u32, SpreadsheetPivotCacheDefinition>,
#[serde(default)]
pub parts: BTreeMap<String, String>,
}
impl SpreadsheetPivotTable {
pub fn range(&self) -> Result<Option<CellRange>, SpreadsheetArtifactError> {
self.address.as_deref().map(CellRange::parse).transpose()
}
pub fn range_ref(
&self,
sheet_name: &str,
) -> Result<Option<SpreadsheetCellRangeRef>, SpreadsheetArtifactError> {
Ok(self
.range()?
.map(|range| SpreadsheetCellRangeRef::new(sheet_name.to_string(), &range)))
}
}
impl SpreadsheetSheet {
pub fn list_pivot_tables(
&self,
range: Option<&CellRange>,
) -> Result<Vec<SpreadsheetPivotTable>, SpreadsheetArtifactError> {
Ok(self
.pivot_tables
.iter()
.filter(|pivot_table| {
range.is_none_or(|target| {
pivot_table
.range()
.ok()
.flatten()
.is_some_and(|pivot_range| pivot_range.intersects(target))
})
})
.cloned()
.collect())
}
pub fn get_pivot_table(
&self,
action: &str,
lookup: SpreadsheetPivotTableLookup,
) -> Result<&SpreadsheetPivotTable, SpreadsheetArtifactError> {
if let Some(name) = lookup.name {
return self
.pivot_tables
.iter()
.find(|pivot_table| pivot_table.name == name)
.ok_or_else(|| SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("pivot table `{name}` was not found"),
});
}
if let Some(index) = lookup.index {
return self.pivot_tables.get(index).ok_or_else(|| {
SpreadsheetArtifactError::IndexOutOfRange {
action: action.to_string(),
index,
len: self.pivot_tables.len(),
}
});
}
Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "pivot table name or index is required".to_string(),
})
}
pub fn validate_pivot_tables(&self, action: &str) -> Result<(), SpreadsheetArtifactError> {
for pivot_table in &self.pivot_tables {
if pivot_table.name.trim().is_empty() {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "pivot table name cannot be empty".to_string(),
});
}
if let Some(address) = &pivot_table.address {
CellRange::parse(address)?;
}
}
Ok(())
}
}

View File

@@ -0,0 +1,373 @@
use std::fs;
use std::path::Path;
use std::path::PathBuf;
use serde::Deserialize;
use serde::Serialize;
use crate::CellAddress;
use crate::CellRange;
use crate::SpreadsheetArtifact;
use crate::SpreadsheetArtifactError;
use crate::SpreadsheetSheet;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SpreadsheetRenderOptions {
pub output_path: Option<PathBuf>,
pub center_address: Option<String>,
pub width: Option<u32>,
pub height: Option<u32>,
pub include_headers: bool,
pub scale: f64,
pub performance_mode: bool,
}
impl Default for SpreadsheetRenderOptions {
fn default() -> Self {
Self {
output_path: None,
center_address: None,
width: None,
height: None,
include_headers: true,
scale: 1.0,
performance_mode: false,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SpreadsheetRenderedOutput {
pub path: PathBuf,
pub html: String,
}
impl SpreadsheetSheet {
pub fn render_html(
&self,
range: Option<&CellRange>,
options: &SpreadsheetRenderOptions,
) -> Result<String, SpreadsheetArtifactError> {
let center = options
.center_address
.as_deref()
.map(CellAddress::parse)
.transpose()?;
let viewport = render_viewport(self, range, center, options)?;
let title = range
.map(CellRange::to_a1)
.unwrap_or_else(|| self.name.clone());
Ok(format!(
concat!(
"<!doctype html><html><head><meta charset=\"utf-8\">",
"<title>{}</title>",
"<style>{}</style>",
"</head><body>",
"<section class=\"spreadsheet-preview\" data-sheet=\"{}\" data-performance-mode=\"{}\">",
"<header><h1>{}</h1><p>{}</p></header>",
"<div class=\"viewport\" style=\"{}\">",
"<table>{}</table>",
"</div></section></body></html>"
),
html_escape(&title),
preview_css(),
html_escape(&self.name),
options.performance_mode,
html_escape(&title),
html_escape(&viewport.to_a1()),
viewport_style(options),
render_table(self, &viewport, options),
))
}
}
impl SpreadsheetArtifact {
pub fn render_workbook_previews(
&self,
cwd: &Path,
options: &SpreadsheetRenderOptions,
) -> Result<Vec<SpreadsheetRenderedOutput>, SpreadsheetArtifactError> {
let sheets = if self.sheets.is_empty() {
vec![SpreadsheetSheet::new("Sheet1".to_string())]
} else {
self.sheets.clone()
};
let output_paths = workbook_output_paths(self, cwd, options, &sheets);
sheets
.iter()
.zip(output_paths)
.map(|(sheet, path)| {
let html = sheet.render_html(None, options)?;
write_rendered_output(&path, &html)?;
Ok(SpreadsheetRenderedOutput { path, html })
})
.collect()
}
pub fn render_sheet_preview(
&self,
cwd: &Path,
sheet: &SpreadsheetSheet,
options: &SpreadsheetRenderOptions,
) -> Result<SpreadsheetRenderedOutput, SpreadsheetArtifactError> {
let path = single_output_path(
cwd,
self,
options.output_path.as_deref(),
&format!("render_{}", sanitize_file_component(&sheet.name)),
);
let html = sheet.render_html(None, options)?;
write_rendered_output(&path, &html)?;
Ok(SpreadsheetRenderedOutput { path, html })
}
pub fn render_range_preview(
&self,
cwd: &Path,
sheet: &SpreadsheetSheet,
range: &CellRange,
options: &SpreadsheetRenderOptions,
) -> Result<SpreadsheetRenderedOutput, SpreadsheetArtifactError> {
let path = single_output_path(
cwd,
self,
options.output_path.as_deref(),
&format!(
"render_{}_{}",
sanitize_file_component(&sheet.name),
sanitize_file_component(&range.to_a1())
),
);
let html = sheet.render_html(Some(range), options)?;
write_rendered_output(&path, &html)?;
Ok(SpreadsheetRenderedOutput { path, html })
}
}
fn render_viewport(
sheet: &SpreadsheetSheet,
range: Option<&CellRange>,
center: Option<CellAddress>,
options: &SpreadsheetRenderOptions,
) -> Result<CellRange, SpreadsheetArtifactError> {
let base = range
.cloned()
.or_else(|| sheet.minimum_range())
.unwrap_or_else(|| {
CellRange::from_start_end(
CellAddress { column: 1, row: 1 },
CellAddress { column: 1, row: 1 },
)
});
let Some(center) = center else {
return Ok(base);
};
let visible_columns = options
.width
.map(|width| estimated_visible_count(width, 96.0, options.scale))
.unwrap_or(base.width() as u32);
let visible_rows = options
.height
.map(|height| estimated_visible_count(height, 28.0, options.scale))
.unwrap_or(base.height() as u32);
let half_columns = visible_columns / 2;
let half_rows = visible_rows / 2;
let start_column = center
.column
.saturating_sub(half_columns)
.max(base.start.column);
let start_row = center.row.saturating_sub(half_rows).max(base.start.row);
let end_column = (start_column + visible_columns.saturating_sub(1)).min(base.end.column);
let end_row = (start_row + visible_rows.saturating_sub(1)).min(base.end.row);
Ok(CellRange::from_start_end(
CellAddress {
column: start_column,
row: start_row,
},
CellAddress {
column: end_column.max(start_column),
row: end_row.max(start_row),
},
))
}
fn estimated_visible_count(dimension: u32, cell_size: f64, scale: f64) -> u32 {
((dimension as f64 / (cell_size * scale.max(0.1))).floor() as u32).max(1)
}
fn render_table(
sheet: &SpreadsheetSheet,
range: &CellRange,
options: &SpreadsheetRenderOptions,
) -> String {
let mut rows = Vec::new();
if options.include_headers {
let mut header = vec!["<tr><th class=\"corner\"></th>".to_string()];
for column in range.start.column..=range.end.column {
header.push(format!(
"<th>{}</th>",
crate::column_index_to_letters(column)
));
}
header.push("</tr>".to_string());
rows.push(header.join(""));
}
for row in range.start.row..=range.end.row {
let mut cells = Vec::new();
if options.include_headers {
cells.push(format!("<th>{row}</th>"));
}
for column in range.start.column..=range.end.column {
let address = CellAddress { column, row };
let view = sheet.get_cell_view(address);
let value = view
.data
.as_ref()
.map(render_data_value)
.unwrap_or_default();
cells.push(format!(
"<td data-address=\"{}\" data-style-index=\"{}\">{}</td>",
address.to_a1(),
view.style_index,
html_escape(&value)
));
}
rows.push(format!("<tr>{}</tr>", cells.join("")));
}
rows.join("")
}
fn render_data_value(value: &serde_json::Value) -> String {
match value {
serde_json::Value::String(value) => value.clone(),
serde_json::Value::Bool(value) => value.to_string(),
serde_json::Value::Number(value) => value.to_string(),
serde_json::Value::Null => String::new(),
other => other.to_string(),
}
}
fn viewport_style(options: &SpreadsheetRenderOptions) -> String {
let mut style = vec![
format!("--scale: {}", options.scale.max(0.1)),
format!(
"--headers: {}",
if options.include_headers { "1" } else { "0" }
),
];
if let Some(width) = options.width {
style.push(format!("width: {width}px"));
}
if let Some(height) = options.height {
style.push(format!("height: {height}px"));
}
style.push("overflow: auto".to_string());
style.join("; ")
}
fn preview_css() -> &'static str {
concat!(
"body{margin:0;padding:24px;background:#f5f3ee;color:#1e1e1e;font-family:Georgia,serif;}",
".spreadsheet-preview{display:flex;flex-direction:column;gap:16px;}",
"header h1{margin:0;font-size:24px;}header p{margin:0;color:#6b6257;font-size:13px;}",
".viewport{border:1px solid #d6d0c7;background:#fff;box-shadow:0 12px 30px rgba(0,0,0,.08);}",
"table{border-collapse:collapse;transform:scale(var(--scale));transform-origin:top left;}",
"th,td{border:1px solid #ddd3c6;padding:6px 10px;min-width:72px;max-width:240px;font-size:13px;text-align:left;vertical-align:top;}",
"th{background:#f0ebe3;font-weight:600;position:sticky;top:0;z-index:1;}",
".corner{background:#e7e0d6;left:0;z-index:2;}",
"td{white-space:pre-wrap;}"
)
}
fn write_rendered_output(path: &Path, html: &str) -> Result<(), SpreadsheetArtifactError> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent).map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
}
fs::write(path, html).map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})
}
fn workbook_output_paths(
artifact: &SpreadsheetArtifact,
cwd: &Path,
options: &SpreadsheetRenderOptions,
sheets: &[SpreadsheetSheet],
) -> Vec<PathBuf> {
if let Some(output_path) = options.output_path.as_deref() {
if output_path.extension().is_some_and(|ext| ext == "html") {
let stem = output_path
.file_stem()
.and_then(|value| value.to_str())
.unwrap_or("render");
let parent = output_path.parent().unwrap_or(cwd);
return sheets
.iter()
.map(|sheet| {
parent.join(format!(
"{}_{}.html",
stem,
sanitize_file_component(&sheet.name)
))
})
.collect();
}
return sheets
.iter()
.map(|sheet| output_path.join(format!("{}.html", sanitize_file_component(&sheet.name))))
.collect();
}
sheets
.iter()
.map(|sheet| {
cwd.join(format!(
"{}_render_{}.html",
artifact.artifact_id,
sanitize_file_component(&sheet.name)
))
})
.collect()
}
fn single_output_path(
cwd: &Path,
artifact: &SpreadsheetArtifact,
output_path: Option<&Path>,
suffix: &str,
) -> PathBuf {
if let Some(output_path) = output_path {
return if output_path.extension().is_some_and(|ext| ext == "html") {
output_path.to_path_buf()
} else {
output_path.join(format!("{suffix}.html"))
};
}
cwd.join(format!("{}_{}.html", artifact.artifact_id, suffix))
}
fn sanitize_file_component(value: &str) -> String {
value
.chars()
.map(|character| {
if character.is_ascii_alphanumeric() {
character
} else {
'_'
}
})
.collect()
}
fn html_escape(value: &str) -> String {
value
.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
.replace('\'', "&#39;")
}

View File

@@ -0,0 +1,580 @@
use std::collections::BTreeMap;
use serde::Deserialize;
use serde::Serialize;
use crate::CellRange;
use crate::SpreadsheetArtifact;
use crate::SpreadsheetArtifactError;
use crate::SpreadsheetCellRangeRef;
use crate::SpreadsheetSheet;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct SpreadsheetFontFace {
pub font_family: Option<String>,
pub font_scheme: Option<String>,
pub typeface: Option<String>,
}
impl SpreadsheetFontFace {
fn merge(&self, patch: &Self) -> Self {
Self {
font_family: patch
.font_family
.clone()
.or_else(|| self.font_family.clone()),
font_scheme: patch
.font_scheme
.clone()
.or_else(|| self.font_scheme.clone()),
typeface: patch.typeface.clone().or_else(|| self.typeface.clone()),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub struct SpreadsheetTextStyle {
pub bold: Option<bool>,
pub italic: Option<bool>,
pub underline: Option<bool>,
pub font_size: Option<f64>,
pub font_color: Option<String>,
pub text_alignment: Option<String>,
pub anchor: Option<String>,
pub vertical_text_orientation: Option<String>,
pub text_rotation: Option<i32>,
pub paragraph_spacing: Option<bool>,
pub bottom_inset: Option<f64>,
pub left_inset: Option<f64>,
pub right_inset: Option<f64>,
pub top_inset: Option<f64>,
pub font_family: Option<String>,
pub font_scheme: Option<String>,
pub typeface: Option<String>,
pub font_face: Option<SpreadsheetFontFace>,
}
impl SpreadsheetTextStyle {
fn merge(&self, patch: &Self) -> Self {
Self {
bold: patch.bold.or(self.bold),
italic: patch.italic.or(self.italic),
underline: patch.underline.or(self.underline),
font_size: patch.font_size.or(self.font_size),
font_color: patch.font_color.clone().or_else(|| self.font_color.clone()),
text_alignment: patch
.text_alignment
.clone()
.or_else(|| self.text_alignment.clone()),
anchor: patch.anchor.clone().or_else(|| self.anchor.clone()),
vertical_text_orientation: patch
.vertical_text_orientation
.clone()
.or_else(|| self.vertical_text_orientation.clone()),
text_rotation: patch.text_rotation.or(self.text_rotation),
paragraph_spacing: patch.paragraph_spacing.or(self.paragraph_spacing),
bottom_inset: patch.bottom_inset.or(self.bottom_inset),
left_inset: patch.left_inset.or(self.left_inset),
right_inset: patch.right_inset.or(self.right_inset),
top_inset: patch.top_inset.or(self.top_inset),
font_family: patch
.font_family
.clone()
.or_else(|| self.font_family.clone()),
font_scheme: patch
.font_scheme
.clone()
.or_else(|| self.font_scheme.clone()),
typeface: patch.typeface.clone().or_else(|| self.typeface.clone()),
font_face: match (&self.font_face, &patch.font_face) {
(Some(base), Some(update)) => Some(base.merge(update)),
(None, Some(update)) => Some(update.clone()),
(Some(base), None) => Some(base.clone()),
(None, None) => None,
},
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SpreadsheetGradientStop {
pub position: f64,
pub color: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SpreadsheetFillRectangle {
pub left: f64,
pub right: f64,
pub top: f64,
pub bottom: f64,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub struct SpreadsheetFill {
pub solid_fill_color: Option<String>,
pub pattern_type: Option<String>,
pub pattern_foreground_color: Option<String>,
pub pattern_background_color: Option<String>,
#[serde(default)]
pub color_transforms: Vec<String>,
pub gradient_fill_type: Option<String>,
#[serde(default)]
pub gradient_stops: Vec<SpreadsheetGradientStop>,
pub gradient_kind: Option<String>,
pub angle: Option<f64>,
pub scaled: Option<bool>,
pub path_type: Option<String>,
pub fill_rectangle: Option<SpreadsheetFillRectangle>,
pub image_reference: Option<String>,
}
impl SpreadsheetFill {
fn merge(&self, patch: &Self) -> Self {
Self {
solid_fill_color: patch
.solid_fill_color
.clone()
.or_else(|| self.solid_fill_color.clone()),
pattern_type: patch
.pattern_type
.clone()
.or_else(|| self.pattern_type.clone()),
pattern_foreground_color: patch
.pattern_foreground_color
.clone()
.or_else(|| self.pattern_foreground_color.clone()),
pattern_background_color: patch
.pattern_background_color
.clone()
.or_else(|| self.pattern_background_color.clone()),
color_transforms: if patch.color_transforms.is_empty() {
self.color_transforms.clone()
} else {
patch.color_transforms.clone()
},
gradient_fill_type: patch
.gradient_fill_type
.clone()
.or_else(|| self.gradient_fill_type.clone()),
gradient_stops: if patch.gradient_stops.is_empty() {
self.gradient_stops.clone()
} else {
patch.gradient_stops.clone()
},
gradient_kind: patch
.gradient_kind
.clone()
.or_else(|| self.gradient_kind.clone()),
angle: patch.angle.or(self.angle),
scaled: patch.scaled.or(self.scaled),
path_type: patch.path_type.clone().or_else(|| self.path_type.clone()),
fill_rectangle: patch
.fill_rectangle
.clone()
.or_else(|| self.fill_rectangle.clone()),
image_reference: patch
.image_reference
.clone()
.or_else(|| self.image_reference.clone()),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct SpreadsheetBorderLine {
pub style: Option<String>,
pub color: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct SpreadsheetBorder {
pub top: Option<SpreadsheetBorderLine>,
pub right: Option<SpreadsheetBorderLine>,
pub bottom: Option<SpreadsheetBorderLine>,
pub left: Option<SpreadsheetBorderLine>,
}
impl SpreadsheetBorder {
fn merge(&self, patch: &Self) -> Self {
Self {
top: patch.top.clone().or_else(|| self.top.clone()),
right: patch.right.clone().or_else(|| self.right.clone()),
bottom: patch.bottom.clone().or_else(|| self.bottom.clone()),
left: patch.left.clone().or_else(|| self.left.clone()),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct SpreadsheetAlignment {
pub horizontal: Option<String>,
pub vertical: Option<String>,
}
impl SpreadsheetAlignment {
fn merge(&self, patch: &Self) -> Self {
Self {
horizontal: patch.horizontal.clone().or_else(|| self.horizontal.clone()),
vertical: patch.vertical.clone().or_else(|| self.vertical.clone()),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
pub struct SpreadsheetNumberFormat {
pub format_id: Option<u32>,
pub format_code: Option<String>,
}
impl SpreadsheetNumberFormat {
fn merge(&self, patch: &Self) -> Self {
Self {
format_id: patch.format_id.or(self.format_id),
format_code: patch
.format_code
.clone()
.or_else(|| self.format_code.clone()),
}
}
fn normalized(mut self) -> Self {
if self.format_code.is_none() {
self.format_code = self.format_id.and_then(builtin_number_format_code);
}
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub struct SpreadsheetCellFormat {
pub text_style_id: Option<u32>,
pub fill_id: Option<u32>,
pub border_id: Option<u32>,
pub alignment: Option<SpreadsheetAlignment>,
pub number_format_id: Option<u32>,
pub wrap_text: Option<bool>,
pub base_cell_style_format_id: Option<u32>,
}
impl SpreadsheetCellFormat {
pub fn wrap(mut self) -> Self {
self.wrap_text = Some(true);
self
}
pub fn unwrap(mut self) -> Self {
self.wrap_text = Some(false);
self
}
fn merge(&self, patch: &Self) -> Self {
Self {
text_style_id: patch.text_style_id.or(self.text_style_id),
fill_id: patch.fill_id.or(self.fill_id),
border_id: patch.border_id.or(self.border_id),
alignment: match (&self.alignment, &patch.alignment) {
(Some(base), Some(update)) => Some(base.merge(update)),
(None, Some(update)) => Some(update.clone()),
(Some(base), None) => Some(base.clone()),
(None, None) => None,
},
number_format_id: patch.number_format_id.or(self.number_format_id),
wrap_text: patch.wrap_text.or(self.wrap_text),
base_cell_style_format_id: patch
.base_cell_style_format_id
.or(self.base_cell_style_format_id),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub struct SpreadsheetDifferentialFormat {
pub text_style_id: Option<u32>,
pub fill_id: Option<u32>,
pub border_id: Option<u32>,
pub alignment: Option<SpreadsheetAlignment>,
pub number_format_id: Option<u32>,
pub wrap_text: Option<bool>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SpreadsheetCellFormatSummary {
pub style_index: u32,
pub text_style: Option<SpreadsheetTextStyle>,
pub fill: Option<SpreadsheetFill>,
pub border: Option<SpreadsheetBorder>,
pub alignment: Option<SpreadsheetAlignment>,
pub number_format: Option<SpreadsheetNumberFormat>,
pub wrap_text: Option<bool>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpreadsheetRangeFormat {
pub sheet_name: String,
pub range: String,
}
impl SpreadsheetRangeFormat {
pub fn new(sheet_name: String, range: &CellRange) -> Self {
Self {
sheet_name,
range: range.to_a1(),
}
}
pub fn range_ref(&self) -> Result<SpreadsheetCellRangeRef, SpreadsheetArtifactError> {
let range = CellRange::parse(&self.range)?;
Ok(SpreadsheetCellRangeRef::new(
self.sheet_name.clone(),
&range,
))
}
pub fn top_left_style_index(
&self,
sheet: &SpreadsheetSheet,
) -> Result<u32, SpreadsheetArtifactError> {
self.range_ref()?.top_left_style_index(sheet)
}
pub fn top_left_cell_format(
&self,
artifact: &SpreadsheetArtifact,
sheet: &SpreadsheetSheet,
) -> Result<Option<SpreadsheetCellFormatSummary>, SpreadsheetArtifactError> {
let range = self.range_ref()?.range()?;
Ok(artifact.cell_format_summary(sheet.top_left_style_index(&range)))
}
}
impl SpreadsheetArtifact {
pub fn create_text_style(
&mut self,
style: SpreadsheetTextStyle,
source_style_id: Option<u32>,
merge_with_existing_components: bool,
) -> Result<u32, SpreadsheetArtifactError> {
let created = if let Some(source_style_id) = source_style_id {
let source = self
.text_styles
.get(&source_style_id)
.cloned()
.ok_or_else(|| SpreadsheetArtifactError::InvalidArgs {
action: "create_text_style".to_string(),
message: format!("text style `{source_style_id}` was not found"),
})?;
if merge_with_existing_components {
source.merge(&style)
} else {
style
}
} else {
style
};
Ok(insert_with_next_id(&mut self.text_styles, created))
}
pub fn get_text_style(&self, style_id: u32) -> Option<&SpreadsheetTextStyle> {
self.text_styles.get(&style_id)
}
pub fn create_fill(
&mut self,
fill: SpreadsheetFill,
source_fill_id: Option<u32>,
merge_with_existing_components: bool,
) -> Result<u32, SpreadsheetArtifactError> {
let created = if let Some(source_fill_id) = source_fill_id {
let source = self.fills.get(&source_fill_id).cloned().ok_or_else(|| {
SpreadsheetArtifactError::InvalidArgs {
action: "create_fill".to_string(),
message: format!("fill `{source_fill_id}` was not found"),
}
})?;
if merge_with_existing_components {
source.merge(&fill)
} else {
fill
}
} else {
fill
};
Ok(insert_with_next_id(&mut self.fills, created))
}
pub fn get_fill(&self, fill_id: u32) -> Option<&SpreadsheetFill> {
self.fills.get(&fill_id)
}
pub fn create_border(
&mut self,
border: SpreadsheetBorder,
source_border_id: Option<u32>,
merge_with_existing_components: bool,
) -> Result<u32, SpreadsheetArtifactError> {
let created = if let Some(source_border_id) = source_border_id {
let source = self
.borders
.get(&source_border_id)
.cloned()
.ok_or_else(|| SpreadsheetArtifactError::InvalidArgs {
action: "create_border".to_string(),
message: format!("border `{source_border_id}` was not found"),
})?;
if merge_with_existing_components {
source.merge(&border)
} else {
border
}
} else {
border
};
Ok(insert_with_next_id(&mut self.borders, created))
}
pub fn get_border(&self, border_id: u32) -> Option<&SpreadsheetBorder> {
self.borders.get(&border_id)
}
pub fn create_number_format(
&mut self,
format: SpreadsheetNumberFormat,
source_number_format_id: Option<u32>,
merge_with_existing_components: bool,
) -> Result<u32, SpreadsheetArtifactError> {
let created = if let Some(source_number_format_id) = source_number_format_id {
let source = self
.number_formats
.get(&source_number_format_id)
.cloned()
.ok_or_else(|| SpreadsheetArtifactError::InvalidArgs {
action: "create_number_format".to_string(),
message: format!("number format `{source_number_format_id}` was not found"),
})?;
if merge_with_existing_components {
source.merge(&format)
} else {
format
}
} else {
format
};
Ok(insert_with_next_id(
&mut self.number_formats,
created.normalized(),
))
}
pub fn get_number_format(&self, number_format_id: u32) -> Option<&SpreadsheetNumberFormat> {
self.number_formats.get(&number_format_id)
}
pub fn create_cell_format(
&mut self,
format: SpreadsheetCellFormat,
source_format_id: Option<u32>,
merge_with_existing_components: bool,
) -> Result<u32, SpreadsheetArtifactError> {
let created = if let Some(source_format_id) = source_format_id {
let source = self
.cell_formats
.get(&source_format_id)
.cloned()
.ok_or_else(|| SpreadsheetArtifactError::InvalidArgs {
action: "create_cell_format".to_string(),
message: format!("cell format `{source_format_id}` was not found"),
})?;
if merge_with_existing_components {
source.merge(&format)
} else {
format
}
} else {
format
};
Ok(insert_with_next_id(&mut self.cell_formats, created))
}
pub fn get_cell_format(&self, format_id: u32) -> Option<&SpreadsheetCellFormat> {
self.cell_formats.get(&format_id)
}
pub fn create_differential_format(&mut self, format: SpreadsheetDifferentialFormat) -> u32 {
insert_with_next_id(&mut self.differential_formats, format)
}
pub fn get_differential_format(
&self,
format_id: u32,
) -> Option<&SpreadsheetDifferentialFormat> {
self.differential_formats.get(&format_id)
}
pub fn resolve_cell_format(&self, style_index: u32) -> Option<SpreadsheetCellFormat> {
let format = self.cell_formats.get(&style_index)?.clone();
resolve_cell_format_recursive(&self.cell_formats, &format, 0)
}
pub fn cell_format_summary(&self, style_index: u32) -> Option<SpreadsheetCellFormatSummary> {
let resolved = self.resolve_cell_format(style_index)?;
Some(SpreadsheetCellFormatSummary {
style_index,
text_style: resolved
.text_style_id
.and_then(|id| self.text_styles.get(&id).cloned()),
fill: resolved.fill_id.and_then(|id| self.fills.get(&id).cloned()),
border: resolved
.border_id
.and_then(|id| self.borders.get(&id).cloned()),
alignment: resolved.alignment,
number_format: resolved
.number_format_id
.and_then(|id| self.number_formats.get(&id).cloned()),
wrap_text: resolved.wrap_text,
})
}
}
impl SpreadsheetSheet {
pub fn range_format(&self, range: &CellRange) -> SpreadsheetRangeFormat {
SpreadsheetRangeFormat::new(self.name.clone(), range)
}
}
fn insert_with_next_id<T>(map: &mut BTreeMap<u32, T>, value: T) -> u32 {
let next_id = map.last_key_value().map(|(key, _)| key + 1).unwrap_or(1);
map.insert(next_id, value);
next_id
}
fn resolve_cell_format_recursive(
cell_formats: &BTreeMap<u32, SpreadsheetCellFormat>,
format: &SpreadsheetCellFormat,
depth: usize,
) -> Option<SpreadsheetCellFormat> {
if depth > 32 {
return None;
}
let base = format
.base_cell_style_format_id
.and_then(|id| cell_formats.get(&id))
.and_then(|base| resolve_cell_format_recursive(cell_formats, base, depth + 1));
Some(match base {
Some(base) => base.merge(format),
None => format.clone(),
})
}
fn builtin_number_format_code(format_id: u32) -> Option<String> {
match format_id {
0 => Some("General".to_string()),
1 => Some("0".to_string()),
2 => Some("0.00".to_string()),
3 => Some("#,##0".to_string()),
4 => Some("#,##0.00".to_string()),
9 => Some("0%".to_string()),
10 => Some("0.00%".to_string()),
_ => None,
}
}

View File

@@ -0,0 +1,630 @@
use std::collections::BTreeMap;
use std::collections::BTreeSet;
use serde::Deserialize;
use serde::Serialize;
use crate::CellAddress;
use crate::CellRange;
use crate::SpreadsheetArtifactError;
use crate::SpreadsheetCellValue;
use crate::SpreadsheetSheet;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SpreadsheetTableColumn {
pub id: u32,
pub name: String,
pub totals_row_label: Option<String>,
pub totals_row_function: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SpreadsheetTable {
pub id: u32,
pub name: String,
pub display_name: String,
pub range: String,
pub header_row_count: u32,
pub totals_row_count: u32,
pub style_name: Option<String>,
pub show_first_column: bool,
pub show_last_column: bool,
pub show_row_stripes: bool,
pub show_column_stripes: bool,
#[serde(default)]
pub columns: Vec<SpreadsheetTableColumn>,
#[serde(default)]
pub filters: BTreeMap<u32, String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SpreadsheetTableView {
pub id: u32,
pub name: String,
pub display_name: String,
pub address: String,
pub full_range: String,
pub header_row_count: u32,
pub totals_row_count: u32,
pub totals_row_visible: bool,
pub header_row_range: Option<String>,
pub data_body_range: Option<String>,
pub totals_row_range: Option<String>,
pub style_name: Option<String>,
pub show_first_column: bool,
pub show_last_column: bool,
pub show_row_stripes: bool,
pub show_column_stripes: bool,
pub columns: Vec<SpreadsheetTableColumn>,
}
#[derive(Debug, Clone, Default)]
pub struct SpreadsheetTableLookup<'a> {
pub name: Option<&'a str>,
pub display_name: Option<&'a str>,
pub id: Option<u32>,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpreadsheetCreateTableOptions {
pub name: Option<String>,
pub display_name: Option<String>,
pub header_row_count: u32,
pub totals_row_count: u32,
pub style_name: Option<String>,
pub show_first_column: bool,
pub show_last_column: bool,
pub show_row_stripes: bool,
pub show_column_stripes: bool,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct SpreadsheetTableStyleOptions {
pub style_name: Option<String>,
pub show_first_column: Option<bool>,
pub show_last_column: Option<bool>,
pub show_row_stripes: Option<bool>,
pub show_column_stripes: Option<bool>,
}
impl SpreadsheetTable {
pub fn range(&self) -> Result<CellRange, SpreadsheetArtifactError> {
CellRange::parse(&self.range)
}
pub fn address(&self) -> String {
self.range.clone()
}
pub fn full_range(&self) -> String {
self.range.clone()
}
pub fn totals_row_visible(&self) -> bool {
self.totals_row_count > 0
}
pub fn header_row_range(&self) -> Result<Option<CellRange>, SpreadsheetArtifactError> {
if self.header_row_count == 0 {
return Ok(None);
}
let range = self.range()?;
Ok(Some(CellRange::from_start_end(
range.start,
CellAddress {
column: range.end.column,
row: range.start.row + self.header_row_count - 1,
},
)))
}
pub fn data_body_range(&self) -> Result<Option<CellRange>, SpreadsheetArtifactError> {
let range = self.range()?;
let start_row = range.start.row + self.header_row_count;
let end_row = range.end.row.saturating_sub(self.totals_row_count);
if start_row > end_row {
return Ok(None);
}
Ok(Some(CellRange::from_start_end(
CellAddress {
column: range.start.column,
row: start_row,
},
CellAddress {
column: range.end.column,
row: end_row,
},
)))
}
pub fn totals_row_range(&self) -> Result<Option<CellRange>, SpreadsheetArtifactError> {
if self.totals_row_count == 0 {
return Ok(None);
}
let range = self.range()?;
Ok(Some(CellRange::from_start_end(
CellAddress {
column: range.start.column,
row: range.end.row - self.totals_row_count + 1,
},
range.end,
)))
}
pub fn view(&self) -> Result<SpreadsheetTableView, SpreadsheetArtifactError> {
Ok(SpreadsheetTableView {
id: self.id,
name: self.name.clone(),
display_name: self.display_name.clone(),
address: self.address(),
full_range: self.full_range(),
header_row_count: self.header_row_count,
totals_row_count: self.totals_row_count,
totals_row_visible: self.totals_row_visible(),
header_row_range: self.header_row_range()?.map(|range| range.to_a1()),
data_body_range: self.data_body_range()?.map(|range| range.to_a1()),
totals_row_range: self.totals_row_range()?.map(|range| range.to_a1()),
style_name: self.style_name.clone(),
show_first_column: self.show_first_column,
show_last_column: self.show_last_column,
show_row_stripes: self.show_row_stripes,
show_column_stripes: self.show_column_stripes,
columns: self.columns.clone(),
})
}
}
impl SpreadsheetSheet {
pub fn create_table(
&mut self,
action: &str,
range: &CellRange,
options: SpreadsheetCreateTableOptions,
) -> Result<u32, SpreadsheetArtifactError> {
validate_table_geometry(
action,
range,
options.header_row_count,
options.totals_row_count,
)?;
for table in &self.tables {
let table_range = table.range()?;
if table_range.intersects(range) {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: format!(
"table range `{}` intersects existing table `{}`",
range.to_a1(),
table.name
),
});
}
}
let next_id = self.tables.iter().map(|table| table.id).max().unwrap_or(0) + 1;
let name = options.name.unwrap_or_else(|| format!("Table{next_id}"));
if name.trim().is_empty() {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "table name cannot be empty".to_string(),
});
}
let display_name = options.display_name.unwrap_or_else(|| name.clone());
if display_name.trim().is_empty() {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "table display_name cannot be empty".to_string(),
});
}
ensure_unique_table_name(&self.tables, action, &name, &display_name, None)?;
let columns = build_table_columns(self, range, options.header_row_count);
self.tables.push(SpreadsheetTable {
id: next_id,
name,
display_name,
range: range.to_a1(),
header_row_count: options.header_row_count,
totals_row_count: options.totals_row_count,
style_name: options.style_name,
show_first_column: options.show_first_column,
show_last_column: options.show_last_column,
show_row_stripes: options.show_row_stripes,
show_column_stripes: options.show_column_stripes,
columns,
filters: BTreeMap::new(),
});
Ok(next_id)
}
pub fn list_tables(
&self,
range: Option<&CellRange>,
) -> Result<Vec<SpreadsheetTableView>, SpreadsheetArtifactError> {
self.tables
.iter()
.filter(|table| {
range.is_none_or(|target| {
table
.range()
.map(|table_range| table_range.intersects(target))
.unwrap_or(false)
})
})
.map(SpreadsheetTable::view)
.collect()
}
pub fn get_table(
&self,
action: &str,
lookup: SpreadsheetTableLookup<'_>,
) -> Result<&SpreadsheetTable, SpreadsheetArtifactError> {
self.table_lookup_internal(action, lookup)
}
pub fn get_table_view(
&self,
action: &str,
lookup: SpreadsheetTableLookup<'_>,
) -> Result<SpreadsheetTableView, SpreadsheetArtifactError> {
self.get_table(action, lookup)?.view()
}
pub fn delete_table(
&mut self,
action: &str,
lookup: SpreadsheetTableLookup<'_>,
) -> Result<(), SpreadsheetArtifactError> {
let index = self.table_index(action, lookup)?;
self.tables.remove(index);
Ok(())
}
pub fn set_table_style(
&mut self,
action: &str,
lookup: SpreadsheetTableLookup<'_>,
options: SpreadsheetTableStyleOptions,
) -> Result<(), SpreadsheetArtifactError> {
let table = self.table_lookup_mut(action, lookup)?;
table.style_name = options.style_name;
if let Some(value) = options.show_first_column {
table.show_first_column = value;
}
if let Some(value) = options.show_last_column {
table.show_last_column = value;
}
if let Some(value) = options.show_row_stripes {
table.show_row_stripes = value;
}
if let Some(value) = options.show_column_stripes {
table.show_column_stripes = value;
}
Ok(())
}
pub fn clear_table_filters(
&mut self,
action: &str,
lookup: SpreadsheetTableLookup<'_>,
) -> Result<(), SpreadsheetArtifactError> {
self.table_lookup_mut(action, lookup)?.filters.clear();
Ok(())
}
pub fn reapply_table_filters(
&mut self,
action: &str,
lookup: SpreadsheetTableLookup<'_>,
) -> Result<(), SpreadsheetArtifactError> {
let _ = self.table_lookup_mut(action, lookup)?;
Ok(())
}
pub fn rename_table_column(
&mut self,
action: &str,
lookup: SpreadsheetTableLookup<'_>,
column_id: Option<u32>,
column_name: Option<&str>,
new_name: String,
) -> Result<SpreadsheetTableColumn, SpreadsheetArtifactError> {
if new_name.trim().is_empty() {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "table column name cannot be empty".to_string(),
});
}
let table = self.table_lookup_mut(action, lookup)?;
if table
.columns
.iter()
.any(|column| column.name == new_name && Some(column.id) != column_id)
{
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("table column `{new_name}` already exists"),
});
}
let column = table_column_lookup_mut(&mut table.columns, action, column_id, column_name)?;
column.name = new_name;
Ok(column.clone())
}
pub fn set_table_column_totals(
&mut self,
action: &str,
lookup: SpreadsheetTableLookup<'_>,
column_id: Option<u32>,
column_name: Option<&str>,
totals_row_label: Option<String>,
totals_row_function: Option<String>,
) -> Result<SpreadsheetTableColumn, SpreadsheetArtifactError> {
let table = self.table_lookup_mut(action, lookup)?;
let column = table_column_lookup_mut(&mut table.columns, action, column_id, column_name)?;
column.totals_row_label = totals_row_label;
column.totals_row_function = totals_row_function;
Ok(column.clone())
}
pub fn validate_tables(&self, action: &str) -> Result<(), SpreadsheetArtifactError> {
let mut seen_names = BTreeSet::new();
let mut seen_display_names = BTreeSet::new();
for table in &self.tables {
let range = table.range()?;
validate_table_geometry(
action,
&range,
table.header_row_count,
table.totals_row_count,
)?;
if !seen_names.insert(table.name.clone()) {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("duplicate table name `{}`", table.name),
});
}
if !seen_display_names.insert(table.display_name.clone()) {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("duplicate table display_name `{}`", table.display_name),
});
}
let column_names = table
.columns
.iter()
.map(|column| column.name.clone())
.collect::<BTreeSet<_>>();
if column_names.len() != table.columns.len() {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("table `{}` has duplicate column names", table.name),
});
}
}
for index in 0..self.tables.len() {
for other in index + 1..self.tables.len() {
if self.tables[index]
.range()?
.intersects(&self.tables[other].range()?)
{
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: format!(
"table `{}` intersects table `{}`",
self.tables[index].name, self.tables[other].name
),
});
}
}
}
Ok(())
}
fn table_index(
&self,
action: &str,
lookup: SpreadsheetTableLookup<'_>,
) -> Result<usize, SpreadsheetArtifactError> {
self.tables
.iter()
.position(|table| table_matches_lookup(table, lookup.clone()))
.ok_or_else(|| SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: describe_missing_table(lookup),
})
}
fn table_lookup_internal(
&self,
action: &str,
lookup: SpreadsheetTableLookup<'_>,
) -> Result<&SpreadsheetTable, SpreadsheetArtifactError> {
self.tables
.iter()
.find(|table| table_matches_lookup(table, lookup.clone()))
.ok_or_else(|| SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: describe_missing_table(lookup),
})
}
fn table_lookup_mut(
&mut self,
action: &str,
lookup: SpreadsheetTableLookup<'_>,
) -> Result<&mut SpreadsheetTable, SpreadsheetArtifactError> {
let index = self.table_index(action, lookup)?;
Ok(&mut self.tables[index])
}
}
fn table_matches_lookup(table: &SpreadsheetTable, lookup: SpreadsheetTableLookup<'_>) -> bool {
if let Some(name) = lookup.name {
table.name == name
} else if let Some(display_name) = lookup.display_name {
table.display_name == display_name
} else if let Some(id) = lookup.id {
table.id == id
} else {
false
}
}
fn describe_missing_table(lookup: SpreadsheetTableLookup<'_>) -> String {
if let Some(name) = lookup.name {
format!("table name `{name}` was not found")
} else if let Some(display_name) = lookup.display_name {
format!("table display_name `{display_name}` was not found")
} else if let Some(id) = lookup.id {
format!("table id `{id}` was not found")
} else {
"table name, display_name, or id is required".to_string()
}
}
fn ensure_unique_table_name(
tables: &[SpreadsheetTable],
action: &str,
name: &str,
display_name: &str,
exclude_id: Option<u32>,
) -> Result<(), SpreadsheetArtifactError> {
if tables.iter().any(|table| {
Some(table.id) != exclude_id && (table.name == name || table.display_name == name)
}) {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("table name `{name}` already exists"),
});
}
if tables.iter().any(|table| {
Some(table.id) != exclude_id
&& (table.display_name == display_name || table.name == display_name)
}) {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: format!("table display_name `{display_name}` already exists"),
});
}
Ok(())
}
fn validate_table_geometry(
action: &str,
range: &CellRange,
header_row_count: u32,
totals_row_count: u32,
) -> Result<(), SpreadsheetArtifactError> {
if range.width() == 0 {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "table range must include at least one column".to_string(),
});
}
if header_row_count + totals_row_count > range.height() as u32 {
return Err(SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: "table range is smaller than header and totals rows".to_string(),
});
}
Ok(())
}
fn build_table_columns(
sheet: &SpreadsheetSheet,
range: &CellRange,
header_row_count: u32,
) -> Vec<SpreadsheetTableColumn> {
let header_row = range.start.row + header_row_count.saturating_sub(1);
let default_names = (0..range.width())
.map(|index| format!("Column{}", index + 1))
.collect::<Vec<_>>();
let names = unique_table_column_names(
(range.start.column..=range.end.column)
.enumerate()
.map(|(index, column)| {
if header_row_count == 0 {
return default_names[index].clone();
}
sheet
.get_cell(CellAddress {
column,
row: header_row,
})
.and_then(|cell| cell.value.as_ref())
.map(cell_value_to_table_header)
.filter(|value| !value.trim().is_empty())
.unwrap_or_else(|| default_names[index].clone())
})
.collect::<Vec<_>>(),
);
names
.into_iter()
.enumerate()
.map(|(index, name)| SpreadsheetTableColumn {
id: index as u32 + 1,
name,
totals_row_label: None,
totals_row_function: None,
})
.collect()
}
fn unique_table_column_names(names: Vec<String>) -> Vec<String> {
let mut seen = BTreeMap::<String, u32>::new();
names
.into_iter()
.map(|name| {
let entry = seen.entry(name.clone()).or_insert(0);
*entry += 1;
if *entry == 1 {
name
} else {
format!("{name}_{}", *entry)
}
})
.collect()
}
fn cell_value_to_table_header(value: &SpreadsheetCellValue) -> String {
match value {
SpreadsheetCellValue::Bool(value) => value.to_string(),
SpreadsheetCellValue::Integer(value) => value.to_string(),
SpreadsheetCellValue::Float(value) => value.to_string(),
SpreadsheetCellValue::String(value)
| SpreadsheetCellValue::DateTime(value)
| SpreadsheetCellValue::Error(value) => value.clone(),
}
}
fn table_column_lookup_mut<'a>(
columns: &'a mut [SpreadsheetTableColumn],
action: &str,
column_id: Option<u32>,
column_name: Option<&str>,
) -> Result<&'a mut SpreadsheetTableColumn, SpreadsheetArtifactError> {
columns
.iter_mut()
.find(|column| {
if let Some(column_id) = column_id {
column.id == column_id
} else if let Some(column_name) = column_name {
column.name == column_name
} else {
false
}
})
.ok_or_else(|| SpreadsheetArtifactError::InvalidArgs {
action: action.to_string(),
message: if let Some(column_id) = column_id {
format!("table column id `{column_id}` was not found")
} else if let Some(column_name) = column_name {
format!("table column `{column_name}` was not found")
} else {
"table column id or name is required".to_string()
},
})
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,860 @@
use std::collections::BTreeMap;
use std::fs::File;
use std::io::Read;
use std::io::Write;
use std::path::Path;
use std::path::PathBuf;
use regex::Regex;
use zip::ZipArchive;
use zip::ZipWriter;
use zip::write::SimpleFileOptions;
use crate::CellAddress;
use crate::CellRange;
use crate::SpreadsheetArtifact;
use crate::SpreadsheetArtifactError;
use crate::SpreadsheetCell;
use crate::SpreadsheetCellValue;
use crate::SpreadsheetSheet;
pub(crate) fn write_xlsx(
artifact: &mut SpreadsheetArtifact,
path: &Path,
) -> Result<PathBuf, SpreadsheetArtifactError> {
if artifact.auto_recalculate {
artifact.recalculate();
}
for sheet in &mut artifact.sheets {
sheet.cleanup_and_validate_sheet()?;
}
let file = File::create(path).map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
let mut zip = ZipWriter::new(file);
let options = SimpleFileOptions::default();
let sheet_count = artifact.sheets.len().max(1);
zip.start_file("[Content_Types].xml", options)
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
zip.write_all(content_types_xml(sheet_count).as_bytes())
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
zip.add_directory("_rels/", options).map_err(|error| {
SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
}
})?;
zip.start_file("_rels/.rels", options).map_err(|error| {
SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
}
})?;
zip.write_all(root_relationships_xml().as_bytes())
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
zip.add_directory("docProps/", options).map_err(|error| {
SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
}
})?;
zip.start_file("docProps/app.xml", options)
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
zip.write_all(app_xml(artifact).as_bytes())
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
zip.start_file("docProps/core.xml", options)
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
zip.write_all(core_xml(artifact).as_bytes())
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
zip.add_directory("xl/", options)
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
zip.start_file("xl/workbook.xml", options)
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
zip.write_all(workbook_xml(artifact).as_bytes())
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
zip.add_directory("xl/_rels/", options).map_err(|error| {
SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
}
})?;
zip.start_file("xl/_rels/workbook.xml.rels", options)
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
zip.write_all(workbook_relationships_xml(artifact).as_bytes())
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
zip.start_file("xl/styles.xml", options).map_err(|error| {
SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
}
})?;
zip.write_all(styles_xml(artifact).as_bytes())
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
zip.add_directory("xl/worksheets/", options)
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
if artifact.sheets.is_empty() {
let empty = SpreadsheetSheet::new("Sheet1".to_string());
zip.start_file("xl/worksheets/sheet1.xml", options)
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
zip.write_all(sheet_xml(&empty).as_bytes())
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
} else {
for (index, sheet) in artifact.sheets.iter().enumerate() {
zip.start_file(format!("xl/worksheets/sheet{}.xml", index + 1), options)
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
zip.write_all(sheet_xml(sheet).as_bytes())
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
}
}
zip.finish()
.map_err(|error| SpreadsheetArtifactError::ExportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
Ok(path.to_path_buf())
}
pub(crate) fn import_xlsx(
path: &Path,
artifact_id: Option<String>,
) -> Result<SpreadsheetArtifact, SpreadsheetArtifactError> {
let file = File::open(path).map_err(|error| SpreadsheetArtifactError::ImportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
let mut archive =
ZipArchive::new(file).map_err(|error| SpreadsheetArtifactError::ImportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
let workbook_xml = read_zip_entry(&mut archive, "xl/workbook.xml", path)?;
let workbook_rels = read_zip_entry(&mut archive, "xl/_rels/workbook.xml.rels", path)?;
let workbook_name = if archive.by_name("docProps/core.xml").is_ok() {
let title =
extract_workbook_title(&read_zip_entry(&mut archive, "docProps/core.xml", path)?);
(!title.trim().is_empty()).then_some(title)
} else {
None
};
let shared_strings = if archive.by_name("xl/sharedStrings.xml").is_ok() {
Some(parse_shared_strings(&read_zip_entry(
&mut archive,
"xl/sharedStrings.xml",
path,
)?)?)
} else {
None
};
let relationships = parse_relationships(&workbook_rels)?;
let sheets = parse_sheet_definitions(&workbook_xml)?
.into_iter()
.map(|(name, relation)| {
let target = relationships.get(&relation).ok_or_else(|| {
SpreadsheetArtifactError::ImportFailed {
path: path.to_path_buf(),
message: format!("missing relationship `{relation}` for sheet `{name}`"),
}
})?;
let normalized = if target.starts_with('/') {
target.trim_start_matches('/').to_string()
} else if target.starts_with("xl/") {
target.clone()
} else {
format!("xl/{target}")
};
Ok((name, normalized))
})
.collect::<Result<Vec<_>, SpreadsheetArtifactError>>()?;
let mut artifact = SpreadsheetArtifact::new(workbook_name.or_else(|| {
path.file_stem()
.and_then(|value| value.to_str())
.map(str::to_string)
}));
if let Some(artifact_id) = artifact_id {
artifact.artifact_id = artifact_id;
}
artifact.sheets.clear();
for (name, target) in sheets {
let xml = read_zip_entry(&mut archive, &target, path)?;
let sheet = parse_sheet(&name, &xml, shared_strings.as_deref())?;
artifact.sheets.push(sheet);
}
Ok(artifact)
}
fn read_zip_entry(
archive: &mut ZipArchive<File>,
entry: &str,
path: &Path,
) -> Result<String, SpreadsheetArtifactError> {
let mut file =
archive
.by_name(entry)
.map_err(|error| SpreadsheetArtifactError::ImportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
let mut text = String::new();
file.read_to_string(&mut text)
.map_err(|error| SpreadsheetArtifactError::ImportFailed {
path: path.to_path_buf(),
message: error.to_string(),
})?;
Ok(text)
}
fn parse_sheet_definitions(
workbook_xml: &str,
) -> Result<Vec<(String, String)>, SpreadsheetArtifactError> {
let regex = Regex::new(r#"<sheet\b([^>]*)/?>"#).map_err(|error| {
SpreadsheetArtifactError::Serialization {
message: error.to_string(),
}
})?;
let mut sheets = Vec::new();
for captures in regex.captures_iter(workbook_xml) {
let Some(attributes) = captures.get(1).map(|value| value.as_str()) else {
continue;
};
let Some(name) = extract_attribute(attributes, "name") else {
continue;
};
let relation = extract_attribute(attributes, "r:id")
.or_else(|| extract_attribute(attributes, "id"))
.unwrap_or_default();
sheets.push((xml_unescape(&name), relation));
}
Ok(sheets)
}
fn parse_relationships(xml: &str) -> Result<BTreeMap<String, String>, SpreadsheetArtifactError> {
let regex = Regex::new(r#"<Relationship\b([^>]*)/?>"#).map_err(|error| {
SpreadsheetArtifactError::Serialization {
message: error.to_string(),
}
})?;
Ok(regex
.captures_iter(xml)
.filter_map(|captures| {
let attributes = captures.get(1)?.as_str();
let id = extract_attribute(attributes, "Id")?;
let target = extract_attribute(attributes, "Target")?;
Some((id, target))
})
.collect())
}
fn parse_shared_strings(xml: &str) -> Result<Vec<String>, SpreadsheetArtifactError> {
let regex = Regex::new(r#"(?s)<si\b[^>]*>(.*?)</si>"#).map_err(|error| {
SpreadsheetArtifactError::Serialization {
message: error.to_string(),
}
})?;
regex
.captures_iter(xml)
.filter_map(|captures| captures.get(1).map(|value| value.as_str()))
.map(all_text_nodes)
.collect()
}
fn parse_sheet(
name: &str,
xml: &str,
shared_strings: Option<&[String]>,
) -> Result<SpreadsheetSheet, SpreadsheetArtifactError> {
let mut sheet = SpreadsheetSheet::new(name.to_string());
if let Some(sheet_view) = first_tag_attributes(xml, "sheetView")
&& let Some(show_grid_lines) = extract_attribute(&sheet_view, "showGridLines")
{
sheet.show_grid_lines = show_grid_lines != "0";
}
if let Some(format_pr) = first_tag_attributes(xml, "sheetFormatPr") {
sheet.default_row_height = extract_attribute(&format_pr, "defaultRowHeight")
.and_then(|value| value.parse::<f64>().ok());
sheet.default_column_width = extract_attribute(&format_pr, "defaultColWidth")
.and_then(|value| value.parse::<f64>().ok());
}
let col_regex = Regex::new(r#"<col\b([^>]*)/?>"#).map_err(|error| {
SpreadsheetArtifactError::Serialization {
message: error.to_string(),
}
})?;
for captures in col_regex.captures_iter(xml) {
let Some(attributes) = captures.get(1).map(|value| value.as_str()) else {
continue;
};
let Some(min) =
extract_attribute(attributes, "min").and_then(|value| value.parse::<u32>().ok())
else {
continue;
};
let Some(max) =
extract_attribute(attributes, "max").and_then(|value| value.parse::<u32>().ok())
else {
continue;
};
let Some(width) =
extract_attribute(attributes, "width").and_then(|value| value.parse::<f64>().ok())
else {
continue;
};
for column in min..=max {
sheet.column_widths.insert(column, width);
}
}
let row_regex = Regex::new(r#"(?s)<row\b([^>]*)>(.*?)</row>"#).map_err(|error| {
SpreadsheetArtifactError::Serialization {
message: error.to_string(),
}
})?;
let cell_regex = Regex::new(r#"(?s)<c\b([^>]*)>(.*?)</c>"#).map_err(|error| {
SpreadsheetArtifactError::Serialization {
message: error.to_string(),
}
})?;
for row_captures in row_regex.captures_iter(xml) {
let row_attributes = row_captures
.get(1)
.map(|value| value.as_str())
.unwrap_or_default();
if let Some(row_index) =
extract_attribute(row_attributes, "r").and_then(|value| value.parse::<u32>().ok())
&& let Some(height) =
extract_attribute(row_attributes, "ht").and_then(|value| value.parse::<f64>().ok())
&& row_index > 0
&& height > 0.0
{
sheet.row_heights.insert(row_index, height);
}
let Some(row_body) = row_captures.get(2).map(|value| value.as_str()) else {
continue;
};
for cell_captures in cell_regex.captures_iter(row_body) {
let Some(attributes) = cell_captures.get(1).map(|value| value.as_str()) else {
continue;
};
let Some(body) = cell_captures.get(2).map(|value| value.as_str()) else {
continue;
};
let Some(address) = extract_attribute(attributes, "r") else {
continue;
};
let address = CellAddress::parse(&address)?;
let style_index = extract_attribute(attributes, "s")
.and_then(|value| value.parse::<u32>().ok())
.unwrap_or(0);
let cell_type = extract_attribute(attributes, "t").unwrap_or_default();
let formula = first_tag_text(body, "f").map(|value| format!("={value}"));
let value = parse_cell_value(body, &cell_type, shared_strings)?;
let cell = SpreadsheetCell {
value,
formula,
style_index,
citations: Vec::new(),
};
if !cell.is_empty() {
sheet.cells.insert(address, cell);
}
}
}
let merge_regex = Regex::new(r#"<mergeCell\b([^>]*)/?>"#).map_err(|error| {
SpreadsheetArtifactError::Serialization {
message: error.to_string(),
}
})?;
for captures in merge_regex.captures_iter(xml) {
let Some(attributes) = captures.get(1).map(|value| value.as_str()) else {
continue;
};
if let Some(reference) = extract_attribute(attributes, "ref") {
sheet.merged_ranges.push(CellRange::parse(&reference)?);
}
}
Ok(sheet)
}
fn parse_cell_value(
body: &str,
cell_type: &str,
shared_strings: Option<&[String]>,
) -> Result<Option<SpreadsheetCellValue>, SpreadsheetArtifactError> {
let inline_text = first_tag_text(body, "t").map(|value| xml_unescape(&value));
let raw_value = first_tag_text(body, "v").map(|value| xml_unescape(&value));
let parsed = match cell_type {
"inlineStr" => inline_text.map(SpreadsheetCellValue::String),
"s" => raw_value
.and_then(|value| value.parse::<usize>().ok())
.and_then(|index| shared_strings.and_then(|entries| entries.get(index).cloned()))
.map(SpreadsheetCellValue::String),
"b" => raw_value.map(|value| SpreadsheetCellValue::Bool(value == "1")),
"str" => raw_value.map(SpreadsheetCellValue::String),
"e" => raw_value.map(SpreadsheetCellValue::Error),
_ => match raw_value {
Some(value) => {
if let Ok(integer) = value.parse::<i64>() {
Some(SpreadsheetCellValue::Integer(integer))
} else if let Ok(float) = value.parse::<f64>() {
Some(SpreadsheetCellValue::Float(float))
} else {
Some(SpreadsheetCellValue::String(value))
}
}
None => None,
},
};
Ok(parsed)
}
fn content_types_xml(sheet_count: usize) -> String {
let mut overrides = String::new();
for index in 1..=sheet_count {
overrides.push_str(&format!(
r#"<Override PartName="/xl/worksheets/sheet{index}.xml" ContentType="application/vnd.openxmlformats-officedocument.spreadsheetml.worksheet+xml"/>"#
));
}
format!(
"{}{}{}{}{}{}{}{}{}{}",
r#"<?xml version="1.0" encoding="UTF-8" standalone="yes"?>"#,
r#"<Types xmlns="http://schemas.openxmlformats.org/package/2006/content-types">"#,
r#"<Default Extension="rels" ContentType="application/vnd.openxmlformats-package.relationships+xml"/>"#,
r#"<Default Extension="xml" ContentType="application/xml"/>"#,
r#"<Override PartName="/xl/workbook.xml" ContentType="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet.main+xml"/>"#,
r#"<Override PartName="/xl/styles.xml" ContentType="application/vnd.openxmlformats-officedocument.spreadsheetml.styles+xml"/>"#,
r#"<Override PartName="/docProps/core.xml" ContentType="application/vnd.openxmlformats-package.core-properties+xml"/>"#,
r#"<Override PartName="/docProps/app.xml" ContentType="application/vnd.openxmlformats-officedocument.extended-properties+xml"/>"#,
overrides,
r#"</Types>"#
)
}
fn root_relationships_xml() -> &'static str {
concat!(
r#"<?xml version="1.0" encoding="UTF-8" standalone="yes"?>"#,
r#"<Relationships xmlns="http://schemas.openxmlformats.org/package/2006/relationships">"#,
r#"<Relationship Id="rId1" Type="http://schemas.openxmlformats.org/officeDocument/2006/relationships/officeDocument" Target="xl/workbook.xml"/>"#,
r#"<Relationship Id="rId2" Type="http://schemas.openxmlformats.org/package/2006/relationships/metadata/core-properties" Target="docProps/core.xml"/>"#,
r#"<Relationship Id="rId3" Type="http://schemas.openxmlformats.org/officeDocument/2006/relationships/extended-properties" Target="docProps/app.xml"/>"#,
r#"</Relationships>"#
)
}
fn app_xml(artifact: &SpreadsheetArtifact) -> String {
let title = artifact
.name
.clone()
.unwrap_or_else(|| "Spreadsheet".to_string());
format!(
concat!(
r#"<?xml version="1.0" encoding="UTF-8" standalone="yes"?>"#,
r#"<Properties xmlns="http://schemas.openxmlformats.org/officeDocument/2006/extended-properties" xmlns:vt="http://schemas.openxmlformats.org/officeDocument/2006/docPropsVTypes">"#,
r#"<Application>Codex</Application>"#,
r#"<DocSecurity>0</DocSecurity>"#,
r#"<ScaleCrop>false</ScaleCrop>"#,
r#"<HeadingPairs><vt:vector size="2" baseType="variant"><vt:variant><vt:lpstr>Worksheets</vt:lpstr></vt:variant><vt:variant><vt:i4>{}</vt:i4></vt:variant></vt:vector></HeadingPairs>"#,
r#"<TitlesOfParts><vt:vector size="{}" baseType="lpstr">{}</vt:vector></TitlesOfParts>"#,
r#"<Company>OpenAI</Company>"#,
r#"<Manager>{}</Manager>"#,
r#"</Properties>"#
),
artifact.sheets.len(),
artifact.sheets.len(),
artifact
.sheets
.iter()
.map(|sheet| format!(r#"<vt:lpstr>{}</vt:lpstr>"#, xml_escape(&sheet.name)))
.collect::<Vec<_>>()
.join(""),
xml_escape(&title),
)
}
fn core_xml(artifact: &SpreadsheetArtifact) -> String {
let title = artifact
.name
.clone()
.unwrap_or_else(|| artifact.artifact_id.clone());
format!(
concat!(
r#"<?xml version="1.0" encoding="UTF-8" standalone="yes"?>"#,
r#"<cp:coreProperties xmlns:cp="http://schemas.openxmlformats.org/package/2006/metadata/core-properties" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:dcterms="http://purl.org/dc/terms/" xmlns:dcmitype="http://purl.org/dc/dcmitype/" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">"#,
r#"<dc:title>{}</dc:title>"#,
r#"<dc:creator>Codex</dc:creator>"#,
r#"<cp:lastModifiedBy>Codex</cp:lastModifiedBy>"#,
r#"</cp:coreProperties>"#
),
xml_escape(&title),
)
}
fn workbook_xml(artifact: &SpreadsheetArtifact) -> String {
let sheets = if artifact.sheets.is_empty() {
r#"<sheet name="Sheet1" sheetId="1" r:id="rId1"/>"#.to_string()
} else {
artifact
.sheets
.iter()
.enumerate()
.map(|(index, sheet)| {
format!(
r#"<sheet name="{}" sheetId="{}" r:id="rId{}"/>"#,
xml_escape(&sheet.name),
index + 1,
index + 1
)
})
.collect::<Vec<_>>()
.join("")
};
format!(
"{}{}{}<sheets>{}</sheets>{}",
r#"<?xml version="1.0" encoding="UTF-8" standalone="yes"?>"#,
r#"<workbook xmlns="http://schemas.openxmlformats.org/spreadsheetml/2006/main" xmlns:r="http://schemas.openxmlformats.org/officeDocument/2006/relationships">"#,
r#"<bookViews><workbookView/></bookViews>"#,
sheets,
r#"</workbook>"#
)
}
fn workbook_relationships_xml(artifact: &SpreadsheetArtifact) -> String {
let sheet_relationships = if artifact.sheets.is_empty() {
r#"<Relationship Id="rId1" Type="http://schemas.openxmlformats.org/officeDocument/2006/relationships/worksheet" Target="worksheets/sheet1.xml"/>"#.to_string()
} else {
artifact
.sheets
.iter()
.enumerate()
.map(|(index, _)| {
format!(
r#"<Relationship Id="rId{}" Type="http://schemas.openxmlformats.org/officeDocument/2006/relationships/worksheet" Target="worksheets/sheet{}.xml"/>"#,
index + 1,
index + 1
)
})
.collect::<Vec<_>>()
.join("")
};
let style_relation_id = artifact.sheets.len().max(1) + 1;
format!(
"{}{}{}<Relationship Id=\"rId{}\" Type=\"http://schemas.openxmlformats.org/officeDocument/2006/relationships/styles\" Target=\"styles.xml\"/>{}",
r#"<?xml version="1.0" encoding="UTF-8" standalone="yes"?>"#,
r#"<Relationships xmlns="http://schemas.openxmlformats.org/package/2006/relationships">"#,
sheet_relationships,
style_relation_id,
r#"</Relationships>"#
)
}
fn styles_xml(artifact: &SpreadsheetArtifact) -> String {
let max_style_index = artifact
.sheets
.iter()
.flat_map(|sheet| sheet.cells.values().map(|cell| cell.style_index))
.max()
.unwrap_or(0);
let cell_xfs = (0..=max_style_index)
.map(|_| r#"<xf numFmtId="0" fontId="0" fillId="0" borderId="0" xfId="0"/>"#)
.collect::<Vec<_>>()
.join("");
format!(
concat!(
r#"<?xml version="1.0" encoding="UTF-8" standalone="yes"?>"#,
r#"<styleSheet xmlns="http://schemas.openxmlformats.org/spreadsheetml/2006/main">"#,
r#"<fonts count="1"><font/></fonts>"#,
r#"<fills count="2"><fill><patternFill patternType="none"/></fill><fill><patternFill patternType="gray125"/></fill></fills>"#,
r#"<borders count="1"><border/></borders>"#,
r#"<cellStyleXfs count="1"><xf numFmtId="0" fontId="0" fillId="0" borderId="0"/></cellStyleXfs>"#,
r#"<cellXfs count="{}">{}</cellXfs>"#,
r#"<cellStyles count="1"><cellStyle name="Normal" xfId="0" builtinId="0"/></cellStyles>"#,
r#"</styleSheet>"#
),
max_style_index + 1,
cell_xfs,
)
}
fn sheet_xml(sheet: &SpreadsheetSheet) -> String {
let mut rows = BTreeMap::<u32, Vec<(CellAddress, &SpreadsheetCell)>>::new();
for row_index in sheet.row_heights.keys() {
rows.entry(*row_index).or_default();
}
for (address, cell) in &sheet.cells {
rows.entry(address.row).or_default().push((*address, cell));
}
let sheet_data = rows
.into_iter()
.map(|(row_index, mut entries)| {
entries.sort_by_key(|(address, _)| address.column);
let cells = entries
.into_iter()
.map(|(address, cell)| cell_xml(address, cell))
.collect::<Vec<_>>()
.join("");
let height = sheet
.row_heights
.get(&row_index)
.map(|value| format!(r#" ht="{value}" customHeight="1""#))
.unwrap_or_default();
format!(r#"<row r="{row_index}"{height}>{cells}</row>"#)
})
.collect::<Vec<_>>()
.join("");
let cols = if sheet.column_widths.is_empty() {
String::new()
} else {
let mut groups = Vec::new();
let mut iter = sheet.column_widths.iter().peekable();
while let Some((&start, &width)) = iter.next() {
let mut end = start;
while let Some((next_column, next_width)) =
iter.peek().map(|(column, width)| (**column, **width))
{
if next_column == end + 1 && (next_width - width).abs() < f64::EPSILON {
end = next_column;
iter.next();
} else {
break;
}
}
groups.push(format!(
r#"<col min="{start}" max="{end}" width="{width}" customWidth="1"/>"#
));
}
format!("<cols>{}</cols>", groups.join(""))
};
let merge_cells = if sheet.merged_ranges.is_empty() {
String::new()
} else {
format!(
r#"<mergeCells count="{}">{}</mergeCells>"#,
sheet.merged_ranges.len(),
sheet
.merged_ranges
.iter()
.map(|range| format!(r#"<mergeCell ref="{}"/>"#, range.to_a1()))
.collect::<Vec<_>>()
.join("")
)
};
let default_row_height = sheet.default_row_height.unwrap_or(15.0);
let default_column_width = sheet.default_column_width.unwrap_or(8.43);
let grid_lines = if sheet.show_grid_lines { "1" } else { "0" };
format!(
"{}{}<sheetViews><sheetView workbookViewId=\"0\" showGridLines=\"{}\"/></sheetViews><sheetFormatPr defaultRowHeight=\"{}\" defaultColWidth=\"{}\"/>{}<sheetData>{}</sheetData>{}{}",
r#"<?xml version="1.0" encoding="UTF-8" standalone="yes"?>"#,
r#"<worksheet xmlns="http://schemas.openxmlformats.org/spreadsheetml/2006/main">"#,
grid_lines,
default_row_height,
default_column_width,
cols,
sheet_data,
merge_cells,
r#"</worksheet>"#
)
}
fn cell_xml(address: CellAddress, cell: &SpreadsheetCell) -> String {
let style = if cell.style_index == 0 {
String::new()
} else {
format!(r#" s="{}""#, cell.style_index)
};
if let Some(formula) = &cell.formula {
let formula = xml_escape(formula.trim_start_matches('='));
let value_xml = match &cell.value {
Some(SpreadsheetCellValue::Bool(value)) => {
format!(
r#" t="b"><f>{formula}</f><v>{}</v></c>"#,
usize::from(*value)
)
}
Some(SpreadsheetCellValue::Integer(value)) => {
format!(r#"><f>{formula}</f><v>{value}</v></c>"#)
}
Some(SpreadsheetCellValue::Float(value)) => {
format!(r#"><f>{formula}</f><v>{value}</v></c>"#)
}
Some(SpreadsheetCellValue::String(value))
| Some(SpreadsheetCellValue::DateTime(value)) => format!(
r#" t="str"><f>{formula}</f><v>{}</v></c>"#,
xml_escape(value)
),
Some(SpreadsheetCellValue::Error(value)) => {
format!(r#" t="e"><f>{formula}</f><v>{}</v></c>"#, xml_escape(value))
}
None => format!(r#"><f>{formula}</f></c>"#),
};
return format!(r#"<c r="{}"{style}{value_xml}"#, address.to_a1());
}
match &cell.value {
Some(SpreadsheetCellValue::Bool(value)) => format!(
r#"<c r="{}"{style} t="b"><v>{}</v></c>"#,
address.to_a1(),
usize::from(*value)
),
Some(SpreadsheetCellValue::Integer(value)) => {
format!(r#"<c r="{}"{style}><v>{value}</v></c>"#, address.to_a1())
}
Some(SpreadsheetCellValue::Float(value)) => {
format!(r#"<c r="{}"{style}><v>{value}</v></c>"#, address.to_a1())
}
Some(SpreadsheetCellValue::String(value)) | Some(SpreadsheetCellValue::DateTime(value)) => {
format!(
r#"<c r="{}"{style} t="inlineStr"><is><t>{}</t></is></c>"#,
address.to_a1(),
xml_escape(value)
)
}
Some(SpreadsheetCellValue::Error(value)) => format!(
r#"<c r="{}"{style} t="e"><v>{}</v></c>"#,
address.to_a1(),
xml_escape(value)
),
None => format!(r#"<c r="{}"{style}/>"#, address.to_a1()),
}
}
fn first_tag_attributes(xml: &str, tag: &str) -> Option<String> {
let regex = Regex::new(&format!(r#"<{tag}\b([^>]*)/?>"#)).ok()?;
let captures = regex.captures(xml)?;
captures.get(1).map(|value| value.as_str().to_string())
}
fn first_tag_text(xml: &str, tag: &str) -> Option<String> {
let regex = Regex::new(&format!(r#"(?s)<{tag}\b[^>]*>(.*?)</{tag}>"#)).ok()?;
let captures = regex.captures(xml)?;
captures.get(1).map(|value| value.as_str().to_string())
}
fn extract_workbook_title(xml: &str) -> String {
let Ok(regex) =
Regex::new(r#"(?s)<(?:[A-Za-z0-9_]+:)?title\b[^>]*>(.*?)</(?:[A-Za-z0-9_]+:)?title>"#)
else {
return String::new();
};
regex
.captures(xml)
.and_then(|captures| captures.get(1).map(|value| xml_unescape(value.as_str())))
.unwrap_or_default()
}
fn all_text_nodes(xml: &str) -> Result<String, SpreadsheetArtifactError> {
let regex = Regex::new(r#"(?s)<t\b[^>]*>(.*?)</t>"#).map_err(|error| {
SpreadsheetArtifactError::Serialization {
message: error.to_string(),
}
})?;
Ok(regex
.captures_iter(xml)
.filter_map(|captures| captures.get(1).map(|value| xml_unescape(value.as_str())))
.collect::<Vec<_>>()
.join(""))
}
fn extract_attribute(attributes: &str, name: &str) -> Option<String> {
let pattern = format!(r#"{name}="([^"]*)""#);
let regex = Regex::new(&pattern).ok()?;
let captures = regex.captures(attributes)?;
captures.get(1).map(|value| xml_unescape(value.as_str()))
}
fn xml_escape(value: &str) -> String {
value
.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
.replace('\'', "&apos;")
}
fn xml_unescape(value: &str) -> String {
value
.replace("&apos;", "'")
.replace("&quot;", "\"")
.replace("&gt;", ">")
.replace("&lt;", "<")
.replace("&amp;", "&")
}

View File

@@ -155,6 +155,8 @@ pub struct ResponsesApiRequest {
pub stream: bool,
pub include: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_cache_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<TextControls>,
@@ -174,6 +176,7 @@ impl From<&ResponsesApiRequest> for ResponseCreateWsRequest {
store: request.store,
stream: request.stream,
include: request.include.clone(),
service_tier: request.service_tier.clone(),
prompt_cache_key: request.prompt_cache_key.clone(),
text: request.text.clone(),
generate: None,
@@ -197,6 +200,8 @@ pub struct ResponseCreateWsRequest {
pub stream: bool,
pub include: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_cache_key: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<TextControls>,

View File

@@ -4,7 +4,10 @@ use crate::endpoint::realtime_websocket::protocol::RealtimeAudioFrame;
use crate::endpoint::realtime_websocket::protocol::RealtimeEvent;
use crate::endpoint::realtime_websocket::protocol::RealtimeOutboundMessage;
use crate::endpoint::realtime_websocket::protocol::RealtimeSessionConfig;
use crate::endpoint::realtime_websocket::protocol::SessionCreateSession;
use crate::endpoint::realtime_websocket::protocol::SessionAudio;
use crate::endpoint::realtime_websocket::protocol::SessionAudioFormat;
use crate::endpoint::realtime_websocket::protocol::SessionAudioInput;
use crate::endpoint::realtime_websocket::protocol::SessionAudioOutput;
use crate::endpoint::realtime_websocket::protocol::SessionUpdateSession;
use crate::endpoint::realtime_websocket::protocol::parse_realtime_event;
use crate::error::ApiError;
@@ -13,6 +16,8 @@ use codex_utils_rustls_provider::ensure_rustls_crypto_provider;
use futures::SinkExt;
use futures::StreamExt;
use http::HeaderMap;
use http::HeaderValue;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
@@ -205,23 +210,13 @@ impl RealtimeWebsocketConnection {
self.writer.send_conversation_item_create(text).await
}
pub async fn send_session_update(
pub async fn send_conversation_handoff_append(
&self,
backend_prompt: String,
conversation_id: Option<String>,
handoff_id: String,
output_text: String,
) -> Result<(), ApiError> {
self.writer
.send_session_update(backend_prompt, conversation_id)
.await
}
pub async fn send_session_create(
&self,
backend_prompt: String,
conversation_id: Option<String>,
) -> Result<(), ApiError> {
self.writer
.send_session_create(backend_prompt, conversation_id)
.send_conversation_handoff_append(handoff_id, output_text)
.await
}
@@ -262,13 +257,8 @@ impl RealtimeWebsocketConnection {
impl RealtimeWebsocketWriter {
pub async fn send_audio_frame(&self, frame: RealtimeAudioFrame) -> Result<(), ApiError> {
self.send_json(RealtimeOutboundMessage::InputAudioDelta {
delta: frame.data,
sample_rate: frame.sample_rate,
num_channels: frame.num_channels,
samples_per_channel: frame.samples_per_channel,
})
.await
self.send_json(RealtimeOutboundMessage::InputAudioBufferAppend { audio: frame.data })
.await
}
pub async fn send_conversation_item_create(&self, text: String) -> Result<(), ApiError> {
@@ -285,29 +275,34 @@ impl RealtimeWebsocketWriter {
.await
}
pub async fn send_session_update(
pub async fn send_conversation_handoff_append(
&self,
backend_prompt: String,
conversation_id: Option<String>,
handoff_id: String,
output_text: String,
) -> Result<(), ApiError> {
self.send_json(RealtimeOutboundMessage::SessionUpdate {
session: Some(SessionUpdateSession {
backend_prompt,
conversation_id,
}),
self.send_json(RealtimeOutboundMessage::ConversationHandoffAppend {
handoff_id,
output_text,
})
.await
}
pub async fn send_session_create(
&self,
backend_prompt: String,
conversation_id: Option<String>,
) -> Result<(), ApiError> {
self.send_json(RealtimeOutboundMessage::SessionCreate {
session: SessionCreateSession {
backend_prompt,
conversation_id,
pub async fn send_session_update(&self, instructions: String) -> Result<(), ApiError> {
self.send_json(RealtimeOutboundMessage::SessionUpdate {
session: SessionUpdateSession {
kind: "quicksilver".to_string(),
instructions,
audio: SessionAudio {
input: SessionAudioInput {
format: SessionAudioFormat {
kind: "audio/pcm".to_string(),
rate: 24_000,
},
},
output: SessionAudioOutput {
voice: "mundo".to_string(),
},
},
},
})
.await
@@ -413,14 +408,21 @@ impl RealtimeWebsocketClient {
default_headers: HeaderMap,
) -> Result<RealtimeWebsocketConnection, ApiError> {
ensure_rustls_crypto_provider();
// Keep provider base_url semantics aligned with HTTP clients; derive the ws endpoint here.
let ws_url = websocket_url_from_api_url(self.provider.base_url.as_str())?;
let ws_url = websocket_url_from_api_url(
self.provider.base_url.as_str(),
self.provider.query_params.as_ref(),
config.model.as_deref(),
)?;
let mut request = ws_url
.as_str()
.into_client_request()
.map_err(|err| ApiError::Stream(format!("failed to build websocket request: {err}")))?;
let headers = merge_request_headers(&self.provider.headers, extra_headers, default_headers);
let headers = merge_request_headers(
&self.provider.headers,
with_session_id_header(extra_headers, config.session_id.as_deref())?,
default_headers,
);
request.headers_mut().extend(headers);
info!("connecting realtime websocket: {ws_url}");
@@ -439,11 +441,12 @@ impl RealtimeWebsocketClient {
let (stream, rx_message) = WsStream::new(stream);
let connection = RealtimeWebsocketConnection::new(stream, rx_message);
debug!(
conversation_id = config.session_id.as_deref().unwrap_or("<none>"),
"realtime websocket sending session.create"
session_id = config.session_id.as_deref().unwrap_or("<none>"),
"realtime websocket sending session.update"
);
connection
.send_session_create(config.prompt, config.session_id)
.writer
.send_session_update(config.instructions)
.await?;
Ok(connection)
}
@@ -464,38 +467,99 @@ fn merge_request_headers(
headers
}
fn with_session_id_header(
mut headers: HeaderMap,
session_id: Option<&str>,
) -> Result<HeaderMap, ApiError> {
let Some(session_id) = session_id else {
return Ok(headers);
};
headers.insert(
"x-session-id",
HeaderValue::from_str(session_id).map_err(|err| {
ApiError::Stream(format!("invalid realtime session id header: {err}"))
})?,
);
Ok(headers)
}
fn websocket_config() -> WebSocketConfig {
WebSocketConfig::default()
}
fn websocket_url_from_api_url(api_url: &str) -> Result<Url, ApiError> {
fn websocket_url_from_api_url(
api_url: &str,
query_params: Option<&HashMap<String, String>>,
model: Option<&str>,
) -> Result<Url, ApiError> {
let mut url = Url::parse(api_url)
.map_err(|err| ApiError::Stream(format!("failed to parse realtime api_url: {err}")))?;
normalize_realtime_path(&mut url);
match url.scheme() {
"ws" | "wss" => {
if url.path().is_empty() || url.path() == "/" {
url.set_path("/ws");
}
Ok(url)
}
"ws" | "wss" => {}
"http" | "https" => {
if url.path().is_empty() || url.path() == "/" {
url.set_path("/ws");
}
let scheme = if url.scheme() == "http" { "ws" } else { "wss" };
let _ = url.set_scheme(scheme);
Ok(url)
}
scheme => Err(ApiError::Stream(format!(
"unsupported realtime api_url scheme: {scheme}"
))),
scheme => {
return Err(ApiError::Stream(format!(
"unsupported realtime api_url scheme: {scheme}"
)));
}
}
{
let mut query = url.query_pairs_mut();
query.append_pair("intent", "quicksilver");
if let Some(model) = model {
query.append_pair("model", model);
}
if let Some(query_params) = query_params {
for (key, value) in query_params {
if key == "intent" || (key == "model" && model.is_some()) {
continue;
}
query.append_pair(key, value);
}
}
}
Ok(url)
}
fn normalize_realtime_path(url: &mut Url) {
let path = url.path().to_string();
if path.is_empty() || path == "/" {
url.set_path("/v1/realtime");
return;
}
if path.ends_with("/realtime") {
return;
}
if path.ends_with("/realtime/") {
url.set_path(path.trim_end_matches('/'));
return;
}
if path.ends_with("/v1") {
url.set_path(&format!("{path}/realtime"));
return;
}
if path.ends_with("/v1/") {
url.set_path(&format!("{path}realtime"));
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::endpoint::realtime_websocket::protocol::RealtimeHandoffMessage;
use crate::endpoint::realtime_websocket::protocol::RealtimeHandoffRequested;
use http::HeaderValue;
use pretty_assertions::assert_eq;
use serde_json::Value;
@@ -507,17 +571,18 @@ mod tests {
use tokio_tungstenite::tungstenite::Message;
#[test]
fn parse_session_created_event() {
fn parse_session_updated_event() {
let payload = json!({
"type": "session.created",
"session": {"id": "sess_123"}
"type": "session.updated",
"session": {"id": "sess_123", "instructions": "backend prompt"}
})
.to_string();
assert_eq!(
parse_realtime_event(payload.as_str()),
Some(RealtimeEvent::SessionCreated {
session_id: "sess_123".to_string()
Some(RealtimeEvent::SessionUpdated {
session_id: "sess_123".to_string(),
instructions: Some("backend prompt".to_string()),
})
);
}
@@ -525,10 +590,10 @@ mod tests {
#[test]
fn parse_audio_delta_event() {
let payload = json!({
"type": "response.output_audio.delta",
"type": "conversation.output_audio.delta",
"delta": "AAA=",
"sample_rate": 48000,
"num_channels": 1,
"channels": 1,
"samples_per_channel": 960
})
.to_string();
@@ -547,17 +612,59 @@ mod tests {
fn parse_conversation_item_added_event() {
let payload = json!({
"type": "conversation.item.added",
"item": {"type": "spawn_transcript", "seq": 7}
"item": {"type": "message", "seq": 7}
})
.to_string();
assert_eq!(
parse_realtime_event(payload.as_str()),
Some(RealtimeEvent::ConversationItemAdded(
json!({"type": "spawn_transcript", "seq": 7})
json!({"type": "message", "seq": 7})
))
);
}
#[test]
fn parse_conversation_item_done_event() {
let payload = json!({
"type": "conversation.item.done",
"item": {"id": "item_123", "type": "message"}
})
.to_string();
assert_eq!(
parse_realtime_event(payload.as_str()),
Some(RealtimeEvent::ConversationItemDone {
item_id: "item_123".to_string(),
})
);
}
#[test]
fn parse_handoff_requested_event() {
let payload = json!({
"type": "conversation.handoff.requested",
"handoff_id": "handoff_123",
"item_id": "item_123",
"input_transcript": "delegate this",
"messages": [
{"role": "user", "text": "delegate this"}
]
})
.to_string();
assert_eq!(
parse_realtime_event(payload.as_str()),
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
handoff_id: "handoff_123".to_string(),
item_id: "item_123".to_string(),
input_transcript: "delegate this".to_string(),
messages: vec![RealtimeHandoffMessage {
role: "user".to_string(),
text: "delegate this".to_string(),
}],
}))
);
}
#[test]
fn merge_request_headers_matches_http_precedence() {
let mut provider_headers = HeaderMap::new();
@@ -593,14 +700,61 @@ mod tests {
#[test]
fn websocket_url_from_http_base_defaults_to_ws_path() {
let url = websocket_url_from_api_url("http://127.0.0.1:8011").expect("build ws url");
assert_eq!(url.as_str(), "ws://127.0.0.1:8011/ws");
let url =
websocket_url_from_api_url("http://127.0.0.1:8011", None, None).expect("build ws url");
assert_eq!(
url.as_str(),
"ws://127.0.0.1:8011/v1/realtime?intent=quicksilver"
);
}
#[test]
fn websocket_url_from_ws_base_defaults_to_ws_path() {
let url = websocket_url_from_api_url("wss://example.com").expect("build ws url");
assert_eq!(url.as_str(), "wss://example.com/ws");
let url =
websocket_url_from_api_url("wss://example.com", None, Some("realtime-test-model"))
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://example.com/v1/realtime?intent=quicksilver&model=realtime-test-model"
);
}
#[test]
fn websocket_url_from_v1_base_appends_realtime_path() {
let url = websocket_url_from_api_url("https://api.openai.com/v1", None, Some("snapshot"))
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://api.openai.com/v1/realtime?intent=quicksilver&model=snapshot"
);
}
#[test]
fn websocket_url_from_nested_v1_base_appends_realtime_path() {
let url =
websocket_url_from_api_url("https://example.com/openai/v1", None, Some("snapshot"))
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://example.com/openai/v1/realtime?intent=quicksilver&model=snapshot"
);
}
#[test]
fn websocket_url_preserves_existing_realtime_path_and_extra_query_params() {
let url = websocket_url_from_api_url(
"https://example.com/v1/realtime?foo=bar",
Some(&HashMap::from([
("trace".to_string(), "1".to_string()),
("intent".to_string(), "ignored".to_string()),
])),
Some("snapshot"),
)
.expect("build ws url");
assert_eq!(
url.as_str(),
"wss://example.com/v1/realtime?foo=bar&intent=quicksilver&model=snapshot&trace=1"
);
}
#[tokio::test]
@@ -620,26 +774,38 @@ mod tests {
.into_text()
.expect("text");
let first_json: Value = serde_json::from_str(&first).expect("json");
assert_eq!(first_json["type"], "session.create");
assert_eq!(first_json["type"], "session.update");
assert_eq!(
first_json["session"]["backend_prompt"],
first_json["session"]["type"],
Value::String("quicksilver".to_string())
);
assert_eq!(
first_json["session"]["instructions"],
Value::String("backend prompt".to_string())
);
assert_eq!(
first_json["session"]["conversation_id"],
Value::String("conv_1".to_string())
first_json["session"]["audio"]["input"]["format"]["type"],
Value::String("audio/pcm".to_string())
);
assert_eq!(
first_json["session"]["audio"]["input"]["format"]["rate"],
Value::from(24_000)
);
assert_eq!(
first_json["session"]["audio"]["output"]["voice"],
Value::String("mundo".to_string())
);
ws.send(Message::Text(
json!({
"type": "session.created",
"session": {"id": "sess_mock"}
"type": "session.updated",
"session": {"id": "sess_mock", "instructions": "backend prompt"}
})
.to_string()
.into(),
))
.await
.expect("send session.created");
.expect("send session.updated");
let second = ws
.next()
@@ -649,7 +815,7 @@ mod tests {
.into_text()
.expect("text");
let second_json: Value = serde_json::from_str(&second).expect("json");
assert_eq!(second_json["type"], "response.input_audio.delta");
assert_eq!(second_json["type"], "input_audio_buffer.append");
let third = ws
.next()
@@ -662,12 +828,24 @@ mod tests {
assert_eq!(third_json["type"], "conversation.item.create");
assert_eq!(third_json["item"]["content"][0]["text"], "hello agent");
let fourth = ws
.next()
.await
.expect("fourth msg")
.expect("fourth msg ok")
.into_text()
.expect("text");
let fourth_json: Value = serde_json::from_str(&fourth).expect("json");
assert_eq!(fourth_json["type"], "conversation.handoff.append");
assert_eq!(fourth_json["handoff_id"], "handoff_1");
assert_eq!(fourth_json["output_text"], "hello from codex");
ws.send(Message::Text(
json!({
"type": "response.output_audio.delta",
"type": "conversation.output_audio.delta",
"delta": "AQID",
"sample_rate": 48000,
"num_channels": 1
"channels": 1
})
.to_string()
.into(),
@@ -677,8 +855,11 @@ mod tests {
ws.send(Message::Text(
json!({
"type": "conversation.item.added",
"item": {"type": "spawn_transcript", "seq": 2}
"type": "conversation.handoff.requested",
"handoff_id": "handoff_1",
"item_id": "item_2",
"input_transcript": "delegate now",
"messages": [{"role": "user", "text": "delegate now"}]
})
.to_string()
.into(),
@@ -705,7 +886,8 @@ mod tests {
let connection = client
.connect(
RealtimeSessionConfig {
prompt: "backend prompt".to_string(),
instructions: "backend prompt".to_string(),
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_1".to_string()),
},
HeaderMap::new(),
@@ -721,8 +903,9 @@ mod tests {
.expect("event");
assert_eq!(
created,
RealtimeEvent::SessionCreated {
session_id: "sess_mock".to_string()
RealtimeEvent::SessionUpdated {
session_id: "sess_mock".to_string(),
instructions: Some("backend prompt".to_string()),
}
);
@@ -739,6 +922,13 @@ mod tests {
.send_conversation_item_create("hello agent".to_string())
.await
.expect("send item");
connection
.send_conversation_handoff_append(
"handoff_1".to_string(),
"hello from codex".to_string(),
)
.await
.expect("send handoff");
let audio_event = connection
.next_event()
@@ -762,10 +952,15 @@ mod tests {
.expect("event");
assert_eq!(
added_event,
RealtimeEvent::ConversationItemAdded(json!({
"type": "spawn_transcript",
"seq": 2
}))
RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
handoff_id: "handoff_1".to_string(),
item_id: "item_2".to_string(),
input_transcript: "delegate now".to_string(),
messages: vec![RealtimeHandoffMessage {
role: "user".to_string(),
text: "delegate now".to_string(),
}],
})
);
connection.close().await.expect("close");
@@ -789,7 +984,7 @@ mod tests {
.into_text()
.expect("text");
let first_json: Value = serde_json::from_str(&first).expect("json");
assert_eq!(first_json["type"], "session.create");
assert_eq!(first_json["type"], "session.update");
let second = ws
.next()
@@ -799,18 +994,18 @@ mod tests {
.into_text()
.expect("text");
let second_json: Value = serde_json::from_str(&second).expect("json");
assert_eq!(second_json["type"], "response.input_audio.delta");
assert_eq!(second_json["type"], "input_audio_buffer.append");
ws.send(Message::Text(
json!({
"type": "session.created",
"session": {"id": "sess_after_send"}
"type": "session.updated",
"session": {"id": "sess_after_send", "instructions": "backend prompt"}
})
.to_string()
.into(),
))
.await
.expect("send session.created");
.expect("send session.updated");
});
let provider = Provider {
@@ -831,7 +1026,8 @@ mod tests {
let connection = client
.connect(
RealtimeSessionConfig {
prompt: "backend prompt".to_string(),
instructions: "backend prompt".to_string(),
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_1".to_string()),
},
HeaderMap::new(),
@@ -862,8 +1058,9 @@ mod tests {
let next_event = next_result.expect("next event").expect("event");
assert_eq!(
next_event,
RealtimeEvent::SessionCreated {
session_id: "sess_after_send".to_string()
RealtimeEvent::SessionUpdated {
session_id: "sess_after_send".to_string(),
instructions: Some("backend prompt".to_string()),
}
);

View File

@@ -1,49 +1,63 @@
pub use codex_protocol::protocol::RealtimeAudioFrame;
pub use codex_protocol::protocol::RealtimeEvent;
pub use codex_protocol::protocol::RealtimeHandoffMessage;
pub use codex_protocol::protocol::RealtimeHandoffRequested;
use serde::Serialize;
use serde_json::Value;
use tracing::debug;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RealtimeSessionConfig {
pub prompt: String,
pub instructions: String,
pub model: Option<String>,
pub session_id: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type")]
pub(super) enum RealtimeOutboundMessage {
#[serde(rename = "response.input_audio.delta")]
InputAudioDelta {
delta: String,
sample_rate: u32,
num_channels: u16,
#[serde(skip_serializing_if = "Option::is_none")]
samples_per_channel: Option<u32>,
#[serde(rename = "input_audio_buffer.append")]
InputAudioBufferAppend { audio: String },
#[serde(rename = "conversation.handoff.append")]
ConversationHandoffAppend {
handoff_id: String,
output_text: String,
},
#[serde(rename = "session.create")]
SessionCreate { session: SessionCreateSession },
#[serde(rename = "session.update")]
SessionUpdate {
#[serde(skip_serializing_if = "Option::is_none")]
session: Option<SessionUpdateSession>,
},
SessionUpdate { session: SessionUpdateSession },
#[serde(rename = "conversation.item.create")]
ConversationItemCreate { item: ConversationItem },
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionUpdateSession {
pub(super) backend_prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub(super) conversation_id: Option<String>,
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) instructions: String,
pub(super) audio: SessionAudio,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionCreateSession {
pub(super) backend_prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub(super) conversation_id: Option<String>,
pub(super) struct SessionAudio {
pub(super) input: SessionAudioInput,
pub(super) output: SessionAudioOutput,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionAudioInput {
pub(super) format: SessionAudioFormat,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionAudioFormat {
#[serde(rename = "type")]
pub(super) kind: String,
pub(super) rate: u32,
}
#[derive(Debug, Clone, Serialize)]
pub(super) struct SessionAudioOutput {
pub(super) voice: String,
}
#[derive(Debug, Clone, Serialize)]
@@ -78,30 +92,25 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
}
};
match message_type {
"session.created" => {
let session = parsed.get("session").and_then(Value::as_object);
let session_id = session
.and_then(|session| session.get("id"))
.and_then(Value::as_str)
.map(str::to_string)
.or_else(|| {
parsed
.get("session_id")
.and_then(Value::as_str)
.map(str::to_string)
});
session_id.map(|id| RealtimeEvent::SessionCreated { session_id: id })
}
"session.updated" => {
let backend_prompt = parsed
let session_id = parsed
.get("session")
.and_then(Value::as_object)
.and_then(|session| session.get("backend_prompt"))
.and_then(|session| session.get("id"))
.and_then(Value::as_str)
.map(str::to_string);
Some(RealtimeEvent::SessionUpdated { backend_prompt })
let instructions = parsed
.get("session")
.and_then(Value::as_object)
.and_then(|session| session.get("instructions"))
.and_then(Value::as_str)
.map(str::to_string);
session_id.map(|session_id| RealtimeEvent::SessionUpdated {
session_id,
instructions,
})
}
"response.output_audio.delta" => {
"conversation.output_audio.delta" => {
let data = parsed
.get("delta")
.and_then(Value::as_str)
@@ -112,7 +121,8 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
.and_then(Value::as_u64)
.and_then(|v| u32::try_from(v).ok())?;
let num_channels = parsed
.get("num_channels")
.get("channels")
.or_else(|| parsed.get("num_channels"))
.and_then(Value::as_u64)
.and_then(|v| u16::try_from(v).ok())?;
Some(RealtimeEvent::AudioOut(RealtimeAudioFrame {
@@ -129,10 +139,55 @@ pub(super) fn parse_realtime_event(payload: &str) -> Option<RealtimeEvent> {
.get("item")
.cloned()
.map(RealtimeEvent::ConversationItemAdded),
"conversation.item.done" => parsed
.get("item")
.and_then(Value::as_object)
.and_then(|item| item.get("id"))
.and_then(Value::as_str)
.map(str::to_string)
.map(|item_id| RealtimeEvent::ConversationItemDone { item_id }),
"conversation.handoff.requested" => {
let handoff_id = parsed
.get("handoff_id")
.and_then(Value::as_str)
.map(str::to_string)?;
let item_id = parsed
.get("item_id")
.and_then(Value::as_str)
.map(str::to_string)?;
let input_transcript = parsed
.get("input_transcript")
.and_then(Value::as_str)
.map(str::to_string)?;
let messages = parsed
.get("messages")
.and_then(Value::as_array)?
.iter()
.filter_map(|message| {
let role = message.get("role").and_then(Value::as_str)?.to_string();
let text = message.get("text").and_then(Value::as_str)?.to_string();
Some(RealtimeHandoffMessage { role, text })
})
.collect();
Some(RealtimeEvent::HandoffRequested(RealtimeHandoffRequested {
handoff_id,
item_id,
input_transcript,
messages,
}))
}
"error" => parsed
.get("message")
.and_then(Value::as_str)
.map(str::to_string)
.or_else(|| {
parsed
.get("error")
.and_then(Value::as_object)
.and_then(|error| error.get("message"))
.and_then(Value::as_str)
.map(str::to_string)
})
.or_else(|| parsed.get("error").map(std::string::ToString::to_string))
.map(RealtimeEvent::Error),
_ => {

View File

@@ -265,6 +265,7 @@ async fn streaming_client_retries_on_transport_error() -> Result<()> {
store: false,
stream: true,
include: Vec::new(),
service_tier: None,
prompt_cache_key: None,
text: None,
};
@@ -306,6 +307,7 @@ async fn azure_default_store_attaches_ids_and_headers() -> Result<()> {
store: true,
stream: true,
include: Vec::new(),
service_tier: None,
prompt_cache_key: None,
text: None,
};

View File

@@ -78,26 +78,34 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
.into_text()
.expect("text");
let first_json: Value = serde_json::from_str(&first).expect("json");
assert_eq!(first_json["type"], "session.create");
assert_eq!(first_json["type"], "session.update");
assert_eq!(
first_json["session"]["backend_prompt"],
first_json["session"]["type"],
Value::String("quicksilver".to_string())
);
assert_eq!(
first_json["session"]["instructions"],
Value::String("backend prompt".to_string())
);
assert_eq!(
first_json["session"]["conversation_id"],
Value::String("conv_123".to_string())
first_json["session"]["audio"]["input"]["format"]["type"],
Value::String("audio/pcm".to_string())
);
assert_eq!(
first_json["session"]["audio"]["input"]["format"]["rate"],
Value::from(24_000)
);
ws.send(Message::Text(
json!({
"type": "session.created",
"session": {"id": "sess_mock"}
"type": "session.updated",
"session": {"id": "sess_mock", "instructions": "backend prompt"}
})
.to_string()
.into(),
))
.await
.expect("send session.created");
.expect("send session.updated");
let second = ws
.next()
@@ -107,14 +115,14 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
.into_text()
.expect("text");
let second_json: Value = serde_json::from_str(&second).expect("json");
assert_eq!(second_json["type"], "response.input_audio.delta");
assert_eq!(second_json["type"], "input_audio_buffer.append");
ws.send(Message::Text(
json!({
"type": "response.output_audio.delta",
"type": "conversation.output_audio.delta",
"delta": "AQID",
"sample_rate": 48000,
"num_channels": 1
"channels": 1
})
.to_string()
.into(),
@@ -128,7 +136,8 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
let connection = client
.connect(
RealtimeSessionConfig {
prompt: "backend prompt".to_string(),
instructions: "backend prompt".to_string(),
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
},
HeaderMap::new(),
@@ -144,8 +153,9 @@ async fn realtime_ws_e2e_session_create_and_event_flow() {
.expect("event");
assert_eq!(
created,
RealtimeEvent::SessionCreated {
session_id: "sess_mock".to_string()
RealtimeEvent::SessionUpdated {
session_id: "sess_mock".to_string(),
instructions: Some("backend prompt".to_string()),
}
);
@@ -189,7 +199,7 @@ async fn realtime_ws_e2e_send_while_next_event_waits() {
.into_text()
.expect("text");
let first_json: Value = serde_json::from_str(&first).expect("json");
assert_eq!(first_json["type"], "session.create");
assert_eq!(first_json["type"], "session.update");
let second = ws
.next()
@@ -199,18 +209,18 @@ async fn realtime_ws_e2e_send_while_next_event_waits() {
.into_text()
.expect("text");
let second_json: Value = serde_json::from_str(&second).expect("json");
assert_eq!(second_json["type"], "response.input_audio.delta");
assert_eq!(second_json["type"], "input_audio_buffer.append");
ws.send(Message::Text(
json!({
"type": "session.created",
"session": {"id": "sess_after_send"}
"type": "session.updated",
"session": {"id": "sess_after_send", "instructions": "backend prompt"}
})
.to_string()
.into(),
))
.await
.expect("send session.created");
.expect("send session.updated");
})
.await;
@@ -218,7 +228,8 @@ async fn realtime_ws_e2e_send_while_next_event_waits() {
let connection = client
.connect(
RealtimeSessionConfig {
prompt: "backend prompt".to_string(),
instructions: "backend prompt".to_string(),
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
},
HeaderMap::new(),
@@ -249,8 +260,9 @@ async fn realtime_ws_e2e_send_while_next_event_waits() {
let next_event = next_result.expect("next event").expect("event");
assert_eq!(
next_event,
RealtimeEvent::SessionCreated {
session_id: "sess_after_send".to_string()
RealtimeEvent::SessionUpdated {
session_id: "sess_after_send".to_string(),
instructions: Some("backend prompt".to_string()),
}
);
@@ -269,7 +281,7 @@ async fn realtime_ws_e2e_disconnected_emitted_once() {
.into_text()
.expect("text");
let first_json: Value = serde_json::from_str(&first).expect("json");
assert_eq!(first_json["type"], "session.create");
assert_eq!(first_json["type"], "session.update");
ws.send(Message::Close(None)).await.expect("send close");
})
@@ -279,7 +291,8 @@ async fn realtime_ws_e2e_disconnected_emitted_once() {
let connection = client
.connect(
RealtimeSessionConfig {
prompt: "backend prompt".to_string(),
instructions: "backend prompt".to_string(),
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
},
HeaderMap::new(),
@@ -308,7 +321,7 @@ async fn realtime_ws_e2e_ignores_unknown_text_events() {
.into_text()
.expect("text");
let first_json: Value = serde_json::from_str(&first).expect("json");
assert_eq!(first_json["type"], "session.create");
assert_eq!(first_json["type"], "session.update");
ws.send(Message::Text(
json!({
@@ -323,14 +336,14 @@ async fn realtime_ws_e2e_ignores_unknown_text_events() {
ws.send(Message::Text(
json!({
"type": "session.created",
"session": {"id": "sess_after_unknown"}
"type": "session.updated",
"session": {"id": "sess_after_unknown", "instructions": "backend prompt"}
})
.to_string()
.into(),
))
.await
.expect("send session.created");
.expect("send session.updated");
})
.await;
@@ -338,7 +351,8 @@ async fn realtime_ws_e2e_ignores_unknown_text_events() {
let connection = client
.connect(
RealtimeSessionConfig {
prompt: "backend prompt".to_string(),
instructions: "backend prompt".to_string(),
model: Some("realtime-test-model".to_string()),
session_id: Some("conv_123".to_string()),
},
HeaderMap::new(),
@@ -354,8 +368,9 @@ async fn realtime_ws_e2e_ignores_unknown_text_events() {
.expect("event");
assert_eq!(
event,
RealtimeEvent::SessionCreated {
session_id: "sess_after_unknown".to_string()
RealtimeEvent::SessionUpdated {
session_id: "sess_after_unknown".to_string(),
instructions: Some("backend prompt".to_string()),
}
);

View File

@@ -42,6 +42,8 @@ codex-hooks = { workspace = true }
codex-keyring-store = { workspace = true }
codex-network-proxy = { workspace = true }
codex-otel = { workspace = true }
codex-artifact-presentation = { workspace = true }
codex-artifact-spreadsheet = { workspace = true }
codex-protocol = { workspace = true }
codex-rmcp-client = { workspace = true }
codex-state = { workspace = true }

Some files were not shown because too many files have changed in this diff Show More