dqn/neural_networks/
basic_linear_layer.rs

1//! Basic Linear Layer Example
2//!
3//! This example demonstrates how to implement a basic linear layer in Train Station:
4//! - Creating a linear layer struct with trainable parameters
5//! - Forward pass implementation with matrix multiplication and bias addition
6//! - Forward pass without gradient tracking using NoGradTrack
7//! - Training loop with gradient computation and optimization
8//! - Single inference and batch inference patterns
9//! - Serialization for saving and loading layer parameters
10//!
11//! # Learning Objectives
12//!
13//! - Understand linear layer implementation using Tensor operations
14//! - Learn gradient-aware and gradient-free forward passes
15//! - Implement complete training workflows with optimization
16//! - Explore serialization patterns for model persistence
17//! - Compare single vs batch inference performance
18//!
19//! # Prerequisites
20//!
21//! - Basic Rust knowledge
22//! - Understanding of tensor basics (see getting_started/tensor_basics.rs)
23//! - Familiarity with neural network concepts
24//! - Knowledge of gradient descent and backpropagation
25//!
26//! # Usage
27//!
28//! ```bash
29//! cargo run --example basic_linear_layer
30//! ```
31
32use std::fs;
33use train_station::{
34    gradtrack::NoGradTrack,
35    optimizers::{Adam, AdamConfig, Optimizer},
36    serialization::StructSerializable,
37    Tensor,
38};
39
40/// A basic linear layer implementation
41#[derive(Debug)]
42pub struct LinearLayer {
43    /// Weight matrix [input_size, output_size]
44    pub weight: Tensor,
45    /// Bias vector [output_size]
46    pub bias: Tensor,
47    pub input_size: usize,
48    pub output_size: usize,
49}
50
51impl LinearLayer {
52    /// Create a new linear layer with random initialization
53    pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
54        // Xavier/Glorot initialization: scale by sqrt(1/input_size)
55        let scale = (1.0 / input_size as f32).sqrt();
56
57        let weight = Tensor::randn(vec![input_size, output_size], seed)
58            .mul_scalar(scale)
59            .with_requires_grad();
60        let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
61
62        Self {
63            weight,
64            bias,
65            input_size,
66            output_size,
67        }
68    }
69
70    /// Forward pass: output = input @ weight + bias
71    pub fn forward(&self, input: &Tensor) -> Tensor {
72        // Matrix multiplication: [batch_size, input_size] @ [input_size, output_size] = [batch_size, output_size]
73        let output = input.matmul(&self.weight);
74        // Add bias: [batch_size, output_size] + [output_size] = [batch_size, output_size]
75        output.add_tensor(&self.bias)
76    }
77
78    /// Forward pass without gradients (for inference)
79    #[allow(unused)]
80    pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
81        let _guard = NoGradTrack::new();
82        self.forward(input)
83    }
84
85    /// Get all parameters for optimization
86    pub fn parameters(&mut self) -> Vec<&mut Tensor> {
87        vec![&mut self.weight, &mut self.bias]
88    }
89
90    /// Save layer parameters to JSON
91    #[allow(unused)]
92    pub fn save_json(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
93        // Create directory if it doesn't exist
94        if let Some(parent) = std::path::Path::new(path).parent() {
95            fs::create_dir_all(parent)?;
96        }
97
98        let weight_path = format!("{}_weight.json", path);
99        let bias_path = format!("{}_bias.json", path);
100
101        self.weight.save_json(&weight_path)?;
102        self.bias.save_json(&bias_path)?;
103
104        println!("Saved linear layer to {} (weight and bias)", path);
105        Ok(())
106    }
107
108    /// Load layer parameters from JSON
109    #[allow(unused)]
110    pub fn load_json(
111        path: &str,
112        input_size: usize,
113        output_size: usize,
114    ) -> Result<Self, Box<dyn std::error::Error>> {
115        let weight_path = format!("{}_weight.json", path);
116        let bias_path = format!("{}_bias.json", path);
117
118        let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
119        let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
120
121        Ok(Self {
122            weight,
123            bias,
124            input_size,
125            output_size,
126        })
127    }
128
129    /// Get parameter count
130    #[allow(unused)]
131    pub fn parameter_count(&self) -> usize {
132        self.input_size * self.output_size + self.output_size
133    }
134}
135
136#[allow(unused)]
137fn main() -> Result<(), Box<dyn std::error::Error>> {
138    println!("=== Basic Linear Layer Example ===\n");
139
140    demonstrate_layer_creation();
141    demonstrate_forward_pass();
142    demonstrate_forward_pass_no_grad();
143    demonstrate_training_loop()?;
144    demonstrate_single_vs_batch_inference();
145    demonstrate_serialization()?;
146    cleanup_temp_files()?;
147
148    println!("\n=== Example completed successfully! ===");
149    Ok(())
150}
151
152/// Demonstrate creating a linear layer
153fn demonstrate_layer_creation() {
154    println!("--- Layer Creation ---");
155
156    let layer = LinearLayer::new(3, 2, Some(42));
157
158    println!("Created linear layer:");
159    println!("  Input size: {}", layer.input_size);
160    println!("  Output size: {}", layer.output_size);
161    println!("  Parameter count: {}", layer.parameter_count());
162    println!("  Weight shape: {:?}", layer.weight.shape().dims());
163    println!("  Bias shape: {:?}", layer.bias.shape().dims());
164    println!("  Weight requires grad: {}", layer.weight.requires_grad());
165    println!("  Bias requires grad: {}", layer.bias.requires_grad());
166}
167
168/// Demonstrate forward pass with gradient tracking
169fn demonstrate_forward_pass() {
170    println!("\n--- Forward Pass (with gradients) ---");
171
172    let layer = LinearLayer::new(3, 2, Some(43));
173
174    // Single input
175    let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
176    let output = layer.forward(&input);
177
178    println!("Single input:");
179    println!("  Input: {:?}", input.data());
180    println!("  Output: {:?}", output.data());
181    println!("  Output requires grad: {}", output.requires_grad());
182
183    // Batch input
184    let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
185    let batch_output = layer.forward(&batch_input);
186
187    println!("Batch input:");
188    println!("  Input shape: {:?}", batch_input.shape().dims());
189    println!("  Output shape: {:?}", batch_output.shape().dims());
190    println!("  Output requires grad: {}", batch_output.requires_grad());
191}
192
193/// Demonstrate forward pass without gradient tracking
194fn demonstrate_forward_pass_no_grad() {
195    println!("\n--- Forward Pass (no gradients) ---");
196
197    let layer = LinearLayer::new(3, 2, Some(44));
198
199    // Single input
200    let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
201    let output = layer.forward_no_grad(&input);
202
203    println!("Single input (no grad):");
204    println!("  Input: {:?}", input.data());
205    println!("  Output: {:?}", output.data());
206    println!("  Output requires grad: {}", output.requires_grad());
207
208    // Compare with grad version
209    let output_with_grad = layer.forward(&input);
210    println!("Comparison:");
211    println!(
212        "  Same values: {}",
213        output.data() == output_with_grad.data()
214    );
215    println!("  No grad requires grad: {}", output.requires_grad());
216    println!(
217        "  With grad requires grad: {}",
218        output_with_grad.requires_grad()
219    );
220}
221
222/// Demonstrate complete training loop
223fn demonstrate_training_loop() -> Result<(), Box<dyn std::error::Error>> {
224    println!("\n--- Training Loop ---");
225
226    // Create layer and training data
227    let mut layer = LinearLayer::new(2, 1, Some(45));
228
229    // Simple regression task: y = 2*x1 + 3*x2 + 1
230    let x_data = Tensor::from_slice(
231        &[
232            1.0, 1.0, // x1=1, x2=1 -> y=6
233            2.0, 1.0, // x1=2, x2=1 -> y=8
234            1.0, 2.0, // x1=1, x2=2 -> y=9
235            2.0, 2.0, // x1=2, x2=2 -> y=11
236        ],
237        vec![4, 2],
238    )
239    .unwrap();
240
241    let y_true = Tensor::from_slice(&[6.0, 8.0, 9.0, 11.0], vec![4, 1]).unwrap();
242
243    println!("Training data:");
244    println!("  X shape: {:?}", x_data.shape().dims());
245    println!("  Y shape: {:?}", y_true.shape().dims());
246    println!("  Target function: y = 2*x1 + 3*x2 + 1");
247
248    // Create optimizer
249    let config = AdamConfig {
250        learning_rate: 0.01,
251        beta1: 0.9,
252        beta2: 0.999,
253        eps: 1e-8,
254        weight_decay: 0.0,
255        amsgrad: false,
256    };
257
258    let mut optimizer = Adam::with_config(config);
259    let params = layer.parameters();
260    for param in &params {
261        optimizer.add_parameter(param);
262    }
263
264    println!("Optimizer setup complete. Starting training...");
265
266    // Training loop
267    let num_epochs = 100;
268    let mut losses = Vec::new();
269
270    for epoch in 0..num_epochs {
271        // Forward pass
272        let y_pred = layer.forward(&x_data);
273
274        // Compute loss: MSE
275        let diff = y_pred.sub_tensor(&y_true);
276        let mut loss = diff.pow_scalar(2.0).mean();
277
278        // Backward pass
279        loss.backward(None);
280
281        // Optimizer step
282        let mut params = layer.parameters();
283        optimizer.step(&mut params);
284        optimizer.zero_grad(&mut params);
285
286        losses.push(loss.value());
287
288        // Print progress
289        if epoch % 20 == 0 || epoch == num_epochs - 1 {
290            println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
291        }
292    }
293
294    // Evaluate final model
295    let final_predictions = layer.forward_no_grad(&x_data);
296
297    println!("\nFinal model evaluation:");
298    println!("  Learned weights: {:?}", layer.weight.data());
299    println!("  Learned bias: {:?}", layer.bias.data());
300    println!("  Target weights: [2.0, 3.0]");
301    println!("  Target bias: [1.0]");
302
303    println!("  Predictions vs True:");
304    for i in 0..4 {
305        let pred = final_predictions.data()[i];
306        let true_val = y_true.data()[i];
307        println!(
308            "    Sample {}: pred={:.3}, true={:.1}, error={:.3}",
309            i + 1,
310            pred,
311            true_val,
312            (pred - true_val).abs()
313        );
314    }
315
316    // Training analysis
317    let initial_loss = losses[0];
318    let final_loss = losses[losses.len() - 1];
319    let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
320
321    println!("\nTraining Analysis:");
322    println!("  Initial loss: {:.6}", initial_loss);
323    println!("  Final loss: {:.6}", final_loss);
324    println!("  Loss reduction: {:.1}%", loss_reduction);
325
326    Ok(())
327}
328
329/// Demonstrate single vs batch inference
330fn demonstrate_single_vs_batch_inference() {
331    println!("\n--- Single vs Batch Inference ---");
332
333    let layer = LinearLayer::new(4, 3, Some(46));
334
335    // Single inference
336    println!("Single inference:");
337    let single_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
338    let single_output = layer.forward_no_grad(&single_input);
339    println!("  Input shape: {:?}", single_input.shape().dims());
340    println!("  Output shape: {:?}", single_output.shape().dims());
341    println!("  Output: {:?}", single_output.data());
342
343    // Batch inference
344    println!("Batch inference:");
345    let batch_input = Tensor::from_slice(
346        &[
347            1.0, 2.0, 3.0, 4.0, // Sample 1
348            5.0, 6.0, 7.0, 8.0, // Sample 2
349            9.0, 10.0, 11.0, 12.0, // Sample 3
350        ],
351        vec![3, 4],
352    )
353    .unwrap();
354    let batch_output = layer.forward_no_grad(&batch_input);
355    println!("  Input shape: {:?}", batch_input.shape().dims());
356    println!("  Output shape: {:?}", batch_output.shape().dims());
357
358    // Verify batch consistency - first sample should match single inference
359    let _first_batch_sample = batch_output.view(vec![3, 3]); // Reshape to access first sample
360    let first_sample_data = &batch_output.data()[0..3]; // First 3 elements
361    let single_sample_data = single_output.data();
362
363    println!("Consistency check:");
364    println!("  Single output: {:?}", single_sample_data);
365    println!("  First batch sample: {:?}", first_sample_data);
366    println!(
367        "  Match: {}",
368        single_sample_data
369            .iter()
370            .zip(first_sample_data.iter())
371            .all(|(a, b)| (a - b).abs() < 1e-6)
372    );
373}
374
375/// Demonstrate serialization and loading
376fn demonstrate_serialization() -> Result<(), Box<dyn std::error::Error>> {
377    println!("\n--- Serialization ---");
378
379    // Create and train a simple layer
380    let mut original_layer = LinearLayer::new(2, 1, Some(47));
381
382    // Simple training data
383    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
384    let y_true = Tensor::from_slice(&[5.0, 11.0], vec![2, 1]).unwrap();
385
386    let mut optimizer = Adam::with_learning_rate(0.01);
387    let params = original_layer.parameters();
388    for param in &params {
389        optimizer.add_parameter(param);
390    }
391
392    // Train for a few epochs
393    for _ in 0..10 {
394        let y_pred = original_layer.forward(&x_data);
395        let mut loss = (y_pred.sub_tensor(&y_true)).pow_scalar(2.0).mean();
396        loss.backward(None);
397
398        let mut params = original_layer.parameters();
399        optimizer.step(&mut params);
400        optimizer.zero_grad(&mut params);
401    }
402
403    println!("Original layer trained");
404    println!("  Weight: {:?}", original_layer.weight.data());
405    println!("  Bias: {:?}", original_layer.bias.data());
406
407    // Save layer
408    original_layer.save_json("temp_linear_layer")?;
409
410    // Load layer
411    let loaded_layer = LinearLayer::load_json("temp_linear_layer", 2, 1)?;
412
413    println!("Loaded layer");
414    println!("  Weight: {:?}", loaded_layer.weight.data());
415    println!("  Bias: {:?}", loaded_layer.bias.data());
416
417    // Verify consistency
418    let test_input = Tensor::from_slice(&[1.0, 1.0], vec![1, 2]).unwrap();
419    let original_output = original_layer.forward_no_grad(&test_input);
420    let loaded_output = loaded_layer.forward_no_grad(&test_input);
421
422    println!("Consistency check:");
423    println!("  Original output: {:?}", original_output.data());
424    println!("  Loaded output: {:?}", loaded_output.data());
425    println!(
426        "  Match: {}",
427        original_output
428            .data()
429            .iter()
430            .zip(loaded_output.data().iter())
431            .all(|(a, b)| (a - b).abs() < 1e-6)
432    );
433
434    println!("Serialization verification: PASSED");
435
436    Ok(())
437}
438
439/// Clean up temporary files
440fn cleanup_temp_files() -> Result<(), Box<dyn std::error::Error>> {
441    println!("\n--- Cleanup ---");
442
443    let files_to_remove = [
444        "temp_linear_layer_weight.json",
445        "temp_linear_layer_bias.json",
446    ];
447
448    for file in &files_to_remove {
449        if fs::metadata(file).is_ok() {
450            fs::remove_file(file)?;
451            println!("Removed: {}", file);
452        }
453    }
454
455    println!("Cleanup completed");
456    Ok(())
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    #[test]
464    fn test_layer_creation() {
465        let layer = LinearLayer::new(3, 2, Some(42));
466        assert_eq!(layer.input_size, 3);
467        assert_eq!(layer.output_size, 2);
468        assert_eq!(layer.weight.shape().dims(), vec![3, 2]);
469        assert_eq!(layer.bias.shape().dims(), vec![2]);
470        assert!(layer.weight.requires_grad());
471        assert!(layer.bias.requires_grad());
472    }
473
474    #[test]
475    fn test_forward_pass() {
476        let layer = LinearLayer::new(2, 1, Some(43));
477        let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
478        let output = layer.forward(&input);
479
480        assert_eq!(output.shape().dims(), vec![1, 1]);
481        assert!(output.requires_grad());
482    }
483
484    #[test]
485    fn test_forward_pass_no_grad() {
486        let layer = LinearLayer::new(2, 1, Some(44));
487        let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
488        let output = layer.forward_no_grad(&input);
489
490        assert_eq!(output.shape().dims(), vec![1, 1]);
491        assert!(!output.requires_grad());
492    }
493
494    #[test]
495    fn test_batch_inference() {
496        let layer = LinearLayer::new(2, 1, Some(45));
497        let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
498        let output = layer.forward(&batch_input);
499
500        assert_eq!(output.shape().dims(), vec![2, 1]);
501    }
502
503    #[test]
504    fn test_parameter_count() {
505        let layer = LinearLayer::new(3, 2, Some(46));
506        assert_eq!(layer.parameter_count(), 3 * 2 + 2); // weights + bias
507    }
508
509    #[test]
510    fn test_serialization_roundtrip() {
511        let original = LinearLayer::new(2, 1, Some(47));
512
513        // Save and load
514        original.save_json("test_layer").unwrap();
515        let loaded = LinearLayer::load_json("test_layer", 2, 1).unwrap();
516
517        // Verify shapes
518        assert_eq!(original.weight.shape().dims(), loaded.weight.shape().dims());
519        assert_eq!(original.bias.shape().dims(), loaded.bias.shape().dims());
520
521        // Verify data
522        assert_eq!(original.weight.data(), loaded.weight.data());
523        assert_eq!(original.bias.data(), loaded.bias.data());
524
525        // Cleanup
526        let _ = fs::remove_file("test_layer_weight.json");
527        let _ = fs::remove_file("test_layer_bias.json");
528    }
529}