smpl_core/common/
betas.rs1use crate::{codec::codec::SmplCodec, AppBackend};
2use burn::{prelude::Backend, tensor::Tensor};
3use gloss_utils::bshare::ToBurn;
4use log::info;
5use ndarray as nd;
6use ndarray::prelude::*;
7use ndarray_npy::NpzReader;
8use smpl_utils::io::FileLoader;
9use std::io::{Read, Seek};
10#[derive(Clone)]
12pub struct BetasG<B: Backend> {
13 pub device: B::Device,
14 pub betas: Tensor<B, 1>,
15}
16impl<B: Backend> Default for BetasG<B> {
17 fn default() -> Self {
18 let device = B::Device::default();
19 let num_betas = 10;
20 let betas = Tensor::<B, 1>::zeros([num_betas], &device);
21 Self { device, betas }
22 }
23}
24impl<B: Backend> BetasG<B> {
25 pub fn new(betas: Tensor<B, 1>) -> Self {
26 Self {
27 device: betas.device(),
28 betas,
29 }
30 }
31 pub fn new_empty(num_betas: usize) -> Self {
32 let device = B::Device::default();
33 let betas = Tensor::<B, 1>::zeros([num_betas], &device);
34 Self { device, betas }
35 }
36 pub fn new_from_ndarray(betas: nd::Array1<f32>) -> Self {
37 let device = B::Device::default();
38 let betas = betas.into_burn(&device);
39 Self::new(betas)
40 }
41 #[allow(clippy::cast_possible_truncation)]
44 fn new_from_npz_reader<R: Read + Seek>(npz: &mut NpzReader<R>, truncate_nr_betas: Option<usize>) -> Self {
45 info!("NPZ keys - {:?}", npz.names().unwrap());
46 let device = B::Device::default();
47 let betas: nd::Array1<f64> = npz.by_name("betas").unwrap();
48 let mut betas = betas.mapv(|x| x as f32);
49 if let Some(truncate_nr_betas) = truncate_nr_betas {
50 if truncate_nr_betas < betas.len() {
51 betas = betas.slice(s![0..truncate_nr_betas]).to_owned();
52 }
53 }
54 let betas = betas.into_burn(&device);
55 Self { device, betas }
56 }
57 #[cfg(not(target_arch = "wasm32"))]
58 #[allow(clippy::cast_possible_truncation)]
61 pub fn new_from_npz(npz_path: &str, truncate_nr_betas: Option<usize>) -> Self {
62 let mut npz = NpzReader::new(std::fs::File::open(npz_path).unwrap()).unwrap();
63 Self::new_from_npz_reader(&mut npz, truncate_nr_betas)
64 }
65 #[allow(clippy::cast_possible_truncation)]
68 pub async fn new_from_npz_async(npz_path: &str, truncate_nr_betas: Option<usize>) -> Self {
69 let reader = FileLoader::open(npz_path).await;
70 let mut npz = NpzReader::new(reader).unwrap();
71 Self::new_from_npz_reader(&mut npz, truncate_nr_betas)
72 }
73 pub fn new_from_smpl_codec(codec: &SmplCodec) -> Option<Self> {
75 codec.shape_parameters.as_ref().map(|betas| Self::new_from_ndarray(betas.clone()))
76 }
77 pub fn new_from_smpl_file(path: &str) -> Option<Self> {
79 let codec = SmplCodec::from_file(path);
80 Self::new_from_smpl_codec(&codec)
81 }
82}
83pub type Betas = BetasG<AppBackend>;