smpl_core/common/
pose.rs

1use super::{
2    metadata::smpl_metadata,
3    pose_override::PoseOverride,
4    types::{SmplType, UpAxis},
5};
6use crate::AppBackend;
7use crate::{codec::codec::SmplCodec, common::pose_parts::PosePart, smpl_h::smpl_h, smpl_x::smpl_x};
8use burn::{prelude::Backend, tensor::Tensor};
9use gloss_utils::bshare::{ToBurn, ToNdArray};
10use log::warn;
11use nd::concatenate;
12use ndarray as nd;
13use smpl_utils::numerical::interpolate_angle_tensor;
14/// Component for pose
15#[derive(Clone, Debug)]
16pub struct PoseG<B: Backend> {
17    pub device: B::Device,
18    pub joint_poses: Tensor<B, 2>,
19    pub global_trans: Tensor<B, 1>,
20    pub enable_pose_corrective: bool,
21    pub up_axis: UpAxis,
22    pub smpl_type: SmplType,
23    pub non_retargeted_pose: Option<Box<PoseG<B>>>,
24    pub retargeted: bool,
25}
26impl<B: Backend> PoseG<B> {
27    pub fn new(joint_poses: Tensor<B, 2>, global_trans: Tensor<B, 1>, up_axis: UpAxis, smpl_type: SmplType) -> Self {
28        Self {
29            device: joint_poses.device(),
30            joint_poses,
31            global_trans,
32            enable_pose_corrective: false,
33            up_axis,
34            smpl_type,
35            non_retargeted_pose: None,
36            retargeted: false,
37        }
38    }
39    pub fn new_empty(up_axis: UpAxis, smpl_type: SmplType) -> Self {
40        let device = B::Device::default();
41        let joint_poses = match smpl_type {
42            SmplType::SmplX => Tensor::<B, 2>::zeros([smpl_x::NUM_JOINTS + 1, 3], &device),
43            SmplType::SmplH => Tensor::<B, 2>::zeros([smpl_h::NUM_JOINTS + 1, 3], &device),
44            _ => panic!("{smpl_type:?} is not yet supported!"),
45        };
46        let global_trans = Tensor::<B, 1>::zeros([3], &device);
47        Self {
48            device,
49            joint_poses,
50            global_trans,
51            enable_pose_corrective: false,
52            up_axis,
53            smpl_type,
54            non_retargeted_pose: None,
55            retargeted: false,
56        }
57    }
58    pub fn new_from_ndarray(joint_poses: nd::Array2<f32>, global_trans: nd::Array1<f32>, up_axis: UpAxis, smpl_type: SmplType) -> Self {
59        let device = B::Device::default();
60        Self {
61            device: device.clone(),
62            joint_poses: joint_poses.into_burn(&device.clone()),
63            global_trans: global_trans.into_burn(&device),
64            enable_pose_corrective: false,
65            up_axis,
66            smpl_type,
67            non_retargeted_pose: None,
68            retargeted: false,
69        }
70    }
71    /// Create a new ``Pose`` component from ``SmplCodec``
72    /// # Panics
73    /// Will panic if the ``nr_frames`` is different than 1
74    #[allow(clippy::cast_sign_loss)]
75    pub fn new_from_smpl_codec(codec: &SmplCodec) -> Option<Self> {
76        let nr_frames = codec.frame_count as u32;
77        assert_eq!(nr_frames, 1, "For a pose the nr of frames in the codec has to be 1");
78        let metadata = smpl_metadata(&codec.smpl_type());
79        let body_translation = codec
80            .body_translation
81            .as_ref()
82            .unwrap_or(&ndarray::Array2::<f32>::zeros((1, 3)))
83            .index_axis(nd::Axis(0), 0)
84            .to_owned();
85        let body_pose = codec.body_pose.as_ref()?.index_axis(nd::Axis(0), 0).to_owned();
86        let head_pose = codec
87            .head_pose
88            .as_ref()
89            .unwrap_or(&ndarray::Array3::<f32>::zeros((1, metadata.num_face_joints, 3)))
90            .index_axis(nd::Axis(0), 0)
91            .into_owned();
92        let left_hand_pose = codec
93            .left_hand_pose
94            .as_ref()
95            .unwrap_or(&ndarray::Array3::<f32>::zeros((1, metadata.num_hand_joints, 3)))
96            .index_axis(nd::Axis(0), 0)
97            .into_owned();
98        let right_hand_pose = codec
99            .right_hand_pose
100            .as_ref()
101            .unwrap_or(&ndarray::Array3::<f32>::zeros((1, metadata.num_hand_joints, 3)))
102            .index_axis(nd::Axis(0), 0)
103            .into_owned();
104        let joint_poses = concatenate(
105            nd::Axis(0),
106            &[body_pose.view(), head_pose.view(), left_hand_pose.view(), right_hand_pose.view()],
107        )
108        .unwrap();
109        Some(Self::new_from_ndarray(joint_poses, body_translation, UpAxis::Y, codec.smpl_type()))
110    }
111    /// Create new ``Pose`` component from ``.smpl`` file
112    #[cfg(not(target_arch = "wasm32"))]
113    #[allow(clippy::cast_possible_truncation)]
114    pub fn new_from_smpl_file(path: &str) -> Option<Self> {
115        let codec = SmplCodec::from_file(path);
116        Self::new_from_smpl_codec(&codec)
117    }
118    pub fn num_active_joints(&self) -> usize {
119        self.joint_poses.dims()[0]
120    }
121    pub fn apply_mask(&mut self, mask: &mut PoseOverride) {
122        let metadata = smpl_metadata(&self.smpl_type);
123        let dim_joint = self.joint_poses.dims()[1];
124        for part in &mask.denied_parts {
125            if *part == PosePart::RootTranslation {
126                self.global_trans = self.global_trans.clone().slice_fill([..], 0.0);
127            } else {
128                let range_of_body_part = metadata.parts2jointranges[*part].clone();
129                let num_joints = self.joint_poses.dims()[0];
130                if range_of_body_part.start < num_joints {
131                    let range_of_body_part_clamped = range_of_body_part.start..std::cmp::min(num_joints, range_of_body_part.end);
132                    self.joint_poses = self.joint_poses.clone().slice_fill([range_of_body_part_clamped, 0..dim_joint], 0.0);
133                }
134            }
135        }
136        let range_left_hand = metadata.parts2jointranges[PosePart::LeftHand].clone();
137        let range_right_hand = metadata.parts2jointranges[PosePart::RightHand].clone();
138        if let Some(hand_type) = mask.overwrite_hands {
139            let original_left = self.joint_poses.clone().slice([range_left_hand.clone(), 0..dim_joint]);
140            let original_right = self.joint_poses.clone().slice([range_right_hand.clone(), 0..dim_joint]);
141            if mask.original_left_hand.is_none() {
142                mask.original_left_hand = Some(original_left.clone().to_ndarray());
143            }
144            if mask.original_right_hand.is_none() {
145                mask.original_right_hand = Some(original_right.clone().to_ndarray());
146            }
147            self.joint_poses = self
148                .joint_poses
149                .clone()
150                .slice_assign([range_left_hand, 0..dim_joint], metadata.hand_poses[hand_type].left.to_burn(&self.device));
151            self.joint_poses = self.joint_poses.clone().slice_assign(
152                [range_right_hand, 0..dim_joint],
153                metadata.hand_poses[hand_type].right.to_burn(&self.device),
154            );
155        } else {
156            if let Some(left) = mask.original_left_hand.take() {
157                self.joint_poses = self
158                    .joint_poses
159                    .clone()
160                    .slice_assign([range_left_hand, 0..dim_joint], left.to_burn(&self.device));
161            }
162            if let Some(right) = mask.original_right_hand.take() {
163                self.joint_poses = self
164                    .joint_poses
165                    .clone()
166                    .slice_assign([range_right_hand, 0..dim_joint], right.to_burn(&self.device));
167            }
168        }
169    }
170    /// Interpolate between 2 poses
171    #[must_use]
172    pub fn interpolate(&self, other_pose: &Self, other_weight: f32) -> PoseG<B> {
173        if !(0.0..=1.0).contains(&other_weight) {
174            warn!("pose interpolation weight is outside the [0,1] range, will clamp. Weight is {other_weight}");
175        }
176        let other_weight = other_weight.clamp(0.0, 1.0);
177        assert!(
178            self.smpl_type == other_pose.smpl_type,
179            "We can only interpolate to a pose of the same type. Origin: {:?}. Dest: {:?}",
180            self.smpl_type,
181            other_pose.smpl_type
182        );
183        let cur_w = 1.0 - other_weight;
184        if self.smpl_type == SmplType::SmplPP {
185            let non_angle_indices = [27, 28, 37, 38];
186            let dim_joint = self.joint_poses.dims()[1];
187            let mut new_joint_poses = self.joint_poses.clone();
188            #[allow(clippy::range_plus_one)]
189            for (i, (cur_angle, other_angle)) in self
190                .joint_poses
191                .clone()
192                .iter_dim(0)
193                .zip(other_pose.joint_poses.clone().iter_dim(0))
194                .enumerate()
195            {
196                if non_angle_indices.contains(&i) {
197                    new_joint_poses = new_joint_poses
198                        .clone()
199                        .slice_assign([i..i + 1, 0..dim_joint], cur_w * cur_angle + other_weight * other_angle);
200                } else {
201                    let new_val = interpolate_angle_tensor(cur_angle.squeeze(0), other_angle.squeeze(0), cur_w, other_weight);
202                    new_joint_poses = new_joint_poses.clone().slice_assign([i..i + 1, 0..dim_joint], new_val.unsqueeze());
203                }
204            }
205            let new_global_trans = cur_w * self.global_trans.clone() + other_weight * other_pose.global_trans.clone();
206            return PoseG::new(new_joint_poses, new_global_trans, self.up_axis, self.smpl_type);
207        }
208        let new_global_trans = cur_w * self.global_trans.clone() + other_weight * other_pose.global_trans.clone();
209        let all_joints = Tensor::cat(vec![self.joint_poses.clone(), other_pose.joint_poses.clone()], 0);
210        let all_quats = smpl_utils::numerical::axis_angle_to_quaternion(all_joints);
211        let vec_quats = all_quats.split(self.joint_poses.dims()[0], 0);
212        let cur_quats = vec_quats[0].clone();
213        let other_quats = vec_quats[1].clone();
214        let interpolated_quats = smpl_utils::numerical::quaternion_interpolate_lerp_fast(cur_quats, other_quats, other_weight);
215        let new_joint_poses = smpl_utils::numerical::quaternion_to_axis_angle_fast(interpolated_quats);
216        PoseG::new(new_joint_poses, new_global_trans, self.up_axis, self.smpl_type)
217    }
218}
219pub type Pose = PoseG<AppBackend>;