pub struct Tensor { /* private fields */ }Expand description
High-performance multi-dimensional tensor with automatic differentiation support
The core data structure for machine learning operations, designed for maximum performance with zero-cost abstractions. Supports arbitrary dimensionality, SIMD optimization, gradient tracking, device placement, and natural mathematical expressions through operator overloading.
§Key Features
- Raw Pointer Storage: Zero-overhead memory access for maximum performance
- SIMD Optimization: AVX2 alignment and vectorized operations
- Memory Efficiency: Optimized alignment strategies for different tensor sizes
- gradtrack Integration: Built-in gradient tracking and computation
- Device Support: CPU and future CUDA device placement
- View Tensors: Zero-copy tensor views with shared memory management
- Thread Safety: Send + Sync implementation for concurrent usage
- Operator Overloading: Natural mathematical expressions (+, -, *, /, +=, -=, *=, /=)
§Memory Layout
Tensors use row-major memory layout with size-dependent alignment:
- Small tensors (≤8 elements): 16-byte SSE alignment
- Medium tensors (8-1024 elements): 32-byte AVX2 alignment
- Large tensors (>1024 elements): 64-byte cache-line alignment
§Performance Characteristics
- Memory Overhead: ~64 bytes per tensor (excluding data)
- SIMD Ready: Properly aligned for vectorized operations
- Cache Friendly: Optimized memory layout for CPU cache hierarchies
- Zero-Cost Views: View tensors share memory without copying
- Thread Safe: Atomic ID generation and lock-free operations
- Operator Performance: Zero-cost operator overloading for mathematical expressions
§Safety
This struct uses unsafe code for performance. The following invariants must be maintained:
datamust be valid forshape.sizeelementsdatamust be properly aligned forf32datamust not be aliased while the tensor existsshape.sizemust match the actual allocated memoryallocation_ownermust be valid if present
§Examples
§Basic Tensor Operations
use train_station::Tensor;
// Create tensors with different configurations
let tensor = Tensor::new(vec![2, 3]);
let tensor_with_grad = Tensor::ones(vec![10, 10]).with_requires_grad();
// Access tensor properties
assert_eq!(tensor.size(), 6);
assert_eq!(tensor.shape().dims(), vec![2, 3]);
assert!(tensor.is_contiguous());§Operator Overloading
use train_station::Tensor;
// Create tensors for operations
let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
// Tensor operations with operators
let result = a.clone() + b.clone(); // Tensor addition
let result = a.clone() * b.clone(); // Element-wise multiplication
let result = a.clone() - b.clone(); // Tensor subtraction
let result = a.clone() / b.clone(); // Element-wise division
// Scalar operations
let result = a.clone() + 5.0; // Tensor + scalar
let result = 5.0 + a.clone(); // Scalar + tensor
let result = a.clone() * 3.0; // Tensor * scalar
let result = 3.0 * a.clone(); // Scalar * tensor
// Compound expressions
let result = (a.clone() + b.clone()) * 2.0 - 1.0; // Complex mathematical expressions
// Assignment operators
let mut c = a.clone();
c += b.clone(); // In-place addition
c *= 2.0; // In-place scalar multiplication
// Negation
let result = -a; // Negate all elements§Thread Safety
This type is Send + Sync and can be safely shared between threads.
All operations are thread-safe through atomic ID generation and
thread-local gradtrack storage.
Implementations§
Source§impl Tensor
impl Tensor
Sourcepub fn capacity_elems(&self) -> usize
pub fn capacity_elems(&self) -> usize
Returns the allocated capacity in elements, which may be padded beyond logical size
Sourcepub fn new(shape_dims: Vec<usize>) -> Self
pub fn new(shape_dims: Vec<usize>) -> Self
Creates a new tensor with the specified shape and optimized memory layout
Allocates memory with size-dependent alignment for optimal performance:
- Small tensors (≤8 elements): 16-byte SSE alignment
- Medium tensors (8-1024 elements): 32-byte AVX2 alignment
- Large tensors (>1024 elements): 64-byte cache-line alignment
§Arguments
shape_dims- Vector of dimension sizes defining the tensor shape
§Returns
A new tensor with uninitialized data. The data must be initialized before use to avoid undefined behavior.
§Performance
- Memory Allocation: Single allocation with optimized alignment
- SIMD Ready: Properly aligned for vectorized operations
- Cache Friendly: Optimized for CPU cache hierarchies
- Thread Safe: Atomic ID generation for gradtrack tracking
§Safety
The returned tensor contains uninitialized memory. You must initialize the data before performing any operations that read from it.
§Examples
use train_station::Tensor;
// Create tensors of different sizes
let small_tensor = Tensor::new(vec![2, 3]); // 16-byte alignment
let medium_tensor = Tensor::new(vec![32, 32]); // 32-byte alignment
let large_tensor = Tensor::new(vec![1000, 1000]); // 64-byte alignment
// Initialize data before use
let mut tensor = Tensor::new(vec![2, 3]);
tensor.fill(0.0); // Initialize with zerosExamples found in repository?
42fn demonstrate_tensor_creation() {
43 println!("--- Tensor Creation ---");
44
45 // Create tensors with different initializations
46 let zeros = Tensor::zeros(vec![2, 3]);
47 println!(
48 "Zeros tensor: shape {:?}, data: {:?}",
49 zeros.shape().dims(),
50 zeros.data()
51 );
52
53 let ones = Tensor::ones(vec![3, 2]);
54 println!(
55 "Ones tensor: shape {:?}, data: {:?}",
56 ones.shape().dims(),
57 ones.data()
58 );
59
60 // Create tensor from slice
61 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
62 let from_slice = Tensor::from_slice(&data, vec![2, 3]).unwrap();
63 println!(
64 "From slice: shape {:?}, data: {:?}",
65 from_slice.shape().dims(),
66 from_slice.data()
67 );
68
69 // Create tensor with specific value
70 let mut filled = Tensor::new(vec![2, 2]);
71 {
72 let data = filled.data_mut();
73 for value in data.iter_mut() {
74 *value = 42.0;
75 }
76 }
77 println!("Filled with 42: {:?}", filled.data());
78
79 // Create tensor with random data
80 let random = Tensor::randn(vec![2, 2], Some(42));
81 println!(
82 "Random tensor: shape {:?}, data: {:?}",
83 random.shape().dims(),
84 random.data()
85 );
86}Sourcepub fn shape(&self) -> &Shape
pub fn shape(&self) -> &Shape
Returns the shape and dimensional information of the tensor
Provides access to the tensor’s dimensions, size, strides, and memory layout information. This is used for shape validation, memory access calculations, and optimization decisions.
§Returns
Reference to the tensor’s shape information containing dimensions, size, strides, and memory layout type.
§Performance
- Time Complexity: O(1) - direct field access
- Memory: No allocation - returns reference to existing data
§Examples
use train_station::Tensor;
let tensor = Tensor::new(vec![2, 3, 4]);
let shape = tensor.shape();
assert_eq!(shape.dims(), vec![2, 3, 4]);
assert_eq!(shape.size(), 24);
assert_eq!(shape.rank(), 3);Examples found in repository?
77 fn triple(t: &Tensor) -> (usize, usize, usize) {
78 let d = t.shape().dims();
79 (d[0], d[1], d[2])
80 }
81}
82
83#[allow(unused)]
84fn main() -> Result<(), Box<dyn std::error::Error>> {
85 println!("=== Basic Decoder Example ===");
86
87 let batch = 2usize;
88 let src = 7usize;
89 let tgt = 5usize;
90 let embed = 32usize;
91 let heads = 4usize;
92
93 let memory = Tensor::randn(vec![batch, src, embed], Some(21));
94 let tgt_in = Tensor::randn(vec![batch, tgt, embed], Some(22));
95
96 let mut dec = DecoderBlock::new(embed, heads, Some(456));
97 let out = dec.forward(&tgt_in, &memory, None, None);
98 println!("Output shape: {:?}", out.shape().dims());
99
100 let mut opt = Adam::with_learning_rate(0.01);
101 let mut params = dec.parameters();
102 for p in ¶ms {
103 opt.add_parameter(p);
104 }
105 let mut loss = out.mean();
106 loss.backward(None);
107 opt.step(&mut params);
108 opt.zero_grad(&mut params);
109 println!("Loss: {:.6}", loss.value());
110 println!("=== Done ===");
111 Ok(())
112}More examples
66 fn triple(t: &Tensor) -> (usize, usize, usize) {
67 let d = t.shape().dims();
68 (d[0], d[1], d[2])
69 }
70}
71
72#[allow(unused)]
73fn main() -> Result<(), Box<dyn std::error::Error>> {
74 println!("=== Basic Encoder Example ===");
75
76 let batch = 2usize;
77 let seq = 6usize;
78 let embed = 32usize;
79 let heads = 4usize;
80
81 let input = Tensor::randn(vec![batch, seq, embed], Some(11));
82 let mut enc = EncoderBlock::new(embed, heads, Some(123));
83
84 // Example: no mask (set Some(mask) to use masking)
85 let out = enc.forward(&input, None);
86 println!("Output shape: {:?}", out.shape().dims());
87
88 // Verify gradients/optimization
89 let mut opt = Adam::with_learning_rate(0.01);
90 let mut params = enc.parameters();
91 for p in ¶ms {
92 opt.add_parameter(p);
93 }
94 let mut loss = out.mean();
95 loss.backward(None);
96 opt.step(&mut params);
97 opt.zero_grad(&mut params);
98 println!("Loss: {:.6}", loss.value());
99 println!("=== Done ===");
100 Ok(())
101}153fn demonstrate_layer_creation() {
154 println!("--- Layer Creation ---");
155
156 let layer = LinearLayer::new(3, 2, Some(42));
157
158 println!("Created linear layer:");
159 println!(" Input size: {}", layer.input_size);
160 println!(" Output size: {}", layer.output_size);
161 println!(" Parameter count: {}", layer.parameter_count());
162 println!(" Weight shape: {:?}", layer.weight.shape().dims());
163 println!(" Bias shape: {:?}", layer.bias.shape().dims());
164 println!(" Weight requires grad: {}", layer.weight.requires_grad());
165 println!(" Bias requires grad: {}", layer.bias.requires_grad());
166}
167
168/// Demonstrate forward pass with gradient tracking
169fn demonstrate_forward_pass() {
170 println!("\n--- Forward Pass (with gradients) ---");
171
172 let layer = LinearLayer::new(3, 2, Some(43));
173
174 // Single input
175 let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
176 let output = layer.forward(&input);
177
178 println!("Single input:");
179 println!(" Input: {:?}", input.data());
180 println!(" Output: {:?}", output.data());
181 println!(" Output requires grad: {}", output.requires_grad());
182
183 // Batch input
184 let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
185 let batch_output = layer.forward(&batch_input);
186
187 println!("Batch input:");
188 println!(" Input shape: {:?}", batch_input.shape().dims());
189 println!(" Output shape: {:?}", batch_output.shape().dims());
190 println!(" Output requires grad: {}", batch_output.requires_grad());
191}
192
193/// Demonstrate forward pass without gradient tracking
194fn demonstrate_forward_pass_no_grad() {
195 println!("\n--- Forward Pass (no gradients) ---");
196
197 let layer = LinearLayer::new(3, 2, Some(44));
198
199 // Single input
200 let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
201 let output = layer.forward_no_grad(&input);
202
203 println!("Single input (no grad):");
204 println!(" Input: {:?}", input.data());
205 println!(" Output: {:?}", output.data());
206 println!(" Output requires grad: {}", output.requires_grad());
207
208 // Compare with grad version
209 let output_with_grad = layer.forward(&input);
210 println!("Comparison:");
211 println!(
212 " Same values: {}",
213 output.data() == output_with_grad.data()
214 );
215 println!(" No grad requires grad: {}", output.requires_grad());
216 println!(
217 " With grad requires grad: {}",
218 output_with_grad.requires_grad()
219 );
220}
221
222/// Demonstrate complete training loop
223fn demonstrate_training_loop() -> Result<(), Box<dyn std::error::Error>> {
224 println!("\n--- Training Loop ---");
225
226 // Create layer and training data
227 let mut layer = LinearLayer::new(2, 1, Some(45));
228
229 // Simple regression task: y = 2*x1 + 3*x2 + 1
230 let x_data = Tensor::from_slice(
231 &[
232 1.0, 1.0, // x1=1, x2=1 -> y=6
233 2.0, 1.0, // x1=2, x2=1 -> y=8
234 1.0, 2.0, // x1=1, x2=2 -> y=9
235 2.0, 2.0, // x1=2, x2=2 -> y=11
236 ],
237 vec![4, 2],
238 )
239 .unwrap();
240
241 let y_true = Tensor::from_slice(&[6.0, 8.0, 9.0, 11.0], vec![4, 1]).unwrap();
242
243 println!("Training data:");
244 println!(" X shape: {:?}", x_data.shape().dims());
245 println!(" Y shape: {:?}", y_true.shape().dims());
246 println!(" Target function: y = 2*x1 + 3*x2 + 1");
247
248 // Create optimizer
249 let config = AdamConfig {
250 learning_rate: 0.01,
251 beta1: 0.9,
252 beta2: 0.999,
253 eps: 1e-8,
254 weight_decay: 0.0,
255 amsgrad: false,
256 };
257
258 let mut optimizer = Adam::with_config(config);
259 let params = layer.parameters();
260 for param in ¶ms {
261 optimizer.add_parameter(param);
262 }
263
264 println!("Optimizer setup complete. Starting training...");
265
266 // Training loop
267 let num_epochs = 100;
268 let mut losses = Vec::new();
269
270 for epoch in 0..num_epochs {
271 // Forward pass
272 let y_pred = layer.forward(&x_data);
273
274 // Compute loss: MSE
275 let diff = y_pred.sub_tensor(&y_true);
276 let mut loss = diff.pow_scalar(2.0).mean();
277
278 // Backward pass
279 loss.backward(None);
280
281 // Optimizer step
282 let mut params = layer.parameters();
283 optimizer.step(&mut params);
284 optimizer.zero_grad(&mut params);
285
286 losses.push(loss.value());
287
288 // Print progress
289 if epoch % 20 == 0 || epoch == num_epochs - 1 {
290 println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
291 }
292 }
293
294 // Evaluate final model
295 let final_predictions = layer.forward_no_grad(&x_data);
296
297 println!("\nFinal model evaluation:");
298 println!(" Learned weights: {:?}", layer.weight.data());
299 println!(" Learned bias: {:?}", layer.bias.data());
300 println!(" Target weights: [2.0, 3.0]");
301 println!(" Target bias: [1.0]");
302
303 println!(" Predictions vs True:");
304 for i in 0..4 {
305 let pred = final_predictions.data()[i];
306 let true_val = y_true.data()[i];
307 println!(
308 " Sample {}: pred={:.3}, true={:.1}, error={:.3}",
309 i + 1,
310 pred,
311 true_val,
312 (pred - true_val).abs()
313 );
314 }
315
316 // Training analysis
317 let initial_loss = losses[0];
318 let final_loss = losses[losses.len() - 1];
319 let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
320
321 println!("\nTraining Analysis:");
322 println!(" Initial loss: {:.6}", initial_loss);
323 println!(" Final loss: {:.6}", final_loss);
324 println!(" Loss reduction: {:.1}%", loss_reduction);
325
326 Ok(())
327}
328
329/// Demonstrate single vs batch inference
330fn demonstrate_single_vs_batch_inference() {
331 println!("\n--- Single vs Batch Inference ---");
332
333 let layer = LinearLayer::new(4, 3, Some(46));
334
335 // Single inference
336 println!("Single inference:");
337 let single_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
338 let single_output = layer.forward_no_grad(&single_input);
339 println!(" Input shape: {:?}", single_input.shape().dims());
340 println!(" Output shape: {:?}", single_output.shape().dims());
341 println!(" Output: {:?}", single_output.data());
342
343 // Batch inference
344 println!("Batch inference:");
345 let batch_input = Tensor::from_slice(
346 &[
347 1.0, 2.0, 3.0, 4.0, // Sample 1
348 5.0, 6.0, 7.0, 8.0, // Sample 2
349 9.0, 10.0, 11.0, 12.0, // Sample 3
350 ],
351 vec![3, 4],
352 )
353 .unwrap();
354 let batch_output = layer.forward_no_grad(&batch_input);
355 println!(" Input shape: {:?}", batch_input.shape().dims());
356 println!(" Output shape: {:?}", batch_output.shape().dims());
357
358 // Verify batch consistency - first sample should match single inference
359 let _first_batch_sample = batch_output.view(vec![3, 3]); // Reshape to access first sample
360 let first_sample_data = &batch_output.data()[0..3]; // First 3 elements
361 let single_sample_data = single_output.data();
362
363 println!("Consistency check:");
364 println!(" Single output: {:?}", single_sample_data);
365 println!(" First batch sample: {:?}", first_sample_data);
366 println!(
367 " Match: {}",
368 single_sample_data
369 .iter()
370 .zip(first_sample_data.iter())
371 .all(|(a, b)| (a - b).abs() < 1e-6)
372 );
373}- examples/iterators/element_iteration.rs
- examples/getting_started/tensor_basics.rs
- examples/getting_started/tensor_operators.rs
- examples/getting_started/optimizer_basics.rs
- examples/neural_networks/feedforward_network.rs
- examples/neural_networks/basic_transformer.rs
- examples/getting_started/serialization_basics.rs
- examples/neural_networks/multi_head_attention.rs
- examples/iterators/advanced_patterns.rs
Sourcepub fn size(&self) -> usize
pub fn size(&self) -> usize
Returns the total number of elements in the tensor
Provides the total count of elements across all dimensions. This is used for memory allocation, iteration bounds, and performance optimization.
§Returns
Total number of elements as usize
§Performance
- Time Complexity: O(1) - direct field access
- Memory: No allocation - returns stored value
§Examples
use train_station::Tensor;
let tensor = Tensor::new(vec![2, 3, 4]);
assert_eq!(tensor.size(), 24); // 2 * 3 * 4
let scalar = Tensor::new(vec![1]);
assert_eq!(scalar.size(), 1);
let empty = Tensor::new(vec![0]);
assert_eq!(empty.size(), 0);Examples found in repository?
179fn demonstrate_utility_functions() {
180 println!("\n--- Utility Functions ---");
181
182 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184 // Basic properties
185 println!("Shape: {:?}", tensor.shape().dims());
186 println!("Size: {}", tensor.size());
187 println!("Is contiguous: {}", tensor.is_contiguous());
188 println!("Device: {:?}", tensor.device());
189
190 // Mathematical operations
191 let sum = tensor.sum();
192 println!("Sum: {}", sum.value());
193
194 let mean = tensor.mean();
195 println!("Mean: {}", mean.value());
196
197 let norm = tensor.norm();
198 println!("Norm: {}", norm.value());
199
200 // Device placement
201 let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202 println!(
203 "CPU tensor: shape {:?}, device: {:?}",
204 cpu_tensor.shape().dims(),
205 cpu_tensor.device()
206 );
207}More examples
162fn demonstrate_memory_optimization() -> Result<(), Box<dyn std::error::Error>> {
163 println!("\n--- Memory Optimization ---");
164
165 // Create a large tensor for memory testing
166 let size = 10000;
167 let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
168 let tensor = Tensor::from_slice(&data, vec![size])?;
169
170 println!("Processing tensor of size: {}", size);
171
172 // Pattern 1: Streaming processing with iterator chunks (process in blocks, collect with shape)
173 println!("\nPattern 1: Streaming Processing");
174 let chunk_size = 1000;
175 let start = Instant::now();
176 let flattened = tensor.view(vec![size as i32]);
177 let _streamed_result: Tensor = flattened
178 .chunks(chunk_size)
179 .map(|c| c.pow_scalar(2.0).sqrt())
180 .collect_shape(vec![size]);
181 let streamed_time = start.elapsed();
182
183 // Pattern 2: Full processing
184 let start = Instant::now();
185 let _full_result: Tensor = tensor
186 .iter_elements()
187 .map(|elem| elem.pow_scalar(2.0).sqrt())
188 .collect_shape(vec![size]);
189 let full_time = start.elapsed();
190
191 println!(" Streaming time: {:?}", streamed_time);
192 println!(" Full processing time: {:?}", full_time);
193 println!(
194 " Memory efficiency ratio: {:.2}x",
195 full_time.as_nanos() as f64 / streamed_time.as_nanos() as f64
196 );
197
198 // Pattern 3: Lazy evaluation with take
199 println!("\nPattern 2: Lazy Evaluation");
200 let start = Instant::now();
201 let lazy_result: Tensor = tensor
202 .iter_elements()
203 .take(1000) // Only process first 1000 elements
204 .map(|elem| elem.pow_scalar(2.0).sqrt())
205 .collect_shape(vec![1000]);
206 let lazy_time = start.elapsed();
207
208 println!(" Lazy processing (1000 elements): {:?}", lazy_time);
209 println!(" Lazy result size: {}", lazy_result.size());
210
211 // Pattern 4: Memory-efficient filtering
212 println!("\nPattern 3: Memory-Efficient Filtering");
213 let start = Instant::now();
214 let filtered_result: Tensor = tensor
215 .iter_elements()
216 .filter(|elem| elem.value() > size as f32 / 2.0) // Keep only large values
217 .map(|elem| elem.mul_scalar(2.0))
218 .collect();
219 let filtered_time = start.elapsed();
220
221 println!(" Filtered processing: {:?}", filtered_time);
222 println!(
223 " Filtered result size: {} (reduced from {})",
224 filtered_result.size(),
225 size
226 );
227
228 Ok(())
229}
230
231/// Demonstrate large-scale processing techniques
232///
233/// Shows how to efficiently process very large datasets using
234/// iterator patterns and optimization strategies.
235fn demonstrate_large_scale_processing() -> Result<(), Box<dyn std::error::Error>> {
236 println!("\n--- Large-Scale Processing ---");
237
238 // Simulate large dataset processing
239 let sizes = vec![10000, 50000, 100000];
240
241 for size in sizes {
242 println!("\nProcessing dataset of size: {}", size);
243
244 // Generate large dataset
245 let data: Vec<f32> = (0..size)
246 .map(|i| {
247 let x = i as f32 / size as f32;
248 x * x + 0.1 * (i % 10) as f32 // Quadratic with noise
249 })
250 .collect();
251
252 let tensor = Tensor::from_slice(&data, vec![size])?;
253
254 // Technique 1: Batch processing
255 let batch_size = 1000;
256 let start = Instant::now();
257
258 let mut batch_results = Vec::new();
259 for batch_start in (0..size).step_by(batch_size) {
260 let batch_end = (batch_start + batch_size).min(size);
261 let batch: Tensor = tensor
262 .iter_range(batch_start, batch_end)
263 .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
264 .collect();
265 batch_results.push(batch);
266 }
267 let batch_time = start.elapsed();
268
269 // Technique 2: Parallel-like processing with stride
270 let start = Instant::now();
271 let stride = 4;
272 let strided_result: Tensor = tensor
273 .iter()
274 .enumerate()
275 .filter(|(i, _)| i % stride == 0)
276 .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
277 .collect();
278 let strided_time = start.elapsed();
279
280 // Technique 3: Hierarchical processing
281 let start = Instant::now();
282 let coarse: Tensor = tensor
283 .iter()
284 .enumerate()
285 .filter(|(i, _)| i % 10 == 0) // Every 10th element
286 .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
287 .collect();
288 let fine: Tensor = tensor
289 .iter()
290 .enumerate()
291 .filter(|(i, _)| i % 10 != 0) // Rest of elements
292 .map(|(_, elem)| elem.pow_scalar(1.5).add_scalar(0.5))
293 .collect();
294 let hierarchical_time = start.elapsed();
295
296 // Report performance
297 println!(" Batch processing: {:?}", batch_time);
298 println!(" Strided processing: {:?}", strided_time);
299 println!(" Hierarchical processing: {:?}", hierarchical_time);
300
301 // Memory usage analysis
302 let total_batches = size.div_ceil(batch_size);
303 println!(" Batch count: {}", total_batches);
304 println!(" Strided result size: {}", strided_result.size());
305 println!(
306 " Hierarchical: coarse={}, fine={}",
307 coarse.size(),
308 fine.size()
309 );
310 }
311
312 Ok(())
313}
314
315/// Demonstrate advanced optimization techniques
316///
317/// Shows sophisticated optimization strategies and techniques
318/// for maximizing performance in tensor iterator operations.
319fn demonstrate_optimization_techniques() -> Result<(), Box<dyn std::error::Error>> {
320 println!("\n--- Optimization Techniques ---");
321
322 let size = 50000;
323 let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
324 let tensor = Tensor::from_slice(&data, vec![size])?;
325
326 println!("Optimizing processing for size: {}", size);
327
328 // Technique 1: Operation fusion
329 println!("\nTechnique 1: Operation Fusion");
330 let start = Instant::now();
331 let fused_result: Tensor = tensor
332 .iter()
333 .map(|elem| {
334 // Fuse multiple operations into single chain
335 elem.mul_scalar(2.0).add_scalar(1.0).pow_scalar(2.0).sqrt()
336 })
337 .collect();
338 let fused_time = start.elapsed();
339
340 // Technique 2: Conditional optimization
341 println!("\nTechnique 2: Conditional Optimization");
342 let start = Instant::now();
343 let conditional_result: Tensor = tensor
344 .iter()
345 .map(|elem| {
346 let val = elem.value();
347 if val < size as f32 / 2.0 {
348 elem.mul_scalar(2.0) // Simple operation for small values
349 } else {
350 elem.pow_scalar(2.0).sqrt() // Complex operation for large values
351 }
352 })
353 .collect();
354 let conditional_time = start.elapsed();
355
356 // Technique 3: Cache-friendly processing
357 println!("\nTechnique 3: Cache-Friendly Processing");
358 let start = Instant::now();
359 let cache_friendly_result: Tensor = tensor
360 .iter()
361 .take(1000) // Process in cache-friendly chunks
362 .map(|elem| elem.mul_scalar(2.0))
363 .collect();
364 let cache_friendly_time = start.elapsed();
365
366 // Technique 4: Memory pooling simulation
367 println!("\nTechnique 4: Memory Pooling Simulation");
368 let start = Instant::now();
369 let pooled_result: Tensor = tensor
370 .iter()
371 .enumerate()
372 .filter(|(i, _)| i % 100 == 0) // Process every 100th element
373 .map(|(_, elem)| elem.pow_scalar(2.0))
374 .collect();
375 let pooled_time = start.elapsed();
376
377 // Report optimization results
378 println!(" Fused operations: {:?}", fused_time);
379 println!(" Conditional optimization: {:?}", conditional_time);
380 println!(" Cache-friendly processing: {:?}", cache_friendly_time);
381 println!(" Memory pooling simulation: {:?}", pooled_time);
382
383 // Performance analysis
384 let fastest = fused_time
385 .min(conditional_time)
386 .min(cache_friendly_time)
387 .min(pooled_time);
388 println!(" Fastest technique: {:?}", fastest);
389
390 // Memory efficiency analysis
391 println!(" Fused result size: {}", fused_result.size());
392 println!(" Conditional result size: {}", conditional_result.size());
393 println!(
394 " Cache-friendly result size: {}",
395 cache_friendly_result.size()
396 );
397 println!(" Pooled result size: {}", pooled_result.size());
398
399 // Technique 5: Gradient optimization
400 println!("\nTechnique 5: Gradient Optimization");
401 let grad_tensor = tensor.with_requires_grad();
402 let start = Instant::now();
403
404 let grad_result: Tensor = grad_tensor
405 .iter()
406 .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
407 .collect();
408
409 let mut loss = grad_result.sum();
410 loss.backward(None);
411 let grad_time = start.elapsed();
412
413 println!(" Gradient computation: {:?}", grad_time);
414 println!(
415 " Gradient tracking enabled: {}",
416 grad_result.requires_grad()
417 );
418
419 Ok(())
420}87fn demonstrate_data_pipeline() -> Result<(), Box<dyn std::error::Error>> {
88 println!("\n--- Data Processing Pipeline ---");
89
90 // Simulate raw sensor data with noise
91 let raw_data: Vec<f32> = (0..20)
92 .map(|i| {
93 let base = i as f32 * 0.5;
94 let noise = (i % 3) as f32 * 0.1;
95 base + noise
96 })
97 .collect();
98
99 let tensor = Tensor::from_slice(&raw_data, vec![20])?;
100 println!("Raw sensor data: {:?}", tensor.data());
101
102 // Multi-stage processing pipeline
103 println!("\nProcessing pipeline:");
104 println!("1. Normalize data (z-score)");
105 println!("2. Apply smoothing filter");
106 println!("3. Detect outliers");
107 println!("4. Apply feature scaling");
108
109 // Stage 1: Normalization
110 let mean = tensor.mean().value();
111 let std = tensor.std().value();
112 let normalized: Tensor = tensor
113 .iter()
114 .map(|elem| elem.sub_scalar(mean).div_scalar(std))
115 .collect();
116 println!(
117 " Normalized (mean={:.3}, std={:.3}): {:?}",
118 mean,
119 std,
120 normalized.data()
121 );
122
123 // Stage 2: Smoothing (simple moving average)
124 let smoothed: Tensor = normalized
125 .iter()
126 .enumerate()
127 .map(|(i, elem)| {
128 if i == 0 || i == normalized.size() - 1 {
129 elem.clone()
130 } else {
131 // Simple 3-point average
132 let prev = normalized.element_view(i - 1);
133 let next = normalized.element_view(i + 1);
134 elem.add_tensor(&prev).add_tensor(&next).div_scalar(3.0)
135 }
136 })
137 .collect();
138 println!(" Smoothed: {:?}", smoothed.data());
139
140 // Stage 3: Outlier detection and removal
141 let outlier_threshold = 2.0;
142 let cleaned: Tensor = smoothed
143 .iter()
144 .filter(|elem| elem.value().abs() < outlier_threshold)
145 .collect();
146 println!(
147 " Outliers removed (threshold={}): {:?}",
148 outlier_threshold,
149 cleaned.data()
150 );
151
152 // Stage 4: Feature scaling to [0, 1] range
153 let min_val = cleaned
154 .iter()
155 .map(|e| e.value())
156 .fold(f32::INFINITY, f32::min);
157 let max_val = cleaned
158 .iter()
159 .map(|e| e.value())
160 .fold(f32::NEG_INFINITY, f32::max);
161 let scaled: Tensor = cleaned
162 .iter()
163 .map(|elem| elem.sub_scalar(min_val).div_scalar(max_val - min_val))
164 .collect();
165 println!(" Scaled to [0,1]: {:?}", scaled.data());
166
167 Ok(())
168}
169
170/// Demonstrate conditional processing patterns
171///
172/// Shows how to implement dynamic filtering and transformation
173/// based on data characteristics and conditions.
174fn demonstrate_conditional_processing() -> Result<(), Box<dyn std::error::Error>> {
175 println!("\n--- Conditional Processing ---");
176
177 // Create data with mixed characteristics
178 let data = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0];
179 let tensor = Tensor::from_slice(&data, vec![10])?;
180 println!("Input data: {:?}", tensor.data());
181
182 // Conditional transformation based on sign
183 println!("\nConditional transformation (positive/negative handling):");
184 let processed: Tensor = tensor
185 .iter()
186 .map(|elem| {
187 let val = elem.value();
188 if val > 0.0 {
189 elem.pow_scalar(2.0) // Square positive values
190 } else {
191 elem.mul_scalar(-1.0).sqrt() // Square root of absolute negative values
192 }
193 })
194 .collect();
195 println!(" Processed: {:?}", processed.data());
196
197 // Adaptive filtering based on local statistics
198 println!("\nAdaptive filtering (remove values > 2 std from local mean):");
199 let window_size = 3;
200 let adaptive_filtered: Tensor = tensor
201 .iter()
202 .enumerate()
203 .filter(|(i, elem)| {
204 let start = i.saturating_sub(window_size / 2);
205 let end = (i + window_size / 2 + 1).min(tensor.size());
206
207 // Calculate local mean and std
208 let local_values: Vec<f32> = (start..end)
209 .map(|j| tensor.element_view(j).value())
210 .collect();
211
212 let local_mean = local_values.iter().sum::<f32>() / local_values.len() as f32;
213 let local_variance = local_values
214 .iter()
215 .map(|v| (v - local_mean).powi(2))
216 .sum::<f32>()
217 / local_values.len() as f32;
218 let local_std = local_variance.sqrt();
219
220 let threshold = local_mean + 2.0 * local_std;
221 elem.value() <= threshold
222 })
223 .map(|(_, elem)| elem)
224 .collect();
225 println!(" Adaptive filtered: {:?}", adaptive_filtered.data());
226
227 // Multi-condition processing
228 println!("\nMulti-condition processing:");
229 let multi_processed: Tensor = tensor
230 .iter()
231 .map(|elem| {
232 let val = elem.value();
233 match () {
234 _ if val > 5.0 => elem.mul_scalar(2.0), // Double large values
235 _ if val < -5.0 => elem.div_scalar(2.0), // Halve small values
236 _ if val.abs() < 2.0 => elem.add_scalar(1.0), // Add 1 to small values
237 _ => elem.clone(), // Keep others unchanged
238 }
239 })
240 .collect();
241 println!(" Multi-condition: {:?}", multi_processed.data());
242
243 Ok(())
244}
245
246/// Demonstrate batch processing operations
247///
248/// Shows efficient processing of large datasets using iterator
249/// patterns and batch operations for performance optimization.
250fn demonstrate_batch_operations() -> Result<(), Box<dyn std::error::Error>> {
251 println!("\n--- Batch Operations ---");
252
253 // Create a larger dataset for batch processing
254 let size = 100;
255 let data: Vec<f32> = (0..size)
256 .map(|i| {
257 let x = i as f32 / size as f32;
258 x * x + 0.1 * (i % 7) as f32 // Quadratic with some noise
259 })
260 .collect();
261
262 let tensor = Tensor::from_slice(&data, vec![size])?;
263 println!("Dataset size: {}", tensor.size());
264
265 // Batch processing with windowing (iterator views)
266 println!("\nBatch processing with sliding windows:");
267 let batch_size = 10;
268 let batches: Vec<Tensor> = tensor
269 .iter()
270 .collect::<Vec<_>>()
271 .chunks(batch_size)
272 .map(|chunk| {
273 // Process each batch independently
274 chunk
275 .iter()
276 .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
277 .collect()
278 })
279 .collect();
280
281 println!(
282 " Processed {} batches of size {}",
283 batches.len(),
284 batch_size
285 );
286 for (i, batch) in batches.iter().enumerate() {
287 println!(
288 " Batch {}: mean={:.3}, std={:.3}",
289 i,
290 batch.mean().value(),
291 batch.std().value()
292 );
293 }
294
295 // Parallel-like processing with stride
296 println!("\nStrided processing (every nth element):");
297 let stride = 5;
298 let strided: Tensor = tensor
299 .iter()
300 .enumerate()
301 .filter(|(i, _)| i % stride == 0)
302 .map(|(_, elem)| elem)
303 .collect();
304 println!(" Strided (every {}th): {:?}", stride, strided.data());
305
306 // Hierarchical processing
307 println!("\nHierarchical processing (coarse to fine):");
308 let coarse: Tensor = tensor
309 .iter()
310 .enumerate()
311 .filter(|(i, _)| i % 4 == 0) // Take every 4th element
312 .map(|(_, elem)| elem)
313 .collect();
314
315 let fine: Tensor = tensor
316 .iter()
317 .enumerate()
318 .filter(|(i, _)| i % 4 != 0) // Take the rest
319 .map(|(_, elem)| elem)
320 .collect();
321
322 println!(" Coarse (every 4th): {:?}", coarse.data());
323 println!(" Fine (rest): {:?}", fine.data());
324
325 // Combine coarse and fine with different processing
326 let combined: Tensor = coarse
327 .iter()
328 .map(|elem| elem.mul_scalar(2.0)) // Scale coarse
329 .chain(fine.iter().map(|elem| elem.div_scalar(2.0))) // Scale fine
330 .collect();
331 println!(" Combined: {:?}", combined.data());
332
333 Ok(())
334}
335
336/// Demonstrate real-world processing scenarios
337///
338/// Shows practical applications of iterator patterns for
339/// common data processing tasks in machine learning and analytics.
340fn demonstrate_real_world_scenarios() -> Result<(), Box<dyn std::error::Error>> {
341 println!("\n--- Real-world Scenarios ---");
342
343 // Scenario 1: Time series analysis
344 println!("\nScenario 1: Time Series Analysis");
345 let time_series: Vec<f32> = (0..24)
346 .map(|hour| {
347 let base = 20.0 + 10.0 * (hour as f32 * std::f32::consts::PI / 12.0).sin();
348 base + (hour % 3) as f32 * 2.0 // Add some noise
349 })
350 .collect();
351
352 let series = Tensor::from_slice(&time_series, vec![24])?;
353 println!(" Time series (24 hours): {:?}", series.data());
354
355 // Calculate moving average with view-based iteration
356 let window_size = 3;
357 let moving_avg: Tensor = series
358 .iter()
359 .enumerate()
360 .map(|(i, _)| {
361 let start = i.saturating_sub(window_size / 2);
362 let end = (i + window_size / 2 + 1).min(series.size());
363 let window = series.iter_range(start, end);
364 window.fold(0.0, |acc, elem| acc + elem.value()) / (end - start) as f32
365 })
366 .map(|val| Tensor::from_slice(&[val], vec![1]).unwrap())
367 .collect();
368 println!(
369 " Moving average (window={}): {:?}",
370 window_size,
371 moving_avg.data()
372 );
373
374 // Inference pipeline with NoGrad + streaming
375 println!("\nInference pipeline (NoGrad + streaming)");
376 let features = Tensor::from_slice(
377 &(0..48).map(|i| i as f32 * 0.125).collect::<Vec<_>>(),
378 vec![6, 8],
379 )?;
380 let fast = with_no_grad(|| {
381 // Stream values directly, apply light affine, and collect back to same shape
382 features
383 .data()
384 .iter()
385 .copied()
386 .map(|x| 0.75 * x + 0.1)
387 .collect_shape(vec![6, 8])
388 });
389 println!(
390 " NoGrad streamed transform shape: {:?}",
391 fast.shape().dims()
392 );
393
394 // Row-wise iteration with shape-preserving collection (GradTrack-friendly)
395 let per_row: Tensor = features
396 .iter()
397 .map(|row| row.mul_scalar(0.5).add_scalar(2.0))
398 .collect_shape(vec![6, 8]);
399 println!(" Row-wise mapped shape: {:?}", per_row.shape().dims());
400
401 // Scenario 2: Feature engineering
402 println!("\nScenario 2: Feature Engineering");
403 let features = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
404 println!(" Original features: {:?}", features.data());
405
406 // Create polynomial features
407 let poly_features: Tensor = features
408 .iter()
409 .flat_map(|elem| {
410 vec![
411 elem.clone(), // x^1
412 elem.pow_scalar(2.0), // x^2
413 elem.pow_scalar(3.0), // x^3
414 ]
415 })
416 .collect();
417 println!(
418 " Polynomial features (x, x^2, x^3): {:?}",
419 poly_features.data()
420 );
421
422 // Scenario 3: Data augmentation
423 println!("\nScenario 3: Data Augmentation");
424 let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?;
425 println!(" Original data: {:?}", original.data());
426
427 // Augment with noise and scaling
428 let augmented: Tensor = original
429 .iter()
430 .flat_map(|elem| {
431 vec![
432 elem.clone(), // Original
433 elem.add_scalar(0.1), // Add noise
434 elem.sub_scalar(0.1), // Subtract noise
435 elem.mul_scalar(1.1), // Scale up
436 elem.mul_scalar(0.9), // Scale down
437 ]
438 })
439 .collect();
440 println!(" Augmented data: {:?}", augmented.data());
441
442 // Scenario 4: Statistical analysis
443 println!("\nScenario 4: Statistical Analysis");
444 let sample_data = Tensor::from_slice(&[1.1, 2.3, 1.8, 2.1, 1.9, 2.0, 1.7, 2.2], vec![8])?;
445 println!(" Sample data: {:?}", sample_data.data());
446
447 // Calculate various statistics
448 let mean = sample_data.mean().value();
449 let std = sample_data.std().value();
450 let min = sample_data
451 .iter()
452 .map(|e| e.value())
453 .fold(f32::INFINITY, f32::min);
454 let max = sample_data
455 .iter()
456 .map(|e| e.value())
457 .fold(f32::NEG_INFINITY, f32::max);
458
459 // Z-score normalization
460 let z_scores: Tensor = sample_data
461 .iter()
462 .map(|elem| elem.sub_scalar(mean).div_scalar(std))
463 .collect();
464
465 println!(
466 " Statistics: mean={:.3}, std={:.3}, min={:.3}, max={:.3}",
467 mean, std, min, max
468 );
469 println!(" Z-scores: {:?}", z_scores.data());
470
471 Ok(())
472}Sourcepub fn device(&self) -> Device
pub fn device(&self) -> Device
Returns the device where this tensor is located
Provides the physical location of the tensor data (CPU/GPU). This determines which operations can be performed on the tensor and where computations will be executed.
§Returns
Device enum indicating the tensor’s physical location
§Performance
- Time Complexity: O(1) - direct field access
- Memory: No allocation - returns stored value
§Examples
use train_station::Tensor;
let tensor = Tensor::new(vec![2, 3]);
assert!(tensor.device().is_cpu());
assert!(!tensor.device().is_cuda());Examples found in repository?
179fn demonstrate_utility_functions() {
180 println!("\n--- Utility Functions ---");
181
182 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184 // Basic properties
185 println!("Shape: {:?}", tensor.shape().dims());
186 println!("Size: {}", tensor.size());
187 println!("Is contiguous: {}", tensor.is_contiguous());
188 println!("Device: {:?}", tensor.device());
189
190 // Mathematical operations
191 let sum = tensor.sum();
192 println!("Sum: {}", sum.value());
193
194 let mean = tensor.mean();
195 println!("Mean: {}", mean.value());
196
197 let norm = tensor.norm();
198 println!("Norm: {}", norm.value());
199
200 // Device placement
201 let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202 println!(
203 "CPU tensor: shape {:?}, device: {:?}",
204 cpu_tensor.shape().dims(),
205 cpu_tensor.device()
206 );
207}Sourcepub fn new_on_device(shape_dims: Vec<usize>, device: Device) -> Self
pub fn new_on_device(shape_dims: Vec<usize>, device: Device) -> Self
Creates a new tensor with the specified shape on a specific device
Allocates memory on the specified device with the same optimized alignment
strategy as new(). Currently supports CPU device with future CUDA support.
§Arguments
shape_dims- Vector of dimension sizes defining the tensor shapedevice- The device where the tensor should be allocated
§Returns
A new tensor with uninitialized data on the specified device
§Performance
- Memory Allocation: Device-specific allocation with optimized alignment
- SIMD Ready: Properly aligned for vectorized operations on target device
- Thread Safe: Atomic ID generation for gradtrack tracking
§Panics
Panics if the specified device is not supported (e.g., CUDA without feature flag)
§Examples
use train_station::Tensor;
let tensor = Tensor::new_on_device(vec![2, 3], train_station::Device::cpu());
assert!(tensor.device().is_cpu());
assert_eq!(tensor.size(), 6);§Arguments
shape_dims- Vector of dimension sizes defining the tensor shapedevice- Device where the tensor should be allocated
§Returns
A new tensor with uninitialized data on the specified device
§Panics
Panics if the device is not supported (currently only CPU is supported)
§Performance
- Memory Allocation: Single allocation with optimized alignment
- SIMD Ready: Properly aligned for vectorized operations
- Cache Friendly: Optimized for CPU cache hierarchies
- Thread Safe: Atomic ID generation for gradtrack tracking
§Examples
use train_station::{Tensor, Device};
// Create tensor on CPU device
let tensor = Tensor::new_on_device(vec![2, 3], Device::cpu());
assert_eq!(tensor.device(), Device::cpu());
assert_eq!(tensor.size(), 6);Sourcepub fn with_requires_grad(self) -> Self
pub fn with_requires_grad(self) -> Self
Enable gradient computation for this tensor
Builder method that enables automatic gradient tracking for this tensor. When enabled, all operations involving this tensor will be recorded in the computation graph for gradient computation during backward pass.
§Returns
self with gradient tracking enabled
§Performance
- Time Complexity: O(1) - simple field assignment
- Memory: No additional allocation
- Overhead: Minimal gradtrack tracking overhead when gradients computed
§Examples
use train_station::Tensor;
let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
assert!(tensor.requires_grad());Examples found in repository?
More examples
53 pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
54 // Xavier/Glorot initialization: scale by sqrt(1/input_size)
55 let scale = (1.0 / input_size as f32).sqrt();
56
57 let weight = Tensor::randn(vec![input_size, output_size], seed)
58 .mul_scalar(scale)
59 .with_requires_grad();
60 let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
61
62 Self {
63 weight,
64 bias,
65 input_size,
66 output_size,
67 }
68 }
69
70 /// Forward pass: output = input @ weight + bias
71 pub fn forward(&self, input: &Tensor) -> Tensor {
72 // Matrix multiplication: [batch_size, input_size] @ [input_size, output_size] = [batch_size, output_size]
73 let output = input.matmul(&self.weight);
74 // Add bias: [batch_size, output_size] + [output_size] = [batch_size, output_size]
75 output.add_tensor(&self.bias)
76 }
77
78 /// Forward pass without gradients (for inference)
79 #[allow(unused)]
80 pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
81 let _guard = NoGradTrack::new();
82 self.forward(input)
83 }
84
85 /// Get all parameters for optimization
86 pub fn parameters(&mut self) -> Vec<&mut Tensor> {
87 vec![&mut self.weight, &mut self.bias]
88 }
89
90 /// Save layer parameters to JSON
91 #[allow(unused)]
92 pub fn save_json(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
93 // Create directory if it doesn't exist
94 if let Some(parent) = std::path::Path::new(path).parent() {
95 fs::create_dir_all(parent)?;
96 }
97
98 let weight_path = format!("{}_weight.json", path);
99 let bias_path = format!("{}_bias.json", path);
100
101 self.weight.save_json(&weight_path)?;
102 self.bias.save_json(&bias_path)?;
103
104 println!("Saved linear layer to {} (weight and bias)", path);
105 Ok(())
106 }
107
108 /// Load layer parameters from JSON
109 #[allow(unused)]
110 pub fn load_json(
111 path: &str,
112 input_size: usize,
113 output_size: usize,
114 ) -> Result<Self, Box<dyn std::error::Error>> {
115 let weight_path = format!("{}_weight.json", path);
116 let bias_path = format!("{}_bias.json", path);
117
118 let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
119 let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
120
121 Ok(Self {
122 weight,
123 bias,
124 input_size,
125 output_size,
126 })
127 }175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176 println!("\n--- Gradient Tracking ---");
177
178 // Create a tensor with gradient tracking enabled
179 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180 println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182 // Perform element-wise operations through iteration
183 let result: Tensor = tensor
184 .iter()
185 .map(|elem| {
186 // Apply a complex transformation: (x^2 + 1) * 2
187 elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188 })
189 .collect();
190
191 println!("Result tensor: {:?}", result.data());
192 println!("Result requires_grad: {}", result.requires_grad());
193
194 // Compute gradients
195 let mut loss = result.sum();
196 loss.backward(None);
197
198 println!("Loss: {:.6}", loss.value());
199 println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201 Ok(())
202}200 pub fn load_json(
201 path: &str,
202 config: FeedForwardConfig,
203 ) -> Result<Self, Box<dyn std::error::Error>> {
204 let mut layers = Vec::new();
205 let mut current_size = config.input_size;
206 let mut layer_idx = 0;
207
208 // Load hidden layers
209 for &hidden_size in &config.hidden_sizes {
210 let layer_path = format!("{}_layer_{}", path, layer_idx);
211 let weight_path = format!("{}_weight.json", layer_path);
212 let bias_path = format!("{}_bias.json", layer_path);
213
214 let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
215 let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
216
217 layers.push(LinearLayer {
218 weight,
219 bias,
220 input_size: current_size,
221 output_size: hidden_size,
222 });
223
224 current_size = hidden_size;
225 layer_idx += 1;
226 }
227
228 // Load output layer
229 let layer_path = format!("{}_layer_{}", path, layer_idx);
230 let weight_path = format!("{}_weight.json", layer_path);
231 let bias_path = format!("{}_bias.json", layer_path);
232
233 let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
234 let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
235
236 layers.push(LinearLayer {
237 weight,
238 bias,
239 input_size: current_size,
240 output_size: config.output_size,
241 });
242
243 Ok(Self { layers, config })
244 }47fn demonstrate_basic_optimizer_setup() {
48 println!("--- Basic Optimizer Setup ---");
49
50 // Create parameters that require gradients
51 let weight = Tensor::randn(vec![3, 2], Some(42)).with_requires_grad();
52 let bias = Tensor::zeros(vec![2]).with_requires_grad();
53
54 println!("Created parameters:");
55 println!(
56 " Weight: shape {:?}, requires_grad: {}",
57 weight.shape().dims(),
58 weight.requires_grad()
59 );
60 println!(
61 " Bias: shape {:?}, requires_grad: {}",
62 bias.shape().dims(),
63 bias.requires_grad()
64 );
65
66 // Create Adam optimizer with default configuration
67 let mut optimizer = Adam::new();
68 println!(
69 "Created Adam optimizer with learning rate: {}",
70 optimizer.learning_rate()
71 );
72
73 // Add parameters to optimizer
74 optimizer.add_parameter(&weight);
75 optimizer.add_parameter(&bias);
76 println!(
77 "Added {} parameters to optimizer",
78 optimizer.parameter_count()
79 );
80
81 // Create optimizer with custom configuration
82 let config = AdamConfig {
83 learning_rate: 0.01,
84 beta1: 0.9,
85 beta2: 0.999,
86 eps: 1e-8,
87 weight_decay: 0.0,
88 amsgrad: false,
89 };
90
91 let mut custom_optimizer = Adam::with_config(config);
92 custom_optimizer.add_parameter(&weight);
93 custom_optimizer.add_parameter(&bias);
94
95 println!(
96 "Created custom optimizer with learning rate: {}",
97 custom_optimizer.learning_rate()
98 );
99
100 // Demonstrate parameter linking
101 println!("Parameter linking completed successfully");
102}
103
104/// Demonstrate simple linear regression training
105fn demonstrate_linear_regression() -> Result<(), Box<dyn std::error::Error>> {
106 println!("\n--- Linear Regression Training ---");
107
108 // Create model parameters
109 let mut weight = Tensor::randn(vec![1, 1], Some(43)).with_requires_grad();
110 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
111
112 // Create optimizer
113 let mut optimizer = Adam::with_learning_rate(0.01);
114 optimizer.add_parameter(&weight);
115 optimizer.add_parameter(&bias);
116
117 // Create simple training data: y = 2*x + 1
118 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
119 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
120
121 println!("Training data:");
122 println!(" X: {:?}", x_data.data());
123 println!(" Y: {:?}", y_true.data());
124 println!(" Target: y = 2*x + 1");
125
126 // Training loop
127 let num_epochs = 100;
128 let mut losses = Vec::new();
129
130 for epoch in 0..num_epochs {
131 // Forward pass: y_pred = x * weight + bias
132 let y_pred = x_data.matmul(&weight) + &bias;
133
134 // Compute loss: MSE
135 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
136
137 // Backward pass
138 loss.backward(None);
139
140 // Optimizer step
141 optimizer.step(&mut [&mut weight, &mut bias]);
142 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
143
144 losses.push(loss.value());
145
146 // Print progress every 20 epochs
147 if epoch % 20 == 0 || epoch == num_epochs - 1 {
148 println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
149 }
150 }
151
152 // Evaluate final model
153 let final_predictions = x_data.matmul(&weight) + &bias;
154 println!("\nFinal model evaluation:");
155 println!(" Learned weight: {:.6}", weight.value());
156 println!(" Learned bias: {:.6}", bias.value());
157 println!(" Predictions vs True:");
158
159 for i in 0..5 {
160 let x1 = x_data.data()[i];
161 let pred = final_predictions.data()[i];
162 let true_val = y_true.data()[i];
163 println!(
164 " x={:.1}: pred={:.3}, true={:.1}, error={:.3}",
165 x1,
166 pred,
167 true_val,
168 (pred - true_val).abs()
169 );
170 }
171
172 Ok(())
173}
174
175/// Demonstrate advanced training patterns
176fn demonstrate_advanced_training() -> Result<(), Box<dyn std::error::Error>> {
177 println!("\n--- Advanced Training Patterns ---");
178
179 // Create a more complex model
180 let mut weight = Tensor::randn(vec![1, 2], Some(44)).with_requires_grad();
181 let mut bias = Tensor::zeros(vec![2]).with_requires_grad();
182
183 // Create optimizer with different learning rate
184 let mut optimizer = Adam::with_learning_rate(0.005);
185 optimizer.add_parameter(&weight);
186 optimizer.add_parameter(&bias);
187
188 // Create training data: y = 2*x + [1, 3]
189 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
190 let y_true = Tensor::from_slice(
191 &[3.0, 5.0, 7.0, 9.0, 11.0, 6.0, 8.0, 10.0, 12.0, 14.0],
192 vec![5, 2],
193 )
194 .unwrap();
195
196 println!("Advanced training with monitoring:");
197 println!(" Initial learning rate: {}", optimizer.learning_rate());
198
199 // Training loop with monitoring
200 let num_epochs = 50;
201 let mut losses = Vec::new();
202 let mut weight_norms = Vec::new();
203 let mut gradient_norms = Vec::new();
204
205 for epoch in 0..num_epochs {
206 // Forward pass
207 let y_pred = x_data.matmul(&weight) + &bias;
208 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
209
210 // Backward pass
211 loss.backward(None);
212
213 // Compute gradient norm before optimizer step
214 let gradient_norm = weight.grad_owned().unwrap().norm();
215
216 // Optimizer step
217 optimizer.step(&mut [&mut weight, &mut bias]);
218 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
219
220 // Learning rate scheduling: reduce every 10 epochs
221 if epoch > 0 && epoch % 10 == 0 {
222 let current_lr = optimizer.learning_rate();
223 let new_lr = current_lr * 0.5;
224 optimizer.set_learning_rate(new_lr);
225 println!(
226 "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
227 epoch, current_lr, new_lr
228 );
229 }
230
231 // Record metrics
232 losses.push(loss.value());
233 weight_norms.push(weight.norm().value());
234 gradient_norms.push(gradient_norm.value());
235
236 // Print detailed progress
237 if epoch % 10 == 0 || epoch == num_epochs - 1 {
238 println!(
239 "Epoch {:2}: Loss = {:.6}, Weight Norm = {:.6}, Gradient Norm = {:.6}",
240 epoch,
241 loss.value(),
242 weight.norm().value(),
243 gradient_norm.value()
244 );
245 }
246 }
247
248 println!("Final learning rate: {}", optimizer.learning_rate());
249
250 // Analyze training progression
251 let initial_loss = losses[0];
252 let final_loss = losses[losses.len() - 1];
253 let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
254
255 println!("\nTraining Analysis:");
256 println!(" Initial loss: {:.6}", initial_loss);
257 println!(" Final loss: {:.6}", final_loss);
258 println!(" Loss reduction: {:.1}%", loss_reduction);
259 println!(" Final weight norm: {:.6}", weight.norm().value());
260 println!(" Final bias: {:?}", bias.data());
261
262 Ok(())
263}
264
265/// Demonstrate learning rate scheduling
266fn demonstrate_learning_rate_scheduling() -> Result<(), Box<dyn std::error::Error>> {
267 println!("\n--- Learning Rate Scheduling ---");
268
269 // Create simple model
270 let mut weight = Tensor::randn(vec![1, 1], Some(45)).with_requires_grad();
271 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
272
273 // Create optimizer with high initial learning rate
274 let mut optimizer = Adam::with_learning_rate(0.1);
275 optimizer.add_parameter(&weight);
276 optimizer.add_parameter(&bias);
277
278 // Simple data
279 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3, 1]).unwrap();
280 let y_true = Tensor::from_slice(&[2.0, 4.0, 6.0], vec![3, 1]).unwrap();
281
282 println!("Initial learning rate: {}", optimizer.learning_rate());
283
284 // Training loop with learning rate scheduling
285 let num_epochs = 50;
286 let mut losses = Vec::new();
287
288 for epoch in 0..num_epochs {
289 // Forward pass
290 let y_pred = x_data.matmul(&weight) + &bias;
291 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
292
293 // Backward pass
294 loss.backward(None);
295
296 // Optimizer step
297 optimizer.step(&mut [&mut weight, &mut bias]);
298 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
299
300 // Learning rate scheduling: reduce every 10 epochs
301 if epoch > 0 && epoch % 10 == 0 {
302 let current_lr = optimizer.learning_rate();
303 let new_lr = current_lr * 0.5;
304 optimizer.set_learning_rate(new_lr);
305 println!(
306 "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
307 epoch, current_lr, new_lr
308 );
309 }
310
311 losses.push(loss.value());
312
313 // Print progress
314 if epoch % 10 == 0 || epoch == num_epochs - 1 {
315 println!(
316 "Epoch {:2}: Loss = {:.6}, LR = {:.3}",
317 epoch,
318 loss.value(),
319 optimizer.learning_rate()
320 );
321 }
322 }
323
324 println!("Final learning rate: {}", optimizer.learning_rate());
325
326 Ok(())
327}
328
329/// Demonstrate training monitoring and analysis
330fn demonstrate_training_monitoring() -> Result<(), Box<dyn std::error::Error>> {
331 println!("\n--- Training Monitoring ---");
332
333 // Create model
334 let mut weight = Tensor::randn(vec![1, 1], Some(46)).with_requires_grad();
335 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
336
337 // Create optimizer
338 let mut optimizer = Adam::with_learning_rate(0.01);
339 optimizer.add_parameter(&weight);
340 optimizer.add_parameter(&bias);
341
342 // Training data
343 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
344 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0], vec![4, 1]).unwrap();
345
346 // Training loop with comprehensive monitoring
347 let num_epochs = 30;
348 let mut losses = Vec::new();
349 let mut weight_history = Vec::new();
350 let mut bias_history = Vec::new();
351
352 for epoch in 0..num_epochs {
353 // Forward pass
354 let y_pred = x_data.matmul(&weight) + &bias;
355 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
356
357 // Backward pass
358 loss.backward(None);
359
360 // Optimizer step
361 optimizer.step(&mut [&mut weight, &mut bias]);
362 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
363
364 // Record history
365 losses.push(loss.value());
366 weight_history.push(weight.value());
367 bias_history.push(bias.value());
368
369 // Print detailed monitoring
370 if epoch % 5 == 0 || epoch == num_epochs - 1 {
371 println!(
372 "Epoch {:2}: Loss = {:.6}, Weight = {:.6}, Bias = {:.6}",
373 epoch,
374 loss.value(),
375 weight.value(),
376 bias.value()
377 );
378 }
379 }
380
381 // Analyze training progression
382 println!("\nTraining Analysis:");
383 println!(" Initial loss: {:.6}", losses[0]);
384 println!(" Final loss: {:.6}", losses[losses.len() - 1]);
385 println!(
386 " Loss reduction: {:.1}%",
387 (losses[0] - losses[losses.len() - 1]) / losses[0] * 100.0
388 );
389
390 // Compute statistics
391 let loss_mean = compute_mean(&losses);
392 let loss_std = compute_std(&losses);
393 let weight_change = (weight_history[weight_history.len() - 1] - weight_history[0]).abs();
394 let bias_change = (bias_history[bias_history.len() - 1] - bias_history[0]).abs();
395
396 println!(" Average loss: {:.6} ± {:.6}", loss_mean, loss_std);
397 println!(" Weight change: {:.6}", weight_change);
398 println!(" Bias change: {:.6}", bias_change);
399 println!(" Final weight norm: {:.6}", weight.norm().value());
400 println!(" Final bias: {:.6}", bias.value());
401
402 Ok(())
403}109fn demonstrate_optimizer_serialization() -> Result<(), Box<dyn std::error::Error>> {
110 println!("\n--- Optimizer Serialization ---");
111
112 // Create an optimizer with some parameters
113 let mut weight = Tensor::randn(vec![2, 2], Some(42)).with_requires_grad();
114 let mut bias = Tensor::randn(vec![2], Some(43)).with_requires_grad();
115
116 let config = AdamConfig {
117 learning_rate: 0.001,
118 beta1: 0.9,
119 beta2: 0.999,
120 eps: 1e-8,
121 weight_decay: 0.0,
122 amsgrad: false,
123 };
124
125 let mut optimizer = Adam::with_config(config);
126 optimizer.add_parameter(&weight);
127 optimizer.add_parameter(&bias);
128
129 println!(
130 "Created optimizer with {} parameters",
131 optimizer.parameter_count()
132 );
133 println!("Learning rate: {}", optimizer.learning_rate());
134
135 // Simulate some training steps
136 for _ in 0..3 {
137 let mut loss = weight.sum() + bias.sum();
138 loss.backward(None);
139 optimizer.step(&mut [&mut weight, &mut bias]);
140 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
141 }
142
143 // Save optimizer state
144 let optimizer_path = "temp_optimizer.json";
145 optimizer.save_json(optimizer_path)?;
146 println!("Saved optimizer to: {}", optimizer_path);
147
148 // Load optimizer state
149 let loaded_optimizer = Adam::load_json(optimizer_path)?;
150 println!(
151 "Loaded optimizer with {} parameters",
152 loaded_optimizer.parameter_count()
153 );
154 println!("Learning rate: {}", loaded_optimizer.learning_rate());
155
156 // Verify optimizer state
157 assert_eq!(
158 optimizer.parameter_count(),
159 loaded_optimizer.parameter_count()
160 );
161 assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
162 println!("Optimizer serialization verification: PASSED");
163
164 Ok(())
165}
166
167/// Demonstrate format comparison and performance characteristics
168fn demonstrate_format_comparison() -> Result<(), Box<dyn std::error::Error>> {
169 println!("\n--- Format Comparison ---");
170
171 // Create a larger tensor for comparison
172 let tensor = Tensor::randn(vec![10, 10], Some(44));
173
174 // Save in both formats
175 tensor.save_json("temp_comparison.json")?;
176 tensor.save_binary("temp_comparison.bin")?;
177
178 // Compare file sizes
179 let json_size = fs::metadata("temp_comparison.json")?.len();
180 let binary_size = fs::metadata("temp_comparison.bin")?.len();
181
182 println!("JSON file size: {} bytes", json_size);
183 println!("Binary file size: {} bytes", binary_size);
184 println!(
185 "Compression ratio: {:.2}x",
186 json_size as f64 / binary_size as f64
187 );
188
189 // Load and verify both formats
190 let json_tensor = Tensor::load_json("temp_comparison.json")?;
191 let binary_tensor = Tensor::load_binary("temp_comparison.bin")?;
192
193 assert_eq!(tensor.shape().dims(), json_tensor.shape().dims());
194 assert_eq!(tensor.shape().dims(), binary_tensor.shape().dims());
195 assert_eq!(tensor.data(), json_tensor.data());
196 assert_eq!(tensor.data(), binary_tensor.data());
197
198 println!("Format comparison verification: PASSED");
199
200 Ok(())
201}
202
203/// Demonstrate a basic model checkpointing workflow
204fn demonstrate_model_checkpointing() -> Result<(), Box<dyn std::error::Error>> {
205 println!("\n--- Model Checkpointing ---");
206
207 // Create a simple model (weights and bias)
208 let mut weights = Tensor::randn(vec![2, 1], Some(45)).with_requires_grad();
209 let mut bias = Tensor::randn(vec![1], Some(46)).with_requires_grad();
210
211 // Create optimizer
212 let mut optimizer = Adam::with_learning_rate(0.01);
213 optimizer.add_parameter(&weights);
214 optimizer.add_parameter(&bias);
215
216 println!("Initial weights: {:?}", weights.data());
217 println!("Initial bias: {:?}", bias.data());
218
219 // Simulate training
220 for epoch in 0..5 {
221 let mut loss = weights.sum() + bias.sum();
222 loss.backward(None);
223 optimizer.step(&mut [&mut weights, &mut bias]);
224 optimizer.zero_grad(&mut [&mut weights, &mut bias]);
225
226 if epoch % 2 == 0 {
227 // Save checkpoint
228 let checkpoint_dir = format!("checkpoint_epoch_{}", epoch);
229 fs::create_dir_all(&checkpoint_dir)?;
230
231 weights.save_json(format!("{}/weights.json", checkpoint_dir))?;
232 bias.save_json(format!("{}/bias.json", checkpoint_dir))?;
233 optimizer.save_json(format!("{}/optimizer.json", checkpoint_dir))?;
234
235 println!("Saved checkpoint for epoch {}", epoch);
236 }
237 }
238
239 // Load from checkpoint
240 let loaded_weights = Tensor::load_json("checkpoint_epoch_4/weights.json")?;
241 let loaded_bias = Tensor::load_json("checkpoint_epoch_4/bias.json")?;
242 let loaded_optimizer = Adam::load_json("checkpoint_epoch_4/optimizer.json")?;
243
244 println!("Loaded weights: {:?}", loaded_weights.data());
245 println!("Loaded bias: {:?}", loaded_bias.data());
246 println!(
247 "Loaded optimizer learning rate: {}",
248 loaded_optimizer.learning_rate()
249 );
250
251 // Verify checkpoint integrity
252 assert_eq!(weights.shape().dims(), loaded_weights.shape().dims());
253 assert_eq!(bias.shape().dims(), loaded_bias.shape().dims());
254 assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
255
256 println!("Checkpointing verification: PASSED");
257
258 Ok(())
259}Sourcepub fn set_requires_grad(&mut self, requires_grad: bool)
pub fn set_requires_grad(&mut self, requires_grad: bool)
Set gradient tracking for this tensor
Controls whether the gradtrack system tracks operations on this tensor and computes gradients during backward pass. When disabled, clears any existing gradients and gradient functions.
§Arguments
requires_grad- Whether to track gradients for this tensor
§Performance
- Time Complexity: O(1) - simple field assignment
- Memory: May free gradient storage when disabled
- Overhead: Zero overhead when gradients disabled
§Examples
use train_station::Tensor;
let mut tensor = Tensor::ones(vec![2, 3]);
tensor.set_requires_grad(true);
assert!(tensor.requires_grad());
// Disable gradient tracking
tensor.set_requires_grad(false);
assert!(!tensor.requires_grad());Examples found in repository?
100 fn set_requires_grad_all(&mut self, enable: bool) {
101 for l in &mut self.layers {
102 l.weight.set_requires_grad(enable);
103 l.bias.set_requires_grad(enable);
104 }
105 }
106
107 // In-place copy (preserve tensor IDs and optimizer links on targets)
108 fn copy_from(&mut self, other: &Self) {
109 for (t, s) in self.layers.iter_mut().zip(other.layers.iter()) {
110 {
111 let src = s.weight.data();
112 let dst = t.weight.data_mut();
113 dst.copy_from_slice(src);
114 }
115 {
116 let src = s.bias.data();
117 let dst = t.bias.data_mut();
118 dst.copy_from_slice(src);
119 }
120 t.weight.set_requires_grad(false);
121 t.bias.set_requires_grad(false);
122 }
123 }More examples
107 fn set_requires_grad_all(&mut self, enable: bool) {
108 for l in &mut self.layers {
109 l.weight.set_requires_grad(enable);
110 l.bias.set_requires_grad(enable);
111 }
112 }
113
114 fn copy_from(&mut self, other: &Self) {
115 for (t, s) in self.layers.iter_mut().zip(other.layers.iter()) {
116 {
117 let src = s.weight.data();
118 let dst = t.weight.data_mut();
119 dst.copy_from_slice(src);
120 }
121 {
122 let src = s.bias.data();
123 let dst = t.bias.data_mut();
124 dst.copy_from_slice(src);
125 }
126 t.weight.set_requires_grad(false);
127 t.bias.set_requires_grad(false);
128 }
129 }
130
131 fn soft_update_from(&mut self, source: &Self, tau: f32) {
132 let _ng = NoGradTrack::new();
133 for (t, s) in self.layers.iter_mut().zip(source.layers.iter()) {
134 // In-place Polyak update to preserve tensor IDs (no optimizer relink needed)
135 let new_w = t
136 .weight
137 .mul_scalar(1.0 - tau)
138 .add_tensor(&s.weight.mul_scalar(tau));
139 let new_b = t
140 .bias
141 .mul_scalar(1.0 - tau)
142 .add_tensor(&s.bias.mul_scalar(tau));
143 {
144 let src = new_w.data();
145 let dst = t.weight.data_mut();
146 dst.copy_from_slice(src);
147 }
148 {
149 let src = new_b.data();
150 let dst = t.bias.data_mut();
151 dst.copy_from_slice(src);
152 }
153 t.weight.set_requires_grad(false);
154 t.bias.set_requires_grad(false);
155 }
156 }Sourcepub fn retain_grad(self) -> Self
pub fn retain_grad(self) -> Self
Mark this tensor to retain gradients after backward, even if it is non-leaf.
Builder-style API: returns self with retain_grad=true.
Call materialize_grad() or grad_or_fetch() after backward to copy the
accumulated gradient from the GradGraph into self.grad so grad() works.
Sourcepub fn retain_grad_(&mut self, enable: bool)
pub fn retain_grad_(&mut self, enable: bool)
In-place variant to enable or disable gradient retention for non-leaf tensors
Sourcepub fn requires_grad(&self) -> bool
pub fn requires_grad(&self) -> bool
Check if this tensor requires gradients
§Returns
true if gradient tracking is enabled for this tensor
§Examples
use train_station::Tensor;
let tensor = Tensor::new(vec![2, 3]);
assert!(!tensor.requires_grad());
let grad_tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
assert!(grad_tensor.requires_grad());Examples found in repository?
153fn demonstrate_layer_creation() {
154 println!("--- Layer Creation ---");
155
156 let layer = LinearLayer::new(3, 2, Some(42));
157
158 println!("Created linear layer:");
159 println!(" Input size: {}", layer.input_size);
160 println!(" Output size: {}", layer.output_size);
161 println!(" Parameter count: {}", layer.parameter_count());
162 println!(" Weight shape: {:?}", layer.weight.shape().dims());
163 println!(" Bias shape: {:?}", layer.bias.shape().dims());
164 println!(" Weight requires grad: {}", layer.weight.requires_grad());
165 println!(" Bias requires grad: {}", layer.bias.requires_grad());
166}
167
168/// Demonstrate forward pass with gradient tracking
169fn demonstrate_forward_pass() {
170 println!("\n--- Forward Pass (with gradients) ---");
171
172 let layer = LinearLayer::new(3, 2, Some(43));
173
174 // Single input
175 let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
176 let output = layer.forward(&input);
177
178 println!("Single input:");
179 println!(" Input: {:?}", input.data());
180 println!(" Output: {:?}", output.data());
181 println!(" Output requires grad: {}", output.requires_grad());
182
183 // Batch input
184 let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
185 let batch_output = layer.forward(&batch_input);
186
187 println!("Batch input:");
188 println!(" Input shape: {:?}", batch_input.shape().dims());
189 println!(" Output shape: {:?}", batch_output.shape().dims());
190 println!(" Output requires grad: {}", batch_output.requires_grad());
191}
192
193/// Demonstrate forward pass without gradient tracking
194fn demonstrate_forward_pass_no_grad() {
195 println!("\n--- Forward Pass (no gradients) ---");
196
197 let layer = LinearLayer::new(3, 2, Some(44));
198
199 // Single input
200 let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
201 let output = layer.forward_no_grad(&input);
202
203 println!("Single input (no grad):");
204 println!(" Input: {:?}", input.data());
205 println!(" Output: {:?}", output.data());
206 println!(" Output requires grad: {}", output.requires_grad());
207
208 // Compare with grad version
209 let output_with_grad = layer.forward(&input);
210 println!("Comparison:");
211 println!(
212 " Same values: {}",
213 output.data() == output_with_grad.data()
214 );
215 println!(" No grad requires grad: {}", output.requires_grad());
216 println!(
217 " With grad requires grad: {}",
218 output_with_grad.requires_grad()
219 );
220}More examples
175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176 println!("\n--- Gradient Tracking ---");
177
178 // Create a tensor with gradient tracking enabled
179 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180 println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182 // Perform element-wise operations through iteration
183 let result: Tensor = tensor
184 .iter()
185 .map(|elem| {
186 // Apply a complex transformation: (x^2 + 1) * 2
187 elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188 })
189 .collect();
190
191 println!("Result tensor: {:?}", result.data());
192 println!("Result requires_grad: {}", result.requires_grad());
193
194 // Compute gradients
195 let mut loss = result.sum();
196 loss.backward(None);
197
198 println!("Loss: {:.6}", loss.value());
199 println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201 Ok(())
202}47fn demonstrate_basic_optimizer_setup() {
48 println!("--- Basic Optimizer Setup ---");
49
50 // Create parameters that require gradients
51 let weight = Tensor::randn(vec![3, 2], Some(42)).with_requires_grad();
52 let bias = Tensor::zeros(vec![2]).with_requires_grad();
53
54 println!("Created parameters:");
55 println!(
56 " Weight: shape {:?}, requires_grad: {}",
57 weight.shape().dims(),
58 weight.requires_grad()
59 );
60 println!(
61 " Bias: shape {:?}, requires_grad: {}",
62 bias.shape().dims(),
63 bias.requires_grad()
64 );
65
66 // Create Adam optimizer with default configuration
67 let mut optimizer = Adam::new();
68 println!(
69 "Created Adam optimizer with learning rate: {}",
70 optimizer.learning_rate()
71 );
72
73 // Add parameters to optimizer
74 optimizer.add_parameter(&weight);
75 optimizer.add_parameter(&bias);
76 println!(
77 "Added {} parameters to optimizer",
78 optimizer.parameter_count()
79 );
80
81 // Create optimizer with custom configuration
82 let config = AdamConfig {
83 learning_rate: 0.01,
84 beta1: 0.9,
85 beta2: 0.999,
86 eps: 1e-8,
87 weight_decay: 0.0,
88 amsgrad: false,
89 };
90
91 let mut custom_optimizer = Adam::with_config(config);
92 custom_optimizer.add_parameter(&weight);
93 custom_optimizer.add_parameter(&bias);
94
95 println!(
96 "Created custom optimizer with learning rate: {}",
97 custom_optimizer.learning_rate()
98 );
99
100 // Demonstrate parameter linking
101 println!("Parameter linking completed successfully");
102}313fn demonstrate_forward_pass() {
314 println!("\n--- Forward Pass ---");
315
316 let config = FeedForwardConfig {
317 input_size: 3,
318 hidden_sizes: vec![5, 3],
319 output_size: 2,
320 use_bias: true,
321 };
322 let network = FeedForwardNetwork::new(config, Some(43));
323
324 // Single input
325 let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
326 let output = network.forward(&input);
327
328 println!("Single input forward pass:");
329 println!(" Input shape: {:?}", input.shape().dims());
330 println!(" Output shape: {:?}", output.shape().dims());
331 println!(" Output: {:?}", output.data());
332 println!(" Output requires grad: {}", output.requires_grad());
333
334 // Batch input
335 let batch_input = Tensor::from_slice(
336 &[
337 1.0, 2.0, 3.0, // Sample 1
338 4.0, 5.0, 6.0, // Sample 2
339 7.0, 8.0, 9.0, // Sample 3
340 ],
341 vec![3, 3],
342 )
343 .unwrap();
344 let batch_output = network.forward(&batch_input);
345
346 println!("Batch input forward pass:");
347 println!(" Input shape: {:?}", batch_input.shape().dims());
348 println!(" Output shape: {:?}", batch_output.shape().dims());
349 println!(" Output requires grad: {}", batch_output.requires_grad());
350
351 // Compare with no-grad version
352 let output_no_grad = network.forward_no_grad(&input);
353 println!("No-grad comparison:");
354 println!(" Same values: {}", output.data() == output_no_grad.data());
355 println!(" With grad requires grad: {}", output.requires_grad());
356 println!(
357 " No grad requires grad: {}",
358 output_no_grad.requires_grad()
359 );
360}319fn demonstrate_optimization_techniques() -> Result<(), Box<dyn std::error::Error>> {
320 println!("\n--- Optimization Techniques ---");
321
322 let size = 50000;
323 let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
324 let tensor = Tensor::from_slice(&data, vec![size])?;
325
326 println!("Optimizing processing for size: {}", size);
327
328 // Technique 1: Operation fusion
329 println!("\nTechnique 1: Operation Fusion");
330 let start = Instant::now();
331 let fused_result: Tensor = tensor
332 .iter()
333 .map(|elem| {
334 // Fuse multiple operations into single chain
335 elem.mul_scalar(2.0).add_scalar(1.0).pow_scalar(2.0).sqrt()
336 })
337 .collect();
338 let fused_time = start.elapsed();
339
340 // Technique 2: Conditional optimization
341 println!("\nTechnique 2: Conditional Optimization");
342 let start = Instant::now();
343 let conditional_result: Tensor = tensor
344 .iter()
345 .map(|elem| {
346 let val = elem.value();
347 if val < size as f32 / 2.0 {
348 elem.mul_scalar(2.0) // Simple operation for small values
349 } else {
350 elem.pow_scalar(2.0).sqrt() // Complex operation for large values
351 }
352 })
353 .collect();
354 let conditional_time = start.elapsed();
355
356 // Technique 3: Cache-friendly processing
357 println!("\nTechnique 3: Cache-Friendly Processing");
358 let start = Instant::now();
359 let cache_friendly_result: Tensor = tensor
360 .iter()
361 .take(1000) // Process in cache-friendly chunks
362 .map(|elem| elem.mul_scalar(2.0))
363 .collect();
364 let cache_friendly_time = start.elapsed();
365
366 // Technique 4: Memory pooling simulation
367 println!("\nTechnique 4: Memory Pooling Simulation");
368 let start = Instant::now();
369 let pooled_result: Tensor = tensor
370 .iter()
371 .enumerate()
372 .filter(|(i, _)| i % 100 == 0) // Process every 100th element
373 .map(|(_, elem)| elem.pow_scalar(2.0))
374 .collect();
375 let pooled_time = start.elapsed();
376
377 // Report optimization results
378 println!(" Fused operations: {:?}", fused_time);
379 println!(" Conditional optimization: {:?}", conditional_time);
380 println!(" Cache-friendly processing: {:?}", cache_friendly_time);
381 println!(" Memory pooling simulation: {:?}", pooled_time);
382
383 // Performance analysis
384 let fastest = fused_time
385 .min(conditional_time)
386 .min(cache_friendly_time)
387 .min(pooled_time);
388 println!(" Fastest technique: {:?}", fastest);
389
390 // Memory efficiency analysis
391 println!(" Fused result size: {}", fused_result.size());
392 println!(" Conditional result size: {}", conditional_result.size());
393 println!(
394 " Cache-friendly result size: {}",
395 cache_friendly_result.size()
396 );
397 println!(" Pooled result size: {}", pooled_result.size());
398
399 // Technique 5: Gradient optimization
400 println!("\nTechnique 5: Gradient Optimization");
401 let grad_tensor = tensor.with_requires_grad();
402 let start = Instant::now();
403
404 let grad_result: Tensor = grad_tensor
405 .iter()
406 .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
407 .collect();
408
409 let mut loss = grad_result.sum();
410 loss.backward(None);
411 let grad_time = start.elapsed();
412
413 println!(" Gradient computation: {:?}", grad_time);
414 println!(
415 " Gradient tracking enabled: {}",
416 grad_result.requires_grad()
417 );
418
419 Ok(())
420}Sourcepub fn grad(&self) -> Option<&Tensor>
pub fn grad(&self) -> Option<&Tensor>
Get a reference to this tensor’s locally cached gradient (if any)
This accessor returns only the gradient cached on this tensor’s grad field.
It does NOT query the global/autograd gradient storage. For leaf tensors,
gradients are accumulated and tracked by the grad engine and are not
automatically written back to the local grad field.
To make grad() return Some(&Tensor):
- Enable retention on this tensor (typically for non-leaf tensors) with
retain_grad_(&mut tensor, true)ortensor.retain_grad() - After
backward(), calltensor.materialize_grad()ortensor.grad_or_fetch()to copy the accumulated gradient from the autograd engine intoself.grad
If you want to read gradients without caching them locally, prefer
grad_owned, which consults the global gradient
storage.
§Returns
Optional reference to the locally cached gradient tensor, or None if
not materialized on this tensor.
§Examples
use train_station::Tensor;
let mut x = Tensor::ones(vec![2, 3]).with_requires_grad();
let mut loss = x.sum();
loss.backward(None);
// Without materialization, grad() typically returns None for leaves
assert!(x.grad().is_none());
// Materialize to cache locally so grad() works
x.retain_grad_ (true);
x.materialize_grad();
assert!(x.grad().is_some());Examples found in repository?
175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176 println!("\n--- Gradient Tracking ---");
177
178 // Create a tensor with gradient tracking enabled
179 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180 println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182 // Perform element-wise operations through iteration
183 let result: Tensor = tensor
184 .iter()
185 .map(|elem| {
186 // Apply a complex transformation: (x^2 + 1) * 2
187 elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188 })
189 .collect();
190
191 println!("Result tensor: {:?}", result.data());
192 println!("Result requires_grad: {}", result.requires_grad());
193
194 // Compute gradients
195 let mut loss = result.sum();
196 loss.backward(None);
197
198 println!("Loss: {:.6}", loss.value());
199 println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201 Ok(())
202}Sourcepub fn materialize_grad(&mut self) -> bool
pub fn materialize_grad(&mut self) -> bool
Fetch the accumulated gradient after backward and cache it on this tensor if retain_grad is enabled.
Returns true if a gradient was found and cached. After a successful call,
grad() will return Some(&Tensor) even for non-leaf tensors.
Sourcepub fn grad_or_fetch(&mut self) -> Option<&Tensor>
pub fn grad_or_fetch(&mut self) -> Option<&Tensor>
Convenience accessor: if retain_grad is enabled, fetch and cache the gradient
on first access so callers can immediately get a reference.
Sourcepub fn grad_owned(&self) -> Option<Tensor>
pub fn grad_owned(&self) -> Option<Tensor>
Get the accumulated gradient from the autograd engine as an owned tensor
This accessor queries the global/shared gradient storage for this tensor’s
ID and returns the current accumulated gradient by value. It complements
grad, which only returns a locally cached reference.
- Works for leaf tensors without any prior materialization step
- Does not modify or clear internal gradient state
- Suitable for optimizers and logging that need the latest gradients
§Returns
Some(Tensor) containing the current accumulated gradient when available,
otherwise None.
§Examples
use train_station::Tensor;
let mut x = Tensor::ones(vec![2, 3]).with_requires_grad();
let mut loss = x.sum();
loss.backward(None);
// Fetch directly from autograd storage (no materialization required)
let g = x.grad_owned().unwrap();
assert_eq!(g.shape().dims(), vec![2, 3]);Examples found in repository?
277fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
278 let mut total_sq = 0.0f32;
279 for p in parameters.iter() {
280 if let Some(g) = p.grad_owned() {
281 for &v in g.data() {
282 total_sq += v * v;
283 }
284 }
285 }
286 let norm = total_sq.sqrt();
287 if norm > max_norm {
288 let scale = max_norm / (norm + eps);
289 for p in parameters.iter_mut() {
290 if let Some(g) = p.grad_owned() {
291 p.set_grad(g.mul_scalar(scale));
292 }
293 }
294 }
295}
296
297fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
298 let mut total_sq = 0.0f32;
299 for p in parameters.iter_mut() {
300 if let Some(g) = p.grad_owned() {
301 for &v in g.data() {
302 total_sq += v * v;
303 }
304 }
305 }
306 total_sq.sqrt()
307}
308
309fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
310 let _ng = NoGradTrack::new();
311 let mut total_sq = 0.0f32;
312 for p in parameters.iter_mut() {
313 for &v in p.data() {
314 total_sq += v * v;
315 }
316 }
317 total_sq.sqrt()
318}
319
320// Pseudo-Huber loss: sqrt(1 + diff^2) - 1 (smooth, robust)
321fn pseudo_huber_mean(diff: &Tensor) -> Tensor {
322 diff.pow_scalar(2.0)
323 .add_scalar(1.0)
324 .sqrt()
325 .sub_scalar(1.0)
326 .mean()
327}
328
329// -------------------------------
330// Main
331// -------------------------------
332
333pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334 println!("=== DQN Example (YardEnv discrete) ===");
335
336 // Dims
337 let state_dim = 3usize;
338 let action_dim = 3usize;
339
340 // Hparams
341 let gamma = 0.99f32;
342 let batch_size = 64usize;
343 let start_steps = 200usize;
344 let target_update_interval = 200usize; // hard update cadence
345 let max_grad_norm = 1.0f32;
346 let mut epsilon = 1.0f32;
347 let eps_min = 0.05f32;
348 let eps_decay_steps = 2_000usize; // linear decay
349 let total_steps = std::env::var("DQN_STEPS")
350 .ok()
351 .and_then(|v| v.parse::<usize>().ok())
352 .unwrap_or(3000usize);
353
354 // Models
355 let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356 let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357 q_targ.net.copy_from(&q_net.net);
358 q_targ.set_requires_grad_all(false);
359
360 // Optimizer
361 let mut q_opt = Adam::with_learning_rate(3e-4);
362 for p in q_net.parameters() {
363 q_opt.add_parameter(p);
364 }
365
366 // Replay + env
367 let mut rb = ReplayBuffer::new(100_000, state_dim);
368 let mut env = YardEnv::new(12345);
369 let mut rng = SmallRng::new(999_111);
370
371 // Metrics
372 let mut state = env.reset();
373 let mut episode_return = 0.0f32;
374 let mut episode = 0usize;
375 let mut ema_return: Option<f32> = None;
376 let ema_alpha = 0.05f32;
377 let mut best_return = f32::NEG_INFINITY;
378
379 for t in 0..total_steps {
380 // Epsilon-greedy action
381 let action_index = if t < start_steps || rng.next_f32() < epsilon {
382 rng.sample_index(action_dim)
383 } else {
384 let _ng = NoGradTrack::new();
385 let q_vals = q_net.forward(&state);
386 let row = q_vals.data();
387 let mut best_i = 0usize;
388 let mut best_v = row[0];
389 for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390 if r > best_v {
391 best_v = r;
392 best_i = i;
393 }
394 }
395 best_i
396 };
397
398 // Env step
399 let (next_state, reward, done) = env.step(action_index);
400 episode_return += reward;
401
402 // Store
403 let s_slice = state.data().to_vec();
404 let s2_slice = next_state.data().to_vec();
405 rb.push(
406 &s_slice,
407 action_index,
408 reward,
409 if done { 1.0 } else { 0.0 },
410 &s2_slice,
411 );
412
413 // Reset on done
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return,
430 rb.size
431 );
432 episode_return = 0.0;
433 episode += 1;
434 st
435 } else {
436 next_state
437 };
438
439 // Epsilon linear decay
440 if t < eps_decay_steps {
441 epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442 }
443
444 // Train
445 if rb.can_sample(batch_size) {
446 let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448 // Double DQN target: a* = argmax_a Q_online(s2,a); y = r + (1-d)*gamma*Q_target(s2, a*)
449 let target_q = {
450 let _ng = NoGradTrack::new();
451 let q_online_s2 = q_net.forward(&s2);
452 // argmax per row (manual on CPU)
453 let row_stride = action_dim;
454 let qd = q_online_s2.data();
455 let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456 for i in 0..batch_size {
457 let base = i * row_stride;
458 let mut bi = 0usize;
459 let mut bv = qd[base];
460 for j in 1..action_dim {
461 let v = qd[base + j];
462 if v > bv {
463 bv = v;
464 bi = j;
465 }
466 }
467 next_actions.push(bi);
468 }
469 let q_targ_s2 = q_targ.forward(&s2);
470 let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471 let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472 r.add_tensor(¬_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473 };
474
475 // Q(s,a) for current actions
476 // Zero grads first
477 {
478 let mut params = q_net.parameters();
479 q_opt.zero_grad(&mut params);
480 }
481
482 let q_all = q_net.forward(&s);
483 let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484 let diff = q_sa.sub_tensor(&target_q);
485 let mut loss = pseudo_huber_mean(&diff);
486 loss.backward(None);
487
488 // Step (filter only params with grads)
489 {
490 let params = q_net.parameters();
491 let mut with_grads: Vec<&mut Tensor> = Vec::new();
492 for p in params {
493 if p.grad_owned().is_some() {
494 with_grads.push(p);
495 }
496 }
497 if !with_grads.is_empty() {
498 let gn = grad_global_norm(&mut with_grads);
499 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500 q_opt.step(&mut with_grads);
501 q_opt.zero_grad(&mut with_grads);
502 if t % 100 == 0 {
503 let mut pn = q_net.parameters();
504 let pn_l2 = params_l2_norm(&mut pn);
505 let q_mean = q_all.mean().value();
506 println!(
507 "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508 t, loss.value(), q_mean, gn, pn_l2, epsilon
509 );
510 }
511 }
512 }
513
514 // Target hard update
515 if t % target_update_interval == 0 {
516 q_targ.net.copy_from(&q_net.net);
517 }
518
519 // Clear graphs
520 clear_all_graphs_known();
521 }
522 }
523
524 println!("=== DQN training finished ===");
525 Ok(())
526}More examples
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294 let mut total_sq = 0.0f32;
295 for p in parameters.iter() {
296 if let Some(g) = p.grad_owned() {
297 for &v in g.data() {
298 total_sq += v * v;
299 }
300 }
301 }
302 let norm = total_sq.sqrt();
303 if norm > max_norm {
304 let scale = max_norm / (norm + eps);
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 p.set_grad(g.mul_scalar(scale));
308 }
309 }
310 }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314 let mut total_sq = 0.0f32;
315 for p in parameters.iter_mut() {
316 if let Some(g) = p.grad_owned() {
317 for &v in g.data() {
318 total_sq += v * v;
319 }
320 }
321 }
322 total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330 println!("=== PPO Continuous Example (YardEnv) ===");
331
332 let state_dim = 3usize;
333 let action_dim = 1usize;
334
335 // Hparams
336 let total_steps = std::env::var("PPO_STEPS")
337 .ok()
338 .and_then(|v| v.parse::<usize>().ok())
339 .unwrap_or(4000usize);
340 let horizon = 128usize; // rollout length per update
341 let epochs = 4usize; // PPO epochs per update
342 let mini_batch_size = 64usize; // minibatch from horizon
343 let gamma = 0.99f32;
344 let lam = 0.95f32; // GAE lambda
345 let clip_eps = 0.2f32;
346 let vf_coef = 0.5f32;
347 let ent_coef = 0.0f32;
348 let max_grad_norm = 1.0f32;
349
350 // Models
351 let mut actor = Actor::new(state_dim, action_dim, Some(101));
352 let mut critic = Critic::new(state_dim, Some(202));
353
354 // Opts
355 let mut actor_opt = Adam::with_learning_rate(3e-4);
356 for p in actor.parameters() {
357 actor_opt.add_parameter(p);
358 }
359 let mut critic_opt = Adam::with_learning_rate(3e-4);
360 for p in critic.parameters() {
361 critic_opt.add_parameter(p);
362 }
363
364 // Env and RNG
365 let mut env = YardEnv::new(42);
366 let mut rng = SmallRng::new(999);
367 let mut state = env.reset();
368
369 // Metrics
370 let mut episode_return = 0.0f32;
371 let mut episode = 0usize;
372 let mut ema_return: Option<f32> = None;
373 let ema_alpha = 0.05f32;
374 let mut best_return = f32::NEG_INFINITY;
375
376 let mut t = 0usize;
377 while t < total_steps {
378 // Collect a rollout
379 let mut batch = RolloutBatch::new(horizon, state_dim);
380 for _ in 0..horizon {
381 // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382 let (mean, log_std_row) = actor.forward(&state);
383 let mean_v = mean.data()[0];
384 let log_std_v = log_std_row.data()[0];
385 let std_v = log_std_v.exp();
386 let noise = rng.normal();
387 let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389 // Build action tensor [1, A] for log_prob calculation with autograd
390 let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391 let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392 let log_prob_v = log_prob_t.data()[0];
393
394 // Step env
395 let (next_state, reward, done) = env.step(action_v);
396 episode_return += reward;
397
398 // Value
399 let value_t = critic.forward(&state);
400 let value_v = value_t.data()[0];
401
402 // Push
403 batch.push(
404 state.data(),
405 action_v,
406 log_prob_v,
407 reward,
408 if done { 1.0 } else { 0.0 },
409 value_v,
410 next_state.data(),
411 );
412
413 // Reset
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return
430 );
431 episode_return = 0.0;
432 episode += 1;
433 st
434 } else {
435 next_state
436 };
437
438 t += 1;
439 if t >= total_steps {
440 break;
441 }
442 }
443
444 // Bootstrap next values for GAE
445 let next_values: Vec<f32> = {
446 let mut out = Vec::with_capacity(batch.len());
447 for i in 0..batch.len() {
448 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450 let v2 = critic.forward(&s2_t).data()[0];
451 out.push(v2);
452 }
453 out
454 };
455
456 // Compute returns and advantages
457 let mut returns = vec![0.0f32; batch.len()];
458 let mut adv = vec![0.0f32; batch.len()];
459 compute_gae(
460 &mut returns,
461 &mut adv,
462 &batch.rewards,
463 &batch.dones,
464 &batch.values,
465 &next_values,
466 gamma,
467 lam,
468 );
469 normalize_in_place(&mut adv, 1e-8);
470
471 // Prepare tensors for training
472 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473 let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474 let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478 // PPO epochs over the rollout
479 let num_minibatches = batch.len().div_ceil(mini_batch_size);
480 for e in 0..epochs {
481 for mb in 0..num_minibatches {
482 let start = mb * mini_batch_size;
483 let end = (start + mini_batch_size).min(batch.len());
484 if start >= end {
485 break;
486 }
487
488 // Slice views
489 let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490 let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491 let a_mb = actions_t
492 .slice_view(start * action_dim, 1, (end - start) * action_dim)
493 .reshape(vec![(end - start) as i32, action_dim as i32]);
494 let oldlp_mb = old_logp_t
495 .slice_view(start, 1, end - start)
496 .reshape(vec![(end - start) as i32, 1]);
497 let ret_mb = returns_t
498 .slice_view(start, 1, end - start)
499 .reshape(vec![(end - start) as i32, 1]);
500 let adv_mb = adv_t
501 .slice_view(start, 1, end - start)
502 .reshape(vec![(end - start) as i32, 1]);
503
504 // Zero grads
505 {
506 let mut ps = actor.parameters();
507 actor_opt.zero_grad(&mut ps);
508 }
509 {
510 let mut ps = critic.parameters();
511 critic_opt.zero_grad(&mut ps);
512 }
513
514 // Forward actor and critic
515 let (mean_mb, log_std_row) = actor.forward(&s_mb);
516 let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517 let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518 let clip_low =
519 Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520 .unwrap();
521 let clip_high =
522 Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523 .unwrap();
524 // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525 let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526 let ratio_clipped =
527 clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528 let pg1 = ratio.mul_tensor(&adv_mb);
529 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532 let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534 let v_pred = critic.forward(&s_mb);
535 let v_loss = v_pred
536 .sub_tensor(&ret_mb)
537 .pow_scalar(2.0)
538 .mean()
539 .mul_scalar(vf_coef);
540
541 // Entropy (approx Gaussian entropy per action)
542 let entropy = log_std_row
543 .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544 .sum_dims(&[1], true)
545 .mean()
546 .mul_scalar(ent_coef);
547
548 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549 loss.backward(None);
550
551 // Step actor
552 {
553 let params = actor.parameters();
554 let mut with_grads: Vec<&mut Tensor> = Vec::new();
555 for p in params {
556 if p.grad_owned().is_some() {
557 with_grads.push(p);
558 }
559 }
560 if !with_grads.is_empty() {
561 let _ = grad_global_norm(&mut with_grads);
562 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563 actor_opt.step(&mut with_grads);
564 actor_opt.zero_grad(&mut with_grads);
565 }
566 }
567
568 // Step critic
569 {
570 let params = critic.parameters();
571 let mut with_grads: Vec<&mut Tensor> = Vec::new();
572 for p in params {
573 if p.grad_owned().is_some() {
574 with_grads.push(p);
575 }
576 }
577 if !with_grads.is_empty() {
578 let _ = grad_global_norm(&mut with_grads);
579 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580 critic_opt.step(&mut with_grads);
581 critic_opt.zero_grad(&mut with_grads);
582 }
583 }
584
585 // Occasionally log
586 if e == 0 && mb == 0 {
587 println!(
588 "update@t={} | actor_loss={:.4} v_loss={:.4}",
589 t,
590 actor_loss.value(),
591 v_loss.value()
592 );
593 }
594
595 clear_all_graphs_known();
596 }
597 }
598 }
599
600 println!("=== PPO training finished ===");
601 Ok(())
602}252fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
253 let mut total_sq = 0.0f32;
254 for p in parameters.iter() {
255 if let Some(g) = p.grad_owned() {
256 for &v in g.data() {
257 total_sq += v * v;
258 }
259 }
260 }
261 let norm = total_sq.sqrt();
262 if norm > max_norm {
263 let scale = max_norm / (norm + eps);
264 for p in parameters.iter_mut() {
265 if let Some(g) = p.grad_owned() {
266 p.set_grad(g.mul_scalar(scale));
267 }
268 }
269 }
270}
271
272// log-softmax for selected actions: given logits [B,A] and actions Vec<usize> -> log_prob [B,1]
273fn log_prob_actions(
274 logits: &Tensor,
275 actions: &[usize],
276 batch: usize,
277 _action_dim: usize,
278) -> Tensor {
279 let max_logits = logits.max_dims(&[1], true); // [B,1]
280 let shifted = logits.sub_tensor(&max_logits);
281 let exp = shifted.exp();
282 let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283 let log_sum_exp = sum_exp.log(); // [B,1]
284 let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285 // gather selected action log-probs
286 log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291 new_logp.sub_tensor(old_logp).exp()
292}
293
294// Clamp ratio to [1-clip, 1+clip] using ReLU-based clamp (no custom ops)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296 let b = ratio.shape().dims()[0];
297 let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298 let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299 let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300 high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304 let mut total_sq = 0.0f32;
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 for &v in g.data() {
308 total_sq += v * v;
309 }
310 }
311 }
312 total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320 println!("=== PPO Discrete Example (YardEnv) ===");
321
322 let state_dim = 3usize;
323 let action_dim = 3usize;
324 let total_steps = std::env::var("PPOD_STEPS")
325 .ok()
326 .and_then(|v| v.parse::<usize>().ok())
327 .unwrap_or(3500usize);
328 let horizon = 128usize;
329 let epochs = 4usize;
330 let mini_batch_size = 64usize;
331 let gamma = 0.99f32;
332 let lam = 0.95f32;
333 let clip_eps = 0.2f32;
334 let vf_coef = 0.5f32;
335 let ent_coef = 0.0f32;
336 let max_grad_norm = 1.0f32;
337
338 let mut actor = Actor::new(state_dim, action_dim, Some(111));
339 let mut critic = Critic::new(state_dim, Some(222));
340 let mut actor_opt = Adam::with_learning_rate(3e-4);
341 for p in actor.parameters() {
342 actor_opt.add_parameter(p);
343 }
344 let mut critic_opt = Adam::with_learning_rate(3e-4);
345 for p in critic.parameters() {
346 critic_opt.add_parameter(p);
347 }
348
349 let mut env = YardEnv::new(1234);
350 let mut rng = SmallRng::new(98765);
351 let mut state = env.reset();
352 let mut episode_return = 0.0f32;
353 let mut episode = 0usize;
354 let mut ema_return: Option<f32> = None;
355 let ema_alpha = 0.05f32;
356 let mut best_return = f32::NEG_INFINITY;
357
358 let mut t = 0usize;
359 while t < total_steps {
360 let mut batch = RolloutBatch::new(horizon, state_dim);
361 for _ in 0..horizon {
362 // Actor logits and categorical sampling
363 let logits = actor.forward(&state); // [1, A]
364 let probs = logits.softmax(1); // [1, A]
365 // sample action from probs (CPU sampling)
366 let p = probs.data();
367 let (p0, p1, _p2) = (p[0], p[1], p[2]);
368 let u = rng.next_f32();
369 let a_idx = if u < p0 {
370 0
371 } else if u < p0 + p1 {
372 1
373 } else {
374 2
375 };
376
377 let old_logp = {
378 let _ng = NoGradTrack::new();
379 let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380 lp.data()[0]
381 };
382
383 // Step env
384 let (next_state, reward, done) = env.step(a_idx);
385 episode_return += reward;
386
387 // Critic value
388 let value_t = critic.forward(&state);
389 let value_v = value_t.data()[0];
390
391 batch.push(
392 state.data(),
393 a_idx,
394 old_logp,
395 reward,
396 if done { 1.0 } else { 0.0 },
397 value_v,
398 next_state.data(),
399 );
400
401 state = if done {
402 let st = env.reset();
403 ema_return = Some(match ema_return {
404 None => episode_return,
405 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406 });
407 if episode_return > best_return {
408 best_return = episode_return;
409 }
410 println!(
411 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412 t,
413 episode,
414 episode_return,
415 ema_return.unwrap_or(episode_return),
416 best_return
417 );
418 episode_return = 0.0;
419 episode += 1;
420 st
421 } else {
422 next_state
423 };
424
425 t += 1;
426 if t >= total_steps {
427 break;
428 }
429 }
430
431 // Bootstrap values for GAE
432 let next_values: Vec<f32> = {
433 let mut out = Vec::with_capacity(batch.len());
434 for i in 0..batch.len() {
435 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437 out.push(critic.forward(&s2_t).data()[0]);
438 }
439 out
440 };
441
442 let mut returns = vec![0.0f32; batch.len()];
443 let mut adv = vec![0.0f32; batch.len()];
444 compute_gae(
445 &mut returns,
446 &mut adv,
447 &batch.rewards,
448 &batch.dones,
449 &batch.values,
450 &next_values,
451 gamma,
452 lam,
453 );
454 normalize_in_place(&mut adv, 1e-8);
455
456 // Tensors for training
457 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458 let actions_vec = batch.actions.clone();
459 let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463 // PPO epochs
464 let num_minibatches = batch.len().div_ceil(mini_batch_size);
465 for e in 0..epochs {
466 for mb in 0..num_minibatches {
467 let start = mb * mini_batch_size;
468 let end = (start + mini_batch_size).min(batch.len());
469 if start >= end {
470 break;
471 }
472
473 // Views
474 let s_mb = states_t
475 .slice_view(start * state_dim, 1, (end - start) * state_dim)
476 .reshape(vec![(end - start) as i32, state_dim as i32]);
477 let oldlp_mb = old_logp_t
478 .slice_view(start, 1, end - start)
479 .reshape(vec![(end - start) as i32, 1]);
480 let ret_mb = returns_t
481 .slice_view(start, 1, end - start)
482 .reshape(vec![(end - start) as i32, 1]);
483 let adv_mb = adv_t
484 .slice_view(start, 1, end - start)
485 .reshape(vec![(end - start) as i32, 1]);
486 let a_slice = &actions_vec[start..end];
487
488 // Zero grads
489 {
490 let mut ps = actor.parameters();
491 actor_opt.zero_grad(&mut ps);
492 }
493 {
494 let mut ps = critic.parameters();
495 critic_opt.zero_grad(&mut ps);
496 }
497
498 // Forward
499 let logits_mb = actor.forward(&s_mb); // [B,A]
500 let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501 let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502 let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503 let pg1 = ratio.mul_tensor(&adv_mb);
504 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507 let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509 let v_pred = critic.forward(&s_mb);
510 let v_loss = v_pred
511 .sub_tensor(&ret_mb)
512 .pow_scalar(2.0)
513 .mean()
514 .mul_scalar(vf_coef);
515
516 // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517 let probs_mb = logits_mb.softmax(1);
518 let logp_all = probs_mb.add_scalar(1e-8).log();
519 let ent = probs_mb
520 .mul_tensor(&logp_all)
521 .sum_dims(&[1], true)
522 .mul_scalar(-1.0)
523 .mean()
524 .mul_scalar(ent_coef);
525
526 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527 loss.backward(None);
528
529 // Step actor
530 {
531 let params = actor.parameters();
532 let mut with_grads: Vec<&mut Tensor> = Vec::new();
533 for p in params {
534 if p.grad_owned().is_some() {
535 with_grads.push(p);
536 }
537 }
538 if !with_grads.is_empty() {
539 let _ = grad_global_norm(&mut with_grads);
540 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541 actor_opt.step(&mut with_grads);
542 actor_opt.zero_grad(&mut with_grads);
543 }
544 }
545
546 // Step critic
547 {
548 let params = critic.parameters();
549 let mut with_grads: Vec<&mut Tensor> = Vec::new();
550 for p in params {
551 if p.grad_owned().is_some() {
552 with_grads.push(p);
553 }
554 }
555 if !with_grads.is_empty() {
556 let _ = grad_global_norm(&mut with_grads);
557 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558 critic_opt.step(&mut with_grads);
559 critic_opt.zero_grad(&mut with_grads);
560 }
561 }
562
563 if e == 0 && mb == 0 {
564 println!(
565 "update@t={} | actor_loss={:.4} v_loss={:.4}",
566 t,
567 actor_loss.value(),
568 v_loss.value()
569 );
570 }
571
572 clear_all_graphs_known();
573 }
574 }
575 }
576
577 println!("=== PPO discrete training finished ===");
578 Ok(())
579}23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24 let mut total_sq = 0.0f32;
25 for p in parameters.iter() {
26 if let Some(g) = p.grad_owned() {
27 for &v in g.data() {
28 total_sq += v * v;
29 }
30 }
31 }
32 let norm = total_sq.sqrt();
33 if norm > max_norm {
34 let scale = max_norm / (norm + eps);
35 for p in parameters.iter_mut() {
36 if let Some(g) = p.grad_owned() {
37 p.set_grad(g.mul_scalar(scale));
38 }
39 }
40 }
41}
42
43fn accuracy(pred: &Tensor, targets: &Tensor) -> f32 {
44 // pred: [B,1] with sigmoid; threshold at 0.5
45 let p = pred.data();
46 let t = targets.data();
47 let mut correct = 0usize;
48 for i in 0..p.len() {
49 let yhat = if p[i] >= 0.5 { 1.0 } else { 0.0 };
50 if (yhat - t[i]).abs() < 1e-6 {
51 correct += 1;
52 }
53 }
54 correct as f32 / (p.len() as f32)
55}
56
57// Numerically stable BCE with logits:
58// L = mean( relu(z) - z*y + log(1 + exp(-|z|)) )
59fn bce_with_logits(logits: &Tensor, targets: &Tensor) -> Tensor {
60 let relu_z = logits.relu();
61 let zy = logits.mul_tensor(targets);
62 // |z| = relu(z) + relu(-z)
63 let abs_z = relu_z.add_tensor(&logits.mul_scalar(-1.0).relu());
64 let log_term = abs_z.mul_scalar(-1.0).exp().add_scalar(1.0).log();
65 relu_z.sub_tensor(&zy).add_tensor(&log_term).mean()
66}
67
68pub fn main() -> Result<(), Box<dyn std::error::Error>> {
69 println!("=== Supervised FFN Example (XOR) ===");
70
71 // Dataset: XOR (repeat to form a small batch)
72 let inputs: Vec<f32> = vec![
73 0.0, 0.0, // -> 0
74 0.0, 1.0, // -> 1
75 1.0, 0.0, // -> 1
76 1.0, 1.0, // -> 0
77 ];
78 let targets: Vec<f32> = vec![0.0, 1.0, 1.0, 0.0];
79
80 // Repeat the base patterns to stabilize training
81 let repeats = 64usize; // effective batch = 4 * repeats = 256
82 let mut xs = Vec::with_capacity(repeats * inputs.len());
83 let mut ys = Vec::with_capacity(repeats * targets.len());
84 for _ in 0..repeats {
85 xs.extend_from_slice(&inputs);
86 ys.extend_from_slice(&targets);
87 }
88
89 let batch = xs.len() / 2; // two features
90 let x_t = Tensor::from_slice(&xs, vec![batch, 2]).unwrap();
91 let y_t = Tensor::from_slice(&ys, vec![batch, 1]).unwrap();
92
93 // Model config: 2 -> 32 -> 32 -> 1, final sigmoid via loss path
94 let cfg = FeedForwardConfig {
95 input_size: 2,
96 hidden_sizes: vec![32, 32],
97 output_size: 1,
98 use_bias: true,
99 };
100 let mut net = FeedForwardNetwork::new(cfg, Some(777));
101
102 // Optimizer and parameter linking
103 let mut opt = Adam::with_learning_rate(1e-3);
104 for p in net.parameters() {
105 opt.add_parameter(p);
106 }
107
108 let epochs = 1000usize;
109 let max_grad_norm = 1.0f32;
110 let mut best_loss = f32::INFINITY;
111 let mut best_acc = 0.0f32;
112
113 for e in 0..epochs {
114 // Zero grads each iteration
115 {
116 let mut params = net.parameters();
117 opt.zero_grad(&mut params);
118 }
119
120 // Forward -> logits; use numerically stable BCE-with-logits for loss
121 let logits = net.forward(&x_t);
122 let mut loss = bce_with_logits(&logits, &y_t);
123 loss.backward(None);
124
125 // Step only params with grads
126 {
127 let params = net.parameters();
128 let mut with_grads: Vec<&mut Tensor> = Vec::new();
129 for p in params {
130 if p.grad_owned().is_some() {
131 with_grads.push(p);
132 }
133 }
134 if !with_grads.is_empty() {
135 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
136 opt.step(&mut with_grads);
137 opt.zero_grad(&mut with_grads);
138 }
139 }
140
141 // Metrics (use sigmoid only for reporting accuracy)
142 let preds = logits.sigmoid();
143 let acc = accuracy(&preds, &y_t);
144 if loss.value() < best_loss {
145 best_loss = loss.value();
146 }
147 if acc > best_acc {
148 best_acc = acc;
149 }
150 if e % 10 == 0 || e + 1 == epochs {
151 println!(
152 "epoch {:4} | loss={:.5} acc={:.3} | best_loss={:.5} best_acc={:.3}",
153 e,
154 loss.value(),
155 acc,
156 best_loss,
157 best_acc
158 );
159 }
160
161 // Clear graphs to avoid stale accumulation across epochs
162 clear_all_graphs_known();
163 }
164
165 // Quick sanity check predictions
166 let test = Tensor::from_slice(&inputs, vec![4, 2]).unwrap();
167 let out = net.forward(&test).sigmoid();
168 println!("predictions (approx): {:?}", out.data());
169
170 println!("=== Supervised training finished ===");
171 Ok(())
172}23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24 let mut total_sq = 0.0f32;
25 for p in parameters.iter() {
26 if let Some(g) = p.grad_owned() {
27 for &v in g.data() {
28 total_sq += v * v;
29 }
30 }
31 }
32 let norm = total_sq.sqrt();
33 if norm > max_norm {
34 let scale = max_norm / (norm + eps);
35 for p in parameters.iter_mut() {
36 if let Some(g) = p.grad_owned() {
37 p.set_grad(g.mul_scalar(scale));
38 }
39 }
40 }
41}
42
43// Cross-entropy over logits: CE = -mean(log_softmax(logits)[range, labels])
44fn cross_entropy_logits(
45 logits: &Tensor,
46 labels: &[usize],
47 batch: usize,
48 _num_classes: usize,
49) -> Tensor {
50 // log_softmax = logits - logsumexp(logits, dim=1)
51 let max_logits = logits.max_dims(&[1], true);
52 let shifted = logits.sub_tensor(&max_logits);
53 let exp = shifted.exp();
54 let sum_exp = exp.sum_dims(&[1], true);
55 let log_sum_exp = sum_exp.log();
56 let log_softmax = shifted.sub_tensor(&log_sum_exp);
57 let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58 ll.mul_scalar(-1.0).mean()
59}
60
61fn accuracy_from_logits(
62 logits: &Tensor,
63 labels: &[usize],
64 batch: usize,
65 num_classes: usize,
66) -> f32 {
67 let row = logits.data();
68 let mut correct = 0usize;
69 for (i, &label) in labels.iter().enumerate().take(batch) {
70 let base = i * num_classes;
71 let mut best_j = 0usize;
72 let mut best_v = row[base];
73 for j in 1..num_classes {
74 let v = row[base + j];
75 if v > best_v {
76 best_v = v;
77 best_j = j;
78 }
79 }
80 if best_j == label {
81 correct += 1;
82 }
83 }
84 correct as f32 / batch as f32
85}
86
87pub fn main() -> Result<(), Box<dyn std::error::Error>> {
88 println!("=== Supervised Classification Example (Cross-Entropy) ===");
89
90 // Synthetic 2D inputs, 3 classes with linear-ish separations
91 let n = 1200usize;
92 let classes = 3usize;
93 let mut xs: Vec<f32> = Vec::with_capacity(n * 2);
94 let mut ys: Vec<usize> = Vec::with_capacity(n);
95
96 // Simple RNG
97 let mut state: u64 = 424242;
98 let mut rand_f32 = || {
99 state = state.wrapping_mul(1664525).wrapping_add(1013904223);
100 (state >> 16) as f32 / (u32::MAX as f32)
101 };
102
103 for _ in 0..n {
104 let x1 = rand_f32() * 4.0 - 2.0;
105 let x2 = rand_f32() * 4.0 - 2.0;
106 // Class by quadrant-ish rule with noise
107 let mut c = if x1 + 0.5 * x2 > 0.5 {
108 0
109 } else if x1 - x2 < -0.5 {
110 1
111 } else {
112 2
113 };
114 if rand_f32() < 0.05 {
115 c = (c + 1) % classes;
116 }
117 xs.push(x1);
118 xs.push(x2);
119 ys.push(c);
120 }
121
122 // Normalize inputs per-feature to [-1, 1]
123 let mut min1 = f32::INFINITY;
124 let mut max1 = f32::NEG_INFINITY;
125 let mut min2 = f32::INFINITY;
126 let mut max2 = f32::NEG_INFINITY;
127 for i in (0..xs.len()).step_by(2) {
128 let a = xs[i];
129 let b = xs[i + 1];
130 if a < min1 {
131 min1 = a;
132 }
133 if a > max1 {
134 max1 = a;
135 }
136 if b < min2 {
137 min2 = b;
138 }
139 if b > max2 {
140 max2 = b;
141 }
142 }
143 let rng1 = (max1 - min1).max(1e-8);
144 let rng2 = (max2 - min2).max(1e-8);
145 for i in (0..xs.len()).step_by(2) {
146 let a = xs[i];
147 let b = xs[i + 1];
148 xs[i] = 2.0 * (a - min1) / rng1 - 1.0;
149 xs[i + 1] = 2.0 * (b - min2) / rng2 - 1.0;
150 }
151
152 // Train/Val split (80/20)
153 let n_train = (n as f32 * 0.8) as usize;
154 let x_train = Tensor::from_slice(&xs[..n_train * 2], vec![n_train, 2]).unwrap();
155 let y_train = ys[..n_train].to_vec();
156 let x_val = Tensor::from_slice(&xs[n_train * 2..], vec![n - n_train, 2]).unwrap();
157 let y_val = ys[n_train..].to_vec();
158
159 // Model: 2 -> 64 -> 64 -> 3 (logits)
160 let cfg = FeedForwardConfig {
161 input_size: 2,
162 hidden_sizes: vec![64, 64],
163 output_size: classes,
164 use_bias: true,
165 };
166 let mut net = FeedForwardNetwork::new(cfg, Some(303));
167
168 // Optimizer
169 let mut opt = Adam::with_learning_rate(1e-3);
170 for p in net.parameters() {
171 opt.add_parameter(p);
172 }
173
174 let epochs = 300usize;
175 let max_grad_norm = 1.0f32;
176 let mut best_val_acc = 0.0f32;
177 let mut best_val_loss = f32::INFINITY;
178
179 for e in 0..epochs {
180 // Zero grads
181 {
182 let mut params = net.parameters();
183 opt.zero_grad(&mut params);
184 }
185
186 // Forward logits
187 let logits = net.forward(&x_train);
188 let mut loss = cross_entropy_logits(&logits, &y_train, n_train, classes);
189 loss.backward(None);
190
191 // Step clipped
192 {
193 let params = net.parameters();
194 let mut with_grads: Vec<&mut Tensor> = Vec::new();
195 for p in params {
196 if p.grad_owned().is_some() {
197 with_grads.push(p);
198 }
199 }
200 if !with_grads.is_empty() {
201 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
202 opt.step(&mut with_grads);
203 opt.zero_grad(&mut with_grads);
204 }
205 }
206
207 // Metrics
208 let train_acc = accuracy_from_logits(&logits, &y_train, n_train, classes);
209 let val_logits = net.forward(&x_val);
210 let val_loss = cross_entropy_logits(&val_logits, &y_val, n - n_train, classes).value();
211 let val_acc = accuracy_from_logits(&val_logits, &y_val, n - n_train, classes);
212 if val_acc > best_val_acc {
213 best_val_acc = val_acc;
214 }
215 if val_loss < best_val_loss {
216 best_val_loss = val_loss;
217 }
218
219 if e % 10 == 0 || e + 1 == epochs {
220 println!(
221 "epoch {:4} | loss={:.4} acc={:.3} | val_loss={:.4} val_acc={:.3} | best_val_acc={:.3}",
222 e, loss.value(), train_acc, val_loss, val_acc, best_val_acc
223 );
224 }
225
226 clear_all_graphs_known();
227 }
228
229 // Quick sample preds via softmax
230 let samples = Tensor::from_slice(&[-1.0, -1.0, 0.0, 0.0, 1.0, 1.0], vec![3, 2]).unwrap();
231 let sm = net.forward(&samples).softmax(1);
232 println!("sample class probs: {:?}", sm.data());
233
234 println!("=== Supervised classification finished ===");
235 Ok(())
236}22fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
23 let mut total_sq = 0.0f32;
24 for p in parameters.iter() {
25 if let Some(g) = p.grad_owned() {
26 for &v in g.data() {
27 total_sq += v * v;
28 }
29 }
30 }
31 let norm = total_sq.sqrt();
32 if norm > max_norm {
33 let scale = max_norm / (norm + eps);
34 for p in parameters.iter_mut() {
35 if let Some(g) = p.grad_owned() {
36 p.set_grad(g.mul_scalar(scale));
37 }
38 }
39 }
40}
41
42fn mse(pred: &Tensor, target: &Tensor) -> Tensor {
43 pred.sub_tensor(target).pow_scalar(2.0).mean()
44}
45
46fn rmse(pred: &Tensor, target: &Tensor) -> f32 {
47 mse(pred, target).sqrt().value()
48}
49
50fn r2_score(pred: &Tensor, target: &Tensor) -> f32 {
51 // R^2 = 1 - SS_res / SS_tot
52 let y = target;
53 let y_mean = y.mean();
54 let ss_res = pred.sub_tensor(y).pow_scalar(2.0).sum();
55 let ss_tot = y.sub_tensor(&y_mean).pow_scalar(2.0).sum();
56 let ss_res_v = ss_res.value();
57 let ss_tot_v = ss_tot.value().max(1e-12); // avoid divide by zero
58 1.0 - (ss_res_v / ss_tot_v)
59}
60
61pub fn main() -> Result<(), Box<dyn std::error::Error>> {
62 println!("=== Supervised Regression Example (MSE) ===");
63
64 // Generate simple synthetic data: y = 2*x1 - 3*x2 + 0.5 + noise
65 let n = 1024usize;
66 let mut xs: Vec<f32> = Vec::with_capacity(n * 2);
67 let mut ys: Vec<f32> = Vec::with_capacity(n);
68 // Simple LCG RNG for reproducibility
69 let mut state: u64 = 123456789;
70 let mut rand_f32 = || {
71 state = state.wrapping_mul(1664525).wrapping_add(1013904223);
72 (state >> 16) as f32 / (u32::MAX as f32)
73 };
74 for _ in 0..n {
75 let x1 = rand_f32() * 2.0 - 1.0;
76 let x2 = rand_f32() * 2.0 - 1.0;
77 let noise = (rand_f32() * 2.0 - 1.0) * 0.05;
78 let y = 2.0 * x1 - 3.0 * x2 + 0.5 + noise;
79 xs.push(x1);
80 xs.push(x2);
81 ys.push(y);
82 }
83
84 // Normalize targets to [-1, 1] (max-abs scaling) for reasonable loss magnitudes
85 let mut max_abs = 0.0f32;
86 for &v in &ys {
87 let a = v.abs();
88 if a > max_abs {
89 max_abs = a;
90 }
91 }
92 if max_abs < 1e-8 {
93 max_abs = 1.0;
94 }
95 for v in ys.iter_mut() {
96 *v /= max_abs;
97 }
98
99 // Normalize inputs per-feature to [-1, 1] (min-max scaling)
100 let mut min1 = f32::INFINITY;
101 let mut max1 = f32::NEG_INFINITY;
102 let mut min2 = f32::INFINITY;
103 let mut max2 = f32::NEG_INFINITY;
104 for i in (0..xs.len()).step_by(2) {
105 let a = xs[i];
106 let b = xs[i + 1];
107 if a < min1 {
108 min1 = a;
109 }
110 if a > max1 {
111 max1 = a;
112 }
113 if b < min2 {
114 min2 = b;
115 }
116 if b > max2 {
117 max2 = b;
118 }
119 }
120 let rng1 = (max1 - min1).max(1e-8);
121 let rng2 = (max2 - min2).max(1e-8);
122 for i in (0..xs.len()).step_by(2) {
123 let a = xs[i];
124 let b = xs[i + 1];
125 xs[i] = 2.0 * (a - min1) / rng1 - 1.0;
126 xs[i + 1] = 2.0 * (b - min2) / rng2 - 1.0;
127 }
128
129 // Train/Val split (80/20)
130 let n_train = (n as f32 * 0.8) as usize;
131 let x_train = Tensor::from_slice(&xs[..n_train * 2], vec![n_train, 2]).unwrap();
132 let y_train = Tensor::from_slice(&ys[..n_train], vec![n_train, 1]).unwrap();
133 let x_val = Tensor::from_slice(&xs[n_train * 2..], vec![n - n_train, 2]).unwrap();
134 let y_val = Tensor::from_slice(&ys[n_train..], vec![n - n_train, 1]).unwrap();
135
136 // Model config: 2 -> 64 -> 64 -> 1
137 let cfg = FeedForwardConfig {
138 input_size: 2,
139 hidden_sizes: vec![64, 64],
140 output_size: 1,
141 use_bias: true,
142 };
143 let mut net = FeedForwardNetwork::new(cfg, Some(2025));
144
145 // Optimizer and parameter linking
146 let mut opt = Adam::with_learning_rate(1e-3);
147 for p in net.parameters() {
148 opt.add_parameter(p);
149 }
150
151 let epochs = 400usize;
152 let max_grad_norm = 1.0f32;
153 let mut best_val_rmse = f32::INFINITY;
154 let mut best_val_r2 = -f32::INFINITY;
155
156 for e in 0..epochs {
157 // Zero grads
158 {
159 let mut params = net.parameters();
160 opt.zero_grad(&mut params);
161 }
162
163 // Forward
164 let pred = net.forward(&x_train);
165 let mut loss = mse(&pred, &y_train);
166 loss.backward(None);
167
168 // Step
169 {
170 let params = net.parameters();
171 let mut with_grads: Vec<&mut Tensor> = Vec::new();
172 for p in params {
173 if p.grad_owned().is_some() {
174 with_grads.push(p);
175 }
176 }
177 if !with_grads.is_empty() {
178 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
179 opt.step(&mut with_grads);
180 opt.zero_grad(&mut with_grads);
181 }
182 }
183
184 // Metrics
185 let train_rmse = rmse(&pred, &y_train);
186 let train_r2 = r2_score(&pred, &y_train);
187 let val_pred = net.forward(&x_val);
188 let val_rmse = rmse(&val_pred, &y_val);
189 let val_r2 = r2_score(&val_pred, &y_val);
190 if val_rmse < best_val_rmse {
191 best_val_rmse = val_rmse;
192 }
193 if val_r2 > best_val_r2 {
194 best_val_r2 = val_r2;
195 }
196
197 if e % 20 == 0 || e + 1 == epochs {
198 // Clamp displayed R^2 to avoid huge negative prints on early epochs
199 let train_r2_disp = train_r2.max(-10.0);
200 let val_r2_disp = val_r2.max(-10.0);
201 println!(
202 "epoch {:4} | train_rmse={:.4} r2={:.3} | val_rmse={:.4} r2={:.3} | best_val_rmse={:.4} best_val_r2={:.3}",
203 e, train_rmse, train_r2_disp, val_rmse, val_r2_disp, best_val_rmse, best_val_r2
204 );
205 }
206
207 clear_all_graphs_known();
208 }
209
210 // Quick sanity predictions on small samples
211 let sample = Tensor::from_slice(&[0.5, -0.25, -0.8, 0.3], vec![2, 2]).unwrap();
212 let sample_pred = net.forward(&sample);
213 println!("samples pred: {:?}", sample_pred.data());
214
215 println!("=== Supervised regression finished ===");
216 Ok(())
217}Sourcepub fn id(&self) -> usize
pub fn id(&self) -> usize
Get the unique ID of this tensor
Returns the unique identifier assigned to this tensor during creation. This ID is used for gradtrack tracking and tensor identification.
§Returns
Unique tensor ID as usize
§Examples
use train_station::Tensor;
let tensor1 = Tensor::new(vec![2, 3]);
let tensor2 = Tensor::new(vec![2, 3]);
assert_ne!(tensor1.id(), tensor2.id()); // Each tensor has unique IDSourcepub fn detach(&self) -> Self
pub fn detach(&self) -> Self
Detach this tensor from the computation graph
Returns a new tensor with the same data but no gradient tracking. This is useful when you want to use a tensor in inference without affecting the computation graph.
§Returns
A new tensor with the same data but gradient tracking disabled
§Examples
use train_station::Tensor;
let tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
let detached = tensor.detach();
assert!(!detached.requires_grad());
assert_eq!(tensor.size(), detached.size());Sourcepub fn detach_(&mut self)
pub fn detach_(&mut self)
Create a new tensor that doesn’t track gradients from this one
Similar to detach() but modifies this tensor in place. This is useful when you want to disable gradient tracking for the current tensor without creating a copy.
§Examples
use train_station::Tensor;
let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
assert!(tensor.requires_grad());
tensor.detach_();
assert!(!tensor.requires_grad());Sourcepub fn backward(&mut self, grad_output: Option<Tensor>)
pub fn backward(&mut self, grad_output: Option<Tensor>)
Entry point for backward pass on this tensor
Computes gradients for all tensors in the computation graph that have
requires_grad set to true. This is the main entry point for automatic
differentiation.
§Arguments
grad_output- Optional gradient tensor for the output. If None, assumes the tensor is a scalar (e.g., loss value) and uses a tensor of ones.
§Examples
use train_station::Tensor;
let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
let mut result = tensor.add_scalar(5.0);
result.backward(None);
// Note: Gradient computation depends on the gradtrack system implementationExamples found in repository?
73fn main() -> Result<(), Box<dyn std::error::Error>> {
74 println!("=== Basic Encoder Example ===");
75
76 let batch = 2usize;
77 let seq = 6usize;
78 let embed = 32usize;
79 let heads = 4usize;
80
81 let input = Tensor::randn(vec![batch, seq, embed], Some(11));
82 let mut enc = EncoderBlock::new(embed, heads, Some(123));
83
84 // Example: no mask (set Some(mask) to use masking)
85 let out = enc.forward(&input, None);
86 println!("Output shape: {:?}", out.shape().dims());
87
88 // Verify gradients/optimization
89 let mut opt = Adam::with_learning_rate(0.01);
90 let mut params = enc.parameters();
91 for p in ¶ms {
92 opt.add_parameter(p);
93 }
94 let mut loss = out.mean();
95 loss.backward(None);
96 opt.step(&mut params);
97 opt.zero_grad(&mut params);
98 println!("Loss: {:.6}", loss.value());
99 println!("=== Done ===");
100 Ok(())
101}More examples
84fn main() -> Result<(), Box<dyn std::error::Error>> {
85 println!("=== Basic Decoder Example ===");
86
87 let batch = 2usize;
88 let src = 7usize;
89 let tgt = 5usize;
90 let embed = 32usize;
91 let heads = 4usize;
92
93 let memory = Tensor::randn(vec![batch, src, embed], Some(21));
94 let tgt_in = Tensor::randn(vec![batch, tgt, embed], Some(22));
95
96 let mut dec = DecoderBlock::new(embed, heads, Some(456));
97 let out = dec.forward(&tgt_in, &memory, None, None);
98 println!("Output shape: {:?}", out.shape().dims());
99
100 let mut opt = Adam::with_learning_rate(0.01);
101 let mut params = dec.parameters();
102 for p in ¶ms {
103 opt.add_parameter(p);
104 }
105 let mut loss = out.mean();
106 loss.backward(None);
107 opt.step(&mut params);
108 opt.zero_grad(&mut params);
109 println!("Loss: {:.6}", loss.value());
110 println!("=== Done ===");
111 Ok(())
112}175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176 println!("\n--- Gradient Tracking ---");
177
178 // Create a tensor with gradient tracking enabled
179 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180 println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182 // Perform element-wise operations through iteration
183 let result: Tensor = tensor
184 .iter()
185 .map(|elem| {
186 // Apply a complex transformation: (x^2 + 1) * 2
187 elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188 })
189 .collect();
190
191 println!("Result tensor: {:?}", result.data());
192 println!("Result requires_grad: {}", result.requires_grad());
193
194 // Compute gradients
195 let mut loss = result.sum();
196 loss.backward(None);
197
198 println!("Loss: {:.6}", loss.value());
199 println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201 Ok(())
202}148 pub fn train_non_autoregressive_steps(
149 &mut self,
150 src: &Tensor,
151 tgt: &Tensor,
152 steps: usize,
153 lr: f32,
154 ) {
155 let mut opt = Adam::with_learning_rate(lr);
156 {
157 let params_once = self.parameters();
158 for p in ¶ms_once {
159 opt.add_parameter(p);
160 }
161 }
162 for step in 0..steps {
163 // forward + backward scope (immutable borrow)
164 {
165 let pred = self.forward(src, tgt);
166 let diff = pred.sub_tensor(tgt);
167 let mut loss = diff.pow_scalar(2.0).mean();
168 if step == 0 || step + 1 == steps {
169 println!("NAR train step {}: loss={:.6}", step, loss.value());
170 }
171 loss.backward(None);
172 }
173 // step + zero_grad scope (mutable borrow)
174 let mut params_step = self.parameters();
175 opt.step(&mut params_step);
176 opt.zero_grad(&mut params_step);
177 }
178 }
179
180 /// Auto-regressive training (teacher forcing): predict next token with causal mask
181 pub fn train_autoregressive_steps(
182 &mut self,
183 src: &Tensor,
184 tgt: &Tensor,
185 steps: usize,
186 lr: f32,
187 ) {
188 let mut opt = Adam::with_learning_rate(lr);
189 {
190 let params_once = self.parameters();
191 for p in ¶ms_once {
192 opt.add_parameter(p);
193 }
194 }
195
196 // Build encoder memory once (static dataset demo)
197 let mut memory = src.clone();
198 for enc in &self.encoders {
199 memory = enc.forward(&memory, None);
200 }
201
202 let (b, t, _e) = Self::triple(tgt);
203 // Predict y[t] from y[:t] using causal mask; here we simply predict full seq with mask
204 let causal = Self::build_causal_mask_static(b, self.num_heads, t);
205 for step in 0..steps {
206 // forward + backward scope
207 {
208 let mut out = tgt.clone();
209 for dec in &self.decoders {
210 out = dec.forward(&out, &memory, Some(&causal), None);
211 }
212 let diff = out.sub_tensor(tgt);
213 let mut loss = diff.pow_scalar(2.0).mean();
214 if step == 0 || step + 1 == steps {
215 println!("AR train step {}: loss={:.6}", step, loss.value());
216 }
217 loss.backward(None);
218 }
219 let mut params_step = self.parameters();
220 opt.step(&mut params_step);
221 opt.zero_grad(&mut params_step);
222 }
223 }
224
225 fn triple(t: &Tensor) -> (usize, usize, usize) {
226 let d = t.shape().dims();
227 (d[0], d[1], d[2])
228 }
229}
230
231fn main() -> Result<(), Box<dyn std::error::Error>> {
232 println!("=== Basic Transformer Example ===");
233
234 let batch = 2usize;
235 let src_len = 8usize;
236 let tgt_len = 6usize;
237 let embed = 32usize;
238 let heads = 4usize;
239 let layers = 2usize;
240
241 let src = Tensor::randn(vec![batch, src_len, embed], Some(1001));
242 let tgt = Tensor::randn(vec![batch, tgt_len, embed], Some(1002));
243
244 let mut trf = BasicTransformer::new(embed, heads, layers, Some(999));
245 let out = trf.forward(&src, &tgt);
246 println!("Output shape: {:?}", out.shape().dims());
247
248 // Quick optimization step
249 let mut opt = Adam::with_learning_rate(0.005);
250 let mut params = trf.parameters();
251 for p in ¶ms {
252 opt.add_parameter(p);
253 }
254 let mut loss = out.mean();
255 loss.backward(None);
256 opt.step(&mut params);
257 opt.zero_grad(&mut params);
258 println!("Loss: {:.6}", loss.value());
259
260 // Demo: non auto-regressive inference (single pass)
261 let nar = trf.infer_non_autoregressive(&src, tgt_len);
262 println!("NAR output shape: {:?}", nar.shape().dims());
263
264 // Demo: auto-regressive inference (toy)
265 let ar = trf.infer_autoregressive(&src, 3);
266 println!("AR output shape: {:?}", ar.shape().dims());
267
268 // NAR training demo
269 let nar_tgt = tgt.clone();
270 trf.train_non_autoregressive_steps(&src, &nar_tgt, 3, 0.01);
271
272 // AR training demo (teacher-forced)
273 let ar_tgt = tgt.clone();
274 trf.train_autoregressive_steps(&src, &ar_tgt, 3, 0.01);
275 println!("=== Done ===");
276 Ok(())
277}109fn demonstrate_optimizer_serialization() -> Result<(), Box<dyn std::error::Error>> {
110 println!("\n--- Optimizer Serialization ---");
111
112 // Create an optimizer with some parameters
113 let mut weight = Tensor::randn(vec![2, 2], Some(42)).with_requires_grad();
114 let mut bias = Tensor::randn(vec![2], Some(43)).with_requires_grad();
115
116 let config = AdamConfig {
117 learning_rate: 0.001,
118 beta1: 0.9,
119 beta2: 0.999,
120 eps: 1e-8,
121 weight_decay: 0.0,
122 amsgrad: false,
123 };
124
125 let mut optimizer = Adam::with_config(config);
126 optimizer.add_parameter(&weight);
127 optimizer.add_parameter(&bias);
128
129 println!(
130 "Created optimizer with {} parameters",
131 optimizer.parameter_count()
132 );
133 println!("Learning rate: {}", optimizer.learning_rate());
134
135 // Simulate some training steps
136 for _ in 0..3 {
137 let mut loss = weight.sum() + bias.sum();
138 loss.backward(None);
139 optimizer.step(&mut [&mut weight, &mut bias]);
140 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
141 }
142
143 // Save optimizer state
144 let optimizer_path = "temp_optimizer.json";
145 optimizer.save_json(optimizer_path)?;
146 println!("Saved optimizer to: {}", optimizer_path);
147
148 // Load optimizer state
149 let loaded_optimizer = Adam::load_json(optimizer_path)?;
150 println!(
151 "Loaded optimizer with {} parameters",
152 loaded_optimizer.parameter_count()
153 );
154 println!("Learning rate: {}", loaded_optimizer.learning_rate());
155
156 // Verify optimizer state
157 assert_eq!(
158 optimizer.parameter_count(),
159 loaded_optimizer.parameter_count()
160 );
161 assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
162 println!("Optimizer serialization verification: PASSED");
163
164 Ok(())
165}
166
167/// Demonstrate format comparison and performance characteristics
168fn demonstrate_format_comparison() -> Result<(), Box<dyn std::error::Error>> {
169 println!("\n--- Format Comparison ---");
170
171 // Create a larger tensor for comparison
172 let tensor = Tensor::randn(vec![10, 10], Some(44));
173
174 // Save in both formats
175 tensor.save_json("temp_comparison.json")?;
176 tensor.save_binary("temp_comparison.bin")?;
177
178 // Compare file sizes
179 let json_size = fs::metadata("temp_comparison.json")?.len();
180 let binary_size = fs::metadata("temp_comparison.bin")?.len();
181
182 println!("JSON file size: {} bytes", json_size);
183 println!("Binary file size: {} bytes", binary_size);
184 println!(
185 "Compression ratio: {:.2}x",
186 json_size as f64 / binary_size as f64
187 );
188
189 // Load and verify both formats
190 let json_tensor = Tensor::load_json("temp_comparison.json")?;
191 let binary_tensor = Tensor::load_binary("temp_comparison.bin")?;
192
193 assert_eq!(tensor.shape().dims(), json_tensor.shape().dims());
194 assert_eq!(tensor.shape().dims(), binary_tensor.shape().dims());
195 assert_eq!(tensor.data(), json_tensor.data());
196 assert_eq!(tensor.data(), binary_tensor.data());
197
198 println!("Format comparison verification: PASSED");
199
200 Ok(())
201}
202
203/// Demonstrate a basic model checkpointing workflow
204fn demonstrate_model_checkpointing() -> Result<(), Box<dyn std::error::Error>> {
205 println!("\n--- Model Checkpointing ---");
206
207 // Create a simple model (weights and bias)
208 let mut weights = Tensor::randn(vec![2, 1], Some(45)).with_requires_grad();
209 let mut bias = Tensor::randn(vec![1], Some(46)).with_requires_grad();
210
211 // Create optimizer
212 let mut optimizer = Adam::with_learning_rate(0.01);
213 optimizer.add_parameter(&weights);
214 optimizer.add_parameter(&bias);
215
216 println!("Initial weights: {:?}", weights.data());
217 println!("Initial bias: {:?}", bias.data());
218
219 // Simulate training
220 for epoch in 0..5 {
221 let mut loss = weights.sum() + bias.sum();
222 loss.backward(None);
223 optimizer.step(&mut [&mut weights, &mut bias]);
224 optimizer.zero_grad(&mut [&mut weights, &mut bias]);
225
226 if epoch % 2 == 0 {
227 // Save checkpoint
228 let checkpoint_dir = format!("checkpoint_epoch_{}", epoch);
229 fs::create_dir_all(&checkpoint_dir)?;
230
231 weights.save_json(format!("{}/weights.json", checkpoint_dir))?;
232 bias.save_json(format!("{}/bias.json", checkpoint_dir))?;
233 optimizer.save_json(format!("{}/optimizer.json", checkpoint_dir))?;
234
235 println!("Saved checkpoint for epoch {}", epoch);
236 }
237 }
238
239 // Load from checkpoint
240 let loaded_weights = Tensor::load_json("checkpoint_epoch_4/weights.json")?;
241 let loaded_bias = Tensor::load_json("checkpoint_epoch_4/bias.json")?;
242 let loaded_optimizer = Adam::load_json("checkpoint_epoch_4/optimizer.json")?;
243
244 println!("Loaded weights: {:?}", loaded_weights.data());
245 println!("Loaded bias: {:?}", loaded_bias.data());
246 println!(
247 "Loaded optimizer learning rate: {}",
248 loaded_optimizer.learning_rate()
249 );
250
251 // Verify checkpoint integrity
252 assert_eq!(weights.shape().dims(), loaded_weights.shape().dims());
253 assert_eq!(bias.shape().dims(), loaded_bias.shape().dims());
254 assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
255
256 println!("Checkpointing verification: PASSED");
257
258 Ok(())
259}165fn main() -> Result<(), Box<dyn std::error::Error>> {
166 println!("=== Multi-Head Attention Example ===");
167
168 let batch = 2usize;
169 let src_len = 5usize;
170 let tgt_len = 4usize;
171 let embed = 16usize;
172 let heads = 4usize;
173
174 let query = Tensor::randn(vec![batch, tgt_len, embed], Some(7));
175 let key = Tensor::randn(vec![batch, src_len, embed], Some(8));
176 let value = Tensor::randn(vec![batch, src_len, embed], Some(9));
177
178 let mut mha = MultiHeadAttention::new(embed, heads, Some(42));
179
180 // Simple causal mask for target self-attention shape [b, h, tq, tk]
181 let mut mask = Tensor::zeros(vec![batch, heads, tgt_len, src_len]);
182 // Disallow attending to future positions when tgt_len <= src_len by adding -1e9
183 // Here, just demonstrate mask broadcast/add mechanics with a light mask on last head
184 if src_len >= tgt_len {
185 // set upper triangle to a large negative value for head 0
186 for i in 0..tgt_len {
187 for j in (i + 1)..src_len {
188 let idx = [0usize, 0usize, i, j];
189 // Quick set via data_mut using a slice view
190 let offset = mask.memory_offset(&idx);
191 let data = mask.data_mut();
192 data[offset] = -1e9;
193 }
194 }
195 }
196
197 let out = mha.forward(&query, &key, &value, Some(&mask));
198 println!("Output shape: {:?}", out.shape().dims());
199
200 // Tiny training step to confirm gradients are wired
201 let mut optimizer = Adam::with_learning_rate(0.01);
202 let mut params = mha.parameters();
203 for p in ¶ms {
204 optimizer.add_parameter(p);
205 }
206
207 // Dummy loss = mean of output
208 let mut loss = out.mean();
209 loss.backward(None);
210 optimizer.step(&mut params);
211 optimizer.zero_grad(&mut params);
212
213 println!("Loss: {:.6}", loss.value());
214 println!("=== Done ===");
215 Ok(())
216}- examples/optimizers/adam_configurations.rs
- examples/optimizers/learning_rate_scheduling.rs
- examples/getting_started/optimizer_basics.rs
- examples/neural_networks/feedforward_network.rs
- examples/RL_training/../neural_networks/basic_linear_layer.rs
- examples/supervised_training/supervised_bce.rs
- examples/iterators/performance_optimization.rs
- examples/supervised_training/supervised_classification.rs
- examples/supervised_training/supervised_regression.rs
- examples/RL_training/dqn.rs
- examples/RL_training/ppo_discrete.rs
- examples/RL_training/ppo_continuous.rs
- examples/RL_training/td3.rs
Sourcepub unsafe fn as_ptr(&self) -> *const f32
pub unsafe fn as_ptr(&self) -> *const f32
Returns a raw pointer to the tensor data for unsafe operations
§Safety
This is unsafe because it provides direct access to the underlying memory. The caller must ensure:
- The tensor is not dropped while the pointer is used
- No concurrent mutable access occurs
- Bounds are respected
Sourcepub unsafe fn as_mut_ptr(&mut self) -> *mut f32
pub unsafe fn as_mut_ptr(&mut self) -> *mut f32
Returns a mutable raw pointer to the tensor data for unsafe operations
§Safety
This is unsafe because it provides direct mutable access to the underlying memory. The caller must ensure:
- The tensor is not dropped while the pointer is used
- No concurrent access occurs
- Bounds are respected
Sourcepub fn grad_fn(&self) -> &GradFn
pub fn grad_fn(&self) -> &GradFn
Get a reference to the gradient function (for gradtrack)
Returns a reference to the gradient function associated with this tensor. This is used internally by the gradtrack system to compute gradients.
§Returns
Reference to the gradient function
§Implementation Details
This method is used by the gradtrack engine to access the gradient computation function during backward pass.
Sourcepub fn set_grad(&mut self, grad: Tensor)
pub fn set_grad(&mut self, grad: Tensor)
Set gradient from external source
Sets the gradient tensor for this tensor. This is used internally by the gradtrack system to set gradients during backward pass.
§Arguments
grad- The gradient tensor to set
§Implementation Details
This method is used internally by the gradtrack engine to set gradients during backward pass. It only sets the gradient if gradient tracking is enabled for this tensor.
Examples found in repository?
277fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
278 let mut total_sq = 0.0f32;
279 for p in parameters.iter() {
280 if let Some(g) = p.grad_owned() {
281 for &v in g.data() {
282 total_sq += v * v;
283 }
284 }
285 }
286 let norm = total_sq.sqrt();
287 if norm > max_norm {
288 let scale = max_norm / (norm + eps);
289 for p in parameters.iter_mut() {
290 if let Some(g) = p.grad_owned() {
291 p.set_grad(g.mul_scalar(scale));
292 }
293 }
294 }
295}More examples
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294 let mut total_sq = 0.0f32;
295 for p in parameters.iter() {
296 if let Some(g) = p.grad_owned() {
297 for &v in g.data() {
298 total_sq += v * v;
299 }
300 }
301 }
302 let norm = total_sq.sqrt();
303 if norm > max_norm {
304 let scale = max_norm / (norm + eps);
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 p.set_grad(g.mul_scalar(scale));
308 }
309 }
310 }
311}252fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
253 let mut total_sq = 0.0f32;
254 for p in parameters.iter() {
255 if let Some(g) = p.grad_owned() {
256 for &v in g.data() {
257 total_sq += v * v;
258 }
259 }
260 }
261 let norm = total_sq.sqrt();
262 if norm > max_norm {
263 let scale = max_norm / (norm + eps);
264 for p in parameters.iter_mut() {
265 if let Some(g) = p.grad_owned() {
266 p.set_grad(g.mul_scalar(scale));
267 }
268 }
269 }
270}23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24 let mut total_sq = 0.0f32;
25 for p in parameters.iter() {
26 if let Some(g) = p.grad_owned() {
27 for &v in g.data() {
28 total_sq += v * v;
29 }
30 }
31 }
32 let norm = total_sq.sqrt();
33 if norm > max_norm {
34 let scale = max_norm / (norm + eps);
35 for p in parameters.iter_mut() {
36 if let Some(g) = p.grad_owned() {
37 p.set_grad(g.mul_scalar(scale));
38 }
39 }
40 }
41}23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24 let mut total_sq = 0.0f32;
25 for p in parameters.iter() {
26 if let Some(g) = p.grad_owned() {
27 for &v in g.data() {
28 total_sq += v * v;
29 }
30 }
31 }
32 let norm = total_sq.sqrt();
33 if norm > max_norm {
34 let scale = max_norm / (norm + eps);
35 for p in parameters.iter_mut() {
36 if let Some(g) = p.grad_owned() {
37 p.set_grad(g.mul_scalar(scale));
38 }
39 }
40 }
41}22fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
23 let mut total_sq = 0.0f32;
24 for p in parameters.iter() {
25 if let Some(g) = p.grad_owned() {
26 for &v in g.data() {
27 total_sq += v * v;
28 }
29 }
30 }
31 let norm = total_sq.sqrt();
32 if norm > max_norm {
33 let scale = max_norm / (norm + eps);
34 for p in parameters.iter_mut() {
35 if let Some(g) = p.grad_owned() {
36 p.set_grad(g.mul_scalar(scale));
37 }
38 }
39 }
40}Sourcepub fn zero_grad(&mut self)
pub fn zero_grad(&mut self)
Clear accumulated gradients for this tensor
This method is used by optimizers to zero gradients before each backward pass. It clears any accumulated gradients, allowing for fresh gradient computation.
§Examples
use train_station::Tensor;
let mut tensor = Tensor::ones(vec![2, 3]).with_requires_grad();
tensor.set_grad(Tensor::ones(vec![2, 3]));
assert!(tensor.grad().is_some());
tensor.zero_grad();
assert!(tensor.grad().is_none());Sourcepub fn is_contiguous(&self) -> bool
pub fn is_contiguous(&self) -> bool
Checks if the tensor data is stored contiguously in memory
§Returns
true if the tensor data is contiguous, enabling optimized SIMD operations
§Examples
use train_station::Tensor;
let tensor = Tensor::new(vec![2, 3, 4]);
assert!(tensor.is_contiguous());Examples found in repository?
179fn demonstrate_utility_functions() {
180 println!("\n--- Utility Functions ---");
181
182 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184 // Basic properties
185 println!("Shape: {:?}", tensor.shape().dims());
186 println!("Size: {}", tensor.size());
187 println!("Is contiguous: {}", tensor.is_contiguous());
188 println!("Device: {:?}", tensor.device());
189
190 // Mathematical operations
191 let sum = tensor.sum();
192 println!("Sum: {}", sum.value());
193
194 let mean = tensor.mean();
195 println!("Mean: {}", mean.value());
196
197 let norm = tensor.norm();
198 println!("Norm: {}", norm.value());
199
200 // Device placement
201 let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202 println!(
203 "CPU tensor: shape {:?}, device: {:?}",
204 cpu_tensor.shape().dims(),
205 cpu_tensor.device()
206 );
207}Sourcepub fn memory_offset(&self, indices: &[usize]) -> usize
pub fn memory_offset(&self, indices: &[usize]) -> usize
Calculates the linear memory offset for given multi-dimensional indices
§Arguments
indices- Vector of indices for each dimension
§Returns
Linear memory offset for direct memory access
§Examples
use train_station::Tensor;
let tensor = Tensor::new(vec![2, 3, 4]);
let offset = tensor.memory_offset(&[1, 2, 3]);
// offset = 1*12 + 2*4 + 3*1 = 23Examples found in repository?
76 pub fn infer_autoregressive(&self, src: &Tensor, max_steps: usize) -> Tensor {
77 let (b, _s, e) = Self::triple(src);
78 let mut memory = src.clone();
79 for enc in &self.encoders {
80 memory = enc.forward(&memory, None);
81 }
82
83 let mut out_seq: Vec<Tensor> = Vec::new();
84 // Start token: zeros
85 let mut current = Tensor::zeros(vec![b, 1, e]);
86 for _step in 0..max_steps {
87 // Build causal mask for length t
88 let t = current.shape().dims()[1];
89 let mut causal = Tensor::ones(vec![b, self.num_heads, t, t]);
90 // Upper triangle as false -> masked for all batches and heads
91 for bb in 0..b {
92 for hh in 0..self.num_heads {
93 for i in 0..t {
94 for j in (i + 1)..t {
95 let offset = causal.memory_offset(&[bb, hh, i, j]);
96 let data = causal.data_mut();
97 data[offset] = 0.0;
98 }
99 }
100 }
101 }
102 let mut step_out = current.clone();
103 for dec in &self.decoders {
104 step_out = dec.forward(&step_out, &memory, Some(&causal), None);
105 }
106 // (Toy) append placeholder token; real models would project last token
107 out_seq.push(step_out.clone());
108 // Append a zero token to grow sequence by 1 for next causal computation
109 current = Tensor::zeros(vec![b, t + 1, e]);
110 }
111 // Simple return of final sequence placeholder
112 current
113 }
114
115 /// Non auto-regressive inference: single forward pass
116 pub fn infer_non_autoregressive(&self, src: &Tensor, tgt_len: usize) -> Tensor {
117 let (b, _s, e) = Self::triple(src);
118 let mut memory = src.clone();
119 for enc in &self.encoders {
120 memory = enc.forward(&memory, None);
121 }
122 let tgt = Tensor::zeros(vec![b, tgt_len, e]);
123 let mut out = tgt.clone();
124 for dec in &self.decoders {
125 out = dec.forward(&out, &memory, None, None);
126 }
127 out
128 }
129
130 /// Helper: build boolean-like causal mask [b, heads, t, t] with 1.0 keep, 0.0 masked
131 fn build_causal_mask_static(batch: usize, heads: usize, t: usize) -> Tensor {
132 let mut mask = Tensor::ones(vec![batch, heads, t, t]);
133 for bb in 0..batch {
134 for hh in 0..heads {
135 for i in 0..t {
136 for j in (i + 1)..t {
137 let offset = mask.memory_offset(&[bb, hh, i, j]);
138 let data = mask.data_mut();
139 data[offset] = 0.0;
140 }
141 }
142 }
143 }
144 mask
145 }More examples
165fn main() -> Result<(), Box<dyn std::error::Error>> {
166 println!("=== Multi-Head Attention Example ===");
167
168 let batch = 2usize;
169 let src_len = 5usize;
170 let tgt_len = 4usize;
171 let embed = 16usize;
172 let heads = 4usize;
173
174 let query = Tensor::randn(vec![batch, tgt_len, embed], Some(7));
175 let key = Tensor::randn(vec![batch, src_len, embed], Some(8));
176 let value = Tensor::randn(vec![batch, src_len, embed], Some(9));
177
178 let mut mha = MultiHeadAttention::new(embed, heads, Some(42));
179
180 // Simple causal mask for target self-attention shape [b, h, tq, tk]
181 let mut mask = Tensor::zeros(vec![batch, heads, tgt_len, src_len]);
182 // Disallow attending to future positions when tgt_len <= src_len by adding -1e9
183 // Here, just demonstrate mask broadcast/add mechanics with a light mask on last head
184 if src_len >= tgt_len {
185 // set upper triangle to a large negative value for head 0
186 for i in 0..tgt_len {
187 for j in (i + 1)..src_len {
188 let idx = [0usize, 0usize, i, j];
189 // Quick set via data_mut using a slice view
190 let offset = mask.memory_offset(&idx);
191 let data = mask.data_mut();
192 data[offset] = -1e9;
193 }
194 }
195 }
196
197 let out = mha.forward(&query, &key, &value, Some(&mask));
198 println!("Output shape: {:?}", out.shape().dims());
199
200 // Tiny training step to confirm gradients are wired
201 let mut optimizer = Adam::with_learning_rate(0.01);
202 let mut params = mha.parameters();
203 for p in ¶ms {
204 optimizer.add_parameter(p);
205 }
206
207 // Dummy loss = mean of output
208 let mut loss = out.mean();
209 loss.backward(None);
210 optimizer.step(&mut params);
211 optimizer.zero_grad(&mut params);
212
213 println!("Loss: {:.6}", loss.value());
214 println!("=== Done ===");
215 Ok(())
216}Sourcepub fn memory_alignment(&self) -> usize
pub fn memory_alignment(&self) -> usize
Gets the memory alignment of the tensor data
§Returns
The memory alignment in bytes (typically 32 for SIMD optimization)
Sourcepub fn is_broadcastable_with(&self, other: &Tensor) -> bool
pub fn is_broadcastable_with(&self, other: &Tensor) -> bool
Checks if this tensor is broadcastable with another tensor
§Arguments
other- The other tensor to check broadcasting compatibility
§Returns
true if the tensors are broadcastable according to NumPy broadcasting rules
§Examples
use train_station::Tensor;
let a = Tensor::new(vec![2, 3, 4]);
let b = Tensor::new(vec![1, 3, 4]);
assert!(a.is_broadcastable_with(&b));Sourcepub fn memory_footprint(&self) -> usize
pub fn memory_footprint(&self) -> usize
Sourcepub fn get(&self, indices: &[usize]) -> f32
pub fn get(&self, indices: &[usize]) -> f32
Get a single element from the tensor at the specified indices
§Arguments
indices- Multi-dimensional indices to access the element
§Returns
The value at the specified position
§Panics
Panics if indices are out of bounds or indices length doesn’t match tensor rank
§Examples
use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let value = tensor.get(&[0, 1]);
assert_eq!(value, 2.0);Examples found in repository?
156fn demonstrate_data_access() {
157 println!("\n--- Data Access ---");
158
159 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
160
161 // Access individual elements
162 println!("Element [0, 0]: {}", tensor.get(&[0, 0]));
163 println!("Element [0, 1]: {}", tensor.get(&[0, 1]));
164 println!("Element [1, 0]: {}", tensor.get(&[1, 0]));
165 println!("Element [1, 1]: {}", tensor.get(&[1, 1]));
166
167 // Access data as slice
168 let data = tensor.data();
169 println!("Data as slice: {:?}", data);
170
171 // Iterate over elements
172 println!("Elements:");
173 for (i, &value) in data.iter().enumerate() {
174 println!(" [{}]: {}", i, value);
175 }
176}Sourcepub fn set(&mut self, indices: &[usize], value: f32)
pub fn set(&mut self, indices: &[usize], value: f32)
Set a single element in the tensor at the specified indices
§Arguments
indices- Multi-dimensional indices to set the elementvalue- The value to set
§Panics
Panics if indices are out of bounds or indices length doesn’t match tensor rank
§Examples
use train_station::Tensor;
let mut tensor = Tensor::new(vec![2, 2]);
tensor.set(&[0, 1], 42.0);
assert_eq!(tensor.get(&[0, 1]), 42.0);Sourcepub fn data(&self) -> &[f32]
pub fn data(&self) -> &[f32]
Returns a safe slice of the tensor’s underlying data
Provides safe access to the tensor’s data without requiring unsafe pointer operations. This is the preferred way to access tensor data for reading values, comparisons, and other operations that don’t require direct pointer manipulation.
§Returns
A slice containing all tensor elements in row-major order
§Performance
- Zero-Cost: Direct slice creation with no copying
- Cache-Friendly: Sequential memory access pattern
- Safe: No unsafe code required for basic data access
§Examples
use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let data = tensor.data();
// Safe indexing and comparisons
assert_eq!(data[0], 1.0);
assert_eq!(data.len(), tensor.size());Examples found in repository?
108 fn copy_from(&mut self, other: &Self) {
109 for (t, s) in self.layers.iter_mut().zip(other.layers.iter()) {
110 {
111 let src = s.weight.data();
112 let dst = t.weight.data_mut();
113 dst.copy_from_slice(src);
114 }
115 {
116 let src = s.bias.data();
117 let dst = t.bias.data_mut();
118 dst.copy_from_slice(src);
119 }
120 t.weight.set_requires_grad(false);
121 t.bias.set_requires_grad(false);
122 }
123 }
124}
125
126// -------------------------------
127// Q-Network (state -> Q-values over actions)
128// -------------------------------
129
130struct QNet {
131 net: Mlp,
132}
133
134impl QNet {
135 fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
136 let net = Mlp::new(&[state_dim, 64, 64, action_dim], seed);
137 Self { net }
138 }
139 fn forward(&self, state: &Tensor) -> Tensor {
140 self.net.forward(state, None)
141 }
142 fn parameters(&mut self) -> Vec<&mut Tensor> {
143 self.net.parameters()
144 }
145 fn set_requires_grad_all(&mut self, enable: bool) {
146 self.net.set_requires_grad_all(enable);
147 }
148}
149
150// -------------------------------
151// Discrete YardEnv (3 actions: -1, 0, +1)
152// -------------------------------
153
154struct YardEnv {
155 pos: f32,
156 vel: f32,
157 steps: usize,
158 max_steps: usize,
159 rng: SmallRng,
160}
161
162impl YardEnv {
163 const ACTIONS: [f32; 3] = [-1.0, 0.0, 1.0];
164
165 fn new(seed: u64) -> Self {
166 let mut env = Self {
167 pos: 0.0,
168 vel: 0.0,
169 steps: 0,
170 max_steps: 200,
171 rng: SmallRng::new(seed),
172 };
173 env.reset();
174 env
175 }
176
177 fn reset(&mut self) -> Tensor {
178 self.pos = self.rng.uniform(-0.5, 0.5);
179 self.vel = self.rng.uniform(-0.1, 0.1);
180 self.steps = 0;
181 self.state_tensor()
182 }
183
184 fn state_tensor(&self) -> Tensor {
185 Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
186 }
187
188 fn step(&mut self, action_index: usize) -> (Tensor, f32, bool) {
189 let a = Self::ACTIONS[action_index.min(2)];
190 self.vel += 0.1 * a - 0.01 * self.pos;
191 self.pos += self.vel;
192 self.steps += 1;
193 let reward = -(self.pos * self.pos) - 0.05 * (a * a);
194 let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
195 (self.state_tensor(), reward, done)
196 }
197}
198
199// -------------------------------
200// Replay Buffer
201// -------------------------------
202
203struct ReplayBuffer {
204 capacity: usize,
205 size: usize,
206 pos: usize,
207 state_dim: usize,
208 states: Vec<f32>,
209 actions: Vec<usize>,
210 rewards: Vec<f32>,
211 dones: Vec<f32>,
212 next_states: Vec<f32>,
213}
214
215impl ReplayBuffer {
216 fn new(capacity: usize, state_dim: usize) -> Self {
217 Self {
218 capacity,
219 size: 0,
220 pos: 0,
221 state_dim,
222 states: vec![0.0; capacity * state_dim],
223 actions: vec![0usize; capacity],
224 rewards: vec![0.0; capacity],
225 dones: vec![0.0; capacity],
226 next_states: vec![0.0; capacity * state_dim],
227 }
228 }
229
230 fn push(&mut self, s: &[f32], a_idx: usize, r: f32, d: f32, s2: &[f32]) {
231 let i = self.pos;
232 let so = i * self.state_dim;
233 self.states[so..so + self.state_dim].copy_from_slice(s);
234 self.actions[i] = a_idx;
235 self.rewards[i] = r;
236 self.dones[i] = d;
237 self.next_states[so..so + self.state_dim].copy_from_slice(s2);
238 self.pos = (self.pos + 1) % self.capacity;
239 self.size = self.size.saturating_add(1).min(self.capacity);
240 }
241
242 fn can_sample(&self, batch_size: usize) -> bool {
243 self.size >= batch_size
244 }
245
246 fn sample(
247 &self,
248 batch_size: usize,
249 rng: &mut SmallRng,
250 ) -> (Tensor, Vec<usize>, Tensor, Tensor, Tensor) {
251 let mut s_vec = Vec::with_capacity(batch_size * self.state_dim);
252 let mut a_idx = Vec::with_capacity(batch_size);
253 let mut r_vec = Vec::with_capacity(batch_size);
254 let mut d_vec = Vec::with_capacity(batch_size);
255 let mut s2_vec = Vec::with_capacity(batch_size * self.state_dim);
256 for _ in 0..batch_size {
257 let idx = rng.sample_index(self.size);
258 let so = idx * self.state_dim;
259 s_vec.extend_from_slice(&self.states[so..so + self.state_dim]);
260 a_idx.push(self.actions[idx]);
261 r_vec.push(self.rewards[idx]);
262 d_vec.push(self.dones[idx]);
263 s2_vec.extend_from_slice(&self.next_states[so..so + self.state_dim]);
264 }
265 let s = Tensor::from_slice(&s_vec, vec![batch_size, self.state_dim]).unwrap();
266 let r = Tensor::from_slice(&r_vec, vec![batch_size, 1]).unwrap();
267 let d = Tensor::from_slice(&d_vec, vec![batch_size, 1]).unwrap();
268 let s2 = Tensor::from_slice(&s2_vec, vec![batch_size, self.state_dim]).unwrap();
269 (s, a_idx, r, d, s2)
270 }
271}
272
273// -------------------------------
274// Helpers
275// -------------------------------
276
277fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
278 let mut total_sq = 0.0f32;
279 for p in parameters.iter() {
280 if let Some(g) = p.grad_owned() {
281 for &v in g.data() {
282 total_sq += v * v;
283 }
284 }
285 }
286 let norm = total_sq.sqrt();
287 if norm > max_norm {
288 let scale = max_norm / (norm + eps);
289 for p in parameters.iter_mut() {
290 if let Some(g) = p.grad_owned() {
291 p.set_grad(g.mul_scalar(scale));
292 }
293 }
294 }
295}
296
297fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
298 let mut total_sq = 0.0f32;
299 for p in parameters.iter_mut() {
300 if let Some(g) = p.grad_owned() {
301 for &v in g.data() {
302 total_sq += v * v;
303 }
304 }
305 }
306 total_sq.sqrt()
307}
308
309fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
310 let _ng = NoGradTrack::new();
311 let mut total_sq = 0.0f32;
312 for p in parameters.iter_mut() {
313 for &v in p.data() {
314 total_sq += v * v;
315 }
316 }
317 total_sq.sqrt()
318}
319
320// Pseudo-Huber loss: sqrt(1 + diff^2) - 1 (smooth, robust)
321fn pseudo_huber_mean(diff: &Tensor) -> Tensor {
322 diff.pow_scalar(2.0)
323 .add_scalar(1.0)
324 .sqrt()
325 .sub_scalar(1.0)
326 .mean()
327}
328
329// -------------------------------
330// Main
331// -------------------------------
332
333pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334 println!("=== DQN Example (YardEnv discrete) ===");
335
336 // Dims
337 let state_dim = 3usize;
338 let action_dim = 3usize;
339
340 // Hparams
341 let gamma = 0.99f32;
342 let batch_size = 64usize;
343 let start_steps = 200usize;
344 let target_update_interval = 200usize; // hard update cadence
345 let max_grad_norm = 1.0f32;
346 let mut epsilon = 1.0f32;
347 let eps_min = 0.05f32;
348 let eps_decay_steps = 2_000usize; // linear decay
349 let total_steps = std::env::var("DQN_STEPS")
350 .ok()
351 .and_then(|v| v.parse::<usize>().ok())
352 .unwrap_or(3000usize);
353
354 // Models
355 let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356 let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357 q_targ.net.copy_from(&q_net.net);
358 q_targ.set_requires_grad_all(false);
359
360 // Optimizer
361 let mut q_opt = Adam::with_learning_rate(3e-4);
362 for p in q_net.parameters() {
363 q_opt.add_parameter(p);
364 }
365
366 // Replay + env
367 let mut rb = ReplayBuffer::new(100_000, state_dim);
368 let mut env = YardEnv::new(12345);
369 let mut rng = SmallRng::new(999_111);
370
371 // Metrics
372 let mut state = env.reset();
373 let mut episode_return = 0.0f32;
374 let mut episode = 0usize;
375 let mut ema_return: Option<f32> = None;
376 let ema_alpha = 0.05f32;
377 let mut best_return = f32::NEG_INFINITY;
378
379 for t in 0..total_steps {
380 // Epsilon-greedy action
381 let action_index = if t < start_steps || rng.next_f32() < epsilon {
382 rng.sample_index(action_dim)
383 } else {
384 let _ng = NoGradTrack::new();
385 let q_vals = q_net.forward(&state);
386 let row = q_vals.data();
387 let mut best_i = 0usize;
388 let mut best_v = row[0];
389 for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390 if r > best_v {
391 best_v = r;
392 best_i = i;
393 }
394 }
395 best_i
396 };
397
398 // Env step
399 let (next_state, reward, done) = env.step(action_index);
400 episode_return += reward;
401
402 // Store
403 let s_slice = state.data().to_vec();
404 let s2_slice = next_state.data().to_vec();
405 rb.push(
406 &s_slice,
407 action_index,
408 reward,
409 if done { 1.0 } else { 0.0 },
410 &s2_slice,
411 );
412
413 // Reset on done
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return,
430 rb.size
431 );
432 episode_return = 0.0;
433 episode += 1;
434 st
435 } else {
436 next_state
437 };
438
439 // Epsilon linear decay
440 if t < eps_decay_steps {
441 epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442 }
443
444 // Train
445 if rb.can_sample(batch_size) {
446 let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448 // Double DQN target: a* = argmax_a Q_online(s2,a); y = r + (1-d)*gamma*Q_target(s2, a*)
449 let target_q = {
450 let _ng = NoGradTrack::new();
451 let q_online_s2 = q_net.forward(&s2);
452 // argmax per row (manual on CPU)
453 let row_stride = action_dim;
454 let qd = q_online_s2.data();
455 let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456 for i in 0..batch_size {
457 let base = i * row_stride;
458 let mut bi = 0usize;
459 let mut bv = qd[base];
460 for j in 1..action_dim {
461 let v = qd[base + j];
462 if v > bv {
463 bv = v;
464 bi = j;
465 }
466 }
467 next_actions.push(bi);
468 }
469 let q_targ_s2 = q_targ.forward(&s2);
470 let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471 let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472 r.add_tensor(¬_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473 };
474
475 // Q(s,a) for current actions
476 // Zero grads first
477 {
478 let mut params = q_net.parameters();
479 q_opt.zero_grad(&mut params);
480 }
481
482 let q_all = q_net.forward(&s);
483 let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484 let diff = q_sa.sub_tensor(&target_q);
485 let mut loss = pseudo_huber_mean(&diff);
486 loss.backward(None);
487
488 // Step (filter only params with grads)
489 {
490 let params = q_net.parameters();
491 let mut with_grads: Vec<&mut Tensor> = Vec::new();
492 for p in params {
493 if p.grad_owned().is_some() {
494 with_grads.push(p);
495 }
496 }
497 if !with_grads.is_empty() {
498 let gn = grad_global_norm(&mut with_grads);
499 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500 q_opt.step(&mut with_grads);
501 q_opt.zero_grad(&mut with_grads);
502 if t % 100 == 0 {
503 let mut pn = q_net.parameters();
504 let pn_l2 = params_l2_norm(&mut pn);
505 let q_mean = q_all.mean().value();
506 println!(
507 "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508 t, loss.value(), q_mean, gn, pn_l2, epsilon
509 );
510 }
511 }
512 }
513
514 // Target hard update
515 if t % target_update_interval == 0 {
516 q_targ.net.copy_from(&q_net.net);
517 }
518
519 // Clear graphs
520 clear_all_graphs_known();
521 }
522 }
523
524 println!("=== DQN training finished ===");
525 Ok(())
526}More examples
114 fn copy_from(&mut self, other: &Self) {
115 for (t, s) in self.layers.iter_mut().zip(other.layers.iter()) {
116 {
117 let src = s.weight.data();
118 let dst = t.weight.data_mut();
119 dst.copy_from_slice(src);
120 }
121 {
122 let src = s.bias.data();
123 let dst = t.bias.data_mut();
124 dst.copy_from_slice(src);
125 }
126 t.weight.set_requires_grad(false);
127 t.bias.set_requires_grad(false);
128 }
129 }
130
131 fn soft_update_from(&mut self, source: &Self, tau: f32) {
132 let _ng = NoGradTrack::new();
133 for (t, s) in self.layers.iter_mut().zip(source.layers.iter()) {
134 // In-place Polyak update to preserve tensor IDs (no optimizer relink needed)
135 let new_w = t
136 .weight
137 .mul_scalar(1.0 - tau)
138 .add_tensor(&s.weight.mul_scalar(tau));
139 let new_b = t
140 .bias
141 .mul_scalar(1.0 - tau)
142 .add_tensor(&s.bias.mul_scalar(tau));
143 {
144 let src = new_w.data();
145 let dst = t.weight.data_mut();
146 dst.copy_from_slice(src);
147 }
148 {
149 let src = new_b.data();
150 let dst = t.bias.data_mut();
151 dst.copy_from_slice(src);
152 }
153 t.weight.set_requires_grad(false);
154 t.bias.set_requires_grad(false);
155 }
156 }
157}
158
159// -------------------------------
160// Actor and Critic
161// -------------------------------
162
163struct Actor {
164 net: Mlp,
165}
166
167impl Actor {
168 fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
169 // Smaller net for faster demo: sd -> 64 -> 64 -> ad, tanh output
170 let net = Mlp::new(&[state_dim, 64, 64, action_dim], seed);
171 Self { net }
172 }
173 fn forward(&self, state: &Tensor) -> Tensor {
174 self.net.forward(state, Some(tanh_bounded))
175 }
176 fn parameters(&mut self) -> Vec<&mut Tensor> {
177 self.net.parameters()
178 }
179 fn set_requires_grad_all(&mut self, enable: bool) {
180 self.net.set_requires_grad_all(enable);
181 }
182}
183
184struct Critic {
185 net: Mlp,
186}
187
188impl Critic {
189 fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
190 let net = Mlp::new(&[state_dim + action_dim, 64, 64, 1], seed);
191 Self { net }
192 }
193 fn forward(&self, state: &Tensor, action: &Tensor) -> Tensor {
194 // Concatenate along feature dim (dim=1) for batched inputs
195 // IMPORTANT: use views to preserve gradient graph; cloning would detach autograd
196 let s_view = state.view(state.shape().dims().iter().map(|&d| d as i32).collect());
197 let a_view = action.view(action.shape().dims().iter().map(|&d| d as i32).collect());
198 let sa = Tensor::cat(&[s_view, a_view], 1);
199 self.net.forward(&sa, None)
200 }
201 fn parameters(&mut self) -> Vec<&mut Tensor> {
202 self.net.parameters()
203 }
204 fn set_requires_grad_all(&mut self, enable: bool) {
205 self.net.set_requires_grad_all(enable);
206 }
207}
208
209// -------------------------------
210// Simple continuous control environment: YardEnv
211// State: normalized features [pos/3, clamp(vel/1, -1..1), bias(=0)] ; Action: scalar in [-1, 1]
212// Dynamics: vel += 0.1*act - 0.01*pos; pos += vel
213// Reward: -(pos^2) - 0.1*act^2 ; Episode ends if |pos| > 3 or step >= max_steps
214// -------------------------------
215
216struct YardEnv {
217 pos: f32,
218 vel: f32,
219 steps: usize,
220 max_steps: usize,
221 rng: SmallRng,
222}
223
224impl YardEnv {
225 fn new(seed: u64) -> Self {
226 let mut env = Self {
227 pos: 0.0,
228 vel: 0.0,
229 steps: 0,
230 max_steps: 200,
231 rng: SmallRng::new(seed),
232 };
233 env.reset();
234 env
235 }
236
237 fn reset(&mut self) -> Tensor {
238 self.pos = self.rng.uniform(-0.5, 0.5);
239 self.vel = self.rng.uniform(-0.1, 0.1);
240 self.steps = 0;
241 self.state_tensor()
242 }
243
244 fn state_tensor(&self) -> Tensor {
245 // Normalize to keep critic inputs bounded:
246 // - Position is bounded by termination at |pos|>3 → scale by 3 to [-1,1]
247 // - Velocity scaled by 1.0 and clamped to [-1,1]
248 let pos_n = self.pos / 3.0;
249 let vel_n = self.vel.clamp(-1.0, 1.0);
250 Tensor::from_slice(&[pos_n, vel_n, 0.0], vec![1, 3]).unwrap()
251 }
252
253 fn step(&mut self, action_value: f32) -> (Tensor, f32, bool) {
254 let a = action_value.clamp(-1.0, 1.0);
255 self.vel += 0.1 * a - 0.01 * self.pos;
256 self.pos += self.vel;
257 self.steps += 1;
258
259 let reward = -(self.pos * self.pos) - 0.1 * (a * a);
260 let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
261 (self.state_tensor(), reward, done)
262 }
263}
264
265// -------------------------------
266// Replay Buffer
267// -------------------------------
268
269struct ReplayBuffer {
270 capacity: usize,
271 size: usize,
272 pos: usize,
273 state_dim: usize,
274 action_dim: usize,
275 states: Vec<f32>,
276 actions: Vec<f32>,
277 rewards: Vec<f32>,
278 dones: Vec<f32>,
279 next_states: Vec<f32>,
280}
281
282impl ReplayBuffer {
283 fn new(capacity: usize, state_dim: usize, action_dim: usize) -> Self {
284 Self {
285 capacity,
286 size: 0,
287 pos: 0,
288 state_dim,
289 action_dim,
290 states: vec![0.0; capacity * state_dim],
291 actions: vec![0.0; capacity * action_dim],
292 rewards: vec![0.0; capacity],
293 dones: vec![0.0; capacity],
294 next_states: vec![0.0; capacity * state_dim],
295 }
296 }
297
298 fn push(&mut self, s: &[f32], a: &[f32], r: f32, d: f32, s2: &[f32]) {
299 let i = self.pos;
300 let so = i * self.state_dim;
301 let ao = i * self.action_dim;
302 self.states[so..so + self.state_dim].copy_from_slice(s);
303 self.actions[ao..ao + self.action_dim].copy_from_slice(a);
304 self.rewards[i] = r;
305 self.dones[i] = d;
306 self.next_states[so..so + self.state_dim].copy_from_slice(s2);
307
308 self.pos = (self.pos + 1) % self.capacity;
309 self.size = self.size.saturating_add(1).min(self.capacity);
310 }
311
312 fn can_sample(&self, batch_size: usize) -> bool {
313 self.size >= batch_size
314 }
315
316 fn sample(
317 &self,
318 batch_size: usize,
319 rng: &mut SmallRng,
320 ) -> (Tensor, Tensor, Tensor, Tensor, Tensor) {
321 let mut s_vec = Vec::with_capacity(batch_size * self.state_dim);
322 let mut a_vec = Vec::with_capacity(batch_size * self.action_dim);
323 let mut r_vec = Vec::with_capacity(batch_size);
324 let mut d_vec = Vec::with_capacity(batch_size);
325 let mut s2_vec = Vec::with_capacity(batch_size * self.state_dim);
326
327 for _ in 0..batch_size {
328 let idx = rng.sample_index(self.size);
329 let so = idx * self.state_dim;
330 let ao = idx * self.action_dim;
331 s_vec.extend_from_slice(&self.states[so..so + self.state_dim]);
332 a_vec.extend_from_slice(&self.actions[ao..ao + self.action_dim]);
333 r_vec.push(self.rewards[idx]);
334 d_vec.push(self.dones[idx]);
335 s2_vec.extend_from_slice(&self.next_states[so..so + self.state_dim]);
336 }
337
338 let s = Tensor::from_slice(&s_vec, vec![batch_size, self.state_dim]).unwrap();
339 let a = Tensor::from_slice(&a_vec, vec![batch_size, self.action_dim]).unwrap();
340 let r = Tensor::from_slice(&r_vec, vec![batch_size, 1]).unwrap();
341 let d = Tensor::from_slice(&d_vec, vec![batch_size, 1]).unwrap();
342 let s2 = Tensor::from_slice(&s2_vec, vec![batch_size, self.state_dim]).unwrap();
343 (s, a, r, d, s2)
344 }
345}
346
347// -------------------------------
348// Helper: gradient clipping by global norm
349// -------------------------------
350
351fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
352 // Compute global L2 norm of all grads
353 let mut total_sq = 0.0f32;
354 for p in parameters.iter() {
355 if let Some(g) = p.grad_owned() {
356 for &v in g.data() {
357 total_sq += v * v;
358 }
359 }
360 }
361 let norm = total_sq.sqrt();
362 if norm > max_norm {
363 let scale = max_norm / (norm + eps);
364 for p in parameters.iter_mut() {
365 if let Some(g) = p.grad_owned() {
366 let scaled = g.mul_scalar(scale);
367 p.set_grad(scaled);
368 }
369 }
370 }
371}
372
373// Compute global L2 norm of gradients across a parameter list (read-only)
374fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
375 let mut total_sq = 0.0f32;
376 for p in parameters.iter_mut() {
377 if let Some(g) = p.grad_owned() {
378 for &v in g.data() {
379 total_sq += v * v;
380 }
381 }
382 }
383 total_sq.sqrt()
384}
385
386// Compute L2 norm of parameters (weights/biases) across a parameter list
387fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
388 let _ng = NoGradTrack::new();
389 let mut total_sq = 0.0f32;
390 for p in parameters.iter_mut() {
391 for &v in p.data() {
392 total_sq += v * v;
393 }
394 }
395 total_sq.sqrt()
396}
397
398// -------------------------------
399// Main: TD3 training on YardEnv
400// -------------------------------
401
402pub fn main() -> Result<(), Box<dyn std::error::Error>> {
403 println!("=== TD3 Example (YardEnv) ===");
404
405 // Environment / problem dims
406 let state_dim = 3usize;
407 let action_dim = 1usize;
408
409 // Hyperparameters (small for demo)
410 let gamma = 0.99f32;
411 let tau = 0.005f32; // Polyak
412 let policy_noise = 0.2f32; // target smoothing noise stddev
413 let exploration_noise = 0.1f32; // behavior policy noise stddev
414 let policy_delay = 2usize;
415 let batch_size = 64usize;
416 let start_steps = 500usize; // random exploration steps
417 let total_steps = 1500usize;
418 let max_grad_norm = 1.0f32;
419
420 // Models
421 let mut actor = Actor::new(state_dim, action_dim, Some(11));
422 let mut actor_targ = Actor::new(state_dim, action_dim, Some(12));
423 actor_targ.net.copy_from(&actor.net);
424 actor_targ.set_requires_grad_all(false);
425
426 let mut critic1 = Critic::new(state_dim, action_dim, Some(21));
427 let mut critic2 = Critic::new(state_dim, action_dim, Some(22));
428 let mut critic1_targ = Critic::new(state_dim, action_dim, Some(23));
429 let mut critic2_targ = Critic::new(state_dim, action_dim, Some(24));
430 critic1_targ.net.copy_from(&critic1.net);
431 critic2_targ.net.copy_from(&critic2.net);
432 critic1_targ.set_requires_grad_all(false);
433 critic2_targ.set_requires_grad_all(false);
434
435 // Optimizers
436 let mut actor_opt = Adam::with_learning_rate(1e-3);
437 for p in actor.parameters() {
438 actor_opt.add_parameter(p);
439 }
440
441 let mut critic_opt = Adam::with_learning_rate(1e-4);
442 for p in critic1.parameters() {
443 critic_opt.add_parameter(p);
444 }
445 for p in critic2.parameters() {
446 critic_opt.add_parameter(p);
447 }
448
449 // Replay buffer and env
450 let mut rb = ReplayBuffer::new(100_000, state_dim, action_dim);
451 let mut env = YardEnv::new(1234);
452 let mut rng = SmallRng::new(987654321);
453
454 // Reset & metric trackers
455 let mut state = env.reset(); // [1, state_dim]
456 let mut episode_return = 0.0f32;
457 let mut episode = 0usize;
458 let mut ema_return: Option<f32> = None;
459 let ema_alpha = 0.05f32; // smooth short-term
460 let mut best_return = f32::NEG_INFINITY;
461 let mut policy_updates: usize = 0;
462
463 for t in 0..total_steps {
464 // Select action
465 let action_tensor = if t < start_steps {
466 let a = rng.uniform(-1.0, 1.0);
467 Tensor::from_slice(&[a], vec![1, action_dim]).unwrap()
468 } else {
469 // Behavior policy with exploration noise
470 let _ng = NoGradTrack::new();
471 let det = actor.forward(&state);
472 let noise = Tensor::randn(vec![1, action_dim], None).mul_scalar(exploration_noise);
473 tanh_bounded(&det.add_tensor(&noise))
474 };
475 let action_value = action_tensor.data()[0];
476
477 // Environment step
478 let (next_state, reward, done) = env.step(action_value);
479 episode_return += reward;
480
481 // Store transition
482 let s_slice = state.data().to_vec();
483 let a_slice = action_tensor.data().to_vec();
484 let s2_slice = next_state.data().to_vec();
485 rb.push(
486 &s_slice,
487 &a_slice,
488 reward,
489 if done { 1.0 } else { 0.0 },
490 &s2_slice,
491 );
492
493 state = if done {
494 let st = env.reset();
495 // Metrics: update EMA and best
496 ema_return = Some(match ema_return {
497 None => episode_return,
498 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
499 });
500 if episode_return > best_return {
501 best_return = episode_return;
502 }
503 println!(
504 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={} | policy_updates={}",
505 t,
506 episode,
507 episode_return,
508 ema_return.unwrap_or(episode_return),
509 best_return,
510 rb.size,
511 policy_updates
512 );
513 episode_return = 0.0;
514 episode += 1;
515 st
516 } else {
517 next_state
518 };
519
520 // Training
521 if rb.can_sample(batch_size) {
522 // Sample batch
523 let (s, a, r, d, s2) = rb.sample(batch_size, &mut rng);
524
525 // Compute target values y = r + (1-d)*gamma*min(Q1', Q2') using target networks (no grad)
526 let target_q = {
527 let _ng = NoGradTrack::new();
528 // Target actions with smoothing noise (tanh bounds)
529 let noise =
530 Tensor::randn(vec![batch_size, action_dim], None).mul_scalar(policy_noise);
531 let a_targ = tanh_bounded(&actor_targ.forward(&s2).add_tensor(&noise));
532 let q1_t = critic1_targ.forward(&s2, &a_targ);
533 let q2_t = critic2_targ.forward(&s2, &a_targ);
534
535 // Elementwise min via data() since this path is no-grad
536 let q1d = q1_t.data();
537 let q2d = q2_t.data();
538 let mut min_vec = Vec::with_capacity(batch_size);
539 for i in 0..batch_size {
540 let v1 = q1d[i];
541 let v2 = q2d[i];
542 min_vec.push(v1.min(v2));
543 }
544 let min_q = Tensor::from_slice(&min_vec, vec![batch_size, 1]).unwrap();
545 let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
546 r.add_tensor(¬_done.mul_scalar(gamma).mul_tensor(&min_q))
547 };
548
549 // Critic update (both critics)
550 // Zero grads in a short scope, then drop borrows before forward
551 {
552 let mut params = {
553 let c_params = critic1.parameters();
554 let c2_params = critic2.parameters();
555 let mut tmp: Vec<&mut Tensor> = Vec::new();
556 tmp.extend(c_params);
557 tmp.extend(c2_params);
558 tmp
559 };
560 critic_opt.zero_grad(&mut params);
561 }
562
563 // Forward current Q estimates
564 let q1 = critic1.forward(&s, &a);
565 let q2 = critic2.forward(&s, &a);
566 let diff1 = q1.sub_tensor(&target_q);
567 let diff2 = q2.sub_tensor(&target_q);
568 let mut critic_loss = diff1
569 .pow_scalar(2.0)
570 .mean()
571 .add_tensor(&diff2.pow_scalar(2.0).mean());
572
573 // Backward
574 critic_loss.backward(None);
575
576 // Optional gradient clipping + step (only for params that received grads)
577 {
578 let params = {
579 let c_params = critic1.parameters();
580 let c2_params = critic2.parameters();
581 let mut tmp: Vec<&mut Tensor> = Vec::new();
582 tmp.extend(c_params);
583 tmp.extend(c2_params);
584 tmp
585 };
586 let mut with_grads: Vec<&mut Tensor> = Vec::new();
587 for p in params {
588 if p.grad_owned().is_some() {
589 with_grads.push(p);
590 }
591 }
592 if !with_grads.is_empty() {
593 // Pre-step metrics
594 let grad_norm_before = grad_global_norm(&mut with_grads);
595 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
596 critic_opt.step(&mut with_grads);
597 critic_opt.zero_grad(&mut with_grads);
598
599 // Post-step metrics (param norm)
600 let mut for_norm_params = {
601 let c_params = critic1.parameters();
602 let c2_params = critic2.parameters();
603 let mut tmp: Vec<&mut Tensor> = Vec::new();
604 tmp.extend(c_params);
605 tmp.extend(c2_params);
606 tmp
607 };
608 let param_norm = params_l2_norm(&mut for_norm_params);
609
610 // Print compact critic metrics occasionally
611 if t % 100 == 0 {
612 let q1_mean = q1.mean().value();
613 let q2_mean = q2.mean().value();
614 let tq_mean = target_q.mean().value();
615 println!(
616 "t={:5} | critic_loss={:.4} | q1_mean={:.3} q2_mean={:.3} tq_mean={:.3} | grad_norm={:.3} | crit_param_norm={:.3}",
617 t,
618 critic_loss.value(),
619 q1_mean,
620 q2_mean,
621 tq_mean,
622 grad_norm_before,
623 param_norm
624 );
625 }
626 }
627 }
628
629 // Delayed policy update
630 if t % policy_delay == 0 {
631 // Actor update: maximize Q1(s, actor(s)) -> minimize -Q1
632 // Zero actor grads before backward
633 {
634 let mut a_params: Vec<&mut Tensor> = actor.parameters();
635 actor_opt.zero_grad(&mut a_params);
636 }
637
638 let a_pred = actor.forward(&s);
639 let q_for_actor = critic1.forward(&s, &a_pred);
640 let mut actor_loss = q_for_actor.mul_scalar(-1.0).mean();
641 actor_loss.backward(None);
642
643 {
644 let a_params: Vec<&mut Tensor> = actor.parameters();
645 let mut with_grads: Vec<&mut Tensor> = Vec::new();
646 for p in a_params {
647 if p.grad_owned().is_some() {
648 with_grads.push(p);
649 }
650 }
651 if !with_grads.is_empty() {
652 let grad_norm_before = grad_global_norm(&mut with_grads);
653 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
654 actor_opt.step(&mut with_grads);
655 actor_opt.zero_grad(&mut with_grads);
656
657 // Post-step param norm
658 let mut for_norm_params = actor.parameters();
659 let param_norm = params_l2_norm(&mut for_norm_params);
660
661 policy_updates += 1;
662 if t % 200 == 0 {
663 println!(
664 "t={:5} | actor_loss={:.4} | act_grad_norm={:.3} | act_param_norm={:.3} | lr_a={:.4e} lr_c={:.4e} | policy_updates={}",
665 t,
666 actor_loss.value(),
667 grad_norm_before,
668 param_norm,
669 actor_opt.learning_rate(),
670 critic_opt.learning_rate(),
671 policy_updates
672 );
673 }
674 }
675 }
676
677 // Target updates (Polyak averaging, no grad)
678 actor_targ.net.soft_update_from(&actor.net, tau);
679 critic1_targ.net.soft_update_from(&critic1.net, tau);
680 critic2_targ.net.soft_update_from(&critic2.net, tau);
681 }
682
683 // Clear entire graphs to avoid stale accumulation across iterations
684 clear_all_graphs_known();
685 }
686 }
687
688 println!("=== TD3 training finished ===");
689 Ok(())
690}293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294 let mut total_sq = 0.0f32;
295 for p in parameters.iter() {
296 if let Some(g) = p.grad_owned() {
297 for &v in g.data() {
298 total_sq += v * v;
299 }
300 }
301 }
302 let norm = total_sq.sqrt();
303 if norm > max_norm {
304 let scale = max_norm / (norm + eps);
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 p.set_grad(g.mul_scalar(scale));
308 }
309 }
310 }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314 let mut total_sq = 0.0f32;
315 for p in parameters.iter_mut() {
316 if let Some(g) = p.grad_owned() {
317 for &v in g.data() {
318 total_sq += v * v;
319 }
320 }
321 }
322 total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330 println!("=== PPO Continuous Example (YardEnv) ===");
331
332 let state_dim = 3usize;
333 let action_dim = 1usize;
334
335 // Hparams
336 let total_steps = std::env::var("PPO_STEPS")
337 .ok()
338 .and_then(|v| v.parse::<usize>().ok())
339 .unwrap_or(4000usize);
340 let horizon = 128usize; // rollout length per update
341 let epochs = 4usize; // PPO epochs per update
342 let mini_batch_size = 64usize; // minibatch from horizon
343 let gamma = 0.99f32;
344 let lam = 0.95f32; // GAE lambda
345 let clip_eps = 0.2f32;
346 let vf_coef = 0.5f32;
347 let ent_coef = 0.0f32;
348 let max_grad_norm = 1.0f32;
349
350 // Models
351 let mut actor = Actor::new(state_dim, action_dim, Some(101));
352 let mut critic = Critic::new(state_dim, Some(202));
353
354 // Opts
355 let mut actor_opt = Adam::with_learning_rate(3e-4);
356 for p in actor.parameters() {
357 actor_opt.add_parameter(p);
358 }
359 let mut critic_opt = Adam::with_learning_rate(3e-4);
360 for p in critic.parameters() {
361 critic_opt.add_parameter(p);
362 }
363
364 // Env and RNG
365 let mut env = YardEnv::new(42);
366 let mut rng = SmallRng::new(999);
367 let mut state = env.reset();
368
369 // Metrics
370 let mut episode_return = 0.0f32;
371 let mut episode = 0usize;
372 let mut ema_return: Option<f32> = None;
373 let ema_alpha = 0.05f32;
374 let mut best_return = f32::NEG_INFINITY;
375
376 let mut t = 0usize;
377 while t < total_steps {
378 // Collect a rollout
379 let mut batch = RolloutBatch::new(horizon, state_dim);
380 for _ in 0..horizon {
381 // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382 let (mean, log_std_row) = actor.forward(&state);
383 let mean_v = mean.data()[0];
384 let log_std_v = log_std_row.data()[0];
385 let std_v = log_std_v.exp();
386 let noise = rng.normal();
387 let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389 // Build action tensor [1, A] for log_prob calculation with autograd
390 let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391 let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392 let log_prob_v = log_prob_t.data()[0];
393
394 // Step env
395 let (next_state, reward, done) = env.step(action_v);
396 episode_return += reward;
397
398 // Value
399 let value_t = critic.forward(&state);
400 let value_v = value_t.data()[0];
401
402 // Push
403 batch.push(
404 state.data(),
405 action_v,
406 log_prob_v,
407 reward,
408 if done { 1.0 } else { 0.0 },
409 value_v,
410 next_state.data(),
411 );
412
413 // Reset
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return
430 );
431 episode_return = 0.0;
432 episode += 1;
433 st
434 } else {
435 next_state
436 };
437
438 t += 1;
439 if t >= total_steps {
440 break;
441 }
442 }
443
444 // Bootstrap next values for GAE
445 let next_values: Vec<f32> = {
446 let mut out = Vec::with_capacity(batch.len());
447 for i in 0..batch.len() {
448 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450 let v2 = critic.forward(&s2_t).data()[0];
451 out.push(v2);
452 }
453 out
454 };
455
456 // Compute returns and advantages
457 let mut returns = vec![0.0f32; batch.len()];
458 let mut adv = vec![0.0f32; batch.len()];
459 compute_gae(
460 &mut returns,
461 &mut adv,
462 &batch.rewards,
463 &batch.dones,
464 &batch.values,
465 &next_values,
466 gamma,
467 lam,
468 );
469 normalize_in_place(&mut adv, 1e-8);
470
471 // Prepare tensors for training
472 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473 let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474 let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478 // PPO epochs over the rollout
479 let num_minibatches = batch.len().div_ceil(mini_batch_size);
480 for e in 0..epochs {
481 for mb in 0..num_minibatches {
482 let start = mb * mini_batch_size;
483 let end = (start + mini_batch_size).min(batch.len());
484 if start >= end {
485 break;
486 }
487
488 // Slice views
489 let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490 let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491 let a_mb = actions_t
492 .slice_view(start * action_dim, 1, (end - start) * action_dim)
493 .reshape(vec![(end - start) as i32, action_dim as i32]);
494 let oldlp_mb = old_logp_t
495 .slice_view(start, 1, end - start)
496 .reshape(vec![(end - start) as i32, 1]);
497 let ret_mb = returns_t
498 .slice_view(start, 1, end - start)
499 .reshape(vec![(end - start) as i32, 1]);
500 let adv_mb = adv_t
501 .slice_view(start, 1, end - start)
502 .reshape(vec![(end - start) as i32, 1]);
503
504 // Zero grads
505 {
506 let mut ps = actor.parameters();
507 actor_opt.zero_grad(&mut ps);
508 }
509 {
510 let mut ps = critic.parameters();
511 critic_opt.zero_grad(&mut ps);
512 }
513
514 // Forward actor and critic
515 let (mean_mb, log_std_row) = actor.forward(&s_mb);
516 let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517 let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518 let clip_low =
519 Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520 .unwrap();
521 let clip_high =
522 Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523 .unwrap();
524 // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525 let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526 let ratio_clipped =
527 clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528 let pg1 = ratio.mul_tensor(&adv_mb);
529 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532 let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534 let v_pred = critic.forward(&s_mb);
535 let v_loss = v_pred
536 .sub_tensor(&ret_mb)
537 .pow_scalar(2.0)
538 .mean()
539 .mul_scalar(vf_coef);
540
541 // Entropy (approx Gaussian entropy per action)
542 let entropy = log_std_row
543 .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544 .sum_dims(&[1], true)
545 .mean()
546 .mul_scalar(ent_coef);
547
548 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549 loss.backward(None);
550
551 // Step actor
552 {
553 let params = actor.parameters();
554 let mut with_grads: Vec<&mut Tensor> = Vec::new();
555 for p in params {
556 if p.grad_owned().is_some() {
557 with_grads.push(p);
558 }
559 }
560 if !with_grads.is_empty() {
561 let _ = grad_global_norm(&mut with_grads);
562 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563 actor_opt.step(&mut with_grads);
564 actor_opt.zero_grad(&mut with_grads);
565 }
566 }
567
568 // Step critic
569 {
570 let params = critic.parameters();
571 let mut with_grads: Vec<&mut Tensor> = Vec::new();
572 for p in params {
573 if p.grad_owned().is_some() {
574 with_grads.push(p);
575 }
576 }
577 if !with_grads.is_empty() {
578 let _ = grad_global_norm(&mut with_grads);
579 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580 critic_opt.step(&mut with_grads);
581 critic_opt.zero_grad(&mut with_grads);
582 }
583 }
584
585 // Occasionally log
586 if e == 0 && mb == 0 {
587 println!(
588 "update@t={} | actor_loss={:.4} v_loss={:.4}",
589 t,
590 actor_loss.value(),
591 v_loss.value()
592 );
593 }
594
595 clear_all_graphs_known();
596 }
597 }
598 }
599
600 println!("=== PPO training finished ===");
601 Ok(())
602}252fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
253 let mut total_sq = 0.0f32;
254 for p in parameters.iter() {
255 if let Some(g) = p.grad_owned() {
256 for &v in g.data() {
257 total_sq += v * v;
258 }
259 }
260 }
261 let norm = total_sq.sqrt();
262 if norm > max_norm {
263 let scale = max_norm / (norm + eps);
264 for p in parameters.iter_mut() {
265 if let Some(g) = p.grad_owned() {
266 p.set_grad(g.mul_scalar(scale));
267 }
268 }
269 }
270}
271
272// log-softmax for selected actions: given logits [B,A] and actions Vec<usize> -> log_prob [B,1]
273fn log_prob_actions(
274 logits: &Tensor,
275 actions: &[usize],
276 batch: usize,
277 _action_dim: usize,
278) -> Tensor {
279 let max_logits = logits.max_dims(&[1], true); // [B,1]
280 let shifted = logits.sub_tensor(&max_logits);
281 let exp = shifted.exp();
282 let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283 let log_sum_exp = sum_exp.log(); // [B,1]
284 let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285 // gather selected action log-probs
286 log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291 new_logp.sub_tensor(old_logp).exp()
292}
293
294// Clamp ratio to [1-clip, 1+clip] using ReLU-based clamp (no custom ops)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296 let b = ratio.shape().dims()[0];
297 let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298 let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299 let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300 high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304 let mut total_sq = 0.0f32;
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 for &v in g.data() {
308 total_sq += v * v;
309 }
310 }
311 }
312 total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320 println!("=== PPO Discrete Example (YardEnv) ===");
321
322 let state_dim = 3usize;
323 let action_dim = 3usize;
324 let total_steps = std::env::var("PPOD_STEPS")
325 .ok()
326 .and_then(|v| v.parse::<usize>().ok())
327 .unwrap_or(3500usize);
328 let horizon = 128usize;
329 let epochs = 4usize;
330 let mini_batch_size = 64usize;
331 let gamma = 0.99f32;
332 let lam = 0.95f32;
333 let clip_eps = 0.2f32;
334 let vf_coef = 0.5f32;
335 let ent_coef = 0.0f32;
336 let max_grad_norm = 1.0f32;
337
338 let mut actor = Actor::new(state_dim, action_dim, Some(111));
339 let mut critic = Critic::new(state_dim, Some(222));
340 let mut actor_opt = Adam::with_learning_rate(3e-4);
341 for p in actor.parameters() {
342 actor_opt.add_parameter(p);
343 }
344 let mut critic_opt = Adam::with_learning_rate(3e-4);
345 for p in critic.parameters() {
346 critic_opt.add_parameter(p);
347 }
348
349 let mut env = YardEnv::new(1234);
350 let mut rng = SmallRng::new(98765);
351 let mut state = env.reset();
352 let mut episode_return = 0.0f32;
353 let mut episode = 0usize;
354 let mut ema_return: Option<f32> = None;
355 let ema_alpha = 0.05f32;
356 let mut best_return = f32::NEG_INFINITY;
357
358 let mut t = 0usize;
359 while t < total_steps {
360 let mut batch = RolloutBatch::new(horizon, state_dim);
361 for _ in 0..horizon {
362 // Actor logits and categorical sampling
363 let logits = actor.forward(&state); // [1, A]
364 let probs = logits.softmax(1); // [1, A]
365 // sample action from probs (CPU sampling)
366 let p = probs.data();
367 let (p0, p1, _p2) = (p[0], p[1], p[2]);
368 let u = rng.next_f32();
369 let a_idx = if u < p0 {
370 0
371 } else if u < p0 + p1 {
372 1
373 } else {
374 2
375 };
376
377 let old_logp = {
378 let _ng = NoGradTrack::new();
379 let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380 lp.data()[0]
381 };
382
383 // Step env
384 let (next_state, reward, done) = env.step(a_idx);
385 episode_return += reward;
386
387 // Critic value
388 let value_t = critic.forward(&state);
389 let value_v = value_t.data()[0];
390
391 batch.push(
392 state.data(),
393 a_idx,
394 old_logp,
395 reward,
396 if done { 1.0 } else { 0.0 },
397 value_v,
398 next_state.data(),
399 );
400
401 state = if done {
402 let st = env.reset();
403 ema_return = Some(match ema_return {
404 None => episode_return,
405 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406 });
407 if episode_return > best_return {
408 best_return = episode_return;
409 }
410 println!(
411 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412 t,
413 episode,
414 episode_return,
415 ema_return.unwrap_or(episode_return),
416 best_return
417 );
418 episode_return = 0.0;
419 episode += 1;
420 st
421 } else {
422 next_state
423 };
424
425 t += 1;
426 if t >= total_steps {
427 break;
428 }
429 }
430
431 // Bootstrap values for GAE
432 let next_values: Vec<f32> = {
433 let mut out = Vec::with_capacity(batch.len());
434 for i in 0..batch.len() {
435 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437 out.push(critic.forward(&s2_t).data()[0]);
438 }
439 out
440 };
441
442 let mut returns = vec![0.0f32; batch.len()];
443 let mut adv = vec![0.0f32; batch.len()];
444 compute_gae(
445 &mut returns,
446 &mut adv,
447 &batch.rewards,
448 &batch.dones,
449 &batch.values,
450 &next_values,
451 gamma,
452 lam,
453 );
454 normalize_in_place(&mut adv, 1e-8);
455
456 // Tensors for training
457 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458 let actions_vec = batch.actions.clone();
459 let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463 // PPO epochs
464 let num_minibatches = batch.len().div_ceil(mini_batch_size);
465 for e in 0..epochs {
466 for mb in 0..num_minibatches {
467 let start = mb * mini_batch_size;
468 let end = (start + mini_batch_size).min(batch.len());
469 if start >= end {
470 break;
471 }
472
473 // Views
474 let s_mb = states_t
475 .slice_view(start * state_dim, 1, (end - start) * state_dim)
476 .reshape(vec![(end - start) as i32, state_dim as i32]);
477 let oldlp_mb = old_logp_t
478 .slice_view(start, 1, end - start)
479 .reshape(vec![(end - start) as i32, 1]);
480 let ret_mb = returns_t
481 .slice_view(start, 1, end - start)
482 .reshape(vec![(end - start) as i32, 1]);
483 let adv_mb = adv_t
484 .slice_view(start, 1, end - start)
485 .reshape(vec![(end - start) as i32, 1]);
486 let a_slice = &actions_vec[start..end];
487
488 // Zero grads
489 {
490 let mut ps = actor.parameters();
491 actor_opt.zero_grad(&mut ps);
492 }
493 {
494 let mut ps = critic.parameters();
495 critic_opt.zero_grad(&mut ps);
496 }
497
498 // Forward
499 let logits_mb = actor.forward(&s_mb); // [B,A]
500 let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501 let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502 let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503 let pg1 = ratio.mul_tensor(&adv_mb);
504 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507 let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509 let v_pred = critic.forward(&s_mb);
510 let v_loss = v_pred
511 .sub_tensor(&ret_mb)
512 .pow_scalar(2.0)
513 .mean()
514 .mul_scalar(vf_coef);
515
516 // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517 let probs_mb = logits_mb.softmax(1);
518 let logp_all = probs_mb.add_scalar(1e-8).log();
519 let ent = probs_mb
520 .mul_tensor(&logp_all)
521 .sum_dims(&[1], true)
522 .mul_scalar(-1.0)
523 .mean()
524 .mul_scalar(ent_coef);
525
526 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527 loss.backward(None);
528
529 // Step actor
530 {
531 let params = actor.parameters();
532 let mut with_grads: Vec<&mut Tensor> = Vec::new();
533 for p in params {
534 if p.grad_owned().is_some() {
535 with_grads.push(p);
536 }
537 }
538 if !with_grads.is_empty() {
539 let _ = grad_global_norm(&mut with_grads);
540 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541 actor_opt.step(&mut with_grads);
542 actor_opt.zero_grad(&mut with_grads);
543 }
544 }
545
546 // Step critic
547 {
548 let params = critic.parameters();
549 let mut with_grads: Vec<&mut Tensor> = Vec::new();
550 for p in params {
551 if p.grad_owned().is_some() {
552 with_grads.push(p);
553 }
554 }
555 if !with_grads.is_empty() {
556 let _ = grad_global_norm(&mut with_grads);
557 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558 critic_opt.step(&mut with_grads);
559 critic_opt.zero_grad(&mut with_grads);
560 }
561 }
562
563 if e == 0 && mb == 0 {
564 println!(
565 "update@t={} | actor_loss={:.4} v_loss={:.4}",
566 t,
567 actor_loss.value(),
568 v_loss.value()
569 );
570 }
571
572 clear_all_graphs_known();
573 }
574 }
575 }
576
577 println!("=== PPO discrete training finished ===");
578 Ok(())
579}23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24 let mut total_sq = 0.0f32;
25 for p in parameters.iter() {
26 if let Some(g) = p.grad_owned() {
27 for &v in g.data() {
28 total_sq += v * v;
29 }
30 }
31 }
32 let norm = total_sq.sqrt();
33 if norm > max_norm {
34 let scale = max_norm / (norm + eps);
35 for p in parameters.iter_mut() {
36 if let Some(g) = p.grad_owned() {
37 p.set_grad(g.mul_scalar(scale));
38 }
39 }
40 }
41}
42
43fn accuracy(pred: &Tensor, targets: &Tensor) -> f32 {
44 // pred: [B,1] with sigmoid; threshold at 0.5
45 let p = pred.data();
46 let t = targets.data();
47 let mut correct = 0usize;
48 for i in 0..p.len() {
49 let yhat = if p[i] >= 0.5 { 1.0 } else { 0.0 };
50 if (yhat - t[i]).abs() < 1e-6 {
51 correct += 1;
52 }
53 }
54 correct as f32 / (p.len() as f32)
55}
56
57// Numerically stable BCE with logits:
58// L = mean( relu(z) - z*y + log(1 + exp(-|z|)) )
59fn bce_with_logits(logits: &Tensor, targets: &Tensor) -> Tensor {
60 let relu_z = logits.relu();
61 let zy = logits.mul_tensor(targets);
62 // |z| = relu(z) + relu(-z)
63 let abs_z = relu_z.add_tensor(&logits.mul_scalar(-1.0).relu());
64 let log_term = abs_z.mul_scalar(-1.0).exp().add_scalar(1.0).log();
65 relu_z.sub_tensor(&zy).add_tensor(&log_term).mean()
66}
67
68pub fn main() -> Result<(), Box<dyn std::error::Error>> {
69 println!("=== Supervised FFN Example (XOR) ===");
70
71 // Dataset: XOR (repeat to form a small batch)
72 let inputs: Vec<f32> = vec![
73 0.0, 0.0, // -> 0
74 0.0, 1.0, // -> 1
75 1.0, 0.0, // -> 1
76 1.0, 1.0, // -> 0
77 ];
78 let targets: Vec<f32> = vec![0.0, 1.0, 1.0, 0.0];
79
80 // Repeat the base patterns to stabilize training
81 let repeats = 64usize; // effective batch = 4 * repeats = 256
82 let mut xs = Vec::with_capacity(repeats * inputs.len());
83 let mut ys = Vec::with_capacity(repeats * targets.len());
84 for _ in 0..repeats {
85 xs.extend_from_slice(&inputs);
86 ys.extend_from_slice(&targets);
87 }
88
89 let batch = xs.len() / 2; // two features
90 let x_t = Tensor::from_slice(&xs, vec![batch, 2]).unwrap();
91 let y_t = Tensor::from_slice(&ys, vec![batch, 1]).unwrap();
92
93 // Model config: 2 -> 32 -> 32 -> 1, final sigmoid via loss path
94 let cfg = FeedForwardConfig {
95 input_size: 2,
96 hidden_sizes: vec![32, 32],
97 output_size: 1,
98 use_bias: true,
99 };
100 let mut net = FeedForwardNetwork::new(cfg, Some(777));
101
102 // Optimizer and parameter linking
103 let mut opt = Adam::with_learning_rate(1e-3);
104 for p in net.parameters() {
105 opt.add_parameter(p);
106 }
107
108 let epochs = 1000usize;
109 let max_grad_norm = 1.0f32;
110 let mut best_loss = f32::INFINITY;
111 let mut best_acc = 0.0f32;
112
113 for e in 0..epochs {
114 // Zero grads each iteration
115 {
116 let mut params = net.parameters();
117 opt.zero_grad(&mut params);
118 }
119
120 // Forward -> logits; use numerically stable BCE-with-logits for loss
121 let logits = net.forward(&x_t);
122 let mut loss = bce_with_logits(&logits, &y_t);
123 loss.backward(None);
124
125 // Step only params with grads
126 {
127 let params = net.parameters();
128 let mut with_grads: Vec<&mut Tensor> = Vec::new();
129 for p in params {
130 if p.grad_owned().is_some() {
131 with_grads.push(p);
132 }
133 }
134 if !with_grads.is_empty() {
135 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
136 opt.step(&mut with_grads);
137 opt.zero_grad(&mut with_grads);
138 }
139 }
140
141 // Metrics (use sigmoid only for reporting accuracy)
142 let preds = logits.sigmoid();
143 let acc = accuracy(&preds, &y_t);
144 if loss.value() < best_loss {
145 best_loss = loss.value();
146 }
147 if acc > best_acc {
148 best_acc = acc;
149 }
150 if e % 10 == 0 || e + 1 == epochs {
151 println!(
152 "epoch {:4} | loss={:.5} acc={:.3} | best_loss={:.5} best_acc={:.3}",
153 e,
154 loss.value(),
155 acc,
156 best_loss,
157 best_acc
158 );
159 }
160
161 // Clear graphs to avoid stale accumulation across epochs
162 clear_all_graphs_known();
163 }
164
165 // Quick sanity check predictions
166 let test = Tensor::from_slice(&inputs, vec![4, 2]).unwrap();
167 let out = net.forward(&test).sigmoid();
168 println!("predictions (approx): {:?}", out.data());
169
170 println!("=== Supervised training finished ===");
171 Ok(())
172}23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24 let mut total_sq = 0.0f32;
25 for p in parameters.iter() {
26 if let Some(g) = p.grad_owned() {
27 for &v in g.data() {
28 total_sq += v * v;
29 }
30 }
31 }
32 let norm = total_sq.sqrt();
33 if norm > max_norm {
34 let scale = max_norm / (norm + eps);
35 for p in parameters.iter_mut() {
36 if let Some(g) = p.grad_owned() {
37 p.set_grad(g.mul_scalar(scale));
38 }
39 }
40 }
41}
42
43// Cross-entropy over logits: CE = -mean(log_softmax(logits)[range, labels])
44fn cross_entropy_logits(
45 logits: &Tensor,
46 labels: &[usize],
47 batch: usize,
48 _num_classes: usize,
49) -> Tensor {
50 // log_softmax = logits - logsumexp(logits, dim=1)
51 let max_logits = logits.max_dims(&[1], true);
52 let shifted = logits.sub_tensor(&max_logits);
53 let exp = shifted.exp();
54 let sum_exp = exp.sum_dims(&[1], true);
55 let log_sum_exp = sum_exp.log();
56 let log_softmax = shifted.sub_tensor(&log_sum_exp);
57 let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58 ll.mul_scalar(-1.0).mean()
59}
60
61fn accuracy_from_logits(
62 logits: &Tensor,
63 labels: &[usize],
64 batch: usize,
65 num_classes: usize,
66) -> f32 {
67 let row = logits.data();
68 let mut correct = 0usize;
69 for (i, &label) in labels.iter().enumerate().take(batch) {
70 let base = i * num_classes;
71 let mut best_j = 0usize;
72 let mut best_v = row[base];
73 for j in 1..num_classes {
74 let v = row[base + j];
75 if v > best_v {
76 best_v = v;
77 best_j = j;
78 }
79 }
80 if best_j == label {
81 correct += 1;
82 }
83 }
84 correct as f32 / batch as f32
85}
86
87pub fn main() -> Result<(), Box<dyn std::error::Error>> {
88 println!("=== Supervised Classification Example (Cross-Entropy) ===");
89
90 // Synthetic 2D inputs, 3 classes with linear-ish separations
91 let n = 1200usize;
92 let classes = 3usize;
93 let mut xs: Vec<f32> = Vec::with_capacity(n * 2);
94 let mut ys: Vec<usize> = Vec::with_capacity(n);
95
96 // Simple RNG
97 let mut state: u64 = 424242;
98 let mut rand_f32 = || {
99 state = state.wrapping_mul(1664525).wrapping_add(1013904223);
100 (state >> 16) as f32 / (u32::MAX as f32)
101 };
102
103 for _ in 0..n {
104 let x1 = rand_f32() * 4.0 - 2.0;
105 let x2 = rand_f32() * 4.0 - 2.0;
106 // Class by quadrant-ish rule with noise
107 let mut c = if x1 + 0.5 * x2 > 0.5 {
108 0
109 } else if x1 - x2 < -0.5 {
110 1
111 } else {
112 2
113 };
114 if rand_f32() < 0.05 {
115 c = (c + 1) % classes;
116 }
117 xs.push(x1);
118 xs.push(x2);
119 ys.push(c);
120 }
121
122 // Normalize inputs per-feature to [-1, 1]
123 let mut min1 = f32::INFINITY;
124 let mut max1 = f32::NEG_INFINITY;
125 let mut min2 = f32::INFINITY;
126 let mut max2 = f32::NEG_INFINITY;
127 for i in (0..xs.len()).step_by(2) {
128 let a = xs[i];
129 let b = xs[i + 1];
130 if a < min1 {
131 min1 = a;
132 }
133 if a > max1 {
134 max1 = a;
135 }
136 if b < min2 {
137 min2 = b;
138 }
139 if b > max2 {
140 max2 = b;
141 }
142 }
143 let rng1 = (max1 - min1).max(1e-8);
144 let rng2 = (max2 - min2).max(1e-8);
145 for i in (0..xs.len()).step_by(2) {
146 let a = xs[i];
147 let b = xs[i + 1];
148 xs[i] = 2.0 * (a - min1) / rng1 - 1.0;
149 xs[i + 1] = 2.0 * (b - min2) / rng2 - 1.0;
150 }
151
152 // Train/Val split (80/20)
153 let n_train = (n as f32 * 0.8) as usize;
154 let x_train = Tensor::from_slice(&xs[..n_train * 2], vec![n_train, 2]).unwrap();
155 let y_train = ys[..n_train].to_vec();
156 let x_val = Tensor::from_slice(&xs[n_train * 2..], vec![n - n_train, 2]).unwrap();
157 let y_val = ys[n_train..].to_vec();
158
159 // Model: 2 -> 64 -> 64 -> 3 (logits)
160 let cfg = FeedForwardConfig {
161 input_size: 2,
162 hidden_sizes: vec![64, 64],
163 output_size: classes,
164 use_bias: true,
165 };
166 let mut net = FeedForwardNetwork::new(cfg, Some(303));
167
168 // Optimizer
169 let mut opt = Adam::with_learning_rate(1e-3);
170 for p in net.parameters() {
171 opt.add_parameter(p);
172 }
173
174 let epochs = 300usize;
175 let max_grad_norm = 1.0f32;
176 let mut best_val_acc = 0.0f32;
177 let mut best_val_loss = f32::INFINITY;
178
179 for e in 0..epochs {
180 // Zero grads
181 {
182 let mut params = net.parameters();
183 opt.zero_grad(&mut params);
184 }
185
186 // Forward logits
187 let logits = net.forward(&x_train);
188 let mut loss = cross_entropy_logits(&logits, &y_train, n_train, classes);
189 loss.backward(None);
190
191 // Step clipped
192 {
193 let params = net.parameters();
194 let mut with_grads: Vec<&mut Tensor> = Vec::new();
195 for p in params {
196 if p.grad_owned().is_some() {
197 with_grads.push(p);
198 }
199 }
200 if !with_grads.is_empty() {
201 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
202 opt.step(&mut with_grads);
203 opt.zero_grad(&mut with_grads);
204 }
205 }
206
207 // Metrics
208 let train_acc = accuracy_from_logits(&logits, &y_train, n_train, classes);
209 let val_logits = net.forward(&x_val);
210 let val_loss = cross_entropy_logits(&val_logits, &y_val, n - n_train, classes).value();
211 let val_acc = accuracy_from_logits(&val_logits, &y_val, n - n_train, classes);
212 if val_acc > best_val_acc {
213 best_val_acc = val_acc;
214 }
215 if val_loss < best_val_loss {
216 best_val_loss = val_loss;
217 }
218
219 if e % 10 == 0 || e + 1 == epochs {
220 println!(
221 "epoch {:4} | loss={:.4} acc={:.3} | val_loss={:.4} val_acc={:.3} | best_val_acc={:.3}",
222 e, loss.value(), train_acc, val_loss, val_acc, best_val_acc
223 );
224 }
225
226 clear_all_graphs_known();
227 }
228
229 // Quick sample preds via softmax
230 let samples = Tensor::from_slice(&[-1.0, -1.0, 0.0, 0.0, 1.0, 1.0], vec![3, 2]).unwrap();
231 let sm = net.forward(&samples).softmax(1);
232 println!("sample class probs: {:?}", sm.data());
233
234 println!("=== Supervised classification finished ===");
235 Ok(())
236}- examples/supervised_training/supervised_regression.rs
- examples/getting_started/tensor_operators.rs
- examples/RL_training/../neural_networks/basic_linear_layer.rs
- examples/iterators/element_iteration.rs
- examples/getting_started/tensor_basics.rs
- examples/neural_networks/feedforward_network.rs
- examples/getting_started/serialization_basics.rs
- examples/getting_started/optimizer_basics.rs
- examples/neural_networks/multi_head_attention.rs
- examples/iterators/advanced_patterns.rs
- examples/iterators/performance_optimization.rs
Sourcepub fn data_mut(&mut self) -> &mut [f32]
pub fn data_mut(&mut self) -> &mut [f32]
Returns a mutable slice of the tensor’s underlying data
Provides safe mutable access to the tensor’s data without requiring unsafe pointer operations. Use this for in-place modifications of tensor values.
§Returns
A mutable slice containing all tensor elements in row-major order
§Performance
- Zero-Cost: Direct slice creation with no copying
- Cache-Friendly: Sequential memory access pattern
- Safe: No unsafe code required for basic data modification
§Examples
use train_station::Tensor;
let mut tensor = Tensor::new(vec![2, 2]);
let data = tensor.data_mut();
// Safe indexing for modification
data[0] = 1.0;
data[1] = 2.0;
assert_eq!(tensor.get(&[0, 0]), 1.0);
assert_eq!(tensor.get(&[0, 1]), 2.0);Examples found in repository?
108 fn copy_from(&mut self, other: &Self) {
109 for (t, s) in self.layers.iter_mut().zip(other.layers.iter()) {
110 {
111 let src = s.weight.data();
112 let dst = t.weight.data_mut();
113 dst.copy_from_slice(src);
114 }
115 {
116 let src = s.bias.data();
117 let dst = t.bias.data_mut();
118 dst.copy_from_slice(src);
119 }
120 t.weight.set_requires_grad(false);
121 t.bias.set_requires_grad(false);
122 }
123 }More examples
114 fn copy_from(&mut self, other: &Self) {
115 for (t, s) in self.layers.iter_mut().zip(other.layers.iter()) {
116 {
117 let src = s.weight.data();
118 let dst = t.weight.data_mut();
119 dst.copy_from_slice(src);
120 }
121 {
122 let src = s.bias.data();
123 let dst = t.bias.data_mut();
124 dst.copy_from_slice(src);
125 }
126 t.weight.set_requires_grad(false);
127 t.bias.set_requires_grad(false);
128 }
129 }
130
131 fn soft_update_from(&mut self, source: &Self, tau: f32) {
132 let _ng = NoGradTrack::new();
133 for (t, s) in self.layers.iter_mut().zip(source.layers.iter()) {
134 // In-place Polyak update to preserve tensor IDs (no optimizer relink needed)
135 let new_w = t
136 .weight
137 .mul_scalar(1.0 - tau)
138 .add_tensor(&s.weight.mul_scalar(tau));
139 let new_b = t
140 .bias
141 .mul_scalar(1.0 - tau)
142 .add_tensor(&s.bias.mul_scalar(tau));
143 {
144 let src = new_w.data();
145 let dst = t.weight.data_mut();
146 dst.copy_from_slice(src);
147 }
148 {
149 let src = new_b.data();
150 let dst = t.bias.data_mut();
151 dst.copy_from_slice(src);
152 }
153 t.weight.set_requires_grad(false);
154 t.bias.set_requires_grad(false);
155 }
156 }42fn demonstrate_tensor_creation() {
43 println!("--- Tensor Creation ---");
44
45 // Create tensors with different initializations
46 let zeros = Tensor::zeros(vec![2, 3]);
47 println!(
48 "Zeros tensor: shape {:?}, data: {:?}",
49 zeros.shape().dims(),
50 zeros.data()
51 );
52
53 let ones = Tensor::ones(vec![3, 2]);
54 println!(
55 "Ones tensor: shape {:?}, data: {:?}",
56 ones.shape().dims(),
57 ones.data()
58 );
59
60 // Create tensor from slice
61 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
62 let from_slice = Tensor::from_slice(&data, vec![2, 3]).unwrap();
63 println!(
64 "From slice: shape {:?}, data: {:?}",
65 from_slice.shape().dims(),
66 from_slice.data()
67 );
68
69 // Create tensor with specific value
70 let mut filled = Tensor::new(vec![2, 2]);
71 {
72 let data = filled.data_mut();
73 for value in data.iter_mut() {
74 *value = 42.0;
75 }
76 }
77 println!("Filled with 42: {:?}", filled.data());
78
79 // Create tensor with random data
80 let random = Tensor::randn(vec![2, 2], Some(42));
81 println!(
82 "Random tensor: shape {:?}, data: {:?}",
83 random.shape().dims(),
84 random.data()
85 );
86}76 pub fn infer_autoregressive(&self, src: &Tensor, max_steps: usize) -> Tensor {
77 let (b, _s, e) = Self::triple(src);
78 let mut memory = src.clone();
79 for enc in &self.encoders {
80 memory = enc.forward(&memory, None);
81 }
82
83 let mut out_seq: Vec<Tensor> = Vec::new();
84 // Start token: zeros
85 let mut current = Tensor::zeros(vec![b, 1, e]);
86 for _step in 0..max_steps {
87 // Build causal mask for length t
88 let t = current.shape().dims()[1];
89 let mut causal = Tensor::ones(vec![b, self.num_heads, t, t]);
90 // Upper triangle as false -> masked for all batches and heads
91 for bb in 0..b {
92 for hh in 0..self.num_heads {
93 for i in 0..t {
94 for j in (i + 1)..t {
95 let offset = causal.memory_offset(&[bb, hh, i, j]);
96 let data = causal.data_mut();
97 data[offset] = 0.0;
98 }
99 }
100 }
101 }
102 let mut step_out = current.clone();
103 for dec in &self.decoders {
104 step_out = dec.forward(&step_out, &memory, Some(&causal), None);
105 }
106 // (Toy) append placeholder token; real models would project last token
107 out_seq.push(step_out.clone());
108 // Append a zero token to grow sequence by 1 for next causal computation
109 current = Tensor::zeros(vec![b, t + 1, e]);
110 }
111 // Simple return of final sequence placeholder
112 current
113 }
114
115 /// Non auto-regressive inference: single forward pass
116 pub fn infer_non_autoregressive(&self, src: &Tensor, tgt_len: usize) -> Tensor {
117 let (b, _s, e) = Self::triple(src);
118 let mut memory = src.clone();
119 for enc in &self.encoders {
120 memory = enc.forward(&memory, None);
121 }
122 let tgt = Tensor::zeros(vec![b, tgt_len, e]);
123 let mut out = tgt.clone();
124 for dec in &self.decoders {
125 out = dec.forward(&out, &memory, None, None);
126 }
127 out
128 }
129
130 /// Helper: build boolean-like causal mask [b, heads, t, t] with 1.0 keep, 0.0 masked
131 fn build_causal_mask_static(batch: usize, heads: usize, t: usize) -> Tensor {
132 let mut mask = Tensor::ones(vec![batch, heads, t, t]);
133 for bb in 0..batch {
134 for hh in 0..heads {
135 for i in 0..t {
136 for j in (i + 1)..t {
137 let offset = mask.memory_offset(&[bb, hh, i, j]);
138 let data = mask.data_mut();
139 data[offset] = 0.0;
140 }
141 }
142 }
143 }
144 mask
145 }165fn main() -> Result<(), Box<dyn std::error::Error>> {
166 println!("=== Multi-Head Attention Example ===");
167
168 let batch = 2usize;
169 let src_len = 5usize;
170 let tgt_len = 4usize;
171 let embed = 16usize;
172 let heads = 4usize;
173
174 let query = Tensor::randn(vec![batch, tgt_len, embed], Some(7));
175 let key = Tensor::randn(vec![batch, src_len, embed], Some(8));
176 let value = Tensor::randn(vec![batch, src_len, embed], Some(9));
177
178 let mut mha = MultiHeadAttention::new(embed, heads, Some(42));
179
180 // Simple causal mask for target self-attention shape [b, h, tq, tk]
181 let mut mask = Tensor::zeros(vec![batch, heads, tgt_len, src_len]);
182 // Disallow attending to future positions when tgt_len <= src_len by adding -1e9
183 // Here, just demonstrate mask broadcast/add mechanics with a light mask on last head
184 if src_len >= tgt_len {
185 // set upper triangle to a large negative value for head 0
186 for i in 0..tgt_len {
187 for j in (i + 1)..src_len {
188 let idx = [0usize, 0usize, i, j];
189 // Quick set via data_mut using a slice view
190 let offset = mask.memory_offset(&idx);
191 let data = mask.data_mut();
192 data[offset] = -1e9;
193 }
194 }
195 }
196
197 let out = mha.forward(&query, &key, &value, Some(&mask));
198 println!("Output shape: {:?}", out.shape().dims());
199
200 // Tiny training step to confirm gradients are wired
201 let mut optimizer = Adam::with_learning_rate(0.01);
202 let mut params = mha.parameters();
203 for p in ¶ms {
204 optimizer.add_parameter(p);
205 }
206
207 // Dummy loss = mean of output
208 let mut loss = out.mean();
209 loss.backward(None);
210 optimizer.step(&mut params);
211 optimizer.zero_grad(&mut params);
212
213 println!("Loss: {:.6}", loss.value());
214 println!("=== Done ===");
215 Ok(())
216}Sourcepub fn value(&self) -> f32
pub fn value(&self) -> f32
Extract scalar value from single-element tensor
This method provides a convenient way to extract the scalar value from tensors that contain exactly one element. This is commonly used with element iterator results and scalar tensor operations.
§Returns
The scalar value contained in this tensor
§Panics
Panics if the tensor does not contain exactly one element
§Examples
use train_station::Tensor;
// Single-element tensor
let scalar = Tensor::from_slice(&[42.0], vec![1]).unwrap();
assert_eq!(scalar.value(), 42.0);Examples found in repository?
46fn rmse(pred: &Tensor, target: &Tensor) -> f32 {
47 mse(pred, target).sqrt().value()
48}
49
50fn r2_score(pred: &Tensor, target: &Tensor) -> f32 {
51 // R^2 = 1 - SS_res / SS_tot
52 let y = target;
53 let y_mean = y.mean();
54 let ss_res = pred.sub_tensor(y).pow_scalar(2.0).sum();
55 let ss_tot = y.sub_tensor(&y_mean).pow_scalar(2.0).sum();
56 let ss_res_v = ss_res.value();
57 let ss_tot_v = ss_tot.value().max(1e-12); // avoid divide by zero
58 1.0 - (ss_res_v / ss_tot_v)
59}More examples
73fn main() -> Result<(), Box<dyn std::error::Error>> {
74 println!("=== Basic Encoder Example ===");
75
76 let batch = 2usize;
77 let seq = 6usize;
78 let embed = 32usize;
79 let heads = 4usize;
80
81 let input = Tensor::randn(vec![batch, seq, embed], Some(11));
82 let mut enc = EncoderBlock::new(embed, heads, Some(123));
83
84 // Example: no mask (set Some(mask) to use masking)
85 let out = enc.forward(&input, None);
86 println!("Output shape: {:?}", out.shape().dims());
87
88 // Verify gradients/optimization
89 let mut opt = Adam::with_learning_rate(0.01);
90 let mut params = enc.parameters();
91 for p in ¶ms {
92 opt.add_parameter(p);
93 }
94 let mut loss = out.mean();
95 loss.backward(None);
96 opt.step(&mut params);
97 opt.zero_grad(&mut params);
98 println!("Loss: {:.6}", loss.value());
99 println!("=== Done ===");
100 Ok(())
101}179fn demonstrate_utility_functions() {
180 println!("\n--- Utility Functions ---");
181
182 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184 // Basic properties
185 println!("Shape: {:?}", tensor.shape().dims());
186 println!("Size: {}", tensor.size());
187 println!("Is contiguous: {}", tensor.is_contiguous());
188 println!("Device: {:?}", tensor.device());
189
190 // Mathematical operations
191 let sum = tensor.sum();
192 println!("Sum: {}", sum.value());
193
194 let mean = tensor.mean();
195 println!("Mean: {}", mean.value());
196
197 let norm = tensor.norm();
198 println!("Norm: {}", norm.value());
199
200 // Device placement
201 let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202 println!(
203 "CPU tensor: shape {:?}, device: {:?}",
204 cpu_tensor.shape().dims(),
205 cpu_tensor.device()
206 );
207}84fn main() -> Result<(), Box<dyn std::error::Error>> {
85 println!("=== Basic Decoder Example ===");
86
87 let batch = 2usize;
88 let src = 7usize;
89 let tgt = 5usize;
90 let embed = 32usize;
91 let heads = 4usize;
92
93 let memory = Tensor::randn(vec![batch, src, embed], Some(21));
94 let tgt_in = Tensor::randn(vec![batch, tgt, embed], Some(22));
95
96 let mut dec = DecoderBlock::new(embed, heads, Some(456));
97 let out = dec.forward(&tgt_in, &memory, None, None);
98 println!("Output shape: {:?}", out.shape().dims());
99
100 let mut opt = Adam::with_learning_rate(0.01);
101 let mut params = dec.parameters();
102 for p in ¶ms {
103 opt.add_parameter(p);
104 }
105 let mut loss = out.mean();
106 loss.backward(None);
107 opt.step(&mut params);
108 opt.zero_grad(&mut params);
109 println!("Loss: {:.6}", loss.value());
110 println!("=== Done ===");
111 Ok(())
112}148 pub fn train_non_autoregressive_steps(
149 &mut self,
150 src: &Tensor,
151 tgt: &Tensor,
152 steps: usize,
153 lr: f32,
154 ) {
155 let mut opt = Adam::with_learning_rate(lr);
156 {
157 let params_once = self.parameters();
158 for p in ¶ms_once {
159 opt.add_parameter(p);
160 }
161 }
162 for step in 0..steps {
163 // forward + backward scope (immutable borrow)
164 {
165 let pred = self.forward(src, tgt);
166 let diff = pred.sub_tensor(tgt);
167 let mut loss = diff.pow_scalar(2.0).mean();
168 if step == 0 || step + 1 == steps {
169 println!("NAR train step {}: loss={:.6}", step, loss.value());
170 }
171 loss.backward(None);
172 }
173 // step + zero_grad scope (mutable borrow)
174 let mut params_step = self.parameters();
175 opt.step(&mut params_step);
176 opt.zero_grad(&mut params_step);
177 }
178 }
179
180 /// Auto-regressive training (teacher forcing): predict next token with causal mask
181 pub fn train_autoregressive_steps(
182 &mut self,
183 src: &Tensor,
184 tgt: &Tensor,
185 steps: usize,
186 lr: f32,
187 ) {
188 let mut opt = Adam::with_learning_rate(lr);
189 {
190 let params_once = self.parameters();
191 for p in ¶ms_once {
192 opt.add_parameter(p);
193 }
194 }
195
196 // Build encoder memory once (static dataset demo)
197 let mut memory = src.clone();
198 for enc in &self.encoders {
199 memory = enc.forward(&memory, None);
200 }
201
202 let (b, t, _e) = Self::triple(tgt);
203 // Predict y[t] from y[:t] using causal mask; here we simply predict full seq with mask
204 let causal = Self::build_causal_mask_static(b, self.num_heads, t);
205 for step in 0..steps {
206 // forward + backward scope
207 {
208 let mut out = tgt.clone();
209 for dec in &self.decoders {
210 out = dec.forward(&out, &memory, Some(&causal), None);
211 }
212 let diff = out.sub_tensor(tgt);
213 let mut loss = diff.pow_scalar(2.0).mean();
214 if step == 0 || step + 1 == steps {
215 println!("AR train step {}: loss={:.6}", step, loss.value());
216 }
217 loss.backward(None);
218 }
219 let mut params_step = self.parameters();
220 opt.step(&mut params_step);
221 opt.zero_grad(&mut params_step);
222 }
223 }
224
225 fn triple(t: &Tensor) -> (usize, usize, usize) {
226 let d = t.shape().dims();
227 (d[0], d[1], d[2])
228 }
229}
230
231fn main() -> Result<(), Box<dyn std::error::Error>> {
232 println!("=== Basic Transformer Example ===");
233
234 let batch = 2usize;
235 let src_len = 8usize;
236 let tgt_len = 6usize;
237 let embed = 32usize;
238 let heads = 4usize;
239 let layers = 2usize;
240
241 let src = Tensor::randn(vec![batch, src_len, embed], Some(1001));
242 let tgt = Tensor::randn(vec![batch, tgt_len, embed], Some(1002));
243
244 let mut trf = BasicTransformer::new(embed, heads, layers, Some(999));
245 let out = trf.forward(&src, &tgt);
246 println!("Output shape: {:?}", out.shape().dims());
247
248 // Quick optimization step
249 let mut opt = Adam::with_learning_rate(0.005);
250 let mut params = trf.parameters();
251 for p in ¶ms {
252 opt.add_parameter(p);
253 }
254 let mut loss = out.mean();
255 loss.backward(None);
256 opt.step(&mut params);
257 opt.zero_grad(&mut params);
258 println!("Loss: {:.6}", loss.value());
259
260 // Demo: non auto-regressive inference (single pass)
261 let nar = trf.infer_non_autoregressive(&src, tgt_len);
262 println!("NAR output shape: {:?}", nar.shape().dims());
263
264 // Demo: auto-regressive inference (toy)
265 let ar = trf.infer_autoregressive(&src, 3);
266 println!("AR output shape: {:?}", ar.shape().dims());
267
268 // NAR training demo
269 let nar_tgt = tgt.clone();
270 trf.train_non_autoregressive_steps(&src, &nar_tgt, 3, 0.01);
271
272 // AR training demo (teacher-forced)
273 let ar_tgt = tgt.clone();
274 trf.train_autoregressive_steps(&src, &ar_tgt, 3, 0.01);
275 println!("=== Done ===");
276 Ok(())
277}93fn demonstrate_basic_iteration() -> Result<(), Box<dyn std::error::Error>> {
94 println!("\n--- Basic Element Iteration ---");
95
96 // Create a simple tensor for demonstration
97 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
98 println!("Original tensor: {:?}", tensor.data());
99
100 // Basic iteration with for loop
101 println!("\nBasic iteration with for loop:");
102 for (i, element) in tensor.iter().enumerate() {
103 println!(
104 " Element {}: value = {:.1}, shape = {:?}",
105 i,
106 element.value(),
107 element.shape().dims()
108 );
109 }
110
111 // Element-wise transformation
112 println!("\nElement-wise transformation (2x + 1):");
113 let transformed: Tensor = tensor
114 .iter()
115 .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0))
116 .collect();
117 println!(" Result: {:?}", transformed.data());
118
119 // Filtering elements
120 println!("\nFiltering elements (values > 3.0):");
121 let filtered: Tensor = tensor.iter().filter(|elem| elem.value() > 3.0).collect();
122 println!(" Filtered: {:?}", filtered.data());
123
124 Ok(())
125}
126
127/// Demonstrate standard iterator trait methods
128///
129/// Shows compatibility with Rust's standard library iterator methods
130/// and demonstrates various functional programming patterns.
131fn demonstrate_standard_methods() -> Result<(), Box<dyn std::error::Error>> {
132 println!("\n--- Standard Iterator Methods ---");
133
134 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
135
136 // Using map for transformations
137 println!("\nMap transformation (square each element):");
138 let squared: Tensor = tensor.iter().map(|elem| elem.pow_scalar(2.0)).collect();
139 println!(" Squared: {:?}", squared.data());
140
141 // Using enumerate for indexed operations
142 println!("\nEnumerate with indexed operations:");
143 let indexed: Tensor = tensor
144 .iter()
145 .enumerate()
146 .map(|(i, elem)| elem.add_scalar(i as f32))
147 .collect();
148 println!(" Indexed: {:?}", indexed.data());
149
150 // Using fold for reduction
151 println!("\nFold for sum calculation:");
152 let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
153 println!(" Sum: {:.1}", sum);
154
155 // Using find for element search
156 println!("\nFind specific element:");
157 if let Some(found) = tensor.iter().find(|elem| elem.value() == 3.0) {
158 println!(" Found element with value 3.0: {:.1}", found.value());
159 }
160
161 // Using any/all for condition checking
162 println!("\nCondition checking:");
163 let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
164 let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
165 println!(" All positive: {}", all_positive);
166 println!(" Any > 4.0: {}", any_large);
167
168 Ok(())
169}
170
171/// Demonstrate gradient tracking through element operations
172///
173/// Shows how gradient tracking works seamlessly through iterator
174/// operations, maintaining the computational graph for backpropagation.
175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176 println!("\n--- Gradient Tracking ---");
177
178 // Create a tensor with gradient tracking enabled
179 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180 println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182 // Perform element-wise operations through iteration
183 let result: Tensor = tensor
184 .iter()
185 .map(|elem| {
186 // Apply a complex transformation: (x^2 + 1) * 2
187 elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188 })
189 .collect();
190
191 println!("Result tensor: {:?}", result.data());
192 println!("Result requires_grad: {}", result.requires_grad());
193
194 // Compute gradients
195 let mut loss = result.sum();
196 loss.backward(None);
197
198 println!("Loss: {:.6}", loss.value());
199 println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201 Ok(())
202}
203
204/// Demonstrate advanced iterator patterns
205///
206/// Shows complex iterator chains and advanced functional programming
207/// patterns for sophisticated data processing workflows.
208fn demonstrate_advanced_patterns() -> Result<(), Box<dyn std::error::Error>> {
209 println!("\n--- Advanced Iterator Patterns ---");
210
211 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6])?;
212 println!("Input tensor: {:?}", tensor.data());
213
214 // Complex chain: enumerate -> filter -> map -> collect
215 println!("\nComplex chain (even indices only, add index to value):");
216 let result: Tensor = tensor
217 .iter()
218 .enumerate()
219 .filter(|(i, _)| i % 2 == 0) // Take even indices
220 .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
221 .collect();
222 println!(" Result: {:?}", result.data());
223
224 // Using take and skip for windowing
225 println!("\nWindowing with take and skip:");
226 let window1: Tensor = tensor.iter().take(3).collect();
227 let window2: Tensor = tensor.iter().skip(2).take(3).collect();
228 println!(" Window 1 (first 3): {:?}", window1.data());
229 println!(" Window 2 (middle 3): {:?}", window2.data());
230
231 // Using rev() for reverse iteration
232 println!("\nReverse iteration:");
233 let reversed: Tensor = tensor.iter().rev().collect();
234 println!(" Reversed: {:?}", reversed.data());
235
236 // Chaining with mathematical operations
237 println!("\nMathematical operation chain:");
238 let math_result: Tensor = tensor
239 .iter()
240 .map(|elem| elem.exp()) // e^x
241 .filter(|elem| elem.value() < 50.0) // Filter large values
242 .map(|elem| elem.log()) // ln(x)
243 .collect();
244 println!(" Math chain result: {:?}", math_result.data());
245
246 // Using zip for element-wise combinations
247 println!("\nElement-wise combination with zip:");
248 let tensor2 = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![6])?;
249 let combined: Tensor = tensor
250 .iter()
251 .zip(tensor2.iter())
252 .map(|(a, b)| a.mul_tensor(&b)) // Element-wise multiplication
253 .collect();
254 println!(" Combined: {:?}", combined.data());
255
256 Ok(())
257}- examples/neural_networks/multi_head_attention.rs
- examples/optimizers/adam_configurations.rs
- examples/optimizers/learning_rate_scheduling.rs
- examples/getting_started/optimizer_basics.rs
- examples/neural_networks/feedforward_network.rs
- examples/iterators/advanced_patterns.rs
- examples/iterators/performance_optimization.rs
- examples/RL_training/../neural_networks/basic_linear_layer.rs
- examples/supervised_training/supervised_bce.rs
- examples/supervised_training/supervised_classification.rs
- examples/RL_training/dqn.rs
- examples/RL_training/ppo_discrete.rs
- examples/RL_training/ppo_continuous.rs
- examples/RL_training/td3.rs
Sourcepub fn view(&self, new_shape: Vec<i32>) -> Tensor
pub fn view(&self, new_shape: Vec<i32>) -> Tensor
Create a view with a new shape (requires contiguous memory)
Behaves like PyTorch view: tensor must be contiguous and the total
number of elements must remain the same. Supports -1 inference for one dimension.
§Arguments
new_shape- New shape for the tensor (can contain -1 for inference)
§Returns
A tensor viewing the same data with a new shape
§Examples
use train_station::Tensor;
let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
let y = x.view(vec![2, 2]);
assert_eq!(y.shape().dims(), vec![2, 2]);Examples found in repository?
More examples
53 pub fn forward(&self, input: &Tensor, attn_mask: Option<&Tensor>) -> Tensor {
54 let attn = self.mha.forward(input, input, input, attn_mask);
55 let res1 = attn.add_tensor(input);
56
57 // Feed-forward network with ReLU and residual
58 let (b, t, e) = Self::triple(input);
59 let x2d = res1.contiguous().view(vec![(b * t) as i32, e as i32]);
60 let hidden = self.ffn_in.forward(&x2d).relu();
61 let out2d = self.ffn_out.forward(&hidden);
62 let out = out2d.view(vec![b as i32, t as i32, e as i32]);
63 out.add_tensor(&res1)
64 }56 pub fn forward(
57 &self,
58 tgt: &Tensor,
59 memory: &Tensor,
60 causal_mask: Option<&Tensor>,
61 cross_mask: Option<&Tensor>,
62 ) -> Tensor {
63 let self_attn = self.self_attn.forward(tgt, tgt, tgt, causal_mask);
64 let res1 = self_attn.add_tensor(tgt);
65
66 let cross = self.cross_attn.forward(&res1, memory, memory, cross_mask);
67 let res2 = cross.add_tensor(&res1);
68
69 let (b, t, e) = Self::triple(tgt);
70 let x2d = res2.contiguous().view(vec![(b * t) as i32, e as i32]);
71 let hidden = self.ffn_in.forward(&x2d).relu();
72 let out2d = self.ffn_out.forward(&hidden);
73 let out = out2d.view(vec![b as i32, t as i32, e as i32]);
74 out.add_tensor(&res2)
75 }120fn demonstrate_shape_operations() {
121 println!("\n--- Shape Operations ---");
122
123 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
124 println!(
125 "Original: shape {:?}, data: {:?}",
126 tensor.shape().dims(),
127 tensor.data()
128 );
129
130 // Reshape (view)
131 let reshaped = tensor.view(vec![3, 2]);
132 println!(
133 "Reshaped to [3, 2]: shape {:?}, data: {:?}",
134 reshaped.shape().dims(),
135 reshaped.data()
136 );
137
138 // Create a different shaped tensor for demonstration
139 let tensor_2d = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
140 println!(
141 "2D tensor: shape {:?}, data: {:?}",
142 tensor_2d.shape().dims(),
143 tensor_2d.data()
144 );
145
146 // Create a 1D tensor
147 let tensor_1d = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
148 println!(
149 "1D tensor: shape {:?}, data: {:?}",
150 tensor_1d.shape().dims(),
151 tensor_1d.data()
152 );
153}330fn demonstrate_single_vs_batch_inference() {
331 println!("\n--- Single vs Batch Inference ---");
332
333 let layer = LinearLayer::new(4, 3, Some(46));
334
335 // Single inference
336 println!("Single inference:");
337 let single_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
338 let single_output = layer.forward_no_grad(&single_input);
339 println!(" Input shape: {:?}", single_input.shape().dims());
340 println!(" Output shape: {:?}", single_output.shape().dims());
341 println!(" Output: {:?}", single_output.data());
342
343 // Batch inference
344 println!("Batch inference:");
345 let batch_input = Tensor::from_slice(
346 &[
347 1.0, 2.0, 3.0, 4.0, // Sample 1
348 5.0, 6.0, 7.0, 8.0, // Sample 2
349 9.0, 10.0, 11.0, 12.0, // Sample 3
350 ],
351 vec![3, 4],
352 )
353 .unwrap();
354 let batch_output = layer.forward_no_grad(&batch_input);
355 println!(" Input shape: {:?}", batch_input.shape().dims());
356 println!(" Output shape: {:?}", batch_output.shape().dims());
357
358 // Verify batch consistency - first sample should match single inference
359 let _first_batch_sample = batch_output.view(vec![3, 3]); // Reshape to access first sample
360 let first_sample_data = &batch_output.data()[0..3]; // First 3 elements
361 let single_sample_data = single_output.data();
362
363 println!("Consistency check:");
364 println!(" Single output: {:?}", single_sample_data);
365 println!(" First batch sample: {:?}", first_sample_data);
366 println!(
367 " Match: {}",
368 single_sample_data
369 .iter()
370 .zip(first_sample_data.iter())
371 .all(|(a, b)| (a - b).abs() < 1e-6)
372 );
373}Sourcepub fn element_view(&self, index: usize) -> Tensor
pub fn element_view(&self, index: usize) -> Tensor
Create an element view for the specified index
Returns a scalar tensor (shape [1]) that views a single element of the source tensor. Maintains gradient tracking.
§Arguments
index- Linear index of the element to view
§Returns
A scalar tensor viewing the specified element
§Examples
use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let element = tensor.element_view(1);
assert_eq!(element.value(), 2.0);Examples found in repository?
87fn demonstrate_data_pipeline() -> Result<(), Box<dyn std::error::Error>> {
88 println!("\n--- Data Processing Pipeline ---");
89
90 // Simulate raw sensor data with noise
91 let raw_data: Vec<f32> = (0..20)
92 .map(|i| {
93 let base = i as f32 * 0.5;
94 let noise = (i % 3) as f32 * 0.1;
95 base + noise
96 })
97 .collect();
98
99 let tensor = Tensor::from_slice(&raw_data, vec![20])?;
100 println!("Raw sensor data: {:?}", tensor.data());
101
102 // Multi-stage processing pipeline
103 println!("\nProcessing pipeline:");
104 println!("1. Normalize data (z-score)");
105 println!("2. Apply smoothing filter");
106 println!("3. Detect outliers");
107 println!("4. Apply feature scaling");
108
109 // Stage 1: Normalization
110 let mean = tensor.mean().value();
111 let std = tensor.std().value();
112 let normalized: Tensor = tensor
113 .iter()
114 .map(|elem| elem.sub_scalar(mean).div_scalar(std))
115 .collect();
116 println!(
117 " Normalized (mean={:.3}, std={:.3}): {:?}",
118 mean,
119 std,
120 normalized.data()
121 );
122
123 // Stage 2: Smoothing (simple moving average)
124 let smoothed: Tensor = normalized
125 .iter()
126 .enumerate()
127 .map(|(i, elem)| {
128 if i == 0 || i == normalized.size() - 1 {
129 elem.clone()
130 } else {
131 // Simple 3-point average
132 let prev = normalized.element_view(i - 1);
133 let next = normalized.element_view(i + 1);
134 elem.add_tensor(&prev).add_tensor(&next).div_scalar(3.0)
135 }
136 })
137 .collect();
138 println!(" Smoothed: {:?}", smoothed.data());
139
140 // Stage 3: Outlier detection and removal
141 let outlier_threshold = 2.0;
142 let cleaned: Tensor = smoothed
143 .iter()
144 .filter(|elem| elem.value().abs() < outlier_threshold)
145 .collect();
146 println!(
147 " Outliers removed (threshold={}): {:?}",
148 outlier_threshold,
149 cleaned.data()
150 );
151
152 // Stage 4: Feature scaling to [0, 1] range
153 let min_val = cleaned
154 .iter()
155 .map(|e| e.value())
156 .fold(f32::INFINITY, f32::min);
157 let max_val = cleaned
158 .iter()
159 .map(|e| e.value())
160 .fold(f32::NEG_INFINITY, f32::max);
161 let scaled: Tensor = cleaned
162 .iter()
163 .map(|elem| elem.sub_scalar(min_val).div_scalar(max_val - min_val))
164 .collect();
165 println!(" Scaled to [0,1]: {:?}", scaled.data());
166
167 Ok(())
168}
169
170/// Demonstrate conditional processing patterns
171///
172/// Shows how to implement dynamic filtering and transformation
173/// based on data characteristics and conditions.
174fn demonstrate_conditional_processing() -> Result<(), Box<dyn std::error::Error>> {
175 println!("\n--- Conditional Processing ---");
176
177 // Create data with mixed characteristics
178 let data = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0];
179 let tensor = Tensor::from_slice(&data, vec![10])?;
180 println!("Input data: {:?}", tensor.data());
181
182 // Conditional transformation based on sign
183 println!("\nConditional transformation (positive/negative handling):");
184 let processed: Tensor = tensor
185 .iter()
186 .map(|elem| {
187 let val = elem.value();
188 if val > 0.0 {
189 elem.pow_scalar(2.0) // Square positive values
190 } else {
191 elem.mul_scalar(-1.0).sqrt() // Square root of absolute negative values
192 }
193 })
194 .collect();
195 println!(" Processed: {:?}", processed.data());
196
197 // Adaptive filtering based on local statistics
198 println!("\nAdaptive filtering (remove values > 2 std from local mean):");
199 let window_size = 3;
200 let adaptive_filtered: Tensor = tensor
201 .iter()
202 .enumerate()
203 .filter(|(i, elem)| {
204 let start = i.saturating_sub(window_size / 2);
205 let end = (i + window_size / 2 + 1).min(tensor.size());
206
207 // Calculate local mean and std
208 let local_values: Vec<f32> = (start..end)
209 .map(|j| tensor.element_view(j).value())
210 .collect();
211
212 let local_mean = local_values.iter().sum::<f32>() / local_values.len() as f32;
213 let local_variance = local_values
214 .iter()
215 .map(|v| (v - local_mean).powi(2))
216 .sum::<f32>()
217 / local_values.len() as f32;
218 let local_std = local_variance.sqrt();
219
220 let threshold = local_mean + 2.0 * local_std;
221 elem.value() <= threshold
222 })
223 .map(|(_, elem)| elem)
224 .collect();
225 println!(" Adaptive filtered: {:?}", adaptive_filtered.data());
226
227 // Multi-condition processing
228 println!("\nMulti-condition processing:");
229 let multi_processed: Tensor = tensor
230 .iter()
231 .map(|elem| {
232 let val = elem.value();
233 match () {
234 _ if val > 5.0 => elem.mul_scalar(2.0), // Double large values
235 _ if val < -5.0 => elem.div_scalar(2.0), // Halve small values
236 _ if val.abs() < 2.0 => elem.add_scalar(1.0), // Add 1 to small values
237 _ => elem.clone(), // Keep others unchanged
238 }
239 })
240 .collect();
241 println!(" Multi-condition: {:?}", multi_processed.data());
242
243 Ok(())
244}Sourcepub fn slice_view(&self, start: usize, step: usize, length: usize) -> Tensor
pub fn slice_view(&self, start: usize, step: usize, length: usize) -> Tensor
Create a slice view of the tensor
Returns a view of a contiguous or strided slice of the source tensor.
§Arguments
start- Starting indexstep- Step size (1 for contiguous)length- Number of elements
§Returns
A tensor viewing the specified slice
§Examples
use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
let slice = tensor.slice_view(1, 2, 2); // [2.0, 4.0]
assert_eq!(slice.get(&[0]), 2.0);
assert_eq!(slice.get(&[1]), 4.0);Examples found in repository?
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320 println!("=== PPO Discrete Example (YardEnv) ===");
321
322 let state_dim = 3usize;
323 let action_dim = 3usize;
324 let total_steps = std::env::var("PPOD_STEPS")
325 .ok()
326 .and_then(|v| v.parse::<usize>().ok())
327 .unwrap_or(3500usize);
328 let horizon = 128usize;
329 let epochs = 4usize;
330 let mini_batch_size = 64usize;
331 let gamma = 0.99f32;
332 let lam = 0.95f32;
333 let clip_eps = 0.2f32;
334 let vf_coef = 0.5f32;
335 let ent_coef = 0.0f32;
336 let max_grad_norm = 1.0f32;
337
338 let mut actor = Actor::new(state_dim, action_dim, Some(111));
339 let mut critic = Critic::new(state_dim, Some(222));
340 let mut actor_opt = Adam::with_learning_rate(3e-4);
341 for p in actor.parameters() {
342 actor_opt.add_parameter(p);
343 }
344 let mut critic_opt = Adam::with_learning_rate(3e-4);
345 for p in critic.parameters() {
346 critic_opt.add_parameter(p);
347 }
348
349 let mut env = YardEnv::new(1234);
350 let mut rng = SmallRng::new(98765);
351 let mut state = env.reset();
352 let mut episode_return = 0.0f32;
353 let mut episode = 0usize;
354 let mut ema_return: Option<f32> = None;
355 let ema_alpha = 0.05f32;
356 let mut best_return = f32::NEG_INFINITY;
357
358 let mut t = 0usize;
359 while t < total_steps {
360 let mut batch = RolloutBatch::new(horizon, state_dim);
361 for _ in 0..horizon {
362 // Actor logits and categorical sampling
363 let logits = actor.forward(&state); // [1, A]
364 let probs = logits.softmax(1); // [1, A]
365 // sample action from probs (CPU sampling)
366 let p = probs.data();
367 let (p0, p1, _p2) = (p[0], p[1], p[2]);
368 let u = rng.next_f32();
369 let a_idx = if u < p0 {
370 0
371 } else if u < p0 + p1 {
372 1
373 } else {
374 2
375 };
376
377 let old_logp = {
378 let _ng = NoGradTrack::new();
379 let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380 lp.data()[0]
381 };
382
383 // Step env
384 let (next_state, reward, done) = env.step(a_idx);
385 episode_return += reward;
386
387 // Critic value
388 let value_t = critic.forward(&state);
389 let value_v = value_t.data()[0];
390
391 batch.push(
392 state.data(),
393 a_idx,
394 old_logp,
395 reward,
396 if done { 1.0 } else { 0.0 },
397 value_v,
398 next_state.data(),
399 );
400
401 state = if done {
402 let st = env.reset();
403 ema_return = Some(match ema_return {
404 None => episode_return,
405 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406 });
407 if episode_return > best_return {
408 best_return = episode_return;
409 }
410 println!(
411 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412 t,
413 episode,
414 episode_return,
415 ema_return.unwrap_or(episode_return),
416 best_return
417 );
418 episode_return = 0.0;
419 episode += 1;
420 st
421 } else {
422 next_state
423 };
424
425 t += 1;
426 if t >= total_steps {
427 break;
428 }
429 }
430
431 // Bootstrap values for GAE
432 let next_values: Vec<f32> = {
433 let mut out = Vec::with_capacity(batch.len());
434 for i in 0..batch.len() {
435 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437 out.push(critic.forward(&s2_t).data()[0]);
438 }
439 out
440 };
441
442 let mut returns = vec![0.0f32; batch.len()];
443 let mut adv = vec![0.0f32; batch.len()];
444 compute_gae(
445 &mut returns,
446 &mut adv,
447 &batch.rewards,
448 &batch.dones,
449 &batch.values,
450 &next_values,
451 gamma,
452 lam,
453 );
454 normalize_in_place(&mut adv, 1e-8);
455
456 // Tensors for training
457 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458 let actions_vec = batch.actions.clone();
459 let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463 // PPO epochs
464 let num_minibatches = batch.len().div_ceil(mini_batch_size);
465 for e in 0..epochs {
466 for mb in 0..num_minibatches {
467 let start = mb * mini_batch_size;
468 let end = (start + mini_batch_size).min(batch.len());
469 if start >= end {
470 break;
471 }
472
473 // Views
474 let s_mb = states_t
475 .slice_view(start * state_dim, 1, (end - start) * state_dim)
476 .reshape(vec![(end - start) as i32, state_dim as i32]);
477 let oldlp_mb = old_logp_t
478 .slice_view(start, 1, end - start)
479 .reshape(vec![(end - start) as i32, 1]);
480 let ret_mb = returns_t
481 .slice_view(start, 1, end - start)
482 .reshape(vec![(end - start) as i32, 1]);
483 let adv_mb = adv_t
484 .slice_view(start, 1, end - start)
485 .reshape(vec![(end - start) as i32, 1]);
486 let a_slice = &actions_vec[start..end];
487
488 // Zero grads
489 {
490 let mut ps = actor.parameters();
491 actor_opt.zero_grad(&mut ps);
492 }
493 {
494 let mut ps = critic.parameters();
495 critic_opt.zero_grad(&mut ps);
496 }
497
498 // Forward
499 let logits_mb = actor.forward(&s_mb); // [B,A]
500 let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501 let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502 let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503 let pg1 = ratio.mul_tensor(&adv_mb);
504 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507 let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509 let v_pred = critic.forward(&s_mb);
510 let v_loss = v_pred
511 .sub_tensor(&ret_mb)
512 .pow_scalar(2.0)
513 .mean()
514 .mul_scalar(vf_coef);
515
516 // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517 let probs_mb = logits_mb.softmax(1);
518 let logp_all = probs_mb.add_scalar(1e-8).log();
519 let ent = probs_mb
520 .mul_tensor(&logp_all)
521 .sum_dims(&[1], true)
522 .mul_scalar(-1.0)
523 .mean()
524 .mul_scalar(ent_coef);
525
526 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527 loss.backward(None);
528
529 // Step actor
530 {
531 let params = actor.parameters();
532 let mut with_grads: Vec<&mut Tensor> = Vec::new();
533 for p in params {
534 if p.grad_owned().is_some() {
535 with_grads.push(p);
536 }
537 }
538 if !with_grads.is_empty() {
539 let _ = grad_global_norm(&mut with_grads);
540 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541 actor_opt.step(&mut with_grads);
542 actor_opt.zero_grad(&mut with_grads);
543 }
544 }
545
546 // Step critic
547 {
548 let params = critic.parameters();
549 let mut with_grads: Vec<&mut Tensor> = Vec::new();
550 for p in params {
551 if p.grad_owned().is_some() {
552 with_grads.push(p);
553 }
554 }
555 if !with_grads.is_empty() {
556 let _ = grad_global_norm(&mut with_grads);
557 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558 critic_opt.step(&mut with_grads);
559 critic_opt.zero_grad(&mut with_grads);
560 }
561 }
562
563 if e == 0 && mb == 0 {
564 println!(
565 "update@t={} | actor_loss={:.4} v_loss={:.4}",
566 t,
567 actor_loss.value(),
568 v_loss.value()
569 );
570 }
571
572 clear_all_graphs_known();
573 }
574 }
575 }
576
577 println!("=== PPO discrete training finished ===");
578 Ok(())
579}More examples
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330 println!("=== PPO Continuous Example (YardEnv) ===");
331
332 let state_dim = 3usize;
333 let action_dim = 1usize;
334
335 // Hparams
336 let total_steps = std::env::var("PPO_STEPS")
337 .ok()
338 .and_then(|v| v.parse::<usize>().ok())
339 .unwrap_or(4000usize);
340 let horizon = 128usize; // rollout length per update
341 let epochs = 4usize; // PPO epochs per update
342 let mini_batch_size = 64usize; // minibatch from horizon
343 let gamma = 0.99f32;
344 let lam = 0.95f32; // GAE lambda
345 let clip_eps = 0.2f32;
346 let vf_coef = 0.5f32;
347 let ent_coef = 0.0f32;
348 let max_grad_norm = 1.0f32;
349
350 // Models
351 let mut actor = Actor::new(state_dim, action_dim, Some(101));
352 let mut critic = Critic::new(state_dim, Some(202));
353
354 // Opts
355 let mut actor_opt = Adam::with_learning_rate(3e-4);
356 for p in actor.parameters() {
357 actor_opt.add_parameter(p);
358 }
359 let mut critic_opt = Adam::with_learning_rate(3e-4);
360 for p in critic.parameters() {
361 critic_opt.add_parameter(p);
362 }
363
364 // Env and RNG
365 let mut env = YardEnv::new(42);
366 let mut rng = SmallRng::new(999);
367 let mut state = env.reset();
368
369 // Metrics
370 let mut episode_return = 0.0f32;
371 let mut episode = 0usize;
372 let mut ema_return: Option<f32> = None;
373 let ema_alpha = 0.05f32;
374 let mut best_return = f32::NEG_INFINITY;
375
376 let mut t = 0usize;
377 while t < total_steps {
378 // Collect a rollout
379 let mut batch = RolloutBatch::new(horizon, state_dim);
380 for _ in 0..horizon {
381 // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382 let (mean, log_std_row) = actor.forward(&state);
383 let mean_v = mean.data()[0];
384 let log_std_v = log_std_row.data()[0];
385 let std_v = log_std_v.exp();
386 let noise = rng.normal();
387 let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389 // Build action tensor [1, A] for log_prob calculation with autograd
390 let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391 let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392 let log_prob_v = log_prob_t.data()[0];
393
394 // Step env
395 let (next_state, reward, done) = env.step(action_v);
396 episode_return += reward;
397
398 // Value
399 let value_t = critic.forward(&state);
400 let value_v = value_t.data()[0];
401
402 // Push
403 batch.push(
404 state.data(),
405 action_v,
406 log_prob_v,
407 reward,
408 if done { 1.0 } else { 0.0 },
409 value_v,
410 next_state.data(),
411 );
412
413 // Reset
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return
430 );
431 episode_return = 0.0;
432 episode += 1;
433 st
434 } else {
435 next_state
436 };
437
438 t += 1;
439 if t >= total_steps {
440 break;
441 }
442 }
443
444 // Bootstrap next values for GAE
445 let next_values: Vec<f32> = {
446 let mut out = Vec::with_capacity(batch.len());
447 for i in 0..batch.len() {
448 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450 let v2 = critic.forward(&s2_t).data()[0];
451 out.push(v2);
452 }
453 out
454 };
455
456 // Compute returns and advantages
457 let mut returns = vec![0.0f32; batch.len()];
458 let mut adv = vec![0.0f32; batch.len()];
459 compute_gae(
460 &mut returns,
461 &mut adv,
462 &batch.rewards,
463 &batch.dones,
464 &batch.values,
465 &next_values,
466 gamma,
467 lam,
468 );
469 normalize_in_place(&mut adv, 1e-8);
470
471 // Prepare tensors for training
472 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473 let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474 let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478 // PPO epochs over the rollout
479 let num_minibatches = batch.len().div_ceil(mini_batch_size);
480 for e in 0..epochs {
481 for mb in 0..num_minibatches {
482 let start = mb * mini_batch_size;
483 let end = (start + mini_batch_size).min(batch.len());
484 if start >= end {
485 break;
486 }
487
488 // Slice views
489 let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490 let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491 let a_mb = actions_t
492 .slice_view(start * action_dim, 1, (end - start) * action_dim)
493 .reshape(vec![(end - start) as i32, action_dim as i32]);
494 let oldlp_mb = old_logp_t
495 .slice_view(start, 1, end - start)
496 .reshape(vec![(end - start) as i32, 1]);
497 let ret_mb = returns_t
498 .slice_view(start, 1, end - start)
499 .reshape(vec![(end - start) as i32, 1]);
500 let adv_mb = adv_t
501 .slice_view(start, 1, end - start)
502 .reshape(vec![(end - start) as i32, 1]);
503
504 // Zero grads
505 {
506 let mut ps = actor.parameters();
507 actor_opt.zero_grad(&mut ps);
508 }
509 {
510 let mut ps = critic.parameters();
511 critic_opt.zero_grad(&mut ps);
512 }
513
514 // Forward actor and critic
515 let (mean_mb, log_std_row) = actor.forward(&s_mb);
516 let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517 let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518 let clip_low =
519 Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520 .unwrap();
521 let clip_high =
522 Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523 .unwrap();
524 // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525 let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526 let ratio_clipped =
527 clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528 let pg1 = ratio.mul_tensor(&adv_mb);
529 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532 let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534 let v_pred = critic.forward(&s_mb);
535 let v_loss = v_pred
536 .sub_tensor(&ret_mb)
537 .pow_scalar(2.0)
538 .mean()
539 .mul_scalar(vf_coef);
540
541 // Entropy (approx Gaussian entropy per action)
542 let entropy = log_std_row
543 .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544 .sum_dims(&[1], true)
545 .mean()
546 .mul_scalar(ent_coef);
547
548 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549 loss.backward(None);
550
551 // Step actor
552 {
553 let params = actor.parameters();
554 let mut with_grads: Vec<&mut Tensor> = Vec::new();
555 for p in params {
556 if p.grad_owned().is_some() {
557 with_grads.push(p);
558 }
559 }
560 if !with_grads.is_empty() {
561 let _ = grad_global_norm(&mut with_grads);
562 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563 actor_opt.step(&mut with_grads);
564 actor_opt.zero_grad(&mut with_grads);
565 }
566 }
567
568 // Step critic
569 {
570 let params = critic.parameters();
571 let mut with_grads: Vec<&mut Tensor> = Vec::new();
572 for p in params {
573 if p.grad_owned().is_some() {
574 with_grads.push(p);
575 }
576 }
577 if !with_grads.is_empty() {
578 let _ = grad_global_norm(&mut with_grads);
579 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580 critic_opt.step(&mut with_grads);
581 critic_opt.zero_grad(&mut with_grads);
582 }
583 }
584
585 // Occasionally log
586 if e == 0 && mb == 0 {
587 println!(
588 "update@t={} | actor_loss={:.4} v_loss={:.4}",
589 t,
590 actor_loss.value(),
591 v_loss.value()
592 );
593 }
594
595 clear_all_graphs_known();
596 }
597 }
598 }
599
600 println!("=== PPO training finished ===");
601 Ok(())
602}Sourcepub fn allocation_owner(&self) -> Option<&Arc<Allocation>>
pub fn allocation_owner(&self) -> Option<&Arc<Allocation>>
Get the allocation owner for this tensor
Returns the shared allocation owner if this tensor is a view, or None if this tensor owns its memory directly.
§Returns
Optional reference to the allocation owner
§Implementation Details
This method is used internally to manage memory lifecycle for tensor views. It helps determine whether a tensor shares memory with another tensor.
Sourcepub fn new_uninitialized(shape_dims: Vec<usize>) -> Self
pub fn new_uninitialized(shape_dims: Vec<usize>) -> Self
Create a new tensor with uninitialized memory
This method allocates memory for a tensor without initializing it to any value. This is useful for performance-critical operations where the memory will be immediately overwritten, such as matrix multiplication results.
§Safety
The caller must ensure that all memory is written before reading from the tensor. Reading from uninitialized memory is undefined behavior.
§Arguments
shape_dims- The dimensions of the tensor
§Returns
A tensor with uninitialized memory
§Performance
- Zero Initialization: Skips memory initialization for maximum performance
- SIMD Ready: Properly aligned for vectorized operations
- Memory Efficient: Uses optimized alignment strategies
§Example
use train_station::Tensor;
// Create uninitialized tensor for matmul result
let mut result = Tensor::new_uninitialized(vec![100, 100]);
// Initialize the memory before use
for value in result.data_mut() {
*value = 0.0;
}Sourcepub fn new_uninitialized_aligned(
shape_dims: Vec<usize>,
alignment_bytes: usize,
) -> Self
pub fn new_uninitialized_aligned( shape_dims: Vec<usize>, alignment_bytes: usize, ) -> Self
Create a new uninitialized tensor with an explicit alignment request (in bytes)
This is intended for internal high-performance paths (e.g., packed GEMM panels) where stronger alignment such as 64 bytes is desired even on AVX2 systems.
Source§impl Tensor
impl Tensor
Sourcepub fn gather(
&self,
dim: usize,
indices: &[usize],
index_shape: &[usize],
) -> Tensor
pub fn gather( &self, dim: usize, indices: &[usize], index_shape: &[usize], ) -> Tensor
Gather values along a dimension using a tensor of indices
This operation extracts elements from the input tensor based on indices provided along a specified dimension. The output tensor has the same shape as the index tensor, with each element taken from the input tensor at the corresponding position with the index value substituted for the specified dimension.
The gather operation is commonly used in machine learning for operations like embedding lookups, attention mechanisms, and advanced indexing patterns.
§Arguments
dim- The dimension along which to gather values (must be < tensor rank)indices- Flattened indices buffer containing the positions to gather fromindex_shape- Shape of the indices tensor and output tensor
§Returns
A new tensor with shape index_shape containing the gathered values
§Examples
§Basic Gather Operation
use train_station::Tensor;
// Create a 2x3 tensor: [[0.0, 0.1, 0.2], [0.3, 0.4, 0.5]]
let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap();
// Gather along dimension 1 (columns) with indices [2, 0, 1, 1]
let indices = [2, 0, 1, 1];
let index_shape = [2, 2];
let result = tensor.gather(1, &indices, &index_shape);
// Result shape is [2, 2]
assert_eq!(result.shape().dims(), vec![2, 2]);
// Row 0: indices [2, 0] -> [0.2, 0.0]
assert!((result.get(&[0, 0]) - 0.2).abs() < 1e-6);
assert!((result.get(&[0, 1]) - 0.0).abs() < 1e-6);
// Row 1: indices [1, 1] -> [0.4, 0.4]
assert!((result.get(&[1, 0]) - 0.4).abs() < 1e-6);
assert!((result.get(&[1, 1]) - 0.4).abs() < 1e-6);§Gather with Gradient Tracking
use train_station::Tensor;
let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap()
.with_requires_grad();
let indices = [1, 1, 0, 2];
let index_shape = [2, 2];
let mut result = tensor.gather(1, &indices, &index_shape);
// Compute gradients
result.backward(None);
let grad = tensor.grad_owned().expect("gradient missing");
// Verify gradient accumulation for repeated indices
assert!((grad.get(&[0, 1]) - 2.0).abs() < 1e-6); // Index 1 used twice in row 0§Performance Characteristics
- Time Complexity: O(n) where n is the number of elements in the output
- Memory Usage: Creates a new tensor with the same size as the index tensor
- Optimization: Uses precomputed strides for efficient memory access
- GradTrack Overhead: Minimal overhead when gradient tracking is enabled
§Implementation Details
The gather operation works by:
- Validating input dimensions and index bounds
- Creating an output tensor with the specified index shape
- Iterating through all positions in the output tensor
- Computing source offsets using the input tensor’s strides
- Copying values from the input tensor to the output tensor
- Registering the operation for gradient computation if needed
§Safety
This function performs bounds checking to ensure:
- The specified dimension is within the tensor’s rank
- All indices are within bounds for the specified dimension
- The index shape is compatible with the input tensor shape
- The indices buffer length matches the product of index shape dimensions
§Panics
This function will panic if:
dimis greater than or equal to the tensor’s rank- Any index in
indicesis out of bounds for the specified dimension - The
index_shaperank doesn’t match the input tensor’s rank - The
index_shapedimensions don’t match the input tensor (except alongdim) - The
indiceslength doesn’t equal the product ofindex_shapedimensions
Examples found in repository?
44fn cross_entropy_logits(
45 logits: &Tensor,
46 labels: &[usize],
47 batch: usize,
48 _num_classes: usize,
49) -> Tensor {
50 // log_softmax = logits - logsumexp(logits, dim=1)
51 let max_logits = logits.max_dims(&[1], true);
52 let shifted = logits.sub_tensor(&max_logits);
53 let exp = shifted.exp();
54 let sum_exp = exp.sum_dims(&[1], true);
55 let log_sum_exp = sum_exp.log();
56 let log_softmax = shifted.sub_tensor(&log_sum_exp);
57 let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58 ll.mul_scalar(-1.0).mean()
59}More examples
273fn log_prob_actions(
274 logits: &Tensor,
275 actions: &[usize],
276 batch: usize,
277 _action_dim: usize,
278) -> Tensor {
279 let max_logits = logits.max_dims(&[1], true); // [B,1]
280 let shifted = logits.sub_tensor(&max_logits);
281 let exp = shifted.exp();
282 let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283 let log_sum_exp = sum_exp.log(); // [B,1]
284 let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285 // gather selected action log-probs
286 log_softmax.gather(1, actions, &[batch, 1])
287}333pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334 println!("=== DQN Example (YardEnv discrete) ===");
335
336 // Dims
337 let state_dim = 3usize;
338 let action_dim = 3usize;
339
340 // Hparams
341 let gamma = 0.99f32;
342 let batch_size = 64usize;
343 let start_steps = 200usize;
344 let target_update_interval = 200usize; // hard update cadence
345 let max_grad_norm = 1.0f32;
346 let mut epsilon = 1.0f32;
347 let eps_min = 0.05f32;
348 let eps_decay_steps = 2_000usize; // linear decay
349 let total_steps = std::env::var("DQN_STEPS")
350 .ok()
351 .and_then(|v| v.parse::<usize>().ok())
352 .unwrap_or(3000usize);
353
354 // Models
355 let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356 let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357 q_targ.net.copy_from(&q_net.net);
358 q_targ.set_requires_grad_all(false);
359
360 // Optimizer
361 let mut q_opt = Adam::with_learning_rate(3e-4);
362 for p in q_net.parameters() {
363 q_opt.add_parameter(p);
364 }
365
366 // Replay + env
367 let mut rb = ReplayBuffer::new(100_000, state_dim);
368 let mut env = YardEnv::new(12345);
369 let mut rng = SmallRng::new(999_111);
370
371 // Metrics
372 let mut state = env.reset();
373 let mut episode_return = 0.0f32;
374 let mut episode = 0usize;
375 let mut ema_return: Option<f32> = None;
376 let ema_alpha = 0.05f32;
377 let mut best_return = f32::NEG_INFINITY;
378
379 for t in 0..total_steps {
380 // Epsilon-greedy action
381 let action_index = if t < start_steps || rng.next_f32() < epsilon {
382 rng.sample_index(action_dim)
383 } else {
384 let _ng = NoGradTrack::new();
385 let q_vals = q_net.forward(&state);
386 let row = q_vals.data();
387 let mut best_i = 0usize;
388 let mut best_v = row[0];
389 for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390 if r > best_v {
391 best_v = r;
392 best_i = i;
393 }
394 }
395 best_i
396 };
397
398 // Env step
399 let (next_state, reward, done) = env.step(action_index);
400 episode_return += reward;
401
402 // Store
403 let s_slice = state.data().to_vec();
404 let s2_slice = next_state.data().to_vec();
405 rb.push(
406 &s_slice,
407 action_index,
408 reward,
409 if done { 1.0 } else { 0.0 },
410 &s2_slice,
411 );
412
413 // Reset on done
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return,
430 rb.size
431 );
432 episode_return = 0.0;
433 episode += 1;
434 st
435 } else {
436 next_state
437 };
438
439 // Epsilon linear decay
440 if t < eps_decay_steps {
441 epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442 }
443
444 // Train
445 if rb.can_sample(batch_size) {
446 let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448 // Double DQN target: a* = argmax_a Q_online(s2,a); y = r + (1-d)*gamma*Q_target(s2, a*)
449 let target_q = {
450 let _ng = NoGradTrack::new();
451 let q_online_s2 = q_net.forward(&s2);
452 // argmax per row (manual on CPU)
453 let row_stride = action_dim;
454 let qd = q_online_s2.data();
455 let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456 for i in 0..batch_size {
457 let base = i * row_stride;
458 let mut bi = 0usize;
459 let mut bv = qd[base];
460 for j in 1..action_dim {
461 let v = qd[base + j];
462 if v > bv {
463 bv = v;
464 bi = j;
465 }
466 }
467 next_actions.push(bi);
468 }
469 let q_targ_s2 = q_targ.forward(&s2);
470 let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471 let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472 r.add_tensor(¬_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473 };
474
475 // Q(s,a) for current actions
476 // Zero grads first
477 {
478 let mut params = q_net.parameters();
479 q_opt.zero_grad(&mut params);
480 }
481
482 let q_all = q_net.forward(&s);
483 let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484 let diff = q_sa.sub_tensor(&target_q);
485 let mut loss = pseudo_huber_mean(&diff);
486 loss.backward(None);
487
488 // Step (filter only params with grads)
489 {
490 let params = q_net.parameters();
491 let mut with_grads: Vec<&mut Tensor> = Vec::new();
492 for p in params {
493 if p.grad_owned().is_some() {
494 with_grads.push(p);
495 }
496 }
497 if !with_grads.is_empty() {
498 let gn = grad_global_norm(&mut with_grads);
499 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500 q_opt.step(&mut with_grads);
501 q_opt.zero_grad(&mut with_grads);
502 if t % 100 == 0 {
503 let mut pn = q_net.parameters();
504 let pn_l2 = params_l2_norm(&mut pn);
505 let q_mean = q_all.mean().value();
506 println!(
507 "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508 t, loss.value(), q_mean, gn, pn_l2, epsilon
509 );
510 }
511 }
512 }
513
514 // Target hard update
515 if t % target_update_interval == 0 {
516 q_targ.net.copy_from(&q_net.net);
517 }
518
519 // Clear graphs
520 clear_all_graphs_known();
521 }
522 }
523
524 println!("=== DQN training finished ===");
525 Ok(())
526}Source§impl Tensor
impl Tensor
Sourcepub fn index_select(&self, dim: usize, indices: &[usize]) -> Tensor
pub fn index_select(&self, dim: usize, indices: &[usize]) -> Tensor
Select elements along a dimension using a list of indices
This operation extracts elements from the input tensor along a specified dimension using the provided indices. The output tensor has the same shape as the input except along the specified dimension, where the size becomes the length of the indices array.
The index_select operation is commonly used for extracting specific rows, columns, or slices from tensors, and is particularly useful in machine learning for operations like embedding lookups and attention mechanisms.
§Arguments
dim- The dimension along which to select elements (must be < tensor rank)indices- Array of indices specifying which elements to select alongdim
§Returns
A new tensor with the same shape as the input except along dim, where the
size is indices.len()
§Examples
§Basic Index Selection
use train_station::Tensor;
// Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
// Select columns 2 and 0 from dimension 1
let result = tensor.index_select(1, &[2, 0]);
// Result shape is [2, 2] (same as input except dim 1 is now 2)
assert_eq!(result.shape().dims(), vec![2, 2]);
// Row 0: selected columns [2, 0] -> [2.0, 0.0]
assert_eq!(result.get(&[0, 0]), 2.0);
assert_eq!(result.get(&[0, 1]), 0.0);
// Row 1: selected columns [2, 0] -> [5.0, 3.0]
assert_eq!(result.get(&[1, 0]), 5.0);
assert_eq!(result.get(&[1, 1]), 3.0);§Index Selection with Gradient Tracking
use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap()
.with_requires_grad();
// Select specific elements with gradient tracking enabled
let mut result = tensor.index_select(1, &[1, 2]);
result.backward(None);
// Verify gradients are computed correctly
let grad = tensor.grad_owned().expect("gradient missing");
assert_eq!(grad.shape().dims(), vec![2, 3]);§Selecting Rows from a Matrix
use train_station::Tensor;
// Create a 3x2 matrix
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]).unwrap();
// Select rows 2 and 0 (dimension 0)
let result = tensor.index_select(0, &[2, 0]);
// Result shape is [2, 2]
assert_eq!(result.shape().dims(), vec![2, 2]);
// Selected rows: row 2 [5.0, 6.0], row 0 [1.0, 2.0]
assert_eq!(result.get(&[0, 0]), 5.0); // First row of result (was row 2)
assert_eq!(result.get(&[0, 1]), 6.0);
assert_eq!(result.get(&[1, 0]), 1.0); // Second row of result (was row 0)
assert_eq!(result.get(&[1, 1]), 2.0);§Performance Characteristics
- Time Complexity: O(n) where n is the number of elements in the output tensor
- Memory Usage: Creates a new tensor with size equal to the output shape
- Optimization: Uses precomputed strides for efficient memory access
- GradTrack Overhead: Minimal overhead when gradient tracking is enabled
- Memory Layout: Output tensor is always contiguous for optimal performance
§Implementation Details
The index_select operation works by:
- Validating the dimension and index bounds
- Computing the output shape (same as input except along
dim) - Creating a new contiguous output tensor
- Iterating through all positions in the output tensor using nested loops:
- Outer loop: iterate over dimensions before
dim - Middle loop: iterate over the selected indices
- Inner loop: iterate over dimensions after
dim
- Outer loop: iterate over dimensions before
- Computing source offsets using the input tensor’s strides
- Copying values from input to output tensor
- Registering the operation for gradient computation if needed
§Safety
This function performs comprehensive bounds checking to ensure:
- The specified dimension is within the tensor’s rank
- All indices are within bounds for the specified dimension
- Memory access is safe through proper offset calculations
§Panics
This function will panic if:
dimis greater than or equal to the tensor’s rank- Any index in
indicesis out of bounds for the specified dimension
§Thread Safety
This function is thread-safe and can be called concurrently on different tensors. The operation does not modify the input tensor and creates a new output tensor.
Source§impl Tensor
impl Tensor
Sourcepub fn masked_fill(&self, mask: &[bool], value: f32) -> Tensor
pub fn masked_fill(&self, mask: &[bool], value: f32) -> Tensor
Fill masked elements with a specified value
This operation returns a copy of the input tensor where elements are replaced by the specified value wherever the corresponding boolean mask is true. Elements where the mask is false retain their original values from the input tensor.
The masked_fill operation is commonly used in machine learning for operations like masking attention weights, zeroing out specific elements, and implementing dropout-like functionality.
§Arguments
mask- Boolean array with the same length as the number of tensor elementsvalue- The value to fill masked positions with
§Returns
A new tensor with the same shape as the input, where masked elements are
replaced by value and unmasked elements retain their original values
§Examples
§Basic Masked Fill
use train_station::Tensor;
// Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
// Create a mask: [false, true, false, true, false, true]
let mask = [false, true, false, true, false, true];
let result = tensor.masked_fill(&mask, -1.0);
// Result: [[0.0, -1.0, 2.0], [-1.0, 4.0, -1.0]]
assert_eq!(result.shape().dims(), vec![2, 3]);
assert_eq!(result.get(&[0, 0]), 0.0); // Unmasked
assert_eq!(result.get(&[0, 1]), -1.0); // Masked
assert_eq!(result.get(&[0, 2]), 2.0); // Unmasked
assert_eq!(result.get(&[1, 0]), -1.0); // Masked
assert_eq!(result.get(&[1, 1]), 4.0); // Unmasked
assert_eq!(result.get(&[1, 2]), -1.0); // Masked§Masked Fill with Gradient Tracking
use train_station::Tensor;
let tensor = Tensor::from_slice(&[0.0, 0.1, 0.2, 0.3, 0.4, 0.5], vec![2, 3]).unwrap()
.with_requires_grad();
// Create a mask with some true values
let mask = [false, true, false, true, false, false];
let mut result = tensor.masked_fill(&mask, 5.0);
// Compute gradients
result.backward(None);
let grad = tensor.grad_owned().expect("gradient missing");
// Gradients should be zero where mask is true, 1 elsewhere
assert_eq!(grad.shape().dims(), vec![2, 3]);
assert!((grad.get(&[0, 0]) - 1.0).abs() < 1e-6); // Unmasked: gradient flows
assert!((grad.get(&[0, 1]) - 0.0).abs() < 1e-6); // Masked: no gradient
assert!((grad.get(&[0, 2]) - 1.0).abs() < 1e-6); // Unmasked: gradient flows§Zeroing Out Specific Elements
use train_station::Tensor;
// Create a tensor with some values
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
// Create a mask to zero out every other element
let mask = [true, false, true, false, true, false];
let result = tensor.masked_fill(&mask, 0.0);
// Result: [[0.0, 2.0, 0.0], [4.0, 0.0, 6.0]]
assert_eq!(result.get(&[0, 0]), 0.0); // Zeroed
assert_eq!(result.get(&[0, 1]), 2.0); // Kept
assert_eq!(result.get(&[0, 2]), 0.0); // Zeroed
assert_eq!(result.get(&[1, 0]), 4.0); // Kept
assert_eq!(result.get(&[1, 1]), 0.0); // Zeroed
assert_eq!(result.get(&[1, 2]), 6.0); // Kept§Performance Characteristics
- Time Complexity: O(n) where n is the number of elements in the tensor
- Memory Usage: Creates a new tensor with the same size as the input
- Optimization: Uses efficient stride-based iteration for non-contiguous tensors
- GradTrack Overhead: Minimal overhead when gradient tracking is enabled
- Memory Layout: Output tensor is always contiguous for optimal performance
§Implementation Details
The masked_fill operation works by:
- Validating that the mask length equals the number of tensor elements
- Creating a new contiguous output tensor with the same shape
- Iterating through all elements in logical order
- For each element, checking the corresponding mask value:
- If mask is true: use the fill value
- If mask is false: copy the original value from input tensor
- Computing source offsets using the input tensor’s shape for non-contiguous tensors
- Registering the operation for gradient computation if needed
§Safety
This function performs bounds checking to ensure:
- The mask length equals the number of tensor elements
- Memory access is safe through proper offset calculations
- The operation handles both contiguous and non-contiguous tensors correctly
§Panics
This function will panic if:
- The mask length does not equal the number of tensor elements
§Thread Safety
This function is thread-safe and can be called concurrently on different tensors. The operation does not modify the input tensor and creates a new output tensor.
§GradTrack Behavior
When gradient tracking is enabled:
- Gradients do not flow through masked positions (they are zeroed)
- Gradients flow normally through unmasked positions
- This behavior is useful for implementing operations like dropout
Examples found in repository?
72 pub fn forward(
73 &self,
74 query: &Tensor,
75 key: &Tensor,
76 value: &Tensor,
77 attn_mask: Option<&Tensor>,
78 ) -> Tensor {
79 let qkv = Self::project_qkv(query, key, value, &self.q_proj, &self.k_proj, &self.v_proj);
80 let (q, k, v) = qkv;
81
82 // Split heads: [b, t, e] -> [b, h, t, d]
83 let (b, tq, _e) = Self::triple(query);
84 let (_b2, tk, _e2) = Self::triple(key);
85 let q = Self::split_heads(&q, b, tq, self.num_heads, self.head_dim);
86 let k = Self::split_heads(&k, b, tk, self.num_heads, self.head_dim);
87 let v = Self::split_heads(&v, b, tk, self.num_heads, self.head_dim);
88
89 // Scaled dot-product attention
90 // logits: [b, h, tq, tk]
91 let k_t = k.transpose(2, 3);
92 let mut logits = q.matmul(&k_t).div_scalar((self.head_dim as f32).sqrt());
93 if let Some(mask) = attn_mask {
94 let dims = mask.shape().dims().to_vec();
95 // If boolean-like mask matching [b,h,tq,tk], apply masked_fill
96 if dims.len() == 4 && dims[0] == b && dims[1] == self.num_heads && dims[2] == tq {
97 // Interpret mask > 0.5 as keep; we invert to build masked positions
98 let cond: Vec<bool> = mask.data().iter().map(|&v| v < 0.5).collect();
99 // Apply masked fill on a flattened view, then reshape back
100 let flat_logits = logits.view(vec![(b * self.num_heads * tq * tk) as i32]);
101 let filled = flat_logits.masked_fill(&cond, f32::NEG_INFINITY);
102 logits = filled.view(vec![b as i32, self.num_heads as i32, tq as i32, tk as i32]);
103 } else {
104 // Fallback: additive mask
105 logits = logits.add_tensor(mask);
106 }
107 }
108 let attn = logits.softmax(3);
109
110 // context: [b, h, tq, d]
111 let context = attn.matmul(&v);
112 let context = context.permute(vec![0, 2, 1, 3]); // [b, tq, h, d]
113 let context = context.contiguous().view(vec![
114 b as i32,
115 tq as i32,
116 (self.num_heads * self.head_dim) as i32,
117 ]);
118
119 // Output projection (flatten to 2D, project, then restore 3D)
120 let flat = context.view(vec![(b * tq) as i32, self.embed_dim as i32]);
121 let out2d = self.out_proj.forward(&flat);
122 out2d.view(vec![b as i32, tq as i32, self.embed_dim as i32])
123 }Source§impl Tensor
impl Tensor
Sourcepub fn select(&self, dim: usize, index: usize) -> Tensor
pub fn select(&self, dim: usize, index: usize) -> Tensor
Select a slice along a given dimension at a specific index
This operation extracts a slice from the input tensor by fixing a specific dimension at a given index. The result is a tensor with one fewer dimension than the input, containing the selected slice.
The select operation returns a view (zero-copy) when the base offset is zero, otherwise it creates a contiguous copy to ensure correctness. This operation is commonly used for extracting specific rows, columns, or slices from tensors.
§Arguments
dim- The dimension along which to select (must be < tensor rank)index- The index within the specified dimension to select (must be < dim size)
§Returns
A tensor with the selected slice. The result has the same shape as the input except with the specified dimension removed.
§Performance
Returns a view when possible (base offset is zero) to avoid copying. On non-zero offsets, falls back to a contiguous copy for correctness. Gradients propagate back to the selected slice when GradTrack is enabled.
§Examples
§Basic Row Selection
use train_station::Tensor;
// Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
// Select row 1 (dimension 0, index 1)
let result = tensor.select(0, 1);
// Result shape is [3] (dimension 0 removed)
assert_eq!(result.shape().dims(), vec![3]);
assert_eq!(result.get(&[0]), 3.0); // First element of row 1
assert_eq!(result.get(&[1]), 4.0); // Second element of row 1
assert_eq!(result.get(&[2]), 5.0); // Third element of row 1§Column Selection
use train_station::Tensor;
// Create a 2x3 tensor: [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]
let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0], vec![2, 3]).unwrap();
// Select column 1 (dimension 1, index 1)
let result = tensor.select(1, 1);
// Result shape is [2] (dimension 1 removed)
assert_eq!(result.shape().dims(), vec![2]);
assert_eq!(result.get(&[0]), 1.0); // Column 1, row 0
assert_eq!(result.get(&[1]), 4.0); // Column 1, row 1§Select with Gradient Tracking
use train_station::Tensor;
let tensor = Tensor::from_slice(&[0.0, 1.0, 2.0, 3.0], vec![2, 2]).unwrap()
.with_requires_grad();
// Select row 1 with gradient tracking enabled
let mut result = tensor.select(0, 1);
result.backward(None);
// Verify gradients are computed correctly
let grad = tensor.grad_owned().expect("gradient missing");
assert_eq!(grad.shape().dims(), vec![2, 2]);
// Only row 1 receives gradients
assert_eq!(grad.get(&[0, 0]), 0.0); // Row 0: no gradient
assert_eq!(grad.get(&[0, 1]), 0.0); // Row 0: no gradient
assert_eq!(grad.get(&[1, 0]), 1.0); // Row 1: gradient flows
assert_eq!(grad.get(&[1, 1]), 1.0); // Row 1: gradient flows§Performance Characteristics
- Time Complexity: O(n) where n is the number of elements in the selected slice
- Memory Usage: Zero-copy view when base offset is zero, otherwise creates a copy
- Optimization: Uses efficient stride-based access for non-contiguous tensors
- GradTrack Overhead: Minimal overhead when gradient tracking is enabled
- Memory Layout: Result is contiguous when a copy is made, view otherwise
§Implementation Details
The select operation works by:
- Validating the dimension and index bounds
- Computing the new shape by removing the selected dimension
- Computing the new strides by removing the selected dimension’s stride
- Calculating the base offset for the selected slice
- If base offset is zero: creating a view with adjusted shape/strides
- If base offset is non-zero: creating a contiguous copy of the slice
- Registering the operation for gradient computation if needed
§Safety
This function performs comprehensive bounds checking to ensure:
- The tensor has non-zero rank
- The specified dimension is within the tensor’s rank
- The index is within bounds for the specified dimension
- Memory access is safe through proper offset calculations
§Panics
This function will panic if:
- The tensor has zero rank
dimis greater than or equal to the tensor’s rankindexis greater than or equal to the size of the specified dimension
§Thread Safety
This function is thread-safe and can be called concurrently on different tensors. The operation does not modify the input tensor and creates either a view or a new tensor.
§View vs Copy Behavior
- View (zero-copy): When the base offset is zero, returns a view that shares the same memory as the input tensor with adjusted shape and strides
- Copy: When the base offset is non-zero, creates a contiguous copy to ensure correctness across all operations
§GradTrack Behavior
When gradient tracking is enabled:
- Gradients are scattered back to the selected slice in the input tensor
- Other positions in the input tensor receive zero gradients
- This behavior ensures correct gradient flow for the selected elements
Source§impl Tensor
impl Tensor
Sourcepub fn zeros(shape_dims: Vec<usize>) -> Self
pub fn zeros(shape_dims: Vec<usize>) -> Self
Creates a new tensor filled with zeros
Convenience constructor that creates a tensor and initializes all elements to zero. Uses optimized SIMD operations for efficient zero initialization.
§Arguments
shape_dims- Vector of dimension sizes defining the tensor shape
§Returns
A new tensor with all elements initialized to zero
§Performance
- Memory Allocation: Single allocation with optimized alignment
- Initialization: SIMD-optimized zero filling for large tensors
- Thread Safe: Atomic ID generation for gradtrack tracking
§Examples
use train_station::Tensor;
let tensor = Tensor::zeros(vec![2, 3]);
assert_eq!(tensor.size(), 6);
assert_eq!(tensor.shape().dims(), vec![2, 3]);
// Verify all elements are zero
assert_eq!(tensor.get(&[0, 0]), 0.0);
assert_eq!(tensor.get(&[1, 2]), 0.0);Examples found in repository?
53 pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
54 // Xavier/Glorot initialization: scale by sqrt(1/input_size)
55 let scale = (1.0 / input_size as f32).sqrt();
56
57 let weight = Tensor::randn(vec![input_size, output_size], seed)
58 .mul_scalar(scale)
59 .with_requires_grad();
60 let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
61
62 Self {
63 weight,
64 bias,
65 input_size,
66 output_size,
67 }
68 }More examples
42fn demonstrate_tensor_creation() {
43 println!("--- Tensor Creation ---");
44
45 // Create tensors with different initializations
46 let zeros = Tensor::zeros(vec![2, 3]);
47 println!(
48 "Zeros tensor: shape {:?}, data: {:?}",
49 zeros.shape().dims(),
50 zeros.data()
51 );
52
53 let ones = Tensor::ones(vec![3, 2]);
54 println!(
55 "Ones tensor: shape {:?}, data: {:?}",
56 ones.shape().dims(),
57 ones.data()
58 );
59
60 // Create tensor from slice
61 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
62 let from_slice = Tensor::from_slice(&data, vec![2, 3]).unwrap();
63 println!(
64 "From slice: shape {:?}, data: {:?}",
65 from_slice.shape().dims(),
66 from_slice.data()
67 );
68
69 // Create tensor with specific value
70 let mut filled = Tensor::new(vec![2, 2]);
71 {
72 let data = filled.data_mut();
73 for value in data.iter_mut() {
74 *value = 42.0;
75 }
76 }
77 println!("Filled with 42: {:?}", filled.data());
78
79 // Create tensor with random data
80 let random = Tensor::randn(vec![2, 2], Some(42));
81 println!(
82 "Random tensor: shape {:?}, data: {:?}",
83 random.shape().dims(),
84 random.data()
85 );
86}47fn demonstrate_basic_optimizer_setup() {
48 println!("--- Basic Optimizer Setup ---");
49
50 // Create parameters that require gradients
51 let weight = Tensor::randn(vec![3, 2], Some(42)).with_requires_grad();
52 let bias = Tensor::zeros(vec![2]).with_requires_grad();
53
54 println!("Created parameters:");
55 println!(
56 " Weight: shape {:?}, requires_grad: {}",
57 weight.shape().dims(),
58 weight.requires_grad()
59 );
60 println!(
61 " Bias: shape {:?}, requires_grad: {}",
62 bias.shape().dims(),
63 bias.requires_grad()
64 );
65
66 // Create Adam optimizer with default configuration
67 let mut optimizer = Adam::new();
68 println!(
69 "Created Adam optimizer with learning rate: {}",
70 optimizer.learning_rate()
71 );
72
73 // Add parameters to optimizer
74 optimizer.add_parameter(&weight);
75 optimizer.add_parameter(&bias);
76 println!(
77 "Added {} parameters to optimizer",
78 optimizer.parameter_count()
79 );
80
81 // Create optimizer with custom configuration
82 let config = AdamConfig {
83 learning_rate: 0.01,
84 beta1: 0.9,
85 beta2: 0.999,
86 eps: 1e-8,
87 weight_decay: 0.0,
88 amsgrad: false,
89 };
90
91 let mut custom_optimizer = Adam::with_config(config);
92 custom_optimizer.add_parameter(&weight);
93 custom_optimizer.add_parameter(&bias);
94
95 println!(
96 "Created custom optimizer with learning rate: {}",
97 custom_optimizer.learning_rate()
98 );
99
100 // Demonstrate parameter linking
101 println!("Parameter linking completed successfully");
102}
103
104/// Demonstrate simple linear regression training
105fn demonstrate_linear_regression() -> Result<(), Box<dyn std::error::Error>> {
106 println!("\n--- Linear Regression Training ---");
107
108 // Create model parameters
109 let mut weight = Tensor::randn(vec![1, 1], Some(43)).with_requires_grad();
110 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
111
112 // Create optimizer
113 let mut optimizer = Adam::with_learning_rate(0.01);
114 optimizer.add_parameter(&weight);
115 optimizer.add_parameter(&bias);
116
117 // Create simple training data: y = 2*x + 1
118 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
119 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
120
121 println!("Training data:");
122 println!(" X: {:?}", x_data.data());
123 println!(" Y: {:?}", y_true.data());
124 println!(" Target: y = 2*x + 1");
125
126 // Training loop
127 let num_epochs = 100;
128 let mut losses = Vec::new();
129
130 for epoch in 0..num_epochs {
131 // Forward pass: y_pred = x * weight + bias
132 let y_pred = x_data.matmul(&weight) + &bias;
133
134 // Compute loss: MSE
135 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
136
137 // Backward pass
138 loss.backward(None);
139
140 // Optimizer step
141 optimizer.step(&mut [&mut weight, &mut bias]);
142 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
143
144 losses.push(loss.value());
145
146 // Print progress every 20 epochs
147 if epoch % 20 == 0 || epoch == num_epochs - 1 {
148 println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
149 }
150 }
151
152 // Evaluate final model
153 let final_predictions = x_data.matmul(&weight) + &bias;
154 println!("\nFinal model evaluation:");
155 println!(" Learned weight: {:.6}", weight.value());
156 println!(" Learned bias: {:.6}", bias.value());
157 println!(" Predictions vs True:");
158
159 for i in 0..5 {
160 let x1 = x_data.data()[i];
161 let pred = final_predictions.data()[i];
162 let true_val = y_true.data()[i];
163 println!(
164 " x={:.1}: pred={:.3}, true={:.1}, error={:.3}",
165 x1,
166 pred,
167 true_val,
168 (pred - true_val).abs()
169 );
170 }
171
172 Ok(())
173}
174
175/// Demonstrate advanced training patterns
176fn demonstrate_advanced_training() -> Result<(), Box<dyn std::error::Error>> {
177 println!("\n--- Advanced Training Patterns ---");
178
179 // Create a more complex model
180 let mut weight = Tensor::randn(vec![1, 2], Some(44)).with_requires_grad();
181 let mut bias = Tensor::zeros(vec![2]).with_requires_grad();
182
183 // Create optimizer with different learning rate
184 let mut optimizer = Adam::with_learning_rate(0.005);
185 optimizer.add_parameter(&weight);
186 optimizer.add_parameter(&bias);
187
188 // Create training data: y = 2*x + [1, 3]
189 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
190 let y_true = Tensor::from_slice(
191 &[3.0, 5.0, 7.0, 9.0, 11.0, 6.0, 8.0, 10.0, 12.0, 14.0],
192 vec![5, 2],
193 )
194 .unwrap();
195
196 println!("Advanced training with monitoring:");
197 println!(" Initial learning rate: {}", optimizer.learning_rate());
198
199 // Training loop with monitoring
200 let num_epochs = 50;
201 let mut losses = Vec::new();
202 let mut weight_norms = Vec::new();
203 let mut gradient_norms = Vec::new();
204
205 for epoch in 0..num_epochs {
206 // Forward pass
207 let y_pred = x_data.matmul(&weight) + &bias;
208 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
209
210 // Backward pass
211 loss.backward(None);
212
213 // Compute gradient norm before optimizer step
214 let gradient_norm = weight.grad_owned().unwrap().norm();
215
216 // Optimizer step
217 optimizer.step(&mut [&mut weight, &mut bias]);
218 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
219
220 // Learning rate scheduling: reduce every 10 epochs
221 if epoch > 0 && epoch % 10 == 0 {
222 let current_lr = optimizer.learning_rate();
223 let new_lr = current_lr * 0.5;
224 optimizer.set_learning_rate(new_lr);
225 println!(
226 "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
227 epoch, current_lr, new_lr
228 );
229 }
230
231 // Record metrics
232 losses.push(loss.value());
233 weight_norms.push(weight.norm().value());
234 gradient_norms.push(gradient_norm.value());
235
236 // Print detailed progress
237 if epoch % 10 == 0 || epoch == num_epochs - 1 {
238 println!(
239 "Epoch {:2}: Loss = {:.6}, Weight Norm = {:.6}, Gradient Norm = {:.6}",
240 epoch,
241 loss.value(),
242 weight.norm().value(),
243 gradient_norm.value()
244 );
245 }
246 }
247
248 println!("Final learning rate: {}", optimizer.learning_rate());
249
250 // Analyze training progression
251 let initial_loss = losses[0];
252 let final_loss = losses[losses.len() - 1];
253 let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
254
255 println!("\nTraining Analysis:");
256 println!(" Initial loss: {:.6}", initial_loss);
257 println!(" Final loss: {:.6}", final_loss);
258 println!(" Loss reduction: {:.1}%", loss_reduction);
259 println!(" Final weight norm: {:.6}", weight.norm().value());
260 println!(" Final bias: {:?}", bias.data());
261
262 Ok(())
263}
264
265/// Demonstrate learning rate scheduling
266fn demonstrate_learning_rate_scheduling() -> Result<(), Box<dyn std::error::Error>> {
267 println!("\n--- Learning Rate Scheduling ---");
268
269 // Create simple model
270 let mut weight = Tensor::randn(vec![1, 1], Some(45)).with_requires_grad();
271 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
272
273 // Create optimizer with high initial learning rate
274 let mut optimizer = Adam::with_learning_rate(0.1);
275 optimizer.add_parameter(&weight);
276 optimizer.add_parameter(&bias);
277
278 // Simple data
279 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3, 1]).unwrap();
280 let y_true = Tensor::from_slice(&[2.0, 4.0, 6.0], vec![3, 1]).unwrap();
281
282 println!("Initial learning rate: {}", optimizer.learning_rate());
283
284 // Training loop with learning rate scheduling
285 let num_epochs = 50;
286 let mut losses = Vec::new();
287
288 for epoch in 0..num_epochs {
289 // Forward pass
290 let y_pred = x_data.matmul(&weight) + &bias;
291 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
292
293 // Backward pass
294 loss.backward(None);
295
296 // Optimizer step
297 optimizer.step(&mut [&mut weight, &mut bias]);
298 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
299
300 // Learning rate scheduling: reduce every 10 epochs
301 if epoch > 0 && epoch % 10 == 0 {
302 let current_lr = optimizer.learning_rate();
303 let new_lr = current_lr * 0.5;
304 optimizer.set_learning_rate(new_lr);
305 println!(
306 "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
307 epoch, current_lr, new_lr
308 );
309 }
310
311 losses.push(loss.value());
312
313 // Print progress
314 if epoch % 10 == 0 || epoch == num_epochs - 1 {
315 println!(
316 "Epoch {:2}: Loss = {:.6}, LR = {:.3}",
317 epoch,
318 loss.value(),
319 optimizer.learning_rate()
320 );
321 }
322 }
323
324 println!("Final learning rate: {}", optimizer.learning_rate());
325
326 Ok(())
327}
328
329/// Demonstrate training monitoring and analysis
330fn demonstrate_training_monitoring() -> Result<(), Box<dyn std::error::Error>> {
331 println!("\n--- Training Monitoring ---");
332
333 // Create model
334 let mut weight = Tensor::randn(vec![1, 1], Some(46)).with_requires_grad();
335 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
336
337 // Create optimizer
338 let mut optimizer = Adam::with_learning_rate(0.01);
339 optimizer.add_parameter(&weight);
340 optimizer.add_parameter(&bias);
341
342 // Training data
343 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
344 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0], vec![4, 1]).unwrap();
345
346 // Training loop with comprehensive monitoring
347 let num_epochs = 30;
348 let mut losses = Vec::new();
349 let mut weight_history = Vec::new();
350 let mut bias_history = Vec::new();
351
352 for epoch in 0..num_epochs {
353 // Forward pass
354 let y_pred = x_data.matmul(&weight) + &bias;
355 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
356
357 // Backward pass
358 loss.backward(None);
359
360 // Optimizer step
361 optimizer.step(&mut [&mut weight, &mut bias]);
362 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
363
364 // Record history
365 losses.push(loss.value());
366 weight_history.push(weight.value());
367 bias_history.push(bias.value());
368
369 // Print detailed monitoring
370 if epoch % 5 == 0 || epoch == num_epochs - 1 {
371 println!(
372 "Epoch {:2}: Loss = {:.6}, Weight = {:.6}, Bias = {:.6}",
373 epoch,
374 loss.value(),
375 weight.value(),
376 bias.value()
377 );
378 }
379 }
380
381 // Analyze training progression
382 println!("\nTraining Analysis:");
383 println!(" Initial loss: {:.6}", losses[0]);
384 println!(" Final loss: {:.6}", losses[losses.len() - 1]);
385 println!(
386 " Loss reduction: {:.1}%",
387 (losses[0] - losses[losses.len() - 1]) / losses[0] * 100.0
388 );
389
390 // Compute statistics
391 let loss_mean = compute_mean(&losses);
392 let loss_std = compute_std(&losses);
393 let weight_change = (weight_history[weight_history.len() - 1] - weight_history[0]).abs();
394 let bias_change = (bias_history[bias_history.len() - 1] - bias_history[0]).abs();
395
396 println!(" Average loss: {:.6} ± {:.6}", loss_mean, loss_std);
397 println!(" Weight change: {:.6}", weight_change);
398 println!(" Bias change: {:.6}", bias_change);
399 println!(" Final weight norm: {:.6}", weight.norm().value());
400 println!(" Final bias: {:.6}", bias.value());
401
402 Ok(())
403}76 pub fn infer_autoregressive(&self, src: &Tensor, max_steps: usize) -> Tensor {
77 let (b, _s, e) = Self::triple(src);
78 let mut memory = src.clone();
79 for enc in &self.encoders {
80 memory = enc.forward(&memory, None);
81 }
82
83 let mut out_seq: Vec<Tensor> = Vec::new();
84 // Start token: zeros
85 let mut current = Tensor::zeros(vec![b, 1, e]);
86 for _step in 0..max_steps {
87 // Build causal mask for length t
88 let t = current.shape().dims()[1];
89 let mut causal = Tensor::ones(vec![b, self.num_heads, t, t]);
90 // Upper triangle as false -> masked for all batches and heads
91 for bb in 0..b {
92 for hh in 0..self.num_heads {
93 for i in 0..t {
94 for j in (i + 1)..t {
95 let offset = causal.memory_offset(&[bb, hh, i, j]);
96 let data = causal.data_mut();
97 data[offset] = 0.0;
98 }
99 }
100 }
101 }
102 let mut step_out = current.clone();
103 for dec in &self.decoders {
104 step_out = dec.forward(&step_out, &memory, Some(&causal), None);
105 }
106 // (Toy) append placeholder token; real models would project last token
107 out_seq.push(step_out.clone());
108 // Append a zero token to grow sequence by 1 for next causal computation
109 current = Tensor::zeros(vec![b, t + 1, e]);
110 }
111 // Simple return of final sequence placeholder
112 current
113 }
114
115 /// Non auto-regressive inference: single forward pass
116 pub fn infer_non_autoregressive(&self, src: &Tensor, tgt_len: usize) -> Tensor {
117 let (b, _s, e) = Self::triple(src);
118 let mut memory = src.clone();
119 for enc in &self.encoders {
120 memory = enc.forward(&memory, None);
121 }
122 let tgt = Tensor::zeros(vec![b, tgt_len, e]);
123 let mut out = tgt.clone();
124 for dec in &self.decoders {
125 out = dec.forward(&out, &memory, None, None);
126 }
127 out
128 }165fn main() -> Result<(), Box<dyn std::error::Error>> {
166 println!("=== Multi-Head Attention Example ===");
167
168 let batch = 2usize;
169 let src_len = 5usize;
170 let tgt_len = 4usize;
171 let embed = 16usize;
172 let heads = 4usize;
173
174 let query = Tensor::randn(vec![batch, tgt_len, embed], Some(7));
175 let key = Tensor::randn(vec![batch, src_len, embed], Some(8));
176 let value = Tensor::randn(vec![batch, src_len, embed], Some(9));
177
178 let mut mha = MultiHeadAttention::new(embed, heads, Some(42));
179
180 // Simple causal mask for target self-attention shape [b, h, tq, tk]
181 let mut mask = Tensor::zeros(vec![batch, heads, tgt_len, src_len]);
182 // Disallow attending to future positions when tgt_len <= src_len by adding -1e9
183 // Here, just demonstrate mask broadcast/add mechanics with a light mask on last head
184 if src_len >= tgt_len {
185 // set upper triangle to a large negative value for head 0
186 for i in 0..tgt_len {
187 for j in (i + 1)..src_len {
188 let idx = [0usize, 0usize, i, j];
189 // Quick set via data_mut using a slice view
190 let offset = mask.memory_offset(&idx);
191 let data = mask.data_mut();
192 data[offset] = -1e9;
193 }
194 }
195 }
196
197 let out = mha.forward(&query, &key, &value, Some(&mask));
198 println!("Output shape: {:?}", out.shape().dims());
199
200 // Tiny training step to confirm gradients are wired
201 let mut optimizer = Adam::with_learning_rate(0.01);
202 let mut params = mha.parameters();
203 for p in ¶ms {
204 optimizer.add_parameter(p);
205 }
206
207 // Dummy loss = mean of output
208 let mut loss = out.mean();
209 loss.backward(None);
210 optimizer.step(&mut params);
211 optimizer.zero_grad(&mut params);
212
213 println!("Loss: {:.6}", loss.value());
214 println!("=== Done ===");
215 Ok(())
216}84fn demonstrate_default_adam() -> Result<(), Box<dyn std::error::Error>> {
85 println!("--- Default Adam Configuration ---");
86
87 // Create a simple regression problem: y = 2*x + 1
88 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
89 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
90
91 // Create model parameters
92 let mut weight = Tensor::randn(vec![1, 1], Some(42)).with_requires_grad();
93 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
94
95 // Create Adam optimizer with default configuration
96 let mut optimizer = Adam::new();
97 optimizer.add_parameter(&weight);
98 optimizer.add_parameter(&bias);
99
100 println!("Default Adam configuration:");
101 println!(" Learning rate: {}", optimizer.learning_rate());
102 println!(" Initial weight: {:.6}", weight.value());
103 println!(" Initial bias: {:.6}", bias.value());
104
105 // Training loop
106 let num_epochs = 50;
107 let mut losses = Vec::new();
108
109 for epoch in 0..num_epochs {
110 // Forward pass
111 let y_pred = x_data.matmul(&weight) + &bias;
112 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
113
114 // Backward pass
115 loss.backward(None);
116
117 // Optimizer step
118 optimizer.step(&mut [&mut weight, &mut bias]);
119 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
120
121 losses.push(loss.value());
122
123 if epoch % 10 == 0 || epoch == num_epochs - 1 {
124 println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
125 }
126 }
127
128 // Evaluate final model
129 let _final_predictions = x_data.matmul(&weight) + &bias;
130 println!("\nFinal model:");
131 println!(" Learned weight: {:.6} (target: 2.0)", weight.value());
132 println!(" Learned bias: {:.6} (target: 1.0)", bias.value());
133 println!(" Final loss: {:.6}", losses[losses.len() - 1]);
134
135 Ok(())
136}
137
138/// Demonstrate learning rate comparison
139fn demonstrate_learning_rate_comparison() -> Result<(), Box<dyn std::error::Error>> {
140 println!("\n--- Learning Rate Comparison ---");
141
142 let learning_rates = [0.001, 0.01, 0.1];
143 let mut results = Vec::new();
144
145 for &lr in &learning_rates {
146 println!("\nTesting learning rate: {}", lr);
147
148 let stats = train_with_config(TrainingConfig {
149 learning_rate: lr,
150 ..Default::default()
151 })?;
152
153 results.push((lr, stats.clone()));
154
155 println!(" Final loss: {:.6}", stats.final_loss);
156 println!(" Convergence epoch: {}", stats.convergence_epoch);
157 }
158
159 // Compare results
160 println!("\nLearning Rate Comparison Summary:");
161 for (lr, stats) in &results {
162 println!(
163 " LR={:6}: Loss={:.6}, Converged@{}",
164 lr, stats.final_loss, stats.convergence_epoch
165 );
166 }
167
168 Ok(())
169}
170
171/// Demonstrate weight decay comparison
172fn demonstrate_weight_decay_comparison() -> Result<(), Box<dyn std::error::Error>> {
173 println!("\n--- Weight Decay Comparison ---");
174
175 let weight_decays = [0.0, 0.001, 0.01];
176 let mut results = Vec::new();
177
178 for &wd in &weight_decays {
179 println!("\nTesting weight decay: {}", wd);
180
181 let stats = train_with_config(TrainingConfig {
182 weight_decay: wd,
183 ..Default::default()
184 })?;
185
186 results.push((wd, stats.clone()));
187
188 println!(" Final loss: {:.6}", stats.final_loss);
189 println!(" Final weight norm: {:.6}", stats.weight_norm);
190 }
191
192 // Compare results
193 println!("\nWeight Decay Comparison Summary:");
194 for (wd, stats) in &results {
195 println!(
196 " WD={:6}: Loss={:.6}, Weight Norm={:.6}",
197 wd, stats.final_loss, stats.weight_norm
198 );
199 }
200
201 Ok(())
202}
203
204/// Demonstrate beta parameter tuning
205fn demonstrate_beta_parameter_tuning() -> Result<(), Box<dyn std::error::Error>> {
206 println!("\n--- Beta Parameter Tuning ---");
207
208 let beta_configs = [
209 (0.9, 0.999), // Default
210 (0.8, 0.999), // More aggressive momentum
211 (0.95, 0.999), // Less aggressive momentum
212 (0.9, 0.99), // Faster second moment decay
213 ];
214
215 let mut results = Vec::new();
216
217 for (i, (beta1, beta2)) in beta_configs.iter().enumerate() {
218 println!(
219 "\nTesting beta configuration {}: beta1={}, beta2={}",
220 i + 1,
221 beta1,
222 beta2
223 );
224
225 let config = TrainingConfig {
226 beta1: *beta1,
227 beta2: *beta2,
228 ..Default::default()
229 };
230
231 let stats = train_with_config(config)?;
232 results.push(((*beta1, *beta2), stats.clone()));
233
234 println!(" Final loss: {:.6}", stats.final_loss);
235 println!(" Convergence epoch: {}", stats.convergence_epoch);
236 }
237
238 // Compare results
239 println!("\nBeta Parameter Comparison Summary:");
240 for ((beta1, beta2), stats) in &results {
241 println!(
242 " B1={:4}, B2={:5}: Loss={:.6}, Converged@{}",
243 beta1, beta2, stats.final_loss, stats.convergence_epoch
244 );
245 }
246
247 Ok(())
248}
249
250/// Demonstrate configuration benchmarking
251fn demonstrate_configuration_benchmarking() -> Result<(), Box<dyn std::error::Error>> {
252 println!("\n--- Configuration Benchmarking ---");
253
254 // Define configurations to benchmark
255 let configs = vec![
256 (
257 "Conservative",
258 TrainingConfig {
259 learning_rate: 0.001,
260 weight_decay: 0.001,
261 beta1: 0.95,
262 ..Default::default()
263 },
264 ),
265 (
266 "Balanced",
267 TrainingConfig {
268 learning_rate: 0.01,
269 weight_decay: 0.0,
270 beta1: 0.9,
271 ..Default::default()
272 },
273 ),
274 (
275 "Aggressive",
276 TrainingConfig {
277 learning_rate: 0.1,
278 weight_decay: 0.0,
279 beta1: 0.8,
280 ..Default::default()
281 },
282 ),
283 ];
284
285 let mut benchmark_results = Vec::new();
286
287 for (name, config) in configs {
288 println!("\nBenchmarking {} configuration:", name);
289
290 let start_time = std::time::Instant::now();
291 let stats = train_with_config(config.clone())?;
292 let elapsed = start_time.elapsed();
293
294 println!(" Training time: {:.2}ms", elapsed.as_millis());
295 println!(" Final loss: {:.6}", stats.final_loss);
296 println!(" Convergence: {} epochs", stats.convergence_epoch);
297
298 benchmark_results.push((name.to_string(), stats, elapsed));
299 }
300
301 // Summary
302 println!("\nBenchmarking Summary:");
303 for (name, stats, elapsed) in &benchmark_results {
304 println!(
305 " {:12}: Loss={:.6}, Time={:4}ms, Converged@{}",
306 name,
307 stats.final_loss,
308 elapsed.as_millis(),
309 stats.convergence_epoch
310 );
311 }
312
313 Ok(())
314}
315
316/// Helper function to train with specific configuration
317fn train_with_config(config: TrainingConfig) -> Result<TrainingStats, Box<dyn std::error::Error>> {
318 // Create training data
319 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
320 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
321
322 // Create model parameters
323 let mut weight = Tensor::randn(vec![1, 1], Some(123)).with_requires_grad();
324 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
325
326 // Create optimizer with custom configuration
327 let adam_config = AdamConfig {
328 learning_rate: config.learning_rate,
329 beta1: config.beta1,
330 beta2: config.beta2,
331 eps: 1e-8,
332 weight_decay: config.weight_decay,
333 amsgrad: false,
334 };
335
336 let mut optimizer = Adam::with_config(adam_config);
337 optimizer.add_parameter(&weight);
338 optimizer.add_parameter(&bias);
339
340 // Training loop
341 let mut losses = Vec::new();
342 let mut convergence_epoch = config.epochs;
343
344 for epoch in 0..config.epochs {
345 // Forward pass
346 let y_pred = x_data.matmul(&weight) + &bias;
347 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
348
349 // Backward pass
350 loss.backward(None);
351
352 // Optimizer step
353 optimizer.step(&mut [&mut weight, &mut bias]);
354 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
355
356 let loss_value = loss.value();
357 losses.push(loss_value);
358
359 // Check for convergence (loss < 0.01)
360 if loss_value < 0.01 && convergence_epoch == config.epochs {
361 convergence_epoch = epoch;
362 }
363 }
364
365 Ok(TrainingStats {
366 config,
367 final_loss: losses[losses.len() - 1],
368 loss_history: losses,
369 convergence_epoch,
370 weight_norm: weight.norm().value(),
371 })
372}Sourcepub fn ones(shape_dims: Vec<usize>) -> Self
pub fn ones(shape_dims: Vec<usize>) -> Self
Creates a new tensor filled with ones
Convenience constructor that creates a tensor and initializes all elements to one. Uses optimized SIMD operations for efficient initialization.
§Arguments
shape_dims- Vector of dimension sizes defining the tensor shape
§Returns
A new tensor with all elements initialized to one
§Performance
- Memory Allocation: Single allocation with optimized alignment
- Initialization: SIMD-optimized one filling for large tensors
- Thread Safe: Atomic ID generation for gradtrack tracking
§Examples
use train_station::Tensor;
let tensor = Tensor::ones(vec![2, 3]);
assert_eq!(tensor.size(), 6);
assert_eq!(tensor.shape().dims(), vec![2, 3]);
// Verify all elements are one
assert_eq!(tensor.get(&[0, 0]), 1.0);
assert_eq!(tensor.get(&[1, 2]), 1.0);Examples found in repository?
42fn demonstrate_tensor_creation() {
43 println!("--- Tensor Creation ---");
44
45 // Create tensors with different initializations
46 let zeros = Tensor::zeros(vec![2, 3]);
47 println!(
48 "Zeros tensor: shape {:?}, data: {:?}",
49 zeros.shape().dims(),
50 zeros.data()
51 );
52
53 let ones = Tensor::ones(vec![3, 2]);
54 println!(
55 "Ones tensor: shape {:?}, data: {:?}",
56 ones.shape().dims(),
57 ones.data()
58 );
59
60 // Create tensor from slice
61 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
62 let from_slice = Tensor::from_slice(&data, vec![2, 3]).unwrap();
63 println!(
64 "From slice: shape {:?}, data: {:?}",
65 from_slice.shape().dims(),
66 from_slice.data()
67 );
68
69 // Create tensor with specific value
70 let mut filled = Tensor::new(vec![2, 2]);
71 {
72 let data = filled.data_mut();
73 for value in data.iter_mut() {
74 *value = 42.0;
75 }
76 }
77 println!("Filled with 42: {:?}", filled.data());
78
79 // Create tensor with random data
80 let random = Tensor::randn(vec![2, 2], Some(42));
81 println!(
82 "Random tensor: shape {:?}, data: {:?}",
83 random.shape().dims(),
84 random.data()
85 );
86}More examples
76 pub fn infer_autoregressive(&self, src: &Tensor, max_steps: usize) -> Tensor {
77 let (b, _s, e) = Self::triple(src);
78 let mut memory = src.clone();
79 for enc in &self.encoders {
80 memory = enc.forward(&memory, None);
81 }
82
83 let mut out_seq: Vec<Tensor> = Vec::new();
84 // Start token: zeros
85 let mut current = Tensor::zeros(vec![b, 1, e]);
86 for _step in 0..max_steps {
87 // Build causal mask for length t
88 let t = current.shape().dims()[1];
89 let mut causal = Tensor::ones(vec![b, self.num_heads, t, t]);
90 // Upper triangle as false -> masked for all batches and heads
91 for bb in 0..b {
92 for hh in 0..self.num_heads {
93 for i in 0..t {
94 for j in (i + 1)..t {
95 let offset = causal.memory_offset(&[bb, hh, i, j]);
96 let data = causal.data_mut();
97 data[offset] = 0.0;
98 }
99 }
100 }
101 }
102 let mut step_out = current.clone();
103 for dec in &self.decoders {
104 step_out = dec.forward(&step_out, &memory, Some(&causal), None);
105 }
106 // (Toy) append placeholder token; real models would project last token
107 out_seq.push(step_out.clone());
108 // Append a zero token to grow sequence by 1 for next causal computation
109 current = Tensor::zeros(vec![b, t + 1, e]);
110 }
111 // Simple return of final sequence placeholder
112 current
113 }
114
115 /// Non auto-regressive inference: single forward pass
116 pub fn infer_non_autoregressive(&self, src: &Tensor, tgt_len: usize) -> Tensor {
117 let (b, _s, e) = Self::triple(src);
118 let mut memory = src.clone();
119 for enc in &self.encoders {
120 memory = enc.forward(&memory, None);
121 }
122 let tgt = Tensor::zeros(vec![b, tgt_len, e]);
123 let mut out = tgt.clone();
124 for dec in &self.decoders {
125 out = dec.forward(&out, &memory, None, None);
126 }
127 out
128 }
129
130 /// Helper: build boolean-like causal mask [b, heads, t, t] with 1.0 keep, 0.0 masked
131 fn build_causal_mask_static(batch: usize, heads: usize, t: usize) -> Tensor {
132 let mut mask = Tensor::ones(vec![batch, heads, t, t]);
133 for bb in 0..batch {
134 for hh in 0..heads {
135 for i in 0..t {
136 for j in (i + 1)..t {
137 let offset = mask.memory_offset(&[bb, hh, i, j]);
138 let data = mask.data_mut();
139 data[offset] = 0.0;
140 }
141 }
142 }
143 }
144 mask
145 }333pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334 println!("=== DQN Example (YardEnv discrete) ===");
335
336 // Dims
337 let state_dim = 3usize;
338 let action_dim = 3usize;
339
340 // Hparams
341 let gamma = 0.99f32;
342 let batch_size = 64usize;
343 let start_steps = 200usize;
344 let target_update_interval = 200usize; // hard update cadence
345 let max_grad_norm = 1.0f32;
346 let mut epsilon = 1.0f32;
347 let eps_min = 0.05f32;
348 let eps_decay_steps = 2_000usize; // linear decay
349 let total_steps = std::env::var("DQN_STEPS")
350 .ok()
351 .and_then(|v| v.parse::<usize>().ok())
352 .unwrap_or(3000usize);
353
354 // Models
355 let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356 let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357 q_targ.net.copy_from(&q_net.net);
358 q_targ.set_requires_grad_all(false);
359
360 // Optimizer
361 let mut q_opt = Adam::with_learning_rate(3e-4);
362 for p in q_net.parameters() {
363 q_opt.add_parameter(p);
364 }
365
366 // Replay + env
367 let mut rb = ReplayBuffer::new(100_000, state_dim);
368 let mut env = YardEnv::new(12345);
369 let mut rng = SmallRng::new(999_111);
370
371 // Metrics
372 let mut state = env.reset();
373 let mut episode_return = 0.0f32;
374 let mut episode = 0usize;
375 let mut ema_return: Option<f32> = None;
376 let ema_alpha = 0.05f32;
377 let mut best_return = f32::NEG_INFINITY;
378
379 for t in 0..total_steps {
380 // Epsilon-greedy action
381 let action_index = if t < start_steps || rng.next_f32() < epsilon {
382 rng.sample_index(action_dim)
383 } else {
384 let _ng = NoGradTrack::new();
385 let q_vals = q_net.forward(&state);
386 let row = q_vals.data();
387 let mut best_i = 0usize;
388 let mut best_v = row[0];
389 for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390 if r > best_v {
391 best_v = r;
392 best_i = i;
393 }
394 }
395 best_i
396 };
397
398 // Env step
399 let (next_state, reward, done) = env.step(action_index);
400 episode_return += reward;
401
402 // Store
403 let s_slice = state.data().to_vec();
404 let s2_slice = next_state.data().to_vec();
405 rb.push(
406 &s_slice,
407 action_index,
408 reward,
409 if done { 1.0 } else { 0.0 },
410 &s2_slice,
411 );
412
413 // Reset on done
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return,
430 rb.size
431 );
432 episode_return = 0.0;
433 episode += 1;
434 st
435 } else {
436 next_state
437 };
438
439 // Epsilon linear decay
440 if t < eps_decay_steps {
441 epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442 }
443
444 // Train
445 if rb.can_sample(batch_size) {
446 let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448 // Double DQN target: a* = argmax_a Q_online(s2,a); y = r + (1-d)*gamma*Q_target(s2, a*)
449 let target_q = {
450 let _ng = NoGradTrack::new();
451 let q_online_s2 = q_net.forward(&s2);
452 // argmax per row (manual on CPU)
453 let row_stride = action_dim;
454 let qd = q_online_s2.data();
455 let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456 for i in 0..batch_size {
457 let base = i * row_stride;
458 let mut bi = 0usize;
459 let mut bv = qd[base];
460 for j in 1..action_dim {
461 let v = qd[base + j];
462 if v > bv {
463 bv = v;
464 bi = j;
465 }
466 }
467 next_actions.push(bi);
468 }
469 let q_targ_s2 = q_targ.forward(&s2);
470 let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471 let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472 r.add_tensor(¬_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473 };
474
475 // Q(s,a) for current actions
476 // Zero grads first
477 {
478 let mut params = q_net.parameters();
479 q_opt.zero_grad(&mut params);
480 }
481
482 let q_all = q_net.forward(&s);
483 let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484 let diff = q_sa.sub_tensor(&target_q);
485 let mut loss = pseudo_huber_mean(&diff);
486 loss.backward(None);
487
488 // Step (filter only params with grads)
489 {
490 let params = q_net.parameters();
491 let mut with_grads: Vec<&mut Tensor> = Vec::new();
492 for p in params {
493 if p.grad_owned().is_some() {
494 with_grads.push(p);
495 }
496 }
497 if !with_grads.is_empty() {
498 let gn = grad_global_norm(&mut with_grads);
499 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500 q_opt.step(&mut with_grads);
501 q_opt.zero_grad(&mut with_grads);
502 if t % 100 == 0 {
503 let mut pn = q_net.parameters();
504 let pn_l2 = params_l2_norm(&mut pn);
505 let q_mean = q_all.mean().value();
506 println!(
507 "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508 t, loss.value(), q_mean, gn, pn_l2, epsilon
509 );
510 }
511 }
512 }
513
514 // Target hard update
515 if t % target_update_interval == 0 {
516 q_targ.net.copy_from(&q_net.net);
517 }
518
519 // Clear graphs
520 clear_all_graphs_known();
521 }
522 }
523
524 println!("=== DQN training finished ===");
525 Ok(())
526}402pub fn main() -> Result<(), Box<dyn std::error::Error>> {
403 println!("=== TD3 Example (YardEnv) ===");
404
405 // Environment / problem dims
406 let state_dim = 3usize;
407 let action_dim = 1usize;
408
409 // Hyperparameters (small for demo)
410 let gamma = 0.99f32;
411 let tau = 0.005f32; // Polyak
412 let policy_noise = 0.2f32; // target smoothing noise stddev
413 let exploration_noise = 0.1f32; // behavior policy noise stddev
414 let policy_delay = 2usize;
415 let batch_size = 64usize;
416 let start_steps = 500usize; // random exploration steps
417 let total_steps = 1500usize;
418 let max_grad_norm = 1.0f32;
419
420 // Models
421 let mut actor = Actor::new(state_dim, action_dim, Some(11));
422 let mut actor_targ = Actor::new(state_dim, action_dim, Some(12));
423 actor_targ.net.copy_from(&actor.net);
424 actor_targ.set_requires_grad_all(false);
425
426 let mut critic1 = Critic::new(state_dim, action_dim, Some(21));
427 let mut critic2 = Critic::new(state_dim, action_dim, Some(22));
428 let mut critic1_targ = Critic::new(state_dim, action_dim, Some(23));
429 let mut critic2_targ = Critic::new(state_dim, action_dim, Some(24));
430 critic1_targ.net.copy_from(&critic1.net);
431 critic2_targ.net.copy_from(&critic2.net);
432 critic1_targ.set_requires_grad_all(false);
433 critic2_targ.set_requires_grad_all(false);
434
435 // Optimizers
436 let mut actor_opt = Adam::with_learning_rate(1e-3);
437 for p in actor.parameters() {
438 actor_opt.add_parameter(p);
439 }
440
441 let mut critic_opt = Adam::with_learning_rate(1e-4);
442 for p in critic1.parameters() {
443 critic_opt.add_parameter(p);
444 }
445 for p in critic2.parameters() {
446 critic_opt.add_parameter(p);
447 }
448
449 // Replay buffer and env
450 let mut rb = ReplayBuffer::new(100_000, state_dim, action_dim);
451 let mut env = YardEnv::new(1234);
452 let mut rng = SmallRng::new(987654321);
453
454 // Reset & metric trackers
455 let mut state = env.reset(); // [1, state_dim]
456 let mut episode_return = 0.0f32;
457 let mut episode = 0usize;
458 let mut ema_return: Option<f32> = None;
459 let ema_alpha = 0.05f32; // smooth short-term
460 let mut best_return = f32::NEG_INFINITY;
461 let mut policy_updates: usize = 0;
462
463 for t in 0..total_steps {
464 // Select action
465 let action_tensor = if t < start_steps {
466 let a = rng.uniform(-1.0, 1.0);
467 Tensor::from_slice(&[a], vec![1, action_dim]).unwrap()
468 } else {
469 // Behavior policy with exploration noise
470 let _ng = NoGradTrack::new();
471 let det = actor.forward(&state);
472 let noise = Tensor::randn(vec![1, action_dim], None).mul_scalar(exploration_noise);
473 tanh_bounded(&det.add_tensor(&noise))
474 };
475 let action_value = action_tensor.data()[0];
476
477 // Environment step
478 let (next_state, reward, done) = env.step(action_value);
479 episode_return += reward;
480
481 // Store transition
482 let s_slice = state.data().to_vec();
483 let a_slice = action_tensor.data().to_vec();
484 let s2_slice = next_state.data().to_vec();
485 rb.push(
486 &s_slice,
487 &a_slice,
488 reward,
489 if done { 1.0 } else { 0.0 },
490 &s2_slice,
491 );
492
493 state = if done {
494 let st = env.reset();
495 // Metrics: update EMA and best
496 ema_return = Some(match ema_return {
497 None => episode_return,
498 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
499 });
500 if episode_return > best_return {
501 best_return = episode_return;
502 }
503 println!(
504 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={} | policy_updates={}",
505 t,
506 episode,
507 episode_return,
508 ema_return.unwrap_or(episode_return),
509 best_return,
510 rb.size,
511 policy_updates
512 );
513 episode_return = 0.0;
514 episode += 1;
515 st
516 } else {
517 next_state
518 };
519
520 // Training
521 if rb.can_sample(batch_size) {
522 // Sample batch
523 let (s, a, r, d, s2) = rb.sample(batch_size, &mut rng);
524
525 // Compute target values y = r + (1-d)*gamma*min(Q1', Q2') using target networks (no grad)
526 let target_q = {
527 let _ng = NoGradTrack::new();
528 // Target actions with smoothing noise (tanh bounds)
529 let noise =
530 Tensor::randn(vec![batch_size, action_dim], None).mul_scalar(policy_noise);
531 let a_targ = tanh_bounded(&actor_targ.forward(&s2).add_tensor(&noise));
532 let q1_t = critic1_targ.forward(&s2, &a_targ);
533 let q2_t = critic2_targ.forward(&s2, &a_targ);
534
535 // Elementwise min via data() since this path is no-grad
536 let q1d = q1_t.data();
537 let q2d = q2_t.data();
538 let mut min_vec = Vec::with_capacity(batch_size);
539 for i in 0..batch_size {
540 let v1 = q1d[i];
541 let v2 = q2d[i];
542 min_vec.push(v1.min(v2));
543 }
544 let min_q = Tensor::from_slice(&min_vec, vec![batch_size, 1]).unwrap();
545 let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
546 r.add_tensor(¬_done.mul_scalar(gamma).mul_tensor(&min_q))
547 };
548
549 // Critic update (both critics)
550 // Zero grads in a short scope, then drop borrows before forward
551 {
552 let mut params = {
553 let c_params = critic1.parameters();
554 let c2_params = critic2.parameters();
555 let mut tmp: Vec<&mut Tensor> = Vec::new();
556 tmp.extend(c_params);
557 tmp.extend(c2_params);
558 tmp
559 };
560 critic_opt.zero_grad(&mut params);
561 }
562
563 // Forward current Q estimates
564 let q1 = critic1.forward(&s, &a);
565 let q2 = critic2.forward(&s, &a);
566 let diff1 = q1.sub_tensor(&target_q);
567 let diff2 = q2.sub_tensor(&target_q);
568 let mut critic_loss = diff1
569 .pow_scalar(2.0)
570 .mean()
571 .add_tensor(&diff2.pow_scalar(2.0).mean());
572
573 // Backward
574 critic_loss.backward(None);
575
576 // Optional gradient clipping + step (only for params that received grads)
577 {
578 let params = {
579 let c_params = critic1.parameters();
580 let c2_params = critic2.parameters();
581 let mut tmp: Vec<&mut Tensor> = Vec::new();
582 tmp.extend(c_params);
583 tmp.extend(c2_params);
584 tmp
585 };
586 let mut with_grads: Vec<&mut Tensor> = Vec::new();
587 for p in params {
588 if p.grad_owned().is_some() {
589 with_grads.push(p);
590 }
591 }
592 if !with_grads.is_empty() {
593 // Pre-step metrics
594 let grad_norm_before = grad_global_norm(&mut with_grads);
595 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
596 critic_opt.step(&mut with_grads);
597 critic_opt.zero_grad(&mut with_grads);
598
599 // Post-step metrics (param norm)
600 let mut for_norm_params = {
601 let c_params = critic1.parameters();
602 let c2_params = critic2.parameters();
603 let mut tmp: Vec<&mut Tensor> = Vec::new();
604 tmp.extend(c_params);
605 tmp.extend(c2_params);
606 tmp
607 };
608 let param_norm = params_l2_norm(&mut for_norm_params);
609
610 // Print compact critic metrics occasionally
611 if t % 100 == 0 {
612 let q1_mean = q1.mean().value();
613 let q2_mean = q2.mean().value();
614 let tq_mean = target_q.mean().value();
615 println!(
616 "t={:5} | critic_loss={:.4} | q1_mean={:.3} q2_mean={:.3} tq_mean={:.3} | grad_norm={:.3} | crit_param_norm={:.3}",
617 t,
618 critic_loss.value(),
619 q1_mean,
620 q2_mean,
621 tq_mean,
622 grad_norm_before,
623 param_norm
624 );
625 }
626 }
627 }
628
629 // Delayed policy update
630 if t % policy_delay == 0 {
631 // Actor update: maximize Q1(s, actor(s)) -> minimize -Q1
632 // Zero actor grads before backward
633 {
634 let mut a_params: Vec<&mut Tensor> = actor.parameters();
635 actor_opt.zero_grad(&mut a_params);
636 }
637
638 let a_pred = actor.forward(&s);
639 let q_for_actor = critic1.forward(&s, &a_pred);
640 let mut actor_loss = q_for_actor.mul_scalar(-1.0).mean();
641 actor_loss.backward(None);
642
643 {
644 let a_params: Vec<&mut Tensor> = actor.parameters();
645 let mut with_grads: Vec<&mut Tensor> = Vec::new();
646 for p in a_params {
647 if p.grad_owned().is_some() {
648 with_grads.push(p);
649 }
650 }
651 if !with_grads.is_empty() {
652 let grad_norm_before = grad_global_norm(&mut with_grads);
653 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
654 actor_opt.step(&mut with_grads);
655 actor_opt.zero_grad(&mut with_grads);
656
657 // Post-step param norm
658 let mut for_norm_params = actor.parameters();
659 let param_norm = params_l2_norm(&mut for_norm_params);
660
661 policy_updates += 1;
662 if t % 200 == 0 {
663 println!(
664 "t={:5} | actor_loss={:.4} | act_grad_norm={:.3} | act_param_norm={:.3} | lr_a={:.4e} lr_c={:.4e} | policy_updates={}",
665 t,
666 actor_loss.value(),
667 grad_norm_before,
668 param_norm,
669 actor_opt.learning_rate(),
670 critic_opt.learning_rate(),
671 policy_updates
672 );
673 }
674 }
675 }
676
677 // Target updates (Polyak averaging, no grad)
678 actor_targ.net.soft_update_from(&actor.net, tau);
679 critic1_targ.net.soft_update_from(&critic1.net, tau);
680 critic2_targ.net.soft_update_from(&critic2.net, tau);
681 }
682
683 // Clear entire graphs to avoid stale accumulation across iterations
684 clear_all_graphs_known();
685 }
686 }
687
688 println!("=== TD3 training finished ===");
689 Ok(())
690}Sourcepub fn zeros_on_device(shape_dims: Vec<usize>, device: Device) -> Self
pub fn zeros_on_device(shape_dims: Vec<usize>, device: Device) -> Self
Creates a new tensor filled with zeros on a specific device
Convenience constructor that creates a tensor on the specified device and initializes all elements to zero. Uses optimized SIMD operations for efficient zero initialization.
§Arguments
shape_dims- Vector of dimension sizes defining the tensor shapedevice- The device where the tensor should be allocated
§Returns
A new tensor with all elements initialized to zero
§Performance
- Memory Allocation: Device-specific allocation with optimized alignment
- Initialization: SIMD-optimized zero filling for large tensors
- Thread Safe: Atomic ID generation for gradtrack tracking
§Examples
use train_station::Tensor;
use train_station::Device;
let tensor = Tensor::zeros_on_device(vec![2, 2], Device::cpu());
assert_eq!(tensor.device(), Device::cpu());
assert_eq!(tensor.size(), 4);
// Verify all elements are zero
assert_eq!(tensor.get(&[0, 0]), 0.0);
assert_eq!(tensor.get(&[1, 1]), 0.0);Examples found in repository?
179fn demonstrate_utility_functions() {
180 println!("\n--- Utility Functions ---");
181
182 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184 // Basic properties
185 println!("Shape: {:?}", tensor.shape().dims());
186 println!("Size: {}", tensor.size());
187 println!("Is contiguous: {}", tensor.is_contiguous());
188 println!("Device: {:?}", tensor.device());
189
190 // Mathematical operations
191 let sum = tensor.sum();
192 println!("Sum: {}", sum.value());
193
194 let mean = tensor.mean();
195 println!("Mean: {}", mean.value());
196
197 let norm = tensor.norm();
198 println!("Norm: {}", norm.value());
199
200 // Device placement
201 let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202 println!(
203 "CPU tensor: shape {:?}, device: {:?}",
204 cpu_tensor.shape().dims(),
205 cpu_tensor.device()
206 );
207}Sourcepub fn ones_on_device(shape_dims: Vec<usize>, device: Device) -> Self
pub fn ones_on_device(shape_dims: Vec<usize>, device: Device) -> Self
Creates a new tensor filled with ones on a specific device
Convenience constructor that creates a tensor on the specified device and initializes all elements to one. Uses optimized SIMD operations for efficient initialization.
§Arguments
shape_dims- Vector of dimension sizes defining the tensor shapedevice- The device where the tensor should be allocated
§Returns
A new tensor with all elements initialized to one
§Performance
- Memory Allocation: Device-specific allocation with optimized alignment
- Initialization: SIMD-optimized one filling for large tensors
- Thread Safe: Atomic ID generation for gradtrack tracking
§Examples
use train_station::Tensor;
use train_station::Device;
let tensor = Tensor::ones_on_device(vec![2, 2], Device::cpu());
assert_eq!(tensor.device(), Device::cpu());
assert_eq!(tensor.size(), 4);
// Verify all elements are one
assert_eq!(tensor.get(&[0, 0]), 1.0);
assert_eq!(tensor.get(&[1, 1]), 1.0);Sourcepub fn fill(&mut self, value: f32)
pub fn fill(&mut self, value: f32)
Fills the tensor with a constant value using SIMD optimization
Efficiently initializes all elements of the tensor to the specified value. Uses SIMD operations for large tensors to maximize performance.
§Arguments
value- The value to fill the tensor with
§Performance
- SIMD Optimization: Uses AVX2 for large tensors when available
- Unrolled Loops: 4x unrolling for better instruction throughput
- Memory Bandwidth: Optimized for maximum memory bandwidth utilization
§Examples
use train_station::Tensor;
let mut tensor = Tensor::new(vec![2, 3]);
tensor.fill(42.0);
// Verify all elements are 42.0
assert_eq!(tensor.get(&[0, 0]), 42.0);
assert_eq!(tensor.get(&[1, 2]), 42.0);§Zero-Sized Tensor Handling
use train_station::Tensor;
let mut empty_tensor = Tensor::new(vec![0]);
empty_tensor.fill(42.0); // Should not panic
assert_eq!(empty_tensor.size(), 0);Source§impl Tensor
impl Tensor
Sourcepub fn from_slice(data: &[f32], shape_dims: Vec<usize>) -> Result<Self, String>
pub fn from_slice(data: &[f32], shape_dims: Vec<usize>) -> Result<Self, String>
Creates a tensor from a slice of data
Creates a new tensor with the specified shape and copies data from the provided slice. Validates that the data size matches the tensor shape before performing the copy operation.
This method provides an efficient way to create tensors from existing data sources while ensuring data integrity and proper memory management.
§Arguments
data- Slice of f32 values to copy into the tensorshape_dims- Vector of dimension sizes defining the tensor shape
§Returns
Ok(Tensor)- Successfully created tensor with copied dataErr(String)- Error if data size doesn’t match shape
§Performance
- Memory Copy: Efficient non-overlapping copy using SIMD when possible
- Validation: Fast size validation before allocation
- Alignment: Proper memory alignment for optimal performance
- Large Data: Optimized handling of large datasets
§Examples
§Basic Usage
use train_station::Tensor;
let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let tensor = Tensor::from_slice(&data, vec![2, 3]).unwrap();
assert_eq!(tensor.size(), 6);
assert_eq!(tensor.get(&[0, 0]), 1.0);
assert_eq!(tensor.get(&[1, 2]), 6.0);§Multi-Dimensional Data
use train_station::Tensor;
// 1D tensor
let data_1d = [1.0, 2.0, 3.0];
let tensor_1d = Tensor::from_slice(&data_1d, vec![3]).unwrap();
assert_eq!(tensor_1d.shape().dims(), vec![3]);
assert_eq!(tensor_1d.get(&[1]), 2.0);
// 3D tensor
let data_3d = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let tensor_3d = Tensor::from_slice(&data_3d, vec![2, 2, 2]).unwrap();
assert_eq!(tensor_3d.shape().dims(), vec![2, 2, 2]);
assert_eq!(tensor_3d.get(&[0, 0, 0]), 1.0);
assert_eq!(tensor_3d.get(&[1, 1, 1]), 8.0);§Error Handling
use train_station::Tensor;
// Size mismatch error
let data = [1.0, 2.0, 3.0];
let result = Tensor::from_slice(&data, vec![2, 2]);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.contains("Data size 3 doesn't match shape size 4"));§Zero-Sized Tensors
use train_station::Tensor;
// Handle empty tensors gracefully
let data: [f32; 0] = [];
let tensor = Tensor::from_slice(&data, vec![0]).unwrap();
assert_eq!(tensor.size(), 0);
assert_eq!(tensor.shape().dims(), vec![0]);§Large Data Sets
use train_station::Tensor;
// Efficient handling of large datasets
let size = 1000;
let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
let tensor = Tensor::from_slice(&data, vec![size]).unwrap();
assert_eq!(tensor.size(), size);
assert_eq!(tensor.get(&[0]), 0.0);
assert_eq!(tensor.get(&[100]), 100.0);
assert_eq!(tensor.get(&[999]), 999.0);§Implementation Details
This method performs the following steps:
- Shape Validation: Creates a Shape object and validates dimensions
- Size Check: Ensures data length matches the calculated tensor size
- Memory Allocation: Allocates tensor memory with proper alignment
- Data Copy: Uses efficient non-overlapping memory copy operation
- Return: Returns the created tensor or descriptive error message
The memory copy operation uses std::ptr::copy_nonoverlapping for
maximum performance and safety, ensuring no data corruption occurs
during the copy process.
Examples found in repository?
184 fn state_tensor(&self) -> Tensor {
185 Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
186 }
187
188 fn step(&mut self, action_index: usize) -> (Tensor, f32, bool) {
189 let a = Self::ACTIONS[action_index.min(2)];
190 self.vel += 0.1 * a - 0.01 * self.pos;
191 self.pos += self.vel;
192 self.steps += 1;
193 let reward = -(self.pos * self.pos) - 0.05 * (a * a);
194 let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
195 (self.state_tensor(), reward, done)
196 }
197}
198
199// -------------------------------
200// Replay Buffer
201// -------------------------------
202
203struct ReplayBuffer {
204 capacity: usize,
205 size: usize,
206 pos: usize,
207 state_dim: usize,
208 states: Vec<f32>,
209 actions: Vec<usize>,
210 rewards: Vec<f32>,
211 dones: Vec<f32>,
212 next_states: Vec<f32>,
213}
214
215impl ReplayBuffer {
216 fn new(capacity: usize, state_dim: usize) -> Self {
217 Self {
218 capacity,
219 size: 0,
220 pos: 0,
221 state_dim,
222 states: vec![0.0; capacity * state_dim],
223 actions: vec![0usize; capacity],
224 rewards: vec![0.0; capacity],
225 dones: vec![0.0; capacity],
226 next_states: vec![0.0; capacity * state_dim],
227 }
228 }
229
230 fn push(&mut self, s: &[f32], a_idx: usize, r: f32, d: f32, s2: &[f32]) {
231 let i = self.pos;
232 let so = i * self.state_dim;
233 self.states[so..so + self.state_dim].copy_from_slice(s);
234 self.actions[i] = a_idx;
235 self.rewards[i] = r;
236 self.dones[i] = d;
237 self.next_states[so..so + self.state_dim].copy_from_slice(s2);
238 self.pos = (self.pos + 1) % self.capacity;
239 self.size = self.size.saturating_add(1).min(self.capacity);
240 }
241
242 fn can_sample(&self, batch_size: usize) -> bool {
243 self.size >= batch_size
244 }
245
246 fn sample(
247 &self,
248 batch_size: usize,
249 rng: &mut SmallRng,
250 ) -> (Tensor, Vec<usize>, Tensor, Tensor, Tensor) {
251 let mut s_vec = Vec::with_capacity(batch_size * self.state_dim);
252 let mut a_idx = Vec::with_capacity(batch_size);
253 let mut r_vec = Vec::with_capacity(batch_size);
254 let mut d_vec = Vec::with_capacity(batch_size);
255 let mut s2_vec = Vec::with_capacity(batch_size * self.state_dim);
256 for _ in 0..batch_size {
257 let idx = rng.sample_index(self.size);
258 let so = idx * self.state_dim;
259 s_vec.extend_from_slice(&self.states[so..so + self.state_dim]);
260 a_idx.push(self.actions[idx]);
261 r_vec.push(self.rewards[idx]);
262 d_vec.push(self.dones[idx]);
263 s2_vec.extend_from_slice(&self.next_states[so..so + self.state_dim]);
264 }
265 let s = Tensor::from_slice(&s_vec, vec![batch_size, self.state_dim]).unwrap();
266 let r = Tensor::from_slice(&r_vec, vec![batch_size, 1]).unwrap();
267 let d = Tensor::from_slice(&d_vec, vec![batch_size, 1]).unwrap();
268 let s2 = Tensor::from_slice(&s2_vec, vec![batch_size, self.state_dim]).unwrap();
269 (s, a_idx, r, d, s2)
270 }More examples
150 fn state_tensor(&self) -> Tensor {
151 Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
152 }
153 fn step(&mut self, action_idx: usize) -> (Tensor, f32, bool) {
154 let a = Self::ACTIONS[action_idx.min(2)];
155 self.vel += 0.1 * a - 0.01 * self.pos;
156 self.pos += self.vel;
157 self.steps += 1;
158 let reward = -(self.pos * self.pos) - 0.05 * (a * a);
159 let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
160 (self.state_tensor(), reward, done)
161 }
162}
163
164// -------------------------------
165// Rollout storage
166// -------------------------------
167
168struct RolloutBatch {
169 states: Vec<f32>,
170 actions: Vec<usize>,
171 old_logps: Vec<f32>,
172 rewards: Vec<f32>,
173 dones: Vec<f32>,
174 values: Vec<f32>,
175 next_states: Vec<f32>,
176 _state_dim: usize,
177}
178impl RolloutBatch {
179 fn new(cap: usize, sd: usize) -> Self {
180 Self {
181 states: Vec::with_capacity(cap * sd),
182 actions: Vec::with_capacity(cap),
183 old_logps: Vec::with_capacity(cap),
184 rewards: Vec::with_capacity(cap),
185 dones: Vec::with_capacity(cap),
186 values: Vec::with_capacity(cap),
187 next_states: Vec::with_capacity(cap * sd),
188 _state_dim: sd,
189 }
190 }
191 #[allow(clippy::too_many_arguments)]
192 fn push(&mut self, s: &[f32], a: usize, lp: f32, r: f32, d: f32, v: f32, s2: &[f32]) {
193 self.states.extend_from_slice(s);
194 self.actions.push(a);
195 self.old_logps.push(lp);
196 self.rewards.push(r);
197 self.dones.push(d);
198 self.values.push(v);
199 self.next_states.extend_from_slice(s2);
200 }
201 fn len(&self) -> usize {
202 self.actions.len()
203 }
204}
205
206// -------------------------------
207// Helpers
208// -------------------------------
209
210#[allow(clippy::too_many_arguments)]
211fn compute_gae(
212 returns_out: &mut [f32],
213 adv_out: &mut [f32],
214 rewards: &[f32],
215 dones: &[f32],
216 values: &[f32],
217 next_values: &[f32],
218 gamma: f32,
219 lam: f32,
220) {
221 let n = rewards.len();
222 let mut gae = 0.0f32;
223 for t in (0..n).rev() {
224 let not_done = 1.0 - dones[t];
225 let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
226 gae = delta + gamma * lam * not_done * gae;
227 adv_out[t] = gae;
228 returns_out[t] = gae + values[t];
229 }
230}
231
232fn normalize_in_place(x: &mut [f32], eps: f32) {
233 let n = x.len() as f32;
234 if n <= 1.0 {
235 return;
236 }
237 let mean = x.iter().copied().sum::<f32>() / n;
238 let var = x
239 .iter()
240 .map(|v| {
241 let d = v - mean;
242 d * d
243 })
244 .sum::<f32>()
245 / n;
246 let std = (var + eps).sqrt();
247 for v in x.iter_mut() {
248 *v = (*v - mean) / std;
249 }
250}
251
252fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
253 let mut total_sq = 0.0f32;
254 for p in parameters.iter() {
255 if let Some(g) = p.grad_owned() {
256 for &v in g.data() {
257 total_sq += v * v;
258 }
259 }
260 }
261 let norm = total_sq.sqrt();
262 if norm > max_norm {
263 let scale = max_norm / (norm + eps);
264 for p in parameters.iter_mut() {
265 if let Some(g) = p.grad_owned() {
266 p.set_grad(g.mul_scalar(scale));
267 }
268 }
269 }
270}
271
272// log-softmax for selected actions: given logits [B,A] and actions Vec<usize> -> log_prob [B,1]
273fn log_prob_actions(
274 logits: &Tensor,
275 actions: &[usize],
276 batch: usize,
277 _action_dim: usize,
278) -> Tensor {
279 let max_logits = logits.max_dims(&[1], true); // [B,1]
280 let shifted = logits.sub_tensor(&max_logits);
281 let exp = shifted.exp();
282 let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283 let log_sum_exp = sum_exp.log(); // [B,1]
284 let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285 // gather selected action log-probs
286 log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291 new_logp.sub_tensor(old_logp).exp()
292}
293
294// Clamp ratio to [1-clip, 1+clip] using ReLU-based clamp (no custom ops)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296 let b = ratio.shape().dims()[0];
297 let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298 let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299 let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300 high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304 let mut total_sq = 0.0f32;
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 for &v in g.data() {
308 total_sq += v * v;
309 }
310 }
311 }
312 total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320 println!("=== PPO Discrete Example (YardEnv) ===");
321
322 let state_dim = 3usize;
323 let action_dim = 3usize;
324 let total_steps = std::env::var("PPOD_STEPS")
325 .ok()
326 .and_then(|v| v.parse::<usize>().ok())
327 .unwrap_or(3500usize);
328 let horizon = 128usize;
329 let epochs = 4usize;
330 let mini_batch_size = 64usize;
331 let gamma = 0.99f32;
332 let lam = 0.95f32;
333 let clip_eps = 0.2f32;
334 let vf_coef = 0.5f32;
335 let ent_coef = 0.0f32;
336 let max_grad_norm = 1.0f32;
337
338 let mut actor = Actor::new(state_dim, action_dim, Some(111));
339 let mut critic = Critic::new(state_dim, Some(222));
340 let mut actor_opt = Adam::with_learning_rate(3e-4);
341 for p in actor.parameters() {
342 actor_opt.add_parameter(p);
343 }
344 let mut critic_opt = Adam::with_learning_rate(3e-4);
345 for p in critic.parameters() {
346 critic_opt.add_parameter(p);
347 }
348
349 let mut env = YardEnv::new(1234);
350 let mut rng = SmallRng::new(98765);
351 let mut state = env.reset();
352 let mut episode_return = 0.0f32;
353 let mut episode = 0usize;
354 let mut ema_return: Option<f32> = None;
355 let ema_alpha = 0.05f32;
356 let mut best_return = f32::NEG_INFINITY;
357
358 let mut t = 0usize;
359 while t < total_steps {
360 let mut batch = RolloutBatch::new(horizon, state_dim);
361 for _ in 0..horizon {
362 // Actor logits and categorical sampling
363 let logits = actor.forward(&state); // [1, A]
364 let probs = logits.softmax(1); // [1, A]
365 // sample action from probs (CPU sampling)
366 let p = probs.data();
367 let (p0, p1, _p2) = (p[0], p[1], p[2]);
368 let u = rng.next_f32();
369 let a_idx = if u < p0 {
370 0
371 } else if u < p0 + p1 {
372 1
373 } else {
374 2
375 };
376
377 let old_logp = {
378 let _ng = NoGradTrack::new();
379 let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380 lp.data()[0]
381 };
382
383 // Step env
384 let (next_state, reward, done) = env.step(a_idx);
385 episode_return += reward;
386
387 // Critic value
388 let value_t = critic.forward(&state);
389 let value_v = value_t.data()[0];
390
391 batch.push(
392 state.data(),
393 a_idx,
394 old_logp,
395 reward,
396 if done { 1.0 } else { 0.0 },
397 value_v,
398 next_state.data(),
399 );
400
401 state = if done {
402 let st = env.reset();
403 ema_return = Some(match ema_return {
404 None => episode_return,
405 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406 });
407 if episode_return > best_return {
408 best_return = episode_return;
409 }
410 println!(
411 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412 t,
413 episode,
414 episode_return,
415 ema_return.unwrap_or(episode_return),
416 best_return
417 );
418 episode_return = 0.0;
419 episode += 1;
420 st
421 } else {
422 next_state
423 };
424
425 t += 1;
426 if t >= total_steps {
427 break;
428 }
429 }
430
431 // Bootstrap values for GAE
432 let next_values: Vec<f32> = {
433 let mut out = Vec::with_capacity(batch.len());
434 for i in 0..batch.len() {
435 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437 out.push(critic.forward(&s2_t).data()[0]);
438 }
439 out
440 };
441
442 let mut returns = vec![0.0f32; batch.len()];
443 let mut adv = vec![0.0f32; batch.len()];
444 compute_gae(
445 &mut returns,
446 &mut adv,
447 &batch.rewards,
448 &batch.dones,
449 &batch.values,
450 &next_values,
451 gamma,
452 lam,
453 );
454 normalize_in_place(&mut adv, 1e-8);
455
456 // Tensors for training
457 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458 let actions_vec = batch.actions.clone();
459 let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463 // PPO epochs
464 let num_minibatches = batch.len().div_ceil(mini_batch_size);
465 for e in 0..epochs {
466 for mb in 0..num_minibatches {
467 let start = mb * mini_batch_size;
468 let end = (start + mini_batch_size).min(batch.len());
469 if start >= end {
470 break;
471 }
472
473 // Views
474 let s_mb = states_t
475 .slice_view(start * state_dim, 1, (end - start) * state_dim)
476 .reshape(vec![(end - start) as i32, state_dim as i32]);
477 let oldlp_mb = old_logp_t
478 .slice_view(start, 1, end - start)
479 .reshape(vec![(end - start) as i32, 1]);
480 let ret_mb = returns_t
481 .slice_view(start, 1, end - start)
482 .reshape(vec![(end - start) as i32, 1]);
483 let adv_mb = adv_t
484 .slice_view(start, 1, end - start)
485 .reshape(vec![(end - start) as i32, 1]);
486 let a_slice = &actions_vec[start..end];
487
488 // Zero grads
489 {
490 let mut ps = actor.parameters();
491 actor_opt.zero_grad(&mut ps);
492 }
493 {
494 let mut ps = critic.parameters();
495 critic_opt.zero_grad(&mut ps);
496 }
497
498 // Forward
499 let logits_mb = actor.forward(&s_mb); // [B,A]
500 let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501 let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502 let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503 let pg1 = ratio.mul_tensor(&adv_mb);
504 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507 let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509 let v_pred = critic.forward(&s_mb);
510 let v_loss = v_pred
511 .sub_tensor(&ret_mb)
512 .pow_scalar(2.0)
513 .mean()
514 .mul_scalar(vf_coef);
515
516 // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517 let probs_mb = logits_mb.softmax(1);
518 let logp_all = probs_mb.add_scalar(1e-8).log();
519 let ent = probs_mb
520 .mul_tensor(&logp_all)
521 .sum_dims(&[1], true)
522 .mul_scalar(-1.0)
523 .mean()
524 .mul_scalar(ent_coef);
525
526 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527 loss.backward(None);
528
529 // Step actor
530 {
531 let params = actor.parameters();
532 let mut with_grads: Vec<&mut Tensor> = Vec::new();
533 for p in params {
534 if p.grad_owned().is_some() {
535 with_grads.push(p);
536 }
537 }
538 if !with_grads.is_empty() {
539 let _ = grad_global_norm(&mut with_grads);
540 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541 actor_opt.step(&mut with_grads);
542 actor_opt.zero_grad(&mut with_grads);
543 }
544 }
545
546 // Step critic
547 {
548 let params = critic.parameters();
549 let mut with_grads: Vec<&mut Tensor> = Vec::new();
550 for p in params {
551 if p.grad_owned().is_some() {
552 with_grads.push(p);
553 }
554 }
555 if !with_grads.is_empty() {
556 let _ = grad_global_norm(&mut with_grads);
557 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558 critic_opt.step(&mut with_grads);
559 critic_opt.zero_grad(&mut with_grads);
560 }
561 }
562
563 if e == 0 && mb == 0 {
564 println!(
565 "update@t={} | actor_loss={:.4} v_loss={:.4}",
566 t,
567 actor_loss.value(),
568 v_loss.value()
569 );
570 }
571
572 clear_all_graphs_known();
573 }
574 }
575 }
576
577 println!("=== PPO discrete training finished ===");
578 Ok(())
579}99 fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
100 let net = Mlp::new(&[state_dim, 64, 64, action_dim], seed);
101 let log_std = Tensor::from_slice(&vec![0.0; action_dim], vec![action_dim])
102 .unwrap()
103 .with_requires_grad();
104 Self { net, log_std }
105 }
106 fn forward(&self, state: &Tensor) -> (Tensor, Tensor) {
107 // Returns (mean [B, A], log_std [A])
108 let mean = self.net.forward(state);
109 (
110 mean,
111 self.log_std
112 .view(vec![1, self.log_std.shape().dims()[0] as i32]),
113 )
114 }
115 fn parameters(&mut self) -> Vec<&mut Tensor> {
116 let mut ps = self.net.parameters();
117 ps.push(&mut self.log_std);
118 ps
119 }
120}
121
122// -------------------------------
123// Critic: value function V(s)
124// -------------------------------
125
126struct Critic {
127 net: Mlp,
128}
129impl Critic {
130 fn new(state_dim: usize, seed: Option<u64>) -> Self {
131 Self {
132 net: Mlp::new(&[state_dim, 64, 64, 1], seed),
133 }
134 }
135 fn forward(&self, state: &Tensor) -> Tensor {
136 self.net.forward(state)
137 }
138 fn parameters(&mut self) -> Vec<&mut Tensor> {
139 self.net.parameters()
140 }
141}
142
143// -------------------------------
144// Continuous YardEnv (same dynamics as TD3 env)
145// -------------------------------
146
147struct YardEnv {
148 pos: f32,
149 vel: f32,
150 steps: usize,
151 max_steps: usize,
152 rng: SmallRng,
153}
154impl YardEnv {
155 fn new(seed: u64) -> Self {
156 let mut e = Self {
157 pos: 0.0,
158 vel: 0.0,
159 steps: 0,
160 max_steps: 200,
161 rng: SmallRng::new(seed),
162 };
163 e.reset();
164 e
165 }
166 fn reset(&mut self) -> Tensor {
167 self.pos = (self.rng.next_f32() * 1.0) - 0.5;
168 self.vel = (self.rng.next_f32() * 0.2) - 0.1;
169 self.steps = 0;
170 self.state_tensor()
171 }
172 fn state_tensor(&self) -> Tensor {
173 Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
174 }
175 fn step(&mut self, action_value: f32) -> (Tensor, f32, bool) {
176 let a = action_value.clamp(-1.0, 1.0);
177 self.vel += 0.1 * a - 0.01 * self.pos;
178 self.pos += self.vel;
179 self.steps += 1;
180 let reward = -(self.pos * self.pos) - 0.1 * (a * a);
181 let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
182 (self.state_tensor(), reward, done)
183 }
184}
185
186// -------------------------------
187// Trajectory storage
188// -------------------------------
189
190struct RolloutBatch {
191 states: Vec<f32>,
192 actions: Vec<f32>,
193 log_probs: Vec<f32>,
194 rewards: Vec<f32>,
195 dones: Vec<f32>,
196 values: Vec<f32>,
197 next_states: Vec<f32>,
198 _state_dim: usize,
199}
200impl RolloutBatch {
201 fn new(capacity: usize, state_dim: usize) -> Self {
202 Self {
203 states: Vec::with_capacity(capacity * state_dim),
204 actions: Vec::with_capacity(capacity),
205 log_probs: Vec::with_capacity(capacity),
206 rewards: Vec::with_capacity(capacity),
207 dones: Vec::with_capacity(capacity),
208 values: Vec::with_capacity(capacity),
209 next_states: Vec::with_capacity(capacity * state_dim),
210 _state_dim: state_dim,
211 }
212 }
213
214 #[allow(clippy::too_many_arguments)]
215 fn push(&mut self, s: &[f32], a: f32, lp: f32, r: f32, d: f32, v: f32, s2: &[f32]) {
216 self.states.extend_from_slice(s);
217 self.actions.push(a);
218 self.log_probs.push(lp);
219 self.rewards.push(r);
220 self.dones.push(d);
221 self.values.push(v);
222 self.next_states.extend_from_slice(s2);
223 }
224
225 fn len(&self) -> usize {
226 self.actions.len()
227 }
228}
229
230// -------------------------------
231// Math helpers
232// -------------------------------
233
234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235 // All tensors shaped [B, A] (log_std is broadcastable)
236 let std = log_std.exp();
237 let var = std.pow_scalar(2.0);
238 let log_scale = log_std;
239 let diff = action.sub_tensor(mean);
240 let log_prob = diff
241 .pow_scalar(2.0)
242 .div_tensor(&var)
243 .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244 .add_tensor(&log_scale.mul_scalar(2.0))
245 .mul_scalar(0.5)
246 .mul_scalar(-1.0);
247 // Sum across action dim (dim=1) -> [B,1]
248 log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253 returns_out: &mut [f32],
254 adv_out: &mut [f32],
255 rewards: &[f32],
256 dones: &[f32],
257 values: &[f32],
258 next_values: &[f32],
259 gamma: f32,
260 lam: f32,
261) {
262 let n = rewards.len();
263 let mut gae = 0.0f32;
264 for t in (0..n).rev() {
265 let not_done = 1.0 - dones[t];
266 let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267 gae = delta + gamma * lam * not_done * gae;
268 adv_out[t] = gae;
269 returns_out[t] = gae + values[t];
270 }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274 let n = x.len() as f32;
275 if n <= 1.0 {
276 return;
277 }
278 let mean = x.iter().copied().sum::<f32>() / n;
279 let var = x
280 .iter()
281 .map(|v| {
282 let d = v - mean;
283 d * d
284 })
285 .sum::<f32>()
286 / n;
287 let std = (var + eps).sqrt();
288 for v in x.iter_mut() {
289 *v = (*v - mean) / std;
290 }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294 let mut total_sq = 0.0f32;
295 for p in parameters.iter() {
296 if let Some(g) = p.grad_owned() {
297 for &v in g.data() {
298 total_sq += v * v;
299 }
300 }
301 }
302 let norm = total_sq.sqrt();
303 if norm > max_norm {
304 let scale = max_norm / (norm + eps);
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 p.set_grad(g.mul_scalar(scale));
308 }
309 }
310 }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314 let mut total_sq = 0.0f32;
315 for p in parameters.iter_mut() {
316 if let Some(g) = p.grad_owned() {
317 for &v in g.data() {
318 total_sq += v * v;
319 }
320 }
321 }
322 total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330 println!("=== PPO Continuous Example (YardEnv) ===");
331
332 let state_dim = 3usize;
333 let action_dim = 1usize;
334
335 // Hparams
336 let total_steps = std::env::var("PPO_STEPS")
337 .ok()
338 .and_then(|v| v.parse::<usize>().ok())
339 .unwrap_or(4000usize);
340 let horizon = 128usize; // rollout length per update
341 let epochs = 4usize; // PPO epochs per update
342 let mini_batch_size = 64usize; // minibatch from horizon
343 let gamma = 0.99f32;
344 let lam = 0.95f32; // GAE lambda
345 let clip_eps = 0.2f32;
346 let vf_coef = 0.5f32;
347 let ent_coef = 0.0f32;
348 let max_grad_norm = 1.0f32;
349
350 // Models
351 let mut actor = Actor::new(state_dim, action_dim, Some(101));
352 let mut critic = Critic::new(state_dim, Some(202));
353
354 // Opts
355 let mut actor_opt = Adam::with_learning_rate(3e-4);
356 for p in actor.parameters() {
357 actor_opt.add_parameter(p);
358 }
359 let mut critic_opt = Adam::with_learning_rate(3e-4);
360 for p in critic.parameters() {
361 critic_opt.add_parameter(p);
362 }
363
364 // Env and RNG
365 let mut env = YardEnv::new(42);
366 let mut rng = SmallRng::new(999);
367 let mut state = env.reset();
368
369 // Metrics
370 let mut episode_return = 0.0f32;
371 let mut episode = 0usize;
372 let mut ema_return: Option<f32> = None;
373 let ema_alpha = 0.05f32;
374 let mut best_return = f32::NEG_INFINITY;
375
376 let mut t = 0usize;
377 while t < total_steps {
378 // Collect a rollout
379 let mut batch = RolloutBatch::new(horizon, state_dim);
380 for _ in 0..horizon {
381 // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382 let (mean, log_std_row) = actor.forward(&state);
383 let mean_v = mean.data()[0];
384 let log_std_v = log_std_row.data()[0];
385 let std_v = log_std_v.exp();
386 let noise = rng.normal();
387 let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389 // Build action tensor [1, A] for log_prob calculation with autograd
390 let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391 let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392 let log_prob_v = log_prob_t.data()[0];
393
394 // Step env
395 let (next_state, reward, done) = env.step(action_v);
396 episode_return += reward;
397
398 // Value
399 let value_t = critic.forward(&state);
400 let value_v = value_t.data()[0];
401
402 // Push
403 batch.push(
404 state.data(),
405 action_v,
406 log_prob_v,
407 reward,
408 if done { 1.0 } else { 0.0 },
409 value_v,
410 next_state.data(),
411 );
412
413 // Reset
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return
430 );
431 episode_return = 0.0;
432 episode += 1;
433 st
434 } else {
435 next_state
436 };
437
438 t += 1;
439 if t >= total_steps {
440 break;
441 }
442 }
443
444 // Bootstrap next values for GAE
445 let next_values: Vec<f32> = {
446 let mut out = Vec::with_capacity(batch.len());
447 for i in 0..batch.len() {
448 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450 let v2 = critic.forward(&s2_t).data()[0];
451 out.push(v2);
452 }
453 out
454 };
455
456 // Compute returns and advantages
457 let mut returns = vec![0.0f32; batch.len()];
458 let mut adv = vec![0.0f32; batch.len()];
459 compute_gae(
460 &mut returns,
461 &mut adv,
462 &batch.rewards,
463 &batch.dones,
464 &batch.values,
465 &next_values,
466 gamma,
467 lam,
468 );
469 normalize_in_place(&mut adv, 1e-8);
470
471 // Prepare tensors for training
472 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473 let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474 let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478 // PPO epochs over the rollout
479 let num_minibatches = batch.len().div_ceil(mini_batch_size);
480 for e in 0..epochs {
481 for mb in 0..num_minibatches {
482 let start = mb * mini_batch_size;
483 let end = (start + mini_batch_size).min(batch.len());
484 if start >= end {
485 break;
486 }
487
488 // Slice views
489 let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490 let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491 let a_mb = actions_t
492 .slice_view(start * action_dim, 1, (end - start) * action_dim)
493 .reshape(vec![(end - start) as i32, action_dim as i32]);
494 let oldlp_mb = old_logp_t
495 .slice_view(start, 1, end - start)
496 .reshape(vec![(end - start) as i32, 1]);
497 let ret_mb = returns_t
498 .slice_view(start, 1, end - start)
499 .reshape(vec![(end - start) as i32, 1]);
500 let adv_mb = adv_t
501 .slice_view(start, 1, end - start)
502 .reshape(vec![(end - start) as i32, 1]);
503
504 // Zero grads
505 {
506 let mut ps = actor.parameters();
507 actor_opt.zero_grad(&mut ps);
508 }
509 {
510 let mut ps = critic.parameters();
511 critic_opt.zero_grad(&mut ps);
512 }
513
514 // Forward actor and critic
515 let (mean_mb, log_std_row) = actor.forward(&s_mb);
516 let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517 let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518 let clip_low =
519 Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520 .unwrap();
521 let clip_high =
522 Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523 .unwrap();
524 // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525 let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526 let ratio_clipped =
527 clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528 let pg1 = ratio.mul_tensor(&adv_mb);
529 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532 let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534 let v_pred = critic.forward(&s_mb);
535 let v_loss = v_pred
536 .sub_tensor(&ret_mb)
537 .pow_scalar(2.0)
538 .mean()
539 .mul_scalar(vf_coef);
540
541 // Entropy (approx Gaussian entropy per action)
542 let entropy = log_std_row
543 .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544 .sum_dims(&[1], true)
545 .mean()
546 .mul_scalar(ent_coef);
547
548 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549 loss.backward(None);
550
551 // Step actor
552 {
553 let params = actor.parameters();
554 let mut with_grads: Vec<&mut Tensor> = Vec::new();
555 for p in params {
556 if p.grad_owned().is_some() {
557 with_grads.push(p);
558 }
559 }
560 if !with_grads.is_empty() {
561 let _ = grad_global_norm(&mut with_grads);
562 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563 actor_opt.step(&mut with_grads);
564 actor_opt.zero_grad(&mut with_grads);
565 }
566 }
567
568 // Step critic
569 {
570 let params = critic.parameters();
571 let mut with_grads: Vec<&mut Tensor> = Vec::new();
572 for p in params {
573 if p.grad_owned().is_some() {
574 with_grads.push(p);
575 }
576 }
577 if !with_grads.is_empty() {
578 let _ = grad_global_norm(&mut with_grads);
579 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580 critic_opt.step(&mut with_grads);
581 critic_opt.zero_grad(&mut with_grads);
582 }
583 }
584
585 // Occasionally log
586 if e == 0 && mb == 0 {
587 println!(
588 "update@t={} | actor_loss={:.4} v_loss={:.4}",
589 t,
590 actor_loss.value(),
591 v_loss.value()
592 );
593 }
594
595 clear_all_graphs_known();
596 }
597 }
598 }
599
600 println!("=== PPO training finished ===");
601 Ok(())
602}244 fn state_tensor(&self) -> Tensor {
245 // Normalize to keep critic inputs bounded:
246 // - Position is bounded by termination at |pos|>3 → scale by 3 to [-1,1]
247 // - Velocity scaled by 1.0 and clamped to [-1,1]
248 let pos_n = self.pos / 3.0;
249 let vel_n = self.vel.clamp(-1.0, 1.0);
250 Tensor::from_slice(&[pos_n, vel_n, 0.0], vec![1, 3]).unwrap()
251 }
252
253 fn step(&mut self, action_value: f32) -> (Tensor, f32, bool) {
254 let a = action_value.clamp(-1.0, 1.0);
255 self.vel += 0.1 * a - 0.01 * self.pos;
256 self.pos += self.vel;
257 self.steps += 1;
258
259 let reward = -(self.pos * self.pos) - 0.1 * (a * a);
260 let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
261 (self.state_tensor(), reward, done)
262 }
263}
264
265// -------------------------------
266// Replay Buffer
267// -------------------------------
268
269struct ReplayBuffer {
270 capacity: usize,
271 size: usize,
272 pos: usize,
273 state_dim: usize,
274 action_dim: usize,
275 states: Vec<f32>,
276 actions: Vec<f32>,
277 rewards: Vec<f32>,
278 dones: Vec<f32>,
279 next_states: Vec<f32>,
280}
281
282impl ReplayBuffer {
283 fn new(capacity: usize, state_dim: usize, action_dim: usize) -> Self {
284 Self {
285 capacity,
286 size: 0,
287 pos: 0,
288 state_dim,
289 action_dim,
290 states: vec![0.0; capacity * state_dim],
291 actions: vec![0.0; capacity * action_dim],
292 rewards: vec![0.0; capacity],
293 dones: vec![0.0; capacity],
294 next_states: vec![0.0; capacity * state_dim],
295 }
296 }
297
298 fn push(&mut self, s: &[f32], a: &[f32], r: f32, d: f32, s2: &[f32]) {
299 let i = self.pos;
300 let so = i * self.state_dim;
301 let ao = i * self.action_dim;
302 self.states[so..so + self.state_dim].copy_from_slice(s);
303 self.actions[ao..ao + self.action_dim].copy_from_slice(a);
304 self.rewards[i] = r;
305 self.dones[i] = d;
306 self.next_states[so..so + self.state_dim].copy_from_slice(s2);
307
308 self.pos = (self.pos + 1) % self.capacity;
309 self.size = self.size.saturating_add(1).min(self.capacity);
310 }
311
312 fn can_sample(&self, batch_size: usize) -> bool {
313 self.size >= batch_size
314 }
315
316 fn sample(
317 &self,
318 batch_size: usize,
319 rng: &mut SmallRng,
320 ) -> (Tensor, Tensor, Tensor, Tensor, Tensor) {
321 let mut s_vec = Vec::with_capacity(batch_size * self.state_dim);
322 let mut a_vec = Vec::with_capacity(batch_size * self.action_dim);
323 let mut r_vec = Vec::with_capacity(batch_size);
324 let mut d_vec = Vec::with_capacity(batch_size);
325 let mut s2_vec = Vec::with_capacity(batch_size * self.state_dim);
326
327 for _ in 0..batch_size {
328 let idx = rng.sample_index(self.size);
329 let so = idx * self.state_dim;
330 let ao = idx * self.action_dim;
331 s_vec.extend_from_slice(&self.states[so..so + self.state_dim]);
332 a_vec.extend_from_slice(&self.actions[ao..ao + self.action_dim]);
333 r_vec.push(self.rewards[idx]);
334 d_vec.push(self.dones[idx]);
335 s2_vec.extend_from_slice(&self.next_states[so..so + self.state_dim]);
336 }
337
338 let s = Tensor::from_slice(&s_vec, vec![batch_size, self.state_dim]).unwrap();
339 let a = Tensor::from_slice(&a_vec, vec![batch_size, self.action_dim]).unwrap();
340 let r = Tensor::from_slice(&r_vec, vec![batch_size, 1]).unwrap();
341 let d = Tensor::from_slice(&d_vec, vec![batch_size, 1]).unwrap();
342 let s2 = Tensor::from_slice(&s2_vec, vec![batch_size, self.state_dim]).unwrap();
343 (s, a, r, d, s2)
344 }
345}
346
347// -------------------------------
348// Helper: gradient clipping by global norm
349// -------------------------------
350
351fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
352 // Compute global L2 norm of all grads
353 let mut total_sq = 0.0f32;
354 for p in parameters.iter() {
355 if let Some(g) = p.grad_owned() {
356 for &v in g.data() {
357 total_sq += v * v;
358 }
359 }
360 }
361 let norm = total_sq.sqrt();
362 if norm > max_norm {
363 let scale = max_norm / (norm + eps);
364 for p in parameters.iter_mut() {
365 if let Some(g) = p.grad_owned() {
366 let scaled = g.mul_scalar(scale);
367 p.set_grad(scaled);
368 }
369 }
370 }
371}
372
373// Compute global L2 norm of gradients across a parameter list (read-only)
374fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
375 let mut total_sq = 0.0f32;
376 for p in parameters.iter_mut() {
377 if let Some(g) = p.grad_owned() {
378 for &v in g.data() {
379 total_sq += v * v;
380 }
381 }
382 }
383 total_sq.sqrt()
384}
385
386// Compute L2 norm of parameters (weights/biases) across a parameter list
387fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
388 let _ng = NoGradTrack::new();
389 let mut total_sq = 0.0f32;
390 for p in parameters.iter_mut() {
391 for &v in p.data() {
392 total_sq += v * v;
393 }
394 }
395 total_sq.sqrt()
396}
397
398// -------------------------------
399// Main: TD3 training on YardEnv
400// -------------------------------
401
402pub fn main() -> Result<(), Box<dyn std::error::Error>> {
403 println!("=== TD3 Example (YardEnv) ===");
404
405 // Environment / problem dims
406 let state_dim = 3usize;
407 let action_dim = 1usize;
408
409 // Hyperparameters (small for demo)
410 let gamma = 0.99f32;
411 let tau = 0.005f32; // Polyak
412 let policy_noise = 0.2f32; // target smoothing noise stddev
413 let exploration_noise = 0.1f32; // behavior policy noise stddev
414 let policy_delay = 2usize;
415 let batch_size = 64usize;
416 let start_steps = 500usize; // random exploration steps
417 let total_steps = 1500usize;
418 let max_grad_norm = 1.0f32;
419
420 // Models
421 let mut actor = Actor::new(state_dim, action_dim, Some(11));
422 let mut actor_targ = Actor::new(state_dim, action_dim, Some(12));
423 actor_targ.net.copy_from(&actor.net);
424 actor_targ.set_requires_grad_all(false);
425
426 let mut critic1 = Critic::new(state_dim, action_dim, Some(21));
427 let mut critic2 = Critic::new(state_dim, action_dim, Some(22));
428 let mut critic1_targ = Critic::new(state_dim, action_dim, Some(23));
429 let mut critic2_targ = Critic::new(state_dim, action_dim, Some(24));
430 critic1_targ.net.copy_from(&critic1.net);
431 critic2_targ.net.copy_from(&critic2.net);
432 critic1_targ.set_requires_grad_all(false);
433 critic2_targ.set_requires_grad_all(false);
434
435 // Optimizers
436 let mut actor_opt = Adam::with_learning_rate(1e-3);
437 for p in actor.parameters() {
438 actor_opt.add_parameter(p);
439 }
440
441 let mut critic_opt = Adam::with_learning_rate(1e-4);
442 for p in critic1.parameters() {
443 critic_opt.add_parameter(p);
444 }
445 for p in critic2.parameters() {
446 critic_opt.add_parameter(p);
447 }
448
449 // Replay buffer and env
450 let mut rb = ReplayBuffer::new(100_000, state_dim, action_dim);
451 let mut env = YardEnv::new(1234);
452 let mut rng = SmallRng::new(987654321);
453
454 // Reset & metric trackers
455 let mut state = env.reset(); // [1, state_dim]
456 let mut episode_return = 0.0f32;
457 let mut episode = 0usize;
458 let mut ema_return: Option<f32> = None;
459 let ema_alpha = 0.05f32; // smooth short-term
460 let mut best_return = f32::NEG_INFINITY;
461 let mut policy_updates: usize = 0;
462
463 for t in 0..total_steps {
464 // Select action
465 let action_tensor = if t < start_steps {
466 let a = rng.uniform(-1.0, 1.0);
467 Tensor::from_slice(&[a], vec![1, action_dim]).unwrap()
468 } else {
469 // Behavior policy with exploration noise
470 let _ng = NoGradTrack::new();
471 let det = actor.forward(&state);
472 let noise = Tensor::randn(vec![1, action_dim], None).mul_scalar(exploration_noise);
473 tanh_bounded(&det.add_tensor(&noise))
474 };
475 let action_value = action_tensor.data()[0];
476
477 // Environment step
478 let (next_state, reward, done) = env.step(action_value);
479 episode_return += reward;
480
481 // Store transition
482 let s_slice = state.data().to_vec();
483 let a_slice = action_tensor.data().to_vec();
484 let s2_slice = next_state.data().to_vec();
485 rb.push(
486 &s_slice,
487 &a_slice,
488 reward,
489 if done { 1.0 } else { 0.0 },
490 &s2_slice,
491 );
492
493 state = if done {
494 let st = env.reset();
495 // Metrics: update EMA and best
496 ema_return = Some(match ema_return {
497 None => episode_return,
498 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
499 });
500 if episode_return > best_return {
501 best_return = episode_return;
502 }
503 println!(
504 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={} | policy_updates={}",
505 t,
506 episode,
507 episode_return,
508 ema_return.unwrap_or(episode_return),
509 best_return,
510 rb.size,
511 policy_updates
512 );
513 episode_return = 0.0;
514 episode += 1;
515 st
516 } else {
517 next_state
518 };
519
520 // Training
521 if rb.can_sample(batch_size) {
522 // Sample batch
523 let (s, a, r, d, s2) = rb.sample(batch_size, &mut rng);
524
525 // Compute target values y = r + (1-d)*gamma*min(Q1', Q2') using target networks (no grad)
526 let target_q = {
527 let _ng = NoGradTrack::new();
528 // Target actions with smoothing noise (tanh bounds)
529 let noise =
530 Tensor::randn(vec![batch_size, action_dim], None).mul_scalar(policy_noise);
531 let a_targ = tanh_bounded(&actor_targ.forward(&s2).add_tensor(&noise));
532 let q1_t = critic1_targ.forward(&s2, &a_targ);
533 let q2_t = critic2_targ.forward(&s2, &a_targ);
534
535 // Elementwise min via data() since this path is no-grad
536 let q1d = q1_t.data();
537 let q2d = q2_t.data();
538 let mut min_vec = Vec::with_capacity(batch_size);
539 for i in 0..batch_size {
540 let v1 = q1d[i];
541 let v2 = q2d[i];
542 min_vec.push(v1.min(v2));
543 }
544 let min_q = Tensor::from_slice(&min_vec, vec![batch_size, 1]).unwrap();
545 let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
546 r.add_tensor(¬_done.mul_scalar(gamma).mul_tensor(&min_q))
547 };
548
549 // Critic update (both critics)
550 // Zero grads in a short scope, then drop borrows before forward
551 {
552 let mut params = {
553 let c_params = critic1.parameters();
554 let c2_params = critic2.parameters();
555 let mut tmp: Vec<&mut Tensor> = Vec::new();
556 tmp.extend(c_params);
557 tmp.extend(c2_params);
558 tmp
559 };
560 critic_opt.zero_grad(&mut params);
561 }
562
563 // Forward current Q estimates
564 let q1 = critic1.forward(&s, &a);
565 let q2 = critic2.forward(&s, &a);
566 let diff1 = q1.sub_tensor(&target_q);
567 let diff2 = q2.sub_tensor(&target_q);
568 let mut critic_loss = diff1
569 .pow_scalar(2.0)
570 .mean()
571 .add_tensor(&diff2.pow_scalar(2.0).mean());
572
573 // Backward
574 critic_loss.backward(None);
575
576 // Optional gradient clipping + step (only for params that received grads)
577 {
578 let params = {
579 let c_params = critic1.parameters();
580 let c2_params = critic2.parameters();
581 let mut tmp: Vec<&mut Tensor> = Vec::new();
582 tmp.extend(c_params);
583 tmp.extend(c2_params);
584 tmp
585 };
586 let mut with_grads: Vec<&mut Tensor> = Vec::new();
587 for p in params {
588 if p.grad_owned().is_some() {
589 with_grads.push(p);
590 }
591 }
592 if !with_grads.is_empty() {
593 // Pre-step metrics
594 let grad_norm_before = grad_global_norm(&mut with_grads);
595 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
596 critic_opt.step(&mut with_grads);
597 critic_opt.zero_grad(&mut with_grads);
598
599 // Post-step metrics (param norm)
600 let mut for_norm_params = {
601 let c_params = critic1.parameters();
602 let c2_params = critic2.parameters();
603 let mut tmp: Vec<&mut Tensor> = Vec::new();
604 tmp.extend(c_params);
605 tmp.extend(c2_params);
606 tmp
607 };
608 let param_norm = params_l2_norm(&mut for_norm_params);
609
610 // Print compact critic metrics occasionally
611 if t % 100 == 0 {
612 let q1_mean = q1.mean().value();
613 let q2_mean = q2.mean().value();
614 let tq_mean = target_q.mean().value();
615 println!(
616 "t={:5} | critic_loss={:.4} | q1_mean={:.3} q2_mean={:.3} tq_mean={:.3} | grad_norm={:.3} | crit_param_norm={:.3}",
617 t,
618 critic_loss.value(),
619 q1_mean,
620 q2_mean,
621 tq_mean,
622 grad_norm_before,
623 param_norm
624 );
625 }
626 }
627 }
628
629 // Delayed policy update
630 if t % policy_delay == 0 {
631 // Actor update: maximize Q1(s, actor(s)) -> minimize -Q1
632 // Zero actor grads before backward
633 {
634 let mut a_params: Vec<&mut Tensor> = actor.parameters();
635 actor_opt.zero_grad(&mut a_params);
636 }
637
638 let a_pred = actor.forward(&s);
639 let q_for_actor = critic1.forward(&s, &a_pred);
640 let mut actor_loss = q_for_actor.mul_scalar(-1.0).mean();
641 actor_loss.backward(None);
642
643 {
644 let a_params: Vec<&mut Tensor> = actor.parameters();
645 let mut with_grads: Vec<&mut Tensor> = Vec::new();
646 for p in a_params {
647 if p.grad_owned().is_some() {
648 with_grads.push(p);
649 }
650 }
651 if !with_grads.is_empty() {
652 let grad_norm_before = grad_global_norm(&mut with_grads);
653 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
654 actor_opt.step(&mut with_grads);
655 actor_opt.zero_grad(&mut with_grads);
656
657 // Post-step param norm
658 let mut for_norm_params = actor.parameters();
659 let param_norm = params_l2_norm(&mut for_norm_params);
660
661 policy_updates += 1;
662 if t % 200 == 0 {
663 println!(
664 "t={:5} | actor_loss={:.4} | act_grad_norm={:.3} | act_param_norm={:.3} | lr_a={:.4e} lr_c={:.4e} | policy_updates={}",
665 t,
666 actor_loss.value(),
667 grad_norm_before,
668 param_norm,
669 actor_opt.learning_rate(),
670 critic_opt.learning_rate(),
671 policy_updates
672 );
673 }
674 }
675 }
676
677 // Target updates (Polyak averaging, no grad)
678 actor_targ.net.soft_update_from(&actor.net, tau);
679 critic1_targ.net.soft_update_from(&critic1.net, tau);
680 critic2_targ.net.soft_update_from(&critic2.net, tau);
681 }
682
683 // Clear entire graphs to avoid stale accumulation across iterations
684 clear_all_graphs_known();
685 }
686 }
687
688 println!("=== TD3 training finished ===");
689 Ok(())
690}46fn demonstrate_basic_operators() {
47 println!("--- Basic Tensor-Tensor Operators ---");
48
49 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
50 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
51
52 println!("Tensor A: {:?}", a.data());
53 println!("Tensor B: {:?}", b.data());
54
55 // Addition
56 let c = &a + &b;
57 println!("A + B: {:?}", c.data());
58
59 // Subtraction
60 let d = &a - &b;
61 println!("A - B: {:?}", d.data());
62
63 // Multiplication
64 let e = &a * &b;
65 println!("A * B: {:?}", e.data());
66
67 // Division
68 let f = &a / &b;
69 println!("A / B: {:?}", f.data());
70}
71
72/// Demonstrate tensor-scalar operators
73fn demonstrate_scalar_operators() {
74 println!("\n--- Tensor-Scalar Operators ---");
75
76 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
77 println!("Original tensor: {:?}", tensor.data());
78
79 // Tensor + scalar
80 let result1 = &tensor + 5.0;
81 println!("Tensor + 5.0: {:?}", result1.data());
82
83 // Scalar + tensor
84 let result2 = 5.0 + &tensor;
85 println!("5.0 + Tensor: {:?}", result2.data());
86
87 // Tensor - scalar
88 let result3 = &tensor - 2.0;
89 println!("Tensor - 2.0: {:?}", result3.data());
90
91 // Tensor * scalar
92 let result4 = &tensor * 3.0;
93 println!("Tensor * 3.0: {:?}", result4.data());
94
95 // Scalar * tensor
96 let result5 = 3.0 * &tensor;
97 println!("3.0 * Tensor: {:?}", result5.data());
98
99 // Tensor / scalar
100 let result6 = &tensor / 2.0;
101 println!("Tensor / 2.0: {:?}", result6.data());
102}
103
104/// Demonstrate assignment operators
105fn demonstrate_operator_assignment() {
106 println!("\n--- Assignment Operators ---");
107
108 let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
109 println!("Original tensor: {:?}", tensor.data());
110
111 // In-place addition
112 tensor += 5.0;
113 println!("After += 5.0: {:?}", tensor.data());
114
115 // In-place subtraction
116 tensor -= 2.0;
117 println!("After -= 2.0: {:?}", tensor.data());
118
119 // In-place multiplication
120 tensor *= 3.0;
121 println!("After *= 3.0: {:?}", tensor.data());
122
123 // In-place division
124 tensor /= 2.0;
125 println!("After /= 2.0: {:?}", tensor.data());
126}
127
128/// Demonstrate operator chaining and complex expressions
129fn demonstrate_operator_chaining() {
130 println!("\n--- Operator Chaining ---");
131
132 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
133 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
134 let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
135
136 println!("Tensor A: {:?}", a.data());
137 println!("Tensor B: {:?}", b.data());
138 println!("Tensor C: {:?}", c.data());
139
140 // Complex expression: (A + B) * C - 5
141 let result = (&a + &b) * &c - 5.0;
142 println!("(A + B) * C - 5: {:?}", result.data());
143
144 // Another complex expression: A * 2 + B / 2
145 let result2 = &a * 2.0 + &b / 2.0;
146 println!("A * 2 + B / 2: {:?}", result2.data());
147
148 // Negation and addition: -A + B * C
149 let result3 = -&a + &b * &c;
150 println!("-A + B * C: {:?}", result3.data());
151
152 // Division with parentheses: (A + B) / (C - 1)
153 let result4 = (&a + &b) / (&c - 1.0);
154 println!("(A + B) / (C - 1): {:?}", result4.data());
155}
156
157/// Demonstrate broadcasting behavior
158fn demonstrate_broadcasting() {
159 println!("\n--- Broadcasting ---");
160
161 // 2D tensor
162 let tensor_2d = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
163 println!(
164 "2D tensor: shape {:?}, data: {:?}",
165 tensor_2d.shape().dims(),
166 tensor_2d.data()
167 );
168
169 // 1D tensor (will be broadcasted)
170 let tensor_1d = Tensor::from_slice(&[10.0, 20.0], vec![2]).unwrap();
171 println!(
172 "1D tensor: shape {:?}, data: {:?}",
173 tensor_1d.shape().dims(),
174 tensor_1d.data()
175 );
176
177 // Broadcasting addition
178 let broadcast_sum = &tensor_2d + &tensor_1d;
179 println!(
180 "Broadcast sum: shape {:?}, data: {:?}",
181 broadcast_sum.shape().dims(),
182 broadcast_sum.data()
183 );
184
185 // Broadcasting multiplication
186 let broadcast_mul = &tensor_2d * &tensor_1d;
187 println!(
188 "Broadcast multiplication: shape {:?}, data: {:?}",
189 broadcast_mul.shape().dims(),
190 broadcast_mul.data()
191 );
192
193 // Broadcasting with scalar
194 let broadcast_scalar = &tensor_2d + 100.0;
195 println!(
196 "Broadcast scalar: shape {:?}, data: {:?}",
197 broadcast_scalar.shape().dims(),
198 broadcast_scalar.data()
199 );
200}
201
202/// Demonstrate equivalence between operators and method calls
203fn demonstrate_method_equivalence() {
204 println!("\n--- Operator vs Method Call Equivalence ---");
205
206 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
207 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
208
209 // Addition: operator vs method
210 let operator_result = &a + &b;
211 let method_result = a.add_tensor(&b);
212
213 println!("A + B (operator): {:?}", operator_result.data());
214 println!("A.add_tensor(B): {:?}", method_result.data());
215 println!(
216 "Results are equal: {}",
217 operator_result.data() == method_result.data()
218 );
219
220 // Multiplication: operator vs method
221 let operator_result = &a * &b;
222 let method_result = a.mul_tensor(&b);
223
224 println!("A * B (operator): {:?}", operator_result.data());
225 println!("A.mul_tensor(B): {:?}", method_result.data());
226 println!(
227 "Results are equal: {}",
228 operator_result.data() == method_result.data()
229 );
230
231 // Scalar addition: operator vs method
232 let operator_result = &a + 5.0;
233 let method_result = a.add_scalar(5.0);
234
235 println!("A + 5.0 (operator): {:?}", operator_result.data());
236 println!("A.add_scalar(5.0): {:?}", method_result.data());
237 println!(
238 "Results are equal: {}",
239 operator_result.data() == method_result.data()
240 );
241}169fn demonstrate_forward_pass() {
170 println!("\n--- Forward Pass (with gradients) ---");
171
172 let layer = LinearLayer::new(3, 2, Some(43));
173
174 // Single input
175 let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
176 let output = layer.forward(&input);
177
178 println!("Single input:");
179 println!(" Input: {:?}", input.data());
180 println!(" Output: {:?}", output.data());
181 println!(" Output requires grad: {}", output.requires_grad());
182
183 // Batch input
184 let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
185 let batch_output = layer.forward(&batch_input);
186
187 println!("Batch input:");
188 println!(" Input shape: {:?}", batch_input.shape().dims());
189 println!(" Output shape: {:?}", batch_output.shape().dims());
190 println!(" Output requires grad: {}", batch_output.requires_grad());
191}
192
193/// Demonstrate forward pass without gradient tracking
194fn demonstrate_forward_pass_no_grad() {
195 println!("\n--- Forward Pass (no gradients) ---");
196
197 let layer = LinearLayer::new(3, 2, Some(44));
198
199 // Single input
200 let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
201 let output = layer.forward_no_grad(&input);
202
203 println!("Single input (no grad):");
204 println!(" Input: {:?}", input.data());
205 println!(" Output: {:?}", output.data());
206 println!(" Output requires grad: {}", output.requires_grad());
207
208 // Compare with grad version
209 let output_with_grad = layer.forward(&input);
210 println!("Comparison:");
211 println!(
212 " Same values: {}",
213 output.data() == output_with_grad.data()
214 );
215 println!(" No grad requires grad: {}", output.requires_grad());
216 println!(
217 " With grad requires grad: {}",
218 output_with_grad.requires_grad()
219 );
220}
221
222/// Demonstrate complete training loop
223fn demonstrate_training_loop() -> Result<(), Box<dyn std::error::Error>> {
224 println!("\n--- Training Loop ---");
225
226 // Create layer and training data
227 let mut layer = LinearLayer::new(2, 1, Some(45));
228
229 // Simple regression task: y = 2*x1 + 3*x2 + 1
230 let x_data = Tensor::from_slice(
231 &[
232 1.0, 1.0, // x1=1, x2=1 -> y=6
233 2.0, 1.0, // x1=2, x2=1 -> y=8
234 1.0, 2.0, // x1=1, x2=2 -> y=9
235 2.0, 2.0, // x1=2, x2=2 -> y=11
236 ],
237 vec![4, 2],
238 )
239 .unwrap();
240
241 let y_true = Tensor::from_slice(&[6.0, 8.0, 9.0, 11.0], vec![4, 1]).unwrap();
242
243 println!("Training data:");
244 println!(" X shape: {:?}", x_data.shape().dims());
245 println!(" Y shape: {:?}", y_true.shape().dims());
246 println!(" Target function: y = 2*x1 + 3*x2 + 1");
247
248 // Create optimizer
249 let config = AdamConfig {
250 learning_rate: 0.01,
251 beta1: 0.9,
252 beta2: 0.999,
253 eps: 1e-8,
254 weight_decay: 0.0,
255 amsgrad: false,
256 };
257
258 let mut optimizer = Adam::with_config(config);
259 let params = layer.parameters();
260 for param in ¶ms {
261 optimizer.add_parameter(param);
262 }
263
264 println!("Optimizer setup complete. Starting training...");
265
266 // Training loop
267 let num_epochs = 100;
268 let mut losses = Vec::new();
269
270 for epoch in 0..num_epochs {
271 // Forward pass
272 let y_pred = layer.forward(&x_data);
273
274 // Compute loss: MSE
275 let diff = y_pred.sub_tensor(&y_true);
276 let mut loss = diff.pow_scalar(2.0).mean();
277
278 // Backward pass
279 loss.backward(None);
280
281 // Optimizer step
282 let mut params = layer.parameters();
283 optimizer.step(&mut params);
284 optimizer.zero_grad(&mut params);
285
286 losses.push(loss.value());
287
288 // Print progress
289 if epoch % 20 == 0 || epoch == num_epochs - 1 {
290 println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
291 }
292 }
293
294 // Evaluate final model
295 let final_predictions = layer.forward_no_grad(&x_data);
296
297 println!("\nFinal model evaluation:");
298 println!(" Learned weights: {:?}", layer.weight.data());
299 println!(" Learned bias: {:?}", layer.bias.data());
300 println!(" Target weights: [2.0, 3.0]");
301 println!(" Target bias: [1.0]");
302
303 println!(" Predictions vs True:");
304 for i in 0..4 {
305 let pred = final_predictions.data()[i];
306 let true_val = y_true.data()[i];
307 println!(
308 " Sample {}: pred={:.3}, true={:.1}, error={:.3}",
309 i + 1,
310 pred,
311 true_val,
312 (pred - true_val).abs()
313 );
314 }
315
316 // Training analysis
317 let initial_loss = losses[0];
318 let final_loss = losses[losses.len() - 1];
319 let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
320
321 println!("\nTraining Analysis:");
322 println!(" Initial loss: {:.6}", initial_loss);
323 println!(" Final loss: {:.6}", final_loss);
324 println!(" Loss reduction: {:.1}%", loss_reduction);
325
326 Ok(())
327}
328
329/// Demonstrate single vs batch inference
330fn demonstrate_single_vs_batch_inference() {
331 println!("\n--- Single vs Batch Inference ---");
332
333 let layer = LinearLayer::new(4, 3, Some(46));
334
335 // Single inference
336 println!("Single inference:");
337 let single_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
338 let single_output = layer.forward_no_grad(&single_input);
339 println!(" Input shape: {:?}", single_input.shape().dims());
340 println!(" Output shape: {:?}", single_output.shape().dims());
341 println!(" Output: {:?}", single_output.data());
342
343 // Batch inference
344 println!("Batch inference:");
345 let batch_input = Tensor::from_slice(
346 &[
347 1.0, 2.0, 3.0, 4.0, // Sample 1
348 5.0, 6.0, 7.0, 8.0, // Sample 2
349 9.0, 10.0, 11.0, 12.0, // Sample 3
350 ],
351 vec![3, 4],
352 )
353 .unwrap();
354 let batch_output = layer.forward_no_grad(&batch_input);
355 println!(" Input shape: {:?}", batch_input.shape().dims());
356 println!(" Output shape: {:?}", batch_output.shape().dims());
357
358 // Verify batch consistency - first sample should match single inference
359 let _first_batch_sample = batch_output.view(vec![3, 3]); // Reshape to access first sample
360 let first_sample_data = &batch_output.data()[0..3]; // First 3 elements
361 let single_sample_data = single_output.data();
362
363 println!("Consistency check:");
364 println!(" Single output: {:?}", single_sample_data);
365 println!(" First batch sample: {:?}", first_sample_data);
366 println!(
367 " Match: {}",
368 single_sample_data
369 .iter()
370 .zip(first_sample_data.iter())
371 .all(|(a, b)| (a - b).abs() < 1e-6)
372 );
373}
374
375/// Demonstrate serialization and loading
376fn demonstrate_serialization() -> Result<(), Box<dyn std::error::Error>> {
377 println!("\n--- Serialization ---");
378
379 // Create and train a simple layer
380 let mut original_layer = LinearLayer::new(2, 1, Some(47));
381
382 // Simple training data
383 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
384 let y_true = Tensor::from_slice(&[5.0, 11.0], vec![2, 1]).unwrap();
385
386 let mut optimizer = Adam::with_learning_rate(0.01);
387 let params = original_layer.parameters();
388 for param in ¶ms {
389 optimizer.add_parameter(param);
390 }
391
392 // Train for a few epochs
393 for _ in 0..10 {
394 let y_pred = original_layer.forward(&x_data);
395 let mut loss = (y_pred.sub_tensor(&y_true)).pow_scalar(2.0).mean();
396 loss.backward(None);
397
398 let mut params = original_layer.parameters();
399 optimizer.step(&mut params);
400 optimizer.zero_grad(&mut params);
401 }
402
403 println!("Original layer trained");
404 println!(" Weight: {:?}", original_layer.weight.data());
405 println!(" Bias: {:?}", original_layer.bias.data());
406
407 // Save layer
408 original_layer.save_json("temp_linear_layer")?;
409
410 // Load layer
411 let loaded_layer = LinearLayer::load_json("temp_linear_layer", 2, 1)?;
412
413 println!("Loaded layer");
414 println!(" Weight: {:?}", loaded_layer.weight.data());
415 println!(" Bias: {:?}", loaded_layer.bias.data());
416
417 // Verify consistency
418 let test_input = Tensor::from_slice(&[1.0, 1.0], vec![1, 2]).unwrap();
419 let original_output = original_layer.forward_no_grad(&test_input);
420 let loaded_output = loaded_layer.forward_no_grad(&test_input);
421
422 println!("Consistency check:");
423 println!(" Original output: {:?}", original_output.data());
424 println!(" Loaded output: {:?}", loaded_output.data());
425 println!(
426 " Match: {}",
427 original_output
428 .data()
429 .iter()
430 .zip(loaded_output.data().iter())
431 .all(|(a, b)| (a - b).abs() < 1e-6)
432 );
433
434 println!("Serialization verification: PASSED");
435
436 Ok(())
437}- examples/iterators/element_iteration.rs
- examples/getting_started/tensor_basics.rs
- examples/neural_networks/feedforward_network.rs
- examples/getting_started/serialization_basics.rs
- examples/optimizers/adam_configurations.rs
- examples/optimizers/learning_rate_scheduling.rs
- examples/getting_started/optimizer_basics.rs
- examples/iterators/advanced_patterns.rs
- examples/iterators/performance_optimization.rs
- examples/supervised_training/supervised_bce.rs
- examples/supervised_training/supervised_classification.rs
- examples/supervised_training/supervised_regression.rs
Source§impl Tensor
impl Tensor
Sourcepub fn randn(shape_dims: Vec<usize>, seed: Option<u64>) -> Self
pub fn randn(shape_dims: Vec<usize>, seed: Option<u64>) -> Self
Creates a tensor with normally distributed random values (mean=0, std=1)
Similar to PyTorch’s torch.randn(), creates a tensor filled with random
values drawn from a standard normal distribution (mean=0, standard deviation=1).
Uses Box-Muller transform for efficient normal distribution generation.
This method provides high-quality random number generation with optional reproducibility through seed-based generation. The generated values follow a standard normal distribution suitable for machine learning applications.
§Arguments
shape_dims- Vector of dimension sizes defining the tensor shapeseed- Optional seed for reproducible random generation
§Returns
A new tensor with normally distributed random values
§Performance
- Box-Muller Transform: Efficient normal distribution generation
- SIMD Optimization: Vectorized operations for large tensors
- Memory Efficient: Single-pass generation with optimized allocation
- Thread Safe: Uses thread-local random state
§Examples
§Basic Usage
use train_station::Tensor;
// Create a 2x3 tensor with random normal values
let tensor = Tensor::randn(vec![2, 3], None);
assert_eq!(tensor.size(), 6);
assert_eq!(tensor.shape().dims(), vec![2, 3]);
// Verify random values are generated
let first_value = tensor.get(&[0, 0]);
assert!(first_value != 0.0); // Should be random§Reproducible Generation
use train_station::Tensor;
// Create with fixed seed for reproducible results
let tensor1 = Tensor::randn(vec![100], Some(42));
let tensor2 = Tensor::randn(vec![100], Some(42));
// tensor1 and tensor2 will have identical values
for i in 0..tensor1.size() {
assert!((tensor1.get(&[i]) - tensor2.get(&[i])).abs() < 1e-6);
}§Statistical Properties
use train_station::Tensor;
// Generate large tensor for statistical analysis
let tensor = Tensor::randn(vec![1000], Some(42));
assert_eq!(tensor.size(), 1000);
// Check that values are reasonable (within 4 standard deviations)
let mut min_val = f32::INFINITY;
let mut max_val = f32::NEG_INFINITY;
let mut sum = 0.0;
for i in 0..tensor.size() {
let val = tensor.get(&[i]);
min_val = min_val.min(val);
max_val = max_val.max(val);
sum += val;
}
let mean = sum / tensor.size() as f32;
// Mean should be close to 0, values should be within reasonable bounds
assert!(mean.abs() < 0.1, "Mean should be close to 0, got {}", mean);
assert!(min_val > -4.0, "Values should not be too negative, min: {}", min_val);
assert!(max_val < 4.0, "Values should not be too positive, max: {}", max_val);§Zero-Sized Tensors
use train_station::Tensor;
// Handle empty tensors gracefully
let tensor = Tensor::randn(vec![0], Some(42));
assert_eq!(tensor.size(), 0);
assert_eq!(tensor.shape().dims(), vec![0]);§Implementation Details
This method uses the Box-Muller transform to generate normally distributed random variables from uniform random variables. The process involves:
- Random Number Generation: Uses Xorshift algorithm for uniform random numbers
- Box-Muller Transform: Converts uniform random variables to normal distribution
- SIMD Optimization: Vectorized operations for large tensors when available
- Numerical Stability: Robust handling of edge cases and potential NaN values
The Box-Muller transform ensures that the generated values follow a true normal distribution with mean=0 and standard deviation=1, making it suitable for machine learning applications requiring normally distributed random values.
Examples found in repository?
53 pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
54 // Xavier/Glorot initialization: scale by sqrt(1/input_size)
55 let scale = (1.0 / input_size as f32).sqrt();
56
57 let weight = Tensor::randn(vec![input_size, output_size], seed)
58 .mul_scalar(scale)
59 .with_requires_grad();
60 let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
61
62 Self {
63 weight,
64 bias,
65 input_size,
66 output_size,
67 }
68 }More examples
73fn main() -> Result<(), Box<dyn std::error::Error>> {
74 println!("=== Basic Encoder Example ===");
75
76 let batch = 2usize;
77 let seq = 6usize;
78 let embed = 32usize;
79 let heads = 4usize;
80
81 let input = Tensor::randn(vec![batch, seq, embed], Some(11));
82 let mut enc = EncoderBlock::new(embed, heads, Some(123));
83
84 // Example: no mask (set Some(mask) to use masking)
85 let out = enc.forward(&input, None);
86 println!("Output shape: {:?}", out.shape().dims());
87
88 // Verify gradients/optimization
89 let mut opt = Adam::with_learning_rate(0.01);
90 let mut params = enc.parameters();
91 for p in ¶ms {
92 opt.add_parameter(p);
93 }
94 let mut loss = out.mean();
95 loss.backward(None);
96 opt.step(&mut params);
97 opt.zero_grad(&mut params);
98 println!("Loss: {:.6}", loss.value());
99 println!("=== Done ===");
100 Ok(())
101}84fn main() -> Result<(), Box<dyn std::error::Error>> {
85 println!("=== Basic Decoder Example ===");
86
87 let batch = 2usize;
88 let src = 7usize;
89 let tgt = 5usize;
90 let embed = 32usize;
91 let heads = 4usize;
92
93 let memory = Tensor::randn(vec![batch, src, embed], Some(21));
94 let tgt_in = Tensor::randn(vec![batch, tgt, embed], Some(22));
95
96 let mut dec = DecoderBlock::new(embed, heads, Some(456));
97 let out = dec.forward(&tgt_in, &memory, None, None);
98 println!("Output shape: {:?}", out.shape().dims());
99
100 let mut opt = Adam::with_learning_rate(0.01);
101 let mut params = dec.parameters();
102 for p in ¶ms {
103 opt.add_parameter(p);
104 }
105 let mut loss = out.mean();
106 loss.backward(None);
107 opt.step(&mut params);
108 opt.zero_grad(&mut params);
109 println!("Loss: {:.6}", loss.value());
110 println!("=== Done ===");
111 Ok(())
112}42fn demonstrate_tensor_creation() {
43 println!("--- Tensor Creation ---");
44
45 // Create tensors with different initializations
46 let zeros = Tensor::zeros(vec![2, 3]);
47 println!(
48 "Zeros tensor: shape {:?}, data: {:?}",
49 zeros.shape().dims(),
50 zeros.data()
51 );
52
53 let ones = Tensor::ones(vec![3, 2]);
54 println!(
55 "Ones tensor: shape {:?}, data: {:?}",
56 ones.shape().dims(),
57 ones.data()
58 );
59
60 // Create tensor from slice
61 let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
62 let from_slice = Tensor::from_slice(&data, vec![2, 3]).unwrap();
63 println!(
64 "From slice: shape {:?}, data: {:?}",
65 from_slice.shape().dims(),
66 from_slice.data()
67 );
68
69 // Create tensor with specific value
70 let mut filled = Tensor::new(vec![2, 2]);
71 {
72 let data = filled.data_mut();
73 for value in data.iter_mut() {
74 *value = 42.0;
75 }
76 }
77 println!("Filled with 42: {:?}", filled.data());
78
79 // Create tensor with random data
80 let random = Tensor::randn(vec![2, 2], Some(42));
81 println!(
82 "Random tensor: shape {:?}, data: {:?}",
83 random.shape().dims(),
84 random.data()
85 );
86}363fn demonstrate_configurable_architectures() {
364 println!("\n--- Configurable Architectures ---");
365
366 let architectures = vec![
367 ("Shallow", vec![8]),
368 ("Medium", vec![16, 8]),
369 ("Deep", vec![32, 16, 8, 4]),
370 ("Wide", vec![64, 32]),
371 ("Bottleneck", vec![16, 4, 16]),
372 ];
373
374 for (name, hidden_sizes) in architectures {
375 let config = FeedForwardConfig {
376 input_size: 10,
377 hidden_sizes,
378 output_size: 3,
379 use_bias: true,
380 };
381
382 let network = FeedForwardNetwork::new(config.clone(), Some(44));
383
384 // Test forward pass
385 let test_input = Tensor::randn(vec![5, 10], Some(45)); // Batch of 5
386 let output = network.forward_no_grad(&test_input);
387
388 println!("{} network:", name);
389 println!(" Architecture: 10 -> {:?} -> 3", config.hidden_sizes);
390 println!(" Parameters: {}", network.parameter_count());
391 println!(" Test output shape: {:?}", output.shape().dims());
392 println!(
393 " Output range: [{:.3}, {:.3}]",
394 output.data().iter().fold(f32::INFINITY, |a, &b| a.min(b)),
395 output
396 .data()
397 .iter()
398 .fold(f32::NEG_INFINITY, |a, &b| a.max(b))
399 );
400 }
401}231fn main() -> Result<(), Box<dyn std::error::Error>> {
232 println!("=== Basic Transformer Example ===");
233
234 let batch = 2usize;
235 let src_len = 8usize;
236 let tgt_len = 6usize;
237 let embed = 32usize;
238 let heads = 4usize;
239 let layers = 2usize;
240
241 let src = Tensor::randn(vec![batch, src_len, embed], Some(1001));
242 let tgt = Tensor::randn(vec![batch, tgt_len, embed], Some(1002));
243
244 let mut trf = BasicTransformer::new(embed, heads, layers, Some(999));
245 let out = trf.forward(&src, &tgt);
246 println!("Output shape: {:?}", out.shape().dims());
247
248 // Quick optimization step
249 let mut opt = Adam::with_learning_rate(0.005);
250 let mut params = trf.parameters();
251 for p in ¶ms {
252 opt.add_parameter(p);
253 }
254 let mut loss = out.mean();
255 loss.backward(None);
256 opt.step(&mut params);
257 opt.zero_grad(&mut params);
258 println!("Loss: {:.6}", loss.value());
259
260 // Demo: non auto-regressive inference (single pass)
261 let nar = trf.infer_non_autoregressive(&src, tgt_len);
262 println!("NAR output shape: {:?}", nar.shape().dims());
263
264 // Demo: auto-regressive inference (toy)
265 let ar = trf.infer_autoregressive(&src, 3);
266 println!("AR output shape: {:?}", ar.shape().dims());
267
268 // NAR training demo
269 let nar_tgt = tgt.clone();
270 trf.train_non_autoregressive_steps(&src, &nar_tgt, 3, 0.01);
271
272 // AR training demo (teacher-forced)
273 let ar_tgt = tgt.clone();
274 trf.train_autoregressive_steps(&src, &ar_tgt, 3, 0.01);
275 println!("=== Done ===");
276 Ok(())
277}Sourcepub fn fill_randn(&mut self, seed: Option<u64>)
pub fn fill_randn(&mut self, seed: Option<u64>)
Fills the tensor with normally distributed random values
Internal method that fills an existing tensor with random values from a standard normal distribution. Uses Box-Muller transform for efficiency and provides SIMD optimization for large tensors.
This method is used internally by randn() and provides the core
random number generation functionality with optimized performance
characteristics.
§Arguments
seed- Optional seed for reproducible random generation
§Performance
- Box-Muller Transform: Generates pairs of normal random variables
- SIMD Optimization: Vectorized operations when possible
- Memory Efficient: Single-pass generation
- Unrolled Loops: 4x unrolling for better instruction throughput
§Implementation Details
The method performs the following steps:
- Zero-sized Check: Returns early for empty tensors
- RNG Initialization: Creates Xorshift RNG with seed or system time
- SIMD Detection: Checks for AVX2 availability for optimized path
- Generation: Uses SIMD or scalar path based on hardware support
- Completion: Fills all tensor elements with normal random values
The method automatically handles hardware capabilities and falls back to scalar operations when SIMD is not available, ensuring compatibility across different CPU architectures.
Source§impl Tensor
impl Tensor
Sourcepub fn chunks(&self, chunk_size: usize) -> TensorChunksIterator<'_>
pub fn chunks(&self, chunk_size: usize) -> TensorChunksIterator<'_>
Standard slice-like chunks iterator. Use this instead of iter_chunks.
Iterates over contiguous or view-backed slices of the tensor with the specified chunk size. In no-grad fast mode, a single contiguous owner may be materialized to optimize subsequent views.
§Arguments
chunk_size- Number of elements per chunk (must be > 0)
§Examples
use train_station::tensor::TensorCollectExt;
use train_station::Tensor;
let t = Tensor::from_slice(&(1..=6).map(|i| i as f32).collect::<Vec<_>>(), vec![6]).unwrap();
let y = t.chunks(2).map(|c| c.mul_scalar(2.0)).collect_shape(vec![6]);
assert_eq!(y.data(), &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0]);Examples found in repository?
162fn demonstrate_memory_optimization() -> Result<(), Box<dyn std::error::Error>> {
163 println!("\n--- Memory Optimization ---");
164
165 // Create a large tensor for memory testing
166 let size = 10000;
167 let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
168 let tensor = Tensor::from_slice(&data, vec![size])?;
169
170 println!("Processing tensor of size: {}", size);
171
172 // Pattern 1: Streaming processing with iterator chunks (process in blocks, collect with shape)
173 println!("\nPattern 1: Streaming Processing");
174 let chunk_size = 1000;
175 let start = Instant::now();
176 let flattened = tensor.view(vec![size as i32]);
177 let _streamed_result: Tensor = flattened
178 .chunks(chunk_size)
179 .map(|c| c.pow_scalar(2.0).sqrt())
180 .collect_shape(vec![size]);
181 let streamed_time = start.elapsed();
182
183 // Pattern 2: Full processing
184 let start = Instant::now();
185 let _full_result: Tensor = tensor
186 .iter_elements()
187 .map(|elem| elem.pow_scalar(2.0).sqrt())
188 .collect_shape(vec![size]);
189 let full_time = start.elapsed();
190
191 println!(" Streaming time: {:?}", streamed_time);
192 println!(" Full processing time: {:?}", full_time);
193 println!(
194 " Memory efficiency ratio: {:.2}x",
195 full_time.as_nanos() as f64 / streamed_time.as_nanos() as f64
196 );
197
198 // Pattern 3: Lazy evaluation with take
199 println!("\nPattern 2: Lazy Evaluation");
200 let start = Instant::now();
201 let lazy_result: Tensor = tensor
202 .iter_elements()
203 .take(1000) // Only process first 1000 elements
204 .map(|elem| elem.pow_scalar(2.0).sqrt())
205 .collect_shape(vec![1000]);
206 let lazy_time = start.elapsed();
207
208 println!(" Lazy processing (1000 elements): {:?}", lazy_time);
209 println!(" Lazy result size: {}", lazy_result.size());
210
211 // Pattern 4: Memory-efficient filtering
212 println!("\nPattern 3: Memory-Efficient Filtering");
213 let start = Instant::now();
214 let filtered_result: Tensor = tensor
215 .iter_elements()
216 .filter(|elem| elem.value() > size as f32 / 2.0) // Keep only large values
217 .map(|elem| elem.mul_scalar(2.0))
218 .collect();
219 let filtered_time = start.elapsed();
220
221 println!(" Filtered processing: {:?}", filtered_time);
222 println!(
223 " Filtered result size: {} (reduced from {})",
224 filtered_result.size(),
225 size
226 );
227
228 Ok(())
229}Sourcepub fn chunks_exact(&self, chunk_size: usize) -> TensorChunksExactIterator<'_>
pub fn chunks_exact(&self, chunk_size: usize) -> TensorChunksExactIterator<'_>
Standard slice-like exact chunks iterator. Use this instead of iter_chunks_exact.
Yields only the exact chunks of size chunk_size, exposing any remainder
via remainder(). See chunks() for a variant that yields the remainder as
the last (smaller) chunk.
§Examples
use train_station::Tensor;
let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
let mut it = t.chunks_exact(2);
assert_eq!(it.next().unwrap().data(), &[1.0, 2.0]);
assert_eq!(it.next().unwrap().data(), &[3.0, 4.0]);
assert_eq!(it.remainder().data(), &[5.0]);pub fn iter_chunks(&self, chunk_size: usize) -> TensorChunksIterator<'_>
pub fn iter_chunks_exact( &self, chunk_size: usize, ) -> TensorChunksExactIterator<'_>
Source§impl Tensor
impl Tensor
Sourcepub fn collect_into_shape<I: IntoIterator<Item = Tensor>>(
iter: I,
dims: Vec<usize>,
) -> Tensor
pub fn collect_into_shape<I: IntoIterator<Item = Tensor>>( iter: I, dims: Vec<usize>, ) -> Tensor
Collect tensors into a single tensor with target shape, copying data in iterator order. Optimizes copy using SIMD when available; asserts total size matches.
Source§impl Tensor
impl Tensor
Sourcepub fn collect_values_shape<I: IntoIterator<Item = f32>>(
iter: I,
dims: Vec<usize>,
) -> Tensor
pub fn collect_values_shape<I: IntoIterator<Item = f32>>( iter: I, dims: Vec<usize>, ) -> Tensor
Inherent helper to collect any iterator of f32 into a shaped tensor.
This mirrors the ValuesCollectExt::collect_shape functionality but does
not require importing the extension trait.
Source§impl Tensor
impl Tensor
Sourcepub fn iter_elements(&self) -> TensorElementIterator<'_>
pub fn iter_elements(&self) -> TensorElementIterator<'_>
Create an iterator over scalar elements (flattened view)
Each yielded item is a [1]-shaped Tensor view that shares storage with
the source. This iterator is GradTrack-aware; element operations propagate
gradients to the original tensor when gradients are enabled.
§Returns
An iterator producing scalar view tensors in row-major order.
§Examples
Collect transformed elements back to the original shape using collect_shape:
use train_station::tensor::TensorCollectExt;
use train_station::Tensor;
let x = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let y = x
.iter_elements()
.map(|e| e.mul_scalar(2.0))
.collect_shape(vec![2, 2]);
assert_eq!(y.data(), &[2.0, 4.0, 6.0, 8.0]);Examples found in repository?
87fn demonstrate_performance_benchmarking() -> Result<(), Box<dyn std::error::Error>> {
88 println!("\n--- Performance Benchmarking ---");
89
90 // Create test data of different sizes
91 let sizes = vec![100, 1000, 10000];
92
93 for size in sizes {
94 println!("\nBenchmarking with tensor size: {}", size);
95
96 // Generate test data
97 let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
98 let tensor = Tensor::from_slice(&data, vec![size])?;
99
100 // Benchmark 1: Direct tensor operations
101 let start = Instant::now();
102 let direct_result = tensor.mul_scalar(2.0).add_scalar(1.0);
103 let direct_time = start.elapsed();
104
105 // Benchmark 2: Iterator-based operations (grad-enabled views, flatten + collect_shape)
106 let start = Instant::now();
107 let iterator_result: Tensor = tensor
108 .iter_elements()
109 .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0))
110 .collect_shape(vec![size]);
111 let iterator_time = start.elapsed();
112
113 // Benchmark 3: Chained iterator operations
114 let start = Instant::now();
115 let _chained_result: Tensor = tensor
116 .iter_elements()
117 .map(|elem| elem.mul_scalar(2.0))
118 .filter(|elem| elem.value() > size as f32)
119 .map(|elem| elem.add_scalar(1.0))
120 .collect();
121 let chained_time = start.elapsed();
122
123 // Benchmark 4: NoGrad raw data streaming + collect_shape
124 let start = Instant::now();
125 let _streamed: Tensor = with_no_grad(|| {
126 tensor
127 .data()
128 .iter()
129 .copied()
130 .map(|x| 2.0 * x + 1.0)
131 .collect_shape(vec![size])
132 });
133 let streaming_time = start.elapsed();
134
135 // Report results
136 println!(" Direct operations: {:?}", direct_time);
137 println!(" Iterator operations: {:?}", iterator_time);
138 println!(" Chained operations: {:?}", chained_time);
139 println!(" NoGrad streaming (data.iter): {:?}", streaming_time);
140
141 // Verify correctness
142 assert_eq!(direct_result.data(), iterator_result.data());
143 println!(
144 " Results match: {}",
145 direct_result.data() == iterator_result.data()
146 );
147
148 // Performance ratios
149 let ratio = iterator_time.as_nanos() as f64 / direct_time.as_nanos() as f64;
150 let ratio_stream = streaming_time.as_nanos() as f64 / direct_time.as_nanos() as f64;
151 println!(" Iterator/Direct ratio: {:.2}x", ratio);
152 println!(" Streaming/Direct ratio: {:.2}x", ratio_stream);
153 }
154
155 Ok(())
156}
157
158/// Demonstrate memory optimization patterns
159///
160/// Shows memory-efficient processing patterns and techniques
161/// for minimizing memory usage while maintaining performance.
162fn demonstrate_memory_optimization() -> Result<(), Box<dyn std::error::Error>> {
163 println!("\n--- Memory Optimization ---");
164
165 // Create a large tensor for memory testing
166 let size = 10000;
167 let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
168 let tensor = Tensor::from_slice(&data, vec![size])?;
169
170 println!("Processing tensor of size: {}", size);
171
172 // Pattern 1: Streaming processing with iterator chunks (process in blocks, collect with shape)
173 println!("\nPattern 1: Streaming Processing");
174 let chunk_size = 1000;
175 let start = Instant::now();
176 let flattened = tensor.view(vec![size as i32]);
177 let _streamed_result: Tensor = flattened
178 .chunks(chunk_size)
179 .map(|c| c.pow_scalar(2.0).sqrt())
180 .collect_shape(vec![size]);
181 let streamed_time = start.elapsed();
182
183 // Pattern 2: Full processing
184 let start = Instant::now();
185 let _full_result: Tensor = tensor
186 .iter_elements()
187 .map(|elem| elem.pow_scalar(2.0).sqrt())
188 .collect_shape(vec![size]);
189 let full_time = start.elapsed();
190
191 println!(" Streaming time: {:?}", streamed_time);
192 println!(" Full processing time: {:?}", full_time);
193 println!(
194 " Memory efficiency ratio: {:.2}x",
195 full_time.as_nanos() as f64 / streamed_time.as_nanos() as f64
196 );
197
198 // Pattern 3: Lazy evaluation with take
199 println!("\nPattern 2: Lazy Evaluation");
200 let start = Instant::now();
201 let lazy_result: Tensor = tensor
202 .iter_elements()
203 .take(1000) // Only process first 1000 elements
204 .map(|elem| elem.pow_scalar(2.0).sqrt())
205 .collect_shape(vec![1000]);
206 let lazy_time = start.elapsed();
207
208 println!(" Lazy processing (1000 elements): {:?}", lazy_time);
209 println!(" Lazy result size: {}", lazy_result.size());
210
211 // Pattern 4: Memory-efficient filtering
212 println!("\nPattern 3: Memory-Efficient Filtering");
213 let start = Instant::now();
214 let filtered_result: Tensor = tensor
215 .iter_elements()
216 .filter(|elem| elem.value() > size as f32 / 2.0) // Keep only large values
217 .map(|elem| elem.mul_scalar(2.0))
218 .collect();
219 let filtered_time = start.elapsed();
220
221 println!(" Filtered processing: {:?}", filtered_time);
222 println!(
223 " Filtered result size: {} (reduced from {})",
224 filtered_result.size(),
225 size
226 );
227
228 Ok(())
229}Sourcepub fn iter_range(&self, start: usize, end: usize) -> TensorElementIterator<'_>
pub fn iter_range(&self, start: usize, end: usize) -> TensorElementIterator<'_>
Create an iterator over a clamped range of elements
Produces scalar view tensors from start..end (clamped to [0, size]).
§Arguments
start- Start index (inclusive)end- End index (exclusive)
§Examples
use train_station::Tensor;
let x = Tensor::from_slice(&(0..6).map(|i| i as f32).collect::<Vec<_>>(), vec![6]).unwrap();
let vals: Vec<f32> = x.iter_range(2, 5).map(|e| e.value()).collect();
assert_eq!(vals, vec![2.0, 3.0, 4.0]);Examples found in repository?
235fn demonstrate_large_scale_processing() -> Result<(), Box<dyn std::error::Error>> {
236 println!("\n--- Large-Scale Processing ---");
237
238 // Simulate large dataset processing
239 let sizes = vec![10000, 50000, 100000];
240
241 for size in sizes {
242 println!("\nProcessing dataset of size: {}", size);
243
244 // Generate large dataset
245 let data: Vec<f32> = (0..size)
246 .map(|i| {
247 let x = i as f32 / size as f32;
248 x * x + 0.1 * (i % 10) as f32 // Quadratic with noise
249 })
250 .collect();
251
252 let tensor = Tensor::from_slice(&data, vec![size])?;
253
254 // Technique 1: Batch processing
255 let batch_size = 1000;
256 let start = Instant::now();
257
258 let mut batch_results = Vec::new();
259 for batch_start in (0..size).step_by(batch_size) {
260 let batch_end = (batch_start + batch_size).min(size);
261 let batch: Tensor = tensor
262 .iter_range(batch_start, batch_end)
263 .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
264 .collect();
265 batch_results.push(batch);
266 }
267 let batch_time = start.elapsed();
268
269 // Technique 2: Parallel-like processing with stride
270 let start = Instant::now();
271 let stride = 4;
272 let strided_result: Tensor = tensor
273 .iter()
274 .enumerate()
275 .filter(|(i, _)| i % stride == 0)
276 .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
277 .collect();
278 let strided_time = start.elapsed();
279
280 // Technique 3: Hierarchical processing
281 let start = Instant::now();
282 let coarse: Tensor = tensor
283 .iter()
284 .enumerate()
285 .filter(|(i, _)| i % 10 == 0) // Every 10th element
286 .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
287 .collect();
288 let fine: Tensor = tensor
289 .iter()
290 .enumerate()
291 .filter(|(i, _)| i % 10 != 0) // Rest of elements
292 .map(|(_, elem)| elem.pow_scalar(1.5).add_scalar(0.5))
293 .collect();
294 let hierarchical_time = start.elapsed();
295
296 // Report performance
297 println!(" Batch processing: {:?}", batch_time);
298 println!(" Strided processing: {:?}", strided_time);
299 println!(" Hierarchical processing: {:?}", hierarchical_time);
300
301 // Memory usage analysis
302 let total_batches = size.div_ceil(batch_size);
303 println!(" Batch count: {}", total_batches);
304 println!(" Strided result size: {}", strided_result.size());
305 println!(
306 " Hierarchical: coarse={}, fine={}",
307 coarse.size(),
308 fine.size()
309 );
310 }
311
312 Ok(())
313}More examples
340fn demonstrate_real_world_scenarios() -> Result<(), Box<dyn std::error::Error>> {
341 println!("\n--- Real-world Scenarios ---");
342
343 // Scenario 1: Time series analysis
344 println!("\nScenario 1: Time Series Analysis");
345 let time_series: Vec<f32> = (0..24)
346 .map(|hour| {
347 let base = 20.0 + 10.0 * (hour as f32 * std::f32::consts::PI / 12.0).sin();
348 base + (hour % 3) as f32 * 2.0 // Add some noise
349 })
350 .collect();
351
352 let series = Tensor::from_slice(&time_series, vec![24])?;
353 println!(" Time series (24 hours): {:?}", series.data());
354
355 // Calculate moving average with view-based iteration
356 let window_size = 3;
357 let moving_avg: Tensor = series
358 .iter()
359 .enumerate()
360 .map(|(i, _)| {
361 let start = i.saturating_sub(window_size / 2);
362 let end = (i + window_size / 2 + 1).min(series.size());
363 let window = series.iter_range(start, end);
364 window.fold(0.0, |acc, elem| acc + elem.value()) / (end - start) as f32
365 })
366 .map(|val| Tensor::from_slice(&[val], vec![1]).unwrap())
367 .collect();
368 println!(
369 " Moving average (window={}): {:?}",
370 window_size,
371 moving_avg.data()
372 );
373
374 // Inference pipeline with NoGrad + streaming
375 println!("\nInference pipeline (NoGrad + streaming)");
376 let features = Tensor::from_slice(
377 &(0..48).map(|i| i as f32 * 0.125).collect::<Vec<_>>(),
378 vec![6, 8],
379 )?;
380 let fast = with_no_grad(|| {
381 // Stream values directly, apply light affine, and collect back to same shape
382 features
383 .data()
384 .iter()
385 .copied()
386 .map(|x| 0.75 * x + 0.1)
387 .collect_shape(vec![6, 8])
388 });
389 println!(
390 " NoGrad streamed transform shape: {:?}",
391 fast.shape().dims()
392 );
393
394 // Row-wise iteration with shape-preserving collection (GradTrack-friendly)
395 let per_row: Tensor = features
396 .iter()
397 .map(|row| row.mul_scalar(0.5).add_scalar(2.0))
398 .collect_shape(vec![6, 8]);
399 println!(" Row-wise mapped shape: {:?}", per_row.shape().dims());
400
401 // Scenario 2: Feature engineering
402 println!("\nScenario 2: Feature Engineering");
403 let features = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
404 println!(" Original features: {:?}", features.data());
405
406 // Create polynomial features
407 let poly_features: Tensor = features
408 .iter()
409 .flat_map(|elem| {
410 vec![
411 elem.clone(), // x^1
412 elem.pow_scalar(2.0), // x^2
413 elem.pow_scalar(3.0), // x^3
414 ]
415 })
416 .collect();
417 println!(
418 " Polynomial features (x, x^2, x^3): {:?}",
419 poly_features.data()
420 );
421
422 // Scenario 3: Data augmentation
423 println!("\nScenario 3: Data Augmentation");
424 let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?;
425 println!(" Original data: {:?}", original.data());
426
427 // Augment with noise and scaling
428 let augmented: Tensor = original
429 .iter()
430 .flat_map(|elem| {
431 vec![
432 elem.clone(), // Original
433 elem.add_scalar(0.1), // Add noise
434 elem.sub_scalar(0.1), // Subtract noise
435 elem.mul_scalar(1.1), // Scale up
436 elem.mul_scalar(0.9), // Scale down
437 ]
438 })
439 .collect();
440 println!(" Augmented data: {:?}", augmented.data());
441
442 // Scenario 4: Statistical analysis
443 println!("\nScenario 4: Statistical Analysis");
444 let sample_data = Tensor::from_slice(&[1.1, 2.3, 1.8, 2.1, 1.9, 2.0, 1.7, 2.2], vec![8])?;
445 println!(" Sample data: {:?}", sample_data.data());
446
447 // Calculate various statistics
448 let mean = sample_data.mean().value();
449 let std = sample_data.std().value();
450 let min = sample_data
451 .iter()
452 .map(|e| e.value())
453 .fold(f32::INFINITY, f32::min);
454 let max = sample_data
455 .iter()
456 .map(|e| e.value())
457 .fold(f32::NEG_INFINITY, f32::max);
458
459 // Z-score normalization
460 let z_scores: Tensor = sample_data
461 .iter()
462 .map(|elem| elem.sub_scalar(mean).div_scalar(std))
463 .collect();
464
465 println!(
466 " Statistics: mean={:.3}, std={:.3}, min={:.3}, max={:.3}",
467 mean, std, min, max
468 );
469 println!(" Z-scores: {:?}", z_scores.data());
470
471 Ok(())
472}Source§impl Tensor
impl Tensor
Sourcepub fn iter_dim(&self, dim: usize) -> TensorDimIterator<'_>
pub fn iter_dim(&self, dim: usize) -> TensorDimIterator<'_>
Iterate over sub-tensors along a specific dimension.
Produces view tensors by slicing along the given dimension; each item has that dimension removed (rank - 1). Views share storage and preserve gradient tracking semantics.
§Arguments
dim- Dimension to iterate over
§Examples
use train_station::tensor::TensorCollectExt;
use train_station::Tensor;
let t = Tensor::from_slice(&(1..=6).map(|i| i as f32).collect::<Vec<_>>(), vec![2, 3]).unwrap();
let out = t.iter_dim(0).map(|row| row.add_scalar(1.0)).collect_shape(vec![2, 3]);
assert_eq!(out.data(), &[2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);Sourcepub fn iter(&self) -> TensorDimIterator<'_>
pub fn iter(&self) -> TensorDimIterator<'_>
Default iterator over the outermost dimension, yielding sub-tensors (N-D) or scalar views (1-D).
This is equivalent to iter_dim(0) with an important optimization:
- For 1-D tensors, it yields scalar element views of shape
[1](same asiter_flat()) to maximize GradTrack cooperation and collection performance. - For N-D tensors (rank > 1), it yields sub-tensors with the outermost dimension removed (rank − 1), suitable for row/batch-wise processing.
Views share storage with the source tensor and preserve gradient tracking semantics.
Use collect_shape([..]) to reconstruct shape efficiently after per-item transforms.
§Returns
An iterator producing view tensors for each slice along the outermost dimension (or scalar views for 1-D).
§Examples
1-D: element views (shape [1]) and shape-preserving collection
use train_station::tensor::TensorCollectExt;
use train_station::Tensor;
let v = Tensor::from_slice(&(0..6).map(|i| i as f32).collect::<Vec<_>>(), vec![6]).unwrap();
let out = v.iter().map(|e| e.add_scalar(1.0)).collect_shape(vec![6]);
assert_eq!(out.data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);2-D: row-wise transforms and shape-preserving collection
use train_station::tensor::TensorCollectExt;
use train_station::Tensor;
let m = Tensor::from_slice(&(1..=6).map(|i| i as f32).collect::<Vec<_>>(), vec![2, 3]).unwrap();
let y = m.iter().map(|row| row.mul_scalar(2.0)).collect_shape(vec![2, 3]);
assert_eq!(y.data(), &[2.0, 4.0, 6.0, 8.0, 10.0, 12.0]);Examples found in repository?
93fn demonstrate_basic_iteration() -> Result<(), Box<dyn std::error::Error>> {
94 println!("\n--- Basic Element Iteration ---");
95
96 // Create a simple tensor for demonstration
97 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
98 println!("Original tensor: {:?}", tensor.data());
99
100 // Basic iteration with for loop
101 println!("\nBasic iteration with for loop:");
102 for (i, element) in tensor.iter().enumerate() {
103 println!(
104 " Element {}: value = {:.1}, shape = {:?}",
105 i,
106 element.value(),
107 element.shape().dims()
108 );
109 }
110
111 // Element-wise transformation
112 println!("\nElement-wise transformation (2x + 1):");
113 let transformed: Tensor = tensor
114 .iter()
115 .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0))
116 .collect();
117 println!(" Result: {:?}", transformed.data());
118
119 // Filtering elements
120 println!("\nFiltering elements (values > 3.0):");
121 let filtered: Tensor = tensor.iter().filter(|elem| elem.value() > 3.0).collect();
122 println!(" Filtered: {:?}", filtered.data());
123
124 Ok(())
125}
126
127/// Demonstrate standard iterator trait methods
128///
129/// Shows compatibility with Rust's standard library iterator methods
130/// and demonstrates various functional programming patterns.
131fn demonstrate_standard_methods() -> Result<(), Box<dyn std::error::Error>> {
132 println!("\n--- Standard Iterator Methods ---");
133
134 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
135
136 // Using map for transformations
137 println!("\nMap transformation (square each element):");
138 let squared: Tensor = tensor.iter().map(|elem| elem.pow_scalar(2.0)).collect();
139 println!(" Squared: {:?}", squared.data());
140
141 // Using enumerate for indexed operations
142 println!("\nEnumerate with indexed operations:");
143 let indexed: Tensor = tensor
144 .iter()
145 .enumerate()
146 .map(|(i, elem)| elem.add_scalar(i as f32))
147 .collect();
148 println!(" Indexed: {:?}", indexed.data());
149
150 // Using fold for reduction
151 println!("\nFold for sum calculation:");
152 let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
153 println!(" Sum: {:.1}", sum);
154
155 // Using find for element search
156 println!("\nFind specific element:");
157 if let Some(found) = tensor.iter().find(|elem| elem.value() == 3.0) {
158 println!(" Found element with value 3.0: {:.1}", found.value());
159 }
160
161 // Using any/all for condition checking
162 println!("\nCondition checking:");
163 let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
164 let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
165 println!(" All positive: {}", all_positive);
166 println!(" Any > 4.0: {}", any_large);
167
168 Ok(())
169}
170
171/// Demonstrate gradient tracking through element operations
172///
173/// Shows how gradient tracking works seamlessly through iterator
174/// operations, maintaining the computational graph for backpropagation.
175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176 println!("\n--- Gradient Tracking ---");
177
178 // Create a tensor with gradient tracking enabled
179 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180 println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182 // Perform element-wise operations through iteration
183 let result: Tensor = tensor
184 .iter()
185 .map(|elem| {
186 // Apply a complex transformation: (x^2 + 1) * 2
187 elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188 })
189 .collect();
190
191 println!("Result tensor: {:?}", result.data());
192 println!("Result requires_grad: {}", result.requires_grad());
193
194 // Compute gradients
195 let mut loss = result.sum();
196 loss.backward(None);
197
198 println!("Loss: {:.6}", loss.value());
199 println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201 Ok(())
202}
203
204/// Demonstrate advanced iterator patterns
205///
206/// Shows complex iterator chains and advanced functional programming
207/// patterns for sophisticated data processing workflows.
208fn demonstrate_advanced_patterns() -> Result<(), Box<dyn std::error::Error>> {
209 println!("\n--- Advanced Iterator Patterns ---");
210
211 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6])?;
212 println!("Input tensor: {:?}", tensor.data());
213
214 // Complex chain: enumerate -> filter -> map -> collect
215 println!("\nComplex chain (even indices only, add index to value):");
216 let result: Tensor = tensor
217 .iter()
218 .enumerate()
219 .filter(|(i, _)| i % 2 == 0) // Take even indices
220 .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
221 .collect();
222 println!(" Result: {:?}", result.data());
223
224 // Using take and skip for windowing
225 println!("\nWindowing with take and skip:");
226 let window1: Tensor = tensor.iter().take(3).collect();
227 let window2: Tensor = tensor.iter().skip(2).take(3).collect();
228 println!(" Window 1 (first 3): {:?}", window1.data());
229 println!(" Window 2 (middle 3): {:?}", window2.data());
230
231 // Using rev() for reverse iteration
232 println!("\nReverse iteration:");
233 let reversed: Tensor = tensor.iter().rev().collect();
234 println!(" Reversed: {:?}", reversed.data());
235
236 // Chaining with mathematical operations
237 println!("\nMathematical operation chain:");
238 let math_result: Tensor = tensor
239 .iter()
240 .map(|elem| elem.exp()) // e^x
241 .filter(|elem| elem.value() < 50.0) // Filter large values
242 .map(|elem| elem.log()) // ln(x)
243 .collect();
244 println!(" Math chain result: {:?}", math_result.data());
245
246 // Using zip for element-wise combinations
247 println!("\nElement-wise combination with zip:");
248 let tensor2 = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![6])?;
249 let combined: Tensor = tensor
250 .iter()
251 .zip(tensor2.iter())
252 .map(|(a, b)| a.mul_tensor(&b)) // Element-wise multiplication
253 .collect();
254 println!(" Combined: {:?}", combined.data());
255
256 Ok(())
257}
258
259/// Demonstrate per-row transforms with shape-preserving collection
260///
261/// Shows how to use `iter()` over the outer dimension on a 2D tensor and
262/// `collect_shape([..])` to maintain the original shape after mapping.
263fn demonstrate_row_wise_collect_shape() -> Result<(), Box<dyn std::error::Error>> {
264 println!("\n--- Row-wise iteration with collect_shape ---");
265 let mat = Tensor::from_slice(&(1..=12).map(|x| x as f32).collect::<Vec<_>>(), vec![3, 4])?;
266 println!("Input shape: {:?}", mat.shape().dims());
267
268 // Map each row: 1.1*x + 0.5, then collect back to [3,4]
269 let out: Tensor = mat
270 .iter()
271 .map(|row| row.mul_scalar(1.1).add_scalar(0.5))
272 .collect_shape(vec![3, 4]);
273 println!(" Output shape: {:?}", out.shape().dims());
274
275 Ok(())
276}
277
278/// Demonstrate NoGrad fast paths and raw data streaming
279///
280/// Highlights how to get maximum performance in inference by:
281/// - Disabling gradient tracking with `with_no_grad`
282/// - Iterating raw values via `tensor.data().iter().copied()`
283/// - Using `collect_shape` to stream directly into destination tensors
284fn demonstrate_nograd_and_streaming() -> Result<(), Box<dyn std::error::Error>> {
285 println!("\n--- NoGrad & Streaming (Inference Fast Paths) ---");
286
287 let input = Tensor::from_slice(
288 &(0..24).map(|i| i as f32 * 0.25).collect::<Vec<_>>(),
289 vec![4, 6],
290 )?;
291 println!("Input shape: {:?}", input.shape().dims());
292
293 // NoGrad: stream values directly and reshape
294 let out = with_no_grad(|| {
295 input
296 .data()
297 .iter()
298 .copied()
299 .map(|x| 1.2 * x - 0.3)
300 .collect_shape(vec![4, 6])
301 });
302 println!(
303 " NoGrad streamed map (1.2x-0.3) -> shape {:?}",
304 out.shape().dims()
305 );
306
307 // Compare to view-based element iteration in NoGrad
308 let out_view: Tensor = with_no_grad(|| {
309 input
310 .iter()
311 .map(|e| e.mul_scalar(1.2).add_scalar(-0.3))
312 .collect_shape(vec![4, 6])
313 });
314 println!(
315 " NoGrad view-based map shape {:?}",
316 out_view.shape().dims()
317 );
318
319 // Quick parity check
320 assert_eq!(out.data(), out_view.data());
321 println!(" Parity check passed.");
322
323 // Show simple flatten + collect back to a different shape
324 let reshaped = with_no_grad(|| input.data().iter().copied().collect_shape(vec![6, 4]));
325 println!(
326 " Reshaped via streaming collect_shape: {:?}",
327 reshaped.shape().dims()
328 );
329
330 Ok(())
331}More examples
87fn demonstrate_data_pipeline() -> Result<(), Box<dyn std::error::Error>> {
88 println!("\n--- Data Processing Pipeline ---");
89
90 // Simulate raw sensor data with noise
91 let raw_data: Vec<f32> = (0..20)
92 .map(|i| {
93 let base = i as f32 * 0.5;
94 let noise = (i % 3) as f32 * 0.1;
95 base + noise
96 })
97 .collect();
98
99 let tensor = Tensor::from_slice(&raw_data, vec![20])?;
100 println!("Raw sensor data: {:?}", tensor.data());
101
102 // Multi-stage processing pipeline
103 println!("\nProcessing pipeline:");
104 println!("1. Normalize data (z-score)");
105 println!("2. Apply smoothing filter");
106 println!("3. Detect outliers");
107 println!("4. Apply feature scaling");
108
109 // Stage 1: Normalization
110 let mean = tensor.mean().value();
111 let std = tensor.std().value();
112 let normalized: Tensor = tensor
113 .iter()
114 .map(|elem| elem.sub_scalar(mean).div_scalar(std))
115 .collect();
116 println!(
117 " Normalized (mean={:.3}, std={:.3}): {:?}",
118 mean,
119 std,
120 normalized.data()
121 );
122
123 // Stage 2: Smoothing (simple moving average)
124 let smoothed: Tensor = normalized
125 .iter()
126 .enumerate()
127 .map(|(i, elem)| {
128 if i == 0 || i == normalized.size() - 1 {
129 elem.clone()
130 } else {
131 // Simple 3-point average
132 let prev = normalized.element_view(i - 1);
133 let next = normalized.element_view(i + 1);
134 elem.add_tensor(&prev).add_tensor(&next).div_scalar(3.0)
135 }
136 })
137 .collect();
138 println!(" Smoothed: {:?}", smoothed.data());
139
140 // Stage 3: Outlier detection and removal
141 let outlier_threshold = 2.0;
142 let cleaned: Tensor = smoothed
143 .iter()
144 .filter(|elem| elem.value().abs() < outlier_threshold)
145 .collect();
146 println!(
147 " Outliers removed (threshold={}): {:?}",
148 outlier_threshold,
149 cleaned.data()
150 );
151
152 // Stage 4: Feature scaling to [0, 1] range
153 let min_val = cleaned
154 .iter()
155 .map(|e| e.value())
156 .fold(f32::INFINITY, f32::min);
157 let max_val = cleaned
158 .iter()
159 .map(|e| e.value())
160 .fold(f32::NEG_INFINITY, f32::max);
161 let scaled: Tensor = cleaned
162 .iter()
163 .map(|elem| elem.sub_scalar(min_val).div_scalar(max_val - min_val))
164 .collect();
165 println!(" Scaled to [0,1]: {:?}", scaled.data());
166
167 Ok(())
168}
169
170/// Demonstrate conditional processing patterns
171///
172/// Shows how to implement dynamic filtering and transformation
173/// based on data characteristics and conditions.
174fn demonstrate_conditional_processing() -> Result<(), Box<dyn std::error::Error>> {
175 println!("\n--- Conditional Processing ---");
176
177 // Create data with mixed characteristics
178 let data = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0];
179 let tensor = Tensor::from_slice(&data, vec![10])?;
180 println!("Input data: {:?}", tensor.data());
181
182 // Conditional transformation based on sign
183 println!("\nConditional transformation (positive/negative handling):");
184 let processed: Tensor = tensor
185 .iter()
186 .map(|elem| {
187 let val = elem.value();
188 if val > 0.0 {
189 elem.pow_scalar(2.0) // Square positive values
190 } else {
191 elem.mul_scalar(-1.0).sqrt() // Square root of absolute negative values
192 }
193 })
194 .collect();
195 println!(" Processed: {:?}", processed.data());
196
197 // Adaptive filtering based on local statistics
198 println!("\nAdaptive filtering (remove values > 2 std from local mean):");
199 let window_size = 3;
200 let adaptive_filtered: Tensor = tensor
201 .iter()
202 .enumerate()
203 .filter(|(i, elem)| {
204 let start = i.saturating_sub(window_size / 2);
205 let end = (i + window_size / 2 + 1).min(tensor.size());
206
207 // Calculate local mean and std
208 let local_values: Vec<f32> = (start..end)
209 .map(|j| tensor.element_view(j).value())
210 .collect();
211
212 let local_mean = local_values.iter().sum::<f32>() / local_values.len() as f32;
213 let local_variance = local_values
214 .iter()
215 .map(|v| (v - local_mean).powi(2))
216 .sum::<f32>()
217 / local_values.len() as f32;
218 let local_std = local_variance.sqrt();
219
220 let threshold = local_mean + 2.0 * local_std;
221 elem.value() <= threshold
222 })
223 .map(|(_, elem)| elem)
224 .collect();
225 println!(" Adaptive filtered: {:?}", adaptive_filtered.data());
226
227 // Multi-condition processing
228 println!("\nMulti-condition processing:");
229 let multi_processed: Tensor = tensor
230 .iter()
231 .map(|elem| {
232 let val = elem.value();
233 match () {
234 _ if val > 5.0 => elem.mul_scalar(2.0), // Double large values
235 _ if val < -5.0 => elem.div_scalar(2.0), // Halve small values
236 _ if val.abs() < 2.0 => elem.add_scalar(1.0), // Add 1 to small values
237 _ => elem.clone(), // Keep others unchanged
238 }
239 })
240 .collect();
241 println!(" Multi-condition: {:?}", multi_processed.data());
242
243 Ok(())
244}
245
246/// Demonstrate batch processing operations
247///
248/// Shows efficient processing of large datasets using iterator
249/// patterns and batch operations for performance optimization.
250fn demonstrate_batch_operations() -> Result<(), Box<dyn std::error::Error>> {
251 println!("\n--- Batch Operations ---");
252
253 // Create a larger dataset for batch processing
254 let size = 100;
255 let data: Vec<f32> = (0..size)
256 .map(|i| {
257 let x = i as f32 / size as f32;
258 x * x + 0.1 * (i % 7) as f32 // Quadratic with some noise
259 })
260 .collect();
261
262 let tensor = Tensor::from_slice(&data, vec![size])?;
263 println!("Dataset size: {}", tensor.size());
264
265 // Batch processing with windowing (iterator views)
266 println!("\nBatch processing with sliding windows:");
267 let batch_size = 10;
268 let batches: Vec<Tensor> = tensor
269 .iter()
270 .collect::<Vec<_>>()
271 .chunks(batch_size)
272 .map(|chunk| {
273 // Process each batch independently
274 chunk
275 .iter()
276 .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
277 .collect()
278 })
279 .collect();
280
281 println!(
282 " Processed {} batches of size {}",
283 batches.len(),
284 batch_size
285 );
286 for (i, batch) in batches.iter().enumerate() {
287 println!(
288 " Batch {}: mean={:.3}, std={:.3}",
289 i,
290 batch.mean().value(),
291 batch.std().value()
292 );
293 }
294
295 // Parallel-like processing with stride
296 println!("\nStrided processing (every nth element):");
297 let stride = 5;
298 let strided: Tensor = tensor
299 .iter()
300 .enumerate()
301 .filter(|(i, _)| i % stride == 0)
302 .map(|(_, elem)| elem)
303 .collect();
304 println!(" Strided (every {}th): {:?}", stride, strided.data());
305
306 // Hierarchical processing
307 println!("\nHierarchical processing (coarse to fine):");
308 let coarse: Tensor = tensor
309 .iter()
310 .enumerate()
311 .filter(|(i, _)| i % 4 == 0) // Take every 4th element
312 .map(|(_, elem)| elem)
313 .collect();
314
315 let fine: Tensor = tensor
316 .iter()
317 .enumerate()
318 .filter(|(i, _)| i % 4 != 0) // Take the rest
319 .map(|(_, elem)| elem)
320 .collect();
321
322 println!(" Coarse (every 4th): {:?}", coarse.data());
323 println!(" Fine (rest): {:?}", fine.data());
324
325 // Combine coarse and fine with different processing
326 let combined: Tensor = coarse
327 .iter()
328 .map(|elem| elem.mul_scalar(2.0)) // Scale coarse
329 .chain(fine.iter().map(|elem| elem.div_scalar(2.0))) // Scale fine
330 .collect();
331 println!(" Combined: {:?}", combined.data());
332
333 Ok(())
334}
335
336/// Demonstrate real-world processing scenarios
337///
338/// Shows practical applications of iterator patterns for
339/// common data processing tasks in machine learning and analytics.
340fn demonstrate_real_world_scenarios() -> Result<(), Box<dyn std::error::Error>> {
341 println!("\n--- Real-world Scenarios ---");
342
343 // Scenario 1: Time series analysis
344 println!("\nScenario 1: Time Series Analysis");
345 let time_series: Vec<f32> = (0..24)
346 .map(|hour| {
347 let base = 20.0 + 10.0 * (hour as f32 * std::f32::consts::PI / 12.0).sin();
348 base + (hour % 3) as f32 * 2.0 // Add some noise
349 })
350 .collect();
351
352 let series = Tensor::from_slice(&time_series, vec![24])?;
353 println!(" Time series (24 hours): {:?}", series.data());
354
355 // Calculate moving average with view-based iteration
356 let window_size = 3;
357 let moving_avg: Tensor = series
358 .iter()
359 .enumerate()
360 .map(|(i, _)| {
361 let start = i.saturating_sub(window_size / 2);
362 let end = (i + window_size / 2 + 1).min(series.size());
363 let window = series.iter_range(start, end);
364 window.fold(0.0, |acc, elem| acc + elem.value()) / (end - start) as f32
365 })
366 .map(|val| Tensor::from_slice(&[val], vec![1]).unwrap())
367 .collect();
368 println!(
369 " Moving average (window={}): {:?}",
370 window_size,
371 moving_avg.data()
372 );
373
374 // Inference pipeline with NoGrad + streaming
375 println!("\nInference pipeline (NoGrad + streaming)");
376 let features = Tensor::from_slice(
377 &(0..48).map(|i| i as f32 * 0.125).collect::<Vec<_>>(),
378 vec![6, 8],
379 )?;
380 let fast = with_no_grad(|| {
381 // Stream values directly, apply light affine, and collect back to same shape
382 features
383 .data()
384 .iter()
385 .copied()
386 .map(|x| 0.75 * x + 0.1)
387 .collect_shape(vec![6, 8])
388 });
389 println!(
390 " NoGrad streamed transform shape: {:?}",
391 fast.shape().dims()
392 );
393
394 // Row-wise iteration with shape-preserving collection (GradTrack-friendly)
395 let per_row: Tensor = features
396 .iter()
397 .map(|row| row.mul_scalar(0.5).add_scalar(2.0))
398 .collect_shape(vec![6, 8]);
399 println!(" Row-wise mapped shape: {:?}", per_row.shape().dims());
400
401 // Scenario 2: Feature engineering
402 println!("\nScenario 2: Feature Engineering");
403 let features = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
404 println!(" Original features: {:?}", features.data());
405
406 // Create polynomial features
407 let poly_features: Tensor = features
408 .iter()
409 .flat_map(|elem| {
410 vec![
411 elem.clone(), // x^1
412 elem.pow_scalar(2.0), // x^2
413 elem.pow_scalar(3.0), // x^3
414 ]
415 })
416 .collect();
417 println!(
418 " Polynomial features (x, x^2, x^3): {:?}",
419 poly_features.data()
420 );
421
422 // Scenario 3: Data augmentation
423 println!("\nScenario 3: Data Augmentation");
424 let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?;
425 println!(" Original data: {:?}", original.data());
426
427 // Augment with noise and scaling
428 let augmented: Tensor = original
429 .iter()
430 .flat_map(|elem| {
431 vec![
432 elem.clone(), // Original
433 elem.add_scalar(0.1), // Add noise
434 elem.sub_scalar(0.1), // Subtract noise
435 elem.mul_scalar(1.1), // Scale up
436 elem.mul_scalar(0.9), // Scale down
437 ]
438 })
439 .collect();
440 println!(" Augmented data: {:?}", augmented.data());
441
442 // Scenario 4: Statistical analysis
443 println!("\nScenario 4: Statistical Analysis");
444 let sample_data = Tensor::from_slice(&[1.1, 2.3, 1.8, 2.1, 1.9, 2.0, 1.7, 2.2], vec![8])?;
445 println!(" Sample data: {:?}", sample_data.data());
446
447 // Calculate various statistics
448 let mean = sample_data.mean().value();
449 let std = sample_data.std().value();
450 let min = sample_data
451 .iter()
452 .map(|e| e.value())
453 .fold(f32::INFINITY, f32::min);
454 let max = sample_data
455 .iter()
456 .map(|e| e.value())
457 .fold(f32::NEG_INFINITY, f32::max);
458
459 // Z-score normalization
460 let z_scores: Tensor = sample_data
461 .iter()
462 .map(|elem| elem.sub_scalar(mean).div_scalar(std))
463 .collect();
464
465 println!(
466 " Statistics: mean={:.3}, std={:.3}, min={:.3}, max={:.3}",
467 mean, std, min, max
468 );
469 println!(" Z-scores: {:?}", z_scores.data());
470
471 Ok(())
472}235fn demonstrate_large_scale_processing() -> Result<(), Box<dyn std::error::Error>> {
236 println!("\n--- Large-Scale Processing ---");
237
238 // Simulate large dataset processing
239 let sizes = vec![10000, 50000, 100000];
240
241 for size in sizes {
242 println!("\nProcessing dataset of size: {}", size);
243
244 // Generate large dataset
245 let data: Vec<f32> = (0..size)
246 .map(|i| {
247 let x = i as f32 / size as f32;
248 x * x + 0.1 * (i % 10) as f32 // Quadratic with noise
249 })
250 .collect();
251
252 let tensor = Tensor::from_slice(&data, vec![size])?;
253
254 // Technique 1: Batch processing
255 let batch_size = 1000;
256 let start = Instant::now();
257
258 let mut batch_results = Vec::new();
259 for batch_start in (0..size).step_by(batch_size) {
260 let batch_end = (batch_start + batch_size).min(size);
261 let batch: Tensor = tensor
262 .iter_range(batch_start, batch_end)
263 .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
264 .collect();
265 batch_results.push(batch);
266 }
267 let batch_time = start.elapsed();
268
269 // Technique 2: Parallel-like processing with stride
270 let start = Instant::now();
271 let stride = 4;
272 let strided_result: Tensor = tensor
273 .iter()
274 .enumerate()
275 .filter(|(i, _)| i % stride == 0)
276 .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
277 .collect();
278 let strided_time = start.elapsed();
279
280 // Technique 3: Hierarchical processing
281 let start = Instant::now();
282 let coarse: Tensor = tensor
283 .iter()
284 .enumerate()
285 .filter(|(i, _)| i % 10 == 0) // Every 10th element
286 .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
287 .collect();
288 let fine: Tensor = tensor
289 .iter()
290 .enumerate()
291 .filter(|(i, _)| i % 10 != 0) // Rest of elements
292 .map(|(_, elem)| elem.pow_scalar(1.5).add_scalar(0.5))
293 .collect();
294 let hierarchical_time = start.elapsed();
295
296 // Report performance
297 println!(" Batch processing: {:?}", batch_time);
298 println!(" Strided processing: {:?}", strided_time);
299 println!(" Hierarchical processing: {:?}", hierarchical_time);
300
301 // Memory usage analysis
302 let total_batches = size.div_ceil(batch_size);
303 println!(" Batch count: {}", total_batches);
304 println!(" Strided result size: {}", strided_result.size());
305 println!(
306 " Hierarchical: coarse={}, fine={}",
307 coarse.size(),
308 fine.size()
309 );
310 }
311
312 Ok(())
313}
314
315/// Demonstrate advanced optimization techniques
316///
317/// Shows sophisticated optimization strategies and techniques
318/// for maximizing performance in tensor iterator operations.
319fn demonstrate_optimization_techniques() -> Result<(), Box<dyn std::error::Error>> {
320 println!("\n--- Optimization Techniques ---");
321
322 let size = 50000;
323 let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
324 let tensor = Tensor::from_slice(&data, vec![size])?;
325
326 println!("Optimizing processing for size: {}", size);
327
328 // Technique 1: Operation fusion
329 println!("\nTechnique 1: Operation Fusion");
330 let start = Instant::now();
331 let fused_result: Tensor = tensor
332 .iter()
333 .map(|elem| {
334 // Fuse multiple operations into single chain
335 elem.mul_scalar(2.0).add_scalar(1.0).pow_scalar(2.0).sqrt()
336 })
337 .collect();
338 let fused_time = start.elapsed();
339
340 // Technique 2: Conditional optimization
341 println!("\nTechnique 2: Conditional Optimization");
342 let start = Instant::now();
343 let conditional_result: Tensor = tensor
344 .iter()
345 .map(|elem| {
346 let val = elem.value();
347 if val < size as f32 / 2.0 {
348 elem.mul_scalar(2.0) // Simple operation for small values
349 } else {
350 elem.pow_scalar(2.0).sqrt() // Complex operation for large values
351 }
352 })
353 .collect();
354 let conditional_time = start.elapsed();
355
356 // Technique 3: Cache-friendly processing
357 println!("\nTechnique 3: Cache-Friendly Processing");
358 let start = Instant::now();
359 let cache_friendly_result: Tensor = tensor
360 .iter()
361 .take(1000) // Process in cache-friendly chunks
362 .map(|elem| elem.mul_scalar(2.0))
363 .collect();
364 let cache_friendly_time = start.elapsed();
365
366 // Technique 4: Memory pooling simulation
367 println!("\nTechnique 4: Memory Pooling Simulation");
368 let start = Instant::now();
369 let pooled_result: Tensor = tensor
370 .iter()
371 .enumerate()
372 .filter(|(i, _)| i % 100 == 0) // Process every 100th element
373 .map(|(_, elem)| elem.pow_scalar(2.0))
374 .collect();
375 let pooled_time = start.elapsed();
376
377 // Report optimization results
378 println!(" Fused operations: {:?}", fused_time);
379 println!(" Conditional optimization: {:?}", conditional_time);
380 println!(" Cache-friendly processing: {:?}", cache_friendly_time);
381 println!(" Memory pooling simulation: {:?}", pooled_time);
382
383 // Performance analysis
384 let fastest = fused_time
385 .min(conditional_time)
386 .min(cache_friendly_time)
387 .min(pooled_time);
388 println!(" Fastest technique: {:?}", fastest);
389
390 // Memory efficiency analysis
391 println!(" Fused result size: {}", fused_result.size());
392 println!(" Conditional result size: {}", conditional_result.size());
393 println!(
394 " Cache-friendly result size: {}",
395 cache_friendly_result.size()
396 );
397 println!(" Pooled result size: {}", pooled_result.size());
398
399 // Technique 5: Gradient optimization
400 println!("\nTechnique 5: Gradient Optimization");
401 let grad_tensor = tensor.with_requires_grad();
402 let start = Instant::now();
403
404 let grad_result: Tensor = grad_tensor
405 .iter()
406 .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
407 .collect();
408
409 let mut loss = grad_result.sum();
410 loss.backward(None);
411 let grad_time = start.elapsed();
412
413 println!(" Gradient computation: {:?}", grad_time);
414 println!(
415 " Gradient tracking enabled: {}",
416 grad_result.requires_grad()
417 );
418
419 Ok(())
420}Sourcepub fn outer_iter(&self) -> TensorDimIterator<'_>
pub fn outer_iter(&self) -> TensorDimIterator<'_>
Explicit alias for outermost-dimension iteration of sub-tensors.
Equivalent to iter_dim(0).
Source§impl Tensor
impl Tensor
Sourcepub fn windows(&self, window_size: usize) -> TensorWindowsIterator<'_>
pub fn windows(&self, window_size: usize) -> TensorWindowsIterator<'_>
Overlapping windows iterator with step=1. Use this instead of iter_windows.
Produces overlapping linear windows as view tensors. In no-grad fast mode, a contiguous owner may be materialized once for faster subsequent views.
§Arguments
window_size- Length of each window (> 0)
§Examples
use train_station::Tensor;
let t = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
let v: Vec<f32> = t.windows(3).map(|w| w.sum().value()).collect();
assert_eq!(v, vec![6.0, 9.0]);Sourcepub fn windows_step(
&self,
window_size: usize,
step: usize,
) -> TensorWindowsIterator<'_>
pub fn windows_step( &self, window_size: usize, step: usize, ) -> TensorWindowsIterator<'_>
Overlapping windows iterator with custom step. Use this instead of iter_windows_step.
Produces windows starting at positions 0, step, 2*step, ... up to the last
valid start. Reverse iteration yields the same sequence in reverse.
§Arguments
window_size- Length of each window (> 0)step- Step between consecutive window starts (> 0)
§Examples
use train_station::Tensor;
let t = Tensor::from_slice(&(1..=8).map(|i| i as f32).collect::<Vec<_>>(), vec![8]).unwrap();
let wins: Vec<Tensor> = t.windows_step(3, 2).collect();
assert_eq!(wins[0].data(), &[1.0, 2.0, 3.0]);
assert_eq!(wins[1].data(), &[3.0, 4.0, 5.0]);
assert_eq!(wins[2].data(), &[5.0, 6.0, 7.0]);pub fn iter_windows(&self, window_size: usize) -> TensorWindowsIterator<'_>
pub fn iter_windows_step( &self, window_size: usize, step: usize, ) -> TensorWindowsIterator<'_>
Source§impl Tensor
impl Tensor
Sourcepub fn add_tensor(&self, other: &Tensor) -> Tensor
pub fn add_tensor(&self, other: &Tensor) -> Tensor
Element-wise addition with another tensor with broadcasting support.
Performs element-wise addition with automatic broadcasting: output[i] = self[i] + other[i]
Broadcasting enables addition between tensors of different but compatible shapes. Compatible shapes follow NumPy broadcasting rules:
- Dimensions are aligned from the rightmost dimension
- Dimensions are compatible if they are equal, or one of them is 1
- Missing dimensions are treated as 1
§Arguments
other- Tensor to add. Shapes must be broadcast-compatible.
§Returns
A new tensor containing the element-wise sum with broadcast result shape
§Examples
§Same Shape Addition
use train_station::Tensor;
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
let c = a.add_tensor(&b);
assert_eq!(c.shape().dims(), vec![3]);
assert_eq!(c.get(&[0]), 5.0);
assert_eq!(c.get(&[1]), 7.0);
assert_eq!(c.get(&[2]), 9.0);§Broadcasting Addition
use train_station::Tensor;
// Broadcasting: [2, 1] + [1, 3] -> [2, 3]
let a = Tensor::from_slice(&[1.0, 2.0], vec![2, 1]).unwrap();
let b = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![1, 3]).unwrap();
let c = a.add_tensor(&b);
assert_eq!(c.shape().dims(), vec![2, 3]);
assert_eq!(c.get(&[0, 0]), 11.0);
assert_eq!(c.get(&[0, 1]), 21.0);
assert_eq!(c.get(&[1, 0]), 12.0);
assert_eq!(c.get(&[1, 1]), 22.0);§Scalar Broadcasting
use train_station::Tensor;
// Scalar broadcasting: [2, 3] + scalar -> [2, 3]
let a = Tensor::ones(vec![2, 3]);
let b = Tensor::from_slice(&[5.0], vec![1]).unwrap();
let c = a.add_tensor(&b);
assert_eq!(c.shape().dims(), vec![2, 3]);
assert_eq!(c.get(&[0, 0]), 6.0);
assert_eq!(c.get(&[1, 2]), 6.0);§Panics
Panics if tensor shapes are not broadcast-compatible
Examples found in repository?
More examples
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296 let b = ratio.shape().dims()[0];
297 let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298 let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299 let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300 high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304 let mut total_sq = 0.0f32;
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 for &v in g.data() {
308 total_sq += v * v;
309 }
310 }
311 }
312 total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320 println!("=== PPO Discrete Example (YardEnv) ===");
321
322 let state_dim = 3usize;
323 let action_dim = 3usize;
324 let total_steps = std::env::var("PPOD_STEPS")
325 .ok()
326 .and_then(|v| v.parse::<usize>().ok())
327 .unwrap_or(3500usize);
328 let horizon = 128usize;
329 let epochs = 4usize;
330 let mini_batch_size = 64usize;
331 let gamma = 0.99f32;
332 let lam = 0.95f32;
333 let clip_eps = 0.2f32;
334 let vf_coef = 0.5f32;
335 let ent_coef = 0.0f32;
336 let max_grad_norm = 1.0f32;
337
338 let mut actor = Actor::new(state_dim, action_dim, Some(111));
339 let mut critic = Critic::new(state_dim, Some(222));
340 let mut actor_opt = Adam::with_learning_rate(3e-4);
341 for p in actor.parameters() {
342 actor_opt.add_parameter(p);
343 }
344 let mut critic_opt = Adam::with_learning_rate(3e-4);
345 for p in critic.parameters() {
346 critic_opt.add_parameter(p);
347 }
348
349 let mut env = YardEnv::new(1234);
350 let mut rng = SmallRng::new(98765);
351 let mut state = env.reset();
352 let mut episode_return = 0.0f32;
353 let mut episode = 0usize;
354 let mut ema_return: Option<f32> = None;
355 let ema_alpha = 0.05f32;
356 let mut best_return = f32::NEG_INFINITY;
357
358 let mut t = 0usize;
359 while t < total_steps {
360 let mut batch = RolloutBatch::new(horizon, state_dim);
361 for _ in 0..horizon {
362 // Actor logits and categorical sampling
363 let logits = actor.forward(&state); // [1, A]
364 let probs = logits.softmax(1); // [1, A]
365 // sample action from probs (CPU sampling)
366 let p = probs.data();
367 let (p0, p1, _p2) = (p[0], p[1], p[2]);
368 let u = rng.next_f32();
369 let a_idx = if u < p0 {
370 0
371 } else if u < p0 + p1 {
372 1
373 } else {
374 2
375 };
376
377 let old_logp = {
378 let _ng = NoGradTrack::new();
379 let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380 lp.data()[0]
381 };
382
383 // Step env
384 let (next_state, reward, done) = env.step(a_idx);
385 episode_return += reward;
386
387 // Critic value
388 let value_t = critic.forward(&state);
389 let value_v = value_t.data()[0];
390
391 batch.push(
392 state.data(),
393 a_idx,
394 old_logp,
395 reward,
396 if done { 1.0 } else { 0.0 },
397 value_v,
398 next_state.data(),
399 );
400
401 state = if done {
402 let st = env.reset();
403 ema_return = Some(match ema_return {
404 None => episode_return,
405 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406 });
407 if episode_return > best_return {
408 best_return = episode_return;
409 }
410 println!(
411 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412 t,
413 episode,
414 episode_return,
415 ema_return.unwrap_or(episode_return),
416 best_return
417 );
418 episode_return = 0.0;
419 episode += 1;
420 st
421 } else {
422 next_state
423 };
424
425 t += 1;
426 if t >= total_steps {
427 break;
428 }
429 }
430
431 // Bootstrap values for GAE
432 let next_values: Vec<f32> = {
433 let mut out = Vec::with_capacity(batch.len());
434 for i in 0..batch.len() {
435 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437 out.push(critic.forward(&s2_t).data()[0]);
438 }
439 out
440 };
441
442 let mut returns = vec![0.0f32; batch.len()];
443 let mut adv = vec![0.0f32; batch.len()];
444 compute_gae(
445 &mut returns,
446 &mut adv,
447 &batch.rewards,
448 &batch.dones,
449 &batch.values,
450 &next_values,
451 gamma,
452 lam,
453 );
454 normalize_in_place(&mut adv, 1e-8);
455
456 // Tensors for training
457 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458 let actions_vec = batch.actions.clone();
459 let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463 // PPO epochs
464 let num_minibatches = batch.len().div_ceil(mini_batch_size);
465 for e in 0..epochs {
466 for mb in 0..num_minibatches {
467 let start = mb * mini_batch_size;
468 let end = (start + mini_batch_size).min(batch.len());
469 if start >= end {
470 break;
471 }
472
473 // Views
474 let s_mb = states_t
475 .slice_view(start * state_dim, 1, (end - start) * state_dim)
476 .reshape(vec![(end - start) as i32, state_dim as i32]);
477 let oldlp_mb = old_logp_t
478 .slice_view(start, 1, end - start)
479 .reshape(vec![(end - start) as i32, 1]);
480 let ret_mb = returns_t
481 .slice_view(start, 1, end - start)
482 .reshape(vec![(end - start) as i32, 1]);
483 let adv_mb = adv_t
484 .slice_view(start, 1, end - start)
485 .reshape(vec![(end - start) as i32, 1]);
486 let a_slice = &actions_vec[start..end];
487
488 // Zero grads
489 {
490 let mut ps = actor.parameters();
491 actor_opt.zero_grad(&mut ps);
492 }
493 {
494 let mut ps = critic.parameters();
495 critic_opt.zero_grad(&mut ps);
496 }
497
498 // Forward
499 let logits_mb = actor.forward(&s_mb); // [B,A]
500 let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501 let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502 let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503 let pg1 = ratio.mul_tensor(&adv_mb);
504 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507 let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509 let v_pred = critic.forward(&s_mb);
510 let v_loss = v_pred
511 .sub_tensor(&ret_mb)
512 .pow_scalar(2.0)
513 .mean()
514 .mul_scalar(vf_coef);
515
516 // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517 let probs_mb = logits_mb.softmax(1);
518 let logp_all = probs_mb.add_scalar(1e-8).log();
519 let ent = probs_mb
520 .mul_tensor(&logp_all)
521 .sum_dims(&[1], true)
522 .mul_scalar(-1.0)
523 .mean()
524 .mul_scalar(ent_coef);
525
526 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527 loss.backward(None);
528
529 // Step actor
530 {
531 let params = actor.parameters();
532 let mut with_grads: Vec<&mut Tensor> = Vec::new();
533 for p in params {
534 if p.grad_owned().is_some() {
535 with_grads.push(p);
536 }
537 }
538 if !with_grads.is_empty() {
539 let _ = grad_global_norm(&mut with_grads);
540 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541 actor_opt.step(&mut with_grads);
542 actor_opt.zero_grad(&mut with_grads);
543 }
544 }
545
546 // Step critic
547 {
548 let params = critic.parameters();
549 let mut with_grads: Vec<&mut Tensor> = Vec::new();
550 for p in params {
551 if p.grad_owned().is_some() {
552 with_grads.push(p);
553 }
554 }
555 if !with_grads.is_empty() {
556 let _ = grad_global_norm(&mut with_grads);
557 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558 critic_opt.step(&mut with_grads);
559 critic_opt.zero_grad(&mut with_grads);
560 }
561 }
562
563 if e == 0 && mb == 0 {
564 println!(
565 "update@t={} | actor_loss={:.4} v_loss={:.4}",
566 t,
567 actor_loss.value(),
568 v_loss.value()
569 );
570 }
571
572 clear_all_graphs_known();
573 }
574 }
575 }
576
577 println!("=== PPO discrete training finished ===");
578 Ok(())
579}53 pub fn forward(&self, input: &Tensor, attn_mask: Option<&Tensor>) -> Tensor {
54 let attn = self.mha.forward(input, input, input, attn_mask);
55 let res1 = attn.add_tensor(input);
56
57 // Feed-forward network with ReLU and residual
58 let (b, t, e) = Self::triple(input);
59 let x2d = res1.contiguous().view(vec![(b * t) as i32, e as i32]);
60 let hidden = self.ffn_in.forward(&x2d).relu();
61 let out2d = self.ffn_out.forward(&hidden);
62 let out = out2d.view(vec![b as i32, t as i32, e as i32]);
63 out.add_tensor(&res1)
64 }234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235 // All tensors shaped [B, A] (log_std is broadcastable)
236 let std = log_std.exp();
237 let var = std.pow_scalar(2.0);
238 let log_scale = log_std;
239 let diff = action.sub_tensor(mean);
240 let log_prob = diff
241 .pow_scalar(2.0)
242 .div_tensor(&var)
243 .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244 .add_tensor(&log_scale.mul_scalar(2.0))
245 .mul_scalar(0.5)
246 .mul_scalar(-1.0);
247 // Sum across action dim (dim=1) -> [B,1]
248 log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253 returns_out: &mut [f32],
254 adv_out: &mut [f32],
255 rewards: &[f32],
256 dones: &[f32],
257 values: &[f32],
258 next_values: &[f32],
259 gamma: f32,
260 lam: f32,
261) {
262 let n = rewards.len();
263 let mut gae = 0.0f32;
264 for t in (0..n).rev() {
265 let not_done = 1.0 - dones[t];
266 let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267 gae = delta + gamma * lam * not_done * gae;
268 adv_out[t] = gae;
269 returns_out[t] = gae + values[t];
270 }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274 let n = x.len() as f32;
275 if n <= 1.0 {
276 return;
277 }
278 let mean = x.iter().copied().sum::<f32>() / n;
279 let var = x
280 .iter()
281 .map(|v| {
282 let d = v - mean;
283 d * d
284 })
285 .sum::<f32>()
286 / n;
287 let std = (var + eps).sqrt();
288 for v in x.iter_mut() {
289 *v = (*v - mean) / std;
290 }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294 let mut total_sq = 0.0f32;
295 for p in parameters.iter() {
296 if let Some(g) = p.grad_owned() {
297 for &v in g.data() {
298 total_sq += v * v;
299 }
300 }
301 }
302 let norm = total_sq.sqrt();
303 if norm > max_norm {
304 let scale = max_norm / (norm + eps);
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 p.set_grad(g.mul_scalar(scale));
308 }
309 }
310 }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314 let mut total_sq = 0.0f32;
315 for p in parameters.iter_mut() {
316 if let Some(g) = p.grad_owned() {
317 for &v in g.data() {
318 total_sq += v * v;
319 }
320 }
321 }
322 total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330 println!("=== PPO Continuous Example (YardEnv) ===");
331
332 let state_dim = 3usize;
333 let action_dim = 1usize;
334
335 // Hparams
336 let total_steps = std::env::var("PPO_STEPS")
337 .ok()
338 .and_then(|v| v.parse::<usize>().ok())
339 .unwrap_or(4000usize);
340 let horizon = 128usize; // rollout length per update
341 let epochs = 4usize; // PPO epochs per update
342 let mini_batch_size = 64usize; // minibatch from horizon
343 let gamma = 0.99f32;
344 let lam = 0.95f32; // GAE lambda
345 let clip_eps = 0.2f32;
346 let vf_coef = 0.5f32;
347 let ent_coef = 0.0f32;
348 let max_grad_norm = 1.0f32;
349
350 // Models
351 let mut actor = Actor::new(state_dim, action_dim, Some(101));
352 let mut critic = Critic::new(state_dim, Some(202));
353
354 // Opts
355 let mut actor_opt = Adam::with_learning_rate(3e-4);
356 for p in actor.parameters() {
357 actor_opt.add_parameter(p);
358 }
359 let mut critic_opt = Adam::with_learning_rate(3e-4);
360 for p in critic.parameters() {
361 critic_opt.add_parameter(p);
362 }
363
364 // Env and RNG
365 let mut env = YardEnv::new(42);
366 let mut rng = SmallRng::new(999);
367 let mut state = env.reset();
368
369 // Metrics
370 let mut episode_return = 0.0f32;
371 let mut episode = 0usize;
372 let mut ema_return: Option<f32> = None;
373 let ema_alpha = 0.05f32;
374 let mut best_return = f32::NEG_INFINITY;
375
376 let mut t = 0usize;
377 while t < total_steps {
378 // Collect a rollout
379 let mut batch = RolloutBatch::new(horizon, state_dim);
380 for _ in 0..horizon {
381 // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382 let (mean, log_std_row) = actor.forward(&state);
383 let mean_v = mean.data()[0];
384 let log_std_v = log_std_row.data()[0];
385 let std_v = log_std_v.exp();
386 let noise = rng.normal();
387 let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389 // Build action tensor [1, A] for log_prob calculation with autograd
390 let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391 let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392 let log_prob_v = log_prob_t.data()[0];
393
394 // Step env
395 let (next_state, reward, done) = env.step(action_v);
396 episode_return += reward;
397
398 // Value
399 let value_t = critic.forward(&state);
400 let value_v = value_t.data()[0];
401
402 // Push
403 batch.push(
404 state.data(),
405 action_v,
406 log_prob_v,
407 reward,
408 if done { 1.0 } else { 0.0 },
409 value_v,
410 next_state.data(),
411 );
412
413 // Reset
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return
430 );
431 episode_return = 0.0;
432 episode += 1;
433 st
434 } else {
435 next_state
436 };
437
438 t += 1;
439 if t >= total_steps {
440 break;
441 }
442 }
443
444 // Bootstrap next values for GAE
445 let next_values: Vec<f32> = {
446 let mut out = Vec::with_capacity(batch.len());
447 for i in 0..batch.len() {
448 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450 let v2 = critic.forward(&s2_t).data()[0];
451 out.push(v2);
452 }
453 out
454 };
455
456 // Compute returns and advantages
457 let mut returns = vec![0.0f32; batch.len()];
458 let mut adv = vec![0.0f32; batch.len()];
459 compute_gae(
460 &mut returns,
461 &mut adv,
462 &batch.rewards,
463 &batch.dones,
464 &batch.values,
465 &next_values,
466 gamma,
467 lam,
468 );
469 normalize_in_place(&mut adv, 1e-8);
470
471 // Prepare tensors for training
472 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473 let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474 let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478 // PPO epochs over the rollout
479 let num_minibatches = batch.len().div_ceil(mini_batch_size);
480 for e in 0..epochs {
481 for mb in 0..num_minibatches {
482 let start = mb * mini_batch_size;
483 let end = (start + mini_batch_size).min(batch.len());
484 if start >= end {
485 break;
486 }
487
488 // Slice views
489 let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490 let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491 let a_mb = actions_t
492 .slice_view(start * action_dim, 1, (end - start) * action_dim)
493 .reshape(vec![(end - start) as i32, action_dim as i32]);
494 let oldlp_mb = old_logp_t
495 .slice_view(start, 1, end - start)
496 .reshape(vec![(end - start) as i32, 1]);
497 let ret_mb = returns_t
498 .slice_view(start, 1, end - start)
499 .reshape(vec![(end - start) as i32, 1]);
500 let adv_mb = adv_t
501 .slice_view(start, 1, end - start)
502 .reshape(vec![(end - start) as i32, 1]);
503
504 // Zero grads
505 {
506 let mut ps = actor.parameters();
507 actor_opt.zero_grad(&mut ps);
508 }
509 {
510 let mut ps = critic.parameters();
511 critic_opt.zero_grad(&mut ps);
512 }
513
514 // Forward actor and critic
515 let (mean_mb, log_std_row) = actor.forward(&s_mb);
516 let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517 let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518 let clip_low =
519 Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520 .unwrap();
521 let clip_high =
522 Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523 .unwrap();
524 // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525 let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526 let ratio_clipped =
527 clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528 let pg1 = ratio.mul_tensor(&adv_mb);
529 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532 let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534 let v_pred = critic.forward(&s_mb);
535 let v_loss = v_pred
536 .sub_tensor(&ret_mb)
537 .pow_scalar(2.0)
538 .mean()
539 .mul_scalar(vf_coef);
540
541 // Entropy (approx Gaussian entropy per action)
542 let entropy = log_std_row
543 .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544 .sum_dims(&[1], true)
545 .mean()
546 .mul_scalar(ent_coef);
547
548 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549 loss.backward(None);
550
551 // Step actor
552 {
553 let params = actor.parameters();
554 let mut with_grads: Vec<&mut Tensor> = Vec::new();
555 for p in params {
556 if p.grad_owned().is_some() {
557 with_grads.push(p);
558 }
559 }
560 if !with_grads.is_empty() {
561 let _ = grad_global_norm(&mut with_grads);
562 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563 actor_opt.step(&mut with_grads);
564 actor_opt.zero_grad(&mut with_grads);
565 }
566 }
567
568 // Step critic
569 {
570 let params = critic.parameters();
571 let mut with_grads: Vec<&mut Tensor> = Vec::new();
572 for p in params {
573 if p.grad_owned().is_some() {
574 with_grads.push(p);
575 }
576 }
577 if !with_grads.is_empty() {
578 let _ = grad_global_norm(&mut with_grads);
579 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580 critic_opt.step(&mut with_grads);
581 critic_opt.zero_grad(&mut with_grads);
582 }
583 }
584
585 // Occasionally log
586 if e == 0 && mb == 0 {
587 println!(
588 "update@t={} | actor_loss={:.4} v_loss={:.4}",
589 t,
590 actor_loss.value(),
591 v_loss.value()
592 );
593 }
594
595 clear_all_graphs_known();
596 }
597 }
598 }
599
600 println!("=== PPO training finished ===");
601 Ok(())
602}56 pub fn forward(
57 &self,
58 tgt: &Tensor,
59 memory: &Tensor,
60 causal_mask: Option<&Tensor>,
61 cross_mask: Option<&Tensor>,
62 ) -> Tensor {
63 let self_attn = self.self_attn.forward(tgt, tgt, tgt, causal_mask);
64 let res1 = self_attn.add_tensor(tgt);
65
66 let cross = self.cross_attn.forward(&res1, memory, memory, cross_mask);
67 let res2 = cross.add_tensor(&res1);
68
69 let (b, t, e) = Self::triple(tgt);
70 let x2d = res2.contiguous().view(vec![(b * t) as i32, e as i32]);
71 let hidden = self.ffn_in.forward(&x2d).relu();
72 let out2d = self.ffn_out.forward(&hidden);
73 let out = out2d.view(vec![b as i32, t as i32, e as i32]);
74 out.add_tensor(&res2)
75 }Sourcepub fn add_scalar(&self, scalar: f32) -> Tensor
pub fn add_scalar(&self, scalar: f32) -> Tensor
Broadcast addition with a scalar value.
Examples found in repository?
More examples
234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235 // All tensors shaped [B, A] (log_std is broadcastable)
236 let std = log_std.exp();
237 let var = std.pow_scalar(2.0);
238 let log_scale = log_std;
239 let diff = action.sub_tensor(mean);
240 let log_prob = diff
241 .pow_scalar(2.0)
242 .div_tensor(&var)
243 .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244 .add_tensor(&log_scale.mul_scalar(2.0))
245 .mul_scalar(0.5)
246 .mul_scalar(-1.0);
247 // Sum across action dim (dim=1) -> [B,1]
248 log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253 returns_out: &mut [f32],
254 adv_out: &mut [f32],
255 rewards: &[f32],
256 dones: &[f32],
257 values: &[f32],
258 next_values: &[f32],
259 gamma: f32,
260 lam: f32,
261) {
262 let n = rewards.len();
263 let mut gae = 0.0f32;
264 for t in (0..n).rev() {
265 let not_done = 1.0 - dones[t];
266 let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267 gae = delta + gamma * lam * not_done * gae;
268 adv_out[t] = gae;
269 returns_out[t] = gae + values[t];
270 }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274 let n = x.len() as f32;
275 if n <= 1.0 {
276 return;
277 }
278 let mean = x.iter().copied().sum::<f32>() / n;
279 let var = x
280 .iter()
281 .map(|v| {
282 let d = v - mean;
283 d * d
284 })
285 .sum::<f32>()
286 / n;
287 let std = (var + eps).sqrt();
288 for v in x.iter_mut() {
289 *v = (*v - mean) / std;
290 }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294 let mut total_sq = 0.0f32;
295 for p in parameters.iter() {
296 if let Some(g) = p.grad_owned() {
297 for &v in g.data() {
298 total_sq += v * v;
299 }
300 }
301 }
302 let norm = total_sq.sqrt();
303 if norm > max_norm {
304 let scale = max_norm / (norm + eps);
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 p.set_grad(g.mul_scalar(scale));
308 }
309 }
310 }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314 let mut total_sq = 0.0f32;
315 for p in parameters.iter_mut() {
316 if let Some(g) = p.grad_owned() {
317 for &v in g.data() {
318 total_sq += v * v;
319 }
320 }
321 }
322 total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330 println!("=== PPO Continuous Example (YardEnv) ===");
331
332 let state_dim = 3usize;
333 let action_dim = 1usize;
334
335 // Hparams
336 let total_steps = std::env::var("PPO_STEPS")
337 .ok()
338 .and_then(|v| v.parse::<usize>().ok())
339 .unwrap_or(4000usize);
340 let horizon = 128usize; // rollout length per update
341 let epochs = 4usize; // PPO epochs per update
342 let mini_batch_size = 64usize; // minibatch from horizon
343 let gamma = 0.99f32;
344 let lam = 0.95f32; // GAE lambda
345 let clip_eps = 0.2f32;
346 let vf_coef = 0.5f32;
347 let ent_coef = 0.0f32;
348 let max_grad_norm = 1.0f32;
349
350 // Models
351 let mut actor = Actor::new(state_dim, action_dim, Some(101));
352 let mut critic = Critic::new(state_dim, Some(202));
353
354 // Opts
355 let mut actor_opt = Adam::with_learning_rate(3e-4);
356 for p in actor.parameters() {
357 actor_opt.add_parameter(p);
358 }
359 let mut critic_opt = Adam::with_learning_rate(3e-4);
360 for p in critic.parameters() {
361 critic_opt.add_parameter(p);
362 }
363
364 // Env and RNG
365 let mut env = YardEnv::new(42);
366 let mut rng = SmallRng::new(999);
367 let mut state = env.reset();
368
369 // Metrics
370 let mut episode_return = 0.0f32;
371 let mut episode = 0usize;
372 let mut ema_return: Option<f32> = None;
373 let ema_alpha = 0.05f32;
374 let mut best_return = f32::NEG_INFINITY;
375
376 let mut t = 0usize;
377 while t < total_steps {
378 // Collect a rollout
379 let mut batch = RolloutBatch::new(horizon, state_dim);
380 for _ in 0..horizon {
381 // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382 let (mean, log_std_row) = actor.forward(&state);
383 let mean_v = mean.data()[0];
384 let log_std_v = log_std_row.data()[0];
385 let std_v = log_std_v.exp();
386 let noise = rng.normal();
387 let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389 // Build action tensor [1, A] for log_prob calculation with autograd
390 let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391 let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392 let log_prob_v = log_prob_t.data()[0];
393
394 // Step env
395 let (next_state, reward, done) = env.step(action_v);
396 episode_return += reward;
397
398 // Value
399 let value_t = critic.forward(&state);
400 let value_v = value_t.data()[0];
401
402 // Push
403 batch.push(
404 state.data(),
405 action_v,
406 log_prob_v,
407 reward,
408 if done { 1.0 } else { 0.0 },
409 value_v,
410 next_state.data(),
411 );
412
413 // Reset
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return
430 );
431 episode_return = 0.0;
432 episode += 1;
433 st
434 } else {
435 next_state
436 };
437
438 t += 1;
439 if t >= total_steps {
440 break;
441 }
442 }
443
444 // Bootstrap next values for GAE
445 let next_values: Vec<f32> = {
446 let mut out = Vec::with_capacity(batch.len());
447 for i in 0..batch.len() {
448 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450 let v2 = critic.forward(&s2_t).data()[0];
451 out.push(v2);
452 }
453 out
454 };
455
456 // Compute returns and advantages
457 let mut returns = vec![0.0f32; batch.len()];
458 let mut adv = vec![0.0f32; batch.len()];
459 compute_gae(
460 &mut returns,
461 &mut adv,
462 &batch.rewards,
463 &batch.dones,
464 &batch.values,
465 &next_values,
466 gamma,
467 lam,
468 );
469 normalize_in_place(&mut adv, 1e-8);
470
471 // Prepare tensors for training
472 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473 let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474 let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478 // PPO epochs over the rollout
479 let num_minibatches = batch.len().div_ceil(mini_batch_size);
480 for e in 0..epochs {
481 for mb in 0..num_minibatches {
482 let start = mb * mini_batch_size;
483 let end = (start + mini_batch_size).min(batch.len());
484 if start >= end {
485 break;
486 }
487
488 // Slice views
489 let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490 let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491 let a_mb = actions_t
492 .slice_view(start * action_dim, 1, (end - start) * action_dim)
493 .reshape(vec![(end - start) as i32, action_dim as i32]);
494 let oldlp_mb = old_logp_t
495 .slice_view(start, 1, end - start)
496 .reshape(vec![(end - start) as i32, 1]);
497 let ret_mb = returns_t
498 .slice_view(start, 1, end - start)
499 .reshape(vec![(end - start) as i32, 1]);
500 let adv_mb = adv_t
501 .slice_view(start, 1, end - start)
502 .reshape(vec![(end - start) as i32, 1]);
503
504 // Zero grads
505 {
506 let mut ps = actor.parameters();
507 actor_opt.zero_grad(&mut ps);
508 }
509 {
510 let mut ps = critic.parameters();
511 critic_opt.zero_grad(&mut ps);
512 }
513
514 // Forward actor and critic
515 let (mean_mb, log_std_row) = actor.forward(&s_mb);
516 let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517 let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518 let clip_low =
519 Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520 .unwrap();
521 let clip_high =
522 Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523 .unwrap();
524 // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525 let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526 let ratio_clipped =
527 clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528 let pg1 = ratio.mul_tensor(&adv_mb);
529 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532 let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534 let v_pred = critic.forward(&s_mb);
535 let v_loss = v_pred
536 .sub_tensor(&ret_mb)
537 .pow_scalar(2.0)
538 .mean()
539 .mul_scalar(vf_coef);
540
541 // Entropy (approx Gaussian entropy per action)
542 let entropy = log_std_row
543 .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544 .sum_dims(&[1], true)
545 .mean()
546 .mul_scalar(ent_coef);
547
548 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549 loss.backward(None);
550
551 // Step actor
552 {
553 let params = actor.parameters();
554 let mut with_grads: Vec<&mut Tensor> = Vec::new();
555 for p in params {
556 if p.grad_owned().is_some() {
557 with_grads.push(p);
558 }
559 }
560 if !with_grads.is_empty() {
561 let _ = grad_global_norm(&mut with_grads);
562 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563 actor_opt.step(&mut with_grads);
564 actor_opt.zero_grad(&mut with_grads);
565 }
566 }
567
568 // Step critic
569 {
570 let params = critic.parameters();
571 let mut with_grads: Vec<&mut Tensor> = Vec::new();
572 for p in params {
573 if p.grad_owned().is_some() {
574 with_grads.push(p);
575 }
576 }
577 if !with_grads.is_empty() {
578 let _ = grad_global_norm(&mut with_grads);
579 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580 critic_opt.step(&mut with_grads);
581 critic_opt.zero_grad(&mut with_grads);
582 }
583 }
584
585 // Occasionally log
586 if e == 0 && mb == 0 {
587 println!(
588 "update@t={} | actor_loss={:.4} v_loss={:.4}",
589 t,
590 actor_loss.value(),
591 v_loss.value()
592 );
593 }
594
595 clear_all_graphs_known();
596 }
597 }
598 }
599
600 println!("=== PPO training finished ===");
601 Ok(())
602}89fn demonstrate_basic_operations() {
90 println!("\n--- Basic Operations ---");
91
92 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
93 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
94
95 // Addition
96 let sum = a.add_tensor(&b);
97 println!("A + B: {:?}", sum.data());
98
99 // Subtraction
100 let diff = a.sub_tensor(&b);
101 println!("A - B: {:?}", diff.data());
102
103 // Multiplication
104 let product = a.mul_tensor(&b);
105 println!("A * B: {:?}", product.data());
106
107 // Division
108 let quotient = a.div_tensor(&b);
109 println!("A / B: {:?}", quotient.data());
110
111 // Scalar operations
112 let scalar_add = a.add_scalar(5.0);
113 println!("A + 5.0: {:?}", scalar_add.data());
114
115 let scalar_mul = a.mul_scalar(2.0);
116 println!("A * 2.0: {:?}", scalar_mul.data());
117}93fn demonstrate_basic_iteration() -> Result<(), Box<dyn std::error::Error>> {
94 println!("\n--- Basic Element Iteration ---");
95
96 // Create a simple tensor for demonstration
97 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
98 println!("Original tensor: {:?}", tensor.data());
99
100 // Basic iteration with for loop
101 println!("\nBasic iteration with for loop:");
102 for (i, element) in tensor.iter().enumerate() {
103 println!(
104 " Element {}: value = {:.1}, shape = {:?}",
105 i,
106 element.value(),
107 element.shape().dims()
108 );
109 }
110
111 // Element-wise transformation
112 println!("\nElement-wise transformation (2x + 1):");
113 let transformed: Tensor = tensor
114 .iter()
115 .map(|elem| elem.mul_scalar(2.0).add_scalar(1.0))
116 .collect();
117 println!(" Result: {:?}", transformed.data());
118
119 // Filtering elements
120 println!("\nFiltering elements (values > 3.0):");
121 let filtered: Tensor = tensor.iter().filter(|elem| elem.value() > 3.0).collect();
122 println!(" Filtered: {:?}", filtered.data());
123
124 Ok(())
125}
126
127/// Demonstrate standard iterator trait methods
128///
129/// Shows compatibility with Rust's standard library iterator methods
130/// and demonstrates various functional programming patterns.
131fn demonstrate_standard_methods() -> Result<(), Box<dyn std::error::Error>> {
132 println!("\n--- Standard Iterator Methods ---");
133
134 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
135
136 // Using map for transformations
137 println!("\nMap transformation (square each element):");
138 let squared: Tensor = tensor.iter().map(|elem| elem.pow_scalar(2.0)).collect();
139 println!(" Squared: {:?}", squared.data());
140
141 // Using enumerate for indexed operations
142 println!("\nEnumerate with indexed operations:");
143 let indexed: Tensor = tensor
144 .iter()
145 .enumerate()
146 .map(|(i, elem)| elem.add_scalar(i as f32))
147 .collect();
148 println!(" Indexed: {:?}", indexed.data());
149
150 // Using fold for reduction
151 println!("\nFold for sum calculation:");
152 let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
153 println!(" Sum: {:.1}", sum);
154
155 // Using find for element search
156 println!("\nFind specific element:");
157 if let Some(found) = tensor.iter().find(|elem| elem.value() == 3.0) {
158 println!(" Found element with value 3.0: {:.1}", found.value());
159 }
160
161 // Using any/all for condition checking
162 println!("\nCondition checking:");
163 let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
164 let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
165 println!(" All positive: {}", all_positive);
166 println!(" Any > 4.0: {}", any_large);
167
168 Ok(())
169}
170
171/// Demonstrate gradient tracking through element operations
172///
173/// Shows how gradient tracking works seamlessly through iterator
174/// operations, maintaining the computational graph for backpropagation.
175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176 println!("\n--- Gradient Tracking ---");
177
178 // Create a tensor with gradient tracking enabled
179 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180 println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182 // Perform element-wise operations through iteration
183 let result: Tensor = tensor
184 .iter()
185 .map(|elem| {
186 // Apply a complex transformation: (x^2 + 1) * 2
187 elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188 })
189 .collect();
190
191 println!("Result tensor: {:?}", result.data());
192 println!("Result requires_grad: {}", result.requires_grad());
193
194 // Compute gradients
195 let mut loss = result.sum();
196 loss.backward(None);
197
198 println!("Loss: {:.6}", loss.value());
199 println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201 Ok(())
202}
203
204/// Demonstrate advanced iterator patterns
205///
206/// Shows complex iterator chains and advanced functional programming
207/// patterns for sophisticated data processing workflows.
208fn demonstrate_advanced_patterns() -> Result<(), Box<dyn std::error::Error>> {
209 println!("\n--- Advanced Iterator Patterns ---");
210
211 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6])?;
212 println!("Input tensor: {:?}", tensor.data());
213
214 // Complex chain: enumerate -> filter -> map -> collect
215 println!("\nComplex chain (even indices only, add index to value):");
216 let result: Tensor = tensor
217 .iter()
218 .enumerate()
219 .filter(|(i, _)| i % 2 == 0) // Take even indices
220 .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
221 .collect();
222 println!(" Result: {:?}", result.data());
223
224 // Using take and skip for windowing
225 println!("\nWindowing with take and skip:");
226 let window1: Tensor = tensor.iter().take(3).collect();
227 let window2: Tensor = tensor.iter().skip(2).take(3).collect();
228 println!(" Window 1 (first 3): {:?}", window1.data());
229 println!(" Window 2 (middle 3): {:?}", window2.data());
230
231 // Using rev() for reverse iteration
232 println!("\nReverse iteration:");
233 let reversed: Tensor = tensor.iter().rev().collect();
234 println!(" Reversed: {:?}", reversed.data());
235
236 // Chaining with mathematical operations
237 println!("\nMathematical operation chain:");
238 let math_result: Tensor = tensor
239 .iter()
240 .map(|elem| elem.exp()) // e^x
241 .filter(|elem| elem.value() < 50.0) // Filter large values
242 .map(|elem| elem.log()) // ln(x)
243 .collect();
244 println!(" Math chain result: {:?}", math_result.data());
245
246 // Using zip for element-wise combinations
247 println!("\nElement-wise combination with zip:");
248 let tensor2 = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![6])?;
249 let combined: Tensor = tensor
250 .iter()
251 .zip(tensor2.iter())
252 .map(|(a, b)| a.mul_tensor(&b)) // Element-wise multiplication
253 .collect();
254 println!(" Combined: {:?}", combined.data());
255
256 Ok(())
257}
258
259/// Demonstrate per-row transforms with shape-preserving collection
260///
261/// Shows how to use `iter()` over the outer dimension on a 2D tensor and
262/// `collect_shape([..])` to maintain the original shape after mapping.
263fn demonstrate_row_wise_collect_shape() -> Result<(), Box<dyn std::error::Error>> {
264 println!("\n--- Row-wise iteration with collect_shape ---");
265 let mat = Tensor::from_slice(&(1..=12).map(|x| x as f32).collect::<Vec<_>>(), vec![3, 4])?;
266 println!("Input shape: {:?}", mat.shape().dims());
267
268 // Map each row: 1.1*x + 0.5, then collect back to [3,4]
269 let out: Tensor = mat
270 .iter()
271 .map(|row| row.mul_scalar(1.1).add_scalar(0.5))
272 .collect_shape(vec![3, 4]);
273 println!(" Output shape: {:?}", out.shape().dims());
274
275 Ok(())
276}
277
278/// Demonstrate NoGrad fast paths and raw data streaming
279///
280/// Highlights how to get maximum performance in inference by:
281/// - Disabling gradient tracking with `with_no_grad`
282/// - Iterating raw values via `tensor.data().iter().copied()`
283/// - Using `collect_shape` to stream directly into destination tensors
284fn demonstrate_nograd_and_streaming() -> Result<(), Box<dyn std::error::Error>> {
285 println!("\n--- NoGrad & Streaming (Inference Fast Paths) ---");
286
287 let input = Tensor::from_slice(
288 &(0..24).map(|i| i as f32 * 0.25).collect::<Vec<_>>(),
289 vec![4, 6],
290 )?;
291 println!("Input shape: {:?}", input.shape().dims());
292
293 // NoGrad: stream values directly and reshape
294 let out = with_no_grad(|| {
295 input
296 .data()
297 .iter()
298 .copied()
299 .map(|x| 1.2 * x - 0.3)
300 .collect_shape(vec![4, 6])
301 });
302 println!(
303 " NoGrad streamed map (1.2x-0.3) -> shape {:?}",
304 out.shape().dims()
305 );
306
307 // Compare to view-based element iteration in NoGrad
308 let out_view: Tensor = with_no_grad(|| {
309 input
310 .iter()
311 .map(|e| e.mul_scalar(1.2).add_scalar(-0.3))
312 .collect_shape(vec![4, 6])
313 });
314 println!(
315 " NoGrad view-based map shape {:?}",
316 out_view.shape().dims()
317 );
318
319 // Quick parity check
320 assert_eq!(out.data(), out_view.data());
321 println!(" Parity check passed.");
322
323 // Show simple flatten + collect back to a different shape
324 let reshaped = with_no_grad(|| input.data().iter().copied().collect_shape(vec![6, 4]));
325 println!(
326 " Reshaped via streaming collect_shape: {:?}",
327 reshaped.shape().dims()
328 );
329
330 Ok(())
331}203fn demonstrate_method_equivalence() {
204 println!("\n--- Operator vs Method Call Equivalence ---");
205
206 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
207 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
208
209 // Addition: operator vs method
210 let operator_result = &a + &b;
211 let method_result = a.add_tensor(&b);
212
213 println!("A + B (operator): {:?}", operator_result.data());
214 println!("A.add_tensor(B): {:?}", method_result.data());
215 println!(
216 "Results are equal: {}",
217 operator_result.data() == method_result.data()
218 );
219
220 // Multiplication: operator vs method
221 let operator_result = &a * &b;
222 let method_result = a.mul_tensor(&b);
223
224 println!("A * B (operator): {:?}", operator_result.data());
225 println!("A.mul_tensor(B): {:?}", method_result.data());
226 println!(
227 "Results are equal: {}",
228 operator_result.data() == method_result.data()
229 );
230
231 // Scalar addition: operator vs method
232 let operator_result = &a + 5.0;
233 let method_result = a.add_scalar(5.0);
234
235 println!("A + 5.0 (operator): {:?}", operator_result.data());
236 println!("A.add_scalar(5.0): {:?}", method_result.data());
237 println!(
238 "Results are equal: {}",
239 operator_result.data() == method_result.data()
240 );
241}Source§impl Tensor
impl Tensor
Sourcepub fn div_tensor(&self, other: &Tensor) -> Tensor
pub fn div_tensor(&self, other: &Tensor) -> Tensor
Element-wise division with another tensor with broadcasting support.
Performs element-wise division with automatic broadcasting: output[i] = self[i] / other[i]
Broadcasting enables division between tensors of different but compatible shapes. Compatible shapes follow NumPy broadcasting rules:
- Dimensions are aligned from the rightmost dimension
- Dimensions are compatible if they are equal, or one of them is 1
- Missing dimensions are treated as 1
§Arguments
other- Tensor to divide by. Shapes must be broadcast-compatible.
§Returns
A new tensor containing the element-wise quotient with broadcast result shape
§Examples
§Same Shape Division
use train_station::Tensor;
let a = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![3]).unwrap();
let b = Tensor::from_slice(&[2.0, 4.0, 5.0], vec![3]).unwrap();
let c = a.div_tensor(&b);
assert_eq!(c.shape().dims(), vec![3]);
assert_eq!(c.get(&[0]), 5.0);
assert_eq!(c.get(&[1]), 5.0);
assert_eq!(c.get(&[2]), 6.0);§Broadcasting Division
use train_station::Tensor;
// Broadcasting: [2, 1] / [1, 3] -> [2, 3]
let a = Tensor::from_slice(&[10.0, 20.0], vec![2, 1]).unwrap();
let b = Tensor::from_slice(&[1.0, 2.0, 5.0], vec![1, 3]).unwrap();
let c = a.div_tensor(&b);
assert_eq!(c.shape().dims(), vec![2, 3]);
assert_eq!(c.get(&[0, 0]), 10.0);
assert_eq!(c.get(&[0, 1]), 5.0);
assert_eq!(c.get(&[1, 0]), 20.0);
assert_eq!(c.get(&[1, 1]), 10.0);§Scalar Division
use train_station::Tensor;
// Scalar division: [2, 3] / scalar -> [2, 3]
let a = Tensor::ones(vec![2, 3]);
let b = Tensor::from_slice(&[2.0], vec![1]).unwrap();
let c = a.div_tensor(&b);
assert_eq!(c.shape().dims(), vec![2, 3]);
assert_eq!(c.get(&[0, 0]), 0.5);
assert_eq!(c.get(&[1, 2]), 0.5);§Panics
Panics if tensor shapes are not broadcast-compatible or division by zero
Examples found in repository?
234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235 // All tensors shaped [B, A] (log_std is broadcastable)
236 let std = log_std.exp();
237 let var = std.pow_scalar(2.0);
238 let log_scale = log_std;
239 let diff = action.sub_tensor(mean);
240 let log_prob = diff
241 .pow_scalar(2.0)
242 .div_tensor(&var)
243 .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244 .add_tensor(&log_scale.mul_scalar(2.0))
245 .mul_scalar(0.5)
246 .mul_scalar(-1.0);
247 // Sum across action dim (dim=1) -> [B,1]
248 log_prob.sum_dims(&[1], true)
249}More examples
89fn demonstrate_basic_operations() {
90 println!("\n--- Basic Operations ---");
91
92 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
93 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
94
95 // Addition
96 let sum = a.add_tensor(&b);
97 println!("A + B: {:?}", sum.data());
98
99 // Subtraction
100 let diff = a.sub_tensor(&b);
101 println!("A - B: {:?}", diff.data());
102
103 // Multiplication
104 let product = a.mul_tensor(&b);
105 println!("A * B: {:?}", product.data());
106
107 // Division
108 let quotient = a.div_tensor(&b);
109 println!("A / B: {:?}", quotient.data());
110
111 // Scalar operations
112 let scalar_add = a.add_scalar(5.0);
113 println!("A + 5.0: {:?}", scalar_add.data());
114
115 let scalar_mul = a.mul_scalar(2.0);
116 println!("A * 2.0: {:?}", scalar_mul.data());
117}Sourcepub fn div_scalar(&self, scalar: f32) -> Tensor
pub fn div_scalar(&self, scalar: f32) -> Tensor
Broadcast division with a scalar value.
Divides every element by the scalar: output[i] = self[i] / scalar
§Arguments
scalar- Value to divide each element by (must not be zero)
§Returns
A new tensor with each element divided by the scalar
§Examples
§Basic Scalar Division
use train_station::Tensor;
let a = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![3]).unwrap();
let b = a.div_scalar(10.0);
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 1.0);
assert_eq!(b.get(&[1]), 2.0);
assert_eq!(b.get(&[2]), 3.0);§Multi-dimensional Scalar Division
use train_station::Tensor;
let a = Tensor::ones(vec![2, 3]);
let b = a.div_scalar(2.0);
assert_eq!(b.shape().dims(), vec![2, 3]);
assert_eq!(b.get(&[0, 0]), 0.5);
assert_eq!(b.get(&[1, 2]), 0.5);§Panics
Panics if scalar is zero
Examples found in repository?
72 pub fn forward(
73 &self,
74 query: &Tensor,
75 key: &Tensor,
76 value: &Tensor,
77 attn_mask: Option<&Tensor>,
78 ) -> Tensor {
79 let qkv = Self::project_qkv(query, key, value, &self.q_proj, &self.k_proj, &self.v_proj);
80 let (q, k, v) = qkv;
81
82 // Split heads: [b, t, e] -> [b, h, t, d]
83 let (b, tq, _e) = Self::triple(query);
84 let (_b2, tk, _e2) = Self::triple(key);
85 let q = Self::split_heads(&q, b, tq, self.num_heads, self.head_dim);
86 let k = Self::split_heads(&k, b, tk, self.num_heads, self.head_dim);
87 let v = Self::split_heads(&v, b, tk, self.num_heads, self.head_dim);
88
89 // Scaled dot-product attention
90 // logits: [b, h, tq, tk]
91 let k_t = k.transpose(2, 3);
92 let mut logits = q.matmul(&k_t).div_scalar((self.head_dim as f32).sqrt());
93 if let Some(mask) = attn_mask {
94 let dims = mask.shape().dims().to_vec();
95 // If boolean-like mask matching [b,h,tq,tk], apply masked_fill
96 if dims.len() == 4 && dims[0] == b && dims[1] == self.num_heads && dims[2] == tq {
97 // Interpret mask > 0.5 as keep; we invert to build masked positions
98 let cond: Vec<bool> = mask.data().iter().map(|&v| v < 0.5).collect();
99 // Apply masked fill on a flattened view, then reshape back
100 let flat_logits = logits.view(vec![(b * self.num_heads * tq * tk) as i32]);
101 let filled = flat_logits.masked_fill(&cond, f32::NEG_INFINITY);
102 logits = filled.view(vec![b as i32, self.num_heads as i32, tq as i32, tk as i32]);
103 } else {
104 // Fallback: additive mask
105 logits = logits.add_tensor(mask);
106 }
107 }
108 let attn = logits.softmax(3);
109
110 // context: [b, h, tq, d]
111 let context = attn.matmul(&v);
112 let context = context.permute(vec![0, 2, 1, 3]); // [b, tq, h, d]
113 let context = context.contiguous().view(vec![
114 b as i32,
115 tq as i32,
116 (self.num_heads * self.head_dim) as i32,
117 ]);
118
119 // Output projection (flatten to 2D, project, then restore 3D)
120 let flat = context.view(vec![(b * tq) as i32, self.embed_dim as i32]);
121 let out2d = self.out_proj.forward(&flat);
122 out2d.view(vec![b as i32, tq as i32, self.embed_dim as i32])
123 }More examples
87fn demonstrate_data_pipeline() -> Result<(), Box<dyn std::error::Error>> {
88 println!("\n--- Data Processing Pipeline ---");
89
90 // Simulate raw sensor data with noise
91 let raw_data: Vec<f32> = (0..20)
92 .map(|i| {
93 let base = i as f32 * 0.5;
94 let noise = (i % 3) as f32 * 0.1;
95 base + noise
96 })
97 .collect();
98
99 let tensor = Tensor::from_slice(&raw_data, vec![20])?;
100 println!("Raw sensor data: {:?}", tensor.data());
101
102 // Multi-stage processing pipeline
103 println!("\nProcessing pipeline:");
104 println!("1. Normalize data (z-score)");
105 println!("2. Apply smoothing filter");
106 println!("3. Detect outliers");
107 println!("4. Apply feature scaling");
108
109 // Stage 1: Normalization
110 let mean = tensor.mean().value();
111 let std = tensor.std().value();
112 let normalized: Tensor = tensor
113 .iter()
114 .map(|elem| elem.sub_scalar(mean).div_scalar(std))
115 .collect();
116 println!(
117 " Normalized (mean={:.3}, std={:.3}): {:?}",
118 mean,
119 std,
120 normalized.data()
121 );
122
123 // Stage 2: Smoothing (simple moving average)
124 let smoothed: Tensor = normalized
125 .iter()
126 .enumerate()
127 .map(|(i, elem)| {
128 if i == 0 || i == normalized.size() - 1 {
129 elem.clone()
130 } else {
131 // Simple 3-point average
132 let prev = normalized.element_view(i - 1);
133 let next = normalized.element_view(i + 1);
134 elem.add_tensor(&prev).add_tensor(&next).div_scalar(3.0)
135 }
136 })
137 .collect();
138 println!(" Smoothed: {:?}", smoothed.data());
139
140 // Stage 3: Outlier detection and removal
141 let outlier_threshold = 2.0;
142 let cleaned: Tensor = smoothed
143 .iter()
144 .filter(|elem| elem.value().abs() < outlier_threshold)
145 .collect();
146 println!(
147 " Outliers removed (threshold={}): {:?}",
148 outlier_threshold,
149 cleaned.data()
150 );
151
152 // Stage 4: Feature scaling to [0, 1] range
153 let min_val = cleaned
154 .iter()
155 .map(|e| e.value())
156 .fold(f32::INFINITY, f32::min);
157 let max_val = cleaned
158 .iter()
159 .map(|e| e.value())
160 .fold(f32::NEG_INFINITY, f32::max);
161 let scaled: Tensor = cleaned
162 .iter()
163 .map(|elem| elem.sub_scalar(min_val).div_scalar(max_val - min_val))
164 .collect();
165 println!(" Scaled to [0,1]: {:?}", scaled.data());
166
167 Ok(())
168}
169
170/// Demonstrate conditional processing patterns
171///
172/// Shows how to implement dynamic filtering and transformation
173/// based on data characteristics and conditions.
174fn demonstrate_conditional_processing() -> Result<(), Box<dyn std::error::Error>> {
175 println!("\n--- Conditional Processing ---");
176
177 // Create data with mixed characteristics
178 let data = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0];
179 let tensor = Tensor::from_slice(&data, vec![10])?;
180 println!("Input data: {:?}", tensor.data());
181
182 // Conditional transformation based on sign
183 println!("\nConditional transformation (positive/negative handling):");
184 let processed: Tensor = tensor
185 .iter()
186 .map(|elem| {
187 let val = elem.value();
188 if val > 0.0 {
189 elem.pow_scalar(2.0) // Square positive values
190 } else {
191 elem.mul_scalar(-1.0).sqrt() // Square root of absolute negative values
192 }
193 })
194 .collect();
195 println!(" Processed: {:?}", processed.data());
196
197 // Adaptive filtering based on local statistics
198 println!("\nAdaptive filtering (remove values > 2 std from local mean):");
199 let window_size = 3;
200 let adaptive_filtered: Tensor = tensor
201 .iter()
202 .enumerate()
203 .filter(|(i, elem)| {
204 let start = i.saturating_sub(window_size / 2);
205 let end = (i + window_size / 2 + 1).min(tensor.size());
206
207 // Calculate local mean and std
208 let local_values: Vec<f32> = (start..end)
209 .map(|j| tensor.element_view(j).value())
210 .collect();
211
212 let local_mean = local_values.iter().sum::<f32>() / local_values.len() as f32;
213 let local_variance = local_values
214 .iter()
215 .map(|v| (v - local_mean).powi(2))
216 .sum::<f32>()
217 / local_values.len() as f32;
218 let local_std = local_variance.sqrt();
219
220 let threshold = local_mean + 2.0 * local_std;
221 elem.value() <= threshold
222 })
223 .map(|(_, elem)| elem)
224 .collect();
225 println!(" Adaptive filtered: {:?}", adaptive_filtered.data());
226
227 // Multi-condition processing
228 println!("\nMulti-condition processing:");
229 let multi_processed: Tensor = tensor
230 .iter()
231 .map(|elem| {
232 let val = elem.value();
233 match () {
234 _ if val > 5.0 => elem.mul_scalar(2.0), // Double large values
235 _ if val < -5.0 => elem.div_scalar(2.0), // Halve small values
236 _ if val.abs() < 2.0 => elem.add_scalar(1.0), // Add 1 to small values
237 _ => elem.clone(), // Keep others unchanged
238 }
239 })
240 .collect();
241 println!(" Multi-condition: {:?}", multi_processed.data());
242
243 Ok(())
244}
245
246/// Demonstrate batch processing operations
247///
248/// Shows efficient processing of large datasets using iterator
249/// patterns and batch operations for performance optimization.
250fn demonstrate_batch_operations() -> Result<(), Box<dyn std::error::Error>> {
251 println!("\n--- Batch Operations ---");
252
253 // Create a larger dataset for batch processing
254 let size = 100;
255 let data: Vec<f32> = (0..size)
256 .map(|i| {
257 let x = i as f32 / size as f32;
258 x * x + 0.1 * (i % 7) as f32 // Quadratic with some noise
259 })
260 .collect();
261
262 let tensor = Tensor::from_slice(&data, vec![size])?;
263 println!("Dataset size: {}", tensor.size());
264
265 // Batch processing with windowing (iterator views)
266 println!("\nBatch processing with sliding windows:");
267 let batch_size = 10;
268 let batches: Vec<Tensor> = tensor
269 .iter()
270 .collect::<Vec<_>>()
271 .chunks(batch_size)
272 .map(|chunk| {
273 // Process each batch independently
274 chunk
275 .iter()
276 .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
277 .collect()
278 })
279 .collect();
280
281 println!(
282 " Processed {} batches of size {}",
283 batches.len(),
284 batch_size
285 );
286 for (i, batch) in batches.iter().enumerate() {
287 println!(
288 " Batch {}: mean={:.3}, std={:.3}",
289 i,
290 batch.mean().value(),
291 batch.std().value()
292 );
293 }
294
295 // Parallel-like processing with stride
296 println!("\nStrided processing (every nth element):");
297 let stride = 5;
298 let strided: Tensor = tensor
299 .iter()
300 .enumerate()
301 .filter(|(i, _)| i % stride == 0)
302 .map(|(_, elem)| elem)
303 .collect();
304 println!(" Strided (every {}th): {:?}", stride, strided.data());
305
306 // Hierarchical processing
307 println!("\nHierarchical processing (coarse to fine):");
308 let coarse: Tensor = tensor
309 .iter()
310 .enumerate()
311 .filter(|(i, _)| i % 4 == 0) // Take every 4th element
312 .map(|(_, elem)| elem)
313 .collect();
314
315 let fine: Tensor = tensor
316 .iter()
317 .enumerate()
318 .filter(|(i, _)| i % 4 != 0) // Take the rest
319 .map(|(_, elem)| elem)
320 .collect();
321
322 println!(" Coarse (every 4th): {:?}", coarse.data());
323 println!(" Fine (rest): {:?}", fine.data());
324
325 // Combine coarse and fine with different processing
326 let combined: Tensor = coarse
327 .iter()
328 .map(|elem| elem.mul_scalar(2.0)) // Scale coarse
329 .chain(fine.iter().map(|elem| elem.div_scalar(2.0))) // Scale fine
330 .collect();
331 println!(" Combined: {:?}", combined.data());
332
333 Ok(())
334}
335
336/// Demonstrate real-world processing scenarios
337///
338/// Shows practical applications of iterator patterns for
339/// common data processing tasks in machine learning and analytics.
340fn demonstrate_real_world_scenarios() -> Result<(), Box<dyn std::error::Error>> {
341 println!("\n--- Real-world Scenarios ---");
342
343 // Scenario 1: Time series analysis
344 println!("\nScenario 1: Time Series Analysis");
345 let time_series: Vec<f32> = (0..24)
346 .map(|hour| {
347 let base = 20.0 + 10.0 * (hour as f32 * std::f32::consts::PI / 12.0).sin();
348 base + (hour % 3) as f32 * 2.0 // Add some noise
349 })
350 .collect();
351
352 let series = Tensor::from_slice(&time_series, vec![24])?;
353 println!(" Time series (24 hours): {:?}", series.data());
354
355 // Calculate moving average with view-based iteration
356 let window_size = 3;
357 let moving_avg: Tensor = series
358 .iter()
359 .enumerate()
360 .map(|(i, _)| {
361 let start = i.saturating_sub(window_size / 2);
362 let end = (i + window_size / 2 + 1).min(series.size());
363 let window = series.iter_range(start, end);
364 window.fold(0.0, |acc, elem| acc + elem.value()) / (end - start) as f32
365 })
366 .map(|val| Tensor::from_slice(&[val], vec![1]).unwrap())
367 .collect();
368 println!(
369 " Moving average (window={}): {:?}",
370 window_size,
371 moving_avg.data()
372 );
373
374 // Inference pipeline with NoGrad + streaming
375 println!("\nInference pipeline (NoGrad + streaming)");
376 let features = Tensor::from_slice(
377 &(0..48).map(|i| i as f32 * 0.125).collect::<Vec<_>>(),
378 vec![6, 8],
379 )?;
380 let fast = with_no_grad(|| {
381 // Stream values directly, apply light affine, and collect back to same shape
382 features
383 .data()
384 .iter()
385 .copied()
386 .map(|x| 0.75 * x + 0.1)
387 .collect_shape(vec![6, 8])
388 });
389 println!(
390 " NoGrad streamed transform shape: {:?}",
391 fast.shape().dims()
392 );
393
394 // Row-wise iteration with shape-preserving collection (GradTrack-friendly)
395 let per_row: Tensor = features
396 .iter()
397 .map(|row| row.mul_scalar(0.5).add_scalar(2.0))
398 .collect_shape(vec![6, 8]);
399 println!(" Row-wise mapped shape: {:?}", per_row.shape().dims());
400
401 // Scenario 2: Feature engineering
402 println!("\nScenario 2: Feature Engineering");
403 let features = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
404 println!(" Original features: {:?}", features.data());
405
406 // Create polynomial features
407 let poly_features: Tensor = features
408 .iter()
409 .flat_map(|elem| {
410 vec![
411 elem.clone(), // x^1
412 elem.pow_scalar(2.0), // x^2
413 elem.pow_scalar(3.0), // x^3
414 ]
415 })
416 .collect();
417 println!(
418 " Polynomial features (x, x^2, x^3): {:?}",
419 poly_features.data()
420 );
421
422 // Scenario 3: Data augmentation
423 println!("\nScenario 3: Data Augmentation");
424 let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?;
425 println!(" Original data: {:?}", original.data());
426
427 // Augment with noise and scaling
428 let augmented: Tensor = original
429 .iter()
430 .flat_map(|elem| {
431 vec![
432 elem.clone(), // Original
433 elem.add_scalar(0.1), // Add noise
434 elem.sub_scalar(0.1), // Subtract noise
435 elem.mul_scalar(1.1), // Scale up
436 elem.mul_scalar(0.9), // Scale down
437 ]
438 })
439 .collect();
440 println!(" Augmented data: {:?}", augmented.data());
441
442 // Scenario 4: Statistical analysis
443 println!("\nScenario 4: Statistical Analysis");
444 let sample_data = Tensor::from_slice(&[1.1, 2.3, 1.8, 2.1, 1.9, 2.0, 1.7, 2.2], vec![8])?;
445 println!(" Sample data: {:?}", sample_data.data());
446
447 // Calculate various statistics
448 let mean = sample_data.mean().value();
449 let std = sample_data.std().value();
450 let min = sample_data
451 .iter()
452 .map(|e| e.value())
453 .fold(f32::INFINITY, f32::min);
454 let max = sample_data
455 .iter()
456 .map(|e| e.value())
457 .fold(f32::NEG_INFINITY, f32::max);
458
459 // Z-score normalization
460 let z_scores: Tensor = sample_data
461 .iter()
462 .map(|elem| elem.sub_scalar(mean).div_scalar(std))
463 .collect();
464
465 println!(
466 " Statistics: mean={:.3}, std={:.3}, min={:.3}, max={:.3}",
467 mean, std, min, max
468 );
469 println!(" Z-scores: {:?}", z_scores.data());
470
471 Ok(())
472}Source§impl Tensor
impl Tensor
Sourcepub fn exp(&self) -> Tensor
pub fn exp(&self) -> Tensor
Element-wise exponential function.
Computes e^x for each element: output[i] = e^(self[i])
§Returns
A new tensor with the exponential of each element
§Examples
§Basic Exponential
use train_station::Tensor;
let a = Tensor::from_slice(&[0.0, 1.0, 2.0], vec![3]).unwrap();
let b = a.exp();
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 1.0); // e^0 = 1
assert!((b.get(&[1]) - 2.71828).abs() < 1e-5); // e^1 ≈ 2.71828
assert!((b.get(&[2]) - 7.38906).abs() < 1e-5); // e^2 ≈ 7.38906§Negative Values
use train_station::Tensor;
let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
let b = a.exp();
assert_eq!(b.shape().dims(), vec![3]);
assert!((b.get(&[0]) - 0.36788).abs() < 1e-5); // e^(-1) ≈ 0.36788
assert_eq!(b.get(&[1]), 1.0); // e^0 = 1
assert!((b.get(&[2]) - 2.71828).abs() < 1e-5); // e^1 ≈ 2.71828Examples found in repository?
More examples
44fn cross_entropy_logits(
45 logits: &Tensor,
46 labels: &[usize],
47 batch: usize,
48 _num_classes: usize,
49) -> Tensor {
50 // log_softmax = logits - logsumexp(logits, dim=1)
51 let max_logits = logits.max_dims(&[1], true);
52 let shifted = logits.sub_tensor(&max_logits);
53 let exp = shifted.exp();
54 let sum_exp = exp.sum_dims(&[1], true);
55 let log_sum_exp = sum_exp.log();
56 let log_softmax = shifted.sub_tensor(&log_sum_exp);
57 let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58 ll.mul_scalar(-1.0).mean()
59}273fn log_prob_actions(
274 logits: &Tensor,
275 actions: &[usize],
276 batch: usize,
277 _action_dim: usize,
278) -> Tensor {
279 let max_logits = logits.max_dims(&[1], true); // [B,1]
280 let shifted = logits.sub_tensor(&max_logits);
281 let exp = shifted.exp();
282 let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283 let log_sum_exp = sum_exp.log(); // [B,1]
284 let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285 // gather selected action log-probs
286 log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291 new_logp.sub_tensor(old_logp).exp()
292}234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235 // All tensors shaped [B, A] (log_std is broadcastable)
236 let std = log_std.exp();
237 let var = std.pow_scalar(2.0);
238 let log_scale = log_std;
239 let diff = action.sub_tensor(mean);
240 let log_prob = diff
241 .pow_scalar(2.0)
242 .div_tensor(&var)
243 .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244 .add_tensor(&log_scale.mul_scalar(2.0))
245 .mul_scalar(0.5)
246 .mul_scalar(-1.0);
247 // Sum across action dim (dim=1) -> [B,1]
248 log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253 returns_out: &mut [f32],
254 adv_out: &mut [f32],
255 rewards: &[f32],
256 dones: &[f32],
257 values: &[f32],
258 next_values: &[f32],
259 gamma: f32,
260 lam: f32,
261) {
262 let n = rewards.len();
263 let mut gae = 0.0f32;
264 for t in (0..n).rev() {
265 let not_done = 1.0 - dones[t];
266 let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267 gae = delta + gamma * lam * not_done * gae;
268 adv_out[t] = gae;
269 returns_out[t] = gae + values[t];
270 }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274 let n = x.len() as f32;
275 if n <= 1.0 {
276 return;
277 }
278 let mean = x.iter().copied().sum::<f32>() / n;
279 let var = x
280 .iter()
281 .map(|v| {
282 let d = v - mean;
283 d * d
284 })
285 .sum::<f32>()
286 / n;
287 let std = (var + eps).sqrt();
288 for v in x.iter_mut() {
289 *v = (*v - mean) / std;
290 }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294 let mut total_sq = 0.0f32;
295 for p in parameters.iter() {
296 if let Some(g) = p.grad_owned() {
297 for &v in g.data() {
298 total_sq += v * v;
299 }
300 }
301 }
302 let norm = total_sq.sqrt();
303 if norm > max_norm {
304 let scale = max_norm / (norm + eps);
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 p.set_grad(g.mul_scalar(scale));
308 }
309 }
310 }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314 let mut total_sq = 0.0f32;
315 for p in parameters.iter_mut() {
316 if let Some(g) = p.grad_owned() {
317 for &v in g.data() {
318 total_sq += v * v;
319 }
320 }
321 }
322 total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330 println!("=== PPO Continuous Example (YardEnv) ===");
331
332 let state_dim = 3usize;
333 let action_dim = 1usize;
334
335 // Hparams
336 let total_steps = std::env::var("PPO_STEPS")
337 .ok()
338 .and_then(|v| v.parse::<usize>().ok())
339 .unwrap_or(4000usize);
340 let horizon = 128usize; // rollout length per update
341 let epochs = 4usize; // PPO epochs per update
342 let mini_batch_size = 64usize; // minibatch from horizon
343 let gamma = 0.99f32;
344 let lam = 0.95f32; // GAE lambda
345 let clip_eps = 0.2f32;
346 let vf_coef = 0.5f32;
347 let ent_coef = 0.0f32;
348 let max_grad_norm = 1.0f32;
349
350 // Models
351 let mut actor = Actor::new(state_dim, action_dim, Some(101));
352 let mut critic = Critic::new(state_dim, Some(202));
353
354 // Opts
355 let mut actor_opt = Adam::with_learning_rate(3e-4);
356 for p in actor.parameters() {
357 actor_opt.add_parameter(p);
358 }
359 let mut critic_opt = Adam::with_learning_rate(3e-4);
360 for p in critic.parameters() {
361 critic_opt.add_parameter(p);
362 }
363
364 // Env and RNG
365 let mut env = YardEnv::new(42);
366 let mut rng = SmallRng::new(999);
367 let mut state = env.reset();
368
369 // Metrics
370 let mut episode_return = 0.0f32;
371 let mut episode = 0usize;
372 let mut ema_return: Option<f32> = None;
373 let ema_alpha = 0.05f32;
374 let mut best_return = f32::NEG_INFINITY;
375
376 let mut t = 0usize;
377 while t < total_steps {
378 // Collect a rollout
379 let mut batch = RolloutBatch::new(horizon, state_dim);
380 for _ in 0..horizon {
381 // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382 let (mean, log_std_row) = actor.forward(&state);
383 let mean_v = mean.data()[0];
384 let log_std_v = log_std_row.data()[0];
385 let std_v = log_std_v.exp();
386 let noise = rng.normal();
387 let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389 // Build action tensor [1, A] for log_prob calculation with autograd
390 let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391 let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392 let log_prob_v = log_prob_t.data()[0];
393
394 // Step env
395 let (next_state, reward, done) = env.step(action_v);
396 episode_return += reward;
397
398 // Value
399 let value_t = critic.forward(&state);
400 let value_v = value_t.data()[0];
401
402 // Push
403 batch.push(
404 state.data(),
405 action_v,
406 log_prob_v,
407 reward,
408 if done { 1.0 } else { 0.0 },
409 value_v,
410 next_state.data(),
411 );
412
413 // Reset
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return
430 );
431 episode_return = 0.0;
432 episode += 1;
433 st
434 } else {
435 next_state
436 };
437
438 t += 1;
439 if t >= total_steps {
440 break;
441 }
442 }
443
444 // Bootstrap next values for GAE
445 let next_values: Vec<f32> = {
446 let mut out = Vec::with_capacity(batch.len());
447 for i in 0..batch.len() {
448 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450 let v2 = critic.forward(&s2_t).data()[0];
451 out.push(v2);
452 }
453 out
454 };
455
456 // Compute returns and advantages
457 let mut returns = vec![0.0f32; batch.len()];
458 let mut adv = vec![0.0f32; batch.len()];
459 compute_gae(
460 &mut returns,
461 &mut adv,
462 &batch.rewards,
463 &batch.dones,
464 &batch.values,
465 &next_values,
466 gamma,
467 lam,
468 );
469 normalize_in_place(&mut adv, 1e-8);
470
471 // Prepare tensors for training
472 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473 let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474 let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478 // PPO epochs over the rollout
479 let num_minibatches = batch.len().div_ceil(mini_batch_size);
480 for e in 0..epochs {
481 for mb in 0..num_minibatches {
482 let start = mb * mini_batch_size;
483 let end = (start + mini_batch_size).min(batch.len());
484 if start >= end {
485 break;
486 }
487
488 // Slice views
489 let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490 let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491 let a_mb = actions_t
492 .slice_view(start * action_dim, 1, (end - start) * action_dim)
493 .reshape(vec![(end - start) as i32, action_dim as i32]);
494 let oldlp_mb = old_logp_t
495 .slice_view(start, 1, end - start)
496 .reshape(vec![(end - start) as i32, 1]);
497 let ret_mb = returns_t
498 .slice_view(start, 1, end - start)
499 .reshape(vec![(end - start) as i32, 1]);
500 let adv_mb = adv_t
501 .slice_view(start, 1, end - start)
502 .reshape(vec![(end - start) as i32, 1]);
503
504 // Zero grads
505 {
506 let mut ps = actor.parameters();
507 actor_opt.zero_grad(&mut ps);
508 }
509 {
510 let mut ps = critic.parameters();
511 critic_opt.zero_grad(&mut ps);
512 }
513
514 // Forward actor and critic
515 let (mean_mb, log_std_row) = actor.forward(&s_mb);
516 let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517 let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518 let clip_low =
519 Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520 .unwrap();
521 let clip_high =
522 Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523 .unwrap();
524 // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525 let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526 let ratio_clipped =
527 clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528 let pg1 = ratio.mul_tensor(&adv_mb);
529 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532 let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534 let v_pred = critic.forward(&s_mb);
535 let v_loss = v_pred
536 .sub_tensor(&ret_mb)
537 .pow_scalar(2.0)
538 .mean()
539 .mul_scalar(vf_coef);
540
541 // Entropy (approx Gaussian entropy per action)
542 let entropy = log_std_row
543 .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544 .sum_dims(&[1], true)
545 .mean()
546 .mul_scalar(ent_coef);
547
548 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549 loss.backward(None);
550
551 // Step actor
552 {
553 let params = actor.parameters();
554 let mut with_grads: Vec<&mut Tensor> = Vec::new();
555 for p in params {
556 if p.grad_owned().is_some() {
557 with_grads.push(p);
558 }
559 }
560 if !with_grads.is_empty() {
561 let _ = grad_global_norm(&mut with_grads);
562 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563 actor_opt.step(&mut with_grads);
564 actor_opt.zero_grad(&mut with_grads);
565 }
566 }
567
568 // Step critic
569 {
570 let params = critic.parameters();
571 let mut with_grads: Vec<&mut Tensor> = Vec::new();
572 for p in params {
573 if p.grad_owned().is_some() {
574 with_grads.push(p);
575 }
576 }
577 if !with_grads.is_empty() {
578 let _ = grad_global_norm(&mut with_grads);
579 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580 critic_opt.step(&mut with_grads);
581 critic_opt.zero_grad(&mut with_grads);
582 }
583 }
584
585 // Occasionally log
586 if e == 0 && mb == 0 {
587 println!(
588 "update@t={} | actor_loss={:.4} v_loss={:.4}",
589 t,
590 actor_loss.value(),
591 v_loss.value()
592 );
593 }
594
595 clear_all_graphs_known();
596 }
597 }
598 }
599
600 println!("=== PPO training finished ===");
601 Ok(())
602}208fn demonstrate_advanced_patterns() -> Result<(), Box<dyn std::error::Error>> {
209 println!("\n--- Advanced Iterator Patterns ---");
210
211 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6])?;
212 println!("Input tensor: {:?}", tensor.data());
213
214 // Complex chain: enumerate -> filter -> map -> collect
215 println!("\nComplex chain (even indices only, add index to value):");
216 let result: Tensor = tensor
217 .iter()
218 .enumerate()
219 .filter(|(i, _)| i % 2 == 0) // Take even indices
220 .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
221 .collect();
222 println!(" Result: {:?}", result.data());
223
224 // Using take and skip for windowing
225 println!("\nWindowing with take and skip:");
226 let window1: Tensor = tensor.iter().take(3).collect();
227 let window2: Tensor = tensor.iter().skip(2).take(3).collect();
228 println!(" Window 1 (first 3): {:?}", window1.data());
229 println!(" Window 2 (middle 3): {:?}", window2.data());
230
231 // Using rev() for reverse iteration
232 println!("\nReverse iteration:");
233 let reversed: Tensor = tensor.iter().rev().collect();
234 println!(" Reversed: {:?}", reversed.data());
235
236 // Chaining with mathematical operations
237 println!("\nMathematical operation chain:");
238 let math_result: Tensor = tensor
239 .iter()
240 .map(|elem| elem.exp()) // e^x
241 .filter(|elem| elem.value() < 50.0) // Filter large values
242 .map(|elem| elem.log()) // ln(x)
243 .collect();
244 println!(" Math chain result: {:?}", math_result.data());
245
246 // Using zip for element-wise combinations
247 println!("\nElement-wise combination with zip:");
248 let tensor2 = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![6])?;
249 let combined: Tensor = tensor
250 .iter()
251 .zip(tensor2.iter())
252 .map(|(a, b)| a.mul_tensor(&b)) // Element-wise multiplication
253 .collect();
254 println!(" Combined: {:?}", combined.data());
255
256 Ok(())
257}Source§impl Tensor
impl Tensor
Sourcepub fn leaky_relu(&self, negative_slope: f32) -> Tensor
pub fn leaky_relu(&self, negative_slope: f32) -> Tensor
Element-wise Leaky ReLU activation.
Applies Leaky ReLU to each element: output[i] = max(0, x) + negative_slope * min(0, x)
Unlike standard ReLU, allows a small gradient when the unit is not active.
§Arguments
negative_slope- Slope for negative values (typically small, e.g., 0.01 or 0.1)
§Returns
A new tensor with Leaky ReLU applied to each element
§Examples
§Basic Leaky ReLU
use train_station::Tensor;
let a = Tensor::from_slice(&[-2.0, -1.0, 0.0, 1.0], vec![4]).unwrap();
let b = a.leaky_relu(0.1);
assert_eq!(b.shape().dims(), vec![4]);
assert!((b.get(&[0]) - (-0.2)).abs() < 1e-6); // -2.0 * 0.1 = -0.2
assert!((b.get(&[1]) - (-0.1)).abs() < 1e-6); // -1.0 * 0.1 = -0.1
assert_eq!(b.get(&[2]), 0.0); // max(0, 0) = 0
assert_eq!(b.get(&[3]), 1.0); // max(0, 1) = 1§Different Negative Slopes
use train_station::Tensor;
let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
let b = a.leaky_relu(0.01); // Smaller negative slope
assert_eq!(b.shape().dims(), vec![3]);
assert!((b.get(&[0]) - (-0.01)).abs() < 1e-6); // -1.0 * 0.01 = -0.01
assert_eq!(b.get(&[1]), 0.0); // max(0, 0) = 0
assert_eq!(b.get(&[2]), 1.0); // max(0, 1) = 1Source§impl Tensor
impl Tensor
Sourcepub fn log(&self) -> Tensor
pub fn log(&self) -> Tensor
Element-wise natural logarithm.
Computes the natural logarithm for each element: output[i] = ln(self[i])
§Returns
A new tensor with the natural logarithm of each element
§Examples
§Basic Natural Logarithm
use train_station::Tensor;
let a = Tensor::from_slice(&[1.0, 2.71828, 7.38906], vec![3]).unwrap();
let b = a.log();
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 0.0); // ln(1) = 0
assert!((b.get(&[1]) - 1.0).abs() < 1e-5); // ln(e) ≈ 1
assert!((b.get(&[2]) - 2.0).abs() < 1e-5); // ln(e^2) ≈ 2§Mathematical Properties
use train_station::Tensor;
let a = Tensor::from_slice(&[4.0, 8.0, 16.0], vec![3]).unwrap();
let b = a.log();
assert_eq!(b.shape().dims(), vec![3]);
assert!((b.get(&[0]) - 1.38629).abs() < 1e-5); // ln(4) ≈ 1.38629
assert!((b.get(&[1]) - 2.07944).abs() < 1e-5); // ln(8) ≈ 2.07944
assert!((b.get(&[2]) - 2.77259).abs() < 1e-5); // ln(16) ≈ 2.77259§Panics
Panics if any element is non-positive (x <= 0)
Examples found in repository?
More examples
44fn cross_entropy_logits(
45 logits: &Tensor,
46 labels: &[usize],
47 batch: usize,
48 _num_classes: usize,
49) -> Tensor {
50 // log_softmax = logits - logsumexp(logits, dim=1)
51 let max_logits = logits.max_dims(&[1], true);
52 let shifted = logits.sub_tensor(&max_logits);
53 let exp = shifted.exp();
54 let sum_exp = exp.sum_dims(&[1], true);
55 let log_sum_exp = sum_exp.log();
56 let log_softmax = shifted.sub_tensor(&log_sum_exp);
57 let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58 ll.mul_scalar(-1.0).mean()
59}273fn log_prob_actions(
274 logits: &Tensor,
275 actions: &[usize],
276 batch: usize,
277 _action_dim: usize,
278) -> Tensor {
279 let max_logits = logits.max_dims(&[1], true); // [B,1]
280 let shifted = logits.sub_tensor(&max_logits);
281 let exp = shifted.exp();
282 let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283 let log_sum_exp = sum_exp.log(); // [B,1]
284 let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285 // gather selected action log-probs
286 log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291 new_logp.sub_tensor(old_logp).exp()
292}
293
294// Clamp ratio to [1-clip, 1+clip] using ReLU-based clamp (no custom ops)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296 let b = ratio.shape().dims()[0];
297 let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298 let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299 let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300 high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304 let mut total_sq = 0.0f32;
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 for &v in g.data() {
308 total_sq += v * v;
309 }
310 }
311 }
312 total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320 println!("=== PPO Discrete Example (YardEnv) ===");
321
322 let state_dim = 3usize;
323 let action_dim = 3usize;
324 let total_steps = std::env::var("PPOD_STEPS")
325 .ok()
326 .and_then(|v| v.parse::<usize>().ok())
327 .unwrap_or(3500usize);
328 let horizon = 128usize;
329 let epochs = 4usize;
330 let mini_batch_size = 64usize;
331 let gamma = 0.99f32;
332 let lam = 0.95f32;
333 let clip_eps = 0.2f32;
334 let vf_coef = 0.5f32;
335 let ent_coef = 0.0f32;
336 let max_grad_norm = 1.0f32;
337
338 let mut actor = Actor::new(state_dim, action_dim, Some(111));
339 let mut critic = Critic::new(state_dim, Some(222));
340 let mut actor_opt = Adam::with_learning_rate(3e-4);
341 for p in actor.parameters() {
342 actor_opt.add_parameter(p);
343 }
344 let mut critic_opt = Adam::with_learning_rate(3e-4);
345 for p in critic.parameters() {
346 critic_opt.add_parameter(p);
347 }
348
349 let mut env = YardEnv::new(1234);
350 let mut rng = SmallRng::new(98765);
351 let mut state = env.reset();
352 let mut episode_return = 0.0f32;
353 let mut episode = 0usize;
354 let mut ema_return: Option<f32> = None;
355 let ema_alpha = 0.05f32;
356 let mut best_return = f32::NEG_INFINITY;
357
358 let mut t = 0usize;
359 while t < total_steps {
360 let mut batch = RolloutBatch::new(horizon, state_dim);
361 for _ in 0..horizon {
362 // Actor logits and categorical sampling
363 let logits = actor.forward(&state); // [1, A]
364 let probs = logits.softmax(1); // [1, A]
365 // sample action from probs (CPU sampling)
366 let p = probs.data();
367 let (p0, p1, _p2) = (p[0], p[1], p[2]);
368 let u = rng.next_f32();
369 let a_idx = if u < p0 {
370 0
371 } else if u < p0 + p1 {
372 1
373 } else {
374 2
375 };
376
377 let old_logp = {
378 let _ng = NoGradTrack::new();
379 let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380 lp.data()[0]
381 };
382
383 // Step env
384 let (next_state, reward, done) = env.step(a_idx);
385 episode_return += reward;
386
387 // Critic value
388 let value_t = critic.forward(&state);
389 let value_v = value_t.data()[0];
390
391 batch.push(
392 state.data(),
393 a_idx,
394 old_logp,
395 reward,
396 if done { 1.0 } else { 0.0 },
397 value_v,
398 next_state.data(),
399 );
400
401 state = if done {
402 let st = env.reset();
403 ema_return = Some(match ema_return {
404 None => episode_return,
405 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406 });
407 if episode_return > best_return {
408 best_return = episode_return;
409 }
410 println!(
411 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412 t,
413 episode,
414 episode_return,
415 ema_return.unwrap_or(episode_return),
416 best_return
417 );
418 episode_return = 0.0;
419 episode += 1;
420 st
421 } else {
422 next_state
423 };
424
425 t += 1;
426 if t >= total_steps {
427 break;
428 }
429 }
430
431 // Bootstrap values for GAE
432 let next_values: Vec<f32> = {
433 let mut out = Vec::with_capacity(batch.len());
434 for i in 0..batch.len() {
435 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437 out.push(critic.forward(&s2_t).data()[0]);
438 }
439 out
440 };
441
442 let mut returns = vec![0.0f32; batch.len()];
443 let mut adv = vec![0.0f32; batch.len()];
444 compute_gae(
445 &mut returns,
446 &mut adv,
447 &batch.rewards,
448 &batch.dones,
449 &batch.values,
450 &next_values,
451 gamma,
452 lam,
453 );
454 normalize_in_place(&mut adv, 1e-8);
455
456 // Tensors for training
457 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458 let actions_vec = batch.actions.clone();
459 let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463 // PPO epochs
464 let num_minibatches = batch.len().div_ceil(mini_batch_size);
465 for e in 0..epochs {
466 for mb in 0..num_minibatches {
467 let start = mb * mini_batch_size;
468 let end = (start + mini_batch_size).min(batch.len());
469 if start >= end {
470 break;
471 }
472
473 // Views
474 let s_mb = states_t
475 .slice_view(start * state_dim, 1, (end - start) * state_dim)
476 .reshape(vec![(end - start) as i32, state_dim as i32]);
477 let oldlp_mb = old_logp_t
478 .slice_view(start, 1, end - start)
479 .reshape(vec![(end - start) as i32, 1]);
480 let ret_mb = returns_t
481 .slice_view(start, 1, end - start)
482 .reshape(vec![(end - start) as i32, 1]);
483 let adv_mb = adv_t
484 .slice_view(start, 1, end - start)
485 .reshape(vec![(end - start) as i32, 1]);
486 let a_slice = &actions_vec[start..end];
487
488 // Zero grads
489 {
490 let mut ps = actor.parameters();
491 actor_opt.zero_grad(&mut ps);
492 }
493 {
494 let mut ps = critic.parameters();
495 critic_opt.zero_grad(&mut ps);
496 }
497
498 // Forward
499 let logits_mb = actor.forward(&s_mb); // [B,A]
500 let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501 let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502 let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503 let pg1 = ratio.mul_tensor(&adv_mb);
504 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507 let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509 let v_pred = critic.forward(&s_mb);
510 let v_loss = v_pred
511 .sub_tensor(&ret_mb)
512 .pow_scalar(2.0)
513 .mean()
514 .mul_scalar(vf_coef);
515
516 // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517 let probs_mb = logits_mb.softmax(1);
518 let logp_all = probs_mb.add_scalar(1e-8).log();
519 let ent = probs_mb
520 .mul_tensor(&logp_all)
521 .sum_dims(&[1], true)
522 .mul_scalar(-1.0)
523 .mean()
524 .mul_scalar(ent_coef);
525
526 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527 loss.backward(None);
528
529 // Step actor
530 {
531 let params = actor.parameters();
532 let mut with_grads: Vec<&mut Tensor> = Vec::new();
533 for p in params {
534 if p.grad_owned().is_some() {
535 with_grads.push(p);
536 }
537 }
538 if !with_grads.is_empty() {
539 let _ = grad_global_norm(&mut with_grads);
540 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541 actor_opt.step(&mut with_grads);
542 actor_opt.zero_grad(&mut with_grads);
543 }
544 }
545
546 // Step critic
547 {
548 let params = critic.parameters();
549 let mut with_grads: Vec<&mut Tensor> = Vec::new();
550 for p in params {
551 if p.grad_owned().is_some() {
552 with_grads.push(p);
553 }
554 }
555 if !with_grads.is_empty() {
556 let _ = grad_global_norm(&mut with_grads);
557 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558 critic_opt.step(&mut with_grads);
559 critic_opt.zero_grad(&mut with_grads);
560 }
561 }
562
563 if e == 0 && mb == 0 {
564 println!(
565 "update@t={} | actor_loss={:.4} v_loss={:.4}",
566 t,
567 actor_loss.value(),
568 v_loss.value()
569 );
570 }
571
572 clear_all_graphs_known();
573 }
574 }
575 }
576
577 println!("=== PPO discrete training finished ===");
578 Ok(())
579}208fn demonstrate_advanced_patterns() -> Result<(), Box<dyn std::error::Error>> {
209 println!("\n--- Advanced Iterator Patterns ---");
210
211 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6])?;
212 println!("Input tensor: {:?}", tensor.data());
213
214 // Complex chain: enumerate -> filter -> map -> collect
215 println!("\nComplex chain (even indices only, add index to value):");
216 let result: Tensor = tensor
217 .iter()
218 .enumerate()
219 .filter(|(i, _)| i % 2 == 0) // Take even indices
220 .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
221 .collect();
222 println!(" Result: {:?}", result.data());
223
224 // Using take and skip for windowing
225 println!("\nWindowing with take and skip:");
226 let window1: Tensor = tensor.iter().take(3).collect();
227 let window2: Tensor = tensor.iter().skip(2).take(3).collect();
228 println!(" Window 1 (first 3): {:?}", window1.data());
229 println!(" Window 2 (middle 3): {:?}", window2.data());
230
231 // Using rev() for reverse iteration
232 println!("\nReverse iteration:");
233 let reversed: Tensor = tensor.iter().rev().collect();
234 println!(" Reversed: {:?}", reversed.data());
235
236 // Chaining with mathematical operations
237 println!("\nMathematical operation chain:");
238 let math_result: Tensor = tensor
239 .iter()
240 .map(|elem| elem.exp()) // e^x
241 .filter(|elem| elem.value() < 50.0) // Filter large values
242 .map(|elem| elem.log()) // ln(x)
243 .collect();
244 println!(" Math chain result: {:?}", math_result.data());
245
246 // Using zip for element-wise combinations
247 println!("\nElement-wise combination with zip:");
248 let tensor2 = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![6])?;
249 let combined: Tensor = tensor
250 .iter()
251 .zip(tensor2.iter())
252 .map(|(a, b)| a.mul_tensor(&b)) // Element-wise multiplication
253 .collect();
254 println!(" Combined: {:?}", combined.data());
255
256 Ok(())
257}Source§impl Tensor
impl Tensor
Sourcepub fn matmul(&self, other: &Tensor) -> Tensor
pub fn matmul(&self, other: &Tensor) -> Tensor
Matrix multiplication with intelligent kernel dispatch
Performs matrix multiplication using optimized SIMD kernels selected based on:
- Runtime SIMD capability (AVX512/AVX2/SSE2/Scalar)
- Matrix operation type (1D@1D, 1D@2D, 2D@1D, 2D@2D, ND@ND)
- Matrix size classification (Small/Medium/Large)
- Memory alignment characteristics
§Arguments
other- Right-hand side tensor for multiplication
§Returns
Result tensor with appropriate shape based on operation type
§Panics
Panics if tensor shapes are incompatible for matrix multiplication
§Examples
use train_station::Tensor;
// 2D matrix multiplication
let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
let result = a.matmul(&b);
assert_eq!(result.shape().dims(), vec![2, 2]);
// 1D dot product
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
let result = a.matmul(&b);
assert_eq!(result.shape().dims(), vec![]); // ScalarExamples found in repository?
More examples
84fn demonstrate_default_adam() -> Result<(), Box<dyn std::error::Error>> {
85 println!("--- Default Adam Configuration ---");
86
87 // Create a simple regression problem: y = 2*x + 1
88 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
89 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
90
91 // Create model parameters
92 let mut weight = Tensor::randn(vec![1, 1], Some(42)).with_requires_grad();
93 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
94
95 // Create Adam optimizer with default configuration
96 let mut optimizer = Adam::new();
97 optimizer.add_parameter(&weight);
98 optimizer.add_parameter(&bias);
99
100 println!("Default Adam configuration:");
101 println!(" Learning rate: {}", optimizer.learning_rate());
102 println!(" Initial weight: {:.6}", weight.value());
103 println!(" Initial bias: {:.6}", bias.value());
104
105 // Training loop
106 let num_epochs = 50;
107 let mut losses = Vec::new();
108
109 for epoch in 0..num_epochs {
110 // Forward pass
111 let y_pred = x_data.matmul(&weight) + &bias;
112 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
113
114 // Backward pass
115 loss.backward(None);
116
117 // Optimizer step
118 optimizer.step(&mut [&mut weight, &mut bias]);
119 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
120
121 losses.push(loss.value());
122
123 if epoch % 10 == 0 || epoch == num_epochs - 1 {
124 println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
125 }
126 }
127
128 // Evaluate final model
129 let _final_predictions = x_data.matmul(&weight) + &bias;
130 println!("\nFinal model:");
131 println!(" Learned weight: {:.6} (target: 2.0)", weight.value());
132 println!(" Learned bias: {:.6} (target: 1.0)", bias.value());
133 println!(" Final loss: {:.6}", losses[losses.len() - 1]);
134
135 Ok(())
136}
137
138/// Demonstrate learning rate comparison
139fn demonstrate_learning_rate_comparison() -> Result<(), Box<dyn std::error::Error>> {
140 println!("\n--- Learning Rate Comparison ---");
141
142 let learning_rates = [0.001, 0.01, 0.1];
143 let mut results = Vec::new();
144
145 for &lr in &learning_rates {
146 println!("\nTesting learning rate: {}", lr);
147
148 let stats = train_with_config(TrainingConfig {
149 learning_rate: lr,
150 ..Default::default()
151 })?;
152
153 results.push((lr, stats.clone()));
154
155 println!(" Final loss: {:.6}", stats.final_loss);
156 println!(" Convergence epoch: {}", stats.convergence_epoch);
157 }
158
159 // Compare results
160 println!("\nLearning Rate Comparison Summary:");
161 for (lr, stats) in &results {
162 println!(
163 " LR={:6}: Loss={:.6}, Converged@{}",
164 lr, stats.final_loss, stats.convergence_epoch
165 );
166 }
167
168 Ok(())
169}
170
171/// Demonstrate weight decay comparison
172fn demonstrate_weight_decay_comparison() -> Result<(), Box<dyn std::error::Error>> {
173 println!("\n--- Weight Decay Comparison ---");
174
175 let weight_decays = [0.0, 0.001, 0.01];
176 let mut results = Vec::new();
177
178 for &wd in &weight_decays {
179 println!("\nTesting weight decay: {}", wd);
180
181 let stats = train_with_config(TrainingConfig {
182 weight_decay: wd,
183 ..Default::default()
184 })?;
185
186 results.push((wd, stats.clone()));
187
188 println!(" Final loss: {:.6}", stats.final_loss);
189 println!(" Final weight norm: {:.6}", stats.weight_norm);
190 }
191
192 // Compare results
193 println!("\nWeight Decay Comparison Summary:");
194 for (wd, stats) in &results {
195 println!(
196 " WD={:6}: Loss={:.6}, Weight Norm={:.6}",
197 wd, stats.final_loss, stats.weight_norm
198 );
199 }
200
201 Ok(())
202}
203
204/// Demonstrate beta parameter tuning
205fn demonstrate_beta_parameter_tuning() -> Result<(), Box<dyn std::error::Error>> {
206 println!("\n--- Beta Parameter Tuning ---");
207
208 let beta_configs = [
209 (0.9, 0.999), // Default
210 (0.8, 0.999), // More aggressive momentum
211 (0.95, 0.999), // Less aggressive momentum
212 (0.9, 0.99), // Faster second moment decay
213 ];
214
215 let mut results = Vec::new();
216
217 for (i, (beta1, beta2)) in beta_configs.iter().enumerate() {
218 println!(
219 "\nTesting beta configuration {}: beta1={}, beta2={}",
220 i + 1,
221 beta1,
222 beta2
223 );
224
225 let config = TrainingConfig {
226 beta1: *beta1,
227 beta2: *beta2,
228 ..Default::default()
229 };
230
231 let stats = train_with_config(config)?;
232 results.push(((*beta1, *beta2), stats.clone()));
233
234 println!(" Final loss: {:.6}", stats.final_loss);
235 println!(" Convergence epoch: {}", stats.convergence_epoch);
236 }
237
238 // Compare results
239 println!("\nBeta Parameter Comparison Summary:");
240 for ((beta1, beta2), stats) in &results {
241 println!(
242 " B1={:4}, B2={:5}: Loss={:.6}, Converged@{}",
243 beta1, beta2, stats.final_loss, stats.convergence_epoch
244 );
245 }
246
247 Ok(())
248}
249
250/// Demonstrate configuration benchmarking
251fn demonstrate_configuration_benchmarking() -> Result<(), Box<dyn std::error::Error>> {
252 println!("\n--- Configuration Benchmarking ---");
253
254 // Define configurations to benchmark
255 let configs = vec![
256 (
257 "Conservative",
258 TrainingConfig {
259 learning_rate: 0.001,
260 weight_decay: 0.001,
261 beta1: 0.95,
262 ..Default::default()
263 },
264 ),
265 (
266 "Balanced",
267 TrainingConfig {
268 learning_rate: 0.01,
269 weight_decay: 0.0,
270 beta1: 0.9,
271 ..Default::default()
272 },
273 ),
274 (
275 "Aggressive",
276 TrainingConfig {
277 learning_rate: 0.1,
278 weight_decay: 0.0,
279 beta1: 0.8,
280 ..Default::default()
281 },
282 ),
283 ];
284
285 let mut benchmark_results = Vec::new();
286
287 for (name, config) in configs {
288 println!("\nBenchmarking {} configuration:", name);
289
290 let start_time = std::time::Instant::now();
291 let stats = train_with_config(config.clone())?;
292 let elapsed = start_time.elapsed();
293
294 println!(" Training time: {:.2}ms", elapsed.as_millis());
295 println!(" Final loss: {:.6}", stats.final_loss);
296 println!(" Convergence: {} epochs", stats.convergence_epoch);
297
298 benchmark_results.push((name.to_string(), stats, elapsed));
299 }
300
301 // Summary
302 println!("\nBenchmarking Summary:");
303 for (name, stats, elapsed) in &benchmark_results {
304 println!(
305 " {:12}: Loss={:.6}, Time={:4}ms, Converged@{}",
306 name,
307 stats.final_loss,
308 elapsed.as_millis(),
309 stats.convergence_epoch
310 );
311 }
312
313 Ok(())
314}
315
316/// Helper function to train with specific configuration
317fn train_with_config(config: TrainingConfig) -> Result<TrainingStats, Box<dyn std::error::Error>> {
318 // Create training data
319 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
320 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
321
322 // Create model parameters
323 let mut weight = Tensor::randn(vec![1, 1], Some(123)).with_requires_grad();
324 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
325
326 // Create optimizer with custom configuration
327 let adam_config = AdamConfig {
328 learning_rate: config.learning_rate,
329 beta1: config.beta1,
330 beta2: config.beta2,
331 eps: 1e-8,
332 weight_decay: config.weight_decay,
333 amsgrad: false,
334 };
335
336 let mut optimizer = Adam::with_config(adam_config);
337 optimizer.add_parameter(&weight);
338 optimizer.add_parameter(&bias);
339
340 // Training loop
341 let mut losses = Vec::new();
342 let mut convergence_epoch = config.epochs;
343
344 for epoch in 0..config.epochs {
345 // Forward pass
346 let y_pred = x_data.matmul(&weight) + &bias;
347 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
348
349 // Backward pass
350 loss.backward(None);
351
352 // Optimizer step
353 optimizer.step(&mut [&mut weight, &mut bias]);
354 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
355
356 let loss_value = loss.value();
357 losses.push(loss_value);
358
359 // Check for convergence (loss < 0.01)
360 if loss_value < 0.01 && convergence_epoch == config.epochs {
361 convergence_epoch = epoch;
362 }
363 }
364
365 Ok(TrainingStats {
366 config,
367 final_loss: losses[losses.len() - 1],
368 loss_history: losses,
369 convergence_epoch,
370 weight_norm: weight.norm().value(),
371 })
372}319fn train_with_scheduler(
320 scheduler: &mut dyn LearningRateScheduler,
321 num_epochs: usize,
322) -> Result<TrainingStats, Box<dyn std::error::Error>> {
323 // Create training data: y = 2*x + 1
324 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
325 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
326
327 // Create model parameters
328 let mut weight = Tensor::randn(vec![1, 1], Some(456)).with_requires_grad();
329 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
330
331 // Create optimizer with initial learning rate
332 let mut optimizer = Adam::with_learning_rate(0.05);
333 optimizer.add_parameter(&weight);
334 optimizer.add_parameter(&bias);
335
336 // Training loop
337 let mut losses = Vec::new();
338 let mut lr_history = Vec::new();
339 let mut convergence_epoch = num_epochs;
340
341 for epoch in 0..num_epochs {
342 // Forward pass
343 let y_pred = x_data.matmul(&weight) + &bias;
344 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
345
346 // Backward pass
347 loss.backward(None);
348
349 // Update learning rate using scheduler
350 let current_lr = optimizer.learning_rate();
351 let new_lr = scheduler.step(current_lr, epoch, loss.value());
352
353 if (new_lr - current_lr).abs() > 1e-8 {
354 optimizer.set_learning_rate(new_lr);
355 }
356
357 // Optimizer step
358 optimizer.step(&mut [&mut weight, &mut bias]);
359 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
360
361 let loss_value = loss.value();
362 losses.push(loss_value);
363 lr_history.push(new_lr);
364
365 // Check for convergence
366 if loss_value < 0.01 && convergence_epoch == num_epochs {
367 convergence_epoch = epoch;
368 }
369 }
370
371 Ok(TrainingStats {
372 scheduler_name: scheduler.name().to_string(),
373 final_loss: losses[losses.len() - 1],
374 lr_history,
375 loss_history: losses,
376 convergence_epoch,
377 })
378}105fn demonstrate_linear_regression() -> Result<(), Box<dyn std::error::Error>> {
106 println!("\n--- Linear Regression Training ---");
107
108 // Create model parameters
109 let mut weight = Tensor::randn(vec![1, 1], Some(43)).with_requires_grad();
110 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
111
112 // Create optimizer
113 let mut optimizer = Adam::with_learning_rate(0.01);
114 optimizer.add_parameter(&weight);
115 optimizer.add_parameter(&bias);
116
117 // Create simple training data: y = 2*x + 1
118 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
119 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
120
121 println!("Training data:");
122 println!(" X: {:?}", x_data.data());
123 println!(" Y: {:?}", y_true.data());
124 println!(" Target: y = 2*x + 1");
125
126 // Training loop
127 let num_epochs = 100;
128 let mut losses = Vec::new();
129
130 for epoch in 0..num_epochs {
131 // Forward pass: y_pred = x * weight + bias
132 let y_pred = x_data.matmul(&weight) + &bias;
133
134 // Compute loss: MSE
135 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
136
137 // Backward pass
138 loss.backward(None);
139
140 // Optimizer step
141 optimizer.step(&mut [&mut weight, &mut bias]);
142 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
143
144 losses.push(loss.value());
145
146 // Print progress every 20 epochs
147 if epoch % 20 == 0 || epoch == num_epochs - 1 {
148 println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
149 }
150 }
151
152 // Evaluate final model
153 let final_predictions = x_data.matmul(&weight) + &bias;
154 println!("\nFinal model evaluation:");
155 println!(" Learned weight: {:.6}", weight.value());
156 println!(" Learned bias: {:.6}", bias.value());
157 println!(" Predictions vs True:");
158
159 for i in 0..5 {
160 let x1 = x_data.data()[i];
161 let pred = final_predictions.data()[i];
162 let true_val = y_true.data()[i];
163 println!(
164 " x={:.1}: pred={:.3}, true={:.1}, error={:.3}",
165 x1,
166 pred,
167 true_val,
168 (pred - true_val).abs()
169 );
170 }
171
172 Ok(())
173}
174
175/// Demonstrate advanced training patterns
176fn demonstrate_advanced_training() -> Result<(), Box<dyn std::error::Error>> {
177 println!("\n--- Advanced Training Patterns ---");
178
179 // Create a more complex model
180 let mut weight = Tensor::randn(vec![1, 2], Some(44)).with_requires_grad();
181 let mut bias = Tensor::zeros(vec![2]).with_requires_grad();
182
183 // Create optimizer with different learning rate
184 let mut optimizer = Adam::with_learning_rate(0.005);
185 optimizer.add_parameter(&weight);
186 optimizer.add_parameter(&bias);
187
188 // Create training data: y = 2*x + [1, 3]
189 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
190 let y_true = Tensor::from_slice(
191 &[3.0, 5.0, 7.0, 9.0, 11.0, 6.0, 8.0, 10.0, 12.0, 14.0],
192 vec![5, 2],
193 )
194 .unwrap();
195
196 println!("Advanced training with monitoring:");
197 println!(" Initial learning rate: {}", optimizer.learning_rate());
198
199 // Training loop with monitoring
200 let num_epochs = 50;
201 let mut losses = Vec::new();
202 let mut weight_norms = Vec::new();
203 let mut gradient_norms = Vec::new();
204
205 for epoch in 0..num_epochs {
206 // Forward pass
207 let y_pred = x_data.matmul(&weight) + &bias;
208 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
209
210 // Backward pass
211 loss.backward(None);
212
213 // Compute gradient norm before optimizer step
214 let gradient_norm = weight.grad_owned().unwrap().norm();
215
216 // Optimizer step
217 optimizer.step(&mut [&mut weight, &mut bias]);
218 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
219
220 // Learning rate scheduling: reduce every 10 epochs
221 if epoch > 0 && epoch % 10 == 0 {
222 let current_lr = optimizer.learning_rate();
223 let new_lr = current_lr * 0.5;
224 optimizer.set_learning_rate(new_lr);
225 println!(
226 "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
227 epoch, current_lr, new_lr
228 );
229 }
230
231 // Record metrics
232 losses.push(loss.value());
233 weight_norms.push(weight.norm().value());
234 gradient_norms.push(gradient_norm.value());
235
236 // Print detailed progress
237 if epoch % 10 == 0 || epoch == num_epochs - 1 {
238 println!(
239 "Epoch {:2}: Loss = {:.6}, Weight Norm = {:.6}, Gradient Norm = {:.6}",
240 epoch,
241 loss.value(),
242 weight.norm().value(),
243 gradient_norm.value()
244 );
245 }
246 }
247
248 println!("Final learning rate: {}", optimizer.learning_rate());
249
250 // Analyze training progression
251 let initial_loss = losses[0];
252 let final_loss = losses[losses.len() - 1];
253 let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
254
255 println!("\nTraining Analysis:");
256 println!(" Initial loss: {:.6}", initial_loss);
257 println!(" Final loss: {:.6}", final_loss);
258 println!(" Loss reduction: {:.1}%", loss_reduction);
259 println!(" Final weight norm: {:.6}", weight.norm().value());
260 println!(" Final bias: {:?}", bias.data());
261
262 Ok(())
263}
264
265/// Demonstrate learning rate scheduling
266fn demonstrate_learning_rate_scheduling() -> Result<(), Box<dyn std::error::Error>> {
267 println!("\n--- Learning Rate Scheduling ---");
268
269 // Create simple model
270 let mut weight = Tensor::randn(vec![1, 1], Some(45)).with_requires_grad();
271 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
272
273 // Create optimizer with high initial learning rate
274 let mut optimizer = Adam::with_learning_rate(0.1);
275 optimizer.add_parameter(&weight);
276 optimizer.add_parameter(&bias);
277
278 // Simple data
279 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3, 1]).unwrap();
280 let y_true = Tensor::from_slice(&[2.0, 4.0, 6.0], vec![3, 1]).unwrap();
281
282 println!("Initial learning rate: {}", optimizer.learning_rate());
283
284 // Training loop with learning rate scheduling
285 let num_epochs = 50;
286 let mut losses = Vec::new();
287
288 for epoch in 0..num_epochs {
289 // Forward pass
290 let y_pred = x_data.matmul(&weight) + &bias;
291 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
292
293 // Backward pass
294 loss.backward(None);
295
296 // Optimizer step
297 optimizer.step(&mut [&mut weight, &mut bias]);
298 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
299
300 // Learning rate scheduling: reduce every 10 epochs
301 if epoch > 0 && epoch % 10 == 0 {
302 let current_lr = optimizer.learning_rate();
303 let new_lr = current_lr * 0.5;
304 optimizer.set_learning_rate(new_lr);
305 println!(
306 "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
307 epoch, current_lr, new_lr
308 );
309 }
310
311 losses.push(loss.value());
312
313 // Print progress
314 if epoch % 10 == 0 || epoch == num_epochs - 1 {
315 println!(
316 "Epoch {:2}: Loss = {:.6}, LR = {:.3}",
317 epoch,
318 loss.value(),
319 optimizer.learning_rate()
320 );
321 }
322 }
323
324 println!("Final learning rate: {}", optimizer.learning_rate());
325
326 Ok(())
327}
328
329/// Demonstrate training monitoring and analysis
330fn demonstrate_training_monitoring() -> Result<(), Box<dyn std::error::Error>> {
331 println!("\n--- Training Monitoring ---");
332
333 // Create model
334 let mut weight = Tensor::randn(vec![1, 1], Some(46)).with_requires_grad();
335 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
336
337 // Create optimizer
338 let mut optimizer = Adam::with_learning_rate(0.01);
339 optimizer.add_parameter(&weight);
340 optimizer.add_parameter(&bias);
341
342 // Training data
343 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
344 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0], vec![4, 1]).unwrap();
345
346 // Training loop with comprehensive monitoring
347 let num_epochs = 30;
348 let mut losses = Vec::new();
349 let mut weight_history = Vec::new();
350 let mut bias_history = Vec::new();
351
352 for epoch in 0..num_epochs {
353 // Forward pass
354 let y_pred = x_data.matmul(&weight) + &bias;
355 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
356
357 // Backward pass
358 loss.backward(None);
359
360 // Optimizer step
361 optimizer.step(&mut [&mut weight, &mut bias]);
362 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
363
364 // Record history
365 losses.push(loss.value());
366 weight_history.push(weight.value());
367 bias_history.push(bias.value());
368
369 // Print detailed monitoring
370 if epoch % 5 == 0 || epoch == num_epochs - 1 {
371 println!(
372 "Epoch {:2}: Loss = {:.6}, Weight = {:.6}, Bias = {:.6}",
373 epoch,
374 loss.value(),
375 weight.value(),
376 bias.value()
377 );
378 }
379 }
380
381 // Analyze training progression
382 println!("\nTraining Analysis:");
383 println!(" Initial loss: {:.6}", losses[0]);
384 println!(" Final loss: {:.6}", losses[losses.len() - 1]);
385 println!(
386 " Loss reduction: {:.1}%",
387 (losses[0] - losses[losses.len() - 1]) / losses[0] * 100.0
388 );
389
390 // Compute statistics
391 let loss_mean = compute_mean(&losses);
392 let loss_std = compute_std(&losses);
393 let weight_change = (weight_history[weight_history.len() - 1] - weight_history[0]).abs();
394 let bias_change = (bias_history[bias_history.len() - 1] - bias_history[0]).abs();
395
396 println!(" Average loss: {:.6} ± {:.6}", loss_mean, loss_std);
397 println!(" Weight change: {:.6}", weight_change);
398 println!(" Bias change: {:.6}", bias_change);
399 println!(" Final weight norm: {:.6}", weight.norm().value());
400 println!(" Final bias: {:.6}", bias.value());
401
402 Ok(())
403}72 pub fn forward(
73 &self,
74 query: &Tensor,
75 key: &Tensor,
76 value: &Tensor,
77 attn_mask: Option<&Tensor>,
78 ) -> Tensor {
79 let qkv = Self::project_qkv(query, key, value, &self.q_proj, &self.k_proj, &self.v_proj);
80 let (q, k, v) = qkv;
81
82 // Split heads: [b, t, e] -> [b, h, t, d]
83 let (b, tq, _e) = Self::triple(query);
84 let (_b2, tk, _e2) = Self::triple(key);
85 let q = Self::split_heads(&q, b, tq, self.num_heads, self.head_dim);
86 let k = Self::split_heads(&k, b, tk, self.num_heads, self.head_dim);
87 let v = Self::split_heads(&v, b, tk, self.num_heads, self.head_dim);
88
89 // Scaled dot-product attention
90 // logits: [b, h, tq, tk]
91 let k_t = k.transpose(2, 3);
92 let mut logits = q.matmul(&k_t).div_scalar((self.head_dim as f32).sqrt());
93 if let Some(mask) = attn_mask {
94 let dims = mask.shape().dims().to_vec();
95 // If boolean-like mask matching [b,h,tq,tk], apply masked_fill
96 if dims.len() == 4 && dims[0] == b && dims[1] == self.num_heads && dims[2] == tq {
97 // Interpret mask > 0.5 as keep; we invert to build masked positions
98 let cond: Vec<bool> = mask.data().iter().map(|&v| v < 0.5).collect();
99 // Apply masked fill on a flattened view, then reshape back
100 let flat_logits = logits.view(vec![(b * self.num_heads * tq * tk) as i32]);
101 let filled = flat_logits.masked_fill(&cond, f32::NEG_INFINITY);
102 logits = filled.view(vec![b as i32, self.num_heads as i32, tq as i32, tk as i32]);
103 } else {
104 // Fallback: additive mask
105 logits = logits.add_tensor(mask);
106 }
107 }
108 let attn = logits.softmax(3);
109
110 // context: [b, h, tq, d]
111 let context = attn.matmul(&v);
112 let context = context.permute(vec![0, 2, 1, 3]); // [b, tq, h, d]
113 let context = context.contiguous().view(vec![
114 b as i32,
115 tq as i32,
116 (self.num_heads * self.head_dim) as i32,
117 ]);
118
119 // Output projection (flatten to 2D, project, then restore 3D)
120 let flat = context.view(vec![(b * tq) as i32, self.embed_dim as i32]);
121 let out2d = self.out_proj.forward(&flat);
122 out2d.view(vec![b as i32, tq as i32, self.embed_dim as i32])
123 }Source§impl Tensor
impl Tensor
Sourcepub fn mul_tensor(&self, other: &Tensor) -> Tensor
pub fn mul_tensor(&self, other: &Tensor) -> Tensor
Element-wise multiplication with another tensor with broadcasting support.
Performs element-wise multiplication with automatic broadcasting: output[i] = self[i] * other[i]
Broadcasting enables multiplication between tensors of different but compatible shapes. Compatible shapes follow NumPy broadcasting rules:
- Dimensions are aligned from the rightmost dimension
- Dimensions are compatible if they are equal, or one of them is 1
- Missing dimensions are treated as 1
§Arguments
other- Tensor to multiply. Shapes must be broadcast-compatible.
§Returns
A new tensor containing the element-wise product with broadcast result shape
§Examples
§Same Shape Multiplication
use train_station::Tensor;
let a = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![3]).unwrap();
let b = Tensor::from_slice(&[5.0, 6.0, 7.0], vec![3]).unwrap();
let c = a.mul_tensor(&b);
assert_eq!(c.shape().dims(), vec![3]);
assert_eq!(c.get(&[0]), 10.0); // 2.0 * 5.0
assert_eq!(c.get(&[1]), 18.0); // 3.0 * 6.0
assert_eq!(c.get(&[2]), 28.0); // 4.0 * 7.0§Broadcasting Multiplication
use train_station::Tensor;
let a = Tensor::from_slice(&[2.0, 3.0], vec![2, 1]).unwrap();
let b = Tensor::from_slice(&[10.0, 20.0, 30.0], vec![1, 3]).unwrap();
let c = a.mul_tensor(&b);
assert_eq!(c.shape().dims(), vec![2, 3]);
// Result: [[20.0, 40.0, 60.0], [30.0, 60.0, 90.0]]
assert_eq!(c.get(&[0, 0]), 20.0); // 2.0 * 10.0
assert_eq!(c.get(&[0, 1]), 40.0); // 2.0 * 20.0
assert_eq!(c.get(&[1, 0]), 30.0); // 3.0 * 10.0§Panics
Panics if tensor shapes are not broadcast-compatible
Examples found in repository?
More examples
89fn demonstrate_basic_operations() {
90 println!("\n--- Basic Operations ---");
91
92 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
93 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
94
95 // Addition
96 let sum = a.add_tensor(&b);
97 println!("A + B: {:?}", sum.data());
98
99 // Subtraction
100 let diff = a.sub_tensor(&b);
101 println!("A - B: {:?}", diff.data());
102
103 // Multiplication
104 let product = a.mul_tensor(&b);
105 println!("A * B: {:?}", product.data());
106
107 // Division
108 let quotient = a.div_tensor(&b);
109 println!("A / B: {:?}", quotient.data());
110
111 // Scalar operations
112 let scalar_add = a.add_scalar(5.0);
113 println!("A + 5.0: {:?}", scalar_add.data());
114
115 let scalar_mul = a.mul_scalar(2.0);
116 println!("A * 2.0: {:?}", scalar_mul.data());
117}203fn demonstrate_method_equivalence() {
204 println!("\n--- Operator vs Method Call Equivalence ---");
205
206 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
207 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
208
209 // Addition: operator vs method
210 let operator_result = &a + &b;
211 let method_result = a.add_tensor(&b);
212
213 println!("A + B (operator): {:?}", operator_result.data());
214 println!("A.add_tensor(B): {:?}", method_result.data());
215 println!(
216 "Results are equal: {}",
217 operator_result.data() == method_result.data()
218 );
219
220 // Multiplication: operator vs method
221 let operator_result = &a * &b;
222 let method_result = a.mul_tensor(&b);
223
224 println!("A * B (operator): {:?}", operator_result.data());
225 println!("A.mul_tensor(B): {:?}", method_result.data());
226 println!(
227 "Results are equal: {}",
228 operator_result.data() == method_result.data()
229 );
230
231 // Scalar addition: operator vs method
232 let operator_result = &a + 5.0;
233 let method_result = a.add_scalar(5.0);
234
235 println!("A + 5.0 (operator): {:?}", operator_result.data());
236 println!("A.add_scalar(5.0): {:?}", method_result.data());
237 println!(
238 "Results are equal: {}",
239 operator_result.data() == method_result.data()
240 );
241}208fn demonstrate_advanced_patterns() -> Result<(), Box<dyn std::error::Error>> {
209 println!("\n--- Advanced Iterator Patterns ---");
210
211 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6])?;
212 println!("Input tensor: {:?}", tensor.data());
213
214 // Complex chain: enumerate -> filter -> map -> collect
215 println!("\nComplex chain (even indices only, add index to value):");
216 let result: Tensor = tensor
217 .iter()
218 .enumerate()
219 .filter(|(i, _)| i % 2 == 0) // Take even indices
220 .map(|(i, elem)| elem.add_scalar(i as f32)) // Add index to value
221 .collect();
222 println!(" Result: {:?}", result.data());
223
224 // Using take and skip for windowing
225 println!("\nWindowing with take and skip:");
226 let window1: Tensor = tensor.iter().take(3).collect();
227 let window2: Tensor = tensor.iter().skip(2).take(3).collect();
228 println!(" Window 1 (first 3): {:?}", window1.data());
229 println!(" Window 2 (middle 3): {:?}", window2.data());
230
231 // Using rev() for reverse iteration
232 println!("\nReverse iteration:");
233 let reversed: Tensor = tensor.iter().rev().collect();
234 println!(" Reversed: {:?}", reversed.data());
235
236 // Chaining with mathematical operations
237 println!("\nMathematical operation chain:");
238 let math_result: Tensor = tensor
239 .iter()
240 .map(|elem| elem.exp()) // e^x
241 .filter(|elem| elem.value() < 50.0) // Filter large values
242 .map(|elem| elem.log()) // ln(x)
243 .collect();
244 println!(" Math chain result: {:?}", math_result.data());
245
246 // Using zip for element-wise combinations
247 println!("\nElement-wise combination with zip:");
248 let tensor2 = Tensor::from_slice(&[10.0, 20.0, 30.0, 40.0, 50.0, 60.0], vec![6])?;
249 let combined: Tensor = tensor
250 .iter()
251 .zip(tensor2.iter())
252 .map(|(a, b)| a.mul_tensor(&b)) // Element-wise multiplication
253 .collect();
254 println!(" Combined: {:?}", combined.data());
255
256 Ok(())
257}333pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334 println!("=== DQN Example (YardEnv discrete) ===");
335
336 // Dims
337 let state_dim = 3usize;
338 let action_dim = 3usize;
339
340 // Hparams
341 let gamma = 0.99f32;
342 let batch_size = 64usize;
343 let start_steps = 200usize;
344 let target_update_interval = 200usize; // hard update cadence
345 let max_grad_norm = 1.0f32;
346 let mut epsilon = 1.0f32;
347 let eps_min = 0.05f32;
348 let eps_decay_steps = 2_000usize; // linear decay
349 let total_steps = std::env::var("DQN_STEPS")
350 .ok()
351 .and_then(|v| v.parse::<usize>().ok())
352 .unwrap_or(3000usize);
353
354 // Models
355 let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356 let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357 q_targ.net.copy_from(&q_net.net);
358 q_targ.set_requires_grad_all(false);
359
360 // Optimizer
361 let mut q_opt = Adam::with_learning_rate(3e-4);
362 for p in q_net.parameters() {
363 q_opt.add_parameter(p);
364 }
365
366 // Replay + env
367 let mut rb = ReplayBuffer::new(100_000, state_dim);
368 let mut env = YardEnv::new(12345);
369 let mut rng = SmallRng::new(999_111);
370
371 // Metrics
372 let mut state = env.reset();
373 let mut episode_return = 0.0f32;
374 let mut episode = 0usize;
375 let mut ema_return: Option<f32> = None;
376 let ema_alpha = 0.05f32;
377 let mut best_return = f32::NEG_INFINITY;
378
379 for t in 0..total_steps {
380 // Epsilon-greedy action
381 let action_index = if t < start_steps || rng.next_f32() < epsilon {
382 rng.sample_index(action_dim)
383 } else {
384 let _ng = NoGradTrack::new();
385 let q_vals = q_net.forward(&state);
386 let row = q_vals.data();
387 let mut best_i = 0usize;
388 let mut best_v = row[0];
389 for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390 if r > best_v {
391 best_v = r;
392 best_i = i;
393 }
394 }
395 best_i
396 };
397
398 // Env step
399 let (next_state, reward, done) = env.step(action_index);
400 episode_return += reward;
401
402 // Store
403 let s_slice = state.data().to_vec();
404 let s2_slice = next_state.data().to_vec();
405 rb.push(
406 &s_slice,
407 action_index,
408 reward,
409 if done { 1.0 } else { 0.0 },
410 &s2_slice,
411 );
412
413 // Reset on done
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return,
430 rb.size
431 );
432 episode_return = 0.0;
433 episode += 1;
434 st
435 } else {
436 next_state
437 };
438
439 // Epsilon linear decay
440 if t < eps_decay_steps {
441 epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442 }
443
444 // Train
445 if rb.can_sample(batch_size) {
446 let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448 // Double DQN target: a* = argmax_a Q_online(s2,a); y = r + (1-d)*gamma*Q_target(s2, a*)
449 let target_q = {
450 let _ng = NoGradTrack::new();
451 let q_online_s2 = q_net.forward(&s2);
452 // argmax per row (manual on CPU)
453 let row_stride = action_dim;
454 let qd = q_online_s2.data();
455 let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456 for i in 0..batch_size {
457 let base = i * row_stride;
458 let mut bi = 0usize;
459 let mut bv = qd[base];
460 for j in 1..action_dim {
461 let v = qd[base + j];
462 if v > bv {
463 bv = v;
464 bi = j;
465 }
466 }
467 next_actions.push(bi);
468 }
469 let q_targ_s2 = q_targ.forward(&s2);
470 let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471 let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472 r.add_tensor(¬_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473 };
474
475 // Q(s,a) for current actions
476 // Zero grads first
477 {
478 let mut params = q_net.parameters();
479 q_opt.zero_grad(&mut params);
480 }
481
482 let q_all = q_net.forward(&s);
483 let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484 let diff = q_sa.sub_tensor(&target_q);
485 let mut loss = pseudo_huber_mean(&diff);
486 loss.backward(None);
487
488 // Step (filter only params with grads)
489 {
490 let params = q_net.parameters();
491 let mut with_grads: Vec<&mut Tensor> = Vec::new();
492 for p in params {
493 if p.grad_owned().is_some() {
494 with_grads.push(p);
495 }
496 }
497 if !with_grads.is_empty() {
498 let gn = grad_global_norm(&mut with_grads);
499 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500 q_opt.step(&mut with_grads);
501 q_opt.zero_grad(&mut with_grads);
502 if t % 100 == 0 {
503 let mut pn = q_net.parameters();
504 let pn_l2 = params_l2_norm(&mut pn);
505 let q_mean = q_all.mean().value();
506 println!(
507 "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508 t, loss.value(), q_mean, gn, pn_l2, epsilon
509 );
510 }
511 }
512 }
513
514 // Target hard update
515 if t % target_update_interval == 0 {
516 q_targ.net.copy_from(&q_net.net);
517 }
518
519 // Clear graphs
520 clear_all_graphs_known();
521 }
522 }
523
524 println!("=== DQN training finished ===");
525 Ok(())
526}319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320 println!("=== PPO Discrete Example (YardEnv) ===");
321
322 let state_dim = 3usize;
323 let action_dim = 3usize;
324 let total_steps = std::env::var("PPOD_STEPS")
325 .ok()
326 .and_then(|v| v.parse::<usize>().ok())
327 .unwrap_or(3500usize);
328 let horizon = 128usize;
329 let epochs = 4usize;
330 let mini_batch_size = 64usize;
331 let gamma = 0.99f32;
332 let lam = 0.95f32;
333 let clip_eps = 0.2f32;
334 let vf_coef = 0.5f32;
335 let ent_coef = 0.0f32;
336 let max_grad_norm = 1.0f32;
337
338 let mut actor = Actor::new(state_dim, action_dim, Some(111));
339 let mut critic = Critic::new(state_dim, Some(222));
340 let mut actor_opt = Adam::with_learning_rate(3e-4);
341 for p in actor.parameters() {
342 actor_opt.add_parameter(p);
343 }
344 let mut critic_opt = Adam::with_learning_rate(3e-4);
345 for p in critic.parameters() {
346 critic_opt.add_parameter(p);
347 }
348
349 let mut env = YardEnv::new(1234);
350 let mut rng = SmallRng::new(98765);
351 let mut state = env.reset();
352 let mut episode_return = 0.0f32;
353 let mut episode = 0usize;
354 let mut ema_return: Option<f32> = None;
355 let ema_alpha = 0.05f32;
356 let mut best_return = f32::NEG_INFINITY;
357
358 let mut t = 0usize;
359 while t < total_steps {
360 let mut batch = RolloutBatch::new(horizon, state_dim);
361 for _ in 0..horizon {
362 // Actor logits and categorical sampling
363 let logits = actor.forward(&state); // [1, A]
364 let probs = logits.softmax(1); // [1, A]
365 // sample action from probs (CPU sampling)
366 let p = probs.data();
367 let (p0, p1, _p2) = (p[0], p[1], p[2]);
368 let u = rng.next_f32();
369 let a_idx = if u < p0 {
370 0
371 } else if u < p0 + p1 {
372 1
373 } else {
374 2
375 };
376
377 let old_logp = {
378 let _ng = NoGradTrack::new();
379 let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380 lp.data()[0]
381 };
382
383 // Step env
384 let (next_state, reward, done) = env.step(a_idx);
385 episode_return += reward;
386
387 // Critic value
388 let value_t = critic.forward(&state);
389 let value_v = value_t.data()[0];
390
391 batch.push(
392 state.data(),
393 a_idx,
394 old_logp,
395 reward,
396 if done { 1.0 } else { 0.0 },
397 value_v,
398 next_state.data(),
399 );
400
401 state = if done {
402 let st = env.reset();
403 ema_return = Some(match ema_return {
404 None => episode_return,
405 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406 });
407 if episode_return > best_return {
408 best_return = episode_return;
409 }
410 println!(
411 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412 t,
413 episode,
414 episode_return,
415 ema_return.unwrap_or(episode_return),
416 best_return
417 );
418 episode_return = 0.0;
419 episode += 1;
420 st
421 } else {
422 next_state
423 };
424
425 t += 1;
426 if t >= total_steps {
427 break;
428 }
429 }
430
431 // Bootstrap values for GAE
432 let next_values: Vec<f32> = {
433 let mut out = Vec::with_capacity(batch.len());
434 for i in 0..batch.len() {
435 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437 out.push(critic.forward(&s2_t).data()[0]);
438 }
439 out
440 };
441
442 let mut returns = vec![0.0f32; batch.len()];
443 let mut adv = vec![0.0f32; batch.len()];
444 compute_gae(
445 &mut returns,
446 &mut adv,
447 &batch.rewards,
448 &batch.dones,
449 &batch.values,
450 &next_values,
451 gamma,
452 lam,
453 );
454 normalize_in_place(&mut adv, 1e-8);
455
456 // Tensors for training
457 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458 let actions_vec = batch.actions.clone();
459 let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463 // PPO epochs
464 let num_minibatches = batch.len().div_ceil(mini_batch_size);
465 for e in 0..epochs {
466 for mb in 0..num_minibatches {
467 let start = mb * mini_batch_size;
468 let end = (start + mini_batch_size).min(batch.len());
469 if start >= end {
470 break;
471 }
472
473 // Views
474 let s_mb = states_t
475 .slice_view(start * state_dim, 1, (end - start) * state_dim)
476 .reshape(vec![(end - start) as i32, state_dim as i32]);
477 let oldlp_mb = old_logp_t
478 .slice_view(start, 1, end - start)
479 .reshape(vec![(end - start) as i32, 1]);
480 let ret_mb = returns_t
481 .slice_view(start, 1, end - start)
482 .reshape(vec![(end - start) as i32, 1]);
483 let adv_mb = adv_t
484 .slice_view(start, 1, end - start)
485 .reshape(vec![(end - start) as i32, 1]);
486 let a_slice = &actions_vec[start..end];
487
488 // Zero grads
489 {
490 let mut ps = actor.parameters();
491 actor_opt.zero_grad(&mut ps);
492 }
493 {
494 let mut ps = critic.parameters();
495 critic_opt.zero_grad(&mut ps);
496 }
497
498 // Forward
499 let logits_mb = actor.forward(&s_mb); // [B,A]
500 let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501 let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502 let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503 let pg1 = ratio.mul_tensor(&adv_mb);
504 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507 let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509 let v_pred = critic.forward(&s_mb);
510 let v_loss = v_pred
511 .sub_tensor(&ret_mb)
512 .pow_scalar(2.0)
513 .mean()
514 .mul_scalar(vf_coef);
515
516 // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517 let probs_mb = logits_mb.softmax(1);
518 let logp_all = probs_mb.add_scalar(1e-8).log();
519 let ent = probs_mb
520 .mul_tensor(&logp_all)
521 .sum_dims(&[1], true)
522 .mul_scalar(-1.0)
523 .mean()
524 .mul_scalar(ent_coef);
525
526 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527 loss.backward(None);
528
529 // Step actor
530 {
531 let params = actor.parameters();
532 let mut with_grads: Vec<&mut Tensor> = Vec::new();
533 for p in params {
534 if p.grad_owned().is_some() {
535 with_grads.push(p);
536 }
537 }
538 if !with_grads.is_empty() {
539 let _ = grad_global_norm(&mut with_grads);
540 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541 actor_opt.step(&mut with_grads);
542 actor_opt.zero_grad(&mut with_grads);
543 }
544 }
545
546 // Step critic
547 {
548 let params = critic.parameters();
549 let mut with_grads: Vec<&mut Tensor> = Vec::new();
550 for p in params {
551 if p.grad_owned().is_some() {
552 with_grads.push(p);
553 }
554 }
555 if !with_grads.is_empty() {
556 let _ = grad_global_norm(&mut with_grads);
557 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558 critic_opt.step(&mut with_grads);
559 critic_opt.zero_grad(&mut with_grads);
560 }
561 }
562
563 if e == 0 && mb == 0 {
564 println!(
565 "update@t={} | actor_loss={:.4} v_loss={:.4}",
566 t,
567 actor_loss.value(),
568 v_loss.value()
569 );
570 }
571
572 clear_all_graphs_known();
573 }
574 }
575 }
576
577 println!("=== PPO discrete training finished ===");
578 Ok(())
579}Sourcepub fn mul_scalar(&self, scalar: f32) -> Tensor
pub fn mul_scalar(&self, scalar: f32) -> Tensor
Broadcast multiplication with a scalar value.
Multiplies every element by the scalar: output[i] = self[i] * scalar
§Arguments
scalar- Value to multiply with each element
§Returns
A new tensor with each element multiplied by the scalar
§Examples
§Basic Scalar Multiplication
use train_station::Tensor;
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let b = a.mul_scalar(10.0);
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 10.0); // 1.0 * 10.0
assert_eq!(b.get(&[1]), 20.0); // 2.0 * 10.0
assert_eq!(b.get(&[2]), 30.0); // 3.0 * 10.0§Negative Scalar Multiplication
use train_station::Tensor;
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let b = a.mul_scalar(-2.0);
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), -2.0); // 1.0 * -2.0
assert_eq!(b.get(&[1]), -4.0); // 2.0 * -2.0
assert_eq!(b.get(&[2]), -6.0); // 3.0 * -2.0Examples found in repository?
53 pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
54 // Xavier/Glorot initialization: scale by sqrt(1/input_size)
55 let scale = (1.0 / input_size as f32).sqrt();
56
57 let weight = Tensor::randn(vec![input_size, output_size], seed)
58 .mul_scalar(scale)
59 .with_requires_grad();
60 let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
61
62 Self {
63 weight,
64 bias,
65 input_size,
66 output_size,
67 }
68 }More examples
277fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
278 let mut total_sq = 0.0f32;
279 for p in parameters.iter() {
280 if let Some(g) = p.grad_owned() {
281 for &v in g.data() {
282 total_sq += v * v;
283 }
284 }
285 }
286 let norm = total_sq.sqrt();
287 if norm > max_norm {
288 let scale = max_norm / (norm + eps);
289 for p in parameters.iter_mut() {
290 if let Some(g) = p.grad_owned() {
291 p.set_grad(g.mul_scalar(scale));
292 }
293 }
294 }
295}
296
297fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
298 let mut total_sq = 0.0f32;
299 for p in parameters.iter_mut() {
300 if let Some(g) = p.grad_owned() {
301 for &v in g.data() {
302 total_sq += v * v;
303 }
304 }
305 }
306 total_sq.sqrt()
307}
308
309fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
310 let _ng = NoGradTrack::new();
311 let mut total_sq = 0.0f32;
312 for p in parameters.iter_mut() {
313 for &v in p.data() {
314 total_sq += v * v;
315 }
316 }
317 total_sq.sqrt()
318}
319
320// Pseudo-Huber loss: sqrt(1 + diff^2) - 1 (smooth, robust)
321fn pseudo_huber_mean(diff: &Tensor) -> Tensor {
322 diff.pow_scalar(2.0)
323 .add_scalar(1.0)
324 .sqrt()
325 .sub_scalar(1.0)
326 .mean()
327}
328
329// -------------------------------
330// Main
331// -------------------------------
332
333pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334 println!("=== DQN Example (YardEnv discrete) ===");
335
336 // Dims
337 let state_dim = 3usize;
338 let action_dim = 3usize;
339
340 // Hparams
341 let gamma = 0.99f32;
342 let batch_size = 64usize;
343 let start_steps = 200usize;
344 let target_update_interval = 200usize; // hard update cadence
345 let max_grad_norm = 1.0f32;
346 let mut epsilon = 1.0f32;
347 let eps_min = 0.05f32;
348 let eps_decay_steps = 2_000usize; // linear decay
349 let total_steps = std::env::var("DQN_STEPS")
350 .ok()
351 .and_then(|v| v.parse::<usize>().ok())
352 .unwrap_or(3000usize);
353
354 // Models
355 let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356 let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357 q_targ.net.copy_from(&q_net.net);
358 q_targ.set_requires_grad_all(false);
359
360 // Optimizer
361 let mut q_opt = Adam::with_learning_rate(3e-4);
362 for p in q_net.parameters() {
363 q_opt.add_parameter(p);
364 }
365
366 // Replay + env
367 let mut rb = ReplayBuffer::new(100_000, state_dim);
368 let mut env = YardEnv::new(12345);
369 let mut rng = SmallRng::new(999_111);
370
371 // Metrics
372 let mut state = env.reset();
373 let mut episode_return = 0.0f32;
374 let mut episode = 0usize;
375 let mut ema_return: Option<f32> = None;
376 let ema_alpha = 0.05f32;
377 let mut best_return = f32::NEG_INFINITY;
378
379 for t in 0..total_steps {
380 // Epsilon-greedy action
381 let action_index = if t < start_steps || rng.next_f32() < epsilon {
382 rng.sample_index(action_dim)
383 } else {
384 let _ng = NoGradTrack::new();
385 let q_vals = q_net.forward(&state);
386 let row = q_vals.data();
387 let mut best_i = 0usize;
388 let mut best_v = row[0];
389 for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390 if r > best_v {
391 best_v = r;
392 best_i = i;
393 }
394 }
395 best_i
396 };
397
398 // Env step
399 let (next_state, reward, done) = env.step(action_index);
400 episode_return += reward;
401
402 // Store
403 let s_slice = state.data().to_vec();
404 let s2_slice = next_state.data().to_vec();
405 rb.push(
406 &s_slice,
407 action_index,
408 reward,
409 if done { 1.0 } else { 0.0 },
410 &s2_slice,
411 );
412
413 // Reset on done
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return,
430 rb.size
431 );
432 episode_return = 0.0;
433 episode += 1;
434 st
435 } else {
436 next_state
437 };
438
439 // Epsilon linear decay
440 if t < eps_decay_steps {
441 epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442 }
443
444 // Train
445 if rb.can_sample(batch_size) {
446 let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448 // Double DQN target: a* = argmax_a Q_online(s2,a); y = r + (1-d)*gamma*Q_target(s2, a*)
449 let target_q = {
450 let _ng = NoGradTrack::new();
451 let q_online_s2 = q_net.forward(&s2);
452 // argmax per row (manual on CPU)
453 let row_stride = action_dim;
454 let qd = q_online_s2.data();
455 let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456 for i in 0..batch_size {
457 let base = i * row_stride;
458 let mut bi = 0usize;
459 let mut bv = qd[base];
460 for j in 1..action_dim {
461 let v = qd[base + j];
462 if v > bv {
463 bv = v;
464 bi = j;
465 }
466 }
467 next_actions.push(bi);
468 }
469 let q_targ_s2 = q_targ.forward(&s2);
470 let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471 let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472 r.add_tensor(¬_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473 };
474
475 // Q(s,a) for current actions
476 // Zero grads first
477 {
478 let mut params = q_net.parameters();
479 q_opt.zero_grad(&mut params);
480 }
481
482 let q_all = q_net.forward(&s);
483 let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484 let diff = q_sa.sub_tensor(&target_q);
485 let mut loss = pseudo_huber_mean(&diff);
486 loss.backward(None);
487
488 // Step (filter only params with grads)
489 {
490 let params = q_net.parameters();
491 let mut with_grads: Vec<&mut Tensor> = Vec::new();
492 for p in params {
493 if p.grad_owned().is_some() {
494 with_grads.push(p);
495 }
496 }
497 if !with_grads.is_empty() {
498 let gn = grad_global_norm(&mut with_grads);
499 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500 q_opt.step(&mut with_grads);
501 q_opt.zero_grad(&mut with_grads);
502 if t % 100 == 0 {
503 let mut pn = q_net.parameters();
504 let pn_l2 = params_l2_norm(&mut pn);
505 let q_mean = q_all.mean().value();
506 println!(
507 "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508 t, loss.value(), q_mean, gn, pn_l2, epsilon
509 );
510 }
511 }
512 }
513
514 // Target hard update
515 if t % target_update_interval == 0 {
516 q_targ.net.copy_from(&q_net.net);
517 }
518
519 // Clear graphs
520 clear_all_graphs_known();
521 }
522 }
523
524 println!("=== DQN training finished ===");
525 Ok(())
526}252fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
253 let mut total_sq = 0.0f32;
254 for p in parameters.iter() {
255 if let Some(g) = p.grad_owned() {
256 for &v in g.data() {
257 total_sq += v * v;
258 }
259 }
260 }
261 let norm = total_sq.sqrt();
262 if norm > max_norm {
263 let scale = max_norm / (norm + eps);
264 for p in parameters.iter_mut() {
265 if let Some(g) = p.grad_owned() {
266 p.set_grad(g.mul_scalar(scale));
267 }
268 }
269 }
270}
271
272// log-softmax for selected actions: given logits [B,A] and actions Vec<usize> -> log_prob [B,1]
273fn log_prob_actions(
274 logits: &Tensor,
275 actions: &[usize],
276 batch: usize,
277 _action_dim: usize,
278) -> Tensor {
279 let max_logits = logits.max_dims(&[1], true); // [B,1]
280 let shifted = logits.sub_tensor(&max_logits);
281 let exp = shifted.exp();
282 let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283 let log_sum_exp = sum_exp.log(); // [B,1]
284 let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285 // gather selected action log-probs
286 log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291 new_logp.sub_tensor(old_logp).exp()
292}
293
294// Clamp ratio to [1-clip, 1+clip] using ReLU-based clamp (no custom ops)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296 let b = ratio.shape().dims()[0];
297 let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298 let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299 let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300 high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304 let mut total_sq = 0.0f32;
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 for &v in g.data() {
308 total_sq += v * v;
309 }
310 }
311 }
312 total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320 println!("=== PPO Discrete Example (YardEnv) ===");
321
322 let state_dim = 3usize;
323 let action_dim = 3usize;
324 let total_steps = std::env::var("PPOD_STEPS")
325 .ok()
326 .and_then(|v| v.parse::<usize>().ok())
327 .unwrap_or(3500usize);
328 let horizon = 128usize;
329 let epochs = 4usize;
330 let mini_batch_size = 64usize;
331 let gamma = 0.99f32;
332 let lam = 0.95f32;
333 let clip_eps = 0.2f32;
334 let vf_coef = 0.5f32;
335 let ent_coef = 0.0f32;
336 let max_grad_norm = 1.0f32;
337
338 let mut actor = Actor::new(state_dim, action_dim, Some(111));
339 let mut critic = Critic::new(state_dim, Some(222));
340 let mut actor_opt = Adam::with_learning_rate(3e-4);
341 for p in actor.parameters() {
342 actor_opt.add_parameter(p);
343 }
344 let mut critic_opt = Adam::with_learning_rate(3e-4);
345 for p in critic.parameters() {
346 critic_opt.add_parameter(p);
347 }
348
349 let mut env = YardEnv::new(1234);
350 let mut rng = SmallRng::new(98765);
351 let mut state = env.reset();
352 let mut episode_return = 0.0f32;
353 let mut episode = 0usize;
354 let mut ema_return: Option<f32> = None;
355 let ema_alpha = 0.05f32;
356 let mut best_return = f32::NEG_INFINITY;
357
358 let mut t = 0usize;
359 while t < total_steps {
360 let mut batch = RolloutBatch::new(horizon, state_dim);
361 for _ in 0..horizon {
362 // Actor logits and categorical sampling
363 let logits = actor.forward(&state); // [1, A]
364 let probs = logits.softmax(1); // [1, A]
365 // sample action from probs (CPU sampling)
366 let p = probs.data();
367 let (p0, p1, _p2) = (p[0], p[1], p[2]);
368 let u = rng.next_f32();
369 let a_idx = if u < p0 {
370 0
371 } else if u < p0 + p1 {
372 1
373 } else {
374 2
375 };
376
377 let old_logp = {
378 let _ng = NoGradTrack::new();
379 let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380 lp.data()[0]
381 };
382
383 // Step env
384 let (next_state, reward, done) = env.step(a_idx);
385 episode_return += reward;
386
387 // Critic value
388 let value_t = critic.forward(&state);
389 let value_v = value_t.data()[0];
390
391 batch.push(
392 state.data(),
393 a_idx,
394 old_logp,
395 reward,
396 if done { 1.0 } else { 0.0 },
397 value_v,
398 next_state.data(),
399 );
400
401 state = if done {
402 let st = env.reset();
403 ema_return = Some(match ema_return {
404 None => episode_return,
405 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406 });
407 if episode_return > best_return {
408 best_return = episode_return;
409 }
410 println!(
411 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412 t,
413 episode,
414 episode_return,
415 ema_return.unwrap_or(episode_return),
416 best_return
417 );
418 episode_return = 0.0;
419 episode += 1;
420 st
421 } else {
422 next_state
423 };
424
425 t += 1;
426 if t >= total_steps {
427 break;
428 }
429 }
430
431 // Bootstrap values for GAE
432 let next_values: Vec<f32> = {
433 let mut out = Vec::with_capacity(batch.len());
434 for i in 0..batch.len() {
435 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437 out.push(critic.forward(&s2_t).data()[0]);
438 }
439 out
440 };
441
442 let mut returns = vec![0.0f32; batch.len()];
443 let mut adv = vec![0.0f32; batch.len()];
444 compute_gae(
445 &mut returns,
446 &mut adv,
447 &batch.rewards,
448 &batch.dones,
449 &batch.values,
450 &next_values,
451 gamma,
452 lam,
453 );
454 normalize_in_place(&mut adv, 1e-8);
455
456 // Tensors for training
457 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458 let actions_vec = batch.actions.clone();
459 let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463 // PPO epochs
464 let num_minibatches = batch.len().div_ceil(mini_batch_size);
465 for e in 0..epochs {
466 for mb in 0..num_minibatches {
467 let start = mb * mini_batch_size;
468 let end = (start + mini_batch_size).min(batch.len());
469 if start >= end {
470 break;
471 }
472
473 // Views
474 let s_mb = states_t
475 .slice_view(start * state_dim, 1, (end - start) * state_dim)
476 .reshape(vec![(end - start) as i32, state_dim as i32]);
477 let oldlp_mb = old_logp_t
478 .slice_view(start, 1, end - start)
479 .reshape(vec![(end - start) as i32, 1]);
480 let ret_mb = returns_t
481 .slice_view(start, 1, end - start)
482 .reshape(vec![(end - start) as i32, 1]);
483 let adv_mb = adv_t
484 .slice_view(start, 1, end - start)
485 .reshape(vec![(end - start) as i32, 1]);
486 let a_slice = &actions_vec[start..end];
487
488 // Zero grads
489 {
490 let mut ps = actor.parameters();
491 actor_opt.zero_grad(&mut ps);
492 }
493 {
494 let mut ps = critic.parameters();
495 critic_opt.zero_grad(&mut ps);
496 }
497
498 // Forward
499 let logits_mb = actor.forward(&s_mb); // [B,A]
500 let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501 let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502 let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503 let pg1 = ratio.mul_tensor(&adv_mb);
504 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507 let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509 let v_pred = critic.forward(&s_mb);
510 let v_loss = v_pred
511 .sub_tensor(&ret_mb)
512 .pow_scalar(2.0)
513 .mean()
514 .mul_scalar(vf_coef);
515
516 // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517 let probs_mb = logits_mb.softmax(1);
518 let logp_all = probs_mb.add_scalar(1e-8).log();
519 let ent = probs_mb
520 .mul_tensor(&logp_all)
521 .sum_dims(&[1], true)
522 .mul_scalar(-1.0)
523 .mean()
524 .mul_scalar(ent_coef);
525
526 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527 loss.backward(None);
528
529 // Step actor
530 {
531 let params = actor.parameters();
532 let mut with_grads: Vec<&mut Tensor> = Vec::new();
533 for p in params {
534 if p.grad_owned().is_some() {
535 with_grads.push(p);
536 }
537 }
538 if !with_grads.is_empty() {
539 let _ = grad_global_norm(&mut with_grads);
540 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541 actor_opt.step(&mut with_grads);
542 actor_opt.zero_grad(&mut with_grads);
543 }
544 }
545
546 // Step critic
547 {
548 let params = critic.parameters();
549 let mut with_grads: Vec<&mut Tensor> = Vec::new();
550 for p in params {
551 if p.grad_owned().is_some() {
552 with_grads.push(p);
553 }
554 }
555 if !with_grads.is_empty() {
556 let _ = grad_global_norm(&mut with_grads);
557 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558 critic_opt.step(&mut with_grads);
559 critic_opt.zero_grad(&mut with_grads);
560 }
561 }
562
563 if e == 0 && mb == 0 {
564 println!(
565 "update@t={} | actor_loss={:.4} v_loss={:.4}",
566 t,
567 actor_loss.value(),
568 v_loss.value()
569 );
570 }
571
572 clear_all_graphs_known();
573 }
574 }
575 }
576
577 println!("=== PPO discrete training finished ===");
578 Ok(())
579}23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24 let mut total_sq = 0.0f32;
25 for p in parameters.iter() {
26 if let Some(g) = p.grad_owned() {
27 for &v in g.data() {
28 total_sq += v * v;
29 }
30 }
31 }
32 let norm = total_sq.sqrt();
33 if norm > max_norm {
34 let scale = max_norm / (norm + eps);
35 for p in parameters.iter_mut() {
36 if let Some(g) = p.grad_owned() {
37 p.set_grad(g.mul_scalar(scale));
38 }
39 }
40 }
41}
42
43fn accuracy(pred: &Tensor, targets: &Tensor) -> f32 {
44 // pred: [B,1] with sigmoid; threshold at 0.5
45 let p = pred.data();
46 let t = targets.data();
47 let mut correct = 0usize;
48 for i in 0..p.len() {
49 let yhat = if p[i] >= 0.5 { 1.0 } else { 0.0 };
50 if (yhat - t[i]).abs() < 1e-6 {
51 correct += 1;
52 }
53 }
54 correct as f32 / (p.len() as f32)
55}
56
57// Numerically stable BCE with logits:
58// L = mean( relu(z) - z*y + log(1 + exp(-|z|)) )
59fn bce_with_logits(logits: &Tensor, targets: &Tensor) -> Tensor {
60 let relu_z = logits.relu();
61 let zy = logits.mul_tensor(targets);
62 // |z| = relu(z) + relu(-z)
63 let abs_z = relu_z.add_tensor(&logits.mul_scalar(-1.0).relu());
64 let log_term = abs_z.mul_scalar(-1.0).exp().add_scalar(1.0).log();
65 relu_z.sub_tensor(&zy).add_tensor(&log_term).mean()
66}23fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
24 let mut total_sq = 0.0f32;
25 for p in parameters.iter() {
26 if let Some(g) = p.grad_owned() {
27 for &v in g.data() {
28 total_sq += v * v;
29 }
30 }
31 }
32 let norm = total_sq.sqrt();
33 if norm > max_norm {
34 let scale = max_norm / (norm + eps);
35 for p in parameters.iter_mut() {
36 if let Some(g) = p.grad_owned() {
37 p.set_grad(g.mul_scalar(scale));
38 }
39 }
40 }
41}
42
43// Cross-entropy over logits: CE = -mean(log_softmax(logits)[range, labels])
44fn cross_entropy_logits(
45 logits: &Tensor,
46 labels: &[usize],
47 batch: usize,
48 _num_classes: usize,
49) -> Tensor {
50 // log_softmax = logits - logsumexp(logits, dim=1)
51 let max_logits = logits.max_dims(&[1], true);
52 let shifted = logits.sub_tensor(&max_logits);
53 let exp = shifted.exp();
54 let sum_exp = exp.sum_dims(&[1], true);
55 let log_sum_exp = sum_exp.log();
56 let log_softmax = shifted.sub_tensor(&log_sum_exp);
57 let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58 ll.mul_scalar(-1.0).mean()
59}22fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
23 let mut total_sq = 0.0f32;
24 for p in parameters.iter() {
25 if let Some(g) = p.grad_owned() {
26 for &v in g.data() {
27 total_sq += v * v;
28 }
29 }
30 }
31 let norm = total_sq.sqrt();
32 if norm > max_norm {
33 let scale = max_norm / (norm + eps);
34 for p in parameters.iter_mut() {
35 if let Some(g) = p.grad_owned() {
36 p.set_grad(g.mul_scalar(scale));
37 }
38 }
39 }
40}Source§impl Tensor
impl Tensor
Sourcepub fn pow_scalar(&self, exponent: f32) -> Tensor
pub fn pow_scalar(&self, exponent: f32) -> Tensor
Raises each element to a scalar power.
Computes element-wise power: output[i] = self[i]^exponent
§Arguments
exponent- The scalar exponent to raise each element to
§Returns
A new tensor with each element raised to the given power
§Examples
§Basic Scalar Power
use train_station::Tensor;
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let b = a.pow_scalar(2.0);
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 1.0); // 1.0^2 = 1.0
assert_eq!(b.get(&[1]), 4.0); // 2.0^2 = 4.0
assert_eq!(b.get(&[2]), 9.0); // 3.0^2 = 9.0§Square Root (Power 0.5)
use train_station::Tensor;
let a = Tensor::from_slice(&[1.0, 4.0, 9.0], vec![3]).unwrap();
let b = a.pow_scalar(0.5);
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 1.0); // sqrt(1.0) = 1.0
assert_eq!(b.get(&[1]), 2.0); // sqrt(4.0) = 2.0
assert_eq!(b.get(&[2]), 3.0); // sqrt(9.0) = 3.0Examples found in repository?
42fn mse(pred: &Tensor, target: &Tensor) -> Tensor {
43 pred.sub_tensor(target).pow_scalar(2.0).mean()
44}
45
46fn rmse(pred: &Tensor, target: &Tensor) -> f32 {
47 mse(pred, target).sqrt().value()
48}
49
50fn r2_score(pred: &Tensor, target: &Tensor) -> f32 {
51 // R^2 = 1 - SS_res / SS_tot
52 let y = target;
53 let y_mean = y.mean();
54 let ss_res = pred.sub_tensor(y).pow_scalar(2.0).sum();
55 let ss_tot = y.sub_tensor(&y_mean).pow_scalar(2.0).sum();
56 let ss_res_v = ss_res.value();
57 let ss_tot_v = ss_tot.value().max(1e-12); // avoid divide by zero
58 1.0 - (ss_res_v / ss_tot_v)
59}More examples
234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235 // All tensors shaped [B, A] (log_std is broadcastable)
236 let std = log_std.exp();
237 let var = std.pow_scalar(2.0);
238 let log_scale = log_std;
239 let diff = action.sub_tensor(mean);
240 let log_prob = diff
241 .pow_scalar(2.0)
242 .div_tensor(&var)
243 .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244 .add_tensor(&log_scale.mul_scalar(2.0))
245 .mul_scalar(0.5)
246 .mul_scalar(-1.0);
247 // Sum across action dim (dim=1) -> [B,1]
248 log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253 returns_out: &mut [f32],
254 adv_out: &mut [f32],
255 rewards: &[f32],
256 dones: &[f32],
257 values: &[f32],
258 next_values: &[f32],
259 gamma: f32,
260 lam: f32,
261) {
262 let n = rewards.len();
263 let mut gae = 0.0f32;
264 for t in (0..n).rev() {
265 let not_done = 1.0 - dones[t];
266 let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267 gae = delta + gamma * lam * not_done * gae;
268 adv_out[t] = gae;
269 returns_out[t] = gae + values[t];
270 }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274 let n = x.len() as f32;
275 if n <= 1.0 {
276 return;
277 }
278 let mean = x.iter().copied().sum::<f32>() / n;
279 let var = x
280 .iter()
281 .map(|v| {
282 let d = v - mean;
283 d * d
284 })
285 .sum::<f32>()
286 / n;
287 let std = (var + eps).sqrt();
288 for v in x.iter_mut() {
289 *v = (*v - mean) / std;
290 }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294 let mut total_sq = 0.0f32;
295 for p in parameters.iter() {
296 if let Some(g) = p.grad_owned() {
297 for &v in g.data() {
298 total_sq += v * v;
299 }
300 }
301 }
302 let norm = total_sq.sqrt();
303 if norm > max_norm {
304 let scale = max_norm / (norm + eps);
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 p.set_grad(g.mul_scalar(scale));
308 }
309 }
310 }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314 let mut total_sq = 0.0f32;
315 for p in parameters.iter_mut() {
316 if let Some(g) = p.grad_owned() {
317 for &v in g.data() {
318 total_sq += v * v;
319 }
320 }
321 }
322 total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330 println!("=== PPO Continuous Example (YardEnv) ===");
331
332 let state_dim = 3usize;
333 let action_dim = 1usize;
334
335 // Hparams
336 let total_steps = std::env::var("PPO_STEPS")
337 .ok()
338 .and_then(|v| v.parse::<usize>().ok())
339 .unwrap_or(4000usize);
340 let horizon = 128usize; // rollout length per update
341 let epochs = 4usize; // PPO epochs per update
342 let mini_batch_size = 64usize; // minibatch from horizon
343 let gamma = 0.99f32;
344 let lam = 0.95f32; // GAE lambda
345 let clip_eps = 0.2f32;
346 let vf_coef = 0.5f32;
347 let ent_coef = 0.0f32;
348 let max_grad_norm = 1.0f32;
349
350 // Models
351 let mut actor = Actor::new(state_dim, action_dim, Some(101));
352 let mut critic = Critic::new(state_dim, Some(202));
353
354 // Opts
355 let mut actor_opt = Adam::with_learning_rate(3e-4);
356 for p in actor.parameters() {
357 actor_opt.add_parameter(p);
358 }
359 let mut critic_opt = Adam::with_learning_rate(3e-4);
360 for p in critic.parameters() {
361 critic_opt.add_parameter(p);
362 }
363
364 // Env and RNG
365 let mut env = YardEnv::new(42);
366 let mut rng = SmallRng::new(999);
367 let mut state = env.reset();
368
369 // Metrics
370 let mut episode_return = 0.0f32;
371 let mut episode = 0usize;
372 let mut ema_return: Option<f32> = None;
373 let ema_alpha = 0.05f32;
374 let mut best_return = f32::NEG_INFINITY;
375
376 let mut t = 0usize;
377 while t < total_steps {
378 // Collect a rollout
379 let mut batch = RolloutBatch::new(horizon, state_dim);
380 for _ in 0..horizon {
381 // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382 let (mean, log_std_row) = actor.forward(&state);
383 let mean_v = mean.data()[0];
384 let log_std_v = log_std_row.data()[0];
385 let std_v = log_std_v.exp();
386 let noise = rng.normal();
387 let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389 // Build action tensor [1, A] for log_prob calculation with autograd
390 let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391 let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392 let log_prob_v = log_prob_t.data()[0];
393
394 // Step env
395 let (next_state, reward, done) = env.step(action_v);
396 episode_return += reward;
397
398 // Value
399 let value_t = critic.forward(&state);
400 let value_v = value_t.data()[0];
401
402 // Push
403 batch.push(
404 state.data(),
405 action_v,
406 log_prob_v,
407 reward,
408 if done { 1.0 } else { 0.0 },
409 value_v,
410 next_state.data(),
411 );
412
413 // Reset
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return
430 );
431 episode_return = 0.0;
432 episode += 1;
433 st
434 } else {
435 next_state
436 };
437
438 t += 1;
439 if t >= total_steps {
440 break;
441 }
442 }
443
444 // Bootstrap next values for GAE
445 let next_values: Vec<f32> = {
446 let mut out = Vec::with_capacity(batch.len());
447 for i in 0..batch.len() {
448 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450 let v2 = critic.forward(&s2_t).data()[0];
451 out.push(v2);
452 }
453 out
454 };
455
456 // Compute returns and advantages
457 let mut returns = vec![0.0f32; batch.len()];
458 let mut adv = vec![0.0f32; batch.len()];
459 compute_gae(
460 &mut returns,
461 &mut adv,
462 &batch.rewards,
463 &batch.dones,
464 &batch.values,
465 &next_values,
466 gamma,
467 lam,
468 );
469 normalize_in_place(&mut adv, 1e-8);
470
471 // Prepare tensors for training
472 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473 let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474 let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478 // PPO epochs over the rollout
479 let num_minibatches = batch.len().div_ceil(mini_batch_size);
480 for e in 0..epochs {
481 for mb in 0..num_minibatches {
482 let start = mb * mini_batch_size;
483 let end = (start + mini_batch_size).min(batch.len());
484 if start >= end {
485 break;
486 }
487
488 // Slice views
489 let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490 let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491 let a_mb = actions_t
492 .slice_view(start * action_dim, 1, (end - start) * action_dim)
493 .reshape(vec![(end - start) as i32, action_dim as i32]);
494 let oldlp_mb = old_logp_t
495 .slice_view(start, 1, end - start)
496 .reshape(vec![(end - start) as i32, 1]);
497 let ret_mb = returns_t
498 .slice_view(start, 1, end - start)
499 .reshape(vec![(end - start) as i32, 1]);
500 let adv_mb = adv_t
501 .slice_view(start, 1, end - start)
502 .reshape(vec![(end - start) as i32, 1]);
503
504 // Zero grads
505 {
506 let mut ps = actor.parameters();
507 actor_opt.zero_grad(&mut ps);
508 }
509 {
510 let mut ps = critic.parameters();
511 critic_opt.zero_grad(&mut ps);
512 }
513
514 // Forward actor and critic
515 let (mean_mb, log_std_row) = actor.forward(&s_mb);
516 let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517 let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518 let clip_low =
519 Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520 .unwrap();
521 let clip_high =
522 Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523 .unwrap();
524 // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525 let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526 let ratio_clipped =
527 clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528 let pg1 = ratio.mul_tensor(&adv_mb);
529 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532 let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534 let v_pred = critic.forward(&s_mb);
535 let v_loss = v_pred
536 .sub_tensor(&ret_mb)
537 .pow_scalar(2.0)
538 .mean()
539 .mul_scalar(vf_coef);
540
541 // Entropy (approx Gaussian entropy per action)
542 let entropy = log_std_row
543 .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544 .sum_dims(&[1], true)
545 .mean()
546 .mul_scalar(ent_coef);
547
548 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549 loss.backward(None);
550
551 // Step actor
552 {
553 let params = actor.parameters();
554 let mut with_grads: Vec<&mut Tensor> = Vec::new();
555 for p in params {
556 if p.grad_owned().is_some() {
557 with_grads.push(p);
558 }
559 }
560 if !with_grads.is_empty() {
561 let _ = grad_global_norm(&mut with_grads);
562 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563 actor_opt.step(&mut with_grads);
564 actor_opt.zero_grad(&mut with_grads);
565 }
566 }
567
568 // Step critic
569 {
570 let params = critic.parameters();
571 let mut with_grads: Vec<&mut Tensor> = Vec::new();
572 for p in params {
573 if p.grad_owned().is_some() {
574 with_grads.push(p);
575 }
576 }
577 if !with_grads.is_empty() {
578 let _ = grad_global_norm(&mut with_grads);
579 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580 critic_opt.step(&mut with_grads);
581 critic_opt.zero_grad(&mut with_grads);
582 }
583 }
584
585 // Occasionally log
586 if e == 0 && mb == 0 {
587 println!(
588 "update@t={} | actor_loss={:.4} v_loss={:.4}",
589 t,
590 actor_loss.value(),
591 v_loss.value()
592 );
593 }
594
595 clear_all_graphs_known();
596 }
597 }
598 }
599
600 println!("=== PPO training finished ===");
601 Ok(())
602}148 pub fn train_non_autoregressive_steps(
149 &mut self,
150 src: &Tensor,
151 tgt: &Tensor,
152 steps: usize,
153 lr: f32,
154 ) {
155 let mut opt = Adam::with_learning_rate(lr);
156 {
157 let params_once = self.parameters();
158 for p in ¶ms_once {
159 opt.add_parameter(p);
160 }
161 }
162 for step in 0..steps {
163 // forward + backward scope (immutable borrow)
164 {
165 let pred = self.forward(src, tgt);
166 let diff = pred.sub_tensor(tgt);
167 let mut loss = diff.pow_scalar(2.0).mean();
168 if step == 0 || step + 1 == steps {
169 println!("NAR train step {}: loss={:.6}", step, loss.value());
170 }
171 loss.backward(None);
172 }
173 // step + zero_grad scope (mutable borrow)
174 let mut params_step = self.parameters();
175 opt.step(&mut params_step);
176 opt.zero_grad(&mut params_step);
177 }
178 }
179
180 /// Auto-regressive training (teacher forcing): predict next token with causal mask
181 pub fn train_autoregressive_steps(
182 &mut self,
183 src: &Tensor,
184 tgt: &Tensor,
185 steps: usize,
186 lr: f32,
187 ) {
188 let mut opt = Adam::with_learning_rate(lr);
189 {
190 let params_once = self.parameters();
191 for p in ¶ms_once {
192 opt.add_parameter(p);
193 }
194 }
195
196 // Build encoder memory once (static dataset demo)
197 let mut memory = src.clone();
198 for enc in &self.encoders {
199 memory = enc.forward(&memory, None);
200 }
201
202 let (b, t, _e) = Self::triple(tgt);
203 // Predict y[t] from y[:t] using causal mask; here we simply predict full seq with mask
204 let causal = Self::build_causal_mask_static(b, self.num_heads, t);
205 for step in 0..steps {
206 // forward + backward scope
207 {
208 let mut out = tgt.clone();
209 for dec in &self.decoders {
210 out = dec.forward(&out, &memory, Some(&causal), None);
211 }
212 let diff = out.sub_tensor(tgt);
213 let mut loss = diff.pow_scalar(2.0).mean();
214 if step == 0 || step + 1 == steps {
215 println!("AR train step {}: loss={:.6}", step, loss.value());
216 }
217 loss.backward(None);
218 }
219 let mut params_step = self.parameters();
220 opt.step(&mut params_step);
221 opt.zero_grad(&mut params_step);
222 }
223 }131fn demonstrate_standard_methods() -> Result<(), Box<dyn std::error::Error>> {
132 println!("\n--- Standard Iterator Methods ---");
133
134 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
135
136 // Using map for transformations
137 println!("\nMap transformation (square each element):");
138 let squared: Tensor = tensor.iter().map(|elem| elem.pow_scalar(2.0)).collect();
139 println!(" Squared: {:?}", squared.data());
140
141 // Using enumerate for indexed operations
142 println!("\nEnumerate with indexed operations:");
143 let indexed: Tensor = tensor
144 .iter()
145 .enumerate()
146 .map(|(i, elem)| elem.add_scalar(i as f32))
147 .collect();
148 println!(" Indexed: {:?}", indexed.data());
149
150 // Using fold for reduction
151 println!("\nFold for sum calculation:");
152 let sum: f32 = tensor.iter().fold(0.0, |acc, elem| acc + elem.value());
153 println!(" Sum: {:.1}", sum);
154
155 // Using find for element search
156 println!("\nFind specific element:");
157 if let Some(found) = tensor.iter().find(|elem| elem.value() == 3.0) {
158 println!(" Found element with value 3.0: {:.1}", found.value());
159 }
160
161 // Using any/all for condition checking
162 println!("\nCondition checking:");
163 let all_positive = tensor.iter().all(|elem| elem.value() > 0.0);
164 let any_large = tensor.iter().any(|elem| elem.value() > 4.0);
165 println!(" All positive: {}", all_positive);
166 println!(" Any > 4.0: {}", any_large);
167
168 Ok(())
169}
170
171/// Demonstrate gradient tracking through element operations
172///
173/// Shows how gradient tracking works seamlessly through iterator
174/// operations, maintaining the computational graph for backpropagation.
175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176 println!("\n--- Gradient Tracking ---");
177
178 // Create a tensor with gradient tracking enabled
179 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180 println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182 // Perform element-wise operations through iteration
183 let result: Tensor = tensor
184 .iter()
185 .map(|elem| {
186 // Apply a complex transformation: (x^2 + 1) * 2
187 elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188 })
189 .collect();
190
191 println!("Result tensor: {:?}", result.data());
192 println!("Result requires_grad: {}", result.requires_grad());
193
194 // Compute gradients
195 let mut loss = result.sum();
196 loss.backward(None);
197
198 println!("Loss: {:.6}", loss.value());
199 println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201 Ok(())
202}84fn demonstrate_default_adam() -> Result<(), Box<dyn std::error::Error>> {
85 println!("--- Default Adam Configuration ---");
86
87 // Create a simple regression problem: y = 2*x + 1
88 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
89 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
90
91 // Create model parameters
92 let mut weight = Tensor::randn(vec![1, 1], Some(42)).with_requires_grad();
93 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
94
95 // Create Adam optimizer with default configuration
96 let mut optimizer = Adam::new();
97 optimizer.add_parameter(&weight);
98 optimizer.add_parameter(&bias);
99
100 println!("Default Adam configuration:");
101 println!(" Learning rate: {}", optimizer.learning_rate());
102 println!(" Initial weight: {:.6}", weight.value());
103 println!(" Initial bias: {:.6}", bias.value());
104
105 // Training loop
106 let num_epochs = 50;
107 let mut losses = Vec::new();
108
109 for epoch in 0..num_epochs {
110 // Forward pass
111 let y_pred = x_data.matmul(&weight) + &bias;
112 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
113
114 // Backward pass
115 loss.backward(None);
116
117 // Optimizer step
118 optimizer.step(&mut [&mut weight, &mut bias]);
119 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
120
121 losses.push(loss.value());
122
123 if epoch % 10 == 0 || epoch == num_epochs - 1 {
124 println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
125 }
126 }
127
128 // Evaluate final model
129 let _final_predictions = x_data.matmul(&weight) + &bias;
130 println!("\nFinal model:");
131 println!(" Learned weight: {:.6} (target: 2.0)", weight.value());
132 println!(" Learned bias: {:.6} (target: 1.0)", bias.value());
133 println!(" Final loss: {:.6}", losses[losses.len() - 1]);
134
135 Ok(())
136}
137
138/// Demonstrate learning rate comparison
139fn demonstrate_learning_rate_comparison() -> Result<(), Box<dyn std::error::Error>> {
140 println!("\n--- Learning Rate Comparison ---");
141
142 let learning_rates = [0.001, 0.01, 0.1];
143 let mut results = Vec::new();
144
145 for &lr in &learning_rates {
146 println!("\nTesting learning rate: {}", lr);
147
148 let stats = train_with_config(TrainingConfig {
149 learning_rate: lr,
150 ..Default::default()
151 })?;
152
153 results.push((lr, stats.clone()));
154
155 println!(" Final loss: {:.6}", stats.final_loss);
156 println!(" Convergence epoch: {}", stats.convergence_epoch);
157 }
158
159 // Compare results
160 println!("\nLearning Rate Comparison Summary:");
161 for (lr, stats) in &results {
162 println!(
163 " LR={:6}: Loss={:.6}, Converged@{}",
164 lr, stats.final_loss, stats.convergence_epoch
165 );
166 }
167
168 Ok(())
169}
170
171/// Demonstrate weight decay comparison
172fn demonstrate_weight_decay_comparison() -> Result<(), Box<dyn std::error::Error>> {
173 println!("\n--- Weight Decay Comparison ---");
174
175 let weight_decays = [0.0, 0.001, 0.01];
176 let mut results = Vec::new();
177
178 for &wd in &weight_decays {
179 println!("\nTesting weight decay: {}", wd);
180
181 let stats = train_with_config(TrainingConfig {
182 weight_decay: wd,
183 ..Default::default()
184 })?;
185
186 results.push((wd, stats.clone()));
187
188 println!(" Final loss: {:.6}", stats.final_loss);
189 println!(" Final weight norm: {:.6}", stats.weight_norm);
190 }
191
192 // Compare results
193 println!("\nWeight Decay Comparison Summary:");
194 for (wd, stats) in &results {
195 println!(
196 " WD={:6}: Loss={:.6}, Weight Norm={:.6}",
197 wd, stats.final_loss, stats.weight_norm
198 );
199 }
200
201 Ok(())
202}
203
204/// Demonstrate beta parameter tuning
205fn demonstrate_beta_parameter_tuning() -> Result<(), Box<dyn std::error::Error>> {
206 println!("\n--- Beta Parameter Tuning ---");
207
208 let beta_configs = [
209 (0.9, 0.999), // Default
210 (0.8, 0.999), // More aggressive momentum
211 (0.95, 0.999), // Less aggressive momentum
212 (0.9, 0.99), // Faster second moment decay
213 ];
214
215 let mut results = Vec::new();
216
217 for (i, (beta1, beta2)) in beta_configs.iter().enumerate() {
218 println!(
219 "\nTesting beta configuration {}: beta1={}, beta2={}",
220 i + 1,
221 beta1,
222 beta2
223 );
224
225 let config = TrainingConfig {
226 beta1: *beta1,
227 beta2: *beta2,
228 ..Default::default()
229 };
230
231 let stats = train_with_config(config)?;
232 results.push(((*beta1, *beta2), stats.clone()));
233
234 println!(" Final loss: {:.6}", stats.final_loss);
235 println!(" Convergence epoch: {}", stats.convergence_epoch);
236 }
237
238 // Compare results
239 println!("\nBeta Parameter Comparison Summary:");
240 for ((beta1, beta2), stats) in &results {
241 println!(
242 " B1={:4}, B2={:5}: Loss={:.6}, Converged@{}",
243 beta1, beta2, stats.final_loss, stats.convergence_epoch
244 );
245 }
246
247 Ok(())
248}
249
250/// Demonstrate configuration benchmarking
251fn demonstrate_configuration_benchmarking() -> Result<(), Box<dyn std::error::Error>> {
252 println!("\n--- Configuration Benchmarking ---");
253
254 // Define configurations to benchmark
255 let configs = vec![
256 (
257 "Conservative",
258 TrainingConfig {
259 learning_rate: 0.001,
260 weight_decay: 0.001,
261 beta1: 0.95,
262 ..Default::default()
263 },
264 ),
265 (
266 "Balanced",
267 TrainingConfig {
268 learning_rate: 0.01,
269 weight_decay: 0.0,
270 beta1: 0.9,
271 ..Default::default()
272 },
273 ),
274 (
275 "Aggressive",
276 TrainingConfig {
277 learning_rate: 0.1,
278 weight_decay: 0.0,
279 beta1: 0.8,
280 ..Default::default()
281 },
282 ),
283 ];
284
285 let mut benchmark_results = Vec::new();
286
287 for (name, config) in configs {
288 println!("\nBenchmarking {} configuration:", name);
289
290 let start_time = std::time::Instant::now();
291 let stats = train_with_config(config.clone())?;
292 let elapsed = start_time.elapsed();
293
294 println!(" Training time: {:.2}ms", elapsed.as_millis());
295 println!(" Final loss: {:.6}", stats.final_loss);
296 println!(" Convergence: {} epochs", stats.convergence_epoch);
297
298 benchmark_results.push((name.to_string(), stats, elapsed));
299 }
300
301 // Summary
302 println!("\nBenchmarking Summary:");
303 for (name, stats, elapsed) in &benchmark_results {
304 println!(
305 " {:12}: Loss={:.6}, Time={:4}ms, Converged@{}",
306 name,
307 stats.final_loss,
308 elapsed.as_millis(),
309 stats.convergence_epoch
310 );
311 }
312
313 Ok(())
314}
315
316/// Helper function to train with specific configuration
317fn train_with_config(config: TrainingConfig) -> Result<TrainingStats, Box<dyn std::error::Error>> {
318 // Create training data
319 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
320 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
321
322 // Create model parameters
323 let mut weight = Tensor::randn(vec![1, 1], Some(123)).with_requires_grad();
324 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
325
326 // Create optimizer with custom configuration
327 let adam_config = AdamConfig {
328 learning_rate: config.learning_rate,
329 beta1: config.beta1,
330 beta2: config.beta2,
331 eps: 1e-8,
332 weight_decay: config.weight_decay,
333 amsgrad: false,
334 };
335
336 let mut optimizer = Adam::with_config(adam_config);
337 optimizer.add_parameter(&weight);
338 optimizer.add_parameter(&bias);
339
340 // Training loop
341 let mut losses = Vec::new();
342 let mut convergence_epoch = config.epochs;
343
344 for epoch in 0..config.epochs {
345 // Forward pass
346 let y_pred = x_data.matmul(&weight) + &bias;
347 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
348
349 // Backward pass
350 loss.backward(None);
351
352 // Optimizer step
353 optimizer.step(&mut [&mut weight, &mut bias]);
354 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
355
356 let loss_value = loss.value();
357 losses.push(loss_value);
358
359 // Check for convergence (loss < 0.01)
360 if loss_value < 0.01 && convergence_epoch == config.epochs {
361 convergence_epoch = epoch;
362 }
363 }
364
365 Ok(TrainingStats {
366 config,
367 final_loss: losses[losses.len() - 1],
368 loss_history: losses,
369 convergence_epoch,
370 weight_norm: weight.norm().value(),
371 })
372}- examples/optimizers/learning_rate_scheduling.rs
- examples/getting_started/optimizer_basics.rs
- examples/iterators/performance_optimization.rs
- examples/neural_networks/feedforward_network.rs
- examples/iterators/advanced_patterns.rs
- examples/RL_training/../neural_networks/basic_linear_layer.rs
- examples/RL_training/ppo_discrete.rs
- examples/RL_training/td3.rs
Sourcepub fn pow_tensor(&self, exponent: &Tensor) -> Tensor
pub fn pow_tensor(&self, exponent: &Tensor) -> Tensor
Element-wise power with tensor exponents.
Computes element-wise power: output[i] = self[i]^exponent[i]
§Arguments
exponent- Tensor of exponents, must have the same shape as self
§Returns
A new tensor with each element raised to the corresponding power
§Examples
§Basic Tensor Power
use train_station::Tensor;
let base = Tensor::from_slice(&[2.0, 3.0, 4.0], vec![3]).unwrap();
let exp = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let result = base.pow_tensor(&exp);
assert_eq!(result.shape().dims(), vec![3]);
assert_eq!(result.get(&[0]), 2.0); // 2.0^1.0 = 2.0
assert_eq!(result.get(&[1]), 9.0); // 3.0^2.0 = 9.0
assert_eq!(result.get(&[2]), 64.0); // 4.0^3.0 = 64.0§Mixed Exponents
use train_station::Tensor;
let base = Tensor::from_slice(&[4.0, 9.0, 16.0], vec![3]).unwrap();
let exp = Tensor::from_slice(&[0.5, 1.0, 2.0], vec![3]).unwrap();
let result = base.pow_tensor(&exp);
assert_eq!(result.shape().dims(), vec![3]);
assert_eq!(result.get(&[0]), 2.0); // sqrt(4.0) = 2.0
assert_eq!(result.get(&[1]), 9.0); // 9.0^1.0 = 9.0
assert_eq!(result.get(&[2]), 256.0); // 16.0^2.0 = 256.0§Panics
Panics if tensor shapes don’t match
Source§impl Tensor
impl Tensor
Sourcepub fn relu(&self) -> Tensor
pub fn relu(&self) -> Tensor
Element-wise ReLU (Rectified Linear Unit) activation.
Applies ReLU to each element: output[i] = max(0, self[i])
§Returns
A new tensor with ReLU applied to each element
§Examples
§Basic ReLU Activation
use train_station::Tensor;
let a = Tensor::from_slice(&[-1.0, 0.0, 2.5], vec![3]).unwrap();
let b = a.relu();
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 0.0); // max(0, -1.0) = 0.0
assert_eq!(b.get(&[1]), 0.0); // max(0, 0.0) = 0.0
assert_eq!(b.get(&[2]), 2.5); // max(0, 2.5) = 2.5§Mixed Positive and Negative Values
use train_station::Tensor;
let a = Tensor::from_slice(&[-5.0, -0.1, 0.0, 0.1, 5.0], vec![5]).unwrap();
let b = a.relu();
assert_eq!(b.shape().dims(), vec![5]);
assert_eq!(b.get(&[0]), 0.0); // max(0, -5.0) = 0.0
assert_eq!(b.get(&[1]), 0.0); // max(0, -0.1) = 0.0
assert_eq!(b.get(&[2]), 0.0); // max(0, 0.0) = 0.0
assert_eq!(b.get(&[3]), 0.1); // max(0, 0.1) = 0.1
assert_eq!(b.get(&[4]), 5.0); // max(0, 5.0) = 5.0Examples found in repository?
More examples
68 fn forward(&self, input: &Tensor) -> Tensor {
69 let mut current: Option<Tensor> = None;
70 for (i, layer) in self.layers.iter().enumerate() {
71 let out = if i == 0 {
72 layer.forward(input)
73 } else {
74 layer.forward(current.as_ref().unwrap())
75 };
76 let is_last = i + 1 == self.layers.len();
77 let out = if !is_last { out.relu() } else { out };
78 current = Some(out);
79 }
80 current.expect("MLP has at least one layer")
81 }
82 fn parameters(&mut self) -> Vec<&mut Tensor> {
83 self.layers
84 .iter_mut()
85 .flat_map(|l| l.parameters())
86 .collect()
87 }
88}
89
90// -------------------------------
91// Actor: mean = MLP(state); log_std is a learnable parameter tensor
92// -------------------------------
93
94struct Actor {
95 net: Mlp,
96 log_std: Tensor, // shape [action_dim]
97}
98impl Actor {
99 fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
100 let net = Mlp::new(&[state_dim, 64, 64, action_dim], seed);
101 let log_std = Tensor::from_slice(&vec![0.0; action_dim], vec![action_dim])
102 .unwrap()
103 .with_requires_grad();
104 Self { net, log_std }
105 }
106 fn forward(&self, state: &Tensor) -> (Tensor, Tensor) {
107 // Returns (mean [B, A], log_std [A])
108 let mean = self.net.forward(state);
109 (
110 mean,
111 self.log_std
112 .view(vec![1, self.log_std.shape().dims()[0] as i32]),
113 )
114 }
115 fn parameters(&mut self) -> Vec<&mut Tensor> {
116 let mut ps = self.net.parameters();
117 ps.push(&mut self.log_std);
118 ps
119 }
120}
121
122// -------------------------------
123// Critic: value function V(s)
124// -------------------------------
125
126struct Critic {
127 net: Mlp,
128}
129impl Critic {
130 fn new(state_dim: usize, seed: Option<u64>) -> Self {
131 Self {
132 net: Mlp::new(&[state_dim, 64, 64, 1], seed),
133 }
134 }
135 fn forward(&self, state: &Tensor) -> Tensor {
136 self.net.forward(state)
137 }
138 fn parameters(&mut self) -> Vec<&mut Tensor> {
139 self.net.parameters()
140 }
141}
142
143// -------------------------------
144// Continuous YardEnv (same dynamics as TD3 env)
145// -------------------------------
146
147struct YardEnv {
148 pos: f32,
149 vel: f32,
150 steps: usize,
151 max_steps: usize,
152 rng: SmallRng,
153}
154impl YardEnv {
155 fn new(seed: u64) -> Self {
156 let mut e = Self {
157 pos: 0.0,
158 vel: 0.0,
159 steps: 0,
160 max_steps: 200,
161 rng: SmallRng::new(seed),
162 };
163 e.reset();
164 e
165 }
166 fn reset(&mut self) -> Tensor {
167 self.pos = (self.rng.next_f32() * 1.0) - 0.5;
168 self.vel = (self.rng.next_f32() * 0.2) - 0.1;
169 self.steps = 0;
170 self.state_tensor()
171 }
172 fn state_tensor(&self) -> Tensor {
173 Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
174 }
175 fn step(&mut self, action_value: f32) -> (Tensor, f32, bool) {
176 let a = action_value.clamp(-1.0, 1.0);
177 self.vel += 0.1 * a - 0.01 * self.pos;
178 self.pos += self.vel;
179 self.steps += 1;
180 let reward = -(self.pos * self.pos) - 0.1 * (a * a);
181 let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
182 (self.state_tensor(), reward, done)
183 }
184}
185
186// -------------------------------
187// Trajectory storage
188// -------------------------------
189
190struct RolloutBatch {
191 states: Vec<f32>,
192 actions: Vec<f32>,
193 log_probs: Vec<f32>,
194 rewards: Vec<f32>,
195 dones: Vec<f32>,
196 values: Vec<f32>,
197 next_states: Vec<f32>,
198 _state_dim: usize,
199}
200impl RolloutBatch {
201 fn new(capacity: usize, state_dim: usize) -> Self {
202 Self {
203 states: Vec::with_capacity(capacity * state_dim),
204 actions: Vec::with_capacity(capacity),
205 log_probs: Vec::with_capacity(capacity),
206 rewards: Vec::with_capacity(capacity),
207 dones: Vec::with_capacity(capacity),
208 values: Vec::with_capacity(capacity),
209 next_states: Vec::with_capacity(capacity * state_dim),
210 _state_dim: state_dim,
211 }
212 }
213
214 #[allow(clippy::too_many_arguments)]
215 fn push(&mut self, s: &[f32], a: f32, lp: f32, r: f32, d: f32, v: f32, s2: &[f32]) {
216 self.states.extend_from_slice(s);
217 self.actions.push(a);
218 self.log_probs.push(lp);
219 self.rewards.push(r);
220 self.dones.push(d);
221 self.values.push(v);
222 self.next_states.extend_from_slice(s2);
223 }
224
225 fn len(&self) -> usize {
226 self.actions.len()
227 }
228}
229
230// -------------------------------
231// Math helpers
232// -------------------------------
233
234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235 // All tensors shaped [B, A] (log_std is broadcastable)
236 let std = log_std.exp();
237 let var = std.pow_scalar(2.0);
238 let log_scale = log_std;
239 let diff = action.sub_tensor(mean);
240 let log_prob = diff
241 .pow_scalar(2.0)
242 .div_tensor(&var)
243 .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244 .add_tensor(&log_scale.mul_scalar(2.0))
245 .mul_scalar(0.5)
246 .mul_scalar(-1.0);
247 // Sum across action dim (dim=1) -> [B,1]
248 log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253 returns_out: &mut [f32],
254 adv_out: &mut [f32],
255 rewards: &[f32],
256 dones: &[f32],
257 values: &[f32],
258 next_values: &[f32],
259 gamma: f32,
260 lam: f32,
261) {
262 let n = rewards.len();
263 let mut gae = 0.0f32;
264 for t in (0..n).rev() {
265 let not_done = 1.0 - dones[t];
266 let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267 gae = delta + gamma * lam * not_done * gae;
268 adv_out[t] = gae;
269 returns_out[t] = gae + values[t];
270 }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274 let n = x.len() as f32;
275 if n <= 1.0 {
276 return;
277 }
278 let mean = x.iter().copied().sum::<f32>() / n;
279 let var = x
280 .iter()
281 .map(|v| {
282 let d = v - mean;
283 d * d
284 })
285 .sum::<f32>()
286 / n;
287 let std = (var + eps).sqrt();
288 for v in x.iter_mut() {
289 *v = (*v - mean) / std;
290 }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294 let mut total_sq = 0.0f32;
295 for p in parameters.iter() {
296 if let Some(g) = p.grad_owned() {
297 for &v in g.data() {
298 total_sq += v * v;
299 }
300 }
301 }
302 let norm = total_sq.sqrt();
303 if norm > max_norm {
304 let scale = max_norm / (norm + eps);
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 p.set_grad(g.mul_scalar(scale));
308 }
309 }
310 }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314 let mut total_sq = 0.0f32;
315 for p in parameters.iter_mut() {
316 if let Some(g) = p.grad_owned() {
317 for &v in g.data() {
318 total_sq += v * v;
319 }
320 }
321 }
322 total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330 println!("=== PPO Continuous Example (YardEnv) ===");
331
332 let state_dim = 3usize;
333 let action_dim = 1usize;
334
335 // Hparams
336 let total_steps = std::env::var("PPO_STEPS")
337 .ok()
338 .and_then(|v| v.parse::<usize>().ok())
339 .unwrap_or(4000usize);
340 let horizon = 128usize; // rollout length per update
341 let epochs = 4usize; // PPO epochs per update
342 let mini_batch_size = 64usize; // minibatch from horizon
343 let gamma = 0.99f32;
344 let lam = 0.95f32; // GAE lambda
345 let clip_eps = 0.2f32;
346 let vf_coef = 0.5f32;
347 let ent_coef = 0.0f32;
348 let max_grad_norm = 1.0f32;
349
350 // Models
351 let mut actor = Actor::new(state_dim, action_dim, Some(101));
352 let mut critic = Critic::new(state_dim, Some(202));
353
354 // Opts
355 let mut actor_opt = Adam::with_learning_rate(3e-4);
356 for p in actor.parameters() {
357 actor_opt.add_parameter(p);
358 }
359 let mut critic_opt = Adam::with_learning_rate(3e-4);
360 for p in critic.parameters() {
361 critic_opt.add_parameter(p);
362 }
363
364 // Env and RNG
365 let mut env = YardEnv::new(42);
366 let mut rng = SmallRng::new(999);
367 let mut state = env.reset();
368
369 // Metrics
370 let mut episode_return = 0.0f32;
371 let mut episode = 0usize;
372 let mut ema_return: Option<f32> = None;
373 let ema_alpha = 0.05f32;
374 let mut best_return = f32::NEG_INFINITY;
375
376 let mut t = 0usize;
377 while t < total_steps {
378 // Collect a rollout
379 let mut batch = RolloutBatch::new(horizon, state_dim);
380 for _ in 0..horizon {
381 // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382 let (mean, log_std_row) = actor.forward(&state);
383 let mean_v = mean.data()[0];
384 let log_std_v = log_std_row.data()[0];
385 let std_v = log_std_v.exp();
386 let noise = rng.normal();
387 let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389 // Build action tensor [1, A] for log_prob calculation with autograd
390 let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391 let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392 let log_prob_v = log_prob_t.data()[0];
393
394 // Step env
395 let (next_state, reward, done) = env.step(action_v);
396 episode_return += reward;
397
398 // Value
399 let value_t = critic.forward(&state);
400 let value_v = value_t.data()[0];
401
402 // Push
403 batch.push(
404 state.data(),
405 action_v,
406 log_prob_v,
407 reward,
408 if done { 1.0 } else { 0.0 },
409 value_v,
410 next_state.data(),
411 );
412
413 // Reset
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return
430 );
431 episode_return = 0.0;
432 episode += 1;
433 st
434 } else {
435 next_state
436 };
437
438 t += 1;
439 if t >= total_steps {
440 break;
441 }
442 }
443
444 // Bootstrap next values for GAE
445 let next_values: Vec<f32> = {
446 let mut out = Vec::with_capacity(batch.len());
447 for i in 0..batch.len() {
448 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450 let v2 = critic.forward(&s2_t).data()[0];
451 out.push(v2);
452 }
453 out
454 };
455
456 // Compute returns and advantages
457 let mut returns = vec![0.0f32; batch.len()];
458 let mut adv = vec![0.0f32; batch.len()];
459 compute_gae(
460 &mut returns,
461 &mut adv,
462 &batch.rewards,
463 &batch.dones,
464 &batch.values,
465 &next_values,
466 gamma,
467 lam,
468 );
469 normalize_in_place(&mut adv, 1e-8);
470
471 // Prepare tensors for training
472 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473 let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474 let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478 // PPO epochs over the rollout
479 let num_minibatches = batch.len().div_ceil(mini_batch_size);
480 for e in 0..epochs {
481 for mb in 0..num_minibatches {
482 let start = mb * mini_batch_size;
483 let end = (start + mini_batch_size).min(batch.len());
484 if start >= end {
485 break;
486 }
487
488 // Slice views
489 let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490 let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491 let a_mb = actions_t
492 .slice_view(start * action_dim, 1, (end - start) * action_dim)
493 .reshape(vec![(end - start) as i32, action_dim as i32]);
494 let oldlp_mb = old_logp_t
495 .slice_view(start, 1, end - start)
496 .reshape(vec![(end - start) as i32, 1]);
497 let ret_mb = returns_t
498 .slice_view(start, 1, end - start)
499 .reshape(vec![(end - start) as i32, 1]);
500 let adv_mb = adv_t
501 .slice_view(start, 1, end - start)
502 .reshape(vec![(end - start) as i32, 1]);
503
504 // Zero grads
505 {
506 let mut ps = actor.parameters();
507 actor_opt.zero_grad(&mut ps);
508 }
509 {
510 let mut ps = critic.parameters();
511 critic_opt.zero_grad(&mut ps);
512 }
513
514 // Forward actor and critic
515 let (mean_mb, log_std_row) = actor.forward(&s_mb);
516 let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517 let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518 let clip_low =
519 Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520 .unwrap();
521 let clip_high =
522 Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523 .unwrap();
524 // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525 let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526 let ratio_clipped =
527 clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528 let pg1 = ratio.mul_tensor(&adv_mb);
529 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532 let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534 let v_pred = critic.forward(&s_mb);
535 let v_loss = v_pred
536 .sub_tensor(&ret_mb)
537 .pow_scalar(2.0)
538 .mean()
539 .mul_scalar(vf_coef);
540
541 // Entropy (approx Gaussian entropy per action)
542 let entropy = log_std_row
543 .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544 .sum_dims(&[1], true)
545 .mean()
546 .mul_scalar(ent_coef);
547
548 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549 loss.backward(None);
550
551 // Step actor
552 {
553 let params = actor.parameters();
554 let mut with_grads: Vec<&mut Tensor> = Vec::new();
555 for p in params {
556 if p.grad_owned().is_some() {
557 with_grads.push(p);
558 }
559 }
560 if !with_grads.is_empty() {
561 let _ = grad_global_norm(&mut with_grads);
562 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563 actor_opt.step(&mut with_grads);
564 actor_opt.zero_grad(&mut with_grads);
565 }
566 }
567
568 // Step critic
569 {
570 let params = critic.parameters();
571 let mut with_grads: Vec<&mut Tensor> = Vec::new();
572 for p in params {
573 if p.grad_owned().is_some() {
574 with_grads.push(p);
575 }
576 }
577 if !with_grads.is_empty() {
578 let _ = grad_global_norm(&mut with_grads);
579 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580 critic_opt.step(&mut with_grads);
581 critic_opt.zero_grad(&mut with_grads);
582 }
583 }
584
585 // Occasionally log
586 if e == 0 && mb == 0 {
587 println!(
588 "update@t={} | actor_loss={:.4} v_loss={:.4}",
589 t,
590 actor_loss.value(),
591 v_loss.value()
592 );
593 }
594
595 clear_all_graphs_known();
596 }
597 }
598 }
599
600 println!("=== PPO training finished ===");
601 Ok(())
602}60 fn forward(&self, input: &Tensor) -> Tensor {
61 let mut current: Option<Tensor> = None;
62 for (i, layer) in self.layers.iter().enumerate() {
63 let out = if i == 0 {
64 layer.forward(input)
65 } else {
66 layer.forward(current.as_ref().unwrap())
67 };
68 let is_last = i + 1 == self.layers.len();
69 let out = if !is_last { out.relu() } else { out };
70 current = Some(out);
71 }
72 current.expect("MLP has at least one layer")
73 }
74 fn parameters(&mut self) -> Vec<&mut Tensor> {
75 self.layers
76 .iter_mut()
77 .flat_map(|l| l.parameters())
78 .collect()
79 }
80}
81
82// -------------------------------
83// Actor (logits) + Critic
84// -------------------------------
85
86struct Actor {
87 net: Mlp,
88}
89impl Actor {
90 fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
91 Self {
92 net: Mlp::new(&[state_dim, 64, 64, action_dim], seed),
93 }
94 }
95 fn forward(&self, state: &Tensor) -> Tensor {
96 self.net.forward(state)
97 } // logits [B, A]
98 fn parameters(&mut self) -> Vec<&mut Tensor> {
99 self.net.parameters()
100 }
101}
102
103struct Critic {
104 net: Mlp,
105}
106impl Critic {
107 fn new(state_dim: usize, seed: Option<u64>) -> Self {
108 Self {
109 net: Mlp::new(&[state_dim, 64, 64, 1], seed),
110 }
111 }
112 fn forward(&self, state: &Tensor) -> Tensor {
113 self.net.forward(state)
114 }
115 fn parameters(&mut self) -> Vec<&mut Tensor> {
116 self.net.parameters()
117 }
118}
119
120// -------------------------------
121// Discrete YardEnv (3 actions -> -1, 0, +1)
122// -------------------------------
123
124struct YardEnv {
125 pos: f32,
126 vel: f32,
127 steps: usize,
128 max_steps: usize,
129 rng: SmallRng,
130}
131impl YardEnv {
132 const ACTIONS: [f32; 3] = [-1.0, 0.0, 1.0];
133 fn new(seed: u64) -> Self {
134 let mut e = Self {
135 pos: 0.0,
136 vel: 0.0,
137 steps: 0,
138 max_steps: 200,
139 rng: SmallRng::new(seed),
140 };
141 e.reset();
142 e
143 }
144 fn reset(&mut self) -> Tensor {
145 self.pos = (self.rng.next_f32() * 1.0) - 0.5;
146 self.vel = (self.rng.next_f32() * 0.2) - 0.1;
147 self.steps = 0;
148 self.state_tensor()
149 }
150 fn state_tensor(&self) -> Tensor {
151 Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
152 }
153 fn step(&mut self, action_idx: usize) -> (Tensor, f32, bool) {
154 let a = Self::ACTIONS[action_idx.min(2)];
155 self.vel += 0.1 * a - 0.01 * self.pos;
156 self.pos += self.vel;
157 self.steps += 1;
158 let reward = -(self.pos * self.pos) - 0.05 * (a * a);
159 let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
160 (self.state_tensor(), reward, done)
161 }
162}
163
164// -------------------------------
165// Rollout storage
166// -------------------------------
167
168struct RolloutBatch {
169 states: Vec<f32>,
170 actions: Vec<usize>,
171 old_logps: Vec<f32>,
172 rewards: Vec<f32>,
173 dones: Vec<f32>,
174 values: Vec<f32>,
175 next_states: Vec<f32>,
176 _state_dim: usize,
177}
178impl RolloutBatch {
179 fn new(cap: usize, sd: usize) -> Self {
180 Self {
181 states: Vec::with_capacity(cap * sd),
182 actions: Vec::with_capacity(cap),
183 old_logps: Vec::with_capacity(cap),
184 rewards: Vec::with_capacity(cap),
185 dones: Vec::with_capacity(cap),
186 values: Vec::with_capacity(cap),
187 next_states: Vec::with_capacity(cap * sd),
188 _state_dim: sd,
189 }
190 }
191 #[allow(clippy::too_many_arguments)]
192 fn push(&mut self, s: &[f32], a: usize, lp: f32, r: f32, d: f32, v: f32, s2: &[f32]) {
193 self.states.extend_from_slice(s);
194 self.actions.push(a);
195 self.old_logps.push(lp);
196 self.rewards.push(r);
197 self.dones.push(d);
198 self.values.push(v);
199 self.next_states.extend_from_slice(s2);
200 }
201 fn len(&self) -> usize {
202 self.actions.len()
203 }
204}
205
206// -------------------------------
207// Helpers
208// -------------------------------
209
210#[allow(clippy::too_many_arguments)]
211fn compute_gae(
212 returns_out: &mut [f32],
213 adv_out: &mut [f32],
214 rewards: &[f32],
215 dones: &[f32],
216 values: &[f32],
217 next_values: &[f32],
218 gamma: f32,
219 lam: f32,
220) {
221 let n = rewards.len();
222 let mut gae = 0.0f32;
223 for t in (0..n).rev() {
224 let not_done = 1.0 - dones[t];
225 let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
226 gae = delta + gamma * lam * not_done * gae;
227 adv_out[t] = gae;
228 returns_out[t] = gae + values[t];
229 }
230}
231
232fn normalize_in_place(x: &mut [f32], eps: f32) {
233 let n = x.len() as f32;
234 if n <= 1.0 {
235 return;
236 }
237 let mean = x.iter().copied().sum::<f32>() / n;
238 let var = x
239 .iter()
240 .map(|v| {
241 let d = v - mean;
242 d * d
243 })
244 .sum::<f32>()
245 / n;
246 let std = (var + eps).sqrt();
247 for v in x.iter_mut() {
248 *v = (*v - mean) / std;
249 }
250}
251
252fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
253 let mut total_sq = 0.0f32;
254 for p in parameters.iter() {
255 if let Some(g) = p.grad_owned() {
256 for &v in g.data() {
257 total_sq += v * v;
258 }
259 }
260 }
261 let norm = total_sq.sqrt();
262 if norm > max_norm {
263 let scale = max_norm / (norm + eps);
264 for p in parameters.iter_mut() {
265 if let Some(g) = p.grad_owned() {
266 p.set_grad(g.mul_scalar(scale));
267 }
268 }
269 }
270}
271
272// log-softmax for selected actions: given logits [B,A] and actions Vec<usize> -> log_prob [B,1]
273fn log_prob_actions(
274 logits: &Tensor,
275 actions: &[usize],
276 batch: usize,
277 _action_dim: usize,
278) -> Tensor {
279 let max_logits = logits.max_dims(&[1], true); // [B,1]
280 let shifted = logits.sub_tensor(&max_logits);
281 let exp = shifted.exp();
282 let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283 let log_sum_exp = sum_exp.log(); // [B,1]
284 let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285 // gather selected action log-probs
286 log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291 new_logp.sub_tensor(old_logp).exp()
292}
293
294// Clamp ratio to [1-clip, 1+clip] using ReLU-based clamp (no custom ops)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296 let b = ratio.shape().dims()[0];
297 let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298 let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299 let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300 high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304 let mut total_sq = 0.0f32;
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 for &v in g.data() {
308 total_sq += v * v;
309 }
310 }
311 }
312 total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320 println!("=== PPO Discrete Example (YardEnv) ===");
321
322 let state_dim = 3usize;
323 let action_dim = 3usize;
324 let total_steps = std::env::var("PPOD_STEPS")
325 .ok()
326 .and_then(|v| v.parse::<usize>().ok())
327 .unwrap_or(3500usize);
328 let horizon = 128usize;
329 let epochs = 4usize;
330 let mini_batch_size = 64usize;
331 let gamma = 0.99f32;
332 let lam = 0.95f32;
333 let clip_eps = 0.2f32;
334 let vf_coef = 0.5f32;
335 let ent_coef = 0.0f32;
336 let max_grad_norm = 1.0f32;
337
338 let mut actor = Actor::new(state_dim, action_dim, Some(111));
339 let mut critic = Critic::new(state_dim, Some(222));
340 let mut actor_opt = Adam::with_learning_rate(3e-4);
341 for p in actor.parameters() {
342 actor_opt.add_parameter(p);
343 }
344 let mut critic_opt = Adam::with_learning_rate(3e-4);
345 for p in critic.parameters() {
346 critic_opt.add_parameter(p);
347 }
348
349 let mut env = YardEnv::new(1234);
350 let mut rng = SmallRng::new(98765);
351 let mut state = env.reset();
352 let mut episode_return = 0.0f32;
353 let mut episode = 0usize;
354 let mut ema_return: Option<f32> = None;
355 let ema_alpha = 0.05f32;
356 let mut best_return = f32::NEG_INFINITY;
357
358 let mut t = 0usize;
359 while t < total_steps {
360 let mut batch = RolloutBatch::new(horizon, state_dim);
361 for _ in 0..horizon {
362 // Actor logits and categorical sampling
363 let logits = actor.forward(&state); // [1, A]
364 let probs = logits.softmax(1); // [1, A]
365 // sample action from probs (CPU sampling)
366 let p = probs.data();
367 let (p0, p1, _p2) = (p[0], p[1], p[2]);
368 let u = rng.next_f32();
369 let a_idx = if u < p0 {
370 0
371 } else if u < p0 + p1 {
372 1
373 } else {
374 2
375 };
376
377 let old_logp = {
378 let _ng = NoGradTrack::new();
379 let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380 lp.data()[0]
381 };
382
383 // Step env
384 let (next_state, reward, done) = env.step(a_idx);
385 episode_return += reward;
386
387 // Critic value
388 let value_t = critic.forward(&state);
389 let value_v = value_t.data()[0];
390
391 batch.push(
392 state.data(),
393 a_idx,
394 old_logp,
395 reward,
396 if done { 1.0 } else { 0.0 },
397 value_v,
398 next_state.data(),
399 );
400
401 state = if done {
402 let st = env.reset();
403 ema_return = Some(match ema_return {
404 None => episode_return,
405 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406 });
407 if episode_return > best_return {
408 best_return = episode_return;
409 }
410 println!(
411 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412 t,
413 episode,
414 episode_return,
415 ema_return.unwrap_or(episode_return),
416 best_return
417 );
418 episode_return = 0.0;
419 episode += 1;
420 st
421 } else {
422 next_state
423 };
424
425 t += 1;
426 if t >= total_steps {
427 break;
428 }
429 }
430
431 // Bootstrap values for GAE
432 let next_values: Vec<f32> = {
433 let mut out = Vec::with_capacity(batch.len());
434 for i in 0..batch.len() {
435 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437 out.push(critic.forward(&s2_t).data()[0]);
438 }
439 out
440 };
441
442 let mut returns = vec![0.0f32; batch.len()];
443 let mut adv = vec![0.0f32; batch.len()];
444 compute_gae(
445 &mut returns,
446 &mut adv,
447 &batch.rewards,
448 &batch.dones,
449 &batch.values,
450 &next_values,
451 gamma,
452 lam,
453 );
454 normalize_in_place(&mut adv, 1e-8);
455
456 // Tensors for training
457 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458 let actions_vec = batch.actions.clone();
459 let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463 // PPO epochs
464 let num_minibatches = batch.len().div_ceil(mini_batch_size);
465 for e in 0..epochs {
466 for mb in 0..num_minibatches {
467 let start = mb * mini_batch_size;
468 let end = (start + mini_batch_size).min(batch.len());
469 if start >= end {
470 break;
471 }
472
473 // Views
474 let s_mb = states_t
475 .slice_view(start * state_dim, 1, (end - start) * state_dim)
476 .reshape(vec![(end - start) as i32, state_dim as i32]);
477 let oldlp_mb = old_logp_t
478 .slice_view(start, 1, end - start)
479 .reshape(vec![(end - start) as i32, 1]);
480 let ret_mb = returns_t
481 .slice_view(start, 1, end - start)
482 .reshape(vec![(end - start) as i32, 1]);
483 let adv_mb = adv_t
484 .slice_view(start, 1, end - start)
485 .reshape(vec![(end - start) as i32, 1]);
486 let a_slice = &actions_vec[start..end];
487
488 // Zero grads
489 {
490 let mut ps = actor.parameters();
491 actor_opt.zero_grad(&mut ps);
492 }
493 {
494 let mut ps = critic.parameters();
495 critic_opt.zero_grad(&mut ps);
496 }
497
498 // Forward
499 let logits_mb = actor.forward(&s_mb); // [B,A]
500 let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501 let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502 let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503 let pg1 = ratio.mul_tensor(&adv_mb);
504 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507 let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509 let v_pred = critic.forward(&s_mb);
510 let v_loss = v_pred
511 .sub_tensor(&ret_mb)
512 .pow_scalar(2.0)
513 .mean()
514 .mul_scalar(vf_coef);
515
516 // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517 let probs_mb = logits_mb.softmax(1);
518 let logp_all = probs_mb.add_scalar(1e-8).log();
519 let ent = probs_mb
520 .mul_tensor(&logp_all)
521 .sum_dims(&[1], true)
522 .mul_scalar(-1.0)
523 .mean()
524 .mul_scalar(ent_coef);
525
526 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527 loss.backward(None);
528
529 // Step actor
530 {
531 let params = actor.parameters();
532 let mut with_grads: Vec<&mut Tensor> = Vec::new();
533 for p in params {
534 if p.grad_owned().is_some() {
535 with_grads.push(p);
536 }
537 }
538 if !with_grads.is_empty() {
539 let _ = grad_global_norm(&mut with_grads);
540 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541 actor_opt.step(&mut with_grads);
542 actor_opt.zero_grad(&mut with_grads);
543 }
544 }
545
546 // Step critic
547 {
548 let params = critic.parameters();
549 let mut with_grads: Vec<&mut Tensor> = Vec::new();
550 for p in params {
551 if p.grad_owned().is_some() {
552 with_grads.push(p);
553 }
554 }
555 if !with_grads.is_empty() {
556 let _ = grad_global_norm(&mut with_grads);
557 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558 critic_opt.step(&mut with_grads);
559 critic_opt.zero_grad(&mut with_grads);
560 }
561 }
562
563 if e == 0 && mb == 0 {
564 println!(
565 "update@t={} | actor_loss={:.4} v_loss={:.4}",
566 t,
567 actor_loss.value(),
568 v_loss.value()
569 );
570 }
571
572 clear_all_graphs_known();
573 }
574 }
575 }
576
577 println!("=== PPO discrete training finished ===");
578 Ok(())
579}53 pub fn forward(&self, input: &Tensor, attn_mask: Option<&Tensor>) -> Tensor {
54 let attn = self.mha.forward(input, input, input, attn_mask);
55 let res1 = attn.add_tensor(input);
56
57 // Feed-forward network with ReLU and residual
58 let (b, t, e) = Self::triple(input);
59 let x2d = res1.contiguous().view(vec![(b * t) as i32, e as i32]);
60 let hidden = self.ffn_in.forward(&x2d).relu();
61 let out2d = self.ffn_out.forward(&hidden);
62 let out = out2d.view(vec![b as i32, t as i32, e as i32]);
63 out.add_tensor(&res1)
64 }71 fn forward(&self, input: &Tensor, final_activation: Option<fn(&Tensor) -> Tensor>) -> Tensor {
72 let mut current: Option<Tensor> = None;
73 for (i, layer) in self.layers.iter().enumerate() {
74 let out = if i == 0 {
75 layer.forward(input)
76 } else {
77 layer.forward(current.as_ref().unwrap())
78 };
79 let is_last = i + 1 == self.layers.len();
80 let out = if !is_last {
81 out.relu()
82 } else if let Some(act) = final_activation {
83 act(&out)
84 } else {
85 out
86 };
87 current = Some(out);
88 }
89 current.expect("MLP has at least one layer")
90 }Source§impl Tensor
impl Tensor
Sourcepub fn sigmoid(&self) -> Tensor
pub fn sigmoid(&self) -> Tensor
Element-wise sigmoid activation function
Computes the sigmoid function for each element: output[i] = 1 / (1 + e^(-self[i]))
Uses a numerically stable implementation that avoids overflow for large positive/negative values by using different computation paths for positive and negative inputs.
§Returns
A new tensor with sigmoid applied to each element, values in range (0, 1)
§Performance Characteristics
- Numerical Stability: Avoids overflow using stable implementation
- Scalar Implementation: Optimized scalar computation for mathematical accuracy
- Cache-friendly: Linear memory access patterns
- Mathematical Accuracy: High-precision exponential and division operations
- Gradient Tracking: Full gradtrack support with efficient gradient computation
§Implementation Details
Uses a numerically stable implementation:
- For x ≥ 0: computes 1 / (1 + e^(-x)) to avoid overflow in e^x for large positive x
- For x < 0: computes e^x / (1 + e^x) to avoid overflow in e^(-x) for large negative x This ensures the result is always in the range (0, 1) without numerical overflow.
§Examples
§Basic Sigmoid Activation
use train_station::Tensor;
let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
let b = a.sigmoid();
assert_eq!(b.shape().dims(), vec![3]);
assert!((b.get(&[0]) - 0.26894143).abs() < 1e-6); // sigmoid(-1.0)
assert!((b.get(&[1]) - 0.5).abs() < 1e-6); // sigmoid(0.0)
assert!((b.get(&[2]) - 0.7310586).abs() < 1e-6); // sigmoid(1.0)§Extreme Values
use train_station::Tensor;
let a = Tensor::from_slice(&[-10.0, 10.0], vec![2]).unwrap();
let b = a.sigmoid();
assert_eq!(b.shape().dims(), vec![2]);
assert!(b.get(&[0]) < 1e-4); // sigmoid(-10.0) ≈ 0
assert!(b.get(&[1]) > 0.9999); // sigmoid(10.0) ≈ 1Examples found in repository?
68pub fn main() -> Result<(), Box<dyn std::error::Error>> {
69 println!("=== Supervised FFN Example (XOR) ===");
70
71 // Dataset: XOR (repeat to form a small batch)
72 let inputs: Vec<f32> = vec![
73 0.0, 0.0, // -> 0
74 0.0, 1.0, // -> 1
75 1.0, 0.0, // -> 1
76 1.0, 1.0, // -> 0
77 ];
78 let targets: Vec<f32> = vec![0.0, 1.0, 1.0, 0.0];
79
80 // Repeat the base patterns to stabilize training
81 let repeats = 64usize; // effective batch = 4 * repeats = 256
82 let mut xs = Vec::with_capacity(repeats * inputs.len());
83 let mut ys = Vec::with_capacity(repeats * targets.len());
84 for _ in 0..repeats {
85 xs.extend_from_slice(&inputs);
86 ys.extend_from_slice(&targets);
87 }
88
89 let batch = xs.len() / 2; // two features
90 let x_t = Tensor::from_slice(&xs, vec![batch, 2]).unwrap();
91 let y_t = Tensor::from_slice(&ys, vec![batch, 1]).unwrap();
92
93 // Model config: 2 -> 32 -> 32 -> 1, final sigmoid via loss path
94 let cfg = FeedForwardConfig {
95 input_size: 2,
96 hidden_sizes: vec![32, 32],
97 output_size: 1,
98 use_bias: true,
99 };
100 let mut net = FeedForwardNetwork::new(cfg, Some(777));
101
102 // Optimizer and parameter linking
103 let mut opt = Adam::with_learning_rate(1e-3);
104 for p in net.parameters() {
105 opt.add_parameter(p);
106 }
107
108 let epochs = 1000usize;
109 let max_grad_norm = 1.0f32;
110 let mut best_loss = f32::INFINITY;
111 let mut best_acc = 0.0f32;
112
113 for e in 0..epochs {
114 // Zero grads each iteration
115 {
116 let mut params = net.parameters();
117 opt.zero_grad(&mut params);
118 }
119
120 // Forward -> logits; use numerically stable BCE-with-logits for loss
121 let logits = net.forward(&x_t);
122 let mut loss = bce_with_logits(&logits, &y_t);
123 loss.backward(None);
124
125 // Step only params with grads
126 {
127 let params = net.parameters();
128 let mut with_grads: Vec<&mut Tensor> = Vec::new();
129 for p in params {
130 if p.grad_owned().is_some() {
131 with_grads.push(p);
132 }
133 }
134 if !with_grads.is_empty() {
135 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
136 opt.step(&mut with_grads);
137 opt.zero_grad(&mut with_grads);
138 }
139 }
140
141 // Metrics (use sigmoid only for reporting accuracy)
142 let preds = logits.sigmoid();
143 let acc = accuracy(&preds, &y_t);
144 if loss.value() < best_loss {
145 best_loss = loss.value();
146 }
147 if acc > best_acc {
148 best_acc = acc;
149 }
150 if e % 10 == 0 || e + 1 == epochs {
151 println!(
152 "epoch {:4} | loss={:.5} acc={:.3} | best_loss={:.5} best_acc={:.3}",
153 e,
154 loss.value(),
155 acc,
156 best_loss,
157 best_acc
158 );
159 }
160
161 // Clear graphs to avoid stale accumulation across epochs
162 clear_all_graphs_known();
163 }
164
165 // Quick sanity check predictions
166 let test = Tensor::from_slice(&inputs, vec![4, 2]).unwrap();
167 let out = net.forward(&test).sigmoid();
168 println!("predictions (approx): {:?}", out.data());
169
170 println!("=== Supervised training finished ===");
171 Ok(())
172}Source§impl Tensor
impl Tensor
Sourcepub fn softmax(&self, dim: usize) -> Tensor
pub fn softmax(&self, dim: usize) -> Tensor
Computes softmax activation along the specified dimension
Applies the softmax function along dimension dim, transforming values into
probabilities that sum to 1 along that dimension. Uses numerically stable
computation to avoid overflow: softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))
§Arguments
dim- Dimension along which to compute softmax (0-based indexing)
§Returns
A new tensor with softmax applied along the specified dimension.
Values are in range (0, 1) and sum to 1 along dim.
§Performance Characteristics
- Numerical Stability: Avoids overflow using max subtraction technique
- Scalar Implementation: Optimized scalar computation for mathematical accuracy
- Cache-friendly: Optimized memory access patterns for dimension operations
- Mathematical Accuracy: High-precision exponential and division operations
- GradTrack Support: Full automatic differentiation with efficient gradient computation
§Implementation Details
Uses a numerically stable three-pass algorithm:
- Max Computation: Find the maximum value along the specified dimension
- Exponential Sum: Compute exp(x - max) and sum the results
- Normalization: Divide each exp(x - max) by the sum to get probabilities
This approach prevents overflow by subtracting the maximum value before computing exponentials, ensuring numerical stability for any input range.
§Examples
§Basic Softmax Activation
use train_station::Tensor;
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let b = a.softmax(0);
assert_eq!(b.shape().dims(), vec![3]);
// Verify probabilities sum to 1
let sum = b.get(&[0]) + b.get(&[1]) + b.get(&[2]);
assert!((sum - 1.0).abs() < 1e-6);
// Verify relative ordering is preserved
assert!(b.get(&[0]) < b.get(&[1]));
assert!(b.get(&[1]) < b.get(&[2]));§2D Softmax Along Different Dimensions
use train_station::Tensor;
let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let b = a.softmax(0); // Softmax along first dimension
assert_eq!(b.shape().dims(), vec![2, 2]);
// Each column should sum to 1
let col1_sum = b.get(&[0, 0]) + b.get(&[1, 0]);
let col2_sum = b.get(&[0, 1]) + b.get(&[1, 1]);
assert!((col1_sum - 1.0).abs() < 1e-6);
assert!((col2_sum - 1.0).abs() < 1e-6);§Panics
- Panics if
dimis out of bounds for the tensor’s rank - Panics if the dimension size is 0
Examples found in repository?
72 pub fn forward(
73 &self,
74 query: &Tensor,
75 key: &Tensor,
76 value: &Tensor,
77 attn_mask: Option<&Tensor>,
78 ) -> Tensor {
79 let qkv = Self::project_qkv(query, key, value, &self.q_proj, &self.k_proj, &self.v_proj);
80 let (q, k, v) = qkv;
81
82 // Split heads: [b, t, e] -> [b, h, t, d]
83 let (b, tq, _e) = Self::triple(query);
84 let (_b2, tk, _e2) = Self::triple(key);
85 let q = Self::split_heads(&q, b, tq, self.num_heads, self.head_dim);
86 let k = Self::split_heads(&k, b, tk, self.num_heads, self.head_dim);
87 let v = Self::split_heads(&v, b, tk, self.num_heads, self.head_dim);
88
89 // Scaled dot-product attention
90 // logits: [b, h, tq, tk]
91 let k_t = k.transpose(2, 3);
92 let mut logits = q.matmul(&k_t).div_scalar((self.head_dim as f32).sqrt());
93 if let Some(mask) = attn_mask {
94 let dims = mask.shape().dims().to_vec();
95 // If boolean-like mask matching [b,h,tq,tk], apply masked_fill
96 if dims.len() == 4 && dims[0] == b && dims[1] == self.num_heads && dims[2] == tq {
97 // Interpret mask > 0.5 as keep; we invert to build masked positions
98 let cond: Vec<bool> = mask.data().iter().map(|&v| v < 0.5).collect();
99 // Apply masked fill on a flattened view, then reshape back
100 let flat_logits = logits.view(vec![(b * self.num_heads * tq * tk) as i32]);
101 let filled = flat_logits.masked_fill(&cond, f32::NEG_INFINITY);
102 logits = filled.view(vec![b as i32, self.num_heads as i32, tq as i32, tk as i32]);
103 } else {
104 // Fallback: additive mask
105 logits = logits.add_tensor(mask);
106 }
107 }
108 let attn = logits.softmax(3);
109
110 // context: [b, h, tq, d]
111 let context = attn.matmul(&v);
112 let context = context.permute(vec![0, 2, 1, 3]); // [b, tq, h, d]
113 let context = context.contiguous().view(vec![
114 b as i32,
115 tq as i32,
116 (self.num_heads * self.head_dim) as i32,
117 ]);
118
119 // Output projection (flatten to 2D, project, then restore 3D)
120 let flat = context.view(vec![(b * tq) as i32, self.embed_dim as i32]);
121 let out2d = self.out_proj.forward(&flat);
122 out2d.view(vec![b as i32, tq as i32, self.embed_dim as i32])
123 }More examples
87pub fn main() -> Result<(), Box<dyn std::error::Error>> {
88 println!("=== Supervised Classification Example (Cross-Entropy) ===");
89
90 // Synthetic 2D inputs, 3 classes with linear-ish separations
91 let n = 1200usize;
92 let classes = 3usize;
93 let mut xs: Vec<f32> = Vec::with_capacity(n * 2);
94 let mut ys: Vec<usize> = Vec::with_capacity(n);
95
96 // Simple RNG
97 let mut state: u64 = 424242;
98 let mut rand_f32 = || {
99 state = state.wrapping_mul(1664525).wrapping_add(1013904223);
100 (state >> 16) as f32 / (u32::MAX as f32)
101 };
102
103 for _ in 0..n {
104 let x1 = rand_f32() * 4.0 - 2.0;
105 let x2 = rand_f32() * 4.0 - 2.0;
106 // Class by quadrant-ish rule with noise
107 let mut c = if x1 + 0.5 * x2 > 0.5 {
108 0
109 } else if x1 - x2 < -0.5 {
110 1
111 } else {
112 2
113 };
114 if rand_f32() < 0.05 {
115 c = (c + 1) % classes;
116 }
117 xs.push(x1);
118 xs.push(x2);
119 ys.push(c);
120 }
121
122 // Normalize inputs per-feature to [-1, 1]
123 let mut min1 = f32::INFINITY;
124 let mut max1 = f32::NEG_INFINITY;
125 let mut min2 = f32::INFINITY;
126 let mut max2 = f32::NEG_INFINITY;
127 for i in (0..xs.len()).step_by(2) {
128 let a = xs[i];
129 let b = xs[i + 1];
130 if a < min1 {
131 min1 = a;
132 }
133 if a > max1 {
134 max1 = a;
135 }
136 if b < min2 {
137 min2 = b;
138 }
139 if b > max2 {
140 max2 = b;
141 }
142 }
143 let rng1 = (max1 - min1).max(1e-8);
144 let rng2 = (max2 - min2).max(1e-8);
145 for i in (0..xs.len()).step_by(2) {
146 let a = xs[i];
147 let b = xs[i + 1];
148 xs[i] = 2.0 * (a - min1) / rng1 - 1.0;
149 xs[i + 1] = 2.0 * (b - min2) / rng2 - 1.0;
150 }
151
152 // Train/Val split (80/20)
153 let n_train = (n as f32 * 0.8) as usize;
154 let x_train = Tensor::from_slice(&xs[..n_train * 2], vec![n_train, 2]).unwrap();
155 let y_train = ys[..n_train].to_vec();
156 let x_val = Tensor::from_slice(&xs[n_train * 2..], vec![n - n_train, 2]).unwrap();
157 let y_val = ys[n_train..].to_vec();
158
159 // Model: 2 -> 64 -> 64 -> 3 (logits)
160 let cfg = FeedForwardConfig {
161 input_size: 2,
162 hidden_sizes: vec![64, 64],
163 output_size: classes,
164 use_bias: true,
165 };
166 let mut net = FeedForwardNetwork::new(cfg, Some(303));
167
168 // Optimizer
169 let mut opt = Adam::with_learning_rate(1e-3);
170 for p in net.parameters() {
171 opt.add_parameter(p);
172 }
173
174 let epochs = 300usize;
175 let max_grad_norm = 1.0f32;
176 let mut best_val_acc = 0.0f32;
177 let mut best_val_loss = f32::INFINITY;
178
179 for e in 0..epochs {
180 // Zero grads
181 {
182 let mut params = net.parameters();
183 opt.zero_grad(&mut params);
184 }
185
186 // Forward logits
187 let logits = net.forward(&x_train);
188 let mut loss = cross_entropy_logits(&logits, &y_train, n_train, classes);
189 loss.backward(None);
190
191 // Step clipped
192 {
193 let params = net.parameters();
194 let mut with_grads: Vec<&mut Tensor> = Vec::new();
195 for p in params {
196 if p.grad_owned().is_some() {
197 with_grads.push(p);
198 }
199 }
200 if !with_grads.is_empty() {
201 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
202 opt.step(&mut with_grads);
203 opt.zero_grad(&mut with_grads);
204 }
205 }
206
207 // Metrics
208 let train_acc = accuracy_from_logits(&logits, &y_train, n_train, classes);
209 let val_logits = net.forward(&x_val);
210 let val_loss = cross_entropy_logits(&val_logits, &y_val, n - n_train, classes).value();
211 let val_acc = accuracy_from_logits(&val_logits, &y_val, n - n_train, classes);
212 if val_acc > best_val_acc {
213 best_val_acc = val_acc;
214 }
215 if val_loss < best_val_loss {
216 best_val_loss = val_loss;
217 }
218
219 if e % 10 == 0 || e + 1 == epochs {
220 println!(
221 "epoch {:4} | loss={:.4} acc={:.3} | val_loss={:.4} val_acc={:.3} | best_val_acc={:.3}",
222 e, loss.value(), train_acc, val_loss, val_acc, best_val_acc
223 );
224 }
225
226 clear_all_graphs_known();
227 }
228
229 // Quick sample preds via softmax
230 let samples = Tensor::from_slice(&[-1.0, -1.0, 0.0, 0.0, 1.0, 1.0], vec![3, 2]).unwrap();
231 let sm = net.forward(&samples).softmax(1);
232 println!("sample class probs: {:?}", sm.data());
233
234 println!("=== Supervised classification finished ===");
235 Ok(())
236}319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320 println!("=== PPO Discrete Example (YardEnv) ===");
321
322 let state_dim = 3usize;
323 let action_dim = 3usize;
324 let total_steps = std::env::var("PPOD_STEPS")
325 .ok()
326 .and_then(|v| v.parse::<usize>().ok())
327 .unwrap_or(3500usize);
328 let horizon = 128usize;
329 let epochs = 4usize;
330 let mini_batch_size = 64usize;
331 let gamma = 0.99f32;
332 let lam = 0.95f32;
333 let clip_eps = 0.2f32;
334 let vf_coef = 0.5f32;
335 let ent_coef = 0.0f32;
336 let max_grad_norm = 1.0f32;
337
338 let mut actor = Actor::new(state_dim, action_dim, Some(111));
339 let mut critic = Critic::new(state_dim, Some(222));
340 let mut actor_opt = Adam::with_learning_rate(3e-4);
341 for p in actor.parameters() {
342 actor_opt.add_parameter(p);
343 }
344 let mut critic_opt = Adam::with_learning_rate(3e-4);
345 for p in critic.parameters() {
346 critic_opt.add_parameter(p);
347 }
348
349 let mut env = YardEnv::new(1234);
350 let mut rng = SmallRng::new(98765);
351 let mut state = env.reset();
352 let mut episode_return = 0.0f32;
353 let mut episode = 0usize;
354 let mut ema_return: Option<f32> = None;
355 let ema_alpha = 0.05f32;
356 let mut best_return = f32::NEG_INFINITY;
357
358 let mut t = 0usize;
359 while t < total_steps {
360 let mut batch = RolloutBatch::new(horizon, state_dim);
361 for _ in 0..horizon {
362 // Actor logits and categorical sampling
363 let logits = actor.forward(&state); // [1, A]
364 let probs = logits.softmax(1); // [1, A]
365 // sample action from probs (CPU sampling)
366 let p = probs.data();
367 let (p0, p1, _p2) = (p[0], p[1], p[2]);
368 let u = rng.next_f32();
369 let a_idx = if u < p0 {
370 0
371 } else if u < p0 + p1 {
372 1
373 } else {
374 2
375 };
376
377 let old_logp = {
378 let _ng = NoGradTrack::new();
379 let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380 lp.data()[0]
381 };
382
383 // Step env
384 let (next_state, reward, done) = env.step(a_idx);
385 episode_return += reward;
386
387 // Critic value
388 let value_t = critic.forward(&state);
389 let value_v = value_t.data()[0];
390
391 batch.push(
392 state.data(),
393 a_idx,
394 old_logp,
395 reward,
396 if done { 1.0 } else { 0.0 },
397 value_v,
398 next_state.data(),
399 );
400
401 state = if done {
402 let st = env.reset();
403 ema_return = Some(match ema_return {
404 None => episode_return,
405 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406 });
407 if episode_return > best_return {
408 best_return = episode_return;
409 }
410 println!(
411 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412 t,
413 episode,
414 episode_return,
415 ema_return.unwrap_or(episode_return),
416 best_return
417 );
418 episode_return = 0.0;
419 episode += 1;
420 st
421 } else {
422 next_state
423 };
424
425 t += 1;
426 if t >= total_steps {
427 break;
428 }
429 }
430
431 // Bootstrap values for GAE
432 let next_values: Vec<f32> = {
433 let mut out = Vec::with_capacity(batch.len());
434 for i in 0..batch.len() {
435 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437 out.push(critic.forward(&s2_t).data()[0]);
438 }
439 out
440 };
441
442 let mut returns = vec![0.0f32; batch.len()];
443 let mut adv = vec![0.0f32; batch.len()];
444 compute_gae(
445 &mut returns,
446 &mut adv,
447 &batch.rewards,
448 &batch.dones,
449 &batch.values,
450 &next_values,
451 gamma,
452 lam,
453 );
454 normalize_in_place(&mut adv, 1e-8);
455
456 // Tensors for training
457 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458 let actions_vec = batch.actions.clone();
459 let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463 // PPO epochs
464 let num_minibatches = batch.len().div_ceil(mini_batch_size);
465 for e in 0..epochs {
466 for mb in 0..num_minibatches {
467 let start = mb * mini_batch_size;
468 let end = (start + mini_batch_size).min(batch.len());
469 if start >= end {
470 break;
471 }
472
473 // Views
474 let s_mb = states_t
475 .slice_view(start * state_dim, 1, (end - start) * state_dim)
476 .reshape(vec![(end - start) as i32, state_dim as i32]);
477 let oldlp_mb = old_logp_t
478 .slice_view(start, 1, end - start)
479 .reshape(vec![(end - start) as i32, 1]);
480 let ret_mb = returns_t
481 .slice_view(start, 1, end - start)
482 .reshape(vec![(end - start) as i32, 1]);
483 let adv_mb = adv_t
484 .slice_view(start, 1, end - start)
485 .reshape(vec![(end - start) as i32, 1]);
486 let a_slice = &actions_vec[start..end];
487
488 // Zero grads
489 {
490 let mut ps = actor.parameters();
491 actor_opt.zero_grad(&mut ps);
492 }
493 {
494 let mut ps = critic.parameters();
495 critic_opt.zero_grad(&mut ps);
496 }
497
498 // Forward
499 let logits_mb = actor.forward(&s_mb); // [B,A]
500 let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501 let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502 let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503 let pg1 = ratio.mul_tensor(&adv_mb);
504 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507 let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509 let v_pred = critic.forward(&s_mb);
510 let v_loss = v_pred
511 .sub_tensor(&ret_mb)
512 .pow_scalar(2.0)
513 .mean()
514 .mul_scalar(vf_coef);
515
516 // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517 let probs_mb = logits_mb.softmax(1);
518 let logp_all = probs_mb.add_scalar(1e-8).log();
519 let ent = probs_mb
520 .mul_tensor(&logp_all)
521 .sum_dims(&[1], true)
522 .mul_scalar(-1.0)
523 .mean()
524 .mul_scalar(ent_coef);
525
526 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527 loss.backward(None);
528
529 // Step actor
530 {
531 let params = actor.parameters();
532 let mut with_grads: Vec<&mut Tensor> = Vec::new();
533 for p in params {
534 if p.grad_owned().is_some() {
535 with_grads.push(p);
536 }
537 }
538 if !with_grads.is_empty() {
539 let _ = grad_global_norm(&mut with_grads);
540 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541 actor_opt.step(&mut with_grads);
542 actor_opt.zero_grad(&mut with_grads);
543 }
544 }
545
546 // Step critic
547 {
548 let params = critic.parameters();
549 let mut with_grads: Vec<&mut Tensor> = Vec::new();
550 for p in params {
551 if p.grad_owned().is_some() {
552 with_grads.push(p);
553 }
554 }
555 if !with_grads.is_empty() {
556 let _ = grad_global_norm(&mut with_grads);
557 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558 critic_opt.step(&mut with_grads);
559 critic_opt.zero_grad(&mut with_grads);
560 }
561 }
562
563 if e == 0 && mb == 0 {
564 println!(
565 "update@t={} | actor_loss={:.4} v_loss={:.4}",
566 t,
567 actor_loss.value(),
568 v_loss.value()
569 );
570 }
571
572 clear_all_graphs_known();
573 }
574 }
575 }
576
577 println!("=== PPO discrete training finished ===");
578 Ok(())
579}Source§impl Tensor
impl Tensor
Sourcepub fn sqrt(&self) -> Tensor
pub fn sqrt(&self) -> Tensor
Element-wise square root
Computes the square root for each element: output[i] = sqrt(self[i])
Uses SIMD optimization when available for maximum performance, with automatic fallback to optimized scalar computation for non-SIMD hardware.
§Returns
A new tensor with the square root of each element
§Performance Characteristics
- SIMD Optimization: AVX2-optimized with 32-element blocks and 4x unrolling
- Scalar Fallback: 4x unrolled scalar implementation for non-SIMD hardware
- Cache-friendly: Linear memory access patterns
- Mathematical Accuracy: High-precision square root computation
- GradTrack Support: Full automatic differentiation with efficient gradient computation
§Implementation Details
Automatically selects between SIMD and scalar implementations based on hardware capabilities. SIMD implementation uses AVX2 vector square root operations for optimal performance. Scalar implementation uses 4x unrolling for better instruction-level parallelism.
§Examples
§Basic Square Root
use train_station::Tensor;
let a = Tensor::from_slice(&[1.0, 4.0, 9.0], vec![3]).unwrap();
let b = a.sqrt();
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 1.0); // sqrt(1.0) = 1.0
assert_eq!(b.get(&[1]), 2.0); // sqrt(4.0) = 2.0
assert_eq!(b.get(&[2]), 3.0); // sqrt(9.0) = 3.0§Zero and Special Values
use train_station::Tensor;
let a = Tensor::from_slice(&[0.0, 1.0, 16.0], vec![3]).unwrap();
let b = a.sqrt();
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 0.0); // sqrt(0.0) = 0.0
assert_eq!(b.get(&[1]), 1.0); // sqrt(1.0) = 1.0
assert_eq!(b.get(&[2]), 4.0); // sqrt(16.0) = 4.0§Note
Results are undefined for negative values (may produce NaN)
Examples found in repository?
More examples
162fn demonstrate_memory_optimization() -> Result<(), Box<dyn std::error::Error>> {
163 println!("\n--- Memory Optimization ---");
164
165 // Create a large tensor for memory testing
166 let size = 10000;
167 let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
168 let tensor = Tensor::from_slice(&data, vec![size])?;
169
170 println!("Processing tensor of size: {}", size);
171
172 // Pattern 1: Streaming processing with iterator chunks (process in blocks, collect with shape)
173 println!("\nPattern 1: Streaming Processing");
174 let chunk_size = 1000;
175 let start = Instant::now();
176 let flattened = tensor.view(vec![size as i32]);
177 let _streamed_result: Tensor = flattened
178 .chunks(chunk_size)
179 .map(|c| c.pow_scalar(2.0).sqrt())
180 .collect_shape(vec![size]);
181 let streamed_time = start.elapsed();
182
183 // Pattern 2: Full processing
184 let start = Instant::now();
185 let _full_result: Tensor = tensor
186 .iter_elements()
187 .map(|elem| elem.pow_scalar(2.0).sqrt())
188 .collect_shape(vec![size]);
189 let full_time = start.elapsed();
190
191 println!(" Streaming time: {:?}", streamed_time);
192 println!(" Full processing time: {:?}", full_time);
193 println!(
194 " Memory efficiency ratio: {:.2}x",
195 full_time.as_nanos() as f64 / streamed_time.as_nanos() as f64
196 );
197
198 // Pattern 3: Lazy evaluation with take
199 println!("\nPattern 2: Lazy Evaluation");
200 let start = Instant::now();
201 let lazy_result: Tensor = tensor
202 .iter_elements()
203 .take(1000) // Only process first 1000 elements
204 .map(|elem| elem.pow_scalar(2.0).sqrt())
205 .collect_shape(vec![1000]);
206 let lazy_time = start.elapsed();
207
208 println!(" Lazy processing (1000 elements): {:?}", lazy_time);
209 println!(" Lazy result size: {}", lazy_result.size());
210
211 // Pattern 4: Memory-efficient filtering
212 println!("\nPattern 3: Memory-Efficient Filtering");
213 let start = Instant::now();
214 let filtered_result: Tensor = tensor
215 .iter_elements()
216 .filter(|elem| elem.value() > size as f32 / 2.0) // Keep only large values
217 .map(|elem| elem.mul_scalar(2.0))
218 .collect();
219 let filtered_time = start.elapsed();
220
221 println!(" Filtered processing: {:?}", filtered_time);
222 println!(
223 " Filtered result size: {} (reduced from {})",
224 filtered_result.size(),
225 size
226 );
227
228 Ok(())
229}
230
231/// Demonstrate large-scale processing techniques
232///
233/// Shows how to efficiently process very large datasets using
234/// iterator patterns and optimization strategies.
235fn demonstrate_large_scale_processing() -> Result<(), Box<dyn std::error::Error>> {
236 println!("\n--- Large-Scale Processing ---");
237
238 // Simulate large dataset processing
239 let sizes = vec![10000, 50000, 100000];
240
241 for size in sizes {
242 println!("\nProcessing dataset of size: {}", size);
243
244 // Generate large dataset
245 let data: Vec<f32> = (0..size)
246 .map(|i| {
247 let x = i as f32 / size as f32;
248 x * x + 0.1 * (i % 10) as f32 // Quadratic with noise
249 })
250 .collect();
251
252 let tensor = Tensor::from_slice(&data, vec![size])?;
253
254 // Technique 1: Batch processing
255 let batch_size = 1000;
256 let start = Instant::now();
257
258 let mut batch_results = Vec::new();
259 for batch_start in (0..size).step_by(batch_size) {
260 let batch_end = (batch_start + batch_size).min(size);
261 let batch: Tensor = tensor
262 .iter_range(batch_start, batch_end)
263 .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
264 .collect();
265 batch_results.push(batch);
266 }
267 let batch_time = start.elapsed();
268
269 // Technique 2: Parallel-like processing with stride
270 let start = Instant::now();
271 let stride = 4;
272 let strided_result: Tensor = tensor
273 .iter()
274 .enumerate()
275 .filter(|(i, _)| i % stride == 0)
276 .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
277 .collect();
278 let strided_time = start.elapsed();
279
280 // Technique 3: Hierarchical processing
281 let start = Instant::now();
282 let coarse: Tensor = tensor
283 .iter()
284 .enumerate()
285 .filter(|(i, _)| i % 10 == 0) // Every 10th element
286 .map(|(_, elem)| elem.pow_scalar(2.0).add_scalar(1.0))
287 .collect();
288 let fine: Tensor = tensor
289 .iter()
290 .enumerate()
291 .filter(|(i, _)| i % 10 != 0) // Rest of elements
292 .map(|(_, elem)| elem.pow_scalar(1.5).add_scalar(0.5))
293 .collect();
294 let hierarchical_time = start.elapsed();
295
296 // Report performance
297 println!(" Batch processing: {:?}", batch_time);
298 println!(" Strided processing: {:?}", strided_time);
299 println!(" Hierarchical processing: {:?}", hierarchical_time);
300
301 // Memory usage analysis
302 let total_batches = size.div_ceil(batch_size);
303 println!(" Batch count: {}", total_batches);
304 println!(" Strided result size: {}", strided_result.size());
305 println!(
306 " Hierarchical: coarse={}, fine={}",
307 coarse.size(),
308 fine.size()
309 );
310 }
311
312 Ok(())
313}
314
315/// Demonstrate advanced optimization techniques
316///
317/// Shows sophisticated optimization strategies and techniques
318/// for maximizing performance in tensor iterator operations.
319fn demonstrate_optimization_techniques() -> Result<(), Box<dyn std::error::Error>> {
320 println!("\n--- Optimization Techniques ---");
321
322 let size = 50000;
323 let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
324 let tensor = Tensor::from_slice(&data, vec![size])?;
325
326 println!("Optimizing processing for size: {}", size);
327
328 // Technique 1: Operation fusion
329 println!("\nTechnique 1: Operation Fusion");
330 let start = Instant::now();
331 let fused_result: Tensor = tensor
332 .iter()
333 .map(|elem| {
334 // Fuse multiple operations into single chain
335 elem.mul_scalar(2.0).add_scalar(1.0).pow_scalar(2.0).sqrt()
336 })
337 .collect();
338 let fused_time = start.elapsed();
339
340 // Technique 2: Conditional optimization
341 println!("\nTechnique 2: Conditional Optimization");
342 let start = Instant::now();
343 let conditional_result: Tensor = tensor
344 .iter()
345 .map(|elem| {
346 let val = elem.value();
347 if val < size as f32 / 2.0 {
348 elem.mul_scalar(2.0) // Simple operation for small values
349 } else {
350 elem.pow_scalar(2.0).sqrt() // Complex operation for large values
351 }
352 })
353 .collect();
354 let conditional_time = start.elapsed();
355
356 // Technique 3: Cache-friendly processing
357 println!("\nTechnique 3: Cache-Friendly Processing");
358 let start = Instant::now();
359 let cache_friendly_result: Tensor = tensor
360 .iter()
361 .take(1000) // Process in cache-friendly chunks
362 .map(|elem| elem.mul_scalar(2.0))
363 .collect();
364 let cache_friendly_time = start.elapsed();
365
366 // Technique 4: Memory pooling simulation
367 println!("\nTechnique 4: Memory Pooling Simulation");
368 let start = Instant::now();
369 let pooled_result: Tensor = tensor
370 .iter()
371 .enumerate()
372 .filter(|(i, _)| i % 100 == 0) // Process every 100th element
373 .map(|(_, elem)| elem.pow_scalar(2.0))
374 .collect();
375 let pooled_time = start.elapsed();
376
377 // Report optimization results
378 println!(" Fused operations: {:?}", fused_time);
379 println!(" Conditional optimization: {:?}", conditional_time);
380 println!(" Cache-friendly processing: {:?}", cache_friendly_time);
381 println!(" Memory pooling simulation: {:?}", pooled_time);
382
383 // Performance analysis
384 let fastest = fused_time
385 .min(conditional_time)
386 .min(cache_friendly_time)
387 .min(pooled_time);
388 println!(" Fastest technique: {:?}", fastest);
389
390 // Memory efficiency analysis
391 println!(" Fused result size: {}", fused_result.size());
392 println!(" Conditional result size: {}", conditional_result.size());
393 println!(
394 " Cache-friendly result size: {}",
395 cache_friendly_result.size()
396 );
397 println!(" Pooled result size: {}", pooled_result.size());
398
399 // Technique 5: Gradient optimization
400 println!("\nTechnique 5: Gradient Optimization");
401 let grad_tensor = tensor.with_requires_grad();
402 let start = Instant::now();
403
404 let grad_result: Tensor = grad_tensor
405 .iter()
406 .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
407 .collect();
408
409 let mut loss = grad_result.sum();
410 loss.backward(None);
411 let grad_time = start.elapsed();
412
413 println!(" Gradient computation: {:?}", grad_time);
414 println!(
415 " Gradient tracking enabled: {}",
416 grad_result.requires_grad()
417 );
418
419 Ok(())
420}174fn demonstrate_conditional_processing() -> Result<(), Box<dyn std::error::Error>> {
175 println!("\n--- Conditional Processing ---");
176
177 // Create data with mixed characteristics
178 let data = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0];
179 let tensor = Tensor::from_slice(&data, vec![10])?;
180 println!("Input data: {:?}", tensor.data());
181
182 // Conditional transformation based on sign
183 println!("\nConditional transformation (positive/negative handling):");
184 let processed: Tensor = tensor
185 .iter()
186 .map(|elem| {
187 let val = elem.value();
188 if val > 0.0 {
189 elem.pow_scalar(2.0) // Square positive values
190 } else {
191 elem.mul_scalar(-1.0).sqrt() // Square root of absolute negative values
192 }
193 })
194 .collect();
195 println!(" Processed: {:?}", processed.data());
196
197 // Adaptive filtering based on local statistics
198 println!("\nAdaptive filtering (remove values > 2 std from local mean):");
199 let window_size = 3;
200 let adaptive_filtered: Tensor = tensor
201 .iter()
202 .enumerate()
203 .filter(|(i, elem)| {
204 let start = i.saturating_sub(window_size / 2);
205 let end = (i + window_size / 2 + 1).min(tensor.size());
206
207 // Calculate local mean and std
208 let local_values: Vec<f32> = (start..end)
209 .map(|j| tensor.element_view(j).value())
210 .collect();
211
212 let local_mean = local_values.iter().sum::<f32>() / local_values.len() as f32;
213 let local_variance = local_values
214 .iter()
215 .map(|v| (v - local_mean).powi(2))
216 .sum::<f32>()
217 / local_values.len() as f32;
218 let local_std = local_variance.sqrt();
219
220 let threshold = local_mean + 2.0 * local_std;
221 elem.value() <= threshold
222 })
223 .map(|(_, elem)| elem)
224 .collect();
225 println!(" Adaptive filtered: {:?}", adaptive_filtered.data());
226
227 // Multi-condition processing
228 println!("\nMulti-condition processing:");
229 let multi_processed: Tensor = tensor
230 .iter()
231 .map(|elem| {
232 let val = elem.value();
233 match () {
234 _ if val > 5.0 => elem.mul_scalar(2.0), // Double large values
235 _ if val < -5.0 => elem.div_scalar(2.0), // Halve small values
236 _ if val.abs() < 2.0 => elem.add_scalar(1.0), // Add 1 to small values
237 _ => elem.clone(), // Keep others unchanged
238 }
239 })
240 .collect();
241 println!(" Multi-condition: {:?}", multi_processed.data());
242
243 Ok(())
244}Source§impl Tensor
impl Tensor
Sourcepub fn sub_tensor(&self, other: &Tensor) -> Tensor
pub fn sub_tensor(&self, other: &Tensor) -> Tensor
Element-wise subtraction with another tensor with broadcasting support
Performs element-wise subtraction with automatic broadcasting: output[i] = self[i] - other[i]
Broadcasting enables subtraction between tensors of different but compatible shapes. Compatible shapes follow NumPy broadcasting rules:
- Dimensions are aligned from the rightmost dimension
- Dimensions are compatible if they are equal, or one of them is 1
- Missing dimensions are treated as 1
§Arguments
other- Tensor to subtract. Shapes must be broadcast-compatible.
§Returns
A new tensor containing the element-wise difference with broadcast result shape
§Performance Characteristics
- Fast Path: Optimized for identical shapes to avoid broadcasting overhead
- SIMD Optimization: AVX2-optimized with 32-element blocks and 4x unrolling
- Broadcasting: Efficient broadcasting for compatible shapes
- Cache-friendly: Linear memory access patterns
- GradTrack Support: Full automatic differentiation with efficient gradient computation
§Implementation Details
Uses a fast path for identical shapes to avoid broadcasting overhead. For different shapes, performs broadcasting followed by optimized element-wise subtraction. Automatically selects between SIMD and scalar implementations based on hardware capabilities.
§Examples
§Same Shape Subtraction
use train_station::Tensor;
let a = Tensor::from_slice(&[5.0, 7.0, 9.0], vec![3]).unwrap();
let b = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let c = a.sub_tensor(&b);
assert_eq!(c.shape().dims(), vec![3]);
assert_eq!(c.get(&[0]), 4.0); // 5.0 - 1.0
assert_eq!(c.get(&[1]), 5.0); // 7.0 - 2.0
assert_eq!(c.get(&[2]), 6.0); // 9.0 - 3.0§Broadcasting Subtraction
use train_station::Tensor;
let a = Tensor::from_slice(&[5.0, 10.0], vec![2, 1]).unwrap();
let b = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
let c = a.sub_tensor(&b);
assert_eq!(c.shape().dims(), vec![2, 3]);
// Result: [[4.0, 3.0, 2.0], [9.0, 8.0, 7.0]]
assert_eq!(c.get(&[0, 0]), 4.0); // 5.0 - 1.0
assert_eq!(c.get(&[0, 1]), 3.0); // 5.0 - 2.0
assert_eq!(c.get(&[1, 0]), 9.0); // 10.0 - 1.0§Scalar Subtraction
use train_station::Tensor;
let a = Tensor::ones(vec![2, 3]);
let b = Tensor::from_slice(&[0.5], vec![1]).unwrap();
let c = a.sub_tensor(&b);
assert_eq!(c.shape().dims(), vec![2, 3]);
assert_eq!(c.get(&[0, 0]), 0.5); // 1.0 - 0.5§Panics
Panics if tensor shapes are not broadcast-compatible
Examples found in repository?
42fn mse(pred: &Tensor, target: &Tensor) -> Tensor {
43 pred.sub_tensor(target).pow_scalar(2.0).mean()
44}
45
46fn rmse(pred: &Tensor, target: &Tensor) -> f32 {
47 mse(pred, target).sqrt().value()
48}
49
50fn r2_score(pred: &Tensor, target: &Tensor) -> f32 {
51 // R^2 = 1 - SS_res / SS_tot
52 let y = target;
53 let y_mean = y.mean();
54 let ss_res = pred.sub_tensor(y).pow_scalar(2.0).sum();
55 let ss_tot = y.sub_tensor(&y_mean).pow_scalar(2.0).sum();
56 let ss_res_v = ss_res.value();
57 let ss_tot_v = ss_tot.value().max(1e-12); // avoid divide by zero
58 1.0 - (ss_res_v / ss_tot_v)
59}More examples
44fn cross_entropy_logits(
45 logits: &Tensor,
46 labels: &[usize],
47 batch: usize,
48 _num_classes: usize,
49) -> Tensor {
50 // log_softmax = logits - logsumexp(logits, dim=1)
51 let max_logits = logits.max_dims(&[1], true);
52 let shifted = logits.sub_tensor(&max_logits);
53 let exp = shifted.exp();
54 let sum_exp = exp.sum_dims(&[1], true);
55 let log_sum_exp = sum_exp.log();
56 let log_softmax = shifted.sub_tensor(&log_sum_exp);
57 let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58 ll.mul_scalar(-1.0).mean()
59}273fn log_prob_actions(
274 logits: &Tensor,
275 actions: &[usize],
276 batch: usize,
277 _action_dim: usize,
278) -> Tensor {
279 let max_logits = logits.max_dims(&[1], true); // [B,1]
280 let shifted = logits.sub_tensor(&max_logits);
281 let exp = shifted.exp();
282 let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283 let log_sum_exp = sum_exp.log(); // [B,1]
284 let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285 // gather selected action log-probs
286 log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291 new_logp.sub_tensor(old_logp).exp()
292}
293
294// Clamp ratio to [1-clip, 1+clip] using ReLU-based clamp (no custom ops)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296 let b = ratio.shape().dims()[0];
297 let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298 let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299 let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300 high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304 let mut total_sq = 0.0f32;
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 for &v in g.data() {
308 total_sq += v * v;
309 }
310 }
311 }
312 total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320 println!("=== PPO Discrete Example (YardEnv) ===");
321
322 let state_dim = 3usize;
323 let action_dim = 3usize;
324 let total_steps = std::env::var("PPOD_STEPS")
325 .ok()
326 .and_then(|v| v.parse::<usize>().ok())
327 .unwrap_or(3500usize);
328 let horizon = 128usize;
329 let epochs = 4usize;
330 let mini_batch_size = 64usize;
331 let gamma = 0.99f32;
332 let lam = 0.95f32;
333 let clip_eps = 0.2f32;
334 let vf_coef = 0.5f32;
335 let ent_coef = 0.0f32;
336 let max_grad_norm = 1.0f32;
337
338 let mut actor = Actor::new(state_dim, action_dim, Some(111));
339 let mut critic = Critic::new(state_dim, Some(222));
340 let mut actor_opt = Adam::with_learning_rate(3e-4);
341 for p in actor.parameters() {
342 actor_opt.add_parameter(p);
343 }
344 let mut critic_opt = Adam::with_learning_rate(3e-4);
345 for p in critic.parameters() {
346 critic_opt.add_parameter(p);
347 }
348
349 let mut env = YardEnv::new(1234);
350 let mut rng = SmallRng::new(98765);
351 let mut state = env.reset();
352 let mut episode_return = 0.0f32;
353 let mut episode = 0usize;
354 let mut ema_return: Option<f32> = None;
355 let ema_alpha = 0.05f32;
356 let mut best_return = f32::NEG_INFINITY;
357
358 let mut t = 0usize;
359 while t < total_steps {
360 let mut batch = RolloutBatch::new(horizon, state_dim);
361 for _ in 0..horizon {
362 // Actor logits and categorical sampling
363 let logits = actor.forward(&state); // [1, A]
364 let probs = logits.softmax(1); // [1, A]
365 // sample action from probs (CPU sampling)
366 let p = probs.data();
367 let (p0, p1, _p2) = (p[0], p[1], p[2]);
368 let u = rng.next_f32();
369 let a_idx = if u < p0 {
370 0
371 } else if u < p0 + p1 {
372 1
373 } else {
374 2
375 };
376
377 let old_logp = {
378 let _ng = NoGradTrack::new();
379 let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380 lp.data()[0]
381 };
382
383 // Step env
384 let (next_state, reward, done) = env.step(a_idx);
385 episode_return += reward;
386
387 // Critic value
388 let value_t = critic.forward(&state);
389 let value_v = value_t.data()[0];
390
391 batch.push(
392 state.data(),
393 a_idx,
394 old_logp,
395 reward,
396 if done { 1.0 } else { 0.0 },
397 value_v,
398 next_state.data(),
399 );
400
401 state = if done {
402 let st = env.reset();
403 ema_return = Some(match ema_return {
404 None => episode_return,
405 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406 });
407 if episode_return > best_return {
408 best_return = episode_return;
409 }
410 println!(
411 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412 t,
413 episode,
414 episode_return,
415 ema_return.unwrap_or(episode_return),
416 best_return
417 );
418 episode_return = 0.0;
419 episode += 1;
420 st
421 } else {
422 next_state
423 };
424
425 t += 1;
426 if t >= total_steps {
427 break;
428 }
429 }
430
431 // Bootstrap values for GAE
432 let next_values: Vec<f32> = {
433 let mut out = Vec::with_capacity(batch.len());
434 for i in 0..batch.len() {
435 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437 out.push(critic.forward(&s2_t).data()[0]);
438 }
439 out
440 };
441
442 let mut returns = vec![0.0f32; batch.len()];
443 let mut adv = vec![0.0f32; batch.len()];
444 compute_gae(
445 &mut returns,
446 &mut adv,
447 &batch.rewards,
448 &batch.dones,
449 &batch.values,
450 &next_values,
451 gamma,
452 lam,
453 );
454 normalize_in_place(&mut adv, 1e-8);
455
456 // Tensors for training
457 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458 let actions_vec = batch.actions.clone();
459 let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463 // PPO epochs
464 let num_minibatches = batch.len().div_ceil(mini_batch_size);
465 for e in 0..epochs {
466 for mb in 0..num_minibatches {
467 let start = mb * mini_batch_size;
468 let end = (start + mini_batch_size).min(batch.len());
469 if start >= end {
470 break;
471 }
472
473 // Views
474 let s_mb = states_t
475 .slice_view(start * state_dim, 1, (end - start) * state_dim)
476 .reshape(vec![(end - start) as i32, state_dim as i32]);
477 let oldlp_mb = old_logp_t
478 .slice_view(start, 1, end - start)
479 .reshape(vec![(end - start) as i32, 1]);
480 let ret_mb = returns_t
481 .slice_view(start, 1, end - start)
482 .reshape(vec![(end - start) as i32, 1]);
483 let adv_mb = adv_t
484 .slice_view(start, 1, end - start)
485 .reshape(vec![(end - start) as i32, 1]);
486 let a_slice = &actions_vec[start..end];
487
488 // Zero grads
489 {
490 let mut ps = actor.parameters();
491 actor_opt.zero_grad(&mut ps);
492 }
493 {
494 let mut ps = critic.parameters();
495 critic_opt.zero_grad(&mut ps);
496 }
497
498 // Forward
499 let logits_mb = actor.forward(&s_mb); // [B,A]
500 let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501 let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502 let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503 let pg1 = ratio.mul_tensor(&adv_mb);
504 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507 let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509 let v_pred = critic.forward(&s_mb);
510 let v_loss = v_pred
511 .sub_tensor(&ret_mb)
512 .pow_scalar(2.0)
513 .mean()
514 .mul_scalar(vf_coef);
515
516 // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517 let probs_mb = logits_mb.softmax(1);
518 let logp_all = probs_mb.add_scalar(1e-8).log();
519 let ent = probs_mb
520 .mul_tensor(&logp_all)
521 .sum_dims(&[1], true)
522 .mul_scalar(-1.0)
523 .mean()
524 .mul_scalar(ent_coef);
525
526 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527 loss.backward(None);
528
529 // Step actor
530 {
531 let params = actor.parameters();
532 let mut with_grads: Vec<&mut Tensor> = Vec::new();
533 for p in params {
534 if p.grad_owned().is_some() {
535 with_grads.push(p);
536 }
537 }
538 if !with_grads.is_empty() {
539 let _ = grad_global_norm(&mut with_grads);
540 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541 actor_opt.step(&mut with_grads);
542 actor_opt.zero_grad(&mut with_grads);
543 }
544 }
545
546 // Step critic
547 {
548 let params = critic.parameters();
549 let mut with_grads: Vec<&mut Tensor> = Vec::new();
550 for p in params {
551 if p.grad_owned().is_some() {
552 with_grads.push(p);
553 }
554 }
555 if !with_grads.is_empty() {
556 let _ = grad_global_norm(&mut with_grads);
557 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558 critic_opt.step(&mut with_grads);
559 critic_opt.zero_grad(&mut with_grads);
560 }
561 }
562
563 if e == 0 && mb == 0 {
564 println!(
565 "update@t={} | actor_loss={:.4} v_loss={:.4}",
566 t,
567 actor_loss.value(),
568 v_loss.value()
569 );
570 }
571
572 clear_all_graphs_known();
573 }
574 }
575 }
576
577 println!("=== PPO discrete training finished ===");
578 Ok(())
579}234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235 // All tensors shaped [B, A] (log_std is broadcastable)
236 let std = log_std.exp();
237 let var = std.pow_scalar(2.0);
238 let log_scale = log_std;
239 let diff = action.sub_tensor(mean);
240 let log_prob = diff
241 .pow_scalar(2.0)
242 .div_tensor(&var)
243 .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244 .add_tensor(&log_scale.mul_scalar(2.0))
245 .mul_scalar(0.5)
246 .mul_scalar(-1.0);
247 // Sum across action dim (dim=1) -> [B,1]
248 log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253 returns_out: &mut [f32],
254 adv_out: &mut [f32],
255 rewards: &[f32],
256 dones: &[f32],
257 values: &[f32],
258 next_values: &[f32],
259 gamma: f32,
260 lam: f32,
261) {
262 let n = rewards.len();
263 let mut gae = 0.0f32;
264 for t in (0..n).rev() {
265 let not_done = 1.0 - dones[t];
266 let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267 gae = delta + gamma * lam * not_done * gae;
268 adv_out[t] = gae;
269 returns_out[t] = gae + values[t];
270 }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274 let n = x.len() as f32;
275 if n <= 1.0 {
276 return;
277 }
278 let mean = x.iter().copied().sum::<f32>() / n;
279 let var = x
280 .iter()
281 .map(|v| {
282 let d = v - mean;
283 d * d
284 })
285 .sum::<f32>()
286 / n;
287 let std = (var + eps).sqrt();
288 for v in x.iter_mut() {
289 *v = (*v - mean) / std;
290 }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294 let mut total_sq = 0.0f32;
295 for p in parameters.iter() {
296 if let Some(g) = p.grad_owned() {
297 for &v in g.data() {
298 total_sq += v * v;
299 }
300 }
301 }
302 let norm = total_sq.sqrt();
303 if norm > max_norm {
304 let scale = max_norm / (norm + eps);
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 p.set_grad(g.mul_scalar(scale));
308 }
309 }
310 }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314 let mut total_sq = 0.0f32;
315 for p in parameters.iter_mut() {
316 if let Some(g) = p.grad_owned() {
317 for &v in g.data() {
318 total_sq += v * v;
319 }
320 }
321 }
322 total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330 println!("=== PPO Continuous Example (YardEnv) ===");
331
332 let state_dim = 3usize;
333 let action_dim = 1usize;
334
335 // Hparams
336 let total_steps = std::env::var("PPO_STEPS")
337 .ok()
338 .and_then(|v| v.parse::<usize>().ok())
339 .unwrap_or(4000usize);
340 let horizon = 128usize; // rollout length per update
341 let epochs = 4usize; // PPO epochs per update
342 let mini_batch_size = 64usize; // minibatch from horizon
343 let gamma = 0.99f32;
344 let lam = 0.95f32; // GAE lambda
345 let clip_eps = 0.2f32;
346 let vf_coef = 0.5f32;
347 let ent_coef = 0.0f32;
348 let max_grad_norm = 1.0f32;
349
350 // Models
351 let mut actor = Actor::new(state_dim, action_dim, Some(101));
352 let mut critic = Critic::new(state_dim, Some(202));
353
354 // Opts
355 let mut actor_opt = Adam::with_learning_rate(3e-4);
356 for p in actor.parameters() {
357 actor_opt.add_parameter(p);
358 }
359 let mut critic_opt = Adam::with_learning_rate(3e-4);
360 for p in critic.parameters() {
361 critic_opt.add_parameter(p);
362 }
363
364 // Env and RNG
365 let mut env = YardEnv::new(42);
366 let mut rng = SmallRng::new(999);
367 let mut state = env.reset();
368
369 // Metrics
370 let mut episode_return = 0.0f32;
371 let mut episode = 0usize;
372 let mut ema_return: Option<f32> = None;
373 let ema_alpha = 0.05f32;
374 let mut best_return = f32::NEG_INFINITY;
375
376 let mut t = 0usize;
377 while t < total_steps {
378 // Collect a rollout
379 let mut batch = RolloutBatch::new(horizon, state_dim);
380 for _ in 0..horizon {
381 // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382 let (mean, log_std_row) = actor.forward(&state);
383 let mean_v = mean.data()[0];
384 let log_std_v = log_std_row.data()[0];
385 let std_v = log_std_v.exp();
386 let noise = rng.normal();
387 let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389 // Build action tensor [1, A] for log_prob calculation with autograd
390 let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391 let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392 let log_prob_v = log_prob_t.data()[0];
393
394 // Step env
395 let (next_state, reward, done) = env.step(action_v);
396 episode_return += reward;
397
398 // Value
399 let value_t = critic.forward(&state);
400 let value_v = value_t.data()[0];
401
402 // Push
403 batch.push(
404 state.data(),
405 action_v,
406 log_prob_v,
407 reward,
408 if done { 1.0 } else { 0.0 },
409 value_v,
410 next_state.data(),
411 );
412
413 // Reset
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return
430 );
431 episode_return = 0.0;
432 episode += 1;
433 st
434 } else {
435 next_state
436 };
437
438 t += 1;
439 if t >= total_steps {
440 break;
441 }
442 }
443
444 // Bootstrap next values for GAE
445 let next_values: Vec<f32> = {
446 let mut out = Vec::with_capacity(batch.len());
447 for i in 0..batch.len() {
448 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450 let v2 = critic.forward(&s2_t).data()[0];
451 out.push(v2);
452 }
453 out
454 };
455
456 // Compute returns and advantages
457 let mut returns = vec![0.0f32; batch.len()];
458 let mut adv = vec![0.0f32; batch.len()];
459 compute_gae(
460 &mut returns,
461 &mut adv,
462 &batch.rewards,
463 &batch.dones,
464 &batch.values,
465 &next_values,
466 gamma,
467 lam,
468 );
469 normalize_in_place(&mut adv, 1e-8);
470
471 // Prepare tensors for training
472 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473 let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474 let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478 // PPO epochs over the rollout
479 let num_minibatches = batch.len().div_ceil(mini_batch_size);
480 for e in 0..epochs {
481 for mb in 0..num_minibatches {
482 let start = mb * mini_batch_size;
483 let end = (start + mini_batch_size).min(batch.len());
484 if start >= end {
485 break;
486 }
487
488 // Slice views
489 let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490 let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491 let a_mb = actions_t
492 .slice_view(start * action_dim, 1, (end - start) * action_dim)
493 .reshape(vec![(end - start) as i32, action_dim as i32]);
494 let oldlp_mb = old_logp_t
495 .slice_view(start, 1, end - start)
496 .reshape(vec![(end - start) as i32, 1]);
497 let ret_mb = returns_t
498 .slice_view(start, 1, end - start)
499 .reshape(vec![(end - start) as i32, 1]);
500 let adv_mb = adv_t
501 .slice_view(start, 1, end - start)
502 .reshape(vec![(end - start) as i32, 1]);
503
504 // Zero grads
505 {
506 let mut ps = actor.parameters();
507 actor_opt.zero_grad(&mut ps);
508 }
509 {
510 let mut ps = critic.parameters();
511 critic_opt.zero_grad(&mut ps);
512 }
513
514 // Forward actor and critic
515 let (mean_mb, log_std_row) = actor.forward(&s_mb);
516 let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517 let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518 let clip_low =
519 Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520 .unwrap();
521 let clip_high =
522 Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523 .unwrap();
524 // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525 let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526 let ratio_clipped =
527 clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528 let pg1 = ratio.mul_tensor(&adv_mb);
529 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532 let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534 let v_pred = critic.forward(&s_mb);
535 let v_loss = v_pred
536 .sub_tensor(&ret_mb)
537 .pow_scalar(2.0)
538 .mean()
539 .mul_scalar(vf_coef);
540
541 // Entropy (approx Gaussian entropy per action)
542 let entropy = log_std_row
543 .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544 .sum_dims(&[1], true)
545 .mean()
546 .mul_scalar(ent_coef);
547
548 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549 loss.backward(None);
550
551 // Step actor
552 {
553 let params = actor.parameters();
554 let mut with_grads: Vec<&mut Tensor> = Vec::new();
555 for p in params {
556 if p.grad_owned().is_some() {
557 with_grads.push(p);
558 }
559 }
560 if !with_grads.is_empty() {
561 let _ = grad_global_norm(&mut with_grads);
562 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563 actor_opt.step(&mut with_grads);
564 actor_opt.zero_grad(&mut with_grads);
565 }
566 }
567
568 // Step critic
569 {
570 let params = critic.parameters();
571 let mut with_grads: Vec<&mut Tensor> = Vec::new();
572 for p in params {
573 if p.grad_owned().is_some() {
574 with_grads.push(p);
575 }
576 }
577 if !with_grads.is_empty() {
578 let _ = grad_global_norm(&mut with_grads);
579 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580 critic_opt.step(&mut with_grads);
581 critic_opt.zero_grad(&mut with_grads);
582 }
583 }
584
585 // Occasionally log
586 if e == 0 && mb == 0 {
587 println!(
588 "update@t={} | actor_loss={:.4} v_loss={:.4}",
589 t,
590 actor_loss.value(),
591 v_loss.value()
592 );
593 }
594
595 clear_all_graphs_known();
596 }
597 }
598 }
599
600 println!("=== PPO training finished ===");
601 Ok(())
602}89fn demonstrate_basic_operations() {
90 println!("\n--- Basic Operations ---");
91
92 let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
93 let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
94
95 // Addition
96 let sum = a.add_tensor(&b);
97 println!("A + B: {:?}", sum.data());
98
99 // Subtraction
100 let diff = a.sub_tensor(&b);
101 println!("A - B: {:?}", diff.data());
102
103 // Multiplication
104 let product = a.mul_tensor(&b);
105 println!("A * B: {:?}", product.data());
106
107 // Division
108 let quotient = a.div_tensor(&b);
109 println!("A / B: {:?}", quotient.data());
110
111 // Scalar operations
112 let scalar_add = a.add_scalar(5.0);
113 println!("A + 5.0: {:?}", scalar_add.data());
114
115 let scalar_mul = a.mul_scalar(2.0);
116 println!("A * 2.0: {:?}", scalar_mul.data());
117}Sourcepub fn sub_scalar(&self, scalar: f32) -> Tensor
pub fn sub_scalar(&self, scalar: f32) -> Tensor
Element-wise subtraction of a scalar from this tensor
Performs element-wise subtraction of a scalar value: output[i] = self[i] - scalar
§Arguments
scalar- The scalar value to subtract from each element
§Returns
A new tensor with the scalar subtracted from each element
§Performance Characteristics
- SIMD Optimization: AVX2-optimized with 32-element blocks and 4x unrolling
- Scalar Fallback: 4x unrolled scalar implementation for non-SIMD hardware
- Cache-friendly: Linear memory access patterns
- Mathematical Accuracy: High-precision subtraction computation
- GradTrack Support: Full automatic differentiation with efficient gradient computation
§Examples
§Basic Scalar Subtraction
use train_station::Tensor;
let a = Tensor::from_slice(&[5.0, 7.0, 9.0], vec![3]).unwrap();
let b = a.sub_scalar(2.0);
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 3.0); // 5.0 - 2.0
assert_eq!(b.get(&[1]), 5.0); // 7.0 - 2.0
assert_eq!(b.get(&[2]), 7.0); // 9.0 - 2.0§Negative Scalar Subtraction
use train_station::Tensor;
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let b = a.sub_scalar(-2.0); // Subtracting negative = adding
assert_eq!(b.shape().dims(), vec![3]);
assert_eq!(b.get(&[0]), 3.0); // 1.0 - (-2.0) = 3.0
assert_eq!(b.get(&[1]), 4.0); // 2.0 - (-2.0) = 4.0
assert_eq!(b.get(&[2]), 5.0); // 3.0 - (-2.0) = 5.0Examples found in repository?
More examples
87fn demonstrate_data_pipeline() -> Result<(), Box<dyn std::error::Error>> {
88 println!("\n--- Data Processing Pipeline ---");
89
90 // Simulate raw sensor data with noise
91 let raw_data: Vec<f32> = (0..20)
92 .map(|i| {
93 let base = i as f32 * 0.5;
94 let noise = (i % 3) as f32 * 0.1;
95 base + noise
96 })
97 .collect();
98
99 let tensor = Tensor::from_slice(&raw_data, vec![20])?;
100 println!("Raw sensor data: {:?}", tensor.data());
101
102 // Multi-stage processing pipeline
103 println!("\nProcessing pipeline:");
104 println!("1. Normalize data (z-score)");
105 println!("2. Apply smoothing filter");
106 println!("3. Detect outliers");
107 println!("4. Apply feature scaling");
108
109 // Stage 1: Normalization
110 let mean = tensor.mean().value();
111 let std = tensor.std().value();
112 let normalized: Tensor = tensor
113 .iter()
114 .map(|elem| elem.sub_scalar(mean).div_scalar(std))
115 .collect();
116 println!(
117 " Normalized (mean={:.3}, std={:.3}): {:?}",
118 mean,
119 std,
120 normalized.data()
121 );
122
123 // Stage 2: Smoothing (simple moving average)
124 let smoothed: Tensor = normalized
125 .iter()
126 .enumerate()
127 .map(|(i, elem)| {
128 if i == 0 || i == normalized.size() - 1 {
129 elem.clone()
130 } else {
131 // Simple 3-point average
132 let prev = normalized.element_view(i - 1);
133 let next = normalized.element_view(i + 1);
134 elem.add_tensor(&prev).add_tensor(&next).div_scalar(3.0)
135 }
136 })
137 .collect();
138 println!(" Smoothed: {:?}", smoothed.data());
139
140 // Stage 3: Outlier detection and removal
141 let outlier_threshold = 2.0;
142 let cleaned: Tensor = smoothed
143 .iter()
144 .filter(|elem| elem.value().abs() < outlier_threshold)
145 .collect();
146 println!(
147 " Outliers removed (threshold={}): {:?}",
148 outlier_threshold,
149 cleaned.data()
150 );
151
152 // Stage 4: Feature scaling to [0, 1] range
153 let min_val = cleaned
154 .iter()
155 .map(|e| e.value())
156 .fold(f32::INFINITY, f32::min);
157 let max_val = cleaned
158 .iter()
159 .map(|e| e.value())
160 .fold(f32::NEG_INFINITY, f32::max);
161 let scaled: Tensor = cleaned
162 .iter()
163 .map(|elem| elem.sub_scalar(min_val).div_scalar(max_val - min_val))
164 .collect();
165 println!(" Scaled to [0,1]: {:?}", scaled.data());
166
167 Ok(())
168}
169
170/// Demonstrate conditional processing patterns
171///
172/// Shows how to implement dynamic filtering and transformation
173/// based on data characteristics and conditions.
174fn demonstrate_conditional_processing() -> Result<(), Box<dyn std::error::Error>> {
175 println!("\n--- Conditional Processing ---");
176
177 // Create data with mixed characteristics
178 let data = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0];
179 let tensor = Tensor::from_slice(&data, vec![10])?;
180 println!("Input data: {:?}", tensor.data());
181
182 // Conditional transformation based on sign
183 println!("\nConditional transformation (positive/negative handling):");
184 let processed: Tensor = tensor
185 .iter()
186 .map(|elem| {
187 let val = elem.value();
188 if val > 0.0 {
189 elem.pow_scalar(2.0) // Square positive values
190 } else {
191 elem.mul_scalar(-1.0).sqrt() // Square root of absolute negative values
192 }
193 })
194 .collect();
195 println!(" Processed: {:?}", processed.data());
196
197 // Adaptive filtering based on local statistics
198 println!("\nAdaptive filtering (remove values > 2 std from local mean):");
199 let window_size = 3;
200 let adaptive_filtered: Tensor = tensor
201 .iter()
202 .enumerate()
203 .filter(|(i, elem)| {
204 let start = i.saturating_sub(window_size / 2);
205 let end = (i + window_size / 2 + 1).min(tensor.size());
206
207 // Calculate local mean and std
208 let local_values: Vec<f32> = (start..end)
209 .map(|j| tensor.element_view(j).value())
210 .collect();
211
212 let local_mean = local_values.iter().sum::<f32>() / local_values.len() as f32;
213 let local_variance = local_values
214 .iter()
215 .map(|v| (v - local_mean).powi(2))
216 .sum::<f32>()
217 / local_values.len() as f32;
218 let local_std = local_variance.sqrt();
219
220 let threshold = local_mean + 2.0 * local_std;
221 elem.value() <= threshold
222 })
223 .map(|(_, elem)| elem)
224 .collect();
225 println!(" Adaptive filtered: {:?}", adaptive_filtered.data());
226
227 // Multi-condition processing
228 println!("\nMulti-condition processing:");
229 let multi_processed: Tensor = tensor
230 .iter()
231 .map(|elem| {
232 let val = elem.value();
233 match () {
234 _ if val > 5.0 => elem.mul_scalar(2.0), // Double large values
235 _ if val < -5.0 => elem.div_scalar(2.0), // Halve small values
236 _ if val.abs() < 2.0 => elem.add_scalar(1.0), // Add 1 to small values
237 _ => elem.clone(), // Keep others unchanged
238 }
239 })
240 .collect();
241 println!(" Multi-condition: {:?}", multi_processed.data());
242
243 Ok(())
244}
245
246/// Demonstrate batch processing operations
247///
248/// Shows efficient processing of large datasets using iterator
249/// patterns and batch operations for performance optimization.
250fn demonstrate_batch_operations() -> Result<(), Box<dyn std::error::Error>> {
251 println!("\n--- Batch Operations ---");
252
253 // Create a larger dataset for batch processing
254 let size = 100;
255 let data: Vec<f32> = (0..size)
256 .map(|i| {
257 let x = i as f32 / size as f32;
258 x * x + 0.1 * (i % 7) as f32 // Quadratic with some noise
259 })
260 .collect();
261
262 let tensor = Tensor::from_slice(&data, vec![size])?;
263 println!("Dataset size: {}", tensor.size());
264
265 // Batch processing with windowing (iterator views)
266 println!("\nBatch processing with sliding windows:");
267 let batch_size = 10;
268 let batches: Vec<Tensor> = tensor
269 .iter()
270 .collect::<Vec<_>>()
271 .chunks(batch_size)
272 .map(|chunk| {
273 // Process each batch independently
274 chunk
275 .iter()
276 .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
277 .collect()
278 })
279 .collect();
280
281 println!(
282 " Processed {} batches of size {}",
283 batches.len(),
284 batch_size
285 );
286 for (i, batch) in batches.iter().enumerate() {
287 println!(
288 " Batch {}: mean={:.3}, std={:.3}",
289 i,
290 batch.mean().value(),
291 batch.std().value()
292 );
293 }
294
295 // Parallel-like processing with stride
296 println!("\nStrided processing (every nth element):");
297 let stride = 5;
298 let strided: Tensor = tensor
299 .iter()
300 .enumerate()
301 .filter(|(i, _)| i % stride == 0)
302 .map(|(_, elem)| elem)
303 .collect();
304 println!(" Strided (every {}th): {:?}", stride, strided.data());
305
306 // Hierarchical processing
307 println!("\nHierarchical processing (coarse to fine):");
308 let coarse: Tensor = tensor
309 .iter()
310 .enumerate()
311 .filter(|(i, _)| i % 4 == 0) // Take every 4th element
312 .map(|(_, elem)| elem)
313 .collect();
314
315 let fine: Tensor = tensor
316 .iter()
317 .enumerate()
318 .filter(|(i, _)| i % 4 != 0) // Take the rest
319 .map(|(_, elem)| elem)
320 .collect();
321
322 println!(" Coarse (every 4th): {:?}", coarse.data());
323 println!(" Fine (rest): {:?}", fine.data());
324
325 // Combine coarse and fine with different processing
326 let combined: Tensor = coarse
327 .iter()
328 .map(|elem| elem.mul_scalar(2.0)) // Scale coarse
329 .chain(fine.iter().map(|elem| elem.div_scalar(2.0))) // Scale fine
330 .collect();
331 println!(" Combined: {:?}", combined.data());
332
333 Ok(())
334}
335
336/// Demonstrate real-world processing scenarios
337///
338/// Shows practical applications of iterator patterns for
339/// common data processing tasks in machine learning and analytics.
340fn demonstrate_real_world_scenarios() -> Result<(), Box<dyn std::error::Error>> {
341 println!("\n--- Real-world Scenarios ---");
342
343 // Scenario 1: Time series analysis
344 println!("\nScenario 1: Time Series Analysis");
345 let time_series: Vec<f32> = (0..24)
346 .map(|hour| {
347 let base = 20.0 + 10.0 * (hour as f32 * std::f32::consts::PI / 12.0).sin();
348 base + (hour % 3) as f32 * 2.0 // Add some noise
349 })
350 .collect();
351
352 let series = Tensor::from_slice(&time_series, vec![24])?;
353 println!(" Time series (24 hours): {:?}", series.data());
354
355 // Calculate moving average with view-based iteration
356 let window_size = 3;
357 let moving_avg: Tensor = series
358 .iter()
359 .enumerate()
360 .map(|(i, _)| {
361 let start = i.saturating_sub(window_size / 2);
362 let end = (i + window_size / 2 + 1).min(series.size());
363 let window = series.iter_range(start, end);
364 window.fold(0.0, |acc, elem| acc + elem.value()) / (end - start) as f32
365 })
366 .map(|val| Tensor::from_slice(&[val], vec![1]).unwrap())
367 .collect();
368 println!(
369 " Moving average (window={}): {:?}",
370 window_size,
371 moving_avg.data()
372 );
373
374 // Inference pipeline with NoGrad + streaming
375 println!("\nInference pipeline (NoGrad + streaming)");
376 let features = Tensor::from_slice(
377 &(0..48).map(|i| i as f32 * 0.125).collect::<Vec<_>>(),
378 vec![6, 8],
379 )?;
380 let fast = with_no_grad(|| {
381 // Stream values directly, apply light affine, and collect back to same shape
382 features
383 .data()
384 .iter()
385 .copied()
386 .map(|x| 0.75 * x + 0.1)
387 .collect_shape(vec![6, 8])
388 });
389 println!(
390 " NoGrad streamed transform shape: {:?}",
391 fast.shape().dims()
392 );
393
394 // Row-wise iteration with shape-preserving collection (GradTrack-friendly)
395 let per_row: Tensor = features
396 .iter()
397 .map(|row| row.mul_scalar(0.5).add_scalar(2.0))
398 .collect_shape(vec![6, 8]);
399 println!(" Row-wise mapped shape: {:?}", per_row.shape().dims());
400
401 // Scenario 2: Feature engineering
402 println!("\nScenario 2: Feature Engineering");
403 let features = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
404 println!(" Original features: {:?}", features.data());
405
406 // Create polynomial features
407 let poly_features: Tensor = features
408 .iter()
409 .flat_map(|elem| {
410 vec![
411 elem.clone(), // x^1
412 elem.pow_scalar(2.0), // x^2
413 elem.pow_scalar(3.0), // x^3
414 ]
415 })
416 .collect();
417 println!(
418 " Polynomial features (x, x^2, x^3): {:?}",
419 poly_features.data()
420 );
421
422 // Scenario 3: Data augmentation
423 println!("\nScenario 3: Data Augmentation");
424 let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?;
425 println!(" Original data: {:?}", original.data());
426
427 // Augment with noise and scaling
428 let augmented: Tensor = original
429 .iter()
430 .flat_map(|elem| {
431 vec![
432 elem.clone(), // Original
433 elem.add_scalar(0.1), // Add noise
434 elem.sub_scalar(0.1), // Subtract noise
435 elem.mul_scalar(1.1), // Scale up
436 elem.mul_scalar(0.9), // Scale down
437 ]
438 })
439 .collect();
440 println!(" Augmented data: {:?}", augmented.data());
441
442 // Scenario 4: Statistical analysis
443 println!("\nScenario 4: Statistical Analysis");
444 let sample_data = Tensor::from_slice(&[1.1, 2.3, 1.8, 2.1, 1.9, 2.0, 1.7, 2.2], vec![8])?;
445 println!(" Sample data: {:?}", sample_data.data());
446
447 // Calculate various statistics
448 let mean = sample_data.mean().value();
449 let std = sample_data.std().value();
450 let min = sample_data
451 .iter()
452 .map(|e| e.value())
453 .fold(f32::INFINITY, f32::min);
454 let max = sample_data
455 .iter()
456 .map(|e| e.value())
457 .fold(f32::NEG_INFINITY, f32::max);
458
459 // Z-score normalization
460 let z_scores: Tensor = sample_data
461 .iter()
462 .map(|elem| elem.sub_scalar(mean).div_scalar(std))
463 .collect();
464
465 println!(
466 " Statistics: mean={:.3}, std={:.3}, min={:.3}, max={:.3}",
467 mean, std, min, max
468 );
469 println!(" Z-scores: {:?}", z_scores.data());
470
471 Ok(())
472}Source§impl Tensor
impl Tensor
Sourcepub fn tanh(&self) -> Tensor
pub fn tanh(&self) -> Tensor
Element-wise hyperbolic tangent activation
Computes hyperbolic tangent for each element: output[i] = tanh(self[i])
The hyperbolic tangent function maps any real number to the range (-1, 1), making it useful as an activation function in neural networks.
§Returns
A new tensor with tanh applied to each element, values in range (-1, 1)
§Performance Characteristics
- High Precision: Accurate scalar implementation for mathematical validation
- 4x Unrolling: Optimized scalar operations with instruction-level parallelism
- Cache-friendly: Linear memory access patterns
- Numerical Stability: Robust handling of extreme input values
- GradTrack Support: Full automatic differentiation with efficient gradient computation
§Mathematical Properties
- Range: Output values are in the range (-1, 1)
- Symmetry: tanh(-x) = -tanh(x) (odd function)
- Zero: tanh(0) = 0
- Gradient: ∂tanh(x)/∂x = 1 - tanh²(x) = sech²(x)
§Examples
§Basic Hyperbolic Tangent
use train_station::Tensor;
let a = Tensor::from_slice(&[-1.0, 0.0, 1.0], vec![3]).unwrap();
let b = a.tanh();
assert_eq!(b.shape().dims(), vec![3]);
assert!((b.get(&[0]) - (-0.7615942)).abs() < 1e-6); // tanh(-1.0)
assert!((b.get(&[1]) - 0.0).abs() < 1e-6); // tanh(0.0)
assert!((b.get(&[2]) - 0.7615942).abs() < 1e-6); // tanh(1.0)§Extreme Values
use train_station::Tensor;
let a = Tensor::from_slice(&[-10.0, 10.0], vec![2]).unwrap();
let b = a.tanh();
assert_eq!(b.shape().dims(), vec![2]);
assert!((b.get(&[0]) - (-1.0)).abs() < 1e-6); // tanh(-10.0) ≈ -1
assert!((b.get(&[1]) - 1.0).abs() < 1e-6); // tanh(10.0) ≈ 1Source§impl Tensor
impl Tensor
Sourcepub fn argmax(&self) -> Tensor
pub fn argmax(&self) -> Tensor
Returns the index of the maximum value across all elements in the tensor
This operation finds the flat index (0-based) of the element with the highest value. If multiple elements have the same maximum value, the index of the first occurrence is returned. The output is a scalar tensor with shape [1] containing the index as a float.
This operation is non-differentiable and the output never requires gradients.
§Returns
A tensor with shape [1] containing the flat index of the maximum value
§Examples
use train_station::Tensor;
// 1D tensor
let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![4]).unwrap();
let max_idx = tensor.argmax();
assert_eq!(max_idx.shape().dims(), vec![1]);
assert_eq!(max_idx.get(&[0]), 1.0); // Index 1 has value 5.0use train_station::Tensor;
// 2D tensor
let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
let max_idx = tensor.argmax();
assert_eq!(max_idx.get(&[0]), 5.0); // Flat index 5 has value 5.0use train_station::Tensor;
// Tied values return first occurrence
let tensor = Tensor::from_slice(&[3.0, 5.0, 5.0, 2.0], vec![4]).unwrap();
let max_idx = tensor.argmax();
assert_eq!(max_idx.get(&[0]), 1.0); // First occurrence of 5.0 at index 1Sourcepub fn argmax_dim(&self, dim: usize, keepdim: bool) -> Tensor
pub fn argmax_dim(&self, dim: usize, keepdim: bool) -> Tensor
Returns the indices of maximum values along a specified dimension
This operation finds the indices of maximum values along the specified dimension. For each slice along the dimension, it returns the index of the maximum value. If multiple elements have the same maximum value, the index of the first occurrence is returned.
The output shape depends on the keepdim parameter:
- If
keepdimistrue, the reduced dimension is kept with size 1 - If
keepdimisfalse, the reduced dimension is removed
This operation is non-differentiable and the output never requires gradients.
§Arguments
dim- The dimension along which to find argmax indices (0-based)keepdim- Whether to keep the reduced dimension with size 1
§Returns
A tensor containing the indices of maximum values along the specified dimension
§Panics
Panics if dim is out of bounds for the tensor’s rank or if the dimension size is 0.
§Examples
use train_station::Tensor;
// 2D tensor: [[1.0, 3.0, 2.0],
// [4.0, 0.0, 5.0]]
let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
// argmax along columns (dim=1)
let col_max_idx = tensor.argmax_dim(1, false);
assert_eq!(col_max_idx.shape().dims(), vec![2]);
assert_eq!(col_max_idx.get(&[0]), 1.0); // Row 0: max at index 1 (value 3.0)
assert_eq!(col_max_idx.get(&[1]), 2.0); // Row 1: max at index 2 (value 5.0)use train_station::Tensor;
// argmax along rows (dim=0) with keepdim
let tensor = Tensor::from_slice(&[1.0, 3.0, 2.0, 4.0, 0.0, 5.0], vec![2, 3]).unwrap();
let row_max_idx = tensor.argmax_dim(0, true);
assert_eq!(row_max_idx.shape().dims(), vec![1, 3]);
assert_eq!(row_max_idx.get(&[0, 0]), 1.0); // Col 0: max at index 1 (value 4.0)
assert_eq!(row_max_idx.get(&[0, 1]), 0.0); // Col 1: max at index 0 (value 3.0)
assert_eq!(row_max_idx.get(&[0, 2]), 1.0); // Col 2: max at index 1 (value 5.0)use train_station::Tensor;
// 1D tensor edge case
let tensor = Tensor::from_slice(&[5.0, 1.0, 8.0, 3.0], vec![4]).unwrap();
let max_idx = tensor.argmax_dim(0, false);
assert_eq!(max_idx.shape().dims(), vec![1]); // Special case: becomes [1] not []
assert_eq!(max_idx.get(&[0]), 2.0); // Index 2 has maximum value 8.0Source§impl Tensor
impl Tensor
Sourcepub fn argmin(&self) -> Tensor
pub fn argmin(&self) -> Tensor
Returns the index of the minimum value in the tensor
This method finds the flat index of the minimum value across all elements in the tensor. The result is a scalar tensor containing the index as a floating-point value. This operation is non-differentiable and the output never requires gradient tracking.
§Returns
A tensor with shape [1] containing the flat index of the minimum value
as a f32. If the input tensor is empty, returns 0.0.
§Examples
use train_station::Tensor;
let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0], vec![4]).unwrap();
let min_index = tensor.argmin();
assert_eq!(min_index.get(&[0]), 1.0); // -2.0 is at index 1use train_station::Tensor;
// Empty tensor case
let empty_tensor = Tensor::new(vec![0]);
let min_index = empty_tensor.argmin();
assert_eq!(min_index.get(&[0]), 0.0);Sourcepub fn argmin_dim(&self, dim: usize, keepdim: bool) -> Tensor
pub fn argmin_dim(&self, dim: usize, keepdim: bool) -> Tensor
Returns the indices of minimum values along a specified dimension
This method finds the indices of minimum values along the specified dimension. The result contains the indices where the minimum values occur in that dimension. This operation is non-differentiable and the output never requires gradient tracking.
§Arguments
dim- The dimension along which to find minimum indices (0-based)keepdim- Whether to keep the reduced dimension in the output shape- If
true, the reduced dimension is kept with size 1 - If
false, the reduced dimension is removed from the output shape
- If
§Returns
A tensor containing the indices of minimum values along the specified dimension.
The output shape depends on keepdim:
- If
keepdimistrue, the reduced dimension has size 1 - If
keepdimisfalse, the reduced dimension is removed
§Panics
- If
dimis out of bounds for the tensor’s rank - If the dimension to reduce has size 0
§Examples
use train_station::Tensor;
let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
// Find minimum indices along dimension 1 (columns), keeping the dimension
let indices = tensor.argmin_dim(1, true);
assert_eq!(indices.shape().dims(), vec![2, 1]);
assert_eq!(indices.get(&[0, 0]), 1.0); // -2.0 is at index 1 in first row
assert_eq!(indices.get(&[1, 0]), 2.0); // -3.0 is at index 2 in second rowuse train_station::Tensor;
let tensor = Tensor::from_slice(&[3.0, -2.0, 5.0, -1.0, 0.0, -3.0], vec![2, 3]).unwrap();
// Find minimum indices along dimension 1 (columns), removing the dimension
let indices = tensor.argmin_dim(1, false);
assert_eq!(indices.shape().dims(), vec![2]);
assert_eq!(indices.get(&[0]), 1.0); // -2.0 is at index 1 in first row
assert_eq!(indices.get(&[1]), 2.0); // -3.0 is at index 2 in second rowuse train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
// Find minimum index in a 1D tensor
let index = tensor.argmin_dim(0, false);
assert_eq!(index.shape().dims(), vec![1]);
assert_eq!(index.get(&[0]), 0.0); // 1.0 is at index 0Source§impl Tensor
impl Tensor
Sourcepub fn max(&self) -> Tensor
pub fn max(&self) -> Tensor
Computes the maximum value over all elements in the tensor
Returns a scalar tensor containing the maximum value. For empty tensors, returns negative infinity. This operation supports gradient tracking through the GradTrack system.
§Returns
A tensor with shape [1] containing the maximum value
§Examples
use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![2, 2]).unwrap();
let max_val = tensor.max();
assert_eq!(max_val.get(&[0]), 5.0);§GradTrack Support
When requires_grad is true, this operation is tracked for automatic
differentiation. The gradient computation uses the saved input and output
for efficient backward pass.
Sourcepub fn max_dims(&self, dims: &[usize], keepdim: bool) -> Tensor
pub fn max_dims(&self, dims: &[usize], keepdim: bool) -> Tensor
Computes the maximum value over specified dimensions
Reduces the tensor along the specified dimensions by computing the maximum
value in each reduction group. The keepdim parameter determines whether
reduced dimensions are kept with size 1 or removed entirely.
§Arguments
dims- Dimensions to reduce over (must be valid for the tensor’s rank)keepdim- If true, reduced dimensions are kept with size 1; if false, they are removed
§Returns
A tensor with the specified dimensions reduced
§Examples
use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
// Max over columns (dim 1), keeping dimensions
let max_cols = tensor.max_dims(&[1], true);
assert_eq!(max_cols.shape().dims(), vec![2, 1]);
assert_eq!(max_cols.get(&[0, 0]), 3.0);
assert_eq!(max_cols.get(&[1, 0]), 6.0);
// Max over rows (dim 0), removing dimensions
let max_rows = tensor.max_dims(&[0], false);
assert_eq!(max_rows.shape().dims(), vec![3]);
assert_eq!(max_rows.get(&[0]), 4.0);
assert_eq!(max_rows.get(&[1]), 5.0);
assert_eq!(max_rows.get(&[2]), 6.0);§Panics
Panics if:
dimsis empty- Any dimension in
dimsis out of bounds for the tensor’s rank
§GradTrack Support
When requires_grad is true, this operation is tracked for automatic
differentiation. The gradient computation preserves the original input
shape and handles broadcasting correctly.
Examples found in repository?
44fn cross_entropy_logits(
45 logits: &Tensor,
46 labels: &[usize],
47 batch: usize,
48 _num_classes: usize,
49) -> Tensor {
50 // log_softmax = logits - logsumexp(logits, dim=1)
51 let max_logits = logits.max_dims(&[1], true);
52 let shifted = logits.sub_tensor(&max_logits);
53 let exp = shifted.exp();
54 let sum_exp = exp.sum_dims(&[1], true);
55 let log_sum_exp = sum_exp.log();
56 let log_softmax = shifted.sub_tensor(&log_sum_exp);
57 let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58 ll.mul_scalar(-1.0).mean()
59}More examples
273fn log_prob_actions(
274 logits: &Tensor,
275 actions: &[usize],
276 batch: usize,
277 _action_dim: usize,
278) -> Tensor {
279 let max_logits = logits.max_dims(&[1], true); // [B,1]
280 let shifted = logits.sub_tensor(&max_logits);
281 let exp = shifted.exp();
282 let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283 let log_sum_exp = sum_exp.log(); // [B,1]
284 let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285 // gather selected action log-probs
286 log_softmax.gather(1, actions, &[batch, 1])
287}Source§impl Tensor
impl Tensor
Sourcepub fn mean(&self) -> Tensor
pub fn mean(&self) -> Tensor
Computes the arithmetic mean of all elements in the tensor
This method calculates the average value across all tensor elements by summing all values and dividing by the total number of elements. The result is a scalar tensor containing the mean value. This operation supports gradient tracking through the GradTrack system.
§Returns
A tensor with shape [1] containing the arithmetic mean of all elements.
For empty tensors, returns 0.0 as a safe default.
§Performance Characteristics
- Linear Time: O(n) complexity for computing the sum
- Memory Efficient: Single pass through tensor data with SIMD-optimized accumulation
- Numerical Stability: Uses direct accumulation for typical tensor sizes
- Edge Case Handling: Returns 0.0 for empty tensors
§Examples
use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let mean_val = tensor.mean();
assert_eq!(mean_val.get(&[0]), 2.5); // (1+2+3+4)/4 = 2.5use train_station::Tensor;
// Empty tensor case
let empty_tensor = Tensor::new(vec![0]);
let mean_val = empty_tensor.mean();
assert_eq!(mean_val.get(&[0]), 0.0);§GradTrack Support
When requires_grad is true, this operation is tracked for automatic
differentiation. The gradient computation distributes the gradient equally
across all input elements.
Examples found in repository?
42fn mse(pred: &Tensor, target: &Tensor) -> Tensor {
43 pred.sub_tensor(target).pow_scalar(2.0).mean()
44}
45
46fn rmse(pred: &Tensor, target: &Tensor) -> f32 {
47 mse(pred, target).sqrt().value()
48}
49
50fn r2_score(pred: &Tensor, target: &Tensor) -> f32 {
51 // R^2 = 1 - SS_res / SS_tot
52 let y = target;
53 let y_mean = y.mean();
54 let ss_res = pred.sub_tensor(y).pow_scalar(2.0).sum();
55 let ss_tot = y.sub_tensor(&y_mean).pow_scalar(2.0).sum();
56 let ss_res_v = ss_res.value();
57 let ss_tot_v = ss_tot.value().max(1e-12); // avoid divide by zero
58 1.0 - (ss_res_v / ss_tot_v)
59}More examples
321fn pseudo_huber_mean(diff: &Tensor) -> Tensor {
322 diff.pow_scalar(2.0)
323 .add_scalar(1.0)
324 .sqrt()
325 .sub_scalar(1.0)
326 .mean()
327}
328
329// -------------------------------
330// Main
331// -------------------------------
332
333pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334 println!("=== DQN Example (YardEnv discrete) ===");
335
336 // Dims
337 let state_dim = 3usize;
338 let action_dim = 3usize;
339
340 // Hparams
341 let gamma = 0.99f32;
342 let batch_size = 64usize;
343 let start_steps = 200usize;
344 let target_update_interval = 200usize; // hard update cadence
345 let max_grad_norm = 1.0f32;
346 let mut epsilon = 1.0f32;
347 let eps_min = 0.05f32;
348 let eps_decay_steps = 2_000usize; // linear decay
349 let total_steps = std::env::var("DQN_STEPS")
350 .ok()
351 .and_then(|v| v.parse::<usize>().ok())
352 .unwrap_or(3000usize);
353
354 // Models
355 let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356 let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357 q_targ.net.copy_from(&q_net.net);
358 q_targ.set_requires_grad_all(false);
359
360 // Optimizer
361 let mut q_opt = Adam::with_learning_rate(3e-4);
362 for p in q_net.parameters() {
363 q_opt.add_parameter(p);
364 }
365
366 // Replay + env
367 let mut rb = ReplayBuffer::new(100_000, state_dim);
368 let mut env = YardEnv::new(12345);
369 let mut rng = SmallRng::new(999_111);
370
371 // Metrics
372 let mut state = env.reset();
373 let mut episode_return = 0.0f32;
374 let mut episode = 0usize;
375 let mut ema_return: Option<f32> = None;
376 let ema_alpha = 0.05f32;
377 let mut best_return = f32::NEG_INFINITY;
378
379 for t in 0..total_steps {
380 // Epsilon-greedy action
381 let action_index = if t < start_steps || rng.next_f32() < epsilon {
382 rng.sample_index(action_dim)
383 } else {
384 let _ng = NoGradTrack::new();
385 let q_vals = q_net.forward(&state);
386 let row = q_vals.data();
387 let mut best_i = 0usize;
388 let mut best_v = row[0];
389 for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390 if r > best_v {
391 best_v = r;
392 best_i = i;
393 }
394 }
395 best_i
396 };
397
398 // Env step
399 let (next_state, reward, done) = env.step(action_index);
400 episode_return += reward;
401
402 // Store
403 let s_slice = state.data().to_vec();
404 let s2_slice = next_state.data().to_vec();
405 rb.push(
406 &s_slice,
407 action_index,
408 reward,
409 if done { 1.0 } else { 0.0 },
410 &s2_slice,
411 );
412
413 // Reset on done
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return,
430 rb.size
431 );
432 episode_return = 0.0;
433 episode += 1;
434 st
435 } else {
436 next_state
437 };
438
439 // Epsilon linear decay
440 if t < eps_decay_steps {
441 epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442 }
443
444 // Train
445 if rb.can_sample(batch_size) {
446 let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448 // Double DQN target: a* = argmax_a Q_online(s2,a); y = r + (1-d)*gamma*Q_target(s2, a*)
449 let target_q = {
450 let _ng = NoGradTrack::new();
451 let q_online_s2 = q_net.forward(&s2);
452 // argmax per row (manual on CPU)
453 let row_stride = action_dim;
454 let qd = q_online_s2.data();
455 let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456 for i in 0..batch_size {
457 let base = i * row_stride;
458 let mut bi = 0usize;
459 let mut bv = qd[base];
460 for j in 1..action_dim {
461 let v = qd[base + j];
462 if v > bv {
463 bv = v;
464 bi = j;
465 }
466 }
467 next_actions.push(bi);
468 }
469 let q_targ_s2 = q_targ.forward(&s2);
470 let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471 let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472 r.add_tensor(¬_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473 };
474
475 // Q(s,a) for current actions
476 // Zero grads first
477 {
478 let mut params = q_net.parameters();
479 q_opt.zero_grad(&mut params);
480 }
481
482 let q_all = q_net.forward(&s);
483 let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484 let diff = q_sa.sub_tensor(&target_q);
485 let mut loss = pseudo_huber_mean(&diff);
486 loss.backward(None);
487
488 // Step (filter only params with grads)
489 {
490 let params = q_net.parameters();
491 let mut with_grads: Vec<&mut Tensor> = Vec::new();
492 for p in params {
493 if p.grad_owned().is_some() {
494 with_grads.push(p);
495 }
496 }
497 if !with_grads.is_empty() {
498 let gn = grad_global_norm(&mut with_grads);
499 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500 q_opt.step(&mut with_grads);
501 q_opt.zero_grad(&mut with_grads);
502 if t % 100 == 0 {
503 let mut pn = q_net.parameters();
504 let pn_l2 = params_l2_norm(&mut pn);
505 let q_mean = q_all.mean().value();
506 println!(
507 "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508 t, loss.value(), q_mean, gn, pn_l2, epsilon
509 );
510 }
511 }
512 }
513
514 // Target hard update
515 if t % target_update_interval == 0 {
516 q_targ.net.copy_from(&q_net.net);
517 }
518
519 // Clear graphs
520 clear_all_graphs_known();
521 }
522 }
523
524 println!("=== DQN training finished ===");
525 Ok(())
526}44fn cross_entropy_logits(
45 logits: &Tensor,
46 labels: &[usize],
47 batch: usize,
48 _num_classes: usize,
49) -> Tensor {
50 // log_softmax = logits - logsumexp(logits, dim=1)
51 let max_logits = logits.max_dims(&[1], true);
52 let shifted = logits.sub_tensor(&max_logits);
53 let exp = shifted.exp();
54 let sum_exp = exp.sum_dims(&[1], true);
55 let log_sum_exp = sum_exp.log();
56 let log_softmax = shifted.sub_tensor(&log_sum_exp);
57 let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58 ll.mul_scalar(-1.0).mean()
59}73fn main() -> Result<(), Box<dyn std::error::Error>> {
74 println!("=== Basic Encoder Example ===");
75
76 let batch = 2usize;
77 let seq = 6usize;
78 let embed = 32usize;
79 let heads = 4usize;
80
81 let input = Tensor::randn(vec![batch, seq, embed], Some(11));
82 let mut enc = EncoderBlock::new(embed, heads, Some(123));
83
84 // Example: no mask (set Some(mask) to use masking)
85 let out = enc.forward(&input, None);
86 println!("Output shape: {:?}", out.shape().dims());
87
88 // Verify gradients/optimization
89 let mut opt = Adam::with_learning_rate(0.01);
90 let mut params = enc.parameters();
91 for p in ¶ms {
92 opt.add_parameter(p);
93 }
94 let mut loss = out.mean();
95 loss.backward(None);
96 opt.step(&mut params);
97 opt.zero_grad(&mut params);
98 println!("Loss: {:.6}", loss.value());
99 println!("=== Done ===");
100 Ok(())
101}179fn demonstrate_utility_functions() {
180 println!("\n--- Utility Functions ---");
181
182 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184 // Basic properties
185 println!("Shape: {:?}", tensor.shape().dims());
186 println!("Size: {}", tensor.size());
187 println!("Is contiguous: {}", tensor.is_contiguous());
188 println!("Device: {:?}", tensor.device());
189
190 // Mathematical operations
191 let sum = tensor.sum();
192 println!("Sum: {}", sum.value());
193
194 let mean = tensor.mean();
195 println!("Mean: {}", mean.value());
196
197 let norm = tensor.norm();
198 println!("Norm: {}", norm.value());
199
200 // Device placement
201 let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202 println!(
203 "CPU tensor: shape {:?}, device: {:?}",
204 cpu_tensor.shape().dims(),
205 cpu_tensor.device()
206 );
207}- examples/neural_networks/basic_decoder.rs
- examples/neural_networks/basic_transformer.rs
- examples/neural_networks/multi_head_attention.rs
- examples/optimizers/adam_configurations.rs
- examples/optimizers/learning_rate_scheduling.rs
- examples/getting_started/optimizer_basics.rs
- examples/neural_networks/feedforward_network.rs
- examples/iterators/advanced_patterns.rs
- examples/RL_training/../neural_networks/basic_linear_layer.rs
- examples/RL_training/ppo_discrete.rs
- examples/RL_training/ppo_continuous.rs
- examples/RL_training/td3.rs
Sourcepub fn mean_dims(&self, dims: &[usize], keepdim: bool) -> Tensor
pub fn mean_dims(&self, dims: &[usize], keepdim: bool) -> Tensor
Computes the arithmetic mean over specified dimensions
This method calculates the mean value along the specified dimensions by first
computing the sum over those dimensions and then dividing by the product of
the reduced dimension sizes. The keepdim parameter determines whether
reduced dimensions are kept with size 1 or removed entirely.
§Arguments
dims- Dimensions to reduce over (must be valid for the tensor’s rank)keepdim- If true, reduced dimensions are kept with size 1; if false, they are removed
§Returns
A tensor with the specified dimensions reduced by computing the mean.
The output shape depends on keepdim:
- If
keepdimistrue, reduced dimensions have size 1 - If
keepdimisfalse, reduced dimensions are removed
§Performance Characteristics
- Efficient Implementation: Uses
sum_dimsfollowed by scalar multiplication - Memory Optimized: Leverages existing sum reduction for optimal performance
- Shape Computation: Fast output shape calculation with dimension preservation
- Numerical Stability: Maintains precision through direct computation
§Examples
use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
// Mean over columns (dim 1), keeping dimensions
let mean_cols = tensor.mean_dims(&[1], true);
assert_eq!(mean_cols.shape().dims(), vec![2, 1]);
assert_eq!(mean_cols.get(&[0, 0]), 2.0); // (1+2+3)/3 = 2.0
assert_eq!(mean_cols.get(&[1, 0]), 5.0); // (4+5+6)/3 = 5.0use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
// Mean over rows (dim 0), removing dimensions
let mean_rows = tensor.mean_dims(&[0], false);
assert_eq!(mean_rows.shape().dims(), vec![3]);
assert_eq!(mean_rows.get(&[0]), 2.5); // (1+4)/2 = 2.5
assert_eq!(mean_rows.get(&[1]), 3.5); // (2+5)/2 = 3.5
assert_eq!(mean_rows.get(&[2]), 4.5); // (3+6)/2 = 4.5use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
// Mean over multiple dimensions
let mean_all = tensor.mean_dims(&[0, 1], false);
assert_eq!(mean_all.shape().dims(), vec![1]);
assert_eq!(mean_all.get(&[0]), 2.5); // (1+2+3+4)/4 = 2.5§Panics
Panics if:
dimsis empty- Any dimension in
dimsis out of bounds for the tensor’s rank
§GradTrack Support
When requires_grad is true, this operation is tracked for automatic
differentiation. The gradient computation preserves the original input
shape and handles broadcasting correctly through the ReduceMeanDims gradient function.
Source§impl Tensor
impl Tensor
Sourcepub fn min(&self) -> Tensor
pub fn min(&self) -> Tensor
Computes the minimum value over all elements in the tensor
Returns a scalar tensor containing the minimum value. For empty tensors, returns positive infinity. This operation supports gradient tracking through the GradTrack system.
§Returns
A tensor with shape [1] containing the minimum value
§Examples
use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 5.0, 3.0, 2.0], vec![2, 2]).unwrap();
let min_val = tensor.min();
assert_eq!(min_val.get(&[0]), 1.0);use train_station::Tensor;
// Empty tensor case
let empty_tensor = Tensor::new(vec![0]);
let min_val = empty_tensor.min();
assert_eq!(min_val.get(&[0]), f32::INFINITY);§GradTrack Support
When requires_grad is true, this operation is tracked for automatic
differentiation. The gradient computation uses the saved input and output
for efficient backward pass.
Sourcepub fn min_dims(&self, dims: &[usize], keepdim: bool) -> Tensor
pub fn min_dims(&self, dims: &[usize], keepdim: bool) -> Tensor
Computes the minimum value over specified dimensions
Reduces the tensor along the specified dimensions by computing the minimum
value in each reduction group. The keepdim parameter determines whether
reduced dimensions are kept with size 1 or removed entirely.
§Arguments
dims- Dimensions to reduce over (must be valid for the tensor’s rank)keepdim- If true, reduced dimensions are kept with size 1; if false, they are removed
§Returns
A tensor with the specified dimensions reduced
§Examples
use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
// Min over columns (dim 1), keeping dimensions
let min_cols = tensor.min_dims(&[1], true);
assert_eq!(min_cols.shape().dims(), vec![2, 1]);
assert_eq!(min_cols.get(&[0, 0]), 1.0);
assert_eq!(min_cols.get(&[1, 0]), 4.0);use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
// Min over rows (dim 0), removing dimensions
let min_rows = tensor.min_dims(&[0], false);
assert_eq!(min_rows.shape().dims(), vec![3]);
assert_eq!(min_rows.get(&[0]), 1.0);
assert_eq!(min_rows.get(&[1]), 2.0);
assert_eq!(min_rows.get(&[2]), 3.0);use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
// Min over multiple dimensions
let min_all = tensor.min_dims(&[0, 1], false);
assert_eq!(min_all.shape().dims(), vec![1]);
assert_eq!(min_all.get(&[0]), 1.0);§Panics
Panics if:
dimsis empty- Any dimension in
dimsis out of bounds for the tensor’s rank
§GradTrack Support
When requires_grad is true, this operation is tracked for automatic
differentiation. The gradient computation preserves the original input
shape and handles broadcasting correctly.
Source§impl Tensor
impl Tensor
Sourcepub fn norm(&self) -> Tensor
pub fn norm(&self) -> Tensor
Computes the L2 norm (Euclidean norm) over all elements
The L2 norm is calculated as sqrt(sum(x²)) where x represents each element in the tensor. This operation reduces the tensor to a scalar value [1].
§Returns
A scalar tensor containing the L2 norm value
§Examples
use train_station::Tensor;
// Basic L2 norm calculation
let tensor = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
let norm = tensor.norm();
assert!((norm.get(&[0]) - 5.0).abs() < 1e-6); // sqrt(3² + 4²) = 5use train_station::Tensor;
// L2 norm of a larger tensor
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
let norm = tensor.norm();
// sqrt(1² + 2² + 3² + 4² + 5² + 6² + 7² + 8²) = sqrt(204) ≈ 14.283
let expected = 204.0_f32.sqrt();
assert!((norm.get(&[0]) - expected).abs() < 1e-5);§Performance
Uses optimized contiguous tensor path with 4x loop unrolling for better performance. Non-contiguous tensors use stride-aware iteration.
Examples found in repository?
179fn demonstrate_utility_functions() {
180 println!("\n--- Utility Functions ---");
181
182 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184 // Basic properties
185 println!("Shape: {:?}", tensor.shape().dims());
186 println!("Size: {}", tensor.size());
187 println!("Is contiguous: {}", tensor.is_contiguous());
188 println!("Device: {:?}", tensor.device());
189
190 // Mathematical operations
191 let sum = tensor.sum();
192 println!("Sum: {}", sum.value());
193
194 let mean = tensor.mean();
195 println!("Mean: {}", mean.value());
196
197 let norm = tensor.norm();
198 println!("Norm: {}", norm.value());
199
200 // Device placement
201 let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202 println!(
203 "CPU tensor: shape {:?}, device: {:?}",
204 cpu_tensor.shape().dims(),
205 cpu_tensor.device()
206 );
207}More examples
317fn train_with_config(config: TrainingConfig) -> Result<TrainingStats, Box<dyn std::error::Error>> {
318 // Create training data
319 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
320 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0, 11.0], vec![5, 1]).unwrap();
321
322 // Create model parameters
323 let mut weight = Tensor::randn(vec![1, 1], Some(123)).with_requires_grad();
324 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
325
326 // Create optimizer with custom configuration
327 let adam_config = AdamConfig {
328 learning_rate: config.learning_rate,
329 beta1: config.beta1,
330 beta2: config.beta2,
331 eps: 1e-8,
332 weight_decay: config.weight_decay,
333 amsgrad: false,
334 };
335
336 let mut optimizer = Adam::with_config(adam_config);
337 optimizer.add_parameter(&weight);
338 optimizer.add_parameter(&bias);
339
340 // Training loop
341 let mut losses = Vec::new();
342 let mut convergence_epoch = config.epochs;
343
344 for epoch in 0..config.epochs {
345 // Forward pass
346 let y_pred = x_data.matmul(&weight) + &bias;
347 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
348
349 // Backward pass
350 loss.backward(None);
351
352 // Optimizer step
353 optimizer.step(&mut [&mut weight, &mut bias]);
354 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
355
356 let loss_value = loss.value();
357 losses.push(loss_value);
358
359 // Check for convergence (loss < 0.01)
360 if loss_value < 0.01 && convergence_epoch == config.epochs {
361 convergence_epoch = epoch;
362 }
363 }
364
365 Ok(TrainingStats {
366 config,
367 final_loss: losses[losses.len() - 1],
368 loss_history: losses,
369 convergence_epoch,
370 weight_norm: weight.norm().value(),
371 })
372}176fn demonstrate_advanced_training() -> Result<(), Box<dyn std::error::Error>> {
177 println!("\n--- Advanced Training Patterns ---");
178
179 // Create a more complex model
180 let mut weight = Tensor::randn(vec![1, 2], Some(44)).with_requires_grad();
181 let mut bias = Tensor::zeros(vec![2]).with_requires_grad();
182
183 // Create optimizer with different learning rate
184 let mut optimizer = Adam::with_learning_rate(0.005);
185 optimizer.add_parameter(&weight);
186 optimizer.add_parameter(&bias);
187
188 // Create training data: y = 2*x + [1, 3]
189 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5, 1]).unwrap();
190 let y_true = Tensor::from_slice(
191 &[3.0, 5.0, 7.0, 9.0, 11.0, 6.0, 8.0, 10.0, 12.0, 14.0],
192 vec![5, 2],
193 )
194 .unwrap();
195
196 println!("Advanced training with monitoring:");
197 println!(" Initial learning rate: {}", optimizer.learning_rate());
198
199 // Training loop with monitoring
200 let num_epochs = 50;
201 let mut losses = Vec::new();
202 let mut weight_norms = Vec::new();
203 let mut gradient_norms = Vec::new();
204
205 for epoch in 0..num_epochs {
206 // Forward pass
207 let y_pred = x_data.matmul(&weight) + &bias;
208 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
209
210 // Backward pass
211 loss.backward(None);
212
213 // Compute gradient norm before optimizer step
214 let gradient_norm = weight.grad_owned().unwrap().norm();
215
216 // Optimizer step
217 optimizer.step(&mut [&mut weight, &mut bias]);
218 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
219
220 // Learning rate scheduling: reduce every 10 epochs
221 if epoch > 0 && epoch % 10 == 0 {
222 let current_lr = optimizer.learning_rate();
223 let new_lr = current_lr * 0.5;
224 optimizer.set_learning_rate(new_lr);
225 println!(
226 "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
227 epoch, current_lr, new_lr
228 );
229 }
230
231 // Record metrics
232 losses.push(loss.value());
233 weight_norms.push(weight.norm().value());
234 gradient_norms.push(gradient_norm.value());
235
236 // Print detailed progress
237 if epoch % 10 == 0 || epoch == num_epochs - 1 {
238 println!(
239 "Epoch {:2}: Loss = {:.6}, Weight Norm = {:.6}, Gradient Norm = {:.6}",
240 epoch,
241 loss.value(),
242 weight.norm().value(),
243 gradient_norm.value()
244 );
245 }
246 }
247
248 println!("Final learning rate: {}", optimizer.learning_rate());
249
250 // Analyze training progression
251 let initial_loss = losses[0];
252 let final_loss = losses[losses.len() - 1];
253 let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
254
255 println!("\nTraining Analysis:");
256 println!(" Initial loss: {:.6}", initial_loss);
257 println!(" Final loss: {:.6}", final_loss);
258 println!(" Loss reduction: {:.1}%", loss_reduction);
259 println!(" Final weight norm: {:.6}", weight.norm().value());
260 println!(" Final bias: {:?}", bias.data());
261
262 Ok(())
263}
264
265/// Demonstrate learning rate scheduling
266fn demonstrate_learning_rate_scheduling() -> Result<(), Box<dyn std::error::Error>> {
267 println!("\n--- Learning Rate Scheduling ---");
268
269 // Create simple model
270 let mut weight = Tensor::randn(vec![1, 1], Some(45)).with_requires_grad();
271 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
272
273 // Create optimizer with high initial learning rate
274 let mut optimizer = Adam::with_learning_rate(0.1);
275 optimizer.add_parameter(&weight);
276 optimizer.add_parameter(&bias);
277
278 // Simple data
279 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3, 1]).unwrap();
280 let y_true = Tensor::from_slice(&[2.0, 4.0, 6.0], vec![3, 1]).unwrap();
281
282 println!("Initial learning rate: {}", optimizer.learning_rate());
283
284 // Training loop with learning rate scheduling
285 let num_epochs = 50;
286 let mut losses = Vec::new();
287
288 for epoch in 0..num_epochs {
289 // Forward pass
290 let y_pred = x_data.matmul(&weight) + &bias;
291 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
292
293 // Backward pass
294 loss.backward(None);
295
296 // Optimizer step
297 optimizer.step(&mut [&mut weight, &mut bias]);
298 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
299
300 // Learning rate scheduling: reduce every 10 epochs
301 if epoch > 0 && epoch % 10 == 0 {
302 let current_lr = optimizer.learning_rate();
303 let new_lr = current_lr * 0.5;
304 optimizer.set_learning_rate(new_lr);
305 println!(
306 "Epoch {:2}: Reduced learning rate from {:.3} to {:.3}",
307 epoch, current_lr, new_lr
308 );
309 }
310
311 losses.push(loss.value());
312
313 // Print progress
314 if epoch % 10 == 0 || epoch == num_epochs - 1 {
315 println!(
316 "Epoch {:2}: Loss = {:.6}, LR = {:.3}",
317 epoch,
318 loss.value(),
319 optimizer.learning_rate()
320 );
321 }
322 }
323
324 println!("Final learning rate: {}", optimizer.learning_rate());
325
326 Ok(())
327}
328
329/// Demonstrate training monitoring and analysis
330fn demonstrate_training_monitoring() -> Result<(), Box<dyn std::error::Error>> {
331 println!("\n--- Training Monitoring ---");
332
333 // Create model
334 let mut weight = Tensor::randn(vec![1, 1], Some(46)).with_requires_grad();
335 let mut bias = Tensor::zeros(vec![1]).with_requires_grad();
336
337 // Create optimizer
338 let mut optimizer = Adam::with_learning_rate(0.01);
339 optimizer.add_parameter(&weight);
340 optimizer.add_parameter(&bias);
341
342 // Training data
343 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4, 1]).unwrap();
344 let y_true = Tensor::from_slice(&[3.0, 5.0, 7.0, 9.0], vec![4, 1]).unwrap();
345
346 // Training loop with comprehensive monitoring
347 let num_epochs = 30;
348 let mut losses = Vec::new();
349 let mut weight_history = Vec::new();
350 let mut bias_history = Vec::new();
351
352 for epoch in 0..num_epochs {
353 // Forward pass
354 let y_pred = x_data.matmul(&weight) + &bias;
355 let mut loss = (&y_pred - &y_true).pow_scalar(2.0).mean();
356
357 // Backward pass
358 loss.backward(None);
359
360 // Optimizer step
361 optimizer.step(&mut [&mut weight, &mut bias]);
362 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
363
364 // Record history
365 losses.push(loss.value());
366 weight_history.push(weight.value());
367 bias_history.push(bias.value());
368
369 // Print detailed monitoring
370 if epoch % 5 == 0 || epoch == num_epochs - 1 {
371 println!(
372 "Epoch {:2}: Loss = {:.6}, Weight = {:.6}, Bias = {:.6}",
373 epoch,
374 loss.value(),
375 weight.value(),
376 bias.value()
377 );
378 }
379 }
380
381 // Analyze training progression
382 println!("\nTraining Analysis:");
383 println!(" Initial loss: {:.6}", losses[0]);
384 println!(" Final loss: {:.6}", losses[losses.len() - 1]);
385 println!(
386 " Loss reduction: {:.1}%",
387 (losses[0] - losses[losses.len() - 1]) / losses[0] * 100.0
388 );
389
390 // Compute statistics
391 let loss_mean = compute_mean(&losses);
392 let loss_std = compute_std(&losses);
393 let weight_change = (weight_history[weight_history.len() - 1] - weight_history[0]).abs();
394 let bias_change = (bias_history[bias_history.len() - 1] - bias_history[0]).abs();
395
396 println!(" Average loss: {:.6} ± {:.6}", loss_mean, loss_std);
397 println!(" Weight change: {:.6}", weight_change);
398 println!(" Bias change: {:.6}", bias_change);
399 println!(" Final weight norm: {:.6}", weight.norm().value());
400 println!(" Final bias: {:.6}", bias.value());
401
402 Ok(())
403}Sourcepub fn norm_dims(&self, dims: &[usize], keepdim: bool) -> Tensor
pub fn norm_dims(&self, dims: &[usize], keepdim: bool) -> Tensor
Computes the L2 norm over specified dimensions
Reduces the tensor along the specified dimensions by computing the L2 norm of each slice. The result maintains the original tensor structure with reduced dimensions optionally preserved as size-1 dimensions.
§Arguments
dims- Vector of dimension indices to reduce over (must be valid for tensor rank)keepdim- Whether to keep reduced dimensions as size-1 dimensions
§Returns
A tensor with L2 norm computed over the specified dimensions
§Examples
use train_station::Tensor;
// Norm along rows (dimension 1) with keepdim=true
let matrix = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
let row_norms = matrix.norm_dims(&[1], true);
assert_eq!(row_norms.shape().dims(), vec![2, 1]);
assert!((row_norms.get(&[0, 0]) - 5.0).abs() < 1e-6); // sqrt(3² + 4²)
assert!((row_norms.get(&[1, 0]) - 5.0).abs() < 1e-6); // sqrt(0² + 5²)use train_station::Tensor;
// Norm along columns (dimension 0) with keepdim=false
let matrix = Tensor::from_slice(&[3.0, 4.0, 0.0, 5.0], vec![2, 2]).unwrap();
let col_norms = matrix.norm_dims(&[0], false);
assert_eq!(col_norms.shape().dims(), vec![2]);
assert!((col_norms.get(&[0]) - 3.0).abs() < 1e-6); // sqrt(3² + 0²)
assert!((col_norms.get(&[1]) - 6.403).abs() < 1e-3); // sqrt(4² + 5²)use train_station::Tensor;
// Norm over multiple dimensions
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let norm_all = tensor.norm_dims(&[0, 1], false);
assert_eq!(norm_all.shape().dims(), vec![1]);
// sqrt(1² + 2² + 3² + 4²) = sqrt(30) ≈ 5.477
assert!((norm_all.get(&[0]) - 30.0_f32.sqrt()).abs() < 1e-5);§Panics
- If
dimsis empty - If any dimension index is out of bounds for the tensor rank
§Performance
Uses efficient coordinate-based iteration that works correctly with both contiguous and non-contiguous tensor layouts.
Source§impl Tensor
impl Tensor
Sourcepub fn std(&self) -> Tensor
pub fn std(&self) -> Tensor
Computes the standard deviation over all elements
The standard deviation is calculated as sqrt(variance) where variance is the mean of squared differences from the mean. This operation reduces the tensor to a scalar value [1].
The implementation uses population standard deviation (divides by n rather than n-1) to match PyTorch’s default behavior.
§Returns
A scalar tensor containing the standard deviation value
§Examples
use train_station::Tensor;
// Basic standard deviation calculation
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
let std_dev = tensor.std();
assert!((std_dev.get(&[0]) - 1.118_034).abs() < 1e-5);use train_station::Tensor;
// Standard deviation of a larger dataset
let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
let std_dev = tensor.std();
// mean=4.5, var=5.25, std=sqrt(5.25)≈2.291
let expected = 5.25_f32.sqrt();
assert!((std_dev.get(&[0]) - expected).abs() < 1e-5);use train_station::Tensor;
// Standard deviation of constant values (should be 0)
let tensor = Tensor::from_slice(&[5.0, 5.0, 5.0, 5.0], vec![4]).unwrap();
let std_dev = tensor.std();
assert!((std_dev.get(&[0]) - 0.0).abs() < 1e-6);§Performance
Uses optimized contiguous tensor path with 4x loop unrolling for better performance. Non-contiguous tensors use stride-aware iteration. The algorithm performs two passes: first to compute the mean, then to compute the variance.
Examples found in repository?
87fn demonstrate_data_pipeline() -> Result<(), Box<dyn std::error::Error>> {
88 println!("\n--- Data Processing Pipeline ---");
89
90 // Simulate raw sensor data with noise
91 let raw_data: Vec<f32> = (0..20)
92 .map(|i| {
93 let base = i as f32 * 0.5;
94 let noise = (i % 3) as f32 * 0.1;
95 base + noise
96 })
97 .collect();
98
99 let tensor = Tensor::from_slice(&raw_data, vec![20])?;
100 println!("Raw sensor data: {:?}", tensor.data());
101
102 // Multi-stage processing pipeline
103 println!("\nProcessing pipeline:");
104 println!("1. Normalize data (z-score)");
105 println!("2. Apply smoothing filter");
106 println!("3. Detect outliers");
107 println!("4. Apply feature scaling");
108
109 // Stage 1: Normalization
110 let mean = tensor.mean().value();
111 let std = tensor.std().value();
112 let normalized: Tensor = tensor
113 .iter()
114 .map(|elem| elem.sub_scalar(mean).div_scalar(std))
115 .collect();
116 println!(
117 " Normalized (mean={:.3}, std={:.3}): {:?}",
118 mean,
119 std,
120 normalized.data()
121 );
122
123 // Stage 2: Smoothing (simple moving average)
124 let smoothed: Tensor = normalized
125 .iter()
126 .enumerate()
127 .map(|(i, elem)| {
128 if i == 0 || i == normalized.size() - 1 {
129 elem.clone()
130 } else {
131 // Simple 3-point average
132 let prev = normalized.element_view(i - 1);
133 let next = normalized.element_view(i + 1);
134 elem.add_tensor(&prev).add_tensor(&next).div_scalar(3.0)
135 }
136 })
137 .collect();
138 println!(" Smoothed: {:?}", smoothed.data());
139
140 // Stage 3: Outlier detection and removal
141 let outlier_threshold = 2.0;
142 let cleaned: Tensor = smoothed
143 .iter()
144 .filter(|elem| elem.value().abs() < outlier_threshold)
145 .collect();
146 println!(
147 " Outliers removed (threshold={}): {:?}",
148 outlier_threshold,
149 cleaned.data()
150 );
151
152 // Stage 4: Feature scaling to [0, 1] range
153 let min_val = cleaned
154 .iter()
155 .map(|e| e.value())
156 .fold(f32::INFINITY, f32::min);
157 let max_val = cleaned
158 .iter()
159 .map(|e| e.value())
160 .fold(f32::NEG_INFINITY, f32::max);
161 let scaled: Tensor = cleaned
162 .iter()
163 .map(|elem| elem.sub_scalar(min_val).div_scalar(max_val - min_val))
164 .collect();
165 println!(" Scaled to [0,1]: {:?}", scaled.data());
166
167 Ok(())
168}
169
170/// Demonstrate conditional processing patterns
171///
172/// Shows how to implement dynamic filtering and transformation
173/// based on data characteristics and conditions.
174fn demonstrate_conditional_processing() -> Result<(), Box<dyn std::error::Error>> {
175 println!("\n--- Conditional Processing ---");
176
177 // Create data with mixed characteristics
178 let data = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0];
179 let tensor = Tensor::from_slice(&data, vec![10])?;
180 println!("Input data: {:?}", tensor.data());
181
182 // Conditional transformation based on sign
183 println!("\nConditional transformation (positive/negative handling):");
184 let processed: Tensor = tensor
185 .iter()
186 .map(|elem| {
187 let val = elem.value();
188 if val > 0.0 {
189 elem.pow_scalar(2.0) // Square positive values
190 } else {
191 elem.mul_scalar(-1.0).sqrt() // Square root of absolute negative values
192 }
193 })
194 .collect();
195 println!(" Processed: {:?}", processed.data());
196
197 // Adaptive filtering based on local statistics
198 println!("\nAdaptive filtering (remove values > 2 std from local mean):");
199 let window_size = 3;
200 let adaptive_filtered: Tensor = tensor
201 .iter()
202 .enumerate()
203 .filter(|(i, elem)| {
204 let start = i.saturating_sub(window_size / 2);
205 let end = (i + window_size / 2 + 1).min(tensor.size());
206
207 // Calculate local mean and std
208 let local_values: Vec<f32> = (start..end)
209 .map(|j| tensor.element_view(j).value())
210 .collect();
211
212 let local_mean = local_values.iter().sum::<f32>() / local_values.len() as f32;
213 let local_variance = local_values
214 .iter()
215 .map(|v| (v - local_mean).powi(2))
216 .sum::<f32>()
217 / local_values.len() as f32;
218 let local_std = local_variance.sqrt();
219
220 let threshold = local_mean + 2.0 * local_std;
221 elem.value() <= threshold
222 })
223 .map(|(_, elem)| elem)
224 .collect();
225 println!(" Adaptive filtered: {:?}", adaptive_filtered.data());
226
227 // Multi-condition processing
228 println!("\nMulti-condition processing:");
229 let multi_processed: Tensor = tensor
230 .iter()
231 .map(|elem| {
232 let val = elem.value();
233 match () {
234 _ if val > 5.0 => elem.mul_scalar(2.0), // Double large values
235 _ if val < -5.0 => elem.div_scalar(2.0), // Halve small values
236 _ if val.abs() < 2.0 => elem.add_scalar(1.0), // Add 1 to small values
237 _ => elem.clone(), // Keep others unchanged
238 }
239 })
240 .collect();
241 println!(" Multi-condition: {:?}", multi_processed.data());
242
243 Ok(())
244}
245
246/// Demonstrate batch processing operations
247///
248/// Shows efficient processing of large datasets using iterator
249/// patterns and batch operations for performance optimization.
250fn demonstrate_batch_operations() -> Result<(), Box<dyn std::error::Error>> {
251 println!("\n--- Batch Operations ---");
252
253 // Create a larger dataset for batch processing
254 let size = 100;
255 let data: Vec<f32> = (0..size)
256 .map(|i| {
257 let x = i as f32 / size as f32;
258 x * x + 0.1 * (i % 7) as f32 // Quadratic with some noise
259 })
260 .collect();
261
262 let tensor = Tensor::from_slice(&data, vec![size])?;
263 println!("Dataset size: {}", tensor.size());
264
265 // Batch processing with windowing (iterator views)
266 println!("\nBatch processing with sliding windows:");
267 let batch_size = 10;
268 let batches: Vec<Tensor> = tensor
269 .iter()
270 .collect::<Vec<_>>()
271 .chunks(batch_size)
272 .map(|chunk| {
273 // Process each batch independently
274 chunk
275 .iter()
276 .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
277 .collect()
278 })
279 .collect();
280
281 println!(
282 " Processed {} batches of size {}",
283 batches.len(),
284 batch_size
285 );
286 for (i, batch) in batches.iter().enumerate() {
287 println!(
288 " Batch {}: mean={:.3}, std={:.3}",
289 i,
290 batch.mean().value(),
291 batch.std().value()
292 );
293 }
294
295 // Parallel-like processing with stride
296 println!("\nStrided processing (every nth element):");
297 let stride = 5;
298 let strided: Tensor = tensor
299 .iter()
300 .enumerate()
301 .filter(|(i, _)| i % stride == 0)
302 .map(|(_, elem)| elem)
303 .collect();
304 println!(" Strided (every {}th): {:?}", stride, strided.data());
305
306 // Hierarchical processing
307 println!("\nHierarchical processing (coarse to fine):");
308 let coarse: Tensor = tensor
309 .iter()
310 .enumerate()
311 .filter(|(i, _)| i % 4 == 0) // Take every 4th element
312 .map(|(_, elem)| elem)
313 .collect();
314
315 let fine: Tensor = tensor
316 .iter()
317 .enumerate()
318 .filter(|(i, _)| i % 4 != 0) // Take the rest
319 .map(|(_, elem)| elem)
320 .collect();
321
322 println!(" Coarse (every 4th): {:?}", coarse.data());
323 println!(" Fine (rest): {:?}", fine.data());
324
325 // Combine coarse and fine with different processing
326 let combined: Tensor = coarse
327 .iter()
328 .map(|elem| elem.mul_scalar(2.0)) // Scale coarse
329 .chain(fine.iter().map(|elem| elem.div_scalar(2.0))) // Scale fine
330 .collect();
331 println!(" Combined: {:?}", combined.data());
332
333 Ok(())
334}
335
336/// Demonstrate real-world processing scenarios
337///
338/// Shows practical applications of iterator patterns for
339/// common data processing tasks in machine learning and analytics.
340fn demonstrate_real_world_scenarios() -> Result<(), Box<dyn std::error::Error>> {
341 println!("\n--- Real-world Scenarios ---");
342
343 // Scenario 1: Time series analysis
344 println!("\nScenario 1: Time Series Analysis");
345 let time_series: Vec<f32> = (0..24)
346 .map(|hour| {
347 let base = 20.0 + 10.0 * (hour as f32 * std::f32::consts::PI / 12.0).sin();
348 base + (hour % 3) as f32 * 2.0 // Add some noise
349 })
350 .collect();
351
352 let series = Tensor::from_slice(&time_series, vec![24])?;
353 println!(" Time series (24 hours): {:?}", series.data());
354
355 // Calculate moving average with view-based iteration
356 let window_size = 3;
357 let moving_avg: Tensor = series
358 .iter()
359 .enumerate()
360 .map(|(i, _)| {
361 let start = i.saturating_sub(window_size / 2);
362 let end = (i + window_size / 2 + 1).min(series.size());
363 let window = series.iter_range(start, end);
364 window.fold(0.0, |acc, elem| acc + elem.value()) / (end - start) as f32
365 })
366 .map(|val| Tensor::from_slice(&[val], vec![1]).unwrap())
367 .collect();
368 println!(
369 " Moving average (window={}): {:?}",
370 window_size,
371 moving_avg.data()
372 );
373
374 // Inference pipeline with NoGrad + streaming
375 println!("\nInference pipeline (NoGrad + streaming)");
376 let features = Tensor::from_slice(
377 &(0..48).map(|i| i as f32 * 0.125).collect::<Vec<_>>(),
378 vec![6, 8],
379 )?;
380 let fast = with_no_grad(|| {
381 // Stream values directly, apply light affine, and collect back to same shape
382 features
383 .data()
384 .iter()
385 .copied()
386 .map(|x| 0.75 * x + 0.1)
387 .collect_shape(vec![6, 8])
388 });
389 println!(
390 " NoGrad streamed transform shape: {:?}",
391 fast.shape().dims()
392 );
393
394 // Row-wise iteration with shape-preserving collection (GradTrack-friendly)
395 let per_row: Tensor = features
396 .iter()
397 .map(|row| row.mul_scalar(0.5).add_scalar(2.0))
398 .collect_shape(vec![6, 8]);
399 println!(" Row-wise mapped shape: {:?}", per_row.shape().dims());
400
401 // Scenario 2: Feature engineering
402 println!("\nScenario 2: Feature Engineering");
403 let features = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5])?;
404 println!(" Original features: {:?}", features.data());
405
406 // Create polynomial features
407 let poly_features: Tensor = features
408 .iter()
409 .flat_map(|elem| {
410 vec![
411 elem.clone(), // x^1
412 elem.pow_scalar(2.0), // x^2
413 elem.pow_scalar(3.0), // x^3
414 ]
415 })
416 .collect();
417 println!(
418 " Polynomial features (x, x^2, x^3): {:?}",
419 poly_features.data()
420 );
421
422 // Scenario 3: Data augmentation
423 println!("\nScenario 3: Data Augmentation");
424 let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?;
425 println!(" Original data: {:?}", original.data());
426
427 // Augment with noise and scaling
428 let augmented: Tensor = original
429 .iter()
430 .flat_map(|elem| {
431 vec![
432 elem.clone(), // Original
433 elem.add_scalar(0.1), // Add noise
434 elem.sub_scalar(0.1), // Subtract noise
435 elem.mul_scalar(1.1), // Scale up
436 elem.mul_scalar(0.9), // Scale down
437 ]
438 })
439 .collect();
440 println!(" Augmented data: {:?}", augmented.data());
441
442 // Scenario 4: Statistical analysis
443 println!("\nScenario 4: Statistical Analysis");
444 let sample_data = Tensor::from_slice(&[1.1, 2.3, 1.8, 2.1, 1.9, 2.0, 1.7, 2.2], vec![8])?;
445 println!(" Sample data: {:?}", sample_data.data());
446
447 // Calculate various statistics
448 let mean = sample_data.mean().value();
449 let std = sample_data.std().value();
450 let min = sample_data
451 .iter()
452 .map(|e| e.value())
453 .fold(f32::INFINITY, f32::min);
454 let max = sample_data
455 .iter()
456 .map(|e| e.value())
457 .fold(f32::NEG_INFINITY, f32::max);
458
459 // Z-score normalization
460 let z_scores: Tensor = sample_data
461 .iter()
462 .map(|elem| elem.sub_scalar(mean).div_scalar(std))
463 .collect();
464
465 println!(
466 " Statistics: mean={:.3}, std={:.3}, min={:.3}, max={:.3}",
467 mean, std, min, max
468 );
469 println!(" Z-scores: {:?}", z_scores.data());
470
471 Ok(())
472}Sourcepub fn std_dims(&self, dims: &[usize], keepdim: bool) -> Tensor
pub fn std_dims(&self, dims: &[usize], keepdim: bool) -> Tensor
Computes the standard deviation over specified dimensions
Reduces the tensor along the specified dimensions by computing the standard deviation of each slice. The result maintains the original tensor structure with reduced dimensions optionally preserved as size-1 dimensions.
Uses population standard deviation (divides by n rather than n-1) to match PyTorch’s default behavior.
§Arguments
dims- Vector of dimension indices to reduce over (must be valid for tensor rank)keepdim- Whether to keep reduced dimensions as size-1 dimensions
§Returns
A tensor with standard deviation computed over the specified dimensions
§Examples
use train_station::Tensor;
// Standard deviation along rows (dimension 1) with keepdim=true
let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
let row_stds = matrix.std_dims(&[1], true);
assert_eq!(row_stds.shape().dims(), vec![2, 1]);
assert!((row_stds.get(&[0, 0]) - 1.0).abs() < 1e-6); // std([1, 3]) = 1.0
assert!((row_stds.get(&[1, 0]) - 0.0).abs() < 1e-6); // std([2, 2]) = 0.0use train_station::Tensor;
// Standard deviation along columns (dimension 0) with keepdim=false
let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let col_stds = matrix.std_dims(&[0], false);
assert_eq!(col_stds.shape().dims(), vec![2]);
// std([1, 3]) = 1.0, std([2, 4]) = 1.0
assert!((col_stds.get(&[0]) - 1.0).abs() < 1e-6);
assert!((col_stds.get(&[1]) - 1.0).abs() < 1e-6);use train_station::Tensor;
// Standard deviation over multiple dimensions
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let std_all = tensor.std_dims(&[0, 1], false);
assert_eq!(std_all.shape().dims(), vec![1]);
// std([1, 2, 3, 4]) = sqrt(1.25) ≈ 1.118
assert!((std_all.get(&[0]) - 1.25_f32.sqrt()).abs() < 1e-5);§Panics
- If
dimsis empty - If any dimension index is out of bounds for the tensor rank
- If the reduced size is 0 (invalid for standard deviation calculation)
§Performance
Uses efficient coordinate-based iteration that works correctly with both contiguous and non-contiguous tensor layouts. The algorithm performs two passes: first to compute means, then to compute variances.
Source§impl Tensor
impl Tensor
Sourcepub fn sum(&self) -> Tensor
pub fn sum(&self) -> Tensor
Returns the sum of all elements in the tensor
This operation computes the sum of all elements across all dimensions, reducing the tensor to a scalar value. The output is a tensor with shape [1] containing the sum as a float.
When requires_grad is enabled, this operation supports automatic gradient
tracking through the GradTrack system.
§Returns
A tensor with shape [1] containing the sum of all elements
§Examples
use train_station::Tensor;
// Basic sum calculation
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let total = tensor.sum();
assert_eq!(total.shape().dims(), vec![1]);
assert_eq!(total.get(&[0]), 10.0); // 1 + 2 + 3 + 4 = 10use train_station::Tensor;
// Sum with gradient tracking
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])
.unwrap()
.with_requires_grad();
let mut total = tensor.sum();
total.backward(None);
let grad = tensor.grad_owned().expect("gradient should exist");
// Gradient should be [1.0, 1.0, 1.0] for each element
assert_eq!(grad.get(&[0]), 1.0);
assert_eq!(grad.get(&[1]), 1.0);
assert_eq!(grad.get(&[2]), 1.0);use train_station::Tensor;
// Sum of empty tensor
let tensor = Tensor::new(vec![0]);
let total = tensor.sum();
assert_eq!(total.get(&[0]), 0.0); // Sum of empty tensor is 0§Performance
Uses optimized contiguous tensor path with 4x loop unrolling for better performance. Non-contiguous tensors use stride-aware iteration.
Examples found in repository?
More examples
179fn demonstrate_utility_functions() {
180 println!("\n--- Utility Functions ---");
181
182 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
183
184 // Basic properties
185 println!("Shape: {:?}", tensor.shape().dims());
186 println!("Size: {}", tensor.size());
187 println!("Is contiguous: {}", tensor.is_contiguous());
188 println!("Device: {:?}", tensor.device());
189
190 // Mathematical operations
191 let sum = tensor.sum();
192 println!("Sum: {}", sum.value());
193
194 let mean = tensor.mean();
195 println!("Mean: {}", mean.value());
196
197 let norm = tensor.norm();
198 println!("Norm: {}", norm.value());
199
200 // Device placement
201 let cpu_tensor = Tensor::zeros_on_device(vec![3, 3], train_station::Device::cpu());
202 println!(
203 "CPU tensor: shape {:?}, device: {:?}",
204 cpu_tensor.shape().dims(),
205 cpu_tensor.device()
206 );
207}175fn demonstrate_gradient_tracking() -> Result<(), Box<dyn std::error::Error>> {
176 println!("\n--- Gradient Tracking ---");
177
178 // Create a tensor with gradient tracking enabled
179 let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3])?.with_requires_grad();
180 println!("Input tensor (requires_grad): {:?}", tensor.data());
181
182 // Perform element-wise operations through iteration
183 let result: Tensor = tensor
184 .iter()
185 .map(|elem| {
186 // Apply a complex transformation: (x^2 + 1) * 2
187 elem.pow_scalar(2.0).add_scalar(1.0).mul_scalar(2.0)
188 })
189 .collect();
190
191 println!("Result tensor: {:?}", result.data());
192 println!("Result requires_grad: {}", result.requires_grad());
193
194 // Compute gradients
195 let mut loss = result.sum();
196 loss.backward(None);
197
198 println!("Loss: {:.6}", loss.value());
199 println!("Input gradients: {:?}", tensor.grad().map(|g| g.data()));
200
201 Ok(())
202}109fn demonstrate_optimizer_serialization() -> Result<(), Box<dyn std::error::Error>> {
110 println!("\n--- Optimizer Serialization ---");
111
112 // Create an optimizer with some parameters
113 let mut weight = Tensor::randn(vec![2, 2], Some(42)).with_requires_grad();
114 let mut bias = Tensor::randn(vec![2], Some(43)).with_requires_grad();
115
116 let config = AdamConfig {
117 learning_rate: 0.001,
118 beta1: 0.9,
119 beta2: 0.999,
120 eps: 1e-8,
121 weight_decay: 0.0,
122 amsgrad: false,
123 };
124
125 let mut optimizer = Adam::with_config(config);
126 optimizer.add_parameter(&weight);
127 optimizer.add_parameter(&bias);
128
129 println!(
130 "Created optimizer with {} parameters",
131 optimizer.parameter_count()
132 );
133 println!("Learning rate: {}", optimizer.learning_rate());
134
135 // Simulate some training steps
136 for _ in 0..3 {
137 let mut loss = weight.sum() + bias.sum();
138 loss.backward(None);
139 optimizer.step(&mut [&mut weight, &mut bias]);
140 optimizer.zero_grad(&mut [&mut weight, &mut bias]);
141 }
142
143 // Save optimizer state
144 let optimizer_path = "temp_optimizer.json";
145 optimizer.save_json(optimizer_path)?;
146 println!("Saved optimizer to: {}", optimizer_path);
147
148 // Load optimizer state
149 let loaded_optimizer = Adam::load_json(optimizer_path)?;
150 println!(
151 "Loaded optimizer with {} parameters",
152 loaded_optimizer.parameter_count()
153 );
154 println!("Learning rate: {}", loaded_optimizer.learning_rate());
155
156 // Verify optimizer state
157 assert_eq!(
158 optimizer.parameter_count(),
159 loaded_optimizer.parameter_count()
160 );
161 assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
162 println!("Optimizer serialization verification: PASSED");
163
164 Ok(())
165}
166
167/// Demonstrate format comparison and performance characteristics
168fn demonstrate_format_comparison() -> Result<(), Box<dyn std::error::Error>> {
169 println!("\n--- Format Comparison ---");
170
171 // Create a larger tensor for comparison
172 let tensor = Tensor::randn(vec![10, 10], Some(44));
173
174 // Save in both formats
175 tensor.save_json("temp_comparison.json")?;
176 tensor.save_binary("temp_comparison.bin")?;
177
178 // Compare file sizes
179 let json_size = fs::metadata("temp_comparison.json")?.len();
180 let binary_size = fs::metadata("temp_comparison.bin")?.len();
181
182 println!("JSON file size: {} bytes", json_size);
183 println!("Binary file size: {} bytes", binary_size);
184 println!(
185 "Compression ratio: {:.2}x",
186 json_size as f64 / binary_size as f64
187 );
188
189 // Load and verify both formats
190 let json_tensor = Tensor::load_json("temp_comparison.json")?;
191 let binary_tensor = Tensor::load_binary("temp_comparison.bin")?;
192
193 assert_eq!(tensor.shape().dims(), json_tensor.shape().dims());
194 assert_eq!(tensor.shape().dims(), binary_tensor.shape().dims());
195 assert_eq!(tensor.data(), json_tensor.data());
196 assert_eq!(tensor.data(), binary_tensor.data());
197
198 println!("Format comparison verification: PASSED");
199
200 Ok(())
201}
202
203/// Demonstrate a basic model checkpointing workflow
204fn demonstrate_model_checkpointing() -> Result<(), Box<dyn std::error::Error>> {
205 println!("\n--- Model Checkpointing ---");
206
207 // Create a simple model (weights and bias)
208 let mut weights = Tensor::randn(vec![2, 1], Some(45)).with_requires_grad();
209 let mut bias = Tensor::randn(vec![1], Some(46)).with_requires_grad();
210
211 // Create optimizer
212 let mut optimizer = Adam::with_learning_rate(0.01);
213 optimizer.add_parameter(&weights);
214 optimizer.add_parameter(&bias);
215
216 println!("Initial weights: {:?}", weights.data());
217 println!("Initial bias: {:?}", bias.data());
218
219 // Simulate training
220 for epoch in 0..5 {
221 let mut loss = weights.sum() + bias.sum();
222 loss.backward(None);
223 optimizer.step(&mut [&mut weights, &mut bias]);
224 optimizer.zero_grad(&mut [&mut weights, &mut bias]);
225
226 if epoch % 2 == 0 {
227 // Save checkpoint
228 let checkpoint_dir = format!("checkpoint_epoch_{}", epoch);
229 fs::create_dir_all(&checkpoint_dir)?;
230
231 weights.save_json(format!("{}/weights.json", checkpoint_dir))?;
232 bias.save_json(format!("{}/bias.json", checkpoint_dir))?;
233 optimizer.save_json(format!("{}/optimizer.json", checkpoint_dir))?;
234
235 println!("Saved checkpoint for epoch {}", epoch);
236 }
237 }
238
239 // Load from checkpoint
240 let loaded_weights = Tensor::load_json("checkpoint_epoch_4/weights.json")?;
241 let loaded_bias = Tensor::load_json("checkpoint_epoch_4/bias.json")?;
242 let loaded_optimizer = Adam::load_json("checkpoint_epoch_4/optimizer.json")?;
243
244 println!("Loaded weights: {:?}", loaded_weights.data());
245 println!("Loaded bias: {:?}", loaded_bias.data());
246 println!(
247 "Loaded optimizer learning rate: {}",
248 loaded_optimizer.learning_rate()
249 );
250
251 // Verify checkpoint integrity
252 assert_eq!(weights.shape().dims(), loaded_weights.shape().dims());
253 assert_eq!(bias.shape().dims(), loaded_bias.shape().dims());
254 assert_eq!(optimizer.learning_rate(), loaded_optimizer.learning_rate());
255
256 println!("Checkpointing verification: PASSED");
257
258 Ok(())
259}319fn demonstrate_optimization_techniques() -> Result<(), Box<dyn std::error::Error>> {
320 println!("\n--- Optimization Techniques ---");
321
322 let size = 50000;
323 let data: Vec<f32> = (0..size).map(|i| i as f32).collect();
324 let tensor = Tensor::from_slice(&data, vec![size])?;
325
326 println!("Optimizing processing for size: {}", size);
327
328 // Technique 1: Operation fusion
329 println!("\nTechnique 1: Operation Fusion");
330 let start = Instant::now();
331 let fused_result: Tensor = tensor
332 .iter()
333 .map(|elem| {
334 // Fuse multiple operations into single chain
335 elem.mul_scalar(2.0).add_scalar(1.0).pow_scalar(2.0).sqrt()
336 })
337 .collect();
338 let fused_time = start.elapsed();
339
340 // Technique 2: Conditional optimization
341 println!("\nTechnique 2: Conditional Optimization");
342 let start = Instant::now();
343 let conditional_result: Tensor = tensor
344 .iter()
345 .map(|elem| {
346 let val = elem.value();
347 if val < size as f32 / 2.0 {
348 elem.mul_scalar(2.0) // Simple operation for small values
349 } else {
350 elem.pow_scalar(2.0).sqrt() // Complex operation for large values
351 }
352 })
353 .collect();
354 let conditional_time = start.elapsed();
355
356 // Technique 3: Cache-friendly processing
357 println!("\nTechnique 3: Cache-Friendly Processing");
358 let start = Instant::now();
359 let cache_friendly_result: Tensor = tensor
360 .iter()
361 .take(1000) // Process in cache-friendly chunks
362 .map(|elem| elem.mul_scalar(2.0))
363 .collect();
364 let cache_friendly_time = start.elapsed();
365
366 // Technique 4: Memory pooling simulation
367 println!("\nTechnique 4: Memory Pooling Simulation");
368 let start = Instant::now();
369 let pooled_result: Tensor = tensor
370 .iter()
371 .enumerate()
372 .filter(|(i, _)| i % 100 == 0) // Process every 100th element
373 .map(|(_, elem)| elem.pow_scalar(2.0))
374 .collect();
375 let pooled_time = start.elapsed();
376
377 // Report optimization results
378 println!(" Fused operations: {:?}", fused_time);
379 println!(" Conditional optimization: {:?}", conditional_time);
380 println!(" Cache-friendly processing: {:?}", cache_friendly_time);
381 println!(" Memory pooling simulation: {:?}", pooled_time);
382
383 // Performance analysis
384 let fastest = fused_time
385 .min(conditional_time)
386 .min(cache_friendly_time)
387 .min(pooled_time);
388 println!(" Fastest technique: {:?}", fastest);
389
390 // Memory efficiency analysis
391 println!(" Fused result size: {}", fused_result.size());
392 println!(" Conditional result size: {}", conditional_result.size());
393 println!(
394 " Cache-friendly result size: {}",
395 cache_friendly_result.size()
396 );
397 println!(" Pooled result size: {}", pooled_result.size());
398
399 // Technique 5: Gradient optimization
400 println!("\nTechnique 5: Gradient Optimization");
401 let grad_tensor = tensor.with_requires_grad();
402 let start = Instant::now();
403
404 let grad_result: Tensor = grad_tensor
405 .iter()
406 .map(|elem| elem.pow_scalar(2.0).add_scalar(1.0))
407 .collect();
408
409 let mut loss = grad_result.sum();
410 loss.backward(None);
411 let grad_time = start.elapsed();
412
413 println!(" Gradient computation: {:?}", grad_time);
414 println!(
415 " Gradient tracking enabled: {}",
416 grad_result.requires_grad()
417 );
418
419 Ok(())
420}Sourcepub fn sum_dims(&self, dims: &[usize], keepdim: bool) -> Tensor
pub fn sum_dims(&self, dims: &[usize], keepdim: bool) -> Tensor
Returns the sum of elements along specified dimensions
This operation computes the sum of elements along the specified dimensions, reducing the tensor while optionally preserving the reduced dimensions as size-1 dimensions.
The output shape depends on the keepdim parameter:
- If
keepdimistrue, the reduced dimensions are kept with size 1 - If
keepdimisfalse, the reduced dimensions are removed
When requires_grad is enabled, this operation supports automatic gradient
tracking through the GradTrack system.
§Arguments
dims- Vector of dimension indices to sum over (must be valid for tensor rank)keepdim- Whether to keep reduced dimensions as size-1 dimensions
§Returns
A tensor with sum computed over the specified dimensions
§Panics
- If
dimsis empty - If any dimension index is out of bounds for the tensor rank
§Examples
use train_station::Tensor;
// Sum along rows (dimension 0) with keepdim=false
let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let row_sums = matrix.sum_dims(&[0], false);
assert_eq!(row_sums.shape().dims(), vec![2]);
assert_eq!(row_sums.get(&[0]), 4.0); // 1 + 3 = 4
assert_eq!(row_sums.get(&[1]), 6.0); // 2 + 4 = 6use train_station::Tensor;
// Sum along columns (dimension 1) with keepdim=true
let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let col_sums = matrix.sum_dims(&[1], true);
assert_eq!(col_sums.shape().dims(), vec![2, 1]);
assert_eq!(col_sums.get(&[0, 0]), 3.0); // 1 + 2 = 3
assert_eq!(col_sums.get(&[1, 0]), 7.0); // 3 + 4 = 7use train_station::Tensor;
// Sum over multiple dimensions
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let total = tensor.sum_dims(&[0, 1], false);
assert_eq!(total.shape().dims(), vec![1]);
assert_eq!(total.get(&[0]), 10.0); // 1 + 2 + 3 + 4 = 10use train_station::Tensor;
// Sum with gradient tracking
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2])
.unwrap()
.with_requires_grad();
let mut row_sums = tensor.sum_dims(&[0], false);
row_sums.backward(None);
let grad = tensor.grad_owned().expect("gradient should exist");
// Gradient should be [1.0, 1.0, 1.0, 1.0] for each element
assert_eq!(grad.get(&[0, 0]), 1.0);
assert_eq!(grad.get(&[0, 1]), 1.0);
assert_eq!(grad.get(&[1, 0]), 1.0);
assert_eq!(grad.get(&[1, 1]), 1.0);§Performance
Uses efficient coordinate-based iteration that works correctly with both contiguous and non-contiguous tensor layouts.
Examples found in repository?
44fn cross_entropy_logits(
45 logits: &Tensor,
46 labels: &[usize],
47 batch: usize,
48 _num_classes: usize,
49) -> Tensor {
50 // log_softmax = logits - logsumexp(logits, dim=1)
51 let max_logits = logits.max_dims(&[1], true);
52 let shifted = logits.sub_tensor(&max_logits);
53 let exp = shifted.exp();
54 let sum_exp = exp.sum_dims(&[1], true);
55 let log_sum_exp = sum_exp.log();
56 let log_softmax = shifted.sub_tensor(&log_sum_exp);
57 let ll = log_softmax.gather(1, labels, &[batch, 1]); // selected log-probs
58 ll.mul_scalar(-1.0).mean()
59}More examples
273fn log_prob_actions(
274 logits: &Tensor,
275 actions: &[usize],
276 batch: usize,
277 _action_dim: usize,
278) -> Tensor {
279 let max_logits = logits.max_dims(&[1], true); // [B,1]
280 let shifted = logits.sub_tensor(&max_logits);
281 let exp = shifted.exp();
282 let sum_exp = exp.sum_dims(&[1], true); // [B,1]
283 let log_sum_exp = sum_exp.log(); // [B,1]
284 let log_softmax = shifted.sub_tensor(&log_sum_exp); // [B,A]
285 // gather selected action log-probs
286 log_softmax.gather(1, actions, &[batch, 1])
287}
288
289// probability ratio = exp(new_logp - old_logp)
290fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291 new_logp.sub_tensor(old_logp).exp()
292}
293
294// Clamp ratio to [1-clip, 1+clip] using ReLU-based clamp (no custom ops)
295fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296 let b = ratio.shape().dims()[0];
297 let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298 let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299 let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300 high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304 let mut total_sq = 0.0f32;
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 for &v in g.data() {
308 total_sq += v * v;
309 }
310 }
311 }
312 total_sq.sqrt()
313}
314
315// -------------------------------
316// Main
317// -------------------------------
318
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320 println!("=== PPO Discrete Example (YardEnv) ===");
321
322 let state_dim = 3usize;
323 let action_dim = 3usize;
324 let total_steps = std::env::var("PPOD_STEPS")
325 .ok()
326 .and_then(|v| v.parse::<usize>().ok())
327 .unwrap_or(3500usize);
328 let horizon = 128usize;
329 let epochs = 4usize;
330 let mini_batch_size = 64usize;
331 let gamma = 0.99f32;
332 let lam = 0.95f32;
333 let clip_eps = 0.2f32;
334 let vf_coef = 0.5f32;
335 let ent_coef = 0.0f32;
336 let max_grad_norm = 1.0f32;
337
338 let mut actor = Actor::new(state_dim, action_dim, Some(111));
339 let mut critic = Critic::new(state_dim, Some(222));
340 let mut actor_opt = Adam::with_learning_rate(3e-4);
341 for p in actor.parameters() {
342 actor_opt.add_parameter(p);
343 }
344 let mut critic_opt = Adam::with_learning_rate(3e-4);
345 for p in critic.parameters() {
346 critic_opt.add_parameter(p);
347 }
348
349 let mut env = YardEnv::new(1234);
350 let mut rng = SmallRng::new(98765);
351 let mut state = env.reset();
352 let mut episode_return = 0.0f32;
353 let mut episode = 0usize;
354 let mut ema_return: Option<f32> = None;
355 let ema_alpha = 0.05f32;
356 let mut best_return = f32::NEG_INFINITY;
357
358 let mut t = 0usize;
359 while t < total_steps {
360 let mut batch = RolloutBatch::new(horizon, state_dim);
361 for _ in 0..horizon {
362 // Actor logits and categorical sampling
363 let logits = actor.forward(&state); // [1, A]
364 let probs = logits.softmax(1); // [1, A]
365 // sample action from probs (CPU sampling)
366 let p = probs.data();
367 let (p0, p1, _p2) = (p[0], p[1], p[2]);
368 let u = rng.next_f32();
369 let a_idx = if u < p0 {
370 0
371 } else if u < p0 + p1 {
372 1
373 } else {
374 2
375 };
376
377 let old_logp = {
378 let _ng = NoGradTrack::new();
379 let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380 lp.data()[0]
381 };
382
383 // Step env
384 let (next_state, reward, done) = env.step(a_idx);
385 episode_return += reward;
386
387 // Critic value
388 let value_t = critic.forward(&state);
389 let value_v = value_t.data()[0];
390
391 batch.push(
392 state.data(),
393 a_idx,
394 old_logp,
395 reward,
396 if done { 1.0 } else { 0.0 },
397 value_v,
398 next_state.data(),
399 );
400
401 state = if done {
402 let st = env.reset();
403 ema_return = Some(match ema_return {
404 None => episode_return,
405 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406 });
407 if episode_return > best_return {
408 best_return = episode_return;
409 }
410 println!(
411 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412 t,
413 episode,
414 episode_return,
415 ema_return.unwrap_or(episode_return),
416 best_return
417 );
418 episode_return = 0.0;
419 episode += 1;
420 st
421 } else {
422 next_state
423 };
424
425 t += 1;
426 if t >= total_steps {
427 break;
428 }
429 }
430
431 // Bootstrap values for GAE
432 let next_values: Vec<f32> = {
433 let mut out = Vec::with_capacity(batch.len());
434 for i in 0..batch.len() {
435 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437 out.push(critic.forward(&s2_t).data()[0]);
438 }
439 out
440 };
441
442 let mut returns = vec![0.0f32; batch.len()];
443 let mut adv = vec![0.0f32; batch.len()];
444 compute_gae(
445 &mut returns,
446 &mut adv,
447 &batch.rewards,
448 &batch.dones,
449 &batch.values,
450 &next_values,
451 gamma,
452 lam,
453 );
454 normalize_in_place(&mut adv, 1e-8);
455
456 // Tensors for training
457 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458 let actions_vec = batch.actions.clone();
459 let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463 // PPO epochs
464 let num_minibatches = batch.len().div_ceil(mini_batch_size);
465 for e in 0..epochs {
466 for mb in 0..num_minibatches {
467 let start = mb * mini_batch_size;
468 let end = (start + mini_batch_size).min(batch.len());
469 if start >= end {
470 break;
471 }
472
473 // Views
474 let s_mb = states_t
475 .slice_view(start * state_dim, 1, (end - start) * state_dim)
476 .reshape(vec![(end - start) as i32, state_dim as i32]);
477 let oldlp_mb = old_logp_t
478 .slice_view(start, 1, end - start)
479 .reshape(vec![(end - start) as i32, 1]);
480 let ret_mb = returns_t
481 .slice_view(start, 1, end - start)
482 .reshape(vec![(end - start) as i32, 1]);
483 let adv_mb = adv_t
484 .slice_view(start, 1, end - start)
485 .reshape(vec![(end - start) as i32, 1]);
486 let a_slice = &actions_vec[start..end];
487
488 // Zero grads
489 {
490 let mut ps = actor.parameters();
491 actor_opt.zero_grad(&mut ps);
492 }
493 {
494 let mut ps = critic.parameters();
495 critic_opt.zero_grad(&mut ps);
496 }
497
498 // Forward
499 let logits_mb = actor.forward(&s_mb); // [B,A]
500 let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501 let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502 let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503 let pg1 = ratio.mul_tensor(&adv_mb);
504 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507 let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509 let v_pred = critic.forward(&s_mb);
510 let v_loss = v_pred
511 .sub_tensor(&ret_mb)
512 .pow_scalar(2.0)
513 .mean()
514 .mul_scalar(vf_coef);
515
516 // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517 let probs_mb = logits_mb.softmax(1);
518 let logp_all = probs_mb.add_scalar(1e-8).log();
519 let ent = probs_mb
520 .mul_tensor(&logp_all)
521 .sum_dims(&[1], true)
522 .mul_scalar(-1.0)
523 .mean()
524 .mul_scalar(ent_coef);
525
526 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527 loss.backward(None);
528
529 // Step actor
530 {
531 let params = actor.parameters();
532 let mut with_grads: Vec<&mut Tensor> = Vec::new();
533 for p in params {
534 if p.grad_owned().is_some() {
535 with_grads.push(p);
536 }
537 }
538 if !with_grads.is_empty() {
539 let _ = grad_global_norm(&mut with_grads);
540 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541 actor_opt.step(&mut with_grads);
542 actor_opt.zero_grad(&mut with_grads);
543 }
544 }
545
546 // Step critic
547 {
548 let params = critic.parameters();
549 let mut with_grads: Vec<&mut Tensor> = Vec::new();
550 for p in params {
551 if p.grad_owned().is_some() {
552 with_grads.push(p);
553 }
554 }
555 if !with_grads.is_empty() {
556 let _ = grad_global_norm(&mut with_grads);
557 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558 critic_opt.step(&mut with_grads);
559 critic_opt.zero_grad(&mut with_grads);
560 }
561 }
562
563 if e == 0 && mb == 0 {
564 println!(
565 "update@t={} | actor_loss={:.4} v_loss={:.4}",
566 t,
567 actor_loss.value(),
568 v_loss.value()
569 );
570 }
571
572 clear_all_graphs_known();
573 }
574 }
575 }
576
577 println!("=== PPO discrete training finished ===");
578 Ok(())
579}234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235 // All tensors shaped [B, A] (log_std is broadcastable)
236 let std = log_std.exp();
237 let var = std.pow_scalar(2.0);
238 let log_scale = log_std;
239 let diff = action.sub_tensor(mean);
240 let log_prob = diff
241 .pow_scalar(2.0)
242 .div_tensor(&var)
243 .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244 .add_tensor(&log_scale.mul_scalar(2.0))
245 .mul_scalar(0.5)
246 .mul_scalar(-1.0);
247 // Sum across action dim (dim=1) -> [B,1]
248 log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253 returns_out: &mut [f32],
254 adv_out: &mut [f32],
255 rewards: &[f32],
256 dones: &[f32],
257 values: &[f32],
258 next_values: &[f32],
259 gamma: f32,
260 lam: f32,
261) {
262 let n = rewards.len();
263 let mut gae = 0.0f32;
264 for t in (0..n).rev() {
265 let not_done = 1.0 - dones[t];
266 let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267 gae = delta + gamma * lam * not_done * gae;
268 adv_out[t] = gae;
269 returns_out[t] = gae + values[t];
270 }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274 let n = x.len() as f32;
275 if n <= 1.0 {
276 return;
277 }
278 let mean = x.iter().copied().sum::<f32>() / n;
279 let var = x
280 .iter()
281 .map(|v| {
282 let d = v - mean;
283 d * d
284 })
285 .sum::<f32>()
286 / n;
287 let std = (var + eps).sqrt();
288 for v in x.iter_mut() {
289 *v = (*v - mean) / std;
290 }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294 let mut total_sq = 0.0f32;
295 for p in parameters.iter() {
296 if let Some(g) = p.grad_owned() {
297 for &v in g.data() {
298 total_sq += v * v;
299 }
300 }
301 }
302 let norm = total_sq.sqrt();
303 if norm > max_norm {
304 let scale = max_norm / (norm + eps);
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 p.set_grad(g.mul_scalar(scale));
308 }
309 }
310 }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314 let mut total_sq = 0.0f32;
315 for p in parameters.iter_mut() {
316 if let Some(g) = p.grad_owned() {
317 for &v in g.data() {
318 total_sq += v * v;
319 }
320 }
321 }
322 total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330 println!("=== PPO Continuous Example (YardEnv) ===");
331
332 let state_dim = 3usize;
333 let action_dim = 1usize;
334
335 // Hparams
336 let total_steps = std::env::var("PPO_STEPS")
337 .ok()
338 .and_then(|v| v.parse::<usize>().ok())
339 .unwrap_or(4000usize);
340 let horizon = 128usize; // rollout length per update
341 let epochs = 4usize; // PPO epochs per update
342 let mini_batch_size = 64usize; // minibatch from horizon
343 let gamma = 0.99f32;
344 let lam = 0.95f32; // GAE lambda
345 let clip_eps = 0.2f32;
346 let vf_coef = 0.5f32;
347 let ent_coef = 0.0f32;
348 let max_grad_norm = 1.0f32;
349
350 // Models
351 let mut actor = Actor::new(state_dim, action_dim, Some(101));
352 let mut critic = Critic::new(state_dim, Some(202));
353
354 // Opts
355 let mut actor_opt = Adam::with_learning_rate(3e-4);
356 for p in actor.parameters() {
357 actor_opt.add_parameter(p);
358 }
359 let mut critic_opt = Adam::with_learning_rate(3e-4);
360 for p in critic.parameters() {
361 critic_opt.add_parameter(p);
362 }
363
364 // Env and RNG
365 let mut env = YardEnv::new(42);
366 let mut rng = SmallRng::new(999);
367 let mut state = env.reset();
368
369 // Metrics
370 let mut episode_return = 0.0f32;
371 let mut episode = 0usize;
372 let mut ema_return: Option<f32> = None;
373 let ema_alpha = 0.05f32;
374 let mut best_return = f32::NEG_INFINITY;
375
376 let mut t = 0usize;
377 while t < total_steps {
378 // Collect a rollout
379 let mut batch = RolloutBatch::new(horizon, state_dim);
380 for _ in 0..horizon {
381 // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382 let (mean, log_std_row) = actor.forward(&state);
383 let mean_v = mean.data()[0];
384 let log_std_v = log_std_row.data()[0];
385 let std_v = log_std_v.exp();
386 let noise = rng.normal();
387 let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389 // Build action tensor [1, A] for log_prob calculation with autograd
390 let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391 let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392 let log_prob_v = log_prob_t.data()[0];
393
394 // Step env
395 let (next_state, reward, done) = env.step(action_v);
396 episode_return += reward;
397
398 // Value
399 let value_t = critic.forward(&state);
400 let value_v = value_t.data()[0];
401
402 // Push
403 batch.push(
404 state.data(),
405 action_v,
406 log_prob_v,
407 reward,
408 if done { 1.0 } else { 0.0 },
409 value_v,
410 next_state.data(),
411 );
412
413 // Reset
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return
430 );
431 episode_return = 0.0;
432 episode += 1;
433 st
434 } else {
435 next_state
436 };
437
438 t += 1;
439 if t >= total_steps {
440 break;
441 }
442 }
443
444 // Bootstrap next values for GAE
445 let next_values: Vec<f32> = {
446 let mut out = Vec::with_capacity(batch.len());
447 for i in 0..batch.len() {
448 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450 let v2 = critic.forward(&s2_t).data()[0];
451 out.push(v2);
452 }
453 out
454 };
455
456 // Compute returns and advantages
457 let mut returns = vec![0.0f32; batch.len()];
458 let mut adv = vec![0.0f32; batch.len()];
459 compute_gae(
460 &mut returns,
461 &mut adv,
462 &batch.rewards,
463 &batch.dones,
464 &batch.values,
465 &next_values,
466 gamma,
467 lam,
468 );
469 normalize_in_place(&mut adv, 1e-8);
470
471 // Prepare tensors for training
472 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473 let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474 let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478 // PPO epochs over the rollout
479 let num_minibatches = batch.len().div_ceil(mini_batch_size);
480 for e in 0..epochs {
481 for mb in 0..num_minibatches {
482 let start = mb * mini_batch_size;
483 let end = (start + mini_batch_size).min(batch.len());
484 if start >= end {
485 break;
486 }
487
488 // Slice views
489 let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490 let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491 let a_mb = actions_t
492 .slice_view(start * action_dim, 1, (end - start) * action_dim)
493 .reshape(vec![(end - start) as i32, action_dim as i32]);
494 let oldlp_mb = old_logp_t
495 .slice_view(start, 1, end - start)
496 .reshape(vec![(end - start) as i32, 1]);
497 let ret_mb = returns_t
498 .slice_view(start, 1, end - start)
499 .reshape(vec![(end - start) as i32, 1]);
500 let adv_mb = adv_t
501 .slice_view(start, 1, end - start)
502 .reshape(vec![(end - start) as i32, 1]);
503
504 // Zero grads
505 {
506 let mut ps = actor.parameters();
507 actor_opt.zero_grad(&mut ps);
508 }
509 {
510 let mut ps = critic.parameters();
511 critic_opt.zero_grad(&mut ps);
512 }
513
514 // Forward actor and critic
515 let (mean_mb, log_std_row) = actor.forward(&s_mb);
516 let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517 let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518 let clip_low =
519 Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520 .unwrap();
521 let clip_high =
522 Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523 .unwrap();
524 // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525 let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526 let ratio_clipped =
527 clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528 let pg1 = ratio.mul_tensor(&adv_mb);
529 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532 let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534 let v_pred = critic.forward(&s_mb);
535 let v_loss = v_pred
536 .sub_tensor(&ret_mb)
537 .pow_scalar(2.0)
538 .mean()
539 .mul_scalar(vf_coef);
540
541 // Entropy (approx Gaussian entropy per action)
542 let entropy = log_std_row
543 .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544 .sum_dims(&[1], true)
545 .mean()
546 .mul_scalar(ent_coef);
547
548 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549 loss.backward(None);
550
551 // Step actor
552 {
553 let params = actor.parameters();
554 let mut with_grads: Vec<&mut Tensor> = Vec::new();
555 for p in params {
556 if p.grad_owned().is_some() {
557 with_grads.push(p);
558 }
559 }
560 if !with_grads.is_empty() {
561 let _ = grad_global_norm(&mut with_grads);
562 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563 actor_opt.step(&mut with_grads);
564 actor_opt.zero_grad(&mut with_grads);
565 }
566 }
567
568 // Step critic
569 {
570 let params = critic.parameters();
571 let mut with_grads: Vec<&mut Tensor> = Vec::new();
572 for p in params {
573 if p.grad_owned().is_some() {
574 with_grads.push(p);
575 }
576 }
577 if !with_grads.is_empty() {
578 let _ = grad_global_norm(&mut with_grads);
579 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580 critic_opt.step(&mut with_grads);
581 critic_opt.zero_grad(&mut with_grads);
582 }
583 }
584
585 // Occasionally log
586 if e == 0 && mb == 0 {
587 println!(
588 "update@t={} | actor_loss={:.4} v_loss={:.4}",
589 t,
590 actor_loss.value(),
591 v_loss.value()
592 );
593 }
594
595 clear_all_graphs_known();
596 }
597 }
598 }
599
600 println!("=== PPO training finished ===");
601 Ok(())
602}Source§impl Tensor
impl Tensor
Sourcepub fn var(&self) -> Tensor
pub fn var(&self) -> Tensor
Computes the variance over all elements
The variance is calculated as the mean of squared differences from the mean. This operation reduces the tensor to a scalar value [1].
The implementation uses population variance (divides by n rather than n-1) to match PyTorch’s default behavior.
§Returns
A scalar tensor containing the variance value
§Examples
use train_station::Tensor;
// Basic variance calculation
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
let variance = tensor.var();
assert!((variance.get(&[0]) - 1.25).abs() < 1e-5);use train_station::Tensor;
// Variance of a larger dataset
let data = vec![1.0, 3.0, 5.0, 7.0, 2.0, 4.0, 6.0, 8.0];
let tensor = Tensor::from_slice(&data, vec![2, 2, 2]).unwrap();
let variance = tensor.var();
// mean=4.5, var=mean([3.5², 1.5², 0.5², 2.5², 2.5², 0.5², 1.5², 3.5²]) = 5.25
assert!((variance.get(&[0]) - 5.25).abs() < 1e-5);use train_station::Tensor;
// Variance of constant values (should be 0)
let tensor = Tensor::from_slice(&[5.0, 5.0, 5.0, 5.0], vec![4]).unwrap();
let variance = tensor.var();
assert!((variance.get(&[0]) - 0.0).abs() < 1e-6);§Performance
Uses optimized contiguous tensor path with manual loop unrolling for better performance. Non-contiguous tensors use stride-aware iteration. The algorithm performs two passes: first to compute the mean, then to compute the variance.
Sourcepub fn var_dims(&self, dims: &[usize], keepdim: bool) -> Tensor
pub fn var_dims(&self, dims: &[usize], keepdim: bool) -> Tensor
Computes the variance over specified dimensions
Reduces the tensor along the specified dimensions by computing the variance of each slice. The result maintains the original tensor structure with reduced dimensions optionally preserved as size-1 dimensions.
Uses population variance (divides by n rather than n-1) to match PyTorch’s default behavior.
§Arguments
dims- Vector of dimension indices to reduce over (must be valid for tensor rank)keepdim- Whether to keep reduced dimensions as size-1 dimensions
§Returns
A tensor with variance computed over the specified dimensions
§Examples
use train_station::Tensor;
// Variance along rows (dimension 1) with keepdim=true
let matrix = Tensor::from_slice(&[1.0, 3.0, 2.0, 2.0], vec![2, 2]).unwrap();
let row_vars = matrix.var_dims(&[1], true);
assert_eq!(row_vars.shape().dims(), vec![2, 1]);
assert!((row_vars.get(&[0, 0]) - 1.0).abs() < 1e-6); // var([1, 3]) = 1.0
assert!((row_vars.get(&[1, 0]) - 0.0).abs() < 1e-6); // var([2, 2]) = 0.0use train_station::Tensor;
// Variance along columns (dimension 0) with keepdim=false
let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let col_vars = matrix.var_dims(&[0], false);
assert_eq!(col_vars.shape().dims(), vec![2]);
// var([1, 3]) = 1.0, var([2, 4]) = 1.0
assert!((col_vars.get(&[0]) - 1.0).abs() < 1e-6);
assert!((col_vars.get(&[1]) - 1.0).abs() < 1e-6);use train_station::Tensor;
// Variance over multiple dimensions
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let var_all = tensor.var_dims(&[0, 1], false);
assert_eq!(var_all.shape().dims(), vec![1]);
// var([1, 2, 3, 4]) = 1.25
assert!((var_all.get(&[0]) - 1.25).abs() < 1e-5);§Panics
- If
dimsis empty - If any dimension index is out of bounds for the tensor rank
- If the reduced size is 0 (invalid for variance calculation)
§Performance
Uses efficient coordinate-based iteration that works correctly with both contiguous and non-contiguous tensor layouts. The algorithm performs two passes: first to compute means, then to compute variances.
Source§impl Tensor
impl Tensor
Sourcepub fn cat(tensors: &[Tensor], dim: usize) -> Tensor
pub fn cat(tensors: &[Tensor], dim: usize) -> Tensor
Concatenate tensors along a given dimension
Joins multiple tensors along the specified dimension, creating a new tensor with the combined data. All input tensors must have the same rank and matching dimensions except for the concatenation dimension.
§Arguments
tensors- Slice of tensors to concatenate (must not be empty)dim- Dimension along which to concatenate (must be < tensor rank)
§Returns
A new tensor containing the concatenated data with shape where the concatenation dimension is the sum of all input tensor sizes along that dimension.
§Panics
- If
tensorsis empty - If
dimis out of bounds for the tensor rank - If tensors have different ranks
- If tensors have mismatched dimensions (except along concatenation dimension)
§Examples
use train_station::Tensor;
// Concatenate 1D tensors
let a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
let b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
let result = Tensor::cat(&[a, b], 0);
assert_eq!(result.shape().dims(), vec![4]);
assert_eq!(result.get(&[0]), 1.0);
assert_eq!(result.get(&[1]), 2.0);
assert_eq!(result.get(&[2]), 3.0);
assert_eq!(result.get(&[3]), 4.0);use train_station::Tensor;
// Concatenate 2D tensors along dimension 1
let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::from_slice(&[5.0, 6.0], vec![2, 1]).unwrap();
let result = Tensor::cat(&[a, b], 1);
assert_eq!(result.shape().dims(), vec![2, 3]);
assert_eq!(result.get(&[0, 0]), 1.0);
assert_eq!(result.get(&[0, 1]), 2.0);
assert_eq!(result.get(&[0, 2]), 5.0);use train_station::Tensor;
// Concatenate with gradient tracking
let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
a.set_requires_grad(true);
b.set_requires_grad(true);
let result = Tensor::cat(&[a, b], 0);
assert!(result.requires_grad());Source§impl Tensor
impl Tensor
Sourcepub fn contiguous(&self) -> Tensor
pub fn contiguous(&self) -> Tensor
Creates a contiguous copy of the tensor
This operation ensures that the tensor data is stored in a linear, cache-friendly memory layout. If the tensor is already contiguous, this operation returns a clone. For non-contiguous tensors, it creates a new tensor with the same data but in contiguous memory layout.
The operation uses different optimization strategies based on tensor size:
- Small tensors (≤64 elements): Simple coordinate-based copy
- Medium tensors (65-1023 elements): Unrolled copy for better performance
- Large tensors (≥1024 elements): Blocked copy with cache optimization
§Returns
A new tensor with contiguous memory layout containing the same data
§Examples
use train_station::Tensor;
// Already contiguous tensor
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let contiguous = tensor.contiguous();
assert!(contiguous.is_contiguous());
assert_eq!(contiguous.shape().dims(), vec![2, 2]);use train_station::Tensor;
// Non-contiguous tensor from transpose
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let transposed = tensor.transpose(0, 1);
assert!(!transposed.is_contiguous());
let contiguous = transposed.contiguous();
assert!(contiguous.is_contiguous());
assert_eq!(contiguous.get(&[0, 0]), 1.0);
assert_eq!(contiguous.get(&[0, 1]), 3.0);use train_station::Tensor;
// Preserves gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
tensor.set_requires_grad(true);
let contiguous = tensor.contiguous();
assert!(contiguous.requires_grad());§Performance
- Already contiguous: O(1) time complexity, returns a clone
- Non-contiguous: O(n) time complexity with size-dependent optimizations
- Memory usage: Creates a new tensor with the same size as the original
Examples found in repository?
53 pub fn forward(&self, input: &Tensor, attn_mask: Option<&Tensor>) -> Tensor {
54 let attn = self.mha.forward(input, input, input, attn_mask);
55 let res1 = attn.add_tensor(input);
56
57 // Feed-forward network with ReLU and residual
58 let (b, t, e) = Self::triple(input);
59 let x2d = res1.contiguous().view(vec![(b * t) as i32, e as i32]);
60 let hidden = self.ffn_in.forward(&x2d).relu();
61 let out2d = self.ffn_out.forward(&hidden);
62 let out = out2d.view(vec![b as i32, t as i32, e as i32]);
63 out.add_tensor(&res1)
64 }More examples
56 pub fn forward(
57 &self,
58 tgt: &Tensor,
59 memory: &Tensor,
60 causal_mask: Option<&Tensor>,
61 cross_mask: Option<&Tensor>,
62 ) -> Tensor {
63 let self_attn = self.self_attn.forward(tgt, tgt, tgt, causal_mask);
64 let res1 = self_attn.add_tensor(tgt);
65
66 let cross = self.cross_attn.forward(&res1, memory, memory, cross_mask);
67 let res2 = cross.add_tensor(&res1);
68
69 let (b, t, e) = Self::triple(tgt);
70 let x2d = res2.contiguous().view(vec![(b * t) as i32, e as i32]);
71 let hidden = self.ffn_in.forward(&x2d).relu();
72 let out2d = self.ffn_out.forward(&hidden);
73 let out = out2d.view(vec![b as i32, t as i32, e as i32]);
74 out.add_tensor(&res2)
75 }72 pub fn forward(
73 &self,
74 query: &Tensor,
75 key: &Tensor,
76 value: &Tensor,
77 attn_mask: Option<&Tensor>,
78 ) -> Tensor {
79 let qkv = Self::project_qkv(query, key, value, &self.q_proj, &self.k_proj, &self.v_proj);
80 let (q, k, v) = qkv;
81
82 // Split heads: [b, t, e] -> [b, h, t, d]
83 let (b, tq, _e) = Self::triple(query);
84 let (_b2, tk, _e2) = Self::triple(key);
85 let q = Self::split_heads(&q, b, tq, self.num_heads, self.head_dim);
86 let k = Self::split_heads(&k, b, tk, self.num_heads, self.head_dim);
87 let v = Self::split_heads(&v, b, tk, self.num_heads, self.head_dim);
88
89 // Scaled dot-product attention
90 // logits: [b, h, tq, tk]
91 let k_t = k.transpose(2, 3);
92 let mut logits = q.matmul(&k_t).div_scalar((self.head_dim as f32).sqrt());
93 if let Some(mask) = attn_mask {
94 let dims = mask.shape().dims().to_vec();
95 // If boolean-like mask matching [b,h,tq,tk], apply masked_fill
96 if dims.len() == 4 && dims[0] == b && dims[1] == self.num_heads && dims[2] == tq {
97 // Interpret mask > 0.5 as keep; we invert to build masked positions
98 let cond: Vec<bool> = mask.data().iter().map(|&v| v < 0.5).collect();
99 // Apply masked fill on a flattened view, then reshape back
100 let flat_logits = logits.view(vec![(b * self.num_heads * tq * tk) as i32]);
101 let filled = flat_logits.masked_fill(&cond, f32::NEG_INFINITY);
102 logits = filled.view(vec![b as i32, self.num_heads as i32, tq as i32, tk as i32]);
103 } else {
104 // Fallback: additive mask
105 logits = logits.add_tensor(mask);
106 }
107 }
108 let attn = logits.softmax(3);
109
110 // context: [b, h, tq, d]
111 let context = attn.matmul(&v);
112 let context = context.permute(vec![0, 2, 1, 3]); // [b, tq, h, d]
113 let context = context.contiguous().view(vec![
114 b as i32,
115 tq as i32,
116 (self.num_heads * self.head_dim) as i32,
117 ]);
118
119 // Output projection (flatten to 2D, project, then restore 3D)
120 let flat = context.view(vec![(b * tq) as i32, self.embed_dim as i32]);
121 let out2d = self.out_proj.forward(&flat);
122 out2d.view(vec![b as i32, tq as i32, self.embed_dim as i32])
123 }Source§impl Tensor
impl Tensor
Sourcepub fn flatten(&self) -> Tensor
pub fn flatten(&self) -> Tensor
Flatten the tensor into a 1D representation
Transforms a multi-dimensional tensor into a 1D tensor by reshaping
all dimensions into a single dimension. This is equivalent to
reshape(vec![-1]) where -1 automatically calculates the size
based on the total number of elements.
The flatten operation preserves the total number of elements while changing the tensor’s shape to have a single dimension. This is commonly used in neural networks to prepare tensor data for linear layers or feature extraction.
§Returns
A 1D tensor containing the same data as the original tensor, with
shape [total_elements] where total_elements is the product of
all original dimensions.
§Examples
use train_station::Tensor;
// Flatten a 2D tensor
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let flattened = tensor.flatten();
assert_eq!(flattened.shape().dims(), vec![4]);
assert_eq!(flattened.get(&[0]), 1.0);
assert_eq!(flattened.get(&[1]), 2.0);
assert_eq!(flattened.get(&[2]), 3.0);
assert_eq!(flattened.get(&[3]), 4.0);use train_station::Tensor;
// Flatten a 3D tensor
let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
let tensor = Tensor::from_slice(&data, vec![2, 2, 3]).unwrap();
let flattened = tensor.flatten();
assert_eq!(flattened.shape().dims(), vec![12]);
assert_eq!(flattened.size(), 12);use train_station::Tensor;
// Flatten with gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
tensor.set_requires_grad(true);
let flattened = tensor.flatten();
assert!(flattened.requires_grad());
assert_eq!(flattened.shape().dims(), vec![4]);use train_station::Tensor;
// Flatten an already 1D tensor (no change)
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let flattened = tensor.flatten();
assert_eq!(flattened.shape().dims(), vec![3]);
assert_eq!(flattened.size(), 3);§Performance
- Time Complexity: O(1) - Returns a view when possible
- Memory Usage: No additional memory allocation for view operations
- Gradient Tracking: Preserves gradient requirements and tracking
§Relationship to Other Operations
This operation is equivalent to:
use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let flattened = tensor.reshape(vec![-1]);Where -1 is a special value that automatically calculates the
dimension size based on the total number of elements in the tensor.
Source§impl Tensor
impl Tensor
Sourcepub fn permute(&self, dims: Vec<usize>) -> Tensor
pub fn permute(&self, dims: Vec<usize>) -> Tensor
Permute tensor dimensions according to specified order
Rearranges the dimensions of the tensor according to the provided dimension order. This operation returns a view with reordered strides, avoiding data copying while changing the logical arrangement of the tensor’s dimensions.
The permutation is specified as a vector where each element represents
the new position of the corresponding dimension from the original tensor.
For example, permute(vec![1, 0]) swaps the first two dimensions.
§Arguments
dims- Vector specifying the new order of dimensions (must have length equal to tensor rank)
§Returns
A new tensor view with rearranged dimensions and correspondingly adjusted strides. The total number of elements remains unchanged.
§Panics
- If
dimslength does not equal the tensor rank - If any dimension index is out of bounds for the tensor rank
- If
dimscontains duplicate dimension indices
§Examples
use train_station::Tensor;
// Permute 2D tensor (swap dimensions)
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
let permuted = tensor.permute(vec![1, 0]);
assert_eq!(permuted.shape().dims(), vec![3, 2]);
assert_eq!(permuted.get(&[0, 0]), 1.0);
assert_eq!(permuted.get(&[1, 0]), 2.0);
assert_eq!(permuted.get(&[2, 1]), 6.0);use train_station::Tensor;
// Permute 3D tensor (reorder dimensions)
let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
let permuted = tensor.permute(vec![2, 0, 1]);
assert_eq!(permuted.shape().dims(), vec![4, 2, 3]);
assert_eq!(permuted.size(), 24); // Total elements unchangeduse train_station::Tensor;
// Permute with gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
tensor.set_requires_grad(true);
let permuted = tensor.permute(vec![1, 0]);
assert!(permuted.requires_grad());
assert_eq!(permuted.shape().dims(), vec![2, 2]);use train_station::Tensor;
// Identity permutation (no change)
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let permuted = tensor.permute(vec![0, 1]);
assert_eq!(permuted.shape().dims(), vec![2, 2]);
assert_eq!(permuted.get(&[0, 0]), 1.0);
assert_eq!(permuted.get(&[1, 1]), 4.0);§Performance
- Time Complexity: O(1) - Returns a view with reordered strides
- Memory Usage: No additional memory allocation (view operation)
- Gradient Tracking: Preserves gradient requirements and tracking
§Relationship to Other Operations
This operation is similar to transpose() but more general:
transpose(dim0, dim1)is equivalent topermute()with a swap of two dimensionspermute()can handle arbitrary dimension reordering for tensors of any rank
§Memory Layout
The permuted tensor maintains the same underlying data but with reordered strides. This means the tensor becomes non-contiguous unless the permutation is the identity permutation.
Examples found in repository?
72 pub fn forward(
73 &self,
74 query: &Tensor,
75 key: &Tensor,
76 value: &Tensor,
77 attn_mask: Option<&Tensor>,
78 ) -> Tensor {
79 let qkv = Self::project_qkv(query, key, value, &self.q_proj, &self.k_proj, &self.v_proj);
80 let (q, k, v) = qkv;
81
82 // Split heads: [b, t, e] -> [b, h, t, d]
83 let (b, tq, _e) = Self::triple(query);
84 let (_b2, tk, _e2) = Self::triple(key);
85 let q = Self::split_heads(&q, b, tq, self.num_heads, self.head_dim);
86 let k = Self::split_heads(&k, b, tk, self.num_heads, self.head_dim);
87 let v = Self::split_heads(&v, b, tk, self.num_heads, self.head_dim);
88
89 // Scaled dot-product attention
90 // logits: [b, h, tq, tk]
91 let k_t = k.transpose(2, 3);
92 let mut logits = q.matmul(&k_t).div_scalar((self.head_dim as f32).sqrt());
93 if let Some(mask) = attn_mask {
94 let dims = mask.shape().dims().to_vec();
95 // If boolean-like mask matching [b,h,tq,tk], apply masked_fill
96 if dims.len() == 4 && dims[0] == b && dims[1] == self.num_heads && dims[2] == tq {
97 // Interpret mask > 0.5 as keep; we invert to build masked positions
98 let cond: Vec<bool> = mask.data().iter().map(|&v| v < 0.5).collect();
99 // Apply masked fill on a flattened view, then reshape back
100 let flat_logits = logits.view(vec![(b * self.num_heads * tq * tk) as i32]);
101 let filled = flat_logits.masked_fill(&cond, f32::NEG_INFINITY);
102 logits = filled.view(vec![b as i32, self.num_heads as i32, tq as i32, tk as i32]);
103 } else {
104 // Fallback: additive mask
105 logits = logits.add_tensor(mask);
106 }
107 }
108 let attn = logits.softmax(3);
109
110 // context: [b, h, tq, d]
111 let context = attn.matmul(&v);
112 let context = context.permute(vec![0, 2, 1, 3]); // [b, tq, h, d]
113 let context = context.contiguous().view(vec![
114 b as i32,
115 tq as i32,
116 (self.num_heads * self.head_dim) as i32,
117 ]);
118
119 // Output projection (flatten to 2D, project, then restore 3D)
120 let flat = context.view(vec![(b * tq) as i32, self.embed_dim as i32]);
121 let out2d = self.out_proj.forward(&flat);
122 out2d.view(vec![b as i32, tq as i32, self.embed_dim as i32])
123 }
124
125 fn project_qkv(
126 query: &Tensor,
127 key: &Tensor,
128 value: &Tensor,
129 q_proj: &LinearLayer,
130 k_proj: &LinearLayer,
131 v_proj: &LinearLayer,
132 ) -> (Tensor, Tensor, Tensor) {
133 let (bq, tq, eq) = Self::triple(query);
134 let (bk, tk, ek) = Self::triple(key);
135 let (_bv, tv, ev) = Self::triple(value);
136 assert!(eq == ek && ek == ev, "Q,K,V embed dims must match");
137 let q2d = query.view(vec![(bq * tq) as i32, eq as i32]);
138 let k2d = key.view(vec![(bk * tk) as i32, ek as i32]);
139 let v2d = value.view(vec![(_bv * tv) as i32, ev as i32]);
140 let q = q_proj
141 .forward(&q2d)
142 .view(vec![bq as i32, tq as i32, eq as i32]);
143 let k = k_proj
144 .forward(&k2d)
145 .view(vec![bk as i32, tk as i32, ek as i32]);
146 let v = v_proj
147 .forward(&v2d)
148 .view(vec![bk as i32, tv as i32, ev as i32]);
149 (q, k, v)
150 }
151
152 fn split_heads(x: &Tensor, b: usize, t: usize, h: usize, d: usize) -> Tensor {
153 x.view(vec![b as i32, t as i32, h as i32, d as i32])
154 .permute(vec![0, 2, 1, 3])
155 }Source§impl Tensor
impl Tensor
Sourcepub fn reshape(&self, new_shape: Vec<i32>) -> Tensor
pub fn reshape(&self, new_shape: Vec<i32>) -> Tensor
Reshape the tensor to the specified dimensions
Changes the shape of the tensor while preserving the total number of elements. This operation returns a view when the tensor is contiguous, avoiding data copying. For non-contiguous tensors, data is copied to ensure the reshape is valid.
The reshape operation supports automatic dimension inference using -1, which allows one dimension to be automatically calculated based on the total number of elements and the other specified dimensions.
§Arguments
new_shape- Target shape for the tensor. Use -1 for one dimension to have it automatically inferred from the total size.
§Returns
A new tensor with the specified shape containing the same data as the original tensor.
§Panics
- If more than one dimension is -1
- If the total number of elements doesn’t match the original tensor
- If any dimension size is 0 or less than -1
- If the inferred dimension size is not a whole number
§Examples
use train_station::Tensor;
// Basic reshape
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
let reshaped = tensor.reshape(vec![3, 2]);
assert_eq!(reshaped.shape().dims(), vec![3, 2]);
assert_eq!(reshaped.get(&[0, 0]), 1.0);
assert_eq!(reshaped.get(&[2, 1]), 6.0);use train_station::Tensor;
// Using -1 for automatic dimension inference
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![4]).unwrap();
let reshaped = tensor.reshape(vec![2, -1]);
assert_eq!(reshaped.shape().dims(), vec![2, 2]);
assert_eq!(reshaped.get(&[0, 0]), 1.0);
assert_eq!(reshaped.get(&[1, 1]), 4.0);use train_station::Tensor;
// Reshape with gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
tensor.set_requires_grad(true);
let reshaped = tensor.reshape(vec![4]);
assert!(reshaped.requires_grad());
assert_eq!(reshaped.shape().dims(), vec![4]);use train_station::Tensor;
// Reshape 3D tensor
let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
let reshaped = tensor.reshape(vec![6, 4]);
assert_eq!(reshaped.shape().dims(), vec![6, 4]);
assert_eq!(reshaped.size(), 24);§Performance
- Contiguous tensors: O(1) time complexity, returns a view
- Non-contiguous tensors: O(n) time complexity with data copying
- Memory usage: No additional allocation for view operations
- Gradient tracking: Preserves gradient requirements and tracking
§Automatic Dimension Inference
When using -1 for a dimension, the size is automatically calculated:
use train_station::Tensor;
// For a tensor with 12 elements
let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
let tensor = Tensor::from_slice(&data, vec![3, 4]).unwrap();
let reshaped1 = tensor.reshape(vec![3, -1]); // Results in shape [3, 4]
let reshaped2 = tensor.reshape(vec![-1, 6]); // Results in shape [2, 6]
let reshaped3 = tensor.reshape(vec![-1]); // Results in shape [12]Examples found in repository?
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320 println!("=== PPO Discrete Example (YardEnv) ===");
321
322 let state_dim = 3usize;
323 let action_dim = 3usize;
324 let total_steps = std::env::var("PPOD_STEPS")
325 .ok()
326 .and_then(|v| v.parse::<usize>().ok())
327 .unwrap_or(3500usize);
328 let horizon = 128usize;
329 let epochs = 4usize;
330 let mini_batch_size = 64usize;
331 let gamma = 0.99f32;
332 let lam = 0.95f32;
333 let clip_eps = 0.2f32;
334 let vf_coef = 0.5f32;
335 let ent_coef = 0.0f32;
336 let max_grad_norm = 1.0f32;
337
338 let mut actor = Actor::new(state_dim, action_dim, Some(111));
339 let mut critic = Critic::new(state_dim, Some(222));
340 let mut actor_opt = Adam::with_learning_rate(3e-4);
341 for p in actor.parameters() {
342 actor_opt.add_parameter(p);
343 }
344 let mut critic_opt = Adam::with_learning_rate(3e-4);
345 for p in critic.parameters() {
346 critic_opt.add_parameter(p);
347 }
348
349 let mut env = YardEnv::new(1234);
350 let mut rng = SmallRng::new(98765);
351 let mut state = env.reset();
352 let mut episode_return = 0.0f32;
353 let mut episode = 0usize;
354 let mut ema_return: Option<f32> = None;
355 let ema_alpha = 0.05f32;
356 let mut best_return = f32::NEG_INFINITY;
357
358 let mut t = 0usize;
359 while t < total_steps {
360 let mut batch = RolloutBatch::new(horizon, state_dim);
361 for _ in 0..horizon {
362 // Actor logits and categorical sampling
363 let logits = actor.forward(&state); // [1, A]
364 let probs = logits.softmax(1); // [1, A]
365 // sample action from probs (CPU sampling)
366 let p = probs.data();
367 let (p0, p1, _p2) = (p[0], p[1], p[2]);
368 let u = rng.next_f32();
369 let a_idx = if u < p0 {
370 0
371 } else if u < p0 + p1 {
372 1
373 } else {
374 2
375 };
376
377 let old_logp = {
378 let _ng = NoGradTrack::new();
379 let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380 lp.data()[0]
381 };
382
383 // Step env
384 let (next_state, reward, done) = env.step(a_idx);
385 episode_return += reward;
386
387 // Critic value
388 let value_t = critic.forward(&state);
389 let value_v = value_t.data()[0];
390
391 batch.push(
392 state.data(),
393 a_idx,
394 old_logp,
395 reward,
396 if done { 1.0 } else { 0.0 },
397 value_v,
398 next_state.data(),
399 );
400
401 state = if done {
402 let st = env.reset();
403 ema_return = Some(match ema_return {
404 None => episode_return,
405 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406 });
407 if episode_return > best_return {
408 best_return = episode_return;
409 }
410 println!(
411 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412 t,
413 episode,
414 episode_return,
415 ema_return.unwrap_or(episode_return),
416 best_return
417 );
418 episode_return = 0.0;
419 episode += 1;
420 st
421 } else {
422 next_state
423 };
424
425 t += 1;
426 if t >= total_steps {
427 break;
428 }
429 }
430
431 // Bootstrap values for GAE
432 let next_values: Vec<f32> = {
433 let mut out = Vec::with_capacity(batch.len());
434 for i in 0..batch.len() {
435 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437 out.push(critic.forward(&s2_t).data()[0]);
438 }
439 out
440 };
441
442 let mut returns = vec![0.0f32; batch.len()];
443 let mut adv = vec![0.0f32; batch.len()];
444 compute_gae(
445 &mut returns,
446 &mut adv,
447 &batch.rewards,
448 &batch.dones,
449 &batch.values,
450 &next_values,
451 gamma,
452 lam,
453 );
454 normalize_in_place(&mut adv, 1e-8);
455
456 // Tensors for training
457 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458 let actions_vec = batch.actions.clone();
459 let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463 // PPO epochs
464 let num_minibatches = batch.len().div_ceil(mini_batch_size);
465 for e in 0..epochs {
466 for mb in 0..num_minibatches {
467 let start = mb * mini_batch_size;
468 let end = (start + mini_batch_size).min(batch.len());
469 if start >= end {
470 break;
471 }
472
473 // Views
474 let s_mb = states_t
475 .slice_view(start * state_dim, 1, (end - start) * state_dim)
476 .reshape(vec![(end - start) as i32, state_dim as i32]);
477 let oldlp_mb = old_logp_t
478 .slice_view(start, 1, end - start)
479 .reshape(vec![(end - start) as i32, 1]);
480 let ret_mb = returns_t
481 .slice_view(start, 1, end - start)
482 .reshape(vec![(end - start) as i32, 1]);
483 let adv_mb = adv_t
484 .slice_view(start, 1, end - start)
485 .reshape(vec![(end - start) as i32, 1]);
486 let a_slice = &actions_vec[start..end];
487
488 // Zero grads
489 {
490 let mut ps = actor.parameters();
491 actor_opt.zero_grad(&mut ps);
492 }
493 {
494 let mut ps = critic.parameters();
495 critic_opt.zero_grad(&mut ps);
496 }
497
498 // Forward
499 let logits_mb = actor.forward(&s_mb); // [B,A]
500 let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501 let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502 let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503 let pg1 = ratio.mul_tensor(&adv_mb);
504 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507 let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509 let v_pred = critic.forward(&s_mb);
510 let v_loss = v_pred
511 .sub_tensor(&ret_mb)
512 .pow_scalar(2.0)
513 .mean()
514 .mul_scalar(vf_coef);
515
516 // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517 let probs_mb = logits_mb.softmax(1);
518 let logp_all = probs_mb.add_scalar(1e-8).log();
519 let ent = probs_mb
520 .mul_tensor(&logp_all)
521 .sum_dims(&[1], true)
522 .mul_scalar(-1.0)
523 .mean()
524 .mul_scalar(ent_coef);
525
526 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527 loss.backward(None);
528
529 // Step actor
530 {
531 let params = actor.parameters();
532 let mut with_grads: Vec<&mut Tensor> = Vec::new();
533 for p in params {
534 if p.grad_owned().is_some() {
535 with_grads.push(p);
536 }
537 }
538 if !with_grads.is_empty() {
539 let _ = grad_global_norm(&mut with_grads);
540 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541 actor_opt.step(&mut with_grads);
542 actor_opt.zero_grad(&mut with_grads);
543 }
544 }
545
546 // Step critic
547 {
548 let params = critic.parameters();
549 let mut with_grads: Vec<&mut Tensor> = Vec::new();
550 for p in params {
551 if p.grad_owned().is_some() {
552 with_grads.push(p);
553 }
554 }
555 if !with_grads.is_empty() {
556 let _ = grad_global_norm(&mut with_grads);
557 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558 critic_opt.step(&mut with_grads);
559 critic_opt.zero_grad(&mut with_grads);
560 }
561 }
562
563 if e == 0 && mb == 0 {
564 println!(
565 "update@t={} | actor_loss={:.4} v_loss={:.4}",
566 t,
567 actor_loss.value(),
568 v_loss.value()
569 );
570 }
571
572 clear_all_graphs_known();
573 }
574 }
575 }
576
577 println!("=== PPO discrete training finished ===");
578 Ok(())
579}More examples
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330 println!("=== PPO Continuous Example (YardEnv) ===");
331
332 let state_dim = 3usize;
333 let action_dim = 1usize;
334
335 // Hparams
336 let total_steps = std::env::var("PPO_STEPS")
337 .ok()
338 .and_then(|v| v.parse::<usize>().ok())
339 .unwrap_or(4000usize);
340 let horizon = 128usize; // rollout length per update
341 let epochs = 4usize; // PPO epochs per update
342 let mini_batch_size = 64usize; // minibatch from horizon
343 let gamma = 0.99f32;
344 let lam = 0.95f32; // GAE lambda
345 let clip_eps = 0.2f32;
346 let vf_coef = 0.5f32;
347 let ent_coef = 0.0f32;
348 let max_grad_norm = 1.0f32;
349
350 // Models
351 let mut actor = Actor::new(state_dim, action_dim, Some(101));
352 let mut critic = Critic::new(state_dim, Some(202));
353
354 // Opts
355 let mut actor_opt = Adam::with_learning_rate(3e-4);
356 for p in actor.parameters() {
357 actor_opt.add_parameter(p);
358 }
359 let mut critic_opt = Adam::with_learning_rate(3e-4);
360 for p in critic.parameters() {
361 critic_opt.add_parameter(p);
362 }
363
364 // Env and RNG
365 let mut env = YardEnv::new(42);
366 let mut rng = SmallRng::new(999);
367 let mut state = env.reset();
368
369 // Metrics
370 let mut episode_return = 0.0f32;
371 let mut episode = 0usize;
372 let mut ema_return: Option<f32> = None;
373 let ema_alpha = 0.05f32;
374 let mut best_return = f32::NEG_INFINITY;
375
376 let mut t = 0usize;
377 while t < total_steps {
378 // Collect a rollout
379 let mut batch = RolloutBatch::new(horizon, state_dim);
380 for _ in 0..horizon {
381 // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382 let (mean, log_std_row) = actor.forward(&state);
383 let mean_v = mean.data()[0];
384 let log_std_v = log_std_row.data()[0];
385 let std_v = log_std_v.exp();
386 let noise = rng.normal();
387 let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389 // Build action tensor [1, A] for log_prob calculation with autograd
390 let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391 let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392 let log_prob_v = log_prob_t.data()[0];
393
394 // Step env
395 let (next_state, reward, done) = env.step(action_v);
396 episode_return += reward;
397
398 // Value
399 let value_t = critic.forward(&state);
400 let value_v = value_t.data()[0];
401
402 // Push
403 batch.push(
404 state.data(),
405 action_v,
406 log_prob_v,
407 reward,
408 if done { 1.0 } else { 0.0 },
409 value_v,
410 next_state.data(),
411 );
412
413 // Reset
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return
430 );
431 episode_return = 0.0;
432 episode += 1;
433 st
434 } else {
435 next_state
436 };
437
438 t += 1;
439 if t >= total_steps {
440 break;
441 }
442 }
443
444 // Bootstrap next values for GAE
445 let next_values: Vec<f32> = {
446 let mut out = Vec::with_capacity(batch.len());
447 for i in 0..batch.len() {
448 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450 let v2 = critic.forward(&s2_t).data()[0];
451 out.push(v2);
452 }
453 out
454 };
455
456 // Compute returns and advantages
457 let mut returns = vec![0.0f32; batch.len()];
458 let mut adv = vec![0.0f32; batch.len()];
459 compute_gae(
460 &mut returns,
461 &mut adv,
462 &batch.rewards,
463 &batch.dones,
464 &batch.values,
465 &next_values,
466 gamma,
467 lam,
468 );
469 normalize_in_place(&mut adv, 1e-8);
470
471 // Prepare tensors for training
472 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473 let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474 let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478 // PPO epochs over the rollout
479 let num_minibatches = batch.len().div_ceil(mini_batch_size);
480 for e in 0..epochs {
481 for mb in 0..num_minibatches {
482 let start = mb * mini_batch_size;
483 let end = (start + mini_batch_size).min(batch.len());
484 if start >= end {
485 break;
486 }
487
488 // Slice views
489 let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490 let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491 let a_mb = actions_t
492 .slice_view(start * action_dim, 1, (end - start) * action_dim)
493 .reshape(vec![(end - start) as i32, action_dim as i32]);
494 let oldlp_mb = old_logp_t
495 .slice_view(start, 1, end - start)
496 .reshape(vec![(end - start) as i32, 1]);
497 let ret_mb = returns_t
498 .slice_view(start, 1, end - start)
499 .reshape(vec![(end - start) as i32, 1]);
500 let adv_mb = adv_t
501 .slice_view(start, 1, end - start)
502 .reshape(vec![(end - start) as i32, 1]);
503
504 // Zero grads
505 {
506 let mut ps = actor.parameters();
507 actor_opt.zero_grad(&mut ps);
508 }
509 {
510 let mut ps = critic.parameters();
511 critic_opt.zero_grad(&mut ps);
512 }
513
514 // Forward actor and critic
515 let (mean_mb, log_std_row) = actor.forward(&s_mb);
516 let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517 let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518 let clip_low =
519 Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520 .unwrap();
521 let clip_high =
522 Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523 .unwrap();
524 // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525 let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526 let ratio_clipped =
527 clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528 let pg1 = ratio.mul_tensor(&adv_mb);
529 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532 let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534 let v_pred = critic.forward(&s_mb);
535 let v_loss = v_pred
536 .sub_tensor(&ret_mb)
537 .pow_scalar(2.0)
538 .mean()
539 .mul_scalar(vf_coef);
540
541 // Entropy (approx Gaussian entropy per action)
542 let entropy = log_std_row
543 .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544 .sum_dims(&[1], true)
545 .mean()
546 .mul_scalar(ent_coef);
547
548 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549 loss.backward(None);
550
551 // Step actor
552 {
553 let params = actor.parameters();
554 let mut with_grads: Vec<&mut Tensor> = Vec::new();
555 for p in params {
556 if p.grad_owned().is_some() {
557 with_grads.push(p);
558 }
559 }
560 if !with_grads.is_empty() {
561 let _ = grad_global_norm(&mut with_grads);
562 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563 actor_opt.step(&mut with_grads);
564 actor_opt.zero_grad(&mut with_grads);
565 }
566 }
567
568 // Step critic
569 {
570 let params = critic.parameters();
571 let mut with_grads: Vec<&mut Tensor> = Vec::new();
572 for p in params {
573 if p.grad_owned().is_some() {
574 with_grads.push(p);
575 }
576 }
577 if !with_grads.is_empty() {
578 let _ = grad_global_norm(&mut with_grads);
579 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580 critic_opt.step(&mut with_grads);
581 critic_opt.zero_grad(&mut with_grads);
582 }
583 }
584
585 // Occasionally log
586 if e == 0 && mb == 0 {
587 println!(
588 "update@t={} | actor_loss={:.4} v_loss={:.4}",
589 t,
590 actor_loss.value(),
591 v_loss.value()
592 );
593 }
594
595 clear_all_graphs_known();
596 }
597 }
598 }
599
600 println!("=== PPO training finished ===");
601 Ok(())
602}Source§impl Tensor
impl Tensor
Sourcepub fn split(&self, split_size: usize, dim: usize) -> Vec<Tensor>
pub fn split(&self, split_size: usize, dim: usize) -> Vec<Tensor>
Split tensor into chunks of equal size along specified dimension
Divides the tensor into multiple smaller tensors along the specified dimension, where each chunk (except possibly the last) has the same size. The last chunk may be smaller if the dimension size is not evenly divisible by the split size.
This operation returns a vector of tensors, where each tensor is a view or copy of a portion of the original tensor. The first chunk is returned as a view when possible (zero-copy), while subsequent chunks may require data copying for non-zero base offsets.
§Arguments
split_size- Size of each chunk along the specified dimension (must be > 0)dim- Dimension along which to split the tensor (must be < tensor rank)
§Returns
A vector of tensors, each representing a chunk of the original tensor. The number of chunks depends on the dimension size and split size.
§Panics
- If tensor rank is 0 (scalar tensors cannot be split)
- If
dimis out of bounds for the tensor rank - If
split_sizeis 0
§Examples
use train_station::Tensor;
// Split 2D tensor into equal chunks along dimension 1
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
let parts = tensor.split(1, 1);
assert_eq!(parts.len(), 3);
assert_eq!(parts[0].shape().dims(), vec![2, 1]);
assert_eq!(parts[1].shape().dims(), vec![2, 1]);
assert_eq!(parts[2].shape().dims(), vec![2, 1]);
assert_eq!(parts[0].get(&[0, 0]), 1.0);
assert_eq!(parts[1].get(&[0, 0]), 2.0);
assert_eq!(parts[2].get(&[1, 0]), 6.0);use train_station::Tensor;
// Split with uneven division (last chunk smaller)
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
let parts = tensor.split(2, 1);
assert_eq!(parts.len(), 3);
assert_eq!(parts[0].shape().dims(), vec![1, 2]);
assert_eq!(parts[1].shape().dims(), vec![1, 2]);
assert_eq!(parts[2].shape().dims(), vec![1, 1]); // Last chunk smalleruse train_station::Tensor;
// Split with gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
tensor.set_requires_grad(true);
let parts = tensor.split(1, 1);
assert_eq!(parts.len(), 2);
assert!(parts[0].requires_grad());
assert!(parts[1].requires_grad());use train_station::Tensor;
// Split 1D tensor
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
let parts = tensor.split(2, 0);
assert_eq!(parts.len(), 3);
assert_eq!(parts[0].shape().dims(), vec![2]);
assert_eq!(parts[1].shape().dims(), vec![2]);
assert_eq!(parts[2].shape().dims(), vec![2]);§Performance
- First Chunk: O(1) - Returns a view when possible (zero-copy)
- Subsequent Chunks: O(n) - May require data copying for non-zero offsets
- Memory Usage: Minimal allocation for view operations, copying for non-zero offsets
- Gradient Tracking: Each chunk preserves gradient requirements and tracking
§Relationship to Other Operations
This operation is related to other tensor transformations:
split_with_sizes()- More general version with explicit chunk sizescat()- Inverse operation that concatenates tensors back togetherchunk()- Alternative splitting operation with different semantics
§Memory Layout
The first chunk maintains the same underlying data as a view when the base offset is zero. Subsequent chunks may require data copying to handle non-zero base offsets, ensuring proper memory layout.
Sourcepub fn split_with_sizes(&self, split_sizes: &[usize], dim: usize) -> Vec<Tensor>
pub fn split_with_sizes(&self, split_sizes: &[usize], dim: usize) -> Vec<Tensor>
Split tensor into chunks with explicit sizes along specified dimension
Divides the tensor into multiple smaller tensors along the specified
dimension according to the provided size specifications. Each chunk
has the exact size specified in the split_sizes array, and the sum
of all sizes must equal the size of the specified dimension.
This operation provides precise control over the size of each resulting
chunk, unlike split() which creates equal-sized chunks. The first
chunk is returned as a view when possible (zero-copy), while subsequent
chunks may require data copying for non-zero base offsets.
§Arguments
split_sizes- Array specifying the size of each chunk along the dimensiondim- Dimension along which to split the tensor (must be < tensor rank)
§Returns
A vector of tensors, each representing a chunk of the original tensor
with the specified size. The number of chunks equals the length of split_sizes.
§Panics
- If tensor rank is 0 (scalar tensors cannot be split)
- If
dimis out of bounds for the tensor rank - If sum of
split_sizesdoes not equal the size of the specified dimension
§Examples
use train_station::Tensor;
// Split with explicit sizes
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![1, 5]).unwrap();
let parts = tensor.split_with_sizes(&[2, 3], 1);
assert_eq!(parts.len(), 2);
assert_eq!(parts[0].shape().dims(), vec![1, 2]);
assert_eq!(parts[1].shape().dims(), vec![1, 3]);
assert_eq!(parts[0].get(&[0, 0]), 1.0);
assert_eq!(parts[0].get(&[0, 1]), 2.0);
assert_eq!(parts[1].get(&[0, 0]), 3.0);use train_station::Tensor;
// Split 2D tensor with different chunk sizes
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
let parts = tensor.split_with_sizes(&[1, 2], 1);
assert_eq!(parts.len(), 2);
assert_eq!(parts[0].shape().dims(), vec![2, 1]);
assert_eq!(parts[1].shape().dims(), vec![2, 2]);use train_station::Tensor;
// Split with gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
tensor.set_requires_grad(true);
let parts = tensor.split_with_sizes(&[1, 1], 1);
assert_eq!(parts.len(), 2);
assert!(parts[0].requires_grad());
assert!(parts[1].requires_grad());use train_station::Tensor;
// Split 1D tensor with explicit sizes
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]).unwrap();
let parts = tensor.split_with_sizes(&[2, 2, 2], 0);
assert_eq!(parts.len(), 3);
assert_eq!(parts[0].shape().dims(), vec![2]);
assert_eq!(parts[1].shape().dims(), vec![2]);
assert_eq!(parts[2].shape().dims(), vec![2]);§Performance
- First Chunk: O(1) - Returns a view when possible (zero-copy)
- Subsequent Chunks: O(n) - May require data copying for non-zero offsets
- Memory Usage: Minimal allocation for view operations, copying for non-zero offsets
- Gradient Tracking: Each chunk preserves gradient requirements and tracking
§Relationship to Other Operations
This operation is related to other tensor transformations:
split()- Simplified version with equal-sized chunkscat()- Inverse operation that concatenates tensors back togetherchunk()- Alternative splitting operation with different semantics
§Memory Layout
The first chunk maintains the same underlying data as a view when the base offset is zero. Subsequent chunks may require data copying to handle non-zero base offsets, ensuring proper memory layout. Zero-sized chunks are handled by creating empty tensors with appropriate shapes.
Source§impl Tensor
impl Tensor
Sourcepub fn squeeze(&self, dim: Option<usize>) -> Tensor
pub fn squeeze(&self, dim: Option<usize>) -> Tensor
Remove dimensions of size 1 from the tensor
Removes singleton dimensions (dimensions with size 1) from the tensor, reducing its rank while preserving the total number of elements. This operation is useful for cleaning up tensor shapes and preparing data for operations that expect specific dimensionality.
The squeeze operation can remove either all size-1 dimensions or a
specific dimension if it has size 1. When all dimensions are size 1,
the result is a scalar tensor with shape [1] rather than an empty
tensor to maintain mathematical consistency.
§Arguments
dim- Optional specific dimension to squeeze. IfNone, all size-1 dimensions are removed. IfSome(d), only dimensiondis removed if it has size 1.
§Returns
A new tensor with size-1 dimensions removed. The total number of elements remains unchanged.
§Panics
- If
dimis specified but out of bounds for the tensor rank - If
dimis specified but the dimension does not have size 1
§Examples
use train_station::Tensor;
// Squeeze all size-1 dimensions
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
let squeezed = tensor.squeeze(None);
assert_eq!(squeezed.shape().dims(), vec![3]);
assert_eq!(squeezed.get(&[0]), 1.0);
assert_eq!(squeezed.get(&[1]), 2.0);
assert_eq!(squeezed.get(&[2]), 3.0);use train_station::Tensor;
// Squeeze specific dimension
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
let squeezed = tensor.squeeze(Some(0));
assert_eq!(squeezed.shape().dims(), vec![3, 1]);
assert_eq!(squeezed.get(&[0, 0]), 1.0);
assert_eq!(squeezed.get(&[1, 0]), 2.0);
assert_eq!(squeezed.get(&[2, 0]), 3.0);use train_station::Tensor;
// Squeeze preserves data integrity
let data = vec![1.0, 2.0, 3.0, 4.0];
let tensor = Tensor::from_slice(&data, vec![1, 2, 1, 2]).unwrap();
let squeezed = tensor.squeeze(None);
assert_eq!(squeezed.shape().dims(), vec![2, 2]);
assert_eq!(squeezed.size(), 4);
assert_eq!(squeezed.get(&[0, 0]), data[0]);
assert_eq!(squeezed.get(&[0, 1]), data[1]);
assert_eq!(squeezed.get(&[1, 0]), data[2]);
assert_eq!(squeezed.get(&[1, 1]), data[3]);use train_station::Tensor;
// Handle edge case: all dimensions are size 1
let tensor = Tensor::from_slice(&[5.0], vec![1, 1, 1]).unwrap();
let squeezed = tensor.squeeze(None);
assert_eq!(squeezed.shape().dims(), vec![1]); // Not empty!
assert_eq!(squeezed.get(&[0]), 5.0);use train_station::Tensor;
// Squeeze with gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3, 1]).unwrap();
tensor.set_requires_grad(true);
let squeezed = tensor.squeeze(None);
assert!(squeezed.requires_grad());
assert_eq!(squeezed.shape().dims(), vec![3]);use train_station::Tensor;
// Squeeze and unsqueeze roundtrip
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let unsqueezed = tensor.unsqueeze(0);
assert_eq!(unsqueezed.shape().dims(), vec![1, 3]);
let squeezed = unsqueezed.squeeze(Some(0));
assert_eq!(squeezed.shape().dims(), vec![3]);
assert_eq!(squeezed.get(&[0]), 1.0);
assert_eq!(squeezed.get(&[2]), 3.0);§Performance
- Time Complexity: O(1) - Returns a view through reshape operation
- Memory Usage: No additional memory allocation (view operation)
- Gradient Tracking: Preserves gradient requirements and tracking
- Shape Transformation: Reduces tensor rank by removing singleton dimensions
§Relationship to Other Operations
This operation is related to other tensor transformations:
unsqueeze()- Inverse operation that adds size-1 dimensionsreshape()- More general shape transformation operationflatten()- Reduces tensor to 1D by combining all dimensions
§Memory Layout
The squeezed tensor maintains the same underlying data as the original tensor through the reshape operation. This ensures zero-copy behavior when the tensor is contiguous, with only the shape metadata being modified to reflect the reduced dimensionality.
§Edge Cases
- All size-1 dimensions: Returns a tensor with shape
[1]rather than an empty tensor to maintain mathematical consistency - No size-1 dimensions: Returns a tensor with the same shape as the input
- Mixed dimensions: Only removes dimensions with size 1, preserving others
Source§impl Tensor
impl Tensor
Sourcepub fn stack(tensors: &[Tensor], dim: usize) -> Tensor
pub fn stack(tensors: &[Tensor], dim: usize) -> Tensor
Stack a list of tensors along a new dimension
Combines multiple tensors by adding a new dimension at the specified
position. All input tensors must have identical shapes, and the output
tensor will have a new dimension of size equal to the number of input
tensors. This operation is similar to PyTorch’s torch.stack function.
The stacking operation creates a new axis in the output tensor, unlike concatenation which operates along existing dimensions. This makes stacking useful for creating batch dimensions, combining feature maps, and implementing operations that require adding new tensor axes.
§Arguments
tensors- Array of tensors to stack. All tensors must have identical shapes.dim- Index of the new axis in the output shape (0 <= dim <= rank)
§Returns
A new tensor with the stacked data. The output shape is the input shape
with a new dimension of size tensors.len() inserted at position dim.
§Panics
- If the tensor array is empty
- If any tensor has a different shape than the first tensor
- If
dimis out of bounds (dim > rank of input tensors)
§Examples
use train_station::Tensor;
// Stack two 1D tensors along dimension 0
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let b = Tensor::from_slice(&[4.0, 5.0, 6.0], vec![3]).unwrap();
let stacked = Tensor::stack(&[a, b], 0);
assert_eq!(stacked.shape().dims(), vec![2, 3]);
assert_eq!(stacked.get(&[0, 0]), 1.0);
assert_eq!(stacked.get(&[1, 2]), 6.0);use train_station::Tensor;
// Stack multiple 2D tensors along dimension 1
let a = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let b = Tensor::from_slice(&[5.0, 6.0, 7.0, 8.0], vec![2, 2]).unwrap();
let c = Tensor::from_slice(&[9.0, 10.0, 11.0, 12.0], vec![2, 2]).unwrap();
let stacked = Tensor::stack(&[a, b, c], 1);
assert_eq!(stacked.shape().dims(), vec![2, 3, 2]);
assert_eq!(stacked.get(&[0, 0, 0]), 1.0);
assert_eq!(stacked.get(&[1, 2, 1]), 12.0);use train_station::Tensor;
// Stack with gradient tracking
let mut a = Tensor::from_slice(&[1.0, 2.0], vec![2]).unwrap();
let mut b = Tensor::from_slice(&[3.0, 4.0], vec![2]).unwrap();
a.set_requires_grad(true);
b.set_requires_grad(true);
let stacked = Tensor::stack(&[a, b], 0);
assert!(stacked.requires_grad());
assert_eq!(stacked.shape().dims(), vec![2, 2]);use train_station::Tensor;
// Stack 3D tensors along the last dimension
let data1: Vec<f32> = (0..8).map(|i| i as f32).collect();
let data2: Vec<f32> = (8..16).map(|i| i as f32).collect();
let a = Tensor::from_slice(&data1, vec![2, 2, 2]).unwrap();
let b = Tensor::from_slice(&data2, vec![2, 2, 2]).unwrap();
let stacked = Tensor::stack(&[a, b], 3);
assert_eq!(stacked.shape().dims(), vec![2, 2, 2, 2]);
assert_eq!(stacked.get(&[0, 0, 0, 0]), 0.0);
assert_eq!(stacked.get(&[1, 1, 1, 1]), 15.0);§Performance
- Time Complexity: O(n) where n is the total number of elements
- Memory Usage: Allocates new contiguous tensor for output
- SIMD Optimization: Uses AVX2 acceleration for large block copies
- Block-wise Copying: Optimized copying strategy for better cache performance
- Gradient Tracking: Preserves gradient requirements and tracking
§Relationship to Other Operations
This operation is related to other tensor transformations:
cat()- Concatenates tensors along existing dimensionsunsqueeze()- Adds a single dimension of size 1reshape()- Changes tensor shape without adding dimensions
§Memory Layout
The output tensor is always contiguous, with elements arranged so that the stacked dimension is the fastest-changing index. This ensures optimal performance for subsequent operations and maintains compatibility with SIMD optimizations.
§Gradient Computation
During backward passes, gradients are split along the stacked dimension and distributed back to the original input tensors. This is implemented using the same gradient function as concatenation, treating the stack operation as concatenation along a new axis.
Source§impl Tensor
impl Tensor
Sourcepub fn transpose(&self, dim0: usize, dim1: usize) -> Tensor
pub fn transpose(&self, dim0: usize, dim1: usize) -> Tensor
Transpose two dimensions of the tensor
Swaps two specified dimensions of the tensor, modifying the shape and memory access pattern. When possible, this operation returns a zero-copy view using stride manipulation. For complex cases or non-contiguous tensors, data is copied to ensure correct transposition.
The transpose operation is its own inverse - applying transpose twice with the same dimensions returns the original tensor.
§Arguments
dim0- First dimension to swap (must be < tensor rank)dim1- Second dimension to swap (must be < tensor rank)
§Returns
A new tensor with the specified dimensions transposed. The total number of elements remains unchanged.
§Panics
- If
dim0is out of bounds for the tensor rank - If
dim1is out of bounds for the tensor rank
§Examples
use train_station::Tensor;
// Basic 2D transpose
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
let transposed = tensor.transpose(0, 1);
assert_eq!(transposed.shape().dims(), vec![3, 2]);
assert_eq!(transposed.get(&[0, 0]), 1.0);
assert_eq!(transposed.get(&[0, 1]), 4.0);
assert_eq!(transposed.get(&[1, 0]), 2.0);
assert_eq!(transposed.get(&[1, 1]), 5.0);
assert_eq!(transposed.get(&[2, 0]), 3.0);
assert_eq!(transposed.get(&[2, 1]), 6.0);use train_station::Tensor;
// 3D tensor transpose
let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
let tensor = Tensor::from_slice(&data, vec![2, 3, 4]).unwrap();
let transposed = tensor.transpose(0, 1);
assert_eq!(transposed.shape().dims(), vec![3, 2, 4]);use train_station::Tensor;
// Transpose with gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
tensor.set_requires_grad(true);
let transposed = tensor.transpose(0, 1);
assert!(transposed.requires_grad());
assert_eq!(transposed.shape().dims(), vec![2, 2]);use train_station::Tensor;
// Transpose same dimension (no change)
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let result = tensor.transpose(1, 1);
assert_eq!(result.shape().dims(), tensor.shape().dims());
assert_eq!(result.get(&[0, 0]), tensor.get(&[0, 0]));use train_station::Tensor;
// Transpose is its own inverse
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let transposed = tensor.transpose(0, 1);
let double_transposed = transposed.transpose(0, 1);
assert_eq!(double_transposed.shape().dims(), tensor.shape().dims());
assert_eq!(double_transposed.get(&[0, 0]), tensor.get(&[0, 0]));§Performance
- Contiguous tensors: O(1) time complexity, returns a view
- Non-contiguous tensors: O(n) time complexity with data copying
- Memory usage: No additional allocation for view operations
- Gradient tracking: Preserves gradient requirements and tracking
§Relationship to Other Operations
This operation is related to other tensor transformations:
t()- Convenience method for matrix transpose (last two dimensions)permute()- More general dimension reordering operationreshape()- Changes shape without changing dimension order
§Memory Layout
For contiguous tensors, transpose returns a view with modified strides, making the tensor non-contiguous. For non-contiguous tensors or complex cases, data is copied to ensure correct transposition.
Examples found in repository?
72 pub fn forward(
73 &self,
74 query: &Tensor,
75 key: &Tensor,
76 value: &Tensor,
77 attn_mask: Option<&Tensor>,
78 ) -> Tensor {
79 let qkv = Self::project_qkv(query, key, value, &self.q_proj, &self.k_proj, &self.v_proj);
80 let (q, k, v) = qkv;
81
82 // Split heads: [b, t, e] -> [b, h, t, d]
83 let (b, tq, _e) = Self::triple(query);
84 let (_b2, tk, _e2) = Self::triple(key);
85 let q = Self::split_heads(&q, b, tq, self.num_heads, self.head_dim);
86 let k = Self::split_heads(&k, b, tk, self.num_heads, self.head_dim);
87 let v = Self::split_heads(&v, b, tk, self.num_heads, self.head_dim);
88
89 // Scaled dot-product attention
90 // logits: [b, h, tq, tk]
91 let k_t = k.transpose(2, 3);
92 let mut logits = q.matmul(&k_t).div_scalar((self.head_dim as f32).sqrt());
93 if let Some(mask) = attn_mask {
94 let dims = mask.shape().dims().to_vec();
95 // If boolean-like mask matching [b,h,tq,tk], apply masked_fill
96 if dims.len() == 4 && dims[0] == b && dims[1] == self.num_heads && dims[2] == tq {
97 // Interpret mask > 0.5 as keep; we invert to build masked positions
98 let cond: Vec<bool> = mask.data().iter().map(|&v| v < 0.5).collect();
99 // Apply masked fill on a flattened view, then reshape back
100 let flat_logits = logits.view(vec![(b * self.num_heads * tq * tk) as i32]);
101 let filled = flat_logits.masked_fill(&cond, f32::NEG_INFINITY);
102 logits = filled.view(vec![b as i32, self.num_heads as i32, tq as i32, tk as i32]);
103 } else {
104 // Fallback: additive mask
105 logits = logits.add_tensor(mask);
106 }
107 }
108 let attn = logits.softmax(3);
109
110 // context: [b, h, tq, d]
111 let context = attn.matmul(&v);
112 let context = context.permute(vec![0, 2, 1, 3]); // [b, tq, h, d]
113 let context = context.contiguous().view(vec![
114 b as i32,
115 tq as i32,
116 (self.num_heads * self.head_dim) as i32,
117 ]);
118
119 // Output projection (flatten to 2D, project, then restore 3D)
120 let flat = context.view(vec![(b * tq) as i32, self.embed_dim as i32]);
121 let out2d = self.out_proj.forward(&flat);
122 out2d.view(vec![b as i32, tq as i32, self.embed_dim as i32])
123 }Sourcepub fn t(&self) -> Tensor
pub fn t(&self) -> Tensor
Matrix transpose (transpose last two dimensions)
Convenience method for the common case of matrix transposition. For 2D tensors, this performs a standard matrix transpose. For higher-dimensional tensors, this transposes the last two dimensions, treating the tensor as a batch of matrices.
This method is equivalent to transpose(rank-2, rank-1) where
rank is the number of dimensions in the tensor.
§Returns
A new tensor with the last two dimensions transposed
§Panics
- If the tensor has less than 2 dimensions
§Examples
use train_station::Tensor;
// 2D matrix transpose
let matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let transposed = matrix.t();
assert_eq!(transposed.shape().dims(), vec![2, 2]);
assert_eq!(transposed.get(&[0, 0]), 1.0);
assert_eq!(transposed.get(&[0, 1]), 3.0);
assert_eq!(transposed.get(&[1, 0]), 2.0);
assert_eq!(transposed.get(&[1, 1]), 4.0);use train_station::Tensor;
// 3D tensor (batch of matrices)
let data: Vec<f32> = (0..12).map(|i| i as f32).collect();
let tensor = Tensor::from_slice(&data, vec![2, 2, 3]).unwrap();
let transposed = tensor.t();
assert_eq!(transposed.shape().dims(), vec![2, 3, 2]);use train_station::Tensor;
// Matrix transpose with gradient tracking
let mut matrix = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
matrix.set_requires_grad(true);
let transposed = matrix.t();
assert!(transposed.requires_grad());
assert_eq!(transposed.shape().dims(), vec![2, 2]);§Performance
- Time Complexity: Same as
transpose()- O(1) for views, O(n) for copies - Memory Usage: Same as
transpose()- no allocation for views - Gradient Tracking: Preserves gradient requirements and tracking
§Relationship to Other Operations
This operation is equivalent to:
use train_station::Tensor;
let tensor = Tensor::new(vec![2, 3, 4]);
let rank = tensor.shape().rank();
let transposed1 = tensor.t();
let transposed2 = tensor.transpose(rank - 2, rank - 1);
// transposed1 and transposed2 are identicalSource§impl Tensor
impl Tensor
Sourcepub fn unsqueeze(&self, dim: usize) -> Tensor
pub fn unsqueeze(&self, dim: usize) -> Tensor
Add a dimension of size 1 at the specified position
Inserts a new dimension of size 1 at the specified position in the tensor’s shape, increasing the rank by 1 while preserving the total number of elements. This operation is useful for preparing tensors for broadcasting, creating batch dimensions, and adapting tensor shapes for specific neural network operations.
The unsqueeze operation is the inverse of squeeze() - unsqueezing
a dimension and then squeezing it at the same position returns the
original tensor.
§Arguments
dim- Position to insert the new dimension (0 <= dim <= rank)
§Returns
A new tensor with an additional dimension of size 1 at the specified position. The total number of elements remains unchanged.
§Panics
- If
dimis out of bounds (dim > rank of the tensor)
§Examples
use train_station::Tensor;
// Add dimension at the beginning
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let unsqueezed = tensor.unsqueeze(0);
assert_eq!(unsqueezed.shape().dims(), vec![1, 3]);
assert_eq!(unsqueezed.get(&[0, 0]), 1.0);
assert_eq!(unsqueezed.get(&[0, 1]), 2.0);
assert_eq!(unsqueezed.get(&[0, 2]), 3.0);use train_station::Tensor;
// Add dimension at the end
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let unsqueezed = tensor.unsqueeze(1);
assert_eq!(unsqueezed.shape().dims(), vec![3, 1]);
assert_eq!(unsqueezed.get(&[0, 0]), 1.0);
assert_eq!(unsqueezed.get(&[1, 0]), 2.0);
assert_eq!(unsqueezed.get(&[2, 0]), 3.0);use train_station::Tensor;
// Add dimension in the middle of 2D tensor
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let unsqueezed = tensor.unsqueeze(1);
assert_eq!(unsqueezed.shape().dims(), vec![2, 1, 2]);
assert_eq!(unsqueezed.get(&[0, 0, 0]), 1.0);
assert_eq!(unsqueezed.get(&[0, 0, 1]), 2.0);
assert_eq!(unsqueezed.get(&[1, 0, 0]), 3.0);
assert_eq!(unsqueezed.get(&[1, 0, 1]), 4.0);use train_station::Tensor;
// Unsqueeze preserves data integrity
let data = vec![1.0, 2.0, 3.0, 4.0];
let tensor = Tensor::from_slice(&data, vec![4]).unwrap();
let unsqueezed = tensor.unsqueeze(0);
assert_eq!(unsqueezed.shape().dims(), vec![1, 4]);
assert_eq!(unsqueezed.size(), 4);
for (i, &d) in data.iter().enumerate() {
assert_eq!(unsqueezed.get(&[0, i]), d);
}use train_station::Tensor;
// Unsqueeze with gradient tracking
let mut tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
tensor.set_requires_grad(true);
let unsqueezed = tensor.unsqueeze(0);
assert!(unsqueezed.requires_grad());
assert_eq!(unsqueezed.shape().dims(), vec![1, 3]);use train_station::Tensor;
// Unsqueeze and squeeze roundtrip
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let unsqueezed = tensor.unsqueeze(0);
assert_eq!(unsqueezed.shape().dims(), vec![1, 3]);
let squeezed = unsqueezed.squeeze(Some(0));
assert_eq!(squeezed.shape().dims(), vec![3]);
assert_eq!(squeezed.get(&[0]), 1.0);
assert_eq!(squeezed.get(&[2]), 3.0);use train_station::Tensor;
// Multiple unsqueeze operations
let tensor = Tensor::from_slice(&[42.0], vec![1]).unwrap();
let unsqueezed1 = tensor.unsqueeze(0);
assert_eq!(unsqueezed1.shape().dims(), vec![1, 1]);
let unsqueezed2 = unsqueezed1.unsqueeze(0);
assert_eq!(unsqueezed2.shape().dims(), vec![1, 1, 1]);
assert_eq!(unsqueezed2.get(&[0, 0, 0]), 42.0);§Performance
- Time Complexity: O(1) - Returns a view through reshape operation
- Memory Usage: No additional memory allocation (view operation)
- Gradient Tracking: Preserves gradient requirements and tracking
- Shape Transformation: Increases tensor rank by adding singleton dimensions
§Relationship to Other Operations
This operation is related to other tensor transformations:
squeeze()- Inverse operation that removes size-1 dimensionsreshape()- More general shape transformation operationexpand()- Broadcasts dimensions to larger sizes
§Memory Layout
The unsqueezed tensor maintains the same underlying data as the original tensor through the reshape operation. This ensures zero-copy behavior when the tensor is contiguous, with only the shape metadata being modified to reflect the increased dimensionality.
§Broadcasting Applications
Unsqueeze is commonly used for broadcasting operations:
use train_station::Tensor;
// Prepare for broadcasting: [3] -> [1, 3] for row-wise operations
let vector = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let row_vector = vector.unsqueeze(0); // Shape: [1, 3]
// Prepare for broadcasting: [3] -> [3, 1] for column-wise operations
let column_vector = vector.unsqueeze(1); // Shape: [3, 1]§Neural Network Applications
Unsqueeze is essential for neural network operations:
use train_station::Tensor;
// Single sample -> batch dimension for neural network input
let sample = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let batch = sample.unsqueeze(0); // Shape: [1, 3] for batch processing
// Add channel dimension for convolutional operations
let feature_map = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
let with_channels = feature_map.unsqueeze(0); // Shape: [1, 2, 2] for conv layersTrait Implementations§
Source§impl Add<Tensor> for f32
Scalar-tensor addition operator implementations
impl Add<Tensor> for f32
Scalar-tensor addition operator implementations
Provides addition operations between scalars and tensors.
All implementations delegate to the underlying add_scalar method.
Source§impl Add<f32> for Tensor
Tensor-scalar addition operator implementations
impl Add<f32> for Tensor
Tensor-scalar addition operator implementations
Provides addition operations between tensors and scalars.
All implementations delegate to the underlying add_scalar method.
Source§impl Add for Tensor
Tensor addition operator implementations
impl Add for Tensor
Tensor addition operator implementations
Provides addition operations between tensors with various reference combinations.
All implementations delegate to the underlying add_tensor method for optimal performance.
Source§impl AddAssign<&Tensor> for Tensor
impl AddAssign<&Tensor> for Tensor
Source§fn add_assign(&mut self, other: &Tensor)
fn add_assign(&mut self, other: &Tensor)
Adds another tensor reference to this tensor in-place
Source§impl AddAssign<f32> for Tensor
Tensor-scalar addition assignment operator implementations
impl AddAssign<f32> for Tensor
Tensor-scalar addition assignment operator implementations
Provides in-place addition operations between tensors and scalars.
Source§fn add_assign(&mut self, scalar: f32)
fn add_assign(&mut self, scalar: f32)
Adds a scalar to each element of this tensor in-place
Source§impl AddAssign for Tensor
Tensor addition assignment operator implementations
impl AddAssign for Tensor
Tensor addition assignment operator implementations
Provides in-place addition operations between tensors.
All implementations delegate to the underlying add_tensor method.
Source§fn add_assign(&mut self, other: Tensor)
fn add_assign(&mut self, other: Tensor)
Adds another tensor to this tensor in-place
Source§impl Clone for Tensor
Clone implementation for Tensor
impl Clone for Tensor
Clone implementation for Tensor
Creates a deep copy of the tensor data but resets gradtrack state (new tensor won’t track gradients unless explicitly set)
Source§impl Div<Tensor> for f32
Scalar-tensor division operator implementations
impl Div<Tensor> for f32
Scalar-tensor division operator implementations
Provides division operations between scalars and tensors.
Computes scalar / tensor by computing the reciprocal of the tensor and multiplying by the scalar.
Source§impl Div<f32> for Tensor
Tensor-scalar division operator implementations
impl Div<f32> for Tensor
Tensor-scalar division operator implementations
Provides division operations between tensors and scalars.
All implementations delegate to the underlying div_scalar method.
Source§impl Div for Tensor
Tensor division operator implementations
impl Div for Tensor
Tensor division operator implementations
Provides element-wise division operations between tensors with various reference combinations.
All implementations delegate to the underlying div_tensor method for optimal performance.
Source§impl DivAssign<&Tensor> for Tensor
impl DivAssign<&Tensor> for Tensor
Source§fn div_assign(&mut self, other: &Tensor)
fn div_assign(&mut self, other: &Tensor)
Divides this tensor by another tensor reference in-place
Source§impl DivAssign<f32> for Tensor
Tensor-scalar division assignment operator implementations
impl DivAssign<f32> for Tensor
Tensor-scalar division assignment operator implementations
Provides in-place division operations between tensors and scalars.
Source§fn div_assign(&mut self, scalar: f32)
fn div_assign(&mut self, scalar: f32)
Divides each element of this tensor by a scalar in-place
Source§impl DivAssign for Tensor
Tensor division assignment operator implementations
impl DivAssign for Tensor
Tensor division assignment operator implementations
Provides in-place division operations between tensors.
All implementations delegate to the underlying div_tensor method.
Source§fn div_assign(&mut self, other: Tensor)
fn div_assign(&mut self, other: Tensor)
Divides this tensor by another tensor in-place
Source§impl FromFieldValue for Tensor
impl FromFieldValue for Tensor
Source§fn from_field_value(
value: FieldValue,
field_name: &str,
) -> SerializationResult<Self>
fn from_field_value( value: FieldValue, field_name: &str, ) -> SerializationResult<Self>
Source§impl FromIterator<Tensor> for Tensor
impl FromIterator<Tensor> for Tensor
Source§fn from_iter<I: IntoIterator<Item = Tensor>>(iter: I) -> Self
fn from_iter<I: IntoIterator<Item = Tensor>>(iter: I) -> Self
Collect element view tensors back into a single tensor
This method reconstructs a tensor from an iterator of element view tensors. It includes optimizations for common patterns and maintains gradient tracking when appropriate.
The collection process automatically detects whether all elements are scalar
views (shape [1]) and uses optimized collection strategies accordingly.
Gradient tracking is preserved when any input element requires gradients.
§Performance
- Optimized Collection: Specialized paths for scalar and mixed views
- Memory Efficient: Direct memory copying without intermediate allocations
- Gradient Preservation: Maintains gradtrack functionality when enabled
- Shape Detection: Automatic detection of element shapes for optimization
§Implementation Details
The method performs the following steps:
- Element Collection: Gathers all element tensors from the iterator
- Shape Analysis: Determines if all elements are scalar views
- Optimized Path: Uses specialized collection for scalar views
- General Path: Handles mixed shapes by flattening into 1D tensor
- Gradient Setup: Preserves gradient tracking when appropriate
§Examples
§Basic Collection
use train_station::Tensor;
let original = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
let doubled: Tensor = original.iter()
.map(|elem| elem.mul_scalar(2.0))
.collect();
assert_eq!(doubled.data(), &[2.0, 4.0, 6.0]);§Collection with Gradient Tracking
use train_station::Tensor;
let original = Tensor::from_slice(&[1.0, 2.0], vec![2])
.unwrap()
.with_requires_grad();
let result: Tensor = original.iter()
.map(|elem| elem.mul_scalar(2.0))
.collect();
assert!(result.requires_grad());
assert_eq!(result.data(), &[2.0, 4.0]);§Empty Iterator Handling
use train_station::Tensor;
let empty: Tensor = Vec::<Tensor>::new().into_iter().collect();
assert_eq!(empty.size(), 0);
assert_eq!(empty.shape().dims(), vec![0]);Source§impl FromIterator<f32> for Tensor
impl FromIterator<f32> for Tensor
Source§fn from_iter<I: IntoIterator<Item = f32>>(iter: I) -> Self
fn from_iter<I: IntoIterator<Item = f32>>(iter: I) -> Self
Collect f32 values into a 1D contiguous, SIMD-aligned Tensor
- Streams directly when iterator reports exact size_hint
- Falls back to temporary Vec and optimized_copy otherwise
- No gradient tracking is set on the result
Source§impl<'a> IntoIterator for &'a Tensor
High-performance iterator over tensor elements as view tensors
impl<'a> IntoIterator for &'a Tensor
High-performance iterator over tensor elements as view tensors
Each element becomes a proper Tensor view of shape [1] that can use
all existing tensor operations and gradient tracking. Implements all
standard iterator traits for maximum compatibility with Rust’s ecosystem.
This iterator provides zero-copy access to tensor elements through view tensors, enabling efficient element-wise operations while maintaining full compatibility with Rust’s standard library iterator methods.
§Performance
- Zero-Copy Views: Each element is a view tensor sharing memory with source
- O(1) Element Access: Constant-time view creation for each element
- Memory Efficient: ~64 bytes overhead per element view
- SIMD Compatible: All tensor operations use existing optimizations
- Gradient Tracking: Full gradtrack support through element operations
§Implementation Details
The iterator creates lightweight view tensors on-demand, sharing the same memory allocation as the source tensor. This ensures zero-copy semantics while maintaining full tensor operation compatibility.
Each element view is created using Tensor::element_view(), which provides
a true view of the underlying data without any copying. The view tensors
support all standard tensor operations including gradient tracking.
§Standard Library Compatibility
This iterator implements all standard iterator traits:
Iterator: Basic iteration withnext()andsize_hint()ExactSizeIterator: Precise size information withlen()DoubleEndedIterator: Reverse iteration withnext_back()FusedIterator: Fused iteration for better performanceIntoIterator: Automatic conversion forforloops
§Examples
§Basic Iteration
use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
// Basic iteration
for element in tensor.iter() {
println!("Element value: {}", element.value());
}
// Standard library methods
let sum: f32 = tensor.iter()
.map(|elem| elem.value())
.sum();
assert_eq!(sum, 6.0);§Element Operations
use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![3]).unwrap();
// Tensor operations on elements
let transformed: Tensor = tensor.iter()
.map(|elem| elem.mul_scalar(2.0).add_scalar(1.0)) // 2x + 1
.collect();
assert_eq!(transformed.data(), &[3.0, 5.0, 7.0]);§Advanced Iterator Methods
use train_station::Tensor;
let tensor = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0], vec![5]).unwrap();
// Filter and transform
let result: Tensor = tensor.iter()
.filter(|elem| elem.value() > 2.0)
.map(|elem| elem.mul_scalar(10.0))
.collect();
assert_eq!(result.data(), &[30.0, 40.0, 50.0]);
// Reverse iteration
let reversed: Tensor = tensor.iter().rev().collect();
assert_eq!(reversed.data(), &[5.0, 4.0, 3.0, 2.0, 1.0]);IntoIterator for &Tensor now iterates outermost dimension, yielding sub-tensors (views)
Source§impl IntoIterator for Tensor
IntoIterator for owned Tensor: iterate outermost dimension producing sub-tensors.
Enables .into_iter().flatten() patterns on owned tensors.
impl IntoIterator for Tensor
IntoIterator for owned Tensor: iterate outermost dimension producing sub-tensors.
Enables .into_iter().flatten() patterns on owned tensors.
Source§impl Mul<Tensor> for f32
Scalar-tensor multiplication operator implementations
impl Mul<Tensor> for f32
Scalar-tensor multiplication operator implementations
Provides multiplication operations between scalars and tensors.
All implementations delegate to the underlying mul_scalar method.
Source§impl Mul<f32> for Tensor
Tensor-scalar multiplication operator implementations
impl Mul<f32> for Tensor
Tensor-scalar multiplication operator implementations
Provides multiplication operations between tensors and scalars.
All implementations delegate to the underlying mul_scalar method.
Source§impl Mul for Tensor
Tensor multiplication operator implementations
impl Mul for Tensor
Tensor multiplication operator implementations
Provides element-wise multiplication operations between tensors with various reference combinations.
All implementations delegate to the underlying mul_tensor method for optimal performance.
Source§impl MulAssign<&Tensor> for Tensor
impl MulAssign<&Tensor> for Tensor
Source§fn mul_assign(&mut self, other: &Tensor)
fn mul_assign(&mut self, other: &Tensor)
Multiplies this tensor by another tensor reference in-place
Source§impl MulAssign<f32> for Tensor
Tensor-scalar multiplication assignment operator implementations
impl MulAssign<f32> for Tensor
Tensor-scalar multiplication assignment operator implementations
Provides in-place multiplication operations between tensors and scalars.
Source§fn mul_assign(&mut self, scalar: f32)
fn mul_assign(&mut self, scalar: f32)
Multiplies each element of this tensor by a scalar in-place
Source§impl MulAssign for Tensor
Tensor multiplication assignment operator implementations
impl MulAssign for Tensor
Tensor multiplication assignment operator implementations
Provides in-place multiplication operations between tensors.
All implementations delegate to the underlying mul_tensor method.
Source§fn mul_assign(&mut self, other: Tensor)
fn mul_assign(&mut self, other: Tensor)
Multiplies this tensor by another tensor in-place
Source§impl Neg for Tensor
Tensor negation operator implementations
impl Neg for Tensor
Tensor negation operator implementations
Provides unary negation operations for tensors.
All implementations delegate to the underlying mul_scalar method with -1.0.
Source§impl Serializable for Tensor
impl Serializable for Tensor
Source§fn to_json(&self) -> SerializationResult<String>
fn to_json(&self) -> SerializationResult<String>
Serialize the tensor to JSON format
This method converts the tensor into a human-readable JSON string representation that includes all tensor data, shape information, device placement, and gradtrack state. The JSON format is suitable for debugging, configuration files, and cross-language interoperability.
§Returns
JSON string representation of the tensor on success, or SerializationError on failure
§Examples
use train_station::Tensor;
use train_station::serialization::Serializable;
let mut tensor = Tensor::zeros(vec![2, 3]);
tensor.set(&[0, 0], 1.0);
tensor.set(&[1, 2], 5.0);
let json = tensor.to_json().unwrap();
assert!(!json.is_empty());
assert!(json.contains("data"));
assert!(json.contains("shape"));Source§fn from_json(json: &str) -> SerializationResult<Self>
fn from_json(json: &str) -> SerializationResult<Self>
Deserialize a tensor from JSON format
This method parses a JSON string and reconstructs a tensor with all its data, shape information, device placement, and gradtrack state. The JSON must contain all necessary fields in the expected format.
§Arguments
json- JSON string containing serialized tensor data
§Returns
The deserialized tensor on success, or SerializationError on failure
§Examples
use train_station::Tensor;
use train_station::serialization::Serializable;
let mut original = Tensor::ones(vec![2, 2]);
original.set(&[0, 1], 3.0);
original.set_requires_grad(true);
let json = original.to_json().unwrap();
let restored = Tensor::from_json(&json).unwrap();
assert_eq!(original.shape().dims(), restored.shape().dims());
assert_eq!(original.get(&[0, 1]), restored.get(&[0, 1]));
assert_eq!(original.requires_grad(), restored.requires_grad());Source§fn to_binary(&self) -> SerializationResult<Vec<u8>>
fn to_binary(&self) -> SerializationResult<Vec<u8>>
Serialize the tensor to binary format
This method converts the tensor into a compact binary representation optimized for storage and transmission. The binary format provides maximum performance and minimal file sizes, making it ideal for large tensors and production use.
§Returns
Binary representation of the tensor on success, or SerializationError on failure
§Examples
use train_station::Tensor;
use train_station::serialization::Serializable;
let mut tensor = Tensor::zeros(vec![100, 100]);
for i in 0..10 {
tensor.set(&[i, i], i as f32);
}
let binary = tensor.to_binary().unwrap();
assert!(!binary.is_empty());
// Binary format is more compact than JSON for large tensorsSource§fn from_binary(data: &[u8]) -> SerializationResult<Self>
fn from_binary(data: &[u8]) -> SerializationResult<Self>
Deserialize a tensor from binary format
This method parses binary data and reconstructs a tensor with all its data, shape information, device placement, and gradtrack state. The binary data must contain complete serialized information in the expected format.
§Arguments
data- Binary data containing serialized tensor information
§Returns
The deserialized tensor on success, or SerializationError on failure
§Examples
use train_station::Tensor;
use train_station::serialization::Serializable;
let mut original = Tensor::ones(vec![3, 4]);
original.set(&[2, 3], 7.5);
original.set_requires_grad(true);
let binary = original.to_binary().unwrap();
let restored = Tensor::from_binary(&binary).unwrap();
assert_eq!(original.shape().dims(), restored.shape().dims());
assert_eq!(original.get(&[2, 3]), restored.get(&[2, 3]));
assert_eq!(original.requires_grad(), restored.requires_grad());Source§fn save<P: AsRef<Path>>(
&self,
path: P,
format: Format,
) -> SerializationResult<()>
fn save<P: AsRef<Path>>( &self, path: P, format: Format, ) -> SerializationResult<()>
Source§fn save_to_writer<W: Write>(
&self,
writer: &mut W,
format: Format,
) -> SerializationResult<()>
fn save_to_writer<W: Write>( &self, writer: &mut W, format: Format, ) -> SerializationResult<()>
Source§fn load<P: AsRef<Path>>(path: P, format: Format) -> SerializationResult<Self>
fn load<P: AsRef<Path>>(path: P, format: Format) -> SerializationResult<Self>
Source§fn load_from_reader<R: Read>(
reader: &mut R,
format: Format,
) -> SerializationResult<Self>
fn load_from_reader<R: Read>( reader: &mut R, format: Format, ) -> SerializationResult<Self>
Source§impl StructSerializable for Tensor
impl StructSerializable for Tensor
Source§fn to_serializer(&self) -> StructSerializer
fn to_serializer(&self) -> StructSerializer
Convert Tensor to StructSerializer for serialization
Serializes tensor data, shape, device, and gradtrack state. Runtime state (id, grad, grad_fn, allocation_owner) is not serialized.
§Returns
StructSerializer containing all persistent tensor state
Source§fn from_deserializer(
deserializer: &mut StructDeserializer,
) -> SerializationResult<Self>
fn from_deserializer( deserializer: &mut StructDeserializer, ) -> SerializationResult<Self>
Create Tensor from StructDeserializer
Reconstructs tensor from serialized data, shape, device, and gradtrack state. Allocates new memory and generates new tensor ID.
§Arguments
deserializer- StructDeserializer containing tensor data
§Returns
Reconstructed Tensor instance or error if deserialization fails
Source§fn save_json<P: AsRef<Path>>(&self, path: P) -> SerializationResult<()>
fn save_json<P: AsRef<Path>>(&self, path: P) -> SerializationResult<()>
Source§fn save_binary<P: AsRef<Path>>(&self, path: P) -> SerializationResult<()>
fn save_binary<P: AsRef<Path>>(&self, path: P) -> SerializationResult<()>
Source§fn load_json<P: AsRef<Path>>(path: P) -> SerializationResult<Self>
fn load_json<P: AsRef<Path>>(path: P) -> SerializationResult<Self>
Source§fn load_binary<P: AsRef<Path>>(path: P) -> SerializationResult<Self>
fn load_binary<P: AsRef<Path>>(path: P) -> SerializationResult<Self>
Source§fn to_json(&self) -> SerializationResult<String>
fn to_json(&self) -> SerializationResult<String>
Source§fn to_binary(&self) -> SerializationResult<Vec<u8>>
fn to_binary(&self) -> SerializationResult<Vec<u8>>
Source§fn from_json(json: &str) -> SerializationResult<Self>
fn from_json(json: &str) -> SerializationResult<Self>
Source§fn from_binary(data: &[u8]) -> SerializationResult<Self>
fn from_binary(data: &[u8]) -> SerializationResult<Self>
Source§impl Sub<Tensor> for f32
Scalar-tensor subtraction operator implementations
impl Sub<Tensor> for f32
Scalar-tensor subtraction operator implementations
Provides subtraction operations between scalars and tensors.
Computes scalar - tensor by negating the tensor and adding the scalar.
Source§impl Sub<f32> for Tensor
Tensor-scalar subtraction operator implementations
impl Sub<f32> for Tensor
Tensor-scalar subtraction operator implementations
Provides subtraction operations between tensors and scalars.
All implementations delegate to the underlying sub_scalar method.
Source§impl Sub for Tensor
Tensor subtraction operator implementations
impl Sub for Tensor
Tensor subtraction operator implementations
Provides subtraction operations between tensors with various reference combinations.
All implementations delegate to the underlying sub_tensor method for optimal performance.
Source§impl SubAssign<&Tensor> for Tensor
impl SubAssign<&Tensor> for Tensor
Source§fn sub_assign(&mut self, other: &Tensor)
fn sub_assign(&mut self, other: &Tensor)
Subtracts another tensor reference from this tensor in-place
Source§impl SubAssign<f32> for Tensor
Tensor-scalar subtraction assignment operator implementations
impl SubAssign<f32> for Tensor
Tensor-scalar subtraction assignment operator implementations
Provides in-place subtraction operations between tensors and scalars.
Source§fn sub_assign(&mut self, scalar: f32)
fn sub_assign(&mut self, scalar: f32)
Subtracts a scalar from each element of this tensor in-place
Source§impl SubAssign for Tensor
Tensor subtraction assignment operator implementations
impl SubAssign for Tensor
Tensor subtraction assignment operator implementations
Provides in-place subtraction operations between tensors.
All implementations delegate to the underlying sub_tensor method.
Source§fn sub_assign(&mut self, other: Tensor)
fn sub_assign(&mut self, other: Tensor)
Subtracts another tensor from this tensor in-place