1use burn::tensor::{backend::Backend, Float, Int, Tensor};
2use gloss_utils::nshare::{RefNdarray2, ToNalgebra};
3use nalgebra as na;
4use nalgebra::clamp;
5use ndarray as nd;
6use ndarray::prelude::*;
7use std::{
8 f32::consts::PI,
9 ops::{Div, SubAssign},
10};
11pub fn hex_to_rgb(hex: &str) -> (u8, u8, u8) {
12 let hex = hex.trim_start_matches('#');
13 let r = u8::from_str_radix(&hex[0..2], 16).unwrap_or(0);
14 let g = u8::from_str_radix(&hex[2..4], 16).unwrap_or(0);
15 let b = u8::from_str_radix(&hex[4..6], 16).unwrap_or(0);
16 (r, g, b)
17}
18pub fn hex_to_rgb_f32(hex: &str) -> (f32, f32, f32) {
19 let (r, g, b) = hex_to_rgb(hex);
20 (f32::from(r) / 255.0, f32::from(g) / 255.0, f32::from(b) / 255.0)
21}
22pub fn interpolate_angle(cur_angle: f32, other_angle: f32, _cur_w: f32, other_w: f32) -> f32 {
23 let mut diff = other_angle - cur_angle;
24 if diff.abs() > PI {
25 if diff > 0.0 {
26 diff -= 2.0 * PI;
27 } else {
28 diff += 2.0 * PI;
29 }
30 }
31 cur_angle + other_w * diff
32}
33pub fn interpolate_angle_tensor<B: Backend>(cur_angle: Tensor<B, 1>, other_angle: Tensor<B, 1>, _cur_w: f32, other_w: f32) -> Tensor<B, 1> {
34 let mut diff = other_angle - cur_angle.clone();
35 assert!(cur_angle.dims() == [1]);
36 let abs_diff = diff.clone().abs();
37 let needs_adjustment = abs_diff.greater_elem(PI);
38 let two_pi = Tensor::<B, 1>::from_floats([2.0 * PI], &cur_angle.device());
39 let neg_two_pi = Tensor::<B, 1>::from_floats([-2.0 * PI], &cur_angle.device());
40 let positive_mask = diff.clone().greater_elem(0.0);
41 let negative_mask = diff.clone().lower_elem(0.0);
42 let pos_adjustment = positive_mask.clone().float() * neg_two_pi.clone();
43 let neg_adjustment = negative_mask.clone().float() * two_pi.clone();
44 let total_adjustment = pos_adjustment + neg_adjustment;
45 let adjustment = needs_adjustment.float() * total_adjustment;
46 diff = diff + adjustment;
47 cur_angle + other_w * diff
48}
49pub fn axis_angle_to_quaternion<B: Backend>(axis_angle: Tensor<B, 2>) -> Tensor<B, 2> {
50 let eps = 1e-6f32;
51 let angle: Tensor<B, 1> = axis_angle.clone().powf_scalar(2.0).sum_dim(1).squeeze_dims(&[1]).sqrt();
52 let denom = angle.clone().unsqueeze_dim(1) + eps;
53 let axis = axis_angle / denom;
54 let half_angle = angle * 0.5;
55 let cos_half = half_angle.clone().cos();
56 let sin_half = half_angle.sin();
57 let qxyz = axis.clone().slice([0..axis.dims()[0], 0..3]) * sin_half.clone().unsqueeze_dim(1);
58 let qw = cos_half.unsqueeze_dim(1);
59 Tensor::cat(vec![qxyz, qw], 1)
60}
61pub fn quaternion_to_axis_angle<B: Backend>(quat: Tensor<B, 2>) -> Tensor<B, 2> {
62 let eps = 1e-6f32;
63 let nr_rows = quat.dims()[0];
64 let qxyz = quat.clone().slice([0..nr_rows, 0..3]);
65 let qw = quat.slice([0..nr_rows, 3..4]).squeeze(1);
66 let vec_norm = qxyz.clone().powf_scalar(2.0).sum_dim(1).sqrt().squeeze_dims(&[1]);
67 let abs_qw = qw.abs();
68 let safe_qw = abs_qw.clone() + eps;
69 let half_angle_tan = vec_norm.clone() / safe_qw;
70 let small_rotation_mask = abs_qw.greater_elem(0.9);
71 let small_angle_approx = 2.0 * vec_norm.clone();
72 let x = half_angle_tan.clone();
73 let atan_approx = x.clone() / (1.0 + 0.28 * x.powf_scalar(2.0));
74 let large_angle_approx = 2.0 * atan_approx;
75 let small_mask_float = small_rotation_mask.clone().float();
76 let angle: Tensor<B, 1> = small_mask_float.clone() * small_angle_approx + (1.0 - small_mask_float) * large_angle_approx;
77 let small_angle_mask = vec_norm.clone().lower_elem(eps);
78 let safe_vec_norm = vec_norm.clone() + eps;
79 let angle_over_norm = angle.unsqueeze_dim(1) / safe_vec_norm.unsqueeze_dim(1);
80 let axis_angle = qxyz * angle_over_norm;
81 let small_angle_mask_3d = small_angle_mask.float().unsqueeze_dim(1);
82 (1.0 - small_angle_mask_3d) * axis_angle
83}
84pub fn quaternion_to_axis_angle_fast<B: Backend>(quat: Tensor<B, 2>) -> Tensor<B, 2> {
85 let eps = 1e-6f32;
86 let nr_rows = quat.dims()[0];
87 let qxyz = quat.clone().slice([0..nr_rows, 0..3]);
88 let qw = quat.slice([0..nr_rows, 3..4]).squeeze(1);
89 let one_minus_w: Tensor<B, 1> = 1.0 - qw.clone();
90 let sqrt_term = one_minus_w.sqrt();
91 let acos_w = sqrt_term * (1.570_728_8 + qw.clone() * (-0.212_114_4 + qw.clone() * (0.074_261_0 + qw.clone() * -0.018_729_3)));
92 let angle: Tensor<B, 1> = 2.0 * acos_w;
93 let one_min_square: Tensor<B, 1> = 1.0 - qw.clone() * qw.clone();
94 let sin_half_angle: Tensor<B, 1> = one_min_square.sqrt();
95 let denom = sin_half_angle + eps;
96 let axis = qxyz / denom.unsqueeze_dim(1);
97 axis * angle.unsqueeze_dim(1)
98}
99pub fn quaternion_interpolate_slerp<B: Backend>(lhs: Tensor<B, 2>, other: Tensor<B, 2>, other_weight: f32) -> Tensor<B, 2> {
100 let eps = 1e-6f32;
101 let lhs_norm = lhs.clone().powf_scalar(2.0).sum_dim(1).sqrt() + eps;
102 let other_norm = other.clone().powf_scalar(2.0).sum_dim(1).sqrt() + eps;
103 let lhs_normalized = lhs / lhs_norm;
104 let other_normalized = other / other_norm;
105 let dot: Tensor<B, 1> = (lhs_normalized.clone() * other_normalized.clone()).sum_dim(1).squeeze_dims(&[1]);
106 let negative_dot_mask = dot.clone().lower_elem(0.0);
107 let negative_dot_mask_float: Tensor<B, 1> = negative_dot_mask.clone().float();
108 let dot_mask_float: Tensor<B, 1> = 1.0 - negative_dot_mask_float;
109 let sign_corrected_other =
110 negative_dot_mask.clone().float().unsqueeze_dim(1) * (-other_normalized.clone()) + dot_mask_float.unsqueeze_dim(1) * other_normalized.clone();
111 let corrected_dot: Tensor<B, 1> = dot.clone().abs();
112 let close_threshold = 0.9995f32;
113 let very_close_mask = corrected_dot.clone().greater_elem(close_threshold);
114 let lerp_result = lhs_normalized.clone() * (1.0 - other_weight) + sign_corrected_other.clone() * other_weight;
115 let lerp_norm = lerp_result.clone().powf_scalar(2.0).sum_dim(1).sqrt() + eps;
116 let lerp_normalized = lerp_result / lerp_norm;
117 let one_minus_dot_sq: Tensor<B, 1> = 1.0 - corrected_dot.clone().powf_scalar(2.0);
118 let sqrt_term = one_minus_dot_sq.sqrt();
119 let safe_dot = corrected_dot.clone() + eps;
120 let ratio = sqrt_term / safe_dot;
121 let theta_approx: Tensor<B, 1> = ratio.clone() / (1.0 + 0.28 * ratio.clone().powf_scalar(2.0));
122 let sin_theta = theta_approx.clone().sin();
123 let safe_sin_theta = sin_theta.clone() + eps;
124 let weight_lhs = ((1.0 - other_weight) * theta_approx.clone()).sin() / safe_sin_theta.clone();
125 let weight_other = (other_weight * theta_approx).sin() / safe_sin_theta;
126 let slerp_result = lhs_normalized.clone() * weight_lhs.unsqueeze_dim(1) + sign_corrected_other * weight_other.unsqueeze_dim(1);
127 let inv_very_close_mask_float: Tensor<B, 1> = 1.0 - very_close_mask.clone().float();
128 very_close_mask.clone().float().unsqueeze_dim(1) * lerp_normalized + inv_very_close_mask_float.unsqueeze_dim(1) * slerp_result
129}
130pub fn quaternion_interpolate_lerp<B: Backend>(lhs: Tensor<B, 2>, other: Tensor<B, 2>, other_weight: f32) -> Tensor<B, 2> {
131 let eps = 1e-6f32;
132 let dot: Tensor<B, 1> = (lhs.clone() * other.clone()).sum_dim(1).squeeze_dims(&[1]);
133 let negative_dot_mask = dot.lower_elem(0.0);
134 let negative_dot_mask_float: Tensor<B, 1> = negative_dot_mask.float();
135 let positive_dot_mask_float: Tensor<B, 1> = 1.0 - negative_dot_mask_float.clone();
136 let sign_corrected_other = negative_dot_mask_float.clone().unsqueeze_dim(1) * (-other.clone()) + positive_dot_mask_float.unsqueeze_dim(1) * other;
137 let lerp_result = lhs * (1.0 - other_weight) + sign_corrected_other * other_weight;
138 let lerp_norm_sq = lerp_result.clone().powf_scalar(2.0).sum_dim(1);
139 lerp_result / (lerp_norm_sq.sqrt() + eps)
140}
141pub fn quaternion_interpolate_lerp_fast<B: Backend>(lhs: Tensor<B, 2>, other: Tensor<B, 2>, other_weight: f32) -> Tensor<B, 2> {
142 let eps = 1e-6f32;
143 let lerp_result = lhs * (1.0 - other_weight) + other * other_weight;
144 let lerp_norm_sq = lerp_result.clone().powf_scalar(2.0).sum_dim(1);
145 lerp_result / (lerp_norm_sq.sqrt() + eps)
146}
147pub fn map(value: f32, in_min: f32, in_max: f32, out_min: f32, out_max: f32) -> f32 {
148 let value_clamped = clamp(value, in_min, in_max);
149 out_min + (out_max - out_min) * (value_clamped - in_min) / (in_max - in_min)
150}
151pub fn smootherstep(low: f32, high: f32, val: f32) -> f32 {
152 let t = map(val, low, high, 0.0, 1.0);
153 t * t * t * (t * (t * 6.0 - 15.0) + 10.0)
154}
155pub fn batch_rodrigues(full_pose: &nd::Array2<f32>) -> nd::Array3<f32> {
156 let mut rotations_per_join = ndarray::Array3::<f32>::zeros((full_pose.shape()[0], 3, 3));
157 for (idx, v) in full_pose.axis_iter(nd::Axis(0)).enumerate() {
158 let angle = v.iter().map(|x| x * x).sum::<f32>().sqrt();
159 let rot_dir = full_pose.row(idx).to_owned().div(angle + 1e-6);
160 let cos = angle.cos();
161 let sin = angle.sin();
162 let (rx, ry, rz) = (rot_dir[0], rot_dir[1], rot_dir[2]);
163 let k = array![[0.0, -rz, ry], [rz, 0.0, -rx], [-ry, rx, 0.0]];
164 let identity = ndarray::Array2::<f32>::eye(3);
165 let rot_mat = identity + sin * k.clone() + (1.0 - cos) * k.dot(&k);
166 rotations_per_join.slice_mut(s![idx, .., ..]).assign(&rot_mat);
167 }
168 rotations_per_join
169}
170#[allow(clippy::let_and_return)]
171pub fn batch_rodrigues_burn<B: Backend>(full_pose: &Tensor<B, 2, Float>) -> Tensor<B, 3, Float> {
172 let eps = Tensor::<B, 1, Float>::from_floats([1e-6], &full_pose.device());
173 let angle = full_pose.clone().powf_scalar(2.0).sum_dim(1).sqrt();
174 let denom = angle.clone() + eps.unsqueeze_dim(0);
175 let k = full_pose.clone() / denom;
176 let kx: Tensor<B, 1> = k.clone().slice_dim(1, 0..1).squeeze(1);
177 let ky: Tensor<B, 1> = k.clone().slice_dim(1, 1..2).squeeze(1);
178 let kz: Tensor<B, 1> = k.clone().slice_dim(1, 2..3).squeeze(1);
179 let zero: Tensor<B, 2> = Tensor::<B, 1, Float>::zeros_like(&kx).unsqueeze_dim(1);
180 let k11 = zero.clone();
181 let k12 = -kz.clone().unsqueeze_dim(1);
182 let k13 = ky.clone().unsqueeze_dim(1);
183 let k21 = kz.clone().unsqueeze_dim(1);
184 let k22 = zero.clone();
185 let k23 = -kx.clone().unsqueeze_dim(1);
186 let k31 = -ky.clone().unsqueeze_dim(1);
187 let k32 = kx.clone().unsqueeze_dim(1);
188 let k33 = zero;
189 let k_mat = Tensor::cat(
190 vec![
191 Tensor::cat(vec![k11, k12, k13], 1),
192 Tensor::cat(vec![k21, k22, k23], 1),
193 Tensor::cat(vec![k31, k32, k33], 1),
194 ],
195 1,
196 )
197 .reshape([-1, 3, 3]);
198 let cos = angle.clone().cos().unsqueeze_dim(2);
199 let sin = angle.clone().sin().unsqueeze_dim(2);
200 let eye = Tensor::<B, 2, Float>::eye(3, &full_pose.device()).unsqueeze_dim(0);
201 let eye = eye.repeat(&[full_pose.dims()[0], 1, 1]);
202 let k_sq = k_mat.clone().matmul(k_mat.clone());
203 let rot_mat = eye + sin * k_mat + (Tensor::ones_like(&cos) - cos) * k_sq;
204 rot_mat
205}
206#[allow(clippy::let_and_return)]
207pub fn batch_rodrigues_burn_2<B: Backend>(full_pose: &Tensor<B, 2, Float>) -> Tensor<B, 3, Float> {
208 let eps = Tensor::<B, 1, Float>::from_floats([1e-6], &full_pose.device());
209 let angle = full_pose.clone().powf_scalar(2.0).sum_dim(1).sqrt().squeeze(1);
210 let denom = angle.clone().unsqueeze_dim(1) + eps.unsqueeze_dim(0);
211 let k = full_pose.clone() / denom;
212 let kx: Tensor<B, 1> = k.clone().slice_dim(1, 0..1).squeeze(1);
213 let ky: Tensor<B, 1> = k.clone().slice_dim(1, 1..2).squeeze(1);
214 let kz: Tensor<B, 1> = k.clone().slice_dim(1, 2..3).squeeze(1);
215 let cos = angle.clone().cos();
216 let sin = angle.clone().sin();
217 let one = Tensor::<B, 1, Float>::ones_like(&cos);
218 let one_minus_cos = one.clone() - cos.clone();
219 let r11 = cos.clone() + one_minus_cos.clone() * kx.clone() * kx.clone();
220 let r12 = one_minus_cos.clone() * kx.clone() * ky.clone() - sin.clone() * kz.clone();
221 let r13 = one_minus_cos.clone() * kx.clone() * kz.clone() + sin.clone() * ky.clone();
222 let r21 = one_minus_cos.clone() * ky.clone() * kx.clone() + sin.clone() * kz.clone();
223 let r22 = cos.clone() + one_minus_cos.clone() * ky.clone() * ky.clone();
224 let r23 = one_minus_cos.clone() * ky.clone() * kz.clone() - sin.clone() * kx.clone();
225 let r31 = one_minus_cos.clone() * kz.clone() * kx.clone() - sin.clone() * ky.clone();
226 let r32 = one_minus_cos.clone() * kz.clone() * ky.clone() + sin.clone() * kx.clone();
227 let r33 = cos.clone() + one_minus_cos.clone() * kz.clone() * kz.clone();
228 let rot_mat = Tensor::stack(
229 vec![
230 Tensor::stack::<2>(vec![r11, r12, r13], 1),
231 Tensor::stack::<2>(vec![r21, r22, r23], 1),
232 Tensor::stack::<2>(vec![r31, r32, r33], 1),
233 ],
234 1,
235 );
236 rot_mat
237}
238#[allow(clippy::let_and_return)]
239pub fn batch_rodrigues_burn_3<B: Backend>(full_pose: &Tensor<B, 2, Float>) -> Tensor<B, 3, Float> {
240 let device = full_pose.device();
241 let angle: Tensor<B, 1> = full_pose.clone().powi_scalar(2).sum_dim(1).sqrt().squeeze(1);
242 let denom = angle.clone().unsqueeze_dim(1) + 1e-6;
243 let k = full_pose.clone() / denom;
244 let k_3_1 = k.clone().unsqueeze_dim(2);
245 let k_1_3 = k.clone().unsqueeze_dim(1);
246 let kk_t = k_3_1 * k_1_3;
247 let kx = k.clone().slice_dim(1, 0..1).squeeze(1);
248 let ky = k.clone().slice_dim(1, 1..2).squeeze(1);
249 let kz = k.clone().slice_dim(1, 2..3).squeeze(1);
250 let zero = Tensor::<B, 1, Float>::zeros_like(&kx);
251 let row1 = Tensor::stack::<2>(vec![zero.clone(), -kz.clone(), ky.clone()], 1);
252 let row2 = Tensor::stack::<2>(vec![kz.clone(), zero.clone(), -kx.clone()], 1);
253 let row3 = Tensor::stack::<2>(vec![-ky.clone(), kx.clone(), zero.clone()], 1);
254 let k = Tensor::stack(vec![row1, row2, row3], 1);
255 let cos: Tensor<B, 3> = angle.clone().cos().unsqueeze_dim::<2>(1).unsqueeze_dim(2);
256 let sin: Tensor<B, 3> = angle.clone().sin().unsqueeze_dim::<2>(1).unsqueeze_dim(2);
257 let one_minus_cos = 1.0 - cos.clone();
258 let eye = Tensor::<B, 2, Float>::eye(3, &device).unsqueeze_dim(0);
259 let rot = cos * eye + one_minus_cos * kk_t + sin * k;
260 rot
261}
262pub fn euler2angleaxis(euler_x: f32, euler_y: f32, euler_z: f32) -> na::Vector3<f32> {
263 let c1 = f32::cos(euler_x / 2.0);
264 let c2 = f32::cos(euler_y / 2.0);
265 let c3 = f32::cos(euler_z / 2.0);
266 let s1: f32 = f32::sin(euler_x / 2.0);
267 let s2 = f32::sin(euler_y / 2.0);
268 let s3 = f32::sin(euler_z / 2.0);
269 let rot = na::Quaternion::new(
270 c1 * c2 * c3 - s1 * s2 * s3,
271 s1 * c2 * c3 + c1 * s2 * s3,
272 c1 * s2 * c3 - s1 * c2 * s3,
273 c1 * c2 * s3 + s1 * s2 * c3,
274 );
275 let rot = na::UnitQuaternion::new_normalize(rot);
276 rot.scaled_axis()
277}
278pub fn interpolate_axis_angle(this_axis: &nd::Array1<f32>, other_axis: &nd::Array1<f32>, other_weight: f32) -> nd::Array1<f32> {
280 let this_axis_na = this_axis.clone().into_nalgebra();
281 let other_axis_na = other_axis.clone().into_nalgebra();
282 let cur_r = na::Rotation3::new(this_axis_na.fixed_rows(0));
283 let other_r = na::Rotation3::new(other_axis_na.fixed_rows(0));
284 let new_r = cur_r.slerp(&other_r, other_weight);
285 let axis_angle = new_r.scaled_axis();
286 let new_axis_angle_nd = array![axis_angle.x, axis_angle.y, axis_angle.z];
287 new_axis_angle_nd
288}
289pub fn interpolate_axis_angle_batch(this_axis: &nd::Array2<f32>, other_axis: &nd::Array2<f32>, other_weight: f32) -> nd::Array2<f32> {
292 let this_axis_na = this_axis.clone().into_nalgebra();
293 let other_axis_na = other_axis.clone().into_nalgebra();
294 let mut new_axis_angles = nd::Array2::<f32>::zeros(this_axis_na.shape());
295 for ((this_axis, other_axis), mut new_joint) in this_axis_na
296 .row_iter()
297 .zip(other_axis_na.row_iter())
298 .zip(new_axis_angles.axis_iter_mut(nd::Axis(0)))
299 {
300 let cur_r = na::Rotation3::new(this_axis.transpose().fixed_rows(0));
301 let other_r = na::Rotation3::new(other_axis.transpose().fixed_rows(0));
302 let new_r = cur_r.slerp(&other_r, other_weight);
303 let axis_angle = new_r.scaled_axis();
304 new_joint.assign(&array![axis_angle.x, axis_angle.y, axis_angle.z]);
305 }
306 new_axis_angles
307}
308#[allow(clippy::missing_panics_doc)]
309#[allow(clippy::similar_names)]
310#[allow(clippy::cast_sign_loss)]
311pub fn batch_rigid_transform(
312 parent_idx_per_joint: Vec<u32>,
313 rot_mats: &nd::Array3<f32>,
314 joints: &nd::Array2<f32>,
315 num_joints: usize,
316) -> (nd::Array2<f32>, nd::Array3<f32>) {
317 let mut rel_joints = joints.clone();
318 let parent_idx_data_u32 = parent_idx_per_joint;
319 let parent_idx_per_joint = nd::Array1::from_vec(parent_idx_data_u32);
320 for (idx_cur, idx_parent) in parent_idx_per_joint.iter().enumerate().skip(1) {
321 let parent_joint_position = joints.row(*idx_parent as usize);
322 rel_joints.row_mut(idx_cur).sub_assign(&parent_joint_position);
323 }
324 let mut transforms_mat = ndarray::Array3::<f32>::zeros((num_joints + 1, 4, 4));
325 for idx in 0..=num_joints {
326 let rot = rot_mats.slice(s![idx, .., ..]).to_owned();
327 let t = rel_joints.row(idx).to_owned();
328 transforms_mat.slice_mut(s![idx, 0..3, 0..3]).assign(&rot);
329 transforms_mat.slice_mut(s![idx, 0..3, 3]).assign(&t);
330 transforms_mat.slice_mut(s![idx, 3, 0..4]).assign(&array![0.0, 0.0, 0.0, 1.0]);
331 }
332 let mut transform_chain = Vec::new();
333 transform_chain.push(transforms_mat.slice(s![0, 0..4, 0..4]).to_owned().into_shape_with_order((4, 4)).unwrap());
334 for i in 1..=num_joints {
335 let mat_1 = &transform_chain[parent_idx_per_joint[[i]] as usize];
336 let mat_2 = transforms_mat.slice(s![i, 0..4, 0..4]);
337 let curr_res = mat_1.dot(&mat_2);
338 transform_chain.push(curr_res);
339 }
340 let mut posed_joints = joints.clone();
341 for (i, tf) in transform_chain.iter().enumerate() {
342 let t = tf.slice(s![0..3, 3]);
343 posed_joints.row_mut(i).assign(&t);
344 }
345 let mut rel_transforms = ndarray::Array3::<f32>::zeros((num_joints + 1, 4, 4));
346 for (i, transform) in transform_chain.iter().enumerate() {
347 let (jx, jy, jz) = (joints.row(i)[0], joints.row(i)[1], joints.row(i)[2]);
348 let joint_homogen = array![jx, jy, jz, 0.0];
349 let transformed_joint = transform.dot(&joint_homogen);
350 let mut transformed_joint_4 = nd::Array2::<f32>::zeros((4, 4));
351 transformed_joint_4.slice_mut(s![0..4, 3]).assign(&transformed_joint);
352 transformed_joint_4 = transform - transformed_joint_4;
353 rel_transforms.slice_mut(s![i, .., ..]).assign(&transformed_joint_4);
354 }
355 (posed_joints, rel_transforms)
356}
357pub fn batch_rigid_transform_burn<B: Backend>(
359 parent_idx_per_joint_t: Tensor<B, 1, Int>,
360 parent_idx_per_joint: &nd::Array1<u32>,
361 rot_mats: Tensor<B, 3>,
362 joints: Tensor<B, 2>,
363) -> (Tensor<B, 2>, Tensor<B, 3>) {
364 let num_joints = joints.dims()[0];
365 let parent_idx_per_joint_t = parent_idx_per_joint_t.slice_fill(0..1, 0);
366 let parent_joints = joints.clone().select(0, parent_idx_per_joint_t);
367 let rel_joints = joints.clone() - parent_joints;
368 let rel_joints = rel_joints.slice_assign([0..1, 0..3], joints.clone().slice([0..1, 0..3]));
369 let eye_row = Tensor::zeros([num_joints, 1, 4], &joints.device());
370 let eye_row = eye_row.slice_fill([0..num_joints, 0..1, 3..4], 1.0);
371 let t_col = rel_joints.reshape([num_joints, 3, 1]);
372 let upper = Tensor::cat(vec![rot_mats, t_col], 2);
373 let transforms = Tensor::cat(vec![upper, eye_row], 1);
374 let mut transform_chain: Vec<Tensor<B, 2>> = Vec::new();
375 #[allow(clippy::needless_range_loop)]
376 #[allow(clippy::single_range_in_vec_init)]
377 #[allow(clippy::range_plus_one)]
378 for j in 0..num_joints {
379 let parent = parent_idx_per_joint[j] as usize;
380 let t_j = transforms.clone().slice([j..j + 1]);
381 let t_j = t_j.squeeze(0);
382 if j == 0 {
383 transform_chain.push(t_j);
384 } else {
385 let parent_t = transform_chain[parent].clone();
386 transform_chain.push(parent_t.matmul(t_j));
387 }
388 }
389 let transform_chain = Tensor::stack(transform_chain, 0);
390 let posed_joints = transform_chain.clone().slice([0..num_joints, 0..3, 3..4]).squeeze(2);
391 let joints_homo = joints.pad((0, 1, 0, 0), 0.0);
392 let joints_homo = joints_homo.unsqueeze_dim(2);
393 let transformed_joint: Tensor<B, 2> = transform_chain.clone().matmul(joints_homo).squeeze(2);
394 let mut transformed_joint_4 = Tensor::zeros_like(&transform_chain.clone());
395 transformed_joint_4 = transformed_joint_4.slice_assign([0..num_joints, 0..4, 3..4], transformed_joint.unsqueeze_dim(2));
396 let rel_transforms = transform_chain - transformed_joint_4;
397 (posed_joints, rel_transforms)
398}
399pub fn batch_rigid_transform_burn_fast<B: Backend>(
431 mut parent_idx_per_joint_t: Tensor<B, 1, Int>,
432 _parent_idx_per_joint: &nd::Array1<u32>,
433 rot_mats: Tensor<B, 3>,
434 joints: Tensor<B, 2>,
435) -> (Tensor<B, 2>, Tensor<B, 3>) {
436 let num_joints = joints.dims()[0];
437 parent_idx_per_joint_t = parent_idx_per_joint_t.clone().slice_fill(0..1, 0);
438 let parent_joints = joints.clone().select(0, parent_idx_per_joint_t.clone());
439 let mut rel_joints = joints.clone() - parent_joints;
440 rel_joints = rel_joints.slice_assign([0..1, 0..3], joints.clone().slice([0..1, 0..3]));
441 let t_col = rel_joints.reshape([num_joints, 3, 1]);
442 let upper = Tensor::cat(vec![rot_mats, t_col], 2);
443 let mut eye_row = Tensor::zeros([num_joints, 1, 4], &joints.device());
444 eye_row = eye_row.slice_fill([0..num_joints, 0..1, 3..4], 1.0);
445 let transforms = Tensor::cat(vec![upper, eye_row], 1);
446 let mut transform_chain = transforms.clone();
447 let identity = Tensor::eye(4, &joints.device()).unsqueeze_dim(0);
448 transform_chain = transform_chain.slice_assign([0..1, 0..4, 0..4], identity.clone());
449 let mut parent_pow = parent_idx_per_joint_t.clone();
450 #[allow(clippy::cast_possible_truncation)]
451 #[allow(clippy::cast_sign_loss)]
452 #[allow(clippy::cast_precision_loss)]
453 let max_steps = if num_joints <= 1 {
454 0usize
455 } else {
456 (num_joints as f32).log2().ceil() as usize
457 };
458 for _ in 0..max_steps {
459 let parent_transforms = transform_chain.clone().select(0, parent_pow.clone());
460 let new_chain = parent_transforms.matmul(transform_chain.clone());
461 parent_pow = parent_pow.clone().select(0, parent_pow.clone());
462 transform_chain = new_chain;
463 }
464 let root_transform = transforms.clone().slice([0..1, 0..4, 0..4]);
465 let transform_chain = root_transform.matmul(transform_chain);
466 let posed_joints = transform_chain.clone().slice([0..num_joints, 0..3, 3..4]).squeeze(2);
467 let joints_homo = joints.pad((0, 1, 0, 0), 0.0).unsqueeze_dim(2);
468 let transformed_joint: Tensor<B, 2> = transform_chain.clone().matmul(joints_homo).squeeze(2);
469 let mut transformed_joint_4 = Tensor::zeros_like(&transform_chain.clone());
470 transformed_joint_4 = transformed_joint_4.slice_assign([0..num_joints, 0..4, 3..4], transformed_joint.unsqueeze_dim(2));
471 let rel_transforms = transform_chain - transformed_joint_4;
472 (posed_joints, rel_transforms)
473}
474pub fn extract_extrinsics_from_rot_trans(translations: &ndarray::Array2<f32>, rotations: &ndarray::Array2<f32>) -> ndarray::Array3<f32> {
476 let num_frames = translations.shape()[0].min(rotations.shape()[0]);
477 let mut extrinsics = ndarray::Array3::<f32>::zeros((num_frames, 4, 4));
478 for frame in 0..num_frames {
479 let trans = nalgebra::Vector3::new(translations[(frame, 0)], translations[(frame, 1)], translations[(frame, 2)]);
480 let quat = nalgebra::UnitQuaternion::new_normalize(nalgebra::Quaternion::new(
481 rotations[(frame, 3)],
482 rotations[(frame, 0)],
483 rotations[(frame, 1)],
484 rotations[(frame, 2)],
485 ));
486 let transform = nalgebra::Isometry3::from_parts(trans.into(), quat);
487 let matrix_4x4 = transform.to_homogeneous();
488 extrinsics.slice_mut(s![frame, .., ..]).assign(&matrix_4x4.ref_ndarray2());
489 }
490 extrinsics
491}