1use super::{
2 betas::Betas,
3 expression::Expression,
4 outputs::SmplOutputDynamic,
5 pose::Pose,
6 smpl_options::SmplOptions,
7 types::{Gender, SmplType},
8};
9use crate::smpl_x::smpl_x_gpu::SmplXDynamic;
10use burn::{
11 backend::{Candle, NdArray, Wgpu},
12 prelude::Backend,
13 tensor::{Float, Int, Tensor},
14};
15use dyn_clone::DynClone;
16use enum_map::EnumMap;
17use gloss_utils::tensor::BurnBackend;
18use ndarray as nd;
19use std::any::Any;
20pub trait FaceModel<B: Backend>: Send + Sync + 'static + Any + DynClone {
21 fn expression2offsets(&self, expression: &Expression) -> Tensor<B, 2, Float>;
22}
23impl<B: Backend> Clone for Box<dyn FaceModel<B>> {
24 #[allow(unconditional_recursion)]
25 fn clone(&self) -> Box<dyn FaceModel<B>> {
26 self.clone()
27 }
28}
29pub trait SmplModel<B: Backend>: Send + Sync + 'static + Any + DynClone {
32 fn smpl_type(&self) -> SmplType;
33 fn gender(&self) -> Gender;
34 fn forward(&self, options: &SmplOptions, betas: &Betas, pose_raw: &Pose, expression: Option<&Expression>) -> SmplOutputDynamic<B>;
35 fn create_body_with_uv(&self, smpl_output: &SmplOutputDynamic<B>) -> SmplOutputDynamic<B>;
36 fn get_face_model(&self) -> &dyn FaceModel<B>;
37 fn betas2verts(&self, betas: &Betas) -> Tensor<B, 2, Float>;
38 fn verts2joints(&self, verts_t_pose: Tensor<B, 2, Float>) -> Tensor<B, 2, Float>;
39 fn compute_pose_correctives(&self, pose: &Pose) -> Tensor<B, 2, Float>;
40 fn compute_pose_feature(&self, pose: &Pose) -> nd::Array1<f32>;
41 #[allow(clippy::type_complexity)]
42 fn apply_pose(
43 &self,
44 verts_t_pose: &Tensor<B, 2, Float>,
45 normals: Option<&Tensor<B, 2, Float>>,
46 tangents: Option<&Tensor<B, 2, Float>>,
47 joints: &Tensor<B, 2, Float>,
48 lbs_weights: &Tensor<B, 2, Float>,
49 pose: &Pose,
50 ) -> (
51 Tensor<B, 2, Float>,
52 Option<Tensor<B, 2, Float>>,
53 Option<Tensor<B, 2, Float>>,
54 Tensor<B, 2, Float>,
55 );
56 fn faces(&self) -> &Tensor<B, 2, Int>;
57 fn faces_uv(&self) -> &Tensor<B, 2, Int>;
58 fn uv(&self) -> &Tensor<B, 2, Float>;
59 fn lbs_weights(&self) -> Tensor<B, 2, Float>;
60 fn lbs_weights_split(&self) -> Tensor<B, 2, Float>;
61 fn idx_split_2_merged(&self) -> Tensor<B, 1, Int>;
62 fn idx_split_2_merged_vec(&self) -> &Vec<usize>;
63 fn set_pose_dirs(&mut self, posedirs: Tensor<B, 2, Float>);
64 fn get_pose_dirs(&self) -> Tensor<B, 2, Float>;
65 fn get_expression_dirs(&self) -> Option<Tensor<B, 2, Float>>;
66 fn clone_dyn(&self) -> Box<dyn SmplModel<B>>;
67 fn as_any(&self) -> &dyn Any;
68}
69impl<B: Backend> Clone for Box<dyn SmplModel<B>> {
70 #[allow(unconditional_recursion)]
71 fn clone(&self) -> Box<dyn SmplModel<B>> {
72 self.clone()
73 }
74}
75#[derive(Default, Clone)]
77pub struct Gender2Model<B: Backend> {
78 gender_to_model: EnumMap<Gender, Option<Box<dyn SmplModel<B>>>>,
79}
80#[derive(Default, Clone)]
81pub struct Gender2Path {
82 gender_to_path: EnumMap<Gender, Option<String>>,
83}
84#[allow(clippy::large_enum_variant)]
87#[derive(Clone)]
88pub enum SmplCacheDynamic {
89 NdArray(SmplCache<NdArray>),
90 Wgpu(SmplCache<Wgpu>),
91 Candle(SmplCache<Candle>),
92}
93impl Default for SmplCacheDynamic {
94 fn default() -> Self {
95 SmplCacheDynamic::Candle(SmplCache::default())
96 }
97}
98impl SmplCacheDynamic {
99 pub fn get_backend(&self) -> BurnBackend {
101 match self {
102 SmplCacheDynamic::NdArray(_) => BurnBackend::NdArray,
103 SmplCacheDynamic::Wgpu(_) => BurnBackend::Wgpu,
104 SmplCacheDynamic::Candle(_) => BurnBackend::Candle,
105 }
106 }
107 pub fn has_model(&self, smpl_type: SmplType, gender: Gender) -> bool {
109 match self {
110 SmplCacheDynamic::NdArray(models) => models.has_model(smpl_type, gender),
111 SmplCacheDynamic::Wgpu(models) => models.has_model(smpl_type, gender),
112 SmplCacheDynamic::Candle(models) => models.has_model(smpl_type, gender),
113 }
114 }
115 pub fn remove_all_models(&mut self) {
117 match self {
118 SmplCacheDynamic::NdArray(models) => models.remove_all_models(),
119 SmplCacheDynamic::Wgpu(models) => models.remove_all_models(),
120 SmplCacheDynamic::Candle(models) => models.remove_all_models(),
121 }
122 }
123 pub fn has_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> bool {
124 match self {
125 SmplCacheDynamic::NdArray(models) => models.has_lazy_loading(smpl_type, gender),
126 SmplCacheDynamic::Wgpu(models) => models.has_lazy_loading(smpl_type, gender),
127 SmplCacheDynamic::Candle(models) => models.has_lazy_loading(smpl_type, gender),
128 }
129 }
130 pub fn get_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> Option<String> {
131 match self {
132 SmplCacheDynamic::NdArray(models) => models.get_lazy_loading(smpl_type, gender),
133 SmplCacheDynamic::Wgpu(models) => models.get_lazy_loading(smpl_type, gender),
134 SmplCacheDynamic::Candle(models) => models.get_lazy_loading(smpl_type, gender),
135 }
136 }
137 pub fn lazy_load_defaults(&mut self) {
139 self.set_lazy_loading(SmplType::SmplX, Gender::Neutral, "./data/smplx/SMPLX_neutral_array_f32_slim.npz");
140 self.set_lazy_loading(SmplType::SmplX, Gender::Male, "./data/smplx/SMPLX_male_array_f32_slim.npz");
141 self.set_lazy_loading(SmplType::SmplX, Gender::Female, "./data/smplx/SMPLX_female_array_f32_slim.npz");
142 }
143 pub fn set_lazy_loading(&mut self, smpl_type: SmplType, gender: Gender, path: &str) {
145 match self {
146 SmplCacheDynamic::NdArray(models) => models.set_lazy_loading(smpl_type, gender, path),
147 SmplCacheDynamic::Wgpu(models) => models.set_lazy_loading(smpl_type, gender, path),
148 SmplCacheDynamic::Candle(models) => models.set_lazy_loading(smpl_type, gender, path),
149 }
150 }
151 pub fn add_model_from_dynamic_device(&mut self, model: SmplXDynamic, cache_models: bool) {
153 match (self, model) {
154 (SmplCacheDynamic::NdArray(models), SmplXDynamic::NdArray(model_ndarray)) => {
155 models.add_model(model_ndarray, cache_models);
156 }
157 (SmplCacheDynamic::Wgpu(models), SmplXDynamic::Wgpu(model_wgpu)) => {
158 models.add_model(model_wgpu, cache_models);
159 }
160 (SmplCacheDynamic::Candle(models), SmplXDynamic::Candle(model_candle)) => {
161 models.add_model(model_candle, cache_models);
162 }
163 _ => {
164 eprintln!("Model and backend type mismatch!");
165 }
166 }
167 }
168
169 #[cfg(not(target_arch = "wasm32"))]
170 pub fn add_model_from_type(&mut self, smpl_type: SmplType, path: &str, gender: Gender, max_num_betas: usize, num_expression_components: usize) {
171 match smpl_type {
172 SmplType::SmplX => {
173 let new_model = SmplXDynamic::new_from_npz(self, path, gender, max_num_betas, num_expression_components);
174 self.add_model_from_dynamic_device(new_model, true);
175 }
176 _ => panic!("Model loading for {smpl_type:?} if not supported yet!"),
177 };
178 }
179}
180#[derive(Default, Clone)]
183pub struct SmplCache<B: Backend> {
184 type_to_model: EnumMap<SmplType, Gender2Model<B>>,
185 type_to_path: EnumMap<SmplType, Gender2Path>,
186}
187impl<B: Backend> SmplCache<B> {
188 pub fn add_model<T: SmplModel<B> + FaceModel<B>>(&mut self, model: T, cache_models: bool) {
189 let smpl_type = model.smpl_type();
190 let gender = model.gender();
191 if !cache_models {
192 self.type_to_model = EnumMap::default();
193 }
194 self.type_to_model[smpl_type].gender_to_model[gender] = Some(Box::new(model));
195 }
196 pub fn remove_all_models(&mut self) {
197 self.type_to_model = EnumMap::default();
198 }
199 #[allow(clippy::borrowed_box)]
200 pub fn get_model_box_ref(&self, smpl_type: SmplType, gender: Gender) -> Option<&Box<dyn SmplModel<B>>> {
201 self.type_to_model[smpl_type].gender_to_model[gender].as_ref()
202 }
203 #[allow(clippy::redundant_closure_for_method_calls)]
204 pub fn get_model_ref(&self, smpl_type: SmplType, gender: Gender) -> Option<&dyn SmplModel<B>> {
205 let opt = &self.type_to_model[smpl_type].gender_to_model[gender];
206 let model = opt.as_ref().map(|x| x.as_ref());
207 model
208 }
209 #[allow(clippy::redundant_closure_for_method_calls)]
210 pub fn get_face_model_ref(&self, smpl_type: SmplType, gender: Gender) -> Option<&dyn FaceModel<B>> {
211 let opt = &self.type_to_model[smpl_type].gender_to_model[gender];
212 opt.as_ref().map(|model| model.get_face_model())
213 }
214 #[allow(clippy::redundant_closure_for_method_calls)]
215 pub fn get_model_mut(&mut self, smpl_type: SmplType, gender: Gender) -> Option<&mut dyn SmplModel<B>> {
216 let opt = &mut self.type_to_model[smpl_type].gender_to_model[gender];
217 let model = opt.as_mut().map(|x| x.as_mut());
218 model
219 }
220 pub fn has_model(&self, smpl_type: SmplType, gender: Gender) -> bool {
221 self.type_to_model[smpl_type].gender_to_model[gender].is_some()
222 }
223 pub fn has_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> bool {
224 self.type_to_path[smpl_type].gender_to_path[gender].is_some()
225 }
226 pub fn get_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> Option<String> {
227 self.type_to_path[smpl_type].gender_to_path[gender].clone()
228 }
229 pub fn set_lazy_loading(&mut self, smpl_type: SmplType, gender: Gender, path: &str) {
230 self.type_to_path[smpl_type].gender_to_path[gender] = Some(path.to_string());
231 assert!(
232 std::path::Path::new(&path).exists(),
233 "File at path {path} does not exist. Please follow the data download instructions in the README."
234 );
235 }
236}