Skip to main content

zeph_llm/
compatible.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! `OpenAI`-compatible provider adapter.
5//!
6//! [`CompatibleProvider`] wraps [`crate::openai::OpenAiProvider`] and adds a named
7//! provider label for logging. Use it for any endpoint that exposes the `OpenAI` Chat
8//! Completions and Embeddings API (Together AI, Fireworks, Anyscale, local vLLM, etc.).
9//!
10//! # Configuration
11//!
12//! ```toml
13//! [[llm.providers]]
14//! name = "together"
15//! type = "compatible"
16//! provider_name = "together-ai"
17//! base_url = "https://api.together.xyz/v1"
18//! model = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
19//! max_tokens = 4096
20//! api_key_vault = "ZEPH_TOGETHER_API_KEY"
21//! ```
22
23use std::fmt;
24
25use crate::error::LlmError;
26use crate::openai::{OpenAiConfig, OpenAiProvider};
27use crate::provider::{
28    ChatExtras, ChatResponse, ChatStream, GenerationOverrides, LlmProvider, Message, StatusTx,
29    ToolDefinition,
30};
31
32/// Configuration for [`CompatibleProvider`].
33///
34/// Pass to [`CompatibleProvider::new`] instead of individual positional arguments to avoid
35/// silent parameter transposition.
36///
37/// # Examples
38///
39/// ```
40/// use zeph_llm::compatible::{CompatibleConfig, CompatibleProvider};
41///
42/// let cfg = CompatibleConfig {
43///     provider_name: "together-ai".into(),
44///     api_key: "key".into(),
45///     base_url: "https://api.together.xyz/v1".into(),
46///     model: "meta-llama/Llama-3.3-70B-Instruct-Turbo".into(),
47///     max_tokens: 4096,
48///     embedding_model: None,
49/// };
50/// let provider = CompatibleProvider::new(cfg);
51/// ```
52#[derive(Debug, Clone)]
53pub struct CompatibleConfig {
54    /// Human-readable provider name used in logs and [`LlmProvider::name`].
55    pub provider_name: String,
56    /// Secret API key sent in the `Authorization: Bearer` header.
57    pub api_key: String,
58    /// Base URL of the endpoint, e.g. `"https://api.together.xyz/v1"`.
59    pub base_url: String,
60    /// Chat model identifier.
61    pub model: String,
62    /// Upper bound on completion tokens returned by the model.
63    pub max_tokens: u32,
64    /// Embedding model identifier. Set to `None` when the endpoint does not support embeddings.
65    pub embedding_model: Option<String>,
66}
67
68/// [`LlmProvider`] adapter for OpenAI-compatible REST endpoints.
69///
70/// Delegates all operations to an inner [`OpenAiProvider`] while exposing a
71/// configurable `provider_name` for logging and routing identification.
72pub struct CompatibleProvider {
73    inner: OpenAiProvider,
74    /// Human-readable name used in logs and [`LlmProvider::name`].
75    provider_name: String,
76}
77
78impl CompatibleProvider {
79    /// Create a new provider from a [`CompatibleConfig`].
80    #[must_use]
81    pub fn new(cfg: CompatibleConfig) -> Self {
82        let provider_name = cfg.provider_name;
83        let inner = OpenAiProvider::new(OpenAiConfig {
84            api_key: cfg.api_key,
85            base_url: cfg.base_url,
86            model: cfg.model,
87            max_tokens: cfg.max_tokens,
88            embedding_model: cfg.embedding_model,
89            reasoning_effort: None,
90        });
91        Self {
92            inner,
93            provider_name,
94        }
95    }
96}
97
98impl fmt::Debug for CompatibleProvider {
99    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100        f.debug_struct("CompatibleProvider")
101            .field("provider_name", &self.provider_name)
102            .field("inner", &self.inner)
103            .finish_non_exhaustive()
104    }
105}
106
107impl Clone for CompatibleProvider {
108    fn clone(&self) -> Self {
109        Self {
110            inner: self.inner.clone(),
111            provider_name: self.provider_name.clone(),
112        }
113    }
114}
115
116impl CompatibleProvider {
117    /// Fetch models via the inner `OpenAiProvider`. Cache slug is derived from base URL.
118    ///
119    /// # Errors
120    ///
121    /// Returns an error if the API request fails.
122    pub async fn list_models_remote(
123        &self,
124    ) -> Result<Vec<crate::model_cache::RemoteModelInfo>, LlmError> {
125        self.inner.list_models_remote().await
126    }
127}
128
129impl CompatibleProvider {
130    /// Attach a status channel for streaming progress events to the TUI.
131    pub fn set_status_tx(&mut self, tx: StatusTx) {
132        self.inner.status_tx = Some(tx);
133    }
134
135    /// Override generation parameters (temperature, top-p, etc.) for all subsequent calls.
136    #[must_use]
137    pub fn with_generation_overrides(mut self, overrides: GenerationOverrides) -> Self {
138        self.inner = self.inner.with_generation_overrides(overrides);
139        self
140    }
141
142    /// Forward MCP tool output schemas as JSON hints appended to tool descriptions.
143    ///
144    /// Delegates to the inner [`OpenAiProvider`]. When `enabled` is `false` the call is a no-op.
145    /// `hint_bytes` caps the JSON representation; `max_description_bytes` caps the combined
146    /// description string.
147    #[must_use]
148    pub fn with_output_schema_forwarding(
149        mut self,
150        enabled: bool,
151        hint_bytes: usize,
152        max_description_bytes: usize,
153    ) -> Self {
154        self.inner =
155            self.inner
156                .with_output_schema_forwarding(enabled, hint_bytes, max_description_bytes);
157        self
158    }
159}
160
161impl LlmProvider for CompatibleProvider {
162    fn context_window(&self) -> Option<usize> {
163        None
164    }
165
166    #[cfg_attr(
167        feature = "profiling",
168        tracing::instrument(
169            name = "llm.chat",
170            skip_all,
171            fields(provider = self.name(), model = self.model_identifier())
172        )
173    )]
174    async fn chat(&self, messages: &[Message]) -> Result<String, LlmError> {
175        self.inner.chat(messages).await
176    }
177
178    async fn chat_with_extras(
179        &self,
180        messages: &[Message],
181    ) -> Result<(String, ChatExtras), LlmError> {
182        self.inner.chat_with_extras(messages).await
183    }
184
185    #[cfg_attr(
186        feature = "profiling",
187        tracing::instrument(
188            name = "llm.chat_stream",
189            skip_all,
190            fields(provider = self.name(), model = self.model_identifier())
191        )
192    )]
193    async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
194        self.inner.chat_stream(messages).await
195    }
196
197    fn supports_streaming(&self) -> bool {
198        self.inner.supports_streaming()
199    }
200
201    #[cfg_attr(
202        feature = "profiling",
203        tracing::instrument(
204            name = "llm.embed",
205            skip_all,
206            fields(provider = self.name(), model = self.model_identifier())
207        )
208    )]
209    async fn embed(&self, text: &str) -> Result<Vec<f32>, LlmError> {
210        self.inner.embed(text).await
211    }
212
213    #[cfg_attr(
214        feature = "profiling",
215        tracing::instrument(
216            name = "llm.embed_batch",
217            skip_all,
218            fields(provider = self.name(), model = self.model_identifier())
219        )
220    )]
221    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, LlmError> {
222        self.inner.embed_batch(texts).await
223    }
224
225    fn supports_embeddings(&self) -> bool {
226        self.inner.supports_embeddings()
227    }
228
229    fn name(&self) -> &str {
230        &self.provider_name
231    }
232
233    fn model_identifier(&self) -> &str {
234        self.inner.model_identifier()
235    }
236
237    fn list_models(&self) -> Vec<String> {
238        self.inner.list_models()
239    }
240
241    fn supports_structured_output(&self) -> bool {
242        self.inner.supports_structured_output()
243    }
244
245    async fn chat_typed<T>(&self, messages: &[Message]) -> Result<T, LlmError>
246    where
247        T: serde::de::DeserializeOwned + schemars::JsonSchema + 'static,
248        Self: Sized,
249    {
250        self.inner.chat_typed(messages).await
251    }
252
253    #[cfg_attr(
254        feature = "profiling",
255        tracing::instrument(
256            name = "llm.chat_with_tools",
257            skip_all,
258            fields(provider = self.name(), model = self.model_identifier(), tool_count = tools.len())
259        )
260    )]
261    async fn chat_with_tools(
262        &self,
263        messages: &[Message],
264        tools: &[ToolDefinition],
265    ) -> Result<ChatResponse, LlmError> {
266        self.inner.chat_with_tools(messages, tools).await
267    }
268
269    fn last_cache_usage(&self) -> Option<(u64, u64)> {
270        self.inner.last_cache_usage()
271    }
272
273    fn last_usage(&self) -> Option<(u64, u64)> {
274        self.inner.last_usage()
275    }
276
277    fn debug_request_json(
278        &self,
279        messages: &[Message],
280        tools: &[ToolDefinition],
281        stream: bool,
282    ) -> serde_json::Value {
283        self.inner.debug_request_json(messages, tools, stream)
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    fn test_provider() -> CompatibleProvider {
292        CompatibleProvider::new(CompatibleConfig {
293            provider_name: "groq".into(),
294            api_key: "key".into(),
295            base_url: "https://api.groq.com/openai/v1".into(),
296            model: "llama-3.3-70b".into(),
297            max_tokens: 4096,
298            embedding_model: None,
299        })
300    }
301
302    #[test]
303    fn name_returns_custom_provider_name() {
304        let p = test_provider();
305        assert_eq!(p.name(), "groq");
306    }
307
308    #[test]
309    fn context_window_returns_none() {
310        assert!(test_provider().context_window().is_none());
311    }
312
313    #[test]
314    fn supports_streaming_delegates() {
315        assert!(test_provider().supports_streaming());
316    }
317
318    #[test]
319    fn supports_embeddings_without_model() {
320        assert!(!test_provider().supports_embeddings());
321    }
322
323    #[test]
324    fn supports_embeddings_with_model() {
325        let p = CompatibleProvider::new(CompatibleConfig {
326            provider_name: "test".into(),
327            api_key: "key".into(),
328            base_url: "http://localhost".into(),
329            model: "m".into(),
330            max_tokens: 100,
331            embedding_model: Some("embed-model".into()),
332        });
333        assert!(p.supports_embeddings());
334    }
335
336    #[test]
337    fn clone_preserves_name() {
338        let p = test_provider();
339        let c = p.clone();
340        assert_eq!(c.name(), "groq");
341    }
342
343    #[test]
344    fn debug_contains_provider_name() {
345        let debug = format!("{:?}", test_provider());
346        assert!(debug.contains("groq"));
347        assert!(debug.contains("CompatibleProvider"));
348    }
349
350    #[tokio::test]
351    async fn chat_unreachable_errors() {
352        let p = CompatibleProvider::new(CompatibleConfig {
353            provider_name: "test".into(),
354            api_key: "key".into(),
355            base_url: "http://127.0.0.1:1".into(),
356            model: "m".into(),
357            max_tokens: 100,
358            embedding_model: None,
359        });
360        let msgs = vec![Message::from_legacy(crate::provider::Role::User, "hello")];
361        assert!(p.chat(&msgs).await.is_err());
362    }
363
364    #[tokio::test]
365    async fn embed_without_model_errors() {
366        let p = test_provider();
367        let result = p.embed("test").await;
368        assert!(result.is_err());
369    }
370
371    #[test]
372    fn last_usage_initially_none() {
373        assert!(test_provider().last_usage().is_none());
374    }
375
376    #[test]
377    fn with_output_schema_forwarding_does_not_panic() {
378        // Smoke-test that the builder compiles and returns self without panicking.
379        let p = test_provider().with_output_schema_forwarding(true, 512, usize::MAX);
380        assert_eq!(p.name(), "groq");
381    }
382}