1use crate::{
2 common::{
3 betas::Betas,
4 expression::Expression,
5 outputs::SmplOutputDynamic,
6 pose::Pose,
7 smpl_model::{SmplCacheDynamic, SmplModel},
8 smpl_options::SmplOptions,
9 types::{Gender, SmplType, UpAxis},
10 },
11 conversions::pose_remap::PoseRemap,
12};
13use burn::tensor::{backend::Backend, Float, Int, Tensor};
14use gloss_utils::bshare::{tensor_to_data_float, tensor_to_data_int, ToBurn};
15use gloss_utils::nshare::ToNalgebra;
16use log::{info, warn};
17use nalgebra as na;
18use ndarray as nd;
19use ndarray::prelude::*;
20use ndarray_npy::NpzReader;
21use smpl_utils::numerical::{batch_rigid_transform, batch_rodrigues};
22use smpl_utils::{array::Gather2D, io::FileLoader};
23use std::ops::Sub;
24use std::{
25 any::Any,
26 io::{Read, Seek},
27};
28pub const NUM_BODY_JOINTS: usize = 21;
29pub const NUM_HAND_JOINTS: usize = 15;
30pub const NUM_FACE_JOINTS: usize = 3;
31pub const NUM_JOINTS: usize = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS;
32pub const NECK_IDX: usize = 12;
33pub const NUM_VERTS: usize = 10475;
34pub const NUM_VERTS_UV_MESH: usize = 11307;
35pub const NUM_FACES: usize = 20908;
36pub const FULL_SHAPE_SPACE_DIM: usize = 400;
37pub const SHAPE_SPACE_DIM: usize = 300;
38pub const EXPRESSION_SPACE_DIM: usize = 100;
39pub const NUM_POSE_BLEND_SHAPES: usize = NUM_JOINTS * 9;
40use burn::backend::{Candle, NdArray, Wgpu};
41#[allow(clippy::large_enum_variant)]
42#[derive(Clone)]
43pub enum SmplXDynamic {
44 NdArray(SmplXGPU<NdArray>),
45 Wgpu(SmplXGPU<Wgpu>),
46 Candle(SmplXGPU<Candle>),
47}
48#[allow(clippy::return_self_not_must_use)]
49impl SmplXDynamic {
50 #[cfg(not(target_arch = "wasm32"))]
51 pub fn new_from_npz(models: &SmplCacheDynamic, path: &str, gender: Gender, max_num_betas: usize, num_expression_components: usize) -> Self {
52 match models {
53 SmplCacheDynamic::Wgpu(_) => {
54 info!("Initializing with Wgpu Backend");
55 let model = SmplXGPU::<Wgpu>::new_from_npz(path, gender, max_num_betas, num_expression_components);
56 SmplXDynamic::Wgpu(model)
57 }
58 SmplCacheDynamic::NdArray(_) => {
59 info!("Initializing with NdArray Backend");
60 let model = SmplXGPU::<NdArray>::new_from_npz(path, gender, max_num_betas, num_expression_components);
61 SmplXDynamic::NdArray(model)
62 }
63 SmplCacheDynamic::Candle(_) => {
64 info!("Initializing with Candle Backend");
65 let model = SmplXGPU::<Candle>::new_from_npz(path, gender, max_num_betas, num_expression_components);
66 SmplXDynamic::Candle(model)
67 }
68 }
69 }
70 pub fn new_from_reader<R: Read + Seek>(
71 models: &SmplCacheDynamic,
72 reader: R,
73 gender: Gender,
74 max_num_betas: usize,
75 max_num_expression_components: usize,
76 ) -> Self {
77 match models {
78 SmplCacheDynamic::Wgpu(_) => {
79 info!("Initializing from reader with Wgpu Backend");
80 let model = SmplXGPU::<Wgpu>::new_from_reader(reader, gender, max_num_betas, max_num_expression_components);
81 SmplXDynamic::Wgpu(model)
82 }
83 SmplCacheDynamic::NdArray(_) => {
84 info!("Initializing from reader with NdArray Backend");
85 let model = SmplXGPU::<NdArray>::new_from_reader(reader, gender, max_num_betas, max_num_expression_components);
86 SmplXDynamic::NdArray(model)
87 }
88 SmplCacheDynamic::Candle(_) => {
89 info!("Initializing from reader with Candle Backend");
90 let model = SmplXGPU::<Candle>::new_from_reader(reader, gender, max_num_betas, max_num_expression_components);
91 SmplXDynamic::Candle(model)
92 }
93 }
94 }
95 pub async fn new_from_npz_async(
96 models: &SmplCacheDynamic,
97 path: &str,
98 gender: Gender,
99 max_num_betas: usize,
100 num_expression_components: usize,
101 ) -> Self {
102 match models {
103 SmplCacheDynamic::Wgpu(_) => {
104 info!("Initializing with Wgpu Backend");
105 let model = SmplXGPU::<Wgpu>::new_from_npz_async(path, gender, max_num_betas, num_expression_components).await;
106 SmplXDynamic::Wgpu(model)
107 }
108 SmplCacheDynamic::NdArray(_) => {
109 info!("Initializing with NdArray Backend");
110 let model = SmplXGPU::<NdArray>::new_from_npz_async(path, gender, max_num_betas, num_expression_components).await;
111 SmplXDynamic::NdArray(model)
112 }
113 SmplCacheDynamic::Candle(_) => {
114 info!("Initializing with Candle Backend");
115 let model = SmplXGPU::<Candle>::new_from_npz_async(path, gender, max_num_betas, num_expression_components).await;
116 SmplXDynamic::Candle(model)
117 }
118 }
119 }
120}
121#[derive(Clone)]
122pub struct SmplXGPU<B: Backend> {
123 pub device: B::Device,
124 pub smpl_type: SmplType,
125 pub gender: Gender,
126 pub verts_template: Tensor<B, 2, Float>,
127 pub faces: Tensor<B, 2, Int>,
128 pub faces_uv_mesh: Tensor<B, 2, Int>,
129 pub uv: Tensor<B, 2, Float>,
130 pub shape_dirs: Tensor<B, 2, Float>,
131 pub expression_dirs: Option<Tensor<B, 2, Float>>,
132 pub pose_dirs: Option<Tensor<B, 2, Float>>,
133 pub joint_regressor: Tensor<B, 2, Float>,
134 pub parent_idx_per_joint: Tensor<B, 1, Int>,
135 pub lbs_weights: Tensor<B, 2, Float>,
136 pub verts_ones: Tensor<B, 2, Float>,
137 pub idx_vuv_2_vnouv: Tensor<B, 1, Int>,
138 pub faces_na: na::DMatrix<u32>,
139 pub faces_uv_mesh_na: na::DMatrix<u32>,
140 pub uv_na: na::DMatrix<f32>,
141 pub idx_vuv_2_vnouv_vec: Vec<usize>,
142 pub lbs_weights_split: Tensor<B, 2>,
143 pub lbs_weights_nd: nd::ArcArray2<f32>,
144 pub lbs_weights_split_nd: nd::ArcArray2<f32>,
145}
146impl<B: Backend> SmplXGPU<B> {
147 #[allow(clippy::too_many_arguments)]
150 #[allow(clippy::too_many_lines)]
151 pub fn new_from_matrices(
152 gender: Gender,
153 verts_template: &nd::Array2<f32>,
154 faces: &nd::Array2<u32>,
155 faces_uv_mesh: &nd::Array2<u32>,
156 uv: &nd::Array2<f32>,
157 shape_dirs: &nd::Array3<f32>,
158 expression_dirs: Option<nd::Array3<f32>>,
159 pose_dirs: Option<nd::Array3<f32>>,
160 joint_regressor: &nd::Array2<f32>,
161 parent_idx_per_joint: &nd::Array1<u32>,
162 lbs_weights: nd::Array2<f32>,
163 max_num_betas: usize,
164 max_num_expression_components: usize,
165 ) -> Self {
166 let device = B::Device::default();
167 let b_verts_template = verts_template.to_burn(&device);
168 let b_faces = faces.to_burn(&device);
169 let b_faces_uv_mesh = faces_uv_mesh.to_burn(&device);
170 let b_uv = uv.to_burn(&device);
171 let shape_dirs = shape_dirs
172 .slice_axis(Axis(2), ndarray::Slice::from(0..max_num_betas))
173 .to_owned()
174 .into_shape_with_order((NUM_VERTS * 3, max_num_betas))
175 .unwrap();
176 let b_shape_dirs = shape_dirs.to_burn(&device);
177 let b_expression_dirs = expression_dirs.map(|expression_dirs| {
178 let expression_dirs = expression_dirs
179 .slice_axis(nd::Axis(2), nd::Slice::from(0..max_num_expression_components))
180 .into_shape_with_order((NUM_VERTS * 3, max_num_expression_components))
181 .unwrap()
182 .to_owned();
183 expression_dirs.to_burn(&device)
184 });
185 let b_pose_dirs = pose_dirs.map(|pose_dirs| {
186 let pose_dirs = pose_dirs.into_shape_with_order((NUM_VERTS * 3, NUM_JOINTS * 9)).unwrap();
187 pose_dirs.to_burn(&device)
188 });
189 let b_joint_regressor = joint_regressor.to_burn(&device);
190 let b_parent_idx_per_joint = parent_idx_per_joint.to_burn(&device).reshape([NUM_JOINTS + 1]);
191 let b_lbs_weights = lbs_weights.to_burn(&device);
192 #[allow(clippy::cast_possible_wrap)]
193 let faces_uv_mesh_i32: nd::Array2<i32> = faces_uv_mesh.mapv(|x| x as i32);
194 let ft: nd::ArcArray2<i32> = faces_uv_mesh_i32.into();
195 let max_v_uv_idx = *ft.iter().max_by_key(|&x| x).unwrap();
196 let max_v_uv_idx_usize = usize::try_from(max_v_uv_idx).unwrap_or_else(|_| panic!("Cannot cast max_v_uv_idx to usize"));
197 let mut idx_vuv_2_vnouv = nd::ArcArray1::<i32>::zeros(max_v_uv_idx_usize + 1);
198 for (fuv, fnouv) in ft.axis_iter(nd::Axis(0)).zip(faces.axis_iter(nd::Axis(0))) {
199 let uv_0 = fuv[[0]];
200 let uv_1 = fuv[[1]];
201 let uv_2 = fuv[[2]];
202 let nouv_0 = fnouv[[0]];
203 let nouv_1 = fnouv[[1]];
204 let nouv_2 = fnouv[[2]];
205 idx_vuv_2_vnouv[usize::try_from(uv_0).unwrap_or_else(|_| panic!("Cannot cast uv_0 to usize"))] =
206 i32::try_from(nouv_0).unwrap_or_else(|_| panic!("Cannot cast nouv_0 to i32"));
207 idx_vuv_2_vnouv[usize::try_from(uv_1).unwrap_or_else(|_| panic!("Cannot cast uv_1 to usize"))] =
208 i32::try_from(nouv_1).unwrap_or_else(|_| panic!("Cannot cast nouv_1 to i32"));
209 idx_vuv_2_vnouv[usize::try_from(uv_2).unwrap_or_else(|_| panic!("Cannot cast uv_2 to usize"))] =
210 i32::try_from(nouv_2).unwrap_or_else(|_| panic!("Cannot cast nouv_2 to i32"));
211 }
212 let idx_vuv_2_vnouv_vec: Vec<i32> = idx_vuv_2_vnouv.mapv(|x| x).into_raw_vec_and_offset().0;
213 let idx_vuv_2_vnouv_slice: &[i32] = &idx_vuv_2_vnouv_vec;
214 let b_idx_vuv_2_vnouv = Tensor::<B, 1, Int>::from_ints(idx_vuv_2_vnouv_slice, &device);
215 let idx_vuv_2_vnouv_vec: Vec<usize> = idx_vuv_2_vnouv
216 .to_vec()
217 .iter()
218 .map(|&x| usize::try_from(x).unwrap_or_else(|_| panic!("Cannot cast negative value to usize")))
219 .collect();
220 let faces_na = faces.view().into_nalgebra().clone_owned().map(|x| x);
221 let faces_uv_mesh_na = ft
222 .view()
223 .into_nalgebra()
224 .clone_owned()
225 .map(|x| u32::try_from(x).unwrap_or_else(|_| panic!("Cannot cast value to u32")));
226 let uv_na = uv.view().into_nalgebra().clone_owned();
227 let cols: Vec<usize> = (0..lbs_weights.ncols()).collect();
228 let lbs_weights_split: nd::ArcArray2<f32> = lbs_weights.to_owned().gather(&idx_vuv_2_vnouv_vec, &cols).into();
229 let b_lbs_weights_split =
230 Tensor::<B, 1>::from_floats(lbs_weights_split.as_slice().unwrap(), &device).reshape([idx_vuv_2_vnouv_vec.len(), NUM_JOINTS + 1]);
231 let verts_ones = Tensor::<B, 2>::ones([NUM_VERTS, 1], &device);
232 let lbs_weights_nd: nd::ArcArray2<f32> = lbs_weights.into();
233 let cols: Vec<usize> = (0..lbs_weights_nd.ncols()).collect();
234 let lbs_weights_split_nd = lbs_weights_nd.to_owned().gather(&idx_vuv_2_vnouv_vec, &cols).into();
235 info!("Initialised burn on Backend: {:?}", B::name());
236 info!("Device: {:?}", &device);
237 Self {
238 smpl_type: SmplType::SmplX,
239 gender,
240 device,
241 verts_template: b_verts_template,
242 faces: b_faces,
243 faces_uv_mesh: b_faces_uv_mesh,
244 uv: b_uv,
245 shape_dirs: b_shape_dirs,
246 expression_dirs: b_expression_dirs,
247 pose_dirs: b_pose_dirs,
248 joint_regressor: b_joint_regressor,
249 parent_idx_per_joint: b_parent_idx_per_joint,
250 lbs_weights: b_lbs_weights,
251 verts_ones,
252 idx_vuv_2_vnouv: b_idx_vuv_2_vnouv,
253 faces_na,
254 faces_uv_mesh_na,
255 uv_na,
256 idx_vuv_2_vnouv_vec,
257 lbs_weights_split: b_lbs_weights_split,
258 lbs_weights_nd,
259 lbs_weights_split_nd,
260 }
261 }
262 fn new_from_npz_reader<R: Read + Seek>(
265 npz: &mut NpzReader<R>,
266 gender: Gender,
267 max_num_betas: usize,
268 max_num_expression_components: usize,
269 ) -> Self {
270 let verts_template: nd::Array2<f32> = npz.by_name("v_template").unwrap();
271 let faces: nd::Array2<u32> = npz.by_name("f").unwrap();
272 let uv: nd::Array2<f32> = npz.by_name("vt").unwrap();
273 let full_shape_dirs: nd::Array3<f32> = npz.by_name("shapedirs").unwrap();
274 let (shape_dirs, expression_dirs) = if let Ok(expression_dirs) = npz.by_name("expressiondirs") {
275 (full_shape_dirs, Some(expression_dirs))
276 } else {
277 let shape_dirs = full_shape_dirs
278 .slice_axis(nd::Axis(2), nd::Slice::from(0..max_num_betas.min(300)))
279 .to_owned();
280 let expression_dirs = if full_shape_dirs.shape()[2] > 300 {
281 Some(
282 full_shape_dirs
283 .slice_axis(nd::Axis(2), nd::Slice::from(300..300 + max_num_expression_components.min(100)))
284 .to_owned(),
285 )
286 } else {
287 None
288 };
289 (shape_dirs, expression_dirs)
290 };
291 let pose_dirs: Option<nd::Array3<f32>> = npz.by_name("posedirs").ok();
292 let joint_regressor: nd::Array2<f32> = npz.by_name("J_regressor").unwrap();
293 let parent_idx_per_joint: nd::Array2<i32> = npz.by_name("kintree_table").unwrap();
294 #[allow(clippy::cast_sign_loss)]
295 let parent_idx_per_joint = parent_idx_per_joint.mapv(|x| x as u32);
296 let parent_idx_per_joint = parent_idx_per_joint
297 .slice_axis(nd::Axis(0), nd::Slice::from(0..1))
298 .to_owned()
299 .into_shape_with_order(NUM_JOINTS + 1)
300 .unwrap();
301 let lbs_weights: nd::Array2<f32> = npz.by_name("weights").unwrap();
302 let ft: nd::Array2<u32> = npz.by_name("ft").unwrap();
303 if pose_dirs.is_none() {
304 warn!("No pose_dirs loaded from npz");
305 }
306 Self::new_from_matrices(
307 gender,
308 &verts_template,
309 &faces,
310 &ft,
311 &uv,
312 &shape_dirs,
313 expression_dirs,
314 pose_dirs,
315 &joint_regressor,
316 &parent_idx_per_joint,
317 lbs_weights,
318 max_num_betas,
319 max_num_expression_components,
320 )
321 }
322 #[cfg(not(target_arch = "wasm32"))]
323 pub fn new_from_npz(model_path: &str, gender: Gender, max_num_betas: usize, max_num_expression_components: usize) -> Self {
326 let mut npz = NpzReader::new(std::fs::File::open(model_path).unwrap()).unwrap();
327 Self::new_from_npz_reader(&mut npz, gender, max_num_betas, max_num_expression_components)
328 }
329 #[allow(clippy::cast_possible_truncation)]
334 pub async fn new_from_npz_async(model_path: &str, gender: Gender, max_num_betas: usize, max_num_expression_components: usize) -> Self {
335 let reader = FileLoader::open(model_path).await;
336 let mut npz = NpzReader::new(reader).unwrap();
337 Self::new_from_npz_reader(&mut npz, gender, max_num_betas, max_num_expression_components)
338 }
339 #[allow(clippy::cast_possible_truncation)]
344 pub fn new_from_reader<R: Read + Seek>(reader: R, gender: Gender, max_num_betas: usize, max_num_expression_components: usize) -> Self {
345 let mut npz = NpzReader::new(reader).unwrap();
346 Self::new_from_npz_reader(&mut npz, gender, max_num_betas, max_num_expression_components)
347 }
348 #[allow(clippy::cast_possible_truncation)]
353 pub fn read_pose_dirs_from_reader<R: Read + Seek>(reader: R, device: &B::Device) -> Tensor<B, 2, Float> {
354 let mut npz = NpzReader::new(reader).unwrap();
355 let pose_dirs: Option<nd::Array3<f32>> = Some(npz.by_name("pose_dirs").unwrap());
356 let b_pose_dirs =
357 pose_dirs.map(|pose_dirs| Tensor::<B, 1>::from_floats(pose_dirs.as_slice().unwrap(), device).reshape([NUM_VERTS * 3, NUM_JOINTS * 9]));
358 b_pose_dirs.unwrap()
359 }
360}
361impl<B: Backend> SmplModel<B> for SmplXGPU<B>
362where
363 B::FloatTensorPrimitive<2>: Sync,
364 B::IntTensorPrimitive<2>: Sync,
365 B::IntTensorPrimitive<1>: Sync,
366 B::QuantizedTensorPrimitive<1>: std::marker::Sync,
367 B::QuantizedTensorPrimitive<2>: std::marker::Sync,
368{
369 fn clone_dyn(&self) -> Box<dyn SmplModel<B>> {
370 Box::new(self.clone())
371 }
372 fn as_any(&self) -> &dyn Any {
373 self
374 }
375 fn smpl_type(&self) -> SmplType {
376 self.smpl_type
377 }
378 fn gender(&self) -> Gender {
379 self.gender
380 }
381 #[allow(clippy::missing_panics_doc)]
382 #[allow(non_snake_case)]
383 fn forward(&self, options: &SmplOptions, betas: &Betas, pose_raw: &Pose, expression: Option<&Expression>) -> SmplOutputDynamic<B> {
384 let mut verts_t_pose = self.betas2verts(betas);
385 if let Some(expression) = expression {
386 verts_t_pose = verts_t_pose + self.expression2offsets(expression);
387 }
388 let pose_remap = PoseRemap::new(pose_raw.smpl_type, SmplType::SmplX);
389 let pose = pose_remap.remap(pose_raw);
390 let joints_t_pose = self.verts2joints(verts_t_pose.clone());
391 if options.enable_pose_corrective {
392 let verts_offset = self.compute_pose_correctives(&pose);
393 verts_t_pose = verts_t_pose + verts_offset;
394 }
395 let (verts_posed_nd, _, _, joints_posed) = self.apply_pose(&verts_t_pose, None, None, &joints_t_pose, &self.lbs_weights, &pose);
396 SmplOutputDynamic {
397 verts: verts_posed_nd,
398 faces: self.faces.clone(),
399 normals: None,
400 uvs: None,
401 joints: joints_posed,
402 }
403 }
404 fn create_body_with_uv(&self, smpl_merged: &SmplOutputDynamic<B>) -> SmplOutputDynamic<B> {
405 let cols_tensor = Tensor::<B, 1, Int>::from_ints([0, 1, 2], &self.device);
406 let mapping_tensor = self.idx_split_2_merged();
407 let v_burn_split = smpl_merged.verts.clone().select(0, mapping_tensor.clone());
408 let v_burn_split = v_burn_split.select(1, cols_tensor.clone());
409 let n_burn_split = smpl_merged
410 .normals
411 .as_ref()
412 .map(|n| n.clone().select(0, mapping_tensor).select(1, cols_tensor));
413 SmplOutputDynamic {
414 verts: v_burn_split,
415 faces: self.faces_uv_mesh.clone(),
416 normals: n_burn_split,
417 uvs: Some(self.uv.clone()),
418 joints: smpl_merged.joints.clone(),
419 }
420 }
421 #[allow(clippy::missing_panics_doc)]
422 #[allow(non_snake_case)]
423 #[allow(clippy::let_and_return)]
424 fn betas2verts(&self, betas: &Betas) -> Tensor<B, 2, Float> {
425 let device = self.verts_template.device();
426 let betas_slice = betas.betas.as_slice().unwrap();
427 let betas_tensor = Tensor::<B, 1, Float>::from_floats(betas_slice, &device);
428 let input_nr_betas = betas_tensor.shape().dims[0];
429 let shape_dirs_sliced = self.shape_dirs.clone().slice([0..self.shape_dirs.dims()[0], 0..input_nr_betas]);
430 let v_beta_offsets = shape_dirs_sliced.matmul(betas_tensor.reshape([input_nr_betas, 1]));
431 let v_beta_offsets_reshaped = v_beta_offsets.reshape([NUM_VERTS, 3]);
432 let verts_t_pose = v_beta_offsets_reshaped.add(self.verts_template.clone());
433 verts_t_pose
434 }
435 #[allow(clippy::missing_panics_doc)]
436 #[allow(non_snake_case)]
437 #[allow(clippy::let_and_return)]
438 fn expression2offsets(&self, expression: &Expression) -> Tensor<B, 2, Float> {
439 let device = self.verts_template.device();
440 let offsets = if let Some(ref expression_dirs) = self.expression_dirs {
441 let input_nr_expression_coeffs = expression.expr_coeffs.len();
442 let expression_dirs_sliced = expression_dirs
443 .clone()
444 .slice([0..expression_dirs.dims()[0], 0..input_nr_expression_coeffs]);
445 let expr_coeffs_tensor = Tensor::<B, 1, Float>::from_floats(expression.expr_coeffs.as_slice().unwrap(), &device);
446 let v_expr_offsets = expression_dirs_sliced.matmul(expr_coeffs_tensor.reshape([input_nr_expression_coeffs, 1]));
447 v_expr_offsets.reshape([NUM_VERTS, 3])
448 } else {
449 Tensor::<B, 2, Float>::zeros([NUM_VERTS, 3], &device)
450 };
451 offsets
452 }
453 fn verts2joints(&self, verts_t_pose: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
454 self.joint_regressor.clone().matmul(verts_t_pose)
455 }
456 #[allow(clippy::missing_panics_doc)]
457 fn compute_pose_correctives(&self, pose: &Pose) -> Tensor<B, 2, Float> {
458 let offsets = if let Some(pose_dirs) = &self.pose_dirs {
459 let full_pose = &pose.joint_poses;
460 assert!(
461 full_pose.dim().0 == NUM_JOINTS + 1,
462 "The pose does not have the correct number of joints for this model. Maybe you need to add a PoseRemapper component?\n {:?} != {:?}",
463 full_pose.dim().0,
464 NUM_JOINTS + 1
465 );
466 let mut rot_mats = batch_rodrigues(full_pose);
467 let identity = ndarray::Array2::<f32>::eye(3);
468 let pose_feature = (rot_mats.slice_mut(s![1.., .., ..]).sub(&identity))
469 .into_shape_with_order(NUM_JOINTS * 9)
470 .unwrap();
471 let b_pose_feature = Tensor::<B, 1, Float>::from_floats(pose_feature.as_slice().unwrap(), &self.device).reshape([NUM_JOINTS * 9, 1]);
472 let new_pose_dirs = pose_dirs.clone();
473 let all_pose_offsets = new_pose_dirs.matmul(b_pose_feature);
474 all_pose_offsets.reshape([NUM_VERTS, 3])
475 } else {
476 Tensor::<B, 2, Float>::zeros([NUM_VERTS, 3], &self.device)
477 };
478 offsets
479 }
480 #[allow(clippy::missing_panics_doc)]
481 fn compute_pose_feature(&self, pose: &Pose) -> nd::Array1<f32> {
482 let full_pose = &pose.joint_poses;
483 assert!(
484 full_pose.dim().0 == NUM_JOINTS + 1,
485 "The pose does not have the correct number of joints for this model. Maybe you need to add a PoseRemapper component?\n {:?} != {:?}",
486 full_pose.dim().0,
487 NUM_JOINTS + 1
488 );
489 let mut rot_mats = batch_rodrigues(full_pose);
490 let identity = ndarray::Array2::<f32>::eye(3);
491 let pose_feature = (rot_mats.slice_mut(s![1.., .., ..]).sub(&identity))
492 .into_shape_with_order(NUM_JOINTS * 9)
493 .unwrap();
494 pose_feature
495 }
496 #[allow(clippy::missing_panics_doc)]
497 #[allow(non_snake_case)]
498 #[allow(clippy::cast_precision_loss)]
499 #[allow(clippy::cast_sign_loss)]
500 #[allow(clippy::too_many_lines)]
501 #[allow(clippy::similar_names)]
502 fn apply_pose(
503 &self,
504 verts_t_pose: &Tensor<B, 2, Float>,
505 normals: Option<&Tensor<B, 2, Float>>,
506 tangents: Option<&Tensor<B, 2, Float>>,
507 joints: &Tensor<B, 2, Float>,
508 lbs_weights: &Tensor<B, 2, Float>,
509 pose: &Pose,
510 ) -> (
511 Tensor<B, 2, Float>,
512 Option<Tensor<B, 2, Float>>,
513 Option<Tensor<B, 2, Float>>,
514 Tensor<B, 2, Float>,
515 ) {
516 assert!(
517 verts_t_pose.shape().dims[0] == lbs_weights.shape().dims[0],
518 "Verts and LBS weights should match"
519 );
520 let full_pose = &pose.joint_poses;
521 assert!(
522 full_pose.shape()[0] == NUM_JOINTS + 1,
523 "The pose does not have the correct number of joints for this model."
524 );
525 let rot_mats = batch_rodrigues(full_pose);
526 let joints_data = tensor_to_data_float(joints);
527 let shape = joints.shape().dims;
528 let nd_joints = nd::Array2::from_shape_vec((shape[0], shape[1]), joints_data).expect("Shape mismatch during tensor to ndarray conversion");
529 let parent_idx_data_i32: Vec<i32> = tensor_to_data_int(&self.parent_idx_per_joint);
530 let parent_idx_data_u32: Vec<u32> = parent_idx_data_i32.into_iter().map(|x| x as u32).collect();
531 let (posed_joints_nd, rel_transforms_nd) = batch_rigid_transform(parent_idx_data_u32, &rot_mats, &nd_joints, NUM_JOINTS);
532 let posed_joints = posed_joints_nd.to_burn(&self.device);
533 let nr_verts = verts_t_pose.shape().dims[0];
534 let nr_joints = posed_joints.shape().dims[0];
535 let v_posed = verts_t_pose.clone();
536 let W = lbs_weights;
537 let A_nd = rel_transforms_nd.into_shape_with_order((NUM_JOINTS + 1, 16)).unwrap();
538 let A = A_nd.to_burn(&self.device);
539 let T = W.clone().matmul(A).reshape([nr_verts, 4, 4]);
540 let dims_3 = 3;
541 let rot0 = T.clone().slice([0..nr_verts, 0..1, 0..dims_3]).squeeze(1);
542 let rot1 = T.clone().slice([0..nr_verts, 1..2, 0..dims_3]).squeeze(1);
543 let rot2 = T.clone().slice([0..nr_verts, 2..3, 0..dims_3]).squeeze(1);
544 let trans: Tensor<B, 2> = T.slice([0..nr_verts, 0..dims_3, 3..4]).squeeze(2);
545 let verts_final_0 = rot0.clone().mul(v_posed.clone()).sum_dim(1);
546 let verts_final_1 = rot1.clone().mul(v_posed.clone()).sum_dim(1);
547 let verts_final_2 = rot2.clone().mul(v_posed.clone()).sum_dim(1);
548 let verts_final = Tensor::<B, 1>::stack(vec![verts_final_0.squeeze(1), verts_final_1.squeeze(1), verts_final_2.squeeze(1)], 1);
549 let verts_final = verts_final.add(trans);
550 let mut normals_final = if let Some(normals) = normals {
551 let normals_0 = rot0.clone().mul(normals.clone()).sum_dim(1);
552 let normals_1 = rot1.clone().mul(normals.clone()).sum_dim(1);
553 let normals_2 = rot2.clone().mul(normals.clone()).sum_dim(1);
554 let normals_final = Tensor::<B, 1>::stack(vec![normals_0.squeeze(1), normals_1.squeeze(1), normals_2.squeeze(1)], 1);
555 Some(normals_final)
556 } else {
557 None
558 };
559 let mut tangents_final = if let Some(tangents) = tangents {
560 let tangents_3 = tangents.clone().slice([0..nr_verts, 0..3]);
561 let tangents_0 = rot0.mul(tangents_3.clone()).sum_dim(1);
562 let tangents_1 = rot1.mul(tangents_3.clone()).sum_dim(1);
563 let tangents_2 = rot2.mul(tangents_3.clone()).sum_dim(1);
564 let handedness: Tensor<B, 1> = tangents.clone().slice([0..nr_verts, 3..4]).squeeze(1);
565 let tangents_final = Tensor::<B, 1>::stack(vec![tangents_0.squeeze(1), tangents_1.squeeze(1), tangents_2.squeeze(1), handedness], 1);
566 Some(tangents_final)
567 } else {
568 None
569 };
570 let trans_pose_nd = pose.global_trans.clone();
571 let trans_pose = trans_pose_nd.to_burn(&self.device);
572 let trans_pose_broadcasted_v = trans_pose.clone().reshape([1, 3]).expand(verts_final.shape());
573 let trans_pose_broadcasted_p = trans_pose.reshape([1, 3]).expand(posed_joints.shape());
574 let mut verts_final_modified = verts_final.clone().add(trans_pose_broadcasted_v.clone());
575 let mut posed_joints_modified = posed_joints.clone().add(trans_pose_broadcasted_p.clone());
576 if pose.up_axis == UpAxis::Z {
577 let vcol0: Tensor<B, 1> = verts_final_modified.clone().slice([0..nr_verts, 0..1]).squeeze(1);
578 let vcol1: Tensor<B, 1> = verts_final_modified.clone().slice([0..nr_verts, 1..2]).squeeze(1);
579 let vcol2: Tensor<B, 1> = verts_final_modified.clone().slice([0..nr_verts, 2..3]).squeeze(1);
580 let verts_new_col1 = vcol2;
581 let verts_new_col2 = vcol1.mul_scalar(-1.0);
582 verts_final_modified = Tensor::stack::<2>(vec![vcol0, verts_new_col1, verts_new_col2], 1);
583 let jcol0: Tensor<B, 1> = posed_joints_modified.clone().slice([0..nr_joints, 0..1]).squeeze(1);
584 let jcol1: Tensor<B, 1> = posed_joints_modified.clone().slice([0..nr_joints, 1..2]).squeeze(1);
585 let jcol2: Tensor<B, 1> = posed_joints_modified.clone().slice([0..nr_joints, 2..3]).squeeze(1);
586 let joints_new_col1 = jcol2;
587 let joints_new_col2 = jcol1.mul_scalar(-1.0);
588 posed_joints_modified = Tensor::stack::<2>(vec![jcol0, joints_new_col1, joints_new_col2], 1);
589 if let Some(ref mut normals) = normals_final {
590 let ncol0: Tensor<B, 1> = normals.clone().slice([0..nr_verts, 0..1]).squeeze(1);
591 let ncol1: Tensor<B, 1> = normals.clone().slice([0..nr_verts, 1..2]).squeeze(1);
592 let ncol2: Tensor<B, 1> = normals.clone().slice([0..nr_verts, 2..3]).squeeze(1);
593 let normals_new_col1 = ncol2;
594 let normals_new_col2 = ncol1.mul_scalar(-1.0);
595 let normals_final_modified = Tensor::stack::<2>(vec![ncol0, normals_new_col1, normals_new_col2], 1);
596 *normals = normals_final_modified;
597 }
598 if let Some(ref mut tangents) = tangents_final {
599 let tcol0: Tensor<B, 1> = tangents.clone().slice([0..nr_verts, 0..1]).squeeze(1);
600 let tcol1: Tensor<B, 1> = tangents.clone().slice([0..nr_verts, 1..2]).squeeze(1);
601 let tcol2: Tensor<B, 1> = tangents.clone().slice([0..nr_verts, 2..3]).squeeze(1);
602 let tangents_new_col1 = tcol2;
603 let tangents_new_col2 = tcol1.mul_scalar(-1.0);
604 let handedness = tangents.clone().slice([0..nr_verts, 3..4]).squeeze(1);
605 let tangents_final_modified = Tensor::stack::<2>(vec![tcol0, tangents_new_col1, tangents_new_col2, handedness], 1);
606 *tangents = tangents_final_modified;
607 }
608 }
609 (verts_final_modified, normals_final.clone(), tangents_final.clone(), posed_joints_modified)
610 }
611 fn faces(&self) -> &Tensor<B, 2, Int> {
612 &self.faces
613 }
614 fn faces_uv(&self) -> &Tensor<B, 2, Int> {
615 &self.faces_uv_mesh
616 }
617 fn uv(&self) -> &Tensor<B, 2, Float> {
618 &self.uv
619 }
620 fn lbs_weights(&self) -> Tensor<B, 2, Float> {
621 self.lbs_weights.clone()
622 }
623 fn lbs_weights_split(&self) -> Tensor<B, 2, Float> {
624 self.lbs_weights_split.clone()
625 }
626 fn idx_split_2_merged(&self) -> Tensor<B, 1, Int> {
627 self.idx_vuv_2_vnouv.clone()
628 }
629 fn idx_split_2_merged_vec(&self) -> &Vec<usize> {
630 &self.idx_vuv_2_vnouv_vec
631 }
632 fn set_pose_dirs(&mut self, pose_dirs: Tensor<B, 2, Float>) {
633 self.pose_dirs = Some(pose_dirs);
634 }
635 fn get_pose_dirs(&self) -> Tensor<B, 2, Float> {
636 if let Some(pose_dirs_tensor) = self.pose_dirs.clone() {
637 pose_dirs_tensor
638 } else {
639 panic!("pose_dirs is not available!");
640 }
641 }
642 fn get_expression_dirs(&self) -> Option<Tensor<B, 2, Float>> {
643 self.expression_dirs.clone()
644 }
645}