1use std::fs;
36use train_station::{
37 gradtrack::NoGradTrack,
38 optimizers::{Adam, Optimizer},
39 serialization::StructSerializable,
40 Tensor,
41};
42
43#[allow(clippy::duplicate_mod)]
45#[path = "basic_linear_layer.rs"]
46mod basic_linear_layer;
47use basic_linear_layer::LinearLayer;
48
49pub struct ReLU;
51
52impl ReLU {
53 pub fn forward(input: &Tensor) -> Tensor {
55 input.relu()
56 }
57
58 #[allow(unused)]
60 pub fn forward_no_grad(input: &Tensor) -> Tensor {
61 let _guard = NoGradTrack::new();
62 Self::forward(input)
63 }
64}
65
66#[derive(Debug, Clone)]
68pub struct FeedForwardConfig {
69 pub input_size: usize,
70 pub hidden_sizes: Vec<usize>,
71 pub output_size: usize,
72 #[allow(unused)]
73 pub use_bias: bool,
74}
75
76impl Default for FeedForwardConfig {
77 fn default() -> Self {
78 Self {
79 input_size: 4,
80 hidden_sizes: vec![8, 4],
81 output_size: 2,
82 use_bias: true,
83 }
84 }
85}
86
87pub struct FeedForwardNetwork {
89 layers: Vec<LinearLayer>,
90 config: FeedForwardConfig,
91}
92
93impl FeedForwardNetwork {
94 pub fn new(config: FeedForwardConfig, seed: Option<u64>) -> Self {
96 let mut layers = Vec::new();
97 let mut current_size = config.input_size;
98 let mut current_seed = seed;
99
100 for &hidden_size in &config.hidden_sizes {
102 layers.push(LinearLayer::new(current_size, hidden_size, current_seed));
103 current_size = hidden_size;
104 current_seed = current_seed.map(|s| s + 1);
105 }
106
107 layers.push(LinearLayer::new(
109 current_size,
110 config.output_size,
111 current_seed,
112 ));
113
114 Self { layers, config }
115 }
116
117 pub fn forward(&self, input: &Tensor) -> Tensor {
119 let mut x = input.clone();
120
121 for layer in &self.layers[..self.layers.len() - 1] {
123 x = layer.forward(&x);
124 x = ReLU::forward(&x);
125 }
126
127 if let Some(final_layer) = self.layers.last() {
129 x = final_layer.forward(&x);
130 }
131
132 x
133 }
134
135 #[allow(unused)]
137 pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
138 let _guard = NoGradTrack::new();
139 self.forward(input)
140 }
141
142 pub fn parameters(&mut self) -> Vec<&mut Tensor> {
144 let mut params = Vec::new();
145 for layer in &mut self.layers {
146 params.extend(layer.parameters());
147 }
148 params
149 }
150
151 #[allow(unused)]
153 pub fn num_layers(&self) -> usize {
154 self.layers.len()
155 }
156
157 #[allow(unused)]
159 pub fn parameter_count(&self) -> usize {
160 let mut count = 0;
161 let mut current_size = self.config.input_size;
162
163 for &hidden_size in &self.config.hidden_sizes {
164 count += current_size * hidden_size + hidden_size; current_size = hidden_size;
166 }
167
168 count += current_size * self.config.output_size + self.config.output_size;
170
171 count
172 }
173
174 #[allow(unused)]
176 pub fn save_json(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
177 if let Some(parent) = std::path::Path::new(path).parent() {
178 fs::create_dir_all(parent)?;
179 }
180
181 for (i, layer) in self.layers.iter().enumerate() {
182 let layer_path = format!("{}_layer_{}", path, i);
183 let weight_path = format!("{}_weight.json", layer_path);
184 let bias_path = format!("{}_bias.json", layer_path);
185
186 layer.weight.save_json(&weight_path)?;
187 layer.bias.save_json(&bias_path)?;
188 }
189
190 println!(
191 "Saved feed-forward network to {} ({} layers)",
192 path,
193 self.layers.len()
194 );
195 Ok(())
196 }
197
198 #[allow(unused)]
200 pub fn load_json(
201 path: &str,
202 config: FeedForwardConfig,
203 ) -> Result<Self, Box<dyn std::error::Error>> {
204 let mut layers = Vec::new();
205 let mut current_size = config.input_size;
206 let mut layer_idx = 0;
207
208 for &hidden_size in &config.hidden_sizes {
210 let layer_path = format!("{}_layer_{}", path, layer_idx);
211 let weight_path = format!("{}_weight.json", layer_path);
212 let bias_path = format!("{}_bias.json", layer_path);
213
214 let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
215 let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
216
217 layers.push(LinearLayer {
218 weight,
219 bias,
220 input_size: current_size,
221 output_size: hidden_size,
222 });
223
224 current_size = hidden_size;
225 layer_idx += 1;
226 }
227
228 let layer_path = format!("{}_layer_{}", path, layer_idx);
230 let weight_path = format!("{}_weight.json", layer_path);
231 let bias_path = format!("{}_bias.json", layer_path);
232
233 let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
234 let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
235
236 layers.push(LinearLayer {
237 weight,
238 bias,
239 input_size: current_size,
240 output_size: config.output_size,
241 });
242
243 Ok(Self { layers, config })
244 }
245}
246
247#[allow(unused)]
248fn main() -> Result<(), Box<dyn std::error::Error>> {
249 println!("=== Feed-Forward Network Example ===\n");
250
251 demonstrate_network_creation();
252 demonstrate_forward_pass();
253 demonstrate_configurable_architectures();
254 demonstrate_training_workflow()?;
255 demonstrate_comprehensive_training()?;
256 demonstrate_network_serialization()?;
257 cleanup_temp_files()?;
258
259 println!("\n=== Example completed successfully! ===");
260 Ok(())
261}
262
263fn demonstrate_network_creation() {
265 println!("--- Network Creation ---");
266
267 let config = FeedForwardConfig::default();
269 let network = FeedForwardNetwork::new(config.clone(), Some(42));
270
271 println!("Default network configuration:");
272 println!(" Input size: {}", config.input_size);
273 println!(" Hidden sizes: {:?}", config.hidden_sizes);
274 println!(" Output size: {}", config.output_size);
275 println!(" Number of layers: {}", network.num_layers());
276 println!(" Total parameters: {}", network.parameter_count());
277
278 let configs = [
280 FeedForwardConfig {
281 input_size: 2,
282 hidden_sizes: vec![4],
283 output_size: 1,
284 use_bias: true,
285 },
286 FeedForwardConfig {
287 input_size: 8,
288 hidden_sizes: vec![16, 8, 4],
289 output_size: 3,
290 use_bias: true,
291 },
292 FeedForwardConfig {
293 input_size: 10,
294 hidden_sizes: vec![20, 15, 10, 5],
295 output_size: 2,
296 use_bias: true,
297 },
298 ];
299
300 for (i, config) in configs.iter().enumerate() {
301 let network = FeedForwardNetwork::new(config.clone(), Some(42 + i as u64));
302 println!("\nCustom network {}:", i + 1);
303 println!(
304 " Architecture: {} -> {:?} -> {}",
305 config.input_size, config.hidden_sizes, config.output_size
306 );
307 println!(" Layers: {}", network.num_layers());
308 println!(" Parameters: {}", network.parameter_count());
309 }
310}
311
312fn demonstrate_forward_pass() {
314 println!("\n--- Forward Pass ---");
315
316 let config = FeedForwardConfig {
317 input_size: 3,
318 hidden_sizes: vec![5, 3],
319 output_size: 2,
320 use_bias: true,
321 };
322 let network = FeedForwardNetwork::new(config, Some(43));
323
324 let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
326 let output = network.forward(&input);
327
328 println!("Single input forward pass:");
329 println!(" Input shape: {:?}", input.shape().dims());
330 println!(" Output shape: {:?}", output.shape().dims());
331 println!(" Output: {:?}", output.data());
332 println!(" Output requires grad: {}", output.requires_grad());
333
334 let batch_input = Tensor::from_slice(
336 &[
337 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, ],
341 vec![3, 3],
342 )
343 .unwrap();
344 let batch_output = network.forward(&batch_input);
345
346 println!("Batch input forward pass:");
347 println!(" Input shape: {:?}", batch_input.shape().dims());
348 println!(" Output shape: {:?}", batch_output.shape().dims());
349 println!(" Output requires grad: {}", batch_output.requires_grad());
350
351 let output_no_grad = network.forward_no_grad(&input);
353 println!("No-grad comparison:");
354 println!(" Same values: {}", output.data() == output_no_grad.data());
355 println!(" With grad requires grad: {}", output.requires_grad());
356 println!(
357 " No grad requires grad: {}",
358 output_no_grad.requires_grad()
359 );
360}
361
362fn demonstrate_configurable_architectures() {
364 println!("\n--- Configurable Architectures ---");
365
366 let architectures = vec![
367 ("Shallow", vec![8]),
368 ("Medium", vec![16, 8]),
369 ("Deep", vec![32, 16, 8, 4]),
370 ("Wide", vec![64, 32]),
371 ("Bottleneck", vec![16, 4, 16]),
372 ];
373
374 for (name, hidden_sizes) in architectures {
375 let config = FeedForwardConfig {
376 input_size: 10,
377 hidden_sizes,
378 output_size: 3,
379 use_bias: true,
380 };
381
382 let network = FeedForwardNetwork::new(config.clone(), Some(44));
383
384 let test_input = Tensor::randn(vec![5, 10], Some(45)); let output = network.forward_no_grad(&test_input);
387
388 println!("{} network:", name);
389 println!(" Architecture: 10 -> {:?} -> 3", config.hidden_sizes);
390 println!(" Parameters: {}", network.parameter_count());
391 println!(" Test output shape: {:?}", output.shape().dims());
392 println!(
393 " Output range: [{:.3}, {:.3}]",
394 output.data().iter().fold(f32::INFINITY, |a, &b| a.min(b)),
395 output
396 .data()
397 .iter()
398 .fold(f32::NEG_INFINITY, |a, &b| a.max(b))
399 );
400 }
401}
402
403fn demonstrate_training_workflow() -> Result<(), Box<dyn std::error::Error>> {
405 println!("\n--- Training Workflow ---");
406
407 let config = FeedForwardConfig {
409 input_size: 2,
410 hidden_sizes: vec![4, 3],
411 output_size: 1,
412 use_bias: true,
413 };
414 let mut network = FeedForwardNetwork::new(config, Some(46));
415
416 println!("Training network: 2 -> [4, 3] -> 1");
417
418 let x_data = Tensor::from_slice(
420 &[
421 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, ],
426 vec![4, 2],
427 )
428 .unwrap();
429
430 let y_true = Tensor::from_slice(&[0.0, 1.0, 1.0, 0.0], vec![4, 1]).unwrap();
431
432 println!("Training on XOR problem:");
433 println!(" Input shape: {:?}", x_data.shape().dims());
434 println!(" Target shape: {:?}", y_true.shape().dims());
435
436 let mut optimizer = Adam::with_learning_rate(0.1);
438 let params = network.parameters();
439 for param in ¶ms {
440 optimizer.add_parameter(param);
441 }
442
443 let num_epochs = 50;
445 let mut losses = Vec::new();
446
447 for epoch in 0..num_epochs {
448 let y_pred = network.forward(&x_data);
450
451 let diff = y_pred.sub_tensor(&y_true);
453 let mut loss = diff.pow_scalar(2.0).mean();
454
455 loss.backward(None);
457
458 let mut params = network.parameters();
460 optimizer.step(&mut params);
461 optimizer.zero_grad(&mut params);
462
463 losses.push(loss.value());
464
465 if epoch % 10 == 0 || epoch == num_epochs - 1 {
467 println!("Epoch {:2}: Loss = {:.6}", epoch, loss.value());
468 }
469 }
470
471 let final_predictions = network.forward_no_grad(&x_data);
473 println!("\nFinal predictions vs targets:");
474 for i in 0..4 {
475 let pred = final_predictions.data()[i];
476 let target = y_true.data()[i];
477 let input_x = x_data.data()[i * 2];
478 let input_y = x_data.data()[i * 2 + 1];
479 println!(
480 " [{:.0}, {:.0}] -> pred: {:.3}, target: {:.0}, error: {:.3}",
481 input_x,
482 input_y,
483 pred,
484 target,
485 (pred - target).abs()
486 );
487 }
488
489 Ok(())
490}
491
492fn demonstrate_comprehensive_training() -> Result<(), Box<dyn std::error::Error>> {
494 println!("\n--- Comprehensive Training (100+ Steps) ---");
495
496 let config = FeedForwardConfig {
498 input_size: 3,
499 hidden_sizes: vec![8, 6, 4],
500 output_size: 2,
501 use_bias: true,
502 };
503 let mut network = FeedForwardNetwork::new(config, Some(47));
504
505 println!("Network architecture: 3 -> [8, 6, 4] -> 2");
506 println!("Total parameters: {}", network.parameter_count());
507
508 let num_samples = 32;
511 let mut x_vec = Vec::new();
512 let mut y_vec = Vec::new();
513
514 for i in 0..num_samples {
515 let x1 = (i as f32 / num_samples as f32) * 2.0 - 1.0; let x2 = ((i * 2) as f32 / num_samples as f32) * 2.0 - 1.0;
517 let x3 = ((i * 3) as f32 / num_samples as f32) * 2.0 - 1.0;
518
519 let y1 = x1 + 2.0 * x2 - x3;
520 let y2 = x1 * x2 + x3;
521
522 x_vec.extend_from_slice(&[x1, x2, x3]);
523 y_vec.extend_from_slice(&[y1, y2]);
524 }
525
526 let x_data = Tensor::from_slice(&x_vec, vec![num_samples, 3]).unwrap();
527 let y_true = Tensor::from_slice(&y_vec, vec![num_samples, 2]).unwrap();
528
529 println!("Training data:");
530 println!(" {} samples", num_samples);
531 println!(" Input shape: {:?}", x_data.shape().dims());
532 println!(" Target shape: {:?}", y_true.shape().dims());
533
534 let mut optimizer = Adam::with_learning_rate(0.01);
536 let params = network.parameters();
537 for param in ¶ms {
538 optimizer.add_parameter(param);
539 }
540
541 let num_epochs = 150;
543 let mut losses = Vec::new();
544 let mut best_loss = f32::INFINITY;
545 let mut patience_counter = 0;
546 let patience = 20;
547
548 println!("Starting comprehensive training...");
549
550 for epoch in 0..num_epochs {
551 let y_pred = network.forward(&x_data);
553
554 let diff = y_pred.sub_tensor(&y_true);
556 let mut loss = diff.pow_scalar(2.0).mean();
557
558 loss.backward(None);
560
561 let mut params = network.parameters();
563 optimizer.step(&mut params);
564 optimizer.zero_grad(&mut params);
565
566 let current_loss = loss.value();
567 losses.push(current_loss);
568
569 if epoch > 0 && epoch % 30 == 0 {
571 let new_lr = optimizer.learning_rate() * 0.8;
572 optimizer.set_learning_rate(new_lr);
573 println!(" Reduced learning rate to {:.4}", new_lr);
574 }
575
576 if current_loss < best_loss {
578 best_loss = current_loss;
579 patience_counter = 0;
580 } else {
581 patience_counter += 1;
582 }
583
584 if epoch % 25 == 0 || epoch == num_epochs - 1 {
586 println!(
587 "Epoch {:3}: Loss = {:.6}, LR = {:.4}, Best = {:.6}",
588 epoch,
589 current_loss,
590 optimizer.learning_rate(),
591 best_loss
592 );
593 }
594
595 if patience_counter >= patience && epoch > 50 {
597 println!("Early stopping at epoch {} (patience exceeded)", epoch);
598 break;
599 }
600 }
601
602 let final_predictions = network.forward_no_grad(&x_data);
604
605 let final_loss = losses[losses.len() - 1];
607 let initial_loss = losses[0];
608 let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
609
610 println!("\nTraining completed!");
611 println!(" Initial loss: {:.6}", initial_loss);
612 println!(" Final loss: {:.6}", final_loss);
613 println!(" Best loss: {:.6}", best_loss);
614 println!(" Loss reduction: {:.1}%", loss_reduction);
615 println!(" Final learning rate: {:.4}", optimizer.learning_rate());
616
617 println!("\nSample predictions (first 5):");
619 for i in 0..5.min(num_samples) {
620 let pred1 = final_predictions.data()[i * 2];
621 let pred2 = final_predictions.data()[i * 2 + 1];
622 let true1 = y_true.data()[i * 2];
623 let true2 = y_true.data()[i * 2 + 1];
624
625 println!(
626 " Sample {}: pred=[{:.3}, {:.3}], true=[{:.3}, {:.3}], error=[{:.3}, {:.3}]",
627 i + 1,
628 pred1,
629 pred2,
630 true1,
631 true2,
632 (pred1 - true1).abs(),
633 (pred2 - true2).abs()
634 );
635 }
636
637 Ok(())
638}
639
640fn demonstrate_network_serialization() -> Result<(), Box<dyn std::error::Error>> {
642 println!("\n--- Network Serialization ---");
643
644 let config = FeedForwardConfig {
646 input_size: 2,
647 hidden_sizes: vec![4, 2],
648 output_size: 1,
649 use_bias: true,
650 };
651 let mut original_network = FeedForwardNetwork::new(config.clone(), Some(48));
652
653 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
655 let y_true = Tensor::from_slice(&[5.0, 11.0], vec![2, 1]).unwrap();
656
657 let mut optimizer = Adam::with_learning_rate(0.01);
658 let params = original_network.parameters();
659 for param in ¶ms {
660 optimizer.add_parameter(param);
661 }
662
663 for _ in 0..20 {
664 let y_pred = original_network.forward(&x_data);
665 let mut loss = (y_pred.sub_tensor(&y_true)).pow_scalar(2.0).mean();
666 loss.backward(None);
667
668 let mut params = original_network.parameters();
669 optimizer.step(&mut params);
670 optimizer.zero_grad(&mut params);
671 }
672
673 let test_input = Tensor::from_slice(&[1.0, 1.0], vec![1, 2]).unwrap();
675 let original_output = original_network.forward_no_grad(&test_input);
676
677 println!("Original network output: {:?}", original_output.data());
678
679 original_network.save_json("temp_feedforward_network")?;
681
682 let loaded_network = FeedForwardNetwork::load_json("temp_feedforward_network", config)?;
684 let loaded_output = loaded_network.forward_no_grad(&test_input);
685
686 println!("Loaded network output: {:?}", loaded_output.data());
687
688 let match_check = original_output
690 .data()
691 .iter()
692 .zip(loaded_output.data().iter())
693 .all(|(a, b)| (a - b).abs() < 1e-6);
694
695 println!(
696 "Serialization verification: {}",
697 if match_check { "PASSED" } else { "FAILED" }
698 );
699
700 Ok(())
701}
702
703fn cleanup_temp_files() -> Result<(), Box<dyn std::error::Error>> {
705 println!("\n--- Cleanup ---");
706
707 for i in 0..10 {
709 let weight_file = format!("temp_feedforward_network_layer_{}_weight.json", i);
711 let bias_file = format!("temp_feedforward_network_layer_{}_bias.json", i);
712
713 if fs::metadata(&weight_file).is_ok() {
714 fs::remove_file(&weight_file)?;
715 println!("Removed: {}", weight_file);
716 }
717 if fs::metadata(&bias_file).is_ok() {
718 fs::remove_file(&bias_file)?;
719 println!("Removed: {}", bias_file);
720 }
721 }
722
723 println!("Cleanup completed");
724 Ok(())
725}
726
727#[cfg(test)]
728mod tests {
729 use super::*;
730
731 #[test]
732 fn test_relu_activation() {
733 let input = Tensor::from_slice(&[-2.0, -1.0, 0.0, 1.0, 2.0], vec![1, 5]).unwrap();
734 let output = ReLU::forward(&input);
735 let expected = vec![0.0, 0.0, 0.0, 1.0, 2.0];
736
737 assert_eq!(output.data(), &expected);
738 }
739
740 #[test]
741 fn test_network_creation() {
742 let config = FeedForwardConfig {
743 input_size: 3,
744 hidden_sizes: vec![5, 4],
745 output_size: 2,
746 use_bias: true,
747 };
748 let network = FeedForwardNetwork::new(config, Some(42));
749
750 assert_eq!(network.num_layers(), 3); assert_eq!(network.parameter_count(), 3 * 5 + 5 + 5 * 4 + 4 + 4 * 2 + 2);
752 }
754
755 #[test]
756 fn test_forward_pass() {
757 let config = FeedForwardConfig {
758 input_size: 2,
759 hidden_sizes: vec![3],
760 output_size: 1,
761 use_bias: true,
762 };
763 let network = FeedForwardNetwork::new(config, Some(43));
764
765 let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
766 let output = network.forward(&input);
767
768 assert_eq!(output.shape().dims(), vec![1, 1]);
769 assert!(output.requires_grad());
770 }
771
772 #[test]
773 fn test_batch_forward_pass() {
774 let config = FeedForwardConfig {
775 input_size: 2,
776 hidden_sizes: vec![3],
777 output_size: 1,
778 use_bias: true,
779 };
780 let network = FeedForwardNetwork::new(config, Some(44));
781
782 let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
783 let output = network.forward(&batch_input);
784
785 assert_eq!(output.shape().dims(), vec![2, 1]);
786 }
787
788 #[test]
789 fn test_no_grad_forward() {
790 let config = FeedForwardConfig::default();
791 let network = FeedForwardNetwork::new(config, Some(45));
792
793 let input = Tensor::randn(vec![1, 4], Some(46));
794 let output = network.forward_no_grad(&input);
795
796 assert!(!output.requires_grad());
797 }
798
799 #[test]
800 fn test_parameter_collection() {
801 let config = FeedForwardConfig {
802 input_size: 2,
803 hidden_sizes: vec![3],
804 output_size: 1,
805 use_bias: true,
806 };
807 let mut network = FeedForwardNetwork::new(config, Some(47));
808
809 let params = network.parameters();
810 assert_eq!(params.len(), 4); }
812}