Skip to main content

ruvector_verified_wasm/
lib.rs

1//! WASM bindings for `ruvector-verified`: proof-carrying vector operations in the browser.
2//!
3//! # Quick Start (JavaScript)
4//!
5//! ```js
6//! import init, { JsProofEnv } from "ruvector-verified-wasm";
7//!
8//! await init();
9//! const env = new JsProofEnv();
10//!
11//! // Prove dimension equality (~500ns)
12//! const proofId = env.prove_dim_eq(384, 384);  // Ok -> proof ID
13//!
14//! // Verify a batch of vectors
15//! const vectors = [new Float32Array(384).fill(0.5)];
16//! const result = env.verify_batch(384, vectors);
17//!
18//! // Get statistics
19//! console.log(env.stats());
20//!
21//! // Create attestation (82 bytes)
22//! const att = env.create_attestation(proofId);
23//! console.log(att.bytes.length); // 82
24//! ```
25
26mod utils;
27
28use ruvector_verified::{
29    ProofEnvironment,
30    fast_arena::FastTermArena,
31    cache::ConversionCache,
32    gated::{self, ProofKind, ProofTier},
33    proof_store,
34    vector_types,
35};
36use serde::Serialize;
37use wasm_bindgen::prelude::*;
38
39// ---------------------------------------------------------------------------
40// Module init
41// ---------------------------------------------------------------------------
42
43/// Called automatically when the WASM module is loaded.
44#[wasm_bindgen(start)]
45pub fn init() {
46    utils::set_panic_hook();
47    utils::console_log("ruvector-verified-wasm loaded");
48}
49
50/// Return the crate version.
51#[wasm_bindgen]
52pub fn version() -> String {
53    env!("CARGO_PKG_VERSION").to_string()
54}
55
56// ---------------------------------------------------------------------------
57// JsProofEnv — main entry point
58// ---------------------------------------------------------------------------
59
60/// Proof environment for the browser. Wraps `ProofEnvironment` + ultra caches.
61#[wasm_bindgen]
62pub struct JsProofEnv {
63    env: ProofEnvironment,
64    arena: FastTermArena,
65    cache: ConversionCache,
66}
67
68#[wasm_bindgen]
69impl JsProofEnv {
70    /// Create a new proof environment with all optimizations.
71    #[wasm_bindgen(constructor)]
72    pub fn new() -> Self {
73        Self {
74            env: ProofEnvironment::new(),
75            arena: FastTermArena::with_capacity(4096),
76            cache: ConversionCache::with_capacity(1024),
77        }
78    }
79
80    /// Prove that two dimensions are equal. Returns proof term ID.
81    ///
82    /// Throws if dimensions don't match.
83    pub fn prove_dim_eq(&mut self, expected: u32, actual: u32) -> Result<u32, JsError> {
84        vector_types::prove_dim_eq(&mut self.env, expected, actual)
85            .map_err(|e| JsError::new(&e.to_string()))
86    }
87
88    /// Build a `RuVec n` type term. Returns term ID.
89    pub fn mk_vector_type(&mut self, dim: u32) -> Result<u32, JsError> {
90        vector_types::mk_vector_type(&mut self.env, dim)
91            .map_err(|e| JsError::new(&e.to_string()))
92    }
93
94    /// Build a distance metric type term. Supported: "L2", "Cosine", "Dot".
95    pub fn mk_distance_metric(&mut self, metric: &str) -> Result<u32, JsError> {
96        vector_types::mk_distance_metric(&mut self.env, metric)
97            .map_err(|e| JsError::new(&e.to_string()))
98    }
99
100    /// Verify that a single vector has the expected dimension.
101    pub fn verify_dim_check(&mut self, index_dim: u32, vector: &[f32]) -> Result<u32, JsError> {
102        vector_types::verified_dim_check(&mut self.env, index_dim, vector)
103            .map(|op| op.proof_id)
104            .map_err(|e| JsError::new(&e.to_string()))
105    }
106
107    /// Verify a batch of vectors (passed as flat f32 array + dimension).
108    ///
109    /// `flat_vectors` is a contiguous f32 array; each vector is `dim` elements.
110    /// Returns the number of vectors verified.
111    pub fn verify_batch_flat(
112        &mut self,
113        dim: u32,
114        flat_vectors: &[f32],
115    ) -> Result<u32, JsError> {
116        let d = dim as usize;
117        if flat_vectors.len() % d != 0 {
118            return Err(JsError::new(&format!(
119                "flat_vectors length {} not divisible by dim {}",
120                flat_vectors.len(), dim
121            )));
122        }
123        let slices: Vec<&[f32]> = flat_vectors.chunks_exact(d).collect();
124        vector_types::verify_batch_dimensions(&mut self.env, dim, &slices)
125            .map(|op| op.value as u32)
126            .map_err(|e| JsError::new(&e.to_string()))
127    }
128
129    /// Intern a hash into the FastTermArena. Returns `[term_id, was_cached]`.
130    pub fn arena_intern(&self, hash_hi: u32, hash_lo: u32) -> Vec<u32> {
131        let hash = (hash_hi as u64) << 32 | hash_lo as u64;
132        let (id, cached) = self.arena.intern(hash);
133        vec![id, if cached { 1 } else { 0 }]
134    }
135
136    /// Route a proof to the cheapest tier. Returns tier name.
137    pub fn route_proof(&self, kind: &str) -> Result<JsValue, JsError> {
138        let proof_kind = match kind {
139            "reflexivity" => ProofKind::Reflexivity,
140            "dimension" => ProofKind::DimensionEquality { expected: 0, actual: 0 },
141            "pipeline" => ProofKind::PipelineComposition { stages: 1 },
142            other => ProofKind::Custom { estimated_complexity: other.parse().unwrap_or(10) },
143        };
144        let decision = gated::route_proof(proof_kind, &self.env);
145        let tier_name = match decision.tier {
146            ProofTier::Reflex => "reflex",
147            ProofTier::Standard { .. } => "standard",
148            ProofTier::Deep => "deep",
149        };
150        let result = JsRoutingResult {
151            tier: tier_name.to_string(),
152            reason: decision.reason.to_string(),
153            estimated_steps: decision.estimated_steps,
154        };
155        serde_wasm_bindgen::to_value(&result)
156            .map_err(|e| JsError::new(&e.to_string()))
157    }
158
159    /// Create a proof attestation (82 bytes). Returns serializable object.
160    pub fn create_attestation(&self, proof_id: u32) -> Result<JsValue, JsError> {
161        let att = proof_store::create_attestation(&self.env, proof_id);
162        let bytes = att.to_bytes();
163        let result = JsAttestation {
164            bytes,
165            proof_term_hash: hex_encode(&att.proof_term_hash),
166            environment_hash: hex_encode(&att.environment_hash),
167            verifier_version: format!("{:#010x}", att.verifier_version),
168            reduction_steps: att.reduction_steps,
169            cache_hit_rate_bps: att.cache_hit_rate_bps,
170        };
171        serde_wasm_bindgen::to_value(&result)
172            .map_err(|e| JsError::new(&e.to_string()))
173    }
174
175    /// Get verification statistics.
176    pub fn stats(&self) -> Result<JsValue, JsError> {
177        let s = self.env.stats();
178        let arena_stats = self.arena.stats();
179        let cache_stats = self.cache.stats();
180        let result = JsStats {
181            proofs_constructed: s.proofs_constructed,
182            proofs_verified: s.proofs_verified,
183            cache_hits: s.cache_hits,
184            cache_misses: s.cache_misses,
185            total_reductions: s.total_reductions,
186            terms_allocated: self.env.terms_allocated(),
187            arena_hit_rate: arena_stats.cache_hit_rate(),
188            conversion_cache_hit_rate: cache_stats.hit_rate(),
189        };
190        serde_wasm_bindgen::to_value(&result)
191            .map_err(|e| JsError::new(&e.to_string()))
192    }
193
194    /// Reset the environment (clears cache, resets counters, re-registers builtins).
195    pub fn reset(&mut self) {
196        self.env.reset();
197        self.arena.reset();
198        self.cache.clear();
199    }
200
201    /// Number of terms currently allocated.
202    pub fn terms_allocated(&self) -> u32 {
203        self.env.terms_allocated()
204    }
205}
206
207// ---------------------------------------------------------------------------
208// JSON result types
209// ---------------------------------------------------------------------------
210
211#[derive(Serialize)]
212struct JsRoutingResult {
213    tier: String,
214    reason: String,
215    estimated_steps: u32,
216}
217
218#[derive(Serialize)]
219struct JsAttestation {
220    bytes: Vec<u8>,
221    proof_term_hash: String,
222    environment_hash: String,
223    verifier_version: String,
224    reduction_steps: u32,
225    cache_hit_rate_bps: u16,
226}
227
228#[derive(Serialize)]
229struct JsStats {
230    proofs_constructed: u64,
231    proofs_verified: u64,
232    cache_hits: u64,
233    cache_misses: u64,
234    total_reductions: u64,
235    terms_allocated: u32,
236    arena_hit_rate: f64,
237    conversion_cache_hit_rate: f64,
238}
239
240fn hex_encode(bytes: &[u8]) -> String {
241    bytes.iter().map(|b| format!("{b:02x}")).collect()
242}