1use super::{
2 metadata::smpl_metadata,
3 pose_override::PoseOverride,
4 types::{SmplType, UpAxis},
5};
6use crate::AppBackend;
7use crate::{codec::codec::SmplCodec, common::pose_parts::PosePart, smpl_h::smpl_h, smpl_x::smpl_x};
8use burn::{prelude::Backend, tensor::Tensor};
9use gloss_utils::bshare::{ToBurn, ToNdArray};
10use log::warn;
11use nd::concatenate;
12use ndarray as nd;
13use smpl_utils::numerical::interpolate_angle_tensor;
14#[derive(Clone, Debug)]
16pub struct PoseG<B: Backend> {
17 pub device: B::Device,
18 pub joint_poses: Tensor<B, 2>,
19 pub global_trans: Tensor<B, 1>,
20 pub enable_pose_corrective: bool,
21 pub up_axis: UpAxis,
22 pub smpl_type: SmplType,
23 pub non_retargeted_pose: Option<Box<PoseG<B>>>,
24 pub retargeted: bool,
25}
26impl<B: Backend> PoseG<B> {
27 pub fn new(joint_poses: Tensor<B, 2>, global_trans: Tensor<B, 1>, up_axis: UpAxis, smpl_type: SmplType) -> Self {
28 Self {
29 device: joint_poses.device(),
30 joint_poses,
31 global_trans,
32 enable_pose_corrective: false,
33 up_axis,
34 smpl_type,
35 non_retargeted_pose: None,
36 retargeted: false,
37 }
38 }
39 pub fn new_empty(up_axis: UpAxis, smpl_type: SmplType) -> Self {
40 let device = B::Device::default();
41 let joint_poses = match smpl_type {
42 SmplType::SmplX => Tensor::<B, 2>::zeros([smpl_x::NUM_JOINTS + 1, 3], &device),
43 SmplType::SmplH => Tensor::<B, 2>::zeros([smpl_h::NUM_JOINTS + 1, 3], &device),
44 _ => panic!("{smpl_type:?} is not yet supported!"),
45 };
46 let global_trans = Tensor::<B, 1>::zeros([3], &device);
47 Self {
48 device,
49 joint_poses,
50 global_trans,
51 enable_pose_corrective: false,
52 up_axis,
53 smpl_type,
54 non_retargeted_pose: None,
55 retargeted: false,
56 }
57 }
58 pub fn new_from_ndarray(joint_poses: nd::Array2<f32>, global_trans: nd::Array1<f32>, up_axis: UpAxis, smpl_type: SmplType) -> Self {
59 let device = B::Device::default();
60 Self {
61 device: device.clone(),
62 joint_poses: joint_poses.into_burn(&device.clone()),
63 global_trans: global_trans.into_burn(&device),
64 enable_pose_corrective: false,
65 up_axis,
66 smpl_type,
67 non_retargeted_pose: None,
68 retargeted: false,
69 }
70 }
71 #[allow(clippy::cast_sign_loss)]
75 pub fn new_from_smpl_codec(codec: &SmplCodec) -> Option<Self> {
76 let nr_frames = codec.frame_count as u32;
77 assert_eq!(nr_frames, 1, "For a pose the nr of frames in the codec has to be 1");
78 let metadata = smpl_metadata(&codec.smpl_type());
79 let body_translation = codec
80 .body_translation
81 .as_ref()
82 .unwrap_or(&ndarray::Array2::<f32>::zeros((1, 3)))
83 .index_axis(nd::Axis(0), 0)
84 .to_owned();
85 let body_pose = codec.body_pose.as_ref()?.index_axis(nd::Axis(0), 0).to_owned();
86 let head_pose = codec
87 .head_pose
88 .as_ref()
89 .unwrap_or(&ndarray::Array3::<f32>::zeros((1, metadata.num_face_joints, 3)))
90 .index_axis(nd::Axis(0), 0)
91 .into_owned();
92 let left_hand_pose = codec
93 .left_hand_pose
94 .as_ref()
95 .unwrap_or(&ndarray::Array3::<f32>::zeros((1, metadata.num_hand_joints, 3)))
96 .index_axis(nd::Axis(0), 0)
97 .into_owned();
98 let right_hand_pose = codec
99 .right_hand_pose
100 .as_ref()
101 .unwrap_or(&ndarray::Array3::<f32>::zeros((1, metadata.num_hand_joints, 3)))
102 .index_axis(nd::Axis(0), 0)
103 .into_owned();
104 let joint_poses = concatenate(
105 nd::Axis(0),
106 &[body_pose.view(), head_pose.view(), left_hand_pose.view(), right_hand_pose.view()],
107 )
108 .unwrap();
109 Some(Self::new_from_ndarray(joint_poses, body_translation, UpAxis::Y, codec.smpl_type()))
110 }
111 #[cfg(not(target_arch = "wasm32"))]
113 #[allow(clippy::cast_possible_truncation)]
114 pub fn new_from_smpl_file(path: &str) -> Option<Self> {
115 let codec = SmplCodec::from_file(path);
116 Self::new_from_smpl_codec(&codec)
117 }
118 pub fn num_active_joints(&self) -> usize {
119 self.joint_poses.dims()[0]
120 }
121 pub fn apply_mask(&mut self, mask: &mut PoseOverride) {
122 let metadata = smpl_metadata(&self.smpl_type);
123 let dim_joint = self.joint_poses.dims()[1];
124 for part in &mask.denied_parts {
125 if *part == PosePart::RootTranslation {
126 self.global_trans = self.global_trans.clone().slice_fill([..], 0.0);
127 } else {
128 let range_of_body_part = metadata.parts2jointranges[*part].clone();
129 let num_joints = self.joint_poses.dims()[0];
130 if range_of_body_part.start < num_joints {
131 let range_of_body_part_clamped = range_of_body_part.start..std::cmp::min(num_joints, range_of_body_part.end);
132 self.joint_poses = self.joint_poses.clone().slice_fill([range_of_body_part_clamped, 0..dim_joint], 0.0);
133 }
134 }
135 }
136 let range_left_hand = metadata.parts2jointranges[PosePart::LeftHand].clone();
137 let range_right_hand = metadata.parts2jointranges[PosePart::RightHand].clone();
138 if let Some(hand_type) = mask.overwrite_hands {
139 let original_left = self.joint_poses.clone().slice([range_left_hand.clone(), 0..dim_joint]);
140 let original_right = self.joint_poses.clone().slice([range_right_hand.clone(), 0..dim_joint]);
141 if mask.original_left_hand.is_none() {
142 mask.original_left_hand = Some(original_left.clone().to_ndarray());
143 }
144 if mask.original_right_hand.is_none() {
145 mask.original_right_hand = Some(original_right.clone().to_ndarray());
146 }
147 self.joint_poses = self
148 .joint_poses
149 .clone()
150 .slice_assign([range_left_hand, 0..dim_joint], metadata.hand_poses[hand_type].left.to_burn(&self.device));
151 self.joint_poses = self.joint_poses.clone().slice_assign(
152 [range_right_hand, 0..dim_joint],
153 metadata.hand_poses[hand_type].right.to_burn(&self.device),
154 );
155 } else {
156 if let Some(left) = mask.original_left_hand.take() {
157 self.joint_poses = self
158 .joint_poses
159 .clone()
160 .slice_assign([range_left_hand, 0..dim_joint], left.to_burn(&self.device));
161 }
162 if let Some(right) = mask.original_right_hand.take() {
163 self.joint_poses = self
164 .joint_poses
165 .clone()
166 .slice_assign([range_right_hand, 0..dim_joint], right.to_burn(&self.device));
167 }
168 }
169 }
170 #[must_use]
172 pub fn interpolate(&self, other_pose: &Self, other_weight: f32) -> PoseG<B> {
173 if !(0.0..=1.0).contains(&other_weight) {
174 warn!("pose interpolation weight is outside the [0,1] range, will clamp. Weight is {other_weight}");
175 }
176 let other_weight = other_weight.clamp(0.0, 1.0);
177 assert!(
178 self.smpl_type == other_pose.smpl_type,
179 "We can only interpolate to a pose of the same type. Origin: {:?}. Dest: {:?}",
180 self.smpl_type,
181 other_pose.smpl_type
182 );
183 let cur_w = 1.0 - other_weight;
184 if self.smpl_type == SmplType::SmplPP {
185 let non_angle_indices = [27, 28, 37, 38];
186 let dim_joint = self.joint_poses.dims()[1];
187 let mut new_joint_poses = self.joint_poses.clone();
188 #[allow(clippy::range_plus_one)]
189 for (i, (cur_angle, other_angle)) in self
190 .joint_poses
191 .clone()
192 .iter_dim(0)
193 .zip(other_pose.joint_poses.clone().iter_dim(0))
194 .enumerate()
195 {
196 if non_angle_indices.contains(&i) {
197 new_joint_poses = new_joint_poses
198 .clone()
199 .slice_assign([i..i + 1, 0..dim_joint], cur_w * cur_angle + other_weight * other_angle);
200 } else {
201 let new_val = interpolate_angle_tensor(cur_angle.squeeze(0), other_angle.squeeze(0), cur_w, other_weight);
202 new_joint_poses = new_joint_poses.clone().slice_assign([i..i + 1, 0..dim_joint], new_val.unsqueeze());
203 }
204 }
205 let new_global_trans = cur_w * self.global_trans.clone() + other_weight * other_pose.global_trans.clone();
206 return PoseG::new(new_joint_poses, new_global_trans, self.up_axis, self.smpl_type);
207 }
208 let new_global_trans = cur_w * self.global_trans.clone() + other_weight * other_pose.global_trans.clone();
209 let all_joints = Tensor::cat(vec![self.joint_poses.clone(), other_pose.joint_poses.clone()], 0);
210 let all_quats = smpl_utils::numerical::axis_angle_to_quaternion(all_joints);
211 let vec_quats = all_quats.split(self.joint_poses.dims()[0], 0);
212 let cur_quats = vec_quats[0].clone();
213 let other_quats = vec_quats[1].clone();
214 let interpolated_quats = smpl_utils::numerical::quaternion_interpolate_lerp_fast(cur_quats, other_quats, other_weight);
215 let new_joint_poses = smpl_utils::numerical::quaternion_to_axis_angle_fast(interpolated_quats);
216 PoseG::new(new_joint_poses, new_global_trans, self.up_axis, self.smpl_type)
217 }
218}
219pub type Pose = PoseG<AppBackend>;