wgpu_3dgs_core/source_format/
spz.rs

1use std::{
2    io::{Read, Write},
3    ops::RangeInclusive,
4};
5
6use flate2::{read::GzDecoder, write::GzEncoder};
7use itertools::Itertools;
8
9use crate::{
10    Gaussian, GaussianToSpzOptions, IterGaussian, ReadIterGaussian, SpzGaussiansFromIterError,
11    WriteIterGaussian,
12};
13
14macro_rules! gaussian_field {
15    (
16        #[docname = $docname:literal]
17        $name:ident {
18            $(
19                $(#[doc = $doc:literal])?
20                $variant:ident $(($ty:ty))?
21            ),+ $(,)?
22        }
23    ) => {
24        paste::paste! {
25            macro_rules! noop {
26                ($tt:tt) => {};
27                ($tt:tt _) => {_};
28            }
29
30            #[doc = "A single SPZ Gaussian "]
31            #[doc = $docname]
32            #[doc = " field."]
33            #[derive(Debug, Clone, PartialEq)]
34            pub enum [< SpzGaussian $name >]  {
35                $(
36                    $(#[doc = $doc])?
37                    $variant $(($ty))?,
38                )+
39            }
40
41            #[doc = "Reference to SPZ Gaussian "]
42            #[doc = $docname]
43            #[doc = " field."]
44            #[derive(Debug, Clone, Copy, PartialEq)]
45            pub enum [< SpzGaussian $name Ref>]<'a> {
46                $(
47                    $(#[doc = $doc])?
48                    $variant $((&'a $ty))?,
49                )+
50            }
51
52            #[doc = "Iterator over SPZ Gaussian "]
53            #[doc = $docname]
54            #[doc = " references."]
55            pub enum [< SpzGaussian $name Iter >]<'a> {
56                $(
57                    $(#[doc = $doc])?
58                    $variant $((std::slice::Iter<'a, $ty>))?,
59                )+
60            }
61
62            impl<'a> Iterator for [< SpzGaussian $name Iter >]<'a> {
63                type Item = [< SpzGaussian $name Ref >]<'a>;
64
65                fn next(&mut self) -> Option<Self::Item> {
66                    macro_rules! body {
67                        ($variant_:ident, $ty_:ty, $iter:expr) => {
68                            $iter.next().map(|v| [< SpzGaussian $name Ref >]:: $variant_ (v))
69                        };
70                        ($variant_:ident) => {
71                            Some([< SpzGaussian $name Ref >]:: $variant_)
72                        };
73                    }
74
75                    match self {
76                        $(
77                            #[allow(clippy::redundant_pattern)]
78                            [< SpzGaussian $name Iter >]:: $variant $( (iter @ noop!($ty _)) )? => {
79                                body!($variant $(, $ty, iter )?)
80                            }
81                        )+
82                    }
83                }
84
85                fn size_hint(&self) -> (usize, Option<usize>) {
86                    match self {
87                        $(
88                            #[allow(clippy::redundant_pattern)]
89                            [< SpzGaussian $name Iter >]:: $variant $( (iter @ noop!($ty _)) )? => {
90                                #[allow(unused_variables)]
91                                let len = 0;
92                                $(
93                                    noop!($ty);
94                                    let len = iter.len();
95                                )?
96                                (len, Some(len))
97                            }
98                        )+
99                    }
100                }
101            }
102
103            impl<'a> ExactSizeIterator for [< SpzGaussian $name Iter >]<'a> {}
104
105            #[doc = "Representation of SPZ Gaussians "]
106            #[doc = $docname]
107            #[doc = "s."]
108            #[derive(Debug, Clone, PartialEq)]
109            pub enum [< SpzGaussians $name s>] {
110                $(
111                    $(#[doc = $doc])?
112                    $variant $((Vec<$ty>))?,
113                )+
114            }
115
116            impl [< SpzGaussians $name s>] {
117                /// Get the number of elements.
118                pub fn len(&self) -> usize {
119                    match self {
120                        $(
121                            #[allow(clippy::redundant_pattern)]
122                            [< SpzGaussians $name s>]:: $variant $( (vec @ noop!($ty _)) )? => {
123                                #[allow(unused_variables)]
124                                let len = 0;
125                                $(
126                                    noop!($ty);
127                                    let len = vec.len();
128                                )?
129                                len
130                            }
131                        )+
132                    }
133                }
134
135                /// Check if empty.
136                pub fn is_empty(&self) -> bool {
137                    self.len() == 0
138                }
139
140                /// Get an iterator over references.
141                pub fn iter<'a>(&'a self) -> [< SpzGaussian $name Iter >]<'a> {
142                    macro_rules! body {
143                        ($variant_:ident, $ty_:ty, $vec:expr) => {
144                            [< SpzGaussian $name Iter >]:: $variant_ ( $vec.iter() )
145                        };
146                        ($variant_:ident) => {
147                            [< SpzGaussian $name Iter >]:: $variant_
148                        };
149                    }
150
151                    match self {
152                        $(
153                            #[allow(clippy::redundant_pattern)]
154                            [< SpzGaussians $name s>]:: $variant $( (vec @ noop!($ty _)) )? => {
155                                body!($variant $(, $ty, vec )?)
156                            }
157                        )+
158                    }
159                }
160            }
161
162            impl FromIterator<[< SpzGaussian $name >]> for Result<
163                [< SpzGaussians $name s>],
164                $crate::error::SpzGaussiansCollectError<[< SpzGaussian $name >]>
165            > {
166                fn from_iter<I: IntoIterator<Item = [< SpzGaussian $name >]>>(iter: I) -> Self {
167                    let mut iter = iter.into_iter();
168                    let Some(first) = iter.next() else {
169                        return Err($crate::error::SpzGaussiansCollectError::EmptyIterator);
170                    };
171
172                    #[allow(unused_variables)]
173                    let first_value = ();
174                    match first {
175                        $(
176                            #[allow(clippy::redundant_pattern)]
177                            [< SpzGaussian $name >]:: $variant $( (first_value @ noop!($ty _)) )? => {
178                                #[allow(unused_variables)]
179                                let value = ();
180                                #[allow(unused_variables)]
181                                let vec = std::iter::once(Ok(first_value))
182                                    .chain(
183                                        iter.map(|v| {
184                                            match v {
185                                                [< SpzGaussian $name >]:: $variant $( (
186                                                    value @ noop!($ty _)
187                                                ) )? => Ok(value),
188                                                other => Err(
189                                                    $crate::error::SpzGaussiansCollectError::InvalidMixedVariant {
190                                                        first_variant: [< SpzGaussian $name >]:: $variant $( (
191                                                            { noop!($ty); first_value }
192                                                        ) )?,
193                                                        current_variant: other,
194                                                    }
195                                                ),
196                                            }
197                                        })
198                                    )
199                                    .collect::<Result<Vec<_>, _>>()?;
200                                Ok([< SpzGaussians $name s>]:: $variant $( ({ noop!($ty); vec }) )?)
201                            }
202                        )+
203                    }
204                }
205            }
206        }
207    }
208}
209
210gaussian_field! {
211    #[docname = "position"]
212    Position {
213        #[doc = "(x, y, z) each as 16-bit floating point."]
214        Float16([u16; 3]),
215        #[doc = "(x, y, z) each as 24-bit fixed point signed integer."]
216        FixedPoint24([[u8; 3]; 3]),
217    }
218}
219
220gaussian_field! {
221    #[docname = "rotation"]
222    Rotation {
223        #[doc = "(x, y, z) each as 8-bit signed integer."]
224        QuatFirstThree([u8; 3]),
225        #[doc = "Smallest 3 components each as 10-bit signed integer. 2 bits for index of omitted component."]
226        QuatSmallestThree([u8; 4]),
227    }
228}
229
230gaussian_field! {
231    #[docname = "SH coefficients"]
232    Sh {
233        Zero,
234        One([[u8; 3]; 3]),
235        Two([[u8; 3]; 8]),
236        Three([[u8; 3]; 15]),
237    }
238}
239
240impl SpzGaussianSh {
241    /// Get the SH degree.
242    pub fn degree(&self) -> SpzGaussianShDegree {
243        match self {
244            SpzGaussianSh::Zero => unsafe { SpzGaussianShDegree::new_unchecked(0) },
245            SpzGaussianSh::One(_) => unsafe { SpzGaussianShDegree::new_unchecked(1) },
246            SpzGaussianSh::Two(_) => unsafe { SpzGaussianShDegree::new_unchecked(2) },
247            SpzGaussianSh::Three(_) => unsafe { SpzGaussianShDegree::new_unchecked(3) },
248        }
249    }
250
251    /// Get an iterator over SH coefficients.
252    pub fn iter(&self) -> impl Iterator<Item = &[u8; 3]> {
253        match self {
254            SpzGaussianSh::Zero => [].iter(),
255            SpzGaussianSh::One(sh) => sh.iter(),
256            SpzGaussianSh::Two(sh) => sh.iter(),
257            SpzGaussianSh::Three(sh) => sh.iter(),
258        }
259    }
260
261    /// Get an iterator over mutable SH coefficients.
262    pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut [u8; 3]> {
263        match self {
264            SpzGaussianSh::Zero => [].iter_mut(),
265            SpzGaussianSh::One(sh) => sh.iter_mut(),
266            SpzGaussianSh::Two(sh) => sh.iter_mut(),
267            SpzGaussianSh::Three(sh) => sh.iter_mut(),
268        }
269    }
270}
271
272impl SpzGaussianShRef<'_> {
273    /// Get the SH degree.
274    pub fn degree(&self) -> SpzGaussianShDegree {
275        match self {
276            SpzGaussianShRef::Zero => unsafe { SpzGaussianShDegree::new_unchecked(0) },
277            SpzGaussianShRef::One(_) => unsafe { SpzGaussianShDegree::new_unchecked(1) },
278            SpzGaussianShRef::Two(_) => unsafe { SpzGaussianShDegree::new_unchecked(2) },
279            SpzGaussianShRef::Three(_) => unsafe { SpzGaussianShDegree::new_unchecked(3) },
280        }
281    }
282
283    /// Get an iterator over SH coefficients.
284    pub fn iter(&self) -> impl Iterator<Item = &[u8; 3]> + '_ {
285        match self {
286            SpzGaussianShRef::Zero => [].iter(),
287            SpzGaussianShRef::One(sh) => sh.iter(),
288            SpzGaussianShRef::Two(sh) => sh.iter(),
289            SpzGaussianShRef::Three(sh) => sh.iter(),
290        }
291    }
292}
293
294impl SpzGaussiansShs {
295    /// Get the SH degree.
296    pub fn degree(&self) -> SpzGaussianShDegree {
297        match self {
298            SpzGaussiansShs::Zero => unsafe { SpzGaussianShDegree::new_unchecked(0) },
299            SpzGaussiansShs::One(_) => unsafe { SpzGaussianShDegree::new_unchecked(1) },
300            SpzGaussiansShs::Two(_) => unsafe { SpzGaussianShDegree::new_unchecked(2) },
301            SpzGaussiansShs::Three(_) => unsafe { SpzGaussianShDegree::new_unchecked(3) },
302        }
303    }
304}
305
306/// The SPZ Gaussian spherical harmonics degrees.
307#[repr(transparent)]
308#[derive(Debug, Clone, Copy, PartialEq, Eq, bytemuck::Pod, bytemuck::Zeroable)]
309pub struct SpzGaussianShDegree(u8);
310
311impl SpzGaussianShDegree {
312    /// Create a new SPZ Gaussian SH degree.
313    ///
314    /// Returns [`None`] if the degree is not in the range of [`SpzGaussiansHeader::SUPPORTED_SH_DEGREES`].
315    pub const fn new(sh_deg: u8) -> Option<Self> {
316        match sh_deg {
317            0..=3 => Some(Self(sh_deg)),
318            _ => None,
319        }
320    }
321
322    /// Create a new SPZ Gaussian SH degree without checking.
323    ///
324    /// # Safety
325    ///
326    /// The degree must be in the range of [`SpzGaussiansHeader::SUPPORTED_SH_DEGREES`].
327    pub const unsafe fn new_unchecked(sh_deg: u8) -> Self {
328        Self(sh_deg)
329    }
330
331    /// Get the degree.
332    pub const fn get(&self) -> u8 {
333        self.0
334    }
335
336    /// Get the number of SH coefficients.
337    pub const fn num_coefficients(&self) -> usize {
338        match self.0 {
339            0 => 0,
340            1 => 3,
341            2 => 8,
342            3 => 15,
343            _ => unreachable!(),
344        }
345    }
346}
347
348impl Default for SpzGaussianShDegree {
349    fn default() -> Self {
350        // SAFETY: 3 is in the range of [0, 3].
351        unsafe { Self::new_unchecked(3) }
352    }
353}
354
355/// A single SPZ Gaussian.
356///
357/// This is usually only used for [`SpzGaussians::from_iter`].
358#[derive(Debug, Clone, PartialEq)]
359pub struct SpzGaussian {
360    pub position: SpzGaussianPosition,
361    pub scale: [u8; 3],
362    pub rotation: SpzGaussianRotation,
363    pub alpha: u8,
364    pub color: [u8; 3],
365    pub sh: SpzGaussianSh,
366}
367
368impl SpzGaussian {
369    /// Get a [`SpzGaussianRef`] reference to this Gaussian.
370    pub fn as_ref(&self) -> SpzGaussianRef<'_> {
371        SpzGaussianRef {
372            position: match &self.position {
373                SpzGaussianPosition::Float16(v) => SpzGaussianPositionRef::Float16(v),
374                SpzGaussianPosition::FixedPoint24(v) => SpzGaussianPositionRef::FixedPoint24(v),
375            },
376            scale: &self.scale,
377            rotation: match &self.rotation {
378                SpzGaussianRotation::QuatFirstThree(v) => SpzGaussianRotationRef::QuatFirstThree(v),
379                SpzGaussianRotation::QuatSmallestThree(v) => {
380                    SpzGaussianRotationRef::QuatSmallestThree(v)
381                }
382            },
383            alpha: &self.alpha,
384            color: &self.color,
385            sh: match &self.sh {
386                SpzGaussianSh::Zero => SpzGaussianShRef::Zero,
387                SpzGaussianSh::One(v) => SpzGaussianShRef::One(v),
388                SpzGaussianSh::Two(v) => SpzGaussianShRef::Two(v),
389                SpzGaussianSh::Three(v) => SpzGaussianShRef::Three(v),
390            },
391        }
392    }
393}
394
395/// Reference to a SPZ Gaussian.
396#[derive(Debug, Clone, Copy, PartialEq)]
397pub struct SpzGaussianRef<'a> {
398    pub position: SpzGaussianPositionRef<'a>,
399    pub scale: &'a [u8; 3],
400    pub rotation: SpzGaussianRotationRef<'a>,
401    pub alpha: &'a u8,
402    pub color: &'a [u8; 3],
403    pub sh: SpzGaussianShRef<'a>,
404}
405
406impl SpzGaussianRef<'_> {
407    /// Convert to [`SpzGaussian`].
408    pub fn to_inner_owned(&self) -> SpzGaussian {
409        SpzGaussian {
410            position: match self.position {
411                SpzGaussianPositionRef::Float16(v) => SpzGaussianPosition::Float16(*v),
412                SpzGaussianPositionRef::FixedPoint24(v) => SpzGaussianPosition::FixedPoint24(*v),
413            },
414            scale: *self.scale,
415            rotation: match self.rotation {
416                SpzGaussianRotationRef::QuatFirstThree(v) => {
417                    SpzGaussianRotation::QuatFirstThree(*v)
418                }
419                SpzGaussianRotationRef::QuatSmallestThree(v) => {
420                    SpzGaussianRotation::QuatSmallestThree(*v)
421                }
422            },
423            alpha: *self.alpha,
424            color: *self.color,
425            sh: match self.sh {
426                SpzGaussianShRef::Zero => SpzGaussianSh::Zero,
427                SpzGaussianShRef::One(v) => SpzGaussianSh::One(*v),
428                SpzGaussianShRef::Two(v) => SpzGaussianSh::Two(*v),
429                SpzGaussianShRef::Three(v) => SpzGaussianSh::Three(*v),
430            },
431        }
432    }
433}
434
435/// Header of SPZ Gaussians file.
436#[repr(C)]
437#[derive(Debug, Clone, Copy, PartialEq, Eq, bytemuck::Pod, bytemuck::Zeroable)]
438pub struct SpzGaussiansHeaderPod {
439    pub magic: u32,
440    pub version: u32,
441    pub num_points: u32,
442    pub sh_degree: SpzGaussianShDegree,
443    pub fractional_bits: u8,
444    pub flags: u8,
445    pub reserved: u8,
446}
447
448/// Header of SPZ Gaussians file.
449///
450/// This is the validated version of [`SpzGaussiansHeaderPod`]. This is simply a wrapper around
451/// [`SpzGaussiansHeaderPod`] that ensures the values are valid, we could also implement
452/// specialized structs for each field but it would be overkill for now.
453#[derive(Debug, Clone, Copy, PartialEq, Eq)]
454pub struct SpzGaussiansHeader(SpzGaussiansHeaderPod);
455
456impl SpzGaussiansHeader {
457    /// The magic number for SPZ Gaussians files.
458    pub const MAGIC: u32 = 0x5053474e; // "NGSP"
459
460    /// The supported SPZ versions.
461    pub const SUPPORTED_VERSIONS: RangeInclusive<u32> = 1..=3;
462
463    /// The supported SH degrees.
464    pub const SUPPORTED_SH_DEGREES: RangeInclusive<u8> = 0..=3;
465
466    /// Create a [`SpzGaussiansHeader`].
467    ///
468    /// Returns an error if the header is invalid.
469    pub fn new(
470        version: u32,
471        num_points: u32,
472        sh_degree: SpzGaussianShDegree,
473        fractional_bits: u8,
474        antialiased: bool,
475    ) -> Result<Self, std::io::Error> {
476        Self::try_from_pod(SpzGaussiansHeaderPod {
477            magic: Self::MAGIC,
478            version,
479            num_points,
480            sh_degree,
481            fractional_bits,
482            flags: if antialiased { 0x1 } else { 0x0 },
483            reserved: 0,
484        })
485    }
486
487    /// Validate and create a validated SPZ Gaussians header.
488    pub fn try_from_pod(pod: SpzGaussiansHeaderPod) -> Result<Self, std::io::Error> {
489        if pod.magic != Self::MAGIC {
490            return Err(std::io::Error::new(
491                std::io::ErrorKind::InvalidData,
492                format!(
493                    "Invalid SPZ magic number: {:X}, expected {:X}",
494                    pod.magic,
495                    Self::MAGIC
496                ),
497            ));
498        }
499
500        if !Self::SUPPORTED_VERSIONS.contains(&pod.version) {
501            return Err(std::io::Error::new(
502                std::io::ErrorKind::InvalidData,
503                format!(
504                    "Unsupported SPZ version: {}, expected one of {:?}",
505                    pod.version,
506                    Self::SUPPORTED_VERSIONS
507                ),
508            ));
509        }
510
511        Ok(Self(pod))
512    }
513
514    /// Create a default [`SpzGaussiansHeader`] from number of points and SH degree.
515    pub fn default(num_points: u32) -> Result<Self, std::io::Error> {
516        Self::new(
517            Self::SUPPORTED_VERSIONS
518                .last()
519                .expect("at least one supported version"),
520            num_points,
521            SpzGaussianShDegree::default(),
522            12,
523            false,
524        )
525    }
526
527    /// Get the [`SpzGaussiansHeaderPod`].
528    pub fn as_pod(&self) -> &SpzGaussiansHeaderPod {
529        &self.0
530    }
531
532    /// Get the version of the SPZ file.
533    pub fn version(&self) -> u32 {
534        self.0.version
535    }
536
537    /// Set the number of points.
538    ///
539    /// Setting the number of points does not invalidate the header.
540    pub fn set_num_points(&mut self, num_points: u32) {
541        self.0.num_points = num_points;
542    }
543
544    /// Get the number of points in the SPZ file.
545    pub fn num_points(&self) -> usize {
546        self.0.num_points as usize
547    }
548
549    /// Get the SH degree of the SPZ file.
550    pub fn sh_degree(&self) -> SpzGaussianShDegree {
551        self.0.sh_degree
552    }
553
554    /// Get the number of SH coefficients.
555    pub fn sh_num_coefficients(&self) -> usize {
556        self.0.sh_degree.num_coefficients()
557    }
558
559    /// Get the number of fractional bits.
560    pub fn fractional_bits(&self) -> u8 {
561        self.0.fractional_bits
562    }
563
564    /// Check if the antialiased flag is set.
565    pub fn is_antialiased(&self) -> bool {
566        (self.0.flags & 0x1) != 0
567    }
568
569    /// Check if float16 encoding is used.
570    pub fn uses_float16(&self) -> bool {
571        self.version() == 1
572    }
573
574    /// Check if quaternion smallest three encoding is used.
575    pub fn uses_quat_smallest_three(&self) -> bool {
576        self.version() >= 3
577    }
578}
579
580impl SpzGaussiansPositions {
581    /// Read positions from reader.
582    pub fn read_from(
583        reader: &mut impl Read,
584        count: usize,
585        uses_float16: bool,
586    ) -> Result<Self, std::io::Error> {
587        if uses_float16 {
588            let mut positions = vec![[0u16; 3]; count];
589            reader.read_exact(bytemuck::cast_slice_mut(&mut positions))?;
590            Ok(SpzGaussiansPositions::Float16(positions))
591        } else {
592            let mut positions = vec![[[0u8; 3]; 3]; count];
593            reader.read_exact(bytemuck::cast_slice_mut(&mut positions))?;
594            Ok(SpzGaussiansPositions::FixedPoint24(positions))
595        }
596    }
597
598    /// Write positions to writer.
599    pub fn write_to(&self, writer: &mut impl Write) -> Result<(), std::io::Error> {
600        match self {
601            SpzGaussiansPositions::Float16(positions) => {
602                writer.write_all(bytemuck::cast_slice(positions))
603            }
604            SpzGaussiansPositions::FixedPoint24(positions) => {
605                writer.write_all(bytemuck::cast_slice(positions))
606            }
607        }
608    }
609}
610
611impl SpzGaussiansRotations {
612    /// Read rotations from reader.
613    pub fn read_from(
614        reader: &mut impl Read,
615        count: usize,
616        uses_quat_smallest_three: bool,
617    ) -> Result<Self, std::io::Error> {
618        if !uses_quat_smallest_three {
619            let mut rots = vec![[0u8; 3]; count];
620            reader.read_exact(bytemuck::cast_slice_mut(&mut rots))?;
621            Ok(SpzGaussiansRotations::QuatFirstThree(rots))
622        } else {
623            let mut rots = vec![[0u8; 4]; count];
624            reader.read_exact(bytemuck::cast_slice_mut(&mut rots))?;
625            Ok(SpzGaussiansRotations::QuatSmallestThree(rots))
626        }
627    }
628
629    /// Write rotations to writer.
630    pub fn write_to(&self, writer: &mut impl Write) -> Result<(), std::io::Error> {
631        match self {
632            SpzGaussiansRotations::QuatFirstThree(rots) => {
633                writer.write_all(bytemuck::cast_slice(rots))
634            }
635            SpzGaussiansRotations::QuatSmallestThree(rots) => {
636                writer.write_all(bytemuck::cast_slice(rots))
637            }
638        }
639    }
640}
641
642impl SpzGaussiansShs {
643    /// Read SH coefficients from reader.
644    pub fn read_from(
645        reader: &mut impl Read,
646        count: usize,
647        sh_degree: SpzGaussianShDegree,
648    ) -> Result<Self, std::io::Error> {
649        match sh_degree.get() {
650            0 => Ok(SpzGaussiansShs::Zero),
651            1 => {
652                let mut sh_coeffs = vec![[[0u8; 3]; 3]; count];
653                reader.read_exact(bytemuck::cast_slice_mut(&mut sh_coeffs))?;
654                Ok(SpzGaussiansShs::One(sh_coeffs))
655            }
656            2 => {
657                let mut sh_coeffs = vec![[[0u8; 3]; 8]; count];
658                reader.read_exact(bytemuck::cast_slice_mut(&mut sh_coeffs))?;
659                Ok(SpzGaussiansShs::Two(sh_coeffs))
660            }
661            3 => {
662                let mut sh_coeffs = vec![[[0u8; 3]; 15]; count];
663                reader.read_exact(bytemuck::cast_slice_mut(&mut sh_coeffs))?;
664                Ok(SpzGaussiansShs::Three(sh_coeffs))
665            }
666            _ => {
667                // SAFETY: SpzGaussianShDegree guarantees the degree is in [0, 3].
668                unreachable!()
669            }
670        }
671    }
672
673    /// Write SH coefficients to writer.
674    pub fn write_to(&self, writer: &mut impl Write) -> Result<(), std::io::Error> {
675        match self {
676            SpzGaussiansShs::Zero => Ok(()),
677            SpzGaussiansShs::One(sh_coeffs) => writer.write_all(bytemuck::cast_slice(sh_coeffs)),
678            SpzGaussiansShs::Two(sh_coeffs) => writer.write_all(bytemuck::cast_slice(sh_coeffs)),
679            SpzGaussiansShs::Three(sh_coeffs) => writer.write_all(bytemuck::cast_slice(sh_coeffs)),
680        }
681    }
682}
683
684/// A collection of Gaussians in SPZ format.
685#[derive(Debug, Clone, PartialEq)]
686pub struct SpzGaussians {
687    pub header: SpzGaussiansHeader,
688
689    pub positions: SpzGaussiansPositions,
690
691    /// `(x, y, z)` each as 8-bit log-encoded integer.
692    pub scales: Vec<[u8; 3]>,
693
694    pub rotations: SpzGaussiansRotations,
695
696    /// 8-bit unsigned integer.
697    pub alphas: Vec<u8>,
698
699    /// `(r, g, b)` each as 8-bit unsigned integer.
700    pub colors: Vec<[u8; 3]>,
701
702    pub shs: SpzGaussiansShs,
703}
704
705impl SpzGaussians {
706    /// Get the number of Gaussians.
707    pub fn len(&self) -> usize {
708        self.header.num_points()
709    }
710
711    /// Check if there are no Gaussians.
712    pub fn is_empty(&self) -> bool {
713        self.len() == 0
714    }
715
716    /// Read a SPZ from a decompressed buffer.
717    ///
718    /// `reader` should be decompressed SPZ buffer.
719    pub fn read_decompressed(reader: &mut impl Read) -> Result<Self, std::io::Error> {
720        let header = Self::read_header(reader)?;
721        Self::read_gaussians(reader, header)
722    }
723
724    /// Read a SPZ header.
725    ///
726    /// `reader` should be decompressed SPZ buffer.
727    pub fn read_header(reader: &mut impl Read) -> Result<SpzGaussiansHeader, std::io::Error> {
728        let mut header_bytes = [0u8; std::mem::size_of::<SpzGaussiansHeaderPod>()];
729        reader.read_exact(&mut header_bytes)?;
730        let header: SpzGaussiansHeaderPod = bytemuck::cast(header_bytes);
731        SpzGaussiansHeader::try_from_pod(header)
732    }
733
734    /// Read the SPZ Gaussians.
735    ///
736    /// `reader` should be decompressed SPZ buffer positioned after the header.
737    ///
738    /// `header` may be parsed by calling [`SpzGaussians::read_header`].
739    pub fn read_gaussians(
740        reader: &mut impl Read,
741        header: SpzGaussiansHeader,
742    ) -> Result<Self, std::io::Error> {
743        let count = header.num_points();
744        let uses_float16 = header.uses_float16();
745        let uses_quat_smallest_three = header.uses_quat_smallest_three();
746
747        let positions = SpzGaussiansPositions::read_from(reader, count, uses_float16)?;
748
749        let mut alphas = vec![0u8; count];
750        reader.read_exact(bytemuck::cast_slice_mut(&mut alphas))?;
751
752        let mut colors = vec![[0u8; 3]; count];
753        reader.read_exact(bytemuck::cast_slice_mut(&mut colors))?;
754
755        let mut scales = vec![[0u8; 3]; count];
756        reader.read_exact(bytemuck::cast_slice_mut(&mut scales))?;
757
758        let rotations = SpzGaussiansRotations::read_from(reader, count, uses_quat_smallest_three)?;
759
760        let shs = SpzGaussiansShs::read_from(reader, count, header.sh_degree())?;
761
762        Ok(SpzGaussians {
763            header,
764            positions,
765            scales,
766            rotations,
767            alphas,
768            colors,
769            shs,
770        })
771    }
772
773    /// Write the Gaussians to a SPZ buffer.
774    ///
775    /// `writer` will receive the decompressed SPZ buffer.
776    pub fn write_decompressed(&self, writer: &mut impl Write) -> Result<(), std::io::Error> {
777        writer.write_all(bytemuck::cast_slice(std::slice::from_ref(
778            self.header.as_pod(),
779        )))?;
780
781        self.positions.write_to(writer)?;
782
783        writer.write_all(bytemuck::cast_slice(&self.alphas))?;
784
785        writer.write_all(bytemuck::cast_slice(&self.colors))?;
786
787        writer.write_all(bytemuck::cast_slice(&self.scales))?;
788
789        self.rotations.write_to(writer)?;
790
791        self.shs.write_to(writer)?;
792
793        Ok(())
794    }
795
796    /// Convert from a slice of [`Gaussian`]s.
797    pub fn from_gaussians(gaussians: impl IntoIterator<Item = impl AsRef<Gaussian>>) -> Self {
798        Self::from_gaussians_with_options(
799            gaussians,
800            &SpzGaussiansFromGaussianSliceOptions::default(),
801        )
802        .expect("valid default options")
803    }
804
805    /// Convert from a slice of [`Gaussian`]s with options.
806    pub fn from_gaussians_with_options(
807        gaussians: impl IntoIterator<Item = impl AsRef<Gaussian>>,
808        options: &SpzGaussiansFromGaussianSliceOptions,
809    ) -> Result<Self, std::io::Error> {
810        let mut header = SpzGaussiansHeader::new(
811            options.version,
812            0,
813            options.sh_degree,
814            options.fractional_bits,
815            options.antialiased,
816        )?;
817
818        let gaussians = gaussians
819            .into_iter()
820            .map(|g| {
821                g.as_ref().to_spz(
822                    &header,
823                    &GaussianToSpzOptions {
824                        sh_quantize_bits: options.sh_quantize_bits,
825                    },
826                )
827            })
828            .collect::<Vec<_>>();
829
830        header.set_num_points(gaussians.len() as u32);
831
832        Ok(Self::from_iter(header, gaussians)
833            .expect("gaussians from valid Gaussians with valid header are valid"))
834    }
835
836    /// Convert from an [`IntoIterator`] of [`SpzGaussian`]s.
837    pub fn from_iter(
838        header: SpzGaussiansHeader,
839        iter: impl IntoIterator<Item = SpzGaussian>,
840    ) -> Result<Self, SpzGaussiansFromIterError> {
841        let (positions, scales, rotations, alphas, colors, shs) = iter
842            .into_iter()
843            .map(|spz| {
844                (
845                    spz.position,
846                    spz.scale,
847                    spz.rotation,
848                    spz.alpha,
849                    spz.color,
850                    spz.sh,
851                )
852            })
853            .multiunzip::<(Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>)>();
854
855        let positions = positions
856            .into_iter()
857            .collect::<Result<_, _>>()
858            .map_err(SpzGaussiansFromIterError::InvalidMixedPositionVariant)?;
859
860        let rotations = rotations
861            .into_iter()
862            .collect::<Result<_, _>>()
863            .map_err(SpzGaussiansFromIterError::InvalidMixedRotationVariant)?;
864
865        let shs = shs
866            .into_iter()
867            .collect::<Result<_, _>>()
868            .map_err(SpzGaussiansFromIterError::InvalidMixedShVariant)?;
869
870        if positions.len() != header.num_points() {
871            return Err(SpzGaussiansFromIterError::CountMismatch {
872                actual_count: positions.len(),
873                header_count: header.num_points(),
874            });
875        }
876
877        if matches!(positions, SpzGaussiansPositions::Float16(_)) != header.uses_float16() {
878            return Err(SpzGaussiansFromIterError::PositionFloat16Mismatch {
879                is_float16: matches!(positions, SpzGaussiansPositions::Float16(_)),
880                header_uses_float16: header.uses_float16(),
881            });
882        }
883
884        if matches!(rotations, SpzGaussiansRotations::QuatSmallestThree(_))
885            != header.uses_quat_smallest_three()
886        {
887            return Err(
888                SpzGaussiansFromIterError::RotationQuatSmallestThreeMismatch {
889                    is_quat_smallest_three: matches!(
890                        rotations,
891                        SpzGaussiansRotations::QuatSmallestThree(_)
892                    ),
893                    header_uses_quat_smallest_three: header.uses_quat_smallest_three(),
894                },
895            );
896        }
897
898        if shs.degree() != header.sh_degree() {
899            return Err(SpzGaussiansFromIterError::ShDegreeMismatch {
900                sh_degree: shs.degree(),
901                header_sh_degree: header.sh_degree(),
902            });
903        }
904
905        Ok(SpzGaussians {
906            header,
907            positions,
908            scales,
909            rotations,
910            alphas,
911            colors,
912            shs,
913        })
914    }
915
916    /// Get an iterator over Gaussian references.
917    pub fn iter<'a>(&'a self) -> impl ExactSizeIterator<Item = SpzGaussianRef<'a>> + 'a {
918        itertools::izip!(
919            self.positions.iter(),
920            self.scales.iter(),
921            self.rotations.iter(),
922            self.alphas.iter(),
923            self.colors.iter(),
924            self.shs.iter()
925        )
926        .map(
927            |(position, scale, rotation, alpha, color, sh)| SpzGaussianRef {
928                position,
929                scale,
930                rotation,
931                alpha,
932                color,
933                sh,
934            },
935        )
936    }
937}
938
939impl IterGaussian for SpzGaussians {
940    fn iter_gaussian(&self) -> impl ExactSizeIterator<Item = Gaussian> + '_ {
941        self.iter().map(|spz| Gaussian::from_spz(spz, &self.header))
942    }
943}
944
945impl ReadIterGaussian for SpzGaussians {
946    fn read_from(reader: &mut impl std::io::BufRead) -> std::io::Result<Self> {
947        let mut decoder = GzDecoder::new(reader);
948        Self::read_decompressed(&mut decoder)
949    }
950}
951
952impl WriteIterGaussian for SpzGaussians {
953    fn write_to(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
954        let mut encoder = GzEncoder::new(writer, flate2::Compression::default());
955        self.write_decompressed(&mut encoder)?;
956        encoder.finish()?;
957        Ok(())
958    }
959}
960
961impl<G: AsRef<Gaussian>> FromIterator<G> for SpzGaussians {
962    fn from_iter<T: IntoIterator<Item = G>>(iter: T) -> Self {
963        Self::from_gaussians(iter)
964    }
965}
966
967/// Options for [`SpzGaussians::from_gaussians_with_options`].
968///
969/// The fields are not validated.
970#[derive(Debug, Clone)]
971pub struct SpzGaussiansFromGaussianSliceOptions {
972    /// Version to use.
973    pub version: u32,
974
975    /// SH degree to use.
976    pub sh_degree: SpzGaussianShDegree,
977
978    /// Number of fractional bits to use for position fixed point encoding.
979    pub fractional_bits: u8,
980
981    /// Whether to use antialiased encoding.
982    pub antialiased: bool,
983
984    /// The quantization bits for each SH degree.
985    pub sh_quantize_bits: [u32; 3],
986}
987
988impl Default for SpzGaussiansFromGaussianSliceOptions {
989    fn default() -> Self {
990        let default_header = SpzGaussiansHeader::default(0).expect("default header");
991        let default_gaussian_to_spz_options = GaussianToSpzOptions::default();
992        Self {
993            version: default_header.version(),
994            sh_degree: default_header.sh_degree(),
995            fractional_bits: default_header.fractional_bits(),
996            antialiased: default_header.is_antialiased(),
997            sh_quantize_bits: default_gaussian_to_spz_options.sh_quantize_bits,
998        }
999    }
1000}