smpl_core/smpl_x/
smpl_x_gpu.rs

1use crate::AppBackend;
2use crate::{
3    common::{
4        betas::BetasG,
5        expression::ExpressionG,
6        outputs::SmplOutputG,
7        pose::PoseG,
8        smpl_model::{FaceModel, SmplModel},
9        smpl_options::SmplOptions,
10        types::{Gender, SmplType, UpAxis},
11    },
12    conversions::pose_remap::PoseRemap,
13};
14use burn::tensor::{backend::Backend, Float, Int, Tensor};
15use gloss_geometry::csr::{VertexFaceCSR, VertexFaceCSRBurn};
16use gloss_utils::bshare::ToBurn;
17use gloss_utils::nshare::ToNalgebra;
18use log::{info, warn};
19use nalgebra as na;
20use ndarray as nd;
21use ndarray::prelude::*;
22use ndarray_npy::NpzReader;
23use smpl_utils::{
24    array::Gather2D,
25    io::FileLoader,
26    numerical::{batch_rigid_transform_burn_fast, batch_rodrigues_burn_3},
27};
28use std::{
29    any::Any,
30    io::{Read, Seek},
31};
32pub const NUM_BODY_JOINTS: usize = 21;
33pub const NUM_HAND_JOINTS: usize = 15;
34pub const NUM_FACE_JOINTS: usize = 3;
35pub const NUM_JOINTS: usize = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS;
36pub const NECK_IDX: usize = 12;
37pub const NUM_VERTS: usize = 10475;
38pub const NUM_VERTS_UV_MESH: usize = 11307;
39pub const NUM_FACES: usize = 20908;
40pub const FULL_SHAPE_SPACE_DIM: usize = 400;
41pub const SHAPE_SPACE_DIM: usize = 300;
42pub const EXPRESSION_SPACE_DIM: usize = 100;
43pub const NUM_POSE_BLEND_SHAPES: usize = NUM_JOINTS * 9;
44#[derive(Clone)]
45pub struct SmplXGPUG<B: Backend> {
46    pub device: B::Device,
47    pub smpl_type: SmplType,
48    pub gender: Gender,
49    pub verts_template: Tensor<B, 2, Float>,
50    pub faces: Tensor<B, 2, Int>,
51    pub faces_uv_mesh: Tensor<B, 2, Int>,
52    pub uv: Tensor<B, 2, Float>,
53    pub shape_dirs: Tensor<B, 2, Float>,
54    pub expression_dirs: Option<Tensor<B, 2, Float>>,
55    pub pose_dirs: Option<Tensor<B, 2, Float>>,
56    pub joint_regressor: Tensor<B, 2, Float>,
57    pub parent_idx_per_joint_nd: nd::Array1<u32>,
58    pub parent_idx_per_joint: Tensor<B, 1, Int>,
59    pub lbs_weights: Tensor<B, 2, Float>,
60    pub verts_ones: Tensor<B, 2, Float>,
61    pub idx_vuv_2_vnouv: Tensor<B, 1, Int>,
62    pub faces_na: na::DMatrix<u32>,
63    pub faces_uv_mesh_na: na::DMatrix<u32>,
64    pub uv_na: na::DMatrix<f32>,
65    pub idx_vuv_2_vnouv_vec: Vec<usize>,
66    pub lbs_weights_split: Tensor<B, 2>,
67    pub lbs_weights_nd: nd::ArcArray2<f32>,
68    pub lbs_weights_split_nd: nd::ArcArray2<f32>,
69    pub vertex_face_csr: VertexFaceCSRBurn<B>,
70    pub vertex_face_uv_csr: VertexFaceCSRBurn<B>,
71}
72impl<B: Backend> SmplXGPUG<B> {
73    /// # Panics
74    /// Will panic if the matrices don't match the expected sizes
75    #[allow(clippy::too_many_arguments)]
76    #[allow(clippy::too_many_lines)]
77    pub fn new_from_matrices(
78        gender: Gender,
79        verts_template: &nd::Array2<f32>,
80        faces: &nd::Array2<u32>,
81        faces_uv_mesh: &nd::Array2<u32>,
82        uv: &nd::Array2<f32>,
83        shape_dirs: &nd::Array3<f32>,
84        expression_dirs: Option<nd::Array3<f32>>,
85        pose_dirs: Option<nd::Array3<f32>>,
86        joint_regressor: &nd::Array2<f32>,
87        parent_idx_per_joint: &nd::Array1<u32>,
88        lbs_weights: nd::Array2<f32>,
89        max_num_betas: usize,
90        max_num_expression_components: usize,
91    ) -> Self {
92        let device = B::Device::default();
93        let b_verts_template = verts_template.to_burn(&device);
94        let b_faces = faces.to_burn(&device);
95        let b_faces_uv_mesh = faces_uv_mesh.to_burn(&device);
96        let b_uv = uv.to_burn(&device);
97        let actual_num_betas = max_num_betas.min(shape_dirs.shape()[2]);
98        let shape_dirs = shape_dirs
99            .slice_axis(Axis(2), ndarray::Slice::from(0..actual_num_betas))
100            .to_owned()
101            .into_shape_with_order((NUM_VERTS * 3, actual_num_betas))
102            .unwrap();
103        let b_shape_dirs = shape_dirs.to_burn(&device);
104        let b_expression_dirs = expression_dirs.map(|expression_dirs| {
105            let actual_num_expression_components = max_num_expression_components.min(expression_dirs.shape()[2]);
106            let expression_dirs = expression_dirs
107                .slice_axis(nd::Axis(2), nd::Slice::from(0..actual_num_expression_components))
108                .into_shape_with_order((NUM_VERTS * 3, actual_num_expression_components))
109                .unwrap()
110                .to_owned();
111            expression_dirs.to_burn(&device)
112        });
113        let b_pose_dirs = pose_dirs.map(|pose_dirs| {
114            let pose_dirs = pose_dirs.into_shape_with_order((NUM_VERTS * 3, NUM_JOINTS * 9)).unwrap();
115            pose_dirs.to_burn(&device)
116        });
117        let b_joint_regressor = joint_regressor.to_burn(&device);
118        let b_parent_idx_per_joint = parent_idx_per_joint.to_burn(&device).reshape([NUM_JOINTS + 1]);
119        let b_lbs_weights = lbs_weights.to_burn(&device);
120        #[allow(clippy::cast_possible_wrap)]
121        let faces_uv_mesh_i32: nd::Array2<i32> = faces_uv_mesh.mapv(|x| x as i32);
122        let ft: nd::ArcArray2<i32> = faces_uv_mesh_i32.into();
123        let max_v_uv_idx = *ft.iter().max_by_key(|&x| x).unwrap();
124        let max_v_uv_idx_usize = usize::try_from(max_v_uv_idx).unwrap_or_else(|_| panic!("Cannot cast max_v_uv_idx to usize"));
125        let mut idx_vuv_2_vnouv = nd::ArcArray1::<i32>::zeros(max_v_uv_idx_usize + 1);
126        for (fuv, fnouv) in ft.axis_iter(nd::Axis(0)).zip(faces.axis_iter(nd::Axis(0))) {
127            let uv_0 = fuv[[0]];
128            let uv_1 = fuv[[1]];
129            let uv_2 = fuv[[2]];
130            let nouv_0 = fnouv[[0]];
131            let nouv_1 = fnouv[[1]];
132            let nouv_2 = fnouv[[2]];
133            idx_vuv_2_vnouv[usize::try_from(uv_0).unwrap_or_else(|_| panic!("Cannot cast uv_0 to usize"))] =
134                i32::try_from(nouv_0).unwrap_or_else(|_| panic!("Cannot cast nouv_0 to i32"));
135            idx_vuv_2_vnouv[usize::try_from(uv_1).unwrap_or_else(|_| panic!("Cannot cast uv_1 to usize"))] =
136                i32::try_from(nouv_1).unwrap_or_else(|_| panic!("Cannot cast nouv_1 to i32"));
137            idx_vuv_2_vnouv[usize::try_from(uv_2).unwrap_or_else(|_| panic!("Cannot cast uv_2 to usize"))] =
138                i32::try_from(nouv_2).unwrap_or_else(|_| panic!("Cannot cast nouv_2 to i32"));
139        }
140        let idx_vuv_2_vnouv_vec: Vec<i32> = idx_vuv_2_vnouv.mapv(|x| x).into_raw_vec_and_offset().0;
141        let idx_vuv_2_vnouv_slice: &[i32] = &idx_vuv_2_vnouv_vec;
142        let b_idx_vuv_2_vnouv = Tensor::<B, 1, Int>::from_ints(idx_vuv_2_vnouv_slice, &device);
143        let idx_vuv_2_vnouv_vec: Vec<usize> = idx_vuv_2_vnouv
144            .to_vec()
145            .iter()
146            .map(|&x| usize::try_from(x).unwrap_or_else(|_| panic!("Cannot cast negative value to usize")))
147            .collect();
148        let faces_na = faces.view().into_nalgebra().clone_owned().map(|x| x);
149        let faces_uv_mesh_na = ft
150            .view()
151            .into_nalgebra()
152            .clone_owned()
153            .map(|x| u32::try_from(x).unwrap_or_else(|_| panic!("Cannot cast value to u32")));
154        let uv_na = uv.view().into_nalgebra().clone_owned();
155        let cols: Vec<usize> = (0..lbs_weights.ncols()).collect();
156        let lbs_weights_split: nd::ArcArray2<f32> = lbs_weights.to_owned().gather(&idx_vuv_2_vnouv_vec, &cols).into();
157        let b_lbs_weights_split =
158            Tensor::<B, 1>::from_floats(lbs_weights_split.as_slice().unwrap(), &device).reshape([idx_vuv_2_vnouv_vec.len(), NUM_JOINTS + 1]);
159        let verts_ones = Tensor::<B, 2>::ones([NUM_VERTS, 1], &device);
160        let lbs_weights_nd: nd::ArcArray2<f32> = lbs_weights.into();
161        let cols: Vec<usize> = (0..lbs_weights_nd.ncols()).collect();
162        let lbs_weights_split_nd = lbs_weights_nd.to_owned().gather(&idx_vuv_2_vnouv_vec, &cols).into();
163        let vertex_face_csr = VertexFaceCSR::from_faces(&faces_na.clone());
164        let vertex_face_csr_burn = vertex_face_csr.to_burn(&device);
165        let vertex_face_uv_csr = VertexFaceCSR::from_faces(&faces_uv_mesh_na.clone());
166        let vertex_face_uv_csr_burn = vertex_face_uv_csr.to_burn(&device);
167        info!("Initialised burn on Backend: {:?}", B::name(&device));
168        info!("Device: {:?}", &device);
169        Self {
170            smpl_type: SmplType::SmplX,
171            gender,
172            device,
173            verts_template: b_verts_template,
174            faces: b_faces,
175            faces_uv_mesh: b_faces_uv_mesh,
176            uv: b_uv,
177            shape_dirs: b_shape_dirs,
178            expression_dirs: b_expression_dirs,
179            pose_dirs: b_pose_dirs,
180            joint_regressor: b_joint_regressor,
181            parent_idx_per_joint: b_parent_idx_per_joint,
182            parent_idx_per_joint_nd: parent_idx_per_joint.clone(),
183            lbs_weights: b_lbs_weights,
184            verts_ones,
185            idx_vuv_2_vnouv: b_idx_vuv_2_vnouv,
186            faces_na,
187            faces_uv_mesh_na,
188            uv_na,
189            idx_vuv_2_vnouv_vec,
190            lbs_weights_split: b_lbs_weights_split,
191            lbs_weights_nd,
192            lbs_weights_split_nd,
193            vertex_face_csr: vertex_face_csr_burn,
194            vertex_face_uv_csr: vertex_face_uv_csr_burn,
195        }
196    }
197    /// # Panics
198    /// Will panic if the path cannot be opened
199    fn new_from_npz_reader<R: Read + Seek>(
200        npz: &mut NpzReader<R>,
201        gender: Gender,
202        max_num_betas: usize,
203        max_num_expression_components: usize,
204    ) -> Self {
205        let verts_template: nd::Array2<f32> = npz.by_name("v_template").unwrap();
206        let faces: nd::Array2<u32> = npz.by_name("f").unwrap();
207        let uv: nd::Array2<f32> = npz.by_name("vt").unwrap();
208        let full_shape_dirs: nd::Array3<f32> = npz.by_name("shapedirs").unwrap();
209        let (shape_dirs, expression_dirs) = if let Ok(expression_dirs) = npz.by_name("expressiondirs") {
210            (full_shape_dirs, Some(expression_dirs))
211        } else {
212            let num_available_betas = full_shape_dirs.shape()[2];
213            let num_full_betas = 300;
214            let num_betas_to_use = num_full_betas.min(max_num_betas).min(num_available_betas);
215            let shape_dirs = full_shape_dirs.slice_axis(nd::Axis(2), nd::Slice::from(0..num_betas_to_use)).to_owned();
216            let expression_dirs = if full_shape_dirs.shape()[2] > 300 {
217                Some(
218                    full_shape_dirs
219                        .slice_axis(nd::Axis(2), nd::Slice::from(300..300 + max_num_expression_components.min(100)))
220                        .to_owned(),
221                )
222            } else {
223                None
224            };
225            (shape_dirs, expression_dirs)
226        };
227        let pose_dirs: Option<nd::Array3<f32>> = npz.by_name("posedirs").ok();
228        let joint_regressor: nd::Array2<f32> = npz.by_name("J_regressor").unwrap();
229        let parent_idx_per_joint: nd::Array2<i32> = npz.by_name("kintree_table").unwrap();
230        #[allow(clippy::cast_sign_loss)]
231        let parent_idx_per_joint = parent_idx_per_joint.mapv(|x| x as u32);
232        let parent_idx_per_joint = parent_idx_per_joint
233            .slice_axis(nd::Axis(0), nd::Slice::from(0..1))
234            .to_owned()
235            .into_shape_with_order(NUM_JOINTS + 1)
236            .unwrap();
237        let lbs_weights: nd::Array2<f32> = npz.by_name("weights").unwrap();
238        let ft: nd::Array2<u32> = npz.by_name("ft").unwrap();
239        if pose_dirs.is_none() {
240            warn!("No pose_dirs loaded from npz");
241        }
242        Self::new_from_matrices(
243            gender,
244            &verts_template,
245            &faces,
246            &ft,
247            &uv,
248            &shape_dirs,
249            expression_dirs,
250            pose_dirs,
251            &joint_regressor,
252            &parent_idx_per_joint,
253            lbs_weights,
254            max_num_betas,
255            max_num_expression_components,
256        )
257    }
258    #[cfg(not(target_arch = "wasm32"))]
259    /// # Panics
260    /// Will panic if the path cannot be opened
261    pub fn new_from_npz(model_path: &str, gender: Gender, max_num_betas: usize, max_num_expression_components: usize) -> Self {
262        let mut npz = NpzReader::new(std::fs::File::open(model_path).unwrap()).unwrap();
263        Self::new_from_npz_reader(&mut npz, gender, max_num_betas, max_num_expression_components)
264    }
265    /// # Panics
266    /// Will panic if the path cannot be opened
267    /// Will panic if the translation and rotation do not cover the same number
268    /// of timesteps
269    #[allow(clippy::cast_possible_truncation)]
270    pub async fn new_from_npz_async(model_path: &str, gender: Gender, max_num_betas: usize, max_num_expression_components: usize) -> Self {
271        let reader = FileLoader::open(model_path).await;
272        let mut npz = NpzReader::new(reader).unwrap();
273        Self::new_from_npz_reader(&mut npz, gender, max_num_betas, max_num_expression_components)
274    }
275    /// # Panics
276    /// Will panic if the path cannot be opened
277    /// Will panic if the translation and rotation do not cover the same number
278    /// of timesteps
279    #[allow(clippy::cast_possible_truncation)]
280    pub fn new_from_reader<R: Read + Seek>(reader: R, gender: Gender, max_num_betas: usize, max_num_expression_components: usize) -> Self {
281        let mut npz = NpzReader::new(reader).unwrap();
282        Self::new_from_npz_reader(&mut npz, gender, max_num_betas, max_num_expression_components)
283    }
284    /// # Panics
285    /// Will panic if the path cannot be opened
286    /// Will panic if the translation and rotation do not cover the same number
287    /// of timesteps
288    #[allow(clippy::cast_possible_truncation)]
289    pub fn read_pose_dirs_from_reader<R: Read + Seek>(reader: R, device: &B::Device) -> Tensor<B, 2, Float> {
290        let mut npz = NpzReader::new(reader).unwrap();
291        let pose_dirs: Option<nd::Array3<f32>> = Some(npz.by_name("pose_dirs").unwrap());
292        let b_pose_dirs =
293            pose_dirs.map(|pose_dirs| Tensor::<B, 1>::from_floats(pose_dirs.as_slice().unwrap(), device).reshape([NUM_VERTS * 3, NUM_JOINTS * 9]));
294        b_pose_dirs.unwrap()
295    }
296}
297impl<B: Backend> FaceModel<B> for SmplXGPUG<B> {
298    #[allow(clippy::missing_panics_doc)]
299    #[allow(non_snake_case)]
300    #[allow(clippy::let_and_return)]
301    fn expression2offsets(&self, expression: &ExpressionG<B>) -> Tensor<B, 2, Float> {
302        let device = self.verts_template.device();
303        let offsets = if let Some(ref expression_dirs) = self.expression_dirs {
304            let input_nr_expression_coeffs = expression.expr_coeffs.dims()[0];
305            let model_nr_expression_coeffs = expression_dirs.shape().dims[1];
306            let nr_expression_coeffs = input_nr_expression_coeffs.min(model_nr_expression_coeffs);
307            #[allow(clippy::single_range_in_vec_init)]
308            let expr_sliced = expression.expr_coeffs.clone().slice([0..nr_expression_coeffs]);
309            let expression_dirs_sliced = expression_dirs.clone().slice([0..expression_dirs.dims()[0], 0..nr_expression_coeffs]);
310            let v_expr_offsets = expression_dirs_sliced.matmul(expr_sliced.reshape([-1, 1]));
311            v_expr_offsets.reshape([NUM_VERTS, 3])
312        } else {
313            Tensor::<B, 2, Float>::zeros([NUM_VERTS, 3], &device)
314        };
315        offsets
316    }
317    fn get_face_model(&self) -> &dyn FaceModel<B> {
318        self
319    }
320}
321impl<B: Backend> SmplModel<B> for SmplXGPUG<B> {
322    fn clone_dyn(&self) -> Box<dyn SmplModel<B>> {
323        Box::new(self.clone())
324    }
325    fn as_any(&self) -> &dyn Any {
326        self
327    }
328    fn smpl_type(&self) -> SmplType {
329        self.smpl_type
330    }
331    fn gender(&self) -> Gender {
332        self.gender
333    }
334    fn device(&self) -> B::Device {
335        self.device.clone()
336    }
337    fn get_face_model(&self) -> &dyn FaceModel<B> {
338        self
339    }
340    #[allow(clippy::missing_panics_doc)]
341    #[allow(non_snake_case)]
342    fn forward(&self, options: &SmplOptions, betas: &BetasG<B>, pose_raw: &PoseG<B>, expression: Option<&ExpressionG<B>>) -> SmplOutputG<B> {
343        let mut verts_t_pose = self.betas2verts(betas);
344        if let Some(expression) = expression {
345            verts_t_pose = verts_t_pose + self.expression2offsets(expression);
346        }
347        let pose_remap = PoseRemap::new(pose_raw.smpl_type, SmplType::SmplX);
348        let pose = pose_remap.remap(pose_raw);
349        let joints_t_pose = self.verts2joints(verts_t_pose.clone());
350        if options.enable_pose_corrective {
351            let verts_offset = self.compute_pose_correctives(&pose);
352            verts_t_pose = verts_t_pose + verts_offset;
353        }
354        let (verts_posed_nd, joints_posed) = self.apply_pose(&verts_t_pose, &joints_t_pose, &self.lbs_weights, &pose);
355        SmplOutputG {
356            verts: verts_posed_nd,
357            faces: self.faces.clone(),
358            normals: None,
359            uvs: None,
360            joints: joints_posed,
361        }
362    }
363    fn create_body_with_uv(&self, smpl_merged: &SmplOutputG<B>) -> SmplOutputG<B> {
364        let cols_tensor = Tensor::<B, 1, Int>::from_ints([0, 1, 2], &self.device);
365        let mapping_tensor = self.idx_split_2_merged();
366        let v_burn_split = smpl_merged.verts.clone().select(0, mapping_tensor.clone());
367        let v_burn_split = v_burn_split.select(1, cols_tensor.clone());
368        let n_burn_split = smpl_merged
369            .normals
370            .as_ref()
371            .map(|n| n.clone().select(0, mapping_tensor).select(1, cols_tensor));
372        SmplOutputG {
373            verts: v_burn_split,
374            faces: self.faces_uv_mesh.clone(),
375            normals: n_burn_split,
376            uvs: Some(self.uv.clone()),
377            joints: smpl_merged.joints.clone(),
378        }
379    }
380    #[allow(clippy::missing_panics_doc)]
381    #[allow(non_snake_case)]
382    #[allow(clippy::let_and_return)]
383    fn betas2verts(&self, betas: &BetasG<B>) -> Tensor<B, 2, Float> {
384        let input_nr_betas = betas.betas.dims()[0];
385        let model_nr_betas = self.shape_dirs.shape().dims[1];
386        let nr_betas = input_nr_betas.min(model_nr_betas);
387        #[allow(clippy::single_range_in_vec_init)]
388        let betas_sliced = betas.betas.clone().slice([0..nr_betas]);
389        let shape_dirs_sliced = self.shape_dirs.clone().slice([0..self.shape_dirs.dims()[0], 0..nr_betas]);
390        let v_beta_offsets = shape_dirs_sliced.matmul(betas_sliced.reshape([-1, 1]));
391        let v_beta_offsets_reshaped = v_beta_offsets.reshape([NUM_VERTS, 3]);
392        let verts_t_pose = v_beta_offsets_reshaped.add(self.verts_template.clone());
393        verts_t_pose
394    }
395    fn verts2joints(&self, verts_t_pose: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
396        self.joint_regressor.clone().matmul(verts_t_pose)
397    }
398    #[allow(clippy::missing_panics_doc)]
399    fn compute_pose_correctives(&self, pose: &PoseG<B>) -> Tensor<B, 2, Float> {
400        if let Some(pose_dirs) = &self.pose_dirs {
401            let full_pose = &pose.joint_poses;
402            assert!(
403                full_pose.dims()[0] == NUM_JOINTS + 1,
404                "The pose does not have the correct number of joints for this model. Maybe you need to add a PoseRemapper component?\n {:?} != {:?}",
405                full_pose.dims()[0],
406                NUM_JOINTS + 1
407            );
408            let b_pose_feature = self.compute_pose_feature(pose);
409            let b_pose_feature = b_pose_feature.reshape([NUM_JOINTS * 9, 1]);
410            let new_pose_dirs = pose_dirs.clone();
411            let all_pose_offsets = new_pose_dirs.matmul(b_pose_feature);
412            all_pose_offsets.reshape([NUM_VERTS, 3])
413        } else {
414            Tensor::<B, 2, Float>::zeros([NUM_VERTS, 3], &self.device)
415        }
416    }
417    #[allow(clippy::missing_panics_doc)]
418    fn compute_pose_feature(&self, pose: &PoseG<B>) -> Tensor<B, 1> {
419        let full_pose = &pose.joint_poses;
420        assert!(
421            full_pose.dims()[0] == NUM_JOINTS + 1,
422            "The pose does not have the correct number of joints for this model. Maybe you need to add a PoseRemapper component?\n {:?} != {:?}",
423            full_pose.dims()[0],
424            NUM_JOINTS + 1
425        );
426        let rot_mats = batch_rodrigues_burn_3(full_pose);
427        let identity = Tensor::<B, 2>::eye(3, &self.device());
428        (rot_mats.clone().slice([1..rot_mats.dims()[0], 0..3, 0..3]) - identity.unsqueeze_dim(0)).reshape([NUM_JOINTS * 9])
429    }
430    #[allow(clippy::missing_panics_doc)]
431    #[allow(non_snake_case)]
432    #[allow(clippy::cast_precision_loss)]
433    #[allow(clippy::cast_sign_loss)]
434    #[allow(clippy::too_many_lines)]
435    #[allow(clippy::similar_names)]
436    fn apply_pose(
437        &self,
438        verts_t_pose: &Tensor<B, 2, Float>,
439        joints: &Tensor<B, 2, Float>,
440        lbs_weights: &Tensor<B, 2, Float>,
441        pose: &PoseG<B>,
442    ) -> (Tensor<B, 2, Float>, Tensor<B, 2, Float>) {
443        assert!(
444            verts_t_pose.shape().dims[0] == lbs_weights.shape().dims[0],
445            "Verts and LBS weights should match"
446        );
447        let full_pose = &pose.joint_poses;
448        assert!(
449            full_pose.dims()[0] == NUM_JOINTS + 1,
450            "The pose does not have the correct number of joints for this model."
451        );
452        let full_pose: Tensor<B, 2> = pose.joint_poses.clone();
453        let rot_mats_t = batch_rodrigues_burn_3(&full_pose);
454        let (posed_joints, rel_transforms) = batch_rigid_transform_burn_fast(
455            self.parent_idx_per_joint.clone(),
456            &self.parent_idx_per_joint_nd,
457            rot_mats_t,
458            joints.clone(),
459        );
460        let nr_verts = verts_t_pose.shape().dims[0];
461        let A = rel_transforms.reshape([NUM_JOINTS + 1, 16]);
462        let T = lbs_weights.clone().matmul(A).reshape([nr_verts, 4, 4]);
463        let ones = Tensor::ones([nr_verts, 1], &self.device);
464        let v_posed_h = Tensor::cat(vec![verts_t_pose.clone(), ones], 1).unsqueeze_dim(2);
465        let verts_final_h = T.matmul(v_posed_h).squeeze(2);
466        let verts_final = verts_final_h.slice([0..nr_verts, 0..3]);
467        let trans_pose = pose.global_trans.clone().reshape([1, 3]);
468        let mut verts_final = verts_final.clone() + trans_pose.clone();
469        let mut posed_joints = posed_joints.clone() + trans_pose.clone();
470        if pose.up_axis == UpAxis::Z {
471            let vcol0: Tensor<B, 1> = verts_final.clone().slice([0..nr_verts, 0..1]).squeeze(1);
472            let vcol1: Tensor<B, 1> = verts_final.clone().slice([0..nr_verts, 1..2]).squeeze(1);
473            let vcol2: Tensor<B, 1> = verts_final.clone().slice([0..nr_verts, 2..3]).squeeze(1);
474            let verts_new_col1 = vcol2;
475            let verts_new_col2 = vcol1.mul_scalar(-1.0);
476            verts_final = Tensor::stack::<2>(vec![vcol0, verts_new_col1, verts_new_col2], 1);
477            let nr_joints = posed_joints.shape().dims[0];
478            let jcol0: Tensor<B, 1> = posed_joints.clone().slice([0..nr_joints, 0..1]).squeeze(1);
479            let jcol1: Tensor<B, 1> = posed_joints.clone().slice([0..nr_joints, 1..2]).squeeze(1);
480            let jcol2: Tensor<B, 1> = posed_joints.clone().slice([0..nr_joints, 2..3]).squeeze(1);
481            let joints_new_col1 = jcol2;
482            let joints_new_col2 = jcol1.mul_scalar(-1.0);
483            posed_joints = Tensor::stack::<2>(vec![jcol0, joints_new_col1, joints_new_col2], 1);
484        }
485        (verts_final, posed_joints)
486    }
487    fn faces(&self) -> &Tensor<B, 2, Int> {
488        &self.faces
489    }
490    fn faces_uv(&self) -> &Tensor<B, 2, Int> {
491        &self.faces_uv_mesh
492    }
493    fn uv(&self) -> &Tensor<B, 2, Float> {
494        &self.uv
495    }
496    fn lbs_weights(&self) -> Tensor<B, 2, Float> {
497        self.lbs_weights.clone()
498    }
499    fn lbs_weights_split(&self) -> Tensor<B, 2, Float> {
500        self.lbs_weights_split.clone()
501    }
502    fn idx_split_2_merged(&self) -> Tensor<B, 1, Int> {
503        self.idx_vuv_2_vnouv.clone()
504    }
505    fn idx_split_2_merged_vec(&self) -> &Vec<usize> {
506        &self.idx_vuv_2_vnouv_vec
507    }
508    fn set_pose_dirs(&mut self, pose_dirs: Tensor<B, 2, Float>) {
509        self.pose_dirs = Some(pose_dirs);
510    }
511    fn get_pose_dirs(&self) -> Tensor<B, 2, Float> {
512        if let Some(pose_dirs_tensor) = self.pose_dirs.clone() {
513            pose_dirs_tensor
514        } else {
515            panic!("pose_dirs is not available!");
516        }
517    }
518    fn get_expression_dirs(&self) -> Option<Tensor<B, 2, Float>> {
519        self.expression_dirs.clone()
520    }
521    fn vertex_face_csr(&self) -> Option<VertexFaceCSRBurn<B>> {
522        Some(self.vertex_face_csr.clone())
523    }
524    fn vertex_face_uv_csr(&self) -> Option<VertexFaceCSRBurn<B>> {
525        Some(self.vertex_face_uv_csr.clone())
526    }
527}
528pub type SmplXGPU = SmplXGPUG<AppBackend>;