feedforward_network/
feedforward_network.rs

1//! Feed-Forward Network Example
2//!
3//! This example demonstrates how to build a configurable multi-layer feed-forward network:
4//! - Configurable input/output sizes and number of hidden layers
5//! - Using the basic linear layer as building blocks
6//! - ReLU activation function implementation
7//! - Complete training workflow with gradient computation
8//! - Comprehensive training loop with 100+ steps
9//! - Proper gradient tracking and computation graph connectivity
10//! - Zero gradient management between training steps
11//!
12//! # Learning Objectives
13//!
14//! - Understand multi-layer network architecture design
15//! - Learn to compose linear layers into deeper networks
16//! - Implement activation functions and their gradient properties
17//! - Master training workflows with proper gradient management
18//! - Explore configurable network architectures
19//! - Understand gradient flow through multiple layers
20//!
21//! # Prerequisites
22//!
23//! - Basic Rust knowledge
24//! - Understanding of tensor basics (see getting_started/tensor_basics.rs)
25//! - Familiarity with neural network concepts
26//! - Knowledge of the basic linear layer (see basic_linear_layer.rs)
27//! - Understanding of activation functions and backpropagation
28//!
29//! # Usage
30//!
31//! ```bash
32//! cargo run --example feedforward_network
33//! ```
34
35use std::fs;
36use train_station::{
37    gradtrack::NoGradTrack,
38    optimizers::{Adam, Optimizer},
39    serialization::StructSerializable,
40    Tensor,
41};
42
43// Reuse LinearLayer from the shared example to avoid duplication
44#[allow(clippy::duplicate_mod)]
45#[path = "basic_linear_layer.rs"]
46mod basic_linear_layer;
47use basic_linear_layer::LinearLayer;
48
49/// ReLU activation function
50pub struct ReLU;
51
52impl ReLU {
53    /// Apply ReLU activation: max(0, x)
54    pub fn forward(input: &Tensor) -> Tensor {
55        input.relu()
56    }
57
58    /// Apply ReLU activation without gradients
59    #[allow(unused)]
60    pub fn forward_no_grad(input: &Tensor) -> Tensor {
61        let _guard = NoGradTrack::new();
62        Self::forward(input)
63    }
64}
65
66/// Configuration for feed-forward network
67#[derive(Debug, Clone)]
68pub struct FeedForwardConfig {
69    pub input_size: usize,
70    pub hidden_sizes: Vec<usize>,
71    pub output_size: usize,
72    #[allow(unused)]
73    pub use_bias: bool,
74}
75
76impl Default for FeedForwardConfig {
77    fn default() -> Self {
78        Self {
79            input_size: 4,
80            hidden_sizes: vec![8, 4],
81            output_size: 2,
82            use_bias: true,
83        }
84    }
85}
86
87/// A configurable feed-forward neural network
88pub struct FeedForwardNetwork {
89    layers: Vec<LinearLayer>,
90    config: FeedForwardConfig,
91}
92
93impl FeedForwardNetwork {
94    /// Create a new feed-forward network with the given configuration
95    pub fn new(config: FeedForwardConfig, seed: Option<u64>) -> Self {
96        let mut layers = Vec::new();
97        let mut current_size = config.input_size;
98        let mut current_seed = seed;
99
100        // Create hidden layers
101        for &hidden_size in &config.hidden_sizes {
102            layers.push(LinearLayer::new(current_size, hidden_size, current_seed));
103            current_size = hidden_size;
104            current_seed = current_seed.map(|s| s + 1);
105        }
106
107        // Create output layer
108        layers.push(LinearLayer::new(
109            current_size,
110            config.output_size,
111            current_seed,
112        ));
113
114        Self { layers, config }
115    }
116
117    /// Forward pass through the entire network
118    pub fn forward(&self, input: &Tensor) -> Tensor {
119        let mut x = input.clone();
120
121        // Pass through all layers except the last one with ReLU activation
122        for layer in &self.layers[..self.layers.len() - 1] {
123            x = layer.forward(&x);
124            x = ReLU::forward(&x);
125        }
126
127        // Final layer without activation (raw logits)
128        if let Some(final_layer) = self.layers.last() {
129            x = final_layer.forward(&x);
130        }
131
132        x
133    }
134
135    /// Forward pass without gradients (for inference)
136    #[allow(unused)]
137    pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
138        let _guard = NoGradTrack::new();
139        self.forward(input)
140    }
141
142    /// Get all parameters for optimization
143    pub fn parameters(&mut self) -> Vec<&mut Tensor> {
144        let mut params = Vec::new();
145        for layer in &mut self.layers {
146            params.extend(layer.parameters());
147        }
148        params
149    }
150
151    /// Get the number of layers
152    #[allow(unused)]
153    pub fn num_layers(&self) -> usize {
154        self.layers.len()
155    }
156
157    /// Get the total number of parameters
158    #[allow(unused)]
159    pub fn parameter_count(&self) -> usize {
160        let mut count = 0;
161        let mut current_size = self.config.input_size;
162
163        for &hidden_size in &self.config.hidden_sizes {
164            count += current_size * hidden_size + hidden_size; // weights + bias
165            current_size = hidden_size;
166        }
167
168        // Output layer
169        count += current_size * self.config.output_size + self.config.output_size;
170
171        count
172    }
173
174    /// Save network parameters to JSON
175    #[allow(unused)]
176    pub fn save_json(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
177        if let Some(parent) = std::path::Path::new(path).parent() {
178            fs::create_dir_all(parent)?;
179        }
180
181        for (i, layer) in self.layers.iter().enumerate() {
182            let layer_path = format!("{}_layer_{}", path, i);
183            let weight_path = format!("{}_weight.json", layer_path);
184            let bias_path = format!("{}_bias.json", layer_path);
185
186            layer.weight.save_json(&weight_path)?;
187            layer.bias.save_json(&bias_path)?;
188        }
189
190        println!(
191            "Saved feed-forward network to {} ({} layers)",
192            path,
193            self.layers.len()
194        );
195        Ok(())
196    }
197
198    /// Load network parameters from JSON
199    #[allow(unused)]
200    pub fn load_json(
201        path: &str,
202        config: FeedForwardConfig,
203    ) -> Result<Self, Box<dyn std::error::Error>> {
204        let mut layers = Vec::new();
205        let mut current_size = config.input_size;
206        let mut layer_idx = 0;
207
208        // Load hidden layers
209        for &hidden_size in &config.hidden_sizes {
210            let layer_path = format!("{}_layer_{}", path, layer_idx);
211            let weight_path = format!("{}_weight.json", layer_path);
212            let bias_path = format!("{}_bias.json", layer_path);
213
214            let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
215            let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
216
217            layers.push(LinearLayer {
218                weight,
219                bias,
220                input_size: current_size,
221                output_size: hidden_size,
222            });
223
224            current_size = hidden_size;
225            layer_idx += 1;
226        }
227
228        // Load output layer
229        let layer_path = format!("{}_layer_{}", path, layer_idx);
230        let weight_path = format!("{}_weight.json", layer_path);
231        let bias_path = format!("{}_bias.json", layer_path);
232
233        let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
234        let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
235
236        layers.push(LinearLayer {
237            weight,
238            bias,
239            input_size: current_size,
240            output_size: config.output_size,
241        });
242
243        Ok(Self { layers, config })
244    }
245}
246
247#[allow(unused)]
248fn main() -> Result<(), Box<dyn std::error::Error>> {
249    println!("=== Feed-Forward Network Example ===\n");
250
251    demonstrate_network_creation();
252    demonstrate_forward_pass();
253    demonstrate_configurable_architectures();
254    demonstrate_training_workflow()?;
255    demonstrate_comprehensive_training()?;
256    demonstrate_network_serialization()?;
257    cleanup_temp_files()?;
258
259    println!("\n=== Example completed successfully! ===");
260    Ok(())
261}
262
263/// Demonstrate creating different network configurations
264fn demonstrate_network_creation() {
265    println!("--- Network Creation ---");
266
267    // Default configuration
268    let config = FeedForwardConfig::default();
269    let network = FeedForwardNetwork::new(config.clone(), Some(42));
270
271    println!("Default network configuration:");
272    println!("  Input size: {}", config.input_size);
273    println!("  Hidden sizes: {:?}", config.hidden_sizes);
274    println!("  Output size: {}", config.output_size);
275    println!("  Number of layers: {}", network.num_layers());
276    println!("  Total parameters: {}", network.parameter_count());
277
278    // Custom configurations
279    let configs = [
280        FeedForwardConfig {
281            input_size: 2,
282            hidden_sizes: vec![4],
283            output_size: 1,
284            use_bias: true,
285        },
286        FeedForwardConfig {
287            input_size: 8,
288            hidden_sizes: vec![16, 8, 4],
289            output_size: 3,
290            use_bias: true,
291        },
292        FeedForwardConfig {
293            input_size: 10,
294            hidden_sizes: vec![20, 15, 10, 5],
295            output_size: 2,
296            use_bias: true,
297        },
298    ];
299
300    for (i, config) in configs.iter().enumerate() {
301        let network = FeedForwardNetwork::new(config.clone(), Some(42 + i as u64));
302        println!("\nCustom network {}:", i + 1);
303        println!(
304            "  Architecture: {} -> {:?} -> {}",
305            config.input_size, config.hidden_sizes, config.output_size
306        );
307        println!("  Layers: {}", network.num_layers());
308        println!("  Parameters: {}", network.parameter_count());
309    }
310}
311
312/// Demonstrate forward pass through the network
313fn demonstrate_forward_pass() {
314    println!("\n--- Forward Pass ---");
315
316    let config = FeedForwardConfig {
317        input_size: 3,
318        hidden_sizes: vec![5, 3],
319        output_size: 2,
320        use_bias: true,
321    };
322    let network = FeedForwardNetwork::new(config, Some(43));
323
324    // Single input
325    let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
326    let output = network.forward(&input);
327
328    println!("Single input forward pass:");
329    println!("  Input shape: {:?}", input.shape().dims());
330    println!("  Output shape: {:?}", output.shape().dims());
331    println!("  Output: {:?}", output.data());
332    println!("  Output requires grad: {}", output.requires_grad());
333
334    // Batch input
335    let batch_input = Tensor::from_slice(
336        &[
337            1.0, 2.0, 3.0, // Sample 1
338            4.0, 5.0, 6.0, // Sample 2
339            7.0, 8.0, 9.0, // Sample 3
340        ],
341        vec![3, 3],
342    )
343    .unwrap();
344    let batch_output = network.forward(&batch_input);
345
346    println!("Batch input forward pass:");
347    println!("  Input shape: {:?}", batch_input.shape().dims());
348    println!("  Output shape: {:?}", batch_output.shape().dims());
349    println!("  Output requires grad: {}", batch_output.requires_grad());
350
351    // Compare with no-grad version
352    let output_no_grad = network.forward_no_grad(&input);
353    println!("No-grad comparison:");
354    println!("  Same values: {}", output.data() == output_no_grad.data());
355    println!("  With grad requires grad: {}", output.requires_grad());
356    println!(
357        "  No grad requires grad: {}",
358        output_no_grad.requires_grad()
359    );
360}
361
362/// Demonstrate different configurable architectures
363fn demonstrate_configurable_architectures() {
364    println!("\n--- Configurable Architectures ---");
365
366    let architectures = vec![
367        ("Shallow", vec![8]),
368        ("Medium", vec![16, 8]),
369        ("Deep", vec![32, 16, 8, 4]),
370        ("Wide", vec![64, 32]),
371        ("Bottleneck", vec![16, 4, 16]),
372    ];
373
374    for (name, hidden_sizes) in architectures {
375        let config = FeedForwardConfig {
376            input_size: 10,
377            hidden_sizes,
378            output_size: 3,
379            use_bias: true,
380        };
381
382        let network = FeedForwardNetwork::new(config.clone(), Some(44));
383
384        // Test forward pass
385        let test_input = Tensor::randn(vec![5, 10], Some(45)); // Batch of 5
386        let output = network.forward_no_grad(&test_input);
387
388        println!("{} network:", name);
389        println!("  Architecture: 10 -> {:?} -> 3", config.hidden_sizes);
390        println!("  Parameters: {}", network.parameter_count());
391        println!("  Test output shape: {:?}", output.shape().dims());
392        println!(
393            "  Output range: [{:.3}, {:.3}]",
394            output.data().iter().fold(f32::INFINITY, |a, &b| a.min(b)),
395            output
396                .data()
397                .iter()
398                .fold(f32::NEG_INFINITY, |a, &b| a.max(b))
399        );
400    }
401}
402
403/// Demonstrate basic training workflow
404fn demonstrate_training_workflow() -> Result<(), Box<dyn std::error::Error>> {
405    println!("\n--- Training Workflow ---");
406
407    // Create a simple classification network
408    let config = FeedForwardConfig {
409        input_size: 2,
410        hidden_sizes: vec![4, 3],
411        output_size: 1,
412        use_bias: true,
413    };
414    let mut network = FeedForwardNetwork::new(config, Some(46));
415
416    println!("Training network: 2 -> [4, 3] -> 1");
417
418    // Create simple binary classification data: XOR problem
419    let x_data = Tensor::from_slice(
420        &[
421            0.0, 0.0, // -> 0
422            0.0, 1.0, // -> 1
423            1.0, 0.0, // -> 1
424            1.0, 1.0, // -> 0
425        ],
426        vec![4, 2],
427    )
428    .unwrap();
429
430    let y_true = Tensor::from_slice(&[0.0, 1.0, 1.0, 0.0], vec![4, 1]).unwrap();
431
432    println!("Training on XOR problem:");
433    println!("  Input shape: {:?}", x_data.shape().dims());
434    println!("  Target shape: {:?}", y_true.shape().dims());
435
436    // Create optimizer
437    let mut optimizer = Adam::with_learning_rate(0.1);
438    let params = network.parameters();
439    for param in &params {
440        optimizer.add_parameter(param);
441    }
442
443    // Training loop
444    let num_epochs = 50;
445    let mut losses = Vec::new();
446
447    for epoch in 0..num_epochs {
448        // Forward pass
449        let y_pred = network.forward(&x_data);
450
451        // Compute loss: MSE
452        let diff = y_pred.sub_tensor(&y_true);
453        let mut loss = diff.pow_scalar(2.0).mean();
454
455        // Backward pass
456        loss.backward(None);
457
458        // Optimizer step and zero grad
459        let mut params = network.parameters();
460        optimizer.step(&mut params);
461        optimizer.zero_grad(&mut params);
462
463        losses.push(loss.value());
464
465        // Print progress
466        if epoch % 10 == 0 || epoch == num_epochs - 1 {
467            println!("Epoch {:2}: Loss = {:.6}", epoch, loss.value());
468        }
469    }
470
471    // Test final model
472    let final_predictions = network.forward_no_grad(&x_data);
473    println!("\nFinal predictions vs targets:");
474    for i in 0..4 {
475        let pred = final_predictions.data()[i];
476        let target = y_true.data()[i];
477        let input_x = x_data.data()[i * 2];
478        let input_y = x_data.data()[i * 2 + 1];
479        println!(
480            "  [{:.0}, {:.0}] -> pred: {:.3}, target: {:.0}, error: {:.3}",
481            input_x,
482            input_y,
483            pred,
484            target,
485            (pred - target).abs()
486        );
487    }
488
489    Ok(())
490}
491
492/// Demonstrate comprehensive training with 100+ steps
493fn demonstrate_comprehensive_training() -> Result<(), Box<dyn std::error::Error>> {
494    println!("\n--- Comprehensive Training (100+ Steps) ---");
495
496    // Create a regression network
497    let config = FeedForwardConfig {
498        input_size: 3,
499        hidden_sizes: vec![8, 6, 4],
500        output_size: 2,
501        use_bias: true,
502    };
503    let mut network = FeedForwardNetwork::new(config, Some(47));
504
505    println!("Network architecture: 3 -> [8, 6, 4] -> 2");
506    println!("Total parameters: {}", network.parameter_count());
507
508    // Create synthetic regression data
509    // Target function: [y1, y2] = [x1 + 2*x2 - x3, x1*x2 + x3]
510    let num_samples = 32;
511    let mut x_vec = Vec::new();
512    let mut y_vec = Vec::new();
513
514    for i in 0..num_samples {
515        let x1 = (i as f32 / num_samples as f32) * 2.0 - 1.0; // [-1, 1]
516        let x2 = ((i * 2) as f32 / num_samples as f32) * 2.0 - 1.0;
517        let x3 = ((i * 3) as f32 / num_samples as f32) * 2.0 - 1.0;
518
519        let y1 = x1 + 2.0 * x2 - x3;
520        let y2 = x1 * x2 + x3;
521
522        x_vec.extend_from_slice(&[x1, x2, x3]);
523        y_vec.extend_from_slice(&[y1, y2]);
524    }
525
526    let x_data = Tensor::from_slice(&x_vec, vec![num_samples, 3]).unwrap();
527    let y_true = Tensor::from_slice(&y_vec, vec![num_samples, 2]).unwrap();
528
529    println!("Training data:");
530    println!("  {} samples", num_samples);
531    println!("  Input shape: {:?}", x_data.shape().dims());
532    println!("  Target shape: {:?}", y_true.shape().dims());
533
534    // Create optimizer with learning rate scheduling
535    let mut optimizer = Adam::with_learning_rate(0.01);
536    let params = network.parameters();
537    for param in &params {
538        optimizer.add_parameter(param);
539    }
540
541    // Comprehensive training loop (150 epochs)
542    let num_epochs = 150;
543    let mut losses = Vec::new();
544    let mut best_loss = f32::INFINITY;
545    let mut patience_counter = 0;
546    let patience = 20;
547
548    println!("Starting comprehensive training...");
549
550    for epoch in 0..num_epochs {
551        // Forward pass
552        let y_pred = network.forward(&x_data);
553
554        // Compute loss: MSE
555        let diff = y_pred.sub_tensor(&y_true);
556        let mut loss = diff.pow_scalar(2.0).mean();
557
558        // Backward pass
559        loss.backward(None);
560
561        // Optimizer step and zero grad
562        let mut params = network.parameters();
563        optimizer.step(&mut params);
564        optimizer.zero_grad(&mut params);
565
566        let current_loss = loss.value();
567        losses.push(current_loss);
568
569        // Learning rate scheduling
570        if epoch > 0 && epoch % 30 == 0 {
571            let new_lr = optimizer.learning_rate() * 0.8;
572            optimizer.set_learning_rate(new_lr);
573            println!("  Reduced learning rate to {:.4}", new_lr);
574        }
575
576        // Early stopping logic
577        if current_loss < best_loss {
578            best_loss = current_loss;
579            patience_counter = 0;
580        } else {
581            patience_counter += 1;
582        }
583
584        // Print progress
585        if epoch % 25 == 0 || epoch == num_epochs - 1 {
586            println!(
587                "Epoch {:3}: Loss = {:.6}, LR = {:.4}, Best = {:.6}",
588                epoch,
589                current_loss,
590                optimizer.learning_rate(),
591                best_loss
592            );
593        }
594
595        // Early stopping
596        if patience_counter >= patience && epoch > 50 {
597            println!("Early stopping at epoch {} (patience exceeded)", epoch);
598            break;
599        }
600    }
601
602    // Final evaluation
603    let final_predictions = network.forward_no_grad(&x_data);
604
605    // Compute final metrics
606    let final_loss = losses[losses.len() - 1];
607    let initial_loss = losses[0];
608    let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
609
610    println!("\nTraining completed!");
611    println!("  Initial loss: {:.6}", initial_loss);
612    println!("  Final loss: {:.6}", final_loss);
613    println!("  Best loss: {:.6}", best_loss);
614    println!("  Loss reduction: {:.1}%", loss_reduction);
615    println!("  Final learning rate: {:.4}", optimizer.learning_rate());
616
617    // Sample predictions analysis
618    println!("\nSample predictions (first 5):");
619    for i in 0..5.min(num_samples) {
620        let pred1 = final_predictions.data()[i * 2];
621        let pred2 = final_predictions.data()[i * 2 + 1];
622        let true1 = y_true.data()[i * 2];
623        let true2 = y_true.data()[i * 2 + 1];
624
625        println!(
626            "  Sample {}: pred=[{:.3}, {:.3}], true=[{:.3}, {:.3}], error=[{:.3}, {:.3}]",
627            i + 1,
628            pred1,
629            pred2,
630            true1,
631            true2,
632            (pred1 - true1).abs(),
633            (pred2 - true2).abs()
634        );
635    }
636
637    Ok(())
638}
639
640/// Demonstrate network serialization
641fn demonstrate_network_serialization() -> Result<(), Box<dyn std::error::Error>> {
642    println!("\n--- Network Serialization ---");
643
644    // Create and train a network
645    let config = FeedForwardConfig {
646        input_size: 2,
647        hidden_sizes: vec![4, 2],
648        output_size: 1,
649        use_bias: true,
650    };
651    let mut original_network = FeedForwardNetwork::new(config.clone(), Some(48));
652
653    // Quick training
654    let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
655    let y_true = Tensor::from_slice(&[5.0, 11.0], vec![2, 1]).unwrap();
656
657    let mut optimizer = Adam::with_learning_rate(0.01);
658    let params = original_network.parameters();
659    for param in &params {
660        optimizer.add_parameter(param);
661    }
662
663    for _ in 0..20 {
664        let y_pred = original_network.forward(&x_data);
665        let mut loss = (y_pred.sub_tensor(&y_true)).pow_scalar(2.0).mean();
666        loss.backward(None);
667
668        let mut params = original_network.parameters();
669        optimizer.step(&mut params);
670        optimizer.zero_grad(&mut params);
671    }
672
673    // Test original network
674    let test_input = Tensor::from_slice(&[1.0, 1.0], vec![1, 2]).unwrap();
675    let original_output = original_network.forward_no_grad(&test_input);
676
677    println!("Original network output: {:?}", original_output.data());
678
679    // Save network
680    original_network.save_json("temp_feedforward_network")?;
681
682    // Load network
683    let loaded_network = FeedForwardNetwork::load_json("temp_feedforward_network", config)?;
684    let loaded_output = loaded_network.forward_no_grad(&test_input);
685
686    println!("Loaded network output: {:?}", loaded_output.data());
687
688    // Verify consistency
689    let match_check = original_output
690        .data()
691        .iter()
692        .zip(loaded_output.data().iter())
693        .all(|(a, b)| (a - b).abs() < 1e-6);
694
695    println!(
696        "Serialization verification: {}",
697        if match_check { "PASSED" } else { "FAILED" }
698    );
699
700    Ok(())
701}
702
703/// Clean up temporary files
704fn cleanup_temp_files() -> Result<(), Box<dyn std::error::Error>> {
705    println!("\n--- Cleanup ---");
706
707    // Remove network files
708    for i in 0..10 {
709        // Assume max 10 layers
710        let weight_file = format!("temp_feedforward_network_layer_{}_weight.json", i);
711        let bias_file = format!("temp_feedforward_network_layer_{}_bias.json", i);
712
713        if fs::metadata(&weight_file).is_ok() {
714            fs::remove_file(&weight_file)?;
715            println!("Removed: {}", weight_file);
716        }
717        if fs::metadata(&bias_file).is_ok() {
718            fs::remove_file(&bias_file)?;
719            println!("Removed: {}", bias_file);
720        }
721    }
722
723    println!("Cleanup completed");
724    Ok(())
725}
726
727#[cfg(test)]
728mod tests {
729    use super::*;
730
731    #[test]
732    fn test_relu_activation() {
733        let input = Tensor::from_slice(&[-2.0, -1.0, 0.0, 1.0, 2.0], vec![1, 5]).unwrap();
734        let output = ReLU::forward(&input);
735        let expected = vec![0.0, 0.0, 0.0, 1.0, 2.0];
736
737        assert_eq!(output.data(), &expected);
738    }
739
740    #[test]
741    fn test_network_creation() {
742        let config = FeedForwardConfig {
743            input_size: 3,
744            hidden_sizes: vec![5, 4],
745            output_size: 2,
746            use_bias: true,
747        };
748        let network = FeedForwardNetwork::new(config, Some(42));
749
750        assert_eq!(network.num_layers(), 3); // 2 hidden + 1 output
751        assert_eq!(network.parameter_count(), 3 * 5 + 5 + 5 * 4 + 4 + 4 * 2 + 2);
752        // weights + biases
753    }
754
755    #[test]
756    fn test_forward_pass() {
757        let config = FeedForwardConfig {
758            input_size: 2,
759            hidden_sizes: vec![3],
760            output_size: 1,
761            use_bias: true,
762        };
763        let network = FeedForwardNetwork::new(config, Some(43));
764
765        let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
766        let output = network.forward(&input);
767
768        assert_eq!(output.shape().dims(), vec![1, 1]);
769        assert!(output.requires_grad());
770    }
771
772    #[test]
773    fn test_batch_forward_pass() {
774        let config = FeedForwardConfig {
775            input_size: 2,
776            hidden_sizes: vec![3],
777            output_size: 1,
778            use_bias: true,
779        };
780        let network = FeedForwardNetwork::new(config, Some(44));
781
782        let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
783        let output = network.forward(&batch_input);
784
785        assert_eq!(output.shape().dims(), vec![2, 1]);
786    }
787
788    #[test]
789    fn test_no_grad_forward() {
790        let config = FeedForwardConfig::default();
791        let network = FeedForwardNetwork::new(config, Some(45));
792
793        let input = Tensor::randn(vec![1, 4], Some(46));
794        let output = network.forward_no_grad(&input);
795
796        assert!(!output.requires_grad());
797    }
798
799    #[test]
800    fn test_parameter_collection() {
801        let config = FeedForwardConfig {
802            input_size: 2,
803            hidden_sizes: vec![3],
804            output_size: 1,
805            use_bias: true,
806        };
807        let mut network = FeedForwardNetwork::new(config, Some(47));
808
809        let params = network.parameters();
810        assert_eq!(params.len(), 4); // 2 layers * 2 parameters (weight + bias) each
811    }
812}