smpl_core/common/
expression.rs1use burn::{
2 prelude::Backend,
3 tensor::{Float, Tensor},
4};
5use log::warn;
6use ndarray as nd;
7#[derive(Clone)]
9pub struct Expression {
10 pub expr_coeffs: nd::Array1<f32>,
11}
12impl Default for Expression {
13 fn default() -> Self {
14 let num_coeffs = 10;
15 let expr_coeffs = ndarray::Array1::<f32>::zeros(num_coeffs);
16 Self { expr_coeffs }
17 }
18}
19impl Expression {
20 pub fn new(expr_coeffs: nd::Array1<f32>) -> Self {
21 Self { expr_coeffs }
22 }
23 pub fn new_empty(num_coeffs: usize) -> Self {
24 let expr_coeffs = ndarray::Array1::<f32>::zeros(num_coeffs);
25 Self { expr_coeffs }
26 }
27 #[must_use]
28 pub fn interpolate(&self, other_pose: &Self, other_weight: f32) -> Self {
29 if !(0.0..=1.0).contains(&other_weight) {
30 warn!("pose interpolation weight is outside the [0,1] range, will clamp. Weight is {other_weight}");
31 }
32 let other_weight = other_weight.clamp(0.0, 1.0);
33 let cur_w = 1.0 - other_weight;
34 let new_expression = cur_w * &self.expr_coeffs + other_weight * &other_pose.expr_coeffs;
35 Self::new(new_expression)
36 }
37}
38#[derive(Clone)]
41pub struct ExpressionOffsets<B: Backend> {
42 pub offsets: Tensor<B, 2, Float>,
43}