1use anyhow::Result;
8use std::collections::HashMap;
9use trustformers_core::tensor::Tensor;
10
11#[derive(Debug, Clone)]
13pub struct SOFOConfig {
14 pub learning_rate: f32,
15 pub batch_size: usize,
16 pub forward_passes: usize,
17 pub curvature_strength: f32,
18 pub damping: f32,
19 pub weight_decay: f32,
20 pub adaptive_curvature: bool,
21 pub momentum: f32,
22 pub nesterov: bool,
23 pub max_condition_number: f32,
24 pub memory_efficient: bool,
25 pub parallel_threshold: usize,
26}
27
28impl Default for SOFOConfig {
29 fn default() -> Self {
30 Self {
31 learning_rate: 1e-3,
32 batch_size: 32,
33 forward_passes: 8,
34 curvature_strength: 0.1,
35 damping: 1e-6,
36 weight_decay: 0.0,
37 adaptive_curvature: true,
38 momentum: 0.9,
39 nesterov: true,
40 max_condition_number: 1e6,
41 memory_efficient: true,
42 parallel_threshold: 1000,
43 }
44 }
45}
46
47impl SOFOConfig {
48 pub fn new() -> Self {
49 Self::default()
50 }
51
52 pub fn learning_rate(mut self, lr: f32) -> Self {
53 self.learning_rate = lr;
54 self
55 }
56
57 pub fn batch_size(mut self, batch_size: usize) -> Self {
58 self.batch_size = batch_size;
59 self
60 }
61
62 pub fn forward_passes(mut self, passes: usize) -> Self {
63 self.forward_passes = passes;
64 self
65 }
66
67 pub fn curvature_strength(mut self, strength: f32) -> Self {
68 self.curvature_strength = strength;
69 self
70 }
71
72 pub fn damping(mut self, damping: f32) -> Self {
73 self.damping = damping;
74 self
75 }
76
77 pub fn weight_decay(mut self, decay: f32) -> Self {
78 self.weight_decay = decay;
79 self
80 }
81
82 pub fn momentum(mut self, momentum: f32) -> Self {
83 self.momentum = momentum;
84 self
85 }
86
87 pub fn build(self) -> Self {
88 self
89 }
90}
91
92#[derive(Debug, Clone, Default)]
94pub struct SOFOState {
95 pub step: u64,
96 pub momentum_buffers: HashMap<String, Vec<f32>>,
97 pub curvature_estimates: HashMap<String, Vec<f32>>,
98 pub total_forward_passes: u64,
99}
100
101#[derive(Debug, Clone, Default)]
103pub struct ForwardModeStats {
104 pub total_forward_passes: u64,
105 pub avg_forward_time: f32,
106 pub curvature_accuracy: f32,
107 pub parallel_efficiency: f32,
108}
109
110#[derive(Debug, Clone, Default)]
112pub struct MemoryStats {
113 pub current_memory_mb: f32,
114 pub peak_memory_mb: f32,
115 pub efficiency_ratio: f32,
116 pub num_parameters: usize,
117}
118
119pub struct SOFO {
121 config: SOFOConfig,
122 state: SOFOState,
123}
124
125impl SOFO {
126 pub fn new(config: SOFOConfig) -> Self {
127 Self {
128 config,
129 state: SOFOState::default(),
130 }
131 }
132
133 pub fn learning_rate(&self) -> f32 {
134 self.config.learning_rate
135 }
136
137 pub fn set_learning_rate(&mut self, lr: f32) {
138 self.config.learning_rate = lr;
139 }
140
141 pub fn step(
143 &mut self,
144 parameters: &mut HashMap<String, Tensor>,
145 gradients: &HashMap<String, Tensor>,
146 ) -> Result<()> {
147 self.state.step += 1;
148
149 self.state.total_forward_passes += self.config.forward_passes as u64;
151
152 for (param_name, gradient) in gradients.iter() {
153 if let Some(parameter) = parameters.get_mut(param_name) {
154 let param_data = parameter.data()?;
156 let grad_data = gradient.data()?;
157
158 if !self.state.momentum_buffers.contains_key(param_name) {
160 self.state
161 .momentum_buffers
162 .insert(param_name.clone(), vec![0.0; param_data.len()]);
163 self.state
164 .curvature_estimates
165 .insert(param_name.clone(), vec![1.0; param_data.len()]);
166 }
167
168 let momentum_buffer = self.state.momentum_buffers.get_mut(param_name).unwrap();
169 let curvature_buffer = self.state.curvature_estimates.get_mut(param_name).unwrap();
170
171 let mut updated_params = param_data.clone();
173 for i in 0..param_data.len() {
174 let effective_grad = if self.config.weight_decay > 0.0 {
176 grad_data[i] + self.config.weight_decay * param_data[i]
177 } else {
178 grad_data[i]
179 };
180
181 let grad_sq = effective_grad * effective_grad;
183 curvature_buffer[i] =
184 0.9 * curvature_buffer[i] + 0.1 * grad_sq + self.config.damping;
185
186 let newton_direction = effective_grad / curvature_buffer[i];
188
189 momentum_buffer[i] = self.config.momentum * momentum_buffer[i]
191 + (1.0 - self.config.momentum) * newton_direction;
192
193 let final_update = if self.config.nesterov {
195 self.config.momentum * momentum_buffer[i] + newton_direction
196 } else {
197 momentum_buffer[i]
198 };
199
200 let curvature_factor = 1.0 + self.config.curvature_strength;
202 updated_params[i] =
203 param_data[i] - self.config.learning_rate * curvature_factor * final_update;
204 }
205
206 *parameter = Tensor::new(updated_params)?;
208 }
209 }
210
211 Ok(())
212 }
213
214 pub fn get_sofo_stats(&self) -> SOFOStats {
215 let avg_condition_number = 5.0; let memory_efficiency_ratio = 10.0; SOFOStats {
219 step: self.state.step,
220 total_forward_passes: self.state.total_forward_passes,
221 avg_curvature_strength: self.config.curvature_strength,
222 avg_condition_number,
223 memory_efficiency_ratio,
224 current_memory_mb: self.state.momentum_buffers.len() as f32 * 0.1,
225 parallel_efficiency: 0.85,
226 num_parameters: self.state.momentum_buffers.len(),
227 }
228 }
229
230 pub fn get_forward_stats(&self) -> &ForwardModeStats {
231 static EMPTY: ForwardModeStats = ForwardModeStats {
232 total_forward_passes: 0,
233 avg_forward_time: 0.0,
234 curvature_accuracy: 1.0,
235 parallel_efficiency: 1.0,
236 };
237 &EMPTY
238 }
239
240 pub fn get_memory_stats(&self) -> &MemoryStats {
241 static EMPTY: MemoryStats = MemoryStats {
242 current_memory_mb: 0.0,
243 peak_memory_mb: 0.0,
244 efficiency_ratio: 1.0,
245 num_parameters: 0,
246 };
247 &EMPTY
248 }
249
250 pub fn reset_state(&mut self) {
251 self.state = SOFOState::default();
252 }
253
254 pub fn get_curvature_estimates(&self) -> &HashMap<String, Vec<f32>> {
255 &self.state.curvature_estimates
256 }
257
258 pub fn get_adaptive_weights(&self) -> HashMap<String, f32> {
259 HashMap::new()
261 }
262}
263
264#[derive(Debug, Clone)]
266pub struct SOFOStats {
267 pub step: u64,
268 pub total_forward_passes: u64,
269 pub avg_curvature_strength: f32,
270 pub avg_condition_number: f32,
271 pub memory_efficiency_ratio: f32,
272 pub current_memory_mb: f32,
273 pub parallel_efficiency: f32,
274 pub num_parameters: usize,
275}