synaptic_huggingface/
reranker.rs1use synaptic_core::{Document, SynapticError};
2
3#[derive(Debug, Clone, PartialEq, Eq)]
5pub enum BgeRerankerModel {
6 BgeRerankerV2M3,
8 BgeRerankerLarge,
10 BgeRerankerBase,
12 Custom(String),
14}
15
16impl BgeRerankerModel {
17 pub fn as_str(&self) -> &str {
18 match self {
19 BgeRerankerModel::BgeRerankerV2M3 => "BAAI/bge-reranker-v2-m3",
20 BgeRerankerModel::BgeRerankerLarge => "BAAI/bge-reranker-large",
21 BgeRerankerModel::BgeRerankerBase => "BAAI/bge-reranker-base",
22 BgeRerankerModel::Custom(s) => s.as_str(),
23 }
24 }
25}
26
27impl std::fmt::Display for BgeRerankerModel {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 write!(f, "{}", self.as_str())
30 }
31}
32
33pub struct HuggingFaceReranker {
38 api_key: String,
39 model: String,
40 base_url: String,
41 client: reqwest::Client,
42}
43
44impl HuggingFaceReranker {
45 pub fn new(api_key: impl Into<String>) -> Self {
46 Self {
47 api_key: api_key.into(),
48 model: BgeRerankerModel::BgeRerankerV2M3.to_string(),
49 base_url: "https://api-inference.huggingface.co/models".to_string(),
50 client: reqwest::Client::new(),
51 }
52 }
53
54 pub fn with_model(mut self, model: BgeRerankerModel) -> Self {
55 self.model = model.as_str().to_string();
56 self
57 }
58
59 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
60 self.base_url = url.into();
61 self
62 }
63
64 pub async fn rerank(
69 &self,
70 query: &str,
71 documents: Vec<Document>,
72 top_k: usize,
73 ) -> Result<Vec<(Document, f32)>, SynapticError> {
74 if documents.is_empty() {
75 return Ok(Vec::new());
76 }
77 let sentences: Vec<&str> = documents.iter().map(|d| d.content.as_str()).collect();
78 let body = serde_json::json!({
79 "inputs": {
80 "source_sentence": query,
81 "sentences": sentences,
82 }
83 });
84 let url = format!("{}/{}", self.base_url, self.model);
85 let resp = self
86 .client
87 .post(&url)
88 .header("Authorization", format!("Bearer {}", self.api_key))
89 .header("Content-Type", "application/json")
90 .header("x-wait-for-model", "true")
91 .json(&body)
92 .send()
93 .await
94 .map_err(|e| SynapticError::Retriever(format!("HuggingFace rerank request: {e}")))?;
95 let status = resp.status().as_u16();
96 let json: serde_json::Value = resp
97 .json()
98 .await
99 .map_err(|e| SynapticError::Retriever(format!("HuggingFace rerank parse: {e}")))?;
100 if status != 200 {
101 return Err(SynapticError::Retriever(format!(
102 "HuggingFace API error ({}): {}",
103 status, json
104 )));
105 }
106 let scores = json
108 .as_array()
109 .ok_or_else(|| SynapticError::Retriever("expected array response".to_string()))?;
110 let mut scored: Vec<(Document, f32)> = scores
111 .iter()
112 .enumerate()
113 .filter_map(|(i, v)| {
114 let score = v.as_f64()? as f32;
115 let doc = documents.get(i)?.clone();
116 Some((doc, score))
117 })
118 .collect();
119 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
120 Ok(scored.into_iter().take(top_k).collect())
121 }
122}