1use std::{
2 io::{Read, Write},
3 ops::RangeInclusive,
4};
5
6use flate2::{read::GzDecoder, write::GzEncoder};
7use itertools::Itertools;
8
9use crate::{
10 Gaussian, GaussianToSpzOptions, IterGaussian, ReadIterGaussian, SpzGaussiansFromIterError,
11 WriteIterGaussian,
12};
13
14macro_rules! gaussian_field {
15 (
16 #[docname = $docname:literal]
17 $name:ident {
18 $(
19 $(#[doc = $doc:literal])?
20 $variant:ident $(($ty:ty))?
21 ),+ $(,)?
22 }
23 ) => {
24 paste::paste! {
25 macro_rules! noop {
26 ($tt:tt) => {};
27 ($tt:tt _) => {_};
28 }
29
30 #[doc = "A single SPZ Gaussian "]
31 #[doc = $docname]
32 #[doc = " field."]
33 #[derive(Debug, Clone, PartialEq)]
34 pub enum [< SpzGaussian $name >] {
35 $(
36 $(#[doc = $doc])?
37 $variant $(($ty))?,
38 )+
39 }
40
41 #[doc = "Reference to SPZ Gaussian "]
42 #[doc = $docname]
43 #[doc = " field."]
44 #[derive(Debug, Clone, Copy, PartialEq)]
45 pub enum [< SpzGaussian $name Ref>]<'a> {
46 $(
47 $(#[doc = $doc])?
48 $variant $((&'a $ty))?,
49 )+
50 }
51
52 #[doc = "Iterator over SPZ Gaussian "]
53 #[doc = $docname]
54 #[doc = " references."]
55 pub enum [< SpzGaussian $name Iter >]<'a> {
56 $(
57 $(#[doc = $doc])?
58 $variant $((std::slice::Iter<'a, $ty>))?,
59 )+
60 }
61
62 impl<'a> Iterator for [< SpzGaussian $name Iter >]<'a> {
63 type Item = [< SpzGaussian $name Ref >]<'a>;
64
65 fn next(&mut self) -> Option<Self::Item> {
66 macro_rules! body {
67 ($variant_:ident, $ty_:ty, $iter:expr) => {
68 $iter.next().map(|v| [< SpzGaussian $name Ref >]:: $variant_ (v))
69 };
70 ($variant_:ident) => {
71 Some([< SpzGaussian $name Ref >]:: $variant_)
72 };
73 }
74
75 match self {
76 $(
77 #[allow(clippy::redundant_pattern)]
78 [< SpzGaussian $name Iter >]:: $variant $( (iter @ noop!($ty _)) )? => {
79 body!($variant $(, $ty, iter )?)
80 }
81 )+
82 }
83 }
84
85 fn size_hint(&self) -> (usize, Option<usize>) {
86 match self {
87 $(
88 #[allow(clippy::redundant_pattern)]
89 [< SpzGaussian $name Iter >]:: $variant $( (iter @ noop!($ty _)) )? => {
90 #[allow(unused_variables)]
91 let len = 0;
92 $(
93 noop!($ty);
94 let len = iter.len();
95 )?
96 (len, Some(len))
97 }
98 )+
99 }
100 }
101 }
102
103 impl<'a> ExactSizeIterator for [< SpzGaussian $name Iter >]<'a> {}
104
105 #[doc = "Representation of SPZ Gaussians "]
106 #[doc = $docname]
107 #[doc = "s."]
108 #[derive(Debug, Clone, PartialEq)]
109 pub enum [< SpzGaussians $name s>] {
110 $(
111 $(#[doc = $doc])?
112 $variant $((Vec<$ty>))?,
113 )+
114 }
115
116 impl [< SpzGaussians $name s>] {
117 pub fn len(&self) -> usize {
119 match self {
120 $(
121 #[allow(clippy::redundant_pattern)]
122 [< SpzGaussians $name s>]:: $variant $( (vec @ noop!($ty _)) )? => {
123 #[allow(unused_variables)]
124 let len = 0;
125 $(
126 noop!($ty);
127 let len = vec.len();
128 )?
129 len
130 }
131 )+
132 }
133 }
134
135 pub fn is_empty(&self) -> bool {
137 self.len() == 0
138 }
139
140 pub fn iter<'a>(&'a self) -> [< SpzGaussian $name Iter >]<'a> {
142 macro_rules! body {
143 ($variant_:ident, $ty_:ty, $vec:expr) => {
144 [< SpzGaussian $name Iter >]:: $variant_ ( $vec.iter() )
145 };
146 ($variant_:ident) => {
147 [< SpzGaussian $name Iter >]:: $variant_
148 };
149 }
150
151 match self {
152 $(
153 #[allow(clippy::redundant_pattern)]
154 [< SpzGaussians $name s>]:: $variant $( (vec @ noop!($ty _)) )? => {
155 body!($variant $(, $ty, vec )?)
156 }
157 )+
158 }
159 }
160 }
161
162 impl FromIterator<[< SpzGaussian $name >]> for Result<
163 [< SpzGaussians $name s>],
164 $crate::error::SpzGaussiansCollectError<[< SpzGaussian $name >]>
165 > {
166 fn from_iter<I: IntoIterator<Item = [< SpzGaussian $name >]>>(iter: I) -> Self {
167 let mut iter = iter.into_iter();
168 let Some(first) = iter.next() else {
169 return Err($crate::error::SpzGaussiansCollectError::EmptyIterator);
170 };
171
172 #[allow(unused_variables)]
173 let first_value = ();
174 match first {
175 $(
176 #[allow(clippy::redundant_pattern)]
177 [< SpzGaussian $name >]:: $variant $( (first_value @ noop!($ty _)) )? => {
178 #[allow(unused_variables)]
179 let value = ();
180 #[allow(unused_variables)]
181 let vec = std::iter::once(Ok(first_value))
182 .chain(
183 iter.map(|v| {
184 match v {
185 [< SpzGaussian $name >]:: $variant $( (
186 value @ noop!($ty _)
187 ) )? => Ok(value),
188 other => Err(
189 $crate::error::SpzGaussiansCollectError::InvalidMixedVariant {
190 first_variant: [< SpzGaussian $name >]:: $variant $( (
191 { noop!($ty); first_value }
192 ) )?,
193 current_variant: other,
194 }
195 ),
196 }
197 })
198 )
199 .collect::<Result<Vec<_>, _>>()?;
200 Ok([< SpzGaussians $name s>]:: $variant $( ({ noop!($ty); vec }) )?)
201 }
202 )+
203 }
204 }
205 }
206 }
207 }
208}
209
210gaussian_field! {
211 #[docname = "position"]
212 Position {
213 #[doc = "(x, y, z) each as 16-bit floating point."]
214 Float16([u16; 3]),
215 #[doc = "(x, y, z) each as 24-bit fixed point signed integer."]
216 FixedPoint24([[u8; 3]; 3]),
217 }
218}
219
220gaussian_field! {
221 #[docname = "rotation"]
222 Rotation {
223 #[doc = "(x, y, z) each as 8-bit signed integer."]
224 QuatFirstThree([u8; 3]),
225 #[doc = "Smallest 3 components each as 10-bit signed integer. 2 bits for index of omitted component."]
226 QuatSmallestThree([u8; 4]),
227 }
228}
229
230gaussian_field! {
231 #[docname = "SH coefficients"]
232 Sh {
233 Zero,
234 One([[u8; 3]; 3]),
235 Two([[u8; 3]; 8]),
236 Three([[u8; 3]; 15]),
237 }
238}
239
240impl SpzGaussianSh {
241 pub fn degree(&self) -> SpzGaussianShDegree {
243 match self {
244 SpzGaussianSh::Zero => unsafe { SpzGaussianShDegree::new_unchecked(0) },
245 SpzGaussianSh::One(_) => unsafe { SpzGaussianShDegree::new_unchecked(1) },
246 SpzGaussianSh::Two(_) => unsafe { SpzGaussianShDegree::new_unchecked(2) },
247 SpzGaussianSh::Three(_) => unsafe { SpzGaussianShDegree::new_unchecked(3) },
248 }
249 }
250
251 pub fn iter(&self) -> impl Iterator<Item = &[u8; 3]> {
253 match self {
254 SpzGaussianSh::Zero => [].iter(),
255 SpzGaussianSh::One(sh) => sh.iter(),
256 SpzGaussianSh::Two(sh) => sh.iter(),
257 SpzGaussianSh::Three(sh) => sh.iter(),
258 }
259 }
260
261 pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut [u8; 3]> {
263 match self {
264 SpzGaussianSh::Zero => [].iter_mut(),
265 SpzGaussianSh::One(sh) => sh.iter_mut(),
266 SpzGaussianSh::Two(sh) => sh.iter_mut(),
267 SpzGaussianSh::Three(sh) => sh.iter_mut(),
268 }
269 }
270}
271
272impl SpzGaussianShRef<'_> {
273 pub fn degree(&self) -> SpzGaussianShDegree {
275 match self {
276 SpzGaussianShRef::Zero => unsafe { SpzGaussianShDegree::new_unchecked(0) },
277 SpzGaussianShRef::One(_) => unsafe { SpzGaussianShDegree::new_unchecked(1) },
278 SpzGaussianShRef::Two(_) => unsafe { SpzGaussianShDegree::new_unchecked(2) },
279 SpzGaussianShRef::Three(_) => unsafe { SpzGaussianShDegree::new_unchecked(3) },
280 }
281 }
282
283 pub fn iter(&self) -> impl Iterator<Item = &[u8; 3]> + '_ {
285 match self {
286 SpzGaussianShRef::Zero => [].iter(),
287 SpzGaussianShRef::One(sh) => sh.iter(),
288 SpzGaussianShRef::Two(sh) => sh.iter(),
289 SpzGaussianShRef::Three(sh) => sh.iter(),
290 }
291 }
292}
293
294impl SpzGaussiansShs {
295 pub fn degree(&self) -> SpzGaussianShDegree {
297 match self {
298 SpzGaussiansShs::Zero => unsafe { SpzGaussianShDegree::new_unchecked(0) },
299 SpzGaussiansShs::One(_) => unsafe { SpzGaussianShDegree::new_unchecked(1) },
300 SpzGaussiansShs::Two(_) => unsafe { SpzGaussianShDegree::new_unchecked(2) },
301 SpzGaussiansShs::Three(_) => unsafe { SpzGaussianShDegree::new_unchecked(3) },
302 }
303 }
304}
305
306#[repr(transparent)]
308#[derive(Debug, Clone, Copy, PartialEq, Eq, bytemuck::Pod, bytemuck::Zeroable)]
309pub struct SpzGaussianShDegree(u8);
310
311impl SpzGaussianShDegree {
312 pub const fn new(sh_deg: u8) -> Option<Self> {
316 match sh_deg {
317 0..=3 => Some(Self(sh_deg)),
318 _ => None,
319 }
320 }
321
322 pub const unsafe fn new_unchecked(sh_deg: u8) -> Self {
328 Self(sh_deg)
329 }
330
331 pub const fn get(&self) -> u8 {
333 self.0
334 }
335
336 pub const fn num_coefficients(&self) -> usize {
338 match self.0 {
339 0 => 0,
340 1 => 3,
341 2 => 8,
342 3 => 15,
343 _ => unreachable!(),
344 }
345 }
346}
347
348impl Default for SpzGaussianShDegree {
349 fn default() -> Self {
350 unsafe { Self::new_unchecked(3) }
352 }
353}
354
355#[derive(Debug, Clone, PartialEq)]
359pub struct SpzGaussian {
360 pub position: SpzGaussianPosition,
361 pub scale: [u8; 3],
362 pub rotation: SpzGaussianRotation,
363 pub alpha: u8,
364 pub color: [u8; 3],
365 pub sh: SpzGaussianSh,
366}
367
368impl SpzGaussian {
369 pub fn as_ref(&self) -> SpzGaussianRef<'_> {
371 SpzGaussianRef {
372 position: match &self.position {
373 SpzGaussianPosition::Float16(v) => SpzGaussianPositionRef::Float16(v),
374 SpzGaussianPosition::FixedPoint24(v) => SpzGaussianPositionRef::FixedPoint24(v),
375 },
376 scale: &self.scale,
377 rotation: match &self.rotation {
378 SpzGaussianRotation::QuatFirstThree(v) => SpzGaussianRotationRef::QuatFirstThree(v),
379 SpzGaussianRotation::QuatSmallestThree(v) => {
380 SpzGaussianRotationRef::QuatSmallestThree(v)
381 }
382 },
383 alpha: &self.alpha,
384 color: &self.color,
385 sh: match &self.sh {
386 SpzGaussianSh::Zero => SpzGaussianShRef::Zero,
387 SpzGaussianSh::One(v) => SpzGaussianShRef::One(v),
388 SpzGaussianSh::Two(v) => SpzGaussianShRef::Two(v),
389 SpzGaussianSh::Three(v) => SpzGaussianShRef::Three(v),
390 },
391 }
392 }
393}
394
395#[derive(Debug, Clone, Copy, PartialEq)]
397pub struct SpzGaussianRef<'a> {
398 pub position: SpzGaussianPositionRef<'a>,
399 pub scale: &'a [u8; 3],
400 pub rotation: SpzGaussianRotationRef<'a>,
401 pub alpha: &'a u8,
402 pub color: &'a [u8; 3],
403 pub sh: SpzGaussianShRef<'a>,
404}
405
406impl SpzGaussianRef<'_> {
407 pub fn to_inner_owned(&self) -> SpzGaussian {
409 SpzGaussian {
410 position: match self.position {
411 SpzGaussianPositionRef::Float16(v) => SpzGaussianPosition::Float16(*v),
412 SpzGaussianPositionRef::FixedPoint24(v) => SpzGaussianPosition::FixedPoint24(*v),
413 },
414 scale: *self.scale,
415 rotation: match self.rotation {
416 SpzGaussianRotationRef::QuatFirstThree(v) => {
417 SpzGaussianRotation::QuatFirstThree(*v)
418 }
419 SpzGaussianRotationRef::QuatSmallestThree(v) => {
420 SpzGaussianRotation::QuatSmallestThree(*v)
421 }
422 },
423 alpha: *self.alpha,
424 color: *self.color,
425 sh: match self.sh {
426 SpzGaussianShRef::Zero => SpzGaussianSh::Zero,
427 SpzGaussianShRef::One(v) => SpzGaussianSh::One(*v),
428 SpzGaussianShRef::Two(v) => SpzGaussianSh::Two(*v),
429 SpzGaussianShRef::Three(v) => SpzGaussianSh::Three(*v),
430 },
431 }
432 }
433}
434
435#[repr(C)]
437#[derive(Debug, Clone, Copy, PartialEq, Eq, bytemuck::Pod, bytemuck::Zeroable)]
438pub struct SpzGaussiansHeaderPod {
439 pub magic: u32,
440 pub version: u32,
441 pub num_points: u32,
442 pub sh_degree: SpzGaussianShDegree,
443 pub fractional_bits: u8,
444 pub flags: u8,
445 pub reserved: u8,
446}
447
448#[derive(Debug, Clone, Copy, PartialEq, Eq)]
454pub struct SpzGaussiansHeader(SpzGaussiansHeaderPod);
455
456impl SpzGaussiansHeader {
457 pub const MAGIC: u32 = 0x5053474e; pub const SUPPORTED_VERSIONS: RangeInclusive<u32> = 1..=3;
462
463 pub const SUPPORTED_SH_DEGREES: RangeInclusive<u8> = 0..=3;
465
466 pub fn new(
470 version: u32,
471 num_points: u32,
472 sh_degree: SpzGaussianShDegree,
473 fractional_bits: u8,
474 antialiased: bool,
475 ) -> Result<Self, std::io::Error> {
476 Self::try_from_pod(SpzGaussiansHeaderPod {
477 magic: Self::MAGIC,
478 version,
479 num_points,
480 sh_degree,
481 fractional_bits,
482 flags: if antialiased { 0x1 } else { 0x0 },
483 reserved: 0,
484 })
485 }
486
487 pub fn try_from_pod(pod: SpzGaussiansHeaderPod) -> Result<Self, std::io::Error> {
489 if pod.magic != Self::MAGIC {
490 return Err(std::io::Error::new(
491 std::io::ErrorKind::InvalidData,
492 format!(
493 "Invalid SPZ magic number: {:X}, expected {:X}",
494 pod.magic,
495 Self::MAGIC
496 ),
497 ));
498 }
499
500 if !Self::SUPPORTED_VERSIONS.contains(&pod.version) {
501 return Err(std::io::Error::new(
502 std::io::ErrorKind::InvalidData,
503 format!(
504 "Unsupported SPZ version: {}, expected one of {:?}",
505 pod.version,
506 Self::SUPPORTED_VERSIONS
507 ),
508 ));
509 }
510
511 Ok(Self(pod))
512 }
513
514 pub fn default(num_points: u32) -> Result<Self, std::io::Error> {
516 Self::new(
517 Self::SUPPORTED_VERSIONS
518 .last()
519 .expect("at least one supported version"),
520 num_points,
521 SpzGaussianShDegree::default(),
522 12,
523 false,
524 )
525 }
526
527 pub fn as_pod(&self) -> &SpzGaussiansHeaderPod {
529 &self.0
530 }
531
532 pub fn version(&self) -> u32 {
534 self.0.version
535 }
536
537 pub fn set_num_points(&mut self, num_points: u32) {
541 self.0.num_points = num_points;
542 }
543
544 pub fn num_points(&self) -> usize {
546 self.0.num_points as usize
547 }
548
549 pub fn sh_degree(&self) -> SpzGaussianShDegree {
551 self.0.sh_degree
552 }
553
554 pub fn sh_num_coefficients(&self) -> usize {
556 self.0.sh_degree.num_coefficients()
557 }
558
559 pub fn fractional_bits(&self) -> u8 {
561 self.0.fractional_bits
562 }
563
564 pub fn is_antialiased(&self) -> bool {
566 (self.0.flags & 0x1) != 0
567 }
568
569 pub fn uses_float16(&self) -> bool {
571 self.version() == 1
572 }
573
574 pub fn uses_quat_smallest_three(&self) -> bool {
576 self.version() >= 3
577 }
578}
579
580impl SpzGaussiansPositions {
581 pub fn read_from(
583 reader: &mut impl Read,
584 count: usize,
585 uses_float16: bool,
586 ) -> Result<Self, std::io::Error> {
587 if uses_float16 {
588 let mut positions = vec![[0u16; 3]; count];
589 reader.read_exact(bytemuck::cast_slice_mut(&mut positions))?;
590 Ok(SpzGaussiansPositions::Float16(positions))
591 } else {
592 let mut positions = vec![[[0u8; 3]; 3]; count];
593 reader.read_exact(bytemuck::cast_slice_mut(&mut positions))?;
594 Ok(SpzGaussiansPositions::FixedPoint24(positions))
595 }
596 }
597
598 pub fn write_to(&self, writer: &mut impl Write) -> Result<(), std::io::Error> {
600 match self {
601 SpzGaussiansPositions::Float16(positions) => {
602 writer.write_all(bytemuck::cast_slice(positions))
603 }
604 SpzGaussiansPositions::FixedPoint24(positions) => {
605 writer.write_all(bytemuck::cast_slice(positions))
606 }
607 }
608 }
609}
610
611impl SpzGaussiansRotations {
612 pub fn read_from(
614 reader: &mut impl Read,
615 count: usize,
616 uses_quat_smallest_three: bool,
617 ) -> Result<Self, std::io::Error> {
618 if !uses_quat_smallest_three {
619 let mut rots = vec![[0u8; 3]; count];
620 reader.read_exact(bytemuck::cast_slice_mut(&mut rots))?;
621 Ok(SpzGaussiansRotations::QuatFirstThree(rots))
622 } else {
623 let mut rots = vec![[0u8; 4]; count];
624 reader.read_exact(bytemuck::cast_slice_mut(&mut rots))?;
625 Ok(SpzGaussiansRotations::QuatSmallestThree(rots))
626 }
627 }
628
629 pub fn write_to(&self, writer: &mut impl Write) -> Result<(), std::io::Error> {
631 match self {
632 SpzGaussiansRotations::QuatFirstThree(rots) => {
633 writer.write_all(bytemuck::cast_slice(rots))
634 }
635 SpzGaussiansRotations::QuatSmallestThree(rots) => {
636 writer.write_all(bytemuck::cast_slice(rots))
637 }
638 }
639 }
640}
641
642impl SpzGaussiansShs {
643 pub fn read_from(
645 reader: &mut impl Read,
646 count: usize,
647 sh_degree: SpzGaussianShDegree,
648 ) -> Result<Self, std::io::Error> {
649 match sh_degree.get() {
650 0 => Ok(SpzGaussiansShs::Zero),
651 1 => {
652 let mut sh_coeffs = vec![[[0u8; 3]; 3]; count];
653 reader.read_exact(bytemuck::cast_slice_mut(&mut sh_coeffs))?;
654 Ok(SpzGaussiansShs::One(sh_coeffs))
655 }
656 2 => {
657 let mut sh_coeffs = vec![[[0u8; 3]; 8]; count];
658 reader.read_exact(bytemuck::cast_slice_mut(&mut sh_coeffs))?;
659 Ok(SpzGaussiansShs::Two(sh_coeffs))
660 }
661 3 => {
662 let mut sh_coeffs = vec![[[0u8; 3]; 15]; count];
663 reader.read_exact(bytemuck::cast_slice_mut(&mut sh_coeffs))?;
664 Ok(SpzGaussiansShs::Three(sh_coeffs))
665 }
666 _ => {
667 unreachable!()
669 }
670 }
671 }
672
673 pub fn write_to(&self, writer: &mut impl Write) -> Result<(), std::io::Error> {
675 match self {
676 SpzGaussiansShs::Zero => Ok(()),
677 SpzGaussiansShs::One(sh_coeffs) => writer.write_all(bytemuck::cast_slice(sh_coeffs)),
678 SpzGaussiansShs::Two(sh_coeffs) => writer.write_all(bytemuck::cast_slice(sh_coeffs)),
679 SpzGaussiansShs::Three(sh_coeffs) => writer.write_all(bytemuck::cast_slice(sh_coeffs)),
680 }
681 }
682}
683
684#[derive(Debug, Clone, PartialEq)]
686pub struct SpzGaussians {
687 pub header: SpzGaussiansHeader,
688
689 pub positions: SpzGaussiansPositions,
690
691 pub scales: Vec<[u8; 3]>,
693
694 pub rotations: SpzGaussiansRotations,
695
696 pub alphas: Vec<u8>,
698
699 pub colors: Vec<[u8; 3]>,
701
702 pub shs: SpzGaussiansShs,
703}
704
705impl SpzGaussians {
706 pub fn len(&self) -> usize {
708 self.header.num_points()
709 }
710
711 pub fn is_empty(&self) -> bool {
713 self.len() == 0
714 }
715
716 pub fn read_decompressed(reader: &mut impl Read) -> Result<Self, std::io::Error> {
720 let header = Self::read_header(reader)?;
721 Self::read_gaussians(reader, header)
722 }
723
724 pub fn read_header(reader: &mut impl Read) -> Result<SpzGaussiansHeader, std::io::Error> {
728 let mut header_bytes = [0u8; std::mem::size_of::<SpzGaussiansHeaderPod>()];
729 reader.read_exact(&mut header_bytes)?;
730 let header: SpzGaussiansHeaderPod = bytemuck::cast(header_bytes);
731 SpzGaussiansHeader::try_from_pod(header)
732 }
733
734 pub fn read_gaussians(
740 reader: &mut impl Read,
741 header: SpzGaussiansHeader,
742 ) -> Result<Self, std::io::Error> {
743 let count = header.num_points();
744 let uses_float16 = header.uses_float16();
745 let uses_quat_smallest_three = header.uses_quat_smallest_three();
746
747 let positions = SpzGaussiansPositions::read_from(reader, count, uses_float16)?;
748
749 let mut alphas = vec![0u8; count];
750 reader.read_exact(bytemuck::cast_slice_mut(&mut alphas))?;
751
752 let mut colors = vec![[0u8; 3]; count];
753 reader.read_exact(bytemuck::cast_slice_mut(&mut colors))?;
754
755 let mut scales = vec![[0u8; 3]; count];
756 reader.read_exact(bytemuck::cast_slice_mut(&mut scales))?;
757
758 let rotations = SpzGaussiansRotations::read_from(reader, count, uses_quat_smallest_three)?;
759
760 let shs = SpzGaussiansShs::read_from(reader, count, header.sh_degree())?;
761
762 Ok(SpzGaussians {
763 header,
764 positions,
765 scales,
766 rotations,
767 alphas,
768 colors,
769 shs,
770 })
771 }
772
773 pub fn write_decompressed(&self, writer: &mut impl Write) -> Result<(), std::io::Error> {
777 writer.write_all(bytemuck::cast_slice(std::slice::from_ref(
778 self.header.as_pod(),
779 )))?;
780
781 self.positions.write_to(writer)?;
782
783 writer.write_all(bytemuck::cast_slice(&self.alphas))?;
784
785 writer.write_all(bytemuck::cast_slice(&self.colors))?;
786
787 writer.write_all(bytemuck::cast_slice(&self.scales))?;
788
789 self.rotations.write_to(writer)?;
790
791 self.shs.write_to(writer)?;
792
793 Ok(())
794 }
795
796 pub fn from_gaussians(gaussians: impl IntoIterator<Item = impl AsRef<Gaussian>>) -> Self {
798 Self::from_gaussians_with_options(
799 gaussians,
800 &SpzGaussiansFromGaussianSliceOptions::default(),
801 )
802 .expect("valid default options")
803 }
804
805 pub fn from_gaussians_with_options(
807 gaussians: impl IntoIterator<Item = impl AsRef<Gaussian>>,
808 options: &SpzGaussiansFromGaussianSliceOptions,
809 ) -> Result<Self, std::io::Error> {
810 let mut header = SpzGaussiansHeader::new(
811 options.version,
812 0,
813 options.sh_degree,
814 options.fractional_bits,
815 options.antialiased,
816 )?;
817
818 let gaussians = gaussians
819 .into_iter()
820 .map(|g| {
821 g.as_ref().to_spz(
822 &header,
823 &GaussianToSpzOptions {
824 sh_quantize_bits: options.sh_quantize_bits,
825 },
826 )
827 })
828 .collect::<Vec<_>>();
829
830 header.set_num_points(gaussians.len() as u32);
831
832 Ok(Self::from_iter(header, gaussians)
833 .expect("gaussians from valid Gaussians with valid header are valid"))
834 }
835
836 pub fn from_iter(
838 header: SpzGaussiansHeader,
839 iter: impl IntoIterator<Item = SpzGaussian>,
840 ) -> Result<Self, SpzGaussiansFromIterError> {
841 let (positions, scales, rotations, alphas, colors, shs) = iter
842 .into_iter()
843 .map(|spz| {
844 (
845 spz.position,
846 spz.scale,
847 spz.rotation,
848 spz.alpha,
849 spz.color,
850 spz.sh,
851 )
852 })
853 .multiunzip::<(Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>, Vec<_>)>();
854
855 let positions = positions
856 .into_iter()
857 .collect::<Result<_, _>>()
858 .map_err(SpzGaussiansFromIterError::InvalidMixedPositionVariant)?;
859
860 let rotations = rotations
861 .into_iter()
862 .collect::<Result<_, _>>()
863 .map_err(SpzGaussiansFromIterError::InvalidMixedRotationVariant)?;
864
865 let shs = shs
866 .into_iter()
867 .collect::<Result<_, _>>()
868 .map_err(SpzGaussiansFromIterError::InvalidMixedShVariant)?;
869
870 if positions.len() != header.num_points() {
871 return Err(SpzGaussiansFromIterError::CountMismatch {
872 actual_count: positions.len(),
873 header_count: header.num_points(),
874 });
875 }
876
877 if matches!(positions, SpzGaussiansPositions::Float16(_)) != header.uses_float16() {
878 return Err(SpzGaussiansFromIterError::PositionFloat16Mismatch {
879 is_float16: matches!(positions, SpzGaussiansPositions::Float16(_)),
880 header_uses_float16: header.uses_float16(),
881 });
882 }
883
884 if matches!(rotations, SpzGaussiansRotations::QuatSmallestThree(_))
885 != header.uses_quat_smallest_three()
886 {
887 return Err(
888 SpzGaussiansFromIterError::RotationQuatSmallestThreeMismatch {
889 is_quat_smallest_three: matches!(
890 rotations,
891 SpzGaussiansRotations::QuatSmallestThree(_)
892 ),
893 header_uses_quat_smallest_three: header.uses_quat_smallest_three(),
894 },
895 );
896 }
897
898 if shs.degree() != header.sh_degree() {
899 return Err(SpzGaussiansFromIterError::ShDegreeMismatch {
900 sh_degree: shs.degree(),
901 header_sh_degree: header.sh_degree(),
902 });
903 }
904
905 Ok(SpzGaussians {
906 header,
907 positions,
908 scales,
909 rotations,
910 alphas,
911 colors,
912 shs,
913 })
914 }
915
916 pub fn iter<'a>(&'a self) -> impl ExactSizeIterator<Item = SpzGaussianRef<'a>> + 'a {
918 itertools::izip!(
919 self.positions.iter(),
920 self.scales.iter(),
921 self.rotations.iter(),
922 self.alphas.iter(),
923 self.colors.iter(),
924 self.shs.iter()
925 )
926 .map(
927 |(position, scale, rotation, alpha, color, sh)| SpzGaussianRef {
928 position,
929 scale,
930 rotation,
931 alpha,
932 color,
933 sh,
934 },
935 )
936 }
937}
938
939impl IterGaussian for SpzGaussians {
940 fn iter_gaussian(&self) -> impl ExactSizeIterator<Item = Gaussian> + '_ {
941 self.iter().map(|spz| Gaussian::from_spz(spz, &self.header))
942 }
943}
944
945impl ReadIterGaussian for SpzGaussians {
946 fn read_from(reader: &mut impl std::io::BufRead) -> std::io::Result<Self> {
947 let mut decoder = GzDecoder::new(reader);
948 Self::read_decompressed(&mut decoder)
949 }
950}
951
952impl WriteIterGaussian for SpzGaussians {
953 fn write_to(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
954 let mut encoder = GzEncoder::new(writer, flate2::Compression::default());
955 self.write_decompressed(&mut encoder)?;
956 encoder.finish()?;
957 Ok(())
958 }
959}
960
961impl<G: AsRef<Gaussian>> FromIterator<G> for SpzGaussians {
962 fn from_iter<T: IntoIterator<Item = G>>(iter: T) -> Self {
963 Self::from_gaussians(iter)
964 }
965}
966
967#[derive(Debug, Clone)]
971pub struct SpzGaussiansFromGaussianSliceOptions {
972 pub version: u32,
974
975 pub sh_degree: SpzGaussianShDegree,
977
978 pub fractional_bits: u8,
980
981 pub antialiased: bool,
983
984 pub sh_quantize_bits: [u32; 3],
986}
987
988impl Default for SpzGaussiansFromGaussianSliceOptions {
989 fn default() -> Self {
990 let default_header = SpzGaussiansHeader::default(0).expect("default header");
991 let default_gaussian_to_spz_options = GaussianToSpzOptions::default();
992 Self {
993 version: default_header.version(),
994 sh_degree: default_header.sh_degree(),
995 fractional_bits: default_header.fractional_bits(),
996 antialiased: default_header.is_antialiased(),
997 sh_quantize_bits: default_gaussian_to_spz_options.sh_quantize_bits,
998 }
999 }
1000}