rustlearn/factorization/
factorization_machines.rs

1//! A factorization machine model implemented using stochastic gradient descent.
2//!
3//! A [factorization machine](http://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf) (Rendle 2008)
4//! model combines the advantages of linear and factorization models. In this implementation, it approximates
5//! second-order feature interactions (as in a quadratic SVM) via reduced-rank matrix factorization.
6//! This allows it to estimate feature interactions even in sparse datasets (like recommender systems) where
7//! traditional polynomial SVMs fail.
8//!
9//! The complexity of the model is controlled by the dimensionality of the factorization matrix:
10//! a higher setting will make the model more expressive at the expense of training time and
11//! risk of overfitting.
12//!
13//! # Parallelism
14//!
15//! The model supports multithreaded model fitting via asynchronous stochastic
16//! gradient descent (Hogwild).
17//!
18//! # Examples
19//!
20//! ```
21//! use rustlearn::prelude::*;
22//! use rustlearn::factorization::factorization_machines::Hyperparameters;
23//! use rustlearn::datasets::iris;
24//!
25//! let (X, y) = iris::load_data();
26//!
27//! let mut model = Hyperparameters::new(4, 10)
28//!                                 .one_vs_rest();
29//!
30//! model.fit(&X, &y).unwrap();
31//!
32//! let prediction = model.predict(&X).unwrap();
33//! ```
34#![allow(non_snake_case)]
35
36use std::cmp;
37
38use prelude::*;
39
40use multiclass::OneVsRestWrapper;
41use utils::{
42    check_data_dimensionality, check_matched_dimensions, check_valid_labels, EncodableRng,
43};
44
45use rand;
46use rand::distributions::IndependentSample;
47
48use crossbeam;
49
50fn sigmoid(x: f32) -> f32 {
51    1.0 / (1.0 + (-x).exp())
52}
53
54fn logistic_loss(y: f32, y_hat: f32) -> f32 {
55    y_hat - y
56}
57
58macro_rules! max {
59    ($x:expr, $y:expr) => {{
60        match $x > $y {
61            true => $x,
62            false => $y,
63        }
64    }};
65}
66
67macro_rules! min {
68    ($x:expr, $y:expr) => {{
69        match $x < $y {
70            true => $x,
71            false => $y,
72        }
73    }};
74}
75
76/// Hyperparameters for a FactorizationMachine
77#[derive(Serialize, Deserialize)]
78pub struct Hyperparameters {
79    dim: usize,
80    num_components: usize,
81
82    learning_rate: f32,
83    l2_penalty: f32,
84    l1_penalty: f32,
85    rng: EncodableRng,
86}
87
88impl Hyperparameters {
89    /// Creates new Hyperparameters.
90    ///
91    /// The complexity of the model is controlled by the dimensionality of the factorization matrix:
92    /// a higher `num_components` setting will make the model more expressive
93    /// at the expense of training time and risk of overfitting.
94    pub fn new(dim: usize, num_components: usize) -> Hyperparameters {
95        Hyperparameters {
96            dim: dim,
97            num_components: num_components,
98            learning_rate: 0.05,
99            l2_penalty: 0.0,
100            l1_penalty: 0.0,
101            rng: EncodableRng::new(),
102        }
103    }
104    /// Set the initial learning rate.
105    ///
106    /// During fitting, the learning rate decreases more for parameters which have
107    /// have received larger gradient updates. This maintains more stable estimates
108    /// for common features while allowing fast learning for rare features.
109    pub fn learning_rate(&mut self, learning_rate: f32) -> &mut Hyperparameters {
110        self.learning_rate = learning_rate;
111        self
112    }
113
114    /// Set the L2 penalty.
115    pub fn l2_penalty(&mut self, l2_penalty: f32) -> &mut Hyperparameters {
116        self.l2_penalty = l2_penalty;
117        self
118    }
119
120    /// Set the L1 penalty.
121    pub fn l1_penalty(&mut self, l1_penalty: f32) -> &mut Hyperparameters {
122        self.l1_penalty = l1_penalty;
123        self
124    }
125
126    pub fn rng(&mut self, rng: rand::StdRng) -> &mut Hyperparameters {
127        self.rng.rng = rng;
128        self
129    }
130
131    /// Build a two-class model.
132    pub fn build(&self) -> FactorizationMachine {
133        let mut rng = self.rng.clone();
134
135        FactorizationMachine {
136            dim: self.dim,
137            num_components: self.num_components,
138
139            learning_rate: self.learning_rate,
140            l2_penalty: self.l2_penalty,
141            l1_penalty: self.l1_penalty,
142
143            coefficients: Array::zeros(self.dim, 1),
144            latent_factors: self.init_latent_factors_array(&mut rng),
145            gradsq: Array::ones(self.dim, 1),
146            latent_gradsq: Array::ones(self.dim, self.num_components),
147            applied_l2: Array::ones(self.dim, 1),
148            applied_l1: Array::zeros(self.dim, 1),
149            latent_applied_l2: Array::ones(self.dim, self.num_components),
150            latent_applied_l1: Array::zeros(self.dim, self.num_components),
151            accumulated_l2: 1.0,
152            accumulated_l1: 0.0,
153
154            rng: rng,
155        }
156    }
157
158    /// Initialize the latent factors.
159    fn init_latent_factors_array(&self, rng: &mut EncodableRng) -> Array {
160        let mut data = Vec::with_capacity(self.dim * self.num_components);
161        // let normal = rand::distributions::normal::Normal::new(0.0, 0.1 / ((self.dim * self.num_components) as f64).sqrt());
162        let normal =
163            rand::distributions::normal::Normal::new(0.0, 1.0 / self.num_components as f64);
164
165        for _ in 0..(self.dim * self.num_components) {
166            data.push(normal.ind_sample(&mut rng.rng) as f32)
167        }
168
169        let mut array = Array::from(data);
170        array.reshape(self.dim, self.num_components);
171        array
172    }
173
174    /// Build a one-vs-rest multiclass model.
175    #[allow(dead_code)]
176    pub fn one_vs_rest(&self) -> OneVsRestWrapper<FactorizationMachine> {
177        let base_model = self.build();
178
179        OneVsRestWrapper::new(base_model)
180    }
181}
182
183/// A two-class factorization machine implemented using stochastic gradient descent.
184#[derive(Clone, Serialize, Deserialize)]
185pub struct FactorizationMachine {
186    dim: usize,
187    num_components: usize,
188
189    learning_rate: f32,
190    l2_penalty: f32,
191    l1_penalty: f32,
192
193    coefficients: Array,
194    latent_factors: Array,
195    gradsq: Array,
196    latent_gradsq: Array,
197    applied_l2: Array,
198    applied_l1: Array,
199    latent_applied_l2: Array,
200    latent_applied_l1: Array,
201    accumulated_l2: f32,
202    accumulated_l1: f32,
203
204    rng: EncodableRng,
205}
206
207impl FactorizationMachine {
208    fn compute_prediction<T: NonzeroIterable>(&self, row: &T, component_sum: &mut [f32]) -> f32 {
209        let mut result = 0.0;
210
211        for (feature_idx, feature_value) in row.iter_nonzero() {
212            result += feature_value * self.coefficients.get(feature_idx, 0);
213        }
214
215        for component_idx in 0..self.num_components {
216            let mut component_sum_elem = 0.0;
217            let mut component_sum_sq_elem = 0.0;
218
219            for (feature_idx, feature_value) in row.iter_nonzero() {
220                let val = self.latent_factors.get(feature_idx, component_idx) * feature_value;
221                component_sum_elem += val;
222                component_sum_sq_elem += val.powi(2);
223            }
224
225            component_sum[component_idx] = component_sum_elem;
226
227            result += 0.5 * (component_sum_elem.powi(2) - component_sum_sq_elem);
228        }
229
230        result
231    }
232
233    fn apply_regularization(
234        parameter_value: &mut f32,
235        applied_l2: &mut f32,
236        applied_l1: &mut f32,
237        local_learning_rate: f32,
238        accumulated_l2: f32,
239        accumulated_l1: f32,
240    ) {
241        let l2_update = accumulated_l2 / *applied_l2;
242
243        *parameter_value *= 1.0 - (1.0 - l2_update) * local_learning_rate;
244        *applied_l2 *= l2_update;
245
246        let l1_potential_update = accumulated_l1 - *applied_l1;
247        let pre_update_coeff = parameter_value.clone();
248
249        if *parameter_value > 0.0 {
250            *parameter_value = max!(
251                0.0,
252                *parameter_value - l1_potential_update * local_learning_rate
253            );
254        } else {
255            *parameter_value = min!(
256                0.0,
257                *parameter_value + l1_potential_update * local_learning_rate
258            );
259        }
260
261        let l1_actual_update = (pre_update_coeff - *parameter_value).abs();
262        *applied_l1 += l1_actual_update;
263    }
264
265    fn update<T: NonzeroIterable>(&mut self, row: T, loss: f32, component_sum: &[f32]) {
266        for (feature_idx, feature_value) in (&row).iter_nonzero() {
267            // Update coefficients
268            let gradsq = self.gradsq.get_mut(feature_idx, 0);
269            let local_learning_rate = self.learning_rate / gradsq.sqrt();
270            let coefficient_value = self.coefficients.get_mut(feature_idx, 0);
271
272            let applied_l2 = self.applied_l2.get_mut(feature_idx, 0);
273            let applied_l1 = self.applied_l1.get_mut(feature_idx, 0);
274
275            let gradient = loss * feature_value;
276
277            *coefficient_value -= local_learning_rate * gradient;
278            *gradsq += gradient.powi(2);
279
280            FactorizationMachine::apply_regularization(
281                coefficient_value,
282                applied_l2,
283                applied_l1,
284                local_learning_rate,
285                self.accumulated_l2,
286                self.accumulated_l1,
287            );
288
289            // Update latent factors
290            let slice_start = feature_idx * self.num_components;
291            let slice_stop = slice_start + self.num_components;
292
293            let mut component_row =
294                &mut self.latent_factors.as_mut_slice()[slice_start..slice_stop];
295            let mut gradsq_row = &mut self.latent_gradsq.as_mut_slice()[slice_start..slice_stop];
296            let mut applied_l2_row =
297                &mut self.latent_applied_l2.as_mut_slice()[slice_start..slice_stop];
298            let mut applied_l1_row =
299                &mut self.latent_applied_l1.as_mut_slice()[slice_start..slice_stop];
300
301            for (component_value, (gradsq, (applied_l2, (applied_l1, component_sum_value)))) in
302                component_row.iter_mut().zip(
303                    gradsq_row.iter_mut().zip(
304                        applied_l2_row
305                            .iter_mut()
306                            .zip(applied_l1_row.iter_mut().zip(component_sum.iter())),
307                    ),
308                ) {
309                let local_learning_rate = self.learning_rate / gradsq.sqrt();
310                let update = loss * ((component_sum_value * feature_value)
311                    - (*component_value * feature_value.powi(2)));
312
313                *component_value -= local_learning_rate * update;
314                *gradsq += update.powi(2);
315
316                FactorizationMachine::apply_regularization(
317                    component_value,
318                    applied_l2,
319                    applied_l1,
320                    local_learning_rate,
321                    self.accumulated_l2,
322                    self.accumulated_l1,
323                );
324            }
325        }
326    }
327
328    fn accumulate_regularization(&mut self) {
329        self.accumulated_l2 *= 1.0 - self.l2_penalty;
330        self.accumulated_l1 += self.l1_penalty;
331    }
332
333    fn fit_sigmoid<'a, T>(&mut self, X: &'a T, y: &Array) -> Result<(), &'static str>
334    where
335        T: IndexableMatrix,
336        &'a T: RowIterable,
337    {
338        let mut component_sum = &mut vec![0.0; self.num_components][..];
339
340        for (row, &true_y) in X.iter_rows().zip(y.data().iter()) {
341            let y_hat = sigmoid(self.compute_prediction(&row, component_sum));
342
343            let loss = logistic_loss(true_y, y_hat);
344
345            self.update(row, loss, component_sum);
346
347            self.accumulate_regularization();
348        }
349
350        self.regularize_all();
351
352        Ok(())
353    }
354
355    /// Perform a dummy update pass over all features to force regularization to be applied.
356    fn regularize_all(&mut self) {
357        if self.l1_penalty == 0.0 && self.l2_penalty == 0.0 {
358            return;
359        }
360
361        let array = Array::ones(1, self.dim);
362        let num_components = self.num_components;
363
364        self.update(
365            &array.view_row(0),
366            0.0,
367            &vec![0.0; num_components.clone()][..],
368        );
369
370        self.accumulated_l2 = 1.0;
371        self.accumulated_l1 = 0.0;
372    }
373
374    pub fn get_coefficients(&self) -> &Array {
375        &self.coefficients
376    }
377
378    pub fn get_latent_factors(&self) -> &Array {
379        &self.latent_factors
380    }
381}
382
383impl<'a, T> SupervisedModel<&'a T> for FactorizationMachine
384where
385    &'a T: RowIterable,
386    T: IndexableMatrix,
387{
388    fn fit(&mut self, X: &'a T, y: &Array) -> Result<(), &'static str> {
389        try!(check_data_dimensionality(self.dim, X));
390        try!(check_matched_dimensions(X, y));
391        try!(check_valid_labels(y));
392
393        self.fit_sigmoid(X, y)
394    }
395
396    fn decision_function(&self, X: &'a T) -> Result<Array, &'static str> {
397        try!(check_data_dimensionality(self.dim, X));
398
399        let mut data = Vec::with_capacity(X.rows());
400
401        let mut component_sum = &mut vec![0.0; self.num_components][..];
402
403        for row in X.iter_rows() {
404            let prediction = self.compute_prediction(&row, component_sum);
405            data.push(sigmoid(prediction));
406        }
407
408        Ok(Array::from(data))
409    }
410}
411
412impl<'a, T> ParallelSupervisedModel<&'a T> for FactorizationMachine
413where
414    &'a T: RowIterable,
415    T: IndexableMatrix + Sync,
416{
417    fn fit_parallel(
418        &mut self,
419        X: &'a T,
420        y: &Array,
421        num_threads: usize,
422    ) -> Result<(), &'static str> {
423        try!(check_data_dimensionality(self.dim, X));
424        try!(check_matched_dimensions(X, y));
425        try!(check_valid_labels(y));
426
427        let rows_per_thread = X.rows() / num_threads + 1;
428        let num_components = self.num_components;
429
430        let model_ptr = unsafe { &*(self as *const FactorizationMachine) };
431
432        crossbeam::scope(|scope| {
433            for thread_num in 0..num_threads {
434                scope.spawn(move || {
435                    let start = thread_num * rows_per_thread;
436                    let stop = cmp::min((thread_num + 1) * rows_per_thread, X.rows());
437
438                    let mut component_sum = vec![0.0; num_components];
439
440                    let model = unsafe {
441                        &mut *(model_ptr as *const FactorizationMachine
442                            as *mut FactorizationMachine)
443                    };
444
445                    for (row, &true_y) in X
446                        .iter_rows_range(start..stop)
447                        .zip(y.data()[start..stop].iter())
448                    {
449                        let y_hat = sigmoid(model.compute_prediction(&row, &mut component_sum[..]));
450                        let loss = logistic_loss(true_y, y_hat);
451                        model.update(row, loss, &mut component_sum[..]);
452                        model.accumulate_regularization();
453                    }
454                });
455            }
456        });
457
458        self.regularize_all();
459
460        Ok(())
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use rand::{SeedableRng, StdRng};
467
468    use prelude::*;
469
470    use cross_validation::cross_validation::CrossValidation;
471    use datasets::iris::load_data;
472    use metrics::accuracy_score;
473    use multiclass::OneVsRest;
474
475    #[cfg(feature = "all_tests")]
476    use datasets::newsgroups;
477
478    use super::*;
479
480    #[test]
481    fn basic_updating() {
482        let mut model = Hyperparameters::new(2, 2)
483            .learning_rate(0.01)
484            .l2_penalty(0.0)
485            .l1_penalty(0.0)
486            .build();
487
488        // Set to zero to allow coefficient update tests
489        // to be straightforward
490        for elem in model.latent_factors.as_mut_slice().iter_mut() {
491            *elem = 0.0;
492        }
493
494        let y = Array::ones(1, 1);
495        let X = Array::from(&vec![vec![1.0, -0.1]]);
496
497        model.fit(&X, &y).unwrap();
498
499        assert!(model.gradsq.data()[0] > 1.0);
500        assert!(model.gradsq.data()[1] > 1.0);
501
502        assert!(model.coefficients.data()[0] == 0.005);
503        assert!(model.coefficients.data()[1] == -0.0005);
504
505        model.fit(&X, &y).unwrap();
506        println!("model coefficients {:?}", model.coefficients.data());
507
508        assert!(model.coefficients.data()[0] == 0.009460844);
509        assert!(model.coefficients.data()[1] == -0.0009981153);
510    }
511
512    #[test]
513    fn test_basic_l1() {
514        let mut model = Hyperparameters::new(2, 2)
515            .learning_rate(0.01)
516            .l2_penalty(0.0)
517            .l1_penalty(100.0)
518            .rng(StdRng::from_seed(&[100]))
519            .build();
520
521        let y = Array::ones(1, 1);
522        let X = Array::from(&vec![vec![1.0, -0.1]]);
523
524        for &elem in model.latent_factors.data() {
525            assert!(elem != 0.0);
526        }
527
528        model.fit(&X, &y).unwrap();
529
530        assert!(model.gradsq.data()[0] > 1.0);
531        assert!(model.gradsq.data()[1] > 1.0);
532
533        // All the coefficients/factors should
534        // have been regularized away to zero.
535        for &elem in model.coefficients.data() {
536            assert!(elem == 0.0);
537        }
538
539        for &elem in model.latent_factors.data() {
540            assert!(elem == 0.0);
541        }
542    }
543
544    #[test]
545    fn test_iris() {
546        let (data, target) = load_data();
547
548        let mut test_accuracy = 0.0;
549        let mut train_accuracy = 0.0;
550
551        let no_splits = 10;
552
553        let mut cv = CrossValidation::new(data.rows(), no_splits);
554        cv.set_rng(StdRng::from_seed(&[100]));
555
556        for (train_idx, test_idx) in cv {
557            let x_train = data.get_rows(&train_idx);
558            let x_test = data.get_rows(&test_idx);
559
560            let y_train = target.get_rows(&train_idx);
561
562            let mut model = Hyperparameters::new(data.cols(), 5)
563                .learning_rate(0.05)
564                .l2_penalty(0.0)
565                .rng(StdRng::from_seed(&[100]))
566                .one_vs_rest();
567
568            for _ in 0..20 {
569                model.fit(&x_train, &y_train).unwrap();
570            }
571
572            let y_hat = model.predict(&x_test).unwrap();
573            let y_hat_train = model.predict(&x_train).unwrap();
574
575            test_accuracy += accuracy_score(&target.get_rows(&test_idx), &y_hat);
576            train_accuracy += accuracy_score(&target.get_rows(&train_idx), &y_hat_train);
577        }
578
579        test_accuracy /= no_splits as f32;
580        train_accuracy /= no_splits as f32;
581
582        println!("Accuracy {}", test_accuracy);
583        println!("Train accuracy {}", train_accuracy);
584
585        assert!(test_accuracy > 0.94);
586    }
587
588    #[test]
589    fn test_iris_parallel() {
590        let (data, target) = load_data();
591
592        // Get a binary target so that the parallelism
593        // goes through the FM model and not through the
594        // OvR wrapper.
595        let (_, target) = OneVsRest::split(&target).next().unwrap();
596
597        let mut test_accuracy = 0.0;
598        let mut train_accuracy = 0.0;
599
600        let no_splits = 10;
601
602        let mut cv = CrossValidation::new(data.rows(), no_splits);
603        cv.set_rng(StdRng::from_seed(&[100]));
604
605        for (train_idx, test_idx) in cv {
606            let x_train = data.get_rows(&train_idx);
607            let x_test = data.get_rows(&test_idx);
608
609            let y_train = target.get_rows(&train_idx);
610
611            let mut model = Hyperparameters::new(data.cols(), 5)
612                .learning_rate(0.05)
613                .l2_penalty(0.0)
614                .rng(StdRng::from_seed(&[100]))
615                .build();
616
617            for _ in 0..20 {
618                model.fit_parallel(&x_train, &y_train, 4).unwrap();
619            }
620
621            let y_hat = model.predict(&x_test).unwrap();
622            let y_hat_train = model.predict(&x_train).unwrap();
623
624            test_accuracy += accuracy_score(&target.get_rows(&test_idx), &y_hat);
625            train_accuracy += accuracy_score(&target.get_rows(&train_idx), &y_hat_train);
626        }
627
628        test_accuracy /= no_splits as f32;
629        train_accuracy /= no_splits as f32;
630
631        println!("Accuracy {}", test_accuracy);
632        println!("Train accuracy {}", train_accuracy);
633
634        assert!(test_accuracy > 0.94);
635    }
636
637    #[test]
638    #[cfg(feature = "all_tests")]
639    fn test_fm_newsgroups() {
640        let (X, target) = newsgroups::load_data();
641
642        let no_splits = 2;
643
644        let mut test_accuracy = 0.0;
645        let mut train_accuracy = 0.0;
646
647        let mut cv = CrossValidation::new(X.rows(), no_splits);
648        cv.set_rng(StdRng::from_seed(&[100]));
649
650        for (train_idx, test_idx) in cv {
651            let x_train = X.get_rows(&train_idx);
652
653            let x_test = X.get_rows(&test_idx);
654            let y_train = target.get_rows(&train_idx);
655
656            let mut model = Hyperparameters::new(X.cols(), 10)
657                .learning_rate(0.005)
658                .rng(StdRng::from_seed(&[100]))
659                .one_vs_rest();
660
661            for _ in 0..5 {
662                model.fit(&x_train, &y_train).unwrap();
663                println!("fit");
664            }
665
666            let y_hat = model.predict(&x_test).unwrap();
667            let y_hat_train = model.predict(&x_train).unwrap();
668
669            test_accuracy += accuracy_score(&target.get_rows(&test_idx), &y_hat);
670            train_accuracy += accuracy_score(&target.get_rows(&train_idx), &y_hat_train);
671        }
672
673        train_accuracy /= no_splits as f32;
674        test_accuracy /= no_splits as f32;
675
676        println!("Train accuracy {}", train_accuracy);
677        println!("Test accuracy {}", test_accuracy);
678
679        assert!(train_accuracy > 0.95);
680        assert!(test_accuracy > 0.65);
681    }
682}
683
684#[cfg(feature = "bench")]
685#[allow(unused_imports)]
686mod bench {
687    use rand::{SeedableRng, StdRng};
688
689    use test::Bencher;
690
691    use prelude::*;
692
693    use cross_validation::cross_validation::CrossValidation;
694    use datasets::iris::load_data;
695    use datasets::newsgroups;
696    use metrics::{accuracy_score, roc_auc_score};
697    use multiclass::OneVsRest;
698
699    use super::*;
700
701    #[bench]
702    fn bench_iris_sparse(b: &mut Bencher) {
703        let (data, target) = load_data();
704
705        let sparse_data = SparseRowArray::from(&data);
706
707        let mut model = Hyperparameters::new(data.cols(), 5)
708            .learning_rate(0.05)
709            .l2_penalty(0.0)
710            .rng(StdRng::from_seed(&[100]))
711            .one_vs_rest();
712
713        b.iter(|| {
714            model.fit(&sparse_data, &target).unwrap();
715        });
716    }
717
718    #[bench]
719    fn bench_fm_newsgroups(b: &mut Bencher) {
720        let (X, target) = newsgroups::load_data();
721        let (_, target) = OneVsRest::split(&target).next().unwrap();
722
723        let X = X.get_rows(&(..500));
724        let target = target.get_rows(&(..500));
725
726        let mut model = Hyperparameters::new(X.cols(), 10)
727            .rng(StdRng::from_seed(&[100]))
728            .build();
729
730        b.iter(|| {
731            model.fit(&X, &target).unwrap();
732        });
733    }
734
735    #[bench]
736    fn bench_fm_newsgroups_parallel(b: &mut Bencher) {
737        let (X, target) = newsgroups::load_data();
738        let (_, target) = OneVsRest::split(&target).next().unwrap();
739
740        let X = X.get_rows(&(..500));
741        let target = target.get_rows(&(..500));
742
743        let mut model = Hyperparameters::new(X.cols(), 10)
744            .rng(StdRng::from_seed(&[100]))
745            .build();
746
747        b.iter(|| {
748            model.fit_parallel(&X, &target, 2).unwrap();
749        });
750    }
751}