Skip to main content

tenflowers_neural/
lib.rs

1//! # TenfloweRS Neural Network Framework
2//!
3//! TenfloweRS Neural is a comprehensive, production-ready deep learning library built in pure Rust.
4//! It provides a high-level API for building, training, and deploying neural networks with a focus
5//! on safety, performance, and ease of use.
6//!
7//! ## Features
8//!
9//! - **Comprehensive Layer Library**: Dense, convolutional, recurrent, attention, normalization, and more
10//! - **Advanced Training**: Gradient accumulation, mixed precision, distributed training
11//! - **Modern Architectures**: Transformers, ResNet, EfficientNet, Vision Transformers, BERT, GPT
12//! - **PEFT Methods**: LoRA, QLoRA, Prefix Tuning, P-Tuning v2, IA³
13//! - **Optimization**: SGD, Adam, AdamW, Lion, LAMB, AdaBelief with advanced scheduling
14//! - **Deployment**: Model quantization, pruning, ONNX export, mobile optimization
15//! - **SciRS2 Integration**: Built on the robust SciRS2 scientific computing ecosystem
16//!
17//! ## Quick Start
18//!
19//! ### Building a Simple Neural Network
20//!
21//! ```rust,ignore
22//! use tenflowers_neural::{Sequential, Dense, ActivationFunction};
23//! use tenflowers_core::{Tensor, Device};
24//!
25//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
26//! // Create a simple feedforward network
27//! let mut model = Sequential::new();
28//! model.add(Dense::new(784, 128)?);
29//! model.add_activation(ActivationFunction::ReLU);
30//! model.add(Dense::new(128, 10)?);
31//! model.add_activation(ActivationFunction::Softmax);
32//!
33//! // Forward pass
34//! let input = Tensor::zeros(&[32, 784]); // batch_size=32, features=784
35//! let output = model.forward(&input)?;
36//! # Ok(())
37//! # }
38//! ```
39//!
40//! ### Training with the High-Level API
41//!
42//! ```rust,ignore
43//! use tenflowers_neural::{quick_train, Sequential, Dense, SGD};
44//! use tenflowers_neural::loss::categorical_cross_entropy;
45//! use tenflowers_core::Tensor;
46//!
47//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
48//! // Build model
49//! let mut model = Sequential::new();
50//! model.add(Dense::new(10, 64)?);
51//! model.add(Dense::new(64, 3)?);
52//!
53//! // Prepare data
54//! let x_train = Tensor::zeros(&[100, 10]);
55//! let y_train = Tensor::zeros(&[100, 3]);
56//!
57//! // Train with one line
58//! let results = quick_train(
59//!     model,
60//!     &x_train,
61//!     &y_train,
62//!     Box::new(SGD::new(0.01)),
63//!     categorical_cross_entropy,
64//!     10, // epochs
65//!     32, // batch_size
66//! )?;
67//! # Ok(())
68//! # }
69//! ```
70//!
71//! ### Advanced Training with Callbacks
72//!
73//! ```rust,ignore
74//! use tenflowers_neural::{Trainer, EarlyStopping, ModelCheckpoint};
75//! use tenflowers_neural::{Sequential, Dense, Adam};
76//! use tenflowers_neural::loss::mse;
77//! use tenflowers_core::Tensor;
78//!
79//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
80//! let model = Sequential::new();
81//! let optimizer = Box::new(Adam::new(0.001));
82//!
83//! let mut trainer = Trainer::new(model, optimizer, mse);
84//! trainer.add_callback(Box::new(EarlyStopping::new(5, 0.001)));
85//! trainer.add_callback(Box::new(ModelCheckpoint::new("best_model.bin")?));
86//!
87//! let x_train = Tensor::zeros(&[1000, 10]);
88//! let y_train = Tensor::zeros(&[1000, 1]);
89//!
90//! trainer.fit(&x_train, &y_train, 100, 32)?;
91//! # Ok(())
92//! # }
93//! ```
94//!
95//! ## Architecture Overview
96//!
97//! The crate is organized into the following modules:
98//!
99//! - [`layers`]: Neural network layer implementations (Dense, Conv, RNN, Attention, etc.)
100//! - [`model`]: Model abstractions (Sequential, Functional, custom models)
101//! - [`optimizers`]: Optimization algorithms (SGD, Adam, AdamW, Lion, etc.)
102//! - [`loss`]: Loss functions (MSE, cross-entropy, focal loss, etc.)
103//! - [`metrics`]: Evaluation metrics (accuracy, F1, precision, recall, etc.)
104//! - [`trainer`]: High-level training API with callbacks and hooks
105//! - [`scheduler`]: Learning rate scheduling strategies
106//! - [`distributed`]: Distributed and data-parallel training
107//! - [`peft`]: Parameter-efficient fine-tuning methods
108//! - [`deployment`]: Model optimization and export utilities
109//! - [`pretrained`]: Pretrained model architectures and weights
110//!
111//! ## GPU Acceleration
112//!
113//! TenfloweRS supports GPU acceleration through the SciRS2 ecosystem. GPU operations
114//! are automatically dispatched when tensors are placed on GPU devices:
115//!
116//! ```rust,ignore
117//! use tenflowers_core::{Tensor, Device};
118//! use tenflowers_neural::Dense;
119//!
120//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
121//! # #[cfg(feature = "gpu")]
122//! # {
123//! let device = Device::gpu(0)?; // Use GPU 0
124//! let layer = Dense::new(128, 64)?;
125//! let input = Tensor::zeros(&[32, 128]).to_device(&device)?;
126//! let output = layer.forward(&input)?; // Runs on GPU
127//! # }
128//! # Ok(())
129//! # }
130//! ```
131//!
132//! ## Mixed Precision Training
133//!
134//! For faster training and reduced memory usage:
135//!
136//! ```rust,ignore
137//! use tenflowers_neural::{MixedPrecisionTrainer, Sequential, Adam};
138//! use tenflowers_neural::loss::mse;
139//!
140//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
141//! let model = Sequential::new();
142//! let optimizer = Box::new(Adam::new(0.001));
143//!
144//! let mut trainer = MixedPrecisionTrainer::new(
145//!     model,
146//!     optimizer,
147//!     mse,
148//!     true, // enable loss scaling
149//! );
150//! # Ok(())
151//! # }
152//! ```
153//!
154//! ## Distributed Training
155//!
156//! Scale training across multiple GPUs:
157//!
158//! ```rust,ignore
159//! use tenflowers_neural::{create_data_parallel, Sequential, Dense};
160//! use tenflowers_core::Device;
161//!
162//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
163//! # #[cfg(feature = "gpu")]
164//! # {
165//! let model = Sequential::new();
166//! let devices = vec![Device::gpu(0)?, Device::gpu(1)?];
167//! let parallel_model = create_data_parallel(model, devices)?;
168//! # }
169//! # Ok(())
170//! # }
171//! ```
172//!
173//! ## PEFT (Parameter-Efficient Fine-Tuning)
174//!
175//! Fine-tune large models efficiently:
176//!
177//! ```rust,ignore
178//! use tenflowers_neural::peft::{LoRALayer, LoRAConfig};
179//! use tenflowers_neural::Dense;
180//!
181//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
182//! let base_layer = Dense::new(768, 768)?;
183//! let lora_config = LoRAConfig {
184//!     rank: 8,
185//!     alpha: 16.0,
186//!     dropout: 0.1,
187//! };
188//! let lora_layer = LoRALayer::wrap(base_layer, lora_config)?;
189//! # Ok(())
190//! # }
191//! ```
192//!
193//! ## Model Deployment
194//!
195//! Optimize models for production:
196//!
197//! ```rust,ignore
198//! use tenflowers_neural::deployment::{ModelOptimizer, OptimizationConfig};
199//! use tenflowers_neural::Sequential;
200//!
201//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
202//! let model = Sequential::new();
203//! let config = OptimizationConfig {
204//!     quantize: true,
205//!     prune_threshold: Some(0.01),
206//!     fuse_operations: true,
207//! };
208//!
209//! let optimizer = ModelOptimizer::new(config);
210//! let optimized_model = optimizer.optimize(model)?;
211//! # Ok(())
212//! # }
213//! ```
214//!
215//! ## Contributing
216//!
217//! TenfloweRS is part of the SciRS2 ecosystem. For contributions, issues, or questions,
218//! please visit our GitHub repository.
219
220#![deny(unsafe_code)]
221#![allow(dead_code)]
222#![allow(unused_variables)]
223#![allow(unused_assignments)]
224#![allow(unused_mut)]
225#![allow(clippy::result_large_err)]
226#![allow(clippy::needless_range_loop)]
227#![allow(clippy::type_complexity)]
228#![allow(clippy::vec_init_then_push)]
229#![allow(clippy::clone_on_copy)]
230#![allow(clippy::if_same_then_else)]
231#![allow(clippy::doc_overindented_list_items)]
232
233pub mod activation_function;
234pub mod backends;
235pub mod benchmarks;
236pub mod data;
237pub mod deployment;
238pub mod distributed;
239pub mod layers;
240pub mod loss;
241pub mod metrics;
242pub mod mixed_precision;
243pub mod model;
244pub mod model_parallel;
245pub mod optimizers;
246pub mod peft;
247pub mod pipeline;
248pub mod pretrained;
249pub mod scheduler;
250#[cfg(feature = "serialize")]
251pub mod serialization;
252pub mod trainer;
253pub mod training;
254pub mod training_pipeline;
255pub mod utils;
256
257#[cfg(feature = "onnx")]
258pub mod onnx;
259
260pub mod active_learning;
261pub mod adversarial;
262pub mod anomaly_detection;
263pub mod tensorflow_compat;
264pub mod text_generation_pipelines;
265pub use anomaly_detection::{
266    compute_anomaly_metrics,
267    compute_auc_roc,
268    compute_average_precision,
269    AdDeepSvdd,
270    AdDeepSvddConfig,
271    // Extended f64-based anomaly detection
272    AdLinear,
273    AdMlp,
274    AeAnomaly,
275    AeAnomalyConfig,
276    AnomalyError,
277    AnomalyEvaluationMetrics,
278    AnomalyMemory,
279    AnomalyMetrics,
280    AnomalyThresholder,
281    AnomalyTransformerModel,
282    AnomalyVAE,
283    AnomalyVaeConfig,
284    CouplingLayer,
285    DeepSVDD,
286    DeepSvddConfig,
287    FlowAnomaly,
288    FlowAnomalyConfig,
289    GaussianMixtureAnomaly,
290    IsolationForestConfig,
291    IsolationNode,
292    IsolationTree,
293    MemAeConfig,
294    MemoryAugmentedAE,
295    NeuralIsolationForest,
296    PatchTSAD,
297    PatchTsadConfig,
298    RobustRCF,
299    RobustRcfConfig,
300    SpectralResidual,
301    VaeAnomaly,
302    VaeAnomalyConfig,
303};
304pub mod architecture_distillation;
305pub mod bayesian;
306pub mod bayesian_dl;
307pub use bayesian_dl::{
308    compute_calibration, nll_classification, BdlLinear, BdlMlp, CalibrationResult,
309    DeepEnsemble as BdlDeepEnsemble, DeepEnsembleConfig as BdlDeepEnsembleConfig,
310    LaplaceApproximation, SghcmConfig, SghcmSampler, SgldConfig, SgldSampler, SwagConfig,
311    SwagModel, TemperatureScaling as BdlTemperatureScaling,
312};
313pub mod continual_learning;
314pub mod contrastive;
315pub mod curriculum_learning;
316pub use curriculum_learning::{
317    AutoCurriculum, ClBabyStepConfig, ClBanditArm, ClCompetenceStrategy, ClConfidenceMethod,
318    ClCurriculumScheduler, ClDifficultyMethod, ClDifficultyScorer, ClEpochResult, ClLambdaSchedule,
319    ClReport, ClSplWeightingStrategy, ClTrainingMode, CompetenceLearning, CurriculumMetrics,
320    CurriculumTrainer, DataShapley, KnnShapley, LavaValuation, SelfPacedLearning,
321};
322pub mod lifelong_learning;
323pub use lifelong_learning::{
324    AgemOptimizer,
325    ClMetrics,
326    ContinualLearningMetrics,
327    DarkExperienceReplay,
328    EpisodeMemory,
329    ExperienceReplay,
330    ForgettingMeasure,
331    GemConstraint,
332    GemTrainer,
333    GenerativeReplay,
334    HatMask,
335    LearningWithoutForgetting,
336    LllAGemModel,
337    LllCoPE,
338    LllDer,
339    LllDerEntry,
340    LllER,
341    LllErSample,
342    LllGemModel,
343    LllHAT,
344    // Lll-prefixed extensions
345    LllLinearLayer,
346    LllMetrics,
347    LllReport,
348    LllTaskOracle,
349    ModularNetwork,
350    PackNetMasker,
351    ProgressiveGrowth,
352    ReplayBuffer,
353    SynapticIntelligence,
354    OWM,
355};
356pub mod lm_evaluation;
357pub use lm_evaluation::{
358    LmeBERTScore,
359    // §4 BERTScore
360    LmeBERTScoreResult,
361    LmeBLEU,
362    // §2 BLEU
363    LmeBLEUResult,
364    // §D Calibration
365    LmeCalibration,
366    // §C ChrF
367    LmeChrF,
368    LmeCodeEval,
369    // §9 Code evaluation
370    LmeCodeResult,
371    // §A Distinct-N
372    LmeDistinctN,
373    LmeFewShotConfig,
374    LmeFewShotEval,
375    // §6 Few-shot evaluation
376    LmeFewShotExample,
377    LmeHarmEval,
378    // §7 Harm evaluation
379    LmeHarmResult,
380    // §8 Truthfulness evaluation
381    LmeMC1Result,
382    LmeMC2Result,
383    LmeMathEval,
384    // §5 Math evaluation
385    LmeMathResult,
386    // §B METEOR
387    LmeMeteor,
388    LmePerplexity,
389    // §1 Perplexity
390    LmePerplexityConfig,
391    LmeROUGE,
392    // §3 ROUGE
393    LmeROUGEScore,
394    // §10 Report
395    LmeReport,
396    LmeTruthfulnessEval,
397    // §E WER
398    LmeWER,
399};
400pub mod differentiable_physics;
401pub mod diffusion;
402pub mod distillation;
403pub mod document_understanding;
404pub use document_understanding::{
405    create_spans,
406    detect_table_structure,
407    exact_match,
408    extract_answer,
409    extractive_summarize,
410    f1_answer,
411    // DocumentMetrics
412    f1_score_ner,
413    merge_lines,
414    mmr_summarize,
415    normalize_bbox,
416    parse_form,
417    rouge_n,
418    sort_reading_order,
419    table_to_csv,
420    table_to_json,
421    validate_field,
422    BboxEmbedding,
423    // DocumentClassifier
424    DocClass,
425    DocClassifier,
426    DocEmbedder,
427    DocEmbedderConfig,
428    DocEntity,
429    // QuestionAnsweringDoc
430    DocQaConfig,
431    // InformationExtractor
432    EntityType,
433    // FormParser
434    FieldType,
435    FormField,
436    InformationExtractor,
437    InputSpan,
438    LayoutLm,
439    // LayoutLM
440    LayoutLmConfig,
441    LayoutLmLayer,
442    // DocumentOCR
443    OcrBox,
444    // DocumentEmbedder
445    PoolingStrategy,
446    // DocumentSummarizer
447    SentenceScorer,
448    Table,
449    // TableExtractor
450    TableCell,
451};
452pub mod ensemble;
453pub mod federated;
454pub mod flows;
455pub mod hierarchical_time_series;
456pub mod hparam;
457pub mod hyperdimensional;
458pub use hyperdimensional::{
459    bind_binary, bind_bipolar, bundle_binary, bundle_bipolar, compute_hdc_stats,
460    orthogonality_test, permute_binary, permute_bipolar, unbind_binary, unbind_bipolar, BinaryHv,
461    BipolarHv, HdClassifier, HdSequenceEncoder, HdcError, HdcStats, HvType, IdEncoder, ItemMemory,
462    LevelEncoder, OnlineHdc, RealHv, SdmConfig, SparseSdm, ThermometerEncoder, HD_DIM,
463};
464pub mod hyperparameter_optimization;
465pub use hyperparameter_optimization::{
466    best_trial,
467    hypervolume_contribution,
468    importance_by_fanova,
469    nsga2_select,
470    pareto_front,
471    // §4 BOHB
472    BohbConfig,
473    BohbOptimizer,
474    // §6 CMA-ES
475    CmaHpoConfig,
476    CmaHpoState,
477    EvolutionaryStrategy,
478    // §2 GP surrogate & acquisition
479    GpHpo,
480    HbBracket,
481    HpSpace,
482    // §1 HpSpace
483    HpType,
484    HpoAcqFunction,
485    HpoBayesianConfig,
486    HpoBayesianOptimizer,
487    HpoLogger,
488    HpoStudy,
489    HpoTrial,
490    // §3 HyperBand
491    HyperBandConfig,
492    HyperBandScheduler,
493    HyperParameter,
494    KdeSampler,
495    // §8 Early Termination
496    MedianStopping,
497    // §7 Multi-Objective
498    MoHpoConfig,
499    MoObservation,
500    MultiObjectiveHpo,
501    OptDirection,
502    // §5 PBT
503    PbtConfig,
504    PbtMember,
505    PbtPopulation,
506    PercentileStop,
507    PopulationBasedTraining,
508    // §10 Warm Starting
509    PreviousStudy,
510    SuccessiveHalving,
511    // §9 HpoStudy / Logger
512    TrialStatus,
513    WarmStartSampler,
514};
515pub mod lora_adapters;
516pub use lora_adapters::{
517    AdaLoraLayer, BitFit, BitFitMask, DoraLayer, Ia3Layer, LoftqInit, LoftqResult, LoraLayer,
518    LoraPlus, PeftManager, PrefixTuning, PromptTuning, SingularValueImportance,
519};
520pub mod learning_to_learn;
521pub use learning_to_learn::{
522    compute_linear_mse_grad, GradientPreprocessor, L2lAttnBlock, L2lError, L2lHiddenState,
523    L2lLstmCell, L2lMetrics, L2lScheduler, L2lTask, L2lTcBlock, L2lTensor, LstmMetaOptimizer,
524    MetaDataset, MetaLearningTrainer, MetaSampledTask, MetaTaskType, OptimizerNetwork, SnailModel,
525    WarmStartOptimizer,
526};
527pub mod lr_finder;
528pub mod meta_learning;
529pub mod model_utils;
530pub mod music_generation;
531pub use music_generation::{
532    generate_drum_pattern,
533    AdsrEnvelope,
534    AudioSynthesizer,
535    // §6 Rhythm
536    BjorklundPattern,
537    // §4 Chord Detector
538    Chord,
539    ChordDetection,
540    ChordDetector,
541    DrumStyle,
542    // §1 MIDI
543    EventType,
544    GenerationConfig as MelodyGenerationConfig,
545    MelodyGenerator,
546    MidiSequence,
547    // §3 MuseTransformer
548    MuseTransformer,
549    // §10 Metrics
550    MusicEvalReport,
551    MusicMetrics,
552    MusicTheoryAnalyzer,
553    // §8 Variation Generator
554    MusicVariationGenerator,
555    NoteEvent,
556    // §2 Piano Roll
557    PianoRollEncoder,
558    RhythmGrid,
559    // §5 Melody Generator
560    ScaleType,
561    SynthConfig,
562    // §9 Theory Analyzer
563    VoiceLeadingAnalysis,
564    // §7 Audio Synthesizer
565    WaveType,
566};
567pub mod multi_objective;
568pub use multi_objective::{
569    AugmentedLagrangian, CaGrad, GeneticOperators, GradNormOptimizer, HypervolumeIndicator, ImtlG,
570    InteriorPointMethod, LinearScalarization, Moead, MooHeads, NsgaIii, PFLLayer, ParetoFront,
571    PcGrad, PenaltyMethod, ProjectedGradient, R2Indicator, ReferencePointSampler,
572    SelectionOperator, TchebycheffScalarization, UncertaintyWeighting,
573};
574pub mod multi_fidelity;
575pub mod multi_task;
576pub mod neural_ode;
577pub mod neural_sde;
578pub use neural_sde::{
579    compute_sde_metrics,
580    shuffle_product,
581    // advanced
582    AncestralSampler,
583    CdeVectorField,
584    ControlledSde,
585    LatentSde,
586    LatentSdeConfig,
587    LogSignatureLayer,
588    NaturalCubicSpline,
589    NeuralCde,
590    NeuralCdeConfig,
591    NeuralRde,
592    NsdeMlp,
593    PathSignature,
594    RoughPath,
595    ScoreMatchingSde,
596    SdeDecoder,
597    SdeDiffusionNet,
598    SdeDriftNet,
599    SdeEncoder,
600    SdeMetrics,
601    SdeMetricsExtended,
602    SdeTrainer,
603    SignatureConfig,
604    SignatureKernel,
605    SignatureTransform,
606    VeSde,
607    VpSde,
608};
609pub mod energy_models;
610pub mod quantization;
611pub mod rl;
612pub mod self_supervised;
613pub mod signal;
614pub mod speech_recognition;
615pub use speech_recognition::{
616    cer,
617    // Metrics
618    edit_distance,
619    // Spectrogram
620    log_mel_spectrogram,
621    wer,
622    AsrMetrics,
623    // Encoder / Decoder
624    AudioConvStem,
625    // CTC
626    BeamEntry,
627    CtcBeamDecoder,
628    CtcConfig,
629    JointNetwork,
630    LmRescorer,
631    // LM
632    NgramLm,
633    // RNN-T
634    PredictionNetwork,
635    RnntDecoder,
636    SpeakerDiarizer,
637    // Diarization
638    SpectralClustering,
639    // Augmentation
640    SpeechAugmentation,
641    // Pipeline
642    SpeechPipeline,
643    VadConfig,
644    // VAD
645    VoiceActivityDetector,
646    // Config types
647    WhisperConfig,
648    WhisperDecoder,
649    WhisperDecoderLayer,
650    WhisperEncoder,
651    WhisperEncoderLayer,
652};
653pub mod simulation_based_inference;
654pub use simulation_based_inference::{
655    c2st_accuracy,
656    simulation_based_calibration,
657    tarp_test,
658    // Advanced: ABC
659    AbcRejection,
660    AbcSmcSampler,
661    // Advanced: SBI diagnostics
662    ExpectedCoveragePlot,
663    FlowPosteriorSampler,
664    FlowSbiTrainer,
665    GaussianSimulator,
666    LocalPredictivePerformance,
667    // MADE / NDE
668    MadeLayer,
669    NeuralDensityEstimator,
670    NeuralLikelihood,
671    // SNRE
672    NeuralRatioEstimator,
673    NlePosteriorSampler,
674    NreRatioEstimator,
675    RoundSummary,
676    // Diagnostics
677    SbcResult,
678    // Advanced: NRE / NLE posterior samplers
679    SbiClassifier,
680    SbiExtendedReport,
681    // Advanced: Normalizing Flow SBI
682    SbiNormalizingFlow,
683    SbiReport,
684    SequentialNpe,
685    // Simulator trait + Gaussian toy
686    Simulator,
687    // SNLE
688    SnleConfig,
689    SnleEstimator,
690    SnlePosterior,
691    // SNPE
692    SnpeConfig,
693    SnpePosterior,
694    SummaryStatistics,
695    TarpResult,
696};
697
698pub mod simulation_ml;
699pub use simulation_ml::{
700    AdaptiveSampler, DenseLayer, EnsembleSurrogate, GpSurrogate, LatinHypercubeSampler,
701    MlCorrectionModel, NeuralSurrogate, PhysicsResidual, PiSurrogate, RbfKernel,
702    ReynoldsStressTensor, SimError, SimulationDataAugmenter, SimulationMetrics, SmagorinskyModel,
703    Spring, SpringMassSystem, SurrogateActivation, SurrogateConfig,
704};
705pub mod spectral;
706pub mod ssm;
707pub mod structured_prediction;
708pub use structured_prediction::{
709    ctc_loss,
710    label_smoothing_loss,
711    log_sum_exp as sp_log_sum_exp,
712    ordered_prediction_loss,
713    // §6 Loss functions
714    sequence_cross_entropy,
715    softmax as sp_softmax,
716    BeliefPropagation,
717    ConstituencyConfig,
718    ConstituencyParser,
719    ConstrainedDecoding,
720    // Advanced: Graph-based structured prediction
721    DependencyParser,
722    // §4 Energy-Based Model
723    EnergyNetConfig,
724    EnergyNetwork,
725    // §5 Belief Propagation
726    FactorGraph,
727    HammingLoss,
728    // §1 Linear-Chain CRF
729    LinearChainCrfConfig,
730    // Advanced: Neural CRF++
731    NeuralCrfConfig,
732    NeuralCrfLayer,
733    PartialCrfLoss,
734    SecondOrderCrf,
735    // §2 Second-Order CRF
736    SecondOrderCrfConfig,
737    SemanticRoleLabeler,
738    // §0 Shared utilities
739    SpLinear,
740    SpLinearChainCrf,
741    // Advanced: Metrics
742    SpStructuredMetrics,
743    // Advanced: Span-based models
744    Span,
745    SpanClassifier,
746    SpanExtractor,
747    SrlAnnotation,
748    // §3 Structured SVM
749    SsvmConfig,
750    StructuredLoss,
751    StructuredSvm,
752};
753pub mod symbolic_math;
754pub use symbolic_math::{
755    balance,
756    crossover as sym_crossover,
757    derivative as sym_derivative,
758    dim_divide,
759    dim_multiply,
760    eval as sym_eval,
761    evolve as sym_evolve,
762    find_pi_groups,
763    frobenius_derivative,
764    init_population as sym_init_population,
765    integrate as sym_integrate,
766    is_dimensionless,
767    mutate as sym_mutate,
768    // Equation balancer
769    parse_equation,
770    parse_token_sequence,
771    poly_add,
772    poly_div_rem,
773    poly_evaluate,
774    poly_gcd,
775    poly_mul,
776    poly_sub,
777    roots_companion_matrix,
778    simplify as sym_simplify,
779    tournament_select as sym_tournament_select,
780    trace_derivative,
781    // Dimensional analysis
782    Dimension,
783    // Expression tree
784    Expr,
785    // Regressor
786    ExprBasis,
787    // Synthesizer
788    ExprToken,
789    // Hasher
790    ExpressionHasher,
791    GradientSymbolicRegressor,
792    // Genetic programming
793    Individual,
794    // Matrix calculus
795    MatrixExpr,
796    NeuralExpressionSynthesizer,
797    // Polynomial arithmetic
798    Polynomial,
799    SynthesizerConfig,
800};
801pub mod tabular_learning;
802pub use tabular_learning::{
803    AttentiveTransformer, CatBoostEncoder, CyclicEncoder, FTTransformer, FTTransformerConfig,
804    FeatureEncoder, MinMaxScaler, MixedInputHead, NodeModel, ObliviousTree, QuantileTransformer,
805    SaintBlock, SaintModel, StandardScaler, TabNet, TabNetConfig, TabTransformer,
806    TabTransformerConfig, TabularAugmentation, TabularMetrics,
807};
808pub mod time_series;
809pub mod tokenizer;
810pub mod vae;
811pub mod variational_inference;
812pub use variational_inference::{
813    compute_pareto_k,
814    diagnose_vi,
815    effective_sample_size,
816    vi_log_sum_exp,
817    AdviModel,
818    AdviVariable,
819    // §4 BBVI
820    BbviConfig,
821    BbviResult,
822    BlackBoxVi,
823    FlowVi,
824    // §3 FullRankGaussian
825    FullRankGaussian,
826    // §2 MeanFieldGaussian
827    MeanFieldGaussian,
828    // §5 ADVI
829    ParameterConstraint,
830    // §7 FlowVi / PlanarFlow
831    PlanarFlowLayer,
832    // §6 StructuredVi
833    StructuredVi,
834    // §9 Diagnostics
835    ViDiagnostics,
836    // §1 Variational family enum
837    ViDistribution,
838    ViSvgd,
839    // §8 ViSvgd (separate from monte_carlo::SvgdOptimizer)
840    ViSvgdConfig,
841};
842pub mod graph_generation;
843pub mod graph_matching;
844pub mod graph_transformer;
845pub use graph_matching::{
846    GmEdge, GmEditCostModel, GmEditOp, GmGraph, GmMcsResult, GmReport, GmSoftAssignment,
847    GraduatedAssignment, GraphEditDistance, GraphMatchMetrics, MaxCommonSubgraph, RandomWalkKernel,
848    ShortestPathKernel, SpectralAlignment, Vf2Matcher, WeisfeilerLemanKernel,
849};
850pub mod automl;
851pub mod nas;
852pub use automl::{
853    // meta_features
854    AlgorithmPerformancePredictor,
855    AlgorithmSelector,
856    ArchitectureBank,
857    ArchitectureDecoder,
858    ArchitectureEnsemble,
859    ArchitecturePredictor,
860    AutoFeaturePipeline,
861    AutoMlPipeline,
862    AutoMlReport,
863    AutoNormalizer,
864    BayesianOptimizer as AutomlBayesianOptimizer,
865    CellBasedNas,
866    CellOp,
867    Config as AutomlConfig,
868    ConfigSpace,
869    DatasetMetaFeatures,
870    DistributionType,
871    EarlyStoppingRule,
872    EditDistance,
873    EfficientNasPredictor,
874    FeatureInteractionSearch,
875    FeatureSelector,
876    GradNormScore,
877    GradientBasedNas,
878    GraphEncoding,
879    HyperBand,
880    HyperBandBracket,
881    JacobianScore,
882    LandmarkingFeatures,
883    MetaFeatureNormalizer,
884    MfHpType,
885    MultiObjectiveOptimizer,
886    NaswotScore,
887    PathEncoding,
888    PipelineOptimizer,
889    PipelineStep,
890    PolynomialFeatures,
891    PortfolioSelector,
892    ProxylessNas,
893    SinglePathOneShot,
894    SmacOptimizer,
895    SupernetLayer,
896    SupernetOp,
897    SynflowScore,
898    TpeSampler,
899    TransferNasFeatures,
900    TrialMetric,
901    ZenScore,
902};
903pub mod robotics;
904pub use robotics::{
905    AStarPlanner, BehavioralCloning, DaggerPolicy, DemonstrationBuffer, DhParam, GailDiscriminator,
906    GraspQuality, MotionPrimitive, OccupancyGrid, ParticleFilter, PotentialFieldNavigator,
907    RecurrentStateSpaceModel, RewardPredictor, RobotKinematics, WorldModelDecoder,
908    WorldModelEncoder,
909};
910pub mod riemannian_geometry;
911pub use riemannian_geometry::{
912    mat_inv_nn,
913    // Linear algebra helpers
914    mat_mul_nn,
915    matrix_exp_sym,
916    matrix_log_sym,
917    symmetrize_nn,
918    // Fréchet mean
919    FrechetMean as RgFrechetMean,
920    // Grassmann manifold
921    RgGrassmannManifold,
922    // SO(3)
923    RgSo3Manifold,
924    // SPD manifold
925    RgSpdManifold,
926    // Stiefel manifold
927    RgStiefelManifold,
928    RiemannianAdam,
929    // Optimizers
930    RiemannianAdamConfig,
931    // Batch normalization
932    RiemannianBatchNorm,
933    // Core trait
934    RiemannianManifold,
935    RiemannianSgd,
936};
937pub mod causal_inference;
938pub mod causal_representation;
939pub mod memory_networks;
940pub mod pinn;
941pub mod point_processes;
942pub use causal_representation::{
943    compute_mig, compute_modularity, compute_sap, CrLinear, CrMlp, DeepScm, DisentanglementMetrics,
944    DscmConfig, DscmMechanism, FactorVae, FactorVaeConfig, IvaeConfig, IvaeDecoder, IvaeEncoder,
945    IvaeModel, IvaePrior, NonlinearIca, SlowIcaConfig, TcDiscriminator, TcVae, TcVaeConfig,
946    TcVaeDecoder, TcVaeEncoder,
947};
948pub mod bayesian_opt;
949pub mod multimodal;
950pub mod probabilistic;
951pub use probabilistic::{
952    nig_loss, ActNorm as ProbActNorm, CalibrationEvaluator, DeepEnsemble, DirichletOutput,
953    EnsembleMember, EvidentialClassLayer, EvidentialLayer, Flow, FlowModel, FlowResult,
954    GaussianProcess as ProbGaussianProcess, GpKernel, GpPrediction, IsotonicCalibrator, NigOutput,
955    PlattScaling, RealNvpCoupling, SnapshotEnsemble, TemperatureScaling,
956};
957pub mod reward_learning;
958pub use reward_learning::{
959    compute_returns, cross_entropy_binary, kendall_tau, sigmoid, BradleyTerryModel, Comparison,
960    CuriosityShaper, GoalRewardShaper, IrlConfig, MaxCausalEntIrl, MaxEntIrl, PreferenceDataset,
961    RewardModel, RewardModelConfig, RewardModelMetrics, RewardShaper, RlhfConfig, RlhfResult,
962    RlhfTrainer,
963};
964
965pub use bayesian_opt::{
966    AcquisitionFunction, AcquisitionResult, BayesOptConfig, BayesOptResult, BayesSearchSpace,
967    BayesianOptimizer, CmaEs, CmaEsConfig, CmaEsResult, CmaEsState, FidelityLevel, GaussianProcess,
968    GpConfig, KernelType, MultiFidelityConfig, MultiFidelityOptimizer, ObservationRecord,
969};
970
971pub use graph_transformer::{
972    layer_norm as gt_layer_norm, matmul as gt_matmul, softmax as gt_softmax, ChebNetLayer,
973    GatConfig, GraphAttentionTransformer, GraphTransformerError, GraphTransformerLayer,
974    LaplacianPositionalEncoding, PosEncodingType, RandomWalkPositionalEncoding, ReadoutType,
975};
976
977pub use graph_generation::{
978    AtomType, BondType, GraphGenError, GraphPropertyPredictor, GraphRnn, GraphRnnConfig,
979    MolecularFingerprint, MolecularGraph, MpnnConfig, MpnnLayer, VgaeConfig, VgaeEncoder,
980};
981
982pub use nas::{
983    decode_architecture_string,
984    encode_architecture_string,
985    network_stats,
986    // Legacy NAS components
987    AgingEvolutionNas,
988    ArchitectureEvaluator,
989    CellConfig,
990    CellEdge,
991    CellEncoding,
992    DartsCell,
993    DartsConfig,
994    DartsOptimizer,
995    DartsState,
996    EvoNasConfig,
997    EvolutionResult,
998    EvolutionaryNas,
999    GumbelSoftmax,
1000    LotteryTicketPruner,
1001    MixedOp,
1002    NasLogger,
1003    NasSummary,
1004    NetworkEncoding,
1005    NetworkStats,
1006    NodeConfig,
1007    OneShotConfig,
1008    OneShotNas,
1009    // New comprehensive NAS components
1010    OpType,
1011    OpsChoice,
1012    RandomNasConfig,
1013    RandomNasSearch,
1014    RandomSearchNas,
1015    SearchSpace,
1016    TicketState,
1017};
1018
1019pub use pinn::{
1020    BoundaryCondition, BurgersEquation, CollocationSampler, HeatEquation, LossComponents,
1021    NumericalGradient, PdeResidual, PinnActivation, PinnConfig, PinnLoss, PinnNetwork,
1022    PinnSolution, PinnTrainer, PoissonEquation, WaveEquation,
1023};
1024
1025pub use activation_function::ActivationFunction;
1026pub use benchmarks::{
1027    compare_models, BenchmarkConfig, BenchmarkMetrics, BenchmarkResults, ModelBenchmark,
1028};
1029pub use distributed::{
1030    models::utils::{create_data_parallel, create_distributed_data_parallel, init_process_group},
1031    models::{DDPConfig, DataParallel, DistributedDataParallel, SynchronizationMode},
1032    BackendConfig, CollectiveOp, CollectiveResult, CommunicationBackend, CommunicationBackendImpl,
1033    CommunicationGroup, CommunicationMetrics, CommunicationRuntime, CompressionAlgorithm,
1034    ReductionOp,
1035};
1036pub use layers::{
1037    compute_slopes, naive_attention, scaled_dot_product_attention, AlibiAttention, AlibiMask,
1038    AlibiSlopes, BahdanauAttention, Conv2D, Dense, Dropout, FlashAttention, FlashConfig, KVCache,
1039    Layer, LuongAttention, MultiHeadAttention, OnlineSoftmax, RMSNorm, RopeConfig, RopeEmbedding,
1040    RotaryInterpolation, TransformerDecoder, TransformerEncoder, GRU, LSTM, RNN,
1041};
1042pub use loss::{
1043    advanced_knowledge_distillation_loss, binary_cross_entropy, categorical_cross_entropy,
1044    focal_loss, hinge_loss, huber_loss, knowledge_distillation_loss, mse, quantile_loss,
1045    sparse_categorical_cross_entropy,
1046};
1047pub use metrics::{
1048    accuracy, confusion_matrix, f1_score, mean_absolute_percentage_error, precision, r_squared,
1049    recall, top_k_accuracy,
1050};
1051pub use mixed_precision::MixedPrecisionTrainer;
1052pub use model::{
1053    FunctionalModel, FunctionalModelBuilder, Input, Model, Node, Sequential, SharedLayer,
1054};
1055pub use model_parallel::{
1056    CommunicationPattern, MemoryRequirements, ModelParallelConfig, ModelParallelCoordinator,
1057    ParallelLayer, PipelineConfig, PlacementStrategy, SplitLayer, TensorParallelConfig,
1058};
1059pub use optimizers::{
1060    clip_gradients_adaptive,
1061    clip_gradients_by_global_norm,
1062    clip_gradients_by_norm,
1063    clip_gradients_by_value,
1064    AdaBelief,
1065    Adadelta,
1066    Adagrad,
1067    Adam,
1068    AdamW,
1069    AnnealStrategy,
1070    CosineAnnealingScheduler,
1071    ExponentialDecayScheduler,
1072    LambConfig,
1073    LambOptimizer,
1074    LinearScheduler,
1075    Lion,
1076    // New flat vec-based optimizers
1077    LionConfig,
1078    LionOptimizer,
1079    Lookahead,
1080    // New stateful LR schedulers
1081    LrScheduler,
1082    MetricMode,
1083    MuonConfig,
1084    MuonOptimizer,
1085    Nadam,
1086    OneCycleLrScheduler,
1087    Optimizer,
1088    ParameterGroup,
1089    ParameterGroupOptimizer,
1090    PolynomialDecayScheduler,
1091    RAdam,
1092    RMSprop,
1093    SchedReduceLrOnPlateau,
1094    WarmupScheduler,
1095    LAMB,
1096    SGD,
1097};
1098pub use pipeline::{MicroBatch, PipelineMetrics, PipelineModelBuilder, PipelineParallelModel};
1099pub use scheduler::{
1100    ConstantLR, CosineAnnealingLR, ExponentialLR, LearningRateScheduler, PolynomialLR,
1101    ReduceLROnPlateau, StepLR, WarmupCosineDecayLR,
1102};
1103pub use trainer::{
1104    Callback, EarlyStopping, LearningRateReduction, ModelCheckpoint, Trainer, TrainingMetrics,
1105    TrainingState,
1106};
1107pub use training::{
1108    create_distillation_trainer, create_distillation_trainer_with_temperature,
1109    create_memory_efficient_trainer, create_trainer_for_large_model, AccumulationTrainingConfig,
1110    DistillationConfig, DistillationMetrics, DistillationTrainer, DistillationTrainerBuilder,
1111    GradientAccumulationTrainer, TrainingStats,
1112};
1113pub use training_pipeline::{
1114    quick_train, TrainingPipeline, TrainingPipelineConfig, TrainingResults,
1115};
1116
1117pub use trainer::TensorboardCallback;
1118
1119#[cfg(feature = "onnx")]
1120pub use onnx::{
1121    OnnxAttribute, OnnxDataType, OnnxExport, OnnxGraph, OnnxModel, OnnxNode, OnnxTensor,
1122    OnnxValueInfo,
1123};
1124
1125pub use tensorflow_compat::{
1126    load_tensorflow_model, load_tensorflow_model_with_config, SavedModel, SavedModelLoader,
1127    SavedModelMetadata,
1128};
1129
1130pub use text_generation_pipelines::{
1131    CFGDecoder, CharTokenizer, EtaSampler, GenerationCache, GreedyDecoder, MinPSampler,
1132    PipelineResult, RepetitionPenaltyProcessor, SamplingStrategy as TgpSamplingStrategy,
1133    SimpleVocab, StreamingGenerator, TemperatureScaledDecoder, TgpConfig, TgpPipelineMetrics,
1134    Tokenizer as TgpTokenizer, TypicalSamplerTgp,
1135};
1136
1137pub use data::{DataPipelineConfig, NeuralDataPipeline, NeuralTransforms, TrainingBatch};
1138pub use deployment::{
1139    conservative_pruning_config, edge_fusion_config, edge_pruning_config, edge_quantization_config,
1140    fuse_layers, mobile_fusion_config, mobile_pruning_config, mobile_quantization_config,
1141    optimize_for_deployment, prune_model, quantize_model, ultra_low_precision_config,
1142    DeploymentMetadata, DeploymentModel, FusedLayer, FusionConfig, FusionPattern, FusionStats,
1143    LayerFusion, ModelOptimizer, ModelPruner, ModelQuantizer, OptimizationConfig,
1144    OptimizationStats, PrunedLayer, PruningConfig, PruningMask, PruningScope, PruningStats,
1145    PruningStrategy, QuantizationConfig, QuantizationParams, QuantizationPrecision,
1146    QuantizationStats, QuantizationStrategy, QuantizedLayer,
1147};
1148pub use peft::{
1149    AdaLoRAAdapter, AdaLoRAConfig, AdaLoRAStats, IA3Adapter, IA3Config, IA3InitStrategy,
1150    IA3ScalingType, IA3Stats, ImportanceMetric, LoRAAdapter, LoRAConfig, LoRADense, LoRALayer,
1151    MultiIA3Adapter, MultiIA3Stats, PEFTAdapter, PEFTConfig, PEFTLayer, PEFTMethod, PEFTStats,
1152    PTuningTaskType, PTuningV2Adapter, PTuningV2Config, PTuningV2Stats, PrefixTaskType,
1153    PrefixTuningAdapter, PrefixTuningConfig, PrefixTuningStats, PromptLayerConfig, QLoRAAdapter,
1154    QLoRAConfig, QLoRAMemoryStats, QuantizationType, RankAdaptationStats, TokenPosition,
1155};
1156pub use pretrained::{
1157    BasicBlock, BottleneckBlock, EfficientNet, EfficientNetConfig, MBConvBlock,
1158    PatchEmbedding as PretrainedPatchEmbedding, ResNet, ResNetBlockType, SEBlock,
1159    VisionTransformer as PretrainedVisionTransformer,
1160};
1161
1162// Additional module declarations and re-exports are in reexports_ext.rs
1163// to keep lib.rs under the 2000-line policy limit.
1164include!("reexports_ext.rs");