1use super::{
2 betas::Betas,
3 expression::Expression,
4 outputs::SmplOutputDynamic,
5 pose::Pose,
6 smpl_options::SmplOptions,
7 types::{Gender, SmplType},
8};
9use crate::smpl_x::smpl_x_gpu::SmplXDynamic;
10use burn::{
11 backend::{Candle, NdArray, Wgpu},
12 prelude::Backend,
13 tensor::{Float, Int, Tensor},
14};
15use dyn_clone::DynClone;
16use enum_map::EnumMap;
17use gloss_utils::tensor::BurnBackend;
18use ndarray as nd;
19use std::any::Any;
20pub trait SmplModel<B: Backend>: Send + Sync + 'static + Any + DynClone {
23 fn smpl_type(&self) -> SmplType;
24 fn gender(&self) -> Gender;
25 fn forward(&self, options: &SmplOptions, betas: &Betas, pose_raw: &Pose, expression: Option<&Expression>) -> SmplOutputDynamic<B>;
26 fn create_body_with_uv(&self, smpl_output: &SmplOutputDynamic<B>) -> SmplOutputDynamic<B>;
27 fn expression2offsets(&self, expression: &Expression) -> Tensor<B, 2, Float>;
28 fn betas2verts(&self, betas: &Betas) -> Tensor<B, 2, Float>;
29 fn verts2joints(&self, verts_t_pose: Tensor<B, 2, Float>) -> Tensor<B, 2, Float>;
30 fn compute_pose_correctives(&self, pose: &Pose) -> Tensor<B, 2, Float>;
31 fn compute_pose_feature(&self, pose: &Pose) -> nd::Array1<f32>;
32 #[allow(clippy::type_complexity)]
33 fn apply_pose(
34 &self,
35 verts_t_pose: &Tensor<B, 2, Float>,
36 normals: Option<&Tensor<B, 2, Float>>,
37 tangents: Option<&Tensor<B, 2, Float>>,
38 joints: &Tensor<B, 2, Float>,
39 lbs_weights: &Tensor<B, 2, Float>,
40 pose: &Pose,
41 ) -> (
42 Tensor<B, 2, Float>,
43 Option<Tensor<B, 2, Float>>,
44 Option<Tensor<B, 2, Float>>,
45 Tensor<B, 2, Float>,
46 );
47 fn faces(&self) -> &Tensor<B, 2, Int>;
48 fn faces_uv(&self) -> &Tensor<B, 2, Int>;
49 fn uv(&self) -> &Tensor<B, 2, Float>;
50 fn lbs_weights(&self) -> Tensor<B, 2, Float>;
51 fn lbs_weights_split(&self) -> Tensor<B, 2, Float>;
52 fn idx_split_2_merged(&self) -> Tensor<B, 1, Int>;
53 fn idx_split_2_merged_vec(&self) -> &Vec<usize>;
54 fn set_pose_dirs(&mut self, posedirs: Tensor<B, 2, Float>);
55 fn get_pose_dirs(&self) -> Tensor<B, 2, Float>;
56 fn get_expression_dirs(&self) -> Option<Tensor<B, 2, Float>>;
57 fn clone_dyn(&self) -> Box<dyn SmplModel<B>>;
58 fn as_any(&self) -> &dyn Any;
59}
60impl<B: Backend> Clone for Box<dyn SmplModel<B>> {
61 #[allow(unconditional_recursion)]
62 fn clone(&self) -> Box<dyn SmplModel<B>> {
63 self.clone()
64 }
65}
66#[derive(Default, Clone)]
68pub struct Gender2Model<B: Backend> {
69 gender_to_model: EnumMap<Gender, Option<Box<dyn SmplModel<B>>>>,
70}
71#[derive(Default, Clone)]
72pub struct Gender2Path {
73 gender_to_path: EnumMap<Gender, Option<String>>,
74}
75#[allow(clippy::large_enum_variant)]
78#[derive(Clone)]
79pub enum SmplCacheDynamic {
80 NdArray(SmplCache<NdArray>),
81 Wgpu(SmplCache<Wgpu>),
82 Candle(SmplCache<Candle>),
83}
84impl Default for SmplCacheDynamic {
85 fn default() -> Self {
86 SmplCacheDynamic::Candle(SmplCache::default())
87 }
88}
89impl SmplCacheDynamic {
90 pub fn get_backend(&self) -> BurnBackend {
92 match self {
93 SmplCacheDynamic::NdArray(_) => BurnBackend::NdArray,
94 SmplCacheDynamic::Wgpu(_) => BurnBackend::Wgpu,
95 SmplCacheDynamic::Candle(_) => BurnBackend::Candle,
96 }
97 }
98 pub fn has_model(&self, smpl_type: SmplType, gender: Gender) -> bool {
100 match self {
101 SmplCacheDynamic::NdArray(models) => models.has_model(smpl_type, gender),
102 SmplCacheDynamic::Wgpu(models) => models.has_model(smpl_type, gender),
103 SmplCacheDynamic::Candle(models) => models.has_model(smpl_type, gender),
104 }
105 }
106 pub fn remove_all_models(&mut self) {
108 match self {
109 SmplCacheDynamic::NdArray(models) => models.remove_all_models(),
110 SmplCacheDynamic::Wgpu(models) => models.remove_all_models(),
111 SmplCacheDynamic::Candle(models) => models.remove_all_models(),
112 }
113 }
114 pub fn has_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> bool {
115 match self {
116 SmplCacheDynamic::NdArray(models) => models.has_lazy_loading(smpl_type, gender),
117 SmplCacheDynamic::Wgpu(models) => models.has_lazy_loading(smpl_type, gender),
118 SmplCacheDynamic::Candle(models) => models.has_lazy_loading(smpl_type, gender),
119 }
120 }
121 pub fn get_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> Option<String> {
122 match self {
123 SmplCacheDynamic::NdArray(models) => models.get_lazy_loading(smpl_type, gender),
124 SmplCacheDynamic::Wgpu(models) => models.get_lazy_loading(smpl_type, gender),
125 SmplCacheDynamic::Candle(models) => models.get_lazy_loading(smpl_type, gender),
126 }
127 }
128 pub fn lazy_load_defaults(&mut self) {
130 self.set_lazy_loading(SmplType::SmplX, Gender::Neutral, "./data/smplx/SMPLX_neutral_array_f32_slim.npz");
131 self.set_lazy_loading(SmplType::SmplX, Gender::Male, "./data/smplx/SMPLX_male_array_f32_slim.npz");
132 self.set_lazy_loading(SmplType::SmplX, Gender::Female, "./data/smplx/SMPLX_female_array_f32_slim.npz");
133 }
134 pub fn set_lazy_loading(&mut self, smpl_type: SmplType, gender: Gender, path: &str) {
136 match self {
137 SmplCacheDynamic::NdArray(models) => models.set_lazy_loading(smpl_type, gender, path),
138 SmplCacheDynamic::Wgpu(models) => models.set_lazy_loading(smpl_type, gender, path),
139 SmplCacheDynamic::Candle(models) => models.set_lazy_loading(smpl_type, gender, path),
140 }
141 }
142 pub fn add_model_from_dynamic_device(&mut self, model: SmplXDynamic, cache_models: bool) {
144 match (self, model) {
145 (SmplCacheDynamic::NdArray(models), SmplXDynamic::NdArray(model_ndarray)) => {
146 models.add_model(model_ndarray, cache_models);
147 }
148 (SmplCacheDynamic::Wgpu(models), SmplXDynamic::Wgpu(model_wgpu)) => {
149 models.add_model(model_wgpu, cache_models);
150 }
151 (SmplCacheDynamic::Candle(models), SmplXDynamic::Candle(model_candle)) => {
152 models.add_model(model_candle, cache_models);
153 }
154 _ => {
155 eprintln!("Model and backend type mismatch!");
156 }
157 }
158 }
159
160 #[cfg(not(target_arch = "wasm32"))]
161 pub fn add_model_from_type(&mut self, smpl_type: SmplType, path: &str, gender: Gender, max_num_betas: usize, num_expression_components: usize) {
162 match smpl_type {
163 SmplType::SmplX => {
164 let new_model = SmplXDynamic::new_from_npz(self, path, gender, max_num_betas, num_expression_components);
165 self.add_model_from_dynamic_device(new_model, true);
166 }
167 _ => panic!("Model loading for {smpl_type:?} if not supported yet!"),
168 };
169 }
170}
171#[derive(Default, Clone)]
174pub struct SmplCache<B: Backend> {
175 type_to_model: EnumMap<SmplType, Gender2Model<B>>,
176 type_to_path: EnumMap<SmplType, Gender2Path>,
177}
178impl<B: Backend> SmplCache<B> {
179 pub fn add_model<T: SmplModel<B>>(&mut self, model: T, cache_models: bool) {
180 let smpl_type = model.smpl_type();
181 let gender = model.gender();
182 if !cache_models {
183 self.type_to_model = EnumMap::default();
184 }
185 self.type_to_model[smpl_type].gender_to_model[gender] = Some(Box::new(model));
186 }
187 pub fn remove_all_models(&mut self) {
188 self.type_to_model = EnumMap::default();
189 }
190 #[allow(clippy::borrowed_box)]
191 pub fn get_model_box_ref(&self, smpl_type: SmplType, gender: Gender) -> Option<&Box<dyn SmplModel<B>>> {
192 self.type_to_model[smpl_type].gender_to_model[gender].as_ref()
193 }
194 #[allow(clippy::redundant_closure_for_method_calls)]
195 pub fn get_model_ref(&self, smpl_type: SmplType, gender: Gender) -> Option<&dyn SmplModel<B>> {
196 let opt = &self.type_to_model[smpl_type].gender_to_model[gender];
197 let model = opt.as_ref().map(|x| x.as_ref());
198 model
199 }
200 #[allow(clippy::redundant_closure_for_method_calls)]
201 pub fn get_model_mut(&mut self, smpl_type: SmplType, gender: Gender) -> Option<&mut dyn SmplModel<B>> {
202 let opt = &mut self.type_to_model[smpl_type].gender_to_model[gender];
203 let model = opt.as_mut().map(|x| x.as_mut());
204 model
205 }
206 pub fn has_model(&self, smpl_type: SmplType, gender: Gender) -> bool {
207 self.type_to_model[smpl_type].gender_to_model[gender].is_some()
208 }
209 pub fn has_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> bool {
210 self.type_to_path[smpl_type].gender_to_path[gender].is_some()
211 }
212 pub fn get_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> Option<String> {
213 self.type_to_path[smpl_type].gender_to_path[gender].clone()
214 }
215 pub fn set_lazy_loading(&mut self, smpl_type: SmplType, gender: Gender, path: &str) {
216 self.type_to_path[smpl_type].gender_to_path[gender] = Some(path.to_string());
217 assert!(
218 std::path::Path::new(&path).exists(),
219 "File at path {path} does not exist. Please follow the data download instructions in the README."
220 );
221 }
222}