Skip to main content

tensorlogic_scirs_backend/
lib.rs

1//! SciRS2-backed executor (CPU/SIMD/GPU via features).
2//!
3//! **Version**: 0.1.0-beta.1 | **Status**: Production Ready
4//!
5//! This crate provides a production-ready implementation of the TensorLogic execution
6//! traits using the SciRS2 scientific computing library.
7//!
8//! ## Core Features
9//!
10//! ### Execution Engine
11//! - **Forward pass**: Tensor operations (einsum, element-wise, reductions)
12//! - **Backward pass**: Automatic differentiation with stored intermediate values
13//! - **Gradient checking**: Numeric verification for correctness
14//! - **Batch execution**: Parallel processing support for multiple inputs
15//!
16//! ### Performance
17//! - **Memory pooling**: Efficient tensor allocation with shape-based reuse
18//! - **Operation fusion**: Analysis and optimization opportunities
19//! - **SIMD support**: Vectorized operations via feature flags
20//! - **Profiling**: Detailed performance monitoring and tracing
21//!
22//! ### Reliability
23//! - **Error handling**: Comprehensive error types with detailed context
24//! - **Execution tracing**: Multi-level debugging and operation tracking
25//! - **Numerical stability**: Fallback mechanisms for NaN/Inf handling
26//! - **Shape validation**: Runtime shape inference and verification
27//!
28//! ### Testing
29//! - **104 tests**: Including unit, integration, and property-based tests
30//! - **Property tests**: Mathematical properties verified with proptest
31//! - **Gradient tests**: Numeric gradient checking for autodiff correctness
32//!
33//! ## Module Organization
34//!
35//! - `executor`: Core Scirs2Exec implementation
36//! - `autodiff`: Backward pass and gradient computation
37//! - `gradient_ops`: Advanced gradient operations (STE, Gumbel-Softmax, soft quantifiers)
38//! - `error`: Comprehensive error types and validation
39//! - `fallback`: Numerical stability and NaN/Inf handling
40//! - `tracing`: Execution debugging and performance tracking
41//! - `memory_pool`: Efficient tensor allocation
42//! - `fusion`: Operation fusion analysis
43//! - `gradient_check`: Numeric gradient verification
44//! - `shape_inference`: Runtime shape validation
45//! - `batch_executor`: Parallel batch processing
46//! - `profiled_executor`: Performance profiling wrapper
47//! - `capabilities`: Runtime capability detection
48//! - `dependency_analyzer`: Graph dependency analysis for parallel execution
49//! - `parallel_executor`: Multi-threaded parallel execution using Rayon
50//! - `device`: Device management (CPU/GPU selection)
51//! - `execution_mode`: Execution mode abstractions (Eager/Graph/JIT)
52//! - `precision`: Precision control (f32/f64/mixed)
53
54pub(crate) mod autodiff;
55pub mod batch_executor;
56pub mod capabilities;
57pub mod checkpoint;
58mod conversion;
59pub mod cuda_detect;
60pub mod custom_ops;
61pub mod dependency_analyzer;
62pub mod device;
63pub(crate) mod einsum_grad;
64pub mod error;
65pub mod execution_mode;
66mod executor;
67pub mod fallback;
68pub mod fusion;
69pub mod gpu_readiness;
70pub mod gradient_check;
71pub mod gradient_ops;
72pub mod graph_optimizer;
73pub mod inplace_ops;
74pub mod memory_pool;
75pub mod memory_profiler;
76pub mod metrics;
77mod ops;
78pub mod parallel_executor;
79pub mod precision;
80pub mod profiled_executor;
81pub mod quantization;
82pub mod shape_inference;
83pub mod tracing;
84
85#[cfg(feature = "torsh")]
86pub mod torsh_interop;
87
88#[cfg(test)]
89mod tests;
90
91use scirs2_core::ndarray::ArrayD;
92
93pub type Scirs2Tensor = ArrayD<f64>;
94
95pub use autodiff::ForwardTape;
96pub use batch_executor::ParallelBatchExecutor;
97pub use checkpoint::{Checkpoint, CheckpointConfig, CheckpointManager, CheckpointMetadata};
98pub use cuda_detect::{
99    cuda_device_count, cuda_devices_to_device_list, detect_cuda_devices, is_cuda_available,
100    CudaDeviceInfo,
101};
102pub use custom_ops::{
103    BinaryCustomOp, CustomOp, CustomOpContext, EluOp, GeluOp, HardSigmoidOp, HardSwishOp,
104    LeakyReluOp, MishOp, OpRegistry, SoftplusOp, SwishOp,
105};
106pub use dependency_analyzer::{DependencyAnalysis, DependencyStats, OperationDependency};
107pub use device::{Device, DeviceError, DeviceManager, DeviceType};
108pub use error::{
109    NumericalError, NumericalErrorKind, ShapeMismatchError, TlBackendError, TlBackendResult,
110};
111pub use execution_mode::{
112    CompilationStats, CompiledGraph, ExecutionConfig, ExecutionMode, MemoryPlan, OptimizationConfig,
113};
114pub use executor::Scirs2Exec;
115pub use fallback::{is_valid, sanitize_tensor, FallbackConfig};
116pub use gpu_readiness::{
117    assess_gpu_readiness, generate_recommendations, recommend_batch_size, GpuCapability,
118    GpuReadinessReport, WorkloadProfile,
119};
120pub use gradient_ops::{
121    gumbel_softmax, gumbel_softmax_backward, soft_exists, soft_exists_backward, soft_forall,
122    soft_forall_backward, ste_threshold, ste_threshold_backward, GumbelSoftmaxConfig,
123    QuantifierMode, SteConfig,
124};
125pub use graph_optimizer::{
126    GraphOptimizer, GraphOptimizerBuilder, OptimizationPass, OptimizationStats,
127};
128pub use inplace_ops::{can_execute_inplace, is_shape_preserving, InplaceExecutor, InplaceStats};
129pub use memory_profiler::{
130    AllocationRecord, AtomicMemoryCounter, MemoryProfiler, MemoryStats as ProfilerMemoryStats,
131};
132pub use metrics::{
133    format_bytes, shared_metrics, AtomicMetrics, MemoryStats, MetricsCollector, MetricsConfig,
134    MetricsSummary, OperationRecord, OperationStats, SharedMetrics, ThroughputStats,
135};
136pub use parallel_executor::{ParallelConfig, ParallelScirs2Exec, ParallelStats};
137pub use precision::{ComputePrecision, Precision, PrecisionConfig, Scalar};
138pub use profiled_executor::ProfiledScirs2Exec;
139pub use quantization::{
140    calibrate_quantization, QatConfig, QuantizationGranularity, QuantizationParams,
141    QuantizationScheme, QuantizationStats, QuantizationType, QuantizedTensor,
142};
143pub use shape_inference::{validate_tensor_shapes, Scirs2ShapeInference};
144pub use tracing::{ExecutionTracer, TraceEvent, TraceLevel};