smpl_core/common/
pose.rs

1use super::{
2    metadata::smpl_metadata,
3    pose_override::PoseOverride,
4    types::{SmplType, UpAxis},
5};
6use crate::{codec::codec::SmplCodec, common::pose_parts::PosePart, smpl_h::smpl_h, smpl_x::smpl_x};
7use gloss_utils::nshare::ToNalgebra;
8use log::warn;
9use nalgebra as na;
10use nd::concatenate;
11use ndarray as nd;
12use ndarray::prelude::*;
13use smpl_utils::numerical::interpolate_angle;
14/// Component for pose
15#[derive(Clone, Debug)]
16pub struct Pose {
17    pub joint_poses: nd::Array2<f32>,
18    pub global_trans: nd::Array1<f32>,
19    pub enable_pose_corrective: bool,
20    pub up_axis: UpAxis,
21    pub smpl_type: SmplType,
22    pub non_retargeted_pose: Option<Box<Pose>>,
23    pub retargeted: bool,
24}
25impl Pose {
26    pub fn new(joint_poses: nd::Array2<f32>, global_trans: nd::Array1<f32>, up_axis: UpAxis, smpl_type: SmplType) -> Self {
27        Self {
28            joint_poses,
29            global_trans,
30            enable_pose_corrective: false,
31            up_axis,
32            smpl_type,
33            non_retargeted_pose: None,
34            retargeted: false,
35        }
36    }
37    pub fn new_empty(up_axis: UpAxis, smpl_type: SmplType) -> Self {
38        let joint_poses = match smpl_type {
39            SmplType::SmplX => ndarray::Array2::<f32>::zeros((smpl_x::NUM_JOINTS + 1, 3)),
40            SmplType::SmplH => ndarray::Array2::<f32>::zeros((smpl_h::NUM_JOINTS + 1, 3)),
41            _ => panic!("{smpl_type:?} is not yet supported!"),
42        };
43        let global_trans = ndarray::Array1::<f32>::zeros(3);
44        Self {
45            joint_poses,
46            global_trans,
47            enable_pose_corrective: false,
48            up_axis,
49            smpl_type,
50            non_retargeted_pose: None,
51            retargeted: false,
52        }
53    }
54    /// Create a new ``Pose`` component from ``SmplCodec``
55    /// # Panics
56    /// Will panic if the ``nr_frames`` is different than 1
57    #[allow(clippy::cast_sign_loss)]
58    pub fn new_from_smpl_codec(codec: &SmplCodec) -> Option<Self> {
59        let nr_frames = codec.frame_count as u32;
60        assert_eq!(nr_frames, 1, "For a pose the nr of frames in the codec has to be 1");
61        let metadata = smpl_metadata(&codec.smpl_type());
62        let body_translation = codec
63            .body_translation
64            .as_ref()
65            .unwrap_or(&ndarray::Array2::<f32>::zeros((1, 3)))
66            .index_axis(nd::Axis(0), 0)
67            .to_owned();
68        let body_pose = codec.body_pose.as_ref()?.index_axis(nd::Axis(0), 0).to_owned();
69        let head_pose = codec
70            .head_pose
71            .as_ref()
72            .unwrap_or(&ndarray::Array3::<f32>::zeros((1, metadata.num_face_joints, 3)))
73            .index_axis(nd::Axis(0), 0)
74            .into_owned();
75        let left_hand_pose = codec
76            .left_hand_pose
77            .as_ref()
78            .unwrap_or(&ndarray::Array3::<f32>::zeros((1, metadata.num_hand_joints, 3)))
79            .index_axis(nd::Axis(0), 0)
80            .into_owned();
81        let right_hand_pose = codec
82            .right_hand_pose
83            .as_ref()
84            .unwrap_or(&ndarray::Array3::<f32>::zeros((1, metadata.num_hand_joints, 3)))
85            .index_axis(nd::Axis(0), 0)
86            .into_owned();
87        let joint_poses = concatenate(
88            nd::Axis(0),
89            &[body_pose.view(), head_pose.view(), left_hand_pose.view(), right_hand_pose.view()],
90        )
91        .unwrap();
92        Some(Self::new(joint_poses, body_translation, UpAxis::Y, codec.smpl_type()))
93    }
94    /// Create new ``Pose`` component from ``.smpl`` file
95    #[cfg(not(target_arch = "wasm32"))]
96    #[allow(clippy::cast_possible_truncation)]
97    pub fn new_from_smpl_file(path: &str) -> Option<Self> {
98        let codec = SmplCodec::from_file(path);
99        Self::new_from_smpl_codec(&codec)
100    }
101    pub fn num_active_joints(&self) -> usize {
102        self.joint_poses.dim().0
103    }
104    pub fn apply_mask(&mut self, mask: &mut PoseOverride) {
105        let metadata = smpl_metadata(&self.smpl_type);
106        for part in &mask.denied_parts {
107            if *part == PosePart::RootTranslation {
108                self.global_trans.fill(0.0);
109            } else {
110                let range_of_body_part = metadata.parts2jointranges[*part].clone();
111                let num_joints = self.joint_poses.dim().0;
112                if range_of_body_part.start < num_joints {
113                    let range_of_body_part_clamped = range_of_body_part.start..std::cmp::min(num_joints, range_of_body_part.end);
114                    self.joint_poses.slice_mut(s![range_of_body_part_clamped, ..]).fill(0.0);
115                }
116            }
117        }
118        let range_left_hand = metadata.parts2jointranges[PosePart::LeftHand].clone();
119        let range_right_hand = metadata.parts2jointranges[PosePart::RightHand].clone();
120        if let Some(hand_type) = mask.overwrite_hands {
121            let original_left = self.joint_poses.slice(s![range_left_hand.clone(), ..]);
122            let original_right = self.joint_poses.slice(s![range_right_hand.clone(), ..]);
123            if mask.original_left_hand.is_none() {
124                mask.original_left_hand = Some(original_left.to_owned());
125            }
126            if mask.original_right_hand.is_none() {
127                mask.original_right_hand = Some(original_right.to_owned());
128            }
129            self.joint_poses
130                .slice_mut(s![range_left_hand, ..])
131                .assign(&metadata.hand_poses[hand_type].left);
132            self.joint_poses
133                .slice_mut(s![range_right_hand, ..])
134                .assign(&metadata.hand_poses[hand_type].right);
135        } else {
136            if let Some(left) = mask.original_left_hand.take() {
137                self.joint_poses.slice_mut(s![range_left_hand, ..]).assign(&left);
138            }
139            if let Some(right) = mask.original_right_hand.take() {
140                self.joint_poses.slice_mut(s![range_right_hand, ..]).assign(&right);
141            }
142        }
143    }
144    /// Interpolate between 2 poses
145    #[must_use]
146    pub fn interpolate(&self, other_pose: &Self, other_weight: f32) -> Pose {
147        if !(0.0..=1.0).contains(&other_weight) {
148            warn!("pose interpolation weight is outside the [0,1] range, will clamp. Weight is {other_weight}");
149        }
150        let other_weight = other_weight.clamp(0.0, 1.0);
151        assert!(
152            self.smpl_type == other_pose.smpl_type,
153            "We can only interpolate to a pose of the same type. Origin: {:?}. Dest: {:?}",
154            self.smpl_type,
155            other_pose.smpl_type
156        );
157        let non_angle_indices = [27, 28, 37, 38];
158        if self.smpl_type == SmplType::SmplPP {
159            let cur_w = 1.0 - other_weight;
160            let mut new_joint_poses = self.joint_poses.clone();
161            for (i, ((cur_angle, other_angle), new_angle)) in self
162                .joint_poses
163                .iter()
164                .zip(other_pose.joint_poses.iter())
165                .zip(new_joint_poses.iter_mut())
166                .enumerate()
167            {
168                if non_angle_indices.contains(&i) {
169                    *new_angle = cur_w * cur_angle + other_weight * other_angle;
170                } else {
171                    *new_angle = interpolate_angle(*cur_angle, *other_angle, cur_w, other_weight);
172                }
173            }
174            let new_global_trans = cur_w * &self.global_trans + other_weight * &other_pose.global_trans;
175            return Pose::new(new_joint_poses, new_global_trans, self.up_axis, self.smpl_type);
176        }
177        let cur_w = 1.0 - other_weight;
178        let new_global_trans = cur_w * &self.global_trans + other_weight * &other_pose.global_trans;
179        let cur_pose_axis_angle_na = self.joint_poses.view().into_nalgebra();
180        let other_pose_axis_angle_na = other_pose.joint_poses.view().into_nalgebra();
181        let mut new_joint_poses = nd::Array2::<f32>::zeros((self.num_active_joints(), 3));
182        for ((cur_axis, other_axis), mut new_joint) in cur_pose_axis_angle_na
183            .row_iter()
184            .zip(other_pose_axis_angle_na.row_iter())
185            .zip(new_joint_poses.axis_iter_mut(nd::Axis(0)))
186        {
187            let cur_vec = na::Vector3::new(cur_axis[0], cur_axis[1], cur_axis[2]);
188            let mut cur_q = na::UnitQuaternion::from_axis_angle(&na::UnitVector3::new_normalize(cur_vec), cur_axis.norm());
189            if cur_axis.norm() == 0.0 {
190                cur_q = na::UnitQuaternion::default();
191            }
192            let other_vec = na::Vector3::new(other_axis[0], other_axis[1], other_axis[2]);
193            let mut other_q = na::UnitQuaternion::from_axis_angle(&na::UnitVector3::new_normalize(other_vec), other_axis.norm());
194            if other_axis.norm() == 0.0 {
195                other_q = na::UnitQuaternion::default();
196            }
197            let new_q: na::Unit<na::Quaternion<f32>> = cur_q.slerp(&other_q, other_weight);
198            let axis_opt = new_q.axis();
199            let angle = new_q.angle();
200            if let Some(axis) = axis_opt {
201                new_joint.assign(&array![axis.x * angle, axis.y * angle, axis.z * angle]);
202            }
203        }
204        Pose::new(new_joint_poses, new_global_trans, self.up_axis, self.smpl_type)
205    }
206}