tch_plus/wrappers/
tensor.rs

1use super::stream::ReadSeekAdapter;
2use super::utils::{path_to_cstring, ptr_to_string};
3use super::{
4    device::{Cuda, Device},
5    kind,
6    kind::Kind,
7};
8use crate::TchError;
9use libc::{c_char, c_int, c_void};
10use std::borrow::Borrow;
11use std::io::{Read, Seek, Write};
12use std::path::Path;
13use torch_sys_plus::io::ReadStream;
14use torch_sys_plus::*;
15
16/// A tensor object.
17#[must_use]
18pub struct Tensor {
19    pub(super) c_tensor: *mut C_tensor,
20}
21
22unsafe impl Send for Tensor {}
23
24pub extern "C" fn add_callback(data: *mut c_void, name: *const c_char, c_tensor: *mut C_tensor) {
25    let name = unsafe { std::ffi::CStr::from_ptr(name).to_str().unwrap() };
26    let name = name.replace('|', ".");
27    let v: &mut Vec<(String, Tensor)> = unsafe { &mut *(data as *mut Vec<(String, Tensor)>) };
28    v.push((name, Tensor { c_tensor }))
29}
30
31impl Tensor {
32    /// Creates a new tensor.
33    pub fn new() -> Tensor {
34        let c_tensor = unsafe_torch!(at_new_tensor());
35        Tensor { c_tensor }
36    }
37
38    /// Creates a new tensor from the pointer to an existing C++ tensor.
39    ///
40    /// # Safety
41    ///
42    /// The caller must ensures that the pointer outlives the Rust
43    /// object.
44    pub unsafe fn from_ptr(c_tensor: *mut C_tensor) -> Self {
45        Self { c_tensor }
46    }
47
48    /// Creates a new tensor from the pointer to an existing C++ tensor.
49    ///
50    /// # Safety
51    ///
52    /// A shallow copy of the pointer is made so there is no need for
53    /// this pointer to remain valid for the whole lifetime of the Rust
54    /// object.
55    pub unsafe fn clone_from_ptr(c_tensor: *mut C_tensor) -> Self {
56        let c_tensor = at_shallow_clone(c_tensor);
57        crate::wrappers::utils::read_and_clean_error().unwrap();
58        Self { c_tensor }
59    }
60
61    /// Returns a pointer to the underlying C++ tensor.
62    ///
63    /// The caller must ensures that the Rust tensor object outlives
64    /// this pointer.
65    pub fn as_ptr(&self) -> *const C_tensor {
66        self.c_tensor
67    }
68
69    /// Returns a mutable pointer to the underlying C++ tensor.
70    ///
71    /// The caller must ensures that the Rust tensor object outlives
72    /// this pointer.
73    pub fn as_mut_ptr(&mut self) -> *mut C_tensor {
74        self.c_tensor
75    }
76
77    /// Returns the number of dimension of the tensor.
78    pub fn dim(&self) -> usize {
79        unsafe_torch!(at_dim(self.c_tensor))
80    }
81
82    /// Returns the shape of the input tensor.
83    pub fn size(&self) -> Vec<i64> {
84        let dim = unsafe_torch!(at_dim(self.c_tensor));
85        let mut sz = vec![0i64; dim];
86        unsafe_torch!(at_shape(self.c_tensor, sz.as_mut_ptr()));
87        sz
88    }
89
90    /// Returns the tensor size for single dimension tensors.
91    pub fn size1(&self) -> Result<i64, TchError> {
92        match self.size().as_slice() {
93            &[s0] => Ok(s0),
94            size => Err(TchError::Shape(format!("expected one dim, got {size:?}"))),
95        }
96    }
97
98    /// Returns the tensor sizes for two dimension tensors.
99    pub fn size2(&self) -> Result<(i64, i64), TchError> {
100        match self.size().as_slice() {
101            &[s0, s1] => Ok((s0, s1)),
102            size => Err(TchError::Shape(format!("expected two dims, got {size:?}"))),
103        }
104    }
105
106    /// Returns the tensor sizes for three dimension tensors.
107    pub fn size3(&self) -> Result<(i64, i64, i64), TchError> {
108        match self.size().as_slice() {
109            &[s0, s1, s2] => Ok((s0, s1, s2)),
110            size => Err(TchError::Shape(format!("expected three dims, got {size:?}"))),
111        }
112    }
113
114    /// Returns the tensor sizes for four dimension tensors.
115    pub fn size4(&self) -> Result<(i64, i64, i64, i64), TchError> {
116        match self.size().as_slice() {
117            &[s0, s1, s2, s3] => Ok((s0, s1, s2, s3)),
118            size => Err(TchError::Shape(format!("expected four dims, got {size:?}"))),
119        }
120    }
121
122    /// Returns the tensor sizes for five dimension tensors.
123    pub fn size5(&self) -> Result<(i64, i64, i64, i64, i64), TchError> {
124        match self.size().as_slice() {
125            &[s0, s1, s2, s3, s4] => Ok((s0, s1, s2, s3, s4)),
126            size => Err(TchError::Shape(format!("expected five dims, got {size:?}"))),
127        }
128    }
129
130    /// Returns the tensor sizes for six dimension tensors.
131    pub fn size6(&self) -> Result<(i64, i64, i64, i64, i64, i64), TchError> {
132        match self.size().as_slice() {
133            &[s0, s1, s2, s3, s4, s5] => Ok((s0, s1, s2, s3, s4, s5)),
134            size => Err(TchError::Shape(format!("expected six dims, got {size:?}"))),
135        }
136    }
137
138    /// Returns the stride of the input tensor.
139    pub fn stride(&self) -> Vec<i64> {
140        let dim = unsafe_torch!(at_dim(self.c_tensor));
141        let mut sz = vec![0i64; dim];
142        unsafe_torch!(at_stride(self.c_tensor, sz.as_mut_ptr()));
143        sz
144    }
145
146    /// Returns the tensor strides for single dimension tensors.
147    pub fn stride1(&self) -> Result<i64, TchError> {
148        match self.stride().as_slice() {
149            &[s0] => Ok(s0),
150            size => Err(TchError::Shape(format!("expected one dim, got {size:?}"))),
151        }
152    }
153
154    /// Returns the tensor strides for two dimension tensors.
155    pub fn stride2(&self) -> Result<(i64, i64), TchError> {
156        match self.stride().as_slice() {
157            &[s0, s1] => Ok((s0, s1)),
158            size => Err(TchError::Shape(format!("expected two dims, got {size:?}"))),
159        }
160    }
161
162    /// Returns the tensor strides for three dimension tensors.
163    pub fn stride3(&self) -> Result<(i64, i64, i64), TchError> {
164        match self.stride().as_slice() {
165            &[s0, s1, s2] => Ok((s0, s1, s2)),
166            size => Err(TchError::Shape(format!("expected three dims, got {size:?}"))),
167        }
168    }
169
170    /// Returns the tensor strides for four dimension tensors.
171    pub fn stride4(&self) -> Result<(i64, i64, i64, i64), TchError> {
172        match self.stride().as_slice() {
173            &[s0, s1, s2, s3] => Ok((s0, s1, s2, s3)),
174            size => Err(TchError::Shape(format!("expected four dims, got {size:?}"))),
175        }
176    }
177
178    /// Returns the tensor strides for five dimension tensors.
179    pub fn stride5(&self) -> Result<(i64, i64, i64, i64, i64), TchError> {
180        match self.stride().as_slice() {
181            &[s0, s1, s2, s3, s4] => Ok((s0, s1, s2, s3, s4)),
182            size => Err(TchError::Shape(format!("expected five dims, got {size:?}"))),
183        }
184    }
185
186    /// Returns the tensor strides for six dimension tensors.
187    pub fn stride6(&self) -> Result<(i64, i64, i64, i64, i64, i64), TchError> {
188        match self.stride().as_slice() {
189            &[s0, s1, s2, s3, s4, s5] => Ok((s0, s1, s2, s3, s4, s5)),
190            size => Err(TchError::Shape(format!("expected six dims, got {size:?}"))),
191        }
192    }
193
194    /// Returns the kind of elements stored in the input tensor. Returns
195    /// an error on undefined tensors and unsupported data types.
196    pub fn f_kind(&self) -> Result<Kind, TchError> {
197        let kind = unsafe_torch!(at_scalar_type(self.c_tensor));
198        Kind::from_c_int(kind)
199    }
200
201    /// Returns the kind of elements stored in the input tensor. Panics
202    /// an error on undefined tensors and unsupported data types.
203    pub fn kind(&self) -> Kind {
204        self.f_kind().unwrap()
205    }
206
207    /// Returns the device on which the input tensor is located.
208    pub fn device(&self) -> Device {
209        let device = unsafe_torch!(at_device(self.c_tensor));
210        Device::from_c_int(device)
211    }
212
213    /// Prints the input tensor.
214    ///
215    /// Caution: this uses the C++ printer which prints the whole tensor even if
216    /// it is very large.
217    pub fn print(&self) {
218        unsafe_torch!(at_print(self.c_tensor))
219    }
220
221    /// Returns a double value on tensors holding a single element. An error is
222    /// returned otherwise.
223    pub fn f_double_value(&self, idx: &[i64]) -> Result<f64, TchError> {
224        Ok(unsafe_torch_err!({
225            at_double_value_at_indexes(self.c_tensor, idx.as_ptr(), idx.len() as i32)
226        }))
227    }
228
229    /// Returns an int value on tensors holding a single element. An error is
230    /// returned otherwise.
231    pub fn f_int64_value(&self, idx: &[i64]) -> Result<i64, TchError> {
232        Ok(unsafe_torch_err!({
233            at_int64_value_at_indexes(self.c_tensor, idx.as_ptr(), idx.len() as i32)
234        }))
235    }
236
237    /// Returns a double value on tensors holding a single element. Panics otherwise.
238    pub fn double_value(&self, idx: &[i64]) -> f64 {
239        self.f_double_value(idx).unwrap()
240    }
241
242    /// Returns an int value on tensors holding a single element. Panics otherwise.
243    pub fn int64_value(&self, idx: &[i64]) -> i64 {
244        self.f_int64_value(idx).unwrap()
245    }
246
247    /// Returns true if gradient are currently tracked for this tensor.
248    pub fn requires_grad(&self) -> bool {
249        unsafe_torch!(at_requires_grad(self.c_tensor)) != 0
250    }
251
252    /// Returns the address of the first element of this tensor.
253    pub fn data_ptr(&self) -> *mut c_void {
254        unsafe_torch!(at_data_ptr(self.c_tensor))
255    }
256
257    /// Returns true if the tensor is defined.
258    pub fn defined(&self) -> bool {
259        unsafe_torch!(at_defined(self.c_tensor) != 0)
260    }
261
262    /// Returns true if the tensor is compatible with MKL-DNN (oneDNN).
263    pub fn is_mkldnn(&self) -> bool {
264        unsafe_torch!(at_is_mkldnn(self.c_tensor) != 0)
265    }
266
267    /// Returns true if the tensor is sparse.
268    pub fn is_sparse(&self) -> bool {
269        unsafe_torch!(at_is_sparse(self.c_tensor) != 0)
270    }
271
272    // Returns true if the tensor if contiguous
273    pub fn is_contiguous(&self) -> bool {
274        unsafe_torch!(at_is_contiguous(self.c_tensor) != 0)
275    }
276
277    /// Zeroes the gradient tensor attached to this tensor if defined.
278    pub fn zero_grad(&mut self) {
279        let mut grad = self.grad();
280        if grad.defined() {
281            let _ = grad.detach_().zero_();
282        }
283    }
284
285    /// Runs the backward pass, populating the gradient tensors for tensors
286    /// which gradients are tracked.
287    ///
288    /// Gradients tracking can be turned on via `set_requires_grad`.
289    pub fn f_backward(&self) -> Result<(), TchError> {
290        unsafe_torch_err!(at_backward(self.c_tensor, 0, 0));
291        Ok(())
292    }
293
294    /// Runs the backward pass, populating the gradient tensors for tensors
295    /// which gradients are tracked.
296    ///
297    /// Gradients tracking can be turned on via `set_requires_grad`.
298    /// Panics if the C++ api returns an exception.
299    pub fn backward(&self) {
300        self.f_backward().unwrap()
301    }
302
303    /// Runs the backward pass, populating the gradient tensors for tensors
304    /// which gradients are tracked.
305    ///
306    /// Gradients tracking can be turned on via `set_requires_grad`.
307    pub fn f_backward_with_grad(&self, grad: &Self, keep_graph: bool, create_graph: bool) -> Result<(), TchError> {
308        let keep_graph = if keep_graph { 1 } else { 0 };
309        let create_graph = if create_graph { 1 } else { 0 };
310        unsafe_torch_err!(at_backward_with_grad(self.c_tensor, grad.c_tensor, keep_graph, create_graph));
311        Ok(())
312    }
313
314    pub fn backward_with_grad(&self, grad: &Self, keep_graph: bool, create_graph: bool) {
315        self.f_backward_with_grad(grad, keep_graph, create_graph).unwrap()
316    }
317
318    pub fn f_run_backward<T1, T2>(
319        tensors: &[T1],
320        inputs: &[T2],
321        keep_graph: bool,
322        create_graph: bool,
323    ) -> Result<Vec<Tensor>, TchError>
324    where
325        T1: Borrow<Tensor>,
326        T2: Borrow<Tensor>,
327    {
328        let mut outputs = vec![std::ptr::null_mut(); inputs.len()];
329        let tensors: Vec<_> = tensors.iter().map(|x| x.borrow().c_tensor).collect();
330        let inputs: Vec<_> = inputs.iter().map(|x| x.borrow().c_tensor).collect();
331        unsafe_torch_err!(at_run_backward(
332            tensors.as_ptr(),
333            tensors.len() as c_int,
334            inputs.as_ptr(),
335            inputs.len() as c_int,
336            outputs.as_mut_ptr(),
337            keep_graph as c_int,
338            create_graph as c_int,
339        ));
340        Ok(outputs.into_iter().map(|c_tensor| Tensor { c_tensor }).collect())
341    }
342
343    pub fn run_backward<T1, T2>(
344        tensors: &[T1],
345        inputs: &[T2],
346        keep_graph: bool,
347        create_graph: bool,
348    ) -> Vec<Tensor>
349    where
350        T1: Borrow<Tensor>,
351        T2: Borrow<Tensor>,
352    {
353        Tensor::f_run_backward(tensors, inputs, keep_graph, create_graph).unwrap()
354    }
355
356    /// Copies `numel` elements from `self` to `dst`.
357    pub fn f_copy_data_u8(&self, dst: &mut [u8], numel: usize) -> Result<(), TchError> {
358        let elt_size_in_bytes = self.f_kind()?.elt_size_in_bytes();
359        if dst.len() < numel * elt_size_in_bytes {
360            return Err(TchError::Shape(format!("slice len < {numel}")));
361        }
362        unsafe_torch_err!(at_copy_data(
363            self.c_tensor,
364            dst.as_mut_ptr() as *const c_void,
365            numel,
366            elt_size_in_bytes,
367        ));
368        Ok(())
369    }
370
371    /// Unscale tensor while checking for infinities.
372    ///
373    /// `found_inf` is a singleton tensor that is used to record the
374    /// presence of infinite values. `inv_scale` is a scalar containing
375    /// the inverse scaling factor. This method is only available
376    /// for CUDA tensors.
377    pub fn f_internal_amp_non_finite_check_and_unscale(
378        &mut self,
379        found_inf: &mut Tensor,
380        inv_scale: &Tensor,
381    ) -> Result<(), TchError> {
382        unsafe_torch_err!(at__amp_non_finite_check_and_unscale(
383            self.c_tensor,
384            found_inf.c_tensor,
385            inv_scale.c_tensor
386        ));
387
388        Ok(())
389    }
390
391    /// Unscale tensor while checking for infinities.
392    ///
393    /// `found_inf` is a singleton tensor that is used to record the
394    /// presence of infinite values. `inv_scale` is a scalar containing
395    /// the inverse scaling factor. This method is only available
396    /// for CUDA tensors.
397    pub fn internal_amp_non_finite_check_and_unscale(
398        &mut self,
399        found_inf: &mut Tensor,
400        inv_scale: &Tensor,
401    ) {
402        self.f_internal_amp_non_finite_check_and_unscale(found_inf, inv_scale).unwrap()
403    }
404
405    /// Copies `numel` elements from `self` to `dst`.
406    pub fn copy_data_u8(&self, dst: &mut [u8], numel: usize) {
407        self.f_copy_data_u8(dst, numel).unwrap()
408    }
409
410    /// Copies `numel` elements from `self` to `dst`.
411    pub fn f_copy_data<T: kind::Element>(
412        &self,
413        dst: &mut [T],
414        numel: usize,
415    ) -> Result<(), TchError> {
416        if T::KIND != self.f_kind()? {
417            return Err(TchError::Kind(format!(
418                "incoherent elt kind, {:?} != {:?}",
419                self.f_kind(),
420                T::KIND
421            )));
422        }
423        if dst.len() < numel {
424            return Err(TchError::Shape(format!("slice len < {numel}")));
425        }
426        unsafe_torch_err!(at_copy_data(
427            self.c_tensor,
428            dst.as_mut_ptr() as *const c_void,
429            numel,
430            T::KIND.elt_size_in_bytes(),
431        ));
432        Ok(())
433    }
434
435    /// Copies `numel` elements from `self` to `dst`.
436    pub fn copy_data<T: kind::Element>(&self, dst: &mut [T], numel: usize) {
437        self.f_copy_data(dst, numel).unwrap()
438    }
439
440    /// Returns the total number of elements stored in a tensor.
441    pub fn numel(&self) -> usize {
442        self.size().iter().product::<i64>() as usize
443    }
444
445    // This is similar to vec_... but faster as it directly blits the data.
446    /// Converts a slice to a tensor.
447    pub fn f_from_slice<T: kind::Element>(data: &[T]) -> Result<Tensor, TchError> {
448        let data_len = data.len();
449        let data = data.as_ptr() as *const c_void;
450        let c_tensor = unsafe_torch_err!(at_tensor_of_data(
451            data,
452            [data_len as i64].as_ptr(),
453            1,
454            T::KIND.elt_size_in_bytes(),
455            T::KIND.c_int(),
456        ));
457        Ok(Tensor { c_tensor })
458    }
459
460    /// Converts a slice to a tensor.
461    pub fn from_slice<T: kind::Element>(data: &[T]) -> Tensor {
462        Self::f_from_slice(data).unwrap()
463    }
464
465    /// Converts some byte data to a tensor with some specified kind and shape.
466    pub fn f_from_data_size(data: &[u8], size: &[i64], kind: Kind) -> Result<Tensor, TchError> {
467        let data = data.as_ptr() as *const c_void;
468        let elt_size_in_bytes = kind.elt_size_in_bytes();
469        let c_tensor = unsafe_torch_err!(at_tensor_of_data(
470            data,
471            size.as_ptr(),
472            size.len(),
473            elt_size_in_bytes,
474            kind.c_int(),
475        ));
476        Ok(Tensor { c_tensor })
477    }
478
479    /// Creates a tensor from data that is assumed to be initialized.
480    /// Resize operations are not allowed on this tensor without copying the data first.
481    /// An empty strides slice will result in using the default strides.
482    /// # Safety
483    ///   Behavior is undefined if `data` points to invalid data.
484    pub unsafe fn f_from_blob(
485        data: *const u8,
486        size: &[i64],
487        strides: &[i64],
488        kind: Kind,
489        device: Device,
490    ) -> Result<Tensor, TchError> {
491        let data = data as *const c_void;
492        #[allow(unused_unsafe)]
493        let c_tensor = unsafe_torch_err!(at_tensor_of_blob(
494            data,
495            size.as_ptr(),
496            size.len(),
497            strides.as_ptr(),
498            strides.len(),
499            kind.c_int(),
500            device.c_int()
501        ));
502        Ok(Tensor { c_tensor })
503    }
504
505    /// Creates a tensor from data that is assumed to be initialized.
506    /// Resize operations are not allowed on this tensor without copying the data first.
507    /// An empty strides slice will result in using the default strides.
508    /// # Safety
509    ///   Behavior is undefined if `data` points to invalid data.
510    pub unsafe fn from_blob(
511        data: *const u8,
512        size: &[i64],
513        strides: &[i64],
514        kind: Kind,
515        device: Device,
516    ) -> Tensor {
517        Self::f_from_blob(data, size, strides, kind, device).unwrap()
518    }
519
520    /// Converts some byte data to a tensor with some specified kind and shape.
521    pub fn from_data_size(data: &[u8], size: &[i64], kind: Kind) -> Tensor {
522        Self::f_from_data_size(data, size, kind).unwrap()
523    }
524
525    /// Returns a new tensor that share storage with the input tensor.
526    pub fn shallow_clone(&self) -> Tensor {
527        let c_tensor = unsafe_torch!(at_shallow_clone(self.c_tensor));
528        Tensor { c_tensor }
529    }
530
531    /// Gets the sub-tensor at the given index.
532    pub fn f_get(&self, index: i64) -> Result<Tensor, TchError> {
533        let c_tensor = unsafe_torch_err!(at_get(self.c_tensor, index as c_int));
534        Ok(Tensor { c_tensor })
535    }
536
537    /// Gets the sub-tensor at the given index.
538    pub fn get(&self, index: i64) -> Tensor {
539        self.f_get(index).unwrap()
540    }
541
542    /// Copies values from the argument tensor to the input tensor.
543    pub fn f_copy_(&mut self, src: &Tensor) -> Result<(), TchError> {
544        unsafe_torch_err!(at_copy_(self.c_tensor, src.c_tensor));
545        Ok(())
546    }
547
548    /// Copies values from the argument tensor to the input tensor.
549    pub fn copy_(&mut self, src: &Tensor) {
550        self.f_copy_(src).unwrap()
551    }
552
553    /// Loads a tensor from a file.
554    ///
555    /// The file format is the same as the one used by the PyTorch C++ API.
556    pub fn load<T: AsRef<Path>>(path: T) -> Result<Tensor, TchError> {
557        let path = path_to_cstring(path)?;
558        let c_tensor = unsafe_torch_err!(at_load(path.as_ptr()));
559        Ok(Tensor { c_tensor })
560    }
561
562    /// Loads a tensor from a stream.
563    ///
564    /// The file format is the same as the one used by the PyTorch C++ API.
565    pub fn load_from_stream<T: Read + Seek>(stream: T) -> Result<Tensor, TchError> {
566        let adapter = ReadSeekAdapter::new(stream);
567        let boxed_stream: Box<Box<dyn ReadStream>> = Box::new(Box::new(adapter));
568        let c_tensor =
569            unsafe_torch_err!(at_load_from_stream(Box::into_raw(boxed_stream) as *mut c_void,));
570        Ok(Tensor { c_tensor })
571    }
572
573    /// Saves a tensor to a file.
574    ///
575    /// The file format is the same as the one used by the PyTorch C++ API.
576    pub fn save<T: AsRef<Path>>(&self, path: T) -> Result<(), TchError> {
577        let path = path_to_cstring(path)?;
578        unsafe_torch_err!(at_save(self.c_tensor, path.as_ptr()));
579        Ok(())
580    }
581
582    /// Saves a tensor to a stream.
583    ///
584    /// The file format is the same as the one used by the PyTorch C++ API.
585    pub fn save_to_stream<W: Write>(&self, stream: W) -> Result<(), TchError> {
586        let boxed_stream: Box<Box<dyn Write>> = Box::new(Box::new(stream));
587        unsafe_torch_err!(at_save_to_stream(
588            self.c_tensor,
589            Box::into_raw(boxed_stream) as *mut c_void,
590        ));
591        Ok(())
592    }
593
594    /// Saves some named tensors to a file
595    ///
596    /// The file format is the same as the one used by the PyTorch C++ API.
597    pub fn save_multi<S: AsRef<str>, T: AsRef<Tensor>, P: AsRef<Path>>(
598        named_tensors: &[(S, T)],
599        path: P,
600    ) -> Result<(), TchError> {
601        let path = path_to_cstring(path)?;
602        let c_tensors = named_tensors.iter().map(|nt| nt.1.as_ref().c_tensor).collect::<Vec<_>>();
603        let names = named_tensors
604            .iter()
605            .map(|nt| nt.0.as_ref().replace('.', "|").into_bytes())
606            .map(std::ffi::CString::new)
607            .collect::<Result<Vec<_>, _>>()?;
608        let name_ptrs = names.iter().map(|n| n.as_ptr()).collect::<Vec<_>>();
609        unsafe_torch_err!(at_save_multi(
610            c_tensors.as_ptr(),
611            name_ptrs.as_ptr(),
612            names.len() as i32,
613            path.as_ptr(),
614        ));
615        Ok(())
616    }
617
618    /// Saves some named tensors to a stream
619    ///
620    /// The file format is the same as the one used by the PyTorch C++ API.
621    pub fn save_multi_to_stream<S: AsRef<str>, T: AsRef<Tensor>, W: Write>(
622        named_tensors: &[(S, T)],
623        stream: W,
624    ) -> Result<(), TchError> {
625        let boxed_stream: Box<Box<dyn Write>> = Box::new(Box::new(stream));
626        let c_tensors = named_tensors.iter().map(|nt| nt.1.as_ref().c_tensor).collect::<Vec<_>>();
627        let names = named_tensors
628            .iter()
629            .map(|nt| nt.0.as_ref().replace('.', "|").into_bytes())
630            .map(std::ffi::CString::new)
631            .collect::<Result<Vec<_>, _>>()?;
632        let name_ptrs = names.iter().map(|n| n.as_ptr()).collect::<Vec<_>>();
633        unsafe_torch_err!(at_save_multi_to_stream(
634            c_tensors.as_ptr(),
635            name_ptrs.as_ptr(),
636            names.len() as i32,
637            Box::into_raw(boxed_stream) as *mut c_void,
638        ));
639        Ok(())
640    }
641
642    /// Loads some named tensors from a file
643    ///
644    /// The file format is the same as the one used for modules in the PyTorch C++ API.
645    /// It commonly uses the .ot extension.
646    pub fn load_multi<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Tensor)>, TchError> {
647        let path = path_to_cstring(path)?;
648        let mut v: Vec<(String, Tensor)> = vec![];
649        unsafe_torch_err!(at_load_callback(
650            path.as_ptr(),
651            &mut v as *mut _ as *mut c_void,
652            add_callback
653        ));
654        Ok(v)
655    }
656
657    /// Loads some named tensors from a file to a given device
658    ///
659    /// The file format is the same as the one used for modules in the PyTorch C++ API.
660    /// It commonly uses the .ot extension.
661    pub fn load_multi_with_device<T: AsRef<Path>>(
662        path: T,
663        device: Device,
664    ) -> Result<Vec<(String, Tensor)>, TchError> {
665        let path = path_to_cstring(path)?;
666        let mut v: Vec<(String, Tensor)> = vec![];
667        unsafe_torch_err!(at_load_callback_with_device(
668            path.as_ptr(),
669            &mut v as *mut _ as *mut c_void,
670            add_callback,
671            device.c_int(),
672        ));
673        Ok(v)
674    }
675
676    /// Loads some named tensors from a zip file
677    ///
678    /// The expected file format is a zip archive containing a data.pkl file describing
679    /// the embedded tensors. These are commonly used with the .bin extension to export
680    /// PyTorch models and weights using the Python api.
681    pub fn loadz_multi<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Tensor)>, TchError> {
682        let path = path_to_cstring(path)?;
683        let mut v: Vec<(String, Tensor)> = vec![];
684        unsafe_torch_err!(at_loadz_callback(
685            path.as_ptr(),
686            &mut v as *mut _ as *mut c_void,
687            add_callback
688        ));
689        Ok(v)
690    }
691
692    /// Loads some named tensors from a zip file to a given device
693    ///
694    /// The expected file format is a zip archive containing a data.pkl file describing
695    /// the embedded tensors. These are commonly used with the .bin extension to export
696    /// PyTorch models and weights using the Python api.
697    pub fn loadz_multi_with_device<T: AsRef<Path>>(
698        path: T,
699        device: Device,
700    ) -> Result<Vec<(String, Tensor)>, TchError> {
701        let path = path_to_cstring(path)?;
702        let mut v: Vec<(String, Tensor)> = vec![];
703        unsafe_torch_err!(at_loadz_callback_with_device(
704            path.as_ptr(),
705            &mut v as *mut _ as *mut c_void,
706            add_callback,
707            device.c_int(),
708        ));
709        Ok(v)
710    }
711
712    /// Loads some named tensors from a stream
713    ///
714    /// The file format is the same as the one used by the PyTorch C++ API.
715    pub fn load_multi_from_stream<T: Read + Seek>(
716        stream: T,
717    ) -> Result<Vec<(String, Tensor)>, TchError> {
718        let adapter = ReadSeekAdapter::new(stream);
719        let boxed_stream: Box<Box<dyn ReadStream>> = Box::new(Box::new(adapter));
720        let mut v: Vec<(String, Tensor)> = vec![];
721        unsafe_torch_err!(at_load_from_stream_callback(
722            Box::into_raw(boxed_stream) as *mut c_void,
723            &mut v as *mut _ as *mut c_void,
724            add_callback,
725            false,
726            0,
727        ));
728        Ok(v)
729    }
730
731    /// Loads some named tensors from a stream to a given device
732    ///
733    /// The file format is the same as the one used by the PyTorch C++ API.
734    pub fn load_multi_from_stream_with_device<T: Read + Seek>(
735        stream: T,
736        device: Device,
737    ) -> Result<Vec<(String, Tensor)>, TchError> {
738        let adapter = ReadSeekAdapter::new(stream);
739        let boxed_stream: Box<Box<dyn ReadStream>> = Box::new(Box::new(adapter));
740        let mut v: Vec<(String, Tensor)> = vec![];
741        unsafe_torch_err!(at_load_from_stream_callback(
742            Box::into_raw(boxed_stream) as *mut c_void,
743            &mut v as *mut _ as *mut c_void,
744            add_callback,
745            true,
746            device.c_int(),
747        ));
748        Ok(v)
749    }
750
751    /// Returns a string representation for the tensor.
752    ///
753    /// The representation will contain all the tensor element hence may be huge for
754    /// large tensors.
755    pub fn to_string(&self, lw: i64) -> Result<String, TchError> {
756        let s =
757            unsafe_torch_err!(ptr_to_string(torch_sys_plus::at_to_string(self.c_tensor, lw as c_int)));
758        match s {
759            None => Err(TchError::Kind("nullptr representation".to_string())),
760            Some(s) => Ok(s),
761        }
762    }
763}
764
765impl Default for Tensor {
766    fn default() -> Self {
767        Self::new()
768    }
769}
770
771impl Drop for Tensor {
772    fn drop(&mut self) {
773        unsafe_torch!(at_free(self.c_tensor))
774    }
775}
776
777fn autocast_clear_cache() {
778    unsafe_torch!(at_autocast_clear_cache())
779}
780
781fn autocast_decrement_nesting() -> isize {
782    unsafe_torch!(at_autocast_decrement_nesting() as isize)
783}
784
785fn autocast_increment_nesting() -> isize {
786    unsafe_torch!(at_autocast_increment_nesting() as isize)
787}
788
789fn autocast_is_enabled() -> bool {
790    unsafe_torch!(at_autocast_is_enabled() != 0)
791}
792
793fn autocast_set_enabled(b: bool) -> bool {
794    unsafe_torch!(at_autocast_set_enabled(i32::from(b)) != 0)
795}
796
797/// Runs a closure in mixed precision.
798pub fn autocast<T, F>(enabled: bool, f: F) -> T
799where
800    F: FnOnce() -> T,
801{
802    if !Cuda::is_available() {
803        return f();
804    }
805
806    // Check whether we are using CUDA.
807    let prev = autocast_is_enabled();
808    autocast_set_enabled(enabled);
809    autocast_increment_nesting();
810
811    let result = f();
812
813    if autocast_decrement_nesting() == 0 {
814        autocast_clear_cache();
815    }
816    autocast_set_enabled(prev);
817
818    result
819}
820
821fn grad_set_enabled(b: bool) -> bool {
822    unsafe_torch!(at_grad_set_enabled(i32::from(b)) != 0)
823}
824
825/// Runs a closure without keeping track of gradients.
826pub fn no_grad<T, F>(f: F) -> T
827where
828    F: FnOnce() -> T,
829{
830    let prev = grad_set_enabled(false);
831    let result = f();
832    let _false = grad_set_enabled(prev);
833    result
834}
835
836/// Runs a closure explicitly keeping track of gradients, this could be
837/// run within a no_grad closure for example.
838pub fn with_grad<T, F>(f: F) -> T
839where
840    F: FnOnce() -> T,
841{
842    let prev = grad_set_enabled(true);
843    let result = f();
844    let _false = grad_set_enabled(prev);
845    result
846}
847
848/// A RAII guard that prevents gradient tracking until deallocated.
849pub struct NoGradGuard {
850    enabled: bool,
851}
852
853/// Disables gradient tracking, this will be enabled back when the
854/// returned value gets deallocated.
855/// Note that it is important to bind this to a name like `_guard`
856/// and not to `_` as the latter would immediately drop the guard.
857/// See <https://internals.rust-lang.org/t/pre-rfc-must-bind/12658/46>
858/// for more details.
859pub fn no_grad_guard() -> NoGradGuard {
860    NoGradGuard { enabled: grad_set_enabled(false) }
861}
862
863impl std::convert::AsRef<Tensor> for Tensor {
864    fn as_ref(&self) -> &Self {
865        self
866    }
867}
868
869impl Drop for NoGradGuard {
870    fn drop(&mut self) {
871        let _enabled = grad_set_enabled(self.enabled);
872    }
873}
874
875#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
876pub enum Reduction {
877    /// Do not reduce.
878    None,
879    /// Mean of losses.
880    Mean,
881    /// Sum of losses.
882    Sum,
883    /// Escape hatch in case new options become available.
884    Other(i64),
885}
886
887impl Reduction {
888    // This has to stay in sync with
889    // pytorch/aten/src/ATen/core/Reduction.h
890    pub fn to_int(self) -> i64 {
891        match self {
892            Reduction::None => 0,
893            Reduction::Mean => 1,
894            Reduction::Sum => 2,
895            Reduction::Other(i) => i,
896        }
897    }
898}