1use std::fs;
36use train_station::{
37 gradtrack::NoGradTrack,
38 optimizers::{Adam, Optimizer},
39 serialization::StructSerializable,
40 Tensor,
41};
42
43pub struct ReLU;
45
46impl ReLU {
47 pub fn forward(input: &Tensor) -> Tensor {
49 input.relu()
50 }
51
52 pub fn forward_no_grad(input: &Tensor) -> Tensor {
54 let _guard = NoGradTrack::new();
55 Self::forward(input)
56 }
57}
58
59#[derive(Debug)]
61pub struct LinearLayer {
62 pub weight: Tensor,
63 pub bias: Tensor,
64 pub input_size: usize,
65 pub output_size: usize,
66}
67
68impl LinearLayer {
69 pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
70 let scale = (1.0 / input_size as f32).sqrt();
71
72 let weight = Tensor::randn(vec![input_size, output_size], seed)
73 .mul_scalar(scale)
74 .with_requires_grad();
75 let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
76
77 Self {
78 weight,
79 bias,
80 input_size,
81 output_size,
82 }
83 }
84
85 pub fn forward(&self, input: &Tensor) -> Tensor {
86 let output = input.matmul(&self.weight);
87 output.add_tensor(&self.bias)
88 }
89
90 pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
91 let _guard = NoGradTrack::new();
92 self.forward(input)
93 }
94
95 pub fn parameters(&mut self) -> Vec<&mut Tensor> {
96 vec![&mut self.weight, &mut self.bias]
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct FeedForwardConfig {
103 pub input_size: usize,
104 pub hidden_sizes: Vec<usize>,
105 pub output_size: usize,
106 pub use_bias: bool,
107}
108
109impl Default for FeedForwardConfig {
110 fn default() -> Self {
111 Self {
112 input_size: 4,
113 hidden_sizes: vec![8, 4],
114 output_size: 2,
115 use_bias: true,
116 }
117 }
118}
119
120pub struct FeedForwardNetwork {
122 layers: Vec<LinearLayer>,
123 config: FeedForwardConfig,
124}
125
126impl FeedForwardNetwork {
127 pub fn new(config: FeedForwardConfig, seed: Option<u64>) -> Self {
129 let mut layers = Vec::new();
130 let mut current_size = config.input_size;
131 let mut current_seed = seed;
132
133 for &hidden_size in &config.hidden_sizes {
135 layers.push(LinearLayer::new(current_size, hidden_size, current_seed));
136 current_size = hidden_size;
137 current_seed = current_seed.map(|s| s + 1);
138 }
139
140 layers.push(LinearLayer::new(
142 current_size,
143 config.output_size,
144 current_seed,
145 ));
146
147 Self { layers, config }
148 }
149
150 pub fn forward(&self, input: &Tensor) -> Tensor {
152 let mut x = input.clone();
153
154 for layer in &self.layers[..self.layers.len() - 1] {
156 x = layer.forward(&x);
157 x = ReLU::forward(&x);
158 }
159
160 if let Some(final_layer) = self.layers.last() {
162 x = final_layer.forward(&x);
163 }
164
165 x
166 }
167
168 pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
170 let _guard = NoGradTrack::new();
171 self.forward(input)
172 }
173
174 pub fn parameters(&mut self) -> Vec<&mut Tensor> {
176 let mut params = Vec::new();
177 for layer in &mut self.layers {
178 params.extend(layer.parameters());
179 }
180 params
181 }
182
183 pub fn num_layers(&self) -> usize {
185 self.layers.len()
186 }
187
188 pub fn parameter_count(&self) -> usize {
190 let mut count = 0;
191 let mut current_size = self.config.input_size;
192
193 for &hidden_size in &self.config.hidden_sizes {
194 count += current_size * hidden_size + hidden_size; current_size = hidden_size;
196 }
197
198 count += current_size * self.config.output_size + self.config.output_size;
200
201 count
202 }
203
204 pub fn save_json(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
206 if let Some(parent) = std::path::Path::new(path).parent() {
207 fs::create_dir_all(parent)?;
208 }
209
210 for (i, layer) in self.layers.iter().enumerate() {
211 let layer_path = format!("{}_layer_{}", path, i);
212 let weight_path = format!("{}_weight.json", layer_path);
213 let bias_path = format!("{}_bias.json", layer_path);
214
215 layer.weight.save_json(&weight_path)?;
216 layer.bias.save_json(&bias_path)?;
217 }
218
219 println!(
220 "Saved feed-forward network to {} ({} layers)",
221 path,
222 self.layers.len()
223 );
224 Ok(())
225 }
226
227 pub fn load_json(
229 path: &str,
230 config: FeedForwardConfig,
231 ) -> Result<Self, Box<dyn std::error::Error>> {
232 let mut layers = Vec::new();
233 let mut current_size = config.input_size;
234 let mut layer_idx = 0;
235
236 for &hidden_size in &config.hidden_sizes {
238 let layer_path = format!("{}_layer_{}", path, layer_idx);
239 let weight_path = format!("{}_weight.json", layer_path);
240 let bias_path = format!("{}_bias.json", layer_path);
241
242 let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
243 let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
244
245 layers.push(LinearLayer {
246 weight,
247 bias,
248 input_size: current_size,
249 output_size: hidden_size,
250 });
251
252 current_size = hidden_size;
253 layer_idx += 1;
254 }
255
256 let layer_path = format!("{}_layer_{}", path, layer_idx);
258 let weight_path = format!("{}_weight.json", layer_path);
259 let bias_path = format!("{}_bias.json", layer_path);
260
261 let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
262 let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
263
264 layers.push(LinearLayer {
265 weight,
266 bias,
267 input_size: current_size,
268 output_size: config.output_size,
269 });
270
271 Ok(Self { layers, config })
272 }
273}
274
275fn main() -> Result<(), Box<dyn std::error::Error>> {
276 println!("=== Feed-Forward Network Example ===\n");
277
278 demonstrate_network_creation();
279 demonstrate_forward_pass();
280 demonstrate_configurable_architectures();
281 demonstrate_training_workflow()?;
282 demonstrate_comprehensive_training()?;
283 demonstrate_network_serialization()?;
284 cleanup_temp_files()?;
285
286 println!("\n=== Example completed successfully! ===");
287 Ok(())
288}
289
290fn demonstrate_network_creation() {
292 println!("--- Network Creation ---");
293
294 let config = FeedForwardConfig::default();
296 let network = FeedForwardNetwork::new(config.clone(), Some(42));
297
298 println!("Default network configuration:");
299 println!(" Input size: {}", config.input_size);
300 println!(" Hidden sizes: {:?}", config.hidden_sizes);
301 println!(" Output size: {}", config.output_size);
302 println!(" Number of layers: {}", network.num_layers());
303 println!(" Total parameters: {}", network.parameter_count());
304
305 let configs = [
307 FeedForwardConfig {
308 input_size: 2,
309 hidden_sizes: vec![4],
310 output_size: 1,
311 use_bias: true,
312 },
313 FeedForwardConfig {
314 input_size: 8,
315 hidden_sizes: vec![16, 8, 4],
316 output_size: 3,
317 use_bias: true,
318 },
319 FeedForwardConfig {
320 input_size: 10,
321 hidden_sizes: vec![20, 15, 10, 5],
322 output_size: 2,
323 use_bias: true,
324 },
325 ];
326
327 for (i, config) in configs.iter().enumerate() {
328 let network = FeedForwardNetwork::new(config.clone(), Some(42 + i as u64));
329 println!("\nCustom network {}:", i + 1);
330 println!(
331 " Architecture: {} -> {:?} -> {}",
332 config.input_size, config.hidden_sizes, config.output_size
333 );
334 println!(" Layers: {}", network.num_layers());
335 println!(" Parameters: {}", network.parameter_count());
336 }
337}
338
339fn demonstrate_forward_pass() {
341 println!("\n--- Forward Pass ---");
342
343 let config = FeedForwardConfig {
344 input_size: 3,
345 hidden_sizes: vec![5, 3],
346 output_size: 2,
347 use_bias: true,
348 };
349 let network = FeedForwardNetwork::new(config, Some(43));
350
351 let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
353 let output = network.forward(&input);
354
355 println!("Single input forward pass:");
356 println!(" Input shape: {:?}", input.shape().dims());
357 println!(" Output shape: {:?}", output.shape().dims());
358 println!(" Output: {:?}", output.data());
359 println!(" Output requires grad: {}", output.requires_grad());
360
361 let batch_input = Tensor::from_slice(
363 &[
364 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, ],
368 vec![3, 3],
369 )
370 .unwrap();
371 let batch_output = network.forward(&batch_input);
372
373 println!("Batch input forward pass:");
374 println!(" Input shape: {:?}", batch_input.shape().dims());
375 println!(" Output shape: {:?}", batch_output.shape().dims());
376 println!(" Output requires grad: {}", batch_output.requires_grad());
377
378 let output_no_grad = network.forward_no_grad(&input);
380 println!("No-grad comparison:");
381 println!(" Same values: {}", output.data() == output_no_grad.data());
382 println!(" With grad requires grad: {}", output.requires_grad());
383 println!(
384 " No grad requires grad: {}",
385 output_no_grad.requires_grad()
386 );
387}
388
389fn demonstrate_configurable_architectures() {
391 println!("\n--- Configurable Architectures ---");
392
393 let architectures = vec![
394 ("Shallow", vec![8]),
395 ("Medium", vec![16, 8]),
396 ("Deep", vec![32, 16, 8, 4]),
397 ("Wide", vec![64, 32]),
398 ("Bottleneck", vec![16, 4, 16]),
399 ];
400
401 for (name, hidden_sizes) in architectures {
402 let config = FeedForwardConfig {
403 input_size: 10,
404 hidden_sizes,
405 output_size: 3,
406 use_bias: true,
407 };
408
409 let network = FeedForwardNetwork::new(config.clone(), Some(44));
410
411 let test_input = Tensor::randn(vec![5, 10], Some(45)); let output = network.forward_no_grad(&test_input);
414
415 println!("{} network:", name);
416 println!(" Architecture: 10 -> {:?} -> 3", config.hidden_sizes);
417 println!(" Parameters: {}", network.parameter_count());
418 println!(" Test output shape: {:?}", output.shape().dims());
419 println!(
420 " Output range: [{:.3}, {:.3}]",
421 output.data().iter().fold(f32::INFINITY, |a, &b| a.min(b)),
422 output
423 .data()
424 .iter()
425 .fold(f32::NEG_INFINITY, |a, &b| a.max(b))
426 );
427 }
428}
429
430fn demonstrate_training_workflow() -> Result<(), Box<dyn std::error::Error>> {
432 println!("\n--- Training Workflow ---");
433
434 let config = FeedForwardConfig {
436 input_size: 2,
437 hidden_sizes: vec![4, 3],
438 output_size: 1,
439 use_bias: true,
440 };
441 let mut network = FeedForwardNetwork::new(config, Some(46));
442
443 println!("Training network: 2 -> [4, 3] -> 1");
444
445 let x_data = Tensor::from_slice(
447 &[
448 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, ],
453 vec![4, 2],
454 )
455 .unwrap();
456
457 let y_true = Tensor::from_slice(&[0.0, 1.0, 1.0, 0.0], vec![4, 1]).unwrap();
458
459 println!("Training on XOR problem:");
460 println!(" Input shape: {:?}", x_data.shape().dims());
461 println!(" Target shape: {:?}", y_true.shape().dims());
462
463 let mut optimizer = Adam::with_learning_rate(0.1);
465 let params = network.parameters();
466 for param in ¶ms {
467 optimizer.add_parameter(param);
468 }
469
470 let num_epochs = 50;
472 let mut losses = Vec::new();
473
474 for epoch in 0..num_epochs {
475 let y_pred = network.forward(&x_data);
477
478 let diff = y_pred.sub_tensor(&y_true);
480 let mut loss = diff.pow_scalar(2.0).mean();
481
482 loss.backward(None);
484
485 let mut params = network.parameters();
487 optimizer.step(&mut params);
488 optimizer.zero_grad(&mut params);
489
490 losses.push(loss.value());
491
492 if epoch % 10 == 0 || epoch == num_epochs - 1 {
494 println!("Epoch {:2}: Loss = {:.6}", epoch, loss.value());
495 }
496 }
497
498 let final_predictions = network.forward_no_grad(&x_data);
500 println!("\nFinal predictions vs targets:");
501 for i in 0..4 {
502 let pred = final_predictions.data()[i];
503 let target = y_true.data()[i];
504 let input_x = x_data.data()[i * 2];
505 let input_y = x_data.data()[i * 2 + 1];
506 println!(
507 " [{:.0}, {:.0}] -> pred: {:.3}, target: {:.0}, error: {:.3}",
508 input_x,
509 input_y,
510 pred,
511 target,
512 (pred - target).abs()
513 );
514 }
515
516 Ok(())
517}
518
519fn demonstrate_comprehensive_training() -> Result<(), Box<dyn std::error::Error>> {
521 println!("\n--- Comprehensive Training (100+ Steps) ---");
522
523 let config = FeedForwardConfig {
525 input_size: 3,
526 hidden_sizes: vec![8, 6, 4],
527 output_size: 2,
528 use_bias: true,
529 };
530 let mut network = FeedForwardNetwork::new(config, Some(47));
531
532 println!("Network architecture: 3 -> [8, 6, 4] -> 2");
533 println!("Total parameters: {}", network.parameter_count());
534
535 let num_samples = 32;
538 let mut x_vec = Vec::new();
539 let mut y_vec = Vec::new();
540
541 for i in 0..num_samples {
542 let x1 = (i as f32 / num_samples as f32) * 2.0 - 1.0; let x2 = ((i * 2) as f32 / num_samples as f32) * 2.0 - 1.0;
544 let x3 = ((i * 3) as f32 / num_samples as f32) * 2.0 - 1.0;
545
546 let y1 = x1 + 2.0 * x2 - x3;
547 let y2 = x1 * x2 + x3;
548
549 x_vec.extend_from_slice(&[x1, x2, x3]);
550 y_vec.extend_from_slice(&[y1, y2]);
551 }
552
553 let x_data = Tensor::from_slice(&x_vec, vec![num_samples, 3]).unwrap();
554 let y_true = Tensor::from_slice(&y_vec, vec![num_samples, 2]).unwrap();
555
556 println!("Training data:");
557 println!(" {} samples", num_samples);
558 println!(" Input shape: {:?}", x_data.shape().dims());
559 println!(" Target shape: {:?}", y_true.shape().dims());
560
561 let mut optimizer = Adam::with_learning_rate(0.01);
563 let params = network.parameters();
564 for param in ¶ms {
565 optimizer.add_parameter(param);
566 }
567
568 let num_epochs = 150;
570 let mut losses = Vec::new();
571 let mut best_loss = f32::INFINITY;
572 let mut patience_counter = 0;
573 let patience = 20;
574
575 println!("Starting comprehensive training...");
576
577 for epoch in 0..num_epochs {
578 let y_pred = network.forward(&x_data);
580
581 let diff = y_pred.sub_tensor(&y_true);
583 let mut loss = diff.pow_scalar(2.0).mean();
584
585 loss.backward(None);
587
588 let mut params = network.parameters();
590 optimizer.step(&mut params);
591 optimizer.zero_grad(&mut params);
592
593 let current_loss = loss.value();
594 losses.push(current_loss);
595
596 if epoch > 0 && epoch % 30 == 0 {
598 let new_lr = optimizer.learning_rate() * 0.8;
599 optimizer.set_learning_rate(new_lr);
600 println!(" Reduced learning rate to {:.4}", new_lr);
601 }
602
603 if current_loss < best_loss {
605 best_loss = current_loss;
606 patience_counter = 0;
607 } else {
608 patience_counter += 1;
609 }
610
611 if epoch % 25 == 0 || epoch == num_epochs - 1 {
613 println!(
614 "Epoch {:3}: Loss = {:.6}, LR = {:.4}, Best = {:.6}",
615 epoch,
616 current_loss,
617 optimizer.learning_rate(),
618 best_loss
619 );
620 }
621
622 if patience_counter >= patience && epoch > 50 {
624 println!("Early stopping at epoch {} (patience exceeded)", epoch);
625 break;
626 }
627 }
628
629 let final_predictions = network.forward_no_grad(&x_data);
631
632 let final_loss = losses[losses.len() - 1];
634 let initial_loss = losses[0];
635 let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
636
637 println!("\nTraining completed!");
638 println!(" Initial loss: {:.6}", initial_loss);
639 println!(" Final loss: {:.6}", final_loss);
640 println!(" Best loss: {:.6}", best_loss);
641 println!(" Loss reduction: {:.1}%", loss_reduction);
642 println!(" Final learning rate: {:.4}", optimizer.learning_rate());
643
644 println!("\nSample predictions (first 5):");
646 for i in 0..5.min(num_samples) {
647 let pred1 = final_predictions.data()[i * 2];
648 let pred2 = final_predictions.data()[i * 2 + 1];
649 let true1 = y_true.data()[i * 2];
650 let true2 = y_true.data()[i * 2 + 1];
651
652 println!(
653 " Sample {}: pred=[{:.3}, {:.3}], true=[{:.3}, {:.3}], error=[{:.3}, {:.3}]",
654 i + 1,
655 pred1,
656 pred2,
657 true1,
658 true2,
659 (pred1 - true1).abs(),
660 (pred2 - true2).abs()
661 );
662 }
663
664 Ok(())
665}
666
667fn demonstrate_network_serialization() -> Result<(), Box<dyn std::error::Error>> {
669 println!("\n--- Network Serialization ---");
670
671 let config = FeedForwardConfig {
673 input_size: 2,
674 hidden_sizes: vec![4, 2],
675 output_size: 1,
676 use_bias: true,
677 };
678 let mut original_network = FeedForwardNetwork::new(config.clone(), Some(48));
679
680 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
682 let y_true = Tensor::from_slice(&[5.0, 11.0], vec![2, 1]).unwrap();
683
684 let mut optimizer = Adam::with_learning_rate(0.01);
685 let params = original_network.parameters();
686 for param in ¶ms {
687 optimizer.add_parameter(param);
688 }
689
690 for _ in 0..20 {
691 let y_pred = original_network.forward(&x_data);
692 let mut loss = (y_pred.sub_tensor(&y_true)).pow_scalar(2.0).mean();
693 loss.backward(None);
694
695 let mut params = original_network.parameters();
696 optimizer.step(&mut params);
697 optimizer.zero_grad(&mut params);
698 }
699
700 let test_input = Tensor::from_slice(&[1.0, 1.0], vec![1, 2]).unwrap();
702 let original_output = original_network.forward_no_grad(&test_input);
703
704 println!("Original network output: {:?}", original_output.data());
705
706 original_network.save_json("temp_feedforward_network")?;
708
709 let loaded_network = FeedForwardNetwork::load_json("temp_feedforward_network", config)?;
711 let loaded_output = loaded_network.forward_no_grad(&test_input);
712
713 println!("Loaded network output: {:?}", loaded_output.data());
714
715 let match_check = original_output
717 .data()
718 .iter()
719 .zip(loaded_output.data().iter())
720 .all(|(a, b)| (a - b).abs() < 1e-6);
721
722 println!(
723 "Serialization verification: {}",
724 if match_check { "PASSED" } else { "FAILED" }
725 );
726
727 Ok(())
728}
729
730fn cleanup_temp_files() -> Result<(), Box<dyn std::error::Error>> {
732 println!("\n--- Cleanup ---");
733
734 for i in 0..10 {
736 let weight_file = format!("temp_feedforward_network_layer_{}_weight.json", i);
738 let bias_file = format!("temp_feedforward_network_layer_{}_bias.json", i);
739
740 if fs::metadata(&weight_file).is_ok() {
741 fs::remove_file(&weight_file)?;
742 println!("Removed: {}", weight_file);
743 }
744 if fs::metadata(&bias_file).is_ok() {
745 fs::remove_file(&bias_file)?;
746 println!("Removed: {}", bias_file);
747 }
748 }
749
750 println!("Cleanup completed");
751 Ok(())
752}
753
754#[cfg(test)]
755mod tests {
756 use super::*;
757
758 #[test]
759 fn test_relu_activation() {
760 let input = Tensor::from_slice(&[-2.0, -1.0, 0.0, 1.0, 2.0], vec![1, 5]).unwrap();
761 let output = ReLU::forward(&input);
762 let expected = vec![0.0, 0.0, 0.0, 1.0, 2.0];
763
764 assert_eq!(output.data(), &expected);
765 }
766
767 #[test]
768 fn test_network_creation() {
769 let config = FeedForwardConfig {
770 input_size: 3,
771 hidden_sizes: vec![5, 4],
772 output_size: 2,
773 use_bias: true,
774 };
775 let network = FeedForwardNetwork::new(config, Some(42));
776
777 assert_eq!(network.num_layers(), 3); assert_eq!(network.parameter_count(), 3 * 5 + 5 + 5 * 4 + 4 + 4 * 2 + 2);
779 }
781
782 #[test]
783 fn test_forward_pass() {
784 let config = FeedForwardConfig {
785 input_size: 2,
786 hidden_sizes: vec![3],
787 output_size: 1,
788 use_bias: true,
789 };
790 let network = FeedForwardNetwork::new(config, Some(43));
791
792 let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
793 let output = network.forward(&input);
794
795 assert_eq!(output.shape().dims(), vec![1, 1]);
796 assert!(output.requires_grad());
797 }
798
799 #[test]
800 fn test_batch_forward_pass() {
801 let config = FeedForwardConfig {
802 input_size: 2,
803 hidden_sizes: vec![3],
804 output_size: 1,
805 use_bias: true,
806 };
807 let network = FeedForwardNetwork::new(config, Some(44));
808
809 let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
810 let output = network.forward(&batch_input);
811
812 assert_eq!(output.shape().dims(), vec![2, 1]);
813 }
814
815 #[test]
816 fn test_no_grad_forward() {
817 let config = FeedForwardConfig::default();
818 let network = FeedForwardNetwork::new(config, Some(45));
819
820 let input = Tensor::randn(vec![1, 4], Some(46));
821 let output = network.forward_no_grad(&input);
822
823 assert!(!output.requires_grad());
824 }
825
826 #[test]
827 fn test_parameter_collection() {
828 let config = FeedForwardConfig {
829 input_size: 2,
830 hidden_sizes: vec![3],
831 output_size: 1,
832 use_bias: true,
833 };
834 let mut network = FeedForwardNetwork::new(config, Some(47));
835
836 let params = network.parameters();
837 assert_eq!(params.len(), 4); }
839}