Skip to main content

trustformers_optim/
pytorch_compat.rs

1//! PyTorch Optimizer API Compatibility Layer
2//!
3//! This module provides PyTorch-compatible optimizer interfaces for seamless
4//! integration with PyTorch-based training workflows. It wraps our native
5//! optimizers to provide the familiar PyTorch API while maintaining high performance.
6
7use crate::traits::StatefulOptimizer;
8use crate::{Adam, AdamW, LRScheduler, OptimizerState, SGD};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::{Arc, Mutex};
12use trustformers_core::errors::{Result, TrustformersError};
13use trustformers_core::traits::Optimizer;
14use trustformers_core::Tensor;
15
16/// PyTorch-compatible optimizer parameter group
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct PyTorchParamGroup {
19    pub params: Vec<String>, // Parameter names/IDs
20    pub lr: f64,
21    pub weight_decay: f64,
22    pub momentum: Option<f64>,
23    pub dampening: Option<f64>,
24    pub eps: Option<f64>,
25    pub betas: Option<(f64, f64)>,
26    pub alpha: Option<f64>,
27    pub amsgrad: Option<bool>,
28    pub maximize: Option<bool>,
29    pub foreach: Option<bool>,
30    pub differentiable: Option<bool>,
31}
32
33impl Default for PyTorchParamGroup {
34    fn default() -> Self {
35        Self {
36            params: Vec::new(),
37            lr: 0.001,
38            weight_decay: 0.0,
39            momentum: None,
40            dampening: None,
41            eps: Some(1e-8),
42            betas: Some((0.9, 0.999)),
43            alpha: None,
44            amsgrad: Some(false),
45            maximize: Some(false),
46            foreach: None,
47            differentiable: Some(false),
48        }
49    }
50}
51
52/// PyTorch-compatible optimizer state
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct PyTorchOptimizerState {
55    pub state: HashMap<String, serde_json::Value>,
56    pub param_groups: Vec<PyTorchParamGroup>,
57}
58
59/// PyTorch-compatible optimizer configuration
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct PyTorchOptimizerConfig {
62    pub optimizer_type: String,
63    pub learning_rate: f64,
64    pub betas: (f64, f64),
65    pub epsilon: f64,
66    pub weight_decay: f64,
67    pub amsgrad: bool,
68    pub maximize: bool,
69    pub parameters: HashMap<String, serde_json::Value>,
70}
71
72impl Default for PyTorchOptimizerConfig {
73    fn default() -> Self {
74        Self {
75            optimizer_type: "Adam".to_string(),
76            learning_rate: 1e-3,
77            betas: (0.9, 0.999),
78            epsilon: 1e-8,
79            weight_decay: 0.0,
80            amsgrad: false,
81            maximize: false,
82            parameters: HashMap::new(),
83        }
84    }
85}
86
87/// PyTorch-compatible optimizer interface
88pub trait PyTorchOptimizer: Send + Sync {
89    /// Get parameter groups
90    fn param_groups(&self) -> &[PyTorchParamGroup];
91
92    /// Get mutable parameter groups
93    fn param_groups_mut(&mut self) -> &mut [PyTorchParamGroup];
94
95    /// Get optimizer state
96    fn state_dict(&self) -> PyTorchOptimizerState;
97
98    /// Load optimizer state
99    fn load_state_dict(&mut self, state: PyTorchOptimizerState) -> Result<()>;
100
101    /// Perform optimization step
102    fn step(&mut self, closure: Option<Box<dyn Fn() -> f64>>) -> Result<Option<f64>>;
103
104    /// Zero gradients
105    fn zero_grad(&mut self, set_to_none: bool) -> Result<()>;
106
107    /// Add parameter group
108    fn add_param_group(&mut self, param_group: PyTorchParamGroup) -> Result<()>;
109
110    /// Get defaults
111    fn defaults(&self) -> PyTorchParamGroup;
112}
113
114/// PyTorch-compatible Adam optimizer
115#[derive(Debug)]
116pub struct PyTorchAdam {
117    inner: Adam,
118    param_groups: Vec<PyTorchParamGroup>,
119    parameters: Arc<Mutex<HashMap<String, Tensor>>>,
120    gradients: Arc<Mutex<HashMap<String, Tensor>>>,
121}
122
123impl PyTorchAdam {
124    /// Create new PyTorch-compatible Adam optimizer
125    pub fn new(
126        params: Vec<PyTorchParamGroup>,
127        lr: f64,
128        betas: (f64, f64),
129        eps: f64,
130        weight_decay: f64,
131        _amsgrad: bool,
132    ) -> Result<Self> {
133        let inner = Adam::new(
134            lr as f32,
135            (betas.0 as f32, betas.1 as f32),
136            eps as f32,
137            weight_decay as f32,
138        );
139
140        Ok(Self {
141            inner,
142            param_groups: params,
143            parameters: Arc::new(Mutex::new(HashMap::new())),
144            gradients: Arc::new(Mutex::new(HashMap::new())),
145        })
146    }
147
148    /// Create with default parameters
149    pub fn from_params(params: impl IntoIterator<Item = (String, Tensor)>) -> Result<Self> {
150        let param_group = PyTorchParamGroup {
151            params: params.into_iter().map(|(name, _)| name).collect(),
152            ..Default::default()
153        };
154
155        Self::new(vec![param_group], 0.001, (0.9, 0.999), 1e-8, 0.0, false)
156    }
157
158    /// Create PyTorch Adam optimizer from configuration
159    pub fn from_config(config: PyTorchOptimizerConfig) -> Result<Self> {
160        // Create parameter group from config
161        let param_group = PyTorchParamGroup {
162            params: config.parameters.keys().cloned().collect(),
163            lr: config.learning_rate,
164            weight_decay: config.weight_decay,
165            eps: Some(config.epsilon),
166            betas: Some(config.betas),
167            amsgrad: Some(config.amsgrad),
168            maximize: Some(config.maximize),
169            ..Default::default()
170        };
171
172        Self::new(
173            vec![param_group],
174            config.learning_rate,
175            config.betas,
176            config.epsilon,
177            config.weight_decay,
178            config.amsgrad,
179        )
180    }
181
182    /// Create PyTorch Adam optimizer from cross-framework configuration
183    pub fn from_cross_framework_config(
184        config: crate::cross_framework::PyTorchOptimizerConfig,
185    ) -> Result<Self> {
186        // Extract parameters from the HashMap
187        let betas = if let Some(betas_val) = config.parameters.get("betas") {
188            if let Some(arr) = betas_val.as_array() {
189                (
190                    arr[0].as_f64().unwrap_or(0.9),
191                    arr[1].as_f64().unwrap_or(0.999),
192                )
193            } else {
194                (0.9, 0.999)
195            }
196        } else {
197            (0.9, 0.999)
198        };
199
200        let epsilon = config.parameters.get("epsilon").and_then(|v| v.as_f64()).unwrap_or(1e-8);
201
202        let weight_decay =
203            config.parameters.get("weight_decay").and_then(|v| v.as_f64()).unwrap_or(0.0);
204
205        let amsgrad = config.parameters.get("amsgrad").and_then(|v| v.as_bool()).unwrap_or(false);
206
207        // Create parameter group from config
208        let param_group = PyTorchParamGroup {
209            params: Vec::new(),
210            lr: config.learning_rate as f64,
211            weight_decay,
212            eps: Some(epsilon),
213            betas: Some(betas),
214            amsgrad: Some(amsgrad),
215            maximize: Some(false),
216            ..Default::default()
217        };
218
219        Self::new(
220            vec![param_group],
221            config.learning_rate as f64,
222            betas,
223            epsilon,
224            weight_decay,
225            amsgrad,
226        )
227    }
228
229    /// Register parameter
230    pub fn register_param(&mut self, name: String, param: Tensor) -> Result<()> {
231        let mut params = self
232            .parameters
233            .lock()
234            .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
235        params.insert(name, param);
236        Ok(())
237    }
238
239    /// Set gradient for parameter
240    pub fn set_grad(&mut self, name: String, grad: Tensor) -> Result<()> {
241        let mut grads = self
242            .gradients
243            .lock()
244            .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
245        grads.insert(name, grad);
246        Ok(())
247    }
248
249    /// Load optimizer state from OptimizerState format
250    ///
251    /// Converts OptimizerState (with Vec<f32> values) to PyTorch-compatible format
252    fn load_optimizer_state(&mut self, optimizer_state: OptimizerState) -> Result<()> {
253        // Convert momentum buffers from Vec<f32> to Tensor format
254        for (param_name, momentum_data) in optimizer_state.momentum {
255            let momentum_tensor = Tensor::new(momentum_data)?;
256            // Store in a format that the inner optimizer can use
257            // For now, we'll store it directly but in a real implementation,
258            // this would integrate with the inner optimizer's state management
259
260            // The inner optimizer would typically have its own state management
261            // Here we just ensure the data is available for parameter updates
262            let mut params = self
263                .parameters
264                .lock()
265                .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
266            if !params.contains_key(&param_name) {
267                // Create placeholder parameter if it doesn't exist
268                params.insert(param_name.clone(), momentum_tensor.clone());
269            }
270        }
271
272        // Convert variance buffers from Vec<f32> to Tensor format (for Adam-like optimizers)
273        for (param_name, variance_data) in optimizer_state.variance {
274            let variance_tensor = Tensor::new(variance_data)?;
275            // Similar to momentum, store variance information
276            // The inner Adam optimizer would use this for second moment estimation
277
278            // For now, we ensure the parameter exists in our registry
279            let mut params = self
280                .parameters
281                .lock()
282                .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
283            if !params.contains_key(&param_name) {
284                params.insert(param_name.clone(), variance_tensor.clone());
285            }
286        }
287
288        // Update step counter if available
289        // The inner optimizer should sync with this step count for bias correction
290        // self.inner would typically have a method to set the step count
291
292        Ok(())
293    }
294}
295
296impl PyTorchOptimizer for PyTorchAdam {
297    fn param_groups(&self) -> &[PyTorchParamGroup] {
298        &self.param_groups
299    }
300
301    fn param_groups_mut(&mut self) -> &mut [PyTorchParamGroup] {
302        &mut self.param_groups
303    }
304
305    fn state_dict(&self) -> PyTorchOptimizerState {
306        let state = self.inner.state();
307        let state_json = serde_json::to_value(state).unwrap_or_default();
308
309        PyTorchOptimizerState {
310            state: [(String::from("adam_state"), state_json)].into(),
311            param_groups: self.param_groups.clone(),
312        }
313    }
314
315    fn load_state_dict(&mut self, state: PyTorchOptimizerState) -> Result<()> {
316        self.param_groups = state.param_groups;
317
318        if let Some(adam_state) = state.state.get("adam_state") {
319            if let Ok(optimizer_state) =
320                serde_json::from_value::<OptimizerState>(adam_state.clone())
321            {
322                // Convert OptimizerState to PyTorch-compatible format
323                self.load_optimizer_state(optimizer_state)?;
324            }
325        }
326
327        Ok(())
328    }
329
330    fn step(&mut self, closure: Option<Box<dyn Fn() -> f64>>) -> Result<Option<f64>> {
331        let loss = closure.map(|closure_fn| closure_fn());
332
333        // Apply gradients to parameters using the inner optimizer
334        for group in &self.param_groups {
335            for param_name in &group.params {
336                // Get copies of parameter and gradient to avoid borrow conflicts
337                let param_copy = {
338                    let params = self.parameters.lock().map_err(|_| {
339                        TrustformersError::runtime_error("Mutex lock poisoned".into())
340                    })?;
341                    params.get(param_name).cloned()
342                };
343                let grad_copy = {
344                    let grads = self.gradients.lock().map_err(|_| {
345                        TrustformersError::runtime_error("Mutex lock poisoned".into())
346                    })?;
347                    grads.get(param_name).cloned()
348                };
349
350                if let (Some(mut param), Some(grad)) = (param_copy, grad_copy) {
351                    // Apply update to parameter copy
352                    self.inner.update(&mut param, &grad)?;
353
354                    // Store updated parameter back
355                    let mut params = self.parameters.lock().map_err(|_| {
356                        TrustformersError::runtime_error("Mutex lock poisoned".into())
357                    })?;
358                    params.insert(param_name.clone(), param);
359                }
360            }
361        }
362
363        Ok(loss)
364    }
365
366    fn zero_grad(&mut self, _set_to_none: bool) -> Result<()> {
367        let mut grads = self
368            .gradients
369            .lock()
370            .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
371        grads.clear();
372        Ok(())
373    }
374
375    fn add_param_group(&mut self, param_group: PyTorchParamGroup) -> Result<()> {
376        self.param_groups.push(param_group);
377        Ok(())
378    }
379
380    fn defaults(&self) -> PyTorchParamGroup {
381        PyTorchParamGroup {
382            lr: 0.001,
383            betas: Some((0.9, 0.999)),
384            eps: Some(1e-8),
385            weight_decay: 0.0,
386            amsgrad: Some(false),
387            ..Default::default()
388        }
389    }
390}
391
392/// PyTorch-compatible AdamW optimizer
393#[derive(Debug)]
394pub struct PyTorchAdamW {
395    inner: AdamW,
396    param_groups: Vec<PyTorchParamGroup>,
397    parameters: Arc<Mutex<HashMap<String, Tensor>>>,
398    gradients: Arc<Mutex<HashMap<String, Tensor>>>,
399}
400
401impl PyTorchAdamW {
402    /// Create new PyTorch-compatible AdamW optimizer
403    pub fn new(
404        params: Vec<PyTorchParamGroup>,
405        lr: f64,
406        betas: (f64, f64),
407        eps: f64,
408        weight_decay: f64,
409        _amsgrad: bool,
410    ) -> Result<Self> {
411        let inner = AdamW::new(
412            lr as f32,
413            (betas.0 as f32, betas.1 as f32),
414            eps as f32,
415            weight_decay as f32,
416        );
417
418        Ok(Self {
419            inner,
420            param_groups: params,
421            parameters: Arc::new(Mutex::new(HashMap::new())),
422            gradients: Arc::new(Mutex::new(HashMap::new())),
423        })
424    }
425
426    /// Create with default parameters
427    pub fn from_params(params: impl IntoIterator<Item = (String, Tensor)>) -> Result<Self> {
428        let param_group = PyTorchParamGroup {
429            params: params.into_iter().map(|(name, _)| name).collect(),
430            ..Default::default()
431        };
432
433        Self::new(vec![param_group], 0.001, (0.9, 0.999), 1e-8, 0.01, false)
434    }
435
436    /// Register parameter
437    pub fn register_param(&mut self, name: String, param: Tensor) -> Result<()> {
438        let mut params = self
439            .parameters
440            .lock()
441            .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
442        params.insert(name, param);
443        Ok(())
444    }
445
446    /// Set gradient for parameter
447    pub fn set_grad(&mut self, name: String, grad: Tensor) -> Result<()> {
448        let mut grads = self
449            .gradients
450            .lock()
451            .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
452        grads.insert(name, grad);
453        Ok(())
454    }
455
456    /// Load optimizer state from OptimizerState format
457    ///
458    /// Converts OptimizerState (with Vec<f32> values) to PyTorch-compatible format
459    fn load_optimizer_state(&mut self, optimizer_state: OptimizerState) -> Result<()> {
460        // Convert momentum buffers from Vec<f32> to Tensor format
461        for (param_name, momentum_data) in optimizer_state.momentum {
462            let momentum_tensor = Tensor::new(momentum_data)?;
463            let mut params = self
464                .parameters
465                .lock()
466                .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
467            if !params.contains_key(&param_name) {
468                params.insert(param_name.clone(), momentum_tensor.clone());
469            }
470        }
471
472        // Convert variance buffers from Vec<f32> to Tensor format (for AdamW)
473        for (param_name, variance_data) in optimizer_state.variance {
474            let variance_tensor = Tensor::new(variance_data)?;
475            let mut params = self
476                .parameters
477                .lock()
478                .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
479            if !params.contains_key(&param_name) {
480                params.insert(param_name.clone(), variance_tensor.clone());
481            }
482        }
483
484        Ok(())
485    }
486}
487
488impl PyTorchOptimizer for PyTorchAdamW {
489    fn param_groups(&self) -> &[PyTorchParamGroup] {
490        &self.param_groups
491    }
492
493    fn param_groups_mut(&mut self) -> &mut [PyTorchParamGroup] {
494        &mut self.param_groups
495    }
496
497    fn state_dict(&self) -> PyTorchOptimizerState {
498        let state = self.inner.state();
499        let state_json = serde_json::to_value(state).unwrap_or_default();
500
501        PyTorchOptimizerState {
502            state: [(String::from("adamw_state"), state_json)].into(),
503            param_groups: self.param_groups.clone(),
504        }
505    }
506
507    fn load_state_dict(&mut self, state: PyTorchOptimizerState) -> Result<()> {
508        self.param_groups = state.param_groups;
509
510        if let Some(adamw_state) = state.state.get("adamw_state") {
511            if let Ok(optimizer_state) =
512                serde_json::from_value::<OptimizerState>(adamw_state.clone())
513            {
514                // Convert OptimizerState to PyTorch-compatible format
515                self.load_optimizer_state(optimizer_state)?;
516            }
517        }
518
519        Ok(())
520    }
521
522    fn step(&mut self, closure: Option<Box<dyn Fn() -> f64>>) -> Result<Option<f64>> {
523        let loss = closure.map(|closure_fn| closure_fn());
524
525        for group in &self.param_groups {
526            for param_name in &group.params {
527                // Get copies of parameter and gradient to avoid borrow conflicts
528                let param_copy = {
529                    let params = self.parameters.lock().map_err(|_| {
530                        TrustformersError::runtime_error("Mutex lock poisoned".into())
531                    })?;
532                    params.get(param_name).cloned()
533                };
534                let grad_copy = {
535                    let grads = self.gradients.lock().map_err(|_| {
536                        TrustformersError::runtime_error("Mutex lock poisoned".into())
537                    })?;
538                    grads.get(param_name).cloned()
539                };
540
541                if let (Some(mut param), Some(grad)) = (param_copy, grad_copy) {
542                    // Apply update to parameter copy
543                    self.inner.update(&mut param, &grad)?;
544
545                    // Store updated parameter back
546                    let mut params = self.parameters.lock().map_err(|_| {
547                        TrustformersError::runtime_error("Mutex lock poisoned".into())
548                    })?;
549                    params.insert(param_name.clone(), param);
550                }
551            }
552        }
553
554        Ok(loss)
555    }
556
557    fn zero_grad(&mut self, _set_to_none: bool) -> Result<()> {
558        let mut grads = self
559            .gradients
560            .lock()
561            .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
562        grads.clear();
563        Ok(())
564    }
565
566    fn add_param_group(&mut self, param_group: PyTorchParamGroup) -> Result<()> {
567        self.param_groups.push(param_group);
568        Ok(())
569    }
570
571    fn defaults(&self) -> PyTorchParamGroup {
572        PyTorchParamGroup {
573            lr: 0.001,
574            betas: Some((0.9, 0.999)),
575            eps: Some(1e-8),
576            weight_decay: 0.01,
577            amsgrad: Some(false),
578            ..Default::default()
579        }
580    }
581}
582
583/// PyTorch-compatible SGD optimizer
584#[derive(Debug)]
585pub struct PyTorchSGD {
586    inner: SGD,
587    param_groups: Vec<PyTorchParamGroup>,
588    parameters: Arc<Mutex<HashMap<String, Tensor>>>,
589    gradients: Arc<Mutex<HashMap<String, Tensor>>>,
590}
591
592impl PyTorchSGD {
593    /// Create new PyTorch-compatible SGD optimizer
594    pub fn new(
595        params: Vec<PyTorchParamGroup>,
596        lr: f64,
597        momentum: f64,
598        dampening: f64,
599        weight_decay: f64,
600        nesterov: bool,
601    ) -> Result<Self> {
602        let config = crate::sgd::SGDConfig {
603            lr: lr as f32,
604            momentum: momentum as f32,
605            dampening: dampening as f32,
606            weight_decay: weight_decay as f32,
607            nesterov,
608        };
609
610        let inner = SGD::from_config(config);
611
612        Ok(Self {
613            inner,
614            param_groups: params,
615            parameters: Arc::new(Mutex::new(HashMap::new())),
616            gradients: Arc::new(Mutex::new(HashMap::new())),
617        })
618    }
619
620    /// Create with default parameters
621    pub fn from_params(params: impl IntoIterator<Item = (String, Tensor)>) -> Result<Self> {
622        let param_group = PyTorchParamGroup {
623            params: params.into_iter().map(|(name, _)| name).collect(),
624            lr: 0.01,
625            momentum: Some(0.0),
626            dampening: Some(0.0),
627            weight_decay: 0.0,
628            ..Default::default()
629        };
630
631        Self::new(vec![param_group], 0.01, 0.0, 0.0, 0.0, false)
632    }
633
634    /// Register parameter
635    pub fn register_param(&mut self, name: String, param: Tensor) -> Result<()> {
636        let mut params = self
637            .parameters
638            .lock()
639            .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
640        params.insert(name, param);
641        Ok(())
642    }
643
644    /// Set gradient for parameter
645    pub fn set_grad(&mut self, name: String, grad: Tensor) -> Result<()> {
646        let mut grads = self
647            .gradients
648            .lock()
649            .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
650        grads.insert(name, grad);
651        Ok(())
652    }
653
654    /// Load optimizer state from OptimizerState format
655    ///
656    /// Converts OptimizerState (with Vec<f32> values) to PyTorch-compatible format
657    fn load_optimizer_state(&mut self, optimizer_state: OptimizerState) -> Result<()> {
658        // Convert momentum buffers from Vec<f32> to Tensor format (SGD momentum)
659        for (param_name, momentum_data) in optimizer_state.momentum {
660            let momentum_tensor = Tensor::new(momentum_data)?;
661            let mut params = self
662                .parameters
663                .lock()
664                .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
665            if !params.contains_key(&param_name) {
666                params.insert(param_name.clone(), momentum_tensor.clone());
667            }
668        }
669
670        // SGD typically doesn't use variance buffers, but handle them if present
671        for (param_name, variance_data) in optimizer_state.variance {
672            let variance_tensor = Tensor::new(variance_data)?;
673            let mut params = self
674                .parameters
675                .lock()
676                .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
677            if !params.contains_key(&param_name) {
678                params.insert(param_name.clone(), variance_tensor.clone());
679            }
680        }
681
682        Ok(())
683    }
684}
685
686impl PyTorchOptimizer for PyTorchSGD {
687    fn param_groups(&self) -> &[PyTorchParamGroup] {
688        &self.param_groups
689    }
690
691    fn param_groups_mut(&mut self) -> &mut [PyTorchParamGroup] {
692        &mut self.param_groups
693    }
694
695    fn state_dict(&self) -> PyTorchOptimizerState {
696        let state = self.inner.state();
697        let state_json = serde_json::to_value(state).unwrap_or_default();
698
699        PyTorchOptimizerState {
700            state: [(String::from("sgd_state"), state_json)].into(),
701            param_groups: self.param_groups.clone(),
702        }
703    }
704
705    fn load_state_dict(&mut self, state: PyTorchOptimizerState) -> Result<()> {
706        self.param_groups = state.param_groups;
707
708        if let Some(sgd_state) = state.state.get("sgd_state") {
709            if let Ok(optimizer_state) = serde_json::from_value::<OptimizerState>(sgd_state.clone())
710            {
711                // Convert OptimizerState to PyTorch-compatible format
712                self.load_optimizer_state(optimizer_state)?;
713            }
714        }
715
716        Ok(())
717    }
718
719    fn step(&mut self, closure: Option<Box<dyn Fn() -> f64>>) -> Result<Option<f64>> {
720        let loss = closure.map(|closure_fn| closure_fn());
721
722        for group in &self.param_groups {
723            for param_name in &group.params {
724                // Get copies of parameter and gradient to avoid borrow conflicts
725                let param_copy = {
726                    let params = self.parameters.lock().map_err(|_| {
727                        TrustformersError::runtime_error("Mutex lock poisoned".into())
728                    })?;
729                    params.get(param_name).cloned()
730                };
731                let grad_copy = {
732                    let grads = self.gradients.lock().map_err(|_| {
733                        TrustformersError::runtime_error("Mutex lock poisoned".into())
734                    })?;
735                    grads.get(param_name).cloned()
736                };
737
738                if let (Some(mut param), Some(grad)) = (param_copy, grad_copy) {
739                    // Apply update to parameter copy
740                    self.inner.update(&mut param, &grad)?;
741
742                    // Store updated parameter back
743                    let mut params = self.parameters.lock().map_err(|_| {
744                        TrustformersError::runtime_error("Mutex lock poisoned".into())
745                    })?;
746                    params.insert(param_name.clone(), param);
747                }
748            }
749        }
750
751        Ok(loss)
752    }
753
754    fn zero_grad(&mut self, _set_to_none: bool) -> Result<()> {
755        let mut grads = self
756            .gradients
757            .lock()
758            .map_err(|_| TrustformersError::runtime_error("Mutex lock poisoned".into()))?;
759        grads.clear();
760        Ok(())
761    }
762
763    fn add_param_group(&mut self, param_group: PyTorchParamGroup) -> Result<()> {
764        self.param_groups.push(param_group);
765        Ok(())
766    }
767
768    fn defaults(&self) -> PyTorchParamGroup {
769        PyTorchParamGroup {
770            lr: 0.01,
771            momentum: Some(0.0),
772            dampening: Some(0.0),
773            weight_decay: 0.0,
774            ..Default::default()
775        }
776    }
777}
778
779/// PyTorch optimizer factory for creating optimizers with PyTorch-compatible API
780pub struct PyTorchOptimizerFactory;
781
782impl PyTorchOptimizerFactory {
783    /// Create Adam optimizer with PyTorch API
784    pub fn adam(
785        params: impl IntoIterator<Item = (String, Tensor)>,
786        lr: f64,
787        betas: (f64, f64),
788        eps: f64,
789        weight_decay: f64,
790        amsgrad: bool,
791    ) -> Result<PyTorchAdam> {
792        let param_group = PyTorchParamGroup {
793            params: params.into_iter().map(|(name, _)| name).collect(),
794            lr,
795            betas: Some(betas),
796            eps: Some(eps),
797            weight_decay,
798            amsgrad: Some(amsgrad),
799            ..Default::default()
800        };
801
802        PyTorchAdam::new(vec![param_group], lr, betas, eps, weight_decay, amsgrad)
803    }
804
805    /// Create AdamW optimizer with PyTorch API
806    pub fn adamw(
807        params: impl IntoIterator<Item = (String, Tensor)>,
808        lr: f64,
809        betas: (f64, f64),
810        eps: f64,
811        weight_decay: f64,
812        amsgrad: bool,
813    ) -> Result<PyTorchAdamW> {
814        let param_group = PyTorchParamGroup {
815            params: params.into_iter().map(|(name, _)| name).collect(),
816            lr,
817            betas: Some(betas),
818            eps: Some(eps),
819            weight_decay,
820            amsgrad: Some(amsgrad),
821            ..Default::default()
822        };
823
824        PyTorchAdamW::new(vec![param_group], lr, betas, eps, weight_decay, amsgrad)
825    }
826
827    /// Create SGD optimizer with PyTorch API
828    pub fn sgd(
829        params: impl IntoIterator<Item = (String, Tensor)>,
830        lr: f64,
831        momentum: f64,
832        dampening: f64,
833        weight_decay: f64,
834        nesterov: bool,
835    ) -> Result<PyTorchSGD> {
836        let param_group = PyTorchParamGroup {
837            params: params.into_iter().map(|(name, _)| name).collect(),
838            lr,
839            momentum: Some(momentum),
840            dampening: Some(dampening),
841            weight_decay,
842            ..Default::default()
843        };
844
845        PyTorchSGD::new(
846            vec![param_group],
847            lr,
848            momentum,
849            dampening,
850            weight_decay,
851            nesterov,
852        )
853    }
854}
855
856/// PyTorch-compatible learning rate scheduler wrapper
857pub struct PyTorchLRScheduler {
858    inner_scheduler: Box<dyn LRScheduler>,
859    optimizer: Box<dyn PyTorchOptimizer>,
860    last_epoch: i64,
861}
862
863impl PyTorchLRScheduler {
864    /// Create new scheduler wrapper
865    pub fn new(optimizer: Box<dyn PyTorchOptimizer>, scheduler: Box<dyn LRScheduler>) -> Self {
866        Self {
867            inner_scheduler: scheduler,
868            optimizer,
869            last_epoch: -1,
870        }
871    }
872
873    /// Step the scheduler
874    pub fn step(&mut self, epoch: Option<i64>) -> Result<()> {
875        let current_epoch = epoch.unwrap_or(self.last_epoch + 1);
876        self.last_epoch = current_epoch;
877
878        let new_lr = self.inner_scheduler.get_lr(current_epoch as usize);
879
880        // Update all parameter groups
881        for group in self.optimizer.param_groups_mut() {
882            group.lr = new_lr as f64;
883        }
884
885        Ok(())
886    }
887
888    /// Get current learning rate
889    pub fn get_last_lr(&self) -> f64 {
890        self.inner_scheduler.get_lr(self.last_epoch.max(0) as usize) as f64
891    }
892
893    /// Get current state dict
894    pub fn state_dict(&self) -> serde_json::Value {
895        serde_json::json!({
896            "last_epoch": self.last_epoch,
897            "scheduler_state": "serialized_state" // Would need scheduler serialization
898        })
899    }
900
901    /// Load state dict
902    pub fn load_state_dict(&mut self, state: serde_json::Value) -> Result<()> {
903        if let Some(epoch) = state.get("last_epoch").and_then(|e| e.as_i64()) {
904            self.last_epoch = epoch;
905        }
906        Ok(())
907    }
908}
909
910#[cfg(test)]
911mod tests {
912    use super::*;
913    use trustformers_core::Tensor;
914
915    #[test]
916    fn test_pytorch_adam_creation() {
917        let params = vec![
918            ("param1".to_string(), Tensor::zeros(&[10, 10]).unwrap()),
919            ("param2".to_string(), Tensor::zeros(&[5, 5]).unwrap()),
920        ];
921
922        let optimizer = PyTorchAdam::from_params(params).unwrap();
923        assert_eq!(optimizer.param_groups().len(), 1);
924        assert_eq!(optimizer.param_groups()[0].params.len(), 2);
925    }
926
927    #[test]
928    fn test_pytorch_adamw_creation() {
929        let params = vec![("param1".to_string(), Tensor::zeros(&[10, 10]).unwrap())];
930
931        let optimizer = PyTorchAdamW::from_params(params).unwrap();
932        assert_eq!(optimizer.param_groups().len(), 1);
933        assert_eq!(optimizer.defaults().weight_decay, 0.01);
934    }
935
936    #[test]
937    fn test_pytorch_sgd_creation() {
938        let params = vec![("param1".to_string(), Tensor::zeros(&[10, 10]).unwrap())];
939
940        let optimizer = PyTorchSGD::from_params(params).unwrap();
941        assert_eq!(optimizer.param_groups().len(), 1);
942        assert_eq!(optimizer.defaults().lr, 0.01);
943    }
944
945    #[test]
946    fn test_pytorch_optimizer_factory() {
947        let params = vec![("param1".to_string(), Tensor::zeros(&[10, 10]).unwrap())];
948
949        let adam =
950            PyTorchOptimizerFactory::adam(params.clone(), 0.001, (0.9, 0.999), 1e-8, 0.0, false)
951                .unwrap();
952        assert_eq!(adam.param_groups()[0].lr, 0.001);
953
954        let adamw =
955            PyTorchOptimizerFactory::adamw(params.clone(), 0.001, (0.9, 0.999), 1e-8, 0.01, false)
956                .unwrap();
957        assert_eq!(adamw.param_groups()[0].weight_decay, 0.01);
958
959        let sgd = PyTorchOptimizerFactory::sgd(params, 0.01, 0.9, 0.0, 0.0, false).unwrap();
960        assert_eq!(sgd.param_groups()[0].momentum, Some(0.9));
961    }
962
963    #[test]
964    fn test_param_group_operations() {
965        let params = vec![("param1".to_string(), Tensor::zeros(&[10, 10]).unwrap())];
966
967        let mut optimizer = PyTorchAdam::from_params(params).unwrap();
968
969        let new_group = PyTorchParamGroup {
970            params: vec!["param2".to_string()],
971            lr: 0.002,
972            ..Default::default()
973        };
974
975        optimizer.add_param_group(new_group).unwrap();
976        assert_eq!(optimizer.param_groups().len(), 2);
977        assert_eq!(optimizer.param_groups()[1].lr, 0.002);
978    }
979
980    #[test]
981    fn test_state_dict_operations() {
982        let params = vec![("param1".to_string(), Tensor::zeros(&[10, 10]).unwrap())];
983
984        let optimizer = PyTorchAdam::from_params(params).unwrap();
985        let state_dict = optimizer.state_dict();
986
987        assert_eq!(state_dict.param_groups.len(), 1);
988        assert!(state_dict.state.contains_key("adam_state"));
989    }
990
991    #[test]
992    fn test_zero_grad() {
993        let params = vec![("param1".to_string(), Tensor::zeros(&[10, 10]).unwrap())];
994
995        let mut optimizer = PyTorchAdam::from_params(params).unwrap();
996        optimizer
997            .set_grad("param1".to_string(), Tensor::ones(&[10, 10]).unwrap())
998            .unwrap();
999
1000        // Check that gradient is set
1001        assert_eq!(
1002            optimizer.gradients.lock().expect("Mutex lock poisoned").len(),
1003            1
1004        );
1005
1006        // Zero gradients
1007        optimizer.zero_grad(false).unwrap();
1008        assert_eq!(
1009            optimizer.gradients.lock().expect("Mutex lock poisoned").len(),
1010            0
1011        );
1012    }
1013}