1use std::fmt;
24
25use crate::error::LlmError;
26use crate::openai::{OpenAiConfig, OpenAiProvider};
27use crate::provider::{
28 ChatExtras, ChatResponse, ChatStream, GenerationOverrides, LlmProvider, Message, StatusTx,
29 ToolDefinition,
30};
31
32#[derive(Debug, Clone)]
53pub struct CompatibleConfig {
54 pub provider_name: String,
56 pub api_key: String,
58 pub base_url: String,
60 pub model: String,
62 pub max_tokens: u32,
64 pub embedding_model: Option<String>,
66}
67
68pub struct CompatibleProvider {
73 inner: OpenAiProvider,
74 provider_name: String,
76}
77
78impl CompatibleProvider {
79 #[must_use]
81 pub fn new(cfg: CompatibleConfig) -> Self {
82 let provider_name = cfg.provider_name;
83 let inner = OpenAiProvider::new(OpenAiConfig {
84 api_key: cfg.api_key,
85 base_url: cfg.base_url,
86 model: cfg.model,
87 max_tokens: cfg.max_tokens,
88 embedding_model: cfg.embedding_model,
89 reasoning_effort: None,
90 });
91 Self {
92 inner,
93 provider_name,
94 }
95 }
96}
97
98impl fmt::Debug for CompatibleProvider {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 f.debug_struct("CompatibleProvider")
101 .field("provider_name", &self.provider_name)
102 .field("inner", &self.inner)
103 .finish_non_exhaustive()
104 }
105}
106
107impl Clone for CompatibleProvider {
108 fn clone(&self) -> Self {
109 Self {
110 inner: self.inner.clone(),
111 provider_name: self.provider_name.clone(),
112 }
113 }
114}
115
116impl CompatibleProvider {
117 pub async fn list_models_remote(
123 &self,
124 ) -> Result<Vec<crate::model_cache::RemoteModelInfo>, LlmError> {
125 self.inner.list_models_remote().await
126 }
127}
128
129impl CompatibleProvider {
130 pub fn set_status_tx(&mut self, tx: StatusTx) {
132 self.inner.status_tx = Some(tx);
133 }
134
135 #[must_use]
137 pub fn with_generation_overrides(mut self, overrides: GenerationOverrides) -> Self {
138 self.inner = self.inner.with_generation_overrides(overrides);
139 self
140 }
141
142 #[must_use]
148 pub fn with_output_schema_forwarding(
149 mut self,
150 enabled: bool,
151 hint_bytes: usize,
152 max_description_bytes: usize,
153 ) -> Self {
154 self.inner =
155 self.inner
156 .with_output_schema_forwarding(enabled, hint_bytes, max_description_bytes);
157 self
158 }
159}
160
161impl LlmProvider for CompatibleProvider {
162 fn context_window(&self) -> Option<usize> {
163 None
164 }
165
166 #[cfg_attr(
167 feature = "profiling",
168 tracing::instrument(
169 name = "llm.chat",
170 skip_all,
171 fields(provider = self.name(), model = self.model_identifier())
172 )
173 )]
174 async fn chat(&self, messages: &[Message]) -> Result<String, LlmError> {
175 self.inner.chat(messages).await
176 }
177
178 async fn chat_with_extras(
179 &self,
180 messages: &[Message],
181 ) -> Result<(String, ChatExtras), LlmError> {
182 self.inner.chat_with_extras(messages).await
183 }
184
185 #[cfg_attr(
186 feature = "profiling",
187 tracing::instrument(
188 name = "llm.chat_stream",
189 skip_all,
190 fields(provider = self.name(), model = self.model_identifier())
191 )
192 )]
193 async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
194 self.inner.chat_stream(messages).await
195 }
196
197 fn supports_streaming(&self) -> bool {
198 self.inner.supports_streaming()
199 }
200
201 #[cfg_attr(
202 feature = "profiling",
203 tracing::instrument(
204 name = "llm.embed",
205 skip_all,
206 fields(provider = self.name(), model = self.model_identifier())
207 )
208 )]
209 async fn embed(&self, text: &str) -> Result<Vec<f32>, LlmError> {
210 self.inner.embed(text).await
211 }
212
213 #[cfg_attr(
214 feature = "profiling",
215 tracing::instrument(
216 name = "llm.embed_batch",
217 skip_all,
218 fields(provider = self.name(), model = self.model_identifier())
219 )
220 )]
221 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, LlmError> {
222 self.inner.embed_batch(texts).await
223 }
224
225 fn supports_embeddings(&self) -> bool {
226 self.inner.supports_embeddings()
227 }
228
229 fn name(&self) -> &str {
230 &self.provider_name
231 }
232
233 fn model_identifier(&self) -> &str {
234 self.inner.model_identifier()
235 }
236
237 fn list_models(&self) -> Vec<String> {
238 self.inner.list_models()
239 }
240
241 fn supports_structured_output(&self) -> bool {
242 self.inner.supports_structured_output()
243 }
244
245 async fn chat_typed<T>(&self, messages: &[Message]) -> Result<T, LlmError>
246 where
247 T: serde::de::DeserializeOwned + schemars::JsonSchema + 'static,
248 Self: Sized,
249 {
250 self.inner.chat_typed(messages).await
251 }
252
253 #[cfg_attr(
254 feature = "profiling",
255 tracing::instrument(
256 name = "llm.chat_with_tools",
257 skip_all,
258 fields(provider = self.name(), model = self.model_identifier(), tool_count = tools.len())
259 )
260 )]
261 async fn chat_with_tools(
262 &self,
263 messages: &[Message],
264 tools: &[ToolDefinition],
265 ) -> Result<ChatResponse, LlmError> {
266 self.inner.chat_with_tools(messages, tools).await
267 }
268
269 fn last_cache_usage(&self) -> Option<(u64, u64)> {
270 self.inner.last_cache_usage()
271 }
272
273 fn last_usage(&self) -> Option<(u64, u64)> {
274 self.inner.last_usage()
275 }
276
277 fn debug_request_json(
278 &self,
279 messages: &[Message],
280 tools: &[ToolDefinition],
281 stream: bool,
282 ) -> serde_json::Value {
283 self.inner.debug_request_json(messages, tools, stream)
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 fn test_provider() -> CompatibleProvider {
292 CompatibleProvider::new(CompatibleConfig {
293 provider_name: "groq".into(),
294 api_key: "key".into(),
295 base_url: "https://api.groq.com/openai/v1".into(),
296 model: "llama-3.3-70b".into(),
297 max_tokens: 4096,
298 embedding_model: None,
299 })
300 }
301
302 #[test]
303 fn name_returns_custom_provider_name() {
304 let p = test_provider();
305 assert_eq!(p.name(), "groq");
306 }
307
308 #[test]
309 fn context_window_returns_none() {
310 assert!(test_provider().context_window().is_none());
311 }
312
313 #[test]
314 fn supports_streaming_delegates() {
315 assert!(test_provider().supports_streaming());
316 }
317
318 #[test]
319 fn supports_embeddings_without_model() {
320 assert!(!test_provider().supports_embeddings());
321 }
322
323 #[test]
324 fn supports_embeddings_with_model() {
325 let p = CompatibleProvider::new(CompatibleConfig {
326 provider_name: "test".into(),
327 api_key: "key".into(),
328 base_url: "http://localhost".into(),
329 model: "m".into(),
330 max_tokens: 100,
331 embedding_model: Some("embed-model".into()),
332 });
333 assert!(p.supports_embeddings());
334 }
335
336 #[test]
337 fn clone_preserves_name() {
338 let p = test_provider();
339 let c = p.clone();
340 assert_eq!(c.name(), "groq");
341 }
342
343 #[test]
344 fn debug_contains_provider_name() {
345 let debug = format!("{:?}", test_provider());
346 assert!(debug.contains("groq"));
347 assert!(debug.contains("CompatibleProvider"));
348 }
349
350 #[tokio::test]
351 async fn chat_unreachable_errors() {
352 let p = CompatibleProvider::new(CompatibleConfig {
353 provider_name: "test".into(),
354 api_key: "key".into(),
355 base_url: "http://127.0.0.1:1".into(),
356 model: "m".into(),
357 max_tokens: 100,
358 embedding_model: None,
359 });
360 let msgs = vec![Message::from_legacy(crate::provider::Role::User, "hello")];
361 assert!(p.chat(&msgs).await.is_err());
362 }
363
364 #[tokio::test]
365 async fn embed_without_model_errors() {
366 let p = test_provider();
367 let result = p.embed("test").await;
368 assert!(result.is_err());
369 }
370
371 #[test]
372 fn last_usage_initially_none() {
373 assert!(test_provider().last_usage().is_none());
374 }
375
376 #[test]
377 fn with_output_schema_forwarding_does_not_panic() {
378 let p = test_provider().with_output_schema_forwarding(true, 512, usize::MAX);
380 assert_eq!(p.name(), "groq");
381 }
382}