Skip to main content

Crate trustformers_training

Crate trustformers_training 

Source
Expand description

§TrustformeRS Training

Training infrastructure and utilities for transformer models in Rust.

This crate provides comprehensive tools for training transformer models, including:

  • Distributed training: Data parallelism, model parallelism, pipeline parallelism
  • Memory optimization: Gradient checkpointing, mixed precision, ZeRO optimizer
  • Training strategies: Few-shot learning, continual learning, curriculum learning
  • Stability: Gradient clipping, loss scaling, NaN detection and recovery
  • Monitoring: Real-time metrics, experiment tracking, resource monitoring

§Quick Start

use trustformers_training::{
    Trainer, TrainingConfig,
    distributed::DistributedConfig,
};
use trustformers_core::tensor::Tensor;

// Configure training
let config = TrainingConfig {
    batch_size: 32,
    learning_rate: 1e-4,
    num_epochs: 10,
    gradient_accumulation_steps: 4,
    ..Default::default()
};

// Create trainer
let mut trainer = Trainer::new(model, optimizer, config)?;

// Train model
trainer.train(train_dataset, val_dataset)?;

§Distributed Training

use trustformers_training::distributed::{
    DistributedConfig, DistributedStrategy,
};

let dist_config = DistributedConfig {
    strategy: DistributedStrategy::DataParallel,
    world_size: 8,
    backend: "nccl",
    ..Default::default()
};

let trainer = Trainer::new_distributed(model, optimizer, config, dist_config)?;

§Memory Optimization

  • Gradient Checkpointing: Trade compute for memory
  • Mixed Precision: FP16/BF16 training with loss scaling
  • ZeRO: Sharded optimizer states and gradients
  • Activation Checkpointing: Recompute activations during backward

§Features

  • distributed: Multi-GPU and multi-node training
  • mixed-precision: FP16/BF16 training support
  • gradient-checkpointing: Memory-efficient training
  • wandb: Weights & Biases integration
  • tensorboard: TensorBoard logging

Re-exports§

pub use continual::CatastrophicPreventionStrategy;
pub use continual::ContinualLearningConfig;
pub use continual::ContinualLearningManager;
pub use continual::EWCConfig;
pub use continual::EWCTrainer;
pub use continual::ExperienceBuffer;
pub use continual::FisherInformation;
pub use continual::MemoryReplay;
pub use continual::MemoryReplayConfig;
pub use continual::ProgressiveConfig;
pub use continual::ProgressiveNetwork;
pub use continual::RegularizationMethod;
pub use continual::TaskBoundaryDetector;
pub use continual::TaskInfo;
pub use continual::TaskModule;
pub use continual::TaskTransition;
pub use distributed::init_distributed_training;
pub use distributed::utils as distributed_utils;
pub use distributed::DataParallelTrainer;
pub use distributed::DistributedBackend;
pub use distributed::DistributedConfig;
pub use distributed::ProcessGroup;
pub use experiment_management::ABTestConfig;
pub use experiment_management::ABTestResults;
pub use experiment_management::ABTestStatus;
pub use experiment_management::ArtifactType;
pub use experiment_management::DataLineage;
pub use experiment_management::DataSplit;
pub use experiment_management::EnvironmentInfo;
pub use experiment_management::ExperimentFilters;
pub use experiment_management::ExperimentManager;
pub use experiment_management::ExperimentMetadata;
pub use experiment_management::ExperimentReport;
pub use experiment_management::ExperimentResults;
pub use experiment_management::ExperimentStatus;
pub use experiment_management::GPUInfo;
pub use experiment_management::HardwareInfo;
pub use experiment_management::HyperparameterComparison;
pub use experiment_management::HyperparameterConfig;
pub use experiment_management::ModelArtifact;
pub use experiment_management::ModelLineage;
pub use experiment_management::ModelProvenance;
pub use experiment_management::ModelSizeInfo;
pub use experiment_management::ParameterChange;
pub use experiment_management::PipelineStep;
pub use experiment_management::QualityAssuranceStep;
pub use experiment_management::SystemInfo;
pub use experiment_management::TrainingPipeline;
pub use few_shot::AdaptationConfig;
pub use few_shot::CrossTaskGeneralizer;
pub use few_shot::FewShotConfig;
pub use few_shot::FewShotExample;
pub use few_shot::FewShotMethod;
pub use few_shot::GeneralizationConfig;
pub use few_shot::ICLExample;
pub use few_shot::InContextConfig;
pub use few_shot::InContextLearner;
pub use few_shot::MAMLConfig;
pub use few_shot::MAMLTrainer;
pub use few_shot::MetaLearningAlgorithm;
pub use few_shot::PromptConfig;
pub use few_shot::PromptTuner;
pub use few_shot::ReptileConfig;
pub use few_shot::ReptileTrainer;
pub use few_shot::SoftPrompt;
pub use few_shot::SupportSet;
pub use few_shot::TaskAdapter;
pub use few_shot::TaskDescriptor;
pub use few_shot::TaskEmbedding;
pub use gradient::GradientUtils;
pub use hyperopt::AcquisitionFunction;
pub use hyperopt::AcquisitionFunctionType;
pub use hyperopt::AdvancedEarlyStoppingConfig;
pub use hyperopt::ArmGenerationStrategy;
pub use hyperopt::ArmStatistics;
pub use hyperopt::BanditAlgorithm;
pub use hyperopt::BanditConfig;
pub use hyperopt::BanditOptimizer;
pub use hyperopt::BayesianOptimization;
pub use hyperopt::CategoricalParameter;
pub use hyperopt::ContinuousParameter;
pub use hyperopt::Direction;
pub use hyperopt::DiscreteParameter;
pub use hyperopt::EarlyStoppingConfig;
pub use hyperopt::EarlyStoppingStrategy;
pub use hyperopt::EvaluationJob;
pub use hyperopt::EvaluationResult;
pub use hyperopt::ExplorationStrategy;
pub use hyperopt::FaultToleranceConfig;
pub use hyperopt::GPSampler;
pub use hyperopt::GPUAllocation;
pub use hyperopt::GridSearch;
pub use hyperopt::HalvingStrategy;
pub use hyperopt::HyperParameter;
pub use hyperopt::Hyperband;
pub use hyperopt::HyperparameterTuner;
pub use hyperopt::JobStatus;
pub use hyperopt::KernelType;
pub use hyperopt::LoadBalancer;
pub use hyperopt::LogParameter;
pub use hyperopt::OptimizationDirection;
pub use hyperopt::OptimizationResult;
pub use hyperopt::PBTConfig;
pub use hyperopt::PBTMember;
pub use hyperopt::PBTStats;
pub use hyperopt::ParallelEvaluationConfig;
pub use hyperopt::ParallelEvaluator;
pub use hyperopt::ParallelStrategy;
pub use hyperopt::ParameterValue;
pub use hyperopt::PopulationBasedTraining;
pub use hyperopt::PriorityLevel;
pub use hyperopt::PruningConfig;
pub use hyperopt::PruningStrategy;
pub use hyperopt::RandomSampler;
pub use hyperopt::RandomSearch;
pub use hyperopt::ResourceAllocation;
pub use hyperopt::ResourceUsage;
pub use hyperopt::RewardFunction;
pub use hyperopt::Sampler;
pub use hyperopt::SamplerConfig;
pub use hyperopt::SearchSpace;
pub use hyperopt::SearchStrategy;
pub use hyperopt::StudyStatistics;
pub use hyperopt::SuccessiveHalving;
pub use hyperopt::SurrogateConfig;
pub use hyperopt::SurrogateModel;
pub use hyperopt::SurrogateModelType;
pub use hyperopt::SurrogateOptimizer;
pub use hyperopt::TPESampler;
pub use hyperopt::Trial;
pub use hyperopt::TrialHistory;
pub use hyperopt::TrialMetrics;
pub use hyperopt::TrialResult;
pub use hyperopt::TrialState;
pub use hyperopt::TunerConfig;
pub use hyperopt::WarmStartConfig;
pub use hyperopt::WarmStartDataSource;
pub use hyperopt::WarmStartStrategy;
pub use losses::CrossEntropyLoss;
pub use losses::Loss;
pub use losses::MSELoss;
pub use metrics::Accuracy;
pub use metrics::F1Score;
pub use metrics::Metric;
pub use metrics::MetricCollection;
pub use metrics::Perplexity;
pub use mixed_precision::utils as mixed_precision_utils;
pub use mixed_precision::AMPManager;
pub use mixed_precision::AdvancedMixedPrecisionConfig;
pub use mixed_precision::AdvancedMixedPrecisionManager;
pub use mixed_precision::ComputeOptimizationManager;
pub use mixed_precision::ComputeOptimizationReport;
pub use mixed_precision::DynamicBatchingConfig;
pub use mixed_precision::DynamicBatchingManager;
pub use mixed_precision::DynamicBatchingReport;
pub use mixed_precision::LayerScalingConfig;
pub use mixed_precision::LossScaler;
pub use mixed_precision::MixedPrecisionConfig;
pub use mixed_precision::MixedPrecisionReport;
pub use qat::fake_quantize;
pub use qat::fake_quantize_mixed_bit;
pub use qat::qat_loss;
pub use qat::ActivationQuantizer;
pub use qat::CalibrationDataset;
pub use qat::LayerQuantConfig;
pub use qat::MixedBitQATTrainer;
pub use qat::MixedBitStrategy;
pub use qat::QATConfig;
pub use qat::QATConv2d;
pub use qat::QATLinear;
pub use qat::QATModel;
pub use qat::QATTrainer;
pub use qat::QuantStats;
pub use qat::QuantizationGradients;
pub use qat::QuantizationParams;
pub use qat::QuantizedModel;
pub use rlhf::ConstitutionalPrinciple;
pub use rlhf::GenerationResult;
pub use rlhf::HumanFeedback;
pub use rlhf::PPOConfig;
pub use rlhf::PPOStepResult;
pub use rlhf::PPOTrainer;
pub use rlhf::PolicyModel;
pub use rlhf::PreferencePair;
pub use rlhf::RLHFConfig;
pub use rlhf::RLHFMetrics;
pub use rlhf::RLHFPhase;
pub use rlhf::RewardModel;
pub use rlhf::RewardModelConfig;
pub use rlhf::RewardPrediction;
pub use rlhf::ValueModel;
pub use trainer::EarlyStoppingCallback;
pub use trainer::LogEntry;
pub use trainer::Trainer;
pub use trainer::TrainerCallback;
pub use trainer::TrainingState;
pub use training_args::EvaluationStrategy;
pub use training_args::SaveStrategy;
pub use training_args::TrainingArguments;
pub use training_dynamics::ConvergenceMetrics;
pub use training_dynamics::GradientFlowMetrics;
pub use training_dynamics::LossLandscapeMetrics;
pub use training_dynamics::TrainingDynamicsAnalyzer;
pub use training_dynamics::TrainingDynamicsConfig;
pub use training_dynamics::TrainingDynamicsReport;
pub use training_dynamics::TrainingDynamicsSnapshot;
pub use training_dynamics::WeightEvolutionMetrics;
pub use adaptive_gradient_scaling::AdaptiveGradientScaler;
pub use adaptive_gradient_scaling::AdaptiveGradientScalingConfig;
pub use adaptive_gradient_scaling::AdaptiveScalingStatistics;
pub use adaptive_gradient_scaling::GradientScalingResult;
pub use adaptive_gradient_scaling::LayerGradientStats as AdaptiveLayerGradientStats;
pub use adaptive_gradient_scaling::StabilityTrend;
pub use adaptive_learning_rate::AdaptationStrategy as LRAdaptationStrategy;
pub use adaptive_learning_rate::AdaptiveLRStatistics;
pub use adaptive_learning_rate::AdaptiveLearningRateConfig;
pub use adaptive_learning_rate::AdaptiveLearningRateScheduler;
pub use adaptive_learning_rate::LearningRateUpdate;
pub use adaptive_learning_rate::PerformanceTrend;
pub use adaptive_learning_rate::SchedulerState;
pub use adaptive_learning_rate::TrainingDynamics as LRTrainingDynamics;
pub use auto_parallelism::utils as auto_parallelism_utils;
pub use auto_parallelism::ArchitectureType;
pub use auto_parallelism::AutoParallelismConfig;
pub use auto_parallelism::AutoParallelismSelector;
pub use auto_parallelism::DeviceType;
pub use auto_parallelism::EvaluationMethod;
pub use auto_parallelism::HardwareConstraints;
pub use auto_parallelism::ModelConstraints;
pub use auto_parallelism::NetworkTopology;
pub use auto_parallelism::OptimizationObjective;
pub use auto_parallelism::ParallelismStrategy;
pub use auto_parallelism::PerformanceRequirements;
pub use auto_parallelism::SelectionAlgorithm;
pub use data_pipeline::ActiveLearningConfig;
pub use data_pipeline::ActiveLearningIntegration;
pub use data_pipeline::ActiveLearningManager;
pub use data_pipeline::ActiveLearningStats;
pub use data_pipeline::AdaptationStrategy;
pub use data_pipeline::AdaptiveAugmentationConfig;
pub use data_pipeline::AlignmentConfig;
pub use data_pipeline::AlignmentMethod;
pub use data_pipeline::AnnotationConfig;
pub use data_pipeline::AnnotationSource;
pub use data_pipeline::AudioAugmentationType;
pub use data_pipeline::AugmentationScheduling;
pub use data_pipeline::AugmentationStats;
pub use data_pipeline::AugmentationStrategy;
pub use data_pipeline::AugmentationStrategyType;
pub use data_pipeline::BatchingConfig;
pub use data_pipeline::BatchingStrategy;
pub use data_pipeline::CacheType;
pub use data_pipeline::CachingConfig;
pub use data_pipeline::CompressionAlgorithm as DataPipelineCompressionAlgorithm;
pub use data_pipeline::CoreSetMethod;
pub use data_pipeline::CurriculumLearningConfig;
pub use data_pipeline::CurriculumLearningManager;
pub use data_pipeline::CurriculumScheduling;
pub use data_pipeline::CurriculumSchedulingStrategy;
pub use data_pipeline::CurriculumStage;
pub use data_pipeline::CurriculumStats;
pub use data_pipeline::CurriculumStrategy;
pub use data_pipeline::DataFilter;
pub use data_pipeline::DataPipeline;
pub use data_pipeline::DataPipelineConfig;
pub use data_pipeline::DataSample;
pub use data_pipeline::DataSelectionCriteria;
pub use data_pipeline::DataSource;
pub use data_pipeline::DataSourceType;
pub use data_pipeline::DataValidationConfig;
pub use data_pipeline::DataValidator;
pub use data_pipeline::DifficultyAssessment;
pub use data_pipeline::DisagreementMeasure;
pub use data_pipeline::DistributedProcessingConfig;
pub use data_pipeline::DiversityConstraint;
pub use data_pipeline::DiversityMeasure;
pub use data_pipeline::DynamicAssessmentMethod;
pub use data_pipeline::DynamicAugmentationConfig;
pub use data_pipeline::DynamicAugmentationManager;
pub use data_pipeline::ErrorHandling;
pub use data_pipeline::EvictionPolicy;
pub use data_pipeline::FeatureExtractionConfig;
pub use data_pipeline::FeatureExtractionMethod;
pub use data_pipeline::FilterType;
pub use data_pipeline::FusionStrategy;
pub use data_pipeline::ImageAugmentationType;
pub use data_pipeline::LoadBalancingStrategy as DataPipelineLoadBalancingStrategy;
pub use data_pipeline::MissingModalityHandling;
pub use data_pipeline::Modality;
pub use data_pipeline::ModalityProcessor;
pub use data_pipeline::ModalityType;
pub use data_pipeline::MultiModalConfig;
pub use data_pipeline::MultiModalHandler;
pub use data_pipeline::MultiModalPreprocessing;
pub use data_pipeline::MultiModalStats;
pub use data_pipeline::NormalizationConfig;
pub use data_pipeline::NormalizationType;
pub use data_pipeline::PacingFunction;
pub use data_pipeline::PacingType;
pub use data_pipeline::PreprocessingConfig;
pub use data_pipeline::PreprocessingStep;
pub use data_pipeline::PreprocessingStepType;
pub use data_pipeline::ProcessingBackend;
pub use data_pipeline::QualityAssessmentMethod;
pub use data_pipeline::QualityControl;
pub use data_pipeline::QueryStrategy;
pub use data_pipeline::SamplingConfig;
pub use data_pipeline::ScheduleType;
pub use data_pipeline::ShuffleConfig;
pub use data_pipeline::ShuffleStrategy;
pub use data_pipeline::StreamingDataset;
pub use data_pipeline::StreamingDatasetConfig;
pub use data_pipeline::StreamingStats;
pub use data_pipeline::SuccessCriteria;
pub use data_pipeline::SynchronizationConfig;
pub use data_pipeline::TextAugmentationType;
pub use data_pipeline::TokenAugmentationType;
pub use data_pipeline::UncertaintyMeasure;
pub use data_pipeline::ValidationError;
pub use data_pipeline::ValidationResult;
pub use data_pipeline::ValidationRule;
pub use data_pipeline::ValidationRuleType;
pub use data_pipeline::ValidationSeverity;
pub use data_pipeline::ValidationStats;
pub use data_pipeline::ValidationStrategy;
pub use data_pipeline::ValidationWarning;
pub use data_pipeline::Validator;
pub use elastic_training::ElasticTrainingConfig;
pub use elastic_training::ElasticTrainingCoordinator;
pub use elastic_training::ScalingDecision;
pub use elastic_training::ScalingType;
pub use elastic_training::SystemStatus;
pub use elastic_training::WorkerInfo;
pub use elastic_training::WorkerStatus;
pub use expert_parallelism::utils as expert_parallelism_utils;
pub use expert_parallelism::ExpertAssignment;
pub use expert_parallelism::ExpertCommunicationPattern;
pub use expert_parallelism::ExpertParallelism;
pub use expert_parallelism::ExpertParallelismConfig;
pub use expert_parallelism::ExpertRoutingStrategy;
pub use expert_parallelism::LoadBalancingStats;
pub use expert_parallelism::LoadBalancingStrategy;
pub use expert_parallelism::TokenRouting;
pub use framework_integration::AggregationFunction;
pub use framework_integration::ArtifactConfig;
pub use framework_integration::ArtifactInfo;
pub use framework_integration::AudioLoggingConfig;
pub use framework_integration::AutoConnectConfig;
pub use framework_integration::ChartType;
pub use framework_integration::ClearMLArtifactConfig;
pub use framework_integration::ClearMLConfig;
pub use framework_integration::ClearMLTaskType;
pub use framework_integration::ColorFormat;
pub use framework_integration::ConflictResolution;
pub use framework_integration::CustomArtifact;
pub use framework_integration::CustomMetric;
pub use framework_integration::CustomMonitoring;
pub use framework_integration::CustomScalar;
pub use framework_integration::ExperimentMetadata as FrameworkExperimentMetadata;
pub use framework_integration::ExperimentStatus as FrameworkExperimentStatus;
pub use framework_integration::ExperimentTracker;
pub use framework_integration::ExportConfig;
pub use framework_integration::ExportFormat;
pub use framework_integration::ExportFrequency;
pub use framework_integration::FrameworkIntegrationManager;
pub use framework_integration::GraphLoggingConfig;
pub use framework_integration::HistogramConfig;
pub use framework_integration::ImageLoggingConfig;
pub use framework_integration::IntegrationConfig;
pub use framework_integration::IntegrationType;
pub use framework_integration::MLflowAdvancedConfig;
pub use framework_integration::MLflowAuth;
pub use framework_integration::MLflowAuthType;
pub use framework_integration::MLflowConfig;
pub use framework_integration::MLflowTracker;
pub use framework_integration::MetricType;
pub use framework_integration::MetricValue;
pub use framework_integration::ModelRegistrationConfig;
pub use framework_integration::ModelStage;
pub use framework_integration::NeptuneConfig;
pub use framework_integration::NeptuneExperimentConfig;
pub use framework_integration::NeptuneMonitoringConfig;
pub use framework_integration::ParameterValue as FrameworkParameterValue;
pub use framework_integration::ProfilingConfig;
pub use framework_integration::ResumeConfig;
pub use framework_integration::ScalarLayout;
pub use framework_integration::SyncConfig;
pub use framework_integration::SyncFrequency;
pub use framework_integration::TensorBoardAdvancedConfig;
pub use framework_integration::TensorBoardConfig;
pub use framework_integration::TensorBoardTracker;
pub use framework_integration::UpdateFrequency;
pub use framework_integration::WandBAdvancedConfig;
pub use framework_integration::WandBConfig;
pub use framework_integration::WandBTracker;
pub use framework_integration::WatchModelConfig;
pub use memory_optimization::CPUOffloadManager;
pub use memory_optimization::GradientCheckpointWrapper;
pub use memory_optimization::MemoryOptimizationConfig;
pub use memory_optimization::MemoryOptimizationStats;
pub use memory_optimization::MemoryOptimizer;
pub use multicloud::AlertType;
pub use multicloud::AuthConfig;
pub use multicloud::AuthType;
pub use multicloud::BudgetAlert;
pub use multicloud::CloudProvider;
pub use multicloud::CloudScheduler;
pub use multicloud::CommunicationPattern;
pub use multicloud::CompressionAlgorithm;
pub use multicloud::CompressionConfig;
pub use multicloud::CostConfig;
pub use multicloud::CostEntry;
pub use multicloud::CostOptimizationStrategy;
pub use multicloud::InstanceType;
pub use multicloud::MultiCloudConfig;
pub use multicloud::MultiCloudOrchestrator;
pub use multicloud::MultiCloudProcessGroup;
pub use multicloud::NodeInfo;
pub use multicloud::NodeStatus;
pub use multicloud::OrchestrationStrategy;
pub use multicloud::PerformanceMetrics as MultiCloudPerformanceMetrics;
pub use multicloud::RecoveryStrategy;
pub use multicloud::SchedulingAlgorithm;
pub use nas_integration::Architecture;
pub use nas_integration::NASAlgorithm;
pub use nas_integration::NASConfig;
pub use nas_integration::NASController;
pub use nas_integration::Operation;
pub use nas_integration::PerformanceMetrics;
pub use nas_integration::SearchSpaceConfig;
pub use nas_integration::TargetPlatform;
pub use parallelism_3d::AggregateParallelismStats;
pub use parallelism_3d::CommBackend;
pub use parallelism_3d::MemoryOptimization;
pub use parallelism_3d::Parallelism3D;
pub use parallelism_3d::Parallelism3DManager;
pub use parallelism_3d::Parallelism3DStats;
pub use parallelism_3d::ParallelismConfig;
pub use parallelism_3d::PipelineSchedule;
pub use sequence_parallelism::utils as sequence_parallelism_utils;
pub use sequence_parallelism::AttentionCommunication;
pub use sequence_parallelism::SequenceChunk;
pub use sequence_parallelism::SequenceCommunicationPattern;
pub use sequence_parallelism::SequenceMemoryOptimization;
pub use sequence_parallelism::SequenceParallelism;
pub use sequence_parallelism::SequenceParallelismConfig;
pub use sequence_parallelism::SequenceParallelismStats;
pub use sequence_parallelism::SequenceSplittingStrategy;
pub use tensor_parallelism::utils as tensor_parallelism_utils;
pub use tensor_parallelism::CommunicationRequirement;
pub use tensor_parallelism::TensorCommunicationPattern;
pub use tensor_parallelism::TensorMemoryOptimization;
pub use tensor_parallelism::TensorOperation;
pub use tensor_parallelism::TensorOperationType;
pub use tensor_parallelism::TensorParallelism;
pub use tensor_parallelism::TensorParallelismConfig;
pub use tensor_parallelism::TensorParallelismStatistics;
pub use tensor_parallelism::TensorPartition;
pub use tensor_parallelism::TensorPartitioningStrategy;
pub use training_monitor::AnomalyReport;
pub use training_monitor::AnomalyType;
pub use training_monitor::HealthStatus;
pub use training_monitor::PerformanceStats;
pub use training_monitor::TrainingHealthStatus;
pub use training_monitor::TrainingMonitor;
pub use training_monitor::TrainingMonitorConfig;
pub use training_monitor::TrainingReport;
pub use advanced_stability_monitor::AdvancedStabilityConfig;
pub use advanced_stability_monitor::AdvancedStabilityMonitor;
pub use advanced_stability_monitor::LossLandscapeAnalysis;
pub use advanced_stability_monitor::PatternDetector;
pub use advanced_stability_monitor::PredictedAnomalyType;
pub use advanced_stability_monitor::PredictiveAnomaly;
pub use advanced_stability_monitor::PreventiveAction;
pub use advanced_stability_monitor::RiskLevel;
pub use advanced_stability_monitor::StabilityReport;
pub use advanced_stability_monitor::StabilityScore;
pub use advanced_stability_monitor::TrainerParameters;
pub use advanced_stability_monitor::TrainingDynamics;
pub use advanced_stability_monitor::TrendDirection;
pub use gradient_anomaly_recovery::AdaptiveThresholds;
pub use gradient_anomaly_recovery::GradientAnomaly;
pub use gradient_anomaly_recovery::GradientAnomalyType;
pub use gradient_anomaly_recovery::GradientRecoveryConfig;
pub use gradient_anomaly_recovery::GradientRecoveryManager;
pub use gradient_anomaly_recovery::GradientRecoveryStrategy;
pub use gradient_anomaly_recovery::GradientSeverity;
pub use gradient_anomaly_recovery::LayerGradientStats;
pub use gradient_anomaly_recovery::RecoveryResult;
pub use gradient_anomaly_recovery::RecoveryStatistics;
pub use cost_tracking::AlertThreshold;
pub use cost_tracking::BillingModel;
pub use cost_tracking::Budget;
pub use cost_tracking::BudgetFilters;
pub use cost_tracking::BudgetPeriod;
pub use cost_tracking::BudgetStatus;
pub use cost_tracking::CostBreakdown;
pub use cost_tracking::CostDataPoint;
pub use cost_tracking::CostDriver;
pub use cost_tracking::CostEntry as CostTrackingCostEntry;
pub use cost_tracking::CostForecastingModel;
pub use cost_tracking::CostRecommendation;
pub use cost_tracking::CostReport;
pub use cost_tracking::CostStatistics;
pub use cost_tracking::CostTracker;
pub use cost_tracking::CostTrend;
pub use cost_tracking::EfficiencyMetrics;
pub use cost_tracking::ForecastingAccuracy;
pub use cost_tracking::ForecastingParameters;
pub use cost_tracking::ImplementationEffort;
pub use cost_tracking::NotificationType;
pub use cost_tracking::RecommendationCategory;
pub use cost_tracking::RecommendationPriority;
pub use cost_tracking::ReportType;
pub use cost_tracking::TimeRange;
pub use model_versioning::ModelRegistry;
pub use model_versioning::ModelStatus;
pub use model_versioning::ModelVersion;
pub use model_versioning::ModelVersioningManager;
pub use model_versioning::PerformanceMetrics as ModelVersioningPerformanceMetrics;
pub use model_versioning::TrainingConfig;
pub use model_versioning::VersionComparison;
pub use online_learning::ConceptDrift;
pub use online_learning::DriftType;
pub use online_learning::OnlineDataPoint;
pub use online_learning::OnlineLearningConfig;
pub use online_learning::OnlineLearningError;
pub use online_learning::OnlineLearningManager;
pub use online_learning::OnlineStatistics;
pub use online_learning::PerformanceWindow;
pub use resource_scheduling::AlertSeverity;
pub use resource_scheduling::AllocationStatus;
pub use resource_scheduling::CostAlert;
pub use resource_scheduling::CostOptimizationRecommendation;
pub use resource_scheduling::CostSnapshot;
pub use resource_scheduling::LocalityPreference;
pub use resource_scheduling::Priority;
pub use resource_scheduling::RecommendationType;
pub use resource_scheduling::ResourceAllocation as SchedulingResourceAllocation;
pub use resource_scheduling::ResourceConstraints;
pub use resource_scheduling::ResourcePool;
pub use resource_scheduling::ResourceRequest;
pub use resource_scheduling::ResourceScheduler;
pub use resource_scheduling::ResourceType;
pub use resource_scheduling::SchedulingAlgorithm as ResourceSchedulingAlgorithm;
pub use resource_scheduling::SchedulingStatistics;
pub use resource_scheduling::StorageSpeed;
pub use ring_attention::utils as ring_attention_utils;
pub use ring_attention::ModelParams;
pub use ring_attention::RingAttentionBlock;
pub use ring_attention::RingAttentionConfig;
pub use ring_attention::RingAttentionManager;
pub use ring_attention::RingAttentionStats;
pub use ring_attention::RingCommunicationPattern;
pub use ring_attention::RingKVPair;
pub use training_orchestration::CheckpointConfig;
pub use training_orchestration::CheckpointInfo;
pub use training_orchestration::EarlyStoppingConfig as OrchestrationEarlyStoppingConfig;
pub use training_orchestration::JobEvent;
pub use training_orchestration::JobPriority;
pub use training_orchestration::JobScheduler;
pub use training_orchestration::JobStatus as OrchestrationJobStatus;
pub use training_orchestration::ModelConfig;
pub use training_orchestration::OrchestrationStatistics;
pub use training_orchestration::ResourceNode;
pub use training_orchestration::ResourceRequirements;
pub use training_orchestration::SchedulingStrategy;
pub use training_orchestration::TrainingJob;
pub use training_orchestration::TrainingJobConfig;
pub use training_orchestration::TrainingMetrics;
pub use training_orchestration::TrainingOrchestrator;
pub use config_validation::ConfigSchema;
pub use config_validation::ConfigValidator;
pub use config_validation::Constraint;
pub use config_validation::FieldSchema;
pub use config_validation::FieldType;
pub use config_validation::Severity;
pub use config_validation::Validatable;
pub use config_validation::ValidatedConfig;
pub use config_validation::ValidationError as ConfigValidationError;
pub use config_validation::ValidationReport;
pub use config_validation::ValidationRule as ConfigValidationRule;
pub use error_codes::get_error_info;
pub use error_codes::get_recovery_actions;
pub use error_codes::is_critical_error;
pub use error_codes::ErrorCodeInfo;
pub use error_codes::ErrorCodeRegistry;
pub use error_handling::ErrorContext;
pub use error_handling::ErrorManager;
pub use error_handling::ErrorPattern;
pub use error_handling::ErrorSeverity;
pub use error_handling::ErrorStatistics;
pub use error_handling::ErrorTrend;
pub use error_handling::ErrorType;
pub use error_handling::RecoveryAction;
pub use error_handling::RecoveryStrategy as ErrorRecoveryStrategy;
pub use error_handling::RecoverySuggestion;
pub use error_handling::SystemInfo as ErrorSystemInfo;
pub use error_handling::TrainingError;
pub use error_handling::TrainingErrorExt;
pub use error_handling::TrainingResult;
pub use simplified_trainer::CheckpointCallback;
pub use simplified_trainer::EarlyStoppingMode;
pub use simplified_trainer::EpochResult;
pub use simplified_trainer::LogLevel;
pub use simplified_trainer::LoggingCallback;
pub use simplified_trainer::MetricsCallback;
pub use simplified_trainer::ProgressCallback;
pub use simplified_trainer::SimpleCallback;
pub use simplified_trainer::SimpleTrainer;
pub use simplified_trainer::SimpleTrainerBuilder;
pub use simplified_trainer::SimpleTrainingConfig;
pub use simplified_trainer::TrainingResults;

Modules§

adaptive_gradient_scaling
Adaptive Gradient Scaling for Improved Training Stability
adaptive_learning_rate
Adaptive Learning Rate Schedulers for Dynamic Training Optimization
advanced_stability_monitor
Advanced Training Stability Monitoring System
auto_parallelism
config_validation
continual
cost_tracking
data_pipeline
Data Pipeline Enhancements for TrustformeRS Training
distributed
elastic_training
error_codes
error_handling
experiment_management
Experiment Management Framework
expert_parallelism
few_shot
framework_integration
Framework Integration for TrustformeRS Training
gradient
gradient_anomaly_recovery
Advanced Gradient Anomaly Recovery System
hyperopt
Automated Hyperparameter Tuning Framework
losses
memory_optimization
metrics
mixed_precision
model_versioning
multicloud
nas_integration
online_learning
parallelism_3d
qat
Quantization-Aware Training (QAT) for TrustformeRS
resource_scheduling
ring_attention
rlhf
Reinforcement Learning from Human Feedback (RLHF) infrastructure.
sequence_parallelism
simplified_trainer
tensor_parallelism
trainer
training_args
training_dynamics
Training Dynamics Analysis Module
training_monitor
training_orchestration

Macros§

create_context
training_error
Helper macros for error creation
validate_range
Helper macros for common validation patterns
validate_required