1use std::io::BufRead;
2
3use glam::*;
4
5use crate::{
6 PlyGaussianPod, PlyGaussians, SpzGaussian, SpzGaussianPosition, SpzGaussianPositionRef,
7 SpzGaussianRef, SpzGaussianRotation, SpzGaussianRotationRef, SpzGaussianSh, SpzGaussians,
8 SpzGaussiansHeader,
9};
10
11pub trait IterGaussian: FromIterator<Gaussian> {
13 fn iter_gaussian(&self) -> impl ExactSizeIterator<Item = Gaussian> + '_;
15}
16
17impl IterGaussian for Vec<Gaussian> {
18 fn iter_gaussian(&self) -> impl ExactSizeIterator<Item = Gaussian> + '_ {
19 self.iter().copied()
20 }
21}
22
23pub trait ReadIterGaussian: IterGaussian {
25 fn read_from(reader: &mut impl BufRead) -> std::io::Result<Self>;
27
28 fn read_from_file(path: impl AsRef<std::path::Path>) -> std::io::Result<Self> {
30 let file = std::fs::File::open(path)?;
31 let mut reader = std::io::BufReader::new(file);
32 Self::read_from(&mut reader)
33 }
34}
35
36pub trait WriteIterGaussian: IterGaussian {
38 fn write_to(&self, writer: &mut impl std::io::Write) -> std::io::Result<()>;
40
41 fn write_to_file(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> {
43 let file = std::fs::File::create(path)?;
44 let mut writer = std::io::BufWriter::new(file);
45 self.write_to(&mut writer)
46 }
47}
48
49#[derive(Debug, Clone, Copy, PartialEq)]
54pub struct Gaussian {
55 pub rot: Quat,
56 pub pos: Vec3,
57 pub color: U8Vec4,
58 pub sh: [Vec3; 15],
59 pub scale: Vec3,
60}
61
62impl Gaussian {
63 pub const SH0_TO_LINEAR_FACTOR: f32 = 0.2820948;
65
66 pub const SPZ_SH0_TO_LINEAR_FACTOR: f32 = 0.15;
68
69 pub fn from_ply(ply: &PlyGaussianPod) -> Self {
71 let pos = Vec3::from_array(ply.pos);
72
73 let rot = Quat::from_xyzw(ply.rot[1], ply.rot[2], ply.rot[3], ply.rot[0]).normalize();
74
75 let scale = Vec3::from_array(ply.scale).exp();
76
77 let color = ((Vec3::from_array(ply.color) * Self::SH0_TO_LINEAR_FACTOR + Vec3::splat(0.5))
78 * 255.0)
79 .extend((1.0 / (1.0 + (-ply.alpha).exp())) * 255.0)
80 .clamp(Vec4::splat(0.0), Vec4::splat(255.0))
81 .as_u8vec4();
82
83 let sh = std::array::from_fn(|i| Vec3::new(ply.sh[i], ply.sh[i + 15], ply.sh[i + 30]));
84
85 Self {
86 rot,
87 pos,
88 color,
89 sh,
90 scale,
91 }
92 }
93
94 pub fn to_ply(&self) -> PlyGaussianPod {
96 let pos = self.pos.to_array();
97
98 let rot = [self.rot.w, self.rot.x, self.rot.y, self.rot.z];
99
100 let scale = self.scale.map(|x| x.ln()).to_array();
101
102 let rgba = self.color.as_vec4() / 255.0;
103 let color = ((rgba.xyz() - Vec3::splat(0.5)) / Self::SH0_TO_LINEAR_FACTOR).to_array();
104
105 let alpha = -(1.0 / rgba.w - 1.0).ln();
106
107 let mut sh = [0.0; 3 * 15];
108 for i in 0..15 {
109 sh[i] = self.sh[i].x;
110 sh[i + 15] = self.sh[i].y;
111 sh[i + 30] = self.sh[i].z;
112 }
113
114 let normal = [0.0, 0.0, 1.0];
115
116 PlyGaussianPod {
117 pos,
118 normal,
119 color,
120 sh,
121 alpha,
122 scale,
123 rot,
124 }
125 }
126
127 const SPZ_COLOR_TO_LINEAR_FRAC_A_B: f32 =
128 Gaussian::SH0_TO_LINEAR_FACTOR / Gaussian::SPZ_SH0_TO_LINEAR_FACTOR;
129 const SPZ_COLOR_TO_LINEAR_FRAC_F2_F1: f32 = 0.5 * 255.0;
130 const SPZ_COLOR_TO_LINEAR_C: f32 =
131 (1.0 - Self::SPZ_COLOR_TO_LINEAR_FRAC_A_B) * Self::SPZ_COLOR_TO_LINEAR_FRAC_F2_F1;
132
133 pub fn from_spz(spz: SpzGaussianRef, header: &SpzGaussiansHeader) -> Self {
135 let pos = match spz.position {
136 SpzGaussianPositionRef::Float16(pos) => {
137 let unpacked = pos.map(|c| half::f16::from_bits(c).to_f32_const());
139 Vec3::from_array(unpacked)
140 }
141 SpzGaussianPositionRef::FixedPoint24(pos) => {
142 let scale = 1.0 / (1 << header.fractional_bits()) as f32;
143 let unpacked = pos.map(|c| {
144 let mut fixed32: i32 = c[0] as i32;
145 fixed32 |= (c[1] as i32) << 8;
146 fixed32 |= (c[2] as i32) << 16;
147 fixed32 |= if fixed32 & 0x800000 != 0 {
148 0xff000000u32 as i32
149 } else {
150 0
151 };
152 fixed32 as f32 * scale
153 });
154 Vec3::from_array(unpacked)
155 }
156 };
157
158 let scale = Vec3::from_array(spz.scale.map(|c| c as f32 / 16.0 - 10.0)).exp();
159
160 let rot = match spz.rotation {
161 SpzGaussianRotationRef::QuatFirstThree(quat) => {
162 let xyz = Vec3::from(quat.map(|c| c as f32 / 127.5 - 1.0));
163 let w = (1.0 - xyz.length_squared()).max(0.0).sqrt();
164 Quat::from_xyzw(xyz.x, xyz.y, xyz.z, w)
165 }
166 SpzGaussianRotationRef::QuatSmallestThree(quat) => {
167 let mut comp: u32 = quat[0] as u32
168 | ((quat[1] as u32) << 8)
169 | ((quat[2] as u32) << 16)
170 | ((quat[3] as u32) << 24);
171
172 const C_MASK: u32 = (1 << 9) - 1;
173
174 let largest_index = (comp >> 30) as usize;
175 let mut sum_squares = 0.0f32;
176 let mut comps = std::array::from_fn(|i| {
177 if i == largest_index {
178 return 0.0;
179 }
180
181 let mag = comp & C_MASK;
182 let neg_bit = (comp >> 9) & 1;
183 comp >>= 10;
184
185 let value = std::f32::consts::FRAC_1_SQRT_2
186 * (mag as f32 / C_MASK as f32)
187 * if neg_bit != 0 { -1.0 } else { 1.0 };
188 sum_squares += value * value;
189
190 value
191 });
192
193 comps[largest_index] = (1.0 - sum_squares).max(0.0).sqrt();
194
195 Quat::from_array(comps)
196 }
197 };
198
199 let color = U8Vec3::from_array(spz.color.map(|c| {
200 (c as f32 * Self::SPZ_COLOR_TO_LINEAR_FRAC_A_B + Self::SPZ_COLOR_TO_LINEAR_C)
201 .clamp(0.0, 255.0) as u8
202 }))
203 .extend(*spz.alpha);
204
205 let mut sh = [Vec3::ZERO; 15];
206 for (src, dst) in spz.sh.iter().zip(sh.iter_mut()) {
207 *dst = Vec3::from_array(src.map(|c| (c as f32 - 128.0) / 128.0));
208 }
209
210 Self {
211 rot,
212 pos,
213 color,
214 sh,
215 scale,
216 }
217 }
218
219 pub fn to_spz(
228 &self,
229 header: &SpzGaussiansHeader,
230 options: &GaussianToSpzOptions,
231 ) -> SpzGaussian {
232 let position = if header.uses_float16() {
233 let packed = self
234 .pos
235 .to_array()
236 .map(|c| half::f16::from_f32_const(c).to_bits());
237 SpzGaussianPosition::Float16(packed)
238 } else {
239 let scale = (1 << header.fractional_bits()) as f32;
240 let packed = self.pos.to_array().map(|c| {
241 let fixed32 = (c * scale).round() as i32;
242 [
243 (fixed32 & 0xff) as u8,
244 ((fixed32 >> 8) & 0xff) as u8,
245 ((fixed32 >> 16) & 0xff) as u8,
246 ]
247 });
248 SpzGaussianPosition::FixedPoint24(packed)
249 };
250
251 let scale = self
252 .scale
253 .to_array()
254 .map(|c| ((c.ln() + 10.0) * 16.0).round().clamp(0.0, 255.0) as u8);
255
256 let rotation = if header.uses_quat_smallest_three() {
257 let rot = self.rot.normalize().to_array();
258 let largest_index = rot
259 .into_iter()
260 .map(f32::abs)
261 .enumerate()
262 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
263 .expect("quaternion has at least one component")
264 .0;
265
266 const C_MASK: u32 = (1 << 9) - 1;
267
268 let negate = (rot[largest_index] < 0.0) as u32;
269
270 let mut comp = largest_index as u32;
271 for (i, &value) in rot.iter().enumerate() {
272 if i == largest_index {
273 continue;
274 }
275
276 let neg_bit = (value < 0.0) as u32 ^ negate;
277 let mag = (C_MASK as f32 * (value.abs() * std::f32::consts::SQRT_2) + 0.5)
278 .clamp(0.0, C_MASK as f32 - 1.0) as u32;
279 comp = (comp << 10) | (neg_bit << 9) | mag;
280 }
281
282 SpzGaussianRotation::QuatSmallestThree([
283 (comp & 0xff) as u8,
284 ((comp >> 8) & 0xff) as u8,
285 ((comp >> 16) & 0xff) as u8,
286 ((comp >> 24) & 0xff) as u8,
287 ])
288 } else {
289 let rot = self.rot.normalize();
290 let rot = if rot.w < 0.0 { -rot } else { rot };
291 let packed = rot
292 .xyz()
293 .to_array()
294 .map(|c| ((c + 1.0) * 127.5).round().clamp(0.0, 255.0) as u8);
295 SpzGaussianRotation::QuatFirstThree(packed)
296 };
297
298 let alpha = self.color.w;
299
300 let color = self
301 .color
302 .map(|c| {
303 ((c as f32 - Self::SPZ_COLOR_TO_LINEAR_C) / Self::SPZ_COLOR_TO_LINEAR_FRAC_A_B)
304 .clamp(0.0, 255.0) as u8
305 })
306 .xyz()
307 .to_array();
308
309 let sh = match header.sh_degree().get() {
310 0 => SpzGaussianSh::Zero,
311 deg @ 1..=3 => {
312 let mut sh = match deg {
313 1 => SpzGaussianSh::One([[0; 3]; 3]),
314 2 => SpzGaussianSh::Two([[0; 3]; 8]),
315 3 => SpzGaussianSh::Three([[0; 3]; 15]),
316 _ => unreachable!(),
317 };
318
319 fn quantize_sh(x: f32, bucket_size: u32) -> u8 {
320 let q = (x * 128.0 + 128.0).round() as u32;
321 let q = if bucket_size >= 8 {
322 q
323 } else {
324 (q + bucket_size / 2) / bucket_size * bucket_size
325 };
326 q.clamp(0, 255) as u8
327 }
328
329 for (src, dst) in self.sh.iter().zip(sh.iter_mut()) {
330 let bucket_size = options
331 .sh_bucket_size(deg)
332 .expect("header SH degree is valid");
333 *dst = src.to_array().map(|x| quantize_sh(x, bucket_size));
334 }
335
336 sh
337 }
338 _ => {
339 unreachable!()
341 }
342 };
343
344 SpzGaussian {
345 position,
346 scale,
347 rotation,
348 color,
349 alpha,
350 sh,
351 }
352 }
353}
354
355impl AsRef<Gaussian> for Gaussian {
359 fn as_ref(&self) -> &Gaussian {
360 self
361 }
362}
363
364#[derive(Debug, Clone, Copy, PartialEq)]
366pub struct GaussianToSpzOptions {
367 pub sh_quantize_bits: [u32; 3],
369}
370
371impl GaussianToSpzOptions {
372 pub fn sh_bits(&self, degree: u8) -> Option<u32> {
374 match degree {
375 1..=3 => Some(self.sh_quantize_bits[degree as usize - 1]),
376 _ => None,
377 }
378 }
379
380 pub fn sh_bucket_size(&self, degree: u8) -> Option<u32> {
382 self.sh_bits(degree).map(|bits| 1 << (8 - bits))
383 }
384}
385
386impl Default for GaussianToSpzOptions {
387 fn default() -> Self {
388 Self {
389 sh_quantize_bits: [5, 4, 4],
390 }
391 }
392}
393
394#[derive(Debug, Clone, Copy, PartialEq, Eq)]
396pub enum GaussiansSource {
397 Internal,
398 Ply,
399 Spz,
400}
401
402impl From<&Gaussians> for GaussiansSource {
403 fn from(value: &Gaussians) -> Self {
404 match value {
405 Gaussians::Internal(_) => GaussiansSource::Internal,
406 Gaussians::Ply(_) => GaussiansSource::Ply,
407 Gaussians::Spz(_) => GaussiansSource::Spz,
408 }
409 }
410}
411
412#[derive(Debug, Clone, PartialEq)]
419pub enum Gaussians {
420 Internal(Vec<Gaussian>),
421 Ply(PlyGaussians),
422 Spz(SpzGaussians),
423}
424
425impl Gaussians {
426 pub fn from_gaussians_iter(
428 iter: impl Iterator<Item = Gaussian>,
429 source: GaussiansSource,
430 ) -> Self {
431 match source {
432 GaussiansSource::Internal => Gaussians::Internal(iter.collect()),
433 GaussiansSource::Ply => Gaussians::Ply(iter.collect()),
434 GaussiansSource::Spz => Gaussians::Spz(iter.collect()),
435 }
436 }
437
438 pub fn source(&self) -> GaussiansSource {
440 GaussiansSource::from(self)
441 }
442
443 pub fn len(&self) -> usize {
445 match self {
446 Gaussians::Internal(gaussians) => gaussians.len(),
447 Gaussians::Ply(ply_gaussians) => ply_gaussians.len(),
448 Gaussians::Spz(spz_gaussians) => spz_gaussians.len(),
449 }
450 }
451
452 pub fn is_empty(&self) -> bool {
454 self.len() == 0
455 }
456
457 pub fn read_from_file(
459 path: impl AsRef<std::path::Path>,
460 source: GaussiansSource,
461 ) -> std::io::Result<Self> {
462 match source {
463 GaussiansSource::Internal => Err(std::io::Error::new(
464 std::io::ErrorKind::InvalidInput,
465 "cannot read Internal Gaussians from file",
466 )),
467 GaussiansSource::Ply => {
468 let ply_gaussians = PlyGaussians::read_from_file(path)?;
469 Ok(Gaussians::Ply(ply_gaussians))
470 }
471 GaussiansSource::Spz => {
472 let spz_gaussians = SpzGaussians::read_from_file(path)?;
473 Ok(Gaussians::Spz(spz_gaussians))
474 }
475 }
476 }
477
478 pub fn read_from(reader: &mut impl BufRead, source: GaussiansSource) -> std::io::Result<Self> {
480 match source {
481 GaussiansSource::Internal => Err(std::io::Error::new(
482 std::io::ErrorKind::InvalidInput,
483 "cannot read Internal Gaussians from buffer",
484 )),
485 GaussiansSource::Ply => {
486 let ply_gaussians = PlyGaussians::read_from(reader)?;
487 Ok(Gaussians::Ply(ply_gaussians))
488 }
489 GaussiansSource::Spz => {
490 let spz_gaussians = SpzGaussians::read_from(reader)?;
491 Ok(Gaussians::Spz(spz_gaussians))
492 }
493 }
494 }
495
496 pub fn write_to_file(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> {
498 match self {
499 Gaussians::Internal(_) => Err(std::io::Error::new(
500 std::io::ErrorKind::InvalidInput,
501 "cannot write Internal Gaussians to file",
502 )),
503 Gaussians::Ply(ply_gaussians) => ply_gaussians.write_to_file(path),
504 Gaussians::Spz(spz_gaussians) => spz_gaussians.write_to_file(path),
505 }
506 }
507
508 pub fn write_to(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
510 match self {
511 Gaussians::Internal(_) => Err(std::io::Error::new(
512 std::io::ErrorKind::InvalidInput,
513 "cannot write Internal Gaussians to buffer",
514 )),
515 Gaussians::Ply(ply_gaussians) => ply_gaussians.write_to(writer),
516 Gaussians::Spz(spz_gaussians) => spz_gaussians.write_to(writer),
517 }
518 }
519}
520
521impl From<Vec<Gaussian>> for Gaussians {
522 fn from(value: Vec<Gaussian>) -> Self {
523 Gaussians::Internal(value)
524 }
525}
526
527impl From<PlyGaussians> for Gaussians {
528 fn from(value: PlyGaussians) -> Self {
529 Gaussians::Ply(value)
530 }
531}
532
533impl From<SpzGaussians> for Gaussians {
534 fn from(value: SpzGaussians) -> Self {
535 Gaussians::Spz(value)
536 }
537}
538
539impl IterGaussian for Gaussians {
540 fn iter_gaussian(&self) -> impl ExactSizeIterator<Item = Gaussian> + '_ {
541 match self {
542 Gaussians::Internal(gaussians) => GaussiansIter::Internal(gaussians.iter_gaussian()),
543 Gaussians::Ply(ply_gaussians) => GaussiansIter::Ply(ply_gaussians.iter_gaussian()),
544 Gaussians::Spz(spz_gaussians) => GaussiansIter::Spz(spz_gaussians.iter_gaussian()),
545 }
546 }
547}
548
549impl FromIterator<Gaussian> for Gaussians {
550 fn from_iter<T: IntoIterator<Item = Gaussian>>(iter: T) -> Self {
551 Gaussians::Internal(iter.into_iter().collect())
552 }
553}
554
555pub trait IteratorGaussianExt: Iterator<Item = Gaussian> + Sized {
557 fn collect_gaussians(self, source: GaussiansSource) -> Gaussians {
559 Gaussians::from_gaussians_iter(self, source)
560 }
561}
562
563impl<T: Iterator<Item = Gaussian>> IteratorGaussianExt for T {}
564
565#[derive(Debug, Clone)]
567pub enum GaussiansIter<
568 InternalIter: Iterator<Item = Gaussian>,
569 PlyIter: Iterator<Item = Gaussian>,
570 SpzIter: Iterator<Item = Gaussian>,
571> {
572 Internal(InternalIter),
573 Ply(PlyIter),
574 Spz(SpzIter),
575}
576
577impl<
578 InternalIter: Iterator<Item = Gaussian>,
579 PlyIter: Iterator<Item = Gaussian>,
580 SpzIter: Iterator<Item = Gaussian>,
581> Iterator for GaussiansIter<InternalIter, PlyIter, SpzIter>
582{
583 type Item = Gaussian;
584
585 fn next(&mut self) -> Option<Self::Item> {
586 match self {
587 GaussiansIter::Internal(iter) => iter.next(),
588 GaussiansIter::Ply(iter) => iter.next(),
589 GaussiansIter::Spz(iter) => iter.next(),
590 }
591 }
592
593 fn size_hint(&self) -> (usize, Option<usize>) {
594 match self {
595 GaussiansIter::Internal(iter) => iter.size_hint(),
596 GaussiansIter::Ply(iter) => iter.size_hint(),
597 GaussiansIter::Spz(iter) => iter.size_hint(),
598 }
599 }
600}
601
602impl<
603 InternalIter: ExactSizeIterator<Item = Gaussian>,
604 PlyIter: ExactSizeIterator<Item = Gaussian>,
605 SpzIter: ExactSizeIterator<Item = Gaussian>,
606> ExactSizeIterator for GaussiansIter<InternalIter, PlyIter, SpzIter>
607{
608}