tch_plus/wrappers/
jit.rs

1//! JIT interface to run model trained/saved using PyTorch Python API.
2use super::utils::{path_to_cstring, ptr_to_string};
3use super::{device::Device, kind::Kind};
4use crate::{nn::Path, TchError, Tensor};
5use libc::{c_int, c_void};
6use std::borrow::Borrow;
7use std::convert::TryFrom;
8use torch_sys_plus::*;
9
10/// Argument and output values for JIT models. These represent arbitrary values,
11/// e.g. tensors, atomic values, pairs of values, etc.
12#[derive(Debug, PartialEq)]
13#[non_exhaustive]
14pub enum IValue {
15    None,
16    Tensor(crate::Tensor),
17    Double(f64),
18    Int(i64),
19    Bool(bool),
20    Tuple(Vec<IValue>),
21    IntList(Vec<i64>),
22    DoubleList(Vec<f64>),
23    BoolList(Vec<bool>),
24    String(String),
25    StringList(Vec<String>),
26    TensorList(Vec<crate::Tensor>),
27    GenericList(Vec<IValue>),
28    // We use a vec to represent dictionaries as f64 does not implement
29    // Eq or Hash out of the box in rust. TODO: improve this?
30    GenericDict(Vec<(IValue, IValue)>),
31    Object(Object),
32}
33
34impl IValue {
35    fn type_str(self) -> &'static str {
36        match self {
37            IValue::None => "None",
38            IValue::Tensor(_) => "Tensor",
39            IValue::Double(_) => "Double",
40            IValue::Int(_) => "Int",
41            IValue::Bool(_) => "Bool",
42            IValue::Tuple(_) => "Tuple",
43            IValue::IntList(_) => "IntList",
44            IValue::DoubleList(_) => "DoubleList",
45            IValue::BoolList(_) => "BoolList",
46            IValue::String(_) => "String",
47            IValue::StringList(_) => "StringList",
48            IValue::TensorList(_) => "TensorList",
49            IValue::GenericList(_) => "GenericList",
50            IValue::GenericDict(_) => "GenericDict",
51            IValue::Object(_) => "Object",
52        }
53    }
54}
55
56impl From<()> for IValue {
57    fn from((): ()) -> Self {
58        IValue::None
59    }
60}
61
62impl<T1: Into<IValue>, T2: Into<IValue>> From<(T1, T2)> for IValue {
63    fn from((p1, p2): (T1, T2)) -> Self {
64        IValue::Tuple(vec![p1.into(), p2.into()])
65    }
66}
67
68impl<T1: Into<IValue>, T2: Into<IValue>, T3: Into<IValue>> From<(T1, T2, T3)> for IValue {
69    fn from((p1, p2, p3): (T1, T2, T3)) -> Self {
70        IValue::Tuple(vec![p1.into(), p2.into(), p3.into()])
71    }
72}
73
74impl<T1: Into<IValue>, T2: Into<IValue>, T3: Into<IValue>, T4: Into<IValue>> From<(T1, T2, T3, T4)>
75    for IValue
76{
77    fn from((p1, p2, p3, p4): (T1, T2, T3, T4)) -> Self {
78        IValue::Tuple(vec![p1.into(), p2.into(), p3.into(), p4.into()])
79    }
80}
81
82impl<T1, T2, T1E, T2E> TryFrom<IValue> for (T1, T2)
83where
84    T1: TryFrom<IValue, Error = T1E>,
85    TchError: From<T1E>,
86    T2: TryFrom<IValue, Error = T2E>,
87    TchError: From<T2E>,
88{
89    type Error = TchError;
90    fn try_from(value: IValue) -> Result<Self, TchError> {
91        match value {
92            IValue::GenericList(mut vec) | IValue::Tuple(mut vec) => {
93                if vec.len() == 2 {
94                    let t2 = T2::try_from(vec.pop().unwrap())?;
95                    let t1 = T1::try_from(vec.pop().unwrap())?;
96                    Ok((t1, t2))
97                } else {
98                    Err(TchError::Kind(format!(
99                        "unable to unpack ivalue, expected a tuple of len 2 got {}",
100                        vec.len()
101                    )))
102                }
103            }
104            _ => Err(TchError::Kind(format!(
105                "unable to unpack ivalue, expected a tuple got {}",
106                value.type_str()
107            ))),
108        }
109    }
110}
111
112impl<T1, T2, T3, T1E, T2E, T3E> TryFrom<IValue> for (T1, T2, T3)
113where
114    T1: TryFrom<IValue, Error = T1E>,
115    TchError: From<T1E>,
116    T2: TryFrom<IValue, Error = T2E>,
117    TchError: From<T2E>,
118    T3: TryFrom<IValue, Error = T3E>,
119    TchError: From<T3E>,
120{
121    type Error = TchError;
122    fn try_from(value: IValue) -> Result<Self, TchError> {
123        match value {
124            IValue::GenericList(mut vec) | IValue::Tuple(mut vec) => {
125                if vec.len() == 3 {
126                    let t3 = T3::try_from(vec.pop().unwrap())?;
127                    let t2 = T2::try_from(vec.pop().unwrap())?;
128                    let t1 = T1::try_from(vec.pop().unwrap())?;
129                    Ok((t1, t2, t3))
130                } else {
131                    Err(TchError::Kind(format!(
132                        "unable to unpack ivalue, expected a tuple of len 3 got {}",
133                        vec.len()
134                    )))
135                }
136            }
137            _ => Err(TchError::Kind(format!(
138                "unable to unpack ivalue, expected a tuple got {}",
139                value.type_str()
140            ))),
141        }
142    }
143}
144
145impl<T1, T2, T3, T4, T1E, T2E, T3E, T4E> TryFrom<IValue> for (T1, T2, T3, T4)
146where
147    T1: TryFrom<IValue, Error = T1E>,
148    TchError: From<T1E>,
149    T2: TryFrom<IValue, Error = T2E>,
150    TchError: From<T2E>,
151    T3: TryFrom<IValue, Error = T3E>,
152    TchError: From<T3E>,
153    T4: TryFrom<IValue, Error = T4E>,
154    TchError: From<T4E>,
155{
156    type Error = TchError;
157    fn try_from(value: IValue) -> Result<Self, TchError> {
158        match value {
159            IValue::GenericList(mut vec) | IValue::Tuple(mut vec) => {
160                if vec.len() == 4 {
161                    let t4 = T4::try_from(vec.pop().unwrap())?;
162                    let t3 = T3::try_from(vec.pop().unwrap())?;
163                    let t2 = T2::try_from(vec.pop().unwrap())?;
164                    let t1 = T1::try_from(vec.pop().unwrap())?;
165                    Ok((t1, t2, t3, t4))
166                } else {
167                    Err(TchError::Kind(format!(
168                        "unable to unpack ivalue, expected a tuple of len 4 got {}",
169                        vec.len()
170                    )))
171                }
172            }
173            _ => Err(TchError::Kind(format!(
174                "unable to unpack ivalue, expected a tuple got {}",
175                value.type_str()
176            ))),
177        }
178    }
179}
180
181macro_rules! impl_from {
182    ($type_:ty, $cons:ident) => {
183        impl From<$type_> for IValue {
184            fn from(v: $type_) -> Self {
185                IValue::$cons(v)
186            }
187        }
188
189        impl TryFrom<IValue> for $type_ {
190            type Error = TchError;
191            fn try_from(value: IValue) -> Result<$type_, TchError> {
192                match value {
193                    IValue::$cons(t) => Ok(t),
194                    _ => Err(TchError::Kind(format!(
195                        "unable to unpack ivalue, expected {} got {}",
196                        std::stringify!($cons),
197                        value.type_str()
198                    ))),
199                }
200            }
201        }
202
203        // A generic trait for Option<T> would seem nicer but because
204        // of E0119, this is currently hard to do.
205        // See https://github.com/rust-lang/rust/issues/50133
206        impl TryFrom<IValue> for Option<$type_> {
207            type Error = TchError;
208            fn try_from(value: IValue) -> Result<Self, TchError> {
209                match value {
210                    IValue::None => Ok(None),
211                    IValue::$cons(t) => Ok(Some(t)),
212                    _ => Err(TchError::Kind(format!(
213                        "unable to unpack ivalue, expected {} or None got {}",
214                        std::stringify!($cons),
215                        value.type_str()
216                    ))),
217                }
218            }
219        }
220    };
221}
222
223impl_from!(i64, Int);
224impl_from!(f64, Double);
225impl_from!(bool, Bool);
226impl_from!(String, String);
227impl_from!(Tensor, Tensor);
228impl_from!(Vec<i64>, IntList);
229impl_from!(Vec<f64>, DoubleList);
230impl_from!(Vec<bool>, BoolList);
231impl_from!(Vec<String>, StringList);
232impl_from!(Vec<crate::Tensor>, TensorList);
233impl_from!(Vec<IValue>, GenericList);
234impl_from!(Vec<(IValue, IValue)>, GenericDict);
235impl_from!(Object, Object);
236
237impl From<&str> for IValue {
238    fn from(s: &str) -> Self {
239        IValue::String(s.to_string())
240    }
241}
242
243impl IValue {
244    #![allow(unused_unsafe)]
245    pub(super) fn to_c(&self) -> Result<*mut CIValue, TchError> {
246        let c = unsafe_torch_err!(match self {
247            IValue::Tensor(tensor) => ati_tensor(tensor.c_tensor),
248            IValue::Int(i) => ati_int(*i),
249            IValue::None => ati_none(),
250            IValue::Double(f) => ati_double(*f),
251            IValue::Bool(b) => ati_bool(i32::from(*b)),
252            IValue::Tuple(v) => {
253                let v = v.iter().map(Self::to_c).collect::<Result<Vec<_>, TchError>>()?;
254                let tuple = ati_tuple(v.as_ptr(), v.len() as c_int);
255                for x in v {
256                    ati_free(x);
257                }
258
259                tuple
260            }
261            IValue::GenericList(v) => {
262                let v = v.iter().map(Self::to_c).collect::<Result<Vec<_>, TchError>>()?;
263                let list = ati_generic_list(v.as_ptr(), v.len() as c_int);
264                for x in v {
265                    ati_free(x);
266                }
267                list
268            }
269            IValue::IntList(v) => ati_int_list(v.as_ptr(), v.len() as c_int),
270            IValue::DoubleList(v) => ati_double_list(v.as_ptr(), v.len() as c_int),
271            IValue::BoolList(v) => {
272                let v: Vec<libc::c_char> = v.iter().map(|&b| libc::c_char::from(b)).collect();
273                ati_bool_list(v.as_ptr(), v.len() as c_int)
274            }
275            IValue::TensorList(v) => {
276                let v = v.iter().map(|t| t.c_tensor).collect::<Vec<_>>();
277                ati_tensor_list(v.as_ptr(), v.len() as c_int)
278            }
279            IValue::String(string) => {
280                let c_str = std::ffi::CString::new(string.as_str())?;
281                ati_string(c_str.as_ptr())
282            }
283            IValue::StringList(strings) => {
284                let mut v = vec![];
285                for s in strings {
286                    v.push(std::ffi::CString::new(s.as_str())?);
287                }
288                let v_ptr: Vec<_> = v.iter().map(|s| s.as_ptr()).collect();
289                ati_string_list(v_ptr.as_ptr(), v.len() as c_int)
290            }
291            IValue::GenericDict(dict) => {
292                let v = dict
293                    .iter()
294                    .flat_map(|(k, v)| vec![Self::to_c(k), Self::to_c(v)])
295                    .collect::<Result<Vec<_>, TchError>>()?;
296                let dict = ati_generic_dict(v.as_ptr(), dict.len() as c_int);
297                for x in v {
298                    ati_free(x);
299                }
300                dict
301            }
302            IValue::Object(Object { c_ivalue }) => {
303                // Clone the object if necessary before passing the pointer to the C++ side.
304                unsafe_torch_err!(ati_clone(*c_ivalue))
305            }
306        });
307        Ok(c)
308    }
309
310    // This consumes the pointer and frees the associated memory (unless it is an Object).
311    pub(super) fn from_c(c_ivalue: *mut CIValue) -> Result<Self, TchError> {
312        let mut free = true;
313        let tag = unsafe_torch_err!(ati_tag(c_ivalue));
314        let v = match tag {
315            0 => IValue::None,
316            1 => {
317                let c_tensor = unsafe_torch_err!(ati_to_tensor(c_ivalue));
318                IValue::Tensor(crate::Tensor { c_tensor })
319            }
320            2 => IValue::Double(unsafe_torch_err!(ati_to_double(c_ivalue))),
321            3 => IValue::Int(unsafe_torch_err!(ati_to_int(c_ivalue))),
322            4 => {
323                let b = unsafe_torch_err!(ati_to_bool(c_ivalue));
324                if b < 0 {
325                    return Err(TchError::Kind(format!("unexpected bool value {b}")));
326                }
327                IValue::Bool(b != 0)
328            }
329            5 => {
330                let len = unsafe_torch_err!(ati_tuple_length(c_ivalue));
331                let mut c_ivalues: Vec<_> =
332                    (0..len).map(|_| std::ptr::null_mut::<CIValue>()).collect();
333                unsafe_torch_err!(ati_to_tuple(c_ivalue, c_ivalues.as_mut_ptr(), len));
334                let vec: Result<Vec<_>, _> =
335                    c_ivalues.iter().map(|&c_ivalue| (Self::from_c(c_ivalue))).collect();
336                IValue::Tuple(vec?)
337            }
338            6 => {
339                let len = unsafe_torch_err!(ati_length(c_ivalue));
340                let mut c_array = vec![0i64; len as usize];
341                unsafe_torch_err!(ati_to_int_list(c_ivalue, c_array.as_mut_ptr(), len));
342                IValue::IntList(c_array)
343            }
344            7 => {
345                let len = unsafe_torch_err!(ati_length(c_ivalue));
346                let mut c_array = vec![0f64; len as usize];
347                unsafe_torch_err!(ati_to_double_list(c_ivalue, c_array.as_mut_ptr(), len));
348                IValue::DoubleList(c_array)
349            }
350            8 => {
351                let len = unsafe_torch_err!(ati_length(c_ivalue));
352                let mut c_array = vec![0_i8; len as usize];
353                let c_array_ptr = c_array.as_mut_ptr() as *mut libc::c_char;
354                unsafe_torch_err!(ati_to_bool_list(c_ivalue, c_array_ptr, len));
355                IValue::BoolList(c_array.iter().map(|&x| x != 0).collect())
356            }
357            9 => {
358                let ptr = unsafe_torch_err!(ati_to_string(c_ivalue));
359                let string = match unsafe { ptr_to_string(ptr) } {
360                    None => return Err(TchError::Kind("nullptr representation".to_string())),
361                    Some(s) => s,
362                };
363                IValue::String(string)
364            }
365            10 => {
366                let len = unsafe_torch_err!(ati_length(c_ivalue));
367                let mut c_tensors: Vec<_> =
368                    (0..len).map(|_| std::ptr::null_mut::<C_tensor>()).collect();
369                unsafe_torch_err!(ati_to_tensor_list(c_ivalue, c_tensors.as_mut_ptr(), len));
370                let vec: Vec<_> = c_tensors.iter().map(|&c_tensor| (Tensor { c_tensor })).collect();
371                IValue::TensorList(vec)
372            }
373            12 => {
374                let len = unsafe_torch_err!(ati_length(c_ivalue));
375                let mut c_ivalues: Vec<_> =
376                    (0..len).map(|_| std::ptr::null_mut::<CIValue>()).collect();
377                unsafe_torch_err!(ati_to_generic_list(c_ivalue, c_ivalues.as_mut_ptr(), len));
378                let vec: Result<Vec<_>, _> =
379                    c_ivalues.iter().map(|&c_ivalue| (Self::from_c(c_ivalue))).collect();
380                IValue::GenericList(vec?)
381            }
382            13 => {
383                let len = unsafe_torch_err!(ati_length(c_ivalue));
384                let mut c_ivalues: Vec<_> =
385                    (0..2 * len).map(|_| std::ptr::null_mut::<CIValue>()).collect();
386                unsafe_torch_err!(ati_to_generic_dict(c_ivalue, c_ivalues.as_mut_ptr(), len));
387                let mut res: Vec<(IValue, IValue)> = vec![];
388                for i in 0..(len as usize) {
389                    let key = Self::from_c(c_ivalues[2 * i])?;
390                    let value = Self::from_c(c_ivalues[2 * i + 1])?;
391                    res.push((key, value))
392                }
393                IValue::GenericDict(res)
394            }
395            14 => {
396                free = false;
397                IValue::Object(Object { c_ivalue })
398            }
399            _ => return Err(TchError::Kind(format!("unhandled tag {tag}"))),
400        };
401        if free {
402            unsafe_torch_err!(ati_free(c_ivalue));
403        }
404        Ok(v)
405    }
406}
407
408/// A jit PyTorch module.
409///
410/// These modules can be created via the
411/// [TorchScript python api](https://pytorch.org/docs/stable/jit.html).
412#[derive(Debug)]
413pub struct CModule {
414    pub(super) c_module: *mut CModule_,
415}
416
417unsafe impl Send for CModule {}
418
419unsafe impl Sync for CModule {}
420
421impl Drop for CModule {
422    fn drop(&mut self) {
423        unsafe_torch!(atm_free(self.c_module))
424    }
425}
426
427impl CModule {
428    /// Loads a PyTorch saved JIT model from a file.
429    pub fn load<T: AsRef<std::path::Path>>(path: T) -> Result<CModule, TchError> {
430        let path = path_to_cstring(path)?;
431        let c_module = unsafe_torch_err!(atm_load(path.as_ptr()));
432        Ok(CModule { c_module })
433    }
434
435    /// Loads a PyTorch saved JIT model from a file onto the given device.
436    ///
437    /// This function loads the model directly on the specified device,
438    /// which means it also allows loading a GPU model on the CPU without having a CUDA enabled GPU.
439    pub fn load_on_device<T: AsRef<std::path::Path>>(
440        path: T,
441        device: Device,
442    ) -> Result<CModule, TchError> {
443        let path = path_to_cstring(path)?;
444        let c_module = unsafe_torch_err!(atm_load_on_device(path.as_ptr(), device.c_int()));
445        Ok(CModule { c_module })
446    }
447
448    /// Loads a PyTorch saved JIT model from a read instance.
449    pub fn load_data<T: std::io::Read>(f: &mut T) -> Result<CModule, TchError> {
450        let mut buffer = Vec::new();
451        f.read_to_end(&mut buffer)?;
452        let buffer_ptr = buffer.as_ptr() as *const libc::c_char;
453        let c_module = unsafe_torch_err!(atm_load_str(buffer_ptr, buffer.len()));
454        Ok(CModule { c_module })
455    }
456
457    /// Loads a PyTorch saved JIT model from a read instance.
458    ///
459    /// This function loads the model directly on the specified device,
460    /// which means it also allows loading a GPU model on the CPU without having a CUDA enabled GPU.
461    pub fn load_data_on_device<T: std::io::Read>(
462        f: &mut T,
463        device: Device,
464    ) -> Result<CModule, TchError> {
465        let mut buffer = Vec::new();
466        f.read_to_end(&mut buffer)?;
467        let buffer_ptr = buffer.as_ptr() as *const libc::c_char;
468        let c_module =
469            unsafe_torch_err!(atm_load_str_on_device(buffer_ptr, buffer.len(), device.c_int()));
470        Ok(CModule { c_module })
471    }
472
473    /// Performs the forward pass for a model on some specified tensor inputs. This is equivalent
474    /// to calling method_ts with the 'forward' method name, and returns a single tensor.
475    pub fn forward_ts<T: Borrow<Tensor>>(&self, ts: &[T]) -> Result<Tensor, TchError> {
476        let ts: Vec<_> = ts.iter().map(|x| x.borrow().c_tensor).collect();
477        let c_tensor =
478            unsafe_torch_err!(atm_forward(self.c_module, ts.as_ptr(), ts.len() as c_int));
479        Ok(Tensor { c_tensor })
480    }
481
482    /// Performs the forward pass for a model on some specified ivalue inputs. This is equivalent
483    /// to calling method_is with the 'forward' method name, and returns an arbitrary ivalue.
484    pub fn forward_is<T: Borrow<IValue>>(&self, ts: &[T]) -> Result<IValue, TchError> {
485        let ts = ts.iter().map(|x| x.borrow().to_c()).collect::<Result<Vec<_>, TchError>>()?;
486        let c_ivalue =
487            unsafe_torch_err!(atm_forward_(self.c_module, ts.as_ptr(), ts.len() as c_int));
488        for x in ts {
489            unsafe { ati_free(x) }
490        }
491        IValue::from_c(c_ivalue)
492    }
493
494    /// Runs a specified entry point for a model on some given tensor inputs.
495    pub fn method_ts<T: Borrow<Tensor>>(
496        &self,
497        method_name: &str,
498        ts: &[T],
499    ) -> Result<Tensor, TchError> {
500        let ts: Vec<_> = ts.iter().map(|x| x.borrow().c_tensor).collect();
501        let method_name = std::ffi::CString::new(method_name)?;
502        let c_tensor = unsafe_torch_err!(atm_method(
503            self.c_module,
504            method_name.as_ptr(),
505            ts.as_ptr(),
506            ts.len() as c_int
507        ));
508        Ok(Tensor { c_tensor })
509    }
510
511    /// Runs a specified entry point for a model on some given ivalue inputs.
512    pub fn method_is<T: Borrow<IValue>>(
513        &self,
514        method_name: &str,
515        ts: &[T],
516    ) -> Result<IValue, TchError> {
517        let ts = ts.iter().map(|x| x.borrow().to_c()).collect::<Result<Vec<_>, TchError>>()?;
518        let method_name = std::ffi::CString::new(method_name)?;
519        let c_ivalue = unsafe_torch_err!(atm_method_(
520            self.c_module,
521            method_name.as_ptr(),
522            ts.as_ptr(),
523            ts.len() as c_int
524        ));
525        for x in ts {
526            unsafe { ati_free(x) }
527        }
528        IValue::from_c(c_ivalue)
529    }
530
531    /// Create a specified custom JIT class object with the given class name, eg: `__torch__.foo.Bar`
532    pub fn create_class_is<T: Borrow<IValue>>(
533        &self,
534        clz_name: &str,
535        ts: &[T],
536    ) -> Result<IValue, TchError> {
537        let ts = ts.iter().map(|x| x.borrow().to_c()).collect::<Result<Vec<_>, TchError>>()?;
538        let clz_name = std::ffi::CString::new(clz_name)?;
539        let c_ivalue = unsafe_torch_err!(atm_create_class_(
540            self.c_module,
541            clz_name.as_ptr(),
542            ts.as_ptr(),
543            ts.len() as c_int
544        ));
545        for x in ts {
546            unsafe { ati_free(x) }
547        }
548        IValue::from_c(c_ivalue)
549    }
550
551    /// Switches the module to evaluation mode.
552    pub fn f_set_eval(&mut self) -> Result<(), TchError> {
553        unsafe_torch_err!(atm_eval(self.c_module));
554        Ok(())
555    }
556
557    /// Switches the module to evaluation mode.
558    pub fn set_eval(&mut self) {
559        self.f_set_eval().unwrap();
560    }
561
562    /// Switches the module to training mode.
563    pub fn f_set_train(&mut self) -> Result<(), TchError> {
564        unsafe_torch_err!(atm_train(self.c_module));
565        Ok(())
566    }
567
568    /// Switches the module to training mode.
569    pub fn set_train(&mut self) {
570        self.f_set_train().unwrap();
571    }
572
573    /// Moves the module to a different device and converts the kind.
574    pub fn to(&mut self, device: Device, kind: Kind, non_blocking: bool) {
575        unsafe_torch!(atm_to(self.c_module, device.c_int(), kind.c_int(), non_blocking));
576    }
577
578    /// Saves a module to a given path.
579    pub fn save<T: AsRef<std::path::Path>>(&self, path: T) -> Result<(), TchError> {
580        let path = path_to_cstring(path)?;
581        unsafe_torch_err!(atm_save(self.c_module, path.as_ptr()));
582        Ok(())
583    }
584
585    /// Loads some named tensors from a module
586    pub fn named_parameters(&self) -> Result<Vec<(String, Tensor)>, TchError> {
587        let mut v: Vec<(String, Tensor)> = vec![];
588        unsafe_torch_err!(atm_named_parameters(
589            self.c_module,
590            &mut v as *mut _ as *mut c_void,
591            super::tensor::add_callback
592        ));
593        Ok(v)
594    }
595
596    /// Create a new module by tracing the application of the specified function on
597    /// the given inputs.
598    pub fn create_by_tracing<F>(
599        modl_name: &str,
600        fn_name: &str,
601        inputs: &[Tensor],
602        closure: &mut F,
603    ) -> Result<CModule, TchError>
604    where
605        F: FnMut(&[Tensor]) -> Vec<Tensor>,
606    {
607        let modl_name = std::ffi::CString::new(modl_name)?;
608        let fn_name = std::ffi::CString::new(fn_name)?;
609        let c_inputs = inputs.iter().map(|tensor| tensor.c_tensor).collect::<Vec<_>>();
610        let c_module = unsafe_torch_err!(atm_create_for_tracing(
611            modl_name.as_ptr(),
612            c_inputs.as_ptr(),
613            c_inputs.len() as c_int
614        ));
615        let outputs = closure(inputs);
616        let c_outputs = outputs.iter().map(|tensor| tensor.c_tensor).collect::<Vec<_>>();
617        unsafe_torch_err!(atm_end_tracing(
618            c_module,
619            fn_name.as_ptr(),
620            c_outputs.as_ptr(),
621            c_outputs.len() as c_int,
622        ));
623        Ok(CModule { c_module })
624    }
625}
626
627/// The trainable version of a jit PyTorch module.
628///
629/// These modules can be created via the
630/// [TorchScript python api](https://pytorch.org/docs/stable/jit.html).
631#[derive(Debug)]
632pub struct TrainableCModule {
633    pub(crate) inner: CModule,
634}
635
636impl TrainableCModule {
637    /// Loads a PyTorch saved JIT module from a file.
638    ///
639    /// This function also adds the tensors from the JIT module to the VarStore path
640    /// passed as argument so that the module can be trained.
641    pub fn load<T: AsRef<std::path::Path>>(module_path: T, path: Path) -> Result<Self, TchError> {
642        let inner = CModule::load_on_device(module_path, path.device())?;
643        for (name, tensor) in inner.named_parameters()? {
644            let requires_grad = tensor.requires_grad();
645            let _t = path.add(&name.replace('.', "_"), tensor, requires_grad);
646        }
647        Ok(TrainableCModule { inner })
648    }
649
650    /// Loads a PyTorch saved JIT model from a read instance.
651    ///
652    /// This function also adds the tensors from the JIT module to the VarStore path
653    /// passed as argument so that the module can be trained.
654    pub fn load_data<T: std::io::Read>(data: &mut T, path: Path) -> Result<Self, TchError> {
655        let inner = CModule::load_data_on_device(data, path.device())?;
656        for (name, tensor) in inner.named_parameters()? {
657            let requires_grad = tensor.requires_grad();
658            let _t = path.add(&name.replace('.', "_"), tensor, requires_grad);
659        }
660        Ok(TrainableCModule { inner })
661    }
662
663    pub fn save<T: AsRef<std::path::Path>>(&self, module_path: T) -> Result<(), TchError> {
664        self.inner.save(module_path)
665    }
666
667    /// Switches the module to training mode.
668    pub fn f_set_train(&mut self) -> Result<(), TchError> {
669        self.inner.f_set_train()
670    }
671
672    /// Switches the module to training mode.
673    pub fn set_train(&mut self) {
674        self.inner.set_train()
675    }
676
677    /// Switches the module to evaluation mode.
678    pub fn f_set_eval(&mut self) -> Result<(), TchError> {
679        self.inner.f_set_eval()
680    }
681
682    /// Switches the module to evaluation mode.
683    pub fn set_eval(&mut self) {
684        self.inner.set_eval()
685    }
686
687    /// Performs the forward pass for a model on some specified tensor inputs.
688    pub fn forward_ts<T: Borrow<Tensor>>(&self, ts: &[T]) -> Result<Tensor, TchError> {
689        self.inner.forward_ts(ts)
690    }
691
692    /// Performs the forward pass for a model on some specified ivalue inputs.
693    pub fn forward_is<T: Borrow<IValue>>(&self, ts: &[T]) -> Result<IValue, TchError> {
694        self.inner.forward_is(ts)
695    }
696
697    /// Runs a specified entry point for a model on some given tensor inputs.
698    pub fn method_ts<T: Borrow<Tensor>>(
699        &self,
700        method_name: &str,
701        ts: &[T],
702    ) -> Result<Tensor, TchError> {
703        self.inner.method_ts(method_name, ts)
704    }
705
706    /// Runs a specified entry point for a model on some given ivalue inputs.
707    pub fn method_is<T: Borrow<IValue>>(
708        &self,
709        method_name: &str,
710        ts: &[T],
711    ) -> Result<IValue, TchError> {
712        self.inner.method_is(method_name, ts)
713    }
714}
715
716/// Returns whether profiling mode is set or not.
717pub fn f_get_profiling_mode() -> Result<bool, TchError> {
718    Ok(unsafe_torch_err!(atm_get_profiling_mode()) != 0)
719}
720
721/// Returns whether profiling mode is set or not.
722pub fn get_profiling_mode() -> bool {
723    f_get_profiling_mode().unwrap()
724}
725
726/// Activates or deactivates the profiling mode.
727pub fn f_set_profiling_mode(b: bool) -> Result<(), TchError> {
728    unsafe_torch_err!(atm_set_profiling_mode(b as c_int));
729    Ok(())
730}
731
732/// Activates or deactivates the profiling mode.
733pub fn set_profiling_mode(b: bool) {
734    f_set_profiling_mode(b).unwrap()
735}
736
737pub fn f_fuser_cuda_set_enabled(enabled: bool) -> Result<(), TchError> {
738    unsafe_torch_err!(atm_fuser_cuda_set_enabled(enabled));
739    Ok(())
740}
741
742pub fn fuser_cuda_set_enabled(enabled: bool) {
743    f_fuser_cuda_set_enabled(enabled).unwrap()
744}
745
746pub fn f_fuser_cuda_is_enabled() -> Result<bool, TchError> {
747    let b = unsafe_torch_err!(atm_fuser_cuda_is_enabled());
748    Ok(b)
749}
750
751pub fn fuser_cuda_is_enabled() -> bool {
752    f_fuser_cuda_is_enabled().unwrap()
753}
754
755pub fn f_set_tensor_expr_fuser_enabled(b: bool) -> Result<(), TchError> {
756    unsafe_torch_err!(atm_set_tensor_expr_fuser_enabled(b as c_int));
757    Ok(())
758}
759
760pub fn set_tensor_expr_fuser_enabled(b: bool) {
761    f_set_tensor_expr_fuser_enabled(b).unwrap()
762}
763
764pub fn f_get_tensor_expr_fuser_enabled() -> Result<bool, TchError> {
765    Ok(unsafe_torch_err!(atm_get_tensor_expr_fuser_enabled()))
766}
767
768pub fn get_tensor_expr_fuser_enabled() -> bool {
769    f_get_tensor_expr_fuser_enabled().unwrap()
770}
771
772/// Enables or disables the graph executor optimizer for the current thread.
773///
774/// # Arguments
775///
776/// * `b` - A boolean that if true enables the graph executor optimizer for the current thread.
777///
778/// This function returns an error if it is not possible to enable or disable the graph executor optimizer.
779pub fn f_set_graph_executor_optimize(b: bool) -> Result<(), TchError> {
780    unsafe_torch_err!(at_set_graph_executor_optimize(b));
781    Ok(())
782}
783
784/// Enables or disables the graph executor optimizer for the current thread.
785///
786/// # Arguments
787///
788/// * `b` - A boolean that if true enables the graph executor optimizer for the current thread.
789///
790/// This panics if it is not possible to enable or disable the graph executor optimizer.
791pub fn set_graph_executor_optimize(b: bool) {
792    f_set_graph_executor_optimize(b).unwrap();
793}
794
795#[allow(clippy::derive_partial_eq_without_eq)]
796#[derive(Debug, PartialEq)]
797pub struct Object {
798    c_ivalue: *mut CIValue,
799}
800
801impl Object {
802    /// Applies the specified method to the object. The method takes as argument an arbitrary
803    /// number of ivalues and returns an ivalue.
804    pub fn method_is<T: Borrow<IValue>>(
805        &self,
806        method_name: &str,
807        ts: &[T],
808    ) -> Result<IValue, TchError> {
809        let ts = ts.iter().map(|x| x.borrow().to_c()).collect::<Result<Vec<_>, TchError>>()?;
810        let method_name = std::ffi::CString::new(method_name)?;
811        let c_ivalue = unsafe_torch_err!(ati_object_method_(
812            self.c_ivalue,
813            method_name.as_ptr(),
814            ts.as_ptr(),
815            ts.len() as c_int
816        ));
817        for x in ts {
818            unsafe { ati_free(x) }
819        }
820        IValue::from_c(c_ivalue)
821    }
822
823    /// Retrieves the specified attribute from an object as an ivalue.
824    pub fn getattr(&self, attr_name: &str) -> Result<IValue, TchError> {
825        let property_name = std::ffi::CString::new(attr_name)?;
826        let c_ivalue =
827            unsafe_torch_err!(ati_object_getattr_(self.c_ivalue, property_name.as_ptr()));
828        if c_ivalue.is_null() {
829            return Err(TchError::Torch(format!(
830                "Object.getattr(\"{attr_name}\") returned CIValue nullptr"
831            )));
832        }
833        IValue::from_c(c_ivalue)
834    }
835}
836
837impl Drop for Object {
838    fn drop(&mut self) {
839        unsafe_torch!(ati_free(self.c_ivalue))
840    }
841}
842
843#[cfg(test)]
844mod tests {
845    use super::IValue;
846    use std::f64::consts;
847
848    fn round_trip<T: Into<IValue>>(t: T) {
849        let ivalue: IValue = t.into();
850        let ivalue2 = IValue::from_c(ivalue.to_c().unwrap()).unwrap();
851        assert_eq!(ivalue, ivalue2);
852    }
853    #[test]
854    fn ivalue_round_trip() {
855        round_trip(());
856        round_trip(true);
857        round_trip(false);
858        round_trip(-1);
859        round_trip(42);
860        round_trip(15);
861        round_trip("".to_string());
862        round_trip("foobar".to_string());
863        round_trip((42, consts::PI));
864        round_trip(vec![42, 1337]);
865        round_trip(vec![consts::E, consts::PI, 299792458.00001]);
866        round_trip((vec![true, false, true, true], vec![consts::E, consts::PI, 299792458.00001]));
867        round_trip(vec![IValue::from(42), IValue::from("foobar")]);
868        round_trip(vec![
869            (IValue::from(42), IValue::from("foobar")),
870            (IValue::from("foo"), IValue::from("bar")),
871        ]);
872    }
873}