smpl_utils/
numerical.rs

1use burn::tensor::{backend::Backend, Float, Int, Tensor};
2use gloss_utils::nshare::{RefNdarray2, ToNalgebra};
3use nalgebra as na;
4use nalgebra::clamp;
5use ndarray as nd;
6use ndarray::prelude::*;
7use std::{
8    f32::consts::PI,
9    ops::{Div, SubAssign},
10};
11pub fn hex_to_rgb(hex: &str) -> (u8, u8, u8) {
12    let hex = hex.trim_start_matches('#');
13    let r = u8::from_str_radix(&hex[0..2], 16).unwrap_or(0);
14    let g = u8::from_str_radix(&hex[2..4], 16).unwrap_or(0);
15    let b = u8::from_str_radix(&hex[4..6], 16).unwrap_or(0);
16    (r, g, b)
17}
18pub fn hex_to_rgb_f32(hex: &str) -> (f32, f32, f32) {
19    let (r, g, b) = hex_to_rgb(hex);
20    (f32::from(r) / 255.0, f32::from(g) / 255.0, f32::from(b) / 255.0)
21}
22pub fn interpolate_angle(cur_angle: f32, other_angle: f32, _cur_w: f32, other_w: f32) -> f32 {
23    let mut diff = other_angle - cur_angle;
24    if diff.abs() > PI {
25        if diff > 0.0 {
26            diff -= 2.0 * PI;
27        } else {
28            diff += 2.0 * PI;
29        }
30    }
31    cur_angle + other_w * diff
32}
33pub fn interpolate_angle_tensor<B: Backend>(cur_angle: Tensor<B, 1>, other_angle: Tensor<B, 1>, _cur_w: f32, other_w: f32) -> Tensor<B, 1> {
34    let mut diff = other_angle - cur_angle.clone();
35    assert!(cur_angle.dims() == [1]);
36    let abs_diff = diff.clone().abs();
37    let needs_adjustment = abs_diff.greater_elem(PI);
38    let two_pi = Tensor::<B, 1>::from_floats([2.0 * PI], &cur_angle.device());
39    let neg_two_pi = Tensor::<B, 1>::from_floats([-2.0 * PI], &cur_angle.device());
40    let positive_mask = diff.clone().greater_elem(0.0);
41    let negative_mask = diff.clone().lower_elem(0.0);
42    let pos_adjustment = positive_mask.clone().float() * neg_two_pi.clone();
43    let neg_adjustment = negative_mask.clone().float() * two_pi.clone();
44    let total_adjustment = pos_adjustment + neg_adjustment;
45    let adjustment = needs_adjustment.float() * total_adjustment;
46    diff = diff + adjustment;
47    cur_angle + other_w * diff
48}
49pub fn axis_angle_to_quaternion<B: Backend>(axis_angle: Tensor<B, 2>) -> Tensor<B, 2> {
50    let eps = 1e-6f32;
51    let angle: Tensor<B, 1> = axis_angle.clone().powf_scalar(2.0).sum_dim(1).squeeze_dims(&[1]).sqrt();
52    let denom = angle.clone().unsqueeze_dim(1) + eps;
53    let axis = axis_angle / denom;
54    let half_angle = angle * 0.5;
55    let cos_half = half_angle.clone().cos();
56    let sin_half = half_angle.sin();
57    let qxyz = axis.clone().slice([0..axis.dims()[0], 0..3]) * sin_half.clone().unsqueeze_dim(1);
58    let qw = cos_half.unsqueeze_dim(1);
59    Tensor::cat(vec![qxyz, qw], 1)
60}
61pub fn quaternion_to_axis_angle<B: Backend>(quat: Tensor<B, 2>) -> Tensor<B, 2> {
62    let eps = 1e-6f32;
63    let nr_rows = quat.dims()[0];
64    let qxyz = quat.clone().slice([0..nr_rows, 0..3]);
65    let qw = quat.slice([0..nr_rows, 3..4]).squeeze(1);
66    let vec_norm = qxyz.clone().powf_scalar(2.0).sum_dim(1).sqrt().squeeze_dims(&[1]);
67    let abs_qw = qw.abs();
68    let safe_qw = abs_qw.clone() + eps;
69    let half_angle_tan = vec_norm.clone() / safe_qw;
70    let small_rotation_mask = abs_qw.greater_elem(0.9);
71    let small_angle_approx = 2.0 * vec_norm.clone();
72    let x = half_angle_tan.clone();
73    let atan_approx = x.clone() / (1.0 + 0.28 * x.powf_scalar(2.0));
74    let large_angle_approx = 2.0 * atan_approx;
75    let small_mask_float = small_rotation_mask.clone().float();
76    let angle: Tensor<B, 1> = small_mask_float.clone() * small_angle_approx + (1.0 - small_mask_float) * large_angle_approx;
77    let small_angle_mask = vec_norm.clone().lower_elem(eps);
78    let safe_vec_norm = vec_norm.clone() + eps;
79    let angle_over_norm = angle.unsqueeze_dim(1) / safe_vec_norm.unsqueeze_dim(1);
80    let axis_angle = qxyz * angle_over_norm;
81    let small_angle_mask_3d = small_angle_mask.float().unsqueeze_dim(1);
82    (1.0 - small_angle_mask_3d) * axis_angle
83}
84pub fn quaternion_to_axis_angle_fast<B: Backend>(quat: Tensor<B, 2>) -> Tensor<B, 2> {
85    let eps = 1e-6f32;
86    let nr_rows = quat.dims()[0];
87    let qxyz = quat.clone().slice([0..nr_rows, 0..3]);
88    let qw = quat.slice([0..nr_rows, 3..4]).squeeze(1);
89    let one_minus_w: Tensor<B, 1> = 1.0 - qw.clone();
90    let sqrt_term = one_minus_w.sqrt();
91    let acos_w = sqrt_term * (1.570_728_8 + qw.clone() * (-0.212_114_4 + qw.clone() * (0.074_261_0 + qw.clone() * -0.018_729_3)));
92    let angle: Tensor<B, 1> = 2.0 * acos_w;
93    let one_min_square: Tensor<B, 1> = 1.0 - qw.clone() * qw.clone();
94    let sin_half_angle: Tensor<B, 1> = one_min_square.sqrt();
95    let denom = sin_half_angle + eps;
96    let axis = qxyz / denom.unsqueeze_dim(1);
97    axis * angle.unsqueeze_dim(1)
98}
99pub fn quaternion_interpolate_slerp<B: Backend>(lhs: Tensor<B, 2>, other: Tensor<B, 2>, other_weight: f32) -> Tensor<B, 2> {
100    let eps = 1e-6f32;
101    let lhs_norm = lhs.clone().powf_scalar(2.0).sum_dim(1).sqrt() + eps;
102    let other_norm = other.clone().powf_scalar(2.0).sum_dim(1).sqrt() + eps;
103    let lhs_normalized = lhs / lhs_norm;
104    let other_normalized = other / other_norm;
105    let dot: Tensor<B, 1> = (lhs_normalized.clone() * other_normalized.clone()).sum_dim(1).squeeze_dims(&[1]);
106    let negative_dot_mask = dot.clone().lower_elem(0.0);
107    let negative_dot_mask_float: Tensor<B, 1> = negative_dot_mask.clone().float();
108    let dot_mask_float: Tensor<B, 1> = 1.0 - negative_dot_mask_float;
109    let sign_corrected_other =
110        negative_dot_mask.clone().float().unsqueeze_dim(1) * (-other_normalized.clone()) + dot_mask_float.unsqueeze_dim(1) * other_normalized.clone();
111    let corrected_dot: Tensor<B, 1> = dot.clone().abs();
112    let close_threshold = 0.9995f32;
113    let very_close_mask = corrected_dot.clone().greater_elem(close_threshold);
114    let lerp_result = lhs_normalized.clone() * (1.0 - other_weight) + sign_corrected_other.clone() * other_weight;
115    let lerp_norm = lerp_result.clone().powf_scalar(2.0).sum_dim(1).sqrt() + eps;
116    let lerp_normalized = lerp_result / lerp_norm;
117    let one_minus_dot_sq: Tensor<B, 1> = 1.0 - corrected_dot.clone().powf_scalar(2.0);
118    let sqrt_term = one_minus_dot_sq.sqrt();
119    let safe_dot = corrected_dot.clone() + eps;
120    let ratio = sqrt_term / safe_dot;
121    let theta_approx: Tensor<B, 1> = ratio.clone() / (1.0 + 0.28 * ratio.clone().powf_scalar(2.0));
122    let sin_theta = theta_approx.clone().sin();
123    let safe_sin_theta = sin_theta.clone() + eps;
124    let weight_lhs = ((1.0 - other_weight) * theta_approx.clone()).sin() / safe_sin_theta.clone();
125    let weight_other = (other_weight * theta_approx).sin() / safe_sin_theta;
126    let slerp_result = lhs_normalized.clone() * weight_lhs.unsqueeze_dim(1) + sign_corrected_other * weight_other.unsqueeze_dim(1);
127    let inv_very_close_mask_float: Tensor<B, 1> = 1.0 - very_close_mask.clone().float();
128    very_close_mask.clone().float().unsqueeze_dim(1) * lerp_normalized + inv_very_close_mask_float.unsqueeze_dim(1) * slerp_result
129}
130pub fn quaternion_interpolate_lerp<B: Backend>(lhs: Tensor<B, 2>, other: Tensor<B, 2>, other_weight: f32) -> Tensor<B, 2> {
131    let eps = 1e-6f32;
132    let dot: Tensor<B, 1> = (lhs.clone() * other.clone()).sum_dim(1).squeeze_dims(&[1]);
133    let negative_dot_mask = dot.lower_elem(0.0);
134    let negative_dot_mask_float: Tensor<B, 1> = negative_dot_mask.float();
135    let positive_dot_mask_float: Tensor<B, 1> = 1.0 - negative_dot_mask_float.clone();
136    let sign_corrected_other = negative_dot_mask_float.clone().unsqueeze_dim(1) * (-other.clone()) + positive_dot_mask_float.unsqueeze_dim(1) * other;
137    let lerp_result = lhs * (1.0 - other_weight) + sign_corrected_other * other_weight;
138    let lerp_norm_sq = lerp_result.clone().powf_scalar(2.0).sum_dim(1);
139    lerp_result / (lerp_norm_sq.sqrt() + eps)
140}
141pub fn quaternion_interpolate_lerp_fast<B: Backend>(lhs: Tensor<B, 2>, other: Tensor<B, 2>, other_weight: f32) -> Tensor<B, 2> {
142    let eps = 1e-6f32;
143    let lerp_result = lhs * (1.0 - other_weight) + other * other_weight;
144    let lerp_norm_sq = lerp_result.clone().powf_scalar(2.0).sum_dim(1);
145    lerp_result / (lerp_norm_sq.sqrt() + eps)
146}
147pub fn map(value: f32, in_min: f32, in_max: f32, out_min: f32, out_max: f32) -> f32 {
148    let value_clamped = clamp(value, in_min, in_max);
149    out_min + (out_max - out_min) * (value_clamped - in_min) / (in_max - in_min)
150}
151pub fn smootherstep(low: f32, high: f32, val: f32) -> f32 {
152    let t = map(val, low, high, 0.0, 1.0);
153    t * t * t * (t * (t * 6.0 - 15.0) + 10.0)
154}
155pub fn batch_rodrigues(full_pose: &nd::Array2<f32>) -> nd::Array3<f32> {
156    let mut rotations_per_join = ndarray::Array3::<f32>::zeros((full_pose.shape()[0], 3, 3));
157    for (idx, v) in full_pose.axis_iter(nd::Axis(0)).enumerate() {
158        let angle = v.iter().map(|x| x * x).sum::<f32>().sqrt();
159        let rot_dir = full_pose.row(idx).to_owned().div(angle + 1e-6);
160        let cos = angle.cos();
161        let sin = angle.sin();
162        let (rx, ry, rz) = (rot_dir[0], rot_dir[1], rot_dir[2]);
163        let k = array![[0.0, -rz, ry], [rz, 0.0, -rx], [-ry, rx, 0.0]];
164        let identity = ndarray::Array2::<f32>::eye(3);
165        let rot_mat = identity + sin * k.clone() + (1.0 - cos) * k.dot(&k);
166        rotations_per_join.slice_mut(s![idx, .., ..]).assign(&rot_mat);
167    }
168    rotations_per_join
169}
170#[allow(clippy::let_and_return)]
171pub fn batch_rodrigues_burn<B: Backend>(full_pose: &Tensor<B, 2, Float>) -> Tensor<B, 3, Float> {
172    let eps = Tensor::<B, 1, Float>::from_floats([1e-6], &full_pose.device());
173    let angle = full_pose.clone().powf_scalar(2.0).sum_dim(1).sqrt();
174    let denom = angle.clone() + eps.unsqueeze_dim(0);
175    let k = full_pose.clone() / denom;
176    let kx: Tensor<B, 1> = k.clone().slice_dim(1, 0..1).squeeze(1);
177    let ky: Tensor<B, 1> = k.clone().slice_dim(1, 1..2).squeeze(1);
178    let kz: Tensor<B, 1> = k.clone().slice_dim(1, 2..3).squeeze(1);
179    let zero: Tensor<B, 2> = Tensor::<B, 1, Float>::zeros_like(&kx).unsqueeze_dim(1);
180    let k11 = zero.clone();
181    let k12 = -kz.clone().unsqueeze_dim(1);
182    let k13 = ky.clone().unsqueeze_dim(1);
183    let k21 = kz.clone().unsqueeze_dim(1);
184    let k22 = zero.clone();
185    let k23 = -kx.clone().unsqueeze_dim(1);
186    let k31 = -ky.clone().unsqueeze_dim(1);
187    let k32 = kx.clone().unsqueeze_dim(1);
188    let k33 = zero;
189    let k_mat = Tensor::cat(
190        vec![
191            Tensor::cat(vec![k11, k12, k13], 1),
192            Tensor::cat(vec![k21, k22, k23], 1),
193            Tensor::cat(vec![k31, k32, k33], 1),
194        ],
195        1,
196    )
197    .reshape([-1, 3, 3]);
198    let cos = angle.clone().cos().unsqueeze_dim(2);
199    let sin = angle.clone().sin().unsqueeze_dim(2);
200    let eye = Tensor::<B, 2, Float>::eye(3, &full_pose.device()).unsqueeze_dim(0);
201    let eye = eye.repeat(&[full_pose.dims()[0], 1, 1]);
202    let k_sq = k_mat.clone().matmul(k_mat.clone());
203    let rot_mat = eye + sin * k_mat + (Tensor::ones_like(&cos) - cos) * k_sq;
204    rot_mat
205}
206#[allow(clippy::let_and_return)]
207pub fn batch_rodrigues_burn_2<B: Backend>(full_pose: &Tensor<B, 2, Float>) -> Tensor<B, 3, Float> {
208    let eps = Tensor::<B, 1, Float>::from_floats([1e-6], &full_pose.device());
209    let angle = full_pose.clone().powf_scalar(2.0).sum_dim(1).sqrt().squeeze(1);
210    let denom = angle.clone().unsqueeze_dim(1) + eps.unsqueeze_dim(0);
211    let k = full_pose.clone() / denom;
212    let kx: Tensor<B, 1> = k.clone().slice_dim(1, 0..1).squeeze(1);
213    let ky: Tensor<B, 1> = k.clone().slice_dim(1, 1..2).squeeze(1);
214    let kz: Tensor<B, 1> = k.clone().slice_dim(1, 2..3).squeeze(1);
215    let cos = angle.clone().cos();
216    let sin = angle.clone().sin();
217    let one = Tensor::<B, 1, Float>::ones_like(&cos);
218    let one_minus_cos = one.clone() - cos.clone();
219    let r11 = cos.clone() + one_minus_cos.clone() * kx.clone() * kx.clone();
220    let r12 = one_minus_cos.clone() * kx.clone() * ky.clone() - sin.clone() * kz.clone();
221    let r13 = one_minus_cos.clone() * kx.clone() * kz.clone() + sin.clone() * ky.clone();
222    let r21 = one_minus_cos.clone() * ky.clone() * kx.clone() + sin.clone() * kz.clone();
223    let r22 = cos.clone() + one_minus_cos.clone() * ky.clone() * ky.clone();
224    let r23 = one_minus_cos.clone() * ky.clone() * kz.clone() - sin.clone() * kx.clone();
225    let r31 = one_minus_cos.clone() * kz.clone() * kx.clone() - sin.clone() * ky.clone();
226    let r32 = one_minus_cos.clone() * kz.clone() * ky.clone() + sin.clone() * kx.clone();
227    let r33 = cos.clone() + one_minus_cos.clone() * kz.clone() * kz.clone();
228    let rot_mat = Tensor::stack(
229        vec![
230            Tensor::stack::<2>(vec![r11, r12, r13], 1),
231            Tensor::stack::<2>(vec![r21, r22, r23], 1),
232            Tensor::stack::<2>(vec![r31, r32, r33], 1),
233        ],
234        1,
235    );
236    rot_mat
237}
238#[allow(clippy::let_and_return)]
239pub fn batch_rodrigues_burn_3<B: Backend>(full_pose: &Tensor<B, 2, Float>) -> Tensor<B, 3, Float> {
240    let device = full_pose.device();
241    let angle: Tensor<B, 1> = full_pose.clone().powi_scalar(2).sum_dim(1).sqrt().squeeze(1);
242    let denom = angle.clone().unsqueeze_dim(1) + 1e-6;
243    let k = full_pose.clone() / denom;
244    let k_3_1 = k.clone().unsqueeze_dim(2);
245    let k_1_3 = k.clone().unsqueeze_dim(1);
246    let kk_t = k_3_1 * k_1_3;
247    let kx = k.clone().slice_dim(1, 0..1).squeeze(1);
248    let ky = k.clone().slice_dim(1, 1..2).squeeze(1);
249    let kz = k.clone().slice_dim(1, 2..3).squeeze(1);
250    let zero = Tensor::<B, 1, Float>::zeros_like(&kx);
251    let row1 = Tensor::stack::<2>(vec![zero.clone(), -kz.clone(), ky.clone()], 1);
252    let row2 = Tensor::stack::<2>(vec![kz.clone(), zero.clone(), -kx.clone()], 1);
253    let row3 = Tensor::stack::<2>(vec![-ky.clone(), kx.clone(), zero.clone()], 1);
254    let k = Tensor::stack(vec![row1, row2, row3], 1);
255    let cos: Tensor<B, 3> = angle.clone().cos().unsqueeze_dim::<2>(1).unsqueeze_dim(2);
256    let sin: Tensor<B, 3> = angle.clone().sin().unsqueeze_dim::<2>(1).unsqueeze_dim(2);
257    let one_minus_cos = 1.0 - cos.clone();
258    let eye = Tensor::<B, 2, Float>::eye(3, &device).unsqueeze_dim(0);
259    let rot = cos * eye + one_minus_cos * kk_t + sin * k;
260    rot
261}
262pub fn euler2angleaxis(euler_x: f32, euler_y: f32, euler_z: f32) -> na::Vector3<f32> {
263    let c1 = f32::cos(euler_x / 2.0);
264    let c2 = f32::cos(euler_y / 2.0);
265    let c3 = f32::cos(euler_z / 2.0);
266    let s1: f32 = f32::sin(euler_x / 2.0);
267    let s2 = f32::sin(euler_y / 2.0);
268    let s3 = f32::sin(euler_z / 2.0);
269    let rot = na::Quaternion::new(
270        c1 * c2 * c3 - s1 * s2 * s3,
271        s1 * c2 * c3 + c1 * s2 * s3,
272        c1 * s2 * c3 - s1 * c2 * s3,
273        c1 * c2 * s3 + s1 * s2 * c3,
274    );
275    let rot = na::UnitQuaternion::new_normalize(rot);
276    rot.scaled_axis()
277}
278/// Interpolates between two axis angles using a slerp
279pub fn interpolate_axis_angle(this_axis: &nd::Array1<f32>, other_axis: &nd::Array1<f32>, other_weight: f32) -> nd::Array1<f32> {
280    let this_axis_na = this_axis.clone().into_nalgebra();
281    let other_axis_na = other_axis.clone().into_nalgebra();
282    let cur_r = na::Rotation3::new(this_axis_na.fixed_rows(0));
283    let other_r = na::Rotation3::new(other_axis_na.fixed_rows(0));
284    let new_r = cur_r.slerp(&other_r, other_weight);
285    let axis_angle = new_r.scaled_axis();
286    let new_axis_angle_nd = array![axis_angle.x, axis_angle.y, axis_angle.z];
287    new_axis_angle_nd
288}
289/// Interpolates betwen batch of axis angles where the batch is shape
290/// [``nr_joints``, 3]
291pub fn interpolate_axis_angle_batch(this_axis: &nd::Array2<f32>, other_axis: &nd::Array2<f32>, other_weight: f32) -> nd::Array2<f32> {
292    let this_axis_na = this_axis.clone().into_nalgebra();
293    let other_axis_na = other_axis.clone().into_nalgebra();
294    let mut new_axis_angles = nd::Array2::<f32>::zeros(this_axis_na.shape());
295    for ((this_axis, other_axis), mut new_joint) in this_axis_na
296        .row_iter()
297        .zip(other_axis_na.row_iter())
298        .zip(new_axis_angles.axis_iter_mut(nd::Axis(0)))
299    {
300        let cur_r = na::Rotation3::new(this_axis.transpose().fixed_rows(0));
301        let other_r = na::Rotation3::new(other_axis.transpose().fixed_rows(0));
302        let new_r = cur_r.slerp(&other_r, other_weight);
303        let axis_angle = new_r.scaled_axis();
304        new_joint.assign(&array![axis_angle.x, axis_angle.y, axis_angle.z]);
305    }
306    new_axis_angles
307}
308#[allow(clippy::missing_panics_doc)]
309#[allow(clippy::similar_names)]
310#[allow(clippy::cast_sign_loss)]
311pub fn batch_rigid_transform(
312    parent_idx_per_joint: Vec<u32>,
313    rot_mats: &nd::Array3<f32>,
314    joints: &nd::Array2<f32>,
315    num_joints: usize,
316) -> (nd::Array2<f32>, nd::Array3<f32>) {
317    let mut rel_joints = joints.clone();
318    let parent_idx_data_u32 = parent_idx_per_joint;
319    let parent_idx_per_joint = nd::Array1::from_vec(parent_idx_data_u32);
320    for (idx_cur, idx_parent) in parent_idx_per_joint.iter().enumerate().skip(1) {
321        let parent_joint_position = joints.row(*idx_parent as usize);
322        rel_joints.row_mut(idx_cur).sub_assign(&parent_joint_position);
323    }
324    let mut transforms_mat = ndarray::Array3::<f32>::zeros((num_joints + 1, 4, 4));
325    for idx in 0..=num_joints {
326        let rot = rot_mats.slice(s![idx, .., ..]).to_owned();
327        let t = rel_joints.row(idx).to_owned();
328        transforms_mat.slice_mut(s![idx, 0..3, 0..3]).assign(&rot);
329        transforms_mat.slice_mut(s![idx, 0..3, 3]).assign(&t);
330        transforms_mat.slice_mut(s![idx, 3, 0..4]).assign(&array![0.0, 0.0, 0.0, 1.0]);
331    }
332    let mut transform_chain = Vec::new();
333    transform_chain.push(transforms_mat.slice(s![0, 0..4, 0..4]).to_owned().into_shape_with_order((4, 4)).unwrap());
334    for i in 1..=num_joints {
335        let mat_1 = &transform_chain[parent_idx_per_joint[[i]] as usize];
336        let mat_2 = transforms_mat.slice(s![i, 0..4, 0..4]);
337        let curr_res = mat_1.dot(&mat_2);
338        transform_chain.push(curr_res);
339    }
340    let mut posed_joints = joints.clone();
341    for (i, tf) in transform_chain.iter().enumerate() {
342        let t = tf.slice(s![0..3, 3]);
343        posed_joints.row_mut(i).assign(&t);
344    }
345    let mut rel_transforms = ndarray::Array3::<f32>::zeros((num_joints + 1, 4, 4));
346    for (i, transform) in transform_chain.iter().enumerate() {
347        let (jx, jy, jz) = (joints.row(i)[0], joints.row(i)[1], joints.row(i)[2]);
348        let joint_homogen = array![jx, jy, jz, 0.0];
349        let transformed_joint = transform.dot(&joint_homogen);
350        let mut transformed_joint_4 = nd::Array2::<f32>::zeros((4, 4));
351        transformed_joint_4.slice_mut(s![0..4, 3]).assign(&transformed_joint);
352        transformed_joint_4 = transform - transformed_joint_4;
353        rel_transforms.slice_mut(s![i, .., ..]).assign(&transformed_joint_4);
354    }
355    (posed_joints, rel_transforms)
356}
357/// Burn-only batch rigid transform
358pub fn batch_rigid_transform_burn<B: Backend>(
359    parent_idx_per_joint_t: Tensor<B, 1, Int>,
360    parent_idx_per_joint: &nd::Array1<u32>,
361    rot_mats: Tensor<B, 3>,
362    joints: Tensor<B, 2>,
363) -> (Tensor<B, 2>, Tensor<B, 3>) {
364    let num_joints = joints.dims()[0];
365    let parent_idx_per_joint_t = parent_idx_per_joint_t.slice_fill(0..1, 0);
366    let parent_joints = joints.clone().select(0, parent_idx_per_joint_t);
367    let rel_joints = joints.clone() - parent_joints;
368    let rel_joints = rel_joints.slice_assign([0..1, 0..3], joints.clone().slice([0..1, 0..3]));
369    let eye_row = Tensor::zeros([num_joints, 1, 4], &joints.device());
370    let eye_row = eye_row.slice_fill([0..num_joints, 0..1, 3..4], 1.0);
371    let t_col = rel_joints.reshape([num_joints, 3, 1]);
372    let upper = Tensor::cat(vec![rot_mats, t_col], 2);
373    let transforms = Tensor::cat(vec![upper, eye_row], 1);
374    let mut transform_chain: Vec<Tensor<B, 2>> = Vec::new();
375    #[allow(clippy::needless_range_loop)]
376    #[allow(clippy::single_range_in_vec_init)]
377    #[allow(clippy::range_plus_one)]
378    for j in 0..num_joints {
379        let parent = parent_idx_per_joint[j] as usize;
380        let t_j = transforms.clone().slice([j..j + 1]);
381        let t_j = t_j.squeeze(0);
382        if j == 0 {
383            transform_chain.push(t_j);
384        } else {
385            let parent_t = transform_chain[parent].clone();
386            transform_chain.push(parent_t.matmul(t_j));
387        }
388    }
389    let transform_chain = Tensor::stack(transform_chain, 0);
390    let posed_joints = transform_chain.clone().slice([0..num_joints, 0..3, 3..4]).squeeze(2);
391    let joints_homo = joints.pad((0, 1, 0, 0), 0.0);
392    let joints_homo = joints_homo.unsqueeze_dim(2);
393    let transformed_joint: Tensor<B, 2> = transform_chain.clone().matmul(joints_homo).squeeze(2);
394    let mut transformed_joint_4 = Tensor::zeros_like(&transform_chain.clone());
395    transformed_joint_4 = transformed_joint_4.slice_assign([0..num_joints, 0..4, 3..4], transformed_joint.unsqueeze_dim(2));
396    let rel_transforms = transform_chain - transformed_joint_4;
397    (posed_joints, rel_transforms)
398}
399/// Faster Burn implementation of `batch_rigid_transform` (single-skeleton version)
400/// - `parent_idx_per_joint_t`: Tensor<Int> shape [J] (index tensor)
401/// - `parent_idx_per_joint`: `ndarray::Array1`<u32> (cpu-side parent indices, same shape)
402/// - `rot_mats`: Tensor shape [J,3,3]
403/// - joints: Tensor shape [J,3]
404///   instead of doing a sequential loop over the 55 joints to accumulate the transforms, we do log J iterations.
405///   Assume the tree is like
406///   root (0)
407///     |
408///     1
409///     |
410///     2
411///     |
412///     3
413///   On the first iteration, each joint knows about its local transform
414///   chain[0] = L[0]    (root, special case)
415///   chain[1] = L[1]
416///   chain[2] = L[2]
417///   chain[3] = L[3]
418///   Then we accumulate the transform to the parent
419///   chain[j] = chain[parent[j]] · chain[j]
420///   root (0)          chain[0] = L[0]
421///     1               chain[1] = L[0]·L[1]
422///     2               chain[2] = L[1]·L[2]
423///     3               chain[3] = L[2]·L[3]
424///   Then we accumulate the transform to the grandparent
425///   chain[j] = chain[parent^2[j]] · chain[j]
426///   root (0)          chain[0] = L[0]
427///     1               chain[1] = L[0]·L[1]
428///     2               chain[2] = L[0]·L[1]·L[2]
429///     3               chain[3] = L[0]·L[1]·L[2]·L[3]
430pub fn batch_rigid_transform_burn_fast<B: Backend>(
431    mut parent_idx_per_joint_t: Tensor<B, 1, Int>,
432    _parent_idx_per_joint: &nd::Array1<u32>,
433    rot_mats: Tensor<B, 3>,
434    joints: Tensor<B, 2>,
435) -> (Tensor<B, 2>, Tensor<B, 3>) {
436    let num_joints = joints.dims()[0];
437    parent_idx_per_joint_t = parent_idx_per_joint_t.clone().slice_fill(0..1, 0);
438    let parent_joints = joints.clone().select(0, parent_idx_per_joint_t.clone());
439    let mut rel_joints = joints.clone() - parent_joints;
440    rel_joints = rel_joints.slice_assign([0..1, 0..3], joints.clone().slice([0..1, 0..3]));
441    let t_col = rel_joints.reshape([num_joints, 3, 1]);
442    let upper = Tensor::cat(vec![rot_mats, t_col], 2);
443    let mut eye_row = Tensor::zeros([num_joints, 1, 4], &joints.device());
444    eye_row = eye_row.slice_fill([0..num_joints, 0..1, 3..4], 1.0);
445    let transforms = Tensor::cat(vec![upper, eye_row], 1);
446    let mut transform_chain = transforms.clone();
447    let identity = Tensor::eye(4, &joints.device()).unsqueeze_dim(0);
448    transform_chain = transform_chain.slice_assign([0..1, 0..4, 0..4], identity.clone());
449    let mut parent_pow = parent_idx_per_joint_t.clone();
450    #[allow(clippy::cast_possible_truncation)]
451    #[allow(clippy::cast_sign_loss)]
452    #[allow(clippy::cast_precision_loss)]
453    let max_steps = if num_joints <= 1 {
454        0usize
455    } else {
456        (num_joints as f32).log2().ceil() as usize
457    };
458    for _ in 0..max_steps {
459        let parent_transforms = transform_chain.clone().select(0, parent_pow.clone());
460        let new_chain = parent_transforms.matmul(transform_chain.clone());
461        parent_pow = parent_pow.clone().select(0, parent_pow.clone());
462        transform_chain = new_chain;
463    }
464    let root_transform = transforms.clone().slice([0..1, 0..4, 0..4]);
465    let transform_chain = root_transform.matmul(transform_chain);
466    let posed_joints = transform_chain.clone().slice([0..num_joints, 0..3, 3..4]).squeeze(2);
467    let joints_homo = joints.pad((0, 1, 0, 0), 0.0).unsqueeze_dim(2);
468    let transformed_joint: Tensor<B, 2> = transform_chain.clone().matmul(joints_homo).squeeze(2);
469    let mut transformed_joint_4 = Tensor::zeros_like(&transform_chain.clone());
470    transformed_joint_4 = transformed_joint_4.slice_assign([0..num_joints, 0..4, 3..4], transformed_joint.unsqueeze_dim(2));
471    let rel_transforms = transform_chain - transformed_joint_4;
472    (posed_joints, rel_transforms)
473}
474/// Converts a 2D array of quaternions of shape Nx4 (each row being a quaternion in format xyzw) and a 2D array of translations of shape Nx3 to extrinsics as an 3D array of Nx4x4
475pub fn extract_extrinsics_from_rot_trans(translations: &ndarray::Array2<f32>, rotations: &ndarray::Array2<f32>) -> ndarray::Array3<f32> {
476    let num_frames = translations.shape()[0].min(rotations.shape()[0]);
477    let mut extrinsics = ndarray::Array3::<f32>::zeros((num_frames, 4, 4));
478    for frame in 0..num_frames {
479        let trans = nalgebra::Vector3::new(translations[(frame, 0)], translations[(frame, 1)], translations[(frame, 2)]);
480        let quat = nalgebra::UnitQuaternion::new_normalize(nalgebra::Quaternion::new(
481            rotations[(frame, 3)],
482            rotations[(frame, 0)],
483            rotations[(frame, 1)],
484            rotations[(frame, 2)],
485        ));
486        let transform = nalgebra::Isometry3::from_parts(trans.into(), quat);
487        let matrix_4x4 = transform.to_homogeneous();
488        extrinsics.slice_mut(s![frame, .., ..]).assign(&matrix_4x4.ref_ndarray2());
489    }
490    extrinsics
491}