1use std::fs;
33use train_station::{
34 optimizers::{Adam, AdamConfig, Optimizer},
35 serialization::StructSerializable,
36 NoGradTrack, Tensor,
37};
38
39#[derive(Debug)]
41pub struct LinearLayer {
42 pub weight: Tensor,
44 pub bias: Tensor,
46 pub input_size: usize,
47 pub output_size: usize,
48}
49
50impl LinearLayer {
51 pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
53 let scale = (1.0 / input_size as f32).sqrt();
55
56 let weight = Tensor::randn(vec![input_size, output_size], seed)
57 .mul_scalar(scale)
58 .with_requires_grad();
59 let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
60
61 Self {
62 weight,
63 bias,
64 input_size,
65 output_size,
66 }
67 }
68
69 pub fn forward(&self, input: &Tensor) -> Tensor {
71 let output = input.matmul(&self.weight);
73 output.add_tensor(&self.bias)
75 }
76
77 pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
79 let _guard = NoGradTrack::new();
80 self.forward(input)
81 }
82
83 pub fn parameters(&mut self) -> Vec<&mut Tensor> {
85 vec![&mut self.weight, &mut self.bias]
86 }
87
88 pub fn save_json(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
90 if let Some(parent) = std::path::Path::new(path).parent() {
92 fs::create_dir_all(parent)?;
93 }
94
95 let weight_path = format!("{}_weight.json", path);
96 let bias_path = format!("{}_bias.json", path);
97
98 self.weight.save_json(&weight_path)?;
99 self.bias.save_json(&bias_path)?;
100
101 println!("Saved linear layer to {} (weight and bias)", path);
102 Ok(())
103 }
104
105 pub fn load_json(
107 path: &str,
108 input_size: usize,
109 output_size: usize,
110 ) -> Result<Self, Box<dyn std::error::Error>> {
111 let weight_path = format!("{}_weight.json", path);
112 let bias_path = format!("{}_bias.json", path);
113
114 let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
115 let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
116
117 Ok(Self {
118 weight,
119 bias,
120 input_size,
121 output_size,
122 })
123 }
124
125 pub fn parameter_count(&self) -> usize {
127 self.input_size * self.output_size + self.output_size
128 }
129}
130
131fn main() -> Result<(), Box<dyn std::error::Error>> {
132 println!("=== Basic Linear Layer Example ===\n");
133
134 demonstrate_layer_creation();
135 demonstrate_forward_pass();
136 demonstrate_forward_pass_no_grad();
137 demonstrate_training_loop()?;
138 demonstrate_single_vs_batch_inference();
139 demonstrate_serialization()?;
140 cleanup_temp_files()?;
141
142 println!("\n=== Example completed successfully! ===");
143 Ok(())
144}
145
146fn demonstrate_layer_creation() {
148 println!("--- Layer Creation ---");
149
150 let layer = LinearLayer::new(3, 2, Some(42));
151
152 println!("Created linear layer:");
153 println!(" Input size: {}", layer.input_size);
154 println!(" Output size: {}", layer.output_size);
155 println!(" Parameter count: {}", layer.parameter_count());
156 println!(" Weight shape: {:?}", layer.weight.shape().dims);
157 println!(" Bias shape: {:?}", layer.bias.shape().dims);
158 println!(" Weight requires grad: {}", layer.weight.requires_grad());
159 println!(" Bias requires grad: {}", layer.bias.requires_grad());
160}
161
162fn demonstrate_forward_pass() {
164 println!("\n--- Forward Pass (with gradients) ---");
165
166 let layer = LinearLayer::new(3, 2, Some(43));
167
168 let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
170 let output = layer.forward(&input);
171
172 println!("Single input:");
173 println!(" Input: {:?}", input.data());
174 println!(" Output: {:?}", output.data());
175 println!(" Output requires grad: {}", output.requires_grad());
176
177 let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
179 let batch_output = layer.forward(&batch_input);
180
181 println!("Batch input:");
182 println!(" Input shape: {:?}", batch_input.shape().dims);
183 println!(" Output shape: {:?}", batch_output.shape().dims);
184 println!(" Output requires grad: {}", batch_output.requires_grad());
185}
186
187fn demonstrate_forward_pass_no_grad() {
189 println!("\n--- Forward Pass (no gradients) ---");
190
191 let layer = LinearLayer::new(3, 2, Some(44));
192
193 let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
195 let output = layer.forward_no_grad(&input);
196
197 println!("Single input (no grad):");
198 println!(" Input: {:?}", input.data());
199 println!(" Output: {:?}", output.data());
200 println!(" Output requires grad: {}", output.requires_grad());
201
202 let output_with_grad = layer.forward(&input);
204 println!("Comparison:");
205 println!(
206 " Same values: {}",
207 output.data() == output_with_grad.data()
208 );
209 println!(" No grad requires grad: {}", output.requires_grad());
210 println!(
211 " With grad requires grad: {}",
212 output_with_grad.requires_grad()
213 );
214}
215
216fn demonstrate_training_loop() -> Result<(), Box<dyn std::error::Error>> {
218 println!("\n--- Training Loop ---");
219
220 let mut layer = LinearLayer::new(2, 1, Some(45));
222
223 let x_data = Tensor::from_slice(
225 &[
226 1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 2.0, ],
231 vec![4, 2],
232 )
233 .unwrap();
234
235 let y_true = Tensor::from_slice(&[6.0, 8.0, 9.0, 11.0], vec![4, 1]).unwrap();
236
237 println!("Training data:");
238 println!(" X shape: {:?}", x_data.shape().dims);
239 println!(" Y shape: {:?}", y_true.shape().dims);
240 println!(" Target function: y = 2*x1 + 3*x2 + 1");
241
242 let config = AdamConfig {
244 learning_rate: 0.01,
245 beta1: 0.9,
246 beta2: 0.999,
247 eps: 1e-8,
248 weight_decay: 0.0,
249 amsgrad: false,
250 };
251
252 let mut optimizer = Adam::with_config(config);
253 let params = layer.parameters();
254 for param in ¶ms {
255 optimizer.add_parameter(param);
256 }
257
258 println!("Optimizer setup complete. Starting training...");
259
260 let num_epochs = 100;
262 let mut losses = Vec::new();
263
264 for epoch in 0..num_epochs {
265 let y_pred = layer.forward(&x_data);
267
268 let diff = y_pred.sub_tensor(&y_true);
270 let mut loss = diff.pow_scalar(2.0).mean();
271
272 loss.backward(None);
274
275 let mut params = layer.parameters();
277 optimizer.step(&mut params);
278 optimizer.zero_grad(&mut params);
279
280 losses.push(loss.value());
281
282 if epoch % 20 == 0 || epoch == num_epochs - 1 {
284 println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
285 }
286 }
287
288 let final_predictions = layer.forward_no_grad(&x_data);
290
291 println!("\nFinal model evaluation:");
292 println!(" Learned weights: {:?}", layer.weight.data());
293 println!(" Learned bias: {:?}", layer.bias.data());
294 println!(" Target weights: [2.0, 3.0]");
295 println!(" Target bias: [1.0]");
296
297 println!(" Predictions vs True:");
298 for i in 0..4 {
299 let pred = final_predictions.data()[i];
300 let true_val = y_true.data()[i];
301 println!(
302 " Sample {}: pred={:.3}, true={:.1}, error={:.3}",
303 i + 1,
304 pred,
305 true_val,
306 (pred - true_val).abs()
307 );
308 }
309
310 let initial_loss = losses[0];
312 let final_loss = losses[losses.len() - 1];
313 let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
314
315 println!("\nTraining Analysis:");
316 println!(" Initial loss: {:.6}", initial_loss);
317 println!(" Final loss: {:.6}", final_loss);
318 println!(" Loss reduction: {:.1}%", loss_reduction);
319
320 Ok(())
321}
322
323fn demonstrate_single_vs_batch_inference() {
325 println!("\n--- Single vs Batch Inference ---");
326
327 let layer = LinearLayer::new(4, 3, Some(46));
328
329 println!("Single inference:");
331 let single_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
332 let single_output = layer.forward_no_grad(&single_input);
333 println!(" Input shape: {:?}", single_input.shape().dims);
334 println!(" Output shape: {:?}", single_output.shape().dims);
335 println!(" Output: {:?}", single_output.data());
336
337 println!("Batch inference:");
339 let batch_input = Tensor::from_slice(
340 &[
341 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ],
345 vec![3, 4],
346 )
347 .unwrap();
348 let batch_output = layer.forward_no_grad(&batch_input);
349 println!(" Input shape: {:?}", batch_input.shape().dims);
350 println!(" Output shape: {:?}", batch_output.shape().dims);
351
352 let _first_batch_sample = batch_output.view(vec![3, 3]); let first_sample_data = &batch_output.data()[0..3]; let single_sample_data = single_output.data();
356
357 println!("Consistency check:");
358 println!(" Single output: {:?}", single_sample_data);
359 println!(" First batch sample: {:?}", first_sample_data);
360 println!(
361 " Match: {}",
362 single_sample_data
363 .iter()
364 .zip(first_sample_data.iter())
365 .all(|(a, b)| (a - b).abs() < 1e-6)
366 );
367}
368
369fn demonstrate_serialization() -> Result<(), Box<dyn std::error::Error>> {
371 println!("\n--- Serialization ---");
372
373 let mut original_layer = LinearLayer::new(2, 1, Some(47));
375
376 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
378 let y_true = Tensor::from_slice(&[5.0, 11.0], vec![2, 1]).unwrap();
379
380 let mut optimizer = Adam::with_learning_rate(0.01);
381 let params = original_layer.parameters();
382 for param in ¶ms {
383 optimizer.add_parameter(param);
384 }
385
386 for _ in 0..10 {
388 let y_pred = original_layer.forward(&x_data);
389 let mut loss = (y_pred.sub_tensor(&y_true)).pow_scalar(2.0).mean();
390 loss.backward(None);
391
392 let mut params = original_layer.parameters();
393 optimizer.step(&mut params);
394 optimizer.zero_grad(&mut params);
395 }
396
397 println!("Original layer trained");
398 println!(" Weight: {:?}", original_layer.weight.data());
399 println!(" Bias: {:?}", original_layer.bias.data());
400
401 original_layer.save_json("temp_linear_layer")?;
403
404 let loaded_layer = LinearLayer::load_json("temp_linear_layer", 2, 1)?;
406
407 println!("Loaded layer");
408 println!(" Weight: {:?}", loaded_layer.weight.data());
409 println!(" Bias: {:?}", loaded_layer.bias.data());
410
411 let test_input = Tensor::from_slice(&[1.0, 1.0], vec![1, 2]).unwrap();
413 let original_output = original_layer.forward_no_grad(&test_input);
414 let loaded_output = loaded_layer.forward_no_grad(&test_input);
415
416 println!("Consistency check:");
417 println!(" Original output: {:?}", original_output.data());
418 println!(" Loaded output: {:?}", loaded_output.data());
419 println!(
420 " Match: {}",
421 original_output
422 .data()
423 .iter()
424 .zip(loaded_output.data().iter())
425 .all(|(a, b)| (a - b).abs() < 1e-6)
426 );
427
428 println!("Serialization verification: PASSED");
429
430 Ok(())
431}
432
433fn cleanup_temp_files() -> Result<(), Box<dyn std::error::Error>> {
435 println!("\n--- Cleanup ---");
436
437 let files_to_remove = [
438 "temp_linear_layer_weight.json",
439 "temp_linear_layer_bias.json",
440 ];
441
442 for file in &files_to_remove {
443 if fs::metadata(file).is_ok() {
444 fs::remove_file(file)?;
445 println!("Removed: {}", file);
446 }
447 }
448
449 println!("Cleanup completed");
450 Ok(())
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456
457 #[test]
458 fn test_layer_creation() {
459 let layer = LinearLayer::new(3, 2, Some(42));
460 assert_eq!(layer.input_size, 3);
461 assert_eq!(layer.output_size, 2);
462 assert_eq!(layer.weight.shape().dims, vec![3, 2]);
463 assert_eq!(layer.bias.shape().dims, vec![2]);
464 assert!(layer.weight.requires_grad());
465 assert!(layer.bias.requires_grad());
466 }
467
468 #[test]
469 fn test_forward_pass() {
470 let layer = LinearLayer::new(2, 1, Some(43));
471 let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
472 let output = layer.forward(&input);
473
474 assert_eq!(output.shape().dims, vec![1, 1]);
475 assert!(output.requires_grad());
476 }
477
478 #[test]
479 fn test_forward_pass_no_grad() {
480 let layer = LinearLayer::new(2, 1, Some(44));
481 let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
482 let output = layer.forward_no_grad(&input);
483
484 assert_eq!(output.shape().dims, vec![1, 1]);
485 assert!(!output.requires_grad());
486 }
487
488 #[test]
489 fn test_batch_inference() {
490 let layer = LinearLayer::new(2, 1, Some(45));
491 let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
492 let output = layer.forward(&batch_input);
493
494 assert_eq!(output.shape().dims, vec![2, 1]);
495 }
496
497 #[test]
498 fn test_parameter_count() {
499 let layer = LinearLayer::new(3, 2, Some(46));
500 assert_eq!(layer.parameter_count(), 3 * 2 + 2); }
502
503 #[test]
504 fn test_serialization_roundtrip() {
505 let original = LinearLayer::new(2, 1, Some(47));
506
507 original.save_json("test_layer").unwrap();
509 let loaded = LinearLayer::load_json("test_layer", 2, 1).unwrap();
510
511 assert_eq!(original.weight.shape().dims, loaded.weight.shape().dims);
513 assert_eq!(original.bias.shape().dims, loaded.bias.shape().dims);
514
515 assert_eq!(original.weight.data(), loaded.weight.data());
517 assert_eq!(original.bias.data(), loaded.bias.data());
518
519 let _ = fs::remove_file("test_layer_weight.json");
521 let _ = fs::remove_file("test_layer_bias.json");
522 }
523}