smpl_core/common/
expression.rs1use crate::common::types::FaceType;
2use burn::{
3 prelude::Backend,
4 tensor::{Float, Tensor},
5};
6use log::warn;
7use ndarray as nd;
8#[derive(Clone)]
10pub struct Expression {
11 pub expr_coeffs: nd::Array1<f32>,
12 pub expr_type: FaceType,
13}
14impl Default for Expression {
15 fn default() -> Self {
16 let num_coeffs = 10;
17 let expr_coeffs = ndarray::Array1::<f32>::zeros(num_coeffs);
18 Self {
19 expr_coeffs,
20 expr_type: FaceType::SmplX,
21 }
22 }
23}
24impl Expression {
25 pub fn new(expr_coeffs: nd::Array1<f32>, expr_type: FaceType) -> Self {
26 Self { expr_coeffs, expr_type }
27 }
28 pub fn new_empty(num_coeffs: usize, expr_type: FaceType) -> Self {
29 let expr_coeffs = ndarray::Array1::<f32>::zeros(num_coeffs);
30 Self { expr_coeffs, expr_type }
31 }
32 #[must_use]
33 pub fn interpolate(&self, other_pose: &Self, other_weight: f32) -> Self {
34 if !(0.0..=1.0).contains(&other_weight) {
35 warn!("pose interpolation weight is outside the [0,1] range, will clamp. Weight is {other_weight}");
36 }
37 let other_weight = other_weight.clamp(0.0, 1.0);
38 let cur_w = 1.0 - other_weight;
39 let new_expression = cur_w * &self.expr_coeffs + other_weight * &other_pose.expr_coeffs;
40 Self::new(new_expression, self.expr_type)
41 }
42}
43#[derive(Clone)]
46pub struct ExpressionOffsets<B: Backend> {
47 pub offsets: Tensor<B, 2, Float>,
48}