sklears_svm/hyperparameter_optimization/
grid_search.rs

1//! Grid Search Cross-Validation for hyperparameter optimization
2
3use std::time::Instant;
4
5#[cfg(feature = "parallel")]
6use rayon::prelude::*;
7use scirs2_core::ndarray::{Array1, Array2};
8use scirs2_core::random::Random;
9
10// Type aliases for compatibility
11type DMatrix<T> = Array2<T>;
12type DVector<T> = Array1<T>;
13
14use crate::kernels::KernelType;
15use crate::svc::SVC;
16use sklears_core::error::{Result, SklearsError};
17use sklears_core::traits::{Fit, Predict};
18
19use super::{
20    OptimizationConfig, OptimizationResult, ParameterSet, ParameterSpec, ScoringMetric, SearchSpace,
21};
22
23/// Grid Search hyperparameter optimizer
24pub struct GridSearchCV {
25    config: OptimizationConfig,
26    search_space: SearchSpace,
27    rng: Random<scirs2_core::random::rngs::StdRng>,
28}
29
30impl GridSearchCV {
31    /// Create a new grid search optimizer
32    pub fn new(config: OptimizationConfig, search_space: SearchSpace) -> Self {
33        let rng = if let Some(seed) = config.random_state {
34            Random::seed(seed)
35        } else {
36            Random::seed(42) // Default seed for reproducibility
37        };
38
39        Self {
40            config,
41            search_space,
42            rng,
43        }
44    }
45
46    /// Run grid search optimization
47    pub fn fit(&mut self, x: &DMatrix<f64>, y: &DVector<f64>) -> Result<OptimizationResult> {
48        let start_time = Instant::now();
49
50        // Generate parameter grid
51        let param_grid = self.generate_parameter_grid()?;
52
53        if self.config.verbose {
54            println!(
55                "Grid search with {} parameter combinations",
56                param_grid.len()
57            );
58        }
59
60        // Evaluate all parameter combinations
61        let cv_results: Vec<(ParameterSet, f64)> = {
62            #[cfg(feature = "parallel")]
63            if self.config.n_jobs.is_some() {
64                // Parallel evaluation
65                param_grid
66                    .into_par_iter()
67                    .map(|params| {
68                        let score = self
69                            .evaluate_params(&params, x, y)
70                            .unwrap_or(-f64::INFINITY);
71                        (params, score)
72                    })
73                    .collect()
74            } else {
75                // Sequential evaluation
76                param_grid
77                    .into_iter()
78                    .map(|params| {
79                        let score = self
80                            .evaluate_params(&params, x, y)
81                            .unwrap_or(-f64::INFINITY);
82                        if self.config.verbose {
83                            println!("Params: {:?}, Score: {:.6}", params, score);
84                        }
85                        (params, score)
86                    })
87                    .collect()
88            }
89
90            #[cfg(not(feature = "parallel"))]
91            {
92                // Sequential evaluation (parallel feature disabled)
93                param_grid
94                    .into_iter()
95                    .map(|params| {
96                        let score = self
97                            .evaluate_params(&params, x, y)
98                            .unwrap_or(-f64::INFINITY);
99                        if self.config.verbose {
100                            println!("Params: {:?}, Score: {:.6}", params, score);
101                        }
102                        (params, score)
103                    })
104                    .collect()
105            }
106        };
107
108        // Find best parameters
109        let (best_params, best_score) = cv_results
110            .iter()
111            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
112            .map(|(p, s)| (p.clone(), *s))
113            .ok_or_else(|| {
114                SklearsError::Other("No valid parameter combinations found".to_string())
115            })?;
116
117        let score_history: Vec<f64> = cv_results.iter().map(|(_, score)| *score).collect();
118        let n_iterations = cv_results.len();
119
120        Ok(OptimizationResult {
121            best_params,
122            best_score,
123            cv_results,
124            n_iterations,
125            optimization_time: start_time.elapsed().as_secs_f64(),
126            score_history,
127        })
128    }
129
130    /// Generate parameter grid for grid search
131    fn generate_parameter_grid(&mut self) -> Result<Vec<ParameterSet>> {
132        let mut param_grid = Vec::new();
133
134        // Clone specs to avoid borrowing conflicts
135        let c_spec = self.search_space.c.clone();
136        let kernel_spec = self.search_space.kernel.clone();
137        let tol_spec = self.search_space.tol.clone();
138        let max_iter_spec = self.search_space.max_iter.clone();
139
140        // Generate C values
141        let c_values = self.generate_values(&c_spec, 10)?;
142
143        // Generate kernel values
144        let kernel_values = if let Some(kernel_spec) = kernel_spec {
145            self.generate_kernel_values(&kernel_spec)?
146        } else {
147            vec![KernelType::Rbf { gamma: 1.0 }]
148        };
149
150        // Generate tolerance values
151        let tol_values = if let Some(tol_spec) = tol_spec {
152            self.generate_values(&tol_spec, 5)?
153        } else {
154            vec![1e-3]
155        };
156
157        // Generate max_iter values
158        let max_iter_values = if let Some(max_iter_spec) = max_iter_spec {
159            self.generate_values(&max_iter_spec, 3)?
160                .into_iter()
161                .map(|v| v as usize)
162                .collect()
163        } else {
164            vec![1000]
165        };
166
167        // Generate all combinations
168        for &c in &c_values {
169            for kernel in &kernel_values {
170                for &tol in &tol_values {
171                    for &max_iter in &max_iter_values {
172                        param_grid.push(ParameterSet {
173                            c,
174                            kernel: kernel.clone(),
175                            tol,
176                            max_iter,
177                        });
178                    }
179                }
180            }
181        }
182
183        Ok(param_grid)
184    }
185
186    /// Generate values from parameter specification
187    fn generate_values(&mut self, spec: &ParameterSpec, n_values: usize) -> Result<Vec<f64>> {
188        match spec {
189            ParameterSpec::Fixed(value) => Ok(vec![*value]),
190            ParameterSpec::Uniform { min, max } => Ok((0..n_values)
191                .map(|i| min + (max - min) * i as f64 / (n_values - 1) as f64)
192                .collect()),
193            ParameterSpec::LogUniform { min, max } => {
194                let log_min = min.ln();
195                let log_max = max.ln();
196                Ok((0..n_values)
197                    .map(|i| {
198                        let log_val =
199                            log_min + (log_max - log_min) * i as f64 / (n_values - 1) as f64;
200                        log_val.exp()
201                    })
202                    .collect())
203            }
204            ParameterSpec::Choice(choices) => Ok(choices.clone()),
205            ParameterSpec::KernelChoice(_) => Err(SklearsError::InvalidInput(
206                "Use generate_kernel_values for kernel specs".to_string(),
207            )),
208        }
209    }
210
211    /// Generate kernel values from kernel specification
212    fn generate_kernel_values(&mut self, spec: &ParameterSpec) -> Result<Vec<KernelType>> {
213        match spec {
214            ParameterSpec::KernelChoice(kernels) => Ok(kernels.clone()),
215            _ => Err(SklearsError::InvalidInput(
216                "Invalid kernel specification".to_string(),
217            )),
218        }
219    }
220
221    /// Evaluate parameter set using cross-validation
222    fn evaluate_params(
223        &self,
224        params: &ParameterSet,
225        x: &DMatrix<f64>,
226        y: &DVector<f64>,
227    ) -> Result<f64> {
228        let scores = self.cross_validate(params, x, y)?;
229        Ok(scores.iter().sum::<f64>() / scores.len() as f64)
230    }
231
232    /// Perform cross-validation
233    fn cross_validate(
234        &self,
235        params: &ParameterSet,
236        x: &DMatrix<f64>,
237        y: &DVector<f64>,
238    ) -> Result<Vec<f64>> {
239        let n_samples = x.nrows();
240        let fold_size = n_samples / self.config.cv_folds;
241        let mut scores = Vec::new();
242
243        for fold in 0..self.config.cv_folds {
244            let start_idx = fold * fold_size;
245            let end_idx = if fold == self.config.cv_folds - 1 {
246                n_samples
247            } else {
248                (fold + 1) * fold_size
249            };
250
251            // Create train/test splits
252            let mut x_train_data = Vec::new();
253            let mut y_train_vals = Vec::new();
254            let mut x_test_data = Vec::new();
255            let mut y_test_vals = Vec::new();
256
257            for i in 0..n_samples {
258                if i >= start_idx && i < end_idx {
259                    // Test set
260                    for j in 0..x.ncols() {
261                        x_test_data.push(x[[i, j]]);
262                    }
263                    y_test_vals.push(y[i]);
264                } else {
265                    // Training set
266                    for j in 0..x.ncols() {
267                        x_train_data.push(x[[i, j]]);
268                    }
269                    y_train_vals.push(y[i]);
270                }
271            }
272
273            let n_train = y_train_vals.len();
274            let n_test = y_test_vals.len();
275            let n_features = x.ncols();
276
277            let x_train = Array2::from_shape_vec((n_train, n_features), x_train_data)?;
278            let y_train = Array1::from_vec(y_train_vals);
279            let x_test = Array2::from_shape_vec((n_test, n_features), x_test_data)?;
280            let y_test = Array1::from_vec(y_test_vals);
281
282            // Train and evaluate model
283            let svm = SVC::new()
284                .c(params.c)
285                .kernel(params.kernel.clone())
286                .tol(params.tol)
287                .max_iter(params.max_iter);
288
289            let fitted_svm = svm.fit(&x_train, &y_train)?;
290            let y_pred = fitted_svm.predict(&x_test)?;
291
292            let score = self.calculate_score(&y_test, &y_pred)?;
293            scores.push(score);
294        }
295
296        Ok(scores)
297    }
298
299    /// Calculate score based on scoring metric
300    fn calculate_score(&self, y_true: &DVector<f64>, y_pred: &DVector<f64>) -> Result<f64> {
301        match self.config.scoring {
302            ScoringMetric::Accuracy => {
303                let correct = y_true
304                    .iter()
305                    .zip(y_pred.iter())
306                    .map(|(&t, &p)| if (t - p).abs() < 0.5 { 1.0 } else { 0.0 })
307                    .sum::<f64>();
308                Ok(correct / y_true.len() as f64)
309            }
310            ScoringMetric::MeanSquaredError => {
311                let mse = y_true
312                    .iter()
313                    .zip(y_pred.iter())
314                    .map(|(&t, &p)| (t - p).powi(2))
315                    .sum::<f64>()
316                    / y_true.len() as f64;
317                Ok(-mse) // Negative because we want to maximize
318            }
319            ScoringMetric::MeanAbsoluteError => {
320                let mae = y_true
321                    .iter()
322                    .zip(y_pred.iter())
323                    .map(|(&t, &p)| (t - p).abs())
324                    .sum::<f64>()
325                    / y_true.len() as f64;
326                Ok(-mae) // Negative because we want to maximize
327            }
328            _ => {
329                // For now, default to accuracy for other metrics
330                let correct = y_true
331                    .iter()
332                    .zip(y_pred.iter())
333                    .map(|(&t, &p)| if (t - p).abs() < 0.5 { 1.0 } else { 0.0 })
334                    .sum::<f64>();
335                Ok(correct / y_true.len() as f64)
336            }
337        }
338    }
339}