Skip to main content

trusty_common/
chat.rs

1//! Provider-agnostic streaming chat abstraction with tool-use support.
2//!
3//! Why: trusty-memory and trusty-search both want to support more than one
4//! upstream LLM (OpenRouter for cloud, Ollama / LM Studio for local). Rather
5//! than each crate re-implementing the dispatch, we expose a small
6//! [`ChatProvider`] trait plus two concrete implementations and an
7//! auto-detector for a running local model server. The trait also surfaces
8//! OpenAI-style tool/function calling so downstream agents can let the model
9//! invoke tools (search, memory recall, shell, etc.).
10//!
11//! What: defines the [`ChatProvider`] trait, [`ToolDef`] / [`ToolCall`] /
12//! [`ChatEvent`] tool-use types, an [`OpenRouterProvider`] and an
13//! [`OllamaProvider`] that both speak OpenAI-compatible
14//! `/v1/chat/completions` with SSE streaming (including the streamed
15//! `tool_calls` shape), and [`auto_detect_local_provider`] which probes
16//! `{base_url}/v1/models` with a 1-second timeout.
17//!
18//! Test: `cargo test -p trusty-common` covers default config values, the
19//! unreachable-server path of `auto_detect_local_provider`, SSE delta
20//! streaming, and accumulation of streamed tool-call fragments.
21
22use crate::ChatMessage;
23use anyhow::{Context, Result, anyhow};
24use async_trait::async_trait;
25use serde::{Deserialize, Serialize};
26use tokio::sync::mpsc::Sender;
27
28const LOCAL_PROBE_TIMEOUT_SECS: u64 = 1;
29const LOCAL_REQUEST_TIMEOUT_SECS: u64 = 120;
30const OPENROUTER_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
31const OPENROUTER_CONNECT_TIMEOUT_SECS: u64 = 10;
32const OPENROUTER_REQUEST_TIMEOUT_SECS: u64 = 120;
33const HTTP_REFERER: &str = "https://github.com/bobmatnyc/trusty-common";
34const X_TITLE: &str = "trusty-common";
35
36/// Configuration for a local OpenAI-compatible model server (Ollama, LM
37/// Studio, llama.cpp's server, etc.).
38///
39/// Why: callers want a single struct they can deserialize from config files
40/// and pass to [`auto_detect_local_provider`] without juggling defaults.
41/// What: holds an enable flag, the server's base URL (no trailing slash),
42/// and the default model to request. Defaults target Ollama's standard
43/// localhost binding.
44/// Test: `local_model_config_defaults` asserts the default values.
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct LocalModelConfig {
47    pub enabled: bool,
48    pub base_url: String,
49    pub model: String,
50}
51
52impl Default for LocalModelConfig {
53    fn default() -> Self {
54        Self {
55            enabled: true,
56            base_url: "http://localhost:11434".to_string(),
57            model: "qwen3:30b".to_string(),
58        }
59    }
60}
61
62// ─── Tool-use types ───────────────────────────────────────────────────────
63
64/// JSON-Schema description of a callable tool, in OpenAI function-calling
65/// shape.
66///
67/// Why: downstream agents (trusty-memory, trusty-search) expose tools like
68/// `memory_recall` or `web_search` to the LLM. The OpenAI tool format is the
69/// de-facto common denominator across OpenRouter, Ollama, LM Studio, and
70/// most cloud providers.
71/// What: `name` and `description` are passed verbatim; `parameters` is a
72/// JSON Schema object (typically `{"type":"object","properties":{...}}`).
73/// Test: `tool_def_serializes_as_function` checks the wire shape.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ToolDef {
76    pub name: String,
77    pub description: String,
78    pub parameters: serde_json::Value,
79}
80
81/// A tool invocation the model wants the host to perform.
82///
83/// Why: the streaming chat API emits `tool_calls` in fragments — first an
84/// `id` + `function.name`, then a string of `function.arguments` deltas.
85/// We accumulate fragments and surface one fully-formed [`ToolCall`] per
86/// invocation to the caller.
87/// What: `id` is the upstream's call id (echoed back in subsequent
88/// `role:"tool"` messages); `name` is the function name; `arguments` is a
89/// JSON string (NOT a parsed value — many models emit malformed JSON and
90/// callers want the raw text for error reporting / repair).
91/// Test: `accumulates_streamed_tool_call_fragments`.
92#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
93pub struct ToolCall {
94    pub id: String,
95    pub name: String,
96    pub arguments: String,
97}
98
99/// Streaming chat event.
100///
101/// Why: replaces the previous "string-only" channel so callers can
102/// distinguish text deltas from tool invocations and from terminal
103/// success/error without parsing magic markers out of the text stream.
104/// What: `Delta` is a content chunk; `ToolCall` is a fully-accumulated tool
105/// invocation; `Done` signals the upstream stream terminated normally;
106/// `Error` carries a human-readable message for stream-mid failures (the
107/// provider also returns `Err` from `chat_stream`, but `Error` lets the
108/// caller display partial-stream failures inline).
109/// Test: `ollama_provider_streams_sse_deltas`.
110#[derive(Debug, Clone)]
111pub enum ChatEvent {
112    Delta(String),
113    ToolCall(ToolCall),
114    Done,
115    Error(String),
116}
117
118/// Streaming chat provider abstraction.
119///
120/// Why: downstream crates (trusty-memory, trusty-search) want to support
121/// multiple LLM backends without hard-coding which one to call. Providers
122/// expose a uniform streaming interface so the caller can swap them at
123/// runtime based on configuration / availability.
124/// What: implementors stream [`ChatEvent`]s into `tx`. Pass an empty
125/// `tools` vec to disable tool use entirely (the provider MUST then omit
126/// the `tools` field from the upstream request — some models error on an
127/// empty array). Returning `Ok(())` means the stream completed normally;
128/// the caller should also expect a final [`ChatEvent::Done`].
129/// Test: implementations are covered by their own unit tests in this
130/// module plus integration tests in downstream crates.
131#[async_trait]
132pub trait ChatProvider: Send + Sync {
133    /// Human-readable provider name (e.g. `"openrouter"`, `"ollama"`).
134    fn name(&self) -> &str;
135    /// Model identifier sent on every request.
136    fn model(&self) -> &str;
137    /// Stream chat events into `tx`. `tools` empty disables tool use.
138    async fn chat_stream(
139        &self,
140        messages: Vec<ChatMessage>,
141        tools: Vec<ToolDef>,
142        tx: Sender<ChatEvent>,
143    ) -> Result<()>;
144}
145
146// ─── Shared SSE / request types ────────────────────────────────────────────
147
148#[derive(Debug, Serialize)]
149struct OpenAiToolWire<'a> {
150    #[serde(rename = "type")]
151    kind: &'static str,
152    function: OpenAiFunctionWire<'a>,
153}
154
155#[derive(Debug, Serialize)]
156struct OpenAiFunctionWire<'a> {
157    name: &'a str,
158    description: &'a str,
159    parameters: &'a serde_json::Value,
160}
161
162#[derive(Debug, Serialize)]
163struct ChatRequestWire<'a> {
164    model: &'a str,
165    messages: &'a [ChatMessage],
166    stream: bool,
167    #[serde(skip_serializing_if = "Option::is_none")]
168    tools: Option<Vec<OpenAiToolWire<'a>>>,
169}
170
171fn tools_wire(tools: &[ToolDef]) -> Option<Vec<OpenAiToolWire<'_>>> {
172    if tools.is_empty() {
173        None
174    } else {
175        Some(
176            tools
177                .iter()
178                .map(|t| OpenAiToolWire {
179                    kind: "function",
180                    function: OpenAiFunctionWire {
181                        name: &t.name,
182                        description: &t.description,
183                        parameters: &t.parameters,
184                    },
185                })
186                .collect(),
187        )
188    }
189}
190
191/// Accumulator for streamed tool-call fragments.
192///
193/// Why: OpenAI-style streaming sends each tool call across multiple SSE
194/// frames: the first frame at a given `index` carries `id` and
195/// `function.name`; subsequent frames append to `function.arguments`. We
196/// accumulate by `index` and emit fully-formed [`ToolCall`]s only after the
197/// stream terminates (or we see `finish_reason: tool_calls`).
198/// What: a vector slot per index, growing as needed; merge logic is in
199/// `apply_delta`. `finalize` drops slots that never received an id (defensive
200/// — shouldn't happen but avoids emitting half-baked calls).
201/// Test: `accumulates_streamed_tool_call_fragments`.
202#[derive(Debug, Default)]
203struct ToolCallAccumulator {
204    // index -> (id, name, args)
205    slots: Vec<Option<(String, String, String)>>,
206}
207
208impl ToolCallAccumulator {
209    fn apply_delta(&mut self, tool_calls: &serde_json::Value) {
210        let Some(arr) = tool_calls.as_array() else {
211            return;
212        };
213        for tc in arr {
214            let idx = tc.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as usize;
215            while self.slots.len() <= idx {
216                self.slots.push(None);
217            }
218            let slot = self.slots[idx]
219                .get_or_insert_with(|| (String::new(), String::new(), String::new()));
220            if let Some(id) = tc.get("id").and_then(|v| v.as_str())
221                && !id.is_empty()
222            {
223                slot.0 = id.to_string();
224            }
225            if let Some(func) = tc.get("function") {
226                if let Some(name) = func.get("name").and_then(|v| v.as_str())
227                    && !name.is_empty()
228                {
229                    slot.1 = name.to_string();
230                }
231                if let Some(args) = func.get("arguments").and_then(|v| v.as_str()) {
232                    slot.2.push_str(args);
233                }
234            }
235        }
236    }
237
238    fn finalize(self) -> Vec<ToolCall> {
239        self.slots
240            .into_iter()
241            .filter_map(|opt| {
242                opt.and_then(|(id, name, arguments)| {
243                    if name.is_empty() {
244                        None
245                    } else {
246                        Some(ToolCall {
247                            id,
248                            name,
249                            arguments,
250                        })
251                    }
252                })
253            })
254            .collect()
255    }
256}
257
258/// Drive one OpenAI-compatible SSE stream into the caller's [`ChatEvent`]
259/// channel.
260///
261/// Why: OpenRouter and Ollama both speak the same wire format; sharing the
262/// loop keeps the two providers in lock-step.
263/// What: reads `resp.bytes_stream()`, splits on newlines, parses `data:`
264/// frames, forwards `delta.content` as [`ChatEvent::Delta`], accumulates
265/// `delta.tool_calls`, and on `[DONE]` (or upstream EOF) emits one
266/// [`ChatEvent::ToolCall`] per accumulated call followed by
267/// [`ChatEvent::Done`].
268/// Test: covered by `ollama_provider_streams_sse_deltas` and
269/// `accumulates_streamed_tool_call_fragments`.
270async fn pump_openai_sse(resp: reqwest::Response, tx: Sender<ChatEvent>) -> Result<()> {
271    use futures_util::StreamExt;
272
273    let mut acc = ToolCallAccumulator::default();
274    let mut buf = String::new();
275    let mut stream = resp.bytes_stream();
276
277    while let Some(chunk) = stream.next().await {
278        let bytes = chunk.context("read chat stream chunk")?;
279        let text = match std::str::from_utf8(&bytes) {
280            Ok(s) => s,
281            Err(_) => continue,
282        };
283        buf.push_str(text);
284
285        while let Some(idx) = buf.find('\n') {
286            let line: String = buf.drain(..=idx).collect();
287            let line = line.trim();
288            let Some(payload) = line.strip_prefix("data:").map(str::trim) else {
289                continue;
290            };
291            if payload.is_empty() {
292                continue;
293            }
294            if payload == "[DONE]" {
295                // Flush accumulated tool calls and finish.
296                for call in std::mem::take(&mut acc).finalize() {
297                    if tx.send(ChatEvent::ToolCall(call)).await.is_err() {
298                        return Ok(());
299                    }
300                }
301                let _ = tx.send(ChatEvent::Done).await;
302                return Ok(());
303            }
304            let v: serde_json::Value = match serde_json::from_str(payload) {
305                Ok(v) => v,
306                Err(_) => continue,
307            };
308            let delta = v
309                .get("choices")
310                .and_then(|c| c.get(0))
311                .and_then(|c| c.get("delta"));
312            if let Some(delta) = delta {
313                if let Some(content) = delta.get("content").and_then(|c| c.as_str())
314                    && !content.is_empty()
315                    && tx
316                        .send(ChatEvent::Delta(content.to_string()))
317                        .await
318                        .is_err()
319                {
320                    return Ok(());
321                }
322                if let Some(tc) = delta.get("tool_calls") {
323                    acc.apply_delta(tc);
324                }
325            }
326        }
327    }
328
329    // Upstream EOF without a [DONE] sentinel — still flush and finish.
330    for call in acc.finalize() {
331        if tx.send(ChatEvent::ToolCall(call)).await.is_err() {
332            return Ok(());
333        }
334    }
335    let _ = tx.send(ChatEvent::Done).await;
336    Ok(())
337}
338
339// ─── OpenRouter ───────────────────────────────────────────────────────────
340
341/// Cloud chat provider backed by OpenRouter.
342///
343/// Why: lets callers pick OpenRouter or a local model uniformly through
344/// the [`ChatProvider`] trait.
345/// What: stores an API key and model id; POSTs OpenAI-compatible streaming
346/// chat completions with bearer auth and trusty-common branding headers.
347/// Test: shape covered by `openrouter_provider_reports_metadata`; the
348/// streaming and tool-call paths are covered by integration tests in
349/// downstream crates plus the SSE-pump unit tests in this module.
350pub struct OpenRouterProvider {
351    pub api_key: String,
352    pub model: String,
353}
354
355impl OpenRouterProvider {
356    /// Construct a provider from an API key and model id.
357    ///
358    /// Why: keeps callers from poking the public fields directly so the
359    /// struct can grow optional knobs without breaking call sites.
360    /// What: stores both fields verbatim.
361    /// Test: trivially exercised by `openrouter_provider_reports_metadata`.
362    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
363        Self {
364            api_key: api_key.into(),
365            model: model.into(),
366        }
367    }
368}
369
370#[async_trait]
371impl ChatProvider for OpenRouterProvider {
372    fn name(&self) -> &str {
373        "openrouter"
374    }
375
376    fn model(&self) -> &str {
377        &self.model
378    }
379
380    async fn chat_stream(
381        &self,
382        messages: Vec<ChatMessage>,
383        tools: Vec<ToolDef>,
384        tx: Sender<ChatEvent>,
385    ) -> Result<()> {
386        if self.api_key.is_empty() {
387            return Err(anyhow!("openrouter api key is empty"));
388        }
389        let client = reqwest::Client::builder()
390            .connect_timeout(std::time::Duration::from_secs(
391                OPENROUTER_CONNECT_TIMEOUT_SECS,
392            ))
393            .timeout(std::time::Duration::from_secs(
394                OPENROUTER_REQUEST_TIMEOUT_SECS,
395            ))
396            .build()
397            .context("build reqwest client for OpenRouterProvider::chat_stream")?;
398
399        let tools_wire = tools_wire(&tools);
400        let body = ChatRequestWire {
401            model: &self.model,
402            messages: &messages,
403            stream: true,
404            tools: tools_wire,
405        };
406        let resp = client
407            .post(OPENROUTER_URL)
408            .bearer_auth(&self.api_key)
409            .header("HTTP-Referer", HTTP_REFERER)
410            .header("X-Title", X_TITLE)
411            .json(&body)
412            .send()
413            .await
414            .context("POST openrouter chat completions (stream)")?;
415
416        let status = resp.status();
417        if !status.is_success() {
418            let text = resp.text().await.unwrap_or_default();
419            return Err(anyhow!("openrouter HTTP {status}: {text}"));
420        }
421
422        pump_openai_sse(resp, tx).await
423    }
424}
425
426// ─── Ollama / OpenAI-compatible local ─────────────────────────────────────
427
428/// Local chat provider for OpenAI-compatible servers (Ollama, LM Studio,
429/// llama.cpp's `server`, vLLM, etc.).
430///
431/// Why: developers increasingly run a local model server during dev to avoid
432/// API costs and latency. The OpenAI-compatible `/v1/chat/completions`
433/// endpoint with SSE streaming is the de-facto common denominator.
434/// What: stores the server's base URL and the model id to request.
435/// `chat_stream` POSTs `{model, messages, tools?, stream: true}` and parses
436/// SSE `data:` frames identically to the OpenRouter path.
437/// Test: shape covered by `ollama_provider_reports_metadata`; streaming and
438/// tool-call accumulation by `ollama_provider_streams_sse_deltas` and
439/// `accumulates_streamed_tool_call_fragments`.
440pub struct OllamaProvider {
441    pub base_url: String,
442    pub model: String,
443}
444
445impl OllamaProvider {
446    /// Construct a provider from a base URL and model id.
447    ///
448    /// Why: parallel to [`OpenRouterProvider::new`] so callers see a
449    /// consistent shape across providers.
450    /// What: stores both fields verbatim; the base URL should NOT have a
451    /// trailing slash — the implementation appends `/v1/chat/completions`.
452    /// Test: covered by `ollama_provider_reports_metadata`.
453    pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
454        Self {
455            base_url: base_url.into(),
456            model: model.into(),
457        }
458    }
459}
460
461#[async_trait]
462impl ChatProvider for OllamaProvider {
463    fn name(&self) -> &str {
464        "ollama"
465    }
466
467    fn model(&self) -> &str {
468        &self.model
469    }
470
471    async fn chat_stream(
472        &self,
473        messages: Vec<ChatMessage>,
474        tools: Vec<ToolDef>,
475        tx: Sender<ChatEvent>,
476    ) -> Result<()> {
477        let client = reqwest::Client::builder()
478            .connect_timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
479            .timeout(std::time::Duration::from_secs(LOCAL_REQUEST_TIMEOUT_SECS))
480            .build()
481            .context("build reqwest client for OllamaProvider::chat_stream")?;
482
483        let url = format!(
484            "{}/v1/chat/completions",
485            self.base_url.trim_end_matches('/')
486        );
487        let tools_wire = tools_wire(&tools);
488        let body = ChatRequestWire {
489            model: &self.model,
490            messages: &messages,
491            stream: true,
492            tools: tools_wire,
493        };
494        let resp = client
495            .post(&url)
496            .json(&body)
497            .send()
498            .await
499            .with_context(|| format!("POST {url}"))?;
500
501        let status = resp.status();
502        if !status.is_success() {
503            let text = resp.text().await.unwrap_or_default();
504            return Err(anyhow!("local chat HTTP {status}: {text}"));
505        }
506
507        pump_openai_sse(resp, tx).await
508    }
509}
510
511/// Probe a local model server and return an [`OllamaProvider`] if reachable.
512///
513/// Why: at startup, downstream daemons want to know whether a local model
514/// server is running before falling back to a cloud provider. The OpenAI
515/// `/v1/models` endpoint is a cheap, side-effect-free liveness check that
516/// Ollama, LM Studio, and llama.cpp's server all implement.
517/// What: GETs `{base_url}/v1/models` with a 1-second total timeout. Returns
518/// `Some(OllamaProvider { base_url, model: "" })` on any 2xx response.
519/// Returns `None` on network errors, timeouts, or non-2xx status. Never
520/// returns an error — the caller treats absence as "no local provider
521/// available" and is responsible for setting the model id afterwards (e.g.
522/// from [`LocalModelConfig::model`]).
523/// Test: `auto_detect_returns_none_on_unreachable` points at a closed port
524/// and asserts `None` within the 1-second budget;
525/// `auto_detect_returns_some_on_200` spins up an in-process server and
526/// asserts a provider is returned.
527pub async fn auto_detect_local_provider(base_url: &str) -> Option<OllamaProvider> {
528    let client = reqwest::Client::builder()
529        .connect_timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
530        .timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
531        .build()
532        .ok()?;
533
534    let url = format!("{}/v1/models", base_url.trim_end_matches('/'));
535    match client.get(&url).send().await {
536        Ok(resp) if resp.status().is_success() => {
537            Some(OllamaProvider::new(base_url.to_string(), String::new()))
538        }
539        _ => None,
540    }
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546
547    #[test]
548    fn local_model_config_defaults() {
549        let cfg = LocalModelConfig::default();
550        assert!(cfg.enabled);
551        assert_eq!(cfg.base_url, "http://localhost:11434");
552        assert_eq!(cfg.model, "qwen3:30b");
553    }
554
555    #[test]
556    fn openrouter_provider_reports_metadata() {
557        let p = OpenRouterProvider::new("sk-xxx", "anthropic/claude-3.5-sonnet");
558        assert_eq!(p.name(), "openrouter");
559        assert_eq!(p.model(), "anthropic/claude-3.5-sonnet");
560    }
561
562    #[test]
563    fn ollama_provider_reports_metadata() {
564        let p = OllamaProvider::new("http://localhost:11434", "llama3.2");
565        assert_eq!(p.name(), "ollama");
566        assert_eq!(p.model(), "llama3.2");
567    }
568
569    #[test]
570    fn tool_def_serializes_as_function() {
571        // When passed through `tools_wire`, a ToolDef should produce a JSON
572        // object that matches the OpenAI function-calling shape.
573        let tools = vec![ToolDef {
574            name: "search".into(),
575            description: "Search the web".into(),
576            parameters: serde_json::json!({
577                "type": "object",
578                "properties": { "query": { "type": "string" } },
579                "required": ["query"],
580            }),
581        }];
582        let wire = tools_wire(&tools).expect("expected Some");
583        let v = serde_json::to_value(&wire).unwrap();
584        assert_eq!(v[0]["type"], "function");
585        assert_eq!(v[0]["function"]["name"], "search");
586        assert_eq!(v[0]["function"]["parameters"]["type"], "object");
587    }
588
589    #[test]
590    fn empty_tools_serializes_to_none() {
591        // Empty tools must omit the field entirely so models that error on
592        // empty arrays still work.
593        assert!(tools_wire(&[]).is_none());
594    }
595
596    #[test]
597    fn accumulates_streamed_tool_call_fragments() {
598        // Simulate three SSE deltas for a single tool call: id+name, then
599        // two args fragments. After finalize, we should see one fully-formed
600        // ToolCall with concatenated arguments.
601        let mut acc = ToolCallAccumulator::default();
602        acc.apply_delta(&serde_json::json!([{
603            "index": 0,
604            "id": "call_abc",
605            "function": { "name": "search", "arguments": "" }
606        }]));
607        acc.apply_delta(&serde_json::json!([{
608            "index": 0,
609            "function": { "arguments": "{\"query\":\"" }
610        }]));
611        acc.apply_delta(&serde_json::json!([{
612            "index": 0,
613            "function": { "arguments": "rust\"}" }
614        }]));
615        let calls = acc.finalize();
616        assert_eq!(calls.len(), 1);
617        assert_eq!(calls[0].id, "call_abc");
618        assert_eq!(calls[0].name, "search");
619        assert_eq!(calls[0].arguments, "{\"query\":\"rust\"}");
620    }
621
622    #[tokio::test]
623    async fn auto_detect_returns_none_on_unreachable() {
624        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
625        let port = listener.local_addr().unwrap().port();
626        drop(listener);
627
628        let base = format!("http://127.0.0.1:{port}");
629        let start = std::time::Instant::now();
630        let got = auto_detect_local_provider(&base).await;
631        let elapsed = start.elapsed();
632        assert!(got.is_none(), "expected None for unreachable server");
633        assert!(
634            elapsed < std::time::Duration::from_secs(2),
635            "auto-detect took too long: {elapsed:?}"
636        );
637    }
638
639    #[tokio::test]
640    async fn auto_detect_returns_some_on_200() {
641        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
642        let addr = listener.local_addr().unwrap();
643        let base = format!("http://{addr}");
644
645        tokio::spawn(async move {
646            if let Ok((mut sock, _)) = listener.accept().await {
647                use tokio::io::{AsyncReadExt, AsyncWriteExt};
648                let mut buf = [0u8; 1024];
649                let _ = sock.read(&mut buf).await;
650                let body = b"{\"data\":[]}";
651                let response = format!(
652                    "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
653                    body.len()
654                );
655                let _ = sock.write_all(response.as_bytes()).await;
656                let _ = sock.write_all(body).await;
657                let _ = sock.shutdown().await;
658            }
659        });
660
661        let got = auto_detect_local_provider(&base).await;
662        assert!(got.is_some(), "expected Some for reachable 200 server");
663        let p = got.unwrap();
664        assert_eq!(p.name(), "ollama");
665        assert_eq!(p.base_url, base);
666    }
667
668    #[test]
669    fn local_model_config_deserializes_from_toml() {
670        let toml_src = r#"
671            enabled = true
672            base_url = "http://localhost:1234"
673            model = "qwen2.5-coder"
674        "#;
675        let cfg: LocalModelConfig = toml::from_str(toml_src).expect("parse TOML");
676        assert!(cfg.enabled);
677        assert_eq!(cfg.base_url, "http://localhost:1234");
678        assert_eq!(cfg.model, "qwen2.5-coder");
679    }
680
681    #[tokio::test]
682    async fn ollama_provider_streams_sse_deltas() {
683        // Inline server replies with two content deltas plus [DONE]. We
684        // expect two Delta events followed by Done.
685        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
686        let addr = listener.local_addr().unwrap();
687        let base = format!("http://{addr}");
688
689        tokio::spawn(async move {
690            if let Ok((mut sock, _)) = listener.accept().await {
691                use tokio::io::{AsyncReadExt, AsyncWriteExt};
692                let mut buf = [0u8; 4096];
693                let _ = sock.read(&mut buf).await;
694
695                let sse_body = concat!(
696                    "data: {\"choices\":[{\"delta\":{\"content\":\"hello \"}}]}\n\n",
697                    "data: {\"choices\":[{\"delta\":{\"content\":\"world\"}}]}\n\n",
698                    "data: [DONE]\n\n",
699                );
700                let response = format!(
701                    "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
702                    sse_body.len(),
703                    sse_body
704                );
705                let _ = sock.write_all(response.as_bytes()).await;
706                let _ = sock.shutdown().await;
707            }
708        });
709
710        let provider = OllamaProvider::new(base, "test-model");
711        let (tx, mut rx) = tokio::sync::mpsc::channel::<ChatEvent>(8);
712        let handle = tokio::spawn(async move {
713            provider
714                .chat_stream(
715                    vec![ChatMessage {
716                        role: "user".into(),
717                        content: "hi".into(),
718                        tool_call_id: None,
719                        tool_calls: None,
720                    }],
721                    vec![],
722                    tx,
723                )
724                .await
725        });
726
727        let mut deltas = Vec::new();
728        let mut saw_done = false;
729        while let Some(ev) = rx.recv().await {
730            match ev {
731                ChatEvent::Delta(s) => deltas.push(s),
732                ChatEvent::Done => saw_done = true,
733                ChatEvent::ToolCall(_) => panic!("unexpected tool call"),
734                ChatEvent::Error(e) => panic!("stream error: {e}"),
735            }
736        }
737        let result = handle.await.expect("task panicked");
738        assert!(result.is_ok(), "chat_stream errored: {result:?}");
739        assert_eq!(deltas, vec!["hello ".to_string(), "world".to_string()]);
740        assert!(saw_done, "expected ChatEvent::Done");
741    }
742
743    #[tokio::test]
744    async fn ollama_provider_emits_tool_call() {
745        // SSE stream that delivers one tool call across two fragments.
746        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
747        let addr = listener.local_addr().unwrap();
748        let base = format!("http://{addr}");
749
750        tokio::spawn(async move {
751            if let Ok((mut sock, _)) = listener.accept().await {
752                use tokio::io::{AsyncReadExt, AsyncWriteExt};
753                let mut buf = [0u8; 4096];
754                let _ = sock.read(&mut buf).await;
755
756                let sse_body = concat!(
757                    "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"id\":\"call_1\",\"function\":{\"name\":\"search\",\"arguments\":\"{\\\"q\\\":\"}}]}}]}\n\n",
758                    "data: {\"choices\":[{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\"rust\\\"}\"}}]}}]}\n\n",
759                    "data: [DONE]\n\n",
760                );
761                let response = format!(
762                    "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
763                    sse_body.len(),
764                    sse_body
765                );
766                let _ = sock.write_all(response.as_bytes()).await;
767                let _ = sock.shutdown().await;
768            }
769        });
770
771        let provider = OllamaProvider::new(base, "test-model");
772        let (tx, mut rx) = tokio::sync::mpsc::channel::<ChatEvent>(8);
773        let handle = tokio::spawn(async move {
774            provider
775                .chat_stream(
776                    vec![ChatMessage {
777                        role: "user".into(),
778                        content: "search rust".into(),
779                        tool_call_id: None,
780                        tool_calls: None,
781                    }],
782                    vec![ToolDef {
783                        name: "search".into(),
784                        description: "search the web".into(),
785                        parameters: serde_json::json!({"type":"object"}),
786                    }],
787                    tx,
788                )
789                .await
790        });
791
792        let mut tool_calls = Vec::new();
793        let mut saw_done = false;
794        while let Some(ev) = rx.recv().await {
795            match ev {
796                ChatEvent::ToolCall(tc) => tool_calls.push(tc),
797                ChatEvent::Done => saw_done = true,
798                ChatEvent::Delta(_) => {}
799                ChatEvent::Error(e) => panic!("stream error: {e}"),
800            }
801        }
802        let result = handle.await.expect("task panicked");
803        assert!(result.is_ok(), "chat_stream errored: {result:?}");
804        assert_eq!(tool_calls.len(), 1);
805        assert_eq!(tool_calls[0].id, "call_1");
806        assert_eq!(tool_calls[0].name, "search");
807        assert_eq!(tool_calls[0].arguments, "{\"q\":\"rust\"}");
808        assert!(saw_done);
809    }
810}