wgpu_3dgs_core/
gaussian.rs

1use std::io::BufRead;
2
3use glam::*;
4
5use crate::{
6    PlyGaussianPod, PlyGaussians, SpzGaussian, SpzGaussianPosition, SpzGaussianPositionRef,
7    SpzGaussianRef, SpzGaussianRotation, SpzGaussianRotationRef, SpzGaussianSh, SpzGaussians,
8    SpzGaussiansHeader,
9};
10
11/// A trait of representing an iterable collection of [`Gaussian`].
12pub trait IterGaussian: FromIterator<Gaussian> {
13    /// Iterate over [`Gaussian`].
14    fn iter_gaussian(&self) -> impl ExactSizeIterator<Item = Gaussian> + '_;
15}
16
17impl IterGaussian for Vec<Gaussian> {
18    fn iter_gaussian(&self) -> impl ExactSizeIterator<Item = Gaussian> + '_ {
19        self.iter().copied()
20    }
21}
22
23/// A trait of representing a [`IterGaussian`] that can be read from a buffer.
24pub trait ReadIterGaussian: IterGaussian {
25    /// Read from a buffer.
26    fn read_from(reader: &mut impl BufRead) -> std::io::Result<Self>;
27
28    /// Read from a file.
29    fn read_from_file(path: impl AsRef<std::path::Path>) -> std::io::Result<Self> {
30        let file = std::fs::File::open(path)?;
31        let mut reader = std::io::BufReader::new(file);
32        Self::read_from(&mut reader)
33    }
34}
35
36/// A trait of representing a [`IterGaussian`] that can be written to a buffer.
37pub trait WriteIterGaussian: IterGaussian {
38    /// Write to a buffer.
39    fn write_to(&self, writer: &mut impl std::io::Write) -> std::io::Result<()>;
40
41    /// Write to a file.
42    fn write_to_file(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> {
43        let file = std::fs::File::create(path)?;
44        let mut writer = std::io::BufWriter::new(file);
45        self.write_to(&mut writer)
46    }
47}
48
49/// The Gaussian.
50///
51/// This is an intermediate representation used by the CPU to convert to
52/// [`GaussianPod`](crate::GaussianPod).
53#[derive(Debug, Clone, Copy, PartialEq)]
54pub struct Gaussian {
55    pub rot: Quat,
56    pub pos: Vec3,
57    pub color: U8Vec4,
58    pub sh: [Vec3; 15],
59    pub scale: Vec3,
60}
61
62impl Gaussian {
63    /// The constant to convert from SH coefficient at degree 0 to linear color.
64    pub const SH0_TO_LINEAR_FACTOR: f32 = 0.2820948;
65
66    /// The constant to convert from SH coefficient at degree 0 to linear color in SPZ.
67    pub const SPZ_SH0_TO_LINEAR_FACTOR: f32 = 0.15;
68
69    /// Convert from [`PlyGaussianPod`].
70    pub fn from_ply(ply: &PlyGaussianPod) -> Self {
71        let pos = Vec3::from_array(ply.pos);
72
73        let rot = Quat::from_xyzw(ply.rot[1], ply.rot[2], ply.rot[3], ply.rot[0]).normalize();
74
75        let scale = Vec3::from_array(ply.scale).exp();
76
77        let color = ((Vec3::from_array(ply.color) * Self::SH0_TO_LINEAR_FACTOR + Vec3::splat(0.5))
78            * 255.0)
79            .extend((1.0 / (1.0 + (-ply.alpha).exp())) * 255.0)
80            .clamp(Vec4::splat(0.0), Vec4::splat(255.0))
81            .as_u8vec4();
82
83        let sh = std::array::from_fn(|i| Vec3::new(ply.sh[i], ply.sh[i + 15], ply.sh[i + 30]));
84
85        Self {
86            rot,
87            pos,
88            color,
89            sh,
90            scale,
91        }
92    }
93
94    /// Convert to [`PlyGaussianPod`].
95    pub fn to_ply(&self) -> PlyGaussianPod {
96        let pos = self.pos.to_array();
97
98        let rot = [self.rot.w, self.rot.x, self.rot.y, self.rot.z];
99
100        let scale = self.scale.map(|x| x.ln()).to_array();
101
102        let rgba = self.color.as_vec4() / 255.0;
103        let color = ((rgba.xyz() - Vec3::splat(0.5)) / Self::SH0_TO_LINEAR_FACTOR).to_array();
104
105        let alpha = -(1.0 / rgba.w - 1.0).ln();
106
107        let mut sh = [0.0; 3 * 15];
108        for i in 0..15 {
109            sh[i] = self.sh[i].x;
110            sh[i + 15] = self.sh[i].y;
111            sh[i + 30] = self.sh[i].z;
112        }
113
114        let normal = [0.0, 0.0, 1.0];
115
116        PlyGaussianPod {
117            pos,
118            normal,
119            color,
120            sh,
121            alpha,
122            scale,
123            rot,
124        }
125    }
126
127    const SPZ_COLOR_TO_LINEAR_FRAC_A_B: f32 =
128        Gaussian::SH0_TO_LINEAR_FACTOR / Gaussian::SPZ_SH0_TO_LINEAR_FACTOR;
129    const SPZ_COLOR_TO_LINEAR_FRAC_F2_F1: f32 = 0.5 * 255.0;
130    const SPZ_COLOR_TO_LINEAR_C: f32 =
131        (1.0 - Self::SPZ_COLOR_TO_LINEAR_FRAC_A_B) * Self::SPZ_COLOR_TO_LINEAR_FRAC_F2_F1;
132
133    /// Convert from [`SpzGaussianRef`].
134    pub fn from_spz(spz: SpzGaussianRef, header: &SpzGaussiansHeader) -> Self {
135        let pos = match spz.position {
136            SpzGaussianPositionRef::Float16(pos) => {
137                // The Niantic SPZ format matches the `half` crate's f16 const conversion.
138                let unpacked = pos.map(|c| half::f16::from_bits(c).to_f32_const());
139                Vec3::from_array(unpacked)
140            }
141            SpzGaussianPositionRef::FixedPoint24(pos) => {
142                let scale = 1.0 / (1 << header.fractional_bits()) as f32;
143                let unpacked = pos.map(|c| {
144                    let mut fixed32: i32 = c[0] as i32;
145                    fixed32 |= (c[1] as i32) << 8;
146                    fixed32 |= (c[2] as i32) << 16;
147                    fixed32 |= if fixed32 & 0x800000 != 0 {
148                        0xff000000u32 as i32
149                    } else {
150                        0
151                    };
152                    fixed32 as f32 * scale
153                });
154                Vec3::from_array(unpacked)
155            }
156        };
157
158        let scale = Vec3::from_array(spz.scale.map(|c| c as f32 / 16.0 - 10.0)).exp();
159
160        let rot = match spz.rotation {
161            SpzGaussianRotationRef::QuatFirstThree(quat) => {
162                let xyz = Vec3::from(quat.map(|c| c as f32 / 127.5 - 1.0));
163                let w = (1.0 - xyz.length_squared()).max(0.0).sqrt();
164                Quat::from_xyzw(xyz.x, xyz.y, xyz.z, w)
165            }
166            SpzGaussianRotationRef::QuatSmallestThree(quat) => {
167                let mut comp: u32 = quat[0] as u32
168                    | ((quat[1] as u32) << 8)
169                    | ((quat[2] as u32) << 16)
170                    | ((quat[3] as u32) << 24);
171
172                const C_MASK: u32 = (1 << 9) - 1;
173
174                let largest_index = (comp >> 30) as usize;
175                let mut sum_squares = 0.0f32;
176                let mut comps = std::array::from_fn(|i| {
177                    if i == largest_index {
178                        return 0.0;
179                    }
180
181                    let mag = comp & C_MASK;
182                    let neg_bit = (comp >> 9) & 1;
183                    comp >>= 10;
184
185                    let value = std::f32::consts::FRAC_1_SQRT_2
186                        * (mag as f32 / C_MASK as f32)
187                        * if neg_bit != 0 { -1.0 } else { 1.0 };
188                    sum_squares += value * value;
189
190                    value
191                });
192
193                comps[largest_index] = (1.0 - sum_squares).max(0.0).sqrt();
194
195                Quat::from_array(comps)
196            }
197        };
198
199        let color = U8Vec3::from_array(spz.color.map(|c| {
200            (c as f32 * Self::SPZ_COLOR_TO_LINEAR_FRAC_A_B + Self::SPZ_COLOR_TO_LINEAR_C)
201                .clamp(0.0, 255.0) as u8
202        }))
203        .extend(*spz.alpha);
204
205        let mut sh = [Vec3::ZERO; 15];
206        for (src, dst) in spz.sh.iter().zip(sh.iter_mut()) {
207            *dst = Vec3::from_array(src.map(|c| (c as f32 - 128.0) / 128.0));
208        }
209
210        Self {
211            rot,
212            pos,
213            color,
214            sh,
215            scale,
216        }
217    }
218
219    /// Convert to [`SpzGaussian`].
220    ///
221    /// User usually don't need to call this directly due to the overhead of constructing a
222    /// valid [`SpzGaussiansHeader`]. Instead, use one of the following methods to convert a
223    /// collection of [`Gaussian`] to [`SpzGaussians`](crate::SpzGaussians) properly:
224    ///
225    /// - [`SpzGaussians::from_gaussians`](crate::SpzGaussians::from_gaussians)
226    /// - [`SpzGaussians::from_gaussians_with_options`](crate::SpzGaussians::from_gaussians_with_options)
227    pub fn to_spz(
228        &self,
229        header: &SpzGaussiansHeader,
230        options: &GaussianToSpzOptions,
231    ) -> SpzGaussian {
232        let position = if header.uses_float16() {
233            let packed = self
234                .pos
235                .to_array()
236                .map(|c| half::f16::from_f32_const(c).to_bits());
237            SpzGaussianPosition::Float16(packed)
238        } else {
239            let scale = (1 << header.fractional_bits()) as f32;
240            let packed = self.pos.to_array().map(|c| {
241                let fixed32 = (c * scale).round() as i32;
242                [
243                    (fixed32 & 0xff) as u8,
244                    ((fixed32 >> 8) & 0xff) as u8,
245                    ((fixed32 >> 16) & 0xff) as u8,
246                ]
247            });
248            SpzGaussianPosition::FixedPoint24(packed)
249        };
250
251        let scale = self
252            .scale
253            .to_array()
254            .map(|c| ((c.ln() + 10.0) * 16.0).round().clamp(0.0, 255.0) as u8);
255
256        let rotation = if header.uses_quat_smallest_three() {
257            let rot = self.rot.normalize().to_array();
258            let largest_index = rot
259                .into_iter()
260                .map(f32::abs)
261                .enumerate()
262                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
263                .expect("quaternion has at least one component")
264                .0;
265
266            const C_MASK: u32 = (1 << 9) - 1;
267
268            let negate = (rot[largest_index] < 0.0) as u32;
269
270            let mut comp = largest_index as u32;
271            for (i, &value) in rot.iter().enumerate() {
272                if i == largest_index {
273                    continue;
274                }
275
276                let neg_bit = (value < 0.0) as u32 ^ negate;
277                let mag = (C_MASK as f32 * (value.abs() * std::f32::consts::SQRT_2) + 0.5)
278                    .clamp(0.0, C_MASK as f32 - 1.0) as u32;
279                comp = (comp << 10) | (neg_bit << 9) | mag;
280            }
281
282            SpzGaussianRotation::QuatSmallestThree([
283                (comp & 0xff) as u8,
284                ((comp >> 8) & 0xff) as u8,
285                ((comp >> 16) & 0xff) as u8,
286                ((comp >> 24) & 0xff) as u8,
287            ])
288        } else {
289            let rot = self.rot.normalize();
290            let rot = if rot.w < 0.0 { -rot } else { rot };
291            let packed = rot
292                .xyz()
293                .to_array()
294                .map(|c| ((c + 1.0) * 127.5).round().clamp(0.0, 255.0) as u8);
295            SpzGaussianRotation::QuatFirstThree(packed)
296        };
297
298        let alpha = self.color.w;
299
300        let color = self
301            .color
302            .map(|c| {
303                ((c as f32 - Self::SPZ_COLOR_TO_LINEAR_C) / Self::SPZ_COLOR_TO_LINEAR_FRAC_A_B)
304                    .clamp(0.0, 255.0) as u8
305            })
306            .xyz()
307            .to_array();
308
309        let sh = match header.sh_degree().get() {
310            0 => SpzGaussianSh::Zero,
311            deg @ 1..=3 => {
312                let mut sh = match deg {
313                    1 => SpzGaussianSh::One([[0; 3]; 3]),
314                    2 => SpzGaussianSh::Two([[0; 3]; 8]),
315                    3 => SpzGaussianSh::Three([[0; 3]; 15]),
316                    _ => unreachable!(),
317                };
318
319                fn quantize_sh(x: f32, bucket_size: u32) -> u8 {
320                    let q = (x * 128.0 + 128.0).round() as u32;
321                    let q = if bucket_size >= 8 {
322                        q
323                    } else {
324                        (q + bucket_size / 2) / bucket_size * bucket_size
325                    };
326                    q.clamp(0, 255) as u8
327                }
328
329                for (src, dst) in self.sh.iter().zip(sh.iter_mut()) {
330                    let bucket_size = options
331                        .sh_bucket_size(deg)
332                        .expect("header SH degree is valid");
333                    *dst = src.to_array().map(|x| quantize_sh(x, bucket_size));
334                }
335
336                sh
337            }
338            _ => {
339                // SAFETY: SpzGaussianShDegree is guaranteed to be in [0, 3].
340                unreachable!()
341            }
342        };
343
344        SpzGaussian {
345            position,
346            scale,
347            rotation,
348            color,
349            alpha,
350            sh,
351        }
352    }
353}
354
355// It can be useful to implement `AsRef` for `Gaussian` and `&Gaussian` due to the frequent use of
356// `from_iter` for other source formats.
357
358impl AsRef<Gaussian> for Gaussian {
359    fn as_ref(&self) -> &Gaussian {
360        self
361    }
362}
363
364/// Extra options for [`Gaussian::to_spz`].
365#[derive(Debug, Clone, Copy, PartialEq)]
366pub struct GaussianToSpzOptions {
367    /// The quantization bits for each SH degree.
368    pub sh_quantize_bits: [u32; 3],
369}
370
371impl GaussianToSpzOptions {
372    /// Get the bits for the given SH degree.
373    pub fn sh_bits(&self, degree: u8) -> Option<u32> {
374        match degree {
375            1..=3 => Some(self.sh_quantize_bits[degree as usize - 1]),
376            _ => None,
377        }
378    }
379
380    /// Get the quantization bucket size for the given SH degree.
381    pub fn sh_bucket_size(&self, degree: u8) -> Option<u32> {
382        self.sh_bits(degree).map(|bits| 1 << (8 - bits))
383    }
384}
385
386impl Default for GaussianToSpzOptions {
387    fn default() -> Self {
388        Self {
389            sh_quantize_bits: [5, 4, 4],
390        }
391    }
392}
393
394/// A discriminant representation of [`Gaussians`].
395#[derive(Debug, Clone, Copy, PartialEq, Eq)]
396pub enum GaussiansSource {
397    Internal,
398    Ply,
399    Spz,
400}
401
402impl From<&Gaussians> for GaussiansSource {
403    fn from(value: &Gaussians) -> Self {
404        match value {
405            Gaussians::Internal(_) => GaussiansSource::Internal,
406            Gaussians::Ply(_) => GaussiansSource::Ply,
407            Gaussians::Spz(_) => GaussiansSource::Spz,
408        }
409    }
410}
411
412/// A unified Gaussian representation.
413///
414/// [`Gaussians::Internal`] variant contains Gaussians in the [`Gaussian`] format, which is the one
415/// converted to [`GaussianPod`](crate::GaussianPod) directly.
416///
417/// Other variants contain Gaussians in their respective source file formats.
418#[derive(Debug, Clone, PartialEq)]
419pub enum Gaussians {
420    Internal(Vec<Gaussian>),
421    Ply(PlyGaussians),
422    Spz(SpzGaussians),
423}
424
425impl Gaussians {
426    /// Create a collection of Gaussians from an iterator of [`Gaussian`] with the given source.
427    pub fn from_gaussians_iter(
428        iter: impl Iterator<Item = Gaussian>,
429        source: GaussiansSource,
430    ) -> Self {
431        match source {
432            GaussiansSource::Internal => Gaussians::Internal(iter.collect()),
433            GaussiansSource::Ply => Gaussians::Ply(iter.collect()),
434            GaussiansSource::Spz => Gaussians::Spz(iter.collect()),
435        }
436    }
437
438    /// Get the source representation of the Gaussians.
439    pub fn source(&self) -> GaussiansSource {
440        GaussiansSource::from(self)
441    }
442
443    /// Get the number of Gaussians.
444    pub fn len(&self) -> usize {
445        match self {
446            Gaussians::Internal(gaussians) => gaussians.len(),
447            Gaussians::Ply(ply_gaussians) => ply_gaussians.len(),
448            Gaussians::Spz(spz_gaussians) => spz_gaussians.len(),
449        }
450    }
451
452    /// Check if there is no Gaussian.
453    pub fn is_empty(&self) -> bool {
454        self.len() == 0
455    }
456
457    /// Read from a file with the given source.
458    pub fn read_from_file(
459        path: impl AsRef<std::path::Path>,
460        source: GaussiansSource,
461    ) -> std::io::Result<Self> {
462        match source {
463            GaussiansSource::Internal => Err(std::io::Error::new(
464                std::io::ErrorKind::InvalidInput,
465                "cannot read Internal Gaussians from file",
466            )),
467            GaussiansSource::Ply => {
468                let ply_gaussians = PlyGaussians::read_from_file(path)?;
469                Ok(Gaussians::Ply(ply_gaussians))
470            }
471            GaussiansSource::Spz => {
472                let spz_gaussians = SpzGaussians::read_from_file(path)?;
473                Ok(Gaussians::Spz(spz_gaussians))
474            }
475        }
476    }
477
478    /// Read from a buffer with the given source.
479    pub fn read_from(reader: &mut impl BufRead, source: GaussiansSource) -> std::io::Result<Self> {
480        match source {
481            GaussiansSource::Internal => Err(std::io::Error::new(
482                std::io::ErrorKind::InvalidInput,
483                "cannot read Internal Gaussians from buffer",
484            )),
485            GaussiansSource::Ply => {
486                let ply_gaussians = PlyGaussians::read_from(reader)?;
487                Ok(Gaussians::Ply(ply_gaussians))
488            }
489            GaussiansSource::Spz => {
490                let spz_gaussians = SpzGaussians::read_from(reader)?;
491                Ok(Gaussians::Spz(spz_gaussians))
492            }
493        }
494    }
495
496    /// Write to a file with the given source.
497    pub fn write_to_file(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> {
498        match self {
499            Gaussians::Internal(_) => Err(std::io::Error::new(
500                std::io::ErrorKind::InvalidInput,
501                "cannot write Internal Gaussians to file",
502            )),
503            Gaussians::Ply(ply_gaussians) => ply_gaussians.write_to_file(path),
504            Gaussians::Spz(spz_gaussians) => spz_gaussians.write_to_file(path),
505        }
506    }
507
508    /// Write to a buffer with the given source.
509    pub fn write_to(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
510        match self {
511            Gaussians::Internal(_) => Err(std::io::Error::new(
512                std::io::ErrorKind::InvalidInput,
513                "cannot write Internal Gaussians to buffer",
514            )),
515            Gaussians::Ply(ply_gaussians) => ply_gaussians.write_to(writer),
516            Gaussians::Spz(spz_gaussians) => spz_gaussians.write_to(writer),
517        }
518    }
519}
520
521impl From<Vec<Gaussian>> for Gaussians {
522    fn from(value: Vec<Gaussian>) -> Self {
523        Gaussians::Internal(value)
524    }
525}
526
527impl From<PlyGaussians> for Gaussians {
528    fn from(value: PlyGaussians) -> Self {
529        Gaussians::Ply(value)
530    }
531}
532
533impl From<SpzGaussians> for Gaussians {
534    fn from(value: SpzGaussians) -> Self {
535        Gaussians::Spz(value)
536    }
537}
538
539impl IterGaussian for Gaussians {
540    fn iter_gaussian(&self) -> impl ExactSizeIterator<Item = Gaussian> + '_ {
541        match self {
542            Gaussians::Internal(gaussians) => GaussiansIter::Internal(gaussians.iter_gaussian()),
543            Gaussians::Ply(ply_gaussians) => GaussiansIter::Ply(ply_gaussians.iter_gaussian()),
544            Gaussians::Spz(spz_gaussians) => GaussiansIter::Spz(spz_gaussians.iter_gaussian()),
545        }
546    }
547}
548
549impl FromIterator<Gaussian> for Gaussians {
550    fn from_iter<T: IntoIterator<Item = Gaussian>>(iter: T) -> Self {
551        Gaussians::Internal(iter.into_iter().collect())
552    }
553}
554
555/// Trait to extend [`Iterator`] of [`Gaussian`] to collect into [`Gaussians`].
556pub trait IteratorGaussianExt: Iterator<Item = Gaussian> + Sized {
557    /// Collect the iterator into [`Gaussians`] with the given source.
558    fn collect_gaussians(self, source: GaussiansSource) -> Gaussians {
559        Gaussians::from_gaussians_iter(self, source)
560    }
561}
562
563impl<T: Iterator<Item = Gaussian>> IteratorGaussianExt for T {}
564
565/// Iterator for [`Gaussians`].
566#[derive(Debug, Clone)]
567pub enum GaussiansIter<
568    InternalIter: Iterator<Item = Gaussian>,
569    PlyIter: Iterator<Item = Gaussian>,
570    SpzIter: Iterator<Item = Gaussian>,
571> {
572    Internal(InternalIter),
573    Ply(PlyIter),
574    Spz(SpzIter),
575}
576
577impl<
578    InternalIter: Iterator<Item = Gaussian>,
579    PlyIter: Iterator<Item = Gaussian>,
580    SpzIter: Iterator<Item = Gaussian>,
581> Iterator for GaussiansIter<InternalIter, PlyIter, SpzIter>
582{
583    type Item = Gaussian;
584
585    fn next(&mut self) -> Option<Self::Item> {
586        match self {
587            GaussiansIter::Internal(iter) => iter.next(),
588            GaussiansIter::Ply(iter) => iter.next(),
589            GaussiansIter::Spz(iter) => iter.next(),
590        }
591    }
592
593    fn size_hint(&self) -> (usize, Option<usize>) {
594        match self {
595            GaussiansIter::Internal(iter) => iter.size_hint(),
596            GaussiansIter::Ply(iter) => iter.size_hint(),
597            GaussiansIter::Spz(iter) => iter.size_hint(),
598        }
599    }
600}
601
602impl<
603    InternalIter: ExactSizeIterator<Item = Gaussian>,
604    PlyIter: ExactSizeIterator<Item = Gaussian>,
605    SpzIter: ExactSizeIterator<Item = Gaussian>,
606> ExactSizeIterator for GaussiansIter<InternalIter, PlyIter, SpzIter>
607{
608}