1#![allow(clippy::needless_range_loop)]
8#![allow(clippy::useless_vec)]
9#![allow(clippy::redundant_locals)]
10#![allow(clippy::len_without_is_empty)]
11#![allow(clippy::await_holding_lock)]
12#![allow(clippy::if_same_then_else)]
13#![allow(clippy::derivable_impls)]
14#![allow(clippy::wrong_self_convention)]
15#![allow(clippy::same_item_push)]
16#![allow(clippy::vec_init_then_push)]
17#![allow(clippy::ptr_arg)]
18#![allow(clippy::result_large_err)]
130#![allow(clippy::excessive_nesting)]
132#![allow(clippy::too_many_arguments)]
134#![allow(clippy::type_complexity)]
135#![allow(clippy::large_enum_variant)]
136
137#[cfg(feature = "bert")]
138pub mod bert;
139
140#[cfg(feature = "roberta")]
141pub mod roberta;
142
143#[cfg(feature = "distilbert")]
144pub mod distilbert;
145
146#[cfg(feature = "gpt2")]
147pub mod gpt2;
148
149#[cfg(feature = "gpt_neo")]
150pub mod gpt_neo;
151
152#[cfg(feature = "gpt_j")]
153pub mod gpt_j;
154
155#[cfg(feature = "t5")]
156pub mod t5;
157
158#[cfg(feature = "albert")]
159pub mod albert;
160
161#[cfg(feature = "electra")]
162pub mod electra;
163
164#[cfg(feature = "deberta")]
165pub mod deberta;
166
167#[cfg(feature = "vit")]
168pub mod vit;
169
170#[cfg(feature = "llama")]
171pub mod llama;
172
173#[cfg(feature = "gpt_neox")]
174pub mod gpt_neox;
175
176#[cfg(feature = "mistral")]
177pub mod mistral;
178
179#[cfg(feature = "clip")]
180pub mod clip;
181pub mod cogvlm;
182pub mod recursive;
183
184#[cfg(feature = "blip2")]
186pub mod blip2;
187
188#[cfg(feature = "llava")]
189pub mod llava;
190
191#[cfg(feature = "dalle")]
192pub mod dalle;
193
194#[cfg(feature = "flamingo")]
195pub mod flamingo;
196
197#[cfg(feature = "gemma")]
198pub mod gemma;
199
200#[cfg(feature = "qwen")]
201pub mod qwen;
202
203#[cfg(feature = "phi3")]
204pub mod phi3;
205
206pub mod hyena;
208pub mod mamba;
209pub mod retnet;
210pub mod rwkv;
211pub mod s4;
212
213pub mod falcon;
215pub mod stablelm;
216
217pub mod command_r;
219
220pub mod claude;
222
223pub mod moe;
225
226pub mod fnet;
228pub mod linformer;
229pub mod performer;
230
231pub mod sparse_attention;
233
234pub mod cross_attention;
236
237pub mod hierarchical;
239
240pub mod advanced_quantization;
242pub mod ring_attention;
243pub mod weight_loading;
244
245pub mod generation_utils;
247
248pub mod batch_inference;
250
251pub mod dynamic_pruning;
253
254pub mod knowledge_distillation;
256
257pub mod model_compression;
259
260pub mod continual_learning;
262
263pub mod curriculum_learning;
265
266pub mod multi_task_learning;
268
269pub mod progressive_training;
271
272pub mod meta_learning;
274
275#[cfg(feature = "llama")]
277pub mod code_specialized;
278
279#[cfg(feature = "llama")]
281pub mod math_specialized;
282
283pub mod scientific_specialized;
285
286pub mod legal_medical_specialized;
288
289pub mod creative_writing_specialized;
291
292pub mod common_patterns;
294
295pub mod comprehensive_testing;
297
298pub mod model_cards;
300
301pub mod neural_architecture_search;
303
304pub mod automated_model_design;
306
307pub mod hybrid_architectures;
309
310pub mod memory_profiling;
312
313pub mod error_recovery;
315
316pub mod mixed_bit_quantization;
318
319pub mod performance_optimization;
321
322pub mod model_serving;
324
325pub mod xlstm;
327
328pub mod biologically_inspired;
330
331pub mod quantum_classical_hybrids;
333
334pub mod benchmarking;
336
337pub mod numerical_parity_tests;
339
340pub mod developer_tools;
342
343#[cfg(feature = "bert")]
344pub use bert::{BertConfig, BertForMaskedLM, BertForSequenceClassification, BertModel};
345
346#[cfg(feature = "roberta")]
347pub use roberta::{
348 RobertaConfig, RobertaForMaskedLM, RobertaForQuestionAnswering,
349 RobertaForSequenceClassification, RobertaForTokenClassification, RobertaModel,
350};
351
352#[cfg(feature = "distilbert")]
353pub use distilbert::{
354 DistilBertConfig, DistilBertForMaskedLM, DistilBertForQuestionAnswering,
355 DistilBertForSequenceClassification, DistilBertForTokenClassification, DistilBertModel,
356};
357
358#[cfg(feature = "gpt2")]
359pub use gpt2::{Gpt2Config, Gpt2LMHeadModel, Gpt2Model};
360
361#[cfg(feature = "gpt_neo")]
362pub use gpt_neo::{GptNeoConfig, GptNeoLMHeadModel, GptNeoModel};
363
364#[cfg(feature = "gpt_j")]
365pub use gpt_j::{GptJConfig, GptJLMHeadModel, GptJModel};
366
367#[cfg(feature = "t5")]
368pub use t5::{T5Config, T5ForConditionalGeneration, T5Model};
369
370#[cfg(feature = "albert")]
371pub use albert::{
372 AlbertConfig, AlbertForMaskedLM, AlbertForQuestionAnswering, AlbertForSequenceClassification,
373 AlbertForTokenClassification, AlbertModel,
374};
375
376#[cfg(feature = "electra")]
377pub use electra::{
378 ElectraConfig, ElectraForMultipleChoice, ElectraForPreTraining, ElectraForQuestionAnswering,
379 ElectraForSequenceClassification, ElectraForTokenClassification, ElectraModel,
380};
381
382#[cfg(feature = "deberta")]
383pub use deberta::{
384 DebertaConfig, DebertaForMaskedLM, DebertaForMultipleChoice, DebertaForQuestionAnswering,
385 DebertaForSequenceClassification, DebertaForTokenClassification, DebertaModel,
386};
387
388#[cfg(feature = "vit")]
389pub use vit::{ViTConfig, ViTForImageClassification, ViTModel};
390
391#[cfg(feature = "llama")]
392pub use llama::{LlamaConfig, LlamaForCausalLM, LlamaModel};
393
394#[cfg(feature = "gpt_neox")]
395pub use gpt_neox::{GPTNeoXConfig, GPTNeoXForCausalLM, GPTNeoXModel};
396
397#[cfg(feature = "mistral")]
398pub use mistral::{MistralConfig, MistralForCausalLM, MistralModel};
399
400#[cfg(feature = "clip")]
401pub use clip::{CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig};
402
403#[cfg(feature = "blip2")]
404pub use blip2::{
405 Blip2ConditionalGenerationOutput, Blip2Config, Blip2ForConditionalGeneration, Blip2Model,
406 Blip2Output, Blip2QFormerConfig, Blip2QFormerModel, Blip2QFormerOutput, Blip2TextConfig,
407 Blip2VisionConfig, Blip2VisionModel, LanguageModelOutput,
408};
409
410#[cfg(feature = "llava")]
411pub use llava::{LlavaConfig, LlavaForConditionalGeneration, LlavaVisionConfig};
412
413#[cfg(feature = "dalle")]
414pub use dalle::{
415 DalleConfig, DalleDiffusionConfig, DalleImageConfig, DalleImageEncoder, DalleMLP, DalleModel,
416 DalleModelOutput, DalleTextConfig, DalleTextEncoder, DalleTimeEmbedding, DalleUNet, DalleVAE,
417 DalleVisionConfig,
418};
419
420#[cfg(feature = "flamingo")]
421pub use flamingo::{
422 FlamingoConfig, FlamingoLanguageConfig, FlamingoLanguageModel, FlamingoLanguageOutput,
423 FlamingoModel, FlamingoOutput, FlamingoPerceiverConfig, FlamingoVisionConfig,
424 FlamingoVisionEncoder, FlamingoXAttentionConfig, PerceiverResampler,
425};
426
427#[cfg(feature = "gemma")]
428pub use gemma::{GemmaConfig, GemmaForCausalLM, GemmaModel};
429
430#[cfg(feature = "qwen")]
431pub use qwen::{QwenConfig, QwenForCausalLM, QwenModel};
432
433#[cfg(feature = "phi3")]
434pub use phi3::{Phi3Config, Phi3ForCausalLM, Phi3Model};
435
436pub use automated_model_design::{
437 ArchitectureTemplate, ConstraintSolver, DeploymentEnvironment, DesignPatternLibrary,
438 DesignRequirements, DesignRequirementsBuilder, Modality, ModelDesign, ModelDesignMetadata,
439 ModelDesigner, ModelMetrics, PerformanceTarget, ResourceConstraints,
440 TaskType as DesignTaskType, TemplateMetadata,
441};
442pub use claude::{ClaudeConfig, ClaudeForCausalLM, ClaudeModel};
443#[cfg(feature = "llama")]
444pub use code_specialized::{
445 CodeLlamaConfig, CodeLlamaForCausalLM, CodeLlamaModel, CodeModelVariant, CodeSpecialTokens,
446 CodeSpecializedConfig, CodeSpecializedForCausalLM, CodeSpecializedModel, DeepSeekCoderConfig,
447 DeepSeekCoderForCausalLM, DeepSeekCoderModel, QwenCoderConfig, QwenCoderForCausalLM,
448 QwenCoderModel, StarCoderConfig, StarCoderForCausalLM, StarCoderModel,
449};
450pub use command_r::{CommandRConfig, CommandRForCausalLM, CommandRModel};
451pub use common_patterns::{
452 components, get_global_registry, ArchitectureType, ComputeRequirements, EvaluableModel,
453 EvaluationData, EvaluationMetric, EvaluationResults, GenerationConfig, GenerationStrategy,
454 GenerativeModel, InitializationStrategy, MemoryEstimate, ModelFamily, ModelFamilyMetadata,
455 ModelRegistry, ModelUtils, TaskType as CommonTaskType,
456};
457pub use comprehensive_testing::{
458 reporting, BiasMetric, BiasmitigationStrategy, FairnessAssessment, FairnessConfig,
459 FairnessMetricType, FairnessResult, FairnessTestData, FairnessViolation, GroupData,
460 LayerPerformance, MemoryAnalysis, ModelTestSuite, NumericalDifferences, NumericalParityResults,
461 OverallPerformance, PerformanceProfiler, PerformanceResults, ReferenceComparator,
462 StatisticalTest, TestDataType, TestInputConfig, TestResult, TestStatistics,
463 ThroughputMeasurements, TimingInfo, ValidationConfig,
464};
465pub use continual_learning::{
466 utils as continual_learning_utils, ContinualLearningConfig, ContinualLearningMetrics,
467 ContinualLearningOutput, ContinualLearningTrainer, ContinualStrategy, LearningRateSchedule,
468 MemoryBuffer, MemorySelectionStrategy, TaskEvaluation, TaskInfo,
469};
470pub use creative_writing_specialized::{
471 CreativeWritingConfig, CreativeWritingForCausalLM, CreativeWritingModel,
472 CreativeWritingSpecialTokens, EmotionalTone, ImprovementType, LiteraryDevice,
473 NarrativePerspective, PoetryStyle, StyleAnalysis, WritingGenre, WritingImprovement,
474 WritingStyle,
475};
476pub use cross_attention::{
477 AdaptiveCrossAttention, CrossAttention, CrossAttentionConfig, GatedCrossAttention,
478 HierarchicalCrossAttention, MultiHeadCrossAttention, SparseCrossAttention,
479};
480pub use curriculum_learning::{
481 utils as curriculum_learning_utils, CurriculumAnalysis, CurriculumConfig,
482 CurriculumEpochOutput, CurriculumExample, CurriculumLearningOutput, CurriculumLearningTrainer,
483 CurriculumStats, CurriculumStrategy, DifficultyMeasure, PacingFunction,
484};
485pub use dynamic_pruning::*;
486pub use error_recovery::{
487 ErrorCategory, ErrorRecoveryManager, ErrorTrends, ModelCheckpoint, RecoverableOperation,
488 RecoveryAttempt, RecoveryConfig, RecoveryMetrics, RecoveryReport, RecoveryStrategy,
489};
490pub use falcon::{FalconConfig, FalconForCausalLM, FalconModel};
491pub use fnet::{FNetConfig, FNetForMaskedLM, FNetForSequenceClassification, FNetModel};
492pub use hierarchical::{
493 HierarchicalConfig, HierarchicalForLanguageModeling, HierarchicalForSequenceClassification,
494 HierarchicalTransformer, NestedTransformer, PyramidTransformer, TreeTransformer,
495};
496pub use hybrid_architectures::{
497 AdaptiveConfig, ArchitecturalComponent, ArchitectureSummary, AttentionType, CNNArchitecture,
498 CrossModalConfig, EnsembleMethod, FusionStrategy, GlobalParams, HierarchyType,
499 HybridArchitecture, HybridConfig, HybridConfigBuilder, MemoryType, ParallelFusionMethod,
500 RNNCellType, StateSpaceType, SwitchingCriteria, TransformerVariant,
501};
502pub use hyena::{
503 HyenaConfig, HyenaForLanguageModeling, HyenaForSequenceClassification, HyenaModel,
504};
505pub use knowledge_distillation::{
506 utils as knowledge_distillation_utils, DistillationConfig, DistillationOutput,
507 DistillationStrategy, KnowledgeDistillationTrainer, ProgressiveStage, StudentOutputs,
508 TeacherOutputs,
509};
510pub use legal_medical_specialized::{
511 Citation, CitationType, ComplianceReport, ComplianceViolation, DocumentAnalysis,
512 LegalMedicalConfig, LegalMedicalDomain, LegalMedicalForCausalLM, LegalMedicalModel,
513 LegalMedicalSpecialTokens, LegalSystem, MedicalStandard, PrivacyRequirement,
514};
515pub use linformer::{
516 LinformerConfig, LinformerForMaskedLM, LinformerForSequenceClassification, LinformerModel,
517};
518pub use mamba::{MambaConfig, MambaModel};
519#[cfg(feature = "llama")]
520pub use math_specialized::{
521 ChainOfThoughtConfig, DeepSeekMathConfig, DeepSeekMathForCausalLM, DeepSeekMathModel,
522 MammothConfig, MammothForCausalLM, MammothModel, MathDomain, MathLlamaConfig,
523 MathLlamaForCausalLM, MathLlamaModel, MathModelVariant, MathProblemType, MathReasoningOutput,
524 MathSpecialTokens, MathSpecializedConfig, MathSpecializedForCausalLM, MathSpecializedModel,
525 MinervaConfig, MinervaForCausalLM, MinervaModel, ReasoningStep, ReasoningStrategy,
526};
527pub use meta_learning::{
528 utils as meta_learning_utils, ConvergenceMetrics, EpisodeResult, EvaluationResult, Example,
529 ExampleSet, MetaAlgorithm, MetaLearner, MetaLearningConfig, MetaLearningModel, MetaOptimizer,
530 MetaStatistics, PerformanceMetrics, Task, TaskBatch, TaskResult, TaskSampler,
531 TaskType as MetaTaskType,
532};
533pub use mixed_bit_quantization::{
534 BitAllocationStrategy, CalibrationConfig, CalibrationMethod,
535 HardwareConstraints as QuantizationHardwareConstraints,
536 HardwarePlatform as QuantizationHardwarePlatform, LayerQuantizationConstraints,
537 MixedBitQuantizationConfig, MixedBitQuantizer, ProgressiveQuantizationConfig,
538 QuantizationFormat, QuantizationParams, QuantizationQualityMetrics, QuantizationResults,
539 QuantizedLayerInfo, SensitivityAnalysisResults,
540};
541pub use model_compression::{
542 utils as model_compression_utils, ClusteringMethod, CompressedModel, CompressionAnalysis,
543 CompressionConfig, CompressionPipeline, CompressionStrategy, CompressionSummary,
544 DecompositionType, LayerCompressionStats, OptimizationObjective, PruningStrategy,
545 StructuredPruningGranularity,
546};
547pub use model_serving::{
548 InferenceRequest, InferenceResponse, LoadBalancer, LoadBalancingStrategy, ModelInstance,
549 ModelServingManager, RequestPriority, RequestQueue, ServingConfig, ServingMetrics,
550};
551pub use moe::{
552 glam_config, switch_config, Expert, ExpertParallel, MLPExpert, MoEConfig, RouterOutput,
553 RoutingStats, SparseMoE, SwitchMoE, TopKRouter,
554};
555pub use multi_task_learning::{
556 utils as multi_task_learning_utils, LossBalancingStrategy, MTLAnalysis, MTLArchitecture,
557 MTLConfig, MTLStats, MultiTaskEvaluation, MultiTaskLearningTrainer, MultiTaskOutput,
558 TaskConfig, TaskEvaluation as MTLTaskEvaluation, TaskPriority, TaskType as MTLTaskType,
559};
560pub use neural_architecture_search::{
561 Architecture, ArchitectureConstraint, ArchitectureEvaluation, ArchitectureMetadata,
562 DimensionRange, HardwareConstraints, HardwarePlatform, NASConfig, NeuralArchitectureSearcher,
563 OptimizationObjective as NASOptimizationObjective, SearchSpace, SearchStatistics,
564 SearchStrategy,
565};
566pub use performance_optimization::{
567 BatchProcessor, BatchingStrategy, CachedTensor, DynamicBatchManager, GpuCacheStatistics,
568 GpuMemoryChunk, GpuMemoryOptimizer, GpuMemoryPool, GpuMemoryStats,
569 GpuOptimizationRecommendations, GpuTensorCache, MemoryOptimizer, PerformanceConfig,
570 PerformanceMonitor, PerformanceStatistics,
571};
572pub use performer::{
573 PerformerConfig, PerformerForMaskedLM, PerformerForSequenceClassification, PerformerModel,
574};
575pub use progressive_training::{
576 utils as progressive_training_utils, GrowthDimension, GrowthEvent, GrowthInfo, GrowthResult,
577 GrowthSchedule, GrowthStrategy, LearningProgress, ProgressiveConfig, ProgressiveModel,
578 ProgressiveTrainer,
579};
580pub use retnet::{
581 RetNetConfig, RetNetForLanguageModeling, RetNetForSequenceClassification, RetNetModel,
582};
583pub use rwkv::{RwkvConfig, RwkvModel};
584pub use s4::{S4Config, S4ForLanguageModeling, S4Model};
585pub use scientific_specialized::{
586 CitationStyle, ScientificAnalysis, ScientificConfig, ScientificDomain, ScientificForCausalLM,
587 ScientificModel, ScientificSpecialTokens,
588};
589pub use sparse_attention::{
590 utils as sparse_attention_utils, SparseAttention, SparseAttentionConfig, SparseAttentionMask,
591 SparsePattern,
592};
593pub use stablelm::{StableLMConfig, StableLMForCausalLM, StableLMModel};
594pub use weight_loading::{
595 auto_create_loader, create_distributed_loader, create_gguf_loader, create_huggingface_loader,
596 create_memory_mapped_loader, DistributedStats, DistributedWeightLoader, GGMLType, GGUFLoader,
597 HuggingFaceLoader, LazyTensor, MemoryMappedLoader, QuantizationConfig, StreamingLoader,
598 TensorMetadata, WeightDataType, WeightFormat, WeightLoader, WeightLoadingConfig,
599};
600
601pub use xlstm::{
602 ExponentialGatingConfig, FeedForward, MLstmBlock, MLstmConfig, SLstmBlock, SLstmConfig,
603 XLSTMBlockConfig, XLSTMBlockType, XLSTMConfig, XLSTMForCausalLM,
604 XLSTMForSequenceClassification, XLSTMLayer, XLSTMModel, XLSTMState,
605};
606
607#[cfg(test)]
624mod tests {
625
626 #[test]
627 fn it_works() {
628 assert_eq!(2 + 2, 4);
629 }
630}