1use super::{
2 metadata::smpl_metadata,
3 pose_override::PoseOverride,
4 types::{SmplType, UpAxis},
5};
6use crate::{codec::codec::SmplCodec, common::pose_parts::PosePart, smpl_h::smpl_h, smpl_x::smpl_x};
7use gloss_utils::nshare::ToNalgebra;
8use log::warn;
9use nalgebra as na;
10use nd::concatenate;
11use ndarray as nd;
12use ndarray::prelude::*;
13use smpl_utils::numerical::interpolate_angle;
14#[derive(Clone, Debug)]
16pub struct Pose {
17 pub joint_poses: nd::Array2<f32>,
18 pub global_trans: nd::Array1<f32>,
19 pub enable_pose_corrective: bool,
20 pub up_axis: UpAxis,
21 pub smpl_type: SmplType,
22 pub non_retargeted_pose: Option<Box<Pose>>,
23 pub retargeted: bool,
24}
25impl Pose {
26 pub fn new(joint_poses: nd::Array2<f32>, global_trans: nd::Array1<f32>, up_axis: UpAxis, smpl_type: SmplType) -> Self {
27 Self {
28 joint_poses,
29 global_trans,
30 enable_pose_corrective: false,
31 up_axis,
32 smpl_type,
33 non_retargeted_pose: None,
34 retargeted: false,
35 }
36 }
37 pub fn new_empty(up_axis: UpAxis, smpl_type: SmplType) -> Self {
38 let joint_poses = match smpl_type {
39 SmplType::SmplX => ndarray::Array2::<f32>::zeros((smpl_x::NUM_JOINTS + 1, 3)),
40 SmplType::SmplH => ndarray::Array2::<f32>::zeros((smpl_h::NUM_JOINTS + 1, 3)),
41 _ => panic!("{smpl_type:?} is not yet supported!"),
42 };
43 let global_trans = ndarray::Array1::<f32>::zeros(3);
44 Self {
45 joint_poses,
46 global_trans,
47 enable_pose_corrective: false,
48 up_axis,
49 smpl_type,
50 non_retargeted_pose: None,
51 retargeted: false,
52 }
53 }
54 #[allow(clippy::cast_sign_loss)]
58 pub fn new_from_smpl_codec(codec: &SmplCodec) -> Option<Self> {
59 let nr_frames = codec.frame_count as u32;
60 assert_eq!(nr_frames, 1, "For a pose the nr of frames in the codec has to be 1");
61 let metadata = smpl_metadata(&codec.smpl_type());
62 let body_translation = codec
63 .body_translation
64 .as_ref()
65 .unwrap_or(&ndarray::Array2::<f32>::zeros((1, 3)))
66 .index_axis(nd::Axis(0), 0)
67 .to_owned();
68 let body_pose = codec.body_pose.as_ref()?.index_axis(nd::Axis(0), 0).to_owned();
69 let head_pose = codec
70 .head_pose
71 .as_ref()
72 .unwrap_or(&ndarray::Array3::<f32>::zeros((1, metadata.num_face_joints, 3)))
73 .index_axis(nd::Axis(0), 0)
74 .into_owned();
75 let left_hand_pose = codec
76 .left_hand_pose
77 .as_ref()
78 .unwrap_or(&ndarray::Array3::<f32>::zeros((1, metadata.num_hand_joints, 3)))
79 .index_axis(nd::Axis(0), 0)
80 .into_owned();
81 let right_hand_pose = codec
82 .right_hand_pose
83 .as_ref()
84 .unwrap_or(&ndarray::Array3::<f32>::zeros((1, metadata.num_hand_joints, 3)))
85 .index_axis(nd::Axis(0), 0)
86 .into_owned();
87 let joint_poses = concatenate(
88 nd::Axis(0),
89 &[body_pose.view(), head_pose.view(), left_hand_pose.view(), right_hand_pose.view()],
90 )
91 .unwrap();
92 Some(Self::new(joint_poses, body_translation, UpAxis::Y, codec.smpl_type()))
93 }
94 #[cfg(not(target_arch = "wasm32"))]
96 #[allow(clippy::cast_possible_truncation)]
97 pub fn new_from_smpl_file(path: &str) -> Option<Self> {
98 let codec = SmplCodec::from_file(path);
99 Self::new_from_smpl_codec(&codec)
100 }
101 pub fn num_active_joints(&self) -> usize {
102 self.joint_poses.dim().0
103 }
104 pub fn apply_mask(&mut self, mask: &mut PoseOverride) {
105 let metadata = smpl_metadata(&self.smpl_type);
106 for part in &mask.denied_parts {
107 if *part == PosePart::RootTranslation {
108 self.global_trans.fill(0.0);
109 } else {
110 let range_of_body_part = metadata.parts2jointranges[*part].clone();
111 let num_joints = self.joint_poses.dim().0;
112 if range_of_body_part.start < num_joints {
113 let range_of_body_part_clamped = range_of_body_part.start..std::cmp::min(num_joints, range_of_body_part.end);
114 self.joint_poses.slice_mut(s![range_of_body_part_clamped, ..]).fill(0.0);
115 }
116 }
117 }
118 let range_left_hand = metadata.parts2jointranges[PosePart::LeftHand].clone();
119 let range_right_hand = metadata.parts2jointranges[PosePart::RightHand].clone();
120 if let Some(hand_type) = mask.overwrite_hands {
121 let original_left = self.joint_poses.slice(s![range_left_hand.clone(), ..]);
122 let original_right = self.joint_poses.slice(s![range_right_hand.clone(), ..]);
123 if mask.original_left_hand.is_none() {
124 mask.original_left_hand = Some(original_left.to_owned());
125 }
126 if mask.original_right_hand.is_none() {
127 mask.original_right_hand = Some(original_right.to_owned());
128 }
129 self.joint_poses
130 .slice_mut(s![range_left_hand, ..])
131 .assign(&metadata.hand_poses[hand_type].left);
132 self.joint_poses
133 .slice_mut(s![range_right_hand, ..])
134 .assign(&metadata.hand_poses[hand_type].right);
135 } else {
136 if let Some(left) = mask.original_left_hand.take() {
137 self.joint_poses.slice_mut(s![range_left_hand, ..]).assign(&left);
138 }
139 if let Some(right) = mask.original_right_hand.take() {
140 self.joint_poses.slice_mut(s![range_right_hand, ..]).assign(&right);
141 }
142 }
143 }
144 #[must_use]
146 pub fn interpolate(&self, other_pose: &Self, other_weight: f32) -> Pose {
147 if !(0.0..=1.0).contains(&other_weight) {
148 warn!("pose interpolation weight is outside the [0,1] range, will clamp. Weight is {other_weight}");
149 }
150 let other_weight = other_weight.clamp(0.0, 1.0);
151 assert!(
152 self.smpl_type == other_pose.smpl_type,
153 "We can only interpolate to a pose of the same type. Origin: {:?}. Dest: {:?}",
154 self.smpl_type,
155 other_pose.smpl_type
156 );
157 let non_angle_indices = [27, 28, 37, 38];
158 if self.smpl_type == SmplType::SmplPP {
159 let cur_w = 1.0 - other_weight;
160 let mut new_joint_poses = self.joint_poses.clone();
161 for (i, ((cur_angle, other_angle), new_angle)) in self
162 .joint_poses
163 .iter()
164 .zip(other_pose.joint_poses.iter())
165 .zip(new_joint_poses.iter_mut())
166 .enumerate()
167 {
168 if non_angle_indices.contains(&i) {
169 *new_angle = cur_w * cur_angle + other_weight * other_angle;
170 } else {
171 *new_angle = interpolate_angle(*cur_angle, *other_angle, cur_w, other_weight);
172 }
173 }
174 let new_global_trans = cur_w * &self.global_trans + other_weight * &other_pose.global_trans;
175 return Pose::new(new_joint_poses, new_global_trans, self.up_axis, self.smpl_type);
176 }
177 let cur_w = 1.0 - other_weight;
178 let new_global_trans = cur_w * &self.global_trans + other_weight * &other_pose.global_trans;
179 let cur_pose_axis_angle_na = self.joint_poses.view().into_nalgebra();
180 let other_pose_axis_angle_na = other_pose.joint_poses.view().into_nalgebra();
181 let mut new_joint_poses = nd::Array2::<f32>::zeros((self.num_active_joints(), 3));
182 for ((cur_axis, other_axis), mut new_joint) in cur_pose_axis_angle_na
183 .row_iter()
184 .zip(other_pose_axis_angle_na.row_iter())
185 .zip(new_joint_poses.axis_iter_mut(nd::Axis(0)))
186 {
187 let cur_vec = na::Vector3::new(cur_axis[0], cur_axis[1], cur_axis[2]);
188 let mut cur_q = na::UnitQuaternion::from_axis_angle(&na::UnitVector3::new_normalize(cur_vec), cur_axis.norm());
189 if cur_axis.norm() == 0.0 {
190 cur_q = na::UnitQuaternion::default();
191 }
192 let other_vec = na::Vector3::new(other_axis[0], other_axis[1], other_axis[2]);
193 let mut other_q = na::UnitQuaternion::from_axis_angle(&na::UnitVector3::new_normalize(other_vec), other_axis.norm());
194 if other_axis.norm() == 0.0 {
195 other_q = na::UnitQuaternion::default();
196 }
197 let new_q: na::Unit<na::Quaternion<f32>> = cur_q.slerp(&other_q, other_weight);
198 let axis_opt = new_q.axis();
199 let angle = new_q.angle();
200 if let Some(axis) = axis_opt {
201 new_joint.assign(&array![axis.x * angle, axis.y * angle, axis.z * angle]);
202 }
203 }
204 Pose::new(new_joint_poses, new_global_trans, self.up_axis, self.smpl_type)
205 }
206}