Tensor

Struct Tensor 

Source
pub struct Tensor { /* private fields */ }
Expand description

High-performance multi-dimensional tensor with automatic differentiation support

The core data structure for machine learning operations, designed for maximum performance with zero-cost abstractions. Supports arbitrary dimensionality, SIMD optimization, gradient tracking, device placement, and natural mathematical expressions through operator overloading.

§Key Features

  • Raw Pointer Storage: Zero-overhead memory access for maximum performance
  • SIMD Optimization: AVX2 alignment and vectorized operations
  • Memory Efficiency: Optimized alignment strategies for different tensor sizes
  • gradtrack Integration: Built-in gradient tracking and computation
  • Device Support: CPU and future CUDA device placement
  • View Tensors: Zero-copy tensor views with shared memory management
  • Thread Safety: Send + Sync implementation for concurrent usage
  • Operator Overloading: Natural mathematical expressions (+, -, *, /, +=, -=, *=, /=)

§Memory Layout

Tensors use row-major memory layout with size-dependent alignment:

  • Small tensors (≤8 elements): 16-byte SSE alignment
  • Medium tensors (8-1024 elements): 32-byte AVX2 alignment
  • Large tensors (>1024 elements): 64-byte cache-line alignment

§Performance Characteristics

  • Memory Overhead: ~64 bytes per tensor (excluding data)
  • SIMD Ready: Properly aligned for vectorized operations
  • Cache Friendly: Optimized memory layout for CPU cache hierarchies
  • Zero-Cost Views: View tensors share memory without copying
  • Thread Safe: Atomic ID generation and lock-free operations
  • Operator Performance: Zero-cost operator overloading for mathematical expressions

§Safety

This struct uses unsafe code for performance. The following invariants must be maintained:

  • data must be valid for shape.size elements
  • data must be properly aligned for f32
  • data must not be aliased while the tensor exists
  • shape.size must match the actual allocated memory
  • allocation_owner must be valid if present

§Examples

§Basic Tensor Operations

use train_station::Tensor;

// Create tensors with different configurations
let tensor = Tensor::new(vec![2, 3]);
let tensor_with_grad = Tensor::ones(vec![10, 10]).with_requires_grad();

// Access tensor properties
assert_eq!(tensor.size(), 6);
assert_eq!(tensor.shape().dims(), vec![2, 3]);
assert!(tensor.is_contiguous());

§Operator Overloading

use train_station::Tensor;

// Create tensors for operations
let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();

// Tensor operations with operators
let result = a.clone() + b.clone();                    // Tensor addition
let result = a.clone() * b.clone();                    // Element-wise multiplication
let result = a.clone() - b.clone();                    // Tensor subtraction
let result = a.clone() / b.clone();                    // Element-wise division

// Scalar operations
let result = a.clone() + 5.0;                          // Tensor + scalar
let result = 5.0 + a.clone();                          // Scalar + tensor
let result = a.clone() * 3.0;                          // Tensor * scalar
let result = 3.0 * a.clone();                          // Scalar * tensor

// Compound expressions
let result = (a.clone() + b.clone()) * 2.0 - 1.0;      // Complex mathematical expressions

// Assignment operators
let mut c = a.clone();
c += b.clone();                                        // In-place addition
c *= 2.0;                                              // In-place scalar multiplication

// Negation
let result = -a;                                       // Negate all elements

§Thread Safety

This type is Send + Sync and can be safely shared between threads. All operations are thread-safe through atomic ID generation and thread-local gradtrack storage.

Implementations§

Source§

impl Tensor

Source

pub fn capacity_elems(&self) -> usize

Returns the allocated capacity in elements, which may be padded beyond logical size

Source

pub fn new(shape_dims: Vec<usize>) -> Self

Creates a new tensor with the specified shape and optimized memory layout

Allocates memory with size-dependent alignment for optimal performance:

  • Small tensors (≤8 elements): 16-byte SSE alignment
  • Medium tensors (8-1024 elements): 32-byte AVX2 alignment
  • Large tensors (>1024 elements): 64-byte cache-line alignment
§Arguments
  • shape_dims - Vector of dimension sizes defining the tensor shape
§Returns

A new tensor with uninitialized data. The data must be initialized before use to avoid undefined behavior.

§Performance
  • Memory Allocation: Single allocation with optimized alignment
  • SIMD Ready: Properly aligned for vectorized operations
  • Cache Friendly: Optimized for CPU cache hierarchies
  • Thread Safe: Atomic ID generation for gradtrack tracking
§Safety

The returned tensor contains uninitialized memory. You must initialize the data before performing any operations that read from it.

§Examples
use train_station::Tensor;

// Create tensors of different sizes
let small_tensor = Tensor::new(vec![2, 3]);      // 16-byte alignment
let medium_tensor = Tensor::new(vec![32, 32]);   // 32-byte alignment
let large_tensor = Tensor::new(vec![1000, 1000]); // 64-byte alignment

// Initialize data before use
let mut tensor = Tensor::new(vec![2, 3]);
tensor.fill(0.0); // Initialize with zeros
Examples found in repository?
examples/getting_started/tensor_basics.rs (line 70)
42fn demonstrate_tensor_creation() {
43    println!("--- Tensor Creation ---");
44
45    // Create tensors with different initializations
46    let zeros = Tensor::zeros(vec![2, 3]);
47    println!(
48        "Zeros tensor: shape {:?}, data: {:?}",
49        zeros.shape().dims(),
50        zeros.data()
51    );
52
53    let ones = Tensor::ones(vec![3, 2]);
54    println!(
55        "Ones tensor: shape {:?}, data: {:?}",
56        ones.shape().dims(),
57        ones.data()
58    );
59
60    // Create tensor from slice
61    let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
62    let from_slice = Tensor::from_slice(&data, vec![2, 3]).unwrap();
63    println!(
64        "From slice: shape {:?}, data: {:?}",
65        from_slice.shape().dims(),
66        from_slice.data()
67    );
68
69    // Create tensor with specific value
70    let mut filled = Tensor::new(vec![2, 2]);
71    {
72        let data = filled.data_mut();
73        for value in data.iter_mut() {
74            *value = 42.0;
75        }
76    }
77    println!("Filled with 42: {:?}", filled.data());
78
79    // Create tensor with random data
80    let random = Tensor::randn(vec![2, 2], Some(42));
81    println!(
82        "Random tensor: shape {:?}, data: {:?}",
83        random.shape().dims(),
84        random.data()
85    );
86}
Source

pub fn shape(&self) -> &Shape

Returns the shape and dimensional information of the tensor

Provides access to the tensor’s dimensions, size, strides, and memory layout information. This is used for shape validation, memory access calculations, and optimization decisions.

§Returns

Reference to the tensor’s shape information containing dimensions, size, strides, and memory layout type.

§Performance
  • Time Complexity: O(1) - direct field access
  • Memory: No allocation - returns reference to existing data
§Examples
use train_station::Tensor;

let tensor = Tensor::new(vec![2, 3, 4]);
let shape = tensor.shape();
assert_eq!(shape.dims(), vec![2, 3, 4]);
assert_eq!(shape.size(), 24);
assert_eq!(shape.rank(), 3);
Examples found in repository?
examples/neural_networks/basic_decoder.rs (line 78)
77    fn triple(t: &Tensor) -> (usize, usize, usize) {
78        let d = t.shape().dims();
79        (d[0], d[1], d[2])
80    }
81}
82
83#[allow(unused)]
84fn main() -> Result<(), Box<dyn std::error::Error>> {
85    println!("=== Basic Decoder Example ===");
86
87    let batch = 2usize;
88    let src = 7usize;
89    let tgt = 5usize;
90    let embed = 32usize;
91    let heads = 4usize;
92
93    let memory = Tensor::randn(vec![batch, src, embed], Some(21));
94    let tgt_in = Tensor::randn(vec![batch, tgt, embed], Some(22));
95
96    let mut dec = DecoderBlock::new(embed, heads, Some(456));
97    let out = dec.forward(&tgt_in, &memory, None, None);
98    println!("Output shape: {:?}", out.shape().dims());
99
100    let mut opt = Adam::with_learning_rate(0.01);
101    let mut params = dec.parameters();
102    for p in &params {
103        opt.add_parameter(p);
104    }
105    let mut loss = out.mean();
106    loss.backward(None);
107    opt.step(&mut params);
108    opt.zero_grad(&mut params);
109    println!("Loss: {:.6}", loss.value());
110    println!("=== Done ===");
111    Ok(())
112}
More examples
Hide additional examples
examples/neural_networks/basic_encoder.rs (line 67)
66    fn triple(t: &Tensor) -> (usize, usize, usize) {
67        let d = t.shape().dims();
68        (d[0], d[1], d[2])
69    }
70}
71
72#[allow(unused)]
73fn main() -> Result<(), Box<dyn std::error::Error>> {
74    println!("=== Basic Encoder Example ===");
75
76    let batch = 2usize;
77    let seq = 6usize;
78    let embed = 32usize;
79    let heads = 4usize;
80
81    let input = Tensor::randn(vec![batch, seq, embed], Some(11));
82    let mut enc = EncoderBlock::new(embed, heads, Some(123));
83
84    // Example: no mask (set Some(mask) to use masking)
85    let out = enc.forward(&input, None);
86    println!("Output shape: {:?}", out.shape().dims());
87
88    // Verify gradients/optimization
89    let mut opt = Adam::with_learning_rate(0.01);
90    let mut params = enc.parameters();
91    for p in &params {
92        opt.add_parameter(p);
93    }
94    let mut loss = out.mean();
95    loss.backward(None);
96    opt.step(&mut params);
97    opt.zero_grad(&mut params);
98    println!("Loss: {:.6}", loss.value());
99    println!("=== Done ===");
100    Ok(())
101}
examples/RL_training/ppo_continuous.rs (line 112)
106    fn forward(&self, state: &Tensor) -> (Tensor, Tensor) {
107        // Returns (mean [B, A], log_std [A])
108        let mean = self.net.forward(state);
109        (
110            mean,
111            self.log_std
112                .view(vec![1, self.log_std.shape().dims()[0] as i32]),
113        )
114    }
examples/RL_training/ppo_discrete.rs (line 296)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296    let b = ratio.shape().dims()[0];
297    let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298    let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299    let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300    high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
examples/RL_training/td3.rs (line 196)
193    fn forward(&self, state: &Tensor, action: &Tensor) -> Tensor {
194        // Concatenate along feature dim (dim=1) for batched inputs
195        // IMPORTANT: use views to preserve gradient graph; cloning would detach autograd
196        let s_view = state.view(state.shape().dims().iter().map(|&d| d as i32).collect());
197        let a_view = action.view(action.shape().dims().iter().map(|&d| d as i32).collect());
198        let sa = Tensor::cat(&[s_view, a_view], 1);
199        self.net.forward(&sa, None)
200    }
examples/RL_training/../neural_networks/basic_linear_layer.rs (line 162)
153fn demonstrate_layer_creation() {
154    println!("--- Layer Creation ---");
155
156    let layer = LinearLayer::new(3, 2, Some(42));
157
158    println!("Created linear layer:");
159    println!("  Input size: {}", layer.input_size);
160    println!("  Output size: {}", layer.output_size);
161    println!("  Parameter count: {}", layer.parameter_count());
162    println!("  Weight shape: {:?}", layer.weight.shape().dims());
163    println!("  Bias shape: {:?}", layer.bias.shape().dims());
164    println!("  Weight requires grad: {}", layer.weight.requires_grad());
165    println!("  Bias requires grad: {}", layer.bias.requires_grad());
166}
167
168/// Demonstrate forward pass with gradient tracking
169fn demonstrate_forward_pass() {
170    println!("\n--- Forward Pass (with gradients) ---");
171
172    let layer = LinearLayer::new(3, 2, Some(43));
173
174    // Single input
175    let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
176    let output = layer.forward(&input);
177
178    println!("Single input:");
179    println!("  Input: {:?}", input.data());
180    println!("  Output: {:?}", output.data());
181    println!("  Output requires grad: {}", output.requires_grad());
182
183    // Batch input
184    let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
185    let batch_output = layer.forward(&batch_input);
186
187    println!("Batch input:");
188    println!("  Input shape: {:?}", batch_input.shape().dims());
189    println!("  Output shape: {:?}", batch_output.shape().dims());
190    println!("  Output requires grad: {}", batch_output.requires_grad());
191}
192
193/// Demonstrate forward pass without gradient tracking
194fn demonstrate_forward_pass_no_grad() {
195    println!("\n--- Forward Pass (no gradients) ---");
196
197    let layer = LinearLayer::new(3, 2, Some(44));
198
199    // Single input
200    let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
201    let output = layer.forward_no_grad(&input);
202
203    println!("Single input (no grad):");
204    println!("  Input: {:?}", input.data());
205    println!("  Output: {:?}", output.data());
206    println!("  Output requires grad: {}", output.requires_grad());
207
208    // Compare with grad version
209    let output_with_grad = layer.forward(&input);
210    println!("Comparison:");
211    println!(
212        "  Same values: {}",
213        output.data() == output_with_grad.data()
214    );
215    println!("  No grad requires grad: {}", output.requires_grad());
216    println!(
217        "  With grad requires grad: {}",
218        output_with_grad.requires_grad()
219    );
220}
221
222/// Demonstrate complete training loop
223fn demonstrate_training_loop() -> Result<(), Box<dyn std::error::Error>> {
224    println!("\n--- Training Loop ---");
225
226    // Create layer and training data
227    let mut layer = LinearLayer::new(2, 1, Some(45));
228
229    // Simple regression task: y = 2*x1 + 3*x2 + 1
230    let x_data = Tensor::from_slice(
231        &[
232            1.0, 1.0, // x1=1, x2=1 -> y=6
233            2.0, 1.0, // x1=2, x2=1 -> y=8
234            1.0, 2.0, // x1=1, x2=2 -> y=9
235            2.0, 2.0, // x1=2, x2=2 -> y=11
236        ],
237        vec![4, 2],
238    )
239    .unwrap();
240
241    let y_true = Tensor::from_slice(&[6.0, 8.0, 9.0, 11.0], vec![4, 1]).unwrap();
242
243    println!("Training data:");
244    println!("  X shape: {:?}", x_data.shape().dims());
245    println!("  Y shape: {:?}", y_true.shape().dims());
246    println!("  Target function: y = 2*x1 + 3*x2 + 1");
247
248    // Create optimizer
249    let config = AdamConfig {
250        learning_rate: 0.01,
251        beta1: 0.9,
252        beta2: 0.999,
253        eps: 1e-8,
254        weight_decay: 0.0,
255        amsgrad: false,
256    };
257
258    let mut optimizer = Adam::with_config(config);
259    let params = layer.parameters();
260    for param in &params {
261        optimizer.add_parameter(param);
262    }
263
264    println!("Optimizer setup complete. Starting training...");
265
266    // Training loop
267    let num_epochs = 100;
268    let mut losses = Vec::new();
269
270    for epoch in 0..num_epochs {
271        // Forward pass
272        let y_pred = layer.forward(&x_data);
273
274        // Compute loss: MSE
275        let diff = y_pred.sub_tensor(&y_true);
276        let mut loss = diff.pow_scalar(2.0).mean();
277
278        // Backward pass
279        loss.backward(None);
280
281        // Optimizer step
282        let mut params = layer.parameters();
283        optimizer.step(&mut params);
284        optimizer.zero_grad(&mut params);
285
286        losses.push(loss.value());
287
288        // Print progress
289        if epoch % 20 == 0 || epoch == num_epochs - 1 {
290            println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
291        }
292    }
293
294    // Evaluate final model
295    let final_predictions = layer.forward_no_grad(&x_data);
296
297    println!("\nFinal model evaluation:");
298    println!("  Learned weights: {:?}", layer.weight.data());
299    println!("  Learned bias: {:?}", layer.bias.data());
300    println!("  Target weights: [2.0, 3.0]");
301    println!("  Target bias: [1.0]");
302
303    println!("  Predictions vs True:");
304    for i in 0..4 {
305        let pred = final_predictions.data()[i];
306        let true_val = y_true.data()[i];
307        println!(
308            "    Sample {}: pred={:.3}, true={:.1}, error={:.3}",
309            i + 1,
310            pred,
311            true_val,
312            (pred - true_val).abs()
313        );
314    }
315
316    // Training analysis
317    let initial_loss = losses[0];
318    let final_loss = losses[losses.len() - 1];
319    let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
320
321    println!("\nTraining Analysis:");
322    println!("  Initial loss: {:.6}", initial_loss);
323    println!("  Final loss: {:.6}", final_loss);
324    println!("  Loss reduction: {:.1}%", loss_reduction);
325
326    Ok(())
327}
328
329/// Demonstrate single vs batch inference
330fn demonstrate_single_vs_batch_inference() {
331    println!("\n--- Single vs Batch Inference ---");
332
333    let layer = LinearLayer::new(4, 3, Some(46));
334
335    // Single inference
336    println!("Single inference:");
337    let single_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
338    let single_output = layer.forward_no_grad(&single_input);
339    println!("  Input shape: {:?}", single_input.shape().dims());
340    println!("  Output shape: {:?}", single_output.shape().dims());
341    println!("  Output: {:?}", single_output.data());
342
343    // Batch inference
344    println!("Batch inference:");
345    let batch_input = Tensor::from_slice(
346        &[
347            1.0, 2.0, 3.0, 4.0, // Sample 1
348            5.0, 6.0, 7.0, 8.0, // Sample 2
349            9.0, 10.0, 11.0, 12.0, // Sample 3
350        ],
351        vec![3, 4],
352    )
353    .unwrap();
354    let batch_output = layer.forward_no_grad(&batch_input);
355    println!("  Input shape: {:?}", batch_input.shape().dims());
356    println!("  Output shape: {:?}", batch_output.shape().dims());
357
358    // Verify batch consistency - first sample should match single inference
359    let _first_batch_sample = batch_output.view(vec![3, 3]); // Reshape to access first sample
360    let first_sample_data = &batch_output.data()[0..3]; // First 3 elements
361    let single_sample_data = single_output.data();
362
363    println!("Consistency check:");
364    println!("  Single output: {:?}", single_sample_data);
365    println!("  First batch sample: {:?}", first_sample_data);
366    println!(
367        "  Match: {}",
368        single_sample_data
369            .iter()
370            .zip(first_sample_data.iter())
371            .all(|(a, b)| (a - b).abs() < 1e-6)
372    );
373}
Source

pub fn size(&self) -> usize

Returns the total number of elements in the tensor

Provides the total count of elements across all dimensions. This is used for memory allocation, iteration bounds, and performance optimization.

§Returns

Total number of elements as usize

§Performance
  • Time Complexity: O(1) - direct field access
  • Memory: No allocation - returns stored value
§Examples
use train_station::Tensor;

let tensor = Tensor::new(vec![2, 3, 4]);
assert_eq!(tensor.size(), 24); // 2 * 3 * 4

let scalar = Tensor::new(vec![1]);
assert_eq!(scalar.size(), 1);

let empty = Tensor::new(vec![0]);
assert_eq!(empty.size(), 0);
Examples found in repository?
examples/getting_started/tensor_basics.rs (line 186)
179fn demonstrate_utility_functions() {
180    println!("\n--- Utility Functions ---");
181
182    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184    // Basic properties
185    println!("Shape: {:?}", tensor.shape().dims());
186    println!("Size: {}", tensor.size());
187    println!("Is contiguous: {}", tensor.is_contiguous());
188    println!("Device: {:?}", tensor.device());
189
190    // Mathematical operations
191    let sum = tensor.sum();
192    println!("Sum: {}", sum.value());
193
194    let mean = tensor.mean();
195    println!("Mean: {}", mean.value());
196
197    let norm = tensor.norm();
198    println!("Norm: {}", norm.value());
199
200    // Device placement
201    let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202    println!(
203        "CPU tensor: shape {:?}, device: {:?}",
204        cpu_tensor.shape().dims(),
205        cpu_tensor.device()
206    );
207}
More examples
Hide additional examples
examples/iterators/performance_optimization.rs (line 209)
162fn demonstrate_memory_optimization() -> Result<(), Box<dyn std::error::Error>> {
163    println!("\n--- Memory Optimization ---");
164
165    // Create a large tensor for memory testing
166    let size = 10000;
167    let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
168    let tensor = Tensor::from_slice(&data, vec![size])?;
169
170    println!("Processing tensor of size: {}", size);
171
172    // Pattern 1: Streaming processing with iterator chunks (process in blocks, collect with shape)
173    println!("\nPattern 1: Streaming Processing");
174    let chunk_size = 1000;
175    let start = Instant::now();
176    let flattened = tensor.view(vec![size as i32]);
177    let _streamed_result: Tensor = flattened
178        .chunks(chunk_size)
179        .map(|c| c.pow_scalar(2.0).sqrt())
180        .collect_shape(vec![size]);
181    let streamed_time = start.elapsed();
182
183    // Pattern 2: Full processing
184    let start = Instant::now();
185    let _full_result: Tensor = tensor
186        .iter_elements()
187        .map(|elem| elem.pow_scalar(2.0).sqrt())
188        .collect_shape(vec![size]);
189    let full_time = start.elapsed();
190
191    println!("  Streaming time: {:?}", streamed_time);
192    println!("  Full processing time: {:?}", full_time);
193    println!(
194        "  Memory efficiency ratio: {:.2}x",
195        full_time.as_nanos() as f64 / streamed_time.as_nanos() as f64
196    );
197
198    // Pattern 3: Lazy evaluation with take
199    println!("\nPattern 2: Lazy Evaluation");
200    let start = Instant::now();
201    let lazy_result: Tensor = tensor
202        .iter_elements()
203        .take(1000) // Only process first 1000 elements
204        .map(|elem| elem.pow_scalar(2.0).sqrt())
205        .collect_shape(vec![1000]);
206    let lazy_time = start.elapsed();
207
208    println!("  Lazy processing (1000 elements): {:?}", lazy_time);
209    println!("  Lazy result size: {}", lazy_result.size());
210
211    // Pattern 4: Memory-efficient filtering
212    println!("\nPattern 3: Memory-Efficient Filtering");
213    let start = Instant::now();
214    let filtered_result: Tensor = tensor
215        .iter_elements()
216        .filter(|elem| elem.value() > size as f32 / 2.0) // Keep only large values
217        .map(|elem| elem.mul_scalar(2.0))
218        .collect();
219    let filtered_time = start.elapsed();
220
221    println!("  Filtered processing: {:?}", filtered_time);
222    println!(
223        "  Filtered result size: {} (reduced from {})",
224        filtered_result.size(),
225        size
226    );
227
228    Ok(())
229}
230
231/// Demonstrate large-scale processing techniques
232///
233/// Shows how to efficiently process very large datasets using
234/// iterator patterns and optimization strategies.
235fn demonstrate_large_scale_processing() -> Result<(), Box<dyn std::error::Error>> {
236    println!("\n--- Large-Scale Processing ---");
237
238    // Simulate large dataset processing
239    let sizes = vec![10000, 50000, 100000];
240
241    for size in sizes {
242        println!("\nProcessing dataset of size: {}", size);
243
244        // Generate large dataset
245        let data: Vec<f32> = (0..size)
246            .map(|i| {
247                let x = i as f32 / size as f32;
248                x * x + 0.1 * (i % 10) as f32 // Quadratic with noise
249            })
250            .collect();
251
252        let tensor = Tensor::from_slice(&data, vec![size])?;
253
254        // Technique 1: Batch processing
255        let batch_size = 1000;
256        let start = Instant::now();
257
258        let mut batch_results = Vec::new();
259        for batch_start in (0..size).step_by(batch_size) {
260            let batch_end = (batch_start + batch_size).min(size);
261            let batch: Tensor = tensor
262                .iter_range(batch_start, batch_end)
263                .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
264                .collect();
265            batch_results.push(batch);
266        }
267        let batch_time = start.elapsed();
268
269        // Technique 2: Parallel-like processing with stride
270        let start = Instant::now();
271        let stride = 4;
272        let strided_result: Tensor = tensor
273            .iter()
274            .enumerate()
275            .filter(|(i, _)| i % stride == 0)
276            .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
277            .collect();
278        let strided_time = start.elapsed();
279
280        // Technique 3: Hierarchical processing
281        let start = Instant::now();
282        let coarse: Tensor = tensor
283            .iter()
284            .enumerate()
285            .filter(|(i, _)| i % 10 == 0) // Every 10th element
286            .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
287            .collect();
288        let fine: Tensor = tensor
289            .iter()
290            .enumerate()
291            .filter(|(i, _)| i % 10 != 0) // Rest of elements
292            .map(|(_, elem)| elem.pow_scalar(1.5).add_scalar(0.5))
293            .collect();
294        let hierarchical_time = start.elapsed();
295
296        // Report performance
297        println!("  Batch processing: {:?}", batch_time);
298        println!("  Strided processing: {:?}", strided_time);
299        println!("  Hierarchical processing: {:?}", hierarchical_time);
300
301        // Memory usage analysis
302        let total_batches = size.div_ceil(batch_size);
303        println!("  Batch count: {}", total_batches);
304        println!("  Strided result size: {}", strided_result.size());
305        println!(
306            "  Hierarchical: coarse={}, fine={}",
307            coarse.size(),
308            fine.size()
309        );
310    }
311
312    Ok(())
313}
314
315/// Demonstrate advanced optimization techniques
316///
317/// Shows sophisticated optimization strategies and techniques
318/// for maximizing performance in tensor iterator operations.
319fn demonstrate_optimization_techniques() -> Result<(), Box<dyn std::error::Error>> {
320    println!("\n--- Optimization Techniques ---");
321
322    let size = 50000;
323    let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
324    let tensor = Tensor::from_slice(&data, vec![size])?;
325
326    println!("Optimizing processing for size: {}", size);
327
328    // Technique 1: Operation fusion
329    println!("\nTechnique 1: Operation Fusion");
330    let start = Instant::now();
331    let fused_result: Tensor = tensor
332        .iter()
333        .map(|elem| {
334            // Fuse multiple operations into single chain
335            elem.mul_scalar(2.0).add_scalar(1.0).pow_scalar(2.0).sqrt()
336        })
337        .collect();
338    let fused_time = start.elapsed();
339
340    // Technique 2: Conditional optimization
341    println!("\nTechnique 2: Conditional Optimization");
342    let start = Instant::now();
343    let conditional_result: Tensor = tensor
344        .iter()
345        .map(|elem| {
346            let val = elem.value();
347            if val < size as f32 / 2.0 {
348                elem.mul_scalar(2.0) // Simple operation for small values
349            } else {
350                elem.pow_scalar(2.0).sqrt() // Complex operation for large values
351            }
352        })
353        .collect();
354    let conditional_time = start.elapsed();
355
356    // Technique 3: Cache-friendly processing
357    println!("\nTechnique 3: Cache-Friendly Processing");
358    let start = Instant::now();
359    let cache_friendly_result: Tensor = tensor
360        .iter()
361        .take(1000) // Process in cache-friendly chunks
362        .map(|elem| elem.mul_scalar(2.0))
363        .collect();
364    let cache_friendly_time = start.elapsed();
365
366    // Technique 4: Memory pooling simulation
367    println!("\nTechnique 4: Memory Pooling Simulation");
368    let start = Instant::now();
369    let pooled_result: Tensor = tensor
370        .iter()
371        .enumerate()
372        .filter(|(i, _)| i % 100 == 0) // Process every 100th element
373        .map(|(_, elem)| elem.pow_scalar(2.0))
374        .collect();
375    let pooled_time = start.elapsed();
376
377    // Report optimization results
378    println!("  Fused operations: {:?}", fused_time);
379    println!("  Conditional optimization: {:?}", conditional_time);
380    println!("  Cache-friendly processing: {:?}", cache_friendly_time);
381    println!("  Memory pooling simulation: {:?}", pooled_time);
382
383    // Performance analysis
384    let fastest = fused_time
385        .min(conditional_time)
386        .min(cache_friendly_time)
387        .min(pooled_time);
388    println!("  Fastest technique: {:?}", fastest);
389
390    // Memory efficiency analysis
391    println!("  Fused result size: {}", fused_result.size());
392    println!("  Conditional result size: {}", conditional_result.size());
393    println!(
394        "  Cache-friendly result size: {}",
395        cache_friendly_result.size()
396    );
397    println!("  Pooled result size: {}", pooled_result.size());
398
399    // Technique 5: Gradient optimization
400    println!("\nTechnique 5: Gradient Optimization");
401    let grad_tensor = tensor.with_requires_grad();
402    let start = Instant::now();
403
404    let grad_result: Tensor = grad_tensor
405        .iter()
406        .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
407        .collect();
408
409    let mut loss = grad_result.sum();
410    loss.backward(None);
411    let grad_time = start.elapsed();
412
413    println!("  Gradient computation: {:?}", grad_time);
414    println!(
415        "  Gradient tracking enabled: {}",
416        grad_result.requires_grad()
417    );
418
419    Ok(())
420}
examples/iterators/advanced_patterns.rs (line 128)
87fn demonstrate_data_pipeline() -> Result<(), Box<dyn std::error::Error>> {
88    println!("\n--- Data Processing Pipeline ---");
89
90    // Simulate raw sensor data with noise
91    let raw_data: Vec<f32> = (0..20)
92        .map(|i| {
93            let base = i as f32 * 0.5;
94            let noise = (i % 3) as f32 * 0.1;
95            base + noise
96        })
97        .collect();
98
99    let tensor = Tensor::from_slice(&raw_data, vec![20])?;
100    println!("Raw sensor data: {:?}", tensor.data());
101
102    // Multi-stage processing pipeline
103    println!("\nProcessing pipeline:");
104    println!("1. Normalize data (z-score)");
105    println!("2. Apply smoothing filter");
106    println!("3. Detect outliers");
107    println!("4. Apply feature scaling");
108
109    // Stage 1: Normalization
110    let mean = tensor.mean().value();
111    let std = tensor.std().value();
112    let normalized: Tensor = tensor
113        .iter()
114        .map(|elem| elem.sub_scalar(mean).div_scalar(std))
115        .collect();
116    println!(
117        "  Normalized (mean={:.3}, std={:.3}): {:?}",
118        mean,
119        std,
120        normalized.data()
121    );
122
123    // Stage 2: Smoothing (simple moving average)
124    let smoothed: Tensor = normalized
125        .iter()
126        .enumerate()
127        .map(|(i, elem)| {
128            if i == 0 || i == normalized.size() - 1 {
129                elem.clone()
130            } else {
131                // Simple 3-point average
132                let prev = normalized.element_view(i - 1);
133                let next = normalized.element_view(i + 1);
134                elem.add_tensor(&prev).add_tensor(&next).div_scalar(3.0)
135            }
136        })
137        .collect();
138    println!("  Smoothed: {:?}", smoothed.data());
139
140    // Stage 3: Outlier detection and removal
141    let outlier_threshold = 2.0;
142    let cleaned: Tensor = smoothed
143        .iter()
144        .filter(|elem| elem.value().abs() < outlier_threshold)
145        .collect();
146    println!(
147        "  Outliers removed (threshold={}): {:?}",
148        outlier_threshold,
149        cleaned.data()
150    );
151
152    // Stage 4: Feature scaling to [0, 1] range
153    let min_val = cleaned
154        .iter()
155        .map(|e| e.value())
156        .fold(f32::INFINITY, f32::min);
157    let max_val = cleaned
158        .iter()
159        .map(|e| e.value())
160        .fold(f32::NEG_INFINITY, f32::max);
161    let scaled: Tensor = cleaned
162        .iter()
163        .map(|elem| elem.sub_scalar(min_val).div_scalar(max_val - min_val))
164        .collect();
165    println!("  Scaled to [0,1]: {:?}", scaled.data());
166
167    Ok(())
168}
169
170/// Demonstrate conditional processing patterns
171///
172/// Shows how to implement dynamic filtering and transformation
173/// based on data characteristics and conditions.
174fn demonstrate_conditional_processing() -> Result<(), Box<dyn std::error::Error>> {
175    println!("\n--- Conditional Processing ---");
176
177    // Create data with mixed characteristics
178    let data = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0];
179    let tensor = Tensor::from_slice(&data, vec![10])?;
180    println!("Input data: {:?}", tensor.data());
181
182    // Conditional transformation based on sign
183    println!("\nConditional transformation (positive/negative handling):");
184    let processed: Tensor = tensor
185        .iter()
186        .map(|elem| {
187            let val = elem.value();
188            if val > 0.0 {
189                elem.pow_scalar(2.0) // Square positive values
190            } else {
191                elem.mul_scalar(-1.0).sqrt() // Square root of absolute negative values
192            }
193        })
194        .collect();
195    println!("  Processed: {:?}", processed.data());
196
197    // Adaptive filtering based on local statistics
198    println!("\nAdaptive filtering (remove values > 2 std from local mean):");
199    let window_size = 3;
200    let adaptive_filtered: Tensor = tensor
201        .iter()
202        .enumerate()
203        .filter(|(i, elem)| {
204            let start = i.saturating_sub(window_size / 2);
205            let end = (i + window_size / 2 + 1).min(tensor.size());
206
207            // Calculate local mean and std
208            let local_values: Vec<f32> = (start..end)
209                .map(|j| tensor.element_view(j).value())
210                .collect();
211
212            let local_mean = local_values.iter().sum::<f32>() / local_values.len() as f32;
213            let local_variance = local_values
214                .iter()
215                .map(|v| (v - local_mean).powi(2))
216                .sum::<f32>()
217                / local_values.len() as f32;
218            let local_std = local_variance.sqrt();
219
220            let threshold = local_mean + 2.0 * local_std;
221            elem.value() <= threshold
222        })
223        .map(|(_, elem)| elem)
224        .collect();
225    println!("  Adaptive filtered: {:?}", adaptive_filtered.data());
226
227    // Multi-condition processing
228    println!("\nMulti-condition processing:");
229    let multi_processed: Tensor = tensor
230        .iter()
231        .map(|elem| {
232            let val = elem.value();
233            match () {
234                _ if val > 5.0 => elem.mul_scalar(2.0), // Double large values
235                _ if val < -5.0 => elem.div_scalar(2.0), // Halve small values
236                _ if val.abs() < 2.0 => elem.add_scalar(1.0), // Add 1 to small values
237                _ => elem.clone(),                      // Keep others unchanged
238            }
239        })
240        .collect();
241    println!("  Multi-condition: {:?}", multi_processed.data());
242
243    Ok(())
244}
245
246/// Demonstrate batch processing operations
247///
248/// Shows efficient processing of large datasets using iterator
249/// patterns and batch operations for performance optimization.
250fn demonstrate_batch_operations() -> Result<(), Box<dyn std::error::Error>> {
251    println!("\n--- Batch Operations ---");
252
253    // Create a larger dataset for batch processing
254    let size = 100;
255    let data: Vec<f32> = (0..size)
256        .map(|i| {
257            let x = i as f32 / size as f32;
258            x * x + 0.1 * (i % 7) as f32 // Quadratic with some noise
259        })
260        .collect();
261
262    let tensor = Tensor::from_slice(&data, vec![size])?;
263    println!("Dataset size: {}", tensor.size());
264
265    // Batch processing with windowing (iterator views)
266    println!("\nBatch processing with sliding windows:");
267    let batch_size = 10;
268    let batches: Vec<Tensor> = tensor
269        .iter()
270        .collect::<Vec<_>>()
271        .chunks(batch_size)
272        .map(|chunk| {
273            // Process each batch independently
274            chunk
275                .iter()
276                .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
277                .collect()
278        })
279        .collect();
280
281    println!(
282        "  Processed {} batches of size {}",
283        batches.len(),
284        batch_size
285    );
286    for (i, batch) in batches.iter().enumerate() {
287        println!(
288            "    Batch {}: mean={:.3}, std={:.3}",
289            i,
290            batch.mean().value(),
291            batch.std().value()
292        );
293    }
294
295    // Parallel-like processing with stride
296    println!("\nStrided processing (every nth element):");
297    let stride = 5;
298    let strided: Tensor = tensor
299        .iter()
300        .enumerate()
301        .filter(|(i, _)| i % stride == 0)
302        .map(|(_, elem)| elem)
303        .collect();
304    println!("  Strided (every {}th): {:?}", stride, strided.data());
305
306    // Hierarchical processing
307    println!("\nHierarchical processing (coarse to fine):");
308    let coarse: Tensor = tensor
309        .iter()
310        .enumerate()
311        .filter(|(i, _)| i % 4 == 0) // Take every 4th element
312        .map(|(_, elem)| elem)
313        .collect();
314
315    let fine: Tensor = tensor
316        .iter()
317        .enumerate()
318        .filter(|(i, _)| i % 4 != 0) // Take the rest
319        .map(|(_, elem)| elem)
320        .collect();
321
322    println!("  Coarse (every 4th): {:?}", coarse.data());
323    println!("  Fine (rest): {:?}", fine.data());
324
325    // Combine coarse and fine with different processing
326    let combined: Tensor = coarse
327        .iter()
328        .map(|elem| elem.mul_scalar(2.0)) // Scale coarse
329        .chain(fine.iter().map(|elem| elem.div_scalar(2.0))) // Scale fine
330        .collect();
331    println!("  Combined: {:?}", combined.data());
332
333    Ok(())
334}
335
336/// Demonstrate real-world processing scenarios
337///
338/// Shows practical applications of iterator patterns for
339/// common data processing tasks in machine learning and analytics.
340fn demonstrate_real_world_scenarios() -> Result<(), Box<dyn std::error::Error>> {
341    println!("\n--- Real-world Scenarios ---");
342
343    // Scenario 1: Time series analysis
344    println!("\nScenario 1: Time Series Analysis");
345    let time_series: Vec<f32> = (0..24)
346        .map(|hour| {
347            let base = 20.0 + 10.0 * (hour as f32 * std::f32::consts::PI / 12.0).sin();
348            base + (hour % 3) as f32 * 2.0 // Add some noise
349        })
350        .collect();
351
352    let series = Tensor::from_slice(&time_series, vec![24])?;
353    println!("  Time series (24 hours): {:?}", series.data());
354
355    // Calculate moving average with view-based iteration
356    let window_size = 3;
357    let moving_avg: Tensor = series
358        .iter()
359        .enumerate()
360        .map(|(i, _)| {
361            let start = i.saturating_sub(window_size / 2);
362            let end = (i + window_size / 2 + 1).min(series.size());
363            let window = series.iter_range(start, end);
364            window.fold(0.0, |acc, elem| acc + elem.value()) / (end - start) as f32
365        })
366        .map(|val| Tensor::from_slice(&[val], vec![1]).unwrap())
367        .collect();
368    println!(
369        "  Moving average (window={}): {:?}",
370        window_size,
371        moving_avg.data()
372    );
373
374    // Inference pipeline with NoGrad + streaming
375    println!("\nInference pipeline (NoGrad + streaming)");
376    let features = Tensor::from_slice(
377        &(0..48).map(|i| i as f32 * 0.125).collect::<Vec<_>>(),
378        vec![6, 8],
379    )?;
380    let fast = with_no_grad(|| {
381        // Stream values directly, apply light affine, and collect back to same shape
382        features
383            .data()
384            .iter()
385            .copied()
386            .map(|x| 0.75 * x + 0.1)
387            .collect_shape(vec![6, 8])
388    });
389    println!(
390        "  NoGrad streamed transform shape: {:?}",
391        fast.shape().dims()
392    );
393
394    // Row-wise iteration with shape-preserving collection (GradTrack-friendly)
395    let per_row: Tensor = features
396        .iter()
397        .map(|row| row.mul_scalar(0.5).add_scalar(2.0))
398        .collect_shape(vec![6, 8]);
399    println!("  Row-wise mapped shape: {:?}", per_row.shape().dims());
400
401    // Scenario 2: Feature engineering
402    println!("\nScenario 2: Feature Engineering");
403    let features = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
404    println!("  Original features: {:?}", features.data());
405
406    // Create polynomial features
407    let poly_features: Tensor = features
408        .iter()
409        .flat_map(|elem| {
410            vec![
411                elem.clone(),         // x^1
412                elem.pow_scalar(2.0), // x^2
413                elem.pow_scalar(3.0), // x^3
414            ]
415        })
416        .collect();
417    println!(
418        "  Polynomial features (x, x^2, x^3): {:?}",
419        poly_features.data()
420    );
421
422    // Scenario 3: Data augmentation
423    println!("\nScenario 3: Data Augmentation");
424    let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?;
425    println!("  Original data: {:?}", original.data());
426
427    // Augment with noise and scaling
428    let augmented: Tensor = original
429        .iter()
430        .flat_map(|elem| {
431            vec![
432                elem.clone(),         // Original
433                elem.add_scalar(0.1), // Add noise
434                elem.sub_scalar(0.1), // Subtract noise
435                elem.mul_scalar(1.1), // Scale up
436                elem.mul_scalar(0.9), // Scale down
437            ]
438        })
439        .collect();
440    println!("  Augmented data: {:?}", augmented.data());
441
442    // Scenario 4: Statistical analysis
443    println!("\nScenario 4: Statistical Analysis");
444    let sample_data = Tensor::from_slice(&[1.1, 2.3, 1.8, 2.1, 1.9, 2.0, 1.7, 2.2], vec![8])?;
445    println!("  Sample data: {:?}", sample_data.data());
446
447    // Calculate various statistics
448    let mean = sample_data.mean().value();
449    let std = sample_data.std().value();
450    let min = sample_data
451        .iter()
452        .map(|e| e.value())
453        .fold(f32::INFINITY, f32::min);
454    let max = sample_data
455        .iter()
456        .map(|e| e.value())
457        .fold(f32::NEG_INFINITY, f32::max);
458
459    // Z-score normalization
460    let z_scores: Tensor = sample_data
461        .iter()
462        .map(|elem| elem.sub_scalar(mean).div_scalar(std))
463        .collect();
464
465    println!(
466        "  Statistics: mean={:.3}, std={:.3}, min={:.3}, max={:.3}",
467        mean, std, min, max
468    );
469    println!("  Z-scores: {:?}", z_scores.data());
470
471    Ok(())
472}
Source

pub fn device(&self) -> Device

Returns the device where this tensor is located

Provides the physical location of the tensor data (CPU/GPU). This determines which operations can be performed on the tensor and where computations will be executed.

§Returns

Device enum indicating the tensor’s physical location

§Performance
  • Time Complexity: O(1) - direct field access
  • Memory: No allocation - returns stored value
§Examples
use train_station::Tensor;

let tensor = Tensor::new(vec![2, 3]);
assert!(tensor.device().is_cpu());
assert!(!tensor.device().is_cuda());
Examples found in repository?
examples/getting_started/tensor_basics.rs (line 188)
179fn demonstrate_utility_functions() {
180    println!("\n--- Utility Functions ---");
181
182    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184    // Basic properties
185    println!("Shape: {:?}", tensor.shape().dims());
186    println!("Size: {}", tensor.size());
187    println!("Is contiguous: {}", tensor.is_contiguous());
188    println!("Device: {:?}", tensor.device());
189
190    // Mathematical operations
191    let sum = tensor.sum();
192    println!("Sum: {}", sum.value());
193
194    let mean = tensor.mean();
195    println!("Mean: {}", mean.value());
196
197    let norm = tensor.norm();
198    println!("Norm: {}", norm.value());
199
200    // Device placement
201    let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202    println!(
203        "CPU tensor: shape {:?}, device: {:?}",
204        cpu_tensor.shape().dims(),
205        cpu_tensor.device()
206    );
207}
Source

pub fn new_on_device(shape_dims: Vec<usize>, device: Device) -> Self

Creates a new tensor with the specified shape on a specific device

Allocates memory on the specified device with the same optimized alignment strategy as new(). Currently supports CPU device with future CUDA support.

§Arguments
  • shape_dims - Vector of dimension sizes defining the tensor shape
  • device - The device where the tensor should be allocated
§Returns

A new tensor with uninitialized data on the specified device

§Performance
  • Memory Allocation: Device-specific allocation with optimized alignment
  • SIMD Ready: Properly aligned for vectorized operations on target device
  • Thread Safe: Atomic ID generation for gradtrack tracking
§Panics

Panics if the specified device is not supported (e.g., CUDA without feature flag)

§Examples
use train_station::Tensor;

let tensor = Tensor::new_on_device(vec![2, 3], train_station::Device::cpu());
assert!(tensor.device().is_cpu());
assert_eq!(tensor.size(), 6);
§Arguments
  • shape_dims - Vector of dimension sizes defining the tensor shape
  • device - Device where the tensor should be allocated
§Returns

A new tensor with uninitialized data on the specified device

§Panics

Panics if the device is not supported (currently only CPU is supported)

§Performance
  • Memory Allocation: Single allocation with optimized alignment
  • SIMD Ready: Properly aligned for vectorized operations
  • Cache Friendly: Optimized for CPU cache hierarchies
  • Thread Safe: Atomic ID generation for gradtrack tracking
§Examples
use train_station::{Tensor, Device};

// Create tensor on CPU device
let tensor = Tensor::new_on_device(vec![2, 3], Device::cpu());
assert_eq!(tensor.device(), Device::cpu());
assert_eq!(tensor.size(), 6);
Source

pub fn with_requires_grad(self) -> Self

Enable gradient computation for this tensor

Builder method that enables automatic gradient tracking for this tensor. When enabled, all operations involving this tensor will be recorded in the computation graph for gradient computation during backward pass.

§Returns

self with gradient tracking enabled

§Performance
  • Time Complexity: O(1) - simple field assignment
  • Memory: No additional allocation
  • Overhead: Minimal gradtrack tracking overhead when gradients computed
§Examples
use train_station::Tensor;

let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
assert!(tensor.requires_grad());
Examples found in repository?
examples/RL_training/ppo_continuous.rs (line 103)
99    fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
100        let net = Mlp::new(&[state_dim, 64, 64, action_dim], seed);
101        let log_std = Tensor::from_slice(&vec![0.0; action_dim], vec![action_dim])
102            .unwrap()
103            .with_requires_grad();
104        Self { net, log_std }
105    }
More examples
Hide additional examples
examples/RL_training/../neural_networks/basic_linear_layer.rs (line 59)
53    pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
54        // Xavier/Glorot initialization: scale by sqrt(1/input_size)
55        let scale = (1.0 / input_size as f32).sqrt();
56
57        let weight = Tensor::randn(vec![input_size, output_size], seed)
58            .mul_scalar(scale)
59            .with_requires_grad();
60        let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
61
62        Self {
63            weight,
64            bias,
65            input_size,
66            output_size,
67        }
68    }
69
70    /// Forward pass: output = input @ weight + bias
71    pub fn forward(&self, input: &Tensor) -> Tensor {
72        // Matrix multiplication: [batch_size, input_size] @ [input_size, output_size] = [batch_size, output_size]
73        let output = input.matmul(&self.weight);
74        // Add bias: [batch_size, output_size] + [output_size] = [batch_size, output_size]
75        output.add_tensor(&self.bias)
76    }
77
78    /// Forward pass without gradients (for inference)
79    #[allow(unused)]
80    pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
81        let _guard = NoGradTrack::new();
82        self.forward(input)
83    }
84
85    /// Get all parameters for optimization
86    pub fn parameters(&mut self) -> Vec<&mut Tensor> {
87        vec![&mut self.weight, &mut self.bias]
88    }
89
90    /// Save layer parameters to JSON
91    #[allow(unused)]
92    pub fn save_json(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
93        // Create directory if it doesn't exist
94        if let Some(parent) = std::path::Path::new(path).parent() {
95            fs::create_dir_all(parent)?;
96        }
97
98        let weight_path = format!("{}_weight.json", path);
99        let bias_path = format!("{}_bias.json", path);
100
101        self.weight.save_json(&weight_path)?;
102        self.bias.save_json(&bias_path)?;
103
104        println!("Saved linear layer to {} (weight and bias)", path);
105        Ok(())
106    }
107
108    /// Load layer parameters from JSON
109    #[allow(unused)]
110    pub fn load_json(
111        path: &str,
112        input_size: usize,
113        output_size: usize,
114    ) -> Result<Self, Box<dyn std::error::Error>> {
115        let weight_path = format!("{}_weight.json", path);
116        let bias_path = format!("{}_bias.json", path);
117
118        let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
119        let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
120
121        Ok(Self {
122            weight,
123            bias,
124            input_size,
125            output_size,
126        })
127    }
examples/iterators/element_iteration.rs (line 179)
175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176    println!("\n--- Gradient Tracking ---");
177
178    // Create a tensor with gradient tracking enabled
179    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180    println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182    // Perform element-wise operations through iteration
183    let result: Tensor = tensor
184        .iter()
185        .map(|elem| {
186            // Apply a complex transformation: (x^2 + 1) * 2
187            elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188        })
189        .collect();
190
191    println!("Result tensor: {:?}", result.data());
192    println!("Result requires_grad: {}", result.requires_grad());
193
194    // Compute gradients
195    let mut loss = result.sum();
196    loss.backward(None);
197
198    println!("Loss: {:.6}", loss.value());
199    println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201    Ok(())
202}
examples/neural_networks/feedforward_network.rs (line 214)
200    pub fn load_json(
201        path: &str,
202        config: FeedForwardConfig,
203    ) -> Result<Self, Box<dyn std::error::Error>> {
204        let mut layers = Vec::new();
205        let mut current_size = config.input_size;
206        let mut layer_idx = 0;
207
208        // Load hidden layers
209        for &hidden_size in &config.hidden_sizes {
210            let layer_path = format!("{}_layer_{}", path, layer_idx);
211            let weight_path = format!("{}_weight.json", layer_path);
212            let bias_path = format!("{}_bias.json", layer_path);
213
214            let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
215            let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
216
217            layers.push(LinearLayer {
218                weight,
219                bias,
220                input_size: current_size,
221                output_size: hidden_size,
222            });
223
224            current_size = hidden_size;
225            layer_idx += 1;
226        }
227
228        // Load output layer
229        let layer_path = format!("{}_layer_{}", path, layer_idx);
230        let weight_path = format!("{}_weight.json", layer_path);
231        let bias_path = format!("{}_bias.json", layer_path);
232
233        let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
234        let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
235
236        layers.push(LinearLayer {
237            weight,
238            bias,
239            input_size: current_size,
240            output_size: config.output_size,
241        });
242
243        Ok(Self { layers, config })
244    }
examples/getting_started/optimizer_basics.rs (line 51)
47fn demonstrate_basic_optimizer_setup() {
48    println!("--- Basic Optimizer Setup ---");
49
50    // Create parameters that require gradients
51    let weight = Tensor::randn(vec![3, 2], Some(42)).with_requires_grad();
52    let bias = Tensor::zeros(vec![2]).with_requires_grad();
53
54    println!("Created parameters:");
55    println!(
56        "  Weight: shape {:?}, requires_grad: {}",
57        weight.shape().dims(),
58        weight.requires_grad()
59    );
60    println!(
61        "  Bias: shape {:?}, requires_grad: {}",
62        bias.shape().dims(),
63        bias.requires_grad()
64    );
65
66    // Create Adam optimizer with default configuration
67    let mut optimizer = Adam::new();
68    println!(
69        "Created Adam optimizer with learning rate: {}",
70        optimizer.learning_rate()
71    );
72
73    // Add parameters to optimizer
74    optimizer.add_parameter(&weight);
75    optimizer.add_parameter(&bias);
76    println!(
77        "Added {} parameters to optimizer",
78        optimizer.parameter_count()
79    );
80
81    // Create optimizer with custom configuration
82    let config = AdamConfig {
83        learning_rate: 0.01,
84        beta1: 0.9,
85        beta2: 0.999,
86        eps: 1e-8,
87        weight_decay: 0.0,
88        amsgrad: false,
89    };
90
91    let mut custom_optimizer = Adam::with_config(config);
92    custom_optimizer.add_parameter(&weight);
93    custom_optimizer.add_parameter(&bias);
94
95    println!(
96        "Created custom optimizer with learning rate: {}",
97        custom_optimizer.learning_rate()
98    );
99
100    // Demonstrate parameter linking
101    println!("Parameter linking completed successfully");
102}
103
104/// Demonstrate simple linear regression training
105fn demonstrate_linear_regression() -> Result<(), Box<dyn std::error::Error>> {
106    println!("\n--- Linear Regression Training ---");
107
108    // Create model parameters
109    let mut weight = Tensor::randn(vec![1, 1], Some(43)).with_requires_grad();
110    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
111
112    // Create optimizer
113    let mut optimizer = Adam::with_learning_rate(0.01);
114    optimizer.add_parameter(&weight);
115    optimizer.add_parameter(&bias);
116
117    // Create simple training data: y = 2*x + 1
118    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
119    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
120
121    println!("Training data:");
122    println!("  X: {:?}", x_data.data());
123    println!("  Y: {:?}", y_true.data());
124    println!("  Target: y = 2*x + 1");
125
126    // Training loop
127    let num_epochs = 100;
128    let mut losses = Vec::new();
129
130    for epoch in 0..num_epochs {
131        // Forward pass: y_pred = x * weight + bias
132        let y_pred = x_data.matmul(&weight) + &bias;
133
134        // Compute loss: MSE
135        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
136
137        // Backward pass
138        loss.backward(None);
139
140        // Optimizer step
141        optimizer.step(&mut [&mut weight, &mut bias]);
142        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
143
144        losses.push(loss.value());
145
146        // Print progress every 20 epochs
147        if epoch % 20 == 0 || epoch == num_epochs - 1 {
148            println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
149        }
150    }
151
152    // Evaluate final model
153    let final_predictions = x_data.matmul(&weight) + &bias;
154    println!("\nFinal model evaluation:");
155    println!("  Learned weight: {:.6}", weight.value());
156    println!("  Learned bias: {:.6}", bias.value());
157    println!("  Predictions vs True:");
158
159    for i in 0..5 {
160        let x1 = x_data.data()[i];
161        let pred = final_predictions.data()[i];
162        let true_val = y_true.data()[i];
163        println!(
164            "    x={:.1}: pred={:.3}, true={:.1}, error={:.3}",
165            x1,
166            pred,
167            true_val,
168            (pred - true_val).abs()
169        );
170    }
171
172    Ok(())
173}
174
175/// Demonstrate advanced training patterns
176fn demonstrate_advanced_training() -> Result<(), Box<dyn std::error::Error>> {
177    println!("\n--- Advanced Training Patterns ---");
178
179    // Create a more complex model
180    let mut weight = Tensor::randn(vec![1, 2], Some(44)).with_requires_grad();
181    let mut bias = Tensor::zeros(vec![2]).with_requires_grad();
182
183    // Create optimizer with different learning rate
184    let mut optimizer = Adam::with_learning_rate(0.005);
185    optimizer.add_parameter(&weight);
186    optimizer.add_parameter(&bias);
187
188    // Create training data: y = 2*x + [1, 3]
189    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
190    let y_true = Tensor::from_slice(
191        &[3.0, 5.0, 7.0, 9.0, 11.0, 6.0, 8.0, 10.0, 12.0, 14.0],
192        vec![5, 2],
193    )
194    .unwrap();
195
196    println!("Advanced training with monitoring:");
197    println!("  Initial learning rate: {}", optimizer.learning_rate());
198
199    // Training loop with monitoring
200    let num_epochs = 50;
201    let mut losses = Vec::new();
202    let mut weight_norms = Vec::new();
203    let mut gradient_norms = Vec::new();
204
205    for epoch in 0..num_epochs {
206        // Forward pass
207        let y_pred = x_data.matmul(&weight) + &bias;
208        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
209
210        // Backward pass
211        loss.backward(None);
212
213        // Compute gradient norm before optimizer step
214        let gradient_norm = weight.grad_owned().unwrap().norm();
215
216        // Optimizer step
217        optimizer.step(&mut [&mut weight, &mut bias]);
218        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
219
220        // Learning rate scheduling: reduce every 10 epochs
221        if epoch > 0 && epoch % 10 == 0 {
222            let current_lr = optimizer.learning_rate();
223            let new_lr = current_lr * 0.5;
224            optimizer.set_learning_rate(new_lr);
225            println!(
226                "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
227                epoch, current_lr, new_lr
228            );
229        }
230
231        // Record metrics
232        losses.push(loss.value());
233        weight_norms.push(weight.norm().value());
234        gradient_norms.push(gradient_norm.value());
235
236        // Print detailed progress
237        if epoch % 10 == 0 || epoch == num_epochs - 1 {
238            println!(
239                "Epoch {:2}: Loss = {:.6}, Weight Norm = {:.6}, Gradient Norm = {:.6}",
240                epoch,
241                loss.value(),
242                weight.norm().value(),
243                gradient_norm.value()
244            );
245        }
246    }
247
248    println!("Final learning rate: {}", optimizer.learning_rate());
249
250    // Analyze training progression
251    let initial_loss = losses[0];
252    let final_loss = losses[losses.len() - 1];
253    let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
254
255    println!("\nTraining Analysis:");
256    println!("  Initial loss: {:.6}", initial_loss);
257    println!("  Final loss: {:.6}", final_loss);
258    println!("  Loss reduction: {:.1}%", loss_reduction);
259    println!("  Final weight norm: {:.6}", weight.norm().value());
260    println!("  Final bias: {:?}", bias.data());
261
262    Ok(())
263}
264
265/// Demonstrate learning rate scheduling
266fn demonstrate_learning_rate_scheduling() -> Result<(), Box<dyn std::error::Error>> {
267    println!("\n--- Learning Rate Scheduling ---");
268
269    // Create simple model
270    let mut weight = Tensor::randn(vec![1, 1], Some(45)).with_requires_grad();
271    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
272
273    // Create optimizer with high initial learning rate
274    let mut optimizer = Adam::with_learning_rate(0.1);
275    optimizer.add_parameter(&weight);
276    optimizer.add_parameter(&bias);
277
278    // Simple data
279    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3, 1]).unwrap();
280    let y_true = Tensor::from_slice(&[2.0, 4.0, 6.0], vec![3, 1]).unwrap();
281
282    println!("Initial learning rate: {}", optimizer.learning_rate());
283
284    // Training loop with learning rate scheduling
285    let num_epochs = 50;
286    let mut losses = Vec::new();
287
288    for epoch in 0..num_epochs {
289        // Forward pass
290        let y_pred = x_data.matmul(&weight) + &bias;
291        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
292
293        // Backward pass
294        loss.backward(None);
295
296        // Optimizer step
297        optimizer.step(&mut [&mut weight, &mut bias]);
298        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
299
300        // Learning rate scheduling: reduce every 10 epochs
301        if epoch > 0 && epoch % 10 == 0 {
302            let current_lr = optimizer.learning_rate();
303            let new_lr = current_lr * 0.5;
304            optimizer.set_learning_rate(new_lr);
305            println!(
306                "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
307                epoch, current_lr, new_lr
308            );
309        }
310
311        losses.push(loss.value());
312
313        // Print progress
314        if epoch % 10 == 0 || epoch == num_epochs - 1 {
315            println!(
316                "Epoch {:2}: Loss = {:.6}, LR = {:.3}",
317                epoch,
318                loss.value(),
319                optimizer.learning_rate()
320            );
321        }
322    }
323
324    println!("Final learning rate: {}", optimizer.learning_rate());
325
326    Ok(())
327}
328
329/// Demonstrate training monitoring and analysis
330fn demonstrate_training_monitoring() -> Result<(), Box<dyn std::error::Error>> {
331    println!("\n--- Training Monitoring ---");
332
333    // Create model
334    let mut weight = Tensor::randn(vec![1, 1], Some(46)).with_requires_grad();
335    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
336
337    // Create optimizer
338    let mut optimizer = Adam::with_learning_rate(0.01);
339    optimizer.add_parameter(&weight);
340    optimizer.add_parameter(&bias);
341
342    // Training data
343    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
344    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0], vec![4, 1]).unwrap();
345
346    // Training loop with comprehensive monitoring
347    let num_epochs = 30;
348    let mut losses = Vec::new();
349    let mut weight_history = Vec::new();
350    let mut bias_history = Vec::new();
351
352    for epoch in 0..num_epochs {
353        // Forward pass
354        let y_pred = x_data.matmul(&weight) + &bias;
355        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
356
357        // Backward pass
358        loss.backward(None);
359
360        // Optimizer step
361        optimizer.step(&mut [&mut weight, &mut bias]);
362        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
363
364        // Record history
365        losses.push(loss.value());
366        weight_history.push(weight.value());
367        bias_history.push(bias.value());
368
369        // Print detailed monitoring
370        if epoch % 5 == 0 || epoch == num_epochs - 1 {
371            println!(
372                "Epoch {:2}: Loss = {:.6}, Weight = {:.6}, Bias = {:.6}",
373                epoch,
374                loss.value(),
375                weight.value(),
376                bias.value()
377            );
378        }
379    }
380
381    // Analyze training progression
382    println!("\nTraining Analysis:");
383    println!("  Initial loss: {:.6}", losses[0]);
384    println!("  Final loss: {:.6}", losses[losses.len() - 1]);
385    println!(
386        "  Loss reduction: {:.1}%",
387        (losses[0] - losses[losses.len() - 1]) / losses[0] * 100.0
388    );
389
390    // Compute statistics
391    let loss_mean = compute_mean(&losses);
392    let loss_std = compute_std(&losses);
393    let weight_change = (weight_history[weight_history.len() - 1] - weight_history[0]).abs();
394    let bias_change = (bias_history[bias_history.len() - 1] - bias_history[0]).abs();
395
396    println!("  Average loss: {:.6} ± {:.6}", loss_mean, loss_std);
397    println!("  Weight change: {:.6}", weight_change);
398    println!("  Bias change: {:.6}", bias_change);
399    println!("  Final weight norm: {:.6}", weight.norm().value());
400    println!("  Final bias: {:.6}", bias.value());
401
402    Ok(())
403}
examples/getting_started/serialization_basics.rs (line 113)
109fn demonstrate_optimizer_serialization() -> Result<(), Box<dyn std::error::Error>> {
110    println!("\n--- Optimizer Serialization ---");
111
112    // Create an optimizer with some parameters
113    let mut weight = Tensor::randn(vec![2, 2], Some(42)).with_requires_grad();
114    let mut bias = Tensor::randn(vec![2], Some(43)).with_requires_grad();
115
116    let config = AdamConfig {
117        learning_rate: 0.001,
118        beta1: 0.9,
119        beta2: 0.999,
120        eps: 1e-8,
121        weight_decay: 0.0,
122        amsgrad: false,
123    };
124
125    let mut optimizer = Adam::with_config(config);
126    optimizer.add_parameter(&weight);
127    optimizer.add_parameter(&bias);
128
129    println!(
130        "Created optimizer with {} parameters",
131        optimizer.parameter_count()
132    );
133    println!("Learning rate: {}", optimizer.learning_rate());
134
135    // Simulate some training steps
136    for _ in 0..3 {
137        let mut loss = weight.sum() + bias.sum();
138        loss.backward(None);
139        optimizer.step(&mut [&mut weight, &mut bias]);
140        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
141    }
142
143    // Save optimizer state
144    let optimizer_path = "temp_optimizer.json";
145    optimizer.save_json(optimizer_path)?;
146    println!("Saved optimizer to: {}", optimizer_path);
147
148    // Load optimizer state
149    let loaded_optimizer = Adam::load_json(optimizer_path)?;
150    println!(
151        "Loaded optimizer with {} parameters",
152        loaded_optimizer.parameter_count()
153    );
154    println!("Learning rate: {}", loaded_optimizer.learning_rate());
155
156    // Verify optimizer state
157    assert_eq!(
158        optimizer.parameter_count(),
159        loaded_optimizer.parameter_count()
160    );
161    assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
162    println!("Optimizer serialization verification: PASSED");
163
164    Ok(())
165}
166
167/// Demonstrate format comparison and performance characteristics
168fn demonstrate_format_comparison() -> Result<(), Box<dyn std::error::Error>> {
169    println!("\n--- Format Comparison ---");
170
171    // Create a larger tensor for comparison
172    let tensor = Tensor::randn(vec![10, 10], Some(44));
173
174    // Save in both formats
175    tensor.save_json("temp_comparison.json")?;
176    tensor.save_binary("temp_comparison.bin")?;
177
178    // Compare file sizes
179    let json_size = fs::metadata("temp_comparison.json")?.len();
180    let binary_size = fs::metadata("temp_comparison.bin")?.len();
181
182    println!("JSON file size: {} bytes", json_size);
183    println!("Binary file size: {} bytes", binary_size);
184    println!(
185        "Compression ratio: {:.2}x",
186        json_size as f64 / binary_size as f64
187    );
188
189    // Load and verify both formats
190    let json_tensor = Tensor::load_json("temp_comparison.json")?;
191    let binary_tensor = Tensor::load_binary("temp_comparison.bin")?;
192
193    assert_eq!(tensor.shape().dims(), json_tensor.shape().dims());
194    assert_eq!(tensor.shape().dims(), binary_tensor.shape().dims());
195    assert_eq!(tensor.data(), json_tensor.data());
196    assert_eq!(tensor.data(), binary_tensor.data());
197
198    println!("Format comparison verification: PASSED");
199
200    Ok(())
201}
202
203/// Demonstrate a basic model checkpointing workflow
204fn demonstrate_model_checkpointing() -> Result<(), Box<dyn std::error::Error>> {
205    println!("\n--- Model Checkpointing ---");
206
207    // Create a simple model (weights and bias)
208    let mut weights = Tensor::randn(vec![2, 1], Some(45)).with_requires_grad();
209    let mut bias = Tensor::randn(vec![1], Some(46)).with_requires_grad();
210
211    // Create optimizer
212    let mut optimizer = Adam::with_learning_rate(0.01);
213    optimizer.add_parameter(&weights);
214    optimizer.add_parameter(&bias);
215
216    println!("Initial weights: {:?}", weights.data());
217    println!("Initial bias: {:?}", bias.data());
218
219    // Simulate training
220    for epoch in 0..5 {
221        let mut loss = weights.sum() + bias.sum();
222        loss.backward(None);
223        optimizer.step(&mut [&mut weights, &mut bias]);
224        optimizer.zero_grad(&mut [&mut weights, &mut bias]);
225
226        if epoch % 2 == 0 {
227            // Save checkpoint
228            let checkpoint_dir = format!("checkpoint_epoch_{}", epoch);
229            fs::create_dir_all(&checkpoint_dir)?;
230
231            weights.save_json(format!("{}/weights.json", checkpoint_dir))?;
232            bias.save_json(format!("{}/bias.json", checkpoint_dir))?;
233            optimizer.save_json(format!("{}/optimizer.json", checkpoint_dir))?;
234
235            println!("Saved checkpoint for epoch {}", epoch);
236        }
237    }
238
239    // Load from checkpoint
240    let loaded_weights = Tensor::load_json("checkpoint_epoch_4/weights.json")?;
241    let loaded_bias = Tensor::load_json("checkpoint_epoch_4/bias.json")?;
242    let loaded_optimizer = Adam::load_json("checkpoint_epoch_4/optimizer.json")?;
243
244    println!("Loaded weights: {:?}", loaded_weights.data());
245    println!("Loaded bias: {:?}", loaded_bias.data());
246    println!(
247        "Loaded optimizer learning rate: {}",
248        loaded_optimizer.learning_rate()
249    );
250
251    // Verify checkpoint integrity
252    assert_eq!(weights.shape().dims(), loaded_weights.shape().dims());
253    assert_eq!(bias.shape().dims(), loaded_bias.shape().dims());
254    assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
255
256    println!("Checkpointing verification: PASSED");
257
258    Ok(())
259}
Source

pub fn set_requires_grad(&mut self, requires_grad: bool)

Set gradient tracking for this tensor

Controls whether the gradtrack system tracks operations on this tensor and computes gradients during backward pass. When disabled, clears any existing gradients and gradient functions.

§Arguments
  • requires_grad - Whether to track gradients for this tensor
§Performance
  • Time Complexity: O(1) - simple field assignment
  • Memory: May free gradient storage when disabled
  • Overhead: Zero overhead when gradients disabled
§Examples
use train_station::Tensor;

let mut tensor = Tensor::ones(vec![2, 3]);
tensor.set_requires_grad(true);
assert!(tensor.requires_grad());

// Disable gradient tracking
tensor.set_requires_grad(false);
assert!(!tensor.requires_grad());
Examples found in repository?
examples/RL_training/dqn.rs (line 102)
100    fn set_requires_grad_all(&mut self, enable: bool) {
101        for l in &mut self.layers {
102            l.weight.set_requires_grad(enable);
103            l.bias.set_requires_grad(enable);
104        }
105    }
106
107    // In-place copy (preserve tensor IDs and optimizer links on targets)
108    fn copy_from(&mut self, other: &Self) {
109        for (t, s) in self.layers.iter_mut().zip(other.layers.iter()) {
110            {
111                let src = s.weight.data();
112                let dst = t.weight.data_mut();
113                dst.copy_from_slice(src);
114            }
115            {
116                let src = s.bias.data();
117                let dst = t.bias.data_mut();
118                dst.copy_from_slice(src);
119            }
120            t.weight.set_requires_grad(false);
121            t.bias.set_requires_grad(false);
122        }
123    }
More examples
Hide additional examples
examples/RL_training/td3.rs (line 109)
107    fn set_requires_grad_all(&mut self, enable: bool) {
108        for l in &mut self.layers {
109            l.weight.set_requires_grad(enable);
110            l.bias.set_requires_grad(enable);
111        }
112    }
113
114    fn copy_from(&mut self, other: &Self) {
115        for (t, s) in self.layers.iter_mut().zip(other.layers.iter()) {
116            {
117                let src = s.weight.data();
118                let dst = t.weight.data_mut();
119                dst.copy_from_slice(src);
120            }
121            {
122                let src = s.bias.data();
123                let dst = t.bias.data_mut();
124                dst.copy_from_slice(src);
125            }
126            t.weight.set_requires_grad(false);
127            t.bias.set_requires_grad(false);
128        }
129    }
130
131    fn soft_update_from(&mut self, source: &Self, tau: f32) {
132        let _ng = NoGradTrack::new();
133        for (t, s) in self.layers.iter_mut().zip(source.layers.iter()) {
134            // In-place Polyak update to preserve tensor IDs (no optimizer relink needed)
135            let new_w = t
136                .weight
137                .mul_scalar(1.0 - tau)
138                .add_tensor(&s.weight.mul_scalar(tau));
139            let new_b = t
140                .bias
141                .mul_scalar(1.0 - tau)
142                .add_tensor(&s.bias.mul_scalar(tau));
143            {
144                let src = new_w.data();
145                let dst = t.weight.data_mut();
146                dst.copy_from_slice(src);
147            }
148            {
149                let src = new_b.data();
150                let dst = t.bias.data_mut();
151                dst.copy_from_slice(src);
152            }
153            t.weight.set_requires_grad(false);
154            t.bias.set_requires_grad(false);
155        }
156    }
Source

pub fn retain_grad(self) -> Self

Mark this tensor to retain gradients after backward, even if it is non-leaf.

Builder-style API: returns self with retain_grad=true. Call materialize_grad() or grad_or_fetch() after backward to copy the accumulated gradient from the GradGraph into self.grad so grad() works.

Source

pub fn retain_grad_(&mut self, enable: bool)

In-place variant to enable or disable gradient retention for non-leaf tensors

Source

pub fn requires_grad(&self) -> bool

Check if this tensor requires gradients

§Returns

true if gradient tracking is enabled for this tensor

§Examples
use train_station::Tensor;

let tensor = Tensor::new(vec![2, 3]);
assert!(!tensor.requires_grad());

let grad_tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
assert!(grad_tensor.requires_grad());
Examples found in repository?
examples/RL_training/../neural_networks/basic_linear_layer.rs (line 164)
153fn demonstrate_layer_creation() {
154    println!("--- Layer Creation ---");
155
156    let layer = LinearLayer::new(3, 2, Some(42));
157
158    println!("Created linear layer:");
159    println!("  Input size: {}", layer.input_size);
160    println!("  Output size: {}", layer.output_size);
161    println!("  Parameter count: {}", layer.parameter_count());
162    println!("  Weight shape: {:?}", layer.weight.shape().dims());
163    println!("  Bias shape: {:?}", layer.bias.shape().dims());
164    println!("  Weight requires grad: {}", layer.weight.requires_grad());
165    println!("  Bias requires grad: {}", layer.bias.requires_grad());
166}
167
168/// Demonstrate forward pass with gradient tracking
169fn demonstrate_forward_pass() {
170    println!("\n--- Forward Pass (with gradients) ---");
171
172    let layer = LinearLayer::new(3, 2, Some(43));
173
174    // Single input
175    let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
176    let output = layer.forward(&input);
177
178    println!("Single input:");
179    println!("  Input: {:?}", input.data());
180    println!("  Output: {:?}", output.data());
181    println!("  Output requires grad: {}", output.requires_grad());
182
183    // Batch input
184    let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
185    let batch_output = layer.forward(&batch_input);
186
187    println!("Batch input:");
188    println!("  Input shape: {:?}", batch_input.shape().dims());
189    println!("  Output shape: {:?}", batch_output.shape().dims());
190    println!("  Output requires grad: {}", batch_output.requires_grad());
191}
192
193/// Demonstrate forward pass without gradient tracking
194fn demonstrate_forward_pass_no_grad() {
195    println!("\n--- Forward Pass (no gradients) ---");
196
197    let layer = LinearLayer::new(3, 2, Some(44));
198
199    // Single input
200    let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
201    let output = layer.forward_no_grad(&input);
202
203    println!("Single input (no grad):");
204    println!("  Input: {:?}", input.data());
205    println!("  Output: {:?}", output.data());
206    println!("  Output requires grad: {}", output.requires_grad());
207
208    // Compare with grad version
209    let output_with_grad = layer.forward(&input);
210    println!("Comparison:");
211    println!(
212        "  Same values: {}",
213        output.data() == output_with_grad.data()
214    );
215    println!("  No grad requires grad: {}", output.requires_grad());
216    println!(
217        "  With grad requires grad: {}",
218        output_with_grad.requires_grad()
219    );
220}
More examples
Hide additional examples
examples/iterators/element_iteration.rs (line 192)
175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176    println!("\n--- Gradient Tracking ---");
177
178    // Create a tensor with gradient tracking enabled
179    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180    println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182    // Perform element-wise operations through iteration
183    let result: Tensor = tensor
184        .iter()
185        .map(|elem| {
186            // Apply a complex transformation: (x^2 + 1) * 2
187            elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188        })
189        .collect();
190
191    println!("Result tensor: {:?}", result.data());
192    println!("Result requires_grad: {}", result.requires_grad());
193
194    // Compute gradients
195    let mut loss = result.sum();
196    loss.backward(None);
197
198    println!("Loss: {:.6}", loss.value());
199    println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201    Ok(())
202}
examples/getting_started/optimizer_basics.rs (line 58)
47fn demonstrate_basic_optimizer_setup() {
48    println!("--- Basic Optimizer Setup ---");
49
50    // Create parameters that require gradients
51    let weight = Tensor::randn(vec![3, 2], Some(42)).with_requires_grad();
52    let bias = Tensor::zeros(vec![2]).with_requires_grad();
53
54    println!("Created parameters:");
55    println!(
56        "  Weight: shape {:?}, requires_grad: {}",
57        weight.shape().dims(),
58        weight.requires_grad()
59    );
60    println!(
61        "  Bias: shape {:?}, requires_grad: {}",
62        bias.shape().dims(),
63        bias.requires_grad()
64    );
65
66    // Create Adam optimizer with default configuration
67    let mut optimizer = Adam::new();
68    println!(
69        "Created Adam optimizer with learning rate: {}",
70        optimizer.learning_rate()
71    );
72
73    // Add parameters to optimizer
74    optimizer.add_parameter(&weight);
75    optimizer.add_parameter(&bias);
76    println!(
77        "Added {} parameters to optimizer",
78        optimizer.parameter_count()
79    );
80
81    // Create optimizer with custom configuration
82    let config = AdamConfig {
83        learning_rate: 0.01,
84        beta1: 0.9,
85        beta2: 0.999,
86        eps: 1e-8,
87        weight_decay: 0.0,
88        amsgrad: false,
89    };
90
91    let mut custom_optimizer = Adam::with_config(config);
92    custom_optimizer.add_parameter(&weight);
93    custom_optimizer.add_parameter(&bias);
94
95    println!(
96        "Created custom optimizer with learning rate: {}",
97        custom_optimizer.learning_rate()
98    );
99
100    // Demonstrate parameter linking
101    println!("Parameter linking completed successfully");
102}
examples/neural_networks/feedforward_network.rs (line 332)
313fn demonstrate_forward_pass() {
314    println!("\n--- Forward Pass ---");
315
316    let config = FeedForwardConfig {
317        input_size: 3,
318        hidden_sizes: vec![5, 3],
319        output_size: 2,
320        use_bias: true,
321    };
322    let network = FeedForwardNetwork::new(config, Some(43));
323
324    // Single input
325    let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
326    let output = network.forward(&input);
327
328    println!("Single input forward pass:");
329    println!("  Input shape: {:?}", input.shape().dims());
330    println!("  Output shape: {:?}", output.shape().dims());
331    println!("  Output: {:?}", output.data());
332    println!("  Output requires grad: {}", output.requires_grad());
333
334    // Batch input
335    let batch_input = Tensor::from_slice(
336        &[
337            1.0, 2.0, 3.0, // Sample 1
338            4.0, 5.0, 6.0, // Sample 2
339            7.0, 8.0, 9.0, // Sample 3
340        ],
341        vec![3, 3],
342    )
343    .unwrap();
344    let batch_output = network.forward(&batch_input);
345
346    println!("Batch input forward pass:");
347    println!("  Input shape: {:?}", batch_input.shape().dims());
348    println!("  Output shape: {:?}", batch_output.shape().dims());
349    println!("  Output requires grad: {}", batch_output.requires_grad());
350
351    // Compare with no-grad version
352    let output_no_grad = network.forward_no_grad(&input);
353    println!("No-grad comparison:");
354    println!("  Same values: {}", output.data() == output_no_grad.data());
355    println!("  With grad requires grad: {}", output.requires_grad());
356    println!(
357        "  No grad requires grad: {}",
358        output_no_grad.requires_grad()
359    );
360}
examples/iterators/performance_optimization.rs (line 416)
319fn demonstrate_optimization_techniques() -> Result<(), Box<dyn std::error::Error>> {
320    println!("\n--- Optimization Techniques ---");
321
322    let size = 50000;
323    let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
324    let tensor = Tensor::from_slice(&data, vec![size])?;
325
326    println!("Optimizing processing for size: {}", size);
327
328    // Technique 1: Operation fusion
329    println!("\nTechnique 1: Operation Fusion");
330    let start = Instant::now();
331    let fused_result: Tensor = tensor
332        .iter()
333        .map(|elem| {
334            // Fuse multiple operations into single chain
335            elem.mul_scalar(2.0).add_scalar(1.0).pow_scalar(2.0).sqrt()
336        })
337        .collect();
338    let fused_time = start.elapsed();
339
340    // Technique 2: Conditional optimization
341    println!("\nTechnique 2: Conditional Optimization");
342    let start = Instant::now();
343    let conditional_result: Tensor = tensor
344        .iter()
345        .map(|elem| {
346            let val = elem.value();
347            if val < size as f32 / 2.0 {
348                elem.mul_scalar(2.0) // Simple operation for small values
349            } else {
350                elem.pow_scalar(2.0).sqrt() // Complex operation for large values
351            }
352        })
353        .collect();
354    let conditional_time = start.elapsed();
355
356    // Technique 3: Cache-friendly processing
357    println!("\nTechnique 3: Cache-Friendly Processing");
358    let start = Instant::now();
359    let cache_friendly_result: Tensor = tensor
360        .iter()
361        .take(1000) // Process in cache-friendly chunks
362        .map(|elem| elem.mul_scalar(2.0))
363        .collect();
364    let cache_friendly_time = start.elapsed();
365
366    // Technique 4: Memory pooling simulation
367    println!("\nTechnique 4: Memory Pooling Simulation");
368    let start = Instant::now();
369    let pooled_result: Tensor = tensor
370        .iter()
371        .enumerate()
372        .filter(|(i, _)| i % 100 == 0) // Process every 100th element
373        .map(|(_, elem)| elem.pow_scalar(2.0))
374        .collect();
375    let pooled_time = start.elapsed();
376
377    // Report optimization results
378    println!("  Fused operations: {:?}", fused_time);
379    println!("  Conditional optimization: {:?}", conditional_time);
380    println!("  Cache-friendly processing: {:?}", cache_friendly_time);
381    println!("  Memory pooling simulation: {:?}", pooled_time);
382
383    // Performance analysis
384    let fastest = fused_time
385        .min(conditional_time)
386        .min(cache_friendly_time)
387        .min(pooled_time);
388    println!("  Fastest technique: {:?}", fastest);
389
390    // Memory efficiency analysis
391    println!("  Fused result size: {}", fused_result.size());
392    println!("  Conditional result size: {}", conditional_result.size());
393    println!(
394        "  Cache-friendly result size: {}",
395        cache_friendly_result.size()
396    );
397    println!("  Pooled result size: {}", pooled_result.size());
398
399    // Technique 5: Gradient optimization
400    println!("\nTechnique 5: Gradient Optimization");
401    let grad_tensor = tensor.with_requires_grad();
402    let start = Instant::now();
403
404    let grad_result: Tensor = grad_tensor
405        .iter()
406        .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
407        .collect();
408
409    let mut loss = grad_result.sum();
410    loss.backward(None);
411    let grad_time = start.elapsed();
412
413    println!("  Gradient computation: {:?}", grad_time);
414    println!(
415        "  Gradient tracking enabled: {}",
416        grad_result.requires_grad()
417    );
418
419    Ok(())
420}
Source

pub fn grad(&self) -> Option<&Tensor>

Get a reference to this tensor’s locally cached gradient (if any)

This accessor returns only the gradient cached on this tensor’s grad field. It does NOT query the global/autograd gradient storage. For leaf tensors, gradients are accumulated and tracked by the grad engine and are not automatically written back to the local grad field.

To make grad() return Some(&Tensor):

  • Enable retention on this tensor (typically for non-leaf tensors) with retain_grad_(&mut tensor, true) or tensor.retain_grad()
  • After backward(), call tensor.materialize_grad() or tensor.grad_or_fetch() to copy the accumulated gradient from the autograd engine into self.grad

If you want to read gradients without caching them locally, prefer grad_owned, which consults the global gradient storage.

§Returns

Optional reference to the locally cached gradient tensor, or None if not materialized on this tensor.

§Examples
use train_station::Tensor;

let mut x = Tensor::ones(vec![2, 3]).with_requires_grad();
let mut loss = x.sum();
loss.backward(None);
// Without materialization, grad() typically returns None for leaves
assert!(x.grad().is_none());
// Materialize to cache locally so grad() works
x.retain_grad_ (true);
x.materialize_grad();
assert!(x.grad().is_some());
Examples found in repository?
examples/iterators/element_iteration.rs (line 199)
175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176    println!("\n--- Gradient Tracking ---");
177
178    // Create a tensor with gradient tracking enabled
179    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180    println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182    // Perform element-wise operations through iteration
183    let result: Tensor = tensor
184        .iter()
185        .map(|elem| {
186            // Apply a complex transformation: (x^2 + 1) * 2
187            elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188        })
189        .collect();
190
191    println!("Result tensor: {:?}", result.data());
192    println!("Result requires_grad: {}", result.requires_grad());
193
194    // Compute gradients
195    let mut loss = result.sum();
196    loss.backward(None);
197
198    println!("Loss: {:.6}", loss.value());
199    println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201    Ok(())
202}
Source

pub fn materialize_grad(&mut self) -> bool

Fetch the accumulated gradient after backward and cache it on this tensor if retain_grad is enabled.

Returns true if a gradient was found and cached. After a successful call, grad() will return Some(&Tensor) even for non-leaf tensors.

Source

pub fn grad_or_fetch(&mut self) -> Option<&Tensor>

Convenience accessor: if retain_grad is enabled, fetch and cache the gradient on first access so callers can immediately get a reference.

Source

pub fn grad_owned(&self) -> Option<Tensor>

Get the accumulated gradient from the autograd engine as an owned tensor

This accessor queries the global/shared gradient storage for this tensor’s ID and returns the current accumulated gradient by value. It complements grad, which only returns a locally cached reference.

  • Works for leaf tensors without any prior materialization step
  • Does not modify or clear internal gradient state
  • Suitable for optimizers and logging that need the latest gradients
§Returns

Some(Tensor) containing the current accumulated gradient when available, otherwise None.

§Examples
use train_station::Tensor;

let mut x = Tensor::ones(vec![2, 3]).with_requires_grad();
let mut loss = x.sum();
loss.backward(None);

// Fetch directly from autograd storage (no materialization required)
let g = x.grad_owned().unwrap();
assert_eq!(g.shape().dims(), vec![2, 3]);
Examples found in repository?
examples/RL_training/dqn.rs (line 280)
277fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
278    let mut total_sq = 0.0f32;
279    for p in parameters.iter() {
280        if let Some(g) = p.grad_owned() {
281            for &v in g.data() {
282                total_sq += v * v;
283            }
284        }
285    }
286    let norm = total_sq.sqrt();
287    if norm > max_norm {
288        let scale = max_norm / (norm + eps);
289        for p in parameters.iter_mut() {
290            if let Some(g) = p.grad_owned() {
291                p.set_grad(g.mul_scalar(scale));
292            }
293        }
294    }
295}
296
297fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
298    let mut total_sq = 0.0f32;
299    for p in parameters.iter_mut() {
300        if let Some(g) = p.grad_owned() {
301            for &v in g.data() {
302                total_sq += v * v;
303            }
304        }
305    }
306    total_sq.sqrt()
307}
308
309fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
310    let _ng = NoGradTrack::new();
311    let mut total_sq = 0.0f32;
312    for p in parameters.iter_mut() {
313        for &v in p.data() {
314            total_sq += v * v;
315        }
316    }
317    total_sq.sqrt()
318}
319
320// Pseudo-Huber loss: sqrt(1 + diff^2) - 1 (smooth, robust)
321fn pseudo_huber_mean(diff: &Tensor) -> Tensor {
322    diff.pow_scalar(2.0)
323        .add_scalar(1.0)
324        .sqrt()
325        .sub_scalar(1.0)
326        .mean()
327}
328
329// -------------------------------
330// Main
331// -------------------------------
332
333pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334    println!("=== DQN Example (YardEnv discrete) ===");
335
336    // Dims
337    let state_dim = 3usize;
338    let action_dim = 3usize;
339
340    // Hparams
341    let gamma = 0.99f32;
342    let batch_size = 64usize;
343    let start_steps = 200usize;
344    let target_update_interval = 200usize; // hard update cadence
345    let max_grad_norm = 1.0f32;
346    let mut epsilon = 1.0f32;
347    let eps_min = 0.05f32;
348    let eps_decay_steps = 2_000usize; // linear decay
349    let total_steps = std::env::var("DQN_STEPS")
350        .ok()
351        .and_then(|v| v.parse::<usize>().ok())
352        .unwrap_or(3000usize);
353
354    // Models
355    let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356    let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357    q_targ.net.copy_from(&q_net.net);
358    q_targ.set_requires_grad_all(false);
359
360    // Optimizer
361    let mut q_opt = Adam::with_learning_rate(3e-4);
362    for p in q_net.parameters() {
363        q_opt.add_parameter(p);
364    }
365
366    // Replay + env
367    let mut rb = ReplayBuffer::new(100_000, state_dim);
368    let mut env = YardEnv::new(12345);
369    let mut rng = SmallRng::new(999_111);
370
371    // Metrics
372    let mut state = env.reset();
373    let mut episode_return = 0.0f32;
374    let mut episode = 0usize;
375    let mut ema_return: Option<f32> = None;
376    let ema_alpha = 0.05f32;
377    let mut best_return = f32::NEG_INFINITY;
378
379    for t in 0..total_steps {
380        // Epsilon-greedy action
381        let action_index = if t < start_steps || rng.next_f32() < epsilon {
382            rng.sample_index(action_dim)
383        } else {
384            let _ng = NoGradTrack::new();
385            let q_vals = q_net.forward(&state);
386            let row = q_vals.data();
387            let mut best_i = 0usize;
388            let mut best_v = row[0];
389            for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390                if r > best_v {
391                    best_v = r;
392                    best_i = i;
393                }
394            }
395            best_i
396        };
397
398        // Env step
399        let (next_state, reward, done) = env.step(action_index);
400        episode_return += reward;
401
402        // Store
403        let s_slice = state.data().to_vec();
404        let s2_slice = next_state.data().to_vec();
405        rb.push(
406            &s_slice,
407            action_index,
408            reward,
409            if done { 1.0 } else { 0.0 },
410            &s2_slice,
411        );
412
413        // Reset on done
414        state = if done {
415            let st = env.reset();
416            ema_return = Some(match ema_return {
417                None => episode_return,
418                Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419            });
420            if episode_return > best_return {
421                best_return = episode_return;
422            }
423            println!(
424                "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425                t,
426                episode,
427                episode_return,
428                ema_return.unwrap_or(episode_return),
429                best_return,
430                rb.size
431            );
432            episode_return = 0.0;
433            episode += 1;
434            st
435        } else {
436            next_state
437        };
438
439        // Epsilon linear decay
440        if t < eps_decay_steps {
441            epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442        }
443
444        // Train
445        if rb.can_sample(batch_size) {
446            let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448            // Double DQN target: a* = argmax_a Q_online(s2,a); y = r + (1-d)*gamma*Q_target(s2, a*)
449            let target_q = {
450                let _ng = NoGradTrack::new();
451                let q_online_s2 = q_net.forward(&s2);
452                // argmax per row (manual on CPU)
453                let row_stride = action_dim;
454                let qd = q_online_s2.data();
455                let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456                for i in 0..batch_size {
457                    let base = i * row_stride;
458                    let mut bi = 0usize;
459                    let mut bv = qd[base];
460                    for j in 1..action_dim {
461                        let v = qd[base + j];
462                        if v > bv {
463                            bv = v;
464                            bi = j;
465                        }
466                    }
467                    next_actions.push(bi);
468                }
469                let q_targ_s2 = q_targ.forward(&s2);
470                let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471                let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472                r.add_tensor(&not_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473            };
474
475            // Q(s,a) for current actions
476            // Zero grads first
477            {
478                let mut params = q_net.parameters();
479                q_opt.zero_grad(&mut params);
480            }
481
482            let q_all = q_net.forward(&s);
483            let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484            let diff = q_sa.sub_tensor(&target_q);
485            let mut loss = pseudo_huber_mean(&diff);
486            loss.backward(None);
487
488            // Step (filter only params with grads)
489            {
490                let params = q_net.parameters();
491                let mut with_grads: Vec<&mut Tensor> = Vec::new();
492                for p in params {
493                    if p.grad_owned().is_some() {
494                        with_grads.push(p);
495                    }
496                }
497                if !with_grads.is_empty() {
498                    let gn = grad_global_norm(&mut with_grads);
499                    clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500                    q_opt.step(&mut with_grads);
501                    q_opt.zero_grad(&mut with_grads);
502                    if t % 100 == 0 {
503                        let mut pn = q_net.parameters();
504                        let pn_l2 = params_l2_norm(&mut pn);
505                        let q_mean = q_all.mean().value();
506                        println!(
507                            "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508                            t, loss.value(), q_mean, gn, pn_l2, epsilon
509                        );
510                    }
511                }
512            }
513
514            // Target hard update
515            if t % target_update_interval == 0 {
516                q_targ.net.copy_from(&q_net.net);
517            }
518
519            // Clear graphs
520            clear_all_graphs_known();
521        }
522    }
523
524    println!("=== DQN training finished ===");
525    Ok(())
526}
More examples
Hide additional examples
examples/RL_training/ppo_continuous.rs (line 296)
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294    let mut total_sq = 0.0f32;
295    for p in parameters.iter() {
296        if let Some(g) = p.grad_owned() {
297            for &v in g.data() {
298                total_sq += v * v;
299            }
300        }
301    }
302    let norm = total_sq.sqrt();
303    if norm > max_norm {
304        let scale = max_norm / (norm + eps);
305        for p in parameters.iter_mut() {
306            if let Some(g) = p.grad_owned() {
307                p.set_grad(g.mul_scalar(scale));
308            }
309        }
310    }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314    let mut total_sq = 0.0f32;
315    for p in parameters.iter_mut() {
316        if let Some(g) = p.grad_owned() {
317            for &v in g.data() {
318                total_sq += v * v;
319            }
320        }
321    }
322    total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330    println!("=== PPO Continuous Example (YardEnv) ===");
331
332    let state_dim = 3usize;
333    let action_dim = 1usize;
334
335    // Hparams
336    let total_steps = std::env::var("PPO_STEPS")
337        .ok()
338        .and_then(|v| v.parse::<usize>().ok())
339        .unwrap_or(4000usize);
340    let horizon = 128usize; // rollout length per update
341    let epochs = 4usize; // PPO epochs per update
342    let mini_batch_size = 64usize; // minibatch from horizon
343    let gamma = 0.99f32;
344    let lam = 0.95f32; // GAE lambda
345    let clip_eps = 0.2f32;
346    let vf_coef = 0.5f32;
347    let ent_coef = 0.0f32;
348    let max_grad_norm = 1.0f32;
349
350    // Models
351    let mut actor = Actor::new(state_dim, action_dim, Some(101));
352    let mut critic = Critic::new(state_dim, Some(202));
353
354    // Opts
355    let mut actor_opt = Adam::with_learning_rate(3e-4);
356    for p in actor.parameters() {
357        actor_opt.add_parameter(p);
358    }
359    let mut critic_opt = Adam::with_learning_rate(3e-4);
360    for p in critic.parameters() {
361        critic_opt.add_parameter(p);
362    }
363
364    // Env and RNG
365    let mut env = YardEnv::new(42);
366    let mut rng = SmallRng::new(999);
367    let mut state = env.reset();
368
369    // Metrics
370    let mut episode_return = 0.0f32;
371    let mut episode = 0usize;
372    let mut ema_return: Option<f32> = None;
373    let ema_alpha = 0.05f32;
374    let mut best_return = f32::NEG_INFINITY;
375
376    let mut t = 0usize;
377    while t < total_steps {
378        // Collect a rollout
379        let mut batch = RolloutBatch::new(horizon, state_dim);
380        for _ in 0..horizon {
381            // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382            let (mean, log_std_row) = actor.forward(&state);
383            let mean_v = mean.data()[0];
384            let log_std_v = log_std_row.data()[0];
385            let std_v = log_std_v.exp();
386            let noise = rng.normal();
387            let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389            // Build action tensor [1, A] for log_prob calculation with autograd
390            let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391            let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392            let log_prob_v = log_prob_t.data()[0];
393
394            // Step env
395            let (next_state, reward, done) = env.step(action_v);
396            episode_return += reward;
397
398            // Value
399            let value_t = critic.forward(&state);
400            let value_v = value_t.data()[0];
401
402            // Push
403            batch.push(
404                state.data(),
405                action_v,
406                log_prob_v,
407                reward,
408                if done { 1.0 } else { 0.0 },
409                value_v,
410                next_state.data(),
411            );
412
413            // Reset
414            state = if done {
415                let st = env.reset();
416                ema_return = Some(match ema_return {
417                    None => episode_return,
418                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419                });
420                if episode_return > best_return {
421                    best_return = episode_return;
422                }
423                println!(
424                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425                    t,
426                    episode,
427                    episode_return,
428                    ema_return.unwrap_or(episode_return),
429                    best_return
430                );
431                episode_return = 0.0;
432                episode += 1;
433                st
434            } else {
435                next_state
436            };
437
438            t += 1;
439            if t >= total_steps {
440                break;
441            }
442        }
443
444        // Bootstrap next values for GAE
445        let next_values: Vec<f32> = {
446            let mut out = Vec::with_capacity(batch.len());
447            for i in 0..batch.len() {
448                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450                let v2 = critic.forward(&s2_t).data()[0];
451                out.push(v2);
452            }
453            out
454        };
455
456        // Compute returns and advantages
457        let mut returns = vec![0.0f32; batch.len()];
458        let mut adv = vec![0.0f32; batch.len()];
459        compute_gae(
460            &mut returns,
461            &mut adv,
462            &batch.rewards,
463            &batch.dones,
464            &batch.values,
465            &next_values,
466            gamma,
467            lam,
468        );
469        normalize_in_place(&mut adv, 1e-8);
470
471        // Prepare tensors for training
472        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473        let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474        let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478        // PPO epochs over the rollout
479        let num_minibatches = batch.len().div_ceil(mini_batch_size);
480        for e in 0..epochs {
481            for mb in 0..num_minibatches {
482                let start = mb * mini_batch_size;
483                let end = (start + mini_batch_size).min(batch.len());
484                if start >= end {
485                    break;
486                }
487
488                // Slice views
489                let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490                let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491                let a_mb = actions_t
492                    .slice_view(start * action_dim, 1, (end - start) * action_dim)
493                    .reshape(vec![(end - start) as i32, action_dim as i32]);
494                let oldlp_mb = old_logp_t
495                    .slice_view(start, 1, end - start)
496                    .reshape(vec![(end - start) as i32, 1]);
497                let ret_mb = returns_t
498                    .slice_view(start, 1, end - start)
499                    .reshape(vec![(end - start) as i32, 1]);
500                let adv_mb = adv_t
501                    .slice_view(start, 1, end - start)
502                    .reshape(vec![(end - start) as i32, 1]);
503
504                // Zero grads
505                {
506                    let mut ps = actor.parameters();
507                    actor_opt.zero_grad(&mut ps);
508                }
509                {
510                    let mut ps = critic.parameters();
511                    critic_opt.zero_grad(&mut ps);
512                }
513
514                // Forward actor and critic
515                let (mean_mb, log_std_row) = actor.forward(&s_mb);
516                let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517                let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518                let clip_low =
519                    Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520                        .unwrap();
521                let clip_high =
522                    Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523                        .unwrap();
524                // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525                let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526                let ratio_clipped =
527                    clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528                let pg1 = ratio.mul_tensor(&adv_mb);
529                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532                let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534                let v_pred = critic.forward(&s_mb);
535                let v_loss = v_pred
536                    .sub_tensor(&ret_mb)
537                    .pow_scalar(2.0)
538                    .mean()
539                    .mul_scalar(vf_coef);
540
541                // Entropy (approx Gaussian entropy per action)
542                let entropy = log_std_row
543                    .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544                    .sum_dims(&[1], true)
545                    .mean()
546                    .mul_scalar(ent_coef);
547
548                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549                loss.backward(None);
550
551                // Step actor
552                {
553                    let params = actor.parameters();
554                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
555                    for p in params {
556                        if p.grad_owned().is_some() {
557                            with_grads.push(p);
558                        }
559                    }
560                    if !with_grads.is_empty() {
561                        let _ = grad_global_norm(&mut with_grads);
562                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563                        actor_opt.step(&mut with_grads);
564                        actor_opt.zero_grad(&mut with_grads);
565                    }
566                }
567
568                // Step critic
569                {
570                    let params = critic.parameters();
571                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
572                    for p in params {
573                        if p.grad_owned().is_some() {
574                            with_grads.push(p);
575                        }
576                    }
577                    if !with_grads.is_empty() {
578                        let _ = grad_global_norm(&mut with_grads);
579                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580                        critic_opt.step(&mut with_grads);
581                        critic_opt.zero_grad(&mut with_grads);
582                    }
583                }
584
585                // Occasionally log
586                if e == 0 && mb == 0 {
587                    println!(
588                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
589                        t,
590                        actor_loss.value(),
591                        v_loss.value()
592                    );
593                }
594
595                clear_all_graphs_known();
596            }
597        }
598    }
599
600    println!("=== PPO training finished ===");
601    Ok(())
602}
examples/RL_training/ppo_discrete.rs (line 255)
252fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
253    let mut total_sq = 0.0f32;
254    for p in parameters.iter() {
255        if let Some(g) = p.grad_owned() {
256            for &v in g.data() {
257                total_sq += v * v;
258            }
259        }
260    }
261    let norm = total_sq.sqrt();
262    if norm > max_norm {
263        let scale = max_norm / (norm + eps);
264        for p in parameters.iter_mut() {
265            if let Some(g) = p.grad_owned() {
266                p.set_grad(g.mul_scalar(scale));
267            }
268        }
269    }
270}
271
272// log-softmax for selected actions: given logits [B,A] and actions Vec<usize> -> log_prob [B,1]
273fn log_prob_actions(
274    logits: &Tensor,
275    actions: &[usize],
276    batch: usize,
277    _action_dim: usize,
278) -> Tensor {
279    let max_logits = logits.max_dims(&[1], true); // [B,1]
280    let shifted = logits.sub_tensor(&max_logits);
281    let exp = shifted.exp();
282    let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283    let log_sum_exp = sum_exp.log(); // [B,1]
284    let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285                                                        // gather selected action log-probs
286    log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291    new_logp.sub_tensor(old_logp).exp()
292}
293
294// Clamp ratio to [1-clip, 1+clip] using ReLU-based clamp (no custom ops)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296    let b = ratio.shape().dims()[0];
297    let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298    let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299    let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300    high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304    let mut total_sq = 0.0f32;
305    for p in parameters.iter_mut() {
306        if let Some(g) = p.grad_owned() {
307            for &v in g.data() {
308                total_sq += v * v;
309            }
310        }
311    }
312    total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320    println!("=== PPO Discrete Example (YardEnv) ===");
321
322    let state_dim = 3usize;
323    let action_dim = 3usize;
324    let total_steps = std::env::var("PPOD_STEPS")
325        .ok()
326        .and_then(|v| v.parse::<usize>().ok())
327        .unwrap_or(3500usize);
328    let horizon = 128usize;
329    let epochs = 4usize;
330    let mini_batch_size = 64usize;
331    let gamma = 0.99f32;
332    let lam = 0.95f32;
333    let clip_eps = 0.2f32;
334    let vf_coef = 0.5f32;
335    let ent_coef = 0.0f32;
336    let max_grad_norm = 1.0f32;
337
338    let mut actor = Actor::new(state_dim, action_dim, Some(111));
339    let mut critic = Critic::new(state_dim, Some(222));
340    let mut actor_opt = Adam::with_learning_rate(3e-4);
341    for p in actor.parameters() {
342        actor_opt.add_parameter(p);
343    }
344    let mut critic_opt = Adam::with_learning_rate(3e-4);
345    for p in critic.parameters() {
346        critic_opt.add_parameter(p);
347    }
348
349    let mut env = YardEnv::new(1234);
350    let mut rng = SmallRng::new(98765);
351    let mut state = env.reset();
352    let mut episode_return = 0.0f32;
353    let mut episode = 0usize;
354    let mut ema_return: Option<f32> = None;
355    let ema_alpha = 0.05f32;
356    let mut best_return = f32::NEG_INFINITY;
357
358    let mut t = 0usize;
359    while t < total_steps {
360        let mut batch = RolloutBatch::new(horizon, state_dim);
361        for _ in 0..horizon {
362            // Actor logits and categorical sampling
363            let logits = actor.forward(&state); // [1, A]
364            let probs = logits.softmax(1); // [1, A]
365                                           // sample action from probs (CPU sampling)
366            let p = probs.data();
367            let (p0, p1, _p2) = (p[0], p[1], p[2]);
368            let u = rng.next_f32();
369            let a_idx = if u < p0 {
370                0
371            } else if u < p0 + p1 {
372                1
373            } else {
374                2
375            };
376
377            let old_logp = {
378                let _ng = NoGradTrack::new();
379                let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380                lp.data()[0]
381            };
382
383            // Step env
384            let (next_state, reward, done) = env.step(a_idx);
385            episode_return += reward;
386
387            // Critic value
388            let value_t = critic.forward(&state);
389            let value_v = value_t.data()[0];
390
391            batch.push(
392                state.data(),
393                a_idx,
394                old_logp,
395                reward,
396                if done { 1.0 } else { 0.0 },
397                value_v,
398                next_state.data(),
399            );
400
401            state = if done {
402                let st = env.reset();
403                ema_return = Some(match ema_return {
404                    None => episode_return,
405                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406                });
407                if episode_return > best_return {
408                    best_return = episode_return;
409                }
410                println!(
411                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412                    t,
413                    episode,
414                    episode_return,
415                    ema_return.unwrap_or(episode_return),
416                    best_return
417                );
418                episode_return = 0.0;
419                episode += 1;
420                st
421            } else {
422                next_state
423            };
424
425            t += 1;
426            if t >= total_steps {
427                break;
428            }
429        }
430
431        // Bootstrap values for GAE
432        let next_values: Vec<f32> = {
433            let mut out = Vec::with_capacity(batch.len());
434            for i in 0..batch.len() {
435                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437                out.push(critic.forward(&s2_t).data()[0]);
438            }
439            out
440        };
441
442        let mut returns = vec![0.0f32; batch.len()];
443        let mut adv = vec![0.0f32; batch.len()];
444        compute_gae(
445            &mut returns,
446            &mut adv,
447            &batch.rewards,
448            &batch.dones,
449            &batch.values,
450            &next_values,
451            gamma,
452            lam,
453        );
454        normalize_in_place(&mut adv, 1e-8);
455
456        // Tensors for training
457        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458        let actions_vec = batch.actions.clone();
459        let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463        // PPO epochs
464        let num_minibatches = batch.len().div_ceil(mini_batch_size);
465        for e in 0..epochs {
466            for mb in 0..num_minibatches {
467                let start = mb * mini_batch_size;
468                let end = (start + mini_batch_size).min(batch.len());
469                if start >= end {
470                    break;
471                }
472
473                // Views
474                let s_mb = states_t
475                    .slice_view(start * state_dim, 1, (end - start) * state_dim)
476                    .reshape(vec![(end - start) as i32, state_dim as i32]);
477                let oldlp_mb = old_logp_t
478                    .slice_view(start, 1, end - start)
479                    .reshape(vec![(end - start) as i32, 1]);
480                let ret_mb = returns_t
481                    .slice_view(start, 1, end - start)
482                    .reshape(vec![(end - start) as i32, 1]);
483                let adv_mb = adv_t
484                    .slice_view(start, 1, end - start)
485                    .reshape(vec![(end - start) as i32, 1]);
486                let a_slice = &actions_vec[start..end];
487
488                // Zero grads
489                {
490                    let mut ps = actor.parameters();
491                    actor_opt.zero_grad(&mut ps);
492                }
493                {
494                    let mut ps = critic.parameters();
495                    critic_opt.zero_grad(&mut ps);
496                }
497
498                // Forward
499                let logits_mb = actor.forward(&s_mb); // [B,A]
500                let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501                let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502                let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503                let pg1 = ratio.mul_tensor(&adv_mb);
504                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507                let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509                let v_pred = critic.forward(&s_mb);
510                let v_loss = v_pred
511                    .sub_tensor(&ret_mb)
512                    .pow_scalar(2.0)
513                    .mean()
514                    .mul_scalar(vf_coef);
515
516                // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517                let probs_mb = logits_mb.softmax(1);
518                let logp_all = probs_mb.add_scalar(1e-8).log();
519                let ent = probs_mb
520                    .mul_tensor(&logp_all)
521                    .sum_dims(&[1], true)
522                    .mul_scalar(-1.0)
523                    .mean()
524                    .mul_scalar(ent_coef);
525
526                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527                loss.backward(None);
528
529                // Step actor
530                {
531                    let params = actor.parameters();
532                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
533                    for p in params {
534                        if p.grad_owned().is_some() {
535                            with_grads.push(p);
536                        }
537                    }
538                    if !with_grads.is_empty() {
539                        let _ = grad_global_norm(&mut with_grads);
540                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541                        actor_opt.step(&mut with_grads);
542                        actor_opt.zero_grad(&mut with_grads);
543                    }
544                }
545
546                // Step critic
547                {
548                    let params = critic.parameters();
549                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
550                    for p in params {
551                        if p.grad_owned().is_some() {
552                            with_grads.push(p);
553                        }
554                    }
555                    if !with_grads.is_empty() {
556                        let _ = grad_global_norm(&mut with_grads);
557                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558                        critic_opt.step(&mut with_grads);
559                        critic_opt.zero_grad(&mut with_grads);
560                    }
561                }
562
563                if e == 0 && mb == 0 {
564                    println!(
565                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
566                        t,
567                        actor_loss.value(),
568                        v_loss.value()
569                    );
570                }
571
572                clear_all_graphs_known();
573            }
574        }
575    }
576
577    println!("=== PPO discrete training finished ===");
578    Ok(())
579}
examples/supervised_training/supervised_bce.rs (line 26)
23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24    let mut total_sq = 0.0f32;
25    for p in parameters.iter() {
26        if let Some(g) = p.grad_owned() {
27            for &v in g.data() {
28                total_sq += v * v;
29            }
30        }
31    }
32    let norm = total_sq.sqrt();
33    if norm > max_norm {
34        let scale = max_norm / (norm + eps);
35        for p in parameters.iter_mut() {
36            if let Some(g) = p.grad_owned() {
37                p.set_grad(g.mul_scalar(scale));
38            }
39        }
40    }
41}
42
43fn accuracy(pred: &Tensor, targets: &Tensor) -> f32 {
44    // pred: [B,1] with sigmoid; threshold at 0.5
45    let p = pred.data();
46    let t = targets.data();
47    let mut correct = 0usize;
48    for i in 0..p.len() {
49        let yhat = if p[i] >= 0.5 { 1.0 } else { 0.0 };
50        if (yhat - t[i]).abs() < 1e-6 {
51            correct += 1;
52        }
53    }
54    correct as f32 / (p.len() as f32)
55}
56
57// Numerically stable BCE with logits:
58// L = mean( relu(z) - z*y + log(1 + exp(-|z|)) )
59fn bce_with_logits(logits: &Tensor, targets: &Tensor) -> Tensor {
60    let relu_z = logits.relu();
61    let zy = logits.mul_tensor(targets);
62    // |z| = relu(z) + relu(-z)
63    let abs_z = relu_z.add_tensor(&logits.mul_scalar(-1.0).relu());
64    let log_term = abs_z.mul_scalar(-1.0).exp().add_scalar(1.0).log();
65    relu_z.sub_tensor(&zy).add_tensor(&log_term).mean()
66}
67
68pub fn main() -> Result<(), Box<dyn std::error::Error>> {
69    println!("=== Supervised FFN Example (XOR) ===");
70
71    // Dataset: XOR (repeat to form a small batch)
72    let inputs: Vec<f32> = vec![
73        0.0, 0.0, // -> 0
74        0.0, 1.0, // -> 1
75        1.0, 0.0, // -> 1
76        1.0, 1.0, // -> 0
77    ];
78    let targets: Vec<f32> = vec![0.0, 1.0, 1.0, 0.0];
79
80    // Repeat the base patterns to stabilize training
81    let repeats = 64usize; // effective batch = 4 * repeats = 256
82    let mut xs = Vec::with_capacity(repeats * inputs.len());
83    let mut ys = Vec::with_capacity(repeats * targets.len());
84    for _ in 0..repeats {
85        xs.extend_from_slice(&inputs);
86        ys.extend_from_slice(&targets);
87    }
88
89    let batch = xs.len() / 2; // two features
90    let x_t = Tensor::from_slice(&xs, vec![batch, 2]).unwrap();
91    let y_t = Tensor::from_slice(&ys, vec![batch, 1]).unwrap();
92
93    // Model config: 2 -> 32 -> 32 -> 1, final sigmoid via loss path
94    let cfg = FeedForwardConfig {
95        input_size: 2,
96        hidden_sizes: vec![32, 32],
97        output_size: 1,
98        use_bias: true,
99    };
100    let mut net = FeedForwardNetwork::new(cfg, Some(777));
101
102    // Optimizer and parameter linking
103    let mut opt = Adam::with_learning_rate(1e-3);
104    for p in net.parameters() {
105        opt.add_parameter(p);
106    }
107
108    let epochs = 1000usize;
109    let max_grad_norm = 1.0f32;
110    let mut best_loss = f32::INFINITY;
111    let mut best_acc = 0.0f32;
112
113    for e in 0..epochs {
114        // Zero grads each iteration
115        {
116            let mut params = net.parameters();
117            opt.zero_grad(&mut params);
118        }
119
120        // Forward -> logits; use numerically stable BCE-with-logits for loss
121        let logits = net.forward(&x_t);
122        let mut loss = bce_with_logits(&logits, &y_t);
123        loss.backward(None);
124
125        // Step only params with grads
126        {
127            let params = net.parameters();
128            let mut with_grads: Vec<&mut Tensor> = Vec::new();
129            for p in params {
130                if p.grad_owned().is_some() {
131                    with_grads.push(p);
132                }
133            }
134            if !with_grads.is_empty() {
135                clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
136                opt.step(&mut with_grads);
137                opt.zero_grad(&mut with_grads);
138            }
139        }
140
141        // Metrics (use sigmoid only for reporting accuracy)
142        let preds = logits.sigmoid();
143        let acc = accuracy(&preds, &y_t);
144        if loss.value() < best_loss {
145            best_loss = loss.value();
146        }
147        if acc > best_acc {
148            best_acc = acc;
149        }
150        if e % 10 == 0 || e + 1 == epochs {
151            println!(
152                "epoch {:4} | loss={:.5} acc={:.3} | best_loss={:.5} best_acc={:.3}",
153                e,
154                loss.value(),
155                acc,
156                best_loss,
157                best_acc
158            );
159        }
160
161        // Clear graphs to avoid stale accumulation across epochs
162        clear_all_graphs_known();
163    }
164
165    // Quick sanity check predictions
166    let test = Tensor::from_slice(&inputs, vec![4, 2]).unwrap();
167    let out = net.forward(&test).sigmoid();
168    println!("predictions (approx): {:?}", out.data());
169
170    println!("=== Supervised training finished ===");
171    Ok(())
172}
examples/supervised_training/supervised_classification.rs (line 26)
23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24    let mut total_sq = 0.0f32;
25    for p in parameters.iter() {
26        if let Some(g) = p.grad_owned() {
27            for &v in g.data() {
28                total_sq += v * v;
29            }
30        }
31    }
32    let norm = total_sq.sqrt();
33    if norm > max_norm {
34        let scale = max_norm / (norm + eps);
35        for p in parameters.iter_mut() {
36            if let Some(g) = p.grad_owned() {
37                p.set_grad(g.mul_scalar(scale));
38            }
39        }
40    }
41}
42
43// Cross-entropy over logits: CE = -mean(log_softmax(logits)[range, labels])
44fn cross_entropy_logits(
45    logits: &Tensor,
46    labels: &[usize],
47    batch: usize,
48    _num_classes: usize,
49) -> Tensor {
50    // log_softmax = logits - logsumexp(logits, dim=1)
51    let max_logits = logits.max_dims(&[1], true);
52    let shifted = logits.sub_tensor(&max_logits);
53    let exp = shifted.exp();
54    let sum_exp = exp.sum_dims(&[1], true);
55    let log_sum_exp = sum_exp.log();
56    let log_softmax = shifted.sub_tensor(&log_sum_exp);
57    let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58    ll.mul_scalar(-1.0).mean()
59}
60
61fn accuracy_from_logits(
62    logits: &Tensor,
63    labels: &[usize],
64    batch: usize,
65    num_classes: usize,
66) -> f32 {
67    let row = logits.data();
68    let mut correct = 0usize;
69    for (i, &label) in labels.iter().enumerate().take(batch) {
70        let base = i * num_classes;
71        let mut best_j = 0usize;
72        let mut best_v = row[base];
73        for j in 1..num_classes {
74            let v = row[base + j];
75            if v > best_v {
76                best_v = v;
77                best_j = j;
78            }
79        }
80        if best_j == label {
81            correct += 1;
82        }
83    }
84    correct as f32 / batch as f32
85}
86
87pub fn main() -> Result<(), Box<dyn std::error::Error>> {
88    println!("=== Supervised Classification Example (Cross-Entropy) ===");
89
90    // Synthetic 2D inputs, 3 classes with linear-ish separations
91    let n = 1200usize;
92    let classes = 3usize;
93    let mut xs: Vec<f32> = Vec::with_capacity(n * 2);
94    let mut ys: Vec<usize> = Vec::with_capacity(n);
95
96    // Simple RNG
97    let mut state: u64 = 424242;
98    let mut rand_f32 = || {
99        state = state.wrapping_mul(1664525).wrapping_add(1013904223);
100        (state >> 16) as f32 / (u32::MAX as f32)
101    };
102
103    for _ in 0..n {
104        let x1 = rand_f32() * 4.0 - 2.0;
105        let x2 = rand_f32() * 4.0 - 2.0;
106        // Class by quadrant-ish rule with noise
107        let mut c = if x1 + 0.5 * x2 > 0.5 {
108            0
109        } else if x1 - x2 < -0.5 {
110            1
111        } else {
112            2
113        };
114        if rand_f32() < 0.05 {
115            c = (c + 1) % classes;
116        }
117        xs.push(x1);
118        xs.push(x2);
119        ys.push(c);
120    }
121
122    // Normalize inputs per-feature to [-1, 1]
123    let mut min1 = f32::INFINITY;
124    let mut max1 = f32::NEG_INFINITY;
125    let mut min2 = f32::INFINITY;
126    let mut max2 = f32::NEG_INFINITY;
127    for i in (0..xs.len()).step_by(2) {
128        let a = xs[i];
129        let b = xs[i + 1];
130        if a < min1 {
131            min1 = a;
132        }
133        if a > max1 {
134            max1 = a;
135        }
136        if b < min2 {
137            min2 = b;
138        }
139        if b > max2 {
140            max2 = b;
141        }
142    }
143    let rng1 = (max1 - min1).max(1e-8);
144    let rng2 = (max2 - min2).max(1e-8);
145    for i in (0..xs.len()).step_by(2) {
146        let a = xs[i];
147        let b = xs[i + 1];
148        xs[i] = 2.0 * (a - min1) / rng1 - 1.0;
149        xs[i + 1] = 2.0 * (b - min2) / rng2 - 1.0;
150    }
151
152    // Train/Val split (80/20)
153    let n_train = (n as f32 * 0.8) as usize;
154    let x_train = Tensor::from_slice(&xs[..n_train * 2], vec![n_train, 2]).unwrap();
155    let y_train = ys[..n_train].to_vec();
156    let x_val = Tensor::from_slice(&xs[n_train * 2..], vec![n - n_train, 2]).unwrap();
157    let y_val = ys[n_train..].to_vec();
158
159    // Model: 2 -> 64 -> 64 -> 3 (logits)
160    let cfg = FeedForwardConfig {
161        input_size: 2,
162        hidden_sizes: vec![64, 64],
163        output_size: classes,
164        use_bias: true,
165    };
166    let mut net = FeedForwardNetwork::new(cfg, Some(303));
167
168    // Optimizer
169    let mut opt = Adam::with_learning_rate(1e-3);
170    for p in net.parameters() {
171        opt.add_parameter(p);
172    }
173
174    let epochs = 300usize;
175    let max_grad_norm = 1.0f32;
176    let mut best_val_acc = 0.0f32;
177    let mut best_val_loss = f32::INFINITY;
178
179    for e in 0..epochs {
180        // Zero grads
181        {
182            let mut params = net.parameters();
183            opt.zero_grad(&mut params);
184        }
185
186        // Forward logits
187        let logits = net.forward(&x_train);
188        let mut loss = cross_entropy_logits(&logits, &y_train, n_train, classes);
189        loss.backward(None);
190
191        // Step clipped
192        {
193            let params = net.parameters();
194            let mut with_grads: Vec<&mut Tensor> = Vec::new();
195            for p in params {
196                if p.grad_owned().is_some() {
197                    with_grads.push(p);
198                }
199            }
200            if !with_grads.is_empty() {
201                clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
202                opt.step(&mut with_grads);
203                opt.zero_grad(&mut with_grads);
204            }
205        }
206
207        // Metrics
208        let train_acc = accuracy_from_logits(&logits, &y_train, n_train, classes);
209        let val_logits = net.forward(&x_val);
210        let val_loss = cross_entropy_logits(&val_logits, &y_val, n - n_train, classes).value();
211        let val_acc = accuracy_from_logits(&val_logits, &y_val, n - n_train, classes);
212        if val_acc > best_val_acc {
213            best_val_acc = val_acc;
214        }
215        if val_loss < best_val_loss {
216            best_val_loss = val_loss;
217        }
218
219        if e % 10 == 0 || e + 1 == epochs {
220            println!(
221                "epoch {:4} | loss={:.4} acc={:.3} | val_loss={:.4} val_acc={:.3} | best_val_acc={:.3}",
222                e, loss.value(), train_acc, val_loss, val_acc, best_val_acc
223            );
224        }
225
226        clear_all_graphs_known();
227    }
228
229    // Quick sample preds via softmax
230    let samples = Tensor::from_slice(&[-1.0, -1.0, 0.0, 0.0, 1.0, 1.0], vec![3, 2]).unwrap();
231    let sm = net.forward(&samples).softmax(1);
232    println!("sample class probs: {:?}", sm.data());
233
234    println!("=== Supervised classification finished ===");
235    Ok(())
236}
examples/supervised_training/supervised_regression.rs (line 25)
22fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
23    let mut total_sq = 0.0f32;
24    for p in parameters.iter() {
25        if let Some(g) = p.grad_owned() {
26            for &v in g.data() {
27                total_sq += v * v;
28            }
29        }
30    }
31    let norm = total_sq.sqrt();
32    if norm > max_norm {
33        let scale = max_norm / (norm + eps);
34        for p in parameters.iter_mut() {
35            if let Some(g) = p.grad_owned() {
36                p.set_grad(g.mul_scalar(scale));
37            }
38        }
39    }
40}
41
42fn mse(pred: &Tensor, target: &Tensor) -> Tensor {
43    pred.sub_tensor(target).pow_scalar(2.0).mean()
44}
45
46fn rmse(pred: &Tensor, target: &Tensor) -> f32 {
47    mse(pred, target).sqrt().value()
48}
49
50fn r2_score(pred: &Tensor, target: &Tensor) -> f32 {
51    // R^2 = 1 - SS_res / SS_tot
52    let y = target;
53    let y_mean = y.mean();
54    let ss_res = pred.sub_tensor(y).pow_scalar(2.0).sum();
55    let ss_tot = y.sub_tensor(&y_mean).pow_scalar(2.0).sum();
56    let ss_res_v = ss_res.value();
57    let ss_tot_v = ss_tot.value().max(1e-12); // avoid divide by zero
58    1.0 - (ss_res_v / ss_tot_v)
59}
60
61pub fn main() -> Result<(), Box<dyn std::error::Error>> {
62    println!("=== Supervised Regression Example (MSE) ===");
63
64    // Generate simple synthetic data: y = 2*x1 - 3*x2 + 0.5 + noise
65    let n = 1024usize;
66    let mut xs: Vec<f32> = Vec::with_capacity(n * 2);
67    let mut ys: Vec<f32> = Vec::with_capacity(n);
68    // Simple LCG RNG for reproducibility
69    let mut state: u64 = 123456789;
70    let mut rand_f32 = || {
71        state = state.wrapping_mul(1664525).wrapping_add(1013904223);
72        (state >> 16) as f32 / (u32::MAX as f32)
73    };
74    for _ in 0..n {
75        let x1 = rand_f32() * 2.0 - 1.0;
76        let x2 = rand_f32() * 2.0 - 1.0;
77        let noise = (rand_f32() * 2.0 - 1.0) * 0.05;
78        let y = 2.0 * x1 - 3.0 * x2 + 0.5 + noise;
79        xs.push(x1);
80        xs.push(x2);
81        ys.push(y);
82    }
83
84    // Normalize targets to [-1, 1] (max-abs scaling) for reasonable loss magnitudes
85    let mut max_abs = 0.0f32;
86    for &v in &ys {
87        let a = v.abs();
88        if a > max_abs {
89            max_abs = a;
90        }
91    }
92    if max_abs < 1e-8 {
93        max_abs = 1.0;
94    }
95    for v in ys.iter_mut() {
96        *v /= max_abs;
97    }
98
99    // Normalize inputs per-feature to [-1, 1] (min-max scaling)
100    let mut min1 = f32::INFINITY;
101    let mut max1 = f32::NEG_INFINITY;
102    let mut min2 = f32::INFINITY;
103    let mut max2 = f32::NEG_INFINITY;
104    for i in (0..xs.len()).step_by(2) {
105        let a = xs[i];
106        let b = xs[i + 1];
107        if a < min1 {
108            min1 = a;
109        }
110        if a > max1 {
111            max1 = a;
112        }
113        if b < min2 {
114            min2 = b;
115        }
116        if b > max2 {
117            max2 = b;
118        }
119    }
120    let rng1 = (max1 - min1).max(1e-8);
121    let rng2 = (max2 - min2).max(1e-8);
122    for i in (0..xs.len()).step_by(2) {
123        let a = xs[i];
124        let b = xs[i + 1];
125        xs[i] = 2.0 * (a - min1) / rng1 - 1.0;
126        xs[i + 1] = 2.0 * (b - min2) / rng2 - 1.0;
127    }
128
129    // Train/Val split (80/20)
130    let n_train = (n as f32 * 0.8) as usize;
131    let x_train = Tensor::from_slice(&xs[..n_train * 2], vec![n_train, 2]).unwrap();
132    let y_train = Tensor::from_slice(&ys[..n_train], vec![n_train, 1]).unwrap();
133    let x_val = Tensor::from_slice(&xs[n_train * 2..], vec![n - n_train, 2]).unwrap();
134    let y_val = Tensor::from_slice(&ys[n_train..], vec![n - n_train, 1]).unwrap();
135
136    // Model config: 2 -> 64 -> 64 -> 1
137    let cfg = FeedForwardConfig {
138        input_size: 2,
139        hidden_sizes: vec![64, 64],
140        output_size: 1,
141        use_bias: true,
142    };
143    let mut net = FeedForwardNetwork::new(cfg, Some(2025));
144
145    // Optimizer and parameter linking
146    let mut opt = Adam::with_learning_rate(1e-3);
147    for p in net.parameters() {
148        opt.add_parameter(p);
149    }
150
151    let epochs = 400usize;
152    let max_grad_norm = 1.0f32;
153    let mut best_val_rmse = f32::INFINITY;
154    let mut best_val_r2 = -f32::INFINITY;
155
156    for e in 0..epochs {
157        // Zero grads
158        {
159            let mut params = net.parameters();
160            opt.zero_grad(&mut params);
161        }
162
163        // Forward
164        let pred = net.forward(&x_train);
165        let mut loss = mse(&pred, &y_train);
166        loss.backward(None);
167
168        // Step
169        {
170            let params = net.parameters();
171            let mut with_grads: Vec<&mut Tensor> = Vec::new();
172            for p in params {
173                if p.grad_owned().is_some() {
174                    with_grads.push(p);
175                }
176            }
177            if !with_grads.is_empty() {
178                clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
179                opt.step(&mut with_grads);
180                opt.zero_grad(&mut with_grads);
181            }
182        }
183
184        // Metrics
185        let train_rmse = rmse(&pred, &y_train);
186        let train_r2 = r2_score(&pred, &y_train);
187        let val_pred = net.forward(&x_val);
188        let val_rmse = rmse(&val_pred, &y_val);
189        let val_r2 = r2_score(&val_pred, &y_val);
190        if val_rmse < best_val_rmse {
191            best_val_rmse = val_rmse;
192        }
193        if val_r2 > best_val_r2 {
194            best_val_r2 = val_r2;
195        }
196
197        if e % 20 == 0 || e + 1 == epochs {
198            // Clamp displayed R^2 to avoid huge negative prints on early epochs
199            let train_r2_disp = train_r2.max(-10.0);
200            let val_r2_disp = val_r2.max(-10.0);
201            println!(
202                "epoch {:4} | train_rmse={:.4} r2={:.3} | val_rmse={:.4} r2={:.3} | best_val_rmse={:.4} best_val_r2={:.3}",
203                e, train_rmse, train_r2_disp, val_rmse, val_r2_disp, best_val_rmse, best_val_r2
204            );
205        }
206
207        clear_all_graphs_known();
208    }
209
210    // Quick sanity predictions on small samples
211    let sample = Tensor::from_slice(&[0.5, -0.25, -0.8, 0.3], vec![2, 2]).unwrap();
212    let sample_pred = net.forward(&sample);
213    println!("samples pred: {:?}", sample_pred.data());
214
215    println!("=== Supervised regression finished ===");
216    Ok(())
217}
Source

pub fn id(&self) -> usize

Get the unique ID of this tensor

Returns the unique identifier assigned to this tensor during creation. This ID is used for gradtrack tracking and tensor identification.

§Returns

Unique tensor ID as usize

§Examples
use train_station::Tensor;

let tensor1 = Tensor::new(vec![2, 3]);
let tensor2 = Tensor::new(vec![2, 3]);
assert_ne!(tensor1.id(), tensor2.id()); // Each tensor has unique ID
Source

pub fn detach(&self) -> Self

Detach this tensor from the computation graph

Returns a new tensor with the same data but no gradient tracking. This is useful when you want to use a tensor in inference without affecting the computation graph.

§Returns

A new tensor with the same data but gradient tracking disabled

§Examples
use train_station::Tensor;

let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
let detached = tensor.detach();
assert!(!detached.requires_grad());
assert_eq!(tensor.size(), detached.size());
Source

pub fn detach_(&mut self)

Create a new tensor that doesn’t track gradients from this one

Similar to detach() but modifies this tensor in place. This is useful when you want to disable gradient tracking for the current tensor without creating a copy.

§Examples
use train_station::Tensor;

let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
assert!(tensor.requires_grad());
tensor.detach_();
assert!(!tensor.requires_grad());
Source

pub fn backward(&mut self, grad_output: Option<Tensor>)

Entry point for backward pass on this tensor

Computes gradients for all tensors in the computation graph that have requires_grad set to true. This is the main entry point for automatic differentiation.

§Arguments
  • grad_output - Optional gradient tensor for the output. If None, assumes the tensor is a scalar (e.g., loss value) and uses a tensor of ones.
§Examples
use train_station::Tensor;

let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
let mut result = tensor.add_scalar(5.0);
result.backward(None);
// Note: Gradient computation depends on the gradtrack system implementation
Examples found in repository?
examples/neural_networks/basic_encoder.rs (line 95)
73fn main() -> Result<(), Box<dyn std::error::Error>> {
74    println!("=== Basic Encoder Example ===");
75
76    let batch = 2usize;
77    let seq = 6usize;
78    let embed = 32usize;
79    let heads = 4usize;
80
81    let input = Tensor::randn(vec![batch, seq, embed], Some(11));
82    let mut enc = EncoderBlock::new(embed, heads, Some(123));
83
84    // Example: no mask (set Some(mask) to use masking)
85    let out = enc.forward(&input, None);
86    println!("Output shape: {:?}", out.shape().dims());
87
88    // Verify gradients/optimization
89    let mut opt = Adam::with_learning_rate(0.01);
90    let mut params = enc.parameters();
91    for p in &params {
92        opt.add_parameter(p);
93    }
94    let mut loss = out.mean();
95    loss.backward(None);
96    opt.step(&mut params);
97    opt.zero_grad(&mut params);
98    println!("Loss: {:.6}", loss.value());
99    println!("=== Done ===");
100    Ok(())
101}
More examples
Hide additional examples
examples/neural_networks/basic_decoder.rs (line 106)
84fn main() -> Result<(), Box<dyn std::error::Error>> {
85    println!("=== Basic Decoder Example ===");
86
87    let batch = 2usize;
88    let src = 7usize;
89    let tgt = 5usize;
90    let embed = 32usize;
91    let heads = 4usize;
92
93    let memory = Tensor::randn(vec![batch, src, embed], Some(21));
94    let tgt_in = Tensor::randn(vec![batch, tgt, embed], Some(22));
95
96    let mut dec = DecoderBlock::new(embed, heads, Some(456));
97    let out = dec.forward(&tgt_in, &memory, None, None);
98    println!("Output shape: {:?}", out.shape().dims());
99
100    let mut opt = Adam::with_learning_rate(0.01);
101    let mut params = dec.parameters();
102    for p in &params {
103        opt.add_parameter(p);
104    }
105    let mut loss = out.mean();
106    loss.backward(None);
107    opt.step(&mut params);
108    opt.zero_grad(&mut params);
109    println!("Loss: {:.6}", loss.value());
110    println!("=== Done ===");
111    Ok(())
112}
examples/iterators/element_iteration.rs (line 196)
175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176    println!("\n--- Gradient Tracking ---");
177
178    // Create a tensor with gradient tracking enabled
179    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180    println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182    // Perform element-wise operations through iteration
183    let result: Tensor = tensor
184        .iter()
185        .map(|elem| {
186            // Apply a complex transformation: (x^2 + 1) * 2
187            elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188        })
189        .collect();
190
191    println!("Result tensor: {:?}", result.data());
192    println!("Result requires_grad: {}", result.requires_grad());
193
194    // Compute gradients
195    let mut loss = result.sum();
196    loss.backward(None);
197
198    println!("Loss: {:.6}", loss.value());
199    println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201    Ok(())
202}
examples/neural_networks/basic_transformer.rs (line 171)
148    pub fn train_non_autoregressive_steps(
149        &mut self,
150        src: &Tensor,
151        tgt: &Tensor,
152        steps: usize,
153        lr: f32,
154    ) {
155        let mut opt = Adam::with_learning_rate(lr);
156        {
157            let params_once = self.parameters();
158            for p in &params_once {
159                opt.add_parameter(p);
160            }
161        }
162        for step in 0..steps {
163            // forward + backward scope (immutable borrow)
164            {
165                let pred = self.forward(src, tgt);
166                let diff = pred.sub_tensor(tgt);
167                let mut loss = diff.pow_scalar(2.0).mean();
168                if step == 0 || step + 1 == steps {
169                    println!("NAR train step {}: loss={:.6}", step, loss.value());
170                }
171                loss.backward(None);
172            }
173            // step + zero_grad scope (mutable borrow)
174            let mut params_step = self.parameters();
175            opt.step(&mut params_step);
176            opt.zero_grad(&mut params_step);
177        }
178    }
179
180    /// Auto-regressive training (teacher forcing): predict next token with causal mask
181    pub fn train_autoregressive_steps(
182        &mut self,
183        src: &Tensor,
184        tgt: &Tensor,
185        steps: usize,
186        lr: f32,
187    ) {
188        let mut opt = Adam::with_learning_rate(lr);
189        {
190            let params_once = self.parameters();
191            for p in &params_once {
192                opt.add_parameter(p);
193            }
194        }
195
196        // Build encoder memory once (static dataset demo)
197        let mut memory = src.clone();
198        for enc in &self.encoders {
199            memory = enc.forward(&memory, None);
200        }
201
202        let (b, t, _e) = Self::triple(tgt);
203        // Predict y[t] from y[:t] using causal mask; here we simply predict full seq with mask
204        let causal = Self::build_causal_mask_static(b, self.num_heads, t);
205        for step in 0..steps {
206            // forward + backward scope
207            {
208                let mut out = tgt.clone();
209                for dec in &self.decoders {
210                    out = dec.forward(&out, &memory, Some(&causal), None);
211                }
212                let diff = out.sub_tensor(tgt);
213                let mut loss = diff.pow_scalar(2.0).mean();
214                if step == 0 || step + 1 == steps {
215                    println!("AR  train step {}: loss={:.6}", step, loss.value());
216                }
217                loss.backward(None);
218            }
219            let mut params_step = self.parameters();
220            opt.step(&mut params_step);
221            opt.zero_grad(&mut params_step);
222        }
223    }
224
225    fn triple(t: &Tensor) -> (usize, usize, usize) {
226        let d = t.shape().dims();
227        (d[0], d[1], d[2])
228    }
229}
230
231fn main() -> Result<(), Box<dyn std::error::Error>> {
232    println!("=== Basic Transformer Example ===");
233
234    let batch = 2usize;
235    let src_len = 8usize;
236    let tgt_len = 6usize;
237    let embed = 32usize;
238    let heads = 4usize;
239    let layers = 2usize;
240
241    let src = Tensor::randn(vec![batch, src_len, embed], Some(1001));
242    let tgt = Tensor::randn(vec![batch, tgt_len, embed], Some(1002));
243
244    let mut trf = BasicTransformer::new(embed, heads, layers, Some(999));
245    let out = trf.forward(&src, &tgt);
246    println!("Output shape: {:?}", out.shape().dims());
247
248    // Quick optimization step
249    let mut opt = Adam::with_learning_rate(0.005);
250    let mut params = trf.parameters();
251    for p in &params {
252        opt.add_parameter(p);
253    }
254    let mut loss = out.mean();
255    loss.backward(None);
256    opt.step(&mut params);
257    opt.zero_grad(&mut params);
258    println!("Loss: {:.6}", loss.value());
259
260    // Demo: non auto-regressive inference (single pass)
261    let nar = trf.infer_non_autoregressive(&src, tgt_len);
262    println!("NAR output shape: {:?}", nar.shape().dims());
263
264    // Demo: auto-regressive inference (toy)
265    let ar = trf.infer_autoregressive(&src, 3);
266    println!("AR output shape: {:?}", ar.shape().dims());
267
268    // NAR training demo
269    let nar_tgt = tgt.clone();
270    trf.train_non_autoregressive_steps(&src, &nar_tgt, 3, 0.01);
271
272    // AR training demo (teacher-forced)
273    let ar_tgt = tgt.clone();
274    trf.train_autoregressive_steps(&src, &ar_tgt, 3, 0.01);
275    println!("=== Done ===");
276    Ok(())
277}
examples/getting_started/serialization_basics.rs (line 138)
109fn demonstrate_optimizer_serialization() -> Result<(), Box<dyn std::error::Error>> {
110    println!("\n--- Optimizer Serialization ---");
111
112    // Create an optimizer with some parameters
113    let mut weight = Tensor::randn(vec![2, 2], Some(42)).with_requires_grad();
114    let mut bias = Tensor::randn(vec![2], Some(43)).with_requires_grad();
115
116    let config = AdamConfig {
117        learning_rate: 0.001,
118        beta1: 0.9,
119        beta2: 0.999,
120        eps: 1e-8,
121        weight_decay: 0.0,
122        amsgrad: false,
123    };
124
125    let mut optimizer = Adam::with_config(config);
126    optimizer.add_parameter(&weight);
127    optimizer.add_parameter(&bias);
128
129    println!(
130        "Created optimizer with {} parameters",
131        optimizer.parameter_count()
132    );
133    println!("Learning rate: {}", optimizer.learning_rate());
134
135    // Simulate some training steps
136    for _ in 0..3 {
137        let mut loss = weight.sum() + bias.sum();
138        loss.backward(None);
139        optimizer.step(&mut [&mut weight, &mut bias]);
140        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
141    }
142
143    // Save optimizer state
144    let optimizer_path = "temp_optimizer.json";
145    optimizer.save_json(optimizer_path)?;
146    println!("Saved optimizer to: {}", optimizer_path);
147
148    // Load optimizer state
149    let loaded_optimizer = Adam::load_json(optimizer_path)?;
150    println!(
151        "Loaded optimizer with {} parameters",
152        loaded_optimizer.parameter_count()
153    );
154    println!("Learning rate: {}", loaded_optimizer.learning_rate());
155
156    // Verify optimizer state
157    assert_eq!(
158        optimizer.parameter_count(),
159        loaded_optimizer.parameter_count()
160    );
161    assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
162    println!("Optimizer serialization verification: PASSED");
163
164    Ok(())
165}
166
167/// Demonstrate format comparison and performance characteristics
168fn demonstrate_format_comparison() -> Result<(), Box<dyn std::error::Error>> {
169    println!("\n--- Format Comparison ---");
170
171    // Create a larger tensor for comparison
172    let tensor = Tensor::randn(vec![10, 10], Some(44));
173
174    // Save in both formats
175    tensor.save_json("temp_comparison.json")?;
176    tensor.save_binary("temp_comparison.bin")?;
177
178    // Compare file sizes
179    let json_size = fs::metadata("temp_comparison.json")?.len();
180    let binary_size = fs::metadata("temp_comparison.bin")?.len();
181
182    println!("JSON file size: {} bytes", json_size);
183    println!("Binary file size: {} bytes", binary_size);
184    println!(
185        "Compression ratio: {:.2}x",
186        json_size as f64 / binary_size as f64
187    );
188
189    // Load and verify both formats
190    let json_tensor = Tensor::load_json("temp_comparison.json")?;
191    let binary_tensor = Tensor::load_binary("temp_comparison.bin")?;
192
193    assert_eq!(tensor.shape().dims(), json_tensor.shape().dims());
194    assert_eq!(tensor.shape().dims(), binary_tensor.shape().dims());
195    assert_eq!(tensor.data(), json_tensor.data());
196    assert_eq!(tensor.data(), binary_tensor.data());
197
198    println!("Format comparison verification: PASSED");
199
200    Ok(())
201}
202
203/// Demonstrate a basic model checkpointing workflow
204fn demonstrate_model_checkpointing() -> Result<(), Box<dyn std::error::Error>> {
205    println!("\n--- Model Checkpointing ---");
206
207    // Create a simple model (weights and bias)
208    let mut weights = Tensor::randn(vec![2, 1], Some(45)).with_requires_grad();
209    let mut bias = Tensor::randn(vec![1], Some(46)).with_requires_grad();
210
211    // Create optimizer
212    let mut optimizer = Adam::with_learning_rate(0.01);
213    optimizer.add_parameter(&weights);
214    optimizer.add_parameter(&bias);
215
216    println!("Initial weights: {:?}", weights.data());
217    println!("Initial bias: {:?}", bias.data());
218
219    // Simulate training
220    for epoch in 0..5 {
221        let mut loss = weights.sum() + bias.sum();
222        loss.backward(None);
223        optimizer.step(&mut [&mut weights, &mut bias]);
224        optimizer.zero_grad(&mut [&mut weights, &mut bias]);
225
226        if epoch % 2 == 0 {
227            // Save checkpoint
228            let checkpoint_dir = format!("checkpoint_epoch_{}", epoch);
229            fs::create_dir_all(&checkpoint_dir)?;
230
231            weights.save_json(format!("{}/weights.json", checkpoint_dir))?;
232            bias.save_json(format!("{}/bias.json", checkpoint_dir))?;
233            optimizer.save_json(format!("{}/optimizer.json", checkpoint_dir))?;
234
235            println!("Saved checkpoint for epoch {}", epoch);
236        }
237    }
238
239    // Load from checkpoint
240    let loaded_weights = Tensor::load_json("checkpoint_epoch_4/weights.json")?;
241    let loaded_bias = Tensor::load_json("checkpoint_epoch_4/bias.json")?;
242    let loaded_optimizer = Adam::load_json("checkpoint_epoch_4/optimizer.json")?;
243
244    println!("Loaded weights: {:?}", loaded_weights.data());
245    println!("Loaded bias: {:?}", loaded_bias.data());
246    println!(
247        "Loaded optimizer learning rate: {}",
248        loaded_optimizer.learning_rate()
249    );
250
251    // Verify checkpoint integrity
252    assert_eq!(weights.shape().dims(), loaded_weights.shape().dims());
253    assert_eq!(bias.shape().dims(), loaded_bias.shape().dims());
254    assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
255
256    println!("Checkpointing verification: PASSED");
257
258    Ok(())
259}
examples/neural_networks/multi_head_attention.rs (line 209)
165fn main() -> Result<(), Box<dyn std::error::Error>> {
166    println!("=== Multi-Head Attention Example ===");
167
168    let batch = 2usize;
169    let src_len = 5usize;
170    let tgt_len = 4usize;
171    let embed = 16usize;
172    let heads = 4usize;
173
174    let query = Tensor::randn(vec![batch, tgt_len, embed], Some(7));
175    let key = Tensor::randn(vec![batch, src_len, embed], Some(8));
176    let value = Tensor::randn(vec![batch, src_len, embed], Some(9));
177
178    let mut mha = MultiHeadAttention::new(embed, heads, Some(42));
179
180    // Simple causal mask for target self-attention shape [b, h, tq, tk]
181    let mut mask = Tensor::zeros(vec![batch, heads, tgt_len, src_len]);
182    // Disallow attending to future positions when tgt_len <= src_len by adding -1e9
183    // Here, just demonstrate mask broadcast/add mechanics with a light mask on last head
184    if src_len >= tgt_len {
185        // set upper triangle to a large negative value for head 0
186        for i in 0..tgt_len {
187            for j in (i + 1)..src_len {
188                let idx = [0usize, 0usize, i, j];
189                // Quick set via data_mut using a slice view
190                let offset = mask.memory_offset(&idx);
191                let data = mask.data_mut();
192                data[offset] = -1e9;
193            }
194        }
195    }
196
197    let out = mha.forward(&query, &key, &value, Some(&mask));
198    println!("Output shape: {:?}", out.shape().dims());
199
200    // Tiny training step to confirm gradients are wired
201    let mut optimizer = Adam::with_learning_rate(0.01);
202    let mut params = mha.parameters();
203    for p in &params {
204        optimizer.add_parameter(p);
205    }
206
207    // Dummy loss = mean of output
208    let mut loss = out.mean();
209    loss.backward(None);
210    optimizer.step(&mut params);
211    optimizer.zero_grad(&mut params);
212
213    println!("Loss: {:.6}", loss.value());
214    println!("=== Done ===");
215    Ok(())
216}
Source

pub unsafe fn as_ptr(&self) -> *const f32

Returns a raw pointer to the tensor data for unsafe operations

§Safety

This is unsafe because it provides direct access to the underlying memory. The caller must ensure:

  • The tensor is not dropped while the pointer is used
  • No concurrent mutable access occurs
  • Bounds are respected
Source

pub unsafe fn as_mut_ptr(&mut self) -> *mut f32

Returns a mutable raw pointer to the tensor data for unsafe operations

§Safety

This is unsafe because it provides direct mutable access to the underlying memory. The caller must ensure:

  • The tensor is not dropped while the pointer is used
  • No concurrent access occurs
  • Bounds are respected
Source

pub fn grad_fn(&self) -> &GradFn

Get a reference to the gradient function (for gradtrack)

Returns a reference to the gradient function associated with this tensor. This is used internally by the gradtrack system to compute gradients.

§Returns

Reference to the gradient function

§Implementation Details

This method is used by the gradtrack engine to access the gradient computation function during backward pass.

Source

pub fn set_grad(&mut self, grad: Tensor)

Set gradient from external source

Sets the gradient tensor for this tensor. This is used internally by the gradtrack system to set gradients during backward pass.

§Arguments
  • grad - The gradient tensor to set
§Implementation Details

This method is used internally by the gradtrack engine to set gradients during backward pass. It only sets the gradient if gradient tracking is enabled for this tensor.

Examples found in repository?
examples/RL_training/dqn.rs (line 291)
277fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
278    let mut total_sq = 0.0f32;
279    for p in parameters.iter() {
280        if let Some(g) = p.grad_owned() {
281            for &v in g.data() {
282                total_sq += v * v;
283            }
284        }
285    }
286    let norm = total_sq.sqrt();
287    if norm > max_norm {
288        let scale = max_norm / (norm + eps);
289        for p in parameters.iter_mut() {
290            if let Some(g) = p.grad_owned() {
291                p.set_grad(g.mul_scalar(scale));
292            }
293        }
294    }
295}
More examples
Hide additional examples
examples/RL_training/ppo_continuous.rs (line 307)
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294    let mut total_sq = 0.0f32;
295    for p in parameters.iter() {
296        if let Some(g) = p.grad_owned() {
297            for &v in g.data() {
298                total_sq += v * v;
299            }
300        }
301    }
302    let norm = total_sq.sqrt();
303    if norm > max_norm {
304        let scale = max_norm / (norm + eps);
305        for p in parameters.iter_mut() {
306            if let Some(g) = p.grad_owned() {
307                p.set_grad(g.mul_scalar(scale));
308            }
309        }
310    }
311}
examples/RL_training/ppo_discrete.rs (line 266)
252fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
253    let mut total_sq = 0.0f32;
254    for p in parameters.iter() {
255        if let Some(g) = p.grad_owned() {
256            for &v in g.data() {
257                total_sq += v * v;
258            }
259        }
260    }
261    let norm = total_sq.sqrt();
262    if norm > max_norm {
263        let scale = max_norm / (norm + eps);
264        for p in parameters.iter_mut() {
265            if let Some(g) = p.grad_owned() {
266                p.set_grad(g.mul_scalar(scale));
267            }
268        }
269    }
270}
examples/supervised_training/supervised_bce.rs (line 37)
23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24    let mut total_sq = 0.0f32;
25    for p in parameters.iter() {
26        if let Some(g) = p.grad_owned() {
27            for &v in g.data() {
28                total_sq += v * v;
29            }
30        }
31    }
32    let norm = total_sq.sqrt();
33    if norm > max_norm {
34        let scale = max_norm / (norm + eps);
35        for p in parameters.iter_mut() {
36            if let Some(g) = p.grad_owned() {
37                p.set_grad(g.mul_scalar(scale));
38            }
39        }
40    }
41}
examples/supervised_training/supervised_classification.rs (line 37)
23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24    let mut total_sq = 0.0f32;
25    for p in parameters.iter() {
26        if let Some(g) = p.grad_owned() {
27            for &v in g.data() {
28                total_sq += v * v;
29            }
30        }
31    }
32    let norm = total_sq.sqrt();
33    if norm > max_norm {
34        let scale = max_norm / (norm + eps);
35        for p in parameters.iter_mut() {
36            if let Some(g) = p.grad_owned() {
37                p.set_grad(g.mul_scalar(scale));
38            }
39        }
40    }
41}
examples/supervised_training/supervised_regression.rs (line 36)
22fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
23    let mut total_sq = 0.0f32;
24    for p in parameters.iter() {
25        if let Some(g) = p.grad_owned() {
26            for &v in g.data() {
27                total_sq += v * v;
28            }
29        }
30    }
31    let norm = total_sq.sqrt();
32    if norm > max_norm {
33        let scale = max_norm / (norm + eps);
34        for p in parameters.iter_mut() {
35            if let Some(g) = p.grad_owned() {
36                p.set_grad(g.mul_scalar(scale));
37            }
38        }
39    }
40}
Source

pub fn zero_grad(&mut self)

Clear accumulated gradients for this tensor

This method is used by optimizers to zero gradients before each backward pass. It clears any accumulated gradients, allowing for fresh gradient computation.

§Examples
use train_station::Tensor;

let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
tensor.set_grad(Tensor::ones(vec![2, 3]));
assert!(tensor.grad().is_some());
tensor.zero_grad();
assert!(tensor.grad().is_none());
Source

pub fn is_contiguous(&self) -> bool

Checks if the tensor data is stored contiguously in memory

§Returns

true if the tensor data is contiguous, enabling optimized SIMD operations

§Examples
use train_station::Tensor;

let tensor = Tensor::new(vec![2, 3, 4]);
assert!(tensor.is_contiguous());
Examples found in repository?
examples/getting_started/tensor_basics.rs (line 187)
179fn demonstrate_utility_functions() {
180    println!("\n--- Utility Functions ---");
181
182    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184    // Basic properties
185    println!("Shape: {:?}", tensor.shape().dims());
186    println!("Size: {}", tensor.size());
187    println!("Is contiguous: {}", tensor.is_contiguous());
188    println!("Device: {:?}", tensor.device());
189
190    // Mathematical operations
191    let sum = tensor.sum();
192    println!("Sum: {}", sum.value());
193
194    let mean = tensor.mean();
195    println!("Mean: {}", mean.value());
196
197    let norm = tensor.norm();
198    println!("Norm: {}", norm.value());
199
200    // Device placement
201    let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202    println!(
203        "CPU tensor: shape {:?}, device: {:?}",
204        cpu_tensor.shape().dims(),
205        cpu_tensor.device()
206    );
207}
Source

pub fn strides(&self) -> &[usize]

Gets the memory strides for all dimensions

§Returns

Reference to the stride vector for efficient memory access calculations

§Examples
use train_station::Tensor;

let tensor = Tensor::new(vec![2, 3, 4]);
assert_eq!(tensor.strides(), &[12, 4, 1]);
Source

pub fn stride(&self, dim: usize) -> usize

Gets the memory stride for a specific dimension

§Arguments
  • dim - The dimension index
§Returns

The memory stride for the given dimension

§Panics

Panics if dim is out of bounds

Source

pub fn memory_offset(&self, indices: &[usize]) -> usize

Calculates the linear memory offset for given multi-dimensional indices

§Arguments
  • indices - Vector of indices for each dimension
§Returns

Linear memory offset for direct memory access

§Examples
use train_station::Tensor;

let tensor = Tensor::new(vec![2, 3, 4]);
let offset = tensor.memory_offset(&[1, 2, 3]);
// offset = 1*12 + 2*4 + 3*1 = 23
Examples found in repository?
examples/neural_networks/basic_transformer.rs (line 95)
76    pub fn infer_autoregressive(&self, src: &Tensor, max_steps: usize) -> Tensor {
77        let (b, _s, e) = Self::triple(src);
78        let mut memory = src.clone();
79        for enc in &self.encoders {
80            memory = enc.forward(&memory, None);
81        }
82
83        let mut out_seq: Vec<Tensor> = Vec::new();
84        // Start token: zeros
85        let mut current = Tensor::zeros(vec![b, 1, e]);
86        for _step in 0..max_steps {
87            // Build causal mask for length t
88            let t = current.shape().dims()[1];
89            let mut causal = Tensor::ones(vec![b, self.num_heads, t, t]);
90            // Upper triangle as false -> masked for all batches and heads
91            for bb in 0..b {
92                for hh in 0..self.num_heads {
93                    for i in 0..t {
94                        for j in (i + 1)..t {
95                            let offset = causal.memory_offset(&[bb, hh, i, j]);
96                            let data = causal.data_mut();
97                            data[offset] = 0.0;
98                        }
99                    }
100                }
101            }
102            let mut step_out = current.clone();
103            for dec in &self.decoders {
104                step_out = dec.forward(&step_out, &memory, Some(&causal), None);
105            }
106            // (Toy) append placeholder token; real models would project last token
107            out_seq.push(step_out.clone());
108            // Append a zero token to grow sequence by 1 for next causal computation
109            current = Tensor::zeros(vec![b, t + 1, e]);
110        }
111        // Simple return of final sequence placeholder
112        current
113    }
114
115    /// Non auto-regressive inference: single forward pass
116    pub fn infer_non_autoregressive(&self, src: &Tensor, tgt_len: usize) -> Tensor {
117        let (b, _s, e) = Self::triple(src);
118        let mut memory = src.clone();
119        for enc in &self.encoders {
120            memory = enc.forward(&memory, None);
121        }
122        let tgt = Tensor::zeros(vec![b, tgt_len, e]);
123        let mut out = tgt.clone();
124        for dec in &self.decoders {
125            out = dec.forward(&out, &memory, None, None);
126        }
127        out
128    }
129
130    /// Helper: build boolean-like causal mask [b, heads, t, t] with 1.0 keep, 0.0 masked
131    fn build_causal_mask_static(batch: usize, heads: usize, t: usize) -> Tensor {
132        let mut mask = Tensor::ones(vec![batch, heads, t, t]);
133        for bb in 0..batch {
134            for hh in 0..heads {
135                for i in 0..t {
136                    for j in (i + 1)..t {
137                        let offset = mask.memory_offset(&[bb, hh, i, j]);
138                        let data = mask.data_mut();
139                        data[offset] = 0.0;
140                    }
141                }
142            }
143        }
144        mask
145    }
More examples
Hide additional examples
examples/neural_networks/multi_head_attention.rs (line 190)
165fn main() -> Result<(), Box<dyn std::error::Error>> {
166    println!("=== Multi-Head Attention Example ===");
167
168    let batch = 2usize;
169    let src_len = 5usize;
170    let tgt_len = 4usize;
171    let embed = 16usize;
172    let heads = 4usize;
173
174    let query = Tensor::randn(vec![batch, tgt_len, embed], Some(7));
175    let key = Tensor::randn(vec![batch, src_len, embed], Some(8));
176    let value = Tensor::randn(vec![batch, src_len, embed], Some(9));
177
178    let mut mha = MultiHeadAttention::new(embed, heads, Some(42));
179
180    // Simple causal mask for target self-attention shape [b, h, tq, tk]
181    let mut mask = Tensor::zeros(vec![batch, heads, tgt_len, src_len]);
182    // Disallow attending to future positions when tgt_len <= src_len by adding -1e9
183    // Here, just demonstrate mask broadcast/add mechanics with a light mask on last head
184    if src_len >= tgt_len {
185        // set upper triangle to a large negative value for head 0
186        for i in 0..tgt_len {
187            for j in (i + 1)..src_len {
188                let idx = [0usize, 0usize, i, j];
189                // Quick set via data_mut using a slice view
190                let offset = mask.memory_offset(&idx);
191                let data = mask.data_mut();
192                data[offset] = -1e9;
193            }
194        }
195    }
196
197    let out = mha.forward(&query, &key, &value, Some(&mask));
198    println!("Output shape: {:?}", out.shape().dims());
199
200    // Tiny training step to confirm gradients are wired
201    let mut optimizer = Adam::with_learning_rate(0.01);
202    let mut params = mha.parameters();
203    for p in &params {
204        optimizer.add_parameter(p);
205    }
206
207    // Dummy loss = mean of output
208    let mut loss = out.mean();
209    loss.backward(None);
210    optimizer.step(&mut params);
211    optimizer.zero_grad(&mut params);
212
213    println!("Loss: {:.6}", loss.value());
214    println!("=== Done ===");
215    Ok(())
216}
Source

pub fn memory_alignment(&self) -> usize

Gets the memory alignment of the tensor data

§Returns

The memory alignment in bytes (typically 32 for SIMD optimization)

Source

pub fn is_broadcastable_with(&self, other: &Tensor) -> bool

Checks if this tensor is broadcastable with another tensor

§Arguments
  • other - The other tensor to check broadcasting compatibility
§Returns

true if the tensors are broadcastable according to NumPy broadcasting rules

§Examples
use train_station::Tensor;

let a = Tensor::new(vec![2, 3, 4]);
let b = Tensor::new(vec![1, 3, 4]);
assert!(a.is_broadcastable_with(&b));
Source

pub fn memory_footprint(&self) -> usize

Gets the total number of bytes allocated for this tensor

§Returns

Total memory footprint in bytes

Source

pub fn get(&self, indices: &[usize]) -> f32

Get a single element from the tensor at the specified indices

§Arguments
  • indices - Multi-dimensional indices to access the element
§Returns

The value at the specified position

§Panics

Panics if indices are out of bounds or indices length doesn’t match tensor rank

§Examples
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let value = tensor.get(&[0, 1]);
assert_eq!(value, 2.0);
Examples found in repository?
examples/getting_started/tensor_basics.rs (line 162)
156fn demonstrate_data_access() {
157    println!("\n--- Data Access ---");
158
159    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
160
161    // Access individual elements
162    println!("Element [0, 0]: {}", tensor.get(&[0, 0]));
163    println!("Element [0, 1]: {}", tensor.get(&[0, 1]));
164    println!("Element [1, 0]: {}", tensor.get(&[1, 0]));
165    println!("Element [1, 1]: {}", tensor.get(&[1, 1]));
166
167    // Access data as slice
168    let data = tensor.data();
169    println!("Data as slice: {:?}", data);
170
171    // Iterate over elements
172    println!("Elements:");
173    for (i, &value) in data.iter().enumerate() {
174        println!("  [{}]: {}", i, value);
175    }
176}
Source

pub fn set(&mut self, indices: &[usize], value: f32)

Set a single element in the tensor at the specified indices

§Arguments
  • indices - Multi-dimensional indices to set the element
  • value - The value to set
§Panics

Panics if indices are out of bounds or indices length doesn’t match tensor rank

§Examples
use train_station::Tensor;

let mut tensor = Tensor::new(vec![2, 2]);
tensor.set(&[0, 1], 42.0);
assert_eq!(tensor.get(&[0, 1]), 42.0);
Source

pub fn data(&self) -> &[f32]

Returns a safe slice of the tensor’s underlying data

Provides safe access to the tensor’s data without requiring unsafe pointer operations. This is the preferred way to access tensor data for reading values, comparisons, and other operations that don’t require direct pointer manipulation.

§Returns

A slice containing all tensor elements in row-major order

§Performance
  • Zero-Cost: Direct slice creation with no copying
  • Cache-Friendly: Sequential memory access pattern
  • Safe: No unsafe code required for basic data access
§Examples
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let data = tensor.data();

// Safe indexing and comparisons
assert_eq!(data[0], 1.0);
assert_eq!(data.len(), tensor.size());
Examples found in repository?
examples/RL_training/dqn.rs (line 111)
108    fn copy_from(&mut self, other: &Self) {
109        for (t, s) in self.layers.iter_mut().zip(other.layers.iter()) {
110            {
111                let src = s.weight.data();
112                let dst = t.weight.data_mut();
113                dst.copy_from_slice(src);
114            }
115            {
116                let src = s.bias.data();
117                let dst = t.bias.data_mut();
118                dst.copy_from_slice(src);
119            }
120            t.weight.set_requires_grad(false);
121            t.bias.set_requires_grad(false);
122        }
123    }
124}
125
126// -------------------------------
127// Q-Network (state -> Q-values over actions)
128// -------------------------------
129
130struct QNet {
131    net: Mlp,
132}
133
134impl QNet {
135    fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
136        let net = Mlp::new(&[state_dim, 64, 64, action_dim], seed);
137        Self { net }
138    }
139    fn forward(&self, state: &Tensor) -> Tensor {
140        self.net.forward(state, None)
141    }
142    fn parameters(&mut self) -> Vec<&mut Tensor> {
143        self.net.parameters()
144    }
145    fn set_requires_grad_all(&mut self, enable: bool) {
146        self.net.set_requires_grad_all(enable);
147    }
148}
149
150// -------------------------------
151// Discrete YardEnv (3 actions: -1, 0, +1)
152// -------------------------------
153
154struct YardEnv {
155    pos: f32,
156    vel: f32,
157    steps: usize,
158    max_steps: usize,
159    rng: SmallRng,
160}
161
162impl YardEnv {
163    const ACTIONS: [f32; 3] = [-1.0, 0.0, 1.0];
164
165    fn new(seed: u64) -> Self {
166        let mut env = Self {
167            pos: 0.0,
168            vel: 0.0,
169            steps: 0,
170            max_steps: 200,
171            rng: SmallRng::new(seed),
172        };
173        env.reset();
174        env
175    }
176
177    fn reset(&mut self) -> Tensor {
178        self.pos = self.rng.uniform(-0.5, 0.5);
179        self.vel = self.rng.uniform(-0.1, 0.1);
180        self.steps = 0;
181        self.state_tensor()
182    }
183
184    fn state_tensor(&self) -> Tensor {
185        Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
186    }
187
188    fn step(&mut self, action_index: usize) -> (Tensor, f32, bool) {
189        let a = Self::ACTIONS[action_index.min(2)];
190        self.vel += 0.1 * a - 0.01 * self.pos;
191        self.pos += self.vel;
192        self.steps += 1;
193        let reward = -(self.pos * self.pos) - 0.05 * (a * a);
194        let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
195        (self.state_tensor(), reward, done)
196    }
197}
198
199// -------------------------------
200// Replay Buffer
201// -------------------------------
202
203struct ReplayBuffer {
204    capacity: usize,
205    size: usize,
206    pos: usize,
207    state_dim: usize,
208    states: Vec<f32>,
209    actions: Vec<usize>,
210    rewards: Vec<f32>,
211    dones: Vec<f32>,
212    next_states: Vec<f32>,
213}
214
215impl ReplayBuffer {
216    fn new(capacity: usize, state_dim: usize) -> Self {
217        Self {
218            capacity,
219            size: 0,
220            pos: 0,
221            state_dim,
222            states: vec![0.0; capacity * state_dim],
223            actions: vec![0usize; capacity],
224            rewards: vec![0.0; capacity],
225            dones: vec![0.0; capacity],
226            next_states: vec![0.0; capacity * state_dim],
227        }
228    }
229
230    fn push(&mut self, s: &[f32], a_idx: usize, r: f32, d: f32, s2: &[f32]) {
231        let i = self.pos;
232        let so = i * self.state_dim;
233        self.states[so..so + self.state_dim].copy_from_slice(s);
234        self.actions[i] = a_idx;
235        self.rewards[i] = r;
236        self.dones[i] = d;
237        self.next_states[so..so + self.state_dim].copy_from_slice(s2);
238        self.pos = (self.pos + 1) % self.capacity;
239        self.size = self.size.saturating_add(1).min(self.capacity);
240    }
241
242    fn can_sample(&self, batch_size: usize) -> bool {
243        self.size >= batch_size
244    }
245
246    fn sample(
247        &self,
248        batch_size: usize,
249        rng: &mut SmallRng,
250    ) -> (Tensor, Vec<usize>, Tensor, Tensor, Tensor) {
251        let mut s_vec = Vec::with_capacity(batch_size * self.state_dim);
252        let mut a_idx = Vec::with_capacity(batch_size);
253        let mut r_vec = Vec::with_capacity(batch_size);
254        let mut d_vec = Vec::with_capacity(batch_size);
255        let mut s2_vec = Vec::with_capacity(batch_size * self.state_dim);
256        for _ in 0..batch_size {
257            let idx = rng.sample_index(self.size);
258            let so = idx * self.state_dim;
259            s_vec.extend_from_slice(&self.states[so..so + self.state_dim]);
260            a_idx.push(self.actions[idx]);
261            r_vec.push(self.rewards[idx]);
262            d_vec.push(self.dones[idx]);
263            s2_vec.extend_from_slice(&self.next_states[so..so + self.state_dim]);
264        }
265        let s = Tensor::from_slice(&s_vec, vec![batch_size, self.state_dim]).unwrap();
266        let r = Tensor::from_slice(&r_vec, vec![batch_size, 1]).unwrap();
267        let d = Tensor::from_slice(&d_vec, vec![batch_size, 1]).unwrap();
268        let s2 = Tensor::from_slice(&s2_vec, vec![batch_size, self.state_dim]).unwrap();
269        (s, a_idx, r, d, s2)
270    }
271}
272
273// -------------------------------
274// Helpers
275// -------------------------------
276
277fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
278    let mut total_sq = 0.0f32;
279    for p in parameters.iter() {
280        if let Some(g) = p.grad_owned() {
281            for &v in g.data() {
282                total_sq += v * v;
283            }
284        }
285    }
286    let norm = total_sq.sqrt();
287    if norm > max_norm {
288        let scale = max_norm / (norm + eps);
289        for p in parameters.iter_mut() {
290            if let Some(g) = p.grad_owned() {
291                p.set_grad(g.mul_scalar(scale));
292            }
293        }
294    }
295}
296
297fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
298    let mut total_sq = 0.0f32;
299    for p in parameters.iter_mut() {
300        if let Some(g) = p.grad_owned() {
301            for &v in g.data() {
302                total_sq += v * v;
303            }
304        }
305    }
306    total_sq.sqrt()
307}
308
309fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
310    let _ng = NoGradTrack::new();
311    let mut total_sq = 0.0f32;
312    for p in parameters.iter_mut() {
313        for &v in p.data() {
314            total_sq += v * v;
315        }
316    }
317    total_sq.sqrt()
318}
319
320// Pseudo-Huber loss: sqrt(1 + diff^2) - 1 (smooth, robust)
321fn pseudo_huber_mean(diff: &Tensor) -> Tensor {
322    diff.pow_scalar(2.0)
323        .add_scalar(1.0)
324        .sqrt()
325        .sub_scalar(1.0)
326        .mean()
327}
328
329// -------------------------------
330// Main
331// -------------------------------
332
333pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334    println!("=== DQN Example (YardEnv discrete) ===");
335
336    // Dims
337    let state_dim = 3usize;
338    let action_dim = 3usize;
339
340    // Hparams
341    let gamma = 0.99f32;
342    let batch_size = 64usize;
343    let start_steps = 200usize;
344    let target_update_interval = 200usize; // hard update cadence
345    let max_grad_norm = 1.0f32;
346    let mut epsilon = 1.0f32;
347    let eps_min = 0.05f32;
348    let eps_decay_steps = 2_000usize; // linear decay
349    let total_steps = std::env::var("DQN_STEPS")
350        .ok()
351        .and_then(|v| v.parse::<usize>().ok())
352        .unwrap_or(3000usize);
353
354    // Models
355    let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356    let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357    q_targ.net.copy_from(&q_net.net);
358    q_targ.set_requires_grad_all(false);
359
360    // Optimizer
361    let mut q_opt = Adam::with_learning_rate(3e-4);
362    for p in q_net.parameters() {
363        q_opt.add_parameter(p);
364    }
365
366    // Replay + env
367    let mut rb = ReplayBuffer::new(100_000, state_dim);
368    let mut env = YardEnv::new(12345);
369    let mut rng = SmallRng::new(999_111);
370
371    // Metrics
372    let mut state = env.reset();
373    let mut episode_return = 0.0f32;
374    let mut episode = 0usize;
375    let mut ema_return: Option<f32> = None;
376    let ema_alpha = 0.05f32;
377    let mut best_return = f32::NEG_INFINITY;
378
379    for t in 0..total_steps {
380        // Epsilon-greedy action
381        let action_index = if t < start_steps || rng.next_f32() < epsilon {
382            rng.sample_index(action_dim)
383        } else {
384            let _ng = NoGradTrack::new();
385            let q_vals = q_net.forward(&state);
386            let row = q_vals.data();
387            let mut best_i = 0usize;
388            let mut best_v = row[0];
389            for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390                if r > best_v {
391                    best_v = r;
392                    best_i = i;
393                }
394            }
395            best_i
396        };
397
398        // Env step
399        let (next_state, reward, done) = env.step(action_index);
400        episode_return += reward;
401
402        // Store
403        let s_slice = state.data().to_vec();
404        let s2_slice = next_state.data().to_vec();
405        rb.push(
406            &s_slice,
407            action_index,
408            reward,
409            if done { 1.0 } else { 0.0 },
410            &s2_slice,
411        );
412
413        // Reset on done
414        state = if done {
415            let st = env.reset();
416            ema_return = Some(match ema_return {
417                None => episode_return,
418                Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419            });
420            if episode_return > best_return {
421                best_return = episode_return;
422            }
423            println!(
424                "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425                t,
426                episode,
427                episode_return,
428                ema_return.unwrap_or(episode_return),
429                best_return,
430                rb.size
431            );
432            episode_return = 0.0;
433            episode += 1;
434            st
435        } else {
436            next_state
437        };
438
439        // Epsilon linear decay
440        if t < eps_decay_steps {
441            epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442        }
443
444        // Train
445        if rb.can_sample(batch_size) {
446            let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448            // Double DQN target: a* = argmax_a Q_online(s2,a); y = r + (1-d)*gamma*Q_target(s2, a*)
449            let target_q = {
450                let _ng = NoGradTrack::new();
451                let q_online_s2 = q_net.forward(&s2);
452                // argmax per row (manual on CPU)
453                let row_stride = action_dim;
454                let qd = q_online_s2.data();
455                let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456                for i in 0..batch_size {
457                    let base = i * row_stride;
458                    let mut bi = 0usize;
459                    let mut bv = qd[base];
460                    for j in 1..action_dim {
461                        let v = qd[base + j];
462                        if v > bv {
463                            bv = v;
464                            bi = j;
465                        }
466                    }
467                    next_actions.push(bi);
468                }
469                let q_targ_s2 = q_targ.forward(&s2);
470                let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471                let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472                r.add_tensor(&not_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473            };
474
475            // Q(s,a) for current actions
476            // Zero grads first
477            {
478                let mut params = q_net.parameters();
479                q_opt.zero_grad(&mut params);
480            }
481
482            let q_all = q_net.forward(&s);
483            let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484            let diff = q_sa.sub_tensor(&target_q);
485            let mut loss = pseudo_huber_mean(&diff);
486            loss.backward(None);
487
488            // Step (filter only params with grads)
489            {
490                let params = q_net.parameters();
491                let mut with_grads: Vec<&mut Tensor> = Vec::new();
492                for p in params {
493                    if p.grad_owned().is_some() {
494                        with_grads.push(p);
495                    }
496                }
497                if !with_grads.is_empty() {
498                    let gn = grad_global_norm(&mut with_grads);
499                    clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500                    q_opt.step(&mut with_grads);
501                    q_opt.zero_grad(&mut with_grads);
502                    if t % 100 == 0 {
503                        let mut pn = q_net.parameters();
504                        let pn_l2 = params_l2_norm(&mut pn);
505                        let q_mean = q_all.mean().value();
506                        println!(
507                            "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508                            t, loss.value(), q_mean, gn, pn_l2, epsilon
509                        );
510                    }
511                }
512            }
513
514            // Target hard update
515            if t % target_update_interval == 0 {
516                q_targ.net.copy_from(&q_net.net);
517            }
518
519            // Clear graphs
520            clear_all_graphs_known();
521        }
522    }
523
524    println!("=== DQN training finished ===");
525    Ok(())
526}
More examples
Hide additional examples
examples/RL_training/td3.rs (line 117)
114    fn copy_from(&mut self, other: &Self) {
115        for (t, s) in self.layers.iter_mut().zip(other.layers.iter()) {
116            {
117                let src = s.weight.data();
118                let dst = t.weight.data_mut();
119                dst.copy_from_slice(src);
120            }
121            {
122                let src = s.bias.data();
123                let dst = t.bias.data_mut();
124                dst.copy_from_slice(src);
125            }
126            t.weight.set_requires_grad(false);
127            t.bias.set_requires_grad(false);
128        }
129    }
130
131    fn soft_update_from(&mut self, source: &Self, tau: f32) {
132        let _ng = NoGradTrack::new();
133        for (t, s) in self.layers.iter_mut().zip(source.layers.iter()) {
134            // In-place Polyak update to preserve tensor IDs (no optimizer relink needed)
135            let new_w = t
136                .weight
137                .mul_scalar(1.0 - tau)
138                .add_tensor(&s.weight.mul_scalar(tau));
139            let new_b = t
140                .bias
141                .mul_scalar(1.0 - tau)
142                .add_tensor(&s.bias.mul_scalar(tau));
143            {
144                let src = new_w.data();
145                let dst = t.weight.data_mut();
146                dst.copy_from_slice(src);
147            }
148            {
149                let src = new_b.data();
150                let dst = t.bias.data_mut();
151                dst.copy_from_slice(src);
152            }
153            t.weight.set_requires_grad(false);
154            t.bias.set_requires_grad(false);
155        }
156    }
157}
158
159// -------------------------------
160// Actor and Critic
161// -------------------------------
162
163struct Actor {
164    net: Mlp,
165}
166
167impl Actor {
168    fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
169        // Smaller net for faster demo: sd -> 64 -> 64 -> ad, tanh output
170        let net = Mlp::new(&[state_dim, 64, 64, action_dim], seed);
171        Self { net }
172    }
173    fn forward(&self, state: &Tensor) -> Tensor {
174        self.net.forward(state, Some(tanh_bounded))
175    }
176    fn parameters(&mut self) -> Vec<&mut Tensor> {
177        self.net.parameters()
178    }
179    fn set_requires_grad_all(&mut self, enable: bool) {
180        self.net.set_requires_grad_all(enable);
181    }
182}
183
184struct Critic {
185    net: Mlp,
186}
187
188impl Critic {
189    fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
190        let net = Mlp::new(&[state_dim + action_dim, 64, 64, 1], seed);
191        Self { net }
192    }
193    fn forward(&self, state: &Tensor, action: &Tensor) -> Tensor {
194        // Concatenate along feature dim (dim=1) for batched inputs
195        // IMPORTANT: use views to preserve gradient graph; cloning would detach autograd
196        let s_view = state.view(state.shape().dims().iter().map(|&d| d as i32).collect());
197        let a_view = action.view(action.shape().dims().iter().map(|&d| d as i32).collect());
198        let sa = Tensor::cat(&[s_view, a_view], 1);
199        self.net.forward(&sa, None)
200    }
201    fn parameters(&mut self) -> Vec<&mut Tensor> {
202        self.net.parameters()
203    }
204    fn set_requires_grad_all(&mut self, enable: bool) {
205        self.net.set_requires_grad_all(enable);
206    }
207}
208
209// -------------------------------
210// Simple continuous control environment: YardEnv
211// State: normalized features [pos/3, clamp(vel/1, -1..1), bias(=0)] ; Action: scalar in [-1, 1]
212// Dynamics: vel += 0.1*act - 0.01*pos; pos += vel
213// Reward: -(pos^2) - 0.1*act^2 ; Episode ends if |pos| > 3 or step >= max_steps
214// -------------------------------
215
216struct YardEnv {
217    pos: f32,
218    vel: f32,
219    steps: usize,
220    max_steps: usize,
221    rng: SmallRng,
222}
223
224impl YardEnv {
225    fn new(seed: u64) -> Self {
226        let mut env = Self {
227            pos: 0.0,
228            vel: 0.0,
229            steps: 0,
230            max_steps: 200,
231            rng: SmallRng::new(seed),
232        };
233        env.reset();
234        env
235    }
236
237    fn reset(&mut self) -> Tensor {
238        self.pos = self.rng.uniform(-0.5, 0.5);
239        self.vel = self.rng.uniform(-0.1, 0.1);
240        self.steps = 0;
241        self.state_tensor()
242    }
243
244    fn state_tensor(&self) -> Tensor {
245        // Normalize to keep critic inputs bounded:
246        // - Position is bounded by termination at |pos|>3 → scale by 3 to [-1,1]
247        // - Velocity scaled by 1.0 and clamped to [-1,1]
248        let pos_n = self.pos / 3.0;
249        let vel_n = self.vel.clamp(-1.0, 1.0);
250        Tensor::from_slice(&[pos_n, vel_n, 0.0], vec![1, 3]).unwrap()
251    }
252
253    fn step(&mut self, action_value: f32) -> (Tensor, f32, bool) {
254        let a = action_value.clamp(-1.0, 1.0);
255        self.vel += 0.1 * a - 0.01 * self.pos;
256        self.pos += self.vel;
257        self.steps += 1;
258
259        let reward = -(self.pos * self.pos) - 0.1 * (a * a);
260        let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
261        (self.state_tensor(), reward, done)
262    }
263}
264
265// -------------------------------
266// Replay Buffer
267// -------------------------------
268
269struct ReplayBuffer {
270    capacity: usize,
271    size: usize,
272    pos: usize,
273    state_dim: usize,
274    action_dim: usize,
275    states: Vec<f32>,
276    actions: Vec<f32>,
277    rewards: Vec<f32>,
278    dones: Vec<f32>,
279    next_states: Vec<f32>,
280}
281
282impl ReplayBuffer {
283    fn new(capacity: usize, state_dim: usize, action_dim: usize) -> Self {
284        Self {
285            capacity,
286            size: 0,
287            pos: 0,
288            state_dim,
289            action_dim,
290            states: vec![0.0; capacity * state_dim],
291            actions: vec![0.0; capacity * action_dim],
292            rewards: vec![0.0; capacity],
293            dones: vec![0.0; capacity],
294            next_states: vec![0.0; capacity * state_dim],
295        }
296    }
297
298    fn push(&mut self, s: &[f32], a: &[f32], r: f32, d: f32, s2: &[f32]) {
299        let i = self.pos;
300        let so = i * self.state_dim;
301        let ao = i * self.action_dim;
302        self.states[so..so + self.state_dim].copy_from_slice(s);
303        self.actions[ao..ao + self.action_dim].copy_from_slice(a);
304        self.rewards[i] = r;
305        self.dones[i] = d;
306        self.next_states[so..so + self.state_dim].copy_from_slice(s2);
307
308        self.pos = (self.pos + 1) % self.capacity;
309        self.size = self.size.saturating_add(1).min(self.capacity);
310    }
311
312    fn can_sample(&self, batch_size: usize) -> bool {
313        self.size >= batch_size
314    }
315
316    fn sample(
317        &self,
318        batch_size: usize,
319        rng: &mut SmallRng,
320    ) -> (Tensor, Tensor, Tensor, Tensor, Tensor) {
321        let mut s_vec = Vec::with_capacity(batch_size * self.state_dim);
322        let mut a_vec = Vec::with_capacity(batch_size * self.action_dim);
323        let mut r_vec = Vec::with_capacity(batch_size);
324        let mut d_vec = Vec::with_capacity(batch_size);
325        let mut s2_vec = Vec::with_capacity(batch_size * self.state_dim);
326
327        for _ in 0..batch_size {
328            let idx = rng.sample_index(self.size);
329            let so = idx * self.state_dim;
330            let ao = idx * self.action_dim;
331            s_vec.extend_from_slice(&self.states[so..so + self.state_dim]);
332            a_vec.extend_from_slice(&self.actions[ao..ao + self.action_dim]);
333            r_vec.push(self.rewards[idx]);
334            d_vec.push(self.dones[idx]);
335            s2_vec.extend_from_slice(&self.next_states[so..so + self.state_dim]);
336        }
337
338        let s = Tensor::from_slice(&s_vec, vec![batch_size, self.state_dim]).unwrap();
339        let a = Tensor::from_slice(&a_vec, vec![batch_size, self.action_dim]).unwrap();
340        let r = Tensor::from_slice(&r_vec, vec![batch_size, 1]).unwrap();
341        let d = Tensor::from_slice(&d_vec, vec![batch_size, 1]).unwrap();
342        let s2 = Tensor::from_slice(&s2_vec, vec![batch_size, self.state_dim]).unwrap();
343        (s, a, r, d, s2)
344    }
345}
346
347// -------------------------------
348// Helper: gradient clipping by global norm
349// -------------------------------
350
351fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
352    // Compute global L2 norm of all grads
353    let mut total_sq = 0.0f32;
354    for p in parameters.iter() {
355        if let Some(g) = p.grad_owned() {
356            for &v in g.data() {
357                total_sq += v * v;
358            }
359        }
360    }
361    let norm = total_sq.sqrt();
362    if norm > max_norm {
363        let scale = max_norm / (norm + eps);
364        for p in parameters.iter_mut() {
365            if let Some(g) = p.grad_owned() {
366                let scaled = g.mul_scalar(scale);
367                p.set_grad(scaled);
368            }
369        }
370    }
371}
372
373// Compute global L2 norm of gradients across a parameter list (read-only)
374fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
375    let mut total_sq = 0.0f32;
376    for p in parameters.iter_mut() {
377        if let Some(g) = p.grad_owned() {
378            for &v in g.data() {
379                total_sq += v * v;
380            }
381        }
382    }
383    total_sq.sqrt()
384}
385
386// Compute L2 norm of parameters (weights/biases) across a parameter list
387fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
388    let _ng = NoGradTrack::new();
389    let mut total_sq = 0.0f32;
390    for p in parameters.iter_mut() {
391        for &v in p.data() {
392            total_sq += v * v;
393        }
394    }
395    total_sq.sqrt()
396}
397
398// -------------------------------
399// Main: TD3 training on YardEnv
400// -------------------------------
401
402pub fn main() -> Result<(), Box<dyn std::error::Error>> {
403    println!("=== TD3 Example (YardEnv) ===");
404
405    // Environment / problem dims
406    let state_dim = 3usize;
407    let action_dim = 1usize;
408
409    // Hyperparameters (small for demo)
410    let gamma = 0.99f32;
411    let tau = 0.005f32; // Polyak
412    let policy_noise = 0.2f32; // target smoothing noise stddev
413    let exploration_noise = 0.1f32; // behavior policy noise stddev
414    let policy_delay = 2usize;
415    let batch_size = 64usize;
416    let start_steps = 500usize; // random exploration steps
417    let total_steps = 1500usize;
418    let max_grad_norm = 1.0f32;
419
420    // Models
421    let mut actor = Actor::new(state_dim, action_dim, Some(11));
422    let mut actor_targ = Actor::new(state_dim, action_dim, Some(12));
423    actor_targ.net.copy_from(&actor.net);
424    actor_targ.set_requires_grad_all(false);
425
426    let mut critic1 = Critic::new(state_dim, action_dim, Some(21));
427    let mut critic2 = Critic::new(state_dim, action_dim, Some(22));
428    let mut critic1_targ = Critic::new(state_dim, action_dim, Some(23));
429    let mut critic2_targ = Critic::new(state_dim, action_dim, Some(24));
430    critic1_targ.net.copy_from(&critic1.net);
431    critic2_targ.net.copy_from(&critic2.net);
432    critic1_targ.set_requires_grad_all(false);
433    critic2_targ.set_requires_grad_all(false);
434
435    // Optimizers
436    let mut actor_opt = Adam::with_learning_rate(1e-3);
437    for p in actor.parameters() {
438        actor_opt.add_parameter(p);
439    }
440
441    let mut critic_opt = Adam::with_learning_rate(1e-4);
442    for p in critic1.parameters() {
443        critic_opt.add_parameter(p);
444    }
445    for p in critic2.parameters() {
446        critic_opt.add_parameter(p);
447    }
448
449    // Replay buffer and env
450    let mut rb = ReplayBuffer::new(100_000, state_dim, action_dim);
451    let mut env = YardEnv::new(1234);
452    let mut rng = SmallRng::new(987654321);
453
454    // Reset & metric trackers
455    let mut state = env.reset(); // [1, state_dim]
456    let mut episode_return = 0.0f32;
457    let mut episode = 0usize;
458    let mut ema_return: Option<f32> = None;
459    let ema_alpha = 0.05f32; // smooth short-term
460    let mut best_return = f32::NEG_INFINITY;
461    let mut policy_updates: usize = 0;
462
463    for t in 0..total_steps {
464        // Select action
465        let action_tensor = if t < start_steps {
466            let a = rng.uniform(-1.0, 1.0);
467            Tensor::from_slice(&[a], vec![1, action_dim]).unwrap()
468        } else {
469            // Behavior policy with exploration noise
470            let _ng = NoGradTrack::new();
471            let det = actor.forward(&state);
472            let noise = Tensor::randn(vec![1, action_dim], None).mul_scalar(exploration_noise);
473            tanh_bounded(&det.add_tensor(&noise))
474        };
475        let action_value = action_tensor.data()[0];
476
477        // Environment step
478        let (next_state, reward, done) = env.step(action_value);
479        episode_return += reward;
480
481        // Store transition
482        let s_slice = state.data().to_vec();
483        let a_slice = action_tensor.data().to_vec();
484        let s2_slice = next_state.data().to_vec();
485        rb.push(
486            &s_slice,
487            &a_slice,
488            reward,
489            if done { 1.0 } else { 0.0 },
490            &s2_slice,
491        );
492
493        state = if done {
494            let st = env.reset();
495            // Metrics: update EMA and best
496            ema_return = Some(match ema_return {
497                None => episode_return,
498                Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
499            });
500            if episode_return > best_return {
501                best_return = episode_return;
502            }
503            println!(
504                "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={} | policy_updates={}",
505                t,
506                episode,
507                episode_return,
508                ema_return.unwrap_or(episode_return),
509                best_return,
510                rb.size,
511                policy_updates
512            );
513            episode_return = 0.0;
514            episode += 1;
515            st
516        } else {
517            next_state
518        };
519
520        // Training
521        if rb.can_sample(batch_size) {
522            // Sample batch
523            let (s, a, r, d, s2) = rb.sample(batch_size, &mut rng);
524
525            // Compute target values y = r + (1-d)*gamma*min(Q1', Q2') using target networks (no grad)
526            let target_q = {
527                let _ng = NoGradTrack::new();
528                // Target actions with smoothing noise (tanh bounds)
529                let noise =
530                    Tensor::randn(vec![batch_size, action_dim], None).mul_scalar(policy_noise);
531                let a_targ = tanh_bounded(&actor_targ.forward(&s2).add_tensor(&noise));
532                let q1_t = critic1_targ.forward(&s2, &a_targ);
533                let q2_t = critic2_targ.forward(&s2, &a_targ);
534
535                // Elementwise min via data() since this path is no-grad
536                let q1d = q1_t.data();
537                let q2d = q2_t.data();
538                let mut min_vec = Vec::with_capacity(batch_size);
539                for i in 0..batch_size {
540                    let v1 = q1d[i];
541                    let v2 = q2d[i];
542                    min_vec.push(v1.min(v2));
543                }
544                let min_q = Tensor::from_slice(&min_vec, vec![batch_size, 1]).unwrap();
545                let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
546                r.add_tensor(&not_done.mul_scalar(gamma).mul_tensor(&min_q))
547            };
548
549            // Critic update (both critics)
550            // Zero grads in a short scope, then drop borrows before forward
551            {
552                let mut params = {
553                    let c_params = critic1.parameters();
554                    let c2_params = critic2.parameters();
555                    let mut tmp: Vec<&mut Tensor> = Vec::new();
556                    tmp.extend(c_params);
557                    tmp.extend(c2_params);
558                    tmp
559                };
560                critic_opt.zero_grad(&mut params);
561            }
562
563            // Forward current Q estimates
564            let q1 = critic1.forward(&s, &a);
565            let q2 = critic2.forward(&s, &a);
566            let diff1 = q1.sub_tensor(&target_q);
567            let diff2 = q2.sub_tensor(&target_q);
568            let mut critic_loss = diff1
569                .pow_scalar(2.0)
570                .mean()
571                .add_tensor(&diff2.pow_scalar(2.0).mean());
572
573            // Backward
574            critic_loss.backward(None);
575
576            // Optional gradient clipping + step (only for params that received grads)
577            {
578                let params = {
579                    let c_params = critic1.parameters();
580                    let c2_params = critic2.parameters();
581                    let mut tmp: Vec<&mut Tensor> = Vec::new();
582                    tmp.extend(c_params);
583                    tmp.extend(c2_params);
584                    tmp
585                };
586                let mut with_grads: Vec<&mut Tensor> = Vec::new();
587                for p in params {
588                    if p.grad_owned().is_some() {
589                        with_grads.push(p);
590                    }
591                }
592                if !with_grads.is_empty() {
593                    // Pre-step metrics
594                    let grad_norm_before = grad_global_norm(&mut with_grads);
595                    clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
596                    critic_opt.step(&mut with_grads);
597                    critic_opt.zero_grad(&mut with_grads);
598
599                    // Post-step metrics (param norm)
600                    let mut for_norm_params = {
601                        let c_params = critic1.parameters();
602                        let c2_params = critic2.parameters();
603                        let mut tmp: Vec<&mut Tensor> = Vec::new();
604                        tmp.extend(c_params);
605                        tmp.extend(c2_params);
606                        tmp
607                    };
608                    let param_norm = params_l2_norm(&mut for_norm_params);
609
610                    // Print compact critic metrics occasionally
611                    if t % 100 == 0 {
612                        let q1_mean = q1.mean().value();
613                        let q2_mean = q2.mean().value();
614                        let tq_mean = target_q.mean().value();
615                        println!(
616                            "t={:5} | critic_loss={:.4} | q1_mean={:.3} q2_mean={:.3} tq_mean={:.3} | grad_norm={:.3} | crit_param_norm={:.3}",
617                            t,
618                            critic_loss.value(),
619                            q1_mean,
620                            q2_mean,
621                            tq_mean,
622                            grad_norm_before,
623                            param_norm
624                        );
625                    }
626                }
627            }
628
629            // Delayed policy update
630            if t % policy_delay == 0 {
631                // Actor update: maximize Q1(s, actor(s)) -> minimize -Q1
632                // Zero actor grads before backward
633                {
634                    let mut a_params: Vec<&mut Tensor> = actor.parameters();
635                    actor_opt.zero_grad(&mut a_params);
636                }
637
638                let a_pred = actor.forward(&s);
639                let q_for_actor = critic1.forward(&s, &a_pred);
640                let mut actor_loss = q_for_actor.mul_scalar(-1.0).mean();
641                actor_loss.backward(None);
642
643                {
644                    let a_params: Vec<&mut Tensor> = actor.parameters();
645                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
646                    for p in a_params {
647                        if p.grad_owned().is_some() {
648                            with_grads.push(p);
649                        }
650                    }
651                    if !with_grads.is_empty() {
652                        let grad_norm_before = grad_global_norm(&mut with_grads);
653                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
654                        actor_opt.step(&mut with_grads);
655                        actor_opt.zero_grad(&mut with_grads);
656
657                        // Post-step param norm
658                        let mut for_norm_params = actor.parameters();
659                        let param_norm = params_l2_norm(&mut for_norm_params);
660
661                        policy_updates += 1;
662                        if t % 200 == 0 {
663                            println!(
664                                "t={:5} | actor_loss={:.4} | act_grad_norm={:.3} | act_param_norm={:.3} | lr_a={:.4e} lr_c={:.4e} | policy_updates={}",
665                                t,
666                                actor_loss.value(),
667                                grad_norm_before,
668                                param_norm,
669                                actor_opt.learning_rate(),
670                                critic_opt.learning_rate(),
671                                policy_updates
672                            );
673                        }
674                    }
675                }
676
677                // Target updates (Polyak averaging, no grad)
678                actor_targ.net.soft_update_from(&actor.net, tau);
679                critic1_targ.net.soft_update_from(&critic1.net, tau);
680                critic2_targ.net.soft_update_from(&critic2.net, tau);
681            }
682
683            // Clear entire graphs to avoid stale accumulation across iterations
684            clear_all_graphs_known();
685        }
686    }
687
688    println!("=== TD3 training finished ===");
689    Ok(())
690}
examples/RL_training/ppo_continuous.rs (line 297)
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294    let mut total_sq = 0.0f32;
295    for p in parameters.iter() {
296        if let Some(g) = p.grad_owned() {
297            for &v in g.data() {
298                total_sq += v * v;
299            }
300        }
301    }
302    let norm = total_sq.sqrt();
303    if norm > max_norm {
304        let scale = max_norm / (norm + eps);
305        for p in parameters.iter_mut() {
306            if let Some(g) = p.grad_owned() {
307                p.set_grad(g.mul_scalar(scale));
308            }
309        }
310    }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314    let mut total_sq = 0.0f32;
315    for p in parameters.iter_mut() {
316        if let Some(g) = p.grad_owned() {
317            for &v in g.data() {
318                total_sq += v * v;
319            }
320        }
321    }
322    total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330    println!("=== PPO Continuous Example (YardEnv) ===");
331
332    let state_dim = 3usize;
333    let action_dim = 1usize;
334
335    // Hparams
336    let total_steps = std::env::var("PPO_STEPS")
337        .ok()
338        .and_then(|v| v.parse::<usize>().ok())
339        .unwrap_or(4000usize);
340    let horizon = 128usize; // rollout length per update
341    let epochs = 4usize; // PPO epochs per update
342    let mini_batch_size = 64usize; // minibatch from horizon
343    let gamma = 0.99f32;
344    let lam = 0.95f32; // GAE lambda
345    let clip_eps = 0.2f32;
346    let vf_coef = 0.5f32;
347    let ent_coef = 0.0f32;
348    let max_grad_norm = 1.0f32;
349
350    // Models
351    let mut actor = Actor::new(state_dim, action_dim, Some(101));
352    let mut critic = Critic::new(state_dim, Some(202));
353
354    // Opts
355    let mut actor_opt = Adam::with_learning_rate(3e-4);
356    for p in actor.parameters() {
357        actor_opt.add_parameter(p);
358    }
359    let mut critic_opt = Adam::with_learning_rate(3e-4);
360    for p in critic.parameters() {
361        critic_opt.add_parameter(p);
362    }
363
364    // Env and RNG
365    let mut env = YardEnv::new(42);
366    let mut rng = SmallRng::new(999);
367    let mut state = env.reset();
368
369    // Metrics
370    let mut episode_return = 0.0f32;
371    let mut episode = 0usize;
372    let mut ema_return: Option<f32> = None;
373    let ema_alpha = 0.05f32;
374    let mut best_return = f32::NEG_INFINITY;
375
376    let mut t = 0usize;
377    while t < total_steps {
378        // Collect a rollout
379        let mut batch = RolloutBatch::new(horizon, state_dim);
380        for _ in 0..horizon {
381            // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382            let (mean, log_std_row) = actor.forward(&state);
383            let mean_v = mean.data()[0];
384            let log_std_v = log_std_row.data()[0];
385            let std_v = log_std_v.exp();
386            let noise = rng.normal();
387            let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389            // Build action tensor [1, A] for log_prob calculation with autograd
390            let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391            let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392            let log_prob_v = log_prob_t.data()[0];
393
394            // Step env
395            let (next_state, reward, done) = env.step(action_v);
396            episode_return += reward;
397
398            // Value
399            let value_t = critic.forward(&state);
400            let value_v = value_t.data()[0];
401
402            // Push
403            batch.push(
404                state.data(),
405                action_v,
406                log_prob_v,
407                reward,
408                if done { 1.0 } else { 0.0 },
409                value_v,
410                next_state.data(),
411            );
412
413            // Reset
414            state = if done {
415                let st = env.reset();
416                ema_return = Some(match ema_return {
417                    None => episode_return,
418                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419                });
420                if episode_return > best_return {
421                    best_return = episode_return;
422                }
423                println!(
424                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425                    t,
426                    episode,
427                    episode_return,
428                    ema_return.unwrap_or(episode_return),
429                    best_return
430                );
431                episode_return = 0.0;
432                episode += 1;
433                st
434            } else {
435                next_state
436            };
437
438            t += 1;
439            if t >= total_steps {
440                break;
441            }
442        }
443
444        // Bootstrap next values for GAE
445        let next_values: Vec<f32> = {
446            let mut out = Vec::with_capacity(batch.len());
447            for i in 0..batch.len() {
448                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450                let v2 = critic.forward(&s2_t).data()[0];
451                out.push(v2);
452            }
453            out
454        };
455
456        // Compute returns and advantages
457        let mut returns = vec![0.0f32; batch.len()];
458        let mut adv = vec![0.0f32; batch.len()];
459        compute_gae(
460            &mut returns,
461            &mut adv,
462            &batch.rewards,
463            &batch.dones,
464            &batch.values,
465            &next_values,
466            gamma,
467            lam,
468        );
469        normalize_in_place(&mut adv, 1e-8);
470
471        // Prepare tensors for training
472        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473        let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474        let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478        // PPO epochs over the rollout
479        let num_minibatches = batch.len().div_ceil(mini_batch_size);
480        for e in 0..epochs {
481            for mb in 0..num_minibatches {
482                let start = mb * mini_batch_size;
483                let end = (start + mini_batch_size).min(batch.len());
484                if start >= end {
485                    break;
486                }
487
488                // Slice views
489                let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490                let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491                let a_mb = actions_t
492                    .slice_view(start * action_dim, 1, (end - start) * action_dim)
493                    .reshape(vec![(end - start) as i32, action_dim as i32]);
494                let oldlp_mb = old_logp_t
495                    .slice_view(start, 1, end - start)
496                    .reshape(vec![(end - start) as i32, 1]);
497                let ret_mb = returns_t
498                    .slice_view(start, 1, end - start)
499                    .reshape(vec![(end - start) as i32, 1]);
500                let adv_mb = adv_t
501                    .slice_view(start, 1, end - start)
502                    .reshape(vec![(end - start) as i32, 1]);
503
504                // Zero grads
505                {
506                    let mut ps = actor.parameters();
507                    actor_opt.zero_grad(&mut ps);
508                }
509                {
510                    let mut ps = critic.parameters();
511                    critic_opt.zero_grad(&mut ps);
512                }
513
514                // Forward actor and critic
515                let (mean_mb, log_std_row) = actor.forward(&s_mb);
516                let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517                let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518                let clip_low =
519                    Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520                        .unwrap();
521                let clip_high =
522                    Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523                        .unwrap();
524                // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525                let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526                let ratio_clipped =
527                    clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528                let pg1 = ratio.mul_tensor(&adv_mb);
529                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532                let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534                let v_pred = critic.forward(&s_mb);
535                let v_loss = v_pred
536                    .sub_tensor(&ret_mb)
537                    .pow_scalar(2.0)
538                    .mean()
539                    .mul_scalar(vf_coef);
540
541                // Entropy (approx Gaussian entropy per action)
542                let entropy = log_std_row
543                    .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544                    .sum_dims(&[1], true)
545                    .mean()
546                    .mul_scalar(ent_coef);
547
548                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549                loss.backward(None);
550
551                // Step actor
552                {
553                    let params = actor.parameters();
554                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
555                    for p in params {
556                        if p.grad_owned().is_some() {
557                            with_grads.push(p);
558                        }
559                    }
560                    if !with_grads.is_empty() {
561                        let _ = grad_global_norm(&mut with_grads);
562                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563                        actor_opt.step(&mut with_grads);
564                        actor_opt.zero_grad(&mut with_grads);
565                    }
566                }
567
568                // Step critic
569                {
570                    let params = critic.parameters();
571                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
572                    for p in params {
573                        if p.grad_owned().is_some() {
574                            with_grads.push(p);
575                        }
576                    }
577                    if !with_grads.is_empty() {
578                        let _ = grad_global_norm(&mut with_grads);
579                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580                        critic_opt.step(&mut with_grads);
581                        critic_opt.zero_grad(&mut with_grads);
582                    }
583                }
584
585                // Occasionally log
586                if e == 0 && mb == 0 {
587                    println!(
588                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
589                        t,
590                        actor_loss.value(),
591                        v_loss.value()
592                    );
593                }
594
595                clear_all_graphs_known();
596            }
597        }
598    }
599
600    println!("=== PPO training finished ===");
601    Ok(())
602}
examples/RL_training/ppo_discrete.rs (line 256)
252fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
253    let mut total_sq = 0.0f32;
254    for p in parameters.iter() {
255        if let Some(g) = p.grad_owned() {
256            for &v in g.data() {
257                total_sq += v * v;
258            }
259        }
260    }
261    let norm = total_sq.sqrt();
262    if norm > max_norm {
263        let scale = max_norm / (norm + eps);
264        for p in parameters.iter_mut() {
265            if let Some(g) = p.grad_owned() {
266                p.set_grad(g.mul_scalar(scale));
267            }
268        }
269    }
270}
271
272// log-softmax for selected actions: given logits [B,A] and actions Vec<usize> -> log_prob [B,1]
273fn log_prob_actions(
274    logits: &Tensor,
275    actions: &[usize],
276    batch: usize,
277    _action_dim: usize,
278) -> Tensor {
279    let max_logits = logits.max_dims(&[1], true); // [B,1]
280    let shifted = logits.sub_tensor(&max_logits);
281    let exp = shifted.exp();
282    let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283    let log_sum_exp = sum_exp.log(); // [B,1]
284    let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285                                                        // gather selected action log-probs
286    log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291    new_logp.sub_tensor(old_logp).exp()
292}
293
294// Clamp ratio to [1-clip, 1+clip] using ReLU-based clamp (no custom ops)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296    let b = ratio.shape().dims()[0];
297    let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298    let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299    let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300    high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304    let mut total_sq = 0.0f32;
305    for p in parameters.iter_mut() {
306        if let Some(g) = p.grad_owned() {
307            for &v in g.data() {
308                total_sq += v * v;
309            }
310        }
311    }
312    total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320    println!("=== PPO Discrete Example (YardEnv) ===");
321
322    let state_dim = 3usize;
323    let action_dim = 3usize;
324    let total_steps = std::env::var("PPOD_STEPS")
325        .ok()
326        .and_then(|v| v.parse::<usize>().ok())
327        .unwrap_or(3500usize);
328    let horizon = 128usize;
329    let epochs = 4usize;
330    let mini_batch_size = 64usize;
331    let gamma = 0.99f32;
332    let lam = 0.95f32;
333    let clip_eps = 0.2f32;
334    let vf_coef = 0.5f32;
335    let ent_coef = 0.0f32;
336    let max_grad_norm = 1.0f32;
337
338    let mut actor = Actor::new(state_dim, action_dim, Some(111));
339    let mut critic = Critic::new(state_dim, Some(222));
340    let mut actor_opt = Adam::with_learning_rate(3e-4);
341    for p in actor.parameters() {
342        actor_opt.add_parameter(p);
343    }
344    let mut critic_opt = Adam::with_learning_rate(3e-4);
345    for p in critic.parameters() {
346        critic_opt.add_parameter(p);
347    }
348
349    let mut env = YardEnv::new(1234);
350    let mut rng = SmallRng::new(98765);
351    let mut state = env.reset();
352    let mut episode_return = 0.0f32;
353    let mut episode = 0usize;
354    let mut ema_return: Option<f32> = None;
355    let ema_alpha = 0.05f32;
356    let mut best_return = f32::NEG_INFINITY;
357
358    let mut t = 0usize;
359    while t < total_steps {
360        let mut batch = RolloutBatch::new(horizon, state_dim);
361        for _ in 0..horizon {
362            // Actor logits and categorical sampling
363            let logits = actor.forward(&state); // [1, A]
364            let probs = logits.softmax(1); // [1, A]
365                                           // sample action from probs (CPU sampling)
366            let p = probs.data();
367            let (p0, p1, _p2) = (p[0], p[1], p[2]);
368            let u = rng.next_f32();
369            let a_idx = if u < p0 {
370                0
371            } else if u < p0 + p1 {
372                1
373            } else {
374                2
375            };
376
377            let old_logp = {
378                let _ng = NoGradTrack::new();
379                let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380                lp.data()[0]
381            };
382
383            // Step env
384            let (next_state, reward, done) = env.step(a_idx);
385            episode_return += reward;
386
387            // Critic value
388            let value_t = critic.forward(&state);
389            let value_v = value_t.data()[0];
390
391            batch.push(
392                state.data(),
393                a_idx,
394                old_logp,
395                reward,
396                if done { 1.0 } else { 0.0 },
397                value_v,
398                next_state.data(),
399            );
400
401            state = if done {
402                let st = env.reset();
403                ema_return = Some(match ema_return {
404                    None => episode_return,
405                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406                });
407                if episode_return > best_return {
408                    best_return = episode_return;
409                }
410                println!(
411                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412                    t,
413                    episode,
414                    episode_return,
415                    ema_return.unwrap_or(episode_return),
416                    best_return
417                );
418                episode_return = 0.0;
419                episode += 1;
420                st
421            } else {
422                next_state
423            };
424
425            t += 1;
426            if t >= total_steps {
427                break;
428            }
429        }
430
431        // Bootstrap values for GAE
432        let next_values: Vec<f32> = {
433            let mut out = Vec::with_capacity(batch.len());
434            for i in 0..batch.len() {
435                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437                out.push(critic.forward(&s2_t).data()[0]);
438            }
439            out
440        };
441
442        let mut returns = vec![0.0f32; batch.len()];
443        let mut adv = vec![0.0f32; batch.len()];
444        compute_gae(
445            &mut returns,
446            &mut adv,
447            &batch.rewards,
448            &batch.dones,
449            &batch.values,
450            &next_values,
451            gamma,
452            lam,
453        );
454        normalize_in_place(&mut adv, 1e-8);
455
456        // Tensors for training
457        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458        let actions_vec = batch.actions.clone();
459        let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463        // PPO epochs
464        let num_minibatches = batch.len().div_ceil(mini_batch_size);
465        for e in 0..epochs {
466            for mb in 0..num_minibatches {
467                let start = mb * mini_batch_size;
468                let end = (start + mini_batch_size).min(batch.len());
469                if start >= end {
470                    break;
471                }
472
473                // Views
474                let s_mb = states_t
475                    .slice_view(start * state_dim, 1, (end - start) * state_dim)
476                    .reshape(vec![(end - start) as i32, state_dim as i32]);
477                let oldlp_mb = old_logp_t
478                    .slice_view(start, 1, end - start)
479                    .reshape(vec![(end - start) as i32, 1]);
480                let ret_mb = returns_t
481                    .slice_view(start, 1, end - start)
482                    .reshape(vec![(end - start) as i32, 1]);
483                let adv_mb = adv_t
484                    .slice_view(start, 1, end - start)
485                    .reshape(vec![(end - start) as i32, 1]);
486                let a_slice = &actions_vec[start..end];
487
488                // Zero grads
489                {
490                    let mut ps = actor.parameters();
491                    actor_opt.zero_grad(&mut ps);
492                }
493                {
494                    let mut ps = critic.parameters();
495                    critic_opt.zero_grad(&mut ps);
496                }
497
498                // Forward
499                let logits_mb = actor.forward(&s_mb); // [B,A]
500                let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501                let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502                let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503                let pg1 = ratio.mul_tensor(&adv_mb);
504                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507                let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509                let v_pred = critic.forward(&s_mb);
510                let v_loss = v_pred
511                    .sub_tensor(&ret_mb)
512                    .pow_scalar(2.0)
513                    .mean()
514                    .mul_scalar(vf_coef);
515
516                // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517                let probs_mb = logits_mb.softmax(1);
518                let logp_all = probs_mb.add_scalar(1e-8).log();
519                let ent = probs_mb
520                    .mul_tensor(&logp_all)
521                    .sum_dims(&[1], true)
522                    .mul_scalar(-1.0)
523                    .mean()
524                    .mul_scalar(ent_coef);
525
526                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527                loss.backward(None);
528
529                // Step actor
530                {
531                    let params = actor.parameters();
532                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
533                    for p in params {
534                        if p.grad_owned().is_some() {
535                            with_grads.push(p);
536                        }
537                    }
538                    if !with_grads.is_empty() {
539                        let _ = grad_global_norm(&mut with_grads);
540                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541                        actor_opt.step(&mut with_grads);
542                        actor_opt.zero_grad(&mut with_grads);
543                    }
544                }
545
546                // Step critic
547                {
548                    let params = critic.parameters();
549                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
550                    for p in params {
551                        if p.grad_owned().is_some() {
552                            with_grads.push(p);
553                        }
554                    }
555                    if !with_grads.is_empty() {
556                        let _ = grad_global_norm(&mut with_grads);
557                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558                        critic_opt.step(&mut with_grads);
559                        critic_opt.zero_grad(&mut with_grads);
560                    }
561                }
562
563                if e == 0 && mb == 0 {
564                    println!(
565                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
566                        t,
567                        actor_loss.value(),
568                        v_loss.value()
569                    );
570                }
571
572                clear_all_graphs_known();
573            }
574        }
575    }
576
577    println!("=== PPO discrete training finished ===");
578    Ok(())
579}
examples/supervised_training/supervised_bce.rs (line 27)
23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24    let mut total_sq = 0.0f32;
25    for p in parameters.iter() {
26        if let Some(g) = p.grad_owned() {
27            for &v in g.data() {
28                total_sq += v * v;
29            }
30        }
31    }
32    let norm = total_sq.sqrt();
33    if norm > max_norm {
34        let scale = max_norm / (norm + eps);
35        for p in parameters.iter_mut() {
36            if let Some(g) = p.grad_owned() {
37                p.set_grad(g.mul_scalar(scale));
38            }
39        }
40    }
41}
42
43fn accuracy(pred: &Tensor, targets: &Tensor) -> f32 {
44    // pred: [B,1] with sigmoid; threshold at 0.5
45    let p = pred.data();
46    let t = targets.data();
47    let mut correct = 0usize;
48    for i in 0..p.len() {
49        let yhat = if p[i] >= 0.5 { 1.0 } else { 0.0 };
50        if (yhat - t[i]).abs() < 1e-6 {
51            correct += 1;
52        }
53    }
54    correct as f32 / (p.len() as f32)
55}
56
57// Numerically stable BCE with logits:
58// L = mean( relu(z) - z*y + log(1 + exp(-|z|)) )
59fn bce_with_logits(logits: &Tensor, targets: &Tensor) -> Tensor {
60    let relu_z = logits.relu();
61    let zy = logits.mul_tensor(targets);
62    // |z| = relu(z) + relu(-z)
63    let abs_z = relu_z.add_tensor(&logits.mul_scalar(-1.0).relu());
64    let log_term = abs_z.mul_scalar(-1.0).exp().add_scalar(1.0).log();
65    relu_z.sub_tensor(&zy).add_tensor(&log_term).mean()
66}
67
68pub fn main() -> Result<(), Box<dyn std::error::Error>> {
69    println!("=== Supervised FFN Example (XOR) ===");
70
71    // Dataset: XOR (repeat to form a small batch)
72    let inputs: Vec<f32> = vec![
73        0.0, 0.0, // -> 0
74        0.0, 1.0, // -> 1
75        1.0, 0.0, // -> 1
76        1.0, 1.0, // -> 0
77    ];
78    let targets: Vec<f32> = vec![0.0, 1.0, 1.0, 0.0];
79
80    // Repeat the base patterns to stabilize training
81    let repeats = 64usize; // effective batch = 4 * repeats = 256
82    let mut xs = Vec::with_capacity(repeats * inputs.len());
83    let mut ys = Vec::with_capacity(repeats * targets.len());
84    for _ in 0..repeats {
85        xs.extend_from_slice(&inputs);
86        ys.extend_from_slice(&targets);
87    }
88
89    let batch = xs.len() / 2; // two features
90    let x_t = Tensor::from_slice(&xs, vec![batch, 2]).unwrap();
91    let y_t = Tensor::from_slice(&ys, vec![batch, 1]).unwrap();
92
93    // Model config: 2 -> 32 -> 32 -> 1, final sigmoid via loss path
94    let cfg = FeedForwardConfig {
95        input_size: 2,
96        hidden_sizes: vec![32, 32],
97        output_size: 1,
98        use_bias: true,
99    };
100    let mut net = FeedForwardNetwork::new(cfg, Some(777));
101
102    // Optimizer and parameter linking
103    let mut opt = Adam::with_learning_rate(1e-3);
104    for p in net.parameters() {
105        opt.add_parameter(p);
106    }
107
108    let epochs = 1000usize;
109    let max_grad_norm = 1.0f32;
110    let mut best_loss = f32::INFINITY;
111    let mut best_acc = 0.0f32;
112
113    for e in 0..epochs {
114        // Zero grads each iteration
115        {
116            let mut params = net.parameters();
117            opt.zero_grad(&mut params);
118        }
119
120        // Forward -> logits; use numerically stable BCE-with-logits for loss
121        let logits = net.forward(&x_t);
122        let mut loss = bce_with_logits(&logits, &y_t);
123        loss.backward(None);
124
125        // Step only params with grads
126        {
127            let params = net.parameters();
128            let mut with_grads: Vec<&mut Tensor> = Vec::new();
129            for p in params {
130                if p.grad_owned().is_some() {
131                    with_grads.push(p);
132                }
133            }
134            if !with_grads.is_empty() {
135                clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
136                opt.step(&mut with_grads);
137                opt.zero_grad(&mut with_grads);
138            }
139        }
140
141        // Metrics (use sigmoid only for reporting accuracy)
142        let preds = logits.sigmoid();
143        let acc = accuracy(&preds, &y_t);
144        if loss.value() < best_loss {
145            best_loss = loss.value();
146        }
147        if acc > best_acc {
148            best_acc = acc;
149        }
150        if e % 10 == 0 || e + 1 == epochs {
151            println!(
152                "epoch {:4} | loss={:.5} acc={:.3} | best_loss={:.5} best_acc={:.3}",
153                e,
154                loss.value(),
155                acc,
156                best_loss,
157                best_acc
158            );
159        }
160
161        // Clear graphs to avoid stale accumulation across epochs
162        clear_all_graphs_known();
163    }
164
165    // Quick sanity check predictions
166    let test = Tensor::from_slice(&inputs, vec![4, 2]).unwrap();
167    let out = net.forward(&test).sigmoid();
168    println!("predictions (approx): {:?}", out.data());
169
170    println!("=== Supervised training finished ===");
171    Ok(())
172}
examples/supervised_training/supervised_classification.rs (line 27)
23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24    let mut total_sq = 0.0f32;
25    for p in parameters.iter() {
26        if let Some(g) = p.grad_owned() {
27            for &v in g.data() {
28                total_sq += v * v;
29            }
30        }
31    }
32    let norm = total_sq.sqrt();
33    if norm > max_norm {
34        let scale = max_norm / (norm + eps);
35        for p in parameters.iter_mut() {
36            if let Some(g) = p.grad_owned() {
37                p.set_grad(g.mul_scalar(scale));
38            }
39        }
40    }
41}
42
43// Cross-entropy over logits: CE = -mean(log_softmax(logits)[range, labels])
44fn cross_entropy_logits(
45    logits: &Tensor,
46    labels: &[usize],
47    batch: usize,
48    _num_classes: usize,
49) -> Tensor {
50    // log_softmax = logits - logsumexp(logits, dim=1)
51    let max_logits = logits.max_dims(&[1], true);
52    let shifted = logits.sub_tensor(&max_logits);
53    let exp = shifted.exp();
54    let sum_exp = exp.sum_dims(&[1], true);
55    let log_sum_exp = sum_exp.log();
56    let log_softmax = shifted.sub_tensor(&log_sum_exp);
57    let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58    ll.mul_scalar(-1.0).mean()
59}
60
61fn accuracy_from_logits(
62    logits: &Tensor,
63    labels: &[usize],
64    batch: usize,
65    num_classes: usize,
66) -> f32 {
67    let row = logits.data();
68    let mut correct = 0usize;
69    for (i, &label) in labels.iter().enumerate().take(batch) {
70        let base = i * num_classes;
71        let mut best_j = 0usize;
72        let mut best_v = row[base];
73        for j in 1..num_classes {
74            let v = row[base + j];
75            if v > best_v {
76                best_v = v;
77                best_j = j;
78            }
79        }
80        if best_j == label {
81            correct += 1;
82        }
83    }
84    correct as f32 / batch as f32
85}
86
87pub fn main() -> Result<(), Box<dyn std::error::Error>> {
88    println!("=== Supervised Classification Example (Cross-Entropy) ===");
89
90    // Synthetic 2D inputs, 3 classes with linear-ish separations
91    let n = 1200usize;
92    let classes = 3usize;
93    let mut xs: Vec<f32> = Vec::with_capacity(n * 2);
94    let mut ys: Vec<usize> = Vec::with_capacity(n);
95
96    // Simple RNG
97    let mut state: u64 = 424242;
98    let mut rand_f32 = || {
99        state = state.wrapping_mul(1664525).wrapping_add(1013904223);
100        (state >> 16) as f32 / (u32::MAX as f32)
101    };
102
103    for _ in 0..n {
104        let x1 = rand_f32() * 4.0 - 2.0;
105        let x2 = rand_f32() * 4.0 - 2.0;
106        // Class by quadrant-ish rule with noise
107        let mut c = if x1 + 0.5 * x2 > 0.5 {
108            0
109        } else if x1 - x2 < -0.5 {
110            1
111        } else {
112            2
113        };
114        if rand_f32() < 0.05 {
115            c = (c + 1) % classes;
116        }
117        xs.push(x1);
118        xs.push(x2);
119        ys.push(c);
120    }
121
122    // Normalize inputs per-feature to [-1, 1]
123    let mut min1 = f32::INFINITY;
124    let mut max1 = f32::NEG_INFINITY;
125    let mut min2 = f32::INFINITY;
126    let mut max2 = f32::NEG_INFINITY;
127    for i in (0..xs.len()).step_by(2) {
128        let a = xs[i];
129        let b = xs[i + 1];
130        if a < min1 {
131            min1 = a;
132        }
133        if a > max1 {
134            max1 = a;
135        }
136        if b < min2 {
137            min2 = b;
138        }
139        if b > max2 {
140            max2 = b;
141        }
142    }
143    let rng1 = (max1 - min1).max(1e-8);
144    let rng2 = (max2 - min2).max(1e-8);
145    for i in (0..xs.len()).step_by(2) {
146        let a = xs[i];
147        let b = xs[i + 1];
148        xs[i] = 2.0 * (a - min1) / rng1 - 1.0;
149        xs[i + 1] = 2.0 * (b - min2) / rng2 - 1.0;
150    }
151
152    // Train/Val split (80/20)
153    let n_train = (n as f32 * 0.8) as usize;
154    let x_train = Tensor::from_slice(&xs[..n_train * 2], vec![n_train, 2]).unwrap();
155    let y_train = ys[..n_train].to_vec();
156    let x_val = Tensor::from_slice(&xs[n_train * 2..], vec![n - n_train, 2]).unwrap();
157    let y_val = ys[n_train..].to_vec();
158
159    // Model: 2 -> 64 -> 64 -> 3 (logits)
160    let cfg = FeedForwardConfig {
161        input_size: 2,
162        hidden_sizes: vec![64, 64],
163        output_size: classes,
164        use_bias: true,
165    };
166    let mut net = FeedForwardNetwork::new(cfg, Some(303));
167
168    // Optimizer
169    let mut opt = Adam::with_learning_rate(1e-3);
170    for p in net.parameters() {
171        opt.add_parameter(p);
172    }
173
174    let epochs = 300usize;
175    let max_grad_norm = 1.0f32;
176    let mut best_val_acc = 0.0f32;
177    let mut best_val_loss = f32::INFINITY;
178
179    for e in 0..epochs {
180        // Zero grads
181        {
182            let mut params = net.parameters();
183            opt.zero_grad(&mut params);
184        }
185
186        // Forward logits
187        let logits = net.forward(&x_train);
188        let mut loss = cross_entropy_logits(&logits, &y_train, n_train, classes);
189        loss.backward(None);
190
191        // Step clipped
192        {
193            let params = net.parameters();
194            let mut with_grads: Vec<&mut Tensor> = Vec::new();
195            for p in params {
196                if p.grad_owned().is_some() {
197                    with_grads.push(p);
198                }
199            }
200            if !with_grads.is_empty() {
201                clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
202                opt.step(&mut with_grads);
203                opt.zero_grad(&mut with_grads);
204            }
205        }
206
207        // Metrics
208        let train_acc = accuracy_from_logits(&logits, &y_train, n_train, classes);
209        let val_logits = net.forward(&x_val);
210        let val_loss = cross_entropy_logits(&val_logits, &y_val, n - n_train, classes).value();
211        let val_acc = accuracy_from_logits(&val_logits, &y_val, n - n_train, classes);
212        if val_acc > best_val_acc {
213            best_val_acc = val_acc;
214        }
215        if val_loss < best_val_loss {
216            best_val_loss = val_loss;
217        }
218
219        if e % 10 == 0 || e + 1 == epochs {
220            println!(
221                "epoch {:4} | loss={:.4} acc={:.3} | val_loss={:.4} val_acc={:.3} | best_val_acc={:.3}",
222                e, loss.value(), train_acc, val_loss, val_acc, best_val_acc
223            );
224        }
225
226        clear_all_graphs_known();
227    }
228
229    // Quick sample preds via softmax
230    let samples = Tensor::from_slice(&[-1.0, -1.0, 0.0, 0.0, 1.0, 1.0], vec![3, 2]).unwrap();
231    let sm = net.forward(&samples).softmax(1);
232    println!("sample class probs: {:?}", sm.data());
233
234    println!("=== Supervised classification finished ===");
235    Ok(())
236}
Source

pub fn data_mut(&mut self) -> &mut [f32]

Returns a mutable slice of the tensor’s underlying data

Provides safe mutable access to the tensor’s data without requiring unsafe pointer operations. Use this for in-place modifications of tensor values.

§Returns

A mutable slice containing all tensor elements in row-major order

§Performance
  • Zero-Cost: Direct slice creation with no copying
  • Cache-Friendly: Sequential memory access pattern
  • Safe: No unsafe code required for basic data modification
§Examples
use train_station::Tensor;

let mut tensor = Tensor::new(vec![2, 2]);
let data = tensor.data_mut();

// Safe indexing for modification
data[0] = 1.0;
data[1] = 2.0;

assert_eq!(tensor.get(&[0, 0]), 1.0);
assert_eq!(tensor.get(&[0, 1]), 2.0);
Examples found in repository?
examples/RL_training/dqn.rs (line 112)
108    fn copy_from(&mut self, other: &Self) {
109        for (t, s) in self.layers.iter_mut().zip(other.layers.iter()) {
110            {
111                let src = s.weight.data();
112                let dst = t.weight.data_mut();
113                dst.copy_from_slice(src);
114            }
115            {
116                let src = s.bias.data();
117                let dst = t.bias.data_mut();
118                dst.copy_from_slice(src);
119            }
120            t.weight.set_requires_grad(false);
121            t.bias.set_requires_grad(false);
122        }
123    }
More examples
Hide additional examples
examples/RL_training/td3.rs (line 118)
114    fn copy_from(&mut self, other: &Self) {
115        for (t, s) in self.layers.iter_mut().zip(other.layers.iter()) {
116            {
117                let src = s.weight.data();
118                let dst = t.weight.data_mut();
119                dst.copy_from_slice(src);
120            }
121            {
122                let src = s.bias.data();
123                let dst = t.bias.data_mut();
124                dst.copy_from_slice(src);
125            }
126            t.weight.set_requires_grad(false);
127            t.bias.set_requires_grad(false);
128        }
129    }
130
131    fn soft_update_from(&mut self, source: &Self, tau: f32) {
132        let _ng = NoGradTrack::new();
133        for (t, s) in self.layers.iter_mut().zip(source.layers.iter()) {
134            // In-place Polyak update to preserve tensor IDs (no optimizer relink needed)
135            let new_w = t
136                .weight
137                .mul_scalar(1.0 - tau)
138                .add_tensor(&s.weight.mul_scalar(tau));
139            let new_b = t
140                .bias
141                .mul_scalar(1.0 - tau)
142                .add_tensor(&s.bias.mul_scalar(tau));
143            {
144                let src = new_w.data();
145                let dst = t.weight.data_mut();
146                dst.copy_from_slice(src);
147            }
148            {
149                let src = new_b.data();
150                let dst = t.bias.data_mut();
151                dst.copy_from_slice(src);
152            }
153            t.weight.set_requires_grad(false);
154            t.bias.set_requires_grad(false);
155        }
156    }
examples/getting_started/tensor_basics.rs (line 72)
42fn demonstrate_tensor_creation() {
43    println!("--- Tensor Creation ---");
44
45    // Create tensors with different initializations
46    let zeros = Tensor::zeros(vec![2, 3]);
47    println!(
48        "Zeros tensor: shape {:?}, data: {:?}",
49        zeros.shape().dims(),
50        zeros.data()
51    );
52
53    let ones = Tensor::ones(vec![3, 2]);
54    println!(
55        "Ones tensor: shape {:?}, data: {:?}",
56        ones.shape().dims(),
57        ones.data()
58    );
59
60    // Create tensor from slice
61    let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
62    let from_slice = Tensor::from_slice(&data, vec![2, 3]).unwrap();
63    println!(
64        "From slice: shape {:?}, data: {:?}",
65        from_slice.shape().dims(),
66        from_slice.data()
67    );
68
69    // Create tensor with specific value
70    let mut filled = Tensor::new(vec![2, 2]);
71    {
72        let data = filled.data_mut();
73        for value in data.iter_mut() {
74            *value = 42.0;
75        }
76    }
77    println!("Filled with 42: {:?}", filled.data());
78
79    // Create tensor with random data
80    let random = Tensor::randn(vec![2, 2], Some(42));
81    println!(
82        "Random tensor: shape {:?}, data: {:?}",
83        random.shape().dims(),
84        random.data()
85    );
86}
examples/neural_networks/basic_transformer.rs (line 96)
76    pub fn infer_autoregressive(&self, src: &Tensor, max_steps: usize) -> Tensor {
77        let (b, _s, e) = Self::triple(src);
78        let mut memory = src.clone();
79        for enc in &self.encoders {
80            memory = enc.forward(&memory, None);
81        }
82
83        let mut out_seq: Vec<Tensor> = Vec::new();
84        // Start token: zeros
85        let mut current = Tensor::zeros(vec![b, 1, e]);
86        for _step in 0..max_steps {
87            // Build causal mask for length t
88            let t = current.shape().dims()[1];
89            let mut causal = Tensor::ones(vec![b, self.num_heads, t, t]);
90            // Upper triangle as false -> masked for all batches and heads
91            for bb in 0..b {
92                for hh in 0..self.num_heads {
93                    for i in 0..t {
94                        for j in (i + 1)..t {
95                            let offset = causal.memory_offset(&[bb, hh, i, j]);
96                            let data = causal.data_mut();
97                            data[offset] = 0.0;
98                        }
99                    }
100                }
101            }
102            let mut step_out = current.clone();
103            for dec in &self.decoders {
104                step_out = dec.forward(&step_out, &memory, Some(&causal), None);
105            }
106            // (Toy) append placeholder token; real models would project last token
107            out_seq.push(step_out.clone());
108            // Append a zero token to grow sequence by 1 for next causal computation
109            current = Tensor::zeros(vec![b, t + 1, e]);
110        }
111        // Simple return of final sequence placeholder
112        current
113    }
114
115    /// Non auto-regressive inference: single forward pass
116    pub fn infer_non_autoregressive(&self, src: &Tensor, tgt_len: usize) -> Tensor {
117        let (b, _s, e) = Self::triple(src);
118        let mut memory = src.clone();
119        for enc in &self.encoders {
120            memory = enc.forward(&memory, None);
121        }
122        let tgt = Tensor::zeros(vec![b, tgt_len, e]);
123        let mut out = tgt.clone();
124        for dec in &self.decoders {
125            out = dec.forward(&out, &memory, None, None);
126        }
127        out
128    }
129
130    /// Helper: build boolean-like causal mask [b, heads, t, t] with 1.0 keep, 0.0 masked
131    fn build_causal_mask_static(batch: usize, heads: usize, t: usize) -> Tensor {
132        let mut mask = Tensor::ones(vec![batch, heads, t, t]);
133        for bb in 0..batch {
134            for hh in 0..heads {
135                for i in 0..t {
136                    for j in (i + 1)..t {
137                        let offset = mask.memory_offset(&[bb, hh, i, j]);
138                        let data = mask.data_mut();
139                        data[offset] = 0.0;
140                    }
141                }
142            }
143        }
144        mask
145    }
examples/neural_networks/multi_head_attention.rs (line 191)
165fn main() -> Result<(), Box<dyn std::error::Error>> {
166    println!("=== Multi-Head Attention Example ===");
167
168    let batch = 2usize;
169    let src_len = 5usize;
170    let tgt_len = 4usize;
171    let embed = 16usize;
172    let heads = 4usize;
173
174    let query = Tensor::randn(vec![batch, tgt_len, embed], Some(7));
175    let key = Tensor::randn(vec![batch, src_len, embed], Some(8));
176    let value = Tensor::randn(vec![batch, src_len, embed], Some(9));
177
178    let mut mha = MultiHeadAttention::new(embed, heads, Some(42));
179
180    // Simple causal mask for target self-attention shape [b, h, tq, tk]
181    let mut mask = Tensor::zeros(vec![batch, heads, tgt_len, src_len]);
182    // Disallow attending to future positions when tgt_len <= src_len by adding -1e9
183    // Here, just demonstrate mask broadcast/add mechanics with a light mask on last head
184    if src_len >= tgt_len {
185        // set upper triangle to a large negative value for head 0
186        for i in 0..tgt_len {
187            for j in (i + 1)..src_len {
188                let idx = [0usize, 0usize, i, j];
189                // Quick set via data_mut using a slice view
190                let offset = mask.memory_offset(&idx);
191                let data = mask.data_mut();
192                data[offset] = -1e9;
193            }
194        }
195    }
196
197    let out = mha.forward(&query, &key, &value, Some(&mask));
198    println!("Output shape: {:?}", out.shape().dims());
199
200    // Tiny training step to confirm gradients are wired
201    let mut optimizer = Adam::with_learning_rate(0.01);
202    let mut params = mha.parameters();
203    for p in &params {
204        optimizer.add_parameter(p);
205    }
206
207    // Dummy loss = mean of output
208    let mut loss = out.mean();
209    loss.backward(None);
210    optimizer.step(&mut params);
211    optimizer.zero_grad(&mut params);
212
213    println!("Loss: {:.6}", loss.value());
214    println!("=== Done ===");
215    Ok(())
216}
Source

pub fn value(&self) -> f32

Extract scalar value from single-element tensor

This method provides a convenient way to extract the scalar value from tensors that contain exactly one element. This is commonly used with element iterator results and scalar tensor operations.

§Returns

The scalar value contained in this tensor

§Panics

Panics if the tensor does not contain exactly one element

§Examples
use train_station::Tensor;

// Single-element tensor
let scalar = Tensor::from_slice(&[42.0], vec![1]).unwrap();
assert_eq!(scalar.value(), 42.0);
Examples found in repository?
examples/supervised_training/supervised_regression.rs (line 47)
46fn rmse(pred: &Tensor, target: &Tensor) -> f32 {
47    mse(pred, target).sqrt().value()
48}
49
50fn r2_score(pred: &Tensor, target: &Tensor) -> f32 {
51    // R^2 = 1 - SS_res / SS_tot
52    let y = target;
53    let y_mean = y.mean();
54    let ss_res = pred.sub_tensor(y).pow_scalar(2.0).sum();
55    let ss_tot = y.sub_tensor(&y_mean).pow_scalar(2.0).sum();
56    let ss_res_v = ss_res.value();
57    let ss_tot_v = ss_tot.value().max(1e-12); // avoid divide by zero
58    1.0 - (ss_res_v / ss_tot_v)
59}
More examples
Hide additional examples
examples/neural_networks/basic_encoder.rs (line 98)
73fn main() -> Result<(), Box<dyn std::error::Error>> {
74    println!("=== Basic Encoder Example ===");
75
76    let batch = 2usize;
77    let seq = 6usize;
78    let embed = 32usize;
79    let heads = 4usize;
80
81    let input = Tensor::randn(vec![batch, seq, embed], Some(11));
82    let mut enc = EncoderBlock::new(embed, heads, Some(123));
83
84    // Example: no mask (set Some(mask) to use masking)
85    let out = enc.forward(&input, None);
86    println!("Output shape: {:?}", out.shape().dims());
87
88    // Verify gradients/optimization
89    let mut opt = Adam::with_learning_rate(0.01);
90    let mut params = enc.parameters();
91    for p in &params {
92        opt.add_parameter(p);
93    }
94    let mut loss = out.mean();
95    loss.backward(None);
96    opt.step(&mut params);
97    opt.zero_grad(&mut params);
98    println!("Loss: {:.6}", loss.value());
99    println!("=== Done ===");
100    Ok(())
101}
examples/getting_started/tensor_basics.rs (line 192)
179fn demonstrate_utility_functions() {
180    println!("\n--- Utility Functions ---");
181
182    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184    // Basic properties
185    println!("Shape: {:?}", tensor.shape().dims());
186    println!("Size: {}", tensor.size());
187    println!("Is contiguous: {}", tensor.is_contiguous());
188    println!("Device: {:?}", tensor.device());
189
190    // Mathematical operations
191    let sum = tensor.sum();
192    println!("Sum: {}", sum.value());
193
194    let mean = tensor.mean();
195    println!("Mean: {}", mean.value());
196
197    let norm = tensor.norm();
198    println!("Norm: {}", norm.value());
199
200    // Device placement
201    let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202    println!(
203        "CPU tensor: shape {:?}, device: {:?}",
204        cpu_tensor.shape().dims(),
205        cpu_tensor.device()
206    );
207}
examples/neural_networks/basic_decoder.rs (line 109)
84fn main() -> Result<(), Box<dyn std::error::Error>> {
85    println!("=== Basic Decoder Example ===");
86
87    let batch = 2usize;
88    let src = 7usize;
89    let tgt = 5usize;
90    let embed = 32usize;
91    let heads = 4usize;
92
93    let memory = Tensor::randn(vec![batch, src, embed], Some(21));
94    let tgt_in = Tensor::randn(vec![batch, tgt, embed], Some(22));
95
96    let mut dec = DecoderBlock::new(embed, heads, Some(456));
97    let out = dec.forward(&tgt_in, &memory, None, None);
98    println!("Output shape: {:?}", out.shape().dims());
99
100    let mut opt = Adam::with_learning_rate(0.01);
101    let mut params = dec.parameters();
102    for p in &params {
103        opt.add_parameter(p);
104    }
105    let mut loss = out.mean();
106    loss.backward(None);
107    opt.step(&mut params);
108    opt.zero_grad(&mut params);
109    println!("Loss: {:.6}", loss.value());
110    println!("=== Done ===");
111    Ok(())
112}
examples/neural_networks/basic_transformer.rs (line 169)
148    pub fn train_non_autoregressive_steps(
149        &mut self,
150        src: &Tensor,
151        tgt: &Tensor,
152        steps: usize,
153        lr: f32,
154    ) {
155        let mut opt = Adam::with_learning_rate(lr);
156        {
157            let params_once = self.parameters();
158            for p in &params_once {
159                opt.add_parameter(p);
160            }
161        }
162        for step in 0..steps {
163            // forward + backward scope (immutable borrow)
164            {
165                let pred = self.forward(src, tgt);
166                let diff = pred.sub_tensor(tgt);
167                let mut loss = diff.pow_scalar(2.0).mean();
168                if step == 0 || step + 1 == steps {
169                    println!("NAR train step {}: loss={:.6}", step, loss.value());
170                }
171                loss.backward(None);
172            }
173            // step + zero_grad scope (mutable borrow)
174            let mut params_step = self.parameters();
175            opt.step(&mut params_step);
176            opt.zero_grad(&mut params_step);
177        }
178    }
179
180    /// Auto-regressive training (teacher forcing): predict next token with causal mask
181    pub fn train_autoregressive_steps(
182        &mut self,
183        src: &Tensor,
184        tgt: &Tensor,
185        steps: usize,
186        lr: f32,
187    ) {
188        let mut opt = Adam::with_learning_rate(lr);
189        {
190            let params_once = self.parameters();
191            for p in &params_once {
192                opt.add_parameter(p);
193            }
194        }
195
196        // Build encoder memory once (static dataset demo)
197        let mut memory = src.clone();
198        for enc in &self.encoders {
199            memory = enc.forward(&memory, None);
200        }
201
202        let (b, t, _e) = Self::triple(tgt);
203        // Predict y[t] from y[:t] using causal mask; here we simply predict full seq with mask
204        let causal = Self::build_causal_mask_static(b, self.num_heads, t);
205        for step in 0..steps {
206            // forward + backward scope
207            {
208                let mut out = tgt.clone();
209                for dec in &self.decoders {
210                    out = dec.forward(&out, &memory, Some(&causal), None);
211                }
212                let diff = out.sub_tensor(tgt);
213                let mut loss = diff.pow_scalar(2.0).mean();
214                if step == 0 || step + 1 == steps {
215                    println!("AR  train step {}: loss={:.6}", step, loss.value());
216                }
217                loss.backward(None);
218            }
219            let mut params_step = self.parameters();
220            opt.step(&mut params_step);
221            opt.zero_grad(&mut params_step);
222        }
223    }
224
225    fn triple(t: &Tensor) -> (usize, usize, usize) {
226        let d = t.shape().dims();
227        (d[0], d[1], d[2])
228    }
229}
230
231fn main() -> Result<(), Box<dyn std::error::Error>> {
232    println!("=== Basic Transformer Example ===");
233
234    let batch = 2usize;
235    let src_len = 8usize;
236    let tgt_len = 6usize;
237    let embed = 32usize;
238    let heads = 4usize;
239    let layers = 2usize;
240
241    let src = Tensor::randn(vec![batch, src_len, embed], Some(1001));
242    let tgt = Tensor::randn(vec![batch, tgt_len, embed], Some(1002));
243
244    let mut trf = BasicTransformer::new(embed, heads, layers, Some(999));
245    let out = trf.forward(&src, &tgt);
246    println!("Output shape: {:?}", out.shape().dims());
247
248    // Quick optimization step
249    let mut opt = Adam::with_learning_rate(0.005);
250    let mut params = trf.parameters();
251    for p in &params {
252        opt.add_parameter(p);
253    }
254    let mut loss = out.mean();
255    loss.backward(None);
256    opt.step(&mut params);
257    opt.zero_grad(&mut params);
258    println!("Loss: {:.6}", loss.value());
259
260    // Demo: non auto-regressive inference (single pass)
261    let nar = trf.infer_non_autoregressive(&src, tgt_len);
262    println!("NAR output shape: {:?}", nar.shape().dims());
263
264    // Demo: auto-regressive inference (toy)
265    let ar = trf.infer_autoregressive(&src, 3);
266    println!("AR output shape: {:?}", ar.shape().dims());
267
268    // NAR training demo
269    let nar_tgt = tgt.clone();
270    trf.train_non_autoregressive_steps(&src, &nar_tgt, 3, 0.01);
271
272    // AR training demo (teacher-forced)
273    let ar_tgt = tgt.clone();
274    trf.train_autoregressive_steps(&src, &ar_tgt, 3, 0.01);
275    println!("=== Done ===");
276    Ok(())
277}
examples/iterators/element_iteration.rs (line 106)
93fn demonstrate_basic_iteration() -> Result<(), Box<dyn std::error::Error>> {
94    println!("\n--- Basic Element Iteration ---");
95
96    // Create a simple tensor for demonstration
97    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
98    println!("Original tensor: {:?}", tensor.data());
99
100    // Basic iteration with for loop
101    println!("\nBasic iteration with for loop:");
102    for (i, element) in tensor.iter().enumerate() {
103        println!(
104            "  Element {}: value = {:.1}, shape = {:?}",
105            i,
106            element.value(),
107            element.shape().dims()
108        );
109    }
110
111    // Element-wise transformation
112    println!("\nElement-wise transformation (2x + 1):");
113    let transformed: Tensor = tensor
114        .iter()
115        .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0))
116        .collect();
117    println!("  Result: {:?}", transformed.data());
118
119    // Filtering elements
120    println!("\nFiltering elements (values > 3.0):");
121    let filtered: Tensor = tensor.iter().filter(|elem| elem.value() > 3.0).collect();
122    println!("  Filtered: {:?}", filtered.data());
123
124    Ok(())
125}
126
127/// Demonstrate standard iterator trait methods
128///
129/// Shows compatibility with Rust's standard library iterator methods
130/// and demonstrates various functional programming patterns.
131fn demonstrate_standard_methods() -> Result<(), Box<dyn std::error::Error>> {
132    println!("\n--- Standard Iterator Methods ---");
133
134    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
135
136    // Using map for transformations
137    println!("\nMap transformation (square each element):");
138    let squared: Tensor = tensor.iter().map(|elem| elem.pow_scalar(2.0)).collect();
139    println!("  Squared: {:?}", squared.data());
140
141    // Using enumerate for indexed operations
142    println!("\nEnumerate with indexed operations:");
143    let indexed: Tensor = tensor
144        .iter()
145        .enumerate()
146        .map(|(i, elem)| elem.add_scalar(i as f32))
147        .collect();
148    println!("  Indexed: {:?}", indexed.data());
149
150    // Using fold for reduction
151    println!("\nFold for sum calculation:");
152    let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
153    println!("  Sum: {:.1}", sum);
154
155    // Using find for element search
156    println!("\nFind specific element:");
157    if let Some(found) = tensor.iter().find(|elem| elem.value() == 3.0) {
158        println!("  Found element with value 3.0: {:.1}", found.value());
159    }
160
161    // Using any/all for condition checking
162    println!("\nCondition checking:");
163    let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
164    let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
165    println!("  All positive: {}", all_positive);
166    println!("  Any > 4.0: {}", any_large);
167
168    Ok(())
169}
170
171/// Demonstrate gradient tracking through element operations
172///
173/// Shows how gradient tracking works seamlessly through iterator
174/// operations, maintaining the computational graph for backpropagation.
175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176    println!("\n--- Gradient Tracking ---");
177
178    // Create a tensor with gradient tracking enabled
179    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180    println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182    // Perform element-wise operations through iteration
183    let result: Tensor = tensor
184        .iter()
185        .map(|elem| {
186            // Apply a complex transformation: (x^2 + 1) * 2
187            elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188        })
189        .collect();
190
191    println!("Result tensor: {:?}", result.data());
192    println!("Result requires_grad: {}", result.requires_grad());
193
194    // Compute gradients
195    let mut loss = result.sum();
196    loss.backward(None);
197
198    println!("Loss: {:.6}", loss.value());
199    println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201    Ok(())
202}
203
204/// Demonstrate advanced iterator patterns
205///
206/// Shows complex iterator chains and advanced functional programming
207/// patterns for sophisticated data processing workflows.
208fn demonstrate_advanced_patterns() -> Result<(), Box<dyn std::error::Error>> {
209    println!("\n--- Advanced Iterator Patterns ---");
210
211    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6])?;
212    println!("Input tensor: {:?}", tensor.data());
213
214    // Complex chain: enumerate -> filter -> map -> collect
215    println!("\nComplex chain (even indices only, add index to value):");
216    let result: Tensor = tensor
217        .iter()
218        .enumerate()
219        .filter(|(i, _)| i % 2 == 0) // Take even indices
220        .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
221        .collect();
222    println!("  Result: {:?}", result.data());
223
224    // Using take and skip for windowing
225    println!("\nWindowing with take and skip:");
226    let window1: Tensor = tensor.iter().take(3).collect();
227    let window2: Tensor = tensor.iter().skip(2).take(3).collect();
228    println!("  Window 1 (first 3): {:?}", window1.data());
229    println!("  Window 2 (middle 3): {:?}", window2.data());
230
231    // Using rev() for reverse iteration
232    println!("\nReverse iteration:");
233    let reversed: Tensor = tensor.iter().rev().collect();
234    println!("  Reversed: {:?}", reversed.data());
235
236    // Chaining with mathematical operations
237    println!("\nMathematical operation chain:");
238    let math_result: Tensor = tensor
239        .iter()
240        .map(|elem| elem.exp()) // e^x
241        .filter(|elem| elem.value() < 50.0) // Filter large values
242        .map(|elem| elem.log()) // ln(x)
243        .collect();
244    println!("  Math chain result: {:?}", math_result.data());
245
246    // Using zip for element-wise combinations
247    println!("\nElement-wise combination with zip:");
248    let tensor2 = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![6])?;
249    let combined: Tensor = tensor
250        .iter()
251        .zip(tensor2.iter())
252        .map(|(a, b)| a.mul_tensor(&b)) // Element-wise multiplication
253        .collect();
254    println!("  Combined: {:?}", combined.data());
255
256    Ok(())
257}
Source

pub fn view(&self, new_shape: Vec<i32>) -> Tensor

Create a view with a new shape (requires contiguous memory)

Behaves like PyTorch view: tensor must be contiguous and the total number of elements must remain the same. Supports -1 inference for one dimension.

§Arguments
  • new_shape - New shape for the tensor (can contain -1 for inference)
§Returns

A tensor viewing the same data with a new shape

§Examples
use train_station::Tensor;

let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
let y = x.view(vec![2, 2]);
assert_eq!(y.shape().dims(), vec![2, 2]);
Examples found in repository?
examples/RL_training/ppo_continuous.rs (line 112)
106    fn forward(&self, state: &Tensor) -> (Tensor, Tensor) {
107        // Returns (mean [B, A], log_std [A])
108        let mean = self.net.forward(state);
109        (
110            mean,
111            self.log_std
112                .view(vec![1, self.log_std.shape().dims()[0] as i32]),
113        )
114    }
More examples
Hide additional examples
examples/RL_training/td3.rs (line 196)
193    fn forward(&self, state: &Tensor, action: &Tensor) -> Tensor {
194        // Concatenate along feature dim (dim=1) for batched inputs
195        // IMPORTANT: use views to preserve gradient graph; cloning would detach autograd
196        let s_view = state.view(state.shape().dims().iter().map(|&d| d as i32).collect());
197        let a_view = action.view(action.shape().dims().iter().map(|&d| d as i32).collect());
198        let sa = Tensor::cat(&[s_view, a_view], 1);
199        self.net.forward(&sa, None)
200    }
examples/neural_networks/basic_encoder.rs (line 59)
53    pub fn forward(&self, input: &Tensor, attn_mask: Option<&Tensor>) -> Tensor {
54        let attn = self.mha.forward(input, input, input, attn_mask);
55        let res1 = attn.add_tensor(input);
56
57        // Feed-forward network with ReLU and residual
58        let (b, t, e) = Self::triple(input);
59        let x2d = res1.contiguous().view(vec![(b * t) as i32, e as i32]);
60        let hidden = self.ffn_in.forward(&x2d).relu();
61        let out2d = self.ffn_out.forward(&hidden);
62        let out = out2d.view(vec![b as i32, t as i32, e as i32]);
63        out.add_tensor(&res1)
64    }
examples/neural_networks/basic_decoder.rs (line 70)
56    pub fn forward(
57        &self,
58        tgt: &Tensor,
59        memory: &Tensor,
60        causal_mask: Option<&Tensor>,
61        cross_mask: Option<&Tensor>,
62    ) -> Tensor {
63        let self_attn = self.self_attn.forward(tgt, tgt, tgt, causal_mask);
64        let res1 = self_attn.add_tensor(tgt);
65
66        let cross = self.cross_attn.forward(&res1, memory, memory, cross_mask);
67        let res2 = cross.add_tensor(&res1);
68
69        let (b, t, e) = Self::triple(tgt);
70        let x2d = res2.contiguous().view(vec![(b * t) as i32, e as i32]);
71        let hidden = self.ffn_in.forward(&x2d).relu();
72        let out2d = self.ffn_out.forward(&hidden);
73        let out = out2d.view(vec![b as i32, t as i32, e as i32]);
74        out.add_tensor(&res2)
75    }
examples/getting_started/tensor_basics.rs (line 131)
120fn demonstrate_shape_operations() {
121    println!("\n--- Shape Operations ---");
122
123    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
124    println!(
125        "Original: shape {:?}, data: {:?}",
126        tensor.shape().dims(),
127        tensor.data()
128    );
129
130    // Reshape (view)
131    let reshaped = tensor.view(vec![3, 2]);
132    println!(
133        "Reshaped to [3, 2]: shape {:?}, data: {:?}",
134        reshaped.shape().dims(),
135        reshaped.data()
136    );
137
138    // Create a different shaped tensor for demonstration
139    let tensor_2d = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
140    println!(
141        "2D tensor: shape {:?}, data: {:?}",
142        tensor_2d.shape().dims(),
143        tensor_2d.data()
144    );
145
146    // Create a 1D tensor
147    let tensor_1d = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
148    println!(
149        "1D tensor: shape {:?}, data: {:?}",
150        tensor_1d.shape().dims(),
151        tensor_1d.data()
152    );
153}
examples/RL_training/../neural_networks/basic_linear_layer.rs (line 359)
330fn demonstrate_single_vs_batch_inference() {
331    println!("\n--- Single vs Batch Inference ---");
332
333    let layer = LinearLayer::new(4, 3, Some(46));
334
335    // Single inference
336    println!("Single inference:");
337    let single_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
338    let single_output = layer.forward_no_grad(&single_input);
339    println!("  Input shape: {:?}", single_input.shape().dims());
340    println!("  Output shape: {:?}", single_output.shape().dims());
341    println!("  Output: {:?}", single_output.data());
342
343    // Batch inference
344    println!("Batch inference:");
345    let batch_input = Tensor::from_slice(
346        &[
347            1.0, 2.0, 3.0, 4.0, // Sample 1
348            5.0, 6.0, 7.0, 8.0, // Sample 2
349            9.0, 10.0, 11.0, 12.0, // Sample 3
350        ],
351        vec![3, 4],
352    )
353    .unwrap();
354    let batch_output = layer.forward_no_grad(&batch_input);
355    println!("  Input shape: {:?}", batch_input.shape().dims());
356    println!("  Output shape: {:?}", batch_output.shape().dims());
357
358    // Verify batch consistency - first sample should match single inference
359    let _first_batch_sample = batch_output.view(vec![3, 3]); // Reshape to access first sample
360    let first_sample_data = &batch_output.data()[0..3]; // First 3 elements
361    let single_sample_data = single_output.data();
362
363    println!("Consistency check:");
364    println!("  Single output: {:?}", single_sample_data);
365    println!("  First batch sample: {:?}", first_sample_data);
366    println!(
367        "  Match: {}",
368        single_sample_data
369            .iter()
370            .zip(first_sample_data.iter())
371            .all(|(a, b)| (a - b).abs() < 1e-6)
372    );
373}
Source

pub fn element_view(&self, index: usize) -> Tensor

Create an element view for the specified index

Returns a scalar tensor (shape [1]) that views a single element of the source tensor. Maintains gradient tracking.

§Arguments
  • index - Linear index of the element to view
§Returns

A scalar tensor viewing the specified element

§Examples
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let element = tensor.element_view(1);
assert_eq!(element.value(), 2.0);
Examples found in repository?
examples/iterators/advanced_patterns.rs (line 132)
87fn demonstrate_data_pipeline() -> Result<(), Box<dyn std::error::Error>> {
88    println!("\n--- Data Processing Pipeline ---");
89
90    // Simulate raw sensor data with noise
91    let raw_data: Vec<f32> = (0..20)
92        .map(|i| {
93            let base = i as f32 * 0.5;
94            let noise = (i % 3) as f32 * 0.1;
95            base + noise
96        })
97        .collect();
98
99    let tensor = Tensor::from_slice(&raw_data, vec![20])?;
100    println!("Raw sensor data: {:?}", tensor.data());
101
102    // Multi-stage processing pipeline
103    println!("\nProcessing pipeline:");
104    println!("1. Normalize data (z-score)");
105    println!("2. Apply smoothing filter");
106    println!("3. Detect outliers");
107    println!("4. Apply feature scaling");
108
109    // Stage 1: Normalization
110    let mean = tensor.mean().value();
111    let std = tensor.std().value();
112    let normalized: Tensor = tensor
113        .iter()
114        .map(|elem| elem.sub_scalar(mean).div_scalar(std))
115        .collect();
116    println!(
117        "  Normalized (mean={:.3}, std={:.3}): {:?}",
118        mean,
119        std,
120        normalized.data()
121    );
122
123    // Stage 2: Smoothing (simple moving average)
124    let smoothed: Tensor = normalized
125        .iter()
126        .enumerate()
127        .map(|(i, elem)| {
128            if i == 0 || i == normalized.size() - 1 {
129                elem.clone()
130            } else {
131                // Simple 3-point average
132                let prev = normalized.element_view(i - 1);
133                let next = normalized.element_view(i + 1);
134                elem.add_tensor(&prev).add_tensor(&next).div_scalar(3.0)
135            }
136        })
137        .collect();
138    println!("  Smoothed: {:?}", smoothed.data());
139
140    // Stage 3: Outlier detection and removal
141    let outlier_threshold = 2.0;
142    let cleaned: Tensor = smoothed
143        .iter()
144        .filter(|elem| elem.value().abs() < outlier_threshold)
145        .collect();
146    println!(
147        "  Outliers removed (threshold={}): {:?}",
148        outlier_threshold,
149        cleaned.data()
150    );
151
152    // Stage 4: Feature scaling to [0, 1] range
153    let min_val = cleaned
154        .iter()
155        .map(|e| e.value())
156        .fold(f32::INFINITY, f32::min);
157    let max_val = cleaned
158        .iter()
159        .map(|e| e.value())
160        .fold(f32::NEG_INFINITY, f32::max);
161    let scaled: Tensor = cleaned
162        .iter()
163        .map(|elem| elem.sub_scalar(min_val).div_scalar(max_val - min_val))
164        .collect();
165    println!("  Scaled to [0,1]: {:?}", scaled.data());
166
167    Ok(())
168}
169
170/// Demonstrate conditional processing patterns
171///
172/// Shows how to implement dynamic filtering and transformation
173/// based on data characteristics and conditions.
174fn demonstrate_conditional_processing() -> Result<(), Box<dyn std::error::Error>> {
175    println!("\n--- Conditional Processing ---");
176
177    // Create data with mixed characteristics
178    let data = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0];
179    let tensor = Tensor::from_slice(&data, vec![10])?;
180    println!("Input data: {:?}", tensor.data());
181
182    // Conditional transformation based on sign
183    println!("\nConditional transformation (positive/negative handling):");
184    let processed: Tensor = tensor
185        .iter()
186        .map(|elem| {
187            let val = elem.value();
188            if val > 0.0 {
189                elem.pow_scalar(2.0) // Square positive values
190            } else {
191                elem.mul_scalar(-1.0).sqrt() // Square root of absolute negative values
192            }
193        })
194        .collect();
195    println!("  Processed: {:?}", processed.data());
196
197    // Adaptive filtering based on local statistics
198    println!("\nAdaptive filtering (remove values > 2 std from local mean):");
199    let window_size = 3;
200    let adaptive_filtered: Tensor = tensor
201        .iter()
202        .enumerate()
203        .filter(|(i, elem)| {
204            let start = i.saturating_sub(window_size / 2);
205            let end = (i + window_size / 2 + 1).min(tensor.size());
206
207            // Calculate local mean and std
208            let local_values: Vec<f32> = (start..end)
209                .map(|j| tensor.element_view(j).value())
210                .collect();
211
212            let local_mean = local_values.iter().sum::<f32>() / local_values.len() as f32;
213            let local_variance = local_values
214                .iter()
215                .map(|v| (v - local_mean).powi(2))
216                .sum::<f32>()
217                / local_values.len() as f32;
218            let local_std = local_variance.sqrt();
219
220            let threshold = local_mean + 2.0 * local_std;
221            elem.value() <= threshold
222        })
223        .map(|(_, elem)| elem)
224        .collect();
225    println!("  Adaptive filtered: {:?}", adaptive_filtered.data());
226
227    // Multi-condition processing
228    println!("\nMulti-condition processing:");
229    let multi_processed: Tensor = tensor
230        .iter()
231        .map(|elem| {
232            let val = elem.value();
233            match () {
234                _ if val > 5.0 => elem.mul_scalar(2.0), // Double large values
235                _ if val < -5.0 => elem.div_scalar(2.0), // Halve small values
236                _ if val.abs() < 2.0 => elem.add_scalar(1.0), // Add 1 to small values
237                _ => elem.clone(),                      // Keep others unchanged
238            }
239        })
240        .collect();
241    println!("  Multi-condition: {:?}", multi_processed.data());
242
243    Ok(())
244}
Source

pub fn slice_view(&self, start: usize, step: usize, length: usize) -> Tensor

Create a slice view of the tensor

Returns a view of a contiguous or strided slice of the source tensor.

§Arguments
  • start - Starting index
  • step - Step size (1 for contiguous)
  • length - Number of elements
§Returns

A tensor viewing the specified slice

§Examples
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
let slice = tensor.slice_view(1, 2, 2); // [2.0, 4.0]
assert_eq!(slice.get(&[0]), 2.0);
assert_eq!(slice.get(&[1]), 4.0);
Examples found in repository?
examples/RL_training/ppo_discrete.rs (line 475)
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320    println!("=== PPO Discrete Example (YardEnv) ===");
321
322    let state_dim = 3usize;
323    let action_dim = 3usize;
324    let total_steps = std::env::var("PPOD_STEPS")
325        .ok()
326        .and_then(|v| v.parse::<usize>().ok())
327        .unwrap_or(3500usize);
328    let horizon = 128usize;
329    let epochs = 4usize;
330    let mini_batch_size = 64usize;
331    let gamma = 0.99f32;
332    let lam = 0.95f32;
333    let clip_eps = 0.2f32;
334    let vf_coef = 0.5f32;
335    let ent_coef = 0.0f32;
336    let max_grad_norm = 1.0f32;
337
338    let mut actor = Actor::new(state_dim, action_dim, Some(111));
339    let mut critic = Critic::new(state_dim, Some(222));
340    let mut actor_opt = Adam::with_learning_rate(3e-4);
341    for p in actor.parameters() {
342        actor_opt.add_parameter(p);
343    }
344    let mut critic_opt = Adam::with_learning_rate(3e-4);
345    for p in critic.parameters() {
346        critic_opt.add_parameter(p);
347    }
348
349    let mut env = YardEnv::new(1234);
350    let mut rng = SmallRng::new(98765);
351    let mut state = env.reset();
352    let mut episode_return = 0.0f32;
353    let mut episode = 0usize;
354    let mut ema_return: Option<f32> = None;
355    let ema_alpha = 0.05f32;
356    let mut best_return = f32::NEG_INFINITY;
357
358    let mut t = 0usize;
359    while t < total_steps {
360        let mut batch = RolloutBatch::new(horizon, state_dim);
361        for _ in 0..horizon {
362            // Actor logits and categorical sampling
363            let logits = actor.forward(&state); // [1, A]
364            let probs = logits.softmax(1); // [1, A]
365                                           // sample action from probs (CPU sampling)
366            let p = probs.data();
367            let (p0, p1, _p2) = (p[0], p[1], p[2]);
368            let u = rng.next_f32();
369            let a_idx = if u < p0 {
370                0
371            } else if u < p0 + p1 {
372                1
373            } else {
374                2
375            };
376
377            let old_logp = {
378                let _ng = NoGradTrack::new();
379                let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380                lp.data()[0]
381            };
382
383            // Step env
384            let (next_state, reward, done) = env.step(a_idx);
385            episode_return += reward;
386
387            // Critic value
388            let value_t = critic.forward(&state);
389            let value_v = value_t.data()[0];
390
391            batch.push(
392                state.data(),
393                a_idx,
394                old_logp,
395                reward,
396                if done { 1.0 } else { 0.0 },
397                value_v,
398                next_state.data(),
399            );
400
401            state = if done {
402                let st = env.reset();
403                ema_return = Some(match ema_return {
404                    None => episode_return,
405                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406                });
407                if episode_return > best_return {
408                    best_return = episode_return;
409                }
410                println!(
411                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412                    t,
413                    episode,
414                    episode_return,
415                    ema_return.unwrap_or(episode_return),
416                    best_return
417                );
418                episode_return = 0.0;
419                episode += 1;
420                st
421            } else {
422                next_state
423            };
424
425            t += 1;
426            if t >= total_steps {
427                break;
428            }
429        }
430
431        // Bootstrap values for GAE
432        let next_values: Vec<f32> = {
433            let mut out = Vec::with_capacity(batch.len());
434            for i in 0..batch.len() {
435                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437                out.push(critic.forward(&s2_t).data()[0]);
438            }
439            out
440        };
441
442        let mut returns = vec![0.0f32; batch.len()];
443        let mut adv = vec![0.0f32; batch.len()];
444        compute_gae(
445            &mut returns,
446            &mut adv,
447            &batch.rewards,
448            &batch.dones,
449            &batch.values,
450            &next_values,
451            gamma,
452            lam,
453        );
454        normalize_in_place(&mut adv, 1e-8);
455
456        // Tensors for training
457        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458        let actions_vec = batch.actions.clone();
459        let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463        // PPO epochs
464        let num_minibatches = batch.len().div_ceil(mini_batch_size);
465        for e in 0..epochs {
466            for mb in 0..num_minibatches {
467                let start = mb * mini_batch_size;
468                let end = (start + mini_batch_size).min(batch.len());
469                if start >= end {
470                    break;
471                }
472
473                // Views
474                let s_mb = states_t
475                    .slice_view(start * state_dim, 1, (end - start) * state_dim)
476                    .reshape(vec![(end - start) as i32, state_dim as i32]);
477                let oldlp_mb = old_logp_t
478                    .slice_view(start, 1, end - start)
479                    .reshape(vec![(end - start) as i32, 1]);
480                let ret_mb = returns_t
481                    .slice_view(start, 1, end - start)
482                    .reshape(vec![(end - start) as i32, 1]);
483                let adv_mb = adv_t
484                    .slice_view(start, 1, end - start)
485                    .reshape(vec![(end - start) as i32, 1]);
486                let a_slice = &actions_vec[start..end];
487
488                // Zero grads
489                {
490                    let mut ps = actor.parameters();
491                    actor_opt.zero_grad(&mut ps);
492                }
493                {
494                    let mut ps = critic.parameters();
495                    critic_opt.zero_grad(&mut ps);
496                }
497
498                // Forward
499                let logits_mb = actor.forward(&s_mb); // [B,A]
500                let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501                let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502                let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503                let pg1 = ratio.mul_tensor(&adv_mb);
504                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507                let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509                let v_pred = critic.forward(&s_mb);
510                let v_loss = v_pred
511                    .sub_tensor(&ret_mb)
512                    .pow_scalar(2.0)
513                    .mean()
514                    .mul_scalar(vf_coef);
515
516                // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517                let probs_mb = logits_mb.softmax(1);
518                let logp_all = probs_mb.add_scalar(1e-8).log();
519                let ent = probs_mb
520                    .mul_tensor(&logp_all)
521                    .sum_dims(&[1], true)
522                    .mul_scalar(-1.0)
523                    .mean()
524                    .mul_scalar(ent_coef);
525
526                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527                loss.backward(None);
528
529                // Step actor
530                {
531                    let params = actor.parameters();
532                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
533                    for p in params {
534                        if p.grad_owned().is_some() {
535                            with_grads.push(p);
536                        }
537                    }
538                    if !with_grads.is_empty() {
539                        let _ = grad_global_norm(&mut with_grads);
540                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541                        actor_opt.step(&mut with_grads);
542                        actor_opt.zero_grad(&mut with_grads);
543                    }
544                }
545
546                // Step critic
547                {
548                    let params = critic.parameters();
549                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
550                    for p in params {
551                        if p.grad_owned().is_some() {
552                            with_grads.push(p);
553                        }
554                    }
555                    if !with_grads.is_empty() {
556                        let _ = grad_global_norm(&mut with_grads);
557                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558                        critic_opt.step(&mut with_grads);
559                        critic_opt.zero_grad(&mut with_grads);
560                    }
561                }
562
563                if e == 0 && mb == 0 {
564                    println!(
565                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
566                        t,
567                        actor_loss.value(),
568                        v_loss.value()
569                    );
570                }
571
572                clear_all_graphs_known();
573            }
574        }
575    }
576
577    println!("=== PPO discrete training finished ===");
578    Ok(())
579}
More examples
Hide additional examples
examples/RL_training/ppo_continuous.rs (line 489)
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330    println!("=== PPO Continuous Example (YardEnv) ===");
331
332    let state_dim = 3usize;
333    let action_dim = 1usize;
334
335    // Hparams
336    let total_steps = std::env::var("PPO_STEPS")
337        .ok()
338        .and_then(|v| v.parse::<usize>().ok())
339        .unwrap_or(4000usize);
340    let horizon = 128usize; // rollout length per update
341    let epochs = 4usize; // PPO epochs per update
342    let mini_batch_size = 64usize; // minibatch from horizon
343    let gamma = 0.99f32;
344    let lam = 0.95f32; // GAE lambda
345    let clip_eps = 0.2f32;
346    let vf_coef = 0.5f32;
347    let ent_coef = 0.0f32;
348    let max_grad_norm = 1.0f32;
349
350    // Models
351    let mut actor = Actor::new(state_dim, action_dim, Some(101));
352    let mut critic = Critic::new(state_dim, Some(202));
353
354    // Opts
355    let mut actor_opt = Adam::with_learning_rate(3e-4);
356    for p in actor.parameters() {
357        actor_opt.add_parameter(p);
358    }
359    let mut critic_opt = Adam::with_learning_rate(3e-4);
360    for p in critic.parameters() {
361        critic_opt.add_parameter(p);
362    }
363
364    // Env and RNG
365    let mut env = YardEnv::new(42);
366    let mut rng = SmallRng::new(999);
367    let mut state = env.reset();
368
369    // Metrics
370    let mut episode_return = 0.0f32;
371    let mut episode = 0usize;
372    let mut ema_return: Option<f32> = None;
373    let ema_alpha = 0.05f32;
374    let mut best_return = f32::NEG_INFINITY;
375
376    let mut t = 0usize;
377    while t < total_steps {
378        // Collect a rollout
379        let mut batch = RolloutBatch::new(horizon, state_dim);
380        for _ in 0..horizon {
381            // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382            let (mean, log_std_row) = actor.forward(&state);
383            let mean_v = mean.data()[0];
384            let log_std_v = log_std_row.data()[0];
385            let std_v = log_std_v.exp();
386            let noise = rng.normal();
387            let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389            // Build action tensor [1, A] for log_prob calculation with autograd
390            let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391            let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392            let log_prob_v = log_prob_t.data()[0];
393
394            // Step env
395            let (next_state, reward, done) = env.step(action_v);
396            episode_return += reward;
397
398            // Value
399            let value_t = critic.forward(&state);
400            let value_v = value_t.data()[0];
401
402            // Push
403            batch.push(
404                state.data(),
405                action_v,
406                log_prob_v,
407                reward,
408                if done { 1.0 } else { 0.0 },
409                value_v,
410                next_state.data(),
411            );
412
413            // Reset
414            state = if done {
415                let st = env.reset();
416                ema_return = Some(match ema_return {
417                    None => episode_return,
418                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419                });
420                if episode_return > best_return {
421                    best_return = episode_return;
422                }
423                println!(
424                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425                    t,
426                    episode,
427                    episode_return,
428                    ema_return.unwrap_or(episode_return),
429                    best_return
430                );
431                episode_return = 0.0;
432                episode += 1;
433                st
434            } else {
435                next_state
436            };
437
438            t += 1;
439            if t >= total_steps {
440                break;
441            }
442        }
443
444        // Bootstrap next values for GAE
445        let next_values: Vec<f32> = {
446            let mut out = Vec::with_capacity(batch.len());
447            for i in 0..batch.len() {
448                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450                let v2 = critic.forward(&s2_t).data()[0];
451                out.push(v2);
452            }
453            out
454        };
455
456        // Compute returns and advantages
457        let mut returns = vec![0.0f32; batch.len()];
458        let mut adv = vec![0.0f32; batch.len()];
459        compute_gae(
460            &mut returns,
461            &mut adv,
462            &batch.rewards,
463            &batch.dones,
464            &batch.values,
465            &next_values,
466            gamma,
467            lam,
468        );
469        normalize_in_place(&mut adv, 1e-8);
470
471        // Prepare tensors for training
472        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473        let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474        let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478        // PPO epochs over the rollout
479        let num_minibatches = batch.len().div_ceil(mini_batch_size);
480        for e in 0..epochs {
481            for mb in 0..num_minibatches {
482                let start = mb * mini_batch_size;
483                let end = (start + mini_batch_size).min(batch.len());
484                if start >= end {
485                    break;
486                }
487
488                // Slice views
489                let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490                let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491                let a_mb = actions_t
492                    .slice_view(start * action_dim, 1, (end - start) * action_dim)
493                    .reshape(vec![(end - start) as i32, action_dim as i32]);
494                let oldlp_mb = old_logp_t
495                    .slice_view(start, 1, end - start)
496                    .reshape(vec![(end - start) as i32, 1]);
497                let ret_mb = returns_t
498                    .slice_view(start, 1, end - start)
499                    .reshape(vec![(end - start) as i32, 1]);
500                let adv_mb = adv_t
501                    .slice_view(start, 1, end - start)
502                    .reshape(vec![(end - start) as i32, 1]);
503
504                // Zero grads
505                {
506                    let mut ps = actor.parameters();
507                    actor_opt.zero_grad(&mut ps);
508                }
509                {
510                    let mut ps = critic.parameters();
511                    critic_opt.zero_grad(&mut ps);
512                }
513
514                // Forward actor and critic
515                let (mean_mb, log_std_row) = actor.forward(&s_mb);
516                let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517                let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518                let clip_low =
519                    Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520                        .unwrap();
521                let clip_high =
522                    Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523                        .unwrap();
524                // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525                let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526                let ratio_clipped =
527                    clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528                let pg1 = ratio.mul_tensor(&adv_mb);
529                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532                let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534                let v_pred = critic.forward(&s_mb);
535                let v_loss = v_pred
536                    .sub_tensor(&ret_mb)
537                    .pow_scalar(2.0)
538                    .mean()
539                    .mul_scalar(vf_coef);
540
541                // Entropy (approx Gaussian entropy per action)
542                let entropy = log_std_row
543                    .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544                    .sum_dims(&[1], true)
545                    .mean()
546                    .mul_scalar(ent_coef);
547
548                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549                loss.backward(None);
550
551                // Step actor
552                {
553                    let params = actor.parameters();
554                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
555                    for p in params {
556                        if p.grad_owned().is_some() {
557                            with_grads.push(p);
558                        }
559                    }
560                    if !with_grads.is_empty() {
561                        let _ = grad_global_norm(&mut with_grads);
562                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563                        actor_opt.step(&mut with_grads);
564                        actor_opt.zero_grad(&mut with_grads);
565                    }
566                }
567
568                // Step critic
569                {
570                    let params = critic.parameters();
571                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
572                    for p in params {
573                        if p.grad_owned().is_some() {
574                            with_grads.push(p);
575                        }
576                    }
577                    if !with_grads.is_empty() {
578                        let _ = grad_global_norm(&mut with_grads);
579                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580                        critic_opt.step(&mut with_grads);
581                        critic_opt.zero_grad(&mut with_grads);
582                    }
583                }
584
585                // Occasionally log
586                if e == 0 && mb == 0 {
587                    println!(
588                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
589                        t,
590                        actor_loss.value(),
591                        v_loss.value()
592                    );
593                }
594
595                clear_all_graphs_known();
596            }
597        }
598    }
599
600    println!("=== PPO training finished ===");
601    Ok(())
602}
Source

pub fn allocation_owner(&self) -> Option<&Arc<Allocation>>

Get the allocation owner for this tensor

Returns the shared allocation owner if this tensor is a view, or None if this tensor owns its memory directly.

§Returns

Optional reference to the allocation owner

§Implementation Details

This method is used internally to manage memory lifecycle for tensor views. It helps determine whether a tensor shares memory with another tensor.

Source

pub fn new_uninitialized(shape_dims: Vec<usize>) -> Self

Create a new tensor with uninitialized memory

This method allocates memory for a tensor without initializing it to any value. This is useful for performance-critical operations where the memory will be immediately overwritten, such as matrix multiplication results.

§Safety

The caller must ensure that all memory is written before reading from the tensor. Reading from uninitialized memory is undefined behavior.

§Arguments
  • shape_dims - The dimensions of the tensor
§Returns

A tensor with uninitialized memory

§Performance
  • Zero Initialization: Skips memory initialization for maximum performance
  • SIMD Ready: Properly aligned for vectorized operations
  • Memory Efficient: Uses optimized alignment strategies
§Example
use train_station::Tensor;

// Create uninitialized tensor for matmul result
let mut result = Tensor::new_uninitialized(vec![100, 100]);
// Initialize the memory before use
for value in result.data_mut() {
    *value = 0.0;
}
Source

pub fn new_uninitialized_aligned( shape_dims: Vec<usize>, alignment_bytes: usize, ) -> Self

Create a new uninitialized tensor with an explicit alignment request (in bytes)

This is intended for internal high-performance paths (e.g., packed GEMM panels) where stronger alignment such as 64 bytes is desired even on AVX2 systems.

Source§

impl Tensor

Source

pub fn gather( &self, dim: usize, indices: &[usize], index_shape: &[usize], ) -> Tensor

Gather values along a dimension using a tensor of indices

This operation extracts elements from the input tensor based on indices provided along a specified dimension. The output tensor has the same shape as the index tensor, with each element taken from the input tensor at the corresponding position with the index value substituted for the specified dimension.

The gather operation is commonly used in machine learning for operations like embedding lookups, attention mechanisms, and advanced indexing patterns.

§Arguments
  • dim - The dimension along which to gather values (must be < tensor rank)
  • indices - Flattened indices buffer containing the positions to gather from
  • index_shape - Shape of the indices tensor and output tensor
§Returns

A new tensor with shape index_shape containing the gathered values

§Examples
§Basic Gather Operation
use train_station::Tensor;

// Create a 2x3 tensor: [[0.0, 0.1, 0.2], [0.3, 0.4, 0.5]]
let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap();

// Gather along dimension 1 (columns) with indices [2, 0, 1, 1]
let indices = [2, 0, 1, 1];
let index_shape = [2, 2];
let result = tensor.gather(1, &indices, &index_shape);

// Result shape is [2, 2]
assert_eq!(result.shape().dims(), vec![2, 2]);

// Row 0: indices [2, 0] -> [0.2, 0.0]
assert!((result.get(&[0, 0]) - 0.2).abs() < 1e-6);
assert!((result.get(&[0, 1]) - 0.0).abs() < 1e-6);

// Row 1: indices [1, 1] -> [0.4, 0.4]
assert!((result.get(&[1, 0]) - 0.4).abs() < 1e-6);
assert!((result.get(&[1, 1]) - 0.4).abs() < 1e-6);
§Gather with Gradient Tracking
use train_station::Tensor;

let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap()
    .with_requires_grad();

let indices = [1, 1, 0, 2];
let index_shape = [2, 2];
let mut result = tensor.gather(1, &indices, &index_shape);

// Compute gradients
result.backward(None);
let grad = tensor.grad_owned().expect("gradient missing");

// Verify gradient accumulation for repeated indices
assert!((grad.get(&[0, 1]) - 2.0).abs() < 1e-6); // Index 1 used twice in row 0
§Performance Characteristics
  • Time Complexity: O(n) where n is the number of elements in the output
  • Memory Usage: Creates a new tensor with the same size as the index tensor
  • Optimization: Uses precomputed strides for efficient memory access
  • GradTrack Overhead: Minimal overhead when gradient tracking is enabled
§Implementation Details

The gather operation works by:

  1. Validating input dimensions and index bounds
  2. Creating an output tensor with the specified index shape
  3. Iterating through all positions in the output tensor
  4. Computing source offsets using the input tensor’s strides
  5. Copying values from the input tensor to the output tensor
  6. Registering the operation for gradient computation if needed
§Safety

This function performs bounds checking to ensure:

  • The specified dimension is within the tensor’s rank
  • All indices are within bounds for the specified dimension
  • The index shape is compatible with the input tensor shape
  • The indices buffer length matches the product of index shape dimensions
§Panics

This function will panic if:

  • dim is greater than or equal to the tensor’s rank
  • Any index in indices is out of bounds for the specified dimension
  • The index_shape rank doesn’t match the input tensor’s rank
  • The index_shape dimensions don’t match the input tensor (except along dim)
  • The indices length doesn’t equal the product of index_shape dimensions
Examples found in repository?
examples/supervised_training/supervised_classification.rs (line 57)
44fn cross_entropy_logits(
45    logits: &Tensor,
46    labels: &[usize],
47    batch: usize,
48    _num_classes: usize,
49) -> Tensor {
50    // log_softmax = logits - logsumexp(logits, dim=1)
51    let max_logits = logits.max_dims(&[1], true);
52    let shifted = logits.sub_tensor(&max_logits);
53    let exp = shifted.exp();
54    let sum_exp = exp.sum_dims(&[1], true);
55    let log_sum_exp = sum_exp.log();
56    let log_softmax = shifted.sub_tensor(&log_sum_exp);
57    let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58    ll.mul_scalar(-1.0).mean()
59}
More examples
Hide additional examples
examples/RL_training/ppo_discrete.rs (line 286)
273fn log_prob_actions(
274    logits: &Tensor,
275    actions: &[usize],
276    batch: usize,
277    _action_dim: usize,
278) -> Tensor {
279    let max_logits = logits.max_dims(&[1], true); // [B,1]
280    let shifted = logits.sub_tensor(&max_logits);
281    let exp = shifted.exp();
282    let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283    let log_sum_exp = sum_exp.log(); // [B,1]
284    let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285                                                        // gather selected action log-probs
286    log_softmax.gather(1, actions, &[batch, 1])
287}
examples/RL_training/dqn.rs (line 470)
333pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334    println!("=== DQN Example (YardEnv discrete) ===");
335
336    // Dims
337    let state_dim = 3usize;
338    let action_dim = 3usize;
339
340    // Hparams
341    let gamma = 0.99f32;
342    let batch_size = 64usize;
343    let start_steps = 200usize;
344    let target_update_interval = 200usize; // hard update cadence
345    let max_grad_norm = 1.0f32;
346    let mut epsilon = 1.0f32;
347    let eps_min = 0.05f32;
348    let eps_decay_steps = 2_000usize; // linear decay
349    let total_steps = std::env::var("DQN_STEPS")
350        .ok()
351        .and_then(|v| v.parse::<usize>().ok())
352        .unwrap_or(3000usize);
353
354    // Models
355    let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356    let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357    q_targ.net.copy_from(&q_net.net);
358    q_targ.set_requires_grad_all(false);
359
360    // Optimizer
361    let mut q_opt = Adam::with_learning_rate(3e-4);
362    for p in q_net.parameters() {
363        q_opt.add_parameter(p);
364    }
365
366    // Replay + env
367    let mut rb = ReplayBuffer::new(100_000, state_dim);
368    let mut env = YardEnv::new(12345);
369    let mut rng = SmallRng::new(999_111);
370
371    // Metrics
372    let mut state = env.reset();
373    let mut episode_return = 0.0f32;
374    let mut episode = 0usize;
375    let mut ema_return: Option<f32> = None;
376    let ema_alpha = 0.05f32;
377    let mut best_return = f32::NEG_INFINITY;
378
379    for t in 0..total_steps {
380        // Epsilon-greedy action
381        let action_index = if t < start_steps || rng.next_f32() < epsilon {
382            rng.sample_index(action_dim)
383        } else {
384            let _ng = NoGradTrack::new();
385            let q_vals = q_net.forward(&state);
386            let row = q_vals.data();
387            let mut best_i = 0usize;
388            let mut best_v = row[0];
389            for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390                if r > best_v {
391                    best_v = r;
392                    best_i = i;
393                }
394            }
395            best_i
396        };
397
398        // Env step
399        let (next_state, reward, done) = env.step(action_index);
400        episode_return += reward;
401
402        // Store
403        let s_slice = state.data().to_vec();
404        let s2_slice = next_state.data().to_vec();
405        rb.push(
406            &s_slice,
407            action_index,
408            reward,
409            if done { 1.0 } else { 0.0 },
410            &s2_slice,
411        );
412
413        // Reset on done
414        state = if done {
415            let st = env.reset();
416            ema_return = Some(match ema_return {
417                None => episode_return,
418                Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419            });
420            if episode_return > best_return {
421                best_return = episode_return;
422            }
423            println!(
424                "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425                t,
426                episode,
427                episode_return,
428                ema_return.unwrap_or(episode_return),
429                best_return,
430                rb.size
431            );
432            episode_return = 0.0;
433            episode += 1;
434            st
435        } else {
436            next_state
437        };
438
439        // Epsilon linear decay
440        if t < eps_decay_steps {
441            epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442        }
443
444        // Train
445        if rb.can_sample(batch_size) {
446            let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448            // Double DQN target: a* = argmax_a Q_online(s2,a); y = r + (1-d)*gamma*Q_target(s2, a*)
449            let target_q = {
450                let _ng = NoGradTrack::new();
451                let q_online_s2 = q_net.forward(&s2);
452                // argmax per row (manual on CPU)
453                let row_stride = action_dim;
454                let qd = q_online_s2.data();
455                let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456                for i in 0..batch_size {
457                    let base = i * row_stride;
458                    let mut bi = 0usize;
459                    let mut bv = qd[base];
460                    for j in 1..action_dim {
461                        let v = qd[base + j];
462                        if v > bv {
463                            bv = v;
464                            bi = j;
465                        }
466                    }
467                    next_actions.push(bi);
468                }
469                let q_targ_s2 = q_targ.forward(&s2);
470                let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471                let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472                r.add_tensor(&not_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473            };
474
475            // Q(s,a) for current actions
476            // Zero grads first
477            {
478                let mut params = q_net.parameters();
479                q_opt.zero_grad(&mut params);
480            }
481
482            let q_all = q_net.forward(&s);
483            let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484            let diff = q_sa.sub_tensor(&target_q);
485            let mut loss = pseudo_huber_mean(&diff);
486            loss.backward(None);
487
488            // Step (filter only params with grads)
489            {
490                let params = q_net.parameters();
491                let mut with_grads: Vec<&mut Tensor> = Vec::new();
492                for p in params {
493                    if p.grad_owned().is_some() {
494                        with_grads.push(p);
495                    }
496                }
497                if !with_grads.is_empty() {
498                    let gn = grad_global_norm(&mut with_grads);
499                    clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500                    q_opt.step(&mut with_grads);
501                    q_opt.zero_grad(&mut with_grads);
502                    if t % 100 == 0 {
503                        let mut pn = q_net.parameters();
504                        let pn_l2 = params_l2_norm(&mut pn);
505                        let q_mean = q_all.mean().value();
506                        println!(
507                            "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508                            t, loss.value(), q_mean, gn, pn_l2, epsilon
509                        );
510                    }
511                }
512            }
513
514            // Target hard update
515            if t % target_update_interval == 0 {
516                q_targ.net.copy_from(&q_net.net);
517            }
518
519            // Clear graphs
520            clear_all_graphs_known();
521        }
522    }
523
524    println!("=== DQN training finished ===");
525    Ok(())
526}
Source§

impl Tensor

Source

pub fn index_select(&self, dim: usize, indices: &[usize]) -> Tensor

Select elements along a dimension using a list of indices

This operation extracts elements from the input tensor along a specified dimension using the provided indices. The output tensor has the same shape as the input except along the specified dimension, where the size becomes the length of the indices array.

The index_select operation is commonly used for extracting specific rows, columns, or slices from tensors, and is particularly useful in machine learning for operations like embedding lookups and attention mechanisms.

§Arguments
  • dim - The dimension along which to select elements (must be < tensor rank)
  • indices - Array of indices specifying which elements to select along dim
§Returns

A new tensor with the same shape as the input except along dim, where the size is indices.len()

§Examples
§Basic Index Selection
use train_station::Tensor;

// Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();

// Select columns 2 and 0 from dimension 1
let result = tensor.index_select(1, &[2, 0]);

// Result shape is [2, 2] (same as input except dim 1 is now 2)
assert_eq!(result.shape().dims(), vec![2, 2]);

// Row 0: selected columns [2, 0] -> [2.0, 0.0]
assert_eq!(result.get(&[0, 0]), 2.0);
assert_eq!(result.get(&[0, 1]), 0.0);

// Row 1: selected columns [2, 0] -> [5.0, 3.0]
assert_eq!(result.get(&[1, 0]), 5.0);
assert_eq!(result.get(&[1, 1]), 3.0);
§Index Selection with Gradient Tracking
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap()
    .with_requires_grad();

// Select specific elements with gradient tracking enabled
let mut result = tensor.index_select(1, &[1, 2]);
result.backward(None);

// Verify gradients are computed correctly
let grad = tensor.grad_owned().expect("gradient missing");
assert_eq!(grad.shape().dims(), vec![2, 3]);
§Selecting Rows from a Matrix
use train_station::Tensor;

// Create a 3x2 matrix
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]).unwrap();

// Select rows 2 and 0 (dimension 0)
let result = tensor.index_select(0, &[2, 0]);

// Result shape is [2, 2]
assert_eq!(result.shape().dims(), vec![2, 2]);

// Selected rows: row 2 [5.0, 6.0], row 0 [1.0, 2.0]
assert_eq!(result.get(&[0, 0]), 5.0); // First row of result (was row 2)
assert_eq!(result.get(&[0, 1]), 6.0);
assert_eq!(result.get(&[1, 0]), 1.0); // Second row of result (was row 0)
assert_eq!(result.get(&[1, 1]), 2.0);
§Performance Characteristics
  • Time Complexity: O(n) where n is the number of elements in the output tensor
  • Memory Usage: Creates a new tensor with size equal to the output shape
  • Optimization: Uses precomputed strides for efficient memory access
  • GradTrack Overhead: Minimal overhead when gradient tracking is enabled
  • Memory Layout: Output tensor is always contiguous for optimal performance
§Implementation Details

The index_select operation works by:

  1. Validating the dimension and index bounds
  2. Computing the output shape (same as input except along dim)
  3. Creating a new contiguous output tensor
  4. Iterating through all positions in the output tensor using nested loops:
    • Outer loop: iterate over dimensions before dim
    • Middle loop: iterate over the selected indices
    • Inner loop: iterate over dimensions after dim
  5. Computing source offsets using the input tensor’s strides
  6. Copying values from input to output tensor
  7. Registering the operation for gradient computation if needed
§Safety

This function performs comprehensive bounds checking to ensure:

  • The specified dimension is within the tensor’s rank
  • All indices are within bounds for the specified dimension
  • Memory access is safe through proper offset calculations
§Panics

This function will panic if:

  • dim is greater than or equal to the tensor’s rank
  • Any index in indices is out of bounds for the specified dimension
§Thread Safety

This function is thread-safe and can be called concurrently on different tensors. The operation does not modify the input tensor and creates a new output tensor.

Source§

impl Tensor

Source

pub fn masked_fill(&self, mask: &[bool], value: f32) -> Tensor

Fill masked elements with a specified value

This operation returns a copy of the input tensor where elements are replaced by the specified value wherever the corresponding boolean mask is true. Elements where the mask is false retain their original values from the input tensor.

The masked_fill operation is commonly used in machine learning for operations like masking attention weights, zeroing out specific elements, and implementing dropout-like functionality.

§Arguments
  • mask - Boolean array with the same length as the number of tensor elements
  • value - The value to fill masked positions with
§Returns

A new tensor with the same shape as the input, where masked elements are replaced by value and unmasked elements retain their original values

§Examples
§Basic Masked Fill
use train_station::Tensor;

// Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();

// Create a mask: [false, true, false, true, false, true]
let mask = [false, true, false, true, false, true];
let result = tensor.masked_fill(&mask, -1.0);

// Result: [[0.0, -1.0, 2.0], [-1.0, 4.0, -1.0]]
assert_eq!(result.shape().dims(), vec![2, 3]);
assert_eq!(result.get(&[0, 0]), 0.0);   // Unmasked
assert_eq!(result.get(&[0, 1]), -1.0);  // Masked
assert_eq!(result.get(&[0, 2]), 2.0);   // Unmasked
assert_eq!(result.get(&[1, 0]), -1.0);  // Masked
assert_eq!(result.get(&[1, 1]), 4.0);   // Unmasked
assert_eq!(result.get(&[1, 2]), -1.0);  // Masked
§Masked Fill with Gradient Tracking
use train_station::Tensor;

let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap()
    .with_requires_grad();

// Create a mask with some true values
let mask = [false, true, false, true, false, false];
let mut result = tensor.masked_fill(&mask, 5.0);

// Compute gradients
result.backward(None);
let grad = tensor.grad_owned().expect("gradient missing");

// Gradients should be zero where mask is true, 1 elsewhere
assert_eq!(grad.shape().dims(), vec![2, 3]);
assert!((grad.get(&[0, 0]) - 1.0).abs() < 1e-6);   // Unmasked: gradient flows
assert!((grad.get(&[0, 1]) - 0.0).abs() < 1e-6);   // Masked: no gradient
assert!((grad.get(&[0, 2]) - 1.0).abs() < 1e-6);   // Unmasked: gradient flows
§Zeroing Out Specific Elements
use train_station::Tensor;

// Create a tensor with some values
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();

// Create a mask to zero out every other element
let mask = [true, false, true, false, true, false];
let result = tensor.masked_fill(&mask, 0.0);

// Result: [[0.0, 2.0, 0.0], [4.0, 0.0, 6.0]]
assert_eq!(result.get(&[0, 0]), 0.0);  // Zeroed
assert_eq!(result.get(&[0, 1]), 2.0);  // Kept
assert_eq!(result.get(&[0, 2]), 0.0);  // Zeroed
assert_eq!(result.get(&[1, 0]), 4.0);  // Kept
assert_eq!(result.get(&[1, 1]), 0.0);  // Zeroed
assert_eq!(result.get(&[1, 2]), 6.0);  // Kept
§Performance Characteristics
  • Time Complexity: O(n) where n is the number of elements in the tensor
  • Memory Usage: Creates a new tensor with the same size as the input
  • Optimization: Uses efficient stride-based iteration for non-contiguous tensors
  • GradTrack Overhead: Minimal overhead when gradient tracking is enabled
  • Memory Layout: Output tensor is always contiguous for optimal performance
§Implementation Details

The masked_fill operation works by:

  1. Validating that the mask length equals the number of tensor elements
  2. Creating a new contiguous output tensor with the same shape
  3. Iterating through all elements in logical order
  4. For each element, checking the corresponding mask value:
    • If mask is true: use the fill value
    • If mask is false: copy the original value from input tensor
  5. Computing source offsets using the input tensor’s shape for non-contiguous tensors
  6. Registering the operation for gradient computation if needed
§Safety

This function performs bounds checking to ensure:

  • The mask length equals the number of tensor elements
  • Memory access is safe through proper offset calculations
  • The operation handles both contiguous and non-contiguous tensors correctly
§Panics

This function will panic if:

  • The mask length does not equal the number of tensor elements
§Thread Safety

This function is thread-safe and can be called concurrently on different tensors. The operation does not modify the input tensor and creates a new output tensor.

§GradTrack Behavior

When gradient tracking is enabled:

  • Gradients do not flow through masked positions (they are zeroed)
  • Gradients flow normally through unmasked positions
  • This behavior is useful for implementing operations like dropout
Examples found in repository?
examples/neural_networks/multi_head_attention.rs (line 101)
72    pub fn forward(
73        &self,
74        query: &Tensor,
75        key: &Tensor,
76        value: &Tensor,
77        attn_mask: Option<&Tensor>,
78    ) -> Tensor {
79        let qkv = Self::project_qkv(query, key, value, &self.q_proj, &self.k_proj, &self.v_proj);
80        let (q, k, v) = qkv;
81
82        // Split heads: [b, t, e] -> [b, h, t, d]
83        let (b, tq, _e) = Self::triple(query);
84        let (_b2, tk, _e2) = Self::triple(key);
85        let q = Self::split_heads(&q, b, tq, self.num_heads, self.head_dim);
86        let k = Self::split_heads(&k, b, tk, self.num_heads, self.head_dim);
87        let v = Self::split_heads(&v, b, tk, self.num_heads, self.head_dim);
88
89        // Scaled dot-product attention
90        // logits: [b, h, tq, tk]
91        let k_t = k.transpose(2, 3);
92        let mut logits = q.matmul(&k_t).div_scalar((self.head_dim as f32).sqrt());
93        if let Some(mask) = attn_mask {
94            let dims = mask.shape().dims().to_vec();
95            // If boolean-like mask matching [b,h,tq,tk], apply masked_fill
96            if dims.len() == 4 && dims[0] == b && dims[1] == self.num_heads && dims[2] == tq {
97                // Interpret mask > 0.5 as keep; we invert to build masked positions
98                let cond: Vec<bool> = mask.data().iter().map(|&v| v < 0.5).collect();
99                // Apply masked fill on a flattened view, then reshape back
100                let flat_logits = logits.view(vec![(b * self.num_heads * tq * tk) as i32]);
101                let filled = flat_logits.masked_fill(&cond, f32::NEG_INFINITY);
102                logits = filled.view(vec![b as i32, self.num_heads as i32, tq as i32, tk as i32]);
103            } else {
104                // Fallback: additive mask
105                logits = logits.add_tensor(mask);
106            }
107        }
108        let attn = logits.softmax(3);
109
110        // context: [b, h, tq, d]
111        let context = attn.matmul(&v);
112        let context = context.permute(vec![0, 2, 1, 3]); // [b, tq, h, d]
113        let context = context.contiguous().view(vec![
114            b as i32,
115            tq as i32,
116            (self.num_heads * self.head_dim) as i32,
117        ]);
118
119        // Output projection (flatten to 2D, project, then restore 3D)
120        let flat = context.view(vec![(b * tq) as i32, self.embed_dim as i32]);
121        let out2d = self.out_proj.forward(&flat);
122        out2d.view(vec![b as i32, tq as i32, self.embed_dim as i32])
123    }
Source§

impl Tensor

Source

pub fn select(&self, dim: usize, index: usize) -> Tensor

Select a slice along a given dimension at a specific index

This operation extracts a slice from the input tensor by fixing a specific dimension at a given index. The result is a tensor with one fewer dimension than the input, containing the selected slice.

The select operation returns a view (zero-copy) when the base offset is zero, otherwise it creates a contiguous copy to ensure correctness. This operation is commonly used for extracting specific rows, columns, or slices from tensors.

§Arguments
  • dim - The dimension along which to select (must be < tensor rank)
  • index - The index within the specified dimension to select (must be < dim size)
§Returns

A tensor with the selected slice. The result has the same shape as the input except with the specified dimension removed.

§Performance

Returns a view when possible (base offset is zero) to avoid copying. On non-zero offsets, falls back to a contiguous copy for correctness. Gradients propagate back to the selected slice when GradTrack is enabled.

§Examples
§Basic Row Selection
use train_station::Tensor;

// Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();

// Select row 1 (dimension 0, index 1)
let result = tensor.select(0, 1);

// Result shape is [3] (dimension 0 removed)
assert_eq!(result.shape().dims(), vec![3]);
assert_eq!(result.get(&[0]), 3.0);  // First element of row 1
assert_eq!(result.get(&[1]), 4.0);  // Second element of row 1
assert_eq!(result.get(&[2]), 5.0);  // Third element of row 1
§Column Selection
use train_station::Tensor;

// Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();

// Select column 1 (dimension 1, index 1)
let result = tensor.select(1, 1);

// Result shape is [2] (dimension 1 removed)
assert_eq!(result.shape().dims(), vec![2]);
assert_eq!(result.get(&[0]), 1.0);  // Column 1, row 0
assert_eq!(result.get(&[1]), 4.0);  // Column 1, row 1
§Select with Gradient Tracking
use train_station::Tensor;

let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0], vec![2, 2]).unwrap()
    .with_requires_grad();

// Select row 1 with gradient tracking enabled
let mut result = tensor.select(0, 1);
result.backward(None);

// Verify gradients are computed correctly
let grad = tensor.grad_owned().expect("gradient missing");
assert_eq!(grad.shape().dims(), vec![2, 2]);
// Only row 1 receives gradients
assert_eq!(grad.get(&[0, 0]), 0.0);  // Row 0: no gradient
assert_eq!(grad.get(&[0, 1]), 0.0);  // Row 0: no gradient
assert_eq!(grad.get(&[1, 0]), 1.0);  // Row 1: gradient flows
assert_eq!(grad.get(&[1, 1]), 1.0);  // Row 1: gradient flows
§Performance Characteristics
  • Time Complexity: O(n) where n is the number of elements in the selected slice
  • Memory Usage: Zero-copy view when base offset is zero, otherwise creates a copy
  • Optimization: Uses efficient stride-based access for non-contiguous tensors
  • GradTrack Overhead: Minimal overhead when gradient tracking is enabled
  • Memory Layout: Result is contiguous when a copy is made, view otherwise
§Implementation Details

The select operation works by:

  1. Validating the dimension and index bounds
  2. Computing the new shape by removing the selected dimension
  3. Computing the new strides by removing the selected dimension’s stride
  4. Calculating the base offset for the selected slice
  5. If base offset is zero: creating a view with adjusted shape/strides
  6. If base offset is non-zero: creating a contiguous copy of the slice
  7. Registering the operation for gradient computation if needed
§Safety

This function performs comprehensive bounds checking to ensure:

  • The tensor has non-zero rank
  • The specified dimension is within the tensor’s rank
  • The index is within bounds for the specified dimension
  • Memory access is safe through proper offset calculations
§Panics

This function will panic if:

  • The tensor has zero rank
  • dim is greater than or equal to the tensor’s rank
  • index is greater than or equal to the size of the specified dimension
§Thread Safety

This function is thread-safe and can be called concurrently on different tensors. The operation does not modify the input tensor and creates either a view or a new tensor.

§View vs Copy Behavior
  • View (zero-copy): When the base offset is zero, returns a view that shares the same memory as the input tensor with adjusted shape and strides
  • Copy: When the base offset is non-zero, creates a contiguous copy to ensure correctness across all operations
§GradTrack Behavior

When gradient tracking is enabled:

  • Gradients are scattered back to the selected slice in the input tensor
  • Other positions in the input tensor receive zero gradients
  • This behavior ensures correct gradient flow for the selected elements
Source§

impl Tensor

Source

pub fn zeros(shape_dims: Vec<usize>) -> Self

Creates a new tensor filled with zeros

Convenience constructor that creates a tensor and initializes all elements to zero. Uses optimized SIMD operations for efficient zero initialization.

§Arguments
  • shape_dims - Vector of dimension sizes defining the tensor shape
§Returns

A new tensor with all elements initialized to zero

§Performance
  • Memory Allocation: Single allocation with optimized alignment
  • Initialization: SIMD-optimized zero filling for large tensors
  • Thread Safe: Atomic ID generation for gradtrack tracking
§Examples
use train_station::Tensor;

let tensor = Tensor::zeros(vec![2, 3]);
assert_eq!(tensor.size(), 6);
assert_eq!(tensor.shape().dims(), vec![2, 3]);

// Verify all elements are zero
assert_eq!(tensor.get(&[0, 0]), 0.0);
assert_eq!(tensor.get(&[1, 2]), 0.0);
Examples found in repository?
examples/RL_training/../neural_networks/basic_linear_layer.rs (line 60)
53    pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
54        // Xavier/Glorot initialization: scale by sqrt(1/input_size)
55        let scale = (1.0 / input_size as f32).sqrt();
56
57        let weight = Tensor::randn(vec![input_size, output_size], seed)
58            .mul_scalar(scale)
59            .with_requires_grad();
60        let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
61
62        Self {
63            weight,
64            bias,
65            input_size,
66            output_size,
67        }
68    }
More examples
Hide additional examples
examples/getting_started/tensor_basics.rs (line 46)
42fn demonstrate_tensor_creation() {
43    println!("--- Tensor Creation ---");
44
45    // Create tensors with different initializations
46    let zeros = Tensor::zeros(vec![2, 3]);
47    println!(
48        "Zeros tensor: shape {:?}, data: {:?}",
49        zeros.shape().dims(),
50        zeros.data()
51    );
52
53    let ones = Tensor::ones(vec![3, 2]);
54    println!(
55        "Ones tensor: shape {:?}, data: {:?}",
56        ones.shape().dims(),
57        ones.data()
58    );
59
60    // Create tensor from slice
61    let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
62    let from_slice = Tensor::from_slice(&data, vec![2, 3]).unwrap();
63    println!(
64        "From slice: shape {:?}, data: {:?}",
65        from_slice.shape().dims(),
66        from_slice.data()
67    );
68
69    // Create tensor with specific value
70    let mut filled = Tensor::new(vec![2, 2]);
71    {
72        let data = filled.data_mut();
73        for value in data.iter_mut() {
74            *value = 42.0;
75        }
76    }
77    println!("Filled with 42: {:?}", filled.data());
78
79    // Create tensor with random data
80    let random = Tensor::randn(vec![2, 2], Some(42));
81    println!(
82        "Random tensor: shape {:?}, data: {:?}",
83        random.shape().dims(),
84        random.data()
85    );
86}
examples/getting_started/optimizer_basics.rs (line 52)
47fn demonstrate_basic_optimizer_setup() {
48    println!("--- Basic Optimizer Setup ---");
49
50    // Create parameters that require gradients
51    let weight = Tensor::randn(vec![3, 2], Some(42)).with_requires_grad();
52    let bias = Tensor::zeros(vec![2]).with_requires_grad();
53
54    println!("Created parameters:");
55    println!(
56        "  Weight: shape {:?}, requires_grad: {}",
57        weight.shape().dims(),
58        weight.requires_grad()
59    );
60    println!(
61        "  Bias: shape {:?}, requires_grad: {}",
62        bias.shape().dims(),
63        bias.requires_grad()
64    );
65
66    // Create Adam optimizer with default configuration
67    let mut optimizer = Adam::new();
68    println!(
69        "Created Adam optimizer with learning rate: {}",
70        optimizer.learning_rate()
71    );
72
73    // Add parameters to optimizer
74    optimizer.add_parameter(&weight);
75    optimizer.add_parameter(&bias);
76    println!(
77        "Added {} parameters to optimizer",
78        optimizer.parameter_count()
79    );
80
81    // Create optimizer with custom configuration
82    let config = AdamConfig {
83        learning_rate: 0.01,
84        beta1: 0.9,
85        beta2: 0.999,
86        eps: 1e-8,
87        weight_decay: 0.0,
88        amsgrad: false,
89    };
90
91    let mut custom_optimizer = Adam::with_config(config);
92    custom_optimizer.add_parameter(&weight);
93    custom_optimizer.add_parameter(&bias);
94
95    println!(
96        "Created custom optimizer with learning rate: {}",
97        custom_optimizer.learning_rate()
98    );
99
100    // Demonstrate parameter linking
101    println!("Parameter linking completed successfully");
102}
103
104/// Demonstrate simple linear regression training
105fn demonstrate_linear_regression() -> Result<(), Box<dyn std::error::Error>> {
106    println!("\n--- Linear Regression Training ---");
107
108    // Create model parameters
109    let mut weight = Tensor::randn(vec![1, 1], Some(43)).with_requires_grad();
110    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
111
112    // Create optimizer
113    let mut optimizer = Adam::with_learning_rate(0.01);
114    optimizer.add_parameter(&weight);
115    optimizer.add_parameter(&bias);
116
117    // Create simple training data: y = 2*x + 1
118    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
119    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
120
121    println!("Training data:");
122    println!("  X: {:?}", x_data.data());
123    println!("  Y: {:?}", y_true.data());
124    println!("  Target: y = 2*x + 1");
125
126    // Training loop
127    let num_epochs = 100;
128    let mut losses = Vec::new();
129
130    for epoch in 0..num_epochs {
131        // Forward pass: y_pred = x * weight + bias
132        let y_pred = x_data.matmul(&weight) + &bias;
133
134        // Compute loss: MSE
135        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
136
137        // Backward pass
138        loss.backward(None);
139
140        // Optimizer step
141        optimizer.step(&mut [&mut weight, &mut bias]);
142        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
143
144        losses.push(loss.value());
145
146        // Print progress every 20 epochs
147        if epoch % 20 == 0 || epoch == num_epochs - 1 {
148            println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
149        }
150    }
151
152    // Evaluate final model
153    let final_predictions = x_data.matmul(&weight) + &bias;
154    println!("\nFinal model evaluation:");
155    println!("  Learned weight: {:.6}", weight.value());
156    println!("  Learned bias: {:.6}", bias.value());
157    println!("  Predictions vs True:");
158
159    for i in 0..5 {
160        let x1 = x_data.data()[i];
161        let pred = final_predictions.data()[i];
162        let true_val = y_true.data()[i];
163        println!(
164            "    x={:.1}: pred={:.3}, true={:.1}, error={:.3}",
165            x1,
166            pred,
167            true_val,
168            (pred - true_val).abs()
169        );
170    }
171
172    Ok(())
173}
174
175/// Demonstrate advanced training patterns
176fn demonstrate_advanced_training() -> Result<(), Box<dyn std::error::Error>> {
177    println!("\n--- Advanced Training Patterns ---");
178
179    // Create a more complex model
180    let mut weight = Tensor::randn(vec![1, 2], Some(44)).with_requires_grad();
181    let mut bias = Tensor::zeros(vec![2]).with_requires_grad();
182
183    // Create optimizer with different learning rate
184    let mut optimizer = Adam::with_learning_rate(0.005);
185    optimizer.add_parameter(&weight);
186    optimizer.add_parameter(&bias);
187
188    // Create training data: y = 2*x + [1, 3]
189    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
190    let y_true = Tensor::from_slice(
191        &[3.0, 5.0, 7.0, 9.0, 11.0, 6.0, 8.0, 10.0, 12.0, 14.0],
192        vec![5, 2],
193    )
194    .unwrap();
195
196    println!("Advanced training with monitoring:");
197    println!("  Initial learning rate: {}", optimizer.learning_rate());
198
199    // Training loop with monitoring
200    let num_epochs = 50;
201    let mut losses = Vec::new();
202    let mut weight_norms = Vec::new();
203    let mut gradient_norms = Vec::new();
204
205    for epoch in 0..num_epochs {
206        // Forward pass
207        let y_pred = x_data.matmul(&weight) + &bias;
208        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
209
210        // Backward pass
211        loss.backward(None);
212
213        // Compute gradient norm before optimizer step
214        let gradient_norm = weight.grad_owned().unwrap().norm();
215
216        // Optimizer step
217        optimizer.step(&mut [&mut weight, &mut bias]);
218        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
219
220        // Learning rate scheduling: reduce every 10 epochs
221        if epoch > 0 && epoch % 10 == 0 {
222            let current_lr = optimizer.learning_rate();
223            let new_lr = current_lr * 0.5;
224            optimizer.set_learning_rate(new_lr);
225            println!(
226                "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
227                epoch, current_lr, new_lr
228            );
229        }
230
231        // Record metrics
232        losses.push(loss.value());
233        weight_norms.push(weight.norm().value());
234        gradient_norms.push(gradient_norm.value());
235
236        // Print detailed progress
237        if epoch % 10 == 0 || epoch == num_epochs - 1 {
238            println!(
239                "Epoch {:2}: Loss = {:.6}, Weight Norm = {:.6}, Gradient Norm = {:.6}",
240                epoch,
241                loss.value(),
242                weight.norm().value(),
243                gradient_norm.value()
244            );
245        }
246    }
247
248    println!("Final learning rate: {}", optimizer.learning_rate());
249
250    // Analyze training progression
251    let initial_loss = losses[0];
252    let final_loss = losses[losses.len() - 1];
253    let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
254
255    println!("\nTraining Analysis:");
256    println!("  Initial loss: {:.6}", initial_loss);
257    println!("  Final loss: {:.6}", final_loss);
258    println!("  Loss reduction: {:.1}%", loss_reduction);
259    println!("  Final weight norm: {:.6}", weight.norm().value());
260    println!("  Final bias: {:?}", bias.data());
261
262    Ok(())
263}
264
265/// Demonstrate learning rate scheduling
266fn demonstrate_learning_rate_scheduling() -> Result<(), Box<dyn std::error::Error>> {
267    println!("\n--- Learning Rate Scheduling ---");
268
269    // Create simple model
270    let mut weight = Tensor::randn(vec![1, 1], Some(45)).with_requires_grad();
271    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
272
273    // Create optimizer with high initial learning rate
274    let mut optimizer = Adam::with_learning_rate(0.1);
275    optimizer.add_parameter(&weight);
276    optimizer.add_parameter(&bias);
277
278    // Simple data
279    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3, 1]).unwrap();
280    let y_true = Tensor::from_slice(&[2.0, 4.0, 6.0], vec![3, 1]).unwrap();
281
282    println!("Initial learning rate: {}", optimizer.learning_rate());
283
284    // Training loop with learning rate scheduling
285    let num_epochs = 50;
286    let mut losses = Vec::new();
287
288    for epoch in 0..num_epochs {
289        // Forward pass
290        let y_pred = x_data.matmul(&weight) + &bias;
291        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
292
293        // Backward pass
294        loss.backward(None);
295
296        // Optimizer step
297        optimizer.step(&mut [&mut weight, &mut bias]);
298        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
299
300        // Learning rate scheduling: reduce every 10 epochs
301        if epoch > 0 && epoch % 10 == 0 {
302            let current_lr = optimizer.learning_rate();
303            let new_lr = current_lr * 0.5;
304            optimizer.set_learning_rate(new_lr);
305            println!(
306                "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
307                epoch, current_lr, new_lr
308            );
309        }
310
311        losses.push(loss.value());
312
313        // Print progress
314        if epoch % 10 == 0 || epoch == num_epochs - 1 {
315            println!(
316                "Epoch {:2}: Loss = {:.6}, LR = {:.3}",
317                epoch,
318                loss.value(),
319                optimizer.learning_rate()
320            );
321        }
322    }
323
324    println!("Final learning rate: {}", optimizer.learning_rate());
325
326    Ok(())
327}
328
329/// Demonstrate training monitoring and analysis
330fn demonstrate_training_monitoring() -> Result<(), Box<dyn std::error::Error>> {
331    println!("\n--- Training Monitoring ---");
332
333    // Create model
334    let mut weight = Tensor::randn(vec![1, 1], Some(46)).with_requires_grad();
335    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
336
337    // Create optimizer
338    let mut optimizer = Adam::with_learning_rate(0.01);
339    optimizer.add_parameter(&weight);
340    optimizer.add_parameter(&bias);
341
342    // Training data
343    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
344    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0], vec![4, 1]).unwrap();
345
346    // Training loop with comprehensive monitoring
347    let num_epochs = 30;
348    let mut losses = Vec::new();
349    let mut weight_history = Vec::new();
350    let mut bias_history = Vec::new();
351
352    for epoch in 0..num_epochs {
353        // Forward pass
354        let y_pred = x_data.matmul(&weight) + &bias;
355        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
356
357        // Backward pass
358        loss.backward(None);
359
360        // Optimizer step
361        optimizer.step(&mut [&mut weight, &mut bias]);
362        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
363
364        // Record history
365        losses.push(loss.value());
366        weight_history.push(weight.value());
367        bias_history.push(bias.value());
368
369        // Print detailed monitoring
370        if epoch % 5 == 0 || epoch == num_epochs - 1 {
371            println!(
372                "Epoch {:2}: Loss = {:.6}, Weight = {:.6}, Bias = {:.6}",
373                epoch,
374                loss.value(),
375                weight.value(),
376                bias.value()
377            );
378        }
379    }
380
381    // Analyze training progression
382    println!("\nTraining Analysis:");
383    println!("  Initial loss: {:.6}", losses[0]);
384    println!("  Final loss: {:.6}", losses[losses.len() - 1]);
385    println!(
386        "  Loss reduction: {:.1}%",
387        (losses[0] - losses[losses.len() - 1]) / losses[0] * 100.0
388    );
389
390    // Compute statistics
391    let loss_mean = compute_mean(&losses);
392    let loss_std = compute_std(&losses);
393    let weight_change = (weight_history[weight_history.len() - 1] - weight_history[0]).abs();
394    let bias_change = (bias_history[bias_history.len() - 1] - bias_history[0]).abs();
395
396    println!("  Average loss: {:.6} ± {:.6}", loss_mean, loss_std);
397    println!("  Weight change: {:.6}", weight_change);
398    println!("  Bias change: {:.6}", bias_change);
399    println!("  Final weight norm: {:.6}", weight.norm().value());
400    println!("  Final bias: {:.6}", bias.value());
401
402    Ok(())
403}
examples/neural_networks/basic_transformer.rs (line 85)
76    pub fn infer_autoregressive(&self, src: &Tensor, max_steps: usize) -> Tensor {
77        let (b, _s, e) = Self::triple(src);
78        let mut memory = src.clone();
79        for enc in &self.encoders {
80            memory = enc.forward(&memory, None);
81        }
82
83        let mut out_seq: Vec<Tensor> = Vec::new();
84        // Start token: zeros
85        let mut current = Tensor::zeros(vec![b, 1, e]);
86        for _step in 0..max_steps {
87            // Build causal mask for length t
88            let t = current.shape().dims()[1];
89            let mut causal = Tensor::ones(vec![b, self.num_heads, t, t]);
90            // Upper triangle as false -> masked for all batches and heads
91            for bb in 0..b {
92                for hh in 0..self.num_heads {
93                    for i in 0..t {
94                        for j in (i + 1)..t {
95                            let offset = causal.memory_offset(&[bb, hh, i, j]);
96                            let data = causal.data_mut();
97                            data[offset] = 0.0;
98                        }
99                    }
100                }
101            }
102            let mut step_out = current.clone();
103            for dec in &self.decoders {
104                step_out = dec.forward(&step_out, &memory, Some(&causal), None);
105            }
106            // (Toy) append placeholder token; real models would project last token
107            out_seq.push(step_out.clone());
108            // Append a zero token to grow sequence by 1 for next causal computation
109            current = Tensor::zeros(vec![b, t + 1, e]);
110        }
111        // Simple return of final sequence placeholder
112        current
113    }
114
115    /// Non auto-regressive inference: single forward pass
116    pub fn infer_non_autoregressive(&self, src: &Tensor, tgt_len: usize) -> Tensor {
117        let (b, _s, e) = Self::triple(src);
118        let mut memory = src.clone();
119        for enc in &self.encoders {
120            memory = enc.forward(&memory, None);
121        }
122        let tgt = Tensor::zeros(vec![b, tgt_len, e]);
123        let mut out = tgt.clone();
124        for dec in &self.decoders {
125            out = dec.forward(&out, &memory, None, None);
126        }
127        out
128    }
examples/neural_networks/multi_head_attention.rs (line 181)
165fn main() -> Result<(), Box<dyn std::error::Error>> {
166    println!("=== Multi-Head Attention Example ===");
167
168    let batch = 2usize;
169    let src_len = 5usize;
170    let tgt_len = 4usize;
171    let embed = 16usize;
172    let heads = 4usize;
173
174    let query = Tensor::randn(vec![batch, tgt_len, embed], Some(7));
175    let key = Tensor::randn(vec![batch, src_len, embed], Some(8));
176    let value = Tensor::randn(vec![batch, src_len, embed], Some(9));
177
178    let mut mha = MultiHeadAttention::new(embed, heads, Some(42));
179
180    // Simple causal mask for target self-attention shape [b, h, tq, tk]
181    let mut mask = Tensor::zeros(vec![batch, heads, tgt_len, src_len]);
182    // Disallow attending to future positions when tgt_len <= src_len by adding -1e9
183    // Here, just demonstrate mask broadcast/add mechanics with a light mask on last head
184    if src_len >= tgt_len {
185        // set upper triangle to a large negative value for head 0
186        for i in 0..tgt_len {
187            for j in (i + 1)..src_len {
188                let idx = [0usize, 0usize, i, j];
189                // Quick set via data_mut using a slice view
190                let offset = mask.memory_offset(&idx);
191                let data = mask.data_mut();
192                data[offset] = -1e9;
193            }
194        }
195    }
196
197    let out = mha.forward(&query, &key, &value, Some(&mask));
198    println!("Output shape: {:?}", out.shape().dims());
199
200    // Tiny training step to confirm gradients are wired
201    let mut optimizer = Adam::with_learning_rate(0.01);
202    let mut params = mha.parameters();
203    for p in &params {
204        optimizer.add_parameter(p);
205    }
206
207    // Dummy loss = mean of output
208    let mut loss = out.mean();
209    loss.backward(None);
210    optimizer.step(&mut params);
211    optimizer.zero_grad(&mut params);
212
213    println!("Loss: {:.6}", loss.value());
214    println!("=== Done ===");
215    Ok(())
216}
examples/optimizers/adam_configurations.rs (line 93)
84fn demonstrate_default_adam() -> Result<(), Box<dyn std::error::Error>> {
85    println!("--- Default Adam Configuration ---");
86
87    // Create a simple regression problem: y = 2*x + 1
88    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
89    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
90
91    // Create model parameters
92    let mut weight = Tensor::randn(vec![1, 1], Some(42)).with_requires_grad();
93    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
94
95    // Create Adam optimizer with default configuration
96    let mut optimizer = Adam::new();
97    optimizer.add_parameter(&weight);
98    optimizer.add_parameter(&bias);
99
100    println!("Default Adam configuration:");
101    println!("  Learning rate: {}", optimizer.learning_rate());
102    println!("  Initial weight: {:.6}", weight.value());
103    println!("  Initial bias: {:.6}", bias.value());
104
105    // Training loop
106    let num_epochs = 50;
107    let mut losses = Vec::new();
108
109    for epoch in 0..num_epochs {
110        // Forward pass
111        let y_pred = x_data.matmul(&weight) + &bias;
112        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
113
114        // Backward pass
115        loss.backward(None);
116
117        // Optimizer step
118        optimizer.step(&mut [&mut weight, &mut bias]);
119        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
120
121        losses.push(loss.value());
122
123        if epoch % 10 == 0 || epoch == num_epochs - 1 {
124            println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
125        }
126    }
127
128    // Evaluate final model
129    let _final_predictions = x_data.matmul(&weight) + &bias;
130    println!("\nFinal model:");
131    println!("  Learned weight: {:.6} (target: 2.0)", weight.value());
132    println!("  Learned bias: {:.6} (target: 1.0)", bias.value());
133    println!("  Final loss: {:.6}", losses[losses.len() - 1]);
134
135    Ok(())
136}
137
138/// Demonstrate learning rate comparison
139fn demonstrate_learning_rate_comparison() -> Result<(), Box<dyn std::error::Error>> {
140    println!("\n--- Learning Rate Comparison ---");
141
142    let learning_rates = [0.001, 0.01, 0.1];
143    let mut results = Vec::new();
144
145    for &lr in &learning_rates {
146        println!("\nTesting learning rate: {}", lr);
147
148        let stats = train_with_config(TrainingConfig {
149            learning_rate: lr,
150            ..Default::default()
151        })?;
152
153        results.push((lr, stats.clone()));
154
155        println!("  Final loss: {:.6}", stats.final_loss);
156        println!("  Convergence epoch: {}", stats.convergence_epoch);
157    }
158
159    // Compare results
160    println!("\nLearning Rate Comparison Summary:");
161    for (lr, stats) in &results {
162        println!(
163            "  LR={:6}: Loss={:.6}, Converged@{}",
164            lr, stats.final_loss, stats.convergence_epoch
165        );
166    }
167
168    Ok(())
169}
170
171/// Demonstrate weight decay comparison
172fn demonstrate_weight_decay_comparison() -> Result<(), Box<dyn std::error::Error>> {
173    println!("\n--- Weight Decay Comparison ---");
174
175    let weight_decays = [0.0, 0.001, 0.01];
176    let mut results = Vec::new();
177
178    for &wd in &weight_decays {
179        println!("\nTesting weight decay: {}", wd);
180
181        let stats = train_with_config(TrainingConfig {
182            weight_decay: wd,
183            ..Default::default()
184        })?;
185
186        results.push((wd, stats.clone()));
187
188        println!("  Final loss: {:.6}", stats.final_loss);
189        println!("  Final weight norm: {:.6}", stats.weight_norm);
190    }
191
192    // Compare results
193    println!("\nWeight Decay Comparison Summary:");
194    for (wd, stats) in &results {
195        println!(
196            "  WD={:6}: Loss={:.6}, Weight Norm={:.6}",
197            wd, stats.final_loss, stats.weight_norm
198        );
199    }
200
201    Ok(())
202}
203
204/// Demonstrate beta parameter tuning
205fn demonstrate_beta_parameter_tuning() -> Result<(), Box<dyn std::error::Error>> {
206    println!("\n--- Beta Parameter Tuning ---");
207
208    let beta_configs = [
209        (0.9, 0.999),  // Default
210        (0.8, 0.999),  // More aggressive momentum
211        (0.95, 0.999), // Less aggressive momentum
212        (0.9, 0.99),   // Faster second moment decay
213    ];
214
215    let mut results = Vec::new();
216
217    for (i, (beta1, beta2)) in beta_configs.iter().enumerate() {
218        println!(
219            "\nTesting beta configuration {}: beta1={}, beta2={}",
220            i + 1,
221            beta1,
222            beta2
223        );
224
225        let config = TrainingConfig {
226            beta1: *beta1,
227            beta2: *beta2,
228            ..Default::default()
229        };
230
231        let stats = train_with_config(config)?;
232        results.push(((*beta1, *beta2), stats.clone()));
233
234        println!("  Final loss: {:.6}", stats.final_loss);
235        println!("  Convergence epoch: {}", stats.convergence_epoch);
236    }
237
238    // Compare results
239    println!("\nBeta Parameter Comparison Summary:");
240    for ((beta1, beta2), stats) in &results {
241        println!(
242            "  B1={:4}, B2={:5}: Loss={:.6}, Converged@{}",
243            beta1, beta2, stats.final_loss, stats.convergence_epoch
244        );
245    }
246
247    Ok(())
248}
249
250/// Demonstrate configuration benchmarking
251fn demonstrate_configuration_benchmarking() -> Result<(), Box<dyn std::error::Error>> {
252    println!("\n--- Configuration Benchmarking ---");
253
254    // Define configurations to benchmark
255    let configs = vec![
256        (
257            "Conservative",
258            TrainingConfig {
259                learning_rate: 0.001,
260                weight_decay: 0.001,
261                beta1: 0.95,
262                ..Default::default()
263            },
264        ),
265        (
266            "Balanced",
267            TrainingConfig {
268                learning_rate: 0.01,
269                weight_decay: 0.0,
270                beta1: 0.9,
271                ..Default::default()
272            },
273        ),
274        (
275            "Aggressive",
276            TrainingConfig {
277                learning_rate: 0.1,
278                weight_decay: 0.0,
279                beta1: 0.8,
280                ..Default::default()
281            },
282        ),
283    ];
284
285    let mut benchmark_results = Vec::new();
286
287    for (name, config) in configs {
288        println!("\nBenchmarking {} configuration:", name);
289
290        let start_time = std::time::Instant::now();
291        let stats = train_with_config(config.clone())?;
292        let elapsed = start_time.elapsed();
293
294        println!("  Training time: {:.2}ms", elapsed.as_millis());
295        println!("  Final loss: {:.6}", stats.final_loss);
296        println!("  Convergence: {} epochs", stats.convergence_epoch);
297
298        benchmark_results.push((name.to_string(), stats, elapsed));
299    }
300
301    // Summary
302    println!("\nBenchmarking Summary:");
303    for (name, stats, elapsed) in &benchmark_results {
304        println!(
305            "  {:12}: Loss={:.6}, Time={:4}ms, Converged@{}",
306            name,
307            stats.final_loss,
308            elapsed.as_millis(),
309            stats.convergence_epoch
310        );
311    }
312
313    Ok(())
314}
315
316/// Helper function to train with specific configuration
317fn train_with_config(config: TrainingConfig) -> Result<TrainingStats, Box<dyn std::error::Error>> {
318    // Create training data
319    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
320    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
321
322    // Create model parameters
323    let mut weight = Tensor::randn(vec![1, 1], Some(123)).with_requires_grad();
324    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
325
326    // Create optimizer with custom configuration
327    let adam_config = AdamConfig {
328        learning_rate: config.learning_rate,
329        beta1: config.beta1,
330        beta2: config.beta2,
331        eps: 1e-8,
332        weight_decay: config.weight_decay,
333        amsgrad: false,
334    };
335
336    let mut optimizer = Adam::with_config(adam_config);
337    optimizer.add_parameter(&weight);
338    optimizer.add_parameter(&bias);
339
340    // Training loop
341    let mut losses = Vec::new();
342    let mut convergence_epoch = config.epochs;
343
344    for epoch in 0..config.epochs {
345        // Forward pass
346        let y_pred = x_data.matmul(&weight) + &bias;
347        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
348
349        // Backward pass
350        loss.backward(None);
351
352        // Optimizer step
353        optimizer.step(&mut [&mut weight, &mut bias]);
354        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
355
356        let loss_value = loss.value();
357        losses.push(loss_value);
358
359        // Check for convergence (loss < 0.01)
360        if loss_value < 0.01 && convergence_epoch == config.epochs {
361            convergence_epoch = epoch;
362        }
363    }
364
365    Ok(TrainingStats {
366        config,
367        final_loss: losses[losses.len() - 1],
368        loss_history: losses,
369        convergence_epoch,
370        weight_norm: weight.norm().value(),
371    })
372}
Source

pub fn ones(shape_dims: Vec<usize>) -> Self

Creates a new tensor filled with ones

Convenience constructor that creates a tensor and initializes all elements to one. Uses optimized SIMD operations for efficient initialization.

§Arguments
  • shape_dims - Vector of dimension sizes defining the tensor shape
§Returns

A new tensor with all elements initialized to one

§Performance
  • Memory Allocation: Single allocation with optimized alignment
  • Initialization: SIMD-optimized one filling for large tensors
  • Thread Safe: Atomic ID generation for gradtrack tracking
§Examples
use train_station::Tensor;

let tensor = Tensor::ones(vec![2, 3]);
assert_eq!(tensor.size(), 6);
assert_eq!(tensor.shape().dims(), vec![2, 3]);

// Verify all elements are one
assert_eq!(tensor.get(&[0, 0]), 1.0);
assert_eq!(tensor.get(&[1, 2]), 1.0);
Examples found in repository?
examples/getting_started/tensor_basics.rs (line 53)
42fn demonstrate_tensor_creation() {
43    println!("--- Tensor Creation ---");
44
45    // Create tensors with different initializations
46    let zeros = Tensor::zeros(vec![2, 3]);
47    println!(
48        "Zeros tensor: shape {:?}, data: {:?}",
49        zeros.shape().dims(),
50        zeros.data()
51    );
52
53    let ones = Tensor::ones(vec![3, 2]);
54    println!(
55        "Ones tensor: shape {:?}, data: {:?}",
56        ones.shape().dims(),
57        ones.data()
58    );
59
60    // Create tensor from slice
61    let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
62    let from_slice = Tensor::from_slice(&data, vec![2, 3]).unwrap();
63    println!(
64        "From slice: shape {:?}, data: {:?}",
65        from_slice.shape().dims(),
66        from_slice.data()
67    );
68
69    // Create tensor with specific value
70    let mut filled = Tensor::new(vec![2, 2]);
71    {
72        let data = filled.data_mut();
73        for value in data.iter_mut() {
74            *value = 42.0;
75        }
76    }
77    println!("Filled with 42: {:?}", filled.data());
78
79    // Create tensor with random data
80    let random = Tensor::randn(vec![2, 2], Some(42));
81    println!(
82        "Random tensor: shape {:?}, data: {:?}",
83        random.shape().dims(),
84        random.data()
85    );
86}
More examples
Hide additional examples
examples/neural_networks/basic_transformer.rs (line 89)
76    pub fn infer_autoregressive(&self, src: &Tensor, max_steps: usize) -> Tensor {
77        let (b, _s, e) = Self::triple(src);
78        let mut memory = src.clone();
79        for enc in &self.encoders {
80            memory = enc.forward(&memory, None);
81        }
82
83        let mut out_seq: Vec<Tensor> = Vec::new();
84        // Start token: zeros
85        let mut current = Tensor::zeros(vec![b, 1, e]);
86        for _step in 0..max_steps {
87            // Build causal mask for length t
88            let t = current.shape().dims()[1];
89            let mut causal = Tensor::ones(vec![b, self.num_heads, t, t]);
90            // Upper triangle as false -> masked for all batches and heads
91            for bb in 0..b {
92                for hh in 0..self.num_heads {
93                    for i in 0..t {
94                        for j in (i + 1)..t {
95                            let offset = causal.memory_offset(&[bb, hh, i, j]);
96                            let data = causal.data_mut();
97                            data[offset] = 0.0;
98                        }
99                    }
100                }
101            }
102            let mut step_out = current.clone();
103            for dec in &self.decoders {
104                step_out = dec.forward(&step_out, &memory, Some(&causal), None);
105            }
106            // (Toy) append placeholder token; real models would project last token
107            out_seq.push(step_out.clone());
108            // Append a zero token to grow sequence by 1 for next causal computation
109            current = Tensor::zeros(vec![b, t + 1, e]);
110        }
111        // Simple return of final sequence placeholder
112        current
113    }
114
115    /// Non auto-regressive inference: single forward pass
116    pub fn infer_non_autoregressive(&self, src: &Tensor, tgt_len: usize) -> Tensor {
117        let (b, _s, e) = Self::triple(src);
118        let mut memory = src.clone();
119        for enc in &self.encoders {
120            memory = enc.forward(&memory, None);
121        }
122        let tgt = Tensor::zeros(vec![b, tgt_len, e]);
123        let mut out = tgt.clone();
124        for dec in &self.decoders {
125            out = dec.forward(&out, &memory, None, None);
126        }
127        out
128    }
129
130    /// Helper: build boolean-like causal mask [b, heads, t, t] with 1.0 keep, 0.0 masked
131    fn build_causal_mask_static(batch: usize, heads: usize, t: usize) -> Tensor {
132        let mut mask = Tensor::ones(vec![batch, heads, t, t]);
133        for bb in 0..batch {
134            for hh in 0..heads {
135                for i in 0..t {
136                    for j in (i + 1)..t {
137                        let offset = mask.memory_offset(&[bb, hh, i, j]);
138                        let data = mask.data_mut();
139                        data[offset] = 0.0;
140                    }
141                }
142            }
143        }
144        mask
145    }
examples/RL_training/dqn.rs (line 471)
333pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334    println!("=== DQN Example (YardEnv discrete) ===");
335
336    // Dims
337    let state_dim = 3usize;
338    let action_dim = 3usize;
339
340    // Hparams
341    let gamma = 0.99f32;
342    let batch_size = 64usize;
343    let start_steps = 200usize;
344    let target_update_interval = 200usize; // hard update cadence
345    let max_grad_norm = 1.0f32;
346    let mut epsilon = 1.0f32;
347    let eps_min = 0.05f32;
348    let eps_decay_steps = 2_000usize; // linear decay
349    let total_steps = std::env::var("DQN_STEPS")
350        .ok()
351        .and_then(|v| v.parse::<usize>().ok())
352        .unwrap_or(3000usize);
353
354    // Models
355    let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356    let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357    q_targ.net.copy_from(&q_net.net);
358    q_targ.set_requires_grad_all(false);
359
360    // Optimizer
361    let mut q_opt = Adam::with_learning_rate(3e-4);
362    for p in q_net.parameters() {
363        q_opt.add_parameter(p);
364    }
365
366    // Replay + env
367    let mut rb = ReplayBuffer::new(100_000, state_dim);
368    let mut env = YardEnv::new(12345);
369    let mut rng = SmallRng::new(999_111);
370
371    // Metrics
372    let mut state = env.reset();
373    let mut episode_return = 0.0f32;
374    let mut episode = 0usize;
375    let mut ema_return: Option<f32> = None;
376    let ema_alpha = 0.05f32;
377    let mut best_return = f32::NEG_INFINITY;
378
379    for t in 0..total_steps {
380        // Epsilon-greedy action
381        let action_index = if t < start_steps || rng.next_f32() < epsilon {
382            rng.sample_index(action_dim)
383        } else {
384            let _ng = NoGradTrack::new();
385            let q_vals = q_net.forward(&state);
386            let row = q_vals.data();
387            let mut best_i = 0usize;
388            let mut best_v = row[0];
389            for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390                if r > best_v {
391                    best_v = r;
392                    best_i = i;
393                }
394            }
395            best_i
396        };
397
398        // Env step
399        let (next_state, reward, done) = env.step(action_index);
400        episode_return += reward;
401
402        // Store
403        let s_slice = state.data().to_vec();
404        let s2_slice = next_state.data().to_vec();
405        rb.push(
406            &s_slice,
407            action_index,
408            reward,
409            if done { 1.0 } else { 0.0 },
410            &s2_slice,
411        );
412
413        // Reset on done
414        state = if done {
415            let st = env.reset();
416            ema_return = Some(match ema_return {
417                None => episode_return,
418                Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419            });
420            if episode_return > best_return {
421                best_return = episode_return;
422            }
423            println!(
424                "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425                t,
426                episode,
427                episode_return,
428                ema_return.unwrap_or(episode_return),
429                best_return,
430                rb.size
431            );
432            episode_return = 0.0;
433            episode += 1;
434            st
435        } else {
436            next_state
437        };
438
439        // Epsilon linear decay
440        if t < eps_decay_steps {
441            epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442        }
443
444        // Train
445        if rb.can_sample(batch_size) {
446            let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448            // Double DQN target: a* = argmax_a Q_online(s2,a); y = r + (1-d)*gamma*Q_target(s2, a*)
449            let target_q = {
450                let _ng = NoGradTrack::new();
451                let q_online_s2 = q_net.forward(&s2);
452                // argmax per row (manual on CPU)
453                let row_stride = action_dim;
454                let qd = q_online_s2.data();
455                let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456                for i in 0..batch_size {
457                    let base = i * row_stride;
458                    let mut bi = 0usize;
459                    let mut bv = qd[base];
460                    for j in 1..action_dim {
461                        let v = qd[base + j];
462                        if v > bv {
463                            bv = v;
464                            bi = j;
465                        }
466                    }
467                    next_actions.push(bi);
468                }
469                let q_targ_s2 = q_targ.forward(&s2);
470                let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471                let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472                r.add_tensor(&not_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473            };
474
475            // Q(s,a) for current actions
476            // Zero grads first
477            {
478                let mut params = q_net.parameters();
479                q_opt.zero_grad(&mut params);
480            }
481
482            let q_all = q_net.forward(&s);
483            let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484            let diff = q_sa.sub_tensor(&target_q);
485            let mut loss = pseudo_huber_mean(&diff);
486            loss.backward(None);
487
488            // Step (filter only params with grads)
489            {
490                let params = q_net.parameters();
491                let mut with_grads: Vec<&mut Tensor> = Vec::new();
492                for p in params {
493                    if p.grad_owned().is_some() {
494                        with_grads.push(p);
495                    }
496                }
497                if !with_grads.is_empty() {
498                    let gn = grad_global_norm(&mut with_grads);
499                    clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500                    q_opt.step(&mut with_grads);
501                    q_opt.zero_grad(&mut with_grads);
502                    if t % 100 == 0 {
503                        let mut pn = q_net.parameters();
504                        let pn_l2 = params_l2_norm(&mut pn);
505                        let q_mean = q_all.mean().value();
506                        println!(
507                            "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508                            t, loss.value(), q_mean, gn, pn_l2, epsilon
509                        );
510                    }
511                }
512            }
513
514            // Target hard update
515            if t % target_update_interval == 0 {
516                q_targ.net.copy_from(&q_net.net);
517            }
518
519            // Clear graphs
520            clear_all_graphs_known();
521        }
522    }
523
524    println!("=== DQN training finished ===");
525    Ok(())
526}
examples/RL_training/td3.rs (line 545)
402pub fn main() -> Result<(), Box<dyn std::error::Error>> {
403    println!("=== TD3 Example (YardEnv) ===");
404
405    // Environment / problem dims
406    let state_dim = 3usize;
407    let action_dim = 1usize;
408
409    // Hyperparameters (small for demo)
410    let gamma = 0.99f32;
411    let tau = 0.005f32; // Polyak
412    let policy_noise = 0.2f32; // target smoothing noise stddev
413    let exploration_noise = 0.1f32; // behavior policy noise stddev
414    let policy_delay = 2usize;
415    let batch_size = 64usize;
416    let start_steps = 500usize; // random exploration steps
417    let total_steps = 1500usize;
418    let max_grad_norm = 1.0f32;
419
420    // Models
421    let mut actor = Actor::new(state_dim, action_dim, Some(11));
422    let mut actor_targ = Actor::new(state_dim, action_dim, Some(12));
423    actor_targ.net.copy_from(&actor.net);
424    actor_targ.set_requires_grad_all(false);
425
426    let mut critic1 = Critic::new(state_dim, action_dim, Some(21));
427    let mut critic2 = Critic::new(state_dim, action_dim, Some(22));
428    let mut critic1_targ = Critic::new(state_dim, action_dim, Some(23));
429    let mut critic2_targ = Critic::new(state_dim, action_dim, Some(24));
430    critic1_targ.net.copy_from(&critic1.net);
431    critic2_targ.net.copy_from(&critic2.net);
432    critic1_targ.set_requires_grad_all(false);
433    critic2_targ.set_requires_grad_all(false);
434
435    // Optimizers
436    let mut actor_opt = Adam::with_learning_rate(1e-3);
437    for p in actor.parameters() {
438        actor_opt.add_parameter(p);
439    }
440
441    let mut critic_opt = Adam::with_learning_rate(1e-4);
442    for p in critic1.parameters() {
443        critic_opt.add_parameter(p);
444    }
445    for p in critic2.parameters() {
446        critic_opt.add_parameter(p);
447    }
448
449    // Replay buffer and env
450    let mut rb = ReplayBuffer::new(100_000, state_dim, action_dim);
451    let mut env = YardEnv::new(1234);
452    let mut rng = SmallRng::new(987654321);
453
454    // Reset & metric trackers
455    let mut state = env.reset(); // [1, state_dim]
456    let mut episode_return = 0.0f32;
457    let mut episode = 0usize;
458    let mut ema_return: Option<f32> = None;
459    let ema_alpha = 0.05f32; // smooth short-term
460    let mut best_return = f32::NEG_INFINITY;
461    let mut policy_updates: usize = 0;
462
463    for t in 0..total_steps {
464        // Select action
465        let action_tensor = if t < start_steps {
466            let a = rng.uniform(-1.0, 1.0);
467            Tensor::from_slice(&[a], vec![1, action_dim]).unwrap()
468        } else {
469            // Behavior policy with exploration noise
470            let _ng = NoGradTrack::new();
471            let det = actor.forward(&state);
472            let noise = Tensor::randn(vec![1, action_dim], None).mul_scalar(exploration_noise);
473            tanh_bounded(&det.add_tensor(&noise))
474        };
475        let action_value = action_tensor.data()[0];
476
477        // Environment step
478        let (next_state, reward, done) = env.step(action_value);
479        episode_return += reward;
480
481        // Store transition
482        let s_slice = state.data().to_vec();
483        let a_slice = action_tensor.data().to_vec();
484        let s2_slice = next_state.data().to_vec();
485        rb.push(
486            &s_slice,
487            &a_slice,
488            reward,
489            if done { 1.0 } else { 0.0 },
490            &s2_slice,
491        );
492
493        state = if done {
494            let st = env.reset();
495            // Metrics: update EMA and best
496            ema_return = Some(match ema_return {
497                None => episode_return,
498                Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
499            });
500            if episode_return > best_return {
501                best_return = episode_return;
502            }
503            println!(
504                "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={} | policy_updates={}",
505                t,
506                episode,
507                episode_return,
508                ema_return.unwrap_or(episode_return),
509                best_return,
510                rb.size,
511                policy_updates
512            );
513            episode_return = 0.0;
514            episode += 1;
515            st
516        } else {
517            next_state
518        };
519
520        // Training
521        if rb.can_sample(batch_size) {
522            // Sample batch
523            let (s, a, r, d, s2) = rb.sample(batch_size, &mut rng);
524
525            // Compute target values y = r + (1-d)*gamma*min(Q1', Q2') using target networks (no grad)
526            let target_q = {
527                let _ng = NoGradTrack::new();
528                // Target actions with smoothing noise (tanh bounds)
529                let noise =
530                    Tensor::randn(vec![batch_size, action_dim], None).mul_scalar(policy_noise);
531                let a_targ = tanh_bounded(&actor_targ.forward(&s2).add_tensor(&noise));
532                let q1_t = critic1_targ.forward(&s2, &a_targ);
533                let q2_t = critic2_targ.forward(&s2, &a_targ);
534
535                // Elementwise min via data() since this path is no-grad
536                let q1d = q1_t.data();
537                let q2d = q2_t.data();
538                let mut min_vec = Vec::with_capacity(batch_size);
539                for i in 0..batch_size {
540                    let v1 = q1d[i];
541                    let v2 = q2d[i];
542                    min_vec.push(v1.min(v2));
543                }
544                let min_q = Tensor::from_slice(&min_vec, vec![batch_size, 1]).unwrap();
545                let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
546                r.add_tensor(&not_done.mul_scalar(gamma).mul_tensor(&min_q))
547            };
548
549            // Critic update (both critics)
550            // Zero grads in a short scope, then drop borrows before forward
551            {
552                let mut params = {
553                    let c_params = critic1.parameters();
554                    let c2_params = critic2.parameters();
555                    let mut tmp: Vec<&mut Tensor> = Vec::new();
556                    tmp.extend(c_params);
557                    tmp.extend(c2_params);
558                    tmp
559                };
560                critic_opt.zero_grad(&mut params);
561            }
562
563            // Forward current Q estimates
564            let q1 = critic1.forward(&s, &a);
565            let q2 = critic2.forward(&s, &a);
566            let diff1 = q1.sub_tensor(&target_q);
567            let diff2 = q2.sub_tensor(&target_q);
568            let mut critic_loss = diff1
569                .pow_scalar(2.0)
570                .mean()
571                .add_tensor(&diff2.pow_scalar(2.0).mean());
572
573            // Backward
574            critic_loss.backward(None);
575
576            // Optional gradient clipping + step (only for params that received grads)
577            {
578                let params = {
579                    let c_params = critic1.parameters();
580                    let c2_params = critic2.parameters();
581                    let mut tmp: Vec<&mut Tensor> = Vec::new();
582                    tmp.extend(c_params);
583                    tmp.extend(c2_params);
584                    tmp
585                };
586                let mut with_grads: Vec<&mut Tensor> = Vec::new();
587                for p in params {
588                    if p.grad_owned().is_some() {
589                        with_grads.push(p);
590                    }
591                }
592                if !with_grads.is_empty() {
593                    // Pre-step metrics
594                    let grad_norm_before = grad_global_norm(&mut with_grads);
595                    clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
596                    critic_opt.step(&mut with_grads);
597                    critic_opt.zero_grad(&mut with_grads);
598
599                    // Post-step metrics (param norm)
600                    let mut for_norm_params = {
601                        let c_params = critic1.parameters();
602                        let c2_params = critic2.parameters();
603                        let mut tmp: Vec<&mut Tensor> = Vec::new();
604                        tmp.extend(c_params);
605                        tmp.extend(c2_params);
606                        tmp
607                    };
608                    let param_norm = params_l2_norm(&mut for_norm_params);
609
610                    // Print compact critic metrics occasionally
611                    if t % 100 == 0 {
612                        let q1_mean = q1.mean().value();
613                        let q2_mean = q2.mean().value();
614                        let tq_mean = target_q.mean().value();
615                        println!(
616                            "t={:5} | critic_loss={:.4} | q1_mean={:.3} q2_mean={:.3} tq_mean={:.3} | grad_norm={:.3} | crit_param_norm={:.3}",
617                            t,
618                            critic_loss.value(),
619                            q1_mean,
620                            q2_mean,
621                            tq_mean,
622                            grad_norm_before,
623                            param_norm
624                        );
625                    }
626                }
627            }
628
629            // Delayed policy update
630            if t % policy_delay == 0 {
631                // Actor update: maximize Q1(s, actor(s)) -> minimize -Q1
632                // Zero actor grads before backward
633                {
634                    let mut a_params: Vec<&mut Tensor> = actor.parameters();
635                    actor_opt.zero_grad(&mut a_params);
636                }
637
638                let a_pred = actor.forward(&s);
639                let q_for_actor = critic1.forward(&s, &a_pred);
640                let mut actor_loss = q_for_actor.mul_scalar(-1.0).mean();
641                actor_loss.backward(None);
642
643                {
644                    let a_params: Vec<&mut Tensor> = actor.parameters();
645                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
646                    for p in a_params {
647                        if p.grad_owned().is_some() {
648                            with_grads.push(p);
649                        }
650                    }
651                    if !with_grads.is_empty() {
652                        let grad_norm_before = grad_global_norm(&mut with_grads);
653                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
654                        actor_opt.step(&mut with_grads);
655                        actor_opt.zero_grad(&mut with_grads);
656
657                        // Post-step param norm
658                        let mut for_norm_params = actor.parameters();
659                        let param_norm = params_l2_norm(&mut for_norm_params);
660
661                        policy_updates += 1;
662                        if t % 200 == 0 {
663                            println!(
664                                "t={:5} | actor_loss={:.4} | act_grad_norm={:.3} | act_param_norm={:.3} | lr_a={:.4e} lr_c={:.4e} | policy_updates={}",
665                                t,
666                                actor_loss.value(),
667                                grad_norm_before,
668                                param_norm,
669                                actor_opt.learning_rate(),
670                                critic_opt.learning_rate(),
671                                policy_updates
672                            );
673                        }
674                    }
675                }
676
677                // Target updates (Polyak averaging, no grad)
678                actor_targ.net.soft_update_from(&actor.net, tau);
679                critic1_targ.net.soft_update_from(&critic1.net, tau);
680                critic2_targ.net.soft_update_from(&critic2.net, tau);
681            }
682
683            // Clear entire graphs to avoid stale accumulation across iterations
684            clear_all_graphs_known();
685        }
686    }
687
688    println!("=== TD3 training finished ===");
689    Ok(())
690}
Source

pub fn zeros_on_device(shape_dims: Vec<usize>, device: Device) -> Self

Creates a new tensor filled with zeros on a specific device

Convenience constructor that creates a tensor on the specified device and initializes all elements to zero. Uses optimized SIMD operations for efficient zero initialization.

§Arguments
  • shape_dims - Vector of dimension sizes defining the tensor shape
  • device - The device where the tensor should be allocated
§Returns

A new tensor with all elements initialized to zero

§Performance
  • Memory Allocation: Device-specific allocation with optimized alignment
  • Initialization: SIMD-optimized zero filling for large tensors
  • Thread Safe: Atomic ID generation for gradtrack tracking
§Examples
use train_station::Tensor;
use train_station::Device;

let tensor = Tensor::zeros_on_device(vec![2, 2], Device::cpu());
assert_eq!(tensor.device(), Device::cpu());
assert_eq!(tensor.size(), 4);

// Verify all elements are zero
assert_eq!(tensor.get(&[0, 0]), 0.0);
assert_eq!(tensor.get(&[1, 1]), 0.0);
Examples found in repository?
examples/getting_started/tensor_basics.rs (line 201)
179fn demonstrate_utility_functions() {
180    println!("\n--- Utility Functions ---");
181
182    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184    // Basic properties
185    println!("Shape: {:?}", tensor.shape().dims());
186    println!("Size: {}", tensor.size());
187    println!("Is contiguous: {}", tensor.is_contiguous());
188    println!("Device: {:?}", tensor.device());
189
190    // Mathematical operations
191    let sum = tensor.sum();
192    println!("Sum: {}", sum.value());
193
194    let mean = tensor.mean();
195    println!("Mean: {}", mean.value());
196
197    let norm = tensor.norm();
198    println!("Norm: {}", norm.value());
199
200    // Device placement
201    let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202    println!(
203        "CPU tensor: shape {:?}, device: {:?}",
204        cpu_tensor.shape().dims(),
205        cpu_tensor.device()
206    );
207}
Source

pub fn ones_on_device(shape_dims: Vec<usize>, device: Device) -> Self

Creates a new tensor filled with ones on a specific device

Convenience constructor that creates a tensor on the specified device and initializes all elements to one. Uses optimized SIMD operations for efficient initialization.

§Arguments
  • shape_dims - Vector of dimension sizes defining the tensor shape
  • device - The device where the tensor should be allocated
§Returns

A new tensor with all elements initialized to one

§Performance
  • Memory Allocation: Device-specific allocation with optimized alignment
  • Initialization: SIMD-optimized one filling for large tensors
  • Thread Safe: Atomic ID generation for gradtrack tracking
§Examples
use train_station::Tensor;
use train_station::Device;

let tensor = Tensor::ones_on_device(vec![2, 2], Device::cpu());
assert_eq!(tensor.device(), Device::cpu());
assert_eq!(tensor.size(), 4);

// Verify all elements are one
assert_eq!(tensor.get(&[0, 0]), 1.0);
assert_eq!(tensor.get(&[1, 1]), 1.0);
Source

pub fn fill(&mut self, value: f32)

Fills the tensor with a constant value using SIMD optimization

Efficiently initializes all elements of the tensor to the specified value. Uses SIMD operations for large tensors to maximize performance.

§Arguments
  • value - The value to fill the tensor with
§Performance
  • SIMD Optimization: Uses AVX2 for large tensors when available
  • Unrolled Loops: 4x unrolling for better instruction throughput
  • Memory Bandwidth: Optimized for maximum memory bandwidth utilization
§Examples
use train_station::Tensor;

let mut tensor = Tensor::new(vec![2, 3]);
tensor.fill(42.0);

// Verify all elements are 42.0
assert_eq!(tensor.get(&[0, 0]), 42.0);
assert_eq!(tensor.get(&[1, 2]), 42.0);
§Zero-Sized Tensor Handling
use train_station::Tensor;

let mut empty_tensor = Tensor::new(vec![0]);
empty_tensor.fill(42.0); // Should not panic
assert_eq!(empty_tensor.size(), 0);
Source§

impl Tensor

Source

pub fn from_slice(data: &[f32], shape_dims: Vec<usize>) -> Result<Self, String>

Creates a tensor from a slice of data

Creates a new tensor with the specified shape and copies data from the provided slice. Validates that the data size matches the tensor shape before performing the copy operation.

This method provides an efficient way to create tensors from existing data sources while ensuring data integrity and proper memory management.

§Arguments
  • data - Slice of f32 values to copy into the tensor
  • shape_dims - Vector of dimension sizes defining the tensor shape
§Returns
  • Ok(Tensor) - Successfully created tensor with copied data
  • Err(String) - Error if data size doesn’t match shape
§Performance
  • Memory Copy: Efficient non-overlapping copy using SIMD when possible
  • Validation: Fast size validation before allocation
  • Alignment: Proper memory alignment for optimal performance
  • Large Data: Optimized handling of large datasets
§Examples
§Basic Usage
use train_station::Tensor;

let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let tensor = Tensor::from_slice(&data, vec![2, 3]).unwrap();
assert_eq!(tensor.size(), 6);
assert_eq!(tensor.get(&[0, 0]), 1.0);
assert_eq!(tensor.get(&[1, 2]), 6.0);
§Multi-Dimensional Data
use train_station::Tensor;

// 1D tensor
let data_1d = [1.0, 2.0, 3.0];
let tensor_1d = Tensor::from_slice(&data_1d, vec![3]).unwrap();
assert_eq!(tensor_1d.shape().dims(), vec![3]);
assert_eq!(tensor_1d.get(&[1]), 2.0);

// 3D tensor
let data_3d = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let tensor_3d = Tensor::from_slice(&data_3d, vec![2, 2, 2]).unwrap();
assert_eq!(tensor_3d.shape().dims(), vec![2, 2, 2]);
assert_eq!(tensor_3d.get(&[0, 0, 0]), 1.0);
assert_eq!(tensor_3d.get(&[1, 1, 1]), 8.0);
§Error Handling
use train_station::Tensor;

// Size mismatch error
let data = [1.0, 2.0, 3.0];
let result = Tensor::from_slice(&data, vec![2, 2]);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.contains("Data size 3 doesn't match shape size 4"));
§Zero-Sized Tensors
use train_station::Tensor;

// Handle empty tensors gracefully
let data: [f32; 0] = [];
let tensor = Tensor::from_slice(&data, vec![0]).unwrap();
assert_eq!(tensor.size(), 0);
assert_eq!(tensor.shape().dims(), vec![0]);
§Large Data Sets
use train_station::Tensor;

// Efficient handling of large datasets
let size = 1000;
let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
let tensor = Tensor::from_slice(&data, vec![size]).unwrap();

assert_eq!(tensor.size(), size);
assert_eq!(tensor.get(&[0]), 0.0);
assert_eq!(tensor.get(&[100]), 100.0);
assert_eq!(tensor.get(&[999]), 999.0);
§Implementation Details

This method performs the following steps:

  1. Shape Validation: Creates a Shape object and validates dimensions
  2. Size Check: Ensures data length matches the calculated tensor size
  3. Memory Allocation: Allocates tensor memory with proper alignment
  4. Data Copy: Uses efficient non-overlapping memory copy operation
  5. Return: Returns the created tensor or descriptive error message

The memory copy operation uses std::ptr::copy_nonoverlapping for maximum performance and safety, ensuring no data corruption occurs during the copy process.

Examples found in repository?
examples/RL_training/dqn.rs (line 185)
184    fn state_tensor(&self) -> Tensor {
185        Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
186    }
187
188    fn step(&mut self, action_index: usize) -> (Tensor, f32, bool) {
189        let a = Self::ACTIONS[action_index.min(2)];
190        self.vel += 0.1 * a - 0.01 * self.pos;
191        self.pos += self.vel;
192        self.steps += 1;
193        let reward = -(self.pos * self.pos) - 0.05 * (a * a);
194        let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
195        (self.state_tensor(), reward, done)
196    }
197}
198
199// -------------------------------
200// Replay Buffer
201// -------------------------------
202
203struct ReplayBuffer {
204    capacity: usize,
205    size: usize,
206    pos: usize,
207    state_dim: usize,
208    states: Vec<f32>,
209    actions: Vec<usize>,
210    rewards: Vec<f32>,
211    dones: Vec<f32>,
212    next_states: Vec<f32>,
213}
214
215impl ReplayBuffer {
216    fn new(capacity: usize, state_dim: usize) -> Self {
217        Self {
218            capacity,
219            size: 0,
220            pos: 0,
221            state_dim,
222            states: vec![0.0; capacity * state_dim],
223            actions: vec![0usize; capacity],
224            rewards: vec![0.0; capacity],
225            dones: vec![0.0; capacity],
226            next_states: vec![0.0; capacity * state_dim],
227        }
228    }
229
230    fn push(&mut self, s: &[f32], a_idx: usize, r: f32, d: f32, s2: &[f32]) {
231        let i = self.pos;
232        let so = i * self.state_dim;
233        self.states[so..so + self.state_dim].copy_from_slice(s);
234        self.actions[i] = a_idx;
235        self.rewards[i] = r;
236        self.dones[i] = d;
237        self.next_states[so..so + self.state_dim].copy_from_slice(s2);
238        self.pos = (self.pos + 1) % self.capacity;
239        self.size = self.size.saturating_add(1).min(self.capacity);
240    }
241
242    fn can_sample(&self, batch_size: usize) -> bool {
243        self.size >= batch_size
244    }
245
246    fn sample(
247        &self,
248        batch_size: usize,
249        rng: &mut SmallRng,
250    ) -> (Tensor, Vec<usize>, Tensor, Tensor, Tensor) {
251        let mut s_vec = Vec::with_capacity(batch_size * self.state_dim);
252        let mut a_idx = Vec::with_capacity(batch_size);
253        let mut r_vec = Vec::with_capacity(batch_size);
254        let mut d_vec = Vec::with_capacity(batch_size);
255        let mut s2_vec = Vec::with_capacity(batch_size * self.state_dim);
256        for _ in 0..batch_size {
257            let idx = rng.sample_index(self.size);
258            let so = idx * self.state_dim;
259            s_vec.extend_from_slice(&self.states[so..so + self.state_dim]);
260            a_idx.push(self.actions[idx]);
261            r_vec.push(self.rewards[idx]);
262            d_vec.push(self.dones[idx]);
263            s2_vec.extend_from_slice(&self.next_states[so..so + self.state_dim]);
264        }
265        let s = Tensor::from_slice(&s_vec, vec![batch_size, self.state_dim]).unwrap();
266        let r = Tensor::from_slice(&r_vec, vec![batch_size, 1]).unwrap();
267        let d = Tensor::from_slice(&d_vec, vec![batch_size, 1]).unwrap();
268        let s2 = Tensor::from_slice(&s2_vec, vec![batch_size, self.state_dim]).unwrap();
269        (s, a_idx, r, d, s2)
270    }
More examples
Hide additional examples
examples/RL_training/ppo_discrete.rs (line 151)
150    fn state_tensor(&self) -> Tensor {
151        Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
152    }
153    fn step(&mut self, action_idx: usize) -> (Tensor, f32, bool) {
154        let a = Self::ACTIONS[action_idx.min(2)];
155        self.vel += 0.1 * a - 0.01 * self.pos;
156        self.pos += self.vel;
157        self.steps += 1;
158        let reward = -(self.pos * self.pos) - 0.05 * (a * a);
159        let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
160        (self.state_tensor(), reward, done)
161    }
162}
163
164// -------------------------------
165// Rollout storage
166// -------------------------------
167
168struct RolloutBatch {
169    states: Vec<f32>,
170    actions: Vec<usize>,
171    old_logps: Vec<f32>,
172    rewards: Vec<f32>,
173    dones: Vec<f32>,
174    values: Vec<f32>,
175    next_states: Vec<f32>,
176    _state_dim: usize,
177}
178impl RolloutBatch {
179    fn new(cap: usize, sd: usize) -> Self {
180        Self {
181            states: Vec::with_capacity(cap * sd),
182            actions: Vec::with_capacity(cap),
183            old_logps: Vec::with_capacity(cap),
184            rewards: Vec::with_capacity(cap),
185            dones: Vec::with_capacity(cap),
186            values: Vec::with_capacity(cap),
187            next_states: Vec::with_capacity(cap * sd),
188            _state_dim: sd,
189        }
190    }
191    #[allow(clippy::too_many_arguments)]
192    fn push(&mut self, s: &[f32], a: usize, lp: f32, r: f32, d: f32, v: f32, s2: &[f32]) {
193        self.states.extend_from_slice(s);
194        self.actions.push(a);
195        self.old_logps.push(lp);
196        self.rewards.push(r);
197        self.dones.push(d);
198        self.values.push(v);
199        self.next_states.extend_from_slice(s2);
200    }
201    fn len(&self) -> usize {
202        self.actions.len()
203    }
204}
205
206// -------------------------------
207// Helpers
208// -------------------------------
209
210#[allow(clippy::too_many_arguments)]
211fn compute_gae(
212    returns_out: &mut [f32],
213    adv_out: &mut [f32],
214    rewards: &[f32],
215    dones: &[f32],
216    values: &[f32],
217    next_values: &[f32],
218    gamma: f32,
219    lam: f32,
220) {
221    let n = rewards.len();
222    let mut gae = 0.0f32;
223    for t in (0..n).rev() {
224        let not_done = 1.0 - dones[t];
225        let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
226        gae = delta + gamma * lam * not_done * gae;
227        adv_out[t] = gae;
228        returns_out[t] = gae + values[t];
229    }
230}
231
232fn normalize_in_place(x: &mut [f32], eps: f32) {
233    let n = x.len() as f32;
234    if n <= 1.0 {
235        return;
236    }
237    let mean = x.iter().copied().sum::<f32>() / n;
238    let var = x
239        .iter()
240        .map(|v| {
241            let d = v - mean;
242            d * d
243        })
244        .sum::<f32>()
245        / n;
246    let std = (var + eps).sqrt();
247    for v in x.iter_mut() {
248        *v = (*v - mean) / std;
249    }
250}
251
252fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
253    let mut total_sq = 0.0f32;
254    for p in parameters.iter() {
255        if let Some(g) = p.grad_owned() {
256            for &v in g.data() {
257                total_sq += v * v;
258            }
259        }
260    }
261    let norm = total_sq.sqrt();
262    if norm > max_norm {
263        let scale = max_norm / (norm + eps);
264        for p in parameters.iter_mut() {
265            if let Some(g) = p.grad_owned() {
266                p.set_grad(g.mul_scalar(scale));
267            }
268        }
269    }
270}
271
272// log-softmax for selected actions: given logits [B,A] and actions Vec<usize> -> log_prob [B,1]
273fn log_prob_actions(
274    logits: &Tensor,
275    actions: &[usize],
276    batch: usize,
277    _action_dim: usize,
278) -> Tensor {
279    let max_logits = logits.max_dims(&[1], true); // [B,1]
280    let shifted = logits.sub_tensor(&max_logits);
281    let exp = shifted.exp();
282    let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283    let log_sum_exp = sum_exp.log(); // [B,1]
284    let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285                                                        // gather selected action log-probs
286    log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291    new_logp.sub_tensor(old_logp).exp()
292}
293
294// Clamp ratio to [1-clip, 1+clip] using ReLU-based clamp (no custom ops)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296    let b = ratio.shape().dims()[0];
297    let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298    let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299    let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300    high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304    let mut total_sq = 0.0f32;
305    for p in parameters.iter_mut() {
306        if let Some(g) = p.grad_owned() {
307            for &v in g.data() {
308                total_sq += v * v;
309            }
310        }
311    }
312    total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320    println!("=== PPO Discrete Example (YardEnv) ===");
321
322    let state_dim = 3usize;
323    let action_dim = 3usize;
324    let total_steps = std::env::var("PPOD_STEPS")
325        .ok()
326        .and_then(|v| v.parse::<usize>().ok())
327        .unwrap_or(3500usize);
328    let horizon = 128usize;
329    let epochs = 4usize;
330    let mini_batch_size = 64usize;
331    let gamma = 0.99f32;
332    let lam = 0.95f32;
333    let clip_eps = 0.2f32;
334    let vf_coef = 0.5f32;
335    let ent_coef = 0.0f32;
336    let max_grad_norm = 1.0f32;
337
338    let mut actor = Actor::new(state_dim, action_dim, Some(111));
339    let mut critic = Critic::new(state_dim, Some(222));
340    let mut actor_opt = Adam::with_learning_rate(3e-4);
341    for p in actor.parameters() {
342        actor_opt.add_parameter(p);
343    }
344    let mut critic_opt = Adam::with_learning_rate(3e-4);
345    for p in critic.parameters() {
346        critic_opt.add_parameter(p);
347    }
348
349    let mut env = YardEnv::new(1234);
350    let mut rng = SmallRng::new(98765);
351    let mut state = env.reset();
352    let mut episode_return = 0.0f32;
353    let mut episode = 0usize;
354    let mut ema_return: Option<f32> = None;
355    let ema_alpha = 0.05f32;
356    let mut best_return = f32::NEG_INFINITY;
357
358    let mut t = 0usize;
359    while t < total_steps {
360        let mut batch = RolloutBatch::new(horizon, state_dim);
361        for _ in 0..horizon {
362            // Actor logits and categorical sampling
363            let logits = actor.forward(&state); // [1, A]
364            let probs = logits.softmax(1); // [1, A]
365                                           // sample action from probs (CPU sampling)
366            let p = probs.data();
367            let (p0, p1, _p2) = (p[0], p[1], p[2]);
368            let u = rng.next_f32();
369            let a_idx = if u < p0 {
370                0
371            } else if u < p0 + p1 {
372                1
373            } else {
374                2
375            };
376
377            let old_logp = {
378                let _ng = NoGradTrack::new();
379                let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380                lp.data()[0]
381            };
382
383            // Step env
384            let (next_state, reward, done) = env.step(a_idx);
385            episode_return += reward;
386
387            // Critic value
388            let value_t = critic.forward(&state);
389            let value_v = value_t.data()[0];
390
391            batch.push(
392                state.data(),
393                a_idx,
394                old_logp,
395                reward,
396                if done { 1.0 } else { 0.0 },
397                value_v,
398                next_state.data(),
399            );
400
401            state = if done {
402                let st = env.reset();
403                ema_return = Some(match ema_return {
404                    None => episode_return,
405                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406                });
407                if episode_return > best_return {
408                    best_return = episode_return;
409                }
410                println!(
411                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412                    t,
413                    episode,
414                    episode_return,
415                    ema_return.unwrap_or(episode_return),
416                    best_return
417                );
418                episode_return = 0.0;
419                episode += 1;
420                st
421            } else {
422                next_state
423            };
424
425            t += 1;
426            if t >= total_steps {
427                break;
428            }
429        }
430
431        // Bootstrap values for GAE
432        let next_values: Vec<f32> = {
433            let mut out = Vec::with_capacity(batch.len());
434            for i in 0..batch.len() {
435                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437                out.push(critic.forward(&s2_t).data()[0]);
438            }
439            out
440        };
441
442        let mut returns = vec![0.0f32; batch.len()];
443        let mut adv = vec![0.0f32; batch.len()];
444        compute_gae(
445            &mut returns,
446            &mut adv,
447            &batch.rewards,
448            &batch.dones,
449            &batch.values,
450            &next_values,
451            gamma,
452            lam,
453        );
454        normalize_in_place(&mut adv, 1e-8);
455
456        // Tensors for training
457        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458        let actions_vec = batch.actions.clone();
459        let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463        // PPO epochs
464        let num_minibatches = batch.len().div_ceil(mini_batch_size);
465        for e in 0..epochs {
466            for mb in 0..num_minibatches {
467                let start = mb * mini_batch_size;
468                let end = (start + mini_batch_size).min(batch.len());
469                if start >= end {
470                    break;
471                }
472
473                // Views
474                let s_mb = states_t
475                    .slice_view(start * state_dim, 1, (end - start) * state_dim)
476                    .reshape(vec![(end - start) as i32, state_dim as i32]);
477                let oldlp_mb = old_logp_t
478                    .slice_view(start, 1, end - start)
479                    .reshape(vec![(end - start) as i32, 1]);
480                let ret_mb = returns_t
481                    .slice_view(start, 1, end - start)
482                    .reshape(vec![(end - start) as i32, 1]);
483                let adv_mb = adv_t
484                    .slice_view(start, 1, end - start)
485                    .reshape(vec![(end - start) as i32, 1]);
486                let a_slice = &actions_vec[start..end];
487
488                // Zero grads
489                {
490                    let mut ps = actor.parameters();
491                    actor_opt.zero_grad(&mut ps);
492                }
493                {
494                    let mut ps = critic.parameters();
495                    critic_opt.zero_grad(&mut ps);
496                }
497
498                // Forward
499                let logits_mb = actor.forward(&s_mb); // [B,A]
500                let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501                let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502                let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503                let pg1 = ratio.mul_tensor(&adv_mb);
504                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507                let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509                let v_pred = critic.forward(&s_mb);
510                let v_loss = v_pred
511                    .sub_tensor(&ret_mb)
512                    .pow_scalar(2.0)
513                    .mean()
514                    .mul_scalar(vf_coef);
515
516                // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517                let probs_mb = logits_mb.softmax(1);
518                let logp_all = probs_mb.add_scalar(1e-8).log();
519                let ent = probs_mb
520                    .mul_tensor(&logp_all)
521                    .sum_dims(&[1], true)
522                    .mul_scalar(-1.0)
523                    .mean()
524                    .mul_scalar(ent_coef);
525
526                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527                loss.backward(None);
528
529                // Step actor
530                {
531                    let params = actor.parameters();
532                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
533                    for p in params {
534                        if p.grad_owned().is_some() {
535                            with_grads.push(p);
536                        }
537                    }
538                    if !with_grads.is_empty() {
539                        let _ = grad_global_norm(&mut with_grads);
540                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541                        actor_opt.step(&mut with_grads);
542                        actor_opt.zero_grad(&mut with_grads);
543                    }
544                }
545
546                // Step critic
547                {
548                    let params = critic.parameters();
549                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
550                    for p in params {
551                        if p.grad_owned().is_some() {
552                            with_grads.push(p);
553                        }
554                    }
555                    if !with_grads.is_empty() {
556                        let _ = grad_global_norm(&mut with_grads);
557                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558                        critic_opt.step(&mut with_grads);
559                        critic_opt.zero_grad(&mut with_grads);
560                    }
561                }
562
563                if e == 0 && mb == 0 {
564                    println!(
565                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
566                        t,
567                        actor_loss.value(),
568                        v_loss.value()
569                    );
570                }
571
572                clear_all_graphs_known();
573            }
574        }
575    }
576
577    println!("=== PPO discrete training finished ===");
578    Ok(())
579}
examples/RL_training/ppo_continuous.rs (line 101)
99    fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
100        let net = Mlp::new(&[state_dim, 64, 64, action_dim], seed);
101        let log_std = Tensor::from_slice(&vec![0.0; action_dim], vec![action_dim])
102            .unwrap()
103            .with_requires_grad();
104        Self { net, log_std }
105    }
106    fn forward(&self, state: &Tensor) -> (Tensor, Tensor) {
107        // Returns (mean [B, A], log_std [A])
108        let mean = self.net.forward(state);
109        (
110            mean,
111            self.log_std
112                .view(vec![1, self.log_std.shape().dims()[0] as i32]),
113        )
114    }
115    fn parameters(&mut self) -> Vec<&mut Tensor> {
116        let mut ps = self.net.parameters();
117        ps.push(&mut self.log_std);
118        ps
119    }
120}
121
122// -------------------------------
123// Critic: value function V(s)
124// -------------------------------
125
126struct Critic {
127    net: Mlp,
128}
129impl Critic {
130    fn new(state_dim: usize, seed: Option<u64>) -> Self {
131        Self {
132            net: Mlp::new(&[state_dim, 64, 64, 1], seed),
133        }
134    }
135    fn forward(&self, state: &Tensor) -> Tensor {
136        self.net.forward(state)
137    }
138    fn parameters(&mut self) -> Vec<&mut Tensor> {
139        self.net.parameters()
140    }
141}
142
143// -------------------------------
144// Continuous YardEnv (same dynamics as TD3 env)
145// -------------------------------
146
147struct YardEnv {
148    pos: f32,
149    vel: f32,
150    steps: usize,
151    max_steps: usize,
152    rng: SmallRng,
153}
154impl YardEnv {
155    fn new(seed: u64) -> Self {
156        let mut e = Self {
157            pos: 0.0,
158            vel: 0.0,
159            steps: 0,
160            max_steps: 200,
161            rng: SmallRng::new(seed),
162        };
163        e.reset();
164        e
165    }
166    fn reset(&mut self) -> Tensor {
167        self.pos = (self.rng.next_f32() * 1.0) - 0.5;
168        self.vel = (self.rng.next_f32() * 0.2) - 0.1;
169        self.steps = 0;
170        self.state_tensor()
171    }
172    fn state_tensor(&self) -> Tensor {
173        Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
174    }
175    fn step(&mut self, action_value: f32) -> (Tensor, f32, bool) {
176        let a = action_value.clamp(-1.0, 1.0);
177        self.vel += 0.1 * a - 0.01 * self.pos;
178        self.pos += self.vel;
179        self.steps += 1;
180        let reward = -(self.pos * self.pos) - 0.1 * (a * a);
181        let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
182        (self.state_tensor(), reward, done)
183    }
184}
185
186// -------------------------------
187// Trajectory storage
188// -------------------------------
189
190struct RolloutBatch {
191    states: Vec<f32>,
192    actions: Vec<f32>,
193    log_probs: Vec<f32>,
194    rewards: Vec<f32>,
195    dones: Vec<f32>,
196    values: Vec<f32>,
197    next_states: Vec<f32>,
198    _state_dim: usize,
199}
200impl RolloutBatch {
201    fn new(capacity: usize, state_dim: usize) -> Self {
202        Self {
203            states: Vec::with_capacity(capacity * state_dim),
204            actions: Vec::with_capacity(capacity),
205            log_probs: Vec::with_capacity(capacity),
206            rewards: Vec::with_capacity(capacity),
207            dones: Vec::with_capacity(capacity),
208            values: Vec::with_capacity(capacity),
209            next_states: Vec::with_capacity(capacity * state_dim),
210            _state_dim: state_dim,
211        }
212    }
213
214    #[allow(clippy::too_many_arguments)]
215    fn push(&mut self, s: &[f32], a: f32, lp: f32, r: f32, d: f32, v: f32, s2: &[f32]) {
216        self.states.extend_from_slice(s);
217        self.actions.push(a);
218        self.log_probs.push(lp);
219        self.rewards.push(r);
220        self.dones.push(d);
221        self.values.push(v);
222        self.next_states.extend_from_slice(s2);
223    }
224
225    fn len(&self) -> usize {
226        self.actions.len()
227    }
228}
229
230// -------------------------------
231// Math helpers
232// -------------------------------
233
234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235    // All tensors shaped [B, A] (log_std is broadcastable)
236    let std = log_std.exp();
237    let var = std.pow_scalar(2.0);
238    let log_scale = log_std;
239    let diff = action.sub_tensor(mean);
240    let log_prob = diff
241        .pow_scalar(2.0)
242        .div_tensor(&var)
243        .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244        .add_tensor(&log_scale.mul_scalar(2.0))
245        .mul_scalar(0.5)
246        .mul_scalar(-1.0);
247    // Sum across action dim (dim=1) -> [B,1]
248    log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253    returns_out: &mut [f32],
254    adv_out: &mut [f32],
255    rewards: &[f32],
256    dones: &[f32],
257    values: &[f32],
258    next_values: &[f32],
259    gamma: f32,
260    lam: f32,
261) {
262    let n = rewards.len();
263    let mut gae = 0.0f32;
264    for t in (0..n).rev() {
265        let not_done = 1.0 - dones[t];
266        let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267        gae = delta + gamma * lam * not_done * gae;
268        adv_out[t] = gae;
269        returns_out[t] = gae + values[t];
270    }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274    let n = x.len() as f32;
275    if n <= 1.0 {
276        return;
277    }
278    let mean = x.iter().copied().sum::<f32>() / n;
279    let var = x
280        .iter()
281        .map(|v| {
282            let d = v - mean;
283            d * d
284        })
285        .sum::<f32>()
286        / n;
287    let std = (var + eps).sqrt();
288    for v in x.iter_mut() {
289        *v = (*v - mean) / std;
290    }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294    let mut total_sq = 0.0f32;
295    for p in parameters.iter() {
296        if let Some(g) = p.grad_owned() {
297            for &v in g.data() {
298                total_sq += v * v;
299            }
300        }
301    }
302    let norm = total_sq.sqrt();
303    if norm > max_norm {
304        let scale = max_norm / (norm + eps);
305        for p in parameters.iter_mut() {
306            if let Some(g) = p.grad_owned() {
307                p.set_grad(g.mul_scalar(scale));
308            }
309        }
310    }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314    let mut total_sq = 0.0f32;
315    for p in parameters.iter_mut() {
316        if let Some(g) = p.grad_owned() {
317            for &v in g.data() {
318                total_sq += v * v;
319            }
320        }
321    }
322    total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330    println!("=== PPO Continuous Example (YardEnv) ===");
331
332    let state_dim = 3usize;
333    let action_dim = 1usize;
334
335    // Hparams
336    let total_steps = std::env::var("PPO_STEPS")
337        .ok()
338        .and_then(|v| v.parse::<usize>().ok())
339        .unwrap_or(4000usize);
340    let horizon = 128usize; // rollout length per update
341    let epochs = 4usize; // PPO epochs per update
342    let mini_batch_size = 64usize; // minibatch from horizon
343    let gamma = 0.99f32;
344    let lam = 0.95f32; // GAE lambda
345    let clip_eps = 0.2f32;
346    let vf_coef = 0.5f32;
347    let ent_coef = 0.0f32;
348    let max_grad_norm = 1.0f32;
349
350    // Models
351    let mut actor = Actor::new(state_dim, action_dim, Some(101));
352    let mut critic = Critic::new(state_dim, Some(202));
353
354    // Opts
355    let mut actor_opt = Adam::with_learning_rate(3e-4);
356    for p in actor.parameters() {
357        actor_opt.add_parameter(p);
358    }
359    let mut critic_opt = Adam::with_learning_rate(3e-4);
360    for p in critic.parameters() {
361        critic_opt.add_parameter(p);
362    }
363
364    // Env and RNG
365    let mut env = YardEnv::new(42);
366    let mut rng = SmallRng::new(999);
367    let mut state = env.reset();
368
369    // Metrics
370    let mut episode_return = 0.0f32;
371    let mut episode = 0usize;
372    let mut ema_return: Option<f32> = None;
373    let ema_alpha = 0.05f32;
374    let mut best_return = f32::NEG_INFINITY;
375
376    let mut t = 0usize;
377    while t < total_steps {
378        // Collect a rollout
379        let mut batch = RolloutBatch::new(horizon, state_dim);
380        for _ in 0..horizon {
381            // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382            let (mean, log_std_row) = actor.forward(&state);
383            let mean_v = mean.data()[0];
384            let log_std_v = log_std_row.data()[0];
385            let std_v = log_std_v.exp();
386            let noise = rng.normal();
387            let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389            // Build action tensor [1, A] for log_prob calculation with autograd
390            let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391            let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392            let log_prob_v = log_prob_t.data()[0];
393
394            // Step env
395            let (next_state, reward, done) = env.step(action_v);
396            episode_return += reward;
397
398            // Value
399            let value_t = critic.forward(&state);
400            let value_v = value_t.data()[0];
401
402            // Push
403            batch.push(
404                state.data(),
405                action_v,
406                log_prob_v,
407                reward,
408                if done { 1.0 } else { 0.0 },
409                value_v,
410                next_state.data(),
411            );
412
413            // Reset
414            state = if done {
415                let st = env.reset();
416                ema_return = Some(match ema_return {
417                    None => episode_return,
418                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419                });
420                if episode_return > best_return {
421                    best_return = episode_return;
422                }
423                println!(
424                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425                    t,
426                    episode,
427                    episode_return,
428                    ema_return.unwrap_or(episode_return),
429                    best_return
430                );
431                episode_return = 0.0;
432                episode += 1;
433                st
434            } else {
435                next_state
436            };
437
438            t += 1;
439            if t >= total_steps {
440                break;
441            }
442        }
443
444        // Bootstrap next values for GAE
445        let next_values: Vec<f32> = {
446            let mut out = Vec::with_capacity(batch.len());
447            for i in 0..batch.len() {
448                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450                let v2 = critic.forward(&s2_t).data()[0];
451                out.push(v2);
452            }
453            out
454        };
455
456        // Compute returns and advantages
457        let mut returns = vec![0.0f32; batch.len()];
458        let mut adv = vec![0.0f32; batch.len()];
459        compute_gae(
460            &mut returns,
461            &mut adv,
462            &batch.rewards,
463            &batch.dones,
464            &batch.values,
465            &next_values,
466            gamma,
467            lam,
468        );
469        normalize_in_place(&mut adv, 1e-8);
470
471        // Prepare tensors for training
472        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473        let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474        let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478        // PPO epochs over the rollout
479        let num_minibatches = batch.len().div_ceil(mini_batch_size);
480        for e in 0..epochs {
481            for mb in 0..num_minibatches {
482                let start = mb * mini_batch_size;
483                let end = (start + mini_batch_size).min(batch.len());
484                if start >= end {
485                    break;
486                }
487
488                // Slice views
489                let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490                let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491                let a_mb = actions_t
492                    .slice_view(start * action_dim, 1, (end - start) * action_dim)
493                    .reshape(vec![(end - start) as i32, action_dim as i32]);
494                let oldlp_mb = old_logp_t
495                    .slice_view(start, 1, end - start)
496                    .reshape(vec![(end - start) as i32, 1]);
497                let ret_mb = returns_t
498                    .slice_view(start, 1, end - start)
499                    .reshape(vec![(end - start) as i32, 1]);
500                let adv_mb = adv_t
501                    .slice_view(start, 1, end - start)
502                    .reshape(vec![(end - start) as i32, 1]);
503
504                // Zero grads
505                {
506                    let mut ps = actor.parameters();
507                    actor_opt.zero_grad(&mut ps);
508                }
509                {
510                    let mut ps = critic.parameters();
511                    critic_opt.zero_grad(&mut ps);
512                }
513
514                // Forward actor and critic
515                let (mean_mb, log_std_row) = actor.forward(&s_mb);
516                let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517                let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518                let clip_low =
519                    Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520                        .unwrap();
521                let clip_high =
522                    Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523                        .unwrap();
524                // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525                let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526                let ratio_clipped =
527                    clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528                let pg1 = ratio.mul_tensor(&adv_mb);
529                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532                let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534                let v_pred = critic.forward(&s_mb);
535                let v_loss = v_pred
536                    .sub_tensor(&ret_mb)
537                    .pow_scalar(2.0)
538                    .mean()
539                    .mul_scalar(vf_coef);
540
541                // Entropy (approx Gaussian entropy per action)
542                let entropy = log_std_row
543                    .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544                    .sum_dims(&[1], true)
545                    .mean()
546                    .mul_scalar(ent_coef);
547
548                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549                loss.backward(None);
550
551                // Step actor
552                {
553                    let params = actor.parameters();
554                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
555                    for p in params {
556                        if p.grad_owned().is_some() {
557                            with_grads.push(p);
558                        }
559                    }
560                    if !with_grads.is_empty() {
561                        let _ = grad_global_norm(&mut with_grads);
562                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563                        actor_opt.step(&mut with_grads);
564                        actor_opt.zero_grad(&mut with_grads);
565                    }
566                }
567
568                // Step critic
569                {
570                    let params = critic.parameters();
571                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
572                    for p in params {
573                        if p.grad_owned().is_some() {
574                            with_grads.push(p);
575                        }
576                    }
577                    if !with_grads.is_empty() {
578                        let _ = grad_global_norm(&mut with_grads);
579                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580                        critic_opt.step(&mut with_grads);
581                        critic_opt.zero_grad(&mut with_grads);
582                    }
583                }
584
585                // Occasionally log
586                if e == 0 && mb == 0 {
587                    println!(
588                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
589                        t,
590                        actor_loss.value(),
591                        v_loss.value()
592                    );
593                }
594
595                clear_all_graphs_known();
596            }
597        }
598    }
599
600    println!("=== PPO training finished ===");
601    Ok(())
602}
examples/RL_training/td3.rs (line 250)
244    fn state_tensor(&self) -> Tensor {
245        // Normalize to keep critic inputs bounded:
246        // - Position is bounded by termination at |pos|>3 → scale by 3 to [-1,1]
247        // - Velocity scaled by 1.0 and clamped to [-1,1]
248        let pos_n = self.pos / 3.0;
249        let vel_n = self.vel.clamp(-1.0, 1.0);
250        Tensor::from_slice(&[pos_n, vel_n, 0.0], vec![1, 3]).unwrap()
251    }
252
253    fn step(&mut self, action_value: f32) -> (Tensor, f32, bool) {
254        let a = action_value.clamp(-1.0, 1.0);
255        self.vel += 0.1 * a - 0.01 * self.pos;
256        self.pos += self.vel;
257        self.steps += 1;
258
259        let reward = -(self.pos * self.pos) - 0.1 * (a * a);
260        let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
261        (self.state_tensor(), reward, done)
262    }
263}
264
265// -------------------------------
266// Replay Buffer
267// -------------------------------
268
269struct ReplayBuffer {
270    capacity: usize,
271    size: usize,
272    pos: usize,
273    state_dim: usize,
274    action_dim: usize,
275    states: Vec<f32>,
276    actions: Vec<f32>,
277    rewards: Vec<f32>,
278    dones: Vec<f32>,
279    next_states: Vec<f32>,
280}
281
282impl ReplayBuffer {
283    fn new(capacity: usize, state_dim: usize, action_dim: usize) -> Self {
284        Self {
285            capacity,
286            size: 0,
287            pos: 0,
288            state_dim,
289            action_dim,
290            states: vec![0.0; capacity * state_dim],
291            actions: vec![0.0; capacity * action_dim],
292            rewards: vec![0.0; capacity],
293            dones: vec![0.0; capacity],
294            next_states: vec![0.0; capacity * state_dim],
295        }
296    }
297
298    fn push(&mut self, s: &[f32], a: &[f32], r: f32, d: f32, s2: &[f32]) {
299        let i = self.pos;
300        let so = i * self.state_dim;
301        let ao = i * self.action_dim;
302        self.states[so..so + self.state_dim].copy_from_slice(s);
303        self.actions[ao..ao + self.action_dim].copy_from_slice(a);
304        self.rewards[i] = r;
305        self.dones[i] = d;
306        self.next_states[so..so + self.state_dim].copy_from_slice(s2);
307
308        self.pos = (self.pos + 1) % self.capacity;
309        self.size = self.size.saturating_add(1).min(self.capacity);
310    }
311
312    fn can_sample(&self, batch_size: usize) -> bool {
313        self.size >= batch_size
314    }
315
316    fn sample(
317        &self,
318        batch_size: usize,
319        rng: &mut SmallRng,
320    ) -> (Tensor, Tensor, Tensor, Tensor, Tensor) {
321        let mut s_vec = Vec::with_capacity(batch_size * self.state_dim);
322        let mut a_vec = Vec::with_capacity(batch_size * self.action_dim);
323        let mut r_vec = Vec::with_capacity(batch_size);
324        let mut d_vec = Vec::with_capacity(batch_size);
325        let mut s2_vec = Vec::with_capacity(batch_size * self.state_dim);
326
327        for _ in 0..batch_size {
328            let idx = rng.sample_index(self.size);
329            let so = idx * self.state_dim;
330            let ao = idx * self.action_dim;
331            s_vec.extend_from_slice(&self.states[so..so + self.state_dim]);
332            a_vec.extend_from_slice(&self.actions[ao..ao + self.action_dim]);
333            r_vec.push(self.rewards[idx]);
334            d_vec.push(self.dones[idx]);
335            s2_vec.extend_from_slice(&self.next_states[so..so + self.state_dim]);
336        }
337
338        let s = Tensor::from_slice(&s_vec, vec![batch_size, self.state_dim]).unwrap();
339        let a = Tensor::from_slice(&a_vec, vec![batch_size, self.action_dim]).unwrap();
340        let r = Tensor::from_slice(&r_vec, vec![batch_size, 1]).unwrap();
341        let d = Tensor::from_slice(&d_vec, vec![batch_size, 1]).unwrap();
342        let s2 = Tensor::from_slice(&s2_vec, vec![batch_size, self.state_dim]).unwrap();
343        (s, a, r, d, s2)
344    }
345}
346
347// -------------------------------
348// Helper: gradient clipping by global norm
349// -------------------------------
350
351fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
352    // Compute global L2 norm of all grads
353    let mut total_sq = 0.0f32;
354    for p in parameters.iter() {
355        if let Some(g) = p.grad_owned() {
356            for &v in g.data() {
357                total_sq += v * v;
358            }
359        }
360    }
361    let norm = total_sq.sqrt();
362    if norm > max_norm {
363        let scale = max_norm / (norm + eps);
364        for p in parameters.iter_mut() {
365            if let Some(g) = p.grad_owned() {
366                let scaled = g.mul_scalar(scale);
367                p.set_grad(scaled);
368            }
369        }
370    }
371}
372
373// Compute global L2 norm of gradients across a parameter list (read-only)
374fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
375    let mut total_sq = 0.0f32;
376    for p in parameters.iter_mut() {
377        if let Some(g) = p.grad_owned() {
378            for &v in g.data() {
379                total_sq += v * v;
380            }
381        }
382    }
383    total_sq.sqrt()
384}
385
386// Compute L2 norm of parameters (weights/biases) across a parameter list
387fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
388    let _ng = NoGradTrack::new();
389    let mut total_sq = 0.0f32;
390    for p in parameters.iter_mut() {
391        for &v in p.data() {
392            total_sq += v * v;
393        }
394    }
395    total_sq.sqrt()
396}
397
398// -------------------------------
399// Main: TD3 training on YardEnv
400// -------------------------------
401
402pub fn main() -> Result<(), Box<dyn std::error::Error>> {
403    println!("=== TD3 Example (YardEnv) ===");
404
405    // Environment / problem dims
406    let state_dim = 3usize;
407    let action_dim = 1usize;
408
409    // Hyperparameters (small for demo)
410    let gamma = 0.99f32;
411    let tau = 0.005f32; // Polyak
412    let policy_noise = 0.2f32; // target smoothing noise stddev
413    let exploration_noise = 0.1f32; // behavior policy noise stddev
414    let policy_delay = 2usize;
415    let batch_size = 64usize;
416    let start_steps = 500usize; // random exploration steps
417    let total_steps = 1500usize;
418    let max_grad_norm = 1.0f32;
419
420    // Models
421    let mut actor = Actor::new(state_dim, action_dim, Some(11));
422    let mut actor_targ = Actor::new(state_dim, action_dim, Some(12));
423    actor_targ.net.copy_from(&actor.net);
424    actor_targ.set_requires_grad_all(false);
425
426    let mut critic1 = Critic::new(state_dim, action_dim, Some(21));
427    let mut critic2 = Critic::new(state_dim, action_dim, Some(22));
428    let mut critic1_targ = Critic::new(state_dim, action_dim, Some(23));
429    let mut critic2_targ = Critic::new(state_dim, action_dim, Some(24));
430    critic1_targ.net.copy_from(&critic1.net);
431    critic2_targ.net.copy_from(&critic2.net);
432    critic1_targ.set_requires_grad_all(false);
433    critic2_targ.set_requires_grad_all(false);
434
435    // Optimizers
436    let mut actor_opt = Adam::with_learning_rate(1e-3);
437    for p in actor.parameters() {
438        actor_opt.add_parameter(p);
439    }
440
441    let mut critic_opt = Adam::with_learning_rate(1e-4);
442    for p in critic1.parameters() {
443        critic_opt.add_parameter(p);
444    }
445    for p in critic2.parameters() {
446        critic_opt.add_parameter(p);
447    }
448
449    // Replay buffer and env
450    let mut rb = ReplayBuffer::new(100_000, state_dim, action_dim);
451    let mut env = YardEnv::new(1234);
452    let mut rng = SmallRng::new(987654321);
453
454    // Reset & metric trackers
455    let mut state = env.reset(); // [1, state_dim]
456    let mut episode_return = 0.0f32;
457    let mut episode = 0usize;
458    let mut ema_return: Option<f32> = None;
459    let ema_alpha = 0.05f32; // smooth short-term
460    let mut best_return = f32::NEG_INFINITY;
461    let mut policy_updates: usize = 0;
462
463    for t in 0..total_steps {
464        // Select action
465        let action_tensor = if t < start_steps {
466            let a = rng.uniform(-1.0, 1.0);
467            Tensor::from_slice(&[a], vec![1, action_dim]).unwrap()
468        } else {
469            // Behavior policy with exploration noise
470            let _ng = NoGradTrack::new();
471            let det = actor.forward(&state);
472            let noise = Tensor::randn(vec![1, action_dim], None).mul_scalar(exploration_noise);
473            tanh_bounded(&det.add_tensor(&noise))
474        };
475        let action_value = action_tensor.data()[0];
476
477        // Environment step
478        let (next_state, reward, done) = env.step(action_value);
479        episode_return += reward;
480
481        // Store transition
482        let s_slice = state.data().to_vec();
483        let a_slice = action_tensor.data().to_vec();
484        let s2_slice = next_state.data().to_vec();
485        rb.push(
486            &s_slice,
487            &a_slice,
488            reward,
489            if done { 1.0 } else { 0.0 },
490            &s2_slice,
491        );
492
493        state = if done {
494            let st = env.reset();
495            // Metrics: update EMA and best
496            ema_return = Some(match ema_return {
497                None => episode_return,
498                Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
499            });
500            if episode_return > best_return {
501                best_return = episode_return;
502            }
503            println!(
504                "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={} | policy_updates={}",
505                t,
506                episode,
507                episode_return,
508                ema_return.unwrap_or(episode_return),
509                best_return,
510                rb.size,
511                policy_updates
512            );
513            episode_return = 0.0;
514            episode += 1;
515            st
516        } else {
517            next_state
518        };
519
520        // Training
521        if rb.can_sample(batch_size) {
522            // Sample batch
523            let (s, a, r, d, s2) = rb.sample(batch_size, &mut rng);
524
525            // Compute target values y = r + (1-d)*gamma*min(Q1', Q2') using target networks (no grad)
526            let target_q = {
527                let _ng = NoGradTrack::new();
528                // Target actions with smoothing noise (tanh bounds)
529                let noise =
530                    Tensor::randn(vec![batch_size, action_dim], None).mul_scalar(policy_noise);
531                let a_targ = tanh_bounded(&actor_targ.forward(&s2).add_tensor(&noise));
532                let q1_t = critic1_targ.forward(&s2, &a_targ);
533                let q2_t = critic2_targ.forward(&s2, &a_targ);
534
535                // Elementwise min via data() since this path is no-grad
536                let q1d = q1_t.data();
537                let q2d = q2_t.data();
538                let mut min_vec = Vec::with_capacity(batch_size);
539                for i in 0..batch_size {
540                    let v1 = q1d[i];
541                    let v2 = q2d[i];
542                    min_vec.push(v1.min(v2));
543                }
544                let min_q = Tensor::from_slice(&min_vec, vec![batch_size, 1]).unwrap();
545                let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
546                r.add_tensor(&not_done.mul_scalar(gamma).mul_tensor(&min_q))
547            };
548
549            // Critic update (both critics)
550            // Zero grads in a short scope, then drop borrows before forward
551            {
552                let mut params = {
553                    let c_params = critic1.parameters();
554                    let c2_params = critic2.parameters();
555                    let mut tmp: Vec<&mut Tensor> = Vec::new();
556                    tmp.extend(c_params);
557                    tmp.extend(c2_params);
558                    tmp
559                };
560                critic_opt.zero_grad(&mut params);
561            }
562
563            // Forward current Q estimates
564            let q1 = critic1.forward(&s, &a);
565            let q2 = critic2.forward(&s, &a);
566            let diff1 = q1.sub_tensor(&target_q);
567            let diff2 = q2.sub_tensor(&target_q);
568            let mut critic_loss = diff1
569                .pow_scalar(2.0)
570                .mean()
571                .add_tensor(&diff2.pow_scalar(2.0).mean());
572
573            // Backward
574            critic_loss.backward(None);
575
576            // Optional gradient clipping + step (only for params that received grads)
577            {
578                let params = {
579                    let c_params = critic1.parameters();
580                    let c2_params = critic2.parameters();
581                    let mut tmp: Vec<&mut Tensor> = Vec::new();
582                    tmp.extend(c_params);
583                    tmp.extend(c2_params);
584                    tmp
585                };
586                let mut with_grads: Vec<&mut Tensor> = Vec::new();
587                for p in params {
588                    if p.grad_owned().is_some() {
589                        with_grads.push(p);
590                    }
591                }
592                if !with_grads.is_empty() {
593                    // Pre-step metrics
594                    let grad_norm_before = grad_global_norm(&mut with_grads);
595                    clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
596                    critic_opt.step(&mut with_grads);
597                    critic_opt.zero_grad(&mut with_grads);
598
599                    // Post-step metrics (param norm)
600                    let mut for_norm_params = {
601                        let c_params = critic1.parameters();
602                        let c2_params = critic2.parameters();
603                        let mut tmp: Vec<&mut Tensor> = Vec::new();
604                        tmp.extend(c_params);
605                        tmp.extend(c2_params);
606                        tmp
607                    };
608                    let param_norm = params_l2_norm(&mut for_norm_params);
609
610                    // Print compact critic metrics occasionally
611                    if t % 100 == 0 {
612                        let q1_mean = q1.mean().value();
613                        let q2_mean = q2.mean().value();
614                        let tq_mean = target_q.mean().value();
615                        println!(
616                            "t={:5} | critic_loss={:.4} | q1_mean={:.3} q2_mean={:.3} tq_mean={:.3} | grad_norm={:.3} | crit_param_norm={:.3}",
617                            t,
618                            critic_loss.value(),
619                            q1_mean,
620                            q2_mean,
621                            tq_mean,
622                            grad_norm_before,
623                            param_norm
624                        );
625                    }
626                }
627            }
628
629            // Delayed policy update
630            if t % policy_delay == 0 {
631                // Actor update: maximize Q1(s, actor(s)) -> minimize -Q1
632                // Zero actor grads before backward
633                {
634                    let mut a_params: Vec<&mut Tensor> = actor.parameters();
635                    actor_opt.zero_grad(&mut a_params);
636                }
637
638                let a_pred = actor.forward(&s);
639                let q_for_actor = critic1.forward(&s, &a_pred);
640                let mut actor_loss = q_for_actor.mul_scalar(-1.0).mean();
641                actor_loss.backward(None);
642
643                {
644                    let a_params: Vec<&mut Tensor> = actor.parameters();
645                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
646                    for p in a_params {
647                        if p.grad_owned().is_some() {
648                            with_grads.push(p);
649                        }
650                    }
651                    if !with_grads.is_empty() {
652                        let grad_norm_before = grad_global_norm(&mut with_grads);
653                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
654                        actor_opt.step(&mut with_grads);
655                        actor_opt.zero_grad(&mut with_grads);
656
657                        // Post-step param norm
658                        let mut for_norm_params = actor.parameters();
659                        let param_norm = params_l2_norm(&mut for_norm_params);
660
661                        policy_updates += 1;
662                        if t % 200 == 0 {
663                            println!(
664                                "t={:5} | actor_loss={:.4} | act_grad_norm={:.3} | act_param_norm={:.3} | lr_a={:.4e} lr_c={:.4e} | policy_updates={}",
665                                t,
666                                actor_loss.value(),
667                                grad_norm_before,
668                                param_norm,
669                                actor_opt.learning_rate(),
670                                critic_opt.learning_rate(),
671                                policy_updates
672                            );
673                        }
674                    }
675                }
676
677                // Target updates (Polyak averaging, no grad)
678                actor_targ.net.soft_update_from(&actor.net, tau);
679                critic1_targ.net.soft_update_from(&critic1.net, tau);
680                critic2_targ.net.soft_update_from(&critic2.net, tau);
681            }
682
683            // Clear entire graphs to avoid stale accumulation across iterations
684            clear_all_graphs_known();
685        }
686    }
687
688    println!("=== TD3 training finished ===");
689    Ok(())
690}
examples/getting_started/tensor_operators.rs (line 49)
46fn demonstrate_basic_operators() {
47    println!("--- Basic Tensor-Tensor Operators ---");
48
49    let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
50    let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
51
52    println!("Tensor A: {:?}", a.data());
53    println!("Tensor B: {:?}", b.data());
54
55    // Addition
56    let c = &a + &b;
57    println!("A + B: {:?}", c.data());
58
59    // Subtraction
60    let d = &a - &b;
61    println!("A - B: {:?}", d.data());
62
63    // Multiplication
64    let e = &a * &b;
65    println!("A * B: {:?}", e.data());
66
67    // Division
68    let f = &a / &b;
69    println!("A / B: {:?}", f.data());
70}
71
72/// Demonstrate tensor-scalar operators
73fn demonstrate_scalar_operators() {
74    println!("\n--- Tensor-Scalar Operators ---");
75
76    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
77    println!("Original tensor: {:?}", tensor.data());
78
79    // Tensor + scalar
80    let result1 = &tensor + 5.0;
81    println!("Tensor + 5.0: {:?}", result1.data());
82
83    // Scalar + tensor
84    let result2 = 5.0 + &tensor;
85    println!("5.0 + Tensor: {:?}", result2.data());
86
87    // Tensor - scalar
88    let result3 = &tensor - 2.0;
89    println!("Tensor - 2.0: {:?}", result3.data());
90
91    // Tensor * scalar
92    let result4 = &tensor * 3.0;
93    println!("Tensor * 3.0: {:?}", result4.data());
94
95    // Scalar * tensor
96    let result5 = 3.0 * &tensor;
97    println!("3.0 * Tensor: {:?}", result5.data());
98
99    // Tensor / scalar
100    let result6 = &tensor / 2.0;
101    println!("Tensor / 2.0: {:?}", result6.data());
102}
103
104/// Demonstrate assignment operators
105fn demonstrate_operator_assignment() {
106    println!("\n--- Assignment Operators ---");
107
108    let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
109    println!("Original tensor: {:?}", tensor.data());
110
111    // In-place addition
112    tensor += 5.0;
113    println!("After += 5.0: {:?}", tensor.data());
114
115    // In-place subtraction
116    tensor -= 2.0;
117    println!("After -= 2.0: {:?}", tensor.data());
118
119    // In-place multiplication
120    tensor *= 3.0;
121    println!("After *= 3.0: {:?}", tensor.data());
122
123    // In-place division
124    tensor /= 2.0;
125    println!("After /= 2.0: {:?}", tensor.data());
126}
127
128/// Demonstrate operator chaining and complex expressions
129fn demonstrate_operator_chaining() {
130    println!("\n--- Operator Chaining ---");
131
132    let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
133    let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
134    let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
135
136    println!("Tensor A: {:?}", a.data());
137    println!("Tensor B: {:?}", b.data());
138    println!("Tensor C: {:?}", c.data());
139
140    // Complex expression: (A + B) * C - 5
141    let result = (&a + &b) * &c - 5.0;
142    println!("(A + B) * C - 5: {:?}", result.data());
143
144    // Another complex expression: A * 2 + B / 2
145    let result2 = &a * 2.0 + &b / 2.0;
146    println!("A * 2 + B / 2: {:?}", result2.data());
147
148    // Negation and addition: -A + B * C
149    let result3 = -&a + &b * &c;
150    println!("-A + B * C: {:?}", result3.data());
151
152    // Division with parentheses: (A + B) / (C - 1)
153    let result4 = (&a + &b) / (&c - 1.0);
154    println!("(A + B) / (C - 1): {:?}", result4.data());
155}
156
157/// Demonstrate broadcasting behavior
158fn demonstrate_broadcasting() {
159    println!("\n--- Broadcasting ---");
160
161    // 2D tensor
162    let tensor_2d = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
163    println!(
164        "2D tensor: shape {:?}, data: {:?}",
165        tensor_2d.shape().dims(),
166        tensor_2d.data()
167    );
168
169    // 1D tensor (will be broadcasted)
170    let tensor_1d = Tensor::from_slice(&[10.0, 20.0], vec![2]).unwrap();
171    println!(
172        "1D tensor: shape {:?}, data: {:?}",
173        tensor_1d.shape().dims(),
174        tensor_1d.data()
175    );
176
177    // Broadcasting addition
178    let broadcast_sum = &tensor_2d + &tensor_1d;
179    println!(
180        "Broadcast sum: shape {:?}, data: {:?}",
181        broadcast_sum.shape().dims(),
182        broadcast_sum.data()
183    );
184
185    // Broadcasting multiplication
186    let broadcast_mul = &tensor_2d * &tensor_1d;
187    println!(
188        "Broadcast multiplication: shape {:?}, data: {:?}",
189        broadcast_mul.shape().dims(),
190        broadcast_mul.data()
191    );
192
193    // Broadcasting with scalar
194    let broadcast_scalar = &tensor_2d + 100.0;
195    println!(
196        "Broadcast scalar: shape {:?}, data: {:?}",
197        broadcast_scalar.shape().dims(),
198        broadcast_scalar.data()
199    );
200}
201
202/// Demonstrate equivalence between operators and method calls
203fn demonstrate_method_equivalence() {
204    println!("\n--- Operator vs Method Call Equivalence ---");
205
206    let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
207    let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
208
209    // Addition: operator vs method
210    let operator_result = &a + &b;
211    let method_result = a.add_tensor(&b);
212
213    println!("A + B (operator): {:?}", operator_result.data());
214    println!("A.add_tensor(B): {:?}", method_result.data());
215    println!(
216        "Results are equal: {}",
217        operator_result.data() == method_result.data()
218    );
219
220    // Multiplication: operator vs method
221    let operator_result = &a * &b;
222    let method_result = a.mul_tensor(&b);
223
224    println!("A * B (operator): {:?}", operator_result.data());
225    println!("A.mul_tensor(B): {:?}", method_result.data());
226    println!(
227        "Results are equal: {}",
228        operator_result.data() == method_result.data()
229    );
230
231    // Scalar addition: operator vs method
232    let operator_result = &a + 5.0;
233    let method_result = a.add_scalar(5.0);
234
235    println!("A + 5.0 (operator): {:?}", operator_result.data());
236    println!("A.add_scalar(5.0): {:?}", method_result.data());
237    println!(
238        "Results are equal: {}",
239        operator_result.data() == method_result.data()
240    );
241}
examples/RL_training/../neural_networks/basic_linear_layer.rs (line 175)
169fn demonstrate_forward_pass() {
170    println!("\n--- Forward Pass (with gradients) ---");
171
172    let layer = LinearLayer::new(3, 2, Some(43));
173
174    // Single input
175    let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
176    let output = layer.forward(&input);
177
178    println!("Single input:");
179    println!("  Input: {:?}", input.data());
180    println!("  Output: {:?}", output.data());
181    println!("  Output requires grad: {}", output.requires_grad());
182
183    // Batch input
184    let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
185    let batch_output = layer.forward(&batch_input);
186
187    println!("Batch input:");
188    println!("  Input shape: {:?}", batch_input.shape().dims());
189    println!("  Output shape: {:?}", batch_output.shape().dims());
190    println!("  Output requires grad: {}", batch_output.requires_grad());
191}
192
193/// Demonstrate forward pass without gradient tracking
194fn demonstrate_forward_pass_no_grad() {
195    println!("\n--- Forward Pass (no gradients) ---");
196
197    let layer = LinearLayer::new(3, 2, Some(44));
198
199    // Single input
200    let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
201    let output = layer.forward_no_grad(&input);
202
203    println!("Single input (no grad):");
204    println!("  Input: {:?}", input.data());
205    println!("  Output: {:?}", output.data());
206    println!("  Output requires grad: {}", output.requires_grad());
207
208    // Compare with grad version
209    let output_with_grad = layer.forward(&input);
210    println!("Comparison:");
211    println!(
212        "  Same values: {}",
213        output.data() == output_with_grad.data()
214    );
215    println!("  No grad requires grad: {}", output.requires_grad());
216    println!(
217        "  With grad requires grad: {}",
218        output_with_grad.requires_grad()
219    );
220}
221
222/// Demonstrate complete training loop
223fn demonstrate_training_loop() -> Result<(), Box<dyn std::error::Error>> {
224    println!("\n--- Training Loop ---");
225
226    // Create layer and training data
227    let mut layer = LinearLayer::new(2, 1, Some(45));
228
229    // Simple regression task: y = 2*x1 + 3*x2 + 1
230    let x_data = Tensor::from_slice(
231        &[
232            1.0, 1.0, // x1=1, x2=1 -> y=6
233            2.0, 1.0, // x1=2, x2=1 -> y=8
234            1.0, 2.0, // x1=1, x2=2 -> y=9
235            2.0, 2.0, // x1=2, x2=2 -> y=11
236        ],
237        vec![4, 2],
238    )
239    .unwrap();
240
241    let y_true = Tensor::from_slice(&[6.0, 8.0, 9.0, 11.0], vec![4, 1]).unwrap();
242
243    println!("Training data:");
244    println!("  X shape: {:?}", x_data.shape().dims());
245    println!("  Y shape: {:?}", y_true.shape().dims());
246    println!("  Target function: y = 2*x1 + 3*x2 + 1");
247
248    // Create optimizer
249    let config = AdamConfig {
250        learning_rate: 0.01,
251        beta1: 0.9,
252        beta2: 0.999,
253        eps: 1e-8,
254        weight_decay: 0.0,
255        amsgrad: false,
256    };
257
258    let mut optimizer = Adam::with_config(config);
259    let params = layer.parameters();
260    for param in &params {
261        optimizer.add_parameter(param);
262    }
263
264    println!("Optimizer setup complete. Starting training...");
265
266    // Training loop
267    let num_epochs = 100;
268    let mut losses = Vec::new();
269
270    for epoch in 0..num_epochs {
271        // Forward pass
272        let y_pred = layer.forward(&x_data);
273
274        // Compute loss: MSE
275        let diff = y_pred.sub_tensor(&y_true);
276        let mut loss = diff.pow_scalar(2.0).mean();
277
278        // Backward pass
279        loss.backward(None);
280
281        // Optimizer step
282        let mut params = layer.parameters();
283        optimizer.step(&mut params);
284        optimizer.zero_grad(&mut params);
285
286        losses.push(loss.value());
287
288        // Print progress
289        if epoch % 20 == 0 || epoch == num_epochs - 1 {
290            println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
291        }
292    }
293
294    // Evaluate final model
295    let final_predictions = layer.forward_no_grad(&x_data);
296
297    println!("\nFinal model evaluation:");
298    println!("  Learned weights: {:?}", layer.weight.data());
299    println!("  Learned bias: {:?}", layer.bias.data());
300    println!("  Target weights: [2.0, 3.0]");
301    println!("  Target bias: [1.0]");
302
303    println!("  Predictions vs True:");
304    for i in 0..4 {
305        let pred = final_predictions.data()[i];
306        let true_val = y_true.data()[i];
307        println!(
308            "    Sample {}: pred={:.3}, true={:.1}, error={:.3}",
309            i + 1,
310            pred,
311            true_val,
312            (pred - true_val).abs()
313        );
314    }
315
316    // Training analysis
317    let initial_loss = losses[0];
318    let final_loss = losses[losses.len() - 1];
319    let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
320
321    println!("\nTraining Analysis:");
322    println!("  Initial loss: {:.6}", initial_loss);
323    println!("  Final loss: {:.6}", final_loss);
324    println!("  Loss reduction: {:.1}%", loss_reduction);
325
326    Ok(())
327}
328
329/// Demonstrate single vs batch inference
330fn demonstrate_single_vs_batch_inference() {
331    println!("\n--- Single vs Batch Inference ---");
332
333    let layer = LinearLayer::new(4, 3, Some(46));
334
335    // Single inference
336    println!("Single inference:");
337    let single_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
338    let single_output = layer.forward_no_grad(&single_input);
339    println!("  Input shape: {:?}", single_input.shape().dims());
340    println!("  Output shape: {:?}", single_output.shape().dims());
341    println!("  Output: {:?}", single_output.data());
342
343    // Batch inference
344    println!("Batch inference:");
345    let batch_input = Tensor::from_slice(
346        &[
347            1.0, 2.0, 3.0, 4.0, // Sample 1
348            5.0, 6.0, 7.0, 8.0, // Sample 2
349            9.0, 10.0, 11.0, 12.0, // Sample 3
350        ],
351        vec![3, 4],
352    )
353    .unwrap();
354    let batch_output = layer.forward_no_grad(&batch_input);
355    println!("  Input shape: {:?}", batch_input.shape().dims());
356    println!("  Output shape: {:?}", batch_output.shape().dims());
357
358    // Verify batch consistency - first sample should match single inference
359    let _first_batch_sample = batch_output.view(vec![3, 3]); // Reshape to access first sample
360    let first_sample_data = &batch_output.data()[0..3]; // First 3 elements
361    let single_sample_data = single_output.data();
362
363    println!("Consistency check:");
364    println!("  Single output: {:?}", single_sample_data);
365    println!("  First batch sample: {:?}", first_sample_data);
366    println!(
367        "  Match: {}",
368        single_sample_data
369            .iter()
370            .zip(first_sample_data.iter())
371            .all(|(a, b)| (a - b).abs() < 1e-6)
372    );
373}
374
375/// Demonstrate serialization and loading
376fn demonstrate_serialization() -> Result<(), Box<dyn std::error::Error>> {
377    println!("\n--- Serialization ---");
378
379    // Create and train a simple layer
380    let mut original_layer = LinearLayer::new(2, 1, Some(47));
381
382    // Simple training data
383    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
384    let y_true = Tensor::from_slice(&[5.0, 11.0], vec![2, 1]).unwrap();
385
386    let mut optimizer = Adam::with_learning_rate(0.01);
387    let params = original_layer.parameters();
388    for param in &params {
389        optimizer.add_parameter(param);
390    }
391
392    // Train for a few epochs
393    for _ in 0..10 {
394        let y_pred = original_layer.forward(&x_data);
395        let mut loss = (y_pred.sub_tensor(&y_true)).pow_scalar(2.0).mean();
396        loss.backward(None);
397
398        let mut params = original_layer.parameters();
399        optimizer.step(&mut params);
400        optimizer.zero_grad(&mut params);
401    }
402
403    println!("Original layer trained");
404    println!("  Weight: {:?}", original_layer.weight.data());
405    println!("  Bias: {:?}", original_layer.bias.data());
406
407    // Save layer
408    original_layer.save_json("temp_linear_layer")?;
409
410    // Load layer
411    let loaded_layer = LinearLayer::load_json("temp_linear_layer", 2, 1)?;
412
413    println!("Loaded layer");
414    println!("  Weight: {:?}", loaded_layer.weight.data());
415    println!("  Bias: {:?}", loaded_layer.bias.data());
416
417    // Verify consistency
418    let test_input = Tensor::from_slice(&[1.0, 1.0], vec![1, 2]).unwrap();
419    let original_output = original_layer.forward_no_grad(&test_input);
420    let loaded_output = loaded_layer.forward_no_grad(&test_input);
421
422    println!("Consistency check:");
423    println!("  Original output: {:?}", original_output.data());
424    println!("  Loaded output: {:?}", loaded_output.data());
425    println!(
426        "  Match: {}",
427        original_output
428            .data()
429            .iter()
430            .zip(loaded_output.data().iter())
431            .all(|(a, b)| (a - b).abs() < 1e-6)
432    );
433
434    println!("Serialization verification: PASSED");
435
436    Ok(())
437}
Source§

impl Tensor

Source

pub fn randn(shape_dims: Vec<usize>, seed: Option<u64>) -> Self

Creates a tensor with normally distributed random values (mean=0, std=1)

Similar to PyTorch’s torch.randn(), creates a tensor filled with random values drawn from a standard normal distribution (mean=0, standard deviation=1). Uses Box-Muller transform for efficient normal distribution generation.

This method provides high-quality random number generation with optional reproducibility through seed-based generation. The generated values follow a standard normal distribution suitable for machine learning applications.

§Arguments
  • shape_dims - Vector of dimension sizes defining the tensor shape
  • seed - Optional seed for reproducible random generation
§Returns

A new tensor with normally distributed random values

§Performance
  • Box-Muller Transform: Efficient normal distribution generation
  • SIMD Optimization: Vectorized operations for large tensors
  • Memory Efficient: Single-pass generation with optimized allocation
  • Thread Safe: Uses thread-local random state
§Examples
§Basic Usage
use train_station::Tensor;

// Create a 2x3 tensor with random normal values
let tensor = Tensor::randn(vec![2, 3], None);
assert_eq!(tensor.size(), 6);
assert_eq!(tensor.shape().dims(), vec![2, 3]);

// Verify random values are generated
let first_value = tensor.get(&[0, 0]);
assert!(first_value != 0.0); // Should be random
§Reproducible Generation
use train_station::Tensor;

// Create with fixed seed for reproducible results
let tensor1 = Tensor::randn(vec![100], Some(42));
let tensor2 = Tensor::randn(vec![100], Some(42));

// tensor1 and tensor2 will have identical values
for i in 0..tensor1.size() {
    assert!((tensor1.get(&[i]) - tensor2.get(&[i])).abs() < 1e-6);
}
§Statistical Properties
use train_station::Tensor;

// Generate large tensor for statistical analysis
let tensor = Tensor::randn(vec![1000], Some(42));
assert_eq!(tensor.size(), 1000);

// Check that values are reasonable (within 4 standard deviations)
let mut min_val = f32::INFINITY;
let mut max_val = f32::NEG_INFINITY;
let mut sum = 0.0;

for i in 0..tensor.size() {
    let val = tensor.get(&[i]);
    min_val = min_val.min(val);
    max_val = max_val.max(val);
    sum += val;
}

let mean = sum / tensor.size() as f32;

// Mean should be close to 0, values should be within reasonable bounds
assert!(mean.abs() < 0.1, "Mean should be close to 0, got {}", mean);
assert!(min_val > -4.0, "Values should not be too negative, min: {}", min_val);
assert!(max_val < 4.0, "Values should not be too positive, max: {}", max_val);
§Zero-Sized Tensors
use train_station::Tensor;

// Handle empty tensors gracefully
let tensor = Tensor::randn(vec![0], Some(42));
assert_eq!(tensor.size(), 0);
assert_eq!(tensor.shape().dims(), vec![0]);
§Implementation Details

This method uses the Box-Muller transform to generate normally distributed random variables from uniform random variables. The process involves:

  1. Random Number Generation: Uses Xorshift algorithm for uniform random numbers
  2. Box-Muller Transform: Converts uniform random variables to normal distribution
  3. SIMD Optimization: Vectorized operations for large tensors when available
  4. Numerical Stability: Robust handling of edge cases and potential NaN values

The Box-Muller transform ensures that the generated values follow a true normal distribution with mean=0 and standard deviation=1, making it suitable for machine learning applications requiring normally distributed random values.

Examples found in repository?
examples/RL_training/../neural_networks/basic_linear_layer.rs (line 57)
53    pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
54        // Xavier/Glorot initialization: scale by sqrt(1/input_size)
55        let scale = (1.0 / input_size as f32).sqrt();
56
57        let weight = Tensor::randn(vec![input_size, output_size], seed)
58            .mul_scalar(scale)
59            .with_requires_grad();
60        let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
61
62        Self {
63            weight,
64            bias,
65            input_size,
66            output_size,
67        }
68    }
More examples
Hide additional examples
examples/neural_networks/basic_encoder.rs (line 81)
73fn main() -> Result<(), Box<dyn std::error::Error>> {
74    println!("=== Basic Encoder Example ===");
75
76    let batch = 2usize;
77    let seq = 6usize;
78    let embed = 32usize;
79    let heads = 4usize;
80
81    let input = Tensor::randn(vec![batch, seq, embed], Some(11));
82    let mut enc = EncoderBlock::new(embed, heads, Some(123));
83
84    // Example: no mask (set Some(mask) to use masking)
85    let out = enc.forward(&input, None);
86    println!("Output shape: {:?}", out.shape().dims());
87
88    // Verify gradients/optimization
89    let mut opt = Adam::with_learning_rate(0.01);
90    let mut params = enc.parameters();
91    for p in &params {
92        opt.add_parameter(p);
93    }
94    let mut loss = out.mean();
95    loss.backward(None);
96    opt.step(&mut params);
97    opt.zero_grad(&mut params);
98    println!("Loss: {:.6}", loss.value());
99    println!("=== Done ===");
100    Ok(())
101}
examples/neural_networks/basic_decoder.rs (line 93)
84fn main() -> Result<(), Box<dyn std::error::Error>> {
85    println!("=== Basic Decoder Example ===");
86
87    let batch = 2usize;
88    let src = 7usize;
89    let tgt = 5usize;
90    let embed = 32usize;
91    let heads = 4usize;
92
93    let memory = Tensor::randn(vec![batch, src, embed], Some(21));
94    let tgt_in = Tensor::randn(vec![batch, tgt, embed], Some(22));
95
96    let mut dec = DecoderBlock::new(embed, heads, Some(456));
97    let out = dec.forward(&tgt_in, &memory, None, None);
98    println!("Output shape: {:?}", out.shape().dims());
99
100    let mut opt = Adam::with_learning_rate(0.01);
101    let mut params = dec.parameters();
102    for p in &params {
103        opt.add_parameter(p);
104    }
105    let mut loss = out.mean();
106    loss.backward(None);
107    opt.step(&mut params);
108    opt.zero_grad(&mut params);
109    println!("Loss: {:.6}", loss.value());
110    println!("=== Done ===");
111    Ok(())
112}
examples/getting_started/tensor_basics.rs (line 80)
42fn demonstrate_tensor_creation() {
43    println!("--- Tensor Creation ---");
44
45    // Create tensors with different initializations
46    let zeros = Tensor::zeros(vec![2, 3]);
47    println!(
48        "Zeros tensor: shape {:?}, data: {:?}",
49        zeros.shape().dims(),
50        zeros.data()
51    );
52
53    let ones = Tensor::ones(vec![3, 2]);
54    println!(
55        "Ones tensor: shape {:?}, data: {:?}",
56        ones.shape().dims(),
57        ones.data()
58    );
59
60    // Create tensor from slice
61    let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
62    let from_slice = Tensor::from_slice(&data, vec![2, 3]).unwrap();
63    println!(
64        "From slice: shape {:?}, data: {:?}",
65        from_slice.shape().dims(),
66        from_slice.data()
67    );
68
69    // Create tensor with specific value
70    let mut filled = Tensor::new(vec![2, 2]);
71    {
72        let data = filled.data_mut();
73        for value in data.iter_mut() {
74            *value = 42.0;
75        }
76    }
77    println!("Filled with 42: {:?}", filled.data());
78
79    // Create tensor with random data
80    let random = Tensor::randn(vec![2, 2], Some(42));
81    println!(
82        "Random tensor: shape {:?}, data: {:?}",
83        random.shape().dims(),
84        random.data()
85    );
86}
examples/neural_networks/feedforward_network.rs (line 385)
363fn demonstrate_configurable_architectures() {
364    println!("\n--- Configurable Architectures ---");
365
366    let architectures = vec![
367        ("Shallow", vec![8]),
368        ("Medium", vec![16, 8]),
369        ("Deep", vec![32, 16, 8, 4]),
370        ("Wide", vec![64, 32]),
371        ("Bottleneck", vec![16, 4, 16]),
372    ];
373
374    for (name, hidden_sizes) in architectures {
375        let config = FeedForwardConfig {
376            input_size: 10,
377            hidden_sizes,
378            output_size: 3,
379            use_bias: true,
380        };
381
382        let network = FeedForwardNetwork::new(config.clone(), Some(44));
383
384        // Test forward pass
385        let test_input = Tensor::randn(vec![5, 10], Some(45)); // Batch of 5
386        let output = network.forward_no_grad(&test_input);
387
388        println!("{} network:", name);
389        println!("  Architecture: 10 -> {:?} -> 3", config.hidden_sizes);
390        println!("  Parameters: {}", network.parameter_count());
391        println!("  Test output shape: {:?}", output.shape().dims());
392        println!(
393            "  Output range: [{:.3}, {:.3}]",
394            output.data().iter().fold(f32::INFINITY, |a, &b| a.min(b)),
395            output
396                .data()
397                .iter()
398                .fold(f32::NEG_INFINITY, |a, &b| a.max(b))
399        );
400    }
401}
examples/neural_networks/basic_transformer.rs (line 241)
231fn main() -> Result<(), Box<dyn std::error::Error>> {
232    println!("=== Basic Transformer Example ===");
233
234    let batch = 2usize;
235    let src_len = 8usize;
236    let tgt_len = 6usize;
237    let embed = 32usize;
238    let heads = 4usize;
239    let layers = 2usize;
240
241    let src = Tensor::randn(vec![batch, src_len, embed], Some(1001));
242    let tgt = Tensor::randn(vec![batch, tgt_len, embed], Some(1002));
243
244    let mut trf = BasicTransformer::new(embed, heads, layers, Some(999));
245    let out = trf.forward(&src, &tgt);
246    println!("Output shape: {:?}", out.shape().dims());
247
248    // Quick optimization step
249    let mut opt = Adam::with_learning_rate(0.005);
250    let mut params = trf.parameters();
251    for p in &params {
252        opt.add_parameter(p);
253    }
254    let mut loss = out.mean();
255    loss.backward(None);
256    opt.step(&mut params);
257    opt.zero_grad(&mut params);
258    println!("Loss: {:.6}", loss.value());
259
260    // Demo: non auto-regressive inference (single pass)
261    let nar = trf.infer_non_autoregressive(&src, tgt_len);
262    println!("NAR output shape: {:?}", nar.shape().dims());
263
264    // Demo: auto-regressive inference (toy)
265    let ar = trf.infer_autoregressive(&src, 3);
266    println!("AR output shape: {:?}", ar.shape().dims());
267
268    // NAR training demo
269    let nar_tgt = tgt.clone();
270    trf.train_non_autoregressive_steps(&src, &nar_tgt, 3, 0.01);
271
272    // AR training demo (teacher-forced)
273    let ar_tgt = tgt.clone();
274    trf.train_autoregressive_steps(&src, &ar_tgt, 3, 0.01);
275    println!("=== Done ===");
276    Ok(())
277}
Source

pub fn fill_randn(&mut self, seed: Option<u64>)

Fills the tensor with normally distributed random values

Internal method that fills an existing tensor with random values from a standard normal distribution. Uses Box-Muller transform for efficiency and provides SIMD optimization for large tensors.

This method is used internally by randn() and provides the core random number generation functionality with optimized performance characteristics.

§Arguments
  • seed - Optional seed for reproducible random generation
§Performance
  • Box-Muller Transform: Generates pairs of normal random variables
  • SIMD Optimization: Vectorized operations when possible
  • Memory Efficient: Single-pass generation
  • Unrolled Loops: 4x unrolling for better instruction throughput
§Implementation Details

The method performs the following steps:

  1. Zero-sized Check: Returns early for empty tensors
  2. RNG Initialization: Creates Xorshift RNG with seed or system time
  3. SIMD Detection: Checks for AVX2 availability for optimized path
  4. Generation: Uses SIMD or scalar path based on hardware support
  5. Completion: Fills all tensor elements with normal random values

The method automatically handles hardware capabilities and falls back to scalar operations when SIMD is not available, ensuring compatibility across different CPU architectures.

Source§

impl Tensor

Source

pub fn chunks(&self, chunk_size: usize) -> TensorChunksIterator<'_>

Standard slice-like chunks iterator. Use this instead of iter_chunks.

Iterates over contiguous or view-backed slices of the tensor with the specified chunk size. In no-grad fast mode, a single contiguous owner may be materialized to optimize subsequent views.

§Arguments
  • chunk_size - Number of elements per chunk (must be > 0)
§Examples
use train_station::tensor::TensorCollectExt;
use train_station::Tensor;

let t = Tensor::from_slice(&(1..=6).map(|i| i as f32).collect::<Vec<_>>(), vec![6]).unwrap();
let y = t.chunks(2).map(|c| c.mul_scalar(2.0)).collect_shape(vec![6]);
assert_eq!(y.data(), &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0]);
Examples found in repository?
examples/iterators/performance_optimization.rs (line 178)
162fn demonstrate_memory_optimization() -> Result<(), Box<dyn std::error::Error>> {
163    println!("\n--- Memory Optimization ---");
164
165    // Create a large tensor for memory testing
166    let size = 10000;
167    let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
168    let tensor = Tensor::from_slice(&data, vec![size])?;
169
170    println!("Processing tensor of size: {}", size);
171
172    // Pattern 1: Streaming processing with iterator chunks (process in blocks, collect with shape)
173    println!("\nPattern 1: Streaming Processing");
174    let chunk_size = 1000;
175    let start = Instant::now();
176    let flattened = tensor.view(vec![size as i32]);
177    let _streamed_result: Tensor = flattened
178        .chunks(chunk_size)
179        .map(|c| c.pow_scalar(2.0).sqrt())
180        .collect_shape(vec![size]);
181    let streamed_time = start.elapsed();
182
183    // Pattern 2: Full processing
184    let start = Instant::now();
185    let _full_result: Tensor = tensor
186        .iter_elements()
187        .map(|elem| elem.pow_scalar(2.0).sqrt())
188        .collect_shape(vec![size]);
189    let full_time = start.elapsed();
190
191    println!("  Streaming time: {:?}", streamed_time);
192    println!("  Full processing time: {:?}", full_time);
193    println!(
194        "  Memory efficiency ratio: {:.2}x",
195        full_time.as_nanos() as f64 / streamed_time.as_nanos() as f64
196    );
197
198    // Pattern 3: Lazy evaluation with take
199    println!("\nPattern 2: Lazy Evaluation");
200    let start = Instant::now();
201    let lazy_result: Tensor = tensor
202        .iter_elements()
203        .take(1000) // Only process first 1000 elements
204        .map(|elem| elem.pow_scalar(2.0).sqrt())
205        .collect_shape(vec![1000]);
206    let lazy_time = start.elapsed();
207
208    println!("  Lazy processing (1000 elements): {:?}", lazy_time);
209    println!("  Lazy result size: {}", lazy_result.size());
210
211    // Pattern 4: Memory-efficient filtering
212    println!("\nPattern 3: Memory-Efficient Filtering");
213    let start = Instant::now();
214    let filtered_result: Tensor = tensor
215        .iter_elements()
216        .filter(|elem| elem.value() > size as f32 / 2.0) // Keep only large values
217        .map(|elem| elem.mul_scalar(2.0))
218        .collect();
219    let filtered_time = start.elapsed();
220
221    println!("  Filtered processing: {:?}", filtered_time);
222    println!(
223        "  Filtered result size: {} (reduced from {})",
224        filtered_result.size(),
225        size
226    );
227
228    Ok(())
229}
Source

pub fn chunks_exact(&self, chunk_size: usize) -> TensorChunksExactIterator<'_>

Standard slice-like exact chunks iterator. Use this instead of iter_chunks_exact.

Yields only the exact chunks of size chunk_size, exposing any remainder via remainder(). See chunks() for a variant that yields the remainder as the last (smaller) chunk.

§Examples
use train_station::Tensor;

let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
let mut it = t.chunks_exact(2);
assert_eq!(it.next().unwrap().data(), &[1.0, 2.0]);
assert_eq!(it.next().unwrap().data(), &[3.0, 4.0]);
assert_eq!(it.remainder().data(), &[5.0]);
Source

pub fn iter_chunks(&self, chunk_size: usize) -> TensorChunksIterator<'_>

👎Deprecated: Use Tensor::chunks(…) instead. This alias will be removed before 1.0.
Source

pub fn iter_chunks_exact( &self, chunk_size: usize, ) -> TensorChunksExactIterator<'_>

👎Deprecated: Use Tensor::chunks_exact(…) instead. This alias will be removed before 1.0.
Source§

impl Tensor

Source

pub fn collect_into_shape<I: IntoIterator<Item = Tensor>>( iter: I, dims: Vec<usize>, ) -> Tensor

Collect tensors into a single tensor with target shape, copying data in iterator order. Optimizes copy using SIMD when available; asserts total size matches.

Source§

impl Tensor

Source

pub fn collect_values_shape<I: IntoIterator<Item = f32>>( iter: I, dims: Vec<usize>, ) -> Tensor

Inherent helper to collect any iterator of f32 into a shaped tensor. This mirrors the ValuesCollectExt::collect_shape functionality but does not require importing the extension trait.

Source§

impl Tensor

Source

pub fn iter_elements(&self) -> TensorElementIterator<'_>

Create an iterator over scalar elements (flattened view)

Each yielded item is a [1]-shaped Tensor view that shares storage with the source. This iterator is GradTrack-aware; element operations propagate gradients to the original tensor when gradients are enabled.

§Returns

An iterator producing scalar view tensors in row-major order.

§Examples

Collect transformed elements back to the original shape using collect_shape:

use train_station::tensor::TensorCollectExt;
use train_station::Tensor;

let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let y = x
    .iter_elements()
    .map(|e| e.mul_scalar(2.0))
    .collect_shape(vec![2, 2]);
assert_eq!(y.data(), &[2.0, 4.0, 6.0, 8.0]);
Examples found in repository?
examples/iterators/performance_optimization.rs (line 108)
87fn demonstrate_performance_benchmarking() -> Result<(), Box<dyn std::error::Error>> {
88    println!("\n--- Performance Benchmarking ---");
89
90    // Create test data of different sizes
91    let sizes = vec![100, 1000, 10000];
92
93    for size in sizes {
94        println!("\nBenchmarking with tensor size: {}", size);
95
96        // Generate test data
97        let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
98        let tensor = Tensor::from_slice(&data, vec![size])?;
99
100        // Benchmark 1: Direct tensor operations
101        let start = Instant::now();
102        let direct_result = tensor.mul_scalar(2.0).add_scalar(1.0);
103        let direct_time = start.elapsed();
104
105        // Benchmark 2: Iterator-based operations (grad-enabled views, flatten + collect_shape)
106        let start = Instant::now();
107        let iterator_result: Tensor = tensor
108            .iter_elements()
109            .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0))
110            .collect_shape(vec![size]);
111        let iterator_time = start.elapsed();
112
113        // Benchmark 3: Chained iterator operations
114        let start = Instant::now();
115        let _chained_result: Tensor = tensor
116            .iter_elements()
117            .map(|elem| elem.mul_scalar(2.0))
118            .filter(|elem| elem.value() > size as f32)
119            .map(|elem| elem.add_scalar(1.0))
120            .collect();
121        let chained_time = start.elapsed();
122
123        // Benchmark 4: NoGrad raw data streaming + collect_shape
124        let start = Instant::now();
125        let _streamed: Tensor = with_no_grad(|| {
126            tensor
127                .data()
128                .iter()
129                .copied()
130                .map(|x| 2.0 * x + 1.0)
131                .collect_shape(vec![size])
132        });
133        let streaming_time = start.elapsed();
134
135        // Report results
136        println!("  Direct operations: {:?}", direct_time);
137        println!("  Iterator operations: {:?}", iterator_time);
138        println!("  Chained operations: {:?}", chained_time);
139        println!("  NoGrad streaming (data.iter): {:?}", streaming_time);
140
141        // Verify correctness
142        assert_eq!(direct_result.data(), iterator_result.data());
143        println!(
144            "  Results match: {}",
145            direct_result.data() == iterator_result.data()
146        );
147
148        // Performance ratios
149        let ratio = iterator_time.as_nanos() as f64 / direct_time.as_nanos() as f64;
150        let ratio_stream = streaming_time.as_nanos() as f64 / direct_time.as_nanos() as f64;
151        println!("  Iterator/Direct ratio: {:.2}x", ratio);
152        println!("  Streaming/Direct ratio: {:.2}x", ratio_stream);
153    }
154
155    Ok(())
156}
157
158/// Demonstrate memory optimization patterns
159///
160/// Shows memory-efficient processing patterns and techniques
161/// for minimizing memory usage while maintaining performance.
162fn demonstrate_memory_optimization() -> Result<(), Box<dyn std::error::Error>> {
163    println!("\n--- Memory Optimization ---");
164
165    // Create a large tensor for memory testing
166    let size = 10000;
167    let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
168    let tensor = Tensor::from_slice(&data, vec![size])?;
169
170    println!("Processing tensor of size: {}", size);
171
172    // Pattern 1: Streaming processing with iterator chunks (process in blocks, collect with shape)
173    println!("\nPattern 1: Streaming Processing");
174    let chunk_size = 1000;
175    let start = Instant::now();
176    let flattened = tensor.view(vec![size as i32]);
177    let _streamed_result: Tensor = flattened
178        .chunks(chunk_size)
179        .map(|c| c.pow_scalar(2.0).sqrt())
180        .collect_shape(vec![size]);
181    let streamed_time = start.elapsed();
182
183    // Pattern 2: Full processing
184    let start = Instant::now();
185    let _full_result: Tensor = tensor
186        .iter_elements()
187        .map(|elem| elem.pow_scalar(2.0).sqrt())
188        .collect_shape(vec![size]);
189    let full_time = start.elapsed();
190
191    println!("  Streaming time: {:?}", streamed_time);
192    println!("  Full processing time: {:?}", full_time);
193    println!(
194        "  Memory efficiency ratio: {:.2}x",
195        full_time.as_nanos() as f64 / streamed_time.as_nanos() as f64
196    );
197
198    // Pattern 3: Lazy evaluation with take
199    println!("\nPattern 2: Lazy Evaluation");
200    let start = Instant::now();
201    let lazy_result: Tensor = tensor
202        .iter_elements()
203        .take(1000) // Only process first 1000 elements
204        .map(|elem| elem.pow_scalar(2.0).sqrt())
205        .collect_shape(vec![1000]);
206    let lazy_time = start.elapsed();
207
208    println!("  Lazy processing (1000 elements): {:?}", lazy_time);
209    println!("  Lazy result size: {}", lazy_result.size());
210
211    // Pattern 4: Memory-efficient filtering
212    println!("\nPattern 3: Memory-Efficient Filtering");
213    let start = Instant::now();
214    let filtered_result: Tensor = tensor
215        .iter_elements()
216        .filter(|elem| elem.value() > size as f32 / 2.0) // Keep only large values
217        .map(|elem| elem.mul_scalar(2.0))
218        .collect();
219    let filtered_time = start.elapsed();
220
221    println!("  Filtered processing: {:?}", filtered_time);
222    println!(
223        "  Filtered result size: {} (reduced from {})",
224        filtered_result.size(),
225        size
226    );
227
228    Ok(())
229}
Source

pub fn iter_range(&self, start: usize, end: usize) -> TensorElementIterator<'_>

Create an iterator over a clamped range of elements

Produces scalar view tensors from start..end (clamped to [0, size]).

§Arguments
  • start - Start index (inclusive)
  • end - End index (exclusive)
§Examples
use train_station::Tensor;

let x = Tensor::from_slice(&(0..6).map(|i| i as f32).collect::<Vec<_>>(), vec![6]).unwrap();
let vals: Vec<f32> = x.iter_range(2, 5).map(|e| e.value()).collect();
assert_eq!(vals, vec![2.0, 3.0, 4.0]);
Examples found in repository?
examples/iterators/performance_optimization.rs (line 262)
235fn demonstrate_large_scale_processing() -> Result<(), Box<dyn std::error::Error>> {
236    println!("\n--- Large-Scale Processing ---");
237
238    // Simulate large dataset processing
239    let sizes = vec![10000, 50000, 100000];
240
241    for size in sizes {
242        println!("\nProcessing dataset of size: {}", size);
243
244        // Generate large dataset
245        let data: Vec<f32> = (0..size)
246            .map(|i| {
247                let x = i as f32 / size as f32;
248                x * x + 0.1 * (i % 10) as f32 // Quadratic with noise
249            })
250            .collect();
251
252        let tensor = Tensor::from_slice(&data, vec![size])?;
253
254        // Technique 1: Batch processing
255        let batch_size = 1000;
256        let start = Instant::now();
257
258        let mut batch_results = Vec::new();
259        for batch_start in (0..size).step_by(batch_size) {
260            let batch_end = (batch_start + batch_size).min(size);
261            let batch: Tensor = tensor
262                .iter_range(batch_start, batch_end)
263                .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
264                .collect();
265            batch_results.push(batch);
266        }
267        let batch_time = start.elapsed();
268
269        // Technique 2: Parallel-like processing with stride
270        let start = Instant::now();
271        let stride = 4;
272        let strided_result: Tensor = tensor
273            .iter()
274            .enumerate()
275            .filter(|(i, _)| i % stride == 0)
276            .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
277            .collect();
278        let strided_time = start.elapsed();
279
280        // Technique 3: Hierarchical processing
281        let start = Instant::now();
282        let coarse: Tensor = tensor
283            .iter()
284            .enumerate()
285            .filter(|(i, _)| i % 10 == 0) // Every 10th element
286            .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
287            .collect();
288        let fine: Tensor = tensor
289            .iter()
290            .enumerate()
291            .filter(|(i, _)| i % 10 != 0) // Rest of elements
292            .map(|(_, elem)| elem.pow_scalar(1.5).add_scalar(0.5))
293            .collect();
294        let hierarchical_time = start.elapsed();
295
296        // Report performance
297        println!("  Batch processing: {:?}", batch_time);
298        println!("  Strided processing: {:?}", strided_time);
299        println!("  Hierarchical processing: {:?}", hierarchical_time);
300
301        // Memory usage analysis
302        let total_batches = size.div_ceil(batch_size);
303        println!("  Batch count: {}", total_batches);
304        println!("  Strided result size: {}", strided_result.size());
305        println!(
306            "  Hierarchical: coarse={}, fine={}",
307            coarse.size(),
308            fine.size()
309        );
310    }
311
312    Ok(())
313}
More examples
Hide additional examples
examples/iterators/advanced_patterns.rs (line 363)
340fn demonstrate_real_world_scenarios() -> Result<(), Box<dyn std::error::Error>> {
341    println!("\n--- Real-world Scenarios ---");
342
343    // Scenario 1: Time series analysis
344    println!("\nScenario 1: Time Series Analysis");
345    let time_series: Vec<f32> = (0..24)
346        .map(|hour| {
347            let base = 20.0 + 10.0 * (hour as f32 * std::f32::consts::PI / 12.0).sin();
348            base + (hour % 3) as f32 * 2.0 // Add some noise
349        })
350        .collect();
351
352    let series = Tensor::from_slice(&time_series, vec![24])?;
353    println!("  Time series (24 hours): {:?}", series.data());
354
355    // Calculate moving average with view-based iteration
356    let window_size = 3;
357    let moving_avg: Tensor = series
358        .iter()
359        .enumerate()
360        .map(|(i, _)| {
361            let start = i.saturating_sub(window_size / 2);
362            let end = (i + window_size / 2 + 1).min(series.size());
363            let window = series.iter_range(start, end);
364            window.fold(0.0, |acc, elem| acc + elem.value()) / (end - start) as f32
365        })
366        .map(|val| Tensor::from_slice(&[val], vec![1]).unwrap())
367        .collect();
368    println!(
369        "  Moving average (window={}): {:?}",
370        window_size,
371        moving_avg.data()
372    );
373
374    // Inference pipeline with NoGrad + streaming
375    println!("\nInference pipeline (NoGrad + streaming)");
376    let features = Tensor::from_slice(
377        &(0..48).map(|i| i as f32 * 0.125).collect::<Vec<_>>(),
378        vec![6, 8],
379    )?;
380    let fast = with_no_grad(|| {
381        // Stream values directly, apply light affine, and collect back to same shape
382        features
383            .data()
384            .iter()
385            .copied()
386            .map(|x| 0.75 * x + 0.1)
387            .collect_shape(vec![6, 8])
388    });
389    println!(
390        "  NoGrad streamed transform shape: {:?}",
391        fast.shape().dims()
392    );
393
394    // Row-wise iteration with shape-preserving collection (GradTrack-friendly)
395    let per_row: Tensor = features
396        .iter()
397        .map(|row| row.mul_scalar(0.5).add_scalar(2.0))
398        .collect_shape(vec![6, 8]);
399    println!("  Row-wise mapped shape: {:?}", per_row.shape().dims());
400
401    // Scenario 2: Feature engineering
402    println!("\nScenario 2: Feature Engineering");
403    let features = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
404    println!("  Original features: {:?}", features.data());
405
406    // Create polynomial features
407    let poly_features: Tensor = features
408        .iter()
409        .flat_map(|elem| {
410            vec![
411                elem.clone(),         // x^1
412                elem.pow_scalar(2.0), // x^2
413                elem.pow_scalar(3.0), // x^3
414            ]
415        })
416        .collect();
417    println!(
418        "  Polynomial features (x, x^2, x^3): {:?}",
419        poly_features.data()
420    );
421
422    // Scenario 3: Data augmentation
423    println!("\nScenario 3: Data Augmentation");
424    let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?;
425    println!("  Original data: {:?}", original.data());
426
427    // Augment with noise and scaling
428    let augmented: Tensor = original
429        .iter()
430        .flat_map(|elem| {
431            vec![
432                elem.clone(),         // Original
433                elem.add_scalar(0.1), // Add noise
434                elem.sub_scalar(0.1), // Subtract noise
435                elem.mul_scalar(1.1), // Scale up
436                elem.mul_scalar(0.9), // Scale down
437            ]
438        })
439        .collect();
440    println!("  Augmented data: {:?}", augmented.data());
441
442    // Scenario 4: Statistical analysis
443    println!("\nScenario 4: Statistical Analysis");
444    let sample_data = Tensor::from_slice(&[1.1, 2.3, 1.8, 2.1, 1.9, 2.0, 1.7, 2.2], vec![8])?;
445    println!("  Sample data: {:?}", sample_data.data());
446
447    // Calculate various statistics
448    let mean = sample_data.mean().value();
449    let std = sample_data.std().value();
450    let min = sample_data
451        .iter()
452        .map(|e| e.value())
453        .fold(f32::INFINITY, f32::min);
454    let max = sample_data
455        .iter()
456        .map(|e| e.value())
457        .fold(f32::NEG_INFINITY, f32::max);
458
459    // Z-score normalization
460    let z_scores: Tensor = sample_data
461        .iter()
462        .map(|elem| elem.sub_scalar(mean).div_scalar(std))
463        .collect();
464
465    println!(
466        "  Statistics: mean={:.3}, std={:.3}, min={:.3}, max={:.3}",
467        mean, std, min, max
468    );
469    println!("  Z-scores: {:?}", z_scores.data());
470
471    Ok(())
472}
Source§

impl Tensor

Source

pub fn iter_dim(&self, dim: usize) -> TensorDimIterator<'_>

Iterate over sub-tensors along a specific dimension.

Produces view tensors by slicing along the given dimension; each item has that dimension removed (rank - 1). Views share storage and preserve gradient tracking semantics.

§Arguments
  • dim - Dimension to iterate over
§Examples
use train_station::tensor::TensorCollectExt;
use train_station::Tensor;

let t = Tensor::from_slice(&(1..=6).map(|i| i as f32).collect::<Vec<_>>(), vec![2, 3]).unwrap();
let out = t.iter_dim(0).map(|row| row.add_scalar(1.0)).collect_shape(vec![2, 3]);
assert_eq!(out.data(), &[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
Source

pub fn iter(&self) -> TensorDimIterator<'_>

Default iterator over the outermost dimension, yielding sub-tensors (N-D) or scalar views (1-D).

This is equivalent to iter_dim(0) with an important optimization:

  • For 1-D tensors, it yields scalar element views of shape [1] (same as iter_flat()) to maximize GradTrack cooperation and collection performance.
  • For N-D tensors (rank > 1), it yields sub-tensors with the outermost dimension removed (rank − 1), suitable for row/batch-wise processing.

Views share storage with the source tensor and preserve gradient tracking semantics. Use collect_shape([..]) to reconstruct shape efficiently after per-item transforms.

§Returns

An iterator producing view tensors for each slice along the outermost dimension (or scalar views for 1-D).

§Examples

1-D: element views (shape [1]) and shape-preserving collection

use train_station::tensor::TensorCollectExt;
use train_station::Tensor;

let v = Tensor::from_slice(&(0..6).map(|i| i as f32).collect::<Vec<_>>(), vec![6]).unwrap();
let out = v.iter().map(|e| e.add_scalar(1.0)).collect_shape(vec![6]);
assert_eq!(out.data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);

2-D: row-wise transforms and shape-preserving collection

use train_station::tensor::TensorCollectExt;
use train_station::Tensor;

let m = Tensor::from_slice(&(1..=6).map(|i| i as f32).collect::<Vec<_>>(), vec![2, 3]).unwrap();
let y = m.iter().map(|row| row.mul_scalar(2.0)).collect_shape(vec![2, 3]);
assert_eq!(y.data(), &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0]);
Examples found in repository?
examples/iterators/element_iteration.rs (line 102)
93fn demonstrate_basic_iteration() -> Result<(), Box<dyn std::error::Error>> {
94    println!("\n--- Basic Element Iteration ---");
95
96    // Create a simple tensor for demonstration
97    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
98    println!("Original tensor: {:?}", tensor.data());
99
100    // Basic iteration with for loop
101    println!("\nBasic iteration with for loop:");
102    for (i, element) in tensor.iter().enumerate() {
103        println!(
104            "  Element {}: value = {:.1}, shape = {:?}",
105            i,
106            element.value(),
107            element.shape().dims()
108        );
109    }
110
111    // Element-wise transformation
112    println!("\nElement-wise transformation (2x + 1):");
113    let transformed: Tensor = tensor
114        .iter()
115        .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0))
116        .collect();
117    println!("  Result: {:?}", transformed.data());
118
119    // Filtering elements
120    println!("\nFiltering elements (values > 3.0):");
121    let filtered: Tensor = tensor.iter().filter(|elem| elem.value() > 3.0).collect();
122    println!("  Filtered: {:?}", filtered.data());
123
124    Ok(())
125}
126
127/// Demonstrate standard iterator trait methods
128///
129/// Shows compatibility with Rust's standard library iterator methods
130/// and demonstrates various functional programming patterns.
131fn demonstrate_standard_methods() -> Result<(), Box<dyn std::error::Error>> {
132    println!("\n--- Standard Iterator Methods ---");
133
134    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
135
136    // Using map for transformations
137    println!("\nMap transformation (square each element):");
138    let squared: Tensor = tensor.iter().map(|elem| elem.pow_scalar(2.0)).collect();
139    println!("  Squared: {:?}", squared.data());
140
141    // Using enumerate for indexed operations
142    println!("\nEnumerate with indexed operations:");
143    let indexed: Tensor = tensor
144        .iter()
145        .enumerate()
146        .map(|(i, elem)| elem.add_scalar(i as f32))
147        .collect();
148    println!("  Indexed: {:?}", indexed.data());
149
150    // Using fold for reduction
151    println!("\nFold for sum calculation:");
152    let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
153    println!("  Sum: {:.1}", sum);
154
155    // Using find for element search
156    println!("\nFind specific element:");
157    if let Some(found) = tensor.iter().find(|elem| elem.value() == 3.0) {
158        println!("  Found element with value 3.0: {:.1}", found.value());
159    }
160
161    // Using any/all for condition checking
162    println!("\nCondition checking:");
163    let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
164    let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
165    println!("  All positive: {}", all_positive);
166    println!("  Any > 4.0: {}", any_large);
167
168    Ok(())
169}
170
171/// Demonstrate gradient tracking through element operations
172///
173/// Shows how gradient tracking works seamlessly through iterator
174/// operations, maintaining the computational graph for backpropagation.
175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176    println!("\n--- Gradient Tracking ---");
177
178    // Create a tensor with gradient tracking enabled
179    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180    println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182    // Perform element-wise operations through iteration
183    let result: Tensor = tensor
184        .iter()
185        .map(|elem| {
186            // Apply a complex transformation: (x^2 + 1) * 2
187            elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188        })
189        .collect();
190
191    println!("Result tensor: {:?}", result.data());
192    println!("Result requires_grad: {}", result.requires_grad());
193
194    // Compute gradients
195    let mut loss = result.sum();
196    loss.backward(None);
197
198    println!("Loss: {:.6}", loss.value());
199    println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201    Ok(())
202}
203
204/// Demonstrate advanced iterator patterns
205///
206/// Shows complex iterator chains and advanced functional programming
207/// patterns for sophisticated data processing workflows.
208fn demonstrate_advanced_patterns() -> Result<(), Box<dyn std::error::Error>> {
209    println!("\n--- Advanced Iterator Patterns ---");
210
211    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6])?;
212    println!("Input tensor: {:?}", tensor.data());
213
214    // Complex chain: enumerate -> filter -> map -> collect
215    println!("\nComplex chain (even indices only, add index to value):");
216    let result: Tensor = tensor
217        .iter()
218        .enumerate()
219        .filter(|(i, _)| i % 2 == 0) // Take even indices
220        .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
221        .collect();
222    println!("  Result: {:?}", result.data());
223
224    // Using take and skip for windowing
225    println!("\nWindowing with take and skip:");
226    let window1: Tensor = tensor.iter().take(3).collect();
227    let window2: Tensor = tensor.iter().skip(2).take(3).collect();
228    println!("  Window 1 (first 3): {:?}", window1.data());
229    println!("  Window 2 (middle 3): {:?}", window2.data());
230
231    // Using rev() for reverse iteration
232    println!("\nReverse iteration:");
233    let reversed: Tensor = tensor.iter().rev().collect();
234    println!("  Reversed: {:?}", reversed.data());
235
236    // Chaining with mathematical operations
237    println!("\nMathematical operation chain:");
238    let math_result: Tensor = tensor
239        .iter()
240        .map(|elem| elem.exp()) // e^x
241        .filter(|elem| elem.value() < 50.0) // Filter large values
242        .map(|elem| elem.log()) // ln(x)
243        .collect();
244    println!("  Math chain result: {:?}", math_result.data());
245
246    // Using zip for element-wise combinations
247    println!("\nElement-wise combination with zip:");
248    let tensor2 = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![6])?;
249    let combined: Tensor = tensor
250        .iter()
251        .zip(tensor2.iter())
252        .map(|(a, b)| a.mul_tensor(&b)) // Element-wise multiplication
253        .collect();
254    println!("  Combined: {:?}", combined.data());
255
256    Ok(())
257}
258
259/// Demonstrate per-row transforms with shape-preserving collection
260///
261/// Shows how to use `iter()` over the outer dimension on a 2D tensor and
262/// `collect_shape([..])` to maintain the original shape after mapping.
263fn demonstrate_row_wise_collect_shape() -> Result<(), Box<dyn std::error::Error>> {
264    println!("\n--- Row-wise iteration with collect_shape ---");
265    let mat = Tensor::from_slice(&(1..=12).map(|x| x as f32).collect::<Vec<_>>(), vec![3, 4])?;
266    println!("Input shape: {:?}", mat.shape().dims());
267
268    // Map each row: 1.1*x + 0.5, then collect back to [3,4]
269    let out: Tensor = mat
270        .iter()
271        .map(|row| row.mul_scalar(1.1).add_scalar(0.5))
272        .collect_shape(vec![3, 4]);
273    println!("  Output shape: {:?}", out.shape().dims());
274
275    Ok(())
276}
277
278/// Demonstrate NoGrad fast paths and raw data streaming
279///
280/// Highlights how to get maximum performance in inference by:
281/// - Disabling gradient tracking with `with_no_grad`
282/// - Iterating raw values via `tensor.data().iter().copied()`
283/// - Using `collect_shape` to stream directly into destination tensors
284fn demonstrate_nograd_and_streaming() -> Result<(), Box<dyn std::error::Error>> {
285    println!("\n--- NoGrad & Streaming (Inference Fast Paths) ---");
286
287    let input = Tensor::from_slice(
288        &(0..24).map(|i| i as f32 * 0.25).collect::<Vec<_>>(),
289        vec![4, 6],
290    )?;
291    println!("Input shape: {:?}", input.shape().dims());
292
293    // NoGrad: stream values directly and reshape
294    let out = with_no_grad(|| {
295        input
296            .data()
297            .iter()
298            .copied()
299            .map(|x| 1.2 * x - 0.3)
300            .collect_shape(vec![4, 6])
301    });
302    println!(
303        "  NoGrad streamed map (1.2x-0.3) -> shape {:?}",
304        out.shape().dims()
305    );
306
307    // Compare to view-based element iteration in NoGrad
308    let out_view: Tensor = with_no_grad(|| {
309        input
310            .iter()
311            .map(|e| e.mul_scalar(1.2).add_scalar(-0.3))
312            .collect_shape(vec![4, 6])
313    });
314    println!(
315        "  NoGrad view-based map shape {:?}",
316        out_view.shape().dims()
317    );
318
319    // Quick parity check
320    assert_eq!(out.data(), out_view.data());
321    println!("  Parity check passed.");
322
323    // Show simple flatten + collect back to a different shape
324    let reshaped = with_no_grad(|| input.data().iter().copied().collect_shape(vec![6, 4]));
325    println!(
326        "  Reshaped via streaming collect_shape: {:?}",
327        reshaped.shape().dims()
328    );
329
330    Ok(())
331}
More examples
Hide additional examples
examples/iterators/advanced_patterns.rs (line 113)
87fn demonstrate_data_pipeline() -> Result<(), Box<dyn std::error::Error>> {
88    println!("\n--- Data Processing Pipeline ---");
89
90    // Simulate raw sensor data with noise
91    let raw_data: Vec<f32> = (0..20)
92        .map(|i| {
93            let base = i as f32 * 0.5;
94            let noise = (i % 3) as f32 * 0.1;
95            base + noise
96        })
97        .collect();
98
99    let tensor = Tensor::from_slice(&raw_data, vec![20])?;
100    println!("Raw sensor data: {:?}", tensor.data());
101
102    // Multi-stage processing pipeline
103    println!("\nProcessing pipeline:");
104    println!("1. Normalize data (z-score)");
105    println!("2. Apply smoothing filter");
106    println!("3. Detect outliers");
107    println!("4. Apply feature scaling");
108
109    // Stage 1: Normalization
110    let mean = tensor.mean().value();
111    let std = tensor.std().value();
112    let normalized: Tensor = tensor
113        .iter()
114        .map(|elem| elem.sub_scalar(mean).div_scalar(std))
115        .collect();
116    println!(
117        "  Normalized (mean={:.3}, std={:.3}): {:?}",
118        mean,
119        std,
120        normalized.data()
121    );
122
123    // Stage 2: Smoothing (simple moving average)
124    let smoothed: Tensor = normalized
125        .iter()
126        .enumerate()
127        .map(|(i, elem)| {
128            if i == 0 || i == normalized.size() - 1 {
129                elem.clone()
130            } else {
131                // Simple 3-point average
132                let prev = normalized.element_view(i - 1);
133                let next = normalized.element_view(i + 1);
134                elem.add_tensor(&prev).add_tensor(&next).div_scalar(3.0)
135            }
136        })
137        .collect();
138    println!("  Smoothed: {:?}", smoothed.data());
139
140    // Stage 3: Outlier detection and removal
141    let outlier_threshold = 2.0;
142    let cleaned: Tensor = smoothed
143        .iter()
144        .filter(|elem| elem.value().abs() < outlier_threshold)
145        .collect();
146    println!(
147        "  Outliers removed (threshold={}): {:?}",
148        outlier_threshold,
149        cleaned.data()
150    );
151
152    // Stage 4: Feature scaling to [0, 1] range
153    let min_val = cleaned
154        .iter()
155        .map(|e| e.value())
156        .fold(f32::INFINITY, f32::min);
157    let max_val = cleaned
158        .iter()
159        .map(|e| e.value())
160        .fold(f32::NEG_INFINITY, f32::max);
161    let scaled: Tensor = cleaned
162        .iter()
163        .map(|elem| elem.sub_scalar(min_val).div_scalar(max_val - min_val))
164        .collect();
165    println!("  Scaled to [0,1]: {:?}", scaled.data());
166
167    Ok(())
168}
169
170/// Demonstrate conditional processing patterns
171///
172/// Shows how to implement dynamic filtering and transformation
173/// based on data characteristics and conditions.
174fn demonstrate_conditional_processing() -> Result<(), Box<dyn std::error::Error>> {
175    println!("\n--- Conditional Processing ---");
176
177    // Create data with mixed characteristics
178    let data = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0];
179    let tensor = Tensor::from_slice(&data, vec![10])?;
180    println!("Input data: {:?}", tensor.data());
181
182    // Conditional transformation based on sign
183    println!("\nConditional transformation (positive/negative handling):");
184    let processed: Tensor = tensor
185        .iter()
186        .map(|elem| {
187            let val = elem.value();
188            if val > 0.0 {
189                elem.pow_scalar(2.0) // Square positive values
190            } else {
191                elem.mul_scalar(-1.0).sqrt() // Square root of absolute negative values
192            }
193        })
194        .collect();
195    println!("  Processed: {:?}", processed.data());
196
197    // Adaptive filtering based on local statistics
198    println!("\nAdaptive filtering (remove values > 2 std from local mean):");
199    let window_size = 3;
200    let adaptive_filtered: Tensor = tensor
201        .iter()
202        .enumerate()
203        .filter(|(i, elem)| {
204            let start = i.saturating_sub(window_size / 2);
205            let end = (i + window_size / 2 + 1).min(tensor.size());
206
207            // Calculate local mean and std
208            let local_values: Vec<f32> = (start..end)
209                .map(|j| tensor.element_view(j).value())
210                .collect();
211
212            let local_mean = local_values.iter().sum::<f32>() / local_values.len() as f32;
213            let local_variance = local_values
214                .iter()
215                .map(|v| (v - local_mean).powi(2))
216                .sum::<f32>()
217                / local_values.len() as f32;
218            let local_std = local_variance.sqrt();
219
220            let threshold = local_mean + 2.0 * local_std;
221            elem.value() <= threshold
222        })
223        .map(|(_, elem)| elem)
224        .collect();
225    println!("  Adaptive filtered: {:?}", adaptive_filtered.data());
226
227    // Multi-condition processing
228    println!("\nMulti-condition processing:");
229    let multi_processed: Tensor = tensor
230        .iter()
231        .map(|elem| {
232            let val = elem.value();
233            match () {
234                _ if val > 5.0 => elem.mul_scalar(2.0), // Double large values
235                _ if val < -5.0 => elem.div_scalar(2.0), // Halve small values
236                _ if val.abs() < 2.0 => elem.add_scalar(1.0), // Add 1 to small values
237                _ => elem.clone(),                      // Keep others unchanged
238            }
239        })
240        .collect();
241    println!("  Multi-condition: {:?}", multi_processed.data());
242
243    Ok(())
244}
245
246/// Demonstrate batch processing operations
247///
248/// Shows efficient processing of large datasets using iterator
249/// patterns and batch operations for performance optimization.
250fn demonstrate_batch_operations() -> Result<(), Box<dyn std::error::Error>> {
251    println!("\n--- Batch Operations ---");
252
253    // Create a larger dataset for batch processing
254    let size = 100;
255    let data: Vec<f32> = (0..size)
256        .map(|i| {
257            let x = i as f32 / size as f32;
258            x * x + 0.1 * (i % 7) as f32 // Quadratic with some noise
259        })
260        .collect();
261
262    let tensor = Tensor::from_slice(&data, vec![size])?;
263    println!("Dataset size: {}", tensor.size());
264
265    // Batch processing with windowing (iterator views)
266    println!("\nBatch processing with sliding windows:");
267    let batch_size = 10;
268    let batches: Vec<Tensor> = tensor
269        .iter()
270        .collect::<Vec<_>>()
271        .chunks(batch_size)
272        .map(|chunk| {
273            // Process each batch independently
274            chunk
275                .iter()
276                .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
277                .collect()
278        })
279        .collect();
280
281    println!(
282        "  Processed {} batches of size {}",
283        batches.len(),
284        batch_size
285    );
286    for (i, batch) in batches.iter().enumerate() {
287        println!(
288            "    Batch {}: mean={:.3}, std={:.3}",
289            i,
290            batch.mean().value(),
291            batch.std().value()
292        );
293    }
294
295    // Parallel-like processing with stride
296    println!("\nStrided processing (every nth element):");
297    let stride = 5;
298    let strided: Tensor = tensor
299        .iter()
300        .enumerate()
301        .filter(|(i, _)| i % stride == 0)
302        .map(|(_, elem)| elem)
303        .collect();
304    println!("  Strided (every {}th): {:?}", stride, strided.data());
305
306    // Hierarchical processing
307    println!("\nHierarchical processing (coarse to fine):");
308    let coarse: Tensor = tensor
309        .iter()
310        .enumerate()
311        .filter(|(i, _)| i % 4 == 0) // Take every 4th element
312        .map(|(_, elem)| elem)
313        .collect();
314
315    let fine: Tensor = tensor
316        .iter()
317        .enumerate()
318        .filter(|(i, _)| i % 4 != 0) // Take the rest
319        .map(|(_, elem)| elem)
320        .collect();
321
322    println!("  Coarse (every 4th): {:?}", coarse.data());
323    println!("  Fine (rest): {:?}", fine.data());
324
325    // Combine coarse and fine with different processing
326    let combined: Tensor = coarse
327        .iter()
328        .map(|elem| elem.mul_scalar(2.0)) // Scale coarse
329        .chain(fine.iter().map(|elem| elem.div_scalar(2.0))) // Scale fine
330        .collect();
331    println!("  Combined: {:?}", combined.data());
332
333    Ok(())
334}
335
336/// Demonstrate real-world processing scenarios
337///
338/// Shows practical applications of iterator patterns for
339/// common data processing tasks in machine learning and analytics.
340fn demonstrate_real_world_scenarios() -> Result<(), Box<dyn std::error::Error>> {
341    println!("\n--- Real-world Scenarios ---");
342
343    // Scenario 1: Time series analysis
344    println!("\nScenario 1: Time Series Analysis");
345    let time_series: Vec<f32> = (0..24)
346        .map(|hour| {
347            let base = 20.0 + 10.0 * (hour as f32 * std::f32::consts::PI / 12.0).sin();
348            base + (hour % 3) as f32 * 2.0 // Add some noise
349        })
350        .collect();
351
352    let series = Tensor::from_slice(&time_series, vec![24])?;
353    println!("  Time series (24 hours): {:?}", series.data());
354
355    // Calculate moving average with view-based iteration
356    let window_size = 3;
357    let moving_avg: Tensor = series
358        .iter()
359        .enumerate()
360        .map(|(i, _)| {
361            let start = i.saturating_sub(window_size / 2);
362            let end = (i + window_size / 2 + 1).min(series.size());
363            let window = series.iter_range(start, end);
364            window.fold(0.0, |acc, elem| acc + elem.value()) / (end - start) as f32
365        })
366        .map(|val| Tensor::from_slice(&[val], vec![1]).unwrap())
367        .collect();
368    println!(
369        "  Moving average (window={}): {:?}",
370        window_size,
371        moving_avg.data()
372    );
373
374    // Inference pipeline with NoGrad + streaming
375    println!("\nInference pipeline (NoGrad + streaming)");
376    let features = Tensor::from_slice(
377        &(0..48).map(|i| i as f32 * 0.125).collect::<Vec<_>>(),
378        vec![6, 8],
379    )?;
380    let fast = with_no_grad(|| {
381        // Stream values directly, apply light affine, and collect back to same shape
382        features
383            .data()
384            .iter()
385            .copied()
386            .map(|x| 0.75 * x + 0.1)
387            .collect_shape(vec![6, 8])
388    });
389    println!(
390        "  NoGrad streamed transform shape: {:?}",
391        fast.shape().dims()
392    );
393
394    // Row-wise iteration with shape-preserving collection (GradTrack-friendly)
395    let per_row: Tensor = features
396        .iter()
397        .map(|row| row.mul_scalar(0.5).add_scalar(2.0))
398        .collect_shape(vec![6, 8]);
399    println!("  Row-wise mapped shape: {:?}", per_row.shape().dims());
400
401    // Scenario 2: Feature engineering
402    println!("\nScenario 2: Feature Engineering");
403    let features = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
404    println!("  Original features: {:?}", features.data());
405
406    // Create polynomial features
407    let poly_features: Tensor = features
408        .iter()
409        .flat_map(|elem| {
410            vec![
411                elem.clone(),         // x^1
412                elem.pow_scalar(2.0), // x^2
413                elem.pow_scalar(3.0), // x^3
414            ]
415        })
416        .collect();
417    println!(
418        "  Polynomial features (x, x^2, x^3): {:?}",
419        poly_features.data()
420    );
421
422    // Scenario 3: Data augmentation
423    println!("\nScenario 3: Data Augmentation");
424    let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?;
425    println!("  Original data: {:?}", original.data());
426
427    // Augment with noise and scaling
428    let augmented: Tensor = original
429        .iter()
430        .flat_map(|elem| {
431            vec![
432                elem.clone(),         // Original
433                elem.add_scalar(0.1), // Add noise
434                elem.sub_scalar(0.1), // Subtract noise
435                elem.mul_scalar(1.1), // Scale up
436                elem.mul_scalar(0.9), // Scale down
437            ]
438        })
439        .collect();
440    println!("  Augmented data: {:?}", augmented.data());
441
442    // Scenario 4: Statistical analysis
443    println!("\nScenario 4: Statistical Analysis");
444    let sample_data = Tensor::from_slice(&[1.1, 2.3, 1.8, 2.1, 1.9, 2.0, 1.7, 2.2], vec![8])?;
445    println!("  Sample data: {:?}", sample_data.data());
446
447    // Calculate various statistics
448    let mean = sample_data.mean().value();
449    let std = sample_data.std().value();
450    let min = sample_data
451        .iter()
452        .map(|e| e.value())
453        .fold(f32::INFINITY, f32::min);
454    let max = sample_data
455        .iter()
456        .map(|e| e.value())
457        .fold(f32::NEG_INFINITY, f32::max);
458
459    // Z-score normalization
460    let z_scores: Tensor = sample_data
461        .iter()
462        .map(|elem| elem.sub_scalar(mean).div_scalar(std))
463        .collect();
464
465    println!(
466        "  Statistics: mean={:.3}, std={:.3}, min={:.3}, max={:.3}",
467        mean, std, min, max
468    );
469    println!("  Z-scores: {:?}", z_scores.data());
470
471    Ok(())
472}
examples/iterators/performance_optimization.rs (line 273)
235fn demonstrate_large_scale_processing() -> Result<(), Box<dyn std::error::Error>> {
236    println!("\n--- Large-Scale Processing ---");
237
238    // Simulate large dataset processing
239    let sizes = vec![10000, 50000, 100000];
240
241    for size in sizes {
242        println!("\nProcessing dataset of size: {}", size);
243
244        // Generate large dataset
245        let data: Vec<f32> = (0..size)
246            .map(|i| {
247                let x = i as f32 / size as f32;
248                x * x + 0.1 * (i % 10) as f32 // Quadratic with noise
249            })
250            .collect();
251
252        let tensor = Tensor::from_slice(&data, vec![size])?;
253
254        // Technique 1: Batch processing
255        let batch_size = 1000;
256        let start = Instant::now();
257
258        let mut batch_results = Vec::new();
259        for batch_start in (0..size).step_by(batch_size) {
260            let batch_end = (batch_start + batch_size).min(size);
261            let batch: Tensor = tensor
262                .iter_range(batch_start, batch_end)
263                .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
264                .collect();
265            batch_results.push(batch);
266        }
267        let batch_time = start.elapsed();
268
269        // Technique 2: Parallel-like processing with stride
270        let start = Instant::now();
271        let stride = 4;
272        let strided_result: Tensor = tensor
273            .iter()
274            .enumerate()
275            .filter(|(i, _)| i % stride == 0)
276            .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
277            .collect();
278        let strided_time = start.elapsed();
279
280        // Technique 3: Hierarchical processing
281        let start = Instant::now();
282        let coarse: Tensor = tensor
283            .iter()
284            .enumerate()
285            .filter(|(i, _)| i % 10 == 0) // Every 10th element
286            .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
287            .collect();
288        let fine: Tensor = tensor
289            .iter()
290            .enumerate()
291            .filter(|(i, _)| i % 10 != 0) // Rest of elements
292            .map(|(_, elem)| elem.pow_scalar(1.5).add_scalar(0.5))
293            .collect();
294        let hierarchical_time = start.elapsed();
295
296        // Report performance
297        println!("  Batch processing: {:?}", batch_time);
298        println!("  Strided processing: {:?}", strided_time);
299        println!("  Hierarchical processing: {:?}", hierarchical_time);
300
301        // Memory usage analysis
302        let total_batches = size.div_ceil(batch_size);
303        println!("  Batch count: {}", total_batches);
304        println!("  Strided result size: {}", strided_result.size());
305        println!(
306            "  Hierarchical: coarse={}, fine={}",
307            coarse.size(),
308            fine.size()
309        );
310    }
311
312    Ok(())
313}
314
315/// Demonstrate advanced optimization techniques
316///
317/// Shows sophisticated optimization strategies and techniques
318/// for maximizing performance in tensor iterator operations.
319fn demonstrate_optimization_techniques() -> Result<(), Box<dyn std::error::Error>> {
320    println!("\n--- Optimization Techniques ---");
321
322    let size = 50000;
323    let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
324    let tensor = Tensor::from_slice(&data, vec![size])?;
325
326    println!("Optimizing processing for size: {}", size);
327
328    // Technique 1: Operation fusion
329    println!("\nTechnique 1: Operation Fusion");
330    let start = Instant::now();
331    let fused_result: Tensor = tensor
332        .iter()
333        .map(|elem| {
334            // Fuse multiple operations into single chain
335            elem.mul_scalar(2.0).add_scalar(1.0).pow_scalar(2.0).sqrt()
336        })
337        .collect();
338    let fused_time = start.elapsed();
339
340    // Technique 2: Conditional optimization
341    println!("\nTechnique 2: Conditional Optimization");
342    let start = Instant::now();
343    let conditional_result: Tensor = tensor
344        .iter()
345        .map(|elem| {
346            let val = elem.value();
347            if val < size as f32 / 2.0 {
348                elem.mul_scalar(2.0) // Simple operation for small values
349            } else {
350                elem.pow_scalar(2.0).sqrt() // Complex operation for large values
351            }
352        })
353        .collect();
354    let conditional_time = start.elapsed();
355
356    // Technique 3: Cache-friendly processing
357    println!("\nTechnique 3: Cache-Friendly Processing");
358    let start = Instant::now();
359    let cache_friendly_result: Tensor = tensor
360        .iter()
361        .take(1000) // Process in cache-friendly chunks
362        .map(|elem| elem.mul_scalar(2.0))
363        .collect();
364    let cache_friendly_time = start.elapsed();
365
366    // Technique 4: Memory pooling simulation
367    println!("\nTechnique 4: Memory Pooling Simulation");
368    let start = Instant::now();
369    let pooled_result: Tensor = tensor
370        .iter()
371        .enumerate()
372        .filter(|(i, _)| i % 100 == 0) // Process every 100th element
373        .map(|(_, elem)| elem.pow_scalar(2.0))
374        .collect();
375    let pooled_time = start.elapsed();
376
377    // Report optimization results
378    println!("  Fused operations: {:?}", fused_time);
379    println!("  Conditional optimization: {:?}", conditional_time);
380    println!("  Cache-friendly processing: {:?}", cache_friendly_time);
381    println!("  Memory pooling simulation: {:?}", pooled_time);
382
383    // Performance analysis
384    let fastest = fused_time
385        .min(conditional_time)
386        .min(cache_friendly_time)
387        .min(pooled_time);
388    println!("  Fastest technique: {:?}", fastest);
389
390    // Memory efficiency analysis
391    println!("  Fused result size: {}", fused_result.size());
392    println!("  Conditional result size: {}", conditional_result.size());
393    println!(
394        "  Cache-friendly result size: {}",
395        cache_friendly_result.size()
396    );
397    println!("  Pooled result size: {}", pooled_result.size());
398
399    // Technique 5: Gradient optimization
400    println!("\nTechnique 5: Gradient Optimization");
401    let grad_tensor = tensor.with_requires_grad();
402    let start = Instant::now();
403
404    let grad_result: Tensor = grad_tensor
405        .iter()
406        .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
407        .collect();
408
409    let mut loss = grad_result.sum();
410    loss.backward(None);
411    let grad_time = start.elapsed();
412
413    println!("  Gradient computation: {:?}", grad_time);
414    println!(
415        "  Gradient tracking enabled: {}",
416        grad_result.requires_grad()
417    );
418
419    Ok(())
420}
Source

pub fn outer_iter(&self) -> TensorDimIterator<'_>

Explicit alias for outermost-dimension iteration of sub-tensors. Equivalent to iter_dim(0).

Source§

impl Tensor

Source

pub fn windows(&self, window_size: usize) -> TensorWindowsIterator<'_>

Overlapping windows iterator with step=1. Use this instead of iter_windows.

Produces overlapping linear windows as view tensors. In no-grad fast mode, a contiguous owner may be materialized once for faster subsequent views.

§Arguments
  • window_size - Length of each window (> 0)
§Examples
use train_station::Tensor;

let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
let v: Vec<f32> = t.windows(3).map(|w| w.sum().value()).collect();
assert_eq!(v, vec![6.0, 9.0]);
Source

pub fn windows_step( &self, window_size: usize, step: usize, ) -> TensorWindowsIterator<'_>

Overlapping windows iterator with custom step. Use this instead of iter_windows_step.

Produces windows starting at positions 0, step, 2*step, ... up to the last valid start. Reverse iteration yields the same sequence in reverse.

§Arguments
  • window_size - Length of each window (> 0)
  • step - Step between consecutive window starts (> 0)
§Examples
use train_station::Tensor;

let t = Tensor::from_slice(&(1..=8).map(|i| i as f32).collect::<Vec<_>>(), vec![8]).unwrap();
let wins: Vec<Tensor> = t.windows_step(3, 2).collect();
assert_eq!(wins[0].data(), &[1.0, 2.0, 3.0]);
assert_eq!(wins[1].data(), &[3.0, 4.0, 5.0]);
assert_eq!(wins[2].data(), &[5.0, 6.0, 7.0]);
Source

pub fn iter_windows(&self, window_size: usize) -> TensorWindowsIterator<'_>

👎Deprecated: Use Tensor::windows(…) instead. This alias will be removed before 1.0.
Source

pub fn iter_windows_step( &self, window_size: usize, step: usize, ) -> TensorWindowsIterator<'_>

👎Deprecated: Use Tensor::windows_step(…) instead. This alias will be removed before 1.0.
Source§

impl Tensor

Source

pub fn add_tensor(&self, other: &Tensor) -> Tensor

Element-wise addition with another tensor with broadcasting support.

Performs element-wise addition with automatic broadcasting: output[i] = self[i] + other[i]

Broadcasting enables addition between tensors of different but compatible shapes. Compatible shapes follow NumPy broadcasting rules:

  • Dimensions are aligned from the rightmost dimension
  • Dimensions are compatible if they are equal, or one of them is 1
  • Missing dimensions are treated as 1
§Arguments
  • other - Tensor to add. Shapes must be broadcast-compatible.
§Returns

A new tensor containing the element-wise sum with broadcast result shape

§Examples
§Same Shape Addition
use train_station::Tensor;

let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
let c = a.add_tensor(&b);
assert_eq!(c.shape().dims(), vec![3]);
assert_eq!(c.get(&[0]), 5.0);
assert_eq!(c.get(&[1]), 7.0);
assert_eq!(c.get(&[2]), 9.0);
§Broadcasting Addition
use train_station::Tensor;

// Broadcasting: [2, 1] + [1, 3] -> [2, 3]
let a = Tensor::from_slice(&[1.0, 2.0], vec![2, 1]).unwrap();
let b = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![1, 3]).unwrap();
let c = a.add_tensor(&b);
assert_eq!(c.shape().dims(), vec![2, 3]);
assert_eq!(c.get(&[0, 0]), 11.0);
assert_eq!(c.get(&[0, 1]), 21.0);
assert_eq!(c.get(&[1, 0]), 12.0);
assert_eq!(c.get(&[1, 1]), 22.0);
§Scalar Broadcasting
use train_station::Tensor;

// Scalar broadcasting: [2, 3] + scalar -> [2, 3]
let a = Tensor::ones(vec![2, 3]);
let b = Tensor::from_slice(&[5.0], vec![1]).unwrap();
let c = a.add_tensor(&b);
assert_eq!(c.shape().dims(), vec![2, 3]);
assert_eq!(c.get(&[0, 0]), 6.0);
assert_eq!(c.get(&[1, 2]), 6.0);
§Panics

Panics if tensor shapes are not broadcast-compatible

Examples found in repository?
examples/RL_training/../neural_networks/basic_linear_layer.rs (line 75)
71    pub fn forward(&self, input: &Tensor) -> Tensor {
72        // Matrix multiplication: [batch_size, input_size] @ [input_size, output_size] = [batch_size, output_size]
73        let output = input.matmul(&self.weight);
74        // Add bias: [batch_size, output_size] + [output_size] = [batch_size, output_size]
75        output.add_tensor(&self.bias)
76    }
More examples
Hide additional examples
examples/supervised_training/supervised_bce.rs (line 63)
59fn bce_with_logits(logits: &Tensor, targets: &Tensor) -> Tensor {
60    let relu_z = logits.relu();
61    let zy = logits.mul_tensor(targets);
62    // |z| = relu(z) + relu(-z)
63    let abs_z = relu_z.add_tensor(&logits.mul_scalar(-1.0).relu());
64    let log_term = abs_z.mul_scalar(-1.0).exp().add_scalar(1.0).log();
65    relu_z.sub_tensor(&zy).add_tensor(&log_term).mean()
66}
examples/RL_training/ppo_discrete.rs (line 299)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296    let b = ratio.shape().dims()[0];
297    let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298    let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299    let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300    high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304    let mut total_sq = 0.0f32;
305    for p in parameters.iter_mut() {
306        if let Some(g) = p.grad_owned() {
307            for &v in g.data() {
308                total_sq += v * v;
309            }
310        }
311    }
312    total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320    println!("=== PPO Discrete Example (YardEnv) ===");
321
322    let state_dim = 3usize;
323    let action_dim = 3usize;
324    let total_steps = std::env::var("PPOD_STEPS")
325        .ok()
326        .and_then(|v| v.parse::<usize>().ok())
327        .unwrap_or(3500usize);
328    let horizon = 128usize;
329    let epochs = 4usize;
330    let mini_batch_size = 64usize;
331    let gamma = 0.99f32;
332    let lam = 0.95f32;
333    let clip_eps = 0.2f32;
334    let vf_coef = 0.5f32;
335    let ent_coef = 0.0f32;
336    let max_grad_norm = 1.0f32;
337
338    let mut actor = Actor::new(state_dim, action_dim, Some(111));
339    let mut critic = Critic::new(state_dim, Some(222));
340    let mut actor_opt = Adam::with_learning_rate(3e-4);
341    for p in actor.parameters() {
342        actor_opt.add_parameter(p);
343    }
344    let mut critic_opt = Adam::with_learning_rate(3e-4);
345    for p in critic.parameters() {
346        critic_opt.add_parameter(p);
347    }
348
349    let mut env = YardEnv::new(1234);
350    let mut rng = SmallRng::new(98765);
351    let mut state = env.reset();
352    let mut episode_return = 0.0f32;
353    let mut episode = 0usize;
354    let mut ema_return: Option<f32> = None;
355    let ema_alpha = 0.05f32;
356    let mut best_return = f32::NEG_INFINITY;
357
358    let mut t = 0usize;
359    while t < total_steps {
360        let mut batch = RolloutBatch::new(horizon, state_dim);
361        for _ in 0..horizon {
362            // Actor logits and categorical sampling
363            let logits = actor.forward(&state); // [1, A]
364            let probs = logits.softmax(1); // [1, A]
365                                           // sample action from probs (CPU sampling)
366            let p = probs.data();
367            let (p0, p1, _p2) = (p[0], p[1], p[2]);
368            let u = rng.next_f32();
369            let a_idx = if u < p0 {
370                0
371            } else if u < p0 + p1 {
372                1
373            } else {
374                2
375            };
376
377            let old_logp = {
378                let _ng = NoGradTrack::new();
379                let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380                lp.data()[0]
381            };
382
383            // Step env
384            let (next_state, reward, done) = env.step(a_idx);
385            episode_return += reward;
386
387            // Critic value
388            let value_t = critic.forward(&state);
389            let value_v = value_t.data()[0];
390
391            batch.push(
392                state.data(),
393                a_idx,
394                old_logp,
395                reward,
396                if done { 1.0 } else { 0.0 },
397                value_v,
398                next_state.data(),
399            );
400
401            state = if done {
402                let st = env.reset();
403                ema_return = Some(match ema_return {
404                    None => episode_return,
405                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406                });
407                if episode_return > best_return {
408                    best_return = episode_return;
409                }
410                println!(
411                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412                    t,
413                    episode,
414                    episode_return,
415                    ema_return.unwrap_or(episode_return),
416                    best_return
417                );
418                episode_return = 0.0;
419                episode += 1;
420                st
421            } else {
422                next_state
423            };
424
425            t += 1;
426            if t >= total_steps {
427                break;
428            }
429        }
430
431        // Bootstrap values for GAE
432        let next_values: Vec<f32> = {
433            let mut out = Vec::with_capacity(batch.len());
434            for i in 0..batch.len() {
435                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437                out.push(critic.forward(&s2_t).data()[0]);
438            }
439            out
440        };
441
442        let mut returns = vec![0.0f32; batch.len()];
443        let mut adv = vec![0.0f32; batch.len()];
444        compute_gae(
445            &mut returns,
446            &mut adv,
447            &batch.rewards,
448            &batch.dones,
449            &batch.values,
450            &next_values,
451            gamma,
452            lam,
453        );
454        normalize_in_place(&mut adv, 1e-8);
455
456        // Tensors for training
457        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458        let actions_vec = batch.actions.clone();
459        let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463        // PPO epochs
464        let num_minibatches = batch.len().div_ceil(mini_batch_size);
465        for e in 0..epochs {
466            for mb in 0..num_minibatches {
467                let start = mb * mini_batch_size;
468                let end = (start + mini_batch_size).min(batch.len());
469                if start >= end {
470                    break;
471                }
472
473                // Views
474                let s_mb = states_t
475                    .slice_view(start * state_dim, 1, (end - start) * state_dim)
476                    .reshape(vec![(end - start) as i32, state_dim as i32]);
477                let oldlp_mb = old_logp_t
478                    .slice_view(start, 1, end - start)
479                    .reshape(vec![(end - start) as i32, 1]);
480                let ret_mb = returns_t
481                    .slice_view(start, 1, end - start)
482                    .reshape(vec![(end - start) as i32, 1]);
483                let adv_mb = adv_t
484                    .slice_view(start, 1, end - start)
485                    .reshape(vec![(end - start) as i32, 1]);
486                let a_slice = &actions_vec[start..end];
487
488                // Zero grads
489                {
490                    let mut ps = actor.parameters();
491                    actor_opt.zero_grad(&mut ps);
492                }
493                {
494                    let mut ps = critic.parameters();
495                    critic_opt.zero_grad(&mut ps);
496                }
497
498                // Forward
499                let logits_mb = actor.forward(&s_mb); // [B,A]
500                let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501                let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502                let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503                let pg1 = ratio.mul_tensor(&adv_mb);
504                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507                let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509                let v_pred = critic.forward(&s_mb);
510                let v_loss = v_pred
511                    .sub_tensor(&ret_mb)
512                    .pow_scalar(2.0)
513                    .mean()
514                    .mul_scalar(vf_coef);
515
516                // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517                let probs_mb = logits_mb.softmax(1);
518                let logp_all = probs_mb.add_scalar(1e-8).log();
519                let ent = probs_mb
520                    .mul_tensor(&logp_all)
521                    .sum_dims(&[1], true)
522                    .mul_scalar(-1.0)
523                    .mean()
524                    .mul_scalar(ent_coef);
525
526                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527                loss.backward(None);
528
529                // Step actor
530                {
531                    let params = actor.parameters();
532                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
533                    for p in params {
534                        if p.grad_owned().is_some() {
535                            with_grads.push(p);
536                        }
537                    }
538                    if !with_grads.is_empty() {
539                        let _ = grad_global_norm(&mut with_grads);
540                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541                        actor_opt.step(&mut with_grads);
542                        actor_opt.zero_grad(&mut with_grads);
543                    }
544                }
545
546                // Step critic
547                {
548                    let params = critic.parameters();
549                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
550                    for p in params {
551                        if p.grad_owned().is_some() {
552                            with_grads.push(p);
553                        }
554                    }
555                    if !with_grads.is_empty() {
556                        let _ = grad_global_norm(&mut with_grads);
557                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558                        critic_opt.step(&mut with_grads);
559                        critic_opt.zero_grad(&mut with_grads);
560                    }
561                }
562
563                if e == 0 && mb == 0 {
564                    println!(
565                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
566                        t,
567                        actor_loss.value(),
568                        v_loss.value()
569                    );
570                }
571
572                clear_all_graphs_known();
573            }
574        }
575    }
576
577    println!("=== PPO discrete training finished ===");
578    Ok(())
579}
examples/neural_networks/basic_encoder.rs (line 55)
53    pub fn forward(&self, input: &Tensor, attn_mask: Option<&Tensor>) -> Tensor {
54        let attn = self.mha.forward(input, input, input, attn_mask);
55        let res1 = attn.add_tensor(input);
56
57        // Feed-forward network with ReLU and residual
58        let (b, t, e) = Self::triple(input);
59        let x2d = res1.contiguous().view(vec![(b * t) as i32, e as i32]);
60        let hidden = self.ffn_in.forward(&x2d).relu();
61        let out2d = self.ffn_out.forward(&hidden);
62        let out = out2d.view(vec![b as i32, t as i32, e as i32]);
63        out.add_tensor(&res1)
64    }
examples/RL_training/ppo_continuous.rs (line 244)
234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235    // All tensors shaped [B, A] (log_std is broadcastable)
236    let std = log_std.exp();
237    let var = std.pow_scalar(2.0);
238    let log_scale = log_std;
239    let diff = action.sub_tensor(mean);
240    let log_prob = diff
241        .pow_scalar(2.0)
242        .div_tensor(&var)
243        .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244        .add_tensor(&log_scale.mul_scalar(2.0))
245        .mul_scalar(0.5)
246        .mul_scalar(-1.0);
247    // Sum across action dim (dim=1) -> [B,1]
248    log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253    returns_out: &mut [f32],
254    adv_out: &mut [f32],
255    rewards: &[f32],
256    dones: &[f32],
257    values: &[f32],
258    next_values: &[f32],
259    gamma: f32,
260    lam: f32,
261) {
262    let n = rewards.len();
263    let mut gae = 0.0f32;
264    for t in (0..n).rev() {
265        let not_done = 1.0 - dones[t];
266        let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267        gae = delta + gamma * lam * not_done * gae;
268        adv_out[t] = gae;
269        returns_out[t] = gae + values[t];
270    }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274    let n = x.len() as f32;
275    if n <= 1.0 {
276        return;
277    }
278    let mean = x.iter().copied().sum::<f32>() / n;
279    let var = x
280        .iter()
281        .map(|v| {
282            let d = v - mean;
283            d * d
284        })
285        .sum::<f32>()
286        / n;
287    let std = (var + eps).sqrt();
288    for v in x.iter_mut() {
289        *v = (*v - mean) / std;
290    }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294    let mut total_sq = 0.0f32;
295    for p in parameters.iter() {
296        if let Some(g) = p.grad_owned() {
297            for &v in g.data() {
298                total_sq += v * v;
299            }
300        }
301    }
302    let norm = total_sq.sqrt();
303    if norm > max_norm {
304        let scale = max_norm / (norm + eps);
305        for p in parameters.iter_mut() {
306            if let Some(g) = p.grad_owned() {
307                p.set_grad(g.mul_scalar(scale));
308            }
309        }
310    }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314    let mut total_sq = 0.0f32;
315    for p in parameters.iter_mut() {
316        if let Some(g) = p.grad_owned() {
317            for &v in g.data() {
318                total_sq += v * v;
319            }
320        }
321    }
322    total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330    println!("=== PPO Continuous Example (YardEnv) ===");
331
332    let state_dim = 3usize;
333    let action_dim = 1usize;
334
335    // Hparams
336    let total_steps = std::env::var("PPO_STEPS")
337        .ok()
338        .and_then(|v| v.parse::<usize>().ok())
339        .unwrap_or(4000usize);
340    let horizon = 128usize; // rollout length per update
341    let epochs = 4usize; // PPO epochs per update
342    let mini_batch_size = 64usize; // minibatch from horizon
343    let gamma = 0.99f32;
344    let lam = 0.95f32; // GAE lambda
345    let clip_eps = 0.2f32;
346    let vf_coef = 0.5f32;
347    let ent_coef = 0.0f32;
348    let max_grad_norm = 1.0f32;
349
350    // Models
351    let mut actor = Actor::new(state_dim, action_dim, Some(101));
352    let mut critic = Critic::new(state_dim, Some(202));
353
354    // Opts
355    let mut actor_opt = Adam::with_learning_rate(3e-4);
356    for p in actor.parameters() {
357        actor_opt.add_parameter(p);
358    }
359    let mut critic_opt = Adam::with_learning_rate(3e-4);
360    for p in critic.parameters() {
361        critic_opt.add_parameter(p);
362    }
363
364    // Env and RNG
365    let mut env = YardEnv::new(42);
366    let mut rng = SmallRng::new(999);
367    let mut state = env.reset();
368
369    // Metrics
370    let mut episode_return = 0.0f32;
371    let mut episode = 0usize;
372    let mut ema_return: Option<f32> = None;
373    let ema_alpha = 0.05f32;
374    let mut best_return = f32::NEG_INFINITY;
375
376    let mut t = 0usize;
377    while t < total_steps {
378        // Collect a rollout
379        let mut batch = RolloutBatch::new(horizon, state_dim);
380        for _ in 0..horizon {
381            // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382            let (mean, log_std_row) = actor.forward(&state);
383            let mean_v = mean.data()[0];
384            let log_std_v = log_std_row.data()[0];
385            let std_v = log_std_v.exp();
386            let noise = rng.normal();
387            let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389            // Build action tensor [1, A] for log_prob calculation with autograd
390            let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391            let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392            let log_prob_v = log_prob_t.data()[0];
393
394            // Step env
395            let (next_state, reward, done) = env.step(action_v);
396            episode_return += reward;
397
398            // Value
399            let value_t = critic.forward(&state);
400            let value_v = value_t.data()[0];
401
402            // Push
403            batch.push(
404                state.data(),
405                action_v,
406                log_prob_v,
407                reward,
408                if done { 1.0 } else { 0.0 },
409                value_v,
410                next_state.data(),
411            );
412
413            // Reset
414            state = if done {
415                let st = env.reset();
416                ema_return = Some(match ema_return {
417                    None => episode_return,
418                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419                });
420                if episode_return > best_return {
421                    best_return = episode_return;
422                }
423                println!(
424                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425                    t,
426                    episode,
427                    episode_return,
428                    ema_return.unwrap_or(episode_return),
429                    best_return
430                );
431                episode_return = 0.0;
432                episode += 1;
433                st
434            } else {
435                next_state
436            };
437
438            t += 1;
439            if t >= total_steps {
440                break;
441            }
442        }
443
444        // Bootstrap next values for GAE
445        let next_values: Vec<f32> = {
446            let mut out = Vec::with_capacity(batch.len());
447            for i in 0..batch.len() {
448                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450                let v2 = critic.forward(&s2_t).data()[0];
451                out.push(v2);
452            }
453            out
454        };
455
456        // Compute returns and advantages
457        let mut returns = vec![0.0f32; batch.len()];
458        let mut adv = vec![0.0f32; batch.len()];
459        compute_gae(
460            &mut returns,
461            &mut adv,
462            &batch.rewards,
463            &batch.dones,
464            &batch.values,
465            &next_values,
466            gamma,
467            lam,
468        );
469        normalize_in_place(&mut adv, 1e-8);
470
471        // Prepare tensors for training
472        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473        let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474        let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478        // PPO epochs over the rollout
479        let num_minibatches = batch.len().div_ceil(mini_batch_size);
480        for e in 0..epochs {
481            for mb in 0..num_minibatches {
482                let start = mb * mini_batch_size;
483                let end = (start + mini_batch_size).min(batch.len());
484                if start >= end {
485                    break;
486                }
487
488                // Slice views
489                let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490                let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491                let a_mb = actions_t
492                    .slice_view(start * action_dim, 1, (end - start) * action_dim)
493                    .reshape(vec![(end - start) as i32, action_dim as i32]);
494                let oldlp_mb = old_logp_t
495                    .slice_view(start, 1, end - start)
496                    .reshape(vec![(end - start) as i32, 1]);
497                let ret_mb = returns_t
498                    .slice_view(start, 1, end - start)
499                    .reshape(vec![(end - start) as i32, 1]);
500                let adv_mb = adv_t
501                    .slice_view(start, 1, end - start)
502                    .reshape(vec![(end - start) as i32, 1]);
503
504                // Zero grads
505                {
506                    let mut ps = actor.parameters();
507                    actor_opt.zero_grad(&mut ps);
508                }
509                {
510                    let mut ps = critic.parameters();
511                    critic_opt.zero_grad(&mut ps);
512                }
513
514                // Forward actor and critic
515                let (mean_mb, log_std_row) = actor.forward(&s_mb);
516                let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517                let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518                let clip_low =
519                    Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520                        .unwrap();
521                let clip_high =
522                    Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523                        .unwrap();
524                // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525                let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526                let ratio_clipped =
527                    clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528                let pg1 = ratio.mul_tensor(&adv_mb);
529                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532                let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534                let v_pred = critic.forward(&s_mb);
535                let v_loss = v_pred
536                    .sub_tensor(&ret_mb)
537                    .pow_scalar(2.0)
538                    .mean()
539                    .mul_scalar(vf_coef);
540
541                // Entropy (approx Gaussian entropy per action)
542                let entropy = log_std_row
543                    .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544                    .sum_dims(&[1], true)
545                    .mean()
546                    .mul_scalar(ent_coef);
547
548                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549                loss.backward(None);
550
551                // Step actor
552                {
553                    let params = actor.parameters();
554                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
555                    for p in params {
556                        if p.grad_owned().is_some() {
557                            with_grads.push(p);
558                        }
559                    }
560                    if !with_grads.is_empty() {
561                        let _ = grad_global_norm(&mut with_grads);
562                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563                        actor_opt.step(&mut with_grads);
564                        actor_opt.zero_grad(&mut with_grads);
565                    }
566                }
567
568                // Step critic
569                {
570                    let params = critic.parameters();
571                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
572                    for p in params {
573                        if p.grad_owned().is_some() {
574                            with_grads.push(p);
575                        }
576                    }
577                    if !with_grads.is_empty() {
578                        let _ = grad_global_norm(&mut with_grads);
579                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580                        critic_opt.step(&mut with_grads);
581                        critic_opt.zero_grad(&mut with_grads);
582                    }
583                }
584
585                // Occasionally log
586                if e == 0 && mb == 0 {
587                    println!(
588                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
589                        t,
590                        actor_loss.value(),
591                        v_loss.value()
592                    );
593                }
594
595                clear_all_graphs_known();
596            }
597        }
598    }
599
600    println!("=== PPO training finished ===");
601    Ok(())
602}
examples/neural_networks/basic_decoder.rs (line 64)
56    pub fn forward(
57        &self,
58        tgt: &Tensor,
59        memory: &Tensor,
60        causal_mask: Option<&Tensor>,
61        cross_mask: Option<&Tensor>,
62    ) -> Tensor {
63        let self_attn = self.self_attn.forward(tgt, tgt, tgt, causal_mask);
64        let res1 = self_attn.add_tensor(tgt);
65
66        let cross = self.cross_attn.forward(&res1, memory, memory, cross_mask);
67        let res2 = cross.add_tensor(&res1);
68
69        let (b, t, e) = Self::triple(tgt);
70        let x2d = res2.contiguous().view(vec![(b * t) as i32, e as i32]);
71        let hidden = self.ffn_in.forward(&x2d).relu();
72        let out2d = self.ffn_out.forward(&hidden);
73        let out = out2d.view(vec![b as i32, t as i32, e as i32]);
74        out.add_tensor(&res2)
75    }
Source

pub fn add_scalar(&self, scalar: f32) -> Tensor

Broadcast addition with a scalar value.

Examples found in repository?
examples/RL_training/dqn.rs (line 323)
321fn pseudo_huber_mean(diff: &Tensor) -> Tensor {
322    diff.pow_scalar(2.0)
323        .add_scalar(1.0)
324        .sqrt()
325        .sub_scalar(1.0)
326        .mean()
327}
More examples
Hide additional examples
examples/supervised_training/supervised_bce.rs (line 64)
59fn bce_with_logits(logits: &Tensor, targets: &Tensor) -> Tensor {
60    let relu_z = logits.relu();
61    let zy = logits.mul_tensor(targets);
62    // |z| = relu(z) + relu(-z)
63    let abs_z = relu_z.add_tensor(&logits.mul_scalar(-1.0).relu());
64    let log_term = abs_z.mul_scalar(-1.0).exp().add_scalar(1.0).log();
65    relu_z.sub_tensor(&zy).add_tensor(&log_term).mean()
66}
examples/RL_training/ppo_continuous.rs (line 243)
234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235    // All tensors shaped [B, A] (log_std is broadcastable)
236    let std = log_std.exp();
237    let var = std.pow_scalar(2.0);
238    let log_scale = log_std;
239    let diff = action.sub_tensor(mean);
240    let log_prob = diff
241        .pow_scalar(2.0)
242        .div_tensor(&var)
243        .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244        .add_tensor(&log_scale.mul_scalar(2.0))
245        .mul_scalar(0.5)
246        .mul_scalar(-1.0);
247    // Sum across action dim (dim=1) -> [B,1]
248    log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253    returns_out: &mut [f32],
254    adv_out: &mut [f32],
255    rewards: &[f32],
256    dones: &[f32],
257    values: &[f32],
258    next_values: &[f32],
259    gamma: f32,
260    lam: f32,
261) {
262    let n = rewards.len();
263    let mut gae = 0.0f32;
264    for t in (0..n).rev() {
265        let not_done = 1.0 - dones[t];
266        let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267        gae = delta + gamma * lam * not_done * gae;
268        adv_out[t] = gae;
269        returns_out[t] = gae + values[t];
270    }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274    let n = x.len() as f32;
275    if n <= 1.0 {
276        return;
277    }
278    let mean = x.iter().copied().sum::<f32>() / n;
279    let var = x
280        .iter()
281        .map(|v| {
282            let d = v - mean;
283            d * d
284        })
285        .sum::<f32>()
286        / n;
287    let std = (var + eps).sqrt();
288    for v in x.iter_mut() {
289        *v = (*v - mean) / std;
290    }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294    let mut total_sq = 0.0f32;
295    for p in parameters.iter() {
296        if let Some(g) = p.grad_owned() {
297            for &v in g.data() {
298                total_sq += v * v;
299            }
300        }
301    }
302    let norm = total_sq.sqrt();
303    if norm > max_norm {
304        let scale = max_norm / (norm + eps);
305        for p in parameters.iter_mut() {
306            if let Some(g) = p.grad_owned() {
307                p.set_grad(g.mul_scalar(scale));
308            }
309        }
310    }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314    let mut total_sq = 0.0f32;
315    for p in parameters.iter_mut() {
316        if let Some(g) = p.grad_owned() {
317            for &v in g.data() {
318                total_sq += v * v;
319            }
320        }
321    }
322    total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330    println!("=== PPO Continuous Example (YardEnv) ===");
331
332    let state_dim = 3usize;
333    let action_dim = 1usize;
334
335    // Hparams
336    let total_steps = std::env::var("PPO_STEPS")
337        .ok()
338        .and_then(|v| v.parse::<usize>().ok())
339        .unwrap_or(4000usize);
340    let horizon = 128usize; // rollout length per update
341    let epochs = 4usize; // PPO epochs per update
342    let mini_batch_size = 64usize; // minibatch from horizon
343    let gamma = 0.99f32;
344    let lam = 0.95f32; // GAE lambda
345    let clip_eps = 0.2f32;
346    let vf_coef = 0.5f32;
347    let ent_coef = 0.0f32;
348    let max_grad_norm = 1.0f32;
349
350    // Models
351    let mut actor = Actor::new(state_dim, action_dim, Some(101));
352    let mut critic = Critic::new(state_dim, Some(202));
353
354    // Opts
355    let mut actor_opt = Adam::with_learning_rate(3e-4);
356    for p in actor.parameters() {
357        actor_opt.add_parameter(p);
358    }
359    let mut critic_opt = Adam::with_learning_rate(3e-4);
360    for p in critic.parameters() {
361        critic_opt.add_parameter(p);
362    }
363
364    // Env and RNG
365    let mut env = YardEnv::new(42);
366    let mut rng = SmallRng::new(999);
367    let mut state = env.reset();
368
369    // Metrics
370    let mut episode_return = 0.0f32;
371    let mut episode = 0usize;
372    let mut ema_return: Option<f32> = None;
373    let ema_alpha = 0.05f32;
374    let mut best_return = f32::NEG_INFINITY;
375
376    let mut t = 0usize;
377    while t < total_steps {
378        // Collect a rollout
379        let mut batch = RolloutBatch::new(horizon, state_dim);
380        for _ in 0..horizon {
381            // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382            let (mean, log_std_row) = actor.forward(&state);
383            let mean_v = mean.data()[0];
384            let log_std_v = log_std_row.data()[0];
385            let std_v = log_std_v.exp();
386            let noise = rng.normal();
387            let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389            // Build action tensor [1, A] for log_prob calculation with autograd
390            let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391            let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392            let log_prob_v = log_prob_t.data()[0];
393
394            // Step env
395            let (next_state, reward, done) = env.step(action_v);
396            episode_return += reward;
397
398            // Value
399            let value_t = critic.forward(&state);
400            let value_v = value_t.data()[0];
401
402            // Push
403            batch.push(
404                state.data(),
405                action_v,
406                log_prob_v,
407                reward,
408                if done { 1.0 } else { 0.0 },
409                value_v,
410                next_state.data(),
411            );
412
413            // Reset
414            state = if done {
415                let st = env.reset();
416                ema_return = Some(match ema_return {
417                    None => episode_return,
418                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419                });
420                if episode_return > best_return {
421                    best_return = episode_return;
422                }
423                println!(
424                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425                    t,
426                    episode,
427                    episode_return,
428                    ema_return.unwrap_or(episode_return),
429                    best_return
430                );
431                episode_return = 0.0;
432                episode += 1;
433                st
434            } else {
435                next_state
436            };
437
438            t += 1;
439            if t >= total_steps {
440                break;
441            }
442        }
443
444        // Bootstrap next values for GAE
445        let next_values: Vec<f32> = {
446            let mut out = Vec::with_capacity(batch.len());
447            for i in 0..batch.len() {
448                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450                let v2 = critic.forward(&s2_t).data()[0];
451                out.push(v2);
452            }
453            out
454        };
455
456        // Compute returns and advantages
457        let mut returns = vec![0.0f32; batch.len()];
458        let mut adv = vec![0.0f32; batch.len()];
459        compute_gae(
460            &mut returns,
461            &mut adv,
462            &batch.rewards,
463            &batch.dones,
464            &batch.values,
465            &next_values,
466            gamma,
467            lam,
468        );
469        normalize_in_place(&mut adv, 1e-8);
470
471        // Prepare tensors for training
472        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473        let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474        let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478        // PPO epochs over the rollout
479        let num_minibatches = batch.len().div_ceil(mini_batch_size);
480        for e in 0..epochs {
481            for mb in 0..num_minibatches {
482                let start = mb * mini_batch_size;
483                let end = (start + mini_batch_size).min(batch.len());
484                if start >= end {
485                    break;
486                }
487
488                // Slice views
489                let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490                let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491                let a_mb = actions_t
492                    .slice_view(start * action_dim, 1, (end - start) * action_dim)
493                    .reshape(vec![(end - start) as i32, action_dim as i32]);
494                let oldlp_mb = old_logp_t
495                    .slice_view(start, 1, end - start)
496                    .reshape(vec![(end - start) as i32, 1]);
497                let ret_mb = returns_t
498                    .slice_view(start, 1, end - start)
499                    .reshape(vec![(end - start) as i32, 1]);
500                let adv_mb = adv_t
501                    .slice_view(start, 1, end - start)
502                    .reshape(vec![(end - start) as i32, 1]);
503
504                // Zero grads
505                {
506                    let mut ps = actor.parameters();
507                    actor_opt.zero_grad(&mut ps);
508                }
509                {
510                    let mut ps = critic.parameters();
511                    critic_opt.zero_grad(&mut ps);
512                }
513
514                // Forward actor and critic
515                let (mean_mb, log_std_row) = actor.forward(&s_mb);
516                let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517                let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518                let clip_low =
519                    Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520                        .unwrap();
521                let clip_high =
522                    Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523                        .unwrap();
524                // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525                let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526                let ratio_clipped =
527                    clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528                let pg1 = ratio.mul_tensor(&adv_mb);
529                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532                let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534                let v_pred = critic.forward(&s_mb);
535                let v_loss = v_pred
536                    .sub_tensor(&ret_mb)
537                    .pow_scalar(2.0)
538                    .mean()
539                    .mul_scalar(vf_coef);
540
541                // Entropy (approx Gaussian entropy per action)
542                let entropy = log_std_row
543                    .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544                    .sum_dims(&[1], true)
545                    .mean()
546                    .mul_scalar(ent_coef);
547
548                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549                loss.backward(None);
550
551                // Step actor
552                {
553                    let params = actor.parameters();
554                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
555                    for p in params {
556                        if p.grad_owned().is_some() {
557                            with_grads.push(p);
558                        }
559                    }
560                    if !with_grads.is_empty() {
561                        let _ = grad_global_norm(&mut with_grads);
562                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563                        actor_opt.step(&mut with_grads);
564                        actor_opt.zero_grad(&mut with_grads);
565                    }
566                }
567
568                // Step critic
569                {
570                    let params = critic.parameters();
571                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
572                    for p in params {
573                        if p.grad_owned().is_some() {
574                            with_grads.push(p);
575                        }
576                    }
577                    if !with_grads.is_empty() {
578                        let _ = grad_global_norm(&mut with_grads);
579                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580                        critic_opt.step(&mut with_grads);
581                        critic_opt.zero_grad(&mut with_grads);
582                    }
583                }
584
585                // Occasionally log
586                if e == 0 && mb == 0 {
587                    println!(
588                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
589                        t,
590                        actor_loss.value(),
591                        v_loss.value()
592                    );
593                }
594
595                clear_all_graphs_known();
596            }
597        }
598    }
599
600    println!("=== PPO training finished ===");
601    Ok(())
602}
examples/getting_started/tensor_basics.rs (line 112)
89fn demonstrate_basic_operations() {
90    println!("\n--- Basic Operations ---");
91
92    let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
93    let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
94
95    // Addition
96    let sum = a.add_tensor(&b);
97    println!("A + B: {:?}", sum.data());
98
99    // Subtraction
100    let diff = a.sub_tensor(&b);
101    println!("A - B: {:?}", diff.data());
102
103    // Multiplication
104    let product = a.mul_tensor(&b);
105    println!("A * B: {:?}", product.data());
106
107    // Division
108    let quotient = a.div_tensor(&b);
109    println!("A / B: {:?}", quotient.data());
110
111    // Scalar operations
112    let scalar_add = a.add_scalar(5.0);
113    println!("A + 5.0: {:?}", scalar_add.data());
114
115    let scalar_mul = a.mul_scalar(2.0);
116    println!("A * 2.0: {:?}", scalar_mul.data());
117}
examples/iterators/element_iteration.rs (line 115)
93fn demonstrate_basic_iteration() -> Result<(), Box<dyn std::error::Error>> {
94    println!("\n--- Basic Element Iteration ---");
95
96    // Create a simple tensor for demonstration
97    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
98    println!("Original tensor: {:?}", tensor.data());
99
100    // Basic iteration with for loop
101    println!("\nBasic iteration with for loop:");
102    for (i, element) in tensor.iter().enumerate() {
103        println!(
104            "  Element {}: value = {:.1}, shape = {:?}",
105            i,
106            element.value(),
107            element.shape().dims()
108        );
109    }
110
111    // Element-wise transformation
112    println!("\nElement-wise transformation (2x + 1):");
113    let transformed: Tensor = tensor
114        .iter()
115        .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0))
116        .collect();
117    println!("  Result: {:?}", transformed.data());
118
119    // Filtering elements
120    println!("\nFiltering elements (values > 3.0):");
121    let filtered: Tensor = tensor.iter().filter(|elem| elem.value() > 3.0).collect();
122    println!("  Filtered: {:?}", filtered.data());
123
124    Ok(())
125}
126
127/// Demonstrate standard iterator trait methods
128///
129/// Shows compatibility with Rust's standard library iterator methods
130/// and demonstrates various functional programming patterns.
131fn demonstrate_standard_methods() -> Result<(), Box<dyn std::error::Error>> {
132    println!("\n--- Standard Iterator Methods ---");
133
134    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
135
136    // Using map for transformations
137    println!("\nMap transformation (square each element):");
138    let squared: Tensor = tensor.iter().map(|elem| elem.pow_scalar(2.0)).collect();
139    println!("  Squared: {:?}", squared.data());
140
141    // Using enumerate for indexed operations
142    println!("\nEnumerate with indexed operations:");
143    let indexed: Tensor = tensor
144        .iter()
145        .enumerate()
146        .map(|(i, elem)| elem.add_scalar(i as f32))
147        .collect();
148    println!("  Indexed: {:?}", indexed.data());
149
150    // Using fold for reduction
151    println!("\nFold for sum calculation:");
152    let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
153    println!("  Sum: {:.1}", sum);
154
155    // Using find for element search
156    println!("\nFind specific element:");
157    if let Some(found) = tensor.iter().find(|elem| elem.value() == 3.0) {
158        println!("  Found element with value 3.0: {:.1}", found.value());
159    }
160
161    // Using any/all for condition checking
162    println!("\nCondition checking:");
163    let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
164    let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
165    println!("  All positive: {}", all_positive);
166    println!("  Any > 4.0: {}", any_large);
167
168    Ok(())
169}
170
171/// Demonstrate gradient tracking through element operations
172///
173/// Shows how gradient tracking works seamlessly through iterator
174/// operations, maintaining the computational graph for backpropagation.
175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176    println!("\n--- Gradient Tracking ---");
177
178    // Create a tensor with gradient tracking enabled
179    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180    println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182    // Perform element-wise operations through iteration
183    let result: Tensor = tensor
184        .iter()
185        .map(|elem| {
186            // Apply a complex transformation: (x^2 + 1) * 2
187            elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188        })
189        .collect();
190
191    println!("Result tensor: {:?}", result.data());
192    println!("Result requires_grad: {}", result.requires_grad());
193
194    // Compute gradients
195    let mut loss = result.sum();
196    loss.backward(None);
197
198    println!("Loss: {:.6}", loss.value());
199    println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201    Ok(())
202}
203
204/// Demonstrate advanced iterator patterns
205///
206/// Shows complex iterator chains and advanced functional programming
207/// patterns for sophisticated data processing workflows.
208fn demonstrate_advanced_patterns() -> Result<(), Box<dyn std::error::Error>> {
209    println!("\n--- Advanced Iterator Patterns ---");
210
211    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6])?;
212    println!("Input tensor: {:?}", tensor.data());
213
214    // Complex chain: enumerate -> filter -> map -> collect
215    println!("\nComplex chain (even indices only, add index to value):");
216    let result: Tensor = tensor
217        .iter()
218        .enumerate()
219        .filter(|(i, _)| i % 2 == 0) // Take even indices
220        .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
221        .collect();
222    println!("  Result: {:?}", result.data());
223
224    // Using take and skip for windowing
225    println!("\nWindowing with take and skip:");
226    let window1: Tensor = tensor.iter().take(3).collect();
227    let window2: Tensor = tensor.iter().skip(2).take(3).collect();
228    println!("  Window 1 (first 3): {:?}", window1.data());
229    println!("  Window 2 (middle 3): {:?}", window2.data());
230
231    // Using rev() for reverse iteration
232    println!("\nReverse iteration:");
233    let reversed: Tensor = tensor.iter().rev().collect();
234    println!("  Reversed: {:?}", reversed.data());
235
236    // Chaining with mathematical operations
237    println!("\nMathematical operation chain:");
238    let math_result: Tensor = tensor
239        .iter()
240        .map(|elem| elem.exp()) // e^x
241        .filter(|elem| elem.value() < 50.0) // Filter large values
242        .map(|elem| elem.log()) // ln(x)
243        .collect();
244    println!("  Math chain result: {:?}", math_result.data());
245
246    // Using zip for element-wise combinations
247    println!("\nElement-wise combination with zip:");
248    let tensor2 = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![6])?;
249    let combined: Tensor = tensor
250        .iter()
251        .zip(tensor2.iter())
252        .map(|(a, b)| a.mul_tensor(&b)) // Element-wise multiplication
253        .collect();
254    println!("  Combined: {:?}", combined.data());
255
256    Ok(())
257}
258
259/// Demonstrate per-row transforms with shape-preserving collection
260///
261/// Shows how to use `iter()` over the outer dimension on a 2D tensor and
262/// `collect_shape([..])` to maintain the original shape after mapping.
263fn demonstrate_row_wise_collect_shape() -> Result<(), Box<dyn std::error::Error>> {
264    println!("\n--- Row-wise iteration with collect_shape ---");
265    let mat = Tensor::from_slice(&(1..=12).map(|x| x as f32).collect::<Vec<_>>(), vec![3, 4])?;
266    println!("Input shape: {:?}", mat.shape().dims());
267
268    // Map each row: 1.1*x + 0.5, then collect back to [3,4]
269    let out: Tensor = mat
270        .iter()
271        .map(|row| row.mul_scalar(1.1).add_scalar(0.5))
272        .collect_shape(vec![3, 4]);
273    println!("  Output shape: {:?}", out.shape().dims());
274
275    Ok(())
276}
277
278/// Demonstrate NoGrad fast paths and raw data streaming
279///
280/// Highlights how to get maximum performance in inference by:
281/// - Disabling gradient tracking with `with_no_grad`
282/// - Iterating raw values via `tensor.data().iter().copied()`
283/// - Using `collect_shape` to stream directly into destination tensors
284fn demonstrate_nograd_and_streaming() -> Result<(), Box<dyn std::error::Error>> {
285    println!("\n--- NoGrad & Streaming (Inference Fast Paths) ---");
286
287    let input = Tensor::from_slice(
288        &(0..24).map(|i| i as f32 * 0.25).collect::<Vec<_>>(),
289        vec![4, 6],
290    )?;
291    println!("Input shape: {:?}", input.shape().dims());
292
293    // NoGrad: stream values directly and reshape
294    let out = with_no_grad(|| {
295        input
296            .data()
297            .iter()
298            .copied()
299            .map(|x| 1.2 * x - 0.3)
300            .collect_shape(vec![4, 6])
301    });
302    println!(
303        "  NoGrad streamed map (1.2x-0.3) -> shape {:?}",
304        out.shape().dims()
305    );
306
307    // Compare to view-based element iteration in NoGrad
308    let out_view: Tensor = with_no_grad(|| {
309        input
310            .iter()
311            .map(|e| e.mul_scalar(1.2).add_scalar(-0.3))
312            .collect_shape(vec![4, 6])
313    });
314    println!(
315        "  NoGrad view-based map shape {:?}",
316        out_view.shape().dims()
317    );
318
319    // Quick parity check
320    assert_eq!(out.data(), out_view.data());
321    println!("  Parity check passed.");
322
323    // Show simple flatten + collect back to a different shape
324    let reshaped = with_no_grad(|| input.data().iter().copied().collect_shape(vec![6, 4]));
325    println!(
326        "  Reshaped via streaming collect_shape: {:?}",
327        reshaped.shape().dims()
328    );
329
330    Ok(())
331}
examples/getting_started/tensor_operators.rs (line 233)
203fn demonstrate_method_equivalence() {
204    println!("\n--- Operator vs Method Call Equivalence ---");
205
206    let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
207    let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
208
209    // Addition: operator vs method
210    let operator_result = &a + &b;
211    let method_result = a.add_tensor(&b);
212
213    println!("A + B (operator): {:?}", operator_result.data());
214    println!("A.add_tensor(B): {:?}", method_result.data());
215    println!(
216        "Results are equal: {}",
217        operator_result.data() == method_result.data()
218    );
219
220    // Multiplication: operator vs method
221    let operator_result = &a * &b;
222    let method_result = a.mul_tensor(&b);
223
224    println!("A * B (operator): {:?}", operator_result.data());
225    println!("A.mul_tensor(B): {:?}", method_result.data());
226    println!(
227        "Results are equal: {}",
228        operator_result.data() == method_result.data()
229    );
230
231    // Scalar addition: operator vs method
232    let operator_result = &a + 5.0;
233    let method_result = a.add_scalar(5.0);
234
235    println!("A + 5.0 (operator): {:?}", operator_result.data());
236    println!("A.add_scalar(5.0): {:?}", method_result.data());
237    println!(
238        "Results are equal: {}",
239        operator_result.data() == method_result.data()
240    );
241}
Source§

impl Tensor

Source

pub fn div_tensor(&self, other: &Tensor) -> Tensor

Element-wise division with another tensor with broadcasting support.

Performs element-wise division with automatic broadcasting: output[i] = self[i] / other[i]

Broadcasting enables division between tensors of different but compatible shapes. Compatible shapes follow NumPy broadcasting rules:

  • Dimensions are aligned from the rightmost dimension
  • Dimensions are compatible if they are equal, or one of them is 1
  • Missing dimensions are treated as 1
§Arguments
  • other - Tensor to divide by. Shapes must be broadcast-compatible.
§Returns

A new tensor containing the element-wise quotient with broadcast result shape

§Examples
§Same Shape Division
use train_station::Tensor;

let a = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![3]).unwrap();
let b = Tensor::from_slice(&[2.0, 4.0, 5.0], vec![3]).unwrap();
let c = a.div_tensor(&b);
assert_eq!(c.shape().dims(), vec![3]);
assert_eq!(c.get(&[0]), 5.0);
assert_eq!(c.get(&[1]), 5.0);
assert_eq!(c.get(&[2]), 6.0);
§Broadcasting Division
use train_station::Tensor;

// Broadcasting: [2, 1] / [1, 3] -> [2, 3]
let a = Tensor::from_slice(&[10.0, 20.0], vec![2, 1]).unwrap();
let b = Tensor::from_slice(&[1.0, 2.0, 5.0], vec![1, 3]).unwrap();
let c = a.div_tensor(&b);
assert_eq!(c.shape().dims(), vec![2, 3]);
assert_eq!(c.get(&[0, 0]), 10.0);
assert_eq!(c.get(&[0, 1]), 5.0);
assert_eq!(c.get(&[1, 0]), 20.0);
assert_eq!(c.get(&[1, 1]), 10.0);
§Scalar Division
use train_station::Tensor;

// Scalar division: [2, 3] / scalar -> [2, 3]
let a = Tensor::ones(vec![2, 3]);
let b = Tensor::from_slice(&[2.0], vec![1]).unwrap();
let c = a.div_tensor(&b);
assert_eq!(c.shape().dims(), vec![2, 3]);
assert_eq!(c.get(&[0, 0]), 0.5);
assert_eq!(c.get(&[1, 2]), 0.5);
§Panics

Panics if tensor shapes are not broadcast-compatible or division by zero

Examples found in repository?
examples/RL_training/ppo_continuous.rs (line 242)
234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235    // All tensors shaped [B, A] (log_std is broadcastable)
236    let std = log_std.exp();
237    let var = std.pow_scalar(2.0);
238    let log_scale = log_std;
239    let diff = action.sub_tensor(mean);
240    let log_prob = diff
241        .pow_scalar(2.0)
242        .div_tensor(&var)
243        .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244        .add_tensor(&log_scale.mul_scalar(2.0))
245        .mul_scalar(0.5)
246        .mul_scalar(-1.0);
247    // Sum across action dim (dim=1) -> [B,1]
248    log_prob.sum_dims(&[1], true)
249}
More examples
Hide additional examples
examples/getting_started/tensor_basics.rs (line 108)
89fn demonstrate_basic_operations() {
90    println!("\n--- Basic Operations ---");
91
92    let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
93    let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
94
95    // Addition
96    let sum = a.add_tensor(&b);
97    println!("A + B: {:?}", sum.data());
98
99    // Subtraction
100    let diff = a.sub_tensor(&b);
101    println!("A - B: {:?}", diff.data());
102
103    // Multiplication
104    let product = a.mul_tensor(&b);
105    println!("A * B: {:?}", product.data());
106
107    // Division
108    let quotient = a.div_tensor(&b);
109    println!("A / B: {:?}", quotient.data());
110
111    // Scalar operations
112    let scalar_add = a.add_scalar(5.0);
113    println!("A + 5.0: {:?}", scalar_add.data());
114
115    let scalar_mul = a.mul_scalar(2.0);
116    println!("A * 2.0: {:?}", scalar_mul.data());
117}
Source

pub fn div_scalar(&self, scalar: f32) -> Tensor

Broadcast division with a scalar value.

Divides every element by the scalar: output[i] = self[i] / scalar

§Arguments
  • scalar - Value to divide each element by (must not be zero)
§Returns

A new tensor with each element divided by the scalar

§Examples
§Basic Scalar Division
use train_station::Tensor;

let a = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![3]).unwrap();
let b = a.div_scalar(10.0);
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 1.0);
assert_eq!(b.get(&[1]), 2.0);
assert_eq!(b.get(&[2]), 3.0);
§Multi-dimensional Scalar Division
use train_station::Tensor;

let a = Tensor::ones(vec![2, 3]);
let b = a.div_scalar(2.0);
assert_eq!(b.shape().dims(), vec![2, 3]);
assert_eq!(b.get(&[0, 0]), 0.5);
assert_eq!(b.get(&[1, 2]), 0.5);
§Panics

Panics if scalar is zero

Examples found in repository?
examples/neural_networks/multi_head_attention.rs (line 92)
72    pub fn forward(
73        &self,
74        query: &Tensor,
75        key: &Tensor,
76        value: &Tensor,
77        attn_mask: Option<&Tensor>,
78    ) -> Tensor {
79        let qkv = Self::project_qkv(query, key, value, &self.q_proj, &self.k_proj, &self.v_proj);
80        let (q, k, v) = qkv;
81
82        // Split heads: [b, t, e] -> [b, h, t, d]
83        let (b, tq, _e) = Self::triple(query);
84        let (_b2, tk, _e2) = Self::triple(key);
85        let q = Self::split_heads(&q, b, tq, self.num_heads, self.head_dim);
86        let k = Self::split_heads(&k, b, tk, self.num_heads, self.head_dim);
87        let v = Self::split_heads(&v, b, tk, self.num_heads, self.head_dim);
88
89        // Scaled dot-product attention
90        // logits: [b, h, tq, tk]
91        let k_t = k.transpose(2, 3);
92        let mut logits = q.matmul(&k_t).div_scalar((self.head_dim as f32).sqrt());
93        if let Some(mask) = attn_mask {
94            let dims = mask.shape().dims().to_vec();
95            // If boolean-like mask matching [b,h,tq,tk], apply masked_fill
96            if dims.len() == 4 && dims[0] == b && dims[1] == self.num_heads && dims[2] == tq {
97                // Interpret mask > 0.5 as keep; we invert to build masked positions
98                let cond: Vec<bool> = mask.data().iter().map(|&v| v < 0.5).collect();
99                // Apply masked fill on a flattened view, then reshape back
100                let flat_logits = logits.view(vec![(b * self.num_heads * tq * tk) as i32]);
101                let filled = flat_logits.masked_fill(&cond, f32::NEG_INFINITY);
102                logits = filled.view(vec![b as i32, self.num_heads as i32, tq as i32, tk as i32]);
103            } else {
104                // Fallback: additive mask
105                logits = logits.add_tensor(mask);
106            }
107        }
108        let attn = logits.softmax(3);
109
110        // context: [b, h, tq, d]
111        let context = attn.matmul(&v);
112        let context = context.permute(vec![0, 2, 1, 3]); // [b, tq, h, d]
113        let context = context.contiguous().view(vec![
114            b as i32,
115            tq as i32,
116            (self.num_heads * self.head_dim) as i32,
117        ]);
118
119        // Output projection (flatten to 2D, project, then restore 3D)
120        let flat = context.view(vec![(b * tq) as i32, self.embed_dim as i32]);
121        let out2d = self.out_proj.forward(&flat);
122        out2d.view(vec![b as i32, tq as i32, self.embed_dim as i32])
123    }
More examples
Hide additional examples
examples/iterators/advanced_patterns.rs (line 114)
87fn demonstrate_data_pipeline() -> Result<(), Box<dyn std::error::Error>> {
88    println!("\n--- Data Processing Pipeline ---");
89
90    // Simulate raw sensor data with noise
91    let raw_data: Vec<f32> = (0..20)
92        .map(|i| {
93            let base = i as f32 * 0.5;
94            let noise = (i % 3) as f32 * 0.1;
95            base + noise
96        })
97        .collect();
98
99    let tensor = Tensor::from_slice(&raw_data, vec![20])?;
100    println!("Raw sensor data: {:?}", tensor.data());
101
102    // Multi-stage processing pipeline
103    println!("\nProcessing pipeline:");
104    println!("1. Normalize data (z-score)");
105    println!("2. Apply smoothing filter");
106    println!("3. Detect outliers");
107    println!("4. Apply feature scaling");
108
109    // Stage 1: Normalization
110    let mean = tensor.mean().value();
111    let std = tensor.std().value();
112    let normalized: Tensor = tensor
113        .iter()
114        .map(|elem| elem.sub_scalar(mean).div_scalar(std))
115        .collect();
116    println!(
117        "  Normalized (mean={:.3}, std={:.3}): {:?}",
118        mean,
119        std,
120        normalized.data()
121    );
122
123    // Stage 2: Smoothing (simple moving average)
124    let smoothed: Tensor = normalized
125        .iter()
126        .enumerate()
127        .map(|(i, elem)| {
128            if i == 0 || i == normalized.size() - 1 {
129                elem.clone()
130            } else {
131                // Simple 3-point average
132                let prev = normalized.element_view(i - 1);
133                let next = normalized.element_view(i + 1);
134                elem.add_tensor(&prev).add_tensor(&next).div_scalar(3.0)
135            }
136        })
137        .collect();
138    println!("  Smoothed: {:?}", smoothed.data());
139
140    // Stage 3: Outlier detection and removal
141    let outlier_threshold = 2.0;
142    let cleaned: Tensor = smoothed
143        .iter()
144        .filter(|elem| elem.value().abs() < outlier_threshold)
145        .collect();
146    println!(
147        "  Outliers removed (threshold={}): {:?}",
148        outlier_threshold,
149        cleaned.data()
150    );
151
152    // Stage 4: Feature scaling to [0, 1] range
153    let min_val = cleaned
154        .iter()
155        .map(|e| e.value())
156        .fold(f32::INFINITY, f32::min);
157    let max_val = cleaned
158        .iter()
159        .map(|e| e.value())
160        .fold(f32::NEG_INFINITY, f32::max);
161    let scaled: Tensor = cleaned
162        .iter()
163        .map(|elem| elem.sub_scalar(min_val).div_scalar(max_val - min_val))
164        .collect();
165    println!("  Scaled to [0,1]: {:?}", scaled.data());
166
167    Ok(())
168}
169
170/// Demonstrate conditional processing patterns
171///
172/// Shows how to implement dynamic filtering and transformation
173/// based on data characteristics and conditions.
174fn demonstrate_conditional_processing() -> Result<(), Box<dyn std::error::Error>> {
175    println!("\n--- Conditional Processing ---");
176
177    // Create data with mixed characteristics
178    let data = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0];
179    let tensor = Tensor::from_slice(&data, vec![10])?;
180    println!("Input data: {:?}", tensor.data());
181
182    // Conditional transformation based on sign
183    println!("\nConditional transformation (positive/negative handling):");
184    let processed: Tensor = tensor
185        .iter()
186        .map(|elem| {
187            let val = elem.value();
188            if val > 0.0 {
189                elem.pow_scalar(2.0) // Square positive values
190            } else {
191                elem.mul_scalar(-1.0).sqrt() // Square root of absolute negative values
192            }
193        })
194        .collect();
195    println!("  Processed: {:?}", processed.data());
196
197    // Adaptive filtering based on local statistics
198    println!("\nAdaptive filtering (remove values > 2 std from local mean):");
199    let window_size = 3;
200    let adaptive_filtered: Tensor = tensor
201        .iter()
202        .enumerate()
203        .filter(|(i, elem)| {
204            let start = i.saturating_sub(window_size / 2);
205            let end = (i + window_size / 2 + 1).min(tensor.size());
206
207            // Calculate local mean and std
208            let local_values: Vec<f32> = (start..end)
209                .map(|j| tensor.element_view(j).value())
210                .collect();
211
212            let local_mean = local_values.iter().sum::<f32>() / local_values.len() as f32;
213            let local_variance = local_values
214                .iter()
215                .map(|v| (v - local_mean).powi(2))
216                .sum::<f32>()
217                / local_values.len() as f32;
218            let local_std = local_variance.sqrt();
219
220            let threshold = local_mean + 2.0 * local_std;
221            elem.value() <= threshold
222        })
223        .map(|(_, elem)| elem)
224        .collect();
225    println!("  Adaptive filtered: {:?}", adaptive_filtered.data());
226
227    // Multi-condition processing
228    println!("\nMulti-condition processing:");
229    let multi_processed: Tensor = tensor
230        .iter()
231        .map(|elem| {
232            let val = elem.value();
233            match () {
234                _ if val > 5.0 => elem.mul_scalar(2.0), // Double large values
235                _ if val < -5.0 => elem.div_scalar(2.0), // Halve small values
236                _ if val.abs() < 2.0 => elem.add_scalar(1.0), // Add 1 to small values
237                _ => elem.clone(),                      // Keep others unchanged
238            }
239        })
240        .collect();
241    println!("  Multi-condition: {:?}", multi_processed.data());
242
243    Ok(())
244}
245
246/// Demonstrate batch processing operations
247///
248/// Shows efficient processing of large datasets using iterator
249/// patterns and batch operations for performance optimization.
250fn demonstrate_batch_operations() -> Result<(), Box<dyn std::error::Error>> {
251    println!("\n--- Batch Operations ---");
252
253    // Create a larger dataset for batch processing
254    let size = 100;
255    let data: Vec<f32> = (0..size)
256        .map(|i| {
257            let x = i as f32 / size as f32;
258            x * x + 0.1 * (i % 7) as f32 // Quadratic with some noise
259        })
260        .collect();
261
262    let tensor = Tensor::from_slice(&data, vec![size])?;
263    println!("Dataset size: {}", tensor.size());
264
265    // Batch processing with windowing (iterator views)
266    println!("\nBatch processing with sliding windows:");
267    let batch_size = 10;
268    let batches: Vec<Tensor> = tensor
269        .iter()
270        .collect::<Vec<_>>()
271        .chunks(batch_size)
272        .map(|chunk| {
273            // Process each batch independently
274            chunk
275                .iter()
276                .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
277                .collect()
278        })
279        .collect();
280
281    println!(
282        "  Processed {} batches of size {}",
283        batches.len(),
284        batch_size
285    );
286    for (i, batch) in batches.iter().enumerate() {
287        println!(
288            "    Batch {}: mean={:.3}, std={:.3}",
289            i,
290            batch.mean().value(),
291            batch.std().value()
292        );
293    }
294
295    // Parallel-like processing with stride
296    println!("\nStrided processing (every nth element):");
297    let stride = 5;
298    let strided: Tensor = tensor
299        .iter()
300        .enumerate()
301        .filter(|(i, _)| i % stride == 0)
302        .map(|(_, elem)| elem)
303        .collect();
304    println!("  Strided (every {}th): {:?}", stride, strided.data());
305
306    // Hierarchical processing
307    println!("\nHierarchical processing (coarse to fine):");
308    let coarse: Tensor = tensor
309        .iter()
310        .enumerate()
311        .filter(|(i, _)| i % 4 == 0) // Take every 4th element
312        .map(|(_, elem)| elem)
313        .collect();
314
315    let fine: Tensor = tensor
316        .iter()
317        .enumerate()
318        .filter(|(i, _)| i % 4 != 0) // Take the rest
319        .map(|(_, elem)| elem)
320        .collect();
321
322    println!("  Coarse (every 4th): {:?}", coarse.data());
323    println!("  Fine (rest): {:?}", fine.data());
324
325    // Combine coarse and fine with different processing
326    let combined: Tensor = coarse
327        .iter()
328        .map(|elem| elem.mul_scalar(2.0)) // Scale coarse
329        .chain(fine.iter().map(|elem| elem.div_scalar(2.0))) // Scale fine
330        .collect();
331    println!("  Combined: {:?}", combined.data());
332
333    Ok(())
334}
335
336/// Demonstrate real-world processing scenarios
337///
338/// Shows practical applications of iterator patterns for
339/// common data processing tasks in machine learning and analytics.
340fn demonstrate_real_world_scenarios() -> Result<(), Box<dyn std::error::Error>> {
341    println!("\n--- Real-world Scenarios ---");
342
343    // Scenario 1: Time series analysis
344    println!("\nScenario 1: Time Series Analysis");
345    let time_series: Vec<f32> = (0..24)
346        .map(|hour| {
347            let base = 20.0 + 10.0 * (hour as f32 * std::f32::consts::PI / 12.0).sin();
348            base + (hour % 3) as f32 * 2.0 // Add some noise
349        })
350        .collect();
351
352    let series = Tensor::from_slice(&time_series, vec![24])?;
353    println!("  Time series (24 hours): {:?}", series.data());
354
355    // Calculate moving average with view-based iteration
356    let window_size = 3;
357    let moving_avg: Tensor = series
358        .iter()
359        .enumerate()
360        .map(|(i, _)| {
361            let start = i.saturating_sub(window_size / 2);
362            let end = (i + window_size / 2 + 1).min(series.size());
363            let window = series.iter_range(start, end);
364            window.fold(0.0, |acc, elem| acc + elem.value()) / (end - start) as f32
365        })
366        .map(|val| Tensor::from_slice(&[val], vec![1]).unwrap())
367        .collect();
368    println!(
369        "  Moving average (window={}): {:?}",
370        window_size,
371        moving_avg.data()
372    );
373
374    // Inference pipeline with NoGrad + streaming
375    println!("\nInference pipeline (NoGrad + streaming)");
376    let features = Tensor::from_slice(
377        &(0..48).map(|i| i as f32 * 0.125).collect::<Vec<_>>(),
378        vec![6, 8],
379    )?;
380    let fast = with_no_grad(|| {
381        // Stream values directly, apply light affine, and collect back to same shape
382        features
383            .data()
384            .iter()
385            .copied()
386            .map(|x| 0.75 * x + 0.1)
387            .collect_shape(vec![6, 8])
388    });
389    println!(
390        "  NoGrad streamed transform shape: {:?}",
391        fast.shape().dims()
392    );
393
394    // Row-wise iteration with shape-preserving collection (GradTrack-friendly)
395    let per_row: Tensor = features
396        .iter()
397        .map(|row| row.mul_scalar(0.5).add_scalar(2.0))
398        .collect_shape(vec![6, 8]);
399    println!("  Row-wise mapped shape: {:?}", per_row.shape().dims());
400
401    // Scenario 2: Feature engineering
402    println!("\nScenario 2: Feature Engineering");
403    let features = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
404    println!("  Original features: {:?}", features.data());
405
406    // Create polynomial features
407    let poly_features: Tensor = features
408        .iter()
409        .flat_map(|elem| {
410            vec![
411                elem.clone(),         // x^1
412                elem.pow_scalar(2.0), // x^2
413                elem.pow_scalar(3.0), // x^3
414            ]
415        })
416        .collect();
417    println!(
418        "  Polynomial features (x, x^2, x^3): {:?}",
419        poly_features.data()
420    );
421
422    // Scenario 3: Data augmentation
423    println!("\nScenario 3: Data Augmentation");
424    let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?;
425    println!("  Original data: {:?}", original.data());
426
427    // Augment with noise and scaling
428    let augmented: Tensor = original
429        .iter()
430        .flat_map(|elem| {
431            vec![
432                elem.clone(),         // Original
433                elem.add_scalar(0.1), // Add noise
434                elem.sub_scalar(0.1), // Subtract noise
435                elem.mul_scalar(1.1), // Scale up
436                elem.mul_scalar(0.9), // Scale down
437            ]
438        })
439        .collect();
440    println!("  Augmented data: {:?}", augmented.data());
441
442    // Scenario 4: Statistical analysis
443    println!("\nScenario 4: Statistical Analysis");
444    let sample_data = Tensor::from_slice(&[1.1, 2.3, 1.8, 2.1, 1.9, 2.0, 1.7, 2.2], vec![8])?;
445    println!("  Sample data: {:?}", sample_data.data());
446
447    // Calculate various statistics
448    let mean = sample_data.mean().value();
449    let std = sample_data.std().value();
450    let min = sample_data
451        .iter()
452        .map(|e| e.value())
453        .fold(f32::INFINITY, f32::min);
454    let max = sample_data
455        .iter()
456        .map(|e| e.value())
457        .fold(f32::NEG_INFINITY, f32::max);
458
459    // Z-score normalization
460    let z_scores: Tensor = sample_data
461        .iter()
462        .map(|elem| elem.sub_scalar(mean).div_scalar(std))
463        .collect();
464
465    println!(
466        "  Statistics: mean={:.3}, std={:.3}, min={:.3}, max={:.3}",
467        mean, std, min, max
468    );
469    println!("  Z-scores: {:?}", z_scores.data());
470
471    Ok(())
472}
Source§

impl Tensor

Source

pub fn exp(&self) -> Tensor

Element-wise exponential function.

Computes e^x for each element: output[i] = e^(self[i])

§Returns

A new tensor with the exponential of each element

§Examples
§Basic Exponential
use train_station::Tensor;

let a = Tensor::from_slice(&[0.0, 1.0, 2.0], vec![3]).unwrap();
let b = a.exp();
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 1.0); // e^0 = 1
assert!((b.get(&[1]) - 2.71828).abs() < 1e-5); // e^1 ≈ 2.71828
assert!((b.get(&[2]) - 7.38906).abs() < 1e-5); // e^2 ≈ 7.38906
§Negative Values
use train_station::Tensor;

let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
let b = a.exp();
assert_eq!(b.shape().dims(), vec![3]);
assert!((b.get(&[0]) - 0.36788).abs() < 1e-5); // e^(-1) ≈ 0.36788
assert_eq!(b.get(&[1]), 1.0); // e^0 = 1
assert!((b.get(&[2]) - 2.71828).abs() < 1e-5); // e^1 ≈ 2.71828
Examples found in repository?
examples/supervised_training/supervised_bce.rs (line 64)
59fn bce_with_logits(logits: &Tensor, targets: &Tensor) -> Tensor {
60    let relu_z = logits.relu();
61    let zy = logits.mul_tensor(targets);
62    // |z| = relu(z) + relu(-z)
63    let abs_z = relu_z.add_tensor(&logits.mul_scalar(-1.0).relu());
64    let log_term = abs_z.mul_scalar(-1.0).exp().add_scalar(1.0).log();
65    relu_z.sub_tensor(&zy).add_tensor(&log_term).mean()
66}
More examples
Hide additional examples
examples/supervised_training/supervised_classification.rs (line 53)
44fn cross_entropy_logits(
45    logits: &Tensor,
46    labels: &[usize],
47    batch: usize,
48    _num_classes: usize,
49) -> Tensor {
50    // log_softmax = logits - logsumexp(logits, dim=1)
51    let max_logits = logits.max_dims(&[1], true);
52    let shifted = logits.sub_tensor(&max_logits);
53    let exp = shifted.exp();
54    let sum_exp = exp.sum_dims(&[1], true);
55    let log_sum_exp = sum_exp.log();
56    let log_softmax = shifted.sub_tensor(&log_sum_exp);
57    let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58    ll.mul_scalar(-1.0).mean()
59}
examples/RL_training/ppo_discrete.rs (line 281)
273fn log_prob_actions(
274    logits: &Tensor,
275    actions: &[usize],
276    batch: usize,
277    _action_dim: usize,
278) -> Tensor {
279    let max_logits = logits.max_dims(&[1], true); // [B,1]
280    let shifted = logits.sub_tensor(&max_logits);
281    let exp = shifted.exp();
282    let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283    let log_sum_exp = sum_exp.log(); // [B,1]
284    let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285                                                        // gather selected action log-probs
286    log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291    new_logp.sub_tensor(old_logp).exp()
292}
examples/RL_training/ppo_continuous.rs (line 236)
234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235    // All tensors shaped [B, A] (log_std is broadcastable)
236    let std = log_std.exp();
237    let var = std.pow_scalar(2.0);
238    let log_scale = log_std;
239    let diff = action.sub_tensor(mean);
240    let log_prob = diff
241        .pow_scalar(2.0)
242        .div_tensor(&var)
243        .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244        .add_tensor(&log_scale.mul_scalar(2.0))
245        .mul_scalar(0.5)
246        .mul_scalar(-1.0);
247    // Sum across action dim (dim=1) -> [B,1]
248    log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253    returns_out: &mut [f32],
254    adv_out: &mut [f32],
255    rewards: &[f32],
256    dones: &[f32],
257    values: &[f32],
258    next_values: &[f32],
259    gamma: f32,
260    lam: f32,
261) {
262    let n = rewards.len();
263    let mut gae = 0.0f32;
264    for t in (0..n).rev() {
265        let not_done = 1.0 - dones[t];
266        let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267        gae = delta + gamma * lam * not_done * gae;
268        adv_out[t] = gae;
269        returns_out[t] = gae + values[t];
270    }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274    let n = x.len() as f32;
275    if n <= 1.0 {
276        return;
277    }
278    let mean = x.iter().copied().sum::<f32>() / n;
279    let var = x
280        .iter()
281        .map(|v| {
282            let d = v - mean;
283            d * d
284        })
285        .sum::<f32>()
286        / n;
287    let std = (var + eps).sqrt();
288    for v in x.iter_mut() {
289        *v = (*v - mean) / std;
290    }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294    let mut total_sq = 0.0f32;
295    for p in parameters.iter() {
296        if let Some(g) = p.grad_owned() {
297            for &v in g.data() {
298                total_sq += v * v;
299            }
300        }
301    }
302    let norm = total_sq.sqrt();
303    if norm > max_norm {
304        let scale = max_norm / (norm + eps);
305        for p in parameters.iter_mut() {
306            if let Some(g) = p.grad_owned() {
307                p.set_grad(g.mul_scalar(scale));
308            }
309        }
310    }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314    let mut total_sq = 0.0f32;
315    for p in parameters.iter_mut() {
316        if let Some(g) = p.grad_owned() {
317            for &v in g.data() {
318                total_sq += v * v;
319            }
320        }
321    }
322    total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330    println!("=== PPO Continuous Example (YardEnv) ===");
331
332    let state_dim = 3usize;
333    let action_dim = 1usize;
334
335    // Hparams
336    let total_steps = std::env::var("PPO_STEPS")
337        .ok()
338        .and_then(|v| v.parse::<usize>().ok())
339        .unwrap_or(4000usize);
340    let horizon = 128usize; // rollout length per update
341    let epochs = 4usize; // PPO epochs per update
342    let mini_batch_size = 64usize; // minibatch from horizon
343    let gamma = 0.99f32;
344    let lam = 0.95f32; // GAE lambda
345    let clip_eps = 0.2f32;
346    let vf_coef = 0.5f32;
347    let ent_coef = 0.0f32;
348    let max_grad_norm = 1.0f32;
349
350    // Models
351    let mut actor = Actor::new(state_dim, action_dim, Some(101));
352    let mut critic = Critic::new(state_dim, Some(202));
353
354    // Opts
355    let mut actor_opt = Adam::with_learning_rate(3e-4);
356    for p in actor.parameters() {
357        actor_opt.add_parameter(p);
358    }
359    let mut critic_opt = Adam::with_learning_rate(3e-4);
360    for p in critic.parameters() {
361        critic_opt.add_parameter(p);
362    }
363
364    // Env and RNG
365    let mut env = YardEnv::new(42);
366    let mut rng = SmallRng::new(999);
367    let mut state = env.reset();
368
369    // Metrics
370    let mut episode_return = 0.0f32;
371    let mut episode = 0usize;
372    let mut ema_return: Option<f32> = None;
373    let ema_alpha = 0.05f32;
374    let mut best_return = f32::NEG_INFINITY;
375
376    let mut t = 0usize;
377    while t < total_steps {
378        // Collect a rollout
379        let mut batch = RolloutBatch::new(horizon, state_dim);
380        for _ in 0..horizon {
381            // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382            let (mean, log_std_row) = actor.forward(&state);
383            let mean_v = mean.data()[0];
384            let log_std_v = log_std_row.data()[0];
385            let std_v = log_std_v.exp();
386            let noise = rng.normal();
387            let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389            // Build action tensor [1, A] for log_prob calculation with autograd
390            let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391            let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392            let log_prob_v = log_prob_t.data()[0];
393
394            // Step env
395            let (next_state, reward, done) = env.step(action_v);
396            episode_return += reward;
397
398            // Value
399            let value_t = critic.forward(&state);
400            let value_v = value_t.data()[0];
401
402            // Push
403            batch.push(
404                state.data(),
405                action_v,
406                log_prob_v,
407                reward,
408                if done { 1.0 } else { 0.0 },
409                value_v,
410                next_state.data(),
411            );
412
413            // Reset
414            state = if done {
415                let st = env.reset();
416                ema_return = Some(match ema_return {
417                    None => episode_return,
418                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419                });
420                if episode_return > best_return {
421                    best_return = episode_return;
422                }
423                println!(
424                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425                    t,
426                    episode,
427                    episode_return,
428                    ema_return.unwrap_or(episode_return),
429                    best_return
430                );
431                episode_return = 0.0;
432                episode += 1;
433                st
434            } else {
435                next_state
436            };
437
438            t += 1;
439            if t >= total_steps {
440                break;
441            }
442        }
443
444        // Bootstrap next values for GAE
445        let next_values: Vec<f32> = {
446            let mut out = Vec::with_capacity(batch.len());
447            for i in 0..batch.len() {
448                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450                let v2 = critic.forward(&s2_t).data()[0];
451                out.push(v2);
452            }
453            out
454        };
455
456        // Compute returns and advantages
457        let mut returns = vec![0.0f32; batch.len()];
458        let mut adv = vec![0.0f32; batch.len()];
459        compute_gae(
460            &mut returns,
461            &mut adv,
462            &batch.rewards,
463            &batch.dones,
464            &batch.values,
465            &next_values,
466            gamma,
467            lam,
468        );
469        normalize_in_place(&mut adv, 1e-8);
470
471        // Prepare tensors for training
472        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473        let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474        let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478        // PPO epochs over the rollout
479        let num_minibatches = batch.len().div_ceil(mini_batch_size);
480        for e in 0..epochs {
481            for mb in 0..num_minibatches {
482                let start = mb * mini_batch_size;
483                let end = (start + mini_batch_size).min(batch.len());
484                if start >= end {
485                    break;
486                }
487
488                // Slice views
489                let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490                let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491                let a_mb = actions_t
492                    .slice_view(start * action_dim, 1, (end - start) * action_dim)
493                    .reshape(vec![(end - start) as i32, action_dim as i32]);
494                let oldlp_mb = old_logp_t
495                    .slice_view(start, 1, end - start)
496                    .reshape(vec![(end - start) as i32, 1]);
497                let ret_mb = returns_t
498                    .slice_view(start, 1, end - start)
499                    .reshape(vec![(end - start) as i32, 1]);
500                let adv_mb = adv_t
501                    .slice_view(start, 1, end - start)
502                    .reshape(vec![(end - start) as i32, 1]);
503
504                // Zero grads
505                {
506                    let mut ps = actor.parameters();
507                    actor_opt.zero_grad(&mut ps);
508                }
509                {
510                    let mut ps = critic.parameters();
511                    critic_opt.zero_grad(&mut ps);
512                }
513
514                // Forward actor and critic
515                let (mean_mb, log_std_row) = actor.forward(&s_mb);
516                let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517                let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518                let clip_low =
519                    Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520                        .unwrap();
521                let clip_high =
522                    Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523                        .unwrap();
524                // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525                let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526                let ratio_clipped =
527                    clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528                let pg1 = ratio.mul_tensor(&adv_mb);
529                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532                let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534                let v_pred = critic.forward(&s_mb);
535                let v_loss = v_pred
536                    .sub_tensor(&ret_mb)
537                    .pow_scalar(2.0)
538                    .mean()
539                    .mul_scalar(vf_coef);
540
541                // Entropy (approx Gaussian entropy per action)
542                let entropy = log_std_row
543                    .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544                    .sum_dims(&[1], true)
545                    .mean()
546                    .mul_scalar(ent_coef);
547
548                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549                loss.backward(None);
550
551                // Step actor
552                {
553                    let params = actor.parameters();
554                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
555                    for p in params {
556                        if p.grad_owned().is_some() {
557                            with_grads.push(p);
558                        }
559                    }
560                    if !with_grads.is_empty() {
561                        let _ = grad_global_norm(&mut with_grads);
562                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563                        actor_opt.step(&mut with_grads);
564                        actor_opt.zero_grad(&mut with_grads);
565                    }
566                }
567
568                // Step critic
569                {
570                    let params = critic.parameters();
571                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
572                    for p in params {
573                        if p.grad_owned().is_some() {
574                            with_grads.push(p);
575                        }
576                    }
577                    if !with_grads.is_empty() {
578                        let _ = grad_global_norm(&mut with_grads);
579                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580                        critic_opt.step(&mut with_grads);
581                        critic_opt.zero_grad(&mut with_grads);
582                    }
583                }
584
585                // Occasionally log
586                if e == 0 && mb == 0 {
587                    println!(
588                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
589                        t,
590                        actor_loss.value(),
591                        v_loss.value()
592                    );
593                }
594
595                clear_all_graphs_known();
596            }
597        }
598    }
599
600    println!("=== PPO training finished ===");
601    Ok(())
602}
examples/iterators/element_iteration.rs (line 240)
208fn demonstrate_advanced_patterns() -> Result<(), Box<dyn std::error::Error>> {
209    println!("\n--- Advanced Iterator Patterns ---");
210
211    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6])?;
212    println!("Input tensor: {:?}", tensor.data());
213
214    // Complex chain: enumerate -> filter -> map -> collect
215    println!("\nComplex chain (even indices only, add index to value):");
216    let result: Tensor = tensor
217        .iter()
218        .enumerate()
219        .filter(|(i, _)| i % 2 == 0) // Take even indices
220        .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
221        .collect();
222    println!("  Result: {:?}", result.data());
223
224    // Using take and skip for windowing
225    println!("\nWindowing with take and skip:");
226    let window1: Tensor = tensor.iter().take(3).collect();
227    let window2: Tensor = tensor.iter().skip(2).take(3).collect();
228    println!("  Window 1 (first 3): {:?}", window1.data());
229    println!("  Window 2 (middle 3): {:?}", window2.data());
230
231    // Using rev() for reverse iteration
232    println!("\nReverse iteration:");
233    let reversed: Tensor = tensor.iter().rev().collect();
234    println!("  Reversed: {:?}", reversed.data());
235
236    // Chaining with mathematical operations
237    println!("\nMathematical operation chain:");
238    let math_result: Tensor = tensor
239        .iter()
240        .map(|elem| elem.exp()) // e^x
241        .filter(|elem| elem.value() < 50.0) // Filter large values
242        .map(|elem| elem.log()) // ln(x)
243        .collect();
244    println!("  Math chain result: {:?}", math_result.data());
245
246    // Using zip for element-wise combinations
247    println!("\nElement-wise combination with zip:");
248    let tensor2 = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![6])?;
249    let combined: Tensor = tensor
250        .iter()
251        .zip(tensor2.iter())
252        .map(|(a, b)| a.mul_tensor(&b)) // Element-wise multiplication
253        .collect();
254    println!("  Combined: {:?}", combined.data());
255
256    Ok(())
257}
Source§

impl Tensor

Source

pub fn leaky_relu(&self, negative_slope: f32) -> Tensor

Element-wise Leaky ReLU activation.

Applies Leaky ReLU to each element: output[i] = max(0, x) + negative_slope * min(0, x)

Unlike standard ReLU, allows a small gradient when the unit is not active.

§Arguments
  • negative_slope - Slope for negative values (typically small, e.g., 0.01 or 0.1)
§Returns

A new tensor with Leaky ReLU applied to each element

§Examples
§Basic Leaky ReLU
use train_station::Tensor;

let a = Tensor::from_slice(&[-2.0, -1.0, 0.0, 1.0], vec![4]).unwrap();
let b = a.leaky_relu(0.1);
assert_eq!(b.shape().dims(), vec![4]);
assert!((b.get(&[0]) - (-0.2)).abs() < 1e-6); // -2.0 * 0.1 = -0.2
assert!((b.get(&[1]) - (-0.1)).abs() < 1e-6); // -1.0 * 0.1 = -0.1
assert_eq!(b.get(&[2]), 0.0); // max(0, 0) = 0
assert_eq!(b.get(&[3]), 1.0); // max(0, 1) = 1
§Different Negative Slopes
use train_station::Tensor;

let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
let b = a.leaky_relu(0.01); // Smaller negative slope
assert_eq!(b.shape().dims(), vec![3]);
assert!((b.get(&[0]) - (-0.01)).abs() < 1e-6); // -1.0 * 0.01 = -0.01
assert_eq!(b.get(&[1]), 0.0); // max(0, 0) = 0
assert_eq!(b.get(&[2]), 1.0); // max(0, 1) = 1
Source§

impl Tensor

Source

pub fn log(&self) -> Tensor

Element-wise natural logarithm.

Computes the natural logarithm for each element: output[i] = ln(self[i])

§Returns

A new tensor with the natural logarithm of each element

§Examples
§Basic Natural Logarithm
use train_station::Tensor;

let a = Tensor::from_slice(&[1.0, 2.71828, 7.38906], vec![3]).unwrap();
let b = a.log();
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 0.0); // ln(1) = 0
assert!((b.get(&[1]) - 1.0).abs() < 1e-5); // ln(e) ≈ 1
assert!((b.get(&[2]) - 2.0).abs() < 1e-5); // ln(e^2) ≈ 2
§Mathematical Properties
use train_station::Tensor;

let a = Tensor::from_slice(&[4.0, 8.0, 16.0], vec![3]).unwrap();
let b = a.log();
assert_eq!(b.shape().dims(), vec![3]);
assert!((b.get(&[0]) - 1.38629).abs() < 1e-5); // ln(4) ≈ 1.38629
assert!((b.get(&[1]) - 2.07944).abs() < 1e-5); // ln(8) ≈ 2.07944
assert!((b.get(&[2]) - 2.77259).abs() < 1e-5); // ln(16) ≈ 2.77259
§Panics

Panics if any element is non-positive (x <= 0)

Examples found in repository?
examples/supervised_training/supervised_bce.rs (line 64)
59fn bce_with_logits(logits: &Tensor, targets: &Tensor) -> Tensor {
60    let relu_z = logits.relu();
61    let zy = logits.mul_tensor(targets);
62    // |z| = relu(z) + relu(-z)
63    let abs_z = relu_z.add_tensor(&logits.mul_scalar(-1.0).relu());
64    let log_term = abs_z.mul_scalar(-1.0).exp().add_scalar(1.0).log();
65    relu_z.sub_tensor(&zy).add_tensor(&log_term).mean()
66}
More examples
Hide additional examples
examples/supervised_training/supervised_classification.rs (line 55)
44fn cross_entropy_logits(
45    logits: &Tensor,
46    labels: &[usize],
47    batch: usize,
48    _num_classes: usize,
49) -> Tensor {
50    // log_softmax = logits - logsumexp(logits, dim=1)
51    let max_logits = logits.max_dims(&[1], true);
52    let shifted = logits.sub_tensor(&max_logits);
53    let exp = shifted.exp();
54    let sum_exp = exp.sum_dims(&[1], true);
55    let log_sum_exp = sum_exp.log();
56    let log_softmax = shifted.sub_tensor(&log_sum_exp);
57    let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58    ll.mul_scalar(-1.0).mean()
59}
examples/RL_training/ppo_discrete.rs (line 283)
273fn log_prob_actions(
274    logits: &Tensor,
275    actions: &[usize],
276    batch: usize,
277    _action_dim: usize,
278) -> Tensor {
279    let max_logits = logits.max_dims(&[1], true); // [B,1]
280    let shifted = logits.sub_tensor(&max_logits);
281    let exp = shifted.exp();
282    let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283    let log_sum_exp = sum_exp.log(); // [B,1]
284    let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285                                                        // gather selected action log-probs
286    log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291    new_logp.sub_tensor(old_logp).exp()
292}
293
294// Clamp ratio to [1-clip, 1+clip] using ReLU-based clamp (no custom ops)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296    let b = ratio.shape().dims()[0];
297    let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298    let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299    let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300    high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304    let mut total_sq = 0.0f32;
305    for p in parameters.iter_mut() {
306        if let Some(g) = p.grad_owned() {
307            for &v in g.data() {
308                total_sq += v * v;
309            }
310        }
311    }
312    total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320    println!("=== PPO Discrete Example (YardEnv) ===");
321
322    let state_dim = 3usize;
323    let action_dim = 3usize;
324    let total_steps = std::env::var("PPOD_STEPS")
325        .ok()
326        .and_then(|v| v.parse::<usize>().ok())
327        .unwrap_or(3500usize);
328    let horizon = 128usize;
329    let epochs = 4usize;
330    let mini_batch_size = 64usize;
331    let gamma = 0.99f32;
332    let lam = 0.95f32;
333    let clip_eps = 0.2f32;
334    let vf_coef = 0.5f32;
335    let ent_coef = 0.0f32;
336    let max_grad_norm = 1.0f32;
337
338    let mut actor = Actor::new(state_dim, action_dim, Some(111));
339    let mut critic = Critic::new(state_dim, Some(222));
340    let mut actor_opt = Adam::with_learning_rate(3e-4);
341    for p in actor.parameters() {
342        actor_opt.add_parameter(p);
343    }
344    let mut critic_opt = Adam::with_learning_rate(3e-4);
345    for p in critic.parameters() {
346        critic_opt.add_parameter(p);
347    }
348
349    let mut env = YardEnv::new(1234);
350    let mut rng = SmallRng::new(98765);
351    let mut state = env.reset();
352    let mut episode_return = 0.0f32;
353    let mut episode = 0usize;
354    let mut ema_return: Option<f32> = None;
355    let ema_alpha = 0.05f32;
356    let mut best_return = f32::NEG_INFINITY;
357
358    let mut t = 0usize;
359    while t < total_steps {
360        let mut batch = RolloutBatch::new(horizon, state_dim);
361        for _ in 0..horizon {
362            // Actor logits and categorical sampling
363            let logits = actor.forward(&state); // [1, A]
364            let probs = logits.softmax(1); // [1, A]
365                                           // sample action from probs (CPU sampling)
366            let p = probs.data();
367            let (p0, p1, _p2) = (p[0], p[1], p[2]);
368            let u = rng.next_f32();
369            let a_idx = if u < p0 {
370                0
371            } else if u < p0 + p1 {
372                1
373            } else {
374                2
375            };
376
377            let old_logp = {
378                let _ng = NoGradTrack::new();
379                let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380                lp.data()[0]
381            };
382
383            // Step env
384            let (next_state, reward, done) = env.step(a_idx);
385            episode_return += reward;
386
387            // Critic value
388            let value_t = critic.forward(&state);
389            let value_v = value_t.data()[0];
390
391            batch.push(
392                state.data(),
393                a_idx,
394                old_logp,
395                reward,
396                if done { 1.0 } else { 0.0 },
397                value_v,
398                next_state.data(),
399            );
400
401            state = if done {
402                let st = env.reset();
403                ema_return = Some(match ema_return {
404                    None => episode_return,
405                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406                });
407                if episode_return > best_return {
408                    best_return = episode_return;
409                }
410                println!(
411                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412                    t,
413                    episode,
414                    episode_return,
415                    ema_return.unwrap_or(episode_return),
416                    best_return
417                );
418                episode_return = 0.0;
419                episode += 1;
420                st
421            } else {
422                next_state
423            };
424
425            t += 1;
426            if t >= total_steps {
427                break;
428            }
429        }
430
431        // Bootstrap values for GAE
432        let next_values: Vec<f32> = {
433            let mut out = Vec::with_capacity(batch.len());
434            for i in 0..batch.len() {
435                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437                out.push(critic.forward(&s2_t).data()[0]);
438            }
439            out
440        };
441
442        let mut returns = vec![0.0f32; batch.len()];
443        let mut adv = vec![0.0f32; batch.len()];
444        compute_gae(
445            &mut returns,
446            &mut adv,
447            &batch.rewards,
448            &batch.dones,
449            &batch.values,
450            &next_values,
451            gamma,
452            lam,
453        );
454        normalize_in_place(&mut adv, 1e-8);
455
456        // Tensors for training
457        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458        let actions_vec = batch.actions.clone();
459        let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463        // PPO epochs
464        let num_minibatches = batch.len().div_ceil(mini_batch_size);
465        for e in 0..epochs {
466            for mb in 0..num_minibatches {
467                let start = mb * mini_batch_size;
468                let end = (start + mini_batch_size).min(batch.len());
469                if start >= end {
470                    break;
471                }
472
473                // Views
474                let s_mb = states_t
475                    .slice_view(start * state_dim, 1, (end - start) * state_dim)
476                    .reshape(vec![(end - start) as i32, state_dim as i32]);
477                let oldlp_mb = old_logp_t
478                    .slice_view(start, 1, end - start)
479                    .reshape(vec![(end - start) as i32, 1]);
480                let ret_mb = returns_t
481                    .slice_view(start, 1, end - start)
482                    .reshape(vec![(end - start) as i32, 1]);
483                let adv_mb = adv_t
484                    .slice_view(start, 1, end - start)
485                    .reshape(vec![(end - start) as i32, 1]);
486                let a_slice = &actions_vec[start..end];
487
488                // Zero grads
489                {
490                    let mut ps = actor.parameters();
491                    actor_opt.zero_grad(&mut ps);
492                }
493                {
494                    let mut ps = critic.parameters();
495                    critic_opt.zero_grad(&mut ps);
496                }
497
498                // Forward
499                let logits_mb = actor.forward(&s_mb); // [B,A]
500                let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501                let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502                let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503                let pg1 = ratio.mul_tensor(&adv_mb);
504                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507                let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509                let v_pred = critic.forward(&s_mb);
510                let v_loss = v_pred
511                    .sub_tensor(&ret_mb)
512                    .pow_scalar(2.0)
513                    .mean()
514                    .mul_scalar(vf_coef);
515
516                // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517                let probs_mb = logits_mb.softmax(1);
518                let logp_all = probs_mb.add_scalar(1e-8).log();
519                let ent = probs_mb
520                    .mul_tensor(&logp_all)
521                    .sum_dims(&[1], true)
522                    .mul_scalar(-1.0)
523                    .mean()
524                    .mul_scalar(ent_coef);
525
526                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527                loss.backward(None);
528
529                // Step actor
530                {
531                    let params = actor.parameters();
532                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
533                    for p in params {
534                        if p.grad_owned().is_some() {
535                            with_grads.push(p);
536                        }
537                    }
538                    if !with_grads.is_empty() {
539                        let _ = grad_global_norm(&mut with_grads);
540                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541                        actor_opt.step(&mut with_grads);
542                        actor_opt.zero_grad(&mut with_grads);
543                    }
544                }
545
546                // Step critic
547                {
548                    let params = critic.parameters();
549                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
550                    for p in params {
551                        if p.grad_owned().is_some() {
552                            with_grads.push(p);
553                        }
554                    }
555                    if !with_grads.is_empty() {
556                        let _ = grad_global_norm(&mut with_grads);
557                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558                        critic_opt.step(&mut with_grads);
559                        critic_opt.zero_grad(&mut with_grads);
560                    }
561                }
562
563                if e == 0 && mb == 0 {
564                    println!(
565                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
566                        t,
567                        actor_loss.value(),
568                        v_loss.value()
569                    );
570                }
571
572                clear_all_graphs_known();
573            }
574        }
575    }
576
577    println!("=== PPO discrete training finished ===");
578    Ok(())
579}
examples/iterators/element_iteration.rs (line 242)
208fn demonstrate_advanced_patterns() -> Result<(), Box<dyn std::error::Error>> {
209    println!("\n--- Advanced Iterator Patterns ---");
210
211    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6])?;
212    println!("Input tensor: {:?}", tensor.data());
213
214    // Complex chain: enumerate -> filter -> map -> collect
215    println!("\nComplex chain (even indices only, add index to value):");
216    let result: Tensor = tensor
217        .iter()
218        .enumerate()
219        .filter(|(i, _)| i % 2 == 0) // Take even indices
220        .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
221        .collect();
222    println!("  Result: {:?}", result.data());
223
224    // Using take and skip for windowing
225    println!("\nWindowing with take and skip:");
226    let window1: Tensor = tensor.iter().take(3).collect();
227    let window2: Tensor = tensor.iter().skip(2).take(3).collect();
228    println!("  Window 1 (first 3): {:?}", window1.data());
229    println!("  Window 2 (middle 3): {:?}", window2.data());
230
231    // Using rev() for reverse iteration
232    println!("\nReverse iteration:");
233    let reversed: Tensor = tensor.iter().rev().collect();
234    println!("  Reversed: {:?}", reversed.data());
235
236    // Chaining with mathematical operations
237    println!("\nMathematical operation chain:");
238    let math_result: Tensor = tensor
239        .iter()
240        .map(|elem| elem.exp()) // e^x
241        .filter(|elem| elem.value() < 50.0) // Filter large values
242        .map(|elem| elem.log()) // ln(x)
243        .collect();
244    println!("  Math chain result: {:?}", math_result.data());
245
246    // Using zip for element-wise combinations
247    println!("\nElement-wise combination with zip:");
248    let tensor2 = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![6])?;
249    let combined: Tensor = tensor
250        .iter()
251        .zip(tensor2.iter())
252        .map(|(a, b)| a.mul_tensor(&b)) // Element-wise multiplication
253        .collect();
254    println!("  Combined: {:?}", combined.data());
255
256    Ok(())
257}
Source§

impl Tensor

Source

pub fn matmul(&self, other: &Tensor) -> Tensor

Matrix multiplication with intelligent kernel dispatch

Performs matrix multiplication using optimized SIMD kernels selected based on:

  • Runtime SIMD capability (AVX512/AVX2/SSE2/Scalar)
  • Matrix operation type (1D@1D, 1D@2D, 2D@1D, 2D@2D, ND@ND)
  • Matrix size classification (Small/Medium/Large)
  • Memory alignment characteristics
§Arguments
  • other - Right-hand side tensor for multiplication
§Returns

Result tensor with appropriate shape based on operation type

§Panics

Panics if tensor shapes are incompatible for matrix multiplication

§Examples
use train_station::Tensor;

// 2D matrix multiplication
let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
let result = a.matmul(&b);
assert_eq!(result.shape().dims(), vec![2, 2]);

// 1D dot product
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
let result = a.matmul(&b);
assert_eq!(result.shape().dims(), vec![]); // Scalar
Examples found in repository?
examples/RL_training/../neural_networks/basic_linear_layer.rs (line 73)
71    pub fn forward(&self, input: &Tensor) -> Tensor {
72        // Matrix multiplication: [batch_size, input_size] @ [input_size, output_size] = [batch_size, output_size]
73        let output = input.matmul(&self.weight);
74        // Add bias: [batch_size, output_size] + [output_size] = [batch_size, output_size]
75        output.add_tensor(&self.bias)
76    }
More examples
Hide additional examples
examples/optimizers/adam_configurations.rs (line 111)
84fn demonstrate_default_adam() -> Result<(), Box<dyn std::error::Error>> {
85    println!("--- Default Adam Configuration ---");
86
87    // Create a simple regression problem: y = 2*x + 1
88    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
89    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
90
91    // Create model parameters
92    let mut weight = Tensor::randn(vec![1, 1], Some(42)).with_requires_grad();
93    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
94
95    // Create Adam optimizer with default configuration
96    let mut optimizer = Adam::new();
97    optimizer.add_parameter(&weight);
98    optimizer.add_parameter(&bias);
99
100    println!("Default Adam configuration:");
101    println!("  Learning rate: {}", optimizer.learning_rate());
102    println!("  Initial weight: {:.6}", weight.value());
103    println!("  Initial bias: {:.6}", bias.value());
104
105    // Training loop
106    let num_epochs = 50;
107    let mut losses = Vec::new();
108
109    for epoch in 0..num_epochs {
110        // Forward pass
111        let y_pred = x_data.matmul(&weight) + &bias;
112        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
113
114        // Backward pass
115        loss.backward(None);
116
117        // Optimizer step
118        optimizer.step(&mut [&mut weight, &mut bias]);
119        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
120
121        losses.push(loss.value());
122
123        if epoch % 10 == 0 || epoch == num_epochs - 1 {
124            println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
125        }
126    }
127
128    // Evaluate final model
129    let _final_predictions = x_data.matmul(&weight) + &bias;
130    println!("\nFinal model:");
131    println!("  Learned weight: {:.6} (target: 2.0)", weight.value());
132    println!("  Learned bias: {:.6} (target: 1.0)", bias.value());
133    println!("  Final loss: {:.6}", losses[losses.len() - 1]);
134
135    Ok(())
136}
137
138/// Demonstrate learning rate comparison
139fn demonstrate_learning_rate_comparison() -> Result<(), Box<dyn std::error::Error>> {
140    println!("\n--- Learning Rate Comparison ---");
141
142    let learning_rates = [0.001, 0.01, 0.1];
143    let mut results = Vec::new();
144
145    for &lr in &learning_rates {
146        println!("\nTesting learning rate: {}", lr);
147
148        let stats = train_with_config(TrainingConfig {
149            learning_rate: lr,
150            ..Default::default()
151        })?;
152
153        results.push((lr, stats.clone()));
154
155        println!("  Final loss: {:.6}", stats.final_loss);
156        println!("  Convergence epoch: {}", stats.convergence_epoch);
157    }
158
159    // Compare results
160    println!("\nLearning Rate Comparison Summary:");
161    for (lr, stats) in &results {
162        println!(
163            "  LR={:6}: Loss={:.6}, Converged@{}",
164            lr, stats.final_loss, stats.convergence_epoch
165        );
166    }
167
168    Ok(())
169}
170
171/// Demonstrate weight decay comparison
172fn demonstrate_weight_decay_comparison() -> Result<(), Box<dyn std::error::Error>> {
173    println!("\n--- Weight Decay Comparison ---");
174
175    let weight_decays = [0.0, 0.001, 0.01];
176    let mut results = Vec::new();
177
178    for &wd in &weight_decays {
179        println!("\nTesting weight decay: {}", wd);
180
181        let stats = train_with_config(TrainingConfig {
182            weight_decay: wd,
183            ..Default::default()
184        })?;
185
186        results.push((wd, stats.clone()));
187
188        println!("  Final loss: {:.6}", stats.final_loss);
189        println!("  Final weight norm: {:.6}", stats.weight_norm);
190    }
191
192    // Compare results
193    println!("\nWeight Decay Comparison Summary:");
194    for (wd, stats) in &results {
195        println!(
196            "  WD={:6}: Loss={:.6}, Weight Norm={:.6}",
197            wd, stats.final_loss, stats.weight_norm
198        );
199    }
200
201    Ok(())
202}
203
204/// Demonstrate beta parameter tuning
205fn demonstrate_beta_parameter_tuning() -> Result<(), Box<dyn std::error::Error>> {
206    println!("\n--- Beta Parameter Tuning ---");
207
208    let beta_configs = [
209        (0.9, 0.999),  // Default
210        (0.8, 0.999),  // More aggressive momentum
211        (0.95, 0.999), // Less aggressive momentum
212        (0.9, 0.99),   // Faster second moment decay
213    ];
214
215    let mut results = Vec::new();
216
217    for (i, (beta1, beta2)) in beta_configs.iter().enumerate() {
218        println!(
219            "\nTesting beta configuration {}: beta1={}, beta2={}",
220            i + 1,
221            beta1,
222            beta2
223        );
224
225        let config = TrainingConfig {
226            beta1: *beta1,
227            beta2: *beta2,
228            ..Default::default()
229        };
230
231        let stats = train_with_config(config)?;
232        results.push(((*beta1, *beta2), stats.clone()));
233
234        println!("  Final loss: {:.6}", stats.final_loss);
235        println!("  Convergence epoch: {}", stats.convergence_epoch);
236    }
237
238    // Compare results
239    println!("\nBeta Parameter Comparison Summary:");
240    for ((beta1, beta2), stats) in &results {
241        println!(
242            "  B1={:4}, B2={:5}: Loss={:.6}, Converged@{}",
243            beta1, beta2, stats.final_loss, stats.convergence_epoch
244        );
245    }
246
247    Ok(())
248}
249
250/// Demonstrate configuration benchmarking
251fn demonstrate_configuration_benchmarking() -> Result<(), Box<dyn std::error::Error>> {
252    println!("\n--- Configuration Benchmarking ---");
253
254    // Define configurations to benchmark
255    let configs = vec![
256        (
257            "Conservative",
258            TrainingConfig {
259                learning_rate: 0.001,
260                weight_decay: 0.001,
261                beta1: 0.95,
262                ..Default::default()
263            },
264        ),
265        (
266            "Balanced",
267            TrainingConfig {
268                learning_rate: 0.01,
269                weight_decay: 0.0,
270                beta1: 0.9,
271                ..Default::default()
272            },
273        ),
274        (
275            "Aggressive",
276            TrainingConfig {
277                learning_rate: 0.1,
278                weight_decay: 0.0,
279                beta1: 0.8,
280                ..Default::default()
281            },
282        ),
283    ];
284
285    let mut benchmark_results = Vec::new();
286
287    for (name, config) in configs {
288        println!("\nBenchmarking {} configuration:", name);
289
290        let start_time = std::time::Instant::now();
291        let stats = train_with_config(config.clone())?;
292        let elapsed = start_time.elapsed();
293
294        println!("  Training time: {:.2}ms", elapsed.as_millis());
295        println!("  Final loss: {:.6}", stats.final_loss);
296        println!("  Convergence: {} epochs", stats.convergence_epoch);
297
298        benchmark_results.push((name.to_string(), stats, elapsed));
299    }
300
301    // Summary
302    println!("\nBenchmarking Summary:");
303    for (name, stats, elapsed) in &benchmark_results {
304        println!(
305            "  {:12}: Loss={:.6}, Time={:4}ms, Converged@{}",
306            name,
307            stats.final_loss,
308            elapsed.as_millis(),
309            stats.convergence_epoch
310        );
311    }
312
313    Ok(())
314}
315
316/// Helper function to train with specific configuration
317fn train_with_config(config: TrainingConfig) -> Result<TrainingStats, Box<dyn std::error::Error>> {
318    // Create training data
319    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
320    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
321
322    // Create model parameters
323    let mut weight = Tensor::randn(vec![1, 1], Some(123)).with_requires_grad();
324    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
325
326    // Create optimizer with custom configuration
327    let adam_config = AdamConfig {
328        learning_rate: config.learning_rate,
329        beta1: config.beta1,
330        beta2: config.beta2,
331        eps: 1e-8,
332        weight_decay: config.weight_decay,
333        amsgrad: false,
334    };
335
336    let mut optimizer = Adam::with_config(adam_config);
337    optimizer.add_parameter(&weight);
338    optimizer.add_parameter(&bias);
339
340    // Training loop
341    let mut losses = Vec::new();
342    let mut convergence_epoch = config.epochs;
343
344    for epoch in 0..config.epochs {
345        // Forward pass
346        let y_pred = x_data.matmul(&weight) + &bias;
347        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
348
349        // Backward pass
350        loss.backward(None);
351
352        // Optimizer step
353        optimizer.step(&mut [&mut weight, &mut bias]);
354        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
355
356        let loss_value = loss.value();
357        losses.push(loss_value);
358
359        // Check for convergence (loss < 0.01)
360        if loss_value < 0.01 && convergence_epoch == config.epochs {
361            convergence_epoch = epoch;
362        }
363    }
364
365    Ok(TrainingStats {
366        config,
367        final_loss: losses[losses.len() - 1],
368        loss_history: losses,
369        convergence_epoch,
370        weight_norm: weight.norm().value(),
371    })
372}
examples/optimizers/learning_rate_scheduling.rs (line 343)
319fn train_with_scheduler(
320    scheduler: &mut dyn LearningRateScheduler,
321    num_epochs: usize,
322) -> Result<TrainingStats, Box<dyn std::error::Error>> {
323    // Create training data: y = 2*x + 1
324    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
325    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
326
327    // Create model parameters
328    let mut weight = Tensor::randn(vec![1, 1], Some(456)).with_requires_grad();
329    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
330
331    // Create optimizer with initial learning rate
332    let mut optimizer = Adam::with_learning_rate(0.05);
333    optimizer.add_parameter(&weight);
334    optimizer.add_parameter(&bias);
335
336    // Training loop
337    let mut losses = Vec::new();
338    let mut lr_history = Vec::new();
339    let mut convergence_epoch = num_epochs;
340
341    for epoch in 0..num_epochs {
342        // Forward pass
343        let y_pred = x_data.matmul(&weight) + &bias;
344        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
345
346        // Backward pass
347        loss.backward(None);
348
349        // Update learning rate using scheduler
350        let current_lr = optimizer.learning_rate();
351        let new_lr = scheduler.step(current_lr, epoch, loss.value());
352
353        if (new_lr - current_lr).abs() > 1e-8 {
354            optimizer.set_learning_rate(new_lr);
355        }
356
357        // Optimizer step
358        optimizer.step(&mut [&mut weight, &mut bias]);
359        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
360
361        let loss_value = loss.value();
362        losses.push(loss_value);
363        lr_history.push(new_lr);
364
365        // Check for convergence
366        if loss_value < 0.01 && convergence_epoch == num_epochs {
367            convergence_epoch = epoch;
368        }
369    }
370
371    Ok(TrainingStats {
372        scheduler_name: scheduler.name().to_string(),
373        final_loss: losses[losses.len() - 1],
374        lr_history,
375        loss_history: losses,
376        convergence_epoch,
377    })
378}
examples/getting_started/optimizer_basics.rs (line 132)
105fn demonstrate_linear_regression() -> Result<(), Box<dyn std::error::Error>> {
106    println!("\n--- Linear Regression Training ---");
107
108    // Create model parameters
109    let mut weight = Tensor::randn(vec![1, 1], Some(43)).with_requires_grad();
110    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
111
112    // Create optimizer
113    let mut optimizer = Adam::with_learning_rate(0.01);
114    optimizer.add_parameter(&weight);
115    optimizer.add_parameter(&bias);
116
117    // Create simple training data: y = 2*x + 1
118    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
119    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
120
121    println!("Training data:");
122    println!("  X: {:?}", x_data.data());
123    println!("  Y: {:?}", y_true.data());
124    println!("  Target: y = 2*x + 1");
125
126    // Training loop
127    let num_epochs = 100;
128    let mut losses = Vec::new();
129
130    for epoch in 0..num_epochs {
131        // Forward pass: y_pred = x * weight + bias
132        let y_pred = x_data.matmul(&weight) + &bias;
133
134        // Compute loss: MSE
135        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
136
137        // Backward pass
138        loss.backward(None);
139
140        // Optimizer step
141        optimizer.step(&mut [&mut weight, &mut bias]);
142        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
143
144        losses.push(loss.value());
145
146        // Print progress every 20 epochs
147        if epoch % 20 == 0 || epoch == num_epochs - 1 {
148            println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
149        }
150    }
151
152    // Evaluate final model
153    let final_predictions = x_data.matmul(&weight) + &bias;
154    println!("\nFinal model evaluation:");
155    println!("  Learned weight: {:.6}", weight.value());
156    println!("  Learned bias: {:.6}", bias.value());
157    println!("  Predictions vs True:");
158
159    for i in 0..5 {
160        let x1 = x_data.data()[i];
161        let pred = final_predictions.data()[i];
162        let true_val = y_true.data()[i];
163        println!(
164            "    x={:.1}: pred={:.3}, true={:.1}, error={:.3}",
165            x1,
166            pred,
167            true_val,
168            (pred - true_val).abs()
169        );
170    }
171
172    Ok(())
173}
174
175/// Demonstrate advanced training patterns
176fn demonstrate_advanced_training() -> Result<(), Box<dyn std::error::Error>> {
177    println!("\n--- Advanced Training Patterns ---");
178
179    // Create a more complex model
180    let mut weight = Tensor::randn(vec![1, 2], Some(44)).with_requires_grad();
181    let mut bias = Tensor::zeros(vec![2]).with_requires_grad();
182
183    // Create optimizer with different learning rate
184    let mut optimizer = Adam::with_learning_rate(0.005);
185    optimizer.add_parameter(&weight);
186    optimizer.add_parameter(&bias);
187
188    // Create training data: y = 2*x + [1, 3]
189    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
190    let y_true = Tensor::from_slice(
191        &[3.0, 5.0, 7.0, 9.0, 11.0, 6.0, 8.0, 10.0, 12.0, 14.0],
192        vec![5, 2],
193    )
194    .unwrap();
195
196    println!("Advanced training with monitoring:");
197    println!("  Initial learning rate: {}", optimizer.learning_rate());
198
199    // Training loop with monitoring
200    let num_epochs = 50;
201    let mut losses = Vec::new();
202    let mut weight_norms = Vec::new();
203    let mut gradient_norms = Vec::new();
204
205    for epoch in 0..num_epochs {
206        // Forward pass
207        let y_pred = x_data.matmul(&weight) + &bias;
208        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
209
210        // Backward pass
211        loss.backward(None);
212
213        // Compute gradient norm before optimizer step
214        let gradient_norm = weight.grad_owned().unwrap().norm();
215
216        // Optimizer step
217        optimizer.step(&mut [&mut weight, &mut bias]);
218        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
219
220        // Learning rate scheduling: reduce every 10 epochs
221        if epoch > 0 && epoch % 10 == 0 {
222            let current_lr = optimizer.learning_rate();
223            let new_lr = current_lr * 0.5;
224            optimizer.set_learning_rate(new_lr);
225            println!(
226                "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
227                epoch, current_lr, new_lr
228            );
229        }
230
231        // Record metrics
232        losses.push(loss.value());
233        weight_norms.push(weight.norm().value());
234        gradient_norms.push(gradient_norm.value());
235
236        // Print detailed progress
237        if epoch % 10 == 0 || epoch == num_epochs - 1 {
238            println!(
239                "Epoch {:2}: Loss = {:.6}, Weight Norm = {:.6}, Gradient Norm = {:.6}",
240                epoch,
241                loss.value(),
242                weight.norm().value(),
243                gradient_norm.value()
244            );
245        }
246    }
247
248    println!("Final learning rate: {}", optimizer.learning_rate());
249
250    // Analyze training progression
251    let initial_loss = losses[0];
252    let final_loss = losses[losses.len() - 1];
253    let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
254
255    println!("\nTraining Analysis:");
256    println!("  Initial loss: {:.6}", initial_loss);
257    println!("  Final loss: {:.6}", final_loss);
258    println!("  Loss reduction: {:.1}%", loss_reduction);
259    println!("  Final weight norm: {:.6}", weight.norm().value());
260    println!("  Final bias: {:?}", bias.data());
261
262    Ok(())
263}
264
265/// Demonstrate learning rate scheduling
266fn demonstrate_learning_rate_scheduling() -> Result<(), Box<dyn std::error::Error>> {
267    println!("\n--- Learning Rate Scheduling ---");
268
269    // Create simple model
270    let mut weight = Tensor::randn(vec![1, 1], Some(45)).with_requires_grad();
271    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
272
273    // Create optimizer with high initial learning rate
274    let mut optimizer = Adam::with_learning_rate(0.1);
275    optimizer.add_parameter(&weight);
276    optimizer.add_parameter(&bias);
277
278    // Simple data
279    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3, 1]).unwrap();
280    let y_true = Tensor::from_slice(&[2.0, 4.0, 6.0], vec![3, 1]).unwrap();
281
282    println!("Initial learning rate: {}", optimizer.learning_rate());
283
284    // Training loop with learning rate scheduling
285    let num_epochs = 50;
286    let mut losses = Vec::new();
287
288    for epoch in 0..num_epochs {
289        // Forward pass
290        let y_pred = x_data.matmul(&weight) + &bias;
291        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
292
293        // Backward pass
294        loss.backward(None);
295
296        // Optimizer step
297        optimizer.step(&mut [&mut weight, &mut bias]);
298        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
299
300        // Learning rate scheduling: reduce every 10 epochs
301        if epoch > 0 && epoch % 10 == 0 {
302            let current_lr = optimizer.learning_rate();
303            let new_lr = current_lr * 0.5;
304            optimizer.set_learning_rate(new_lr);
305            println!(
306                "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
307                epoch, current_lr, new_lr
308            );
309        }
310
311        losses.push(loss.value());
312
313        // Print progress
314        if epoch % 10 == 0 || epoch == num_epochs - 1 {
315            println!(
316                "Epoch {:2}: Loss = {:.6}, LR = {:.3}",
317                epoch,
318                loss.value(),
319                optimizer.learning_rate()
320            );
321        }
322    }
323
324    println!("Final learning rate: {}", optimizer.learning_rate());
325
326    Ok(())
327}
328
329/// Demonstrate training monitoring and analysis
330fn demonstrate_training_monitoring() -> Result<(), Box<dyn std::error::Error>> {
331    println!("\n--- Training Monitoring ---");
332
333    // Create model
334    let mut weight = Tensor::randn(vec![1, 1], Some(46)).with_requires_grad();
335    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
336
337    // Create optimizer
338    let mut optimizer = Adam::with_learning_rate(0.01);
339    optimizer.add_parameter(&weight);
340    optimizer.add_parameter(&bias);
341
342    // Training data
343    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
344    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0], vec![4, 1]).unwrap();
345
346    // Training loop with comprehensive monitoring
347    let num_epochs = 30;
348    let mut losses = Vec::new();
349    let mut weight_history = Vec::new();
350    let mut bias_history = Vec::new();
351
352    for epoch in 0..num_epochs {
353        // Forward pass
354        let y_pred = x_data.matmul(&weight) + &bias;
355        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
356
357        // Backward pass
358        loss.backward(None);
359
360        // Optimizer step
361        optimizer.step(&mut [&mut weight, &mut bias]);
362        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
363
364        // Record history
365        losses.push(loss.value());
366        weight_history.push(weight.value());
367        bias_history.push(bias.value());
368
369        // Print detailed monitoring
370        if epoch % 5 == 0 || epoch == num_epochs - 1 {
371            println!(
372                "Epoch {:2}: Loss = {:.6}, Weight = {:.6}, Bias = {:.6}",
373                epoch,
374                loss.value(),
375                weight.value(),
376                bias.value()
377            );
378        }
379    }
380
381    // Analyze training progression
382    println!("\nTraining Analysis:");
383    println!("  Initial loss: {:.6}", losses[0]);
384    println!("  Final loss: {:.6}", losses[losses.len() - 1]);
385    println!(
386        "  Loss reduction: {:.1}%",
387        (losses[0] - losses[losses.len() - 1]) / losses[0] * 100.0
388    );
389
390    // Compute statistics
391    let loss_mean = compute_mean(&losses);
392    let loss_std = compute_std(&losses);
393    let weight_change = (weight_history[weight_history.len() - 1] - weight_history[0]).abs();
394    let bias_change = (bias_history[bias_history.len() - 1] - bias_history[0]).abs();
395
396    println!("  Average loss: {:.6} ± {:.6}", loss_mean, loss_std);
397    println!("  Weight change: {:.6}", weight_change);
398    println!("  Bias change: {:.6}", bias_change);
399    println!("  Final weight norm: {:.6}", weight.norm().value());
400    println!("  Final bias: {:.6}", bias.value());
401
402    Ok(())
403}
examples/neural_networks/multi_head_attention.rs (line 92)
72    pub fn forward(
73        &self,
74        query: &Tensor,
75        key: &Tensor,
76        value: &Tensor,
77        attn_mask: Option<&Tensor>,
78    ) -> Tensor {
79        let qkv = Self::project_qkv(query, key, value, &self.q_proj, &self.k_proj, &self.v_proj);
80        let (q, k, v) = qkv;
81
82        // Split heads: [b, t, e] -> [b, h, t, d]
83        let (b, tq, _e) = Self::triple(query);
84        let (_b2, tk, _e2) = Self::triple(key);
85        let q = Self::split_heads(&q, b, tq, self.num_heads, self.head_dim);
86        let k = Self::split_heads(&k, b, tk, self.num_heads, self.head_dim);
87        let v = Self::split_heads(&v, b, tk, self.num_heads, self.head_dim);
88
89        // Scaled dot-product attention
90        // logits: [b, h, tq, tk]
91        let k_t = k.transpose(2, 3);
92        let mut logits = q.matmul(&k_t).div_scalar((self.head_dim as f32).sqrt());
93        if let Some(mask) = attn_mask {
94            let dims = mask.shape().dims().to_vec();
95            // If boolean-like mask matching [b,h,tq,tk], apply masked_fill
96            if dims.len() == 4 && dims[0] == b && dims[1] == self.num_heads && dims[2] == tq {
97                // Interpret mask > 0.5 as keep; we invert to build masked positions
98                let cond: Vec<bool> = mask.data().iter().map(|&v| v < 0.5).collect();
99                // Apply masked fill on a flattened view, then reshape back
100                let flat_logits = logits.view(vec![(b * self.num_heads * tq * tk) as i32]);
101                let filled = flat_logits.masked_fill(&cond, f32::NEG_INFINITY);
102                logits = filled.view(vec![b as i32, self.num_heads as i32, tq as i32, tk as i32]);
103            } else {
104                // Fallback: additive mask
105                logits = logits.add_tensor(mask);
106            }
107        }
108        let attn = logits.softmax(3);
109
110        // context: [b, h, tq, d]
111        let context = attn.matmul(&v);
112        let context = context.permute(vec![0, 2, 1, 3]); // [b, tq, h, d]
113        let context = context.contiguous().view(vec![
114            b as i32,
115            tq as i32,
116            (self.num_heads * self.head_dim) as i32,
117        ]);
118
119        // Output projection (flatten to 2D, project, then restore 3D)
120        let flat = context.view(vec![(b * tq) as i32, self.embed_dim as i32]);
121        let out2d = self.out_proj.forward(&flat);
122        out2d.view(vec![b as i32, tq as i32, self.embed_dim as i32])
123    }
Source§

impl Tensor

Source

pub fn mul_tensor(&self, other: &Tensor) -> Tensor

Element-wise multiplication with another tensor with broadcasting support.

Performs element-wise multiplication with automatic broadcasting: output[i] = self[i] * other[i]

Broadcasting enables multiplication between tensors of different but compatible shapes. Compatible shapes follow NumPy broadcasting rules:

  • Dimensions are aligned from the rightmost dimension
  • Dimensions are compatible if they are equal, or one of them is 1
  • Missing dimensions are treated as 1
§Arguments
  • other - Tensor to multiply. Shapes must be broadcast-compatible.
§Returns

A new tensor containing the element-wise product with broadcast result shape

§Examples
§Same Shape Multiplication
use train_station::Tensor;

let a = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![3]).unwrap();
let b = Tensor::from_slice(&[5.0, 6.0, 7.0], vec![3]).unwrap();
let c = a.mul_tensor(&b);
assert_eq!(c.shape().dims(), vec![3]);
assert_eq!(c.get(&[0]), 10.0); // 2.0 * 5.0
assert_eq!(c.get(&[1]), 18.0); // 3.0 * 6.0
assert_eq!(c.get(&[2]), 28.0); // 4.0 * 7.0
§Broadcasting Multiplication
use train_station::Tensor;

let a = Tensor::from_slice(&[2.0, 3.0], vec![2, 1]).unwrap();
let b = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![1, 3]).unwrap();
let c = a.mul_tensor(&b);
assert_eq!(c.shape().dims(), vec![2, 3]);
// Result: [[20.0, 40.0, 60.0], [30.0, 60.0, 90.0]]
assert_eq!(c.get(&[0, 0]), 20.0); // 2.0 * 10.0
assert_eq!(c.get(&[0, 1]), 40.0); // 2.0 * 20.0
assert_eq!(c.get(&[1, 0]), 30.0); // 3.0 * 10.0
§Panics

Panics if tensor shapes are not broadcast-compatible

Examples found in repository?
examples/supervised_training/supervised_bce.rs (line 61)
59fn bce_with_logits(logits: &Tensor, targets: &Tensor) -> Tensor {
60    let relu_z = logits.relu();
61    let zy = logits.mul_tensor(targets);
62    // |z| = relu(z) + relu(-z)
63    let abs_z = relu_z.add_tensor(&logits.mul_scalar(-1.0).relu());
64    let log_term = abs_z.mul_scalar(-1.0).exp().add_scalar(1.0).log();
65    relu_z.sub_tensor(&zy).add_tensor(&log_term).mean()
66}
More examples
Hide additional examples
examples/getting_started/tensor_basics.rs (line 104)
89fn demonstrate_basic_operations() {
90    println!("\n--- Basic Operations ---");
91
92    let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
93    let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
94
95    // Addition
96    let sum = a.add_tensor(&b);
97    println!("A + B: {:?}", sum.data());
98
99    // Subtraction
100    let diff = a.sub_tensor(&b);
101    println!("A - B: {:?}", diff.data());
102
103    // Multiplication
104    let product = a.mul_tensor(&b);
105    println!("A * B: {:?}", product.data());
106
107    // Division
108    let quotient = a.div_tensor(&b);
109    println!("A / B: {:?}", quotient.data());
110
111    // Scalar operations
112    let scalar_add = a.add_scalar(5.0);
113    println!("A + 5.0: {:?}", scalar_add.data());
114
115    let scalar_mul = a.mul_scalar(2.0);
116    println!("A * 2.0: {:?}", scalar_mul.data());
117}
examples/getting_started/tensor_operators.rs (line 222)
203fn demonstrate_method_equivalence() {
204    println!("\n--- Operator vs Method Call Equivalence ---");
205
206    let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
207    let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
208
209    // Addition: operator vs method
210    let operator_result = &a + &b;
211    let method_result = a.add_tensor(&b);
212
213    println!("A + B (operator): {:?}", operator_result.data());
214    println!("A.add_tensor(B): {:?}", method_result.data());
215    println!(
216        "Results are equal: {}",
217        operator_result.data() == method_result.data()
218    );
219
220    // Multiplication: operator vs method
221    let operator_result = &a * &b;
222    let method_result = a.mul_tensor(&b);
223
224    println!("A * B (operator): {:?}", operator_result.data());
225    println!("A.mul_tensor(B): {:?}", method_result.data());
226    println!(
227        "Results are equal: {}",
228        operator_result.data() == method_result.data()
229    );
230
231    // Scalar addition: operator vs method
232    let operator_result = &a + 5.0;
233    let method_result = a.add_scalar(5.0);
234
235    println!("A + 5.0 (operator): {:?}", operator_result.data());
236    println!("A.add_scalar(5.0): {:?}", method_result.data());
237    println!(
238        "Results are equal: {}",
239        operator_result.data() == method_result.data()
240    );
241}
examples/iterators/element_iteration.rs (line 252)
208fn demonstrate_advanced_patterns() -> Result<(), Box<dyn std::error::Error>> {
209    println!("\n--- Advanced Iterator Patterns ---");
210
211    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6])?;
212    println!("Input tensor: {:?}", tensor.data());
213
214    // Complex chain: enumerate -> filter -> map -> collect
215    println!("\nComplex chain (even indices only, add index to value):");
216    let result: Tensor = tensor
217        .iter()
218        .enumerate()
219        .filter(|(i, _)| i % 2 == 0) // Take even indices
220        .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
221        .collect();
222    println!("  Result: {:?}", result.data());
223
224    // Using take and skip for windowing
225    println!("\nWindowing with take and skip:");
226    let window1: Tensor = tensor.iter().take(3).collect();
227    let window2: Tensor = tensor.iter().skip(2).take(3).collect();
228    println!("  Window 1 (first 3): {:?}", window1.data());
229    println!("  Window 2 (middle 3): {:?}", window2.data());
230
231    // Using rev() for reverse iteration
232    println!("\nReverse iteration:");
233    let reversed: Tensor = tensor.iter().rev().collect();
234    println!("  Reversed: {:?}", reversed.data());
235
236    // Chaining with mathematical operations
237    println!("\nMathematical operation chain:");
238    let math_result: Tensor = tensor
239        .iter()
240        .map(|elem| elem.exp()) // e^x
241        .filter(|elem| elem.value() < 50.0) // Filter large values
242        .map(|elem| elem.log()) // ln(x)
243        .collect();
244    println!("  Math chain result: {:?}", math_result.data());
245
246    // Using zip for element-wise combinations
247    println!("\nElement-wise combination with zip:");
248    let tensor2 = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![6])?;
249    let combined: Tensor = tensor
250        .iter()
251        .zip(tensor2.iter())
252        .map(|(a, b)| a.mul_tensor(&b)) // Element-wise multiplication
253        .collect();
254    println!("  Combined: {:?}", combined.data());
255
256    Ok(())
257}
examples/RL_training/dqn.rs (line 472)
333pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334    println!("=== DQN Example (YardEnv discrete) ===");
335
336    // Dims
337    let state_dim = 3usize;
338    let action_dim = 3usize;
339
340    // Hparams
341    let gamma = 0.99f32;
342    let batch_size = 64usize;
343    let start_steps = 200usize;
344    let target_update_interval = 200usize; // hard update cadence
345    let max_grad_norm = 1.0f32;
346    let mut epsilon = 1.0f32;
347    let eps_min = 0.05f32;
348    let eps_decay_steps = 2_000usize; // linear decay
349    let total_steps = std::env::var("DQN_STEPS")
350        .ok()
351        .and_then(|v| v.parse::<usize>().ok())
352        .unwrap_or(3000usize);
353
354    // Models
355    let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356    let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357    q_targ.net.copy_from(&q_net.net);
358    q_targ.set_requires_grad_all(false);
359
360    // Optimizer
361    let mut q_opt = Adam::with_learning_rate(3e-4);
362    for p in q_net.parameters() {
363        q_opt.add_parameter(p);
364    }
365
366    // Replay + env
367    let mut rb = ReplayBuffer::new(100_000, state_dim);
368    let mut env = YardEnv::new(12345);
369    let mut rng = SmallRng::new(999_111);
370
371    // Metrics
372    let mut state = env.reset();
373    let mut episode_return = 0.0f32;
374    let mut episode = 0usize;
375    let mut ema_return: Option<f32> = None;
376    let ema_alpha = 0.05f32;
377    let mut best_return = f32::NEG_INFINITY;
378
379    for t in 0..total_steps {
380        // Epsilon-greedy action
381        let action_index = if t < start_steps || rng.next_f32() < epsilon {
382            rng.sample_index(action_dim)
383        } else {
384            let _ng = NoGradTrack::new();
385            let q_vals = q_net.forward(&state);
386            let row = q_vals.data();
387            let mut best_i = 0usize;
388            let mut best_v = row[0];
389            for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390                if r > best_v {
391                    best_v = r;
392                    best_i = i;
393                }
394            }
395            best_i
396        };
397
398        // Env step
399        let (next_state, reward, done) = env.step(action_index);
400        episode_return += reward;
401
402        // Store
403        let s_slice = state.data().to_vec();
404        let s2_slice = next_state.data().to_vec();
405        rb.push(
406            &s_slice,
407            action_index,
408            reward,
409            if done { 1.0 } else { 0.0 },
410            &s2_slice,
411        );
412
413        // Reset on done
414        state = if done {
415            let st = env.reset();
416            ema_return = Some(match ema_return {
417                None => episode_return,
418                Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419            });
420            if episode_return > best_return {
421                best_return = episode_return;
422            }
423            println!(
424                "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425                t,
426                episode,
427                episode_return,
428                ema_return.unwrap_or(episode_return),
429                best_return,
430                rb.size
431            );
432            episode_return = 0.0;
433            episode += 1;
434            st
435        } else {
436            next_state
437        };
438
439        // Epsilon linear decay
440        if t < eps_decay_steps {
441            epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442        }
443
444        // Train
445        if rb.can_sample(batch_size) {
446            let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448            // Double DQN target: a* = argmax_a Q_online(s2,a); y = r + (1-d)*gamma*Q_target(s2, a*)
449            let target_q = {
450                let _ng = NoGradTrack::new();
451                let q_online_s2 = q_net.forward(&s2);
452                // argmax per row (manual on CPU)
453                let row_stride = action_dim;
454                let qd = q_online_s2.data();
455                let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456                for i in 0..batch_size {
457                    let base = i * row_stride;
458                    let mut bi = 0usize;
459                    let mut bv = qd[base];
460                    for j in 1..action_dim {
461                        let v = qd[base + j];
462                        if v > bv {
463                            bv = v;
464                            bi = j;
465                        }
466                    }
467                    next_actions.push(bi);
468                }
469                let q_targ_s2 = q_targ.forward(&s2);
470                let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471                let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472                r.add_tensor(&not_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473            };
474
475            // Q(s,a) for current actions
476            // Zero grads first
477            {
478                let mut params = q_net.parameters();
479                q_opt.zero_grad(&mut params);
480            }
481
482            let q_all = q_net.forward(&s);
483            let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484            let diff = q_sa.sub_tensor(&target_q);
485            let mut loss = pseudo_huber_mean(&diff);
486            loss.backward(None);
487
488            // Step (filter only params with grads)
489            {
490                let params = q_net.parameters();
491                let mut with_grads: Vec<&mut Tensor> = Vec::new();
492                for p in params {
493                    if p.grad_owned().is_some() {
494                        with_grads.push(p);
495                    }
496                }
497                if !with_grads.is_empty() {
498                    let gn = grad_global_norm(&mut with_grads);
499                    clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500                    q_opt.step(&mut with_grads);
501                    q_opt.zero_grad(&mut with_grads);
502                    if t % 100 == 0 {
503                        let mut pn = q_net.parameters();
504                        let pn_l2 = params_l2_norm(&mut pn);
505                        let q_mean = q_all.mean().value();
506                        println!(
507                            "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508                            t, loss.value(), q_mean, gn, pn_l2, epsilon
509                        );
510                    }
511                }
512            }
513
514            // Target hard update
515            if t % target_update_interval == 0 {
516                q_targ.net.copy_from(&q_net.net);
517            }
518
519            // Clear graphs
520            clear_all_graphs_known();
521        }
522    }
523
524    println!("=== DQN training finished ===");
525    Ok(())
526}
examples/RL_training/ppo_discrete.rs (line 503)
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320    println!("=== PPO Discrete Example (YardEnv) ===");
321
322    let state_dim = 3usize;
323    let action_dim = 3usize;
324    let total_steps = std::env::var("PPOD_STEPS")
325        .ok()
326        .and_then(|v| v.parse::<usize>().ok())
327        .unwrap_or(3500usize);
328    let horizon = 128usize;
329    let epochs = 4usize;
330    let mini_batch_size = 64usize;
331    let gamma = 0.99f32;
332    let lam = 0.95f32;
333    let clip_eps = 0.2f32;
334    let vf_coef = 0.5f32;
335    let ent_coef = 0.0f32;
336    let max_grad_norm = 1.0f32;
337
338    let mut actor = Actor::new(state_dim, action_dim, Some(111));
339    let mut critic = Critic::new(state_dim, Some(222));
340    let mut actor_opt = Adam::with_learning_rate(3e-4);
341    for p in actor.parameters() {
342        actor_opt.add_parameter(p);
343    }
344    let mut critic_opt = Adam::with_learning_rate(3e-4);
345    for p in critic.parameters() {
346        critic_opt.add_parameter(p);
347    }
348
349    let mut env = YardEnv::new(1234);
350    let mut rng = SmallRng::new(98765);
351    let mut state = env.reset();
352    let mut episode_return = 0.0f32;
353    let mut episode = 0usize;
354    let mut ema_return: Option<f32> = None;
355    let ema_alpha = 0.05f32;
356    let mut best_return = f32::NEG_INFINITY;
357
358    let mut t = 0usize;
359    while t < total_steps {
360        let mut batch = RolloutBatch::new(horizon, state_dim);
361        for _ in 0..horizon {
362            // Actor logits and categorical sampling
363            let logits = actor.forward(&state); // [1, A]
364            let probs = logits.softmax(1); // [1, A]
365                                           // sample action from probs (CPU sampling)
366            let p = probs.data();
367            let (p0, p1, _p2) = (p[0], p[1], p[2]);
368            let u = rng.next_f32();
369            let a_idx = if u < p0 {
370                0
371            } else if u < p0 + p1 {
372                1
373            } else {
374                2
375            };
376
377            let old_logp = {
378                let _ng = NoGradTrack::new();
379                let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380                lp.data()[0]
381            };
382
383            // Step env
384            let (next_state, reward, done) = env.step(a_idx);
385            episode_return += reward;
386
387            // Critic value
388            let value_t = critic.forward(&state);
389            let value_v = value_t.data()[0];
390
391            batch.push(
392                state.data(),
393                a_idx,
394                old_logp,
395                reward,
396                if done { 1.0 } else { 0.0 },
397                value_v,
398                next_state.data(),
399            );
400
401            state = if done {
402                let st = env.reset();
403                ema_return = Some(match ema_return {
404                    None => episode_return,
405                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406                });
407                if episode_return > best_return {
408                    best_return = episode_return;
409                }
410                println!(
411                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412                    t,
413                    episode,
414                    episode_return,
415                    ema_return.unwrap_or(episode_return),
416                    best_return
417                );
418                episode_return = 0.0;
419                episode += 1;
420                st
421            } else {
422                next_state
423            };
424
425            t += 1;
426            if t >= total_steps {
427                break;
428            }
429        }
430
431        // Bootstrap values for GAE
432        let next_values: Vec<f32> = {
433            let mut out = Vec::with_capacity(batch.len());
434            for i in 0..batch.len() {
435                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437                out.push(critic.forward(&s2_t).data()[0]);
438            }
439            out
440        };
441
442        let mut returns = vec![0.0f32; batch.len()];
443        let mut adv = vec![0.0f32; batch.len()];
444        compute_gae(
445            &mut returns,
446            &mut adv,
447            &batch.rewards,
448            &batch.dones,
449            &batch.values,
450            &next_values,
451            gamma,
452            lam,
453        );
454        normalize_in_place(&mut adv, 1e-8);
455
456        // Tensors for training
457        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458        let actions_vec = batch.actions.clone();
459        let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463        // PPO epochs
464        let num_minibatches = batch.len().div_ceil(mini_batch_size);
465        for e in 0..epochs {
466            for mb in 0..num_minibatches {
467                let start = mb * mini_batch_size;
468                let end = (start + mini_batch_size).min(batch.len());
469                if start >= end {
470                    break;
471                }
472
473                // Views
474                let s_mb = states_t
475                    .slice_view(start * state_dim, 1, (end - start) * state_dim)
476                    .reshape(vec![(end - start) as i32, state_dim as i32]);
477                let oldlp_mb = old_logp_t
478                    .slice_view(start, 1, end - start)
479                    .reshape(vec![(end - start) as i32, 1]);
480                let ret_mb = returns_t
481                    .slice_view(start, 1, end - start)
482                    .reshape(vec![(end - start) as i32, 1]);
483                let adv_mb = adv_t
484                    .slice_view(start, 1, end - start)
485                    .reshape(vec![(end - start) as i32, 1]);
486                let a_slice = &actions_vec[start..end];
487
488                // Zero grads
489                {
490                    let mut ps = actor.parameters();
491                    actor_opt.zero_grad(&mut ps);
492                }
493                {
494                    let mut ps = critic.parameters();
495                    critic_opt.zero_grad(&mut ps);
496                }
497
498                // Forward
499                let logits_mb = actor.forward(&s_mb); // [B,A]
500                let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501                let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502                let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503                let pg1 = ratio.mul_tensor(&adv_mb);
504                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507                let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509                let v_pred = critic.forward(&s_mb);
510                let v_loss = v_pred
511                    .sub_tensor(&ret_mb)
512                    .pow_scalar(2.0)
513                    .mean()
514                    .mul_scalar(vf_coef);
515
516                // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517                let probs_mb = logits_mb.softmax(1);
518                let logp_all = probs_mb.add_scalar(1e-8).log();
519                let ent = probs_mb
520                    .mul_tensor(&logp_all)
521                    .sum_dims(&[1], true)
522                    .mul_scalar(-1.0)
523                    .mean()
524                    .mul_scalar(ent_coef);
525
526                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527                loss.backward(None);
528
529                // Step actor
530                {
531                    let params = actor.parameters();
532                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
533                    for p in params {
534                        if p.grad_owned().is_some() {
535                            with_grads.push(p);
536                        }
537                    }
538                    if !with_grads.is_empty() {
539                        let _ = grad_global_norm(&mut with_grads);
540                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541                        actor_opt.step(&mut with_grads);
542                        actor_opt.zero_grad(&mut with_grads);
543                    }
544                }
545
546                // Step critic
547                {
548                    let params = critic.parameters();
549                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
550                    for p in params {
551                        if p.grad_owned().is_some() {
552                            with_grads.push(p);
553                        }
554                    }
555                    if !with_grads.is_empty() {
556                        let _ = grad_global_norm(&mut with_grads);
557                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558                        critic_opt.step(&mut with_grads);
559                        critic_opt.zero_grad(&mut with_grads);
560                    }
561                }
562
563                if e == 0 && mb == 0 {
564                    println!(
565                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
566                        t,
567                        actor_loss.value(),
568                        v_loss.value()
569                    );
570                }
571
572                clear_all_graphs_known();
573            }
574        }
575    }
576
577    println!("=== PPO discrete training finished ===");
578    Ok(())
579}
Source

pub fn mul_scalar(&self, scalar: f32) -> Tensor

Broadcast multiplication with a scalar value.

Multiplies every element by the scalar: output[i] = self[i] * scalar

§Arguments
  • scalar - Value to multiply with each element
§Returns

A new tensor with each element multiplied by the scalar

§Examples
§Basic Scalar Multiplication
use train_station::Tensor;

let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let b = a.mul_scalar(10.0);
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 10.0); // 1.0 * 10.0
assert_eq!(b.get(&[1]), 20.0); // 2.0 * 10.0
assert_eq!(b.get(&[2]), 30.0); // 3.0 * 10.0
§Negative Scalar Multiplication
use train_station::Tensor;

let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let b = a.mul_scalar(-2.0);
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), -2.0); // 1.0 * -2.0
assert_eq!(b.get(&[1]), -4.0); // 2.0 * -2.0
assert_eq!(b.get(&[2]), -6.0); // 3.0 * -2.0
Examples found in repository?
examples/RL_training/../neural_networks/basic_linear_layer.rs (line 58)
53    pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
54        // Xavier/Glorot initialization: scale by sqrt(1/input_size)
55        let scale = (1.0 / input_size as f32).sqrt();
56
57        let weight = Tensor::randn(vec![input_size, output_size], seed)
58            .mul_scalar(scale)
59            .with_requires_grad();
60        let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
61
62        Self {
63            weight,
64            bias,
65            input_size,
66            output_size,
67        }
68    }
More examples
Hide additional examples
examples/RL_training/dqn.rs (line 291)
277fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
278    let mut total_sq = 0.0f32;
279    for p in parameters.iter() {
280        if let Some(g) = p.grad_owned() {
281            for &v in g.data() {
282                total_sq += v * v;
283            }
284        }
285    }
286    let norm = total_sq.sqrt();
287    if norm > max_norm {
288        let scale = max_norm / (norm + eps);
289        for p in parameters.iter_mut() {
290            if let Some(g) = p.grad_owned() {
291                p.set_grad(g.mul_scalar(scale));
292            }
293        }
294    }
295}
296
297fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
298    let mut total_sq = 0.0f32;
299    for p in parameters.iter_mut() {
300        if let Some(g) = p.grad_owned() {
301            for &v in g.data() {
302                total_sq += v * v;
303            }
304        }
305    }
306    total_sq.sqrt()
307}
308
309fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
310    let _ng = NoGradTrack::new();
311    let mut total_sq = 0.0f32;
312    for p in parameters.iter_mut() {
313        for &v in p.data() {
314            total_sq += v * v;
315        }
316    }
317    total_sq.sqrt()
318}
319
320// Pseudo-Huber loss: sqrt(1 + diff^2) - 1 (smooth, robust)
321fn pseudo_huber_mean(diff: &Tensor) -> Tensor {
322    diff.pow_scalar(2.0)
323        .add_scalar(1.0)
324        .sqrt()
325        .sub_scalar(1.0)
326        .mean()
327}
328
329// -------------------------------
330// Main
331// -------------------------------
332
333pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334    println!("=== DQN Example (YardEnv discrete) ===");
335
336    // Dims
337    let state_dim = 3usize;
338    let action_dim = 3usize;
339
340    // Hparams
341    let gamma = 0.99f32;
342    let batch_size = 64usize;
343    let start_steps = 200usize;
344    let target_update_interval = 200usize; // hard update cadence
345    let max_grad_norm = 1.0f32;
346    let mut epsilon = 1.0f32;
347    let eps_min = 0.05f32;
348    let eps_decay_steps = 2_000usize; // linear decay
349    let total_steps = std::env::var("DQN_STEPS")
350        .ok()
351        .and_then(|v| v.parse::<usize>().ok())
352        .unwrap_or(3000usize);
353
354    // Models
355    let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356    let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357    q_targ.net.copy_from(&q_net.net);
358    q_targ.set_requires_grad_all(false);
359
360    // Optimizer
361    let mut q_opt = Adam::with_learning_rate(3e-4);
362    for p in q_net.parameters() {
363        q_opt.add_parameter(p);
364    }
365
366    // Replay + env
367    let mut rb = ReplayBuffer::new(100_000, state_dim);
368    let mut env = YardEnv::new(12345);
369    let mut rng = SmallRng::new(999_111);
370
371    // Metrics
372    let mut state = env.reset();
373    let mut episode_return = 0.0f32;
374    let mut episode = 0usize;
375    let mut ema_return: Option<f32> = None;
376    let ema_alpha = 0.05f32;
377    let mut best_return = f32::NEG_INFINITY;
378
379    for t in 0..total_steps {
380        // Epsilon-greedy action
381        let action_index = if t < start_steps || rng.next_f32() < epsilon {
382            rng.sample_index(action_dim)
383        } else {
384            let _ng = NoGradTrack::new();
385            let q_vals = q_net.forward(&state);
386            let row = q_vals.data();
387            let mut best_i = 0usize;
388            let mut best_v = row[0];
389            for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390                if r > best_v {
391                    best_v = r;
392                    best_i = i;
393                }
394            }
395            best_i
396        };
397
398        // Env step
399        let (next_state, reward, done) = env.step(action_index);
400        episode_return += reward;
401
402        // Store
403        let s_slice = state.data().to_vec();
404        let s2_slice = next_state.data().to_vec();
405        rb.push(
406            &s_slice,
407            action_index,
408            reward,
409            if done { 1.0 } else { 0.0 },
410            &s2_slice,
411        );
412
413        // Reset on done
414        state = if done {
415            let st = env.reset();
416            ema_return = Some(match ema_return {
417                None => episode_return,
418                Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419            });
420            if episode_return > best_return {
421                best_return = episode_return;
422            }
423            println!(
424                "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425                t,
426                episode,
427                episode_return,
428                ema_return.unwrap_or(episode_return),
429                best_return,
430                rb.size
431            );
432            episode_return = 0.0;
433            episode += 1;
434            st
435        } else {
436            next_state
437        };
438
439        // Epsilon linear decay
440        if t < eps_decay_steps {
441            epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442        }
443
444        // Train
445        if rb.can_sample(batch_size) {
446            let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448            // Double DQN target: a* = argmax_a Q_online(s2,a); y = r + (1-d)*gamma*Q_target(s2, a*)
449            let target_q = {
450                let _ng = NoGradTrack::new();
451                let q_online_s2 = q_net.forward(&s2);
452                // argmax per row (manual on CPU)
453                let row_stride = action_dim;
454                let qd = q_online_s2.data();
455                let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456                for i in 0..batch_size {
457                    let base = i * row_stride;
458                    let mut bi = 0usize;
459                    let mut bv = qd[base];
460                    for j in 1..action_dim {
461                        let v = qd[base + j];
462                        if v > bv {
463                            bv = v;
464                            bi = j;
465                        }
466                    }
467                    next_actions.push(bi);
468                }
469                let q_targ_s2 = q_targ.forward(&s2);
470                let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471                let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472                r.add_tensor(&not_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473            };
474
475            // Q(s,a) for current actions
476            // Zero grads first
477            {
478                let mut params = q_net.parameters();
479                q_opt.zero_grad(&mut params);
480            }
481
482            let q_all = q_net.forward(&s);
483            let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484            let diff = q_sa.sub_tensor(&target_q);
485            let mut loss = pseudo_huber_mean(&diff);
486            loss.backward(None);
487
488            // Step (filter only params with grads)
489            {
490                let params = q_net.parameters();
491                let mut with_grads: Vec<&mut Tensor> = Vec::new();
492                for p in params {
493                    if p.grad_owned().is_some() {
494                        with_grads.push(p);
495                    }
496                }
497                if !with_grads.is_empty() {
498                    let gn = grad_global_norm(&mut with_grads);
499                    clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500                    q_opt.step(&mut with_grads);
501                    q_opt.zero_grad(&mut with_grads);
502                    if t % 100 == 0 {
503                        let mut pn = q_net.parameters();
504                        let pn_l2 = params_l2_norm(&mut pn);
505                        let q_mean = q_all.mean().value();
506                        println!(
507                            "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508                            t, loss.value(), q_mean, gn, pn_l2, epsilon
509                        );
510                    }
511                }
512            }
513
514            // Target hard update
515            if t % target_update_interval == 0 {
516                q_targ.net.copy_from(&q_net.net);
517            }
518
519            // Clear graphs
520            clear_all_graphs_known();
521        }
522    }
523
524    println!("=== DQN training finished ===");
525    Ok(())
526}
examples/RL_training/ppo_discrete.rs (line 266)
252fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
253    let mut total_sq = 0.0f32;
254    for p in parameters.iter() {
255        if let Some(g) = p.grad_owned() {
256            for &v in g.data() {
257                total_sq += v * v;
258            }
259        }
260    }
261    let norm = total_sq.sqrt();
262    if norm > max_norm {
263        let scale = max_norm / (norm + eps);
264        for p in parameters.iter_mut() {
265            if let Some(g) = p.grad_owned() {
266                p.set_grad(g.mul_scalar(scale));
267            }
268        }
269    }
270}
271
272// log-softmax for selected actions: given logits [B,A] and actions Vec<usize> -> log_prob [B,1]
273fn log_prob_actions(
274    logits: &Tensor,
275    actions: &[usize],
276    batch: usize,
277    _action_dim: usize,
278) -> Tensor {
279    let max_logits = logits.max_dims(&[1], true); // [B,1]
280    let shifted = logits.sub_tensor(&max_logits);
281    let exp = shifted.exp();
282    let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283    let log_sum_exp = sum_exp.log(); // [B,1]
284    let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285                                                        // gather selected action log-probs
286    log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291    new_logp.sub_tensor(old_logp).exp()
292}
293
294// Clamp ratio to [1-clip, 1+clip] using ReLU-based clamp (no custom ops)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296    let b = ratio.shape().dims()[0];
297    let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298    let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299    let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300    high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304    let mut total_sq = 0.0f32;
305    for p in parameters.iter_mut() {
306        if let Some(g) = p.grad_owned() {
307            for &v in g.data() {
308                total_sq += v * v;
309            }
310        }
311    }
312    total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320    println!("=== PPO Discrete Example (YardEnv) ===");
321
322    let state_dim = 3usize;
323    let action_dim = 3usize;
324    let total_steps = std::env::var("PPOD_STEPS")
325        .ok()
326        .and_then(|v| v.parse::<usize>().ok())
327        .unwrap_or(3500usize);
328    let horizon = 128usize;
329    let epochs = 4usize;
330    let mini_batch_size = 64usize;
331    let gamma = 0.99f32;
332    let lam = 0.95f32;
333    let clip_eps = 0.2f32;
334    let vf_coef = 0.5f32;
335    let ent_coef = 0.0f32;
336    let max_grad_norm = 1.0f32;
337
338    let mut actor = Actor::new(state_dim, action_dim, Some(111));
339    let mut critic = Critic::new(state_dim, Some(222));
340    let mut actor_opt = Adam::with_learning_rate(3e-4);
341    for p in actor.parameters() {
342        actor_opt.add_parameter(p);
343    }
344    let mut critic_opt = Adam::with_learning_rate(3e-4);
345    for p in critic.parameters() {
346        critic_opt.add_parameter(p);
347    }
348
349    let mut env = YardEnv::new(1234);
350    let mut rng = SmallRng::new(98765);
351    let mut state = env.reset();
352    let mut episode_return = 0.0f32;
353    let mut episode = 0usize;
354    let mut ema_return: Option<f32> = None;
355    let ema_alpha = 0.05f32;
356    let mut best_return = f32::NEG_INFINITY;
357
358    let mut t = 0usize;
359    while t < total_steps {
360        let mut batch = RolloutBatch::new(horizon, state_dim);
361        for _ in 0..horizon {
362            // Actor logits and categorical sampling
363            let logits = actor.forward(&state); // [1, A]
364            let probs = logits.softmax(1); // [1, A]
365                                           // sample action from probs (CPU sampling)
366            let p = probs.data();
367            let (p0, p1, _p2) = (p[0], p[1], p[2]);
368            let u = rng.next_f32();
369            let a_idx = if u < p0 {
370                0
371            } else if u < p0 + p1 {
372                1
373            } else {
374                2
375            };
376
377            let old_logp = {
378                let _ng = NoGradTrack::new();
379                let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380                lp.data()[0]
381            };
382
383            // Step env
384            let (next_state, reward, done) = env.step(a_idx);
385            episode_return += reward;
386
387            // Critic value
388            let value_t = critic.forward(&state);
389            let value_v = value_t.data()[0];
390
391            batch.push(
392                state.data(),
393                a_idx,
394                old_logp,
395                reward,
396                if done { 1.0 } else { 0.0 },
397                value_v,
398                next_state.data(),
399            );
400
401            state = if done {
402                let st = env.reset();
403                ema_return = Some(match ema_return {
404                    None => episode_return,
405                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406                });
407                if episode_return > best_return {
408                    best_return = episode_return;
409                }
410                println!(
411                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412                    t,
413                    episode,
414                    episode_return,
415                    ema_return.unwrap_or(episode_return),
416                    best_return
417                );
418                episode_return = 0.0;
419                episode += 1;
420                st
421            } else {
422                next_state
423            };
424
425            t += 1;
426            if t >= total_steps {
427                break;
428            }
429        }
430
431        // Bootstrap values for GAE
432        let next_values: Vec<f32> = {
433            let mut out = Vec::with_capacity(batch.len());
434            for i in 0..batch.len() {
435                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437                out.push(critic.forward(&s2_t).data()[0]);
438            }
439            out
440        };
441
442        let mut returns = vec![0.0f32; batch.len()];
443        let mut adv = vec![0.0f32; batch.len()];
444        compute_gae(
445            &mut returns,
446            &mut adv,
447            &batch.rewards,
448            &batch.dones,
449            &batch.values,
450            &next_values,
451            gamma,
452            lam,
453        );
454        normalize_in_place(&mut adv, 1e-8);
455
456        // Tensors for training
457        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458        let actions_vec = batch.actions.clone();
459        let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463        // PPO epochs
464        let num_minibatches = batch.len().div_ceil(mini_batch_size);
465        for e in 0..epochs {
466            for mb in 0..num_minibatches {
467                let start = mb * mini_batch_size;
468                let end = (start + mini_batch_size).min(batch.len());
469                if start >= end {
470                    break;
471                }
472
473                // Views
474                let s_mb = states_t
475                    .slice_view(start * state_dim, 1, (end - start) * state_dim)
476                    .reshape(vec![(end - start) as i32, state_dim as i32]);
477                let oldlp_mb = old_logp_t
478                    .slice_view(start, 1, end - start)
479                    .reshape(vec![(end - start) as i32, 1]);
480                let ret_mb = returns_t
481                    .slice_view(start, 1, end - start)
482                    .reshape(vec![(end - start) as i32, 1]);
483                let adv_mb = adv_t
484                    .slice_view(start, 1, end - start)
485                    .reshape(vec![(end - start) as i32, 1]);
486                let a_slice = &actions_vec[start..end];
487
488                // Zero grads
489                {
490                    let mut ps = actor.parameters();
491                    actor_opt.zero_grad(&mut ps);
492                }
493                {
494                    let mut ps = critic.parameters();
495                    critic_opt.zero_grad(&mut ps);
496                }
497
498                // Forward
499                let logits_mb = actor.forward(&s_mb); // [B,A]
500                let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501                let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502                let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503                let pg1 = ratio.mul_tensor(&adv_mb);
504                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507                let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509                let v_pred = critic.forward(&s_mb);
510                let v_loss = v_pred
511                    .sub_tensor(&ret_mb)
512                    .pow_scalar(2.0)
513                    .mean()
514                    .mul_scalar(vf_coef);
515
516                // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517                let probs_mb = logits_mb.softmax(1);
518                let logp_all = probs_mb.add_scalar(1e-8).log();
519                let ent = probs_mb
520                    .mul_tensor(&logp_all)
521                    .sum_dims(&[1], true)
522                    .mul_scalar(-1.0)
523                    .mean()
524                    .mul_scalar(ent_coef);
525
526                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527                loss.backward(None);
528
529                // Step actor
530                {
531                    let params = actor.parameters();
532                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
533                    for p in params {
534                        if p.grad_owned().is_some() {
535                            with_grads.push(p);
536                        }
537                    }
538                    if !with_grads.is_empty() {
539                        let _ = grad_global_norm(&mut with_grads);
540                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541                        actor_opt.step(&mut with_grads);
542                        actor_opt.zero_grad(&mut with_grads);
543                    }
544                }
545
546                // Step critic
547                {
548                    let params = critic.parameters();
549                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
550                    for p in params {
551                        if p.grad_owned().is_some() {
552                            with_grads.push(p);
553                        }
554                    }
555                    if !with_grads.is_empty() {
556                        let _ = grad_global_norm(&mut with_grads);
557                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558                        critic_opt.step(&mut with_grads);
559                        critic_opt.zero_grad(&mut with_grads);
560                    }
561                }
562
563                if e == 0 && mb == 0 {
564                    println!(
565                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
566                        t,
567                        actor_loss.value(),
568                        v_loss.value()
569                    );
570                }
571
572                clear_all_graphs_known();
573            }
574        }
575    }
576
577    println!("=== PPO discrete training finished ===");
578    Ok(())
579}
examples/supervised_training/supervised_bce.rs (line 37)
23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24    let mut total_sq = 0.0f32;
25    for p in parameters.iter() {
26        if let Some(g) = p.grad_owned() {
27            for &v in g.data() {
28                total_sq += v * v;
29            }
30        }
31    }
32    let norm = total_sq.sqrt();
33    if norm > max_norm {
34        let scale = max_norm / (norm + eps);
35        for p in parameters.iter_mut() {
36            if let Some(g) = p.grad_owned() {
37                p.set_grad(g.mul_scalar(scale));
38            }
39        }
40    }
41}
42
43fn accuracy(pred: &Tensor, targets: &Tensor) -> f32 {
44    // pred: [B,1] with sigmoid; threshold at 0.5
45    let p = pred.data();
46    let t = targets.data();
47    let mut correct = 0usize;
48    for i in 0..p.len() {
49        let yhat = if p[i] >= 0.5 { 1.0 } else { 0.0 };
50        if (yhat - t[i]).abs() < 1e-6 {
51            correct += 1;
52        }
53    }
54    correct as f32 / (p.len() as f32)
55}
56
57// Numerically stable BCE with logits:
58// L = mean( relu(z) - z*y + log(1 + exp(-|z|)) )
59fn bce_with_logits(logits: &Tensor, targets: &Tensor) -> Tensor {
60    let relu_z = logits.relu();
61    let zy = logits.mul_tensor(targets);
62    // |z| = relu(z) + relu(-z)
63    let abs_z = relu_z.add_tensor(&logits.mul_scalar(-1.0).relu());
64    let log_term = abs_z.mul_scalar(-1.0).exp().add_scalar(1.0).log();
65    relu_z.sub_tensor(&zy).add_tensor(&log_term).mean()
66}
examples/supervised_training/supervised_classification.rs (line 37)
23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24    let mut total_sq = 0.0f32;
25    for p in parameters.iter() {
26        if let Some(g) = p.grad_owned() {
27            for &v in g.data() {
28                total_sq += v * v;
29            }
30        }
31    }
32    let norm = total_sq.sqrt();
33    if norm > max_norm {
34        let scale = max_norm / (norm + eps);
35        for p in parameters.iter_mut() {
36            if let Some(g) = p.grad_owned() {
37                p.set_grad(g.mul_scalar(scale));
38            }
39        }
40    }
41}
42
43// Cross-entropy over logits: CE = -mean(log_softmax(logits)[range, labels])
44fn cross_entropy_logits(
45    logits: &Tensor,
46    labels: &[usize],
47    batch: usize,
48    _num_classes: usize,
49) -> Tensor {
50    // log_softmax = logits - logsumexp(logits, dim=1)
51    let max_logits = logits.max_dims(&[1], true);
52    let shifted = logits.sub_tensor(&max_logits);
53    let exp = shifted.exp();
54    let sum_exp = exp.sum_dims(&[1], true);
55    let log_sum_exp = sum_exp.log();
56    let log_softmax = shifted.sub_tensor(&log_sum_exp);
57    let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58    ll.mul_scalar(-1.0).mean()
59}
examples/supervised_training/supervised_regression.rs (line 36)
22fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
23    let mut total_sq = 0.0f32;
24    for p in parameters.iter() {
25        if let Some(g) = p.grad_owned() {
26            for &v in g.data() {
27                total_sq += v * v;
28            }
29        }
30    }
31    let norm = total_sq.sqrt();
32    if norm > max_norm {
33        let scale = max_norm / (norm + eps);
34        for p in parameters.iter_mut() {
35            if let Some(g) = p.grad_owned() {
36                p.set_grad(g.mul_scalar(scale));
37            }
38        }
39    }
40}
Source§

impl Tensor

Source

pub fn pow_scalar(&self, exponent: f32) -> Tensor

Raises each element to a scalar power.

Computes element-wise power: output[i] = self[i]^exponent

§Arguments
  • exponent - The scalar exponent to raise each element to
§Returns

A new tensor with each element raised to the given power

§Examples
§Basic Scalar Power
use train_station::Tensor;

let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let b = a.pow_scalar(2.0);
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 1.0); // 1.0^2 = 1.0
assert_eq!(b.get(&[1]), 4.0); // 2.0^2 = 4.0
assert_eq!(b.get(&[2]), 9.0); // 3.0^2 = 9.0
§Square Root (Power 0.5)
use train_station::Tensor;

let a = Tensor::from_slice(&[1.0, 4.0, 9.0], vec![3]).unwrap();
let b = a.pow_scalar(0.5);
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 1.0); // sqrt(1.0) = 1.0
assert_eq!(b.get(&[1]), 2.0); // sqrt(4.0) = 2.0
assert_eq!(b.get(&[2]), 3.0); // sqrt(9.0) = 3.0
Examples found in repository?
examples/supervised_training/supervised_regression.rs (line 43)
42fn mse(pred: &Tensor, target: &Tensor) -> Tensor {
43    pred.sub_tensor(target).pow_scalar(2.0).mean()
44}
45
46fn rmse(pred: &Tensor, target: &Tensor) -> f32 {
47    mse(pred, target).sqrt().value()
48}
49
50fn r2_score(pred: &Tensor, target: &Tensor) -> f32 {
51    // R^2 = 1 - SS_res / SS_tot
52    let y = target;
53    let y_mean = y.mean();
54    let ss_res = pred.sub_tensor(y).pow_scalar(2.0).sum();
55    let ss_tot = y.sub_tensor(&y_mean).pow_scalar(2.0).sum();
56    let ss_res_v = ss_res.value();
57    let ss_tot_v = ss_tot.value().max(1e-12); // avoid divide by zero
58    1.0 - (ss_res_v / ss_tot_v)
59}
More examples
Hide additional examples
examples/RL_training/dqn.rs (line 322)
321fn pseudo_huber_mean(diff: &Tensor) -> Tensor {
322    diff.pow_scalar(2.0)
323        .add_scalar(1.0)
324        .sqrt()
325        .sub_scalar(1.0)
326        .mean()
327}
examples/RL_training/ppo_continuous.rs (line 237)
234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235    // All tensors shaped [B, A] (log_std is broadcastable)
236    let std = log_std.exp();
237    let var = std.pow_scalar(2.0);
238    let log_scale = log_std;
239    let diff = action.sub_tensor(mean);
240    let log_prob = diff
241        .pow_scalar(2.0)
242        .div_tensor(&var)
243        .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244        .add_tensor(&log_scale.mul_scalar(2.0))
245        .mul_scalar(0.5)
246        .mul_scalar(-1.0);
247    // Sum across action dim (dim=1) -> [B,1]
248    log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253    returns_out: &mut [f32],
254    adv_out: &mut [f32],
255    rewards: &[f32],
256    dones: &[f32],
257    values: &[f32],
258    next_values: &[f32],
259    gamma: f32,
260    lam: f32,
261) {
262    let n = rewards.len();
263    let mut gae = 0.0f32;
264    for t in (0..n).rev() {
265        let not_done = 1.0 - dones[t];
266        let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267        gae = delta + gamma * lam * not_done * gae;
268        adv_out[t] = gae;
269        returns_out[t] = gae + values[t];
270    }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274    let n = x.len() as f32;
275    if n <= 1.0 {
276        return;
277    }
278    let mean = x.iter().copied().sum::<f32>() / n;
279    let var = x
280        .iter()
281        .map(|v| {
282            let d = v - mean;
283            d * d
284        })
285        .sum::<f32>()
286        / n;
287    let std = (var + eps).sqrt();
288    for v in x.iter_mut() {
289        *v = (*v - mean) / std;
290    }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294    let mut total_sq = 0.0f32;
295    for p in parameters.iter() {
296        if let Some(g) = p.grad_owned() {
297            for &v in g.data() {
298                total_sq += v * v;
299            }
300        }
301    }
302    let norm = total_sq.sqrt();
303    if norm > max_norm {
304        let scale = max_norm / (norm + eps);
305        for p in parameters.iter_mut() {
306            if let Some(g) = p.grad_owned() {
307                p.set_grad(g.mul_scalar(scale));
308            }
309        }
310    }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314    let mut total_sq = 0.0f32;
315    for p in parameters.iter_mut() {
316        if let Some(g) = p.grad_owned() {
317            for &v in g.data() {
318                total_sq += v * v;
319            }
320        }
321    }
322    total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330    println!("=== PPO Continuous Example (YardEnv) ===");
331
332    let state_dim = 3usize;
333    let action_dim = 1usize;
334
335    // Hparams
336    let total_steps = std::env::var("PPO_STEPS")
337        .ok()
338        .and_then(|v| v.parse::<usize>().ok())
339        .unwrap_or(4000usize);
340    let horizon = 128usize; // rollout length per update
341    let epochs = 4usize; // PPO epochs per update
342    let mini_batch_size = 64usize; // minibatch from horizon
343    let gamma = 0.99f32;
344    let lam = 0.95f32; // GAE lambda
345    let clip_eps = 0.2f32;
346    let vf_coef = 0.5f32;
347    let ent_coef = 0.0f32;
348    let max_grad_norm = 1.0f32;
349
350    // Models
351    let mut actor = Actor::new(state_dim, action_dim, Some(101));
352    let mut critic = Critic::new(state_dim, Some(202));
353
354    // Opts
355    let mut actor_opt = Adam::with_learning_rate(3e-4);
356    for p in actor.parameters() {
357        actor_opt.add_parameter(p);
358    }
359    let mut critic_opt = Adam::with_learning_rate(3e-4);
360    for p in critic.parameters() {
361        critic_opt.add_parameter(p);
362    }
363
364    // Env and RNG
365    let mut env = YardEnv::new(42);
366    let mut rng = SmallRng::new(999);
367    let mut state = env.reset();
368
369    // Metrics
370    let mut episode_return = 0.0f32;
371    let mut episode = 0usize;
372    let mut ema_return: Option<f32> = None;
373    let ema_alpha = 0.05f32;
374    let mut best_return = f32::NEG_INFINITY;
375
376    let mut t = 0usize;
377    while t < total_steps {
378        // Collect a rollout
379        let mut batch = RolloutBatch::new(horizon, state_dim);
380        for _ in 0..horizon {
381            // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382            let (mean, log_std_row) = actor.forward(&state);
383            let mean_v = mean.data()[0];
384            let log_std_v = log_std_row.data()[0];
385            let std_v = log_std_v.exp();
386            let noise = rng.normal();
387            let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389            // Build action tensor [1, A] for log_prob calculation with autograd
390            let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391            let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392            let log_prob_v = log_prob_t.data()[0];
393
394            // Step env
395            let (next_state, reward, done) = env.step(action_v);
396            episode_return += reward;
397
398            // Value
399            let value_t = critic.forward(&state);
400            let value_v = value_t.data()[0];
401
402            // Push
403            batch.push(
404                state.data(),
405                action_v,
406                log_prob_v,
407                reward,
408                if done { 1.0 } else { 0.0 },
409                value_v,
410                next_state.data(),
411            );
412
413            // Reset
414            state = if done {
415                let st = env.reset();
416                ema_return = Some(match ema_return {
417                    None => episode_return,
418                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419                });
420                if episode_return > best_return {
421                    best_return = episode_return;
422                }
423                println!(
424                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425                    t,
426                    episode,
427                    episode_return,
428                    ema_return.unwrap_or(episode_return),
429                    best_return
430                );
431                episode_return = 0.0;
432                episode += 1;
433                st
434            } else {
435                next_state
436            };
437
438            t += 1;
439            if t >= total_steps {
440                break;
441            }
442        }
443
444        // Bootstrap next values for GAE
445        let next_values: Vec<f32> = {
446            let mut out = Vec::with_capacity(batch.len());
447            for i in 0..batch.len() {
448                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450                let v2 = critic.forward(&s2_t).data()[0];
451                out.push(v2);
452            }
453            out
454        };
455
456        // Compute returns and advantages
457        let mut returns = vec![0.0f32; batch.len()];
458        let mut adv = vec![0.0f32; batch.len()];
459        compute_gae(
460            &mut returns,
461            &mut adv,
462            &batch.rewards,
463            &batch.dones,
464            &batch.values,
465            &next_values,
466            gamma,
467            lam,
468        );
469        normalize_in_place(&mut adv, 1e-8);
470
471        // Prepare tensors for training
472        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473        let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474        let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478        // PPO epochs over the rollout
479        let num_minibatches = batch.len().div_ceil(mini_batch_size);
480        for e in 0..epochs {
481            for mb in 0..num_minibatches {
482                let start = mb * mini_batch_size;
483                let end = (start + mini_batch_size).min(batch.len());
484                if start >= end {
485                    break;
486                }
487
488                // Slice views
489                let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490                let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491                let a_mb = actions_t
492                    .slice_view(start * action_dim, 1, (end - start) * action_dim)
493                    .reshape(vec![(end - start) as i32, action_dim as i32]);
494                let oldlp_mb = old_logp_t
495                    .slice_view(start, 1, end - start)
496                    .reshape(vec![(end - start) as i32, 1]);
497                let ret_mb = returns_t
498                    .slice_view(start, 1, end - start)
499                    .reshape(vec![(end - start) as i32, 1]);
500                let adv_mb = adv_t
501                    .slice_view(start, 1, end - start)
502                    .reshape(vec![(end - start) as i32, 1]);
503
504                // Zero grads
505                {
506                    let mut ps = actor.parameters();
507                    actor_opt.zero_grad(&mut ps);
508                }
509                {
510                    let mut ps = critic.parameters();
511                    critic_opt.zero_grad(&mut ps);
512                }
513
514                // Forward actor and critic
515                let (mean_mb, log_std_row) = actor.forward(&s_mb);
516                let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517                let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518                let clip_low =
519                    Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520                        .unwrap();
521                let clip_high =
522                    Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523                        .unwrap();
524                // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525                let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526                let ratio_clipped =
527                    clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528                let pg1 = ratio.mul_tensor(&adv_mb);
529                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532                let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534                let v_pred = critic.forward(&s_mb);
535                let v_loss = v_pred
536                    .sub_tensor(&ret_mb)
537                    .pow_scalar(2.0)
538                    .mean()
539                    .mul_scalar(vf_coef);
540
541                // Entropy (approx Gaussian entropy per action)
542                let entropy = log_std_row
543                    .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544                    .sum_dims(&[1], true)
545                    .mean()
546                    .mul_scalar(ent_coef);
547
548                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549                loss.backward(None);
550
551                // Step actor
552                {
553                    let params = actor.parameters();
554                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
555                    for p in params {
556                        if p.grad_owned().is_some() {
557                            with_grads.push(p);
558                        }
559                    }
560                    if !with_grads.is_empty() {
561                        let _ = grad_global_norm(&mut with_grads);
562                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563                        actor_opt.step(&mut with_grads);
564                        actor_opt.zero_grad(&mut with_grads);
565                    }
566                }
567
568                // Step critic
569                {
570                    let params = critic.parameters();
571                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
572                    for p in params {
573                        if p.grad_owned().is_some() {
574                            with_grads.push(p);
575                        }
576                    }
577                    if !with_grads.is_empty() {
578                        let _ = grad_global_norm(&mut with_grads);
579                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580                        critic_opt.step(&mut with_grads);
581                        critic_opt.zero_grad(&mut with_grads);
582                    }
583                }
584
585                // Occasionally log
586                if e == 0 && mb == 0 {
587                    println!(
588                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
589                        t,
590                        actor_loss.value(),
591                        v_loss.value()
592                    );
593                }
594
595                clear_all_graphs_known();
596            }
597        }
598    }
599
600    println!("=== PPO training finished ===");
601    Ok(())
602}
examples/neural_networks/basic_transformer.rs (line 167)
148    pub fn train_non_autoregressive_steps(
149        &mut self,
150        src: &Tensor,
151        tgt: &Tensor,
152        steps: usize,
153        lr: f32,
154    ) {
155        let mut opt = Adam::with_learning_rate(lr);
156        {
157            let params_once = self.parameters();
158            for p in &params_once {
159                opt.add_parameter(p);
160            }
161        }
162        for step in 0..steps {
163            // forward + backward scope (immutable borrow)
164            {
165                let pred = self.forward(src, tgt);
166                let diff = pred.sub_tensor(tgt);
167                let mut loss = diff.pow_scalar(2.0).mean();
168                if step == 0 || step + 1 == steps {
169                    println!("NAR train step {}: loss={:.6}", step, loss.value());
170                }
171                loss.backward(None);
172            }
173            // step + zero_grad scope (mutable borrow)
174            let mut params_step = self.parameters();
175            opt.step(&mut params_step);
176            opt.zero_grad(&mut params_step);
177        }
178    }
179
180    /// Auto-regressive training (teacher forcing): predict next token with causal mask
181    pub fn train_autoregressive_steps(
182        &mut self,
183        src: &Tensor,
184        tgt: &Tensor,
185        steps: usize,
186        lr: f32,
187    ) {
188        let mut opt = Adam::with_learning_rate(lr);
189        {
190            let params_once = self.parameters();
191            for p in &params_once {
192                opt.add_parameter(p);
193            }
194        }
195
196        // Build encoder memory once (static dataset demo)
197        let mut memory = src.clone();
198        for enc in &self.encoders {
199            memory = enc.forward(&memory, None);
200        }
201
202        let (b, t, _e) = Self::triple(tgt);
203        // Predict y[t] from y[:t] using causal mask; here we simply predict full seq with mask
204        let causal = Self::build_causal_mask_static(b, self.num_heads, t);
205        for step in 0..steps {
206            // forward + backward scope
207            {
208                let mut out = tgt.clone();
209                for dec in &self.decoders {
210                    out = dec.forward(&out, &memory, Some(&causal), None);
211                }
212                let diff = out.sub_tensor(tgt);
213                let mut loss = diff.pow_scalar(2.0).mean();
214                if step == 0 || step + 1 == steps {
215                    println!("AR  train step {}: loss={:.6}", step, loss.value());
216                }
217                loss.backward(None);
218            }
219            let mut params_step = self.parameters();
220            opt.step(&mut params_step);
221            opt.zero_grad(&mut params_step);
222        }
223    }
examples/iterators/element_iteration.rs (line 138)
131fn demonstrate_standard_methods() -> Result<(), Box<dyn std::error::Error>> {
132    println!("\n--- Standard Iterator Methods ---");
133
134    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
135
136    // Using map for transformations
137    println!("\nMap transformation (square each element):");
138    let squared: Tensor = tensor.iter().map(|elem| elem.pow_scalar(2.0)).collect();
139    println!("  Squared: {:?}", squared.data());
140
141    // Using enumerate for indexed operations
142    println!("\nEnumerate with indexed operations:");
143    let indexed: Tensor = tensor
144        .iter()
145        .enumerate()
146        .map(|(i, elem)| elem.add_scalar(i as f32))
147        .collect();
148    println!("  Indexed: {:?}", indexed.data());
149
150    // Using fold for reduction
151    println!("\nFold for sum calculation:");
152    let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
153    println!("  Sum: {:.1}", sum);
154
155    // Using find for element search
156    println!("\nFind specific element:");
157    if let Some(found) = tensor.iter().find(|elem| elem.value() == 3.0) {
158        println!("  Found element with value 3.0: {:.1}", found.value());
159    }
160
161    // Using any/all for condition checking
162    println!("\nCondition checking:");
163    let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
164    let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
165    println!("  All positive: {}", all_positive);
166    println!("  Any > 4.0: {}", any_large);
167
168    Ok(())
169}
170
171/// Demonstrate gradient tracking through element operations
172///
173/// Shows how gradient tracking works seamlessly through iterator
174/// operations, maintaining the computational graph for backpropagation.
175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176    println!("\n--- Gradient Tracking ---");
177
178    // Create a tensor with gradient tracking enabled
179    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180    println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182    // Perform element-wise operations through iteration
183    let result: Tensor = tensor
184        .iter()
185        .map(|elem| {
186            // Apply a complex transformation: (x^2 + 1) * 2
187            elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188        })
189        .collect();
190
191    println!("Result tensor: {:?}", result.data());
192    println!("Result requires_grad: {}", result.requires_grad());
193
194    // Compute gradients
195    let mut loss = result.sum();
196    loss.backward(None);
197
198    println!("Loss: {:.6}", loss.value());
199    println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201    Ok(())
202}
examples/optimizers/adam_configurations.rs (line 112)
84fn demonstrate_default_adam() -> Result<(), Box<dyn std::error::Error>> {
85    println!("--- Default Adam Configuration ---");
86
87    // Create a simple regression problem: y = 2*x + 1
88    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
89    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
90
91    // Create model parameters
92    let mut weight = Tensor::randn(vec![1, 1], Some(42)).with_requires_grad();
93    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
94
95    // Create Adam optimizer with default configuration
96    let mut optimizer = Adam::new();
97    optimizer.add_parameter(&weight);
98    optimizer.add_parameter(&bias);
99
100    println!("Default Adam configuration:");
101    println!("  Learning rate: {}", optimizer.learning_rate());
102    println!("  Initial weight: {:.6}", weight.value());
103    println!("  Initial bias: {:.6}", bias.value());
104
105    // Training loop
106    let num_epochs = 50;
107    let mut losses = Vec::new();
108
109    for epoch in 0..num_epochs {
110        // Forward pass
111        let y_pred = x_data.matmul(&weight) + &bias;
112        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
113
114        // Backward pass
115        loss.backward(None);
116
117        // Optimizer step
118        optimizer.step(&mut [&mut weight, &mut bias]);
119        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
120
121        losses.push(loss.value());
122
123        if epoch % 10 == 0 || epoch == num_epochs - 1 {
124            println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
125        }
126    }
127
128    // Evaluate final model
129    let _final_predictions = x_data.matmul(&weight) + &bias;
130    println!("\nFinal model:");
131    println!("  Learned weight: {:.6} (target: 2.0)", weight.value());
132    println!("  Learned bias: {:.6} (target: 1.0)", bias.value());
133    println!("  Final loss: {:.6}", losses[losses.len() - 1]);
134
135    Ok(())
136}
137
138/// Demonstrate learning rate comparison
139fn demonstrate_learning_rate_comparison() -> Result<(), Box<dyn std::error::Error>> {
140    println!("\n--- Learning Rate Comparison ---");
141
142    let learning_rates = [0.001, 0.01, 0.1];
143    let mut results = Vec::new();
144
145    for &lr in &learning_rates {
146        println!("\nTesting learning rate: {}", lr);
147
148        let stats = train_with_config(TrainingConfig {
149            learning_rate: lr,
150            ..Default::default()
151        })?;
152
153        results.push((lr, stats.clone()));
154
155        println!("  Final loss: {:.6}", stats.final_loss);
156        println!("  Convergence epoch: {}", stats.convergence_epoch);
157    }
158
159    // Compare results
160    println!("\nLearning Rate Comparison Summary:");
161    for (lr, stats) in &results {
162        println!(
163            "  LR={:6}: Loss={:.6}, Converged@{}",
164            lr, stats.final_loss, stats.convergence_epoch
165        );
166    }
167
168    Ok(())
169}
170
171/// Demonstrate weight decay comparison
172fn demonstrate_weight_decay_comparison() -> Result<(), Box<dyn std::error::Error>> {
173    println!("\n--- Weight Decay Comparison ---");
174
175    let weight_decays = [0.0, 0.001, 0.01];
176    let mut results = Vec::new();
177
178    for &wd in &weight_decays {
179        println!("\nTesting weight decay: {}", wd);
180
181        let stats = train_with_config(TrainingConfig {
182            weight_decay: wd,
183            ..Default::default()
184        })?;
185
186        results.push((wd, stats.clone()));
187
188        println!("  Final loss: {:.6}", stats.final_loss);
189        println!("  Final weight norm: {:.6}", stats.weight_norm);
190    }
191
192    // Compare results
193    println!("\nWeight Decay Comparison Summary:");
194    for (wd, stats) in &results {
195        println!(
196            "  WD={:6}: Loss={:.6}, Weight Norm={:.6}",
197            wd, stats.final_loss, stats.weight_norm
198        );
199    }
200
201    Ok(())
202}
203
204/// Demonstrate beta parameter tuning
205fn demonstrate_beta_parameter_tuning() -> Result<(), Box<dyn std::error::Error>> {
206    println!("\n--- Beta Parameter Tuning ---");
207
208    let beta_configs = [
209        (0.9, 0.999),  // Default
210        (0.8, 0.999),  // More aggressive momentum
211        (0.95, 0.999), // Less aggressive momentum
212        (0.9, 0.99),   // Faster second moment decay
213    ];
214
215    let mut results = Vec::new();
216
217    for (i, (beta1, beta2)) in beta_configs.iter().enumerate() {
218        println!(
219            "\nTesting beta configuration {}: beta1={}, beta2={}",
220            i + 1,
221            beta1,
222            beta2
223        );
224
225        let config = TrainingConfig {
226            beta1: *beta1,
227            beta2: *beta2,
228            ..Default::default()
229        };
230
231        let stats = train_with_config(config)?;
232        results.push(((*beta1, *beta2), stats.clone()));
233
234        println!("  Final loss: {:.6}", stats.final_loss);
235        println!("  Convergence epoch: {}", stats.convergence_epoch);
236    }
237
238    // Compare results
239    println!("\nBeta Parameter Comparison Summary:");
240    for ((beta1, beta2), stats) in &results {
241        println!(
242            "  B1={:4}, B2={:5}: Loss={:.6}, Converged@{}",
243            beta1, beta2, stats.final_loss, stats.convergence_epoch
244        );
245    }
246
247    Ok(())
248}
249
250/// Demonstrate configuration benchmarking
251fn demonstrate_configuration_benchmarking() -> Result<(), Box<dyn std::error::Error>> {
252    println!("\n--- Configuration Benchmarking ---");
253
254    // Define configurations to benchmark
255    let configs = vec![
256        (
257            "Conservative",
258            TrainingConfig {
259                learning_rate: 0.001,
260                weight_decay: 0.001,
261                beta1: 0.95,
262                ..Default::default()
263            },
264        ),
265        (
266            "Balanced",
267            TrainingConfig {
268                learning_rate: 0.01,
269                weight_decay: 0.0,
270                beta1: 0.9,
271                ..Default::default()
272            },
273        ),
274        (
275            "Aggressive",
276            TrainingConfig {
277                learning_rate: 0.1,
278                weight_decay: 0.0,
279                beta1: 0.8,
280                ..Default::default()
281            },
282        ),
283    ];
284
285    let mut benchmark_results = Vec::new();
286
287    for (name, config) in configs {
288        println!("\nBenchmarking {} configuration:", name);
289
290        let start_time = std::time::Instant::now();
291        let stats = train_with_config(config.clone())?;
292        let elapsed = start_time.elapsed();
293
294        println!("  Training time: {:.2}ms", elapsed.as_millis());
295        println!("  Final loss: {:.6}", stats.final_loss);
296        println!("  Convergence: {} epochs", stats.convergence_epoch);
297
298        benchmark_results.push((name.to_string(), stats, elapsed));
299    }
300
301    // Summary
302    println!("\nBenchmarking Summary:");
303    for (name, stats, elapsed) in &benchmark_results {
304        println!(
305            "  {:12}: Loss={:.6}, Time={:4}ms, Converged@{}",
306            name,
307            stats.final_loss,
308            elapsed.as_millis(),
309            stats.convergence_epoch
310        );
311    }
312
313    Ok(())
314}
315
316/// Helper function to train with specific configuration
317fn train_with_config(config: TrainingConfig) -> Result<TrainingStats, Box<dyn std::error::Error>> {
318    // Create training data
319    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
320    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
321
322    // Create model parameters
323    let mut weight = Tensor::randn(vec![1, 1], Some(123)).with_requires_grad();
324    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
325
326    // Create optimizer with custom configuration
327    let adam_config = AdamConfig {
328        learning_rate: config.learning_rate,
329        beta1: config.beta1,
330        beta2: config.beta2,
331        eps: 1e-8,
332        weight_decay: config.weight_decay,
333        amsgrad: false,
334    };
335
336    let mut optimizer = Adam::with_config(adam_config);
337    optimizer.add_parameter(&weight);
338    optimizer.add_parameter(&bias);
339
340    // Training loop
341    let mut losses = Vec::new();
342    let mut convergence_epoch = config.epochs;
343
344    for epoch in 0..config.epochs {
345        // Forward pass
346        let y_pred = x_data.matmul(&weight) + &bias;
347        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
348
349        // Backward pass
350        loss.backward(None);
351
352        // Optimizer step
353        optimizer.step(&mut [&mut weight, &mut bias]);
354        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
355
356        let loss_value = loss.value();
357        losses.push(loss_value);
358
359        // Check for convergence (loss < 0.01)
360        if loss_value < 0.01 && convergence_epoch == config.epochs {
361            convergence_epoch = epoch;
362        }
363    }
364
365    Ok(TrainingStats {
366        config,
367        final_loss: losses[losses.len() - 1],
368        loss_history: losses,
369        convergence_epoch,
370        weight_norm: weight.norm().value(),
371    })
372}
Source

pub fn pow_tensor(&self, exponent: &Tensor) -> Tensor

Element-wise power with tensor exponents.

Computes element-wise power: output[i] = self[i]^exponent[i]

§Arguments
  • exponent - Tensor of exponents, must have the same shape as self
§Returns

A new tensor with each element raised to the corresponding power

§Examples
§Basic Tensor Power
use train_station::Tensor;

let base = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![3]).unwrap();
let exp = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let result = base.pow_tensor(&exp);
assert_eq!(result.shape().dims(), vec![3]);
assert_eq!(result.get(&[0]), 2.0); // 2.0^1.0 = 2.0
assert_eq!(result.get(&[1]), 9.0); // 3.0^2.0 = 9.0
assert_eq!(result.get(&[2]), 64.0); // 4.0^3.0 = 64.0
§Mixed Exponents
use train_station::Tensor;

let base = Tensor::from_slice(&[4.0, 9.0, 16.0], vec![3]).unwrap();
let exp = Tensor::from_slice(&[0.5, 1.0, 2.0], vec![3]).unwrap();
let result = base.pow_tensor(&exp);
assert_eq!(result.shape().dims(), vec![3]);
assert_eq!(result.get(&[0]), 2.0); // sqrt(4.0) = 2.0
assert_eq!(result.get(&[1]), 9.0); // 9.0^1.0 = 9.0
assert_eq!(result.get(&[2]), 256.0); // 16.0^2.0 = 256.0
§Panics

Panics if tensor shapes don’t match

Source§

impl Tensor

Source

pub fn relu(&self) -> Tensor

Element-wise ReLU (Rectified Linear Unit) activation.

Applies ReLU to each element: output[i] = max(0, self[i])

§Returns

A new tensor with ReLU applied to each element

§Examples
§Basic ReLU Activation
use train_station::Tensor;

let a = Tensor::from_slice(&[-1.0, 0.0, 2.5], vec![3]).unwrap();
let b = a.relu();
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 0.0); // max(0, -1.0) = 0.0
assert_eq!(b.get(&[1]), 0.0); // max(0, 0.0) = 0.0
assert_eq!(b.get(&[2]), 2.5); // max(0, 2.5) = 2.5
§Mixed Positive and Negative Values
use train_station::Tensor;

let a = Tensor::from_slice(&[-5.0, -0.1, 0.0, 0.1, 5.0], vec![5]).unwrap();
let b = a.relu();
assert_eq!(b.shape().dims(), vec![5]);
assert_eq!(b.get(&[0]), 0.0); // max(0, -5.0) = 0.0
assert_eq!(b.get(&[1]), 0.0); // max(0, -0.1) = 0.0
assert_eq!(b.get(&[2]), 0.0); // max(0, 0.0) = 0.0
assert_eq!(b.get(&[3]), 0.1); // max(0, 0.1) = 0.1
assert_eq!(b.get(&[4]), 5.0); // max(0, 5.0) = 5.0
Examples found in repository?
examples/neural_networks/feedforward_network.rs (line 55)
54    pub fn forward(input: &Tensor) -> Tensor {
55        input.relu()
56    }
More examples
Hide additional examples
examples/supervised_training/supervised_bce.rs (line 60)
59fn bce_with_logits(logits: &Tensor, targets: &Tensor) -> Tensor {
60    let relu_z = logits.relu();
61    let zy = logits.mul_tensor(targets);
62    // |z| = relu(z) + relu(-z)
63    let abs_z = relu_z.add_tensor(&logits.mul_scalar(-1.0).relu());
64    let log_term = abs_z.mul_scalar(-1.0).exp().add_scalar(1.0).log();
65    relu_z.sub_tensor(&zy).add_tensor(&log_term).mean()
66}
examples/RL_training/ppo_continuous.rs (line 77)
68    fn forward(&self, input: &Tensor) -> Tensor {
69        let mut current: Option<Tensor> = None;
70        for (i, layer) in self.layers.iter().enumerate() {
71            let out = if i == 0 {
72                layer.forward(input)
73            } else {
74                layer.forward(current.as_ref().unwrap())
75            };
76            let is_last = i + 1 == self.layers.len();
77            let out = if !is_last { out.relu() } else { out };
78            current = Some(out);
79        }
80        current.expect("MLP has at least one layer")
81    }
82    fn parameters(&mut self) -> Vec<&mut Tensor> {
83        self.layers
84            .iter_mut()
85            .flat_map(|l| l.parameters())
86            .collect()
87    }
88}
89
90// -------------------------------
91// Actor: mean = MLP(state); log_std is a learnable parameter tensor
92// -------------------------------
93
94struct Actor {
95    net: Mlp,
96    log_std: Tensor, // shape [action_dim]
97}
98impl Actor {
99    fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
100        let net = Mlp::new(&[state_dim, 64, 64, action_dim], seed);
101        let log_std = Tensor::from_slice(&vec![0.0; action_dim], vec![action_dim])
102            .unwrap()
103            .with_requires_grad();
104        Self { net, log_std }
105    }
106    fn forward(&self, state: &Tensor) -> (Tensor, Tensor) {
107        // Returns (mean [B, A], log_std [A])
108        let mean = self.net.forward(state);
109        (
110            mean,
111            self.log_std
112                .view(vec![1, self.log_std.shape().dims()[0] as i32]),
113        )
114    }
115    fn parameters(&mut self) -> Vec<&mut Tensor> {
116        let mut ps = self.net.parameters();
117        ps.push(&mut self.log_std);
118        ps
119    }
120}
121
122// -------------------------------
123// Critic: value function V(s)
124// -------------------------------
125
126struct Critic {
127    net: Mlp,
128}
129impl Critic {
130    fn new(state_dim: usize, seed: Option<u64>) -> Self {
131        Self {
132            net: Mlp::new(&[state_dim, 64, 64, 1], seed),
133        }
134    }
135    fn forward(&self, state: &Tensor) -> Tensor {
136        self.net.forward(state)
137    }
138    fn parameters(&mut self) -> Vec<&mut Tensor> {
139        self.net.parameters()
140    }
141}
142
143// -------------------------------
144// Continuous YardEnv (same dynamics as TD3 env)
145// -------------------------------
146
147struct YardEnv {
148    pos: f32,
149    vel: f32,
150    steps: usize,
151    max_steps: usize,
152    rng: SmallRng,
153}
154impl YardEnv {
155    fn new(seed: u64) -> Self {
156        let mut e = Self {
157            pos: 0.0,
158            vel: 0.0,
159            steps: 0,
160            max_steps: 200,
161            rng: SmallRng::new(seed),
162        };
163        e.reset();
164        e
165    }
166    fn reset(&mut self) -> Tensor {
167        self.pos = (self.rng.next_f32() * 1.0) - 0.5;
168        self.vel = (self.rng.next_f32() * 0.2) - 0.1;
169        self.steps = 0;
170        self.state_tensor()
171    }
172    fn state_tensor(&self) -> Tensor {
173        Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
174    }
175    fn step(&mut self, action_value: f32) -> (Tensor, f32, bool) {
176        let a = action_value.clamp(-1.0, 1.0);
177        self.vel += 0.1 * a - 0.01 * self.pos;
178        self.pos += self.vel;
179        self.steps += 1;
180        let reward = -(self.pos * self.pos) - 0.1 * (a * a);
181        let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
182        (self.state_tensor(), reward, done)
183    }
184}
185
186// -------------------------------
187// Trajectory storage
188// -------------------------------
189
190struct RolloutBatch {
191    states: Vec<f32>,
192    actions: Vec<f32>,
193    log_probs: Vec<f32>,
194    rewards: Vec<f32>,
195    dones: Vec<f32>,
196    values: Vec<f32>,
197    next_states: Vec<f32>,
198    _state_dim: usize,
199}
200impl RolloutBatch {
201    fn new(capacity: usize, state_dim: usize) -> Self {
202        Self {
203            states: Vec::with_capacity(capacity * state_dim),
204            actions: Vec::with_capacity(capacity),
205            log_probs: Vec::with_capacity(capacity),
206            rewards: Vec::with_capacity(capacity),
207            dones: Vec::with_capacity(capacity),
208            values: Vec::with_capacity(capacity),
209            next_states: Vec::with_capacity(capacity * state_dim),
210            _state_dim: state_dim,
211        }
212    }
213
214    #[allow(clippy::too_many_arguments)]
215    fn push(&mut self, s: &[f32], a: f32, lp: f32, r: f32, d: f32, v: f32, s2: &[f32]) {
216        self.states.extend_from_slice(s);
217        self.actions.push(a);
218        self.log_probs.push(lp);
219        self.rewards.push(r);
220        self.dones.push(d);
221        self.values.push(v);
222        self.next_states.extend_from_slice(s2);
223    }
224
225    fn len(&self) -> usize {
226        self.actions.len()
227    }
228}
229
230// -------------------------------
231// Math helpers
232// -------------------------------
233
234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235    // All tensors shaped [B, A] (log_std is broadcastable)
236    let std = log_std.exp();
237    let var = std.pow_scalar(2.0);
238    let log_scale = log_std;
239    let diff = action.sub_tensor(mean);
240    let log_prob = diff
241        .pow_scalar(2.0)
242        .div_tensor(&var)
243        .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244        .add_tensor(&log_scale.mul_scalar(2.0))
245        .mul_scalar(0.5)
246        .mul_scalar(-1.0);
247    // Sum across action dim (dim=1) -> [B,1]
248    log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253    returns_out: &mut [f32],
254    adv_out: &mut [f32],
255    rewards: &[f32],
256    dones: &[f32],
257    values: &[f32],
258    next_values: &[f32],
259    gamma: f32,
260    lam: f32,
261) {
262    let n = rewards.len();
263    let mut gae = 0.0f32;
264    for t in (0..n).rev() {
265        let not_done = 1.0 - dones[t];
266        let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267        gae = delta + gamma * lam * not_done * gae;
268        adv_out[t] = gae;
269        returns_out[t] = gae + values[t];
270    }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274    let n = x.len() as f32;
275    if n <= 1.0 {
276        return;
277    }
278    let mean = x.iter().copied().sum::<f32>() / n;
279    let var = x
280        .iter()
281        .map(|v| {
282            let d = v - mean;
283            d * d
284        })
285        .sum::<f32>()
286        / n;
287    let std = (var + eps).sqrt();
288    for v in x.iter_mut() {
289        *v = (*v - mean) / std;
290    }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294    let mut total_sq = 0.0f32;
295    for p in parameters.iter() {
296        if let Some(g) = p.grad_owned() {
297            for &v in g.data() {
298                total_sq += v * v;
299            }
300        }
301    }
302    let norm = total_sq.sqrt();
303    if norm > max_norm {
304        let scale = max_norm / (norm + eps);
305        for p in parameters.iter_mut() {
306            if let Some(g) = p.grad_owned() {
307                p.set_grad(g.mul_scalar(scale));
308            }
309        }
310    }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314    let mut total_sq = 0.0f32;
315    for p in parameters.iter_mut() {
316        if let Some(g) = p.grad_owned() {
317            for &v in g.data() {
318                total_sq += v * v;
319            }
320        }
321    }
322    total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330    println!("=== PPO Continuous Example (YardEnv) ===");
331
332    let state_dim = 3usize;
333    let action_dim = 1usize;
334
335    // Hparams
336    let total_steps = std::env::var("PPO_STEPS")
337        .ok()
338        .and_then(|v| v.parse::<usize>().ok())
339        .unwrap_or(4000usize);
340    let horizon = 128usize; // rollout length per update
341    let epochs = 4usize; // PPO epochs per update
342    let mini_batch_size = 64usize; // minibatch from horizon
343    let gamma = 0.99f32;
344    let lam = 0.95f32; // GAE lambda
345    let clip_eps = 0.2f32;
346    let vf_coef = 0.5f32;
347    let ent_coef = 0.0f32;
348    let max_grad_norm = 1.0f32;
349
350    // Models
351    let mut actor = Actor::new(state_dim, action_dim, Some(101));
352    let mut critic = Critic::new(state_dim, Some(202));
353
354    // Opts
355    let mut actor_opt = Adam::with_learning_rate(3e-4);
356    for p in actor.parameters() {
357        actor_opt.add_parameter(p);
358    }
359    let mut critic_opt = Adam::with_learning_rate(3e-4);
360    for p in critic.parameters() {
361        critic_opt.add_parameter(p);
362    }
363
364    // Env and RNG
365    let mut env = YardEnv::new(42);
366    let mut rng = SmallRng::new(999);
367    let mut state = env.reset();
368
369    // Metrics
370    let mut episode_return = 0.0f32;
371    let mut episode = 0usize;
372    let mut ema_return: Option<f32> = None;
373    let ema_alpha = 0.05f32;
374    let mut best_return = f32::NEG_INFINITY;
375
376    let mut t = 0usize;
377    while t < total_steps {
378        // Collect a rollout
379        let mut batch = RolloutBatch::new(horizon, state_dim);
380        for _ in 0..horizon {
381            // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382            let (mean, log_std_row) = actor.forward(&state);
383            let mean_v = mean.data()[0];
384            let log_std_v = log_std_row.data()[0];
385            let std_v = log_std_v.exp();
386            let noise = rng.normal();
387            let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389            // Build action tensor [1, A] for log_prob calculation with autograd
390            let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391            let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392            let log_prob_v = log_prob_t.data()[0];
393
394            // Step env
395            let (next_state, reward, done) = env.step(action_v);
396            episode_return += reward;
397
398            // Value
399            let value_t = critic.forward(&state);
400            let value_v = value_t.data()[0];
401
402            // Push
403            batch.push(
404                state.data(),
405                action_v,
406                log_prob_v,
407                reward,
408                if done { 1.0 } else { 0.0 },
409                value_v,
410                next_state.data(),
411            );
412
413            // Reset
414            state = if done {
415                let st = env.reset();
416                ema_return = Some(match ema_return {
417                    None => episode_return,
418                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419                });
420                if episode_return > best_return {
421                    best_return = episode_return;
422                }
423                println!(
424                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425                    t,
426                    episode,
427                    episode_return,
428                    ema_return.unwrap_or(episode_return),
429                    best_return
430                );
431                episode_return = 0.0;
432                episode += 1;
433                st
434            } else {
435                next_state
436            };
437
438            t += 1;
439            if t >= total_steps {
440                break;
441            }
442        }
443
444        // Bootstrap next values for GAE
445        let next_values: Vec<f32> = {
446            let mut out = Vec::with_capacity(batch.len());
447            for i in 0..batch.len() {
448                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450                let v2 = critic.forward(&s2_t).data()[0];
451                out.push(v2);
452            }
453            out
454        };
455
456        // Compute returns and advantages
457        let mut returns = vec![0.0f32; batch.len()];
458        let mut adv = vec![0.0f32; batch.len()];
459        compute_gae(
460            &mut returns,
461            &mut adv,
462            &batch.rewards,
463            &batch.dones,
464            &batch.values,
465            &next_values,
466            gamma,
467            lam,
468        );
469        normalize_in_place(&mut adv, 1e-8);
470
471        // Prepare tensors for training
472        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473        let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474        let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478        // PPO epochs over the rollout
479        let num_minibatches = batch.len().div_ceil(mini_batch_size);
480        for e in 0..epochs {
481            for mb in 0..num_minibatches {
482                let start = mb * mini_batch_size;
483                let end = (start + mini_batch_size).min(batch.len());
484                if start >= end {
485                    break;
486                }
487
488                // Slice views
489                let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490                let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491                let a_mb = actions_t
492                    .slice_view(start * action_dim, 1, (end - start) * action_dim)
493                    .reshape(vec![(end - start) as i32, action_dim as i32]);
494                let oldlp_mb = old_logp_t
495                    .slice_view(start, 1, end - start)
496                    .reshape(vec![(end - start) as i32, 1]);
497                let ret_mb = returns_t
498                    .slice_view(start, 1, end - start)
499                    .reshape(vec![(end - start) as i32, 1]);
500                let adv_mb = adv_t
501                    .slice_view(start, 1, end - start)
502                    .reshape(vec![(end - start) as i32, 1]);
503
504                // Zero grads
505                {
506                    let mut ps = actor.parameters();
507                    actor_opt.zero_grad(&mut ps);
508                }
509                {
510                    let mut ps = critic.parameters();
511                    critic_opt.zero_grad(&mut ps);
512                }
513
514                // Forward actor and critic
515                let (mean_mb, log_std_row) = actor.forward(&s_mb);
516                let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517                let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518                let clip_low =
519                    Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520                        .unwrap();
521                let clip_high =
522                    Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523                        .unwrap();
524                // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525                let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526                let ratio_clipped =
527                    clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528                let pg1 = ratio.mul_tensor(&adv_mb);
529                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532                let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534                let v_pred = critic.forward(&s_mb);
535                let v_loss = v_pred
536                    .sub_tensor(&ret_mb)
537                    .pow_scalar(2.0)
538                    .mean()
539                    .mul_scalar(vf_coef);
540
541                // Entropy (approx Gaussian entropy per action)
542                let entropy = log_std_row
543                    .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544                    .sum_dims(&[1], true)
545                    .mean()
546                    .mul_scalar(ent_coef);
547
548                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549                loss.backward(None);
550
551                // Step actor
552                {
553                    let params = actor.parameters();
554                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
555                    for p in params {
556                        if p.grad_owned().is_some() {
557                            with_grads.push(p);
558                        }
559                    }
560                    if !with_grads.is_empty() {
561                        let _ = grad_global_norm(&mut with_grads);
562                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563                        actor_opt.step(&mut with_grads);
564                        actor_opt.zero_grad(&mut with_grads);
565                    }
566                }
567
568                // Step critic
569                {
570                    let params = critic.parameters();
571                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
572                    for p in params {
573                        if p.grad_owned().is_some() {
574                            with_grads.push(p);
575                        }
576                    }
577                    if !with_grads.is_empty() {
578                        let _ = grad_global_norm(&mut with_grads);
579                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580                        critic_opt.step(&mut with_grads);
581                        critic_opt.zero_grad(&mut with_grads);
582                    }
583                }
584
585                // Occasionally log
586                if e == 0 && mb == 0 {
587                    println!(
588                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
589                        t,
590                        actor_loss.value(),
591                        v_loss.value()
592                    );
593                }
594
595                clear_all_graphs_known();
596            }
597        }
598    }
599
600    println!("=== PPO training finished ===");
601    Ok(())
602}
examples/RL_training/ppo_discrete.rs (line 69)
60    fn forward(&self, input: &Tensor) -> Tensor {
61        let mut current: Option<Tensor> = None;
62        for (i, layer) in self.layers.iter().enumerate() {
63            let out = if i == 0 {
64                layer.forward(input)
65            } else {
66                layer.forward(current.as_ref().unwrap())
67            };
68            let is_last = i + 1 == self.layers.len();
69            let out = if !is_last { out.relu() } else { out };
70            current = Some(out);
71        }
72        current.expect("MLP has at least one layer")
73    }
74    fn parameters(&mut self) -> Vec<&mut Tensor> {
75        self.layers
76            .iter_mut()
77            .flat_map(|l| l.parameters())
78            .collect()
79    }
80}
81
82// -------------------------------
83// Actor (logits) + Critic
84// -------------------------------
85
86struct Actor {
87    net: Mlp,
88}
89impl Actor {
90    fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
91        Self {
92            net: Mlp::new(&[state_dim, 64, 64, action_dim], seed),
93        }
94    }
95    fn forward(&self, state: &Tensor) -> Tensor {
96        self.net.forward(state)
97    } // logits [B, A]
98    fn parameters(&mut self) -> Vec<&mut Tensor> {
99        self.net.parameters()
100    }
101}
102
103struct Critic {
104    net: Mlp,
105}
106impl Critic {
107    fn new(state_dim: usize, seed: Option<u64>) -> Self {
108        Self {
109            net: Mlp::new(&[state_dim, 64, 64, 1], seed),
110        }
111    }
112    fn forward(&self, state: &Tensor) -> Tensor {
113        self.net.forward(state)
114    }
115    fn parameters(&mut self) -> Vec<&mut Tensor> {
116        self.net.parameters()
117    }
118}
119
120// -------------------------------
121// Discrete YardEnv (3 actions -> -1, 0, +1)
122// -------------------------------
123
124struct YardEnv {
125    pos: f32,
126    vel: f32,
127    steps: usize,
128    max_steps: usize,
129    rng: SmallRng,
130}
131impl YardEnv {
132    const ACTIONS: [f32; 3] = [-1.0, 0.0, 1.0];
133    fn new(seed: u64) -> Self {
134        let mut e = Self {
135            pos: 0.0,
136            vel: 0.0,
137            steps: 0,
138            max_steps: 200,
139            rng: SmallRng::new(seed),
140        };
141        e.reset();
142        e
143    }
144    fn reset(&mut self) -> Tensor {
145        self.pos = (self.rng.next_f32() * 1.0) - 0.5;
146        self.vel = (self.rng.next_f32() * 0.2) - 0.1;
147        self.steps = 0;
148        self.state_tensor()
149    }
150    fn state_tensor(&self) -> Tensor {
151        Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
152    }
153    fn step(&mut self, action_idx: usize) -> (Tensor, f32, bool) {
154        let a = Self::ACTIONS[action_idx.min(2)];
155        self.vel += 0.1 * a - 0.01 * self.pos;
156        self.pos += self.vel;
157        self.steps += 1;
158        let reward = -(self.pos * self.pos) - 0.05 * (a * a);
159        let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
160        (self.state_tensor(), reward, done)
161    }
162}
163
164// -------------------------------
165// Rollout storage
166// -------------------------------
167
168struct RolloutBatch {
169    states: Vec<f32>,
170    actions: Vec<usize>,
171    old_logps: Vec<f32>,
172    rewards: Vec<f32>,
173    dones: Vec<f32>,
174    values: Vec<f32>,
175    next_states: Vec<f32>,
176    _state_dim: usize,
177}
178impl RolloutBatch {
179    fn new(cap: usize, sd: usize) -> Self {
180        Self {
181            states: Vec::with_capacity(cap * sd),
182            actions: Vec::with_capacity(cap),
183            old_logps: Vec::with_capacity(cap),
184            rewards: Vec::with_capacity(cap),
185            dones: Vec::with_capacity(cap),
186            values: Vec::with_capacity(cap),
187            next_states: Vec::with_capacity(cap * sd),
188            _state_dim: sd,
189        }
190    }
191    #[allow(clippy::too_many_arguments)]
192    fn push(&mut self, s: &[f32], a: usize, lp: f32, r: f32, d: f32, v: f32, s2: &[f32]) {
193        self.states.extend_from_slice(s);
194        self.actions.push(a);
195        self.old_logps.push(lp);
196        self.rewards.push(r);
197        self.dones.push(d);
198        self.values.push(v);
199        self.next_states.extend_from_slice(s2);
200    }
201    fn len(&self) -> usize {
202        self.actions.len()
203    }
204}
205
206// -------------------------------
207// Helpers
208// -------------------------------
209
210#[allow(clippy::too_many_arguments)]
211fn compute_gae(
212    returns_out: &mut [f32],
213    adv_out: &mut [f32],
214    rewards: &[f32],
215    dones: &[f32],
216    values: &[f32],
217    next_values: &[f32],
218    gamma: f32,
219    lam: f32,
220) {
221    let n = rewards.len();
222    let mut gae = 0.0f32;
223    for t in (0..n).rev() {
224        let not_done = 1.0 - dones[t];
225        let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
226        gae = delta + gamma * lam * not_done * gae;
227        adv_out[t] = gae;
228        returns_out[t] = gae + values[t];
229    }
230}
231
232fn normalize_in_place(x: &mut [f32], eps: f32) {
233    let n = x.len() as f32;
234    if n <= 1.0 {
235        return;
236    }
237    let mean = x.iter().copied().sum::<f32>() / n;
238    let var = x
239        .iter()
240        .map(|v| {
241            let d = v - mean;
242            d * d
243        })
244        .sum::<f32>()
245        / n;
246    let std = (var + eps).sqrt();
247    for v in x.iter_mut() {
248        *v = (*v - mean) / std;
249    }
250}
251
252fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
253    let mut total_sq = 0.0f32;
254    for p in parameters.iter() {
255        if let Some(g) = p.grad_owned() {
256            for &v in g.data() {
257                total_sq += v * v;
258            }
259        }
260    }
261    let norm = total_sq.sqrt();
262    if norm > max_norm {
263        let scale = max_norm / (norm + eps);
264        for p in parameters.iter_mut() {
265            if let Some(g) = p.grad_owned() {
266                p.set_grad(g.mul_scalar(scale));
267            }
268        }
269    }
270}
271
272// log-softmax for selected actions: given logits [B,A] and actions Vec<usize> -> log_prob [B,1]
273fn log_prob_actions(
274    logits: &Tensor,
275    actions: &[usize],
276    batch: usize,
277    _action_dim: usize,
278) -> Tensor {
279    let max_logits = logits.max_dims(&[1], true); // [B,1]
280    let shifted = logits.sub_tensor(&max_logits);
281    let exp = shifted.exp();
282    let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283    let log_sum_exp = sum_exp.log(); // [B,1]
284    let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285                                                        // gather selected action log-probs
286    log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291    new_logp.sub_tensor(old_logp).exp()
292}
293
294// Clamp ratio to [1-clip, 1+clip] using ReLU-based clamp (no custom ops)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296    let b = ratio.shape().dims()[0];
297    let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298    let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299    let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300    high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304    let mut total_sq = 0.0f32;
305    for p in parameters.iter_mut() {
306        if let Some(g) = p.grad_owned() {
307            for &v in g.data() {
308                total_sq += v * v;
309            }
310        }
311    }
312    total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320    println!("=== PPO Discrete Example (YardEnv) ===");
321
322    let state_dim = 3usize;
323    let action_dim = 3usize;
324    let total_steps = std::env::var("PPOD_STEPS")
325        .ok()
326        .and_then(|v| v.parse::<usize>().ok())
327        .unwrap_or(3500usize);
328    let horizon = 128usize;
329    let epochs = 4usize;
330    let mini_batch_size = 64usize;
331    let gamma = 0.99f32;
332    let lam = 0.95f32;
333    let clip_eps = 0.2f32;
334    let vf_coef = 0.5f32;
335    let ent_coef = 0.0f32;
336    let max_grad_norm = 1.0f32;
337
338    let mut actor = Actor::new(state_dim, action_dim, Some(111));
339    let mut critic = Critic::new(state_dim, Some(222));
340    let mut actor_opt = Adam::with_learning_rate(3e-4);
341    for p in actor.parameters() {
342        actor_opt.add_parameter(p);
343    }
344    let mut critic_opt = Adam::with_learning_rate(3e-4);
345    for p in critic.parameters() {
346        critic_opt.add_parameter(p);
347    }
348
349    let mut env = YardEnv::new(1234);
350    let mut rng = SmallRng::new(98765);
351    let mut state = env.reset();
352    let mut episode_return = 0.0f32;
353    let mut episode = 0usize;
354    let mut ema_return: Option<f32> = None;
355    let ema_alpha = 0.05f32;
356    let mut best_return = f32::NEG_INFINITY;
357
358    let mut t = 0usize;
359    while t < total_steps {
360        let mut batch = RolloutBatch::new(horizon, state_dim);
361        for _ in 0..horizon {
362            // Actor logits and categorical sampling
363            let logits = actor.forward(&state); // [1, A]
364            let probs = logits.softmax(1); // [1, A]
365                                           // sample action from probs (CPU sampling)
366            let p = probs.data();
367            let (p0, p1, _p2) = (p[0], p[1], p[2]);
368            let u = rng.next_f32();
369            let a_idx = if u < p0 {
370                0
371            } else if u < p0 + p1 {
372                1
373            } else {
374                2
375            };
376
377            let old_logp = {
378                let _ng = NoGradTrack::new();
379                let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380                lp.data()[0]
381            };
382
383            // Step env
384            let (next_state, reward, done) = env.step(a_idx);
385            episode_return += reward;
386
387            // Critic value
388            let value_t = critic.forward(&state);
389            let value_v = value_t.data()[0];
390
391            batch.push(
392                state.data(),
393                a_idx,
394                old_logp,
395                reward,
396                if done { 1.0 } else { 0.0 },
397                value_v,
398                next_state.data(),
399            );
400
401            state = if done {
402                let st = env.reset();
403                ema_return = Some(match ema_return {
404                    None => episode_return,
405                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406                });
407                if episode_return > best_return {
408                    best_return = episode_return;
409                }
410                println!(
411                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412                    t,
413                    episode,
414                    episode_return,
415                    ema_return.unwrap_or(episode_return),
416                    best_return
417                );
418                episode_return = 0.0;
419                episode += 1;
420                st
421            } else {
422                next_state
423            };
424
425            t += 1;
426            if t >= total_steps {
427                break;
428            }
429        }
430
431        // Bootstrap values for GAE
432        let next_values: Vec<f32> = {
433            let mut out = Vec::with_capacity(batch.len());
434            for i in 0..batch.len() {
435                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437                out.push(critic.forward(&s2_t).data()[0]);
438            }
439            out
440        };
441
442        let mut returns = vec![0.0f32; batch.len()];
443        let mut adv = vec![0.0f32; batch.len()];
444        compute_gae(
445            &mut returns,
446            &mut adv,
447            &batch.rewards,
448            &batch.dones,
449            &batch.values,
450            &next_values,
451            gamma,
452            lam,
453        );
454        normalize_in_place(&mut adv, 1e-8);
455
456        // Tensors for training
457        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458        let actions_vec = batch.actions.clone();
459        let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463        // PPO epochs
464        let num_minibatches = batch.len().div_ceil(mini_batch_size);
465        for e in 0..epochs {
466            for mb in 0..num_minibatches {
467                let start = mb * mini_batch_size;
468                let end = (start + mini_batch_size).min(batch.len());
469                if start >= end {
470                    break;
471                }
472
473                // Views
474                let s_mb = states_t
475                    .slice_view(start * state_dim, 1, (end - start) * state_dim)
476                    .reshape(vec![(end - start) as i32, state_dim as i32]);
477                let oldlp_mb = old_logp_t
478                    .slice_view(start, 1, end - start)
479                    .reshape(vec![(end - start) as i32, 1]);
480                let ret_mb = returns_t
481                    .slice_view(start, 1, end - start)
482                    .reshape(vec![(end - start) as i32, 1]);
483                let adv_mb = adv_t
484                    .slice_view(start, 1, end - start)
485                    .reshape(vec![(end - start) as i32, 1]);
486                let a_slice = &actions_vec[start..end];
487
488                // Zero grads
489                {
490                    let mut ps = actor.parameters();
491                    actor_opt.zero_grad(&mut ps);
492                }
493                {
494                    let mut ps = critic.parameters();
495                    critic_opt.zero_grad(&mut ps);
496                }
497
498                // Forward
499                let logits_mb = actor.forward(&s_mb); // [B,A]
500                let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501                let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502                let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503                let pg1 = ratio.mul_tensor(&adv_mb);
504                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507                let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509                let v_pred = critic.forward(&s_mb);
510                let v_loss = v_pred
511                    .sub_tensor(&ret_mb)
512                    .pow_scalar(2.0)
513                    .mean()
514                    .mul_scalar(vf_coef);
515
516                // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517                let probs_mb = logits_mb.softmax(1);
518                let logp_all = probs_mb.add_scalar(1e-8).log();
519                let ent = probs_mb
520                    .mul_tensor(&logp_all)
521                    .sum_dims(&[1], true)
522                    .mul_scalar(-1.0)
523                    .mean()
524                    .mul_scalar(ent_coef);
525
526                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527                loss.backward(None);
528
529                // Step actor
530                {
531                    let params = actor.parameters();
532                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
533                    for p in params {
534                        if p.grad_owned().is_some() {
535                            with_grads.push(p);
536                        }
537                    }
538                    if !with_grads.is_empty() {
539                        let _ = grad_global_norm(&mut with_grads);
540                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541                        actor_opt.step(&mut with_grads);
542                        actor_opt.zero_grad(&mut with_grads);
543                    }
544                }
545
546                // Step critic
547                {
548                    let params = critic.parameters();
549                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
550                    for p in params {
551                        if p.grad_owned().is_some() {
552                            with_grads.push(p);
553                        }
554                    }
555                    if !with_grads.is_empty() {
556                        let _ = grad_global_norm(&mut with_grads);
557                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558                        critic_opt.step(&mut with_grads);
559                        critic_opt.zero_grad(&mut with_grads);
560                    }
561                }
562
563                if e == 0 && mb == 0 {
564                    println!(
565                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
566                        t,
567                        actor_loss.value(),
568                        v_loss.value()
569                    );
570                }
571
572                clear_all_graphs_known();
573            }
574        }
575    }
576
577    println!("=== PPO discrete training finished ===");
578    Ok(())
579}
examples/neural_networks/basic_encoder.rs (line 60)
53    pub fn forward(&self, input: &Tensor, attn_mask: Option<&Tensor>) -> Tensor {
54        let attn = self.mha.forward(input, input, input, attn_mask);
55        let res1 = attn.add_tensor(input);
56
57        // Feed-forward network with ReLU and residual
58        let (b, t, e) = Self::triple(input);
59        let x2d = res1.contiguous().view(vec![(b * t) as i32, e as i32]);
60        let hidden = self.ffn_in.forward(&x2d).relu();
61        let out2d = self.ffn_out.forward(&hidden);
62        let out = out2d.view(vec![b as i32, t as i32, e as i32]);
63        out.add_tensor(&res1)
64    }
examples/RL_training/dqn.rs (line 81)
71    fn forward(&self, input: &Tensor, final_activation: Option<fn(&Tensor) -> Tensor>) -> Tensor {
72        let mut current: Option<Tensor> = None;
73        for (i, layer) in self.layers.iter().enumerate() {
74            let out = if i == 0 {
75                layer.forward(input)
76            } else {
77                layer.forward(current.as_ref().unwrap())
78            };
79            let is_last = i + 1 == self.layers.len();
80            let out = if !is_last {
81                out.relu()
82            } else if let Some(act) = final_activation {
83                act(&out)
84            } else {
85                out
86            };
87            current = Some(out);
88        }
89        current.expect("MLP has at least one layer")
90    }
Source§

impl Tensor

Source

pub fn sigmoid(&self) -> Tensor

Element-wise sigmoid activation function

Computes the sigmoid function for each element: output[i] = 1 / (1 + e^(-self[i]))

Uses a numerically stable implementation that avoids overflow for large positive/negative values by using different computation paths for positive and negative inputs.

§Returns

A new tensor with sigmoid applied to each element, values in range (0, 1)

§Performance Characteristics
  • Numerical Stability: Avoids overflow using stable implementation
  • Scalar Implementation: Optimized scalar computation for mathematical accuracy
  • Cache-friendly: Linear memory access patterns
  • Mathematical Accuracy: High-precision exponential and division operations
  • Gradient Tracking: Full gradtrack support with efficient gradient computation
§Implementation Details

Uses a numerically stable implementation:

  • For x ≥ 0: computes 1 / (1 + e^(-x)) to avoid overflow in e^x for large positive x
  • For x < 0: computes e^x / (1 + e^x) to avoid overflow in e^(-x) for large negative x This ensures the result is always in the range (0, 1) without numerical overflow.
§Examples
§Basic Sigmoid Activation
use train_station::Tensor;

let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
let b = a.sigmoid();
assert_eq!(b.shape().dims(), vec![3]);
assert!((b.get(&[0]) - 0.26894143).abs() < 1e-6); // sigmoid(-1.0)
assert!((b.get(&[1]) - 0.5).abs() < 1e-6); // sigmoid(0.0)
assert!((b.get(&[2]) - 0.7310586).abs() < 1e-6); // sigmoid(1.0)
§Extreme Values
use train_station::Tensor;

let a = Tensor::from_slice(&[-10.0, 10.0], vec![2]).unwrap();
let b = a.sigmoid();
assert_eq!(b.shape().dims(), vec![2]);
assert!(b.get(&[0]) < 1e-4); // sigmoid(-10.0) ≈ 0
assert!(b.get(&[1]) > 0.9999); // sigmoid(10.0) ≈ 1
Examples found in repository?
examples/supervised_training/supervised_bce.rs (line 142)
68pub fn main() -> Result<(), Box<dyn std::error::Error>> {
69    println!("=== Supervised FFN Example (XOR) ===");
70
71    // Dataset: XOR (repeat to form a small batch)
72    let inputs: Vec<f32> = vec![
73        0.0, 0.0, // -> 0
74        0.0, 1.0, // -> 1
75        1.0, 0.0, // -> 1
76        1.0, 1.0, // -> 0
77    ];
78    let targets: Vec<f32> = vec![0.0, 1.0, 1.0, 0.0];
79
80    // Repeat the base patterns to stabilize training
81    let repeats = 64usize; // effective batch = 4 * repeats = 256
82    let mut xs = Vec::with_capacity(repeats * inputs.len());
83    let mut ys = Vec::with_capacity(repeats * targets.len());
84    for _ in 0..repeats {
85        xs.extend_from_slice(&inputs);
86        ys.extend_from_slice(&targets);
87    }
88
89    let batch = xs.len() / 2; // two features
90    let x_t = Tensor::from_slice(&xs, vec![batch, 2]).unwrap();
91    let y_t = Tensor::from_slice(&ys, vec![batch, 1]).unwrap();
92
93    // Model config: 2 -> 32 -> 32 -> 1, final sigmoid via loss path
94    let cfg = FeedForwardConfig {
95        input_size: 2,
96        hidden_sizes: vec![32, 32],
97        output_size: 1,
98        use_bias: true,
99    };
100    let mut net = FeedForwardNetwork::new(cfg, Some(777));
101
102    // Optimizer and parameter linking
103    let mut opt = Adam::with_learning_rate(1e-3);
104    for p in net.parameters() {
105        opt.add_parameter(p);
106    }
107
108    let epochs = 1000usize;
109    let max_grad_norm = 1.0f32;
110    let mut best_loss = f32::INFINITY;
111    let mut best_acc = 0.0f32;
112
113    for e in 0..epochs {
114        // Zero grads each iteration
115        {
116            let mut params = net.parameters();
117            opt.zero_grad(&mut params);
118        }
119
120        // Forward -> logits; use numerically stable BCE-with-logits for loss
121        let logits = net.forward(&x_t);
122        let mut loss = bce_with_logits(&logits, &y_t);
123        loss.backward(None);
124
125        // Step only params with grads
126        {
127            let params = net.parameters();
128            let mut with_grads: Vec<&mut Tensor> = Vec::new();
129            for p in params {
130                if p.grad_owned().is_some() {
131                    with_grads.push(p);
132                }
133            }
134            if !with_grads.is_empty() {
135                clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
136                opt.step(&mut with_grads);
137                opt.zero_grad(&mut with_grads);
138            }
139        }
140
141        // Metrics (use sigmoid only for reporting accuracy)
142        let preds = logits.sigmoid();
143        let acc = accuracy(&preds, &y_t);
144        if loss.value() < best_loss {
145            best_loss = loss.value();
146        }
147        if acc > best_acc {
148            best_acc = acc;
149        }
150        if e % 10 == 0 || e + 1 == epochs {
151            println!(
152                "epoch {:4} | loss={:.5} acc={:.3} | best_loss={:.5} best_acc={:.3}",
153                e,
154                loss.value(),
155                acc,
156                best_loss,
157                best_acc
158            );
159        }
160
161        // Clear graphs to avoid stale accumulation across epochs
162        clear_all_graphs_known();
163    }
164
165    // Quick sanity check predictions
166    let test = Tensor::from_slice(&inputs, vec![4, 2]).unwrap();
167    let out = net.forward(&test).sigmoid();
168    println!("predictions (approx): {:?}", out.data());
169
170    println!("=== Supervised training finished ===");
171    Ok(())
172}
Source§

impl Tensor

Source

pub fn softmax(&self, dim: usize) -> Tensor

Computes softmax activation along the specified dimension

Applies the softmax function along dimension dim, transforming values into probabilities that sum to 1 along that dimension. Uses numerically stable computation to avoid overflow: softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))

§Arguments
  • dim - Dimension along which to compute softmax (0-based indexing)
§Returns

A new tensor with softmax applied along the specified dimension. Values are in range (0, 1) and sum to 1 along dim.

§Performance Characteristics
  • Numerical Stability: Avoids overflow using max subtraction technique
  • Scalar Implementation: Optimized scalar computation for mathematical accuracy
  • Cache-friendly: Optimized memory access patterns for dimension operations
  • Mathematical Accuracy: High-precision exponential and division operations
  • GradTrack Support: Full automatic differentiation with efficient gradient computation
§Implementation Details

Uses a numerically stable three-pass algorithm:

  1. Max Computation: Find the maximum value along the specified dimension
  2. Exponential Sum: Compute exp(x - max) and sum the results
  3. Normalization: Divide each exp(x - max) by the sum to get probabilities

This approach prevents overflow by subtracting the maximum value before computing exponentials, ensuring numerical stability for any input range.

§Examples
§Basic Softmax Activation
use train_station::Tensor;

let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let b = a.softmax(0);
assert_eq!(b.shape().dims(), vec![3]);

// Verify probabilities sum to 1
let sum = b.get(&[0]) + b.get(&[1]) + b.get(&[2]);
assert!((sum - 1.0).abs() < 1e-6);

// Verify relative ordering is preserved
assert!(b.get(&[0]) < b.get(&[1]));
assert!(b.get(&[1]) < b.get(&[2]));
§2D Softmax Along Different Dimensions
use train_station::Tensor;

let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let b = a.softmax(0); // Softmax along first dimension
assert_eq!(b.shape().dims(), vec![2, 2]);

// Each column should sum to 1
let col1_sum = b.get(&[0, 0]) + b.get(&[1, 0]);
let col2_sum = b.get(&[0, 1]) + b.get(&[1, 1]);
assert!((col1_sum - 1.0).abs() < 1e-6);
assert!((col2_sum - 1.0).abs() < 1e-6);
§Panics
  • Panics if dim is out of bounds for the tensor’s rank
  • Panics if the dimension size is 0
Examples found in repository?
examples/neural_networks/multi_head_attention.rs (line 108)
72    pub fn forward(
73        &self,
74        query: &Tensor,
75        key: &Tensor,
76        value: &Tensor,
77        attn_mask: Option<&Tensor>,
78    ) -> Tensor {
79        let qkv = Self::project_qkv(query, key, value, &self.q_proj, &self.k_proj, &self.v_proj);
80        let (q, k, v) = qkv;
81
82        // Split heads: [b, t, e] -> [b, h, t, d]
83        let (b, tq, _e) = Self::triple(query);
84        let (_b2, tk, _e2) = Self::triple(key);
85        let q = Self::split_heads(&q, b, tq, self.num_heads, self.head_dim);
86        let k = Self::split_heads(&k, b, tk, self.num_heads, self.head_dim);
87        let v = Self::split_heads(&v, b, tk, self.num_heads, self.head_dim);
88
89        // Scaled dot-product attention
90        // logits: [b, h, tq, tk]
91        let k_t = k.transpose(2, 3);
92        let mut logits = q.matmul(&k_t).div_scalar((self.head_dim as f32).sqrt());
93        if let Some(mask) = attn_mask {
94            let dims = mask.shape().dims().to_vec();
95            // If boolean-like mask matching [b,h,tq,tk], apply masked_fill
96            if dims.len() == 4 && dims[0] == b && dims[1] == self.num_heads && dims[2] == tq {
97                // Interpret mask > 0.5 as keep; we invert to build masked positions
98                let cond: Vec<bool> = mask.data().iter().map(|&v| v < 0.5).collect();
99                // Apply masked fill on a flattened view, then reshape back
100                let flat_logits = logits.view(vec![(b * self.num_heads * tq * tk) as i32]);
101                let filled = flat_logits.masked_fill(&cond, f32::NEG_INFINITY);
102                logits = filled.view(vec![b as i32, self.num_heads as i32, tq as i32, tk as i32]);
103            } else {
104                // Fallback: additive mask
105                logits = logits.add_tensor(mask);
106            }
107        }
108        let attn = logits.softmax(3);
109
110        // context: [b, h, tq, d]
111        let context = attn.matmul(&v);
112        let context = context.permute(vec![0, 2, 1, 3]); // [b, tq, h, d]
113        let context = context.contiguous().view(vec![
114            b as i32,
115            tq as i32,
116            (self.num_heads * self.head_dim) as i32,
117        ]);
118
119        // Output projection (flatten to 2D, project, then restore 3D)
120        let flat = context.view(vec![(b * tq) as i32, self.embed_dim as i32]);
121        let out2d = self.out_proj.forward(&flat);
122        out2d.view(vec![b as i32, tq as i32, self.embed_dim as i32])
123    }
More examples
Hide additional examples
examples/supervised_training/supervised_classification.rs (line 231)
87pub fn main() -> Result<(), Box<dyn std::error::Error>> {
88    println!("=== Supervised Classification Example (Cross-Entropy) ===");
89
90    // Synthetic 2D inputs, 3 classes with linear-ish separations
91    let n = 1200usize;
92    let classes = 3usize;
93    let mut xs: Vec<f32> = Vec::with_capacity(n * 2);
94    let mut ys: Vec<usize> = Vec::with_capacity(n);
95
96    // Simple RNG
97    let mut state: u64 = 424242;
98    let mut rand_f32 = || {
99        state = state.wrapping_mul(1664525).wrapping_add(1013904223);
100        (state >> 16) as f32 / (u32::MAX as f32)
101    };
102
103    for _ in 0..n {
104        let x1 = rand_f32() * 4.0 - 2.0;
105        let x2 = rand_f32() * 4.0 - 2.0;
106        // Class by quadrant-ish rule with noise
107        let mut c = if x1 + 0.5 * x2 > 0.5 {
108            0
109        } else if x1 - x2 < -0.5 {
110            1
111        } else {
112            2
113        };
114        if rand_f32() < 0.05 {
115            c = (c + 1) % classes;
116        }
117        xs.push(x1);
118        xs.push(x2);
119        ys.push(c);
120    }
121
122    // Normalize inputs per-feature to [-1, 1]
123    let mut min1 = f32::INFINITY;
124    let mut max1 = f32::NEG_INFINITY;
125    let mut min2 = f32::INFINITY;
126    let mut max2 = f32::NEG_INFINITY;
127    for i in (0..xs.len()).step_by(2) {
128        let a = xs[i];
129        let b = xs[i + 1];
130        if a < min1 {
131            min1 = a;
132        }
133        if a > max1 {
134            max1 = a;
135        }
136        if b < min2 {
137            min2 = b;
138        }
139        if b > max2 {
140            max2 = b;
141        }
142    }
143    let rng1 = (max1 - min1).max(1e-8);
144    let rng2 = (max2 - min2).max(1e-8);
145    for i in (0..xs.len()).step_by(2) {
146        let a = xs[i];
147        let b = xs[i + 1];
148        xs[i] = 2.0 * (a - min1) / rng1 - 1.0;
149        xs[i + 1] = 2.0 * (b - min2) / rng2 - 1.0;
150    }
151
152    // Train/Val split (80/20)
153    let n_train = (n as f32 * 0.8) as usize;
154    let x_train = Tensor::from_slice(&xs[..n_train * 2], vec![n_train, 2]).unwrap();
155    let y_train = ys[..n_train].to_vec();
156    let x_val = Tensor::from_slice(&xs[n_train * 2..], vec![n - n_train, 2]).unwrap();
157    let y_val = ys[n_train..].to_vec();
158
159    // Model: 2 -> 64 -> 64 -> 3 (logits)
160    let cfg = FeedForwardConfig {
161        input_size: 2,
162        hidden_sizes: vec![64, 64],
163        output_size: classes,
164        use_bias: true,
165    };
166    let mut net = FeedForwardNetwork::new(cfg, Some(303));
167
168    // Optimizer
169    let mut opt = Adam::with_learning_rate(1e-3);
170    for p in net.parameters() {
171        opt.add_parameter(p);
172    }
173
174    let epochs = 300usize;
175    let max_grad_norm = 1.0f32;
176    let mut best_val_acc = 0.0f32;
177    let mut best_val_loss = f32::INFINITY;
178
179    for e in 0..epochs {
180        // Zero grads
181        {
182            let mut params = net.parameters();
183            opt.zero_grad(&mut params);
184        }
185
186        // Forward logits
187        let logits = net.forward(&x_train);
188        let mut loss = cross_entropy_logits(&logits, &y_train, n_train, classes);
189        loss.backward(None);
190
191        // Step clipped
192        {
193            let params = net.parameters();
194            let mut with_grads: Vec<&mut Tensor> = Vec::new();
195            for p in params {
196                if p.grad_owned().is_some() {
197                    with_grads.push(p);
198                }
199            }
200            if !with_grads.is_empty() {
201                clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
202                opt.step(&mut with_grads);
203                opt.zero_grad(&mut with_grads);
204            }
205        }
206
207        // Metrics
208        let train_acc = accuracy_from_logits(&logits, &y_train, n_train, classes);
209        let val_logits = net.forward(&x_val);
210        let val_loss = cross_entropy_logits(&val_logits, &y_val, n - n_train, classes).value();
211        let val_acc = accuracy_from_logits(&val_logits, &y_val, n - n_train, classes);
212        if val_acc > best_val_acc {
213            best_val_acc = val_acc;
214        }
215        if val_loss < best_val_loss {
216            best_val_loss = val_loss;
217        }
218
219        if e % 10 == 0 || e + 1 == epochs {
220            println!(
221                "epoch {:4} | loss={:.4} acc={:.3} | val_loss={:.4} val_acc={:.3} | best_val_acc={:.3}",
222                e, loss.value(), train_acc, val_loss, val_acc, best_val_acc
223            );
224        }
225
226        clear_all_graphs_known();
227    }
228
229    // Quick sample preds via softmax
230    let samples = Tensor::from_slice(&[-1.0, -1.0, 0.0, 0.0, 1.0, 1.0], vec![3, 2]).unwrap();
231    let sm = net.forward(&samples).softmax(1);
232    println!("sample class probs: {:?}", sm.data());
233
234    println!("=== Supervised classification finished ===");
235    Ok(())
236}
examples/RL_training/ppo_discrete.rs (line 364)
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320    println!("=== PPO Discrete Example (YardEnv) ===");
321
322    let state_dim = 3usize;
323    let action_dim = 3usize;
324    let total_steps = std::env::var("PPOD_STEPS")
325        .ok()
326        .and_then(|v| v.parse::<usize>().ok())
327        .unwrap_or(3500usize);
328    let horizon = 128usize;
329    let epochs = 4usize;
330    let mini_batch_size = 64usize;
331    let gamma = 0.99f32;
332    let lam = 0.95f32;
333    let clip_eps = 0.2f32;
334    let vf_coef = 0.5f32;
335    let ent_coef = 0.0f32;
336    let max_grad_norm = 1.0f32;
337
338    let mut actor = Actor::new(state_dim, action_dim, Some(111));
339    let mut critic = Critic::new(state_dim, Some(222));
340    let mut actor_opt = Adam::with_learning_rate(3e-4);
341    for p in actor.parameters() {
342        actor_opt.add_parameter(p);
343    }
344    let mut critic_opt = Adam::with_learning_rate(3e-4);
345    for p in critic.parameters() {
346        critic_opt.add_parameter(p);
347    }
348
349    let mut env = YardEnv::new(1234);
350    let mut rng = SmallRng::new(98765);
351    let mut state = env.reset();
352    let mut episode_return = 0.0f32;
353    let mut episode = 0usize;
354    let mut ema_return: Option<f32> = None;
355    let ema_alpha = 0.05f32;
356    let mut best_return = f32::NEG_INFINITY;
357
358    let mut t = 0usize;
359    while t < total_steps {
360        let mut batch = RolloutBatch::new(horizon, state_dim);
361        for _ in 0..horizon {
362            // Actor logits and categorical sampling
363            let logits = actor.forward(&state); // [1, A]
364            let probs = logits.softmax(1); // [1, A]
365                                           // sample action from probs (CPU sampling)
366            let p = probs.data();
367            let (p0, p1, _p2) = (p[0], p[1], p[2]);
368            let u = rng.next_f32();
369            let a_idx = if u < p0 {
370                0
371            } else if u < p0 + p1 {
372                1
373            } else {
374                2
375            };
376
377            let old_logp = {
378                let _ng = NoGradTrack::new();
379                let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380                lp.data()[0]
381            };
382
383            // Step env
384            let (next_state, reward, done) = env.step(a_idx);
385            episode_return += reward;
386
387            // Critic value
388            let value_t = critic.forward(&state);
389            let value_v = value_t.data()[0];
390
391            batch.push(
392                state.data(),
393                a_idx,
394                old_logp,
395                reward,
396                if done { 1.0 } else { 0.0 },
397                value_v,
398                next_state.data(),
399            );
400
401            state = if done {
402                let st = env.reset();
403                ema_return = Some(match ema_return {
404                    None => episode_return,
405                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406                });
407                if episode_return > best_return {
408                    best_return = episode_return;
409                }
410                println!(
411                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412                    t,
413                    episode,
414                    episode_return,
415                    ema_return.unwrap_or(episode_return),
416                    best_return
417                );
418                episode_return = 0.0;
419                episode += 1;
420                st
421            } else {
422                next_state
423            };
424
425            t += 1;
426            if t >= total_steps {
427                break;
428            }
429        }
430
431        // Bootstrap values for GAE
432        let next_values: Vec<f32> = {
433            let mut out = Vec::with_capacity(batch.len());
434            for i in 0..batch.len() {
435                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437                out.push(critic.forward(&s2_t).data()[0]);
438            }
439            out
440        };
441
442        let mut returns = vec![0.0f32; batch.len()];
443        let mut adv = vec![0.0f32; batch.len()];
444        compute_gae(
445            &mut returns,
446            &mut adv,
447            &batch.rewards,
448            &batch.dones,
449            &batch.values,
450            &next_values,
451            gamma,
452            lam,
453        );
454        normalize_in_place(&mut adv, 1e-8);
455
456        // Tensors for training
457        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458        let actions_vec = batch.actions.clone();
459        let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463        // PPO epochs
464        let num_minibatches = batch.len().div_ceil(mini_batch_size);
465        for e in 0..epochs {
466            for mb in 0..num_minibatches {
467                let start = mb * mini_batch_size;
468                let end = (start + mini_batch_size).min(batch.len());
469                if start >= end {
470                    break;
471                }
472
473                // Views
474                let s_mb = states_t
475                    .slice_view(start * state_dim, 1, (end - start) * state_dim)
476                    .reshape(vec![(end - start) as i32, state_dim as i32]);
477                let oldlp_mb = old_logp_t
478                    .slice_view(start, 1, end - start)
479                    .reshape(vec![(end - start) as i32, 1]);
480                let ret_mb = returns_t
481                    .slice_view(start, 1, end - start)
482                    .reshape(vec![(end - start) as i32, 1]);
483                let adv_mb = adv_t
484                    .slice_view(start, 1, end - start)
485                    .reshape(vec![(end - start) as i32, 1]);
486                let a_slice = &actions_vec[start..end];
487
488                // Zero grads
489                {
490                    let mut ps = actor.parameters();
491                    actor_opt.zero_grad(&mut ps);
492                }
493                {
494                    let mut ps = critic.parameters();
495                    critic_opt.zero_grad(&mut ps);
496                }
497
498                // Forward
499                let logits_mb = actor.forward(&s_mb); // [B,A]
500                let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501                let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502                let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503                let pg1 = ratio.mul_tensor(&adv_mb);
504                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507                let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509                let v_pred = critic.forward(&s_mb);
510                let v_loss = v_pred
511                    .sub_tensor(&ret_mb)
512                    .pow_scalar(2.0)
513                    .mean()
514                    .mul_scalar(vf_coef);
515
516                // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517                let probs_mb = logits_mb.softmax(1);
518                let logp_all = probs_mb.add_scalar(1e-8).log();
519                let ent = probs_mb
520                    .mul_tensor(&logp_all)
521                    .sum_dims(&[1], true)
522                    .mul_scalar(-1.0)
523                    .mean()
524                    .mul_scalar(ent_coef);
525
526                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527                loss.backward(None);
528
529                // Step actor
530                {
531                    let params = actor.parameters();
532                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
533                    for p in params {
534                        if p.grad_owned().is_some() {
535                            with_grads.push(p);
536                        }
537                    }
538                    if !with_grads.is_empty() {
539                        let _ = grad_global_norm(&mut with_grads);
540                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541                        actor_opt.step(&mut with_grads);
542                        actor_opt.zero_grad(&mut with_grads);
543                    }
544                }
545
546                // Step critic
547                {
548                    let params = critic.parameters();
549                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
550                    for p in params {
551                        if p.grad_owned().is_some() {
552                            with_grads.push(p);
553                        }
554                    }
555                    if !with_grads.is_empty() {
556                        let _ = grad_global_norm(&mut with_grads);
557                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558                        critic_opt.step(&mut with_grads);
559                        critic_opt.zero_grad(&mut with_grads);
560                    }
561                }
562
563                if e == 0 && mb == 0 {
564                    println!(
565                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
566                        t,
567                        actor_loss.value(),
568                        v_loss.value()
569                    );
570                }
571
572                clear_all_graphs_known();
573            }
574        }
575    }
576
577    println!("=== PPO discrete training finished ===");
578    Ok(())
579}
Source§

impl Tensor

Source

pub fn sqrt(&self) -> Tensor

Element-wise square root

Computes the square root for each element: output[i] = sqrt(self[i])

Uses SIMD optimization when available for maximum performance, with automatic fallback to optimized scalar computation for non-SIMD hardware.

§Returns

A new tensor with the square root of each element

§Performance Characteristics
  • SIMD Optimization: AVX2-optimized with 32-element blocks and 4x unrolling
  • Scalar Fallback: 4x unrolled scalar implementation for non-SIMD hardware
  • Cache-friendly: Linear memory access patterns
  • Mathematical Accuracy: High-precision square root computation
  • GradTrack Support: Full automatic differentiation with efficient gradient computation
§Implementation Details

Automatically selects between SIMD and scalar implementations based on hardware capabilities. SIMD implementation uses AVX2 vector square root operations for optimal performance. Scalar implementation uses 4x unrolling for better instruction-level parallelism.

§Examples
§Basic Square Root
use train_station::Tensor;

let a = Tensor::from_slice(&[1.0, 4.0, 9.0], vec![3]).unwrap();
let b = a.sqrt();
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 1.0); // sqrt(1.0) = 1.0
assert_eq!(b.get(&[1]), 2.0); // sqrt(4.0) = 2.0
assert_eq!(b.get(&[2]), 3.0); // sqrt(9.0) = 3.0
§Zero and Special Values
use train_station::Tensor;

let a = Tensor::from_slice(&[0.0, 1.0, 16.0], vec![3]).unwrap();
let b = a.sqrt();
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 0.0); // sqrt(0.0) = 0.0
assert_eq!(b.get(&[1]), 1.0); // sqrt(1.0) = 1.0
assert_eq!(b.get(&[2]), 4.0); // sqrt(16.0) = 4.0
§Note

Results are undefined for negative values (may produce NaN)

Examples found in repository?
examples/supervised_training/supervised_regression.rs (line 47)
46fn rmse(pred: &Tensor, target: &Tensor) -> f32 {
47    mse(pred, target).sqrt().value()
48}
More examples
Hide additional examples
examples/RL_training/dqn.rs (line 324)
321fn pseudo_huber_mean(diff: &Tensor) -> Tensor {
322    diff.pow_scalar(2.0)
323        .add_scalar(1.0)
324        .sqrt()
325        .sub_scalar(1.0)
326        .mean()
327}
examples/iterators/performance_optimization.rs (line 179)
162fn demonstrate_memory_optimization() -> Result<(), Box<dyn std::error::Error>> {
163    println!("\n--- Memory Optimization ---");
164
165    // Create a large tensor for memory testing
166    let size = 10000;
167    let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
168    let tensor = Tensor::from_slice(&data, vec![size])?;
169
170    println!("Processing tensor of size: {}", size);
171
172    // Pattern 1: Streaming processing with iterator chunks (process in blocks, collect with shape)
173    println!("\nPattern 1: Streaming Processing");
174    let chunk_size = 1000;
175    let start = Instant::now();
176    let flattened = tensor.view(vec![size as i32]);
177    let _streamed_result: Tensor = flattened
178        .chunks(chunk_size)
179        .map(|c| c.pow_scalar(2.0).sqrt())
180        .collect_shape(vec![size]);
181    let streamed_time = start.elapsed();
182
183    // Pattern 2: Full processing
184    let start = Instant::now();
185    let _full_result: Tensor = tensor
186        .iter_elements()
187        .map(|elem| elem.pow_scalar(2.0).sqrt())
188        .collect_shape(vec![size]);
189    let full_time = start.elapsed();
190
191    println!("  Streaming time: {:?}", streamed_time);
192    println!("  Full processing time: {:?}", full_time);
193    println!(
194        "  Memory efficiency ratio: {:.2}x",
195        full_time.as_nanos() as f64 / streamed_time.as_nanos() as f64
196    );
197
198    // Pattern 3: Lazy evaluation with take
199    println!("\nPattern 2: Lazy Evaluation");
200    let start = Instant::now();
201    let lazy_result: Tensor = tensor
202        .iter_elements()
203        .take(1000) // Only process first 1000 elements
204        .map(|elem| elem.pow_scalar(2.0).sqrt())
205        .collect_shape(vec![1000]);
206    let lazy_time = start.elapsed();
207
208    println!("  Lazy processing (1000 elements): {:?}", lazy_time);
209    println!("  Lazy result size: {}", lazy_result.size());
210
211    // Pattern 4: Memory-efficient filtering
212    println!("\nPattern 3: Memory-Efficient Filtering");
213    let start = Instant::now();
214    let filtered_result: Tensor = tensor
215        .iter_elements()
216        .filter(|elem| elem.value() > size as f32 / 2.0) // Keep only large values
217        .map(|elem| elem.mul_scalar(2.0))
218        .collect();
219    let filtered_time = start.elapsed();
220
221    println!("  Filtered processing: {:?}", filtered_time);
222    println!(
223        "  Filtered result size: {} (reduced from {})",
224        filtered_result.size(),
225        size
226    );
227
228    Ok(())
229}
230
231/// Demonstrate large-scale processing techniques
232///
233/// Shows how to efficiently process very large datasets using
234/// iterator patterns and optimization strategies.
235fn demonstrate_large_scale_processing() -> Result<(), Box<dyn std::error::Error>> {
236    println!("\n--- Large-Scale Processing ---");
237
238    // Simulate large dataset processing
239    let sizes = vec![10000, 50000, 100000];
240
241    for size in sizes {
242        println!("\nProcessing dataset of size: {}", size);
243
244        // Generate large dataset
245        let data: Vec<f32> = (0..size)
246            .map(|i| {
247                let x = i as f32 / size as f32;
248                x * x + 0.1 * (i % 10) as f32 // Quadratic with noise
249            })
250            .collect();
251
252        let tensor = Tensor::from_slice(&data, vec![size])?;
253
254        // Technique 1: Batch processing
255        let batch_size = 1000;
256        let start = Instant::now();
257
258        let mut batch_results = Vec::new();
259        for batch_start in (0..size).step_by(batch_size) {
260            let batch_end = (batch_start + batch_size).min(size);
261            let batch: Tensor = tensor
262                .iter_range(batch_start, batch_end)
263                .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
264                .collect();
265            batch_results.push(batch);
266        }
267        let batch_time = start.elapsed();
268
269        // Technique 2: Parallel-like processing with stride
270        let start = Instant::now();
271        let stride = 4;
272        let strided_result: Tensor = tensor
273            .iter()
274            .enumerate()
275            .filter(|(i, _)| i % stride == 0)
276            .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
277            .collect();
278        let strided_time = start.elapsed();
279
280        // Technique 3: Hierarchical processing
281        let start = Instant::now();
282        let coarse: Tensor = tensor
283            .iter()
284            .enumerate()
285            .filter(|(i, _)| i % 10 == 0) // Every 10th element
286            .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
287            .collect();
288        let fine: Tensor = tensor
289            .iter()
290            .enumerate()
291            .filter(|(i, _)| i % 10 != 0) // Rest of elements
292            .map(|(_, elem)| elem.pow_scalar(1.5).add_scalar(0.5))
293            .collect();
294        let hierarchical_time = start.elapsed();
295
296        // Report performance
297        println!("  Batch processing: {:?}", batch_time);
298        println!("  Strided processing: {:?}", strided_time);
299        println!("  Hierarchical processing: {:?}", hierarchical_time);
300
301        // Memory usage analysis
302        let total_batches = size.div_ceil(batch_size);
303        println!("  Batch count: {}", total_batches);
304        println!("  Strided result size: {}", strided_result.size());
305        println!(
306            "  Hierarchical: coarse={}, fine={}",
307            coarse.size(),
308            fine.size()
309        );
310    }
311
312    Ok(())
313}
314
315/// Demonstrate advanced optimization techniques
316///
317/// Shows sophisticated optimization strategies and techniques
318/// for maximizing performance in tensor iterator operations.
319fn demonstrate_optimization_techniques() -> Result<(), Box<dyn std::error::Error>> {
320    println!("\n--- Optimization Techniques ---");
321
322    let size = 50000;
323    let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
324    let tensor = Tensor::from_slice(&data, vec![size])?;
325
326    println!("Optimizing processing for size: {}", size);
327
328    // Technique 1: Operation fusion
329    println!("\nTechnique 1: Operation Fusion");
330    let start = Instant::now();
331    let fused_result: Tensor = tensor
332        .iter()
333        .map(|elem| {
334            // Fuse multiple operations into single chain
335            elem.mul_scalar(2.0).add_scalar(1.0).pow_scalar(2.0).sqrt()
336        })
337        .collect();
338    let fused_time = start.elapsed();
339
340    // Technique 2: Conditional optimization
341    println!("\nTechnique 2: Conditional Optimization");
342    let start = Instant::now();
343    let conditional_result: Tensor = tensor
344        .iter()
345        .map(|elem| {
346            let val = elem.value();
347            if val < size as f32 / 2.0 {
348                elem.mul_scalar(2.0) // Simple operation for small values
349            } else {
350                elem.pow_scalar(2.0).sqrt() // Complex operation for large values
351            }
352        })
353        .collect();
354    let conditional_time = start.elapsed();
355
356    // Technique 3: Cache-friendly processing
357    println!("\nTechnique 3: Cache-Friendly Processing");
358    let start = Instant::now();
359    let cache_friendly_result: Tensor = tensor
360        .iter()
361        .take(1000) // Process in cache-friendly chunks
362        .map(|elem| elem.mul_scalar(2.0))
363        .collect();
364    let cache_friendly_time = start.elapsed();
365
366    // Technique 4: Memory pooling simulation
367    println!("\nTechnique 4: Memory Pooling Simulation");
368    let start = Instant::now();
369    let pooled_result: Tensor = tensor
370        .iter()
371        .enumerate()
372        .filter(|(i, _)| i % 100 == 0) // Process every 100th element
373        .map(|(_, elem)| elem.pow_scalar(2.0))
374        .collect();
375    let pooled_time = start.elapsed();
376
377    // Report optimization results
378    println!("  Fused operations: {:?}", fused_time);
379    println!("  Conditional optimization: {:?}", conditional_time);
380    println!("  Cache-friendly processing: {:?}", cache_friendly_time);
381    println!("  Memory pooling simulation: {:?}", pooled_time);
382
383    // Performance analysis
384    let fastest = fused_time
385        .min(conditional_time)
386        .min(cache_friendly_time)
387        .min(pooled_time);
388    println!("  Fastest technique: {:?}", fastest);
389
390    // Memory efficiency analysis
391    println!("  Fused result size: {}", fused_result.size());
392    println!("  Conditional result size: {}", conditional_result.size());
393    println!(
394        "  Cache-friendly result size: {}",
395        cache_friendly_result.size()
396    );
397    println!("  Pooled result size: {}", pooled_result.size());
398
399    // Technique 5: Gradient optimization
400    println!("\nTechnique 5: Gradient Optimization");
401    let grad_tensor = tensor.with_requires_grad();
402    let start = Instant::now();
403
404    let grad_result: Tensor = grad_tensor
405        .iter()
406        .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
407        .collect();
408
409    let mut loss = grad_result.sum();
410    loss.backward(None);
411    let grad_time = start.elapsed();
412
413    println!("  Gradient computation: {:?}", grad_time);
414    println!(
415        "  Gradient tracking enabled: {}",
416        grad_result.requires_grad()
417    );
418
419    Ok(())
420}
examples/iterators/advanced_patterns.rs (line 191)
174fn demonstrate_conditional_processing() -> Result<(), Box<dyn std::error::Error>> {
175    println!("\n--- Conditional Processing ---");
176
177    // Create data with mixed characteristics
178    let data = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0];
179    let tensor = Tensor::from_slice(&data, vec![10])?;
180    println!("Input data: {:?}", tensor.data());
181
182    // Conditional transformation based on sign
183    println!("\nConditional transformation (positive/negative handling):");
184    let processed: Tensor = tensor
185        .iter()
186        .map(|elem| {
187            let val = elem.value();
188            if val > 0.0 {
189                elem.pow_scalar(2.0) // Square positive values
190            } else {
191                elem.mul_scalar(-1.0).sqrt() // Square root of absolute negative values
192            }
193        })
194        .collect();
195    println!("  Processed: {:?}", processed.data());
196
197    // Adaptive filtering based on local statistics
198    println!("\nAdaptive filtering (remove values > 2 std from local mean):");
199    let window_size = 3;
200    let adaptive_filtered: Tensor = tensor
201        .iter()
202        .enumerate()
203        .filter(|(i, elem)| {
204            let start = i.saturating_sub(window_size / 2);
205            let end = (i + window_size / 2 + 1).min(tensor.size());
206
207            // Calculate local mean and std
208            let local_values: Vec<f32> = (start..end)
209                .map(|j| tensor.element_view(j).value())
210                .collect();
211
212            let local_mean = local_values.iter().sum::<f32>() / local_values.len() as f32;
213            let local_variance = local_values
214                .iter()
215                .map(|v| (v - local_mean).powi(2))
216                .sum::<f32>()
217                / local_values.len() as f32;
218            let local_std = local_variance.sqrt();
219
220            let threshold = local_mean + 2.0 * local_std;
221            elem.value() <= threshold
222        })
223        .map(|(_, elem)| elem)
224        .collect();
225    println!("  Adaptive filtered: {:?}", adaptive_filtered.data());
226
227    // Multi-condition processing
228    println!("\nMulti-condition processing:");
229    let multi_processed: Tensor = tensor
230        .iter()
231        .map(|elem| {
232            let val = elem.value();
233            match () {
234                _ if val > 5.0 => elem.mul_scalar(2.0), // Double large values
235                _ if val < -5.0 => elem.div_scalar(2.0), // Halve small values
236                _ if val.abs() < 2.0 => elem.add_scalar(1.0), // Add 1 to small values
237                _ => elem.clone(),                      // Keep others unchanged
238            }
239        })
240        .collect();
241    println!("  Multi-condition: {:?}", multi_processed.data());
242
243    Ok(())
244}
Source§

impl Tensor

Source

pub fn sub_tensor(&self, other: &Tensor) -> Tensor

Element-wise subtraction with another tensor with broadcasting support

Performs element-wise subtraction with automatic broadcasting: output[i] = self[i] - other[i]

Broadcasting enables subtraction between tensors of different but compatible shapes. Compatible shapes follow NumPy broadcasting rules:

  • Dimensions are aligned from the rightmost dimension
  • Dimensions are compatible if they are equal, or one of them is 1
  • Missing dimensions are treated as 1
§Arguments
  • other - Tensor to subtract. Shapes must be broadcast-compatible.
§Returns

A new tensor containing the element-wise difference with broadcast result shape

§Performance Characteristics
  • Fast Path: Optimized for identical shapes to avoid broadcasting overhead
  • SIMD Optimization: AVX2-optimized with 32-element blocks and 4x unrolling
  • Broadcasting: Efficient broadcasting for compatible shapes
  • Cache-friendly: Linear memory access patterns
  • GradTrack Support: Full automatic differentiation with efficient gradient computation
§Implementation Details

Uses a fast path for identical shapes to avoid broadcasting overhead. For different shapes, performs broadcasting followed by optimized element-wise subtraction. Automatically selects between SIMD and scalar implementations based on hardware capabilities.

§Examples
§Same Shape Subtraction
use train_station::Tensor;

let a = Tensor::from_slice(&[5.0, 7.0, 9.0], vec![3]).unwrap();
let b = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let c = a.sub_tensor(&b);
assert_eq!(c.shape().dims(), vec![3]);
assert_eq!(c.get(&[0]), 4.0); // 5.0 - 1.0
assert_eq!(c.get(&[1]), 5.0); // 7.0 - 2.0
assert_eq!(c.get(&[2]), 6.0); // 9.0 - 3.0
§Broadcasting Subtraction
use train_station::Tensor;

let a = Tensor::from_slice(&[5.0, 10.0], vec![2, 1]).unwrap();
let b = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
let c = a.sub_tensor(&b);
assert_eq!(c.shape().dims(), vec![2, 3]);
// Result: [[4.0, 3.0, 2.0], [9.0, 8.0, 7.0]]
assert_eq!(c.get(&[0, 0]), 4.0); // 5.0 - 1.0
assert_eq!(c.get(&[0, 1]), 3.0); // 5.0 - 2.0
assert_eq!(c.get(&[1, 0]), 9.0); // 10.0 - 1.0
§Scalar Subtraction
use train_station::Tensor;

let a = Tensor::ones(vec![2, 3]);
let b = Tensor::from_slice(&[0.5], vec![1]).unwrap();
let c = a.sub_tensor(&b);
assert_eq!(c.shape().dims(), vec![2, 3]);
assert_eq!(c.get(&[0, 0]), 0.5); // 1.0 - 0.5
§Panics

Panics if tensor shapes are not broadcast-compatible

Examples found in repository?
examples/supervised_training/supervised_regression.rs (line 43)
42fn mse(pred: &Tensor, target: &Tensor) -> Tensor {
43    pred.sub_tensor(target).pow_scalar(2.0).mean()
44}
45
46fn rmse(pred: &Tensor, target: &Tensor) -> f32 {
47    mse(pred, target).sqrt().value()
48}
49
50fn r2_score(pred: &Tensor, target: &Tensor) -> f32 {
51    // R^2 = 1 - SS_res / SS_tot
52    let y = target;
53    let y_mean = y.mean();
54    let ss_res = pred.sub_tensor(y).pow_scalar(2.0).sum();
55    let ss_tot = y.sub_tensor(&y_mean).pow_scalar(2.0).sum();
56    let ss_res_v = ss_res.value();
57    let ss_tot_v = ss_tot.value().max(1e-12); // avoid divide by zero
58    1.0 - (ss_res_v / ss_tot_v)
59}
More examples
Hide additional examples
examples/supervised_training/supervised_bce.rs (line 65)
59fn bce_with_logits(logits: &Tensor, targets: &Tensor) -> Tensor {
60    let relu_z = logits.relu();
61    let zy = logits.mul_tensor(targets);
62    // |z| = relu(z) + relu(-z)
63    let abs_z = relu_z.add_tensor(&logits.mul_scalar(-1.0).relu());
64    let log_term = abs_z.mul_scalar(-1.0).exp().add_scalar(1.0).log();
65    relu_z.sub_tensor(&zy).add_tensor(&log_term).mean()
66}
examples/supervised_training/supervised_classification.rs (line 52)
44fn cross_entropy_logits(
45    logits: &Tensor,
46    labels: &[usize],
47    batch: usize,
48    _num_classes: usize,
49) -> Tensor {
50    // log_softmax = logits - logsumexp(logits, dim=1)
51    let max_logits = logits.max_dims(&[1], true);
52    let shifted = logits.sub_tensor(&max_logits);
53    let exp = shifted.exp();
54    let sum_exp = exp.sum_dims(&[1], true);
55    let log_sum_exp = sum_exp.log();
56    let log_softmax = shifted.sub_tensor(&log_sum_exp);
57    let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58    ll.mul_scalar(-1.0).mean()
59}
examples/RL_training/ppo_discrete.rs (line 280)
273fn log_prob_actions(
274    logits: &Tensor,
275    actions: &[usize],
276    batch: usize,
277    _action_dim: usize,
278) -> Tensor {
279    let max_logits = logits.max_dims(&[1], true); // [B,1]
280    let shifted = logits.sub_tensor(&max_logits);
281    let exp = shifted.exp();
282    let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283    let log_sum_exp = sum_exp.log(); // [B,1]
284    let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285                                                        // gather selected action log-probs
286    log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291    new_logp.sub_tensor(old_logp).exp()
292}
293
294// Clamp ratio to [1-clip, 1+clip] using ReLU-based clamp (no custom ops)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296    let b = ratio.shape().dims()[0];
297    let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298    let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299    let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300    high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304    let mut total_sq = 0.0f32;
305    for p in parameters.iter_mut() {
306        if let Some(g) = p.grad_owned() {
307            for &v in g.data() {
308                total_sq += v * v;
309            }
310        }
311    }
312    total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320    println!("=== PPO Discrete Example (YardEnv) ===");
321
322    let state_dim = 3usize;
323    let action_dim = 3usize;
324    let total_steps = std::env::var("PPOD_STEPS")
325        .ok()
326        .and_then(|v| v.parse::<usize>().ok())
327        .unwrap_or(3500usize);
328    let horizon = 128usize;
329    let epochs = 4usize;
330    let mini_batch_size = 64usize;
331    let gamma = 0.99f32;
332    let lam = 0.95f32;
333    let clip_eps = 0.2f32;
334    let vf_coef = 0.5f32;
335    let ent_coef = 0.0f32;
336    let max_grad_norm = 1.0f32;
337
338    let mut actor = Actor::new(state_dim, action_dim, Some(111));
339    let mut critic = Critic::new(state_dim, Some(222));
340    let mut actor_opt = Adam::with_learning_rate(3e-4);
341    for p in actor.parameters() {
342        actor_opt.add_parameter(p);
343    }
344    let mut critic_opt = Adam::with_learning_rate(3e-4);
345    for p in critic.parameters() {
346        critic_opt.add_parameter(p);
347    }
348
349    let mut env = YardEnv::new(1234);
350    let mut rng = SmallRng::new(98765);
351    let mut state = env.reset();
352    let mut episode_return = 0.0f32;
353    let mut episode = 0usize;
354    let mut ema_return: Option<f32> = None;
355    let ema_alpha = 0.05f32;
356    let mut best_return = f32::NEG_INFINITY;
357
358    let mut t = 0usize;
359    while t < total_steps {
360        let mut batch = RolloutBatch::new(horizon, state_dim);
361        for _ in 0..horizon {
362            // Actor logits and categorical sampling
363            let logits = actor.forward(&state); // [1, A]
364            let probs = logits.softmax(1); // [1, A]
365                                           // sample action from probs (CPU sampling)
366            let p = probs.data();
367            let (p0, p1, _p2) = (p[0], p[1], p[2]);
368            let u = rng.next_f32();
369            let a_idx = if u < p0 {
370                0
371            } else if u < p0 + p1 {
372                1
373            } else {
374                2
375            };
376
377            let old_logp = {
378                let _ng = NoGradTrack::new();
379                let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380                lp.data()[0]
381            };
382
383            // Step env
384            let (next_state, reward, done) = env.step(a_idx);
385            episode_return += reward;
386
387            // Critic value
388            let value_t = critic.forward(&state);
389            let value_v = value_t.data()[0];
390
391            batch.push(
392                state.data(),
393                a_idx,
394                old_logp,
395                reward,
396                if done { 1.0 } else { 0.0 },
397                value_v,
398                next_state.data(),
399            );
400
401            state = if done {
402                let st = env.reset();
403                ema_return = Some(match ema_return {
404                    None => episode_return,
405                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406                });
407                if episode_return > best_return {
408                    best_return = episode_return;
409                }
410                println!(
411                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412                    t,
413                    episode,
414                    episode_return,
415                    ema_return.unwrap_or(episode_return),
416                    best_return
417                );
418                episode_return = 0.0;
419                episode += 1;
420                st
421            } else {
422                next_state
423            };
424
425            t += 1;
426            if t >= total_steps {
427                break;
428            }
429        }
430
431        // Bootstrap values for GAE
432        let next_values: Vec<f32> = {
433            let mut out = Vec::with_capacity(batch.len());
434            for i in 0..batch.len() {
435                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437                out.push(critic.forward(&s2_t).data()[0]);
438            }
439            out
440        };
441
442        let mut returns = vec![0.0f32; batch.len()];
443        let mut adv = vec![0.0f32; batch.len()];
444        compute_gae(
445            &mut returns,
446            &mut adv,
447            &batch.rewards,
448            &batch.dones,
449            &batch.values,
450            &next_values,
451            gamma,
452            lam,
453        );
454        normalize_in_place(&mut adv, 1e-8);
455
456        // Tensors for training
457        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458        let actions_vec = batch.actions.clone();
459        let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463        // PPO epochs
464        let num_minibatches = batch.len().div_ceil(mini_batch_size);
465        for e in 0..epochs {
466            for mb in 0..num_minibatches {
467                let start = mb * mini_batch_size;
468                let end = (start + mini_batch_size).min(batch.len());
469                if start >= end {
470                    break;
471                }
472
473                // Views
474                let s_mb = states_t
475                    .slice_view(start * state_dim, 1, (end - start) * state_dim)
476                    .reshape(vec![(end - start) as i32, state_dim as i32]);
477                let oldlp_mb = old_logp_t
478                    .slice_view(start, 1, end - start)
479                    .reshape(vec![(end - start) as i32, 1]);
480                let ret_mb = returns_t
481                    .slice_view(start, 1, end - start)
482                    .reshape(vec![(end - start) as i32, 1]);
483                let adv_mb = adv_t
484                    .slice_view(start, 1, end - start)
485                    .reshape(vec![(end - start) as i32, 1]);
486                let a_slice = &actions_vec[start..end];
487
488                // Zero grads
489                {
490                    let mut ps = actor.parameters();
491                    actor_opt.zero_grad(&mut ps);
492                }
493                {
494                    let mut ps = critic.parameters();
495                    critic_opt.zero_grad(&mut ps);
496                }
497
498                // Forward
499                let logits_mb = actor.forward(&s_mb); // [B,A]
500                let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501                let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502                let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503                let pg1 = ratio.mul_tensor(&adv_mb);
504                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507                let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509                let v_pred = critic.forward(&s_mb);
510                let v_loss = v_pred
511                    .sub_tensor(&ret_mb)
512                    .pow_scalar(2.0)
513                    .mean()
514                    .mul_scalar(vf_coef);
515
516                // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517                let probs_mb = logits_mb.softmax(1);
518                let logp_all = probs_mb.add_scalar(1e-8).log();
519                let ent = probs_mb
520                    .mul_tensor(&logp_all)
521                    .sum_dims(&[1], true)
522                    .mul_scalar(-1.0)
523                    .mean()
524                    .mul_scalar(ent_coef);
525
526                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527                loss.backward(None);
528
529                // Step actor
530                {
531                    let params = actor.parameters();
532                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
533                    for p in params {
534                        if p.grad_owned().is_some() {
535                            with_grads.push(p);
536                        }
537                    }
538                    if !with_grads.is_empty() {
539                        let _ = grad_global_norm(&mut with_grads);
540                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541                        actor_opt.step(&mut with_grads);
542                        actor_opt.zero_grad(&mut with_grads);
543                    }
544                }
545
546                // Step critic
547                {
548                    let params = critic.parameters();
549                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
550                    for p in params {
551                        if p.grad_owned().is_some() {
552                            with_grads.push(p);
553                        }
554                    }
555                    if !with_grads.is_empty() {
556                        let _ = grad_global_norm(&mut with_grads);
557                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558                        critic_opt.step(&mut with_grads);
559                        critic_opt.zero_grad(&mut with_grads);
560                    }
561                }
562
563                if e == 0 && mb == 0 {
564                    println!(
565                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
566                        t,
567                        actor_loss.value(),
568                        v_loss.value()
569                    );
570                }
571
572                clear_all_graphs_known();
573            }
574        }
575    }
576
577    println!("=== PPO discrete training finished ===");
578    Ok(())
579}
examples/RL_training/ppo_continuous.rs (line 239)
234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235    // All tensors shaped [B, A] (log_std is broadcastable)
236    let std = log_std.exp();
237    let var = std.pow_scalar(2.0);
238    let log_scale = log_std;
239    let diff = action.sub_tensor(mean);
240    let log_prob = diff
241        .pow_scalar(2.0)
242        .div_tensor(&var)
243        .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244        .add_tensor(&log_scale.mul_scalar(2.0))
245        .mul_scalar(0.5)
246        .mul_scalar(-1.0);
247    // Sum across action dim (dim=1) -> [B,1]
248    log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253    returns_out: &mut [f32],
254    adv_out: &mut [f32],
255    rewards: &[f32],
256    dones: &[f32],
257    values: &[f32],
258    next_values: &[f32],
259    gamma: f32,
260    lam: f32,
261) {
262    let n = rewards.len();
263    let mut gae = 0.0f32;
264    for t in (0..n).rev() {
265        let not_done = 1.0 - dones[t];
266        let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267        gae = delta + gamma * lam * not_done * gae;
268        adv_out[t] = gae;
269        returns_out[t] = gae + values[t];
270    }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274    let n = x.len() as f32;
275    if n <= 1.0 {
276        return;
277    }
278    let mean = x.iter().copied().sum::<f32>() / n;
279    let var = x
280        .iter()
281        .map(|v| {
282            let d = v - mean;
283            d * d
284        })
285        .sum::<f32>()
286        / n;
287    let std = (var + eps).sqrt();
288    for v in x.iter_mut() {
289        *v = (*v - mean) / std;
290    }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294    let mut total_sq = 0.0f32;
295    for p in parameters.iter() {
296        if let Some(g) = p.grad_owned() {
297            for &v in g.data() {
298                total_sq += v * v;
299            }
300        }
301    }
302    let norm = total_sq.sqrt();
303    if norm > max_norm {
304        let scale = max_norm / (norm + eps);
305        for p in parameters.iter_mut() {
306            if let Some(g) = p.grad_owned() {
307                p.set_grad(g.mul_scalar(scale));
308            }
309        }
310    }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314    let mut total_sq = 0.0f32;
315    for p in parameters.iter_mut() {
316        if let Some(g) = p.grad_owned() {
317            for &v in g.data() {
318                total_sq += v * v;
319            }
320        }
321    }
322    total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330    println!("=== PPO Continuous Example (YardEnv) ===");
331
332    let state_dim = 3usize;
333    let action_dim = 1usize;
334
335    // Hparams
336    let total_steps = std::env::var("PPO_STEPS")
337        .ok()
338        .and_then(|v| v.parse::<usize>().ok())
339        .unwrap_or(4000usize);
340    let horizon = 128usize; // rollout length per update
341    let epochs = 4usize; // PPO epochs per update
342    let mini_batch_size = 64usize; // minibatch from horizon
343    let gamma = 0.99f32;
344    let lam = 0.95f32; // GAE lambda
345    let clip_eps = 0.2f32;
346    let vf_coef = 0.5f32;
347    let ent_coef = 0.0f32;
348    let max_grad_norm = 1.0f32;
349
350    // Models
351    let mut actor = Actor::new(state_dim, action_dim, Some(101));
352    let mut critic = Critic::new(state_dim, Some(202));
353
354    // Opts
355    let mut actor_opt = Adam::with_learning_rate(3e-4);
356    for p in actor.parameters() {
357        actor_opt.add_parameter(p);
358    }
359    let mut critic_opt = Adam::with_learning_rate(3e-4);
360    for p in critic.parameters() {
361        critic_opt.add_parameter(p);
362    }
363
364    // Env and RNG
365    let mut env = YardEnv::new(42);
366    let mut rng = SmallRng::new(999);
367    let mut state = env.reset();
368
369    // Metrics
370    let mut episode_return = 0.0f32;
371    let mut episode = 0usize;
372    let mut ema_return: Option<f32> = None;
373    let ema_alpha = 0.05f32;
374    let mut best_return = f32::NEG_INFINITY;
375
376    let mut t = 0usize;
377    while t < total_steps {
378        // Collect a rollout
379        let mut batch = RolloutBatch::new(horizon, state_dim);
380        for _ in 0..horizon {
381            // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382            let (mean, log_std_row) = actor.forward(&state);
383            let mean_v = mean.data()[0];
384            let log_std_v = log_std_row.data()[0];
385            let std_v = log_std_v.exp();
386            let noise = rng.normal();
387            let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389            // Build action tensor [1, A] for log_prob calculation with autograd
390            let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391            let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392            let log_prob_v = log_prob_t.data()[0];
393
394            // Step env
395            let (next_state, reward, done) = env.step(action_v);
396            episode_return += reward;
397
398            // Value
399            let value_t = critic.forward(&state);
400            let value_v = value_t.data()[0];
401
402            // Push
403            batch.push(
404                state.data(),
405                action_v,
406                log_prob_v,
407                reward,
408                if done { 1.0 } else { 0.0 },
409                value_v,
410                next_state.data(),
411            );
412
413            // Reset
414            state = if done {
415                let st = env.reset();
416                ema_return = Some(match ema_return {
417                    None => episode_return,
418                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419                });
420                if episode_return > best_return {
421                    best_return = episode_return;
422                }
423                println!(
424                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425                    t,
426                    episode,
427                    episode_return,
428                    ema_return.unwrap_or(episode_return),
429                    best_return
430                );
431                episode_return = 0.0;
432                episode += 1;
433                st
434            } else {
435                next_state
436            };
437
438            t += 1;
439            if t >= total_steps {
440                break;
441            }
442        }
443
444        // Bootstrap next values for GAE
445        let next_values: Vec<f32> = {
446            let mut out = Vec::with_capacity(batch.len());
447            for i in 0..batch.len() {
448                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450                let v2 = critic.forward(&s2_t).data()[0];
451                out.push(v2);
452            }
453            out
454        };
455
456        // Compute returns and advantages
457        let mut returns = vec![0.0f32; batch.len()];
458        let mut adv = vec![0.0f32; batch.len()];
459        compute_gae(
460            &mut returns,
461            &mut adv,
462            &batch.rewards,
463            &batch.dones,
464            &batch.values,
465            &next_values,
466            gamma,
467            lam,
468        );
469        normalize_in_place(&mut adv, 1e-8);
470
471        // Prepare tensors for training
472        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473        let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474        let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478        // PPO epochs over the rollout
479        let num_minibatches = batch.len().div_ceil(mini_batch_size);
480        for e in 0..epochs {
481            for mb in 0..num_minibatches {
482                let start = mb * mini_batch_size;
483                let end = (start + mini_batch_size).min(batch.len());
484                if start >= end {
485                    break;
486                }
487
488                // Slice views
489                let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490                let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491                let a_mb = actions_t
492                    .slice_view(start * action_dim, 1, (end - start) * action_dim)
493                    .reshape(vec![(end - start) as i32, action_dim as i32]);
494                let oldlp_mb = old_logp_t
495                    .slice_view(start, 1, end - start)
496                    .reshape(vec![(end - start) as i32, 1]);
497                let ret_mb = returns_t
498                    .slice_view(start, 1, end - start)
499                    .reshape(vec![(end - start) as i32, 1]);
500                let adv_mb = adv_t
501                    .slice_view(start, 1, end - start)
502                    .reshape(vec![(end - start) as i32, 1]);
503
504                // Zero grads
505                {
506                    let mut ps = actor.parameters();
507                    actor_opt.zero_grad(&mut ps);
508                }
509                {
510                    let mut ps = critic.parameters();
511                    critic_opt.zero_grad(&mut ps);
512                }
513
514                // Forward actor and critic
515                let (mean_mb, log_std_row) = actor.forward(&s_mb);
516                let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517                let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518                let clip_low =
519                    Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520                        .unwrap();
521                let clip_high =
522                    Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523                        .unwrap();
524                // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525                let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526                let ratio_clipped =
527                    clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528                let pg1 = ratio.mul_tensor(&adv_mb);
529                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532                let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534                let v_pred = critic.forward(&s_mb);
535                let v_loss = v_pred
536                    .sub_tensor(&ret_mb)
537                    .pow_scalar(2.0)
538                    .mean()
539                    .mul_scalar(vf_coef);
540
541                // Entropy (approx Gaussian entropy per action)
542                let entropy = log_std_row
543                    .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544                    .sum_dims(&[1], true)
545                    .mean()
546                    .mul_scalar(ent_coef);
547
548                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549                loss.backward(None);
550
551                // Step actor
552                {
553                    let params = actor.parameters();
554                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
555                    for p in params {
556                        if p.grad_owned().is_some() {
557                            with_grads.push(p);
558                        }
559                    }
560                    if !with_grads.is_empty() {
561                        let _ = grad_global_norm(&mut with_grads);
562                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563                        actor_opt.step(&mut with_grads);
564                        actor_opt.zero_grad(&mut with_grads);
565                    }
566                }
567
568                // Step critic
569                {
570                    let params = critic.parameters();
571                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
572                    for p in params {
573                        if p.grad_owned().is_some() {
574                            with_grads.push(p);
575                        }
576                    }
577                    if !with_grads.is_empty() {
578                        let _ = grad_global_norm(&mut with_grads);
579                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580                        critic_opt.step(&mut with_grads);
581                        critic_opt.zero_grad(&mut with_grads);
582                    }
583                }
584
585                // Occasionally log
586                if e == 0 && mb == 0 {
587                    println!(
588                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
589                        t,
590                        actor_loss.value(),
591                        v_loss.value()
592                    );
593                }
594
595                clear_all_graphs_known();
596            }
597        }
598    }
599
600    println!("=== PPO training finished ===");
601    Ok(())
602}
examples/getting_started/tensor_basics.rs (line 100)
89fn demonstrate_basic_operations() {
90    println!("\n--- Basic Operations ---");
91
92    let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
93    let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
94
95    // Addition
96    let sum = a.add_tensor(&b);
97    println!("A + B: {:?}", sum.data());
98
99    // Subtraction
100    let diff = a.sub_tensor(&b);
101    println!("A - B: {:?}", diff.data());
102
103    // Multiplication
104    let product = a.mul_tensor(&b);
105    println!("A * B: {:?}", product.data());
106
107    // Division
108    let quotient = a.div_tensor(&b);
109    println!("A / B: {:?}", quotient.data());
110
111    // Scalar operations
112    let scalar_add = a.add_scalar(5.0);
113    println!("A + 5.0: {:?}", scalar_add.data());
114
115    let scalar_mul = a.mul_scalar(2.0);
116    println!("A * 2.0: {:?}", scalar_mul.data());
117}
Source

pub fn sub_scalar(&self, scalar: f32) -> Tensor

Element-wise subtraction of a scalar from this tensor

Performs element-wise subtraction of a scalar value: output[i] = self[i] - scalar

§Arguments
  • scalar - The scalar value to subtract from each element
§Returns

A new tensor with the scalar subtracted from each element

§Performance Characteristics
  • SIMD Optimization: AVX2-optimized with 32-element blocks and 4x unrolling
  • Scalar Fallback: 4x unrolled scalar implementation for non-SIMD hardware
  • Cache-friendly: Linear memory access patterns
  • Mathematical Accuracy: High-precision subtraction computation
  • GradTrack Support: Full automatic differentiation with efficient gradient computation
§Examples
§Basic Scalar Subtraction
use train_station::Tensor;

let a = Tensor::from_slice(&[5.0, 7.0, 9.0], vec![3]).unwrap();
let b = a.sub_scalar(2.0);
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 3.0); // 5.0 - 2.0
assert_eq!(b.get(&[1]), 5.0); // 7.0 - 2.0
assert_eq!(b.get(&[2]), 7.0); // 9.0 - 2.0
§Negative Scalar Subtraction
use train_station::Tensor;

let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let b = a.sub_scalar(-2.0); // Subtracting negative = adding
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 3.0); // 1.0 - (-2.0) = 3.0
assert_eq!(b.get(&[1]), 4.0); // 2.0 - (-2.0) = 4.0
assert_eq!(b.get(&[2]), 5.0); // 3.0 - (-2.0) = 5.0
Examples found in repository?
examples/RL_training/dqn.rs (line 325)
321fn pseudo_huber_mean(diff: &Tensor) -> Tensor {
322    diff.pow_scalar(2.0)
323        .add_scalar(1.0)
324        .sqrt()
325        .sub_scalar(1.0)
326        .mean()
327}
More examples
Hide additional examples
examples/iterators/advanced_patterns.rs (line 114)
87fn demonstrate_data_pipeline() -> Result<(), Box<dyn std::error::Error>> {
88    println!("\n--- Data Processing Pipeline ---");
89
90    // Simulate raw sensor data with noise
91    let raw_data: Vec<f32> = (0..20)
92        .map(|i| {
93            let base = i as f32 * 0.5;
94            let noise = (i % 3) as f32 * 0.1;
95            base + noise
96        })
97        .collect();
98
99    let tensor = Tensor::from_slice(&raw_data, vec![20])?;
100    println!("Raw sensor data: {:?}", tensor.data());
101
102    // Multi-stage processing pipeline
103    println!("\nProcessing pipeline:");
104    println!("1. Normalize data (z-score)");
105    println!("2. Apply smoothing filter");
106    println!("3. Detect outliers");
107    println!("4. Apply feature scaling");
108
109    // Stage 1: Normalization
110    let mean = tensor.mean().value();
111    let std = tensor.std().value();
112    let normalized: Tensor = tensor
113        .iter()
114        .map(|elem| elem.sub_scalar(mean).div_scalar(std))
115        .collect();
116    println!(
117        "  Normalized (mean={:.3}, std={:.3}): {:?}",
118        mean,
119        std,
120        normalized.data()
121    );
122
123    // Stage 2: Smoothing (simple moving average)
124    let smoothed: Tensor = normalized
125        .iter()
126        .enumerate()
127        .map(|(i, elem)| {
128            if i == 0 || i == normalized.size() - 1 {
129                elem.clone()
130            } else {
131                // Simple 3-point average
132                let prev = normalized.element_view(i - 1);
133                let next = normalized.element_view(i + 1);
134                elem.add_tensor(&prev).add_tensor(&next).div_scalar(3.0)
135            }
136        })
137        .collect();
138    println!("  Smoothed: {:?}", smoothed.data());
139
140    // Stage 3: Outlier detection and removal
141    let outlier_threshold = 2.0;
142    let cleaned: Tensor = smoothed
143        .iter()
144        .filter(|elem| elem.value().abs() < outlier_threshold)
145        .collect();
146    println!(
147        "  Outliers removed (threshold={}): {:?}",
148        outlier_threshold,
149        cleaned.data()
150    );
151
152    // Stage 4: Feature scaling to [0, 1] range
153    let min_val = cleaned
154        .iter()
155        .map(|e| e.value())
156        .fold(f32::INFINITY, f32::min);
157    let max_val = cleaned
158        .iter()
159        .map(|e| e.value())
160        .fold(f32::NEG_INFINITY, f32::max);
161    let scaled: Tensor = cleaned
162        .iter()
163        .map(|elem| elem.sub_scalar(min_val).div_scalar(max_val - min_val))
164        .collect();
165    println!("  Scaled to [0,1]: {:?}", scaled.data());
166
167    Ok(())
168}
169
170/// Demonstrate conditional processing patterns
171///
172/// Shows how to implement dynamic filtering and transformation
173/// based on data characteristics and conditions.
174fn demonstrate_conditional_processing() -> Result<(), Box<dyn std::error::Error>> {
175    println!("\n--- Conditional Processing ---");
176
177    // Create data with mixed characteristics
178    let data = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0];
179    let tensor = Tensor::from_slice(&data, vec![10])?;
180    println!("Input data: {:?}", tensor.data());
181
182    // Conditional transformation based on sign
183    println!("\nConditional transformation (positive/negative handling):");
184    let processed: Tensor = tensor
185        .iter()
186        .map(|elem| {
187            let val = elem.value();
188            if val > 0.0 {
189                elem.pow_scalar(2.0) // Square positive values
190            } else {
191                elem.mul_scalar(-1.0).sqrt() // Square root of absolute negative values
192            }
193        })
194        .collect();
195    println!("  Processed: {:?}", processed.data());
196
197    // Adaptive filtering based on local statistics
198    println!("\nAdaptive filtering (remove values > 2 std from local mean):");
199    let window_size = 3;
200    let adaptive_filtered: Tensor = tensor
201        .iter()
202        .enumerate()
203        .filter(|(i, elem)| {
204            let start = i.saturating_sub(window_size / 2);
205            let end = (i + window_size / 2 + 1).min(tensor.size());
206
207            // Calculate local mean and std
208            let local_values: Vec<f32> = (start..end)
209                .map(|j| tensor.element_view(j).value())
210                .collect();
211
212            let local_mean = local_values.iter().sum::<f32>() / local_values.len() as f32;
213            let local_variance = local_values
214                .iter()
215                .map(|v| (v - local_mean).powi(2))
216                .sum::<f32>()
217                / local_values.len() as f32;
218            let local_std = local_variance.sqrt();
219
220            let threshold = local_mean + 2.0 * local_std;
221            elem.value() <= threshold
222        })
223        .map(|(_, elem)| elem)
224        .collect();
225    println!("  Adaptive filtered: {:?}", adaptive_filtered.data());
226
227    // Multi-condition processing
228    println!("\nMulti-condition processing:");
229    let multi_processed: Tensor = tensor
230        .iter()
231        .map(|elem| {
232            let val = elem.value();
233            match () {
234                _ if val > 5.0 => elem.mul_scalar(2.0), // Double large values
235                _ if val < -5.0 => elem.div_scalar(2.0), // Halve small values
236                _ if val.abs() < 2.0 => elem.add_scalar(1.0), // Add 1 to small values
237                _ => elem.clone(),                      // Keep others unchanged
238            }
239        })
240        .collect();
241    println!("  Multi-condition: {:?}", multi_processed.data());
242
243    Ok(())
244}
245
246/// Demonstrate batch processing operations
247///
248/// Shows efficient processing of large datasets using iterator
249/// patterns and batch operations for performance optimization.
250fn demonstrate_batch_operations() -> Result<(), Box<dyn std::error::Error>> {
251    println!("\n--- Batch Operations ---");
252
253    // Create a larger dataset for batch processing
254    let size = 100;
255    let data: Vec<f32> = (0..size)
256        .map(|i| {
257            let x = i as f32 / size as f32;
258            x * x + 0.1 * (i % 7) as f32 // Quadratic with some noise
259        })
260        .collect();
261
262    let tensor = Tensor::from_slice(&data, vec![size])?;
263    println!("Dataset size: {}", tensor.size());
264
265    // Batch processing with windowing (iterator views)
266    println!("\nBatch processing with sliding windows:");
267    let batch_size = 10;
268    let batches: Vec<Tensor> = tensor
269        .iter()
270        .collect::<Vec<_>>()
271        .chunks(batch_size)
272        .map(|chunk| {
273            // Process each batch independently
274            chunk
275                .iter()
276                .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
277                .collect()
278        })
279        .collect();
280
281    println!(
282        "  Processed {} batches of size {}",
283        batches.len(),
284        batch_size
285    );
286    for (i, batch) in batches.iter().enumerate() {
287        println!(
288            "    Batch {}: mean={:.3}, std={:.3}",
289            i,
290            batch.mean().value(),
291            batch.std().value()
292        );
293    }
294
295    // Parallel-like processing with stride
296    println!("\nStrided processing (every nth element):");
297    let stride = 5;
298    let strided: Tensor = tensor
299        .iter()
300        .enumerate()
301        .filter(|(i, _)| i % stride == 0)
302        .map(|(_, elem)| elem)
303        .collect();
304    println!("  Strided (every {}th): {:?}", stride, strided.data());
305
306    // Hierarchical processing
307    println!("\nHierarchical processing (coarse to fine):");
308    let coarse: Tensor = tensor
309        .iter()
310        .enumerate()
311        .filter(|(i, _)| i % 4 == 0) // Take every 4th element
312        .map(|(_, elem)| elem)
313        .collect();
314
315    let fine: Tensor = tensor
316        .iter()
317        .enumerate()
318        .filter(|(i, _)| i % 4 != 0) // Take the rest
319        .map(|(_, elem)| elem)
320        .collect();
321
322    println!("  Coarse (every 4th): {:?}", coarse.data());
323    println!("  Fine (rest): {:?}", fine.data());
324
325    // Combine coarse and fine with different processing
326    let combined: Tensor = coarse
327        .iter()
328        .map(|elem| elem.mul_scalar(2.0)) // Scale coarse
329        .chain(fine.iter().map(|elem| elem.div_scalar(2.0))) // Scale fine
330        .collect();
331    println!("  Combined: {:?}", combined.data());
332
333    Ok(())
334}
335
336/// Demonstrate real-world processing scenarios
337///
338/// Shows practical applications of iterator patterns for
339/// common data processing tasks in machine learning and analytics.
340fn demonstrate_real_world_scenarios() -> Result<(), Box<dyn std::error::Error>> {
341    println!("\n--- Real-world Scenarios ---");
342
343    // Scenario 1: Time series analysis
344    println!("\nScenario 1: Time Series Analysis");
345    let time_series: Vec<f32> = (0..24)
346        .map(|hour| {
347            let base = 20.0 + 10.0 * (hour as f32 * std::f32::consts::PI / 12.0).sin();
348            base + (hour % 3) as f32 * 2.0 // Add some noise
349        })
350        .collect();
351
352    let series = Tensor::from_slice(&time_series, vec![24])?;
353    println!("  Time series (24 hours): {:?}", series.data());
354
355    // Calculate moving average with view-based iteration
356    let window_size = 3;
357    let moving_avg: Tensor = series
358        .iter()
359        .enumerate()
360        .map(|(i, _)| {
361            let start = i.saturating_sub(window_size / 2);
362            let end = (i + window_size / 2 + 1).min(series.size());
363            let window = series.iter_range(start, end);
364            window.fold(0.0, |acc, elem| acc + elem.value()) / (end - start) as f32
365        })
366        .map(|val| Tensor::from_slice(&[val], vec![1]).unwrap())
367        .collect();
368    println!(
369        "  Moving average (window={}): {:?}",
370        window_size,
371        moving_avg.data()
372    );
373
374    // Inference pipeline with NoGrad + streaming
375    println!("\nInference pipeline (NoGrad + streaming)");
376    let features = Tensor::from_slice(
377        &(0..48).map(|i| i as f32 * 0.125).collect::<Vec<_>>(),
378        vec![6, 8],
379    )?;
380    let fast = with_no_grad(|| {
381        // Stream values directly, apply light affine, and collect back to same shape
382        features
383            .data()
384            .iter()
385            .copied()
386            .map(|x| 0.75 * x + 0.1)
387            .collect_shape(vec![6, 8])
388    });
389    println!(
390        "  NoGrad streamed transform shape: {:?}",
391        fast.shape().dims()
392    );
393
394    // Row-wise iteration with shape-preserving collection (GradTrack-friendly)
395    let per_row: Tensor = features
396        .iter()
397        .map(|row| row.mul_scalar(0.5).add_scalar(2.0))
398        .collect_shape(vec![6, 8]);
399    println!("  Row-wise mapped shape: {:?}", per_row.shape().dims());
400
401    // Scenario 2: Feature engineering
402    println!("\nScenario 2: Feature Engineering");
403    let features = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
404    println!("  Original features: {:?}", features.data());
405
406    // Create polynomial features
407    let poly_features: Tensor = features
408        .iter()
409        .flat_map(|elem| {
410            vec![
411                elem.clone(),         // x^1
412                elem.pow_scalar(2.0), // x^2
413                elem.pow_scalar(3.0), // x^3
414            ]
415        })
416        .collect();
417    println!(
418        "  Polynomial features (x, x^2, x^3): {:?}",
419        poly_features.data()
420    );
421
422    // Scenario 3: Data augmentation
423    println!("\nScenario 3: Data Augmentation");
424    let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?;
425    println!("  Original data: {:?}", original.data());
426
427    // Augment with noise and scaling
428    let augmented: Tensor = original
429        .iter()
430        .flat_map(|elem| {
431            vec![
432                elem.clone(),         // Original
433                elem.add_scalar(0.1), // Add noise
434                elem.sub_scalar(0.1), // Subtract noise
435                elem.mul_scalar(1.1), // Scale up
436                elem.mul_scalar(0.9), // Scale down
437            ]
438        })
439        .collect();
440    println!("  Augmented data: {:?}", augmented.data());
441
442    // Scenario 4: Statistical analysis
443    println!("\nScenario 4: Statistical Analysis");
444    let sample_data = Tensor::from_slice(&[1.1, 2.3, 1.8, 2.1, 1.9, 2.0, 1.7, 2.2], vec![8])?;
445    println!("  Sample data: {:?}", sample_data.data());
446
447    // Calculate various statistics
448    let mean = sample_data.mean().value();
449    let std = sample_data.std().value();
450    let min = sample_data
451        .iter()
452        .map(|e| e.value())
453        .fold(f32::INFINITY, f32::min);
454    let max = sample_data
455        .iter()
456        .map(|e| e.value())
457        .fold(f32::NEG_INFINITY, f32::max);
458
459    // Z-score normalization
460    let z_scores: Tensor = sample_data
461        .iter()
462        .map(|elem| elem.sub_scalar(mean).div_scalar(std))
463        .collect();
464
465    println!(
466        "  Statistics: mean={:.3}, std={:.3}, min={:.3}, max={:.3}",
467        mean, std, min, max
468    );
469    println!("  Z-scores: {:?}", z_scores.data());
470
471    Ok(())
472}
Source§

impl Tensor

Source

pub fn tanh(&self) -> Tensor

Element-wise hyperbolic tangent activation

Computes hyperbolic tangent for each element: output[i] = tanh(self[i])

The hyperbolic tangent function maps any real number to the range (-1, 1), making it useful as an activation function in neural networks.

§Returns

A new tensor with tanh applied to each element, values in range (-1, 1)

§Performance Characteristics
  • High Precision: Accurate scalar implementation for mathematical validation
  • 4x Unrolling: Optimized scalar operations with instruction-level parallelism
  • Cache-friendly: Linear memory access patterns
  • Numerical Stability: Robust handling of extreme input values
  • GradTrack Support: Full automatic differentiation with efficient gradient computation
§Mathematical Properties
  • Range: Output values are in the range (-1, 1)
  • Symmetry: tanh(-x) = -tanh(x) (odd function)
  • Zero: tanh(0) = 0
  • Gradient: ∂tanh(x)/∂x = 1 - tanh²(x) = sech²(x)
§Examples
§Basic Hyperbolic Tangent
use train_station::Tensor;

let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
let b = a.tanh();
assert_eq!(b.shape().dims(), vec![3]);
assert!((b.get(&[0]) - (-0.7615942)).abs() < 1e-6); // tanh(-1.0)
assert!((b.get(&[1]) - 0.0).abs() < 1e-6); // tanh(0.0)
assert!((b.get(&[2]) - 0.7615942).abs() < 1e-6); // tanh(1.0)
§Extreme Values
use train_station::Tensor;

let a = Tensor::from_slice(&[-10.0, 10.0], vec![2]).unwrap();
let b = a.tanh();
assert_eq!(b.shape().dims(), vec![2]);
assert!((b.get(&[0]) - (-1.0)).abs() < 1e-6); // tanh(-10.0) ≈ -1
assert!((b.get(&[1]) - 1.0).abs() < 1e-6); // tanh(10.0) ≈ 1
Examples found in repository?
examples/RL_training/td3.rs (line 55)
54fn tanh_bounded(x: &Tensor) -> Tensor {
55    x.tanh()
56}
Source§

impl Tensor

Source

pub fn argmax(&self) -> Tensor

Returns the index of the maximum value across all elements in the tensor

This operation finds the flat index (0-based) of the element with the highest value. If multiple elements have the same maximum value, the index of the first occurrence is returned. The output is a scalar tensor with shape [1] containing the index as a float.

This operation is non-differentiable and the output never requires gradients.

§Returns

A tensor with shape [1] containing the flat index of the maximum value

§Examples
use train_station::Tensor;

// 1D tensor
let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![4]).unwrap();
let max_idx = tensor.argmax();
assert_eq!(max_idx.shape().dims(), vec![1]);
assert_eq!(max_idx.get(&[0]), 1.0); // Index 1 has value 5.0
use train_station::Tensor;

// 2D tensor
let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
let max_idx = tensor.argmax();
assert_eq!(max_idx.get(&[0]), 5.0); // Flat index 5 has value 5.0
use train_station::Tensor;

// Tied values return first occurrence
let tensor = Tensor::from_slice(&[3.0, 5.0, 5.0, 2.0], vec![4]).unwrap();
let max_idx = tensor.argmax();
assert_eq!(max_idx.get(&[0]), 1.0); // First occurrence of 5.0 at index 1
Source

pub fn argmax_dim(&self, dim: usize, keepdim: bool) -> Tensor

Returns the indices of maximum values along a specified dimension

This operation finds the indices of maximum values along the specified dimension. For each slice along the dimension, it returns the index of the maximum value. If multiple elements have the same maximum value, the index of the first occurrence is returned.

The output shape depends on the keepdim parameter:

  • If keepdim is true, the reduced dimension is kept with size 1
  • If keepdim is false, the reduced dimension is removed

This operation is non-differentiable and the output never requires gradients.

§Arguments
  • dim - The dimension along which to find argmax indices (0-based)
  • keepdim - Whether to keep the reduced dimension with size 1
§Returns

A tensor containing the indices of maximum values along the specified dimension

§Panics

Panics if dim is out of bounds for the tensor’s rank or if the dimension size is 0.

§Examples
use train_station::Tensor;

// 2D tensor: [[1.0, 3.0, 2.0],
//             [4.0, 0.0, 5.0]]
let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();

// argmax along columns (dim=1)
let col_max_idx = tensor.argmax_dim(1, false);
assert_eq!(col_max_idx.shape().dims(), vec![2]);
assert_eq!(col_max_idx.get(&[0]), 1.0); // Row 0: max at index 1 (value 3.0)
assert_eq!(col_max_idx.get(&[1]), 2.0); // Row 1: max at index 2 (value 5.0)
use train_station::Tensor;

// argmax along rows (dim=0) with keepdim
let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
let row_max_idx = tensor.argmax_dim(0, true);
assert_eq!(row_max_idx.shape().dims(), vec![1, 3]);
assert_eq!(row_max_idx.get(&[0, 0]), 1.0); // Col 0: max at index 1 (value 4.0)
assert_eq!(row_max_idx.get(&[0, 1]), 0.0); // Col 1: max at index 0 (value 3.0)
assert_eq!(row_max_idx.get(&[0, 2]), 1.0); // Col 2: max at index 1 (value 5.0)
use train_station::Tensor;

// 1D tensor edge case
let tensor = Tensor::from_slice(&[5.0, 1.0, 8.0, 3.0], vec![4]).unwrap();
let max_idx = tensor.argmax_dim(0, false);
assert_eq!(max_idx.shape().dims(), vec![1]); // Special case: becomes [1] not []
assert_eq!(max_idx.get(&[0]), 2.0); // Index 2 has maximum value 8.0
Source§

impl Tensor

Source

pub fn argmin(&self) -> Tensor

Returns the index of the minimum value in the tensor

This method finds the flat index of the minimum value across all elements in the tensor. The result is a scalar tensor containing the index as a floating-point value. This operation is non-differentiable and the output never requires gradient tracking.

§Returns

A tensor with shape [1] containing the flat index of the minimum value as a f32. If the input tensor is empty, returns 0.0.

§Examples
use train_station::Tensor;

let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
let min_index = tensor.argmin();
assert_eq!(min_index.get(&[0]), 1.0); // -2.0 is at index 1
use train_station::Tensor;

// Empty tensor case
let empty_tensor = Tensor::new(vec![0]);
let min_index = empty_tensor.argmin();
assert_eq!(min_index.get(&[0]), 0.0);
Source

pub fn argmin_dim(&self, dim: usize, keepdim: bool) -> Tensor

Returns the indices of minimum values along a specified dimension

This method finds the indices of minimum values along the specified dimension. The result contains the indices where the minimum values occur in that dimension. This operation is non-differentiable and the output never requires gradient tracking.

§Arguments
  • dim - The dimension along which to find minimum indices (0-based)
  • keepdim - Whether to keep the reduced dimension in the output shape
    • If true, the reduced dimension is kept with size 1
    • If false, the reduced dimension is removed from the output shape
§Returns

A tensor containing the indices of minimum values along the specified dimension. The output shape depends on keepdim:

  • If keepdim is true, the reduced dimension has size 1
  • If keepdim is false, the reduced dimension is removed
§Panics
  • If dim is out of bounds for the tensor’s rank
  • If the dimension to reduce has size 0
§Examples
use train_station::Tensor;

let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();

// Find minimum indices along dimension 1 (columns), keeping the dimension
let indices = tensor.argmin_dim(1, true);
assert_eq!(indices.shape().dims(), vec![2, 1]);
assert_eq!(indices.get(&[0, 0]), 1.0); // -2.0 is at index 1 in first row
assert_eq!(indices.get(&[1, 0]), 2.0); // -3.0 is at index 2 in second row
use train_station::Tensor;

let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();

// Find minimum indices along dimension 1 (columns), removing the dimension
let indices = tensor.argmin_dim(1, false);
assert_eq!(indices.shape().dims(), vec![2]);
assert_eq!(indices.get(&[0]), 1.0); // -2.0 is at index 1 in first row
assert_eq!(indices.get(&[1]), 2.0); // -3.0 is at index 2 in second row
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();

// Find minimum index in a 1D tensor
let index = tensor.argmin_dim(0, false);
assert_eq!(index.shape().dims(), vec![1]);
assert_eq!(index.get(&[0]), 0.0); // 1.0 is at index 0
Source§

impl Tensor

Source

pub fn max(&self) -> Tensor

Computes the maximum value over all elements in the tensor

Returns a scalar tensor containing the maximum value. For empty tensors, returns negative infinity. This operation supports gradient tracking through the GradTrack system.

§Returns

A tensor with shape [1] containing the maximum value

§Examples
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![2, 2]).unwrap();
let max_val = tensor.max();
assert_eq!(max_val.get(&[0]), 5.0);
§GradTrack Support

When requires_grad is true, this operation is tracked for automatic differentiation. The gradient computation uses the saved input and output for efficient backward pass.

Source

pub fn max_dims(&self, dims: &[usize], keepdim: bool) -> Tensor

Computes the maximum value over specified dimensions

Reduces the tensor along the specified dimensions by computing the maximum value in each reduction group. The keepdim parameter determines whether reduced dimensions are kept with size 1 or removed entirely.

§Arguments
  • dims - Dimensions to reduce over (must be valid for the tensor’s rank)
  • keepdim - If true, reduced dimensions are kept with size 1; if false, they are removed
§Returns

A tensor with the specified dimensions reduced

§Examples
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();

// Max over columns (dim 1), keeping dimensions
let max_cols = tensor.max_dims(&[1], true);
assert_eq!(max_cols.shape().dims(), vec![2, 1]);
assert_eq!(max_cols.get(&[0, 0]), 3.0);
assert_eq!(max_cols.get(&[1, 0]), 6.0);

// Max over rows (dim 0), removing dimensions
let max_rows = tensor.max_dims(&[0], false);
assert_eq!(max_rows.shape().dims(), vec![3]);
assert_eq!(max_rows.get(&[0]), 4.0);
assert_eq!(max_rows.get(&[1]), 5.0);
assert_eq!(max_rows.get(&[2]), 6.0);
§Panics

Panics if:

  • dims is empty
  • Any dimension in dims is out of bounds for the tensor’s rank
§GradTrack Support

When requires_grad is true, this operation is tracked for automatic differentiation. The gradient computation preserves the original input shape and handles broadcasting correctly.

Examples found in repository?
examples/supervised_training/supervised_classification.rs (line 51)
44fn cross_entropy_logits(
45    logits: &Tensor,
46    labels: &[usize],
47    batch: usize,
48    _num_classes: usize,
49) -> Tensor {
50    // log_softmax = logits - logsumexp(logits, dim=1)
51    let max_logits = logits.max_dims(&[1], true);
52    let shifted = logits.sub_tensor(&max_logits);
53    let exp = shifted.exp();
54    let sum_exp = exp.sum_dims(&[1], true);
55    let log_sum_exp = sum_exp.log();
56    let log_softmax = shifted.sub_tensor(&log_sum_exp);
57    let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58    ll.mul_scalar(-1.0).mean()
59}
More examples
Hide additional examples
examples/RL_training/ppo_discrete.rs (line 279)
273fn log_prob_actions(
274    logits: &Tensor,
275    actions: &[usize],
276    batch: usize,
277    _action_dim: usize,
278) -> Tensor {
279    let max_logits = logits.max_dims(&[1], true); // [B,1]
280    let shifted = logits.sub_tensor(&max_logits);
281    let exp = shifted.exp();
282    let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283    let log_sum_exp = sum_exp.log(); // [B,1]
284    let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285                                                        // gather selected action log-probs
286    log_softmax.gather(1, actions, &[batch, 1])
287}
Source§

impl Tensor

Source

pub fn mean(&self) -> Tensor

Computes the arithmetic mean of all elements in the tensor

This method calculates the average value across all tensor elements by summing all values and dividing by the total number of elements. The result is a scalar tensor containing the mean value. This operation supports gradient tracking through the GradTrack system.

§Returns

A tensor with shape [1] containing the arithmetic mean of all elements. For empty tensors, returns 0.0 as a safe default.

§Performance Characteristics
  • Linear Time: O(n) complexity for computing the sum
  • Memory Efficient: Single pass through tensor data with SIMD-optimized accumulation
  • Numerical Stability: Uses direct accumulation for typical tensor sizes
  • Edge Case Handling: Returns 0.0 for empty tensors
§Examples
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let mean_val = tensor.mean();
assert_eq!(mean_val.get(&[0]), 2.5); // (1+2+3+4)/4 = 2.5
use train_station::Tensor;

// Empty tensor case
let empty_tensor = Tensor::new(vec![0]);
let mean_val = empty_tensor.mean();
assert_eq!(mean_val.get(&[0]), 0.0);
§GradTrack Support

When requires_grad is true, this operation is tracked for automatic differentiation. The gradient computation distributes the gradient equally across all input elements.

Examples found in repository?
examples/supervised_training/supervised_regression.rs (line 43)
42fn mse(pred: &Tensor, target: &Tensor) -> Tensor {
43    pred.sub_tensor(target).pow_scalar(2.0).mean()
44}
45
46fn rmse(pred: &Tensor, target: &Tensor) -> f32 {
47    mse(pred, target).sqrt().value()
48}
49
50fn r2_score(pred: &Tensor, target: &Tensor) -> f32 {
51    // R^2 = 1 - SS_res / SS_tot
52    let y = target;
53    let y_mean = y.mean();
54    let ss_res = pred.sub_tensor(y).pow_scalar(2.0).sum();
55    let ss_tot = y.sub_tensor(&y_mean).pow_scalar(2.0).sum();
56    let ss_res_v = ss_res.value();
57    let ss_tot_v = ss_tot.value().max(1e-12); // avoid divide by zero
58    1.0 - (ss_res_v / ss_tot_v)
59}
More examples
Hide additional examples
examples/RL_training/dqn.rs (line 326)
321fn pseudo_huber_mean(diff: &Tensor) -> Tensor {
322    diff.pow_scalar(2.0)
323        .add_scalar(1.0)
324        .sqrt()
325        .sub_scalar(1.0)
326        .mean()
327}
328
329// -------------------------------
330// Main
331// -------------------------------
332
333pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334    println!("=== DQN Example (YardEnv discrete) ===");
335
336    // Dims
337    let state_dim = 3usize;
338    let action_dim = 3usize;
339
340    // Hparams
341    let gamma = 0.99f32;
342    let batch_size = 64usize;
343    let start_steps = 200usize;
344    let target_update_interval = 200usize; // hard update cadence
345    let max_grad_norm = 1.0f32;
346    let mut epsilon = 1.0f32;
347    let eps_min = 0.05f32;
348    let eps_decay_steps = 2_000usize; // linear decay
349    let total_steps = std::env::var("DQN_STEPS")
350        .ok()
351        .and_then(|v| v.parse::<usize>().ok())
352        .unwrap_or(3000usize);
353
354    // Models
355    let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356    let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357    q_targ.net.copy_from(&q_net.net);
358    q_targ.set_requires_grad_all(false);
359
360    // Optimizer
361    let mut q_opt = Adam::with_learning_rate(3e-4);
362    for p in q_net.parameters() {
363        q_opt.add_parameter(p);
364    }
365
366    // Replay + env
367    let mut rb = ReplayBuffer::new(100_000, state_dim);
368    let mut env = YardEnv::new(12345);
369    let mut rng = SmallRng::new(999_111);
370
371    // Metrics
372    let mut state = env.reset();
373    let mut episode_return = 0.0f32;
374    let mut episode = 0usize;
375    let mut ema_return: Option<f32> = None;
376    let ema_alpha = 0.05f32;
377    let mut best_return = f32::NEG_INFINITY;
378
379    for t in 0..total_steps {
380        // Epsilon-greedy action
381        let action_index = if t < start_steps || rng.next_f32() < epsilon {
382            rng.sample_index(action_dim)
383        } else {
384            let _ng = NoGradTrack::new();
385            let q_vals = q_net.forward(&state);
386            let row = q_vals.data();
387            let mut best_i = 0usize;
388            let mut best_v = row[0];
389            for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390                if r > best_v {
391                    best_v = r;
392                    best_i = i;
393                }
394            }
395            best_i
396        };
397
398        // Env step
399        let (next_state, reward, done) = env.step(action_index);
400        episode_return += reward;
401
402        // Store
403        let s_slice = state.data().to_vec();
404        let s2_slice = next_state.data().to_vec();
405        rb.push(
406            &s_slice,
407            action_index,
408            reward,
409            if done { 1.0 } else { 0.0 },
410            &s2_slice,
411        );
412
413        // Reset on done
414        state = if done {
415            let st = env.reset();
416            ema_return = Some(match ema_return {
417                None => episode_return,
418                Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419            });
420            if episode_return > best_return {
421                best_return = episode_return;
422            }
423            println!(
424                "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425                t,
426                episode,
427                episode_return,
428                ema_return.unwrap_or(episode_return),
429                best_return,
430                rb.size
431            );
432            episode_return = 0.0;
433            episode += 1;
434            st
435        } else {
436            next_state
437        };
438
439        // Epsilon linear decay
440        if t < eps_decay_steps {
441            epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442        }
443
444        // Train
445        if rb.can_sample(batch_size) {
446            let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448            // Double DQN target: a* = argmax_a Q_online(s2,a); y = r + (1-d)*gamma*Q_target(s2, a*)
449            let target_q = {
450                let _ng = NoGradTrack::new();
451                let q_online_s2 = q_net.forward(&s2);
452                // argmax per row (manual on CPU)
453                let row_stride = action_dim;
454                let qd = q_online_s2.data();
455                let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456                for i in 0..batch_size {
457                    let base = i * row_stride;
458                    let mut bi = 0usize;
459                    let mut bv = qd[base];
460                    for j in 1..action_dim {
461                        let v = qd[base + j];
462                        if v > bv {
463                            bv = v;
464                            bi = j;
465                        }
466                    }
467                    next_actions.push(bi);
468                }
469                let q_targ_s2 = q_targ.forward(&s2);
470                let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471                let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472                r.add_tensor(&not_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473            };
474
475            // Q(s,a) for current actions
476            // Zero grads first
477            {
478                let mut params = q_net.parameters();
479                q_opt.zero_grad(&mut params);
480            }
481
482            let q_all = q_net.forward(&s);
483            let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484            let diff = q_sa.sub_tensor(&target_q);
485            let mut loss = pseudo_huber_mean(&diff);
486            loss.backward(None);
487
488            // Step (filter only params with grads)
489            {
490                let params = q_net.parameters();
491                let mut with_grads: Vec<&mut Tensor> = Vec::new();
492                for p in params {
493                    if p.grad_owned().is_some() {
494                        with_grads.push(p);
495                    }
496                }
497                if !with_grads.is_empty() {
498                    let gn = grad_global_norm(&mut with_grads);
499                    clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500                    q_opt.step(&mut with_grads);
501                    q_opt.zero_grad(&mut with_grads);
502                    if t % 100 == 0 {
503                        let mut pn = q_net.parameters();
504                        let pn_l2 = params_l2_norm(&mut pn);
505                        let q_mean = q_all.mean().value();
506                        println!(
507                            "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508                            t, loss.value(), q_mean, gn, pn_l2, epsilon
509                        );
510                    }
511                }
512            }
513
514            // Target hard update
515            if t % target_update_interval == 0 {
516                q_targ.net.copy_from(&q_net.net);
517            }
518
519            // Clear graphs
520            clear_all_graphs_known();
521        }
522    }
523
524    println!("=== DQN training finished ===");
525    Ok(())
526}
examples/supervised_training/supervised_bce.rs (line 65)
59fn bce_with_logits(logits: &Tensor, targets: &Tensor) -> Tensor {
60    let relu_z = logits.relu();
61    let zy = logits.mul_tensor(targets);
62    // |z| = relu(z) + relu(-z)
63    let abs_z = relu_z.add_tensor(&logits.mul_scalar(-1.0).relu());
64    let log_term = abs_z.mul_scalar(-1.0).exp().add_scalar(1.0).log();
65    relu_z.sub_tensor(&zy).add_tensor(&log_term).mean()
66}
examples/supervised_training/supervised_classification.rs (line 58)
44fn cross_entropy_logits(
45    logits: &Tensor,
46    labels: &[usize],
47    batch: usize,
48    _num_classes: usize,
49) -> Tensor {
50    // log_softmax = logits - logsumexp(logits, dim=1)
51    let max_logits = logits.max_dims(&[1], true);
52    let shifted = logits.sub_tensor(&max_logits);
53    let exp = shifted.exp();
54    let sum_exp = exp.sum_dims(&[1], true);
55    let log_sum_exp = sum_exp.log();
56    let log_softmax = shifted.sub_tensor(&log_sum_exp);
57    let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58    ll.mul_scalar(-1.0).mean()
59}
examples/neural_networks/basic_encoder.rs (line 94)
73fn main() -> Result<(), Box<dyn std::error::Error>> {
74    println!("=== Basic Encoder Example ===");
75
76    let batch = 2usize;
77    let seq = 6usize;
78    let embed = 32usize;
79    let heads = 4usize;
80
81    let input = Tensor::randn(vec![batch, seq, embed], Some(11));
82    let mut enc = EncoderBlock::new(embed, heads, Some(123));
83
84    // Example: no mask (set Some(mask) to use masking)
85    let out = enc.forward(&input, None);
86    println!("Output shape: {:?}", out.shape().dims());
87
88    // Verify gradients/optimization
89    let mut opt = Adam::with_learning_rate(0.01);
90    let mut params = enc.parameters();
91    for p in &params {
92        opt.add_parameter(p);
93    }
94    let mut loss = out.mean();
95    loss.backward(None);
96    opt.step(&mut params);
97    opt.zero_grad(&mut params);
98    println!("Loss: {:.6}", loss.value());
99    println!("=== Done ===");
100    Ok(())
101}
examples/getting_started/tensor_basics.rs (line 194)
179fn demonstrate_utility_functions() {
180    println!("\n--- Utility Functions ---");
181
182    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184    // Basic properties
185    println!("Shape: {:?}", tensor.shape().dims());
186    println!("Size: {}", tensor.size());
187    println!("Is contiguous: {}", tensor.is_contiguous());
188    println!("Device: {:?}", tensor.device());
189
190    // Mathematical operations
191    let sum = tensor.sum();
192    println!("Sum: {}", sum.value());
193
194    let mean = tensor.mean();
195    println!("Mean: {}", mean.value());
196
197    let norm = tensor.norm();
198    println!("Norm: {}", norm.value());
199
200    // Device placement
201    let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202    println!(
203        "CPU tensor: shape {:?}, device: {:?}",
204        cpu_tensor.shape().dims(),
205        cpu_tensor.device()
206    );
207}
Source

pub fn mean_dims(&self, dims: &[usize], keepdim: bool) -> Tensor

Computes the arithmetic mean over specified dimensions

This method calculates the mean value along the specified dimensions by first computing the sum over those dimensions and then dividing by the product of the reduced dimension sizes. The keepdim parameter determines whether reduced dimensions are kept with size 1 or removed entirely.

§Arguments
  • dims - Dimensions to reduce over (must be valid for the tensor’s rank)
  • keepdim - If true, reduced dimensions are kept with size 1; if false, they are removed
§Returns

A tensor with the specified dimensions reduced by computing the mean. The output shape depends on keepdim:

  • If keepdim is true, reduced dimensions have size 1
  • If keepdim is false, reduced dimensions are removed
§Performance Characteristics
  • Efficient Implementation: Uses sum_dims followed by scalar multiplication
  • Memory Optimized: Leverages existing sum reduction for optimal performance
  • Shape Computation: Fast output shape calculation with dimension preservation
  • Numerical Stability: Maintains precision through direct computation
§Examples
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();

// Mean over columns (dim 1), keeping dimensions
let mean_cols = tensor.mean_dims(&[1], true);
assert_eq!(mean_cols.shape().dims(), vec![2, 1]);
assert_eq!(mean_cols.get(&[0, 0]), 2.0); // (1+2+3)/3 = 2.0
assert_eq!(mean_cols.get(&[1, 0]), 5.0); // (4+5+6)/3 = 5.0
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();

// Mean over rows (dim 0), removing dimensions
let mean_rows = tensor.mean_dims(&[0], false);
assert_eq!(mean_rows.shape().dims(), vec![3]);
assert_eq!(mean_rows.get(&[0]), 2.5); // (1+4)/2 = 2.5
assert_eq!(mean_rows.get(&[1]), 3.5); // (2+5)/2 = 3.5
assert_eq!(mean_rows.get(&[2]), 4.5); // (3+6)/2 = 4.5
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();

// Mean over multiple dimensions
let mean_all = tensor.mean_dims(&[0, 1], false);
assert_eq!(mean_all.shape().dims(), vec![1]);
assert_eq!(mean_all.get(&[0]), 2.5); // (1+2+3+4)/4 = 2.5
§Panics

Panics if:

  • dims is empty
  • Any dimension in dims is out of bounds for the tensor’s rank
§GradTrack Support

When requires_grad is true, this operation is tracked for automatic differentiation. The gradient computation preserves the original input shape and handles broadcasting correctly through the ReduceMeanDims gradient function.

Source§

impl Tensor

Source

pub fn min(&self) -> Tensor

Computes the minimum value over all elements in the tensor

Returns a scalar tensor containing the minimum value. For empty tensors, returns positive infinity. This operation supports gradient tracking through the GradTrack system.

§Returns

A tensor with shape [1] containing the minimum value

§Examples
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![2, 2]).unwrap();
let min_val = tensor.min();
assert_eq!(min_val.get(&[0]), 1.0);
use train_station::Tensor;

// Empty tensor case
let empty_tensor = Tensor::new(vec![0]);
let min_val = empty_tensor.min();
assert_eq!(min_val.get(&[0]), f32::INFINITY);
§GradTrack Support

When requires_grad is true, this operation is tracked for automatic differentiation. The gradient computation uses the saved input and output for efficient backward pass.

Source

pub fn min_dims(&self, dims: &[usize], keepdim: bool) -> Tensor

Computes the minimum value over specified dimensions

Reduces the tensor along the specified dimensions by computing the minimum value in each reduction group. The keepdim parameter determines whether reduced dimensions are kept with size 1 or removed entirely.

§Arguments
  • dims - Dimensions to reduce over (must be valid for the tensor’s rank)
  • keepdim - If true, reduced dimensions are kept with size 1; if false, they are removed
§Returns

A tensor with the specified dimensions reduced

§Examples
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();

// Min over columns (dim 1), keeping dimensions
let min_cols = tensor.min_dims(&[1], true);
assert_eq!(min_cols.shape().dims(), vec![2, 1]);
assert_eq!(min_cols.get(&[0, 0]), 1.0);
assert_eq!(min_cols.get(&[1, 0]), 4.0);
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();

// Min over rows (dim 0), removing dimensions
let min_rows = tensor.min_dims(&[0], false);
assert_eq!(min_rows.shape().dims(), vec![3]);
assert_eq!(min_rows.get(&[0]), 1.0);
assert_eq!(min_rows.get(&[1]), 2.0);
assert_eq!(min_rows.get(&[2]), 3.0);
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();

// Min over multiple dimensions
let min_all = tensor.min_dims(&[0, 1], false);
assert_eq!(min_all.shape().dims(), vec![1]);
assert_eq!(min_all.get(&[0]), 1.0);
§Panics

Panics if:

  • dims is empty
  • Any dimension in dims is out of bounds for the tensor’s rank
§GradTrack Support

When requires_grad is true, this operation is tracked for automatic differentiation. The gradient computation preserves the original input shape and handles broadcasting correctly.

Source§

impl Tensor

Source

pub fn norm(&self) -> Tensor

Computes the L2 norm (Euclidean norm) over all elements

The L2 norm is calculated as sqrt(sum(x²)) where x represents each element in the tensor. This operation reduces the tensor to a scalar value [1].

§Returns

A scalar tensor containing the L2 norm value

§Examples
use train_station::Tensor;

// Basic L2 norm calculation
let tensor = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
let norm = tensor.norm();
assert!((norm.get(&[0]) - 5.0).abs() < 1e-6); // sqrt(3² + 4²) = 5
use train_station::Tensor;

// L2 norm of a larger tensor
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
let norm = tensor.norm();
// sqrt(1² + 2² + 3² + 4² + 5² + 6² + 7² + 8²) = sqrt(204) ≈ 14.283
let expected = 204.0_f32.sqrt();
assert!((norm.get(&[0]) - expected).abs() < 1e-5);
§Performance

Uses optimized contiguous tensor path with 4x loop unrolling for better performance. Non-contiguous tensors use stride-aware iteration.

Examples found in repository?
examples/getting_started/tensor_basics.rs (line 197)
179fn demonstrate_utility_functions() {
180    println!("\n--- Utility Functions ---");
181
182    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184    // Basic properties
185    println!("Shape: {:?}", tensor.shape().dims());
186    println!("Size: {}", tensor.size());
187    println!("Is contiguous: {}", tensor.is_contiguous());
188    println!("Device: {:?}", tensor.device());
189
190    // Mathematical operations
191    let sum = tensor.sum();
192    println!("Sum: {}", sum.value());
193
194    let mean = tensor.mean();
195    println!("Mean: {}", mean.value());
196
197    let norm = tensor.norm();
198    println!("Norm: {}", norm.value());
199
200    // Device placement
201    let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202    println!(
203        "CPU tensor: shape {:?}, device: {:?}",
204        cpu_tensor.shape().dims(),
205        cpu_tensor.device()
206    );
207}
More examples
Hide additional examples
examples/optimizers/adam_configurations.rs (line 370)
317fn train_with_config(config: TrainingConfig) -> Result<TrainingStats, Box<dyn std::error::Error>> {
318    // Create training data
319    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
320    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
321
322    // Create model parameters
323    let mut weight = Tensor::randn(vec![1, 1], Some(123)).with_requires_grad();
324    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
325
326    // Create optimizer with custom configuration
327    let adam_config = AdamConfig {
328        learning_rate: config.learning_rate,
329        beta1: config.beta1,
330        beta2: config.beta2,
331        eps: 1e-8,
332        weight_decay: config.weight_decay,
333        amsgrad: false,
334    };
335
336    let mut optimizer = Adam::with_config(adam_config);
337    optimizer.add_parameter(&weight);
338    optimizer.add_parameter(&bias);
339
340    // Training loop
341    let mut losses = Vec::new();
342    let mut convergence_epoch = config.epochs;
343
344    for epoch in 0..config.epochs {
345        // Forward pass
346        let y_pred = x_data.matmul(&weight) + &bias;
347        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
348
349        // Backward pass
350        loss.backward(None);
351
352        // Optimizer step
353        optimizer.step(&mut [&mut weight, &mut bias]);
354        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
355
356        let loss_value = loss.value();
357        losses.push(loss_value);
358
359        // Check for convergence (loss < 0.01)
360        if loss_value < 0.01 && convergence_epoch == config.epochs {
361            convergence_epoch = epoch;
362        }
363    }
364
365    Ok(TrainingStats {
366        config,
367        final_loss: losses[losses.len() - 1],
368        loss_history: losses,
369        convergence_epoch,
370        weight_norm: weight.norm().value(),
371    })
372}
examples/getting_started/optimizer_basics.rs (line 214)
176fn demonstrate_advanced_training() -> Result<(), Box<dyn std::error::Error>> {
177    println!("\n--- Advanced Training Patterns ---");
178
179    // Create a more complex model
180    let mut weight = Tensor::randn(vec![1, 2], Some(44)).with_requires_grad();
181    let mut bias = Tensor::zeros(vec![2]).with_requires_grad();
182
183    // Create optimizer with different learning rate
184    let mut optimizer = Adam::with_learning_rate(0.005);
185    optimizer.add_parameter(&weight);
186    optimizer.add_parameter(&bias);
187
188    // Create training data: y = 2*x + [1, 3]
189    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
190    let y_true = Tensor::from_slice(
191        &[3.0, 5.0, 7.0, 9.0, 11.0, 6.0, 8.0, 10.0, 12.0, 14.0],
192        vec![5, 2],
193    )
194    .unwrap();
195
196    println!("Advanced training with monitoring:");
197    println!("  Initial learning rate: {}", optimizer.learning_rate());
198
199    // Training loop with monitoring
200    let num_epochs = 50;
201    let mut losses = Vec::new();
202    let mut weight_norms = Vec::new();
203    let mut gradient_norms = Vec::new();
204
205    for epoch in 0..num_epochs {
206        // Forward pass
207        let y_pred = x_data.matmul(&weight) + &bias;
208        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
209
210        // Backward pass
211        loss.backward(None);
212
213        // Compute gradient norm before optimizer step
214        let gradient_norm = weight.grad_owned().unwrap().norm();
215
216        // Optimizer step
217        optimizer.step(&mut [&mut weight, &mut bias]);
218        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
219
220        // Learning rate scheduling: reduce every 10 epochs
221        if epoch > 0 && epoch % 10 == 0 {
222            let current_lr = optimizer.learning_rate();
223            let new_lr = current_lr * 0.5;
224            optimizer.set_learning_rate(new_lr);
225            println!(
226                "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
227                epoch, current_lr, new_lr
228            );
229        }
230
231        // Record metrics
232        losses.push(loss.value());
233        weight_norms.push(weight.norm().value());
234        gradient_norms.push(gradient_norm.value());
235
236        // Print detailed progress
237        if epoch % 10 == 0 || epoch == num_epochs - 1 {
238            println!(
239                "Epoch {:2}: Loss = {:.6}, Weight Norm = {:.6}, Gradient Norm = {:.6}",
240                epoch,
241                loss.value(),
242                weight.norm().value(),
243                gradient_norm.value()
244            );
245        }
246    }
247
248    println!("Final learning rate: {}", optimizer.learning_rate());
249
250    // Analyze training progression
251    let initial_loss = losses[0];
252    let final_loss = losses[losses.len() - 1];
253    let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
254
255    println!("\nTraining Analysis:");
256    println!("  Initial loss: {:.6}", initial_loss);
257    println!("  Final loss: {:.6}", final_loss);
258    println!("  Loss reduction: {:.1}%", loss_reduction);
259    println!("  Final weight norm: {:.6}", weight.norm().value());
260    println!("  Final bias: {:?}", bias.data());
261
262    Ok(())
263}
264
265/// Demonstrate learning rate scheduling
266fn demonstrate_learning_rate_scheduling() -> Result<(), Box<dyn std::error::Error>> {
267    println!("\n--- Learning Rate Scheduling ---");
268
269    // Create simple model
270    let mut weight = Tensor::randn(vec![1, 1], Some(45)).with_requires_grad();
271    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
272
273    // Create optimizer with high initial learning rate
274    let mut optimizer = Adam::with_learning_rate(0.1);
275    optimizer.add_parameter(&weight);
276    optimizer.add_parameter(&bias);
277
278    // Simple data
279    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3, 1]).unwrap();
280    let y_true = Tensor::from_slice(&[2.0, 4.0, 6.0], vec![3, 1]).unwrap();
281
282    println!("Initial learning rate: {}", optimizer.learning_rate());
283
284    // Training loop with learning rate scheduling
285    let num_epochs = 50;
286    let mut losses = Vec::new();
287
288    for epoch in 0..num_epochs {
289        // Forward pass
290        let y_pred = x_data.matmul(&weight) + &bias;
291        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
292
293        // Backward pass
294        loss.backward(None);
295
296        // Optimizer step
297        optimizer.step(&mut [&mut weight, &mut bias]);
298        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
299
300        // Learning rate scheduling: reduce every 10 epochs
301        if epoch > 0 && epoch % 10 == 0 {
302            let current_lr = optimizer.learning_rate();
303            let new_lr = current_lr * 0.5;
304            optimizer.set_learning_rate(new_lr);
305            println!(
306                "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
307                epoch, current_lr, new_lr
308            );
309        }
310
311        losses.push(loss.value());
312
313        // Print progress
314        if epoch % 10 == 0 || epoch == num_epochs - 1 {
315            println!(
316                "Epoch {:2}: Loss = {:.6}, LR = {:.3}",
317                epoch,
318                loss.value(),
319                optimizer.learning_rate()
320            );
321        }
322    }
323
324    println!("Final learning rate: {}", optimizer.learning_rate());
325
326    Ok(())
327}
328
329/// Demonstrate training monitoring and analysis
330fn demonstrate_training_monitoring() -> Result<(), Box<dyn std::error::Error>> {
331    println!("\n--- Training Monitoring ---");
332
333    // Create model
334    let mut weight = Tensor::randn(vec![1, 1], Some(46)).with_requires_grad();
335    let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
336
337    // Create optimizer
338    let mut optimizer = Adam::with_learning_rate(0.01);
339    optimizer.add_parameter(&weight);
340    optimizer.add_parameter(&bias);
341
342    // Training data
343    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
344    let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0], vec![4, 1]).unwrap();
345
346    // Training loop with comprehensive monitoring
347    let num_epochs = 30;
348    let mut losses = Vec::new();
349    let mut weight_history = Vec::new();
350    let mut bias_history = Vec::new();
351
352    for epoch in 0..num_epochs {
353        // Forward pass
354        let y_pred = x_data.matmul(&weight) + &bias;
355        let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
356
357        // Backward pass
358        loss.backward(None);
359
360        // Optimizer step
361        optimizer.step(&mut [&mut weight, &mut bias]);
362        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
363
364        // Record history
365        losses.push(loss.value());
366        weight_history.push(weight.value());
367        bias_history.push(bias.value());
368
369        // Print detailed monitoring
370        if epoch % 5 == 0 || epoch == num_epochs - 1 {
371            println!(
372                "Epoch {:2}: Loss = {:.6}, Weight = {:.6}, Bias = {:.6}",
373                epoch,
374                loss.value(),
375                weight.value(),
376                bias.value()
377            );
378        }
379    }
380
381    // Analyze training progression
382    println!("\nTraining Analysis:");
383    println!("  Initial loss: {:.6}", losses[0]);
384    println!("  Final loss: {:.6}", losses[losses.len() - 1]);
385    println!(
386        "  Loss reduction: {:.1}%",
387        (losses[0] - losses[losses.len() - 1]) / losses[0] * 100.0
388    );
389
390    // Compute statistics
391    let loss_mean = compute_mean(&losses);
392    let loss_std = compute_std(&losses);
393    let weight_change = (weight_history[weight_history.len() - 1] - weight_history[0]).abs();
394    let bias_change = (bias_history[bias_history.len() - 1] - bias_history[0]).abs();
395
396    println!("  Average loss: {:.6} ± {:.6}", loss_mean, loss_std);
397    println!("  Weight change: {:.6}", weight_change);
398    println!("  Bias change: {:.6}", bias_change);
399    println!("  Final weight norm: {:.6}", weight.norm().value());
400    println!("  Final bias: {:.6}", bias.value());
401
402    Ok(())
403}
Source

pub fn norm_dims(&self, dims: &[usize], keepdim: bool) -> Tensor

Computes the L2 norm over specified dimensions

Reduces the tensor along the specified dimensions by computing the L2 norm of each slice. The result maintains the original tensor structure with reduced dimensions optionally preserved as size-1 dimensions.

§Arguments
  • dims - Vector of dimension indices to reduce over (must be valid for tensor rank)
  • keepdim - Whether to keep reduced dimensions as size-1 dimensions
§Returns

A tensor with L2 norm computed over the specified dimensions

§Examples
use train_station::Tensor;

// Norm along rows (dimension 1) with keepdim=true
let matrix = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
let row_norms = matrix.norm_dims(&[1], true);
assert_eq!(row_norms.shape().dims(), vec![2, 1]);
assert!((row_norms.get(&[0, 0]) - 5.0).abs() < 1e-6); // sqrt(3² + 4²)
assert!((row_norms.get(&[1, 0]) - 5.0).abs() < 1e-6); // sqrt(0² + 5²)
use train_station::Tensor;

// Norm along columns (dimension 0) with keepdim=false
let matrix = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
let col_norms = matrix.norm_dims(&[0], false);
assert_eq!(col_norms.shape().dims(), vec![2]);
assert!((col_norms.get(&[0]) - 3.0).abs() < 1e-6); // sqrt(3² + 0²)
assert!((col_norms.get(&[1]) - 6.403).abs() < 1e-3); // sqrt(4² + 5²)
use train_station::Tensor;

// Norm over multiple dimensions
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let norm_all = tensor.norm_dims(&[0, 1], false);
assert_eq!(norm_all.shape().dims(), vec![1]);
// sqrt(1² + 2² + 3² + 4²) = sqrt(30) ≈ 5.477
assert!((norm_all.get(&[0]) - 30.0_f32.sqrt()).abs() < 1e-5);
§Panics
  • If dims is empty
  • If any dimension index is out of bounds for the tensor rank
§Performance

Uses efficient coordinate-based iteration that works correctly with both contiguous and non-contiguous tensor layouts.

Source§

impl Tensor

Source

pub fn std(&self) -> Tensor

Computes the standard deviation over all elements

The standard deviation is calculated as sqrt(variance) where variance is the mean of squared differences from the mean. This operation reduces the tensor to a scalar value [1].

The implementation uses population standard deviation (divides by n rather than n-1) to match PyTorch’s default behavior.

§Returns

A scalar tensor containing the standard deviation value

§Examples
use train_station::Tensor;

// Basic standard deviation calculation
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
let std_dev = tensor.std();
assert!((std_dev.get(&[0]) - 1.118_034).abs() < 1e-5);
use train_station::Tensor;

// Standard deviation of a larger dataset
let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
let std_dev = tensor.std();
// mean=4.5, var=5.25, std=sqrt(5.25)≈2.291
let expected = 5.25_f32.sqrt();
assert!((std_dev.get(&[0]) - expected).abs() < 1e-5);
use train_station::Tensor;

// Standard deviation of constant values (should be 0)
let tensor = Tensor::from_slice(&[5.0, 5.0, 5.0, 5.0], vec![4]).unwrap();
let std_dev = tensor.std();
assert!((std_dev.get(&[0]) - 0.0).abs() < 1e-6);
§Performance

Uses optimized contiguous tensor path with 4x loop unrolling for better performance. Non-contiguous tensors use stride-aware iteration. The algorithm performs two passes: first to compute the mean, then to compute the variance.

Examples found in repository?
examples/iterators/advanced_patterns.rs (line 111)
87fn demonstrate_data_pipeline() -> Result<(), Box<dyn std::error::Error>> {
88    println!("\n--- Data Processing Pipeline ---");
89
90    // Simulate raw sensor data with noise
91    let raw_data: Vec<f32> = (0..20)
92        .map(|i| {
93            let base = i as f32 * 0.5;
94            let noise = (i % 3) as f32 * 0.1;
95            base + noise
96        })
97        .collect();
98
99    let tensor = Tensor::from_slice(&raw_data, vec![20])?;
100    println!("Raw sensor data: {:?}", tensor.data());
101
102    // Multi-stage processing pipeline
103    println!("\nProcessing pipeline:");
104    println!("1. Normalize data (z-score)");
105    println!("2. Apply smoothing filter");
106    println!("3. Detect outliers");
107    println!("4. Apply feature scaling");
108
109    // Stage 1: Normalization
110    let mean = tensor.mean().value();
111    let std = tensor.std().value();
112    let normalized: Tensor = tensor
113        .iter()
114        .map(|elem| elem.sub_scalar(mean).div_scalar(std))
115        .collect();
116    println!(
117        "  Normalized (mean={:.3}, std={:.3}): {:?}",
118        mean,
119        std,
120        normalized.data()
121    );
122
123    // Stage 2: Smoothing (simple moving average)
124    let smoothed: Tensor = normalized
125        .iter()
126        .enumerate()
127        .map(|(i, elem)| {
128            if i == 0 || i == normalized.size() - 1 {
129                elem.clone()
130            } else {
131                // Simple 3-point average
132                let prev = normalized.element_view(i - 1);
133                let next = normalized.element_view(i + 1);
134                elem.add_tensor(&prev).add_tensor(&next).div_scalar(3.0)
135            }
136        })
137        .collect();
138    println!("  Smoothed: {:?}", smoothed.data());
139
140    // Stage 3: Outlier detection and removal
141    let outlier_threshold = 2.0;
142    let cleaned: Tensor = smoothed
143        .iter()
144        .filter(|elem| elem.value().abs() < outlier_threshold)
145        .collect();
146    println!(
147        "  Outliers removed (threshold={}): {:?}",
148        outlier_threshold,
149        cleaned.data()
150    );
151
152    // Stage 4: Feature scaling to [0, 1] range
153    let min_val = cleaned
154        .iter()
155        .map(|e| e.value())
156        .fold(f32::INFINITY, f32::min);
157    let max_val = cleaned
158        .iter()
159        .map(|e| e.value())
160        .fold(f32::NEG_INFINITY, f32::max);
161    let scaled: Tensor = cleaned
162        .iter()
163        .map(|elem| elem.sub_scalar(min_val).div_scalar(max_val - min_val))
164        .collect();
165    println!("  Scaled to [0,1]: {:?}", scaled.data());
166
167    Ok(())
168}
169
170/// Demonstrate conditional processing patterns
171///
172/// Shows how to implement dynamic filtering and transformation
173/// based on data characteristics and conditions.
174fn demonstrate_conditional_processing() -> Result<(), Box<dyn std::error::Error>> {
175    println!("\n--- Conditional Processing ---");
176
177    // Create data with mixed characteristics
178    let data = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0];
179    let tensor = Tensor::from_slice(&data, vec![10])?;
180    println!("Input data: {:?}", tensor.data());
181
182    // Conditional transformation based on sign
183    println!("\nConditional transformation (positive/negative handling):");
184    let processed: Tensor = tensor
185        .iter()
186        .map(|elem| {
187            let val = elem.value();
188            if val > 0.0 {
189                elem.pow_scalar(2.0) // Square positive values
190            } else {
191                elem.mul_scalar(-1.0).sqrt() // Square root of absolute negative values
192            }
193        })
194        .collect();
195    println!("  Processed: {:?}", processed.data());
196
197    // Adaptive filtering based on local statistics
198    println!("\nAdaptive filtering (remove values > 2 std from local mean):");
199    let window_size = 3;
200    let adaptive_filtered: Tensor = tensor
201        .iter()
202        .enumerate()
203        .filter(|(i, elem)| {
204            let start = i.saturating_sub(window_size / 2);
205            let end = (i + window_size / 2 + 1).min(tensor.size());
206
207            // Calculate local mean and std
208            let local_values: Vec<f32> = (start..end)
209                .map(|j| tensor.element_view(j).value())
210                .collect();
211
212            let local_mean = local_values.iter().sum::<f32>() / local_values.len() as f32;
213            let local_variance = local_values
214                .iter()
215                .map(|v| (v - local_mean).powi(2))
216                .sum::<f32>()
217                / local_values.len() as f32;
218            let local_std = local_variance.sqrt();
219
220            let threshold = local_mean + 2.0 * local_std;
221            elem.value() <= threshold
222        })
223        .map(|(_, elem)| elem)
224        .collect();
225    println!("  Adaptive filtered: {:?}", adaptive_filtered.data());
226
227    // Multi-condition processing
228    println!("\nMulti-condition processing:");
229    let multi_processed: Tensor = tensor
230        .iter()
231        .map(|elem| {
232            let val = elem.value();
233            match () {
234                _ if val > 5.0 => elem.mul_scalar(2.0), // Double large values
235                _ if val < -5.0 => elem.div_scalar(2.0), // Halve small values
236                _ if val.abs() < 2.0 => elem.add_scalar(1.0), // Add 1 to small values
237                _ => elem.clone(),                      // Keep others unchanged
238            }
239        })
240        .collect();
241    println!("  Multi-condition: {:?}", multi_processed.data());
242
243    Ok(())
244}
245
246/// Demonstrate batch processing operations
247///
248/// Shows efficient processing of large datasets using iterator
249/// patterns and batch operations for performance optimization.
250fn demonstrate_batch_operations() -> Result<(), Box<dyn std::error::Error>> {
251    println!("\n--- Batch Operations ---");
252
253    // Create a larger dataset for batch processing
254    let size = 100;
255    let data: Vec<f32> = (0..size)
256        .map(|i| {
257            let x = i as f32 / size as f32;
258            x * x + 0.1 * (i % 7) as f32 // Quadratic with some noise
259        })
260        .collect();
261
262    let tensor = Tensor::from_slice(&data, vec![size])?;
263    println!("Dataset size: {}", tensor.size());
264
265    // Batch processing with windowing (iterator views)
266    println!("\nBatch processing with sliding windows:");
267    let batch_size = 10;
268    let batches: Vec<Tensor> = tensor
269        .iter()
270        .collect::<Vec<_>>()
271        .chunks(batch_size)
272        .map(|chunk| {
273            // Process each batch independently
274            chunk
275                .iter()
276                .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
277                .collect()
278        })
279        .collect();
280
281    println!(
282        "  Processed {} batches of size {}",
283        batches.len(),
284        batch_size
285    );
286    for (i, batch) in batches.iter().enumerate() {
287        println!(
288            "    Batch {}: mean={:.3}, std={:.3}",
289            i,
290            batch.mean().value(),
291            batch.std().value()
292        );
293    }
294
295    // Parallel-like processing with stride
296    println!("\nStrided processing (every nth element):");
297    let stride = 5;
298    let strided: Tensor = tensor
299        .iter()
300        .enumerate()
301        .filter(|(i, _)| i % stride == 0)
302        .map(|(_, elem)| elem)
303        .collect();
304    println!("  Strided (every {}th): {:?}", stride, strided.data());
305
306    // Hierarchical processing
307    println!("\nHierarchical processing (coarse to fine):");
308    let coarse: Tensor = tensor
309        .iter()
310        .enumerate()
311        .filter(|(i, _)| i % 4 == 0) // Take every 4th element
312        .map(|(_, elem)| elem)
313        .collect();
314
315    let fine: Tensor = tensor
316        .iter()
317        .enumerate()
318        .filter(|(i, _)| i % 4 != 0) // Take the rest
319        .map(|(_, elem)| elem)
320        .collect();
321
322    println!("  Coarse (every 4th): {:?}", coarse.data());
323    println!("  Fine (rest): {:?}", fine.data());
324
325    // Combine coarse and fine with different processing
326    let combined: Tensor = coarse
327        .iter()
328        .map(|elem| elem.mul_scalar(2.0)) // Scale coarse
329        .chain(fine.iter().map(|elem| elem.div_scalar(2.0))) // Scale fine
330        .collect();
331    println!("  Combined: {:?}", combined.data());
332
333    Ok(())
334}
335
336/// Demonstrate real-world processing scenarios
337///
338/// Shows practical applications of iterator patterns for
339/// common data processing tasks in machine learning and analytics.
340fn demonstrate_real_world_scenarios() -> Result<(), Box<dyn std::error::Error>> {
341    println!("\n--- Real-world Scenarios ---");
342
343    // Scenario 1: Time series analysis
344    println!("\nScenario 1: Time Series Analysis");
345    let time_series: Vec<f32> = (0..24)
346        .map(|hour| {
347            let base = 20.0 + 10.0 * (hour as f32 * std::f32::consts::PI / 12.0).sin();
348            base + (hour % 3) as f32 * 2.0 // Add some noise
349        })
350        .collect();
351
352    let series = Tensor::from_slice(&time_series, vec![24])?;
353    println!("  Time series (24 hours): {:?}", series.data());
354
355    // Calculate moving average with view-based iteration
356    let window_size = 3;
357    let moving_avg: Tensor = series
358        .iter()
359        .enumerate()
360        .map(|(i, _)| {
361            let start = i.saturating_sub(window_size / 2);
362            let end = (i + window_size / 2 + 1).min(series.size());
363            let window = series.iter_range(start, end);
364            window.fold(0.0, |acc, elem| acc + elem.value()) / (end - start) as f32
365        })
366        .map(|val| Tensor::from_slice(&[val], vec![1]).unwrap())
367        .collect();
368    println!(
369        "  Moving average (window={}): {:?}",
370        window_size,
371        moving_avg.data()
372    );
373
374    // Inference pipeline with NoGrad + streaming
375    println!("\nInference pipeline (NoGrad + streaming)");
376    let features = Tensor::from_slice(
377        &(0..48).map(|i| i as f32 * 0.125).collect::<Vec<_>>(),
378        vec![6, 8],
379    )?;
380    let fast = with_no_grad(|| {
381        // Stream values directly, apply light affine, and collect back to same shape
382        features
383            .data()
384            .iter()
385            .copied()
386            .map(|x| 0.75 * x + 0.1)
387            .collect_shape(vec![6, 8])
388    });
389    println!(
390        "  NoGrad streamed transform shape: {:?}",
391        fast.shape().dims()
392    );
393
394    // Row-wise iteration with shape-preserving collection (GradTrack-friendly)
395    let per_row: Tensor = features
396        .iter()
397        .map(|row| row.mul_scalar(0.5).add_scalar(2.0))
398        .collect_shape(vec![6, 8]);
399    println!("  Row-wise mapped shape: {:?}", per_row.shape().dims());
400
401    // Scenario 2: Feature engineering
402    println!("\nScenario 2: Feature Engineering");
403    let features = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
404    println!("  Original features: {:?}", features.data());
405
406    // Create polynomial features
407    let poly_features: Tensor = features
408        .iter()
409        .flat_map(|elem| {
410            vec![
411                elem.clone(),         // x^1
412                elem.pow_scalar(2.0), // x^2
413                elem.pow_scalar(3.0), // x^3
414            ]
415        })
416        .collect();
417    println!(
418        "  Polynomial features (x, x^2, x^3): {:?}",
419        poly_features.data()
420    );
421
422    // Scenario 3: Data augmentation
423    println!("\nScenario 3: Data Augmentation");
424    let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?;
425    println!("  Original data: {:?}", original.data());
426
427    // Augment with noise and scaling
428    let augmented: Tensor = original
429        .iter()
430        .flat_map(|elem| {
431            vec![
432                elem.clone(),         // Original
433                elem.add_scalar(0.1), // Add noise
434                elem.sub_scalar(0.1), // Subtract noise
435                elem.mul_scalar(1.1), // Scale up
436                elem.mul_scalar(0.9), // Scale down
437            ]
438        })
439        .collect();
440    println!("  Augmented data: {:?}", augmented.data());
441
442    // Scenario 4: Statistical analysis
443    println!("\nScenario 4: Statistical Analysis");
444    let sample_data = Tensor::from_slice(&[1.1, 2.3, 1.8, 2.1, 1.9, 2.0, 1.7, 2.2], vec![8])?;
445    println!("  Sample data: {:?}", sample_data.data());
446
447    // Calculate various statistics
448    let mean = sample_data.mean().value();
449    let std = sample_data.std().value();
450    let min = sample_data
451        .iter()
452        .map(|e| e.value())
453        .fold(f32::INFINITY, f32::min);
454    let max = sample_data
455        .iter()
456        .map(|e| e.value())
457        .fold(f32::NEG_INFINITY, f32::max);
458
459    // Z-score normalization
460    let z_scores: Tensor = sample_data
461        .iter()
462        .map(|elem| elem.sub_scalar(mean).div_scalar(std))
463        .collect();
464
465    println!(
466        "  Statistics: mean={:.3}, std={:.3}, min={:.3}, max={:.3}",
467        mean, std, min, max
468    );
469    println!("  Z-scores: {:?}", z_scores.data());
470
471    Ok(())
472}
Source

pub fn std_dims(&self, dims: &[usize], keepdim: bool) -> Tensor

Computes the standard deviation over specified dimensions

Reduces the tensor along the specified dimensions by computing the standard deviation of each slice. The result maintains the original tensor structure with reduced dimensions optionally preserved as size-1 dimensions.

Uses population standard deviation (divides by n rather than n-1) to match PyTorch’s default behavior.

§Arguments
  • dims - Vector of dimension indices to reduce over (must be valid for tensor rank)
  • keepdim - Whether to keep reduced dimensions as size-1 dimensions
§Returns

A tensor with standard deviation computed over the specified dimensions

§Examples
use train_station::Tensor;

// Standard deviation along rows (dimension 1) with keepdim=true
let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
let row_stds = matrix.std_dims(&[1], true);
assert_eq!(row_stds.shape().dims(), vec![2, 1]);
assert!((row_stds.get(&[0, 0]) - 1.0).abs() < 1e-6); // std([1, 3]) = 1.0
assert!((row_stds.get(&[1, 0]) - 0.0).abs() < 1e-6); // std([2, 2]) = 0.0
use train_station::Tensor;

// Standard deviation along columns (dimension 0) with keepdim=false
let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let col_stds = matrix.std_dims(&[0], false);
assert_eq!(col_stds.shape().dims(), vec![2]);
// std([1, 3]) = 1.0, std([2, 4]) = 1.0
assert!((col_stds.get(&[0]) - 1.0).abs() < 1e-6);
assert!((col_stds.get(&[1]) - 1.0).abs() < 1e-6);
use train_station::Tensor;

// Standard deviation over multiple dimensions
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let std_all = tensor.std_dims(&[0, 1], false);
assert_eq!(std_all.shape().dims(), vec![1]);
// std([1, 2, 3, 4]) = sqrt(1.25) ≈ 1.118
assert!((std_all.get(&[0]) - 1.25_f32.sqrt()).abs() < 1e-5);
§Panics
  • If dims is empty
  • If any dimension index is out of bounds for the tensor rank
  • If the reduced size is 0 (invalid for standard deviation calculation)
§Performance

Uses efficient coordinate-based iteration that works correctly with both contiguous and non-contiguous tensor layouts. The algorithm performs two passes: first to compute means, then to compute variances.

Source§

impl Tensor

Source

pub fn sum(&self) -> Tensor

Returns the sum of all elements in the tensor

This operation computes the sum of all elements across all dimensions, reducing the tensor to a scalar value. The output is a tensor with shape [1] containing the sum as a float.

When requires_grad is enabled, this operation supports automatic gradient tracking through the GradTrack system.

§Returns

A tensor with shape [1] containing the sum of all elements

§Examples
use train_station::Tensor;

// Basic sum calculation
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let total = tensor.sum();
assert_eq!(total.shape().dims(), vec![1]);
assert_eq!(total.get(&[0]), 10.0); // 1 + 2 + 3 + 4 = 10
use train_station::Tensor;

// Sum with gradient tracking
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])
    .unwrap()
    .with_requires_grad();
let mut total = tensor.sum();
total.backward(None);
let grad = tensor.grad_owned().expect("gradient should exist");
// Gradient should be [1.0, 1.0, 1.0] for each element
assert_eq!(grad.get(&[0]), 1.0);
assert_eq!(grad.get(&[1]), 1.0);
assert_eq!(grad.get(&[2]), 1.0);
use train_station::Tensor;

// Sum of empty tensor
let tensor = Tensor::new(vec![0]);
let total = tensor.sum();
assert_eq!(total.get(&[0]), 0.0); // Sum of empty tensor is 0
§Performance

Uses optimized contiguous tensor path with 4x loop unrolling for better performance. Non-contiguous tensors use stride-aware iteration.

Examples found in repository?
examples/supervised_training/supervised_regression.rs (line 54)
50fn r2_score(pred: &Tensor, target: &Tensor) -> f32 {
51    // R^2 = 1 - SS_res / SS_tot
52    let y = target;
53    let y_mean = y.mean();
54    let ss_res = pred.sub_tensor(y).pow_scalar(2.0).sum();
55    let ss_tot = y.sub_tensor(&y_mean).pow_scalar(2.0).sum();
56    let ss_res_v = ss_res.value();
57    let ss_tot_v = ss_tot.value().max(1e-12); // avoid divide by zero
58    1.0 - (ss_res_v / ss_tot_v)
59}
More examples
Hide additional examples
examples/getting_started/tensor_basics.rs (line 191)
179fn demonstrate_utility_functions() {
180    println!("\n--- Utility Functions ---");
181
182    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184    // Basic properties
185    println!("Shape: {:?}", tensor.shape().dims());
186    println!("Size: {}", tensor.size());
187    println!("Is contiguous: {}", tensor.is_contiguous());
188    println!("Device: {:?}", tensor.device());
189
190    // Mathematical operations
191    let sum = tensor.sum();
192    println!("Sum: {}", sum.value());
193
194    let mean = tensor.mean();
195    println!("Mean: {}", mean.value());
196
197    let norm = tensor.norm();
198    println!("Norm: {}", norm.value());
199
200    // Device placement
201    let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202    println!(
203        "CPU tensor: shape {:?}, device: {:?}",
204        cpu_tensor.shape().dims(),
205        cpu_tensor.device()
206    );
207}
examples/iterators/element_iteration.rs (line 195)
175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176    println!("\n--- Gradient Tracking ---");
177
178    // Create a tensor with gradient tracking enabled
179    let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180    println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182    // Perform element-wise operations through iteration
183    let result: Tensor = tensor
184        .iter()
185        .map(|elem| {
186            // Apply a complex transformation: (x^2 + 1) * 2
187            elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188        })
189        .collect();
190
191    println!("Result tensor: {:?}", result.data());
192    println!("Result requires_grad: {}", result.requires_grad());
193
194    // Compute gradients
195    let mut loss = result.sum();
196    loss.backward(None);
197
198    println!("Loss: {:.6}", loss.value());
199    println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201    Ok(())
202}
examples/getting_started/serialization_basics.rs (line 137)
109fn demonstrate_optimizer_serialization() -> Result<(), Box<dyn std::error::Error>> {
110    println!("\n--- Optimizer Serialization ---");
111
112    // Create an optimizer with some parameters
113    let mut weight = Tensor::randn(vec![2, 2], Some(42)).with_requires_grad();
114    let mut bias = Tensor::randn(vec![2], Some(43)).with_requires_grad();
115
116    let config = AdamConfig {
117        learning_rate: 0.001,
118        beta1: 0.9,
119        beta2: 0.999,
120        eps: 1e-8,
121        weight_decay: 0.0,
122        amsgrad: false,
123    };
124
125    let mut optimizer = Adam::with_config(config);
126    optimizer.add_parameter(&weight);
127    optimizer.add_parameter(&bias);
128
129    println!(
130        "Created optimizer with {} parameters",
131        optimizer.parameter_count()
132    );
133    println!("Learning rate: {}", optimizer.learning_rate());
134
135    // Simulate some training steps
136    for _ in 0..3 {
137        let mut loss = weight.sum() + bias.sum();
138        loss.backward(None);
139        optimizer.step(&mut [&mut weight, &mut bias]);
140        optimizer.zero_grad(&mut [&mut weight, &mut bias]);
141    }
142
143    // Save optimizer state
144    let optimizer_path = "temp_optimizer.json";
145    optimizer.save_json(optimizer_path)?;
146    println!("Saved optimizer to: {}", optimizer_path);
147
148    // Load optimizer state
149    let loaded_optimizer = Adam::load_json(optimizer_path)?;
150    println!(
151        "Loaded optimizer with {} parameters",
152        loaded_optimizer.parameter_count()
153    );
154    println!("Learning rate: {}", loaded_optimizer.learning_rate());
155
156    // Verify optimizer state
157    assert_eq!(
158        optimizer.parameter_count(),
159        loaded_optimizer.parameter_count()
160    );
161    assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
162    println!("Optimizer serialization verification: PASSED");
163
164    Ok(())
165}
166
167/// Demonstrate format comparison and performance characteristics
168fn demonstrate_format_comparison() -> Result<(), Box<dyn std::error::Error>> {
169    println!("\n--- Format Comparison ---");
170
171    // Create a larger tensor for comparison
172    let tensor = Tensor::randn(vec![10, 10], Some(44));
173
174    // Save in both formats
175    tensor.save_json("temp_comparison.json")?;
176    tensor.save_binary("temp_comparison.bin")?;
177
178    // Compare file sizes
179    let json_size = fs::metadata("temp_comparison.json")?.len();
180    let binary_size = fs::metadata("temp_comparison.bin")?.len();
181
182    println!("JSON file size: {} bytes", json_size);
183    println!("Binary file size: {} bytes", binary_size);
184    println!(
185        "Compression ratio: {:.2}x",
186        json_size as f64 / binary_size as f64
187    );
188
189    // Load and verify both formats
190    let json_tensor = Tensor::load_json("temp_comparison.json")?;
191    let binary_tensor = Tensor::load_binary("temp_comparison.bin")?;
192
193    assert_eq!(tensor.shape().dims(), json_tensor.shape().dims());
194    assert_eq!(tensor.shape().dims(), binary_tensor.shape().dims());
195    assert_eq!(tensor.data(), json_tensor.data());
196    assert_eq!(tensor.data(), binary_tensor.data());
197
198    println!("Format comparison verification: PASSED");
199
200    Ok(())
201}
202
203/// Demonstrate a basic model checkpointing workflow
204fn demonstrate_model_checkpointing() -> Result<(), Box<dyn std::error::Error>> {
205    println!("\n--- Model Checkpointing ---");
206
207    // Create a simple model (weights and bias)
208    let mut weights = Tensor::randn(vec![2, 1], Some(45)).with_requires_grad();
209    let mut bias = Tensor::randn(vec![1], Some(46)).with_requires_grad();
210
211    // Create optimizer
212    let mut optimizer = Adam::with_learning_rate(0.01);
213    optimizer.add_parameter(&weights);
214    optimizer.add_parameter(&bias);
215
216    println!("Initial weights: {:?}", weights.data());
217    println!("Initial bias: {:?}", bias.data());
218
219    // Simulate training
220    for epoch in 0..5 {
221        let mut loss = weights.sum() + bias.sum();
222        loss.backward(None);
223        optimizer.step(&mut [&mut weights, &mut bias]);
224        optimizer.zero_grad(&mut [&mut weights, &mut bias]);
225
226        if epoch % 2 == 0 {
227            // Save checkpoint
228            let checkpoint_dir = format!("checkpoint_epoch_{}", epoch);
229            fs::create_dir_all(&checkpoint_dir)?;
230
231            weights.save_json(format!("{}/weights.json", checkpoint_dir))?;
232            bias.save_json(format!("{}/bias.json", checkpoint_dir))?;
233            optimizer.save_json(format!("{}/optimizer.json", checkpoint_dir))?;
234
235            println!("Saved checkpoint for epoch {}", epoch);
236        }
237    }
238
239    // Load from checkpoint
240    let loaded_weights = Tensor::load_json("checkpoint_epoch_4/weights.json")?;
241    let loaded_bias = Tensor::load_json("checkpoint_epoch_4/bias.json")?;
242    let loaded_optimizer = Adam::load_json("checkpoint_epoch_4/optimizer.json")?;
243
244    println!("Loaded weights: {:?}", loaded_weights.data());
245    println!("Loaded bias: {:?}", loaded_bias.data());
246    println!(
247        "Loaded optimizer learning rate: {}",
248        loaded_optimizer.learning_rate()
249    );
250
251    // Verify checkpoint integrity
252    assert_eq!(weights.shape().dims(), loaded_weights.shape().dims());
253    assert_eq!(bias.shape().dims(), loaded_bias.shape().dims());
254    assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
255
256    println!("Checkpointing verification: PASSED");
257
258    Ok(())
259}
examples/iterators/performance_optimization.rs (line 409)
319fn demonstrate_optimization_techniques() -> Result<(), Box<dyn std::error::Error>> {
320    println!("\n--- Optimization Techniques ---");
321
322    let size = 50000;
323    let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
324    let tensor = Tensor::from_slice(&data, vec![size])?;
325
326    println!("Optimizing processing for size: {}", size);
327
328    // Technique 1: Operation fusion
329    println!("\nTechnique 1: Operation Fusion");
330    let start = Instant::now();
331    let fused_result: Tensor = tensor
332        .iter()
333        .map(|elem| {
334            // Fuse multiple operations into single chain
335            elem.mul_scalar(2.0).add_scalar(1.0).pow_scalar(2.0).sqrt()
336        })
337        .collect();
338    let fused_time = start.elapsed();
339
340    // Technique 2: Conditional optimization
341    println!("\nTechnique 2: Conditional Optimization");
342    let start = Instant::now();
343    let conditional_result: Tensor = tensor
344        .iter()
345        .map(|elem| {
346            let val = elem.value();
347            if val < size as f32 / 2.0 {
348                elem.mul_scalar(2.0) // Simple operation for small values
349            } else {
350                elem.pow_scalar(2.0).sqrt() // Complex operation for large values
351            }
352        })
353        .collect();
354    let conditional_time = start.elapsed();
355
356    // Technique 3: Cache-friendly processing
357    println!("\nTechnique 3: Cache-Friendly Processing");
358    let start = Instant::now();
359    let cache_friendly_result: Tensor = tensor
360        .iter()
361        .take(1000) // Process in cache-friendly chunks
362        .map(|elem| elem.mul_scalar(2.0))
363        .collect();
364    let cache_friendly_time = start.elapsed();
365
366    // Technique 4: Memory pooling simulation
367    println!("\nTechnique 4: Memory Pooling Simulation");
368    let start = Instant::now();
369    let pooled_result: Tensor = tensor
370        .iter()
371        .enumerate()
372        .filter(|(i, _)| i % 100 == 0) // Process every 100th element
373        .map(|(_, elem)| elem.pow_scalar(2.0))
374        .collect();
375    let pooled_time = start.elapsed();
376
377    // Report optimization results
378    println!("  Fused operations: {:?}", fused_time);
379    println!("  Conditional optimization: {:?}", conditional_time);
380    println!("  Cache-friendly processing: {:?}", cache_friendly_time);
381    println!("  Memory pooling simulation: {:?}", pooled_time);
382
383    // Performance analysis
384    let fastest = fused_time
385        .min(conditional_time)
386        .min(cache_friendly_time)
387        .min(pooled_time);
388    println!("  Fastest technique: {:?}", fastest);
389
390    // Memory efficiency analysis
391    println!("  Fused result size: {}", fused_result.size());
392    println!("  Conditional result size: {}", conditional_result.size());
393    println!(
394        "  Cache-friendly result size: {}",
395        cache_friendly_result.size()
396    );
397    println!("  Pooled result size: {}", pooled_result.size());
398
399    // Technique 5: Gradient optimization
400    println!("\nTechnique 5: Gradient Optimization");
401    let grad_tensor = tensor.with_requires_grad();
402    let start = Instant::now();
403
404    let grad_result: Tensor = grad_tensor
405        .iter()
406        .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
407        .collect();
408
409    let mut loss = grad_result.sum();
410    loss.backward(None);
411    let grad_time = start.elapsed();
412
413    println!("  Gradient computation: {:?}", grad_time);
414    println!(
415        "  Gradient tracking enabled: {}",
416        grad_result.requires_grad()
417    );
418
419    Ok(())
420}
Source

pub fn sum_dims(&self, dims: &[usize], keepdim: bool) -> Tensor

Returns the sum of elements along specified dimensions

This operation computes the sum of elements along the specified dimensions, reducing the tensor while optionally preserving the reduced dimensions as size-1 dimensions.

The output shape depends on the keepdim parameter:

  • If keepdim is true, the reduced dimensions are kept with size 1
  • If keepdim is false, the reduced dimensions are removed

When requires_grad is enabled, this operation supports automatic gradient tracking through the GradTrack system.

§Arguments
  • dims - Vector of dimension indices to sum over (must be valid for tensor rank)
  • keepdim - Whether to keep reduced dimensions as size-1 dimensions
§Returns

A tensor with sum computed over the specified dimensions

§Panics
  • If dims is empty
  • If any dimension index is out of bounds for the tensor rank
§Examples
use train_station::Tensor;

// Sum along rows (dimension 0) with keepdim=false
let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let row_sums = matrix.sum_dims(&[0], false);
assert_eq!(row_sums.shape().dims(), vec![2]);
assert_eq!(row_sums.get(&[0]), 4.0); // 1 + 3 = 4
assert_eq!(row_sums.get(&[1]), 6.0); // 2 + 4 = 6
use train_station::Tensor;

// Sum along columns (dimension 1) with keepdim=true
let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let col_sums = matrix.sum_dims(&[1], true);
assert_eq!(col_sums.shape().dims(), vec![2, 1]);
assert_eq!(col_sums.get(&[0, 0]), 3.0); // 1 + 2 = 3
assert_eq!(col_sums.get(&[1, 0]), 7.0); // 3 + 4 = 7
use train_station::Tensor;

// Sum over multiple dimensions
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let total = tensor.sum_dims(&[0, 1], false);
assert_eq!(total.shape().dims(), vec![1]);
assert_eq!(total.get(&[0]), 10.0); // 1 + 2 + 3 + 4 = 10
use train_station::Tensor;

// Sum with gradient tracking
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
    .unwrap()
    .with_requires_grad();
let mut row_sums = tensor.sum_dims(&[0], false);
row_sums.backward(None);
let grad = tensor.grad_owned().expect("gradient should exist");
// Gradient should be [1.0, 1.0, 1.0, 1.0] for each element
assert_eq!(grad.get(&[0, 0]), 1.0);
assert_eq!(grad.get(&[0, 1]), 1.0);
assert_eq!(grad.get(&[1, 0]), 1.0);
assert_eq!(grad.get(&[1, 1]), 1.0);
§Performance

Uses efficient coordinate-based iteration that works correctly with both contiguous and non-contiguous tensor layouts.

Examples found in repository?
examples/supervised_training/supervised_classification.rs (line 54)
44fn cross_entropy_logits(
45    logits: &Tensor,
46    labels: &[usize],
47    batch: usize,
48    _num_classes: usize,
49) -> Tensor {
50    // log_softmax = logits - logsumexp(logits, dim=1)
51    let max_logits = logits.max_dims(&[1], true);
52    let shifted = logits.sub_tensor(&max_logits);
53    let exp = shifted.exp();
54    let sum_exp = exp.sum_dims(&[1], true);
55    let log_sum_exp = sum_exp.log();
56    let log_softmax = shifted.sub_tensor(&log_sum_exp);
57    let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58    ll.mul_scalar(-1.0).mean()
59}
More examples
Hide additional examples
examples/RL_training/ppo_discrete.rs (line 282)
273fn log_prob_actions(
274    logits: &Tensor,
275    actions: &[usize],
276    batch: usize,
277    _action_dim: usize,
278) -> Tensor {
279    let max_logits = logits.max_dims(&[1], true); // [B,1]
280    let shifted = logits.sub_tensor(&max_logits);
281    let exp = shifted.exp();
282    let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283    let log_sum_exp = sum_exp.log(); // [B,1]
284    let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285                                                        // gather selected action log-probs
286    log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291    new_logp.sub_tensor(old_logp).exp()
292}
293
294// Clamp ratio to [1-clip, 1+clip] using ReLU-based clamp (no custom ops)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296    let b = ratio.shape().dims()[0];
297    let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298    let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299    let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300    high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304    let mut total_sq = 0.0f32;
305    for p in parameters.iter_mut() {
306        if let Some(g) = p.grad_owned() {
307            for &v in g.data() {
308                total_sq += v * v;
309            }
310        }
311    }
312    total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320    println!("=== PPO Discrete Example (YardEnv) ===");
321
322    let state_dim = 3usize;
323    let action_dim = 3usize;
324    let total_steps = std::env::var("PPOD_STEPS")
325        .ok()
326        .and_then(|v| v.parse::<usize>().ok())
327        .unwrap_or(3500usize);
328    let horizon = 128usize;
329    let epochs = 4usize;
330    let mini_batch_size = 64usize;
331    let gamma = 0.99f32;
332    let lam = 0.95f32;
333    let clip_eps = 0.2f32;
334    let vf_coef = 0.5f32;
335    let ent_coef = 0.0f32;
336    let max_grad_norm = 1.0f32;
337
338    let mut actor = Actor::new(state_dim, action_dim, Some(111));
339    let mut critic = Critic::new(state_dim, Some(222));
340    let mut actor_opt = Adam::with_learning_rate(3e-4);
341    for p in actor.parameters() {
342        actor_opt.add_parameter(p);
343    }
344    let mut critic_opt = Adam::with_learning_rate(3e-4);
345    for p in critic.parameters() {
346        critic_opt.add_parameter(p);
347    }
348
349    let mut env = YardEnv::new(1234);
350    let mut rng = SmallRng::new(98765);
351    let mut state = env.reset();
352    let mut episode_return = 0.0f32;
353    let mut episode = 0usize;
354    let mut ema_return: Option<f32> = None;
355    let ema_alpha = 0.05f32;
356    let mut best_return = f32::NEG_INFINITY;
357
358    let mut t = 0usize;
359    while t < total_steps {
360        let mut batch = RolloutBatch::new(horizon, state_dim);
361        for _ in 0..horizon {
362            // Actor logits and categorical sampling
363            let logits = actor.forward(&state); // [1, A]
364            let probs = logits.softmax(1); // [1, A]
365                                           // sample action from probs (CPU sampling)
366            let p = probs.data();
367            let (p0, p1, _p2) = (p[0], p[1], p[2]);
368            let u = rng.next_f32();
369            let a_idx = if u < p0 {
370                0
371            } else if u < p0 + p1 {
372                1
373            } else {
374                2
375            };
376
377            let old_logp = {
378                let _ng = NoGradTrack::new();
379                let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380                lp.data()[0]
381            };
382
383            // Step env
384            let (next_state, reward, done) = env.step(a_idx);
385            episode_return += reward;
386
387            // Critic value
388            let value_t = critic.forward(&state);
389            let value_v = value_t.data()[0];
390
391            batch.push(
392                state.data(),
393                a_idx,
394                old_logp,
395                reward,
396                if done { 1.0 } else { 0.0 },
397                value_v,
398                next_state.data(),
399            );
400
401            state = if done {
402                let st = env.reset();
403                ema_return = Some(match ema_return {
404                    None => episode_return,
405                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406                });
407                if episode_return > best_return {
408                    best_return = episode_return;
409                }
410                println!(
411                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412                    t,
413                    episode,
414                    episode_return,
415                    ema_return.unwrap_or(episode_return),
416                    best_return
417                );
418                episode_return = 0.0;
419                episode += 1;
420                st
421            } else {
422                next_state
423            };
424
425            t += 1;
426            if t >= total_steps {
427                break;
428            }
429        }
430
431        // Bootstrap values for GAE
432        let next_values: Vec<f32> = {
433            let mut out = Vec::with_capacity(batch.len());
434            for i in 0..batch.len() {
435                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437                out.push(critic.forward(&s2_t).data()[0]);
438            }
439            out
440        };
441
442        let mut returns = vec![0.0f32; batch.len()];
443        let mut adv = vec![0.0f32; batch.len()];
444        compute_gae(
445            &mut returns,
446            &mut adv,
447            &batch.rewards,
448            &batch.dones,
449            &batch.values,
450            &next_values,
451            gamma,
452            lam,
453        );
454        normalize_in_place(&mut adv, 1e-8);
455
456        // Tensors for training
457        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458        let actions_vec = batch.actions.clone();
459        let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463        // PPO epochs
464        let num_minibatches = batch.len().div_ceil(mini_batch_size);
465        for e in 0..epochs {
466            for mb in 0..num_minibatches {
467                let start = mb * mini_batch_size;
468                let end = (start + mini_batch_size).min(batch.len());
469                if start >= end {
470                    break;
471                }
472
473                // Views
474                let s_mb = states_t
475                    .slice_view(start * state_dim, 1, (end - start) * state_dim)
476                    .reshape(vec![(end - start) as i32, state_dim as i32]);
477                let oldlp_mb = old_logp_t
478                    .slice_view(start, 1, end - start)
479                    .reshape(vec![(end - start) as i32, 1]);
480                let ret_mb = returns_t
481                    .slice_view(start, 1, end - start)
482                    .reshape(vec![(end - start) as i32, 1]);
483                let adv_mb = adv_t
484                    .slice_view(start, 1, end - start)
485                    .reshape(vec![(end - start) as i32, 1]);
486                let a_slice = &actions_vec[start..end];
487
488                // Zero grads
489                {
490                    let mut ps = actor.parameters();
491                    actor_opt.zero_grad(&mut ps);
492                }
493                {
494                    let mut ps = critic.parameters();
495                    critic_opt.zero_grad(&mut ps);
496                }
497
498                // Forward
499                let logits_mb = actor.forward(&s_mb); // [B,A]
500                let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501                let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502                let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503                let pg1 = ratio.mul_tensor(&adv_mb);
504                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507                let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509                let v_pred = critic.forward(&s_mb);
510                let v_loss = v_pred
511                    .sub_tensor(&ret_mb)
512                    .pow_scalar(2.0)
513                    .mean()
514                    .mul_scalar(vf_coef);
515
516                // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517                let probs_mb = logits_mb.softmax(1);
518                let logp_all = probs_mb.add_scalar(1e-8).log();
519                let ent = probs_mb
520                    .mul_tensor(&logp_all)
521                    .sum_dims(&[1], true)
522                    .mul_scalar(-1.0)
523                    .mean()
524                    .mul_scalar(ent_coef);
525
526                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527                loss.backward(None);
528
529                // Step actor
530                {
531                    let params = actor.parameters();
532                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
533                    for p in params {
534                        if p.grad_owned().is_some() {
535                            with_grads.push(p);
536                        }
537                    }
538                    if !with_grads.is_empty() {
539                        let _ = grad_global_norm(&mut with_grads);
540                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541                        actor_opt.step(&mut with_grads);
542                        actor_opt.zero_grad(&mut with_grads);
543                    }
544                }
545
546                // Step critic
547                {
548                    let params = critic.parameters();
549                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
550                    for p in params {
551                        if p.grad_owned().is_some() {
552                            with_grads.push(p);
553                        }
554                    }
555                    if !with_grads.is_empty() {
556                        let _ = grad_global_norm(&mut with_grads);
557                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558                        critic_opt.step(&mut with_grads);
559                        critic_opt.zero_grad(&mut with_grads);
560                    }
561                }
562
563                if e == 0 && mb == 0 {
564                    println!(
565                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
566                        t,
567                        actor_loss.value(),
568                        v_loss.value()
569                    );
570                }
571
572                clear_all_graphs_known();
573            }
574        }
575    }
576
577    println!("=== PPO discrete training finished ===");
578    Ok(())
579}
examples/RL_training/ppo_continuous.rs (line 248)
234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235    // All tensors shaped [B, A] (log_std is broadcastable)
236    let std = log_std.exp();
237    let var = std.pow_scalar(2.0);
238    let log_scale = log_std;
239    let diff = action.sub_tensor(mean);
240    let log_prob = diff
241        .pow_scalar(2.0)
242        .div_tensor(&var)
243        .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244        .add_tensor(&log_scale.mul_scalar(2.0))
245        .mul_scalar(0.5)
246        .mul_scalar(-1.0);
247    // Sum across action dim (dim=1) -> [B,1]
248    log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253    returns_out: &mut [f32],
254    adv_out: &mut [f32],
255    rewards: &[f32],
256    dones: &[f32],
257    values: &[f32],
258    next_values: &[f32],
259    gamma: f32,
260    lam: f32,
261) {
262    let n = rewards.len();
263    let mut gae = 0.0f32;
264    for t in (0..n).rev() {
265        let not_done = 1.0 - dones[t];
266        let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267        gae = delta + gamma * lam * not_done * gae;
268        adv_out[t] = gae;
269        returns_out[t] = gae + values[t];
270    }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274    let n = x.len() as f32;
275    if n <= 1.0 {
276        return;
277    }
278    let mean = x.iter().copied().sum::<f32>() / n;
279    let var = x
280        .iter()
281        .map(|v| {
282            let d = v - mean;
283            d * d
284        })
285        .sum::<f32>()
286        / n;
287    let std = (var + eps).sqrt();
288    for v in x.iter_mut() {
289        *v = (*v - mean) / std;
290    }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294    let mut total_sq = 0.0f32;
295    for p in parameters.iter() {
296        if let Some(g) = p.grad_owned() {
297            for &v in g.data() {
298                total_sq += v * v;
299            }
300        }
301    }
302    let norm = total_sq.sqrt();
303    if norm > max_norm {
304        let scale = max_norm / (norm + eps);
305        for p in parameters.iter_mut() {
306            if let Some(g) = p.grad_owned() {
307                p.set_grad(g.mul_scalar(scale));
308            }
309        }
310    }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314    let mut total_sq = 0.0f32;
315    for p in parameters.iter_mut() {
316        if let Some(g) = p.grad_owned() {
317            for &v in g.data() {
318                total_sq += v * v;
319            }
320        }
321    }
322    total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330    println!("=== PPO Continuous Example (YardEnv) ===");
331
332    let state_dim = 3usize;
333    let action_dim = 1usize;
334
335    // Hparams
336    let total_steps = std::env::var("PPO_STEPS")
337        .ok()
338        .and_then(|v| v.parse::<usize>().ok())
339        .unwrap_or(4000usize);
340    let horizon = 128usize; // rollout length per update
341    let epochs = 4usize; // PPO epochs per update
342    let mini_batch_size = 64usize; // minibatch from horizon
343    let gamma = 0.99f32;
344    let lam = 0.95f32; // GAE lambda
345    let clip_eps = 0.2f32;
346    let vf_coef = 0.5f32;
347    let ent_coef = 0.0f32;
348    let max_grad_norm = 1.0f32;
349
350    // Models
351    let mut actor = Actor::new(state_dim, action_dim, Some(101));
352    let mut critic = Critic::new(state_dim, Some(202));
353
354    // Opts
355    let mut actor_opt = Adam::with_learning_rate(3e-4);
356    for p in actor.parameters() {
357        actor_opt.add_parameter(p);
358    }
359    let mut critic_opt = Adam::with_learning_rate(3e-4);
360    for p in critic.parameters() {
361        critic_opt.add_parameter(p);
362    }
363
364    // Env and RNG
365    let mut env = YardEnv::new(42);
366    let mut rng = SmallRng::new(999);
367    let mut state = env.reset();
368
369    // Metrics
370    let mut episode_return = 0.0f32;
371    let mut episode = 0usize;
372    let mut ema_return: Option<f32> = None;
373    let ema_alpha = 0.05f32;
374    let mut best_return = f32::NEG_INFINITY;
375
376    let mut t = 0usize;
377    while t < total_steps {
378        // Collect a rollout
379        let mut batch = RolloutBatch::new(horizon, state_dim);
380        for _ in 0..horizon {
381            // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382            let (mean, log_std_row) = actor.forward(&state);
383            let mean_v = mean.data()[0];
384            let log_std_v = log_std_row.data()[0];
385            let std_v = log_std_v.exp();
386            let noise = rng.normal();
387            let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389            // Build action tensor [1, A] for log_prob calculation with autograd
390            let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391            let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392            let log_prob_v = log_prob_t.data()[0];
393
394            // Step env
395            let (next_state, reward, done) = env.step(action_v);
396            episode_return += reward;
397
398            // Value
399            let value_t = critic.forward(&state);
400            let value_v = value_t.data()[0];
401
402            // Push
403            batch.push(
404                state.data(),
405                action_v,
406                log_prob_v,
407                reward,
408                if done { 1.0 } else { 0.0 },
409                value_v,
410                next_state.data(),
411            );
412
413            // Reset
414            state = if done {
415                let st = env.reset();
416                ema_return = Some(match ema_return {
417                    None => episode_return,
418                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419                });
420                if episode_return > best_return {
421                    best_return = episode_return;
422                }
423                println!(
424                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425                    t,
426                    episode,
427                    episode_return,
428                    ema_return.unwrap_or(episode_return),
429                    best_return
430                );
431                episode_return = 0.0;
432                episode += 1;
433                st
434            } else {
435                next_state
436            };
437
438            t += 1;
439            if t >= total_steps {
440                break;
441            }
442        }
443
444        // Bootstrap next values for GAE
445        let next_values: Vec<f32> = {
446            let mut out = Vec::with_capacity(batch.len());
447            for i in 0..batch.len() {
448                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450                let v2 = critic.forward(&s2_t).data()[0];
451                out.push(v2);
452            }
453            out
454        };
455
456        // Compute returns and advantages
457        let mut returns = vec![0.0f32; batch.len()];
458        let mut adv = vec![0.0f32; batch.len()];
459        compute_gae(
460            &mut returns,
461            &mut adv,
462            &batch.rewards,
463            &batch.dones,
464            &batch.values,
465            &next_values,
466            gamma,
467            lam,
468        );
469        normalize_in_place(&mut adv, 1e-8);
470
471        // Prepare tensors for training
472        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473        let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474        let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478        // PPO epochs over the rollout
479        let num_minibatches = batch.len().div_ceil(mini_batch_size);
480        for e in 0..epochs {
481            for mb in 0..num_minibatches {
482                let start = mb * mini_batch_size;
483                let end = (start + mini_batch_size).min(batch.len());
484                if start >= end {
485                    break;
486                }
487
488                // Slice views
489                let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490                let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491                let a_mb = actions_t
492                    .slice_view(start * action_dim, 1, (end - start) * action_dim)
493                    .reshape(vec![(end - start) as i32, action_dim as i32]);
494                let oldlp_mb = old_logp_t
495                    .slice_view(start, 1, end - start)
496                    .reshape(vec![(end - start) as i32, 1]);
497                let ret_mb = returns_t
498                    .slice_view(start, 1, end - start)
499                    .reshape(vec![(end - start) as i32, 1]);
500                let adv_mb = adv_t
501                    .slice_view(start, 1, end - start)
502                    .reshape(vec![(end - start) as i32, 1]);
503
504                // Zero grads
505                {
506                    let mut ps = actor.parameters();
507                    actor_opt.zero_grad(&mut ps);
508                }
509                {
510                    let mut ps = critic.parameters();
511                    critic_opt.zero_grad(&mut ps);
512                }
513
514                // Forward actor and critic
515                let (mean_mb, log_std_row) = actor.forward(&s_mb);
516                let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517                let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518                let clip_low =
519                    Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520                        .unwrap();
521                let clip_high =
522                    Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523                        .unwrap();
524                // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525                let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526                let ratio_clipped =
527                    clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528                let pg1 = ratio.mul_tensor(&adv_mb);
529                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532                let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534                let v_pred = critic.forward(&s_mb);
535                let v_loss = v_pred
536                    .sub_tensor(&ret_mb)
537                    .pow_scalar(2.0)
538                    .mean()
539                    .mul_scalar(vf_coef);
540
541                // Entropy (approx Gaussian entropy per action)
542                let entropy = log_std_row
543                    .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544                    .sum_dims(&[1], true)
545                    .mean()
546                    .mul_scalar(ent_coef);
547
548                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549                loss.backward(None);
550
551                // Step actor
552                {
553                    let params = actor.parameters();
554                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
555                    for p in params {
556                        if p.grad_owned().is_some() {
557                            with_grads.push(p);
558                        }
559                    }
560                    if !with_grads.is_empty() {
561                        let _ = grad_global_norm(&mut with_grads);
562                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563                        actor_opt.step(&mut with_grads);
564                        actor_opt.zero_grad(&mut with_grads);
565                    }
566                }
567
568                // Step critic
569                {
570                    let params = critic.parameters();
571                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
572                    for p in params {
573                        if p.grad_owned().is_some() {
574                            with_grads.push(p);
575                        }
576                    }
577                    if !with_grads.is_empty() {
578                        let _ = grad_global_norm(&mut with_grads);
579                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580                        critic_opt.step(&mut with_grads);
581                        critic_opt.zero_grad(&mut with_grads);
582                    }
583                }
584
585                // Occasionally log
586                if e == 0 && mb == 0 {
587                    println!(
588                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
589                        t,
590                        actor_loss.value(),
591                        v_loss.value()
592                    );
593                }
594
595                clear_all_graphs_known();
596            }
597        }
598    }
599
600    println!("=== PPO training finished ===");
601    Ok(())
602}
Source§

impl Tensor

Source

pub fn var(&self) -> Tensor

Computes the variance over all elements

The variance is calculated as the mean of squared differences from the mean. This operation reduces the tensor to a scalar value [1].

The implementation uses population variance (divides by n rather than n-1) to match PyTorch’s default behavior.

§Returns

A scalar tensor containing the variance value

§Examples
use train_station::Tensor;

// Basic variance calculation
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
let variance = tensor.var();
assert!((variance.get(&[0]) - 1.25).abs() < 1e-5);
use train_station::Tensor;

// Variance of a larger dataset
let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
let variance = tensor.var();
// mean=4.5, var=mean([3.5², 1.5², 0.5², 2.5², 2.5², 0.5², 1.5², 3.5²]) = 5.25
assert!((variance.get(&[0]) - 5.25).abs() < 1e-5);
use train_station::Tensor;

// Variance of constant values (should be 0)
let tensor = Tensor::from_slice(&[5.0, 5.0, 5.0, 5.0], vec![4]).unwrap();
let variance = tensor.var();
assert!((variance.get(&[0]) - 0.0).abs() < 1e-6);
§Performance

Uses optimized contiguous tensor path with manual loop unrolling for better performance. Non-contiguous tensors use stride-aware iteration. The algorithm performs two passes: first to compute the mean, then to compute the variance.

Source

pub fn var_dims(&self, dims: &[usize], keepdim: bool) -> Tensor

Computes the variance over specified dimensions

Reduces the tensor along the specified dimensions by computing the variance of each slice. The result maintains the original tensor structure with reduced dimensions optionally preserved as size-1 dimensions.

Uses population variance (divides by n rather than n-1) to match PyTorch’s default behavior.

§Arguments
  • dims - Vector of dimension indices to reduce over (must be valid for tensor rank)
  • keepdim - Whether to keep reduced dimensions as size-1 dimensions
§Returns

A tensor with variance computed over the specified dimensions

§Examples
use train_station::Tensor;

// Variance along rows (dimension 1) with keepdim=true
let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
let row_vars = matrix.var_dims(&[1], true);
assert_eq!(row_vars.shape().dims(), vec![2, 1]);
assert!((row_vars.get(&[0, 0]) - 1.0).abs() < 1e-6); // var([1, 3]) = 1.0
assert!((row_vars.get(&[1, 0]) - 0.0).abs() < 1e-6); // var([2, 2]) = 0.0
use train_station::Tensor;

// Variance along columns (dimension 0) with keepdim=false
let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let col_vars = matrix.var_dims(&[0], false);
assert_eq!(col_vars.shape().dims(), vec![2]);
// var([1, 3]) = 1.0, var([2, 4]) = 1.0
assert!((col_vars.get(&[0]) - 1.0).abs() < 1e-6);
assert!((col_vars.get(&[1]) - 1.0).abs() < 1e-6);
use train_station::Tensor;

// Variance over multiple dimensions
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let var_all = tensor.var_dims(&[0, 1], false);
assert_eq!(var_all.shape().dims(), vec![1]);
// var([1, 2, 3, 4]) = 1.25
assert!((var_all.get(&[0]) - 1.25).abs() < 1e-5);
§Panics
  • If dims is empty
  • If any dimension index is out of bounds for the tensor rank
  • If the reduced size is 0 (invalid for variance calculation)
§Performance

Uses efficient coordinate-based iteration that works correctly with both contiguous and non-contiguous tensor layouts. The algorithm performs two passes: first to compute means, then to compute variances.

Source§

impl Tensor

Source

pub fn cat(tensors: &[Tensor], dim: usize) -> Tensor

Concatenate tensors along a given dimension

Joins multiple tensors along the specified dimension, creating a new tensor with the combined data. All input tensors must have the same rank and matching dimensions except for the concatenation dimension.

§Arguments
  • tensors - Slice of tensors to concatenate (must not be empty)
  • dim - Dimension along which to concatenate (must be < tensor rank)
§Returns

A new tensor containing the concatenated data with shape where the concatenation dimension is the sum of all input tensor sizes along that dimension.

§Panics
  • If tensors is empty
  • If dim is out of bounds for the tensor rank
  • If tensors have different ranks
  • If tensors have mismatched dimensions (except along concatenation dimension)
§Examples
use train_station::Tensor;

// Concatenate 1D tensors
let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
let result = Tensor::cat(&[a, b], 0);
assert_eq!(result.shape().dims(), vec![4]);
assert_eq!(result.get(&[0]), 1.0);
assert_eq!(result.get(&[1]), 2.0);
assert_eq!(result.get(&[2]), 3.0);
assert_eq!(result.get(&[3]), 4.0);
use train_station::Tensor;

// Concatenate 2D tensors along dimension 1
let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::from_slice(&[5.0, 6.0], vec![2, 1]).unwrap();
let result = Tensor::cat(&[a, b], 1);
assert_eq!(result.shape().dims(), vec![2, 3]);
assert_eq!(result.get(&[0, 0]), 1.0);
assert_eq!(result.get(&[0, 1]), 2.0);
assert_eq!(result.get(&[0, 2]), 5.0);
use train_station::Tensor;

// Concatenate with gradient tracking
let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
a.set_requires_grad(true);
b.set_requires_grad(true);

let result = Tensor::cat(&[a, b], 0);
assert!(result.requires_grad());
Examples found in repository?
examples/RL_training/td3.rs (line 198)
193    fn forward(&self, state: &Tensor, action: &Tensor) -> Tensor {
194        // Concatenate along feature dim (dim=1) for batched inputs
195        // IMPORTANT: use views to preserve gradient graph; cloning would detach autograd
196        let s_view = state.view(state.shape().dims().iter().map(|&d| d as i32).collect());
197        let a_view = action.view(action.shape().dims().iter().map(|&d| d as i32).collect());
198        let sa = Tensor::cat(&[s_view, a_view], 1);
199        self.net.forward(&sa, None)
200    }
Source§

impl Tensor

Source

pub fn contiguous(&self) -> Tensor

Creates a contiguous copy of the tensor

This operation ensures that the tensor data is stored in a linear, cache-friendly memory layout. If the tensor is already contiguous, this operation returns a clone. For non-contiguous tensors, it creates a new tensor with the same data but in contiguous memory layout.

The operation uses different optimization strategies based on tensor size:

  • Small tensors (≤64 elements): Simple coordinate-based copy
  • Medium tensors (65-1023 elements): Unrolled copy for better performance
  • Large tensors (≥1024 elements): Blocked copy with cache optimization
§Returns

A new tensor with contiguous memory layout containing the same data

§Examples
use train_station::Tensor;

// Already contiguous tensor
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let contiguous = tensor.contiguous();
assert!(contiguous.is_contiguous());
assert_eq!(contiguous.shape().dims(), vec![2, 2]);
use train_station::Tensor;

// Non-contiguous tensor from transpose
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let transposed = tensor.transpose(0, 1);
assert!(!transposed.is_contiguous());

let contiguous = transposed.contiguous();
assert!(contiguous.is_contiguous());
assert_eq!(contiguous.get(&[0, 0]), 1.0);
assert_eq!(contiguous.get(&[0, 1]), 3.0);
use train_station::Tensor;

// Preserves gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
tensor.set_requires_grad(true);

let contiguous = tensor.contiguous();
assert!(contiguous.requires_grad());
§Performance
  • Already contiguous: O(1) time complexity, returns a clone
  • Non-contiguous: O(n) time complexity with size-dependent optimizations
  • Memory usage: Creates a new tensor with the same size as the original
Examples found in repository?
examples/neural_networks/basic_encoder.rs (line 59)
53    pub fn forward(&self, input: &Tensor, attn_mask: Option<&Tensor>) -> Tensor {
54        let attn = self.mha.forward(input, input, input, attn_mask);
55        let res1 = attn.add_tensor(input);
56
57        // Feed-forward network with ReLU and residual
58        let (b, t, e) = Self::triple(input);
59        let x2d = res1.contiguous().view(vec![(b * t) as i32, e as i32]);
60        let hidden = self.ffn_in.forward(&x2d).relu();
61        let out2d = self.ffn_out.forward(&hidden);
62        let out = out2d.view(vec![b as i32, t as i32, e as i32]);
63        out.add_tensor(&res1)
64    }
More examples
Hide additional examples
examples/neural_networks/basic_decoder.rs (line 70)
56    pub fn forward(
57        &self,
58        tgt: &Tensor,
59        memory: &Tensor,
60        causal_mask: Option<&Tensor>,
61        cross_mask: Option<&Tensor>,
62    ) -> Tensor {
63        let self_attn = self.self_attn.forward(tgt, tgt, tgt, causal_mask);
64        let res1 = self_attn.add_tensor(tgt);
65
66        let cross = self.cross_attn.forward(&res1, memory, memory, cross_mask);
67        let res2 = cross.add_tensor(&res1);
68
69        let (b, t, e) = Self::triple(tgt);
70        let x2d = res2.contiguous().view(vec![(b * t) as i32, e as i32]);
71        let hidden = self.ffn_in.forward(&x2d).relu();
72        let out2d = self.ffn_out.forward(&hidden);
73        let out = out2d.view(vec![b as i32, t as i32, e as i32]);
74        out.add_tensor(&res2)
75    }
examples/neural_networks/multi_head_attention.rs (line 113)
72    pub fn forward(
73        &self,
74        query: &Tensor,
75        key: &Tensor,
76        value: &Tensor,
77        attn_mask: Option<&Tensor>,
78    ) -> Tensor {
79        let qkv = Self::project_qkv(query, key, value, &self.q_proj, &self.k_proj, &self.v_proj);
80        let (q, k, v) = qkv;
81
82        // Split heads: [b, t, e] -> [b, h, t, d]
83        let (b, tq, _e) = Self::triple(query);
84        let (_b2, tk, _e2) = Self::triple(key);
85        let q = Self::split_heads(&q, b, tq, self.num_heads, self.head_dim);
86        let k = Self::split_heads(&k, b, tk, self.num_heads, self.head_dim);
87        let v = Self::split_heads(&v, b, tk, self.num_heads, self.head_dim);
88
89        // Scaled dot-product attention
90        // logits: [b, h, tq, tk]
91        let k_t = k.transpose(2, 3);
92        let mut logits = q.matmul(&k_t).div_scalar((self.head_dim as f32).sqrt());
93        if let Some(mask) = attn_mask {
94            let dims = mask.shape().dims().to_vec();
95            // If boolean-like mask matching [b,h,tq,tk], apply masked_fill
96            if dims.len() == 4 && dims[0] == b && dims[1] == self.num_heads && dims[2] == tq {
97                // Interpret mask > 0.5 as keep; we invert to build masked positions
98                let cond: Vec<bool> = mask.data().iter().map(|&v| v < 0.5).collect();
99                // Apply masked fill on a flattened view, then reshape back
100                let flat_logits = logits.view(vec![(b * self.num_heads * tq * tk) as i32]);
101                let filled = flat_logits.masked_fill(&cond, f32::NEG_INFINITY);
102                logits = filled.view(vec![b as i32, self.num_heads as i32, tq as i32, tk as i32]);
103            } else {
104                // Fallback: additive mask
105                logits = logits.add_tensor(mask);
106            }
107        }
108        let attn = logits.softmax(3);
109
110        // context: [b, h, tq, d]
111        let context = attn.matmul(&v);
112        let context = context.permute(vec![0, 2, 1, 3]); // [b, tq, h, d]
113        let context = context.contiguous().view(vec![
114            b as i32,
115            tq as i32,
116            (self.num_heads * self.head_dim) as i32,
117        ]);
118
119        // Output projection (flatten to 2D, project, then restore 3D)
120        let flat = context.view(vec![(b * tq) as i32, self.embed_dim as i32]);
121        let out2d = self.out_proj.forward(&flat);
122        out2d.view(vec![b as i32, tq as i32, self.embed_dim as i32])
123    }
Source§

impl Tensor

Source

pub fn flatten(&self) -> Tensor

Flatten the tensor into a 1D representation

Transforms a multi-dimensional tensor into a 1D tensor by reshaping all dimensions into a single dimension. This is equivalent to reshape(vec![-1]) where -1 automatically calculates the size based on the total number of elements.

The flatten operation preserves the total number of elements while changing the tensor’s shape to have a single dimension. This is commonly used in neural networks to prepare tensor data for linear layers or feature extraction.

§Returns

A 1D tensor containing the same data as the original tensor, with shape [total_elements] where total_elements is the product of all original dimensions.

§Examples
use train_station::Tensor;

// Flatten a 2D tensor
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let flattened = tensor.flatten();
assert_eq!(flattened.shape().dims(), vec![4]);
assert_eq!(flattened.get(&[0]), 1.0);
assert_eq!(flattened.get(&[1]), 2.0);
assert_eq!(flattened.get(&[2]), 3.0);
assert_eq!(flattened.get(&[3]), 4.0);
use train_station::Tensor;

// Flatten a 3D tensor
let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
let tensor = Tensor::from_slice(&data, vec![2, 2, 3]).unwrap();
let flattened = tensor.flatten();
assert_eq!(flattened.shape().dims(), vec![12]);
assert_eq!(flattened.size(), 12);
use train_station::Tensor;

// Flatten with gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
tensor.set_requires_grad(true);

let flattened = tensor.flatten();
assert!(flattened.requires_grad());
assert_eq!(flattened.shape().dims(), vec![4]);
use train_station::Tensor;

// Flatten an already 1D tensor (no change)
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let flattened = tensor.flatten();
assert_eq!(flattened.shape().dims(), vec![3]);
assert_eq!(flattened.size(), 3);
§Performance
  • Time Complexity: O(1) - Returns a view when possible
  • Memory Usage: No additional memory allocation for view operations
  • Gradient Tracking: Preserves gradient requirements and tracking
§Relationship to Other Operations

This operation is equivalent to:

use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let flattened = tensor.reshape(vec![-1]);

Where -1 is a special value that automatically calculates the dimension size based on the total number of elements in the tensor.

Source§

impl Tensor

Source

pub fn permute(&self, dims: Vec<usize>) -> Tensor

Permute tensor dimensions according to specified order

Rearranges the dimensions of the tensor according to the provided dimension order. This operation returns a view with reordered strides, avoiding data copying while changing the logical arrangement of the tensor’s dimensions.

The permutation is specified as a vector where each element represents the new position of the corresponding dimension from the original tensor. For example, permute(vec![1, 0]) swaps the first two dimensions.

§Arguments
  • dims - Vector specifying the new order of dimensions (must have length equal to tensor rank)
§Returns

A new tensor view with rearranged dimensions and correspondingly adjusted strides. The total number of elements remains unchanged.

§Panics
  • If dims length does not equal the tensor rank
  • If any dimension index is out of bounds for the tensor rank
  • If dims contains duplicate dimension indices
§Examples
use train_station::Tensor;

// Permute 2D tensor (swap dimensions)
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
let permuted = tensor.permute(vec![1, 0]);
assert_eq!(permuted.shape().dims(), vec![3, 2]);
assert_eq!(permuted.get(&[0, 0]), 1.0);
assert_eq!(permuted.get(&[1, 0]), 2.0);
assert_eq!(permuted.get(&[2, 1]), 6.0);
use train_station::Tensor;

// Permute 3D tensor (reorder dimensions)
let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
let permuted = tensor.permute(vec![2, 0, 1]);
assert_eq!(permuted.shape().dims(), vec![4, 2, 3]);
assert_eq!(permuted.size(), 24); // Total elements unchanged
use train_station::Tensor;

// Permute with gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
tensor.set_requires_grad(true);

let permuted = tensor.permute(vec![1, 0]);
assert!(permuted.requires_grad());
assert_eq!(permuted.shape().dims(), vec![2, 2]);
use train_station::Tensor;

// Identity permutation (no change)
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let permuted = tensor.permute(vec![0, 1]);
assert_eq!(permuted.shape().dims(), vec![2, 2]);
assert_eq!(permuted.get(&[0, 0]), 1.0);
assert_eq!(permuted.get(&[1, 1]), 4.0);
§Performance
  • Time Complexity: O(1) - Returns a view with reordered strides
  • Memory Usage: No additional memory allocation (view operation)
  • Gradient Tracking: Preserves gradient requirements and tracking
§Relationship to Other Operations

This operation is similar to transpose() but more general:

  • transpose(dim0, dim1) is equivalent to permute() with a swap of two dimensions
  • permute() can handle arbitrary dimension reordering for tensors of any rank
§Memory Layout

The permuted tensor maintains the same underlying data but with reordered strides. This means the tensor becomes non-contiguous unless the permutation is the identity permutation.

Examples found in repository?
examples/neural_networks/multi_head_attention.rs (line 112)
72    pub fn forward(
73        &self,
74        query: &Tensor,
75        key: &Tensor,
76        value: &Tensor,
77        attn_mask: Option<&Tensor>,
78    ) -> Tensor {
79        let qkv = Self::project_qkv(query, key, value, &self.q_proj, &self.k_proj, &self.v_proj);
80        let (q, k, v) = qkv;
81
82        // Split heads: [b, t, e] -> [b, h, t, d]
83        let (b, tq, _e) = Self::triple(query);
84        let (_b2, tk, _e2) = Self::triple(key);
85        let q = Self::split_heads(&q, b, tq, self.num_heads, self.head_dim);
86        let k = Self::split_heads(&k, b, tk, self.num_heads, self.head_dim);
87        let v = Self::split_heads(&v, b, tk, self.num_heads, self.head_dim);
88
89        // Scaled dot-product attention
90        // logits: [b, h, tq, tk]
91        let k_t = k.transpose(2, 3);
92        let mut logits = q.matmul(&k_t).div_scalar((self.head_dim as f32).sqrt());
93        if let Some(mask) = attn_mask {
94            let dims = mask.shape().dims().to_vec();
95            // If boolean-like mask matching [b,h,tq,tk], apply masked_fill
96            if dims.len() == 4 && dims[0] == b && dims[1] == self.num_heads && dims[2] == tq {
97                // Interpret mask > 0.5 as keep; we invert to build masked positions
98                let cond: Vec<bool> = mask.data().iter().map(|&v| v < 0.5).collect();
99                // Apply masked fill on a flattened view, then reshape back
100                let flat_logits = logits.view(vec![(b * self.num_heads * tq * tk) as i32]);
101                let filled = flat_logits.masked_fill(&cond, f32::NEG_INFINITY);
102                logits = filled.view(vec![b as i32, self.num_heads as i32, tq as i32, tk as i32]);
103            } else {
104                // Fallback: additive mask
105                logits = logits.add_tensor(mask);
106            }
107        }
108        let attn = logits.softmax(3);
109
110        // context: [b, h, tq, d]
111        let context = attn.matmul(&v);
112        let context = context.permute(vec![0, 2, 1, 3]); // [b, tq, h, d]
113        let context = context.contiguous().view(vec![
114            b as i32,
115            tq as i32,
116            (self.num_heads * self.head_dim) as i32,
117        ]);
118
119        // Output projection (flatten to 2D, project, then restore 3D)
120        let flat = context.view(vec![(b * tq) as i32, self.embed_dim as i32]);
121        let out2d = self.out_proj.forward(&flat);
122        out2d.view(vec![b as i32, tq as i32, self.embed_dim as i32])
123    }
124
125    fn project_qkv(
126        query: &Tensor,
127        key: &Tensor,
128        value: &Tensor,
129        q_proj: &LinearLayer,
130        k_proj: &LinearLayer,
131        v_proj: &LinearLayer,
132    ) -> (Tensor, Tensor, Tensor) {
133        let (bq, tq, eq) = Self::triple(query);
134        let (bk, tk, ek) = Self::triple(key);
135        let (_bv, tv, ev) = Self::triple(value);
136        assert!(eq == ek && ek == ev, "Q,K,V embed dims must match");
137        let q2d = query.view(vec![(bq * tq) as i32, eq as i32]);
138        let k2d = key.view(vec![(bk * tk) as i32, ek as i32]);
139        let v2d = value.view(vec![(_bv * tv) as i32, ev as i32]);
140        let q = q_proj
141            .forward(&q2d)
142            .view(vec![bq as i32, tq as i32, eq as i32]);
143        let k = k_proj
144            .forward(&k2d)
145            .view(vec![bk as i32, tk as i32, ek as i32]);
146        let v = v_proj
147            .forward(&v2d)
148            .view(vec![bk as i32, tv as i32, ev as i32]);
149        (q, k, v)
150    }
151
152    fn split_heads(x: &Tensor, b: usize, t: usize, h: usize, d: usize) -> Tensor {
153        x.view(vec![b as i32, t as i32, h as i32, d as i32])
154            .permute(vec![0, 2, 1, 3])
155    }
Source§

impl Tensor

Source

pub fn reshape(&self, new_shape: Vec<i32>) -> Tensor

Reshape the tensor to the specified dimensions

Changes the shape of the tensor while preserving the total number of elements. This operation returns a view when the tensor is contiguous, avoiding data copying. For non-contiguous tensors, data is copied to ensure the reshape is valid.

The reshape operation supports automatic dimension inference using -1, which allows one dimension to be automatically calculated based on the total number of elements and the other specified dimensions.

§Arguments
  • new_shape - Target shape for the tensor. Use -1 for one dimension to have it automatically inferred from the total size.
§Returns

A new tensor with the specified shape containing the same data as the original tensor.

§Panics
  • If more than one dimension is -1
  • If the total number of elements doesn’t match the original tensor
  • If any dimension size is 0 or less than -1
  • If the inferred dimension size is not a whole number
§Examples
use train_station::Tensor;

// Basic reshape
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
let reshaped = tensor.reshape(vec![3, 2]);
assert_eq!(reshaped.shape().dims(), vec![3, 2]);
assert_eq!(reshaped.get(&[0, 0]), 1.0);
assert_eq!(reshaped.get(&[2, 1]), 6.0);
use train_station::Tensor;

// Using -1 for automatic dimension inference
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
let reshaped = tensor.reshape(vec![2, -1]);
assert_eq!(reshaped.shape().dims(), vec![2, 2]);
assert_eq!(reshaped.get(&[0, 0]), 1.0);
assert_eq!(reshaped.get(&[1, 1]), 4.0);
use train_station::Tensor;

// Reshape with gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
tensor.set_requires_grad(true);

let reshaped = tensor.reshape(vec![4]);
assert!(reshaped.requires_grad());
assert_eq!(reshaped.shape().dims(), vec![4]);
use train_station::Tensor;

// Reshape 3D tensor
let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
let reshaped = tensor.reshape(vec![6, 4]);
assert_eq!(reshaped.shape().dims(), vec![6, 4]);
assert_eq!(reshaped.size(), 24);
§Performance
  • Contiguous tensors: O(1) time complexity, returns a view
  • Non-contiguous tensors: O(n) time complexity with data copying
  • Memory usage: No additional allocation for view operations
  • Gradient tracking: Preserves gradient requirements and tracking
§Automatic Dimension Inference

When using -1 for a dimension, the size is automatically calculated:

use train_station::Tensor;

// For a tensor with 12 elements
let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
let tensor = Tensor::from_slice(&data, vec![3, 4]).unwrap();

let reshaped1 = tensor.reshape(vec![3, -1]);  // Results in shape [3, 4]
let reshaped2 = tensor.reshape(vec![-1, 6]);  // Results in shape [2, 6]
let reshaped3 = tensor.reshape(vec![-1]);     // Results in shape [12]
Examples found in repository?
examples/RL_training/ppo_discrete.rs (line 476)
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320    println!("=== PPO Discrete Example (YardEnv) ===");
321
322    let state_dim = 3usize;
323    let action_dim = 3usize;
324    let total_steps = std::env::var("PPOD_STEPS")
325        .ok()
326        .and_then(|v| v.parse::<usize>().ok())
327        .unwrap_or(3500usize);
328    let horizon = 128usize;
329    let epochs = 4usize;
330    let mini_batch_size = 64usize;
331    let gamma = 0.99f32;
332    let lam = 0.95f32;
333    let clip_eps = 0.2f32;
334    let vf_coef = 0.5f32;
335    let ent_coef = 0.0f32;
336    let max_grad_norm = 1.0f32;
337
338    let mut actor = Actor::new(state_dim, action_dim, Some(111));
339    let mut critic = Critic::new(state_dim, Some(222));
340    let mut actor_opt = Adam::with_learning_rate(3e-4);
341    for p in actor.parameters() {
342        actor_opt.add_parameter(p);
343    }
344    let mut critic_opt = Adam::with_learning_rate(3e-4);
345    for p in critic.parameters() {
346        critic_opt.add_parameter(p);
347    }
348
349    let mut env = YardEnv::new(1234);
350    let mut rng = SmallRng::new(98765);
351    let mut state = env.reset();
352    let mut episode_return = 0.0f32;
353    let mut episode = 0usize;
354    let mut ema_return: Option<f32> = None;
355    let ema_alpha = 0.05f32;
356    let mut best_return = f32::NEG_INFINITY;
357
358    let mut t = 0usize;
359    while t < total_steps {
360        let mut batch = RolloutBatch::new(horizon, state_dim);
361        for _ in 0..horizon {
362            // Actor logits and categorical sampling
363            let logits = actor.forward(&state); // [1, A]
364            let probs = logits.softmax(1); // [1, A]
365                                           // sample action from probs (CPU sampling)
366            let p = probs.data();
367            let (p0, p1, _p2) = (p[0], p[1], p[2]);
368            let u = rng.next_f32();
369            let a_idx = if u < p0 {
370                0
371            } else if u < p0 + p1 {
372                1
373            } else {
374                2
375            };
376
377            let old_logp = {
378                let _ng = NoGradTrack::new();
379                let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380                lp.data()[0]
381            };
382
383            // Step env
384            let (next_state, reward, done) = env.step(a_idx);
385            episode_return += reward;
386
387            // Critic value
388            let value_t = critic.forward(&state);
389            let value_v = value_t.data()[0];
390
391            batch.push(
392                state.data(),
393                a_idx,
394                old_logp,
395                reward,
396                if done { 1.0 } else { 0.0 },
397                value_v,
398                next_state.data(),
399            );
400
401            state = if done {
402                let st = env.reset();
403                ema_return = Some(match ema_return {
404                    None => episode_return,
405                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406                });
407                if episode_return > best_return {
408                    best_return = episode_return;
409                }
410                println!(
411                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412                    t,
413                    episode,
414                    episode_return,
415                    ema_return.unwrap_or(episode_return),
416                    best_return
417                );
418                episode_return = 0.0;
419                episode += 1;
420                st
421            } else {
422                next_state
423            };
424
425            t += 1;
426            if t >= total_steps {
427                break;
428            }
429        }
430
431        // Bootstrap values for GAE
432        let next_values: Vec<f32> = {
433            let mut out = Vec::with_capacity(batch.len());
434            for i in 0..batch.len() {
435                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437                out.push(critic.forward(&s2_t).data()[0]);
438            }
439            out
440        };
441
442        let mut returns = vec![0.0f32; batch.len()];
443        let mut adv = vec![0.0f32; batch.len()];
444        compute_gae(
445            &mut returns,
446            &mut adv,
447            &batch.rewards,
448            &batch.dones,
449            &batch.values,
450            &next_values,
451            gamma,
452            lam,
453        );
454        normalize_in_place(&mut adv, 1e-8);
455
456        // Tensors for training
457        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458        let actions_vec = batch.actions.clone();
459        let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463        // PPO epochs
464        let num_minibatches = batch.len().div_ceil(mini_batch_size);
465        for e in 0..epochs {
466            for mb in 0..num_minibatches {
467                let start = mb * mini_batch_size;
468                let end = (start + mini_batch_size).min(batch.len());
469                if start >= end {
470                    break;
471                }
472
473                // Views
474                let s_mb = states_t
475                    .slice_view(start * state_dim, 1, (end - start) * state_dim)
476                    .reshape(vec![(end - start) as i32, state_dim as i32]);
477                let oldlp_mb = old_logp_t
478                    .slice_view(start, 1, end - start)
479                    .reshape(vec![(end - start) as i32, 1]);
480                let ret_mb = returns_t
481                    .slice_view(start, 1, end - start)
482                    .reshape(vec![(end - start) as i32, 1]);
483                let adv_mb = adv_t
484                    .slice_view(start, 1, end - start)
485                    .reshape(vec![(end - start) as i32, 1]);
486                let a_slice = &actions_vec[start..end];
487
488                // Zero grads
489                {
490                    let mut ps = actor.parameters();
491                    actor_opt.zero_grad(&mut ps);
492                }
493                {
494                    let mut ps = critic.parameters();
495                    critic_opt.zero_grad(&mut ps);
496                }
497
498                // Forward
499                let logits_mb = actor.forward(&s_mb); // [B,A]
500                let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501                let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502                let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503                let pg1 = ratio.mul_tensor(&adv_mb);
504                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507                let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509                let v_pred = critic.forward(&s_mb);
510                let v_loss = v_pred
511                    .sub_tensor(&ret_mb)
512                    .pow_scalar(2.0)
513                    .mean()
514                    .mul_scalar(vf_coef);
515
516                // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517                let probs_mb = logits_mb.softmax(1);
518                let logp_all = probs_mb.add_scalar(1e-8).log();
519                let ent = probs_mb
520                    .mul_tensor(&logp_all)
521                    .sum_dims(&[1], true)
522                    .mul_scalar(-1.0)
523                    .mean()
524                    .mul_scalar(ent_coef);
525
526                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527                loss.backward(None);
528
529                // Step actor
530                {
531                    let params = actor.parameters();
532                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
533                    for p in params {
534                        if p.grad_owned().is_some() {
535                            with_grads.push(p);
536                        }
537                    }
538                    if !with_grads.is_empty() {
539                        let _ = grad_global_norm(&mut with_grads);
540                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541                        actor_opt.step(&mut with_grads);
542                        actor_opt.zero_grad(&mut with_grads);
543                    }
544                }
545
546                // Step critic
547                {
548                    let params = critic.parameters();
549                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
550                    for p in params {
551                        if p.grad_owned().is_some() {
552                            with_grads.push(p);
553                        }
554                    }
555                    if !with_grads.is_empty() {
556                        let _ = grad_global_norm(&mut with_grads);
557                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558                        critic_opt.step(&mut with_grads);
559                        critic_opt.zero_grad(&mut with_grads);
560                    }
561                }
562
563                if e == 0 && mb == 0 {
564                    println!(
565                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
566                        t,
567                        actor_loss.value(),
568                        v_loss.value()
569                    );
570                }
571
572                clear_all_graphs_known();
573            }
574        }
575    }
576
577    println!("=== PPO discrete training finished ===");
578    Ok(())
579}
More examples
Hide additional examples
examples/RL_training/ppo_continuous.rs (line 490)
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330    println!("=== PPO Continuous Example (YardEnv) ===");
331
332    let state_dim = 3usize;
333    let action_dim = 1usize;
334
335    // Hparams
336    let total_steps = std::env::var("PPO_STEPS")
337        .ok()
338        .and_then(|v| v.parse::<usize>().ok())
339        .unwrap_or(4000usize);
340    let horizon = 128usize; // rollout length per update
341    let epochs = 4usize; // PPO epochs per update
342    let mini_batch_size = 64usize; // minibatch from horizon
343    let gamma = 0.99f32;
344    let lam = 0.95f32; // GAE lambda
345    let clip_eps = 0.2f32;
346    let vf_coef = 0.5f32;
347    let ent_coef = 0.0f32;
348    let max_grad_norm = 1.0f32;
349
350    // Models
351    let mut actor = Actor::new(state_dim, action_dim, Some(101));
352    let mut critic = Critic::new(state_dim, Some(202));
353
354    // Opts
355    let mut actor_opt = Adam::with_learning_rate(3e-4);
356    for p in actor.parameters() {
357        actor_opt.add_parameter(p);
358    }
359    let mut critic_opt = Adam::with_learning_rate(3e-4);
360    for p in critic.parameters() {
361        critic_opt.add_parameter(p);
362    }
363
364    // Env and RNG
365    let mut env = YardEnv::new(42);
366    let mut rng = SmallRng::new(999);
367    let mut state = env.reset();
368
369    // Metrics
370    let mut episode_return = 0.0f32;
371    let mut episode = 0usize;
372    let mut ema_return: Option<f32> = None;
373    let ema_alpha = 0.05f32;
374    let mut best_return = f32::NEG_INFINITY;
375
376    let mut t = 0usize;
377    while t < total_steps {
378        // Collect a rollout
379        let mut batch = RolloutBatch::new(horizon, state_dim);
380        for _ in 0..horizon {
381            // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382            let (mean, log_std_row) = actor.forward(&state);
383            let mean_v = mean.data()[0];
384            let log_std_v = log_std_row.data()[0];
385            let std_v = log_std_v.exp();
386            let noise = rng.normal();
387            let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389            // Build action tensor [1, A] for log_prob calculation with autograd
390            let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391            let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392            let log_prob_v = log_prob_t.data()[0];
393
394            // Step env
395            let (next_state, reward, done) = env.step(action_v);
396            episode_return += reward;
397
398            // Value
399            let value_t = critic.forward(&state);
400            let value_v = value_t.data()[0];
401
402            // Push
403            batch.push(
404                state.data(),
405                action_v,
406                log_prob_v,
407                reward,
408                if done { 1.0 } else { 0.0 },
409                value_v,
410                next_state.data(),
411            );
412
413            // Reset
414            state = if done {
415                let st = env.reset();
416                ema_return = Some(match ema_return {
417                    None => episode_return,
418                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419                });
420                if episode_return > best_return {
421                    best_return = episode_return;
422                }
423                println!(
424                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425                    t,
426                    episode,
427                    episode_return,
428                    ema_return.unwrap_or(episode_return),
429                    best_return
430                );
431                episode_return = 0.0;
432                episode += 1;
433                st
434            } else {
435                next_state
436            };
437
438            t += 1;
439            if t >= total_steps {
440                break;
441            }
442        }
443
444        // Bootstrap next values for GAE
445        let next_values: Vec<f32> = {
446            let mut out = Vec::with_capacity(batch.len());
447            for i in 0..batch.len() {
448                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450                let v2 = critic.forward(&s2_t).data()[0];
451                out.push(v2);
452            }
453            out
454        };
455
456        // Compute returns and advantages
457        let mut returns = vec![0.0f32; batch.len()];
458        let mut adv = vec![0.0f32; batch.len()];
459        compute_gae(
460            &mut returns,
461            &mut adv,
462            &batch.rewards,
463            &batch.dones,
464            &batch.values,
465            &next_values,
466            gamma,
467            lam,
468        );
469        normalize_in_place(&mut adv, 1e-8);
470
471        // Prepare tensors for training
472        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473        let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474        let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478        // PPO epochs over the rollout
479        let num_minibatches = batch.len().div_ceil(mini_batch_size);
480        for e in 0..epochs {
481            for mb in 0..num_minibatches {
482                let start = mb * mini_batch_size;
483                let end = (start + mini_batch_size).min(batch.len());
484                if start >= end {
485                    break;
486                }
487
488                // Slice views
489                let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490                let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491                let a_mb = actions_t
492                    .slice_view(start * action_dim, 1, (end - start) * action_dim)
493                    .reshape(vec![(end - start) as i32, action_dim as i32]);
494                let oldlp_mb = old_logp_t
495                    .slice_view(start, 1, end - start)
496                    .reshape(vec![(end - start) as i32, 1]);
497                let ret_mb = returns_t
498                    .slice_view(start, 1, end - start)
499                    .reshape(vec![(end - start) as i32, 1]);
500                let adv_mb = adv_t
501                    .slice_view(start, 1, end - start)
502                    .reshape(vec![(end - start) as i32, 1]);
503
504                // Zero grads
505                {
506                    let mut ps = actor.parameters();
507                    actor_opt.zero_grad(&mut ps);
508                }
509                {
510                    let mut ps = critic.parameters();
511                    critic_opt.zero_grad(&mut ps);
512                }
513
514                // Forward actor and critic
515                let (mean_mb, log_std_row) = actor.forward(&s_mb);
516                let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517                let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518                let clip_low =
519                    Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520                        .unwrap();
521                let clip_high =
522                    Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523                        .unwrap();
524                // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525                let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526                let ratio_clipped =
527                    clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528                let pg1 = ratio.mul_tensor(&adv_mb);
529                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532                let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534                let v_pred = critic.forward(&s_mb);
535                let v_loss = v_pred
536                    .sub_tensor(&ret_mb)
537                    .pow_scalar(2.0)
538                    .mean()
539                    .mul_scalar(vf_coef);
540
541                // Entropy (approx Gaussian entropy per action)
542                let entropy = log_std_row
543                    .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544                    .sum_dims(&[1], true)
545                    .mean()
546                    .mul_scalar(ent_coef);
547
548                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549                loss.backward(None);
550
551                // Step actor
552                {
553                    let params = actor.parameters();
554                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
555                    for p in params {
556                        if p.grad_owned().is_some() {
557                            with_grads.push(p);
558                        }
559                    }
560                    if !with_grads.is_empty() {
561                        let _ = grad_global_norm(&mut with_grads);
562                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563                        actor_opt.step(&mut with_grads);
564                        actor_opt.zero_grad(&mut with_grads);
565                    }
566                }
567
568                // Step critic
569                {
570                    let params = critic.parameters();
571                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
572                    for p in params {
573                        if p.grad_owned().is_some() {
574                            with_grads.push(p);
575                        }
576                    }
577                    if !with_grads.is_empty() {
578                        let _ = grad_global_norm(&mut with_grads);
579                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580                        critic_opt.step(&mut with_grads);
581                        critic_opt.zero_grad(&mut with_grads);
582                    }
583                }
584
585                // Occasionally log
586                if e == 0 && mb == 0 {
587                    println!(
588                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
589                        t,
590                        actor_loss.value(),
591                        v_loss.value()
592                    );
593                }
594
595                clear_all_graphs_known();
596            }
597        }
598    }
599
600    println!("=== PPO training finished ===");
601    Ok(())
602}
Source§

impl Tensor

Source

pub fn split(&self, split_size: usize, dim: usize) -> Vec<Tensor>

Split tensor into chunks of equal size along specified dimension

Divides the tensor into multiple smaller tensors along the specified dimension, where each chunk (except possibly the last) has the same size. The last chunk may be smaller if the dimension size is not evenly divisible by the split size.

This operation returns a vector of tensors, where each tensor is a view or copy of a portion of the original tensor. The first chunk is returned as a view when possible (zero-copy), while subsequent chunks may require data copying for non-zero base offsets.

§Arguments
  • split_size - Size of each chunk along the specified dimension (must be > 0)
  • dim - Dimension along which to split the tensor (must be < tensor rank)
§Returns

A vector of tensors, each representing a chunk of the original tensor. The number of chunks depends on the dimension size and split size.

§Panics
  • If tensor rank is 0 (scalar tensors cannot be split)
  • If dim is out of bounds for the tensor rank
  • If split_size is 0
§Examples
use train_station::Tensor;

// Split 2D tensor into equal chunks along dimension 1
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
let parts = tensor.split(1, 1);
assert_eq!(parts.len(), 3);
assert_eq!(parts[0].shape().dims(), vec![2, 1]);
assert_eq!(parts[1].shape().dims(), vec![2, 1]);
assert_eq!(parts[2].shape().dims(), vec![2, 1]);
assert_eq!(parts[0].get(&[0, 0]), 1.0);
assert_eq!(parts[1].get(&[0, 0]), 2.0);
assert_eq!(parts[2].get(&[1, 0]), 6.0);
use train_station::Tensor;

// Split with uneven division (last chunk smaller)
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
let parts = tensor.split(2, 1);
assert_eq!(parts.len(), 3);
assert_eq!(parts[0].shape().dims(), vec![1, 2]);
assert_eq!(parts[1].shape().dims(), vec![1, 2]);
assert_eq!(parts[2].shape().dims(), vec![1, 1]); // Last chunk smaller
use train_station::Tensor;

// Split with gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
tensor.set_requires_grad(true);

let parts = tensor.split(1, 1);
assert_eq!(parts.len(), 2);
assert!(parts[0].requires_grad());
assert!(parts[1].requires_grad());
use train_station::Tensor;

// Split 1D tensor
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
let parts = tensor.split(2, 0);
assert_eq!(parts.len(), 3);
assert_eq!(parts[0].shape().dims(), vec![2]);
assert_eq!(parts[1].shape().dims(), vec![2]);
assert_eq!(parts[2].shape().dims(), vec![2]);
§Performance
  • First Chunk: O(1) - Returns a view when possible (zero-copy)
  • Subsequent Chunks: O(n) - May require data copying for non-zero offsets
  • Memory Usage: Minimal allocation for view operations, copying for non-zero offsets
  • Gradient Tracking: Each chunk preserves gradient requirements and tracking
§Relationship to Other Operations

This operation is related to other tensor transformations:

  • split_with_sizes() - More general version with explicit chunk sizes
  • cat() - Inverse operation that concatenates tensors back together
  • chunk() - Alternative splitting operation with different semantics
§Memory Layout

The first chunk maintains the same underlying data as a view when the base offset is zero. Subsequent chunks may require data copying to handle non-zero base offsets, ensuring proper memory layout.

Source

pub fn split_with_sizes(&self, split_sizes: &[usize], dim: usize) -> Vec<Tensor>

Split tensor into chunks with explicit sizes along specified dimension

Divides the tensor into multiple smaller tensors along the specified dimension according to the provided size specifications. Each chunk has the exact size specified in the split_sizes array, and the sum of all sizes must equal the size of the specified dimension.

This operation provides precise control over the size of each resulting chunk, unlike split() which creates equal-sized chunks. The first chunk is returned as a view when possible (zero-copy), while subsequent chunks may require data copying for non-zero base offsets.

§Arguments
  • split_sizes - Array specifying the size of each chunk along the dimension
  • dim - Dimension along which to split the tensor (must be < tensor rank)
§Returns

A vector of tensors, each representing a chunk of the original tensor with the specified size. The number of chunks equals the length of split_sizes.

§Panics
  • If tensor rank is 0 (scalar tensors cannot be split)
  • If dim is out of bounds for the tensor rank
  • If sum of split_sizes does not equal the size of the specified dimension
§Examples
use train_station::Tensor;

// Split with explicit sizes
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
let parts = tensor.split_with_sizes(&[2, 3], 1);
assert_eq!(parts.len(), 2);
assert_eq!(parts[0].shape().dims(), vec![1, 2]);
assert_eq!(parts[1].shape().dims(), vec![1, 3]);
assert_eq!(parts[0].get(&[0, 0]), 1.0);
assert_eq!(parts[0].get(&[0, 1]), 2.0);
assert_eq!(parts[1].get(&[0, 0]), 3.0);
use train_station::Tensor;

// Split 2D tensor with different chunk sizes
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
let parts = tensor.split_with_sizes(&[1, 2], 1);
assert_eq!(parts.len(), 2);
assert_eq!(parts[0].shape().dims(), vec![2, 1]);
assert_eq!(parts[1].shape().dims(), vec![2, 2]);
use train_station::Tensor;

// Split with gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
tensor.set_requires_grad(true);

let parts = tensor.split_with_sizes(&[1, 1], 1);
assert_eq!(parts.len(), 2);
assert!(parts[0].requires_grad());
assert!(parts[1].requires_grad());
use train_station::Tensor;

// Split 1D tensor with explicit sizes
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
let parts = tensor.split_with_sizes(&[2, 2, 2], 0);
assert_eq!(parts.len(), 3);
assert_eq!(parts[0].shape().dims(), vec![2]);
assert_eq!(parts[1].shape().dims(), vec![2]);
assert_eq!(parts[2].shape().dims(), vec![2]);
§Performance
  • First Chunk: O(1) - Returns a view when possible (zero-copy)
  • Subsequent Chunks: O(n) - May require data copying for non-zero offsets
  • Memory Usage: Minimal allocation for view operations, copying for non-zero offsets
  • Gradient Tracking: Each chunk preserves gradient requirements and tracking
§Relationship to Other Operations

This operation is related to other tensor transformations:

  • split() - Simplified version with equal-sized chunks
  • cat() - Inverse operation that concatenates tensors back together
  • chunk() - Alternative splitting operation with different semantics
§Memory Layout

The first chunk maintains the same underlying data as a view when the base offset is zero. Subsequent chunks may require data copying to handle non-zero base offsets, ensuring proper memory layout. Zero-sized chunks are handled by creating empty tensors with appropriate shapes.

Source§

impl Tensor

Source

pub fn squeeze(&self, dim: Option<usize>) -> Tensor

Remove dimensions of size 1 from the tensor

Removes singleton dimensions (dimensions with size 1) from the tensor, reducing its rank while preserving the total number of elements. This operation is useful for cleaning up tensor shapes and preparing data for operations that expect specific dimensionality.

The squeeze operation can remove either all size-1 dimensions or a specific dimension if it has size 1. When all dimensions are size 1, the result is a scalar tensor with shape [1] rather than an empty tensor to maintain mathematical consistency.

§Arguments
  • dim - Optional specific dimension to squeeze. If None, all size-1 dimensions are removed. If Some(d), only dimension d is removed if it has size 1.
§Returns

A new tensor with size-1 dimensions removed. The total number of elements remains unchanged.

§Panics
  • If dim is specified but out of bounds for the tensor rank
  • If dim is specified but the dimension does not have size 1
§Examples
use train_station::Tensor;

// Squeeze all size-1 dimensions
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
let squeezed = tensor.squeeze(None);
assert_eq!(squeezed.shape().dims(), vec![3]);
assert_eq!(squeezed.get(&[0]), 1.0);
assert_eq!(squeezed.get(&[1]), 2.0);
assert_eq!(squeezed.get(&[2]), 3.0);
use train_station::Tensor;

// Squeeze specific dimension
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
let squeezed = tensor.squeeze(Some(0));
assert_eq!(squeezed.shape().dims(), vec![3, 1]);
assert_eq!(squeezed.get(&[0, 0]), 1.0);
assert_eq!(squeezed.get(&[1, 0]), 2.0);
assert_eq!(squeezed.get(&[2, 0]), 3.0);
use train_station::Tensor;

// Squeeze preserves data integrity
let data = vec![1.0, 2.0, 3.0, 4.0];
let tensor = Tensor::from_slice(&data, vec![1, 2, 1, 2]).unwrap();
let squeezed = tensor.squeeze(None);
assert_eq!(squeezed.shape().dims(), vec![2, 2]);
assert_eq!(squeezed.size(), 4);
assert_eq!(squeezed.get(&[0, 0]), data[0]);
assert_eq!(squeezed.get(&[0, 1]), data[1]);
assert_eq!(squeezed.get(&[1, 0]), data[2]);
assert_eq!(squeezed.get(&[1, 1]), data[3]);
use train_station::Tensor;

// Handle edge case: all dimensions are size 1
let tensor = Tensor::from_slice(&[5.0], vec![1, 1, 1]).unwrap();
let squeezed = tensor.squeeze(None);
assert_eq!(squeezed.shape().dims(), vec![1]); // Not empty!
assert_eq!(squeezed.get(&[0]), 5.0);
use train_station::Tensor;

// Squeeze with gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
tensor.set_requires_grad(true);

let squeezed = tensor.squeeze(None);
assert!(squeezed.requires_grad());
assert_eq!(squeezed.shape().dims(), vec![3]);
use train_station::Tensor;

// Squeeze and unsqueeze roundtrip
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let unsqueezed = tensor.unsqueeze(0);
assert_eq!(unsqueezed.shape().dims(), vec![1, 3]);

let squeezed = unsqueezed.squeeze(Some(0));
assert_eq!(squeezed.shape().dims(), vec![3]);
assert_eq!(squeezed.get(&[0]), 1.0);
assert_eq!(squeezed.get(&[2]), 3.0);
§Performance
  • Time Complexity: O(1) - Returns a view through reshape operation
  • Memory Usage: No additional memory allocation (view operation)
  • Gradient Tracking: Preserves gradient requirements and tracking
  • Shape Transformation: Reduces tensor rank by removing singleton dimensions
§Relationship to Other Operations

This operation is related to other tensor transformations:

  • unsqueeze() - Inverse operation that adds size-1 dimensions
  • reshape() - More general shape transformation operation
  • flatten() - Reduces tensor to 1D by combining all dimensions
§Memory Layout

The squeezed tensor maintains the same underlying data as the original tensor through the reshape operation. This ensures zero-copy behavior when the tensor is contiguous, with only the shape metadata being modified to reflect the reduced dimensionality.

§Edge Cases
  • All size-1 dimensions: Returns a tensor with shape [1] rather than an empty tensor to maintain mathematical consistency
  • No size-1 dimensions: Returns a tensor with the same shape as the input
  • Mixed dimensions: Only removes dimensions with size 1, preserving others
Source§

impl Tensor

Source

pub fn stack(tensors: &[Tensor], dim: usize) -> Tensor

Stack a list of tensors along a new dimension

Combines multiple tensors by adding a new dimension at the specified position. All input tensors must have identical shapes, and the output tensor will have a new dimension of size equal to the number of input tensors. This operation is similar to PyTorch’s torch.stack function.

The stacking operation creates a new axis in the output tensor, unlike concatenation which operates along existing dimensions. This makes stacking useful for creating batch dimensions, combining feature maps, and implementing operations that require adding new tensor axes.

§Arguments
  • tensors - Array of tensors to stack. All tensors must have identical shapes.
  • dim - Index of the new axis in the output shape (0 <= dim <= rank)
§Returns

A new tensor with the stacked data. The output shape is the input shape with a new dimension of size tensors.len() inserted at position dim.

§Panics
  • If the tensor array is empty
  • If any tensor has a different shape than the first tensor
  • If dim is out of bounds (dim > rank of input tensors)
§Examples
use train_station::Tensor;

// Stack two 1D tensors along dimension 0
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
let stacked = Tensor::stack(&[a, b], 0);
assert_eq!(stacked.shape().dims(), vec![2, 3]);
assert_eq!(stacked.get(&[0, 0]), 1.0);
assert_eq!(stacked.get(&[1, 2]), 6.0);
use train_station::Tensor;

// Stack multiple 2D tensors along dimension 1
let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
let stacked = Tensor::stack(&[a, b, c], 1);
assert_eq!(stacked.shape().dims(), vec![2, 3, 2]);
assert_eq!(stacked.get(&[0, 0, 0]), 1.0);
assert_eq!(stacked.get(&[1, 2, 1]), 12.0);
use train_station::Tensor;

// Stack with gradient tracking
let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
a.set_requires_grad(true);
b.set_requires_grad(true);

let stacked = Tensor::stack(&[a, b], 0);
assert!(stacked.requires_grad());
assert_eq!(stacked.shape().dims(), vec![2, 2]);
use train_station::Tensor;

// Stack 3D tensors along the last dimension
let data1: Vec<f32> = (0..8).map(|i| i as f32).collect();
let data2: Vec<f32> = (8..16).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data1, vec![2, 2, 2]).unwrap();
let b = Tensor::from_slice(&data2, vec![2, 2, 2]).unwrap();
let stacked = Tensor::stack(&[a, b], 3);
assert_eq!(stacked.shape().dims(), vec![2, 2, 2, 2]);
assert_eq!(stacked.get(&[0, 0, 0, 0]), 0.0);
assert_eq!(stacked.get(&[1, 1, 1, 1]), 15.0);
§Performance
  • Time Complexity: O(n) where n is the total number of elements
  • Memory Usage: Allocates new contiguous tensor for output
  • SIMD Optimization: Uses AVX2 acceleration for large block copies
  • Block-wise Copying: Optimized copying strategy for better cache performance
  • Gradient Tracking: Preserves gradient requirements and tracking
§Relationship to Other Operations

This operation is related to other tensor transformations:

  • cat() - Concatenates tensors along existing dimensions
  • unsqueeze() - Adds a single dimension of size 1
  • reshape() - Changes tensor shape without adding dimensions
§Memory Layout

The output tensor is always contiguous, with elements arranged so that the stacked dimension is the fastest-changing index. This ensures optimal performance for subsequent operations and maintains compatibility with SIMD optimizations.

§Gradient Computation

During backward passes, gradients are split along the stacked dimension and distributed back to the original input tensors. This is implemented using the same gradient function as concatenation, treating the stack operation as concatenation along a new axis.

Source§

impl Tensor

Source

pub fn transpose(&self, dim0: usize, dim1: usize) -> Tensor

Transpose two dimensions of the tensor

Swaps two specified dimensions of the tensor, modifying the shape and memory access pattern. When possible, this operation returns a zero-copy view using stride manipulation. For complex cases or non-contiguous tensors, data is copied to ensure correct transposition.

The transpose operation is its own inverse - applying transpose twice with the same dimensions returns the original tensor.

§Arguments
  • dim0 - First dimension to swap (must be < tensor rank)
  • dim1 - Second dimension to swap (must be < tensor rank)
§Returns

A new tensor with the specified dimensions transposed. The total number of elements remains unchanged.

§Panics
  • If dim0 is out of bounds for the tensor rank
  • If dim1 is out of bounds for the tensor rank
§Examples
use train_station::Tensor;

// Basic 2D transpose
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
let transposed = tensor.transpose(0, 1);
assert_eq!(transposed.shape().dims(), vec![3, 2]);
assert_eq!(transposed.get(&[0, 0]), 1.0);
assert_eq!(transposed.get(&[0, 1]), 4.0);
assert_eq!(transposed.get(&[1, 0]), 2.0);
assert_eq!(transposed.get(&[1, 1]), 5.0);
assert_eq!(transposed.get(&[2, 0]), 3.0);
assert_eq!(transposed.get(&[2, 1]), 6.0);
use train_station::Tensor;

// 3D tensor transpose
let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
let transposed = tensor.transpose(0, 1);
assert_eq!(transposed.shape().dims(), vec![3, 2, 4]);
use train_station::Tensor;

// Transpose with gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
tensor.set_requires_grad(true);

let transposed = tensor.transpose(0, 1);
assert!(transposed.requires_grad());
assert_eq!(transposed.shape().dims(), vec![2, 2]);
use train_station::Tensor;

// Transpose same dimension (no change)
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let result = tensor.transpose(1, 1);
assert_eq!(result.shape().dims(), tensor.shape().dims());
assert_eq!(result.get(&[0, 0]), tensor.get(&[0, 0]));
use train_station::Tensor;

// Transpose is its own inverse
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let transposed = tensor.transpose(0, 1);
let double_transposed = transposed.transpose(0, 1);
assert_eq!(double_transposed.shape().dims(), tensor.shape().dims());
assert_eq!(double_transposed.get(&[0, 0]), tensor.get(&[0, 0]));
§Performance
  • Contiguous tensors: O(1) time complexity, returns a view
  • Non-contiguous tensors: O(n) time complexity with data copying
  • Memory usage: No additional allocation for view operations
  • Gradient tracking: Preserves gradient requirements and tracking
§Relationship to Other Operations

This operation is related to other tensor transformations:

  • t() - Convenience method for matrix transpose (last two dimensions)
  • permute() - More general dimension reordering operation
  • reshape() - Changes shape without changing dimension order
§Memory Layout

For contiguous tensors, transpose returns a view with modified strides, making the tensor non-contiguous. For non-contiguous tensors or complex cases, data is copied to ensure correct transposition.

Examples found in repository?
examples/neural_networks/multi_head_attention.rs (line 91)
72    pub fn forward(
73        &self,
74        query: &Tensor,
75        key: &Tensor,
76        value: &Tensor,
77        attn_mask: Option<&Tensor>,
78    ) -> Tensor {
79        let qkv = Self::project_qkv(query, key, value, &self.q_proj, &self.k_proj, &self.v_proj);
80        let (q, k, v) = qkv;
81
82        // Split heads: [b, t, e] -> [b, h, t, d]
83        let (b, tq, _e) = Self::triple(query);
84        let (_b2, tk, _e2) = Self::triple(key);
85        let q = Self::split_heads(&q, b, tq, self.num_heads, self.head_dim);
86        let k = Self::split_heads(&k, b, tk, self.num_heads, self.head_dim);
87        let v = Self::split_heads(&v, b, tk, self.num_heads, self.head_dim);
88
89        // Scaled dot-product attention
90        // logits: [b, h, tq, tk]
91        let k_t = k.transpose(2, 3);
92        let mut logits = q.matmul(&k_t).div_scalar((self.head_dim as f32).sqrt());
93        if let Some(mask) = attn_mask {
94            let dims = mask.shape().dims().to_vec();
95            // If boolean-like mask matching [b,h,tq,tk], apply masked_fill
96            if dims.len() == 4 && dims[0] == b && dims[1] == self.num_heads && dims[2] == tq {
97                // Interpret mask > 0.5 as keep; we invert to build masked positions
98                let cond: Vec<bool> = mask.data().iter().map(|&v| v < 0.5).collect();
99                // Apply masked fill on a flattened view, then reshape back
100                let flat_logits = logits.view(vec![(b * self.num_heads * tq * tk) as i32]);
101                let filled = flat_logits.masked_fill(&cond, f32::NEG_INFINITY);
102                logits = filled.view(vec![b as i32, self.num_heads as i32, tq as i32, tk as i32]);
103            } else {
104                // Fallback: additive mask
105                logits = logits.add_tensor(mask);
106            }
107        }
108        let attn = logits.softmax(3);
109
110        // context: [b, h, tq, d]
111        let context = attn.matmul(&v);
112        let context = context.permute(vec![0, 2, 1, 3]); // [b, tq, h, d]
113        let context = context.contiguous().view(vec![
114            b as i32,
115            tq as i32,
116            (self.num_heads * self.head_dim) as i32,
117        ]);
118
119        // Output projection (flatten to 2D, project, then restore 3D)
120        let flat = context.view(vec![(b * tq) as i32, self.embed_dim as i32]);
121        let out2d = self.out_proj.forward(&flat);
122        out2d.view(vec![b as i32, tq as i32, self.embed_dim as i32])
123    }
Source

pub fn t(&self) -> Tensor

Matrix transpose (transpose last two dimensions)

Convenience method for the common case of matrix transposition. For 2D tensors, this performs a standard matrix transpose. For higher-dimensional tensors, this transposes the last two dimensions, treating the tensor as a batch of matrices.

This method is equivalent to transpose(rank-2, rank-1) where rank is the number of dimensions in the tensor.

§Returns

A new tensor with the last two dimensions transposed

§Panics
  • If the tensor has less than 2 dimensions
§Examples
use train_station::Tensor;

// 2D matrix transpose
let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let transposed = matrix.t();
assert_eq!(transposed.shape().dims(), vec![2, 2]);
assert_eq!(transposed.get(&[0, 0]), 1.0);
assert_eq!(transposed.get(&[0, 1]), 3.0);
assert_eq!(transposed.get(&[1, 0]), 2.0);
assert_eq!(transposed.get(&[1, 1]), 4.0);
use train_station::Tensor;

// 3D tensor (batch of matrices)
let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
let tensor = Tensor::from_slice(&data, vec![2, 2, 3]).unwrap();
let transposed = tensor.t();
assert_eq!(transposed.shape().dims(), vec![2, 3, 2]);
use train_station::Tensor;

// Matrix transpose with gradient tracking
let mut matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
matrix.set_requires_grad(true);

let transposed = matrix.t();
assert!(transposed.requires_grad());
assert_eq!(transposed.shape().dims(), vec![2, 2]);
§Performance
  • Time Complexity: Same as transpose() - O(1) for views, O(n) for copies
  • Memory Usage: Same as transpose() - no allocation for views
  • Gradient Tracking: Preserves gradient requirements and tracking
§Relationship to Other Operations

This operation is equivalent to:

use train_station::Tensor;

let tensor = Tensor::new(vec![2, 3, 4]);
let rank = tensor.shape().rank();
let transposed1 = tensor.t();
let transposed2 = tensor.transpose(rank - 2, rank - 1);
// transposed1 and transposed2 are identical
Source§

impl Tensor

Source

pub fn unsqueeze(&self, dim: usize) -> Tensor

Add a dimension of size 1 at the specified position

Inserts a new dimension of size 1 at the specified position in the tensor’s shape, increasing the rank by 1 while preserving the total number of elements. This operation is useful for preparing tensors for broadcasting, creating batch dimensions, and adapting tensor shapes for specific neural network operations.

The unsqueeze operation is the inverse of squeeze() - unsqueezing a dimension and then squeezing it at the same position returns the original tensor.

§Arguments
  • dim - Position to insert the new dimension (0 <= dim <= rank)
§Returns

A new tensor with an additional dimension of size 1 at the specified position. The total number of elements remains unchanged.

§Panics
  • If dim is out of bounds (dim > rank of the tensor)
§Examples
use train_station::Tensor;

// Add dimension at the beginning
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let unsqueezed = tensor.unsqueeze(0);
assert_eq!(unsqueezed.shape().dims(), vec![1, 3]);
assert_eq!(unsqueezed.get(&[0, 0]), 1.0);
assert_eq!(unsqueezed.get(&[0, 1]), 2.0);
assert_eq!(unsqueezed.get(&[0, 2]), 3.0);
use train_station::Tensor;

// Add dimension at the end
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let unsqueezed = tensor.unsqueeze(1);
assert_eq!(unsqueezed.shape().dims(), vec![3, 1]);
assert_eq!(unsqueezed.get(&[0, 0]), 1.0);
assert_eq!(unsqueezed.get(&[1, 0]), 2.0);
assert_eq!(unsqueezed.get(&[2, 0]), 3.0);
use train_station::Tensor;

// Add dimension in the middle of 2D tensor
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let unsqueezed = tensor.unsqueeze(1);
assert_eq!(unsqueezed.shape().dims(), vec![2, 1, 2]);
assert_eq!(unsqueezed.get(&[0, 0, 0]), 1.0);
assert_eq!(unsqueezed.get(&[0, 0, 1]), 2.0);
assert_eq!(unsqueezed.get(&[1, 0, 0]), 3.0);
assert_eq!(unsqueezed.get(&[1, 0, 1]), 4.0);
use train_station::Tensor;

// Unsqueeze preserves data integrity
let data = vec![1.0, 2.0, 3.0, 4.0];
let tensor = Tensor::from_slice(&data, vec![4]).unwrap();
let unsqueezed = tensor.unsqueeze(0);
assert_eq!(unsqueezed.shape().dims(), vec![1, 4]);
assert_eq!(unsqueezed.size(), 4);
for (i, &d) in data.iter().enumerate() {
    assert_eq!(unsqueezed.get(&[0, i]), d);
}
use train_station::Tensor;

// Unsqueeze with gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
tensor.set_requires_grad(true);

let unsqueezed = tensor.unsqueeze(0);
assert!(unsqueezed.requires_grad());
assert_eq!(unsqueezed.shape().dims(), vec![1, 3]);
use train_station::Tensor;

// Unsqueeze and squeeze roundtrip
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let unsqueezed = tensor.unsqueeze(0);
assert_eq!(unsqueezed.shape().dims(), vec![1, 3]);

let squeezed = unsqueezed.squeeze(Some(0));
assert_eq!(squeezed.shape().dims(), vec![3]);
assert_eq!(squeezed.get(&[0]), 1.0);
assert_eq!(squeezed.get(&[2]), 3.0);
use train_station::Tensor;

// Multiple unsqueeze operations
let tensor = Tensor::from_slice(&[42.0], vec![1]).unwrap();
let unsqueezed1 = tensor.unsqueeze(0);
assert_eq!(unsqueezed1.shape().dims(), vec![1, 1]);

let unsqueezed2 = unsqueezed1.unsqueeze(0);
assert_eq!(unsqueezed2.shape().dims(), vec![1, 1, 1]);
assert_eq!(unsqueezed2.get(&[0, 0, 0]), 42.0);
§Performance
  • Time Complexity: O(1) - Returns a view through reshape operation
  • Memory Usage: No additional memory allocation (view operation)
  • Gradient Tracking: Preserves gradient requirements and tracking
  • Shape Transformation: Increases tensor rank by adding singleton dimensions
§Relationship to Other Operations

This operation is related to other tensor transformations:

  • squeeze() - Inverse operation that removes size-1 dimensions
  • reshape() - More general shape transformation operation
  • expand() - Broadcasts dimensions to larger sizes
§Memory Layout

The unsqueezed tensor maintains the same underlying data as the original tensor through the reshape operation. This ensures zero-copy behavior when the tensor is contiguous, with only the shape metadata being modified to reflect the increased dimensionality.

§Broadcasting Applications

Unsqueeze is commonly used for broadcasting operations:

use train_station::Tensor;

// Prepare for broadcasting: [3] -> [1, 3] for row-wise operations
let vector = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let row_vector = vector.unsqueeze(0); // Shape: [1, 3]

// Prepare for broadcasting: [3] -> [3, 1] for column-wise operations
let column_vector = vector.unsqueeze(1); // Shape: [3, 1]
§Neural Network Applications

Unsqueeze is essential for neural network operations:

use train_station::Tensor;

// Single sample -> batch dimension for neural network input
let sample = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let batch = sample.unsqueeze(0); // Shape: [1, 3] for batch processing

// Add channel dimension for convolutional operations
let feature_map = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let with_channels = feature_map.unsqueeze(0); // Shape: [1, 2, 2] for conv layers

Trait Implementations§

Source§

impl Add<&Tensor> for Tensor

Source§

fn add(self, other: &Tensor) -> Tensor

Adds a tensor and a tensor reference element-wise

§Returns

A new tensor containing the element-wise sum

Source§

type Output = Tensor

The resulting type after applying the + operator.
Source§

impl Add<&Tensor> for f32

Source§

fn add(self, tensor: &Tensor) -> Tensor

Adds a scalar to each element of the tensor (reference version)

§Returns

A new tensor with the scalar added to each element

Source§

type Output = Tensor

The resulting type after applying the + operator.
Source§

impl Add<Tensor> for &Tensor

Source§

fn add(self, other: Tensor) -> Tensor

Adds a tensor reference and a tensor element-wise

§Returns

A new tensor containing the element-wise sum

Source§

type Output = Tensor

The resulting type after applying the + operator.
Source§

impl Add<Tensor> for f32

Scalar-tensor addition operator implementations

Provides addition operations between scalars and tensors. All implementations delegate to the underlying add_scalar method.

Source§

fn add(self, tensor: Tensor) -> Tensor

Adds a scalar to each element of the tensor

§Returns

A new tensor with the scalar added to each element

Source§

type Output = Tensor

The resulting type after applying the + operator.
Source§

impl Add<f32> for &Tensor

Source§

fn add(self, scalar: f32) -> Tensor

Adds a scalar to each element of the tensor (reference version)

§Returns

A new tensor with the scalar added to each element

Source§

type Output = Tensor

The resulting type after applying the + operator.
Source§

impl Add<f32> for Tensor

Tensor-scalar addition operator implementations

Provides addition operations between tensors and scalars. All implementations delegate to the underlying add_scalar method.

Source§

fn add(self, scalar: f32) -> Tensor

Adds a scalar to each element of the tensor

§Returns

A new tensor with the scalar added to each element

Source§

type Output = Tensor

The resulting type after applying the + operator.
Source§

impl Add for &Tensor

Source§

fn add(self, other: &Tensor) -> Tensor

Adds two tensors element-wise (reference version)

§Returns

A new tensor containing the element-wise sum

Source§

type Output = Tensor

The resulting type after applying the + operator.
Source§

impl Add for Tensor

Tensor addition operator implementations

Provides addition operations between tensors with various reference combinations. All implementations delegate to the underlying add_tensor method for optimal performance.

Source§

fn add(self, other: Tensor) -> Tensor

Adds two tensors element-wise

§Returns

A new tensor containing the element-wise sum

Source§

type Output = Tensor

The resulting type after applying the + operator.
Source§

impl AddAssign<&Tensor> for Tensor

Source§

fn add_assign(&mut self, other: &Tensor)

Adds another tensor reference to this tensor in-place

Source§

impl AddAssign<f32> for Tensor

Tensor-scalar addition assignment operator implementations

Provides in-place addition operations between tensors and scalars.

Source§

fn add_assign(&mut self, scalar: f32)

Adds a scalar to each element of this tensor in-place

Source§

impl AddAssign for Tensor

Tensor addition assignment operator implementations

Provides in-place addition operations between tensors. All implementations delegate to the underlying add_tensor method.

Source§

fn add_assign(&mut self, other: Tensor)

Adds another tensor to this tensor in-place

Source§

impl Clone for Tensor

Clone implementation for Tensor

Creates a deep copy of the tensor data but resets gradtrack state (new tensor won’t track gradients unless explicitly set)

Source§

fn clone(&self) -> Self

Returns a duplicate of the value. Read more
1.0.0 · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for Tensor

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl Div<&Tensor> for Tensor

Source§

fn div(self, other: &Tensor) -> Tensor

Divides a tensor by a tensor reference element-wise

§Returns

A new tensor containing the element-wise quotient

Source§

type Output = Tensor

The resulting type after applying the / operator.
Source§

impl Div<&Tensor> for f32

Source§

fn div(self, tensor: &Tensor) -> Tensor

Divides a scalar by each element of the tensor (reference version)

§Returns

A new tensor with the scalar divided by each element

Source§

type Output = Tensor

The resulting type after applying the / operator.
Source§

impl Div<Tensor> for &Tensor

Source§

fn div(self, other: Tensor) -> Tensor

Divides a tensor reference by a tensor element-wise

§Returns

A new tensor containing the element-wise quotient

Source§

type Output = Tensor

The resulting type after applying the / operator.
Source§

impl Div<Tensor> for f32

Scalar-tensor division operator implementations

Provides division operations between scalars and tensors. Computes scalar / tensor by computing the reciprocal of the tensor and multiplying by the scalar.

Source§

fn div(self, tensor: Tensor) -> Tensor

Divides a scalar by each element of the tensor

§Returns

A new tensor with the scalar divided by each element

Source§

type Output = Tensor

The resulting type after applying the / operator.
Source§

impl Div<f32> for &Tensor

Source§

fn div(self, scalar: f32) -> Tensor

Divides each element of the tensor by a scalar (reference version)

§Returns

A new tensor with each element divided by the scalar

Source§

type Output = Tensor

The resulting type after applying the / operator.
Source§

impl Div<f32> for Tensor

Tensor-scalar division operator implementations

Provides division operations between tensors and scalars. All implementations delegate to the underlying div_scalar method.

Source§

fn div(self, scalar: f32) -> Tensor

Divides each element of the tensor by a scalar

§Returns

A new tensor with each element divided by the scalar

Source§

type Output = Tensor

The resulting type after applying the / operator.
Source§

impl Div for &Tensor

Source§

fn div(self, other: &Tensor) -> Tensor

Divides two tensors element-wise (reference version)

§Returns

A new tensor containing the element-wise quotient

Source§

type Output = Tensor

The resulting type after applying the / operator.
Source§

impl Div for Tensor

Tensor division operator implementations

Provides element-wise division operations between tensors with various reference combinations. All implementations delegate to the underlying div_tensor method for optimal performance.

Source§

fn div(self, other: Tensor) -> Tensor

Divides two tensors element-wise

§Returns

A new tensor containing the element-wise quotient

Source§

type Output = Tensor

The resulting type after applying the / operator.
Source§

impl DivAssign<&Tensor> for Tensor

Source§

fn div_assign(&mut self, other: &Tensor)

Divides this tensor by another tensor reference in-place

Source§

impl DivAssign<f32> for Tensor

Tensor-scalar division assignment operator implementations

Provides in-place division operations between tensors and scalars.

Source§

fn div_assign(&mut self, scalar: f32)

Divides each element of this tensor by a scalar in-place

Source§

impl DivAssign for Tensor

Tensor division assignment operator implementations

Provides in-place division operations between tensors. All implementations delegate to the underlying div_tensor method.

Source§

fn div_assign(&mut self, other: Tensor)

Divides this tensor by another tensor in-place

Source§

impl From<Tensor> for Vec<f32>

Source§

fn from(tensor: Tensor) -> Vec<f32>

Convert this tensor into a Vec<f32> in row-major order.

  • Contiguous fast path: single optimized copy
  • Non-contiguous: materialize a contiguous copy first
  • Gradient metadata is ignored; this is a pure data extraction API
Source§

impl From<Vec<f32>> for Tensor

Source§

fn from(v: Vec<f32>) -> Self

Create a Tensor from a Vec<f32> by copying into an aligned/padded allocation.

Note: We do not adopt the Vec’s allocation to preserve alignment and padding guarantees for SIMD operations. Use Tensor::into_vec() to extract data back.

Source§

impl FromFieldValue for Tensor

Source§

fn from_field_value( value: FieldValue, field_name: &str, ) -> SerializationResult<Self>

Convert FieldValue to Tensor for use as struct field

§Arguments
  • value - FieldValue containing tensor data
  • field_name - Name of the field for error reporting
§Returns

Tensor instance or error if deserialization fails

Source§

impl FromIterator<Tensor> for Tensor

Source§

fn from_iter<I: IntoIterator<Item = Tensor>>(iter: I) -> Self

Collect element view tensors back into a single tensor

This method reconstructs a tensor from an iterator of element view tensors. It includes optimizations for common patterns and maintains gradient tracking when appropriate.

The collection process automatically detects whether all elements are scalar views (shape [1]) and uses optimized collection strategies accordingly. Gradient tracking is preserved when any input element requires gradients.

§Performance
  • Optimized Collection: Specialized paths for scalar and mixed views
  • Memory Efficient: Direct memory copying without intermediate allocations
  • Gradient Preservation: Maintains gradtrack functionality when enabled
  • Shape Detection: Automatic detection of element shapes for optimization
§Implementation Details

The method performs the following steps:

  1. Element Collection: Gathers all element tensors from the iterator
  2. Shape Analysis: Determines if all elements are scalar views
  3. Optimized Path: Uses specialized collection for scalar views
  4. General Path: Handles mixed shapes by flattening into 1D tensor
  5. Gradient Setup: Preserves gradient tracking when appropriate
§Examples
§Basic Collection
use train_station::Tensor;

let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let doubled: Tensor = original.iter()
    .map(|elem| elem.mul_scalar(2.0))
    .collect();

assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);
§Collection with Gradient Tracking
use train_station::Tensor;

let original = Tensor::from_slice(&[1.0, 2.0], vec![2])
    .unwrap()
    .with_requires_grad();

let result: Tensor = original.iter()
    .map(|elem| elem.mul_scalar(2.0))
    .collect();

assert!(result.requires_grad());
assert_eq!(result.data(), &[2.0, 4.0]);
§Empty Iterator Handling
use train_station::Tensor;

let empty: Tensor = Vec::<Tensor>::new().into_iter().collect();
assert_eq!(empty.size(), 0);
assert_eq!(empty.shape().dims(), vec![0]);
Source§

impl FromIterator<f32> for Tensor

Source§

fn from_iter<I: IntoIterator<Item = f32>>(iter: I) -> Self

Collect f32 values into a 1D contiguous, SIMD-aligned Tensor

  • Streams directly when iterator reports exact size_hint
  • Falls back to temporary Vec and optimized_copy otherwise
  • No gradient tracking is set on the result
Source§

impl<'a> IntoIterator for &'a Tensor

High-performance iterator over tensor elements as view tensors

Each element becomes a proper Tensor view of shape [1] that can use all existing tensor operations and gradient tracking. Implements all standard iterator traits for maximum compatibility with Rust’s ecosystem.

This iterator provides zero-copy access to tensor elements through view tensors, enabling efficient element-wise operations while maintaining full compatibility with Rust’s standard library iterator methods.

§Performance

  • Zero-Copy Views: Each element is a view tensor sharing memory with source
  • O(1) Element Access: Constant-time view creation for each element
  • Memory Efficient: ~64 bytes overhead per element view
  • SIMD Compatible: All tensor operations use existing optimizations
  • Gradient Tracking: Full gradtrack support through element operations

§Implementation Details

The iterator creates lightweight view tensors on-demand, sharing the same memory allocation as the source tensor. This ensures zero-copy semantics while maintaining full tensor operation compatibility.

Each element view is created using Tensor::element_view(), which provides a true view of the underlying data without any copying. The view tensors support all standard tensor operations including gradient tracking.

§Standard Library Compatibility

This iterator implements all standard iterator traits:

  • Iterator: Basic iteration with next() and size_hint()
  • ExactSizeIterator: Precise size information with len()
  • DoubleEndedIterator: Reverse iteration with next_back()
  • FusedIterator: Fused iteration for better performance
  • IntoIterator: Automatic conversion for for loops

§Examples

§Basic Iteration
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();

// Basic iteration
for element in tensor.iter() {
    println!("Element value: {}", element.value());
}

// Standard library methods
let sum: f32 = tensor.iter()
    .map(|elem| elem.value())
    .sum();

assert_eq!(sum, 6.0);
§Element Operations
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();

// Tensor operations on elements
let transformed: Tensor = tensor.iter()
    .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
    .collect();

assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);
§Advanced Iterator Methods
use train_station::Tensor;

let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();

// Filter and transform
let result: Tensor = tensor.iter()
    .filter(|elem| elem.value() > 2.0)
    .map(|elem| elem.mul_scalar(10.0))
    .collect();

assert_eq!(result.data(), &[30.0, 40.0, 50.0]);

// Reverse iteration
let reversed: Tensor = tensor.iter().rev().collect();
assert_eq!(reversed.data(), &[5.0, 4.0, 3.0, 2.0, 1.0]);

IntoIterator for &Tensor now iterates outermost dimension, yielding sub-tensors (views)

Source§

type Item = Tensor

The type of the elements being iterated over.
Source§

type IntoIter = TensorDimIterator<'a>

Which kind of iterator are we turning this into?
Source§

fn into_iter(self) -> Self::IntoIter

Creates an iterator from a value. Read more
Source§

impl IntoIterator for Tensor

IntoIterator for owned Tensor: iterate outermost dimension producing sub-tensors. Enables .into_iter().flatten() patterns on owned tensors.

Source§

type Item = Tensor

The type of the elements being iterated over.
Source§

type IntoIter = TensorDimOwnedIterator

Which kind of iterator are we turning this into?
Source§

fn into_iter(self) -> Self::IntoIter

Creates an iterator from a value. Read more
Source§

impl Mul<&Tensor> for Tensor

Source§

fn mul(self, other: &Tensor) -> Tensor

Multiplies a tensor and a tensor reference element-wise

§Returns

A new tensor containing the element-wise product

Source§

type Output = Tensor

The resulting type after applying the * operator.
Source§

impl Mul<&Tensor> for f32

Source§

fn mul(self, tensor: &Tensor) -> Tensor

Multiplies each element of the tensor by a scalar (reference version)

§Returns

A new tensor with each element multiplied by the scalar

Source§

type Output = Tensor

The resulting type after applying the * operator.
Source§

impl Mul<Tensor> for &Tensor

Source§

fn mul(self, other: Tensor) -> Tensor

Multiplies a tensor reference and a tensor element-wise

§Returns

A new tensor containing the element-wise product

Source§

type Output = Tensor

The resulting type after applying the * operator.
Source§

impl Mul<Tensor> for f32

Scalar-tensor multiplication operator implementations

Provides multiplication operations between scalars and tensors. All implementations delegate to the underlying mul_scalar method.

Source§

fn mul(self, tensor: Tensor) -> Tensor

Multiplies each element of the tensor by a scalar

§Returns

A new tensor with each element multiplied by the scalar

Source§

type Output = Tensor

The resulting type after applying the * operator.
Source§

impl Mul<f32> for &Tensor

Source§

fn mul(self, scalar: f32) -> Tensor

Multiplies each element of the tensor by a scalar (reference version)

§Returns

A new tensor with each element multiplied by the scalar

Source§

type Output = Tensor

The resulting type after applying the * operator.
Source§

impl Mul<f32> for Tensor

Tensor-scalar multiplication operator implementations

Provides multiplication operations between tensors and scalars. All implementations delegate to the underlying mul_scalar method.

Source§

fn mul(self, scalar: f32) -> Tensor

Multiplies each element of the tensor by a scalar

§Returns

A new tensor with each element multiplied by the scalar

Source§

type Output = Tensor

The resulting type after applying the * operator.
Source§

impl Mul for &Tensor

Source§

fn mul(self, other: &Tensor) -> Tensor

Multiplies two tensors element-wise (reference version)

§Returns

A new tensor containing the element-wise product

Source§

type Output = Tensor

The resulting type after applying the * operator.
Source§

impl Mul for Tensor

Tensor multiplication operator implementations

Provides element-wise multiplication operations between tensors with various reference combinations. All implementations delegate to the underlying mul_tensor method for optimal performance.

Source§

fn mul(self, other: Tensor) -> Tensor

Multiplies two tensors element-wise

§Returns

A new tensor containing the element-wise product

Source§

type Output = Tensor

The resulting type after applying the * operator.
Source§

impl MulAssign<&Tensor> for Tensor

Source§

fn mul_assign(&mut self, other: &Tensor)

Multiplies this tensor by another tensor reference in-place

Source§

impl MulAssign<f32> for Tensor

Tensor-scalar multiplication assignment operator implementations

Provides in-place multiplication operations between tensors and scalars.

Source§

fn mul_assign(&mut self, scalar: f32)

Multiplies each element of this tensor by a scalar in-place

Source§

impl MulAssign for Tensor

Tensor multiplication assignment operator implementations

Provides in-place multiplication operations between tensors. All implementations delegate to the underlying mul_tensor method.

Source§

fn mul_assign(&mut self, other: Tensor)

Multiplies this tensor by another tensor in-place

Source§

impl Neg for &Tensor

Source§

fn neg(self) -> Tensor

Negates each element of the tensor (reference version)

§Returns

A new tensor with each element negated

Source§

type Output = Tensor

The resulting type after applying the - operator.
Source§

impl Neg for Tensor

Tensor negation operator implementations

Provides unary negation operations for tensors. All implementations delegate to the underlying mul_scalar method with -1.0.

Source§

fn neg(self) -> Tensor

Negates each element of the tensor

§Returns

A new tensor with each element negated

Source§

type Output = Tensor

The resulting type after applying the - operator.
Source§

impl Serializable for Tensor

Source§

fn to_json(&self) -> SerializationResult<String>

Serialize the tensor to JSON format

This method converts the tensor into a human-readable JSON string representation that includes all tensor data, shape information, device placement, and gradtrack state. The JSON format is suitable for debugging, configuration files, and cross-language interoperability.

§Returns

JSON string representation of the tensor on success, or SerializationError on failure

§Examples
use train_station::Tensor;
use train_station::serialization::Serializable;

let mut tensor = Tensor::zeros(vec![2, 3]);
tensor.set(&[0, 0], 1.0);
tensor.set(&[1, 2], 5.0);

let json = tensor.to_json().unwrap();
assert!(!json.is_empty());
assert!(json.contains("data"));
assert!(json.contains("shape"));
Source§

fn from_json(json: &str) -> SerializationResult<Self>

Deserialize a tensor from JSON format

This method parses a JSON string and reconstructs a tensor with all its data, shape information, device placement, and gradtrack state. The JSON must contain all necessary fields in the expected format.

§Arguments
  • json - JSON string containing serialized tensor data
§Returns

The deserialized tensor on success, or SerializationError on failure

§Examples
use train_station::Tensor;
use train_station::serialization::Serializable;

let mut original = Tensor::ones(vec![2, 2]);
original.set(&[0, 1], 3.0);
original.set_requires_grad(true);

let json = original.to_json().unwrap();
let restored = Tensor::from_json(&json).unwrap();

assert_eq!(original.shape().dims(), restored.shape().dims());
assert_eq!(original.get(&[0, 1]), restored.get(&[0, 1]));
assert_eq!(original.requires_grad(), restored.requires_grad());
Source§

fn to_binary(&self) -> SerializationResult<Vec<u8>>

Serialize the tensor to binary format

This method converts the tensor into a compact binary representation optimized for storage and transmission. The binary format provides maximum performance and minimal file sizes, making it ideal for large tensors and production use.

§Returns

Binary representation of the tensor on success, or SerializationError on failure

§Examples
use train_station::Tensor;
use train_station::serialization::Serializable;

let mut tensor = Tensor::zeros(vec![100, 100]);
for i in 0..10 {
    tensor.set(&[i, i], i as f32);
}

let binary = tensor.to_binary().unwrap();
assert!(!binary.is_empty());
// Binary format is more compact than JSON for large tensors
Source§

fn from_binary(data: &[u8]) -> SerializationResult<Self>

Deserialize a tensor from binary format

This method parses binary data and reconstructs a tensor with all its data, shape information, device placement, and gradtrack state. The binary data must contain complete serialized information in the expected format.

§Arguments
  • data - Binary data containing serialized tensor information
§Returns

The deserialized tensor on success, or SerializationError on failure

§Examples
use train_station::Tensor;
use train_station::serialization::Serializable;

let mut original = Tensor::ones(vec![3, 4]);
original.set(&[2, 3], 7.5);
original.set_requires_grad(true);

let binary = original.to_binary().unwrap();
let restored = Tensor::from_binary(&binary).unwrap();

assert_eq!(original.shape().dims(), restored.shape().dims());
assert_eq!(original.get(&[2, 3]), restored.get(&[2, 3]));
assert_eq!(original.requires_grad(), restored.requires_grad());
Source§

fn save<P: AsRef<Path>>( &self, path: P, format: Format, ) -> SerializationResult<()>

Save the object to a file in the specified format Read more
Source§

fn save_to_writer<W: Write>( &self, writer: &mut W, format: Format, ) -> SerializationResult<()>

Save the object to a writer in the specified format Read more
Source§

fn load<P: AsRef<Path>>(path: P, format: Format) -> SerializationResult<Self>

Load an object from a file in the specified format Read more
Source§

fn load_from_reader<R: Read>( reader: &mut R, format: Format, ) -> SerializationResult<Self>

Load an object from a reader in the specified format Read more
Source§

impl StructSerializable for Tensor

Source§

fn to_serializer(&self) -> StructSerializer

Convert Tensor to StructSerializer for serialization

Serializes tensor data, shape, device, and gradtrack state. Runtime state (id, grad, grad_fn, allocation_owner) is not serialized.

§Returns

StructSerializer containing all persistent tensor state

Source§

fn from_deserializer( deserializer: &mut StructDeserializer, ) -> SerializationResult<Self>

Create Tensor from StructDeserializer

Reconstructs tensor from serialized data, shape, device, and gradtrack state. Allocates new memory and generates new tensor ID.

§Arguments
  • deserializer - StructDeserializer containing tensor data
§Returns

Reconstructed Tensor instance or error if deserialization fails

Source§

fn save_json<P: AsRef<Path>>(&self, path: P) -> SerializationResult<()>

Saves the struct to a JSON file Read more
Source§

fn save_binary<P: AsRef<Path>>(&self, path: P) -> SerializationResult<()>

Saves the struct to a binary file Read more
Source§

fn load_json<P: AsRef<Path>>(path: P) -> SerializationResult<Self>

Loads the struct from a JSON file Read more
Source§

fn load_binary<P: AsRef<Path>>(path: P) -> SerializationResult<Self>

Loads the struct from a binary file Read more
Source§

fn to_json(&self) -> SerializationResult<String>

Converts the struct to a JSON string Read more
Source§

fn to_binary(&self) -> SerializationResult<Vec<u8>>

Converts the struct to binary data Read more
Source§

fn from_json(json: &str) -> SerializationResult<Self>

Creates the struct from a JSON string Read more
Source§

fn from_binary(data: &[u8]) -> SerializationResult<Self>

Creates the struct from binary data Read more
Source§

impl Sub<&Tensor> for Tensor

Source§

fn sub(self, other: &Tensor) -> Tensor

Subtracts a tensor reference from a tensor element-wise

§Returns

A new tensor containing the element-wise difference

Source§

type Output = Tensor

The resulting type after applying the - operator.
Source§

impl Sub<&Tensor> for f32

Source§

fn sub(self, tensor: &Tensor) -> Tensor

Subtracts each element of the tensor from the scalar (reference version)

§Returns

A new tensor with each element subtracted from the scalar

Source§

type Output = Tensor

The resulting type after applying the - operator.
Source§

impl Sub<Tensor> for &Tensor

Source§

fn sub(self, other: Tensor) -> Tensor

Subtracts a tensor from a tensor reference element-wise

§Returns

A new tensor containing the element-wise difference

Source§

type Output = Tensor

The resulting type after applying the - operator.
Source§

impl Sub<Tensor> for f32

Scalar-tensor subtraction operator implementations

Provides subtraction operations between scalars and tensors. Computes scalar - tensor by negating the tensor and adding the scalar.

Source§

fn sub(self, tensor: Tensor) -> Tensor

Subtracts each element of the tensor from the scalar

§Returns

A new tensor with each element subtracted from the scalar

Source§

type Output = Tensor

The resulting type after applying the - operator.
Source§

impl Sub<f32> for &Tensor

Source§

fn sub(self, scalar: f32) -> Tensor

Subtracts a scalar from each element of the tensor (reference version)

§Returns

A new tensor with the scalar subtracted from each element

Source§

type Output = Tensor

The resulting type after applying the - operator.
Source§

impl Sub<f32> for Tensor

Tensor-scalar subtraction operator implementations

Provides subtraction operations between tensors and scalars. All implementations delegate to the underlying sub_scalar method.

Source§

fn sub(self, scalar: f32) -> Tensor

Subtracts a scalar from each element of the tensor

§Returns

A new tensor with the scalar subtracted from each element

Source§

type Output = Tensor

The resulting type after applying the - operator.
Source§

impl Sub for &Tensor

Source§

fn sub(self, other: &Tensor) -> Tensor

Subtracts two tensors element-wise (reference version)

§Returns

A new tensor containing the element-wise difference

Source§

type Output = Tensor

The resulting type after applying the - operator.
Source§

impl Sub for Tensor

Tensor subtraction operator implementations

Provides subtraction operations between tensors with various reference combinations. All implementations delegate to the underlying sub_tensor method for optimal performance.

Source§

fn sub(self, other: Tensor) -> Tensor

Subtracts two tensors element-wise

§Returns

A new tensor containing the element-wise difference

Source§

type Output = Tensor

The resulting type after applying the - operator.
Source§

impl SubAssign<&Tensor> for Tensor

Source§

fn sub_assign(&mut self, other: &Tensor)

Subtracts another tensor reference from this tensor in-place

Source§

impl SubAssign<f32> for Tensor

Tensor-scalar subtraction assignment operator implementations

Provides in-place subtraction operations between tensors and scalars.

Source§

fn sub_assign(&mut self, scalar: f32)

Subtracts a scalar from each element of this tensor in-place

Source§

impl SubAssign for Tensor

Tensor subtraction assignment operator implementations

Provides in-place subtraction operations between tensors. All implementations delegate to the underlying sub_tensor method.

Source§

fn sub_assign(&mut self, other: Tensor)

Subtracts another tensor from this tensor in-place

Source§

impl Send for Tensor

Source§

impl Sync for Tensor

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> ToFieldValue for T
where T: Serializable,

Source§

fn to_field_value(&self) -> FieldValue

Converts the value to a FieldValue for serialization Read more
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.