1use std::collections::HashMap;
2use std::path::Path;
3
4use snafu::{ResultExt, Snafu};
5use svod_dtype::DType;
6use svod_tensor::Tensor;
7
8pub type StateDict = HashMap<String, Tensor>;
9
10#[derive(Debug, Snafu)]
11pub enum Error {
12 #[snafu(display("failed to read file: {source}"))]
13 Io { source: std::io::Error },
14 #[snafu(display("failed to deserialize safetensors"))]
15 Safetensors { source: safetensors::SafeTensorError },
16 #[snafu(display("unsupported dtype in safetensors: {dtype}"))]
17 UnsupportedDtype { dtype: String },
18 #[snafu(display("missing key in state dict: {key}"))]
19 MissingKey { key: String },
20 #[snafu(display("{source}"))]
21 Tensor {
22 #[snafu(source(from(svod_tensor::error::Error, Box::new)))]
23 source: Box<svod_tensor::error::Error>,
24 },
25}
26
27type Result<T> = std::result::Result<T, Error>;
28
29pub fn load_safetensors(path: &Path) -> Result<StateDict> {
30 let data = std::fs::read(path).context(IoSnafu)?;
31 let tensors = safetensors::SafeTensors::deserialize(&data).context(SafetensorsSnafu)?;
32 let mut sd = StateDict::new();
33 for (name, view) in tensors.tensors() {
34 let dtype = convert_dtype(view.dtype())?;
35 let shape: Vec<usize> = view.shape().to_vec();
36 let tensor = Tensor::from_raw_bytes(view.data(), &shape, dtype).context(TensorSnafu)?;
37 sd.insert(name.to_string(), tensor);
38 }
39 Ok(sd)
40}
41
42fn convert_dtype(dt: safetensors::Dtype) -> Result<DType> {
43 use safetensors::Dtype as ST;
44 match dt {
45 ST::F32 => Ok(DType::Float32),
46 ST::F16 => Ok(DType::Float16),
47 ST::BF16 => Ok(DType::BFloat16),
48 ST::F64 => Ok(DType::Float64),
49 ST::I32 => Ok(DType::Int32),
50 ST::I64 => Ok(DType::Int64),
51 ST::I16 => Ok(DType::Int16),
52 ST::I8 => Ok(DType::Int8),
53 ST::U8 => Ok(DType::UInt8),
54 ST::BOOL => Ok(DType::Bool),
55 other => Err(Error::UnsupportedDtype { dtype: format!("{other:?}") }),
56 }
57}
58
59pub trait HasStateDict {
60 fn state_dict(&self, prefix: &str) -> StateDict;
61 fn load_state_dict(&mut self, sd: &StateDict, prefix: &str) -> Result<()>;
62}
63
64pub fn get_tensor(sd: &StateDict, key: &str) -> Result<Tensor> {
66 sd.get(key).cloned().ok_or_else(|| Error::MissingKey { key: key.to_string() })
67}
68
69pub fn prefixed(prefix: &str, name: &str) -> String {
71 if prefix.is_empty() { name.to_string() } else { format!("{prefix}.{name}") }
72}
73
74#[macro_export]
77macro_rules! state_field {
78 ($sd:expr, $prefix:expr, $self:ident, [$($field:ident),+ $(,)?]) => {
79 $(
80 $sd.insert(
81 $crate::state::prefixed($prefix, stringify!($field)),
82 $self.$field.clone(),
83 );
84 )+
85 };
86}
87
88#[macro_export]
91macro_rules! load_state_field {
92 ($self:ident, $sd:expr, $prefix:expr, [$($field:ident),+ $(,)?]) => {
93 $(
94 $self.$field = $crate::state::get_tensor(
95 $sd,
96 &$crate::state::prefixed($prefix, stringify!($field)),
97 )?;
98 )+
99 };
100}