feedforward_network/
feedforward_network.rs

1//! Feed-Forward Network Example
2//!
3//! This example demonstrates how to build a configurable multi-layer feed-forward network:
4//! - Configurable input/output sizes and number of hidden layers
5//! - Using the basic linear layer as building blocks
6//! - ReLU activation function implementation
7//! - Complete training workflow with gradient computation
8//! - Comprehensive training loop with 100+ steps
9//! - Proper gradient tracking and computation graph connectivity
10//! - Zero gradient management between training steps
11//!
12//! # Learning Objectives
13//!
14//! - Understand multi-layer network architecture design
15//! - Learn to compose linear layers into deeper networks
16//! - Implement activation functions and their gradient properties
17//! - Master training workflows with proper gradient management
18//! - Explore configurable network architectures
19//! - Understand gradient flow through multiple layers
20//!
21//! # Prerequisites
22//!
23//! - Basic Rust knowledge
24//! - Understanding of tensor basics (see getting_started/tensor_basics.rs)
25//! - Familiarity with neural network concepts
26//! - Knowledge of the basic linear layer (see basic_linear_layer.rs)
27//! - Understanding of activation functions and backpropagation
28//!
29//! # Usage
30//!
31//! ```bash
32//! cargo run --example feedforward_network
33//! ```
34
35use std::fs;
36use train_station::{
37    optimizers::{Adam, Optimizer},
38    serialization::StructSerializable,
39    NoGradTrack, Tensor,
40};
41
42/// ReLU activation function
43pub struct ReLU;
44
45impl ReLU {
46    /// Apply ReLU activation: max(0, x)
47    pub fn forward(input: &Tensor) -> Tensor {
48        input.relu()
49    }
50
51    /// Apply ReLU activation without gradients
52    pub fn forward_no_grad(input: &Tensor) -> Tensor {
53        let _guard = NoGradTrack::new();
54        Self::forward(input)
55    }
56}
57
58/// A basic linear layer implementation (reused from basic_linear_layer.rs)
59#[derive(Debug)]
60pub struct LinearLayer {
61    pub weight: Tensor,
62    pub bias: Tensor,
63    pub input_size: usize,
64    pub output_size: usize,
65}
66
67impl LinearLayer {
68    pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
69        let scale = (1.0 / input_size as f32).sqrt();
70
71        let weight = Tensor::randn(vec![input_size, output_size], seed)
72            .mul_scalar(scale)
73            .with_requires_grad();
74        let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
75
76        Self {
77            weight,
78            bias,
79            input_size,
80            output_size,
81        }
82    }
83
84    pub fn forward(&self, input: &Tensor) -> Tensor {
85        let output = input.matmul(&self.weight);
86        output.add_tensor(&self.bias)
87    }
88
89    pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
90        let _guard = NoGradTrack::new();
91        self.forward(input)
92    }
93
94    pub fn parameters(&mut self) -> Vec<&mut Tensor> {
95        vec![&mut self.weight, &mut self.bias]
96    }
97}
98
99/// Configuration for feed-forward network
100#[derive(Debug, Clone)]
101pub struct FeedForwardConfig {
102    pub input_size: usize,
103    pub hidden_sizes: Vec<usize>,
104    pub output_size: usize,
105    pub use_bias: bool,
106}
107
108impl Default for FeedForwardConfig {
109    fn default() -> Self {
110        Self {
111            input_size: 4,
112            hidden_sizes: vec![8, 4],
113            output_size: 2,
114            use_bias: true,
115        }
116    }
117}
118
119/// A configurable feed-forward neural network
120pub struct FeedForwardNetwork {
121    layers: Vec<LinearLayer>,
122    config: FeedForwardConfig,
123}
124
125impl FeedForwardNetwork {
126    /// Create a new feed-forward network with the given configuration
127    pub fn new(config: FeedForwardConfig, seed: Option<u64>) -> Self {
128        let mut layers = Vec::new();
129        let mut current_size = config.input_size;
130        let mut current_seed = seed;
131
132        // Create hidden layers
133        for &hidden_size in &config.hidden_sizes {
134            layers.push(LinearLayer::new(current_size, hidden_size, current_seed));
135            current_size = hidden_size;
136            current_seed = current_seed.map(|s| s + 1);
137        }
138
139        // Create output layer
140        layers.push(LinearLayer::new(
141            current_size,
142            config.output_size,
143            current_seed,
144        ));
145
146        Self { layers, config }
147    }
148
149    /// Forward pass through the entire network
150    pub fn forward(&self, input: &Tensor) -> Tensor {
151        let mut x = input.clone();
152
153        // Pass through all layers except the last one with ReLU activation
154        for layer in &self.layers[..self.layers.len() - 1] {
155            x = layer.forward(&x);
156            x = ReLU::forward(&x);
157        }
158
159        // Final layer without activation (raw logits)
160        if let Some(final_layer) = self.layers.last() {
161            x = final_layer.forward(&x);
162        }
163
164        x
165    }
166
167    /// Forward pass without gradients (for inference)
168    pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
169        let _guard = NoGradTrack::new();
170        self.forward(input)
171    }
172
173    /// Get all parameters for optimization
174    pub fn parameters(&mut self) -> Vec<&mut Tensor> {
175        let mut params = Vec::new();
176        for layer in &mut self.layers {
177            params.extend(layer.parameters());
178        }
179        params
180    }
181
182    /// Get the number of layers
183    pub fn num_layers(&self) -> usize {
184        self.layers.len()
185    }
186
187    /// Get the total number of parameters
188    pub fn parameter_count(&self) -> usize {
189        let mut count = 0;
190        let mut current_size = self.config.input_size;
191
192        for &hidden_size in &self.config.hidden_sizes {
193            count += current_size * hidden_size + hidden_size; // weights + bias
194            current_size = hidden_size;
195        }
196
197        // Output layer
198        count += current_size * self.config.output_size + self.config.output_size;
199
200        count
201    }
202
203    /// Save network parameters to JSON
204    pub fn save_json(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
205        if let Some(parent) = std::path::Path::new(path).parent() {
206            fs::create_dir_all(parent)?;
207        }
208
209        for (i, layer) in self.layers.iter().enumerate() {
210            let layer_path = format!("{}_layer_{}", path, i);
211            let weight_path = format!("{}_weight.json", layer_path);
212            let bias_path = format!("{}_bias.json", layer_path);
213
214            layer.weight.save_json(&weight_path)?;
215            layer.bias.save_json(&bias_path)?;
216        }
217
218        println!(
219            "Saved feed-forward network to {} ({} layers)",
220            path,
221            self.layers.len()
222        );
223        Ok(())
224    }
225
226    /// Load network parameters from JSON
227    pub fn load_json(
228        path: &str,
229        config: FeedForwardConfig,
230    ) -> Result<Self, Box<dyn std::error::Error>> {
231        let mut layers = Vec::new();
232        let mut current_size = config.input_size;
233        let mut layer_idx = 0;
234
235        // Load hidden layers
236        for &hidden_size in &config.hidden_sizes {
237            let layer_path = format!("{}_layer_{}", path, layer_idx);
238            let weight_path = format!("{}_weight.json", layer_path);
239            let bias_path = format!("{}_bias.json", layer_path);
240
241            let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
242            let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
243
244            layers.push(LinearLayer {
245                weight,
246                bias,
247                input_size: current_size,
248                output_size: hidden_size,
249            });
250
251            current_size = hidden_size;
252            layer_idx += 1;
253        }
254
255        // Load output layer
256        let layer_path = format!("{}_layer_{}", path, layer_idx);
257        let weight_path = format!("{}_weight.json", layer_path);
258        let bias_path = format!("{}_bias.json", layer_path);
259
260        let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
261        let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
262
263        layers.push(LinearLayer {
264            weight,
265            bias,
266            input_size: current_size,
267            output_size: config.output_size,
268        });
269
270        Ok(Self { layers, config })
271    }
272}
273
274fn main() -> Result<(), Box<dyn std::error::Error>> {
275    println!("=== Feed-Forward Network Example ===\n");
276
277    demonstrate_network_creation();
278    demonstrate_forward_pass();
279    demonstrate_configurable_architectures();
280    demonstrate_training_workflow()?;
281    demonstrate_comprehensive_training()?;
282    demonstrate_network_serialization()?;
283    cleanup_temp_files()?;
284
285    println!("\n=== Example completed successfully! ===");
286    Ok(())
287}
288
289/// Demonstrate creating different network configurations
290fn demonstrate_network_creation() {
291    println!("--- Network Creation ---");
292
293    // Default configuration
294    let config = FeedForwardConfig::default();
295    let network = FeedForwardNetwork::new(config.clone(), Some(42));
296
297    println!("Default network configuration:");
298    println!("  Input size: {}", config.input_size);
299    println!("  Hidden sizes: {:?}", config.hidden_sizes);
300    println!("  Output size: {}", config.output_size);
301    println!("  Number of layers: {}", network.num_layers());
302    println!("  Total parameters: {}", network.parameter_count());
303
304    // Custom configurations
305    let configs = [
306        FeedForwardConfig {
307            input_size: 2,
308            hidden_sizes: vec![4],
309            output_size: 1,
310            use_bias: true,
311        },
312        FeedForwardConfig {
313            input_size: 8,
314            hidden_sizes: vec![16, 8, 4],
315            output_size: 3,
316            use_bias: true,
317        },
318        FeedForwardConfig {
319            input_size: 10,
320            hidden_sizes: vec![20, 15, 10, 5],
321            output_size: 2,
322            use_bias: true,
323        },
324    ];
325
326    for (i, config) in configs.iter().enumerate() {
327        let network = FeedForwardNetwork::new(config.clone(), Some(42 + i as u64));
328        println!("\nCustom network {}:", i + 1);
329        println!(
330            "  Architecture: {} -> {:?} -> {}",
331            config.input_size, config.hidden_sizes, config.output_size
332        );
333        println!("  Layers: {}", network.num_layers());
334        println!("  Parameters: {}", network.parameter_count());
335    }
336}
337
338/// Demonstrate forward pass through the network
339fn demonstrate_forward_pass() {
340    println!("\n--- Forward Pass ---");
341
342    let config = FeedForwardConfig {
343        input_size: 3,
344        hidden_sizes: vec![5, 3],
345        output_size: 2,
346        use_bias: true,
347    };
348    let network = FeedForwardNetwork::new(config, Some(43));
349
350    // Single input
351    let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
352    let output = network.forward(&input);
353
354    println!("Single input forward pass:");
355    println!("  Input shape: {:?}", input.shape().dims);
356    println!("  Output shape: {:?}", output.shape().dims);
357    println!("  Output: {:?}", output.data());
358    println!("  Output requires grad: {}", output.requires_grad());
359
360    // Batch input
361    let batch_input = Tensor::from_slice(
362        &[
363            1.0, 2.0, 3.0, // Sample 1
364            4.0, 5.0, 6.0, // Sample 2
365            7.0, 8.0, 9.0, // Sample 3
366        ],
367        vec![3, 3],
368    )
369    .unwrap();
370    let batch_output = network.forward(&batch_input);
371
372    println!("Batch input forward pass:");
373    println!("  Input shape: {:?}", batch_input.shape().dims);
374    println!("  Output shape: {:?}", batch_output.shape().dims);
375    println!("  Output requires grad: {}", batch_output.requires_grad());
376
377    // Compare with no-grad version
378    let output_no_grad = network.forward_no_grad(&input);
379    println!("No-grad comparison:");
380    println!("  Same values: {}", output.data() == output_no_grad.data());
381    println!("  With grad requires grad: {}", output.requires_grad());
382    println!(
383        "  No grad requires grad: {}",
384        output_no_grad.requires_grad()
385    );
386}
387
388/// Demonstrate different configurable architectures
389fn demonstrate_configurable_architectures() {
390    println!("\n--- Configurable Architectures ---");
391
392    let architectures = vec![
393        ("Shallow", vec![8]),
394        ("Medium", vec![16, 8]),
395        ("Deep", vec![32, 16, 8, 4]),
396        ("Wide", vec![64, 32]),
397        ("Bottleneck", vec![16, 4, 16]),
398    ];
399
400    for (name, hidden_sizes) in architectures {
401        let config = FeedForwardConfig {
402            input_size: 10,
403            hidden_sizes,
404            output_size: 3,
405            use_bias: true,
406        };
407
408        let network = FeedForwardNetwork::new(config.clone(), Some(44));
409
410        // Test forward pass
411        let test_input = Tensor::randn(vec![5, 10], Some(45)); // Batch of 5
412        let output = network.forward_no_grad(&test_input);
413
414        println!("{} network:", name);
415        println!("  Architecture: 10 -> {:?} -> 3", config.hidden_sizes);
416        println!("  Parameters: {}", network.parameter_count());
417        println!("  Test output shape: {:?}", output.shape().dims);
418        println!(
419            "  Output range: [{:.3}, {:.3}]",
420            output.data().iter().fold(f32::INFINITY, |a, &b| a.min(b)),
421            output
422                .data()
423                .iter()
424                .fold(f32::NEG_INFINITY, |a, &b| a.max(b))
425        );
426    }
427}
428
429/// Demonstrate basic training workflow
430fn demonstrate_training_workflow() -> Result<(), Box<dyn std::error::Error>> {
431    println!("\n--- Training Workflow ---");
432
433    // Create a simple classification network
434    let config = FeedForwardConfig {
435        input_size: 2,
436        hidden_sizes: vec![4, 3],
437        output_size: 1,
438        use_bias: true,
439    };
440    let mut network = FeedForwardNetwork::new(config, Some(46));
441
442    println!("Training network: 2 -> [4, 3] -> 1");
443
444    // Create simple binary classification data: XOR problem
445    let x_data = Tensor::from_slice(
446        &[
447            0.0, 0.0, // -> 0
448            0.0, 1.0, // -> 1
449            1.0, 0.0, // -> 1
450            1.0, 1.0, // -> 0
451        ],
452        vec![4, 2],
453    )
454    .unwrap();
455
456    let y_true = Tensor::from_slice(&[0.0, 1.0, 1.0, 0.0], vec![4, 1]).unwrap();
457
458    println!("Training on XOR problem:");
459    println!("  Input shape: {:?}", x_data.shape().dims);
460    println!("  Target shape: {:?}", y_true.shape().dims);
461
462    // Create optimizer
463    let mut optimizer = Adam::with_learning_rate(0.1);
464    let params = network.parameters();
465    for param in &params {
466        optimizer.add_parameter(param);
467    }
468
469    // Training loop
470    let num_epochs = 50;
471    let mut losses = Vec::new();
472
473    for epoch in 0..num_epochs {
474        // Forward pass
475        let y_pred = network.forward(&x_data);
476
477        // Compute loss: MSE
478        let diff = y_pred.sub_tensor(&y_true);
479        let mut loss = diff.pow_scalar(2.0).mean();
480
481        // Backward pass
482        loss.backward(None);
483
484        // Optimizer step and zero grad
485        let mut params = network.parameters();
486        optimizer.step(&mut params);
487        optimizer.zero_grad(&mut params);
488
489        losses.push(loss.value());
490
491        // Print progress
492        if epoch % 10 == 0 || epoch == num_epochs - 1 {
493            println!("Epoch {:2}: Loss = {:.6}", epoch, loss.value());
494        }
495    }
496
497    // Test final model
498    let final_predictions = network.forward_no_grad(&x_data);
499    println!("\nFinal predictions vs targets:");
500    for i in 0..4 {
501        let pred = final_predictions.data()[i];
502        let target = y_true.data()[i];
503        let input_x = x_data.data()[i * 2];
504        let input_y = x_data.data()[i * 2 + 1];
505        println!(
506            "  [{:.0}, {:.0}] -> pred: {:.3}, target: {:.0}, error: {:.3}",
507            input_x,
508            input_y,
509            pred,
510            target,
511            (pred - target).abs()
512        );
513    }
514
515    Ok(())
516}
517
518/// Demonstrate comprehensive training with 100+ steps
519fn demonstrate_comprehensive_training() -> Result<(), Box<dyn std::error::Error>> {
520    println!("\n--- Comprehensive Training (100+ Steps) ---");
521
522    // Create a regression network
523    let config = FeedForwardConfig {
524        input_size: 3,
525        hidden_sizes: vec![8, 6, 4],
526        output_size: 2,
527        use_bias: true,
528    };
529    let mut network = FeedForwardNetwork::new(config, Some(47));
530
531    println!("Network architecture: 3 -> [8, 6, 4] -> 2");
532    println!("Total parameters: {}", network.parameter_count());
533
534    // Create synthetic regression data
535    // Target function: [y1, y2] = [x1 + 2*x2 - x3, x1*x2 + x3]
536    let num_samples = 32;
537    let mut x_vec = Vec::new();
538    let mut y_vec = Vec::new();
539
540    for i in 0..num_samples {
541        let x1 = (i as f32 / num_samples as f32) * 2.0 - 1.0; // [-1, 1]
542        let x2 = ((i * 2) as f32 / num_samples as f32) * 2.0 - 1.0;
543        let x3 = ((i * 3) as f32 / num_samples as f32) * 2.0 - 1.0;
544
545        let y1 = x1 + 2.0 * x2 - x3;
546        let y2 = x1 * x2 + x3;
547
548        x_vec.extend_from_slice(&[x1, x2, x3]);
549        y_vec.extend_from_slice(&[y1, y2]);
550    }
551
552    let x_data = Tensor::from_slice(&x_vec, vec![num_samples, 3]).unwrap();
553    let y_true = Tensor::from_slice(&y_vec, vec![num_samples, 2]).unwrap();
554
555    println!("Training data:");
556    println!("  {} samples", num_samples);
557    println!("  Input shape: {:?}", x_data.shape().dims);
558    println!("  Target shape: {:?}", y_true.shape().dims);
559
560    // Create optimizer with learning rate scheduling
561    let mut optimizer = Adam::with_learning_rate(0.01);
562    let params = network.parameters();
563    for param in &params {
564        optimizer.add_parameter(param);
565    }
566
567    // Comprehensive training loop (150 epochs)
568    let num_epochs = 150;
569    let mut losses = Vec::new();
570    let mut best_loss = f32::INFINITY;
571    let mut patience_counter = 0;
572    let patience = 20;
573
574    println!("Starting comprehensive training...");
575
576    for epoch in 0..num_epochs {
577        // Forward pass
578        let y_pred = network.forward(&x_data);
579
580        // Compute loss: MSE
581        let diff = y_pred.sub_tensor(&y_true);
582        let mut loss = diff.pow_scalar(2.0).mean();
583
584        // Backward pass
585        loss.backward(None);
586
587        // Optimizer step and zero grad
588        let mut params = network.parameters();
589        optimizer.step(&mut params);
590        optimizer.zero_grad(&mut params);
591
592        let current_loss = loss.value();
593        losses.push(current_loss);
594
595        // Learning rate scheduling
596        if epoch > 0 && epoch % 30 == 0 {
597            let new_lr = optimizer.learning_rate() * 0.8;
598            optimizer.set_learning_rate(new_lr);
599            println!("  Reduced learning rate to {:.4}", new_lr);
600        }
601
602        // Early stopping logic
603        if current_loss < best_loss {
604            best_loss = current_loss;
605            patience_counter = 0;
606        } else {
607            patience_counter += 1;
608        }
609
610        // Print progress
611        if epoch % 25 == 0 || epoch == num_epochs - 1 {
612            println!(
613                "Epoch {:3}: Loss = {:.6}, LR = {:.4}, Best = {:.6}",
614                epoch,
615                current_loss,
616                optimizer.learning_rate(),
617                best_loss
618            );
619        }
620
621        // Early stopping
622        if patience_counter >= patience && epoch > 50 {
623            println!("Early stopping at epoch {} (patience exceeded)", epoch);
624            break;
625        }
626    }
627
628    // Final evaluation
629    let final_predictions = network.forward_no_grad(&x_data);
630
631    // Compute final metrics
632    let final_loss = losses[losses.len() - 1];
633    let initial_loss = losses[0];
634    let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
635
636    println!("\nTraining completed!");
637    println!("  Initial loss: {:.6}", initial_loss);
638    println!("  Final loss: {:.6}", final_loss);
639    println!("  Best loss: {:.6}", best_loss);
640    println!("  Loss reduction: {:.1}%", loss_reduction);
641    println!("  Final learning rate: {:.4}", optimizer.learning_rate());
642
643    // Sample predictions analysis
644    println!("\nSample predictions (first 5):");
645    for i in 0..5.min(num_samples) {
646        let pred1 = final_predictions.data()[i * 2];
647        let pred2 = final_predictions.data()[i * 2 + 1];
648        let true1 = y_true.data()[i * 2];
649        let true2 = y_true.data()[i * 2 + 1];
650
651        println!(
652            "  Sample {}: pred=[{:.3}, {:.3}], true=[{:.3}, {:.3}], error=[{:.3}, {:.3}]",
653            i + 1,
654            pred1,
655            pred2,
656            true1,
657            true2,
658            (pred1 - true1).abs(),
659            (pred2 - true2).abs()
660        );
661    }
662
663    Ok(())
664}
665
666/// Demonstrate network serialization
667fn demonstrate_network_serialization() -> Result<(), Box<dyn std::error::Error>> {
668    println!("\n--- Network Serialization ---");
669
670    // Create and train a network
671    let config = FeedForwardConfig {
672        input_size: 2,
673        hidden_sizes: vec![4, 2],
674        output_size: 1,
675        use_bias: true,
676    };
677    let mut original_network = FeedForwardNetwork::new(config.clone(), Some(48));
678
679    // Quick training
680    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
681    let y_true = Tensor::from_slice(&[5.0, 11.0], vec![2, 1]).unwrap();
682
683    let mut optimizer = Adam::with_learning_rate(0.01);
684    let params = original_network.parameters();
685    for param in &params {
686        optimizer.add_parameter(param);
687    }
688
689    for _ in 0..20 {
690        let y_pred = original_network.forward(&x_data);
691        let mut loss = (y_pred.sub_tensor(&y_true)).pow_scalar(2.0).mean();
692        loss.backward(None);
693
694        let mut params = original_network.parameters();
695        optimizer.step(&mut params);
696        optimizer.zero_grad(&mut params);
697    }
698
699    // Test original network
700    let test_input = Tensor::from_slice(&[1.0, 1.0], vec![1, 2]).unwrap();
701    let original_output = original_network.forward_no_grad(&test_input);
702
703    println!("Original network output: {:?}", original_output.data());
704
705    // Save network
706    original_network.save_json("temp_feedforward_network")?;
707
708    // Load network
709    let loaded_network = FeedForwardNetwork::load_json("temp_feedforward_network", config)?;
710    let loaded_output = loaded_network.forward_no_grad(&test_input);
711
712    println!("Loaded network output: {:?}", loaded_output.data());
713
714    // Verify consistency
715    let match_check = original_output
716        .data()
717        .iter()
718        .zip(loaded_output.data().iter())
719        .all(|(a, b)| (a - b).abs() < 1e-6);
720
721    println!(
722        "Serialization verification: {}",
723        if match_check { "PASSED" } else { "FAILED" }
724    );
725
726    Ok(())
727}
728
729/// Clean up temporary files
730fn cleanup_temp_files() -> Result<(), Box<dyn std::error::Error>> {
731    println!("\n--- Cleanup ---");
732
733    // Remove network files
734    for i in 0..10 {
735        // Assume max 10 layers
736        let weight_file = format!("temp_feedforward_network_layer_{}_weight.json", i);
737        let bias_file = format!("temp_feedforward_network_layer_{}_bias.json", i);
738
739        if fs::metadata(&weight_file).is_ok() {
740            fs::remove_file(&weight_file)?;
741            println!("Removed: {}", weight_file);
742        }
743        if fs::metadata(&bias_file).is_ok() {
744            fs::remove_file(&bias_file)?;
745            println!("Removed: {}", bias_file);
746        }
747    }
748
749    println!("Cleanup completed");
750    Ok(())
751}
752
753#[cfg(test)]
754mod tests {
755    use super::*;
756
757    #[test]
758    fn test_relu_activation() {
759        let input = Tensor::from_slice(&[-2.0, -1.0, 0.0, 1.0, 2.0], vec![1, 5]).unwrap();
760        let output = ReLU::forward(&input);
761        let expected = vec![0.0, 0.0, 0.0, 1.0, 2.0];
762
763        assert_eq!(output.data(), &expected);
764    }
765
766    #[test]
767    fn test_network_creation() {
768        let config = FeedForwardConfig {
769            input_size: 3,
770            hidden_sizes: vec![5, 4],
771            output_size: 2,
772            use_bias: true,
773        };
774        let network = FeedForwardNetwork::new(config, Some(42));
775
776        assert_eq!(network.num_layers(), 3); // 2 hidden + 1 output
777        assert_eq!(network.parameter_count(), 3 * 5 + 5 + 5 * 4 + 4 + 4 * 2 + 2);
778        // weights + biases
779    }
780
781    #[test]
782    fn test_forward_pass() {
783        let config = FeedForwardConfig {
784            input_size: 2,
785            hidden_sizes: vec![3],
786            output_size: 1,
787            use_bias: true,
788        };
789        let network = FeedForwardNetwork::new(config, Some(43));
790
791        let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
792        let output = network.forward(&input);
793
794        assert_eq!(output.shape().dims, vec![1, 1]);
795        assert!(output.requires_grad());
796    }
797
798    #[test]
799    fn test_batch_forward_pass() {
800        let config = FeedForwardConfig {
801            input_size: 2,
802            hidden_sizes: vec![3],
803            output_size: 1,
804            use_bias: true,
805        };
806        let network = FeedForwardNetwork::new(config, Some(44));
807
808        let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
809        let output = network.forward(&batch_input);
810
811        assert_eq!(output.shape().dims, vec![2, 1]);
812    }
813
814    #[test]
815    fn test_no_grad_forward() {
816        let config = FeedForwardConfig::default();
817        let network = FeedForwardNetwork::new(config, Some(45));
818
819        let input = Tensor::randn(vec![1, 4], Some(46));
820        let output = network.forward_no_grad(&input);
821
822        assert!(!output.requires_grad());
823    }
824
825    #[test]
826    fn test_parameter_collection() {
827        let config = FeedForwardConfig {
828            input_size: 2,
829            hidden_sizes: vec![3],
830            output_size: 1,
831            use_bias: true,
832        };
833        let mut network = FeedForwardNetwork::new(config, Some(47));
834
835        let params = network.parameters();
836        assert_eq!(params.len(), 4); // 2 layers * 2 parameters (weight + bias) each
837    }
838}