tiny_recursive_rs/models/
loader.rs1use std::path::Path;
3use candle_core::{Device, DType};
4use candle_nn::VarBuilder;
5use crate::TRMConfig;
6use super::TinyRecursiveModel;
7
8pub fn load_model<P: AsRef<Path>>(
18 config: TRMConfig,
19 weights_path: P,
20 device: &Device,
21) -> crate::Result<TinyRecursiveModel> {
22 let dtype = DType::F32;
24 let vb = unsafe {
25 VarBuilder::from_mmaped_safetensors(
26 &[weights_path.as_ref()],
27 dtype,
28 device,
29 )?
30 };
31
32 TinyRecursiveModel::new(config, vb)
33}