tiny_recursive_rs/models/
loader.rs

1/// Weight loading from safetensors files
2use std::path::Path;
3use candle_core::{Device, DType};
4use candle_nn::VarBuilder;
5use crate::TRMConfig;
6use super::TinyRecursiveModel;
7
8/// Load model from safetensors file
9///
10/// # Arguments
11/// * `config` - Model configuration
12/// * `weights_path` - Path to safetensors file
13/// * `device` - Device to load model on
14///
15/// # Returns
16/// Loaded TinyRecursiveModel
17pub fn load_model<P: AsRef<Path>>(
18    config: TRMConfig,
19    weights_path: P,
20    device: &Device,
21) -> crate::Result<TinyRecursiveModel> {
22    // Load weights using Candle's built-in safetensors support
23    let dtype = DType::F32;
24    let vb = unsafe {
25        VarBuilder::from_mmaped_safetensors(
26            &[weights_path.as_ref()],
27            dtype,
28            device,
29        )?
30    };
31
32    TinyRecursiveModel::new(config, vb)
33}