1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
use super::utils::{path_to_cstring, ptr_to_string};
use super::{
    device::{Cuda, Device},
    kind,
    kind::Kind,
};
use crate::TchError;
use libc::{c_char, c_int, c_void};
use std::borrow::Borrow;
use std::path::Path;
use torch_sys::*;

/// A tensor object.
#[must_use]
pub struct Tensor {
    pub(super) c_tensor: *mut C_tensor,
}

unsafe impl Send for Tensor {}

pub extern "C" fn add_callback(data: *mut c_void, name: *const c_char, c_tensor: *mut C_tensor) {
    let name = unsafe { std::ffi::CStr::from_ptr(name).to_str().unwrap() };
    let name = name.replace("|", ".");
    let v: &mut Vec<(String, Tensor)> = unsafe { &mut *(data as *mut Vec<(String, Tensor)>) };
    v.push((name, Tensor { c_tensor }))
}

impl Tensor {
    /// Creates a new tensor.
    pub fn new() -> Tensor {
        let c_tensor = unsafe_torch!(at_new_tensor());
        Tensor { c_tensor }
    }

    /// Returns the number of dimension of the tensor.
    pub fn dim(&self) -> usize {
        unsafe_torch!(at_dim(self.c_tensor))
    }

    /// Returns the shape of the input tensor.
    pub fn size(&self) -> Vec<i64> {
        let dim = unsafe_torch!(at_dim(self.c_tensor));
        let mut sz = vec![0i64; dim];
        unsafe_torch!(at_shape(self.c_tensor, sz.as_mut_ptr()));
        sz
    }

    /// Returns the tensor size for single dimension tensors.
    pub fn size1(&self) -> Result<i64, TchError> {
        match self.size().as_slice() {
            &[s0] => Ok(s0),
            size => Err(TchError::Shape(format!("expected one dim, got {:?}", size))),
        }
    }

    /// Returns the tensor sizes for two dimension tensors.
    pub fn size2(&self) -> Result<(i64, i64), TchError> {
        match self.size().as_slice() {
            &[s0, s1] => Ok((s0, s1)),
            size => Err(TchError::Shape(format!(
                "expected two dims, got {:?}",
                size
            ))),
        }
    }

    /// Returns the tensor sizes for three dimension tensors.
    pub fn size3(&self) -> Result<(i64, i64, i64), TchError> {
        match self.size().as_slice() {
            &[s0, s1, s2] => Ok((s0, s1, s2)),
            size => Err(TchError::Shape(format!(
                "expected three dims, got {:?}",
                size
            ))),
        }
    }

    /// Returns the tensor sizes for four dimension tensors.
    pub fn size4(&self) -> Result<(i64, i64, i64, i64), TchError> {
        match self.size().as_slice() {
            &[s0, s1, s2, s3] => Ok((s0, s1, s2, s3)),
            size => Err(TchError::Shape(format!(
                "expected four dims, got {:?}",
                size
            ))),
        }
    }

    /// Returns the tensor sizes for five dimension tensors.
    pub fn size5(&self) -> Result<(i64, i64, i64, i64, i64), TchError> {
        match self.size().as_slice() {
            &[s0, s1, s2, s3, s4] => Ok((s0, s1, s2, s3, s4)),
            size => Err(TchError::Shape(format!(
                "expected five dims, got {:?}",
                size
            ))),
        }
    }

    /// Returns the tensor sizes for six dimension tensors.
    pub fn size6(&self) -> Result<(i64, i64, i64, i64, i64, i64), TchError> {
        match self.size().as_slice() {
            &[s0, s1, s2, s3, s4, s5] => Ok((s0, s1, s2, s3, s4, s5)),
            size => Err(TchError::Shape(format!(
                "expected six dims, got {:?}",
                size
            ))),
        }
    }

    /// Returns the stride of the input tensor.
    pub fn stride(&self) -> Vec<i64> {
        let dim = unsafe_torch!(at_dim(self.c_tensor));
        let mut sz = vec![0i64; dim];
        unsafe_torch!(at_stride(self.c_tensor, sz.as_mut_ptr()));
        sz
    }

    /// Returns the tensor strides for single dimension tensors.
    pub fn stride1(&self) -> Result<i64, TchError> {
        match self.stride().as_slice() {
            &[s0] => Ok(s0),
            size => Err(TchError::Shape(format!("expected one dim, got {:?}", size))),
        }
    }

    /// Returns the tensor strides for two dimension tensors.
    pub fn stride2(&self) -> Result<(i64, i64), TchError> {
        match self.stride().as_slice() {
            &[s0, s1] => Ok((s0, s1)),
            size => Err(TchError::Shape(format!(
                "expected two dims, got {:?}",
                size
            ))),
        }
    }

    /// Returns the tensor strides for three dimension tensors.
    pub fn stride3(&self) -> Result<(i64, i64, i64), TchError> {
        match self.stride().as_slice() {
            &[s0, s1, s2] => Ok((s0, s1, s2)),
            size => Err(TchError::Shape(format!(
                "expected three dims, got {:?}",
                size
            ))),
        }
    }

    /// Returns the tensor strides for four dimension tensors.
    pub fn stride4(&self) -> Result<(i64, i64, i64, i64), TchError> {
        match self.stride().as_slice() {
            &[s0, s1, s2, s3] => Ok((s0, s1, s2, s3)),
            size => Err(TchError::Shape(format!(
                "expected four dims, got {:?}",
                size
            ))),
        }
    }

    /// Returns the tensor strides for five dimension tensors.
    pub fn stride5(&self) -> Result<(i64, i64, i64, i64, i64), TchError> {
        match self.stride().as_slice() {
            &[s0, s1, s2, s3, s4] => Ok((s0, s1, s2, s3, s4)),
            size => Err(TchError::Shape(format!(
                "expected five dims, got {:?}",
                size
            ))),
        }
    }

    /// Returns the tensor strides for six dimension tensors.
    pub fn stride6(&self) -> Result<(i64, i64, i64, i64, i64, i64), TchError> {
        match self.stride().as_slice() {
            &[s0, s1, s2, s3, s4, s5] => Ok((s0, s1, s2, s3, s4, s5)),
            size => Err(TchError::Shape(format!(
                "expected six dims, got {:?}",
                size
            ))),
        }
    }

    /// Returns the kind of elements stored in the input tensor. Returns
    /// an error on undefined tensors and unsupported data types.
    pub fn f_kind(&self) -> Result<Kind, TchError> {
        let kind = unsafe_torch!(at_scalar_type(self.c_tensor));
        Kind::of_c_int(kind)
    }

    /// Returns the kind of elements stored in the input tensor. Panics
    /// an error on undefined tensors and unsupported data types.
    pub fn kind(&self) -> Kind {
        self.f_kind().unwrap()
    }

    /// Returns the device on which the input tensor is located.
    pub fn device(&self) -> Device {
        let device = unsafe_torch!(at_device(self.c_tensor));
        Device::of_c_int(device)
    }

    /// Prints the input tensor.
    ///
    /// Caution: this uses the C++ printer which prints the whole tensor even if
    /// it is very large.
    pub fn print(&self) {
        unsafe_torch!(at_print(self.c_tensor))
    }

    /// Returns a double value on tensors holding a single element. An error is
    /// returned otherwise.
    pub fn f_double_value(&self, idx: &[i64]) -> Result<f64, TchError> {
        Ok(unsafe_torch_err!({
            at_double_value_at_indexes(self.c_tensor, idx.as_ptr(), idx.len() as i32)
        }))
    }

    /// Returns an int value on tensors holding a single element. An error is
    /// returned otherwise.
    pub fn f_int64_value(&self, idx: &[i64]) -> Result<i64, TchError> {
        Ok(unsafe_torch_err!({
            at_int64_value_at_indexes(self.c_tensor, idx.as_ptr(), idx.len() as i32)
        }))
    }

    /// Returns a double value on tensors holding a single element. Panics otherwise.
    pub fn double_value(&self, idx: &[i64]) -> f64 {
        self.f_double_value(idx).unwrap()
    }

    /// Returns an int value on tensors holding a single element. Panics otherwise.
    pub fn int64_value(&self, idx: &[i64]) -> i64 {
        self.f_int64_value(idx).unwrap()
    }

    /// Returns true if gradient are currently tracked for this tensor.
    pub fn requires_grad(&self) -> bool {
        unsafe_torch!(at_requires_grad(self.c_tensor)) != 0
    }

    /// Returns the address of the first element of this tensor.
    pub fn data_ptr(&self) -> *mut c_void {
        unsafe_torch!(at_data_ptr(self.c_tensor))
    }

    /// Returns true if the tensor is defined.
    pub fn defined(&self) -> bool {
        unsafe_torch!(at_defined(self.c_tensor) != 0)
    }

    /// Returns true if the tensor is compatible with MKL-DNN (oneDNN).
    pub fn is_mkldnn(&self) -> bool {
        unsafe_torch!(at_is_mkldnn(self.c_tensor) != 0)
    }

    /// Returns true if the tensor is sparse.
    pub fn is_sparse(&self) -> bool {
        unsafe_torch!(at_is_sparse(self.c_tensor) != 0)
    }

    /// Zeroes the gradient tensor attached to this tensor if defined.
    pub fn zero_grad(&mut self) {
        let mut grad = self.grad();
        if grad.defined() {
            let _ = grad.detach_().zero_();
        }
    }

    /// Runs the backward pass, populating the gradient tensors for tensors
    /// which gradients are tracked.
    ///
    /// Gradients tracking can be turned on via `set_requires_grad`.
    pub fn f_backward(&self) -> Result<(), TchError> {
        unsafe_torch_err!(at_backward(self.c_tensor, 0, 0));
        Ok(())
    }

    /// Runs the backward pass, populating the gradient tensors for tensors
    /// which gradients are tracked.
    ///
    /// Gradients tracking can be turned on via `set_requires_grad`.
    /// Panics if the C++ api returns an exception.
    pub fn backward(&self) {
        self.f_backward().unwrap()
    }

    pub fn f_run_backward<T1, T2>(
        tensors: &[T1],
        inputs: &[T2],
        keep_graph: bool,
        create_graph: bool,
    ) -> Result<Vec<Tensor>, TchError>
    where
        T1: Borrow<Tensor>,
        T2: Borrow<Tensor>,
    {
        let mut outputs = vec![std::ptr::null_mut(); inputs.len()];
        let tensors: Vec<_> = tensors.iter().map(|x| x.borrow().c_tensor).collect();
        let inputs: Vec<_> = inputs.iter().map(|x| x.borrow().c_tensor).collect();
        unsafe_torch_err!(at_run_backward(
            tensors.as_ptr(),
            tensors.len() as c_int,
            inputs.as_ptr(),
            inputs.len() as c_int,
            outputs.as_mut_ptr(),
            keep_graph as c_int,
            create_graph as c_int,
        ));
        Ok(outputs
            .into_iter()
            .map(|c_tensor| Tensor { c_tensor })
            .collect())
    }

    pub fn run_backward<T1, T2>(
        tensors: &[T1],
        inputs: &[T2],
        keep_graph: bool,
        create_graph: bool,
    ) -> Vec<Tensor>
    where
        T1: Borrow<Tensor>,
        T2: Borrow<Tensor>,
    {
        Tensor::f_run_backward(tensors, inputs, keep_graph, create_graph).unwrap()
    }

    /// Copies `numel` elements from `self` to `dst`.
    pub fn f_copy_data_u8(&self, dst: &mut [u8], numel: usize) -> Result<(), TchError> {
        let elt_size_in_bytes = self.f_kind()?.elt_size_in_bytes();
        if dst.len() < numel * elt_size_in_bytes {
            return Err(TchError::Shape(format!("slice len < {}", numel)));
        }
        unsafe_torch_err!(at_copy_data(
            self.c_tensor,
            dst.as_mut_ptr() as *const c_void,
            numel,
            elt_size_in_bytes,
        ));
        Ok(())
    }

    /// Unscale tensor while checking for infinities.
    ///
    /// `found_inf` is a singleton tensor that is used to record the
    /// presence of infinite values. `inv_scale` is a scalar containing
    /// the inverse scaling factor. This method is only available
    /// for CUDA tensors.
    pub fn f_internal_amp_non_finite_check_and_unscale(
        &mut self,
        found_inf: &mut Tensor,
        inv_scale: &Tensor,
    ) -> Result<(), TchError> {
        unsafe_torch_err!(at__amp_non_finite_check_and_unscale(
            self.c_tensor,
            found_inf.c_tensor,
            inv_scale.c_tensor
        ));

        Ok(())
    }

    /// Unscale tensor while checking for infinities.
    ///
    /// `found_inf` is a singleton tensor that is used to record the
    /// presence of infinite values. `inv_scale` is a scalar containing
    /// the inverse scaling factor. This method is only available
    /// for CUDA tensors.
    pub fn internal_amp_non_finite_check_and_unscale(
        &mut self,
        found_inf: &mut Tensor,
        inv_scale: &Tensor,
    ) {
        self.f_internal_amp_non_finite_check_and_unscale(found_inf, inv_scale)
            .unwrap()
    }

    /// Copies `numel` elements from `self` to `dst`.
    pub fn copy_data_u8(&self, dst: &mut [u8], numel: usize) {
        self.f_copy_data_u8(dst, numel).unwrap()
    }

    /// Copies `numel` elements from `self` to `dst`.
    pub fn f_copy_data<T: kind::Element>(
        &self,
        dst: &mut [T],
        numel: usize,
    ) -> Result<(), TchError> {
        if T::KIND != self.f_kind()? {
            return Err(TchError::Kind(format!(
                "incoherent elt kind, {:?} != {:?}",
                self.f_kind(),
                T::KIND
            )));
        }
        if dst.len() < numel {
            return Err(TchError::Shape(format!("slice len < {}", numel)));
        }
        unsafe_torch_err!(at_copy_data(
            self.c_tensor,
            dst.as_mut_ptr() as *const c_void,
            numel,
            T::KIND.elt_size_in_bytes(),
        ));
        Ok(())
    }

    /// Copies `numel` elements from `self` to `dst`.
    pub fn copy_data<T: kind::Element>(&self, dst: &mut [T], numel: usize) {
        self.f_copy_data(dst, numel).unwrap()
    }

    /// Returns the total number of elements stored in a tensor.
    pub fn numel(&self) -> usize {
        self.size().iter().product::<i64>() as usize
    }

    // This is similar to vec_... but faster as it directly blits the data.
    /// Converts a slice to a tensor.
    pub fn f_of_slice<T: kind::Element>(data: &[T]) -> Result<Tensor, TchError> {
        let data_len = data.len();
        let data = data.as_ptr() as *const c_void;
        let c_tensor = unsafe_torch_err!(at_tensor_of_data(
            data,
            [data_len as i64].as_ptr(),
            1,
            T::KIND.elt_size_in_bytes(),
            T::KIND.c_int(),
        ));
        Ok(Tensor { c_tensor })
    }

    /// Converts a slice to a tensor.
    pub fn of_slice<T: kind::Element>(data: &[T]) -> Tensor {
        Self::f_of_slice(data).unwrap()
    }

    /// Converts some byte data to a tensor with some specified kind and shape.
    pub fn f_of_data_size(data: &[u8], size: &[i64], kind: Kind) -> Result<Tensor, TchError> {
        let data = data.as_ptr() as *const c_void;
        let elt_size_in_bytes = kind.elt_size_in_bytes();
        let c_tensor = unsafe_torch_err!(at_tensor_of_data(
            data,
            size.as_ptr(),
            size.len(),
            elt_size_in_bytes,
            kind.c_int(),
        ));
        Ok(Tensor { c_tensor })
    }

    /// Creates a tensor from data that is assumed to be initialized.
    /// Resize operations are now allowed on this tensor without copying the data first.
    /// # Safety
    ///   This will panic if `data` points to invalid data.
    pub unsafe fn f_of_blob(
        data: *const u8,
        size: &[i64],
        strides: &[i64],
        kind: Kind,
        device: Device,
    ) -> Result<Tensor, TchError> {
        let data = data as *const c_void;
        #[allow(unused_unsafe)]
        let c_tensor = unsafe_torch_err!(at_tensor_of_blob(
            data,
            size.as_ptr(),
            size.len(),
            strides.as_ptr(),
            strides.len(),
            kind.c_int(),
            device.c_int()
        ));
        Ok(Tensor { c_tensor })
    }

    /// Creates a tensor from data that is assumed to be initialized.
    /// Resize operations are now allowed on this tensor without copying the data first.
    /// # Safety
    ///   This will panic if `data` points to invalid data.
    pub unsafe fn of_blob(
        data: *const u8,
        size: &[i64],
        strides: &[i64],
        kind: Kind,
        device: Device,
    ) -> Tensor {
        Self::f_of_blob(data, size, strides, kind, device).unwrap()
    }

    /// Converts some byte data to a tensor with some specified kind and shape.
    pub fn of_data_size(data: &[u8], size: &[i64], kind: Kind) -> Tensor {
        Self::f_of_data_size(data, size, kind).unwrap()
    }

    /// Returns a new tensor that share storage with the input tensor.
    pub fn shallow_clone(&self) -> Tensor {
        let c_tensor = unsafe_torch!(at_shallow_clone(self.c_tensor));
        Tensor { c_tensor }
    }

    /// Gets the sub-tensor at the given index.
    pub fn f_get(&self, index: i64) -> Result<Tensor, TchError> {
        let c_tensor = unsafe_torch_err!(at_get(self.c_tensor, index as c_int));
        Ok(Tensor { c_tensor })
    }

    /// Gets the sub-tensor at the given index.
    pub fn get(&self, index: i64) -> Tensor {
        self.f_get(index).unwrap()
    }

    /// Copies values from the argument tensor to the input tensor.
    pub fn f_copy_(&mut self, src: &Tensor) -> Result<(), TchError> {
        unsafe_torch_err!(at_copy_(self.c_tensor, src.c_tensor));
        Ok(())
    }

    /// Copies values from the argument tensor to the input tensor.
    pub fn copy_(&mut self, src: &Tensor) {
        self.f_copy_(src).unwrap()
    }

    /// Loads a tensor from a file.
    ///
    /// The file format is the same as the one used by the PyTorch C++ API.
    pub fn load<T: AsRef<Path>>(path: T) -> Result<Tensor, TchError> {
        let path = path_to_cstring(path)?;
        let c_tensor = unsafe_torch_err!(at_load(path.as_ptr()));
        Ok(Tensor { c_tensor })
    }

    /// Saves a tensor to a file.
    ///
    /// The file format is the same as the one used by the PyTorch C++ API.
    pub fn save<T: AsRef<Path>>(&self, path: T) -> Result<(), TchError> {
        let path = path_to_cstring(path)?;
        unsafe_torch_err!(at_save(self.c_tensor, path.as_ptr()));
        Ok(())
    }

    /// Saves some named tensors to a file
    ///
    /// The file format is the same as the one used by the PyTorch C++ API.
    pub fn save_multi<S: AsRef<str>, T: AsRef<Tensor>, P: AsRef<Path>>(
        named_tensors: &[(S, T)],
        path: P,
    ) -> Result<(), TchError> {
        let path = path_to_cstring(path)?;
        let c_tensors = named_tensors
            .iter()
            .map(|nt| nt.1.as_ref().c_tensor)
            .collect::<Vec<_>>();
        let names = named_tensors
            .iter()
            .map(|nt| nt.0.as_ref().replace(".", "|").into_bytes())
            .map(std::ffi::CString::new)
            .collect::<Result<Vec<_>, _>>()?;
        let name_ptrs = names.iter().map(|n| n.as_ptr()).collect::<Vec<_>>();
        unsafe_torch_err!(at_save_multi(
            c_tensors.as_ptr(),
            name_ptrs.as_ptr(),
            names.len() as i32,
            path.as_ptr(),
        ));
        Ok(())
    }

    /// Loads some named tensors from a file
    ///
    /// The file format is the same as the one used by the PyTorch C++ API.
    pub fn load_multi<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Tensor)>, TchError> {
        let path = path_to_cstring(path)?;
        let mut v: Vec<(String, Tensor)> = vec![];
        unsafe_torch_err!(at_load_callback(
            path.as_ptr(),
            &mut v as *mut _ as *mut c_void,
            add_callback
        ));
        Ok(v)
    }

    /// Loads some named tensors from a file to a given device
    ///
    /// The file format is the same as the one used by the PyTorch C++ API.
    pub fn load_multi_with_device<T: AsRef<Path>>(
        path: T,
        device: Device,
    ) -> Result<Vec<(String, Tensor)>, TchError> {
        let path = path_to_cstring(path)?;
        let mut v: Vec<(String, Tensor)> = vec![];
        unsafe_torch_err!(at_load_callback_with_device(
            path.as_ptr(),
            &mut v as *mut _ as *mut c_void,
            add_callback,
            device.c_int(),
        ));
        Ok(v)
    }

    /// Returns a string representation for the tensor.
    ///
    /// The representation will contain all the tensor element hence may be huge for
    /// large tensors.
    pub fn to_string(&self, lw: i64) -> Result<String, TchError> {
        let s = unsafe_torch_err!(ptr_to_string(torch_sys::at_to_string(
            self.c_tensor,
            lw as c_int
        )));
        match s {
            None => Err(TchError::Kind("nullptr representation".to_string())),
            Some(s) => Ok(s),
        }
    }
}

impl Default for Tensor {
    fn default() -> Self {
        Self::new()
    }
}

impl Drop for Tensor {
    fn drop(&mut self) {
        unsafe_torch!(at_free(self.c_tensor))
    }
}

fn autocast_clear_cache() {
    unsafe_torch!(at_autocast_clear_cache())
}

fn autocast_decrement_nesting() -> isize {
    unsafe_torch!(at_autocast_decrement_nesting() as isize)
}

fn autocast_increment_nesting() -> isize {
    unsafe_torch!(at_autocast_increment_nesting() as isize)
}

fn autocast_is_enabled() -> bool {
    unsafe_torch!(at_autocast_is_enabled() != 0)
}

fn autocast_set_enabled(b: bool) -> bool {
    unsafe_torch!(at_autocast_set_enabled(if b { 1 } else { 0 }) != 0)
}

/// Runs a closure in mixed precision.
pub fn autocast<T, F>(enabled: bool, f: F) -> T
where
    F: FnOnce() -> T,
{
    if !Cuda::is_available() {
        return f();
    }

    // Check whether we are using CUDA.
    let prev = autocast_is_enabled();
    autocast_set_enabled(enabled);
    autocast_increment_nesting();

    let result = f();

    if autocast_decrement_nesting() == 0 {
        autocast_clear_cache();
    }
    autocast_set_enabled(prev);

    result
}

fn grad_set_enabled(b: bool) -> bool {
    unsafe_torch!(at_grad_set_enabled(if b { 1 } else { 0 }) != 0)
}

/// Runs a closure without keeping track of gradients.
pub fn no_grad<T, F>(f: F) -> T
where
    F: FnOnce() -> T,
{
    let prev = grad_set_enabled(false);
    let result = f();
    let _false = grad_set_enabled(prev);
    result
}

/// Runs a closure explicitly keeping track of gradients, this could be
/// run within a no_grad closure for example.
pub fn with_grad<T, F>(f: F) -> T
where
    F: FnOnce() -> T,
{
    let prev = grad_set_enabled(true);
    let result = f();
    let _false = grad_set_enabled(prev);
    result
}

/// A RAII guard that prevents gradient tracking until deallocated.
pub struct NoGradGuard {
    enabled: bool,
}

/// Disables gradient tracking, this will be enabled back when the
/// returned value gets deallocated.
pub fn no_grad_guard() -> NoGradGuard {
    NoGradGuard {
        enabled: grad_set_enabled(false),
    }
}

impl std::convert::AsRef<Tensor> for Tensor {
    fn as_ref(&self) -> &Self {
        &self
    }
}

impl Drop for NoGradGuard {
    fn drop(&mut self) {
        let _enabled = grad_set_enabled(self.enabled);
    }
}

#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum Reduction {
    /// Do not reduce.
    None,
    /// Mean of losses.
    Mean,
    /// Sum of losses.
    Sum,
    /// Escape hatch in case new options become available.
    Other(i64),
}

impl Reduction {
    // This has to stay in sync with
    // pytorch/aten/src/ATen/core/Reduction.h
    pub fn to_int(self) -> i64 {
        match self {
            Reduction::None => 0,
            Reduction::Mean => 1,
            Reduction::Sum => 2,
            Reduction::Other(i) => i,
        }
    }
}