Skip to main content

tt_provider_compat/
compat.rs

1//! Shared base for OpenAI-compatible provider adapters.
2//!
3//! Many inference providers (Mistral, Groq, Together AI, OpenRouter) expose an
4//! endpoint that is wire-compatible with OpenAI's `POST /chat/completions` API.
5//! Rather than duplicating HTTP plumbing in each adapter crate, this module
6//! provides [`OpenAICompatibleProvider`] — a single generic implementation that
7//! each adapter instantiates with its own [`CompatConfig`].
8//!
9//! # Billing note
10//!
11//! [`CompatConfig::fee_multiplier`] is stored but **not applied at request time**.
12//! Token counts and raw per-token costs flow through the response unchanged;
13//! the billing layer (in `tt-core`) multiplies by `fee_multiplier` when it
14//! computes the final USD charge displayed on the dashboard. This is intentional:
15//! the adapter should not alter usage numbers, only report them faithfully.
16//! (Tracked as a follow-up in the cost-accounting work item.)
17//!
18//! # Usage
19//!
20//! ```rust,no_run
21//! use std::collections::HashMap;
22//! use tt_provider_compat::{CompatConfig, OpenAICompatibleProvider, ClientConfig};
23//! use tt_shared::pricing::{Capability, ModelInfo, ModelPricing};
24//! use chrono::Utc;
25//!
26//! let cfg = CompatConfig {
27//!     id: "my-provider",
28//!     default_base_url: "https://api.example.com/v1".to_string(),
29//!     models: vec![],
30//!     pricing_table: HashMap::new(),
31//!     fee_multiplier: 1.0,
32//!     allow_local: false,
33//! };
34//! let provider = OpenAICompatibleProvider::new(ClientConfig::default(), cfg);
35//! ```
36
37use std::collections::HashMap;
38
39use async_trait::async_trait;
40use futures::stream::BoxStream;
41use reqwest::Client;
42use tracing::instrument;
43use tt_shared::{
44    filter_extra_headers, validate_provider_url, ChatCompletionChunk, ChatCompletionRequest,
45    ChatCompletionResponse, EmbeddingsRequest, EmbeddingsResponse, ModelInfo, ModelPricing,
46    Provider, ProviderError, RequestContext,
47};
48
49use crate::client::{build_client, ClientConfig};
50use crate::errors::{map_reqwest_error, map_response_error};
51use crate::{stream, translate};
52
53// ---------------------------------------------------------------------------
54// Configuration
55// ---------------------------------------------------------------------------
56
57/// Per-provider configuration for [`OpenAICompatibleProvider`].
58///
59/// Construct this once at startup (or in a lazy static) and pass it to
60/// [`OpenAICompatibleProvider::new`].
61pub struct CompatConfig {
62    /// Stable, lower-case identifier used by the routing and telemetry layers.
63    ///
64    /// Examples: `"mistral"`, `"groq"`, `"together"`, `"openrouter"`.
65    pub id: &'static str,
66
67    /// Default base URL, used when the caller's [`RequestContext`] does not
68    /// supply a `base_url` override in its credentials.
69    pub default_base_url: String,
70
71    /// All models exposed by this provider configuration.
72    pub models: Vec<ModelInfo>,
73
74    /// Pricing keyed by model ID string, mirroring the per-provider tables in
75    /// the OpenAI adapter's `pricing.rs`.
76    pub pricing_table: HashMap<String, ModelPricing>,
77
78    /// Optional fee multiplier stored for the billing layer (e.g. `1.05` for a
79    /// 5% BYOK fee on OpenRouter).
80    ///
81    /// **This value is NOT applied to usage at request time.** The adapter
82    /// faithfully reports raw token counts; the dashboard billing pass applies
83    /// the multiplier when computing the final USD charge. Default: `1.0`.
84    pub fee_multiplier: f64,
85
86    /// When `true`, skip SSRF URL validation for private/loopback addresses.
87    ///
88    /// Set to `true` only for local providers (Ollama, vLLM, LM Studio) that
89    /// legitimately target `http://localhost` or `http://127.0.0.1`. All hosted
90    /// providers must use `false`.
91    pub allow_local: bool,
92}
93
94// ---------------------------------------------------------------------------
95// Provider struct
96// ---------------------------------------------------------------------------
97
98/// Generic OpenAI-compatible chat-completion adapter.
99///
100/// Holds an HTTP client and a [`CompatConfig`] that varies per provider.
101/// All four thin adapter crates (Mistral, Groq, Together, OpenRouter) wrap
102/// this struct and forward every [`Provider`] method to it.
103pub struct OpenAICompatibleProvider {
104    client: Client,
105    cfg: CompatConfig,
106}
107
108impl OpenAICompatibleProvider {
109    /// Construct a new adapter from the given HTTP client configuration and
110    /// provider-specific configuration.
111    ///
112    /// # Panics
113    ///
114    /// Panics if the underlying [`reqwest::Client`] cannot be constructed (very
115    /// rare — only happens with invalid TLS configuration).
116    pub fn new(client_cfg: ClientConfig, cfg: CompatConfig) -> Self {
117        let client = build_client(&client_cfg)
118            .unwrap_or_else(|e| panic!("failed to build HTTP client for {}: {e}", cfg.id));
119        Self { client, cfg }
120    }
121
122    /// The fee multiplier stored in this provider's config.
123    ///
124    /// Exposed so that the billing layer can retrieve it without accessing
125    /// private fields. See [`CompatConfig::fee_multiplier`] for semantics.
126    pub fn fee_multiplier(&self) -> f64 {
127        self.cfg.fee_multiplier
128    }
129
130    /// Resolve the base URL: prefer the credential override, fall back to the
131    /// compiled-in default.
132    fn base_url<'a>(&'a self, ctx: &'a RequestContext) -> &'a str {
133        ctx.credentials
134            .base_url
135            .as_deref()
136            .unwrap_or(self.cfg.default_base_url.as_str())
137    }
138}
139
140// ---------------------------------------------------------------------------
141// Provider trait implementation
142// ---------------------------------------------------------------------------
143
144#[async_trait]
145impl Provider for OpenAICompatibleProvider {
146    fn id(&self) -> &'static str {
147        self.cfg.id
148    }
149
150    fn models(&self) -> Vec<ModelInfo> {
151        self.cfg.models.clone()
152    }
153
154    fn pricing(&self, model: &str) -> Option<ModelPricing> {
155        self.cfg.pricing_table.get(model).cloned()
156    }
157
158    fn dropped_params(&self, req: &tt_shared::ChatCompletionRequest) -> Vec<String> {
159        crate::translate::dropped_params(req)
160    }
161
162    /// Non-streaming chat completion via `POST /chat/completions`.
163    ///
164    /// Translates the canonical request, sends it to the provider's endpoint
165    /// (resolved from credentials or the default base URL), and maps any HTTP
166    /// error to the appropriate [`ProviderError`] variant.
167    #[instrument(skip(self, ctx), fields(provider = %self.cfg.id, model = %req.model))]
168    async fn chat_completion(
169        &self,
170        req: ChatCompletionRequest,
171        ctx: &RequestContext,
172    ) -> Result<ChatCompletionResponse, ProviderError> {
173        let base_url = self.base_url(ctx);
174        validate_provider_url(base_url, self.cfg.allow_local)
175            .map_err(|e| ProviderError::InvalidRequest(format!("blocked provider URL: {e}")))?;
176
177        let url = format!("{base_url}/chat/completions");
178        let body = translate::translate_request(req)?;
179
180        let mut rb = self
181            .client
182            .post(&url)
183            .header(
184                "Authorization",
185                format!("Bearer {}", ctx.credentials.api_key.expose()),
186            )
187            .header("Content-Type", "application/json")
188            .json(&body);
189
190        for (name, value) in &filter_extra_headers(&ctx.credentials.extra_headers) {
191            rb = rb.header(name, value);
192        }
193
194        let response = rb.send().await.map_err(map_reqwest_error)?;
195
196        let status = response.status().as_u16();
197        let retry_after = response
198            .headers()
199            .get("retry-after")
200            .and_then(|v| v.to_str().ok())
201            .map(|s| s.to_string());
202
203        let response_text = response.text().await.map_err(map_reqwest_error)?;
204
205        if status >= 400 {
206            return Err(map_response_error(
207                status,
208                &response_text,
209                retry_after.as_deref(),
210            ));
211        }
212
213        translate::deserialize_response(&response_text)
214    }
215
216    /// Streaming chat completion via `POST /chat/completions` with `stream: true`.
217    ///
218    /// Returns a [`BoxStream`] that yields [`ChatCompletionChunk`] values parsed
219    /// from OpenAI-compatible SSE events. HTTP errors before the first byte are
220    /// surfaced as `Err` before any chunk is produced.
221    #[instrument(skip(self, ctx), fields(provider = %self.cfg.id, model = %req.model))]
222    async fn chat_completion_stream(
223        &self,
224        req: ChatCompletionRequest,
225        ctx: &RequestContext,
226    ) -> Result<BoxStream<'static, Result<ChatCompletionChunk, ProviderError>>, ProviderError> {
227        let base_url = self.base_url(ctx);
228        validate_provider_url(base_url, self.cfg.allow_local)
229            .map_err(|e| ProviderError::InvalidRequest(format!("blocked provider URL: {e}")))?;
230
231        let base_url = base_url.to_string();
232        let client = self.client.clone();
233        stream::stream_chat_completion(client, &base_url, req, ctx).await
234    }
235
236    /// Embeddings via `POST /embeddings`.
237    ///
238    /// Sends the canonical [`EmbeddingsRequest`] to the provider's `/embeddings`
239    /// endpoint (resolved from credentials or the default base URL). All
240    /// OpenAI-compatible providers (Mistral, Together, etc.) expose the same
241    /// `/embeddings` path and wire format, so no translation is needed beyond
242    /// what [`translate::translate_embeddings_request`] provides.
243    #[instrument(skip(self, ctx), fields(provider = %self.cfg.id, model = %req.model))]
244    async fn embeddings(
245        &self,
246        req: EmbeddingsRequest,
247        ctx: &RequestContext,
248    ) -> Result<EmbeddingsResponse, ProviderError> {
249        let base_url = self.base_url(ctx);
250        validate_provider_url(base_url, self.cfg.allow_local)
251            .map_err(|e| ProviderError::InvalidRequest(format!("blocked provider URL: {e}")))?;
252
253        let url = format!("{base_url}/embeddings");
254        let body = translate::translate_embeddings_request(req)?;
255
256        let mut rb = self
257            .client
258            .post(&url)
259            .header(
260                "Authorization",
261                format!("Bearer {}", ctx.credentials.api_key.expose()),
262            )
263            .header("Content-Type", "application/json")
264            .json(&body);
265
266        for (name, value) in &filter_extra_headers(&ctx.credentials.extra_headers) {
267            rb = rb.header(name, value);
268        }
269
270        let response = rb.send().await.map_err(map_reqwest_error)?;
271
272        let status = response.status().as_u16();
273        let retry_after = response
274            .headers()
275            .get("retry-after")
276            .and_then(|v| v.to_str().ok())
277            .map(|s| s.to_string());
278
279        let response_text = response.text().await.map_err(map_reqwest_error)?;
280
281        if status >= 400 {
282            return Err(map_response_error(
283                status,
284                &response_text,
285                retry_after.as_deref(),
286            ));
287        }
288
289        translate::deserialize_embeddings_response(&response_text)
290    }
291}