Skip to main content

trusty_common/
chat.rs

1//! Provider-agnostic streaming chat abstraction with tool-use support.
2//!
3//! Why: trusty-memory and trusty-search both want to support more than one
4//! upstream LLM (OpenRouter for cloud, Ollama / LM Studio for local). Rather
5//! than each crate re-implementing the dispatch, we expose a small
6//! [`ChatProvider`] trait plus two concrete implementations and an
7//! auto-detector for a running local model server. The trait also surfaces
8//! OpenAI-style tool/function calling so downstream agents can let the model
9//! invoke tools (search, memory recall, shell, etc.).
10//!
11//! What: defines the [`ChatProvider`] trait, [`ToolDef`] / [`ToolCall`] /
12//! [`ChatEvent`] tool-use types, an [`OpenRouterProvider`] and an
13//! [`OllamaProvider`] that both speak OpenAI-compatible
14//! `/v1/chat/completions` with SSE streaming (including the streamed
15//! `tool_calls` shape), and [`auto_detect_local_provider`] which probes
16//! `{base_url}/v1/models` with a 1-second timeout.
17//!
18//! Test: `cargo test -p trusty-common` covers default config values, the
19//! unreachable-server path of `auto_detect_local_provider`, SSE delta
20//! streaming, and accumulation of streamed tool-call fragments.
21
22use crate::ChatMessage;
23use anyhow::{Context, Result, anyhow};
24use async_trait::async_trait;
25use serde::{Deserialize, Serialize};
26use tokio::sync::mpsc::Sender;
27
28const LOCAL_PROBE_TIMEOUT_SECS: u64 = 1;
29const LOCAL_REQUEST_TIMEOUT_SECS: u64 = 120;
30const OPENROUTER_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
31const OPENROUTER_CONNECT_TIMEOUT_SECS: u64 = 10;
32const OPENROUTER_REQUEST_TIMEOUT_SECS: u64 = 120;
33const HTTP_REFERER: &str = "https://github.com/bobmatnyc/trusty-common";
34const X_TITLE: &str = "trusty-common";
35
36/// Configuration for a local OpenAI-compatible model server (Ollama, LM
37/// Studio, llama.cpp's server, etc.).
38///
39/// Why: callers want a single struct they can deserialize from config files
40/// and pass to [`auto_detect_local_provider`] without juggling defaults.
41/// What: holds an enable flag, the server's base URL (no trailing slash),
42/// and the default model to request. Defaults target Ollama's standard
43/// localhost binding.
44/// Test: `local_model_config_defaults` asserts the default values.
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct LocalModelConfig {
47    pub enabled: bool,
48    pub base_url: String,
49    pub model: String,
50}
51
52impl Default for LocalModelConfig {
53    fn default() -> Self {
54        Self {
55            enabled: true,
56            base_url: "http://localhost:11434".to_string(),
57            model: "llama3.2".to_string(),
58        }
59    }
60}
61
62// ─── Tool-use types ───────────────────────────────────────────────────────
63
64/// JSON-Schema description of a callable tool, in OpenAI function-calling
65/// shape.
66///
67/// Why: downstream agents (trusty-memory, trusty-search) expose tools like
68/// `memory_recall` or `web_search` to the LLM. The OpenAI tool format is the
69/// de-facto common denominator across OpenRouter, Ollama, LM Studio, and
70/// most cloud providers.
71/// What: `name` and `description` are passed verbatim; `parameters` is a
72/// JSON Schema object (typically `{"type":"object","properties":{...}}`).
73/// Test: `tool_def_serializes_as_function` checks the wire shape.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ToolDef {
76    pub name: String,
77    pub description: String,
78    pub parameters: serde_json::Value,
79}
80
81/// A tool invocation the model wants the host to perform.
82///
83/// Why: the streaming chat API emits `tool_calls` in fragments — first an
84/// `id` + `function.name`, then a string of `function.arguments` deltas.
85/// We accumulate fragments and surface one fully-formed [`ToolCall`] per
86/// invocation to the caller.
87/// What: `id` is the upstream's call id (echoed back in subsequent
88/// `role:"tool"` messages); `name` is the function name; `arguments` is a
89/// JSON string (NOT a parsed value — many models emit malformed JSON and
90/// callers want the raw text for error reporting / repair).
91/// Test: `accumulates_streamed_tool_call_fragments`.
92#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
93pub struct ToolCall {
94    pub id: String,
95    pub name: String,
96    pub arguments: String,
97}
98
99/// Streaming chat event.
100///
101/// Why: replaces the previous "string-only" channel so callers can
102/// distinguish text deltas from tool invocations and from terminal
103/// success/error without parsing magic markers out of the text stream.
104/// What: `Delta` is a content chunk; `ToolCall` is a fully-accumulated tool
105/// invocation; `Done` signals the upstream stream terminated normally;
106/// `Error` carries a human-readable message for stream-mid failures (the
107/// provider also returns `Err` from `chat_stream`, but `Error` lets the
108/// caller display partial-stream failures inline).
109/// Test: `ollama_provider_streams_sse_deltas`.
110#[derive(Debug, Clone)]
111pub enum ChatEvent {
112    Delta(String),
113    ToolCall(ToolCall),
114    Done,
115    Error(String),
116}
117
118/// Streaming chat provider abstraction.
119///
120/// Why: downstream crates (trusty-memory, trusty-search) want to support
121/// multiple LLM backends without hard-coding which one to call. Providers
122/// expose a uniform streaming interface so the caller can swap them at
123/// runtime based on configuration / availability.
124/// What: implementors stream [`ChatEvent`]s into `tx`. Pass an empty
125/// `tools` vec to disable tool use entirely (the provider MUST then omit
126/// the `tools` field from the upstream request — some models error on an
127/// empty array). Returning `Ok(())` means the stream completed normally;
128/// the caller should also expect a final [`ChatEvent::Done`].
129/// Test: implementations are covered by their own unit tests in this
130/// module plus integration tests in downstream crates.
131#[async_trait]
132pub trait ChatProvider: Send + Sync {
133    /// Human-readable provider name (e.g. `"openrouter"`, `"ollama"`).
134    fn name(&self) -> &str;
135    /// Model identifier sent on every request.
136    fn model(&self) -> &str;
137    /// Stream chat events into `tx`. `tools` empty disables tool use.
138    async fn chat_stream(
139        &self,
140        messages: Vec<ChatMessage>,
141        tools: Vec<ToolDef>,
142        tx: Sender<ChatEvent>,
143    ) -> Result<()>;
144}
145
146// ─── Shared SSE / request types ────────────────────────────────────────────
147
148#[derive(Debug, Serialize)]
149struct OpenAiToolWire<'a> {
150    #[serde(rename = "type")]
151    kind: &'static str,
152    function: OpenAiFunctionWire<'a>,
153}
154
155#[derive(Debug, Serialize)]
156struct OpenAiFunctionWire<'a> {
157    name: &'a str,
158    description: &'a str,
159    parameters: &'a serde_json::Value,
160}
161
162#[derive(Debug, Serialize)]
163struct ChatRequestWire<'a> {
164    model: &'a str,
165    messages: &'a [ChatMessage],
166    stream: bool,
167    #[serde(skip_serializing_if = "Option::is_none")]
168    tools: Option<Vec<OpenAiToolWire<'a>>>,
169}
170
171fn tools_wire(tools: &[ToolDef]) -> Option<Vec<OpenAiToolWire<'_>>> {
172    if tools.is_empty() {
173        None
174    } else {
175        Some(
176            tools
177                .iter()
178                .map(|t| OpenAiToolWire {
179                    kind: "function",
180                    function: OpenAiFunctionWire {
181                        name: &t.name,
182                        description: &t.description,
183                        parameters: &t.parameters,
184                    },
185                })
186                .collect(),
187        )
188    }
189}
190
191/// Accumulator for streamed tool-call fragments.
192///
193/// Why: OpenAI-style streaming sends each tool call across multiple SSE
194/// frames: the first frame at a given `index` carries `id` and
195/// `function.name`; subsequent frames append to `function.arguments`. We
196/// accumulate by `index` and emit fully-formed [`ToolCall`]s only after the
197/// stream terminates (or we see `finish_reason: tool_calls`).
198/// What: a vector slot per index, growing as needed; merge logic is in
199/// `apply_delta`. `finalize` drops slots that never received an id (defensive
200/// — shouldn't happen but avoids emitting half-baked calls).
201/// Test: `accumulates_streamed_tool_call_fragments`.
202#[derive(Debug, Default)]
203struct ToolCallAccumulator {
204    // index -> (id, name, args)
205    slots: Vec<Option<(String, String, String)>>,
206}
207
208impl ToolCallAccumulator {
209    fn apply_delta(&mut self, tool_calls: &serde_json::Value) {
210        let Some(arr) = tool_calls.as_array() else {
211            return;
212        };
213        for tc in arr {
214            let idx = tc.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
215            while self.slots.len() <= idx {
216                self.slots.push(None);
217            }
218            let slot = self.slots[idx].get_or_insert_with(|| {
219                (String::new(), String::new(), String::new())
220            });
221            if let Some(id) = tc.get("id").and_then(|v| v.as_str())
222                && !id.is_empty()
223            {
224                slot.0 = id.to_string();
225            }
226            if let Some(func) = tc.get("function") {
227                if let Some(name) = func.get("name").and_then(|v| v.as_str())
228                    && !name.is_empty()
229                {
230                    slot.1 = name.to_string();
231                }
232                if let Some(args) = func.get("arguments").and_then(|v| v.as_str()) {
233                    slot.2.push_str(args);
234                }
235            }
236        }
237    }
238
239    fn finalize(self) -> Vec<ToolCall> {
240        self.slots
241            .into_iter()
242            .filter_map(|opt| {
243                opt.and_then(|(id, name, arguments)| {
244                    if name.is_empty() {
245                        None
246                    } else {
247                        Some(ToolCall {
248                            id,
249                            name,
250                            arguments,
251                        })
252                    }
253                })
254            })
255            .collect()
256    }
257}
258
259/// Drive one OpenAI-compatible SSE stream into the caller's [`ChatEvent`]
260/// channel.
261///
262/// Why: OpenRouter and Ollama both speak the same wire format; sharing the
263/// loop keeps the two providers in lock-step.
264/// What: reads `resp.bytes_stream()`, splits on newlines, parses `data:`
265/// frames, forwards `delta.content` as [`ChatEvent::Delta`], accumulates
266/// `delta.tool_calls`, and on `[DONE]` (or upstream EOF) emits one
267/// [`ChatEvent::ToolCall`] per accumulated call followed by
268/// [`ChatEvent::Done`].
269/// Test: covered by `ollama_provider_streams_sse_deltas` and
270/// `accumulates_streamed_tool_call_fragments`.
271async fn pump_openai_sse(resp: reqwest::Response, tx: Sender<ChatEvent>) -> Result<()> {
272    use futures_util::StreamExt;
273
274    let mut acc = ToolCallAccumulator::default();
275    let mut buf = String::new();
276    let mut stream = resp.bytes_stream();
277
278    while let Some(chunk) = stream.next().await {
279        let bytes = chunk.context("read chat stream chunk")?;
280        let text = match std::str::from_utf8(&bytes) {
281            Ok(s) => s,
282            Err(_) => continue,
283        };
284        buf.push_str(text);
285
286        while let Some(idx) = buf.find('\n') {
287            let line: String = buf.drain(..=idx).collect();
288            let line = line.trim();
289            let Some(payload) = line.strip_prefix("data:").map(str::trim) else {
290                continue;
291            };
292            if payload.is_empty() {
293                continue;
294            }
295            if payload == "[DONE]" {
296                // Flush accumulated tool calls and finish.
297                for call in std::mem::take(&mut acc).finalize() {
298                    if tx.send(ChatEvent::ToolCall(call)).await.is_err() {
299                        return Ok(());
300                    }
301                }
302                let _ = tx.send(ChatEvent::Done).await;
303                return Ok(());
304            }
305            let v: serde_json::Value = match serde_json::from_str(payload) {
306                Ok(v) => v,
307                Err(_) => continue,
308            };
309            let delta = v
310                .get("choices")
311                .and_then(|c| c.get(0))
312                .and_then(|c| c.get("delta"));
313            if let Some(delta) = delta {
314                if let Some(content) = delta.get("content").and_then(|c| c.as_str())
315                    && !content.is_empty()
316                    && tx
317                        .send(ChatEvent::Delta(content.to_string()))
318                        .await
319                        .is_err()
320                {
321                    return Ok(());
322                }
323                if let Some(tc) = delta.get("tool_calls") {
324                    acc.apply_delta(tc);
325                }
326            }
327        }
328    }
329
330    // Upstream EOF without a [DONE] sentinel — still flush and finish.
331    for call in acc.finalize() {
332        if tx.send(ChatEvent::ToolCall(call)).await.is_err() {
333            return Ok(());
334        }
335    }
336    let _ = tx.send(ChatEvent::Done).await;
337    Ok(())
338}
339
340// ─── OpenRouter ───────────────────────────────────────────────────────────
341
342/// Cloud chat provider backed by OpenRouter.
343///
344/// Why: lets callers pick OpenRouter or a local model uniformly through
345/// the [`ChatProvider`] trait.
346/// What: stores an API key and model id; POSTs OpenAI-compatible streaming
347/// chat completions with bearer auth and trusty-common branding headers.
348/// Test: shape covered by `openrouter_provider_reports_metadata`; the
349/// streaming and tool-call paths are covered by integration tests in
350/// downstream crates plus the SSE-pump unit tests in this module.
351pub struct OpenRouterProvider {
352    pub api_key: String,
353    pub model: String,
354}
355
356impl OpenRouterProvider {
357    /// Construct a provider from an API key and model id.
358    ///
359    /// Why: keeps callers from poking the public fields directly so the
360    /// struct can grow optional knobs without breaking call sites.
361    /// What: stores both fields verbatim.
362    /// Test: trivially exercised by `openrouter_provider_reports_metadata`.
363    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
364        Self {
365            api_key: api_key.into(),
366            model: model.into(),
367        }
368    }
369}
370
371#[async_trait]
372impl ChatProvider for OpenRouterProvider {
373    fn name(&self) -> &str {
374        "openrouter"
375    }
376
377    fn model(&self) -> &str {
378        &self.model
379    }
380
381    async fn chat_stream(
382        &self,
383        messages: Vec<ChatMessage>,
384        tools: Vec<ToolDef>,
385        tx: Sender<ChatEvent>,
386    ) -> Result<()> {
387        if self.api_key.is_empty() {
388            return Err(anyhow!("openrouter api key is empty"));
389        }
390        let client = reqwest::Client::builder()
391            .connect_timeout(std::time::Duration::from_secs(
392                OPENROUTER_CONNECT_TIMEOUT_SECS,
393            ))
394            .timeout(std::time::Duration::from_secs(
395                OPENROUTER_REQUEST_TIMEOUT_SECS,
396            ))
397            .build()
398            .context("build reqwest client for OpenRouterProvider::chat_stream")?;
399
400        let tools_wire = tools_wire(&tools);
401        let body = ChatRequestWire {
402            model: &self.model,
403            messages: &messages,
404            stream: true,
405            tools: tools_wire,
406        };
407        let resp = client
408            .post(OPENROUTER_URL)
409            .bearer_auth(&self.api_key)
410            .header("HTTP-Referer", HTTP_REFERER)
411            .header("X-Title", X_TITLE)
412            .json(&body)
413            .send()
414            .await
415            .context("POST openrouter chat completions (stream)")?;
416
417        let status = resp.status();
418        if !status.is_success() {
419            let text = resp.text().await.unwrap_or_default();
420            return Err(anyhow!("openrouter HTTP {status}: {text}"));
421        }
422
423        pump_openai_sse(resp, tx).await
424    }
425}
426
427// ─── Ollama / OpenAI-compatible local ─────────────────────────────────────
428
429/// Local chat provider for OpenAI-compatible servers (Ollama, LM Studio,
430/// llama.cpp's `server`, vLLM, etc.).
431///
432/// Why: developers increasingly run a local model server during dev to avoid
433/// API costs and latency. The OpenAI-compatible `/v1/chat/completions`
434/// endpoint with SSE streaming is the de-facto common denominator.
435/// What: stores the server's base URL and the model id to request.
436/// `chat_stream` POSTs `{model, messages, tools?, stream: true}` and parses
437/// SSE `data:` frames identically to the OpenRouter path.
438/// Test: shape covered by `ollama_provider_reports_metadata`; streaming and
439/// tool-call accumulation by `ollama_provider_streams_sse_deltas` and
440/// `accumulates_streamed_tool_call_fragments`.
441pub struct OllamaProvider {
442    pub base_url: String,
443    pub model: String,
444}
445
446impl OllamaProvider {
447    /// Construct a provider from a base URL and model id.
448    ///
449    /// Why: parallel to [`OpenRouterProvider::new`] so callers see a
450    /// consistent shape across providers.
451    /// What: stores both fields verbatim; the base URL should NOT have a
452    /// trailing slash — the implementation appends `/v1/chat/completions`.
453    /// Test: covered by `ollama_provider_reports_metadata`.
454    pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
455        Self {
456            base_url: base_url.into(),
457            model: model.into(),
458        }
459    }
460}
461
462#[async_trait]
463impl ChatProvider for OllamaProvider {
464    fn name(&self) -> &str {
465        "ollama"
466    }
467
468    fn model(&self) -> &str {
469        &self.model
470    }
471
472    async fn chat_stream(
473        &self,
474        messages: Vec<ChatMessage>,
475        tools: Vec<ToolDef>,
476        tx: Sender<ChatEvent>,
477    ) -> Result<()> {
478        let client = reqwest::Client::builder()
479            .connect_timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
480            .timeout(std::time::Duration::from_secs(LOCAL_REQUEST_TIMEOUT_SECS))
481            .build()
482            .context("build reqwest client for OllamaProvider::chat_stream")?;
483
484        let url = format!(
485            "{}/v1/chat/completions",
486            self.base_url.trim_end_matches('/')
487        );
488        let tools_wire = tools_wire(&tools);
489        let body = ChatRequestWire {
490            model: &self.model,
491            messages: &messages,
492            stream: true,
493            tools: tools_wire,
494        };
495        let resp = client
496            .post(&url)
497            .json(&body)
498            .send()
499            .await
500            .with_context(|| format!("POST {url}"))?;
501
502        let status = resp.status();
503        if !status.is_success() {
504            let text = resp.text().await.unwrap_or_default();
505            return Err(anyhow!("local chat HTTP {status}: {text}"));
506        }
507
508        pump_openai_sse(resp, tx).await
509    }
510}
511
512/// Probe a local model server and return an [`OllamaProvider`] if reachable.
513///
514/// Why: at startup, downstream daemons want to know whether a local model
515/// server is running before falling back to a cloud provider. The OpenAI
516/// `/v1/models` endpoint is a cheap, side-effect-free liveness check that
517/// Ollama, LM Studio, and llama.cpp's server all implement.
518/// What: GETs `{base_url}/v1/models` with a 1-second total timeout. Returns
519/// `Some(OllamaProvider { base_url, model: "" })` on any 2xx response.
520/// Returns `None` on network errors, timeouts, or non-2xx status. Never
521/// returns an error — the caller treats absence as "no local provider
522/// available" and is responsible for setting the model id afterwards (e.g.
523/// from [`LocalModelConfig::model`]).
524/// Test: `auto_detect_returns_none_on_unreachable` points at a closed port
525/// and asserts `None` within the 1-second budget;
526/// `auto_detect_returns_some_on_200` spins up an in-process server and
527/// asserts a provider is returned.
528pub async fn auto_detect_local_provider(base_url: &str) -> Option<OllamaProvider> {
529    let client = reqwest::Client::builder()
530        .connect_timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
531        .timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
532        .build()
533        .ok()?;
534
535    let url = format!("{}/v1/models", base_url.trim_end_matches('/'));
536    match client.get(&url).send().await {
537        Ok(resp) if resp.status().is_success() => {
538            Some(OllamaProvider::new(base_url.to_string(), String::new()))
539        }
540        _ => None,
541    }
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547
548    #[test]
549    fn local_model_config_defaults() {
550        let cfg = LocalModelConfig::default();
551        assert!(cfg.enabled);
552        assert_eq!(cfg.base_url, "http://localhost:11434");
553        assert_eq!(cfg.model, "llama3.2");
554    }
555
556    #[test]
557    fn openrouter_provider_reports_metadata() {
558        let p = OpenRouterProvider::new("sk-xxx", "anthropic/claude-3.5-sonnet");
559        assert_eq!(p.name(), "openrouter");
560        assert_eq!(p.model(), "anthropic/claude-3.5-sonnet");
561    }
562
563    #[test]
564    fn ollama_provider_reports_metadata() {
565        let p = OllamaProvider::new("http://localhost:11434", "llama3.2");
566        assert_eq!(p.name(), "ollama");
567        assert_eq!(p.model(), "llama3.2");
568    }
569
570    #[test]
571    fn tool_def_serializes_as_function() {
572        // When passed through `tools_wire`, a ToolDef should produce a JSON
573        // object that matches the OpenAI function-calling shape.
574        let tools = vec![ToolDef {
575            name: "search".into(),
576            description: "Search the web".into(),
577            parameters: serde_json::json!({
578                "type": "object",
579                "properties": { "query": { "type": "string" } },
580                "required": ["query"],
581            }),
582        }];
583        let wire = tools_wire(&tools).expect("expected Some");
584        let v = serde_json::to_value(&wire).unwrap();
585        assert_eq!(v[0]["type"], "function");
586        assert_eq!(v[0]["function"]["name"], "search");
587        assert_eq!(v[0]["function"]["parameters"]["type"], "object");
588    }
589
590    #[test]
591    fn empty_tools_serializes_to_none() {
592        // Empty tools must omit the field entirely so models that error on
593        // empty arrays still work.
594        assert!(tools_wire(&[]).is_none());
595    }
596
597    #[test]
598    fn accumulates_streamed_tool_call_fragments() {
599        // Simulate three SSE deltas for a single tool call: id+name, then
600        // two args fragments. After finalize, we should see one fully-formed
601        // ToolCall with concatenated arguments.
602        let mut acc = ToolCallAccumulator::default();
603        acc.apply_delta(&serde_json::json!([{
604            "index": 0,
605            "id": "call_abc",
606            "function": { "name": "search", "arguments": "" }
607        }]));
608        acc.apply_delta(&serde_json::json!([{
609            "index": 0,
610            "function": { "arguments": "{\"query\":\"" }
611        }]));
612        acc.apply_delta(&serde_json::json!([{
613            "index": 0,
614            "function": { "arguments": "rust\"}" }
615        }]));
616        let calls = acc.finalize();
617        assert_eq!(calls.len(), 1);
618        assert_eq!(calls[0].id, "call_abc");
619        assert_eq!(calls[0].name, "search");
620        assert_eq!(calls[0].arguments, "{\"query\":\"rust\"}");
621    }
622
623    #[tokio::test]
624    async fn auto_detect_returns_none_on_unreachable() {
625        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
626        let port = listener.local_addr().unwrap().port();
627        drop(listener);
628
629        let base = format!("http://127.0.0.1:{port}");
630        let start = std::time::Instant::now();
631        let got = auto_detect_local_provider(&base).await;
632        let elapsed = start.elapsed();
633        assert!(got.is_none(), "expected None for unreachable server");
634        assert!(
635            elapsed < std::time::Duration::from_secs(2),
636            "auto-detect took too long: {elapsed:?}"
637        );
638    }
639
640    #[tokio::test]
641    async fn auto_detect_returns_some_on_200() {
642        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
643        let addr = listener.local_addr().unwrap();
644        let base = format!("http://{addr}");
645
646        tokio::spawn(async move {
647            if let Ok((mut sock, _)) = listener.accept().await {
648                use tokio::io::{AsyncReadExt, AsyncWriteExt};
649                let mut buf = [0u8; 1024];
650                let _ = sock.read(&mut buf).await;
651                let body = b"{\"data\":[]}";
652                let response = format!(
653                    "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
654                    body.len()
655                );
656                let _ = sock.write_all(response.as_bytes()).await;
657                let _ = sock.write_all(body).await;
658                let _ = sock.shutdown().await;
659            }
660        });
661
662        let got = auto_detect_local_provider(&base).await;
663        assert!(got.is_some(), "expected Some for reachable 200 server");
664        let p = got.unwrap();
665        assert_eq!(p.name(), "ollama");
666        assert_eq!(p.base_url, base);
667    }
668
669    #[test]
670    fn local_model_config_deserializes_from_toml() {
671        let toml_src = r#"
672            enabled = true
673            base_url = "http://localhost:1234"
674            model = "qwen2.5-coder"
675        "#;
676        let cfg: LocalModelConfig = toml::from_str(toml_src).expect("parse TOML");
677        assert!(cfg.enabled);
678        assert_eq!(cfg.base_url, "http://localhost:1234");
679        assert_eq!(cfg.model, "qwen2.5-coder");
680    }
681
682    #[tokio::test]
683    async fn ollama_provider_streams_sse_deltas() {
684        // Inline server replies with two content deltas plus [DONE]. We
685        // expect two Delta events followed by Done.
686        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
687        let addr = listener.local_addr().unwrap();
688        let base = format!("http://{addr}");
689
690        tokio::spawn(async move {
691            if let Ok((mut sock, _)) = listener.accept().await {
692                use tokio::io::{AsyncReadExt, AsyncWriteExt};
693                let mut buf = [0u8; 4096];
694                let _ = sock.read(&mut buf).await;
695
696                let sse_body = concat!(
697                    "data: {\"choices\":[{\"delta\":{\"content\":\"hello \"}}]}\n\n",
698                    "data: {\"choices\":[{\"delta\":{\"content\":\"world\"}}]}\n\n",
699                    "data: [DONE]\n\n",
700                );
701                let response = format!(
702                    "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
703                    sse_body.len(),
704                    sse_body
705                );
706                let _ = sock.write_all(response.as_bytes()).await;
707                let _ = sock.shutdown().await;
708            }
709        });
710
711        let provider = OllamaProvider::new(base, "test-model");
712        let (tx, mut rx) = tokio::sync::mpsc::channel::<ChatEvent>(8);
713        let handle = tokio::spawn(async move {
714            provider
715                .chat_stream(
716                    vec![ChatMessage {
717                        role: "user".into(),
718                        content: "hi".into(),
719                        tool_call_id: None,
720                        tool_calls: None,
721                    }],
722                    vec![],
723                    tx,
724                )
725                .await
726        });
727
728        let mut deltas = Vec::new();
729        let mut saw_done = false;
730        while let Some(ev) = rx.recv().await {
731            match ev {
732                ChatEvent::Delta(s) => deltas.push(s),
733                ChatEvent::Done => saw_done = true,
734                ChatEvent::ToolCall(_) => panic!("unexpected tool call"),
735                ChatEvent::Error(e) => panic!("stream error: {e}"),
736            }
737        }
738        let result = handle.await.expect("task panicked");
739        assert!(result.is_ok(), "chat_stream errored: {result:?}");
740        assert_eq!(deltas, vec!["hello ".to_string(), "world".to_string()]);
741        assert!(saw_done, "expected ChatEvent::Done");
742    }
743
744    #[tokio::test]
745    async fn ollama_provider_emits_tool_call() {
746        // SSE stream that delivers one tool call across two fragments.
747        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
748        let addr = listener.local_addr().unwrap();
749        let base = format!("http://{addr}");
750
751        tokio::spawn(async move {
752            if let Ok((mut sock, _)) = listener.accept().await {
753                use tokio::io::{AsyncReadExt, AsyncWriteExt};
754                let mut buf = [0u8; 4096];
755                let _ = sock.read(&mut buf).await;
756
757                let sse_body = concat!(
758                    "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"search\",\"arguments\":\"{\\\"q\\\":\"}}]}}]}\n\n",
759                    "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\"rust\\\"}\"}}]}}]}\n\n",
760                    "data: [DONE]\n\n",
761                );
762                let response = format!(
763                    "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
764                    sse_body.len(),
765                    sse_body
766                );
767                let _ = sock.write_all(response.as_bytes()).await;
768                let _ = sock.shutdown().await;
769            }
770        });
771
772        let provider = OllamaProvider::new(base, "test-model");
773        let (tx, mut rx) = tokio::sync::mpsc::channel::<ChatEvent>(8);
774        let handle = tokio::spawn(async move {
775            provider
776                .chat_stream(
777                    vec![ChatMessage {
778                        role: "user".into(),
779                        content: "search rust".into(),
780                        tool_call_id: None,
781                        tool_calls: None,
782                    }],
783                    vec![ToolDef {
784                        name: "search".into(),
785                        description: "search the web".into(),
786                        parameters: serde_json::json!({"type":"object"}),
787                    }],
788                    tx,
789                )
790                .await
791        });
792
793        let mut tool_calls = Vec::new();
794        let mut saw_done = false;
795        while let Some(ev) = rx.recv().await {
796            match ev {
797                ChatEvent::ToolCall(tc) => tool_calls.push(tc),
798                ChatEvent::Done => saw_done = true,
799                ChatEvent::Delta(_) => {}
800                ChatEvent::Error(e) => panic!("stream error: {e}"),
801            }
802        }
803        let result = handle.await.expect("task panicked");
804        assert!(result.is_ok(), "chat_stream errored: {result:?}");
805        assert_eq!(tool_calls.len(), 1);
806        assert_eq!(tool_calls[0].id, "call_1");
807        assert_eq!(tool_calls[0].name, "search");
808        assert_eq!(tool_calls[0].arguments, "{\"q\":\"rust\"}");
809        assert!(saw_done);
810    }
811}