pub fn batch_rigid_transform( parent_idx_per_joint: Vec<u32>, rot_mats: &Array3<f32>, joints: &Array2<f32>, num_joints: usize, ) -> (Array2<f32>, Array3<f32>)