Skip to main content

torsh_optim/
lib.rs

1//! Optimization algorithms for ToRSh
2//!
3//! This crate provides PyTorch-compatible optimizers built on top of scirs2-optim.
4//!
5//! # Features
6//!
7//! - **80+ optimizers**: Comprehensive collection including Adam, SGD, RAdam, Ranger, Lion, Sophia, and more
8//! - **Modern optimizers**: Latest research including Schedule-Free AdamW and Prodigy
9//! - **Second-order methods**: L-BFGS, Newton-CG, Trust Region, K-FAC, AdaHessian
10//! - **Learning rate schedulers**: Step, exponential, cosine annealing, one-cycle, and more
11//! - **Mixed precision training**: Full fp16/fp32 support with loss scaling
12//! - **Distributed optimization**: AsyncSGD, Elastic Averaging, Federated Learning
13//! - **Advanced features**: Gradient accumulation, fused kernels, memory-efficient implementations
14//! - **Research features**: Quantum-inspired, neuromorphic, continual learning, green AI optimizers
15//!
16//! # Quick Start
17//!
18//! ```rust,no_run
19//! use torsh_optim::prelude::*;
20//! use torsh_tensor::Tensor;
21//! use std::sync::Arc;
22//! use parking_lot::RwLock;
23//!
24//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
25//! // Create parameters
26//! let params = vec![Arc::new(RwLock::new(Tensor::scalar(1.0)?))];
27//!
28//! // Create optimizer
29//! let mut optimizer = Adam::new(params, Some(0.001), None, None, None, false);
30//!
31//! // Training loop
32//! for _ in 0..100 {
33//!     // ... compute gradients ...
34//!     optimizer.step()?;
35//!     optimizer.zero_grad();
36//! }
37//! # Ok(())
38//! # }
39//! ```
40
41#![cfg_attr(not(feature = "std"), no_std)]
42// Note: These allows are necessary for maintaining compatibility with diverse optimizer implementations
43// and reducing noise from legitimate design patterns used across the codebase
44#![allow(dead_code)] // Many optimizers have internal methods not called externally
45#![allow(unused_imports)] // Conditional compilation features may leave some imports unused
46#![allow(unused_variables)] // Some optimizer variants have parameters used only in specific configurations
47#![allow(unused_mut)] // Mutability annotations required for consistency even when not always modified
48
49#[cfg(not(feature = "std"))]
50extern crate alloc;
51
52pub mod adabelief;
53pub mod adabound;
54pub mod adadelta;
55pub mod adagrad;
56pub mod adahessian;
57pub mod adam;
58pub mod adamax;
59pub mod advanced;
60pub mod asgd;
61pub mod bayesian_optimization;
62pub mod benchmarks;
63pub mod checkpointing;
64pub mod composition;
65pub mod continual_learning;
66pub mod cross_framework_validation;
67pub mod debugging;
68pub mod differential_privacy;
69pub mod distributed;
70pub mod evolutionary_strategies;
71pub mod ftrl;
72pub mod fused_kernels;
73pub mod grad_accumulation;
74pub mod gradient_free;
75pub mod green_ai;
76pub mod hyperparameter_tuning;
77pub mod kfac;
78pub mod lamb;
79pub mod lazy_updates;
80pub mod lbfgs;
81pub mod lion;
82pub mod lookahead;
83pub mod low_precision;
84pub mod lr_scheduler;
85pub mod lr_scheduler_additional;
86pub mod lr_scheduler_enhanced;
87pub mod memory_efficient;
88pub mod memory_mapped;
89pub mod mixed_precision;
90pub mod nadam;
91pub mod natural_gradient;
92pub mod neural_optimizer;
93pub mod neuromorphic;
94pub mod newton_cg;
95pub mod numerical_stability_tests;
96pub mod online_learning;
97pub mod optimizer;
98pub mod prodigy;
99pub mod quantum_inspired;
100pub mod radam;
101pub mod ranger;
102pub mod rmsprop;
103pub mod robustness;
104pub mod rprop;
105pub mod schedule_free;
106pub mod sgd;
107pub mod shampoo;
108pub mod sophia;
109pub mod sparse_adam;
110pub mod sparse_updates;
111pub mod state_dict_ops;
112pub mod stress_tests;
113pub mod trust_region;
114pub mod yellowfin;
115
116use parking_lot::RwLock;
117use std::collections::HashMap;
118use std::sync::Arc;
119use torsh_core::error::{Result, TorshError};
120use torsh_tensor::Tensor;
121
122/// Optimizer-specific error type
123#[derive(Debug, thiserror::Error)]
124pub enum OptimizerError {
125    #[error("Tensor operation failed: {0}")]
126    TensorError(#[from] torsh_core::error::TorshError),
127
128    #[error("Invalid parameter: {0}")]
129    InvalidParameter(String),
130
131    #[error("Serialization error: {0}")]
132    SerializationError(String),
133
134    #[error("IO error: {0}")]
135    IoError(#[from] std::io::Error),
136
137    #[error("Checkpoint error: {0}")]
138    CheckpointError(String),
139
140    #[error("Configuration error: {0}")]
141    ConfigError(String),
142
143    #[error("State error: {0}")]
144    StateError(String),
145
146    #[error("Invalid input: {0}")]
147    InvalidInput(String),
148
149    #[error("Numerical error: {0}")]
150    NumericalError(String),
151
152    #[error("Memory map error: {0}")]
153    MemoryMapError(String),
154}
155
156impl From<OptimizerError> for torsh_core::error::TorshError {
157    fn from(err: OptimizerError) -> Self {
158        match err {
159            OptimizerError::TensorError(e) => e,
160            OptimizerError::InvalidParameter(msg) => {
161                torsh_core::error::TorshError::InvalidArgument(msg)
162            }
163            OptimizerError::SerializationError(msg) => {
164                torsh_core::error::TorshError::SerializationError(msg)
165            }
166            OptimizerError::IoError(e) => torsh_core::error::TorshError::IoError(e.to_string()),
167            OptimizerError::CheckpointError(msg) => {
168                torsh_core::error::TorshError::RuntimeError(msg)
169            }
170            OptimizerError::ConfigError(msg) => torsh_core::error::TorshError::ConfigError(msg),
171            OptimizerError::StateError(msg) => torsh_core::error::TorshError::RuntimeError(msg),
172            OptimizerError::InvalidInput(msg) => {
173                torsh_core::error::TorshError::InvalidArgument(msg)
174            }
175            OptimizerError::NumericalError(msg) => torsh_core::error::TorshError::RuntimeError(msg),
176            OptimizerError::MemoryMapError(msg) => torsh_core::error::TorshError::RuntimeError(msg),
177        }
178    }
179}
180
181/// Result type for optimizer operations
182pub type OptimizerResult<T> = std::result::Result<T, OptimizerError>;
183
184// Version information
185pub const VERSION: &str = env!("CARGO_PKG_VERSION");
186pub const VERSION_MAJOR: u32 = 0;
187pub const VERSION_MINOR: u32 = 1;
188pub const VERSION_PATCH: u32 = 0;
189
190// Re-export scirs2 optimizer functionality
191// use scirs2::optim as sci_optim;
192
193/// Base optimizer trait
194pub trait Optimizer {
195    /// Perform a single optimization step
196    fn step(&mut self) -> OptimizerResult<()>;
197
198    /// Zero all gradients
199    fn zero_grad(&mut self);
200
201    /// Get the current learning rate
202    fn get_lr(&self) -> Vec<f32>;
203
204    /// Set the learning rate
205    fn set_lr(&mut self, lr: f32);
206
207    /// Add a parameter group
208    fn add_param_group(&mut self, params: Vec<Arc<RwLock<Tensor>>>, options: HashMap<String, f32>);
209
210    /// Get state dict for serialization
211    fn state_dict(&self) -> OptimizerResult<OptimizerState>;
212
213    /// Load state dict
214    fn load_state_dict(&mut self, state: OptimizerState) -> OptimizerResult<()>;
215}
216
217/// Optimizer state for serialization
218#[derive(Debug, Clone)]
219pub struct OptimizerState {
220    /// Optimizer type identifier
221    pub optimizer_type: String,
222    /// Version of the state format
223    pub version: String,
224    /// Parameter group states
225    pub param_groups: Vec<ParamGroupState>,
226    /// Per-parameter optimizer state (keyed by parameter ID)
227    pub state: HashMap<String, HashMap<String, Tensor>>,
228    /// Global optimizer state
229    pub global_state: HashMap<String, f32>,
230}
231
232/// Parameter group state
233#[derive(Debug, Clone)]
234pub struct ParamGroupState {
235    /// Learning rate for this group
236    pub lr: f32,
237    /// Additional options for this group
238    pub options: HashMap<String, f32>,
239    /// Number of parameters in this group (for validation)
240    pub param_count: usize,
241}
242
243impl OptimizerState {
244    /// Create a new empty optimizer state
245    pub fn new(optimizer_type: String) -> Self {
246        Self {
247            optimizer_type,
248            version: VERSION.to_string(),
249            param_groups: Vec::new(),
250            state: HashMap::new(),
251            global_state: HashMap::new(),
252        }
253    }
254
255    /// Validate the state structure
256    pub fn validate(&self) -> Result<()> {
257        if self.optimizer_type.is_empty() {
258            return Err(TorshError::InvalidArgument(
259                "Optimizer type cannot be empty".to_string(),
260            ));
261        }
262
263        // Check that all parameter groups are valid
264        for (i, group) in self.param_groups.iter().enumerate() {
265            if !group.lr.is_finite() || group.lr <= 0.0 {
266                return Err(TorshError::InvalidArgument(format!(
267                    "Invalid learning rate in group {}",
268                    i
269                )));
270            }
271        }
272
273        // Check that all state values are finite
274        for (param_id, param_state) in &self.state {
275            for (state_name, tensor) in param_state {
276                // For now, just check that the keys are valid
277                if param_id.is_empty() || state_name.is_empty() {
278                    return Err(TorshError::InvalidArgument(
279                        "State keys cannot be empty".to_string(),
280                    ));
281                }
282            }
283        }
284
285        Ok(())
286    }
287
288    /// Get the total number of parameters across all groups
289    pub fn total_param_count(&self) -> usize {
290        self.param_groups.iter().map(|g| g.param_count).sum()
291    }
292
293    /// Check if state is compatible with another state (same structure)
294    pub fn is_compatible_with(&self, other: &OptimizerState) -> bool {
295        self.optimizer_type == other.optimizer_type
296            && self.param_groups.len() == other.param_groups.len()
297            && self
298                .param_groups
299                .iter()
300                .zip(other.param_groups.iter())
301                .all(|(a, b)| a.param_count == b.param_count)
302    }
303}
304
305impl ParamGroupState {
306    /// Create a new parameter group state
307    pub fn new(lr: f32, param_count: usize) -> Self {
308        Self {
309            lr,
310            options: HashMap::new(),
311            param_count,
312        }
313    }
314
315    /// Create from a ParamGroup
316    pub fn from_param_group(group: &ParamGroup) -> Self {
317        Self {
318            lr: group.lr,
319            options: group.options.clone(),
320            param_count: group.params.len(),
321        }
322    }
323
324    /// Get an option value with a default
325    pub fn get_option(&self, key: &str, default: f32) -> f32 {
326        self.options.get(key).copied().unwrap_or(default)
327    }
328
329    /// Set an option value
330    pub fn set_option(&mut self, key: String, value: f32) {
331        self.options.insert(key, value);
332    }
333}
334
335/// Parameter group
336#[derive(Debug, Clone)]
337pub struct ParamGroup {
338    pub params: Vec<Arc<RwLock<Tensor>>>,
339    pub lr: f32,
340    pub options: HashMap<String, f32>,
341}
342
343/// Builder for creating parameter groups with various options
344#[derive(Debug)]
345pub struct ParamGroupBuilder {
346    params: Vec<Arc<RwLock<Tensor>>>,
347    lr: f32,
348    options: HashMap<String, f32>,
349}
350
351impl ParamGroupBuilder {
352    /// Create a new parameter group builder
353    pub fn new(lr: f32) -> Self {
354        Self {
355            params: Vec::new(),
356            lr,
357            options: HashMap::new(),
358        }
359    }
360
361    /// Add parameters to the group
362    pub fn params(mut self, params: Vec<Arc<RwLock<Tensor>>>) -> Self {
363        self.params = params;
364        self
365    }
366
367    /// Add a single parameter to the group
368    pub fn add_param(mut self, param: Arc<RwLock<Tensor>>) -> Self {
369        self.params.push(param);
370        self
371    }
372
373    /// Set weight decay
374    pub fn weight_decay(mut self, weight_decay: f32) -> Self {
375        self.options
376            .insert("weight_decay".to_string(), weight_decay);
377        self
378    }
379
380    /// Set epsilon
381    pub fn eps(mut self, eps: f32) -> Self {
382        self.options.insert("eps".to_string(), eps);
383        self
384    }
385
386    /// Set a custom option
387    pub fn option(mut self, key: String, value: f32) -> Self {
388        self.options.insert(key, value);
389        self
390    }
391
392    /// Set options from OptimizerOptions
393    pub fn from_options(mut self, options: &OptimizerOptions) -> Self {
394        self.lr = options.lr;
395        self.options = options.to_hashmap();
396        self.options.remove("lr"); // lr is stored separately
397        self
398    }
399
400    /// Build the parameter group
401    pub fn build(self) -> ParamGroup {
402        ParamGroup {
403            params: self.params,
404            lr: self.lr,
405            options: self.options,
406        }
407    }
408}
409
410impl ParamGroup {
411    pub fn new(params: Vec<Arc<RwLock<Tensor>>>, lr: f32) -> Self {
412        Self {
413            params,
414            lr,
415            options: HashMap::new(),
416        }
417    }
418
419    pub fn with_options(mut self, options: HashMap<String, f32>) -> Self {
420        self.options = options;
421        self
422    }
423
424    /// Add a single parameter to the group
425    pub fn add_param(&mut self, param: Arc<RwLock<Tensor>>) {
426        self.params.push(param);
427    }
428
429    /// Get a specific option value, falling back to a default
430    pub fn get_option(&self, key: &str, default: f32) -> f32 {
431        self.options.get(key).copied().unwrap_or(default)
432    }
433
434    /// Set a specific option value
435    pub fn set_option(&mut self, key: String, value: f32) {
436        self.options.insert(key, value);
437    }
438
439    /// Get the number of parameters in this group
440    pub fn param_count(&self) -> usize {
441        self.params.len()
442    }
443
444    /// Check if this group has any parameters
445    pub fn is_empty(&self) -> bool {
446        self.params.is_empty()
447    }
448
449    /// Get all parameters that have gradients
450    pub fn params_with_grads(&self) -> Vec<&Arc<RwLock<Tensor>>> {
451        self.params
452            .iter()
453            .filter(|param| param.read().has_grad())
454            .collect()
455    }
456
457    /// Validate that all parameters in the group are valid
458    pub fn validate(&self) -> bool {
459        !self.params.is_empty() && self.lr.is_finite() && self.lr > 0.0
460    }
461
462    /// Get parameter count for each unique shape in the group
463    pub fn get_shape_counts(&self) -> HashMap<Vec<usize>, usize> {
464        let mut shape_counts = HashMap::new();
465        for param in &self.params {
466            let shape = param.read().shape().dims().to_vec();
467            *shape_counts.entry(shape).or_insert(0) += 1;
468        }
469        shape_counts
470    }
471
472    /// Get total number of parameters (not tensors, but individual parameters)
473    pub fn total_param_count(&self) -> usize {
474        self.params.iter().map(|param| param.read().numel()).sum()
475    }
476
477    /// Clear gradients for all parameters in this group
478    pub fn zero_grad(&self) {
479        for param in &self.params {
480            param.write().zero_grad();
481        }
482    }
483
484    /// Check if any parameter in the group has gradients
485    pub fn has_any_grads(&self) -> bool {
486        self.params.iter().any(|param| param.read().has_grad())
487    }
488
489    /// Get gradient norm for all parameters in the group
490    pub fn grad_norm(&self) -> Result<f32> {
491        let mut total_norm_sq = 0.0f32;
492
493        for param in &self.params {
494            let param_guard = param.read();
495            if let Some(grad) = param_guard.grad() {
496                let grad_norm = grad.norm().map_err(|e| {
497                    TorshError::Other(format!("Failed to compute gradient norm: {}", e))
498                })?;
499                let norm_value = grad_norm.to_vec().map_err(|e| {
500                    TorshError::Other(format!("Failed to extract norm value: {}", e))
501                })?[0];
502                total_norm_sq += norm_value * norm_value;
503            }
504        }
505
506        Ok(total_norm_sq.sqrt())
507    }
508
509    /// Apply gradient clipping to all parameters in the group
510    pub fn clip_grads(&self, max_norm: f32) -> Result<f32> {
511        let total_norm = self.grad_norm()?;
512
513        if total_norm > max_norm {
514            let scale = max_norm / total_norm;
515            for param in &self.params {
516                let mut param_guard = param.write();
517                if let Some(grad) = param_guard.grad() {
518                    let clipped_grad = grad.mul_scalar(scale).map_err(|e| {
519                        TorshError::Other(format!("Failed to clip gradient: {}", e))
520                    })?;
521                    param_guard.set_grad(Some(clipped_grad));
522                }
523            }
524        }
525
526        Ok(total_norm)
527    }
528}
529
530/// Common optimizer options
531#[derive(Debug, Clone)]
532pub struct OptimizerOptions {
533    pub lr: f32,
534    pub weight_decay: f32,
535    pub eps: f32,
536    pub maximize: bool,
537}
538
539impl Default for OptimizerOptions {
540    fn default() -> Self {
541        Self {
542            lr: 1e-3,
543            weight_decay: 0.0,
544            eps: 1e-8,
545            maximize: false,
546        }
547    }
548}
549
550impl OptimizerOptions {
551    /// Create new optimizer options with specified learning rate
552    pub fn new(lr: f32) -> Self {
553        Self {
554            lr,
555            ..Default::default()
556        }
557    }
558
559    /// Set weight decay
560    pub fn with_weight_decay(mut self, weight_decay: f32) -> Self {
561        self.weight_decay = weight_decay;
562        self
563    }
564
565    /// Set epsilon value for numerical stability
566    pub fn with_eps(mut self, eps: f32) -> Self {
567        self.eps = eps;
568        self
569    }
570
571    /// Set maximize flag (for maximization problems)
572    pub fn with_maximize(mut self, maximize: bool) -> Self {
573        self.maximize = maximize;
574        self
575    }
576
577    /// Convert to HashMap for compatibility with parameter groups
578    pub fn to_hashmap(&self) -> HashMap<String, f32> {
579        let mut map = HashMap::new();
580        map.insert("lr".to_string(), self.lr);
581        map.insert("weight_decay".to_string(), self.weight_decay);
582        map.insert("eps".to_string(), self.eps);
583        map.insert(
584            "maximize".to_string(),
585            if self.maximize { 1.0 } else { 0.0 },
586        );
587        map
588    }
589
590    /// Create from HashMap
591    pub fn from_hashmap(map: &HashMap<String, f32>) -> Self {
592        Self {
593            lr: map.get("lr").copied().unwrap_or(1e-3),
594            weight_decay: map.get("weight_decay").copied().unwrap_or(0.0),
595            eps: map.get("eps").copied().unwrap_or(1e-8),
596            maximize: map.get("maximize").copied().unwrap_or(0.0) > 0.0,
597        }
598    }
599
600    /// Validate the options are reasonable
601    pub fn validate(&self) -> Result<()> {
602        if !self.lr.is_finite() || self.lr <= 0.0 {
603            return Err(TorshError::InvalidArgument(
604                "Learning rate must be positive and finite".to_string(),
605            ));
606        }
607        if !self.weight_decay.is_finite() || self.weight_decay < 0.0 {
608            return Err(TorshError::InvalidArgument(
609                "Weight decay must be non-negative and finite".to_string(),
610            ));
611        }
612        if !self.eps.is_finite() || self.eps <= 0.0 {
613            return Err(TorshError::InvalidArgument(
614                "Epsilon must be positive and finite".to_string(),
615            ));
616        }
617        Ok(())
618    }
619
620    /// Create standardized state dict for any optimizer
621    pub fn create_standard_state_dict(
622        optimizer_type: &str,
623        version: Option<&str>,
624        param_groups: &[ParamGroup],
625        state: &HashMap<String, HashMap<String, Tensor>>,
626        global_state: Option<HashMap<String, f32>>,
627    ) -> OptimizerState {
628        let param_group_states = param_groups
629            .iter()
630            .map(|g| ParamGroupState::from_param_group(g))
631            .collect();
632
633        let mut optimizer_state = OptimizerState {
634            optimizer_type: optimizer_type.to_string(),
635            version: version.unwrap_or("1.0").to_string(),
636            param_groups: param_group_states,
637            state: state.clone(),
638            global_state: global_state.unwrap_or_default(),
639        };
640
641        optimizer_state
642    }
643
644    /// Validate state dict compatibility between optimizers
645    pub fn validate_state_compatibility(
646        current_groups: &[ParamGroup],
647        state_groups: &[ParamGroupState],
648    ) -> Result<()> {
649        if current_groups.len() != state_groups.len() {
650            return Err(TorshError::InvalidArgument(format!(
651                "Parameter group count mismatch: expected {}, got {}",
652                current_groups.len(),
653                state_groups.len()
654            )));
655        }
656
657        for (i, (current_group, state_group)) in
658            current_groups.iter().zip(state_groups.iter()).enumerate()
659        {
660            if current_group.params.len() != state_group.param_count {
661                return Err(TorshError::InvalidArgument(format!(
662                    "Parameter count mismatch in group {}: expected {}, got {}",
663                    i,
664                    current_group.params.len(),
665                    state_group.param_count
666                )));
667            }
668        }
669
670        Ok(())
671    }
672}
673
674/// Prelude module for convenient imports
675/// Convergence testing utilities
676#[cfg(test)]
677pub mod convergence_tests {
678    use super::*;
679    use parking_lot::RwLock;
680    use std::ops::Add;
681    use std::sync::Arc;
682    use torsh_tensor::{
683        creation::{randn, zeros},
684        Tensor,
685    };
686
687    /// Test that an optimizer can minimize a simple quadratic function
688    pub fn test_quadratic_convergence<O: Optimizer>(
689        create_optimizer: impl Fn(Vec<Arc<RwLock<Tensor>>>) -> O,
690        tolerance: f32,
691        max_iterations: usize,
692    ) -> Result<()> {
693        // Create a simple quadratic function: f(x) = x^2 + y^2
694        let x = Arc::new(RwLock::new(Tensor::scalar(2.0)?));
695        let y = Arc::new(RwLock::new(Tensor::scalar(2.0)?));
696        let params = vec![x.clone(), y.clone()];
697
698        let mut optimizer = create_optimizer(params);
699
700        for i in 0..max_iterations {
701            // Compute gradients: df/dx = 2x, df/dy = 2y
702            {
703                let x_val = x.read().clone();
704                let y_val = y.read().clone();
705
706                let x_grad = x_val.mul_scalar(2.0)?;
707                let y_grad = y_val.mul_scalar(2.0)?;
708
709                x.write().set_grad(Some(x_grad));
710                y.write().set_grad(Some(y_grad));
711            }
712
713            // Optimizer step
714            optimizer
715                .step()
716                .map_err(|e| TorshError::Other(format!("Optimizer step failed: {}", e)))?;
717
718            // Check convergence
719            let x_val = x.read().to_vec()?[0];
720            let y_val = y.read().to_vec()?[0];
721            let loss = x_val * x_val + y_val * y_val;
722
723            if loss < tolerance {
724                return Ok(());
725            }
726
727            // Clear gradients for next iteration
728            optimizer.zero_grad();
729        }
730
731        Err(TorshError::Other(format!(
732            "Failed to converge within {} iterations",
733            max_iterations
734        )))
735    }
736
737    /// Test that an optimizer can minimize a linear regression problem
738    pub fn test_linear_regression_convergence<O: Optimizer>(
739        create_optimizer: impl Fn(Vec<Arc<RwLock<Tensor>>>) -> O,
740        tolerance: f32,
741        max_iterations: usize,
742    ) -> Result<()> {
743        // Create a simple linear regression problem: y = 2x + 1 + noise
744        let true_weight = 2.0;
745        let true_bias = 1.0;
746
747        // Generate synthetic data
748        let n_samples = 100;
749        let x_data = randn::<f32>(&[n_samples, 1])?;
750        let noise = randn::<f32>(&[n_samples, 1])?.mul_scalar(0.1)?;
751        let y_data = x_data
752            .mul_scalar(true_weight)?
753            .add_scalar(true_bias)?
754            .add(&noise)?;
755
756        // Initialize parameters
757        let weight = Arc::new(RwLock::new(zeros(&[1, 1])?));
758        let bias = Arc::new(RwLock::new(zeros(&[1])?));
759        let params = vec![weight.clone(), bias.clone()];
760
761        let mut optimizer = create_optimizer(params);
762
763        for i in 0..max_iterations {
764            // Forward pass: y_pred = x * weight + bias
765            let w_val = weight.read().clone();
766            let b_val = bias.read().clone();
767
768            let y_pred = x_data.matmul(&w_val)?.add(&b_val)?;
769
770            // Compute loss: MSE = mean((y_pred - y_true)^2)
771            let diff = y_pred.sub(&y_data)?;
772            let loss_tensor = diff.pow(2.0)?.mean(Some(&[0]), false)?;
773            let loss = loss_tensor.to_vec()?[0];
774
775            // Compute gradients
776            let grad_scale = 2.0 / n_samples as f32;
777            let weight_grad = x_data
778                .transpose(0, 1)?
779                .matmul(&diff)?
780                .mul_scalar(grad_scale)?;
781            let bias_grad = diff.sum()?.mul_scalar(grad_scale)?;
782
783            weight.write().set_grad(Some(weight_grad));
784            bias.write().set_grad(Some(bias_grad));
785
786            // Optimizer step
787            optimizer
788                .step()
789                .map_err(|e| TorshError::Other(format!("Optimizer step failed: {}", e)))?;
790
791            // Check convergence
792            if loss < tolerance {
793                // Verify the learned parameters are close to true values
794                let learned_weight = weight.read().to_vec()?[0];
795                let learned_bias = bias.read().to_vec()?[0];
796
797                if (learned_weight - true_weight).abs() < 0.1
798                    && (learned_bias - true_bias).abs() < 0.1
799                {
800                    return Ok(());
801                }
802            }
803
804            // Clear gradients for next iteration
805            optimizer.zero_grad();
806        }
807
808        Err(TorshError::Other(format!(
809            "Failed to converge within {} iterations",
810            max_iterations
811        )))
812    }
813
814    /// Test that an optimizer maintains consistent behavior across multiple runs
815    pub fn test_optimizer_consistency<O: Optimizer>(
816        create_optimizer: impl Fn(Vec<Arc<RwLock<Tensor>>>) -> O,
817        n_runs: usize,
818        tolerance: f32,
819    ) -> Result<()> {
820        let mut final_values = Vec::new();
821
822        for run in 0..n_runs {
823            let param = Arc::new(RwLock::new(Tensor::scalar(1.0)?));
824            let params = vec![param.clone()];
825            let mut optimizer = create_optimizer(params);
826
827            // Run for a fixed number of steps
828            for _ in 0..10 {
829                {
830                    let param_val = param.read().clone();
831                    let grad = param_val.mul_scalar(2.0)?; // Simple gradient
832                    param.write().set_grad(Some(grad));
833                }
834
835                optimizer
836                    .step()
837                    .map_err(|e| TorshError::Other(format!("Optimizer step failed: {}", e)))?;
838                optimizer.zero_grad();
839            }
840
841            final_values.push(param.read().to_vec()?[0]);
842        }
843
844        // Check that all runs produce similar results
845        let mean_value = final_values.iter().sum::<f32>() / final_values.len() as f32;
846        for &value in &final_values {
847            if (value - mean_value).abs() > tolerance {
848                return Err(TorshError::Other(format!(
849                    "Inconsistent optimizer behavior: values vary by more than {}",
850                    tolerance
851                )));
852            }
853        }
854
855        Ok(())
856    }
857}
858
859pub mod prelude {
860    pub use crate::adabelief::AdaBelief;
861    pub use crate::adabound::AdaBound;
862    pub use crate::adadelta::AdaDelta;
863    pub use crate::adagrad::AdaGrad;
864    pub use crate::adahessian::{AdaHessian, AdaHessianBuilder};
865    pub use crate::adam::{Adam, AdamW};
866    pub use crate::adamax::AdaMax;
867    pub use crate::asgd::ASGD;
868    pub use crate::checkpointing::{
869        Checkpoint, CheckpointConfig, CheckpointManager, CheckpointMetadata, CheckpointStatistics,
870        CheckpointSupport, CheckpointingOptimizer,
871    };
872    pub use crate::composition::{
873        CombinationMethod, ComposedOptimizer, CompositionBuilder, CompositionStrategy,
874        OptimizerMetrics, SwitchCriterion, VotingMethod,
875    };
876    pub use crate::debugging::{
877        AnalysisReport, AnalyzerConfig, ConvergenceTracker, GradientFlowPoint, GradientStatistics,
878        HyperparameterSensitivity, OptimizationRecommendation, OptimizationStep, OptimizerAnalyzer,
879        ParameterStatistics, RecommendationCategory, SensitivityReport, SensitivityResult,
880        Severity,
881    };
882    pub use crate::distributed::{
883        utils as distributed_utils, AsyncConfig, AsyncSGD, CommunicationStats, DistributedBackend,
884        DistributedConfig, DistributedOptimizer, ElasticAveragingSGD, SyncStrategy,
885    };
886    pub use crate::ftrl::{FTRLBuilder, FTRL};
887    pub use crate::fused_kernels::{
888        fused_adadelta_step, fused_adagrad_step, fused_adam_step, fused_rmsprop_step,
889        fused_sgd_step, FusedKernelSupport, FusedStats,
890    };
891    pub use crate::grad_accumulation::{
892        with_gradient_accumulation, AccumulatingOptimizer, GradientAccumulationSupport,
893        GradientAccumulator,
894    };
895    pub use crate::kfac::{KFACBuilder, KFAC};
896    pub use crate::lamb::LAMB;
897    pub use crate::lazy_updates::{
898        LazyUpdateConfig, LazyUpdateDecision, LazyUpdateManager, LazyUpdateOptimizer,
899        LazyUpdateStatistics, LazyUpdateSupport, ParameterImportance, PendingUpdate,
900        UpdatePriority,
901    };
902    pub use crate::lbfgs::LBFGS;
903    pub use crate::lion::{Lion, LionBuilder, LionConfig};
904    pub use crate::lookahead::{lookahead_adam, lookahead_radam, lookahead_sgd, Lookahead};
905    pub use crate::low_precision::{
906        LowPrecisionConvertible, LowPrecisionOptimizer, LowPrecisionState, PrecisionType,
907        StateStatistics,
908    };
909    pub use crate::lr_scheduler::{
910        CosineAnnealingLR, ExponentialLR, LRScheduler, OneCycleLR, ReduceLROnPlateau, StepLR,
911    };
912    pub use crate::lr_scheduler_additional::{
913        ConstantLR, CosineAnnealingWarmRestarts, CyclicLR, LinearLR, MultiStepLR, PolynomialLR,
914    };
915    pub use crate::lr_scheduler_enhanced::{
916        utils as lr_enhanced_utils, AdaptiveLRScheduler, AdaptiveSchedulerStats, AdaptiveStrategy,
917        CosineAnnealingWarmRestartsWithWarmup, PolynomialDecayWithWarmup, WarmupStrategy,
918    };
919    pub use crate::memory_efficient::{
920        CircularBuffer, MemoryConfig, MemoryEfficientAdam, MemoryEfficientLBFGS,
921        MemoryEfficientOptimizerBuilder, MemoryPool,
922    };
923    pub use crate::memory_mapped::{
924        MemoryMappedConfig, MemoryMappedFile, MemoryMappedOptimizer, MemoryMappedStateStorage,
925        MemoryMappedSupport, StorageStatistics,
926    };
927    pub use crate::mixed_precision::{
928        with_mixed_precision, MixedPrecisionConfig, MixedPrecisionOptimizer,
929    };
930    pub use crate::nadam::NAdam;
931    pub use crate::natural_gradient::{NaturalGradient, NaturalGradientBuilder};
932    pub use crate::newton_cg::{NewtonCG, NewtonCGBuilder, NewtonCGConfig};
933    pub use crate::online_learning::{
934        OnlineGradientDescent, ProximalGradient, ProximalOperator, SAGA, SVRG,
935    };
936    pub use crate::prodigy::{Prodigy, ProdigyBuilder, ProdigyConfig};
937    pub use crate::radam::RAdam;
938    pub use crate::ranger::{Ranger, RangerBuilder};
939    pub use crate::rmsprop::RMSprop;
940    pub use crate::rprop::Rprop;
941    pub use crate::schedule_free::{ScheduleFreeAdamW, ScheduleFreeAdamWBuilder};
942    pub use crate::sgd::SGD;
943    pub use crate::shampoo::{Shampoo, ShampooBuilder};
944    pub use crate::sophia::{Sophia, SophiaBuilder, SophiaConfig};
945    pub use crate::sparse_adam::SparseAdam;
946    pub use crate::state_dict_ops::{
947        CompressionMethod, CompressionStats, MemoryEstimate, SerializationFormat, StateDictConfig,
948        StateDictManager,
949    };
950    pub use crate::trust_region::{
951        SubproblemSolver, TrustRegionBuilder, TrustRegionConfig, TrustRegionMethod,
952        TrustRegionStrategy,
953    };
954    pub use crate::yellowfin::{YellowFin, YellowFinBuilder, YellowFinConfig};
955    pub use crate::{Optimizer, OptimizerOptions, OptimizerState, ParamGroup, ParamGroupBuilder};
956    pub use crate::{OptimizerError, OptimizerResult};
957}
958
959// Re-export commonly used types
960pub use adam::{Adam, AdamW};
961pub use distributed::{DistributedBackend, DistributedConfig, DistributedOptimizer, SyncStrategy};
962pub use rmsprop::RMSprop;
963pub use sgd::SGD;
964
965#[cfg(test)]
966mod tests {
967    use super::*;
968
969    #[test]
970    fn test_param_group() {
971        let params = vec![];
972        let group = ParamGroup::new(params, 0.01);
973        assert_eq!(group.lr, 0.01);
974    }
975}