smpl_core/common/
betas.rs

1use crate::codec::codec::SmplCodec;
2use log::info;
3use ndarray as nd;
4use ndarray::prelude::*;
5use ndarray_npy::NpzReader;
6use smpl_utils::io::FileLoader;
7use std::io::{Read, Seek};
8/// Component for Smpl Betas or Shape Parameters
9#[derive(Clone)]
10pub struct Betas {
11    pub betas: nd::Array1<f32>,
12}
13impl Default for Betas {
14    fn default() -> Self {
15        let num_betas = 10;
16        let betas = ndarray::Array1::<f32>::zeros(num_betas);
17        Self { betas }
18    }
19}
20impl Betas {
21    pub fn new(betas: nd::Array1<f32>) -> Self {
22        Self { betas }
23    }
24    pub fn new_empty(num_betas: usize) -> Self {
25        let betas = ndarray::Array1::<f32>::zeros(num_betas);
26        Self { betas }
27    }
28    /// # Panics
29    /// Will panic if the file cannot be read
30    #[allow(clippy::cast_possible_truncation)]
31    fn new_from_npz_reader<R: Read + Seek>(npz: &mut NpzReader<R>, truncate_nr_betas: Option<usize>) -> Self {
32        info!("NPZ keys - {:?}", npz.names().unwrap());
33        let betas: nd::Array1<f64> = npz.by_name("betas").unwrap();
34        let mut betas = betas.mapv(|x| x as f32);
35        if let Some(truncate_nr_betas) = truncate_nr_betas {
36            if truncate_nr_betas < betas.len() {
37                betas = betas.slice(s![0..truncate_nr_betas]).to_owned();
38            }
39        }
40        Self { betas }
41    }
42    #[cfg(not(target_arch = "wasm32"))]
43    /// # Panics
44    /// Will panic if the file cannot be read
45    #[allow(clippy::cast_possible_truncation)]
46    pub fn new_from_npz(npz_path: &str, truncate_nr_betas: Option<usize>) -> Self {
47        let mut npz = NpzReader::new(std::fs::File::open(npz_path).unwrap()).unwrap();
48        Self::new_from_npz_reader(&mut npz, truncate_nr_betas)
49    }
50    /// # Panics
51    /// Will panic if the file cannot be read
52    #[allow(clippy::cast_possible_truncation)]
53    pub async fn new_from_npz_async(npz_path: &str, truncate_nr_betas: Option<usize>) -> Self {
54        let reader = FileLoader::open(npz_path).await;
55        let mut npz = NpzReader::new(reader).unwrap();
56        Self::new_from_npz_reader(&mut npz, truncate_nr_betas)
57    }
58    /// Create a new ``Betas`` component from a ``SmplCodec``
59    pub fn new_from_smpl_codec(codec: &SmplCodec) -> Option<Self> {
60        codec.shape_parameters.as_ref().map(|betas| Self { betas: betas.clone() })
61    }
62    /// Create a new ``Betas`` component from a ``.smpl`` file
63    pub fn new_from_smpl_file(path: &str) -> Option<Self> {
64        let codec = SmplCodec::from_file(path);
65        Self::new_from_smpl_codec(&codec)
66    }
67}