1use anyhow::{anyhow, Result};
47use async_trait::async_trait;
48
49use smooth_operator::rerank::Reranker;
50use smooth_operator_core::KnowledgeResult;
51
52pub const DEFAULT_RERANK_MODEL: &str = "rerank-english-v3.0";
55
56#[derive(Debug, Clone, Copy, PartialEq)]
59pub struct RerankScore {
60 pub index: usize,
62 pub relevance_score: f32,
65}
66
67#[async_trait]
73pub trait RerankBackend: Send + Sync {
74 async fn rerank(
82 &self,
83 query: &str,
84 documents: &[String],
85 top_n: usize,
86 ) -> Result<Vec<RerankScore>>;
87}
88
89#[derive(Clone)]
91pub struct HttpRerankBackend {
92 client: reqwest::Client,
93 base_url: String,
94 api_key: String,
95 model: String,
96}
97
98impl HttpRerankBackend {
99 #[must_use]
101 pub fn new(
102 base_url: impl Into<String>,
103 api_key: impl Into<String>,
104 model: impl Into<String>,
105 ) -> Self {
106 Self {
107 client: reqwest::Client::new(),
108 base_url: base_url.into(),
109 api_key: api_key.into(),
110 model: model.into(),
111 }
112 }
113}
114
115#[async_trait]
116impl RerankBackend for HttpRerankBackend {
117 async fn rerank(
118 &self,
119 query: &str,
120 documents: &[String],
121 top_n: usize,
122 ) -> Result<Vec<RerankScore>> {
123 let url = format!("{}/v1/rerank", self.base_url.trim_end_matches('/'));
126 let body = serde_json::json!({
127 "model": self.model,
128 "query": query,
129 "documents": documents,
130 "top_n": top_n,
131 });
132
133 let resp = self
134 .client
135 .post(&url)
136 .bearer_auth(&self.api_key)
137 .json(&body)
138 .send()
139 .await?;
140
141 if !resp.status().is_success() {
142 let status = resp.status();
143 let text = resp.text().await.unwrap_or_default();
144 return Err(anyhow!("rerank request failed ({status}): {text}"));
145 }
146
147 #[derive(serde::Deserialize)]
148 struct ResultItem {
149 index: usize,
150 relevance_score: f32,
151 }
152 #[derive(serde::Deserialize)]
153 struct RerankResponse {
154 results: Vec<ResultItem>,
155 }
156
157 let parsed: RerankResponse = resp.json().await?;
158 Ok(parsed
159 .results
160 .into_iter()
161 .map(|r| RerankScore {
162 index: r.index,
163 relevance_score: r.relevance_score,
164 })
165 .collect())
166 }
167}
168
169pub struct GatewayReranker {
178 backend: std::sync::Arc<dyn RerankBackend>,
179}
180
181impl GatewayReranker {
182 #[must_use]
185 pub fn with_backend(backend: std::sync::Arc<dyn RerankBackend>) -> Self {
186 Self { backend }
187 }
188
189 #[must_use]
191 pub fn new(
192 base_url: impl Into<String>,
193 api_key: impl Into<String>,
194 model: impl Into<String>,
195 ) -> Self {
196 Self::with_backend(std::sync::Arc::new(HttpRerankBackend::new(
197 base_url, api_key, model,
198 )))
199 }
200
201 pub fn from_env() -> Result<Self> {
207 let base_url = std::env::var("SMOOAI_GATEWAY_URL")
208 .map_err(|_| anyhow!("SMOOAI_GATEWAY_URL is not set"))?;
209 let api_key = std::env::var("SMOOAI_GATEWAY_KEY")
210 .map_err(|_| anyhow!("SMOOAI_GATEWAY_KEY is not set"))?;
211 Ok(Self::new(base_url, api_key, DEFAULT_RERANK_MODEL))
212 }
213
214 fn reorder(
221 mut scores: Vec<RerankScore>,
222 candidates: Vec<KnowledgeResult>,
223 top_k: usize,
224 ) -> Vec<KnowledgeResult> {
225 let n = candidates.len();
226 let mut slots: Vec<Option<KnowledgeResult>> = candidates.into_iter().map(Some).collect();
228
229 scores.sort_by(|a, b| {
231 b.relevance_score
232 .partial_cmp(&a.relevance_score)
233 .unwrap_or(std::cmp::Ordering::Equal)
234 });
235
236 let mut out: Vec<KnowledgeResult> = Vec::with_capacity(top_k.min(n));
237 let mut taken = vec![false; n];
238 for s in scores {
239 if out.len() >= top_k {
240 break;
241 }
242 if s.index < n && !taken[s.index] {
245 if let Some(c) = slots[s.index].take() {
246 taken[s.index] = true;
247 out.push(c);
248 }
249 }
250 }
251 if out.len() < top_k {
254 for (i, slot) in slots.iter_mut().enumerate() {
255 if out.len() >= top_k {
256 break;
257 }
258 if !taken[i] {
259 if let Some(c) = slot.take() {
260 out.push(c);
261 }
262 }
263 }
264 }
265 out
266 }
267}
268
269#[async_trait]
270impl Reranker for GatewayReranker {
271 async fn rerank(
272 &self,
273 query: &str,
274 candidates: Vec<KnowledgeResult>,
275 top_k: usize,
276 ) -> Vec<KnowledgeResult> {
277 if candidates.is_empty() || top_k == 0 {
278 return Vec::new();
279 }
280 let documents: Vec<String> = candidates.iter().map(|c| c.chunk.clone()).collect();
281
282 match self.backend.rerank(query, &documents, top_k).await {
283 Ok(scores) => Self::reorder(scores, candidates, top_k),
284 Err(e) => {
285 tracing::warn!(
288 error = %e,
289 "GatewayReranker call failed; falling back to upstream candidate order"
290 );
291 let mut fallback = candidates;
292 fallback.truncate(top_k);
293 fallback
294 }
295 }
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302 use std::sync::Arc;
303
304 fn result(id: &str, chunk: &str) -> KnowledgeResult {
305 KnowledgeResult {
306 document_id: id.to_string(),
307 chunk: chunk.to_string(),
308 score: 0.5,
309 source: format!("{id}.md"),
310 }
311 }
312
313 struct StubBackend {
315 scores: Vec<RerankScore>,
316 }
317 #[async_trait]
318 impl RerankBackend for StubBackend {
319 async fn rerank(
320 &self,
321 _query: &str,
322 _documents: &[String],
323 _top_n: usize,
324 ) -> Result<Vec<RerankScore>> {
325 Ok(self.scores.clone())
326 }
327 }
328
329 struct ErrorBackend;
331 #[async_trait]
332 impl RerankBackend for ErrorBackend {
333 async fn rerank(
334 &self,
335 _query: &str,
336 _documents: &[String],
337 _top_n: usize,
338 ) -> Result<Vec<RerankScore>> {
339 Err(anyhow!("simulated rerank API failure"))
340 }
341 }
342
343 #[tokio::test]
346 async fn gateway_reranker_reorders_by_relevance() {
347 let candidates = vec![
349 result("shipping", "shipping takes 5-7 days"),
350 result("warranty", "warranty is one year"),
351 result("returns", "30 day refund window"),
352 ];
353 let scores = vec![
355 RerankScore {
356 index: 0,
357 relevance_score: 0.1,
358 },
359 RerankScore {
360 index: 1,
361 relevance_score: 0.4,
362 },
363 RerankScore {
364 index: 2,
365 relevance_score: 0.95,
366 },
367 ];
368 let reranker = GatewayReranker::with_backend(Arc::new(StubBackend { scores }));
369 let out = reranker.rerank("refund returns", candidates, 3).await;
370
371 assert_eq!(
372 out.iter()
373 .map(|r| r.document_id.as_str())
374 .collect::<Vec<_>>(),
375 vec!["returns", "warranty", "shipping"],
376 "candidates must be reordered by descending relevance score"
377 );
378 }
379
380 #[tokio::test]
381 async fn gateway_reranker_truncates_to_top_k() {
382 let candidates = vec![
383 result("a", "alpha"),
384 result("b", "beta"),
385 result("c", "gamma"),
386 ];
387 let scores = vec![
388 RerankScore {
389 index: 2,
390 relevance_score: 0.9,
391 },
392 RerankScore {
393 index: 0,
394 relevance_score: 0.5,
395 },
396 RerankScore {
397 index: 1,
398 relevance_score: 0.1,
399 },
400 ];
401 let reranker = GatewayReranker::with_backend(Arc::new(StubBackend { scores }));
402 let out = reranker.rerank("q", candidates, 1).await;
403 assert_eq!(out.len(), 1);
404 assert_eq!(out[0].document_id, "c", "top_k=1 keeps only the best");
405 }
406
407 #[tokio::test]
409 async fn gateway_reranker_error_falls_back_to_input_order() {
410 let candidates = vec![
411 result("first", "one"),
412 result("second", "two"),
413 result("third", "three"),
414 ];
415 let reranker = GatewayReranker::with_backend(Arc::new(ErrorBackend));
416 let out = reranker.rerank("anything", candidates, 2).await;
417
418 assert_eq!(out.len(), 2, "fallback truncates to top_k");
419 assert_eq!(
420 out.iter()
421 .map(|r| r.document_id.as_str())
422 .collect::<Vec<_>>(),
423 vec!["first", "second"],
424 "on error the upstream order is preserved"
425 );
426 }
427
428 #[tokio::test]
431 async fn gateway_reranker_partial_scores_appends_unscored_in_order() {
432 let candidates = vec![result("a", "aaa"), result("b", "bbb"), result("c", "ccc")];
433 let scores = vec![RerankScore {
435 index: 2,
436 relevance_score: 0.9,
437 }];
438 let reranker = GatewayReranker::with_backend(Arc::new(StubBackend { scores }));
439 let out = reranker.rerank("q", candidates, 3).await;
440 assert_eq!(
441 out.iter()
442 .map(|r| r.document_id.as_str())
443 .collect::<Vec<_>>(),
444 vec!["c", "a", "b"],
445 "scored candidate first, then unscored in upstream order"
446 );
447 }
448
449 #[tokio::test]
451 async fn gateway_reranker_ignores_out_of_range_index() {
452 let candidates = vec![result("a", "aaa"), result("b", "bbb")];
453 let scores = vec![
454 RerankScore {
455 index: 99, relevance_score: 0.99,
457 },
458 RerankScore {
459 index: 1,
460 relevance_score: 0.5,
461 },
462 ];
463 let reranker = GatewayReranker::with_backend(Arc::new(StubBackend { scores }));
464 let out = reranker.rerank("q", candidates, 2).await;
465 assert_eq!(
467 out.iter()
468 .map(|r| r.document_id.as_str())
469 .collect::<Vec<_>>(),
470 vec!["b", "a"]
471 );
472 }
473
474 #[tokio::test]
475 async fn gateway_reranker_empty_candidates_yields_empty() {
476 let reranker = GatewayReranker::with_backend(Arc::new(StubBackend { scores: vec![] }));
477 let out = reranker.rerank("q", vec![], 3).await;
478 assert!(out.is_empty());
479 }
480
481 #[tokio::test]
487 #[ignore = "network + creds: gated on SMOOTH_AGENT_E2E=1 and a /v1/rerank route"]
488 async fn live_rerank() {
489 if std::env::var("SMOOTH_AGENT_E2E").as_deref() != Ok("1") {
490 eprintln!("skipping live rerank: set SMOOTH_AGENT_E2E=1 to run");
491 return;
492 }
493 let Ok(reranker) = GatewayReranker::from_env() else {
494 eprintln!("skipping live rerank: SMOOAI_GATEWAY_URL / SMOOAI_GATEWAY_KEY not set");
495 return;
496 };
497 let candidates = vec![
498 result("shipping", "Standard shipping takes 5 to 7 business days."),
499 result("warranty", "Warranty claims must be filed within one year."),
500 result(
501 "returns",
502 "Our return policy: refunds within the 30 day window.",
503 ),
504 ];
505 let out = reranker
506 .rerank("how do refunds and returns work", candidates, 3)
507 .await;
508 eprintln!(
509 "live rerank order: {:?}",
510 out.iter()
511 .map(|r| r.document_id.as_str())
512 .collect::<Vec<_>>()
513 );
514 assert_eq!(out.len(), 3, "live rerank should return all 3 reordered");
515 }
516}