adding tests

This commit is contained in:
pap
2025-08-04 00:59:20 +01:00
parent b2ed15430f
commit 42c726be79
4 changed files with 168 additions and 40 deletions

View File

@@ -102,6 +102,18 @@ pub enum CodexErr {
#[error("{0}")]
EnvVar(EnvVarError),
// ------------------------------
// Ollamaspecific errors
// ------------------------------
#[error("no running Ollama server detected; please install/start Ollama")]
OllamaServerUnreachable,
#[error("ollama model not found: {0}")]
OllamaModelNotFound(String),
#[error("ollama pull failed: {0}")]
OllamaPullFailed(String),
}
#[derive(Debug)]

View File

@@ -1,3 +1,5 @@
use crate::error::CodexErr;
use crate::error::Result as CoreResult;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
use bytes::BytesMut;
@@ -56,10 +58,10 @@ pub fn probe_url_for_base(base_url: &str) -> String {
}
/// Convenience helper to probe an Ollama server given a provider style base URL.
pub async fn probe_ollama_server(base_url: &str) -> io::Result<bool> {
pub async fn probe_ollama_server(base_url: &str) -> CoreResult<bool> {
let url = probe_url_for_base(base_url);
let resp = reqwest::Client::new().get(url).send().await;
Ok(matches!(resp, Ok(r) if r.status().is_success()))
let resp = reqwest::Client::new().get(url).send().await?;
Ok(resp.status().is_success())
}
/// Coordinator wrapper used by frontends when responding to `--ollama`.
///
@@ -68,7 +70,7 @@ pub async fn probe_ollama_server(base_url: &str) -> io::Result<bool> {
/// - If the server is reachable, ensures an `[model_providers.ollama]` entry
/// exists in `config.toml` with sensible defaults.
/// - If no server is reachable, returns an error.
pub async fn ensure_configured_and_running() -> io::Result<()> {
pub async fn ensure_configured_and_running() -> CoreResult<()> {
use crate::config::find_codex_home;
use toml::Value as TomlValue;
@@ -94,9 +96,7 @@ pub async fn ensure_configured_and_running() -> io::Result<()> {
// Probe reachability.
let ok = probe_ollama_server(&base_url).await?;
if !ok {
return Err(io::Error::other(
"No running Ollama server detected. Please install/start Ollama: https://github.com/ollama/ollama?tab=readme-ov-file#ollama",
));
return Err(CodexErr::OllamaServerUnreachable);
}
// Ensure provider entry exists with defaults.
@@ -231,13 +231,8 @@ impl PullProgressReporter for CliProgressReporter {
/// For now the TUI reporter delegates to the CLI reporter. This keeps UI and
/// CLI behavior aligned until a dedicated TUI integration is implemented.
#[derive(Default)]
pub struct TuiProgressReporter(CliProgressReporter);
impl Default for TuiProgressReporter {
fn default() -> Self {
Self(CliProgressReporter::new())
}
}
impl TuiProgressReporter {
pub fn new() -> Self {
Default::default()
@@ -358,20 +353,14 @@ impl OllamaClient {
let text = text.trim();
if text.is_empty() { continue; }
if let Ok(value) = serde_json::from_str::<JsonValue>(text) {
for ev in pull_events_from_value(&value) { yield ev; }
if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) {
yield PullEvent::Status(format!("error: {err_msg}"));
return;
}
if let Some(status) = value.get("status").and_then(|s| s.as_str()) {
yield PullEvent::Status(status.to_string());
if status == "success" { yield PullEvent::Success; return; }
}
let digest = value.get("digest").and_then(|d| d.as_str()).unwrap_or("").to_string();
let total = value.get("total").and_then(|t| t.as_u64());
let completed = value.get("completed").and_then(|t| t.as_u64());
if total.is_some() || completed.is_some() {
yield PullEvent::ChunkProgress { digest, total, completed };
}
}
}
}
@@ -513,7 +502,7 @@ pub async fn ensure_model_available(
client: &OllamaClient,
config_path: &Path,
reporter: &mut dyn PullProgressReporter,
) -> io::Result<()> {
) -> CoreResult<()> {
let mut listed = read_ollama_models_list(config_path);
let available = client.fetch_models().await.unwrap_or_default();
if available.iter().any(|m| m == model) {
@@ -527,9 +516,7 @@ pub async fn ensure_model_available(
}
if !DEFAULT_PULL_ALLOWLIST.contains(&model) {
return Err(io::Error::other(format!(
"Model `{model}` not found locally and not in allowlist for automatic download."
)));
return Err(CodexErr::OllamaModelNotFound(model.to_string()));
}
loop {
@@ -568,25 +555,20 @@ fn write_document(path: &Path, doc: &Document) -> io::Result<()> {
}
pub fn upsert_provider_ollama(doc: &mut Document) -> &mut Table {
// Ensure the provider tables exist first, then take a single mutable borrow.
if !doc["model_providers"].is_table() {
if doc.get("model_providers").is_none() {
doc.as_table_mut()
.insert("model_providers", Item::Table(Table::new()));
} else if !doc["model_providers"].is_table() {
doc["model_providers"] = Item::Table(Table::new());
}
{
// Narrow scope: mutate/create the nested `ollama` table without keeping a borrow alive.
let providers = match doc["model_providers"].as_table_mut() {
Some(table) => table,
None => return Box::leak(Box::new(Table::default())),
};
if !providers.contains_key("ollama") || !providers["ollama"].is_table() {
providers["ollama"] = Item::Table(Table::new());
}
let providers = doc["model_providers"]
.as_table_mut()
.expect("providers table");
if providers.get("ollama").is_none() || !providers["ollama"].is_table() {
providers["ollama"] = Item::Table(Table::new());
}
// Now, safely borrow the `ollama` table mutably once and return it.
let tbl = match doc["model_providers"]["ollama"].as_table_mut() {
Some(table) => table,
None => return Box::leak(Box::new(Table::default())),
};
let tbl = providers["ollama"].as_table_mut().expect("ollama table");
if !tbl.contains_key("name") {
tbl["name"] = Item::Value(TomlValueEdit::from("Ollama"));
}
@@ -607,3 +589,133 @@ pub fn set_ollama_models(doc: &mut Document, models: &[String]) {
}
tbl["models"] = Item::Value(TomlValueEdit::Array(arr));
}
// Convert a single JSON object representing a pull update into one or more events.
fn pull_events_from_value(value: &JsonValue) -> Vec<PullEvent> {
let mut events = Vec::new();
if let Some(status) = value.get("status").and_then(|s| s.as_str()) {
events.push(PullEvent::Status(status.to_string()));
if status == "success" {
events.push(PullEvent::Success);
}
}
let digest = value
.get("digest")
.and_then(|d| d.as_str())
.unwrap_or("")
.to_string();
let total = value.get("total").and_then(|t| t.as_u64());
let completed = value.get("completed").and_then(|t| t.as_u64());
if total.is_some() || completed.is_some() {
events.push(PullEvent::ChunkProgress {
digest,
total,
completed,
});
}
events
}
#[cfg(test)]
mod tests {
use super::*;
use toml_edit::DocumentMut as Document;
#[test]
fn test_base_url_to_host_root() {
assert_eq!(
base_url_to_host_root("http://localhost:11434/v1"),
"http://localhost:11434"
);
assert_eq!(
base_url_to_host_root("http://localhost:11434"),
"http://localhost:11434"
);
assert_eq!(
base_url_to_host_root("http://localhost:11434/"),
"http://localhost:11434"
);
}
#[test]
fn test_probe_url_for_base() {
assert_eq!(
probe_url_for_base("http://localhost:11434/v1"),
"http://localhost:11434/v1/models"
);
assert_eq!(
probe_url_for_base("http://localhost:11434"),
"http://localhost:11434/api/tags"
);
}
#[test]
fn test_pull_events_decoder_status_and_success() {
let v: JsonValue = serde_json::json!({"status":"verifying"});
let events = pull_events_from_value(&v);
assert!(matches!(events.as_slice(), [PullEvent::Status(s)] if s == "verifying"));
let v2: JsonValue = serde_json::json!({"status":"success"});
let events2 = pull_events_from_value(&v2);
assert_eq!(events2.len(), 2);
assert!(matches!(events2[0], PullEvent::Status(ref s) if s == "success"));
assert!(matches!(events2[1], PullEvent::Success));
}
#[test]
fn test_pull_events_decoder_progress() {
let v: JsonValue = serde_json::json!({"digest":"sha256:abc","total":100});
let events = pull_events_from_value(&v);
assert_eq!(events.len(), 1);
match &events[0] {
PullEvent::ChunkProgress {
digest,
total,
completed,
} => {
assert_eq!(digest, "sha256:abc");
assert_eq!(*total, Some(100));
assert_eq!(*completed, None);
}
_ => panic!("expected ChunkProgress"),
}
let v2: JsonValue = serde_json::json!({"digest":"sha256:def","completed":42});
let events2 = pull_events_from_value(&v2);
assert_eq!(events2.len(), 1);
match &events2[0] {
PullEvent::ChunkProgress {
digest,
total,
completed,
} => {
assert_eq!(digest, "sha256:def");
assert_eq!(*total, None);
assert_eq!(*completed, Some(42));
}
_ => panic!("expected ChunkProgress"),
}
}
#[test]
fn test_upsert_provider_and_models() {
let mut doc = Document::new();
let tbl = upsert_provider_ollama(&mut doc);
assert!(tbl.contains_key("name"));
assert!(tbl.contains_key("base_url"));
assert!(tbl.contains_key("wire_api"));
set_ollama_models(&mut doc, &vec!["llama3.2:3b".to_string()]);
let root = doc.as_table();
let mp = root
.get("model_providers")
.and_then(|i| i.as_table())
.expect("model_providers");
let ollama = mp.get("ollama").and_then(|i| i.as_table()).expect("ollama");
let arr = ollama.get("models").expect("models array");
assert!(arr.is_array(), "models should be an array");
let s = doc.to_string();
assert!(s.contains("model_providers"));
assert!(s.contains("ollama"));
assert!(s.contains("models"));
}
}

View File

@@ -128,6 +128,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
// have a configured provider entry and that a local server is running.
if ollama {
if let Err(e) = ensure_configured_and_running().await {
tracing::error!("{e}");
eprintln!("{e}");
std::process::exit(1);
}
@@ -172,6 +173,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
if let Err(e) =
ensure_model_available(&model_name, &client, &config_path, &mut reporter).await
{
tracing::error!("{e}");
eprintln!("{e}");
std::process::exit(1);
}

View File

@@ -342,6 +342,7 @@ pub async fn run_main(
if let Err(e) = ensure_configured_and_running().await {
#[allow(clippy::print_stderr)]
{
tracing::error!("{e}");
eprintln!("{e}");
}
std::process::exit(1);
@@ -407,6 +408,7 @@ pub async fn run_main(
ensure_model_available(model_name, &client, &config_path, &mut reporter).await
{
let mut out = std::io::stderr();
tracing::error!("{e}");
let _ = out.write_all(format!("{e}\n").as_bytes());
let _ = out.flush();
std::process::exit(1);