1use crate::error::{SparseError, SparseResult};
7
8use super::cost_model;
9use super::feature_extraction::{extract_features, normalize_features};
10use super::types::{
11 CostEstimate, MatrixFeatures, PreconditionerType, SelectionConfig, SelectionResult,
12};
13
14#[derive(Debug, Clone)]
20pub struct DecisionStump {
21 pub feature_idx: usize,
23 pub threshold: f64,
25 pub left_class: usize,
27 pub right_class: usize,
29}
30
31impl DecisionStump {
32 pub fn predict(&self, features: &[f64]) -> usize {
34 let val = features.get(self.feature_idx).copied().unwrap_or(0.0);
35 if val < self.threshold {
36 self.left_class
37 } else {
38 self.right_class
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
49pub enum DecisionTree {
50 Leaf(usize),
52 Split {
54 stump: DecisionStump,
56 left: Box<DecisionTree>,
58 right: Box<DecisionTree>,
60 },
61}
62
63impl DecisionTree {
64 pub fn train(features: &[Vec<f64>], labels: &[usize], max_depth: usize) -> Self {
66 Self::build(features, labels, max_depth, 0)
67 }
68
69 fn build(features: &[Vec<f64>], labels: &[usize], max_depth: usize, depth: usize) -> Self {
70 if labels.is_empty() {
71 return Self::Leaf(0);
72 }
73
74 let first = labels[0];
76 if labels.iter().all(|&l| l == first) || depth >= max_depth || features.is_empty() {
77 return Self::Leaf(majority_class(labels));
78 }
79
80 let n_features = features.first().map_or(0, |f| f.len());
81 if n_features == 0 {
82 return Self::Leaf(majority_class(labels));
83 }
84
85 let mut best_gini = f64::INFINITY;
87 let mut best_stump = DecisionStump {
88 feature_idx: 0,
89 threshold: 0.0,
90 left_class: 0,
91 right_class: 0,
92 };
93 let mut best_left_idx: Vec<usize> = Vec::new();
94 let mut best_right_idx: Vec<usize> = Vec::new();
95
96 for feat in 0..n_features {
97 let mut vals: Vec<f64> = features.iter().map(|f| f[feat]).collect();
99 vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
100 vals.dedup();
101
102 for window in vals.windows(2) {
103 let threshold = (window[0] + window[1]) / 2.0;
104 let mut left_labels = Vec::new();
105 let mut right_labels = Vec::new();
106 let mut left_idx = Vec::new();
107 let mut right_idx = Vec::new();
108
109 for (i, f) in features.iter().enumerate() {
110 if f[feat] < threshold {
111 left_labels.push(labels[i]);
112 left_idx.push(i);
113 } else {
114 right_labels.push(labels[i]);
115 right_idx.push(i);
116 }
117 }
118
119 if left_labels.is_empty() || right_labels.is_empty() {
120 continue;
121 }
122
123 let n_total = labels.len() as f64;
124 let gini = (left_labels.len() as f64 / n_total) * gini_impurity(&left_labels)
125 + (right_labels.len() as f64 / n_total) * gini_impurity(&right_labels);
126
127 if gini < best_gini {
128 best_gini = gini;
129 best_stump = DecisionStump {
130 feature_idx: feat,
131 threshold,
132 left_class: majority_class(&left_labels),
133 right_class: majority_class(&right_labels),
134 };
135 best_left_idx = left_idx;
136 best_right_idx = right_idx;
137 }
138 }
139 }
140
141 if best_left_idx.is_empty() || best_right_idx.is_empty() {
142 return Self::Leaf(majority_class(labels));
143 }
144
145 let left_features: Vec<Vec<f64>> =
146 best_left_idx.iter().map(|&i| features[i].clone()).collect();
147 let left_labels: Vec<usize> = best_left_idx.iter().map(|&i| labels[i]).collect();
148 let right_features: Vec<Vec<f64>> = best_right_idx
149 .iter()
150 .map(|&i| features[i].clone())
151 .collect();
152 let right_labels: Vec<usize> = best_right_idx.iter().map(|&i| labels[i]).collect();
153
154 Self::Split {
155 stump: best_stump,
156 left: Box::new(Self::build(
157 &left_features,
158 &left_labels,
159 max_depth,
160 depth + 1,
161 )),
162 right: Box::new(Self::build(
163 &right_features,
164 &right_labels,
165 max_depth,
166 depth + 1,
167 )),
168 }
169 }
170
171 pub fn predict(&self, features: &[f64]) -> usize {
173 match self {
174 Self::Leaf(label) => *label,
175 Self::Split { stump, left, right } => {
176 if stump.predict(features) == stump.left_class {
177 left.predict(features)
178 } else {
179 right.predict(features)
180 }
181 }
182 }
183 }
184}
185
186#[derive(Debug, Clone)]
192pub struct RandomForest {
193 pub trees: Vec<DecisionTree>,
195 pub n_classes: usize,
197}
198
199impl RandomForest {
200 pub fn train(features: &[Vec<f64>], labels: &[usize], n_trees: usize) -> Self {
205 let n_classes = labels.iter().copied().max().map_or(0, |m| m + 1);
206 let n_samples = features.len();
207 let mut trees = Vec::with_capacity(n_trees);
208
209 for t in 0..n_trees {
210 let offset = t % n_samples.max(1);
212 let bag_size = n_samples;
213 let mut bag_features = Vec::with_capacity(bag_size);
214 let mut bag_labels = Vec::with_capacity(bag_size);
215 for i in 0..bag_size {
216 let idx = (offset + i * (t + 1)) % n_samples.max(1);
217 if idx < n_samples {
218 bag_features.push(features[idx].clone());
219 bag_labels.push(labels[idx]);
220 }
221 }
222
223 let tree = DecisionTree::train(&bag_features, &bag_labels, 5);
224 trees.push(tree);
225 }
226
227 Self { trees, n_classes }
228 }
229
230 pub fn predict(&self, features: &[f64]) -> usize {
232 if self.trees.is_empty() {
233 return 0;
234 }
235 let mut votes = vec![0usize; self.n_classes.max(1)];
236 for tree in &self.trees {
237 let pred = tree.predict(features);
238 if pred < votes.len() {
239 votes[pred] += 1;
240 }
241 }
242 votes
243 .iter()
244 .enumerate()
245 .max_by_key(|&(_, &count)| count)
246 .map_or(0, |(idx, _)| idx)
247 }
248}
249
250#[derive(Debug, Clone, Default)]
259pub struct HeuristicClassifier;
260
261impl HeuristicClassifier {
262 pub fn predict(&self, features: &MatrixFeatures) -> PreconditionerType {
264 let is_diag_dominant = features.diag_dominance >= 1.0;
265 let is_symmetric = features.symmetry_measure > 0.95;
266 let is_small = features.n <= 500;
267 let is_dense = features.density > 0.1;
268 let is_large = features.n > 10_000;
269 let is_spd_like = is_diag_dominant && features.has_positive_diagonal && is_symmetric;
270
271 if is_small && is_dense {
272 return PreconditionerType::None;
273 }
274 if is_spd_like {
275 return PreconditionerType::IC0;
276 }
277 if is_diag_dominant && is_symmetric {
278 return PreconditionerType::SSOR;
279 }
280 if is_diag_dominant {
281 return PreconditionerType::Jacobi;
282 }
283 if is_large {
284 return PreconditionerType::AMG;
285 }
286 PreconditionerType::ILU0
287 }
288}
289
290#[derive(Debug, Clone)]
297#[non_exhaustive]
298pub enum PreconditionerClassifier {
299 Forest(RandomForest),
301 Heuristic(HeuristicClassifier),
303}
304
305impl Default for PreconditionerClassifier {
306 fn default() -> Self {
307 Self::Heuristic(HeuristicClassifier)
308 }
309}
310
311impl PreconditionerClassifier {
312 fn class_to_type(idx: usize) -> PreconditionerType {
314 match idx {
315 0 => PreconditionerType::Jacobi,
316 1 => PreconditionerType::SSOR,
317 2 => PreconditionerType::ILU0,
318 3 => PreconditionerType::IC0,
319 4 => PreconditionerType::AMG,
320 5 => PreconditionerType::SPAI,
321 6 => PreconditionerType::Polynomial,
322 7 => PreconditionerType::None,
323 #[allow(unreachable_patterns)]
324 _ => PreconditionerType::ILU0,
325 }
326 }
327
328 pub fn predict(&self, features: &MatrixFeatures) -> PreconditionerType {
330 match self {
331 Self::Forest(rf) => {
332 let fv = normalize_features(features);
333 Self::class_to_type(rf.predict(&fv))
334 }
335 Self::Heuristic(h) => h.predict(features),
336 #[allow(unreachable_patterns)]
337 _ => PreconditionerType::ILU0,
338 }
339 }
340}
341
342pub fn select_preconditioner(
351 values: &[f64],
352 row_ptr: &[usize],
353 col_idx: &[usize],
354 n: usize,
355 config: &SelectionConfig,
356) -> SparseResult<SelectionResult> {
357 let features = extract_features(values, row_ptr, col_idx, n)?;
358
359 let classifier = PreconditionerClassifier::default();
360 let recommended = classifier.predict(&features);
361
362 let candidates = [
364 PreconditionerType::Jacobi,
365 PreconditionerType::SSOR,
366 PreconditionerType::ILU0,
367 PreconditionerType::IC0,
368 PreconditionerType::AMG,
369 PreconditionerType::SPAI,
370 PreconditionerType::Polynomial,
371 PreconditionerType::None,
372 ];
373
374 let mut all_scores: Vec<(PreconditionerType, f64)> = if config.use_cost_model {
375 let ranked = cost_model::rank_by_cost(&features, &candidates);
376 let max_cost = ranked
378 .iter()
379 .map(|(_, c)| c.total_cost)
380 .fold(0.0_f64, f64::max);
381 let scale = if max_cost > 1e-30 { max_cost } else { 1.0 };
382 ranked
383 .iter()
384 .map(|(pt, c)| (*pt, 1.0 - c.total_cost / scale))
385 .collect()
386 } else {
387 candidates
388 .iter()
389 .map(|&pt| {
390 let score = if pt == recommended { 1.0 } else { 0.0 };
391 (pt, score)
392 })
393 .collect()
394 };
395
396 for entry in &mut all_scores {
398 if entry.0 == recommended {
399 entry.1 += 0.5;
400 }
401 }
402
403 all_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
405
406 let confidence = if all_scores.len() >= 2 {
408 let gap = all_scores[0].1 - all_scores[1].1;
409 (gap / (all_scores[0].1.abs() + 1e-10)).clamp(0.0, 1.0)
410 } else {
411 1.0
412 };
413
414 Ok(SelectionResult {
415 recommended,
416 confidence,
417 all_scores,
418 features,
419 })
420}
421
422fn majority_class(labels: &[usize]) -> usize {
427 if labels.is_empty() {
428 return 0;
429 }
430 let max_label = labels.iter().copied().max().unwrap_or(0);
431 let mut counts = vec![0usize; max_label + 1];
432 for &l in labels {
433 counts[l] += 1;
434 }
435 counts
436 .iter()
437 .enumerate()
438 .max_by_key(|&(_, &c)| c)
439 .map_or(0, |(idx, _)| idx)
440}
441
442fn gini_impurity(labels: &[usize]) -> f64 {
443 if labels.is_empty() {
444 return 0.0;
445 }
446 let max_label = labels.iter().copied().max().unwrap_or(0);
447 let mut counts = vec![0usize; max_label + 1];
448 for &l in labels {
449 counts[l] += 1;
450 }
451 let n = labels.len() as f64;
452 let sum_sq: f64 = counts.iter().map(|&c| (c as f64 / n).powi(2)).sum();
453 1.0 - sum_sq
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459
460 #[test]
461 fn test_heuristic_diag_dominant_symmetric_spd() {
462 let h = HeuristicClassifier;
463 let features = MatrixFeatures {
464 n: 1000,
465 nnz: 5000,
466 density: 0.005,
467 max_row_nnz: 5,
468 mean_row_nnz: 5.0,
469 bandwidth: 2,
470 bandwidth_ratio: 0.002,
471 cond_estimate: 10.0,
472 spectral_radius: 6.0,
473 diag_dominance: 2.0,
474 symmetry_measure: 1.0,
475 has_positive_diagonal: true,
476 };
477 assert_eq!(h.predict(&features), PreconditionerType::IC0);
478 }
479
480 #[test]
481 fn test_heuristic_diag_dominant_nonsymmetric() {
482 let h = HeuristicClassifier;
483 let features = MatrixFeatures {
484 n: 1000,
485 nnz: 5000,
486 density: 0.005,
487 max_row_nnz: 5,
488 mean_row_nnz: 5.0,
489 bandwidth: 2,
490 bandwidth_ratio: 0.002,
491 cond_estimate: 10.0,
492 spectral_radius: 6.0,
493 diag_dominance: 2.0,
494 symmetry_measure: 0.3,
495 has_positive_diagonal: true,
496 };
497 assert_eq!(h.predict(&features), PreconditionerType::Jacobi);
498 }
499
500 #[test]
501 fn test_heuristic_small_dense() {
502 let h = HeuristicClassifier;
503 let features = MatrixFeatures {
504 n: 50,
505 nnz: 500,
506 density: 0.2,
507 max_row_nnz: 20,
508 mean_row_nnz: 10.0,
509 bandwidth: 49,
510 bandwidth_ratio: 1.0,
511 cond_estimate: 5.0,
512 spectral_radius: 10.0,
513 diag_dominance: 0.5,
514 symmetry_measure: 0.8,
515 has_positive_diagonal: true,
516 };
517 assert_eq!(h.predict(&features), PreconditionerType::None);
518 }
519
520 #[test]
521 fn test_heuristic_large_sparse() {
522 let h = HeuristicClassifier;
523 let features = MatrixFeatures {
524 n: 100_000,
525 nnz: 500_000,
526 density: 0.00005,
527 max_row_nnz: 7,
528 mean_row_nnz: 5.0,
529 bandwidth: 1000,
530 bandwidth_ratio: 0.01,
531 cond_estimate: 1000.0,
532 spectral_radius: 100.0,
533 diag_dominance: 0.5,
534 symmetry_measure: 0.5,
535 has_positive_diagonal: true,
536 };
537 assert_eq!(h.predict(&features), PreconditionerType::AMG);
538 }
539
540 #[test]
541 fn test_heuristic_general() {
542 let h = HeuristicClassifier;
543 let features = MatrixFeatures {
544 n: 2000,
545 nnz: 20_000,
546 density: 0.005,
547 max_row_nnz: 15,
548 mean_row_nnz: 10.0,
549 bandwidth: 200,
550 bandwidth_ratio: 0.1,
551 cond_estimate: 100.0,
552 spectral_radius: 50.0,
553 diag_dominance: 0.3,
554 symmetry_measure: 0.6,
555 has_positive_diagonal: false,
556 };
557 assert_eq!(h.predict(&features), PreconditionerType::ILU0);
558 }
559
560 #[test]
561 fn test_select_preconditioner_tridiag() {
562 let values = vec![4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0];
564 let col_idx = vec![0, 1, 0, 1, 2, 1, 2];
565 let row_ptr = vec![0, 2, 5, 7];
566 let config = SelectionConfig::default();
567 let result =
568 select_preconditioner(&values, &row_ptr, &col_idx, 3, &config).expect("select");
569 assert_eq!(result.recommended, PreconditionerType::None);
571 assert!(!result.all_scores.is_empty());
572 }
573
574 #[test]
575 fn test_decision_tree_pure_leaf() {
576 let features = vec![vec![1.0], vec![2.0], vec![3.0]];
577 let labels = vec![0, 0, 0];
578 let tree = DecisionTree::train(&features, &labels, 3);
579 assert_eq!(tree.predict(&[1.5]), 0);
580 }
581
582 #[test]
583 fn test_random_forest_simple() {
584 let features = vec![
585 vec![0.1, 0.2],
586 vec![0.9, 0.8],
587 vec![0.15, 0.25],
588 vec![0.85, 0.75],
589 ];
590 let labels = vec![0, 1, 0, 1];
591 let rf = RandomForest::train(&features, &labels, 5);
592 let pred0 = rf.predict(&[0.1, 0.2]);
594 let pred1 = rf.predict(&[0.9, 0.8]);
595 assert!(pred0 < 2);
597 assert!(pred1 < 2);
598 }
599
600 #[test]
601 fn test_classifier_default_is_heuristic() {
602 let c = PreconditionerClassifier::default();
603 match c {
604 PreconditionerClassifier::Heuristic(_) => {}
605 _ => panic!("default should be heuristic"),
606 }
607 }
608}