tch_plus/wrappers/
scalar.rs1use crate::TchError;
4
5pub struct Scalar {
7 pub(super) c_scalar: *mut torch_sys_plus::C_scalar,
8}
9
10impl Scalar {
11 pub fn int(v: i64) -> Scalar {
13 let c_scalar = unsafe_torch!(torch_sys_plus::ats_int(v));
14 Scalar { c_scalar }
15 }
16
17 pub fn float(v: f64) -> Scalar {
19 let c_scalar = unsafe_torch!(torch_sys_plus::ats_float(v));
20 Scalar { c_scalar }
21 }
22
23 pub fn to_int(&self) -> Result<i64, TchError> {
25 let i = unsafe_torch_err!(torch_sys_plus::ats_to_int(self.c_scalar));
26 Ok(i)
27 }
28
29 pub fn to_float(&self) -> Result<f64, TchError> {
31 let f = unsafe_torch_err!(torch_sys_plus::ats_to_float(self.c_scalar));
32 Ok(f)
33 }
34
35 pub fn to_string(&self) -> Result<String, TchError> {
37 let s = unsafe_torch_err!({
38 super::utils::ptr_to_string(torch_sys_plus::ats_to_string(self.c_scalar))
39 });
40 match s {
41 None => Err(TchError::Kind("nullptr representation".to_string())),
42 Some(s) => Ok(s),
43 }
44 }
45}
46
47impl std::fmt::Debug for Scalar {
48 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
49 match self.to_string() {
50 Err(_) => write!(f, "err"),
51 Ok(s) => write!(f, "scalar<{s}>"),
52 }
53 }
54}
55
56impl Drop for Scalar {
57 fn drop(&mut self) {
58 unsafe_torch!(torch_sys_plus::ats_free(self.c_scalar))
59 }
60}
61
62impl From<i64> for Scalar {
63 fn from(v: i64) -> Scalar {
64 Scalar::int(v)
65 }
66}
67
68impl From<f64> for Scalar {
69 fn from(v: f64) -> Scalar {
70 Scalar::float(v)
71 }
72}
73
74impl From<Scalar> for i64 {
75 fn from(s: Scalar) -> i64 {
76 Self::from(&s)
77 }
78}
79
80impl From<Scalar> for f64 {
81 fn from(s: Scalar) -> f64 {
82 Self::from(&s)
83 }
84}
85
86impl From<&Scalar> for i64 {
87 fn from(s: &Scalar) -> i64 {
88 s.to_int().unwrap()
89 }
90}
91
92impl From<&Scalar> for f64 {
93 fn from(s: &Scalar) -> f64 {
94 s.to_float().unwrap()
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::Scalar;
101 #[test]
102 fn scalar() {
103 let pi = Scalar::float(std::f64::consts::PI);
104 assert_eq!(i64::from(&pi), 3);
105 assert_eq!(f64::from(&pi), std::f64::consts::PI);
106 let leet = Scalar::int(1337);
107 assert_eq!(i64::from(&leet), 1337);
108 assert_eq!(f64::from(&leet), 1337.);
109 assert_eq!(&format!("{pi:?}"), "scalar<3.14159>");
110 }
111}