1#![allow(clippy::result_large_err)]
74#![allow(clippy::too_many_arguments)]
76#![allow(clippy::type_complexity)]
77#![allow(clippy::excessive_nesting)]
78#![allow(clippy::await_holding_lock)]
80#![allow(clippy::needless_range_loop)]
81#![allow(clippy::empty_line_after_doc_comments)]
82#![allow(clippy::manual_clamp)]
83#![allow(clippy::derivable_impls)]
84#![allow(clippy::vec_init_then_push)]
85#![allow(clippy::ptr_arg)]
86
87pub mod adaptive_gradient_scaling;
88pub mod adaptive_learning_rate;
89pub mod advanced_stability_monitor;
90pub mod auto_parallelism;
91pub mod config_validation;
92pub mod continual;
93pub mod cost_tracking;
94pub mod data_pipeline;
95pub mod distributed;
96pub mod elastic_training;
97pub mod error_codes;
98pub mod error_handling;
99pub mod experiment_management;
100pub mod expert_parallelism;
101pub mod few_shot;
102pub mod framework_integration;
103pub mod gradient;
104pub mod gradient_anomaly_recovery;
105pub mod hyperopt;
106pub mod losses;
107pub mod memory_optimization;
108pub mod metrics;
109pub mod mixed_precision;
110pub mod model_versioning;
111pub mod multicloud;
112pub mod nas_integration;
113pub mod online_learning;
114pub mod parallelism_3d;
115pub mod qat;
116pub mod resource_scheduling;
117pub mod ring_attention;
118pub mod rlhf;
119pub mod sequence_parallelism;
120pub mod simplified_trainer;
121pub mod tensor_parallelism;
122pub mod trainer;
123pub mod training_args;
124pub mod training_dynamics;
125pub mod training_monitor;
126pub mod training_orchestration;
127
128pub use continual::{
129 CatastrophicPreventionStrategy, ContinualLearningConfig, ContinualLearningManager, EWCConfig,
130 EWCTrainer, ExperienceBuffer, FisherInformation, MemoryReplay, MemoryReplayConfig,
131 ProgressiveConfig, ProgressiveNetwork, RegularizationMethod, TaskBoundaryDetector, TaskInfo,
132 TaskModule, TaskTransition,
133};
134pub use distributed::{
135 init_distributed_training, utils as distributed_utils, DataParallelTrainer, DistributedBackend,
136 DistributedConfig, ProcessGroup,
137};
138pub use experiment_management::{
139 ABTestConfig, ABTestResults, ABTestStatus, ArtifactType, DataLineage, DataSplit,
140 EnvironmentInfo, ExperimentFilters, ExperimentManager, ExperimentMetadata, ExperimentReport,
141 ExperimentResults, ExperimentStatus, GPUInfo, HardwareInfo, HyperparameterComparison,
142 HyperparameterConfig, ModelArtifact, ModelLineage, ModelProvenance, ModelSizeInfo,
143 ParameterChange, PipelineStep, QualityAssuranceStep, SystemInfo, TrainingPipeline,
144};
145pub use few_shot::{
146 AdaptationConfig, CrossTaskGeneralizer, FewShotConfig, FewShotExample, FewShotMethod,
147 GeneralizationConfig, ICLExample, InContextConfig, InContextLearner, MAMLConfig, MAMLTrainer,
148 MetaLearningAlgorithm, PromptConfig, PromptTuner, ReptileConfig, ReptileTrainer, SoftPrompt,
149 SupportSet, TaskAdapter, TaskDescriptor, TaskEmbedding,
150};
151pub use gradient::GradientUtils;
152pub use hyperopt::{
153 AcquisitionFunction,
155 AcquisitionFunctionType,
156 AdvancedEarlyStoppingConfig,
157 ArmGenerationStrategy,
158 ArmStatistics,
159 BanditAlgorithm,
160 BanditConfig,
161 BanditOptimizer,
162 BayesianOptimization,
163 CategoricalParameter,
164 ContinuousParameter,
165 Direction,
166 DiscreteParameter,
167 EarlyStoppingConfig,
169 EarlyStoppingStrategy,
170 EvaluationJob,
171 EvaluationResult,
172 ExplorationStrategy,
173 FaultToleranceConfig,
174 GPSampler,
175 GPUAllocation,
176 GridSearch,
177 HalvingStrategy,
178 HyperParameter,
179 Hyperband,
180 HyperparameterTuner,
182 JobStatus,
183 KernelType,
184 LoadBalancer,
185 LogParameter,
186 OptimizationDirection,
187 OptimizationResult,
188 PBTConfig,
190 PBTMember,
191 PBTStats,
192 ParallelEvaluationConfig,
193 ParallelEvaluator,
194 ParallelStrategy,
195 ParameterValue,
196 PopulationBasedTraining,
197 PriorityLevel,
198 PruningConfig,
199 PruningStrategy,
200 RandomSampler,
201 RandomSearch,
202 ResourceAllocation,
203 ResourceUsage,
204 RewardFunction,
205 Sampler,
207 SamplerConfig,
208 SearchSpace,
210 SearchStrategy,
212 StudyStatistics,
213 SuccessiveHalving,
214 SurrogateConfig,
215 SurrogateModel,
216 SurrogateModelType,
217 SurrogateOptimizer,
218 TPESampler,
219 Trial,
221 TrialHistory,
222 TrialMetrics,
223 TrialResult,
224 TrialState,
225 TunerConfig,
226 WarmStartConfig,
227 WarmStartDataSource,
228 WarmStartStrategy,
229};
230pub use losses::{CrossEntropyLoss, Loss, MSELoss};
231pub use metrics::{Accuracy, F1Score, Metric, MetricCollection, Perplexity};
232pub use mixed_precision::{
233 utils as mixed_precision_utils, AMPManager, AdvancedMixedPrecisionConfig,
234 AdvancedMixedPrecisionManager, ComputeOptimizationManager, ComputeOptimizationReport,
235 DynamicBatchingConfig, DynamicBatchingManager, DynamicBatchingReport, LayerScalingConfig,
236 LossScaler, MixedPrecisionConfig, MixedPrecisionReport,
237};
238pub use qat::{
239 fake_quantize, fake_quantize_mixed_bit, qat_loss, ActivationQuantizer, CalibrationDataset,
240 LayerQuantConfig, MixedBitQATTrainer, MixedBitStrategy, QATConfig, QATConv2d, QATLinear,
241 QATModel, QATTrainer, QuantStats, QuantizationGradients, QuantizationParams, QuantizedModel,
242};
243pub use rlhf::{
244 ConstitutionalPrinciple, GenerationResult, HumanFeedback, PPOConfig, PPOStepResult, PPOTrainer,
245 PolicyModel, PreferencePair, RLHFConfig, RLHFMetrics, RLHFPhase, RewardModel,
246 RewardModelConfig, RewardPrediction, ValueModel,
247};
248pub use trainer::{EarlyStoppingCallback, LogEntry, Trainer, TrainerCallback, TrainingState};
249pub use training_args::{EvaluationStrategy, SaveStrategy, TrainingArguments};
250pub use training_dynamics::{
251 ConvergenceMetrics, GradientFlowMetrics, LossLandscapeMetrics, TrainingDynamicsAnalyzer,
252 TrainingDynamicsConfig, TrainingDynamicsReport, TrainingDynamicsSnapshot,
253 WeightEvolutionMetrics,
254};
255
256pub use adaptive_gradient_scaling::{
258 AdaptiveGradientScaler, AdaptiveGradientScalingConfig, AdaptiveScalingStatistics,
259 GradientScalingResult, LayerGradientStats as AdaptiveLayerGradientStats, StabilityTrend,
260};
261pub use adaptive_learning_rate::{
262 AdaptationStrategy as LRAdaptationStrategy, AdaptiveLRStatistics, AdaptiveLearningRateConfig,
263 AdaptiveLearningRateScheduler, LearningRateUpdate, PerformanceTrend, SchedulerState,
264 TrainingDynamics as LRTrainingDynamics,
265};
266pub use auto_parallelism::{
267 utils as auto_parallelism_utils, ArchitectureType, AutoParallelismConfig,
268 AutoParallelismSelector, DeviceType, EvaluationMethod, HardwareConstraints, ModelConstraints,
269 NetworkTopology, OptimizationObjective, ParallelismStrategy, PerformanceRequirements,
270 SelectionAlgorithm,
271};
272pub use data_pipeline::{
273 ActiveLearningConfig, ActiveLearningIntegration, ActiveLearningManager, ActiveLearningStats,
274 AdaptationStrategy, AdaptiveAugmentationConfig, AlignmentConfig, AlignmentMethod,
275 AnnotationConfig, AnnotationSource, AudioAugmentationType, AugmentationScheduling,
276 AugmentationStats, AugmentationStrategy, AugmentationStrategyType, BatchingConfig,
277 BatchingStrategy, CacheType, CachingConfig,
278 CompressionAlgorithm as DataPipelineCompressionAlgorithm, CoreSetMethod,
279 CurriculumLearningConfig, CurriculumLearningManager, CurriculumScheduling,
280 CurriculumSchedulingStrategy, CurriculumStage, CurriculumStats, CurriculumStrategy, DataFilter,
281 DataPipeline, DataPipelineConfig, DataSample, DataSelectionCriteria, DataSource,
282 DataSourceType, DataValidationConfig, DataValidator, DifficultyAssessment, DisagreementMeasure,
283 DistributedProcessingConfig, DiversityConstraint, DiversityMeasure, DynamicAssessmentMethod,
284 DynamicAugmentationConfig, DynamicAugmentationManager, ErrorHandling, EvictionPolicy,
285 FeatureExtractionConfig, FeatureExtractionMethod, FilterType, FusionStrategy,
286 ImageAugmentationType, LoadBalancingStrategy as DataPipelineLoadBalancingStrategy,
287 MissingModalityHandling, Modality, ModalityProcessor, ModalityType, MultiModalConfig,
288 MultiModalHandler, MultiModalPreprocessing, MultiModalStats, NormalizationConfig,
289 NormalizationType, PacingFunction, PacingType, PreprocessingConfig, PreprocessingStep,
290 PreprocessingStepType, ProcessingBackend, QualityAssessmentMethod, QualityControl,
291 QueryStrategy, SamplingConfig, ScheduleType, ShuffleConfig, ShuffleStrategy, StreamingDataset,
292 StreamingDatasetConfig, StreamingStats, SuccessCriteria, SynchronizationConfig,
293 TextAugmentationType, TokenAugmentationType, UncertaintyMeasure, ValidationError,
294 ValidationResult, ValidationRule, ValidationRuleType, ValidationSeverity, ValidationStats,
295 ValidationStrategy, ValidationWarning, Validator,
296};
297pub use elastic_training::{
298 ElasticTrainingConfig, ElasticTrainingCoordinator, ScalingDecision, ScalingType, SystemStatus,
299 WorkerInfo, WorkerStatus,
300};
301pub use expert_parallelism::{
302 utils as expert_parallelism_utils, ExpertAssignment, ExpertCommunicationPattern,
303 ExpertParallelism, ExpertParallelismConfig, ExpertRoutingStrategy, LoadBalancingStats,
304 LoadBalancingStrategy, TokenRouting,
305};
306pub use framework_integration::{
307 AggregationFunction, ArtifactConfig, ArtifactInfo, AudioLoggingConfig, AutoConnectConfig,
308 ChartType, ClearMLArtifactConfig, ClearMLConfig, ClearMLTaskType, ColorFormat,
309 ConflictResolution, CustomArtifact, CustomMetric, CustomMonitoring, CustomScalar,
310 ExperimentMetadata as FrameworkExperimentMetadata,
311 ExperimentStatus as FrameworkExperimentStatus, ExperimentTracker, ExportConfig, ExportFormat,
312 ExportFrequency, FrameworkIntegrationManager, GraphLoggingConfig, HistogramConfig,
313 ImageLoggingConfig, IntegrationConfig, IntegrationType, MLflowAdvancedConfig, MLflowAuth,
314 MLflowAuthType, MLflowConfig, MLflowTracker, MetricType, MetricValue, ModelRegistrationConfig,
315 ModelStage, NeptuneConfig, NeptuneExperimentConfig, NeptuneMonitoringConfig,
316 ParameterValue as FrameworkParameterValue, ProfilingConfig, ResumeConfig, ScalarLayout,
317 SyncConfig, SyncFrequency, TensorBoardAdvancedConfig, TensorBoardConfig, TensorBoardTracker,
318 UpdateFrequency, WandBAdvancedConfig, WandBConfig, WandBTracker, WatchModelConfig,
319};
320pub use memory_optimization::{
321 CPUOffloadManager, GradientCheckpointWrapper, MemoryOptimizationConfig,
322 MemoryOptimizationStats, MemoryOptimizer,
323};
324pub use multicloud::{
325 AlertType, AuthConfig, AuthType, BudgetAlert, CloudProvider, CloudScheduler,
326 CommunicationPattern, CompressionAlgorithm, CompressionConfig, CostConfig, CostEntry,
327 CostOptimizationStrategy, InstanceType, MultiCloudConfig, MultiCloudOrchestrator,
328 MultiCloudProcessGroup, NodeInfo, NodeStatus, OrchestrationStrategy,
329 PerformanceMetrics as MultiCloudPerformanceMetrics, RecoveryStrategy, SchedulingAlgorithm,
330};
331pub use nas_integration::{
332 Architecture, NASAlgorithm, NASConfig, NASController, Operation, PerformanceMetrics,
333 SearchSpaceConfig, TargetPlatform,
334};
335pub use parallelism_3d::{
336 AggregateParallelismStats, CommBackend, MemoryOptimization, Parallelism3D,
337 Parallelism3DManager, Parallelism3DStats, ParallelismConfig, PipelineSchedule,
338};
339pub use sequence_parallelism::{
340 utils as sequence_parallelism_utils, AttentionCommunication, SequenceChunk,
341 SequenceCommunicationPattern, SequenceMemoryOptimization, SequenceParallelism,
342 SequenceParallelismConfig, SequenceParallelismStats, SequenceSplittingStrategy,
343};
344pub use tensor_parallelism::{
345 utils as tensor_parallelism_utils, CommunicationRequirement, TensorCommunicationPattern,
346 TensorMemoryOptimization, TensorOperation, TensorOperationType, TensorParallelism,
347 TensorParallelismConfig, TensorParallelismStatistics, TensorPartition,
348 TensorPartitioningStrategy,
349};
350pub use training_monitor::{
351 AnomalyReport, AnomalyType, HealthStatus, PerformanceStats, TrainingHealthStatus,
352 TrainingMonitor, TrainingMonitorConfig, TrainingReport,
353};
354
355pub use advanced_stability_monitor::{
357 AdvancedStabilityConfig, AdvancedStabilityMonitor, LossLandscapeAnalysis, PatternDetector,
358 PredictedAnomalyType, PredictiveAnomaly, PreventiveAction, RiskLevel, StabilityReport,
359 StabilityScore, TrainerParameters, TrainingDynamics, TrendDirection,
360};
361
362pub use gradient_anomaly_recovery::{
364 AdaptiveThresholds, GradientAnomaly, GradientAnomalyType, GradientRecoveryConfig,
365 GradientRecoveryManager, GradientRecoveryStrategy, GradientSeverity, LayerGradientStats,
366 RecoveryResult, RecoveryStatistics,
367};
368
369pub use cost_tracking::{
371 AlertThreshold, BillingModel, Budget, BudgetFilters, BudgetPeriod, BudgetStatus, CostBreakdown,
372 CostDataPoint, CostDriver, CostEntry as CostTrackingCostEntry, CostForecastingModel,
373 CostRecommendation, CostReport, CostStatistics, CostTracker, CostTrend, EfficiencyMetrics,
374 ForecastingAccuracy, ForecastingParameters, ImplementationEffort, NotificationType,
375 RecommendationCategory, RecommendationPriority, ReportType, TimeRange,
376};
377pub use model_versioning::{
378 ModelRegistry, ModelStatus, ModelVersion, ModelVersioningManager,
379 PerformanceMetrics as ModelVersioningPerformanceMetrics, TrainingConfig, VersionComparison,
380};
381pub use online_learning::{
382 ConceptDrift, DriftType, OnlineDataPoint, OnlineLearningConfig, OnlineLearningError,
383 OnlineLearningManager, OnlineStatistics, PerformanceWindow,
384};
385pub use resource_scheduling::{
386 AlertSeverity, AllocationStatus, CostAlert, CostOptimizationRecommendation, CostSnapshot,
387 LocalityPreference, Priority, RecommendationType,
388 ResourceAllocation as SchedulingResourceAllocation, ResourceConstraints, ResourcePool,
389 ResourceRequest, ResourceScheduler, ResourceType,
390 SchedulingAlgorithm as ResourceSchedulingAlgorithm, SchedulingStatistics, StorageSpeed,
391};
392pub use ring_attention::{
393 utils as ring_attention_utils, ModelParams, RingAttentionBlock, RingAttentionConfig,
394 RingAttentionManager, RingAttentionStats, RingCommunicationPattern, RingKVPair,
395};
396pub use training_orchestration::{
397 CheckpointConfig, CheckpointInfo, EarlyStoppingConfig as OrchestrationEarlyStoppingConfig,
398 JobEvent, JobPriority, JobScheduler, JobStatus as OrchestrationJobStatus, ModelConfig,
399 OrchestrationStatistics, ResourceNode, ResourceRequirements, SchedulingStrategy, TrainingJob,
400 TrainingJobConfig, TrainingMetrics, TrainingOrchestrator,
401};
402
403pub use config_validation::{
405 ConfigSchema, ConfigValidator, Constraint, FieldSchema, FieldType, Severity, Validatable,
406 ValidatedConfig, ValidationError as ConfigValidationError, ValidationReport,
407 ValidationRule as ConfigValidationRule,
408};
409pub use error_codes::{
410 get_error_info, get_recovery_actions, is_critical_error, ErrorCodeInfo, ErrorCodeRegistry,
411};
412pub use error_handling::{
413 ErrorContext, ErrorManager, ErrorPattern, ErrorSeverity, ErrorStatistics, ErrorTrend,
414 ErrorType, RecoveryAction, RecoveryStrategy as ErrorRecoveryStrategy, RecoverySuggestion,
415 SystemInfo as ErrorSystemInfo, TrainingError, TrainingErrorExt, TrainingResult,
416};
417pub use simplified_trainer::{
418 CheckpointCallback, EarlyStoppingMode, EpochResult, LogLevel, LoggingCallback, MetricsCallback,
419 ProgressCallback, SimpleCallback, SimpleTrainer, SimpleTrainerBuilder, SimpleTrainingConfig,
420 TrainingResults,
421};