1use crate::errors::AppError;
10use crate::retry::AttemptOutcome;
11use secrecy::{ExposeSecret, SecretBox};
12use serde::{Deserialize, Serialize};
13use std::time::Duration;
14
15const OPENROUTER_EMBEDDINGS_URL: &str = "https://openrouter.ai/api/v1/embeddings";
16const DEFAULT_TIMEOUT_SECS: u64 = 30;
17const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 10;
18const MAX_BATCH_SIZE: usize = 32;
19
20#[derive(Serialize)]
21struct EmbeddingRequest<'a> {
22 model: &'a str,
23 input: EmbeddingInput<'a>,
24 #[serde(skip_serializing_if = "Option::is_none")]
25 dimensions: Option<usize>,
26 encoding_format: &'a str,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 input_type: Option<&'a str>,
29}
30
31#[derive(Serialize)]
32#[serde(untagged)]
33enum EmbeddingInput<'a> {
34 Single(&'a str),
35 Batch(Vec<&'a str>),
36}
37
38#[derive(Deserialize)]
39struct EmbeddingResponse {
40 data: Vec<EmbeddingData>,
41}
42
43#[derive(Deserialize)]
44struct EmbeddingData {
45 embedding: Vec<f32>,
46 index: usize,
47}
48
49#[derive(Deserialize)]
57struct EmbeddingEnvelope {
58 #[serde(default)]
59 data: Option<Vec<EmbeddingData>>,
60 #[serde(default)]
61 error: Option<ApiError>,
62}
63
64use crate::openrouter_http::ApiError;
69
70#[derive(Debug)]
82pub struct EmbedError {
83 pub source: AppError,
85 pub retry_class: AttemptOutcome,
88}
89
90impl EmbedError {
91 fn new(source: AppError, retry_class: AttemptOutcome) -> Self {
92 Self {
93 source,
94 retry_class,
95 }
96 }
97}
98
99impl std::fmt::Display for EmbedError {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 std::fmt::Display::fmt(&self.source, f)
102 }
103}
104
105impl std::error::Error for EmbedError {
106 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
107 Some(&self.source)
108 }
109}
110
111impl From<AppError> for EmbedError {
120 fn from(source: AppError) -> Self {
121 Self::new(source, AttemptOutcome::HardFailure)
122 }
123}
124
125impl From<EmbedError> for AppError {
132 fn from(err: EmbedError) -> Self {
133 err.source
134 }
135}
136
137pub struct OpenRouterClient {
138 client: reqwest::Client,
139 api_key: SecretBox<String>,
140 model: String,
141 dim: usize,
142 supports_mrl: bool,
143 default_input_type: Option<&'static str>,
144 base_url: String,
149}
150
151fn model_supports_mrl(model: &str) -> bool {
152 model.contains("qwen3-embedding")
153 || model.contains("text-embedding-3")
154 || model.contains("gemini-embedding")
155 || model.contains("llama-nemotron-embed")
156 || model.contains("bge-m3")
157}
158
159fn model_default_input_type(model: &str) -> Option<&'static str> {
160 if model.contains("llama-nemotron-embed") {
161 Some("passage")
162 } else if model.contains("mistral-embed") {
163 None
164 } else {
165 Some("search_document")
166 }
167}
168
169impl OpenRouterClient {
170 pub fn new(api_key: SecretBox<String>, model: String, dim: usize) -> Result<Self, AppError> {
171 let client = reqwest::Client::builder()
172 .timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS))
173 .connect_timeout(Duration::from_secs(DEFAULT_CONNECT_TIMEOUT_SECS))
174 .user_agent("sqlite-graphrag/1.1.00")
175 .build()
176 .map_err(|e| AppError::Embedding(format!("failed to build HTTP client: {e}")))?;
177
178 let supports_mrl = model_supports_mrl(&model);
179 let default_input_type = model_default_input_type(&model);
180
181 Ok(Self {
182 client,
183 api_key,
184 model,
185 dim,
186 supports_mrl,
187 default_input_type,
188 base_url: OPENROUTER_EMBEDDINGS_URL.to_string(),
189 })
190 }
191
192 #[cfg(test)]
196 fn new_with_url(
197 api_key: SecretBox<String>,
198 model: String,
199 dim: usize,
200 base_url: String,
201 ) -> Result<Self, AppError> {
202 let mut client = Self::new(api_key, model, dim)?;
203 client.base_url = base_url;
204 Ok(client)
205 }
206
207 pub fn default_input_type(&self) -> Option<&'static str> {
208 self.default_input_type
209 }
210
211 pub async fn embed_single(
212 &self,
213 text: &str,
214 input_type: Option<&str>,
215 ) -> Result<Vec<f32>, EmbedError> {
216 crate::memory_guard::check_embedding_input_size(text)?;
220
221 let request = EmbeddingRequest {
222 model: &self.model,
223 input: EmbeddingInput::Single(text),
224 dimensions: if self.supports_mrl {
225 Some(self.dim)
226 } else {
227 None
228 },
229 encoding_format: "float",
230 input_type,
231 };
232
233 let response = self.execute_with_retry(&request).await?;
234
235 let embedding = response
236 .data
237 .into_iter()
238 .next()
239 .ok_or_else(|| AppError::Embedding("empty response from OpenRouter".into()))?
240 .embedding;
241
242 Ok(self.truncate_embedding(embedding)?)
243 }
244
245 pub async fn embed_batch(
246 &self,
247 texts: &[&str],
248 input_type: Option<&str>,
249 ) -> Result<Vec<Vec<f32>>, EmbedError> {
250 if texts.is_empty() {
251 return Ok(Vec::new());
252 }
253
254 for text in texts {
258 crate::memory_guard::check_embedding_input_size(text)?;
259 }
260
261 let mut all = Vec::with_capacity(texts.len());
262
263 for chunk in texts.chunks(MAX_BATCH_SIZE) {
264 let request = EmbeddingRequest {
265 model: &self.model,
266 input: EmbeddingInput::Batch(chunk.to_vec()),
267 dimensions: if self.supports_mrl {
268 Some(self.dim)
269 } else {
270 None
271 },
272 encoding_format: "float",
273 input_type,
274 };
275
276 let response = self.execute_with_retry(&request).await?;
277
278 if response.data.len() != chunk.len() {
279 return Err(AppError::Embedding(format!(
280 "expected {} embeddings, got {}",
281 chunk.len(),
282 response.data.len()
283 ))
284 .into());
285 }
286
287 let mut sorted = response.data;
288 sorted.sort_by_key(|d| d.index);
289
290 for d in sorted {
291 all.push(self.truncate_embedding(d.embedding)?);
292 }
293 }
294
295 Ok(all)
296 }
297
298 fn truncate_embedding(&self, embedding: Vec<f32>) -> Result<Vec<f32>, AppError> {
299 if embedding.len() < self.dim {
300 return Err(AppError::Embedding(format!(
301 "embedding dimension {} < requested {}",
302 embedding.len(),
303 self.dim
304 )));
305 }
306 if embedding.len() == self.dim {
307 Ok(embedding)
308 } else {
309 Ok(embedding[..self.dim].to_vec())
310 }
311 }
312
313 async fn execute_with_retry(
320 &self,
321 request: &EmbeddingRequest<'_>,
322 ) -> Result<EmbeddingResponse, EmbedError> {
323 let mut last_err: Option<EmbedError> = None;
324
325 for attempt in 0..crate::openrouter_http::MAX_RETRIES {
326 let result = self
327 .client
328 .post(&self.base_url)
329 .header(
330 "Authorization",
331 format!("Bearer {}", self.api_key.expose_secret()),
332 )
333 .json(request)
334 .send()
335 .await;
336
337 let resp = match result {
338 Ok(r) => r,
339 Err(e) if e.is_timeout() => {
340 return Err(EmbedError::new(
341 AppError::Embedding("OpenRouter request timed out".into()),
342 AttemptOutcome::Transient,
343 ));
344 }
345 Err(e) => {
346 last_err = Some(EmbedError::new(
347 AppError::Embedding(format!("HTTP request failed: {e}")),
348 AttemptOutcome::Transient,
349 ));
350 crate::openrouter_http::backoff(attempt).await;
351 continue;
352 }
353 };
354
355 let status = resp.status();
356
357 if status.is_success() {
358 let body = resp.text().await.map_err(|e| {
359 EmbedError::new(
360 AppError::Embedding(format!("failed to read response body: {e}")),
361 AttemptOutcome::Transient,
362 )
363 })?;
364 match serde_json::from_str::<EmbeddingEnvelope>(&body) {
365 Ok(env) => {
366 if let Some(api_err) = env.error {
371 let retry_class =
372 crate::openrouter_http::provider_error_retry_class(&api_err);
373 return Err(EmbedError::new(
374 AppError::ProviderError {
375 code: api_err.code_string(),
376 message: api_err.message,
377 },
378 retry_class,
379 ));
380 }
381 match env.data {
382 Some(data) => return Ok(EmbeddingResponse { data }),
383 None => {
384 tracing::warn!(
385 attempt,
386 body_len = body.len(),
387 "HTTP 200 with neither data nor error (retrying)"
388 );
389 last_err = Some(EmbedError::new(
390 AppError::Embedding(
391 "OpenRouter 200 response had neither data nor error".into(),
392 ),
393 AttemptOutcome::Transient,
394 ));
395 crate::openrouter_http::backoff(attempt).await;
396 continue;
397 }
398 }
399 }
400 Err(e) => {
401 tracing::warn!(
402 attempt,
403 body_len = body.len(),
404 "HTTP 200 but JSON unparseable (retrying): {e}"
405 );
406 last_err = Some(EmbedError::new(
407 AppError::Embedding(format!("failed to parse embedding response: {e}")),
408 AttemptOutcome::Transient,
409 ));
410 crate::openrouter_http::backoff(attempt).await;
411 continue;
412 }
413 }
414 }
415
416 if status.as_u16() == 401 {
417 return Err(EmbedError::new(
418 AppError::Embedding("invalid OpenRouter API key (HTTP 401)".into()),
419 AttemptOutcome::HardFailure,
420 ));
421 }
422
423 if status.as_u16() == 400 || status.as_u16() == 404 {
424 let body = resp.text().await.unwrap_or_default();
425 return Err(EmbedError::new(
426 AppError::Embedding(format!("OpenRouter returned {status}: {body}")),
427 AttemptOutcome::HardFailure,
428 ));
429 }
430
431 if status.as_u16() == 429 {
432 let retry_after = resp
433 .headers()
434 .get("retry-after")
435 .and_then(|v| v.to_str().ok())
436 .and_then(|v| v.parse::<u64>().ok())
437 .unwrap_or(2);
438 tracing::warn!(
439 attempt,
440 retry_after_secs = retry_after,
441 "OpenRouter rate limited, waiting"
442 );
443 last_err = Some(EmbedError::new(
448 AppError::RateLimited {
449 detail: format!("OpenRouter HTTP 429 (retry-after {retry_after}s)"),
450 },
451 AttemptOutcome::Transient,
452 ));
453 tokio::time::sleep(Duration::from_secs(retry_after)).await;
454 continue;
455 }
456
457 if status.is_server_error() {
458 tracing::warn!(attempt, status = %status, "OpenRouter server error, retrying");
459 last_err = Some(EmbedError::new(
460 AppError::Embedding(format!("OpenRouter server error: {status}")),
461 AttemptOutcome::Transient,
462 ));
463 crate::openrouter_http::backoff(attempt).await;
464 continue;
465 }
466
467 let body = resp.text().await.unwrap_or_default();
468 return Err(EmbedError::new(
469 AppError::Embedding(format!("unexpected HTTP {status}: {body}")),
470 crate::openrouter_http::status_retry_class(status),
471 ));
472 }
473
474 Err(last_err.unwrap_or_else(|| {
479 EmbedError::new(
480 AppError::Embedding("max retries exceeded for OpenRouter request".into()),
481 AttemptOutcome::Transient,
482 )
483 }))
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490
491 #[test]
492 fn test_supports_mrl_detection() {
493 assert!(model_supports_mrl("qwen/qwen3-embedding-8b"));
494 assert!(model_supports_mrl("qwen/qwen3-embedding-4b"));
495 assert!(model_supports_mrl("openai/text-embedding-3-small"));
496 assert!(model_supports_mrl("openai/text-embedding-3-large"));
497 assert!(model_supports_mrl("google/gemini-embedding-001"));
498 assert!(model_supports_mrl("google/gemini-embedding-2"));
499 assert!(model_supports_mrl(
500 "nvidia/llama-nemotron-embed-vl-1b-v2:free"
501 ));
502 assert!(model_supports_mrl("baai/bge-m3"));
503
504 assert!(!model_supports_mrl("perplexity/pplx-embed-v1-0.6b"));
505 assert!(!model_supports_mrl("mistralai/mistral-embed-2312"));
506 assert!(!model_supports_mrl("some-random-model"));
507 }
508
509 #[test]
510 fn test_model_default_input_type() {
511 assert_eq!(
512 model_default_input_type("nvidia/llama-nemotron-embed-vl-1b-v2:free"),
513 Some("passage")
514 );
515 assert_eq!(
516 model_default_input_type("mistralai/mistral-embed-2312"),
517 None
518 );
519 assert_eq!(
520 model_default_input_type("qwen/qwen3-embedding-8b"),
521 Some("search_document")
522 );
523 assert_eq!(
524 model_default_input_type("openai/text-embedding-3-small"),
525 Some("search_document")
526 );
527 assert_eq!(
528 model_default_input_type("baai/bge-m3"),
529 Some("search_document")
530 );
531 }
532
533 #[test]
534 fn test_truncate_embedding() {
535 let api_key = SecretBox::new(Box::new("test-key".to_string()));
536 let client = OpenRouterClient::new(api_key, "test-model".into(), 3).unwrap();
537
538 let full = vec![1.0, 2.0, 3.0, 4.0, 5.0];
539 let truncated = client.truncate_embedding(full).unwrap();
540 assert_eq!(truncated, vec![1.0, 2.0, 3.0]);
541
542 let exact = vec![1.0, 2.0, 3.0];
543 let kept = client.truncate_embedding(exact).unwrap();
544 assert_eq!(kept, vec![1.0, 2.0, 3.0]);
545
546 let short = vec![1.0, 2.0];
547 let err = client.truncate_embedding(short);
548 assert!(err.is_err());
549 }
550
551 #[test]
552 fn embedding_envelope_surfaces_provider_error_not_missing_field() {
553 let body = r#"{"error":{"code":400,"message":"context length exceeded"}}"#;
556
557 let legacy_err = match serde_json::from_str::<EmbeddingResponse>(body) {
560 Ok(_) => panic!("legacy parse should have failed on an error body"),
561 Err(e) => e.to_string(),
562 };
563 assert!(
564 legacy_err.contains("missing field"),
565 "precondition: legacy parse masks the cause as a missing field: {legacy_err}"
566 );
567
568 let env: EmbeddingEnvelope =
570 serde_json::from_str(body).expect("envelope parses an error body");
571 assert!(env.data.is_none());
572 let api_err = env.error.expect("error object captured");
573 assert_eq!(api_err.message, "context length exceeded");
574 assert_eq!(api_err.code_string(), "400");
575 }
576
577 #[test]
578 fn embedding_envelope_parses_success_body() {
579 let body = r#"{"data":[{"embedding":[1.0,2.0,3.0],"index":0}]}"#;
580 let env: EmbeddingEnvelope =
581 serde_json::from_str(body).expect("envelope parses a success body");
582 assert!(env.error.is_none());
583 let data = env.data.expect("data present");
584 assert_eq!(data.len(), 1);
585 assert_eq!(data[0].embedding, vec![1.0, 2.0, 3.0]);
586 }
587
588 #[test]
589 fn api_error_code_string_handles_number_string_and_missing() {
590 let num: ApiError = serde_json::from_str(r#"{"code":429,"message":"slow down"}"#).unwrap();
591 assert_eq!(num.code_string(), "429");
592
593 let s: ApiError =
594 serde_json::from_str(r#"{"code":"rate_limited","message":"slow down"}"#).unwrap();
595 assert_eq!(s.code_string(), "rate_limited");
596
597 let missing: ApiError = serde_json::from_str(r#"{"message":"oops"}"#).unwrap();
598 assert_eq!(missing.code_string(), "unknown");
599 }
600
601 #[tokio::test]
602 async fn embed_single_rejects_oversized_input_before_request() {
603 let api_key = SecretBox::new(Box::new("test-key".to_string()));
607 let client = OpenRouterClient::new(api_key, "qwen/qwen3-embedding-8b".into(), 384).unwrap();
608 let big = "word ".repeat(crate::constants::EMBEDDING_REQUEST_MAX_TOKENS + 5_000);
609 match client.embed_single(&big, None).await {
610 Err(EmbedError {
611 source: AppError::Validation(msg),
612 retry_class,
613 }) => {
614 assert!(msg.contains("tokens"));
615 assert_eq!(
616 retry_class,
617 AttemptOutcome::HardFailure,
618 "an oversized input is a permanent client error"
619 );
620 }
621 other => unreachable!("expected Validation before request, got: {other:?}"),
622 }
623 }
624
625 async fn client_for(server: &wiremock::MockServer, model: &str) -> OpenRouterClient {
626 OpenRouterClient::new_with_url(
627 SecretBox::new(Box::new("test-key".to_string())),
628 model.to_string(),
629 384,
630 format!("{}/embeddings", server.uri()),
631 )
632 .expect("test client builds")
633 }
634
635 #[tokio::test]
636 async fn embed_single_401_is_hard_failure() {
637 use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate};
640 let server = MockServer::start().await;
641 Mock::given(method("POST"))
642 .respond_with(ResponseTemplate::new(401))
643 .mount(&server)
644 .await;
645
646 let client = client_for(&server, "qwen/qwen3-embedding-8b").await;
647 let err = client
648 .embed_single("hello", None)
649 .await
650 .expect_err("401 is an error");
651 assert_eq!(err.retry_class, AttemptOutcome::HardFailure);
652 }
653
654 #[tokio::test]
655 async fn embed_single_exhausted_5xx_is_transient() {
656 use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate};
660 let server = MockServer::start().await;
661 Mock::given(method("POST"))
662 .respond_with(ResponseTemplate::new(503))
663 .mount(&server)
664 .await;
665
666 let client = client_for(&server, "qwen/qwen3-embedding-8b").await;
667 let err = client
668 .embed_single("hello", None)
669 .await
670 .expect_err("persistent 5xx exhausts retries");
671 assert_eq!(err.retry_class, AttemptOutcome::Transient);
672 }
673
674 #[tokio::test]
675 async fn embed_single_provider_error_code_classifies_by_code_not_message() {
676 use wiremock::{matchers::method, Mock, MockServer, ResponseTemplate};
680 let server = MockServer::start().await;
681 Mock::given(method("POST"))
682 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
683 "error": { "code": "context_length_exceeded", "message": "too many tokens" }
684 })))
685 .mount(&server)
686 .await;
687
688 let client = client_for(&server, "qwen/qwen3-embedding-8b").await;
689 let err = client
690 .embed_single("hello", None)
691 .await
692 .expect_err("provider error must surface");
693 assert_eq!(err.retry_class, AttemptOutcome::HardFailure);
694 }
695}