1use std::fs;
33use train_station::{
34 gradtrack::NoGradTrack,
35 optimizers::{Adam, AdamConfig, Optimizer},
36 serialization::StructSerializable,
37 Tensor,
38};
39
40#[derive(Debug)]
42pub struct LinearLayer {
43 pub weight: Tensor,
45 pub bias: Tensor,
47 pub input_size: usize,
48 pub output_size: usize,
49}
50
51impl LinearLayer {
52 pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
54 let scale = (1.0 / input_size as f32).sqrt();
56
57 let weight = Tensor::randn(vec![input_size, output_size], seed)
58 .mul_scalar(scale)
59 .with_requires_grad();
60 let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
61
62 Self {
63 weight,
64 bias,
65 input_size,
66 output_size,
67 }
68 }
69
70 pub fn forward(&self, input: &Tensor) -> Tensor {
72 let output = input.matmul(&self.weight);
74 output.add_tensor(&self.bias)
76 }
77
78 pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
80 let _guard = NoGradTrack::new();
81 self.forward(input)
82 }
83
84 pub fn parameters(&mut self) -> Vec<&mut Tensor> {
86 vec![&mut self.weight, &mut self.bias]
87 }
88
89 pub fn save_json(&self, path: &str) -> Result<(), Box<dyn std::error::Error>> {
91 if let Some(parent) = std::path::Path::new(path).parent() {
93 fs::create_dir_all(parent)?;
94 }
95
96 let weight_path = format!("{}_weight.json", path);
97 let bias_path = format!("{}_bias.json", path);
98
99 self.weight.save_json(&weight_path)?;
100 self.bias.save_json(&bias_path)?;
101
102 println!("Saved linear layer to {} (weight and bias)", path);
103 Ok(())
104 }
105
106 pub fn load_json(
108 path: &str,
109 input_size: usize,
110 output_size: usize,
111 ) -> Result<Self, Box<dyn std::error::Error>> {
112 let weight_path = format!("{}_weight.json", path);
113 let bias_path = format!("{}_bias.json", path);
114
115 let weight = Tensor::load_json(&weight_path)?.with_requires_grad();
116 let bias = Tensor::load_json(&bias_path)?.with_requires_grad();
117
118 Ok(Self {
119 weight,
120 bias,
121 input_size,
122 output_size,
123 })
124 }
125
126 pub fn parameter_count(&self) -> usize {
128 self.input_size * self.output_size + self.output_size
129 }
130}
131
132fn main() -> Result<(), Box<dyn std::error::Error>> {
133 println!("=== Basic Linear Layer Example ===\n");
134
135 demonstrate_layer_creation();
136 demonstrate_forward_pass();
137 demonstrate_forward_pass_no_grad();
138 demonstrate_training_loop()?;
139 demonstrate_single_vs_batch_inference();
140 demonstrate_serialization()?;
141 cleanup_temp_files()?;
142
143 println!("\n=== Example completed successfully! ===");
144 Ok(())
145}
146
147fn demonstrate_layer_creation() {
149 println!("--- Layer Creation ---");
150
151 let layer = LinearLayer::new(3, 2, Some(42));
152
153 println!("Created linear layer:");
154 println!(" Input size: {}", layer.input_size);
155 println!(" Output size: {}", layer.output_size);
156 println!(" Parameter count: {}", layer.parameter_count());
157 println!(" Weight shape: {:?}", layer.weight.shape().dims());
158 println!(" Bias shape: {:?}", layer.bias.shape().dims());
159 println!(" Weight requires grad: {}", layer.weight.requires_grad());
160 println!(" Bias requires grad: {}", layer.bias.requires_grad());
161}
162
163fn demonstrate_forward_pass() {
165 println!("\n--- Forward Pass (with gradients) ---");
166
167 let layer = LinearLayer::new(3, 2, Some(43));
168
169 let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
171 let output = layer.forward(&input);
172
173 println!("Single input:");
174 println!(" Input: {:?}", input.data());
175 println!(" Output: {:?}", output.data());
176 println!(" Output requires grad: {}", output.requires_grad());
177
178 let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]).unwrap();
180 let batch_output = layer.forward(&batch_input);
181
182 println!("Batch input:");
183 println!(" Input shape: {:?}", batch_input.shape().dims());
184 println!(" Output shape: {:?}", batch_output.shape().dims());
185 println!(" Output requires grad: {}", batch_output.requires_grad());
186}
187
188fn demonstrate_forward_pass_no_grad() {
190 println!("\n--- Forward Pass (no gradients) ---");
191
192 let layer = LinearLayer::new(3, 2, Some(44));
193
194 let input = Tensor::from_slice(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
196 let output = layer.forward_no_grad(&input);
197
198 println!("Single input (no grad):");
199 println!(" Input: {:?}", input.data());
200 println!(" Output: {:?}", output.data());
201 println!(" Output requires grad: {}", output.requires_grad());
202
203 let output_with_grad = layer.forward(&input);
205 println!("Comparison:");
206 println!(
207 " Same values: {}",
208 output.data() == output_with_grad.data()
209 );
210 println!(" No grad requires grad: {}", output.requires_grad());
211 println!(
212 " With grad requires grad: {}",
213 output_with_grad.requires_grad()
214 );
215}
216
217fn demonstrate_training_loop() -> Result<(), Box<dyn std::error::Error>> {
219 println!("\n--- Training Loop ---");
220
221 let mut layer = LinearLayer::new(2, 1, Some(45));
223
224 let x_data = Tensor::from_slice(
226 &[
227 1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 2.0, ],
232 vec![4, 2],
233 )
234 .unwrap();
235
236 let y_true = Tensor::from_slice(&[6.0, 8.0, 9.0, 11.0], vec![4, 1]).unwrap();
237
238 println!("Training data:");
239 println!(" X shape: {:?}", x_data.shape().dims());
240 println!(" Y shape: {:?}", y_true.shape().dims());
241 println!(" Target function: y = 2*x1 + 3*x2 + 1");
242
243 let config = AdamConfig {
245 learning_rate: 0.01,
246 beta1: 0.9,
247 beta2: 0.999,
248 eps: 1e-8,
249 weight_decay: 0.0,
250 amsgrad: false,
251 };
252
253 let mut optimizer = Adam::with_config(config);
254 let params = layer.parameters();
255 for param in ¶ms {
256 optimizer.add_parameter(param);
257 }
258
259 println!("Optimizer setup complete. Starting training...");
260
261 let num_epochs = 100;
263 let mut losses = Vec::new();
264
265 for epoch in 0..num_epochs {
266 let y_pred = layer.forward(&x_data);
268
269 let diff = y_pred.sub_tensor(&y_true);
271 let mut loss = diff.pow_scalar(2.0).mean();
272
273 loss.backward(None);
275
276 let mut params = layer.parameters();
278 optimizer.step(&mut params);
279 optimizer.zero_grad(&mut params);
280
281 losses.push(loss.value());
282
283 if epoch % 20 == 0 || epoch == num_epochs - 1 {
285 println!("Epoch {:3}: Loss = {:.6}", epoch, loss.value());
286 }
287 }
288
289 let final_predictions = layer.forward_no_grad(&x_data);
291
292 println!("\nFinal model evaluation:");
293 println!(" Learned weights: {:?}", layer.weight.data());
294 println!(" Learned bias: {:?}", layer.bias.data());
295 println!(" Target weights: [2.0, 3.0]");
296 println!(" Target bias: [1.0]");
297
298 println!(" Predictions vs True:");
299 for i in 0..4 {
300 let pred = final_predictions.data()[i];
301 let true_val = y_true.data()[i];
302 println!(
303 " Sample {}: pred={:.3}, true={:.1}, error={:.3}",
304 i + 1,
305 pred,
306 true_val,
307 (pred - true_val).abs()
308 );
309 }
310
311 let initial_loss = losses[0];
313 let final_loss = losses[losses.len() - 1];
314 let loss_reduction = (initial_loss - final_loss) / initial_loss * 100.0;
315
316 println!("\nTraining Analysis:");
317 println!(" Initial loss: {:.6}", initial_loss);
318 println!(" Final loss: {:.6}", final_loss);
319 println!(" Loss reduction: {:.1}%", loss_reduction);
320
321 Ok(())
322}
323
324fn demonstrate_single_vs_batch_inference() {
326 println!("\n--- Single vs Batch Inference ---");
327
328 let layer = LinearLayer::new(4, 3, Some(46));
329
330 println!("Single inference:");
332 let single_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
333 let single_output = layer.forward_no_grad(&single_input);
334 println!(" Input shape: {:?}", single_input.shape().dims());
335 println!(" Output shape: {:?}", single_output.shape().dims());
336 println!(" Output: {:?}", single_output.data());
337
338 println!("Batch inference:");
340 let batch_input = Tensor::from_slice(
341 &[
342 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ],
346 vec![3, 4],
347 )
348 .unwrap();
349 let batch_output = layer.forward_no_grad(&batch_input);
350 println!(" Input shape: {:?}", batch_input.shape().dims());
351 println!(" Output shape: {:?}", batch_output.shape().dims());
352
353 let _first_batch_sample = batch_output.view(vec![3, 3]); let first_sample_data = &batch_output.data()[0..3]; let single_sample_data = single_output.data();
357
358 println!("Consistency check:");
359 println!(" Single output: {:?}", single_sample_data);
360 println!(" First batch sample: {:?}", first_sample_data);
361 println!(
362 " Match: {}",
363 single_sample_data
364 .iter()
365 .zip(first_sample_data.iter())
366 .all(|(a, b)| (a - b).abs() < 1e-6)
367 );
368}
369
370fn demonstrate_serialization() -> Result<(), Box<dyn std::error::Error>> {
372 println!("\n--- Serialization ---");
373
374 let mut original_layer = LinearLayer::new(2, 1, Some(47));
376
377 let x_data = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
379 let y_true = Tensor::from_slice(&[5.0, 11.0], vec![2, 1]).unwrap();
380
381 let mut optimizer = Adam::with_learning_rate(0.01);
382 let params = original_layer.parameters();
383 for param in ¶ms {
384 optimizer.add_parameter(param);
385 }
386
387 for _ in 0..10 {
389 let y_pred = original_layer.forward(&x_data);
390 let mut loss = (y_pred.sub_tensor(&y_true)).pow_scalar(2.0).mean();
391 loss.backward(None);
392
393 let mut params = original_layer.parameters();
394 optimizer.step(&mut params);
395 optimizer.zero_grad(&mut params);
396 }
397
398 println!("Original layer trained");
399 println!(" Weight: {:?}", original_layer.weight.data());
400 println!(" Bias: {:?}", original_layer.bias.data());
401
402 original_layer.save_json("temp_linear_layer")?;
404
405 let loaded_layer = LinearLayer::load_json("temp_linear_layer", 2, 1)?;
407
408 println!("Loaded layer");
409 println!(" Weight: {:?}", loaded_layer.weight.data());
410 println!(" Bias: {:?}", loaded_layer.bias.data());
411
412 let test_input = Tensor::from_slice(&[1.0, 1.0], vec![1, 2]).unwrap();
414 let original_output = original_layer.forward_no_grad(&test_input);
415 let loaded_output = loaded_layer.forward_no_grad(&test_input);
416
417 println!("Consistency check:");
418 println!(" Original output: {:?}", original_output.data());
419 println!(" Loaded output: {:?}", loaded_output.data());
420 println!(
421 " Match: {}",
422 original_output
423 .data()
424 .iter()
425 .zip(loaded_output.data().iter())
426 .all(|(a, b)| (a - b).abs() < 1e-6)
427 );
428
429 println!("Serialization verification: PASSED");
430
431 Ok(())
432}
433
434fn cleanup_temp_files() -> Result<(), Box<dyn std::error::Error>> {
436 println!("\n--- Cleanup ---");
437
438 let files_to_remove = [
439 "temp_linear_layer_weight.json",
440 "temp_linear_layer_bias.json",
441 ];
442
443 for file in &files_to_remove {
444 if fs::metadata(file).is_ok() {
445 fs::remove_file(file)?;
446 println!("Removed: {}", file);
447 }
448 }
449
450 println!("Cleanup completed");
451 Ok(())
452}
453
454#[cfg(test)]
455mod tests {
456 use super::*;
457
458 #[test]
459 fn test_layer_creation() {
460 let layer = LinearLayer::new(3, 2, Some(42));
461 assert_eq!(layer.input_size, 3);
462 assert_eq!(layer.output_size, 2);
463 assert_eq!(layer.weight.shape().dims(), vec![3, 2]);
464 assert_eq!(layer.bias.shape().dims(), vec![2]);
465 assert!(layer.weight.requires_grad());
466 assert!(layer.bias.requires_grad());
467 }
468
469 #[test]
470 fn test_forward_pass() {
471 let layer = LinearLayer::new(2, 1, Some(43));
472 let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
473 let output = layer.forward(&input);
474
475 assert_eq!(output.shape().dims(), vec![1, 1]);
476 assert!(output.requires_grad());
477 }
478
479 #[test]
480 fn test_forward_pass_no_grad() {
481 let layer = LinearLayer::new(2, 1, Some(44));
482 let input = Tensor::from_slice(&[1.0, 2.0], vec![1, 2]).unwrap();
483 let output = layer.forward_no_grad(&input);
484
485 assert_eq!(output.shape().dims(), vec![1, 1]);
486 assert!(!output.requires_grad());
487 }
488
489 #[test]
490 fn test_batch_inference() {
491 let layer = LinearLayer::new(2, 1, Some(45));
492 let batch_input = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], vec![2, 2]).unwrap();
493 let output = layer.forward(&batch_input);
494
495 assert_eq!(output.shape().dims(), vec![2, 1]);
496 }
497
498 #[test]
499 fn test_parameter_count() {
500 let layer = LinearLayer::new(3, 2, Some(46));
501 assert_eq!(layer.parameter_count(), 3 * 2 + 2); }
503
504 #[test]
505 fn test_serialization_roundtrip() {
506 let original = LinearLayer::new(2, 1, Some(47));
507
508 original.save_json("test_layer").unwrap();
510 let loaded = LinearLayer::load_json("test_layer", 2, 1).unwrap();
511
512 assert_eq!(original.weight.shape().dims(), loaded.weight.shape().dims());
514 assert_eq!(original.bias.shape().dims(), loaded.bias.shape().dims());
515
516 assert_eq!(original.weight.data(), loaded.weight.data());
518 assert_eq!(original.bias.data(), loaded.bias.data());
519
520 let _ = fs::remove_file("test_layer_weight.json");
522 let _ = fs::remove_file("test_layer_bias.json");
523 }
524}