1use crate::errors::AppError;
10use secrecy::{ExposeSecret, SecretBox};
11use serde::{Deserialize, Serialize};
12use std::time::Duration;
13
14const OPENROUTER_EMBEDDINGS_URL: &str = "https://openrouter.ai/api/v1/embeddings";
15const DEFAULT_TIMEOUT_SECS: u64 = 30;
16const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 10;
17const MAX_BATCH_SIZE: usize = 32;
18const MAX_RETRIES: u32 = 4;
19
20#[derive(Serialize)]
21struct EmbeddingRequest<'a> {
22 model: &'a str,
23 input: EmbeddingInput<'a>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 dimensions: Option<usize>,
26 encoding_format: &'a str,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 input_type: Option<&'a str>,
29}
30
31#[derive(Serialize)]
32#[serde(untagged)]
33enum EmbeddingInput<'a> {
34 Single(&'a str),
35 Batch(Vec<&'a str>),
36}
37
38#[derive(Deserialize)]
39struct EmbeddingResponse {
40 data: Vec<EmbeddingData>,
41}
42
43#[derive(Deserialize)]
44struct EmbeddingData {
45 embedding: Vec<f32>,
46 index: usize,
47}
48
49#[derive(Deserialize)]
57struct EmbeddingEnvelope {
58 #[serde(default)]
59 data: Option<Vec<EmbeddingData>>,
60 #[serde(default)]
61 error: Option<ApiError>,
62}
63
64#[derive(Deserialize)]
69struct ApiError {
70 #[serde(default)]
71 code: Option<serde_json::Value>,
72 #[serde(default)]
73 message: String,
74}
75
76impl ApiError {
77 fn code_string(&self) -> String {
80 match &self.code {
81 Some(serde_json::Value::String(s)) => s.clone(),
82 Some(other) => other.to_string(),
83 None => "unknown".to_string(),
84 }
85 }
86}
87
88pub struct OpenRouterClient {
89 client: reqwest::Client,
90 api_key: SecretBox<String>,
91 model: String,
92 dim: usize,
93 supports_mrl: bool,
94 default_input_type: Option<&'static str>,
95}
96
97fn model_supports_mrl(model: &str) -> bool {
98 model.contains("qwen3-embedding")
99 || model.contains("text-embedding-3")
100 || model.contains("gemini-embedding")
101 || model.contains("llama-nemotron-embed")
102 || model.contains("bge-m3")
103}
104
105fn model_default_input_type(model: &str) -> Option<&'static str> {
106 if model.contains("llama-nemotron-embed") {
107 Some("passage")
108 } else if model.contains("mistral-embed") {
109 None
110 } else {
111 Some("search_document")
112 }
113}
114
115impl OpenRouterClient {
116 pub fn new(api_key: SecretBox<String>, model: String, dim: usize) -> Result<Self, AppError> {
117 let client = reqwest::Client::builder()
118 .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
119 .connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
120 .user_agent("sqlite-graphrag/1.0.96")
121 .build()
122 .map_err(|e| AppError::Embedding(format!("failed to build HTTP client: {e}")))?;
123
124 let supports_mrl = model_supports_mrl(&model);
125 let default_input_type = model_default_input_type(&model);
126
127 Ok(Self {
128 client,
129 api_key,
130 model,
131 dim,
132 supports_mrl,
133 default_input_type,
134 })
135 }
136
137 pub fn default_input_type(&self) -> Option<&'static str> {
138 self.default_input_type
139 }
140
141 pub async fn embed_single(
142 &self,
143 text: &str,
144 input_type: Option<&str>,
145 ) -> Result<Vec<f32>, AppError> {
146 crate::memory_guard::check_embedding_input_size(text)?;
150
151 let request = EmbeddingRequest {
152 model: &self.model,
153 input: EmbeddingInput::Single(text),
154 dimensions: if self.supports_mrl {
155 Some(self.dim)
156 } else {
157 None
158 },
159 encoding_format: "float",
160 input_type,
161 };
162
163 let response = self.execute_with_retry(&request).await?;
164
165 let embedding = response
166 .data
167 .into_iter()
168 .next()
169 .ok_or_else(|| AppError::Embedding("empty response from OpenRouter".into()))?
170 .embedding;
171
172 self.truncate_embedding(embedding)
173 }
174
175 pub async fn embed_batch(
176 &self,
177 texts: &[&str],
178 input_type: Option<&str>,
179 ) -> Result<Vec<Vec<f32>>, AppError> {
180 if texts.is_empty() {
181 return Ok(Vec::new());
182 }
183
184 for text in texts {
188 crate::memory_guard::check_embedding_input_size(text)?;
189 }
190
191 let mut all = Vec::with_capacity(texts.len());
192
193 for chunk in texts.chunks(MAX_BATCH_SIZE) {
194 let request = EmbeddingRequest {
195 model: &self.model,
196 input: EmbeddingInput::Batch(chunk.to_vec()),
197 dimensions: if self.supports_mrl {
198 Some(self.dim)
199 } else {
200 None
201 },
202 encoding_format: "float",
203 input_type,
204 };
205
206 let response = self.execute_with_retry(&request).await?;
207
208 if response.data.len() != chunk.len() {
209 return Err(AppError::Embedding(format!(
210 "expected {} embeddings, got {}",
211 chunk.len(),
212 response.data.len()
213 )));
214 }
215
216 let mut sorted = response.data;
217 sorted.sort_by_key(|d| d.index);
218
219 for d in sorted {
220 all.push(self.truncate_embedding(d.embedding)?);
221 }
222 }
223
224 Ok(all)
225 }
226
227 fn truncate_embedding(&self, embedding: Vec<f32>) -> Result<Vec<f32>, AppError> {
228 if embedding.len() < self.dim {
229 return Err(AppError::Embedding(format!(
230 "embedding dimension {} < requested {}",
231 embedding.len(),
232 self.dim
233 )));
234 }
235 if embedding.len() == self.dim {
236 Ok(embedding)
237 } else {
238 Ok(embedding[..self.dim].to_vec())
239 }
240 }
241
242 async fn execute_with_retry(
243 &self,
244 request: &EmbeddingRequest<'_>,
245 ) -> Result<EmbeddingResponse, AppError> {
246 let mut last_err = None;
247
248 for attempt in 0..MAX_RETRIES {
249 let result = self
250 .client
251 .post(OPENROUTER_EMBEDDINGS_URL)
252 .header(
253 "Authorization",
254 format!("Bearer {}", self.api_key.expose_secret()),
255 )
256 .json(request)
257 .send()
258 .await;
259
260 let resp = match result {
261 Ok(r) => r,
262 Err(e) if e.is_timeout() => {
263 return Err(AppError::Embedding("OpenRouter request timed out".into()));
264 }
265 Err(e) => {
266 last_err = Some(AppError::Embedding(format!("HTTP request failed: {e}")));
267 Self::backoff(attempt).await;
268 continue;
269 }
270 };
271
272 let status = resp.status();
273
274 if status.is_success() {
275 let body = resp.text().await.map_err(|e| {
276 AppError::Embedding(format!("failed to read response body: {e}"))
277 })?;
278 match serde_json::from_str::<EmbeddingEnvelope>(&body) {
279 Ok(env) => {
280 if let Some(api_err) = env.error {
285 return Err(AppError::ProviderError {
286 code: api_err.code_string(),
287 message: api_err.message,
288 });
289 }
290 match env.data {
291 Some(data) => return Ok(EmbeddingResponse { data }),
292 None => {
293 tracing::warn!(
294 attempt,
295 body_len = body.len(),
296 "HTTP 200 with neither data nor error (retrying)"
297 );
298 last_err = Some(AppError::Embedding(
299 "OpenRouter 200 response had neither data nor error".into(),
300 ));
301 Self::backoff(attempt).await;
302 continue;
303 }
304 }
305 }
306 Err(e) => {
307 tracing::warn!(
308 attempt,
309 body_len = body.len(),
310 "HTTP 200 but JSON unparseable (retrying): {e}"
311 );
312 last_err = Some(AppError::Embedding(format!(
313 "failed to parse embedding response: {e}"
314 )));
315 Self::backoff(attempt).await;
316 continue;
317 }
318 }
319 }
320
321 if status.as_u16() == 401 {
322 return Err(AppError::Embedding(
323 "invalid OpenRouter API key (HTTP 401)".into(),
324 ));
325 }
326
327 if status.as_u16() == 400 || status.as_u16() == 404 {
328 let body = resp.text().await.unwrap_or_default();
329 return Err(AppError::Embedding(format!(
330 "OpenRouter returned {status}: {body}"
331 )));
332 }
333
334 if status.as_u16() == 429 {
335 let retry_after = resp
336 .headers()
337 .get("retry-after")
338 .and_then(|v| v.to_str().ok())
339 .and_then(|v| v.parse::<u64>().ok())
340 .unwrap_or(2);
341 tracing::warn!(
342 attempt,
343 retry_after_secs = retry_after,
344 "OpenRouter rate limited, waiting"
345 );
346 last_err = Some(AppError::RateLimited {
351 detail: format!("OpenRouter HTTP 429 (retry-after {retry_after}s)"),
352 });
353 tokio::time::sleep(Duration::from_secs(retry_after)).await;
354 continue;
355 }
356
357 if status.is_server_error() {
358 tracing::warn!(attempt, status = %status, "OpenRouter server error, retrying");
359 last_err = Some(AppError::Embedding(format!(
360 "OpenRouter server error: {status}"
361 )));
362 Self::backoff(attempt).await;
363 continue;
364 }
365
366 let body = resp.text().await.unwrap_or_default();
367 return Err(AppError::Embedding(format!(
368 "unexpected HTTP {status}: {body}"
369 )));
370 }
371
372 Err(last_err.unwrap_or_else(|| {
373 AppError::Embedding("max retries exceeded for OpenRouter request".into())
374 }))
375 }
376
377 async fn backoff(attempt: u32) {
378 let base_ms = 1000u64 * 2u64.pow(attempt);
379 let jitter = fastrand::u64(0..500);
380 let sleep_ms = base_ms + jitter;
381 tracing::debug!(attempt, sleep_ms, "exponential backoff");
382 tokio::time::sleep(Duration::from_millis(sleep_ms)).await;
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 #[test]
391 fn test_supports_mrl_detection() {
392 assert!(model_supports_mrl("qwen/qwen3-embedding-8b"));
393 assert!(model_supports_mrl("qwen/qwen3-embedding-4b"));
394 assert!(model_supports_mrl("openai/text-embedding-3-small"));
395 assert!(model_supports_mrl("openai/text-embedding-3-large"));
396 assert!(model_supports_mrl("google/gemini-embedding-001"));
397 assert!(model_supports_mrl("google/gemini-embedding-2"));
398 assert!(model_supports_mrl(
399 "nvidia/llama-nemotron-embed-vl-1b-v2:free"
400 ));
401 assert!(model_supports_mrl("baai/bge-m3"));
402
403 assert!(!model_supports_mrl("perplexity/pplx-embed-v1-0.6b"));
404 assert!(!model_supports_mrl("mistralai/mistral-embed-2312"));
405 assert!(!model_supports_mrl("some-random-model"));
406 }
407
408 #[test]
409 fn test_model_default_input_type() {
410 assert_eq!(
411 model_default_input_type("nvidia/llama-nemotron-embed-vl-1b-v2:free"),
412 Some("passage")
413 );
414 assert_eq!(
415 model_default_input_type("mistralai/mistral-embed-2312"),
416 None
417 );
418 assert_eq!(
419 model_default_input_type("qwen/qwen3-embedding-8b"),
420 Some("search_document")
421 );
422 assert_eq!(
423 model_default_input_type("openai/text-embedding-3-small"),
424 Some("search_document")
425 );
426 assert_eq!(
427 model_default_input_type("baai/bge-m3"),
428 Some("search_document")
429 );
430 }
431
432 #[test]
433 fn test_truncate_embedding() {
434 let api_key = SecretBox::new(Box::new("test-key".to_string()));
435 let client = OpenRouterClient::new(api_key, "test-model".into(), 3).unwrap();
436
437 let full = vec![1.0, 2.0, 3.0, 4.0, 5.0];
438 let truncated = client.truncate_embedding(full).unwrap();
439 assert_eq!(truncated, vec![1.0, 2.0, 3.0]);
440
441 let exact = vec![1.0, 2.0, 3.0];
442 let kept = client.truncate_embedding(exact).unwrap();
443 assert_eq!(kept, vec![1.0, 2.0, 3.0]);
444
445 let short = vec![1.0, 2.0];
446 let err = client.truncate_embedding(short);
447 assert!(err.is_err());
448 }
449
450 #[test]
451 fn embedding_envelope_surfaces_provider_error_not_missing_field() {
452 let body = r#"{"error":{"code":400,"message":"context length exceeded"}}"#;
455
456 let legacy_err = match serde_json::from_str::<EmbeddingResponse>(body) {
459 Ok(_) => panic!("legacy parse should have failed on an error body"),
460 Err(e) => e.to_string(),
461 };
462 assert!(
463 legacy_err.contains("missing field"),
464 "precondition: legacy parse masks the cause as a missing field: {legacy_err}"
465 );
466
467 let env: EmbeddingEnvelope =
469 serde_json::from_str(body).expect("envelope parses an error body");
470 assert!(env.data.is_none());
471 let api_err = env.error.expect("error object captured");
472 assert_eq!(api_err.message, "context length exceeded");
473 assert_eq!(api_err.code_string(), "400");
474 }
475
476 #[test]
477 fn embedding_envelope_parses_success_body() {
478 let body = r#"{"data":[{"embedding":[1.0,2.0,3.0],"index":0}]}"#;
479 let env: EmbeddingEnvelope =
480 serde_json::from_str(body).expect("envelope parses a success body");
481 assert!(env.error.is_none());
482 let data = env.data.expect("data present");
483 assert_eq!(data.len(), 1);
484 assert_eq!(data[0].embedding, vec![1.0, 2.0, 3.0]);
485 }
486
487 #[test]
488 fn api_error_code_string_handles_number_string_and_missing() {
489 let num: ApiError = serde_json::from_str(r#"{"code":429,"message":"slow down"}"#).unwrap();
490 assert_eq!(num.code_string(), "429");
491
492 let s: ApiError =
493 serde_json::from_str(r#"{"code":"rate_limited","message":"slow down"}"#).unwrap();
494 assert_eq!(s.code_string(), "rate_limited");
495
496 let missing: ApiError = serde_json::from_str(r#"{"message":"oops"}"#).unwrap();
497 assert_eq!(missing.code_string(), "unknown");
498 }
499
500 #[tokio::test]
501 async fn embed_single_rejects_oversized_input_before_request() {
502 let api_key = SecretBox::new(Box::new("test-key".to_string()));
506 let client = OpenRouterClient::new(api_key, "qwen/qwen3-embedding-8b".into(), 384).unwrap();
507 let big = "word ".repeat(crate::constants::EMBEDDING_REQUEST_MAX_TOKENS + 5_000);
508 match client.embed_single(&big, None).await {
509 Err(AppError::Validation(msg)) => assert!(msg.contains("tokens")),
510 other => unreachable!("expected Validation before request, got: {other:?}"),
511 }
512 }
513}