sklears_core/
utils.rs

1/// Utility functions for machine learning operations
2///
3/// This module provides common utility functions that are frequently needed
4/// across different machine learning algorithms and workflows.
5use crate::types::Float;
6// SciRS2 Policy: Using scirs2_core::ndarray (COMPLIANT)
7use scirs2_core::ndarray::Array1;
8
9/// Generate a random seed from system entropy
10pub fn generate_random_seed() -> u64 {
11    use std::time::{SystemTime, UNIX_EPOCH};
12    SystemTime::now()
13        .duration_since(UNIX_EPOCH)
14        .unwrap()
15        .as_nanos() as u64
16}
17
18/// Calculate the entropy of a discrete distribution
19///
20/// # Arguments
21/// * `probabilities` - Array of probabilities that should sum to 1.0
22///
23/// # Returns
24/// The Shannon entropy in bits
25///
26/// # Example
27/// ```
28/// use sklears_core::utils::entropy;
29/// use scirs2_core::ndarray::array;
30///
31/// let probs = array![0.5, 0.5];
32/// let ent = entropy(&probs);
33/// assert!((ent - 1.0).abs() < 1e-10);
34/// ```
35pub fn entropy(probabilities: &Array1<Float>) -> Float {
36    probabilities
37        .iter()
38        .filter(|&&p| p > 0.0)
39        .map(|&p| -p * p.log2())
40        .sum()
41}
42
43/// Calculate Gini impurity for a discrete distribution
44///
45/// # Arguments
46/// * `probabilities` - Array of probabilities that should sum to 1.0
47///
48/// # Returns
49/// The Gini impurity score
50///
51/// # Example
52/// ```
53/// use sklears_core::utils::gini_impurity;
54/// use scirs2_core::ndarray::array;
55///
56/// let probs = array![0.5, 0.5];
57/// let gini = gini_impurity(&probs);
58/// assert!((gini - 0.5).abs() < 1e-10);
59/// ```
60pub fn gini_impurity(probabilities: &Array1<Float>) -> Float {
61    1.0 - probabilities.iter().map(|&p| p * p).sum::<Float>()
62}
63
64/// Normalize an array to have zero mean and unit variance
65///
66/// # Arguments
67/// * `array` - Input array to normalize
68///
69/// # Returns
70/// Normalized array with zero mean and unit variance
71///
72/// # Example
73/// ```
74/// use sklears_core::utils::standardize;
75/// use scirs2_core::ndarray::array;
76///
77/// let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
78/// let normalized = standardize(&data);
79/// let mean = normalized.mean().unwrap();
80/// assert!(mean.abs() < 1e-10);
81/// ```
82pub fn standardize(array: &Array1<Float>) -> Array1<Float> {
83    let mean = array.mean().unwrap();
84    let std = array.std(0.0);
85
86    if std > 1e-10 {
87        (array - mean) / std
88    } else {
89        array.clone()
90    }
91}
92
93/// Min-max normalize an array to the range [0, 1]
94///
95/// # Arguments
96/// * `array` - Input array to normalize
97///
98/// # Returns
99/// Normalized array with values in [0, 1]
100///
101/// # Example
102/// ```
103/// use sklears_core::utils::min_max_normalize;
104/// use scirs2_core::ndarray::array;
105///
106/// let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
107/// let normalized = min_max_normalize(&data);
108/// assert!((normalized[[0]] - 0.0).abs() < 1e-10);
109/// assert!((normalized[[4]] - 1.0).abs() < 1e-10);
110/// ```
111pub fn min_max_normalize(array: &Array1<Float>) -> Array1<Float> {
112    let min_val = array.iter().fold(Float::INFINITY, |a, &b| a.min(b));
113    let max_val = array.iter().fold(Float::NEG_INFINITY, |a, &b| a.max(b));
114    let range = max_val - min_val;
115
116    if range > 1e-10 {
117        (array - min_val) / range
118    } else {
119        Array1::zeros(array.len())
120    }
121}
122
123/// Calculate the cosine similarity between two vectors
124///
125/// # Arguments
126/// * `a` - First vector
127/// * `b` - Second vector
128///
129/// # Returns
130/// Cosine similarity value between -1 and 1
131///
132/// # Example
133/// ```
134/// use sklears_core::utils::cosine_similarity;
135/// use scirs2_core::ndarray::array;
136///
137/// let a = array![1.0, 0.0];
138/// let b = array![0.0, 1.0];
139/// let sim = cosine_similarity(&a, &b);
140/// assert!((sim - 0.0).abs() < 1e-10);
141/// ```
142pub fn cosine_similarity(a: &Array1<Float>, b: &Array1<Float>) -> Float {
143    if a.len() != b.len() {
144        panic!("Arrays must have the same length");
145    }
146
147    let dot_product = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum::<Float>();
148    let norm_a = a.iter().map(|&x| x * x).sum::<Float>().sqrt();
149    let norm_b = b.iter().map(|&x| x * x).sum::<Float>().sqrt();
150
151    if norm_a > 1e-10 && norm_b > 1e-10 {
152        dot_product / (norm_a * norm_b)
153    } else {
154        0.0
155    }
156}
157
158/// Calculate Euclidean distance between two points
159///
160/// # Arguments
161/// * `a` - First point
162/// * `b` - Second point
163///
164/// # Returns
165/// Euclidean distance
166///
167/// # Example
168/// ```
169/// use sklears_core::utils::euclidean_distance;
170/// use scirs2_core::ndarray::array;
171///
172/// let a = array![0.0, 0.0];
173/// let b = array![3.0, 4.0];
174/// let dist = euclidean_distance(&a, &b);
175/// assert!((dist - 5.0).abs() < 1e-10);
176/// ```
177pub fn euclidean_distance(a: &Array1<Float>, b: &Array1<Float>) -> Float {
178    if a.len() != b.len() {
179        panic!("Arrays must have the same length");
180    }
181
182    a.iter()
183        .zip(b.iter())
184        .map(|(&x, &y)| (x - y).powi(2))
185        .sum::<Float>()
186        .sqrt()
187}
188
189/// Calculate Manhattan distance between two points
190///
191/// # Arguments
192/// * `a` - First point
193/// * `b` - Second point
194///
195/// # Returns
196/// Manhattan distance
197///
198/// # Example
199/// ```
200/// use sklears_core::utils::manhattan_distance;
201/// use scirs2_core::ndarray::array;
202///
203/// let a = array![0.0, 0.0];
204/// let b = array![3.0, 4.0];
205/// let dist = manhattan_distance(&a, &b);
206/// assert!((dist - 7.0).abs() < 1e-10);
207/// ```
208pub fn manhattan_distance(a: &Array1<Float>, b: &Array1<Float>) -> Float {
209    if a.len() != b.len() {
210        panic!("Arrays must have the same length");
211    }
212
213    a.iter().zip(b.iter()).map(|(&x, &y)| (x - y).abs()).sum()
214}
215
216/// Check if a value is approximately zero within a tolerance
217///
218/// # Arguments
219/// * `value` - Value to check
220/// * `tolerance` - Tolerance level (default: 1e-10)
221///
222/// # Returns
223/// True if the value is within tolerance of zero
224pub fn is_zero(value: Float, tolerance: Option<Float>) -> bool {
225    let tol = tolerance.unwrap_or(1e-10);
226    value.abs() < tol
227}
228
229/// Clamp a value between minimum and maximum bounds
230///
231/// # Arguments
232/// * `value` - Value to clamp
233/// * `min_val` - Minimum bound
234/// * `max_val` - Maximum bound
235///
236/// # Returns
237/// Clamped value
238pub fn clamp(value: Float, min_val: Float, max_val: Float) -> Float {
239    value.clamp(min_val, max_val)
240}
241
242/// Calculate the number of combinations (n choose k)
243///
244/// # Arguments
245/// * `n` - Total number of items
246/// * `k` - Number of items to choose
247///
248/// # Returns
249/// Number of combinations
250pub fn combinations(n: usize, k: usize) -> usize {
251    if k > n {
252        return 0;
253    }
254    if k == 0 || k == n {
255        return 1;
256    }
257
258    let k = k.min(n - k); // Take advantage of symmetry
259    let mut result = 1;
260
261    for i in 0..k {
262        result = result * (n - i) / (i + 1);
263    }
264
265    result
266}
267
268/// Generate samples from a multivariate normal distribution
269///
270/// Generates samples from a multivariate normal distribution with specified mean
271/// and identity covariance matrix (independent components with unit variance).
272///
273/// # Arguments
274/// * `mean` - Mean vector of the distribution
275/// * `n_samples` - Number of samples to generate
276/// * `rng` - Random number generator
277///
278/// # Returns
279/// Array of shape (n_samples, n_features) containing the generated samples
280///
281/// # Example
282/// ```
283/// use sklears_core::utils::multivariate_normal_samples;
284/// use scirs2_core::ndarray::array;
285/// use scirs2_core::random::rngs::StdRng;
286/// use scirs2_core::random::SeedableRng;
287///
288/// let mean = array![0.0, 1.0];
289/// let mut rng = StdRng::seed_from_u64(42);
290/// let samples = multivariate_normal_samples(&mean, 100, &mut rng);
291/// assert_eq!(samples.shape(), &[100, 2]);
292/// ```
293pub fn multivariate_normal_samples<R: scirs2_core::random::Rng>(
294    mean: &Array1<Float>,
295    n_samples: usize,
296    rng: &mut R,
297) -> scirs2_core::ndarray::Array2<Float> {
298    use scirs2_core::ndarray::Array2;
299    use scirs2_core::random::essentials::Normal;
300    use scirs2_core::Distribution;
301
302    let n_features = mean.len();
303    let mut samples = Array2::zeros((n_samples, n_features));
304
305    // Standard normal distribution (mean=0, std=1)
306    let standard_normal =
307        Normal::new(0.0, 1.0).expect("Failed to create standard normal distribution");
308
309    // Generate samples: X = μ + Z where Z ~ N(0, I)
310    for i in 0..n_samples {
311        for j in 0..n_features {
312            let z = standard_normal.sample(rng);
313            samples[(i, j)] = mean[j] + z;
314        }
315    }
316
317    samples
318}
319
320#[allow(non_snake_case)]
321#[cfg(test)]
322mod tests {
323    use super::*;
324    // SciRS2 Policy: Using scirs2_core::ndarray and scirs2_core::random (COMPLIANT)
325    use scirs2_core::ndarray::array;
326
327    #[test]
328    fn test_entropy_uniform() {
329        let probs = array![0.25, 0.25, 0.25, 0.25];
330        let ent = entropy(&probs);
331        assert!((ent - 2.0).abs() < 1e-10);
332    }
333
334    #[test]
335    fn test_entropy_certain() {
336        let probs = array![1.0, 0.0, 0.0, 0.0];
337        let ent = entropy(&probs);
338        assert!(ent.abs() < 1e-10);
339    }
340
341    #[test]
342    fn test_gini_impurity_uniform() {
343        let probs = array![0.5, 0.5];
344        let gini = gini_impurity(&probs);
345        assert!((gini - 0.5).abs() < 1e-10);
346    }
347
348    #[test]
349    fn test_gini_impurity_pure() {
350        let probs = array![1.0, 0.0];
351        let gini = gini_impurity(&probs);
352        assert!(gini.abs() < 1e-10);
353    }
354
355    #[test]
356    fn test_standardize() {
357        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
358        let normalized = standardize(&data);
359        let mean = normalized.mean().unwrap();
360        let std = normalized.std(0.0);
361        assert!(mean.abs() < 1e-10);
362        assert!((std - 1.0).abs() < 1e-10);
363    }
364
365    #[test]
366    fn test_min_max_normalize() {
367        let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
368        let normalized = min_max_normalize(&data);
369        assert!((normalized[[0]] - 0.0).abs() < 1e-10);
370        assert!((normalized[[4]] - 1.0).abs() < 1e-10);
371    }
372
373    #[test]
374    fn test_cosine_similarity() {
375        let a = array![1.0, 0.0];
376        let b = array![1.0, 0.0];
377        let sim = cosine_similarity(&a, &b);
378        assert!((sim - 1.0).abs() < 1e-10);
379
380        let c = array![0.0, 1.0];
381        let sim2 = cosine_similarity(&a, &c);
382        assert!(sim2.abs() < 1e-10);
383    }
384
385    #[test]
386    fn test_euclidean_distance() {
387        let a = array![0.0, 0.0];
388        let b = array![3.0, 4.0];
389        let dist = euclidean_distance(&a, &b);
390        assert!((dist - 5.0).abs() < 1e-10);
391    }
392
393    #[test]
394    fn test_manhattan_distance() {
395        let a = array![0.0, 0.0];
396        let b = array![3.0, 4.0];
397        let dist = manhattan_distance(&a, &b);
398        assert!((dist - 7.0).abs() < 1e-10);
399    }
400
401    #[test]
402    fn test_multivariate_normal_samples() {
403        use scirs2_core::random::rngs::StdRng;
404        use scirs2_core::random::SeedableRng;
405
406        let mean = array![0.0, 1.0];
407        let mut rng = StdRng::seed_from_u64(42);
408        let samples = multivariate_normal_samples(&mean, 100, &mut rng);
409
410        // Check shape
411        assert_eq!(samples.shape(), &[100, 2]);
412
413        // Check that sample means are approximately correct
414        let sample_mean_0 = samples.column(0).mean().unwrap();
415        let sample_mean_1 = samples.column(1).mean().unwrap();
416
417        // With 100 samples, means should be within ~0.3 of true means (roughly 3 * std_err)
418        assert!(
419            (sample_mean_0 - 0.0).abs() < 0.3,
420            "Mean of first component should be close to 0.0"
421        );
422        assert!(
423            (sample_mean_1 - 1.0).abs() < 0.3,
424            "Mean of second component should be close to 1.0"
425        );
426
427        // Check that samples have reasonable variance (should be close to 1.0)
428        let sample_std_0 = samples.column(0).std(0.0);
429        let sample_std_1 = samples.column(1).std(0.0);
430        assert!(
431            sample_std_0 > 0.7 && sample_std_0 < 1.3,
432            "Std of first component should be close to 1.0"
433        );
434        assert!(
435            sample_std_1 > 0.7 && sample_std_1 < 1.3,
436            "Std of second component should be close to 1.0"
437        );
438    }
439
440    #[test]
441    fn test_is_zero() {
442        assert!(is_zero(0.0, None));
443        assert!(is_zero(1e-12, None));
444        assert!(!is_zero(1e-8, None));
445        assert!(is_zero(0.01, Some(0.1)));
446    }
447
448    #[test]
449    fn test_clamp() {
450        assert_eq!(clamp(5.0, 0.0, 10.0), 5.0);
451        assert_eq!(clamp(-1.0, 0.0, 10.0), 0.0);
452        assert_eq!(clamp(15.0, 0.0, 10.0), 10.0);
453    }
454
455    #[test]
456    fn test_combinations() {
457        assert_eq!(combinations(5, 0), 1);
458        assert_eq!(combinations(5, 1), 5);
459        assert_eq!(combinations(5, 2), 10);
460        assert_eq!(combinations(5, 3), 10);
461        assert_eq!(combinations(5, 5), 1);
462        assert_eq!(combinations(3, 5), 0);
463    }
464
465    #[test]
466    fn test_generate_random_seed() {
467        let seed1 = generate_random_seed();
468        std::thread::sleep(std::time::Duration::from_nanos(1000));
469        let seed2 = generate_random_seed();
470        // Seeds should be different (extremely unlikely to be the same)
471        assert_ne!(seed1, seed2);
472    }
473}