1#![deny(unsafe_code)]
221#![allow(dead_code)]
222#![allow(unused_variables)]
223#![allow(unused_assignments)]
224#![allow(unused_mut)]
225#![allow(clippy::result_large_err)]
226#![allow(clippy::needless_range_loop)]
227#![allow(clippy::type_complexity)]
228#![allow(clippy::vec_init_then_push)]
229#![allow(clippy::clone_on_copy)]
230#![allow(clippy::if_same_then_else)]
231#![allow(clippy::doc_overindented_list_items)]
232
233pub mod activation_function;
234pub mod backends;
235pub mod benchmarks;
236pub mod data;
237pub mod deployment;
238pub mod distributed;
239pub mod layers;
240pub mod loss;
241pub mod metrics;
242pub mod mixed_precision;
243pub mod model;
244pub mod model_parallel;
245pub mod optimizers;
246pub mod peft;
247pub mod pipeline;
248pub mod pretrained;
249pub mod scheduler;
250#[cfg(feature = "serialize")]
251pub mod serialization;
252pub mod trainer;
253pub mod training;
254pub mod training_pipeline;
255pub mod utils;
256
257#[cfg(feature = "onnx")]
258pub mod onnx;
259
260pub mod active_learning;
261pub mod adversarial;
262pub mod anomaly_detection;
263pub mod tensorflow_compat;
264pub mod text_generation_pipelines;
265pub use anomaly_detection::{
266 compute_anomaly_metrics,
267 compute_auc_roc,
268 compute_average_precision,
269 AdDeepSvdd,
270 AdDeepSvddConfig,
271 AdLinear,
273 AdMlp,
274 AeAnomaly,
275 AeAnomalyConfig,
276 AnomalyError,
277 AnomalyEvaluationMetrics,
278 AnomalyMemory,
279 AnomalyMetrics,
280 AnomalyThresholder,
281 AnomalyTransformerModel,
282 AnomalyVAE,
283 AnomalyVaeConfig,
284 CouplingLayer,
285 DeepSVDD,
286 DeepSvddConfig,
287 FlowAnomaly,
288 FlowAnomalyConfig,
289 GaussianMixtureAnomaly,
290 IsolationForestConfig,
291 IsolationNode,
292 IsolationTree,
293 MemAeConfig,
294 MemoryAugmentedAE,
295 NeuralIsolationForest,
296 PatchTSAD,
297 PatchTsadConfig,
298 RobustRCF,
299 RobustRcfConfig,
300 SpectralResidual,
301 VaeAnomaly,
302 VaeAnomalyConfig,
303};
304pub mod architecture_distillation;
305pub mod bayesian;
306pub mod bayesian_dl;
307pub use bayesian_dl::{
308 compute_calibration, nll_classification, BdlLinear, BdlMlp, CalibrationResult,
309 DeepEnsemble as BdlDeepEnsemble, DeepEnsembleConfig as BdlDeepEnsembleConfig,
310 LaplaceApproximation, SghcmConfig, SghcmSampler, SgldConfig, SgldSampler, SwagConfig,
311 SwagModel, TemperatureScaling as BdlTemperatureScaling,
312};
313pub mod continual_learning;
314pub mod contrastive;
315pub mod curriculum_learning;
316pub use curriculum_learning::{
317 AutoCurriculum, ClBabyStepConfig, ClBanditArm, ClCompetenceStrategy, ClConfidenceMethod,
318 ClCurriculumScheduler, ClDifficultyMethod, ClDifficultyScorer, ClEpochResult, ClLambdaSchedule,
319 ClReport, ClSplWeightingStrategy, ClTrainingMode, CompetenceLearning, CurriculumMetrics,
320 CurriculumTrainer, DataShapley, KnnShapley, LavaValuation, SelfPacedLearning,
321};
322pub mod lifelong_learning;
323pub use lifelong_learning::{
324 AgemOptimizer,
325 ClMetrics,
326 ContinualLearningMetrics,
327 DarkExperienceReplay,
328 EpisodeMemory,
329 ExperienceReplay,
330 ForgettingMeasure,
331 GemConstraint,
332 GemTrainer,
333 GenerativeReplay,
334 HatMask,
335 LearningWithoutForgetting,
336 LllAGemModel,
337 LllCoPE,
338 LllDer,
339 LllDerEntry,
340 LllER,
341 LllErSample,
342 LllGemModel,
343 LllHAT,
344 LllLinearLayer,
346 LllMetrics,
347 LllReport,
348 LllTaskOracle,
349 ModularNetwork,
350 PackNetMasker,
351 ProgressiveGrowth,
352 ReplayBuffer,
353 SynapticIntelligence,
354 OWM,
355};
356pub mod lm_evaluation;
357pub use lm_evaluation::{
358 LmeBERTScore,
359 LmeBERTScoreResult,
361 LmeBLEU,
362 LmeBLEUResult,
364 LmeCalibration,
366 LmeChrF,
368 LmeCodeEval,
369 LmeCodeResult,
371 LmeDistinctN,
373 LmeFewShotConfig,
374 LmeFewShotEval,
375 LmeFewShotExample,
377 LmeHarmEval,
378 LmeHarmResult,
380 LmeMC1Result,
382 LmeMC2Result,
383 LmeMathEval,
384 LmeMathResult,
386 LmeMeteor,
388 LmePerplexity,
389 LmePerplexityConfig,
391 LmeROUGE,
392 LmeROUGEScore,
394 LmeReport,
396 LmeTruthfulnessEval,
397 LmeWER,
399};
400pub mod differentiable_physics;
401pub mod diffusion;
402pub mod distillation;
403pub mod document_understanding;
404pub use document_understanding::{
405 create_spans,
406 detect_table_structure,
407 exact_match,
408 extract_answer,
409 extractive_summarize,
410 f1_answer,
411 f1_score_ner,
413 merge_lines,
414 mmr_summarize,
415 normalize_bbox,
416 parse_form,
417 rouge_n,
418 sort_reading_order,
419 table_to_csv,
420 table_to_json,
421 validate_field,
422 BboxEmbedding,
423 DocClass,
425 DocClassifier,
426 DocEmbedder,
427 DocEmbedderConfig,
428 DocEntity,
429 DocQaConfig,
431 EntityType,
433 FieldType,
435 FormField,
436 InformationExtractor,
437 InputSpan,
438 LayoutLm,
439 LayoutLmConfig,
441 LayoutLmLayer,
442 OcrBox,
444 PoolingStrategy,
446 SentenceScorer,
448 Table,
449 TableCell,
451};
452pub mod ensemble;
453pub mod federated;
454pub mod flows;
455pub mod hierarchical_time_series;
456pub mod hparam;
457pub mod hyperdimensional;
458pub use hyperdimensional::{
459 bind_binary, bind_bipolar, bundle_binary, bundle_bipolar, compute_hdc_stats,
460 orthogonality_test, permute_binary, permute_bipolar, unbind_binary, unbind_bipolar, BinaryHv,
461 BipolarHv, HdClassifier, HdSequenceEncoder, HdcError, HdcStats, HvType, IdEncoder, ItemMemory,
462 LevelEncoder, OnlineHdc, RealHv, SdmConfig, SparseSdm, ThermometerEncoder, HD_DIM,
463};
464pub mod hyperparameter_optimization;
465pub use hyperparameter_optimization::{
466 best_trial,
467 hypervolume_contribution,
468 importance_by_fanova,
469 nsga2_select,
470 pareto_front,
471 BohbConfig,
473 BohbOptimizer,
474 CmaHpoConfig,
476 CmaHpoState,
477 EvolutionaryStrategy,
478 GpHpo,
480 HbBracket,
481 HpSpace,
482 HpType,
484 HpoAcqFunction,
485 HpoBayesianConfig,
486 HpoBayesianOptimizer,
487 HpoLogger,
488 HpoStudy,
489 HpoTrial,
490 HyperBandConfig,
492 HyperBandScheduler,
493 HyperParameter,
494 KdeSampler,
495 MedianStopping,
497 MoHpoConfig,
499 MoObservation,
500 MultiObjectiveHpo,
501 OptDirection,
502 PbtConfig,
504 PbtMember,
505 PbtPopulation,
506 PercentileStop,
507 PopulationBasedTraining,
508 PreviousStudy,
510 SuccessiveHalving,
511 TrialStatus,
513 WarmStartSampler,
514};
515pub mod lora_adapters;
516pub use lora_adapters::{
517 AdaLoraLayer, BitFit, BitFitMask, DoraLayer, Ia3Layer, LoftqInit, LoftqResult, LoraLayer,
518 LoraPlus, PeftManager, PrefixTuning, PromptTuning, SingularValueImportance,
519};
520pub mod learning_to_learn;
521pub use learning_to_learn::{
522 compute_linear_mse_grad, GradientPreprocessor, L2lAttnBlock, L2lError, L2lHiddenState,
523 L2lLstmCell, L2lMetrics, L2lScheduler, L2lTask, L2lTcBlock, L2lTensor, LstmMetaOptimizer,
524 MetaDataset, MetaLearningTrainer, MetaSampledTask, MetaTaskType, OptimizerNetwork, SnailModel,
525 WarmStartOptimizer,
526};
527pub mod lr_finder;
528pub mod meta_learning;
529pub mod model_utils;
530pub mod music_generation;
531pub use music_generation::{
532 generate_drum_pattern,
533 AdsrEnvelope,
534 AudioSynthesizer,
535 BjorklundPattern,
537 Chord,
539 ChordDetection,
540 ChordDetector,
541 DrumStyle,
542 EventType,
544 GenerationConfig as MelodyGenerationConfig,
545 MelodyGenerator,
546 MidiSequence,
547 MuseTransformer,
549 MusicEvalReport,
551 MusicMetrics,
552 MusicTheoryAnalyzer,
553 MusicVariationGenerator,
555 NoteEvent,
556 PianoRollEncoder,
558 RhythmGrid,
559 ScaleType,
561 SynthConfig,
562 VoiceLeadingAnalysis,
564 WaveType,
566};
567pub mod multi_objective;
568pub use multi_objective::{
569 AugmentedLagrangian, CaGrad, GeneticOperators, GradNormOptimizer, HypervolumeIndicator, ImtlG,
570 InteriorPointMethod, LinearScalarization, Moead, MooHeads, NsgaIii, PFLLayer, ParetoFront,
571 PcGrad, PenaltyMethod, ProjectedGradient, R2Indicator, ReferencePointSampler,
572 SelectionOperator, TchebycheffScalarization, UncertaintyWeighting,
573};
574pub mod multi_fidelity;
575pub mod multi_task;
576pub mod neural_ode;
577pub mod neural_sde;
578pub use neural_sde::{
579 compute_sde_metrics,
580 shuffle_product,
581 AncestralSampler,
583 CdeVectorField,
584 ControlledSde,
585 LatentSde,
586 LatentSdeConfig,
587 LogSignatureLayer,
588 NaturalCubicSpline,
589 NeuralCde,
590 NeuralCdeConfig,
591 NeuralRde,
592 NsdeMlp,
593 PathSignature,
594 RoughPath,
595 ScoreMatchingSde,
596 SdeDecoder,
597 SdeDiffusionNet,
598 SdeDriftNet,
599 SdeEncoder,
600 SdeMetrics,
601 SdeMetricsExtended,
602 SdeTrainer,
603 SignatureConfig,
604 SignatureKernel,
605 SignatureTransform,
606 VeSde,
607 VpSde,
608};
609pub mod energy_models;
610pub mod quantization;
611pub mod rl;
612pub mod self_supervised;
613pub mod signal;
614pub mod speech_recognition;
615pub use speech_recognition::{
616 cer,
617 edit_distance,
619 log_mel_spectrogram,
621 wer,
622 AsrMetrics,
623 AudioConvStem,
625 BeamEntry,
627 CtcBeamDecoder,
628 CtcConfig,
629 JointNetwork,
630 LmRescorer,
631 NgramLm,
633 PredictionNetwork,
635 RnntDecoder,
636 SpeakerDiarizer,
637 SpectralClustering,
639 SpeechAugmentation,
641 SpeechPipeline,
643 VadConfig,
644 VoiceActivityDetector,
646 WhisperConfig,
648 WhisperDecoder,
649 WhisperDecoderLayer,
650 WhisperEncoder,
651 WhisperEncoderLayer,
652};
653pub mod simulation_based_inference;
654pub use simulation_based_inference::{
655 c2st_accuracy,
656 simulation_based_calibration,
657 tarp_test,
658 AbcRejection,
660 AbcSmcSampler,
661 ExpectedCoveragePlot,
663 FlowPosteriorSampler,
664 FlowSbiTrainer,
665 GaussianSimulator,
666 LocalPredictivePerformance,
667 MadeLayer,
669 NeuralDensityEstimator,
670 NeuralLikelihood,
671 NeuralRatioEstimator,
673 NlePosteriorSampler,
674 NreRatioEstimator,
675 RoundSummary,
676 SbcResult,
678 SbiClassifier,
680 SbiExtendedReport,
681 SbiNormalizingFlow,
683 SbiReport,
684 SequentialNpe,
685 Simulator,
687 SnleConfig,
689 SnleEstimator,
690 SnlePosterior,
691 SnpeConfig,
693 SnpePosterior,
694 SummaryStatistics,
695 TarpResult,
696};
697
698pub mod simulation_ml;
699pub use simulation_ml::{
700 AdaptiveSampler, DenseLayer, EnsembleSurrogate, GpSurrogate, LatinHypercubeSampler,
701 MlCorrectionModel, NeuralSurrogate, PhysicsResidual, PiSurrogate, RbfKernel,
702 ReynoldsStressTensor, SimError, SimulationDataAugmenter, SimulationMetrics, SmagorinskyModel,
703 Spring, SpringMassSystem, SurrogateActivation, SurrogateConfig,
704};
705pub mod spectral;
706pub mod ssm;
707pub mod structured_prediction;
708pub use structured_prediction::{
709 ctc_loss,
710 label_smoothing_loss,
711 log_sum_exp as sp_log_sum_exp,
712 ordered_prediction_loss,
713 sequence_cross_entropy,
715 softmax as sp_softmax,
716 BeliefPropagation,
717 ConstituencyConfig,
718 ConstituencyParser,
719 ConstrainedDecoding,
720 DependencyParser,
722 EnergyNetConfig,
724 EnergyNetwork,
725 FactorGraph,
727 HammingLoss,
728 LinearChainCrfConfig,
730 NeuralCrfConfig,
732 NeuralCrfLayer,
733 PartialCrfLoss,
734 SecondOrderCrf,
735 SecondOrderCrfConfig,
737 SemanticRoleLabeler,
738 SpLinear,
740 SpLinearChainCrf,
741 SpStructuredMetrics,
743 Span,
745 SpanClassifier,
746 SpanExtractor,
747 SrlAnnotation,
748 SsvmConfig,
750 StructuredLoss,
751 StructuredSvm,
752};
753pub mod symbolic_math;
754pub use symbolic_math::{
755 balance,
756 crossover as sym_crossover,
757 derivative as sym_derivative,
758 dim_divide,
759 dim_multiply,
760 eval as sym_eval,
761 evolve as sym_evolve,
762 find_pi_groups,
763 frobenius_derivative,
764 init_population as sym_init_population,
765 integrate as sym_integrate,
766 is_dimensionless,
767 mutate as sym_mutate,
768 parse_equation,
770 parse_token_sequence,
771 poly_add,
772 poly_div_rem,
773 poly_evaluate,
774 poly_gcd,
775 poly_mul,
776 poly_sub,
777 roots_companion_matrix,
778 simplify as sym_simplify,
779 tournament_select as sym_tournament_select,
780 trace_derivative,
781 Dimension,
783 Expr,
785 ExprBasis,
787 ExprToken,
789 ExpressionHasher,
791 GradientSymbolicRegressor,
792 Individual,
794 MatrixExpr,
796 NeuralExpressionSynthesizer,
797 Polynomial,
799 SynthesizerConfig,
800};
801pub mod tabular_learning;
802pub use tabular_learning::{
803 AttentiveTransformer, CatBoostEncoder, CyclicEncoder, FTTransformer, FTTransformerConfig,
804 FeatureEncoder, MinMaxScaler, MixedInputHead, NodeModel, ObliviousTree, QuantileTransformer,
805 SaintBlock, SaintModel, StandardScaler, TabNet, TabNetConfig, TabTransformer,
806 TabTransformerConfig, TabularAugmentation, TabularMetrics,
807};
808pub mod time_series;
809pub mod tokenizer;
810pub mod vae;
811pub mod variational_inference;
812pub use variational_inference::{
813 compute_pareto_k,
814 diagnose_vi,
815 effective_sample_size,
816 vi_log_sum_exp,
817 AdviModel,
818 AdviVariable,
819 BbviConfig,
821 BbviResult,
822 BlackBoxVi,
823 FlowVi,
824 FullRankGaussian,
826 MeanFieldGaussian,
828 ParameterConstraint,
830 PlanarFlowLayer,
832 StructuredVi,
834 ViDiagnostics,
836 ViDistribution,
838 ViSvgd,
839 ViSvgdConfig,
841};
842pub mod graph_generation;
843pub mod graph_matching;
844pub mod graph_transformer;
845pub use graph_matching::{
846 GmEdge, GmEditCostModel, GmEditOp, GmGraph, GmMcsResult, GmReport, GmSoftAssignment,
847 GraduatedAssignment, GraphEditDistance, GraphMatchMetrics, MaxCommonSubgraph, RandomWalkKernel,
848 ShortestPathKernel, SpectralAlignment, Vf2Matcher, WeisfeilerLemanKernel,
849};
850pub mod automl;
851pub mod nas;
852pub use automl::{
853 AlgorithmPerformancePredictor,
855 AlgorithmSelector,
856 ArchitectureBank,
857 ArchitectureDecoder,
858 ArchitectureEnsemble,
859 ArchitecturePredictor,
860 AutoFeaturePipeline,
861 AutoMlPipeline,
862 AutoMlReport,
863 AutoNormalizer,
864 BayesianOptimizer as AutomlBayesianOptimizer,
865 CellBasedNas,
866 CellOp,
867 Config as AutomlConfig,
868 ConfigSpace,
869 DatasetMetaFeatures,
870 DistributionType,
871 EarlyStoppingRule,
872 EditDistance,
873 EfficientNasPredictor,
874 FeatureInteractionSearch,
875 FeatureSelector,
876 GradNormScore,
877 GradientBasedNas,
878 GraphEncoding,
879 HyperBand,
880 HyperBandBracket,
881 JacobianScore,
882 LandmarkingFeatures,
883 MetaFeatureNormalizer,
884 MfHpType,
885 MultiObjectiveOptimizer,
886 NaswotScore,
887 PathEncoding,
888 PipelineOptimizer,
889 PipelineStep,
890 PolynomialFeatures,
891 PortfolioSelector,
892 ProxylessNas,
893 SinglePathOneShot,
894 SmacOptimizer,
895 SupernetLayer,
896 SupernetOp,
897 SynflowScore,
898 TpeSampler,
899 TransferNasFeatures,
900 TrialMetric,
901 ZenScore,
902};
903pub mod robotics;
904pub use robotics::{
905 AStarPlanner, BehavioralCloning, DaggerPolicy, DemonstrationBuffer, DhParam, GailDiscriminator,
906 GraspQuality, MotionPrimitive, OccupancyGrid, ParticleFilter, PotentialFieldNavigator,
907 RecurrentStateSpaceModel, RewardPredictor, RobotKinematics, WorldModelDecoder,
908 WorldModelEncoder,
909};
910pub mod riemannian_geometry;
911pub use riemannian_geometry::{
912 mat_inv_nn,
913 mat_mul_nn,
915 matrix_exp_sym,
916 matrix_log_sym,
917 symmetrize_nn,
918 FrechetMean as RgFrechetMean,
920 RgGrassmannManifold,
922 RgSo3Manifold,
924 RgSpdManifold,
926 RgStiefelManifold,
928 RiemannianAdam,
929 RiemannianAdamConfig,
931 RiemannianBatchNorm,
933 RiemannianManifold,
935 RiemannianSgd,
936};
937pub mod causal_inference;
938pub mod causal_representation;
939pub mod memory_networks;
940pub mod pinn;
941pub mod point_processes;
942pub use causal_representation::{
943 compute_mig, compute_modularity, compute_sap, CrLinear, CrMlp, DeepScm, DisentanglementMetrics,
944 DscmConfig, DscmMechanism, FactorVae, FactorVaeConfig, IvaeConfig, IvaeDecoder, IvaeEncoder,
945 IvaeModel, IvaePrior, NonlinearIca, SlowIcaConfig, TcDiscriminator, TcVae, TcVaeConfig,
946 TcVaeDecoder, TcVaeEncoder,
947};
948pub mod bayesian_opt;
949pub mod multimodal;
950pub mod probabilistic;
951pub use probabilistic::{
952 nig_loss, ActNorm as ProbActNorm, CalibrationEvaluator, DeepEnsemble, DirichletOutput,
953 EnsembleMember, EvidentialClassLayer, EvidentialLayer, Flow, FlowModel, FlowResult,
954 GaussianProcess as ProbGaussianProcess, GpKernel, GpPrediction, IsotonicCalibrator, NigOutput,
955 PlattScaling, RealNvpCoupling, SnapshotEnsemble, TemperatureScaling,
956};
957pub mod reward_learning;
958pub use reward_learning::{
959 compute_returns, cross_entropy_binary, kendall_tau, sigmoid, BradleyTerryModel, Comparison,
960 CuriosityShaper, GoalRewardShaper, IrlConfig, MaxCausalEntIrl, MaxEntIrl, PreferenceDataset,
961 RewardModel, RewardModelConfig, RewardModelMetrics, RewardShaper, RlhfConfig, RlhfResult,
962 RlhfTrainer,
963};
964
965pub use bayesian_opt::{
966 AcquisitionFunction, AcquisitionResult, BayesOptConfig, BayesOptResult, BayesSearchSpace,
967 BayesianOptimizer, CmaEs, CmaEsConfig, CmaEsResult, CmaEsState, FidelityLevel, GaussianProcess,
968 GpConfig, KernelType, MultiFidelityConfig, MultiFidelityOptimizer, ObservationRecord,
969};
970
971pub use graph_transformer::{
972 layer_norm as gt_layer_norm, matmul as gt_matmul, softmax as gt_softmax, ChebNetLayer,
973 GatConfig, GraphAttentionTransformer, GraphTransformerError, GraphTransformerLayer,
974 LaplacianPositionalEncoding, PosEncodingType, RandomWalkPositionalEncoding, ReadoutType,
975};
976
977pub use graph_generation::{
978 AtomType, BondType, GraphGenError, GraphPropertyPredictor, GraphRnn, GraphRnnConfig,
979 MolecularFingerprint, MolecularGraph, MpnnConfig, MpnnLayer, VgaeConfig, VgaeEncoder,
980};
981
982pub use nas::{
983 decode_architecture_string,
984 encode_architecture_string,
985 network_stats,
986 AgingEvolutionNas,
988 ArchitectureEvaluator,
989 CellConfig,
990 CellEdge,
991 CellEncoding,
992 DartsCell,
993 DartsConfig,
994 DartsOptimizer,
995 DartsState,
996 EvoNasConfig,
997 EvolutionResult,
998 EvolutionaryNas,
999 GumbelSoftmax,
1000 LotteryTicketPruner,
1001 MixedOp,
1002 NasLogger,
1003 NasSummary,
1004 NetworkEncoding,
1005 NetworkStats,
1006 NodeConfig,
1007 OneShotConfig,
1008 OneShotNas,
1009 OpType,
1011 OpsChoice,
1012 RandomNasConfig,
1013 RandomNasSearch,
1014 RandomSearchNas,
1015 SearchSpace,
1016 TicketState,
1017};
1018
1019pub use pinn::{
1020 BoundaryCondition, BurgersEquation, CollocationSampler, HeatEquation, LossComponents,
1021 NumericalGradient, PdeResidual, PinnActivation, PinnConfig, PinnLoss, PinnNetwork,
1022 PinnSolution, PinnTrainer, PoissonEquation, WaveEquation,
1023};
1024
1025pub use activation_function::ActivationFunction;
1026pub use benchmarks::{
1027 compare_models, BenchmarkConfig, BenchmarkMetrics, BenchmarkResults, ModelBenchmark,
1028};
1029pub use distributed::{
1030 models::utils::{create_data_parallel, create_distributed_data_parallel, init_process_group},
1031 models::{DDPConfig, DataParallel, DistributedDataParallel, SynchronizationMode},
1032 BackendConfig, CollectiveOp, CollectiveResult, CommunicationBackend, CommunicationBackendImpl,
1033 CommunicationGroup, CommunicationMetrics, CommunicationRuntime, CompressionAlgorithm,
1034 ReductionOp,
1035};
1036pub use layers::{
1037 compute_slopes, naive_attention, scaled_dot_product_attention, AlibiAttention, AlibiMask,
1038 AlibiSlopes, BahdanauAttention, Conv2D, Dense, Dropout, FlashAttention, FlashConfig, KVCache,
1039 Layer, LuongAttention, MultiHeadAttention, OnlineSoftmax, RMSNorm, RopeConfig, RopeEmbedding,
1040 RotaryInterpolation, TransformerDecoder, TransformerEncoder, GRU, LSTM, RNN,
1041};
1042pub use loss::{
1043 advanced_knowledge_distillation_loss, binary_cross_entropy, categorical_cross_entropy,
1044 focal_loss, hinge_loss, huber_loss, knowledge_distillation_loss, mse, quantile_loss,
1045 sparse_categorical_cross_entropy,
1046};
1047pub use metrics::{
1048 accuracy, confusion_matrix, f1_score, mean_absolute_percentage_error, precision, r_squared,
1049 recall, top_k_accuracy,
1050};
1051pub use mixed_precision::MixedPrecisionTrainer;
1052pub use model::{
1053 FunctionalModel, FunctionalModelBuilder, Input, Model, Node, Sequential, SharedLayer,
1054};
1055pub use model_parallel::{
1056 CommunicationPattern, MemoryRequirements, ModelParallelConfig, ModelParallelCoordinator,
1057 ParallelLayer, PipelineConfig, PlacementStrategy, SplitLayer, TensorParallelConfig,
1058};
1059pub use optimizers::{
1060 clip_gradients_adaptive,
1061 clip_gradients_by_global_norm,
1062 clip_gradients_by_norm,
1063 clip_gradients_by_value,
1064 AdaBelief,
1065 Adadelta,
1066 Adagrad,
1067 Adam,
1068 AdamW,
1069 AnnealStrategy,
1070 CosineAnnealingScheduler,
1071 ExponentialDecayScheduler,
1072 LambConfig,
1073 LambOptimizer,
1074 LinearScheduler,
1075 Lion,
1076 LionConfig,
1078 LionOptimizer,
1079 Lookahead,
1080 LrScheduler,
1082 MetricMode,
1083 MuonConfig,
1084 MuonOptimizer,
1085 Nadam,
1086 OneCycleLrScheduler,
1087 Optimizer,
1088 ParameterGroup,
1089 ParameterGroupOptimizer,
1090 PolynomialDecayScheduler,
1091 RAdam,
1092 RMSprop,
1093 SchedReduceLrOnPlateau,
1094 WarmupScheduler,
1095 LAMB,
1096 SGD,
1097};
1098pub use pipeline::{MicroBatch, PipelineMetrics, PipelineModelBuilder, PipelineParallelModel};
1099pub use scheduler::{
1100 ConstantLR, CosineAnnealingLR, ExponentialLR, LearningRateScheduler, PolynomialLR,
1101 ReduceLROnPlateau, StepLR, WarmupCosineDecayLR,
1102};
1103pub use trainer::{
1104 Callback, EarlyStopping, LearningRateReduction, ModelCheckpoint, Trainer, TrainingMetrics,
1105 TrainingState,
1106};
1107pub use training::{
1108 create_distillation_trainer, create_distillation_trainer_with_temperature,
1109 create_memory_efficient_trainer, create_trainer_for_large_model, AccumulationTrainingConfig,
1110 DistillationConfig, DistillationMetrics, DistillationTrainer, DistillationTrainerBuilder,
1111 GradientAccumulationTrainer, TrainingStats,
1112};
1113pub use training_pipeline::{
1114 quick_train, TrainingPipeline, TrainingPipelineConfig, TrainingResults,
1115};
1116
1117pub use trainer::TensorboardCallback;
1118
1119#[cfg(feature = "onnx")]
1120pub use onnx::{
1121 OnnxAttribute, OnnxDataType, OnnxExport, OnnxGraph, OnnxModel, OnnxNode, OnnxTensor,
1122 OnnxValueInfo,
1123};
1124
1125pub use tensorflow_compat::{
1126 load_tensorflow_model, load_tensorflow_model_with_config, SavedModel, SavedModelLoader,
1127 SavedModelMetadata,
1128};
1129
1130pub use text_generation_pipelines::{
1131 CFGDecoder, CharTokenizer, EtaSampler, GenerationCache, GreedyDecoder, MinPSampler,
1132 PipelineResult, RepetitionPenaltyProcessor, SamplingStrategy as TgpSamplingStrategy,
1133 SimpleVocab, StreamingGenerator, TemperatureScaledDecoder, TgpConfig, TgpPipelineMetrics,
1134 Tokenizer as TgpTokenizer, TypicalSamplerTgp,
1135};
1136
1137pub use data::{DataPipelineConfig, NeuralDataPipeline, NeuralTransforms, TrainingBatch};
1138pub use deployment::{
1139 conservative_pruning_config, edge_fusion_config, edge_pruning_config, edge_quantization_config,
1140 fuse_layers, mobile_fusion_config, mobile_pruning_config, mobile_quantization_config,
1141 optimize_for_deployment, prune_model, quantize_model, ultra_low_precision_config,
1142 DeploymentMetadata, DeploymentModel, FusedLayer, FusionConfig, FusionPattern, FusionStats,
1143 LayerFusion, ModelOptimizer, ModelPruner, ModelQuantizer, OptimizationConfig,
1144 OptimizationStats, PrunedLayer, PruningConfig, PruningMask, PruningScope, PruningStats,
1145 PruningStrategy, QuantizationConfig, QuantizationParams, QuantizationPrecision,
1146 QuantizationStats, QuantizationStrategy, QuantizedLayer,
1147};
1148pub use peft::{
1149 AdaLoRAAdapter, AdaLoRAConfig, AdaLoRAStats, IA3Adapter, IA3Config, IA3InitStrategy,
1150 IA3ScalingType, IA3Stats, ImportanceMetric, LoRAAdapter, LoRAConfig, LoRADense, LoRALayer,
1151 MultiIA3Adapter, MultiIA3Stats, PEFTAdapter, PEFTConfig, PEFTLayer, PEFTMethod, PEFTStats,
1152 PTuningTaskType, PTuningV2Adapter, PTuningV2Config, PTuningV2Stats, PrefixTaskType,
1153 PrefixTuningAdapter, PrefixTuningConfig, PrefixTuningStats, PromptLayerConfig, QLoRAAdapter,
1154 QLoRAConfig, QLoRAMemoryStats, QuantizationType, RankAdaptationStats, TokenPosition,
1155};
1156pub use pretrained::{
1157 BasicBlock, BottleneckBlock, EfficientNet, EfficientNetConfig, MBConvBlock,
1158 PatchEmbedding as PretrainedPatchEmbedding, ResNet, ResNetBlockType, SEBlock,
1159 VisionTransformer as PretrainedVisionTransformer,
1160};
1161
1162include!("reexports_ext.rs");