1use crate::AppBackend;
2use crate::{
3 common::{
4 betas::BetasG,
5 expression::ExpressionG,
6 outputs::SmplOutputG,
7 pose::PoseG,
8 smpl_model::{FaceModel, SmplModel},
9 smpl_options::SmplOptions,
10 types::{Gender, SmplType, UpAxis},
11 },
12 conversions::pose_remap::PoseRemap,
13};
14use burn::tensor::{backend::Backend, Float, Int, Tensor};
15use gloss_geometry::csr::{VertexFaceCSR, VertexFaceCSRBurn};
16use gloss_utils::bshare::ToBurn;
17use gloss_utils::nshare::ToNalgebra;
18use log::{info, warn};
19use nalgebra as na;
20use ndarray as nd;
21use ndarray::prelude::*;
22use ndarray_npy::NpzReader;
23use smpl_utils::{
24 array::Gather2D,
25 io::FileLoader,
26 numerical::{batch_rigid_transform_burn_fast, batch_rodrigues_burn_3},
27};
28use std::{
29 any::Any,
30 io::{Read, Seek},
31};
32pub const NUM_BODY_JOINTS: usize = 21;
33pub const NUM_HAND_JOINTS: usize = 15;
34pub const NUM_FACE_JOINTS: usize = 3;
35pub const NUM_JOINTS: usize = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS;
36pub const NECK_IDX: usize = 12;
37pub const NUM_VERTS: usize = 10475;
38pub const NUM_VERTS_UV_MESH: usize = 11307;
39pub const NUM_FACES: usize = 20908;
40pub const FULL_SHAPE_SPACE_DIM: usize = 400;
41pub const SHAPE_SPACE_DIM: usize = 300;
42pub const EXPRESSION_SPACE_DIM: usize = 100;
43pub const NUM_POSE_BLEND_SHAPES: usize = NUM_JOINTS * 9;
44#[derive(Clone)]
45pub struct SmplXGPUG<B: Backend> {
46 pub device: B::Device,
47 pub smpl_type: SmplType,
48 pub gender: Gender,
49 pub verts_template: Tensor<B, 2, Float>,
50 pub faces: Tensor<B, 2, Int>,
51 pub faces_uv_mesh: Tensor<B, 2, Int>,
52 pub uv: Tensor<B, 2, Float>,
53 pub shape_dirs: Tensor<B, 2, Float>,
54 pub expression_dirs: Option<Tensor<B, 2, Float>>,
55 pub pose_dirs: Option<Tensor<B, 2, Float>>,
56 pub joint_regressor: Tensor<B, 2, Float>,
57 pub parent_idx_per_joint_nd: nd::Array1<u32>,
58 pub parent_idx_per_joint: Tensor<B, 1, Int>,
59 pub lbs_weights: Tensor<B, 2, Float>,
60 pub verts_ones: Tensor<B, 2, Float>,
61 pub idx_vuv_2_vnouv: Tensor<B, 1, Int>,
62 pub faces_na: na::DMatrix<u32>,
63 pub faces_uv_mesh_na: na::DMatrix<u32>,
64 pub uv_na: na::DMatrix<f32>,
65 pub idx_vuv_2_vnouv_vec: Vec<usize>,
66 pub lbs_weights_split: Tensor<B, 2>,
67 pub lbs_weights_nd: nd::ArcArray2<f32>,
68 pub lbs_weights_split_nd: nd::ArcArray2<f32>,
69 pub vertex_face_csr: VertexFaceCSRBurn<B>,
70 pub vertex_face_uv_csr: VertexFaceCSRBurn<B>,
71}
72impl<B: Backend> SmplXGPUG<B> {
73 #[allow(clippy::too_many_arguments)]
76 #[allow(clippy::too_many_lines)]
77 pub fn new_from_matrices(
78 gender: Gender,
79 verts_template: &nd::Array2<f32>,
80 faces: &nd::Array2<u32>,
81 faces_uv_mesh: &nd::Array2<u32>,
82 uv: &nd::Array2<f32>,
83 shape_dirs: &nd::Array3<f32>,
84 expression_dirs: Option<nd::Array3<f32>>,
85 pose_dirs: Option<nd::Array3<f32>>,
86 joint_regressor: &nd::Array2<f32>,
87 parent_idx_per_joint: &nd::Array1<u32>,
88 lbs_weights: nd::Array2<f32>,
89 max_num_betas: usize,
90 max_num_expression_components: usize,
91 ) -> Self {
92 let device = B::Device::default();
93 let b_verts_template = verts_template.to_burn(&device);
94 let b_faces = faces.to_burn(&device);
95 let b_faces_uv_mesh = faces_uv_mesh.to_burn(&device);
96 let b_uv = uv.to_burn(&device);
97 let actual_num_betas = max_num_betas.min(shape_dirs.shape()[2]);
98 let shape_dirs = shape_dirs
99 .slice_axis(Axis(2), ndarray::Slice::from(0..actual_num_betas))
100 .to_owned()
101 .into_shape_with_order((NUM_VERTS * 3, actual_num_betas))
102 .unwrap();
103 let b_shape_dirs = shape_dirs.to_burn(&device);
104 let b_expression_dirs = expression_dirs.map(|expression_dirs| {
105 let actual_num_expression_components = max_num_expression_components.min(expression_dirs.shape()[2]);
106 let expression_dirs = expression_dirs
107 .slice_axis(nd::Axis(2), nd::Slice::from(0..actual_num_expression_components))
108 .into_shape_with_order((NUM_VERTS * 3, actual_num_expression_components))
109 .unwrap()
110 .to_owned();
111 expression_dirs.to_burn(&device)
112 });
113 let b_pose_dirs = pose_dirs.map(|pose_dirs| {
114 let pose_dirs = pose_dirs.into_shape_with_order((NUM_VERTS * 3, NUM_JOINTS * 9)).unwrap();
115 pose_dirs.to_burn(&device)
116 });
117 let b_joint_regressor = joint_regressor.to_burn(&device);
118 let b_parent_idx_per_joint = parent_idx_per_joint.to_burn(&device).reshape([NUM_JOINTS + 1]);
119 let b_lbs_weights = lbs_weights.to_burn(&device);
120 #[allow(clippy::cast_possible_wrap)]
121 let faces_uv_mesh_i32: nd::Array2<i32> = faces_uv_mesh.mapv(|x| x as i32);
122 let ft: nd::ArcArray2<i32> = faces_uv_mesh_i32.into();
123 let max_v_uv_idx = *ft.iter().max_by_key(|&x| x).unwrap();
124 let max_v_uv_idx_usize = usize::try_from(max_v_uv_idx).unwrap_or_else(|_| panic!("Cannot cast max_v_uv_idx to usize"));
125 let mut idx_vuv_2_vnouv = nd::ArcArray1::<i32>::zeros(max_v_uv_idx_usize + 1);
126 for (fuv, fnouv) in ft.axis_iter(nd::Axis(0)).zip(faces.axis_iter(nd::Axis(0))) {
127 let uv_0 = fuv[[0]];
128 let uv_1 = fuv[[1]];
129 let uv_2 = fuv[[2]];
130 let nouv_0 = fnouv[[0]];
131 let nouv_1 = fnouv[[1]];
132 let nouv_2 = fnouv[[2]];
133 idx_vuv_2_vnouv[usize::try_from(uv_0).unwrap_or_else(|_| panic!("Cannot cast uv_0 to usize"))] =
134 i32::try_from(nouv_0).unwrap_or_else(|_| panic!("Cannot cast nouv_0 to i32"));
135 idx_vuv_2_vnouv[usize::try_from(uv_1).unwrap_or_else(|_| panic!("Cannot cast uv_1 to usize"))] =
136 i32::try_from(nouv_1).unwrap_or_else(|_| panic!("Cannot cast nouv_1 to i32"));
137 idx_vuv_2_vnouv[usize::try_from(uv_2).unwrap_or_else(|_| panic!("Cannot cast uv_2 to usize"))] =
138 i32::try_from(nouv_2).unwrap_or_else(|_| panic!("Cannot cast nouv_2 to i32"));
139 }
140 let idx_vuv_2_vnouv_vec: Vec<i32> = idx_vuv_2_vnouv.mapv(|x| x).into_raw_vec_and_offset().0;
141 let idx_vuv_2_vnouv_slice: &[i32] = &idx_vuv_2_vnouv_vec;
142 let b_idx_vuv_2_vnouv = Tensor::<B, 1, Int>::from_ints(idx_vuv_2_vnouv_slice, &device);
143 let idx_vuv_2_vnouv_vec: Vec<usize> = idx_vuv_2_vnouv
144 .to_vec()
145 .iter()
146 .map(|&x| usize::try_from(x).unwrap_or_else(|_| panic!("Cannot cast negative value to usize")))
147 .collect();
148 let faces_na = faces.view().into_nalgebra().clone_owned().map(|x| x);
149 let faces_uv_mesh_na = ft
150 .view()
151 .into_nalgebra()
152 .clone_owned()
153 .map(|x| u32::try_from(x).unwrap_or_else(|_| panic!("Cannot cast value to u32")));
154 let uv_na = uv.view().into_nalgebra().clone_owned();
155 let cols: Vec<usize> = (0..lbs_weights.ncols()).collect();
156 let lbs_weights_split: nd::ArcArray2<f32> = lbs_weights.to_owned().gather(&idx_vuv_2_vnouv_vec, &cols).into();
157 let b_lbs_weights_split =
158 Tensor::<B, 1>::from_floats(lbs_weights_split.as_slice().unwrap(), &device).reshape([idx_vuv_2_vnouv_vec.len(), NUM_JOINTS + 1]);
159 let verts_ones = Tensor::<B, 2>::ones([NUM_VERTS, 1], &device);
160 let lbs_weights_nd: nd::ArcArray2<f32> = lbs_weights.into();
161 let cols: Vec<usize> = (0..lbs_weights_nd.ncols()).collect();
162 let lbs_weights_split_nd = lbs_weights_nd.to_owned().gather(&idx_vuv_2_vnouv_vec, &cols).into();
163 let vertex_face_csr = VertexFaceCSR::from_faces(&faces_na.clone());
164 let vertex_face_csr_burn = vertex_face_csr.to_burn(&device);
165 let vertex_face_uv_csr = VertexFaceCSR::from_faces(&faces_uv_mesh_na.clone());
166 let vertex_face_uv_csr_burn = vertex_face_uv_csr.to_burn(&device);
167 info!("Initialised burn on Backend: {:?}", B::name(&device));
168 info!("Device: {:?}", &device);
169 Self {
170 smpl_type: SmplType::SmplX,
171 gender,
172 device,
173 verts_template: b_verts_template,
174 faces: b_faces,
175 faces_uv_mesh: b_faces_uv_mesh,
176 uv: b_uv,
177 shape_dirs: b_shape_dirs,
178 expression_dirs: b_expression_dirs,
179 pose_dirs: b_pose_dirs,
180 joint_regressor: b_joint_regressor,
181 parent_idx_per_joint: b_parent_idx_per_joint,
182 parent_idx_per_joint_nd: parent_idx_per_joint.clone(),
183 lbs_weights: b_lbs_weights,
184 verts_ones,
185 idx_vuv_2_vnouv: b_idx_vuv_2_vnouv,
186 faces_na,
187 faces_uv_mesh_na,
188 uv_na,
189 idx_vuv_2_vnouv_vec,
190 lbs_weights_split: b_lbs_weights_split,
191 lbs_weights_nd,
192 lbs_weights_split_nd,
193 vertex_face_csr: vertex_face_csr_burn,
194 vertex_face_uv_csr: vertex_face_uv_csr_burn,
195 }
196 }
197 fn new_from_npz_reader<R: Read + Seek>(
200 npz: &mut NpzReader<R>,
201 gender: Gender,
202 max_num_betas: usize,
203 max_num_expression_components: usize,
204 ) -> Self {
205 let verts_template: nd::Array2<f32> = npz.by_name("v_template").unwrap();
206 let faces: nd::Array2<u32> = npz.by_name("f").unwrap();
207 let uv: nd::Array2<f32> = npz.by_name("vt").unwrap();
208 let full_shape_dirs: nd::Array3<f32> = npz.by_name("shapedirs").unwrap();
209 let (shape_dirs, expression_dirs) = if let Ok(expression_dirs) = npz.by_name("expressiondirs") {
210 (full_shape_dirs, Some(expression_dirs))
211 } else {
212 let num_available_betas = full_shape_dirs.shape()[2];
213 let num_full_betas = 300;
214 let num_betas_to_use = num_full_betas.min(max_num_betas).min(num_available_betas);
215 let shape_dirs = full_shape_dirs.slice_axis(nd::Axis(2), nd::Slice::from(0..num_betas_to_use)).to_owned();
216 let expression_dirs = if full_shape_dirs.shape()[2] > 300 {
217 Some(
218 full_shape_dirs
219 .slice_axis(nd::Axis(2), nd::Slice::from(300..300 + max_num_expression_components.min(100)))
220 .to_owned(),
221 )
222 } else {
223 None
224 };
225 (shape_dirs, expression_dirs)
226 };
227 let pose_dirs: Option<nd::Array3<f32>> = npz.by_name("posedirs").ok();
228 let joint_regressor: nd::Array2<f32> = npz.by_name("J_regressor").unwrap();
229 let parent_idx_per_joint: nd::Array2<i32> = npz.by_name("kintree_table").unwrap();
230 #[allow(clippy::cast_sign_loss)]
231 let parent_idx_per_joint = parent_idx_per_joint.mapv(|x| x as u32);
232 let parent_idx_per_joint = parent_idx_per_joint
233 .slice_axis(nd::Axis(0), nd::Slice::from(0..1))
234 .to_owned()
235 .into_shape_with_order(NUM_JOINTS + 1)
236 .unwrap();
237 let lbs_weights: nd::Array2<f32> = npz.by_name("weights").unwrap();
238 let ft: nd::Array2<u32> = npz.by_name("ft").unwrap();
239 if pose_dirs.is_none() {
240 warn!("No pose_dirs loaded from npz");
241 }
242 Self::new_from_matrices(
243 gender,
244 &verts_template,
245 &faces,
246 &ft,
247 &uv,
248 &shape_dirs,
249 expression_dirs,
250 pose_dirs,
251 &joint_regressor,
252 &parent_idx_per_joint,
253 lbs_weights,
254 max_num_betas,
255 max_num_expression_components,
256 )
257 }
258 #[cfg(not(target_arch = "wasm32"))]
259 pub fn new_from_npz(model_path: &str, gender: Gender, max_num_betas: usize, max_num_expression_components: usize) -> Self {
262 let mut npz = NpzReader::new(std::fs::File::open(model_path).unwrap()).unwrap();
263 Self::new_from_npz_reader(&mut npz, gender, max_num_betas, max_num_expression_components)
264 }
265 #[allow(clippy::cast_possible_truncation)]
270 pub async fn new_from_npz_async(model_path: &str, gender: Gender, max_num_betas: usize, max_num_expression_components: usize) -> Self {
271 let reader = FileLoader::open(model_path).await;
272 let mut npz = NpzReader::new(reader).unwrap();
273 Self::new_from_npz_reader(&mut npz, gender, max_num_betas, max_num_expression_components)
274 }
275 #[allow(clippy::cast_possible_truncation)]
280 pub fn new_from_reader<R: Read + Seek>(reader: R, gender: Gender, max_num_betas: usize, max_num_expression_components: usize) -> Self {
281 let mut npz = NpzReader::new(reader).unwrap();
282 Self::new_from_npz_reader(&mut npz, gender, max_num_betas, max_num_expression_components)
283 }
284 #[allow(clippy::cast_possible_truncation)]
289 pub fn read_pose_dirs_from_reader<R: Read + Seek>(reader: R, device: &B::Device) -> Tensor<B, 2, Float> {
290 let mut npz = NpzReader::new(reader).unwrap();
291 let pose_dirs: Option<nd::Array3<f32>> = Some(npz.by_name("pose_dirs").unwrap());
292 let b_pose_dirs =
293 pose_dirs.map(|pose_dirs| Tensor::<B, 1>::from_floats(pose_dirs.as_slice().unwrap(), device).reshape([NUM_VERTS * 3, NUM_JOINTS * 9]));
294 b_pose_dirs.unwrap()
295 }
296}
297impl<B: Backend> FaceModel<B> for SmplXGPUG<B> {
298 #[allow(clippy::missing_panics_doc)]
299 #[allow(non_snake_case)]
300 #[allow(clippy::let_and_return)]
301 fn expression2offsets(&self, expression: &ExpressionG<B>) -> Tensor<B, 2, Float> {
302 let device = self.verts_template.device();
303 let offsets = if let Some(ref expression_dirs) = self.expression_dirs {
304 let input_nr_expression_coeffs = expression.expr_coeffs.dims()[0];
305 let model_nr_expression_coeffs = expression_dirs.shape().dims[1];
306 let nr_expression_coeffs = input_nr_expression_coeffs.min(model_nr_expression_coeffs);
307 #[allow(clippy::single_range_in_vec_init)]
308 let expr_sliced = expression.expr_coeffs.clone().slice([0..nr_expression_coeffs]);
309 let expression_dirs_sliced = expression_dirs.clone().slice([0..expression_dirs.dims()[0], 0..nr_expression_coeffs]);
310 let v_expr_offsets = expression_dirs_sliced.matmul(expr_sliced.reshape([-1, 1]));
311 v_expr_offsets.reshape([NUM_VERTS, 3])
312 } else {
313 Tensor::<B, 2, Float>::zeros([NUM_VERTS, 3], &device)
314 };
315 offsets
316 }
317 fn get_face_model(&self) -> &dyn FaceModel<B> {
318 self
319 }
320}
321impl<B: Backend> SmplModel<B> for SmplXGPUG<B> {
322 fn clone_dyn(&self) -> Box<dyn SmplModel<B>> {
323 Box::new(self.clone())
324 }
325 fn as_any(&self) -> &dyn Any {
326 self
327 }
328 fn smpl_type(&self) -> SmplType {
329 self.smpl_type
330 }
331 fn gender(&self) -> Gender {
332 self.gender
333 }
334 fn device(&self) -> B::Device {
335 self.device.clone()
336 }
337 fn get_face_model(&self) -> &dyn FaceModel<B> {
338 self
339 }
340 #[allow(clippy::missing_panics_doc)]
341 #[allow(non_snake_case)]
342 fn forward(&self, options: &SmplOptions, betas: &BetasG<B>, pose_raw: &PoseG<B>, expression: Option<&ExpressionG<B>>) -> SmplOutputG<B> {
343 let mut verts_t_pose = self.betas2verts(betas);
344 if let Some(expression) = expression {
345 verts_t_pose = verts_t_pose + self.expression2offsets(expression);
346 }
347 let pose_remap = PoseRemap::new(pose_raw.smpl_type, SmplType::SmplX);
348 let pose = pose_remap.remap(pose_raw);
349 let joints_t_pose = self.verts2joints(verts_t_pose.clone());
350 if options.enable_pose_corrective {
351 let verts_offset = self.compute_pose_correctives(&pose);
352 verts_t_pose = verts_t_pose + verts_offset;
353 }
354 let (verts_posed_nd, joints_posed) = self.apply_pose(&verts_t_pose, &joints_t_pose, &self.lbs_weights, &pose);
355 SmplOutputG {
356 verts: verts_posed_nd,
357 faces: self.faces.clone(),
358 normals: None,
359 uvs: None,
360 joints: joints_posed,
361 }
362 }
363 fn create_body_with_uv(&self, smpl_merged: &SmplOutputG<B>) -> SmplOutputG<B> {
364 let cols_tensor = Tensor::<B, 1, Int>::from_ints([0, 1, 2], &self.device);
365 let mapping_tensor = self.idx_split_2_merged();
366 let v_burn_split = smpl_merged.verts.clone().select(0, mapping_tensor.clone());
367 let v_burn_split = v_burn_split.select(1, cols_tensor.clone());
368 let n_burn_split = smpl_merged
369 .normals
370 .as_ref()
371 .map(|n| n.clone().select(0, mapping_tensor).select(1, cols_tensor));
372 SmplOutputG {
373 verts: v_burn_split,
374 faces: self.faces_uv_mesh.clone(),
375 normals: n_burn_split,
376 uvs: Some(self.uv.clone()),
377 joints: smpl_merged.joints.clone(),
378 }
379 }
380 #[allow(clippy::missing_panics_doc)]
381 #[allow(non_snake_case)]
382 #[allow(clippy::let_and_return)]
383 fn betas2verts(&self, betas: &BetasG<B>) -> Tensor<B, 2, Float> {
384 let input_nr_betas = betas.betas.dims()[0];
385 let model_nr_betas = self.shape_dirs.shape().dims[1];
386 let nr_betas = input_nr_betas.min(model_nr_betas);
387 #[allow(clippy::single_range_in_vec_init)]
388 let betas_sliced = betas.betas.clone().slice([0..nr_betas]);
389 let shape_dirs_sliced = self.shape_dirs.clone().slice([0..self.shape_dirs.dims()[0], 0..nr_betas]);
390 let v_beta_offsets = shape_dirs_sliced.matmul(betas_sliced.reshape([-1, 1]));
391 let v_beta_offsets_reshaped = v_beta_offsets.reshape([NUM_VERTS, 3]);
392 let verts_t_pose = v_beta_offsets_reshaped.add(self.verts_template.clone());
393 verts_t_pose
394 }
395 fn verts2joints(&self, verts_t_pose: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
396 self.joint_regressor.clone().matmul(verts_t_pose)
397 }
398 #[allow(clippy::missing_panics_doc)]
399 fn compute_pose_correctives(&self, pose: &PoseG<B>) -> Tensor<B, 2, Float> {
400 if let Some(pose_dirs) = &self.pose_dirs {
401 let full_pose = &pose.joint_poses;
402 assert!(
403 full_pose.dims()[0] == NUM_JOINTS + 1,
404 "The pose does not have the correct number of joints for this model. Maybe you need to add a PoseRemapper component?\n {:?} != {:?}",
405 full_pose.dims()[0],
406 NUM_JOINTS + 1
407 );
408 let b_pose_feature = self.compute_pose_feature(pose);
409 let b_pose_feature = b_pose_feature.reshape([NUM_JOINTS * 9, 1]);
410 let new_pose_dirs = pose_dirs.clone();
411 let all_pose_offsets = new_pose_dirs.matmul(b_pose_feature);
412 all_pose_offsets.reshape([NUM_VERTS, 3])
413 } else {
414 Tensor::<B, 2, Float>::zeros([NUM_VERTS, 3], &self.device)
415 }
416 }
417 #[allow(clippy::missing_panics_doc)]
418 fn compute_pose_feature(&self, pose: &PoseG<B>) -> Tensor<B, 1> {
419 let full_pose = &pose.joint_poses;
420 assert!(
421 full_pose.dims()[0] == NUM_JOINTS + 1,
422 "The pose does not have the correct number of joints for this model. Maybe you need to add a PoseRemapper component?\n {:?} != {:?}",
423 full_pose.dims()[0],
424 NUM_JOINTS + 1
425 );
426 let rot_mats = batch_rodrigues_burn_3(full_pose);
427 let identity = Tensor::<B, 2>::eye(3, &self.device());
428 (rot_mats.clone().slice([1..rot_mats.dims()[0], 0..3, 0..3]) - identity.unsqueeze_dim(0)).reshape([NUM_JOINTS * 9])
429 }
430 #[allow(clippy::missing_panics_doc)]
431 #[allow(non_snake_case)]
432 #[allow(clippy::cast_precision_loss)]
433 #[allow(clippy::cast_sign_loss)]
434 #[allow(clippy::too_many_lines)]
435 #[allow(clippy::similar_names)]
436 fn apply_pose(
437 &self,
438 verts_t_pose: &Tensor<B, 2, Float>,
439 joints: &Tensor<B, 2, Float>,
440 lbs_weights: &Tensor<B, 2, Float>,
441 pose: &PoseG<B>,
442 ) -> (Tensor<B, 2, Float>, Tensor<B, 2, Float>) {
443 assert!(
444 verts_t_pose.shape().dims[0] == lbs_weights.shape().dims[0],
445 "Verts and LBS weights should match"
446 );
447 let full_pose = &pose.joint_poses;
448 assert!(
449 full_pose.dims()[0] == NUM_JOINTS + 1,
450 "The pose does not have the correct number of joints for this model."
451 );
452 let full_pose: Tensor<B, 2> = pose.joint_poses.clone();
453 let rot_mats_t = batch_rodrigues_burn_3(&full_pose);
454 let (posed_joints, rel_transforms) = batch_rigid_transform_burn_fast(
455 self.parent_idx_per_joint.clone(),
456 &self.parent_idx_per_joint_nd,
457 rot_mats_t,
458 joints.clone(),
459 );
460 let nr_verts = verts_t_pose.shape().dims[0];
461 let A = rel_transforms.reshape([NUM_JOINTS + 1, 16]);
462 let T = lbs_weights.clone().matmul(A).reshape([nr_verts, 4, 4]);
463 let ones = Tensor::ones([nr_verts, 1], &self.device);
464 let v_posed_h = Tensor::cat(vec![verts_t_pose.clone(), ones], 1).unsqueeze_dim(2);
465 let verts_final_h = T.matmul(v_posed_h).squeeze(2);
466 let verts_final = verts_final_h.slice([0..nr_verts, 0..3]);
467 let trans_pose = pose.global_trans.clone().reshape([1, 3]);
468 let mut verts_final = verts_final.clone() + trans_pose.clone();
469 let mut posed_joints = posed_joints.clone() + trans_pose.clone();
470 if pose.up_axis == UpAxis::Z {
471 let vcol0: Tensor<B, 1> = verts_final.clone().slice([0..nr_verts, 0..1]).squeeze(1);
472 let vcol1: Tensor<B, 1> = verts_final.clone().slice([0..nr_verts, 1..2]).squeeze(1);
473 let vcol2: Tensor<B, 1> = verts_final.clone().slice([0..nr_verts, 2..3]).squeeze(1);
474 let verts_new_col1 = vcol2;
475 let verts_new_col2 = vcol1.mul_scalar(-1.0);
476 verts_final = Tensor::stack::<2>(vec![vcol0, verts_new_col1, verts_new_col2], 1);
477 let nr_joints = posed_joints.shape().dims[0];
478 let jcol0: Tensor<B, 1> = posed_joints.clone().slice([0..nr_joints, 0..1]).squeeze(1);
479 let jcol1: Tensor<B, 1> = posed_joints.clone().slice([0..nr_joints, 1..2]).squeeze(1);
480 let jcol2: Tensor<B, 1> = posed_joints.clone().slice([0..nr_joints, 2..3]).squeeze(1);
481 let joints_new_col1 = jcol2;
482 let joints_new_col2 = jcol1.mul_scalar(-1.0);
483 posed_joints = Tensor::stack::<2>(vec![jcol0, joints_new_col1, joints_new_col2], 1);
484 }
485 (verts_final, posed_joints)
486 }
487 fn faces(&self) -> &Tensor<B, 2, Int> {
488 &self.faces
489 }
490 fn faces_uv(&self) -> &Tensor<B, 2, Int> {
491 &self.faces_uv_mesh
492 }
493 fn uv(&self) -> &Tensor<B, 2, Float> {
494 &self.uv
495 }
496 fn lbs_weights(&self) -> Tensor<B, 2, Float> {
497 self.lbs_weights.clone()
498 }
499 fn lbs_weights_split(&self) -> Tensor<B, 2, Float> {
500 self.lbs_weights_split.clone()
501 }
502 fn idx_split_2_merged(&self) -> Tensor<B, 1, Int> {
503 self.idx_vuv_2_vnouv.clone()
504 }
505 fn idx_split_2_merged_vec(&self) -> &Vec<usize> {
506 &self.idx_vuv_2_vnouv_vec
507 }
508 fn set_pose_dirs(&mut self, pose_dirs: Tensor<B, 2, Float>) {
509 self.pose_dirs = Some(pose_dirs);
510 }
511 fn get_pose_dirs(&self) -> Tensor<B, 2, Float> {
512 if let Some(pose_dirs_tensor) = self.pose_dirs.clone() {
513 pose_dirs_tensor
514 } else {
515 panic!("pose_dirs is not available!");
516 }
517 }
518 fn get_expression_dirs(&self) -> Option<Tensor<B, 2, Float>> {
519 self.expression_dirs.clone()
520 }
521 fn vertex_face_csr(&self) -> Option<VertexFaceCSRBurn<B>> {
522 Some(self.vertex_face_csr.clone())
523 }
524 fn vertex_face_uv_csr(&self) -> Option<VertexFaceCSRBurn<B>> {
525 Some(self.vertex_face_uv_csr.clone())
526 }
527}
528pub type SmplXGPU = SmplXGPUG<AppBackend>;