1use crate::error::{OptimizeError, OptimizeResult};
30use scirs2_core::ndarray::Array2;
31use scirs2_core::random::{rngs::StdRng, RngExt, SeedableRng};
32
33#[non_exhaustive]
35#[derive(Debug, Clone, PartialEq)]
36pub enum SketchType {
37 Gaussian,
39 Hadamard,
41 Uniform,
43 CountSketch,
45}
46
47impl Default for SketchType {
48 fn default() -> Self {
49 SketchType::Gaussian
50 }
51}
52
53#[derive(Clone, Debug)]
55pub struct SketchedLeastSquaresConfig {
56 pub sketch_dim: usize,
58 pub sketch_type: SketchType,
60 pub max_iter: usize,
62 pub tol: f64,
64 pub seed: u64,
66 pub refresh_sketch: bool,
68 pub step_size: Option<f64>,
71}
72
73impl Default for SketchedLeastSquaresConfig {
74 fn default() -> Self {
75 Self {
76 sketch_dim: 512,
77 sketch_type: SketchType::Gaussian,
78 max_iter: 100,
79 tol: 1e-6,
80 seed: 42,
81 refresh_sketch: true,
82 step_size: None,
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
89pub struct LsqResult {
90 pub x: Vec<f64>,
92 pub residual_norm: f64,
94 pub n_iter: usize,
96 pub converged: bool,
98}
99
100fn build_gaussian_sketch(sketch_dim: usize, m: usize, rng: &mut StdRng) -> Vec<f64> {
106 let scale = (1.0 / sketch_dim as f64).sqrt();
107 let mut s = Vec::with_capacity(sketch_dim * m);
108 let mut spare: Option<f64> = None;
110 for _ in 0..(sketch_dim * m) {
111 let v = match spare.take() {
112 Some(z) => z,
113 None => {
114 loop {
116 let u: f64 = rng.random::<f64>();
117 let v: f64 = rng.random::<f64>();
118 if u > 0.0 {
119 let mag = (-2.0 * u.ln()).sqrt();
120 let angle = std::f64::consts::TAU * v;
121 spare = Some(mag * angle.sin());
122 break mag * angle.cos();
123 }
124 }
125 }
126 };
127 s.push(v * scale);
128 }
129 s
130}
131
132fn build_rademacher_sketch(sketch_dim: usize, m: usize, rng: &mut StdRng) -> Vec<f64> {
136 let scale = 1.0 / (sketch_dim as f64).sqrt();
137 (0..sketch_dim * m)
138 .map(|_| if rng.random::<bool>() { scale } else { -scale })
139 .collect()
140}
141
142fn build_count_sketch(sketch_dim: usize, m: usize, rng: &mut StdRng) -> Vec<f64> {
146 let mut s = vec![0.0f64; sketch_dim * m];
147 for j in 0..m {
148 let row = rng.random_range(0..sketch_dim);
149 let sign: f64 = if rng.random::<bool>() { 1.0 } else { -1.0 };
150 s[row * m + j] = sign;
151 }
152 s
153}
154
155fn walsh_hadamard_transform(x: &mut [f64]) {
157 let n = x.len();
158 if n <= 1 {
159 return;
160 }
161 let mut h = 1;
163 while h < n {
164 for i in (0..n).step_by(2 * h) {
165 for j in i..(i + h) {
166 let u = x[j];
167 let v = x[j + h];
168 x[j] = u + v;
169 x[j + h] = u - v;
170 }
171 }
172 h <<= 1;
173 }
174 let inv_sqrt_n = 1.0 / (n as f64).sqrt();
176 for xi in x.iter_mut() {
177 *xi *= inv_sqrt_n;
178 }
179}
180
181fn build_hadamard_sketch(sketch_dim: usize, m: usize, rng: &mut StdRng) -> (Vec<f64>, usize) {
186 let m_pad = m.next_power_of_two();
188 let scale = (m_pad as f64 / sketch_dim as f64).sqrt() / (m_pad as f64).sqrt();
189
190 let signs: Vec<f64> = (0..m_pad)
192 .map(|_| if rng.random::<bool>() { 1.0 } else { -1.0 })
193 .collect();
194
195 let mut perm: Vec<usize> = (0..m_pad).collect();
197 for i in 0..sketch_dim.min(m_pad) {
199 let j = i + rng.random_range(0..(m_pad - i));
200 perm.swap(i, j);
201 }
202 let selected_rows: Vec<usize> = perm[..sketch_dim.min(m_pad)].to_vec();
203
204 let mut s = vec![0.0f64; sketch_dim * m_pad];
214
215 for j in 0..m {
218 let mut col = vec![0.0f64; m_pad];
219 col[j] = signs[j]; walsh_hadamard_transform(&mut col);
222
223 for (k, &row_idx) in selected_rows.iter().enumerate() {
225 s[k * m_pad + j] = scale * col[row_idx];
226 }
227 }
228
229 (s, m_pad)
230}
231
232fn sketch_matrix(s: &[f64], sketch_dim: usize, a: &Array2<f64>, m_actual: usize) -> Vec<f64> {
236 let m = a.nrows();
237 let n = a.ncols();
238 let m_s = m_actual.min(m); let mut sa = vec![0.0f64; sketch_dim * n];
240
241 for k in 0..sketch_dim {
242 for j in 0..n {
243 let mut val = 0.0;
244 for i in 0..m_s {
245 val += s[k * m_actual + i] * a[[i, j]];
246 }
247 sa[k * n + j] = val;
248 }
249 }
250 sa
251}
252
253fn sketch_vector(s: &[f64], sketch_dim: usize, b: &[f64], m_actual: usize) -> Vec<f64> {
255 let m_use = b.len().min(m_actual);
256 let mut sb = vec![0.0f64; sketch_dim];
257 for k in 0..sketch_dim {
258 let mut val = 0.0;
259 for i in 0..m_use {
260 val += s[k * m_actual + i] * b[i];
261 }
262 sb[k] = val;
263 }
264 sb
265}
266
267fn sketched_gradient(sa: &[f64], sb: &[f64], x: &[f64], sketch_dim: usize, n: usize) -> Vec<f64> {
269 let mut r = vec![0.0f64; sketch_dim];
271 for k in 0..sketch_dim {
272 let mut dot = 0.0;
273 for j in 0..n {
274 dot += sa[k * n + j] * x[j];
275 }
276 r[k] = dot - sb[k];
277 }
278
279 let mut g = vec![0.0f64; n];
281 for j in 0..n {
282 let mut val = 0.0;
283 for k in 0..sketch_dim {
284 val += sa[k * n + j] * r[k];
285 }
286 g[j] = val;
287 }
288 g
289}
290
291fn estimate_step_size(sa: &[f64], sketch_dim: usize, n: usize) -> f64 {
293 let norm_sq: f64 = sa.iter().map(|v| v * v).sum();
296 if norm_sq < f64::EPSILON {
297 1e-4
298 } else {
299 let max_col_sq = (0..n)
302 .map(|j| (0..sketch_dim).map(|k| sa[k * n + j].powi(2)).sum::<f64>())
303 .fold(f64::NEG_INFINITY, f64::max);
304
305 if max_col_sq > f64::EPSILON {
306 0.9 / max_col_sq
307 } else {
308 1e-4
309 }
310 }
311}
312
313fn full_residual_norm(a: &Array2<f64>, b: &[f64], x: &[f64]) -> f64 {
315 let m = a.nrows();
316 let mut norm_sq = 0.0;
317 for i in 0..m {
318 let row = a.row(i);
319 let ax_i: f64 = row.iter().zip(x.iter()).map(|(aij, xj)| aij * xj).sum();
320 let r = ax_i - b[i];
321 norm_sq += r * r;
322 }
323 norm_sq.sqrt()
324}
325
326pub fn sketched_least_squares(
341 a: &Array2<f64>,
342 b: &[f64],
343 config: &SketchedLeastSquaresConfig,
344) -> OptimizeResult<LsqResult> {
345 let m = a.nrows();
346 let n = a.ncols();
347
348 if m == 0 || n == 0 {
349 return Err(OptimizeError::InvalidInput(
350 "Matrix A must be non-empty".to_string(),
351 ));
352 }
353 if b.len() != m {
354 return Err(OptimizeError::InvalidInput(format!(
355 "b has length {} but A has {} rows",
356 b.len(),
357 m
358 )));
359 }
360 if config.sketch_dim == 0 {
361 return Err(OptimizeError::InvalidParameter(
362 "sketch_dim must be positive".to_string(),
363 ));
364 }
365
366 let sketch_dim = config.sketch_dim.min(m); let mut x = vec![0.0f64; n];
369 let mut rng = StdRng::seed_from_u64(config.seed);
370
371 let precomputed_sketch: Option<(Vec<f64>, Vec<f64>)> = if !config.refresh_sketch {
373 let (s, m_actual) = build_sketch_matrix(&config.sketch_type, sketch_dim, m, &mut rng);
374 let sa = sketch_matrix(&s, sketch_dim, a, m_actual);
375 let sb = sketch_vector(&s, sketch_dim, b, m_actual);
376 Some((sa, sb))
377 } else {
378 None
379 };
380
381 for iter in 0..config.max_iter {
382 let (sa, sb) = match &precomputed_sketch {
383 Some((sa, sb)) => (sa.clone(), sb.clone()),
384 None => {
385 let (s, m_actual) =
386 build_sketch_matrix(&config.sketch_type, sketch_dim, m, &mut rng);
387 let sa = sketch_matrix(&s, sketch_dim, a, m_actual);
388 let sb = sketch_vector(&s, sketch_dim, b, m_actual);
389 (sa, sb)
390 }
391 };
392
393 let alpha = config
394 .step_size
395 .unwrap_or_else(|| estimate_step_size(&sa, sketch_dim, n));
396
397 let g = sketched_gradient(&sa, &sb, &x, sketch_dim, n);
398
399 let update_norm: f64 = g.iter().map(|v| (alpha * v).powi(2)).sum::<f64>().sqrt();
401 let x_norm: f64 = x.iter().map(|v| v * v).sum::<f64>().sqrt();
402 let rel_change = update_norm / (1.0 + x_norm);
403
404 for (xi, gi) in x.iter_mut().zip(g.iter()) {
406 *xi -= alpha * gi;
407 }
408
409 if rel_change < config.tol {
410 let rn = full_residual_norm(a, b, &x);
411 return Ok(LsqResult {
412 x,
413 residual_norm: rn,
414 n_iter: iter + 1,
415 converged: true,
416 });
417 }
418 }
419
420 let rn = full_residual_norm(a, b, &x);
421 let converged = rn < config.tol * (1.0 + b.iter().map(|v| v * v).sum::<f64>().sqrt());
423
424 Ok(LsqResult {
425 x,
426 residual_norm: rn,
427 n_iter: config.max_iter,
428 converged,
429 })
430}
431
432fn build_sketch_matrix(
436 sketch_type: &SketchType,
437 sketch_dim: usize,
438 m: usize,
439 rng: &mut StdRng,
440) -> (Vec<f64>, usize) {
441 match sketch_type {
442 SketchType::Gaussian => (build_gaussian_sketch(sketch_dim, m, rng), m),
443 SketchType::Uniform => (build_rademacher_sketch(sketch_dim, m, rng), m),
444 SketchType::CountSketch => (build_count_sketch(sketch_dim, m, rng), m),
445 SketchType::Hadamard => build_hadamard_sketch(sketch_dim, m, rng),
446 _ => (build_gaussian_sketch(sketch_dim, m, rng), m),
447 }
448}
449
450#[cfg(test)]
453mod tests {
454 use super::*;
455 use scirs2_core::ndarray::Array2;
456
457 fn make_lsq_problem(noise_scale: f64, rng: &mut StdRng) -> (Array2<f64>, Vec<f64>) {
459 let m = 50;
460 let n = 2;
461 let x_true = vec![1.0, 2.0];
462
463 let mut a_data = vec![0.0f64; m * n];
464 let mut b = vec![0.0f64; m];
465
466 for i in 0..m {
467 let a0 = (i as f64) / m as f64;
468 let a1 = 1.0 - a0;
469 a_data[i * n] = a0;
470 a_data[i * n + 1] = a1;
471 b[i] = a0 * x_true[0] + a1 * x_true[1];
472 if noise_scale > 0.0 {
473 let u: f64 = rng.random::<f64>() - 0.5;
474 b[i] += noise_scale * u;
475 }
476 }
477
478 let a = Array2::from_shape_vec((m, n), a_data).expect("valid shape");
479 (a, b)
480 }
481
482 #[test]
483 fn test_sketched_ls_gaussian() {
484 let mut rng = StdRng::seed_from_u64(0);
485 let (a, b) = make_lsq_problem(0.0, &mut rng);
486
487 let config = SketchedLeastSquaresConfig {
488 sketch_dim: 30,
489 sketch_type: SketchType::Gaussian,
490 max_iter: 500,
491 tol: 1e-5,
492 seed: 42,
493 refresh_sketch: true,
494 step_size: Some(0.01),
495 };
496
497 let result = sketched_least_squares(&a, &b, &config).expect("sketched LS should succeed");
498 assert!(
500 (result.x[0] - 1.0).abs() < 0.1,
501 "x[0] ≈ 1, got {}",
502 result.x[0]
503 );
504 assert!(
505 (result.x[1] - 2.0).abs() < 0.1,
506 "x[1] ≈ 2, got {}",
507 result.x[1]
508 );
509 }
510
511 #[test]
512 fn test_sketched_ls_count_sketch() {
513 let mut rng = StdRng::seed_from_u64(0);
514 let (a, b) = make_lsq_problem(0.0, &mut rng);
515
516 let config = SketchedLeastSquaresConfig {
517 sketch_dim: 30,
518 sketch_type: SketchType::CountSketch,
519 max_iter: 500,
520 tol: 1e-5,
521 seed: 77,
522 refresh_sketch: true,
523 step_size: Some(0.01),
524 };
525
526 let result =
527 sketched_least_squares(&a, &b, &config).expect("count sketch LS should succeed");
528 assert!(
529 (result.x[0] - 1.0).abs() < 0.2,
530 "x[0] ≈ 1, got {}",
531 result.x[0]
532 );
533 assert!(
534 (result.x[1] - 2.0).abs() < 0.2,
535 "x[1] ≈ 2, got {}",
536 result.x[1]
537 );
538 }
539
540 #[test]
541 fn test_sketched_ls_rademacher() {
542 let mut rng = StdRng::seed_from_u64(0);
543 let (a, b) = make_lsq_problem(0.0, &mut rng);
544
545 let config = SketchedLeastSquaresConfig {
546 sketch_dim: 25,
547 sketch_type: SketchType::Uniform,
548 max_iter: 500,
549 tol: 1e-5,
550 seed: 99,
551 refresh_sketch: true,
552 step_size: Some(0.01),
553 };
554
555 let result =
556 sketched_least_squares(&a, &b, &config).expect("Rademacher sketch should succeed");
557 assert!((result.x[0] - 1.0).abs() < 0.2, "x[0] ≈ 1");
558 assert!((result.x[1] - 2.0).abs() < 0.2, "x[1] ≈ 2");
559 }
560
561 #[test]
562 fn test_sketched_ls_hadamard() {
563 let mut rng = StdRng::seed_from_u64(0);
564 let (a, b) = make_lsq_problem(0.0, &mut rng);
565
566 let config = SketchedLeastSquaresConfig {
567 sketch_dim: 20,
568 sketch_type: SketchType::Hadamard,
569 max_iter: 500,
570 tol: 1e-5,
571 seed: 42,
572 refresh_sketch: true,
573 step_size: Some(0.01),
574 };
575
576 let result = sketched_least_squares(&a, &b, &config).expect("SRHT sketch should succeed");
577 assert!(
579 (result.x[0] - 1.0).abs() < 0.5,
580 "x[0] ≈ 1, got {}",
581 result.x[0]
582 );
583 assert!(
584 (result.x[1] - 2.0).abs() < 0.5,
585 "x[1] ≈ 2, got {}",
586 result.x[1]
587 );
588 }
589
590 #[test]
591 fn test_sketched_ls_static_sketch() {
592 let mut rng = StdRng::seed_from_u64(0);
593 let (a, b) = make_lsq_problem(0.0, &mut rng);
594
595 let config = SketchedLeastSquaresConfig {
596 sketch_dim: 30,
597 sketch_type: SketchType::Gaussian,
598 max_iter: 500,
599 tol: 1e-5,
600 seed: 42,
601 refresh_sketch: false, step_size: Some(0.01),
603 };
604
605 let result =
606 sketched_least_squares(&a, &b, &config).expect("static sketch LS should succeed");
607 assert!(result.residual_norm < 5.0);
609 }
610
611 #[test]
612 fn test_sketched_ls_invalid_input() {
613 let a = Array2::<f64>::zeros((5, 2));
614 let b = vec![1.0; 3]; let result = sketched_least_squares(&a, &b, &SketchedLeastSquaresConfig::default());
616 assert!(result.is_err());
617 }
618
619 #[test]
620 fn test_sketched_ls_zero_sketch_dim_error() {
621 let a = Array2::<f64>::eye(4);
622 let b = vec![1.0; 4];
623 let config = SketchedLeastSquaresConfig {
624 sketch_dim: 0,
625 ..SketchedLeastSquaresConfig::default()
626 };
627 let result = sketched_least_squares(&a, &b, &config);
628 assert!(result.is_err());
629 }
630
631 #[test]
632 fn test_sketched_ls_identity_system() {
633 let a = Array2::<f64>::eye(4);
635 let b = vec![1.0, 2.0, 3.0, 4.0];
636
637 let config = SketchedLeastSquaresConfig {
638 sketch_dim: 4,
639 sketch_type: SketchType::Gaussian,
640 max_iter: 1000,
641 tol: 1e-6,
642 seed: 42,
643 refresh_sketch: true,
644 step_size: Some(0.1),
645 };
646
647 let result = sketched_least_squares(&a, &b, &config).expect("identity system should work");
648 for (i, (&xi, &bi)) in result.x.iter().zip(b.iter()).enumerate() {
649 assert!((xi - bi).abs() < 0.5, "x[{}] ≈ {}, got {}", i, bi, xi);
650 }
651 }
652}