Skip to main content

scirs2_optimize/sketched/
mod.rs

1//! Sketched Gradient Descent for Large-Scale Least-Squares
2//!
3//! Randomized sketching compresses the m×n system (m rows, n columns) into a
4//! sketch of dimension (sketch_dim × n), dramatically reducing memory and compute
5//! requirements when m is very large.
6//!
7//! ## Algorithm
8//!
9//! At each iteration t:
10//! 1. Draw sketch matrix S_t of shape (sketch_dim × m).
11//! 2. Form sketched system: Ã = S A, b̃ = S b.
12//! 3. Compute sketched gradient: g = Ã^T (Ã x - b̃).
13//! 4. Update: x ← x - α g.
14//!
15//! ## Sketch Types
16//!
17//! - **Gaussian**: Each entry of S is drawn i.i.d. N(0, 1/sketch_dim).
18//! - **Hadamard** (SRHT): Structured random Hadamard transform — applied as Walsh-Hadamard
19//!   transform followed by random sign-flips and row sampling.
20//! - **Uniform** (Rademacher): Entries are uniform ±1 / √sketch_dim.
21//! - **CountSketch**: Each column has exactly one non-zero entry ±1. Very sparse and fast.
22//!
23//! ## References
24//!
25//! - Mahoney, M.W. (2011). "Randomized Algorithms for Matrices and Data"
26//! - Woodruff, D.P. (2014). "Sketching as a Tool for Numerical Linear Algebra"
27//! - Drineas, P. et al. (2011). "Faster Least Squares Approximation"
28
29use crate::error::{OptimizeError, OptimizeResult};
30use scirs2_core::ndarray::Array2;
31use scirs2_core::random::{rngs::StdRng, RngExt, SeedableRng};
32
33/// Type of sketch matrix to use for dimensionality reduction.
34#[non_exhaustive]
35#[derive(Debug, Clone, PartialEq)]
36pub enum SketchType {
37    /// Dense Gaussian sketch: S_{ij} ~ N(0, 1/sketch_dim).
38    Gaussian,
39    /// Subsampled Randomized Hadamard Transform (SRHT).
40    Hadamard,
41    /// Rademacher (uniform ±1/√sketch_dim) sketch.
42    Uniform,
43    /// Count sketch: each column of S has exactly one non-zero ±1 entry.
44    CountSketch,
45}
46
47impl Default for SketchType {
48    fn default() -> Self {
49        SketchType::Gaussian
50    }
51}
52
53/// Configuration for sketched least-squares solver.
54#[derive(Clone, Debug)]
55pub struct SketchedLeastSquaresConfig {
56    /// Sketch dimension m_s (number of rows in the sketch matrix, m_s << m).
57    pub sketch_dim: usize,
58    /// Type of random sketch.
59    pub sketch_type: SketchType,
60    /// Maximum number of iterations.
61    pub max_iter: usize,
62    /// Convergence tolerance on the relative change in x: ||x_{t+1} - x_t|| / (1 + ||x_t||).
63    pub tol: f64,
64    /// Random seed for reproducibility.
65    pub seed: u64,
66    /// Whether to refresh the sketch at every iteration (recommended for accuracy).
67    pub refresh_sketch: bool,
68    /// Step size (learning rate) for gradient updates.
69    /// When `None`, uses the sketch-based Lipschitz estimate.
70    pub step_size: Option<f64>,
71}
72
73impl Default for SketchedLeastSquaresConfig {
74    fn default() -> Self {
75        Self {
76            sketch_dim: 512,
77            sketch_type: SketchType::Gaussian,
78            max_iter: 100,
79            tol: 1e-6,
80            seed: 42,
81            refresh_sketch: true,
82            step_size: None,
83        }
84    }
85}
86
87/// Result of sketched least-squares optimization.
88#[derive(Debug, Clone)]
89pub struct LsqResult {
90    /// Approximate solution x minimizing ||Ax - b||².
91    pub x: Vec<f64>,
92    /// Euclidean norm of the final residual ||Ax - b||.
93    pub residual_norm: f64,
94    /// Number of iterations performed.
95    pub n_iter: usize,
96    /// Whether the algorithm converged.
97    pub converged: bool,
98}
99
100// ─── Sketch construction helpers ─────────────────────────────────────────────
101
102/// Build a Gaussian sketch matrix S of shape (sketch_dim × m).
103///
104/// Entries S_{ij} ~ N(0, 1/sketch_dim) using Box-Muller transform.
105fn build_gaussian_sketch(sketch_dim: usize, m: usize, rng: &mut StdRng) -> Vec<f64> {
106    let scale = (1.0 / sketch_dim as f64).sqrt();
107    let mut s = Vec::with_capacity(sketch_dim * m);
108    // Box-Muller: generate pairs of standard normals
109    let mut spare: Option<f64> = None;
110    for _ in 0..(sketch_dim * m) {
111        let v = match spare.take() {
112            Some(z) => z,
113            None => {
114                // Box-Muller
115                loop {
116                    let u: f64 = rng.random::<f64>();
117                    let v: f64 = rng.random::<f64>();
118                    if u > 0.0 {
119                        let mag = (-2.0 * u.ln()).sqrt();
120                        let angle = std::f64::consts::TAU * v;
121                        spare = Some(mag * angle.sin());
122                        break mag * angle.cos();
123                    }
124                }
125            }
126        };
127        s.push(v * scale);
128    }
129    s
130}
131
132/// Build a Rademacher sketch matrix S of shape (sketch_dim × m).
133///
134/// Entries are uniform ±1 / √sketch_dim.
135fn build_rademacher_sketch(sketch_dim: usize, m: usize, rng: &mut StdRng) -> Vec<f64> {
136    let scale = 1.0 / (sketch_dim as f64).sqrt();
137    (0..sketch_dim * m)
138        .map(|_| if rng.random::<bool>() { scale } else { -scale })
139        .collect()
140}
141
142/// Build a Count sketch matrix of shape (sketch_dim × m).
143///
144/// Each column j has exactly one non-zero entry at a random row h(j) with sign σ(j) ∈ {±1}.
145fn build_count_sketch(sketch_dim: usize, m: usize, rng: &mut StdRng) -> Vec<f64> {
146    let mut s = vec![0.0f64; sketch_dim * m];
147    for j in 0..m {
148        let row = rng.random_range(0..sketch_dim);
149        let sign: f64 = if rng.random::<bool>() { 1.0 } else { -1.0 };
150        s[row * m + j] = sign;
151    }
152    s
153}
154
155/// Apply the Walsh-Hadamard transform to a slice in-place (length must be a power of 2).
156fn walsh_hadamard_transform(x: &mut [f64]) {
157    let n = x.len();
158    if n <= 1 {
159        return;
160    }
161    // Cooley-Tukey style
162    let mut h = 1;
163    while h < n {
164        for i in (0..n).step_by(2 * h) {
165            for j in i..(i + h) {
166                let u = x[j];
167                let v = x[j + h];
168                x[j] = u + v;
169                x[j + h] = u - v;
170            }
171        }
172        h <<= 1;
173    }
174    // Normalize by 1/sqrt(n)
175    let inv_sqrt_n = 1.0 / (n as f64).sqrt();
176    for xi in x.iter_mut() {
177        *xi *= inv_sqrt_n;
178    }
179}
180
181/// Build a SRHT sketch matrix of shape (sketch_dim × m).
182///
183/// Applies D (random sign flips), then H (Hadamard), then samples sketch_dim rows.
184/// m is padded to the next power of 2 if needed.
185fn build_hadamard_sketch(sketch_dim: usize, m: usize, rng: &mut StdRng) -> (Vec<f64>, usize) {
186    // Pad m to next power of 2
187    let m_pad = m.next_power_of_two();
188    let scale = (m_pad as f64 / sketch_dim as f64).sqrt() / (m_pad as f64).sqrt();
189
190    // Random diagonal sign matrix D (m_pad entries)
191    let signs: Vec<f64> = (0..m_pad)
192        .map(|_| if rng.random::<bool>() { 1.0 } else { -1.0 })
193        .collect();
194
195    // Random row-selection permutation (sample sketch_dim rows without replacement from m_pad)
196    let mut perm: Vec<usize> = (0..m_pad).collect();
197    // Fisher-Yates partial shuffle for first sketch_dim elements
198    for i in 0..sketch_dim.min(m_pad) {
199        let j = i + rng.random_range(0..(m_pad - i));
200        perm.swap(i, j);
201    }
202    let selected_rows: Vec<usize> = perm[..sketch_dim.min(m_pad)].to_vec();
203
204    // Build each row of S by applying D then H to canonical basis vectors
205    // More practically: S[k, :] = scale * e_{selected_rows[k]}^T H D
206    // We represent S as a dense matrix for application to vectors
207    // Actually: to apply S to a vector v of length m:
208    //   1. Pad v to m_pad with zeros
209    //   2. Apply D: u = D v_pad (element-wise multiply)
210    //   3. Apply H: w = H u (WHT)
211    //   4. Select rows: Sv = scale * w[selected_rows]
212    // For the matrix form needed in matrix-matrix products, we build S explicitly.
213    let mut s = vec![0.0f64; sketch_dim * m_pad];
214
215    // Build each column of S^T (which is a row of S)
216    // We process each basis vector e_j and apply D then H
217    for j in 0..m {
218        let mut col = vec![0.0f64; m_pad];
219        col[j] = signs[j]; // D applied to e_j
220
221        walsh_hadamard_transform(&mut col);
222
223        // Now col = H D e_j; select rows
224        for (k, &row_idx) in selected_rows.iter().enumerate() {
225            s[k * m_pad + j] = scale * col[row_idx];
226        }
227    }
228
229    (s, m_pad)
230}
231
232// ─── Matrix-vector operations ─────────────────────────────────────────────────
233
234/// Compute S A for sketch matrix S (sketch_dim × m) and A (m × n), giving (sketch_dim × n).
235fn sketch_matrix(s: &[f64], sketch_dim: usize, a: &Array2<f64>, m_actual: usize) -> Vec<f64> {
236    let m = a.nrows();
237    let n = a.ncols();
238    let m_s = m_actual.min(m); // rows to use from S (in case of padding)
239    let mut sa = vec![0.0f64; sketch_dim * n];
240
241    for k in 0..sketch_dim {
242        for j in 0..n {
243            let mut val = 0.0;
244            for i in 0..m_s {
245                val += s[k * m_actual + i] * a[[i, j]];
246            }
247            sa[k * n + j] = val;
248        }
249    }
250    sa
251}
252
253/// Compute S b for sketch matrix S (sketch_dim × m) and vector b (m,), giving (sketch_dim,).
254fn sketch_vector(s: &[f64], sketch_dim: usize, b: &[f64], m_actual: usize) -> Vec<f64> {
255    let m_use = b.len().min(m_actual);
256    let mut sb = vec![0.0f64; sketch_dim];
257    for k in 0..sketch_dim {
258        let mut val = 0.0;
259        for i in 0..m_use {
260            val += s[k * m_actual + i] * b[i];
261        }
262        sb[k] = val;
263    }
264    sb
265}
266
267/// Compute (SA)^T (SA x - Sb) — the sketched gradient with respect to x.
268fn sketched_gradient(sa: &[f64], sb: &[f64], x: &[f64], sketch_dim: usize, n: usize) -> Vec<f64> {
269    // r = SA x - Sb  (sketch_dim)
270    let mut r = vec![0.0f64; sketch_dim];
271    for k in 0..sketch_dim {
272        let mut dot = 0.0;
273        for j in 0..n {
274            dot += sa[k * n + j] * x[j];
275        }
276        r[k] = dot - sb[k];
277    }
278
279    // g = (SA)^T r  (n)
280    let mut g = vec![0.0f64; n];
281    for j in 0..n {
282        let mut val = 0.0;
283        for k in 0..sketch_dim {
284            val += sa[k * n + j] * r[k];
285        }
286        g[j] = val;
287    }
288    g
289}
290
291/// Estimate a safe step size as 1 / (max diagonal of (SA)^T SA).
292fn estimate_step_size(sa: &[f64], sketch_dim: usize, n: usize) -> f64 {
293    // Largest eigenvalue of (SA)^T SA is bounded by its maximum diagonal entry * n
294    // For a safe choice, use 1 / ||SA||_F^2 (a conservative estimate)
295    let norm_sq: f64 = sa.iter().map(|v| v * v).sum();
296    if norm_sq < f64::EPSILON {
297        1e-4
298    } else {
299        // Each step size should be < 2 / (largest eigenvalue of (SA)^T SA)
300        // A conservative estimate: 1 / (sketch_dim * max_j sum_k sa[k,j]^2)
301        let max_col_sq = (0..n)
302            .map(|j| (0..sketch_dim).map(|k| sa[k * n + j].powi(2)).sum::<f64>())
303            .fold(f64::NEG_INFINITY, f64::max);
304
305        if max_col_sq > f64::EPSILON {
306            0.9 / max_col_sq
307        } else {
308            1e-4
309        }
310    }
311}
312
313/// Compute the full residual norm ||Ax - b||.
314fn full_residual_norm(a: &Array2<f64>, b: &[f64], x: &[f64]) -> f64 {
315    let m = a.nrows();
316    let mut norm_sq = 0.0;
317    for i in 0..m {
318        let row = a.row(i);
319        let ax_i: f64 = row.iter().zip(x.iter()).map(|(aij, xj)| aij * xj).sum();
320        let r = ax_i - b[i];
321        norm_sq += r * r;
322    }
323    norm_sq.sqrt()
324}
325
326// ─── Public API ──────────────────────────────────────────────────────────────
327
328/// Solve the least-squares problem min ||Ax - b||² using sketched gradient descent.
329///
330/// At each iteration, forms a random sketch S of A and b, computes the sketched
331/// gradient g = (SA)^T (SA x - Sb), and performs a gradient step x ← x - α g.
332///
333/// # Arguments
334/// - `a`: Coefficient matrix of shape (m, n) with m >> n typical.
335/// - `b`: Right-hand side vector of length m.
336/// - `config`: Solver configuration.
337///
338/// # Returns
339/// A [`LsqResult`] with the approximate minimizer.
340pub fn sketched_least_squares(
341    a: &Array2<f64>,
342    b: &[f64],
343    config: &SketchedLeastSquaresConfig,
344) -> OptimizeResult<LsqResult> {
345    let m = a.nrows();
346    let n = a.ncols();
347
348    if m == 0 || n == 0 {
349        return Err(OptimizeError::InvalidInput(
350            "Matrix A must be non-empty".to_string(),
351        ));
352    }
353    if b.len() != m {
354        return Err(OptimizeError::InvalidInput(format!(
355            "b has length {} but A has {} rows",
356            b.len(),
357            m
358        )));
359    }
360    if config.sketch_dim == 0 {
361        return Err(OptimizeError::InvalidParameter(
362            "sketch_dim must be positive".to_string(),
363        ));
364    }
365
366    let sketch_dim = config.sketch_dim.min(m); // Sketch cannot be larger than m
367
368    let mut x = vec![0.0f64; n];
369    let mut rng = StdRng::seed_from_u64(config.seed);
370
371    // Precompute sketch once if not refreshing
372    let precomputed_sketch: Option<(Vec<f64>, Vec<f64>)> = if !config.refresh_sketch {
373        let (s, m_actual) = build_sketch_matrix(&config.sketch_type, sketch_dim, m, &mut rng);
374        let sa = sketch_matrix(&s, sketch_dim, a, m_actual);
375        let sb = sketch_vector(&s, sketch_dim, b, m_actual);
376        Some((sa, sb))
377    } else {
378        None
379    };
380
381    for iter in 0..config.max_iter {
382        let (sa, sb) = match &precomputed_sketch {
383            Some((sa, sb)) => (sa.clone(), sb.clone()),
384            None => {
385                let (s, m_actual) =
386                    build_sketch_matrix(&config.sketch_type, sketch_dim, m, &mut rng);
387                let sa = sketch_matrix(&s, sketch_dim, a, m_actual);
388                let sb = sketch_vector(&s, sketch_dim, b, m_actual);
389                (sa, sb)
390            }
391        };
392
393        let alpha = config
394            .step_size
395            .unwrap_or_else(|| estimate_step_size(&sa, sketch_dim, n));
396
397        let g = sketched_gradient(&sa, &sb, &x, sketch_dim, n);
398
399        // Compute update norm for convergence check
400        let update_norm: f64 = g.iter().map(|v| (alpha * v).powi(2)).sum::<f64>().sqrt();
401        let x_norm: f64 = x.iter().map(|v| v * v).sum::<f64>().sqrt();
402        let rel_change = update_norm / (1.0 + x_norm);
403
404        // Apply update
405        for (xi, gi) in x.iter_mut().zip(g.iter()) {
406            *xi -= alpha * gi;
407        }
408
409        if rel_change < config.tol {
410            let rn = full_residual_norm(a, b, &x);
411            return Ok(LsqResult {
412                x,
413                residual_norm: rn,
414                n_iter: iter + 1,
415                converged: true,
416            });
417        }
418    }
419
420    let rn = full_residual_norm(a, b, &x);
421    // Check convergence based on residual norm for consistency
422    let converged = rn < config.tol * (1.0 + b.iter().map(|v| v * v).sum::<f64>().sqrt());
423
424    Ok(LsqResult {
425        x,
426        residual_norm: rn,
427        n_iter: config.max_iter,
428        converged,
429    })
430}
431
432/// Build a sketch matrix (flat row-major array) and return (S, m_actual).
433///
434/// `m_actual` may differ from `m` for Hadamard sketches (padding to power of 2).
435fn build_sketch_matrix(
436    sketch_type: &SketchType,
437    sketch_dim: usize,
438    m: usize,
439    rng: &mut StdRng,
440) -> (Vec<f64>, usize) {
441    match sketch_type {
442        SketchType::Gaussian => (build_gaussian_sketch(sketch_dim, m, rng), m),
443        SketchType::Uniform => (build_rademacher_sketch(sketch_dim, m, rng), m),
444        SketchType::CountSketch => (build_count_sketch(sketch_dim, m, rng), m),
445        SketchType::Hadamard => build_hadamard_sketch(sketch_dim, m, rng),
446        _ => (build_gaussian_sketch(sketch_dim, m, rng), m),
447    }
448}
449
450// ─── Tests ───────────────────────────────────────────────────────────────────
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455    use scirs2_core::ndarray::Array2;
456
457    /// Build a simple overdetermined least-squares test case: A x = b with x* = [1, 2].
458    fn make_lsq_problem(noise_scale: f64, rng: &mut StdRng) -> (Array2<f64>, Vec<f64>) {
459        let m = 50;
460        let n = 2;
461        let x_true = vec![1.0, 2.0];
462
463        let mut a_data = vec![0.0f64; m * n];
464        let mut b = vec![0.0f64; m];
465
466        for i in 0..m {
467            let a0 = (i as f64) / m as f64;
468            let a1 = 1.0 - a0;
469            a_data[i * n] = a0;
470            a_data[i * n + 1] = a1;
471            b[i] = a0 * x_true[0] + a1 * x_true[1];
472            if noise_scale > 0.0 {
473                let u: f64 = rng.random::<f64>() - 0.5;
474                b[i] += noise_scale * u;
475            }
476        }
477
478        let a = Array2::from_shape_vec((m, n), a_data).expect("valid shape");
479        (a, b)
480    }
481
482    #[test]
483    fn test_sketched_ls_gaussian() {
484        let mut rng = StdRng::seed_from_u64(0);
485        let (a, b) = make_lsq_problem(0.0, &mut rng);
486
487        let config = SketchedLeastSquaresConfig {
488            sketch_dim: 30,
489            sketch_type: SketchType::Gaussian,
490            max_iter: 500,
491            tol: 1e-5,
492            seed: 42,
493            refresh_sketch: true,
494            step_size: Some(0.01),
495        };
496
497        let result = sketched_least_squares(&a, &b, &config).expect("sketched LS should succeed");
498        // Should recover x ≈ [1, 2]
499        assert!(
500            (result.x[0] - 1.0).abs() < 0.1,
501            "x[0] ≈ 1, got {}",
502            result.x[0]
503        );
504        assert!(
505            (result.x[1] - 2.0).abs() < 0.1,
506            "x[1] ≈ 2, got {}",
507            result.x[1]
508        );
509    }
510
511    #[test]
512    fn test_sketched_ls_count_sketch() {
513        let mut rng = StdRng::seed_from_u64(0);
514        let (a, b) = make_lsq_problem(0.0, &mut rng);
515
516        let config = SketchedLeastSquaresConfig {
517            sketch_dim: 30,
518            sketch_type: SketchType::CountSketch,
519            max_iter: 500,
520            tol: 1e-5,
521            seed: 77,
522            refresh_sketch: true,
523            step_size: Some(0.01),
524        };
525
526        let result =
527            sketched_least_squares(&a, &b, &config).expect("count sketch LS should succeed");
528        assert!(
529            (result.x[0] - 1.0).abs() < 0.2,
530            "x[0] ≈ 1, got {}",
531            result.x[0]
532        );
533        assert!(
534            (result.x[1] - 2.0).abs() < 0.2,
535            "x[1] ≈ 2, got {}",
536            result.x[1]
537        );
538    }
539
540    #[test]
541    fn test_sketched_ls_rademacher() {
542        let mut rng = StdRng::seed_from_u64(0);
543        let (a, b) = make_lsq_problem(0.0, &mut rng);
544
545        let config = SketchedLeastSquaresConfig {
546            sketch_dim: 25,
547            sketch_type: SketchType::Uniform,
548            max_iter: 500,
549            tol: 1e-5,
550            seed: 99,
551            refresh_sketch: true,
552            step_size: Some(0.01),
553        };
554
555        let result =
556            sketched_least_squares(&a, &b, &config).expect("Rademacher sketch should succeed");
557        assert!((result.x[0] - 1.0).abs() < 0.2, "x[0] ≈ 1");
558        assert!((result.x[1] - 2.0).abs() < 0.2, "x[1] ≈ 2");
559    }
560
561    #[test]
562    fn test_sketched_ls_hadamard() {
563        let mut rng = StdRng::seed_from_u64(0);
564        let (a, b) = make_lsq_problem(0.0, &mut rng);
565
566        let config = SketchedLeastSquaresConfig {
567            sketch_dim: 20,
568            sketch_type: SketchType::Hadamard,
569            max_iter: 500,
570            tol: 1e-5,
571            seed: 42,
572            refresh_sketch: true,
573            step_size: Some(0.01),
574        };
575
576        let result = sketched_least_squares(&a, &b, &config).expect("SRHT sketch should succeed");
577        // SRHT may be less accurate due to padding; allow wider tolerance
578        assert!(
579            (result.x[0] - 1.0).abs() < 0.5,
580            "x[0] ≈ 1, got {}",
581            result.x[0]
582        );
583        assert!(
584            (result.x[1] - 2.0).abs() < 0.5,
585            "x[1] ≈ 2, got {}",
586            result.x[1]
587        );
588    }
589
590    #[test]
591    fn test_sketched_ls_static_sketch() {
592        let mut rng = StdRng::seed_from_u64(0);
593        let (a, b) = make_lsq_problem(0.0, &mut rng);
594
595        let config = SketchedLeastSquaresConfig {
596            sketch_dim: 30,
597            sketch_type: SketchType::Gaussian,
598            max_iter: 500,
599            tol: 1e-5,
600            seed: 42,
601            refresh_sketch: false, // fixed sketch throughout
602            step_size: Some(0.01),
603        };
604
605        let result =
606            sketched_least_squares(&a, &b, &config).expect("static sketch LS should succeed");
607        // Fixed sketch is less powerful but should still reduce residual
608        assert!(result.residual_norm < 5.0);
609    }
610
611    #[test]
612    fn test_sketched_ls_invalid_input() {
613        let a = Array2::<f64>::zeros((5, 2));
614        let b = vec![1.0; 3]; // wrong length
615        let result = sketched_least_squares(&a, &b, &SketchedLeastSquaresConfig::default());
616        assert!(result.is_err());
617    }
618
619    #[test]
620    fn test_sketched_ls_zero_sketch_dim_error() {
621        let a = Array2::<f64>::eye(4);
622        let b = vec![1.0; 4];
623        let config = SketchedLeastSquaresConfig {
624            sketch_dim: 0,
625            ..SketchedLeastSquaresConfig::default()
626        };
627        let result = sketched_least_squares(&a, &b, &config);
628        assert!(result.is_err());
629    }
630
631    #[test]
632    fn test_sketched_ls_identity_system() {
633        // Exact system A = I_4, b = [1,2,3,4], x* = [1,2,3,4]
634        let a = Array2::<f64>::eye(4);
635        let b = vec![1.0, 2.0, 3.0, 4.0];
636
637        let config = SketchedLeastSquaresConfig {
638            sketch_dim: 4,
639            sketch_type: SketchType::Gaussian,
640            max_iter: 1000,
641            tol: 1e-6,
642            seed: 42,
643            refresh_sketch: true,
644            step_size: Some(0.1),
645        };
646
647        let result = sketched_least_squares(&a, &b, &config).expect("identity system should work");
648        for (i, (&xi, &bi)) in result.x.iter().zip(b.iter()).enumerate() {
649            assert!((xi - bi).abs() < 0.5, "x[{}] ≈ {}, got {}", i, bi, xi);
650        }
651    }
652}