smpl_core/common/
betas.rs1use crate::codec::codec::SmplCodec;
2use log::info;
3use ndarray as nd;
4use ndarray::prelude::*;
5use ndarray_npy::NpzReader;
6use smpl_utils::io::FileLoader;
7use std::io::{Read, Seek};
8#[derive(Clone)]
10pub struct Betas {
11 pub betas: nd::Array1<f32>,
12}
13impl Default for Betas {
14 fn default() -> Self {
15 let num_betas = 10;
16 let betas = ndarray::Array1::<f32>::zeros(num_betas);
17 Self { betas }
18 }
19}
20impl Betas {
21 pub fn new(betas: nd::Array1<f32>) -> Self {
22 Self { betas }
23 }
24 pub fn new_empty(num_betas: usize) -> Self {
25 let betas = ndarray::Array1::<f32>::zeros(num_betas);
26 Self { betas }
27 }
28 #[allow(clippy::cast_possible_truncation)]
31 fn new_from_npz_reader<R: Read + Seek>(npz: &mut NpzReader<R>, truncate_nr_betas: Option<usize>) -> Self {
32 info!("NPZ keys - {:?}", npz.names().unwrap());
33 let betas: nd::Array1<f64> = npz.by_name("betas").unwrap();
34 let mut betas = betas.mapv(|x| x as f32);
35 if let Some(truncate_nr_betas) = truncate_nr_betas {
36 if truncate_nr_betas < betas.len() {
37 betas = betas.slice(s![0..truncate_nr_betas]).to_owned();
38 }
39 }
40 Self { betas }
41 }
42 #[cfg(not(target_arch = "wasm32"))]
43 #[allow(clippy::cast_possible_truncation)]
46 pub fn new_from_npz(npz_path: &str, truncate_nr_betas: Option<usize>) -> Self {
47 let mut npz = NpzReader::new(std::fs::File::open(npz_path).unwrap()).unwrap();
48 Self::new_from_npz_reader(&mut npz, truncate_nr_betas)
49 }
50 #[allow(clippy::cast_possible_truncation)]
53 pub async fn new_from_npz_async(npz_path: &str, truncate_nr_betas: Option<usize>) -> Self {
54 let reader = FileLoader::open(npz_path).await;
55 let mut npz = NpzReader::new(reader).unwrap();
56 Self::new_from_npz_reader(&mut npz, truncate_nr_betas)
57 }
58 pub fn new_from_smpl_codec(codec: &SmplCodec) -> Option<Self> {
60 codec.shape_parameters.as_ref().map(|betas| Self { betas: betas.clone() })
61 }
62 pub fn new_from_smpl_file(path: &str) -> Option<Self> {
64 let codec = SmplCodec::from_file(path);
65 Self::new_from_smpl_codec(&codec)
66 }
67}