smpl_core/conversions/
pose_chunked.rs

1use crate::common::{
2    metadata::SmplMetadata,
3    pose::PoseG,
4    pose_parts::PosePart,
5    types::{SmplType, UpAxis},
6};
7use burn::{
8    prelude::Backend,
9    tensor::{Float, Tensor},
10};
11use std::ops::Range;
12/// Chunk ``Pose`` into various pose parts
13#[derive(Debug)]
14pub struct PoseChunked<B: Backend> {
15    pub device: B::Device,
16    pub global_trans: Tensor<B, 2>,
17    pub global_orient: Option<Tensor<B, 2>>,
18    pub body_pose: Option<Tensor<B, 2>>,
19    pub left_hand_pose: Option<Tensor<B, 2>>,
20    pub right_hand_pose: Option<Tensor<B, 2>>,
21    pub jaw_pose: Option<Tensor<B, 2>>,
22    pub left_eye_pose: Option<Tensor<B, 2>>,
23    pub right_eye_pose: Option<Tensor<B, 2>>,
24    pub up_axis: UpAxis,
25    pub smpl_type: SmplType,
26}
27impl<B: Backend> Default for PoseChunked<B> {
28    fn default() -> Self {
29        let device = B::Device::default();
30        let global_trans = Tensor::<B, 2, Float>::zeros([1, 3], &device.clone());
31        Self {
32            device,
33            global_trans,
34            global_orient: None,
35            body_pose: None,
36            left_hand_pose: None,
37            right_hand_pose: None,
38            jaw_pose: None,
39            left_eye_pose: None,
40            right_eye_pose: None,
41            up_axis: UpAxis::Y,
42            smpl_type: SmplType::SmplX,
43        }
44    }
45}
46impl<B: Backend> PoseChunked<B> {
47    #[allow(clippy::missing_panics_doc)]
48    pub fn new(pose: &PoseG<B>, metadata: &SmplMetadata) -> Self {
49        if pose.smpl_type == SmplType::SmplPP {
50            return Self {
51                device: pose.device.clone(),
52                global_trans: pose.global_trans.clone().reshape([1, 3]),
53                global_orient: None,
54                body_pose: Some(pose.joint_poses.clone()),
55                left_hand_pose: None,
56                right_hand_pose: None,
57                jaw_pose: None,
58                left_eye_pose: None,
59                right_eye_pose: None,
60                up_axis: pose.up_axis,
61                smpl_type: pose.smpl_type,
62            };
63        }
64        let p2r = &metadata.parts2jointranges;
65        let joint_poses = &pose.joint_poses;
66        let max_range = 0..joint_poses.dims()[0];
67        let jdim = joint_poses.dims()[1];
68        #[allow(clippy::if_same_then_else)]
69        let slice_or_none = |joints: Tensor<B, 2>, slice: &Range<usize>, max: &Range<usize>, jdim: usize| -> Option<Tensor<B, 2>> {
70            if slice.end > max.end {
71                None
72            } else if slice.start == 0 && slice.end == 0 {
73                None
74            } else {
75                Some(joints.clone().slice([slice.start..slice.end, 0..jdim]))
76            }
77        };
78        let global_orient = slice_or_none(joint_poses.clone(), &p2r[PosePart::RootRotation], &max_range, jdim);
79        let body_pose = slice_or_none(joint_poses.clone(), &p2r[PosePart::Body], &max_range, jdim);
80        let left_hand_pose = slice_or_none(joint_poses.clone(), &p2r[PosePart::LeftHand], &max_range, jdim);
81        let right_hand_pose = slice_or_none(joint_poses.clone(), &p2r[PosePart::RightHand], &max_range, jdim);
82        let jaw_pose = slice_or_none(joint_poses.clone(), &p2r[PosePart::Jaw], &max_range, jdim);
83        let left_eye_pose = slice_or_none(joint_poses.clone(), &p2r[PosePart::LeftEye], &max_range, jdim);
84        let right_eye_pose = slice_or_none(joint_poses.clone(), &p2r[PosePart::RightEye], &max_range, jdim);
85        Self {
86            device: pose.device.clone(),
87            global_trans: pose.global_trans.clone().reshape([1, 3]),
88            global_orient,
89            body_pose,
90            left_hand_pose,
91            right_hand_pose,
92            jaw_pose,
93            left_eye_pose,
94            right_eye_pose,
95            up_axis: pose.up_axis,
96            smpl_type: pose.smpl_type,
97        }
98    }
99    #[allow(clippy::missing_panics_doc)]
100    pub fn to_pose(&self, metadata: &SmplMetadata, smpl_type: SmplType) -> PoseG<B> {
101        if smpl_type == SmplType::SmplPP {
102            let mut pose = PoseG::<B>::new_empty(self.up_axis, smpl_type);
103            let zeros = Tensor::<B, 2, Float>::zeros([46, 1], &self.device.clone());
104            pose.joint_poses = self.body_pose.as_ref().unwrap_or(&zeros).clone();
105            pose.global_trans = self.global_trans.clone().reshape([3]);
106            return pose;
107        }
108        let mut pose = PoseG::<B>::new_empty(self.up_axis, smpl_type);
109        let zeros = Tensor::<B, 2, Float>::zeros([1, 3], &self.device.clone());
110        pose.global_trans = self.global_trans.clone().reshape([3]);
111        let jdim = pose.joint_poses.dims()[1];
112        pose.joint_poses = pose.joint_poses.clone().slice_assign(
113            [metadata.parts2jointranges[PosePart::RootRotation].clone(), 0..jdim],
114            self.global_orient.as_ref().unwrap_or(&zeros).clone(),
115        );
116        pose.joint_poses = pose.joint_poses.clone().slice_assign(
117            [metadata.parts2jointranges[PosePart::Body].clone(), 0..jdim],
118            self.body_pose.as_ref().unwrap_or(&zeros).clone(),
119        );
120        pose.joint_poses = pose.joint_poses.clone().slice_assign(
121            [metadata.parts2jointranges[PosePart::LeftHand].clone(), 0..jdim],
122            self.left_hand_pose.as_ref().unwrap_or(&zeros).clone(),
123        );
124        pose.joint_poses = pose.joint_poses.clone().slice_assign(
125            [metadata.parts2jointranges[PosePart::RightHand].clone(), 0..jdim],
126            self.right_hand_pose.as_ref().unwrap_or(&zeros).clone(),
127        );
128        pose.joint_poses = pose.joint_poses.clone().slice_assign(
129            [metadata.parts2jointranges[PosePart::Jaw].clone(), 0..jdim],
130            self.jaw_pose.as_ref().unwrap_or(&zeros).clone(),
131        );
132        pose.joint_poses = pose.joint_poses.clone().slice_assign(
133            [metadata.parts2jointranges[PosePart::LeftEye].clone(), 0..jdim],
134            self.left_eye_pose.as_ref().unwrap_or(&zeros).clone(),
135        );
136        pose.joint_poses = pose.joint_poses.clone().slice_assign(
137            [metadata.parts2jointranges[PosePart::RightEye].clone(), 0..jdim],
138            self.right_eye_pose.as_ref().unwrap_or(&zeros).clone(),
139        );
140        pose
141    }
142}