1use parking_lot::{RwLockReadGuard, RwLockWriteGuard};
2use std::fmt;
3use std::ops::{Add, Div, Mul, Sub};
4use std::sync::{Arc, Mutex};
5use rand_distr::{Distribution, Normal, Uniform};
7use crate::autograd::BackwardOp;
11use crate::storage::Storage;
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14
15#[derive(Clone, Debug)]
16pub struct Tensor {
17 pub(crate) inner: Arc<TensorImpl>,
18}
19
20impl PartialEq for Tensor {
21 fn eq(&self, other: &Self) -> bool {
22 Arc::ptr_eq(&self.inner, &other.inner)
23 }
24}
25
26#[derive(Debug)]
27pub(crate) struct TensorImpl {
28 pub(crate) storage: Storage,
29 pub(crate) shape: Vec<usize>,
30 pub(crate) strides: Vec<usize>,
31 pub(crate) grad: Mutex<Option<Tensor>>, pub(crate) requires_grad: bool,
33 pub(crate) op: Option<Arc<dyn BackwardOp>>, pub(crate) is_leaf: bool,
35}
36
37#[cfg(feature = "wgpu_backend")]
38#[derive(Debug)]
39struct ToCpuBackward {
40 input: Tensor,
41}
42
43#[cfg(feature = "wgpu_backend")]
44impl BackwardOp for ToCpuBackward {
45 fn backward(&self, grad: &Tensor) {
46 if self.input.requires_grad() {
47 let grad_wgpu = grad.to_wgpu();
48 self.input.accumulate_grad(&grad_wgpu);
49 self.input.backward_step();
50 }
51 }
52}
53
54impl Tensor {
55 pub fn new(data: &[f32], shape: &[usize]) -> Self {
56 let size: usize = shape.iter().product();
57 if data.len() != size {
58 panic!(
59 "Data size {} does not match shape {:?} (expected {})",
60 data.len(),
61 shape,
62 size
63 );
64 }
65
66 let strides = Self::compute_strides(shape);
67 let storage = Storage::from_slice(data);
68
69 Self {
70 inner: Arc::new(TensorImpl {
71 storage,
72 shape: shape.to_vec(),
73 strides,
74 grad: Mutex::new(None),
75 requires_grad: false,
76 op: None,
77 is_leaf: true,
78 }),
79 }
80 }
81
82 pub fn new_with_storage(storage: Storage, shape: &[usize]) -> Self {
83 let strides = Self::compute_strides(shape);
84 Self {
85 inner: Arc::new(TensorImpl {
86 storage,
87 shape: shape.to_vec(),
88 strides,
89 grad: Mutex::new(None),
90 requires_grad: false,
91 op: None,
92 is_leaf: true,
93 }),
94 }
95 }
96
97 pub fn zeros(shape: &[usize]) -> Self {
98 let size: usize = shape.iter().product();
99 Self::new(&vec![0.0; size], shape)
100 }
101
102 pub fn full(shape: &[usize], value: f32) -> Self {
103 let size: usize = shape.iter().product();
104 let data = vec![value; size];
105 let storage = Storage::new(data);
106 Self::new_with_storage(storage, shape)
107 }
108
109 pub fn ones(shape: &[usize]) -> Self {
110 let size: usize = shape.iter().product();
111 Self::new(&vec![1.0; size], shape)
112 }
113
114 pub fn storage(&self) -> &Storage {
115 &self.inner.storage
116 }
117
118 #[cfg(feature = "wgpu_backend")]
119 pub fn to_wgpu(&self) -> Self {
120 if let Some(_) = self.storage().wgpu_buffer() {
121 return self.clone();
122 }
123
124 let contig = if self.is_contiguous() {
125 self.clone()
126 } else {
127 self.contiguous()
128 };
129
130 let data = contig.data();
131 let ctx = crate::backend::wgpu::get_context().expect("WGPU context not initialized");
132
133 use wgpu::util::DeviceExt;
134 let buffer = ctx
135 .device
136 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
137 label: Some("Tensor Buffer"),
138 contents: bytemuck::cast_slice(&data),
139 usage: wgpu::BufferUsages::STORAGE
140 | wgpu::BufferUsages::COPY_SRC
141 | wgpu::BufferUsages::COPY_DST,
142 });
143
144 let storage = Storage::new_wgpu(buffer, data.len(), 0);
145
146 let inner = TensorImpl {
147 storage,
148 shape: contig.shape().to_vec(),
149 strides: contig.strides().to_vec(),
150 grad: Mutex::new(None),
151 requires_grad: self.requires_grad(),
152 op: None,
153 is_leaf: self.inner.is_leaf,
154 };
155
156 Tensor {
157 inner: Arc::new(inner),
158 }
159 }
160
161 #[cfg(feature = "wgpu_backend")]
162 pub fn to_cpu(&self) -> Self {
163 if let Some(buffer) = self.storage().wgpu_buffer() {
164 crate::backend::wgpu::flush_queue();
166
167 let ctx = crate::backend::wgpu::get_context().expect("WGPU context not initialized");
168
169 let buf_size = buffer.size();
170
171 let staging_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
172 label: Some("Staging Buffer"),
173 size: buf_size,
174 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
175 mapped_at_creation: false,
176 });
177
178 let mut encoder = ctx
179 .device
180 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
181 label: Some("Download Encoder"),
182 });
183 encoder.copy_buffer_to_buffer(buffer, 0, &staging_buffer, 0, buf_size);
184 ctx.queue.submit(Some(encoder.finish()));
185
186 let buffer_slice = staging_buffer.slice(..);
187 let (sender, receiver) = std::sync::mpsc::channel();
188 buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
189
190 ctx.device.poll(wgpu::Maintain::Wait);
191 receiver.recv().unwrap().unwrap();
192
193 let data = buffer_slice.get_mapped_range();
194 let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
195 drop(data);
196 staging_buffer.unmap();
197
198 let storage = Storage::new(result);
200
201 let inner = TensorImpl {
203 storage,
204 shape: self.shape().to_vec(),
205 strides: self.strides().to_vec(),
206 grad: Mutex::new(None),
207 requires_grad: self.requires_grad(),
208 op: if self.requires_grad() {
209 Some(Arc::new(ToCpuBackward {
210 input: self.clone(),
211 }))
212 } else {
213 None
214 },
215 is_leaf: self.inner.is_leaf,
216 };
217 return Self {
218 inner: Arc::new(inner),
219 };
220 }
221 self.clone()
223 }
224
225 pub fn shape(&self) -> &[usize] {
226 &self.inner.shape
227 }
228
229 pub fn strides(&self) -> &[usize] {
230 &self.inner.strides
231 }
232
233 pub fn set_requires_grad(self, requires_grad: bool) -> Self {
234 let inner = &self.inner;
235 let new_impl = TensorImpl {
236 storage: inner.storage.clone(),
237 shape: inner.shape.clone(),
238 strides: inner.strides.clone(),
239 grad: Mutex::new(None),
240 requires_grad,
241 op: inner.op.clone(),
242 is_leaf: inner.is_leaf,
243 };
244 Self {
245 inner: Arc::new(new_impl),
246 }
247 }
248
249 pub fn set_requires_grad_mut(&mut self, requires_grad: bool) {
250 if let Some(inner) = Arc::get_mut(&mut self.inner) {
251 inner.requires_grad = requires_grad;
252 } else {
253 *self = self.clone().set_requires_grad(requires_grad);
255 }
256 }
257
258 pub fn requires_grad(&self) -> bool {
259 self.inner.requires_grad
260 }
261
262 pub fn data(&self) -> RwLockReadGuard<'_, Vec<f32>> {
263 self.inner.storage.data()
264 }
265
266 pub fn data_mut(&self) -> RwLockWriteGuard<'_, Vec<f32>> {
267 self.inner.storage.data_mut()
268 }
269
270 pub fn grad(&self) -> Option<Tensor> {
271 self.inner.grad.lock().unwrap().clone()
272 }
273
274 pub fn zero_grad(&self) {
275 *self.inner.grad.lock().unwrap() = None;
276 }
277
278 pub fn accumulate_grad(&self, grad: &Tensor) {
279 let mut g = self.inner.grad.lock().unwrap();
280 if let Some(existing) = &*g {
281 #[cfg(feature = "wgpu_backend")]
282 {
283 let existing_is_wgpu = existing.storage().wgpu_buffer().is_some();
284 let grad_is_wgpu = grad.storage().wgpu_buffer().is_some();
285
286 if existing_is_wgpu && grad_is_wgpu {
287 *g = Some(existing.add(grad));
288 } else if existing_is_wgpu {
289 *g = Some(existing.add(&grad.to_wgpu()));
290 } else if grad_is_wgpu {
291 *g = Some(existing.add(&grad.to_cpu()));
292 } else {
293 *g = Some(existing.add(grad));
294 }
295 }
296 #[cfg(not(feature = "wgpu_backend"))]
297 {
298 *g = Some(existing.add(grad));
299 }
300 } else {
301 *g = Some(grad.clone());
302 }
303 }
304
305 pub fn backward(&self) {
306 if self.shape().len() != 1 || self.shape()[0] != 1 {
308 }
314
315 let grad = Tensor::ones(self.shape());
316 self.accumulate_grad(&grad);
317 self.backward_step();
318 }
319
320 pub fn backward_step(&self) {
321 if let Some(op) = &self.inner.op {
322 if let Some(grad) = self.grad() {
323 op.backward(&grad);
324 }
325 }
326 }
327
328 pub fn detach(&self) -> Tensor {
331 Tensor {
332 inner: Arc::new(TensorImpl {
333 storage: self.inner.storage.clone(),
334 shape: self.inner.shape.clone(),
335 strides: self.inner.strides.clone(),
336 grad: Mutex::new(None),
337 requires_grad: false,
338 op: None,
339 is_leaf: true,
340 }),
341 }
342 }
343
344 pub fn set_op(&mut self, op: Arc<dyn BackwardOp>) {
345 if let Some(inner) = Arc::get_mut(&mut self.inner) {
346 inner.op = Some(op);
347 } else {
348 panic!("Cannot set op on shared tensor storage wrapper");
355 }
356 }
357
358 pub fn matmul(&self, rhs: &Tensor) -> Tensor {
359 crate::ops::matmul(self, rhs)
360 }
361
362 pub fn t(&self) -> Tensor {
363 crate::ops::view::transpose(self, 0, 1) }
365
366 pub fn sub(&self, rhs: &Tensor) -> Tensor {
367 crate::ops::sub(self, rhs)
368 }
369
370 pub fn add(&self, rhs: &Tensor) -> Tensor {
371 crate::ops::add(self, rhs)
372 }
373
374 pub fn neg(&self) -> Tensor {
375 crate::ops::neg(self)
376 }
377
378 pub fn relu(&self) -> Tensor {
379 crate::ops::relu(self)
380 }
381
382 pub fn sigmoid(&self) -> Tensor {
383 crate::ops::sigmoid(self)
384 }
385
386 pub fn tanh(&self) -> Tensor {
387 crate::ops::tanh(self)
388 }
389
390 pub fn softmax(&self, dim: i64) -> Tensor {
391 crate::ops::softmax(self, dim)
392 }
393
394 pub fn conv2d(
395 &self,
396 weight: &Tensor,
397 stride: (usize, usize),
398 padding: (usize, usize),
399 ) -> Tensor {
400 crate::ops::conv2d(self, weight, stride, padding)
401 }
402
403 pub fn max_pool2d(
404 &self,
405 kernel_size: (usize, usize),
406 stride: (usize, usize),
407 padding: (usize, usize),
408 ) -> Tensor {
409 crate::ops::max_pool2d(self, kernel_size, stride, padding)
410 }
411
412 #[allow(clippy::too_many_arguments)]
413 pub fn batch_norm2d(
414 &self,
415 gamma: Option<&Tensor>,
416 beta: Option<&Tensor>,
417 running_mean: &Tensor,
418 running_var: &Tensor,
419 training: bool,
420 momentum: f32,
421 eps: f32,
422 ) -> Tensor {
423 crate::ops::batch_norm2d(
424 self,
425 gamma,
426 beta,
427 running_mean,
428 running_var,
429 training,
430 momentum,
431 eps,
432 )
433 }
434
435 pub fn layer_norm(
436 &self,
437 normalized_shape: &[usize],
438 weight: Option<&Tensor>,
439 bias: Option<&Tensor>,
440 eps: f32,
441 ) -> Tensor {
442 crate::ops::layer_norm(self, normalized_shape, weight, bias, eps)
443 }
444
445 pub fn permute(&self, dims: &[usize]) -> Tensor {
446 crate::ops::view::permute(self, dims)
447 }
448
449 pub fn transpose(&self, dim0: usize, dim1: usize) -> Tensor {
450 crate::ops::view::transpose(self, dim0, dim1)
451 }
452
453 pub fn contiguous(&self) -> Tensor {
454 if self.is_contiguous() {
455 return self.clone();
456 }
457
458 #[cfg(feature = "wgpu_backend")]
459 if let Some(input_buf) = self.storage().wgpu_buffer() {
460 use crate::backend::wgpu::contiguous_wgpu;
461 let output_buf = contiguous_wgpu(input_buf, self.shape(), self.strides());
462 let size: usize = self.shape().iter().product();
463 let storage = Storage::new_wgpu(output_buf, size, 0);
464 let mut tensor = Tensor::new_with_storage(storage, self.shape());
465 tensor.set_requires_grad_mut(self.requires_grad());
466 return tensor;
467 }
468
469 crate::ops::view::contiguous(self)
470 }
471
472 pub fn is_contiguous(&self) -> bool {
473 let default_strides = Self::compute_strides(self.shape());
474 if self.strides() != default_strides {
475 return false;
476 }
477 let expected_size: usize = self.shape().iter().product();
478 let actual_size = self.storage().len();
479 expected_size == actual_size
480 }
481
482 pub fn normal_(&self, mean: f32, std: f32) {
483 let mut guard = self.data_mut();
484 let mut rng = rand::thread_rng();
485 let normal = Normal::new(mean, std).unwrap();
486 for x in guard.iter_mut() {
487 *x = normal.sample(&mut rng);
488 }
489 }
490
491 pub fn uniform_(&self, low: f32, high: f32) {
492 let mut guard = self.data_mut();
493 let mut rng = rand::thread_rng();
494 let uniform = Uniform::new(low, high);
495 for x in guard.iter_mut() {
496 *x = uniform.sample(&mut rng);
497 }
498 }
499
500 pub fn fill_(&self, value: f32) {
501 let mut guard = self.data_mut();
502 for x in guard.iter_mut() {
503 *x = value;
504 }
505 }
506
507 pub fn reshape(&self, new_shape: &[usize]) -> Tensor {
508 let size: usize = self.shape().iter().product();
509 let new_size: usize = new_shape.iter().product();
510 if size != new_size {
511 panic!(
512 "Reshape: element count mismatch: {:?} vs {:?}",
513 self.shape(),
514 new_shape
515 );
516 }
517
518 let inner = &self.inner;
519 let strides = Self::compute_strides(new_shape);
520
521 let mut tensor = Self {
523 inner: Arc::new(TensorImpl {
524 storage: inner.storage.clone(),
525 shape: new_shape.to_vec(),
526 strides,
527 grad: Mutex::new(None),
528 requires_grad: inner.requires_grad,
529 op: None,
530 is_leaf: false,
531 }),
532 };
533
534 if inner.requires_grad {
535 tensor.set_op(Arc::new(crate::ops::ReshapeBackward {
536 input_shape: inner.shape.clone(),
537 input: self.clone(),
538 }));
539 }
540
541 tensor
542 }
543
544 pub fn mul(&self, rhs: &Tensor) -> Tensor {
545 crate::ops::mul(self, rhs)
546 }
547
548 pub fn div(&self, rhs: &Tensor) -> Tensor {
549 crate::ops::div(self, rhs)
550 }
551
552 #[cfg(feature = "wgpu_backend")]
553 pub fn matmul_relu(&self, rhs: &Tensor) -> Tensor {
554 crate::ops::matmul_fused(self, rhs, None, crate::backend::wgpu::Activation::ReLU)
555 }
556
557 #[cfg(not(feature = "wgpu_backend"))]
558 pub fn matmul_relu(&self, rhs: &Tensor) -> Tensor {
559 self.matmul(rhs).relu()
560 }
561
562 pub fn sgd_step(&self, grad: &Tensor, lr: f32) -> Tensor {
563 crate::ops::sgd_step(self, grad, lr)
564 }
565
566 pub fn copy_(&self, src: &Tensor) {
567 #[cfg(feature = "wgpu_backend")]
568 if self.storage().device().is_wgpu() {
569 if let Some(dest_buf) = self.storage().wgpu_buffer() {
570 let ctx = crate::backend::wgpu::get_context().expect("WGPU context missing");
571
572 if let Some(src_buf) = src.storage().wgpu_buffer() {
574 let mut encoder =
575 ctx.device
576 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
577 label: Some("Copy Encoder"),
578 });
579 encoder.copy_buffer_to_buffer(src_buf, 0, dest_buf, 0, dest_buf.size());
580 ctx.queue.submit(Some(encoder.finish()));
581 return;
582 }
583
584 let src_cpu = if src.storage().device().is_wgpu() {
586 src.to_cpu()
587 } else {
588 src.clone()
589 };
590 let src_guard = src_cpu.data();
591 ctx.queue
592 .write_buffer(dest_buf, 0, bytemuck::cast_slice(&src_guard));
593 return;
594 }
595 }
596
597 let src_cpu = if src.storage().device().is_wgpu() {
599 src.to_cpu()
600 } else {
601 src.clone()
602 };
603 let mut dest_guard = self.data_mut();
604 let src_guard = src_cpu.data();
605 if dest_guard.len() != src_guard.len() {
606 }
609 let len = std::cmp::min(dest_guard.len(), src_guard.len());
610 dest_guard[..len].copy_from_slice(&src_guard[..len]);
611 }
612
613 fn compute_strides(shape: &[usize]) -> Vec<usize> {
614 let mut strides = vec![0; shape.len()];
615 let mut stride = 1;
616 for i in (0..shape.len()).rev() {
617 strides[i] = stride;
618 stride *= shape[i];
619 }
620 strides
621 }
622
623 pub fn copy_from_slice(&self, src: &[f32]) {
628 let mut guard = self.data_mut();
629 let len = std::cmp::min(guard.len(), src.len());
630 guard[..len].copy_from_slice(&src[..len]);
631 }
632}
633
634impl Add for &Tensor {
636 type Output = Tensor;
637 fn add(self, rhs: Self) -> Tensor {
638 self.add(rhs)
639 }
640}
641
642impl Add<Tensor> for Tensor {
643 type Output = Tensor;
644 fn add(self, rhs: Tensor) -> Tensor {
645 Tensor::add(&self, &rhs)
646 }
647}
648
649impl Sub<Tensor> for Tensor {
650 type Output = Tensor;
651 fn sub(self, rhs: Tensor) -> Tensor {
652 Tensor::sub(&self, &rhs)
653 }
654}
655
656impl Mul<Tensor> for Tensor {
657 type Output = Tensor;
658 fn mul(self, rhs: Tensor) -> Tensor {
659 Tensor::mul(&self, &rhs)
660 }
661}
662
663impl Div<Tensor> for Tensor {
664 type Output = Tensor;
665 fn div(self, rhs: Tensor) -> Tensor {
666 Tensor::div(&self, &rhs)
667 }
668}
669
670impl Sub for &Tensor {
671 type Output = Tensor;
672 fn sub(self, rhs: Self) -> Tensor {
673 self.sub(rhs)
674 }
675}
676
677impl Mul for &Tensor {
678 type Output = Tensor;
679 fn mul(self, rhs: Self) -> Tensor {
680 self.mul(rhs)
681 }
682}
683
684impl Div for &Tensor {
685 type Output = Tensor;
686 fn div(self, rhs: Self) -> Tensor {
687 self.div(rhs)
688 }
689}
690
691#[cfg(feature = "serde")]
692#[derive(Serialize, Deserialize)]
693struct TensorData {
694 shape: Vec<usize>,
695 data: Vec<f32>,
696 requires_grad: bool,
697}
698
699#[cfg(feature = "serde")]
700impl Serialize for Tensor {
701 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
702 where
703 S: Serializer,
704 {
705 let data = self.data().clone();
706 let tensor_data = TensorData {
707 shape: self.shape().to_vec(),
708 data,
709 requires_grad: self.requires_grad(),
710 };
711 tensor_data.serialize(serializer)
712 }
713}
714
715#[cfg(feature = "serde")]
716impl<'de> Deserialize<'de> for Tensor {
717 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
718 where
719 D: Deserializer<'de>,
720 {
721 let tensor_data = TensorData::deserialize(deserializer)?;
722 let tensor = Tensor::new(&tensor_data.data, &tensor_data.shape)
723 .set_requires_grad(tensor_data.requires_grad);
724 Ok(tensor)
725 }
726}
727
728impl fmt::Display for Tensor {
729 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
730 let data = self.data();
731 let len = std::cmp::min(data.len(), 10);
732 write!(
733 f,
734 "Tensor(shape={:?}, data={:?})",
735 self.shape(),
736 &data[..len]
737 )
738 }
739}