1use std::fs;
36use train_station::{
37 optimizers::{Adam, Optimizer},
38 serialization::StructSerializable,
39 NoGradTrack, Tensor,
40};
41
42pub struct ReLU;
44
45impl ReLU {
46 pub fn forward(input: &Tensor) -> Tensor {
48 input.relu()
49 }
50
51 pub fn forward_no_grad(input: &Tensor) -> Tensor {
53 let _guard = NoGradTrack::new();
54 Self::forward(input)
55 }
56}
57
58#[derive(Debug)]
60pub struct LinearLayer {
61 pub weight: Tensor,
62 pub bias: Tensor,
63 pub input_size: usize,
64 pub output_size: usize,
65}
66
67impl LinearLayer {
68 pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
69 let scale = (1.0 / input_size as f32).sqrt();
70
71 let weight = Tensor::randn(vec![input_size, output_size], seed)
72 .mul_scalar(scale)
73 .with_requires_grad();
74 let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
75
76 Self {
77 weight,
78 bias,
79 input_size,
80 output_size,
81 }
82 }
83
84 pub fn forward(&self, input: &Tensor) -> Tensor {
85 let output = input.matmul(&self.weight);
86 output.add_tensor(&self.bias)
87 }
88
89 pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
90 let _guard = NoGradTrack::new();
91 self.forward(input)
92 }
93
94 pub fn parameters(&mut self) -> Vec<&mut Tensor> {
95 vec![&mut self.weight, &mut self.bias]
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct FeedForwardConfig {
102 pub input_size: usize,
103 pub hidden_sizes: Vec<usize>,
104 pub output_size: usize,
105 pub use_bias: bool,
106}
107
108impl Default for FeedForwardConfig {
109 fn default() -> Self {
110 Self {
111 input_size: 4,
112 hidden_sizes: vec![8, 4],
113 output_size: 2,
114 use_bias: true,
115 }
116 }
117}
118
119pub struct FeedForwardNetwork {
121 layers: Vec<LinearLayer>,
122 config: FeedForwardConfig,
123}
124
125impl FeedForwardNetwork {
126 pub fn new(config: FeedForwardConfig, seed: Option<u64>) -> Self {
128 let mut layers = Vec::new();
129 let mut current_size = config.input_size;
130 let mut current_seed = seed;
131
132 for &hidden_size in &config.hidden_sizes {
134 layers.push(LinearLayer::new(current_size, hidden_size, current_seed));
135 current_size = hidden_size;
136 current_seed = current_seed.map(|s| s + 1);
137 }
138
139 layers.push(LinearLayer::new(
141 current_size,
142 config.output_size,
143 current_seed,
144 ));
145
146 Self { layers, config }
147 }
148
149 pub fn forward(&self, input: &Tensor) -> Tensor {
151 let mut x = input.clone();
152
153 for layer in &self.layers[..self.layers.len() - 1] {
155 x = layer.forward(&x);
156 x = ReLU::forward(&x);
157 }
158
159 if let Some(final_layer) = self.layers.last() {
161 x = final_layer.forward(&x);
162 }
163
164 x
165 }
166
167 pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
169 let _guard = NoGradTrack::new();
170 self.forward(input)
171 }
172
173 pub fn parameters(&mut self) -> Vec<&mut Tensor> {
175 let mut params = Vec::new();
176 for layer in &mut self.layers {
177 params.extend(layer.parameters());
178 }
179 params
180 }
181
182 pub fn num_layers(&self) -> usize {
184 self.layers.len()
185 }
186
187 pub fn parameter_count(&self) -> usize {
189 let mut count = 0;
190 let mut current_size = self.config.input_size;
191
192 for &hidden_size in &self.config.hidden_sizes {
193 count += current_size * hidden_size + hidden_size; current_size = hidden_size;
195 }
196
197 count += current_size * self.config.output_size + self.config.output_size;
199
200 count
201 }
202
203 pub fn save_json(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
205 if let Some(parent) = std::path::Path::new(path).parent() {
206 fs::create_dir_all(parent)?;
207 }
208
209 for (i, layer) in self.layers.iter().enumerate() {
210 let layer_path = format!("{}_layer_{}", path, i);
211 let weight_path = format!("{}_weight.json", layer_path);
212 let bias_path = format!("{}_bias.json", layer_path);
213
214 layer.weight.save_json(&weight_path)?;
215 layer.bias.save_json(&bias_path)?;
216 }
217
218 println!(
219 "Saved feed-forward network to {} ({} layers)",
220 path,
221 self.layers.len()
222 );
223 Ok(())
224 }
225
226 pub fn load_json(
228 path: &str,
229 config: FeedForwardConfig,
230 ) -> Result<Self, Box<dyn std::error::Error>> {
231 let mut layers = Vec::new();
232 let mut current_size = config.input_size;
233 let mut layer_idx = 0;
234
235 for &hidden_size in &config.hidden_sizes {
237 let layer_path = format!("{}_layer_{}", path, layer_idx);
238 let weight_path = format!("{}_weight.json", layer_path);
239 let bias_path = format!("{}_bias.json", layer_path);
240
241 let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
242 let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
243
244 layers.push(LinearLayer {
245 weight,
246 bias,
247 input_size: current_size,
248 output_size: hidden_size,
249 });
250
251 current_size = hidden_size;
252 layer_idx += 1;
253 }
254
255 let layer_path = format!("{}_layer_{}", path, layer_idx);
257 let weight_path = format!("{}_weight.json", layer_path);
258 let bias_path = format!("{}_bias.json", layer_path);
259
260 let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
261 let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
262
263 layers.push(LinearLayer {
264 weight,
265 bias,
266 input_size: current_size,
267 output_size: config.output_size,
268 });
269
270 Ok(Self { layers, config })
271 }
272}
273
274fn main() -> Result<(), Box<dyn std::error::Error>> {
275 println!("=== Feed-Forward Network Example ===\n");
276
277 demonstrate_network_creation();
278 demonstrate_forward_pass();
279 demonstrate_configurable_architectures();
280 demonstrate_training_workflow()?;
281 demonstrate_comprehensive_training()?;
282 demonstrate_network_serialization()?;
283 cleanup_temp_files()?;
284
285 println!("\n=== Example completed successfully! ===");
286 Ok(())
287}
288
289fn demonstrate_network_creation() {
291 println!("--- Network Creation ---");
292
293 let config = FeedForwardConfig::default();
295 let network = FeedForwardNetwork::new(config.clone(), Some(42));
296
297 println!("Default network configuration:");
298 println!(" Input size: {}", config.input_size);
299 println!(" Hidden sizes: {:?}", config.hidden_sizes);
300 println!(" Output size: {}", config.output_size);
301 println!(" Number of layers: {}", network.num_layers());
302 println!(" Total parameters: {}", network.parameter_count());
303
304 let configs = [
306 FeedForwardConfig {
307 input_size: 2,
308 hidden_sizes: vec![4],
309 output_size: 1,
310 use_bias: true,
311 },
312 FeedForwardConfig {
313 input_size: 8,
314 hidden_sizes: vec![16, 8, 4],
315 output_size: 3,
316 use_bias: true,
317 },
318 FeedForwardConfig {
319 input_size: 10,
320 hidden_sizes: vec![20, 15, 10, 5],
321 output_size: 2,
322 use_bias: true,
323 },
324 ];
325
326 for (i, config) in configs.iter().enumerate() {
327 let network = FeedForwardNetwork::new(config.clone(), Some(42 + i as u64));
328 println!("\nCustom network {}:", i + 1);
329 println!(
330 " Architecture: {} -> {:?} -> {}",
331 config.input_size, config.hidden_sizes, config.output_size
332 );
333 println!(" Layers: {}", network.num_layers());
334 println!(" Parameters: {}", network.parameter_count());
335 }
336}
337
338fn demonstrate_forward_pass() {
340 println!("\n--- Forward Pass ---");
341
342 let config = FeedForwardConfig {
343 input_size: 3,
344 hidden_sizes: vec![5, 3],
345 output_size: 2,
346 use_bias: true,
347 };
348 let network = FeedForwardNetwork::new(config, Some(43));
349
350 let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
352 let output = network.forward(&input);
353
354 println!("Single input forward pass:");
355 println!(" Input shape: {:?}", input.shape().dims);
356 println!(" Output shape: {:?}", output.shape().dims);
357 println!(" Output: {:?}", output.data());
358 println!(" Output requires grad: {}", output.requires_grad());
359
360 let batch_input = Tensor::from_slice(
362 &[
363 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, ],
367 vec![3, 3],
368 )
369 .unwrap();
370 let batch_output = network.forward(&batch_input);
371
372 println!("Batch input forward pass:");
373 println!(" Input shape: {:?}", batch_input.shape().dims);
374 println!(" Output shape: {:?}", batch_output.shape().dims);
375 println!(" Output requires grad: {}", batch_output.requires_grad());
376
377 let output_no_grad = network.forward_no_grad(&input);
379 println!("No-grad comparison:");
380 println!(" Same values: {}", output.data() == output_no_grad.data());
381 println!(" With grad requires grad: {}", output.requires_grad());
382 println!(
383 " No grad requires grad: {}",
384 output_no_grad.requires_grad()
385 );
386}
387
388fn demonstrate_configurable_architectures() {
390 println!("\n--- Configurable Architectures ---");
391
392 let architectures = vec![
393 ("Shallow", vec![8]),
394 ("Medium", vec![16, 8]),
395 ("Deep", vec![32, 16, 8, 4]),
396 ("Wide", vec![64, 32]),
397 ("Bottleneck", vec![16, 4, 16]),
398 ];
399
400 for (name, hidden_sizes) in architectures {
401 let config = FeedForwardConfig {
402 input_size: 10,
403 hidden_sizes,
404 output_size: 3,
405 use_bias: true,
406 };
407
408 let network = FeedForwardNetwork::new(config.clone(), Some(44));
409
410 let test_input = Tensor::randn(vec![5, 10], Some(45)); let output = network.forward_no_grad(&test_input);
413
414 println!("{} network:", name);
415 println!(" Architecture: 10 -> {:?} -> 3", config.hidden_sizes);
416 println!(" Parameters: {}", network.parameter_count());
417 println!(" Test output shape: {:?}", output.shape().dims);
418 println!(
419 " Output range: [{:.3}, {:.3}]",
420 output.data().iter().fold(f32::INFINITY, |a, &b| a.min(b)),
421 output
422 .data()
423 .iter()
424 .fold(f32::NEG_INFINITY, |a, &b| a.max(b))
425 );
426 }
427}
428
429fn demonstrate_training_workflow() -> Result<(), Box<dyn std::error::Error>> {
431 println!("\n--- Training Workflow ---");
432
433 let config = FeedForwardConfig {
435 input_size: 2,
436 hidden_sizes: vec![4, 3],
437 output_size: 1,
438 use_bias: true,
439 };
440 let mut network = FeedForwardNetwork::new(config, Some(46));
441
442 println!("Training network: 2 -> [4, 3] -> 1");
443
444 let x_data = Tensor::from_slice(
446 &[
447 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, ],
452 vec![4, 2],
453 )
454 .unwrap();
455
456 let y_true = Tensor::from_slice(&[0.0, 1.0, 1.0, 0.0], vec![4, 1]).unwrap();
457
458 println!("Training on XOR problem:");
459 println!(" Input shape: {:?}", x_data.shape().dims);
460 println!(" Target shape: {:?}", y_true.shape().dims);
461
462 let mut optimizer = Adam::with_learning_rate(0.1);
464 let params = network.parameters();
465 for param in ¶ms {
466 optimizer.add_parameter(param);
467 }
468
469 let num_epochs = 50;
471 let mut losses = Vec::new();
472
473 for epoch in 0..num_epochs {
474 let y_pred = network.forward(&x_data);
476
477 let diff = y_pred.sub_tensor(&y_true);
479 let mut loss = diff.pow_scalar(2.0).mean();
480
481 loss.backward(None);
483
484 let mut params = network.parameters();
486 optimizer.step(&mut params);
487 optimizer.zero_grad(&mut params);
488
489 losses.push(loss.value());
490
491 if epoch % 10 == 0 || epoch == num_epochs - 1 {
493 println!("Epoch {:2}: Loss = {:.6}", epoch, loss.value());
494 }
495 }
496
497 let final_predictions = network.forward_no_grad(&x_data);
499 println!("\nFinal predictions vs targets:");
500 for i in 0..4 {
501 let pred = final_predictions.data()[i];
502 let target = y_true.data()[i];
503 let input_x = x_data.data()[i * 2];
504 let input_y = x_data.data()[i * 2 + 1];
505 println!(
506 " [{:.0}, {:.0}] -> pred: {:.3}, target: {:.0}, error: {:.3}",
507 input_x,
508 input_y,
509 pred,
510 target,
511 (pred - target).abs()
512 );
513 }
514
515 Ok(())
516}
517
518fn demonstrate_comprehensive_training() -> Result<(), Box<dyn std::error::Error>> {
520 println!("\n--- Comprehensive Training (100+ Steps) ---");
521
522 let config = FeedForwardConfig {
524 input_size: 3,
525 hidden_sizes: vec![8, 6, 4],
526 output_size: 2,
527 use_bias: true,
528 };
529 let mut network = FeedForwardNetwork::new(config, Some(47));
530
531 println!("Network architecture: 3 -> [8, 6, 4] -> 2");
532 println!("Total parameters: {}", network.parameter_count());
533
534 let num_samples = 32;
537 let mut x_vec = Vec::new();
538 let mut y_vec = Vec::new();
539
540 for i in 0..num_samples {
541 let x1 = (i as f32 / num_samples as f32) * 2.0 - 1.0; let x2 = ((i * 2) as f32 / num_samples as f32) * 2.0 - 1.0;
543 let x3 = ((i * 3) as f32 / num_samples as f32) * 2.0 - 1.0;
544
545 let y1 = x1 + 2.0 * x2 - x3;
546 let y2 = x1 * x2 + x3;
547
548 x_vec.extend_from_slice(&[x1, x2, x3]);
549 y_vec.extend_from_slice(&[y1, y2]);
550 }
551
552 let x_data = Tensor::from_slice(&x_vec, vec![num_samples, 3]).unwrap();
553 let y_true = Tensor::from_slice(&y_vec, vec![num_samples, 2]).unwrap();
554
555 println!("Training data:");
556 println!(" {} samples", num_samples);
557 println!(" Input shape: {:?}", x_data.shape().dims);
558 println!(" Target shape: {:?}", y_true.shape().dims);
559
560 let mut optimizer = Adam::with_learning_rate(0.01);
562 let params = network.parameters();
563 for param in ¶ms {
564 optimizer.add_parameter(param);
565 }
566
567 let num_epochs = 150;
569 let mut losses = Vec::new();
570 let mut best_loss = f32::INFINITY;
571 let mut patience_counter = 0;
572 let patience = 20;
573
574 println!("Starting comprehensive training...");
575
576 for epoch in 0..num_epochs {
577 let y_pred = network.forward(&x_data);
579
580 let diff = y_pred.sub_tensor(&y_true);
582 let mut loss = diff.pow_scalar(2.0).mean();
583
584 loss.backward(None);
586
587 let mut params = network.parameters();
589 optimizer.step(&mut params);
590 optimizer.zero_grad(&mut params);
591
592 let current_loss = loss.value();
593 losses.push(current_loss);
594
595 if epoch > 0 && epoch % 30 == 0 {
597 let new_lr = optimizer.learning_rate() * 0.8;
598 optimizer.set_learning_rate(new_lr);
599 println!(" Reduced learning rate to {:.4}", new_lr);
600 }
601
602 if current_loss < best_loss {
604 best_loss = current_loss;
605 patience_counter = 0;
606 } else {
607 patience_counter += 1;
608 }
609
610 if epoch % 25 == 0 || epoch == num_epochs - 1 {
612 println!(
613 "Epoch {:3}: Loss = {:.6}, LR = {:.4}, Best = {:.6}",
614 epoch,
615 current_loss,
616 optimizer.learning_rate(),
617 best_loss
618 );
619 }
620
621 if patience_counter >= patience && epoch > 50 {
623 println!("Early stopping at epoch {} (patience exceeded)", epoch);
624 break;
625 }
626 }
627
628 let final_predictions = network.forward_no_grad(&x_data);
630
631 let final_loss = losses[losses.len() - 1];
633 let initial_loss = losses[0];
634 let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
635
636 println!("\nTraining completed!");
637 println!(" Initial loss: {:.6}", initial_loss);
638 println!(" Final loss: {:.6}", final_loss);
639 println!(" Best loss: {:.6}", best_loss);
640 println!(" Loss reduction: {:.1}%", loss_reduction);
641 println!(" Final learning rate: {:.4}", optimizer.learning_rate());
642
643 println!("\nSample predictions (first 5):");
645 for i in 0..5.min(num_samples) {
646 let pred1 = final_predictions.data()[i * 2];
647 let pred2 = final_predictions.data()[i * 2 + 1];
648 let true1 = y_true.data()[i * 2];
649 let true2 = y_true.data()[i * 2 + 1];
650
651 println!(
652 " Sample {}: pred=[{:.3}, {:.3}], true=[{:.3}, {:.3}], error=[{:.3}, {:.3}]",
653 i + 1,
654 pred1,
655 pred2,
656 true1,
657 true2,
658 (pred1 - true1).abs(),
659 (pred2 - true2).abs()
660 );
661 }
662
663 Ok(())
664}
665
666fn demonstrate_network_serialization() -> Result<(), Box<dyn std::error::Error>> {
668 println!("\n--- Network Serialization ---");
669
670 let config = FeedForwardConfig {
672 input_size: 2,
673 hidden_sizes: vec![4, 2],
674 output_size: 1,
675 use_bias: true,
676 };
677 let mut original_network = FeedForwardNetwork::new(config.clone(), Some(48));
678
679 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
681 let y_true = Tensor::from_slice(&[5.0, 11.0], vec![2, 1]).unwrap();
682
683 let mut optimizer = Adam::with_learning_rate(0.01);
684 let params = original_network.parameters();
685 for param in ¶ms {
686 optimizer.add_parameter(param);
687 }
688
689 for _ in 0..20 {
690 let y_pred = original_network.forward(&x_data);
691 let mut loss = (y_pred.sub_tensor(&y_true)).pow_scalar(2.0).mean();
692 loss.backward(None);
693
694 let mut params = original_network.parameters();
695 optimizer.step(&mut params);
696 optimizer.zero_grad(&mut params);
697 }
698
699 let test_input = Tensor::from_slice(&[1.0, 1.0], vec![1, 2]).unwrap();
701 let original_output = original_network.forward_no_grad(&test_input);
702
703 println!("Original network output: {:?}", original_output.data());
704
705 original_network.save_json("temp_feedforward_network")?;
707
708 let loaded_network = FeedForwardNetwork::load_json("temp_feedforward_network", config)?;
710 let loaded_output = loaded_network.forward_no_grad(&test_input);
711
712 println!("Loaded network output: {:?}", loaded_output.data());
713
714 let match_check = original_output
716 .data()
717 .iter()
718 .zip(loaded_output.data().iter())
719 .all(|(a, b)| (a - b).abs() < 1e-6);
720
721 println!(
722 "Serialization verification: {}",
723 if match_check { "PASSED" } else { "FAILED" }
724 );
725
726 Ok(())
727}
728
729fn cleanup_temp_files() -> Result<(), Box<dyn std::error::Error>> {
731 println!("\n--- Cleanup ---");
732
733 for i in 0..10 {
735 let weight_file = format!("temp_feedforward_network_layer_{}_weight.json", i);
737 let bias_file = format!("temp_feedforward_network_layer_{}_bias.json", i);
738
739 if fs::metadata(&weight_file).is_ok() {
740 fs::remove_file(&weight_file)?;
741 println!("Removed: {}", weight_file);
742 }
743 if fs::metadata(&bias_file).is_ok() {
744 fs::remove_file(&bias_file)?;
745 println!("Removed: {}", bias_file);
746 }
747 }
748
749 println!("Cleanup completed");
750 Ok(())
751}
752
753#[cfg(test)]
754mod tests {
755 use super::*;
756
757 #[test]
758 fn test_relu_activation() {
759 let input = Tensor::from_slice(&[-2.0, -1.0, 0.0, 1.0, 2.0], vec![1, 5]).unwrap();
760 let output = ReLU::forward(&input);
761 let expected = vec![0.0, 0.0, 0.0, 1.0, 2.0];
762
763 assert_eq!(output.data(), &expected);
764 }
765
766 #[test]
767 fn test_network_creation() {
768 let config = FeedForwardConfig {
769 input_size: 3,
770 hidden_sizes: vec![5, 4],
771 output_size: 2,
772 use_bias: true,
773 };
774 let network = FeedForwardNetwork::new(config, Some(42));
775
776 assert_eq!(network.num_layers(), 3); assert_eq!(network.parameter_count(), 3 * 5 + 5 + 5 * 4 + 4 + 4 * 2 + 2);
778 }
780
781 #[test]
782 fn test_forward_pass() {
783 let config = FeedForwardConfig {
784 input_size: 2,
785 hidden_sizes: vec![3],
786 output_size: 1,
787 use_bias: true,
788 };
789 let network = FeedForwardNetwork::new(config, Some(43));
790
791 let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
792 let output = network.forward(&input);
793
794 assert_eq!(output.shape().dims, vec![1, 1]);
795 assert!(output.requires_grad());
796 }
797
798 #[test]
799 fn test_batch_forward_pass() {
800 let config = FeedForwardConfig {
801 input_size: 2,
802 hidden_sizes: vec![3],
803 output_size: 1,
804 use_bias: true,
805 };
806 let network = FeedForwardNetwork::new(config, Some(44));
807
808 let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
809 let output = network.forward(&batch_input);
810
811 assert_eq!(output.shape().dims, vec![2, 1]);
812 }
813
814 #[test]
815 fn test_no_grad_forward() {
816 let config = FeedForwardConfig::default();
817 let network = FeedForwardNetwork::new(config, Some(45));
818
819 let input = Tensor::randn(vec![1, 4], Some(46));
820 let output = network.forward_no_grad(&input);
821
822 assert!(!output.requires_grad());
823 }
824
825 #[test]
826 fn test_parameter_collection() {
827 let config = FeedForwardConfig {
828 input_size: 2,
829 hidden_sizes: vec![3],
830 output_size: 1,
831 use_bias: true,
832 };
833 let mut network = FeedForwardNetwork::new(config, Some(47));
834
835 let params = network.parameters();
836 assert_eq!(params.len(), 4); }
838}