Skip to main content

ruvector_solver_wasm/
lib.rs

1//! WASM bindings for the RuVector sublinear-time solver.
2//!
3//! Exposes a [`JsSolver`] struct that can be constructed from JavaScript and
4//! used to solve sparse linear systems, compute Personalized PageRank, and
5//! estimate solve complexity -- all within the browser or any WASM runtime.
6//!
7//! # Quick Start (JavaScript)
8//!
9//! ```js
10//! import { JsSolver } from "ruvector-solver-wasm";
11//!
12//! const solver = new JsSolver();
13//!
14//! // CSR representation of a 3x3 diagonally-dominant matrix.
15//! const values   = new Float32Array([4, -1, -1, 4, -1, -1, 4]);
16//! const colIdx   = new Uint32Array([0, 1, 0, 1, 2, 1, 2]);
17//! const rowPtrs  = new Uint32Array([0, 2, 5, 7]);
18//! const rhs      = new Float32Array([1, 0, 1]);
19//!
20//! const result = solver.solve(values, colIdx, rowPtrs, 3, 3, rhs);
21//! console.log(result);
22//! ```
23
24mod utils;
25
26use ruvector_solver::types::{
27    Algorithm, ComplexityClass, ComplexityEstimate, CsrMatrix, SparsityProfile,
28};
29use serde::Serialize;
30use wasm_bindgen::prelude::*;
31
32use crate::utils::{console_log, csr_from_js_arrays, set_panic_hook};
33
34// ---------------------------------------------------------------------------
35// Module initialisation
36// ---------------------------------------------------------------------------
37
38/// Called automatically when the WASM module is loaded.
39#[wasm_bindgen(start)]
40pub fn init() {
41    set_panic_hook();
42    console_log("ruvector-solver-wasm module loaded");
43}
44
45/// Return the crate version.
46#[wasm_bindgen]
47pub fn version() -> String {
48    env!("CARGO_PKG_VERSION").to_string()
49}
50
51// ---------------------------------------------------------------------------
52// JsSolver
53// ---------------------------------------------------------------------------
54
55/// Top-level solver handle exposed to JavaScript.
56///
57/// Wraps the algorithm router and iterative solvers, providing a high-level
58/// API that accepts CSR arrays directly from JS typed arrays.
59#[wasm_bindgen]
60pub struct JsSolver {
61    /// Default maximum iterations.
62    max_iterations: usize,
63    /// Default convergence tolerance.
64    tolerance: f64,
65    /// Default teleportation probability for PageRank.
66    alpha: f64,
67}
68
69#[wasm_bindgen]
70impl JsSolver {
71    /// Construct a new solver with default parameters.
72    ///
73    /// - `max_iterations`: 1000
74    /// - `tolerance`: 1e-6
75    /// - `alpha` (PageRank teleport): 0.15
76    #[wasm_bindgen(constructor)]
77    pub fn new() -> Self {
78        Self {
79            max_iterations: 1000,
80            tolerance: 1e-6,
81            alpha: 0.15,
82        }
83    }
84
85    /// Set the maximum number of iterations for iterative solvers.
86    #[wasm_bindgen(js_name = "setMaxIterations")]
87    pub fn set_max_iterations(&mut self, max_iterations: usize) {
88        self.max_iterations = max_iterations;
89    }
90
91    /// Set the convergence tolerance.
92    #[wasm_bindgen(js_name = "setTolerance")]
93    pub fn set_tolerance(&mut self, tolerance: f64) {
94        self.tolerance = tolerance;
95    }
96
97    /// Set the teleportation probability for PageRank.
98    #[wasm_bindgen(js_name = "setAlpha")]
99    pub fn set_alpha(&mut self, alpha: f64) {
100        self.alpha = alpha;
101    }
102
103    // -----------------------------------------------------------------------
104    // Solve Ax = b
105    // -----------------------------------------------------------------------
106
107    /// Solve a sparse linear system `Ax = b`.
108    ///
109    /// The matrix `A` is provided in CSR format via three flat arrays.
110    /// Returns a JSON-serialisable result object on success.
111    ///
112    /// # Arguments
113    ///
114    /// * `values`      - Non-zero values (`Float32Array`).
115    /// * `col_indices` - Column indices for each non-zero (`Uint32Array`).
116    /// * `row_ptrs`    - Row pointers of length `rows + 1` (`Uint32Array`).
117    /// * `rows`        - Number of rows.
118    /// * `cols`        - Number of columns.
119    /// * `rhs`         - Right-hand side vector `b` (`Float32Array`).
120    ///
121    /// # Errors
122    ///
123    /// Returns `JsError` on invalid input or non-convergence.
124    pub fn solve(
125        &self,
126        values: &[f32],
127        col_indices: &[u32],
128        row_ptrs: &[u32],
129        rows: usize,
130        cols: usize,
131        rhs: &[f32],
132    ) -> Result<JsValue, JsError> {
133        let csr = csr_from_js_arrays(values, col_indices, row_ptrs, rows, cols)
134            .map_err(|e| JsError::new(&e))?;
135
136        if rows != cols {
137            return Err(JsError::new(
138                "solve requires a square matrix (rows must equal cols)",
139            ));
140        }
141        if rhs.len() != rows {
142            return Err(JsError::new(&format!(
143                "rhs length {} does not match matrix rows {}",
144                rhs.len(),
145                rows,
146            )));
147        }
148
149        // Analyse sparsity to choose the algorithm.
150        let profile = analyze_sparsity(&csr);
151        let algorithm = select_algorithm(&profile);
152
153        // Perform the solve.
154        let start = js_sys::Date::now();
155        let result = match algorithm {
156            Algorithm::Neumann => neumann_solve(&csr, rhs, self.tolerance, self.max_iterations),
157            Algorithm::CG => cg_solve(&csr, rhs, self.tolerance, self.max_iterations),
158            _ => {
159                // Fallback: try Neumann first, then CG.
160                let nr = neumann_solve(&csr, rhs, self.tolerance, self.max_iterations);
161                if nr.converged {
162                    nr
163                } else {
164                    cg_solve(&csr, rhs, self.tolerance, self.max_iterations)
165                }
166            }
167        };
168        let elapsed_us = ((js_sys::Date::now() - start) * 1000.0) as u64;
169
170        let js_result = JsSolverResult {
171            solution: result.solution,
172            iterations: result.iterations,
173            residual: result.residual,
174            converged: result.converged,
175            algorithm: result.algorithm.to_string(),
176            time_us: elapsed_us,
177        };
178
179        serde_wasm_bindgen::to_value(&js_result)
180            .map_err(|e| JsError::new(&format!("serialisation error: {}", e)))
181    }
182
183    // -----------------------------------------------------------------------
184    // Personalized PageRank
185    // -----------------------------------------------------------------------
186
187    /// Compute Personalized PageRank from a single source node.
188    ///
189    /// Uses the power-iteration method with teleportation probability `alpha`
190    /// (configurable via [`set_alpha`](JsSolver::set_alpha)).
191    ///
192    /// # Arguments
193    ///
194    /// * `values`      - Edge weights (`Float32Array`).
195    /// * `col_indices` - Column indices (`Uint32Array`).
196    /// * `row_ptrs`    - Row pointers (`Uint32Array`).
197    /// * `rows`        - Number of nodes.
198    /// * `source`      - Source node index.
199    /// * `tolerance`   - Convergence tolerance (L1 residual).
200    ///
201    /// # Errors
202    ///
203    /// Returns `JsError` on invalid input.
204    pub fn pagerank(
205        &self,
206        values: &[f32],
207        col_indices: &[u32],
208        row_ptrs: &[u32],
209        rows: usize,
210        source: usize,
211        tolerance: f64,
212    ) -> Result<JsValue, JsError> {
213        let csr = csr_from_js_arrays(values, col_indices, row_ptrs, rows, rows)
214            .map_err(|e| JsError::new(&e))?;
215
216        if source >= rows {
217            return Err(JsError::new(&format!(
218                "source node {} out of bounds (graph has {} nodes)",
219                source, rows,
220            )));
221        }
222
223        let tol = if tolerance > 0.0 {
224            tolerance
225        } else {
226            self.tolerance
227        };
228
229        let start = js_sys::Date::now();
230        let result = power_iteration_ppr(&csr, source, self.alpha, tol, self.max_iterations);
231        let elapsed_us = ((js_sys::Date::now() - start) * 1000.0) as u64;
232
233        let js_result = JsPageRankResult {
234            scores: result.scores,
235            iterations: result.iterations,
236            residual: result.residual,
237            converged: result.converged,
238            time_us: elapsed_us,
239        };
240
241        serde_wasm_bindgen::to_value(&js_result)
242            .map_err(|e| JsError::new(&format!("serialisation error: {}", e)))
243    }
244
245    // -----------------------------------------------------------------------
246    // Complexity estimation
247    // -----------------------------------------------------------------------
248
249    /// Estimate the computational complexity of solving a system with the
250    /// given matrix without performing the actual solve.
251    ///
252    /// Returns a JSON object with the selected algorithm, estimated FLOPS,
253    /// estimated iterations, memory usage, and complexity class.
254    #[wasm_bindgen(js_name = "estimateComplexity")]
255    pub fn estimate_complexity(
256        &self,
257        values: &[f32],
258        col_indices: &[u32],
259        row_ptrs: &[u32],
260        rows: usize,
261        cols: usize,
262    ) -> Result<JsValue, JsError> {
263        let csr = csr_from_js_arrays(values, col_indices, row_ptrs, rows, cols)
264            .map_err(|e| JsError::new(&e))?;
265
266        let profile = analyze_sparsity(&csr);
267        let algorithm = select_algorithm(&profile);
268        let estimate = build_complexity_estimate(&profile, algorithm);
269
270        let js_est = JsComplexityEstimate {
271            algorithm: algorithm.to_string(),
272            estimated_flops: estimate.estimated_flops,
273            estimated_iterations: estimate.estimated_iterations,
274            estimated_memory_bytes: estimate.estimated_memory_bytes,
275            complexity_class: format!("{:?}", estimate.complexity_class),
276            density: profile.density,
277            is_diag_dominant: profile.is_diag_dominant,
278            estimated_spectral_radius: profile.estimated_spectral_radius,
279        };
280
281        serde_wasm_bindgen::to_value(&js_est)
282            .map_err(|e| JsError::new(&format!("serialisation error: {}", e)))
283    }
284}
285
286// ---------------------------------------------------------------------------
287// JS-facing result types (serde-serialisable)
288// ---------------------------------------------------------------------------
289
290/// JSON-serialisable solve result returned to JavaScript.
291#[derive(Serialize)]
292struct JsSolverResult {
293    solution: Vec<f32>,
294    iterations: usize,
295    residual: f64,
296    converged: bool,
297    algorithm: String,
298    time_us: u64,
299}
300
301/// JSON-serialisable PageRank result.
302#[derive(Serialize)]
303struct JsPageRankResult {
304    scores: Vec<f32>,
305    iterations: usize,
306    residual: f64,
307    converged: bool,
308    time_us: u64,
309}
310
311/// JSON-serialisable complexity estimate.
312#[derive(Serialize)]
313struct JsComplexityEstimate {
314    algorithm: String,
315    estimated_flops: u64,
316    estimated_iterations: usize,
317    estimated_memory_bytes: usize,
318    complexity_class: String,
319    density: f64,
320    is_diag_dominant: bool,
321    estimated_spectral_radius: f64,
322}
323
324// ---------------------------------------------------------------------------
325// Internal solver result (before JS conversion)
326// ---------------------------------------------------------------------------
327
328struct InternalSolveResult {
329    solution: Vec<f32>,
330    iterations: usize,
331    residual: f64,
332    converged: bool,
333    algorithm: Algorithm,
334}
335
336struct InternalPprResult {
337    scores: Vec<f32>,
338    iterations: usize,
339    residual: f64,
340    converged: bool,
341}
342
343// ---------------------------------------------------------------------------
344// Sparsity analysis
345// ---------------------------------------------------------------------------
346
347/// Analyse the sparsity structure of a CSR matrix to inform algorithm
348/// selection.
349fn analyze_sparsity(csr: &CsrMatrix<f32>) -> SparsityProfile {
350    let nnz = csr.values.len();
351    let n = csr.rows;
352    let total_elements = if n > 0 && csr.cols > 0 {
353        n * csr.cols
354    } else {
355        1
356    };
357    let density = nnz as f64 / total_elements as f64;
358
359    let mut is_diag_dominant = true;
360    let mut max_nnz_per_row: usize = 0;
361    let mut est_spectral_sum = 0.0f64;
362    let mut symmetric_check = true;
363
364    for row in 0..n {
365        let start = csr.row_ptr[row];
366        let end = csr.row_ptr[row + 1];
367        let row_nnz = end - start;
368        if row_nnz > max_nnz_per_row {
369            max_nnz_per_row = row_nnz;
370        }
371
372        let mut diag_val = 0.0f64;
373        let mut off_diag_sum = 0.0f64;
374
375        for idx in start..end {
376            let col = csr.col_indices[idx];
377            let val = csr.values[idx] as f64;
378            if col == row {
379                diag_val = val.abs();
380            } else {
381                off_diag_sum += val.abs();
382            }
383        }
384
385        if diag_val <= off_diag_sum && diag_val > 0.0 {
386            is_diag_dominant = false;
387        }
388        if diag_val > 0.0 {
389            est_spectral_sum += off_diag_sum / diag_val;
390        } else if off_diag_sum > 0.0 {
391            is_diag_dominant = false;
392            est_spectral_sum += 1.0; // pessimistic
393        }
394    }
395
396    let estimated_spectral_radius = if n > 0 {
397        est_spectral_sum / n as f64
398    } else {
399        0.0
400    };
401
402    // Quick structural symmetry check (sample-based for large matrices).
403    let check_limit = n.min(64);
404    'outer: for row in 0..check_limit {
405        let start = csr.row_ptr[row];
406        let end = csr.row_ptr[row + 1];
407        for idx in start..end {
408            let col = csr.col_indices[idx];
409            if col >= n || col == row {
410                continue;
411            }
412            // Check if (col, row) entry exists.
413            let col_start = csr.row_ptr[col];
414            let col_end = csr.row_ptr[col + 1];
415            let found = csr.col_indices[col_start..col_end]
416                .iter()
417                .any(|&c| c == row);
418            if !found {
419                symmetric_check = false;
420                break 'outer;
421            }
422        }
423    }
424
425    let avg_nnz = if n > 0 { nnz as f64 / n as f64 } else { 0.0 };
426
427    // Rough condition estimate from spectral radius.
428    let estimated_condition = if estimated_spectral_radius < 1.0 {
429        1.0 / (1.0 - estimated_spectral_radius)
430    } else {
431        estimated_spectral_radius * 100.0 // pessimistic
432    };
433
434    SparsityProfile {
435        rows: n,
436        cols: csr.cols,
437        nnz,
438        density,
439        is_diag_dominant,
440        estimated_spectral_radius,
441        estimated_condition,
442        is_symmetric_structure: symmetric_check,
443        avg_nnz_per_row: avg_nnz,
444        max_nnz_per_row,
445    }
446}
447
448// ---------------------------------------------------------------------------
449// Algorithm selection
450// ---------------------------------------------------------------------------
451
452/// Select the best algorithm given a sparsity profile.
453fn select_algorithm(profile: &SparsityProfile) -> Algorithm {
454    // Neumann series requires spectral radius < 1.
455    if profile.is_diag_dominant && profile.estimated_spectral_radius < 0.95 {
456        return Algorithm::Neumann;
457    }
458
459    // CG is good for symmetric positive-definite systems.
460    if profile.is_symmetric_structure && profile.is_diag_dominant {
461        return Algorithm::CG;
462    }
463
464    // Default: CG for general sparse systems.
465    Algorithm::CG
466}
467
468// ---------------------------------------------------------------------------
469// Neumann series solver
470// ---------------------------------------------------------------------------
471
472/// Neumann series solver for diagonally dominant systems.
473///
474/// Computes `x = sum_{k=0}^{K} (I - D^{-1} A)^k D^{-1} b` where `D` is the
475/// diagonal of `A`. This converges when the spectral radius of `D^{-1}(A - D)`
476/// is less than 1.
477fn neumann_solve(
478    csr: &CsrMatrix<f32>,
479    rhs: &[f32],
480    tolerance: f64,
481    max_iterations: usize,
482) -> InternalSolveResult {
483    let n = csr.rows;
484
485    // Extract diagonal and compute D^{-1} b.
486    let mut diag_inv = vec![0.0f32; n];
487    for row in 0..n {
488        let start = csr.row_ptr[row];
489        let end = csr.row_ptr[row + 1];
490        for idx in start..end {
491            if csr.col_indices[idx] == row {
492                let d = csr.values[idx];
493                diag_inv[row] = if d.abs() > 1e-30 { 1.0 / d } else { 0.0 };
494                break;
495            }
496        }
497    }
498
499    // x = D^{-1} b  (initial approximation: zeroth-order term).
500    let mut x: Vec<f32> = rhs
501        .iter()
502        .zip(diag_inv.iter())
503        .map(|(&b, &di)| b * di)
504        .collect();
505
506    // Iterate: x_{k+1} = x_k + D^{-1} r_k   where r_k = b - A x_k.
507    let mut residual_buf = vec![0.0f32; n];
508    let mut converged = false;
509    let mut iterations = 0;
510    let mut final_residual = f64::MAX;
511
512    for k in 0..max_iterations {
513        // Compute r = b - A x.
514        spmv(csr, &x, &mut residual_buf);
515        for i in 0..n {
516            residual_buf[i] = rhs[i] - residual_buf[i];
517        }
518
519        // Residual norm.
520        let res_norm: f64 = residual_buf
521            .iter()
522            .map(|&r| (r as f64) * (r as f64))
523            .sum::<f64>()
524            .sqrt();
525
526        final_residual = res_norm;
527        iterations = k + 1;
528
529        if res_norm < tolerance {
530            converged = true;
531            break;
532        }
533
534        // Check for divergence.
535        if !res_norm.is_finite() {
536            break;
537        }
538
539        // Update: x += D^{-1} r.
540        for i in 0..n {
541            x[i] += diag_inv[i] * residual_buf[i];
542        }
543    }
544
545    InternalSolveResult {
546        solution: x,
547        iterations,
548        residual: final_residual,
549        converged,
550        algorithm: Algorithm::Neumann,
551    }
552}
553
554// ---------------------------------------------------------------------------
555// Conjugate Gradient solver
556// ---------------------------------------------------------------------------
557
558/// Conjugate Gradient solver for symmetric positive-definite systems.
559///
560/// Standard CG with residual-based convergence detection.
561fn cg_solve(
562    csr: &CsrMatrix<f32>,
563    rhs: &[f32],
564    tolerance: f64,
565    max_iterations: usize,
566) -> InternalSolveResult {
567    let n = csr.rows;
568
569    // x_0 = 0, r_0 = b, p_0 = r_0.
570    let mut x = vec![0.0f32; n];
571    let mut r: Vec<f32> = rhs.to_vec();
572    let mut p: Vec<f32> = rhs.to_vec();
573    let mut ap = vec![0.0f32; n];
574
575    let mut rr: f64 = r.iter().map(|&v| (v as f64) * (v as f64)).sum();
576    let mut converged = false;
577    let mut iterations = 0;
578    let mut final_residual = rr.sqrt();
579
580    if final_residual < tolerance {
581        return InternalSolveResult {
582            solution: x,
583            iterations: 0,
584            residual: final_residual,
585            converged: true,
586            algorithm: Algorithm::CG,
587        };
588    }
589
590    for k in 0..max_iterations {
591        // ap = A * p.
592        spmv(csr, &p, &mut ap);
593
594        // alpha = r^T r / (p^T A p).
595        let pap: f64 = p
596            .iter()
597            .zip(ap.iter())
598            .map(|(&pi, &ai)| (pi as f64) * (ai as f64))
599            .sum();
600
601        if pap.abs() < 1e-30 {
602            // Breakdown: p is in the null space.
603            iterations = k + 1;
604            break;
605        }
606
607        let alpha = rr / pap;
608        let alpha_f32 = alpha as f32;
609
610        // x += alpha * p.
611        for i in 0..n {
612            x[i] += alpha_f32 * p[i];
613        }
614
615        // r -= alpha * A p.
616        for i in 0..n {
617            r[i] -= alpha_f32 * ap[i];
618        }
619
620        let rr_new: f64 = r.iter().map(|&v| (v as f64) * (v as f64)).sum();
621        final_residual = rr_new.sqrt();
622        iterations = k + 1;
623
624        if final_residual < tolerance {
625            converged = true;
626            break;
627        }
628
629        if !rr_new.is_finite() {
630            break;
631        }
632
633        // beta = r_{k+1}^T r_{k+1} / r_k^T r_k.
634        let beta = rr_new / rr;
635        let beta_f32 = beta as f32;
636
637        // p = r + beta * p.
638        for i in 0..n {
639            p[i] = r[i] + beta_f32 * p[i];
640        }
641
642        rr = rr_new;
643    }
644
645    InternalSolveResult {
646        solution: x,
647        iterations,
648        residual: final_residual,
649        converged,
650        algorithm: Algorithm::CG,
651    }
652}
653
654// ---------------------------------------------------------------------------
655// Power-iteration PPR
656// ---------------------------------------------------------------------------
657
658/// Power iteration for Personalized PageRank.
659///
660/// Computes `pi = alpha * s + (1 - alpha) * M^T pi` where `s` is the source
661/// distribution and `M` is the row-normalised transition matrix.
662fn power_iteration_ppr(
663    csr: &CsrMatrix<f32>,
664    source: usize,
665    alpha: f64,
666    tolerance: f64,
667    max_iterations: usize,
668) -> InternalPprResult {
669    let n = csr.rows;
670    let alpha_f32 = alpha as f32;
671    let one_minus_alpha = (1.0 - alpha) as f32;
672
673    // Compute row sums (out-degree) for normalisation.
674    let mut row_sums = vec![0.0f32; n];
675    for row in 0..n {
676        let start = csr.row_ptr[row];
677        let end = csr.row_ptr[row + 1];
678        let sum: f32 = csr.values[start..end].iter().sum();
679        // Dangling nodes get uniform teleport.
680        row_sums[row] = if sum > 0.0 { sum } else { 1.0 };
681    }
682
683    // pi starts as the source distribution.
684    let mut pi = vec![0.0f32; n];
685    pi[source] = 1.0;
686
687    let mut new_pi = vec![0.0f32; n];
688    let mut converged = false;
689    let mut iterations = 0;
690    let mut final_residual = f64::MAX;
691
692    for k in 0..max_iterations {
693        // new_pi = alpha * e_source + (1-alpha) * M^T * pi
694        // where M[i][j] = A[i][j] / row_sum[i].
695        new_pi.fill(0.0);
696
697        // Scatter: for each row i, distribute pi[i] to neighbours.
698        for row in 0..n {
699            if pi[row] == 0.0 {
700                continue;
701            }
702            let start = csr.row_ptr[row];
703            let end = csr.row_ptr[row + 1];
704            let inv_deg = pi[row] / row_sums[row];
705
706            for idx in start..end {
707                let col = csr.col_indices[idx];
708                new_pi[col] += one_minus_alpha * csr.values[idx] * inv_deg;
709            }
710        }
711
712        // Teleportation.
713        new_pi[source] += alpha_f32;
714
715        // L1 residual.
716        let l1_diff: f64 = pi
717            .iter()
718            .zip(new_pi.iter())
719            .map(|(&a, &b)| ((a - b) as f64).abs())
720            .sum();
721
722        std::mem::swap(&mut pi, &mut new_pi);
723        final_residual = l1_diff;
724        iterations = k + 1;
725
726        if l1_diff < tolerance {
727            converged = true;
728            break;
729        }
730
731        if !l1_diff.is_finite() {
732            break;
733        }
734    }
735
736    InternalPprResult {
737        scores: pi,
738        iterations,
739        residual: final_residual,
740        converged,
741    }
742}
743
744// ---------------------------------------------------------------------------
745// Complexity estimation
746// ---------------------------------------------------------------------------
747
748/// Build a [`ComplexityEstimate`] based on the sparsity profile and selected
749/// algorithm.
750fn build_complexity_estimate(
751    profile: &SparsityProfile,
752    algorithm: Algorithm,
753) -> ComplexityEstimate {
754    let n = profile.rows;
755    let nnz = profile.nnz;
756
757    match algorithm {
758        Algorithm::Neumann => {
759            // O(nnz * log(1/eps)) iterations; each iteration is O(nnz).
760            let est_iters = if profile.estimated_spectral_radius < 1.0 {
761                ((1.0 / (1.0 - profile.estimated_spectral_radius)).ln() * 10.0).ceil() as usize
762            } else {
763                1000
764            };
765            let est_flops = (nnz as u64) * (est_iters as u64) * 2;
766
767            ComplexityEstimate {
768                algorithm,
769                estimated_flops: est_flops,
770                estimated_iterations: est_iters,
771                estimated_memory_bytes: n * 4 * 3, // x, r, diag_inv
772                complexity_class: ComplexityClass::SublinearNnz,
773            }
774        }
775        Algorithm::CG => {
776            // CG converges in O(sqrt(kappa)) iterations.
777            let kappa = profile.estimated_condition.max(1.0);
778            let est_iters = (kappa.sqrt() * 2.0).ceil().min(n as f64) as usize;
779            let est_flops = (nnz as u64) * (est_iters as u64) * 2;
780
781            ComplexityEstimate {
782                algorithm,
783                estimated_flops: est_flops,
784                estimated_iterations: est_iters,
785                estimated_memory_bytes: n * 4 * 4, // x, r, p, Ap
786                complexity_class: ComplexityClass::SqrtCondition,
787            }
788        }
789        Algorithm::ForwardPush | Algorithm::BackwardPush => {
790            // O(1/epsilon) work, sublinear in graph size.
791            let est_iters = 1000;
792            ComplexityEstimate {
793                algorithm,
794                estimated_flops: est_iters as u64 * profile.avg_nnz_per_row.ceil() as u64,
795                estimated_iterations: est_iters,
796                estimated_memory_bytes: n * 8 * 2, // estimate + residual
797                complexity_class: ComplexityClass::SublinearNnz,
798            }
799        }
800        _ => {
801            // Conservative fallback.
802            ComplexityEstimate {
803                algorithm,
804                estimated_flops: (nnz as u64) * (n as u64),
805                estimated_iterations: n,
806                estimated_memory_bytes: n * n * 4,
807                complexity_class: ComplexityClass::Quadratic,
808            }
809        }
810    }
811}
812
813// ---------------------------------------------------------------------------
814// Low-level utilities
815// ---------------------------------------------------------------------------
816
817/// Sparse matrix-vector product `y = A * x` using the types::CsrMatrix layout.
818#[inline]
819fn spmv(csr: &CsrMatrix<f32>, x: &[f32], y: &mut [f32]) {
820    y.iter_mut().for_each(|v| *v = 0.0);
821    for row in 0..csr.rows {
822        let start = csr.row_ptr[row];
823        let end = csr.row_ptr[row + 1];
824        let mut sum = 0.0f32;
825        for idx in start..end {
826            sum += csr.values[idx] * x[csr.col_indices[idx]];
827        }
828        y[row] = sum;
829    }
830}
831
832// ---------------------------------------------------------------------------
833// Tests
834// ---------------------------------------------------------------------------
835
836#[cfg(test)]
837mod tests {
838    use super::*;
839
840    /// Helper: build a 3x3 diagonally dominant test matrix.
841    ///  [[ 4, -1,  0],
842    ///   [-1,  4, -1],
843    ///   [ 0, -1,  4]]
844    fn test_matrix() -> (Vec<f32>, Vec<u32>, Vec<u32>, usize, usize) {
845        let values = vec![4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0];
846        let col_indices = vec![0, 1, 0, 1, 2, 1, 2];
847        let row_ptrs = vec![0, 2, 5, 7];
848        (values, col_indices, row_ptrs, 3, 3)
849    }
850
851    #[test]
852    fn test_analyze_sparsity() {
853        let (vals, cols, ptrs, rows, c) = test_matrix();
854        let csr = csr_from_js_arrays(&vals, &cols, &ptrs, rows, c).unwrap();
855        let profile = analyze_sparsity(&csr);
856
857        assert_eq!(profile.rows, 3);
858        assert_eq!(profile.cols, 3);
859        assert_eq!(profile.nnz, 7);
860        assert!(profile.is_diag_dominant);
861        assert!(profile.estimated_spectral_radius < 1.0);
862    }
863
864    #[test]
865    fn test_select_algorithm_neumann_for_diag_dominant() {
866        let (vals, cols, ptrs, rows, c) = test_matrix();
867        let csr = csr_from_js_arrays(&vals, &cols, &ptrs, rows, c).unwrap();
868        let profile = analyze_sparsity(&csr);
869        let algo = select_algorithm(&profile);
870        assert_eq!(algo, Algorithm::Neumann);
871    }
872
873    #[test]
874    fn test_neumann_solve_identity() {
875        // Identity matrix: solution should equal rhs.
876        let values = vec![1.0f32, 1.0, 1.0];
877        let col_indices = vec![0u32, 1, 2];
878        let row_ptrs = vec![0u32, 1, 2, 3];
879        let csr = csr_from_js_arrays(&values, &col_indices, &row_ptrs, 3, 3).unwrap();
880        let rhs = vec![1.0, 2.0, 3.0];
881
882        let result = neumann_solve(&csr, &rhs, 1e-6, 100);
883        assert!(result.converged);
884        for (i, &v) in result.solution.iter().enumerate() {
885            assert!(
886                (v - rhs[i]).abs() < 1e-4,
887                "solution[{}] = {} != {}",
888                i,
889                v,
890                rhs[i],
891            );
892        }
893    }
894
895    #[test]
896    fn test_neumann_solve_tridiagonal() {
897        let (vals, cols, ptrs, rows, c) = test_matrix();
898        let csr = csr_from_js_arrays(&vals, &cols, &ptrs, rows, c).unwrap();
899        let rhs = vec![1.0, 0.0, 1.0];
900
901        let result = neumann_solve(&csr, &rhs, 1e-6, 1000);
902        assert!(result.converged, "residual = {}", result.residual);
903        assert!(result.iterations < 100);
904
905        // Verify A * x ~ b.
906        let mut ax = vec![0.0f32; rows];
907        spmv(&csr, &result.solution, &mut ax);
908        for i in 0..rows {
909            assert!(
910                (ax[i] - rhs[i]).abs() < 1e-4,
911                "Ax[{}] = {} != {}",
912                i,
913                ax[i],
914                rhs[i],
915            );
916        }
917    }
918
919    #[test]
920    fn test_cg_solve_tridiagonal() {
921        let (vals, cols, ptrs, rows, c) = test_matrix();
922        let csr = csr_from_js_arrays(&vals, &cols, &ptrs, rows, c).unwrap();
923        let rhs = vec![1.0, 0.0, 1.0];
924
925        let result = cg_solve(&csr, &rhs, 1e-6, 1000);
926        assert!(result.converged, "residual = {}", result.residual);
927
928        let mut ax = vec![0.0f32; rows];
929        spmv(&csr, &result.solution, &mut ax);
930        for i in 0..rows {
931            assert!(
932                (ax[i] - rhs[i]).abs() < 1e-4,
933                "Ax[{}] = {} != {}",
934                i,
935                ax[i],
936                rhs[i],
937            );
938        }
939    }
940
941    #[test]
942    fn test_power_iteration_ppr_convergence() {
943        // Simple 3-node chain: 0 -> 1 -> 2 -> 0.
944        let values = vec![1.0f32, 1.0, 1.0];
945        let col_indices = vec![1u32, 2, 0];
946        let row_ptrs = vec![0u32, 1, 2, 3];
947        let csr = csr_from_js_arrays(&values, &col_indices, &row_ptrs, 3, 3).unwrap();
948
949        let result = power_iteration_ppr(&csr, 0, 0.15, 1e-6, 1000);
950        assert!(result.converged, "residual = {}", result.residual);
951
952        // Source node should have highest PPR score.
953        assert!(result.scores[0] > result.scores[1]);
954        assert!(result.scores[0] > result.scores[2]);
955
956        // Scores should approximately sum to 1.
957        let sum: f32 = result.scores.iter().sum();
958        assert!((sum - 1.0).abs() < 0.1, "sum = {}", sum);
959    }
960
961    #[test]
962    fn test_complexity_estimate() {
963        let (vals, cols, ptrs, rows, c) = test_matrix();
964        let csr = csr_from_js_arrays(&vals, &cols, &ptrs, rows, c).unwrap();
965        let profile = analyze_sparsity(&csr);
966        let est = build_complexity_estimate(&profile, Algorithm::Neumann);
967
968        assert_eq!(est.algorithm, Algorithm::Neumann);
969        assert!(est.estimated_flops > 0);
970        assert!(est.estimated_iterations > 0);
971        assert!(est.estimated_memory_bytes > 0);
972        assert_eq!(est.complexity_class, ComplexityClass::SublinearNnz);
973    }
974
975    #[test]
976    fn test_spmv_basic() {
977        // [[2, 1], [0, 3]] * [1, 2] = [4, 6]
978        let csr = CsrMatrix {
979            row_ptr: vec![0, 2, 3],
980            col_indices: vec![0, 1, 1],
981            values: vec![2.0f32, 1.0, 3.0],
982            rows: 2,
983            cols: 2,
984        };
985        let x = [1.0f32, 2.0];
986        let mut y = [0.0f32; 2];
987        spmv(&csr, &x, &mut y);
988        assert_eq!(y, [4.0, 6.0]);
989    }
990
991    #[test]
992    fn test_version_not_empty() {
993        assert!(!version().is_empty());
994    }
995}