1pub mod anomaly_alerts;
97pub mod anomaly_detection;
98pub mod autograd_traits;
99pub mod complex_ops;
100pub mod global_adapter;
101pub mod grad_mode;
102pub mod gradient_storage;
103pub mod guards;
104pub mod variable_env;
105
106pub mod checkpoint_scheduler;
108pub mod common_utils;
109pub mod communication_efficient;
110pub mod compression;
111pub mod context;
112pub mod differentiable_programming;
113pub mod discrete_ops;
114pub mod distributed;
115pub mod error_diagnostics;
116pub mod error_handling;
117pub mod external_ad_integration;
118pub mod federated_learning;
119pub mod flamegraph;
120pub mod function;
121pub mod function_optimization;
122pub mod garbage_collection;
123pub mod gradient_checking;
124pub mod gradient_filtering;
125pub mod gradient_flow_analysis;
126pub mod gradient_scaling;
127pub mod gradient_scheduler;
128pub mod gradient_tracer;
129pub mod gradient_validation;
130pub mod graph_opt;
131pub mod graph_visualization;
132pub mod hyperparameter_optimization;
133pub mod iterative_solvers;
134pub mod jax_transformations;
135pub mod matrix_calculus;
136pub mod memory;
137pub mod meta_gradient;
138pub mod metrics_collection;
139pub mod mlx_compat;
140pub mod onnx_integration;
141pub mod operation_introspection;
142pub mod operation_replay;
143pub mod optimization_diff;
144pub mod parameter_server;
145pub mod profiler;
146pub mod property_testing;
147pub mod pytorch_compat;
148pub mod scirs2_integration;
149pub mod simd_gradient;
150pub mod simd_ops;
151pub mod staleness_handling;
152pub mod stochastic_graphs;
153pub mod structured_logging;
154pub mod symbolic;
155pub mod tensorflow_compat;
156pub mod visualization;
157pub mod vjp_optimization;
158
159pub mod auto_tuning;
161pub mod automatic_error_recovery;
162pub mod blas_integration;
163pub mod buffer_optimization;
164pub mod cross_framework_verification;
165pub mod custom_backends;
166pub mod edge_case_handling;
167pub mod exception_safety;
168pub mod gpu_gradient;
169pub mod graceful_degradation;
170pub mod gradient_clipping;
171pub mod gradient_hooks;
172pub mod hardware_acceleration;
173pub mod health_diagnostics;
174pub mod higher_order_gradients;
175pub mod integration_patterns;
176pub mod intelligent_chunking;
177pub mod interactive_debugger;
178pub mod neural_architecture_search;
179pub mod neural_ode;
180pub mod parallel_gradient;
181pub mod performance_regression;
182pub mod profiling_debugging_integration;
183pub mod progress_reporting;
184pub mod quantum_autograd;
185pub mod raii_resources;
186pub mod regression_testing;
187pub mod scirs2_integration_testing;
188pub mod specialized_gradient_libs;
189pub mod stress_testing;
190
191pub mod ad_framework_compatibility;
193
194pub mod stable_api;
196
197pub mod examples;
199
200pub mod audit_logging;
202pub mod capacity_planning;
203pub mod error_rate_monitoring;
204pub mod flamegraph_generation;
205pub mod gradient_tracing;
206pub mod operation_cost_analysis;
207pub mod performance_dashboard;
208
209pub use crate::error_handling::{AutogradError, AutogradResult};
211
212pub use crate::autograd_traits::{
213 AutogradTensor, AutogradTensorFactory, BackwardTensor, GradientAccumulation,
214};
215
216pub use crate::global_adapter::{
217 backward_global, create_gradient_tensor, get_global_adapter, get_gradient_global,
218};
219
220pub use crate::grad_mode::{
221 is_grad_enabled, pop_grad_enabled, push_grad_enabled, set_grad_enabled, with_grad_mode,
222};
223
224pub use crate::guards::{enable_grad, no_grad, EnableGradGuard, GradModeGuard, NoGradGuard};
225
226pub use crate::gradient_storage::{
227 get_gradient_storage, GlobalGradientStorage, GradientStorage, HashMapGradientStorage,
228};
229
230pub use crate::variable_env::{
231 clear_variable_env, get_or_create_variable_env, handle_inplace_operation,
232 is_variable_env_initialized, validate_inplace_operation, with_variable_env, InplaceConfig,
233 InplaceStrategy,
234};
235
236pub use crate::complex_ops::backward_complex;
237pub use crate::pytorch_compat::backward;
238
239pub use crate::anomaly_detection::{
240 detect_complex_anomalies,
241 recovery::{
242 AnomalyRecoverySystem, RecoveryConfig, RecoveryResult, RecoveryStats, RecoveryStrategy,
243 },
244};
245
246pub use crate::scirs2_integration::{GradientTensor, SciRS2AutogradAdapter};
247
248pub use crate::auto_tuning::{
249 AppliedOptimization, AutoTuningController, OptimizationType, ParameterValue,
250 PerformanceSnapshot, TuningConfig, TuningRecommendation, TuningStatistics,
251};
252
253pub use crate::error_diagnostics::{
254 DiagnosticRecommendation, DiagnosticReport, DiagnosticStatus, DiagnosticsConfig,
255 ErrorCorrelation, ErrorDiagnosticsSystem, ErrorPattern, LabeledErrorEvent, MLAnalysisResult,
256 MLPatternPrediction, MLPatternRecognitionSystem, MLSystemConfig, PatternLabel, SeverityLevel,
257 TemporalContext,
258};
259
260pub const VERSION: &str = env!("CARGO_PKG_VERSION");
262pub const VERSION_MAJOR: u32 = 0;
263pub const VERSION_MINOR: u32 = 1;
264pub const VERSION_PATCH: u32 = 0;
265
266use torsh_core::error::{Result, TorshError};
268
269#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
274pub struct TensorVersion {
275 pub version: usize,
277 pub tensor_id: usize,
279}
280
281impl TensorVersion {
282 pub fn new(tensor_id: usize) -> Self {
284 Self {
285 version: 0,
286 tensor_id,
287 }
288 }
289
290 pub fn increment(&mut self) -> Self {
292 self.version += 1;
293 *self
294 }
295
296 pub fn is_compatible_with(&self, other: &TensorVersion) -> bool {
298 self.tensor_id == other.tensor_id && self.version == other.version
299 }
300}
301
302static TENSOR_ID_COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
304
305pub fn new_tensor_id() -> usize {
307 TENSOR_ID_COUNTER.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
308}
309
310pub mod inplace_versioning {
315 use super::*;
316
317 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
319 pub enum VersionConflictStrategy {
320 Error,
322 Warn,
324 CopyOnWrite,
326 Allow,
328 }
329
330 impl Default for VersionConflictStrategy {
331 fn default() -> Self {
332 Self::Warn
333 }
334 }
335
336 pub fn check_version_compatibility(
338 current: &TensorVersion,
339 expected: &TensorVersion,
340 strategy: VersionConflictStrategy,
341 operation_name: &str,
342 ) -> Result<()> {
343 if !current.is_compatible_with(expected) {
344 let message = format!(
345 "Version conflict in in-place operation '{}': expected version {} but found {}",
346 operation_name, expected.version, current.version
347 );
348
349 match strategy {
350 VersionConflictStrategy::Error => {
351 return Err(TorshError::AutogradError(message));
352 }
353 VersionConflictStrategy::Warn => {
354 tracing::warn!("{}", message);
355 }
356 VersionConflictStrategy::CopyOnWrite => {
357 tracing::info!("Triggering copy-on-write for: {}", message);
358 }
359 VersionConflictStrategy::Allow => {
360 tracing::debug!("Allowing version conflict: {}", message);
361 }
362 }
363 }
364 Ok(())
365 }
366
367 pub fn update_version_after_inplace(
369 version: &mut TensorVersion,
370 operation_name: &str,
371 ) -> TensorVersion {
372 let old_version = *version;
373 let new_version = version.increment();
374 tracing::debug!(
375 "Updated tensor {} version from {} to {} after in-place operation '{}'",
376 version.tensor_id,
377 old_version.version,
378 new_version.version,
379 operation_name
380 );
381 new_version
382 }
383}
384
385pub mod prelude {
387 pub use crate::autograd_traits::{AutogradTensor, BackwardTensor, GradientAccumulation};
388 pub use crate::global_adapter::{
389 backward_global, create_gradient_tensor, get_global_adapter, get_gradient_global,
390 };
391 pub use crate::grad_mode::{is_grad_enabled, set_grad_enabled, with_grad_mode};
392 pub use crate::gradient_storage::{get_gradient_storage, GradientStorage};
393 pub use crate::guards::{enable_grad, no_grad, EnableGradGuard, NoGradGuard};
394 pub use crate::variable_env::{with_variable_env, InplaceStrategy};
395 pub use crate::{new_tensor_id, TensorVersion};
396
397 pub use crate::neural_architecture_search::{
399 ConvOperation, DARTSArchitecture, DARTSPruningResult, IdentityOperation, MixedOperation,
400 ProgressiveDARTS, ProgressiveStageInfo, SamplingStrategy, SearchableOperation,
401 ZeroOperation, DARTS,
402 };
403
404 pub use crate::neural_ode::{
406 AdjointMethod, AdjointSolution, IntegrationMethod, NeuralODE, NeuralODELayer, ODESolver,
407 ODESolverConfig, ODESystem,
408 };
409
410 pub use crate::parallel_gradient::{
412 configure_global_parallel, get_global_parallel_computer, get_global_parallel_computer_mut,
413 ParallelConfig, ParallelGradientComputer, ParallelStats,
414 };
415
416 pub use crate::gpu_gradient::{
418 get_global_gpu_computer, initialize_global_gpu, is_global_gpu_available, ActivationType,
419 GpuBackend, GpuConfig, GpuGradientComputer, GpuStats,
420 };
421
422 pub use crate::simd_gradient::{
424 configure_global_simd, get_global_simd_computer, get_global_simd_computer_mut,
425 SimdCapability, SimdConfig, SimdGradientComputer, SimdStats,
426 };
427
428 pub use crate::intelligent_chunking::{
430 configure_global_chunker, get_global_chunker, get_global_chunker_mut, ChunkingConfig,
431 ChunkingStats, ChunkingStrategy, IntelligentChunker,
432 };
433
434 pub use crate::gradient_clipping::{
436 clip_gradients_global, configure_global_clipper, get_global_clipper,
437 get_global_clipper_mut, ClippingStats, ClippingStrategy, GradientClipper,
438 };
439
440 pub use crate::higher_order_gradients::{
442 configure_global_higher_order, get_global_higher_order, get_global_higher_order_mut,
443 ComputationMode, GradientOrder, HigherOrderConfig, HigherOrderGradient, HigherOrderStats,
444 };
445
446 pub use crate::gradient_hooks::{
448 get_global_hook_manager, get_global_hook_manager_mut, register_global_hook, GradientHook,
449 GradientHookManager, HookContext, HookPriority, HookStats, HookType,
450 };
451
452 pub use crate::automatic_error_recovery::{
454 get_global_recovery, recover_from_error, AutomaticErrorRecovery, CorrectiveTransform,
455 RecoveryAction, RecoveryStrategy, TransientFailureType,
456 };
457 pub use crate::with_error_recovery;
458
459 pub use crate::edge_case_handling::{
461 get_global_edge_case_handler, handle_tensor_edge_cases, validate_tensor_shapes,
462 EdgeCaseHandler, EdgeCaseStrategy, EdgeCaseTransformation, EdgeCaseType, TensorInfo,
463 };
464
465 pub use crate::quantum_autograd::{
467 Complex, Observable, PauliX, QuantumCircuit, QuantumExpectationValue, QuantumGate,
468 QuantumState, QuantumStateGradient, Qubit, RotationY, VQEResult, CNOT, VQE,
469 };
470
471 pub use crate::cross_framework_verification::{
473 get_global_verifier, initialize_verification_frameworks, ComparisonTolerance,
474 CrossFrameworkVerifier, FrameworkAdapter, GradientComparisonResult, GradientData,
475 MockPyTorchAdapter, SupportedFramework, TorshFrameworkAdapter, VerificationReport,
476 };
477
478 pub use crate::regression_testing::{
480 get_global_regression_tester, GradientRegressionTester, GradientTestCase,
481 RegressionTestResult, RegressionTestStatistics,
482 };
483
484 pub use crate::exception_safety::{
486 get_global_executor, AutogradTransaction, ComputationGraphGuard, ExceptionSafeExecutor,
487 ExceptionSafetyAnalyzer, ExceptionSafetyLevel, GradientStorageGuard, ResourceGuard,
488 SafetyViolation, SafetyViolationReport, TransactionOperation,
489 };
490
491 pub use crate::graceful_degradation::{
493 get_global_degradation_manager, DegradationEvent, DegradationStatistics,
494 DegradationStrategy, FallbackFunction, GracefulDegradationManager, MatrixMultiplyFallback,
495 OperationCategory, UnsupportedOperationInfo, UnsupportedReason,
496 };
497
498 pub use crate::scirs2_integration_testing::{
500 get_global_integration_tester, run_scirs2_integration_tests, CompatibilitySummary,
501 PerformanceSummary, SciRS2IntegrationTestCase, SciRS2IntegrationTestSuite,
502 SciRS2IntegrationTester, SciRS2TestResult, SciRS2Version, TestCategory,
503 TestCategoryResults,
504 };
505
506 pub use crate::integration_patterns::{
508 IntegrationDocumentation, IntegrationPatterns, MigrationGuide, MigrationScenario, Pattern,
509 PatternCategory, PatternDocumentation, TroubleshootingGuide, TroubleshootingIssue,
510 };
511
512 pub use crate::blas_integration::{
514 blas_dot, blas_gemm, blas_gemv, get_global_blas_manager, BlasConfig, BlasImplementation,
515 BlasManager, BlasOperation, BlasPerformanceReport, BlasProvider, PureRustBlasProvider,
516 };
517
518 pub use crate::specialized_gradient_libs::{
520 get_global_specialized_manager, BenchmarkReport, CasADiLibrary, ComputationResult,
521 Function, GradientComputationType, LibraryUsageReport, QuadraticFunction, SparseGradient,
522 SpecializedGradientLibrary, SpecializedLibConfig, SpecializedLibrary,
523 SpecializedLibraryManager,
524 };
525
526 pub use crate::custom_backends::{
528 get_active_backend, get_global_backend_registry, AutogradBackend, BackendCapability,
529 BackendConfig, BackendInfo, BackendRegistry, BackendTensor, CustomOperation, DataType,
530 DeviceConfig, DeviceType, GradFunction, OperationContext, OptimizationLevel,
531 PerformanceStats, ReferenceBackend,
532 };
533
534 pub use crate::hardware_acceleration::{
536 get_global_acceleration_manager, AccelerationConfig, AcceleratorBenchmarkReport,
537 AcceleratorType, AcceleratorUsageReport, Conv2DParams, CudaAccelerator, DeviceStats,
538 HardwareAccelerationManager, HardwareAccelerator, HardwareCapability, HardwareDevice,
539 HardwareMemoryHandle, MetalAccelerator, OptimizationLevel as HardwareOptimizationLevel,
540 PrecisionPreference,
541 };
542
543 pub use crate::profiling_debugging_integration::{
545 get_global_profiling_debugging_manager, AnalysisCapability, CPUProfile, DebuggingConfig,
546 DebuggingReport, DebuggingSession, DebuggingTool, ExternalDebugger, ExternalProfiler,
547 GPUProfile, GdbDebugger, Hotspot, IntegrationConfig, IntegrationReport, MemoryError,
548 MemoryProfile, PerfProfiler, ProfilingConfig, ProfilingDebuggingManager, ProfilingReport,
549 ProfilingSession, ProfilingTool, StackTrace, ThreadError,
550 };
551
552 pub use crate::ad_framework_compatibility::{
554 check_framework_compatibility, convert_tensor, get_global_compatibility_manager,
555 migrate_model, ADFramework, ADFrameworkCompatibilityManager, AutomationLevel,
556 CompatibilityLevel, CompatibilityReport, CustomOperationDefinition, EffortLevel,
557 FrameworkAdapter as ADFrameworkAdapter, FrameworkCapabilities, FrameworkTensor,
558 MigrationCapability, MigrationData, MigrationOperation, MigrationPlan, MigrationResult,
559 MigrationStep, PerformanceComparison, PerformanceMetrics, PyTorchAdapter, PyTorchTensor,
560 RequiredTransformation, UniversalDataType, UniversalOperation, UniversalTensor,
561 ValidationResult,
562 };
563}
564
565pub mod accumulate {
567 use super::*;
568 use crate::autograd_traits::AutogradTensor;
569 use num_traits::Float;
570
571 pub fn accumulate_gradient_safe<T>(
573 existing: &mut dyn AutogradTensor<T>,
574 new_grad: &dyn AutogradTensor<T>,
575 overflow_threshold: Option<T>,
576 ) -> Result<()>
577 where
578 T: torsh_core::dtype::TensorElement
579 + Float
580 + Clone
581 + std::fmt::Debug
582 + std::fmt::Display
583 + Send
584 + Sync,
585 {
586 if existing.shape() != new_grad.shape() {
588 return Err(TorshError::AutogradError(format!(
589 "Shape mismatch in gradient accumulation: {:?} vs {:?}",
590 existing.shape(),
591 new_grad.shape()
592 )));
593 }
594
595 let existing_data = existing.data();
597 let new_data = new_grad.data();
598
599 if let Some(threshold) = overflow_threshold {
601 for (existing_val, new_val) in existing_data.iter().zip(new_data.iter()) {
602 let sum = *existing_val + *new_val;
603 if sum.abs() > threshold {
604 return Err(TorshError::AutogradError(format!(
605 "Gradient accumulation overflow detected: {} + {} = {} > {}",
606 existing_val, new_val, sum, threshold
607 )));
608 }
609 }
610 }
611
612 tracing::debug!(
615 "Gradient accumulation requested for tensor with shape {:?}",
616 existing.shape()
617 );
618
619 Ok(())
620 }
621
622 pub fn check_accumulation_overflow<T>(val1: T, val2: T, threshold: T) -> bool
624 where
625 T: Float + PartialOrd,
626 {
627 let sum = val1 + val2;
628 sum.abs() > threshold
629 }
630}
631
632pub mod clip {
634 use super::*;
635 use crate::autograd_traits::AutogradTensor;
636 use num_traits::Float;
637
638 pub fn clip_grad_norm<T>(
640 gradients: &[&dyn AutogradTensor<T>],
641 max_norm: T,
642 norm_type: f32,
643 ) -> Result<T>
644 where
645 T: torsh_core::dtype::TensorElement
646 + Float
647 + Clone
648 + std::fmt::Debug
649 + Send
650 + Sync
651 + std::fmt::Display,
652 f32: From<T>,
653 {
654 if gradients.is_empty() {
655 return Ok(<T as num_traits::Zero>::zero());
656 }
657
658 let mut total_norm = <T as num_traits::Zero>::zero();
660
661 for grad in gradients {
662 let data = grad.data();
663 for &val in data.iter() {
664 if norm_type == 2.0 {
665 total_norm = total_norm + val * val;
666 } else if norm_type == 1.0 {
667 total_norm = total_norm + val.abs();
668 } else {
669 let abs_val = val.abs();
670 total_norm = total_norm
671 + abs_val.powf(T::from(norm_type).unwrap_or(<T as num_traits::One>::one()));
672 }
673 }
674 }
675
676 if norm_type == 2.0 {
677 total_norm = total_norm.sqrt();
678 } else if norm_type != 1.0 {
679 total_norm =
680 total_norm.powf(T::from(1.0 / norm_type).unwrap_or(<T as num_traits::One>::one()));
681 }
682
683 tracing::debug!("Calculated gradient norm: {:?}", total_norm);
684
685 let clip_coef =
687 max_norm / (total_norm + T::from(1e-6).unwrap_or(<T as num_traits::One>::one()));
688 let clip_coef = clip_coef.min(<T as num_traits::One>::one());
689
690 tracing::debug!("Gradient clipping coefficient: {:?}", clip_coef);
691
692 Ok(total_norm)
693 }
694
695 pub fn clip_grad_value<T>(
697 gradient: &dyn AutogradTensor<T>,
698 min_value: T,
699 max_value: T,
700 ) -> Result<Vec<T>>
701 where
702 T: torsh_core::dtype::TensorElement + Float + Clone + std::fmt::Debug + Send + Sync,
703 {
704 let data = gradient.data();
705 let clipped: Vec<T> = data
706 .iter()
707 .map(|&val| val.max(min_value).min(max_value))
708 .collect();
709
710 Ok(clipped)
711 }
712}
713
714pub mod forward_mode {
716
717 use num_traits::Float;
718
719 #[derive(Debug, Clone, Copy, PartialEq)]
721 pub struct Dual<T> {
722 pub value: T,
724 pub derivative: T,
726 }
727
728 impl<T: Float> Dual<T> {
729 pub fn new(value: T, derivative: T) -> Self {
731 Self { value, derivative }
732 }
733
734 pub fn variable(value: T) -> Self {
736 Self::new(value, <T as num_traits::One>::one())
737 }
738
739 pub fn constant(value: T) -> Self {
741 Self::new(value, T::zero())
742 }
743 }
744
745 impl<T: Float> std::ops::Add for Dual<T> {
746 type Output = Self;
747
748 fn add(self, rhs: Self) -> Self::Output {
749 Self::new(self.value + rhs.value, self.derivative + rhs.derivative)
750 }
751 }
752
753 impl<T: Float> std::ops::Mul for Dual<T> {
754 type Output = Self;
755
756 fn mul(self, rhs: Self) -> Self::Output {
757 Self::new(
758 self.value * rhs.value,
759 self.derivative * rhs.value + self.value * rhs.derivative,
760 )
761 }
762 }
763
764 pub fn forward_derivative<T, F>(input: T, f: F) -> (T, T)
766 where
767 T: Float + Clone,
768 F: Fn(Dual<T>) -> Dual<T>,
769 {
770 let dual_input = Dual::variable(input);
771 let dual_output = f(dual_input);
772 (dual_output.value, dual_output.derivative)
773 }
774}
775
776#[cfg(test)]
777mod tests {
778 use super::*;
779
780 #[test]
781 fn test_tensor_version() {
782 let mut version = TensorVersion::new(42);
783 assert_eq!(version.tensor_id, 42);
784 assert_eq!(version.version, 0);
785
786 let new_version = version.increment();
787 assert_eq!(new_version.version, 1);
788 assert_eq!(version.version, 1);
789 }
790
791 #[test]
792 fn test_tensor_version_compatibility() {
793 let version1 = TensorVersion::new(1);
794 let version2 = TensorVersion::new(1);
795 let version3 = TensorVersion::new(2);
796
797 assert!(version1.is_compatible_with(&version2));
798 assert!(!version1.is_compatible_with(&version3));
799 }
800
801 #[test]
802 fn test_unique_tensor_ids() {
803 let id1 = new_tensor_id();
804 let id2 = new_tensor_id();
805 let id3 = new_tensor_id();
806
807 assert_ne!(id1, id2);
808 assert_ne!(id2, id3);
809 assert_ne!(id1, id3);
810 }
811
812 #[test]
813 fn test_dual_number_arithmetic() {
814 use forward_mode::Dual;
815
816 let x = Dual::new(2.0, 1.0);
817 let y = Dual::new(3.0, 0.0);
818
819 let sum = x + y;
820 assert_eq!(sum.value, 5.0);
821 assert_eq!(sum.derivative, 1.0);
822
823 let product = x * y;
824 assert_eq!(product.value, 6.0);
825 assert_eq!(product.derivative, 3.0);
826 }
827
828 #[test]
829 fn test_forward_derivative() {
830 use forward_mode::forward_derivative;
831
832 let (value, derivative) = forward_derivative(3.0, |x| x * x);
834 assert_eq!(value, 9.0);
835 assert_eq!(derivative, 6.0);
836 }
837}