1use wasm_bindgen::prelude::*;
8use ruvector_math::{
9 optimal_transport::{SlicedWasserstein, SinkhornSolver, GromovWasserstein},
10 information_geometry::{FisherInformation, NaturalGradient},
11 spherical::SphericalSpace,
12 product_manifold::ProductManifold,
13};
14
15#[wasm_bindgen(start)]
16pub fn start() {
17 #[cfg(feature = "console_error_panic_hook")]
18 console_error_panic_hook::set_once();
19}
20
21#[wasm_bindgen]
27pub struct WasmSlicedWasserstein {
28 inner: SlicedWasserstein,
29}
30
31#[wasm_bindgen]
32impl WasmSlicedWasserstein {
33 #[wasm_bindgen(constructor)]
37 pub fn new(num_projections: usize) -> Self {
38 Self {
39 inner: SlicedWasserstein::new(num_projections),
40 }
41 }
42
43 #[wasm_bindgen(js_name = withPower)]
45 pub fn with_power(self, p: f64) -> Self {
46 Self {
47 inner: self.inner.with_power(p),
48 }
49 }
50
51 #[wasm_bindgen(js_name = withSeed)]
53 pub fn with_seed(self, seed: u64) -> Self {
54 Self {
55 inner: self.inner.with_seed(seed),
56 }
57 }
58
59 #[wasm_bindgen]
65 pub fn distance(&self, source: &[f64], target: &[f64], dim: usize) -> f64 {
66 use ruvector_math::optimal_transport::OptimalTransport;
67
68 let source_points = to_points(source, dim);
69 let target_points = to_points(target, dim);
70
71 self.inner.distance(&source_points, &target_points)
72 }
73
74 #[wasm_bindgen(js_name = weightedDistance)]
76 pub fn weighted_distance(
77 &self,
78 source: &[f64],
79 source_weights: &[f64],
80 target: &[f64],
81 target_weights: &[f64],
82 dim: usize,
83 ) -> f64 {
84 use ruvector_math::optimal_transport::OptimalTransport;
85
86 let source_points = to_points(source, dim);
87 let target_points = to_points(target, dim);
88
89 self.inner.weighted_distance(
90 &source_points,
91 source_weights,
92 &target_points,
93 target_weights,
94 )
95 }
96}
97
98#[wasm_bindgen]
100pub struct WasmSinkhorn {
101 inner: SinkhornSolver,
102}
103
104#[wasm_bindgen]
105impl WasmSinkhorn {
106 #[wasm_bindgen(constructor)]
111 pub fn new(regularization: f64, max_iterations: usize) -> Self {
112 Self {
113 inner: SinkhornSolver::new(regularization, max_iterations),
114 }
115 }
116
117 #[wasm_bindgen]
119 pub fn distance(&self, source: &[f64], target: &[f64], dim: usize) -> Result<f64, JsError> {
120 let source_points = to_points(source, dim);
121 let target_points = to_points(target, dim);
122
123 self.inner
124 .distance(&source_points, &target_points)
125 .map_err(|e| JsError::new(&e.to_string()))
126 }
127
128 #[wasm_bindgen(js_name = solveTransport)]
130 pub fn solve_transport(
131 &self,
132 cost_matrix: &[f64],
133 source_weights: &[f64],
134 target_weights: &[f64],
135 n: usize,
136 m: usize,
137 ) -> Result<TransportResult, JsError> {
138 let cost = to_matrix(cost_matrix, n, m);
139
140 let result = self
141 .inner
142 .solve(&cost, source_weights, target_weights)
143 .map_err(|e| JsError::new(&e.to_string()))?;
144
145 Ok(TransportResult {
146 plan: result.plan.into_iter().flatten().collect(),
147 cost: result.cost,
148 iterations: result.iterations,
149 converged: result.converged,
150 })
151 }
152}
153
154#[wasm_bindgen]
156pub struct TransportResult {
157 plan: Vec<f64>,
158 cost: f64,
159 iterations: usize,
160 converged: bool,
161}
162
163#[wasm_bindgen]
164impl TransportResult {
165 #[wasm_bindgen(getter)]
167 pub fn plan(&self) -> Vec<f64> {
168 self.plan.clone()
169 }
170
171 #[wasm_bindgen(getter)]
173 pub fn cost(&self) -> f64 {
174 self.cost
175 }
176
177 #[wasm_bindgen(getter)]
179 pub fn iterations(&self) -> usize {
180 self.iterations
181 }
182
183 #[wasm_bindgen(getter)]
185 pub fn converged(&self) -> bool {
186 self.converged
187 }
188}
189
190#[wasm_bindgen]
192pub struct WasmGromovWasserstein {
193 inner: GromovWasserstein,
194}
195
196#[wasm_bindgen]
197impl WasmGromovWasserstein {
198 #[wasm_bindgen(constructor)]
200 pub fn new(regularization: f64) -> Self {
201 Self {
202 inner: GromovWasserstein::new(regularization),
203 }
204 }
205
206 #[wasm_bindgen]
208 pub fn distance(&self, source: &[f64], target: &[f64], dim: usize) -> Result<f64, JsError> {
209 let source_points = to_points(source, dim);
210 let target_points = to_points(target, dim);
211
212 self.inner
213 .distance(&source_points, &target_points)
214 .map_err(|e| JsError::new(&e.to_string()))
215 }
216}
217
218#[wasm_bindgen]
224pub struct WasmFisherInformation {
225 inner: FisherInformation,
226}
227
228#[wasm_bindgen]
229impl WasmFisherInformation {
230 #[wasm_bindgen(constructor)]
232 pub fn new() -> Self {
233 Self {
234 inner: FisherInformation::new(),
235 }
236 }
237
238 #[wasm_bindgen(js_name = withDamping)]
240 pub fn with_damping(self, damping: f64) -> Self {
241 Self {
242 inner: self.inner.with_damping(damping),
243 }
244 }
245
246 #[wasm_bindgen(js_name = diagonalFim)]
248 pub fn diagonal_fim(&self, gradients: &[f64], _num_samples: usize, dim: usize) -> Result<Vec<f64>, JsError> {
249 let grads = to_points(gradients, dim);
250 self.inner
251 .diagonal_fim(&grads)
252 .map_err(|e| JsError::new(&e.to_string()))
253 }
254
255 #[wasm_bindgen(js_name = naturalGradient)]
257 pub fn natural_gradient(
258 &self,
259 fim_diag: &[f64],
260 gradient: &[f64],
261 damping: f64,
262 ) -> Vec<f64> {
263 gradient
264 .iter()
265 .zip(fim_diag.iter())
266 .map(|(&g, &f)| g / (f + damping))
267 .collect()
268 }
269}
270
271#[wasm_bindgen]
273pub struct WasmNaturalGradient {
274 inner: NaturalGradient,
275}
276
277#[wasm_bindgen]
278impl WasmNaturalGradient {
279 #[wasm_bindgen(constructor)]
281 pub fn new(learning_rate: f64) -> Self {
282 Self {
283 inner: NaturalGradient::new(learning_rate),
284 }
285 }
286
287 #[wasm_bindgen(js_name = withDamping)]
289 pub fn with_damping(self, damping: f64) -> Self {
290 Self {
291 inner: self.inner.with_damping(damping),
292 }
293 }
294
295 #[wasm_bindgen(js_name = withDiagonal)]
297 pub fn with_diagonal(self, use_diagonal: bool) -> Self {
298 Self {
299 inner: self.inner.with_diagonal(use_diagonal),
300 }
301 }
302
303 #[wasm_bindgen]
305 pub fn step(
306 &mut self,
307 gradient: &[f64],
308 gradient_samples: Option<Vec<f64>>,
309 dim: usize,
310 ) -> Result<Vec<f64>, JsError> {
311 let samples = gradient_samples.map(|s| to_points(&s, dim));
312
313 self.inner
314 .step(gradient, samples.as_deref())
315 .map_err(|e| JsError::new(&e.to_string()))
316 }
317
318 #[wasm_bindgen]
320 pub fn reset(&mut self) {
321 self.inner.reset();
322 }
323}
324
325#[wasm_bindgen]
331pub struct WasmSphericalSpace {
332 inner: SphericalSpace,
333}
334
335#[wasm_bindgen]
336impl WasmSphericalSpace {
337 #[wasm_bindgen(constructor)]
339 pub fn new(ambient_dim: usize) -> Self {
340 Self {
341 inner: SphericalSpace::new(ambient_dim),
342 }
343 }
344
345 #[wasm_bindgen(getter, js_name = ambientDim)]
347 pub fn ambient_dim(&self) -> usize {
348 self.inner.ambient_dim()
349 }
350
351 #[wasm_bindgen]
353 pub fn project(&self, point: &[f64]) -> Result<Vec<f64>, JsError> {
354 self.inner
355 .project(point)
356 .map_err(|e| JsError::new(&e.to_string()))
357 }
358
359 #[wasm_bindgen]
361 pub fn distance(&self, x: &[f64], y: &[f64]) -> Result<f64, JsError> {
362 self.inner
363 .distance(x, y)
364 .map_err(|e| JsError::new(&e.to_string()))
365 }
366
367 #[wasm_bindgen(js_name = expMap)]
369 pub fn exp_map(&self, x: &[f64], v: &[f64]) -> Result<Vec<f64>, JsError> {
370 self.inner
371 .exp_map(x, v)
372 .map_err(|e| JsError::new(&e.to_string()))
373 }
374
375 #[wasm_bindgen(js_name = logMap)]
377 pub fn log_map(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>, JsError> {
378 self.inner
379 .log_map(x, y)
380 .map_err(|e| JsError::new(&e.to_string()))
381 }
382
383 #[wasm_bindgen]
385 pub fn geodesic(&self, x: &[f64], y: &[f64], t: f64) -> Result<Vec<f64>, JsError> {
386 self.inner
387 .geodesic(x, y, t)
388 .map_err(|e| JsError::new(&e.to_string()))
389 }
390
391 #[wasm_bindgen(js_name = frechetMean)]
393 pub fn frechet_mean(&self, points: &[f64], dim: usize) -> Result<Vec<f64>, JsError> {
394 let pts = to_points(points, dim);
395 self.inner
396 .frechet_mean(&pts, None)
397 .map_err(|e| JsError::new(&e.to_string()))
398 }
399}
400
401#[wasm_bindgen]
407pub struct WasmProductManifold {
408 inner: ProductManifold,
409}
410
411#[wasm_bindgen]
412impl WasmProductManifold {
413 #[wasm_bindgen(constructor)]
419 pub fn new(euclidean_dim: usize, hyperbolic_dim: usize, spherical_dim: usize) -> Self {
420 Self {
421 inner: ProductManifold::new(euclidean_dim, hyperbolic_dim, spherical_dim),
422 }
423 }
424
425 #[wasm_bindgen(getter)]
427 pub fn dim(&self) -> usize {
428 self.inner.dim()
429 }
430
431 #[wasm_bindgen]
433 pub fn project(&self, point: &[f64]) -> Result<Vec<f64>, JsError> {
434 self.inner
435 .project(point)
436 .map_err(|e| JsError::new(&e.to_string()))
437 }
438
439 #[wasm_bindgen]
441 pub fn distance(&self, x: &[f64], y: &[f64]) -> Result<f64, JsError> {
442 self.inner
443 .distance(x, y)
444 .map_err(|e| JsError::new(&e.to_string()))
445 }
446
447 #[wasm_bindgen(js_name = expMap)]
449 pub fn exp_map(&self, x: &[f64], v: &[f64]) -> Result<Vec<f64>, JsError> {
450 self.inner
451 .exp_map(x, v)
452 .map_err(|e| JsError::new(&e.to_string()))
453 }
454
455 #[wasm_bindgen(js_name = logMap)]
457 pub fn log_map(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>, JsError> {
458 self.inner
459 .log_map(x, y)
460 .map_err(|e| JsError::new(&e.to_string()))
461 }
462
463 #[wasm_bindgen]
465 pub fn geodesic(&self, x: &[f64], y: &[f64], t: f64) -> Result<Vec<f64>, JsError> {
466 self.inner
467 .geodesic(x, y, t)
468 .map_err(|e| JsError::new(&e.to_string()))
469 }
470
471 #[wasm_bindgen(js_name = frechetMean)]
473 pub fn frechet_mean(&self, points: &[f64], _num_points: usize) -> Result<Vec<f64>, JsError> {
474 let dim = self.inner.dim();
475 let pts = to_points(points, dim);
476 self.inner
477 .frechet_mean(&pts, None)
478 .map_err(|e| JsError::new(&e.to_string()))
479 }
480
481 #[wasm_bindgen]
483 pub fn knn(&self, query: &[f64], points: &[f64], k: usize) -> Result<Vec<u32>, JsError> {
484 let dim = self.inner.dim();
485 let pts = to_points(points, dim);
486 let neighbors = self
487 .inner
488 .knn(query, &pts, k)
489 .map_err(|e| JsError::new(&e.to_string()))?;
490
491 Ok(neighbors.into_iter().map(|(idx, _)| idx as u32).collect())
492 }
493
494 #[wasm_bindgen(js_name = pairwiseDistances)]
496 pub fn pairwise_distances(&self, points: &[f64]) -> Result<Vec<f64>, JsError> {
497 let dim = self.inner.dim();
498 let pts = to_points(points, dim);
499 let dists = self
500 .inner
501 .pairwise_distances(&pts)
502 .map_err(|e| JsError::new(&e.to_string()))?;
503
504 Ok(dists.into_iter().flatten().collect())
505 }
506}
507
508fn to_points(flat: &[f64], dim: usize) -> Vec<Vec<f64>> {
514 flat.chunks(dim).map(|c| c.to_vec()).collect()
515}
516
517fn to_matrix(flat: &[f64], rows: usize, cols: usize) -> Vec<Vec<f64>> {
519 flat.chunks(cols).take(rows).map(|c| c.to_vec()).collect()
520}
521
522#[wasm_bindgen(typescript_custom_section)]
527const TS_TYPES: &'static str = r#"
528/** Sliced Wasserstein distance for comparing point cloud distributions */
529export interface SlicedWassersteinOptions {
530 numProjections?: number;
531 power?: number;
532 seed?: number;
533}
534
535/** Sinkhorn optimal transport options */
536export interface SinkhornOptions {
537 regularization?: number;
538 maxIterations?: number;
539 threshold?: number;
540}
541
542/** Product manifold configuration */
543export interface ProductManifoldConfig {
544 euclideanDim: number;
545 hyperbolicDim: number;
546 sphericalDim: number;
547 hyperbolicCurvature?: number;
548 sphericalCurvature?: number;
549}
550"#;