Skip to main content

sklears_svm/
nusvr.rs

1//! Nu Support Vector Regression
2//!
3//! This module implements Nu-SVR, an alternative formulation of SVM regression
4//! that uses a parameter nu instead of C and epsilon for controlling the
5//! regularization and error tolerance.
6
7use crate::kernels::{Kernel, KernelType};
8use scirs2_core::ndarray::{Array1, Array2};
9use sklears_core::{
10    error::{Result, SklearsError},
11    traits::{Fit, Predict, Trained, Untrained},
12    types::Float,
13};
14use std::marker::PhantomData;
15
16/// Nu Support Vector Regression Configuration
17#[derive(Debug, Clone)]
18pub struct NuSVRConfig {
19    /// Nu parameter (0 < nu <= 1)
20    pub nu: Float,
21    /// Kernel function to use
22    pub kernel: KernelType,
23    /// Tolerance for stopping criterion
24    pub tol: Float,
25    /// Maximum number of iterations
26    pub max_iter: usize,
27    /// Random seed
28    pub random_state: Option<u64>,
29}
30
31impl Default for NuSVRConfig {
32    fn default() -> Self {
33        Self {
34            nu: 0.5,
35            kernel: KernelType::Rbf { gamma: 1.0 },
36            tol: 1e-3,
37            max_iter: 200,
38            random_state: None,
39        }
40    }
41}
42
43/// Nu Support Vector Regression
44///
45/// Nu-SVR is an alternative formulation of SVR that uses a parameter nu
46/// instead of C and epsilon. The parameter nu controls the fraction of
47/// support vectors and roughly corresponds to the fraction of training
48/// points that lie outside the epsilon-tube.
49#[derive(Debug)]
50pub struct NuSVR<State = Untrained> {
51    config: NuSVRConfig,
52    state: PhantomData<State>,
53    // Fitted attributes
54    support_vectors_: Option<Array2<Float>>,
55    support_: Option<Array1<usize>>,
56    dual_coef_: Option<Array1<Float>>,
57    intercept_: Option<Float>,
58    n_features_in_: Option<usize>,
59    n_support_: Option<usize>,
60    epsilon_: Option<Float>,
61}
62
63impl NuSVR<Untrained> {
64    /// Create a new Nu-SVR regressor
65    pub fn new() -> Self {
66        Self {
67            config: NuSVRConfig::default(),
68            state: PhantomData,
69            support_vectors_: None,
70            support_: None,
71            dual_coef_: None,
72            intercept_: None,
73            n_features_in_: None,
74            n_support_: None,
75            epsilon_: None,
76        }
77    }
78
79    /// Set the nu parameter (0 < nu <= 1)
80    pub fn nu(mut self, nu: Float) -> Result<Self> {
81        if nu <= 0.0 || nu > 1.0 {
82            return Err(SklearsError::InvalidParameter {
83                name: "nu".to_string(),
84                reason: "must be in the range (0, 1]".to_string(),
85            });
86        }
87        self.config.nu = nu;
88        Ok(self)
89    }
90
91    /// Set the kernel type
92    pub fn kernel(mut self, kernel: KernelType) -> Self {
93        self.config.kernel = kernel;
94        self
95    }
96
97    /// Set the tolerance for stopping criterion
98    pub fn tol(mut self, tol: Float) -> Self {
99        self.config.tol = tol;
100        self
101    }
102
103    /// Set the maximum number of iterations
104    pub fn max_iter(mut self, max_iter: usize) -> Self {
105        self.config.max_iter = max_iter;
106        self
107    }
108
109    /// Set the random state
110    pub fn random_state(mut self, random_state: u64) -> Self {
111        self.config.random_state = Some(random_state);
112        self
113    }
114}
115
116impl Default for NuSVR<Untrained> {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122impl Fit<Array2<Float>, Array1<Float>> for NuSVR<Untrained> {
123    type Fitted = NuSVR<Trained>;
124
125    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
126        if x.nrows() != y.len() {
127            return Err(SklearsError::InvalidInput(format!(
128                "Shape mismatch: X has {} samples, y has {} samples",
129                x.nrows(),
130                y.len()
131            )));
132        }
133
134        if x.nrows() == 0 {
135            return Err(SklearsError::InvalidInput("Empty dataset".to_string()));
136        }
137
138        let n_features = x.ncols();
139        let n_samples = x.nrows();
140
141        // Estimate epsilon from the data and nu parameter
142        // This is a simplified approach - in practice, epsilon is determined
143        // during the optimization process
144        let y_std = {
145            let mean = y.mean().unwrap_or(0.0);
146            let variance =
147                y.iter().map(|&val| (val - mean).powi(2)).sum::<Float>() / n_samples as Float;
148            variance.sqrt()
149        };
150        let epsilon = self.config.nu * y_std;
151
152        // Convert nu to C parameter for SVR
153        // This is a simplified conversion
154        let _c = 1.0 / (self.config.nu * n_samples as Float);
155
156        // For regression, we create a modified problem
157        // This is a placeholder implementation - actual Nu-SVR requires
158        // a specialized solver
159
160        // Create a simple linear approximation for now
161        // In practice, this would use a proper Nu-SVR solver
162
163        // Placeholder: Use mean prediction
164        let intercept = y.mean().unwrap_or(0.0);
165
166        // For simplicity, use all points as support vectors in this placeholder
167        let support_indices: Vec<usize> = (0..n_samples).collect();
168        let support_vectors = x.clone();
169        let dual_coef = Array1::zeros(n_samples);
170        let support = Array1::from_vec(support_indices);
171
172        Ok(NuSVR {
173            config: self.config,
174            state: PhantomData,
175            support_vectors_: Some(support_vectors),
176            support_: Some(support),
177            dual_coef_: Some(dual_coef),
178            intercept_: Some(intercept),
179            n_features_in_: Some(n_features),
180            n_support_: Some(n_samples),
181            epsilon_: Some(epsilon),
182        })
183    }
184}
185
186impl Predict<Array2<Float>, Array1<Float>> for NuSVR<Trained> {
187    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
188        if x.ncols()
189            != self
190                .n_features_in_
191                .expect("n_features_in_ not available - model not fitted")
192        {
193            return Err(SklearsError::InvalidInput(format!(
194                "Feature mismatch: expected {} features, got {}",
195                self.n_features_in_
196                    .expect("n_features_in_ not available - model not fitted"),
197                x.ncols()
198            )));
199        }
200
201        let support_vectors = self
202            .support_vectors_
203            .as_ref()
204            .expect("support_vectors_ not available - model not fitted");
205        let dual_coef = self
206            .dual_coef_
207            .as_ref()
208            .expect("dual_coef_ not available - model not fitted");
209        let intercept = self
210            .intercept_
211            .expect("intercept_ not available - model not fitted");
212
213        let kernel = match &self.config.kernel {
214            KernelType::Linear => Box::new(crate::kernels::LinearKernel) as Box<dyn Kernel>,
215            KernelType::Rbf { gamma } => {
216                Box::new(crate::kernels::RbfKernel::new(*gamma)) as Box<dyn Kernel>
217            }
218            _ => Box::new(crate::kernels::RbfKernel::new(1.0)) as Box<dyn Kernel>, // Default fallback
219        };
220        let mut predictions = Array1::zeros(x.nrows());
221
222        for i in 0..x.nrows() {
223            let mut prediction = intercept;
224            for (j, &coef) in dual_coef.iter().enumerate() {
225                let k_val = kernel.compute(x.row(i), support_vectors.row(j));
226                prediction += coef * k_val;
227            }
228            predictions[i] = prediction;
229        }
230
231        Ok(predictions)
232    }
233}
234
235impl NuSVR<Trained> {
236    /// Get the support vectors
237    pub fn support_vectors(&self) -> &Array2<Float> {
238        self.support_vectors_
239            .as_ref()
240            .expect("support_vectors_ not available - model not fitted")
241    }
242
243    /// Get the indices of support vectors
244    pub fn support(&self) -> &Array1<usize> {
245        self.support_
246            .as_ref()
247            .expect("support_ not available - model not fitted")
248    }
249
250    /// Get the dual coefficients
251    pub fn dual_coef(&self) -> &Array1<Float> {
252        self.dual_coef_
253            .as_ref()
254            .expect("dual_coef_ not available - model not fitted")
255    }
256
257    /// Get the intercept
258    pub fn intercept(&self) -> Float {
259        self.intercept_
260            .expect("intercept_ not available - model not fitted")
261    }
262
263    /// Get the number of support vectors
264    pub fn n_support(&self) -> usize {
265        self.n_support_
266            .expect("n_support_ not available - model not fitted")
267    }
268
269    /// Get the number of features
270    pub fn n_features_in(&self) -> usize {
271        self.n_features_in_
272            .expect("n_features_in_ not available - model not fitted")
273    }
274
275    /// Get the epsilon parameter
276    pub fn epsilon(&self) -> Float {
277        self.epsilon_
278            .expect("epsilon_ not available - model not fitted")
279    }
280}
281
282#[allow(non_snake_case)]
283#[cfg(test)]
284mod tests {
285    use super::*;
286    use scirs2_core::ndarray::array;
287
288    #[test]
289    fn test_nusvr_creation() {
290        let nusvr = NuSVR::new()
291            .nu(0.3)
292            .expect("valid parameter")
293            .kernel(KernelType::Linear)
294            .tol(1e-4)
295            .max_iter(500)
296            .random_state(42);
297
298        assert_eq!(nusvr.config.nu, 0.3);
299        assert_eq!(nusvr.config.tol, 1e-4);
300        assert_eq!(nusvr.config.max_iter, 500);
301        assert_eq!(nusvr.config.random_state, Some(42));
302    }
303
304    #[test]
305    fn test_nusvr_invalid_nu() {
306        let result = NuSVR::new().nu(1.5);
307        assert!(result.is_err());
308    }
309
310    #[test]
311    fn test_nusvr_regression() {
312        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0],];
313        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0]; // y = 2*x
314
315        let nusvr = NuSVR::new()
316            .nu(0.5)
317            .expect("valid parameter")
318            .kernel(KernelType::Linear);
319        let fitted_model = nusvr.fit(&x, &y).expect("model fitting should succeed");
320
321        assert_eq!(fitted_model.n_features_in(), 1);
322        assert!(fitted_model.epsilon() > 0.0);
323
324        let predictions = fitted_model.predict(&x).expect("prediction should succeed");
325        assert_eq!(predictions.len(), 6);
326
327        // Check that predictions are finite
328        for &pred in predictions.iter() {
329            assert!(pred.is_finite());
330        }
331    }
332
333    #[test]
334    fn test_nusvr_shape_mismatch() {
335        let x = array![[1.0, 2.0], [3.0, 4.0]];
336        let y = array![1.0]; // Wrong length
337
338        let nusvr = NuSVR::new();
339        let result = nusvr.fit(&x, &y);
340
341        assert!(result.is_err());
342        assert!(result.unwrap_err().to_string().contains("Shape mismatch"));
343    }
344
345    #[test]
346    fn test_nusvr_feature_mismatch() {
347        let x_train = array![[1.0, 2.0], [3.0, 4.0]];
348        let y_train = array![1.0, 2.0];
349        let x_test = array![[1.0, 2.0, 3.0]]; // Wrong number of features
350
351        let nusvr = NuSVR::new();
352        let fitted_model = nusvr
353            .fit(&x_train, &y_train)
354            .expect("model fitting should succeed");
355        let result = fitted_model.predict(&x_test);
356
357        assert!(result.is_err());
358        assert!(result.unwrap_err().to_string().contains("Feature"));
359    }
360
361    #[test]
362    fn test_nusvr_empty_data() {
363        let x: Array2<f64> = Array2::zeros((0, 2));
364        let y: Array1<f64> = Array1::zeros(0);
365
366        let nusvr = NuSVR::new();
367        let result = nusvr.fit(&x, &y);
368
369        assert!(result.is_err());
370        assert!(result.unwrap_err().to_string().contains("Empty dataset"));
371    }
372
373    #[test]
374    fn test_nusvr_different_kernels() {
375        let x = array![[1.0], [2.0], [3.0], [4.0]];
376        let y = array![1.0, 4.0, 9.0, 16.0]; // y = x^2
377
378        let kernels = vec![
379            KernelType::Linear,
380            KernelType::Rbf { gamma: 0.1 },
381            KernelType::Polynomial {
382                gamma: 1.0,
383                degree: 2.0,
384                coef0: 0.0,
385            },
386        ];
387
388        for kernel in kernels {
389            let nusvr = NuSVR::new().kernel(kernel);
390            let fitted_model = nusvr.fit(&x, &y).expect("model fitting should succeed");
391            let predictions = fitted_model.predict(&x).expect("prediction should succeed");
392
393            assert_eq!(predictions.len(), 4);
394            for &pred in predictions.iter() {
395                assert!(pred.is_finite());
396            }
397        }
398    }
399}