Skip to main content

trustformers_training/
lib.rs

1//! # TrustformeRS Training
2//!
3//! Training infrastructure and utilities for transformer models in Rust.
4//!
5//! This crate provides comprehensive tools for training transformer models, including:
6//!
7//! - **Distributed training**: Data parallelism, model parallelism, pipeline parallelism
8//! - **Memory optimization**: Gradient checkpointing, mixed precision, ZeRO optimizer
9//! - **Training strategies**: Few-shot learning, continual learning, curriculum learning
10//! - **Stability**: Gradient clipping, loss scaling, NaN detection and recovery
11//! - **Monitoring**: Real-time metrics, experiment tracking, resource monitoring
12//!
13//! ## Quick Start
14//!
15//! ```rust,no_run
16//! use trustformers_training::{
17//!     Trainer, TrainingConfig,
18//!     distributed::DistributedConfig,
19//! };
20//! use trustformers_core::tensor::Tensor;
21//!
22//! // Configure training
23//! let config = TrainingConfig {
24//!     batch_size: 32,
25//!     learning_rate: 1e-4,
26//!     num_epochs: 10,
27//!     gradient_accumulation_steps: 4,
28//!     ..Default::default()
29//! };
30//!
31//! // Create trainer
32//! let mut trainer = Trainer::new(model, optimizer, config)?;
33//!
34//! // Train model
35//! trainer.train(train_dataset, val_dataset)?;
36//! # Ok::<(), Box<dyn std::error::Error>>(())
37//! ```
38//!
39//! ## Distributed Training
40//!
41//! ```rust,no_run
42//! use trustformers_training::distributed::{
43//!     DistributedConfig, DistributedStrategy,
44//! };
45//!
46//! let dist_config = DistributedConfig {
47//!     strategy: DistributedStrategy::DataParallel,
48//!     world_size: 8,
49//!     backend: "nccl",
50//!     ..Default::default()
51//! };
52//!
53//! let trainer = Trainer::new_distributed(model, optimizer, config, dist_config)?;
54//! # Ok::<(), Box<dyn std::error::Error>>(())
55//! ```
56//!
57//! ## Memory Optimization
58//!
59//! - **Gradient Checkpointing**: Trade compute for memory
60//! - **Mixed Precision**: FP16/BF16 training with loss scaling
61//! - **ZeRO**: Sharded optimizer states and gradients
62//! - **Activation Checkpointing**: Recompute activations during backward
63//!
64//! ## Features
65//!
66//! - `distributed`: Multi-GPU and multi-node training
67//! - `mixed-precision`: FP16/BF16 training support
68//! - `gradient-checkpointing`: Memory-efficient training
69//! - `wandb`: Weights & Biases integration
70//! - `tensorboard`: TensorBoard logging
71
72// Allow large error types in Result (TrustformersError is large by design)
73#![allow(clippy::result_large_err)]
74// Allow common patterns in training code
75#![allow(clippy::too_many_arguments)]
76#![allow(clippy::type_complexity)]
77#![allow(clippy::excessive_nesting)]
78// Allow training-specific patterns
79#![allow(clippy::await_holding_lock)]
80#![allow(clippy::needless_range_loop)]
81#![allow(clippy::empty_line_after_doc_comments)]
82#![allow(clippy::manual_clamp)]
83#![allow(clippy::derivable_impls)]
84#![allow(clippy::vec_init_then_push)]
85#![allow(clippy::ptr_arg)]
86
87pub mod adaptive_gradient_scaling;
88pub mod adaptive_learning_rate;
89pub mod advanced_stability_monitor;
90pub mod auto_parallelism;
91pub mod config_validation;
92pub mod continual;
93pub mod cost_tracking;
94pub mod data_pipeline;
95pub mod distributed;
96pub mod elastic_training;
97pub mod error_codes;
98pub mod error_handling;
99pub mod experiment_management;
100pub mod expert_parallelism;
101pub mod few_shot;
102pub mod framework_integration;
103pub mod gradient;
104pub mod gradient_anomaly_recovery;
105pub mod hyperopt;
106pub mod losses;
107pub mod memory_optimization;
108pub mod metrics;
109pub mod mixed_precision;
110pub mod model_versioning;
111pub mod multicloud;
112pub mod nas_integration;
113pub mod online_learning;
114pub mod parallelism_3d;
115pub mod qat;
116pub mod resource_scheduling;
117pub mod ring_attention;
118pub mod rlhf;
119pub mod sequence_parallelism;
120pub mod simplified_trainer;
121pub mod tensor_parallelism;
122pub mod trainer;
123pub mod training_args;
124pub mod training_dynamics;
125pub mod training_monitor;
126pub mod training_orchestration;
127
128pub use continual::{
129    CatastrophicPreventionStrategy, ContinualLearningConfig, ContinualLearningManager, EWCConfig,
130    EWCTrainer, ExperienceBuffer, FisherInformation, MemoryReplay, MemoryReplayConfig,
131    ProgressiveConfig, ProgressiveNetwork, RegularizationMethod, TaskBoundaryDetector, TaskInfo,
132    TaskModule, TaskTransition,
133};
134pub use distributed::{
135    init_distributed_training, utils as distributed_utils, DataParallelTrainer, DistributedBackend,
136    DistributedConfig, ProcessGroup,
137};
138pub use experiment_management::{
139    ABTestConfig, ABTestResults, ABTestStatus, ArtifactType, DataLineage, DataSplit,
140    EnvironmentInfo, ExperimentFilters, ExperimentManager, ExperimentMetadata, ExperimentReport,
141    ExperimentResults, ExperimentStatus, GPUInfo, HardwareInfo, HyperparameterComparison,
142    HyperparameterConfig, ModelArtifact, ModelLineage, ModelProvenance, ModelSizeInfo,
143    ParameterChange, PipelineStep, QualityAssuranceStep, SystemInfo, TrainingPipeline,
144};
145pub use few_shot::{
146    AdaptationConfig, CrossTaskGeneralizer, FewShotConfig, FewShotExample, FewShotMethod,
147    GeneralizationConfig, ICLExample, InContextConfig, InContextLearner, MAMLConfig, MAMLTrainer,
148    MetaLearningAlgorithm, PromptConfig, PromptTuner, ReptileConfig, ReptileTrainer, SoftPrompt,
149    SupportSet, TaskAdapter, TaskDescriptor, TaskEmbedding,
150};
151pub use gradient::GradientUtils;
152pub use hyperopt::{
153    // Efficiency features
154    AcquisitionFunction,
155    AcquisitionFunctionType,
156    AdvancedEarlyStoppingConfig,
157    ArmGenerationStrategy,
158    ArmStatistics,
159    BanditAlgorithm,
160    BanditConfig,
161    BanditOptimizer,
162    BayesianOptimization,
163    CategoricalParameter,
164    ContinuousParameter,
165    Direction,
166    DiscreteParameter,
167    // Configuration
168    EarlyStoppingConfig,
169    EarlyStoppingStrategy,
170    EvaluationJob,
171    EvaluationResult,
172    ExplorationStrategy,
173    FaultToleranceConfig,
174    GPSampler,
175    GPUAllocation,
176    GridSearch,
177    HalvingStrategy,
178    HyperParameter,
179    Hyperband,
180    // Core types
181    HyperparameterTuner,
182    JobStatus,
183    KernelType,
184    LoadBalancer,
185    LogParameter,
186    OptimizationDirection,
187    OptimizationResult,
188    // PBT (Population-based Training)
189    PBTConfig,
190    PBTMember,
191    PBTStats,
192    ParallelEvaluationConfig,
193    ParallelEvaluator,
194    ParallelStrategy,
195    ParameterValue,
196    PopulationBasedTraining,
197    PriorityLevel,
198    PruningConfig,
199    PruningStrategy,
200    RandomSampler,
201    RandomSearch,
202    ResourceAllocation,
203    ResourceUsage,
204    RewardFunction,
205    // Samplers
206    Sampler,
207    SamplerConfig,
208    // Search space
209    SearchSpace,
210    // Strategies
211    SearchStrategy,
212    StudyStatistics,
213    SuccessiveHalving,
214    SurrogateConfig,
215    SurrogateModel,
216    SurrogateModelType,
217    SurrogateOptimizer,
218    TPESampler,
219    // Trials
220    Trial,
221    TrialHistory,
222    TrialMetrics,
223    TrialResult,
224    TrialState,
225    TunerConfig,
226    WarmStartConfig,
227    WarmStartDataSource,
228    WarmStartStrategy,
229};
230pub use losses::{CrossEntropyLoss, Loss, MSELoss};
231pub use metrics::{Accuracy, F1Score, Metric, MetricCollection, Perplexity};
232pub use mixed_precision::{
233    utils as mixed_precision_utils, AMPManager, AdvancedMixedPrecisionConfig,
234    AdvancedMixedPrecisionManager, ComputeOptimizationManager, ComputeOptimizationReport,
235    DynamicBatchingConfig, DynamicBatchingManager, DynamicBatchingReport, LayerScalingConfig,
236    LossScaler, MixedPrecisionConfig, MixedPrecisionReport,
237};
238pub use qat::{
239    fake_quantize, fake_quantize_mixed_bit, qat_loss, ActivationQuantizer, CalibrationDataset,
240    LayerQuantConfig, MixedBitQATTrainer, MixedBitStrategy, QATConfig, QATConv2d, QATLinear,
241    QATModel, QATTrainer, QuantStats, QuantizationGradients, QuantizationParams, QuantizedModel,
242};
243pub use rlhf::{
244    ConstitutionalPrinciple, GenerationResult, HumanFeedback, PPOConfig, PPOStepResult, PPOTrainer,
245    PolicyModel, PreferencePair, RLHFConfig, RLHFMetrics, RLHFPhase, RewardModel,
246    RewardModelConfig, RewardPrediction, ValueModel,
247};
248pub use trainer::{EarlyStoppingCallback, LogEntry, Trainer, TrainerCallback, TrainingState};
249pub use training_args::{EvaluationStrategy, SaveStrategy, TrainingArguments};
250pub use training_dynamics::{
251    ConvergenceMetrics, GradientFlowMetrics, LossLandscapeMetrics, TrainingDynamicsAnalyzer,
252    TrainingDynamicsConfig, TrainingDynamicsReport, TrainingDynamicsSnapshot,
253    WeightEvolutionMetrics,
254};
255
256// New module exports
257pub use adaptive_gradient_scaling::{
258    AdaptiveGradientScaler, AdaptiveGradientScalingConfig, AdaptiveScalingStatistics,
259    GradientScalingResult, LayerGradientStats as AdaptiveLayerGradientStats, StabilityTrend,
260};
261pub use adaptive_learning_rate::{
262    AdaptationStrategy as LRAdaptationStrategy, AdaptiveLRStatistics, AdaptiveLearningRateConfig,
263    AdaptiveLearningRateScheduler, LearningRateUpdate, PerformanceTrend, SchedulerState,
264    TrainingDynamics as LRTrainingDynamics,
265};
266pub use auto_parallelism::{
267    utils as auto_parallelism_utils, ArchitectureType, AutoParallelismConfig,
268    AutoParallelismSelector, DeviceType, EvaluationMethod, HardwareConstraints, ModelConstraints,
269    NetworkTopology, OptimizationObjective, ParallelismStrategy, PerformanceRequirements,
270    SelectionAlgorithm,
271};
272pub use data_pipeline::{
273    ActiveLearningConfig, ActiveLearningIntegration, ActiveLearningManager, ActiveLearningStats,
274    AdaptationStrategy, AdaptiveAugmentationConfig, AlignmentConfig, AlignmentMethod,
275    AnnotationConfig, AnnotationSource, AudioAugmentationType, AugmentationScheduling,
276    AugmentationStats, AugmentationStrategy, AugmentationStrategyType, BatchingConfig,
277    BatchingStrategy, CacheType, CachingConfig,
278    CompressionAlgorithm as DataPipelineCompressionAlgorithm, CoreSetMethod,
279    CurriculumLearningConfig, CurriculumLearningManager, CurriculumScheduling,
280    CurriculumSchedulingStrategy, CurriculumStage, CurriculumStats, CurriculumStrategy, DataFilter,
281    DataPipeline, DataPipelineConfig, DataSample, DataSelectionCriteria, DataSource,
282    DataSourceType, DataValidationConfig, DataValidator, DifficultyAssessment, DisagreementMeasure,
283    DistributedProcessingConfig, DiversityConstraint, DiversityMeasure, DynamicAssessmentMethod,
284    DynamicAugmentationConfig, DynamicAugmentationManager, ErrorHandling, EvictionPolicy,
285    FeatureExtractionConfig, FeatureExtractionMethod, FilterType, FusionStrategy,
286    ImageAugmentationType, LoadBalancingStrategy as DataPipelineLoadBalancingStrategy,
287    MissingModalityHandling, Modality, ModalityProcessor, ModalityType, MultiModalConfig,
288    MultiModalHandler, MultiModalPreprocessing, MultiModalStats, NormalizationConfig,
289    NormalizationType, PacingFunction, PacingType, PreprocessingConfig, PreprocessingStep,
290    PreprocessingStepType, ProcessingBackend, QualityAssessmentMethod, QualityControl,
291    QueryStrategy, SamplingConfig, ScheduleType, ShuffleConfig, ShuffleStrategy, StreamingDataset,
292    StreamingDatasetConfig, StreamingStats, SuccessCriteria, SynchronizationConfig,
293    TextAugmentationType, TokenAugmentationType, UncertaintyMeasure, ValidationError,
294    ValidationResult, ValidationRule, ValidationRuleType, ValidationSeverity, ValidationStats,
295    ValidationStrategy, ValidationWarning, Validator,
296};
297pub use elastic_training::{
298    ElasticTrainingConfig, ElasticTrainingCoordinator, ScalingDecision, ScalingType, SystemStatus,
299    WorkerInfo, WorkerStatus,
300};
301pub use expert_parallelism::{
302    utils as expert_parallelism_utils, ExpertAssignment, ExpertCommunicationPattern,
303    ExpertParallelism, ExpertParallelismConfig, ExpertRoutingStrategy, LoadBalancingStats,
304    LoadBalancingStrategy, TokenRouting,
305};
306pub use framework_integration::{
307    AggregationFunction, ArtifactConfig, ArtifactInfo, AudioLoggingConfig, AutoConnectConfig,
308    ChartType, ClearMLArtifactConfig, ClearMLConfig, ClearMLTaskType, ColorFormat,
309    ConflictResolution, CustomArtifact, CustomMetric, CustomMonitoring, CustomScalar,
310    ExperimentMetadata as FrameworkExperimentMetadata,
311    ExperimentStatus as FrameworkExperimentStatus, ExperimentTracker, ExportConfig, ExportFormat,
312    ExportFrequency, FrameworkIntegrationManager, GraphLoggingConfig, HistogramConfig,
313    ImageLoggingConfig, IntegrationConfig, IntegrationType, MLflowAdvancedConfig, MLflowAuth,
314    MLflowAuthType, MLflowConfig, MLflowTracker, MetricType, MetricValue, ModelRegistrationConfig,
315    ModelStage, NeptuneConfig, NeptuneExperimentConfig, NeptuneMonitoringConfig,
316    ParameterValue as FrameworkParameterValue, ProfilingConfig, ResumeConfig, ScalarLayout,
317    SyncConfig, SyncFrequency, TensorBoardAdvancedConfig, TensorBoardConfig, TensorBoardTracker,
318    UpdateFrequency, WandBAdvancedConfig, WandBConfig, WandBTracker, WatchModelConfig,
319};
320pub use memory_optimization::{
321    CPUOffloadManager, GradientCheckpointWrapper, MemoryOptimizationConfig,
322    MemoryOptimizationStats, MemoryOptimizer,
323};
324pub use multicloud::{
325    AlertType, AuthConfig, AuthType, BudgetAlert, CloudProvider, CloudScheduler,
326    CommunicationPattern, CompressionAlgorithm, CompressionConfig, CostConfig, CostEntry,
327    CostOptimizationStrategy, InstanceType, MultiCloudConfig, MultiCloudOrchestrator,
328    MultiCloudProcessGroup, NodeInfo, NodeStatus, OrchestrationStrategy,
329    PerformanceMetrics as MultiCloudPerformanceMetrics, RecoveryStrategy, SchedulingAlgorithm,
330};
331pub use nas_integration::{
332    Architecture, NASAlgorithm, NASConfig, NASController, Operation, PerformanceMetrics,
333    SearchSpaceConfig, TargetPlatform,
334};
335pub use parallelism_3d::{
336    AggregateParallelismStats, CommBackend, MemoryOptimization, Parallelism3D,
337    Parallelism3DManager, Parallelism3DStats, ParallelismConfig, PipelineSchedule,
338};
339pub use sequence_parallelism::{
340    utils as sequence_parallelism_utils, AttentionCommunication, SequenceChunk,
341    SequenceCommunicationPattern, SequenceMemoryOptimization, SequenceParallelism,
342    SequenceParallelismConfig, SequenceParallelismStats, SequenceSplittingStrategy,
343};
344pub use tensor_parallelism::{
345    utils as tensor_parallelism_utils, CommunicationRequirement, TensorCommunicationPattern,
346    TensorMemoryOptimization, TensorOperation, TensorOperationType, TensorParallelism,
347    TensorParallelismConfig, TensorParallelismStatistics, TensorPartition,
348    TensorPartitioningStrategy,
349};
350pub use training_monitor::{
351    AnomalyReport, AnomalyType, HealthStatus, PerformanceStats, TrainingHealthStatus,
352    TrainingMonitor, TrainingMonitorConfig, TrainingReport,
353};
354
355// Advanced stability monitoring exports
356pub use advanced_stability_monitor::{
357    AdvancedStabilityConfig, AdvancedStabilityMonitor, LossLandscapeAnalysis, PatternDetector,
358    PredictedAnomalyType, PredictiveAnomaly, PreventiveAction, RiskLevel, StabilityReport,
359    StabilityScore, TrainerParameters, TrainingDynamics, TrendDirection,
360};
361
362// Gradient anomaly recovery exports
363pub use gradient_anomaly_recovery::{
364    AdaptiveThresholds, GradientAnomaly, GradientAnomalyType, GradientRecoveryConfig,
365    GradientRecoveryManager, GradientRecoveryStrategy, GradientSeverity, LayerGradientStats,
366    RecoveryResult, RecoveryStatistics,
367};
368
369// Production feature exports
370pub use cost_tracking::{
371    AlertThreshold, BillingModel, Budget, BudgetFilters, BudgetPeriod, BudgetStatus, CostBreakdown,
372    CostDataPoint, CostDriver, CostEntry as CostTrackingCostEntry, CostForecastingModel,
373    CostRecommendation, CostReport, CostStatistics, CostTracker, CostTrend, EfficiencyMetrics,
374    ForecastingAccuracy, ForecastingParameters, ImplementationEffort, NotificationType,
375    RecommendationCategory, RecommendationPriority, ReportType, TimeRange,
376};
377pub use model_versioning::{
378    ModelRegistry, ModelStatus, ModelVersion, ModelVersioningManager,
379    PerformanceMetrics as ModelVersioningPerformanceMetrics, TrainingConfig, VersionComparison,
380};
381pub use online_learning::{
382    ConceptDrift, DriftType, OnlineDataPoint, OnlineLearningConfig, OnlineLearningError,
383    OnlineLearningManager, OnlineStatistics, PerformanceWindow,
384};
385pub use resource_scheduling::{
386    AlertSeverity, AllocationStatus, CostAlert, CostOptimizationRecommendation, CostSnapshot,
387    LocalityPreference, Priority, RecommendationType,
388    ResourceAllocation as SchedulingResourceAllocation, ResourceConstraints, ResourcePool,
389    ResourceRequest, ResourceScheduler, ResourceType,
390    SchedulingAlgorithm as ResourceSchedulingAlgorithm, SchedulingStatistics, StorageSpeed,
391};
392pub use ring_attention::{
393    utils as ring_attention_utils, ModelParams, RingAttentionBlock, RingAttentionConfig,
394    RingAttentionManager, RingAttentionStats, RingCommunicationPattern, RingKVPair,
395};
396pub use training_orchestration::{
397    CheckpointConfig, CheckpointInfo, EarlyStoppingConfig as OrchestrationEarlyStoppingConfig,
398    JobEvent, JobPriority, JobScheduler, JobStatus as OrchestrationJobStatus, ModelConfig,
399    OrchestrationStatistics, ResourceNode, ResourceRequirements, SchedulingStrategy, TrainingJob,
400    TrainingJobConfig, TrainingMetrics, TrainingOrchestrator,
401};
402
403// API improvement exports
404pub use config_validation::{
405    ConfigSchema, ConfigValidator, Constraint, FieldSchema, FieldType, Severity, Validatable,
406    ValidatedConfig, ValidationError as ConfigValidationError, ValidationReport,
407    ValidationRule as ConfigValidationRule,
408};
409pub use error_codes::{
410    get_error_info, get_recovery_actions, is_critical_error, ErrorCodeInfo, ErrorCodeRegistry,
411};
412pub use error_handling::{
413    ErrorContext, ErrorManager, ErrorPattern, ErrorSeverity, ErrorStatistics, ErrorTrend,
414    ErrorType, RecoveryAction, RecoveryStrategy as ErrorRecoveryStrategy, RecoverySuggestion,
415    SystemInfo as ErrorSystemInfo, TrainingError, TrainingErrorExt, TrainingResult,
416};
417pub use simplified_trainer::{
418    CheckpointCallback, EarlyStoppingMode, EpochResult, LogLevel, LoggingCallback, MetricsCallback,
419    ProgressCallback, SimpleCallback, SimpleTrainer, SimpleTrainerBuilder, SimpleTrainingConfig,
420    TrainingResults,
421};