tensorlogic_train/
lib.rs

1//! Training scaffolds: loss wiring, schedules, callbacks.
2//!
3//! This crate provides comprehensive training infrastructure for Tensorlogic models:
4//! - Loss functions (standard and logical constraint-based)
5//! - Optimizer wrappers around SciRS2
6//! - Training loops with callbacks
7//! - Batch management
8//! - Validation and metrics
9//! - Regularization techniques
10//! - Data augmentation
11//! - Logging and monitoring
12//! - Curriculum learning strategies
13//! - Transfer learning utilities
14//! - Hyperparameter optimization (grid search, random search)
15//! - Cross-validation utilities
16//! - Model ensembling
17
18mod augmentation;
19mod batch;
20mod callbacks;
21mod crossval;
22mod curriculum;
23mod ensemble;
24mod error;
25mod hyperparameter;
26mod logging;
27mod loss;
28mod metrics;
29mod model;
30mod optimizer;
31mod regularization;
32mod scheduler;
33mod trainer;
34mod transfer;
35
36pub use augmentation::{
37    CompositeAugmenter, DataAugmenter, MixupAugmenter, NoAugmentation, NoiseAugmenter,
38    RotationAugmenter, ScaleAugmenter,
39};
40pub use batch::{extract_batch, BatchConfig, BatchIterator, DataShuffler};
41pub use callbacks::{
42    BatchCallback, Callback, CallbackList, CheckpointCallback, EarlyStoppingCallback,
43    EpochCallback, GradientAccumulationCallback, GradientMonitor, GradientSummary,
44    HistogramCallback, HistogramStats, LearningRateFinder, ModelEMACallback, ProfilingCallback,
45    ProfilingStats, ReduceLrOnPlateauCallback, SWACallback, TrainingCheckpoint, ValidationCallback,
46};
47pub use error::{TrainError, TrainResult};
48pub use logging::{ConsoleLogger, FileLogger, LoggingBackend, MetricsLogger, TensorBoardLogger};
49pub use loss::{
50    BCEWithLogitsLoss, ConstraintViolationLoss, ContrastiveLoss, CrossEntropyLoss, DiceLoss,
51    FocalLoss, HingeLoss, HuberLoss, KLDivergenceLoss, LogicalLoss, Loss, LossConfig, MseLoss,
52    RuleSatisfactionLoss, TripletLoss, TverskyLoss,
53};
54pub use metrics::{
55    Accuracy, BalancedAccuracy, CohensKappa, ConfusionMatrix, F1Score,
56    MatthewsCorrelationCoefficient, Metric, MetricTracker, PerClassMetrics, Precision, Recall,
57    RocCurve, TopKAccuracy,
58};
59pub use model::{AutodiffModel, DynamicModel, LinearModel, Model};
60pub use optimizer::{
61    AdaBeliefOptimizer, AdaMaxOptimizer, AdagradOptimizer, AdamOptimizer, AdamWOptimizer,
62    GradClipMode, LambOptimizer, LarsOptimizer, LookaheadOptimizer, NAdamOptimizer, Optimizer,
63    OptimizerConfig, RAdamOptimizer, RMSpropOptimizer, SamOptimizer, SgdOptimizer,
64};
65pub use regularization::{
66    CompositeRegularization, ElasticNetRegularization, L1Regularization, L2Regularization,
67    Regularizer,
68};
69pub use scheduler::{
70    CosineAnnealingLrScheduler, CyclicLrMode, CyclicLrScheduler, ExponentialLrScheduler,
71    LrScheduler, MultiStepLrScheduler, NoamScheduler, OneCycleLrScheduler, PlateauMode,
72    PolynomialDecayLrScheduler, ReduceLROnPlateauScheduler, StepLrScheduler,
73    WarmupCosineLrScheduler,
74};
75pub use trainer::{Trainer, TrainerConfig, TrainingHistory, TrainingState};
76
77// Curriculum learning
78pub use curriculum::{
79    CompetenceCurriculum, CurriculumManager, CurriculumStrategy, ExponentialCurriculum,
80    LinearCurriculum, SelfPacedCurriculum, TaskCurriculum,
81};
82
83// Transfer learning
84pub use transfer::{
85    DiscriminativeFineTuning, FeatureExtractorMode, LayerFreezingConfig, ProgressiveUnfreezing,
86    TransferLearningManager,
87};
88
89// Hyperparameter optimization
90pub use hyperparameter::{
91    GridSearch, HyperparamConfig, HyperparamResult, HyperparamSpace, HyperparamValue, RandomSearch,
92};
93
94// Cross-validation
95pub use crossval::{
96    CrossValidationResults, CrossValidationSplit, KFold, LeaveOneOut, StratifiedKFold,
97    TimeSeriesSplit,
98};
99
100// Ensembling
101pub use ensemble::{
102    AveragingEnsemble, BaggingHelper, Ensemble, StackingEnsemble, VotingEnsemble, VotingMode,
103};