rusty_machine/learning/
gp.rs1use learning::toolkit::kernel::{Kernel, SquaredExp};
30use linalg::{Matrix, BaseMatrix};
31use linalg::Vector;
32use learning::{LearningResult, SupModel};
33use learning::error::{Error, ErrorKind};
34
35
36pub trait MeanFunc {
38 fn func(&self, x: Matrix<f64>) -> Vector<f64>;
40}
41
42#[derive(Clone, Copy, Debug)]
44pub struct ConstMean {
45 a: f64,
46}
47
48impl Default for ConstMean {
50 fn default() -> ConstMean {
51 ConstMean { a: 0f64 }
52 }
53}
54
55impl MeanFunc for ConstMean {
56 fn func(&self, x: Matrix<f64>) -> Vector<f64> {
57 Vector::zeros(x.rows()) + self.a
58 }
59}
60
61#[derive(Debug)]
67pub struct GaussianProcess<T: Kernel, U: MeanFunc> {
68 ker: T,
69 mean: U,
70 pub noise: f64,
72 alpha: Option<Vector<f64>>,
73 train_mat: Option<Matrix<f64>>,
74 train_data: Option<Matrix<f64>>,
75}
76
77impl Default for GaussianProcess<SquaredExp, ConstMean> {
88 fn default() -> GaussianProcess<SquaredExp, ConstMean> {
89 GaussianProcess {
90 ker: SquaredExp::default(),
91 mean: ConstMean::default(),
92 noise: 0f64,
93 train_mat: None,
94 train_data: None,
95 alpha: None,
96 }
97 }
98}
99
100impl<T: Kernel, U: MeanFunc> GaussianProcess<T, U> {
101 pub fn new(ker: T, mean: U, noise: f64) -> GaussianProcess<T, U> {
114 GaussianProcess {
115 ker: ker,
116 mean: mean,
117 noise: noise,
118 train_mat: None,
119 train_data: None,
120 alpha: None,
121 }
122 }
123
124 fn ker_mat(&self, m1: &Matrix<f64>, m2: &Matrix<f64>) -> LearningResult<Matrix<f64>> {
126 if m1.cols() != m2.cols() {
127 Err(Error::new(ErrorKind::InvalidState,
128 "Inputs to kernel matrices have different column counts."))
129 } else {
130 let dim1 = m1.rows();
131 let dim2 = m2.rows();
132
133 let mut ker_data = Vec::with_capacity(dim1 * dim2);
134 ker_data.extend(m1.iter_rows().flat_map(|row1| {
135 m2.iter_rows()
136 .map(move |row2| self.ker.kernel(row1, row2))
137 }));
138
139 Ok(Matrix::new(dim1, dim2, ker_data))
140 }
141 }
142}
143
144impl<T: Kernel, U: MeanFunc> SupModel<Matrix<f64>, Vector<f64>> for GaussianProcess<T, U> {
145 fn predict(&self, inputs: &Matrix<f64>) -> LearningResult<Vector<f64>> {
147
148 if let (&Some(ref alpha), &Some(ref t_data)) = (&self.alpha, &self.train_data) {
150 let mean = self.mean.func(inputs.clone());
151 let post_mean = try!(self.ker_mat(inputs, t_data)) * alpha;
152 Ok(mean + post_mean)
153 } else {
154 Err(Error::new(ErrorKind::UntrainedModel, "The model has not been trained."))
155 }
156 }
157
158 fn train(&mut self, inputs: &Matrix<f64>, targets: &Vector<f64>) -> LearningResult<()> {
160 let noise_mat = Matrix::identity(inputs.rows()) * self.noise;
161
162 let ker_mat = self.ker_mat(inputs, inputs).unwrap();
163
164 let train_mat = try!((ker_mat + noise_mat).cholesky().map_err(|_| {
165 Error::new(ErrorKind::InvalidState,
166 "Could not compute Cholesky decomposition.")
167 }));
168
169 let x = train_mat.solve_l_triangular(targets - self.mean.func(inputs.clone())).unwrap();
170 let alpha = train_mat.transpose().solve_u_triangular(x).unwrap();
171
172 self.train_mat = Some(train_mat);
173 self.train_data = Some(inputs.clone());
174 self.alpha = Some(alpha);
175
176 Ok(())
177 }
178}
179
180impl<T: Kernel, U: MeanFunc> GaussianProcess<T, U> {
181 pub fn get_posterior(&self,
187 inputs: &Matrix<f64>)
188 -> LearningResult<(Vector<f64>, Matrix<f64>)> {
189 if let (&Some(ref t_mat), &Some(ref alpha), &Some(ref t_data)) = (&self.train_mat,
190 &self.alpha,
191 &self.train_data) {
192 let mean = self.mean.func(inputs.clone());
193
194 let post_mean = mean + try!(self.ker_mat(inputs, t_data)) * alpha;
195
196 let test_mat = try!(self.ker_mat(inputs, t_data));
197 let mut var_data = Vec::with_capacity(inputs.rows() * inputs.cols());
198 for row in test_mat.iter_rows() {
199 let test_point = Vector::new(row.to_vec());
200 var_data.append(&mut t_mat.solve_l_triangular(test_point).unwrap().into_vec());
201 }
202
203 let v_mat = Matrix::new(test_mat.rows(), test_mat.cols(), var_data);
204
205 let post_var = try!(self.ker_mat(inputs, inputs)) - &v_mat * v_mat.transpose();
206
207 Ok((post_mean, post_var))
208 } else {
209 Err(Error::new_untrained())
210 }
211 }
212}