mirror of
https://github.com/openai/codex.git
synced 2026-04-24 22:54:54 +00:00
adding tests
This commit is contained in:
@@ -102,6 +102,18 @@ pub enum CodexErr {
|
||||
|
||||
#[error("{0}")]
|
||||
EnvVar(EnvVarError),
|
||||
|
||||
// ------------------------------
|
||||
// Ollama‑specific errors
|
||||
// ------------------------------
|
||||
#[error("no running Ollama server detected; please install/start Ollama")]
|
||||
OllamaServerUnreachable,
|
||||
|
||||
#[error("ollama model not found: {0}")]
|
||||
OllamaModelNotFound(String),
|
||||
|
||||
#[error("ollama pull failed: {0}")]
|
||||
OllamaPullFailed(String),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use crate::error::CodexErr;
|
||||
use crate::error::Result as CoreResult;
|
||||
use crate::model_provider_info::ModelProviderInfo;
|
||||
use crate::model_provider_info::WireApi;
|
||||
use bytes::BytesMut;
|
||||
@@ -56,10 +58,10 @@ pub fn probe_url_for_base(base_url: &str) -> String {
|
||||
}
|
||||
|
||||
/// Convenience helper to probe an Ollama server given a provider style base URL.
|
||||
pub async fn probe_ollama_server(base_url: &str) -> io::Result<bool> {
|
||||
pub async fn probe_ollama_server(base_url: &str) -> CoreResult<bool> {
|
||||
let url = probe_url_for_base(base_url);
|
||||
let resp = reqwest::Client::new().get(url).send().await;
|
||||
Ok(matches!(resp, Ok(r) if r.status().is_success()))
|
||||
let resp = reqwest::Client::new().get(url).send().await?;
|
||||
Ok(resp.status().is_success())
|
||||
}
|
||||
/// Coordinator wrapper used by frontends when responding to `--ollama`.
|
||||
///
|
||||
@@ -68,7 +70,7 @@ pub async fn probe_ollama_server(base_url: &str) -> io::Result<bool> {
|
||||
/// - If the server is reachable, ensures an `[model_providers.ollama]` entry
|
||||
/// exists in `config.toml` with sensible defaults.
|
||||
/// - If no server is reachable, returns an error.
|
||||
pub async fn ensure_configured_and_running() -> io::Result<()> {
|
||||
pub async fn ensure_configured_and_running() -> CoreResult<()> {
|
||||
use crate::config::find_codex_home;
|
||||
use toml::Value as TomlValue;
|
||||
|
||||
@@ -94,9 +96,7 @@ pub async fn ensure_configured_and_running() -> io::Result<()> {
|
||||
// Probe reachability.
|
||||
let ok = probe_ollama_server(&base_url).await?;
|
||||
if !ok {
|
||||
return Err(io::Error::other(
|
||||
"No running Ollama server detected. Please install/start Ollama: https://github.com/ollama/ollama?tab=readme-ov-file#ollama",
|
||||
));
|
||||
return Err(CodexErr::OllamaServerUnreachable);
|
||||
}
|
||||
|
||||
// Ensure provider entry exists with defaults.
|
||||
@@ -231,13 +231,8 @@ impl PullProgressReporter for CliProgressReporter {
|
||||
|
||||
/// For now the TUI reporter delegates to the CLI reporter. This keeps UI and
|
||||
/// CLI behavior aligned until a dedicated TUI integration is implemented.
|
||||
#[derive(Default)]
|
||||
pub struct TuiProgressReporter(CliProgressReporter);
|
||||
|
||||
impl Default for TuiProgressReporter {
|
||||
fn default() -> Self {
|
||||
Self(CliProgressReporter::new())
|
||||
}
|
||||
}
|
||||
impl TuiProgressReporter {
|
||||
pub fn new() -> Self {
|
||||
Default::default()
|
||||
@@ -358,20 +353,14 @@ impl OllamaClient {
|
||||
let text = text.trim();
|
||||
if text.is_empty() { continue; }
|
||||
if let Ok(value) = serde_json::from_str::<JsonValue>(text) {
|
||||
for ev in pull_events_from_value(&value) { yield ev; }
|
||||
if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) {
|
||||
yield PullEvent::Status(format!("error: {err_msg}"));
|
||||
return;
|
||||
}
|
||||
if let Some(status) = value.get("status").and_then(|s| s.as_str()) {
|
||||
yield PullEvent::Status(status.to_string());
|
||||
if status == "success" { yield PullEvent::Success; return; }
|
||||
}
|
||||
let digest = value.get("digest").and_then(|d| d.as_str()).unwrap_or("").to_string();
|
||||
let total = value.get("total").and_then(|t| t.as_u64());
|
||||
let completed = value.get("completed").and_then(|t| t.as_u64());
|
||||
if total.is_some() || completed.is_some() {
|
||||
yield PullEvent::ChunkProgress { digest, total, completed };
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -513,7 +502,7 @@ pub async fn ensure_model_available(
|
||||
client: &OllamaClient,
|
||||
config_path: &Path,
|
||||
reporter: &mut dyn PullProgressReporter,
|
||||
) -> io::Result<()> {
|
||||
) -> CoreResult<()> {
|
||||
let mut listed = read_ollama_models_list(config_path);
|
||||
let available = client.fetch_models().await.unwrap_or_default();
|
||||
if available.iter().any(|m| m == model) {
|
||||
@@ -527,9 +516,7 @@ pub async fn ensure_model_available(
|
||||
}
|
||||
|
||||
if !DEFAULT_PULL_ALLOWLIST.contains(&model) {
|
||||
return Err(io::Error::other(format!(
|
||||
"Model `{model}` not found locally and not in allowlist for automatic download."
|
||||
)));
|
||||
return Err(CodexErr::OllamaModelNotFound(model.to_string()));
|
||||
}
|
||||
|
||||
loop {
|
||||
@@ -568,25 +555,20 @@ fn write_document(path: &Path, doc: &Document) -> io::Result<()> {
|
||||
}
|
||||
|
||||
pub fn upsert_provider_ollama(doc: &mut Document) -> &mut Table {
|
||||
// Ensure the provider tables exist first, then take a single mutable borrow.
|
||||
if !doc["model_providers"].is_table() {
|
||||
if doc.get("model_providers").is_none() {
|
||||
doc.as_table_mut()
|
||||
.insert("model_providers", Item::Table(Table::new()));
|
||||
} else if !doc["model_providers"].is_table() {
|
||||
doc["model_providers"] = Item::Table(Table::new());
|
||||
}
|
||||
{
|
||||
// Narrow scope: mutate/create the nested `ollama` table without keeping a borrow alive.
|
||||
let providers = match doc["model_providers"].as_table_mut() {
|
||||
Some(table) => table,
|
||||
None => return Box::leak(Box::new(Table::default())),
|
||||
};
|
||||
if !providers.contains_key("ollama") || !providers["ollama"].is_table() {
|
||||
providers["ollama"] = Item::Table(Table::new());
|
||||
}
|
||||
|
||||
let providers = doc["model_providers"]
|
||||
.as_table_mut()
|
||||
.expect("providers table");
|
||||
if providers.get("ollama").is_none() || !providers["ollama"].is_table() {
|
||||
providers["ollama"] = Item::Table(Table::new());
|
||||
}
|
||||
// Now, safely borrow the `ollama` table mutably once and return it.
|
||||
let tbl = match doc["model_providers"]["ollama"].as_table_mut() {
|
||||
Some(table) => table,
|
||||
None => return Box::leak(Box::new(Table::default())),
|
||||
};
|
||||
let tbl = providers["ollama"].as_table_mut().expect("ollama table");
|
||||
if !tbl.contains_key("name") {
|
||||
tbl["name"] = Item::Value(TomlValueEdit::from("Ollama"));
|
||||
}
|
||||
@@ -607,3 +589,133 @@ pub fn set_ollama_models(doc: &mut Document, models: &[String]) {
|
||||
}
|
||||
tbl["models"] = Item::Value(TomlValueEdit::Array(arr));
|
||||
}
|
||||
|
||||
// Convert a single JSON object representing a pull update into one or more events.
|
||||
fn pull_events_from_value(value: &JsonValue) -> Vec<PullEvent> {
|
||||
let mut events = Vec::new();
|
||||
if let Some(status) = value.get("status").and_then(|s| s.as_str()) {
|
||||
events.push(PullEvent::Status(status.to_string()));
|
||||
if status == "success" {
|
||||
events.push(PullEvent::Success);
|
||||
}
|
||||
}
|
||||
let digest = value
|
||||
.get("digest")
|
||||
.and_then(|d| d.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let total = value.get("total").and_then(|t| t.as_u64());
|
||||
let completed = value.get("completed").and_then(|t| t.as_u64());
|
||||
if total.is_some() || completed.is_some() {
|
||||
events.push(PullEvent::ChunkProgress {
|
||||
digest,
|
||||
total,
|
||||
completed,
|
||||
});
|
||||
}
|
||||
events
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use toml_edit::DocumentMut as Document;
|
||||
|
||||
#[test]
|
||||
fn test_base_url_to_host_root() {
|
||||
assert_eq!(
|
||||
base_url_to_host_root("http://localhost:11434/v1"),
|
||||
"http://localhost:11434"
|
||||
);
|
||||
assert_eq!(
|
||||
base_url_to_host_root("http://localhost:11434"),
|
||||
"http://localhost:11434"
|
||||
);
|
||||
assert_eq!(
|
||||
base_url_to_host_root("http://localhost:11434/"),
|
||||
"http://localhost:11434"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_probe_url_for_base() {
|
||||
assert_eq!(
|
||||
probe_url_for_base("http://localhost:11434/v1"),
|
||||
"http://localhost:11434/v1/models"
|
||||
);
|
||||
assert_eq!(
|
||||
probe_url_for_base("http://localhost:11434"),
|
||||
"http://localhost:11434/api/tags"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pull_events_decoder_status_and_success() {
|
||||
let v: JsonValue = serde_json::json!({"status":"verifying"});
|
||||
let events = pull_events_from_value(&v);
|
||||
assert!(matches!(events.as_slice(), [PullEvent::Status(s)] if s == "verifying"));
|
||||
|
||||
let v2: JsonValue = serde_json::json!({"status":"success"});
|
||||
let events2 = pull_events_from_value(&v2);
|
||||
assert_eq!(events2.len(), 2);
|
||||
assert!(matches!(events2[0], PullEvent::Status(ref s) if s == "success"));
|
||||
assert!(matches!(events2[1], PullEvent::Success));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pull_events_decoder_progress() {
|
||||
let v: JsonValue = serde_json::json!({"digest":"sha256:abc","total":100});
|
||||
let events = pull_events_from_value(&v);
|
||||
assert_eq!(events.len(), 1);
|
||||
match &events[0] {
|
||||
PullEvent::ChunkProgress {
|
||||
digest,
|
||||
total,
|
||||
completed,
|
||||
} => {
|
||||
assert_eq!(digest, "sha256:abc");
|
||||
assert_eq!(*total, Some(100));
|
||||
assert_eq!(*completed, None);
|
||||
}
|
||||
_ => panic!("expected ChunkProgress"),
|
||||
}
|
||||
|
||||
let v2: JsonValue = serde_json::json!({"digest":"sha256:def","completed":42});
|
||||
let events2 = pull_events_from_value(&v2);
|
||||
assert_eq!(events2.len(), 1);
|
||||
match &events2[0] {
|
||||
PullEvent::ChunkProgress {
|
||||
digest,
|
||||
total,
|
||||
completed,
|
||||
} => {
|
||||
assert_eq!(digest, "sha256:def");
|
||||
assert_eq!(*total, None);
|
||||
assert_eq!(*completed, Some(42));
|
||||
}
|
||||
_ => panic!("expected ChunkProgress"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_upsert_provider_and_models() {
|
||||
let mut doc = Document::new();
|
||||
let tbl = upsert_provider_ollama(&mut doc);
|
||||
assert!(tbl.contains_key("name"));
|
||||
assert!(tbl.contains_key("base_url"));
|
||||
assert!(tbl.contains_key("wire_api"));
|
||||
set_ollama_models(&mut doc, &vec!["llama3.2:3b".to_string()]);
|
||||
let root = doc.as_table();
|
||||
let mp = root
|
||||
.get("model_providers")
|
||||
.and_then(|i| i.as_table())
|
||||
.expect("model_providers");
|
||||
let ollama = mp.get("ollama").and_then(|i| i.as_table()).expect("ollama");
|
||||
let arr = ollama.get("models").expect("models array");
|
||||
assert!(arr.is_array(), "models should be an array");
|
||||
let s = doc.to_string();
|
||||
assert!(s.contains("model_providers"));
|
||||
assert!(s.contains("ollama"));
|
||||
assert!(s.contains("models"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,6 +128,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
// have a configured provider entry and that a local server is running.
|
||||
if ollama {
|
||||
if let Err(e) = ensure_configured_and_running().await {
|
||||
tracing::error!("{e}");
|
||||
eprintln!("{e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
@@ -172,6 +173,7 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any
|
||||
if let Err(e) =
|
||||
ensure_model_available(&model_name, &client, &config_path, &mut reporter).await
|
||||
{
|
||||
tracing::error!("{e}");
|
||||
eprintln!("{e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
@@ -342,6 +342,7 @@ pub async fn run_main(
|
||||
if let Err(e) = ensure_configured_and_running().await {
|
||||
#[allow(clippy::print_stderr)]
|
||||
{
|
||||
tracing::error!("{e}");
|
||||
eprintln!("{e}");
|
||||
}
|
||||
std::process::exit(1);
|
||||
@@ -407,6 +408,7 @@ pub async fn run_main(
|
||||
ensure_model_available(model_name, &client, &config_path, &mut reporter).await
|
||||
{
|
||||
let mut out = std::io::stderr();
|
||||
tracing::error!("{e}");
|
||||
let _ = out.write_all(format!("{e}\n").as_bytes());
|
||||
let _ = out.flush();
|
||||
std::process::exit(1);
|
||||
|
||||
Reference in New Issue
Block a user