smpl_core/common/
smpl_model.rs

1use super::{
2    betas::Betas,
3    expression::Expression,
4    outputs::SmplOutputDynamic,
5    pose::Pose,
6    smpl_options::SmplOptions,
7    types::{Gender, SmplType},
8};
9use crate::smpl_x::smpl_x_gpu::SmplXDynamic;
10use burn::{
11    backend::{Candle, NdArray, Wgpu},
12    prelude::Backend,
13    tensor::{Float, Int, Tensor},
14};
15use dyn_clone::DynClone;
16use enum_map::EnumMap;
17use gloss_utils::tensor::BurnBackend;
18use ndarray as nd;
19use std::any::Any;
20pub trait FaceModel<B: Backend>: Send + Sync + 'static + Any + DynClone {
21    fn expression2offsets(&self, expression: &Expression) -> Tensor<B, 2, Float>;
22}
23impl<B: Backend> Clone for Box<dyn FaceModel<B>> {
24    #[allow(unconditional_recursion)]
25    fn clone(&self) -> Box<dyn FaceModel<B>> {
26        self.clone()
27    }
28}
29/// Trait for a Smpl based model. Smpl-rs expects all Smpl models to implement
30/// this.
31pub trait SmplModel<B: Backend>: Send + Sync + 'static + Any + DynClone {
32    fn smpl_type(&self) -> SmplType;
33    fn gender(&self) -> Gender;
34    fn forward(&self, options: &SmplOptions, betas: &Betas, pose_raw: &Pose, expression: Option<&Expression>) -> SmplOutputDynamic<B>;
35    fn create_body_with_uv(&self, smpl_output: &SmplOutputDynamic<B>) -> SmplOutputDynamic<B>;
36    fn get_face_model(&self) -> &dyn FaceModel<B>;
37    fn betas2verts(&self, betas: &Betas) -> Tensor<B, 2, Float>;
38    fn verts2joints(&self, verts_t_pose: Tensor<B, 2, Float>) -> Tensor<B, 2, Float>;
39    fn compute_pose_correctives(&self, pose: &Pose) -> Tensor<B, 2, Float>;
40    fn compute_pose_feature(&self, pose: &Pose) -> nd::Array1<f32>;
41    #[allow(clippy::type_complexity)]
42    fn apply_pose(
43        &self,
44        verts_t_pose: &Tensor<B, 2, Float>,
45        normals: Option<&Tensor<B, 2, Float>>,
46        tangents: Option<&Tensor<B, 2, Float>>,
47        joints: &Tensor<B, 2, Float>,
48        lbs_weights: &Tensor<B, 2, Float>,
49        pose: &Pose,
50    ) -> (
51        Tensor<B, 2, Float>,
52        Option<Tensor<B, 2, Float>>,
53        Option<Tensor<B, 2, Float>>,
54        Tensor<B, 2, Float>,
55    );
56    fn faces(&self) -> &Tensor<B, 2, Int>;
57    fn faces_uv(&self) -> &Tensor<B, 2, Int>;
58    fn uv(&self) -> &Tensor<B, 2, Float>;
59    fn lbs_weights(&self) -> Tensor<B, 2, Float>;
60    fn lbs_weights_split(&self) -> Tensor<B, 2, Float>;
61    fn idx_split_2_merged(&self) -> Tensor<B, 1, Int>;
62    fn idx_split_2_merged_vec(&self) -> &Vec<usize>;
63    fn set_pose_dirs(&mut self, posedirs: Tensor<B, 2, Float>);
64    fn get_pose_dirs(&self) -> Tensor<B, 2, Float>;
65    fn get_expression_dirs(&self) -> Option<Tensor<B, 2, Float>>;
66    fn clone_dyn(&self) -> Box<dyn SmplModel<B>>;
67    fn as_any(&self) -> &dyn Any;
68}
69impl<B: Backend> Clone for Box<dyn SmplModel<B>> {
70    #[allow(unconditional_recursion)]
71    fn clone(&self) -> Box<dyn SmplModel<B>> {
72        self.clone()
73    }
74}
75/// A mapping from ``Gender`` to ``SmplModel``
76#[derive(Default, Clone)]
77pub struct Gender2Model<B: Backend> {
78    gender_to_model: EnumMap<Gender, Option<Box<dyn SmplModel<B>>>>,
79}
80#[derive(Default, Clone)]
81pub struct Gender2Path {
82    gender_to_path: EnumMap<Gender, Option<String>>,
83}
84/// A Dynamic Backend Cache for storing and easy access to ``SmplModels``
85/// This internally uses ``SmplCache<B>``
86#[allow(clippy::large_enum_variant)]
87#[derive(Clone)]
88pub enum SmplCacheDynamic {
89    NdArray(SmplCache<NdArray>),
90    Wgpu(SmplCache<Wgpu>),
91    Candle(SmplCache<Candle>),
92}
93impl Default for SmplCacheDynamic {
94    fn default() -> Self {
95        SmplCacheDynamic::Candle(SmplCache::default())
96    }
97}
98impl SmplCacheDynamic {
99    /// Get the Burn Backend the Cache was created using
100    pub fn get_backend(&self) -> BurnBackend {
101        match self {
102            SmplCacheDynamic::NdArray(_) => BurnBackend::NdArray,
103            SmplCacheDynamic::Wgpu(_) => BurnBackend::Wgpu,
104            SmplCacheDynamic::Candle(_) => BurnBackend::Candle,
105        }
106    }
107    /// Check whether the Cache has a certain model
108    pub fn has_model(&self, smpl_type: SmplType, gender: Gender) -> bool {
109        match self {
110            SmplCacheDynamic::NdArray(models) => models.has_model(smpl_type, gender),
111            SmplCacheDynamic::Wgpu(models) => models.has_model(smpl_type, gender),
112            SmplCacheDynamic::Candle(models) => models.has_model(smpl_type, gender),
113        }
114    }
115    /// Clear the Cache
116    pub fn remove_all_models(&mut self) {
117        match self {
118            SmplCacheDynamic::NdArray(models) => models.remove_all_models(),
119            SmplCacheDynamic::Wgpu(models) => models.remove_all_models(),
120            SmplCacheDynamic::Candle(models) => models.remove_all_models(),
121        }
122    }
123    pub fn has_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> bool {
124        match self {
125            SmplCacheDynamic::NdArray(models) => models.has_lazy_loading(smpl_type, gender),
126            SmplCacheDynamic::Wgpu(models) => models.has_lazy_loading(smpl_type, gender),
127            SmplCacheDynamic::Candle(models) => models.has_lazy_loading(smpl_type, gender),
128        }
129    }
130    pub fn get_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> Option<String> {
131        match self {
132            SmplCacheDynamic::NdArray(models) => models.get_lazy_loading(smpl_type, gender),
133            SmplCacheDynamic::Wgpu(models) => models.get_lazy_loading(smpl_type, gender),
134            SmplCacheDynamic::Candle(models) => models.get_lazy_loading(smpl_type, gender),
135        }
136    }
137    /// Set lazy loading using default paths
138    pub fn lazy_load_defaults(&mut self) {
139        self.set_lazy_loading(SmplType::SmplX, Gender::Neutral, "./data/smplx/SMPLX_neutral_array_f32_slim.npz");
140        self.set_lazy_loading(SmplType::SmplX, Gender::Male, "./data/smplx/SMPLX_male_array_f32_slim.npz");
141        self.set_lazy_loading(SmplType::SmplX, Gender::Female, "./data/smplx/SMPLX_female_array_f32_slim.npz");
142    }
143    /// Set lazy loading explicitly
144    pub fn set_lazy_loading(&mut self, smpl_type: SmplType, gender: Gender, path: &str) {
145        match self {
146            SmplCacheDynamic::NdArray(models) => models.set_lazy_loading(smpl_type, gender, path),
147            SmplCacheDynamic::Wgpu(models) => models.set_lazy_loading(smpl_type, gender, path),
148            SmplCacheDynamic::Candle(models) => models.set_lazy_loading(smpl_type, gender, path),
149        }
150    }
151    /// Add a Smpl Model created on a certain Burn Backend
152    pub fn add_model_from_dynamic_device(&mut self, model: SmplXDynamic, cache_models: bool) {
153        match (self, model) {
154            (SmplCacheDynamic::NdArray(models), SmplXDynamic::NdArray(model_ndarray)) => {
155                models.add_model(model_ndarray, cache_models);
156            }
157            (SmplCacheDynamic::Wgpu(models), SmplXDynamic::Wgpu(model_wgpu)) => {
158                models.add_model(model_wgpu, cache_models);
159            }
160            (SmplCacheDynamic::Candle(models), SmplXDynamic::Candle(model_candle)) => {
161                models.add_model(model_candle, cache_models);
162            }
163            _ => {
164                eprintln!("Model and backend type mismatch!");
165            }
166        }
167    }
168
169    #[cfg(not(target_arch = "wasm32"))]
170    pub fn add_model_from_type(&mut self, smpl_type: SmplType, path: &str, gender: Gender, max_num_betas: usize, num_expression_components: usize) {
171        match smpl_type {
172            SmplType::SmplX => {
173                let new_model = SmplXDynamic::new_from_npz(self, path, gender, max_num_betas, num_expression_components);
174                self.add_model_from_dynamic_device(new_model, true);
175            }
176            _ => panic!("Model loading for {smpl_type:?} if not supported yet!"),
177        };
178    }
179}
180/// A Cache for storing and easy access to ``SmplModels`` which is generic over
181/// Burn Backend
182#[derive(Default, Clone)]
183pub struct SmplCache<B: Backend> {
184    type_to_model: EnumMap<SmplType, Gender2Model<B>>,
185    type_to_path: EnumMap<SmplType, Gender2Path>,
186}
187impl<B: Backend> SmplCache<B> {
188    pub fn add_model<T: SmplModel<B> + FaceModel<B>>(&mut self, model: T, cache_models: bool) {
189        let smpl_type = model.smpl_type();
190        let gender = model.gender();
191        if !cache_models {
192            self.type_to_model = EnumMap::default();
193        }
194        self.type_to_model[smpl_type].gender_to_model[gender] = Some(Box::new(model));
195    }
196    pub fn remove_all_models(&mut self) {
197        self.type_to_model = EnumMap::default();
198    }
199    #[allow(clippy::borrowed_box)]
200    pub fn get_model_box_ref(&self, smpl_type: SmplType, gender: Gender) -> Option<&Box<dyn SmplModel<B>>> {
201        self.type_to_model[smpl_type].gender_to_model[gender].as_ref()
202    }
203    #[allow(clippy::redundant_closure_for_method_calls)]
204    pub fn get_model_ref(&self, smpl_type: SmplType, gender: Gender) -> Option<&dyn SmplModel<B>> {
205        let opt = &self.type_to_model[smpl_type].gender_to_model[gender];
206        let model = opt.as_ref().map(|x| x.as_ref());
207        model
208    }
209    #[allow(clippy::redundant_closure_for_method_calls)]
210    pub fn get_face_model_ref(&self, smpl_type: SmplType, gender: Gender) -> Option<&dyn FaceModel<B>> {
211        let opt = &self.type_to_model[smpl_type].gender_to_model[gender];
212        opt.as_ref().map(|model| model.get_face_model())
213    }
214    #[allow(clippy::redundant_closure_for_method_calls)]
215    pub fn get_model_mut(&mut self, smpl_type: SmplType, gender: Gender) -> Option<&mut dyn SmplModel<B>> {
216        let opt = &mut self.type_to_model[smpl_type].gender_to_model[gender];
217        let model = opt.as_mut().map(|x| x.as_mut());
218        model
219    }
220    pub fn has_model(&self, smpl_type: SmplType, gender: Gender) -> bool {
221        self.type_to_model[smpl_type].gender_to_model[gender].is_some()
222    }
223    pub fn has_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> bool {
224        self.type_to_path[smpl_type].gender_to_path[gender].is_some()
225    }
226    pub fn get_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> Option<String> {
227        self.type_to_path[smpl_type].gender_to_path[gender].clone()
228    }
229    pub fn set_lazy_loading(&mut self, smpl_type: SmplType, gender: Gender, path: &str) {
230        self.type_to_path[smpl_type].gender_to_path[gender] = Some(path.to_string());
231        assert!(
232            std::path::Path::new(&path).exists(),
233            "File at path {path} does not exist. Please follow the data download instructions in the README."
234        );
235    }
236}