tch_plus/wrappers/
scalar.rs

1//! Scalar elements.
2
3use crate::TchError;
4
5/// A single scalar value.
6pub struct Scalar {
7    pub(super) c_scalar: *mut torch_sys_plus::C_scalar,
8}
9
10impl Scalar {
11    /// Creates an integer scalar.
12    pub fn int(v: i64) -> Scalar {
13        let c_scalar = unsafe_torch!(torch_sys_plus::ats_int(v));
14        Scalar { c_scalar }
15    }
16
17    /// Creates a float scalar scalar.
18    pub fn float(v: f64) -> Scalar {
19        let c_scalar = unsafe_torch!(torch_sys_plus::ats_float(v));
20        Scalar { c_scalar }
21    }
22
23    /// Returns an integer value.
24    pub fn to_int(&self) -> Result<i64, TchError> {
25        let i = unsafe_torch_err!(torch_sys_plus::ats_to_int(self.c_scalar));
26        Ok(i)
27    }
28
29    /// Returns a float value.
30    pub fn to_float(&self) -> Result<f64, TchError> {
31        let f = unsafe_torch_err!(torch_sys_plus::ats_to_float(self.c_scalar));
32        Ok(f)
33    }
34
35    /// Returns a string representation of the scalar.
36    pub fn to_string(&self) -> Result<String, TchError> {
37        let s = unsafe_torch_err!({
38            super::utils::ptr_to_string(torch_sys_plus::ats_to_string(self.c_scalar))
39        });
40        match s {
41            None => Err(TchError::Kind("nullptr representation".to_string())),
42            Some(s) => Ok(s),
43        }
44    }
45}
46
47impl std::fmt::Debug for Scalar {
48    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
49        match self.to_string() {
50            Err(_) => write!(f, "err"),
51            Ok(s) => write!(f, "scalar<{s}>"),
52        }
53    }
54}
55
56impl Drop for Scalar {
57    fn drop(&mut self) {
58        unsafe_torch!(torch_sys_plus::ats_free(self.c_scalar))
59    }
60}
61
62impl From<i64> for Scalar {
63    fn from(v: i64) -> Scalar {
64        Scalar::int(v)
65    }
66}
67
68impl From<f64> for Scalar {
69    fn from(v: f64) -> Scalar {
70        Scalar::float(v)
71    }
72}
73
74impl From<Scalar> for i64 {
75    fn from(s: Scalar) -> i64 {
76        Self::from(&s)
77    }
78}
79
80impl From<Scalar> for f64 {
81    fn from(s: Scalar) -> f64 {
82        Self::from(&s)
83    }
84}
85
86impl From<&Scalar> for i64 {
87    fn from(s: &Scalar) -> i64 {
88        s.to_int().unwrap()
89    }
90}
91
92impl From<&Scalar> for f64 {
93    fn from(s: &Scalar) -> f64 {
94        s.to_float().unwrap()
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::Scalar;
101    #[test]
102    fn scalar() {
103        let pi = Scalar::float(std::f64::consts::PI);
104        assert_eq!(i64::from(&pi), 3);
105        assert_eq!(f64::from(&pi), std::f64::consts::PI);
106        let leet = Scalar::int(1337);
107        assert_eq!(i64::from(&leet), 1337);
108        assert_eq!(f64::from(&leet), 1337.);
109        assert_eq!(&format!("{pi:?}"), "scalar<3.14159>");
110    }
111}