ruvector_math_wasm/
lib.rs

1//! WebAssembly bindings for ruvector-math
2//!
3//! This crate provides JavaScript/TypeScript bindings for the advanced
4//! mathematics in ruvector-math, enabling browser-based vector search
5//! with optimal transport, information geometry, and product manifolds.
6
7use wasm_bindgen::prelude::*;
8use ruvector_math::{
9    optimal_transport::{SlicedWasserstein, SinkhornSolver, GromovWasserstein},
10    information_geometry::{FisherInformation, NaturalGradient},
11    spherical::SphericalSpace,
12    product_manifold::ProductManifold,
13};
14
15#[wasm_bindgen(start)]
16pub fn start() {
17    #[cfg(feature = "console_error_panic_hook")]
18    console_error_panic_hook::set_once();
19}
20
21// ============================================================================
22// Optimal Transport
23// ============================================================================
24
25/// Sliced Wasserstein distance calculator for WASM
26#[wasm_bindgen]
27pub struct WasmSlicedWasserstein {
28    inner: SlicedWasserstein,
29}
30
31#[wasm_bindgen]
32impl WasmSlicedWasserstein {
33    /// Create a new Sliced Wasserstein calculator
34    ///
35    /// @param num_projections - Number of random 1D projections (100-1000 typical)
36    #[wasm_bindgen(constructor)]
37    pub fn new(num_projections: usize) -> Self {
38        Self {
39            inner: SlicedWasserstein::new(num_projections),
40        }
41    }
42
43    /// Set Wasserstein power (1 for W1, 2 for W2)
44    #[wasm_bindgen(js_name = withPower)]
45    pub fn with_power(self, p: f64) -> Self {
46        Self {
47            inner: self.inner.with_power(p),
48        }
49    }
50
51    /// Set random seed for reproducibility
52    #[wasm_bindgen(js_name = withSeed)]
53    pub fn with_seed(self, seed: u64) -> Self {
54        Self {
55            inner: self.inner.with_seed(seed),
56        }
57    }
58
59    /// Compute distance between two point clouds
60    ///
61    /// @param source - Source points as flat array [x1, y1, z1, x2, y2, z2, ...]
62    /// @param target - Target points as flat array
63    /// @param dim - Dimension of each point
64    #[wasm_bindgen]
65    pub fn distance(&self, source: &[f64], target: &[f64], dim: usize) -> f64 {
66        use ruvector_math::optimal_transport::OptimalTransport;
67
68        let source_points = to_points(source, dim);
69        let target_points = to_points(target, dim);
70
71        self.inner.distance(&source_points, &target_points)
72    }
73
74    /// Compute weighted distance
75    #[wasm_bindgen(js_name = weightedDistance)]
76    pub fn weighted_distance(
77        &self,
78        source: &[f64],
79        source_weights: &[f64],
80        target: &[f64],
81        target_weights: &[f64],
82        dim: usize,
83    ) -> f64 {
84        use ruvector_math::optimal_transport::OptimalTransport;
85
86        let source_points = to_points(source, dim);
87        let target_points = to_points(target, dim);
88
89        self.inner.weighted_distance(
90            &source_points,
91            source_weights,
92            &target_points,
93            target_weights,
94        )
95    }
96}
97
98/// Sinkhorn optimal transport solver for WASM
99#[wasm_bindgen]
100pub struct WasmSinkhorn {
101    inner: SinkhornSolver,
102}
103
104#[wasm_bindgen]
105impl WasmSinkhorn {
106    /// Create a new Sinkhorn solver
107    ///
108    /// @param regularization - Entropy regularization (0.01-0.1 typical)
109    /// @param max_iterations - Maximum iterations (100-1000 typical)
110    #[wasm_bindgen(constructor)]
111    pub fn new(regularization: f64, max_iterations: usize) -> Self {
112        Self {
113            inner: SinkhornSolver::new(regularization, max_iterations),
114        }
115    }
116
117    /// Compute transport cost between point clouds
118    #[wasm_bindgen]
119    pub fn distance(&self, source: &[f64], target: &[f64], dim: usize) -> Result<f64, JsError> {
120        let source_points = to_points(source, dim);
121        let target_points = to_points(target, dim);
122
123        self.inner
124            .distance(&source_points, &target_points)
125            .map_err(|e| JsError::new(&e.to_string()))
126    }
127
128    /// Solve optimal transport and return transport plan
129    #[wasm_bindgen(js_name = solveTransport)]
130    pub fn solve_transport(
131        &self,
132        cost_matrix: &[f64],
133        source_weights: &[f64],
134        target_weights: &[f64],
135        n: usize,
136        m: usize,
137    ) -> Result<TransportResult, JsError> {
138        let cost = to_matrix(cost_matrix, n, m);
139
140        let result = self
141            .inner
142            .solve(&cost, source_weights, target_weights)
143            .map_err(|e| JsError::new(&e.to_string()))?;
144
145        Ok(TransportResult {
146            plan: result.plan.into_iter().flatten().collect(),
147            cost: result.cost,
148            iterations: result.iterations,
149            converged: result.converged,
150        })
151    }
152}
153
154/// Result of Sinkhorn transport computation
155#[wasm_bindgen]
156pub struct TransportResult {
157    plan: Vec<f64>,
158    cost: f64,
159    iterations: usize,
160    converged: bool,
161}
162
163#[wasm_bindgen]
164impl TransportResult {
165    /// Get transport plan as flat array
166    #[wasm_bindgen(getter)]
167    pub fn plan(&self) -> Vec<f64> {
168        self.plan.clone()
169    }
170
171    /// Get total transport cost
172    #[wasm_bindgen(getter)]
173    pub fn cost(&self) -> f64 {
174        self.cost
175    }
176
177    /// Get number of iterations
178    #[wasm_bindgen(getter)]
179    pub fn iterations(&self) -> usize {
180        self.iterations
181    }
182
183    /// Whether algorithm converged
184    #[wasm_bindgen(getter)]
185    pub fn converged(&self) -> bool {
186        self.converged
187    }
188}
189
190/// Gromov-Wasserstein distance for WASM
191#[wasm_bindgen]
192pub struct WasmGromovWasserstein {
193    inner: GromovWasserstein,
194}
195
196#[wasm_bindgen]
197impl WasmGromovWasserstein {
198    /// Create a new Gromov-Wasserstein calculator
199    #[wasm_bindgen(constructor)]
200    pub fn new(regularization: f64) -> Self {
201        Self {
202            inner: GromovWasserstein::new(regularization),
203        }
204    }
205
206    /// Compute GW distance between point clouds
207    #[wasm_bindgen]
208    pub fn distance(&self, source: &[f64], target: &[f64], dim: usize) -> Result<f64, JsError> {
209        let source_points = to_points(source, dim);
210        let target_points = to_points(target, dim);
211
212        self.inner
213            .distance(&source_points, &target_points)
214            .map_err(|e| JsError::new(&e.to_string()))
215    }
216}
217
218// ============================================================================
219// Information Geometry
220// ============================================================================
221
222/// Fisher Information for WASM
223#[wasm_bindgen]
224pub struct WasmFisherInformation {
225    inner: FisherInformation,
226}
227
228#[wasm_bindgen]
229impl WasmFisherInformation {
230    /// Create a new Fisher Information calculator
231    #[wasm_bindgen(constructor)]
232    pub fn new() -> Self {
233        Self {
234            inner: FisherInformation::new(),
235        }
236    }
237
238    /// Set damping factor
239    #[wasm_bindgen(js_name = withDamping)]
240    pub fn with_damping(self, damping: f64) -> Self {
241        Self {
242            inner: self.inner.with_damping(damping),
243        }
244    }
245
246    /// Compute diagonal FIM from gradient samples
247    #[wasm_bindgen(js_name = diagonalFim)]
248    pub fn diagonal_fim(&self, gradients: &[f64], _num_samples: usize, dim: usize) -> Result<Vec<f64>, JsError> {
249        let grads = to_points(gradients, dim);
250        self.inner
251            .diagonal_fim(&grads)
252            .map_err(|e| JsError::new(&e.to_string()))
253    }
254
255    /// Compute natural gradient
256    #[wasm_bindgen(js_name = naturalGradient)]
257    pub fn natural_gradient(
258        &self,
259        fim_diag: &[f64],
260        gradient: &[f64],
261        damping: f64,
262    ) -> Vec<f64> {
263        gradient
264            .iter()
265            .zip(fim_diag.iter())
266            .map(|(&g, &f)| g / (f + damping))
267            .collect()
268    }
269}
270
271/// Natural Gradient optimizer for WASM
272#[wasm_bindgen]
273pub struct WasmNaturalGradient {
274    inner: NaturalGradient,
275}
276
277#[wasm_bindgen]
278impl WasmNaturalGradient {
279    /// Create a new Natural Gradient optimizer
280    #[wasm_bindgen(constructor)]
281    pub fn new(learning_rate: f64) -> Self {
282        Self {
283            inner: NaturalGradient::new(learning_rate),
284        }
285    }
286
287    /// Set damping factor
288    #[wasm_bindgen(js_name = withDamping)]
289    pub fn with_damping(self, damping: f64) -> Self {
290        Self {
291            inner: self.inner.with_damping(damping),
292        }
293    }
294
295    /// Use diagonal approximation
296    #[wasm_bindgen(js_name = withDiagonal)]
297    pub fn with_diagonal(self, use_diagonal: bool) -> Self {
298        Self {
299            inner: self.inner.with_diagonal(use_diagonal),
300        }
301    }
302
303    /// Compute update step
304    #[wasm_bindgen]
305    pub fn step(
306        &mut self,
307        gradient: &[f64],
308        gradient_samples: Option<Vec<f64>>,
309        dim: usize,
310    ) -> Result<Vec<f64>, JsError> {
311        let samples = gradient_samples.map(|s| to_points(&s, dim));
312
313        self.inner
314            .step(gradient, samples.as_deref())
315            .map_err(|e| JsError::new(&e.to_string()))
316    }
317
318    /// Reset optimizer state
319    #[wasm_bindgen]
320    pub fn reset(&mut self) {
321        self.inner.reset();
322    }
323}
324
325// ============================================================================
326// Spherical Geometry
327// ============================================================================
328
329/// Spherical space operations for WASM
330#[wasm_bindgen]
331pub struct WasmSphericalSpace {
332    inner: SphericalSpace,
333}
334
335#[wasm_bindgen]
336impl WasmSphericalSpace {
337    /// Create a new spherical space S^{n-1} embedded in R^n
338    #[wasm_bindgen(constructor)]
339    pub fn new(ambient_dim: usize) -> Self {
340        Self {
341            inner: SphericalSpace::new(ambient_dim),
342        }
343    }
344
345    /// Get ambient dimension
346    #[wasm_bindgen(getter, js_name = ambientDim)]
347    pub fn ambient_dim(&self) -> usize {
348        self.inner.ambient_dim()
349    }
350
351    /// Project point onto sphere
352    #[wasm_bindgen]
353    pub fn project(&self, point: &[f64]) -> Result<Vec<f64>, JsError> {
354        self.inner
355            .project(point)
356            .map_err(|e| JsError::new(&e.to_string()))
357    }
358
359    /// Geodesic distance on sphere
360    #[wasm_bindgen]
361    pub fn distance(&self, x: &[f64], y: &[f64]) -> Result<f64, JsError> {
362        self.inner
363            .distance(x, y)
364            .map_err(|e| JsError::new(&e.to_string()))
365    }
366
367    /// Exponential map: move from x in direction v
368    #[wasm_bindgen(js_name = expMap)]
369    pub fn exp_map(&self, x: &[f64], v: &[f64]) -> Result<Vec<f64>, JsError> {
370        self.inner
371            .exp_map(x, v)
372            .map_err(|e| JsError::new(&e.to_string()))
373    }
374
375    /// Logarithmic map: tangent vector at x pointing toward y
376    #[wasm_bindgen(js_name = logMap)]
377    pub fn log_map(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>, JsError> {
378        self.inner
379            .log_map(x, y)
380            .map_err(|e| JsError::new(&e.to_string()))
381    }
382
383    /// Geodesic interpolation at fraction t
384    #[wasm_bindgen]
385    pub fn geodesic(&self, x: &[f64], y: &[f64], t: f64) -> Result<Vec<f64>, JsError> {
386        self.inner
387            .geodesic(x, y, t)
388            .map_err(|e| JsError::new(&e.to_string()))
389    }
390
391    /// Fréchet mean of points
392    #[wasm_bindgen(js_name = frechetMean)]
393    pub fn frechet_mean(&self, points: &[f64], dim: usize) -> Result<Vec<f64>, JsError> {
394        let pts = to_points(points, dim);
395        self.inner
396            .frechet_mean(&pts, None)
397            .map_err(|e| JsError::new(&e.to_string()))
398    }
399}
400
401// ============================================================================
402// Product Manifolds
403// ============================================================================
404
405/// Product manifold for WASM: E^e × H^h × S^s
406#[wasm_bindgen]
407pub struct WasmProductManifold {
408    inner: ProductManifold,
409}
410
411#[wasm_bindgen]
412impl WasmProductManifold {
413    /// Create a new product manifold
414    ///
415    /// @param euclidean_dim - Dimension of Euclidean component
416    /// @param hyperbolic_dim - Dimension of hyperbolic component
417    /// @param spherical_dim - Dimension of spherical component
418    #[wasm_bindgen(constructor)]
419    pub fn new(euclidean_dim: usize, hyperbolic_dim: usize, spherical_dim: usize) -> Self {
420        Self {
421            inner: ProductManifold::new(euclidean_dim, hyperbolic_dim, spherical_dim),
422        }
423    }
424
425    /// Get total dimension
426    #[wasm_bindgen(getter)]
427    pub fn dim(&self) -> usize {
428        self.inner.dim()
429    }
430
431    /// Project point onto manifold
432    #[wasm_bindgen]
433    pub fn project(&self, point: &[f64]) -> Result<Vec<f64>, JsError> {
434        self.inner
435            .project(point)
436            .map_err(|e| JsError::new(&e.to_string()))
437    }
438
439    /// Compute distance in product manifold
440    #[wasm_bindgen]
441    pub fn distance(&self, x: &[f64], y: &[f64]) -> Result<f64, JsError> {
442        self.inner
443            .distance(x, y)
444            .map_err(|e| JsError::new(&e.to_string()))
445    }
446
447    /// Exponential map
448    #[wasm_bindgen(js_name = expMap)]
449    pub fn exp_map(&self, x: &[f64], v: &[f64]) -> Result<Vec<f64>, JsError> {
450        self.inner
451            .exp_map(x, v)
452            .map_err(|e| JsError::new(&e.to_string()))
453    }
454
455    /// Logarithmic map
456    #[wasm_bindgen(js_name = logMap)]
457    pub fn log_map(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>, JsError> {
458        self.inner
459            .log_map(x, y)
460            .map_err(|e| JsError::new(&e.to_string()))
461    }
462
463    /// Geodesic interpolation
464    #[wasm_bindgen]
465    pub fn geodesic(&self, x: &[f64], y: &[f64], t: f64) -> Result<Vec<f64>, JsError> {
466        self.inner
467            .geodesic(x, y, t)
468            .map_err(|e| JsError::new(&e.to_string()))
469    }
470
471    /// Fréchet mean
472    #[wasm_bindgen(js_name = frechetMean)]
473    pub fn frechet_mean(&self, points: &[f64], _num_points: usize) -> Result<Vec<f64>, JsError> {
474        let dim = self.inner.dim();
475        let pts = to_points(points, dim);
476        self.inner
477            .frechet_mean(&pts, None)
478            .map_err(|e| JsError::new(&e.to_string()))
479    }
480
481    /// K-nearest neighbors
482    #[wasm_bindgen]
483    pub fn knn(&self, query: &[f64], points: &[f64], k: usize) -> Result<Vec<u32>, JsError> {
484        let dim = self.inner.dim();
485        let pts = to_points(points, dim);
486        let neighbors = self
487            .inner
488            .knn(query, &pts, k)
489            .map_err(|e| JsError::new(&e.to_string()))?;
490
491        Ok(neighbors.into_iter().map(|(idx, _)| idx as u32).collect())
492    }
493
494    /// Pairwise distances
495    #[wasm_bindgen(js_name = pairwiseDistances)]
496    pub fn pairwise_distances(&self, points: &[f64]) -> Result<Vec<f64>, JsError> {
497        let dim = self.inner.dim();
498        let pts = to_points(points, dim);
499        let dists = self
500            .inner
501            .pairwise_distances(&pts)
502            .map_err(|e| JsError::new(&e.to_string()))?;
503
504        Ok(dists.into_iter().flatten().collect())
505    }
506}
507
508// ============================================================================
509// Utility functions
510// ============================================================================
511
512/// Convert flat array to vector of points
513fn to_points(flat: &[f64], dim: usize) -> Vec<Vec<f64>> {
514    flat.chunks(dim).map(|c| c.to_vec()).collect()
515}
516
517/// Convert flat array to matrix
518fn to_matrix(flat: &[f64], rows: usize, cols: usize) -> Vec<Vec<f64>> {
519    flat.chunks(cols).take(rows).map(|c| c.to_vec()).collect()
520}
521
522// ============================================================================
523// TypeScript type definitions
524// ============================================================================
525
526#[wasm_bindgen(typescript_custom_section)]
527const TS_TYPES: &'static str = r#"
528/** Sliced Wasserstein distance for comparing point cloud distributions */
529export interface SlicedWassersteinOptions {
530    numProjections?: number;
531    power?: number;
532    seed?: number;
533}
534
535/** Sinkhorn optimal transport options */
536export interface SinkhornOptions {
537    regularization?: number;
538    maxIterations?: number;
539    threshold?: number;
540}
541
542/** Product manifold configuration */
543export interface ProductManifoldConfig {
544    euclideanDim: number;
545    hyperbolicDim: number;
546    sphericalDim: number;
547    hyperbolicCurvature?: number;
548    sphericalCurvature?: number;
549}
550"#;