1pub mod adversarial;
25mod augmentation;
26mod batch;
27mod callbacks;
28pub mod checkpoint;
29mod crossval;
30mod curriculum;
31mod data;
32mod distillation;
33mod dropblock;
34pub mod early_stopping;
35mod ensemble;
36mod error;
37mod few_shot;
38mod gradient_accumulation;
39mod gradient_centralization;
40mod hyperparameter;
41mod label_smoothing;
42mod logging;
43pub mod lora;
44mod loss;
45mod lr_scheduler;
46mod memory;
47mod meta_learning;
48mod metrics;
49mod mixed_precision;
50mod model;
51mod multitask;
52pub mod nas;
53pub mod neural_ode;
54pub mod online_learning;
55mod optimizer;
56mod optimizers;
57mod pruning;
58mod quantization;
59mod regularization;
60mod sampling;
61mod scheduler;
62mod stochastic_depth;
63mod trainer;
64mod transfer;
65mod utils;
66pub mod weight_init;
67
68#[cfg(feature = "structured-logging")]
69pub mod structured_logging;
70
71pub use augmentation::{
72 center_crop_2d,
73 clip,
74 cutmix,
75 denormalize,
76 dropout,
77 dropout_mask,
78 gaussian_noise,
79 mixup,
80 normalize,
81 random_crop_2d,
82 random_hflip,
83 random_vflip,
84 AugRng,
86 AugStats,
87 AugmentationError,
88 AugmentationPipeline,
89 AugmentationStep,
90 CompositeAugmenter,
91 CutMixAugmenter,
92 CutOutAugmenter,
93 DataAugmenter,
94 MixupAugmenter,
95 NoAugmentation,
96 NoiseAugmenter,
97 RandomErasingAugmenter,
98 RotationAugmenter,
99 ScaleAugmenter,
100};
101pub use batch::{extract_batch, BatchConfig, BatchIterator, DataShuffler};
102pub use callbacks::{
103 BatchCallback, Callback, CallbackList, CheckpointCallback, CheckpointCompression,
104 EarlyStoppingCallback, EpochCallback, GradientAccumulationCallback, GradientAccumulationStats,
105 GradientMonitor, GradientScalingStrategy, GradientSummary, HistogramCallback, HistogramStats,
106 LearningRateFinder, ModelEMACallback, ProfilingCallback, ProfilingStats,
107 ReduceLrOnPlateauCallback, SWACallback, TrainingCheckpoint, ValidationCallback,
108};
109pub use error::{TrainError, TrainResult};
110pub use logging::{
111 ConsoleLogger, CsvLogger, FileLogger, JsonlLogger, LoggingBackend, MetricsLogger,
112 TensorBoardLogger,
113};
114pub use loss::{
115 BCEWithLogitsLoss, ConstraintViolationLoss, ContrastiveLoss, CrossEntropyLoss, DiceLoss,
116 FocalLoss, HingeLoss, HuberLoss, KLDivergenceLoss, LogicalLoss, Loss, LossConfig, MseLoss,
117 PolyLoss, RuleSatisfactionLoss, TripletLoss, TverskyLoss,
118};
119pub use lr_scheduler::{
120 CosineAnnealingScheduler, CyclicalScheduler, LrSchedulerV2,
121 OneCycleLrScheduler as OneCyclePolicyScheduler, SchedulerConfig, SchedulerError, SchedulerType,
122 StepDecayScheduler, WarmupScheduler,
123};
124pub use metrics::{
125 Accuracy, BalancedAccuracy, CohensKappa, ConfusionMatrix, DiceCoefficient,
126 ExpectedCalibrationError, F1Score, IoU, MatthewsCorrelationCoefficient,
127 MaximumCalibrationError, MeanAveragePrecision, MeanIoU, Metric, MetricTracker,
128 NormalizedDiscountedCumulativeGain, PerClassMetrics, Precision, Recall, RocCurve, TopKAccuracy,
129};
130pub use model::{AutodiffModel, DynamicModel, LinearModel, Model};
131pub use optimizer::{
132 AdaBeliefOptimizer, AdaMaxOptimizer, AdagradOptimizer, AdamOptimizer, AdamPOptimizer,
133 AdamWOptimizer, GradClipMode, LambOptimizer, LarsOptimizer, LionConfig, LionOptimizer,
134 LookaheadOptimizer, NAdamOptimizer, Optimizer, OptimizerConfig, ProdigyConfig,
135 ProdigyOptimizer, RAdamOptimizer, RMSpropOptimizer, SamOptimizer, ScheduleFreeAdamW,
136 ScheduleFreeConfig, SgdOptimizer, SophiaConfig, SophiaOptimizer, SophiaVariant,
137};
138pub use regularization::{
139 CompositeRegularization, ElasticNetRegularization, GroupLassoRegularization, L1Regularization,
140 L2Regularization, MaxNormRegularization, OrthogonalRegularization, Regularizer,
141 SpectralNormalization,
142};
143pub use scheduler::{
144 CosineAnnealingLrScheduler, CyclicLrMode, CyclicLrScheduler, ExponentialLrScheduler,
145 LrScheduler, MultiStepLrScheduler, NoamScheduler, OneCycleLrScheduler, PlateauMode,
146 PolynomialDecayLrScheduler, ReduceLROnPlateauScheduler, SgdrScheduler, StepLrScheduler,
147 WarmupCosineLrScheduler,
148};
149pub use trainer::{Trainer, TrainerConfig, TrainingHistory, TrainingState};
150
151pub use curriculum::{
153 CompetenceCurriculum, CurriculumManager, CurriculumStrategy, ExponentialCurriculum,
154 LinearCurriculum, SelfPacedCurriculum, TaskCurriculum,
155};
156
157pub use transfer::{
159 DiscriminativeFineTuning, FeatureExtractorMode, LayerFreezingConfig, ProgressiveUnfreezing,
160 TransferLearningManager,
161};
162
163pub use hyperparameter::{
165 AcquisitionFunction, BayesianOptimization, GaussianProcess, GpKernel, GridSearch,
166 HyperparamConfig, HyperparamResult, HyperparamSpace, HyperparamValue, RandomSearch,
167};
168
169pub use crossval::{
171 CrossValidationResults, CrossValidationSplit, KFold, LeaveOneOut, StratifiedKFold,
172 TimeSeriesSplit,
173};
174
175pub use ensemble::{
177 AveragingEnsemble, BaggingHelper, Ensemble, ModelSoup, SoupRecipe, StackingEnsemble,
178 VotingEnsemble, VotingMode,
179};
180
181pub use multitask::{MultiTaskLoss, PCGrad, TaskWeightingStrategy};
183
184pub use distillation::{AttentionTransferLoss, DistillationLoss, FeatureDistillationLoss};
186
187pub use label_smoothing::{LabelSmoothingLoss, MixupLoss};
189
190pub use memory::{
192 CheckpointStrategy, GradientCheckpointConfig, MemoryBudgetManager, MemoryEfficientTraining,
193 MemoryProfilerCallback, MemorySettings, MemoryStats,
194};
195
196pub use data::{
198 CsvLoader, DataPreprocessor, Dataset, LabelEncoder, OneHotEncoder, PreprocessingMethod,
199};
200
201pub use utils::{
203 compare_models, compute_gradient_stats, format_duration, print_gradient_report, GradientStats,
204 LrRangeTestAnalyzer, ModelSummary, ParameterDifference, ParameterStats, TimeEstimator,
205};
206
207pub use pruning::{
209 GlobalPruner, GradientPruner, LayerPruningStats, MagnitudePruner, Pruner, PruningConfig,
210 PruningMask, PruningStats, StructuredPruner, StructuredPruningAxis,
211};
212
213pub use sampling::{
215 BatchReweighter, ClassBalancedSampler, CurriculumSampler, FocalSampler, HardNegativeMiner,
216 ImportanceSampler, MiningStrategy, OnlineHardExampleMiner, ReweightingStrategy,
217};
218
219pub use quantization::{
221 BitWidth, DynamicRangeCalibrator, Granularity, QuantizationAwareTraining, QuantizationConfig,
222 QuantizationMode, QuantizationParams, QuantizedTensor, Quantizer,
223};
224
225pub use mixed_precision::{
227 AutocastContext, GradientScaler, LossScaler, MixedPrecisionStats, MixedPrecisionTrainer,
228 PrecisionMode,
229};
230
231pub use few_shot::{
233 DistanceMetric, EpisodeSampler, FewShotAccuracy, MatchingNetwork, PrototypicalDistance,
234 ShotType, SupportSet,
235};
236
237pub use meta_learning::{
239 MAMLConfig, MetaLearner, MetaStats, MetaTask, Reptile, ReptileConfig, MAML,
240};
241
242pub use gradient_accumulation::{
244 AccumulationConfig, AccumulationError, AccumulationStats, GradientAccumulator, GradientBuffer,
245};
246
247pub use gradient_centralization::{GcConfig, GcStats, GcStrategy, GradientCentralization};
249
250pub use stochastic_depth::{DropPath, ExponentialStochasticDepth, LinearStochasticDepth};
252
253pub use dropblock::{DropBlock, LinearDropBlockScheduler};
255
256pub use early_stopping::{
258 EarlyStoppingConfig, EarlyStoppingDecision, EarlyStoppingMonitor, MonitorMode,
259 MultiMetricMonitor, MultiMetricPolicy, PlateauDetector, TrainingProgress,
260};
261
262pub use checkpoint::{
264 deserialize_checkpoint, serialize_checkpoint, CheckpointError, CheckpointFormat,
265 CheckpointManager, CheckpointMetadata, LossTracker, OptimizerCheckpoint, ParamState,
266};
267
268pub use weight_init::{
270 compute_fans, constant_init, gain_for_activation, kaiming_normal, kaiming_uniform,
271 lecun_normal, lecun_uniform, normal_init, ones_init, orthogonal_init, uniform_init,
272 xavier_normal, xavier_uniform, zeros_init, FanMode, InitError, InitRng, InitStats,
273};
274
275pub use online_learning::{
277 online_evaluate, Ftrl, OGDLoss, OnlineError, OnlineGradientDescent, OnlineLearner, OnlineStats,
278 OnlineUpdateResult, PAVariant, PassiveAggressive, Perceptron,
279};
280
281pub use adversarial::{
283 adversarial_training_loss, fgsm, pgd, project_l1, project_l2, project_linf, robustness_eval,
284 AdversarialError, AdversarialExample, AdversarialTrainStats, AttackConfig, AttackLoss,
285 AttackModel, CrossEntropyAttackLoss, LinearAttackModel, MseAttackLoss, PerturbNorm,
286};
287
288pub use neural_ode::{
290 dopri5_solve, rk4_solve, AdaptiveSolution, AdjointResult, NeuralOde, OdeError, OdeFunc,
291 OdeSolution, OdeSolverConfig,
292};
293
294pub use lora::{LoraAdapter, LoraConfig, LoraError, LoraLayer};
296
297pub use nas::{
299 ArchSampler, ArchSearchSpace, Architecture, LayerSpec, NasResult, RandomArchSearch,
300 RegularizedEvolution,
301};