rust_als/
als.rs

1use hashbrown::HashMap;
2use nalgebra::{DMatrix, DVector, OMatrix};
3use rand::{Rng, thread_rng};
4use rayon::prelude::*;
5use crate::errors::ALSError;
6
7const DEFAULT_ITERATIONS : usize = 10;
8const DEFAULT_EPS : f64 = 1.0e-9;
9const DEFAULT_REG : f64 = 1.0;
10type T = f64;
11pub type RTriplet<T> = (usize, usize, T);
12
13pub struct ALS<T> {
14    n : usize,
15    m : usize,
16    k : usize,
17    r_row_first: HashMap<usize, HashMap<usize, T>>,
18    r_col_first : HashMap<usize, HashMap<usize, T>>,
19    x_mat : Vec<DVector<T>>,
20    y_mat : Vec<DVector<T>>,
21    default_iters : usize,
22    default_regularization: T,
23}
24
25impl ALS<T> {
26
27    /// Constructs a new ALS learner for an initially empty sparse matrix R of size N x M using
28    /// K features for X and Y.
29    pub fn new(n : usize, m : usize, k : usize) -> Self {
30        let mut als =
31        ALS {
32            n,
33            m,
34            k,
35            r_row_first : HashMap::new(),
36            r_col_first : HashMap::new(),
37            x_mat : vec![],
38            y_mat : vec![],
39            default_iters : DEFAULT_ITERATIONS,
40            default_regularization: DEFAULT_REG,
41        };
42        als.init_y();
43        als.init_x();
44        als
45    }
46
47    /// Adds a value to the sparse matrix R. Will overwrite a previous value if indices coincide.
48    pub fn add(&mut self, e : RTriplet<T>) -> Result<Option<T>, ALSError<T>> {
49        if e.0 >= self.n {
50            return Err(ALSError::InvalidTripletError(e, format!("{} exceeds row index range for R = {}x{}", e.0, self.n, self.m)))
51        }
52        if e.1 >= self.m {
53            return Err(ALSError::InvalidTripletError(e, format!("{} exceeds column index range of R = {}x{}", e.1, self.n, self.m)))
54        }
55
56        let mut previous_entry_val = None;
57        self.r_row_first.entry(e.0)
58            .and_modify(|col| {
59                previous_entry_val = col.insert(e.1, e.2);
60            })
61            .or_insert({
62                let mut col = HashMap::new();
63                previous_entry_val = col.insert(e.1, e.2);
64                col
65            });
66
67        self.r_col_first.entry(e.1)
68            .and_modify(|row| {
69                row.insert(e.0, e.2);
70            })
71            .or_insert({
72                let mut row = HashMap::new();
73                row.insert(e.0, e.2);
74                row
75            });
76
77        Ok(previous_entry_val)
78    }
79
80    /// Resets all entries of X with values uniformly sampled from (0, 1 / sqrt(K)).
81    pub fn reset_x(&mut self) {
82        let upper_init_bound : T = 1.0 / (self.k as T).sqrt();
83        self.x_mat.par_iter_mut().for_each(|x_col| {
84            x_col.fill_with(|| thread_rng().gen_range(0.0..upper_init_bound))
85        });
86    }
87
88    /// Resets all entries of Y with values uniformly sampled from (0, 1 / sqrt(K)).
89    pub fn reset_y(&mut self) {
90        let upper_init_bound : T = 1.0 / (self.k as T).sqrt();
91        self.y_mat.par_iter_mut().for_each(|y_col| {
92            y_col.fill_with(|| thread_rng().gen_range(0.0..upper_init_bound))
93        });
94    }
95
96    fn init_x(&mut self) {
97        self.x_mat = Vec::with_capacity(self.n);
98        let upper_init_bound : T = 1.0 / (self.k as T).sqrt();
99        self.x_mat.par_extend((0..self.n).into_par_iter()
100            .map(|_| DVector::<T>::from_fn(
101                self.k,
102                |_, _| thread_rng().gen_range(0.0..upper_init_bound))));
103    }
104
105    fn init_y(&mut self) {
106        self.y_mat = Vec::with_capacity(self.m);
107        let upper_init_bound : T = 1.0 / (self.k as T).sqrt();
108        self.y_mat.par_extend((0..self.m).into_par_iter()
109            .map(|_| DVector::<T>::from_fn(
110                self.k,
111                |_, _| thread_rng().gen_range(0.0..upper_init_bound))));
112    }
113
114    /// Clears all entries of R.
115    pub fn reset_r(&mut self) {
116        self.r_row_first = HashMap::new();
117        self.r_col_first = HashMap::new();
118    }
119
120    /// Sets the regularization factor.
121    pub fn set_regularization(&mut self, lambda : T) {
122        self.default_regularization = lambda;
123    }
124
125    pub fn set_default_iters(&mut self, iters : usize) {
126        self.default_iters = iters;
127    }
128
129    /// Trains for a specified amount of iterations.
130    pub fn train_for(&mut self, iters: usize) {
131       self.ensure_x_y_existence();
132        let mut precomp_yyt: HashMap<usize, OMatrix<T, _, _>> = HashMap::with_capacity(self.m);
133        let mut precomp_xxt: HashMap<usize, OMatrix<T, _, _>> = HashMap::with_capacity(self.n);
134        let reg_diag = DMatrix::<T>::from_diagonal_element(self.k, self.k, self.default_regularization);
135        precomp_yyt.par_extend(
136            self.r_col_first.par_keys()
137                .map(|i_m| {
138                    (*i_m, DMatrix::<T>::zeros(self.k, self.k))
139                })
140        );
141        precomp_xxt.par_extend(
142            self.r_row_first.par_keys()
143                .map(|i_n| {
144                    (*i_n, DMatrix::<T>::zeros(self.k, self.k))
145                })
146        );
147        for _ in 0..iters {
148            precomp_yyt.par_iter_mut().for_each(|(i_m, kk_term)| {
149                let y_i = &self.y_mat[*i_m];
150                y_i.mul_to(&y_i.transpose(), kk_term);
151            });
152
153            self.x_mat.par_iter_mut().enumerate().for_each(|(i_n, x_row)| {
154                if let Some(r_row) = self.r_row_first.get(&i_n) {
155                    let mut first_sum = reg_diag.clone();
156                    let mut second_sum: DVector<T> = DVector::zeros(self.k);
157                    r_row.iter().for_each(|(i_m, r_nm)|{
158                        first_sum += precomp_yyt.get(i_m).unwrap();
159                        second_sum += &(&self.y_mat[*i_m] * *r_nm);
160                    });
161                    if !first_sum.try_inverse_mut() {
162                        first_sum = first_sum.pseudo_inverse(DEFAULT_EPS).unwrap();
163                    }
164                    first_sum.mul_to(&second_sum, x_row);
165                }
166            });
167
168            precomp_xxt.par_iter_mut().for_each(|(i_n, kk_term)| {
169                let x_i = &self.x_mat[*i_n];
170                x_i.mul_to(&x_i.transpose(), kk_term);
171            });
172
173            self.y_mat.par_iter_mut().enumerate().for_each(|(i_m, y_row)| {
174                if let Some(r_col) =  self.r_col_first.get(&i_m) {
175                    let mut first_sum = reg_diag.clone();
176                    let mut second_sum: DVector<T> = DVector::zeros(self.k);
177                    r_col.iter().for_each(|(i_n, r_nm)|{
178                        first_sum += precomp_xxt.get(i_n).unwrap();
179                        second_sum += &(&self.x_mat[*i_n] * *r_nm);
180                    });
181                    if !first_sum.try_inverse_mut() {
182                        first_sum = first_sum.pseudo_inverse(DEFAULT_EPS).unwrap();
183                    }
184                    first_sum.mul_to(&second_sum, y_row);
185                }
186
187            });
188        }
189    }
190
191    fn ensure_x_y_existence(&mut self) {
192        if self.x_mat.len() != self.n {
193            self.init_x();
194        }
195
196        if self.y_mat.len() != self.m {
197            self.init_y();
198        }
199    }
200
201    /// Trains for the default amount of iterations set for the instance.
202    pub fn train(&mut self) {
203        self.train_for(self.default_iters);
204    }
205
206    /// Get the feature vectors of the row
207    pub fn get_row_factors(&self, row : usize) -> Option<&DVector<T>> {
208        self.x_mat.get(row)
209    }
210    pub fn get_col_factors(&self, col : usize) -> Option<&DVector<T>> {
211        self.y_mat.get(col)
212    }
213
214    pub fn get_x(&self) -> &Vec<DVector<T>> {
215        &self.x_mat
216    }
217
218    pub fn get_y(&self) -> &Vec<DVector<T>> {
219        &self.y_mat
220    }
221
222
223    /// Computes the cost function between X^T x Y and R.
224    pub fn cost(&mut self) -> T {
225        self.ensure_x_y_existence();
226        let r_term : T = self.r_row_first.par_iter().map(|(i_n, col)| {
227            col
228                .par_iter()
229                .map(|(i_m, val)|
230                    (*val - (self.x_mat[*i_n].transpose() * &self.y_mat[*i_m])[(0, 0)])
231                        .powi(2)
232                )
233                .sum::<T>()
234        }).sum::<T>();
235
236        let x_term : T = self.x_mat
237            .par_iter()
238            .map(|x_in| (x_in.transpose() * x_in)[(0, 0)])
239            .sum::<T>();
240
241        let y_term : T = self.y_mat
242            .par_iter()
243            .map(|y_in| (y_in.transpose() * y_in)[(0, 0)])
244            .sum::<T>();
245
246        r_term + self.default_regularization * (x_term + y_term)
247    }
248
249    /// Predicts the value of R at some index.
250    pub fn predict_r_val(&self, n :usize, m : usize) -> T {
251        (self.x_mat[n].transpose() * &self.y_mat[m])[(0, 0)]
252    }
253}
254