Skip to main content

tensorlogic_train/
lib.rs

1//! Training scaffolds: loss wiring, schedules, callbacks.
2//!
3//! **Version**: 0.1.0 | **Status**: Production Ready
4//!
5//! This crate provides comprehensive training infrastructure for Tensorlogic models:
6//! - Loss functions (standard and logical constraint-based)
7//! - Optimizer wrappers around SciRS2
8//! - Training loops with callbacks
9//! - Batch management
10//! - Validation and metrics
11//! - Regularization techniques
12//! - Data augmentation
13//! - Logging and monitoring
14//! - Curriculum learning strategies
15//! - Transfer learning utilities
16//! - Hyperparameter optimization (grid search, random search)
17//! - Cross-validation utilities
18//! - Model ensembling
19//! - Model pruning and compression
20//! - Model quantization (int8, int4, int2)
21//! - Mixed precision training (FP16, BF16)
22//! - Advanced sampling strategies
23
24pub mod adversarial;
25mod augmentation;
26mod batch;
27mod callbacks;
28pub mod checkpoint;
29mod crossval;
30mod curriculum;
31mod data;
32mod distillation;
33mod dropblock;
34pub mod early_stopping;
35mod ensemble;
36mod error;
37mod few_shot;
38mod gradient_accumulation;
39mod gradient_centralization;
40mod hyperparameter;
41mod label_smoothing;
42mod logging;
43pub mod lora;
44mod loss;
45mod lr_scheduler;
46mod memory;
47mod meta_learning;
48mod metrics;
49mod mixed_precision;
50mod model;
51mod multitask;
52pub mod neural_ode;
53pub mod online_learning;
54mod optimizer;
55mod optimizers;
56mod pruning;
57mod quantization;
58mod regularization;
59mod sampling;
60mod scheduler;
61mod stochastic_depth;
62mod trainer;
63mod transfer;
64mod utils;
65pub mod weight_init;
66
67#[cfg(feature = "structured-logging")]
68pub mod structured_logging;
69
70pub use augmentation::{
71    center_crop_2d,
72    clip,
73    cutmix,
74    denormalize,
75    dropout,
76    dropout_mask,
77    gaussian_noise,
78    mixup,
79    normalize,
80    random_crop_2d,
81    random_hflip,
82    random_vflip,
83    // Functional API (v2)
84    AugRng,
85    AugStats,
86    AugmentationError,
87    AugmentationPipeline,
88    AugmentationStep,
89    CompositeAugmenter,
90    CutMixAugmenter,
91    CutOutAugmenter,
92    DataAugmenter,
93    MixupAugmenter,
94    NoAugmentation,
95    NoiseAugmenter,
96    RandomErasingAugmenter,
97    RotationAugmenter,
98    ScaleAugmenter,
99};
100pub use batch::{extract_batch, BatchConfig, BatchIterator, DataShuffler};
101pub use callbacks::{
102    BatchCallback, Callback, CallbackList, CheckpointCallback, CheckpointCompression,
103    EarlyStoppingCallback, EpochCallback, GradientAccumulationCallback, GradientAccumulationStats,
104    GradientMonitor, GradientScalingStrategy, GradientSummary, HistogramCallback, HistogramStats,
105    LearningRateFinder, ModelEMACallback, ProfilingCallback, ProfilingStats,
106    ReduceLrOnPlateauCallback, SWACallback, TrainingCheckpoint, ValidationCallback,
107};
108pub use error::{TrainError, TrainResult};
109pub use logging::{
110    ConsoleLogger, CsvLogger, FileLogger, JsonlLogger, LoggingBackend, MetricsLogger,
111    TensorBoardLogger,
112};
113pub use loss::{
114    BCEWithLogitsLoss, ConstraintViolationLoss, ContrastiveLoss, CrossEntropyLoss, DiceLoss,
115    FocalLoss, HingeLoss, HuberLoss, KLDivergenceLoss, LogicalLoss, Loss, LossConfig, MseLoss,
116    PolyLoss, RuleSatisfactionLoss, TripletLoss, TverskyLoss,
117};
118pub use lr_scheduler::{
119    CosineAnnealingScheduler, CyclicalScheduler, LrSchedulerV2,
120    OneCycleLrScheduler as OneCyclePolicyScheduler, SchedulerConfig, SchedulerError, SchedulerType,
121    StepDecayScheduler, WarmupScheduler,
122};
123pub use metrics::{
124    Accuracy, BalancedAccuracy, CohensKappa, ConfusionMatrix, DiceCoefficient,
125    ExpectedCalibrationError, F1Score, IoU, MatthewsCorrelationCoefficient,
126    MaximumCalibrationError, MeanAveragePrecision, MeanIoU, Metric, MetricTracker,
127    NormalizedDiscountedCumulativeGain, PerClassMetrics, Precision, Recall, RocCurve, TopKAccuracy,
128};
129pub use model::{AutodiffModel, DynamicModel, LinearModel, Model};
130pub use optimizer::{
131    AdaBeliefOptimizer, AdaMaxOptimizer, AdagradOptimizer, AdamOptimizer, AdamPOptimizer,
132    AdamWOptimizer, GradClipMode, LambOptimizer, LarsOptimizer, LionConfig, LionOptimizer,
133    LookaheadOptimizer, NAdamOptimizer, Optimizer, OptimizerConfig, ProdigyConfig,
134    ProdigyOptimizer, RAdamOptimizer, RMSpropOptimizer, SamOptimizer, ScheduleFreeAdamW,
135    ScheduleFreeConfig, SgdOptimizer, SophiaConfig, SophiaOptimizer, SophiaVariant,
136};
137pub use regularization::{
138    CompositeRegularization, ElasticNetRegularization, GroupLassoRegularization, L1Regularization,
139    L2Regularization, MaxNormRegularization, OrthogonalRegularization, Regularizer,
140    SpectralNormalization,
141};
142pub use scheduler::{
143    CosineAnnealingLrScheduler, CyclicLrMode, CyclicLrScheduler, ExponentialLrScheduler,
144    LrScheduler, MultiStepLrScheduler, NoamScheduler, OneCycleLrScheduler, PlateauMode,
145    PolynomialDecayLrScheduler, ReduceLROnPlateauScheduler, SgdrScheduler, StepLrScheduler,
146    WarmupCosineLrScheduler,
147};
148pub use trainer::{Trainer, TrainerConfig, TrainingHistory, TrainingState};
149
150// Curriculum learning
151pub use curriculum::{
152    CompetenceCurriculum, CurriculumManager, CurriculumStrategy, ExponentialCurriculum,
153    LinearCurriculum, SelfPacedCurriculum, TaskCurriculum,
154};
155
156// Transfer learning
157pub use transfer::{
158    DiscriminativeFineTuning, FeatureExtractorMode, LayerFreezingConfig, ProgressiveUnfreezing,
159    TransferLearningManager,
160};
161
162// Hyperparameter optimization
163pub use hyperparameter::{
164    AcquisitionFunction, BayesianOptimization, GaussianProcess, GpKernel, GridSearch,
165    HyperparamConfig, HyperparamResult, HyperparamSpace, HyperparamValue, RandomSearch,
166};
167
168// Cross-validation
169pub use crossval::{
170    CrossValidationResults, CrossValidationSplit, KFold, LeaveOneOut, StratifiedKFold,
171    TimeSeriesSplit,
172};
173
174// Ensembling
175pub use ensemble::{
176    AveragingEnsemble, BaggingHelper, Ensemble, ModelSoup, SoupRecipe, StackingEnsemble,
177    VotingEnsemble, VotingMode,
178};
179
180// Multi-task learning
181pub use multitask::{MultiTaskLoss, PCGrad, TaskWeightingStrategy};
182
183// Knowledge distillation
184pub use distillation::{AttentionTransferLoss, DistillationLoss, FeatureDistillationLoss};
185
186// Label smoothing
187pub use label_smoothing::{LabelSmoothingLoss, MixupLoss};
188
189// Memory management and profiling
190pub use memory::{
191    CheckpointStrategy, GradientCheckpointConfig, MemoryBudgetManager, MemoryEfficientTraining,
192    MemoryProfilerCallback, MemorySettings, MemoryStats,
193};
194
195// Data loading and preprocessing
196pub use data::{
197    CsvLoader, DataPreprocessor, Dataset, LabelEncoder, OneHotEncoder, PreprocessingMethod,
198};
199
200// Utilities for model introspection and analysis
201pub use utils::{
202    compare_models, compute_gradient_stats, format_duration, print_gradient_report, GradientStats,
203    LrRangeTestAnalyzer, ModelSummary, ParameterDifference, ParameterStats, TimeEstimator,
204};
205
206// Model pruning and compression
207pub use pruning::{
208    GlobalPruner, GradientPruner, LayerPruningStats, MagnitudePruner, Pruner, PruningConfig,
209    PruningMask, PruningStats, StructuredPruner, StructuredPruningAxis,
210};
211
212// Advanced sampling strategies
213pub use sampling::{
214    BatchReweighter, ClassBalancedSampler, CurriculumSampler, FocalSampler, HardNegativeMiner,
215    ImportanceSampler, MiningStrategy, OnlineHardExampleMiner, ReweightingStrategy,
216};
217
218// Model quantization and compression
219pub use quantization::{
220    BitWidth, DynamicRangeCalibrator, Granularity, QuantizationAwareTraining, QuantizationConfig,
221    QuantizationMode, QuantizationParams, QuantizedTensor, Quantizer,
222};
223
224// Mixed precision training
225pub use mixed_precision::{
226    AutocastContext, GradientScaler, LossScaler, MixedPrecisionStats, MixedPrecisionTrainer,
227    PrecisionMode,
228};
229
230// Few-shot learning
231pub use few_shot::{
232    DistanceMetric, EpisodeSampler, FewShotAccuracy, MatchingNetwork, PrototypicalDistance,
233    ShotType, SupportSet,
234};
235
236// Meta-learning
237pub use meta_learning::{
238    MAMLConfig, MetaLearner, MetaStats, MetaTask, Reptile, ReptileConfig, MAML,
239};
240
241// Gradient accumulation and micro-batching
242pub use gradient_accumulation::{
243    AccumulationConfig, AccumulationError, AccumulationStats, GradientAccumulator, GradientBuffer,
244};
245
246// Gradient centralization
247pub use gradient_centralization::{GcConfig, GcStats, GcStrategy, GradientCentralization};
248
249// Stochastic Depth (DropPath)
250pub use stochastic_depth::{DropPath, ExponentialStochasticDepth, LinearStochasticDepth};
251
252// DropBlock regularization
253pub use dropblock::{DropBlock, LinearDropBlockScheduler};
254
255// Early stopping
256pub use early_stopping::{
257    EarlyStoppingConfig, EarlyStoppingDecision, EarlyStoppingMonitor, MonitorMode,
258    MultiMetricMonitor, MultiMetricPolicy, PlateauDetector, TrainingProgress,
259};
260
261// Optimizer checkpointing
262pub use checkpoint::{
263    deserialize_checkpoint, serialize_checkpoint, CheckpointError, CheckpointFormat,
264    CheckpointManager, CheckpointMetadata, LossTracker, OptimizerCheckpoint, ParamState,
265};
266
267// Weight initialization strategies
268pub use weight_init::{
269    compute_fans, constant_init, gain_for_activation, kaiming_normal, kaiming_uniform,
270    lecun_normal, lecun_uniform, normal_init, ones_init, orthogonal_init, uniform_init,
271    xavier_normal, xavier_uniform, zeros_init, FanMode, InitError, InitRng, InitStats,
272};
273
274// Online learning algorithms
275pub use online_learning::{
276    online_evaluate, Ftrl, OGDLoss, OnlineError, OnlineGradientDescent, OnlineLearner, OnlineStats,
277    OnlineUpdateResult, PAVariant, PassiveAggressive, Perceptron,
278};
279
280// Adversarial training utilities
281pub use adversarial::{
282    adversarial_training_loss, fgsm, pgd, project_l1, project_l2, project_linf, robustness_eval,
283    AdversarialError, AdversarialExample, AdversarialTrainStats, AttackConfig, AttackLoss,
284    AttackModel, CrossEntropyAttackLoss, LinearAttackModel, MseAttackLoss, PerturbNorm,
285};
286
287// Neural ODE — continuous-depth models with adjoint sensitivity
288pub use neural_ode::{
289    dopri5_solve, rk4_solve, AdaptiveSolution, AdjointResult, NeuralOde, OdeError, OdeFunc,
290    OdeSolution, OdeSolverConfig,
291};
292
293// LoRA — low-rank adaptation for parameter-efficient fine-tuning
294pub use lora::{LoraAdapter, LoraConfig, LoraError, LoraLayer};