1use std::fs;
33use train_station::{
34 gradtrack::NoGradTrack,
35 optimizers::{Adam, AdamConfig, Optimizer},
36 serialization::StructSerializable,
37 Tensor,
38};
39
40#[derive(Debug)]
42pub struct LinearLayer {
43 pub weight: Tensor,
45 pub bias: Tensor,
47 pub input_size: usize,
48 pub output_size: usize,
49}
50
51impl LinearLayer {
52 pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
54 let scale = (1.0 / input_size as f32).sqrt();
56
57 let weight = Tensor::randn(vec![input_size, output_size], seed)
58 .mul_scalar(scale)
59 .with_requires_grad();
60 let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
61
62 Self {
63 weight,
64 bias,
65 input_size,
66 output_size,
67 }
68 }
69
70 pub fn forward(&self, input: &Tensor) -> Tensor {
72 let output = input.matmul(&self.weight);
74 output.add_tensor(&self.bias)
76 }
77
78 #[allow(unused)]
80 pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
81 let _guard = NoGradTrack::new();
82 self.forward(input)
83 }
84
85 pub fn parameters(&mut self) -> Vec<&mut Tensor> {
87 vec![&mut self.weight, &mut self.bias]
88 }
89
90 #[allow(unused)]
92 pub fn save_json(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
93 if let Some(parent) = std::path::Path::new(path).parent() {
95 fs::create_dir_all(parent)?;
96 }
97
98 let weight_path = format!("{}_weight.json", path);
99 let bias_path = format!("{}_bias.json", path);
100
101 self.weight.save_json(&weight_path)?;
102 self.bias.save_json(&bias_path)?;
103
104 println!("Saved linear layer to {} (weight and bias)", path);
105 Ok(())
106 }
107
108 #[allow(unused)]
110 pub fn load_json(
111 path: &str,
112 input_size: usize,
113 output_size: usize,
114 ) -> Result<Self, Box<dyn std::error::Error>> {
115 let weight_path = format!("{}_weight.json", path);
116 let bias_path = format!("{}_bias.json", path);
117
118 let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
119 let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
120
121 Ok(Self {
122 weight,
123 bias,
124 input_size,
125 output_size,
126 })
127 }
128
129 #[allow(unused)]
131 pub fn parameter_count(&self) -> usize {
132 self.input_size * self.output_size + self.output_size
133 }
134}
135
136#[allow(unused)]
137fn main() -> Result<(), Box<dyn std::error::Error>> {
138 println!("=== Basic Linear Layer Example ===\n");
139
140 demonstrate_layer_creation();
141 demonstrate_forward_pass();
142 demonstrate_forward_pass_no_grad();
143 demonstrate_training_loop()?;
144 demonstrate_single_vs_batch_inference();
145 demonstrate_serialization()?;
146 cleanup_temp_files()?;
147
148 println!("\n=== Example completed successfully! ===");
149 Ok(())
150}
151
152fn demonstrate_layer_creation() {
154 println!("--- Layer Creation ---");
155
156 let layer = LinearLayer::new(3, 2, Some(42));
157
158 println!("Created linear layer:");
159 println!(" Input size: {}", layer.input_size);
160 println!(" Output size: {}", layer.output_size);
161 println!(" Parameter count: {}", layer.parameter_count());
162 println!(" Weight shape: {:?}", layer.weight.shape().dims());
163 println!(" Bias shape: {:?}", layer.bias.shape().dims());
164 println!(" Weight requires grad: {}", layer.weight.requires_grad());
165 println!(" Bias requires grad: {}", layer.bias.requires_grad());
166}
167
168fn demonstrate_forward_pass() {
170 println!("\n--- Forward Pass (with gradients) ---");
171
172 let layer = LinearLayer::new(3, 2, Some(43));
173
174 let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
176 let output = layer.forward(&input);
177
178 println!("Single input:");
179 println!(" Input: {:?}", input.data());
180 println!(" Output: {:?}", output.data());
181 println!(" Output requires grad: {}", output.requires_grad());
182
183 let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
185 let batch_output = layer.forward(&batch_input);
186
187 println!("Batch input:");
188 println!(" Input shape: {:?}", batch_input.shape().dims());
189 println!(" Output shape: {:?}", batch_output.shape().dims());
190 println!(" Output requires grad: {}", batch_output.requires_grad());
191}
192
193fn demonstrate_forward_pass_no_grad() {
195 println!("\n--- Forward Pass (no gradients) ---");
196
197 let layer = LinearLayer::new(3, 2, Some(44));
198
199 let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
201 let output = layer.forward_no_grad(&input);
202
203 println!("Single input (no grad):");
204 println!(" Input: {:?}", input.data());
205 println!(" Output: {:?}", output.data());
206 println!(" Output requires grad: {}", output.requires_grad());
207
208 let output_with_grad = layer.forward(&input);
210 println!("Comparison:");
211 println!(
212 " Same values: {}",
213 output.data() == output_with_grad.data()
214 );
215 println!(" No grad requires grad: {}", output.requires_grad());
216 println!(
217 " With grad requires grad: {}",
218 output_with_grad.requires_grad()
219 );
220}
221
222fn demonstrate_training_loop() -> Result<(), Box<dyn std::error::Error>> {
224 println!("\n--- Training Loop ---");
225
226 let mut layer = LinearLayer::new(2, 1, Some(45));
228
229 let x_data = Tensor::from_slice(
231 &[
232 1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 2.0, ],
237 vec![4, 2],
238 )
239 .unwrap();
240
241 let y_true = Tensor::from_slice(&[6.0, 8.0, 9.0, 11.0], vec![4, 1]).unwrap();
242
243 println!("Training data:");
244 println!(" X shape: {:?}", x_data.shape().dims());
245 println!(" Y shape: {:?}", y_true.shape().dims());
246 println!(" Target function: y = 2*x1 + 3*x2 + 1");
247
248 let config = AdamConfig {
250 learning_rate: 0.01,
251 beta1: 0.9,
252 beta2: 0.999,
253 eps: 1e-8,
254 weight_decay: 0.0,
255 amsgrad: false,
256 };
257
258 let mut optimizer = Adam::with_config(config);
259 let params = layer.parameters();
260 for param in ¶ms {
261 optimizer.add_parameter(param);
262 }
263
264 println!("Optimizer setup complete. Starting training...");
265
266 let num_epochs = 100;
268 let mut losses = Vec::new();
269
270 for epoch in 0..num_epochs {
271 let y_pred = layer.forward(&x_data);
273
274 let diff = y_pred.sub_tensor(&y_true);
276 let mut loss = diff.pow_scalar(2.0).mean();
277
278 loss.backward(None);
280
281 let mut params = layer.parameters();
283 optimizer.step(&mut params);
284 optimizer.zero_grad(&mut params);
285
286 losses.push(loss.value());
287
288 if epoch % 20 == 0 || epoch == num_epochs - 1 {
290 println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
291 }
292 }
293
294 let final_predictions = layer.forward_no_grad(&x_data);
296
297 println!("\nFinal model evaluation:");
298 println!(" Learned weights: {:?}", layer.weight.data());
299 println!(" Learned bias: {:?}", layer.bias.data());
300 println!(" Target weights: [2.0, 3.0]");
301 println!(" Target bias: [1.0]");
302
303 println!(" Predictions vs True:");
304 for i in 0..4 {
305 let pred = final_predictions.data()[i];
306 let true_val = y_true.data()[i];
307 println!(
308 " Sample {}: pred={:.3}, true={:.1}, error={:.3}",
309 i + 1,
310 pred,
311 true_val,
312 (pred - true_val).abs()
313 );
314 }
315
316 let initial_loss = losses[0];
318 let final_loss = losses[losses.len() - 1];
319 let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
320
321 println!("\nTraining Analysis:");
322 println!(" Initial loss: {:.6}", initial_loss);
323 println!(" Final loss: {:.6}", final_loss);
324 println!(" Loss reduction: {:.1}%", loss_reduction);
325
326 Ok(())
327}
328
329fn demonstrate_single_vs_batch_inference() {
331 println!("\n--- Single vs Batch Inference ---");
332
333 let layer = LinearLayer::new(4, 3, Some(46));
334
335 println!("Single inference:");
337 let single_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
338 let single_output = layer.forward_no_grad(&single_input);
339 println!(" Input shape: {:?}", single_input.shape().dims());
340 println!(" Output shape: {:?}", single_output.shape().dims());
341 println!(" Output: {:?}", single_output.data());
342
343 println!("Batch inference:");
345 let batch_input = Tensor::from_slice(
346 &[
347 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ],
351 vec![3, 4],
352 )
353 .unwrap();
354 let batch_output = layer.forward_no_grad(&batch_input);
355 println!(" Input shape: {:?}", batch_input.shape().dims());
356 println!(" Output shape: {:?}", batch_output.shape().dims());
357
358 let _first_batch_sample = batch_output.view(vec![3, 3]); let first_sample_data = &batch_output.data()[0..3]; let single_sample_data = single_output.data();
362
363 println!("Consistency check:");
364 println!(" Single output: {:?}", single_sample_data);
365 println!(" First batch sample: {:?}", first_sample_data);
366 println!(
367 " Match: {}",
368 single_sample_data
369 .iter()
370 .zip(first_sample_data.iter())
371 .all(|(a, b)| (a - b).abs() < 1e-6)
372 );
373}
374
375fn demonstrate_serialization() -> Result<(), Box<dyn std::error::Error>> {
377 println!("\n--- Serialization ---");
378
379 let mut original_layer = LinearLayer::new(2, 1, Some(47));
381
382 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
384 let y_true = Tensor::from_slice(&[5.0, 11.0], vec![2, 1]).unwrap();
385
386 let mut optimizer = Adam::with_learning_rate(0.01);
387 let params = original_layer.parameters();
388 for param in ¶ms {
389 optimizer.add_parameter(param);
390 }
391
392 for _ in 0..10 {
394 let y_pred = original_layer.forward(&x_data);
395 let mut loss = (y_pred.sub_tensor(&y_true)).pow_scalar(2.0).mean();
396 loss.backward(None);
397
398 let mut params = original_layer.parameters();
399 optimizer.step(&mut params);
400 optimizer.zero_grad(&mut params);
401 }
402
403 println!("Original layer trained");
404 println!(" Weight: {:?}", original_layer.weight.data());
405 println!(" Bias: {:?}", original_layer.bias.data());
406
407 original_layer.save_json("temp_linear_layer")?;
409
410 let loaded_layer = LinearLayer::load_json("temp_linear_layer", 2, 1)?;
412
413 println!("Loaded layer");
414 println!(" Weight: {:?}", loaded_layer.weight.data());
415 println!(" Bias: {:?}", loaded_layer.bias.data());
416
417 let test_input = Tensor::from_slice(&[1.0, 1.0], vec![1, 2]).unwrap();
419 let original_output = original_layer.forward_no_grad(&test_input);
420 let loaded_output = loaded_layer.forward_no_grad(&test_input);
421
422 println!("Consistency check:");
423 println!(" Original output: {:?}", original_output.data());
424 println!(" Loaded output: {:?}", loaded_output.data());
425 println!(
426 " Match: {}",
427 original_output
428 .data()
429 .iter()
430 .zip(loaded_output.data().iter())
431 .all(|(a, b)| (a - b).abs() < 1e-6)
432 );
433
434 println!("Serialization verification: PASSED");
435
436 Ok(())
437}
438
439fn cleanup_temp_files() -> Result<(), Box<dyn std::error::Error>> {
441 println!("\n--- Cleanup ---");
442
443 let files_to_remove = [
444 "temp_linear_layer_weight.json",
445 "temp_linear_layer_bias.json",
446 ];
447
448 for file in &files_to_remove {
449 if fs::metadata(file).is_ok() {
450 fs::remove_file(file)?;
451 println!("Removed: {}", file);
452 }
453 }
454
455 println!("Cleanup completed");
456 Ok(())
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 #[test]
464 fn test_layer_creation() {
465 let layer = LinearLayer::new(3, 2, Some(42));
466 assert_eq!(layer.input_size, 3);
467 assert_eq!(layer.output_size, 2);
468 assert_eq!(layer.weight.shape().dims(), vec![3, 2]);
469 assert_eq!(layer.bias.shape().dims(), vec![2]);
470 assert!(layer.weight.requires_grad());
471 assert!(layer.bias.requires_grad());
472 }
473
474 #[test]
475 fn test_forward_pass() {
476 let layer = LinearLayer::new(2, 1, Some(43));
477 let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
478 let output = layer.forward(&input);
479
480 assert_eq!(output.shape().dims(), vec![1, 1]);
481 assert!(output.requires_grad());
482 }
483
484 #[test]
485 fn test_forward_pass_no_grad() {
486 let layer = LinearLayer::new(2, 1, Some(44));
487 let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
488 let output = layer.forward_no_grad(&input);
489
490 assert_eq!(output.shape().dims(), vec![1, 1]);
491 assert!(!output.requires_grad());
492 }
493
494 #[test]
495 fn test_batch_inference() {
496 let layer = LinearLayer::new(2, 1, Some(45));
497 let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
498 let output = layer.forward(&batch_input);
499
500 assert_eq!(output.shape().dims(), vec![2, 1]);
501 }
502
503 #[test]
504 fn test_parameter_count() {
505 let layer = LinearLayer::new(3, 2, Some(46));
506 assert_eq!(layer.parameter_count(), 3 * 2 + 2); }
508
509 #[test]
510 fn test_serialization_roundtrip() {
511 let original = LinearLayer::new(2, 1, Some(47));
512
513 original.save_json("test_layer").unwrap();
515 let loaded = LinearLayer::load_json("test_layer", 2, 1).unwrap();
516
517 assert_eq!(original.weight.shape().dims(), loaded.weight.shape().dims());
519 assert_eq!(original.bias.shape().dims(), loaded.bias.shape().dims());
520
521 assert_eq!(original.weight.data(), loaded.weight.data());
523 assert_eq!(original.bias.data(), loaded.bias.data());
524
525 let _ = fs::remove_file("test_layer_weight.json");
527 let _ = fs::remove_file("test_layer_bias.json");
528 }
529}