1use crate::common::{
2 metadata::SmplMetadata,
3 pose::PoseG,
4 pose_parts::PosePart,
5 types::{SmplType, UpAxis},
6};
7use burn::{
8 prelude::Backend,
9 tensor::{Float, Tensor},
10};
11use std::ops::Range;
12#[derive(Debug)]
14pub struct PoseChunked<B: Backend> {
15 pub device: B::Device,
16 pub global_trans: Tensor<B, 2>,
17 pub global_orient: Option<Tensor<B, 2>>,
18 pub body_pose: Option<Tensor<B, 2>>,
19 pub left_hand_pose: Option<Tensor<B, 2>>,
20 pub right_hand_pose: Option<Tensor<B, 2>>,
21 pub jaw_pose: Option<Tensor<B, 2>>,
22 pub left_eye_pose: Option<Tensor<B, 2>>,
23 pub right_eye_pose: Option<Tensor<B, 2>>,
24 pub up_axis: UpAxis,
25 pub smpl_type: SmplType,
26}
27impl<B: Backend> Default for PoseChunked<B> {
28 fn default() -> Self {
29 let device = B::Device::default();
30 let global_trans = Tensor::<B, 2, Float>::zeros([1, 3], &device.clone());
31 Self {
32 device,
33 global_trans,
34 global_orient: None,
35 body_pose: None,
36 left_hand_pose: None,
37 right_hand_pose: None,
38 jaw_pose: None,
39 left_eye_pose: None,
40 right_eye_pose: None,
41 up_axis: UpAxis::Y,
42 smpl_type: SmplType::SmplX,
43 }
44 }
45}
46impl<B: Backend> PoseChunked<B> {
47 #[allow(clippy::missing_panics_doc)]
48 pub fn new(pose: &PoseG<B>, metadata: &SmplMetadata) -> Self {
49 if pose.smpl_type == SmplType::SmplPP {
50 return Self {
51 device: pose.device.clone(),
52 global_trans: pose.global_trans.clone().reshape([1, 3]),
53 global_orient: None,
54 body_pose: Some(pose.joint_poses.clone()),
55 left_hand_pose: None,
56 right_hand_pose: None,
57 jaw_pose: None,
58 left_eye_pose: None,
59 right_eye_pose: None,
60 up_axis: pose.up_axis,
61 smpl_type: pose.smpl_type,
62 };
63 }
64 let p2r = &metadata.parts2jointranges;
65 let joint_poses = &pose.joint_poses;
66 let max_range = 0..joint_poses.dims()[0];
67 let jdim = joint_poses.dims()[1];
68 #[allow(clippy::if_same_then_else)]
69 let slice_or_none = |joints: Tensor<B, 2>, slice: &Range<usize>, max: &Range<usize>, jdim: usize| -> Option<Tensor<B, 2>> {
70 if slice.end > max.end {
71 None
72 } else if slice.start == 0 && slice.end == 0 {
73 None
74 } else {
75 Some(joints.clone().slice([slice.start..slice.end, 0..jdim]))
76 }
77 };
78 let global_orient = slice_or_none(joint_poses.clone(), &p2r[PosePart::RootRotation], &max_range, jdim);
79 let body_pose = slice_or_none(joint_poses.clone(), &p2r[PosePart::Body], &max_range, jdim);
80 let left_hand_pose = slice_or_none(joint_poses.clone(), &p2r[PosePart::LeftHand], &max_range, jdim);
81 let right_hand_pose = slice_or_none(joint_poses.clone(), &p2r[PosePart::RightHand], &max_range, jdim);
82 let jaw_pose = slice_or_none(joint_poses.clone(), &p2r[PosePart::Jaw], &max_range, jdim);
83 let left_eye_pose = slice_or_none(joint_poses.clone(), &p2r[PosePart::LeftEye], &max_range, jdim);
84 let right_eye_pose = slice_or_none(joint_poses.clone(), &p2r[PosePart::RightEye], &max_range, jdim);
85 Self {
86 device: pose.device.clone(),
87 global_trans: pose.global_trans.clone().reshape([1, 3]),
88 global_orient,
89 body_pose,
90 left_hand_pose,
91 right_hand_pose,
92 jaw_pose,
93 left_eye_pose,
94 right_eye_pose,
95 up_axis: pose.up_axis,
96 smpl_type: pose.smpl_type,
97 }
98 }
99 #[allow(clippy::missing_panics_doc)]
100 pub fn to_pose(&self, metadata: &SmplMetadata, smpl_type: SmplType) -> PoseG<B> {
101 if smpl_type == SmplType::SmplPP {
102 let mut pose = PoseG::<B>::new_empty(self.up_axis, smpl_type);
103 let zeros = Tensor::<B, 2, Float>::zeros([46, 1], &self.device.clone());
104 pose.joint_poses = self.body_pose.as_ref().unwrap_or(&zeros).clone();
105 pose.global_trans = self.global_trans.clone().reshape([3]);
106 return pose;
107 }
108 let mut pose = PoseG::<B>::new_empty(self.up_axis, smpl_type);
109 let zeros = Tensor::<B, 2, Float>::zeros([1, 3], &self.device.clone());
110 pose.global_trans = self.global_trans.clone().reshape([3]);
111 let jdim = pose.joint_poses.dims()[1];
112 pose.joint_poses = pose.joint_poses.clone().slice_assign(
113 [metadata.parts2jointranges[PosePart::RootRotation].clone(), 0..jdim],
114 self.global_orient.as_ref().unwrap_or(&zeros).clone(),
115 );
116 pose.joint_poses = pose.joint_poses.clone().slice_assign(
117 [metadata.parts2jointranges[PosePart::Body].clone(), 0..jdim],
118 self.body_pose.as_ref().unwrap_or(&zeros).clone(),
119 );
120 pose.joint_poses = pose.joint_poses.clone().slice_assign(
121 [metadata.parts2jointranges[PosePart::LeftHand].clone(), 0..jdim],
122 self.left_hand_pose.as_ref().unwrap_or(&zeros).clone(),
123 );
124 pose.joint_poses = pose.joint_poses.clone().slice_assign(
125 [metadata.parts2jointranges[PosePart::RightHand].clone(), 0..jdim],
126 self.right_hand_pose.as_ref().unwrap_or(&zeros).clone(),
127 );
128 pose.joint_poses = pose.joint_poses.clone().slice_assign(
129 [metadata.parts2jointranges[PosePart::Jaw].clone(), 0..jdim],
130 self.jaw_pose.as_ref().unwrap_or(&zeros).clone(),
131 );
132 pose.joint_poses = pose.joint_poses.clone().slice_assign(
133 [metadata.parts2jointranges[PosePart::LeftEye].clone(), 0..jdim],
134 self.left_eye_pose.as_ref().unwrap_or(&zeros).clone(),
135 );
136 pose.joint_poses = pose.joint_poses.clone().slice_assign(
137 [metadata.parts2jointranges[PosePart::RightEye].clone(), 0..jdim],
138 self.right_eye_pose.as_ref().unwrap_or(&zeros).clone(),
139 );
140 pose
141 }
142}