Skip to main content

tenflowers_neural/neural_rendering/
mod.rs

1//! Neural Rendering & 3D Vision.
2//!
3//! Implements cutting-edge neural rendering techniques:
4//!
5//! - **NeRF**: [`PositionalEncoding`], [`NeRFMlp`], [`VolumeRenderer`], [`RayMarcher`], [`NeRFLoss`]
6//! - **3D Gaussian Splatting**: [`Gaussian3D`], [`GaussianSplatRenderer`], [`Gaussian2D`], [`GaussianOptimizer`], [`QuaternionOps`]
7//! - **Implicit Neural Representations**: [`SirenLayer`], [`SirenNetwork`], [`NeuralSdf`], [`InstantNgp`], [`OccupancyNetwork`]
8//! - **Scene Understanding**: [`DepthEstimationNet`], [`SurfaceNormalEstimator`], [`SemanticNerfDecoder`], [`PanopticLiftingHead`], [`SceneFlowEstimator`]
9//! - **Camera & View Synthesis**: [`CameraModel`], [`MultiViewConsistencyLoss`], [`PoseEstimator`], [`ViewInterpolator`], [`CameraOptimizer`]
10
11pub mod extensions;
12pub use extensions::*;
13
14pub mod advanced;
15// Explicit re-exports from advanced to avoid name conflicts with mod-level types.
16// Note: advanced::Gaussian3D is a different type (with SH coefficients) and is
17// accessed as neural_rendering::advanced::Gaussian3D to avoid collision.
18pub use advanced::{
19    DeformationField, DynamicNerf, GaussianDensification, GaussianSplatter,
20    NrcCache, NrMetrics, Reservoir, ReSTIR, ShCoefficients,
21};
22
23#[cfg(test)]
24pub mod tests;
25
26use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
27use scirs2_core::RngExt;
28use tenflowers_core::{Result, TensorError};
29
30// ─────────────────────────────────────────────────────────────────────────────
31// Shared utilities
32// ─────────────────────────────────────────────────────────────────────────────
33
34#[inline]
35pub(crate) fn relu(x: f64) -> f64 {
36    x.max(0.0)
37}
38
39#[inline]
40pub(crate) fn sigmoid(x: f64) -> f64 {
41    if x >= 0.0 {
42        1.0 / (1.0 + (-x).exp())
43    } else {
44        let e = x.exp();
45        e / (1.0 + e)
46    }
47}
48
49#[inline]
50pub(crate) fn softplus(x: f64) -> f64 {
51    (1.0 + x.exp()).ln()
52}
53
54pub(crate) fn dot(a: &[f64], b: &[f64]) -> f64 {
55    a.iter().zip(b).map(|(x, y)| x * y).sum()
56}
57
58pub(crate) fn matvec(mat: &[Vec<f64>], v: &[f64]) -> Vec<f64> {
59    mat.iter().map(|row| dot(row, v)).collect()
60}
61
62pub(crate) fn rand_weight(rows: usize, cols: usize, seed: u64) -> Vec<Vec<f64>> {
63    let mut rng = StdRng::seed_from_u64(seed);
64    let scale = (2.0 / cols as f64).sqrt();
65    (0..rows)
66        .map(|_| {
67            (0..cols)
68                .map(|_| (rng.random::<f64>() * 2.0 - 1.0) * scale)
69                .collect()
70        })
71        .collect()
72}
73
74pub(crate) fn linear(w: &[Vec<f64>], b: &[f64], x: &[f64]) -> Vec<f64> {
75    let out = matvec(w, x);
76    out.iter().zip(b).map(|(o, bi)| o + bi).collect()
77}
78
79pub(crate) fn linear_relu(w: &[Vec<f64>], b: &[f64], x: &[f64]) -> Vec<f64> {
80    linear(w, b, x).into_iter().map(relu).collect()
81}
82
83pub(crate) fn normalize3(v: [f64; 3]) -> [f64; 3] {
84    let n = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
85    if n < 1e-300 {
86        v
87    } else {
88        [v[0] / n, v[1] / n, v[2] / n]
89    }
90}
91
92pub(crate) fn cross3(a: [f64; 3], b: [f64; 3]) -> [f64; 3] {
93    [
94        a[1] * b[2] - a[2] * b[1],
95        a[2] * b[0] - a[0] * b[2],
96        a[0] * b[1] - a[1] * b[0],
97    ]
98}
99
100// ─────────────────────────────────────────────────────────────────────────────
101// §1  Neural Radiance Fields (NeRF)
102// ─────────────────────────────────────────────────────────────────────────────
103
104/// Fourier positional encoding used by NeRF.
105///
106/// For a scalar `x` and `n_freqs` frequencies the encoding is:
107/// `[sin(2^0 x), cos(2^0 x), sin(2^1 x), cos(2^1 x), …, sin(2^{L-1} x), cos(2^{L-1} x)]`
108/// giving a vector of length `2 * n_freqs`.
109#[derive(Debug, Clone)]
110pub struct PositionalEncoding {
111    /// Number of frequency bands L.
112    pub n_freqs: usize,
113    /// Whether to include the raw input (identity) in the encoding.
114    pub include_input: bool,
115}
116
117impl PositionalEncoding {
118    /// Create a new encoding with `n_freqs` frequency bands.
119    pub fn new(n_freqs: usize) -> Self {
120        Self {
121            n_freqs,
122            include_input: false,
123        }
124    }
125
126    /// Create a new encoding including the raw input.
127    pub fn with_identity(n_freqs: usize) -> Self {
128        Self {
129            n_freqs,
130            include_input: true,
131        }
132    }
133
134    /// Encode a scalar `x`.  Returns a vector of length `2 * n_freqs` (or
135    /// `2 * n_freqs + 1` when `include_input` is true).
136    pub fn encode_scalar(&self, x: f64) -> Vec<f64> {
137        let mut out = if self.include_input { vec![x] } else { vec![] };
138        for i in 0..self.n_freqs {
139            let freq = (2.0_f64).powi(i as i32);
140            out.push((freq * x).sin());
141            out.push((freq * x).cos());
142        }
143        out
144    }
145
146    /// Encode an arbitrary-length input vector by encoding each element
147    /// independently and concatenating the results.
148    pub fn encode(&self, x: &[f64]) -> Vec<f64> {
149        x.iter().flat_map(|&xi| self.encode_scalar(xi)).collect()
150    }
151
152    /// Output dimension for a single scalar input.
153    pub fn output_dim_scalar(&self) -> usize {
154        2 * self.n_freqs + if self.include_input { 1 } else { 0 }
155    }
156
157    /// Output dimension for an input of length `input_len`.
158    pub fn output_dim(&self, input_len: usize) -> usize {
159        self.output_dim_scalar() * input_len
160    }
161}
162
163/// NeRF MLP: (point, direction) → (RGB, density σ).
164/// 4-layer backbone with skip connection + density head + direction-conditioned colour head.
165#[derive(Debug, Clone)]
166pub struct NeRFMlp {
167    pub point_enc: PositionalEncoding,
168    pub dir_enc: PositionalEncoding,
169    backbone_w: Vec<Vec<Vec<f64>>>,
170    backbone_b: Vec<Vec<f64>>,
171    density_w: Vec<Vec<f64>>,
172    density_b: Vec<f64>,
173    feature_w: Vec<Vec<f64>>,
174    feature_b: Vec<f64>,
175    colour_w: Vec<Vec<f64>>,
176    colour_b: Vec<f64>,
177    hidden_dim: usize,
178}
179
180impl NeRFMlp {
181    /// Build NeRF MLP. `point_n_freqs`/`dir_n_freqs`: encoding bands; `hidden_dim`: layer width.
182    pub fn new(point_n_freqs: usize, dir_n_freqs: usize, hidden_dim: usize) -> Self {
183        let point_enc = PositionalEncoding::new(point_n_freqs);
184        let dir_enc = PositionalEncoding::new(dir_n_freqs);
185        let pt_dim = point_enc.output_dim(3);
186        let dir_dim = dir_enc.output_dim(3);
187
188        let mut seed = 42_u64;
189        let mut mk = |r: usize, c: usize| {
190            let w = rand_weight(r, c, seed);
191            seed += 1;
192            (w, vec![0.0_f64; r])
193        };
194
195        // 4 backbone layers; layer 0 takes pt_dim, layers 1–3 take hidden_dim
196        // layer 2 (index 2) gets pt_dim concatenated → input = hidden_dim + pt_dim
197        let mut backbone_w = Vec::new();
198        let mut backbone_b = Vec::new();
199        for i in 0..4 {
200            let in_dim = if i == 0 {
201                pt_dim
202            } else if i == 2 {
203                hidden_dim + pt_dim
204            } else {
205                hidden_dim
206            };
207            let (w, b) = mk(hidden_dim, in_dim);
208            backbone_w.push(w);
209            backbone_b.push(b);
210        }
211
212        let (density_w, density_b) = mk(1, hidden_dim);
213        let (feature_w, feature_b) = mk(hidden_dim, hidden_dim);
214        let (colour_w, colour_b) = mk(3, hidden_dim + dir_dim);
215
216        Self {
217            point_enc,
218            dir_enc,
219            backbone_w,
220            backbone_b,
221            density_w,
222            density_b,
223            feature_w,
224            feature_b,
225            colour_w,
226            colour_b,
227            hidden_dim,
228        }
229    }
230
231    /// Forward pass.
232    ///
233    /// Returns `(rgb, sigma)` where `rgb` is `[r,g,b]` ∈ \[0,1\]³ and
234    /// `sigma` ≥ 0 is the volumetric density.
235    pub fn forward(&self, point: &[f64], direction: &[f64]) -> Result<([f64; 3], f64)> {
236        if point.len() < 3 {
237            return Err(TensorError::invalid_argument_op(
238                "nerf_mlp",
239                "point must have at least 3 elements",
240            ));
241        }
242        if direction.len() < 3 {
243            return Err(TensorError::invalid_argument_op(
244                "nerf_mlp",
245                "direction must have at least 3 elements",
246            ));
247        }
248
249        let pt_enc = self.point_enc.encode(&point[..3]);
250        let dir_enc = self.dir_enc.encode(&direction[..3]);
251
252        // backbone with skip connection
253        let mut h = pt_enc.clone();
254        for i in 0..4 {
255            if i == 2 {
256                // skip: concatenate original point encoding
257                let mut inp = h.clone();
258                inp.extend_from_slice(&pt_enc);
259                h = linear_relu(&self.backbone_w[i], &self.backbone_b[i], &inp);
260            } else {
261                h = linear_relu(&self.backbone_w[i], &self.backbone_b[i], &h);
262            }
263        }
264
265        // density head (softplus for non-negativity)
266        let raw_sigma = linear(&self.density_w, &self.density_b, &h)[0];
267        let sigma = softplus(raw_sigma);
268
269        // feature → colour head conditioned on direction
270        let feat = linear_relu(&self.feature_w, &self.feature_b, &h);
271        let mut feat_dir = feat;
272        feat_dir.extend_from_slice(&dir_enc);
273        let rgb_raw = linear(&self.colour_w, &self.colour_b, &feat_dir);
274
275        let rgb = [
276            sigmoid(rgb_raw[0]),
277            sigmoid(rgb_raw.get(1).copied().unwrap_or(0.0)),
278            sigmoid(rgb_raw.get(2).copied().unwrap_or(0.0)),
279        ];
280
281        Ok((rgb, sigma))
282    }
283}
284
285/// Alpha compositing volume renderer (NeRF discrete rendering equation).
286/// `T_i = exp(-Σ σ_j δ_j)`, `w_i = T_i(1-exp(-σ_i δ_i))`, `C = Σ w_i c_i`.
287#[derive(Debug, Clone)]
288pub struct VolumeRenderer {
289    /// Background colour blended with `(1 - accumulated_alpha)`.
290    pub background: [f64; 3],
291}
292
293impl VolumeRenderer {
294    /// Create a new volume renderer with black background.
295    pub fn new() -> Self {
296        Self {
297            background: [0.0; 3],
298        }
299    }
300    /// Create a volume renderer with a specific background colour.
301    pub fn with_background(background: [f64; 3]) -> Self {
302        Self { background }
303    }
304
305    /// Render a ray. `samples`: `(sigma, delta, rgb)` ordered near→far.
306    /// Returns `(colour, per-sample weights)`.
307    pub fn render(&self, samples: &[(f64, f64, [f64; 3])]) -> ([f64; 3], Vec<f64>) {
308        let mut transmittance = 1.0_f64;
309        let mut colour = [0.0_f64; 3];
310        let mut weights = Vec::with_capacity(samples.len());
311
312        for &(sigma, delta, rgb) in samples {
313            let alpha = 1.0 - (-sigma * delta).exp();
314            let weight = transmittance * alpha;
315            weights.push(weight);
316
317            for ch in 0..3 {
318                colour[ch] += weight * rgb[ch];
319            }
320
321            transmittance *= 1.0 - alpha;
322            if transmittance < 1e-10 {
323                // fill remaining weights with zero
324                weights.resize(samples.len(), 0.0);
325                break;
326            }
327        }
328
329        // blend background
330        for ch in 0..3 {
331            colour[ch] += transmittance * self.background[ch];
332        }
333
334        // pad weights to match samples length
335        weights.resize(samples.len(), 0.0);
336
337        (colour, weights)
338    }
339
340    /// Convenience overload matching the task signature:
341    /// `samples: &[(sigma: f64, rgb: [f64;3])]` with uniform step δ = 1/(n-1).
342    pub fn render_uniform(&self, samples: &[(f64, [f64; 3])]) -> [f64; 3] {
343        let n = samples.len();
344        if n == 0 {
345            return self.background;
346        }
347        let delta = if n > 1 { 1.0 / (n as f64 - 1.0) } else { 1.0 };
348        let full: Vec<_> = samples.iter().map(|&(s, rgb)| (s, delta, rgb)).collect();
349        self.render(&full).0
350    }
351}
352
353impl Default for VolumeRenderer {
354    fn default() -> Self {
355        Self::new()
356    }
357}
358
359/// Stratified ray marcher: samples `n_samples` points near→far along a ray.
360#[derive(Debug, Clone)]
361pub struct RayMarcher {
362    pub near: f64,
363    pub far: f64,
364    pub n_samples: usize,
365    pub stratified: bool,
366}
367
368impl RayMarcher {
369    /// Create a new stratified ray marcher.
370    pub fn new(near: f64, far: f64, n_samples: usize) -> Self {
371        Self {
372            near,
373            far,
374            n_samples,
375            stratified: true,
376        }
377    }
378
379    /// Returns `Vec<(point, t)>` sampled stratified along the ray.
380    pub fn sample_points(
381        &self,
382        origin: [f64; 3],
383        direction: [f64; 3],
384        rng: &mut StdRng,
385    ) -> Vec<([f64; 3], f64)> {
386        let n = self.n_samples.max(1);
387        let step = (self.far - self.near) / n as f64;
388        (0..n)
389            .map(|i| {
390                let t_low = self.near + i as f64 * step;
391                let jitter: f64 = if self.stratified {
392                    rng.random::<f64>() * step
393                } else {
394                    step * 0.5
395                };
396                let t = (t_low + jitter).min(self.far);
397                let pt = [
398                    origin[0] + t * direction[0],
399                    origin[1] + t * direction[1],
400                    origin[2] + t * direction[2],
401                ];
402                (pt, t)
403            })
404            .collect()
405    }
406}
407
408/// Photometric NeRF training loss.
409///
410/// L = (1/N) Σ_i ‖ĉ_i − c_i‖² + λ_depth · depth_consistency_term
411#[derive(Debug, Clone)]
412pub struct NeRFLoss {
413    /// Weight for the optional depth consistency term.
414    pub lambda_depth: f64,
415}
416
417impl NeRFLoss {
418    /// Create a new NeRF loss with no depth regularisation.
419    pub fn new() -> Self {
420        Self { lambda_depth: 0.0 }
421    }
422    /// Create a new NeRF loss with depth regularisation weight `lambda_depth`.
423    pub fn with_depth(lambda_depth: f64) -> Self {
424        Self { lambda_depth }
425    }
426
427    /// Photometric MSE: (1/N) Σ ‖ĉ − c‖².
428    pub fn photometric_mse(&self, rendered: &[[f64; 3]], target: &[[f64; 3]]) -> Result<f64> {
429        if rendered.len() != target.len() {
430            return Err(TensorError::invalid_argument_op(
431                "nerf_loss",
432                "rendered and target must have the same length",
433            ));
434        }
435        if rendered.is_empty() {
436            return Ok(0.0);
437        }
438        let n = rendered.len() as f64;
439        let mse = rendered
440            .iter()
441            .zip(target)
442            .map(|(r, t)| (r[0] - t[0]).powi(2) + (r[1] - t[1]).powi(2) + (r[2] - t[2]).powi(2))
443            .sum::<f64>()
444            / n;
445        Ok(mse)
446    }
447
448    /// Full loss = MSE + λ_depth · depth_consistency.
449    pub fn compute(
450        &self,
451        rendered: &[[f64; 3]],
452        target: &[[f64; 3]],
453        depth_rendered: Option<&[f64]>,
454        depth_target: Option<&[f64]>,
455    ) -> Result<f64> {
456        let photo = self.photometric_mse(rendered, target)?;
457        let depth_term = match (depth_rendered, depth_target) {
458            (Some(dr), Some(dt)) if self.lambda_depth > 0.0 && dr.len() == dt.len() => {
459                let n = dr.len().max(1) as f64;
460                let d = dr.iter().zip(dt).map(|(a, b)| (a - b).powi(2)).sum::<f64>() / n;
461                self.lambda_depth * d
462            }
463            _ => 0.0,
464        };
465        Ok(photo + depth_term)
466    }
467}
468
469impl Default for NeRFLoss {
470    fn default() -> Self {
471        Self::new()
472    }
473}
474
475// ─────────────────────────────────────────────────────────────────────────────
476// §2  3D Gaussian Splatting
477// ─────────────────────────────────────────────────────────────────────────────
478
479/// Quaternion `[w,x,y,z]` helper utilities.
480#[derive(Debug, Clone)]
481pub struct QuaternionOps;
482
483impl QuaternionOps {
484    /// Normalise to unit quaternion.
485    pub fn normalize(q: [f64; 4]) -> [f64; 4] {
486        let n = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
487        if n < 1e-300 {
488            return [1.0, 0.0, 0.0, 0.0];
489        }
490        [q[0] / n, q[1] / n, q[2] / n, q[3] / n]
491    }
492
493    /// Hamilton product.
494    pub fn multiply(a: [f64; 4], b: [f64; 4]) -> [f64; 4] {
495        let (aw, ax, ay, az) = (a[0], a[1], a[2], a[3]);
496        let (bw, bx, by, bz) = (b[0], b[1], b[2], b[3]);
497        [
498            aw * bw - ax * bx - ay * by - az * bz,
499            aw * bx + ax * bw + ay * bz - az * by,
500            aw * by - ax * bz + ay * bw + az * bx,
501            aw * bz + ax * by - ay * bx + az * bw,
502        ]
503    }
504
505    /// Convert unit quaternion to 3×3 rotation matrix (row-major).
506    pub fn to_rotation_matrix(q: [f64; 4]) -> [[f64; 3]; 3] {
507        let q = Self::normalize(q);
508        let (w, x, y, z) = (q[0], q[1], q[2], q[3]);
509        [
510            [
511                1.0 - 2.0 * (y * y + z * z),
512                2.0 * (x * y - w * z),
513                2.0 * (x * z + w * y),
514            ],
515            [
516                2.0 * (x * y + w * z),
517                1.0 - 2.0 * (x * x + z * z),
518                2.0 * (y * z - w * x),
519            ],
520            [
521                2.0 * (x * z - w * y),
522                2.0 * (y * z + w * x),
523                1.0 - 2.0 * (x * x + y * y),
524            ],
525        ]
526    }
527}
528
529/// A single 3D Gaussian primitive for 3DGS.
530#[derive(Debug, Clone)]
531pub struct Gaussian3D {
532    /// Centre position `[x, y, z]`.
533    pub center: [f64; 3],
534    /// Log-scale parameters `[sx, sy, sz]` (exponentiated to give actual scale).
535    pub log_scale: [f64; 3],
536    /// Unit quaternion `[w, x, y, z]` representing orientation.
537    pub rotation: [f64; 4],
538    /// Logit-opacity (sigmoid → opacity in \[0,1\]).
539    pub logit_opacity: f64,
540    /// Spherical-harmonic DC coefficient as RGB colour.
541    pub color: [f64; 3],
542}
543
544impl Gaussian3D {
545    /// Create a new 3D Gaussian at the given center with uniform scale and color.
546    pub fn new(center: [f64; 3], scale: f64, color: [f64; 3]) -> Self {
547        let log_s = scale.abs().max(1e-10).ln();
548        Self {
549            center,
550            log_scale: [log_s; 3],
551            rotation: [1.0, 0.0, 0.0, 0.0],
552            logit_opacity: 0.0,
553            color,
554        }
555    }
556
557    /// Get the scale of this Gaussian.
558    pub fn scale(&self) -> [f64; 3] {
559        [
560            self.log_scale[0].exp(),
561            self.log_scale[1].exp(),
562            self.log_scale[2].exp(),
563        ]
564    }
565    /// Get the opacity (sigmoid of logit).
566    pub fn opacity(&self) -> f64 {
567        sigmoid(self.logit_opacity)
568    }
569
570    /// Compute 3D covariance: Σ = R diag(s²) Rᵀ.
571    pub fn covariance3d(&self) -> [[f64; 3]; 3] {
572        let r = QuaternionOps::to_rotation_matrix(self.rotation);
573        let s = self.scale();
574        // R * diag(s^2) * R^T
575        let mut cov = [[0.0_f64; 3]; 3];
576        for i in 0..3 {
577            for j in 0..3 {
578                cov[i][j] = (0..3).map(|k| r[i][k] * s[k] * s[k] * r[j][k]).sum();
579            }
580        }
581        cov
582    }
583}
584
585/// 2D Gaussian projected onto the image plane.
586#[derive(Debug, Clone)]
587pub struct Gaussian2D {
588    pub center: [f64; 2],
589    pub cov2d: [[f64; 2]; 2],
590    pub opacity: f64,
591    pub color: [f64; 3],
592    /// Camera-space depth (for back-to-front sorting).
593    pub depth: f64,
594}
595
596impl Gaussian2D {
597    pub(crate) fn eval(&self, px: f64, py: f64) -> f64 {
598        let dx = px - self.center[0];
599        let dy = py - self.center[1];
600        let det = self.cov2d[0][0] * self.cov2d[1][1] - self.cov2d[0][1] * self.cov2d[1][0];
601        if det.abs() < 1e-12 {
602            return 0.0;
603        }
604        let inv = [
605            [self.cov2d[1][1] / det, -self.cov2d[0][1] / det],
606            [-self.cov2d[1][0] / det, self.cov2d[0][0] / det],
607        ];
608        let q = dx * (inv[0][0] * dx + inv[0][1] * dy) + dy * (inv[1][0] * dx + inv[1][1] * dy);
609        (-0.5 * q).exp()
610    }
611}
612
613/// Tile-based 3DGS renderer.
614#[derive(Debug, Clone)]
615pub struct GaussianSplatRenderer {
616    pub focal_length: f64,
617    pub width: usize,
618    pub height: usize,
619}
620
621impl GaussianSplatRenderer {
622    /// Create a new Gaussian splat renderer.
623    pub fn new(focal_length: f64, width: usize, height: usize) -> Self {
624        Self {
625            focal_length,
626            width,
627            height,
628        }
629    }
630
631    /// Project a 3D Gaussian to 2D (pinhole model, simple translation).
632    pub fn project(&self, g: &Gaussian3D, camera_pos: [f64; 3]) -> Option<Gaussian2D> {
633        let cx = g.center[0] - camera_pos[0];
634        let cy = g.center[1] - camera_pos[1];
635        let cz = g.center[2] - camera_pos[2];
636        if cz <= 0.001 {
637            return None;
638        }
639
640        let f = self.focal_length;
641        let px = f * cx / cz + self.width as f64 * 0.5;
642        let py = f * cy / cz + self.height as f64 * 0.5;
643
644        // Jacobian of perspective projection (2×3)
645        let j = [
646            [f / cz, 0.0, -f * cx / (cz * cz)],
647            [0.0, f / cz, -f * cy / (cz * cz)],
648        ];
649
650        // 3D covariance
651        let sigma3 = g.covariance3d();
652
653        // Σ2D = J Σ3D Jᵀ
654        // tmp[2×3] = J * Σ3D
655        let mut tmp = [[0.0_f64; 3]; 2];
656        for i in 0..2 {
657            for k in 0..3 {
658                tmp[i][k] = j[i]
659                    .iter()
660                    .enumerate()
661                    .map(|(l, &jil)| jil * sigma3[l][k])
662                    .sum();
663            }
664        }
665        // cov2d[2×2] = tmp * Jᵀ
666        let mut cov2d = [[0.0_f64; 2]; 2];
667        for i in 0..2 {
668            for k in 0..2 {
669                cov2d[i][k] = (0..3).map(|l| tmp[i][l] * j[k][l]).sum::<f64>();
670            }
671        }
672        // Add small regularisation to diagonal
673        cov2d[0][0] += 0.3;
674        cov2d[1][1] += 0.3;
675
676        Some(Gaussian2D {
677            center: [px, py],
678            cov2d,
679            opacity: g.opacity(),
680            color: g.color,
681            depth: cz,
682        })
683    }
684
685    /// Render a single pixel `(px, py)` by alpha-blending sorted 2D Gaussians.
686    ///
687    /// Gaussians must be sorted back-to-front (largest depth first).
688    pub fn render_pixel(&self, gaussians_2d: &[Gaussian2D], px: f64, py: f64) -> [f64; 3] {
689        let mut colour = [0.0_f64; 3];
690        let mut transmittance = 1.0_f64;
691        for g in gaussians_2d {
692            let alpha = g.opacity * g.eval(px, py);
693            let w = transmittance * alpha;
694            for ch in 0..3 {
695                colour[ch] += w * g.color[ch];
696            }
697            transmittance *= 1.0 - alpha;
698            if transmittance < 1e-4 {
699                break;
700            }
701        }
702        colour
703    }
704
705    /// Render a full image given a set of 3D Gaussians and camera position.
706    pub fn render_image(
707        &self,
708        gaussians: &[Gaussian3D],
709        camera_pos: [f64; 3],
710    ) -> Vec<Vec<[f64; 3]>> {
711        let mut g2d: Vec<Gaussian2D> = gaussians
712            .iter()
713            .filter_map(|g| self.project(g, camera_pos))
714            .collect();
715        // Sort back-to-front
716        g2d.sort_by(|a, b| {
717            b.depth
718                .partial_cmp(&a.depth)
719                .unwrap_or(std::cmp::Ordering::Equal)
720        });
721
722        (0..self.height)
723            .map(|y| {
724                (0..self.width)
725                    .map(|x| self.render_pixel(&g2d, x as f64, y as f64))
726                    .collect()
727            })
728            .collect()
729    }
730}
731
732/// Finite-difference gradient optimiser for 3DGS: update centres, densify, prune.
733#[derive(Debug, Clone)]
734pub struct GaussianOptimizer {
735    pub lr: f64,
736    pub eps: f64,
737    pub grad_accum: Vec<f64>,
738    pub step: usize,
739}
740
741impl GaussianOptimizer {
742    /// Create a new Gaussian optimizer with the given learning rate.
743    pub fn new(lr: f64) -> Self {
744        Self {
745            lr,
746            eps: 1e-4,
747            grad_accum: vec![],
748            step: 0,
749        }
750    }
751
752    /// Gradient descent on Gaussian centres. Returns loss before step.
753    #[allow(clippy::ptr_arg)]
754    pub fn step_once(
755        &mut self,
756        renderer: &GaussianSplatRenderer,
757        gaussians: &mut Vec<Gaussian3D>,
758        camera_pos: [f64; 3],
759        target: &[[f64; 3]],
760    ) -> f64 {
761        let w = renderer.width;
762        let h = renderer.height;
763        self.grad_accum.resize(gaussians.len(), 0.0);
764
765        let image0 = renderer.render_image(gaussians, camera_pos);
766        let loss0 = mse_image(&image0, target, w, h);
767
768        for gi in 0..gaussians.len() {
769            let mut grad_norm = 0.0_f64;
770            for ax in 0..3 {
771                let orig = gaussians[gi].center[ax];
772                gaussians[gi].center[ax] = orig + self.eps;
773                let img_p = renderer.render_image(gaussians, camera_pos);
774                let loss_p = mse_image(&img_p, target, w, h);
775                gaussians[gi].center[ax] = orig;
776
777                let grad = (loss_p - loss0) / self.eps;
778                grad_norm += grad.abs();
779                gaussians[gi].center[ax] -= self.lr * grad;
780            }
781            self.grad_accum[gi] += grad_norm;
782        }
783        self.step += 1;
784        loss0
785    }
786
787    /// Split Gaussians with accumulated gradient > `threshold` into two.
788    pub fn densify(&mut self, gaussians: &mut Vec<Gaussian3D>, threshold: f64) {
789        let mut new_gaussians = Vec::new();
790        let mut to_keep = Vec::new();
791        for (gi, g) in gaussians.iter().enumerate() {
792            let acc = self.grad_accum.get(gi).copied().unwrap_or(0.0);
793            if acc > threshold {
794                let s = g.scale()[0] * 0.5;
795                let log_s = s.max(1e-10).ln();
796                // Copy 1: slightly positive offset
797                let mut g1 = g.clone();
798                g1.center[0] += s;
799                g1.log_scale = [log_s; 3];
800                // Copy 2: slightly negative offset
801                let mut g2 = g.clone();
802                g2.center[0] -= s;
803                g2.log_scale = [log_s; 3];
804                new_gaussians.push(g1);
805                new_gaussians.push(g2);
806            } else {
807                to_keep.push(gi);
808            }
809        }
810        let kept: Vec<Gaussian3D> = to_keep.iter().map(|&i| gaussians[i].clone()).collect();
811        *gaussians = kept;
812        gaussians.extend(new_gaussians);
813        self.grad_accum = vec![0.0; gaussians.len()];
814    }
815
816    /// Remove Gaussians with opacity below `opacity_threshold`.
817    pub fn prune(&mut self, gaussians: &mut Vec<Gaussian3D>, opacity_threshold: f64) {
818        gaussians.retain(|g| g.opacity() >= opacity_threshold);
819        self.grad_accum.resize(gaussians.len(), 0.0);
820    }
821}
822
823pub(crate) fn mse_image(image: &[Vec<[f64; 3]>], target: &[[f64; 3]], w: usize, h: usize) -> f64 {
824    let n = (w * h).max(1) as f64;
825    let mut sum = 0.0;
826    for (row_idx, row) in image.iter().enumerate() {
827        for (col_idx, &px) in row.iter().enumerate() {
828            let flat = row_idx * w + col_idx;
829            if let Some(&tgt) = target.get(flat) {
830                sum +=
831                    (px[0] - tgt[0]).powi(2) + (px[1] - tgt[1]).powi(2) + (px[2] - tgt[2]).powi(2);
832            }
833        }
834    }
835    sum / n
836}
837
838// ─────────────────────────────────────────────────────────────────────────────
839// §3  Implicit Neural Representations
840// ─────────────────────────────────────────────────────────────────────────────
841
842/// SIREN layer: `sin(ω₀ · (Wx + b))` (Sitzmann et al., 2020).
843#[derive(Debug, Clone)]
844pub struct SirenLayer {
845    pub weight: Vec<Vec<f64>>,
846    pub bias: Vec<f64>,
847    pub in_dim: usize,
848    pub out_dim: usize,
849}
850
851impl SirenLayer {
852    /// `is_first`: use `U(-1/in, 1/in)` init (first layer); else SIREN-scaled Xavier.
853    pub fn new(in_dim: usize, out_dim: usize, omega_0: f64, seed: u64, is_first: bool) -> Self {
854        let mut rng = StdRng::seed_from_u64(seed);
855        let weight = if is_first {
856            let bound = 1.0 / in_dim as f64;
857            (0..out_dim)
858                .map(|_| {
859                    (0..in_dim)
860                        .map(|_| (rng.random::<f64>() * 2.0 - 1.0) * bound)
861                        .collect()
862                })
863                .collect()
864        } else {
865            let bound = (6.0_f64 / in_dim as f64).sqrt() / omega_0;
866            (0..out_dim)
867                .map(|_| {
868                    (0..in_dim)
869                        .map(|_| (rng.random::<f64>() * 2.0 - 1.0) * bound)
870                        .collect()
871                })
872                .collect()
873        };
874        let bias = vec![0.0; out_dim];
875        Self {
876            weight,
877            bias,
878            in_dim,
879            out_dim,
880        }
881    }
882
883    /// Forward pass: `sin(ω₀ · (W x + b))`.
884    pub fn forward(&self, x: &[f64], omega_0: f64) -> Vec<f64> {
885        let pre = linear(&self.weight, &self.bias, x);
886        pre.into_iter().map(|v| (omega_0 * v).sin()).collect()
887    }
888}
889
890/// Stack of SIREN layers (all layers use sin; last layer is linear).
891#[derive(Debug, Clone)]
892pub struct SirenNetwork {
893    pub layers: Vec<SirenLayer>,
894    pub omega_0: f64,
895}
896
897impl SirenNetwork {
898    /// Create a new SIREN network.
899    pub fn new(in_dim: usize, hidden_dims: &[usize], out_dim: usize, omega_0: f64) -> Self {
900        let mut layers = Vec::new();
901        let mut seed = 100_u64;
902        let mut prev = in_dim;
903        for (i, &h) in hidden_dims.iter().enumerate() {
904            layers.push(SirenLayer::new(prev, h, omega_0, seed, i == 0));
905            seed += 1;
906            prev = h;
907        }
908        layers.push(SirenLayer::new(prev, out_dim, omega_0, seed, false));
909        Self { layers, omega_0 }
910    }
911
912    /// Forward pass through the SIREN network.
913    pub fn forward(&self, x: &[f64]) -> Vec<f64> {
914        let n = self.layers.len();
915        let mut h = x.to_vec();
916        for (i, layer) in self.layers.iter().enumerate() {
917            h = if i < n - 1 {
918                layer.forward(&h, self.omega_0)
919            } else {
920                linear(&layer.weight, &layer.bias, &h)
921            };
922        }
923        h
924    }
925}
926
927/// Neural SDF: SIREN predicting signed distance. Negative inside, positive outside.
928#[derive(Debug, Clone)]
929pub struct NeuralSdf {
930    network: SirenNetwork,
931}
932
933impl NeuralSdf {
934    /// Create a new neural SDF.
935    pub fn new(hidden_dim: usize, n_layers: usize) -> Self {
936        let hidden_dims = vec![hidden_dim; n_layers.max(1) - 1];
937        Self {
938            network: SirenNetwork::new(3, &hidden_dims, 1, 30.0),
939        }
940    }
941    /// Forward pass: returns the signed distance at the given 3D point.
942    pub fn forward(&self, point: &[f64]) -> f64 {
943        self.network.forward(point).first().copied().unwrap_or(0.0)
944    }
945}
946
947/// Instant NGP multi-resolution hash encoding (Müller et al., 2022).
948/// Trilinear interpolation over a spatial hash table per level.
949#[derive(Debug, Clone)]
950pub struct InstantNgp {
951    pub feature_dim: usize,
952    tables: Vec<Vec<Vec<f64>>>,
953    resolutions: Vec<usize>,
954}
955
956impl InstantNgp {
957    /// `resolutions`: grid sizes per level; `table_size`: hash entries; `feature_dim`: features/entry.
958    pub fn new(resolutions: &[usize], table_size: usize, feature_dim: usize) -> Self {
959        let mut rng = StdRng::seed_from_u64(999);
960        let tables: Vec<Vec<Vec<f64>>> = resolutions
961            .iter()
962            .map(|_| {
963                (0..table_size)
964                    .map(|_| {
965                        (0..feature_dim)
966                            .map(|_| rng.random::<f64>() * 0.001)
967                            .collect()
968                    })
969                    .collect()
970            })
971            .collect();
972        Self {
973            feature_dim,
974            tables,
975            resolutions: resolutions.to_vec(),
976        }
977    }
978
979    fn hash(ix: usize, iy: usize, iz: usize, table_size: usize) -> usize {
980        (ix ^ iy.wrapping_mul(2654435761) ^ iz.wrapping_mul(805459861)) % table_size
981    }
982
983    /// Encode 3D point in `[0,1]³` → concatenated multi-level features.
984    pub fn encode(&self, point: &[f64]) -> Vec<f64> {
985        let px = point.first().copied().unwrap_or(0.0).clamp(0.0, 1.0);
986        let py = point.get(1).copied().unwrap_or(0.0).clamp(0.0, 1.0);
987        let pz = point.get(2).copied().unwrap_or(0.0).clamp(0.0, 1.0);
988
989        let mut out = Vec::with_capacity(self.feature_dim * self.tables.len());
990        let table_size = self.tables.first().map(|t| t.len()).unwrap_or(1);
991
992        for (level_idx, &res_u) in self.resolutions.iter().enumerate() {
993            let res = res_u as f64;
994            let (fx, fy, fz) = (px * res, py * res, pz * res);
995            let (ix, iy, iz) = (
996                fx.floor() as usize,
997                fy.floor() as usize,
998                fz.floor() as usize,
999            );
1000            let (tx, ty, tz) = (fx.fract(), fy.fract(), fz.fract());
1001            let mut feat = vec![0.0_f64; self.feature_dim];
1002            for dz in 0..2usize {
1003                for dy in 0..2usize {
1004                    for dx in 0..2usize {
1005                        let h = Self::hash(ix + dx, iy + dy, iz + dz, table_size);
1006                        let w = (if dx == 0 { 1.0 - tx } else { tx })
1007                            * (if dy == 0 { 1.0 - ty } else { ty })
1008                            * (if dz == 0 { 1.0 - tz } else { tz });
1009                        for (k, &v) in self.tables[level_idx][h].iter().enumerate() {
1010                            feat[k] += w * v;
1011                        }
1012                    }
1013                }
1014            }
1015            out.extend_from_slice(&feat);
1016        }
1017        out
1018    }
1019
1020    /// Output dimension of the encoding.
1021    pub fn output_dim(&self) -> usize {
1022        self.tables.len() * self.feature_dim
1023    }
1024}
1025
1026/// MLP predicting occupancy probability ∈ \[0,1\] at arbitrary 3D query points.
1027#[derive(Debug, Clone)]
1028pub struct OccupancyNetwork {
1029    layers: Vec<(Vec<Vec<f64>>, Vec<f64>)>,
1030}
1031
1032impl OccupancyNetwork {
1033    /// Create a new occupancy network.
1034    pub fn new(in_dim: usize, hidden_dims: &[usize], seed: u64) -> Self {
1035        let mut s = seed;
1036        let mut layers = Vec::new();
1037        let mut prev = in_dim;
1038        for &h in hidden_dims {
1039            layers.push((rand_weight(h, prev, s), vec![0.0; h]));
1040            s += 1;
1041            prev = h;
1042        }
1043        layers.push((rand_weight(1, prev, s), vec![0.0]));
1044        Self { layers }
1045    }
1046
1047    /// Forward pass: returns occupancy probability in \[0,1\].
1048    pub fn forward(&self, point: &[f64]) -> f64 {
1049        let n = self.layers.len();
1050        let mut h = point.to_vec();
1051        for (i, (w, b)) in self.layers.iter().enumerate() {
1052            if i < n - 1 {
1053                h = linear_relu(w, b, &h);
1054            } else {
1055                h = linear(w, b, &h);
1056            }
1057        }
1058        sigmoid(h.first().copied().unwrap_or(0.0))
1059    }
1060}