smpl_core/common/
expression.rs

1use crate::{common::types::FaceType, AppBackend};
2use burn::{
3    prelude::Backend,
4    tensor::{Float, Tensor},
5};
6use gloss_utils::bshare::ToBurn;
7use log::warn;
8use ndarray as nd;
9/// Component for Smpl Expressions or Expression Parameters
10#[derive(Clone)]
11pub struct ExpressionG<B: Backend> {
12    pub device: B::Device,
13    pub expr_coeffs: Tensor<B, 1>,
14    pub expr_type: FaceType,
15}
16impl<B: Backend> Default for ExpressionG<B> {
17    fn default() -> Self {
18        let device = B::Device::default();
19        let num_coeffs = 10;
20        let expr_coeffs = Tensor::<B, 1>::zeros([num_coeffs], &device);
21        Self {
22            device,
23            expr_coeffs,
24            expr_type: FaceType::SmplX,
25        }
26    }
27}
28impl<B: Backend> ExpressionG<B> {
29    pub fn new(expr_coeffs: Tensor<B, 1>, expr_type: FaceType) -> Self {
30        Self {
31            device: expr_coeffs.device(),
32            expr_coeffs,
33            expr_type,
34        }
35    }
36    pub fn new_empty(num_coeffs: usize, expr_type: FaceType) -> Self {
37        let device = B::Device::default();
38        let expr_coeffs = Tensor::<B, 1>::zeros([num_coeffs], &device);
39        Self {
40            device,
41            expr_coeffs,
42            expr_type,
43        }
44    }
45    pub fn new_from_ndarray(expr_coeffs: nd::Array1<f32>, expr_type: FaceType) -> Self {
46        let device = B::Device::default();
47        Self::new(expr_coeffs.into_burn(&device), expr_type)
48    }
49    #[must_use]
50    pub fn interpolate(&self, other_pose: &Self, other_weight: f32) -> Self {
51        if !(0.0..=1.0).contains(&other_weight) {
52            warn!("pose interpolation weight is outside the [0,1] range, will clamp. Weight is {other_weight}");
53        }
54        let other_weight = other_weight.clamp(0.0, 1.0);
55        let cur_w = 1.0 - other_weight;
56        let new_expression = cur_w * self.expr_coeffs.clone() + other_weight * other_pose.expr_coeffs.clone();
57        Self::new(new_expression, self.expr_type)
58    }
59}
60/// ``ExpressionOffsets`` is the result of smpl.expression2offsets(expression)
61/// which contains vertex offset for that expression
62#[derive(Clone)]
63pub struct ExpressionOffsetsG<B: Backend> {
64    pub offsets: Tensor<B, 2, Float>,
65}
66pub type ExpressionOffsets = ExpressionOffsetsG<AppBackend>;
67pub type Expression = ExpressionG<AppBackend>;