Expand description
§SciRS2 Autograd - Automatic Differentiation for Rust
scirs2-autograd provides PyTorch-style automatic differentiation with lazy tensor evaluation, enabling efficient gradient computation for scientific computing and deep learning.
§šÆ Key Features
- Reverse-mode Autodiff: Efficient backpropagation for neural networks
- Lazy Evaluation: Build computation graphs, evaluate only when needed
- Higher-order Gradients: Compute derivatives of derivatives
- Neural Network Ops: Optimized operations for deep learning
- Optimizers: Adam, SGD, RMSprop with state management
- Model Persistence: Save and load trained models
- Variable Management: Namespace-based variable organization
§š¦ Installation
[dependencies]
scirs2-autograd = { version = "0.1.5", features = ["blas"] }§BLAS Acceleration (Recommended)
For fast matrix operations, enable BLAS (uses OxiBLAS - pure Rust):
[dependencies]
scirs2-autograd = { version = "0.1.5", features = ["blas"] }§š Quick Start
§Basic Differentiation
Compute gradients of a simple function:
use scirs2_autograd as ag;
use ag::tensor_ops as T;
ag::run(|ctx: &mut ag::Context<f64>| {
// Define variables
let x = ctx.placeholder("x", &[]);
let y = ctx.placeholder("y", &[]);
// Build computation graph: z = 2x² + 3y + 1
let z = 2.0 * x * x + 3.0 * y + 1.0;
// Compute dz/dy
let dz_dy = &T::grad(&[z], &[y])[0];
println!("dz/dy = {:?}", dz_dy.eval(ctx)); // => 3.0
// Compute dz/dx (feed x=2)
let dz_dx = &T::grad(&[z], &[x])[0];
let x_val = scirs2_core::ndarray::arr0(2.0);
let result = ctx.evaluator()
.push(dz_dx)
.feed(x, x_val.view().into_dyn())
.run()[0].clone();
println!("dz/dx at x=2: {:?}", result); // => 8.0
// Higher-order: d²z/dx²
let d2z_dx2 = &T::grad(&[dz_dx], &[x])[0];
println!("d²z/dx² = {:?}", d2z_dx2.eval(ctx)); // => 4.0
});§Neural Network Training
Train a multi-layer perceptron for MNIST:
use scirs2_autograd as ag;
use ag::optimizers::adam::Adam;
use ag::tensor_ops::*;
use ag::prelude::*;
// Create variable environment
let mut env = ag::VariableEnvironment::new();
let mut rng = ag::ndarray_ext::ArrayRng::<f32>::default();
// Initialize network weights
env.name("w1").set(rng.glorot_uniform(&[784, 256]));
env.name("b1").set(ag::ndarray_ext::zeros(&[1, 256]));
env.name("w2").set(rng.glorot_uniform(&[256, 10]));
env.name("b2").set(ag::ndarray_ext::zeros(&[1, 10]));
// Create Adam optimizer
let var_ids = env.default_namespace().current_var_ids();
let adam = Adam::default("adam", var_ids, &mut env);
// Training loop
for epoch in 0..10 {
env.run(|ctx| {
// Define computation graph
let x = ctx.placeholder("x", &[-1, 784]);
let y = ctx.placeholder("y", &[-1]);
let w1 = ctx.variable("w1");
let b1 = ctx.variable("b1");
let w2 = ctx.variable("w2");
let b2 = ctx.variable("b2");
// Forward pass: x -> hidden -> output
let hidden = relu(matmul(x, w1) + b1);
let logits = matmul(hidden, w2) + b2;
// Loss: cross-entropy
let loss = reduce_mean(
sparse_softmax_cross_entropy(logits, &y),
&[0],
false
);
// Backpropagation
let params = &[w1, b1, w2, b2];
let grads = &grad(&[loss], params);
// Update weights (requires actual data feeding)
// let mut feeder = ag::Feeder::new();
// feeder.push(x, x_batch).push(y, y_batch);
// adam.update(params, grads, ctx, feeder);
});
}§Custom Operations
Define custom differentiable operations:
use scirs2_autograd as ag;
use ag::tensor_ops::*;
ag::run::<f64, _, _>(|ctx| {
let x = ones(&[3, 4], ctx);
// Apply custom transformations using tensor.map()
let y = x.map(|arr| arr.mapv(|v: f64| v * 2.0 + 1.0));
// Hooks for debugging
let z = x.showshape(); // Print shape
let w = x.raw_hook(|arr| println!("Tensor value: {}", arr));
});§š§ Core Concepts
§Tensors
Lazy-evaluated multi-dimensional arrays with automatic gradient tracking:
use scirs2_autograd as ag;
use ag::tensor_ops::*;
use ag::prelude::*;
ag::run::<f64, _, _>(|ctx| {
// Create tensors
let a = zeros(&[2, 3], ctx); // All zeros
let b = ones(&[2, 3], ctx); // All ones
let c = ctx.placeholder("c", &[2, 3]); // Placeholder (fill later)
let d = ctx.variable("d"); // Trainable variable
});§Computation Graphs
Build graphs of operations, evaluate lazily:
use scirs2_autograd as ag;
use ag::tensor_ops as T;
ag::run::<f64, _, _>(|ctx| {
let x = ctx.placeholder("x", &[2, 2]);
let y = ctx.placeholder("y", &[2, 2]);
// Build graph (no computation yet)
let z = T::matmul(x, y);
let w = T::sigmoid(z);
// Evaluate when needed
// let result = w.eval(ctx);
});§Gradient Computation
Reverse-mode automatic differentiation:
use scirs2_autograd as ag;
use ag::tensor_ops as T;
ag::run(|ctx| {
let x = ctx.placeholder("x", &[]);
let y = x * x * x; // y = x³
// Compute dy/dx = 3x²
let dy_dx = &T::grad(&[y], &[x])[0];
// Evaluate at x=2: 3(2²) = 12
let x_val = scirs2_core::ndarray::arr0(2.0);
let grad_val = ctx.evaluator()
.push(dy_dx)
.feed(x, x_val.view().into_dyn())
.run()[0].clone();
});§šØ Available Operations
§Basic Math
- Arithmetic:
+,-,*,/,pow - Comparison:
equal,not_equal,greater,less - Reduction:
sum,mean,max,min
§Neural Network Ops
- Activations:
relu,sigmoid,tanh,softmax,gelu - Pooling:
max_pool2d,avg_pool2d - Convolution:
conv2d,conv2d_transpose - Normalization:
batch_norm,layer_norm - Dropout:
dropout
§Matrix Operations
matmul- Matrix multiplicationtranspose- Matrix transposereshape- Change tensor shapeconcat- Concatenate tensorssplit- Split tensor
§Loss Functions
sparse_softmax_cross_entropy- Classification losssigmoid_cross_entropy- Binary classificationsoftmax_cross_entropy- Multi-class loss
§š§ Optimizers
Built-in optimization algorithms:
- Adam: Adaptive moment estimation (recommended)
- SGD: Stochastic gradient descent with momentum
- RMSprop: Root mean square propagation
- Adagrad: Adaptive learning rates
§š¾ Model Persistence
Save and load trained models:
use scirs2_autograd as ag;
let mut env = ag::VariableEnvironment::<f64>::new();
// After training...
env.save("model.safetensors")?;
// Later, load the model
let env = ag::VariableEnvironment::<f64>::load("model.safetensors")?;§š Performance
scirs2-autograd is designed for efficiency:
- Lazy Evaluation: Build graphs without computation overhead
- Minimal Allocations: Reuse memory where possible
- BLAS Integration: Fast matrix operations via OxiBLAS (pure Rust)
- Zero-copy: Efficient data handling with ndarray views
Typical training speed: 0.11 sec/epoch for MNIST MLP (2.7GHz Intel Core i5)
§š Integration
- scirs2-neural: High-level neural network layers
- scirs2-linalg: Matrix operations
- scirs2-optimize: Optimization algorithms
- ndarray: Core array library (re-exported)
§š Comparison with PyTorch
| Feature | PyTorch | scirs2-autograd |
|---|---|---|
| Autodiff | ā | ā |
| Dynamic Graphs | ā | ā |
| GPU Support | ā | ā (v0.2.0) |
| Type Safety | ā | ā |
| Memory Safety | ā ļø | ā |
| Pure Rust | ā | ā |
§š v0.2.0 Features
- GPU Acceleration: CUDA, Metal, OpenCL, WebGPU backends
- Higher-Order Derivatives: Hessian-vector products, full Jacobians
- Memory Optimization: Advanced checkpointing, memory pooling
- Graph Optimization: CSE, operation fusion
- Distributed Training: Data/model parallelism
- Symbolic Differentiation: Analytical derivatives
§š Version
Current version: 0.2.0 (Development - Target Release February 2026)
Re-exports§
pub use crate::ndarray_ext::array_gen;pub use crate::ndarray_ext::NdArray;pub use crate::ndarray_ext::NdArrayView;pub use crate::ndarray_ext::NdArrayViewMut;pub use crate::evaluation::Evaluator;pub use crate::evaluation::Feeder;pub use crate::tensor::Tensor;pub use crate::error::AutogradError;pub use crate::error::EvalError;pub use crate::error::OpError;pub use crate::error::Result;pub use crate::graph::run;pub use crate::graph::Context;pub use crate::high_performance::memory_efficient_grad_accumulation;pub use crate::high_performance::parallel_gradient_computation;pub use crate::high_performance::simd_backward_pass;pub use crate::high_performance::ultra_backward_pass;pub use crate::variable::AutogradTensor;pub use crate::variable::SafeVariable;pub use crate::variable::SafeVariableEnvironment;pub use crate::variable::VariableEnvironment;pub use crate::optimizers::FunctionalAdam;pub use crate::optimizers::FunctionalOptimizer;pub use crate::optimizers::FunctionalSGD;pub use crate::forward_mode::gradient_forward;pub use crate::forward_mode::hessian;pub use crate::forward_mode::hessian_vector_product;pub use crate::forward_mode::jacobian_forward;pub use crate::forward_mode::jvp;pub use crate::forward_mode::DualNumber;pub use crate::transforms::batched_value_and_grad;pub use crate::transforms::check_grad;pub use crate::transforms::compose;pub use crate::transforms::iterate;pub use crate::transforms::numerical_jacobian;pub use crate::transforms::pmap;pub use crate::transforms::scan;pub use crate::transforms::stop_gradient;pub use crate::transforms::stop_gradient_1d;pub use crate::transforms::stop_gradient_dual;pub use crate::transforms::vmap;pub use crate::transforms::Checkpoint;pub use crate::transforms::JitHint;pub use crate::custom_gradient::custom_op;pub use crate::custom_gradient::custom_unary_op;pub use crate::custom_gradient::detach;pub use crate::custom_gradient::gradient_reversal;pub use crate::custom_gradient::scale_gradient;pub use crate::custom_gradient::selective_stop_gradient;pub use crate::custom_gradient::CustomGradientOp;pub use crate::custom_gradient::ScaleGradient;pub use crate::custom_gradient::SelectiveStopGradient;pub use crate::gradient_accumulation::GradientAccumulator;pub use crate::gradient_accumulation::GradientStats;pub use crate::gradient_accumulation::VirtualBatchAccumulator;pub use crate::higher_order::extensions::efficient_second_order_grad;pub use crate::higher_order::extensions::fisher_diagonal;pub use crate::higher_order::extensions::fisher_information;pub use crate::higher_order::extensions::fisher_information_forward;pub use crate::higher_order::extensions::hessian_diagonal;pub use crate::higher_order::extensions::hessian_diagonal_forward;pub use crate::higher_order::extensions::laplacian;pub use crate::higher_order::extensions::laplacian_forward;pub use crate::jacobian_ops::batch_jacobian;pub use crate::jacobian_ops::jacobian_auto;pub use crate::jacobian_ops::jacobian_check;pub use crate::jacobian_ops::jacobian_diagonal;pub use crate::jacobian_ops::jacobian_reverse;pub use crate::jacobian_ops::jvp_forward;pub use crate::jacobian_ops::jvp_graph;pub use crate::jacobian_ops::numerical_jacobian as numerical_jacobian_fd;pub use crate::jacobian_ops::vjp_multi;pub use crate::jacobian_ops::vjp_reverse;pub use crate::visualization::graph_summary;pub use crate::visualization::graph_to_dot;pub use crate::visualization::graph_to_json;pub use crate::visualization::graph_to_mermaid;pub use crate::visualization::GraphStats;pub use crate::scheduling::build_memory_plan;pub use crate::scheduling::critical_path;pub use crate::scheduling::forward_schedule;pub use crate::scheduling::level_decomposition;pub use crate::scheduling::memory_optimal_schedule;pub use crate::scheduling::parallel_analysis;pub use crate::scheduling::reverse_schedule;pub use crate::scheduling::validate_schedule;pub use crate::scheduling::work_stealing_schedule;pub use crate::scheduling::CriticalPath;pub use crate::scheduling::MemoryPlan;pub use crate::scheduling::ParallelAnalysis;pub use crate::scheduling::Schedule;pub use crate::scheduling::ScheduleDirection;pub use crate::scheduling::WorkStealingSchedule;pub use crate::graph_transforms::analyse_graph;pub use crate::graph_transforms::detect_algebraic_simplifications;pub use crate::graph_transforms::detect_cse;pub use crate::graph_transforms::detect_fusions;pub use crate::graph_transforms::find_dead_nodes;pub use crate::graph_transforms::find_foldable_constants;pub use crate::graph_transforms::infer_shapes;pub use crate::graph_transforms::AlgebraicSimplification;pub use crate::graph_transforms::FusionGroup;pub use crate::graph_transforms::FusionKind;pub use crate::graph_transforms::SimplificationRule;pub use crate::graph_transforms::TransformReport;pub use crate::autodiff_enhanced::binomial_checkpoint_plan;pub use crate::autodiff_enhanced::build_rematerialization_plan;pub use crate::autodiff_enhanced::plan_jacobian_computation;pub use crate::autodiff_enhanced::select_jacobian_mode;pub use crate::autodiff_enhanced::solve_implicit_diff;pub use crate::autodiff_enhanced::sqrt_checkpoint_plan;pub use crate::autodiff_enhanced::uniform_checkpoint_plan;pub use crate::autodiff_enhanced::CheckpointPlan;pub use crate::autodiff_enhanced::CheckpointStrategy;pub use crate::autodiff_enhanced::DiffRuleRegistry;pub use crate::autodiff_enhanced::ImplicitDiffConfig;pub use crate::autodiff_enhanced::ImplicitDiffResult;pub use crate::autodiff_enhanced::JacobianMode;pub use crate::autodiff_enhanced::MixedModeJacobianPlan;pub use crate::autodiff_enhanced::RematerializationDecision;pub use crate::autodiff_enhanced::RematerializationPolicy;pub use crate::profiling::analyse_gradient_flow;pub use crate::profiling::classify_gradient;pub use crate::profiling::count_ops;pub use crate::profiling::estimate_bandwidth;pub use crate::profiling::estimate_flops;pub use crate::profiling::graph_complexity;pub use crate::profiling::has_gradient_issues;pub use crate::profiling::profile_graph;pub use crate::profiling::total_flops;pub use crate::profiling::BandwidthEstimate;pub use crate::profiling::EstimateConfidence;pub use crate::profiling::FlopEstimate;pub use crate::profiling::GradientFlowStats;pub use crate::profiling::GradientHealth;pub use crate::profiling::GradientThresholds;pub use crate::profiling::GraphComplexity;pub use crate::profiling::OpCounts;pub use crate::profiling::OpTiming;pub use crate::profiling::OperationProfiler;pub use crate::profiling::ProfilingReport;
Modules§
- autodiff_
enhanced - Enhanced automatic differentiation strategies
- custom_
gradient - Custom gradient rules for automatic differentiation
- distributed
- Distributed automatic differentiation
- error
- Error types for the autograd module
- error_
helpers - Helper functions to eliminate repetitive expect() patterns throughout the codebase.
- evaluation
- forward_
mode - Forward-mode automatic differentiation via dual numbers
- gpu
- GPU-accelerated automatic differentiation
- gradient_
accumulation - Gradient accumulation for large effective batch sizes
- gradient_
clipping - Gradient clipping utilities
- graph
- graph_
transforms - Graph transformation passes for computation graph optimisation
- high_
performance - High-Performance Autograd APIs for ToRSh Integration
- higher_
order - Higher-order automatic differentiation
- hooks
- You can register hooks on
ag::Tensorobjects for debugging. - integration
- Integration utilities for working with other SciRS2 modules
- jacobian_
ops - Enhanced Jacobian computation for automatic differentiation
- jax
- JAX-style functional transformations re-exported for convenience.
- memory_
pool - Tensor memory pool for reducing allocation pressure during training.
- ndarray
- Complete ndarray re-export for SciRS2 ecosystem
- ndarray_
ext - A small extension of ndarray
- onnx
- ONNX Ecosystem Interoperability
- op
- Implementing differentiable operations
- optimization
- Graph optimization and expression simplification for computation graphs
- optimizers
- A collection of gradient descent optimizers
- parallel
- Parallel processing and thread pool optimizations
- prelude
- Exports useful trait implementations
- profiling
- Profiling and debugging tools for computation graphs
- rand
- Ultra-advanced random number generation for SCIRS2 ecosystem
- schedulers
- Learning rate schedulers
- scheduling
- Operator scheduling for computation graphs
- symbolic
- Symbolic differentiation capabilities
- tensor
- tensor_
ops - A collection of functions for manipulating
autograd::Tensorobjects - test_
helper - Provides helper functions for testing.
- testing
- Numerical stability testing framework for automatic differentiation
- tracing
- Enhanced tracing and recording capabilities for computation graphs
- transforms
- JAX-inspired function transformation primitives
- validation
- Domain validation utilities for mathematical operations.
- variable
- Variable and namespace
- visualization
- Computation graph visualization and analysis tools
Macros§
- symexpr
- Helper macro for building symbolic expressions
Traits§
- Float
- A primitive type in this crate, which is actually a decorated
scirs2_core::numeric::Float.
Functions§
- rand
- Convenience function to generate a random value of the inferred type