smpl_core/common/
betas.rs

1use crate::{codec::codec::SmplCodec, AppBackend};
2use burn::{prelude::Backend, tensor::Tensor};
3use gloss_utils::bshare::ToBurn;
4use log::info;
5use ndarray as nd;
6use ndarray::prelude::*;
7use ndarray_npy::NpzReader;
8use smpl_utils::io::FileLoader;
9use std::io::{Read, Seek};
10/// Component for Smpl Betas or Shape Parameters
11#[derive(Clone)]
12pub struct BetasG<B: Backend> {
13    pub device: B::Device,
14    pub betas: Tensor<B, 1>,
15}
16impl<B: Backend> Default for BetasG<B> {
17    fn default() -> Self {
18        let device = B::Device::default();
19        let num_betas = 10;
20        let betas = Tensor::<B, 1>::zeros([num_betas], &device);
21        Self { device, betas }
22    }
23}
24impl<B: Backend> BetasG<B> {
25    pub fn new(betas: Tensor<B, 1>) -> Self {
26        Self {
27            device: betas.device(),
28            betas,
29        }
30    }
31    pub fn new_empty(num_betas: usize) -> Self {
32        let device = B::Device::default();
33        let betas = Tensor::<B, 1>::zeros([num_betas], &device);
34        Self { device, betas }
35    }
36    pub fn new_from_ndarray(betas: nd::Array1<f32>) -> Self {
37        let device = B::Device::default();
38        let betas = betas.into_burn(&device);
39        Self::new(betas)
40    }
41    /// # Panics
42    /// Will panic if the file cannot be read
43    #[allow(clippy::cast_possible_truncation)]
44    fn new_from_npz_reader<R: Read + Seek>(npz: &mut NpzReader<R>, truncate_nr_betas: Option<usize>) -> Self {
45        info!("NPZ keys - {:?}", npz.names().unwrap());
46        let device = B::Device::default();
47        let betas: nd::Array1<f64> = npz.by_name("betas").unwrap();
48        let mut betas = betas.mapv(|x| x as f32);
49        if let Some(truncate_nr_betas) = truncate_nr_betas {
50            if truncate_nr_betas < betas.len() {
51                betas = betas.slice(s![0..truncate_nr_betas]).to_owned();
52            }
53        }
54        let betas = betas.into_burn(&device);
55        Self { device, betas }
56    }
57    #[cfg(not(target_arch = "wasm32"))]
58    /// # Panics
59    /// Will panic if the file cannot be read
60    #[allow(clippy::cast_possible_truncation)]
61    pub fn new_from_npz(npz_path: &str, truncate_nr_betas: Option<usize>) -> Self {
62        let mut npz = NpzReader::new(std::fs::File::open(npz_path).unwrap()).unwrap();
63        Self::new_from_npz_reader(&mut npz, truncate_nr_betas)
64    }
65    /// # Panics
66    /// Will panic if the file cannot be read
67    #[allow(clippy::cast_possible_truncation)]
68    pub async fn new_from_npz_async(npz_path: &str, truncate_nr_betas: Option<usize>) -> Self {
69        let reader = FileLoader::open(npz_path).await;
70        let mut npz = NpzReader::new(reader).unwrap();
71        Self::new_from_npz_reader(&mut npz, truncate_nr_betas)
72    }
73    /// Create a new ``Betas`` component from a ``SmplCodec``
74    pub fn new_from_smpl_codec(codec: &SmplCodec) -> Option<Self> {
75        codec.shape_parameters.as_ref().map(|betas| Self::new_from_ndarray(betas.clone()))
76    }
77    /// Create a new ``Betas`` component from a ``.smpl`` file
78    pub fn new_from_smpl_file(path: &str) -> Option<Self> {
79        let codec = SmplCodec::from_file(path);
80        Self::new_from_smpl_codec(&codec)
81    }
82}
83pub type Betas = BetasG<AppBackend>;