smpl_core/common/
smpl_model.rs

1use super::{
2    betas::Betas,
3    expression::Expression,
4    outputs::SmplOutputDynamic,
5    pose::Pose,
6    smpl_options::SmplOptions,
7    types::{Gender, SmplType},
8};
9use crate::smpl_x::smpl_x_gpu::SmplXDynamic;
10use burn::{
11    backend::{Candle, NdArray, Wgpu},
12    prelude::Backend,
13    tensor::{Float, Int, Tensor},
14};
15use dyn_clone::DynClone;
16use enum_map::EnumMap;
17use gloss_utils::tensor::BurnBackend;
18use ndarray as nd;
19use std::any::Any;
20/// Trait for a Smpl based model. Smpl-rs expects all Smpl models to implement
21/// this.
22pub trait SmplModel<B: Backend>: Send + Sync + 'static + Any + DynClone {
23    fn smpl_type(&self) -> SmplType;
24    fn gender(&self) -> Gender;
25    fn forward(&self, options: &SmplOptions, betas: &Betas, pose_raw: &Pose, expression: Option<&Expression>) -> SmplOutputDynamic<B>;
26    fn create_body_with_uv(&self, smpl_output: &SmplOutputDynamic<B>) -> SmplOutputDynamic<B>;
27    fn expression2offsets(&self, expression: &Expression) -> Tensor<B, 2, Float>;
28    fn betas2verts(&self, betas: &Betas) -> Tensor<B, 2, Float>;
29    fn verts2joints(&self, verts_t_pose: Tensor<B, 2, Float>) -> Tensor<B, 2, Float>;
30    fn compute_pose_correctives(&self, pose: &Pose) -> Tensor<B, 2, Float>;
31    fn compute_pose_feature(&self, pose: &Pose) -> nd::Array1<f32>;
32    #[allow(clippy::type_complexity)]
33    fn apply_pose(
34        &self,
35        verts_t_pose: &Tensor<B, 2, Float>,
36        normals: Option<&Tensor<B, 2, Float>>,
37        tangents: Option<&Tensor<B, 2, Float>>,
38        joints: &Tensor<B, 2, Float>,
39        lbs_weights: &Tensor<B, 2, Float>,
40        pose: &Pose,
41    ) -> (
42        Tensor<B, 2, Float>,
43        Option<Tensor<B, 2, Float>>,
44        Option<Tensor<B, 2, Float>>,
45        Tensor<B, 2, Float>,
46    );
47    fn faces(&self) -> &Tensor<B, 2, Int>;
48    fn faces_uv(&self) -> &Tensor<B, 2, Int>;
49    fn uv(&self) -> &Tensor<B, 2, Float>;
50    fn lbs_weights(&self) -> Tensor<B, 2, Float>;
51    fn lbs_weights_split(&self) -> Tensor<B, 2, Float>;
52    fn idx_split_2_merged(&self) -> Tensor<B, 1, Int>;
53    fn idx_split_2_merged_vec(&self) -> &Vec<usize>;
54    fn set_pose_dirs(&mut self, posedirs: Tensor<B, 2, Float>);
55    fn get_pose_dirs(&self) -> Tensor<B, 2, Float>;
56    fn get_expression_dirs(&self) -> Option<Tensor<B, 2, Float>>;
57    fn clone_dyn(&self) -> Box<dyn SmplModel<B>>;
58    fn as_any(&self) -> &dyn Any;
59}
60impl<B: Backend> Clone for Box<dyn SmplModel<B>> {
61    #[allow(unconditional_recursion)]
62    fn clone(&self) -> Box<dyn SmplModel<B>> {
63        self.clone()
64    }
65}
66/// A mapping from ``Gender`` to ``SmplModel``
67#[derive(Default, Clone)]
68pub struct Gender2Model<B: Backend> {
69    gender_to_model: EnumMap<Gender, Option<Box<dyn SmplModel<B>>>>,
70}
71#[derive(Default, Clone)]
72pub struct Gender2Path {
73    gender_to_path: EnumMap<Gender, Option<String>>,
74}
75/// A Dynamic Backend Cache for storing and easy access to ``SmplModels``
76/// This internally uses ``SmplCache<B>``
77#[allow(clippy::large_enum_variant)]
78#[derive(Clone)]
79pub enum SmplCacheDynamic {
80    NdArray(SmplCache<NdArray>),
81    Wgpu(SmplCache<Wgpu>),
82    Candle(SmplCache<Candle>),
83}
84impl Default for SmplCacheDynamic {
85    fn default() -> Self {
86        SmplCacheDynamic::Candle(SmplCache::default())
87    }
88}
89impl SmplCacheDynamic {
90    /// Get the Burn Backend the Cache was created using
91    pub fn get_backend(&self) -> BurnBackend {
92        match self {
93            SmplCacheDynamic::NdArray(_) => BurnBackend::NdArray,
94            SmplCacheDynamic::Wgpu(_) => BurnBackend::Wgpu,
95            SmplCacheDynamic::Candle(_) => BurnBackend::Candle,
96        }
97    }
98    /// Check whether the Cache has a certain model
99    pub fn has_model(&self, smpl_type: SmplType, gender: Gender) -> bool {
100        match self {
101            SmplCacheDynamic::NdArray(models) => models.has_model(smpl_type, gender),
102            SmplCacheDynamic::Wgpu(models) => models.has_model(smpl_type, gender),
103            SmplCacheDynamic::Candle(models) => models.has_model(smpl_type, gender),
104        }
105    }
106    /// Clear the Cache
107    pub fn remove_all_models(&mut self) {
108        match self {
109            SmplCacheDynamic::NdArray(models) => models.remove_all_models(),
110            SmplCacheDynamic::Wgpu(models) => models.remove_all_models(),
111            SmplCacheDynamic::Candle(models) => models.remove_all_models(),
112        }
113    }
114    pub fn has_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> bool {
115        match self {
116            SmplCacheDynamic::NdArray(models) => models.has_lazy_loading(smpl_type, gender),
117            SmplCacheDynamic::Wgpu(models) => models.has_lazy_loading(smpl_type, gender),
118            SmplCacheDynamic::Candle(models) => models.has_lazy_loading(smpl_type, gender),
119        }
120    }
121    pub fn get_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> Option<String> {
122        match self {
123            SmplCacheDynamic::NdArray(models) => models.get_lazy_loading(smpl_type, gender),
124            SmplCacheDynamic::Wgpu(models) => models.get_lazy_loading(smpl_type, gender),
125            SmplCacheDynamic::Candle(models) => models.get_lazy_loading(smpl_type, gender),
126        }
127    }
128    /// Set lazy loading using default paths
129    pub fn lazy_load_defaults(&mut self) {
130        self.set_lazy_loading(SmplType::SmplX, Gender::Neutral, "./data/smplx/SMPLX_neutral_array_f32_slim.npz");
131        self.set_lazy_loading(SmplType::SmplX, Gender::Male, "./data/smplx/SMPLX_male_array_f32_slim.npz");
132        self.set_lazy_loading(SmplType::SmplX, Gender::Female, "./data/smplx/SMPLX_female_array_f32_slim.npz");
133    }
134    /// Set lazy loading explicitly
135    pub fn set_lazy_loading(&mut self, smpl_type: SmplType, gender: Gender, path: &str) {
136        match self {
137            SmplCacheDynamic::NdArray(models) => models.set_lazy_loading(smpl_type, gender, path),
138            SmplCacheDynamic::Wgpu(models) => models.set_lazy_loading(smpl_type, gender, path),
139            SmplCacheDynamic::Candle(models) => models.set_lazy_loading(smpl_type, gender, path),
140        }
141    }
142    /// Add a Smpl Model created on a certain Burn Backend
143    pub fn add_model_from_dynamic_device(&mut self, model: SmplXDynamic, cache_models: bool) {
144        match (self, model) {
145            (SmplCacheDynamic::NdArray(models), SmplXDynamic::NdArray(model_ndarray)) => {
146                models.add_model(model_ndarray, cache_models);
147            }
148            (SmplCacheDynamic::Wgpu(models), SmplXDynamic::Wgpu(model_wgpu)) => {
149                models.add_model(model_wgpu, cache_models);
150            }
151            (SmplCacheDynamic::Candle(models), SmplXDynamic::Candle(model_candle)) => {
152                models.add_model(model_candle, cache_models);
153            }
154            _ => {
155                eprintln!("Model and backend type mismatch!");
156            }
157        }
158    }
159
160    #[cfg(not(target_arch = "wasm32"))]
161    pub fn add_model_from_type(&mut self, smpl_type: SmplType, path: &str, gender: Gender, max_num_betas: usize, num_expression_components: usize) {
162        match smpl_type {
163            SmplType::SmplX => {
164                let new_model = SmplXDynamic::new_from_npz(self, path, gender, max_num_betas, num_expression_components);
165                self.add_model_from_dynamic_device(new_model, true);
166            }
167            _ => panic!("Model loading for {smpl_type:?} if not supported yet!"),
168        };
169    }
170}
171/// A Cache for storing and easy access to ``SmplModels`` which is generic over
172/// Burn Backend
173#[derive(Default, Clone)]
174pub struct SmplCache<B: Backend> {
175    type_to_model: EnumMap<SmplType, Gender2Model<B>>,
176    type_to_path: EnumMap<SmplType, Gender2Path>,
177}
178impl<B: Backend> SmplCache<B> {
179    pub fn add_model<T: SmplModel<B>>(&mut self, model: T, cache_models: bool) {
180        let smpl_type = model.smpl_type();
181        let gender = model.gender();
182        if !cache_models {
183            self.type_to_model = EnumMap::default();
184        }
185        self.type_to_model[smpl_type].gender_to_model[gender] = Some(Box::new(model));
186    }
187    pub fn remove_all_models(&mut self) {
188        self.type_to_model = EnumMap::default();
189    }
190    #[allow(clippy::borrowed_box)]
191    pub fn get_model_box_ref(&self, smpl_type: SmplType, gender: Gender) -> Option<&Box<dyn SmplModel<B>>> {
192        self.type_to_model[smpl_type].gender_to_model[gender].as_ref()
193    }
194    #[allow(clippy::redundant_closure_for_method_calls)]
195    pub fn get_model_ref(&self, smpl_type: SmplType, gender: Gender) -> Option<&dyn SmplModel<B>> {
196        let opt = &self.type_to_model[smpl_type].gender_to_model[gender];
197        let model = opt.as_ref().map(|x| x.as_ref());
198        model
199    }
200    #[allow(clippy::redundant_closure_for_method_calls)]
201    pub fn get_model_mut(&mut self, smpl_type: SmplType, gender: Gender) -> Option<&mut dyn SmplModel<B>> {
202        let opt = &mut self.type_to_model[smpl_type].gender_to_model[gender];
203        let model = opt.as_mut().map(|x| x.as_mut());
204        model
205    }
206    pub fn has_model(&self, smpl_type: SmplType, gender: Gender) -> bool {
207        self.type_to_model[smpl_type].gender_to_model[gender].is_some()
208    }
209    pub fn has_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> bool {
210        self.type_to_path[smpl_type].gender_to_path[gender].is_some()
211    }
212    pub fn get_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> Option<String> {
213        self.type_to_path[smpl_type].gender_to_path[gender].clone()
214    }
215    pub fn set_lazy_loading(&mut self, smpl_type: SmplType, gender: Gender, path: &str) {
216        self.type_to_path[smpl_type].gender_to_path[gender] = Some(path.to_string());
217        assert!(
218            std::path::Path::new(&path).exists(),
219            "File at path {path} does not exist. Please follow the data download instructions in the README."
220        );
221    }
222}