Skip to main content

vtcode_core/llm/providers/ollama/
client.rs

1/// High-level Ollama client for server interaction and model management.
2/// Adapted from OpenAI Codex's codex-ollama/src/client.rs
3use std::io;
4use std::time::Duration;
5
6use futures::StreamExt;
7use futures::stream::BoxStream;
8use semver::Version;
9use serde_json::Value as JsonValue;
10
11use super::OLLAMA_CONNECTION_ERROR;
12use super::ollama_model_name_from_fields;
13use super::pull::{OllamaPullEvent, OllamaPullProgressReporter};
14use super::url::base_url_to_host_root;
15
16/// Client for interacting with a local or remote Ollama instance.
17pub struct OllamaClient {
18    client: reqwest::Client,
19    host_root: String,
20}
21
22impl OllamaClient {
23    /// Create a client from a base URL and verify the server is reachable.
24    pub async fn try_from_base_url(base_url: &str) -> io::Result<Self> {
25        let host_root = base_url_to_host_root(base_url);
26        let client = reqwest::Client::builder()
27            .connect_timeout(Duration::from_secs(5))
28            .build()
29            .unwrap_or_else(|_| reqwest::Client::new());
30
31        let instance = Self { client, host_root };
32        instance.probe_server().await?;
33        Ok(instance)
34    }
35
36    /// Probe whether the server is reachable.
37    async fn probe_server(&self) -> io::Result<()> {
38        let url = format!("{}/api/tags", self.host_root.trim_end_matches('/'));
39        let resp = self.client.get(url).send().await.map_err(|err| {
40            tracing::warn!("Failed to connect to Ollama server: {err:?}");
41            io::Error::other(OLLAMA_CONNECTION_ERROR)
42        })?;
43
44        if resp.status().is_success() {
45            Ok(())
46        } else {
47            tracing::warn!(
48                "Failed to probe server at {}: HTTP {}",
49                self.host_root,
50                resp.status()
51            );
52            Err(io::Error::other(OLLAMA_CONNECTION_ERROR))
53        }
54    }
55
56    /// Fetch the list of model names available on the server.
57    pub async fn fetch_models(&self) -> io::Result<Vec<String>> {
58        let tags_url = format!("{}/api/tags", self.host_root.trim_end_matches('/'));
59        let resp = self
60            .client
61            .get(tags_url)
62            .send()
63            .await
64            .map_err(io::Error::other)?;
65
66        if !resp.status().is_success() {
67            return Ok(Vec::new());
68        }
69
70        let val = resp.json::<JsonValue>().await.map_err(io::Error::other)?;
71        let names = val
72            .get("models")
73            .and_then(|m| m.as_array())
74            .map(|arr| {
75                arr.iter()
76                    .filter_map(|v| {
77                        ollama_model_name_from_fields(
78                            v.get("name").and_then(|n| n.as_str()),
79                            v.get("model").and_then(|n| n.as_str()),
80                        )
81                    })
82                    .map(str::to_string)
83                    .collect::<Vec<_>>()
84            })
85            .unwrap_or_default();
86
87        Ok(names)
88    }
89
90    /// Query the server for its version string, returning `None` when unavailable.
91    /// Adapted from OpenAI Codex's codex-ollama/src/client.rs
92    pub async fn fetch_version(&self) -> io::Result<Option<Version>> {
93        let version_url = format!("{}/api/version", self.host_root.trim_end_matches('/'));
94        let resp = self
95            .client
96            .get(version_url)
97            .send()
98            .await
99            .map_err(io::Error::other)?;
100
101        if !resp.status().is_success() {
102            return Ok(None);
103        }
104
105        let val = resp.json::<JsonValue>().await.map_err(io::Error::other)?;
106        let Some(version_str) = val.get("version").and_then(|v| v.as_str()).map(str::trim) else {
107            return Ok(None);
108        };
109
110        let normalized = version_str.trim_start_matches('v');
111        match Version::parse(normalized) {
112            Ok(version) => Ok(Some(version)),
113            Err(err) => {
114                tracing::warn!("Failed to parse Ollama version `{version_str}`: {err}");
115                Ok(None)
116            }
117        }
118    }
119
120    /// Start a model pull and return a stream of events.
121    pub async fn pull_model_stream(
122        &self,
123        model: &str,
124    ) -> io::Result<BoxStream<'static, OllamaPullEvent>> {
125        let url = format!("{}/api/pull", self.host_root.trim_end_matches('/'));
126        let resp = self
127            .client
128            .post(url)
129            .json(&serde_json::json!({"model": model, "stream": true}))
130            .send()
131            .await
132            .map_err(io::Error::other)?;
133
134        if !resp.status().is_success() {
135            return Err(io::Error::other(format!(
136                "failed to start pull: HTTP {}",
137                resp.status()
138            )));
139        }
140
141        let mut stream = resp.bytes_stream();
142        let mut buf = String::new();
143
144        let s = async_stream::stream! {
145            while let Some(chunk) = stream.next().await {
146                match chunk {
147                    Ok(bytes) => {
148                        if let Ok(text) = std::str::from_utf8(&bytes) {
149                            buf.push_str(text);
150                            while let Some(pos) = buf.find('\n') {
151                                let line = buf.drain(..=pos).collect::<String>();
152                                let text = line.trim();
153                                if text.is_empty() { continue; }
154                                if let Ok(value) = serde_json::from_str::<JsonValue>(text) {
155                                    for ev in super::parser::pull_events_from_value(&value) {
156                                        yield ev;
157                                    }
158                                    if let Some(err_msg) = value.get("error").and_then(|e| e.as_str()) {
159                                        yield OllamaPullEvent::Error(err_msg.to_string());
160                                        return;
161                                    }
162                                }
163                            }
164                        }
165                    }
166                    Err(_) => {
167                        // Connection error: end the stream.
168                        return;
169                    }
170                }
171            }
172        };
173
174        Ok(Box::pin(s))
175    }
176
177    /// High-level helper to pull a model and drive a progress reporter.
178    /// Adapted from OpenAI Codex's codex-ollama/src/client.rs
179    pub async fn pull_with_reporter(
180        &self,
181        model: &str,
182        reporter: &mut dyn OllamaPullProgressReporter,
183    ) -> io::Result<()> {
184        reporter.on_event(&OllamaPullEvent::Status(format!(
185            "Pulling model {model}..."
186        )))?;
187        let mut stream = self.pull_model_stream(model).await?;
188
189        while let Some(event) = stream.next().await {
190            reporter.on_event(&event)?;
191            match event {
192                OllamaPullEvent::Success => {
193                    return Ok(());
194                }
195                OllamaPullEvent::Error(err) => {
196                    // Empirically, ollama returns a 200 OK response even when
197                    // the output stream includes an error message. Verify with:
198                    //
199                    // `curl -i http://localhost:11434/api/pull -d '{ "model": "foobarbaz" }'`
200                    //
201                    // When we see an error in the stream, we return it to the
202                    // caller as an I/O error.
203                    return Err(io::Error::other(err));
204                }
205                _ => {}
206            }
207        }
208
209        // Stream ended without explicit success or error.
210        Err(io::Error::other("Pull stream ended unexpectedly"))
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use semver::Version;
217
218    #[test]
219    fn test_client_creation_requires_valid_base_url() {
220        // This would require a running Ollama server to test properly.
221        // For now, we verify the URL parsing logic in url.rs tests.
222    }
223
224    #[test]
225    fn test_version_parsing() {
226        // Test that semver::Version parses Ollama version strings correctly
227        let v = Version::parse("0.14.1").expect("parse version");
228        assert_eq!(v.major, 0);
229        assert_eq!(v.minor, 14);
230        assert_eq!(v.patch, 1);
231    }
232
233    #[test]
234    fn test_version_parsing_strips_v_prefix() {
235        // Ollama may return versions with 'v' prefix
236        let version_str = "v0.13.4";
237        let normalized = version_str.trim_start_matches('v');
238        let v = Version::parse(normalized).expect("parse version");
239        assert_eq!(v, Version::new(0, 13, 4));
240    }
241}