smpl_gloss_integration/
gltf.rs

1use crate::scene::SceneAnimation;
2use burn::backend::{Candle, NdArray, Wgpu};
3use burn::prelude::Backend;
4use gloss_img::dynamic_image::DynImage;
5use gloss_renderer::{
6    components::{DiffuseImg, MetalnessImg, Name, NormalImg, RoughnessImg},
7    scene::Scene,
8};
9use gloss_utils::{
10    bshare::{ToNalgebraFloat, ToNalgebraInt, ToNdArray},
11    nshare::ToNalgebra,
12};
13use image::imageops::FilterType;
14use log::info;
15use nalgebra::DMatrix;
16use ndarray::{self as nd, s};
17use smpl_core::common::types::SmplType;
18use smpl_core::{
19    codec::gltf::GltfCodec,
20    common::{metadata::smpl_metadata, pose::Pose, smpl_model::SmplCache, smpl_params::SmplParams},
21    conversions::pose_remap::PoseRemap,
22};
23use smpl_core::{
24    codec::{gltf::PerBodyData, scene::CameraTrack},
25    common::{
26        animation::Animation, betas::Betas, expression::Expression, pose_override::PoseOverride, pose_retarget::RetargetPoseYShift,
27        smpl_model::SmplCacheDynamic, smpl_options::SmplOptions, types::UpAxis,
28    },
29};
30use smpl_utils::array::{Gather2D, Gather3D};
31use std::f32::consts::PI;
32/// Creates a ``GltfCodec`` from an entity by extracting components from it
33pub trait GltfCodecGloss {
34    fn from_scene(scene: &Scene, max_texture_size: Option<u32>, export_camera: bool) -> GltfCodec;
35    fn from_entities(scene: &Scene, max_texture_size: Option<u32>, export_camera: bool, entities: Vec<String>) -> GltfCodec;
36}
37fn get_image(image: &DynImage, to_gray: bool, max_texture_size: Option<u32>) -> DynImage {
38    let mut image = image.clone();
39    if to_gray {
40        image = image.grayscale();
41    }
42    if let Some(force_image_size) = max_texture_size {
43        if image.width() > force_image_size {
44            image.resize(force_image_size, force_image_size, FilterType::Gaussian)
45        } else {
46            image
47        }
48    } else {
49        image
50    }
51}
52/// Trait implementation for ``GltfCodec``
53impl GltfCodecGloss for GltfCodec {
54    /// Get a ``GltfCodec`` from the scene
55    fn from_scene(scene: &Scene, max_texture_size: Option<u32>, export_camera: bool) -> GltfCodec {
56        let smpl_models = scene.get_resource::<&SmplCacheDynamic>().unwrap();
57        match &*smpl_models {
58            SmplCacheDynamic::NdArray(models) => from_scene_on_backend::<NdArray>(scene, models, max_texture_size, &None, export_camera),
59            SmplCacheDynamic::Wgpu(models) => from_scene_on_backend::<Wgpu>(scene, models, max_texture_size, &None, export_camera),
60            SmplCacheDynamic::Candle(models) => from_scene_on_backend::<Candle>(scene, models, max_texture_size, &None, export_camera),
61        }
62    }
63    /// Get a ``GltfCodec`` from the scene
64    fn from_entities(scene: &Scene, max_texture_size: Option<u32>, export_camera: bool, entities: Vec<String>) -> GltfCodec {
65        let smpl_models = scene.get_resource::<&SmplCacheDynamic>().unwrap();
66        match &*smpl_models {
67            SmplCacheDynamic::NdArray(models) => from_scene_on_backend::<NdArray>(scene, models, max_texture_size, &Some(entities), export_camera),
68            SmplCacheDynamic::Wgpu(models) => from_scene_on_backend::<Wgpu>(scene, models, max_texture_size, &Some(entities), export_camera),
69            SmplCacheDynamic::Candle(models) => from_scene_on_backend::<Candle>(scene, models, max_texture_size, &Some(entities), export_camera),
70        }
71    }
72}
73/// Function to get a ``GltfCodec`` from an entity on a generic Burn backend. We
74/// currently support - ``Candle``, ``NdArray``, and ``Wgpu``
75#[allow(clippy::too_many_lines)]
76#[allow(clippy::trivially_copy_pass_by_ref)]
77fn from_scene_on_backend<B: Backend>(
78    scene: &Scene,
79    smpl_models: &SmplCache<B>,
80    max_texture_size: Option<u32>,
81    entities: &Option<Vec<String>>,
82    export_camera: bool,
83) -> GltfCodec
84where
85    <B as Backend>::FloatTensorPrimitive<2>: Sync,
86    <B as Backend>::IntTensorPrimitive<2>: Sync,
87    B::QuantizedTensorPrimitive<1>: std::marker::Sync,
88    B::QuantizedTensorPrimitive<2>: std::marker::Sync,
89    B::QuantizedTensorPrimitive<3>: std::marker::Sync,
90{
91    let now = wasm_timer::Instant::now();
92    let mut gltf_codec = GltfCodec::default();
93    let scene_anim = scene.get_resource::<&SceneAnimation>().unwrap();
94    let nr_frames = scene_anim.num_frames;
95    let fps = scene_anim.config.fps;
96    if export_camera {
97        let mut cameras_query = scene.world.query::<&CameraTrack>();
98        for (_, camera_track) in cameras_query.iter() {
99            gltf_codec.camera_track = Some(camera_track.clone());
100        }
101    }
102    let mut query = scene.world.query::<(&SmplParams, &Name)>();
103    let num_bodies = query.iter().len();
104    gltf_codec.num_bodies = num_bodies;
105    let mut should_export_posedirs = false;
106    let mut should_export_exprdirs = false;
107    let mut num_expression_blend_shapes = 0;
108    for (entity, (smpl_params, _name)) in query.iter() {
109        if scene.world.has::<Animation>(entity).unwrap() && smpl_params.enable_pose_corrective {
110            should_export_posedirs = true;
111        }
112        let smpl_model = smpl_models.get_model_ref(smpl_params.smpl_type, smpl_params.gender).unwrap();
113        if let Ok(anim) = scene.get_comp::<&Animation>(&entity) {
114            if anim.has_expression() && smpl_model.get_expression_dirs().is_some() {
115                should_export_exprdirs = true;
116                num_expression_blend_shapes = smpl_model.get_expression_dirs().unwrap().shape().dims[1];
117            }
118        }
119    }
120    for (body_idx, (entity, (smpl_params, name))) in query.iter().enumerate() {
121        if let Some(entities) = entities {
122            if !entities.contains(&name.0) {
123                continue;
124            }
125        }
126        let smpl_version = smpl_params.smpl_type;
127        let gender = smpl_params.gender as i32;
128        let mut current_body = PerBodyData::default();
129        assert!(smpl_version != SmplType::SmplPP, "GLTF export for SMPL++ is not supported yet!");
130        let smpl_model = smpl_models.get_model_ref(smpl_params.smpl_type, smpl_params.gender).unwrap();
131        let Ok(betas) = scene.get_comp::<&Betas>(&entity) else {
132            panic!("Betas component does not exist!");
133        };
134        let default_pose = Pose::new_empty(UpAxis::Y, smpl_params.smpl_type);
135        let default_expression = Expression::new_empty(10);
136        let mut smpl_output = smpl_model.forward(&SmplOptions::default(), &betas, &default_pose, Some(&default_expression));
137        smpl_output.compute_normals();
138        smpl_output = smpl_model.create_body_with_uv(&smpl_output);
139        let metadata = smpl_metadata(&smpl_params.smpl_type);
140        let mut num_total_blendshapes = 0;
141        if should_export_posedirs {
142            num_total_blendshapes += metadata.num_pose_blend_shapes + 1;
143        }
144        if should_export_exprdirs {
145            num_total_blendshapes += num_expression_blend_shapes;
146        }
147        gltf_codec.smpl_type = smpl_version;
148        gltf_codec.gender = gender;
149        current_body.pose = Some(default_pose.clone());
150        gltf_codec.default_joint_poses = Some(default_pose.clone().joint_poses);
151        current_body.body_translation = Some(default_pose.clone().global_trans.to_shape((1, 3)).unwrap().to_owned());
152        let verts_na = smpl_output.verts.to_nalgebra();
153        let normals_na = smpl_output.normals.as_ref().expect("SMPL Output is missing normals!").to_nalgebra();
154        let faces_na = smpl_output.faces.to_nalgebra();
155        let uvs_na = smpl_output.uvs.as_ref().expect("SMPL Output is missing UVs!").to_nalgebra();
156        current_body.positions = Some(verts_na);
157        current_body.normals = Some(normals_na);
158        gltf_codec.faces = Some(faces_na);
159        gltf_codec.uvs = Some(uvs_na);
160        let smpl_joints = smpl_output.joints.clone().to_ndarray();
161        let joint_count = smpl_joints.shape()[0];
162        let lbs_weights = smpl_model.lbs_weights_split().to_ndarray();
163        let vertex_count = smpl_output.verts.dims()[0];
164        let mut skin_vertex_index = DMatrix::<u32>::zeros(vertex_count, 4);
165        let mut skin_vertex_weight = DMatrix::<f32>::zeros(vertex_count, 4);
166        for (vertex_id, row) in lbs_weights.outer_iter().enumerate() {
167            let mut vertex_weights: Vec<(usize, f32)> = row.iter().enumerate().map(|(index, &weight)| (index, weight)).collect();
168            vertex_weights.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
169            assert_eq!(vertex_weights.len().min(4), 4, "Illegal vertex weights");
170            for (i, (index, weight)) in vertex_weights.iter().take(4).enumerate() {
171                skin_vertex_index[(vertex_id, i)] = u32::try_from(*index).expect("Cannot convert to u32!");
172                skin_vertex_weight[(vertex_id, i)] = *weight;
173            }
174        }
175        gltf_codec.joint_index = Some(skin_vertex_index);
176        gltf_codec.joint_weight = Some(skin_vertex_weight);
177        let diffuse_img = scene.get_comp::<&DiffuseImg>(&entity);
178        if let Ok(diffuse_img) = diffuse_img {
179            if let Some(img) = &diffuse_img.generic_img.cpu_img {
180                current_body.diffuse_textures = Some(get_image(img, false, max_texture_size));
181            }
182        }
183        let normals_img = scene.get_comp::<&NormalImg>(&entity);
184        if let Ok(normals_img) = normals_img {
185            if let Some(img) = &normals_img.generic_img.cpu_img {
186                current_body.normals_textures = Some(get_image(img, false, max_texture_size));
187            }
188        }
189        let metalness_img = scene.get_comp::<&MetalnessImg>(&entity);
190        if let Ok(metalness_img) = metalness_img {
191            if let Some(img) = &metalness_img.generic_img.cpu_img {
192                current_body.metalness_textures = Some(get_image(img, true, max_texture_size));
193            }
194        }
195        let roughness_img = scene.get_comp::<&RoughnessImg>(&entity);
196        if let Ok(roughness_img) = roughness_img {
197            if let Some(img) = &roughness_img.generic_img.cpu_img {
198                current_body.roughness_textures = Some(get_image(img, true, max_texture_size));
199            }
200        }
201        if scene.world.has::<Pose>(entity).unwrap() && !scene.world.has::<Animation>(entity).unwrap() {
202            let Ok(pose_ref) = scene.get_comp::<&Pose>(&entity) else {
203                panic!("Pose component doesn't exist");
204            };
205            let current_pose: &Pose = &pose_ref;
206            let current_body_translation = current_pose.global_trans.to_shape((1, 3)).unwrap().to_owned();
207            current_body.pose = Some(current_pose.clone());
208            current_body.body_translation = Some(current_body_translation);
209            if smpl_params.enable_pose_corrective {
210                let vertex_offsets_merged = smpl_model.compute_pose_correctives(current_pose).to_ndarray();
211                let mapping = &smpl_model.idx_split_2_merged_vec();
212                let cols = vec![0, 1, 2];
213                let vertex_offsets = vertex_offsets_merged.gather(mapping, &cols).into_nalgebra();
214                current_body.positions = Some(current_body.positions.as_ref().unwrap() + vertex_offsets);
215            }
216        }
217        #[allow(clippy::cast_precision_loss)]
218        if scene.world.has::<Animation>(entity).unwrap() {
219            info!("Processing Animation for body {:?}", body_idx);
220            let anim = scene.get_comp::<&Animation>(&entity).unwrap();
221            gltf_codec.frame_count = Some(nr_frames);
222            let mut keyframe_times: Vec<f32> = Vec::new();
223            let mut current_body_rotations = nd::Array3::<f32>::zeros((joint_count, nr_frames, 3));
224            let mut current_body_translations = nd::Array2::<f32>::zeros((nr_frames, 3));
225            let mut current_body_scales = nd::Array2::<f32>::zeros((nr_frames, 3));
226            let mut current_per_frame_blend_weights = nd::Array2::<f32>::zeros((nr_frames, num_total_blendshapes));
227            if should_export_posedirs || should_export_exprdirs {
228                let mut full_morph_targets = nd::Array3::<f32>::zeros((num_total_blendshapes, vertex_count, 3));
229                let mut running_idx_morph_target = 0;
230                if should_export_posedirs {
231                    let mut pose_morph_targets = nd::Array3::<f32>::zeros((metadata.num_pose_blend_shapes + 1, vertex_count, 3));
232                    let nr_elem_merged = smpl_model.get_pose_dirs().dims()[0] / 3;
233                    let pose_dirs_merged = smpl_model
234                        .get_pose_dirs()
235                        .to_ndarray()
236                        .into_shape_with_order((nr_elem_merged, 3, metadata.num_pose_blend_shapes))
237                        .unwrap();
238                    let mapping = smpl_model.idx_split_2_merged_vec();
239                    let cols = vec![0, 1, 2];
240                    let depth = (0..metadata.num_pose_blend_shapes).collect::<Vec<_>>().into_boxed_slice();
241                    let pose_blend_shapes = pose_dirs_merged
242                        .gather(mapping, &cols, &depth)
243                        .into_shape_with_order((vertex_count, 3, metadata.num_pose_blend_shapes))
244                        .unwrap()
245                        .permuted_axes([2, 0, 1]);
246                    let morph_targets = (2.0 * PI) * pose_blend_shapes.clone();
247                    let pi = nd::Array1::<f32>::from_elem(metadata.num_pose_blend_shapes, -PI);
248                    let pi_array = pi.insert_axis(nd::Axis(1)).insert_axis(nd::Axis(2));
249                    assert_eq!(pose_blend_shapes.shape()[0], pi_array.len());
250                    let template_offset = (pose_blend_shapes * &pi_array).sum_axis(nd::Axis(0));
251                    pose_morph_targets
252                        .slice_mut(s![0..metadata.num_pose_blend_shapes, .., ..])
253                        .assign(&morph_targets);
254                    pose_morph_targets
255                        .slice_mut(s![metadata.num_pose_blend_shapes, .., ..])
256                        .assign(&template_offset);
257                    #[allow(clippy::range_plus_one)]
258                    full_morph_targets
259                        .slice_mut(s![
260                            running_idx_morph_target..running_idx_morph_target + metadata.num_pose_blend_shapes + 1,
261                            ..,
262                            ..
263                        ])
264                        .assign(&pose_morph_targets);
265                    running_idx_morph_target += metadata.num_pose_blend_shapes + 1;
266                }
267                #[allow(unused_assignments)]
268                if should_export_exprdirs {
269                    if let Some(expr_dirs) = smpl_model.get_expression_dirs() {
270                        let nr_elem_merged = expr_dirs.dims()[0] / 3;
271                        let expression_dirs_merged = expr_dirs
272                            .to_ndarray()
273                            .into_shape_with_order((nr_elem_merged, 3, num_expression_blend_shapes))
274                            .unwrap();
275                        let mapping = smpl_model.idx_split_2_merged_vec();
276                        let cols = vec![0, 1, 2];
277                        let depth = (0..metadata.expression_space_dim).collect::<Vec<_>>().into_boxed_slice();
278                        let expression_dirs_split = expression_dirs_merged
279                            .gather(mapping, &cols, &depth)
280                            .into_shape_with_order((vertex_count, 3, num_expression_blend_shapes))
281                            .unwrap()
282                            .permuted_axes([2, 0, 1]);
283                        full_morph_targets
284                            .slice_mut(s![
285                                running_idx_morph_target..running_idx_morph_target + num_expression_blend_shapes,
286                                ..,
287                                ..
288                            ])
289                            .assign(&expression_dirs_split);
290                        running_idx_morph_target += num_expression_blend_shapes;
291                    }
292                }
293                gltf_codec.morph_targets = Some(full_morph_targets);
294            }
295            for global_frame_idx in 0..nr_frames {
296                keyframe_times.push((global_frame_idx as f32) / fps);
297                if global_frame_idx < anim.start_offset || global_frame_idx > anim.start_offset + anim.num_animation_frames() {
298                    continue;
299                }
300                let mut local_frame_idx = global_frame_idx - anim.start_offset;
301                if global_frame_idx == (anim.start_offset + anim.num_animation_frames()) {
302                    local_frame_idx -= 1;
303                }
304                let mut pose = anim.get_pose_at_idx(local_frame_idx);
305                let pose_remap = PoseRemap::new(pose.smpl_type, smpl_params.smpl_type);
306                pose = pose_remap.remap(&pose);
307                if let Ok(ref pose_mask) = scene.get_comp::<&PoseOverride>(&entity) {
308                    let mut new_pose_mask = PoseOverride::clone(pose_mask);
309                    pose.apply_mask(&mut new_pose_mask);
310                }
311                if let Ok(ref pose_retarget) = scene.get_comp::<&RetargetPoseYShift>(&entity) {
312                    let mut pose_retarget_local = RetargetPoseYShift::clone(pose_retarget);
313                    pose_retarget_local.apply(&mut pose);
314                }
315                current_body_rotations.slice_mut(s![.., global_frame_idx, ..]).assign(&pose.joint_poses);
316                let mut skeleton_root_translation = pose.global_trans.to_owned();
317                let root_translation = smpl_output.joints.to_ndarray().slice(s![0, ..]).to_owned();
318                skeleton_root_translation = skeleton_root_translation + root_translation;
319                current_body_translations
320                    .slice_mut(s![global_frame_idx, ..])
321                    .assign(&skeleton_root_translation);
322                if global_frame_idx < (anim.start_offset + anim.num_animation_frames()) {
323                    current_body_scales.slice_mut(s![global_frame_idx, ..]).assign(&nd::Array1::ones(3));
324                }
325                let mut running_idx_morph_target = 0;
326                if should_export_posedirs {
327                    let pose_blend_weights = &smpl_model.compute_pose_feature(&pose);
328                    let rescaled_pose_blend_weights = pose_blend_weights.map(|&elem| (elem + PI) / (2.0 * PI));
329                    current_per_frame_blend_weights
330                        .slice_mut(s![global_frame_idx, 0..metadata.num_pose_blend_shapes])
331                        .assign(&rescaled_pose_blend_weights);
332                    if global_frame_idx == (anim.start_offset + anim.num_animation_frames()) {
333                        current_per_frame_blend_weights
334                            .slice_mut(s![global_frame_idx..nr_frames, 0..metadata.num_pose_blend_shapes])
335                            .assign(&rescaled_pose_blend_weights);
336                    }
337                    running_idx_morph_target += metadata.num_pose_blend_shapes + 1;
338                }
339                #[allow(unused_assignments)]
340                if should_export_exprdirs {
341                    let expr_opt = anim.get_expression_at_idx(local_frame_idx);
342                    if let Some(expr) = expr_opt.as_ref() {
343                        let max_nr_expr_coeffs = num_expression_blend_shapes.min(expr.expr_coeffs.len());
344                        let expr_coeffs = expr.expr_coeffs.slice(s![0..max_nr_expr_coeffs]);
345                        current_per_frame_blend_weights
346                            .slice_mut(s![
347                                global_frame_idx,
348                                running_idx_morph_target..running_idx_morph_target + max_nr_expr_coeffs
349                            ])
350                            .assign(&expr_coeffs);
351                    }
352                    running_idx_morph_target += num_expression_blend_shapes;
353                }
354            }
355            gltf_codec.keyframe_times = Some(keyframe_times);
356            current_body.body_scales = Some(current_body_scales);
357            current_body.body_translations = Some(current_body_translations);
358            current_body.body_rotations = Some(current_body_rotations);
359            if should_export_posedirs {
360                current_per_frame_blend_weights
361                    .slice_mut(s![.., metadata.num_pose_blend_shapes])
362                    .assign(&nd::Array1::<f32>::from_elem(nr_frames, 1.0));
363            }
364            if should_export_posedirs || should_export_exprdirs {
365                current_body.per_frame_blend_weights = Some(current_per_frame_blend_weights);
366            }
367        }
368        current_body.default_joint_translations = Some(smpl_joints);
369        gltf_codec.per_body_data.push(current_body);
370    }
371    info!(
372        "Writing {} body scene to GltfCodec: Took {} seconds for {} frames",
373        num_bodies,
374        now.elapsed().as_secs(),
375        nr_frames
376    );
377    gltf_codec
378}