Skip to main content

sochdb_vector/
wasm_rerank.rs

1//! WASM rerank plugin interface for in-engine cross-encoder reranking.
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct RerankCandidate {
8    pub doc_id: u64,
9    pub text: String,
10    pub fused_score: f32,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct RerankResult {
15    pub doc_id: u64,
16    pub rerank_score: f32,
17    pub fused_score: f32,
18}
19
20#[derive(Debug, Clone)]
21pub struct WasmRerankPlugin {
22    pub plugin_id: String,
23    pub loaded: bool,
24}
25
26impl WasmRerankPlugin {
27    pub fn new(plugin_id: impl Into<String>) -> Self {
28        Self {
29            plugin_id: plugin_id.into(),
30            loaded: false,
31        }
32    }
33
34    pub fn load(&mut self) {
35        self.loaded = true;
36    }
37
38    /// Rerank top-n candidates. Production path delegates to WASM sandbox;
39    /// fallback uses fused score ordering.
40    pub fn rerank(&self, candidates: &[RerankCandidate], top_k: usize) -> Vec<RerankResult> {
41        let mut scored: Vec<RerankResult> = candidates
42            .iter()
43            .map(|c| RerankResult {
44                doc_id: c.doc_id,
45                rerank_score: c.fused_score,
46                fused_score: c.fused_score,
47            })
48            .collect();
49        scored.sort_by(|a, b| {
50            b.rerank_score
51                .partial_cmp(&a.rerank_score)
52                .unwrap_or(std::cmp::Ordering::Equal)
53        });
54        scored.truncate(top_k);
55        scored
56    }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ProvenanceRerankOutput {
61    pub doc_id: u64,
62    pub rerank_score: f32,
63    pub episode_id: Option<u64>,
64    pub trust_score: f32,
65}
66
67pub fn attach_provenance(
68    reranked: &[RerankResult],
69    provenance: &HashMap<u64, (Option<u64>, f32)>,
70) -> Vec<ProvenanceRerankOutput> {
71    reranked
72        .iter()
73        .map(|r| {
74            let (ep, trust) = provenance.get(&r.doc_id).copied().unwrap_or((None, 0.5));
75            ProvenanceRerankOutput {
76                doc_id: r.doc_id,
77                rerank_score: r.rerank_score,
78                episode_id: ep,
79                trust_score: trust,
80            }
81        })
82        .collect()
83}