smpl_gloss_integration/
gltf.rs

1use crate::scene::SceneAnimation;
2use burn::backend::{Candle, NdArray, Wgpu};
3use burn::prelude::Backend;
4use gloss_img::dynamic_image::DynImage;
5use gloss_renderer::{
6    components::{DiffuseImg, MetalnessImg, Name, NormalImg, RoughnessImg},
7    scene::Scene,
8};
9use gloss_utils::{
10    bshare::{ToNalgebraFloat, ToNalgebraInt, ToNdArray},
11    nshare::ToNalgebra,
12};
13use image::imageops::FilterType;
14use log::info;
15use nalgebra::DMatrix;
16use ndarray::{self as nd, s};
17use smpl_core::common::types::SmplType;
18use smpl_core::{
19    codec::gltf::GltfCodec,
20    common::{metadata::smpl_metadata, pose::Pose, smpl_model::SmplCache, smpl_params::SmplParams},
21    conversions::pose_remap::PoseRemap,
22};
23use smpl_core::{
24    codec::{gltf::PerBodyData, scene::CameraTrack},
25    common::{
26        animation::Animation,
27        betas::Betas,
28        expression::Expression,
29        pose_override::PoseOverride,
30        pose_retarget::RetargetPoseYShift,
31        smpl_model::SmplCacheDynamic,
32        smpl_options::SmplOptions,
33        types::{FaceType, UpAxis},
34    },
35};
36use smpl_utils::array::{Gather2D, Gather3D};
37use std::f32::consts::PI;
38/// Creates a ``GltfCodec`` from an entity by extracting components from it
39pub trait GltfCodecGloss {
40    fn from_scene(scene: &Scene, max_texture_size: Option<u32>, export_camera: bool) -> GltfCodec;
41    fn from_entities(scene: &Scene, max_texture_size: Option<u32>, export_camera: bool, entities: Vec<String>) -> GltfCodec;
42}
43fn get_image(image: &DynImage, to_gray: bool, max_texture_size: Option<u32>) -> DynImage {
44    let mut image = image.clone();
45    if to_gray {
46        image = image.grayscale();
47    }
48    if let Some(force_image_size) = max_texture_size {
49        if image.width() > force_image_size {
50            image.resize(force_image_size, force_image_size, FilterType::Gaussian)
51        } else {
52            image
53        }
54    } else {
55        image
56    }
57}
58/// Trait implementation for ``GltfCodec``
59impl GltfCodecGloss for GltfCodec {
60    /// Get a ``GltfCodec`` from the scene
61    fn from_scene(scene: &Scene, max_texture_size: Option<u32>, export_camera: bool) -> GltfCodec {
62        let smpl_models = scene.get_resource::<&SmplCacheDynamic>().unwrap();
63        match &*smpl_models {
64            SmplCacheDynamic::NdArray(models) => from_scene_on_backend::<NdArray>(scene, models, max_texture_size, None, export_camera),
65            SmplCacheDynamic::Wgpu(models) => from_scene_on_backend::<Wgpu>(scene, models, max_texture_size, None, export_camera),
66            SmplCacheDynamic::Candle(models) => from_scene_on_backend::<Candle>(scene, models, max_texture_size, None, export_camera),
67        }
68    }
69    /// Get a ``GltfCodec`` from the scene
70    fn from_entities(scene: &Scene, max_texture_size: Option<u32>, export_camera: bool, entities: Vec<String>) -> GltfCodec {
71        let smpl_models = scene.get_resource::<&SmplCacheDynamic>().unwrap();
72        match &*smpl_models {
73            SmplCacheDynamic::NdArray(models) => from_scene_on_backend::<NdArray>(scene, models, max_texture_size, Some(&entities), export_camera),
74            SmplCacheDynamic::Wgpu(models) => from_scene_on_backend::<Wgpu>(scene, models, max_texture_size, Some(&entities), export_camera),
75            SmplCacheDynamic::Candle(models) => from_scene_on_backend::<Candle>(scene, models, max_texture_size, Some(&entities), export_camera),
76        }
77    }
78}
79/// Function to get a ``GltfCodec`` from an entity on a generic Burn backend. We
80/// currently support - ``Candle``, ``NdArray``, and ``Wgpu``
81#[allow(clippy::too_many_lines)]
82#[allow(clippy::trivially_copy_pass_by_ref)]
83fn from_scene_on_backend<B: Backend>(
84    scene: &Scene,
85    smpl_models: &SmplCache<B>,
86    max_texture_size: Option<u32>,
87    entities: Option<&Vec<String>>,
88    export_camera: bool,
89) -> GltfCodec
90where
91    <B as Backend>::FloatTensorPrimitive<2>: Sync,
92    <B as Backend>::IntTensorPrimitive<2>: Sync,
93    B::QuantizedTensorPrimitive<1>: std::marker::Sync,
94    B::QuantizedTensorPrimitive<2>: std::marker::Sync,
95    B::QuantizedTensorPrimitive<3>: std::marker::Sync,
96{
97    let now = wasm_timer::Instant::now();
98    let mut gltf_codec = GltfCodec::default();
99    let mut nr_frames = 0;
100    if export_camera {
101        let mut cameras_query = scene.world.query::<&CameraTrack>();
102        for (_, camera_track) in cameras_query.iter() {
103            gltf_codec.camera_track = Some(camera_track.clone());
104        }
105    }
106    let mut query = scene.world.query::<(&SmplParams, &Name)>();
107    let num_bodies = query.iter().len();
108    gltf_codec.num_bodies = num_bodies;
109    let mut should_export_posedirs = false;
110    let mut should_export_exprdirs = false;
111
112    let mut num_expression_blend_shapes = 0;
113    for (entity, (smpl_params, _name)) in query.iter() {
114        if scene.world.has::<Animation>(entity).unwrap() && smpl_params.enable_pose_corrective {
115            should_export_posedirs = true;
116        }
117        let smpl_model = smpl_models.get_model_ref(smpl_params.smpl_type, smpl_params.gender).unwrap();
118        if let Ok(anim) = scene.get_comp::<&Animation>(&entity) {
119            if anim.has_expression() && smpl_model.get_expression_dirs().is_some() {
120                should_export_exprdirs = true;
121                num_expression_blend_shapes = smpl_model.get_expression_dirs().unwrap().shape().dims[1];
122            }
123        }
124    }
125
126    for (body_idx, (entity, (smpl_params, name))) in query.iter().enumerate() {
127        if let Some(entities) = entities {
128            if !entities.contains(&name.0) {
129                continue;
130            }
131        }
132        let smpl_version = smpl_params.smpl_type;
133        let gender = smpl_params.gender as i32;
134        let mut current_body = PerBodyData::default();
135        assert!(smpl_version != SmplType::SmplPP, "GLTF export for SMPL++ is not supported yet!");
136        let smpl_model = smpl_models.get_model_ref(smpl_params.smpl_type, smpl_params.gender).unwrap();
137        let Ok(betas) = scene.get_comp::<&Betas>(&entity) else {
138            panic!("Betas component does not exist!");
139        };
140        let default_pose = Pose::new_empty(UpAxis::Y, smpl_params.smpl_type);
141        let default_expression = Expression::new_empty(10, FaceType::SmplX);
142        let mut smpl_output = smpl_model.forward(&SmplOptions::default(), &betas, &default_pose, Some(&default_expression));
143        smpl_output.compute_normals();
144        smpl_output = smpl_model.create_body_with_uv(&smpl_output);
145        let metadata = smpl_metadata(&smpl_params.smpl_type);
146        let mut num_total_blendshapes = 0;
147        if should_export_posedirs {
148            num_total_blendshapes += metadata.num_pose_blend_shapes + 1;
149        }
150        if should_export_exprdirs {
151            num_total_blendshapes += num_expression_blend_shapes + 1;
152        }
153        gltf_codec.smpl_type = smpl_version;
154        gltf_codec.gender = gender;
155        current_body.pose = Some(default_pose.clone());
156        gltf_codec.default_joint_poses = Some(default_pose.clone().joint_poses);
157        current_body.body_translation = Some(default_pose.clone().global_trans.to_shape((1, 3)).unwrap().to_owned());
158        let verts_na = smpl_output.verts.to_nalgebra();
159        let normals_na = smpl_output.normals.as_ref().expect("SMPL Output is missing normals!").to_nalgebra();
160        let faces_na = smpl_output.faces.to_nalgebra();
161        let uvs_na = smpl_output.uvs.as_ref().expect("SMPL Output is missing UVs!").to_nalgebra();
162        current_body.positions = Some(verts_na);
163        current_body.normals = Some(normals_na);
164        gltf_codec.faces = Some(faces_na);
165        gltf_codec.uvs = Some(uvs_na);
166        let smpl_joints = smpl_output.joints.clone().to_ndarray();
167        let joint_count = smpl_joints.shape()[0];
168        let lbs_weights = smpl_model.lbs_weights_split().to_ndarray();
169        let vertex_count = smpl_output.verts.dims()[0];
170        let mut skin_vertex_index = DMatrix::<u32>::zeros(vertex_count, 4);
171        let mut skin_vertex_weight = DMatrix::<f32>::zeros(vertex_count, 4);
172        for (vertex_id, row) in lbs_weights.outer_iter().enumerate() {
173            let mut vertex_weights: Vec<(usize, f32)> = row.iter().enumerate().map(|(index, &weight)| (index, weight)).collect();
174            vertex_weights.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
175            assert_eq!(vertex_weights.len().min(4), 4, "Illegal vertex weights");
176            for (i, (index, weight)) in vertex_weights.iter().take(4).enumerate() {
177                skin_vertex_index[(vertex_id, i)] = u32::try_from(*index).expect("Cannot convert to u32!");
178                skin_vertex_weight[(vertex_id, i)] = *weight;
179            }
180        }
181        gltf_codec.joint_index = Some(skin_vertex_index);
182        gltf_codec.joint_weight = Some(skin_vertex_weight);
183        let diffuse_img = scene.get_comp::<&DiffuseImg>(&entity);
184        if let Ok(diffuse_img) = diffuse_img {
185            if let Some(img) = &diffuse_img.generic_img.cpu_img {
186                current_body.diffuse_textures = Some(get_image(img, false, max_texture_size));
187            }
188        }
189        let normals_img = scene.get_comp::<&NormalImg>(&entity);
190        if let Ok(normals_img) = normals_img {
191            if let Some(img) = &normals_img.generic_img.cpu_img {
192                current_body.normals_textures = Some(get_image(img, false, max_texture_size));
193            }
194        }
195        let metalness_img = scene.get_comp::<&MetalnessImg>(&entity);
196        if let Ok(metalness_img) = metalness_img {
197            if let Some(img) = &metalness_img.generic_img.cpu_img {
198                current_body.metalness_textures = Some(get_image(img, true, max_texture_size));
199            }
200        }
201        let roughness_img = scene.get_comp::<&RoughnessImg>(&entity);
202        if let Ok(roughness_img) = roughness_img {
203            if let Some(img) = &roughness_img.generic_img.cpu_img {
204                current_body.roughness_textures = Some(get_image(img, true, max_texture_size));
205            }
206        }
207        if scene.world.has::<Pose>(entity).unwrap() && !scene.world.has::<Animation>(entity).unwrap() {
208            let Ok(pose_ref) = scene.get_comp::<&Pose>(&entity) else {
209                panic!("Pose component doesn't exist");
210            };
211            let current_pose: &Pose = &pose_ref;
212            let current_body_translation = current_pose.global_trans.to_shape((1, 3)).unwrap().to_owned();
213            current_body.pose = Some(current_pose.clone());
214            current_body.body_translation = Some(current_body_translation);
215            if smpl_params.enable_pose_corrective {
216                let vertex_offsets_merged = smpl_model.compute_pose_correctives(current_pose).to_ndarray();
217                let mapping = &smpl_model.idx_split_2_merged_vec();
218                let cols = vec![0, 1, 2];
219                let vertex_offsets = vertex_offsets_merged.gather(mapping, &cols).into_nalgebra();
220                current_body.positions = Some(current_body.positions.as_ref().unwrap() + vertex_offsets);
221            }
222        }
223        #[allow(clippy::cast_precision_loss)]
224        if scene.world.has::<Animation>(entity).unwrap() {
225            let scene_anim = scene.get_resource::<&SceneAnimation>().unwrap();
226            nr_frames = scene_anim.num_frames;
227            let fps = scene_anim.config.fps;
228            info!("Processing Animation for body {:?}", body_idx);
229            let anim = scene.get_comp::<&Animation>(&entity).unwrap();
230            gltf_codec.frame_count = Some(nr_frames);
231            let mut keyframe_times: Vec<f32> = Vec::new();
232            let mut current_body_rotations = nd::Array3::<f32>::zeros((joint_count, nr_frames, 3));
233            let mut current_body_translations = nd::Array2::<f32>::zeros((nr_frames, 3));
234            let mut current_body_scales = nd::Array2::<f32>::zeros((nr_frames, 3));
235            let mut current_per_frame_blend_weights = nd::Array2::<f32>::zeros((nr_frames, num_total_blendshapes));
236            if should_export_posedirs || should_export_exprdirs {
237                let mut full_morph_targets = nd::Array3::<f32>::zeros((num_total_blendshapes, vertex_count, 3));
238                let mut running_idx_morph_target = 0;
239                if should_export_posedirs {
240                    let mut pose_morph_targets = nd::Array3::<f32>::zeros((metadata.num_pose_blend_shapes + 1, vertex_count, 3));
241                    let nr_elem_merged = smpl_model.get_pose_dirs().dims()[0] / 3;
242                    let pose_dirs_merged = smpl_model
243                        .get_pose_dirs()
244                        .to_ndarray()
245                        .into_shape_with_order((nr_elem_merged, 3, metadata.num_pose_blend_shapes))
246                        .unwrap();
247                    let mapping = smpl_model.idx_split_2_merged_vec();
248                    let cols = vec![0, 1, 2];
249                    let depth = (0..metadata.num_pose_blend_shapes).collect::<Vec<_>>().into_boxed_slice();
250                    let pose_blend_shapes = pose_dirs_merged
251                        .gather(mapping, &cols, &depth)
252                        .into_shape_with_order((vertex_count, 3, metadata.num_pose_blend_shapes))
253                        .unwrap()
254                        .permuted_axes([2, 0, 1]);
255                    let pi = nd::Array1::<f32>::from_elem(metadata.num_pose_blend_shapes, -PI);
256                    let pi_array = pi.insert_axis(nd::Axis(1)).insert_axis(nd::Axis(2));
257                    assert_eq!(pose_blend_shapes.shape()[0], pi_array.len());
258                    let template_offset = (pose_blend_shapes.clone() * &pi_array).sum_axis(nd::Axis(0));
259                    pose_morph_targets
260                        .slice_mut(s![0..metadata.num_pose_blend_shapes, .., ..])
261                        .assign(&(pose_blend_shapes));
262                    pose_morph_targets
263                        .slice_mut(s![metadata.num_pose_blend_shapes, .., ..])
264                        .assign(&template_offset);
265                    #[allow(clippy::range_plus_one)]
266                    full_morph_targets
267                        .slice_mut(s![
268                            running_idx_morph_target..running_idx_morph_target + metadata.num_pose_blend_shapes + 1,
269                            ..,
270                            ..
271                        ])
272                        .assign(&pose_morph_targets);
273                    running_idx_morph_target += metadata.num_pose_blend_shapes + 1;
274                    gltf_codec.num_pose_morph_targets = metadata.num_pose_blend_shapes + 1;
275                }
276                #[allow(unused_assignments)]
277                if should_export_exprdirs {
278                    if let Some(expr_dirs) = smpl_model.get_expression_dirs() {
279                        let mut expression_morph_targets = nd::Array3::<f32>::zeros((num_expression_blend_shapes + 1, vertex_count, 3));
280                        let nr_elem_merged = expr_dirs.dims()[0] / 3;
281                        let expression_dirs_merged = expr_dirs
282                            .to_ndarray()
283                            .into_shape_with_order((nr_elem_merged, 3, num_expression_blend_shapes))
284                            .unwrap();
285                        let mapping = smpl_model.idx_split_2_merged_vec();
286                        let cols = vec![0, 1, 2];
287                        let depth = (0..num_expression_blend_shapes).collect::<Vec<_>>().into_boxed_slice();
288                        let expression_dirs_split = expression_dirs_merged
289                            .gather(mapping, &cols, &depth)
290                            .into_shape_with_order((vertex_count, 3, num_expression_blend_shapes))
291                            .unwrap()
292                            .permuted_axes([2, 0, 1]);
293                        let expression_bounds = nd::Array1::<f32>::from_elem(num_expression_blend_shapes, -7.0);
294                        let expression_bounds_array = expression_bounds.insert_axis(nd::Axis(1)).insert_axis(nd::Axis(2));
295                        assert_eq!(expression_dirs_split.shape()[0], expression_bounds_array.len());
296                        let template_offset = (expression_dirs_split.clone() * &expression_bounds_array).sum_axis(nd::Axis(0));
297                        expression_morph_targets
298                            .slice_mut(s![0..num_expression_blend_shapes, .., ..])
299                            .assign(&(expression_dirs_split));
300                        expression_morph_targets
301                            .slice_mut(s![num_expression_blend_shapes, .., ..])
302                            .assign(&template_offset);
303                        #[allow(clippy::range_plus_one)]
304                        full_morph_targets
305                            .slice_mut(s![
306                                running_idx_morph_target..running_idx_morph_target + num_expression_blend_shapes + 1,
307                                ..,
308                                ..
309                            ])
310                            .assign(&expression_morph_targets);
311                        running_idx_morph_target += num_expression_blend_shapes + 1;
312                        gltf_codec.num_expression_morph_targets = num_expression_blend_shapes + 1;
313                    }
314                }
315                gltf_codec.morph_targets = Some(full_morph_targets);
316            }
317            for global_frame_idx in 0..nr_frames {
318                keyframe_times.push((global_frame_idx as f32) / fps);
319                if global_frame_idx < anim.start_offset || global_frame_idx > anim.start_offset + anim.num_animation_frames() {
320                    continue;
321                }
322                let mut local_frame_idx = global_frame_idx - anim.start_offset;
323                if global_frame_idx == (anim.start_offset + anim.num_animation_frames()) {
324                    local_frame_idx -= 1;
325                }
326                let mut pose = anim.get_pose_at_idx(local_frame_idx);
327                let pose_remap = PoseRemap::new(pose.smpl_type, smpl_params.smpl_type);
328                pose = pose_remap.remap(&pose);
329                if let Ok(ref pose_mask) = scene.get_comp::<&PoseOverride>(&entity) {
330                    let mut new_pose_mask = PoseOverride::clone(pose_mask);
331                    pose.apply_mask(&mut new_pose_mask);
332                }
333                if let Ok(ref pose_retarget) = scene.get_comp::<&RetargetPoseYShift>(&entity) {
334                    let mut pose_retarget_local = RetargetPoseYShift::clone(pose_retarget);
335                    pose_retarget_local.apply(&mut pose);
336                }
337                current_body_rotations.slice_mut(s![.., global_frame_idx, ..]).assign(&pose.joint_poses);
338                let mut skeleton_root_translation = pose.global_trans.to_owned();
339                let root_translation = smpl_output.joints.to_ndarray().slice(s![0, ..]).to_owned();
340                skeleton_root_translation = skeleton_root_translation + root_translation;
341                current_body_translations
342                    .slice_mut(s![global_frame_idx, ..])
343                    .assign(&skeleton_root_translation);
344                if global_frame_idx < (anim.start_offset + anim.num_animation_frames()) {
345                    current_body_scales.slice_mut(s![global_frame_idx, ..]).assign(&nd::Array1::ones(3));
346                }
347                let mut running_idx_morph_target = 0;
348                if should_export_posedirs {
349                    let pose_blend_weights = &smpl_model.compute_pose_feature(&pose);
350                    current_per_frame_blend_weights
351                        .slice_mut(s![global_frame_idx, 0..metadata.num_pose_blend_shapes])
352                        .assign(pose_blend_weights);
353                    if global_frame_idx == (anim.start_offset + anim.num_animation_frames()) {
354                        current_per_frame_blend_weights
355                            .slice_mut(s![global_frame_idx..nr_frames, 0..metadata.num_pose_blend_shapes])
356                            .assign(pose_blend_weights);
357                    }
358                    running_idx_morph_target += metadata.num_pose_blend_shapes + 1;
359                }
360                #[allow(unused_assignments)]
361                if should_export_exprdirs {
362                    let expr_opt = anim.get_expression_at_idx(local_frame_idx);
363                    if let Some(expr) = expr_opt.as_ref() {
364                        let max_nr_expr_coeffs = num_expression_blend_shapes.min(expr.expr_coeffs.len());
365                        let expr_coeffs = expr.expr_coeffs.slice(s![0..max_nr_expr_coeffs]);
366                        current_per_frame_blend_weights
367                            .slice_mut(s![
368                                global_frame_idx,
369                                running_idx_morph_target..running_idx_morph_target + max_nr_expr_coeffs
370                            ])
371                            .assign(&expr_coeffs);
372                    }
373                    running_idx_morph_target += num_expression_blend_shapes + 1;
374                }
375            }
376            gltf_codec.keyframe_times = Some(keyframe_times);
377            current_body.body_scales = Some(current_body_scales);
378            current_body.body_translations = Some(current_body_translations);
379            current_body.body_rotations = Some(current_body_rotations);
380            if should_export_posedirs {
381                current_per_frame_blend_weights
382                    .slice_mut(s![.., metadata.num_pose_blend_shapes])
383                    .assign(&nd::Array1::<f32>::from_elem(nr_frames, 0.0));
384            }
385            if should_export_exprdirs {
386                current_per_frame_blend_weights
387                    .slice_mut(s![.., metadata.num_pose_blend_shapes + num_expression_blend_shapes + 1])
388                    .assign(&nd::Array1::<f32>::from_elem(nr_frames, 0.0));
389            }
390            if should_export_posedirs || should_export_exprdirs {
391                current_body.per_frame_blend_weights = Some(current_per_frame_blend_weights);
392            }
393        }
394        current_body.default_joint_translations = Some(smpl_joints);
395        gltf_codec.per_body_data.push(current_body);
396    }
397    info!(
398        "Writing {} body scene to GltfCodec: Took {} seconds for {} frames",
399        num_bodies,
400        now.elapsed().as_secs(),
401        nr_frames
402    );
403    gltf_codec
404}