1#[cfg(feature = "gpu")]
17pub mod gpu;
18
19pub mod dtype;
20pub mod storage;
21
22pub use dtype::DType;
24pub use storage::Storage;
25
26#[cfg(feature = "gpu")]
27pub use gpu::{
28 GpuBuffer, GpuContext, get_gpu_context, gpu_cleanup, gpu_compact, gpu_pending_count,
29 gpu_pool_stats, gpu_sync, gpu_sync_threshold, is_gpu_available,
30};
31
32pub mod autograd;
33pub mod data;
34pub mod device;
35pub mod io;
36pub mod nn;
37pub mod ops;
38pub mod tensor;
39pub mod utils;
40
41pub use autograd::GradFn;
43pub use device::Device;
44pub use nn::layers::Dropout;
45pub use nn::layers::flatten::Flatten;
46pub use nn::{
47 Adam, BatchNorm1d, BatchNorm2d, Conv2d, ConvTranspose2d, Embedding, LSTMCell, Linear,
48 MaxPool2d, Module, PixelShuffle, ReLU, SGD, Sequential, SequentialBuilder, Sigmoid, Tanh,
49};
50pub use tensor::{RawTensor, Tensor, TensorOps};
51
52pub use tensor::{
55 DataLoader, bce_loss, bce_with_logits_loss, check_gradients, check_gradients_simple,
56 cross_entropy_loss, kl_divergence_gaussian, manual_seed, max_dim, mse_loss, new_tensor,
57 nll_loss, ones, rand, randn, randn_like, softmax, sum_dim, zeros,
58};
59
60pub use data::{load_mnist_images, load_mnist_labels, normalize, to_one_hot};
61pub use io::{
62 TypedTensorData, load_safetensors, load_safetensors_raw, load_safetensors_with_mapping,
63 load_state_dict_with_mapping, mapping, save_safetensors, save_safetensors_typed,
64};
65pub use utils::ProgressBar;
66
67pub use ops::{
68 BinaryGradFn, BinaryOp, MatMulGradFn, MaxReduceGradFn, MeanGradFn, MovementGradFn, MovementOp,
69 MulAccGradFn, ReduceOp, SumGradFn, TernaryOp, UnaryGradFn, UnaryOp, WhereGradFn,
70};
71
72#[cfg(test)]
81mod tests {
82 use super::*;
83
84 #[test]
85 fn test_basic_add_backward() {
86 let a = RawTensor::new(vec![2.0], &[1], true);
87 let b = RawTensor::new(vec![3.0], &[1], true);
88 let c = a.add(&b);
89 c.backward();
90
91 assert_eq!(a.grad(), Some(vec![1.0]));
92 assert_eq!(b.grad(), Some(vec![1.0]));
93 }
94
95 #[test]
96 fn test_enhanced_device_safety() {
97 let a = RawTensor::new(vec![2.0], &[1], true);
98 let b = RawTensor::new(vec![3.0], &[1], true);
99 let c = a.add(&b);
100 c.backward();
101
102 let cpu_device = Device::CPU;
104 assert!(cpu_device.is_cpu());
105 assert!(!cpu_device.is_gpu());
106 assert_eq!(cpu_device.name(), "CPU");
107
108 let gpu_device = Device::GPU("CUDA".to_string());
109 assert!(!gpu_device.is_cpu());
110 assert!(gpu_device.is_gpu());
111 assert_eq!(gpu_device.name(), "CUDA");
112
113 assert_eq!(a.grad(), Some(vec![1.0]));
115 assert_eq!(b.grad(), Some(vec![1.0]));
116 }
117
118 #[test]
119 fn test_multiply_backward() {
120 let a = RawTensor::new(vec![3.0], &[1], true);
121 let b = RawTensor::new(vec![4.0], &[1], true);
122 let c = a.elem_mul(&b);
123 c.backward();
124
125 assert_eq!(a.grad(), Some(vec![4.0]));
126 assert_eq!(b.grad(), Some(vec![3.0]));
127 }
128
129 #[test]
130 fn test_chain_rule() {
131 let a = RawTensor::new(vec![2.0], &[1], true);
132 let b = RawTensor::new(vec![3.0], &[1], true);
133 let c = a.add(&b);
134 let d = c.elem_mul(&a);
135 d.backward();
136
137 assert_eq!(a.grad(), Some(vec![7.0]));
138 assert_eq!(b.grad(), Some(vec![2.0]));
139 }
140
141 #[test]
142 fn test_sum_backward() {
143 let a = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
144 let loss = a.sum();
145 loss.backward();
146
147 assert_eq!(a.grad(), Some(vec![1.0, 1.0, 1.0]));
148 }
149
150 #[test]
151 fn test_multidim_ops() {
152 let a = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], true);
153 let b = RawTensor::new(vec![0.5, 0.5, 0.5, 0.5], &[2, 2], true);
154 let c = a.elem_mul(&b);
155 let loss = c.sum();
156 loss.backward();
157
158 assert_eq!(a.grad(), Some(vec![0.5, 0.5, 0.5, 0.5]));
159 assert_eq!(b.grad(), Some(vec![1.0, 2.0, 3.0, 4.0]));
160 }
161}
162
163#[cfg(test)]
164mod unary_tests {
165 use super::*;
166 use approx::assert_relative_eq;
167
168 #[test]
169 fn test_neg_forward_backward() {
170 let x = RawTensor::new(vec![2.0, -3.0], &[2], true);
171 let y = x.neg();
172
173 assert_eq!(y.borrow().data, vec![-2.0, 3.0]);
175
176 y.backward();
178 assert_eq!(x.grad(), Some(vec![-1.0, -1.0]));
179 }
180
181 #[test]
182 fn test_sqrt_chain() {
183 let x = RawTensor::new(vec![4.0], &[1], true);
184 let y = x.sqrt(); let z = y.elem_mul(&y); z.backward();
187
188 assert_relative_eq!(
190 x.grad().unwrap().first().copied().unwrap_or(f32::NAN),
191 1.0,
192 epsilon = 1e-6
193 );
194 }
195
196 #[test]
197 fn test_exp2_log2_inverse() {
198 let x = RawTensor::new(vec![2.0], &[1], true);
199 let y = x.exp2().log2(); y.backward();
201
202 assert_relative_eq!(
203 y.borrow().data.first().copied().unwrap_or(f32::NAN),
204 2.0,
205 epsilon = 1e-6
206 );
207 assert_relative_eq!(
209 x.grad().unwrap().first().copied().unwrap_or(f32::NAN),
210 1.0,
211 epsilon = 1e-6
212 );
213 }
214}
215
216#[cfg(test)]
217mod binary_tests {
218 use super::*;
219
220 #[test]
221 fn test_div_backward() {
222 let x = RawTensor::new(vec![6.0], &[1], true);
223 let y = RawTensor::new(vec![2.0], &[1], true);
224 let z = x.div(&y); z.backward();
226
227 assert_eq!(x.grad(), Some(vec![0.5]));
229 assert_eq!(y.grad(), Some(vec![-1.5]));
231 }
232
233 #[test]
234 fn test_max_backward() {
235 let x = RawTensor::new(vec![3.0, 1.0], &[2], true);
236 let y = RawTensor::new(vec![2.0, 4.0], &[2], true);
237 let z = x.max_elem(&y);
238 let loss = z.sum();
239 loss.backward();
240
241 assert_eq!(x.grad(), Some(vec![1.0, 0.0]));
243 assert_eq!(y.grad(), Some(vec![0.0, 1.0]));
244 }
245}
246
247#[cfg(test)]
248mod reduce_tests {
249 use super::*;
250
251 #[test]
252 fn test_reduce_max_backward() {
253 let x = RawTensor::new(vec![1.0, 5.0, 3.0], &[3], true);
254 let y = x.max_reduce(); y.backward();
256
257 assert_eq!(x.grad(), Some(vec![0.0, 1.0, 0.0]));
259 }
260}
261
262#[cfg(test)]
263mod ternary_tests {
264 use super::*;
265
266 #[test]
267 fn test_mulacc_backward() {
268 let x = RawTensor::new(vec![2.0], &[1], true);
270 let y = RawTensor::new(vec![3.0], &[1], true);
271 let w = RawTensor::new(vec![1.0], &[1], true);
272 let z = x.mulacc(&y, &w); z.backward();
274
275 assert_eq!(x.grad(), Some(vec![3.0])); assert_eq!(y.grad(), Some(vec![2.0])); assert_eq!(w.grad(), Some(vec![1.0])); }
279
280 #[test]
281 fn test_where_backward() {
282 let cond = RawTensor::new(vec![1.0, 0.0], &[2], false);
283 let x = RawTensor::new(vec![10.0, 20.0], &[2], true);
284 let y = RawTensor::new(vec![30.0, 40.0], &[2], true);
285 let z = cond.where_op(&x, &y); z.backward();
287
288 assert_eq!(x.grad(), Some(vec![1.0, 0.0])); assert_eq!(y.grad(), Some(vec![0.0, 1.0])); }
291
292 #[test]
293 fn test_where_broadcast_backward() {
294 let cond = RawTensor::new(vec![1.0, 0.0], &[2, 1], false);
296 let true_branch = RawTensor::new(vec![10.0, 11.0, 12.0, 20.0, 21.0, 22.0], &[2, 3], true);
297 let false_branch = RawTensor::new(vec![1.0, 2.0, 3.0], &[1, 3], true);
298 let out = cond.where_op(&true_branch, &false_branch);
299 let loss = out.sum();
300 loss.backward();
301
302 assert_eq!(true_branch.grad(), Some(vec![1.0, 1.0, 1.0, 0.0, 0.0, 0.0]));
304 assert_eq!(false_branch.grad(), Some(vec![1.0, 1.0, 1.0]));
305 }
306}
307
308#[cfg(test)]
309mod movement_tests {
310 use super::*;
311
312 #[test]
313 fn test_reshape_backward() {
314 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[4], true);
315 let y = x.reshape(&[2, 2]);
316 let loss = y.sum();
317 loss.backward();
318
319 assert_eq!(x.grad(), Some(vec![1.0, 1.0, 1.0, 1.0]));
321 }
322
323 #[test]
324 fn test_permute_backward() {
325 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], true);
326 let y = x.permute(&[1, 0]); let loss = y.sum();
328 loss.backward();
329
330 assert_eq!(x.grad(), Some(vec![1.0, 1.0, 1.0, 1.0]));
332 }
333}
334
335#[cfg(test)]
336mod misc_tests {
337 use super::*;
338 #[test]
341 fn test_linear_layer() {
342 let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[1, 3], true);
344 let w = RawTensor::new(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], &[3, 2], true);
345 let b = RawTensor::new(vec![0.1, 0.2], &[1, 2], true);
346
347 let y = x.matmul(&w); let out = y.add(&b);
349 let loss = out.sum();
350
351 loss.backward();
352
353 assert!(x.grad().is_some());
355 assert!(w.grad().is_some());
356 assert!(b.grad().is_some());
357
358 assert_eq!(b.grad(), Some(vec![1.0, 1.0]));
360 }
361
362 #[test]
363 fn test_tensor_zero_grad() {
364 let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
365 let loss = x.sum();
366 loss.backward();
367
368 assert!(x.grad().is_some());
369
370 x.zero_grad();
371 assert!(x.grad().is_none());
372
373 let rewind = x.sum();
374 rewind.backward();
375 assert_eq!(x.grad(), Some(vec![1.0, 1.0, 1.0]));
376 }
377
378 #[test]
381 fn test_broadcast_shape() {
382 let shape = RawTensor::broadcast_shape(&[3, 1], &[1, 4]);
384 assert_eq!(shape, vec![3, 4]);
385
386 let shape = RawTensor::broadcast_shape(&[5, 3, 1], &[1, 4]);
388 assert_eq!(shape, vec![5, 3, 4]);
389
390 let shape = RawTensor::broadcast_shape(&[1], &[3, 4]);
392 assert_eq!(shape, vec![3, 4]);
393
394 let shape = RawTensor::broadcast_shape(&[3, 4], &[4]);
396 assert_eq!(shape, vec![3, 4]);
397 }
398
399 #[test]
400 #[should_panic(expected = "Cannot broadcast")]
401 fn test_broadcast_incompatible() {
402 let _ = RawTensor::broadcast_shape(&[3, 2], &[4, 3]);
403 }
404
405 #[test]
406 fn test_broadcast_add_scalar() {
407 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
409 let scalar = RawTensor::new(vec![10.0], &[1], true);
410 let y = x.add(&scalar);
411
412 assert_eq!(y.borrow().shape, vec![2, 3]);
413 assert_eq!(y.borrow().data, vec![11.0, 12.0, 13.0, 14.0, 15.0, 16.0]);
414
415 y.backward();
416
417 assert_eq!(x.grad(), Some(vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0]));
419 assert_eq!(scalar.grad(), Some(vec![6.0]));
421 }
422
423 #[test]
424 fn test_broadcast_mul_vector() {
425 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
427 let v = RawTensor::new(vec![2.0, 3.0, 4.0], &[3], true);
428 let y = x.elem_mul(&v);
429
430 assert_eq!(y.borrow().shape, vec![2, 3]);
431 assert_eq!(y.borrow().data, vec![2.0, 6.0, 12.0, 8.0, 15.0, 24.0]);
432
433 y.backward();
434
435 assert_eq!(x.grad(), Some(vec![2.0, 3.0, 4.0, 2.0, 3.0, 4.0]));
437 assert_eq!(v.grad(), Some(vec![5.0, 7.0, 9.0])); }
440
441 #[test]
442 fn test_broadcast_add_matrix() {
443 let x = RawTensor::new(vec![1.0, 2.0], &[2, 1], true);
445 let y = RawTensor::new(vec![10.0, 20.0, 30.0], &[1, 3], true);
446 let z = x.add(&y);
447
448 assert_eq!(z.borrow().shape, vec![2, 3]);
449 assert_eq!(z.borrow().data, vec![11.0, 21.0, 31.0, 12.0, 22.0, 32.0]);
450
451 z.backward();
452
453 assert_eq!(x.grad(), Some(vec![3.0, 3.0]));
455 assert_eq!(y.grad(), Some(vec![2.0, 2.0, 2.0]));
457 }
458
459 #[test]
460 fn test_broadcast_batch_bias() {
461 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2], true);
463 let bias = RawTensor::new(vec![0.5, 1.0], &[2], true);
464 let y = x.add(&bias);
465
466 assert_eq!(y.borrow().shape, vec![3, 2]);
467 assert_eq!(y.borrow().data, vec![1.5, 3.0, 3.5, 5.0, 5.5, 7.0]);
468
469 let loss = y.sum();
470 loss.backward();
471
472 assert_eq!(x.grad(), Some(vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0]));
474 assert_eq!(bias.grad(), Some(vec![3.0, 3.0]));
476 }
477
478 #[test]
479 fn test_broadcast_div() {
480 let x = RawTensor::new(vec![2.0, 4.0, 6.0, 8.0, 10.0, 12.0], &[2, 3], true);
482 let y = RawTensor::new(vec![2.0, 2.0, 2.0], &[1, 3], true);
483 let z = x.div(&y);
484
485 assert_eq!(z.borrow().shape, vec![2, 3]);
486 assert_eq!(z.borrow().data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
487
488 z.backward();
489
490 assert_eq!(x.grad(), Some(vec![0.5, 0.5, 0.5, 0.5, 0.5, 0.5]));
492 assert_eq!(y.grad(), Some(vec![-2.5, -3.5, -4.5]));
497 }
498
499 #[test]
500 fn test_broadcast_3d() {
501 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[1, 2, 3], true);
503 let y = RawTensor::new(vec![10.0, 20.0], &[2, 1], true);
504 let z = x.add(&y);
505
506 assert_eq!(z.borrow().shape, vec![1, 2, 3]);
507 assert_eq!(z.borrow().data, vec![11.0, 12.0, 13.0, 24.0, 25.0, 26.0]);
510
511 z.backward();
512
513 assert_eq!(x.grad(), Some(vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0]));
515 assert_eq!(y.grad(), Some(vec![3.0, 3.0]));
517 }
518
519 #[test]
520 fn test_broadcast_max() {
521 let x = RawTensor::new(vec![1.0, 5.0, 3.0, 4.0, 2.0, 6.0], &[2, 3], true);
523 let y = RawTensor::new(vec![2.0, 3.0, 4.0], &[3], true);
524 let z = x.max_elem(&y);
525
526 assert_eq!(z.borrow().shape, vec![2, 3]);
527 assert_eq!(z.borrow().data, vec![2.0, 5.0, 4.0, 4.0, 3.0, 6.0]);
530
531 z.backward();
532
533 assert_eq!(x.grad(), Some(vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0]));
536 assert_eq!(y.grad(), Some(vec![1.0, 1.0, 1.0]));
542 }
543
544 #[test]
545 fn test_broadcast_bias_add() {
546 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
549 let w = RawTensor::new(
550 vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2],
551 &[3, 4],
552 true,
553 );
554 let b = RawTensor::new(vec![0.01, 0.02, 0.03, 0.04], &[4], true);
555
556 let y = x.matmul(&w);
557 let z = y.add(&b); let loss = z.sum();
559
560 loss.backward();
561
562 assert!(x.grad().is_some());
564 assert!(w.grad().is_some());
565 assert!(b.grad().is_some());
566
567 assert_eq!(b.grad(), Some(vec![2.0, 2.0, 2.0, 2.0]));
570 }
571
572 #[test]
573 fn test_batched_matmul() {
574 let x = RawTensor::ones(&[2, 2, 3]);
576 let y = RawTensor::ones(&[2, 3, 2]);
577 let z = x.matmul(&y);
578
579 assert_eq!(z.borrow().shape, vec![2, 2, 2]);
580 assert_eq!(z.borrow().data.first().copied().unwrap_or(f32::NAN), 3.0);
582 assert_eq!(z.borrow().data.get(7).copied().unwrap_or(f32::NAN), 3.0);
583 }
584 #[test]
585 #[allow(clippy::identity_op)]
586 fn test_batched_matmul_broadcasting() {
588 let a_data = vec![1.0; 2 * 1 * 2 * 3]; let b_data = vec![2.0; 1 * 2 * 3 * 1]; let a = RawTensor::new(a_data, &[2, 1, 2, 3], true);
595 let b = RawTensor::new(b_data, &[1, 2, 3, 1], true);
596
597 let c = a.matmul(&b);
598
599 assert_eq!(c.borrow().shape, vec![2, 2, 2, 1]);
601
602 assert_eq!(c.borrow().data.first().copied().unwrap_or(f32::NAN), 6.0);
604
605 let loss = c.sum();
606 loss.backward();
607
608 assert!(a.grad().is_some());
612 assert!(b.grad().is_some());
613 }
614
615 #[test]
616 fn test_matmul_matrix_vector_backward() {
617 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[3, 2], true);
619 let v = RawTensor::new(vec![0.5, -1.0], &[2], true);
620
621 let z = x.matmul(&v);
623 let loss = z.sum();
625 loss.backward();
626
627 assert_eq!(x.grad(), Some(vec![0.5, -1.0, 0.5, -1.0, 0.5, -1.0]));
629 assert_eq!(v.grad(), Some(vec![9.0, 12.0]));
632 }
633
634 #[test]
635 fn test_dot_backward() {
636 let a = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
638 let b = RawTensor::new(vec![4.0, 5.0, 6.0], &[3], true);
639
640 let loss = a.matmul(&b);
642 loss.backward();
643
644 assert_eq!(a.grad(), Some(vec![4.0, 5.0, 6.0]));
646 assert_eq!(b.grad(), Some(vec![1.0, 2.0, 3.0]));
648 }
649
650 #[test]
651 fn test_gradcheck_matrix_vector_matmul() {
652 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], true);
654 let v = RawTensor::new(vec![0.3, -0.7], &[2], false);
655 let passed = RawTensor::check_gradients_simple(&x, |t| t.matmul(&v).sum());
656 assert!(passed, "Matrix-vector matmul gradient check failed");
657 }
658
659 #[test]
660 fn test_broadcast_sub() {
661 let x = RawTensor::new(vec![5.0, 6.0, 7.0, 8.0], &[2, 2], true);
663 let y = RawTensor::new(vec![1.0, 2.0], &[2], true);
664 let z = x.sub(&y);
665
666 assert_eq!(z.borrow().shape, vec![2, 2]);
667 assert_eq!(z.borrow().data, vec![4.0, 4.0, 6.0, 6.0]);
668
669 z.backward();
670
671 assert_eq!(x.grad(), Some(vec![1.0, 1.0, 1.0, 1.0]));
672 assert_eq!(y.grad(), Some(vec![-2.0, -2.0]));
674 }
675
676 #[test]
679 fn test_gradcheck_unary_ops() {
680 let x = RawTensor::new(vec![4.0, 9.0, 16.0], &[3], true);
682 let passed = RawTensor::check_gradients_simple(&x, |t| {
683 let y = t.sqrt();
684 y.sum()
685 });
686 assert!(passed, "Sqrt gradient check failed");
687
688 let x = RawTensor::new(vec![0.5, 1.0, 1.5], &[3], true);
690 let passed = RawTensor::check_gradients_simple(&x, |t| {
691 let y = t.sin();
692 y.sum()
693 });
694 assert!(passed, "Sin gradient check failed");
695
696 let x = RawTensor::new(vec![0.0, 1.0, -1.0], &[3], true);
698 let passed = RawTensor::check_gradients_simple(&x, |t| {
699 let y = t.sigmoid();
700 y.sum()
701 });
702 assert!(passed, "Sigmoid gradient check failed");
703 }
704
705 #[test]
706 fn test_gradcheck_binary_ops() {
707 let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
709 let y = RawTensor::new(vec![4.0, 5.0, 6.0], &[3], false);
710 let passed = RawTensor::check_gradients_simple(&x, |t| {
711 let z = t.add(&y);
712 z.sum()
713 });
714 assert!(passed, "Add gradient check failed");
715
716 let x = RawTensor::new(vec![2.0, 3.0], &[2], true);
718 let y = RawTensor::new(vec![4.0, 5.0], &[2], false);
719 let passed = RawTensor::check_gradients_simple(&x, |t| {
720 let z = t.elem_mul(&y);
721 z.sum()
722 });
723 assert!(passed, "Mul gradient check failed");
724
725 let x = RawTensor::new(vec![6.0, 8.0], &[2], true);
727 let y = RawTensor::new(vec![2.0, 4.0], &[2], false);
728 let passed = RawTensor::check_gradients_simple(&x, |t| {
729 let z = t.div(&y);
730 z.sum()
731 });
732 assert!(passed, "Div gradient check failed");
733 }
734
735 #[test]
736 fn test_gradcheck_matmul() {
737 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
739 let w = RawTensor::new(
740 vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
741 &[3, 3],
742 false,
743 );
744 let passed = RawTensor::check_gradients_simple(&x, |t| {
745 let y = t.matmul(&w);
746 y.sum()
747 });
748 assert!(passed, "Matmul gradient check failed");
749 }
750
751 #[test]
752 fn test_gradcheck_broadcast() {
753 let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
755 let y = RawTensor::new(vec![0.5], &[1], false);
756 let passed = RawTensor::check_gradients_simple(&x, |t| {
757 let z = t.elem_mul(&y);
758 z.sum()
759 });
760 assert!(passed, "Broadcast gradient check failed");
761
762 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], true);
764 let y = RawTensor::new(vec![0.5, 1.0], &[2], false);
765 let passed = RawTensor::check_gradients_simple(&x, |t| {
766 let z = t.add(&y);
767 z.sum()
768 });
769 assert!(passed, "Matrix broadcast gradient check failed");
770 }
771
772 #[test]
773 fn test_gradcheck_movement_ops() {
774 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[4], true);
776 let passed = RawTensor::check_gradients_simple(&x, |t| {
777 let y = t.reshape(&[2, 2]);
778 y.sum()
779 });
780 assert!(passed, "Reshape gradient check failed");
781
782 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], true);
784 let passed = RawTensor::check_gradients_simple(&x, |t| {
785 let y = t.permute(&[1, 0]);
786 y.sum()
787 });
788 assert!(passed, "Permute gradient check failed");
789
790 let x = RawTensor::new(vec![1.0, 2.0], &[2], true);
792 let passed = RawTensor::check_gradients_simple(&x, |t| {
793 let y = t.pad(&[(1, 1)]);
794 y.sum()
795 });
796 assert!(passed, "Pad gradient check failed");
797
798 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[4], true);
800 let passed = RawTensor::check_gradients_simple(&x, |t| {
801 let y = t.shrink(&[(1, 3)]);
802 y.sum()
803 });
804 assert!(passed, "Shrink gradient check failed");
805 }
806
807 #[test]
808 fn test_gradcheck_reduce_ops() {
809 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[4], true);
811 let passed = RawTensor::check_gradients_simple(&x, |t| t.mean());
812 assert!(passed, "Mean gradient check failed");
813
814 let x = RawTensor::new(vec![1.0, 5.0, 3.0, 2.0], &[4], true);
816 let passed = RawTensor::check_gradients_simple(&x, |t| t.max_reduce());
817 assert!(passed, "Max gradient check failed");
818 }
819
820 #[test]
821 fn test_gradcheck_ternary_ops() {
822 let x = RawTensor::new(vec![1.0, 2.0], &[2], true);
824 let y = RawTensor::new(vec![3.0, 4.0], &[2], false);
825 let z = RawTensor::new(vec![0.5, 1.0], &[2], false);
826 let passed = RawTensor::check_gradients_simple(&x, |t| {
827 let out = t.mulacc(&y, &z);
828 out.sum()
829 });
830 assert!(passed, "MulAcc gradient check failed");
831 }
832
833 #[test]
834 fn test_gradcheck_complex_chain() {
835 let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
837 let w = RawTensor::new(vec![0.5, 1.0, 1.5], &[3], false);
838
839 let passed = RawTensor::check_gradients_simple(&x, |t| {
840 let prod = t.elem_mul(&w);
842 let y = prod.sigmoid();
843 y.sum()
844 });
845 assert!(passed, "Complex chain gradient check failed");
846 }
847
848 #[test]
849 fn test_gradcheck_neural_network_layer() {
850 let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[1, 3], true);
852 let w = RawTensor::new(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], &[3, 2], false);
853 let b = RawTensor::new(vec![0.1, 0.2], &[2], false);
854
855 let passed = RawTensor::check_gradients_simple(&x, |t| {
856 let y = t.matmul(&w);
857 let z = y.add(&b);
858 z.sum()
859 });
860 assert!(passed, "Neural network layer gradient check failed");
861 }
862
863 #[test]
864 fn test_gradcheck_with_tolerance() {
865 let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
867
868 let (max_err, mean_err, passed) = RawTensor::check_gradients(
869 &x,
870 |t| {
871 let y = t.relu();
872 y.sum()
873 },
874 1e-5, 1e-2, );
877
878 assert!(passed, "Custom tolerance gradient check failed");
879 println!(
880 "ReLU gradcheck: max_err={:.6e}, mean_err={:.6e}",
881 max_err, mean_err
882 );
883 }
884
885 #[test]
886 fn test_gradcheck_multidim() {
887 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
889
890 let passed = RawTensor::check_gradients_simple(&x, |t| {
891 let y = t.sqrt();
892 let z = y.elem_mul(t);
893 z.sum()
894 });
895 assert!(passed, "Multidim gradient check failed");
896 }
897
898 #[test]
899 fn test_gradcheck_expand() {
900 let x = RawTensor::new(vec![1.0, 2.0], &[2, 1], true);
901 let passed = RawTensor::check_gradients_simple(&x, |t| {
902 let y = t.expand(&[2, 3]);
903 y.sum()
904 });
905 assert!(passed, "Expand gradient check failed");
906 }
907
908 #[test]
909 fn test_gradcheck_transpose() {
910 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
912 let passed = RawTensor::check_gradients_simple(&x, |t| {
913 let y = t.transpose();
914 let w = RawTensor::new(vec![0.5, 1.0, 1.5, 2.0, 2.5, 3.0], &[3, 2], false);
916 y.elem_mul(&w).sum()
917 });
918 assert!(passed, "Transpose gradient check failed");
919 }
920
921 #[test]
922 fn test_gradcheck_pad() {
923 let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
924 let passed = RawTensor::check_gradients_simple(&x, |t| {
925 let y = t.pad(&[(1, 1)]);
926 y.sum()
927 });
928 assert!(passed, "Pad gradient check failed");
929 }
930
931 #[test]
932 fn test_gradcheck_shrink() {
933 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5], true);
934 let passed = RawTensor::check_gradients_simple(&x, |t| {
935 let y = t.shrink(&[(1, 4)]);
936 y.sum()
937 });
938 assert!(passed, "Shrink gradient check failed");
939 }
940
941 #[test]
942 fn test_gradcheck_stride() {
943 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6], true);
944 let passed = RawTensor::check_gradients_simple(&x, |t| {
945 let y = t.stride_op(&[2]);
946 y.sum()
947 });
948 assert!(passed, "Stride gradient check failed");
949 }
950
951 #[test]
952 fn test_gradcheck_matmul_vec() {
953 let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], true);
955 let w = RawTensor::new(vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6], &[3, 2], false);
956 let passed = RawTensor::check_gradients_simple(&x, |t| {
957 let y = t.matmul(&w);
958 y.sum()
959 });
960 assert!(passed, "Vec-mat matmul gradient check failed");
961 }
962 #[test]
963 fn test_broadcast_3d_fix() {
964 let x = RawTensor::new(vec![10.0, 20.0], &[2, 1], true);
966 let y = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[1, 2, 3], true);
967 let z = x.add(&y);
968
969 assert_eq!(z.borrow().shape, vec![1, 2, 3]);
970 assert_eq!(z.borrow().data, vec![11.0, 12.0, 13.0, 24.0, 25.0, 26.0]);
973
974 z.backward();
975 assert_eq!(x.grad(), Some(vec![3.0, 3.0])); assert_eq!(y.grad(), Some(vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0]));
977 }
978
979 #[test]
980 fn test_broadcast_batch_channels() {
981 let x = RawTensor::new((0..16).map(|i| i as f32).collect(), &[2, 2, 2, 2], true);
983 let bias = RawTensor::new(vec![0.1, 0.2], &[2, 1, 1], true);
984 let z = x.add(&bias);
985
986 assert_eq!(z.borrow().shape, vec![2, 2, 2, 2]);
987 let loss = z.sum();
988 loss.backward();
989
990 assert_eq!(bias.grad(), Some(vec![8.0, 8.0]));
992 }
993
994 #[test]
995 fn test_gradcheck_broadcast_3d() {
996 let x = RawTensor::new(vec![1.0, 2.0], &[2, 1], true);
997 let y = RawTensor::new(vec![0.5; 6], &[1, 2, 3], false);
998 let passed = RawTensor::check_gradients_simple(&x, |t| t.add(&y).sum());
999 assert!(passed, "3D broadcast gradcheck failed");
1000 }
1001 #[test]
1002 fn test_sequential_forward() {
1003 let model = Sequential::new(vec![
1004 Box::new(Linear::new(3, 4, true)),
1005 Box::new(ReLU),
1006 Box::new(Linear::new(4, 2, true)),
1007 ]);
1008
1009 let x = RawTensor::new(vec![1.0, 2.0, 3.0], &[1, 3], true);
1010 let y = model.forward(&x);
1011
1012 assert_eq!(y.borrow().shape, vec![1, 2]);
1013
1014 let loss = y.sum();
1015 loss.backward();
1016
1017 for param in model.parameters() {
1019 assert!(param.grad().is_some(), "Missing gradient");
1020 }
1021 }
1022
1023 #[test]
1024 fn test_sequential_zero_grad() {
1025 let mut model = Sequential::new(vec![Box::new(Linear::new(2, 3, true))]);
1026
1027 let x = RawTensor::new(vec![1.0, 2.0], &[1, 2], true);
1028 model.forward(&x).sum().backward();
1029
1030 assert!(model.parameters().first().unwrap().grad().is_some());
1032
1033 model.zero_grad();
1034
1035 for p in model.parameters() {
1037 assert!(p.grad().is_none());
1038 }
1039 }
1040 #[test]
1041 fn test_adam_converges_faster() {
1042 let x_data: Vec<f32> = (0..10).map(|i| i as f32 * 0.1).collect(); let y_data: Vec<f32> = x_data.iter().map(|v| v * 2.0).collect();
1049
1050 let x = RawTensor::new(x_data.clone(), &[10, 1], false);
1051 let y = RawTensor::new(y_data.clone(), &[10, 1], false);
1052
1053 let layer = Linear::new(1, 1, false);
1056 *layer.weight.borrow_mut().data.first_mut().unwrap() = 0.0;
1058
1059 let model = Sequential::new(vec![Box::new(layer)]);
1060
1061 let params = model.parameters();
1062 let mut opt = Adam::new(params, 0.5, (0.9, 0.999), 1e-8, 0.0);
1064
1065 let mut losses = vec![];
1066 for _ in 0..50 {
1067 opt.zero_grad();
1068
1069 let pred = model.forward(&x);
1070 let loss = RawTensor::mse_loss(&pred, &y);
1071 loss.backward();
1072 opt.step();
1073
1074 losses.push(loss.borrow().data.first().copied().unwrap_or(f32::NAN));
1075 }
1076
1077 let final_loss = *losses.last().unwrap();
1078 assert!(
1079 final_loss < 0.01,
1080 "Adam failed simple regression convergence: {:.6}",
1081 final_loss
1082 );
1083 }
1084 #[test]
1085 fn test_adam_vs_sgd() {
1086 crate::manual_seed(42); fn train_model(use_adam: bool) -> f32 {
1089 let x_data = vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0];
1090 let x = RawTensor::new(x_data, &[4, 2], false);
1091 let y_data = vec![0.0, 1.0, 1.0, 0.0];
1092 let y = RawTensor::new(y_data, &[4], false);
1093
1094 let model = Sequential::new(vec![
1095 Box::new(Linear::new(2, 8, true)),
1096 Box::new(ReLU),
1097 Box::new(Linear::new(8, 1, true)),
1098 ]);
1099
1100 let params = model.parameters();
1101
1102 if use_adam {
1103 let mut opt = Adam::new(params, 0.05, (0.9, 0.999), 1e-8, 0.0);
1104 for _ in 0..50 {
1105 opt.zero_grad();
1106 let pred = model.forward(&x).reshape(&[4]);
1107 let loss = RawTensor::mse_loss(&pred, &y);
1108 loss.backward();
1109 opt.step();
1110 }
1111 } else {
1112 let mut opt = SGD::new(params, 0.01, 0.0, 0.0);
1113 for _ in 0..50 {
1114 opt.zero_grad();
1115 let pred = model.forward(&x).reshape(&[4]);
1116 let loss = RawTensor::mse_loss(&pred, &y);
1117 loss.backward();
1118 opt.step();
1119 }
1120 }
1121
1122 let pred = model.forward(&x).reshape(&[4]);
1124 RawTensor::mse_loss(&pred, &y)
1125 .borrow()
1126 .data
1127 .first()
1128 .copied()
1129 .unwrap_or(f32::NAN)
1130 }
1131
1132 let adam_loss = train_model(true);
1133 let sgd_loss = train_model(false);
1134
1135 println!(
1136 "Adam final loss: {:.6}, SGD final loss: {:.6}",
1137 adam_loss, sgd_loss
1138 );
1139
1140 assert!(adam_loss < sgd_loss * 0.75, "Adam not outperforming SGD");
1142 }
1143 #[test]
1144 fn test_dataloader_iteration() {
1145 let data = (0..16).map(|i| i as f32).collect();
1147 let targets = vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
1148
1149 let mut loader = DataLoader::new(
1150 data,
1151 targets,
1152 &[2], &[1], 3, false, );
1157
1158 let (x, y) = loader.next().unwrap();
1160 assert_eq!(x.borrow().shape, vec![3, 2]);
1161 assert_eq!(x.borrow().data, vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0]);
1162 assert_eq!(y.borrow().shape, vec![3, 1]);
1163
1164 let (x, _y) = loader.next().unwrap();
1166 assert_eq!(x.borrow().shape, vec![3, 2]);
1167
1168 let (x, _y) = loader.next().unwrap();
1170 assert_eq!(x.borrow().shape, vec![2, 2]);
1171
1172 assert!(loader.next().is_none());
1174
1175 loader.reset();
1177 let (x, _) = loader.next().unwrap();
1178 assert_eq!(x.borrow().shape, vec![3, 2]);
1179 }
1180
1181 #[test]
1182 fn test_dataloader_in_training_loop() {
1183 let data = vec![0.0; 40]; let targets = vec![1.0; 10];
1185
1186 let model = Sequential::new(vec![Box::new(Linear::new(4, 2, true))]);
1187
1188 let mut opt = SGD::new(model.parameters(), 0.1, 0.0, 0.0);
1189
1190 for epoch in 0..2 {
1191 let loader = DataLoader::new(data.clone(), targets.clone(), &[4], &[1], 3, false);
1192
1193 for (batch_x, _batch_y) in loader {
1194 opt.zero_grad();
1195 let pred = model.forward(&batch_x);
1196 let loss = pred.sum();
1198 loss.backward();
1199 opt.step();
1200 }
1201
1202 println!("Epoch {} complete", epoch);
1203 }
1204 }
1205 #[test]
1206 fn bench_matmul_speedup() {
1207 use std::time::Instant;
1208
1209 let a = vec![1.0; 256 * 256];
1210 let b = vec![1.0; 256 * 256];
1211
1212 let start = Instant::now();
1213 let _ = RawTensor::matmul_raw(&a, &b, 256, 256, 256);
1214 let duration = start.elapsed();
1215
1216 println!("256x256 matmul: {:?}", duration);
1217 let max_duration_ms: u128 = if cfg!(all(feature = "accelerate", target_os = "macos")) {
1218 if cfg!(debug_assertions) { 50 } else { 10 }
1219 } else if cfg!(debug_assertions) {
1220 250
1221 } else {
1222 100
1223 };
1224
1225 assert!(
1226 duration.as_millis() < max_duration_ms,
1227 "Matmul took {:?} (> {}ms threshold for this build configuration)",
1228 duration,
1229 max_duration_ms
1230 );
1231 }
1232 #[test]
1233 fn test_batchnorm_working() {
1234 let mut bn = BatchNorm2d::new(3);
1235 let x = RawTensor::randn(&[2, 3, 4, 4]);
1236
1237 bn.train(true);
1239 let y = bn.forward(&x);
1240 assert_eq!(y.borrow().shape, vec![2, 3, 4, 4]);
1241
1242 bn.train(false);
1244 let y2 = bn.forward(&x);
1245 assert_eq!(y2.borrow().shape, vec![2, 3, 4, 4]);
1246 }
1247
1248 #[test]
1249 fn test_batchnorm1d_forward_shape() {
1250 let bn = BatchNorm1d::new(32);
1251 let x = RawTensor::randn(&[8, 32]); let y = bn.forward(&x);
1253 assert_eq!(y.borrow().shape, vec![8, 32]);
1254 }
1255
1256 #[test]
1257 fn test_batchnorm1d_train_vs_test_mode() {
1258 let mut bn = BatchNorm1d::new(4);
1259
1260 bn.train(true);
1262 for _ in 0..5 {
1263 let x = RawTensor::randn(&[16, 4]);
1264 let _ = bn.forward(&x);
1265 }
1266
1267 let test_input = RawTensor::randn(&[8, 4]);
1269 bn.train(true);
1270 let y_train = bn.forward(&test_input);
1271
1272 bn.train(false);
1273 let y_test = bn.forward(&test_input);
1274
1275 let train_data = &y_train.borrow().data;
1278 let test_data = &y_test.borrow().data;
1279 let differs = train_data
1280 .iter()
1281 .zip(test_data.iter())
1282 .any(|(a, b)| (a - b).abs() > 1e-5);
1283 assert!(differs, "Train and test outputs should differ");
1284 }
1285
1286 #[test]
1287 fn test_batchnorm1d_parameters() {
1288 let bn = BatchNorm1d::new(16);
1289 let params = bn.parameters();
1290 assert_eq!(params.len(), 2);
1292 assert_eq!(params.first().unwrap().borrow().shape, vec![16]);
1294 assert_eq!(params.get(1).unwrap().borrow().shape, vec![16]);
1296 }
1297
1298 #[test]
1299 fn test_pixelshuffle_forward_shape() {
1300 let layer = PixelShuffle::new(3);
1302 let x = RawTensor::randn(&[2, 36, 4, 4]); let y = layer.forward(&x);
1304 assert_eq!(y.borrow().shape, vec![2, 4, 12, 12]);
1305
1306 let layer2 = PixelShuffle::new(2);
1308 let x2 = RawTensor::randn(&[1, 12, 8, 8]); let y2 = layer2.forward(&x2);
1310 assert_eq!(y2.borrow().shape, vec![1, 3, 16, 16]);
1311 }
1312
1313 #[test]
1314 fn test_pixelshuffle_backward_flow() {
1315 let layer = PixelShuffle::new(2);
1316 let x = RawTensor::randn(&[2, 4, 3, 3]); x.borrow_mut().requires_grad = true;
1318
1319 let y = layer.forward(&x);
1320 assert_eq!(y.borrow().shape, vec![2, 1, 6, 6]);
1321
1322 let loss = y.sum();
1323 loss.backward();
1324
1325 let grad = x.grad();
1326 assert!(
1327 grad.is_some(),
1328 "Gradient should flow back through PixelShuffle"
1329 );
1330 assert_eq!(grad.unwrap().len(), 72);
1332 }
1333
1334 #[test]
1335 fn test_pixelshuffle_values() {
1336 let layer = PixelShuffle::new(2);
1340 #[rustfmt::skip]
1341 let data = vec![
1342 1.0, 2.0,
1344 3.0, 4.0,
1345 5.0, 6.0,
1347 7.0, 8.0,
1348 9.0, 10.0,
1350 11.0, 12.0,
1351 13.0, 14.0,
1353 15.0, 16.0,
1354 ];
1355 let x = RawTensor::new(data, &[1, 4, 2, 2], false);
1356 let y = layer.forward(&x);
1357
1358 assert_eq!(y.borrow().shape, vec![1, 1, 4, 4]);
1359
1360 let out_data = &y.borrow().data;
1363 assert_eq!(out_data.len(), 16);
1364
1365 let input_sum: f32 = (1..=16).map(|x| x as f32).sum();
1367 let output_sum: f32 = out_data.iter().sum();
1368 assert!((input_sum - output_sum).abs() < 1e-5);
1369 }
1370
1371 #[test]
1372 fn test_embedding_forward_shape() {
1373 let embedding = Embedding::new(100, 32);
1374 let indices = vec![5, 12, 7, 99];
1375 let output = embedding.forward(&indices);
1376 assert_eq!(output.borrow().shape, vec![4, 32]);
1377 }
1378
1379 #[test]
1380 fn test_embedding_backward_flow() {
1381 let embedding = Embedding::new(50, 16);
1382 let indices = vec![3, 10, 3]; let output = embedding.forward(&indices);
1384
1385 let loss = output.sum();
1387 loss.backward();
1388
1389 let grad = embedding.weight.grad();
1391 assert!(grad.is_some(), "Weight should have gradients");
1392
1393 let grad_data = grad.unwrap();
1394 let grad_at_idx3_sum: f32 = (0..16)
1397 .map(|d| grad_data.get(3 * 16 + d).copied().unwrap_or(f32::NAN))
1398 .sum();
1399 let expected_sum = 2.0 * 16.0; assert!(
1401 (grad_at_idx3_sum - expected_sum).abs() < 1e-4,
1402 "Expected accumulated grad sum {}, got {}",
1403 expected_sum,
1404 grad_at_idx3_sum
1405 );
1406 }
1407
1408 #[test]
1409 fn test_embedding_gradient_accumulation() {
1410 let embedding = Embedding::new(10, 4);
1411 let indices = vec![2, 5, 2, 2]; let output = embedding.forward(&indices);
1413
1414 let loss = output.sum();
1415 loss.backward();
1416
1417 let grad = embedding.weight.grad().unwrap();
1418 for d in 0..4 {
1420 let grad_val = grad.get(2 * 4 + d).copied().unwrap_or(f32::NAN);
1421 assert!(
1422 (grad_val - 3.0).abs() < 1e-5,
1423 "Expected grad 3.0 for index 2, got {}",
1424 grad_val
1425 );
1426 }
1427
1428 for d in 0..4 {
1430 let grad_val = grad.get(5 * 4 + d).copied().unwrap_or(f32::NAN);
1431 assert!(
1432 (grad_val - 1.0).abs() < 1e-5,
1433 "Expected grad 1.0 for index 5, got {}",
1434 grad_val
1435 );
1436 }
1437 }
1438}
1439
1440#[cfg(test)]
1441mod axis_reduce_tests {
1442 use super::*;
1443
1444 #[test]
1445 fn test_sum_dim_basic() {
1446 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
1448 let y = RawTensor::sum_dim(&x, 1, false);
1449
1450 assert_eq!(y.borrow().shape, vec![2]);
1451 assert_eq!(y.borrow().data, vec![6.0, 15.0]); }
1453
1454 #[test]
1455 fn test_sum_dim_keepdim() {
1456 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], false);
1457 let y = RawTensor::sum_dim(&x, 0, true);
1458
1459 assert_eq!(y.borrow().shape, vec![1, 2]);
1460 assert_eq!(y.borrow().data, vec![4.0, 6.0]); }
1462
1463 #[test]
1464 fn test_sum_dim_backward() {
1465 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
1466 let y = RawTensor::sum_dim(&x, 1, false); y.backward();
1468
1469 assert_eq!(x.grad(), Some(vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0]));
1471 }
1472
1473 #[test]
1474 fn test_max_dim_basic() {
1475 let x = RawTensor::new(vec![1.0, 5.0, 3.0, 2.0, 8.0, 4.0], &[2, 3], false);
1476 let y = RawTensor::max_dim(&x, 1, false);
1477
1478 assert_eq!(y.borrow().shape, vec![2]);
1479 assert_eq!(y.borrow().data, vec![5.0, 8.0]); }
1481
1482 #[test]
1483 fn test_max_dim_backward() {
1484 let x = RawTensor::new(vec![1.0, 5.0, 3.0, 2.0, 8.0, 4.0], &[2, 3], true);
1485 let y = RawTensor::max_dim(&x, 1, false);
1486 y.backward();
1487
1488 assert_eq!(x.grad(), Some(vec![0.0, 1.0, 0.0, 0.0, 1.0, 0.0]));
1490 }
1491
1492 #[test]
1493 fn test_gradcheck_sum_dim() {
1494 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], true);
1495 let passed =
1496 RawTensor::check_gradients_simple(&x, |t| RawTensor::sum_dim(t, 0, false).sum());
1497 assert!(passed, "sum_dim gradient check failed");
1498 }
1499
1500 #[test]
1501 fn test_gradcheck_max_dim() {
1502 let x = RawTensor::new(vec![1.0, 5.0, 3.0, 2.0], &[2, 2], true);
1503 let passed =
1504 RawTensor::check_gradients_simple(&x, |t| RawTensor::max_dim(t, 1, false).sum());
1505 assert!(passed, "max_dim gradient check failed");
1506 }
1507
1508 #[test]
1509 fn test_softmax_forward() {
1510 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
1512 let y = RawTensor::softmax(&x, 1);
1513
1514 let data = y.borrow();
1516 let row0_sum: f32 = data.data.get(0..3).unwrap().iter().sum();
1517 let row1_sum: f32 = data.data.get(3..6).unwrap().iter().sum();
1518
1519 approx::assert_relative_eq!(row0_sum, 1.0, epsilon = 1e-6);
1520 approx::assert_relative_eq!(row1_sum, 1.0, epsilon = 1e-6);
1521 }
1522
1523 #[test]
1524 fn test_gradcheck_softmax() {
1525 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], true);
1526 let passed = RawTensor::check_gradients_simple(&x, |t| RawTensor::softmax(t, 1).sum());
1527 assert!(passed, "Softmax gradient check failed");
1528 }
1529
1530 #[test]
1531 fn test_cross_entropy_loss() {
1532 let logits = RawTensor::new(vec![2.0, 1.0, 0.5, 2.5], &[2, 2], true);
1534 let targets = RawTensor::new(vec![1.0, 0.0, 0.0, 1.0], &[2, 2], false);
1535
1536 let loss = RawTensor::cross_entropy_loss(&logits, &targets);
1537 loss.backward();
1538
1539 assert_eq!(loss.borrow().shape, vec![1]);
1541 assert!(loss.borrow().data.first().copied().unwrap_or(f32::NAN) > 0.0);
1542
1543 assert_eq!(logits.grad().unwrap().len(), 4);
1545 }
1546 #[test]
1547 fn test_dropout_train_eval() {
1548 let mut dropout = Dropout::new(0.5);
1549 let x = RawTensor::ones(&[1000]);
1550
1551 dropout.train(true);
1553 let y = dropout.forward(&x);
1554 let y_data = &y.borrow().data;
1555 let num_zeros = y_data.iter().filter(|&&v| v == 0.0).count();
1556
1557 assert!(
1559 num_zeros > 400 && num_zeros < 600,
1560 "Dropout ratio off: {}",
1561 num_zeros
1562 );
1563
1564 let non_zeros_correct = y_data.iter().all(|&v| v == 0.0 || v == 2.0);
1566 assert!(non_zeros_correct, "Dropout scaling incorrect");
1567
1568 dropout.eval();
1570 let y_eval = dropout.forward(&x);
1571 let eval_correct = y_eval.borrow().data.iter().all(|&v| v == 1.0);
1572 assert!(eval_correct, "Dropout eval mode should be identity");
1573 }
1574
1575 #[test]
1576 fn test_weight_decay_sgd() {
1577 let w = RawTensor::new(vec![1.0], &[1], true);
1578 let mut opt = SGD::new(vec![w.clone()], 0.1, 0.0, 0.1);
1581
1582 w.borrow_mut().grad = Some(Storage::cpu(vec![0.0])); opt.step();
1584
1585 let new_val = w.borrow().data.first().copied().unwrap_or(f32::NAN);
1586 approx::assert_relative_eq!(new_val, 0.99, epsilon = 1e-6);
1587 }
1588
1589 #[test]
1590 fn test_mean_dim() {
1591 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
1595
1596 let m = x.mean_dim(1, false);
1598 assert_eq!(m.borrow().shape, vec![2]);
1599 assert!((m.borrow().data.first().copied().unwrap_or(f32::NAN) - 2.0).abs() < 1e-6);
1600 assert!((m.borrow().data.get(1).copied().unwrap_or(f32::NAN) - 5.0).abs() < 1e-6);
1601
1602 m.sum().backward();
1604 let grads = x.grad().unwrap();
1606 for g in grads {
1607 assert!((g - 1.0 / 3.0).abs() < 1e-6);
1608 }
1609 }
1610
1611 #[test]
1620 fn test_external_model_loading_integration() {
1621 use crate::io::{TensorData, load_state_dict, mapping::StateDictMapper, save_state_dict};
1622 use crate::nn::{Linear, Module, ReLU, Sequential};
1623 use std::collections::BTreeMap;
1624
1625 let mut pytorch_state = BTreeMap::new();
1628
1629 pytorch_state.insert(
1631 "fc1.weight".to_string(),
1632 TensorData {
1633 data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1635 shape: vec![3, 2],
1636 },
1637 );
1638 pytorch_state.insert(
1639 "fc1.bias".to_string(),
1640 TensorData {
1641 data: vec![0.1, 0.2, 0.3],
1642 shape: vec![3],
1643 },
1644 );
1645
1646 pytorch_state.insert(
1648 "fc2.weight".to_string(),
1649 TensorData {
1650 data: vec![0.5, 0.6, 0.7],
1651 shape: vec![1, 3],
1652 },
1653 );
1654 pytorch_state.insert(
1655 "fc2.bias".to_string(),
1656 TensorData {
1657 data: vec![0.01],
1658 shape: vec![1],
1659 },
1660 );
1661
1662 let temp_path = std::env::temp_dir().join("test_pytorch_model.bin");
1664 save_state_dict(&pytorch_state, temp_path.to_str().unwrap()).unwrap();
1665
1666 let loaded = load_state_dict(temp_path.to_str().unwrap()).unwrap();
1668
1669 let mapper = StateDictMapper::new()
1671 .rename("fc1.weight", "encoder.weight")
1672 .rename("fc1.bias", "encoder.bias")
1673 .rename("fc2.weight", "decoder.weight")
1674 .rename("fc2.bias", "decoder.bias")
1675 .transpose("encoder.weight") .transpose("decoder.weight"); let volta_state = mapper.map(loaded);
1679
1680 assert!(volta_state.contains_key("encoder.weight"));
1682 assert!(volta_state.contains_key("decoder.weight"));
1683 assert_eq!(volta_state.get("encoder.weight").unwrap().shape, vec![2, 3]);
1684 assert_eq!(volta_state.get("decoder.weight").unwrap().shape, vec![3, 1]);
1685
1686 let encoder_weight = &volta_state.get("encoder.weight").unwrap().data;
1690 assert_eq!(encoder_weight.first().copied().unwrap_or(f32::NAN), 1.0);
1691 assert_eq!(encoder_weight.get(1).copied().unwrap_or(f32::NAN), 3.0);
1692 assert_eq!(encoder_weight.get(2).copied().unwrap_or(f32::NAN), 5.0);
1693 assert_eq!(encoder_weight.get(3).copied().unwrap_or(f32::NAN), 2.0);
1694 assert_eq!(encoder_weight.get(4).copied().unwrap_or(f32::NAN), 4.0);
1695 assert_eq!(encoder_weight.get(5).copied().unwrap_or(f32::NAN), 6.0);
1696
1697 let mut model = Sequential::builder()
1699 .add_named("encoder", Box::new(Linear::new(2, 3, true)))
1700 .add_unnamed(Box::new(ReLU))
1701 .add_named("decoder", Box::new(Linear::new(3, 1, true)))
1702 .build();
1703
1704 model.load_state_dict(&volta_state);
1706
1707 let input = RawTensor::new(vec![1.0, 1.0], &[1, 2], false);
1709 let output = model.forward(&input);
1710
1711 assert_eq!(output.borrow().shape, vec![1, 1]);
1713
1714 assert!(model.get_named("encoder").is_some());
1716 assert!(model.get_named("decoder").is_some());
1717 assert!(model.get_named("nonexistent").is_none());
1718
1719 let names = model.layer_names();
1721 assert_eq!(names.first().copied().unwrap_or(None), Some("encoder"));
1722 assert_eq!(names.get(1).copied().unwrap_or(None), None); assert_eq!(names.get(2).copied().unwrap_or(None), Some("decoder"));
1724 }
1725}
1726
1727#[cfg(test)]
1728mod gpu_tests {
1729 use super::*;
1730 use approx::assert_abs_diff_eq;
1731
1732 #[test]
1733 fn test_device_gpu_returns_none_when_disabled() {
1734 let gpu = Device::gpu();
1736 if cfg!(feature = "gpu") {
1737 if is_gpu_available() {
1740 assert!(gpu.is_some());
1741 assert!(gpu.unwrap().is_gpu());
1742 }
1743 } else {
1744 assert!(gpu.is_none());
1746 }
1747 }
1748
1749 #[test]
1750 fn test_to_device_cpu_to_cpu() {
1751 let t = RawTensor::new(vec![1.0, 2.0, 3.0], &[3], false);
1752 let t_cpu = t.to_device(Device::CPU);
1753
1754 assert_eq!(t_cpu.borrow().device, Device::CPU);
1756 assert_eq!(t_cpu.borrow().data.to_vec(), vec![1.0, 2.0, 3.0]);
1757 }
1758
1759 #[cfg(feature = "gpu")]
1760 #[test]
1761 fn test_to_device_cpu_to_gpu() {
1762 if !is_gpu_available() {
1763 return; }
1765
1766 let t = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2], false);
1767 let gpu_device = Device::gpu().expect("GPU should be available");
1768 let t_gpu = t.to_device(gpu_device.clone());
1769
1770 assert!(t_gpu.borrow().device.is_gpu());
1772 assert_eq!(t_gpu.borrow().device.name(), gpu_device.name());
1773
1774 assert_eq!(t_gpu.borrow().data.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
1776 assert_eq!(t_gpu.borrow().shape, vec![2, 2]);
1777 }
1778
1779 #[cfg(feature = "gpu")]
1780 #[test]
1781 fn test_to_device_gpu_to_cpu() {
1782 if !is_gpu_available() {
1783 return; }
1785
1786 let gpu_device = Device::gpu().expect("GPU should be available");
1787 let t = RawTensor::new(vec![5.0, 6.0, 7.0], &[3], false);
1788 let t_gpu = t.to_device(gpu_device.clone());
1789
1790 let t_cpu = t_gpu.to_device(Device::CPU);
1792
1793 assert!(t_cpu.borrow().device.is_cpu());
1794 assert_eq!(t_cpu.borrow().data.to_vec(), vec![5.0, 6.0, 7.0]);
1795 }
1796
1797 #[cfg(feature = "gpu")]
1798 #[test]
1799 fn test_to_device_preserves_autograd_metadata() {
1800 if !is_gpu_available() {
1801 return; }
1803
1804 let a = RawTensor::new(vec![2.0], &[1], true);
1806 let b = RawTensor::new(vec![3.0], &[1], true);
1807 let c = a.add(&b);
1808
1809 let gpu_device = Device::gpu().expect("GPU should be available");
1811 let c_gpu = c.to_device(gpu_device);
1812
1813 assert!(c_gpu.borrow().requires_grad);
1815 assert!(!c_gpu.borrow().parents.is_empty());
1816 assert!(c_gpu.borrow().grad_fn.is_some());
1817
1818 }
1821
1822 #[cfg(feature = "gpu")]
1823 #[test]
1824 fn test_matmul_backward_gpu() {
1825 if !is_gpu_available() {
1826 return; }
1828
1829 let gpu_device = Device::gpu().expect("GPU should be available");
1830
1831 let a = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true)
1833 .to_device(gpu_device.clone());
1834 let b = RawTensor::new(
1835 vec![
1836 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1837 ],
1838 &[3, 4],
1839 true,
1840 )
1841 .to_device(gpu_device.clone());
1842 let c = a.matmul(&b);
1843
1844 assert!(c.borrow().device.is_gpu());
1846
1847 c.backward();
1849
1850 {
1852 let a_ref = a.borrow();
1853 let b_ref = b.borrow();
1854 let a_grad = a_ref.grad.as_ref().expect("a should have grad");
1855 let b_grad = b_ref.grad.as_ref().expect("b should have grad");
1856 assert!(a_grad.is_gpu());
1857 assert!(b_grad.is_gpu());
1858 }
1859
1860 let a_cpu = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true);
1862 let b_cpu = RawTensor::new(
1863 vec![
1864 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1865 ],
1866 &[3, 4],
1867 true,
1868 );
1869 let c_cpu = a_cpu.matmul(&b_cpu);
1870 c_cpu.backward();
1871
1872 let a_grad_data;
1873 let b_grad_data;
1874 {
1875 let a_ref = a.borrow();
1876 let b_ref = b.borrow();
1877 a_grad_data = a_ref.grad.as_ref().unwrap().to_vec();
1878 b_grad_data = b_ref.grad.as_ref().unwrap().to_vec();
1879 }
1880
1881 let a_grad_cpu_data;
1882 let b_grad_cpu_data;
1883 {
1884 let a_ref = a_cpu.borrow();
1885 let b_ref = b_cpu.borrow();
1886 a_grad_cpu_data = a_ref.grad.as_ref().unwrap().to_vec();
1887 b_grad_cpu_data = b_ref.grad.as_ref().unwrap().to_vec();
1888 }
1889
1890 assert_eq!(a_grad_data.len(), a_grad_cpu_data.len());
1891 assert_eq!(b_grad_data.len(), b_grad_cpu_data.len());
1892
1893 for (gpu_val, cpu_val) in a_grad_data.iter().zip(a_grad_cpu_data.iter()) {
1895 assert!((gpu_val - cpu_val).abs() < 1e-5);
1896 }
1897 for (gpu_val, cpu_val) in b_grad_data.iter().zip(b_grad_cpu_data.iter()) {
1898 assert!((gpu_val - cpu_val).abs() < 1e-5);
1899 }
1900 }
1901
1902 #[cfg(feature = "gpu")]
1903 #[test]
1904 fn test_linear_layer_backward_gpu() {
1905 if !is_gpu_available() {
1906 return; }
1908
1909 use crate::nn::{Linear, Module};
1910
1911 let gpu_device = Device::gpu().expect("GPU should be available");
1912
1913 let layer = Linear::new(4, 3, true);
1914
1915 let params = layer.parameters();
1917 for param in ¶ms {
1918 let p = RawTensor::to_device(param, gpu_device.clone());
1919 *param.borrow_mut() = p.borrow().clone();
1920 }
1921
1922 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4], true)
1923 .to_device(gpu_device);
1924 let out = layer.forward(&x);
1925 let loss = out.sum();
1926
1927 loss.backward();
1928
1929 let params = layer.parameters();
1931 assert!(!params.is_empty());
1932
1933 for param in params {
1934 let param_ref = param.borrow();
1935 if let Some(grad) = ¶m_ref.grad {
1936 assert!(grad.is_gpu(), "Gradient should be on GPU");
1937 }
1938 }
1939 }
1940
1941 #[cfg(feature = "gpu")]
1942 #[test]
1943 fn test_sum_backward_gpu() {
1944 if !is_gpu_available() {
1945 return;
1946 }
1947
1948 let gpu_device = Device::gpu().expect("GPU should be available");
1949
1950 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true)
1951 .to_device(gpu_device.clone());
1952 let sum_result = x.sum();
1953
1954 assert!(sum_result.borrow().device.is_gpu());
1956
1957 sum_result.backward();
1959
1960 let grad_data;
1962 {
1963 let x_ref = x.borrow();
1964 let x_grad = x_ref.grad.as_ref().expect("x should have grad");
1965 assert!(x_grad.is_gpu());
1966 grad_data = x_grad.to_vec();
1967 }
1968
1969 assert_eq!(grad_data.len(), 6);
1970 for &val in &grad_data {
1971 assert!((val - 1.0).abs() < 1e-5);
1972 }
1973 }
1974
1975 #[cfg(feature = "gpu")]
1976 #[test]
1977 fn test_mean_backward_gpu() {
1978 if !is_gpu_available() {
1979 return;
1980 }
1981
1982 let gpu_device = Device::gpu().expect("GPU should be available");
1983
1984 let x = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], true)
1985 .to_device(gpu_device.clone());
1986 let mean_result = x.mean();
1987
1988 assert!(mean_result.borrow().device.is_gpu());
1990
1991 mean_result.backward();
1993
1994 let grad_data;
1996 {
1997 let x_ref = x.borrow();
1998 let x_grad = x_ref.grad.as_ref().expect("x should have grad");
1999 assert!(x_grad.is_gpu());
2000 grad_data = x_grad.to_vec();
2001 }
2002
2003 assert_eq!(grad_data.len(), 6);
2004 let expected = 1.0 / 6.0;
2005 for &val in &grad_data {
2006 assert!((val - expected).abs() < 1e-5);
2007 }
2008 }
2009
2010 #[cfg(feature = "gpu")]
2011 #[test]
2012 fn test_max_backward_gpu() {
2013 if !is_gpu_available() {
2014 return;
2015 }
2016
2017 let gpu_device = Device::gpu().expect("GPU should be available");
2018
2019 let x = RawTensor::new(vec![1.0, 5.0, 3.0, 4.0, 2.0, 6.0], &[2, 3], true)
2020 .to_device(gpu_device.clone());
2021 let max_result = x.max_reduce();
2022
2023 assert!(max_result.borrow().device.is_gpu());
2025
2026 max_result.backward();
2028
2029 let grad_data;
2031 {
2032 let x_ref = x.borrow();
2033 let x_grad = x_ref.grad.as_ref().expect("x should have grad");
2034 assert!(x_grad.is_gpu());
2035 grad_data = x_grad.to_vec();
2036 }
2037
2038 assert_eq!(grad_data.len(), 6);
2039
2040 for (i, &val) in grad_data.iter().enumerate() {
2042 if i == 5 {
2043 assert!((val - 1.0).abs() < 1e-5);
2044 } else {
2045 assert!(val.abs() < 1e-5);
2046 }
2047 }
2048 }
2049
2050 #[cfg(feature = "gpu")]
2051 #[test]
2052 fn test_reduction_backward_gpu_cpu_equivalence() {
2053 if !is_gpu_available() {
2054 return;
2055 }
2056
2057 let gpu_device = Device::gpu().expect("GPU should be available");
2058
2059 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
2060 let shape = &[2, 4];
2061
2062 let x_cpu = RawTensor::new(data.clone(), shape, true);
2064 let sum_cpu = x_cpu.sum();
2065 sum_cpu.backward();
2066 let sum_grad_cpu = x_cpu.borrow().grad.as_ref().unwrap().to_vec();
2067
2068 let x_gpu = RawTensor::new(data.clone(), shape, true).to_device(gpu_device.clone());
2069 let sum_gpu = x_gpu.sum();
2070 sum_gpu.backward();
2071 let sum_grad_gpu = x_gpu.borrow().grad.as_ref().unwrap().to_vec();
2072
2073 assert_eq!(sum_grad_cpu.len(), sum_grad_gpu.len());
2074 for (cpu_val, gpu_val) in sum_grad_cpu.iter().zip(sum_grad_gpu.iter()) {
2075 assert!((cpu_val - gpu_val).abs() < 1e-5);
2076 }
2077
2078 let x_cpu2 = RawTensor::new(data.clone(), shape, true);
2080 let mean_cpu = x_cpu2.mean();
2081 mean_cpu.backward();
2082 let mean_grad_cpu = x_cpu2.borrow().grad.as_ref().unwrap().to_vec();
2083
2084 let x_gpu2 = RawTensor::new(data.clone(), shape, true).to_device(gpu_device.clone());
2085 let mean_gpu = x_gpu2.mean();
2086 mean_gpu.backward();
2087 let mean_grad_gpu = x_gpu2.borrow().grad.as_ref().unwrap().to_vec();
2088
2089 assert_eq!(mean_grad_cpu.len(), mean_grad_gpu.len());
2090 for (cpu_val, gpu_val) in mean_grad_cpu.iter().zip(mean_grad_gpu.iter()) {
2091 assert!((cpu_val - gpu_val).abs() < 1e-5);
2092 }
2093
2094 let x_cpu3 = RawTensor::new(data.clone(), shape, true);
2096 let max_cpu = x_cpu3.max_reduce();
2097 max_cpu.backward();
2098 let max_grad_cpu = x_cpu3.borrow().grad.as_ref().unwrap().to_vec();
2099
2100 let x_gpu3 = RawTensor::new(data.clone(), shape, true).to_device(gpu_device);
2101 let max_gpu = x_gpu3.max_reduce();
2102 max_gpu.backward();
2103 let max_grad_gpu = x_gpu3.borrow().grad.as_ref().unwrap().to_vec();
2104
2105 assert_eq!(max_grad_cpu.len(), max_grad_gpu.len());
2106 for (cpu_val, gpu_val) in max_grad_cpu.iter().zip(max_grad_gpu.iter()) {
2107 assert!((cpu_val - gpu_val).abs() < 1e-5);
2108 }
2109 }
2110
2111 #[cfg(feature = "gpu")]
2112 #[test]
2113 fn test_sgd_optimizer_gpu() {
2114 use crate::nn::optim::SGD;
2115
2116 if !is_gpu_available() {
2117 return;
2118 }
2119
2120 let gpu_device = Device::gpu().expect("GPU should be available");
2121
2122 let param = RawTensor::new(vec![0.5; 10], &[10], true).to_device(gpu_device.clone());
2124
2125 let mut opt = SGD::new(vec![param.clone()], 0.01, 0.0, 0.0);
2127
2128 {
2130 let mut p = param.borrow_mut();
2131 let grad_data = vec![0.1; 10];
2132 p.grad = Some(Storage::gpu(grad_data));
2133 }
2134
2135 let param_before = param.borrow().data.to_vec();
2137 opt.step();
2138 let param_after = param.borrow().data.to_vec();
2139
2140 for (before, after) in param_before.iter().zip(param_after.iter()) {
2142 assert!((after - (before - 0.001)).abs() < 1e-5);
2143 }
2144 }
2145
2146 #[cfg(feature = "gpu")]
2147 #[test]
2148 fn test_adam_optimizer_gpu() {
2149 use crate::nn::optim::Adam;
2150
2151 if !is_gpu_available() {
2152 return;
2153 }
2154
2155 let gpu_device = Device::gpu().expect("GPU should be available");
2156
2157 let param = RawTensor::new(vec![0.5; 10], &[10], true).to_device(gpu_device.clone());
2159
2160 let mut opt = Adam::new(vec![param.clone()], 0.01, (0.9, 0.999), 1e-8, 0.0);
2162
2163 {
2165 let mut p = param.borrow_mut();
2166 let grad_data = vec![0.1; 10];
2167 p.grad = Some(Storage::gpu(grad_data));
2168 }
2169
2170 let param_before = param.borrow().data.to_vec();
2172 opt.step();
2173 let param_after = param.borrow().data.to_vec();
2174
2175 for (before, after) in param_before.iter().zip(param_after.iter()) {
2177 assert_ne!(before, after, "Parameter should change after Adam step");
2178 }
2179 }
2180
2181 #[cfg(feature = "gpu")]
2182 #[test]
2183 fn test_sgd_momentum_optimizer_gpu() {
2184 use crate::nn::optim::SGD;
2185
2186 if !is_gpu_available() {
2187 return;
2188 }
2189
2190 let gpu_device = Device::gpu().expect("GPU should be available");
2191
2192 let param = RawTensor::new(vec![0.5; 10], &[10], true).to_device(gpu_device.clone());
2194
2195 let mut opt = SGD::new(vec![param.clone()], 0.01, 0.9, 0.0);
2197
2198 {
2200 let mut p = param.borrow_mut();
2201 let grad_data = vec![0.1; 10];
2202 p.grad = Some(Storage::gpu(grad_data));
2203 }
2204
2205 let param_before = param.borrow().data.to_vec();
2207 opt.step();
2208 let param_after_step1 = param.borrow().data.to_vec();
2209
2210 for (before, after) in param_before.iter().zip(param_after_step1.iter()) {
2212 assert_ne!(before, after);
2213 }
2214
2215 {
2217 let mut p = param.borrow_mut();
2218 p.grad = Some(Storage::gpu(vec![0.1; 10]));
2219 }
2220
2221 let param_after_step2;
2223 {
2224 let p = param.borrow();
2225 param_after_step2 = p.data.to_vec();
2226 }
2227
2228 let changes_step1: Vec<f32> = param_after_step1
2231 .iter()
2232 .zip(param_before.iter())
2233 .map(|(a, b)| a - b)
2234 .collect();
2235 let changes_step2: Vec<f32> = param_after_step2
2236 .iter()
2237 .zip(param_after_step1.iter())
2238 .map(|(a, b)| a - b)
2239 .collect();
2240
2241 for (c1, c2) in changes_step1.iter().zip(changes_step2.iter()) {
2243 assert_ne!(c1, c2);
2244 }
2245 }
2246
2247 #[cfg(feature = "gpu")]
2248 #[test]
2249 fn test_optimizer_gpu_cpu_equivalence() {
2250 use crate::nn::optim::Adam;
2251
2252 if !is_gpu_available() {
2253 return;
2254 }
2255
2256 let gpu_device = Device::gpu().expect("GPU should be available");
2257
2258 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
2259 let grad_data = vec![0.1, 0.2, 0.3, 0.4, 0.5];
2260
2261 let param_cpu = RawTensor::new(data.clone(), &[5], true);
2263 let mut opt_cpu = Adam::new(vec![param_cpu.clone()], 0.01, (0.9, 0.999), 1e-8, 0.0);
2264 {
2265 let mut p = param_cpu.borrow_mut();
2266 p.grad = Some(Storage::cpu(grad_data.clone()));
2267 }
2268
2269 let param_gpu = RawTensor::new(data, &[5], true).to_device(gpu_device.clone());
2271 let mut opt_gpu = Adam::new(vec![param_gpu.clone()], 0.01, (0.9, 0.999), 1e-8, 0.0);
2272 {
2273 let mut p = param_gpu.borrow_mut();
2274 p.grad = Some(Storage::gpu(grad_data));
2276 }
2277
2278 opt_cpu.step();
2280 opt_gpu.step();
2281
2282 let result_cpu = param_cpu.borrow().data.to_vec();
2284 let result_gpu = param_gpu.borrow().data.to_vec();
2285
2286 assert_eq!(result_cpu.len(), result_gpu.len());
2287 for (cpu_val, gpu_val) in result_cpu.iter().zip(result_gpu.iter()) {
2288 assert!(
2289 (cpu_val - gpu_val).abs() < 1e-4,
2290 "CPU={cpu_val}, GPU={gpu_val}"
2291 );
2292 }
2293 }
2294
2295 #[cfg(feature = "gpu")]
2296 #[test]
2297 fn test_optimizer_state_stays_on_gpu() {
2298 use crate::nn::optim::Adam;
2299
2300 if !is_gpu_available() {
2301 return;
2302 }
2303
2304 let gpu_device = Device::gpu().expect("GPU should be available");
2305
2306 let param = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5], true).to_device(gpu_device);
2308
2309 let mut opt = Adam::new(vec![param.clone()], 0.01, (0.9, 0.999), 1e-8, 0.0);
2311
2312 {
2318 let mut p = param.borrow_mut();
2319 p.grad = Some(Storage::gpu(vec![0.1, 0.2, 0.3, 0.4, 0.5]));
2320 }
2321
2322 for _ in 0..5 {
2324 {
2326 let mut p = param.borrow_mut();
2327 p.grad = Some(Storage::gpu(vec![0.1, 0.2, 0.3, 0.4, 0.5]));
2328 }
2329 opt.step();
2330 }
2331
2332 let result = param.borrow().data.to_vec();
2338 for val in result.iter() {
2339 assert_ne!(*val, 0.0, "Parameter should have been updated");
2340 }
2341 }
2342
2343 #[cfg(feature = "gpu")]
2344 #[test]
2345 fn test_sgd_momentum_state_stays_on_gpu() {
2346 use crate::nn::optim::SGD;
2347
2348 if !is_gpu_available() {
2349 return;
2350 }
2351
2352 let gpu_device = Device::gpu().expect("GPU should be available");
2353
2354 let param = RawTensor::new(vec![1.0, 2.0, 3.0, 4.0, 5.0], &[5], true).to_device(gpu_device);
2356
2357 let mut opt = SGD::new(vec![param.clone()], 0.01, 0.9, 0.0);
2359
2360 for _ in 0..5 {
2362 {
2363 let mut p = param.borrow_mut();
2364 p.grad = Some(Storage::gpu(vec![0.1, 0.2, 0.3, 0.4, 0.5]));
2365 }
2366 opt.step();
2367 }
2368
2369 let result = param.borrow().data.to_vec();
2371 for val in result.iter() {
2372 assert_ne!(*val, 0.0, "Parameter should have been updated");
2373 }
2374 }
2375
2376 #[cfg(feature = "gpu")]
2377 #[test]
2378 fn test_gpu_binary_backward_broadcast_add() {
2379 if !is_gpu_available() {
2380 return;
2381 }
2382
2383 let gpu_device = Device::gpu().expect("GPU should be available");
2384
2385 let a = RawTensor::new(vec![1.0, 2.0, 3.0], &[3, 1], true).to_device(gpu_device.clone());
2389 let b = RawTensor::new(vec![10.0, 20.0, 30.0, 40.0], &[1, 4], true)
2390 .to_device(gpu_device.clone());
2391
2392 let c = a.add(&b);
2393 c.backward();
2394
2395 assert!(a.grad().is_some());
2397 assert!(b.grad().is_some());
2398
2399 let a_grad = a.grad().unwrap();
2400 let b_grad = b.grad().unwrap();
2401
2402 assert_abs_diff_eq!(
2404 a_grad.first().copied().unwrap_or(f32::NAN),
2405 4.0,
2406 epsilon = 1e-3
2407 );
2408 assert_abs_diff_eq!(
2409 a_grad.get(1).copied().unwrap_or(f32::NAN),
2410 4.0,
2411 epsilon = 1e-3
2412 );
2413 assert_abs_diff_eq!(
2414 a_grad.get(2).copied().unwrap_or(f32::NAN),
2415 4.0,
2416 epsilon = 1e-3
2417 );
2418
2419 assert_abs_diff_eq!(
2421 b_grad.first().copied().unwrap_or(f32::NAN),
2422 3.0,
2423 epsilon = 1e-3
2424 );
2425 assert_abs_diff_eq!(
2426 b_grad.get(1).copied().unwrap_or(f32::NAN),
2427 3.0,
2428 epsilon = 1e-3
2429 );
2430 assert_abs_diff_eq!(
2431 b_grad.get(2).copied().unwrap_or(f32::NAN),
2432 3.0,
2433 epsilon = 1e-3
2434 );
2435 assert_abs_diff_eq!(
2436 b_grad.get(3).copied().unwrap_or(f32::NAN),
2437 3.0,
2438 epsilon = 1e-3
2439 );
2440 }
2441
2442 #[cfg(feature = "gpu")]
2443 #[test]
2444 fn test_gpu_binary_backward_broadcast_mul() {
2445 if !is_gpu_available() {
2446 return;
2447 }
2448
2449 let gpu_device = Device::gpu().expect("GPU should be available");
2450
2451 let a = RawTensor::new(vec![1.0, 2.0], &[2, 1], true).to_device(gpu_device.clone());
2453 let b = RawTensor::new(vec![10.0, 20.0, 30.0], &[1, 3], true).to_device(gpu_device.clone());
2454
2455 let c = a.elem_mul(&b);
2456 c.backward();
2457
2458 assert!(a.grad().is_some());
2460 assert!(b.grad().is_some());
2461
2462 let a_grad = a.grad().unwrap();
2463 let b_grad = b.grad().unwrap();
2464
2465 for g in a_grad.iter() {
2467 assert!(g.is_finite(), "a_grad contains non-finite value");
2468 }
2469 for g in b_grad.iter() {
2470 assert!(g.is_finite(), "b_grad contains non-finite value");
2471 }
2472 }
2473
2474 #[cfg(feature = "gpu")]
2475 #[test]
2476 fn test_gpu_binary_backward_broadcast_stress() {
2477 if !is_gpu_available() {
2478 return;
2479 }
2480
2481 let gpu_device = Device::gpu().expect("GPU should be available");
2482
2483 let a = RawTensor::new(vec![1.0], &[1], true).to_device(gpu_device.clone());
2485 let b = RawTensor::new(
2486 (0..1000).map(|i| i as f32).collect::<Vec<_>>(),
2487 &[1000],
2488 true,
2489 )
2490 .to_device(gpu_device.clone());
2491
2492 let c = a.add(&b);
2493 c.backward();
2494
2495 let a_grad = a.grad().unwrap();
2497 assert_abs_diff_eq!(
2498 a_grad.first().copied().unwrap_or(f32::NAN),
2499 1000.0,
2500 epsilon = 1e-2
2501 );
2502 }
2503}