1#![cfg_attr(not(feature = "std"), no_std)]
42#![allow(dead_code)] #![allow(unused_imports)] #![allow(unused_variables)] #![allow(unused_mut)] #[cfg(not(feature = "std"))]
50extern crate alloc;
51
52pub mod adabelief;
53pub mod adabound;
54pub mod adadelta;
55pub mod adagrad;
56pub mod adahessian;
57pub mod adam;
58pub mod adamax;
59pub mod advanced;
60pub mod asgd;
61pub mod bayesian_optimization;
62pub mod benchmarks;
63pub mod checkpointing;
64pub mod composition;
65pub mod continual_learning;
66pub mod cross_framework_validation;
67pub mod debugging;
68pub mod differential_privacy;
69pub mod distributed;
70pub mod evolutionary_strategies;
71pub mod ftrl;
72pub mod fused_kernels;
73pub mod grad_accumulation;
74pub mod gradient_free;
75pub mod green_ai;
76pub mod hyperparameter_tuning;
77pub mod kfac;
78pub mod lamb;
79pub mod lazy_updates;
80pub mod lbfgs;
81pub mod lion;
82pub mod lookahead;
83pub mod low_precision;
84pub mod lr_scheduler;
85pub mod lr_scheduler_additional;
86pub mod lr_scheduler_enhanced;
87pub mod memory_efficient;
88pub mod memory_mapped;
89pub mod mixed_precision;
90pub mod nadam;
91pub mod natural_gradient;
92pub mod neural_optimizer;
93pub mod neuromorphic;
94pub mod newton_cg;
95pub mod numerical_stability_tests;
96pub mod online_learning;
97pub mod optimizer;
98pub mod prodigy;
99pub mod quantum_inspired;
100pub mod radam;
101pub mod ranger;
102pub mod rmsprop;
103pub mod robustness;
104pub mod rprop;
105pub mod schedule_free;
106pub mod sgd;
107pub mod shampoo;
108pub mod sophia;
109pub mod sparse_adam;
110pub mod sparse_updates;
111pub mod state_dict_ops;
112pub mod stress_tests;
113pub mod trust_region;
114pub mod yellowfin;
115
116use parking_lot::RwLock;
117use std::collections::HashMap;
118use std::sync::Arc;
119use torsh_core::error::{Result, TorshError};
120use torsh_tensor::Tensor;
121
122#[derive(Debug, thiserror::Error)]
124pub enum OptimizerError {
125 #[error("Tensor operation failed: {0}")]
126 TensorError(#[from] torsh_core::error::TorshError),
127
128 #[error("Invalid parameter: {0}")]
129 InvalidParameter(String),
130
131 #[error("Serialization error: {0}")]
132 SerializationError(String),
133
134 #[error("IO error: {0}")]
135 IoError(#[from] std::io::Error),
136
137 #[error("Checkpoint error: {0}")]
138 CheckpointError(String),
139
140 #[error("Configuration error: {0}")]
141 ConfigError(String),
142
143 #[error("State error: {0}")]
144 StateError(String),
145
146 #[error("Invalid input: {0}")]
147 InvalidInput(String),
148
149 #[error("Numerical error: {0}")]
150 NumericalError(String),
151
152 #[error("Memory map error: {0}")]
153 MemoryMapError(String),
154}
155
156impl From<OptimizerError> for torsh_core::error::TorshError {
157 fn from(err: OptimizerError) -> Self {
158 match err {
159 OptimizerError::TensorError(e) => e,
160 OptimizerError::InvalidParameter(msg) => {
161 torsh_core::error::TorshError::InvalidArgument(msg)
162 }
163 OptimizerError::SerializationError(msg) => {
164 torsh_core::error::TorshError::SerializationError(msg)
165 }
166 OptimizerError::IoError(e) => torsh_core::error::TorshError::IoError(e.to_string()),
167 OptimizerError::CheckpointError(msg) => {
168 torsh_core::error::TorshError::RuntimeError(msg)
169 }
170 OptimizerError::ConfigError(msg) => torsh_core::error::TorshError::ConfigError(msg),
171 OptimizerError::StateError(msg) => torsh_core::error::TorshError::RuntimeError(msg),
172 OptimizerError::InvalidInput(msg) => {
173 torsh_core::error::TorshError::InvalidArgument(msg)
174 }
175 OptimizerError::NumericalError(msg) => torsh_core::error::TorshError::RuntimeError(msg),
176 OptimizerError::MemoryMapError(msg) => torsh_core::error::TorshError::RuntimeError(msg),
177 }
178 }
179}
180
181pub type OptimizerResult<T> = std::result::Result<T, OptimizerError>;
183
184pub const VERSION: &str = env!("CARGO_PKG_VERSION");
186pub const VERSION_MAJOR: u32 = 0;
187pub const VERSION_MINOR: u32 = 1;
188pub const VERSION_PATCH: u32 = 0;
189
190pub trait Optimizer {
195 fn step(&mut self) -> OptimizerResult<()>;
197
198 fn zero_grad(&mut self);
200
201 fn get_lr(&self) -> Vec<f32>;
203
204 fn set_lr(&mut self, lr: f32);
206
207 fn add_param_group(&mut self, params: Vec<Arc<RwLock<Tensor>>>, options: HashMap<String, f32>);
209
210 fn state_dict(&self) -> OptimizerResult<OptimizerState>;
212
213 fn load_state_dict(&mut self, state: OptimizerState) -> OptimizerResult<()>;
215}
216
217#[derive(Debug, Clone)]
219pub struct OptimizerState {
220 pub optimizer_type: String,
222 pub version: String,
224 pub param_groups: Vec<ParamGroupState>,
226 pub state: HashMap<String, HashMap<String, Tensor>>,
228 pub global_state: HashMap<String, f32>,
230}
231
232#[derive(Debug, Clone)]
234pub struct ParamGroupState {
235 pub lr: f32,
237 pub options: HashMap<String, f32>,
239 pub param_count: usize,
241}
242
243impl OptimizerState {
244 pub fn new(optimizer_type: String) -> Self {
246 Self {
247 optimizer_type,
248 version: VERSION.to_string(),
249 param_groups: Vec::new(),
250 state: HashMap::new(),
251 global_state: HashMap::new(),
252 }
253 }
254
255 pub fn validate(&self) -> Result<()> {
257 if self.optimizer_type.is_empty() {
258 return Err(TorshError::InvalidArgument(
259 "Optimizer type cannot be empty".to_string(),
260 ));
261 }
262
263 for (i, group) in self.param_groups.iter().enumerate() {
265 if !group.lr.is_finite() || group.lr <= 0.0 {
266 return Err(TorshError::InvalidArgument(format!(
267 "Invalid learning rate in group {}",
268 i
269 )));
270 }
271 }
272
273 for (param_id, param_state) in &self.state {
275 for (state_name, tensor) in param_state {
276 if param_id.is_empty() || state_name.is_empty() {
278 return Err(TorshError::InvalidArgument(
279 "State keys cannot be empty".to_string(),
280 ));
281 }
282 }
283 }
284
285 Ok(())
286 }
287
288 pub fn total_param_count(&self) -> usize {
290 self.param_groups.iter().map(|g| g.param_count).sum()
291 }
292
293 pub fn is_compatible_with(&self, other: &OptimizerState) -> bool {
295 self.optimizer_type == other.optimizer_type
296 && self.param_groups.len() == other.param_groups.len()
297 && self
298 .param_groups
299 .iter()
300 .zip(other.param_groups.iter())
301 .all(|(a, b)| a.param_count == b.param_count)
302 }
303}
304
305impl ParamGroupState {
306 pub fn new(lr: f32, param_count: usize) -> Self {
308 Self {
309 lr,
310 options: HashMap::new(),
311 param_count,
312 }
313 }
314
315 pub fn from_param_group(group: &ParamGroup) -> Self {
317 Self {
318 lr: group.lr,
319 options: group.options.clone(),
320 param_count: group.params.len(),
321 }
322 }
323
324 pub fn get_option(&self, key: &str, default: f32) -> f32 {
326 self.options.get(key).copied().unwrap_or(default)
327 }
328
329 pub fn set_option(&mut self, key: String, value: f32) {
331 self.options.insert(key, value);
332 }
333}
334
335#[derive(Debug, Clone)]
337pub struct ParamGroup {
338 pub params: Vec<Arc<RwLock<Tensor>>>,
339 pub lr: f32,
340 pub options: HashMap<String, f32>,
341}
342
343#[derive(Debug)]
345pub struct ParamGroupBuilder {
346 params: Vec<Arc<RwLock<Tensor>>>,
347 lr: f32,
348 options: HashMap<String, f32>,
349}
350
351impl ParamGroupBuilder {
352 pub fn new(lr: f32) -> Self {
354 Self {
355 params: Vec::new(),
356 lr,
357 options: HashMap::new(),
358 }
359 }
360
361 pub fn params(mut self, params: Vec<Arc<RwLock<Tensor>>>) -> Self {
363 self.params = params;
364 self
365 }
366
367 pub fn add_param(mut self, param: Arc<RwLock<Tensor>>) -> Self {
369 self.params.push(param);
370 self
371 }
372
373 pub fn weight_decay(mut self, weight_decay: f32) -> Self {
375 self.options
376 .insert("weight_decay".to_string(), weight_decay);
377 self
378 }
379
380 pub fn eps(mut self, eps: f32) -> Self {
382 self.options.insert("eps".to_string(), eps);
383 self
384 }
385
386 pub fn option(mut self, key: String, value: f32) -> Self {
388 self.options.insert(key, value);
389 self
390 }
391
392 pub fn from_options(mut self, options: &OptimizerOptions) -> Self {
394 self.lr = options.lr;
395 self.options = options.to_hashmap();
396 self.options.remove("lr"); self
398 }
399
400 pub fn build(self) -> ParamGroup {
402 ParamGroup {
403 params: self.params,
404 lr: self.lr,
405 options: self.options,
406 }
407 }
408}
409
410impl ParamGroup {
411 pub fn new(params: Vec<Arc<RwLock<Tensor>>>, lr: f32) -> Self {
412 Self {
413 params,
414 lr,
415 options: HashMap::new(),
416 }
417 }
418
419 pub fn with_options(mut self, options: HashMap<String, f32>) -> Self {
420 self.options = options;
421 self
422 }
423
424 pub fn add_param(&mut self, param: Arc<RwLock<Tensor>>) {
426 self.params.push(param);
427 }
428
429 pub fn get_option(&self, key: &str, default: f32) -> f32 {
431 self.options.get(key).copied().unwrap_or(default)
432 }
433
434 pub fn set_option(&mut self, key: String, value: f32) {
436 self.options.insert(key, value);
437 }
438
439 pub fn param_count(&self) -> usize {
441 self.params.len()
442 }
443
444 pub fn is_empty(&self) -> bool {
446 self.params.is_empty()
447 }
448
449 pub fn params_with_grads(&self) -> Vec<&Arc<RwLock<Tensor>>> {
451 self.params
452 .iter()
453 .filter(|param| param.read().has_grad())
454 .collect()
455 }
456
457 pub fn validate(&self) -> bool {
459 !self.params.is_empty() && self.lr.is_finite() && self.lr > 0.0
460 }
461
462 pub fn get_shape_counts(&self) -> HashMap<Vec<usize>, usize> {
464 let mut shape_counts = HashMap::new();
465 for param in &self.params {
466 let shape = param.read().shape().dims().to_vec();
467 *shape_counts.entry(shape).or_insert(0) += 1;
468 }
469 shape_counts
470 }
471
472 pub fn total_param_count(&self) -> usize {
474 self.params.iter().map(|param| param.read().numel()).sum()
475 }
476
477 pub fn zero_grad(&self) {
479 for param in &self.params {
480 param.write().zero_grad();
481 }
482 }
483
484 pub fn has_any_grads(&self) -> bool {
486 self.params.iter().any(|param| param.read().has_grad())
487 }
488
489 pub fn grad_norm(&self) -> Result<f32> {
491 let mut total_norm_sq = 0.0f32;
492
493 for param in &self.params {
494 let param_guard = param.read();
495 if let Some(grad) = param_guard.grad() {
496 let grad_norm = grad.norm().map_err(|e| {
497 TorshError::Other(format!("Failed to compute gradient norm: {}", e))
498 })?;
499 let norm_value = grad_norm.to_vec().map_err(|e| {
500 TorshError::Other(format!("Failed to extract norm value: {}", e))
501 })?[0];
502 total_norm_sq += norm_value * norm_value;
503 }
504 }
505
506 Ok(total_norm_sq.sqrt())
507 }
508
509 pub fn clip_grads(&self, max_norm: f32) -> Result<f32> {
511 let total_norm = self.grad_norm()?;
512
513 if total_norm > max_norm {
514 let scale = max_norm / total_norm;
515 for param in &self.params {
516 let mut param_guard = param.write();
517 if let Some(grad) = param_guard.grad() {
518 let clipped_grad = grad.mul_scalar(scale).map_err(|e| {
519 TorshError::Other(format!("Failed to clip gradient: {}", e))
520 })?;
521 param_guard.set_grad(Some(clipped_grad));
522 }
523 }
524 }
525
526 Ok(total_norm)
527 }
528}
529
530#[derive(Debug, Clone)]
532pub struct OptimizerOptions {
533 pub lr: f32,
534 pub weight_decay: f32,
535 pub eps: f32,
536 pub maximize: bool,
537}
538
539impl Default for OptimizerOptions {
540 fn default() -> Self {
541 Self {
542 lr: 1e-3,
543 weight_decay: 0.0,
544 eps: 1e-8,
545 maximize: false,
546 }
547 }
548}
549
550impl OptimizerOptions {
551 pub fn new(lr: f32) -> Self {
553 Self {
554 lr,
555 ..Default::default()
556 }
557 }
558
559 pub fn with_weight_decay(mut self, weight_decay: f32) -> Self {
561 self.weight_decay = weight_decay;
562 self
563 }
564
565 pub fn with_eps(mut self, eps: f32) -> Self {
567 self.eps = eps;
568 self
569 }
570
571 pub fn with_maximize(mut self, maximize: bool) -> Self {
573 self.maximize = maximize;
574 self
575 }
576
577 pub fn to_hashmap(&self) -> HashMap<String, f32> {
579 let mut map = HashMap::new();
580 map.insert("lr".to_string(), self.lr);
581 map.insert("weight_decay".to_string(), self.weight_decay);
582 map.insert("eps".to_string(), self.eps);
583 map.insert(
584 "maximize".to_string(),
585 if self.maximize { 1.0 } else { 0.0 },
586 );
587 map
588 }
589
590 pub fn from_hashmap(map: &HashMap<String, f32>) -> Self {
592 Self {
593 lr: map.get("lr").copied().unwrap_or(1e-3),
594 weight_decay: map.get("weight_decay").copied().unwrap_or(0.0),
595 eps: map.get("eps").copied().unwrap_or(1e-8),
596 maximize: map.get("maximize").copied().unwrap_or(0.0) > 0.0,
597 }
598 }
599
600 pub fn validate(&self) -> Result<()> {
602 if !self.lr.is_finite() || self.lr <= 0.0 {
603 return Err(TorshError::InvalidArgument(
604 "Learning rate must be positive and finite".to_string(),
605 ));
606 }
607 if !self.weight_decay.is_finite() || self.weight_decay < 0.0 {
608 return Err(TorshError::InvalidArgument(
609 "Weight decay must be non-negative and finite".to_string(),
610 ));
611 }
612 if !self.eps.is_finite() || self.eps <= 0.0 {
613 return Err(TorshError::InvalidArgument(
614 "Epsilon must be positive and finite".to_string(),
615 ));
616 }
617 Ok(())
618 }
619
620 pub fn create_standard_state_dict(
622 optimizer_type: &str,
623 version: Option<&str>,
624 param_groups: &[ParamGroup],
625 state: &HashMap<String, HashMap<String, Tensor>>,
626 global_state: Option<HashMap<String, f32>>,
627 ) -> OptimizerState {
628 let param_group_states = param_groups
629 .iter()
630 .map(|g| ParamGroupState::from_param_group(g))
631 .collect();
632
633 let mut optimizer_state = OptimizerState {
634 optimizer_type: optimizer_type.to_string(),
635 version: version.unwrap_or("1.0").to_string(),
636 param_groups: param_group_states,
637 state: state.clone(),
638 global_state: global_state.unwrap_or_default(),
639 };
640
641 optimizer_state
642 }
643
644 pub fn validate_state_compatibility(
646 current_groups: &[ParamGroup],
647 state_groups: &[ParamGroupState],
648 ) -> Result<()> {
649 if current_groups.len() != state_groups.len() {
650 return Err(TorshError::InvalidArgument(format!(
651 "Parameter group count mismatch: expected {}, got {}",
652 current_groups.len(),
653 state_groups.len()
654 )));
655 }
656
657 for (i, (current_group, state_group)) in
658 current_groups.iter().zip(state_groups.iter()).enumerate()
659 {
660 if current_group.params.len() != state_group.param_count {
661 return Err(TorshError::InvalidArgument(format!(
662 "Parameter count mismatch in group {}: expected {}, got {}",
663 i,
664 current_group.params.len(),
665 state_group.param_count
666 )));
667 }
668 }
669
670 Ok(())
671 }
672}
673
674#[cfg(test)]
677pub mod convergence_tests {
678 use super::*;
679 use parking_lot::RwLock;
680 use std::ops::Add;
681 use std::sync::Arc;
682 use torsh_tensor::{
683 creation::{randn, zeros},
684 Tensor,
685 };
686
687 pub fn test_quadratic_convergence<O: Optimizer>(
689 create_optimizer: impl Fn(Vec<Arc<RwLock<Tensor>>>) -> O,
690 tolerance: f32,
691 max_iterations: usize,
692 ) -> Result<()> {
693 let x = Arc::new(RwLock::new(Tensor::scalar(2.0)?));
695 let y = Arc::new(RwLock::new(Tensor::scalar(2.0)?));
696 let params = vec![x.clone(), y.clone()];
697
698 let mut optimizer = create_optimizer(params);
699
700 for i in 0..max_iterations {
701 {
703 let x_val = x.read().clone();
704 let y_val = y.read().clone();
705
706 let x_grad = x_val.mul_scalar(2.0)?;
707 let y_grad = y_val.mul_scalar(2.0)?;
708
709 x.write().set_grad(Some(x_grad));
710 y.write().set_grad(Some(y_grad));
711 }
712
713 optimizer
715 .step()
716 .map_err(|e| TorshError::Other(format!("Optimizer step failed: {}", e)))?;
717
718 let x_val = x.read().to_vec()?[0];
720 let y_val = y.read().to_vec()?[0];
721 let loss = x_val * x_val + y_val * y_val;
722
723 if loss < tolerance {
724 return Ok(());
725 }
726
727 optimizer.zero_grad();
729 }
730
731 Err(TorshError::Other(format!(
732 "Failed to converge within {} iterations",
733 max_iterations
734 )))
735 }
736
737 pub fn test_linear_regression_convergence<O: Optimizer>(
739 create_optimizer: impl Fn(Vec<Arc<RwLock<Tensor>>>) -> O,
740 tolerance: f32,
741 max_iterations: usize,
742 ) -> Result<()> {
743 let true_weight = 2.0;
745 let true_bias = 1.0;
746
747 let n_samples = 100;
749 let x_data = randn::<f32>(&[n_samples, 1])?;
750 let noise = randn::<f32>(&[n_samples, 1])?.mul_scalar(0.1)?;
751 let y_data = x_data
752 .mul_scalar(true_weight)?
753 .add_scalar(true_bias)?
754 .add(&noise)?;
755
756 let weight = Arc::new(RwLock::new(zeros(&[1, 1])?));
758 let bias = Arc::new(RwLock::new(zeros(&[1])?));
759 let params = vec![weight.clone(), bias.clone()];
760
761 let mut optimizer = create_optimizer(params);
762
763 for i in 0..max_iterations {
764 let w_val = weight.read().clone();
766 let b_val = bias.read().clone();
767
768 let y_pred = x_data.matmul(&w_val)?.add(&b_val)?;
769
770 let diff = y_pred.sub(&y_data)?;
772 let loss_tensor = diff.pow(2.0)?.mean(Some(&[0]), false)?;
773 let loss = loss_tensor.to_vec()?[0];
774
775 let grad_scale = 2.0 / n_samples as f32;
777 let weight_grad = x_data
778 .transpose(0, 1)?
779 .matmul(&diff)?
780 .mul_scalar(grad_scale)?;
781 let bias_grad = diff.sum()?.mul_scalar(grad_scale)?;
782
783 weight.write().set_grad(Some(weight_grad));
784 bias.write().set_grad(Some(bias_grad));
785
786 optimizer
788 .step()
789 .map_err(|e| TorshError::Other(format!("Optimizer step failed: {}", e)))?;
790
791 if loss < tolerance {
793 let learned_weight = weight.read().to_vec()?[0];
795 let learned_bias = bias.read().to_vec()?[0];
796
797 if (learned_weight - true_weight).abs() < 0.1
798 && (learned_bias - true_bias).abs() < 0.1
799 {
800 return Ok(());
801 }
802 }
803
804 optimizer.zero_grad();
806 }
807
808 Err(TorshError::Other(format!(
809 "Failed to converge within {} iterations",
810 max_iterations
811 )))
812 }
813
814 pub fn test_optimizer_consistency<O: Optimizer>(
816 create_optimizer: impl Fn(Vec<Arc<RwLock<Tensor>>>) -> O,
817 n_runs: usize,
818 tolerance: f32,
819 ) -> Result<()> {
820 let mut final_values = Vec::new();
821
822 for run in 0..n_runs {
823 let param = Arc::new(RwLock::new(Tensor::scalar(1.0)?));
824 let params = vec![param.clone()];
825 let mut optimizer = create_optimizer(params);
826
827 for _ in 0..10 {
829 {
830 let param_val = param.read().clone();
831 let grad = param_val.mul_scalar(2.0)?; param.write().set_grad(Some(grad));
833 }
834
835 optimizer
836 .step()
837 .map_err(|e| TorshError::Other(format!("Optimizer step failed: {}", e)))?;
838 optimizer.zero_grad();
839 }
840
841 final_values.push(param.read().to_vec()?[0]);
842 }
843
844 let mean_value = final_values.iter().sum::<f32>() / final_values.len() as f32;
846 for &value in &final_values {
847 if (value - mean_value).abs() > tolerance {
848 return Err(TorshError::Other(format!(
849 "Inconsistent optimizer behavior: values vary by more than {}",
850 tolerance
851 )));
852 }
853 }
854
855 Ok(())
856 }
857}
858
859pub mod prelude {
860 pub use crate::adabelief::AdaBelief;
861 pub use crate::adabound::AdaBound;
862 pub use crate::adadelta::AdaDelta;
863 pub use crate::adagrad::AdaGrad;
864 pub use crate::adahessian::{AdaHessian, AdaHessianBuilder};
865 pub use crate::adam::{Adam, AdamW};
866 pub use crate::adamax::AdaMax;
867 pub use crate::asgd::ASGD;
868 pub use crate::checkpointing::{
869 Checkpoint, CheckpointConfig, CheckpointManager, CheckpointMetadata, CheckpointStatistics,
870 CheckpointSupport, CheckpointingOptimizer,
871 };
872 pub use crate::composition::{
873 CombinationMethod, ComposedOptimizer, CompositionBuilder, CompositionStrategy,
874 OptimizerMetrics, SwitchCriterion, VotingMethod,
875 };
876 pub use crate::debugging::{
877 AnalysisReport, AnalyzerConfig, ConvergenceTracker, GradientFlowPoint, GradientStatistics,
878 HyperparameterSensitivity, OptimizationRecommendation, OptimizationStep, OptimizerAnalyzer,
879 ParameterStatistics, RecommendationCategory, SensitivityReport, SensitivityResult,
880 Severity,
881 };
882 pub use crate::distributed::{
883 utils as distributed_utils, AsyncConfig, AsyncSGD, CommunicationStats, DistributedBackend,
884 DistributedConfig, DistributedOptimizer, ElasticAveragingSGD, SyncStrategy,
885 };
886 pub use crate::ftrl::{FTRLBuilder, FTRL};
887 pub use crate::fused_kernels::{
888 fused_adadelta_step, fused_adagrad_step, fused_adam_step, fused_rmsprop_step,
889 fused_sgd_step, FusedKernelSupport, FusedStats,
890 };
891 pub use crate::grad_accumulation::{
892 with_gradient_accumulation, AccumulatingOptimizer, GradientAccumulationSupport,
893 GradientAccumulator,
894 };
895 pub use crate::kfac::{KFACBuilder, KFAC};
896 pub use crate::lamb::LAMB;
897 pub use crate::lazy_updates::{
898 LazyUpdateConfig, LazyUpdateDecision, LazyUpdateManager, LazyUpdateOptimizer,
899 LazyUpdateStatistics, LazyUpdateSupport, ParameterImportance, PendingUpdate,
900 UpdatePriority,
901 };
902 pub use crate::lbfgs::LBFGS;
903 pub use crate::lion::{Lion, LionBuilder, LionConfig};
904 pub use crate::lookahead::{lookahead_adam, lookahead_radam, lookahead_sgd, Lookahead};
905 pub use crate::low_precision::{
906 LowPrecisionConvertible, LowPrecisionOptimizer, LowPrecisionState, PrecisionType,
907 StateStatistics,
908 };
909 pub use crate::lr_scheduler::{
910 CosineAnnealingLR, ExponentialLR, LRScheduler, OneCycleLR, ReduceLROnPlateau, StepLR,
911 };
912 pub use crate::lr_scheduler_additional::{
913 ConstantLR, CosineAnnealingWarmRestarts, CyclicLR, LinearLR, MultiStepLR, PolynomialLR,
914 };
915 pub use crate::lr_scheduler_enhanced::{
916 utils as lr_enhanced_utils, AdaptiveLRScheduler, AdaptiveSchedulerStats, AdaptiveStrategy,
917 CosineAnnealingWarmRestartsWithWarmup, PolynomialDecayWithWarmup, WarmupStrategy,
918 };
919 pub use crate::memory_efficient::{
920 CircularBuffer, MemoryConfig, MemoryEfficientAdam, MemoryEfficientLBFGS,
921 MemoryEfficientOptimizerBuilder, MemoryPool,
922 };
923 pub use crate::memory_mapped::{
924 MemoryMappedConfig, MemoryMappedFile, MemoryMappedOptimizer, MemoryMappedStateStorage,
925 MemoryMappedSupport, StorageStatistics,
926 };
927 pub use crate::mixed_precision::{
928 with_mixed_precision, MixedPrecisionConfig, MixedPrecisionOptimizer,
929 };
930 pub use crate::nadam::NAdam;
931 pub use crate::natural_gradient::{NaturalGradient, NaturalGradientBuilder};
932 pub use crate::newton_cg::{NewtonCG, NewtonCGBuilder, NewtonCGConfig};
933 pub use crate::online_learning::{
934 OnlineGradientDescent, ProximalGradient, ProximalOperator, SAGA, SVRG,
935 };
936 pub use crate::prodigy::{Prodigy, ProdigyBuilder, ProdigyConfig};
937 pub use crate::radam::RAdam;
938 pub use crate::ranger::{Ranger, RangerBuilder};
939 pub use crate::rmsprop::RMSprop;
940 pub use crate::rprop::Rprop;
941 pub use crate::schedule_free::{ScheduleFreeAdamW, ScheduleFreeAdamWBuilder};
942 pub use crate::sgd::SGD;
943 pub use crate::shampoo::{Shampoo, ShampooBuilder};
944 pub use crate::sophia::{Sophia, SophiaBuilder, SophiaConfig};
945 pub use crate::sparse_adam::SparseAdam;
946 pub use crate::state_dict_ops::{
947 CompressionMethod, CompressionStats, MemoryEstimate, SerializationFormat, StateDictConfig,
948 StateDictManager,
949 };
950 pub use crate::trust_region::{
951 SubproblemSolver, TrustRegionBuilder, TrustRegionConfig, TrustRegionMethod,
952 TrustRegionStrategy,
953 };
954 pub use crate::yellowfin::{YellowFin, YellowFinBuilder, YellowFinConfig};
955 pub use crate::{Optimizer, OptimizerOptions, OptimizerState, ParamGroup, ParamGroupBuilder};
956 pub use crate::{OptimizerError, OptimizerResult};
957}
958
959pub use adam::{Adam, AdamW};
961pub use distributed::{DistributedBackend, DistributedConfig, DistributedOptimizer, SyncStrategy};
962pub use rmsprop::RMSprop;
963pub use sgd::SGD;
964
965#[cfg(test)]
966mod tests {
967 use super::*;
968
969 #[test]
970 fn test_param_group() {
971 let params = vec![];
972 let group = ParamGroup::new(params, 0.01);
973 assert_eq!(group.lr, 0.01);
974 }
975}