Skip to main content

sqlite_graphrag/
embedding_api.rs

1//! HTTP client for the OpenRouter embeddings API.
2//!
3//! Sends embedding requests to the OpenAI-compatible endpoint at
4//! `openrouter.ai/api/v1/embeddings` and returns dense `Vec<f32>`
5//! vectors. Handles retry with exponential backoff + jitter for
6//! transient failures (429, 5xx) and immediate abort for permanent
7//! errors (401, 400).
8
9use crate::errors::AppError;
10use secrecy::{ExposeSecret, SecretBox};
11use serde::{Deserialize, Serialize};
12use std::time::Duration;
13
14const OPENROUTER_EMBEDDINGS_URL: &str = "https://openrouter.ai/api/v1/embeddings";
15const DEFAULT_TIMEOUT_SECS: u64 = 30;
16const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 10;
17const MAX_BATCH_SIZE: usize = 32;
18const MAX_RETRIES: u32 = 4;
19
20#[derive(Serialize)]
21struct EmbeddingRequest<'a> {
22    model: &'a str,
23    input: EmbeddingInput<'a>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    dimensions: Option<usize>,
26    encoding_format: &'a str,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    input_type: Option<&'a str>,
29}
30
31#[derive(Serialize)]
32#[serde(untagged)]
33enum EmbeddingInput<'a> {
34    Single(&'a str),
35    Batch(Vec<&'a str>),
36}
37
38#[derive(Deserialize)]
39struct EmbeddingResponse {
40    data: Vec<EmbeddingData>,
41}
42
43#[derive(Deserialize)]
44struct EmbeddingData {
45    embedding: Vec<f32>,
46    index: usize,
47}
48
49pub struct OpenRouterClient {
50    client: reqwest::Client,
51    api_key: SecretBox<String>,
52    model: String,
53    dim: usize,
54    supports_mrl: bool,
55    default_input_type: Option<&'static str>,
56}
57
58fn model_supports_mrl(model: &str) -> bool {
59    model.contains("qwen3-embedding")
60        || model.contains("text-embedding-3")
61        || model.contains("gemini-embedding")
62        || model.contains("llama-nemotron-embed")
63        || model.contains("bge-m3")
64}
65
66fn model_default_input_type(model: &str) -> Option<&'static str> {
67    if model.contains("llama-nemotron-embed") {
68        Some("passage")
69    } else if model.contains("mistral-embed") {
70        None
71    } else {
72        Some("search_document")
73    }
74}
75
76impl OpenRouterClient {
77    pub fn new(api_key: SecretBox<String>, model: String, dim: usize) -> Result<Self, AppError> {
78        let client = reqwest::Client::builder()
79            .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
80            .connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
81            .user_agent("sqlite-graphrag/1.0.93")
82            .build()
83            .map_err(|e| AppError::Embedding(format!("failed to build HTTP client: {e}")))?;
84
85        let supports_mrl = model_supports_mrl(&model);
86        let default_input_type = model_default_input_type(&model);
87
88        Ok(Self {
89            client,
90            api_key,
91            model,
92            dim,
93            supports_mrl,
94            default_input_type,
95        })
96    }
97
98    pub fn default_input_type(&self) -> Option<&'static str> {
99        self.default_input_type
100    }
101
102    pub async fn embed_single(
103        &self,
104        text: &str,
105        input_type: Option<&str>,
106    ) -> Result<Vec<f32>, AppError> {
107        let request = EmbeddingRequest {
108            model: &self.model,
109            input: EmbeddingInput::Single(text),
110            dimensions: if self.supports_mrl {
111                Some(self.dim)
112            } else {
113                None
114            },
115            encoding_format: "float",
116            input_type,
117        };
118
119        let response = self.execute_with_retry(&request).await?;
120
121        let embedding = response
122            .data
123            .into_iter()
124            .next()
125            .ok_or_else(|| AppError::Embedding("empty response from OpenRouter".into()))?
126            .embedding;
127
128        self.truncate_embedding(embedding)
129    }
130
131    pub async fn embed_batch(
132        &self,
133        texts: &[&str],
134        input_type: Option<&str>,
135    ) -> Result<Vec<Vec<f32>>, AppError> {
136        if texts.is_empty() {
137            return Ok(Vec::new());
138        }
139
140        let mut all = Vec::with_capacity(texts.len());
141
142        for chunk in texts.chunks(MAX_BATCH_SIZE) {
143            let request = EmbeddingRequest {
144                model: &self.model,
145                input: EmbeddingInput::Batch(chunk.to_vec()),
146                dimensions: if self.supports_mrl {
147                    Some(self.dim)
148                } else {
149                    None
150                },
151                encoding_format: "float",
152                input_type,
153            };
154
155            let response = self.execute_with_retry(&request).await?;
156
157            if response.data.len() != chunk.len() {
158                return Err(AppError::Embedding(format!(
159                    "expected {} embeddings, got {}",
160                    chunk.len(),
161                    response.data.len()
162                )));
163            }
164
165            let mut sorted = response.data;
166            sorted.sort_by_key(|d| d.index);
167
168            for d in sorted {
169                all.push(self.truncate_embedding(d.embedding)?);
170            }
171        }
172
173        Ok(all)
174    }
175
176    fn truncate_embedding(&self, embedding: Vec<f32>) -> Result<Vec<f32>, AppError> {
177        if embedding.len() < self.dim {
178            return Err(AppError::Embedding(format!(
179                "embedding dimension {} < requested {}",
180                embedding.len(),
181                self.dim
182            )));
183        }
184        if embedding.len() == self.dim {
185            Ok(embedding)
186        } else {
187            Ok(embedding[..self.dim].to_vec())
188        }
189    }
190
191    async fn execute_with_retry(
192        &self,
193        request: &EmbeddingRequest<'_>,
194    ) -> Result<EmbeddingResponse, AppError> {
195        let mut last_err = None;
196
197        for attempt in 0..MAX_RETRIES {
198            let result = self
199                .client
200                .post(OPENROUTER_EMBEDDINGS_URL)
201                .header(
202                    "Authorization",
203                    format!("Bearer {}", self.api_key.expose_secret()),
204                )
205                .json(request)
206                .send()
207                .await;
208
209            let resp = match result {
210                Ok(r) => r,
211                Err(e) if e.is_timeout() => {
212                    return Err(AppError::Embedding("OpenRouter request timed out".into()));
213                }
214                Err(e) => {
215                    last_err = Some(AppError::Embedding(format!("HTTP request failed: {e}")));
216                    Self::backoff(attempt).await;
217                    continue;
218                }
219            };
220
221            let status = resp.status();
222
223            if status.is_success() {
224                let body = resp.text().await.map_err(|e| {
225                    AppError::Embedding(format!("failed to read response body: {e}"))
226                })?;
227                match serde_json::from_str::<EmbeddingResponse>(&body) {
228                    Ok(parsed) => return Ok(parsed),
229                    Err(e) => {
230                        tracing::warn!(
231                            attempt,
232                            body_len = body.len(),
233                            "HTTP 200 but parse failed (retrying): {e}"
234                        );
235                        last_err = Some(AppError::Embedding(format!(
236                            "failed to parse embedding response: {e}"
237                        )));
238                        Self::backoff(attempt).await;
239                        continue;
240                    }
241                }
242            }
243
244            if status.as_u16() == 401 {
245                return Err(AppError::Embedding(
246                    "invalid OpenRouter API key (HTTP 401)".into(),
247                ));
248            }
249
250            if status.as_u16() == 400 || status.as_u16() == 404 {
251                let body = resp.text().await.unwrap_or_default();
252                return Err(AppError::Embedding(format!(
253                    "OpenRouter returned {status}: {body}"
254                )));
255            }
256
257            if status.as_u16() == 429 {
258                let retry_after = resp
259                    .headers()
260                    .get("retry-after")
261                    .and_then(|v| v.to_str().ok())
262                    .and_then(|v| v.parse::<u64>().ok())
263                    .unwrap_or(2);
264                tracing::warn!(
265                    attempt,
266                    retry_after_secs = retry_after,
267                    "OpenRouter rate limited, waiting"
268                );
269                tokio::time::sleep(Duration::from_secs(retry_after)).await;
270                continue;
271            }
272
273            if status.is_server_error() {
274                tracing::warn!(attempt, status = %status, "OpenRouter server error, retrying");
275                last_err = Some(AppError::Embedding(format!(
276                    "OpenRouter server error: {status}"
277                )));
278                Self::backoff(attempt).await;
279                continue;
280            }
281
282            let body = resp.text().await.unwrap_or_default();
283            return Err(AppError::Embedding(format!(
284                "unexpected HTTP {status}: {body}"
285            )));
286        }
287
288        Err(last_err.unwrap_or_else(|| {
289            AppError::Embedding("max retries exceeded for OpenRouter request".into())
290        }))
291    }
292
293    async fn backoff(attempt: u32) {
294        let base_ms = 1000u64 * 2u64.pow(attempt);
295        let jitter = fastrand::u64(0..500);
296        let sleep_ms = base_ms + jitter;
297        tracing::debug!(attempt, sleep_ms, "exponential backoff");
298        tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_supports_mrl_detection() {
308        assert!(model_supports_mrl("qwen/qwen3-embedding-8b"));
309        assert!(model_supports_mrl("qwen/qwen3-embedding-4b"));
310        assert!(model_supports_mrl("openai/text-embedding-3-small"));
311        assert!(model_supports_mrl("openai/text-embedding-3-large"));
312        assert!(model_supports_mrl("google/gemini-embedding-001"));
313        assert!(model_supports_mrl("google/gemini-embedding-2"));
314        assert!(model_supports_mrl(
315            "nvidia/llama-nemotron-embed-vl-1b-v2:free"
316        ));
317        assert!(model_supports_mrl("baai/bge-m3"));
318
319        assert!(!model_supports_mrl("perplexity/pplx-embed-v1-0.6b"));
320        assert!(!model_supports_mrl("mistralai/mistral-embed-2312"));
321        assert!(!model_supports_mrl("some-random-model"));
322    }
323
324    #[test]
325    fn test_model_default_input_type() {
326        assert_eq!(
327            model_default_input_type("nvidia/llama-nemotron-embed-vl-1b-v2:free"),
328            Some("passage")
329        );
330        assert_eq!(
331            model_default_input_type("mistralai/mistral-embed-2312"),
332            None
333        );
334        assert_eq!(
335            model_default_input_type("qwen/qwen3-embedding-8b"),
336            Some("search_document")
337        );
338        assert_eq!(
339            model_default_input_type("openai/text-embedding-3-small"),
340            Some("search_document")
341        );
342        assert_eq!(
343            model_default_input_type("baai/bge-m3"),
344            Some("search_document")
345        );
346    }
347
348    #[test]
349    fn test_truncate_embedding() {
350        let api_key = SecretBox::new(Box::new("test-key".to_string()));
351        let client = OpenRouterClient::new(api_key, "test-model".into(), 3).unwrap();
352
353        let full = vec![1.0, 2.0, 3.0, 4.0, 5.0];
354        let truncated = client.truncate_embedding(full).unwrap();
355        assert_eq!(truncated, vec![1.0, 2.0, 3.0]);
356
357        let exact = vec![1.0, 2.0, 3.0];
358        let kept = client.truncate_embedding(exact).unwrap();
359        assert_eq!(kept, vec![1.0, 2.0, 3.0]);
360
361        let short = vec![1.0, 2.0];
362        let err = client.truncate_embedding(short);
363        assert!(err.is_err());
364    }
365}