1use std::collections::HashMap;
2use trustformers_core::errors::{Result, TrustformersError};
3use trustformers_core::tensor::Tensor;
4use trustformers_core::traits::Optimizer;
5
6#[derive(Debug)]
19pub struct PDEAwareOptimizer {
20 pub learning_rate: f32,
21 pub beta1: f32,
22 pub beta2: f32,
23 pub epsilon: f32,
24 pub weight_decay: f32,
25
26 pub residual_variance_weight: f32, pub gradient_alignment_factor: f32, pub smoothing_factor: f32, pub sharp_gradient_threshold: f32, pub step: usize,
34 pub momentum: HashMap<String, Vec<f32>>,
35 pub variance: HashMap<String, Vec<f32>>,
36 pub residual_variance_history: Vec<f32>,
37 pub gradient_alignment_history: Vec<f32>,
38}
39
40#[derive(Debug, Clone)]
41pub struct PDEAwareConfig {
42 pub learning_rate: f32,
43 pub beta1: f32,
44 pub beta2: f32,
45 pub epsilon: f32,
46 pub weight_decay: f32,
47 pub residual_variance_weight: f32,
48 pub gradient_alignment_factor: f32,
49 pub smoothing_factor: f32,
50 pub sharp_gradient_threshold: f32,
51}
52
53impl Default for PDEAwareConfig {
54 fn default() -> Self {
55 Self {
56 learning_rate: 1e-3,
57 beta1: 0.9,
58 beta2: 0.999,
59 epsilon: 1e-8,
60 weight_decay: 0.0,
61 residual_variance_weight: 0.1,
62 gradient_alignment_factor: 0.05,
63 smoothing_factor: 0.95,
64 sharp_gradient_threshold: 1.0,
65 }
66 }
67}
68
69impl Default for PDEAwareOptimizer {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75impl PDEAwareOptimizer {
76 pub fn new() -> Self {
77 Self::from_config(PDEAwareConfig::default())
78 }
79
80 pub fn from_config(config: PDEAwareConfig) -> Self {
81 Self {
82 learning_rate: config.learning_rate,
83 beta1: config.beta1,
84 beta2: config.beta2,
85 epsilon: config.epsilon,
86 weight_decay: config.weight_decay,
87 residual_variance_weight: config.residual_variance_weight,
88 gradient_alignment_factor: config.gradient_alignment_factor,
89 smoothing_factor: config.smoothing_factor,
90 sharp_gradient_threshold: config.sharp_gradient_threshold,
91 step: 0,
92 momentum: HashMap::new(),
93 variance: HashMap::new(),
94 residual_variance_history: Vec::new(),
95 gradient_alignment_history: Vec::new(),
96 }
97 }
98
99 pub fn for_burgers_equation() -> Self {
101 Self::from_config(PDEAwareConfig {
102 learning_rate: 5e-4,
103 beta1: 0.95,
104 beta2: 0.999,
105 epsilon: 1e-10,
106 weight_decay: 1e-6,
107 residual_variance_weight: 0.15,
108 gradient_alignment_factor: 0.08,
109 smoothing_factor: 0.98,
110 sharp_gradient_threshold: 0.8,
111 })
112 }
113
114 pub fn for_allen_cahn() -> Self {
116 Self::from_config(PDEAwareConfig {
117 learning_rate: 1e-3,
118 beta1: 0.9,
119 beta2: 0.995,
120 epsilon: 1e-9,
121 weight_decay: 1e-5,
122 residual_variance_weight: 0.2,
123 gradient_alignment_factor: 0.1,
124 smoothing_factor: 0.95,
125 sharp_gradient_threshold: 1.5,
126 })
127 }
128
129 pub fn for_kdv_equation() -> Self {
131 Self::from_config(PDEAwareConfig {
132 learning_rate: 2e-4,
133 beta1: 0.95,
134 beta2: 0.9995,
135 epsilon: 1e-12,
136 weight_decay: 0.0,
137 residual_variance_weight: 0.25,
138 gradient_alignment_factor: 0.12,
139 smoothing_factor: 0.99,
140 sharp_gradient_threshold: 0.5,
141 })
142 }
143
144 pub fn for_sharp_gradients() -> Self {
146 Self::from_config(PDEAwareConfig {
147 learning_rate: 1e-4,
148 beta1: 0.95,
149 beta2: 0.9999,
150 epsilon: 1e-10,
151 weight_decay: 1e-7,
152 residual_variance_weight: 0.3,
153 gradient_alignment_factor: 0.15,
154 smoothing_factor: 0.99,
155 sharp_gradient_threshold: 0.3,
156 })
157 }
158
159 fn compute_residual_variance_from_norm(&mut self, grad_norm: f32) -> f32 {
161 let variance = grad_norm;
162
163 self.residual_variance_history.push(variance);
165
166 if self.residual_variance_history.len() > 100 {
168 self.residual_variance_history.remove(0);
169 }
170
171 if self.residual_variance_history.len() > 1 {
173 let prev_variance =
174 self.residual_variance_history[self.residual_variance_history.len() - 2];
175 self.smoothing_factor * prev_variance + (1.0 - self.smoothing_factor) * variance
176 } else {
177 variance
178 }
179 }
180
181 fn is_sharp_gradient_region_from_norm(&self, grad_norm: f32, max_grad: f32) -> bool {
183 grad_norm > self.sharp_gradient_threshold || max_grad > 2.0 * self.sharp_gradient_threshold
185 }
186
187 pub fn adaptive_learning_rate(
189 &self,
190 base_lr: f32,
191 residual_variance: f32,
192 is_sharp_region: bool,
193 ) -> f32 {
194 let mut adaptive_lr = base_lr;
195
196 if residual_variance > 0.1 {
198 adaptive_lr *= 1.0 / (1.0 + self.residual_variance_weight * residual_variance);
199 }
200
201 if is_sharp_region {
203 adaptive_lr *= 0.5;
204 }
205
206 adaptive_lr.clamp(base_lr * 0.01, base_lr * 2.0)
208 }
209
210 pub fn get_pde_stats(&self) -> PDEAwareStats {
212 let avg_residual_variance = if !self.residual_variance_history.is_empty() {
213 self.residual_variance_history.iter().sum::<f32>()
214 / self.residual_variance_history.len() as f32
215 } else {
216 0.0
217 };
218
219 PDEAwareStats {
220 step: self.step,
221 average_residual_variance: avg_residual_variance,
222 parameters_tracked: self.momentum.len(),
223 }
224 }
225}
226
227#[derive(Debug, Clone)]
228pub struct PDEAwareStats {
229 pub step: usize,
230 pub average_residual_variance: f32,
231 pub parameters_tracked: usize,
232}
233
234impl Optimizer for PDEAwareOptimizer {
235 fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
236 match (parameter, grad) {
237 (Tensor::F32(param), Tensor::F32(grad_arr)) => {
238 self.step += 1;
239
240 let param_id = format!("{:p}", param.as_ptr());
241
242 let grad_norm: f32 = grad_arr.iter().map(|g| g * g).sum::<f32>().sqrt();
244 let max_grad: f32 = grad_arr.iter().map(|g| g.abs()).fold(0.0, f32::max);
245
246 let residual_variance = self.compute_residual_variance_from_norm(grad_norm);
247 let is_sharp_region = self.is_sharp_gradient_region_from_norm(grad_norm, max_grad);
248
249 let adaptive_lr = self.adaptive_learning_rate(
251 self.learning_rate,
252 residual_variance,
253 is_sharp_region,
254 );
255
256 let m = self
258 .momentum
259 .entry(param_id.clone())
260 .or_insert_with(|| vec![0.0; grad_arr.len()]);
261 let v = self.variance.entry(param_id).or_insert_with(|| vec![0.0; grad_arr.len()]);
262
263 if m.len() != grad_arr.len() || v.len() != grad_arr.len() {
264 return Err(TrustformersError::tensor_op_error(
265 "Momentum/variance buffer size mismatch",
266 "pde_aware_update",
267 ));
268 }
269
270 for i in 0..grad_arr.len() {
272 m[i] = self.beta1 * m[i] + (1.0 - self.beta1) * grad_arr[i];
273 v[i] = self.beta2 * v[i] + (1.0 - self.beta2) * grad_arr[i] * grad_arr[i];
274 }
275
276 let bias_correction1 = 1.0 - self.beta1.powi(self.step as i32);
278 let bias_correction2 = 1.0 - self.beta2.powi(self.step as i32);
279
280 let mut update_vec = vec![0.0; param.len()];
282 for i in 0..param.len() {
283 let m_hat = m[i] / bias_correction1;
284 let v_hat = v[i] / bias_correction2;
285
286 let update = adaptive_lr * m_hat / (v_hat.sqrt() + self.epsilon);
287 update_vec[i] = update;
288
289 if self.weight_decay > 0.0 {
291 update_vec[i] += self.weight_decay * param[i];
292 }
293 }
294
295 for (i, update) in update_vec.iter().enumerate() {
297 param[i] -= update;
298 }
299
300 Ok(())
301 },
302 _ => Err(TrustformersError::tensor_op_error(
303 "Unsupported tensor types for PDEAwareOptimizer",
304 "pde_aware_update",
305 )),
306 }
307 }
308
309 fn zero_grad(&mut self) {
310 }
312
313 fn step(&mut self) {
314 }
316
317 fn get_lr(&self) -> f32 {
318 self.learning_rate
319 }
320
321 fn set_lr(&mut self, lr: f32) {
322 self.learning_rate = lr;
323 }
324}