1pub mod adversarial;
25mod augmentation;
26mod batch;
27mod callbacks;
28pub mod checkpoint;
29mod crossval;
30mod curriculum;
31mod data;
32mod distillation;
33mod dropblock;
34pub mod early_stopping;
35mod ensemble;
36mod error;
37mod few_shot;
38mod gradient_accumulation;
39mod gradient_centralization;
40mod hyperparameter;
41mod label_smoothing;
42mod logging;
43pub mod lora;
44mod loss;
45mod lr_scheduler;
46mod memory;
47mod meta_learning;
48mod metrics;
49mod mixed_precision;
50mod model;
51mod multitask;
52pub mod neural_ode;
53pub mod online_learning;
54mod optimizer;
55mod optimizers;
56mod pruning;
57mod quantization;
58mod regularization;
59mod sampling;
60mod scheduler;
61mod stochastic_depth;
62mod trainer;
63mod transfer;
64mod utils;
65pub mod weight_init;
66
67#[cfg(feature = "structured-logging")]
68pub mod structured_logging;
69
70pub use augmentation::{
71 center_crop_2d,
72 clip,
73 cutmix,
74 denormalize,
75 dropout,
76 dropout_mask,
77 gaussian_noise,
78 mixup,
79 normalize,
80 random_crop_2d,
81 random_hflip,
82 random_vflip,
83 AugRng,
85 AugStats,
86 AugmentationError,
87 AugmentationPipeline,
88 AugmentationStep,
89 CompositeAugmenter,
90 CutMixAugmenter,
91 CutOutAugmenter,
92 DataAugmenter,
93 MixupAugmenter,
94 NoAugmentation,
95 NoiseAugmenter,
96 RandomErasingAugmenter,
97 RotationAugmenter,
98 ScaleAugmenter,
99};
100pub use batch::{extract_batch, BatchConfig, BatchIterator, DataShuffler};
101pub use callbacks::{
102 BatchCallback, Callback, CallbackList, CheckpointCallback, CheckpointCompression,
103 EarlyStoppingCallback, EpochCallback, GradientAccumulationCallback, GradientAccumulationStats,
104 GradientMonitor, GradientScalingStrategy, GradientSummary, HistogramCallback, HistogramStats,
105 LearningRateFinder, ModelEMACallback, ProfilingCallback, ProfilingStats,
106 ReduceLrOnPlateauCallback, SWACallback, TrainingCheckpoint, ValidationCallback,
107};
108pub use error::{TrainError, TrainResult};
109pub use logging::{
110 ConsoleLogger, CsvLogger, FileLogger, JsonlLogger, LoggingBackend, MetricsLogger,
111 TensorBoardLogger,
112};
113pub use loss::{
114 BCEWithLogitsLoss, ConstraintViolationLoss, ContrastiveLoss, CrossEntropyLoss, DiceLoss,
115 FocalLoss, HingeLoss, HuberLoss, KLDivergenceLoss, LogicalLoss, Loss, LossConfig, MseLoss,
116 PolyLoss, RuleSatisfactionLoss, TripletLoss, TverskyLoss,
117};
118pub use lr_scheduler::{
119 CosineAnnealingScheduler, CyclicalScheduler, LrSchedulerV2,
120 OneCycleLrScheduler as OneCyclePolicyScheduler, SchedulerConfig, SchedulerError, SchedulerType,
121 StepDecayScheduler, WarmupScheduler,
122};
123pub use metrics::{
124 Accuracy, BalancedAccuracy, CohensKappa, ConfusionMatrix, DiceCoefficient,
125 ExpectedCalibrationError, F1Score, IoU, MatthewsCorrelationCoefficient,
126 MaximumCalibrationError, MeanAveragePrecision, MeanIoU, Metric, MetricTracker,
127 NormalizedDiscountedCumulativeGain, PerClassMetrics, Precision, Recall, RocCurve, TopKAccuracy,
128};
129pub use model::{AutodiffModel, DynamicModel, LinearModel, Model};
130pub use optimizer::{
131 AdaBeliefOptimizer, AdaMaxOptimizer, AdagradOptimizer, AdamOptimizer, AdamPOptimizer,
132 AdamWOptimizer, GradClipMode, LambOptimizer, LarsOptimizer, LionConfig, LionOptimizer,
133 LookaheadOptimizer, NAdamOptimizer, Optimizer, OptimizerConfig, ProdigyConfig,
134 ProdigyOptimizer, RAdamOptimizer, RMSpropOptimizer, SamOptimizer, ScheduleFreeAdamW,
135 ScheduleFreeConfig, SgdOptimizer, SophiaConfig, SophiaOptimizer, SophiaVariant,
136};
137pub use regularization::{
138 CompositeRegularization, ElasticNetRegularization, GroupLassoRegularization, L1Regularization,
139 L2Regularization, MaxNormRegularization, OrthogonalRegularization, Regularizer,
140 SpectralNormalization,
141};
142pub use scheduler::{
143 CosineAnnealingLrScheduler, CyclicLrMode, CyclicLrScheduler, ExponentialLrScheduler,
144 LrScheduler, MultiStepLrScheduler, NoamScheduler, OneCycleLrScheduler, PlateauMode,
145 PolynomialDecayLrScheduler, ReduceLROnPlateauScheduler, SgdrScheduler, StepLrScheduler,
146 WarmupCosineLrScheduler,
147};
148pub use trainer::{Trainer, TrainerConfig, TrainingHistory, TrainingState};
149
150pub use curriculum::{
152 CompetenceCurriculum, CurriculumManager, CurriculumStrategy, ExponentialCurriculum,
153 LinearCurriculum, SelfPacedCurriculum, TaskCurriculum,
154};
155
156pub use transfer::{
158 DiscriminativeFineTuning, FeatureExtractorMode, LayerFreezingConfig, ProgressiveUnfreezing,
159 TransferLearningManager,
160};
161
162pub use hyperparameter::{
164 AcquisitionFunction, BayesianOptimization, GaussianProcess, GpKernel, GridSearch,
165 HyperparamConfig, HyperparamResult, HyperparamSpace, HyperparamValue, RandomSearch,
166};
167
168pub use crossval::{
170 CrossValidationResults, CrossValidationSplit, KFold, LeaveOneOut, StratifiedKFold,
171 TimeSeriesSplit,
172};
173
174pub use ensemble::{
176 AveragingEnsemble, BaggingHelper, Ensemble, ModelSoup, SoupRecipe, StackingEnsemble,
177 VotingEnsemble, VotingMode,
178};
179
180pub use multitask::{MultiTaskLoss, PCGrad, TaskWeightingStrategy};
182
183pub use distillation::{AttentionTransferLoss, DistillationLoss, FeatureDistillationLoss};
185
186pub use label_smoothing::{LabelSmoothingLoss, MixupLoss};
188
189pub use memory::{
191 CheckpointStrategy, GradientCheckpointConfig, MemoryBudgetManager, MemoryEfficientTraining,
192 MemoryProfilerCallback, MemorySettings, MemoryStats,
193};
194
195pub use data::{
197 CsvLoader, DataPreprocessor, Dataset, LabelEncoder, OneHotEncoder, PreprocessingMethod,
198};
199
200pub use utils::{
202 compare_models, compute_gradient_stats, format_duration, print_gradient_report, GradientStats,
203 LrRangeTestAnalyzer, ModelSummary, ParameterDifference, ParameterStats, TimeEstimator,
204};
205
206pub use pruning::{
208 GlobalPruner, GradientPruner, LayerPruningStats, MagnitudePruner, Pruner, PruningConfig,
209 PruningMask, PruningStats, StructuredPruner, StructuredPruningAxis,
210};
211
212pub use sampling::{
214 BatchReweighter, ClassBalancedSampler, CurriculumSampler, FocalSampler, HardNegativeMiner,
215 ImportanceSampler, MiningStrategy, OnlineHardExampleMiner, ReweightingStrategy,
216};
217
218pub use quantization::{
220 BitWidth, DynamicRangeCalibrator, Granularity, QuantizationAwareTraining, QuantizationConfig,
221 QuantizationMode, QuantizationParams, QuantizedTensor, Quantizer,
222};
223
224pub use mixed_precision::{
226 AutocastContext, GradientScaler, LossScaler, MixedPrecisionStats, MixedPrecisionTrainer,
227 PrecisionMode,
228};
229
230pub use few_shot::{
232 DistanceMetric, EpisodeSampler, FewShotAccuracy, MatchingNetwork, PrototypicalDistance,
233 ShotType, SupportSet,
234};
235
236pub use meta_learning::{
238 MAMLConfig, MetaLearner, MetaStats, MetaTask, Reptile, ReptileConfig, MAML,
239};
240
241pub use gradient_accumulation::{
243 AccumulationConfig, AccumulationError, AccumulationStats, GradientAccumulator, GradientBuffer,
244};
245
246pub use gradient_centralization::{GcConfig, GcStats, GcStrategy, GradientCentralization};
248
249pub use stochastic_depth::{DropPath, ExponentialStochasticDepth, LinearStochasticDepth};
251
252pub use dropblock::{DropBlock, LinearDropBlockScheduler};
254
255pub use early_stopping::{
257 EarlyStoppingConfig, EarlyStoppingDecision, EarlyStoppingMonitor, MonitorMode,
258 MultiMetricMonitor, MultiMetricPolicy, PlateauDetector, TrainingProgress,
259};
260
261pub use checkpoint::{
263 deserialize_checkpoint, serialize_checkpoint, CheckpointError, CheckpointFormat,
264 CheckpointManager, CheckpointMetadata, LossTracker, OptimizerCheckpoint, ParamState,
265};
266
267pub use weight_init::{
269 compute_fans, constant_init, gain_for_activation, kaiming_normal, kaiming_uniform,
270 lecun_normal, lecun_uniform, normal_init, ones_init, orthogonal_init, uniform_init,
271 xavier_normal, xavier_uniform, zeros_init, FanMode, InitError, InitRng, InitStats,
272};
273
274pub use online_learning::{
276 online_evaluate, Ftrl, OGDLoss, OnlineError, OnlineGradientDescent, OnlineLearner, OnlineStats,
277 OnlineUpdateResult, PAVariant, PassiveAggressive, Perceptron,
278};
279
280pub use adversarial::{
282 adversarial_training_loss, fgsm, pgd, project_l1, project_l2, project_linf, robustness_eval,
283 AdversarialError, AdversarialExample, AdversarialTrainStats, AttackConfig, AttackLoss,
284 AttackModel, CrossEntropyAttackLoss, LinearAttackModel, MseAttackLoss, PerturbNorm,
285};
286
287pub use neural_ode::{
289 dopri5_solve, rk4_solve, AdaptiveSolution, AdjointResult, NeuralOde, OdeError, OdeFunc,
290 OdeSolution, OdeSolverConfig,
291};
292
293pub use lora::{LoraAdapter, LoraConfig, LoraError, LoraLayer};