Skip to main content

sklears_simd/
approximate.rs

1//! Approximate computing algorithms for high-performance scenarios
2//!
3//! This module provides approximate SIMD operations with controlled error bounds,
4//! reduced precision arithmetic, and probabilistic algorithms for large-scale computations.
5
6#[cfg(feature = "no-std")]
7extern crate alloc;
8
9#[cfg(feature = "no-std")]
10use alloc::collections::BTreeMap as HashMap;
11#[cfg(not(feature = "no-std"))]
12use std::{collections::HashMap, vec};
13
14/// Error bounds for approximate operations
15#[derive(Debug, Clone, Copy)]
16pub struct ErrorBound {
17    pub relative_error: f64,
18    pub absolute_error: f64,
19    pub probability: f64, // Probability that error is within bounds
20}
21
22impl ErrorBound {
23    pub const TIGHT: Self = Self {
24        relative_error: 0.01, // 1%
25        absolute_error: 1e-6,
26        probability: 0.99,
27    };
28
29    pub const MODERATE: Self = Self {
30        relative_error: 0.05, // 5%
31        absolute_error: 1e-4,
32        probability: 0.95,
33    };
34
35    pub const RELAXED: Self = Self {
36        relative_error: 0.1, // 10%
37        absolute_error: 1e-3,
38        probability: 0.9,
39    };
40}
41
42/// Approximate SIMD operations with error bounds
43pub mod approximate_ops {
44    use super::*;
45
46    /// Approximate dot product using reduced precision
47    pub fn approximate_dot_product_f32(
48        a: &[f32],
49        b: &[f32],
50        error_bound: ErrorBound,
51    ) -> (f32, ErrorBound) {
52        assert_eq!(a.len(), b.len());
53
54        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
55        {
56            if crate::simd_feature_detected!("avx2") {
57                return unsafe { approximate_dot_product_f32_avx2(a, b, error_bound) };
58            }
59        }
60
61        approximate_dot_product_f32_scalar(a, b, error_bound)
62    }
63
64    /// Approximate sum with controlled precision loss
65    pub fn approximate_sum_f32(data: &[f32], error_bound: ErrorBound) -> (f32, ErrorBound) {
66        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
67        {
68            if crate::simd_feature_detected!("avx2") {
69                return unsafe { approximate_sum_f32_avx2(data, error_bound) };
70            }
71        }
72
73        approximate_sum_f32_scalar(data, error_bound)
74    }
75
76    /// Approximate L2 norm computation
77    pub fn approximate_l2_norm_f32(data: &[f32], error_bound: ErrorBound) -> (f32, ErrorBound) {
78        let (sum_squares, error) = approximate_sum_of_squares_f32(data, error_bound);
79        let norm = sum_squares.sqrt();
80
81        // Error propagation through square root
82        let propagated_error = ErrorBound {
83            relative_error: error.relative_error * 0.5, // sqrt reduces relative error
84            absolute_error: error.absolute_error * 0.5,
85            probability: error.probability,
86        };
87
88        (norm, propagated_error)
89    }
90
91    /// Approximate sum of squares
92    pub fn approximate_sum_of_squares_f32(
93        data: &[f32],
94        error_bound: ErrorBound,
95    ) -> (f32, ErrorBound) {
96        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
97        {
98            if crate::simd_feature_detected!("avx2") {
99                return unsafe { approximate_sum_of_squares_f32_avx2(data, error_bound) };
100            }
101        }
102
103        approximate_sum_of_squares_f32_scalar(data, error_bound)
104    }
105
106    // Scalar implementations
107    fn approximate_dot_product_f32_scalar(
108        a: &[f32],
109        b: &[f32],
110        error_bound: ErrorBound,
111    ) -> (f32, ErrorBound) {
112        // Use reduced precision accumulation for speed
113        let mut sum = 0.0f32;
114
115        for (&x, &y) in a.iter().zip(b.iter()) {
116            // Optionally quantize inputs for faster computation
117            let x_approx = quantize_f32(x, 16); // 16-bit precision
118            let y_approx = quantize_f32(y, 16);
119            sum += x_approx * y_approx;
120        }
121
122        // Estimate error based on precision reduction
123        let estimated_error = ErrorBound {
124            relative_error: (error_bound.relative_error + 0.001).min(0.1),
125            absolute_error: error_bound.absolute_error + 1e-5,
126            probability: error_bound.probability * 0.95,
127        };
128
129        (sum, estimated_error)
130    }
131
132    fn approximate_sum_f32_scalar(data: &[f32], error_bound: ErrorBound) -> (f32, ErrorBound) {
133        // Use Kahan summation for better accuracy while maintaining speed
134        let mut sum = 0.0f32;
135        let mut c = 0.0f32; // Compensation for lost low-order bits
136
137        for &x in data {
138            let x_approx = quantize_f32(x, 16);
139            let y = x_approx - c;
140            let t = sum + y;
141            c = (t - sum) - y;
142            sum = t;
143        }
144
145        let estimated_error = ErrorBound {
146            relative_error: (error_bound.relative_error + 0.0005).min(0.05),
147            absolute_error: error_bound.absolute_error + 1e-6,
148            probability: error_bound.probability * 0.98,
149        };
150
151        (sum, estimated_error)
152    }
153
154    fn approximate_sum_of_squares_f32_scalar(
155        data: &[f32],
156        error_bound: ErrorBound,
157    ) -> (f32, ErrorBound) {
158        let mut sum = 0.0f32;
159
160        for &x in data {
161            let x_approx = quantize_f32(x, 16);
162            sum += x_approx * x_approx;
163        }
164
165        let estimated_error = ErrorBound {
166            relative_error: (error_bound.relative_error + 0.002).min(0.1),
167            absolute_error: error_bound.absolute_error + 1e-5,
168            probability: error_bound.probability * 0.95,
169        };
170
171        (sum, estimated_error)
172    }
173
174    // AVX2 implementations
175    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
176    #[target_feature(enable = "avx2")]
177    unsafe fn approximate_dot_product_f32_avx2(
178        a: &[f32],
179        b: &[f32],
180        error_bound: ErrorBound,
181    ) -> (f32, ErrorBound) {
182        use core::arch::x86_64::*;
183
184        let mut sum_vec = _mm256_setzero_ps();
185        let chunks_a = a.chunks_exact(8);
186        let chunks_b = b.chunks_exact(8);
187        let remainder_a = chunks_a.remainder();
188        let remainder_b = chunks_b.remainder();
189
190        for (chunk_a, chunk_b) in chunks_a.zip(chunks_b) {
191            let vec_a = _mm256_loadu_ps(chunk_a.as_ptr());
192            let vec_b = _mm256_loadu_ps(chunk_b.as_ptr());
193
194            // Use FMA for better accuracy
195            sum_vec = _mm256_fmadd_ps(vec_a, vec_b, sum_vec);
196        }
197
198        // Horizontal sum
199        let sum_high = _mm256_extractf128_ps(sum_vec, 1);
200        let sum_low = _mm256_castps256_ps128(sum_vec);
201        let sum128 = _mm_add_ps(sum_high, sum_low);
202        let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
203        let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
204        let mut result = _mm_cvtss_f32(sum32);
205
206        // Handle remainder
207        for (&x, &y) in remainder_a.iter().zip(remainder_b.iter()) {
208            result += x * y;
209        }
210
211        let estimated_error = ErrorBound {
212            relative_error: error_bound.relative_error * 0.8, // SIMD typically more accurate
213            absolute_error: error_bound.absolute_error,
214            probability: error_bound.probability * 0.99,
215        };
216
217        (result, estimated_error)
218    }
219
220    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
221    #[target_feature(enable = "avx2")]
222    unsafe fn approximate_sum_f32_avx2(data: &[f32], error_bound: ErrorBound) -> (f32, ErrorBound) {
223        use core::arch::x86_64::*;
224
225        let mut sum_vec = _mm256_setzero_ps();
226        let chunks = data.chunks_exact(8);
227        let remainder = chunks.remainder();
228
229        for chunk in chunks {
230            let vec = _mm256_loadu_ps(chunk.as_ptr());
231            sum_vec = _mm256_add_ps(sum_vec, vec);
232        }
233
234        // Horizontal sum
235        let sum_high = _mm256_extractf128_ps(sum_vec, 1);
236        let sum_low = _mm256_castps256_ps128(sum_vec);
237        let sum128 = _mm_add_ps(sum_high, sum_low);
238        let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
239        let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
240        let mut result = _mm_cvtss_f32(sum32);
241
242        // Handle remainder
243        for &x in remainder {
244            result += x;
245        }
246
247        let estimated_error = ErrorBound {
248            relative_error: error_bound.relative_error * 0.9,
249            absolute_error: error_bound.absolute_error,
250            probability: error_bound.probability * 0.99,
251        };
252
253        (result, estimated_error)
254    }
255
256    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
257    #[target_feature(enable = "avx2")]
258    unsafe fn approximate_sum_of_squares_f32_avx2(
259        data: &[f32],
260        error_bound: ErrorBound,
261    ) -> (f32, ErrorBound) {
262        use core::arch::x86_64::*;
263
264        let mut sum_vec = _mm256_setzero_ps();
265        let chunks = data.chunks_exact(8);
266        let remainder = chunks.remainder();
267
268        for chunk in chunks {
269            let vec = _mm256_loadu_ps(chunk.as_ptr());
270            sum_vec = _mm256_fmadd_ps(vec, vec, sum_vec);
271        }
272
273        // Horizontal sum
274        let sum_high = _mm256_extractf128_ps(sum_vec, 1);
275        let sum_low = _mm256_castps256_ps128(sum_vec);
276        let sum128 = _mm_add_ps(sum_high, sum_low);
277        let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
278        let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
279        let mut result = _mm_cvtss_f32(sum32);
280
281        // Handle remainder
282        for &x in remainder {
283            result += x * x;
284        }
285
286        let estimated_error = ErrorBound {
287            relative_error: error_bound.relative_error * 0.85,
288            absolute_error: error_bound.absolute_error,
289            probability: error_bound.probability * 0.99,
290        };
291
292        (result, estimated_error)
293    }
294
295    /// Quantize f32 to reduced precision
296    fn quantize_f32(value: f32, bits: u8) -> f32 {
297        if bits >= 32 {
298            return value;
299        }
300
301        let scale = (1u32 << bits) as f32;
302
303        (value * scale).round() / scale
304    }
305}
306
307/// Reduced precision arithmetic for faster computation
308pub mod reduced_precision {
309    #[cfg(feature = "no-std")]
310    use alloc::{vec, vec::Vec};
311
312    /// 16-bit floating point emulation
313    #[derive(Debug, Clone, Copy, PartialEq)]
314    pub struct F16 {
315        bits: u16,
316    }
317
318    impl F16 {
319        pub fn from_f32(value: f32) -> Self {
320            // Simplified f16 conversion
321            let bits = if value.is_nan() {
322                0x7e00 // NaN
323            } else if value.is_infinite() {
324                if value.is_sign_positive() {
325                    0x7c00
326                } else {
327                    0xfc00
328                }
329            } else if value == 0.0 {
330                if value.is_sign_positive() {
331                    0x0000
332                } else {
333                    0x8000
334                }
335            } else {
336                // Very simplified conversion - not IEEE 754 compliant
337                let abs_val = value.abs();
338                let sign = if value < 0.0 { 0x8000 } else { 0x0000 };
339
340                if abs_val < 6.1e-5 {
341                    sign // Underflow to zero
342                } else if abs_val > 65504.0 {
343                    sign | 0x7c00 // Overflow to infinity
344                } else {
345                    // Approximate conversion
346                    let exp = (abs_val.log2().floor() as i16 + 15).clamp(0, 31) as u16;
347                    let mantissa =
348                        ((abs_val / 2.0_f32.powi(exp as i32 - 15) - 1.0) * 1024.0) as u16 & 0x3ff;
349                    sign | (exp << 10) | mantissa
350                }
351            };
352
353            Self { bits }
354        }
355
356        pub fn to_f32(self) -> f32 {
357            let sign = (self.bits & 0x8000) != 0;
358            let exp = (self.bits >> 10) & 0x1f;
359            let mantissa = self.bits & 0x3ff;
360
361            if exp == 0 {
362                if mantissa == 0 {
363                    if sign {
364                        -0.0
365                    } else {
366                        0.0
367                    }
368                } else {
369                    // Denormalized number
370                    let value = (mantissa as f32) / 1024.0 * 2.0_f32.powi(-14);
371                    if sign {
372                        -value
373                    } else {
374                        value
375                    }
376                }
377            } else if exp == 31 {
378                if mantissa == 0 {
379                    if sign {
380                        f32::NEG_INFINITY
381                    } else {
382                        f32::INFINITY
383                    }
384                } else {
385                    f32::NAN
386                }
387            } else {
388                let value = (1.0 + (mantissa as f32) / 1024.0) * 2.0_f32.powi(exp as i32 - 15);
389                if sign {
390                    -value
391                } else {
392                    value
393                }
394            }
395        }
396    }
397
398    /// 8-bit quantized operations
399    pub struct U8Quantized {
400        scale: f32,
401        zero_point: u8,
402    }
403
404    impl U8Quantized {
405        pub fn new(min_val: f32, max_val: f32) -> Self {
406            let scale = (max_val - min_val) / 255.0;
407            let zero_point = (-min_val / scale).round().clamp(0.0, 255.0) as u8;
408
409            Self { scale, zero_point }
410        }
411
412        pub fn quantize(&self, value: f32) -> u8 {
413            ((value / self.scale) + self.zero_point as f32)
414                .round()
415                .clamp(0.0, 255.0) as u8
416        }
417
418        pub fn dequantize(&self, quantized: u8) -> f32 {
419            (quantized as f32 - self.zero_point as f32) * self.scale
420        }
421
422        pub fn quantized_dot_product(&self, a: &[u8], b: &[u8]) -> f32 {
423            let sum: i32 = a
424                .iter()
425                .zip(b.iter())
426                .map(|(&x, &y)| {
427                    let x_adj = x as i32 - self.zero_point as i32;
428                    let y_adj = y as i32 - self.zero_point as i32;
429                    x_adj * y_adj
430                })
431                .sum();
432
433            sum as f32 * self.scale * self.scale
434        }
435    }
436
437    /// Mixed precision operations
438    pub fn mixed_precision_matrix_multiply(
439        a: &[f32],
440        b: &[f32],
441        rows_a: usize,
442        cols_a: usize,
443        cols_b: usize,
444    ) -> Vec<f32> {
445        assert_eq!(a.len(), rows_a * cols_a);
446        assert_eq!(b.len(), cols_a * cols_b);
447
448        let mut result = vec![0.0f32; rows_a * cols_b];
449
450        // Convert to f16 for intermediate computations
451        let a_f16: Vec<F16> = a.iter().map(|&x| F16::from_f32(x)).collect();
452        let b_f16: Vec<F16> = b.iter().map(|&x| F16::from_f32(x)).collect();
453
454        for i in 0..rows_a {
455            for j in 0..cols_b {
456                let mut sum = 0.0f32;
457                for k in 0..cols_a {
458                    let a_val = a_f16[i * cols_a + k].to_f32();
459                    let b_val = b_f16[k * cols_b + j].to_f32();
460                    sum += a_val * b_val;
461                }
462                result[i * cols_b + j] = sum;
463            }
464        }
465
466        result
467    }
468}
469
470/// Probabilistic algorithms for large-scale computations
471pub mod probabilistic {
472    use super::*;
473    #[cfg(feature = "no-std")]
474    use alloc::{vec, vec::Vec};
475
476    /// Count-Min Sketch for frequency estimation
477    pub struct CountMinSketch {
478        table: Vec<Vec<u32>>,
479        hash_functions: Vec<u64>,
480        width: usize,
481        #[allow(dead_code)] // Stored for introspection and future depth-adaptive queries
482        depth: usize,
483    }
484
485    impl CountMinSketch {
486        pub fn new(width: usize, depth: usize) -> Self {
487            use scirs2_core::random::thread_rng;
488            let mut rng = thread_rng();
489            let hash_functions: Vec<u64> = (0..depth).map(|_| rng.random::<u64>()).collect();
490
491            Self {
492                table: vec![vec![0; width]; depth],
493                hash_functions,
494                width,
495                depth,
496            }
497        }
498
499        pub fn update(&mut self, item: u64, count: u32) {
500            for (i, &hash_seed) in self.hash_functions.iter().enumerate() {
501                let hash = self.hash_item(item, hash_seed);
502                let index = (hash as usize) % self.width;
503                self.table[i][index] = self.table[i][index].saturating_add(count);
504            }
505        }
506
507        pub fn estimate(&self, item: u64) -> u32 {
508            self.hash_functions
509                .iter()
510                .enumerate()
511                .map(|(i, &hash_seed)| {
512                    let hash = self.hash_item(item, hash_seed);
513                    let index = (hash as usize) % self.width;
514                    self.table[i][index]
515                })
516                .min()
517                .unwrap_or(0)
518        }
519
520        fn hash_item(&self, item: u64, seed: u64) -> u64 {
521            // Better hash function (FNV-1a variant)
522            let mut hash = seed.wrapping_mul(14695981039346656037u64);
523            let bytes = item.to_le_bytes();
524            for byte in bytes {
525                hash ^= byte as u64;
526                hash = hash.wrapping_mul(1099511628211);
527            }
528            hash
529        }
530    }
531
532    /// HyperLogLog for cardinality estimation
533    pub struct HyperLogLog {
534        buckets: Vec<u8>,
535        bucket_count: usize,
536        alpha: f64,
537    }
538
539    impl HyperLogLog {
540        pub fn new(precision: u8) -> Self {
541            let bucket_count = 1 << precision;
542            let alpha = match bucket_count {
543                16 => 0.673,
544                32 => 0.697,
545                64 => 0.709,
546                _ => 0.7213 / (1.0 + 1.079 / bucket_count as f64),
547            };
548
549            Self {
550                buckets: vec![0; bucket_count],
551                bucket_count,
552                alpha,
553            }
554        }
555
556        pub fn add(&mut self, item: u64) {
557            let hash = self.hash_item(item);
558            let precision = self.bucket_count.trailing_zeros() as usize;
559            let bucket = (hash & ((self.bucket_count - 1) as u64)) as usize;
560            let remaining_hash = hash >> precision;
561            let leading_zeros = remaining_hash.leading_zeros() as u8 + 1;
562
563            self.buckets[bucket] = self.buckets[bucket].max(leading_zeros);
564        }
565
566        pub fn estimate(&self) -> f64 {
567            let raw_estimate = self.alpha * (self.bucket_count as f64).powi(2)
568                / self
569                    .buckets
570                    .iter()
571                    .map(|&b| 2.0_f64.powi(-(b as i32)))
572                    .sum::<f64>();
573
574            // Small range correction
575            if raw_estimate <= 2.5 * self.bucket_count as f64 {
576                let zeros = self.buckets.iter().filter(|&&b| b == 0).count();
577                if zeros != 0 {
578                    return (self.bucket_count as f64)
579                        * (self.bucket_count as f64 / zeros as f64).ln();
580                }
581            }
582
583            raw_estimate
584        }
585
586        fn hash_item(&self, item: u64) -> u64 {
587            // FNV-1a hash
588            let mut hash = 14695981039346656037u64;
589            let bytes = item.to_le_bytes();
590            for byte in bytes {
591                hash ^= byte as u64;
592                hash = hash.wrapping_mul(1099511628211);
593            }
594            hash
595        }
596    }
597
598    /// Bloom filter for membership testing
599    pub struct BloomFilter {
600        bit_array: Vec<bool>,
601        hash_functions: Vec<u64>,
602        size: usize,
603        #[allow(dead_code)] // Stored for false-positive-rate reporting and future re-hash logic
604        hash_count: usize,
605    }
606
607    impl BloomFilter {
608        pub fn new(expected_elements: usize, false_positive_rate: f64) -> Self {
609            let size = (-(expected_elements as f64 * false_positive_rate.ln())
610                / (2.0_f64.ln().powi(2)))
611            .ceil() as usize;
612            let hash_count =
613                ((size as f64 / expected_elements as f64) * 2.0_f64.ln()).ceil() as usize;
614
615            use scirs2_core::random::thread_rng;
616            let mut rng = thread_rng();
617            let hash_functions: Vec<u64> = (0..hash_count).map(|_| rng.random::<u64>()).collect();
618
619            Self {
620                bit_array: vec![false; size],
621                hash_functions,
622                size,
623                hash_count,
624            }
625        }
626
627        pub fn add(&mut self, item: u64) {
628            for &hash_seed in &self.hash_functions {
629                let hash = self.hash_item(item, hash_seed);
630                let index = (hash as usize) % self.size;
631                self.bit_array[index] = true;
632            }
633        }
634
635        pub fn contains(&self, item: u64) -> bool {
636            self.hash_functions.iter().all(|&hash_seed| {
637                let hash = self.hash_item(item, hash_seed);
638                let index = (hash as usize) % self.size;
639                self.bit_array[index]
640            })
641        }
642
643        fn hash_item(&self, item: u64, seed: u64) -> u64 {
644            item.wrapping_mul(seed).wrapping_add(seed >> 32)
645        }
646    }
647}
648
649/// Sketching techniques for streaming data
650pub mod sketching {
651    use super::*;
652    #[cfg(feature = "no-std")]
653    use alloc::{vec, vec::Vec};
654
655    /// Johnson-Lindenstrauss random projection
656    pub struct RandomProjection {
657        projection_matrix: Vec<f32>,
658        original_dim: usize,
659        projected_dim: usize,
660    }
661
662    impl RandomProjection {
663        pub fn new(original_dim: usize, projected_dim: usize, epsilon: f64) -> Self {
664            // Verify JL lemma constraints
665            let min_dim =
666                (4.0 * (2.0 * epsilon.powi(2) - epsilon.powi(3) / 3.0).ln()).ceil() as usize;
667            assert!(
668                projected_dim >= min_dim,
669                "Projected dimension too small for given epsilon"
670            );
671
672            use scirs2_core::random::thread_rng;
673            let mut rng = thread_rng();
674            let scale = (projected_dim as f32).sqrt();
675
676            let projection_matrix: Vec<f32> = (0..original_dim * projected_dim)
677                .map(|_| {
678                    // Gaussian random projection
679                    let u1: f32 = rng.random::<f32>();
680                    let u2: f32 = rng.random::<f32>();
681                    let z = (-2.0 * u1.ln()).sqrt() * (2.0 * core::f32::consts::PI * u2).cos();
682                    z / scale
683                })
684                .collect();
685
686            Self {
687                projection_matrix,
688                original_dim,
689                projected_dim,
690            }
691        }
692
693        pub fn project(&self, vector: &[f32]) -> Vec<f32> {
694            assert_eq!(vector.len(), self.original_dim);
695
696            let mut result = vec![0.0f32; self.projected_dim];
697
698            for (j, result_j) in result.iter_mut().enumerate() {
699                for (i, &v) in vector.iter().enumerate() {
700                    *result_j += v * self.projection_matrix[j * self.original_dim + i];
701                }
702            }
703
704            result
705        }
706
707        pub fn batch_project(&self, vectors: &[Vec<f32>]) -> Vec<Vec<f32>> {
708            vectors.iter().map(|v| self.project(v)).collect()
709        }
710    }
711
712    /// Frequent Items sketch (Count-Min with improvements)
713    pub struct FrequentItemsSketch {
714        count_min: probabilistic::CountMinSketch,
715        heavy_hitters: HashMap<u64, u32>,
716        threshold: u32,
717        total_count: u64,
718    }
719
720    impl FrequentItemsSketch {
721        pub fn new(width: usize, depth: usize, threshold: u32) -> Self {
722            Self {
723                count_min: probabilistic::CountMinSketch::new(width, depth),
724                heavy_hitters: HashMap::new(),
725                threshold,
726                total_count: 0,
727            }
728        }
729
730        pub fn update(&mut self, item: u64, count: u32) {
731            self.count_min.update(item, count);
732            self.total_count += count as u64;
733
734            let estimated_count = self.count_min.estimate(item);
735            if estimated_count >= self.threshold {
736                *self.heavy_hitters.entry(item).or_insert(0) += count;
737            }
738        }
739
740        pub fn get_frequent_items(&self) -> Vec<(u64, u32)> {
741            self.heavy_hitters.iter().map(|(&k, &v)| (k, v)).collect()
742        }
743
744        pub fn estimate_frequency(&self, item: u64) -> f64 {
745            let count = if let Some(&exact_count) = self.heavy_hitters.get(&item) {
746                exact_count
747            } else {
748                self.count_min.estimate(item)
749            };
750
751            count as f64 / self.total_count as f64
752        }
753    }
754
755    /// Quantile sketching using Q-digest
756    pub struct QuantileSketch {
757        buckets: Vec<(f64, u64)>, // (value, count)
758        max_buckets: usize,
759        total_count: u64,
760    }
761
762    impl QuantileSketch {
763        pub fn new(max_buckets: usize) -> Self {
764            Self {
765                buckets: Vec::new(),
766                max_buckets,
767                total_count: 0,
768            }
769        }
770
771        pub fn add(&mut self, value: f64) {
772            self.total_count += 1;
773
774            // Find insertion point
775            let pos = self
776                .buckets
777                .binary_search_by(|(v, _)| v.partial_cmp(&value).expect("operation should succeed"))
778                .unwrap_or_else(|e| e);
779
780            if pos < self.buckets.len() && (self.buckets[pos].0 - value).abs() < 1e-10 {
781                // Value already exists, increment count
782                self.buckets[pos].1 += 1;
783            } else {
784                // Insert new value
785                self.buckets.insert(pos, (value, 1));
786            }
787
788            // Compress if necessary
789            if self.buckets.len() > self.max_buckets {
790                self.compress();
791            }
792        }
793
794        pub fn quantile(&self, q: f64) -> Option<f64> {
795            if self.buckets.is_empty() || !(0.0..=1.0).contains(&q) {
796                return None;
797            }
798
799            let target_rank = (q * self.total_count as f64) as u64;
800            let mut current_rank = 0;
801
802            for &(value, count) in &self.buckets {
803                current_rank += count;
804                if current_rank >= target_rank {
805                    return Some(value);
806                }
807            }
808
809            self.buckets.last().map(|(v, _)| *v)
810        }
811
812        fn compress(&mut self) {
813            // Simple compression: merge adjacent buckets with smallest combined error
814            while self.buckets.len() > self.max_buckets {
815                let mut min_error = f64::INFINITY;
816                let mut merge_idx = 0;
817
818                for i in 0..self.buckets.len() - 1 {
819                    let error = (self.buckets[i + 1].0 - self.buckets[i].0)
820                        * (self.buckets[i].1 + self.buckets[i + 1].1) as f64;
821                    if error < min_error {
822                        min_error = error;
823                        merge_idx = i;
824                    }
825                }
826
827                // Merge buckets at merge_idx and merge_idx + 1
828                let merged_count = self.buckets[merge_idx].1 + self.buckets[merge_idx + 1].1;
829                let merged_value = (self.buckets[merge_idx].0 * self.buckets[merge_idx].1 as f64
830                    + self.buckets[merge_idx + 1].0 * self.buckets[merge_idx + 1].1 as f64)
831                    / merged_count as f64;
832
833                self.buckets[merge_idx] = (merged_value, merged_count);
834                self.buckets.remove(merge_idx + 1);
835            }
836        }
837    }
838}
839
840#[allow(non_snake_case)]
841#[cfg(all(test, not(feature = "no-std")))]
842mod tests {
843    use super::*;
844    #[cfg(feature = "no-std")]
845    use alloc::{vec, vec::Vec};
846    use approx::assert_abs_diff_eq;
847
848    #[test]
849    fn test_approximate_dot_product() {
850        let a = vec![1.0, 2.0, 3.0, 4.0];
851        let b = vec![5.0, 6.0, 7.0, 8.0];
852        let expected = 70.0; // 1*5 + 2*6 + 3*7 + 4*8
853
854        let (result, _error) =
855            approximate_ops::approximate_dot_product_f32(&a, &b, ErrorBound::MODERATE);
856        assert_abs_diff_eq!(result, expected, epsilon = 1.0);
857    }
858
859    #[test]
860    fn test_approximate_sum() {
861        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
862        let expected = 15.0;
863
864        let (result, _error) = approximate_ops::approximate_sum_f32(&data, ErrorBound::MODERATE);
865        assert_abs_diff_eq!(result, expected, epsilon = 0.1);
866    }
867
868    #[test]
869    fn test_f16_conversion() {
870        let values = vec![0.0, 1.0, -1.0, 10.5, -10.5];
871
872        for &val in &values {
873            let f16_val = reduced_precision::F16::from_f32(val);
874            let converted_back = f16_val.to_f32();
875            assert_abs_diff_eq!(converted_back, val, epsilon = 0.1);
876        }
877    }
878
879    #[test]
880    fn test_u8_quantization() {
881        let quantizer = reduced_precision::U8Quantized::new(-10.0, 10.0);
882
883        let values = vec![-10.0, 0.0, 10.0, 5.0, -5.0];
884        for &val in &values {
885            let quantized = quantizer.quantize(val);
886            let dequantized = quantizer.dequantize(quantized);
887            assert_abs_diff_eq!(dequantized, val, epsilon = 0.2);
888        }
889    }
890
891    #[test]
892    fn test_count_min_sketch() {
893        let mut sketch = probabilistic::CountMinSketch::new(100, 5);
894
895        sketch.update(42, 10);
896        sketch.update(42, 5);
897        sketch.update(100, 3);
898
899        assert!(sketch.estimate(42) >= 15);
900        assert!(sketch.estimate(100) >= 3);
901        assert_eq!(sketch.estimate(999), 0);
902    }
903
904    #[test]
905    fn test_hyperloglog() {
906        let mut hll = probabilistic::HyperLogLog::new(10);
907
908        // Add some unique items
909        for i in 0..1000 {
910            hll.add(i);
911        }
912
913        let estimate = hll.estimate();
914        assert!((100.0..=10000.0).contains(&estimate)); // Lenient range for HyperLogLog approximation
915    }
916
917    #[test]
918    fn test_bloom_filter() {
919        let mut bloom = probabilistic::BloomFilter::new(1000, 0.01);
920
921        // Add some items
922        for i in 0..100 {
923            bloom.add(i);
924        }
925
926        // Check membership
927        for i in 0..100 {
928            assert!(bloom.contains(i));
929        }
930
931        // Check for false positives (should be rare)
932        let mut false_positives = 0;
933        for i in 100..200 {
934            if bloom.contains(i) {
935                false_positives += 1;
936            }
937        }
938
939        assert!(false_positives < 5); // Should be very few false positives
940    }
941
942    #[test]
943    fn test_random_projection() {
944        let projection = sketching::RandomProjection::new(100, 20, 0.1);
945
946        let vector = (0..100).map(|i| i as f32).collect::<Vec<f32>>();
947        let projected = projection.project(&vector);
948
949        assert_eq!(projected.len(), 20);
950
951        // Test that projection preserves some structure
952        let vector2 = (0..100).map(|i| (i * 2) as f32).collect::<Vec<f32>>();
953        let projected2 = projection.project(&vector2);
954
955        // The projections should have some correlation
956        let correlation = projected
957            .iter()
958            .zip(projected2.iter())
959            .map(|(a, b)| a * b)
960            .sum::<f32>();
961
962        assert!(correlation > 0.0);
963    }
964
965    #[test]
966    fn test_quantile_sketch() {
967        let mut sketch = sketching::QuantileSketch::new(20);
968
969        // Add values 1 to 100
970        for i in 1..=100 {
971            sketch.add(i as f64);
972        }
973
974        // Test quantiles
975        let median = sketch.quantile(0.5).expect("operation should succeed");
976        assert!((45.0..=55.0).contains(&median));
977
978        let q90 = sketch.quantile(0.9).expect("operation should succeed");
979        assert!((85.0..=95.0).contains(&q90));
980    }
981
982    #[test]
983    fn test_frequent_items_sketch() {
984        let mut sketch = sketching::FrequentItemsSketch::new(100, 5, 5); // Lower threshold
985
986        // Add frequent items
987        for _ in 0..20 {
988            sketch.update(42, 1);
989        }
990        for _ in 0..15 {
991            sketch.update(100, 1);
992        }
993        for _ in 0..5 {
994            sketch.update(200, 1);
995        }
996
997        let frequent = sketch.get_frequent_items();
998        assert!(!frequent.is_empty()); // Should find at least the top 1
999
1000        // Check that frequency estimation works
1001        let freq_42 = sketch.estimate_frequency(42);
1002        assert!(freq_42 >= 0.3); // More lenient threshold (20/40 = 0.5)
1003    }
1004
1005    #[test]
1006    fn test_mixed_precision_matrix_multiply() {
1007        let a = vec![1.0, 2.0, 3.0, 4.0]; // 2x2 matrix
1008        let b = vec![5.0, 6.0, 7.0, 8.0]; // 2x2 matrix
1009
1010        let result = reduced_precision::mixed_precision_matrix_multiply(&a, &b, 2, 2, 2);
1011
1012        // Expected: [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]] = [[19, 22], [43, 50]]
1013        let expected = [19.0, 22.0, 43.0, 50.0];
1014
1015        for (actual, expected) in result.iter().zip(expected.iter()) {
1016            assert_abs_diff_eq!(*actual, *expected, epsilon = 1.0);
1017        }
1018    }
1019}