Skip to main content

scirs2_sparse/ml_preconditioner/
classifier.rs

1//! Classifiers for preconditioner selection.
2//!
3//! Provides both a rule-based heuristic classifier (requires no training) and
4//! a simple random-forest classifier that can be trained on labelled data.
5
6use crate::error::{SparseError, SparseResult};
7
8use super::cost_model;
9use super::feature_extraction::{extract_features, normalize_features};
10use super::types::{
11    CostEstimate, MatrixFeatures, PreconditionerType, SelectionConfig, SelectionResult,
12};
13
14// ============================================================
15// Decision stump
16// ============================================================
17
18/// A single axis-aligned split.
19#[derive(Debug, Clone)]
20pub struct DecisionStump {
21    /// Index into the feature vector.
22    pub feature_idx: usize,
23    /// Split threshold.
24    pub threshold: f64,
25    /// Class label when feature < threshold.
26    pub left_class: usize,
27    /// Class label when feature >= threshold.
28    pub right_class: usize,
29}
30
31impl DecisionStump {
32    /// Predict the class label for a feature vector.
33    pub fn predict(&self, features: &[f64]) -> usize {
34        let val = features.get(self.feature_idx).copied().unwrap_or(0.0);
35        if val < self.threshold {
36            self.left_class
37        } else {
38            self.right_class
39        }
40    }
41}
42
43// ============================================================
44// Decision tree
45// ============================================================
46
47/// A binary decision tree of bounded depth.
48#[derive(Debug, Clone)]
49pub enum DecisionTree {
50    /// Leaf node with a class label.
51    Leaf(usize),
52    /// Internal split node.
53    Split {
54        /// The decision stump at this node.
55        stump: DecisionStump,
56        /// Left subtree (feature < threshold).
57        left: Box<DecisionTree>,
58        /// Right subtree (feature >= threshold).
59        right: Box<DecisionTree>,
60    },
61}
62
63impl DecisionTree {
64    /// Build a decision tree from labelled data with a maximum depth.
65    pub fn train(features: &[Vec<f64>], labels: &[usize], max_depth: usize) -> Self {
66        Self::build(features, labels, max_depth, 0)
67    }
68
69    fn build(features: &[Vec<f64>], labels: &[usize], max_depth: usize, depth: usize) -> Self {
70        if labels.is_empty() {
71            return Self::Leaf(0);
72        }
73
74        // Check if all labels are the same
75        let first = labels[0];
76        if labels.iter().all(|&l| l == first) || depth >= max_depth || features.is_empty() {
77            return Self::Leaf(majority_class(labels));
78        }
79
80        let n_features = features.first().map_or(0, |f| f.len());
81        if n_features == 0 {
82            return Self::Leaf(majority_class(labels));
83        }
84
85        // Find best split via Gini impurity reduction
86        let mut best_gini = f64::INFINITY;
87        let mut best_stump = DecisionStump {
88            feature_idx: 0,
89            threshold: 0.0,
90            left_class: 0,
91            right_class: 0,
92        };
93        let mut best_left_idx: Vec<usize> = Vec::new();
94        let mut best_right_idx: Vec<usize> = Vec::new();
95
96        for feat in 0..n_features {
97            // Collect unique thresholds (midpoints of sorted values)
98            let mut vals: Vec<f64> = features.iter().map(|f| f[feat]).collect();
99            vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
100            vals.dedup();
101
102            for window in vals.windows(2) {
103                let threshold = (window[0] + window[1]) / 2.0;
104                let mut left_labels = Vec::new();
105                let mut right_labels = Vec::new();
106                let mut left_idx = Vec::new();
107                let mut right_idx = Vec::new();
108
109                for (i, f) in features.iter().enumerate() {
110                    if f[feat] < threshold {
111                        left_labels.push(labels[i]);
112                        left_idx.push(i);
113                    } else {
114                        right_labels.push(labels[i]);
115                        right_idx.push(i);
116                    }
117                }
118
119                if left_labels.is_empty() || right_labels.is_empty() {
120                    continue;
121                }
122
123                let n_total = labels.len() as f64;
124                let gini = (left_labels.len() as f64 / n_total) * gini_impurity(&left_labels)
125                    + (right_labels.len() as f64 / n_total) * gini_impurity(&right_labels);
126
127                if gini < best_gini {
128                    best_gini = gini;
129                    best_stump = DecisionStump {
130                        feature_idx: feat,
131                        threshold,
132                        left_class: majority_class(&left_labels),
133                        right_class: majority_class(&right_labels),
134                    };
135                    best_left_idx = left_idx;
136                    best_right_idx = right_idx;
137                }
138            }
139        }
140
141        if best_left_idx.is_empty() || best_right_idx.is_empty() {
142            return Self::Leaf(majority_class(labels));
143        }
144
145        let left_features: Vec<Vec<f64>> =
146            best_left_idx.iter().map(|&i| features[i].clone()).collect();
147        let left_labels: Vec<usize> = best_left_idx.iter().map(|&i| labels[i]).collect();
148        let right_features: Vec<Vec<f64>> = best_right_idx
149            .iter()
150            .map(|&i| features[i].clone())
151            .collect();
152        let right_labels: Vec<usize> = best_right_idx.iter().map(|&i| labels[i]).collect();
153
154        Self::Split {
155            stump: best_stump,
156            left: Box::new(Self::build(
157                &left_features,
158                &left_labels,
159                max_depth,
160                depth + 1,
161            )),
162            right: Box::new(Self::build(
163                &right_features,
164                &right_labels,
165                max_depth,
166                depth + 1,
167            )),
168        }
169    }
170
171    /// Predict the class for a single feature vector.
172    pub fn predict(&self, features: &[f64]) -> usize {
173        match self {
174            Self::Leaf(label) => *label,
175            Self::Split { stump, left, right } => {
176                if stump.predict(features) == stump.left_class {
177                    left.predict(features)
178                } else {
179                    right.predict(features)
180                }
181            }
182        }
183    }
184}
185
186// ============================================================
187// Random forest
188// ============================================================
189
190/// A bagged ensemble of decision trees.
191#[derive(Debug, Clone)]
192pub struct RandomForest {
193    /// Individual trees in the ensemble.
194    pub trees: Vec<DecisionTree>,
195    /// Number of distinct class labels.
196    pub n_classes: usize,
197}
198
199impl RandomForest {
200    /// Train a random forest on labelled feature data.
201    ///
202    /// Uses bootstrap sampling (simple cyclic resampling for determinism in
203    /// the absence of an RNG) and builds each tree with max_depth = 5.
204    pub fn train(features: &[Vec<f64>], labels: &[usize], n_trees: usize) -> Self {
205        let n_classes = labels.iter().copied().max().map_or(0, |m| m + 1);
206        let n_samples = features.len();
207        let mut trees = Vec::with_capacity(n_trees);
208
209        for t in 0..n_trees {
210            // Deterministic bootstrap: offset + stride
211            let offset = t % n_samples.max(1);
212            let bag_size = n_samples;
213            let mut bag_features = Vec::with_capacity(bag_size);
214            let mut bag_labels = Vec::with_capacity(bag_size);
215            for i in 0..bag_size {
216                let idx = (offset + i * (t + 1)) % n_samples.max(1);
217                if idx < n_samples {
218                    bag_features.push(features[idx].clone());
219                    bag_labels.push(labels[idx]);
220                }
221            }
222
223            let tree = DecisionTree::train(&bag_features, &bag_labels, 5);
224            trees.push(tree);
225        }
226
227        Self { trees, n_classes }
228    }
229
230    /// Predict the class for a feature vector using majority vote.
231    pub fn predict(&self, features: &[f64]) -> usize {
232        if self.trees.is_empty() {
233            return 0;
234        }
235        let mut votes = vec![0usize; self.n_classes.max(1)];
236        for tree in &self.trees {
237            let pred = tree.predict(features);
238            if pred < votes.len() {
239                votes[pred] += 1;
240            }
241        }
242        votes
243            .iter()
244            .enumerate()
245            .max_by_key(|&(_, &count)| count)
246            .map_or(0, |(idx, _)| idx)
247    }
248}
249
250// ============================================================
251// Heuristic classifier
252// ============================================================
253
254/// Rule-based heuristic classifier that requires no training data.
255///
256/// Uses structural and numerical properties of the matrix to pick
257/// a preconditioner via human-expert rules.
258#[derive(Debug, Clone, Default)]
259pub struct HeuristicClassifier;
260
261impl HeuristicClassifier {
262    /// Select a preconditioner type based on matrix features.
263    pub fn predict(&self, features: &MatrixFeatures) -> PreconditionerType {
264        let is_diag_dominant = features.diag_dominance >= 1.0;
265        let is_symmetric = features.symmetry_measure > 0.95;
266        let is_small = features.n <= 500;
267        let is_dense = features.density > 0.1;
268        let is_large = features.n > 10_000;
269        let is_spd_like = is_diag_dominant && features.has_positive_diagonal && is_symmetric;
270
271        if is_small && is_dense {
272            return PreconditionerType::None;
273        }
274        if is_spd_like {
275            return PreconditionerType::IC0;
276        }
277        if is_diag_dominant && is_symmetric {
278            return PreconditionerType::SSOR;
279        }
280        if is_diag_dominant {
281            return PreconditionerType::Jacobi;
282        }
283        if is_large {
284            return PreconditionerType::AMG;
285        }
286        PreconditionerType::ILU0
287    }
288}
289
290// ============================================================
291// Unified classifier wrapper
292// ============================================================
293
294/// Unified classifier that delegates to either a random forest or the
295/// heuristic rule set.
296#[derive(Debug, Clone)]
297#[non_exhaustive]
298pub enum PreconditionerClassifier {
299    /// Learned random-forest classifier.
300    Forest(RandomForest),
301    /// Rule-based heuristic classifier.
302    Heuristic(HeuristicClassifier),
303}
304
305impl Default for PreconditionerClassifier {
306    fn default() -> Self {
307        Self::Heuristic(HeuristicClassifier)
308    }
309}
310
311impl PreconditionerClassifier {
312    /// Map a class index to a `PreconditionerType`.
313    fn class_to_type(idx: usize) -> PreconditionerType {
314        match idx {
315            0 => PreconditionerType::Jacobi,
316            1 => PreconditionerType::SSOR,
317            2 => PreconditionerType::ILU0,
318            3 => PreconditionerType::IC0,
319            4 => PreconditionerType::AMG,
320            5 => PreconditionerType::SPAI,
321            6 => PreconditionerType::Polynomial,
322            7 => PreconditionerType::None,
323            #[allow(unreachable_patterns)]
324            _ => PreconditionerType::ILU0,
325        }
326    }
327
328    /// Predict the preconditioner type.
329    pub fn predict(&self, features: &MatrixFeatures) -> PreconditionerType {
330        match self {
331            Self::Forest(rf) => {
332                let fv = normalize_features(features);
333                Self::class_to_type(rf.predict(&fv))
334            }
335            Self::Heuristic(h) => h.predict(features),
336            #[allow(unreachable_patterns)]
337            _ => PreconditionerType::ILU0,
338        }
339    }
340}
341
342// ============================================================
343// Top-level selection API
344// ============================================================
345
346/// Select the best preconditioner for a sparse matrix given as raw CSR data.
347///
348/// This is the main entry point. It extracts features, classifies, and
349/// optionally re-ranks candidates by estimated cost.
350pub fn select_preconditioner(
351    values: &[f64],
352    row_ptr: &[usize],
353    col_idx: &[usize],
354    n: usize,
355    config: &SelectionConfig,
356) -> SparseResult<SelectionResult> {
357    let features = extract_features(values, row_ptr, col_idx, n)?;
358
359    let classifier = PreconditionerClassifier::default();
360    let recommended = classifier.predict(&features);
361
362    // Build scored candidate list
363    let candidates = [
364        PreconditionerType::Jacobi,
365        PreconditionerType::SSOR,
366        PreconditionerType::ILU0,
367        PreconditionerType::IC0,
368        PreconditionerType::AMG,
369        PreconditionerType::SPAI,
370        PreconditionerType::Polynomial,
371        PreconditionerType::None,
372    ];
373
374    let mut all_scores: Vec<(PreconditionerType, f64)> = if config.use_cost_model {
375        let ranked = cost_model::rank_by_cost(&features, &candidates);
376        // Invert cost to get a score (lower cost → higher score)
377        let max_cost = ranked
378            .iter()
379            .map(|(_, c)| c.total_cost)
380            .fold(0.0_f64, f64::max);
381        let scale = if max_cost > 1e-30 { max_cost } else { 1.0 };
382        ranked
383            .iter()
384            .map(|(pt, c)| (*pt, 1.0 - c.total_cost / scale))
385            .collect()
386    } else {
387        candidates
388            .iter()
389            .map(|&pt| {
390                let score = if pt == recommended { 1.0 } else { 0.0 };
391                (pt, score)
392            })
393            .collect()
394    };
395
396    // Ensure recommended type gets a bonus
397    for entry in &mut all_scores {
398        if entry.0 == recommended {
399            entry.1 += 0.5;
400        }
401    }
402
403    // Sort descending by score
404    all_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
405
406    // Confidence based on score gap between #1 and #2
407    let confidence = if all_scores.len() >= 2 {
408        let gap = all_scores[0].1 - all_scores[1].1;
409        (gap / (all_scores[0].1.abs() + 1e-10)).clamp(0.0, 1.0)
410    } else {
411        1.0
412    };
413
414    Ok(SelectionResult {
415        recommended,
416        confidence,
417        all_scores,
418        features,
419    })
420}
421
422// ============================================================
423// Helpers
424// ============================================================
425
426fn majority_class(labels: &[usize]) -> usize {
427    if labels.is_empty() {
428        return 0;
429    }
430    let max_label = labels.iter().copied().max().unwrap_or(0);
431    let mut counts = vec![0usize; max_label + 1];
432    for &l in labels {
433        counts[l] += 1;
434    }
435    counts
436        .iter()
437        .enumerate()
438        .max_by_key(|&(_, &c)| c)
439        .map_or(0, |(idx, _)| idx)
440}
441
442fn gini_impurity(labels: &[usize]) -> f64 {
443    if labels.is_empty() {
444        return 0.0;
445    }
446    let max_label = labels.iter().copied().max().unwrap_or(0);
447    let mut counts = vec![0usize; max_label + 1];
448    for &l in labels {
449        counts[l] += 1;
450    }
451    let n = labels.len() as f64;
452    let sum_sq: f64 = counts.iter().map(|&c| (c as f64 / n).powi(2)).sum();
453    1.0 - sum_sq
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459
460    #[test]
461    fn test_heuristic_diag_dominant_symmetric_spd() {
462        let h = HeuristicClassifier;
463        let features = MatrixFeatures {
464            n: 1000,
465            nnz: 5000,
466            density: 0.005,
467            max_row_nnz: 5,
468            mean_row_nnz: 5.0,
469            bandwidth: 2,
470            bandwidth_ratio: 0.002,
471            cond_estimate: 10.0,
472            spectral_radius: 6.0,
473            diag_dominance: 2.0,
474            symmetry_measure: 1.0,
475            has_positive_diagonal: true,
476        };
477        assert_eq!(h.predict(&features), PreconditionerType::IC0);
478    }
479
480    #[test]
481    fn test_heuristic_diag_dominant_nonsymmetric() {
482        let h = HeuristicClassifier;
483        let features = MatrixFeatures {
484            n: 1000,
485            nnz: 5000,
486            density: 0.005,
487            max_row_nnz: 5,
488            mean_row_nnz: 5.0,
489            bandwidth: 2,
490            bandwidth_ratio: 0.002,
491            cond_estimate: 10.0,
492            spectral_radius: 6.0,
493            diag_dominance: 2.0,
494            symmetry_measure: 0.3,
495            has_positive_diagonal: true,
496        };
497        assert_eq!(h.predict(&features), PreconditionerType::Jacobi);
498    }
499
500    #[test]
501    fn test_heuristic_small_dense() {
502        let h = HeuristicClassifier;
503        let features = MatrixFeatures {
504            n: 50,
505            nnz: 500,
506            density: 0.2,
507            max_row_nnz: 20,
508            mean_row_nnz: 10.0,
509            bandwidth: 49,
510            bandwidth_ratio: 1.0,
511            cond_estimate: 5.0,
512            spectral_radius: 10.0,
513            diag_dominance: 0.5,
514            symmetry_measure: 0.8,
515            has_positive_diagonal: true,
516        };
517        assert_eq!(h.predict(&features), PreconditionerType::None);
518    }
519
520    #[test]
521    fn test_heuristic_large_sparse() {
522        let h = HeuristicClassifier;
523        let features = MatrixFeatures {
524            n: 100_000,
525            nnz: 500_000,
526            density: 0.00005,
527            max_row_nnz: 7,
528            mean_row_nnz: 5.0,
529            bandwidth: 1000,
530            bandwidth_ratio: 0.01,
531            cond_estimate: 1000.0,
532            spectral_radius: 100.0,
533            diag_dominance: 0.5,
534            symmetry_measure: 0.5,
535            has_positive_diagonal: true,
536        };
537        assert_eq!(h.predict(&features), PreconditionerType::AMG);
538    }
539
540    #[test]
541    fn test_heuristic_general() {
542        let h = HeuristicClassifier;
543        let features = MatrixFeatures {
544            n: 2000,
545            nnz: 20_000,
546            density: 0.005,
547            max_row_nnz: 15,
548            mean_row_nnz: 10.0,
549            bandwidth: 200,
550            bandwidth_ratio: 0.1,
551            cond_estimate: 100.0,
552            spectral_radius: 50.0,
553            diag_dominance: 0.3,
554            symmetry_measure: 0.6,
555            has_positive_diagonal: false,
556        };
557        assert_eq!(h.predict(&features), PreconditionerType::ILU0);
558    }
559
560    #[test]
561    fn test_select_preconditioner_tridiag() {
562        // 3×3 tridiag SPD
563        let values = vec![4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0];
564        let col_idx = vec![0, 1, 0, 1, 2, 1, 2];
565        let row_ptr = vec![0, 2, 5, 7];
566        let config = SelectionConfig::default();
567        let result =
568            select_preconditioner(&values, &row_ptr, &col_idx, 3, &config).expect("select");
569        // Small dense → should recommend None
570        assert_eq!(result.recommended, PreconditionerType::None);
571        assert!(!result.all_scores.is_empty());
572    }
573
574    #[test]
575    fn test_decision_tree_pure_leaf() {
576        let features = vec![vec![1.0], vec![2.0], vec![3.0]];
577        let labels = vec![0, 0, 0];
578        let tree = DecisionTree::train(&features, &labels, 3);
579        assert_eq!(tree.predict(&[1.5]), 0);
580    }
581
582    #[test]
583    fn test_random_forest_simple() {
584        let features = vec![
585            vec![0.1, 0.2],
586            vec![0.9, 0.8],
587            vec![0.15, 0.25],
588            vec![0.85, 0.75],
589        ];
590        let labels = vec![0, 1, 0, 1];
591        let rf = RandomForest::train(&features, &labels, 5);
592        // Predictions should be consistent for training data
593        let pred0 = rf.predict(&[0.1, 0.2]);
594        let pred1 = rf.predict(&[0.9, 0.8]);
595        // At minimum, check they're valid class indices
596        assert!(pred0 < 2);
597        assert!(pred1 < 2);
598    }
599
600    #[test]
601    fn test_classifier_default_is_heuristic() {
602        let c = PreconditionerClassifier::default();
603        match c {
604            PreconditionerClassifier::Heuristic(_) => {}
605            _ => panic!("default should be heuristic"),
606        }
607    }
608}