1mod augmentation;
25mod batch;
26mod callbacks;
27mod crossval;
28mod curriculum;
29mod data;
30mod distillation;
31mod dropblock;
32mod ensemble;
33mod error;
34mod few_shot;
35mod gradient_centralization;
36mod hyperparameter;
37mod label_smoothing;
38mod logging;
39mod loss;
40mod memory;
41mod meta_learning;
42mod metrics;
43mod mixed_precision;
44mod model;
45mod multitask;
46mod optimizer;
47mod optimizers;
48mod pruning;
49mod quantization;
50mod regularization;
51mod sampling;
52mod scheduler;
53mod stochastic_depth;
54mod trainer;
55mod transfer;
56mod utils;
57
58#[cfg(feature = "structured-logging")]
59pub mod structured_logging;
60
61pub use augmentation::{
62 CompositeAugmenter, CutMixAugmenter, CutOutAugmenter, DataAugmenter, MixupAugmenter,
63 NoAugmentation, NoiseAugmenter, RandomErasingAugmenter, RotationAugmenter, ScaleAugmenter,
64};
65pub use batch::{extract_batch, BatchConfig, BatchIterator, DataShuffler};
66pub use callbacks::{
67 BatchCallback, Callback, CallbackList, CheckpointCallback, CheckpointCompression,
68 EarlyStoppingCallback, EpochCallback, GradientAccumulationCallback, GradientAccumulationStats,
69 GradientMonitor, GradientScalingStrategy, GradientSummary, HistogramCallback, HistogramStats,
70 LearningRateFinder, ModelEMACallback, ProfilingCallback, ProfilingStats,
71 ReduceLrOnPlateauCallback, SWACallback, TrainingCheckpoint, ValidationCallback,
72};
73pub use error::{TrainError, TrainResult};
74pub use logging::{
75 ConsoleLogger, CsvLogger, FileLogger, JsonlLogger, LoggingBackend, MetricsLogger,
76 TensorBoardLogger,
77};
78pub use loss::{
79 BCEWithLogitsLoss, ConstraintViolationLoss, ContrastiveLoss, CrossEntropyLoss, DiceLoss,
80 FocalLoss, HingeLoss, HuberLoss, KLDivergenceLoss, LogicalLoss, Loss, LossConfig, MseLoss,
81 PolyLoss, RuleSatisfactionLoss, TripletLoss, TverskyLoss,
82};
83pub use metrics::{
84 Accuracy, BalancedAccuracy, CohensKappa, ConfusionMatrix, DiceCoefficient,
85 ExpectedCalibrationError, F1Score, IoU, MatthewsCorrelationCoefficient,
86 MaximumCalibrationError, MeanAveragePrecision, MeanIoU, Metric, MetricTracker,
87 NormalizedDiscountedCumulativeGain, PerClassMetrics, Precision, Recall, RocCurve, TopKAccuracy,
88};
89pub use model::{AutodiffModel, DynamicModel, LinearModel, Model};
90pub use optimizer::{
91 AdaBeliefOptimizer, AdaMaxOptimizer, AdagradOptimizer, AdamOptimizer, AdamPOptimizer,
92 AdamWOptimizer, GradClipMode, LambOptimizer, LarsOptimizer, LionConfig, LionOptimizer,
93 LookaheadOptimizer, NAdamOptimizer, Optimizer, OptimizerConfig, ProdigyConfig,
94 ProdigyOptimizer, RAdamOptimizer, RMSpropOptimizer, SamOptimizer, ScheduleFreeAdamW,
95 ScheduleFreeConfig, SgdOptimizer, SophiaConfig, SophiaOptimizer, SophiaVariant,
96};
97pub use regularization::{
98 CompositeRegularization, ElasticNetRegularization, GroupLassoRegularization, L1Regularization,
99 L2Regularization, MaxNormRegularization, OrthogonalRegularization, Regularizer,
100 SpectralNormalization,
101};
102pub use scheduler::{
103 CosineAnnealingLrScheduler, CyclicLrMode, CyclicLrScheduler, ExponentialLrScheduler,
104 LrScheduler, MultiStepLrScheduler, NoamScheduler, OneCycleLrScheduler, PlateauMode,
105 PolynomialDecayLrScheduler, ReduceLROnPlateauScheduler, SgdrScheduler, StepLrScheduler,
106 WarmupCosineLrScheduler,
107};
108pub use trainer::{Trainer, TrainerConfig, TrainingHistory, TrainingState};
109
110pub use curriculum::{
112 CompetenceCurriculum, CurriculumManager, CurriculumStrategy, ExponentialCurriculum,
113 LinearCurriculum, SelfPacedCurriculum, TaskCurriculum,
114};
115
116pub use transfer::{
118 DiscriminativeFineTuning, FeatureExtractorMode, LayerFreezingConfig, ProgressiveUnfreezing,
119 TransferLearningManager,
120};
121
122pub use hyperparameter::{
124 AcquisitionFunction, BayesianOptimization, GaussianProcess, GpKernel, GridSearch,
125 HyperparamConfig, HyperparamResult, HyperparamSpace, HyperparamValue, RandomSearch,
126};
127
128pub use crossval::{
130 CrossValidationResults, CrossValidationSplit, KFold, LeaveOneOut, StratifiedKFold,
131 TimeSeriesSplit,
132};
133
134pub use ensemble::{
136 AveragingEnsemble, BaggingHelper, Ensemble, ModelSoup, SoupRecipe, StackingEnsemble,
137 VotingEnsemble, VotingMode,
138};
139
140pub use multitask::{MultiTaskLoss, PCGrad, TaskWeightingStrategy};
142
143pub use distillation::{AttentionTransferLoss, DistillationLoss, FeatureDistillationLoss};
145
146pub use label_smoothing::{LabelSmoothingLoss, MixupLoss};
148
149pub use memory::{
151 CheckpointStrategy, GradientCheckpointConfig, MemoryBudgetManager, MemoryEfficientTraining,
152 MemoryProfilerCallback, MemorySettings, MemoryStats,
153};
154
155pub use data::{
157 CsvLoader, DataPreprocessor, Dataset, LabelEncoder, OneHotEncoder, PreprocessingMethod,
158};
159
160pub use utils::{
162 compare_models, compute_gradient_stats, format_duration, print_gradient_report, GradientStats,
163 LrRangeTestAnalyzer, ModelSummary, ParameterDifference, ParameterStats, TimeEstimator,
164};
165
166pub use pruning::{
168 GlobalPruner, GradientPruner, LayerPruningStats, MagnitudePruner, Pruner, PruningConfig,
169 PruningMask, PruningStats, StructuredPruner, StructuredPruningAxis,
170};
171
172pub use sampling::{
174 BatchReweighter, ClassBalancedSampler, CurriculumSampler, FocalSampler, HardNegativeMiner,
175 ImportanceSampler, MiningStrategy, OnlineHardExampleMiner, ReweightingStrategy,
176};
177
178pub use quantization::{
180 BitWidth, DynamicRangeCalibrator, Granularity, QuantizationAwareTraining, QuantizationConfig,
181 QuantizationMode, QuantizationParams, QuantizedTensor, Quantizer,
182};
183
184pub use mixed_precision::{
186 AutocastContext, GradientScaler, LossScaler, MixedPrecisionStats, MixedPrecisionTrainer,
187 PrecisionMode,
188};
189
190pub use few_shot::{
192 DistanceMetric, EpisodeSampler, FewShotAccuracy, MatchingNetwork, PrototypicalDistance,
193 ShotType, SupportSet,
194};
195
196pub use meta_learning::{
198 MAMLConfig, MetaLearner, MetaStats, MetaTask, Reptile, ReptileConfig, MAML,
199};
200
201pub use gradient_centralization::{GcConfig, GcStats, GcStrategy, GradientCentralization};
203
204pub use stochastic_depth::{DropPath, ExponentialStochasticDepth, LinearStochasticDepth};
206
207pub use dropblock::{DropBlock, LinearDropBlockScheduler};