Skip to main content

synaptic_huggingface/
reranker.rs

1use synaptic_core::{Document, SynapticError};
2
3/// Available BGE reranker models via HuggingFace Inference API.
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub enum BgeRerankerModel {
6    /// `BAAI/bge-reranker-v2-m3` — multilingual cross-encoder (recommended)
7    BgeRerankerV2M3,
8    /// `BAAI/bge-reranker-large` — highest quality, English-focused
9    BgeRerankerLarge,
10    /// `BAAI/bge-reranker-base` — fast, good quality, English-focused
11    BgeRerankerBase,
12    /// Any HuggingFace model ID
13    Custom(String),
14}
15
16impl BgeRerankerModel {
17    pub fn as_str(&self) -> &str {
18        match self {
19            BgeRerankerModel::BgeRerankerV2M3 => "BAAI/bge-reranker-v2-m3",
20            BgeRerankerModel::BgeRerankerLarge => "BAAI/bge-reranker-large",
21            BgeRerankerModel::BgeRerankerBase => "BAAI/bge-reranker-base",
22            BgeRerankerModel::Custom(s) => s.as_str(),
23        }
24    }
25}
26
27impl std::fmt::Display for BgeRerankerModel {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        write!(f, "{}", self.as_str())
30    }
31}
32
33/// Reranker using HuggingFace Inference API (BGE cross-encoder models).
34///
35/// Calls the sentence-similarity inference endpoint with `source_sentence`/`sentences`
36/// format and returns documents sorted by relevance score.
37pub struct HuggingFaceReranker {
38    api_key: String,
39    model: String,
40    base_url: String,
41    client: reqwest::Client,
42}
43
44impl HuggingFaceReranker {
45    pub fn new(api_key: impl Into<String>) -> Self {
46        Self {
47            api_key: api_key.into(),
48            model: BgeRerankerModel::BgeRerankerV2M3.to_string(),
49            base_url: "https://api-inference.huggingface.co/models".to_string(),
50            client: reqwest::Client::new(),
51        }
52    }
53
54    pub fn with_model(mut self, model: BgeRerankerModel) -> Self {
55        self.model = model.as_str().to_string();
56        self
57    }
58
59    pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
60        self.base_url = url.into();
61        self
62    }
63
64    /// Rerank documents by relevance to the query.
65    ///
66    /// Returns `(document, score)` pairs sorted by relevance score descending,
67    /// limited to `top_k` results.
68    pub async fn rerank(
69        &self,
70        query: &str,
71        documents: Vec<Document>,
72        top_k: usize,
73    ) -> Result<Vec<(Document, f32)>, SynapticError> {
74        if documents.is_empty() {
75            return Ok(Vec::new());
76        }
77        let sentences: Vec<&str> = documents.iter().map(|d| d.content.as_str()).collect();
78        let body = serde_json::json!({
79            "inputs": {
80                "source_sentence": query,
81                "sentences": sentences,
82            }
83        });
84        let url = format!("{}/{}", self.base_url, self.model);
85        let resp = self
86            .client
87            .post(&url)
88            .header("Authorization", format!("Bearer {}", self.api_key))
89            .header("Content-Type", "application/json")
90            .header("x-wait-for-model", "true")
91            .json(&body)
92            .send()
93            .await
94            .map_err(|e| SynapticError::Retriever(format!("HuggingFace rerank request: {e}")))?;
95        let status = resp.status().as_u16();
96        let json: serde_json::Value = resp
97            .json()
98            .await
99            .map_err(|e| SynapticError::Retriever(format!("HuggingFace rerank parse: {e}")))?;
100        if status != 200 {
101            return Err(SynapticError::Retriever(format!(
102                "HuggingFace API error ({}): {}",
103                status, json
104            )));
105        }
106        // Response is an array of floats, one per input sentence, in input order
107        let scores = json
108            .as_array()
109            .ok_or_else(|| SynapticError::Retriever("expected array response".to_string()))?;
110        let mut scored: Vec<(Document, f32)> = scores
111            .iter()
112            .enumerate()
113            .filter_map(|(i, v)| {
114                let score = v.as_f64()? as f32;
115                let doc = documents.get(i)?.clone();
116                Some((doc, score))
117            })
118            .collect();
119        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
120        Ok(scored.into_iter().take(top_k).collect())
121    }
122}