trusty_common/chat/openai_compat/providers.rs
1//! Concrete OpenAI-compatible chat provider implementations.
2//!
3//! Why: OpenRouter (cloud) and Ollama (local) both use the same SSE pump but
4//! differ in auth, URL, and timeout configuration. Keeping both concrete
5//! providers in one file makes the symmetry obvious and lets us share the
6//! `build_consolidation_prompt`-style helpers without re-importing.
7//! What: `OpenRouterProvider`, `OllamaProvider`, and
8//! `auto_detect_local_provider` — the public surface that callers import from
9//! the `openai_compat` module.
10//! Test: `openrouter_provider_reports_metadata`,
11//! `ollama_provider_reports_metadata`, `ollama_provider_streams_sse_deltas`,
12//! `ollama_provider_emits_tool_call`,
13//! `auto_detect_returns_none_on_unreachable`,
14//! `auto_detect_returns_some_on_200`.
15
16use super::sse_pump::pump_openai_sse;
17use super::wire::tools_wire;
18use crate::ChatMessage;
19use crate::chat::{ChatEvent, ChatProvider, ToolDef};
20use anyhow::{Context, Result, anyhow};
21use async_trait::async_trait;
22use tokio::sync::mpsc::Sender;
23
24const LOCAL_PROBE_TIMEOUT_SECS: u64 = 1;
25const LOCAL_REQUEST_TIMEOUT_SECS: u64 = 120;
26const OPENROUTER_URL: &str = "https://openrouter.ai/api/v1/chat/completions";
27const OPENROUTER_CONNECT_TIMEOUT_SECS: u64 = 10;
28const OPENROUTER_REQUEST_TIMEOUT_SECS: u64 = 120;
29const HTTP_REFERER: &str = "https://github.com/bobmatnyc/trusty-common";
30const X_TITLE: &str = "trusty-common";
31
32/// Cloud chat provider backed by OpenRouter.
33///
34/// Why: lets callers pick OpenRouter or a local model uniformly through
35/// the [`ChatProvider`] trait.
36/// What: stores an API key and model id; POSTs OpenAI-compatible streaming
37/// chat completions with bearer auth and trusty-common branding headers.
38/// Test: shape covered by `openrouter_provider_reports_metadata`; the
39/// streaming and tool-call paths are covered by integration tests in
40/// downstream crates plus the SSE-pump unit tests in this module.
41pub struct OpenRouterProvider {
42 pub api_key: String,
43 pub model: String,
44}
45
46impl OpenRouterProvider {
47 /// Construct a provider from an API key and model id.
48 ///
49 /// Why: keeps callers from poking the public fields directly so the
50 /// struct can grow optional knobs without breaking call sites.
51 /// What: stores both fields verbatim.
52 /// Test: trivially exercised by `openrouter_provider_reports_metadata`.
53 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
54 Self {
55 api_key: api_key.into(),
56 model: model.into(),
57 }
58 }
59}
60
61#[async_trait]
62impl ChatProvider for OpenRouterProvider {
63 fn name(&self) -> &str {
64 "openrouter"
65 }
66
67 fn model(&self) -> &str {
68 &self.model
69 }
70
71 async fn chat_stream(
72 &self,
73 messages: Vec<ChatMessage>,
74 tools: Vec<ToolDef>,
75 tx: Sender<ChatEvent>,
76 ) -> Result<()> {
77 if self.api_key.is_empty() {
78 return Err(anyhow!("openrouter api key is empty"));
79 }
80 let client = reqwest::Client::builder()
81 .connect_timeout(std::time::Duration::from_secs(
82 OPENROUTER_CONNECT_TIMEOUT_SECS,
83 ))
84 .timeout(std::time::Duration::from_secs(
85 OPENROUTER_REQUEST_TIMEOUT_SECS,
86 ))
87 .build()
88 .context("build reqwest client for OpenRouterProvider::chat_stream")?;
89
90 let tw = tools_wire(&tools);
91 let body = super::wire::ChatRequestWire {
92 model: &self.model,
93 messages: &messages,
94 stream: true,
95 tools: tw,
96 };
97 let resp = client
98 .post(OPENROUTER_URL)
99 .bearer_auth(&self.api_key)
100 .header("HTTP-Referer", HTTP_REFERER)
101 .header("X-Title", X_TITLE)
102 .json(&body)
103 .send()
104 .await
105 .context("POST openrouter chat completions (stream)")?;
106
107 let status = resp.status();
108 if !status.is_success() {
109 let text = resp.text().await.unwrap_or_default();
110 return Err(anyhow!("openrouter HTTP {status}: {text}"));
111 }
112
113 pump_openai_sse(resp, tx).await
114 }
115}
116
117/// Local chat provider for OpenAI-compatible servers (Ollama, LM Studio,
118/// llama.cpp's `server`, vLLM, etc.).
119///
120/// Why: developers increasingly run a local model server during dev to avoid
121/// API costs and latency. The OpenAI-compatible `/v1/chat/completions`
122/// endpoint with SSE streaming is the de-facto common denominator.
123/// What: stores the server's base URL and the model id to request.
124/// `chat_stream` POSTs `{model, messages, tools?, stream: true}` and parses
125/// SSE `data:` frames identically to the OpenRouter path.
126/// Test: shape covered by `ollama_provider_reports_metadata`; streaming and
127/// tool-call accumulation by `ollama_provider_streams_sse_deltas` and
128/// `accumulates_streamed_tool_call_fragments`.
129pub struct OllamaProvider {
130 pub base_url: String,
131 pub model: String,
132}
133
134impl OllamaProvider {
135 /// Construct a provider from a base URL and model id.
136 ///
137 /// Why: parallel to [`OpenRouterProvider::new`] so callers see a
138 /// consistent shape across providers.
139 /// What: stores both fields verbatim; the base URL should NOT have a
140 /// trailing slash — the implementation appends `/v1/chat/completions`.
141 /// Test: covered by `ollama_provider_reports_metadata`.
142 pub fn new(base_url: impl Into<String>, model: impl Into<String>) -> Self {
143 Self {
144 base_url: base_url.into(),
145 model: model.into(),
146 }
147 }
148}
149
150#[async_trait]
151impl ChatProvider for OllamaProvider {
152 fn name(&self) -> &str {
153 "ollama"
154 }
155
156 fn model(&self) -> &str {
157 &self.model
158 }
159
160 async fn chat_stream(
161 &self,
162 messages: Vec<ChatMessage>,
163 tools: Vec<ToolDef>,
164 tx: Sender<ChatEvent>,
165 ) -> Result<()> {
166 let client = reqwest::Client::builder()
167 .connect_timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
168 .timeout(std::time::Duration::from_secs(LOCAL_REQUEST_TIMEOUT_SECS))
169 .build()
170 .context("build reqwest client for OllamaProvider::chat_stream")?;
171
172 let url = format!(
173 "{}/v1/chat/completions",
174 self.base_url.trim_end_matches('/')
175 );
176 let tw = tools_wire(&tools);
177 let body = super::wire::ChatRequestWire {
178 model: &self.model,
179 messages: &messages,
180 stream: true,
181 tools: tw,
182 };
183 let resp = client
184 .post(&url)
185 .json(&body)
186 .send()
187 .await
188 .with_context(|| format!("POST {url}"))?;
189
190 let status = resp.status();
191 if !status.is_success() {
192 let text = resp.text().await.unwrap_or_default();
193 return Err(anyhow!("local chat HTTP {status}: {text}"));
194 }
195
196 pump_openai_sse(resp, tx).await
197 }
198}
199
200/// Probe a local model server and return an [`OllamaProvider`] if reachable.
201///
202/// Why: at startup, downstream daemons want to know whether a local model
203/// server is running before falling back to a cloud provider. The OpenAI
204/// `/v1/models` endpoint is a cheap, side-effect-free liveness check that
205/// Ollama, LM Studio, and llama.cpp's server all implement.
206/// What: GETs `{base_url}/v1/models` with a 1-second total timeout. Returns
207/// `Some(OllamaProvider { base_url, model: "" })` on any 2xx response.
208/// Returns `None` on network errors, timeouts, or non-2xx status. Never
209/// returns an error — the caller treats absence as "no local provider
210/// available" and is responsible for setting the model id afterwards (e.g.
211/// from [`super::LocalModelConfig::model`]).
212/// Test: `auto_detect_returns_none_on_unreachable` points at a closed port
213/// and asserts `None` within the 1-second budget;
214/// `auto_detect_returns_some_on_200` spins up an in-process server and
215/// asserts a provider is returned.
216pub async fn auto_detect_local_provider(base_url: &str) -> Option<OllamaProvider> {
217 let client = reqwest::Client::builder()
218 .connect_timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
219 .timeout(std::time::Duration::from_secs(LOCAL_PROBE_TIMEOUT_SECS))
220 .build()
221 .ok()?;
222
223 let url = format!("{}/v1/models", base_url.trim_end_matches('/'));
224 match client.get(&url).send().await {
225 Ok(resp) if resp.status().is_success() => {
226 Some(OllamaProvider::new(base_url.to_string(), String::new()))
227 }
228 _ => None,
229 }
230}