sochdb_vector/
wasm_rerank.rs1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct RerankCandidate {
8 pub doc_id: u64,
9 pub text: String,
10 pub fused_score: f32,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct RerankResult {
15 pub doc_id: u64,
16 pub rerank_score: f32,
17 pub fused_score: f32,
18}
19
20#[derive(Debug, Clone)]
21pub struct WasmRerankPlugin {
22 pub plugin_id: String,
23 pub loaded: bool,
24}
25
26impl WasmRerankPlugin {
27 pub fn new(plugin_id: impl Into<String>) -> Self {
28 Self {
29 plugin_id: plugin_id.into(),
30 loaded: false,
31 }
32 }
33
34 pub fn load(&mut self) {
35 self.loaded = true;
36 }
37
38 pub fn rerank(&self, candidates: &[RerankCandidate], top_k: usize) -> Vec<RerankResult> {
41 let mut scored: Vec<RerankResult> = candidates
42 .iter()
43 .map(|c| RerankResult {
44 doc_id: c.doc_id,
45 rerank_score: c.fused_score,
46 fused_score: c.fused_score,
47 })
48 .collect();
49 scored.sort_by(|a, b| {
50 b.rerank_score
51 .partial_cmp(&a.rerank_score)
52 .unwrap_or(std::cmp::Ordering::Equal)
53 });
54 scored.truncate(top_k);
55 scored
56 }
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct ProvenanceRerankOutput {
61 pub doc_id: u64,
62 pub rerank_score: f32,
63 pub episode_id: Option<u64>,
64 pub trust_score: f32,
65}
66
67pub fn attach_provenance(
68 reranked: &[RerankResult],
69 provenance: &HashMap<u64, (Option<u64>, f32)>,
70) -> Vec<ProvenanceRerankOutput> {
71 reranked
72 .iter()
73 .map(|r| {
74 let (ep, trust) = provenance.get(&r.doc_id).copied().unwrap_or((None, 0.5));
75 ProvenanceRerankOutput {
76 doc_id: r.doc_id,
77 rerank_score: r.rerank_score,
78 episode_id: ep,
79 trust_score: trust,
80 }
81 })
82 .collect()
83}