Skip to main content

trustformers_core/
lib.rs

1//! # TrustformeRS Core
2//!
3//! Core traits, types, and utilities for the TrustformeRS transformer library.
4//!
5//! This crate provides the foundational building blocks for implementing transformer models
6//! in pure Rust with zero-cost abstractions. It includes:
7//!
8//! - **Tensor operations**: High-performance tensor abstractions with GPU acceleration
9//! - **Neural network layers**: Attention mechanisms, feed-forward networks, normalization
10//! - **Model traits**: Unified interfaces for models, tokenizers, and configurations
11//! - **Device management**: CPU, CUDA, Metal, and other hardware backend support
12//! - **Quantization**: INT4/INT8/FP16 quantization for efficient inference
13//! - **Memory management**: Caching, checkpointing, and memory-efficient operations
14//! - **Hardware acceleration**: SIMD, BLAS, GPU kernels, and compiler optimizations
15//!
16//! ## Quick Start
17//!
18//! ```rust,no_run
19//! use trustformers_core::{
20//!     tensor::Tensor,
21//!     layers::Linear,
22//!     traits::Layer,
23//! };
24//!
25//! // Create a tensor on CPU
26//! let input = Tensor::randn(&[32, 512])?;
27//!
28//! // Create a linear layer
29//! let linear = Linear::new(512, 768, true);
30//! let output = linear.forward(input)?;
31//! # Ok::<(), Box<dyn std::error::Error>>(())
32//! ```
33//!
34//! ## Architecture
35//!
36//! TrustformeRS Core follows a dual-layer architecture:
37//! - High-level ML abstractions (tensors, layers, models)
38//! - Low-level scientific computing via SciRS2 (SIMD, parallel ops, BLAS)
39//!
40//! All external dependencies (PyTorch, ONNX Runtime, tokenizers) are abstracted
41//! through unified interfaces to maintain flexibility and testability.
42//!
43//! ## Features
44//!
45//! - `cuda`: NVIDIA GPU support via CUDA
46//! - `metal`: Apple GPU support via Metal
47//! - `opencl`: OpenCL GPU support
48//! - `mkl`: Intel MKL BLAS backend
49//! - `quantization`: Model quantization support
50//! - `distributed`: Distributed training utilities
51//!
52//! ## Safety and Performance
53//!
54//! - **Memory-safe**: Pure Rust with no unsafe code in critical paths
55//! - **Zero-cost abstractions**: Performance comparable to C++ implementations
56//! - **GPU-accelerated**: Automatic dispatch to GPU when available
57//! - **SIMD-optimized**: Vectorized operations for CPU performance
58
59#![allow(clippy::excessive_nesting)] // Algorithm-heavy code often requires deep nesting
60#![allow(clippy::result_large_err)] // Large error enums are intentional for rich error context
61#![allow(clippy::excessive_precision)] // High-precision floats needed for ML computations
62#![allow(clippy::module_name_repetitions)] // Module names often repeat for clarity
63#![allow(clippy::similar_names)] // Similar variable names are common in mathematical code
64#![allow(clippy::too_many_arguments)] // ML functions often require many parameters
65#![allow(clippy::too_many_lines)] // Complex algorithms require longer functions
66
67pub mod ab_testing;
68pub mod adaptive_computation;
69pub mod autodiff;
70pub mod blas;
71pub mod cache;
72pub mod checkpoint;
73pub mod compiler;
74pub mod compression;
75pub mod device;
76pub mod error;
77pub mod errors;
78pub mod evaluation;
79pub mod export;
80pub mod generation;
81pub mod gpu;
82// Temporarily disabled when CUDA feature is enabled - needs cudarc 0.17.7 API migration
83#[cfg(not(feature = "cuda"))]
84pub mod gpu_accelerated;
85pub mod gpu_ops;
86pub mod hardware;
87// Temporarily disabled when CUDA feature is enabled - needs cudarc 0.17.7 API migration
88#[cfg(not(feature = "cuda"))]
89pub mod hardware_acceleration;
90pub mod kernel_fusion;
91pub mod kernel_tuning;
92pub mod kernels;
93pub mod layers;
94pub mod leaderboard;
95pub mod memory;
96pub mod monitoring;
97pub mod neuromorphic;
98pub mod numa_optimization;
99pub mod ops;
100pub mod optical;
101pub mod parallel;
102pub mod patterns;
103pub mod peft;
104pub mod performance;
105pub mod plugins;
106pub mod quantization;
107pub mod quantum;
108pub mod sparse_ops;
109pub mod sparse_tensor;
110pub mod tensor;
111pub mod tensor_debugger;
112pub mod testing;
113#[cfg(test)]
114pub mod tests;
115pub mod tokenizer_backend;
116pub mod traits;
117pub mod utils;
118pub mod versioning;
119pub mod visualization;
120
121pub use ab_testing::{
122    ABTestManager, ABTestSummary, ConfidenceLevel, DeploymentStrategy, Experiment,
123    ExperimentConfig, ExperimentStatus, HealthCheck, HealthCheckType, MetricCollector,
124    MetricSummary, MetricType, MetricValue, Recommendation, RollbackCondition, RolloutController,
125    RolloutStatus, RoutingStrategy, StatisticalAnalyzer, TestRecommendation, TestResult,
126    TrafficSplitter, UserSegment, Variant,
127};
128pub use adaptive_computation::{
129    AdaptiveComputationConfig, AdaptiveComputationManager, AdaptiveComputationStrategy,
130    ComplexityEstimationMethod, ComplexityEstimator, ComputationBudget, ComputationPath,
131    ConfidenceBasedStrategy, DynamicDepthStrategy, EntropyBasedComplexityEstimator, LayerMetrics,
132    LayerSkipPattern, PerformanceTracker, ResourceAllocation, UncertaintyBasedStrategy,
133};
134pub use blas::{
135    blas_optimizer, init_blas, optimized_dot, optimized_gemm, optimized_gemv, BlasBackend,
136    BlasConfig, BlasOperation, BlasOptimizer,
137};
138pub use cache::{
139    CacheConfig, CacheEntry, CacheKey, CacheKeyBuilder, CacheMetrics, EvictionPolicy,
140    InferenceCache, LRUEviction, SizeBasedEviction, TTLEviction,
141};
142pub use checkpoint::{
143    convert_checkpoint, detect_format, load_checkpoint, save_checkpoint, CheckpointConverter,
144    CheckpointFormat, ConversionConfig, ConversionResult, JaxCheckpoint, LayerMapping,
145    PyTorchCheckpoint, TensorFlowCheckpoint, TrustformersCheckpoint, WeightMapping,
146    WeightMappingRule,
147};
148pub use compiler::{
149    CompilationResult, CompilerConfig, CompilerOptimizer, ComputationGraph, DeviceType, GraphEdge,
150    GraphNode, HardwareTarget, OptimizationLevel, OptimizationRecommendation,
151    OptimizationRecommendations, OptimizationResult, PassResult, RecommendationCategory,
152    RecommendationPriority,
153};
154pub use compression::{
155    // Convenience functions
156    create_compression_pipeline,
157    AccuracyRetention,
158    AttentionDistiller,
159    ChannelPruner,
160    CompressionConfig as CompressionPipelineConfig,
161    CompressionEvaluator,
162    // Metrics exports
163    CompressionMetrics,
164    // Pipeline exports
165    CompressionPipeline,
166    CompressionRatio,
167    CompressionReport,
168    CompressionResult,
169    CompressionStage,
170    CompressionTargets,
171    // Distillation exports
172    DistillationConfig,
173    DistillationLoss,
174    DistillationResult,
175    DistillationStrategy,
176    FeatureDistiller,
177    FilterPruner,
178    GradualPruner,
179    HeadPruner,
180    HiddenStateDistiller,
181    InferenceSpeedup,
182    KnowledgeDistiller,
183    LayerDistiller,
184    LayerPruner,
185    MagnitudePruner,
186    ModelSizeReduction,
187    PipelineBuilder,
188    PruningConfig,
189    PruningResult,
190    PruningSchedule,
191    PruningStats,
192    // Pruning exports
193    PruningStrategy,
194    ResponseDistiller,
195    SparsityMetric,
196    StructuredPruner,
197    StudentModel,
198    TeacherModel,
199    UnstructuredPruner,
200};
201pub use device::Device;
202pub use errors::{Result, TrustformersError};
203pub use evaluation::{
204    Accuracy, DatasetLoader, DatasetManager, DatasetSample, EvaluationConfig, EvaluationDataset,
205    EvaluationHarness, EvaluationResult, EvaluationSuite, Evaluator, ExactMatch, F1Average,
206    F1Score, FileDatasetLoader, GLUEEvaluator, GLUETask, MemoryDatasetLoader, Metric,
207    MetricCollection, OtherBenchmark, Perplexity, SuperGLUEEvaluator, SuperGLUETask, BLEU,
208};
209pub use export::{
210    CoreMLExporter, ExportConfig, ExportFormat, ExportPrecision, ExportQuantization, GGMLExporter,
211    GGUFExporter, ModelExporter, ONNXExporter, TensorRTExporter, UniversalExporter,
212};
213pub use generation::{
214    FinishReason,
215    GenerationConfig,
216    GenerationStrategy,
217    GenerationStream,
218    // KVCache, // Now exported from cache module
219    // SpeculativeDecoder, TextGenerator,  // Temporarily disabled due to missing modules
220};
221#[cfg(not(feature = "cuda"))]
222pub use gpu_accelerated::{GpuAcceleratedOps, GpuOpsConfig, GpuPrecision};
223pub use hardware::{
224    AsicBackend, AsicDevice, AsicOperationSet, DataType, HardwareBackend, HardwareCapabilities,
225    HardwareConfig, HardwareDevice, HardwareManager, HardwareMetrics, HardwareOperation,
226    HardwareRegistry, HardwareResult, HardwareType, OperationMode, PrecisionMode,
227};
228pub use kernel_fusion::{
229    ComputationGraph as FusionComputationGraph, DataType as FusionDataType, Device as FusionDevice,
230    FusedKernel, FusionConstraint, FusionOpportunity, FusionPattern, FusionStatistics,
231    GraphNode as FusionGraphNode, KernelFusionEngine, KernelImplementation, MemoryLayout,
232    NodeMetadata, OperationType, TensorInfo,
233};
234pub use kernel_tuning::{
235    get_kernel_tuner, Backend as TuningBackend, KernelParams, KernelTuner,
236    Operation as TuningOperation, PlatformInfo, TuningConfig, TuningStatistics,
237};
238pub use kernels::fused_ops::ActivationType;
239pub use kernels::{
240    FusedAttentionDropout, FusedBiasActivation, FusedGELU, FusedLinear, FusedMatmulScale,
241    OptimizedRoPE, RoPEConfig, RoPEScalingType, SIMDLayerNorm, SIMDSoftmax, VectorizedRoPE,
242};
243// Import hardware traits separately
244pub use hardware::traits::{
245    AsyncHardwareOperation, AsyncOperationHandle, AsyncOperationStatus, DeviceMemory, DeviceStatus,
246    HardwareScheduler, MemoryType, MemoryUsage as HardwareMemoryUsage, OperationParameter,
247    OperationRequirements, PerformanceRequirements, SchedulerStatistics,
248};
249// Import ASIC types from asic submodule
250pub use autodiff::{
251    AnalysisResult,
252    AutodiffEngine,
253    ComputationGraph as AutodiffComputationGraph,
254    DebuggerConfig,
255    GradientFlowStats,
256    GradientMode,
257    GradientTape,
258    // Debugger exports
259    GraphDebugger,
260    GraphIssue,
261    GraphNode as AutodiffGraphNode,
262    GraphOutputFormat,
263    IssueSeverity,
264    IssueType,
265    MemoryStats,
266    NodeDebugInfo,
267    NodeId,
268    OperationType as AutodiffOperationType,
269    TapeEntry,
270    TraversalInfo,
271    Variable,
272    VariableRef,
273};
274pub use hardware::asic::{
275    AsicDeviceConfig, AsicDriver, AsicDriverFactory, AsicMemoryConfig, AsicPerformanceMonitor,
276    AsicSpec, AsicType, AsicVendor, CacheConfig as AsicCacheConfig,
277};
278#[cfg(not(feature = "cuda"))]
279pub use hardware_acceleration::{
280    api as hardware_acceleration_api, AccelerationBackend, AccelerationConfig, AccelerationStats,
281    HardwareAccelerator,
282};
283pub use leaderboard::{
284    LeaderboardCategory, LeaderboardClient, LeaderboardEntry, LeaderboardFilter,
285    LeaderboardManager, LeaderboardQuery, LeaderboardRanking, LeaderboardStats, LeaderboardStorage,
286    LeaderboardSubmission, RankingCriteria, SubmissionValidator,
287};
288pub use memory::{
289    get_memory_manager, get_tensor, init_memory_manager, return_tensor, AdaptiveStrategy,
290    MemoryConfig, MemoryEvictionPolicy, MemoryMappedTensor, MemoryPoolStats, TensorMemoryPool,
291    TensorView,
292};
293pub use monitoring::{
294    AttentionPattern, AttentionPatternType, AttentionReport, AttentionVisualizer,
295    AttentionVisualizerConfig, Counter, Gauge, Histogram, MemoryReport, MemorySnapshot,
296    MemoryTracker, MemoryTrackerConfig, MemoryUsage, MetricsCollector, MetricsCollectorConfig,
297    MetricsSummary, ModelMonitor, ModelProfiler, MonitoringConfig, MonitoringReport,
298    MonitoringSession, OptimizationSuggestion, OptimizationType, ProfilerConfig, ProfilingReport,
299};
300pub use numa_optimization::{
301    get_numa_allocator, init_numa_allocator, numa_alloc, numa_free, AllocationStats,
302    HotspotSeverity, NumaAllocation, NumaAllocator, NumaNode, NumaPerformanceMonitor, NumaPolicy,
303    NumaStrategy, NumaTopology, NumaTrafficAnalysis, ThreadAffinity, ThreadPriority,
304    TrafficHotspot,
305};
306pub use parallel::{
307    init_parallelism,
308    parallel_chunk_map,
309    parallel_context,
310    parallel_execute,
311    parallel_map,
312    ActivationType as ParallelActivationType,
313    AsyncTensorParallel,
314    // Parallel layers exports
315    ColumnParallelLinear,
316    CommunicationBackend,
317    Communicator,
318    DeviceMesh,
319    DistributedTensor,
320    InitMethod,
321    MemoryPolicy,
322    MicrobatchManager,
323    // Model parallel exports
324    ModelParallelConfig,
325    ModelParallelContext,
326    ModelParallelStrategy,
327    NumaConfig,
328    ParallelContext,
329    ParallelMLP,
330    ParallelMultiHeadAttention,
331    ParallelOps,
332    ParallelismStrategy,
333    PipelineExecutor,
334    // Pipeline parallel exports
335    PipelineLayer,
336    PipelineModel,
337    PipelineOp,
338    PipelineOptimizer,
339    PipelineSchedule,
340    PipelineScheduleType,
341    PipelineStage,
342    RowParallelLinear,
343    TensorParallelInit,
344    // Tensor parallel exports
345    TensorParallelOps,
346    TensorParallelShapes,
347    TensorPartition,
348};
349pub use patterns::{
350    Buildable, Builder, BuilderError, BuilderResult, ConfigBuilder, ConfigBuilderImpl,
351    ConfigManager, ConfigMetadata, ConfigSerializable, CpuLimits, EnvironmentConfig, GpuLimits,
352    LoggingConfig, MemoryLimits, PatternError, PatternResult, PerformanceConfig, ResourceConfig,
353    SecurityConfig, StandardBuilder, StandardConfig, UnifiedConfig, ValidatedBuilder,
354};
355pub use peft::{
356    AdapterLayer, LoRALayer, PeftConfig, PeftMethod, PeftModel, PrefixTuningLayer,
357    PromptTuningEmbedding, QLoRALayer,
358};
359pub use performance::{
360    BenchmarkBuilder, BenchmarkCategory, BenchmarkConfig, BenchmarkDSL, BenchmarkMetadata,
361    BenchmarkRegistry, BenchmarkReport, BenchmarkResult, BenchmarkRunner, BenchmarkRunnerBuilder,
362    BenchmarkSpec, BenchmarkSuite, ComparisonResult, ContinuousBenchmark,
363    ContinuousBenchmarkConfig, CustomBenchmark, Framework, HuggingFaceBenchmark, LatencyMetrics,
364    MemoryMetrics, MemoryProfiler, MemorySnapshot as PerformanceMemorySnapshot,
365    MemoryTracker as PerformanceMemoryTracker, MetricsTracker, ModelComparison,
366    PerformanceProfiler, PerformanceRegression, ProfileResult, PytorchBenchmark, ReportFormat,
367    Reporter, RunConfig, RunMode, ThroughputMetrics,
368};
369pub use plugins::{
370    Dependency, GpuRequirements, Plugin, PluginContext, PluginEvent, PluginEventHandler,
371    PluginInfo, PluginLoader, PluginManager, PluginRegistry, SystemRequirements,
372};
373pub use quantization::{
374    dequantize_bitsandbytes,
375    estimate_quantization_error,
376    from_bitsandbytes_format,
377    quantize_4bit,
378    quantize_dynamic_tree,
379    quantize_int8,
380    select_fp8_format,
381    to_bitsandbytes_format,
382    AWQQuantizer,
383    ActivationLayerQuantConfig,
384    // Activation quantization
385    ActivationQuantConfig,
386    ActivationQuantScheme,
387    ActivationQuantizer,
388    ActivationStats,
389    // Mixed-bit quantization
390    AutoBitAllocationStrategy,
391    // BitsAndBytes compatibility
392    BitsAndBytesConfig,
393    // GGUF K-quant formats
394    BlockQ2K,
395    BlockQ3K,
396    BlockQ4K,
397    BnBComputeType,
398    BnBConfig,
399    BnBQuantType,
400    BnBQuantizer,
401    BnBStorageType,
402    // FP8 quantization
403    DelayedScalingConfig,
404    FP8Config,
405    FP8Format,
406    FP8Quantizer,
407    FP8Tensor,
408    FakeQuantize,
409    GPTQQuantizer,
410    KQuantConfig,
411    KQuantTensor,
412    KQuantType,
413    KQuantizer,
414    LayerQuantConfig,
415    MixedBitConfig,
416    MixedBitQuantizedTensor,
417    MixedBitQuantizer,
418    Observer,
419    QATConfig,
420    QuantState,
421    QuantizationConfig,
422    QuantizationScheme,
423    QuantizedActivation,
424    QuantizedBlock,
425    QuantizedTensor,
426    Quantizer,
427    ScaleFactors,
428    ScalingStrategy,
429    SensitivityConfig,
430    SensitivityMetric,
431};
432pub use sparse_ops::{
433    conversion, pruning, sparse_attention, sparse_matmul, BlockSparsity, NMSparsity,
434    StructuredSparsityPattern,
435};
436pub use sparse_tensor::{SparseFormat, SparseIndices, SparseTensor};
437pub use tensor::{
438    DType, EvalContext, ExprNode, OpType, OptimizationHints, Tensor, TensorExpr, TensorType,
439};
440pub use tensor_debugger::{
441    DebugTensorStats, OperationTrace, Severity, TensorDebugIssue, TensorDebugger,
442    TensorDebuggerConfig, TensorIssueType, WatchCondition, Watchpoint,
443};
444pub use tokenizer_backend::{Encoding, Tokenizer, TokenizerError};
445pub use traits::{Config, Layer, Model};
446pub use versioning::{
447    ActiveDeployment,
448    // Storage types
449    Artifact,
450    ArtifactType,
451    DateRange,
452    DeploymentConfig,
453    DeploymentEvent,
454    DeploymentEventType,
455    // Deployment types
456    DeploymentManager,
457    DeploymentStatistics,
458    DeploymentStatus,
459    DeploymentStrategy as VersioningDeploymentStrategy,
460    Environment,
461    FileSystemStorage,
462    HealthStatus,
463    InMemoryStorage,
464    LifecycleEvent,
465    LifecyclePolicies,
466    LifecycleStatistics,
467    // Metadata types
468    ModelMetadata,
469    ModelRegistry,
470    ModelRoutingResult,
471    ModelSource,
472    ModelStorage,
473    ModelTag,
474    // Core versioning types
475    ModelVersionManager,
476    PromotionResult,
477    RegistryStatistics,
478    SortBy,
479    SortOrder,
480    TagMatchMode,
481    VersionExperimentConfig,
482    VersionExperimentResult,
483    VersionFilter,
484    VersionLifecycle,
485    VersionMetricType,
486    // Registry types
487    VersionQuery,
488    VersionStats,
489    // Lifecycle types
490    VersionStatus,
491    VersionTransition,
492    // Integration types
493    VersionedABTestManager,
494    VersionedExperiment,
495    VersionedExperimentStatus,
496    VersionedModel,
497};
498pub use visualization::{
499    ColorScheme, OutputFormat, TensorHeatmap, TensorHistogram, TensorSliceView, TensorStats,
500    TensorVisualizer, VisualizationConfig,
501};