rustlearn/svm/libsvm/
svc.rs

1//! Support Vector Classifier using the `libsvm` library.
2//!
3//! Both dense and sparse models are supported.
4//!
5//! # Examples
6//!
7//! ```
8//! use rustlearn::prelude::*; use rustlearn::datasets::iris;
9//! use rustlearn::svm::libsvm::svc::{Hyperparameters, KernelType};
10//!
11//! let (X, y) = iris::load_data();
12//!
13//! let mut model = Hyperparameters::new(4, KernelType::Linear, 4)
14//!     .C(0.3)
15//!     .build();
16//!
17//! model.fit(&X, &y).unwrap();
18//!
19//! let prediction = model.predict(&X).unwrap();
20//! ```
21
22use prelude::*;
23
24use super::ffi;
25pub use super::ffi::KernelType;
26
27use utils::{check_data_dimensionality, check_matched_dimensions};
28
29#[derive(Clone, Serialize, Deserialize)]
30/// Hyperparameters for the SVC model.
31pub struct Hyperparameters {
32    dim: usize,
33    num_classes: usize,
34    svm_parameter: ffi::SvmParameter,
35}
36
37impl Hyperparameters {
38    pub fn new(dim: usize, kernel: KernelType, num_classes: usize) -> Hyperparameters {
39        Hyperparameters {
40            dim: dim,
41            num_classes: num_classes,
42            svm_parameter: ffi::SvmParameter::new(ffi::SvmType::C_SVC, kernel, num_classes, dim),
43        }
44    }
45
46    /// Set the regularization parameter `C`; smaller values
47    /// mean more regularization.
48    /// Default is `1.0`.
49    pub fn C(&mut self, C: f64) -> &mut Hyperparameters {
50        self.svm_parameter.C = C;
51        self
52    }
53
54    /// Set the degree of the polynomial kernel. No effect on other
55    /// kernels. Default: 3.
56    pub fn degree(&mut self, degree: i32) -> &mut Hyperparameters {
57        self.svm_parameter.degree = degree;
58        self
59    }
60
61    /// Set the gamma parameter of the RBF kernel.
62    /// Default is `1 / self.dim`.
63    pub fn gamma(&mut self, gamma: f64) -> &mut Hyperparameters {
64        self.svm_parameter.gamma = gamma;
65        self
66    }
67
68    /// Set the coef0 parameter for the sigmoid kernel.
69    /// Default is `0.0`.
70    pub fn coef0(&mut self, coef0: f64) -> &mut Hyperparameters {
71        self.svm_parameter.coef0 = coef0;
72        self
73    }
74
75    /// Set the `libsvm` cache size, in megabytes.
76    /// Default is `100.0`.
77    pub fn cache_size(&mut self, cache_size: f64) -> &mut Hyperparameters {
78        self.svm_parameter.cache_size = cache_size;
79        self
80    }
81
82    /// Build the SVC model. `libsvm` natively supports multiclass
83    /// problems via one-vs-one (OvO) estimation, so no one-vs-rest
84    /// wrapper is provided.
85    pub fn build(&self) -> SVC {
86        SVC {
87            dim: self.dim,
88            hyperparams: self.to_owned(),
89            model: None,
90        }
91    }
92
93    fn svm_parameter(&self) -> &ffi::SvmParameter {
94        &self.svm_parameter
95    }
96}
97
98/// Support Vector Classifier provided by the `libsvm` library.
99#[derive(Clone, Serialize, Deserialize)]
100pub struct SVC {
101    dim: usize,
102    hyperparams: Hyperparameters,
103    model: Option<ffi::SvmModel>,
104}
105
106macro_rules! impl_supervised_model {
107    ($x_type:ty) => {
108        impl<'a> SupervisedModel<&'a $x_type> for SVC {
109            fn fit(&mut self, X: &$x_type, y: &Array) -> Result<(), &'static str> {
110                try!(check_data_dimensionality(self.dim, X));
111                try!(check_matched_dimensions(X, y));
112
113                let svm_params = self.hyperparams.svm_parameter();
114
115                self.model = Some(try!(ffi::fit(X, y, &svm_params)));
116
117                Ok(())
118            }
119
120            fn decision_function(&self, X: &$x_type) -> Result<Array, &'static str> {
121                try!(check_data_dimensionality(self.dim, X));
122
123                match self.model {
124                    Some(ref model) => {
125                        let (decision_function, _) = ffi::predict(model, X);
126                        Ok(decision_function)
127                    }
128                    None => Err("Model must be fit before predicting."),
129                }
130            }
131
132            fn predict(&self, X: &$x_type) -> Result<Array, &'static str> {
133                match self.model {
134                    Some(ref model) => {
135                        let (_, predicted_class) = ffi::predict(model, X);
136                        Ok(predicted_class)
137                    }
138                    None => Err("Model must be fit before predicting."),
139                }
140            }
141        }
142    };
143}
144
145impl_supervised_model!(Array);
146impl_supervised_model!(SparseRowArray);
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151
152    use rand::{SeedableRng, StdRng};
153
154    use prelude::*;
155
156    use cross_validation::cross_validation::CrossValidation;
157    use datasets::iris::load_data;
158    use metrics::accuracy_score;
159
160    use bincode;
161
162    #[cfg(feature = "all_tests")]
163    use datasets::newsgroups;
164
165    macro_rules! test_iris_kernel {
166        ($kernel:expr, $fn_name:ident) => {
167            #[test]
168            fn $fn_name() {
169                let (data, target) = load_data();
170
171                let mut test_accuracy = 0.0;
172                let mut train_accuracy = 0.0;
173
174                let no_splits = 10;
175
176                let mut cv = CrossValidation::new(data.rows(), no_splits);
177                cv.set_rng(StdRng::from_seed(&[100]));
178
179                for (train_idx, test_idx) in cv {
180                    let x_train = data.get_rows(&train_idx);
181                    let x_test = data.get_rows(&test_idx);
182
183                    let y_train = target.get_rows(&train_idx);
184
185                    let mut model = Hyperparameters::new(data.cols(), $kernel, 3).build();
186
187                    model.fit(&x_train, &y_train).unwrap();
188
189                    let y_hat = model.predict(&x_test).unwrap();
190
191                    test_accuracy += accuracy_score(&target.get_rows(&test_idx), &y_hat);
192                    train_accuracy += accuracy_score(&y_train, &model.predict(&x_train).unwrap());
193                }
194
195                test_accuracy /= no_splits as f32;
196                train_accuracy /= no_splits as f32;
197
198                println!("Accuracy {}", test_accuracy);
199                println!("Train accuracy {}", train_accuracy);
200                assert!(test_accuracy > 0.97);
201            }
202        };
203    }
204
205    test_iris_kernel!(KernelType::Linear, test_iris_linear);
206    test_iris_kernel!(KernelType::Polynomial, test_iris_polynomial);
207    test_iris_kernel!(KernelType::RBF, test_iris_rbf);
208
209    #[test]
210    fn test_sparse_iris() {
211        let (dense_data, target) = load_data();
212        let data = SparseRowArray::from(&dense_data);
213
214        let mut test_accuracy = 0.0;
215        let mut train_accuracy = 0.0;
216
217        let no_splits = 10;
218
219        let mut cv = CrossValidation::new(data.rows(), no_splits);
220        cv.set_rng(StdRng::from_seed(&[100]));
221
222        for (train_idx, test_idx) in cv {
223            let x_train = data.get_rows(&train_idx);
224            let x_test = data.get_rows(&test_idx);
225
226            let y_train = target.get_rows(&train_idx);
227
228            let mut model = Hyperparameters::new(data.cols(), KernelType::Linear, 3).build();
229
230            model.fit(&x_train, &y_train).unwrap();
231
232            let y_hat = model.predict(&x_test).unwrap();
233
234            test_accuracy += accuracy_score(&target.get_rows(&test_idx), &y_hat);
235            train_accuracy += accuracy_score(&y_train, &model.predict(&x_train).unwrap());
236        }
237
238        test_accuracy /= no_splits as f32;
239        train_accuracy /= no_splits as f32;
240
241        println!("Accuracy {}", test_accuracy);
242        println!("Train accuracy {}", train_accuracy);
243        assert!(test_accuracy > 0.97);
244    }
245
246    #[test]
247    fn serialization() {
248        let (data, target) = load_data();
249
250        let mut test_accuracy = 0.0;
251        let mut train_accuracy = 0.0;
252
253        let no_splits = 10;
254
255        let mut cv = CrossValidation::new(data.rows(), no_splits);
256        cv.set_rng(StdRng::from_seed(&[100]));
257
258        for (train_idx, test_idx) in cv {
259            let x_train = data.get_rows(&train_idx);
260            let x_test = data.get_rows(&test_idx);
261
262            let y_train = target.get_rows(&train_idx);
263
264            let mut model = Hyperparameters::new(data.cols(), KernelType::Linear, 3).build();
265
266            model.fit(&x_train, &y_train).unwrap();
267
268            let encoded = bincode::serialize(&model).unwrap();
269            let decoded: SVC = bincode::deserialize(&encoded).unwrap();
270
271            let y_hat = decoded.predict(&x_test).unwrap();
272
273            test_accuracy += accuracy_score(&target.get_rows(&test_idx), &y_hat);
274            train_accuracy += accuracy_score(&y_train, &decoded.predict(&x_train).unwrap());
275        }
276
277        test_accuracy /= no_splits as f32;
278        train_accuracy /= no_splits as f32;
279
280        println!("Accuracy {}", test_accuracy);
281        println!("Train accuracy {}", train_accuracy);
282        assert!(test_accuracy > 0.97);
283    }
284
285    #[test]
286    #[cfg(feature = "all_tests")]
287    fn test_newsgroups() {
288        let (X, target) = newsgroups::load_data();
289
290        let no_splits = 2;
291        let mut test_accuracy = 0.0;
292
293        let mut cv = CrossValidation::new(X.rows(), no_splits);
294        cv.set_rng(StdRng::from_seed(&[100]));
295
296        for (train_idx, test_idx) in cv {
297            let x_train = X.get_rows(&train_idx);
298            let x_test = X.get_rows(&test_idx);
299
300            let y_train = target.get_rows(&train_idx);
301
302            let mut model = Hyperparameters::new(X.cols(), KernelType::Linear, 20).build();
303
304            model.fit(&x_train, &y_train).unwrap();
305
306            let y_hat = model.predict(&x_test).unwrap();
307
308            test_accuracy += accuracy_score(&target.get_rows(&test_idx), &y_hat);
309        }
310
311        test_accuracy /= no_splits as f32;
312        println!("{}", test_accuracy);
313
314        // This could definitely be improved
315        // with better hyperparameter choice.
316        assert!(test_accuracy > 0.8);
317    }
318}