Skip to main content

trusty_common/chat/openai_compat/
providers.rs

1//! Concrete OpenAI-compatible chat provider implementations.
2//!
3//! Why: OpenRouter (cloud) and Ollama (local) both use the same SSE pump but
4//! differ in auth, URL, and timeout configuration. Keeping both concrete
5//! providers in one file makes the symmetry obvious and lets us share the
6//! `build_consolidation_prompt`-style helpers without re-importing.
7//! What: `OpenRouterProvider`, `OllamaProvider`, and
8//! `auto_detect_local_provider` — the public surface that callers import from
9//! the `openai_compat` module.
10//! Test: `openrouter_provider_reports_metadata`,
11//! `ollama_provider_reports_metadata`, `ollama_provider_streams_sse_deltas`,
12//! `ollama_provider_emits_tool_call`,
13//! `auto_detect_returns_none_on_unreachable`,
14//! `auto_detect_returns_some_on_200`.
15
16use super::sse_pump::pump_openai_sse;
17use super::wire::tools_wire;
18use crate::ChatMessage;
19use crate::chat::{ChatEvent, ChatProvider, ToolDef};
20use anyhow::{Context, Result, anyhow};
21use async_trait::async_trait;
22use tokio::sync::mpsc::Sender;
23
24const LOCAL_PROBE_TIMEOUT_SECS: u64 = 1;
25const LOCAL_REQUEST_TIMEOUT_SECS: u64 = 120;
26const OPENROUTER_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
27const OPENROUTER_CONNECT_TIMEOUT_SECS: u64 = 10;
28const OPENROUTER_REQUEST_TIMEOUT_SECS: u64 = 120;
29const HTTP_REFERER: &str = "https://github.com/bobmatnyc/trusty-common";
30const X_TITLE: &str = "trusty-common";
31
32/// Cloud chat provider backed by OpenRouter.
33///
34/// Why: lets callers pick OpenRouter or a local model uniformly through
35/// the [`ChatProvider`] trait.
36/// What: stores an API key and model id; POSTs OpenAI-compatible streaming
37/// chat completions with bearer auth and trusty-common branding headers.
38/// Test: shape covered by `openrouter_provider_reports_metadata`; the
39/// streaming and tool-call paths are covered by integration tests in
40/// downstream crates plus the SSE-pump unit tests in this module.
41pub struct OpenRouterProvider {
42    pub api_key: String,
43    pub model: String,
44}
45
46impl OpenRouterProvider {
47    /// Construct a provider from an API key and model id.
48    ///
49    /// Why: keeps callers from poking the public fields directly so the
50    /// struct can grow optional knobs without breaking call sites.
51    /// What: stores both fields verbatim.
52    /// Test: trivially exercised by `openrouter_provider_reports_metadata`.
53    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
54        Self {
55            api_key: api_key.into(),
56            model: model.into(),
57        }
58    }
59}
60
61#[async_trait]
62impl ChatProvider for OpenRouterProvider {
63    fn name(&self) -> &str {
64        "openrouter"
65    }
66
67    fn model(&self) -> &str {
68        &self.model
69    }
70
71    async fn chat_stream(
72        &self,
73        messages: Vec<ChatMessage>,
74        tools: Vec<ToolDef>,
75        tx: Sender<ChatEvent>,
76    ) -> Result<()> {
77        if self.api_key.is_empty() {
78            return Err(anyhow!("openrouter api key is empty"));
79        }
80        let client = reqwest::Client::builder()
81            .connect_timeout(std::time::Duration::from_secs(
82                OPENROUTER_CONNECT_TIMEOUT_SECS,
83            ))
84            .timeout(std::time::Duration::from_secs(
85                OPENROUTER_REQUEST_TIMEOUT_SECS,
86            ))
87            .build()
88            .context("build reqwest client for OpenRouterProvider::chat_stream")?;
89
90        let tw = tools_wire(&tools);
91        let body = super::wire::ChatRequestWire {
92            model: &self.model,
93            messages: &messages,
94            stream: true,
95            tools: tw,
96        };
97        let resp = client
98            .post(OPENROUTER_URL)
99            .bearer_auth(&self.api_key)
100            .header("HTTP-Referer", HTTP_REFERER)
101            .header("X-Title", X_TITLE)
102            .json(&body)
103            .send()
104            .await
105            .context("POST openrouter chat completions (stream)")?;
106
107        let status = resp.status();
108        if !status.is_success() {
109            let text = resp.text().await.unwrap_or_default();
110            return Err(anyhow!("openrouter HTTP {status}: {text}"));
111        }
112
113        pump_openai_sse(resp, tx).await
114    }
115}
116
117/// Local chat provider for OpenAI-compatible servers (Ollama, LM Studio,
118/// llama.cpp's `server`, vLLM, etc.).
119///
120/// Why: developers increasingly run a local model server during dev to avoid
121/// API costs and latency. The OpenAI-compatible `/v1/chat/completions`
122/// endpoint with SSE streaming is the de-facto common denominator.
123/// What: stores the server's base URL and the model id to request.
124/// `chat_stream` POSTs `{model, messages, tools?, stream: true}` and parses
125/// SSE `data:` frames identically to the OpenRouter path.
126/// Test: shape covered by `ollama_provider_reports_metadata`; streaming and
127/// tool-call accumulation by `ollama_provider_streams_sse_deltas` and
128/// `accumulates_streamed_tool_call_fragments`.
129pub struct OllamaProvider {
130    pub base_url: String,
131    pub model: String,
132}
133
134impl OllamaProvider {
135    /// Construct a provider from a base URL and model id.
136    ///
137    /// Why: parallel to [`OpenRouterProvider::new`] so callers see a
138    /// consistent shape across providers.
139    /// What: stores both fields verbatim; the base URL should NOT have a
140    /// trailing slash — the implementation appends `/v1/chat/completions`.
141    /// Test: covered by `ollama_provider_reports_metadata`.
142    pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
143        Self {
144            base_url: base_url.into(),
145            model: model.into(),
146        }
147    }
148}
149
150#[async_trait]
151impl ChatProvider for OllamaProvider {
152    fn name(&self) -> &str {
153        "ollama"
154    }
155
156    fn model(&self) -> &str {
157        &self.model
158    }
159
160    async fn chat_stream(
161        &self,
162        messages: Vec<ChatMessage>,
163        tools: Vec<ToolDef>,
164        tx: Sender<ChatEvent>,
165    ) -> Result<()> {
166        let client = reqwest::Client::builder()
167            .connect_timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
168            .timeout(std::time::Duration::from_secs(LOCAL_REQUEST_TIMEOUT_SECS))
169            .build()
170            .context("build reqwest client for OllamaProvider::chat_stream")?;
171
172        let url = format!(
173            "{}/v1/chat/completions",
174            self.base_url.trim_end_matches('/')
175        );
176        let tw = tools_wire(&tools);
177        let body = super::wire::ChatRequestWire {
178            model: &self.model,
179            messages: &messages,
180            stream: true,
181            tools: tw,
182        };
183        let resp = client
184            .post(&url)
185            .json(&body)
186            .send()
187            .await
188            .with_context(|| format!("POST {url}"))?;
189
190        let status = resp.status();
191        if !status.is_success() {
192            let text = resp.text().await.unwrap_or_default();
193            return Err(anyhow!("local chat HTTP {status}: {text}"));
194        }
195
196        pump_openai_sse(resp, tx).await
197    }
198}
199
200/// Probe a local model server and return an [`OllamaProvider`] if reachable.
201///
202/// Why: at startup, downstream daemons want to know whether a local model
203/// server is running before falling back to a cloud provider. The OpenAI
204/// `/v1/models` endpoint is a cheap, side-effect-free liveness check that
205/// Ollama, LM Studio, and llama.cpp's server all implement.
206/// What: GETs `{base_url}/v1/models` with a 1-second total timeout. Returns
207/// `Some(OllamaProvider { base_url, model: "" })` on any 2xx response.
208/// Returns `None` on network errors, timeouts, or non-2xx status. Never
209/// returns an error — the caller treats absence as "no local provider
210/// available" and is responsible for setting the model id afterwards (e.g.
211/// from [`super::LocalModelConfig::model`]).
212/// Test: `auto_detect_returns_none_on_unreachable` points at a closed port
213/// and asserts `None` within the 1-second budget;
214/// `auto_detect_returns_some_on_200` spins up an in-process server and
215/// asserts a provider is returned.
216pub async fn auto_detect_local_provider(base_url: &str) -> Option<OllamaProvider> {
217    let client = reqwest::Client::builder()
218        .connect_timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
219        .timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
220        .build()
221        .ok()?;
222
223    let url = format!("{}/v1/models", base_url.trim_end_matches('/'));
224    match client.get(&url).send().await {
225        Ok(resp) if resp.status().is_success() => {
226            Some(OllamaProvider::new(base_url.to_string(), String::new()))
227        }
228        _ => None,
229    }
230}