1use prelude::*;
23
24use super::ffi;
25pub use super::ffi::KernelType;
26
27use utils::{check_data_dimensionality, check_matched_dimensions};
28
29#[derive(Clone, Serialize, Deserialize)]
30pub struct Hyperparameters {
32 dim: usize,
33 num_classes: usize,
34 svm_parameter: ffi::SvmParameter,
35}
36
37impl Hyperparameters {
38 pub fn new(dim: usize, kernel: KernelType, num_classes: usize) -> Hyperparameters {
39 Hyperparameters {
40 dim: dim,
41 num_classes: num_classes,
42 svm_parameter: ffi::SvmParameter::new(ffi::SvmType::C_SVC, kernel, num_classes, dim),
43 }
44 }
45
46 pub fn C(&mut self, C: f64) -> &mut Hyperparameters {
50 self.svm_parameter.C = C;
51 self
52 }
53
54 pub fn degree(&mut self, degree: i32) -> &mut Hyperparameters {
57 self.svm_parameter.degree = degree;
58 self
59 }
60
61 pub fn gamma(&mut self, gamma: f64) -> &mut Hyperparameters {
64 self.svm_parameter.gamma = gamma;
65 self
66 }
67
68 pub fn coef0(&mut self, coef0: f64) -> &mut Hyperparameters {
71 self.svm_parameter.coef0 = coef0;
72 self
73 }
74
75 pub fn cache_size(&mut self, cache_size: f64) -> &mut Hyperparameters {
78 self.svm_parameter.cache_size = cache_size;
79 self
80 }
81
82 pub fn build(&self) -> SVC {
86 SVC {
87 dim: self.dim,
88 hyperparams: self.to_owned(),
89 model: None,
90 }
91 }
92
93 fn svm_parameter(&self) -> &ffi::SvmParameter {
94 &self.svm_parameter
95 }
96}
97
98#[derive(Clone, Serialize, Deserialize)]
100pub struct SVC {
101 dim: usize,
102 hyperparams: Hyperparameters,
103 model: Option<ffi::SvmModel>,
104}
105
106macro_rules! impl_supervised_model {
107 ($x_type:ty) => {
108 impl<'a> SupervisedModel<&'a $x_type> for SVC {
109 fn fit(&mut self, X: &$x_type, y: &Array) -> Result<(), &'static str> {
110 try!(check_data_dimensionality(self.dim, X));
111 try!(check_matched_dimensions(X, y));
112
113 let svm_params = self.hyperparams.svm_parameter();
114
115 self.model = Some(try!(ffi::fit(X, y, &svm_params)));
116
117 Ok(())
118 }
119
120 fn decision_function(&self, X: &$x_type) -> Result<Array, &'static str> {
121 try!(check_data_dimensionality(self.dim, X));
122
123 match self.model {
124 Some(ref model) => {
125 let (decision_function, _) = ffi::predict(model, X);
126 Ok(decision_function)
127 }
128 None => Err("Model must be fit before predicting."),
129 }
130 }
131
132 fn predict(&self, X: &$x_type) -> Result<Array, &'static str> {
133 match self.model {
134 Some(ref model) => {
135 let (_, predicted_class) = ffi::predict(model, X);
136 Ok(predicted_class)
137 }
138 None => Err("Model must be fit before predicting."),
139 }
140 }
141 }
142 };
143}
144
145impl_supervised_model!(Array);
146impl_supervised_model!(SparseRowArray);
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151
152 use rand::{SeedableRng, StdRng};
153
154 use prelude::*;
155
156 use cross_validation::cross_validation::CrossValidation;
157 use datasets::iris::load_data;
158 use metrics::accuracy_score;
159
160 use bincode;
161
162 #[cfg(feature = "all_tests")]
163 use datasets::newsgroups;
164
165 macro_rules! test_iris_kernel {
166 ($kernel:expr, $fn_name:ident) => {
167 #[test]
168 fn $fn_name() {
169 let (data, target) = load_data();
170
171 let mut test_accuracy = 0.0;
172 let mut train_accuracy = 0.0;
173
174 let no_splits = 10;
175
176 let mut cv = CrossValidation::new(data.rows(), no_splits);
177 cv.set_rng(StdRng::from_seed(&[100]));
178
179 for (train_idx, test_idx) in cv {
180 let x_train = data.get_rows(&train_idx);
181 let x_test = data.get_rows(&test_idx);
182
183 let y_train = target.get_rows(&train_idx);
184
185 let mut model = Hyperparameters::new(data.cols(), $kernel, 3).build();
186
187 model.fit(&x_train, &y_train).unwrap();
188
189 let y_hat = model.predict(&x_test).unwrap();
190
191 test_accuracy += accuracy_score(&target.get_rows(&test_idx), &y_hat);
192 train_accuracy += accuracy_score(&y_train, &model.predict(&x_train).unwrap());
193 }
194
195 test_accuracy /= no_splits as f32;
196 train_accuracy /= no_splits as f32;
197
198 println!("Accuracy {}", test_accuracy);
199 println!("Train accuracy {}", train_accuracy);
200 assert!(test_accuracy > 0.97);
201 }
202 };
203 }
204
205 test_iris_kernel!(KernelType::Linear, test_iris_linear);
206 test_iris_kernel!(KernelType::Polynomial, test_iris_polynomial);
207 test_iris_kernel!(KernelType::RBF, test_iris_rbf);
208
209 #[test]
210 fn test_sparse_iris() {
211 let (dense_data, target) = load_data();
212 let data = SparseRowArray::from(&dense_data);
213
214 let mut test_accuracy = 0.0;
215 let mut train_accuracy = 0.0;
216
217 let no_splits = 10;
218
219 let mut cv = CrossValidation::new(data.rows(), no_splits);
220 cv.set_rng(StdRng::from_seed(&[100]));
221
222 for (train_idx, test_idx) in cv {
223 let x_train = data.get_rows(&train_idx);
224 let x_test = data.get_rows(&test_idx);
225
226 let y_train = target.get_rows(&train_idx);
227
228 let mut model = Hyperparameters::new(data.cols(), KernelType::Linear, 3).build();
229
230 model.fit(&x_train, &y_train).unwrap();
231
232 let y_hat = model.predict(&x_test).unwrap();
233
234 test_accuracy += accuracy_score(&target.get_rows(&test_idx), &y_hat);
235 train_accuracy += accuracy_score(&y_train, &model.predict(&x_train).unwrap());
236 }
237
238 test_accuracy /= no_splits as f32;
239 train_accuracy /= no_splits as f32;
240
241 println!("Accuracy {}", test_accuracy);
242 println!("Train accuracy {}", train_accuracy);
243 assert!(test_accuracy > 0.97);
244 }
245
246 #[test]
247 fn serialization() {
248 let (data, target) = load_data();
249
250 let mut test_accuracy = 0.0;
251 let mut train_accuracy = 0.0;
252
253 let no_splits = 10;
254
255 let mut cv = CrossValidation::new(data.rows(), no_splits);
256 cv.set_rng(StdRng::from_seed(&[100]));
257
258 for (train_idx, test_idx) in cv {
259 let x_train = data.get_rows(&train_idx);
260 let x_test = data.get_rows(&test_idx);
261
262 let y_train = target.get_rows(&train_idx);
263
264 let mut model = Hyperparameters::new(data.cols(), KernelType::Linear, 3).build();
265
266 model.fit(&x_train, &y_train).unwrap();
267
268 let encoded = bincode::serialize(&model).unwrap();
269 let decoded: SVC = bincode::deserialize(&encoded).unwrap();
270
271 let y_hat = decoded.predict(&x_test).unwrap();
272
273 test_accuracy += accuracy_score(&target.get_rows(&test_idx), &y_hat);
274 train_accuracy += accuracy_score(&y_train, &decoded.predict(&x_train).unwrap());
275 }
276
277 test_accuracy /= no_splits as f32;
278 train_accuracy /= no_splits as f32;
279
280 println!("Accuracy {}", test_accuracy);
281 println!("Train accuracy {}", train_accuracy);
282 assert!(test_accuracy > 0.97);
283 }
284
285 #[test]
286 #[cfg(feature = "all_tests")]
287 fn test_newsgroups() {
288 let (X, target) = newsgroups::load_data();
289
290 let no_splits = 2;
291 let mut test_accuracy = 0.0;
292
293 let mut cv = CrossValidation::new(X.rows(), no_splits);
294 cv.set_rng(StdRng::from_seed(&[100]));
295
296 for (train_idx, test_idx) in cv {
297 let x_train = X.get_rows(&train_idx);
298 let x_test = X.get_rows(&test_idx);
299
300 let y_train = target.get_rows(&train_idx);
301
302 let mut model = Hyperparameters::new(X.cols(), KernelType::Linear, 20).build();
303
304 model.fit(&x_train, &y_train).unwrap();
305
306 let y_hat = model.predict(&x_test).unwrap();
307
308 test_accuracy += accuracy_score(&target.get_rows(&test_idx), &y_hat);
309 }
310
311 test_accuracy /= no_splits as f32;
312 println!("{}", test_accuracy);
313
314 assert!(test_accuracy > 0.8);
317 }
318}