rusty_machine/learning/optim/
fmincg.rs1use learning::optim::{Optimizable, OptimAlgorithm};
32use linalg::Vector;
33
34use std::cmp;
35use std::f64;
36
37
38#[derive(Clone, Copy, Debug)]
40pub struct ConjugateGD {
41 pub rho: f64,
43 pub sig: f64,
45 pub int: f64,
47 pub ext: f64,
49 pub max: usize,
51 pub ratio: f64,
53
54 pub iters: usize,
56}
57
58impl Default for ConjugateGD {
70 fn default() -> ConjugateGD {
71 ConjugateGD {
72 rho: 0.01,
73 sig: 0.5,
74 int: 0.1,
75 ext: 3.0,
76 max: 20,
77 ratio: 100.0,
78 iters: 100,
79 }
80 }
81}
82
83impl<M: Optimizable> OptimAlgorithm<M> for ConjugateGD {
84 fn optimize(&self,
85 model: &M,
86 start: &[f64],
87 inputs: &M::Inputs,
88 targets: &M::Targets)
89 -> Vec<f64> {
90 let mut i = 0usize;
91 let mut ls_failed = false;
92
93 let (mut f1, vec_df1) = model.compute_grad(start, inputs, targets);
94 let mut df1 = Vector::new(vec_df1);
95
96 let red = 1f64;
98
99 let length = self.iters as i32;
100
101 let mut s = -df1.clone();
102 let mut d1 = -s.dot(&s);
103 let mut z1 = red / (1f64 - d1);
104
105 let mut x = Vector::new(start.to_vec());
106
107 let (mut f2, mut df2): (f64, Vector<f64>);
108
109 while (i as i32) < length.abs() {
110 if length > 0 {
111 i += 1;
112 }
113
114 let (x0, f0) = (x.clone(), f1);
115
116 x = x + &s * z1;
117
118 let cost = model.compute_grad(x.data(), inputs, targets);
119 f2 = cost.0;
120 df2 = Vector::new(cost.1);
121
122 if length < 0 {
123 i += 1;
124 }
125
126 let mut d2 = df2.dot(&s);
127
128 let (mut f3, mut d3, mut z3) = (f1, d1, -z1);
129
130 let mut m = if length > 0 {
131 self.max as i32
132 } else {
133 cmp::min(self.max as i32, -length - (i as i32))
134 };
135
136 let mut success = false;
137 let mut limit = -1f64;
138
139 loop {
140 let mut z2: f64;
141
142 while ((f2 > (f1 + z1 * self.rho * d1)) || (d2 > -self.sig * d1)) && (m > 0i32) {
143
144 limit = z1;
145
146 if f2 > f1 {
147 z2 = z3 - (0.5 * d3 * z3 * z3) / (d3 * z3 + f2 - f3);
148 } else {
149 let a = 6f64 * (f2 - f3) / z3 + 3f64 * (d2 + d3);
150 let b = 3f64 * (f3 - f2) - z3 * (2f64 * d2 + d3);
151 z2 = ((b * b - a * d2 * z3 * z3).sqrt() - b) / a;
152 }
153
154 if z2.is_nan() || z2.is_infinite() {
155 z2 = z3 / 2f64;
156 }
157
158 if z2 <= self.int * z3 {
159 if z2 <= (1f64 - self.int) * z3 {
160 z2 = (1f64 - self.int) * z3;
161 }
162 } else if self.int * z3 <= (1f64 - self.int) * z3 {
163 z2 = (1f64 - self.int) * z3;
164 } else {
165 z2 = self.int * z3;
166 }
167
168 z1 += z2;
169 x = x + &s * z2;
170 let cost_grad = model.compute_grad(x.data(), inputs, targets);
171 f2 = cost_grad.0;
172 df2 = Vector::new(cost_grad.1);
173
174 m -= 1i32;
175 if length < 0 {
176 i += 1;
177 }
178
179 d2 = df2.dot(&s);
180 z3 -= z2;
181 }
182
183 if f2 > f1 + z1 * self.rho * d1 || d2 > -self.sig * d1 {
184 break;
185 } else if d2 > self.sig * d1 {
186 success = true;
187 break;
188 } else if m == 0i32 {
189 break;
190 }
191
192 let a = 6f64 * (f2 - f3) / z3 + 3f64 * (d2 + d3);
193 let b = 3f64 * (f3 - f2) - z3 * (2f64 * d2 + d3);
194 z2 = -d2 * z3 * z3 / (b + (b * b - a * d2 * z3 * z3).sqrt());
195
196 if z2.is_nan() || z2.is_infinite() || z2 < 0f64 {
197 if limit < -0.5 {
198 z2 = z1 * (self.ext - 1f64);
199 } else {
200 z2 = (limit - z1) / 2f64;
201 }
202 } else if (limit > -0.5) && (z2 + z1 > limit) {
203 z2 = (limit - z1) / 2f64;
204 } else if (limit < -0.5) && (z2 + z1 > z1 * self.ext) {
205 z2 = z1 * (self.ext - 1f64);
206 } else if z2 < -z3 * self.int {
207 z2 = -z3 * self.int;
208 } else if (limit > -0.5) && (z2 < (limit - z1) * (1f64 - self.int)) {
209 z2 = (limit - z1) * (1f64 - self.int);
210 }
211
212 f3 = f2;
213 d3 = d2;
214 z3 = -z2;
215 z1 += z2;
216 x = x + &s * z2;
217
218 let cost_grad = model.compute_grad(x.data(), inputs, targets);
219 f2 = cost_grad.0;
220 df2 = Vector::new(cost_grad.1);
221
222 m -= 1;
223 if length < 0 {
224 i += 1;
225 }
226
227 d2 = df2.dot(&s);
228 }
229
230 if success {
231 f1 = f2;
232 s = s * (&df2 - &df1).dot(&df2) / df1.dot(&df1) - &df2;
233
234 df1 = df2;
235
236 d2 = df1.dot(&s);
237
238 if d2 > 0f64 {
239 s = -&df1;
240 d2 = -s.dot(&s);
241 }
242
243 let ratio = d1 / (d2 - f64::MIN_POSITIVE);
244 if self.ratio < ratio {
245 z1 *= self.ratio;
246 } else {
247 z1 *= ratio;
248 }
249
250 d1 = d2;
251 ls_failed = false;
252 } else {
253 x = x0;
254 f1 = f0;
255
256 if ls_failed || i as i32 > length.abs() {
257 break;
258 }
259
260 df1 = df2;
261
262 s = -&df1;
263 d1 = -s.dot(&s);
264
265 z1 = 1f64 / (1f64 - d1);
266 ls_failed = true;
267 }
268
269 }
270 x.into_vec()
271 }
272}