Skip to main content

rustorch_core/
tensor.rs

1use parking_lot::{RwLockReadGuard, RwLockWriteGuard};
2use std::fmt;
3use std::ops::{Add, Div, Mul, Sub};
4use std::sync::{Arc, Mutex};
5// use rand::Rng;
6use rand_distr::{Distribution, Normal, Uniform};
7// use rayon::prelude::*;
8// use rayon::iter::{IntoParallelRefIterator, ParallelIterator, IndexedParallelIterator};
9// use rayon::slice::ParallelSliceMut;
10use crate::autograd::BackwardOp;
11use crate::storage::Storage;
12#[cfg(feature = "serde")]
13use serde::{Deserialize, Deserializer, Serialize, Serializer};
14
15#[derive(Clone, Debug)]
16pub struct Tensor {
17    pub(crate) inner: Arc<TensorImpl>,
18}
19
20impl PartialEq for Tensor {
21    fn eq(&self, other: &Self) -> bool {
22        Arc::ptr_eq(&self.inner, &other.inner)
23    }
24}
25
26#[derive(Debug)]
27pub(crate) struct TensorImpl {
28    pub(crate) storage: Storage,
29    pub(crate) shape: Vec<usize>,
30    pub(crate) strides: Vec<usize>,
31    pub(crate) grad: Mutex<Option<Tensor>>, // Gradient
32    pub(crate) requires_grad: bool,
33    pub(crate) op: Option<Arc<dyn BackwardOp>>, // Operation that created this tensor
34    pub(crate) is_leaf: bool,
35}
36
37#[cfg(feature = "wgpu_backend")]
38#[derive(Debug)]
39struct ToCpuBackward {
40    input: Tensor,
41}
42
43#[cfg(feature = "wgpu_backend")]
44impl BackwardOp for ToCpuBackward {
45    fn backward(&self, grad: &Tensor) {
46        if self.input.requires_grad() {
47            let grad_wgpu = grad.to_wgpu();
48            self.input.accumulate_grad(&grad_wgpu);
49            self.input.backward_step();
50        }
51    }
52}
53
54impl Tensor {
55    pub fn new(data: &[f32], shape: &[usize]) -> Self {
56        let size: usize = shape.iter().product();
57        if data.len() != size {
58            panic!(
59                "Data size {} does not match shape {:?} (expected {})",
60                data.len(),
61                shape,
62                size
63            );
64        }
65
66        let strides = Self::compute_strides(shape);
67        let storage = Storage::from_slice(data);
68
69        Self {
70            inner: Arc::new(TensorImpl {
71                storage,
72                shape: shape.to_vec(),
73                strides,
74                grad: Mutex::new(None),
75                requires_grad: false,
76                op: None,
77                is_leaf: true,
78            }),
79        }
80    }
81
82    pub fn new_with_storage(storage: Storage, shape: &[usize]) -> Self {
83        let strides = Self::compute_strides(shape);
84        Self {
85            inner: Arc::new(TensorImpl {
86                storage,
87                shape: shape.to_vec(),
88                strides,
89                grad: Mutex::new(None),
90                requires_grad: false,
91                op: None,
92                is_leaf: true,
93            }),
94        }
95    }
96
97    pub fn zeros(shape: &[usize]) -> Self {
98        let size: usize = shape.iter().product();
99        Self::new(&vec![0.0; size], shape)
100    }
101
102    pub fn full(shape: &[usize], value: f32) -> Self {
103        let size: usize = shape.iter().product();
104        let data = vec![value; size];
105        let storage = Storage::new(data);
106        Self::new_with_storage(storage, shape)
107    }
108
109    pub fn ones(shape: &[usize]) -> Self {
110        let size: usize = shape.iter().product();
111        Self::new(&vec![1.0; size], shape)
112    }
113
114    pub fn storage(&self) -> &Storage {
115        &self.inner.storage
116    }
117
118    #[cfg(feature = "wgpu_backend")]
119    pub fn to_wgpu(&self) -> Self {
120        if let Some(_) = self.storage().wgpu_buffer() {
121            return self.clone();
122        }
123
124        let contig = if self.is_contiguous() {
125            self.clone()
126        } else {
127            self.contiguous()
128        };
129
130        let data = contig.data();
131        let ctx = crate::backend::wgpu::get_context().expect("WGPU context not initialized");
132
133        use wgpu::util::DeviceExt;
134        let buffer = ctx
135            .device
136            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
137                label: Some("Tensor Buffer"),
138                contents: bytemuck::cast_slice(&data),
139                usage: wgpu::BufferUsages::STORAGE
140                    | wgpu::BufferUsages::COPY_SRC
141                    | wgpu::BufferUsages::COPY_DST,
142            });
143
144        let storage = Storage::new_wgpu(buffer, data.len(), 0);
145
146        let inner = TensorImpl {
147            storage,
148            shape: contig.shape().to_vec(),
149            strides: contig.strides().to_vec(),
150            grad: Mutex::new(None),
151            requires_grad: self.requires_grad(),
152            op: None,
153            is_leaf: self.inner.is_leaf,
154        };
155
156        Tensor {
157            inner: Arc::new(inner),
158        }
159    }
160
161    #[cfg(feature = "wgpu_backend")]
162    pub fn to_cpu(&self) -> Self {
163        if let Some(buffer) = self.storage().wgpu_buffer() {
164            // Flush any pending commands to ensure buffer is ready
165            crate::backend::wgpu::flush_queue();
166
167            let ctx = crate::backend::wgpu::get_context().expect("WGPU context not initialized");
168
169            let buf_size = buffer.size();
170
171            let staging_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
172                label: Some("Staging Buffer"),
173                size: buf_size,
174                usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
175                mapped_at_creation: false,
176            });
177
178            let mut encoder = ctx
179                .device
180                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
181                    label: Some("Download Encoder"),
182                });
183            encoder.copy_buffer_to_buffer(buffer, 0, &staging_buffer, 0, buf_size);
184            ctx.queue.submit(Some(encoder.finish()));
185
186            let buffer_slice = staging_buffer.slice(..);
187            let (sender, receiver) = std::sync::mpsc::channel();
188            buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
189
190            ctx.device.poll(wgpu::Maintain::Wait);
191            receiver.recv().unwrap().unwrap();
192
193            let data = buffer_slice.get_mapped_range();
194            let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
195            drop(data);
196            staging_buffer.unmap();
197
198            // Create CPU tensor with SAME shape and strides, but using the downloaded storage
199            let storage = Storage::new(result);
200
201            // We need to construct Tensor manually to preserve strides/offset
202            let inner = TensorImpl {
203                storage,
204                shape: self.shape().to_vec(),
205                strides: self.strides().to_vec(),
206                grad: Mutex::new(None),
207                requires_grad: self.requires_grad(),
208                op: if self.requires_grad() {
209                    Some(Arc::new(ToCpuBackward {
210                        input: self.clone(),
211                    }))
212                } else {
213                    None
214                },
215                is_leaf: self.inner.is_leaf,
216            };
217            return Self {
218                inner: Arc::new(inner),
219            };
220        }
221        // println!("DEBUG: to_cpu falling back to clone (no WGPU buffer)");
222        self.clone()
223    }
224
225    pub fn shape(&self) -> &[usize] {
226        &self.inner.shape
227    }
228
229    pub fn strides(&self) -> &[usize] {
230        &self.inner.strides
231    }
232
233    pub fn set_requires_grad(self, requires_grad: bool) -> Self {
234        let inner = &self.inner;
235        let new_impl = TensorImpl {
236            storage: inner.storage.clone(),
237            shape: inner.shape.clone(),
238            strides: inner.strides.clone(),
239            grad: Mutex::new(None),
240            requires_grad,
241            op: inner.op.clone(),
242            is_leaf: inner.is_leaf,
243        };
244        Self {
245            inner: Arc::new(new_impl),
246        }
247    }
248
249    pub fn set_requires_grad_mut(&mut self, requires_grad: bool) {
250        if let Some(inner) = Arc::get_mut(&mut self.inner) {
251            inner.requires_grad = requires_grad;
252        } else {
253            // Clone if shared
254            *self = self.clone().set_requires_grad(requires_grad);
255        }
256    }
257
258    pub fn requires_grad(&self) -> bool {
259        self.inner.requires_grad
260    }
261
262    pub fn data(&self) -> RwLockReadGuard<'_, Vec<f32>> {
263        self.inner.storage.data()
264    }
265
266    pub fn data_mut(&self) -> RwLockWriteGuard<'_, Vec<f32>> {
267        self.inner.storage.data_mut()
268    }
269
270    pub fn grad(&self) -> Option<Tensor> {
271        self.inner.grad.lock().unwrap().clone()
272    }
273
274    pub fn zero_grad(&self) {
275        *self.inner.grad.lock().unwrap() = None;
276    }
277
278    pub fn accumulate_grad(&self, grad: &Tensor) {
279        let mut g = self.inner.grad.lock().unwrap();
280        if let Some(existing) = &*g {
281            #[cfg(feature = "wgpu_backend")]
282            {
283                let existing_is_wgpu = existing.storage().wgpu_buffer().is_some();
284                let grad_is_wgpu = grad.storage().wgpu_buffer().is_some();
285
286                if existing_is_wgpu && grad_is_wgpu {
287                    *g = Some(existing.add(grad));
288                } else if existing_is_wgpu {
289                    *g = Some(existing.add(&grad.to_wgpu()));
290                } else if grad_is_wgpu {
291                    *g = Some(existing.add(&grad.to_cpu()));
292                } else {
293                    *g = Some(existing.add(grad));
294                }
295            }
296            #[cfg(not(feature = "wgpu_backend"))]
297            {
298                *g = Some(existing.add(grad));
299            }
300        } else {
301            *g = Some(grad.clone());
302        }
303    }
304
305    pub fn backward(&self) {
306        // Gradient of scalar output is 1.0
307        if self.shape().len() != 1 || self.shape()[0] != 1 {
308            // Usually backward() is called on scalar loss.
309            // If not scalar, PyTorch requires gradient argument.
310            // RusTorch: implicitly assume 1.0 if scalar?
311            // If tensor is not scalar, we should probably fill ones.
312            // But for simplicity, let's assume scalar 1.0 or Tensor::ones.
313        }
314
315        let grad = Tensor::ones(self.shape());
316        self.accumulate_grad(&grad);
317        self.backward_step();
318    }
319
320    pub fn backward_step(&self) {
321        if let Some(op) = &self.inner.op {
322            if let Some(grad) = self.grad() {
323                op.backward(&grad);
324            }
325        }
326    }
327
328    /// Returns a new Tensor, detached from the current graph.
329    /// The result will never require gradient.
330    pub fn detach(&self) -> Tensor {
331        Tensor {
332            inner: Arc::new(TensorImpl {
333                storage: self.inner.storage.clone(),
334                shape: self.inner.shape.clone(),
335                strides: self.inner.strides.clone(),
336                grad: Mutex::new(None),
337                requires_grad: false,
338                op: None,
339                is_leaf: true,
340            }),
341        }
342    }
343
344    pub fn set_op(&mut self, op: Arc<dyn BackwardOp>) {
345        if let Some(inner) = Arc::get_mut(&mut self.inner) {
346            inner.op = Some(op);
347        } else {
348            // Panic or clone?
349            // Usually set_op is called during construction where we have unique ownership.
350            // If not, it means something is wrong.
351            // But `permute` cloned `inner`...
352            // In `permute`, I created a new Tensor with `inner: Arc::new(...)`.
353            // So `self.inner` is unique there.
354            panic!("Cannot set op on shared tensor storage wrapper");
355        }
356    }
357
358    pub fn matmul(&self, rhs: &Tensor) -> Tensor {
359        crate::ops::matmul(self, rhs)
360    }
361
362    pub fn t(&self) -> Tensor {
363        crate::ops::view::transpose(self, 0, 1) // Default to 2D transpose
364    }
365
366    pub fn sub(&self, rhs: &Tensor) -> Tensor {
367        crate::ops::sub(self, rhs)
368    }
369
370    pub fn add(&self, rhs: &Tensor) -> Tensor {
371        crate::ops::add(self, rhs)
372    }
373
374    pub fn neg(&self) -> Tensor {
375        crate::ops::neg(self)
376    }
377
378    pub fn relu(&self) -> Tensor {
379        crate::ops::relu(self)
380    }
381
382    pub fn sigmoid(&self) -> Tensor {
383        crate::ops::sigmoid(self)
384    }
385
386    pub fn tanh(&self) -> Tensor {
387        crate::ops::tanh(self)
388    }
389
390    pub fn softmax(&self, dim: i64) -> Tensor {
391        crate::ops::softmax(self, dim)
392    }
393
394    pub fn conv2d(
395        &self,
396        weight: &Tensor,
397        stride: (usize, usize),
398        padding: (usize, usize),
399    ) -> Tensor {
400        crate::ops::conv2d(self, weight, stride, padding)
401    }
402
403    pub fn max_pool2d(
404        &self,
405        kernel_size: (usize, usize),
406        stride: (usize, usize),
407        padding: (usize, usize),
408    ) -> Tensor {
409        crate::ops::max_pool2d(self, kernel_size, stride, padding)
410    }
411
412    #[allow(clippy::too_many_arguments)]
413    pub fn batch_norm2d(
414        &self,
415        gamma: Option<&Tensor>,
416        beta: Option<&Tensor>,
417        running_mean: &Tensor,
418        running_var: &Tensor,
419        training: bool,
420        momentum: f32,
421        eps: f32,
422    ) -> Tensor {
423        crate::ops::batch_norm2d(
424            self,
425            gamma,
426            beta,
427            running_mean,
428            running_var,
429            training,
430            momentum,
431            eps,
432        )
433    }
434
435    pub fn layer_norm(
436        &self,
437        normalized_shape: &[usize],
438        weight: Option<&Tensor>,
439        bias: Option<&Tensor>,
440        eps: f32,
441    ) -> Tensor {
442        crate::ops::layer_norm(self, normalized_shape, weight, bias, eps)
443    }
444
445    pub fn permute(&self, dims: &[usize]) -> Tensor {
446        crate::ops::view::permute(self, dims)
447    }
448
449    pub fn transpose(&self, dim0: usize, dim1: usize) -> Tensor {
450        crate::ops::view::transpose(self, dim0, dim1)
451    }
452
453    pub fn contiguous(&self) -> Tensor {
454        if self.is_contiguous() {
455            return self.clone();
456        }
457
458        #[cfg(feature = "wgpu_backend")]
459        if let Some(input_buf) = self.storage().wgpu_buffer() {
460            use crate::backend::wgpu::contiguous_wgpu;
461            let output_buf = contiguous_wgpu(input_buf, self.shape(), self.strides());
462            let size: usize = self.shape().iter().product();
463            let storage = Storage::new_wgpu(output_buf, size, 0);
464            let mut tensor = Tensor::new_with_storage(storage, self.shape());
465            tensor.set_requires_grad_mut(self.requires_grad());
466            return tensor;
467        }
468
469        crate::ops::view::contiguous(self)
470    }
471
472    pub fn is_contiguous(&self) -> bool {
473        let default_strides = Self::compute_strides(self.shape());
474        if self.strides() != default_strides {
475            return false;
476        }
477        let expected_size: usize = self.shape().iter().product();
478        let actual_size = self.storage().len();
479        expected_size == actual_size
480    }
481
482    pub fn normal_(&self, mean: f32, std: f32) {
483        let mut guard = self.data_mut();
484        let mut rng = rand::thread_rng();
485        let normal = Normal::new(mean, std).unwrap();
486        for x in guard.iter_mut() {
487            *x = normal.sample(&mut rng);
488        }
489    }
490
491    pub fn uniform_(&self, low: f32, high: f32) {
492        let mut guard = self.data_mut();
493        let mut rng = rand::thread_rng();
494        let uniform = Uniform::new(low, high);
495        for x in guard.iter_mut() {
496            *x = uniform.sample(&mut rng);
497        }
498    }
499
500    pub fn fill_(&self, value: f32) {
501        let mut guard = self.data_mut();
502        for x in guard.iter_mut() {
503            *x = value;
504        }
505    }
506
507    pub fn reshape(&self, new_shape: &[usize]) -> Tensor {
508        let size: usize = self.shape().iter().product();
509        let new_size: usize = new_shape.iter().product();
510        if size != new_size {
511            panic!(
512                "Reshape: element count mismatch: {:?} vs {:?}",
513                self.shape(),
514                new_shape
515            );
516        }
517
518        let inner = &self.inner;
519        let strides = Self::compute_strides(new_shape);
520
521        // Share storage, create new TensorImpl
522        let mut tensor = Self {
523            inner: Arc::new(TensorImpl {
524                storage: inner.storage.clone(),
525                shape: new_shape.to_vec(),
526                strides,
527                grad: Mutex::new(None),
528                requires_grad: inner.requires_grad,
529                op: None,
530                is_leaf: false,
531            }),
532        };
533
534        if inner.requires_grad {
535            tensor.set_op(Arc::new(crate::ops::ReshapeBackward {
536                input_shape: inner.shape.clone(),
537                input: self.clone(),
538            }));
539        }
540
541        tensor
542    }
543
544    pub fn mul(&self, rhs: &Tensor) -> Tensor {
545        crate::ops::mul(self, rhs)
546    }
547
548    pub fn div(&self, rhs: &Tensor) -> Tensor {
549        crate::ops::div(self, rhs)
550    }
551
552    #[cfg(feature = "wgpu_backend")]
553    pub fn matmul_relu(&self, rhs: &Tensor) -> Tensor {
554        crate::ops::matmul_fused(self, rhs, None, crate::backend::wgpu::Activation::ReLU)
555    }
556
557    #[cfg(not(feature = "wgpu_backend"))]
558    pub fn matmul_relu(&self, rhs: &Tensor) -> Tensor {
559        self.matmul(rhs).relu()
560    }
561
562    pub fn sgd_step(&self, grad: &Tensor, lr: f32) -> Tensor {
563        crate::ops::sgd_step(self, grad, lr)
564    }
565
566    pub fn copy_(&self, src: &Tensor) {
567        #[cfg(feature = "wgpu_backend")]
568        if self.storage().device().is_wgpu() {
569            if let Some(dest_buf) = self.storage().wgpu_buffer() {
570                let ctx = crate::backend::wgpu::get_context().expect("WGPU context missing");
571
572                // Case 1: src is WGPU -> GPU Copy
573                if let Some(src_buf) = src.storage().wgpu_buffer() {
574                    let mut encoder =
575                        ctx.device
576                            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
577                                label: Some("Copy Encoder"),
578                            });
579                    encoder.copy_buffer_to_buffer(src_buf, 0, dest_buf, 0, dest_buf.size());
580                    ctx.queue.submit(Some(encoder.finish()));
581                    return;
582                }
583
584                // Case 2: src is CPU -> Upload
585                let src_cpu = if src.storage().device().is_wgpu() {
586                    src.to_cpu()
587                } else {
588                    src.clone()
589                };
590                let src_guard = src_cpu.data();
591                ctx.queue
592                    .write_buffer(dest_buf, 0, bytemuck::cast_slice(&src_guard));
593                return;
594            }
595        }
596
597        // Case 3: self is CPU -> Memcpy
598        let src_cpu = if src.storage().device().is_wgpu() {
599            src.to_cpu()
600        } else {
601            src.clone()
602        };
603        let mut dest_guard = self.data_mut();
604        let src_guard = src_cpu.data();
605        if dest_guard.len() != src_guard.len() {
606            // panic!("copy_: element count mismatch");
607            // Allow broadcasting? No, strict copy_ usually.
608        }
609        let len = std::cmp::min(dest_guard.len(), src_guard.len());
610        dest_guard[..len].copy_from_slice(&src_guard[..len]);
611    }
612
613    fn compute_strides(shape: &[usize]) -> Vec<usize> {
614        let mut strides = vec![0; shape.len()];
615        let mut stride = 1;
616        for i in (0..shape.len()).rev() {
617            strides[i] = stride;
618            stride *= shape[i];
619        }
620        strides
621    }
622
623    // pub fn expand(&self, target_shape: &[usize]) -> Tensor {
624    //    crate::broadcast::expand(self, target_shape)
625    // }
626
627    pub fn copy_from_slice(&self, src: &[f32]) {
628        let mut guard = self.data_mut();
629        let len = std::cmp::min(guard.len(), src.len());
630        guard[..len].copy_from_slice(&src[..len]);
631    }
632}
633
634// Implement arithmetic traits for &Tensor
635impl Add for &Tensor {
636    type Output = Tensor;
637    fn add(self, rhs: Self) -> Tensor {
638        self.add(rhs)
639    }
640}
641
642impl Add<Tensor> for Tensor {
643    type Output = Tensor;
644    fn add(self, rhs: Tensor) -> Tensor {
645        Tensor::add(&self, &rhs)
646    }
647}
648
649impl Sub<Tensor> for Tensor {
650    type Output = Tensor;
651    fn sub(self, rhs: Tensor) -> Tensor {
652        Tensor::sub(&self, &rhs)
653    }
654}
655
656impl Mul<Tensor> for Tensor {
657    type Output = Tensor;
658    fn mul(self, rhs: Tensor) -> Tensor {
659        Tensor::mul(&self, &rhs)
660    }
661}
662
663impl Div<Tensor> for Tensor {
664    type Output = Tensor;
665    fn div(self, rhs: Tensor) -> Tensor {
666        Tensor::div(&self, &rhs)
667    }
668}
669
670impl Sub for &Tensor {
671    type Output = Tensor;
672    fn sub(self, rhs: Self) -> Tensor {
673        self.sub(rhs)
674    }
675}
676
677impl Mul for &Tensor {
678    type Output = Tensor;
679    fn mul(self, rhs: Self) -> Tensor {
680        self.mul(rhs)
681    }
682}
683
684impl Div for &Tensor {
685    type Output = Tensor;
686    fn div(self, rhs: Self) -> Tensor {
687        self.div(rhs)
688    }
689}
690
691#[cfg(feature = "serde")]
692#[derive(Serialize, Deserialize)]
693struct TensorData {
694    shape: Vec<usize>,
695    data: Vec<f32>,
696    requires_grad: bool,
697}
698
699#[cfg(feature = "serde")]
700impl Serialize for Tensor {
701    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
702    where
703        S: Serializer,
704    {
705        let data = self.data().clone();
706        let tensor_data = TensorData {
707            shape: self.shape().to_vec(),
708            data,
709            requires_grad: self.requires_grad(),
710        };
711        tensor_data.serialize(serializer)
712    }
713}
714
715#[cfg(feature = "serde")]
716impl<'de> Deserialize<'de> for Tensor {
717    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
718    where
719        D: Deserializer<'de>,
720    {
721        let tensor_data = TensorData::deserialize(deserializer)?;
722        let tensor = Tensor::new(&tensor_data.data, &tensor_data.shape)
723            .set_requires_grad(tensor_data.requires_grad);
724        Ok(tensor)
725    }
726}
727
728impl fmt::Display for Tensor {
729    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
730        let data = self.data();
731        let len = std::cmp::min(data.len(), 10);
732        write!(
733            f,
734            "Tensor(shape={:?}, data={:?})",
735            self.shape(),
736            &data[..len]
737        )
738    }
739}