basic_linear_layer/
basic_linear_layer.rs

1//! Basic Linear Layer Example
2//!
3//! This example demonstrates how to implement a basic linear layer in Train Station:
4//! - Creating a linear layer struct with trainable parameters
5//! - Forward pass implementation with matrix multiplication and bias addition
6//! - Forward pass without gradient tracking using NoGradTrack
7//! - Training loop with gradient computation and optimization
8//! - Single inference and batch inference patterns
9//! - Serialization for saving and loading layer parameters
10//!
11//! # Learning Objectives
12//!
13//! - Understand linear layer implementation using Tensor operations
14//! - Learn gradient-aware and gradient-free forward passes
15//! - Implement complete training workflows with optimization
16//! - Explore serialization patterns for model persistence
17//! - Compare single vs batch inference performance
18//!
19//! # Prerequisites
20//!
21//! - Basic Rust knowledge
22//! - Understanding of tensor basics (see getting_started/tensor_basics.rs)
23//! - Familiarity with neural network concepts
24//! - Knowledge of gradient descent and backpropagation
25//!
26//! # Usage
27//!
28//! ```bash
29//! cargo run --example basic_linear_layer
30//! ```
31
32use std::fs;
33use train_station::{
34    gradtrack::NoGradTrack,
35    optimizers::{Adam, AdamConfig, Optimizer},
36    serialization::StructSerializable,
37    Tensor,
38};
39
40/// A basic linear layer implementation
41#[derive(Debug)]
42pub struct LinearLayer {
43    /// Weight matrix [input_size, output_size]
44    pub weight: Tensor,
45    /// Bias vector [output_size]
46    pub bias: Tensor,
47    pub input_size: usize,
48    pub output_size: usize,
49}
50
51impl LinearLayer {
52    /// Create a new linear layer with random initialization
53    pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
54        // Xavier/Glorot initialization: scale by sqrt(1/input_size)
55        let scale = (1.0 / input_size as f32).sqrt();
56
57        let weight = Tensor::randn(vec![input_size, output_size], seed)
58            .mul_scalar(scale)
59            .with_requires_grad();
60        let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
61
62        Self {
63            weight,
64            bias,
65            input_size,
66            output_size,
67        }
68    }
69
70    /// Forward pass: output = input @ weight + bias
71    pub fn forward(&self, input: &Tensor) -> Tensor {
72        // Matrix multiplication: [batch_size, input_size] @ [input_size, output_size] = [batch_size, output_size]
73        let output = input.matmul(&self.weight);
74        // Add bias: [batch_size, output_size] + [output_size] = [batch_size, output_size]
75        output.add_tensor(&self.bias)
76    }
77
78    /// Forward pass without gradients (for inference)
79    pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
80        let _guard = NoGradTrack::new();
81        self.forward(input)
82    }
83
84    /// Get all parameters for optimization
85    pub fn parameters(&mut self) -> Vec<&mut Tensor> {
86        vec![&mut self.weight, &mut self.bias]
87    }
88
89    /// Save layer parameters to JSON
90    pub fn save_json(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
91        // Create directory if it doesn't exist
92        if let Some(parent) = std::path::Path::new(path).parent() {
93            fs::create_dir_all(parent)?;
94        }
95
96        let weight_path = format!("{}_weight.json", path);
97        let bias_path = format!("{}_bias.json", path);
98
99        self.weight.save_json(&weight_path)?;
100        self.bias.save_json(&bias_path)?;
101
102        println!("Saved linear layer to {} (weight and bias)", path);
103        Ok(())
104    }
105
106    /// Load layer parameters from JSON
107    pub fn load_json(
108        path: &str,
109        input_size: usize,
110        output_size: usize,
111    ) -> Result<Self, Box<dyn std::error::Error>> {
112        let weight_path = format!("{}_weight.json", path);
113        let bias_path = format!("{}_bias.json", path);
114
115        let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
116        let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
117
118        Ok(Self {
119            weight,
120            bias,
121            input_size,
122            output_size,
123        })
124    }
125
126    /// Get parameter count
127    pub fn parameter_count(&self) -> usize {
128        self.input_size * self.output_size + self.output_size
129    }
130}
131
132fn main() -> Result<(), Box<dyn std::error::Error>> {
133    println!("=== Basic Linear Layer Example ===\n");
134
135    demonstrate_layer_creation();
136    demonstrate_forward_pass();
137    demonstrate_forward_pass_no_grad();
138    demonstrate_training_loop()?;
139    demonstrate_single_vs_batch_inference();
140    demonstrate_serialization()?;
141    cleanup_temp_files()?;
142
143    println!("\n=== Example completed successfully! ===");
144    Ok(())
145}
146
147/// Demonstrate creating a linear layer
148fn demonstrate_layer_creation() {
149    println!("--- Layer Creation ---");
150
151    let layer = LinearLayer::new(3, 2, Some(42));
152
153    println!("Created linear layer:");
154    println!("  Input size: {}", layer.input_size);
155    println!("  Output size: {}", layer.output_size);
156    println!("  Parameter count: {}", layer.parameter_count());
157    println!("  Weight shape: {:?}", layer.weight.shape().dims());
158    println!("  Bias shape: {:?}", layer.bias.shape().dims());
159    println!("  Weight requires grad: {}", layer.weight.requires_grad());
160    println!("  Bias requires grad: {}", layer.bias.requires_grad());
161}
162
163/// Demonstrate forward pass with gradient tracking
164fn demonstrate_forward_pass() {
165    println!("\n--- Forward Pass (with gradients) ---");
166
167    let layer = LinearLayer::new(3, 2, Some(43));
168
169    // Single input
170    let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
171    let output = layer.forward(&input);
172
173    println!("Single input:");
174    println!("  Input: {:?}", input.data());
175    println!("  Output: {:?}", output.data());
176    println!("  Output requires grad: {}", output.requires_grad());
177
178    // Batch input
179    let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
180    let batch_output = layer.forward(&batch_input);
181
182    println!("Batch input:");
183    println!("  Input shape: {:?}", batch_input.shape().dims());
184    println!("  Output shape: {:?}", batch_output.shape().dims());
185    println!("  Output requires grad: {}", batch_output.requires_grad());
186}
187
188/// Demonstrate forward pass without gradient tracking
189fn demonstrate_forward_pass_no_grad() {
190    println!("\n--- Forward Pass (no gradients) ---");
191
192    let layer = LinearLayer::new(3, 2, Some(44));
193
194    // Single input
195    let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
196    let output = layer.forward_no_grad(&input);
197
198    println!("Single input (no grad):");
199    println!("  Input: {:?}", input.data());
200    println!("  Output: {:?}", output.data());
201    println!("  Output requires grad: {}", output.requires_grad());
202
203    // Compare with grad version
204    let output_with_grad = layer.forward(&input);
205    println!("Comparison:");
206    println!(
207        "  Same values: {}",
208        output.data() == output_with_grad.data()
209    );
210    println!("  No grad requires grad: {}", output.requires_grad());
211    println!(
212        "  With grad requires grad: {}",
213        output_with_grad.requires_grad()
214    );
215}
216
217/// Demonstrate complete training loop
218fn demonstrate_training_loop() -> Result<(), Box<dyn std::error::Error>> {
219    println!("\n--- Training Loop ---");
220
221    // Create layer and training data
222    let mut layer = LinearLayer::new(2, 1, Some(45));
223
224    // Simple regression task: y = 2*x1 + 3*x2 + 1
225    let x_data = Tensor::from_slice(
226        &[
227            1.0, 1.0, // x1=1, x2=1 -> y=6
228            2.0, 1.0, // x1=2, x2=1 -> y=8
229            1.0, 2.0, // x1=1, x2=2 -> y=9
230            2.0, 2.0, // x1=2, x2=2 -> y=11
231        ],
232        vec![4, 2],
233    )
234    .unwrap();
235
236    let y_true = Tensor::from_slice(&[6.0, 8.0, 9.0, 11.0], vec![4, 1]).unwrap();
237
238    println!("Training data:");
239    println!("  X shape: {:?}", x_data.shape().dims());
240    println!("  Y shape: {:?}", y_true.shape().dims());
241    println!("  Target function: y = 2*x1 + 3*x2 + 1");
242
243    // Create optimizer
244    let config = AdamConfig {
245        learning_rate: 0.01,
246        beta1: 0.9,
247        beta2: 0.999,
248        eps: 1e-8,
249        weight_decay: 0.0,
250        amsgrad: false,
251    };
252
253    let mut optimizer = Adam::with_config(config);
254    let params = layer.parameters();
255    for param in &params {
256        optimizer.add_parameter(param);
257    }
258
259    println!("Optimizer setup complete. Starting training...");
260
261    // Training loop
262    let num_epochs = 100;
263    let mut losses = Vec::new();
264
265    for epoch in 0..num_epochs {
266        // Forward pass
267        let y_pred = layer.forward(&x_data);
268
269        // Compute loss: MSE
270        let diff = y_pred.sub_tensor(&y_true);
271        let mut loss = diff.pow_scalar(2.0).mean();
272
273        // Backward pass
274        loss.backward(None);
275
276        // Optimizer step
277        let mut params = layer.parameters();
278        optimizer.step(&mut params);
279        optimizer.zero_grad(&mut params);
280
281        losses.push(loss.value());
282
283        // Print progress
284        if epoch % 20 == 0 || epoch == num_epochs - 1 {
285            println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
286        }
287    }
288
289    // Evaluate final model
290    let final_predictions = layer.forward_no_grad(&x_data);
291
292    println!("\nFinal model evaluation:");
293    println!("  Learned weights: {:?}", layer.weight.data());
294    println!("  Learned bias: {:?}", layer.bias.data());
295    println!("  Target weights: [2.0, 3.0]");
296    println!("  Target bias: [1.0]");
297
298    println!("  Predictions vs True:");
299    for i in 0..4 {
300        let pred = final_predictions.data()[i];
301        let true_val = y_true.data()[i];
302        println!(
303            "    Sample {}: pred={:.3}, true={:.1}, error={:.3}",
304            i + 1,
305            pred,
306            true_val,
307            (pred - true_val).abs()
308        );
309    }
310
311    // Training analysis
312    let initial_loss = losses[0];
313    let final_loss = losses[losses.len() - 1];
314    let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
315
316    println!("\nTraining Analysis:");
317    println!("  Initial loss: {:.6}", initial_loss);
318    println!("  Final loss: {:.6}", final_loss);
319    println!("  Loss reduction: {:.1}%", loss_reduction);
320
321    Ok(())
322}
323
324/// Demonstrate single vs batch inference
325fn demonstrate_single_vs_batch_inference() {
326    println!("\n--- Single vs Batch Inference ---");
327
328    let layer = LinearLayer::new(4, 3, Some(46));
329
330    // Single inference
331    println!("Single inference:");
332    let single_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
333    let single_output = layer.forward_no_grad(&single_input);
334    println!("  Input shape: {:?}", single_input.shape().dims());
335    println!("  Output shape: {:?}", single_output.shape().dims());
336    println!("  Output: {:?}", single_output.data());
337
338    // Batch inference
339    println!("Batch inference:");
340    let batch_input = Tensor::from_slice(
341        &[
342            1.0, 2.0, 3.0, 4.0, // Sample 1
343            5.0, 6.0, 7.0, 8.0, // Sample 2
344            9.0, 10.0, 11.0, 12.0, // Sample 3
345        ],
346        vec![3, 4],
347    )
348    .unwrap();
349    let batch_output = layer.forward_no_grad(&batch_input);
350    println!("  Input shape: {:?}", batch_input.shape().dims());
351    println!("  Output shape: {:?}", batch_output.shape().dims());
352
353    // Verify batch consistency - first sample should match single inference
354    let _first_batch_sample = batch_output.view(vec![3, 3]); // Reshape to access first sample
355    let first_sample_data = &batch_output.data()[0..3]; // First 3 elements
356    let single_sample_data = single_output.data();
357
358    println!("Consistency check:");
359    println!("  Single output: {:?}", single_sample_data);
360    println!("  First batch sample: {:?}", first_sample_data);
361    println!(
362        "  Match: {}",
363        single_sample_data
364            .iter()
365            .zip(first_sample_data.iter())
366            .all(|(a, b)| (a - b).abs() < 1e-6)
367    );
368}
369
370/// Demonstrate serialization and loading
371fn demonstrate_serialization() -> Result<(), Box<dyn std::error::Error>> {
372    println!("\n--- Serialization ---");
373
374    // Create and train a simple layer
375    let mut original_layer = LinearLayer::new(2, 1, Some(47));
376
377    // Simple training data
378    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
379    let y_true = Tensor::from_slice(&[5.0, 11.0], vec![2, 1]).unwrap();
380
381    let mut optimizer = Adam::with_learning_rate(0.01);
382    let params = original_layer.parameters();
383    for param in &params {
384        optimizer.add_parameter(param);
385    }
386
387    // Train for a few epochs
388    for _ in 0..10 {
389        let y_pred = original_layer.forward(&x_data);
390        let mut loss = (y_pred.sub_tensor(&y_true)).pow_scalar(2.0).mean();
391        loss.backward(None);
392
393        let mut params = original_layer.parameters();
394        optimizer.step(&mut params);
395        optimizer.zero_grad(&mut params);
396    }
397
398    println!("Original layer trained");
399    println!("  Weight: {:?}", original_layer.weight.data());
400    println!("  Bias: {:?}", original_layer.bias.data());
401
402    // Save layer
403    original_layer.save_json("temp_linear_layer")?;
404
405    // Load layer
406    let loaded_layer = LinearLayer::load_json("temp_linear_layer", 2, 1)?;
407
408    println!("Loaded layer");
409    println!("  Weight: {:?}", loaded_layer.weight.data());
410    println!("  Bias: {:?}", loaded_layer.bias.data());
411
412    // Verify consistency
413    let test_input = Tensor::from_slice(&[1.0, 1.0], vec![1, 2]).unwrap();
414    let original_output = original_layer.forward_no_grad(&test_input);
415    let loaded_output = loaded_layer.forward_no_grad(&test_input);
416
417    println!("Consistency check:");
418    println!("  Original output: {:?}", original_output.data());
419    println!("  Loaded output: {:?}", loaded_output.data());
420    println!(
421        "  Match: {}",
422        original_output
423            .data()
424            .iter()
425            .zip(loaded_output.data().iter())
426            .all(|(a, b)| (a - b).abs() < 1e-6)
427    );
428
429    println!("Serialization verification: PASSED");
430
431    Ok(())
432}
433
434/// Clean up temporary files
435fn cleanup_temp_files() -> Result<(), Box<dyn std::error::Error>> {
436    println!("\n--- Cleanup ---");
437
438    let files_to_remove = [
439        "temp_linear_layer_weight.json",
440        "temp_linear_layer_bias.json",
441    ];
442
443    for file in &files_to_remove {
444        if fs::metadata(file).is_ok() {
445            fs::remove_file(file)?;
446            println!("Removed: {}", file);
447        }
448    }
449
450    println!("Cleanup completed");
451    Ok(())
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457
458    #[test]
459    fn test_layer_creation() {
460        let layer = LinearLayer::new(3, 2, Some(42));
461        assert_eq!(layer.input_size, 3);
462        assert_eq!(layer.output_size, 2);
463        assert_eq!(layer.weight.shape().dims(), vec![3, 2]);
464        assert_eq!(layer.bias.shape().dims(), vec![2]);
465        assert!(layer.weight.requires_grad());
466        assert!(layer.bias.requires_grad());
467    }
468
469    #[test]
470    fn test_forward_pass() {
471        let layer = LinearLayer::new(2, 1, Some(43));
472        let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
473        let output = layer.forward(&input);
474
475        assert_eq!(output.shape().dims(), vec![1, 1]);
476        assert!(output.requires_grad());
477    }
478
479    #[test]
480    fn test_forward_pass_no_grad() {
481        let layer = LinearLayer::new(2, 1, Some(44));
482        let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
483        let output = layer.forward_no_grad(&input);
484
485        assert_eq!(output.shape().dims(), vec![1, 1]);
486        assert!(!output.requires_grad());
487    }
488
489    #[test]
490    fn test_batch_inference() {
491        let layer = LinearLayer::new(2, 1, Some(45));
492        let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
493        let output = layer.forward(&batch_input);
494
495        assert_eq!(output.shape().dims(), vec![2, 1]);
496    }
497
498    #[test]
499    fn test_parameter_count() {
500        let layer = LinearLayer::new(3, 2, Some(46));
501        assert_eq!(layer.parameter_count(), 3 * 2 + 2); // weights + bias
502    }
503
504    #[test]
505    fn test_serialization_roundtrip() {
506        let original = LinearLayer::new(2, 1, Some(47));
507
508        // Save and load
509        original.save_json("test_layer").unwrap();
510        let loaded = LinearLayer::load_json("test_layer", 2, 1).unwrap();
511
512        // Verify shapes
513        assert_eq!(original.weight.shape().dims(), loaded.weight.shape().dims());
514        assert_eq!(original.bias.shape().dims(), loaded.bias.shape().dims());
515
516        // Verify data
517        assert_eq!(original.weight.data(), loaded.weight.data());
518        assert_eq!(original.bias.data(), loaded.bias.data());
519
520        // Cleanup
521        let _ = fs::remove_file("test_layer_weight.json");
522        let _ = fs::remove_file("test_layer_bias.json");
523    }
524}