basic_linear_layer/
basic_linear_layer.rs

1//! Basic Linear Layer Example
2//!
3//! This example demonstrates how to implement a basic linear layer in Train Station:
4//! - Creating a linear layer struct with trainable parameters
5//! - Forward pass implementation with matrix multiplication and bias addition
6//! - Forward pass without gradient tracking using NoGradTrack
7//! - Training loop with gradient computation and optimization
8//! - Single inference and batch inference patterns
9//! - Serialization for saving and loading layer parameters
10//!
11//! # Learning Objectives
12//!
13//! - Understand linear layer implementation using Tensor operations
14//! - Learn gradient-aware and gradient-free forward passes
15//! - Implement complete training workflows with optimization
16//! - Explore serialization patterns for model persistence
17//! - Compare single vs batch inference performance
18//!
19//! # Prerequisites
20//!
21//! - Basic Rust knowledge
22//! - Understanding of tensor basics (see getting_started/tensor_basics.rs)
23//! - Familiarity with neural network concepts
24//! - Knowledge of gradient descent and backpropagation
25//!
26//! # Usage
27//!
28//! ```bash
29//! cargo run --example basic_linear_layer
30//! ```
31
32use std::fs;
33use train_station::{
34    optimizers::{Adam, AdamConfig, Optimizer},
35    serialization::StructSerializable,
36    NoGradTrack, Tensor,
37};
38
39/// A basic linear layer implementation
40#[derive(Debug)]
41pub struct LinearLayer {
42    /// Weight matrix [input_size, output_size]
43    pub weight: Tensor,
44    /// Bias vector [output_size]
45    pub bias: Tensor,
46    pub input_size: usize,
47    pub output_size: usize,
48}
49
50impl LinearLayer {
51    /// Create a new linear layer with random initialization
52    pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
53        // Xavier/Glorot initialization: scale by sqrt(1/input_size)
54        let scale = (1.0 / input_size as f32).sqrt();
55
56        let weight = Tensor::randn(vec![input_size, output_size], seed)
57            .mul_scalar(scale)
58            .with_requires_grad();
59        let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
60
61        Self {
62            weight,
63            bias,
64            input_size,
65            output_size,
66        }
67    }
68
69    /// Forward pass: output = input @ weight + bias
70    pub fn forward(&self, input: &Tensor) -> Tensor {
71        // Matrix multiplication: [batch_size, input_size] @ [input_size, output_size] = [batch_size, output_size]
72        let output = input.matmul(&self.weight);
73        // Add bias: [batch_size, output_size] + [output_size] = [batch_size, output_size]
74        output.add_tensor(&self.bias)
75    }
76
77    /// Forward pass without gradients (for inference)
78    pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
79        let _guard = NoGradTrack::new();
80        self.forward(input)
81    }
82
83    /// Get all parameters for optimization
84    pub fn parameters(&mut self) -> Vec<&mut Tensor> {
85        vec![&mut self.weight, &mut self.bias]
86    }
87
88    /// Save layer parameters to JSON
89    pub fn save_json(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
90        // Create directory if it doesn't exist
91        if let Some(parent) = std::path::Path::new(path).parent() {
92            fs::create_dir_all(parent)?;
93        }
94
95        let weight_path = format!("{}_weight.json", path);
96        let bias_path = format!("{}_bias.json", path);
97
98        self.weight.save_json(&weight_path)?;
99        self.bias.save_json(&bias_path)?;
100
101        println!("Saved linear layer to {} (weight and bias)", path);
102        Ok(())
103    }
104
105    /// Load layer parameters from JSON
106    pub fn load_json(
107        path: &str,
108        input_size: usize,
109        output_size: usize,
110    ) -> Result<Self, Box<dyn std::error::Error>> {
111        let weight_path = format!("{}_weight.json", path);
112        let bias_path = format!("{}_bias.json", path);
113
114        let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
115        let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
116
117        Ok(Self {
118            weight,
119            bias,
120            input_size,
121            output_size,
122        })
123    }
124
125    /// Get parameter count
126    pub fn parameter_count(&self) -> usize {
127        self.input_size * self.output_size + self.output_size
128    }
129}
130
131fn main() -> Result<(), Box<dyn std::error::Error>> {
132    println!("=== Basic Linear Layer Example ===\n");
133
134    demonstrate_layer_creation();
135    demonstrate_forward_pass();
136    demonstrate_forward_pass_no_grad();
137    demonstrate_training_loop()?;
138    demonstrate_single_vs_batch_inference();
139    demonstrate_serialization()?;
140    cleanup_temp_files()?;
141
142    println!("\n=== Example completed successfully! ===");
143    Ok(())
144}
145
146/// Demonstrate creating a linear layer
147fn demonstrate_layer_creation() {
148    println!("--- Layer Creation ---");
149
150    let layer = LinearLayer::new(3, 2, Some(42));
151
152    println!("Created linear layer:");
153    println!("  Input size: {}", layer.input_size);
154    println!("  Output size: {}", layer.output_size);
155    println!("  Parameter count: {}", layer.parameter_count());
156    println!("  Weight shape: {:?}", layer.weight.shape().dims);
157    println!("  Bias shape: {:?}", layer.bias.shape().dims);
158    println!("  Weight requires grad: {}", layer.weight.requires_grad());
159    println!("  Bias requires grad: {}", layer.bias.requires_grad());
160}
161
162/// Demonstrate forward pass with gradient tracking
163fn demonstrate_forward_pass() {
164    println!("\n--- Forward Pass (with gradients) ---");
165
166    let layer = LinearLayer::new(3, 2, Some(43));
167
168    // Single input
169    let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
170    let output = layer.forward(&input);
171
172    println!("Single input:");
173    println!("  Input: {:?}", input.data());
174    println!("  Output: {:?}", output.data());
175    println!("  Output requires grad: {}", output.requires_grad());
176
177    // Batch input
178    let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
179    let batch_output = layer.forward(&batch_input);
180
181    println!("Batch input:");
182    println!("  Input shape: {:?}", batch_input.shape().dims);
183    println!("  Output shape: {:?}", batch_output.shape().dims);
184    println!("  Output requires grad: {}", batch_output.requires_grad());
185}
186
187/// Demonstrate forward pass without gradient tracking
188fn demonstrate_forward_pass_no_grad() {
189    println!("\n--- Forward Pass (no gradients) ---");
190
191    let layer = LinearLayer::new(3, 2, Some(44));
192
193    // Single input
194    let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
195    let output = layer.forward_no_grad(&input);
196
197    println!("Single input (no grad):");
198    println!("  Input: {:?}", input.data());
199    println!("  Output: {:?}", output.data());
200    println!("  Output requires grad: {}", output.requires_grad());
201
202    // Compare with grad version
203    let output_with_grad = layer.forward(&input);
204    println!("Comparison:");
205    println!(
206        "  Same values: {}",
207        output.data() == output_with_grad.data()
208    );
209    println!("  No grad requires grad: {}", output.requires_grad());
210    println!(
211        "  With grad requires grad: {}",
212        output_with_grad.requires_grad()
213    );
214}
215
216/// Demonstrate complete training loop
217fn demonstrate_training_loop() -> Result<(), Box<dyn std::error::Error>> {
218    println!("\n--- Training Loop ---");
219
220    // Create layer and training data
221    let mut layer = LinearLayer::new(2, 1, Some(45));
222
223    // Simple regression task: y = 2*x1 + 3*x2 + 1
224    let x_data = Tensor::from_slice(
225        &[
226            1.0, 1.0, // x1=1, x2=1 -> y=6
227            2.0, 1.0, // x1=2, x2=1 -> y=8
228            1.0, 2.0, // x1=1, x2=2 -> y=9
229            2.0, 2.0, // x1=2, x2=2 -> y=11
230        ],
231        vec![4, 2],
232    )
233    .unwrap();
234
235    let y_true = Tensor::from_slice(&[6.0, 8.0, 9.0, 11.0], vec![4, 1]).unwrap();
236
237    println!("Training data:");
238    println!("  X shape: {:?}", x_data.shape().dims);
239    println!("  Y shape: {:?}", y_true.shape().dims);
240    println!("  Target function: y = 2*x1 + 3*x2 + 1");
241
242    // Create optimizer
243    let config = AdamConfig {
244        learning_rate: 0.01,
245        beta1: 0.9,
246        beta2: 0.999,
247        eps: 1e-8,
248        weight_decay: 0.0,
249        amsgrad: false,
250    };
251
252    let mut optimizer = Adam::with_config(config);
253    let params = layer.parameters();
254    for param in &params {
255        optimizer.add_parameter(param);
256    }
257
258    println!("Optimizer setup complete. Starting training...");
259
260    // Training loop
261    let num_epochs = 100;
262    let mut losses = Vec::new();
263
264    for epoch in 0..num_epochs {
265        // Forward pass
266        let y_pred = layer.forward(&x_data);
267
268        // Compute loss: MSE
269        let diff = y_pred.sub_tensor(&y_true);
270        let mut loss = diff.pow_scalar(2.0).mean();
271
272        // Backward pass
273        loss.backward(None);
274
275        // Optimizer step
276        let mut params = layer.parameters();
277        optimizer.step(&mut params);
278        optimizer.zero_grad(&mut params);
279
280        losses.push(loss.value());
281
282        // Print progress
283        if epoch % 20 == 0 || epoch == num_epochs - 1 {
284            println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
285        }
286    }
287
288    // Evaluate final model
289    let final_predictions = layer.forward_no_grad(&x_data);
290
291    println!("\nFinal model evaluation:");
292    println!("  Learned weights: {:?}", layer.weight.data());
293    println!("  Learned bias: {:?}", layer.bias.data());
294    println!("  Target weights: [2.0, 3.0]");
295    println!("  Target bias: [1.0]");
296
297    println!("  Predictions vs True:");
298    for i in 0..4 {
299        let pred = final_predictions.data()[i];
300        let true_val = y_true.data()[i];
301        println!(
302            "    Sample {}: pred={:.3}, true={:.1}, error={:.3}",
303            i + 1,
304            pred,
305            true_val,
306            (pred - true_val).abs()
307        );
308    }
309
310    // Training analysis
311    let initial_loss = losses[0];
312    let final_loss = losses[losses.len() - 1];
313    let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
314
315    println!("\nTraining Analysis:");
316    println!("  Initial loss: {:.6}", initial_loss);
317    println!("  Final loss: {:.6}", final_loss);
318    println!("  Loss reduction: {:.1}%", loss_reduction);
319
320    Ok(())
321}
322
323/// Demonstrate single vs batch inference
324fn demonstrate_single_vs_batch_inference() {
325    println!("\n--- Single vs Batch Inference ---");
326
327    let layer = LinearLayer::new(4, 3, Some(46));
328
329    // Single inference
330    println!("Single inference:");
331    let single_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
332    let single_output = layer.forward_no_grad(&single_input);
333    println!("  Input shape: {:?}", single_input.shape().dims);
334    println!("  Output shape: {:?}", single_output.shape().dims);
335    println!("  Output: {:?}", single_output.data());
336
337    // Batch inference
338    println!("Batch inference:");
339    let batch_input = Tensor::from_slice(
340        &[
341            1.0, 2.0, 3.0, 4.0, // Sample 1
342            5.0, 6.0, 7.0, 8.0, // Sample 2
343            9.0, 10.0, 11.0, 12.0, // Sample 3
344        ],
345        vec![3, 4],
346    )
347    .unwrap();
348    let batch_output = layer.forward_no_grad(&batch_input);
349    println!("  Input shape: {:?}", batch_input.shape().dims);
350    println!("  Output shape: {:?}", batch_output.shape().dims);
351
352    // Verify batch consistency - first sample should match single inference
353    let _first_batch_sample = batch_output.view(vec![3, 3]); // Reshape to access first sample
354    let first_sample_data = &batch_output.data()[0..3]; // First 3 elements
355    let single_sample_data = single_output.data();
356
357    println!("Consistency check:");
358    println!("  Single output: {:?}", single_sample_data);
359    println!("  First batch sample: {:?}", first_sample_data);
360    println!(
361        "  Match: {}",
362        single_sample_data
363            .iter()
364            .zip(first_sample_data.iter())
365            .all(|(a, b)| (a - b).abs() < 1e-6)
366    );
367}
368
369/// Demonstrate serialization and loading
370fn demonstrate_serialization() -> Result<(), Box<dyn std::error::Error>> {
371    println!("\n--- Serialization ---");
372
373    // Create and train a simple layer
374    let mut original_layer = LinearLayer::new(2, 1, Some(47));
375
376    // Simple training data
377    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
378    let y_true = Tensor::from_slice(&[5.0, 11.0], vec![2, 1]).unwrap();
379
380    let mut optimizer = Adam::with_learning_rate(0.01);
381    let params = original_layer.parameters();
382    for param in &params {
383        optimizer.add_parameter(param);
384    }
385
386    // Train for a few epochs
387    for _ in 0..10 {
388        let y_pred = original_layer.forward(&x_data);
389        let mut loss = (y_pred.sub_tensor(&y_true)).pow_scalar(2.0).mean();
390        loss.backward(None);
391
392        let mut params = original_layer.parameters();
393        optimizer.step(&mut params);
394        optimizer.zero_grad(&mut params);
395    }
396
397    println!("Original layer trained");
398    println!("  Weight: {:?}", original_layer.weight.data());
399    println!("  Bias: {:?}", original_layer.bias.data());
400
401    // Save layer
402    original_layer.save_json("temp_linear_layer")?;
403
404    // Load layer
405    let loaded_layer = LinearLayer::load_json("temp_linear_layer", 2, 1)?;
406
407    println!("Loaded layer");
408    println!("  Weight: {:?}", loaded_layer.weight.data());
409    println!("  Bias: {:?}", loaded_layer.bias.data());
410
411    // Verify consistency
412    let test_input = Tensor::from_slice(&[1.0, 1.0], vec![1, 2]).unwrap();
413    let original_output = original_layer.forward_no_grad(&test_input);
414    let loaded_output = loaded_layer.forward_no_grad(&test_input);
415
416    println!("Consistency check:");
417    println!("  Original output: {:?}", original_output.data());
418    println!("  Loaded output: {:?}", loaded_output.data());
419    println!(
420        "  Match: {}",
421        original_output
422            .data()
423            .iter()
424            .zip(loaded_output.data().iter())
425            .all(|(a, b)| (a - b).abs() < 1e-6)
426    );
427
428    println!("Serialization verification: PASSED");
429
430    Ok(())
431}
432
433/// Clean up temporary files
434fn cleanup_temp_files() -> Result<(), Box<dyn std::error::Error>> {
435    println!("\n--- Cleanup ---");
436
437    let files_to_remove = [
438        "temp_linear_layer_weight.json",
439        "temp_linear_layer_bias.json",
440    ];
441
442    for file in &files_to_remove {
443        if fs::metadata(file).is_ok() {
444            fs::remove_file(file)?;
445            println!("Removed: {}", file);
446        }
447    }
448
449    println!("Cleanup completed");
450    Ok(())
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456
457    #[test]
458    fn test_layer_creation() {
459        let layer = LinearLayer::new(3, 2, Some(42));
460        assert_eq!(layer.input_size, 3);
461        assert_eq!(layer.output_size, 2);
462        assert_eq!(layer.weight.shape().dims, vec![3, 2]);
463        assert_eq!(layer.bias.shape().dims, vec![2]);
464        assert!(layer.weight.requires_grad());
465        assert!(layer.bias.requires_grad());
466    }
467
468    #[test]
469    fn test_forward_pass() {
470        let layer = LinearLayer::new(2, 1, Some(43));
471        let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
472        let output = layer.forward(&input);
473
474        assert_eq!(output.shape().dims, vec![1, 1]);
475        assert!(output.requires_grad());
476    }
477
478    #[test]
479    fn test_forward_pass_no_grad() {
480        let layer = LinearLayer::new(2, 1, Some(44));
481        let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
482        let output = layer.forward_no_grad(&input);
483
484        assert_eq!(output.shape().dims, vec![1, 1]);
485        assert!(!output.requires_grad());
486    }
487
488    #[test]
489    fn test_batch_inference() {
490        let layer = LinearLayer::new(2, 1, Some(45));
491        let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
492        let output = layer.forward(&batch_input);
493
494        assert_eq!(output.shape().dims, vec![2, 1]);
495    }
496
497    #[test]
498    fn test_parameter_count() {
499        let layer = LinearLayer::new(3, 2, Some(46));
500        assert_eq!(layer.parameter_count(), 3 * 2 + 2); // weights + bias
501    }
502
503    #[test]
504    fn test_serialization_roundtrip() {
505        let original = LinearLayer::new(2, 1, Some(47));
506
507        // Save and load
508        original.save_json("test_layer").unwrap();
509        let loaded = LinearLayer::load_json("test_layer", 2, 1).unwrap();
510
511        // Verify shapes
512        assert_eq!(original.weight.shape().dims, loaded.weight.shape().dims);
513        assert_eq!(original.bias.shape().dims, loaded.bias.shape().dims);
514
515        // Verify data
516        assert_eq!(original.weight.data(), loaded.weight.data());
517        assert_eq!(original.bias.data(), loaded.bias.data());
518
519        // Cleanup
520        let _ = fs::remove_file("test_layer_weight.json");
521        let _ = fs::remove_file("test_layer_bias.json");
522    }
523}