ollama provider broke down in multiple files

This commit is contained in:
pap
2025-08-04 18:10:46 +01:00
parent 5df778471c
commit 304d01c099
5 changed files with 583 additions and 545 deletions

View File

@@ -0,0 +1,162 @@
use bytes::BytesMut;
use futures::StreamExt;
use futures::stream::BoxStream;
use serde_json::Value as JsonValue;
use std::collections::VecDeque;
use std::io;
use crate::model_provider_info::{ModelProviderInfo, WireApi};
use super::parser::pull_events_from_value;
use super::url::{base_url_to_host_root, is_openai_compatible_base_url};
use super::{DEFAULT_BASE_URL, PullEvent, PullProgressReporter};
/// Client for interacting with a local Ollama instance.
pub struct OllamaClient {
client: reqwest::Client,
host_root: String,
uses_openai_compat: bool,
}
impl OllamaClient {
/// Build a client from a provider definition. Falls back to the default
/// local URL if no base_url is configured.
pub fn from_provider(provider: &ModelProviderInfo) -> Self {
let base_url = provider
.base_url
.clone()
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
let uses_openai_compat = is_openai_compatible_base_url(&base_url)
|| matches!(provider.wire_api, WireApi::Chat)
&& is_openai_compatible_base_url(&base_url);
let host_root = base_url_to_host_root(&base_url);
Self {
client: reqwest::Client::new(),
host_root,
uses_openai_compat,
}
}
/// Low-level constructor given a raw host root, e.g. "http://localhost:11434".
pub fn from_host_root(host_root: impl Into<String>) -> Self {
Self {
client: reqwest::Client::new(),
host_root: host_root.into(),
uses_openai_compat: false,
}
}
/// Probe whether the server is reachable by hitting the appropriate health endpoint.
pub async fn probe_server(&self) -> io::Result<bool> {
let url = if self.uses_openai_compat {
format!("{}/v1/models", self.host_root.trim_end_matches('/'))
} else {
format!("{}/api/tags", self.host_root.trim_end_matches('/'))
};
let resp = self.client.get(url).send().await;
Ok(matches!(resp, Ok(r) if r.status().is_success()))
}
/// Return the list of model names known to the local Ollama instance.
pub async fn fetch_models(&self) -> io::Result<Vec<String>> {
let tags_url = format!("{}/api/tags", self.host_root.trim_end_matches('/'));
let resp = self
.client
.get(tags_url)
.send()
.await
.map_err(io::Error::other)?;
if !resp.status().is_success() {
return Ok(Vec::new());
}
let val = resp.json::<JsonValue>().await.map_err(io::Error::other)?;
let names = val
.get("models")
.and_then(|m| m.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("name").and_then(|n| n.as_str()))
.map(|s| s.to_string())
.collect::<Vec<_>>()
})
.unwrap_or_default();
Ok(names)
}
/// Start a model pull and emit streaming events. The returned stream ends when
/// a Success event is observed or the server closes the connection.
pub async fn pull_model_stream(
&self,
model: &str,
) -> io::Result<BoxStream<'static, PullEvent>> {
let url = format!("{}/api/pull", self.host_root.trim_end_matches('/'));
let resp = self
.client
.post(url)
.json(&serde_json::json!({"model": model, "stream": true}))
.send()
.await
.map_err(io::Error::other)?;
if !resp.status().is_success() {
return Err(io::Error::other(format!(
"failed to start pull: HTTP {}",
resp.status()
)));
}
let mut stream = resp.bytes_stream();
let mut buf = BytesMut::new();
let _pending: VecDeque<PullEvent> = VecDeque::new();
// Using an async stream adaptor backed by unfold-like manual loop.
let s = async_stream::stream! {
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
buf.extend_from_slice(&bytes);
while let Some(pos) = buf.iter().position(|b| *b == b'\n') {
let line = buf.split_to(pos + 1);
if let Ok(text) = std::str::from_utf8(&line) {
let text = text.trim();
if text.is_empty() { continue; }
if let Ok(value) = serde_json::from_str::<JsonValue>(text) {
for ev in pull_events_from_value(&value) { yield ev; }
if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) {
yield PullEvent::Status(format!("error: {err_msg}"));
return;
}
if let Some(status) = value.get("status").and_then(|s| s.as_str()) {
if status == "success" { yield PullEvent::Success; return; }
}
}
}
}
}
Err(_) => {
// Connection error: end the stream.
return;
}
}
}
};
Ok(Box::pin(s))
}
/// High-level helper to pull a model and drive a progress reporter.
pub async fn pull_with_reporter(
&self,
model: &str,
reporter: &mut dyn PullProgressReporter,
) -> io::Result<()> {
reporter.on_event(&PullEvent::Status(format!("Pulling model {model}...")))?;
let mut stream = self.pull_model_stream(model).await?;
while let Some(event) = stream.next().await {
reporter.on_event(&event)?;
if matches!(event, PullEvent::Success) {
break;
}
}
Ok(())
}
}

View File

@@ -0,0 +1,243 @@
use std::io;
use std::path::Path;
use std::str::FromStr;
use toml_edit::DocumentMut as Document;
use toml_edit::Item;
use toml_edit::Table;
use toml_edit::Value as TomlValueEdit;
use super::DEFAULT_BASE_URL;
/// Read the list of models recorded under [model_providers.ollama].models.
pub fn read_ollama_models_list(config_path: &Path) -> Vec<String> {
match std::fs::read_to_string(config_path)
.ok()
.and_then(|s| toml::from_str::<toml::Value>(&s).ok())
{
Some(toml::Value::Table(root)) => root
.get("model_providers")
.and_then(|v| v.as_table())
.and_then(|t| t.get("ollama"))
.and_then(|v| v.as_table())
.and_then(|t| t.get("models"))
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect::<Vec<_>>()
})
.unwrap_or_default(),
_ => Vec::new(),
}
}
/// Convenience wrapper that returns the models list as an io::Result for callers
/// that want a uniform Result-based API.
pub fn read_config_models(config_path: &Path) -> io::Result<Vec<String>> {
Ok(read_ollama_models_list(config_path))
}
/// Overwrite the recorded models list under [model_providers.ollama].models using toml_edit.
pub fn write_ollama_models_list(config_path: &Path, models: &[String]) -> io::Result<()> {
let mut doc = read_document(config_path)?;
{
let tbl = upsert_provider_ollama(&mut doc);
let mut arr = toml_edit::Array::new();
for m in models {
arr.push(TomlValueEdit::from(m.clone()));
}
tbl["models"] = Item::Value(TomlValueEdit::Array(arr));
}
write_document(config_path, &doc)
}
/// Write models list via a uniform name expected by higher layers.
pub fn write_config_models(config_path: &Path, models: &[String]) -> io::Result<()> {
write_ollama_models_list(config_path, models)
}
/// Ensure `[model_providers.ollama]` exists with sensible defaults on disk.
/// Returns true if it created/updated the entry.
pub fn ensure_ollama_provider_entry(codex_home: &Path) -> io::Result<bool> {
let config_path = codex_home.join("config.toml");
let mut doc = read_document(&config_path)?;
let before = doc.to_string();
let _tbl = upsert_provider_ollama(&mut doc);
let after = doc.to_string();
if before != after {
write_document(&config_path, &doc)?;
Ok(true)
} else {
Ok(false)
}
}
/// Alias name mirroring the refactor plan wording.
pub fn ensure_provider_entry_and_defaults(codex_home: &Path) -> io::Result<bool> {
ensure_ollama_provider_entry(codex_home)
}
/// Read whether the provider exists and how many models are recorded under it.
pub fn read_provider_state(config_path: &Path) -> (bool, usize) {
match std::fs::read_to_string(config_path)
.ok()
.and_then(|s| toml::from_str::<toml::Value>(&s).ok())
{
Some(toml::Value::Table(root)) => {
let provider_present = root
.get("model_providers")
.and_then(|v| v.as_table())
.and_then(|t| t.get("ollama"))
.map(|_| true)
.unwrap_or(false);
let models_count = root
.get("model_providers")
.and_then(|v| v.as_table())
.and_then(|t| t.get("ollama"))
.and_then(|v| v.as_table())
.and_then(|t| t.get("models"))
.and_then(|v| v.as_array())
.map(|arr| arr.len())
.unwrap_or(0);
(provider_present, models_count)
}
_ => (false, 0),
}
}
// ---------- toml_edit helpers ----------
fn read_document(path: &Path) -> io::Result<Document> {
match std::fs::read_to_string(path) {
Ok(s) => Document::from_str(&s).map_err(io::Error::other),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(Document::new()),
Err(e) => Err(e),
}
}
fn write_document(path: &Path, doc: &Document) -> io::Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
std::fs::write(path, doc.to_string())
}
pub fn upsert_provider_ollama(doc: &mut Document) -> &mut Table {
// Ensure "model_providers" exists and is a table.
let needs_init = match doc.get("model_providers") {
None => true,
Some(item) => !item.is_table(),
};
if needs_init {
doc.as_table_mut()
.insert("model_providers", Item::Table(Table::new()));
}
// Now, get a mutable reference to the "model_providers" table without `expect`/`unwrap`.
let providers: &mut Table = {
// Insert if missing.
if doc.as_table().get("model_providers").is_none() {
doc.as_table_mut()
.insert("model_providers", Item::Table(Table::new()));
}
match doc.as_table_mut().get_mut("model_providers") {
Some(item) => {
if !item.is_table() {
*item = Item::Table(Table::new());
}
match item.as_table_mut() {
Some(t) => t,
None => unreachable!("model_providers was set to a table"),
}
}
None => unreachable!("model_providers should exist after insertion"),
}
};
// Ensure "ollama" exists and is a table.
let needs_ollama_init = match providers.get("ollama") {
None => true,
Some(item) => !item.is_table(),
};
if needs_ollama_init {
providers.insert("ollama", Item::Table(Table::new()));
}
// Get a mutable reference to the "ollama" table without `expect`/`unwrap`.
let tbl: &mut Table = {
let needs_set = match providers.get("ollama") {
None => true,
Some(item) => !item.is_table(),
};
if needs_set {
providers.insert("ollama", Item::Table(Table::new()));
}
match providers.get_mut("ollama") {
Some(item) => {
if !item.is_table() {
*item = Item::Table(Table::new());
}
match item.as_table_mut() {
Some(t) => t,
None => unreachable!("ollama was set to a table"),
}
}
None => unreachable!("ollama should exist after insertion"),
}
};
if !tbl.contains_key("name") {
tbl["name"] = Item::Value(TomlValueEdit::from("Ollama"));
}
if !tbl.contains_key("base_url") {
tbl["base_url"] = Item::Value(TomlValueEdit::from(DEFAULT_BASE_URL));
}
if !tbl.contains_key("wire_api") {
tbl["wire_api"] = Item::Value(TomlValueEdit::from("chat"));
}
tbl
}
pub fn set_ollama_models(doc: &mut Document, models: &[String]) {
let tbl = upsert_provider_ollama(doc);
let mut arr = toml_edit::Array::new();
for m in models {
arr.push(TomlValueEdit::from(m.clone()));
}
tbl["models"] = Item::Value(TomlValueEdit::Array(arr));
}
#[cfg(test)]
mod tests {
use super::*;
use toml_edit::DocumentMut as Document;
#[test]
fn test_upsert_provider_and_models() {
let mut doc = Document::new();
let tbl = upsert_provider_ollama(&mut doc);
assert!(tbl.contains_key("name"));
assert!(tbl.contains_key("base_url"));
assert!(tbl.contains_key("wire_api"));
set_ollama_models(&mut doc, &[String::from("llama3.2:3b")]);
let root = doc.as_table();
let mp = match root.get("model_providers").and_then(|i| i.as_table()) {
Some(t) => t,
None => panic!("model_providers"),
};
let ollama = match mp.get("ollama").and_then(|i| i.as_table()) {
Some(t) => t,
None => panic!("ollama"),
};
let arr = match ollama.get("models") {
Some(v) => v,
None => panic!("models array"),
};
assert!(arr.is_array(), "models should be an array");
let s = doc.to_string();
assert!(s.contains("model_providers"));
assert!(s.contains("ollama"));
assert!(s.contains("models"));
}
}

View File

@@ -1,68 +1,25 @@
use crate::error::CodexErr;
use crate::error::Result as CoreResult;
use crate::model_provider_info::ModelProviderInfo;
use crate::model_provider_info::WireApi;
use bytes::BytesMut;
use futures::StreamExt;
use futures::stream::BoxStream;
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::io;
use std::io::Write;
use std::path::Path;
use std::str::FromStr;
use toml_edit::DocumentMut as Document;
use toml_edit::Item;
use toml_edit::Table;
use toml_edit::Value as TomlValueEdit;
pub const DEFAULT_BASE_URL: &str = "http://localhost:11434/v1";
pub const DEFAULT_WIRE_API: WireApi = WireApi::Chat;
pub const DEFAULT_WIRE_API: crate::model_provider_info::WireApi =
crate::model_provider_info::WireApi::Chat;
pub const DEFAULT_PULL_ALLOWLIST: &[&str] = &["llama3.2:3b"];
/// Identify whether a base_url points at an OpenAI-compatible root (".../v1").
fn is_openai_compatible_base_url(base_url: &str) -> bool {
base_url.trim_end_matches('/').ends_with("/v1")
}
pub mod client;
pub mod config;
pub mod parser;
pub mod url;
/// Convert a provider base_url into the native Ollama host root.
/// For example, "http://localhost:11434/v1" -> "http://localhost:11434".
pub fn base_url_to_host_root(base_url: &str) -> String {
let trimmed = base_url.trim_end_matches('/');
if trimmed.ends_with("/v1") {
trimmed
.trim_end_matches("/v1")
.trim_end_matches('/')
.to_string()
} else {
trimmed.to_string()
}
}
/// Variant that considers an explicit WireApi value; provided to centralize
/// host root computation in one place for future extension.
pub fn base_url_to_host_root_with_wire(base_url: &str, _wire_api: WireApi) -> String {
base_url_to_host_root(base_url)
}
/// Compute the probe URL to verify if an Ollama server is reachable.
/// If the configured base is OpenAI-compatible (/v1), probe "models", otherwise
/// fall back to the native "/api/tags" endpoint.
pub fn probe_url_for_base(base_url: &str) -> String {
if is_openai_compatible_base_url(base_url) {
format!("{}/models", base_url.trim_end_matches('/'))
} else {
format!("{}/api/tags", base_url.trim_end_matches('/'))
}
}
/// Convenience helper to probe an Ollama server given a provider style base URL.
pub async fn probe_ollama_server(base_url: &str) -> CoreResult<bool> {
let url = probe_url_for_base(base_url);
let resp = reqwest::Client::new().get(url).send().await?;
Ok(resp.status().is_success())
}
pub use client::OllamaClient;
pub use config::{read_config_models, read_provider_state, write_config_models};
pub use url::{
base_url_to_host_root, base_url_to_host_root_with_wire, probe_ollama_server, probe_url_for_base,
};
/// Coordinator wrapper used by frontends when responding to `--ollama`.
///
/// - Probes the server using the configured base_url when present, otherwise
@@ -94,13 +51,13 @@ pub async fn ensure_configured_and_running() -> CoreResult<()> {
};
// Probe reachability.
let ok = probe_ollama_server(&base_url).await?;
let ok = url::probe_ollama_server(&base_url).await?;
if !ok {
return Err(CodexErr::OllamaServerUnreachable);
}
// Ensure provider entry exists with defaults.
let _ = ensure_ollama_provider_entry(&codex_home)?;
let _ = config::ensure_ollama_provider_entry(&codex_home)?;
Ok(())
}
@@ -243,254 +200,6 @@ impl PullProgressReporter for TuiProgressReporter {
self.0.on_event(event)
}
}
/// Client for interacting with a local Ollama instance.
pub struct OllamaClient {
client: reqwest::Client,
host_root: String,
uses_openai_compat: bool,
}
impl OllamaClient {
/// Build a client from a provider definition. Falls back to the default
/// local URL if no base_url is configured.
pub fn from_provider(provider: &ModelProviderInfo) -> Self {
let base_url = provider
.base_url
.clone()
.unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
let uses_openai_compat = is_openai_compatible_base_url(&base_url)
|| matches!(provider.wire_api, WireApi::Chat)
&& is_openai_compatible_base_url(&base_url);
let host_root = base_url_to_host_root(&base_url);
Self {
client: reqwest::Client::new(),
host_root,
uses_openai_compat,
}
}
/// Low-level constructor given a raw host root, e.g. "http://localhost:11434".
pub fn from_host_root(host_root: impl Into<String>) -> Self {
Self {
client: reqwest::Client::new(),
host_root: host_root.into(),
uses_openai_compat: false,
}
}
/// Probe whether the server is reachable by hitting the appropriate health endpoint.
pub async fn probe_server(&self) -> io::Result<bool> {
let url = if self.uses_openai_compat {
format!("{}/v1/models", self.host_root.trim_end_matches('/'))
} else {
format!("{}/api/tags", self.host_root.trim_end_matches('/'))
};
let resp = self.client.get(url).send().await;
Ok(matches!(resp, Ok(r) if r.status().is_success()))
}
/// Return the list of model names known to the local Ollama instance.
pub async fn fetch_models(&self) -> io::Result<Vec<String>> {
let tags_url = format!("{}/api/tags", self.host_root.trim_end_matches('/'));
let resp = self
.client
.get(tags_url)
.send()
.await
.map_err(io::Error::other)?;
if !resp.status().is_success() {
return Ok(Vec::new());
}
let val = resp.json::<JsonValue>().await.map_err(io::Error::other)?;
let names = val
.get("models")
.and_then(|m| m.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.get("name").and_then(|n| n.as_str()))
.map(|s| s.to_string())
.collect::<Vec<_>>()
})
.unwrap_or_default();
Ok(names)
}
/// Start a model pull and emit streaming events. The returned stream ends when
/// a Success event is observed or the server closes the connection.
pub async fn pull_model_stream(
&self,
model: &str,
) -> io::Result<BoxStream<'static, PullEvent>> {
let url = format!("{}/api/pull", self.host_root.trim_end_matches('/'));
let resp = self
.client
.post(url)
.json(&serde_json::json!({"model": model, "stream": true}))
.send()
.await
.map_err(io::Error::other)?;
if !resp.status().is_success() {
return Err(io::Error::other(format!(
"failed to start pull: HTTP {}",
resp.status()
)));
}
let mut stream = resp.bytes_stream();
let mut buf = BytesMut::new();
let _pending: VecDeque<PullEvent> = VecDeque::new();
// Using an async stream adaptor backed by unfold-like manual loop.
let s = async_stream::stream! {
while let Some(chunk) = stream.next().await {
match chunk {
Ok(bytes) => {
buf.extend_from_slice(&bytes);
while let Some(pos) = buf.iter().position(|b| *b == b'\n') {
let line = buf.split_to(pos + 1);
if let Ok(text) = std::str::from_utf8(&line) {
let text = text.trim();
if text.is_empty() { continue; }
if let Ok(value) = serde_json::from_str::<JsonValue>(text) {
for ev in pull_events_from_value(&value) { yield ev; }
if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) {
yield PullEvent::Status(format!("error: {err_msg}"));
return;
}
if let Some(status) = value.get("status").and_then(|s| s.as_str()) {
if status == "success" { yield PullEvent::Success; return; }
}
}
}
}
}
Err(_) => {
// Connection error: end the stream.
return;
}
}
}
};
Ok(Box::pin(s))
}
/// High-level helper to pull a model and drive a progress reporter.
pub async fn pull_with_reporter(
&self,
model: &str,
reporter: &mut dyn PullProgressReporter,
) -> io::Result<()> {
reporter.on_event(&PullEvent::Status(format!("Pulling model {model}...")))?;
let mut stream = self.pull_model_stream(model).await?;
while let Some(event) = stream.next().await {
reporter.on_event(&event)?;
if matches!(event, PullEvent::Success) {
break;
}
}
Ok(())
}
}
/// Read the list of models recorded under [model_providers.ollama].models.
pub fn read_ollama_models_list(config_path: &Path) -> Vec<String> {
match std::fs::read_to_string(config_path)
.ok()
.and_then(|s| toml::from_str::<toml::Value>(&s).ok())
{
Some(toml::Value::Table(root)) => root
.get("model_providers")
.and_then(|v| v.as_table())
.and_then(|t| t.get("ollama"))
.and_then(|v| v.as_table())
.and_then(|t| t.get("models"))
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect::<Vec<_>>()
})
.unwrap_or_default(),
_ => Vec::new(),
}
}
/// Convenience wrapper that returns the models list as an io::Result for callers
/// that want a uniform Result-based API.
pub fn read_config_models(config_path: &Path) -> io::Result<Vec<String>> {
Ok(read_ollama_models_list(config_path))
}
/// Overwrite the recorded models list under [model_providers.ollama].models using toml_edit.
pub fn write_ollama_models_list(config_path: &Path, models: &[String]) -> io::Result<()> {
let mut doc = read_document(config_path)?;
{
let tbl = upsert_provider_ollama(&mut doc);
let mut arr = toml_edit::Array::new();
for m in models {
arr.push(TomlValueEdit::from(m.clone()));
}
tbl["models"] = Item::Value(TomlValueEdit::Array(arr));
}
write_document(config_path, &doc)
}
/// Write models list via a uniform name expected by higher layers.
pub fn write_config_models(config_path: &Path, models: &[String]) -> io::Result<()> {
write_ollama_models_list(config_path, models)
}
/// Ensure `[model_providers.ollama]` exists with sensible defaults on disk.
/// Returns true if it created/updated the entry.
pub fn ensure_ollama_provider_entry(codex_home: &Path) -> io::Result<bool> {
let config_path = codex_home.join("config.toml");
let mut doc = read_document(&config_path)?;
let before = doc.to_string();
let _tbl = upsert_provider_ollama(&mut doc);
let after = doc.to_string();
if before != after {
write_document(&config_path, &doc)?;
Ok(true)
} else {
Ok(false)
}
}
/// Alias name mirroring the refactor plan wording.
pub fn ensure_provider_entry_and_defaults(codex_home: &Path) -> io::Result<bool> {
ensure_ollama_provider_entry(codex_home)
}
/// Read whether the provider exists and how many models are recorded under it.
pub fn read_provider_state(config_path: &Path) -> (bool, usize) {
match std::fs::read_to_string(config_path)
.ok()
.and_then(|s| toml::from_str::<toml::Value>(&s).ok())
{
Some(toml::Value::Table(root)) => {
let provider_present = root
.get("model_providers")
.and_then(|v| v.as_table())
.and_then(|t| t.get("ollama"))
.map(|_| true)
.unwrap_or(false);
let models_count = root
.get("model_providers")
.and_then(|v| v.as_table())
.and_then(|t| t.get("ollama"))
.and_then(|v| v.as_table())
.and_then(|t| t.get("models"))
.and_then(|v| v.as_array())
.map(|arr| arr.len())
.unwrap_or(0);
(provider_present, models_count)
}
_ => (false, 0),
}
}
/// Ensure a model is available locally.
///
/// - If the model is already present, ensure it is recorded in config.toml.
@@ -503,14 +212,14 @@ pub async fn ensure_model_available(
config_path: &Path,
reporter: &mut dyn PullProgressReporter,
) -> CoreResult<()> {
let mut listed = read_ollama_models_list(config_path);
let mut listed = config::read_ollama_models_list(config_path);
let available = client.fetch_models().await.unwrap_or_default();
if available.iter().any(|m| m == model) {
if !listed.iter().any(|m| m == model) {
listed.push(model.to_string());
listed.sort();
listed.dedup();
let _ = write_ollama_models_list(config_path, &listed);
let _ = config::write_ollama_models_list(config_path, &listed);
}
return Ok(());
}
@@ -533,244 +242,6 @@ pub async fn ensure_model_available(
listed.push(model.to_string());
listed.sort();
listed.dedup();
let _ = write_ollama_models_list(config_path, &listed);
let _ = config::write_ollama_models_list(config_path, &listed);
Ok(())
}
// ---------- toml_edit helpers ----------
fn read_document(path: &Path) -> io::Result<Document> {
match std::fs::read_to_string(path) {
Ok(s) => Document::from_str(&s).map_err(io::Error::other),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(Document::new()),
Err(e) => Err(e),
}
}
fn write_document(path: &Path, doc: &Document) -> io::Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
std::fs::write(path, doc.to_string())
}
pub fn upsert_provider_ollama(doc: &mut Document) -> &mut Table {
// Ensure "model_providers" exists and is a table.
let needs_init = match doc.get("model_providers") {
None => true,
Some(item) => !item.is_table(),
};
if needs_init {
doc.as_table_mut()
.insert("model_providers", Item::Table(Table::new()));
}
// Now, get a mutable reference to the "model_providers" table without `expect`/`unwrap`.
let providers: &mut Table = {
// Insert if missing.
if doc.as_table().get("model_providers").is_none() {
doc.as_table_mut()
.insert("model_providers", Item::Table(Table::new()));
}
match doc.as_table_mut().get_mut("model_providers") {
Some(item) => {
if !item.is_table() {
*item = Item::Table(Table::new());
}
match item.as_table_mut() {
Some(t) => t,
None => unreachable!("model_providers was set to a table"),
}
}
None => unreachable!("model_providers should exist after insertion"),
}
};
// Ensure "ollama" exists and is a table.
let needs_ollama_init = match providers.get("ollama") {
None => true,
Some(item) => !item.is_table(),
};
if needs_ollama_init {
providers.insert("ollama", Item::Table(Table::new()));
}
// Get a mutable reference to the "ollama" table without `expect`/`unwrap`.
let tbl: &mut Table = {
let needs_set = match providers.get("ollama") {
None => true,
Some(item) => !item.is_table(),
};
if needs_set {
providers.insert("ollama", Item::Table(Table::new()));
}
match providers.get_mut("ollama") {
Some(item) => {
if !item.is_table() {
*item = Item::Table(Table::new());
}
match item.as_table_mut() {
Some(t) => t,
None => unreachable!("ollama was set to a table"),
}
}
None => unreachable!("ollama should exist after insertion"),
}
};
if !tbl.contains_key("name") {
tbl["name"] = Item::Value(TomlValueEdit::from("Ollama"));
}
if !tbl.contains_key("base_url") {
tbl["base_url"] = Item::Value(TomlValueEdit::from(DEFAULT_BASE_URL));
}
if !tbl.contains_key("wire_api") {
tbl["wire_api"] = Item::Value(TomlValueEdit::from("chat"));
}
tbl
}
pub fn set_ollama_models(doc: &mut Document, models: &[String]) {
let tbl = upsert_provider_ollama(doc);
let mut arr = toml_edit::Array::new();
for m in models {
arr.push(TomlValueEdit::from(m.clone()));
}
tbl["models"] = Item::Value(TomlValueEdit::Array(arr));
}
// Convert a single JSON object representing a pull update into one or more events.
fn pull_events_from_value(value: &JsonValue) -> Vec<PullEvent> {
let mut events = Vec::new();
if let Some(status) = value.get("status").and_then(|s| s.as_str()) {
events.push(PullEvent::Status(status.to_string()));
if status == "success" {
events.push(PullEvent::Success);
}
}
let digest = value
.get("digest")
.and_then(|d| d.as_str())
.unwrap_or("")
.to_string();
let total = value.get("total").and_then(|t| t.as_u64());
let completed = value.get("completed").and_then(|t| t.as_u64());
if total.is_some() || completed.is_some() {
events.push(PullEvent::ChunkProgress {
digest,
total,
completed,
});
}
events
}
#[cfg(test)]
mod tests {
use super::*;
use toml_edit::DocumentMut as Document;
#[test]
fn test_base_url_to_host_root() {
assert_eq!(
base_url_to_host_root("http://localhost:11434/v1"),
"http://localhost:11434"
);
assert_eq!(
base_url_to_host_root("http://localhost:11434"),
"http://localhost:11434"
);
assert_eq!(
base_url_to_host_root("http://localhost:11434/"),
"http://localhost:11434"
);
}
#[test]
fn test_probe_url_for_base() {
assert_eq!(
probe_url_for_base("http://localhost:11434/v1"),
"http://localhost:11434/v1/models"
);
assert_eq!(
probe_url_for_base("http://localhost:11434"),
"http://localhost:11434/api/tags"
);
}
#[test]
fn test_pull_events_decoder_status_and_success() {
let v: JsonValue = serde_json::json!({"status":"verifying"});
let events = pull_events_from_value(&v);
assert!(matches!(events.as_slice(), [PullEvent::Status(s)] if s == "verifying"));
let v2: JsonValue = serde_json::json!({"status":"success"});
let events2 = pull_events_from_value(&v2);
assert_eq!(events2.len(), 2);
assert!(matches!(events2[0], PullEvent::Status(ref s) if s == "success"));
assert!(matches!(events2[1], PullEvent::Success));
}
#[test]
fn test_pull_events_decoder_progress() {
let v: JsonValue = serde_json::json!({"digest":"sha256:abc","total":100});
let events = pull_events_from_value(&v);
assert_eq!(events.len(), 1);
match &events[0] {
PullEvent::ChunkProgress {
digest,
total,
completed,
} => {
assert_eq!(digest, "sha256:abc");
assert_eq!(*total, Some(100));
assert_eq!(*completed, None);
}
_ => panic!("expected ChunkProgress"),
}
let v2: JsonValue = serde_json::json!({"digest":"sha256:def","completed":42});
let events2 = pull_events_from_value(&v2);
assert_eq!(events2.len(), 1);
match &events2[0] {
PullEvent::ChunkProgress {
digest,
total,
completed,
} => {
assert_eq!(digest, "sha256:def");
assert_eq!(*total, None);
assert_eq!(*completed, Some(42));
}
_ => panic!("expected ChunkProgress"),
}
}
#[test]
fn test_upsert_provider_and_models() {
let mut doc = Document::new();
let tbl = upsert_provider_ollama(&mut doc);
assert!(tbl.contains_key("name"));
assert!(tbl.contains_key("base_url"));
assert!(tbl.contains_key("wire_api"));
set_ollama_models(&mut doc, &[String::from("llama3.2:3b")]);
let root = doc.as_table();
let mp = match root.get("model_providers").and_then(|i| i.as_table()) {
Some(t) => t,
None => panic!("model_providers"),
};
let ollama = match mp.get("ollama").and_then(|i| i.as_table()) {
Some(t) => t,
None => panic!("ollama"),
};
let arr = match ollama.get("models") {
Some(v) => v,
None => panic!("models array"),
};
assert!(arr.is_array(), "models should be an array");
let s = doc.to_string();
assert!(s.contains("model_providers"));
assert!(s.contains("ollama"));
assert!(s.contains("models"));
}
}

View File

@@ -0,0 +1,82 @@
use serde_json::Value as JsonValue;
use super::PullEvent;
// Convert a single JSON object representing a pull update into one or more events.
pub(crate) fn pull_events_from_value(value: &JsonValue) -> Vec<PullEvent> {
let mut events = Vec::new();
if let Some(status) = value.get("status").and_then(|s| s.as_str()) {
events.push(PullEvent::Status(status.to_string()));
if status == "success" {
events.push(PullEvent::Success);
}
}
let digest = value
.get("digest")
.and_then(|d| d.as_str())
.unwrap_or("")
.to_string();
let total = value.get("total").and_then(|t| t.as_u64());
let completed = value.get("completed").and_then(|t| t.as_u64());
if total.is_some() || completed.is_some() {
events.push(PullEvent::ChunkProgress {
digest,
total,
completed,
});
}
events
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pull_events_decoder_status_and_success() {
let v: JsonValue = serde_json::json!({"status":"verifying"});
let events = pull_events_from_value(&v);
assert!(matches!(events.as_slice(), [PullEvent::Status(s)] if s == "verifying"));
let v2: JsonValue = serde_json::json!({"status":"success"});
let events2 = pull_events_from_value(&v2);
assert_eq!(events2.len(), 2);
assert!(matches!(events2[0], PullEvent::Status(ref s) if s == "success"));
assert!(matches!(events2[1], PullEvent::Success));
}
#[test]
fn test_pull_events_decoder_progress() {
let v: JsonValue = serde_json::json!({"digest":"sha256:abc","total":100});
let events = pull_events_from_value(&v);
assert_eq!(events.len(), 1);
match &events[0] {
PullEvent::ChunkProgress {
digest,
total,
completed,
} => {
assert_eq!(digest, "sha256:abc");
assert_eq!(*total, Some(100));
assert_eq!(*completed, None);
}
_ => panic!("expected ChunkProgress"),
}
let v2: JsonValue = serde_json::json!({"digest":"sha256:def","completed":42});
let events2 = pull_events_from_value(&v2);
assert_eq!(events2.len(), 1);
match &events2[0] {
PullEvent::ChunkProgress {
digest,
total,
completed,
} => {
assert_eq!(digest, "sha256:def");
assert_eq!(*total, None);
assert_eq!(*completed, Some(42));
}
_ => panic!("expected ChunkProgress"),
}
}
}

View File

@@ -0,0 +1,80 @@
use crate::error::Result as CoreResult;
/// Identify whether a base_url points at an OpenAI-compatible root (".../v1").
pub(crate) fn is_openai_compatible_base_url(base_url: &str) -> bool {
base_url.trim_end_matches('/').ends_with("/v1")
}
/// Convert a provider base_url into the native Ollama host root.
/// For example, "http://localhost:11434/v1" -> "http://localhost:11434".
pub fn base_url_to_host_root(base_url: &str) -> String {
let trimmed = base_url.trim_end_matches('/');
if trimmed.ends_with("/v1") {
trimmed
.trim_end_matches("/v1")
.trim_end_matches('/')
.to_string()
} else {
trimmed.to_string()
}
}
/// Variant that considers an explicit WireApi value; provided to centralize
/// host root computation in one place for future extension.
pub fn base_url_to_host_root_with_wire(
base_url: &str,
_wire_api: crate::model_provider_info::WireApi,
) -> String {
base_url_to_host_root(base_url)
}
/// Compute the probe URL to verify if an Ollama server is reachable.
/// If the configured base is OpenAI-compatible (/v1), probe "models", otherwise
/// fall back to the native "/api/tags" endpoint.
pub fn probe_url_for_base(base_url: &str) -> String {
if is_openai_compatible_base_url(base_url) {
format!("{}/models", base_url.trim_end_matches('/'))
} else {
format!("{}/api/tags", base_url.trim_end_matches('/'))
}
}
/// Convenience helper to probe an Ollama server given a provider style base URL.
pub async fn probe_ollama_server(base_url: &str) -> CoreResult<bool> {
let url = probe_url_for_base(base_url);
let resp = reqwest::Client::new().get(url).send().await?;
Ok(resp.status().is_success())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_base_url_to_host_root() {
assert_eq!(
base_url_to_host_root("http://localhost:11434/v1"),
"http://localhost:11434"
);
assert_eq!(
base_url_to_host_root("http://localhost:11434"),
"http://localhost:11434"
);
assert_eq!(
base_url_to_host_root("http://localhost:11434/"),
"http://localhost:11434"
);
}
#[test]
fn test_probe_url_for_base() {
assert_eq!(
probe_url_for_base("http://localhost:11434/v1"),
"http://localhost:11434/v1/models"
);
assert_eq!(
probe_url_for_base("http://localhost:11434"),
"http://localhost:11434/api/tags"
);
}
}