1use super::stream::ReadSeekAdapter;
2use super::utils::{path_to_cstring, ptr_to_string};
3use super::{
4 device::{Cuda, Device},
5 kind,
6 kind::Kind,
7};
8use crate::TchError;
9use libc::{c_char, c_int, c_void};
10use std::borrow::Borrow;
11use std::io::{Read, Seek, Write};
12use std::path::Path;
13use torch_sys_plus::io::ReadStream;
14use torch_sys_plus::*;
15
16#[must_use]
18pub struct Tensor {
19 pub(super) c_tensor: *mut C_tensor,
20}
21
22unsafe impl Send for Tensor {}
23
24pub extern "C" fn add_callback(data: *mut c_void, name: *const c_char, c_tensor: *mut C_tensor) {
25 let name = unsafe { std::ffi::CStr::from_ptr(name).to_str().unwrap() };
26 let name = name.replace('|', ".");
27 let v: &mut Vec<(String, Tensor)> = unsafe { &mut *(data as *mut Vec<(String, Tensor)>) };
28 v.push((name, Tensor { c_tensor }))
29}
30
31impl Tensor {
32 pub fn new() -> Tensor {
34 let c_tensor = unsafe_torch!(at_new_tensor());
35 Tensor { c_tensor }
36 }
37
38 pub unsafe fn from_ptr(c_tensor: *mut C_tensor) -> Self {
45 Self { c_tensor }
46 }
47
48 pub unsafe fn clone_from_ptr(c_tensor: *mut C_tensor) -> Self {
56 let c_tensor = at_shallow_clone(c_tensor);
57 crate::wrappers::utils::read_and_clean_error().unwrap();
58 Self { c_tensor }
59 }
60
61 pub fn as_ptr(&self) -> *const C_tensor {
66 self.c_tensor
67 }
68
69 pub fn as_mut_ptr(&mut self) -> *mut C_tensor {
74 self.c_tensor
75 }
76
77 pub fn dim(&self) -> usize {
79 unsafe_torch!(at_dim(self.c_tensor))
80 }
81
82 pub fn size(&self) -> Vec<i64> {
84 let dim = unsafe_torch!(at_dim(self.c_tensor));
85 let mut sz = vec![0i64; dim];
86 unsafe_torch!(at_shape(self.c_tensor, sz.as_mut_ptr()));
87 sz
88 }
89
90 pub fn size1(&self) -> Result<i64, TchError> {
92 match self.size().as_slice() {
93 &[s0] => Ok(s0),
94 size => Err(TchError::Shape(format!("expected one dim, got {size:?}"))),
95 }
96 }
97
98 pub fn size2(&self) -> Result<(i64, i64), TchError> {
100 match self.size().as_slice() {
101 &[s0, s1] => Ok((s0, s1)),
102 size => Err(TchError::Shape(format!("expected two dims, got {size:?}"))),
103 }
104 }
105
106 pub fn size3(&self) -> Result<(i64, i64, i64), TchError> {
108 match self.size().as_slice() {
109 &[s0, s1, s2] => Ok((s0, s1, s2)),
110 size => Err(TchError::Shape(format!("expected three dims, got {size:?}"))),
111 }
112 }
113
114 pub fn size4(&self) -> Result<(i64, i64, i64, i64), TchError> {
116 match self.size().as_slice() {
117 &[s0, s1, s2, s3] => Ok((s0, s1, s2, s3)),
118 size => Err(TchError::Shape(format!("expected four dims, got {size:?}"))),
119 }
120 }
121
122 pub fn size5(&self) -> Result<(i64, i64, i64, i64, i64), TchError> {
124 match self.size().as_slice() {
125 &[s0, s1, s2, s3, s4] => Ok((s0, s1, s2, s3, s4)),
126 size => Err(TchError::Shape(format!("expected five dims, got {size:?}"))),
127 }
128 }
129
130 pub fn size6(&self) -> Result<(i64, i64, i64, i64, i64, i64), TchError> {
132 match self.size().as_slice() {
133 &[s0, s1, s2, s3, s4, s5] => Ok((s0, s1, s2, s3, s4, s5)),
134 size => Err(TchError::Shape(format!("expected six dims, got {size:?}"))),
135 }
136 }
137
138 pub fn stride(&self) -> Vec<i64> {
140 let dim = unsafe_torch!(at_dim(self.c_tensor));
141 let mut sz = vec![0i64; dim];
142 unsafe_torch!(at_stride(self.c_tensor, sz.as_mut_ptr()));
143 sz
144 }
145
146 pub fn stride1(&self) -> Result<i64, TchError> {
148 match self.stride().as_slice() {
149 &[s0] => Ok(s0),
150 size => Err(TchError::Shape(format!("expected one dim, got {size:?}"))),
151 }
152 }
153
154 pub fn stride2(&self) -> Result<(i64, i64), TchError> {
156 match self.stride().as_slice() {
157 &[s0, s1] => Ok((s0, s1)),
158 size => Err(TchError::Shape(format!("expected two dims, got {size:?}"))),
159 }
160 }
161
162 pub fn stride3(&self) -> Result<(i64, i64, i64), TchError> {
164 match self.stride().as_slice() {
165 &[s0, s1, s2] => Ok((s0, s1, s2)),
166 size => Err(TchError::Shape(format!("expected three dims, got {size:?}"))),
167 }
168 }
169
170 pub fn stride4(&self) -> Result<(i64, i64, i64, i64), TchError> {
172 match self.stride().as_slice() {
173 &[s0, s1, s2, s3] => Ok((s0, s1, s2, s3)),
174 size => Err(TchError::Shape(format!("expected four dims, got {size:?}"))),
175 }
176 }
177
178 pub fn stride5(&self) -> Result<(i64, i64, i64, i64, i64), TchError> {
180 match self.stride().as_slice() {
181 &[s0, s1, s2, s3, s4] => Ok((s0, s1, s2, s3, s4)),
182 size => Err(TchError::Shape(format!("expected five dims, got {size:?}"))),
183 }
184 }
185
186 pub fn stride6(&self) -> Result<(i64, i64, i64, i64, i64, i64), TchError> {
188 match self.stride().as_slice() {
189 &[s0, s1, s2, s3, s4, s5] => Ok((s0, s1, s2, s3, s4, s5)),
190 size => Err(TchError::Shape(format!("expected six dims, got {size:?}"))),
191 }
192 }
193
194 pub fn f_kind(&self) -> Result<Kind, TchError> {
197 let kind = unsafe_torch!(at_scalar_type(self.c_tensor));
198 Kind::from_c_int(kind)
199 }
200
201 pub fn kind(&self) -> Kind {
204 self.f_kind().unwrap()
205 }
206
207 pub fn device(&self) -> Device {
209 let device = unsafe_torch!(at_device(self.c_tensor));
210 Device::from_c_int(device)
211 }
212
213 pub fn print(&self) {
218 unsafe_torch!(at_print(self.c_tensor))
219 }
220
221 pub fn f_double_value(&self, idx: &[i64]) -> Result<f64, TchError> {
224 Ok(unsafe_torch_err!({
225 at_double_value_at_indexes(self.c_tensor, idx.as_ptr(), idx.len() as i32)
226 }))
227 }
228
229 pub fn f_int64_value(&self, idx: &[i64]) -> Result<i64, TchError> {
232 Ok(unsafe_torch_err!({
233 at_int64_value_at_indexes(self.c_tensor, idx.as_ptr(), idx.len() as i32)
234 }))
235 }
236
237 pub fn double_value(&self, idx: &[i64]) -> f64 {
239 self.f_double_value(idx).unwrap()
240 }
241
242 pub fn int64_value(&self, idx: &[i64]) -> i64 {
244 self.f_int64_value(idx).unwrap()
245 }
246
247 pub fn requires_grad(&self) -> bool {
249 unsafe_torch!(at_requires_grad(self.c_tensor)) != 0
250 }
251
252 pub fn data_ptr(&self) -> *mut c_void {
254 unsafe_torch!(at_data_ptr(self.c_tensor))
255 }
256
257 pub fn defined(&self) -> bool {
259 unsafe_torch!(at_defined(self.c_tensor) != 0)
260 }
261
262 pub fn is_mkldnn(&self) -> bool {
264 unsafe_torch!(at_is_mkldnn(self.c_tensor) != 0)
265 }
266
267 pub fn is_sparse(&self) -> bool {
269 unsafe_torch!(at_is_sparse(self.c_tensor) != 0)
270 }
271
272 pub fn is_contiguous(&self) -> bool {
274 unsafe_torch!(at_is_contiguous(self.c_tensor) != 0)
275 }
276
277 pub fn zero_grad(&mut self) {
279 let mut grad = self.grad();
280 if grad.defined() {
281 let _ = grad.detach_().zero_();
282 }
283 }
284
285 pub fn f_backward(&self) -> Result<(), TchError> {
290 unsafe_torch_err!(at_backward(self.c_tensor, 0, 0));
291 Ok(())
292 }
293
294 pub fn backward(&self) {
300 self.f_backward().unwrap()
301 }
302
303 pub fn f_backward_with_grad(&self, grad: &Self, keep_graph: bool, create_graph: bool) -> Result<(), TchError> {
308 let keep_graph = if keep_graph { 1 } else { 0 };
309 let create_graph = if create_graph { 1 } else { 0 };
310 unsafe_torch_err!(at_backward_with_grad(self.c_tensor, grad.c_tensor, keep_graph, create_graph));
311 Ok(())
312 }
313
314 pub fn backward_with_grad(&self, grad: &Self, keep_graph: bool, create_graph: bool) {
315 self.f_backward_with_grad(grad, keep_graph, create_graph).unwrap()
316 }
317
318 pub fn f_run_backward<T1, T2>(
319 tensors: &[T1],
320 inputs: &[T2],
321 keep_graph: bool,
322 create_graph: bool,
323 ) -> Result<Vec<Tensor>, TchError>
324 where
325 T1: Borrow<Tensor>,
326 T2: Borrow<Tensor>,
327 {
328 let mut outputs = vec![std::ptr::null_mut(); inputs.len()];
329 let tensors: Vec<_> = tensors.iter().map(|x| x.borrow().c_tensor).collect();
330 let inputs: Vec<_> = inputs.iter().map(|x| x.borrow().c_tensor).collect();
331 unsafe_torch_err!(at_run_backward(
332 tensors.as_ptr(),
333 tensors.len() as c_int,
334 inputs.as_ptr(),
335 inputs.len() as c_int,
336 outputs.as_mut_ptr(),
337 keep_graph as c_int,
338 create_graph as c_int,
339 ));
340 Ok(outputs.into_iter().map(|c_tensor| Tensor { c_tensor }).collect())
341 }
342
343 pub fn run_backward<T1, T2>(
344 tensors: &[T1],
345 inputs: &[T2],
346 keep_graph: bool,
347 create_graph: bool,
348 ) -> Vec<Tensor>
349 where
350 T1: Borrow<Tensor>,
351 T2: Borrow<Tensor>,
352 {
353 Tensor::f_run_backward(tensors, inputs, keep_graph, create_graph).unwrap()
354 }
355
356 pub fn f_copy_data_u8(&self, dst: &mut [u8], numel: usize) -> Result<(), TchError> {
358 let elt_size_in_bytes = self.f_kind()?.elt_size_in_bytes();
359 if dst.len() < numel * elt_size_in_bytes {
360 return Err(TchError::Shape(format!("slice len < {numel}")));
361 }
362 unsafe_torch_err!(at_copy_data(
363 self.c_tensor,
364 dst.as_mut_ptr() as *const c_void,
365 numel,
366 elt_size_in_bytes,
367 ));
368 Ok(())
369 }
370
371 pub fn f_internal_amp_non_finite_check_and_unscale(
378 &mut self,
379 found_inf: &mut Tensor,
380 inv_scale: &Tensor,
381 ) -> Result<(), TchError> {
382 unsafe_torch_err!(at__amp_non_finite_check_and_unscale(
383 self.c_tensor,
384 found_inf.c_tensor,
385 inv_scale.c_tensor
386 ));
387
388 Ok(())
389 }
390
391 pub fn internal_amp_non_finite_check_and_unscale(
398 &mut self,
399 found_inf: &mut Tensor,
400 inv_scale: &Tensor,
401 ) {
402 self.f_internal_amp_non_finite_check_and_unscale(found_inf, inv_scale).unwrap()
403 }
404
405 pub fn copy_data_u8(&self, dst: &mut [u8], numel: usize) {
407 self.f_copy_data_u8(dst, numel).unwrap()
408 }
409
410 pub fn f_copy_data<T: kind::Element>(
412 &self,
413 dst: &mut [T],
414 numel: usize,
415 ) -> Result<(), TchError> {
416 if T::KIND != self.f_kind()? {
417 return Err(TchError::Kind(format!(
418 "incoherent elt kind, {:?} != {:?}",
419 self.f_kind(),
420 T::KIND
421 )));
422 }
423 if dst.len() < numel {
424 return Err(TchError::Shape(format!("slice len < {numel}")));
425 }
426 unsafe_torch_err!(at_copy_data(
427 self.c_tensor,
428 dst.as_mut_ptr() as *const c_void,
429 numel,
430 T::KIND.elt_size_in_bytes(),
431 ));
432 Ok(())
433 }
434
435 pub fn copy_data<T: kind::Element>(&self, dst: &mut [T], numel: usize) {
437 self.f_copy_data(dst, numel).unwrap()
438 }
439
440 pub fn numel(&self) -> usize {
442 self.size().iter().product::<i64>() as usize
443 }
444
445 pub fn f_from_slice<T: kind::Element>(data: &[T]) -> Result<Tensor, TchError> {
448 let data_len = data.len();
449 let data = data.as_ptr() as *const c_void;
450 let c_tensor = unsafe_torch_err!(at_tensor_of_data(
451 data,
452 [data_len as i64].as_ptr(),
453 1,
454 T::KIND.elt_size_in_bytes(),
455 T::KIND.c_int(),
456 ));
457 Ok(Tensor { c_tensor })
458 }
459
460 pub fn from_slice<T: kind::Element>(data: &[T]) -> Tensor {
462 Self::f_from_slice(data).unwrap()
463 }
464
465 pub fn f_from_data_size(data: &[u8], size: &[i64], kind: Kind) -> Result<Tensor, TchError> {
467 let data = data.as_ptr() as *const c_void;
468 let elt_size_in_bytes = kind.elt_size_in_bytes();
469 let c_tensor = unsafe_torch_err!(at_tensor_of_data(
470 data,
471 size.as_ptr(),
472 size.len(),
473 elt_size_in_bytes,
474 kind.c_int(),
475 ));
476 Ok(Tensor { c_tensor })
477 }
478
479 pub unsafe fn f_from_blob(
485 data: *const u8,
486 size: &[i64],
487 strides: &[i64],
488 kind: Kind,
489 device: Device,
490 ) -> Result<Tensor, TchError> {
491 let data = data as *const c_void;
492 #[allow(unused_unsafe)]
493 let c_tensor = unsafe_torch_err!(at_tensor_of_blob(
494 data,
495 size.as_ptr(),
496 size.len(),
497 strides.as_ptr(),
498 strides.len(),
499 kind.c_int(),
500 device.c_int()
501 ));
502 Ok(Tensor { c_tensor })
503 }
504
505 pub unsafe fn from_blob(
511 data: *const u8,
512 size: &[i64],
513 strides: &[i64],
514 kind: Kind,
515 device: Device,
516 ) -> Tensor {
517 Self::f_from_blob(data, size, strides, kind, device).unwrap()
518 }
519
520 pub fn from_data_size(data: &[u8], size: &[i64], kind: Kind) -> Tensor {
522 Self::f_from_data_size(data, size, kind).unwrap()
523 }
524
525 pub fn shallow_clone(&self) -> Tensor {
527 let c_tensor = unsafe_torch!(at_shallow_clone(self.c_tensor));
528 Tensor { c_tensor }
529 }
530
531 pub fn f_get(&self, index: i64) -> Result<Tensor, TchError> {
533 let c_tensor = unsafe_torch_err!(at_get(self.c_tensor, index as c_int));
534 Ok(Tensor { c_tensor })
535 }
536
537 pub fn get(&self, index: i64) -> Tensor {
539 self.f_get(index).unwrap()
540 }
541
542 pub fn f_copy_(&mut self, src: &Tensor) -> Result<(), TchError> {
544 unsafe_torch_err!(at_copy_(self.c_tensor, src.c_tensor));
545 Ok(())
546 }
547
548 pub fn copy_(&mut self, src: &Tensor) {
550 self.f_copy_(src).unwrap()
551 }
552
553 pub fn load<T: AsRef<Path>>(path: T) -> Result<Tensor, TchError> {
557 let path = path_to_cstring(path)?;
558 let c_tensor = unsafe_torch_err!(at_load(path.as_ptr()));
559 Ok(Tensor { c_tensor })
560 }
561
562 pub fn load_from_stream<T: Read + Seek>(stream: T) -> Result<Tensor, TchError> {
566 let adapter = ReadSeekAdapter::new(stream);
567 let boxed_stream: Box<Box<dyn ReadStream>> = Box::new(Box::new(adapter));
568 let c_tensor =
569 unsafe_torch_err!(at_load_from_stream(Box::into_raw(boxed_stream) as *mut c_void,));
570 Ok(Tensor { c_tensor })
571 }
572
573 pub fn save<T: AsRef<Path>>(&self, path: T) -> Result<(), TchError> {
577 let path = path_to_cstring(path)?;
578 unsafe_torch_err!(at_save(self.c_tensor, path.as_ptr()));
579 Ok(())
580 }
581
582 pub fn save_to_stream<W: Write>(&self, stream: W) -> Result<(), TchError> {
586 let boxed_stream: Box<Box<dyn Write>> = Box::new(Box::new(stream));
587 unsafe_torch_err!(at_save_to_stream(
588 self.c_tensor,
589 Box::into_raw(boxed_stream) as *mut c_void,
590 ));
591 Ok(())
592 }
593
594 pub fn save_multi<S: AsRef<str>, T: AsRef<Tensor>, P: AsRef<Path>>(
598 named_tensors: &[(S, T)],
599 path: P,
600 ) -> Result<(), TchError> {
601 let path = path_to_cstring(path)?;
602 let c_tensors = named_tensors.iter().map(|nt| nt.1.as_ref().c_tensor).collect::<Vec<_>>();
603 let names = named_tensors
604 .iter()
605 .map(|nt| nt.0.as_ref().replace('.', "|").into_bytes())
606 .map(std::ffi::CString::new)
607 .collect::<Result<Vec<_>, _>>()?;
608 let name_ptrs = names.iter().map(|n| n.as_ptr()).collect::<Vec<_>>();
609 unsafe_torch_err!(at_save_multi(
610 c_tensors.as_ptr(),
611 name_ptrs.as_ptr(),
612 names.len() as i32,
613 path.as_ptr(),
614 ));
615 Ok(())
616 }
617
618 pub fn save_multi_to_stream<S: AsRef<str>, T: AsRef<Tensor>, W: Write>(
622 named_tensors: &[(S, T)],
623 stream: W,
624 ) -> Result<(), TchError> {
625 let boxed_stream: Box<Box<dyn Write>> = Box::new(Box::new(stream));
626 let c_tensors = named_tensors.iter().map(|nt| nt.1.as_ref().c_tensor).collect::<Vec<_>>();
627 let names = named_tensors
628 .iter()
629 .map(|nt| nt.0.as_ref().replace('.', "|").into_bytes())
630 .map(std::ffi::CString::new)
631 .collect::<Result<Vec<_>, _>>()?;
632 let name_ptrs = names.iter().map(|n| n.as_ptr()).collect::<Vec<_>>();
633 unsafe_torch_err!(at_save_multi_to_stream(
634 c_tensors.as_ptr(),
635 name_ptrs.as_ptr(),
636 names.len() as i32,
637 Box::into_raw(boxed_stream) as *mut c_void,
638 ));
639 Ok(())
640 }
641
642 pub fn load_multi<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Tensor)>, TchError> {
647 let path = path_to_cstring(path)?;
648 let mut v: Vec<(String, Tensor)> = vec![];
649 unsafe_torch_err!(at_load_callback(
650 path.as_ptr(),
651 &mut v as *mut _ as *mut c_void,
652 add_callback
653 ));
654 Ok(v)
655 }
656
657 pub fn load_multi_with_device<T: AsRef<Path>>(
662 path: T,
663 device: Device,
664 ) -> Result<Vec<(String, Tensor)>, TchError> {
665 let path = path_to_cstring(path)?;
666 let mut v: Vec<(String, Tensor)> = vec![];
667 unsafe_torch_err!(at_load_callback_with_device(
668 path.as_ptr(),
669 &mut v as *mut _ as *mut c_void,
670 add_callback,
671 device.c_int(),
672 ));
673 Ok(v)
674 }
675
676 pub fn loadz_multi<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Tensor)>, TchError> {
682 let path = path_to_cstring(path)?;
683 let mut v: Vec<(String, Tensor)> = vec![];
684 unsafe_torch_err!(at_loadz_callback(
685 path.as_ptr(),
686 &mut v as *mut _ as *mut c_void,
687 add_callback
688 ));
689 Ok(v)
690 }
691
692 pub fn loadz_multi_with_device<T: AsRef<Path>>(
698 path: T,
699 device: Device,
700 ) -> Result<Vec<(String, Tensor)>, TchError> {
701 let path = path_to_cstring(path)?;
702 let mut v: Vec<(String, Tensor)> = vec![];
703 unsafe_torch_err!(at_loadz_callback_with_device(
704 path.as_ptr(),
705 &mut v as *mut _ as *mut c_void,
706 add_callback,
707 device.c_int(),
708 ));
709 Ok(v)
710 }
711
712 pub fn load_multi_from_stream<T: Read + Seek>(
716 stream: T,
717 ) -> Result<Vec<(String, Tensor)>, TchError> {
718 let adapter = ReadSeekAdapter::new(stream);
719 let boxed_stream: Box<Box<dyn ReadStream>> = Box::new(Box::new(adapter));
720 let mut v: Vec<(String, Tensor)> = vec![];
721 unsafe_torch_err!(at_load_from_stream_callback(
722 Box::into_raw(boxed_stream) as *mut c_void,
723 &mut v as *mut _ as *mut c_void,
724 add_callback,
725 false,
726 0,
727 ));
728 Ok(v)
729 }
730
731 pub fn load_multi_from_stream_with_device<T: Read + Seek>(
735 stream: T,
736 device: Device,
737 ) -> Result<Vec<(String, Tensor)>, TchError> {
738 let adapter = ReadSeekAdapter::new(stream);
739 let boxed_stream: Box<Box<dyn ReadStream>> = Box::new(Box::new(adapter));
740 let mut v: Vec<(String, Tensor)> = vec![];
741 unsafe_torch_err!(at_load_from_stream_callback(
742 Box::into_raw(boxed_stream) as *mut c_void,
743 &mut v as *mut _ as *mut c_void,
744 add_callback,
745 true,
746 device.c_int(),
747 ));
748 Ok(v)
749 }
750
751 pub fn to_string(&self, lw: i64) -> Result<String, TchError> {
756 let s =
757 unsafe_torch_err!(ptr_to_string(torch_sys_plus::at_to_string(self.c_tensor, lw as c_int)));
758 match s {
759 None => Err(TchError::Kind("nullptr representation".to_string())),
760 Some(s) => Ok(s),
761 }
762 }
763}
764
765impl Default for Tensor {
766 fn default() -> Self {
767 Self::new()
768 }
769}
770
771impl Drop for Tensor {
772 fn drop(&mut self) {
773 unsafe_torch!(at_free(self.c_tensor))
774 }
775}
776
777fn autocast_clear_cache() {
778 unsafe_torch!(at_autocast_clear_cache())
779}
780
781fn autocast_decrement_nesting() -> isize {
782 unsafe_torch!(at_autocast_decrement_nesting() as isize)
783}
784
785fn autocast_increment_nesting() -> isize {
786 unsafe_torch!(at_autocast_increment_nesting() as isize)
787}
788
789fn autocast_is_enabled() -> bool {
790 unsafe_torch!(at_autocast_is_enabled() != 0)
791}
792
793fn autocast_set_enabled(b: bool) -> bool {
794 unsafe_torch!(at_autocast_set_enabled(i32::from(b)) != 0)
795}
796
797pub fn autocast<T, F>(enabled: bool, f: F) -> T
799where
800 F: FnOnce() -> T,
801{
802 if !Cuda::is_available() {
803 return f();
804 }
805
806 let prev = autocast_is_enabled();
808 autocast_set_enabled(enabled);
809 autocast_increment_nesting();
810
811 let result = f();
812
813 if autocast_decrement_nesting() == 0 {
814 autocast_clear_cache();
815 }
816 autocast_set_enabled(prev);
817
818 result
819}
820
821fn grad_set_enabled(b: bool) -> bool {
822 unsafe_torch!(at_grad_set_enabled(i32::from(b)) != 0)
823}
824
825pub fn no_grad<T, F>(f: F) -> T
827where
828 F: FnOnce() -> T,
829{
830 let prev = grad_set_enabled(false);
831 let result = f();
832 let _false = grad_set_enabled(prev);
833 result
834}
835
836pub fn with_grad<T, F>(f: F) -> T
839where
840 F: FnOnce() -> T,
841{
842 let prev = grad_set_enabled(true);
843 let result = f();
844 let _false = grad_set_enabled(prev);
845 result
846}
847
848pub struct NoGradGuard {
850 enabled: bool,
851}
852
853pub fn no_grad_guard() -> NoGradGuard {
860 NoGradGuard { enabled: grad_set_enabled(false) }
861}
862
863impl std::convert::AsRef<Tensor> for Tensor {
864 fn as_ref(&self) -> &Self {
865 self
866 }
867}
868
869impl Drop for NoGradGuard {
870 fn drop(&mut self) {
871 let _enabled = grad_set_enabled(self.enabled);
872 }
873}
874
875#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
876pub enum Reduction {
877 None,
879 Mean,
881 Sum,
883 Other(i64),
885}
886
887impl Reduction {
888 pub fn to_int(self) -> i64 {
891 match self {
892 Reduction::None => 0,
893 Reduction::Mean => 1,
894 Reduction::Sum => 2,
895 Reduction::Other(i) => i,
896 }
897 }
898}