1use super::{
2 expression::ExpressionG,
3 pose::PoseG,
4 smpl_options::SmplOptions,
5 types::{Gender, SmplType},
6};
7use crate::{
8 common::{betas::BetasG, outputs::SmplOutputG},
9 AppBackend,
10};
11use burn::{
12 prelude::Backend,
13 tensor::{Float, Int, Tensor},
14};
15use dyn_clone::DynClone;
16use enum_map::EnumMap;
17use gloss_geometry::csr::VertexFaceCSRBurn;
18use std::any::Any;
19pub trait FaceModel<B: Backend>: Send + Sync + 'static + Any + DynClone {
20 fn expression2offsets(&self, expression: &ExpressionG<B>) -> Tensor<B, 2, Float>;
21 fn get_face_model(&self) -> &dyn FaceModel<B>;
22}
23impl<B: Backend> Clone for Box<dyn FaceModel<B>> {
24 #[allow(unconditional_recursion)]
25 fn clone(&self) -> Box<dyn FaceModel<B>> {
26 self.clone()
27 }
28}
29pub trait SmplModel<B: Backend>: Send + Sync + 'static + Any + DynClone {
32 fn smpl_type(&self) -> SmplType;
33 fn gender(&self) -> Gender;
34 fn device(&self) -> B::Device;
35 fn forward(&self, options: &SmplOptions, betas: &BetasG<B>, pose_raw: &PoseG<B>, expression: Option<&ExpressionG<B>>) -> SmplOutputG<B>;
36 fn create_body_with_uv(&self, smpl_output: &SmplOutputG<B>) -> SmplOutputG<B>;
37 fn get_face_model(&self) -> &dyn FaceModel<B>;
38 fn betas2verts(&self, betas: &BetasG<B>) -> Tensor<B, 2, Float>;
39 fn verts2joints(&self, verts_t_pose: Tensor<B, 2, Float>) -> Tensor<B, 2, Float>;
40 fn compute_pose_correctives(&self, pose: &PoseG<B>) -> Tensor<B, 2, Float>;
41 fn compute_pose_feature(&self, pose: &PoseG<B>) -> Tensor<B, 1, Float>;
42 #[allow(clippy::type_complexity)]
43 fn apply_pose(
44 &self,
45 verts_t_pose: &Tensor<B, 2, Float>,
46 joints: &Tensor<B, 2, Float>,
47 lbs_weights: &Tensor<B, 2, Float>,
48 pose: &PoseG<B>,
49 ) -> (Tensor<B, 2, Float>, Tensor<B, 2, Float>);
50 fn faces(&self) -> &Tensor<B, 2, Int>;
51 fn faces_uv(&self) -> &Tensor<B, 2, Int>;
52 fn uv(&self) -> &Tensor<B, 2, Float>;
53 fn lbs_weights(&self) -> Tensor<B, 2, Float>;
54 fn lbs_weights_split(&self) -> Tensor<B, 2, Float>;
55 fn idx_split_2_merged(&self) -> Tensor<B, 1, Int>;
56 fn idx_split_2_merged_vec(&self) -> &Vec<usize>;
57 fn set_pose_dirs(&mut self, posedirs: Tensor<B, 2, Float>);
58 fn get_pose_dirs(&self) -> Tensor<B, 2, Float>;
59 fn get_expression_dirs(&self) -> Option<Tensor<B, 2, Float>>;
60 fn vertex_face_csr(&self) -> Option<VertexFaceCSRBurn<B>>;
61 fn vertex_face_uv_csr(&self) -> Option<VertexFaceCSRBurn<B>>;
62 fn clone_dyn(&self) -> Box<dyn SmplModel<B>>;
63 fn as_any(&self) -> &dyn Any;
64}
65impl<B: Backend> Clone for Box<dyn SmplModel<B>> {
66 #[allow(unconditional_recursion)]
67 fn clone(&self) -> Box<dyn SmplModel<B>> {
68 self.clone()
69 }
70}
71#[derive(Default, Clone)]
73pub struct Gender2Model<B: Backend> {
74 gender_to_model: EnumMap<Gender, Option<Box<dyn SmplModel<B>>>>,
75}
76#[derive(Default, Clone)]
77pub struct Gender2Path {
78 gender_to_path: EnumMap<Gender, Option<String>>,
79}
80#[derive(Default, Clone)]
83pub struct SmplCacheG<B: Backend> {
84 type_to_model: EnumMap<SmplType, Gender2Model<B>>,
85 type_to_path: EnumMap<SmplType, Gender2Path>,
86}
87impl<B: Backend> SmplCacheG<B> {
88 pub fn add_model<T: SmplModel<B> + FaceModel<B>>(&mut self, model: T, cache_models: bool) {
89 let smpl_type = model.smpl_type();
90 let gender = model.gender();
91 if !cache_models {
92 self.type_to_model = EnumMap::default();
93 }
94 self.type_to_model[smpl_type].gender_to_model[gender] = Some(Box::new(model));
95 }
96 pub fn remove_all_models(&mut self) {
97 self.type_to_model = EnumMap::default();
98 }
99 #[allow(clippy::borrowed_box)]
100 pub fn get_model_box_ref(&self, smpl_type: SmplType, gender: Gender) -> Option<&Box<dyn SmplModel<B>>> {
101 self.type_to_model[smpl_type].gender_to_model[gender].as_ref()
102 }
103 #[allow(clippy::redundant_closure_for_method_calls)]
104 pub fn get_model_ref(&self, smpl_type: SmplType, gender: Gender) -> Option<&dyn SmplModel<B>> {
105 let opt = &self.type_to_model[smpl_type].gender_to_model[gender];
106 let model = opt.as_ref().map(|x| x.as_ref());
107 model
108 }
109 #[allow(clippy::redundant_closure_for_method_calls)]
110 pub fn get_face_model_ref(&self, smpl_type: SmplType, gender: Gender) -> Option<&dyn FaceModel<B>> {
111 let opt = &self.type_to_model[smpl_type].gender_to_model[gender];
112 opt.as_ref().map(|model| model.get_face_model())
113 }
114 #[allow(clippy::redundant_closure_for_method_calls)]
115 pub fn get_model_mut(&mut self, smpl_type: SmplType, gender: Gender) -> Option<&mut dyn SmplModel<B>> {
116 let opt = &mut self.type_to_model[smpl_type].gender_to_model[gender];
117 let model = opt.as_mut().map(|x| x.as_mut());
118 model
119 }
120 pub fn has_model(&self, smpl_type: SmplType, gender: Gender) -> bool {
121 self.type_to_model[smpl_type].gender_to_model[gender].is_some()
122 }
123 pub fn has_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> bool {
124 self.type_to_path[smpl_type].gender_to_path[gender].is_some()
125 }
126 pub fn get_lazy_loading(&self, smpl_type: SmplType, gender: Gender) -> Option<String> {
127 self.type_to_path[smpl_type].gender_to_path[gender].clone()
128 }
129 pub fn lazy_load_defaults(&mut self) {
130 self.set_lazy_loading(SmplType::SmplX, Gender::Neutral, "./data/smplx/SMPLX_neutral_array_f32_slim.npz");
131 self.set_lazy_loading(SmplType::SmplX, Gender::Male, "./data/smplx/SMPLX_male_array_f32_slim.npz");
132 self.set_lazy_loading(SmplType::SmplX, Gender::Female, "./data/smplx/SMPLX_female_array_f32_slim.npz");
133 }
134 pub fn set_lazy_loading(&mut self, smpl_type: SmplType, gender: Gender, path: &str) {
135 self.type_to_path[smpl_type].gender_to_path[gender] = Some(path.to_string());
136 #[cfg(not(target_arch = "wasm32"))]
137 assert!(
138 std::path::Path::new(&path).exists(),
139 "File at path {path} does not exist. Please follow the data download instructions in the README."
140 );
141 }
142 #[cfg(not(target_arch = "wasm32"))]
143 pub fn add_model_from_type(&mut self, smpl_type: SmplType, path: &str, gender: Gender, max_num_betas: usize, num_expression_components: usize) {
144 match smpl_type {
145 SmplType::SmplX => {
146 use crate::smpl_x::smpl_x_gpu::SmplXGPUG;
147 let new_model = SmplXGPUG::new_from_npz(path, gender, max_num_betas, num_expression_components);
148 self.add_model(new_model, true);
149 }
150 _ => panic!("Model loading for {smpl_type:?} if not supported yet!"),
151 };
152 }
153}
154pub type SmplCache = SmplCacheG<AppBackend>;