1use crate::kernels::{Kernel, KernelType};
8use scirs2_core::ndarray::{Array1, Array2};
9use sklears_core::{
10 error::{Result, SklearsError},
11 traits::{Fit, Predict, Trained, Untrained},
12 types::Float,
13};
14use std::marker::PhantomData;
15
16#[derive(Debug, Clone)]
18pub struct NuSVRConfig {
19 pub nu: Float,
21 pub kernel: KernelType,
23 pub tol: Float,
25 pub max_iter: usize,
27 pub random_state: Option<u64>,
29}
30
31impl Default for NuSVRConfig {
32 fn default() -> Self {
33 Self {
34 nu: 0.5,
35 kernel: KernelType::Rbf { gamma: 1.0 },
36 tol: 1e-3,
37 max_iter: 200,
38 random_state: None,
39 }
40 }
41}
42
43#[derive(Debug)]
50pub struct NuSVR<State = Untrained> {
51 config: NuSVRConfig,
52 state: PhantomData<State>,
53 support_vectors_: Option<Array2<Float>>,
55 support_: Option<Array1<usize>>,
56 dual_coef_: Option<Array1<Float>>,
57 intercept_: Option<Float>,
58 n_features_in_: Option<usize>,
59 n_support_: Option<usize>,
60 epsilon_: Option<Float>,
61}
62
63impl NuSVR<Untrained> {
64 pub fn new() -> Self {
66 Self {
67 config: NuSVRConfig::default(),
68 state: PhantomData,
69 support_vectors_: None,
70 support_: None,
71 dual_coef_: None,
72 intercept_: None,
73 n_features_in_: None,
74 n_support_: None,
75 epsilon_: None,
76 }
77 }
78
79 pub fn nu(mut self, nu: Float) -> Result<Self> {
81 if nu <= 0.0 || nu > 1.0 {
82 return Err(SklearsError::InvalidParameter {
83 name: "nu".to_string(),
84 reason: "must be in the range (0, 1]".to_string(),
85 });
86 }
87 self.config.nu = nu;
88 Ok(self)
89 }
90
91 pub fn kernel(mut self, kernel: KernelType) -> Self {
93 self.config.kernel = kernel;
94 self
95 }
96
97 pub fn tol(mut self, tol: Float) -> Self {
99 self.config.tol = tol;
100 self
101 }
102
103 pub fn max_iter(mut self, max_iter: usize) -> Self {
105 self.config.max_iter = max_iter;
106 self
107 }
108
109 pub fn random_state(mut self, random_state: u64) -> Self {
111 self.config.random_state = Some(random_state);
112 self
113 }
114}
115
116impl Default for NuSVR<Untrained> {
117 fn default() -> Self {
118 Self::new()
119 }
120}
121
122impl Fit<Array2<Float>, Array1<Float>> for NuSVR<Untrained> {
123 type Fitted = NuSVR<Trained>;
124
125 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
126 if x.nrows() != y.len() {
127 return Err(SklearsError::InvalidInput(format!(
128 "Shape mismatch: X has {} samples, y has {} samples",
129 x.nrows(),
130 y.len()
131 )));
132 }
133
134 if x.nrows() == 0 {
135 return Err(SklearsError::InvalidInput("Empty dataset".to_string()));
136 }
137
138 let n_features = x.ncols();
139 let n_samples = x.nrows();
140
141 let y_std = {
145 let mean = y.mean().unwrap_or(0.0);
146 let variance =
147 y.iter().map(|&val| (val - mean).powi(2)).sum::<Float>() / n_samples as Float;
148 variance.sqrt()
149 };
150 let epsilon = self.config.nu * y_std;
151
152 let _c = 1.0 / (self.config.nu * n_samples as Float);
155
156 let intercept = y.mean().unwrap_or(0.0);
165
166 let support_indices: Vec<usize> = (0..n_samples).collect();
168 let support_vectors = x.clone();
169 let dual_coef = Array1::zeros(n_samples);
170 let support = Array1::from_vec(support_indices);
171
172 Ok(NuSVR {
173 config: self.config,
174 state: PhantomData,
175 support_vectors_: Some(support_vectors),
176 support_: Some(support),
177 dual_coef_: Some(dual_coef),
178 intercept_: Some(intercept),
179 n_features_in_: Some(n_features),
180 n_support_: Some(n_samples),
181 epsilon_: Some(epsilon),
182 })
183 }
184}
185
186impl Predict<Array2<Float>, Array1<Float>> for NuSVR<Trained> {
187 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
188 if x.ncols()
189 != self
190 .n_features_in_
191 .expect("n_features_in_ not available - model not fitted")
192 {
193 return Err(SklearsError::InvalidInput(format!(
194 "Feature mismatch: expected {} features, got {}",
195 self.n_features_in_
196 .expect("n_features_in_ not available - model not fitted"),
197 x.ncols()
198 )));
199 }
200
201 let support_vectors = self
202 .support_vectors_
203 .as_ref()
204 .expect("support_vectors_ not available - model not fitted");
205 let dual_coef = self
206 .dual_coef_
207 .as_ref()
208 .expect("dual_coef_ not available - model not fitted");
209 let intercept = self
210 .intercept_
211 .expect("intercept_ not available - model not fitted");
212
213 let kernel = match &self.config.kernel {
214 KernelType::Linear => Box::new(crate::kernels::LinearKernel) as Box<dyn Kernel>,
215 KernelType::Rbf { gamma } => {
216 Box::new(crate::kernels::RbfKernel::new(*gamma)) as Box<dyn Kernel>
217 }
218 _ => Box::new(crate::kernels::RbfKernel::new(1.0)) as Box<dyn Kernel>, };
220 let mut predictions = Array1::zeros(x.nrows());
221
222 for i in 0..x.nrows() {
223 let mut prediction = intercept;
224 for (j, &coef) in dual_coef.iter().enumerate() {
225 let k_val = kernel.compute(x.row(i), support_vectors.row(j));
226 prediction += coef * k_val;
227 }
228 predictions[i] = prediction;
229 }
230
231 Ok(predictions)
232 }
233}
234
235impl NuSVR<Trained> {
236 pub fn support_vectors(&self) -> &Array2<Float> {
238 self.support_vectors_
239 .as_ref()
240 .expect("support_vectors_ not available - model not fitted")
241 }
242
243 pub fn support(&self) -> &Array1<usize> {
245 self.support_
246 .as_ref()
247 .expect("support_ not available - model not fitted")
248 }
249
250 pub fn dual_coef(&self) -> &Array1<Float> {
252 self.dual_coef_
253 .as_ref()
254 .expect("dual_coef_ not available - model not fitted")
255 }
256
257 pub fn intercept(&self) -> Float {
259 self.intercept_
260 .expect("intercept_ not available - model not fitted")
261 }
262
263 pub fn n_support(&self) -> usize {
265 self.n_support_
266 .expect("n_support_ not available - model not fitted")
267 }
268
269 pub fn n_features_in(&self) -> usize {
271 self.n_features_in_
272 .expect("n_features_in_ not available - model not fitted")
273 }
274
275 pub fn epsilon(&self) -> Float {
277 self.epsilon_
278 .expect("epsilon_ not available - model not fitted")
279 }
280}
281
282#[allow(non_snake_case)]
283#[cfg(test)]
284mod tests {
285 use super::*;
286 use scirs2_core::ndarray::array;
287
288 #[test]
289 fn test_nusvr_creation() {
290 let nusvr = NuSVR::new()
291 .nu(0.3)
292 .expect("valid parameter")
293 .kernel(KernelType::Linear)
294 .tol(1e-4)
295 .max_iter(500)
296 .random_state(42);
297
298 assert_eq!(nusvr.config.nu, 0.3);
299 assert_eq!(nusvr.config.tol, 1e-4);
300 assert_eq!(nusvr.config.max_iter, 500);
301 assert_eq!(nusvr.config.random_state, Some(42));
302 }
303
304 #[test]
305 fn test_nusvr_invalid_nu() {
306 let result = NuSVR::new().nu(1.5);
307 assert!(result.is_err());
308 }
309
310 #[test]
311 fn test_nusvr_regression() {
312 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0],];
313 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0]; let nusvr = NuSVR::new()
316 .nu(0.5)
317 .expect("valid parameter")
318 .kernel(KernelType::Linear);
319 let fitted_model = nusvr.fit(&x, &y).expect("model fitting should succeed");
320
321 assert_eq!(fitted_model.n_features_in(), 1);
322 assert!(fitted_model.epsilon() > 0.0);
323
324 let predictions = fitted_model.predict(&x).expect("prediction should succeed");
325 assert_eq!(predictions.len(), 6);
326
327 for &pred in predictions.iter() {
329 assert!(pred.is_finite());
330 }
331 }
332
333 #[test]
334 fn test_nusvr_shape_mismatch() {
335 let x = array![[1.0, 2.0], [3.0, 4.0]];
336 let y = array![1.0]; let nusvr = NuSVR::new();
339 let result = nusvr.fit(&x, &y);
340
341 assert!(result.is_err());
342 assert!(result.unwrap_err().to_string().contains("Shape mismatch"));
343 }
344
345 #[test]
346 fn test_nusvr_feature_mismatch() {
347 let x_train = array![[1.0, 2.0], [3.0, 4.0]];
348 let y_train = array![1.0, 2.0];
349 let x_test = array![[1.0, 2.0, 3.0]]; let nusvr = NuSVR::new();
352 let fitted_model = nusvr
353 .fit(&x_train, &y_train)
354 .expect("model fitting should succeed");
355 let result = fitted_model.predict(&x_test);
356
357 assert!(result.is_err());
358 assert!(result.unwrap_err().to_string().contains("Feature"));
359 }
360
361 #[test]
362 fn test_nusvr_empty_data() {
363 let x: Array2<f64> = Array2::zeros((0, 2));
364 let y: Array1<f64> = Array1::zeros(0);
365
366 let nusvr = NuSVR::new();
367 let result = nusvr.fit(&x, &y);
368
369 assert!(result.is_err());
370 assert!(result.unwrap_err().to_string().contains("Empty dataset"));
371 }
372
373 #[test]
374 fn test_nusvr_different_kernels() {
375 let x = array![[1.0], [2.0], [3.0], [4.0]];
376 let y = array![1.0, 4.0, 9.0, 16.0]; let kernels = vec![
379 KernelType::Linear,
380 KernelType::Rbf { gamma: 0.1 },
381 KernelType::Polynomial {
382 gamma: 1.0,
383 degree: 2.0,
384 coef0: 0.0,
385 },
386 ];
387
388 for kernel in kernels {
389 let nusvr = NuSVR::new().kernel(kernel);
390 let fitted_model = nusvr.fit(&x, &y).expect("model fitting should succeed");
391 let predictions = fitted_model.predict(&x).expect("prediction should succeed");
392
393 assert_eq!(predictions.len(), 4);
394 for &pred in predictions.iter() {
395 assert!(pred.is_finite());
396 }
397 }
398 }
399}