1use super::utils::{path_to_cstring, ptr_to_string};
3use super::{device::Device, kind::Kind};
4use crate::{nn::Path, TchError, Tensor};
5use libc::{c_int, c_void};
6use std::borrow::Borrow;
7use std::convert::TryFrom;
8use torch_sys_plus::*;
9
10#[derive(Debug, PartialEq)]
13#[non_exhaustive]
14pub enum IValue {
15 None,
16 Tensor(crate::Tensor),
17 Double(f64),
18 Int(i64),
19 Bool(bool),
20 Tuple(Vec<IValue>),
21 IntList(Vec<i64>),
22 DoubleList(Vec<f64>),
23 BoolList(Vec<bool>),
24 String(String),
25 StringList(Vec<String>),
26 TensorList(Vec<crate::Tensor>),
27 GenericList(Vec<IValue>),
28 GenericDict(Vec<(IValue, IValue)>),
31 Object(Object),
32}
33
34impl IValue {
35 fn type_str(self) -> &'static str {
36 match self {
37 IValue::None => "None",
38 IValue::Tensor(_) => "Tensor",
39 IValue::Double(_) => "Double",
40 IValue::Int(_) => "Int",
41 IValue::Bool(_) => "Bool",
42 IValue::Tuple(_) => "Tuple",
43 IValue::IntList(_) => "IntList",
44 IValue::DoubleList(_) => "DoubleList",
45 IValue::BoolList(_) => "BoolList",
46 IValue::String(_) => "String",
47 IValue::StringList(_) => "StringList",
48 IValue::TensorList(_) => "TensorList",
49 IValue::GenericList(_) => "GenericList",
50 IValue::GenericDict(_) => "GenericDict",
51 IValue::Object(_) => "Object",
52 }
53 }
54}
55
56impl From<()> for IValue {
57 fn from((): ()) -> Self {
58 IValue::None
59 }
60}
61
62impl<T1: Into<IValue>, T2: Into<IValue>> From<(T1, T2)> for IValue {
63 fn from((p1, p2): (T1, T2)) -> Self {
64 IValue::Tuple(vec![p1.into(), p2.into()])
65 }
66}
67
68impl<T1: Into<IValue>, T2: Into<IValue>, T3: Into<IValue>> From<(T1, T2, T3)> for IValue {
69 fn from((p1, p2, p3): (T1, T2, T3)) -> Self {
70 IValue::Tuple(vec![p1.into(), p2.into(), p3.into()])
71 }
72}
73
74impl<T1: Into<IValue>, T2: Into<IValue>, T3: Into<IValue>, T4: Into<IValue>> From<(T1, T2, T3, T4)>
75 for IValue
76{
77 fn from((p1, p2, p3, p4): (T1, T2, T3, T4)) -> Self {
78 IValue::Tuple(vec![p1.into(), p2.into(), p3.into(), p4.into()])
79 }
80}
81
82impl<T1, T2, T1E, T2E> TryFrom<IValue> for (T1, T2)
83where
84 T1: TryFrom<IValue, Error = T1E>,
85 TchError: From<T1E>,
86 T2: TryFrom<IValue, Error = T2E>,
87 TchError: From<T2E>,
88{
89 type Error = TchError;
90 fn try_from(value: IValue) -> Result<Self, TchError> {
91 match value {
92 IValue::GenericList(mut vec) | IValue::Tuple(mut vec) => {
93 if vec.len() == 2 {
94 let t2 = T2::try_from(vec.pop().unwrap())?;
95 let t1 = T1::try_from(vec.pop().unwrap())?;
96 Ok((t1, t2))
97 } else {
98 Err(TchError::Kind(format!(
99 "unable to unpack ivalue, expected a tuple of len 2 got {}",
100 vec.len()
101 )))
102 }
103 }
104 _ => Err(TchError::Kind(format!(
105 "unable to unpack ivalue, expected a tuple got {}",
106 value.type_str()
107 ))),
108 }
109 }
110}
111
112impl<T1, T2, T3, T1E, T2E, T3E> TryFrom<IValue> for (T1, T2, T3)
113where
114 T1: TryFrom<IValue, Error = T1E>,
115 TchError: From<T1E>,
116 T2: TryFrom<IValue, Error = T2E>,
117 TchError: From<T2E>,
118 T3: TryFrom<IValue, Error = T3E>,
119 TchError: From<T3E>,
120{
121 type Error = TchError;
122 fn try_from(value: IValue) -> Result<Self, TchError> {
123 match value {
124 IValue::GenericList(mut vec) | IValue::Tuple(mut vec) => {
125 if vec.len() == 3 {
126 let t3 = T3::try_from(vec.pop().unwrap())?;
127 let t2 = T2::try_from(vec.pop().unwrap())?;
128 let t1 = T1::try_from(vec.pop().unwrap())?;
129 Ok((t1, t2, t3))
130 } else {
131 Err(TchError::Kind(format!(
132 "unable to unpack ivalue, expected a tuple of len 3 got {}",
133 vec.len()
134 )))
135 }
136 }
137 _ => Err(TchError::Kind(format!(
138 "unable to unpack ivalue, expected a tuple got {}",
139 value.type_str()
140 ))),
141 }
142 }
143}
144
145impl<T1, T2, T3, T4, T1E, T2E, T3E, T4E> TryFrom<IValue> for (T1, T2, T3, T4)
146where
147 T1: TryFrom<IValue, Error = T1E>,
148 TchError: From<T1E>,
149 T2: TryFrom<IValue, Error = T2E>,
150 TchError: From<T2E>,
151 T3: TryFrom<IValue, Error = T3E>,
152 TchError: From<T3E>,
153 T4: TryFrom<IValue, Error = T4E>,
154 TchError: From<T4E>,
155{
156 type Error = TchError;
157 fn try_from(value: IValue) -> Result<Self, TchError> {
158 match value {
159 IValue::GenericList(mut vec) | IValue::Tuple(mut vec) => {
160 if vec.len() == 4 {
161 let t4 = T4::try_from(vec.pop().unwrap())?;
162 let t3 = T3::try_from(vec.pop().unwrap())?;
163 let t2 = T2::try_from(vec.pop().unwrap())?;
164 let t1 = T1::try_from(vec.pop().unwrap())?;
165 Ok((t1, t2, t3, t4))
166 } else {
167 Err(TchError::Kind(format!(
168 "unable to unpack ivalue, expected a tuple of len 4 got {}",
169 vec.len()
170 )))
171 }
172 }
173 _ => Err(TchError::Kind(format!(
174 "unable to unpack ivalue, expected a tuple got {}",
175 value.type_str()
176 ))),
177 }
178 }
179}
180
181macro_rules! impl_from {
182 ($type_:ty, $cons:ident) => {
183 impl From<$type_> for IValue {
184 fn from(v: $type_) -> Self {
185 IValue::$cons(v)
186 }
187 }
188
189 impl TryFrom<IValue> for $type_ {
190 type Error = TchError;
191 fn try_from(value: IValue) -> Result<$type_, TchError> {
192 match value {
193 IValue::$cons(t) => Ok(t),
194 _ => Err(TchError::Kind(format!(
195 "unable to unpack ivalue, expected {} got {}",
196 std::stringify!($cons),
197 value.type_str()
198 ))),
199 }
200 }
201 }
202
203 impl TryFrom<IValue> for Option<$type_> {
207 type Error = TchError;
208 fn try_from(value: IValue) -> Result<Self, TchError> {
209 match value {
210 IValue::None => Ok(None),
211 IValue::$cons(t) => Ok(Some(t)),
212 _ => Err(TchError::Kind(format!(
213 "unable to unpack ivalue, expected {} or None got {}",
214 std::stringify!($cons),
215 value.type_str()
216 ))),
217 }
218 }
219 }
220 };
221}
222
223impl_from!(i64, Int);
224impl_from!(f64, Double);
225impl_from!(bool, Bool);
226impl_from!(String, String);
227impl_from!(Tensor, Tensor);
228impl_from!(Vec<i64>, IntList);
229impl_from!(Vec<f64>, DoubleList);
230impl_from!(Vec<bool>, BoolList);
231impl_from!(Vec<String>, StringList);
232impl_from!(Vec<crate::Tensor>, TensorList);
233impl_from!(Vec<IValue>, GenericList);
234impl_from!(Vec<(IValue, IValue)>, GenericDict);
235impl_from!(Object, Object);
236
237impl From<&str> for IValue {
238 fn from(s: &str) -> Self {
239 IValue::String(s.to_string())
240 }
241}
242
243impl IValue {
244 #![allow(unused_unsafe)]
245 pub(super) fn to_c(&self) -> Result<*mut CIValue, TchError> {
246 let c = unsafe_torch_err!(match self {
247 IValue::Tensor(tensor) => ati_tensor(tensor.c_tensor),
248 IValue::Int(i) => ati_int(*i),
249 IValue::None => ati_none(),
250 IValue::Double(f) => ati_double(*f),
251 IValue::Bool(b) => ati_bool(i32::from(*b)),
252 IValue::Tuple(v) => {
253 let v = v.iter().map(Self::to_c).collect::<Result<Vec<_>, TchError>>()?;
254 let tuple = ati_tuple(v.as_ptr(), v.len() as c_int);
255 for x in v {
256 ati_free(x);
257 }
258
259 tuple
260 }
261 IValue::GenericList(v) => {
262 let v = v.iter().map(Self::to_c).collect::<Result<Vec<_>, TchError>>()?;
263 let list = ati_generic_list(v.as_ptr(), v.len() as c_int);
264 for x in v {
265 ati_free(x);
266 }
267 list
268 }
269 IValue::IntList(v) => ati_int_list(v.as_ptr(), v.len() as c_int),
270 IValue::DoubleList(v) => ati_double_list(v.as_ptr(), v.len() as c_int),
271 IValue::BoolList(v) => {
272 let v: Vec<libc::c_char> = v.iter().map(|&b| libc::c_char::from(b)).collect();
273 ati_bool_list(v.as_ptr(), v.len() as c_int)
274 }
275 IValue::TensorList(v) => {
276 let v = v.iter().map(|t| t.c_tensor).collect::<Vec<_>>();
277 ati_tensor_list(v.as_ptr(), v.len() as c_int)
278 }
279 IValue::String(string) => {
280 let c_str = std::ffi::CString::new(string.as_str())?;
281 ati_string(c_str.as_ptr())
282 }
283 IValue::StringList(strings) => {
284 let mut v = vec![];
285 for s in strings {
286 v.push(std::ffi::CString::new(s.as_str())?);
287 }
288 let v_ptr: Vec<_> = v.iter().map(|s| s.as_ptr()).collect();
289 ati_string_list(v_ptr.as_ptr(), v.len() as c_int)
290 }
291 IValue::GenericDict(dict) => {
292 let v = dict
293 .iter()
294 .flat_map(|(k, v)| vec![Self::to_c(k), Self::to_c(v)])
295 .collect::<Result<Vec<_>, TchError>>()?;
296 let dict = ati_generic_dict(v.as_ptr(), dict.len() as c_int);
297 for x in v {
298 ati_free(x);
299 }
300 dict
301 }
302 IValue::Object(Object { c_ivalue }) => {
303 unsafe_torch_err!(ati_clone(*c_ivalue))
305 }
306 });
307 Ok(c)
308 }
309
310 pub(super) fn from_c(c_ivalue: *mut CIValue) -> Result<Self, TchError> {
312 let mut free = true;
313 let tag = unsafe_torch_err!(ati_tag(c_ivalue));
314 let v = match tag {
315 0 => IValue::None,
316 1 => {
317 let c_tensor = unsafe_torch_err!(ati_to_tensor(c_ivalue));
318 IValue::Tensor(crate::Tensor { c_tensor })
319 }
320 2 => IValue::Double(unsafe_torch_err!(ati_to_double(c_ivalue))),
321 3 => IValue::Int(unsafe_torch_err!(ati_to_int(c_ivalue))),
322 4 => {
323 let b = unsafe_torch_err!(ati_to_bool(c_ivalue));
324 if b < 0 {
325 return Err(TchError::Kind(format!("unexpected bool value {b}")));
326 }
327 IValue::Bool(b != 0)
328 }
329 5 => {
330 let len = unsafe_torch_err!(ati_tuple_length(c_ivalue));
331 let mut c_ivalues: Vec<_> =
332 (0..len).map(|_| std::ptr::null_mut::<CIValue>()).collect();
333 unsafe_torch_err!(ati_to_tuple(c_ivalue, c_ivalues.as_mut_ptr(), len));
334 let vec: Result<Vec<_>, _> =
335 c_ivalues.iter().map(|&c_ivalue| (Self::from_c(c_ivalue))).collect();
336 IValue::Tuple(vec?)
337 }
338 6 => {
339 let len = unsafe_torch_err!(ati_length(c_ivalue));
340 let mut c_array = vec![0i64; len as usize];
341 unsafe_torch_err!(ati_to_int_list(c_ivalue, c_array.as_mut_ptr(), len));
342 IValue::IntList(c_array)
343 }
344 7 => {
345 let len = unsafe_torch_err!(ati_length(c_ivalue));
346 let mut c_array = vec![0f64; len as usize];
347 unsafe_torch_err!(ati_to_double_list(c_ivalue, c_array.as_mut_ptr(), len));
348 IValue::DoubleList(c_array)
349 }
350 8 => {
351 let len = unsafe_torch_err!(ati_length(c_ivalue));
352 let mut c_array = vec![0_i8; len as usize];
353 let c_array_ptr = c_array.as_mut_ptr() as *mut libc::c_char;
354 unsafe_torch_err!(ati_to_bool_list(c_ivalue, c_array_ptr, len));
355 IValue::BoolList(c_array.iter().map(|&x| x != 0).collect())
356 }
357 9 => {
358 let ptr = unsafe_torch_err!(ati_to_string(c_ivalue));
359 let string = match unsafe { ptr_to_string(ptr) } {
360 None => return Err(TchError::Kind("nullptr representation".to_string())),
361 Some(s) => s,
362 };
363 IValue::String(string)
364 }
365 10 => {
366 let len = unsafe_torch_err!(ati_length(c_ivalue));
367 let mut c_tensors: Vec<_> =
368 (0..len).map(|_| std::ptr::null_mut::<C_tensor>()).collect();
369 unsafe_torch_err!(ati_to_tensor_list(c_ivalue, c_tensors.as_mut_ptr(), len));
370 let vec: Vec<_> = c_tensors.iter().map(|&c_tensor| (Tensor { c_tensor })).collect();
371 IValue::TensorList(vec)
372 }
373 12 => {
374 let len = unsafe_torch_err!(ati_length(c_ivalue));
375 let mut c_ivalues: Vec<_> =
376 (0..len).map(|_| std::ptr::null_mut::<CIValue>()).collect();
377 unsafe_torch_err!(ati_to_generic_list(c_ivalue, c_ivalues.as_mut_ptr(), len));
378 let vec: Result<Vec<_>, _> =
379 c_ivalues.iter().map(|&c_ivalue| (Self::from_c(c_ivalue))).collect();
380 IValue::GenericList(vec?)
381 }
382 13 => {
383 let len = unsafe_torch_err!(ati_length(c_ivalue));
384 let mut c_ivalues: Vec<_> =
385 (0..2 * len).map(|_| std::ptr::null_mut::<CIValue>()).collect();
386 unsafe_torch_err!(ati_to_generic_dict(c_ivalue, c_ivalues.as_mut_ptr(), len));
387 let mut res: Vec<(IValue, IValue)> = vec![];
388 for i in 0..(len as usize) {
389 let key = Self::from_c(c_ivalues[2 * i])?;
390 let value = Self::from_c(c_ivalues[2 * i + 1])?;
391 res.push((key, value))
392 }
393 IValue::GenericDict(res)
394 }
395 14 => {
396 free = false;
397 IValue::Object(Object { c_ivalue })
398 }
399 _ => return Err(TchError::Kind(format!("unhandled tag {tag}"))),
400 };
401 if free {
402 unsafe_torch_err!(ati_free(c_ivalue));
403 }
404 Ok(v)
405 }
406}
407
408#[derive(Debug)]
413pub struct CModule {
414 pub(super) c_module: *mut CModule_,
415}
416
417unsafe impl Send for CModule {}
418
419unsafe impl Sync for CModule {}
420
421impl Drop for CModule {
422 fn drop(&mut self) {
423 unsafe_torch!(atm_free(self.c_module))
424 }
425}
426
427impl CModule {
428 pub fn load<T: AsRef<std::path::Path>>(path: T) -> Result<CModule, TchError> {
430 let path = path_to_cstring(path)?;
431 let c_module = unsafe_torch_err!(atm_load(path.as_ptr()));
432 Ok(CModule { c_module })
433 }
434
435 pub fn load_on_device<T: AsRef<std::path::Path>>(
440 path: T,
441 device: Device,
442 ) -> Result<CModule, TchError> {
443 let path = path_to_cstring(path)?;
444 let c_module = unsafe_torch_err!(atm_load_on_device(path.as_ptr(), device.c_int()));
445 Ok(CModule { c_module })
446 }
447
448 pub fn load_data<T: std::io::Read>(f: &mut T) -> Result<CModule, TchError> {
450 let mut buffer = Vec::new();
451 f.read_to_end(&mut buffer)?;
452 let buffer_ptr = buffer.as_ptr() as *const libc::c_char;
453 let c_module = unsafe_torch_err!(atm_load_str(buffer_ptr, buffer.len()));
454 Ok(CModule { c_module })
455 }
456
457 pub fn load_data_on_device<T: std::io::Read>(
462 f: &mut T,
463 device: Device,
464 ) -> Result<CModule, TchError> {
465 let mut buffer = Vec::new();
466 f.read_to_end(&mut buffer)?;
467 let buffer_ptr = buffer.as_ptr() as *const libc::c_char;
468 let c_module =
469 unsafe_torch_err!(atm_load_str_on_device(buffer_ptr, buffer.len(), device.c_int()));
470 Ok(CModule { c_module })
471 }
472
473 pub fn forward_ts<T: Borrow<Tensor>>(&self, ts: &[T]) -> Result<Tensor, TchError> {
476 let ts: Vec<_> = ts.iter().map(|x| x.borrow().c_tensor).collect();
477 let c_tensor =
478 unsafe_torch_err!(atm_forward(self.c_module, ts.as_ptr(), ts.len() as c_int));
479 Ok(Tensor { c_tensor })
480 }
481
482 pub fn forward_is<T: Borrow<IValue>>(&self, ts: &[T]) -> Result<IValue, TchError> {
485 let ts = ts.iter().map(|x| x.borrow().to_c()).collect::<Result<Vec<_>, TchError>>()?;
486 let c_ivalue =
487 unsafe_torch_err!(atm_forward_(self.c_module, ts.as_ptr(), ts.len() as c_int));
488 for x in ts {
489 unsafe { ati_free(x) }
490 }
491 IValue::from_c(c_ivalue)
492 }
493
494 pub fn method_ts<T: Borrow<Tensor>>(
496 &self,
497 method_name: &str,
498 ts: &[T],
499 ) -> Result<Tensor, TchError> {
500 let ts: Vec<_> = ts.iter().map(|x| x.borrow().c_tensor).collect();
501 let method_name = std::ffi::CString::new(method_name)?;
502 let c_tensor = unsafe_torch_err!(atm_method(
503 self.c_module,
504 method_name.as_ptr(),
505 ts.as_ptr(),
506 ts.len() as c_int
507 ));
508 Ok(Tensor { c_tensor })
509 }
510
511 pub fn method_is<T: Borrow<IValue>>(
513 &self,
514 method_name: &str,
515 ts: &[T],
516 ) -> Result<IValue, TchError> {
517 let ts = ts.iter().map(|x| x.borrow().to_c()).collect::<Result<Vec<_>, TchError>>()?;
518 let method_name = std::ffi::CString::new(method_name)?;
519 let c_ivalue = unsafe_torch_err!(atm_method_(
520 self.c_module,
521 method_name.as_ptr(),
522 ts.as_ptr(),
523 ts.len() as c_int
524 ));
525 for x in ts {
526 unsafe { ati_free(x) }
527 }
528 IValue::from_c(c_ivalue)
529 }
530
531 pub fn create_class_is<T: Borrow<IValue>>(
533 &self,
534 clz_name: &str,
535 ts: &[T],
536 ) -> Result<IValue, TchError> {
537 let ts = ts.iter().map(|x| x.borrow().to_c()).collect::<Result<Vec<_>, TchError>>()?;
538 let clz_name = std::ffi::CString::new(clz_name)?;
539 let c_ivalue = unsafe_torch_err!(atm_create_class_(
540 self.c_module,
541 clz_name.as_ptr(),
542 ts.as_ptr(),
543 ts.len() as c_int
544 ));
545 for x in ts {
546 unsafe { ati_free(x) }
547 }
548 IValue::from_c(c_ivalue)
549 }
550
551 pub fn f_set_eval(&mut self) -> Result<(), TchError> {
553 unsafe_torch_err!(atm_eval(self.c_module));
554 Ok(())
555 }
556
557 pub fn set_eval(&mut self) {
559 self.f_set_eval().unwrap();
560 }
561
562 pub fn f_set_train(&mut self) -> Result<(), TchError> {
564 unsafe_torch_err!(atm_train(self.c_module));
565 Ok(())
566 }
567
568 pub fn set_train(&mut self) {
570 self.f_set_train().unwrap();
571 }
572
573 pub fn to(&mut self, device: Device, kind: Kind, non_blocking: bool) {
575 unsafe_torch!(atm_to(self.c_module, device.c_int(), kind.c_int(), non_blocking));
576 }
577
578 pub fn save<T: AsRef<std::path::Path>>(&self, path: T) -> Result<(), TchError> {
580 let path = path_to_cstring(path)?;
581 unsafe_torch_err!(atm_save(self.c_module, path.as_ptr()));
582 Ok(())
583 }
584
585 pub fn named_parameters(&self) -> Result<Vec<(String, Tensor)>, TchError> {
587 let mut v: Vec<(String, Tensor)> = vec![];
588 unsafe_torch_err!(atm_named_parameters(
589 self.c_module,
590 &mut v as *mut _ as *mut c_void,
591 super::tensor::add_callback
592 ));
593 Ok(v)
594 }
595
596 pub fn create_by_tracing<F>(
599 modl_name: &str,
600 fn_name: &str,
601 inputs: &[Tensor],
602 closure: &mut F,
603 ) -> Result<CModule, TchError>
604 where
605 F: FnMut(&[Tensor]) -> Vec<Tensor>,
606 {
607 let modl_name = std::ffi::CString::new(modl_name)?;
608 let fn_name = std::ffi::CString::new(fn_name)?;
609 let c_inputs = inputs.iter().map(|tensor| tensor.c_tensor).collect::<Vec<_>>();
610 let c_module = unsafe_torch_err!(atm_create_for_tracing(
611 modl_name.as_ptr(),
612 c_inputs.as_ptr(),
613 c_inputs.len() as c_int
614 ));
615 let outputs = closure(inputs);
616 let c_outputs = outputs.iter().map(|tensor| tensor.c_tensor).collect::<Vec<_>>();
617 unsafe_torch_err!(atm_end_tracing(
618 c_module,
619 fn_name.as_ptr(),
620 c_outputs.as_ptr(),
621 c_outputs.len() as c_int,
622 ));
623 Ok(CModule { c_module })
624 }
625}
626
627#[derive(Debug)]
632pub struct TrainableCModule {
633 pub(crate) inner: CModule,
634}
635
636impl TrainableCModule {
637 pub fn load<T: AsRef<std::path::Path>>(module_path: T, path: Path) -> Result<Self, TchError> {
642 let inner = CModule::load_on_device(module_path, path.device())?;
643 for (name, tensor) in inner.named_parameters()? {
644 let requires_grad = tensor.requires_grad();
645 let _t = path.add(&name.replace('.', "_"), tensor, requires_grad);
646 }
647 Ok(TrainableCModule { inner })
648 }
649
650 pub fn load_data<T: std::io::Read>(data: &mut T, path: Path) -> Result<Self, TchError> {
655 let inner = CModule::load_data_on_device(data, path.device())?;
656 for (name, tensor) in inner.named_parameters()? {
657 let requires_grad = tensor.requires_grad();
658 let _t = path.add(&name.replace('.', "_"), tensor, requires_grad);
659 }
660 Ok(TrainableCModule { inner })
661 }
662
663 pub fn save<T: AsRef<std::path::Path>>(&self, module_path: T) -> Result<(), TchError> {
664 self.inner.save(module_path)
665 }
666
667 pub fn f_set_train(&mut self) -> Result<(), TchError> {
669 self.inner.f_set_train()
670 }
671
672 pub fn set_train(&mut self) {
674 self.inner.set_train()
675 }
676
677 pub fn f_set_eval(&mut self) -> Result<(), TchError> {
679 self.inner.f_set_eval()
680 }
681
682 pub fn set_eval(&mut self) {
684 self.inner.set_eval()
685 }
686
687 pub fn forward_ts<T: Borrow<Tensor>>(&self, ts: &[T]) -> Result<Tensor, TchError> {
689 self.inner.forward_ts(ts)
690 }
691
692 pub fn forward_is<T: Borrow<IValue>>(&self, ts: &[T]) -> Result<IValue, TchError> {
694 self.inner.forward_is(ts)
695 }
696
697 pub fn method_ts<T: Borrow<Tensor>>(
699 &self,
700 method_name: &str,
701 ts: &[T],
702 ) -> Result<Tensor, TchError> {
703 self.inner.method_ts(method_name, ts)
704 }
705
706 pub fn method_is<T: Borrow<IValue>>(
708 &self,
709 method_name: &str,
710 ts: &[T],
711 ) -> Result<IValue, TchError> {
712 self.inner.method_is(method_name, ts)
713 }
714}
715
716pub fn f_get_profiling_mode() -> Result<bool, TchError> {
718 Ok(unsafe_torch_err!(atm_get_profiling_mode()) != 0)
719}
720
721pub fn get_profiling_mode() -> bool {
723 f_get_profiling_mode().unwrap()
724}
725
726pub fn f_set_profiling_mode(b: bool) -> Result<(), TchError> {
728 unsafe_torch_err!(atm_set_profiling_mode(b as c_int));
729 Ok(())
730}
731
732pub fn set_profiling_mode(b: bool) {
734 f_set_profiling_mode(b).unwrap()
735}
736
737pub fn f_fuser_cuda_set_enabled(enabled: bool) -> Result<(), TchError> {
738 unsafe_torch_err!(atm_fuser_cuda_set_enabled(enabled));
739 Ok(())
740}
741
742pub fn fuser_cuda_set_enabled(enabled: bool) {
743 f_fuser_cuda_set_enabled(enabled).unwrap()
744}
745
746pub fn f_fuser_cuda_is_enabled() -> Result<bool, TchError> {
747 let b = unsafe_torch_err!(atm_fuser_cuda_is_enabled());
748 Ok(b)
749}
750
751pub fn fuser_cuda_is_enabled() -> bool {
752 f_fuser_cuda_is_enabled().unwrap()
753}
754
755pub fn f_set_tensor_expr_fuser_enabled(b: bool) -> Result<(), TchError> {
756 unsafe_torch_err!(atm_set_tensor_expr_fuser_enabled(b as c_int));
757 Ok(())
758}
759
760pub fn set_tensor_expr_fuser_enabled(b: bool) {
761 f_set_tensor_expr_fuser_enabled(b).unwrap()
762}
763
764pub fn f_get_tensor_expr_fuser_enabled() -> Result<bool, TchError> {
765 Ok(unsafe_torch_err!(atm_get_tensor_expr_fuser_enabled()))
766}
767
768pub fn get_tensor_expr_fuser_enabled() -> bool {
769 f_get_tensor_expr_fuser_enabled().unwrap()
770}
771
772pub fn f_set_graph_executor_optimize(b: bool) -> Result<(), TchError> {
780 unsafe_torch_err!(at_set_graph_executor_optimize(b));
781 Ok(())
782}
783
784pub fn set_graph_executor_optimize(b: bool) {
792 f_set_graph_executor_optimize(b).unwrap();
793}
794
795#[allow(clippy::derive_partial_eq_without_eq)]
796#[derive(Debug, PartialEq)]
797pub struct Object {
798 c_ivalue: *mut CIValue,
799}
800
801impl Object {
802 pub fn method_is<T: Borrow<IValue>>(
805 &self,
806 method_name: &str,
807 ts: &[T],
808 ) -> Result<IValue, TchError> {
809 let ts = ts.iter().map(|x| x.borrow().to_c()).collect::<Result<Vec<_>, TchError>>()?;
810 let method_name = std::ffi::CString::new(method_name)?;
811 let c_ivalue = unsafe_torch_err!(ati_object_method_(
812 self.c_ivalue,
813 method_name.as_ptr(),
814 ts.as_ptr(),
815 ts.len() as c_int
816 ));
817 for x in ts {
818 unsafe { ati_free(x) }
819 }
820 IValue::from_c(c_ivalue)
821 }
822
823 pub fn getattr(&self, attr_name: &str) -> Result<IValue, TchError> {
825 let property_name = std::ffi::CString::new(attr_name)?;
826 let c_ivalue =
827 unsafe_torch_err!(ati_object_getattr_(self.c_ivalue, property_name.as_ptr()));
828 if c_ivalue.is_null() {
829 return Err(TchError::Torch(format!(
830 "Object.getattr(\"{attr_name}\") returned CIValue nullptr"
831 )));
832 }
833 IValue::from_c(c_ivalue)
834 }
835}
836
837impl Drop for Object {
838 fn drop(&mut self) {
839 unsafe_torch!(ati_free(self.c_ivalue))
840 }
841}
842
843#[cfg(test)]
844mod tests {
845 use super::IValue;
846 use std::f64::consts;
847
848 fn round_trip<T: Into<IValue>>(t: T) {
849 let ivalue: IValue = t.into();
850 let ivalue2 = IValue::from_c(ivalue.to_c().unwrap()).unwrap();
851 assert_eq!(ivalue, ivalue2);
852 }
853 #[test]
854 fn ivalue_round_trip() {
855 round_trip(());
856 round_trip(true);
857 round_trip(false);
858 round_trip(-1);
859 round_trip(42);
860 round_trip(15);
861 round_trip("".to_string());
862 round_trip("foobar".to_string());
863 round_trip((42, consts::PI));
864 round_trip(vec![42, 1337]);
865 round_trip(vec![consts::E, consts::PI, 299792458.00001]);
866 round_trip((vec![true, false, true, true], vec![consts::E, consts::PI, 299792458.00001]));
867 round_trip(vec![IValue::from(42), IValue::from("foobar")]);
868 round_trip(vec![
869 (IValue::from(42), IValue::from("foobar")),
870 (IValue::from("foo"), IValue::from("bar")),
871 ]);
872 }
873}