Skip to main content

tensorlogic_train/
lib.rs

1//! Training scaffolds: loss wiring, schedules, callbacks.
2//!
3//! **Version**: 0.1.1 | **Status**: Production Ready
4//!
5//! This crate provides comprehensive training infrastructure for Tensorlogic models:
6//! - Loss functions (standard and logical constraint-based)
7//! - Optimizer wrappers around SciRS2
8//! - Training loops with callbacks
9//! - Batch management
10//! - Validation and metrics
11//! - Regularization techniques
12//! - Data augmentation
13//! - Logging and monitoring
14//! - Curriculum learning strategies
15//! - Transfer learning utilities
16//! - Hyperparameter optimization (grid search, random search)
17//! - Cross-validation utilities
18//! - Model ensembling
19//! - Model pruning and compression
20//! - Model quantization (int8, int4, int2)
21//! - Mixed precision training (FP16, BF16)
22//! - Advanced sampling strategies
23
24pub mod adversarial;
25mod augmentation;
26mod batch;
27mod callbacks;
28pub mod checkpoint;
29mod crossval;
30mod curriculum;
31mod data;
32mod distillation;
33mod dropblock;
34pub mod early_stopping;
35mod ensemble;
36mod error;
37mod few_shot;
38mod gradient_accumulation;
39mod gradient_centralization;
40mod hyperparameter;
41mod label_smoothing;
42mod logging;
43pub mod lora;
44mod loss;
45mod lr_scheduler;
46mod memory;
47mod meta_learning;
48mod metrics;
49mod mixed_precision;
50mod model;
51mod multitask;
52pub mod nas;
53pub mod neural_ode;
54pub mod online_learning;
55mod optimizer;
56mod optimizers;
57mod pruning;
58mod quantization;
59mod regularization;
60mod sampling;
61mod scheduler;
62mod stochastic_depth;
63mod trainer;
64mod transfer;
65mod utils;
66pub mod weight_init;
67
68#[cfg(feature = "structured-logging")]
69pub mod structured_logging;
70
71pub use augmentation::{
72    center_crop_2d,
73    clip,
74    cutmix,
75    denormalize,
76    dropout,
77    dropout_mask,
78    gaussian_noise,
79    mixup,
80    normalize,
81    random_crop_2d,
82    random_hflip,
83    random_vflip,
84    // Functional API (v2)
85    AugRng,
86    AugStats,
87    AugmentationError,
88    AugmentationPipeline,
89    AugmentationStep,
90    CompositeAugmenter,
91    CutMixAugmenter,
92    CutOutAugmenter,
93    DataAugmenter,
94    MixupAugmenter,
95    NoAugmentation,
96    NoiseAugmenter,
97    RandomErasingAugmenter,
98    RotationAugmenter,
99    ScaleAugmenter,
100};
101pub use batch::{extract_batch, BatchConfig, BatchIterator, DataShuffler};
102pub use callbacks::{
103    BatchCallback, Callback, CallbackList, CheckpointCallback, CheckpointCompression,
104    EarlyStoppingCallback, EpochCallback, GradientAccumulationCallback, GradientAccumulationStats,
105    GradientMonitor, GradientScalingStrategy, GradientSummary, HistogramCallback, HistogramStats,
106    LearningRateFinder, ModelEMACallback, ProfilingCallback, ProfilingStats,
107    ReduceLrOnPlateauCallback, SWACallback, TrainingCheckpoint, ValidationCallback,
108};
109pub use error::{TrainError, TrainResult};
110pub use logging::{
111    ConsoleLogger, CsvLogger, FileLogger, JsonlLogger, LoggingBackend, MetricsLogger,
112    TensorBoardLogger,
113};
114pub use loss::{
115    BCEWithLogitsLoss, ConstraintViolationLoss, ContrastiveLoss, CrossEntropyLoss, DiceLoss,
116    FocalLoss, HingeLoss, HuberLoss, KLDivergenceLoss, LogicalLoss, Loss, LossConfig, MseLoss,
117    PolyLoss, RuleSatisfactionLoss, TripletLoss, TverskyLoss,
118};
119pub use lr_scheduler::{
120    CosineAnnealingScheduler, CyclicalScheduler, LrSchedulerV2,
121    OneCycleLrScheduler as OneCyclePolicyScheduler, SchedulerConfig, SchedulerError, SchedulerType,
122    StepDecayScheduler, WarmupScheduler,
123};
124pub use metrics::{
125    Accuracy, BalancedAccuracy, CohensKappa, ConfusionMatrix, DiceCoefficient,
126    ExpectedCalibrationError, F1Score, IoU, MatthewsCorrelationCoefficient,
127    MaximumCalibrationError, MeanAveragePrecision, MeanIoU, Metric, MetricTracker,
128    NormalizedDiscountedCumulativeGain, PerClassMetrics, Precision, Recall, RocCurve, TopKAccuracy,
129};
130pub use model::{AutodiffModel, DynamicModel, LinearModel, Model};
131pub use optimizer::{
132    AdaBeliefOptimizer, AdaMaxOptimizer, AdagradOptimizer, AdamOptimizer, AdamPOptimizer,
133    AdamWOptimizer, GradClipMode, LambOptimizer, LarsOptimizer, LionConfig, LionOptimizer,
134    LookaheadOptimizer, NAdamOptimizer, Optimizer, OptimizerConfig, ProdigyConfig,
135    ProdigyOptimizer, RAdamOptimizer, RMSpropOptimizer, SamOptimizer, ScheduleFreeAdamW,
136    ScheduleFreeConfig, SgdOptimizer, SophiaConfig, SophiaOptimizer, SophiaVariant,
137};
138pub use regularization::{
139    CompositeRegularization, ElasticNetRegularization, GroupLassoRegularization, L1Regularization,
140    L2Regularization, MaxNormRegularization, OrthogonalRegularization, Regularizer,
141    SpectralNormalization,
142};
143pub use scheduler::{
144    CosineAnnealingLrScheduler, CyclicLrMode, CyclicLrScheduler, ExponentialLrScheduler,
145    LrScheduler, MultiStepLrScheduler, NoamScheduler, OneCycleLrScheduler, PlateauMode,
146    PolynomialDecayLrScheduler, ReduceLROnPlateauScheduler, SgdrScheduler, StepLrScheduler,
147    WarmupCosineLrScheduler,
148};
149pub use trainer::{Trainer, TrainerConfig, TrainingHistory, TrainingState};
150
151// Curriculum learning
152pub use curriculum::{
153    CompetenceCurriculum, CurriculumManager, CurriculumStrategy, ExponentialCurriculum,
154    LinearCurriculum, SelfPacedCurriculum, TaskCurriculum,
155};
156
157// Transfer learning
158pub use transfer::{
159    DiscriminativeFineTuning, FeatureExtractorMode, LayerFreezingConfig, ProgressiveUnfreezing,
160    TransferLearningManager,
161};
162
163// Hyperparameter optimization
164pub use hyperparameter::{
165    AcquisitionFunction, BayesianOptimization, GaussianProcess, GpKernel, GridSearch,
166    HyperparamConfig, HyperparamResult, HyperparamSpace, HyperparamValue, RandomSearch,
167};
168
169// Cross-validation
170pub use crossval::{
171    CrossValidationResults, CrossValidationSplit, KFold, LeaveOneOut, StratifiedKFold,
172    TimeSeriesSplit,
173};
174
175// Ensembling
176pub use ensemble::{
177    AveragingEnsemble, BaggingHelper, Ensemble, ModelSoup, SoupRecipe, StackingEnsemble,
178    VotingEnsemble, VotingMode,
179};
180
181// Multi-task learning
182pub use multitask::{MultiTaskLoss, PCGrad, TaskWeightingStrategy};
183
184// Knowledge distillation
185pub use distillation::{AttentionTransferLoss, DistillationLoss, FeatureDistillationLoss};
186
187// Label smoothing
188pub use label_smoothing::{LabelSmoothingLoss, MixupLoss};
189
190// Memory management and profiling
191pub use memory::{
192    CheckpointStrategy, GradientCheckpointConfig, MemoryBudgetManager, MemoryEfficientTraining,
193    MemoryProfilerCallback, MemorySettings, MemoryStats,
194};
195
196// Data loading and preprocessing
197pub use data::{
198    CsvLoader, DataPreprocessor, Dataset, LabelEncoder, OneHotEncoder, PreprocessingMethod,
199};
200
201// Utilities for model introspection and analysis
202pub use utils::{
203    compare_models, compute_gradient_stats, format_duration, print_gradient_report, GradientStats,
204    LrRangeTestAnalyzer, ModelSummary, ParameterDifference, ParameterStats, TimeEstimator,
205};
206
207// Model pruning and compression
208pub use pruning::{
209    GlobalPruner, GradientPruner, LayerPruningStats, MagnitudePruner, Pruner, PruningConfig,
210    PruningMask, PruningStats, StructuredPruner, StructuredPruningAxis,
211};
212
213// Advanced sampling strategies
214pub use sampling::{
215    BatchReweighter, ClassBalancedSampler, CurriculumSampler, FocalSampler, HardNegativeMiner,
216    ImportanceSampler, MiningStrategy, OnlineHardExampleMiner, ReweightingStrategy,
217};
218
219// Model quantization and compression
220pub use quantization::{
221    BitWidth, DynamicRangeCalibrator, Granularity, QuantizationAwareTraining, QuantizationConfig,
222    QuantizationMode, QuantizationParams, QuantizedTensor, Quantizer,
223};
224
225// Mixed precision training
226pub use mixed_precision::{
227    AutocastContext, GradientScaler, LossScaler, MixedPrecisionStats, MixedPrecisionTrainer,
228    PrecisionMode,
229};
230
231// Few-shot learning
232pub use few_shot::{
233    DistanceMetric, EpisodeSampler, FewShotAccuracy, MatchingNetwork, PrototypicalDistance,
234    ShotType, SupportSet,
235};
236
237// Meta-learning
238pub use meta_learning::{
239    MAMLConfig, MetaLearner, MetaStats, MetaTask, Reptile, ReptileConfig, MAML,
240};
241
242// Gradient accumulation and micro-batching
243pub use gradient_accumulation::{
244    AccumulationConfig, AccumulationError, AccumulationStats, GradientAccumulator, GradientBuffer,
245};
246
247// Gradient centralization
248pub use gradient_centralization::{GcConfig, GcStats, GcStrategy, GradientCentralization};
249
250// Stochastic Depth (DropPath)
251pub use stochastic_depth::{DropPath, ExponentialStochasticDepth, LinearStochasticDepth};
252
253// DropBlock regularization
254pub use dropblock::{DropBlock, LinearDropBlockScheduler};
255
256// Early stopping
257pub use early_stopping::{
258    EarlyStoppingConfig, EarlyStoppingDecision, EarlyStoppingMonitor, MonitorMode,
259    MultiMetricMonitor, MultiMetricPolicy, PlateauDetector, TrainingProgress,
260};
261
262// Optimizer checkpointing
263pub use checkpoint::{
264    deserialize_checkpoint, serialize_checkpoint, CheckpointError, CheckpointFormat,
265    CheckpointManager, CheckpointMetadata, LossTracker, OptimizerCheckpoint, ParamState,
266};
267
268// Weight initialization strategies
269pub use weight_init::{
270    compute_fans, constant_init, gain_for_activation, kaiming_normal, kaiming_uniform,
271    lecun_normal, lecun_uniform, normal_init, ones_init, orthogonal_init, uniform_init,
272    xavier_normal, xavier_uniform, zeros_init, FanMode, InitError, InitRng, InitStats,
273};
274
275// Online learning algorithms
276pub use online_learning::{
277    online_evaluate, Ftrl, OGDLoss, OnlineError, OnlineGradientDescent, OnlineLearner, OnlineStats,
278    OnlineUpdateResult, PAVariant, PassiveAggressive, Perceptron,
279};
280
281// Adversarial training utilities
282pub use adversarial::{
283    adversarial_training_loss, fgsm, pgd, project_l1, project_l2, project_linf, robustness_eval,
284    AdversarialError, AdversarialExample, AdversarialTrainStats, AttackConfig, AttackLoss,
285    AttackModel, CrossEntropyAttackLoss, LinearAttackModel, MseAttackLoss, PerturbNorm,
286};
287
288// Neural ODE — continuous-depth models with adjoint sensitivity
289pub use neural_ode::{
290    dopri5_solve, rk4_solve, AdaptiveSolution, AdjointResult, NeuralOde, OdeError, OdeFunc,
291    OdeSolution, OdeSolverConfig,
292};
293
294// LoRA — low-rank adaptation for parameter-efficient fine-tuning
295pub use lora::{LoraAdapter, LoraConfig, LoraError, LoraLayer};
296
297// Neural Architecture Search
298pub use nas::{
299    ArchSampler, ArchSearchSpace, Architecture, LayerSpec, NasResult, RandomArchSearch,
300    RegularizedEvolution,
301};