Struct tch::TrainableCModule
source · pub struct TrainableCModule { /* private fields */ }
Expand description
The trainable version of a jit PyTorch module.
These modules can be created via the TorchScript python api.
Implementations§
source§impl TrainableCModule
impl TrainableCModule
sourcepub fn load<T: AsRef<Path>>(
module_path: T,
path: Path<'_>
) -> Result<Self, TchError>
pub fn load<T: AsRef<Path>>( module_path: T, path: Path<'_> ) -> Result<Self, TchError>
Loads a PyTorch saved JIT module from a file.
This function also adds the tensors from the JIT module to the VarStore path passed as argument so that the module can be trained.
sourcepub fn load_data<T: Read>(data: &mut T, path: Path<'_>) -> Result<Self, TchError>
pub fn load_data<T: Read>(data: &mut T, path: Path<'_>) -> Result<Self, TchError>
Loads a PyTorch saved JIT model from a read instance.
This function also adds the tensors from the JIT module to the VarStore path passed as argument so that the module can be trained.
pub fn save<T: AsRef<Path>>(&self, module_path: T) -> Result<(), TchError>
sourcepub fn f_set_train(&mut self) -> Result<(), TchError>
pub fn f_set_train(&mut self) -> Result<(), TchError>
Switches the module to training mode.
sourcepub fn f_set_eval(&mut self) -> Result<(), TchError>
pub fn f_set_eval(&mut self) -> Result<(), TchError>
Switches the module to evaluation mode.
sourcepub fn forward_ts<T: Borrow<Tensor>>(&self, ts: &[T]) -> Result<Tensor, TchError>
pub fn forward_ts<T: Borrow<Tensor>>(&self, ts: &[T]) -> Result<Tensor, TchError>
Performs the forward pass for a model on some specified tensor inputs.
sourcepub fn forward_is<T: Borrow<IValue>>(&self, ts: &[T]) -> Result<IValue, TchError>
pub fn forward_is<T: Borrow<IValue>>(&self, ts: &[T]) -> Result<IValue, TchError>
Performs the forward pass for a model on some specified ivalue inputs.