feedforward_network/
feedforward_network.rs

1//! Feed-Forward Network Example
2//!
3//! This example demonstrates how to build a configurable multi-layer feed-forward network:
4//! - Configurable input/output sizes and number of hidden layers
5//! - Using the basic linear layer as building blocks
6//! - ReLU activation function implementation
7//! - Complete training workflow with gradient computation
8//! - Comprehensive training loop with 100+ steps
9//! - Proper gradient tracking and computation graph connectivity
10//! - Zero gradient management between training steps
11//!
12//! # Learning Objectives
13//!
14//! - Understand multi-layer network architecture design
15//! - Learn to compose linear layers into deeper networks
16//! - Implement activation functions and their gradient properties
17//! - Master training workflows with proper gradient management
18//! - Explore configurable network architectures
19//! - Understand gradient flow through multiple layers
20//!
21//! # Prerequisites
22//!
23//! - Basic Rust knowledge
24//! - Understanding of tensor basics (see getting_started/tensor_basics.rs)
25//! - Familiarity with neural network concepts
26//! - Knowledge of the basic linear layer (see basic_linear_layer.rs)
27//! - Understanding of activation functions and backpropagation
28//!
29//! # Usage
30//!
31//! ```bash
32//! cargo run --example feedforward_network
33//! ```
34
35use std::fs;
36use train_station::{
37    gradtrack::NoGradTrack,
38    optimizers::{Adam, Optimizer},
39    serialization::StructSerializable,
40    Tensor,
41};
42
43/// ReLU activation function
44pub struct ReLU;
45
46impl ReLU {
47    /// Apply ReLU activation: max(0, x)
48    pub fn forward(input: &Tensor) -> Tensor {
49        input.relu()
50    }
51
52    /// Apply ReLU activation without gradients
53    pub fn forward_no_grad(input: &Tensor) -> Tensor {
54        let _guard = NoGradTrack::new();
55        Self::forward(input)
56    }
57}
58
59/// A basic linear layer implementation (reused from basic_linear_layer.rs)
60#[derive(Debug)]
61pub struct LinearLayer {
62    pub weight: Tensor,
63    pub bias: Tensor,
64    pub input_size: usize,
65    pub output_size: usize,
66}
67
68impl LinearLayer {
69    pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
70        let scale = (1.0 / input_size as f32).sqrt();
71
72        let weight = Tensor::randn(vec![input_size, output_size], seed)
73            .mul_scalar(scale)
74            .with_requires_grad();
75        let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
76
77        Self {
78            weight,
79            bias,
80            input_size,
81            output_size,
82        }
83    }
84
85    pub fn forward(&self, input: &Tensor) -> Tensor {
86        let output = input.matmul(&self.weight);
87        output.add_tensor(&self.bias)
88    }
89
90    pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
91        let _guard = NoGradTrack::new();
92        self.forward(input)
93    }
94
95    pub fn parameters(&mut self) -> Vec<&mut Tensor> {
96        vec![&mut self.weight, &mut self.bias]
97    }
98}
99
100/// Configuration for feed-forward network
101#[derive(Debug, Clone)]
102pub struct FeedForwardConfig {
103    pub input_size: usize,
104    pub hidden_sizes: Vec<usize>,
105    pub output_size: usize,
106    pub use_bias: bool,
107}
108
109impl Default for FeedForwardConfig {
110    fn default() -> Self {
111        Self {
112            input_size: 4,
113            hidden_sizes: vec![8, 4],
114            output_size: 2,
115            use_bias: true,
116        }
117    }
118}
119
120/// A configurable feed-forward neural network
121pub struct FeedForwardNetwork {
122    layers: Vec<LinearLayer>,
123    config: FeedForwardConfig,
124}
125
126impl FeedForwardNetwork {
127    /// Create a new feed-forward network with the given configuration
128    pub fn new(config: FeedForwardConfig, seed: Option<u64>) -> Self {
129        let mut layers = Vec::new();
130        let mut current_size = config.input_size;
131        let mut current_seed = seed;
132
133        // Create hidden layers
134        for &hidden_size in &config.hidden_sizes {
135            layers.push(LinearLayer::new(current_size, hidden_size, current_seed));
136            current_size = hidden_size;
137            current_seed = current_seed.map(|s| s + 1);
138        }
139
140        // Create output layer
141        layers.push(LinearLayer::new(
142            current_size,
143            config.output_size,
144            current_seed,
145        ));
146
147        Self { layers, config }
148    }
149
150    /// Forward pass through the entire network
151    pub fn forward(&self, input: &Tensor) -> Tensor {
152        let mut x = input.clone();
153
154        // Pass through all layers except the last one with ReLU activation
155        for layer in &self.layers[..self.layers.len() - 1] {
156            x = layer.forward(&x);
157            x = ReLU::forward(&x);
158        }
159
160        // Final layer without activation (raw logits)
161        if let Some(final_layer) = self.layers.last() {
162            x = final_layer.forward(&x);
163        }
164
165        x
166    }
167
168    /// Forward pass without gradients (for inference)
169    pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
170        let _guard = NoGradTrack::new();
171        self.forward(input)
172    }
173
174    /// Get all parameters for optimization
175    pub fn parameters(&mut self) -> Vec<&mut Tensor> {
176        let mut params = Vec::new();
177        for layer in &mut self.layers {
178            params.extend(layer.parameters());
179        }
180        params
181    }
182
183    /// Get the number of layers
184    pub fn num_layers(&self) -> usize {
185        self.layers.len()
186    }
187
188    /// Get the total number of parameters
189    pub fn parameter_count(&self) -> usize {
190        let mut count = 0;
191        let mut current_size = self.config.input_size;
192
193        for &hidden_size in &self.config.hidden_sizes {
194            count += current_size * hidden_size + hidden_size; // weights + bias
195            current_size = hidden_size;
196        }
197
198        // Output layer
199        count += current_size * self.config.output_size + self.config.output_size;
200
201        count
202    }
203
204    /// Save network parameters to JSON
205    pub fn save_json(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
206        if let Some(parent) = std::path::Path::new(path).parent() {
207            fs::create_dir_all(parent)?;
208        }
209
210        for (i, layer) in self.layers.iter().enumerate() {
211            let layer_path = format!("{}_layer_{}", path, i);
212            let weight_path = format!("{}_weight.json", layer_path);
213            let bias_path = format!("{}_bias.json", layer_path);
214
215            layer.weight.save_json(&weight_path)?;
216            layer.bias.save_json(&bias_path)?;
217        }
218
219        println!(
220            "Saved feed-forward network to {} ({} layers)",
221            path,
222            self.layers.len()
223        );
224        Ok(())
225    }
226
227    /// Load network parameters from JSON
228    pub fn load_json(
229        path: &str,
230        config: FeedForwardConfig,
231    ) -> Result<Self, Box<dyn std::error::Error>> {
232        let mut layers = Vec::new();
233        let mut current_size = config.input_size;
234        let mut layer_idx = 0;
235
236        // Load hidden layers
237        for &hidden_size in &config.hidden_sizes {
238            let layer_path = format!("{}_layer_{}", path, layer_idx);
239            let weight_path = format!("{}_weight.json", layer_path);
240            let bias_path = format!("{}_bias.json", layer_path);
241
242            let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
243            let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
244
245            layers.push(LinearLayer {
246                weight,
247                bias,
248                input_size: current_size,
249                output_size: hidden_size,
250            });
251
252            current_size = hidden_size;
253            layer_idx += 1;
254        }
255
256        // Load output layer
257        let layer_path = format!("{}_layer_{}", path, layer_idx);
258        let weight_path = format!("{}_weight.json", layer_path);
259        let bias_path = format!("{}_bias.json", layer_path);
260
261        let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
262        let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
263
264        layers.push(LinearLayer {
265            weight,
266            bias,
267            input_size: current_size,
268            output_size: config.output_size,
269        });
270
271        Ok(Self { layers, config })
272    }
273}
274
275fn main() -> Result<(), Box<dyn std::error::Error>> {
276    println!("=== Feed-Forward Network Example ===\n");
277
278    demonstrate_network_creation();
279    demonstrate_forward_pass();
280    demonstrate_configurable_architectures();
281    demonstrate_training_workflow()?;
282    demonstrate_comprehensive_training()?;
283    demonstrate_network_serialization()?;
284    cleanup_temp_files()?;
285
286    println!("\n=== Example completed successfully! ===");
287    Ok(())
288}
289
290/// Demonstrate creating different network configurations
291fn demonstrate_network_creation() {
292    println!("--- Network Creation ---");
293
294    // Default configuration
295    let config = FeedForwardConfig::default();
296    let network = FeedForwardNetwork::new(config.clone(), Some(42));
297
298    println!("Default network configuration:");
299    println!("  Input size: {}", config.input_size);
300    println!("  Hidden sizes: {:?}", config.hidden_sizes);
301    println!("  Output size: {}", config.output_size);
302    println!("  Number of layers: {}", network.num_layers());
303    println!("  Total parameters: {}", network.parameter_count());
304
305    // Custom configurations
306    let configs = [
307        FeedForwardConfig {
308            input_size: 2,
309            hidden_sizes: vec![4],
310            output_size: 1,
311            use_bias: true,
312        },
313        FeedForwardConfig {
314            input_size: 8,
315            hidden_sizes: vec![16, 8, 4],
316            output_size: 3,
317            use_bias: true,
318        },
319        FeedForwardConfig {
320            input_size: 10,
321            hidden_sizes: vec![20, 15, 10, 5],
322            output_size: 2,
323            use_bias: true,
324        },
325    ];
326
327    for (i, config) in configs.iter().enumerate() {
328        let network = FeedForwardNetwork::new(config.clone(), Some(42 + i as u64));
329        println!("\nCustom network {}:", i + 1);
330        println!(
331            "  Architecture: {} -> {:?} -> {}",
332            config.input_size, config.hidden_sizes, config.output_size
333        );
334        println!("  Layers: {}", network.num_layers());
335        println!("  Parameters: {}", network.parameter_count());
336    }
337}
338
339/// Demonstrate forward pass through the network
340fn demonstrate_forward_pass() {
341    println!("\n--- Forward Pass ---");
342
343    let config = FeedForwardConfig {
344        input_size: 3,
345        hidden_sizes: vec![5, 3],
346        output_size: 2,
347        use_bias: true,
348    };
349    let network = FeedForwardNetwork::new(config, Some(43));
350
351    // Single input
352    let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
353    let output = network.forward(&input);
354
355    println!("Single input forward pass:");
356    println!("  Input shape: {:?}", input.shape().dims());
357    println!("  Output shape: {:?}", output.shape().dims());
358    println!("  Output: {:?}", output.data());
359    println!("  Output requires grad: {}", output.requires_grad());
360
361    // Batch input
362    let batch_input = Tensor::from_slice(
363        &[
364            1.0, 2.0, 3.0, // Sample 1
365            4.0, 5.0, 6.0, // Sample 2
366            7.0, 8.0, 9.0, // Sample 3
367        ],
368        vec![3, 3],
369    )
370    .unwrap();
371    let batch_output = network.forward(&batch_input);
372
373    println!("Batch input forward pass:");
374    println!("  Input shape: {:?}", batch_input.shape().dims());
375    println!("  Output shape: {:?}", batch_output.shape().dims());
376    println!("  Output requires grad: {}", batch_output.requires_grad());
377
378    // Compare with no-grad version
379    let output_no_grad = network.forward_no_grad(&input);
380    println!("No-grad comparison:");
381    println!("  Same values: {}", output.data() == output_no_grad.data());
382    println!("  With grad requires grad: {}", output.requires_grad());
383    println!(
384        "  No grad requires grad: {}",
385        output_no_grad.requires_grad()
386    );
387}
388
389/// Demonstrate different configurable architectures
390fn demonstrate_configurable_architectures() {
391    println!("\n--- Configurable Architectures ---");
392
393    let architectures = vec![
394        ("Shallow", vec![8]),
395        ("Medium", vec![16, 8]),
396        ("Deep", vec![32, 16, 8, 4]),
397        ("Wide", vec![64, 32]),
398        ("Bottleneck", vec![16, 4, 16]),
399    ];
400
401    for (name, hidden_sizes) in architectures {
402        let config = FeedForwardConfig {
403            input_size: 10,
404            hidden_sizes,
405            output_size: 3,
406            use_bias: true,
407        };
408
409        let network = FeedForwardNetwork::new(config.clone(), Some(44));
410
411        // Test forward pass
412        let test_input = Tensor::randn(vec![5, 10], Some(45)); // Batch of 5
413        let output = network.forward_no_grad(&test_input);
414
415        println!("{} network:", name);
416        println!("  Architecture: 10 -> {:?} -> 3", config.hidden_sizes);
417        println!("  Parameters: {}", network.parameter_count());
418        println!("  Test output shape: {:?}", output.shape().dims());
419        println!(
420            "  Output range: [{:.3}, {:.3}]",
421            output.data().iter().fold(f32::INFINITY, |a, &b| a.min(b)),
422            output
423                .data()
424                .iter()
425                .fold(f32::NEG_INFINITY, |a, &b| a.max(b))
426        );
427    }
428}
429
430/// Demonstrate basic training workflow
431fn demonstrate_training_workflow() -> Result<(), Box<dyn std::error::Error>> {
432    println!("\n--- Training Workflow ---");
433
434    // Create a simple classification network
435    let config = FeedForwardConfig {
436        input_size: 2,
437        hidden_sizes: vec![4, 3],
438        output_size: 1,
439        use_bias: true,
440    };
441    let mut network = FeedForwardNetwork::new(config, Some(46));
442
443    println!("Training network: 2 -> [4, 3] -> 1");
444
445    // Create simple binary classification data: XOR problem
446    let x_data = Tensor::from_slice(
447        &[
448            0.0, 0.0, // -> 0
449            0.0, 1.0, // -> 1
450            1.0, 0.0, // -> 1
451            1.0, 1.0, // -> 0
452        ],
453        vec![4, 2],
454    )
455    .unwrap();
456
457    let y_true = Tensor::from_slice(&[0.0, 1.0, 1.0, 0.0], vec![4, 1]).unwrap();
458
459    println!("Training on XOR problem:");
460    println!("  Input shape: {:?}", x_data.shape().dims());
461    println!("  Target shape: {:?}", y_true.shape().dims());
462
463    // Create optimizer
464    let mut optimizer = Adam::with_learning_rate(0.1);
465    let params = network.parameters();
466    for param in &params {
467        optimizer.add_parameter(param);
468    }
469
470    // Training loop
471    let num_epochs = 50;
472    let mut losses = Vec::new();
473
474    for epoch in 0..num_epochs {
475        // Forward pass
476        let y_pred = network.forward(&x_data);
477
478        // Compute loss: MSE
479        let diff = y_pred.sub_tensor(&y_true);
480        let mut loss = diff.pow_scalar(2.0).mean();
481
482        // Backward pass
483        loss.backward(None);
484
485        // Optimizer step and zero grad
486        let mut params = network.parameters();
487        optimizer.step(&mut params);
488        optimizer.zero_grad(&mut params);
489
490        losses.push(loss.value());
491
492        // Print progress
493        if epoch % 10 == 0 || epoch == num_epochs - 1 {
494            println!("Epoch {:2}: Loss = {:.6}", epoch, loss.value());
495        }
496    }
497
498    // Test final model
499    let final_predictions = network.forward_no_grad(&x_data);
500    println!("\nFinal predictions vs targets:");
501    for i in 0..4 {
502        let pred = final_predictions.data()[i];
503        let target = y_true.data()[i];
504        let input_x = x_data.data()[i * 2];
505        let input_y = x_data.data()[i * 2 + 1];
506        println!(
507            "  [{:.0}, {:.0}] -> pred: {:.3}, target: {:.0}, error: {:.3}",
508            input_x,
509            input_y,
510            pred,
511            target,
512            (pred - target).abs()
513        );
514    }
515
516    Ok(())
517}
518
519/// Demonstrate comprehensive training with 100+ steps
520fn demonstrate_comprehensive_training() -> Result<(), Box<dyn std::error::Error>> {
521    println!("\n--- Comprehensive Training (100+ Steps) ---");
522
523    // Create a regression network
524    let config = FeedForwardConfig {
525        input_size: 3,
526        hidden_sizes: vec![8, 6, 4],
527        output_size: 2,
528        use_bias: true,
529    };
530    let mut network = FeedForwardNetwork::new(config, Some(47));
531
532    println!("Network architecture: 3 -> [8, 6, 4] -> 2");
533    println!("Total parameters: {}", network.parameter_count());
534
535    // Create synthetic regression data
536    // Target function: [y1, y2] = [x1 + 2*x2 - x3, x1*x2 + x3]
537    let num_samples = 32;
538    let mut x_vec = Vec::new();
539    let mut y_vec = Vec::new();
540
541    for i in 0..num_samples {
542        let x1 = (i as f32 / num_samples as f32) * 2.0 - 1.0; // [-1, 1]
543        let x2 = ((i * 2) as f32 / num_samples as f32) * 2.0 - 1.0;
544        let x3 = ((i * 3) as f32 / num_samples as f32) * 2.0 - 1.0;
545
546        let y1 = x1 + 2.0 * x2 - x3;
547        let y2 = x1 * x2 + x3;
548
549        x_vec.extend_from_slice(&[x1, x2, x3]);
550        y_vec.extend_from_slice(&[y1, y2]);
551    }
552
553    let x_data = Tensor::from_slice(&x_vec, vec![num_samples, 3]).unwrap();
554    let y_true = Tensor::from_slice(&y_vec, vec![num_samples, 2]).unwrap();
555
556    println!("Training data:");
557    println!("  {} samples", num_samples);
558    println!("  Input shape: {:?}", x_data.shape().dims());
559    println!("  Target shape: {:?}", y_true.shape().dims());
560
561    // Create optimizer with learning rate scheduling
562    let mut optimizer = Adam::with_learning_rate(0.01);
563    let params = network.parameters();
564    for param in &params {
565        optimizer.add_parameter(param);
566    }
567
568    // Comprehensive training loop (150 epochs)
569    let num_epochs = 150;
570    let mut losses = Vec::new();
571    let mut best_loss = f32::INFINITY;
572    let mut patience_counter = 0;
573    let patience = 20;
574
575    println!("Starting comprehensive training...");
576
577    for epoch in 0..num_epochs {
578        // Forward pass
579        let y_pred = network.forward(&x_data);
580
581        // Compute loss: MSE
582        let diff = y_pred.sub_tensor(&y_true);
583        let mut loss = diff.pow_scalar(2.0).mean();
584
585        // Backward pass
586        loss.backward(None);
587
588        // Optimizer step and zero grad
589        let mut params = network.parameters();
590        optimizer.step(&mut params);
591        optimizer.zero_grad(&mut params);
592
593        let current_loss = loss.value();
594        losses.push(current_loss);
595
596        // Learning rate scheduling
597        if epoch > 0 && epoch % 30 == 0 {
598            let new_lr = optimizer.learning_rate() * 0.8;
599            optimizer.set_learning_rate(new_lr);
600            println!("  Reduced learning rate to {:.4}", new_lr);
601        }
602
603        // Early stopping logic
604        if current_loss < best_loss {
605            best_loss = current_loss;
606            patience_counter = 0;
607        } else {
608            patience_counter += 1;
609        }
610
611        // Print progress
612        if epoch % 25 == 0 || epoch == num_epochs - 1 {
613            println!(
614                "Epoch {:3}: Loss = {:.6}, LR = {:.4}, Best = {:.6}",
615                epoch,
616                current_loss,
617                optimizer.learning_rate(),
618                best_loss
619            );
620        }
621
622        // Early stopping
623        if patience_counter >= patience && epoch > 50 {
624            println!("Early stopping at epoch {} (patience exceeded)", epoch);
625            break;
626        }
627    }
628
629    // Final evaluation
630    let final_predictions = network.forward_no_grad(&x_data);
631
632    // Compute final metrics
633    let final_loss = losses[losses.len() - 1];
634    let initial_loss = losses[0];
635    let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
636
637    println!("\nTraining completed!");
638    println!("  Initial loss: {:.6}", initial_loss);
639    println!("  Final loss: {:.6}", final_loss);
640    println!("  Best loss: {:.6}", best_loss);
641    println!("  Loss reduction: {:.1}%", loss_reduction);
642    println!("  Final learning rate: {:.4}", optimizer.learning_rate());
643
644    // Sample predictions analysis
645    println!("\nSample predictions (first 5):");
646    for i in 0..5.min(num_samples) {
647        let pred1 = final_predictions.data()[i * 2];
648        let pred2 = final_predictions.data()[i * 2 + 1];
649        let true1 = y_true.data()[i * 2];
650        let true2 = y_true.data()[i * 2 + 1];
651
652        println!(
653            "  Sample {}: pred=[{:.3}, {:.3}], true=[{:.3}, {:.3}], error=[{:.3}, {:.3}]",
654            i + 1,
655            pred1,
656            pred2,
657            true1,
658            true2,
659            (pred1 - true1).abs(),
660            (pred2 - true2).abs()
661        );
662    }
663
664    Ok(())
665}
666
667/// Demonstrate network serialization
668fn demonstrate_network_serialization() -> Result<(), Box<dyn std::error::Error>> {
669    println!("\n--- Network Serialization ---");
670
671    // Create and train a network
672    let config = FeedForwardConfig {
673        input_size: 2,
674        hidden_sizes: vec![4, 2],
675        output_size: 1,
676        use_bias: true,
677    };
678    let mut original_network = FeedForwardNetwork::new(config.clone(), Some(48));
679
680    // Quick training
681    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
682    let y_true = Tensor::from_slice(&[5.0, 11.0], vec![2, 1]).unwrap();
683
684    let mut optimizer = Adam::with_learning_rate(0.01);
685    let params = original_network.parameters();
686    for param in &params {
687        optimizer.add_parameter(param);
688    }
689
690    for _ in 0..20 {
691        let y_pred = original_network.forward(&x_data);
692        let mut loss = (y_pred.sub_tensor(&y_true)).pow_scalar(2.0).mean();
693        loss.backward(None);
694
695        let mut params = original_network.parameters();
696        optimizer.step(&mut params);
697        optimizer.zero_grad(&mut params);
698    }
699
700    // Test original network
701    let test_input = Tensor::from_slice(&[1.0, 1.0], vec![1, 2]).unwrap();
702    let original_output = original_network.forward_no_grad(&test_input);
703
704    println!("Original network output: {:?}", original_output.data());
705
706    // Save network
707    original_network.save_json("temp_feedforward_network")?;
708
709    // Load network
710    let loaded_network = FeedForwardNetwork::load_json("temp_feedforward_network", config)?;
711    let loaded_output = loaded_network.forward_no_grad(&test_input);
712
713    println!("Loaded network output: {:?}", loaded_output.data());
714
715    // Verify consistency
716    let match_check = original_output
717        .data()
718        .iter()
719        .zip(loaded_output.data().iter())
720        .all(|(a, b)| (a - b).abs() < 1e-6);
721
722    println!(
723        "Serialization verification: {}",
724        if match_check { "PASSED" } else { "FAILED" }
725    );
726
727    Ok(())
728}
729
730/// Clean up temporary files
731fn cleanup_temp_files() -> Result<(), Box<dyn std::error::Error>> {
732    println!("\n--- Cleanup ---");
733
734    // Remove network files
735    for i in 0..10 {
736        // Assume max 10 layers
737        let weight_file = format!("temp_feedforward_network_layer_{}_weight.json", i);
738        let bias_file = format!("temp_feedforward_network_layer_{}_bias.json", i);
739
740        if fs::metadata(&weight_file).is_ok() {
741            fs::remove_file(&weight_file)?;
742            println!("Removed: {}", weight_file);
743        }
744        if fs::metadata(&bias_file).is_ok() {
745            fs::remove_file(&bias_file)?;
746            println!("Removed: {}", bias_file);
747        }
748    }
749
750    println!("Cleanup completed");
751    Ok(())
752}
753
754#[cfg(test)]
755mod tests {
756    use super::*;
757
758    #[test]
759    fn test_relu_activation() {
760        let input = Tensor::from_slice(&[-2.0, -1.0, 0.0, 1.0, 2.0], vec![1, 5]).unwrap();
761        let output = ReLU::forward(&input);
762        let expected = vec![0.0, 0.0, 0.0, 1.0, 2.0];
763
764        assert_eq!(output.data(), &expected);
765    }
766
767    #[test]
768    fn test_network_creation() {
769        let config = FeedForwardConfig {
770            input_size: 3,
771            hidden_sizes: vec![5, 4],
772            output_size: 2,
773            use_bias: true,
774        };
775        let network = FeedForwardNetwork::new(config, Some(42));
776
777        assert_eq!(network.num_layers(), 3); // 2 hidden + 1 output
778        assert_eq!(network.parameter_count(), 3 * 5 + 5 + 5 * 4 + 4 + 4 * 2 + 2);
779        // weights + biases
780    }
781
782    #[test]
783    fn test_forward_pass() {
784        let config = FeedForwardConfig {
785            input_size: 2,
786            hidden_sizes: vec![3],
787            output_size: 1,
788            use_bias: true,
789        };
790        let network = FeedForwardNetwork::new(config, Some(43));
791
792        let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
793        let output = network.forward(&input);
794
795        assert_eq!(output.shape().dims(), vec![1, 1]);
796        assert!(output.requires_grad());
797    }
798
799    #[test]
800    fn test_batch_forward_pass() {
801        let config = FeedForwardConfig {
802            input_size: 2,
803            hidden_sizes: vec![3],
804            output_size: 1,
805            use_bias: true,
806        };
807        let network = FeedForwardNetwork::new(config, Some(44));
808
809        let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
810        let output = network.forward(&batch_input);
811
812        assert_eq!(output.shape().dims(), vec![2, 1]);
813    }
814
815    #[test]
816    fn test_no_grad_forward() {
817        let config = FeedForwardConfig::default();
818        let network = FeedForwardNetwork::new(config, Some(45));
819
820        let input = Tensor::randn(vec![1, 4], Some(46));
821        let output = network.forward_no_grad(&input);
822
823        assert!(!output.requires_grad());
824    }
825
826    #[test]
827    fn test_parameter_collection() {
828        let config = FeedForwardConfig {
829            input_size: 2,
830            hidden_sizes: vec![3],
831            output_size: 1,
832            use_bias: true,
833        };
834        let mut network = FeedForwardNetwork::new(config, Some(47));
835
836        let params = network.parameters();
837        assert_eq!(params.len(), 4); // 2 layers * 2 parameters (weight + bias) each
838    }
839}