Skip to main content

tenflowers_neural/depth_estimation/
mod.rs

1//! Depth Estimation & 3D Vision.
2//!
3//! Production-grade monocular/stereo depth estimation, depth completion,
4//! volumetric convolutions, implicit neural fields, and panoptic segmentation.
5//!
6//! - **DepthEncoder**: Multi-scale feature extraction (4 levels, 1/2..1/16)
7//! - **DptDecoder**: Dense Prediction Transformer decoder (reassemble + fusion)
8//! - **MonocularDepthEstimator**: Full DPT-style pipeline (scale-invariant + gradient loss)
9//! - **StereoMatcher**: Cost volume correlation + soft-argmin disparity regression
10//! - **DepthCompletion**: Sparse-to-dense depth with confidence-guided propagation
11//! - **DeConv3d**: Full 3D convolution (N, C, D, H, W) with stride/padding/dilation
12//! - **DeImplicitNeuralField**: NeRF-inspired positional-encoding MLP + volume rendering
13//! - **DePanopticHead**: Panoptic segmentation (semantic + instance + fusion)
14//! - **DeDepthMetrics** / **DeDepthReport**: AbsRel, SqRel, RMSE, delta thresholds, SILog
15
16use scirs2_core::random::{rngs::StdRng, Rng, SeedableRng};
17use scirs2_core::RngExt;
18use tenflowers_core::{Result, TensorError};
19
20// ─────────────────────────────────────────────────────────────────────────────
21// Internal helpers
22// ─────────────────────────────────────────────────────────────────────────────
23
24#[inline]
25fn de_relu(x: f64) -> f64 {
26    x.max(0.0)
27}
28
29#[inline]
30fn de_sigmoid(x: f64) -> f64 {
31    if x >= 0.0 {
32        1.0 / (1.0 + (-x).exp())
33    } else {
34        let e = x.exp();
35        e / (1.0 + e)
36    }
37}
38
39#[inline]
40fn de_softplus(x: f64) -> f64 {
41    if x > 20.0 {
42        x
43    } else {
44        (1.0 + x.exp()).ln()
45    }
46}
47
48fn de_sample_normal(rng: &mut impl Rng) -> f64 {
49    let u1: f64 = rng.random::<f64>().max(1e-12);
50    let u2: f64 = rng.random::<f64>();
51    (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
52}
53
54/// Xavier-normal initialization for a weight matrix [rows x cols].
55fn de_xavier_init(rows: usize, cols: usize, seed: u64) -> Vec<Vec<f64>> {
56    let mut rng = StdRng::seed_from_u64(seed);
57    let scale = (2.0 / (rows + cols) as f64).sqrt();
58    (0..rows)
59        .map(|_| {
60            (0..cols)
61                .map(|_| de_sample_normal(&mut rng) * scale)
62                .collect()
63        })
64        .collect()
65}
66
67/// Flat Xavier-normal initialization.
68fn de_xavier_init_flat(size: usize, fan_in: usize, fan_out: usize, seed: u64) -> Vec<f64> {
69    let mut rng = StdRng::seed_from_u64(seed);
70    let scale = (2.0 / (fan_in + fan_out) as f64).sqrt();
71    (0..size)
72        .map(|_| de_sample_normal(&mut rng) * scale)
73        .collect()
74}
75
76fn de_dot(a: &[f64], b: &[f64]) -> f64 {
77    a.iter().zip(b).map(|(x, y)| x * y).sum()
78}
79
80fn de_matvec(mat: &[Vec<f64>], v: &[f64]) -> Vec<f64> {
81    mat.iter().map(|row| de_dot(row, v)).collect()
82}
83
84fn de_linear(w: &[Vec<f64>], b: &[f64], x: &[f64]) -> Vec<f64> {
85    let out = de_matvec(w, x);
86    out.iter().zip(b).map(|(o, bi)| o + bi).collect()
87}
88
89fn de_linear_relu(w: &[Vec<f64>], b: &[f64], x: &[f64]) -> Vec<f64> {
90    de_linear(w, b, x).into_iter().map(de_relu).collect()
91}
92
93/// Batch normalization (1-D, inference-mode: channel-wise mean/var).
94fn de_batch_norm(x: &[f64]) -> Vec<f64> {
95    if x.is_empty() {
96        return Vec::new();
97    }
98    let n = x.len() as f64;
99    let mean: f64 = x.iter().sum::<f64>() / n;
100    let var: f64 = x.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / n;
101    let inv_std = 1.0 / (var + 1e-5).sqrt();
102    x.iter().map(|&v| (v - mean) * inv_std).collect()
103}
104
105/// 2-D convolution on a single-channel HxW image with a KxK kernel (stride, pad).
106fn de_conv2d_single(
107    input: &[f64],
108    h: usize,
109    w: usize,
110    kernel: &[f64],
111    k: usize,
112    stride: usize,
113    pad: usize,
114) -> (Vec<f64>, usize, usize) {
115    let oh = (h + 2 * pad - k) / stride + 1;
116    let ow = (w + 2 * pad - k) / stride + 1;
117    let mut out = vec![0.0; oh * ow];
118    for oi in 0..oh {
119        for oj in 0..ow {
120            let mut val = 0.0;
121            for ki in 0..k {
122                for kj in 0..k {
123                    let ii = oi * stride + ki;
124                    let jj = oj * stride + kj;
125                    if ii >= pad && jj >= pad && ii - pad < h && jj - pad < w {
126                        val += input[(ii - pad) * w + (jj - pad)] * kernel[ki * k + kj];
127                    }
128                }
129            }
130            out[oi * ow + oj] = val;
131        }
132    }
133    (out, oh, ow)
134}
135
136/// Max-pool 2-D (pool_size x pool_size, stride = pool_size).
137fn de_maxpool2d(input: &[f64], h: usize, w: usize, pool: usize) -> (Vec<f64>, usize, usize) {
138    let oh = h / pool;
139    let ow = w / pool;
140    let mut out = vec![f64::NEG_INFINITY; oh * ow];
141    for oi in 0..oh {
142        for oj in 0..ow {
143            for pi in 0..pool {
144                for pj in 0..pool {
145                    let idx = (oi * pool + pi) * w + (oj * pool + pj);
146                    if idx < input.len() && input[idx] > out[oi * ow + oj] {
147                        out[oi * ow + oj] = input[idx];
148                    }
149                }
150            }
151        }
152    }
153    (out, oh, ow)
154}
155
156/// Bilinear upsample 2x for a single-channel HxW feature map.
157fn de_upsample_2x(input: &[f64], h: usize, w: usize) -> (Vec<f64>, usize, usize) {
158    let nh = h * 2;
159    let nw = w * 2;
160    let mut out = vec![0.0; nh * nw];
161    for i in 0..nh {
162        for j in 0..nw {
163            let si = (i as f64) / 2.0;
164            let sj = (j as f64) / 2.0;
165            let y0 = si.floor() as usize;
166            let x0 = sj.floor() as usize;
167            let y1 = (y0 + 1).min(h - 1);
168            let x1 = (x0 + 1).min(w - 1);
169            let fy = si - y0 as f64;
170            let fx = sj - x0 as f64;
171            let val = input[y0 * w + x0] * (1.0 - fy) * (1.0 - fx)
172                + input[y0 * w + x1] * (1.0 - fy) * fx
173                + input[y1 * w + x0] * fy * (1.0 - fx)
174                + input[y1 * w + x1] * fy * fx;
175            out[i * nw + j] = val;
176        }
177    }
178    (out, nh, nw)
179}
180
181/// Element-wise add two same-length slices.
182fn de_vec_add(a: &[f64], b: &[f64]) -> Vec<f64> {
183    a.iter().zip(b).map(|(x, y)| x + y).collect()
184}
185
186// ─────────────────────────────────────────────────────────────────────────────
187// §1  DepthEncoder — Multi-scale feature extraction (4 levels)
188// ─────────────────────────────────────────────────────────────────────────────
189
190/// A single encoder block: Conv → BN → ReLU → Conv → BN → ReLU → MaxPool.
191#[derive(Debug, Clone)]
192pub struct DeEncoderBlock {
193    /// Conv1 kernel (k x k, flattened).
194    pub conv1: Vec<f64>,
195    /// Conv2 kernel (k x k, flattened).
196    pub conv2: Vec<f64>,
197    /// Kernel size.
198    pub k: usize,
199    /// Stride (always 1 for conv, pool handles downsampling).
200    pub stride: usize,
201    /// Padding for convolution.
202    pub pad: usize,
203    /// Pool size (2 for 2x downsampling).
204    pub pool_size: usize,
205}
206
207/// Configuration for the depth encoder.
208#[derive(Debug, Clone)]
209pub struct DepthEncoderConfig {
210    /// Input spatial height.
211    pub input_h: usize,
212    /// Input spatial width.
213    pub input_w: usize,
214    /// Convolution kernel size.
215    pub kernel_size: usize,
216    /// Random seed.
217    pub seed: u64,
218}
219
220impl Default for DepthEncoderConfig {
221    fn default() -> Self {
222        Self {
223            input_h: 64,
224            input_w: 64,
225            kernel_size: 3,
226            seed: 42,
227        }
228    }
229}
230
231/// Multi-scale feature extractor producing 4 levels at 1/2, 1/4, 1/8, 1/16.
232///
233/// Each level applies two convolutions (Xavier init), batch normalization,
234/// ReLU, and max-pooling.
235#[derive(Debug, Clone)]
236pub struct DepthEncoder {
237    pub blocks: Vec<DeEncoderBlock>,
238    pub input_h: usize,
239    pub input_w: usize,
240}
241
242/// A multi-scale feature map produced by [`DepthEncoder`].
243#[derive(Debug, Clone)]
244pub struct DeMultiScaleFeatures {
245    /// Feature maps at each level (flattened), from finest (1/2) to coarsest (1/16).
246    pub features: Vec<Vec<f64>>,
247    /// Spatial height at each level.
248    pub heights: Vec<usize>,
249    /// Spatial width at each level.
250    pub widths: Vec<usize>,
251}
252
253impl DepthEncoder {
254    /// Create a new depth encoder with Xavier-initialized conv kernels.
255    pub fn new(config: &DepthEncoderConfig) -> Result<Self> {
256        if config.input_h < 16 || config.input_w < 16 {
257            return Err(TensorError::compute_error_simple(
258                "DepthEncoder: input_h and input_w must be >= 16".to_string(),
259            ));
260        }
261        let k = config.kernel_size;
262        let kk = k * k;
263        let mut blocks = Vec::with_capacity(4);
264        let mut seed = config.seed;
265        for _level in 0..4 {
266            let conv1 = de_xavier_init_flat(kk, kk, kk, seed);
267            seed = seed.wrapping_add(1);
268            let conv2 = de_xavier_init_flat(kk, kk, kk, seed);
269            seed = seed.wrapping_add(1);
270            blocks.push(DeEncoderBlock {
271                conv1,
272                conv2,
273                k,
274                stride: 1,
275                pad: k / 2,
276                pool_size: 2,
277            });
278        }
279        Ok(Self {
280            blocks,
281            input_h: config.input_h,
282            input_w: config.input_w,
283        })
284    }
285
286    /// Extract multi-scale features from a single-channel input.
287    ///
288    /// `input` must have length `input_h * input_w`.
289    pub fn forward(&self, input: &[f64]) -> Result<DeMultiScaleFeatures> {
290        let expected = self.input_h * self.input_w;
291        if input.len() != expected {
292            return Err(TensorError::compute_error_simple(
293                "DepthEncoder: input length mismatch".to_string(),
294            ));
295        }
296        let mut features = Vec::with_capacity(4);
297        let mut heights = Vec::with_capacity(4);
298        let mut widths = Vec::with_capacity(4);
299
300        let mut current = input.to_vec();
301        let mut h = self.input_h;
302        let mut w = self.input_w;
303
304        for block in &self.blocks {
305            // Conv1 → BN → ReLU
306            let (c1, h1, w1) = de_conv2d_single(
307                &current,
308                h,
309                w,
310                &block.conv1,
311                block.k,
312                block.stride,
313                block.pad,
314            );
315            let bn1 = de_batch_norm(&c1);
316            let r1: Vec<f64> = bn1.into_iter().map(de_relu).collect();
317
318            // Conv2 → BN → ReLU
319            let (c2, h2, w2) =
320                de_conv2d_single(&r1, h1, w1, &block.conv2, block.k, block.stride, block.pad);
321            let bn2 = de_batch_norm(&c2);
322            let r2: Vec<f64> = bn2.into_iter().map(de_relu).collect();
323
324            // MaxPool
325            let (pooled, ph, pw) = de_maxpool2d(&r2, h2, w2, block.pool_size);
326
327            features.push(pooled.clone());
328            heights.push(ph);
329            widths.push(pw);
330
331            current = pooled;
332            h = ph;
333            w = pw;
334        }
335
336        Ok(DeMultiScaleFeatures {
337            features,
338            heights,
339            widths,
340        })
341    }
342}
343
344// ─────────────────────────────────────────────────────────────────────────────
345// §2  DptDecoder — Dense Prediction Transformer decoder
346// ─────────────────────────────────────────────────────────────────────────────
347
348/// Configuration for the DPT decoder.
349#[derive(Debug, Clone)]
350pub struct DptDecoderConfig {
351    /// Number of feature levels (typically 4).
352    pub n_levels: usize,
353    /// Output channel dimension after reassembly projection.
354    pub reassemble_dim: usize,
355    /// Final head intermediate channels.
356    pub head_channels: usize,
357    /// Random seed.
358    pub seed: u64,
359}
360
361impl Default for DptDecoderConfig {
362    fn default() -> Self {
363        Self {
364            n_levels: 4,
365            reassemble_dim: 32,
366            head_channels: 16,
367            seed: 100,
368        }
369    }
370}
371
372/// A reassembly projection layer: linear projection to a fixed dimension.
373#[derive(Debug, Clone)]
374pub struct DeReassembleLayer {
375    pub weight: Vec<f64>,
376    pub bias: f64,
377    pub in_dim: usize,
378    pub out_dim: usize,
379}
380
381impl DeReassembleLayer {
382    fn new(in_dim: usize, out_dim: usize, seed: u64) -> Self {
383        let weight = de_xavier_init_flat(in_dim * out_dim, in_dim, out_dim, seed);
384        Self {
385            weight,
386            bias: 0.0,
387            in_dim,
388            out_dim,
389        }
390    }
391
392    /// Project each pixel's feature through a linear layer.
393    fn forward(&self, input: &[f64], h: usize, w: usize) -> Result<Vec<f64>> {
394        // input: [h * w] (single channel) → project to out_dim channels per pixel
395        // For simplicity, treat as 1-channel and scale/bias per pixel
396        let n_pixels = h * w;
397        if input.len() != n_pixels {
398            return Err(TensorError::compute_error_simple(
399                "DeReassembleLayer: input size mismatch".to_string(),
400            ));
401        }
402        // Single-channel projection: simple scale + bias
403        let scale = if self.weight.is_empty() {
404            1.0
405        } else {
406            self.weight[0]
407        };
408        let out: Vec<f64> = input.iter().map(|&v| v * scale + self.bias).collect();
409        Ok(out)
410    }
411}
412
413/// RefineNet-style residual convolution block for fusion.
414#[derive(Debug, Clone)]
415pub struct DeRefineBlock {
416    pub conv_kernel: Vec<f64>,
417    pub k: usize,
418    pub pad: usize,
419}
420
421impl DeRefineBlock {
422    fn new(k: usize, seed: u64) -> Self {
423        let kk = k * k;
424        let conv_kernel = de_xavier_init_flat(kk, kk, kk, seed);
425        Self {
426            conv_kernel,
427            k,
428            pad: k / 2,
429        }
430    }
431
432    /// Refine: conv → BN → ReLU, then residual add with input.
433    fn forward(&self, input: &[f64], h: usize, w: usize) -> Result<(Vec<f64>, usize, usize)> {
434        let (conv_out, oh, ow) =
435            de_conv2d_single(input, h, w, &self.conv_kernel, self.k, 1, self.pad);
436        let bn = de_batch_norm(&conv_out);
437        let activated: Vec<f64> = bn.into_iter().map(de_relu).collect();
438        // Residual connection
439        let result = if activated.len() == input.len() {
440            de_vec_add(&activated, input)
441        } else {
442            activated
443        };
444        Ok((result, oh, ow))
445    }
446}
447
448/// Dense Prediction Transformer (DPT) decoder.
449///
450/// Reassembles multi-scale features from [`DepthEncoder`], fuses them
451/// iteratively from coarsest to finest with RefineNet blocks, and
452/// produces a final depth map through a 2-layer convolution head.
453#[derive(Debug, Clone)]
454pub struct DptDecoder {
455    pub reassemble_layers: Vec<DeReassembleLayer>,
456    pub refine_blocks: Vec<DeRefineBlock>,
457    pub head_conv1: Vec<f64>,
458    pub head_conv2: Vec<f64>,
459    pub head_k: usize,
460    pub head_pad: usize,
461}
462
463impl DptDecoder {
464    /// Create a new DPT decoder.
465    pub fn new(config: &DptDecoderConfig) -> Result<Self> {
466        if config.n_levels == 0 {
467            return Err(TensorError::compute_error_simple(
468                "DptDecoder: n_levels must be > 0".to_string(),
469            ));
470        }
471        let mut seed = config.seed;
472        let mut reassemble_layers = Vec::with_capacity(config.n_levels);
473        let mut refine_blocks = Vec::with_capacity(config.n_levels);
474
475        for _i in 0..config.n_levels {
476            reassemble_layers.push(DeReassembleLayer::new(
477                config.reassemble_dim,
478                config.reassemble_dim,
479                seed,
480            ));
481            seed = seed.wrapping_add(1);
482            refine_blocks.push(DeRefineBlock::new(3, seed));
483            seed = seed.wrapping_add(1);
484        }
485
486        let hk = 3;
487        let hkk = hk * hk;
488        let head_conv1 = de_xavier_init_flat(hkk, config.head_channels, config.head_channels, seed);
489        seed = seed.wrapping_add(1);
490        let head_conv2 = de_xavier_init_flat(hkk, config.head_channels, 1, seed);
491
492        Ok(Self {
493            reassemble_layers,
494            refine_blocks,
495            head_conv1,
496            head_conv2,
497            head_k: hk,
498            head_pad: hk / 2,
499        })
500    }
501
502    /// Decode multi-scale features to a depth map.
503    ///
504    /// Returns `(depth_map, height, width)`.
505    pub fn forward(&self, ms_features: &DeMultiScaleFeatures) -> Result<(Vec<f64>, usize, usize)> {
506        let n = ms_features.features.len();
507        if n == 0 {
508            return Err(TensorError::compute_error_simple(
509                "DptDecoder: empty features".to_string(),
510            ));
511        }
512
513        // Start from coarsest level
514        let last = n - 1;
515        let reassembled = self.reassemble_layers[last].forward(
516            &ms_features.features[last],
517            ms_features.heights[last],
518            ms_features.widths[last],
519        )?;
520        let (mut fused, mut fh, mut fw) = self.refine_blocks[last].forward(
521            &reassembled,
522            ms_features.heights[last],
523            ms_features.widths[last],
524        )?;
525
526        // Iteratively fuse from coarser to finer
527        for level in (0..last).rev() {
528            // Upsample current fused to match finer level
529            let (up, uh, uw) = de_upsample_2x(&fused, fh, fw);
530            let reassembled_level = self.reassemble_layers[level].forward(
531                &ms_features.features[level],
532                ms_features.heights[level],
533                ms_features.widths[level],
534            )?;
535
536            // Add upsampled fused + reassembled finer features
537            let combined = if up.len() == reassembled_level.len() {
538                de_vec_add(&up, &reassembled_level)
539            } else {
540                // Size mismatch: use the smaller
541                let min_len = up.len().min(reassembled_level.len());
542                up[..min_len]
543                    .iter()
544                    .zip(&reassembled_level[..min_len])
545                    .map(|(a, b)| a + b)
546                    .collect()
547            };
548
549            let target_h = ms_features.heights[level];
550            let target_w = ms_features.widths[level];
551            let actual_len = combined.len();
552            let expected_len = target_h * target_w;
553
554            let adjusted = if actual_len >= expected_len {
555                combined[..expected_len].to_vec()
556            } else {
557                let mut v = combined;
558                v.resize(expected_len, 0.0);
559                v
560            };
561
562            let (refined, rh, rw) =
563                self.refine_blocks[level].forward(&adjusted, target_h, target_w)?;
564            fused = refined;
565            fh = rh;
566            fw = rw;
567        }
568
569        // Head: Conv1 → ReLU → Conv2 → softplus (ensure positive depth)
570        let (h1, h1h, h1w) = de_conv2d_single(
571            &fused,
572            fh,
573            fw,
574            &self.head_conv1,
575            self.head_k,
576            1,
577            self.head_pad,
578        );
579        let h1_act: Vec<f64> = h1.into_iter().map(de_relu).collect();
580        let (h2, h2h, h2w) = de_conv2d_single(
581            &h1_act,
582            h1h,
583            h1w,
584            &self.head_conv2,
585            self.head_k,
586            1,
587            self.head_pad,
588        );
589        let depth: Vec<f64> = h2.into_iter().map(de_softplus).collect();
590
591        Ok((depth, h2h, h2w))
592    }
593}
594
595// ─────────────────────────────────────────────────────────────────────────────
596// §3  MonocularDepthEstimator — Full DPT-style pipeline
597// ─────────────────────────────────────────────────────────────────────────────
598
599/// Depth output mode.
600#[derive(Debug, Clone, Copy, PartialEq)]
601pub enum DepthMode {
602    /// Metric depth in absolute units.
603    Metric,
604    /// Relative (ordinal) depth (normalized 0..1).
605    Relative,
606}
607
608/// Full monocular depth estimation pipeline.
609///
610/// Combines [`DepthEncoder`] and [`DptDecoder`] into an end-to-end system
611/// with scale-invariant loss (Eigen 2014) and gradient-matching loss.
612#[derive(Debug, Clone)]
613pub struct MonocularDepthEstimator {
614    pub encoder: DepthEncoder,
615    pub decoder: DptDecoder,
616    pub mode: DepthMode,
617    /// Lambda for the scale-invariant loss variance term.
618    pub si_lambda: f64,
619}
620
621impl MonocularDepthEstimator {
622    /// Create a new monocular depth estimator.
623    pub fn new(
624        encoder_config: &DepthEncoderConfig,
625        decoder_config: &DptDecoderConfig,
626        mode: DepthMode,
627    ) -> Result<Self> {
628        let encoder = DepthEncoder::new(encoder_config)?;
629        let decoder = DptDecoder::new(decoder_config)?;
630        Ok(Self {
631            encoder,
632            decoder,
633            mode,
634            si_lambda: 0.5,
635        })
636    }
637
638    /// Predict depth from a single-channel image.
639    ///
640    /// Returns `(depth_map, height, width)`.
641    pub fn predict_depth(&self, image: &[f64]) -> Result<(Vec<f64>, usize, usize)> {
642        let features = self.encoder.forward(image)?;
643        let (mut depth, h, w) = self.decoder.forward(&features)?;
644
645        if self.mode == DepthMode::Relative {
646            // Normalize to [0, 1]
647            let min_val = depth.iter().cloned().fold(f64::INFINITY, f64::min);
648            let max_val = depth.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
649            let range = (max_val - min_val).max(1e-8);
650            for v in &mut depth {
651                *v = (*v - min_val) / range;
652            }
653        }
654
655        Ok((depth, h, w))
656    }
657
658    /// Scale-invariant loss (Eigen et al. 2014).
659    ///
660    /// `d_i = log(pred_i) - log(gt_i)`, loss = Var(d) + lambda * Mean(d)^2
661    pub fn scale_invariant_loss(pred: &[f64], gt: &[f64]) -> Result<f64> {
662        if pred.len() != gt.len() || pred.is_empty() {
663            return Err(TensorError::compute_error_simple(
664                "scale_invariant_loss: size mismatch or empty".to_string(),
665            ));
666        }
667        let n = pred.len() as f64;
668        let d: Vec<f64> = pred
669            .iter()
670            .zip(gt)
671            .map(|(&p, &g)| (p.max(1e-8)).ln() - (g.max(1e-8)).ln())
672            .collect();
673        let mean_d: f64 = d.iter().sum::<f64>() / n;
674        let var_d: f64 = d.iter().map(|&di| (di - mean_d).powi(2)).sum::<f64>() / n;
675        Ok(var_d + 0.5 * mean_d * mean_d)
676    }
677
678    /// Scale-invariant loss with configurable lambda.
679    pub fn scale_invariant_loss_lambda(pred: &[f64], gt: &[f64], lambda: f64) -> Result<f64> {
680        if pred.len() != gt.len() || pred.is_empty() {
681            return Err(TensorError::compute_error_simple(
682                "scale_invariant_loss_lambda: size mismatch or empty".to_string(),
683            ));
684        }
685        let n = pred.len() as f64;
686        let d: Vec<f64> = pred
687            .iter()
688            .zip(gt)
689            .map(|(&p, &g)| (p.max(1e-8)).ln() - (g.max(1e-8)).ln())
690            .collect();
691        let mean_d: f64 = d.iter().sum::<f64>() / n;
692        let var_d: f64 = d.iter().map(|&di| (di - mean_d).powi(2)).sum::<f64>() / n;
693        Ok(var_d + lambda * mean_d * mean_d)
694    }
695
696    /// Gradient matching loss (edge-aware smoothness).
697    ///
698    /// Penalizes differences between spatial gradients of predicted and GT depth.
699    pub fn gradient_matching_loss(pred: &[f64], gt: &[f64], h: usize, w: usize) -> Result<f64> {
700        if pred.len() != h * w || gt.len() != h * w {
701            return Err(TensorError::compute_error_simple(
702                "gradient_matching_loss: size mismatch".to_string(),
703            ));
704        }
705        if h < 2 || w < 2 {
706            return Err(TensorError::compute_error_simple(
707                "gradient_matching_loss: need h >= 2 and w >= 2".to_string(),
708            ));
709        }
710
711        let mut loss: f64 = 0.0;
712        let mut count: f64 = 0.0;
713
714        // Horizontal gradients
715        for i in 0..h {
716            for j in 0..(w - 1) {
717                let pred_dx = pred[i * w + j + 1] - pred[i * w + j];
718                let gt_dx = gt[i * w + j + 1] - gt[i * w + j];
719                loss += (pred_dx - gt_dx).powi(2);
720                count += 1.0;
721            }
722        }
723
724        // Vertical gradients
725        for i in 0..(h - 1) {
726            for j in 0..w {
727                let pred_dy = pred[(i + 1) * w + j] - pred[i * w + j];
728                let gt_dy = gt[(i + 1) * w + j] - gt[i * w + j];
729                loss += (pred_dy - gt_dy).powi(2);
730                count += 1.0;
731            }
732        }
733
734        Ok(loss / count.max(1.0))
735    }
736}
737
738// ─────────────────────────────────────────────────────────────────────────────
739// §4  StereoMatcher — Stereo vision depth estimation
740// ─────────────────────────────────────────────────────────────────────────────
741
742/// Configuration for the stereo matcher.
743#[derive(Debug, Clone)]
744pub struct StereoMatcherConfig {
745    /// Maximum disparity to search.
746    pub max_disparity: usize,
747    /// Focal length (pixels).
748    pub focal_length: f64,
749    /// Stereo baseline distance (same units as desired depth).
750    pub baseline: f64,
751    /// Cost aggregation window radius.
752    pub agg_radius: usize,
753    /// Random seed.
754    pub seed: u64,
755}
756
757impl Default for StereoMatcherConfig {
758    fn default() -> Self {
759        Self {
760            max_disparity: 32,
761            focal_length: 500.0,
762            baseline: 0.12,
763            agg_radius: 2,
764            seed: 200,
765        }
766    }
767}
768
769/// Stereo vision depth estimation.
770///
771/// Builds a cost volume via correlation between left/right feature maps
772/// at each candidate disparity, applies window-based cost aggregation,
773/// and uses soft-argmin regression for sub-pixel disparity.
774#[derive(Debug, Clone)]
775pub struct StereoMatcher {
776    pub config: StereoMatcherConfig,
777}
778
779impl StereoMatcher {
780    pub fn new(config: StereoMatcherConfig) -> Result<Self> {
781        if config.max_disparity == 0 {
782            return Err(TensorError::compute_error_simple(
783                "StereoMatcher: max_disparity must be > 0".to_string(),
784            ));
785        }
786        if config.focal_length <= 0.0 || config.baseline <= 0.0 {
787            return Err(TensorError::compute_error_simple(
788                "StereoMatcher: focal_length and baseline must be positive".to_string(),
789            ));
790        }
791        Ok(Self { config })
792    }
793
794    /// Build the cost volume: correlation between left and right at each disparity.
795    ///
796    /// Returns `[max_disparity][h * w]` cost volume.
797    pub fn build_cost_volume(
798        &self,
799        left: &[f64],
800        right: &[f64],
801        h: usize,
802        w: usize,
803    ) -> Result<Vec<Vec<f64>>> {
804        let expected = h * w;
805        if left.len() != expected || right.len() != expected {
806            return Err(TensorError::compute_error_simple(
807                "StereoMatcher: left/right size mismatch".to_string(),
808            ));
809        }
810
811        let max_d = self.config.max_disparity;
812        let mut cost_vol = Vec::with_capacity(max_d);
813
814        for d in 0..max_d {
815            let mut cost = vec![0.0; h * w];
816            for i in 0..h {
817                for j in 0..w {
818                    if j >= d {
819                        // Correlation (negative L1 distance for matching cost)
820                        cost[i * w + j] = -(left[i * w + j] - right[i * w + (j - d)]).abs();
821                    } else {
822                        cost[i * w + j] = f64::NEG_INFINITY;
823                    }
824                }
825            }
826            cost_vol.push(cost);
827        }
828
829        Ok(cost_vol)
830    }
831
832    /// Aggregate costs using a local window.
833    pub fn aggregate_costs(&self, cost_vol: &[Vec<f64>], h: usize, w: usize) -> Vec<Vec<f64>> {
834        let r = self.config.agg_radius;
835        let max_d = cost_vol.len();
836        let mut agg = Vec::with_capacity(max_d);
837
838        for d_idx in 0..max_d {
839            let mut agg_d = vec![0.0; h * w];
840            for i in 0..h {
841                for j in 0..w {
842                    let mut sum = 0.0;
843                    let mut cnt = 0.0;
844                    let i_start = i.saturating_sub(r);
845                    let j_start = j.saturating_sub(r);
846                    let i_end = (i + r + 1).min(h);
847                    let j_end = (j + r + 1).min(w);
848                    for ii in i_start..i_end {
849                        for jj in j_start..j_end {
850                            let val = cost_vol[d_idx][ii * w + jj];
851                            if val > f64::NEG_INFINITY {
852                                sum += val;
853                                cnt += 1.0;
854                            }
855                        }
856                    }
857                    agg_d[i * w + j] = if cnt > 0.0 {
858                        sum / cnt
859                    } else {
860                        f64::NEG_INFINITY
861                    };
862                }
863            }
864            agg.push(agg_d);
865        }
866
867        agg
868    }
869
870    /// Soft-argmin disparity regression (differentiable).
871    ///
872    /// Applies softmax over the disparity dimension and computes the
873    /// expected disparity value.
874    pub fn soft_argmin_disparity(&self, cost_vol: &[Vec<f64>], h: usize, w: usize) -> Vec<f64> {
875        let max_d = cost_vol.len();
876        let n = h * w;
877        let mut disparity = vec![0.0; n];
878
879        for pixel in 0..n {
880            // Collect costs for this pixel across disparities
881            let mut costs: Vec<f64> = (0..max_d).map(|d| cost_vol[d][pixel]).collect();
882
883            // Softmax (negative costs → use costs directly as logits)
884            let max_c = costs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
885            let mut sum_exp = 0.0;
886            for c in &mut costs {
887                if *c > f64::NEG_INFINITY {
888                    *c = (*c - max_c).exp();
889                } else {
890                    *c = 0.0;
891                }
892                sum_exp += *c;
893            }
894            if sum_exp > 0.0 {
895                for c in &mut costs {
896                    *c /= sum_exp;
897                }
898            }
899
900            // Expected disparity
901            let d_val: f64 = costs
902                .iter()
903                .enumerate()
904                .map(|(d, &w_d)| d as f64 * w_d)
905                .sum();
906            disparity[pixel] = d_val;
907        }
908
909        disparity
910    }
911
912    /// Full stereo matching pipeline.
913    ///
914    /// Returns `(disparity_map, depth_map)` each of size `h * w`.
915    pub fn match_stereo(
916        &self,
917        left: &[f64],
918        right: &[f64],
919        h: usize,
920        w: usize,
921    ) -> Result<(Vec<f64>, Vec<f64>)> {
922        let cost_vol = self.build_cost_volume(left, right, h, w)?;
923        let agg_costs = self.aggregate_costs(&cost_vol, h, w);
924        let disparity = self.soft_argmin_disparity(&agg_costs, h, w);
925
926        // Convert disparity to depth: depth = focal_length * baseline / disparity
927        let depth: Vec<f64> = disparity
928            .iter()
929            .map(|&d| {
930                if d > 1e-6 {
931                    self.config.focal_length * self.config.baseline / d
932                } else {
933                    0.0
934                }
935            })
936            .collect();
937
938        Ok((disparity, depth))
939    }
940
941    /// Convert a disparity value to depth.
942    pub fn disparity_to_depth(&self, disparity: f64) -> f64 {
943        if disparity > 1e-6 {
944            self.config.focal_length * self.config.baseline / disparity
945        } else {
946            0.0
947        }
948    }
949}
950
951// ─────────────────────────────────────────────────────────────────────────────
952// §5  DepthCompletion — Sparse-to-dense depth completion
953// ─────────────────────────────────────────────────────────────────────────────
954
955/// Fusion strategy for depth completion.
956#[derive(Debug, Clone, Copy, PartialEq)]
957pub enum DeCompletionFusion {
958    /// Concatenate sparse depth + RGB → single encoder.
959    Early,
960    /// Separate encoders for depth and RGB, merge late.
961    Late,
962}
963
964/// Configuration for depth completion.
965#[derive(Debug, Clone)]
966pub struct DepthCompletionConfig {
967    /// Spatial height.
968    pub h: usize,
969    /// Spatial width.
970    pub w: usize,
971    /// Number of propagation iterations.
972    pub n_iterations: usize,
973    /// Fusion strategy.
974    pub fusion: DeCompletionFusion,
975    /// Confidence decay factor per iteration.
976    pub confidence_decay: f64,
977    /// Seed.
978    pub seed: u64,
979}
980
981impl Default for DepthCompletionConfig {
982    fn default() -> Self {
983        Self {
984            h: 32,
985            w: 32,
986            n_iterations: 5,
987            fusion: DeCompletionFusion::Early,
988            confidence_decay: 0.9,
989            seed: 300,
990        }
991    }
992}
993
994/// Sparse-to-dense depth completion.
995///
996/// Uses confidence-guided propagation: sparse depth observations with
997/// known confidence are iteratively diffused to neighboring pixels using
998/// a mask-propagating convolution scheme.
999#[derive(Debug, Clone)]
1000pub struct DepthCompletion {
1001    pub config: DepthCompletionConfig,
1002    /// Propagation kernel (3x3).
1003    pub prop_kernel: [f64; 9],
1004}
1005
1006impl DepthCompletion {
1007    pub fn new(config: DepthCompletionConfig) -> Result<Self> {
1008        if config.h == 0 || config.w == 0 {
1009            return Err(TensorError::compute_error_simple(
1010                "DepthCompletion: h and w must be > 0".to_string(),
1011            ));
1012        }
1013        // Normalized averaging kernel
1014        let prop_kernel = [1.0 / 9.0; 9];
1015        Ok(Self {
1016            config,
1017            prop_kernel,
1018        })
1019    }
1020
1021    /// Sparse convolution with mask propagation.
1022    ///
1023    /// Only pixels with confidence > 0 contribute to the convolution.
1024    fn sparse_conv(
1025        &self,
1026        depth: &[f64],
1027        confidence: &[f64],
1028        h: usize,
1029        w: usize,
1030    ) -> (Vec<f64>, Vec<f64>) {
1031        let n = h * w;
1032        let mut new_depth = vec![0.0; n];
1033        let mut new_conf = vec![0.0; n];
1034
1035        for i in 0..h {
1036            for j in 0..w {
1037                let mut weighted_sum = 0.0;
1038                let mut conf_sum = 0.0;
1039                let mut k_idx = 0;
1040
1041                for ki in 0..3_usize {
1042                    for kj in 0..3_usize {
1043                        let ni = (i as isize) + (ki as isize) - 1;
1044                        let nj = (j as isize) + (kj as isize) - 1;
1045                        if ni >= 0 && ni < h as isize && nj >= 0 && nj < w as isize {
1046                            let idx = ni as usize * w + nj as usize;
1047                            if confidence[idx] > 0.0 {
1048                                weighted_sum +=
1049                                    depth[idx] * confidence[idx] * self.prop_kernel[k_idx];
1050                                conf_sum += confidence[idx] * self.prop_kernel[k_idx];
1051                            }
1052                        }
1053                        k_idx += 1;
1054                    }
1055                }
1056
1057                let pixel_idx = i * w + j;
1058                if confidence[pixel_idx] > 0.5 {
1059                    // Keep original observed values with high confidence
1060                    new_depth[pixel_idx] = depth[pixel_idx];
1061                    new_conf[pixel_idx] = confidence[pixel_idx];
1062                } else if conf_sum > 0.0 {
1063                    new_depth[pixel_idx] = weighted_sum / conf_sum;
1064                    new_conf[pixel_idx] = conf_sum * self.config.confidence_decay;
1065                }
1066            }
1067        }
1068
1069        (new_depth, new_conf)
1070    }
1071
1072    /// Complete sparse depth to dense.
1073    ///
1074    /// `sparse_depth`: depth values (0 for unknown).
1075    /// `confidence_map`: confidence per pixel (0..1).
1076    /// `rgb_guide`: optional RGB guidance (3 channels, length 3 * h * w), used for
1077    ///              edge-aware weighting in early fusion mode.
1078    pub fn complete(
1079        &self,
1080        sparse_depth: &[f64],
1081        confidence_map: &[f64],
1082        rgb_guide: Option<&[f64]>,
1083    ) -> Result<Vec<f64>> {
1084        let h = self.config.h;
1085        let w = self.config.w;
1086        let n = h * w;
1087
1088        if sparse_depth.len() != n || confidence_map.len() != n {
1089            return Err(TensorError::compute_error_simple(
1090                "DepthCompletion: sparse_depth/confidence size mismatch".to_string(),
1091            ));
1092        }
1093
1094        let mut depth = sparse_depth.to_vec();
1095        let mut conf = confidence_map.to_vec();
1096
1097        // Early fusion: incorporate RGB as edge-aware guidance
1098        if self.config.fusion == DeCompletionFusion::Early {
1099            if let Some(rgb) = rgb_guide {
1100                if rgb.len() >= n {
1101                    // Use luminance channel to modulate confidence at edges
1102                    for i in 0..n {
1103                        let lum = rgb[i]; // use first channel as proxy
1104                                          // Reduce confidence at strong edges (high gradient)
1105                        if i >= w && i + w < n && i + 1 < n && i > 0 {
1106                            let grad: f64 = (rgb[i + 1] - rgb[i - 1]).abs();
1107                            if grad > 0.1 {
1108                                conf[i] *= 0.8;
1109                            }
1110                        }
1111                        let _ = lum;
1112                    }
1113                }
1114            }
1115        }
1116
1117        // Iterative confidence-guided propagation
1118        for _iter in 0..self.config.n_iterations {
1119            let (new_depth, new_conf) = self.sparse_conv(&depth, &conf, h, w);
1120            depth = new_depth;
1121            conf = new_conf;
1122        }
1123
1124        Ok(depth)
1125    }
1126}
1127
1128// ─────────────────────────────────────────────────────────────────────────────
1129// §6  DeConv3d — 3D Convolution layer
1130// ─────────────────────────────────────────────────────────────────────────────
1131
1132/// Configuration for a 3D convolution layer.
1133#[derive(Debug, Clone)]
1134pub struct DeConv3dConfig {
1135    /// Input channels.
1136    pub in_channels: usize,
1137    /// Output channels.
1138    pub out_channels: usize,
1139    /// Kernel size (D, H, W).
1140    pub kernel_size: [usize; 3],
1141    /// Stride (D, H, W).
1142    pub stride: [usize; 3],
1143    /// Padding (D, H, W).
1144    pub padding: [usize; 3],
1145    /// Dilation (D, H, W).
1146    pub dilation: [usize; 3],
1147    /// Random seed.
1148    pub seed: u64,
1149}
1150
1151impl Default for DeConv3dConfig {
1152    fn default() -> Self {
1153        Self {
1154            in_channels: 1,
1155            out_channels: 1,
1156            kernel_size: [3, 3, 3],
1157            stride: [1, 1, 1],
1158            padding: [1, 1, 1],
1159            dilation: [1, 1, 1],
1160            seed: 400,
1161        }
1162    }
1163}
1164
1165/// Full 3D convolution layer.
1166///
1167/// Input shape: `[N, C_in, D, H, W]` (flattened).
1168/// Output shape: `[N, C_out, D', H', W']`.
1169#[derive(Debug, Clone)]
1170pub struct DeConv3d {
1171    pub config: DeConv3dConfig,
1172    /// Weights: `[out_channels][in_channels * kD * kH * kW]`.
1173    pub weights: Vec<Vec<f64>>,
1174    /// Bias: `[out_channels]`.
1175    pub bias: Vec<f64>,
1176}
1177
1178impl DeConv3d {
1179    pub fn new(config: DeConv3dConfig) -> Result<Self> {
1180        let [kd, kh, kw] = config.kernel_size;
1181        let fan_in = config.in_channels * kd * kh * kw;
1182        let fan_out = config.out_channels * kd * kh * kw;
1183        let weights = de_xavier_init(config.out_channels, fan_in, config.seed);
1184        let bias = vec![0.0; config.out_channels];
1185        Ok(Self {
1186            config,
1187            weights,
1188            bias,
1189        })
1190    }
1191
1192    /// Compute output spatial dimensions.
1193    pub fn output_shape(&self, d: usize, h: usize, w: usize) -> (usize, usize, usize) {
1194        let [kd, kh, kw] = self.config.kernel_size;
1195        let [sd, sh, sw] = self.config.stride;
1196        let [pd, ph, pw] = self.config.padding;
1197        let [dd, dh, dw] = self.config.dilation;
1198
1199        let od = (d + 2 * pd - dd * (kd - 1) - 1) / sd + 1;
1200        let oh = (h + 2 * ph - dh * (kh - 1) - 1) / sh + 1;
1201        let ow = (w + 2 * pw - dw * (kw - 1) - 1) / sw + 1;
1202        (od, oh, ow)
1203    }
1204
1205    /// Forward pass for a single sample (batch_size = 1).
1206    ///
1207    /// `input`: flattened `[C_in, D, H, W]`.
1208    /// `shape`: `[C_in, D, H, W]`.
1209    ///
1210    /// Returns `(output, [C_out, D', H', W'])`.
1211    pub fn forward(
1212        &self,
1213        input: &[f64],
1214        in_c: usize,
1215        d: usize,
1216        h: usize,
1217        w: usize,
1218    ) -> Result<(Vec<f64>, usize, usize, usize, usize)> {
1219        if in_c != self.config.in_channels {
1220            return Err(TensorError::compute_error_simple(
1221                "DeConv3d: in_channels mismatch".to_string(),
1222            ));
1223        }
1224        let expected = in_c * d * h * w;
1225        if input.len() != expected {
1226            return Err(TensorError::compute_error_simple(
1227                "DeConv3d: input length mismatch".to_string(),
1228            ));
1229        }
1230
1231        let [kd, kh, kw] = self.config.kernel_size;
1232        let [sd, sh, sw] = self.config.stride;
1233        let [pd, ph, pw] = self.config.padding;
1234        let [dd, dh, dw] = self.config.dilation;
1235        let (od, oh, ow) = self.output_shape(d, h, w);
1236        let out_c = self.config.out_channels;
1237
1238        let mut output = vec![0.0; out_c * od * oh * ow];
1239
1240        for oc in 0..out_c {
1241            for oi in 0..od {
1242                for oj in 0..oh {
1243                    for ok in 0..ow {
1244                        let mut val = self.bias[oc];
1245                        for ic in 0..in_c {
1246                            for ki in 0..kd {
1247                                for kj in 0..kh {
1248                                    for kk in 0..kw {
1249                                        let id = oi * sd + ki * dd;
1250                                        let ih = oj * sh + kj * dh;
1251                                        let iw = ok * sw + kk * dw;
1252
1253                                        if id >= pd
1254                                            && ih >= ph
1255                                            && iw >= pw
1256                                            && id - pd < d
1257                                            && ih - ph < h
1258                                            && iw - pw < w
1259                                        {
1260                                            let in_idx = ic * (d * h * w)
1261                                                + (id - pd) * (h * w)
1262                                                + (ih - ph) * w
1263                                                + (iw - pw);
1264                                            let w_idx =
1265                                                ic * (kd * kh * kw) + ki * (kh * kw) + kj * kw + kk;
1266                                            val += input[in_idx] * self.weights[oc][w_idx];
1267                                        }
1268                                    }
1269                                }
1270                            }
1271                        }
1272                        let out_idx = oc * (od * oh * ow) + oi * (oh * ow) + oj * ow + ok;
1273                        output[out_idx] = val;
1274                    }
1275                }
1276            }
1277        }
1278
1279        Ok((output, out_c, od, oh, ow))
1280    }
1281}
1282
1283// ─────────────────────────────────────────────────────────────────────────────
1284// §7  DeImplicitNeuralField — NeRF-inspired implicit representation
1285// ─────────────────────────────────────────────────────────────────────────────
1286
1287/// Configuration for the implicit neural field.
1288#[derive(Debug, Clone)]
1289pub struct DeImplicitFieldConfig {
1290    /// Number of positional encoding frequency bands L.
1291    pub n_freqs: usize,
1292    /// Number of hidden units in the MLP.
1293    pub hidden_dim: usize,
1294    /// Number of hidden layers.
1295    pub n_layers: usize,
1296    /// Number of stratified samples per ray.
1297    pub n_samples: usize,
1298    /// Random seed.
1299    pub seed: u64,
1300}
1301
1302impl Default for DeImplicitFieldConfig {
1303    fn default() -> Self {
1304        Self {
1305            n_freqs: 6,
1306            hidden_dim: 64,
1307            n_layers: 4,
1308            n_samples: 32,
1309            seed: 500,
1310        }
1311    }
1312}
1313
1314/// NeRF-inspired implicit neural representation.
1315///
1316/// Encodes 3D coordinates with positional encoding, processes through an MLP
1317/// to predict density and color, and renders rays via volume rendering.
1318#[derive(Debug, Clone)]
1319pub struct DeImplicitNeuralField {
1320    pub config: DeImplicitFieldConfig,
1321    /// MLP weights: `[n_layers][(out_dim, in_dim)]` each layer.
1322    pub layers: Vec<(Vec<Vec<f64>>, Vec<f64>)>,
1323    /// Output head: density (1) + color (3).
1324    pub output_head: (Vec<Vec<f64>>, Vec<f64>),
1325}
1326
1327impl DeImplicitNeuralField {
1328    pub fn new(config: DeImplicitFieldConfig) -> Result<Self> {
1329        if config.n_freqs == 0 || config.hidden_dim == 0 {
1330            return Err(TensorError::compute_error_simple(
1331                "DeImplicitNeuralField: n_freqs and hidden_dim must be > 0".to_string(),
1332            ));
1333        }
1334
1335        // Input dim: 3 coords * (2 * n_freqs) = 6 * n_freqs
1336        let input_dim = 3 * 2 * config.n_freqs;
1337        let mut layers = Vec::with_capacity(config.n_layers);
1338        let mut seed = config.seed;
1339
1340        // First layer
1341        let w0 = de_xavier_init(config.hidden_dim, input_dim, seed);
1342        let b0 = vec![0.0; config.hidden_dim];
1343        layers.push((w0, b0));
1344        seed = seed.wrapping_add(1);
1345
1346        // Hidden layers
1347        for _ in 1..config.n_layers {
1348            let w = de_xavier_init(config.hidden_dim, config.hidden_dim, seed);
1349            let b = vec![0.0; config.hidden_dim];
1350            layers.push((w, b));
1351            seed = seed.wrapping_add(1);
1352        }
1353
1354        // Output head: hidden_dim → 4 (density + RGB)
1355        let w_out = de_xavier_init(4, config.hidden_dim, seed);
1356        let b_out = vec![0.0; 4];
1357
1358        Ok(Self {
1359            config,
1360            layers,
1361            output_head: (w_out, b_out),
1362        })
1363    }
1364
1365    /// Positional encoding: gamma(p) = [sin(2^k * pi * p), cos(2^k * pi * p)] for k=0..L-1.
1366    pub fn positional_encode(&self, coords: &[f64; 3]) -> Vec<f64> {
1367        let mut encoded = Vec::with_capacity(3 * 2 * self.config.n_freqs);
1368        for &c in coords {
1369            for k in 0..self.config.n_freqs {
1370                let freq = (2.0_f64).powi(k as i32) * std::f64::consts::PI;
1371                encoded.push((freq * c).sin());
1372                encoded.push((freq * c).cos());
1373            }
1374        }
1375        encoded
1376    }
1377
1378    /// Query the field at a 3D point.
1379    ///
1380    /// Returns `(density, [r, g, b])`.
1381    pub fn query(&self, coords: &[f64; 3]) -> Result<(f64, [f64; 3])> {
1382        let mut x = self.positional_encode(coords);
1383
1384        for (w, b) in &self.layers {
1385            if w.is_empty() || w[0].len() != x.len() {
1386                // Skip layers with dimension mismatch during partial init
1387                continue;
1388            }
1389            x = de_linear_relu(w, b, &x);
1390        }
1391
1392        let (ref w_out, ref b_out) = self.output_head;
1393        if w_out.is_empty() || w_out[0].len() != x.len() {
1394            return Err(TensorError::compute_error_simple(
1395                "DeImplicitNeuralField: output head dimension mismatch".to_string(),
1396            ));
1397        }
1398        let out = de_linear(w_out, b_out, &x);
1399
1400        if out.len() < 4 {
1401            return Err(TensorError::compute_error_simple(
1402                "DeImplicitNeuralField: output too short".to_string(),
1403            ));
1404        }
1405
1406        let density = de_softplus(out[0]);
1407        let color = [de_sigmoid(out[1]), de_sigmoid(out[2]), de_sigmoid(out[3])];
1408
1409        Ok((density, color))
1410    }
1411
1412    /// Render a single ray via volume rendering.
1413    ///
1414    /// `C = sum_i T_i * alpha_i * c_i` where
1415    /// `T_i = prod_{j<i} (1 - alpha_j)`, `alpha = 1 - exp(-sigma * delta)`.
1416    ///
1417    /// Returns `(color: [f64; 3], expected_depth: f64)`.
1418    pub fn render_ray(
1419        &self,
1420        origin: &[f64; 3],
1421        direction: &[f64; 3],
1422        near: f64,
1423        far: f64,
1424        n_samples: usize,
1425    ) -> Result<([f64; 3], f64)> {
1426        if near >= far {
1427            return Err(TensorError::compute_error_simple(
1428                "render_ray: near must be < far".to_string(),
1429            ));
1430        }
1431        if n_samples == 0 {
1432            return Err(TensorError::compute_error_simple(
1433                "render_ray: n_samples must be > 0".to_string(),
1434            ));
1435        }
1436
1437        let n = n_samples;
1438        let step = (far - near) / n as f64;
1439
1440        let mut color = [0.0; 3];
1441        let mut depth = 0.0;
1442        let mut transmittance = 1.0;
1443
1444        // Stratified sampling along the ray
1445        let mut rng = StdRng::seed_from_u64(self.config.seed);
1446
1447        for i in 0..n {
1448            let t_base = near + i as f64 * step;
1449            let jitter = rng.random::<f64>() * step;
1450            let t = t_base + jitter;
1451            let delta = step;
1452
1453            let point = [
1454                origin[0] + t * direction[0],
1455                origin[1] + t * direction[1],
1456                origin[2] + t * direction[2],
1457            ];
1458
1459            let (sigma, c) = self.query(&point)?;
1460            let alpha = 1.0 - (-sigma * delta).exp();
1461            let weight = transmittance * alpha;
1462
1463            color[0] += weight * c[0];
1464            color[1] += weight * c[1];
1465            color[2] += weight * c[2];
1466            depth += weight * t;
1467
1468            transmittance *= 1.0 - alpha;
1469            if transmittance < 1e-6 {
1470                break;
1471            }
1472        }
1473
1474        Ok((color, depth))
1475    }
1476}
1477
1478// ─────────────────────────────────────────────────────────────────────────────
1479// §8  DePanopticHead — Panoptic segmentation
1480// ─────────────────────────────────────────────────────────────────────────────
1481
1482/// Panoptic segmentation result.
1483#[derive(Debug, Clone)]
1484pub struct DePanopticResult {
1485    /// Per-pixel semantic class IDs.
1486    pub semantic_map: Vec<usize>,
1487    /// Per-pixel instance IDs (0 = background/stuff).
1488    pub instance_map: Vec<usize>,
1489    /// Fused panoptic map: `class_id * 1000 + instance_id`.
1490    pub panoptic_map: Vec<usize>,
1491    /// Height.
1492    pub height: usize,
1493    /// Width.
1494    pub width: usize,
1495}
1496
1497/// Configuration for the panoptic head.
1498#[derive(Debug, Clone)]
1499pub struct DePanopticConfig {
1500    /// Number of semantic classes.
1501    pub n_classes: usize,
1502    /// Maximum number of instances.
1503    pub max_instances: usize,
1504    /// Feature dimension.
1505    pub feature_dim: usize,
1506    /// IoU threshold for overlap resolution.
1507    pub overlap_threshold: f64,
1508    /// Seed.
1509    pub seed: u64,
1510}
1511
1512impl Default for DePanopticConfig {
1513    fn default() -> Self {
1514        Self {
1515            n_classes: 10,
1516            max_instances: 50,
1517            feature_dim: 32,
1518            overlap_threshold: 0.5,
1519            seed: 600,
1520        }
1521    }
1522}
1523
1524/// Panoptic segmentation head.
1525///
1526/// Combines a semantic branch (per-pixel class logits) with an instance branch
1527/// (center prediction + offset regression) and fuses them into a panoptic map.
1528#[derive(Debug, Clone)]
1529pub struct DePanopticHead {
1530    pub config: DePanopticConfig,
1531    /// Semantic branch: feature_dim → n_classes.
1532    pub semantic_weights: Vec<Vec<f64>>,
1533    pub semantic_bias: Vec<f64>,
1534    /// Instance center weights: feature_dim → 1 (center-ness score).
1535    pub center_weights: Vec<Vec<f64>>,
1536    pub center_bias: Vec<f64>,
1537    /// Offset weights: feature_dim → 2 (x, y offsets to center).
1538    pub offset_weights: Vec<Vec<f64>>,
1539    pub offset_bias: Vec<f64>,
1540}
1541
1542impl DePanopticHead {
1543    pub fn new(config: DePanopticConfig) -> Result<Self> {
1544        if config.n_classes == 0 {
1545            return Err(TensorError::compute_error_simple(
1546                "DePanopticHead: n_classes must be > 0".to_string(),
1547            ));
1548        }
1549        let mut seed = config.seed;
1550        let sem_w = de_xavier_init(config.n_classes, config.feature_dim, seed);
1551        let sem_b = vec![0.0; config.n_classes];
1552        seed = seed.wrapping_add(1);
1553
1554        let ctr_w = de_xavier_init(1, config.feature_dim, seed);
1555        let ctr_b = vec![0.0; 1];
1556        seed = seed.wrapping_add(1);
1557
1558        let off_w = de_xavier_init(2, config.feature_dim, seed);
1559        let off_b = vec![0.0; 2];
1560
1561        Ok(Self {
1562            config,
1563            semantic_weights: sem_w,
1564            semantic_bias: sem_b,
1565            center_weights: ctr_w,
1566            center_bias: ctr_b,
1567            offset_weights: off_w,
1568            offset_bias: off_b,
1569        })
1570    }
1571
1572    /// Compute semantic logits for a feature vector.
1573    pub fn semantic_logits(&self, features: &[f64]) -> Vec<f64> {
1574        de_linear(&self.semantic_weights, &self.semantic_bias, features)
1575    }
1576
1577    /// Compute center-ness score for a feature vector.
1578    pub fn center_score(&self, features: &[f64]) -> f64 {
1579        let out = de_linear(&self.center_weights, &self.center_bias, features);
1580        if out.is_empty() {
1581            0.0
1582        } else {
1583            de_sigmoid(out[0])
1584        }
1585    }
1586
1587    /// Compute offset prediction for a feature vector.
1588    pub fn offset_pred(&self, features: &[f64]) -> [f64; 2] {
1589        let out = de_linear(&self.offset_weights, &self.offset_bias, features);
1590        if out.len() >= 2 {
1591            [out[0], out[1]]
1592        } else {
1593            [0.0, 0.0]
1594        }
1595    }
1596
1597    /// Full panoptic segmentation from per-pixel feature maps.
1598    ///
1599    /// `feature_map`: `[h * w][feature_dim]` feature vectors.
1600    pub fn forward(
1601        &self,
1602        feature_map: &[Vec<f64>],
1603        h: usize,
1604        w: usize,
1605    ) -> Result<DePanopticResult> {
1606        let n = h * w;
1607        if feature_map.len() != n {
1608            return Err(TensorError::compute_error_simple(
1609                "DePanopticHead: feature_map length mismatch".to_string(),
1610            ));
1611        }
1612
1613        let mut semantic_map = vec![0usize; n];
1614        let mut center_scores = vec![0.0f64; n];
1615        let mut offsets = vec![[0.0f64; 2]; n];
1616
1617        for (idx, feat) in feature_map.iter().enumerate() {
1618            // Semantic: argmax
1619            let logits = self.semantic_logits(feat);
1620            let mut best_class = 0;
1621            let mut best_score = f64::NEG_INFINITY;
1622            for (c, &s) in logits.iter().enumerate() {
1623                if s > best_score {
1624                    best_score = s;
1625                    best_class = c;
1626                }
1627            }
1628            semantic_map[idx] = best_class;
1629
1630            // Instance
1631            center_scores[idx] = self.center_score(feat);
1632            offsets[idx] = self.offset_pred(feat);
1633        }
1634
1635        // Instance grouping via center voting
1636        let instance_map = self.group_instances(&center_scores, &offsets, &semantic_map, h, w);
1637
1638        // Panoptic fusion
1639        let panoptic_map: Vec<usize> = semantic_map
1640            .iter()
1641            .zip(&instance_map)
1642            .map(|(&sem, &inst)| sem * 1000 + inst)
1643            .collect();
1644
1645        Ok(DePanopticResult {
1646            semantic_map,
1647            instance_map,
1648            panoptic_map,
1649            height: h,
1650            width: w,
1651        })
1652    }
1653
1654    /// Group pixels into instances using center voting and NMS-style merging.
1655    fn group_instances(
1656        &self,
1657        center_scores: &[f64],
1658        offsets: &[[f64; 2]],
1659        semantic_map: &[usize],
1660        h: usize,
1661        w: usize,
1662    ) -> Vec<usize> {
1663        let n = h * w;
1664        let mut instance_map = vec![0usize; n];
1665        let center_threshold = 0.3;
1666
1667        // Find center candidates (pixels with high center-ness)
1668        let mut centers: Vec<(usize, f64)> = Vec::new();
1669        for (idx, &score) in center_scores.iter().enumerate() {
1670            if score > center_threshold {
1671                centers.push((idx, score));
1672            }
1673        }
1674
1675        // Sort by score descending
1676        centers.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1677
1678        // Assign instance IDs via greedy assignment
1679        let mut next_instance = 1usize;
1680        let max_inst = self.config.max_instances;
1681
1682        for &(center_idx, _score) in &centers {
1683            if next_instance > max_inst {
1684                break;
1685            }
1686            if instance_map[center_idx] != 0 {
1687                continue;
1688            }
1689
1690            let center_y = center_idx / w;
1691            let center_x = center_idx % w;
1692            let center_class = semantic_map[center_idx];
1693
1694            // Assign nearby pixels that vote for this center
1695            for idx in 0..n {
1696                if instance_map[idx] != 0 {
1697                    continue;
1698                }
1699                if semantic_map[idx] != center_class {
1700                    continue;
1701                }
1702                let py = idx / w;
1703                let px = idx % w;
1704                let voted_y = py as f64 + offsets[idx][1];
1705                let voted_x = px as f64 + offsets[idx][0];
1706                let dist = ((voted_y - center_y as f64).powi(2)
1707                    + (voted_x - center_x as f64).powi(2))
1708                .sqrt();
1709                if dist < 3.0 {
1710                    instance_map[idx] = next_instance;
1711                }
1712            }
1713
1714            next_instance += 1;
1715        }
1716
1717        instance_map
1718    }
1719}
1720
1721// ─────────────────────────────────────────────────────────────────────────────
1722// §9  DeDepthMetrics — Comprehensive depth evaluation
1723// ─────────────────────────────────────────────────────────────────────────────
1724
1725/// Comprehensive depth estimation evaluation report.
1726#[derive(Debug, Clone)]
1727pub struct DeDepthReport {
1728    /// Absolute Relative Error: mean(|pred - gt| / gt).
1729    pub abs_rel: f64,
1730    /// Squared Relative Error: mean((pred - gt)^2 / gt).
1731    pub sq_rel: f64,
1732    /// Root Mean Squared Error.
1733    pub rmse: f64,
1734    /// RMSE on log depths.
1735    pub rmse_log: f64,
1736    /// Scale-Invariant Log Error.
1737    pub si_log: f64,
1738    /// Delta < 1.25 (percentage).
1739    pub delta_1: f64,
1740    /// Delta < 1.25^2 (percentage).
1741    pub delta_2: f64,
1742    /// Delta < 1.25^3 (percentage).
1743    pub delta_3: f64,
1744    /// Number of valid pixels.
1745    pub n_valid: usize,
1746}
1747
1748/// Depth evaluation metrics.
1749pub struct DeDepthMetrics;
1750
1751impl DeDepthMetrics {
1752    /// Compute absolute relative error: mean(|pred - gt| / gt).
1753    pub fn abs_rel(pred: &[f64], gt: &[f64]) -> Result<f64> {
1754        let (sum, count) = Self::valid_pairs(pred, gt)?;
1755        if count == 0 {
1756            return Err(TensorError::compute_error_simple(
1757                "abs_rel: no valid pixels".to_string(),
1758            ));
1759        }
1760        let val: f64 = sum.iter().map(|&(p, g)| (p - g).abs() / g).sum::<f64>() / count as f64;
1761        Ok(val)
1762    }
1763
1764    /// Compute squared relative error: mean((pred - gt)^2 / gt).
1765    pub fn sq_rel(pred: &[f64], gt: &[f64]) -> Result<f64> {
1766        let (pairs, count) = Self::valid_pairs(pred, gt)?;
1767        if count == 0 {
1768            return Err(TensorError::compute_error_simple(
1769                "sq_rel: no valid pixels".to_string(),
1770            ));
1771        }
1772        let val: f64 = pairs.iter().map(|&(p, g)| (p - g).powi(2) / g).sum::<f64>() / count as f64;
1773        Ok(val)
1774    }
1775
1776    /// Compute RMSE.
1777    pub fn rmse(pred: &[f64], gt: &[f64]) -> Result<f64> {
1778        let (pairs, count) = Self::valid_pairs(pred, gt)?;
1779        if count == 0 {
1780            return Err(TensorError::compute_error_simple(
1781                "rmse: no valid pixels".to_string(),
1782            ));
1783        }
1784        let mse: f64 = pairs.iter().map(|&(p, g)| (p - g).powi(2)).sum::<f64>() / count as f64;
1785        Ok(mse.sqrt())
1786    }
1787
1788    /// Compute RMSE on log depths.
1789    pub fn rmse_log(pred: &[f64], gt: &[f64]) -> Result<f64> {
1790        let (pairs, count) = Self::valid_pairs(pred, gt)?;
1791        if count == 0 {
1792            return Err(TensorError::compute_error_simple(
1793                "rmse_log: no valid pixels".to_string(),
1794            ));
1795        }
1796        let mse: f64 = pairs
1797            .iter()
1798            .map(|&(p, g)| (p.ln() - g.ln()).powi(2))
1799            .sum::<f64>()
1800            / count as f64;
1801        Ok(mse.sqrt())
1802    }
1803
1804    /// Scale-Invariant Log Error (SILog).
1805    ///
1806    /// SILog = sqrt(Var(d) + 0 * mean(d)^2) where d_i = log(pred) - log(gt).
1807    /// Common definition: SILog = 100 * sqrt(E[d^2] - E\[d\]^2).
1808    pub fn si_log(pred: &[f64], gt: &[f64]) -> Result<f64> {
1809        let (pairs, count) = Self::valid_pairs(pred, gt)?;
1810        if count == 0 {
1811            return Err(TensorError::compute_error_simple(
1812                "si_log: no valid pixels".to_string(),
1813            ));
1814        }
1815        let n = count as f64;
1816        let d: Vec<f64> = pairs.iter().map(|&(p, g)| p.ln() - g.ln()).collect();
1817        let mean_d: f64 = d.iter().sum::<f64>() / n;
1818        let mean_d2: f64 = d.iter().map(|&di| di * di).sum::<f64>() / n;
1819        let val = (mean_d2 - mean_d * mean_d).max(0.0).sqrt() * 100.0;
1820        Ok(val)
1821    }
1822
1823    /// Delta threshold metric: percentage of pixels where
1824    /// max(pred/gt, gt/pred) < threshold.
1825    pub fn delta_threshold(pred: &[f64], gt: &[f64], threshold: f64) -> Result<f64> {
1826        let (pairs, count) = Self::valid_pairs(pred, gt)?;
1827        if count == 0 {
1828            return Err(TensorError::compute_error_simple(
1829                "delta_threshold: no valid pixels".to_string(),
1830            ));
1831        }
1832        let good = pairs
1833            .iter()
1834            .filter(|&&(p, g)| {
1835                let ratio = (p / g).max(g / p);
1836                ratio < threshold
1837            })
1838            .count();
1839        Ok(good as f64 / count as f64)
1840    }
1841
1842    /// Compute a comprehensive depth report.
1843    pub fn evaluate(pred: &[f64], gt: &[f64]) -> Result<DeDepthReport> {
1844        let abs_rel = Self::abs_rel(pred, gt)?;
1845        let sq_rel = Self::sq_rel(pred, gt)?;
1846        let rmse = Self::rmse(pred, gt)?;
1847        let rmse_log = Self::rmse_log(pred, gt)?;
1848        let si_log = Self::si_log(pred, gt)?;
1849        let delta_1 = Self::delta_threshold(pred, gt, 1.25)?;
1850        let delta_2 = Self::delta_threshold(pred, gt, 1.25 * 1.25)?;
1851        let delta_3 = Self::delta_threshold(pred, gt, 1.25 * 1.25 * 1.25)?;
1852        let (_, n_valid) = Self::valid_pairs(pred, gt)?;
1853
1854        Ok(DeDepthReport {
1855            abs_rel,
1856            sq_rel,
1857            rmse,
1858            rmse_log,
1859            si_log,
1860            delta_1,
1861            delta_2,
1862            delta_3,
1863            n_valid,
1864        })
1865    }
1866
1867    /// Extract valid prediction/GT pairs (both > 0).
1868    fn valid_pairs(pred: &[f64], gt: &[f64]) -> Result<(Vec<(f64, f64)>, usize)> {
1869        if pred.len() != gt.len() {
1870            return Err(TensorError::compute_error_simple(
1871                "DeDepthMetrics: pred/gt length mismatch".to_string(),
1872            ));
1873        }
1874        let pairs: Vec<(f64, f64)> = pred
1875            .iter()
1876            .zip(gt)
1877            .filter(|(&p, &g)| p > 0.0 && g > 0.0)
1878            .map(|(&p, &g)| (p, g))
1879            .collect();
1880        let count = pairs.len();
1881        Ok((pairs, count))
1882    }
1883}
1884
1885// ─────────────────────────────────────────────────────────────────────────────
1886// Tests
1887// ─────────────────────────────────────────────────────────────────────────────
1888
1889#[cfg(test)]
1890mod tests;