1use ternlang_core::trit::Trit;
13use ternlang_ml::{TritMatrix, quantize, bitnet_threshold};
14use rayon::prelude::*;
15use serde::{Serialize, Deserialize};
16
17#[derive(Serialize, Deserialize)]
19pub struct RuVectorDB {
20 pub embeddings: TritMatrix,
22 pub metadata: Vec<String>,
24}
25
26impl RuVectorDB {
27 pub fn from_f32(embeddings: &[Vec<f32>], metadata: Vec<String>) -> anyhow::Result<Self> {
29 if embeddings.is_empty() {
30 return Err(anyhow::anyhow!("Embeddings cannot be empty"));
31 }
32 let rows = embeddings.len();
33 let cols = embeddings[0].len();
34
35 let flat_f32: Vec<f32> = embeddings.iter().flatten().cloned().collect();
37 let threshold = bitnet_threshold(&flat_f32);
38
39 let trit_matrix = TritMatrix::from_f32(rows, cols, &flat_f32, threshold);
40
41 Ok(Self {
42 embeddings: trit_matrix,
43 metadata,
44 })
45 }
46
47 pub fn search(&self, query_f32: &[f32], top_k: usize) -> Vec<SearchResult> {
51 let threshold = bitnet_threshold(query_f32);
52 let query_trits = quantize(query_f32, threshold);
53
54 let scores = self.sparse_gemv_similarity(&query_trits);
56
57 let mut results: Vec<SearchResult> = scores.into_iter()
58 .enumerate()
59 .map(|(i, score)| SearchResult {
60 index: i,
61 score,
62 metadata: self.metadata.get(i).cloned().unwrap_or_default(),
63 })
64 .collect();
65
66 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
68 results.truncate(top_k);
69
70 results
71 }
72
73 fn sparse_gemv_similarity(&self, query: &[Trit]) -> Vec<f32> {
76 let num_docs = self.embeddings.rows;
77 let dim = self.embeddings.cols;
78
79 let q_flat: Vec<i8> = query.iter().map(|&t| match t {
81 Trit::Affirm => 1,
82 Trit::Reject => -1,
83 Trit::Tend => 0,
84 }).collect();
85
86 let db_flat = self.embeddings.to_i8_vec();
88
89 (0..num_docs).into_par_iter().map(|row_idx| {
91 let row_data = &db_flat[row_idx * dim .. (row_idx + 1) * dim];
92 let mut acc: i32 = 0;
93
94 for i in 0..dim {
97 let qi = q_flat[i];
98 if qi == 0 { continue; }
99
100 let di = row_data[i];
101 if di == 0 { continue; }
102
103 acc += (qi * di) as i32;
104 }
105
106 acc as f32
107 }).collect()
108 }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct SearchResult {
113 pub index: usize,
114 pub score: f32,
115 pub metadata: String,
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn test_ruvector_search() {
124 let embeddings = vec![
125 vec![1.0, 0.0, -1.0, 0.5],
126 vec![-1.0, 1.0, 0.0, 0.0],
127 vec![0.1, 0.1, 0.1, 0.1], ];
129 let metadata = vec!["Doc A".to_string(), "Doc B".to_string(), "Doc C".to_string()];
130
131 let db = RuVectorDB::from_f32(&embeddings, metadata).unwrap();
132
133 let query = vec![1.0, 0.0, -1.0, 0.0];
134 let results = db.search(&query, 2);
135
136 assert_eq!(results.len(), 2);
137 assert_eq!(results[0].metadata, "Doc A");
138 assert!(results[0].score > results[1].score);
139 }
140}