Skip to main content

smooth_operator_adapter_postgres/
reranker.rs

1//! Adapter-specific reranker: the live [`GatewayReranker`] (feature gap G8).
2//!
3//! The provider-agnostic [`Reranker`] trait, the identity [`NoopReranker`], and
4//! the network-free [`LexicalReranker`] all live in
5//! [`smooth_operator::rerank`] — the shared home so the retrieval path can swap a
6//! reranker in without depending on any paid API. This module holds only the
7//! adapter-specific [`GatewayReranker`]: a cross-encoder reranker behind the
8//! SmooAI LiteLLM gateway's Cohere/Voyage-style `/v1/rerank` endpoint, exactly as
9//! [`GatewayEmbedder`](crate::GatewayEmbedder) holds the live `/v1/embeddings`
10//! client. It drags `reqwest` and so lives here rather than in `core`.
11//!
12//! ## Endpoint shape
13//!
14//! The gateway exposes a Cohere/Voyage-compatible rerank route:
15//!
16//! ```text
17//! POST {base}/v1/rerank
18//! { "model": "...", "query": "...", "documents": ["doc0", "doc1", ...],
19//!   "top_n": K }
20//! → { "results": [ { "index": <usize>, "relevance_score": <f32> }, ... ] }
21//! ```
22//!
23//! The reranker sends the candidate chunks as `documents`, reads the returned
24//! `index → relevance_score`, and reorders the **original** candidates by that
25//! score (highest first), truncating to `top_k`.
26//!
27//! ## Failure is non-fatal (never drop the turn)
28//!
29//! A reranker is a *quality* stage, not a *correctness* stage: the upstream
30//! retrieval already produced a usable, rank-ordered candidate set. So any
31//! failure on the rerank call — network error, non-2xx, malformed JSON, an
32//! out-of-range index — degrades gracefully to the **input order** (truncated to
33//! `top_k`), logging a [`tracing::warn!`]. It never panics and never drops the
34//! turn. This mirrors the embedder's "fail loud, keep working" posture, tuned to
35//! the fact that an identity reorder is a perfectly safe fallback here.
36//!
37//! ## Testability seam
38//!
39//! [`GatewayReranker`] is generic over a [`RerankBackend`] — the thing that turns
40//! `(query, documents, top_n)` into `(index, score)` pairs. The production
41//! backend is [`HttpRerankBackend`] (the real `/v1/rerank` call). Unit tests
42//! inject a stub backend so the reorder/truncate/error-fallback logic is exercised
43//! **without touching the network**, exactly how `github_search` tests its
44//! `GithubSearchBackend`.
45
46use anyhow::{anyhow, Result};
47use async_trait::async_trait;
48
49use smooth_operator::rerank::Reranker;
50use smooth_operator_core::KnowledgeResult;
51
52/// Default rerank model requested over the gateway (Cohere-compatible). Distinct
53/// from the embedding model and the chat model.
54pub const DEFAULT_RERANK_MODEL: &str = "rerank-english-v3.0";
55
56/// One scored candidate returned by a [`RerankBackend`]: the candidate's index in
57/// the request `documents` array, and its relevance score against the query.
58#[derive(Debug, Clone, Copy, PartialEq)]
59pub struct RerankScore {
60    /// Index into the `documents` slice the backend was given.
61    pub index: usize,
62    /// Relevance score; higher is more relevant. Reordering is by this value
63    /// descending.
64    pub relevance_score: f32,
65}
66
67/// A pluggable rerank backend.
68///
69/// The production [`HttpRerankBackend`] POSTs to the gateway's `/v1/rerank`. Tests
70/// inject a stub so the [`GatewayReranker`] reorder/truncate/error-fallback logic
71/// runs offline (mirrors `github_search`'s `GithubSearchBackend` seam).
72#[async_trait]
73pub trait RerankBackend: Send + Sync {
74    /// Score `documents` against `query`, returning at most `top_n` `(index,
75    /// score)` pairs. Implementations need not sort — [`GatewayReranker`] sorts
76    /// the returned scores itself.
77    ///
78    /// # Errors
79    /// Returns an error if the upstream rerank call fails; [`GatewayReranker`]
80    /// catches it and falls back to the input order.
81    async fn rerank(
82        &self,
83        query: &str,
84        documents: &[String],
85        top_n: usize,
86    ) -> Result<Vec<RerankScore>>;
87}
88
89/// The real backend: a Cohere/Voyage-style `/v1/rerank` call over the gateway.
90#[derive(Clone)]
91pub struct HttpRerankBackend {
92    client: reqwest::Client,
93    base_url: String,
94    api_key: String,
95    model: String,
96}
97
98impl HttpRerankBackend {
99    /// Build from explicit config.
100    #[must_use]
101    pub fn new(
102        base_url: impl Into<String>,
103        api_key: impl Into<String>,
104        model: impl Into<String>,
105    ) -> Self {
106        Self {
107            client: reqwest::Client::new(),
108            base_url: base_url.into(),
109            api_key: api_key.into(),
110            model: model.into(),
111        }
112    }
113}
114
115#[async_trait]
116impl RerankBackend for HttpRerankBackend {
117    async fn rerank(
118        &self,
119        query: &str,
120        documents: &[String],
121        top_n: usize,
122    ) -> Result<Vec<RerankScore>> {
123        // Trim a trailing slash so `{base}/v1/rerank` is well-formed whether the
124        // configured URL ends in `/` or not.
125        let url = format!("{}/v1/rerank", self.base_url.trim_end_matches('/'));
126        let body = serde_json::json!({
127            "model": self.model,
128            "query": query,
129            "documents": documents,
130            "top_n": top_n,
131        });
132
133        let resp = self
134            .client
135            .post(&url)
136            .bearer_auth(&self.api_key)
137            .json(&body)
138            .send()
139            .await?;
140
141        if !resp.status().is_success() {
142            let status = resp.status();
143            let text = resp.text().await.unwrap_or_default();
144            return Err(anyhow!("rerank request failed ({status}): {text}"));
145        }
146
147        #[derive(serde::Deserialize)]
148        struct ResultItem {
149            index: usize,
150            relevance_score: f32,
151        }
152        #[derive(serde::Deserialize)]
153        struct RerankResponse {
154            results: Vec<ResultItem>,
155        }
156
157        let parsed: RerankResponse = resp.json().await?;
158        Ok(parsed
159            .results
160            .into_iter()
161            .map(|r| RerankScore {
162                index: r.index,
163                relevance_score: r.relevance_score,
164            })
165            .collect())
166    }
167}
168
169/// Cross-encoder reranker over the SmooAI gateway's `/v1/rerank` endpoint
170/// (feature gap G8).
171///
172/// Reorders retrieval candidates by a sharp query↔candidate relevance model and
173/// truncates to `top_k`. On any backend failure it falls back to the input order
174/// (truncated) — a reranker is a quality stage, so an identity reorder is always
175/// a safe fallback. Construct with [`from_env`](Self::from_env) for the live
176/// gateway, or [`with_backend`](Self::with_backend) to inject a stub in tests.
177pub struct GatewayReranker {
178    backend: std::sync::Arc<dyn RerankBackend>,
179}
180
181impl GatewayReranker {
182    /// Build over an explicit backend. Production passes an
183    /// [`HttpRerankBackend`]; tests pass a stub.
184    #[must_use]
185    pub fn with_backend(backend: std::sync::Arc<dyn RerankBackend>) -> Self {
186        Self { backend }
187    }
188
189    /// Build the live gateway reranker from explicit config.
190    #[must_use]
191    pub fn new(
192        base_url: impl Into<String>,
193        api_key: impl Into<String>,
194        model: impl Into<String>,
195    ) -> Self {
196        Self::with_backend(std::sync::Arc::new(HttpRerankBackend::new(
197            base_url, api_key, model,
198        )))
199    }
200
201    /// Build from `SMOOAI_GATEWAY_URL` + `SMOOAI_GATEWAY_KEY`, defaulting the
202    /// model to [`DEFAULT_RERANK_MODEL`].
203    ///
204    /// # Errors
205    /// Returns an error if either environment variable is unset.
206    pub fn from_env() -> Result<Self> {
207        let base_url = std::env::var("SMOOAI_GATEWAY_URL")
208            .map_err(|_| anyhow!("SMOOAI_GATEWAY_URL is not set"))?;
209        let api_key = std::env::var("SMOOAI_GATEWAY_KEY")
210            .map_err(|_| anyhow!("SMOOAI_GATEWAY_KEY is not set"))?;
211        Ok(Self::new(base_url, api_key, DEFAULT_RERANK_MODEL))
212    }
213
214    /// Reorder `candidates` by the backend's `(index, score)` pairs.
215    ///
216    /// Scores are sorted descending; any candidate the backend did NOT score (or
217    /// scored with an out-of-range index, which is ignored) keeps a lower priority
218    /// than scored ones and retains its upstream order — so a partial response
219    /// still degrades sanely rather than dropping candidates.
220    fn reorder(
221        mut scores: Vec<RerankScore>,
222        candidates: Vec<KnowledgeResult>,
223        top_k: usize,
224    ) -> Vec<KnowledgeResult> {
225        let n = candidates.len();
226        // Move candidates into Options so we can take() each at most once.
227        let mut slots: Vec<Option<KnowledgeResult>> = candidates.into_iter().map(Some).collect();
228
229        // Stable sort by score descending so equal scores keep backend order.
230        scores.sort_by(|a, b| {
231            b.relevance_score
232                .partial_cmp(&a.relevance_score)
233                .unwrap_or(std::cmp::Ordering::Equal)
234        });
235
236        let mut out: Vec<KnowledgeResult> = Vec::with_capacity(top_k.min(n));
237        let mut taken = vec![false; n];
238        for s in scores {
239            if out.len() >= top_k {
240                break;
241            }
242            // Ignore out-of-range indices from a misbehaving backend rather than
243            // panicking.
244            if s.index < n && !taken[s.index] {
245                if let Some(c) = slots[s.index].take() {
246                    taken[s.index] = true;
247                    out.push(c);
248                }
249            }
250        }
251        // Append any unscored candidates (backend returned fewer than n, e.g. a
252        // top_n cutoff) in upstream order, until top_k is reached.
253        if out.len() < top_k {
254            for (i, slot) in slots.iter_mut().enumerate() {
255                if out.len() >= top_k {
256                    break;
257                }
258                if !taken[i] {
259                    if let Some(c) = slot.take() {
260                        out.push(c);
261                    }
262                }
263            }
264        }
265        out
266    }
267}
268
269#[async_trait]
270impl Reranker for GatewayReranker {
271    async fn rerank(
272        &self,
273        query: &str,
274        candidates: Vec<KnowledgeResult>,
275        top_k: usize,
276    ) -> Vec<KnowledgeResult> {
277        if candidates.is_empty() || top_k == 0 {
278            return Vec::new();
279        }
280        let documents: Vec<String> = candidates.iter().map(|c| c.chunk.clone()).collect();
281
282        match self.backend.rerank(query, &documents, top_k).await {
283            Ok(scores) => Self::reorder(scores, candidates, top_k),
284            Err(e) => {
285                // Quality stage: a rerank failure must never drop the turn. Fall
286                // back to the upstream order, truncated to top_k.
287                tracing::warn!(
288                    error = %e,
289                    "GatewayReranker call failed; falling back to upstream candidate order"
290                );
291                let mut fallback = candidates;
292                fallback.truncate(top_k);
293                fallback
294            }
295        }
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use std::sync::Arc;
303
304    fn result(id: &str, chunk: &str) -> KnowledgeResult {
305        KnowledgeResult {
306            document_id: id.to_string(),
307            chunk: chunk.to_string(),
308            score: 0.5,
309            source: format!("{id}.md"),
310        }
311    }
312
313    /// A stub backend that returns caller-supplied scores — no network.
314    struct StubBackend {
315        scores: Vec<RerankScore>,
316    }
317    #[async_trait]
318    impl RerankBackend for StubBackend {
319        async fn rerank(
320            &self,
321            _query: &str,
322            _documents: &[String],
323            _top_n: usize,
324        ) -> Result<Vec<RerankScore>> {
325            Ok(self.scores.clone())
326        }
327    }
328
329    /// A stub backend that always errors — exercises the graceful fallback.
330    struct ErrorBackend;
331    #[async_trait]
332    impl RerankBackend for ErrorBackend {
333        async fn rerank(
334            &self,
335            _query: &str,
336            _documents: &[String],
337            _top_n: usize,
338        ) -> Result<Vec<RerankScore>> {
339            Err(anyhow!("simulated rerank API failure"))
340        }
341    }
342
343    /// TDD: the highest-relevance candidate is seeded LAST; the GatewayReranker
344    /// must promote it to the front using the backend's scores.
345    #[tokio::test]
346    async fn gateway_reranker_reorders_by_relevance() {
347        // Upstream order: index 0 weak, index 1 medium, index 2 strong.
348        let candidates = vec![
349            result("shipping", "shipping takes 5-7 days"),
350            result("warranty", "warranty is one year"),
351            result("returns", "30 day refund window"),
352        ];
353        // Backend says index 2 is most relevant, then 1, then 0.
354        let scores = vec![
355            RerankScore {
356                index: 0,
357                relevance_score: 0.1,
358            },
359            RerankScore {
360                index: 1,
361                relevance_score: 0.4,
362            },
363            RerankScore {
364                index: 2,
365                relevance_score: 0.95,
366            },
367        ];
368        let reranker = GatewayReranker::with_backend(Arc::new(StubBackend { scores }));
369        let out = reranker.rerank("refund returns", candidates, 3).await;
370
371        assert_eq!(
372            out.iter()
373                .map(|r| r.document_id.as_str())
374                .collect::<Vec<_>>(),
375            vec!["returns", "warranty", "shipping"],
376            "candidates must be reordered by descending relevance score"
377        );
378    }
379
380    #[tokio::test]
381    async fn gateway_reranker_truncates_to_top_k() {
382        let candidates = vec![
383            result("a", "alpha"),
384            result("b", "beta"),
385            result("c", "gamma"),
386        ];
387        let scores = vec![
388            RerankScore {
389                index: 2,
390                relevance_score: 0.9,
391            },
392            RerankScore {
393                index: 0,
394                relevance_score: 0.5,
395            },
396            RerankScore {
397                index: 1,
398                relevance_score: 0.1,
399            },
400        ];
401        let reranker = GatewayReranker::with_backend(Arc::new(StubBackend { scores }));
402        let out = reranker.rerank("q", candidates, 1).await;
403        assert_eq!(out.len(), 1);
404        assert_eq!(out[0].document_id, "c", "top_k=1 keeps only the best");
405    }
406
407    /// API error → input order preserved (truncated), no panic, no drop.
408    #[tokio::test]
409    async fn gateway_reranker_error_falls_back_to_input_order() {
410        let candidates = vec![
411            result("first", "one"),
412            result("second", "two"),
413            result("third", "three"),
414        ];
415        let reranker = GatewayReranker::with_backend(Arc::new(ErrorBackend));
416        let out = reranker.rerank("anything", candidates, 2).await;
417
418        assert_eq!(out.len(), 2, "fallback truncates to top_k");
419        assert_eq!(
420            out.iter()
421                .map(|r| r.document_id.as_str())
422                .collect::<Vec<_>>(),
423            vec!["first", "second"],
424            "on error the upstream order is preserved"
425        );
426    }
427
428    /// A partial backend response (fewer scores than candidates) still returns
429    /// the unscored candidates in upstream order rather than dropping them.
430    #[tokio::test]
431    async fn gateway_reranker_partial_scores_appends_unscored_in_order() {
432        let candidates = vec![result("a", "aaa"), result("b", "bbb"), result("c", "ccc")];
433        // Backend only scored index 2.
434        let scores = vec![RerankScore {
435            index: 2,
436            relevance_score: 0.9,
437        }];
438        let reranker = GatewayReranker::with_backend(Arc::new(StubBackend { scores }));
439        let out = reranker.rerank("q", candidates, 3).await;
440        assert_eq!(
441            out.iter()
442                .map(|r| r.document_id.as_str())
443                .collect::<Vec<_>>(),
444            vec!["c", "a", "b"],
445            "scored candidate first, then unscored in upstream order"
446        );
447    }
448
449    /// An out-of-range index from a misbehaving backend is ignored, not panicked.
450    #[tokio::test]
451    async fn gateway_reranker_ignores_out_of_range_index() {
452        let candidates = vec![result("a", "aaa"), result("b", "bbb")];
453        let scores = vec![
454            RerankScore {
455                index: 99, // out of range
456                relevance_score: 0.99,
457            },
458            RerankScore {
459                index: 1,
460                relevance_score: 0.5,
461            },
462        ];
463        let reranker = GatewayReranker::with_backend(Arc::new(StubBackend { scores }));
464        let out = reranker.rerank("q", candidates, 2).await;
465        // index 99 ignored; index 1 promoted, then unscored index 0 appended.
466        assert_eq!(
467            out.iter()
468                .map(|r| r.document_id.as_str())
469                .collect::<Vec<_>>(),
470            vec!["b", "a"]
471        );
472    }
473
474    #[tokio::test]
475    async fn gateway_reranker_empty_candidates_yields_empty() {
476        let reranker = GatewayReranker::with_backend(Arc::new(StubBackend { scores: vec![] }));
477        let out = reranker.rerank("q", vec![], 3).await;
478        assert!(out.is_empty());
479    }
480
481    /// Live rerank against the real gateway — only with `SMOOTH_AGENT_E2E=1` and a
482    /// `SMOOAI_GATEWAY_KEY`. Ignored by default (network + creds + the gateway
483    /// must actually expose a `/v1/rerank` route, which is not guaranteed). Run:
484    /// `SMOOTH_AGENT_E2E=1 cargo test -p smooai-smooth-operator-adapter-postgres \
485    ///    reranker::tests::live_rerank -- --ignored --nocapture`
486    #[tokio::test]
487    #[ignore = "network + creds: gated on SMOOTH_AGENT_E2E=1 and a /v1/rerank route"]
488    async fn live_rerank() {
489        if std::env::var("SMOOTH_AGENT_E2E").as_deref() != Ok("1") {
490            eprintln!("skipping live rerank: set SMOOTH_AGENT_E2E=1 to run");
491            return;
492        }
493        let Ok(reranker) = GatewayReranker::from_env() else {
494            eprintln!("skipping live rerank: SMOOAI_GATEWAY_URL / SMOOAI_GATEWAY_KEY not set");
495            return;
496        };
497        let candidates = vec![
498            result("shipping", "Standard shipping takes 5 to 7 business days."),
499            result("warranty", "Warranty claims must be filed within one year."),
500            result(
501                "returns",
502                "Our return policy: refunds within the 30 day window.",
503            ),
504        ];
505        let out = reranker
506            .rerank("how do refunds and returns work", candidates, 3)
507            .await;
508        eprintln!(
509            "live rerank order: {:?}",
510            out.iter()
511                .map(|r| r.document_id.as_str())
512                .collect::<Vec<_>>()
513        );
514        assert_eq!(out.len(), 3, "live rerank should return all 3 reordered");
515    }
516}