smpl_core/smpl_x/
smpl_x_gpu.rs

1use crate::{
2    common::{
3        betas::Betas,
4        expression::Expression,
5        outputs::SmplOutputDynamic,
6        pose::Pose,
7        smpl_model::{FaceModel, SmplCacheDynamic, SmplModel},
8        smpl_options::SmplOptions,
9        types::{Gender, SmplType, UpAxis},
10    },
11    conversions::pose_remap::PoseRemap,
12};
13use burn::tensor::{backend::Backend, Float, Int, Tensor};
14use gloss_utils::bshare::{tensor_to_data_float, tensor_to_data_int, ToBurn};
15use gloss_utils::nshare::ToNalgebra;
16use log::{info, warn};
17use nalgebra as na;
18use ndarray as nd;
19use ndarray::prelude::*;
20use ndarray_npy::NpzReader;
21use smpl_utils::numerical::{batch_rigid_transform, batch_rodrigues};
22use smpl_utils::{array::Gather2D, io::FileLoader};
23use std::ops::Sub;
24use std::{
25    any::Any,
26    io::{Read, Seek},
27};
28pub const NUM_BODY_JOINTS: usize = 21;
29pub const NUM_HAND_JOINTS: usize = 15;
30pub const NUM_FACE_JOINTS: usize = 3;
31pub const NUM_JOINTS: usize = NUM_BODY_JOINTS + 2 * NUM_HAND_JOINTS + NUM_FACE_JOINTS;
32pub const NECK_IDX: usize = 12;
33pub const NUM_VERTS: usize = 10475;
34pub const NUM_VERTS_UV_MESH: usize = 11307;
35pub const NUM_FACES: usize = 20908;
36pub const FULL_SHAPE_SPACE_DIM: usize = 400;
37pub const SHAPE_SPACE_DIM: usize = 300;
38pub const EXPRESSION_SPACE_DIM: usize = 100;
39pub const NUM_POSE_BLEND_SHAPES: usize = NUM_JOINTS * 9;
40use burn::backend::{Candle, NdArray, Wgpu};
41#[allow(clippy::large_enum_variant)]
42#[derive(Clone)]
43pub enum SmplXDynamic {
44    NdArray(SmplXGPU<NdArray>),
45    Wgpu(SmplXGPU<Wgpu>),
46    Candle(SmplXGPU<Candle>),
47}
48#[allow(clippy::return_self_not_must_use)]
49impl SmplXDynamic {
50    #[cfg(not(target_arch = "wasm32"))]
51    pub fn new_from_npz(models: &SmplCacheDynamic, path: &str, gender: Gender, max_num_betas: usize, num_expression_components: usize) -> Self {
52        match models {
53            SmplCacheDynamic::Wgpu(_) => {
54                info!("Initializing with Wgpu Backend");
55                let model = SmplXGPU::<Wgpu>::new_from_npz(path, gender, max_num_betas, num_expression_components);
56                SmplXDynamic::Wgpu(model)
57            }
58            SmplCacheDynamic::NdArray(_) => {
59                info!("Initializing with NdArray Backend");
60                let model = SmplXGPU::<NdArray>::new_from_npz(path, gender, max_num_betas, num_expression_components);
61                SmplXDynamic::NdArray(model)
62            }
63            SmplCacheDynamic::Candle(_) => {
64                info!("Initializing with Candle Backend");
65                let model = SmplXGPU::<Candle>::new_from_npz(path, gender, max_num_betas, num_expression_components);
66                SmplXDynamic::Candle(model)
67            }
68        }
69    }
70    pub fn new_from_reader<R: Read + Seek>(
71        models: &SmplCacheDynamic,
72        reader: R,
73        gender: Gender,
74        max_num_betas: usize,
75        max_num_expression_components: usize,
76    ) -> Self {
77        match models {
78            SmplCacheDynamic::Wgpu(_) => {
79                info!("Initializing from reader with Wgpu Backend");
80                let model = SmplXGPU::<Wgpu>::new_from_reader(reader, gender, max_num_betas, max_num_expression_components);
81                SmplXDynamic::Wgpu(model)
82            }
83            SmplCacheDynamic::NdArray(_) => {
84                info!("Initializing from reader with NdArray Backend");
85                let model = SmplXGPU::<NdArray>::new_from_reader(reader, gender, max_num_betas, max_num_expression_components);
86                SmplXDynamic::NdArray(model)
87            }
88            SmplCacheDynamic::Candle(_) => {
89                info!("Initializing from reader with Candle Backend");
90                let model = SmplXGPU::<Candle>::new_from_reader(reader, gender, max_num_betas, max_num_expression_components);
91                SmplXDynamic::Candle(model)
92            }
93        }
94    }
95    pub async fn new_from_npz_async(
96        models: &SmplCacheDynamic,
97        path: &str,
98        gender: Gender,
99        max_num_betas: usize,
100        num_expression_components: usize,
101    ) -> Self {
102        match models {
103            SmplCacheDynamic::Wgpu(_) => {
104                info!("Initializing with Wgpu Backend");
105                let model = SmplXGPU::<Wgpu>::new_from_npz_async(path, gender, max_num_betas, num_expression_components).await;
106                SmplXDynamic::Wgpu(model)
107            }
108            SmplCacheDynamic::NdArray(_) => {
109                info!("Initializing with NdArray Backend");
110                let model = SmplXGPU::<NdArray>::new_from_npz_async(path, gender, max_num_betas, num_expression_components).await;
111                SmplXDynamic::NdArray(model)
112            }
113            SmplCacheDynamic::Candle(_) => {
114                info!("Initializing with Candle Backend");
115                let model = SmplXGPU::<Candle>::new_from_npz_async(path, gender, max_num_betas, num_expression_components).await;
116                SmplXDynamic::Candle(model)
117            }
118        }
119    }
120}
121#[derive(Clone)]
122pub struct SmplXGPU<B: Backend> {
123    pub device: B::Device,
124    pub smpl_type: SmplType,
125    pub gender: Gender,
126    pub verts_template: Tensor<B, 2, Float>,
127    pub faces: Tensor<B, 2, Int>,
128    pub faces_uv_mesh: Tensor<B, 2, Int>,
129    pub uv: Tensor<B, 2, Float>,
130    pub shape_dirs: Tensor<B, 2, Float>,
131    pub expression_dirs: Option<Tensor<B, 2, Float>>,
132    pub pose_dirs: Option<Tensor<B, 2, Float>>,
133    pub joint_regressor: Tensor<B, 2, Float>,
134    pub parent_idx_per_joint: Tensor<B, 1, Int>,
135    pub lbs_weights: Tensor<B, 2, Float>,
136    pub verts_ones: Tensor<B, 2, Float>,
137    pub idx_vuv_2_vnouv: Tensor<B, 1, Int>,
138    pub faces_na: na::DMatrix<u32>,
139    pub faces_uv_mesh_na: na::DMatrix<u32>,
140    pub uv_na: na::DMatrix<f32>,
141    pub idx_vuv_2_vnouv_vec: Vec<usize>,
142    pub lbs_weights_split: Tensor<B, 2>,
143    pub lbs_weights_nd: nd::ArcArray2<f32>,
144    pub lbs_weights_split_nd: nd::ArcArray2<f32>,
145}
146impl<B: Backend> SmplXGPU<B> {
147    /// # Panics
148    /// Will panic if the matrices don't match the expected sizes
149    #[allow(clippy::too_many_arguments)]
150    #[allow(clippy::too_many_lines)]
151    pub fn new_from_matrices(
152        gender: Gender,
153        verts_template: &nd::Array2<f32>,
154        faces: &nd::Array2<u32>,
155        faces_uv_mesh: &nd::Array2<u32>,
156        uv: &nd::Array2<f32>,
157        shape_dirs: &nd::Array3<f32>,
158        expression_dirs: Option<nd::Array3<f32>>,
159        pose_dirs: Option<nd::Array3<f32>>,
160        joint_regressor: &nd::Array2<f32>,
161        parent_idx_per_joint: &nd::Array1<u32>,
162        lbs_weights: nd::Array2<f32>,
163        max_num_betas: usize,
164        max_num_expression_components: usize,
165    ) -> Self {
166        let device = B::Device::default();
167        let b_verts_template = verts_template.to_burn(&device);
168        let b_faces = faces.to_burn(&device);
169        let b_faces_uv_mesh = faces_uv_mesh.to_burn(&device);
170        let b_uv = uv.to_burn(&device);
171        let actual_num_betas = max_num_betas.min(shape_dirs.shape()[2]);
172        let shape_dirs = shape_dirs
173            .slice_axis(Axis(2), ndarray::Slice::from(0..actual_num_betas))
174            .to_owned()
175            .into_shape_with_order((NUM_VERTS * 3, actual_num_betas))
176            .unwrap();
177        let b_shape_dirs = shape_dirs.to_burn(&device);
178        let b_expression_dirs = expression_dirs.map(|expression_dirs| {
179            let actual_num_expression_components = max_num_expression_components.min(expression_dirs.shape()[2]);
180            let expression_dirs = expression_dirs
181                .slice_axis(nd::Axis(2), nd::Slice::from(0..actual_num_expression_components))
182                .into_shape_with_order((NUM_VERTS * 3, actual_num_expression_components))
183                .unwrap()
184                .to_owned();
185            expression_dirs.to_burn(&device)
186        });
187        let b_pose_dirs = pose_dirs.map(|pose_dirs| {
188            let pose_dirs = pose_dirs.into_shape_with_order((NUM_VERTS * 3, NUM_JOINTS * 9)).unwrap();
189            pose_dirs.to_burn(&device)
190        });
191        let b_joint_regressor = joint_regressor.to_burn(&device);
192        let b_parent_idx_per_joint = parent_idx_per_joint.to_burn(&device).reshape([NUM_JOINTS + 1]);
193        let b_lbs_weights = lbs_weights.to_burn(&device);
194        #[allow(clippy::cast_possible_wrap)]
195        let faces_uv_mesh_i32: nd::Array2<i32> = faces_uv_mesh.mapv(|x| x as i32);
196        let ft: nd::ArcArray2<i32> = faces_uv_mesh_i32.into();
197        let max_v_uv_idx = *ft.iter().max_by_key(|&x| x).unwrap();
198        let max_v_uv_idx_usize = usize::try_from(max_v_uv_idx).unwrap_or_else(|_| panic!("Cannot cast max_v_uv_idx to usize"));
199        let mut idx_vuv_2_vnouv = nd::ArcArray1::<i32>::zeros(max_v_uv_idx_usize + 1);
200        for (fuv, fnouv) in ft.axis_iter(nd::Axis(0)).zip(faces.axis_iter(nd::Axis(0))) {
201            let uv_0 = fuv[[0]];
202            let uv_1 = fuv[[1]];
203            let uv_2 = fuv[[2]];
204            let nouv_0 = fnouv[[0]];
205            let nouv_1 = fnouv[[1]];
206            let nouv_2 = fnouv[[2]];
207            idx_vuv_2_vnouv[usize::try_from(uv_0).unwrap_or_else(|_| panic!("Cannot cast uv_0 to usize"))] =
208                i32::try_from(nouv_0).unwrap_or_else(|_| panic!("Cannot cast nouv_0 to i32"));
209            idx_vuv_2_vnouv[usize::try_from(uv_1).unwrap_or_else(|_| panic!("Cannot cast uv_1 to usize"))] =
210                i32::try_from(nouv_1).unwrap_or_else(|_| panic!("Cannot cast nouv_1 to i32"));
211            idx_vuv_2_vnouv[usize::try_from(uv_2).unwrap_or_else(|_| panic!("Cannot cast uv_2 to usize"))] =
212                i32::try_from(nouv_2).unwrap_or_else(|_| panic!("Cannot cast nouv_2 to i32"));
213        }
214        let idx_vuv_2_vnouv_vec: Vec<i32> = idx_vuv_2_vnouv.mapv(|x| x).into_raw_vec_and_offset().0;
215        let idx_vuv_2_vnouv_slice: &[i32] = &idx_vuv_2_vnouv_vec;
216        let b_idx_vuv_2_vnouv = Tensor::<B, 1, Int>::from_ints(idx_vuv_2_vnouv_slice, &device);
217        let idx_vuv_2_vnouv_vec: Vec<usize> = idx_vuv_2_vnouv
218            .to_vec()
219            .iter()
220            .map(|&x| usize::try_from(x).unwrap_or_else(|_| panic!("Cannot cast negative value to usize")))
221            .collect();
222        let faces_na = faces.view().into_nalgebra().clone_owned().map(|x| x);
223        let faces_uv_mesh_na = ft
224            .view()
225            .into_nalgebra()
226            .clone_owned()
227            .map(|x| u32::try_from(x).unwrap_or_else(|_| panic!("Cannot cast value to u32")));
228        let uv_na = uv.view().into_nalgebra().clone_owned();
229        let cols: Vec<usize> = (0..lbs_weights.ncols()).collect();
230        let lbs_weights_split: nd::ArcArray2<f32> = lbs_weights.to_owned().gather(&idx_vuv_2_vnouv_vec, &cols).into();
231        let b_lbs_weights_split =
232            Tensor::<B, 1>::from_floats(lbs_weights_split.as_slice().unwrap(), &device).reshape([idx_vuv_2_vnouv_vec.len(), NUM_JOINTS + 1]);
233        let verts_ones = Tensor::<B, 2>::ones([NUM_VERTS, 1], &device);
234        let lbs_weights_nd: nd::ArcArray2<f32> = lbs_weights.into();
235        let cols: Vec<usize> = (0..lbs_weights_nd.ncols()).collect();
236        let lbs_weights_split_nd = lbs_weights_nd.to_owned().gather(&idx_vuv_2_vnouv_vec, &cols).into();
237        info!("Initialised burn on Backend: {:?}", B::name());
238        info!("Device: {:?}", &device);
239        Self {
240            smpl_type: SmplType::SmplX,
241            gender,
242            device,
243            verts_template: b_verts_template,
244            faces: b_faces,
245            faces_uv_mesh: b_faces_uv_mesh,
246            uv: b_uv,
247            shape_dirs: b_shape_dirs,
248            expression_dirs: b_expression_dirs,
249            pose_dirs: b_pose_dirs,
250            joint_regressor: b_joint_regressor,
251            parent_idx_per_joint: b_parent_idx_per_joint,
252            lbs_weights: b_lbs_weights,
253            verts_ones,
254            idx_vuv_2_vnouv: b_idx_vuv_2_vnouv,
255            faces_na,
256            faces_uv_mesh_na,
257            uv_na,
258            idx_vuv_2_vnouv_vec,
259            lbs_weights_split: b_lbs_weights_split,
260            lbs_weights_nd,
261            lbs_weights_split_nd,
262        }
263    }
264    /// # Panics
265    /// Will panic if the path cannot be opened
266    fn new_from_npz_reader<R: Read + Seek>(
267        npz: &mut NpzReader<R>,
268        gender: Gender,
269        max_num_betas: usize,
270        max_num_expression_components: usize,
271    ) -> Self {
272        let verts_template: nd::Array2<f32> = npz.by_name("v_template").unwrap();
273        let faces: nd::Array2<u32> = npz.by_name("f").unwrap();
274        let uv: nd::Array2<f32> = npz.by_name("vt").unwrap();
275        let full_shape_dirs: nd::Array3<f32> = npz.by_name("shapedirs").unwrap();
276        let (shape_dirs, expression_dirs) = if let Ok(expression_dirs) = npz.by_name("expressiondirs") {
277            (full_shape_dirs, Some(expression_dirs))
278        } else {
279            let num_available_betas = full_shape_dirs.shape()[2];
280            let num_full_betas = 300;
281            let num_betas_to_use = num_full_betas.min(max_num_betas).min(num_available_betas);
282            let shape_dirs = full_shape_dirs.slice_axis(nd::Axis(2), nd::Slice::from(0..num_betas_to_use)).to_owned();
283            let expression_dirs = if full_shape_dirs.shape()[2] > 300 {
284                Some(
285                    full_shape_dirs
286                        .slice_axis(nd::Axis(2), nd::Slice::from(300..300 + max_num_expression_components.min(100)))
287                        .to_owned(),
288                )
289            } else {
290                None
291            };
292            (shape_dirs, expression_dirs)
293        };
294        let pose_dirs: Option<nd::Array3<f32>> = npz.by_name("posedirs").ok();
295        let joint_regressor: nd::Array2<f32> = npz.by_name("J_regressor").unwrap();
296        let parent_idx_per_joint: nd::Array2<i32> = npz.by_name("kintree_table").unwrap();
297        #[allow(clippy::cast_sign_loss)]
298        let parent_idx_per_joint = parent_idx_per_joint.mapv(|x| x as u32);
299        let parent_idx_per_joint = parent_idx_per_joint
300            .slice_axis(nd::Axis(0), nd::Slice::from(0..1))
301            .to_owned()
302            .into_shape_with_order(NUM_JOINTS + 1)
303            .unwrap();
304        let lbs_weights: nd::Array2<f32> = npz.by_name("weights").unwrap();
305        let ft: nd::Array2<u32> = npz.by_name("ft").unwrap();
306        if pose_dirs.is_none() {
307            warn!("No pose_dirs loaded from npz");
308        }
309        Self::new_from_matrices(
310            gender,
311            &verts_template,
312            &faces,
313            &ft,
314            &uv,
315            &shape_dirs,
316            expression_dirs,
317            pose_dirs,
318            &joint_regressor,
319            &parent_idx_per_joint,
320            lbs_weights,
321            max_num_betas,
322            max_num_expression_components,
323        )
324    }
325    #[cfg(not(target_arch = "wasm32"))]
326    /// # Panics
327    /// Will panic if the path cannot be opened
328    pub fn new_from_npz(model_path: &str, gender: Gender, max_num_betas: usize, max_num_expression_components: usize) -> Self {
329        let mut npz = NpzReader::new(std::fs::File::open(model_path).unwrap()).unwrap();
330        Self::new_from_npz_reader(&mut npz, gender, max_num_betas, max_num_expression_components)
331    }
332    /// # Panics
333    /// Will panic if the path cannot be opened
334    /// Will panic if the translation and rotation do not cover the same number
335    /// of timesteps
336    #[allow(clippy::cast_possible_truncation)]
337    pub async fn new_from_npz_async(model_path: &str, gender: Gender, max_num_betas: usize, max_num_expression_components: usize) -> Self {
338        let reader = FileLoader::open(model_path).await;
339        let mut npz = NpzReader::new(reader).unwrap();
340        Self::new_from_npz_reader(&mut npz, gender, max_num_betas, max_num_expression_components)
341    }
342    /// # Panics
343    /// Will panic if the path cannot be opened
344    /// Will panic if the translation and rotation do not cover the same number
345    /// of timesteps
346    #[allow(clippy::cast_possible_truncation)]
347    pub fn new_from_reader<R: Read + Seek>(reader: R, gender: Gender, max_num_betas: usize, max_num_expression_components: usize) -> Self {
348        let mut npz = NpzReader::new(reader).unwrap();
349        Self::new_from_npz_reader(&mut npz, gender, max_num_betas, max_num_expression_components)
350    }
351    /// # Panics
352    /// Will panic if the path cannot be opened
353    /// Will panic if the translation and rotation do not cover the same number
354    /// of timesteps
355    #[allow(clippy::cast_possible_truncation)]
356    pub fn read_pose_dirs_from_reader<R: Read + Seek>(reader: R, device: &B::Device) -> Tensor<B, 2, Float> {
357        let mut npz = NpzReader::new(reader).unwrap();
358        let pose_dirs: Option<nd::Array3<f32>> = Some(npz.by_name("pose_dirs").unwrap());
359        let b_pose_dirs =
360            pose_dirs.map(|pose_dirs| Tensor::<B, 1>::from_floats(pose_dirs.as_slice().unwrap(), device).reshape([NUM_VERTS * 3, NUM_JOINTS * 9]));
361        b_pose_dirs.unwrap()
362    }
363}
364impl<B: Backend> FaceModel<B> for SmplXGPU<B>
365where
366    B::IntTensorPrimitive<1>: Sync,
367    B::IntTensorPrimitive<2>: Sync,
368    B::FloatTensorPrimitive<2>: Sync,
369    B::QuantizedTensorPrimitive<2>: std::marker::Sync,
370{
371    #[allow(clippy::missing_panics_doc)]
372    #[allow(non_snake_case)]
373    #[allow(clippy::let_and_return)]
374    fn expression2offsets(&self, expression: &Expression) -> Tensor<B, 2, Float> {
375        let device = self.verts_template.device();
376        let offsets = if let Some(ref expression_dirs) = self.expression_dirs {
377            let input_nr_expression_coeffs = expression.expr_coeffs.len();
378            let model_nr_expression_coeffs = expression_dirs.shape().dims[1];
379            let nr_expression_coeffs = input_nr_expression_coeffs.min(model_nr_expression_coeffs);
380            let expr_sliced = expression.expr_coeffs.slice(s![0..nr_expression_coeffs]);
381            let expr_tensor = Tensor::<B, 1, Float>::from_floats(expr_sliced.as_slice().unwrap(), &device);
382            let expression_dirs_sliced = expression_dirs.clone().slice([0..expression_dirs.dims()[0], 0..nr_expression_coeffs]);
383            let v_expr_offsets = expression_dirs_sliced.matmul(expr_tensor.reshape([-1, 1]));
384            v_expr_offsets.reshape([NUM_VERTS, 3])
385        } else {
386            Tensor::<B, 2, Float>::zeros([NUM_VERTS, 3], &device)
387        };
388        offsets
389    }
390}
391impl<B: Backend> SmplModel<B> for SmplXGPU<B>
392where
393    B::FloatTensorPrimitive<2>: Sync,
394    B::IntTensorPrimitive<2>: Sync,
395    B::IntTensorPrimitive<1>: Sync,
396    B::QuantizedTensorPrimitive<1>: std::marker::Sync,
397    B::QuantizedTensorPrimitive<2>: std::marker::Sync,
398{
399    fn clone_dyn(&self) -> Box<dyn SmplModel<B>> {
400        Box::new(self.clone())
401    }
402    fn as_any(&self) -> &dyn Any {
403        self
404    }
405    fn smpl_type(&self) -> SmplType {
406        self.smpl_type
407    }
408    fn gender(&self) -> Gender {
409        self.gender
410    }
411    fn get_face_model(&self) -> &dyn FaceModel<B> {
412        self
413    }
414    #[allow(clippy::missing_panics_doc)]
415    #[allow(non_snake_case)]
416    fn forward(&self, options: &SmplOptions, betas: &Betas, pose_raw: &Pose, expression: Option<&Expression>) -> SmplOutputDynamic<B> {
417        let mut verts_t_pose = self.betas2verts(betas);
418        if let Some(expression) = expression {
419            verts_t_pose = verts_t_pose + self.expression2offsets(expression);
420        }
421        let pose_remap = PoseRemap::new(pose_raw.smpl_type, SmplType::SmplX);
422        let pose = pose_remap.remap(pose_raw);
423        let joints_t_pose = self.verts2joints(verts_t_pose.clone());
424        if options.enable_pose_corrective {
425            let verts_offset = self.compute_pose_correctives(&pose);
426            verts_t_pose = verts_t_pose + verts_offset;
427        }
428        let (verts_posed_nd, _, _, joints_posed) = self.apply_pose(&verts_t_pose, None, None, &joints_t_pose, &self.lbs_weights, &pose);
429        SmplOutputDynamic {
430            verts: verts_posed_nd,
431            faces: self.faces.clone(),
432            normals: None,
433            uvs: None,
434            joints: joints_posed,
435        }
436    }
437    fn create_body_with_uv(&self, smpl_merged: &SmplOutputDynamic<B>) -> SmplOutputDynamic<B> {
438        let cols_tensor = Tensor::<B, 1, Int>::from_ints([0, 1, 2], &self.device);
439        let mapping_tensor = self.idx_split_2_merged();
440        let v_burn_split = smpl_merged.verts.clone().select(0, mapping_tensor.clone());
441        let v_burn_split = v_burn_split.select(1, cols_tensor.clone());
442        let n_burn_split = smpl_merged
443            .normals
444            .as_ref()
445            .map(|n| n.clone().select(0, mapping_tensor).select(1, cols_tensor));
446        SmplOutputDynamic {
447            verts: v_burn_split,
448            faces: self.faces_uv_mesh.clone(),
449            normals: n_burn_split,
450            uvs: Some(self.uv.clone()),
451            joints: smpl_merged.joints.clone(),
452        }
453    }
454    #[allow(clippy::missing_panics_doc)]
455    #[allow(non_snake_case)]
456    #[allow(clippy::let_and_return)]
457    fn betas2verts(&self, betas: &Betas) -> Tensor<B, 2, Float> {
458        let device = self.verts_template.device();
459        let input_nr_betas = betas.betas.len();
460        let model_nr_betas = self.shape_dirs.shape().dims[1];
461        let nr_betas = input_nr_betas.min(model_nr_betas);
462        let betas_sliced = betas.betas.slice(s![0..nr_betas]);
463        let betas_tensor = Tensor::<B, 1, Float>::from_floats(betas_sliced.as_slice().unwrap(), &device);
464        let shape_dirs_sliced = self.shape_dirs.clone().slice([0..self.shape_dirs.dims()[0], 0..nr_betas]);
465        let v_beta_offsets = shape_dirs_sliced.matmul(betas_tensor.reshape([-1, 1]));
466        let v_beta_offsets_reshaped = v_beta_offsets.reshape([NUM_VERTS, 3]);
467        let verts_t_pose = v_beta_offsets_reshaped.add(self.verts_template.clone());
468        verts_t_pose
469    }
470    fn verts2joints(&self, verts_t_pose: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
471        self.joint_regressor.clone().matmul(verts_t_pose)
472    }
473    #[allow(clippy::missing_panics_doc)]
474    fn compute_pose_correctives(&self, pose: &Pose) -> Tensor<B, 2, Float> {
475        let offsets = if let Some(pose_dirs) = &self.pose_dirs {
476            let full_pose = &pose.joint_poses;
477            assert!(
478                full_pose.dim().0 == NUM_JOINTS + 1,
479                "The pose does not have the correct number of joints for this model. Maybe you need to add a PoseRemapper component?\n {:?} != {:?}",
480                full_pose.dim().0,
481                NUM_JOINTS + 1
482            );
483            let mut rot_mats = batch_rodrigues(full_pose);
484            let identity = ndarray::Array2::<f32>::eye(3);
485            let pose_feature = (rot_mats.slice_mut(s![1.., .., ..]).sub(&identity))
486                .into_shape_with_order(NUM_JOINTS * 9)
487                .unwrap();
488            let b_pose_feature = Tensor::<B, 1, Float>::from_floats(pose_feature.as_slice().unwrap(), &self.device).reshape([NUM_JOINTS * 9, 1]);
489            let new_pose_dirs = pose_dirs.clone();
490            let all_pose_offsets = new_pose_dirs.matmul(b_pose_feature);
491            all_pose_offsets.reshape([NUM_VERTS, 3])
492        } else {
493            Tensor::<B, 2, Float>::zeros([NUM_VERTS, 3], &self.device)
494        };
495        offsets
496    }
497    #[allow(clippy::missing_panics_doc)]
498    fn compute_pose_feature(&self, pose: &Pose) -> nd::Array1<f32> {
499        let full_pose = &pose.joint_poses;
500        assert!(
501            full_pose.dim().0 == NUM_JOINTS + 1,
502            "The pose does not have the correct number of joints for this model. Maybe you need to add a PoseRemapper component?\n {:?} != {:?}",
503            full_pose.dim().0,
504            NUM_JOINTS + 1
505        );
506        let mut rot_mats = batch_rodrigues(full_pose);
507        let identity = ndarray::Array2::<f32>::eye(3);
508        let pose_feature = (rot_mats.slice_mut(s![1.., .., ..]).sub(&identity))
509            .into_shape_with_order(NUM_JOINTS * 9)
510            .unwrap();
511        pose_feature
512    }
513    #[allow(clippy::missing_panics_doc)]
514    #[allow(non_snake_case)]
515    #[allow(clippy::cast_precision_loss)]
516    #[allow(clippy::cast_sign_loss)]
517    #[allow(clippy::too_many_lines)]
518    #[allow(clippy::similar_names)]
519    fn apply_pose(
520        &self,
521        verts_t_pose: &Tensor<B, 2, Float>,
522        normals: Option<&Tensor<B, 2, Float>>,
523        tangents: Option<&Tensor<B, 2, Float>>,
524        joints: &Tensor<B, 2, Float>,
525        lbs_weights: &Tensor<B, 2, Float>,
526        pose: &Pose,
527    ) -> (
528        Tensor<B, 2, Float>,
529        Option<Tensor<B, 2, Float>>,
530        Option<Tensor<B, 2, Float>>,
531        Tensor<B, 2, Float>,
532    ) {
533        assert!(
534            verts_t_pose.shape().dims[0] == lbs_weights.shape().dims[0],
535            "Verts and LBS weights should match"
536        );
537        let full_pose = &pose.joint_poses;
538        assert!(
539            full_pose.shape()[0] == NUM_JOINTS + 1,
540            "The pose does not have the correct number of joints for this model."
541        );
542        let rot_mats = batch_rodrigues(full_pose);
543        let joints_data = tensor_to_data_float(joints);
544        let shape = joints.shape().dims;
545        let nd_joints = nd::Array2::from_shape_vec((shape[0], shape[1]), joints_data).expect("Shape mismatch during tensor to ndarray conversion");
546        let parent_idx_data_i32: Vec<i32> = tensor_to_data_int(&self.parent_idx_per_joint);
547        let parent_idx_data_u32: Vec<u32> = parent_idx_data_i32.into_iter().map(|x| x as u32).collect();
548        let (posed_joints_nd, rel_transforms_nd) = batch_rigid_transform(parent_idx_data_u32, &rot_mats, &nd_joints, NUM_JOINTS);
549        let posed_joints = posed_joints_nd.to_burn(&self.device);
550        let nr_verts = verts_t_pose.shape().dims[0];
551        let nr_joints = posed_joints.shape().dims[0];
552        let v_posed = verts_t_pose.clone();
553        let W = lbs_weights;
554        let A_nd = rel_transforms_nd.into_shape_with_order((NUM_JOINTS + 1, 16)).unwrap();
555        let A = A_nd.to_burn(&self.device);
556        let T = W.clone().matmul(A).reshape([nr_verts, 4, 4]);
557        let dims_3 = 3;
558        let rot0 = T.clone().slice([0..nr_verts, 0..1, 0..dims_3]).squeeze(1);
559        let rot1 = T.clone().slice([0..nr_verts, 1..2, 0..dims_3]).squeeze(1);
560        let rot2 = T.clone().slice([0..nr_verts, 2..3, 0..dims_3]).squeeze(1);
561        let trans: Tensor<B, 2> = T.slice([0..nr_verts, 0..dims_3, 3..4]).squeeze(2);
562        let verts_final_0 = rot0.clone().mul(v_posed.clone()).sum_dim(1);
563        let verts_final_1 = rot1.clone().mul(v_posed.clone()).sum_dim(1);
564        let verts_final_2 = rot2.clone().mul(v_posed.clone()).sum_dim(1);
565        let verts_final = Tensor::<B, 1>::stack(vec![verts_final_0.squeeze(1), verts_final_1.squeeze(1), verts_final_2.squeeze(1)], 1);
566        let verts_final = verts_final.add(trans);
567        let mut normals_final = if let Some(normals) = normals {
568            let normals_0 = rot0.clone().mul(normals.clone()).sum_dim(1);
569            let normals_1 = rot1.clone().mul(normals.clone()).sum_dim(1);
570            let normals_2 = rot2.clone().mul(normals.clone()).sum_dim(1);
571            let normals_final = Tensor::<B, 1>::stack(vec![normals_0.squeeze(1), normals_1.squeeze(1), normals_2.squeeze(1)], 1);
572            Some(normals_final)
573        } else {
574            None
575        };
576        let mut tangents_final = if let Some(tangents) = tangents {
577            let tangents_3 = tangents.clone().slice([0..nr_verts, 0..3]);
578            let tangents_0 = rot0.mul(tangents_3.clone()).sum_dim(1);
579            let tangents_1 = rot1.mul(tangents_3.clone()).sum_dim(1);
580            let tangents_2 = rot2.mul(tangents_3.clone()).sum_dim(1);
581            let handedness: Tensor<B, 1> = tangents.clone().slice([0..nr_verts, 3..4]).squeeze(1);
582            let tangents_final = Tensor::<B, 1>::stack(vec![tangents_0.squeeze(1), tangents_1.squeeze(1), tangents_2.squeeze(1), handedness], 1);
583            Some(tangents_final)
584        } else {
585            None
586        };
587        let trans_pose_nd = pose.global_trans.clone();
588        let trans_pose = trans_pose_nd.to_burn(&self.device);
589        let trans_pose_broadcasted_v = trans_pose.clone().reshape([1, 3]).expand(verts_final.shape());
590        let trans_pose_broadcasted_p = trans_pose.reshape([1, 3]).expand(posed_joints.shape());
591        let mut verts_final_modified = verts_final.clone().add(trans_pose_broadcasted_v.clone());
592        let mut posed_joints_modified = posed_joints.clone().add(trans_pose_broadcasted_p.clone());
593        if pose.up_axis == UpAxis::Z {
594            let vcol0: Tensor<B, 1> = verts_final_modified.clone().slice([0..nr_verts, 0..1]).squeeze(1);
595            let vcol1: Tensor<B, 1> = verts_final_modified.clone().slice([0..nr_verts, 1..2]).squeeze(1);
596            let vcol2: Tensor<B, 1> = verts_final_modified.clone().slice([0..nr_verts, 2..3]).squeeze(1);
597            let verts_new_col1 = vcol2;
598            let verts_new_col2 = vcol1.mul_scalar(-1.0);
599            verts_final_modified = Tensor::stack::<2>(vec![vcol0, verts_new_col1, verts_new_col2], 1);
600            let jcol0: Tensor<B, 1> = posed_joints_modified.clone().slice([0..nr_joints, 0..1]).squeeze(1);
601            let jcol1: Tensor<B, 1> = posed_joints_modified.clone().slice([0..nr_joints, 1..2]).squeeze(1);
602            let jcol2: Tensor<B, 1> = posed_joints_modified.clone().slice([0..nr_joints, 2..3]).squeeze(1);
603            let joints_new_col1 = jcol2;
604            let joints_new_col2 = jcol1.mul_scalar(-1.0);
605            posed_joints_modified = Tensor::stack::<2>(vec![jcol0, joints_new_col1, joints_new_col2], 1);
606            if let Some(ref mut normals) = normals_final {
607                let ncol0: Tensor<B, 1> = normals.clone().slice([0..nr_verts, 0..1]).squeeze(1);
608                let ncol1: Tensor<B, 1> = normals.clone().slice([0..nr_verts, 1..2]).squeeze(1);
609                let ncol2: Tensor<B, 1> = normals.clone().slice([0..nr_verts, 2..3]).squeeze(1);
610                let normals_new_col1 = ncol2;
611                let normals_new_col2 = ncol1.mul_scalar(-1.0);
612                let normals_final_modified = Tensor::stack::<2>(vec![ncol0, normals_new_col1, normals_new_col2], 1);
613                *normals = normals_final_modified;
614            }
615            if let Some(ref mut tangents) = tangents_final {
616                let tcol0: Tensor<B, 1> = tangents.clone().slice([0..nr_verts, 0..1]).squeeze(1);
617                let tcol1: Tensor<B, 1> = tangents.clone().slice([0..nr_verts, 1..2]).squeeze(1);
618                let tcol2: Tensor<B, 1> = tangents.clone().slice([0..nr_verts, 2..3]).squeeze(1);
619                let tangents_new_col1 = tcol2;
620                let tangents_new_col2 = tcol1.mul_scalar(-1.0);
621                let handedness = tangents.clone().slice([0..nr_verts, 3..4]).squeeze(1);
622                let tangents_final_modified = Tensor::stack::<2>(vec![tcol0, tangents_new_col1, tangents_new_col2, handedness], 1);
623                *tangents = tangents_final_modified;
624            }
625        }
626        (verts_final_modified, normals_final.clone(), tangents_final.clone(), posed_joints_modified)
627    }
628    fn faces(&self) -> &Tensor<B, 2, Int> {
629        &self.faces
630    }
631    fn faces_uv(&self) -> &Tensor<B, 2, Int> {
632        &self.faces_uv_mesh
633    }
634    fn uv(&self) -> &Tensor<B, 2, Float> {
635        &self.uv
636    }
637    fn lbs_weights(&self) -> Tensor<B, 2, Float> {
638        self.lbs_weights.clone()
639    }
640    fn lbs_weights_split(&self) -> Tensor<B, 2, Float> {
641        self.lbs_weights_split.clone()
642    }
643    fn idx_split_2_merged(&self) -> Tensor<B, 1, Int> {
644        self.idx_vuv_2_vnouv.clone()
645    }
646    fn idx_split_2_merged_vec(&self) -> &Vec<usize> {
647        &self.idx_vuv_2_vnouv_vec
648    }
649    fn set_pose_dirs(&mut self, pose_dirs: Tensor<B, 2, Float>) {
650        self.pose_dirs = Some(pose_dirs);
651    }
652    fn get_pose_dirs(&self) -> Tensor<B, 2, Float> {
653        if let Some(pose_dirs_tensor) = self.pose_dirs.clone() {
654            pose_dirs_tensor
655        } else {
656            panic!("pose_dirs is not available!");
657        }
658    }
659    fn get_expression_dirs(&self) -> Option<Tensor<B, 2, Float>> {
660        self.expression_dirs.clone()
661    }
662}