Skip to main content

threecrate_simplification/
clustering.rs

1//! Clustering-based mesh simplification
2//!
3//! Implements the Rossignac & Borrel (1993) vertex clustering algorithm with
4//! extensions for adaptive octree-based clustering, multiple representative
5//! selection strategies, boundary-aware clustering, and sharp feature preservation.
6
7use crate::MeshSimplifier;
8use nalgebra::{Matrix4, Vector4};
9use std::collections::{HashMap, HashSet};
10use threecrate_core::{Error, Point3f, Result, TriangleMesh, Vector3f};
11
12// ============================================================
13// Configuration Types
14// ============================================================
15
16/// Strategy for selecting the representative vertex within a cluster.
17#[derive(Debug, Clone, Copy, PartialEq)]
18pub enum RepresentativeStrategy {
19    /// Arithmetic mean of all vertex positions in the cluster.
20    Centroid,
21    /// Weighted average using vertex valence (number of adjacent faces).
22    WeightedAverage,
23    /// Position that minimizes the summed quadric error for the cluster.
24    MinimumError,
25}
26
27/// Clustering grid mode.
28#[derive(Debug, Clone, Copy, PartialEq)]
29pub enum ClusteringMode {
30    /// Uniform grid with automatically computed cell size from reduction ratio.
31    Uniform,
32    /// Adaptive octree that subdivides cells exceeding the error threshold.
33    Adaptive {
34        max_depth: u32,
35        error_threshold: f64,
36    },
37}
38
39// ============================================================
40// Bounding Box
41// ============================================================
42
43#[derive(Debug, Clone, Copy)]
44struct BBox {
45    min: [f64; 3],
46    max: [f64; 3],
47}
48
49impl BBox {
50    fn from_vertices(vertices: &[Point3f]) -> Self {
51        let mut min = [f64::MAX; 3];
52        let mut max = [f64::MIN; 3];
53        for v in vertices {
54            for i in 0..3 {
55                let c = v[i] as f64;
56                if c < min[i] {
57                    min[i] = c;
58                }
59                if c > max[i] {
60                    max[i] = c;
61                }
62            }
63        }
64        BBox { min, max }
65    }
66
67    fn size(&self) -> [f64; 3] {
68        [
69            self.max[0] - self.min[0],
70            self.max[1] - self.min[1],
71            self.max[2] - self.min[2],
72        ]
73    }
74
75    fn max_extent(&self) -> f64 {
76        let s = self.size();
77        s[0].max(s[1]).max(s[2])
78    }
79
80    fn center(&self) -> [f64; 3] {
81        [
82            (self.min[0] + self.max[0]) * 0.5,
83            (self.min[1] + self.max[1]) * 0.5,
84            (self.min[2] + self.max[2]) * 0.5,
85        ]
86    }
87}
88
89// ============================================================
90// Octree Node (for adaptive clustering)
91// ============================================================
92
93#[derive(Debug)]
94struct OctreeNode {
95    bbox: BBox,
96    children: Option<Box<[OctreeNode; 8]>>,
97    vertex_indices: Vec<usize>,
98    depth: u32,
99}
100
101impl OctreeNode {
102    fn new(bbox: BBox, depth: u32) -> Self {
103        OctreeNode {
104            bbox,
105            children: None,
106            vertex_indices: Vec::new(),
107            depth,
108        }
109    }
110
111    fn contains(&self, p: &Point3f) -> bool {
112        let eps = 1e-6;
113        (p.x as f64) >= self.bbox.min[0] - eps
114            && (p.x as f64) <= self.bbox.max[0] + eps
115            && (p.y as f64) >= self.bbox.min[1] - eps
116            && (p.y as f64) <= self.bbox.max[1] + eps
117            && (p.z as f64) >= self.bbox.min[2] - eps
118            && (p.z as f64) <= self.bbox.max[2] + eps
119    }
120
121    fn subdivide(&mut self) {
122        let c = self.bbox.center();
123        let mn = self.bbox.min;
124        let mx = self.bbox.max;
125        let d = self.depth + 1;
126
127        let children = [
128            OctreeNode::new(BBox { min: [mn[0], mn[1], mn[2]], max: [c[0], c[1], c[2]] }, d),
129            OctreeNode::new(BBox { min: [c[0], mn[1], mn[2]], max: [mx[0], c[1], c[2]] }, d),
130            OctreeNode::new(BBox { min: [mn[0], c[1], mn[2]], max: [c[0], mx[1], c[2]] }, d),
131            OctreeNode::new(BBox { min: [c[0], c[1], mn[2]], max: [mx[0], mx[1], c[2]] }, d),
132            OctreeNode::new(BBox { min: [mn[0], mn[1], c[2]], max: [c[0], c[1], mx[2]] }, d),
133            OctreeNode::new(BBox { min: [c[0], mn[1], c[2]], max: [mx[0], c[1], mx[2]] }, d),
134            OctreeNode::new(BBox { min: [mn[0], c[1], c[2]], max: [c[0], mx[1], mx[2]] }, d),
135            OctreeNode::new(BBox { min: [c[0], c[1], c[2]], max: [mx[0], mx[1], mx[2]] }, d),
136        ];
137        self.children = Some(Box::new(children));
138    }
139
140    /// Insert a vertex index; if adaptive, subdivide when the cluster error exceeds
141    /// the threshold and depth < max_depth.
142    fn insert(
143        &mut self,
144        vi: usize,
145        positions: &[Point3f],
146        quadrics: &[Matrix4<f64>],
147        max_depth: u32,
148        error_threshold: f64,
149    ) {
150        if let Some(ref mut children) = self.children {
151            for child in children.iter_mut() {
152                if child.contains(&positions[vi]) {
153                    child.insert(vi, positions, quadrics, max_depth, error_threshold);
154                    return;
155                }
156            }
157            // Fallback: add to self if no child contains the point
158            self.vertex_indices.push(vi);
159            return;
160        }
161
162        self.vertex_indices.push(vi);
163
164        // Consider subdividing
165        if self.vertex_indices.len() > 1 && self.depth < max_depth {
166            let cluster_error = compute_cluster_quadric_error(
167                &self.vertex_indices,
168                positions,
169                quadrics,
170            );
171            if cluster_error > error_threshold {
172                self.subdivide();
173                let verts = std::mem::take(&mut self.vertex_indices);
174                for v in verts {
175                    self.insert(v, positions, quadrics, max_depth, error_threshold);
176                }
177            }
178        }
179    }
180
181    /// Collect all leaf clusters (non-empty vertex lists).
182    fn collect_clusters(&self, out: &mut Vec<Vec<usize>>) {
183        if let Some(ref children) = self.children {
184            for child in children.iter() {
185                child.collect_clusters(out);
186            }
187        }
188        if !self.vertex_indices.is_empty() {
189            out.push(self.vertex_indices.clone());
190        }
191    }
192}
193
194// ============================================================
195// Helpers
196// ============================================================
197
198fn compute_plane(v0: &Point3f, v1: &Point3f, v2: &Point3f) -> Vector4<f64> {
199    let e1 = v1 - v0;
200    let e2 = v2 - v0;
201    let n = e1.cross(&e2).normalize();
202    if !n.iter().all(|x| x.is_finite()) {
203        return Vector4::new(0.0, 0.0, 1.0, 0.0);
204    }
205    let d = -n.dot(&v0.coords);
206    Vector4::new(n.x as f64, n.y as f64, n.z as f64, d as f64)
207}
208
209fn plane_to_quadric(p: &Vector4<f64>) -> Matrix4<f64> {
210    let (a, b, c, d) = (p[0], p[1], p[2], p[3]);
211    Matrix4::new(
212        a * a, a * b, a * c, a * d,
213        a * b, b * b, b * c, b * d,
214        a * c, b * c, c * c, c * d,
215        a * d, b * d, c * d, d * d,
216    )
217}
218
219fn compute_quadrics(mesh: &TriangleMesh) -> Vec<Matrix4<f64>> {
220    let mut quadrics = vec![Matrix4::zeros(); mesh.vertices.len()];
221    for face in &mesh.faces {
222        let plane = compute_plane(
223            &mesh.vertices[face[0]],
224            &mesh.vertices[face[1]],
225            &mesh.vertices[face[2]],
226        );
227        let q = plane_to_quadric(&plane);
228        for &vi in face {
229            quadrics[vi] += q;
230        }
231    }
232    quadrics
233}
234
235fn compute_vertex_valence(mesh: &TriangleMesh) -> Vec<usize> {
236    let mut valence = vec![0usize; mesh.vertices.len()];
237    for face in &mesh.faces {
238        for &vi in face {
239            valence[vi] += 1;
240        }
241    }
242    valence
243}
244
245fn quadric_error_at(pos: &Point3f, q: &Matrix4<f64>) -> f64 {
246    let v = Vector4::new(pos.x as f64, pos.y as f64, pos.z as f64, 1.0);
247    (v.transpose() * q * v)[0].max(0.0)
248}
249
250fn compute_cluster_quadric_error(
251    indices: &[usize],
252    positions: &[Point3f],
253    quadrics: &[Matrix4<f64>],
254) -> f64 {
255    if indices.is_empty() {
256        return 0.0;
257    }
258    // Compute centroid
259    let mut cx = 0.0f64;
260    let mut cy = 0.0f64;
261    let mut cz = 0.0f64;
262    for &vi in indices {
263        cx += positions[vi].x as f64;
264        cy += positions[vi].y as f64;
265        cz += positions[vi].z as f64;
266    }
267    let n = indices.len() as f64;
268    let centroid = Point3f::new((cx / n) as f32, (cy / n) as f32, (cz / n) as f32);
269
270    // Sum quadric errors at the centroid
271    let mut total = 0.0;
272    for &vi in indices {
273        total += quadric_error_at(&centroid, &quadrics[vi]);
274    }
275    total
276}
277
278fn find_boundary_vertices(mesh: &TriangleMesh) -> HashSet<usize> {
279    let mut edge_count: HashMap<(usize, usize), usize> = HashMap::new();
280    for face in &mesh.faces {
281        let edges = [
282            (face[0].min(face[1]), face[0].max(face[1])),
283            (face[1].min(face[2]), face[1].max(face[2])),
284            (face[2].min(face[0]), face[2].max(face[0])),
285        ];
286        for &e in &edges {
287            *edge_count.entry(e).or_insert(0) += 1;
288        }
289    }
290    let mut boundary = HashSet::new();
291    for ((v1, v2), count) in &edge_count {
292        if *count == 1 {
293            boundary.insert(*v1);
294            boundary.insert(*v2);
295        }
296    }
297    boundary
298}
299
300/// Detect vertices on sharp features using the dihedral angle between
301/// adjacent face normals.
302fn find_feature_vertices(mesh: &TriangleMesh, angle_threshold: f32) -> HashSet<usize> {
303    let cos_threshold = angle_threshold.cos();
304
305    // Compute face normals
306    let face_normals: Vec<Vector3f> = mesh
307        .faces
308        .iter()
309        .map(|f| {
310            let e1 = mesh.vertices[f[1]] - mesh.vertices[f[0]];
311            let e2 = mesh.vertices[f[2]] - mesh.vertices[f[0]];
312            let n = e1.cross(&e2);
313            let len = n.magnitude();
314            if len > 1e-12 {
315                n / len
316            } else {
317                Vector3f::new(0.0, 0.0, 1.0)
318            }
319        })
320        .collect();
321
322    // Build edge -> face adjacency
323    let mut edge_faces: HashMap<(usize, usize), Vec<usize>> = HashMap::new();
324    for (fi, face) in mesh.faces.iter().enumerate() {
325        let edges = [
326            (face[0].min(face[1]), face[0].max(face[1])),
327            (face[1].min(face[2]), face[1].max(face[2])),
328            (face[2].min(face[0]), face[2].max(face[0])),
329        ];
330        for &e in &edges {
331            edge_faces.entry(e).or_default().push(fi);
332        }
333    }
334
335    let mut feature_verts = HashSet::new();
336    for ((v1, v2), faces) in &edge_faces {
337        if faces.len() == 2 {
338            let dot = face_normals[faces[0]].dot(&face_normals[faces[1]]);
339            if dot < cos_threshold {
340                feature_verts.insert(*v1);
341                feature_verts.insert(*v2);
342            }
343        }
344    }
345    feature_verts
346}
347
348// ============================================================
349// Representative Selection
350// ============================================================
351
352fn select_representative(
353    cluster: &[usize],
354    positions: &[Point3f],
355    quadrics: &[Matrix4<f64>],
356    valence: &[usize],
357    strategy: RepresentativeStrategy,
358) -> Point3f {
359    match strategy {
360        RepresentativeStrategy::Centroid => {
361            let mut cx = 0.0f64;
362            let mut cy = 0.0f64;
363            let mut cz = 0.0f64;
364            for &vi in cluster {
365                cx += positions[vi].x as f64;
366                cy += positions[vi].y as f64;
367                cz += positions[vi].z as f64;
368            }
369            let n = cluster.len() as f64;
370            Point3f::new((cx / n) as f32, (cy / n) as f32, (cz / n) as f32)
371        }
372        RepresentativeStrategy::WeightedAverage => {
373            let mut wx = 0.0f64;
374            let mut wy = 0.0f64;
375            let mut wz = 0.0f64;
376            let mut w_total = 0.0f64;
377            for &vi in cluster {
378                let w = (valence[vi].max(1)) as f64;
379                wx += positions[vi].x as f64 * w;
380                wy += positions[vi].y as f64 * w;
381                wz += positions[vi].z as f64 * w;
382                w_total += w;
383            }
384            if w_total > 0.0 {
385                Point3f::new(
386                    (wx / w_total) as f32,
387                    (wy / w_total) as f32,
388                    (wz / w_total) as f32,
389                )
390            } else {
391                positions[cluster[0]]
392            }
393        }
394        RepresentativeStrategy::MinimumError => {
395            // Sum the quadrics for all vertices in the cluster
396            let mut q_sum = Matrix4::zeros();
397            for &vi in cluster {
398                q_sum += quadrics[vi];
399            }
400
401            // Try to solve for the optimal position via the quadric
402            let q3 = q_sum.fixed_view::<3, 3>(0, 0);
403            let q1 = q_sum.fixed_view::<3, 1>(0, 3);
404
405            if let Some(inv) = q3.try_inverse() {
406                let p = -inv * q1;
407                let candidate = Point3f::new(p[0] as f32, p[1] as f32, p[2] as f32);
408                if candidate.x.is_finite() && candidate.y.is_finite() && candidate.z.is_finite() {
409                    return candidate;
410                }
411            }
412
413            // Fallback: pick the vertex with minimum quadric error
414            let mut best_vi = cluster[0];
415            let mut best_err = f64::MAX;
416            for &vi in cluster {
417                let err = quadric_error_at(&positions[vi], &q_sum);
418                if err < best_err {
419                    best_err = err;
420                    best_vi = vi;
421                }
422            }
423            positions[best_vi]
424        }
425    }
426}
427
428// ============================================================
429// Clustering Simplifier
430// ============================================================
431
432/// Clustering-based mesh simplifier.
433///
434/// Uses vertex clustering (Rossignac & Borrel 1993) to rapidly simplify meshes.
435/// Supports uniform grid and adaptive octree clustering modes, multiple
436/// representative selection strategies, boundary preservation, and sharp
437/// feature maintenance.
438pub struct ClusteringSimplifier {
439    /// Clustering grid mode (uniform or adaptive octree).
440    pub mode: ClusteringMode,
441    /// Strategy for choosing the representative position of each cluster.
442    pub representative_strategy: RepresentativeStrategy,
443    /// If true, boundary vertices are clustered only with other boundary
444    /// vertices in the same cell, preventing boundary drift.
445    pub preserve_boundary: bool,
446    /// Dihedral angle threshold (in radians) for sharp feature detection.
447    /// Edges with dihedral angles exceeding this are treated as features.
448    pub feature_angle_threshold: f32,
449}
450
451impl Default for ClusteringSimplifier {
452    fn default() -> Self {
453        Self {
454            mode: ClusteringMode::Uniform,
455            representative_strategy: RepresentativeStrategy::Centroid,
456            preserve_boundary: true,
457            feature_angle_threshold: 45.0_f32.to_radians(),
458        }
459    }
460}
461
462impl ClusteringSimplifier {
463    pub fn new() -> Self {
464        Self::default()
465    }
466
467    pub fn with_params(
468        mode: ClusteringMode,
469        representative_strategy: RepresentativeStrategy,
470        preserve_boundary: bool,
471        feature_angle_threshold: f32,
472    ) -> Self {
473        Self {
474            mode,
475            representative_strategy,
476            preserve_boundary,
477            feature_angle_threshold,
478        }
479    }
480
481    /// Compute uniform grid cell size from the desired reduction ratio and mesh bounding box.
482    /// Handles degenerate (planar/linear) meshes by using only non-degenerate dimensions.
483    fn compute_cell_size(bbox: &BBox, num_vertices: usize, reduction_ratio: f32) -> f64 {
484        let target_clusters = ((1.0 - reduction_ratio) * num_vertices as f32).max(1.0) as f64;
485        let s = bbox.size();
486        let eps = 1e-6;
487
488        // Use only non-degenerate dimensions for cell size calculation
489        let extents: Vec<f64> = s.iter().filter(|&&d| d > eps).copied().collect();
490        let dim = extents.len();
491
492        if dim == 0 {
493            // All vertices at same point; any cell size works
494            return 1.0;
495        }
496
497        // product(extents) / cell_size^dim ≈ target_clusters
498        let product: f64 = extents.iter().product();
499        (product / target_clusters).powf(1.0 / dim as f64)
500    }
501
502    /// Assign vertices to uniform grid cells, returning a map of cell key -> vertex indices.
503    fn uniform_clustering(
504        &self,
505        mesh: &TriangleMesh,
506        cell_size: f64,
507        bbox: &BBox,
508        boundary_verts: &HashSet<usize>,
509        feature_verts: &HashSet<usize>,
510    ) -> Vec<Vec<usize>> {
511        // Cell key: (ix, iy, iz, class) where class separates boundary/feature/interior
512        let mut cells: HashMap<(i64, i64, i64, u8), Vec<usize>> = HashMap::new();
513
514        for (vi, v) in mesh.vertices.iter().enumerate() {
515            let ix = ((v.x as f64 - bbox.min[0]) / cell_size).floor() as i64;
516            let iy = ((v.y as f64 - bbox.min[1]) / cell_size).floor() as i64;
517            let iz = ((v.z as f64 - bbox.min[2]) / cell_size).floor() as i64;
518
519            let class = if self.preserve_boundary && boundary_verts.contains(&vi) {
520                1u8
521            } else if feature_verts.contains(&vi) {
522                2u8
523            } else {
524                0u8
525            };
526
527            cells.entry((ix, iy, iz, class)).or_default().push(vi);
528        }
529
530        cells.into_values().collect()
531    }
532
533    /// Assign vertices using adaptive octree clustering.
534    fn adaptive_clustering(
535        &self,
536        mesh: &TriangleMesh,
537        bbox: &BBox,
538        quadrics: &[Matrix4<f64>],
539        max_depth: u32,
540        error_threshold: f64,
541        boundary_verts: &HashSet<usize>,
542        feature_verts: &HashSet<usize>,
543    ) -> Vec<Vec<usize>> {
544        // Separate protected vertices (boundary/feature) from interior
545        let mut protected: HashMap<(i64, i64, i64, u8), Vec<usize>> = HashMap::new();
546        let mut interior_indices: Vec<usize> = Vec::new();
547
548        // Use a coarse grid for protected vertex grouping
549        let extent = bbox.max_extent().max(1e-6);
550        let coarse_size = extent / (1 << max_depth.min(6)) as f64;
551
552        for vi in 0..mesh.vertices.len() {
553            let is_boundary = self.preserve_boundary && boundary_verts.contains(&vi);
554            let is_feature = feature_verts.contains(&vi);
555
556            if is_boundary || is_feature {
557                let v = &mesh.vertices[vi];
558                let ix = ((v.x as f64 - bbox.min[0]) / coarse_size).floor() as i64;
559                let iy = ((v.y as f64 - bbox.min[1]) / coarse_size).floor() as i64;
560                let iz = ((v.z as f64 - bbox.min[2]) / coarse_size).floor() as i64;
561                let class = if is_boundary { 1u8 } else { 2u8 };
562                protected.entry((ix, iy, iz, class)).or_default().push(vi);
563            } else {
564                interior_indices.push(vi);
565            }
566        }
567
568        // Build octree for interior vertices
569        // Make bbox slightly larger to ensure all points are inside
570        let padded_bbox = BBox {
571            min: [bbox.min[0] - 1e-4, bbox.min[1] - 1e-4, bbox.min[2] - 1e-4],
572            max: [bbox.max[0] + 1e-4, bbox.max[1] + 1e-4, bbox.max[2] + 1e-4],
573        };
574        let mut root = OctreeNode::new(padded_bbox, 0);
575
576        for &vi in &interior_indices {
577            root.insert(vi, &mesh.vertices, quadrics, max_depth, error_threshold);
578        }
579
580        let mut clusters: Vec<Vec<usize>> = Vec::new();
581        root.collect_clusters(&mut clusters);
582
583        // Add protected clusters
584        for (_, verts) in protected {
585            if !verts.is_empty() {
586                clusters.push(verts);
587            }
588        }
589
590        clusters
591    }
592
593    /// Build the simplified mesh from clusters.
594    fn build_simplified_mesh(
595        &self,
596        mesh: &TriangleMesh,
597        clusters: &[Vec<usize>],
598        quadrics: &[Matrix4<f64>],
599        valence: &[usize],
600    ) -> TriangleMesh {
601        // Map each original vertex to its cluster index
602        let mut vertex_to_cluster: Vec<usize> = vec![0; mesh.vertices.len()];
603        for (ci, cluster) in clusters.iter().enumerate() {
604            for &vi in cluster {
605                vertex_to_cluster[vi] = ci;
606            }
607        }
608
609        // Compute representative positions
610        let representatives: Vec<Point3f> = clusters
611            .iter()
612            .map(|cluster| {
613                if cluster.len() == 1 {
614                    mesh.vertices[cluster[0]]
615                } else {
616                    select_representative(
617                        cluster,
618                        &mesh.vertices,
619                        quadrics,
620                        valence,
621                        self.representative_strategy,
622                    )
623                }
624            })
625            .collect();
626
627        // Remap faces, filtering degenerate triangles
628        let mut new_faces: Vec<[usize; 3]> = Vec::new();
629        let mut seen_faces: HashSet<[usize; 3]> = HashSet::new();
630
631        for face in &mesh.faces {
632            let nv0 = vertex_to_cluster[face[0]];
633            let nv1 = vertex_to_cluster[face[1]];
634            let nv2 = vertex_to_cluster[face[2]];
635
636            // Skip degenerate triangles
637            if nv0 == nv1 || nv1 == nv2 || nv2 == nv0 {
638                continue;
639            }
640
641            // Canonical ordering to deduplicate
642            let mut sorted = [nv0, nv1, nv2];
643            sorted.sort();
644            if seen_faces.insert(sorted) {
645                new_faces.push([nv0, nv1, nv2]);
646            }
647        }
648
649        // Compact: only include clusters that appear in at least one face
650        let mut used_clusters: HashSet<usize> = HashSet::new();
651        for face in &new_faces {
652            for &vi in face {
653                used_clusters.insert(vi);
654            }
655        }
656
657        let mut old_to_new: HashMap<usize, usize> = HashMap::new();
658        let mut new_vertices: Vec<Point3f> = Vec::new();
659        let mut new_normals: Option<Vec<Vector3f>> = mesh.normals.as_ref().map(|_| Vec::new());
660        let mut new_colors: Option<Vec<[u8; 3]>> = mesh.colors.as_ref().map(|_| Vec::new());
661
662        for (ci, cluster) in clusters.iter().enumerate() {
663            if !used_clusters.contains(&ci) {
664                continue;
665            }
666            let new_idx = new_vertices.len();
667            old_to_new.insert(ci, new_idx);
668            new_vertices.push(representatives[ci]);
669
670            // Interpolate normals: average normals of cluster members
671            if let Some(ref normals) = mesh.normals {
672                let mut avg = Vector3f::new(0.0, 0.0, 0.0);
673                for &vi in cluster {
674                    avg += normals[vi];
675                }
676                let len = avg.magnitude();
677                if len > 1e-12 {
678                    avg /= len;
679                }
680                new_normals.as_mut().unwrap().push(avg);
681            }
682
683            // Interpolate colors: average colors of cluster members
684            if let Some(ref colors) = mesh.colors {
685                let mut r = 0u32;
686                let mut g = 0u32;
687                let mut b = 0u32;
688                for &vi in cluster {
689                    r += colors[vi][0] as u32;
690                    g += colors[vi][1] as u32;
691                    b += colors[vi][2] as u32;
692                }
693                let n = cluster.len() as u32;
694                new_colors
695                    .as_mut()
696                    .unwrap()
697                    .push([(r / n) as u8, (g / n) as u8, (b / n) as u8]);
698            }
699        }
700
701        // Remap face indices
702        let remapped_faces: Vec<[usize; 3]> = new_faces
703            .iter()
704            .filter_map(|f| {
705                match (old_to_new.get(&f[0]), old_to_new.get(&f[1]), old_to_new.get(&f[2])) {
706                    (Some(&a), Some(&b), Some(&c)) if a != b && b != c && c != a => {
707                        Some([a, b, c])
708                    }
709                    _ => None,
710                }
711            })
712            .collect();
713
714        let mut result = TriangleMesh::from_vertices_and_faces(new_vertices, remapped_faces);
715        if let Some(normals) = new_normals {
716            result.set_normals(normals);
717        }
718        if let Some(colors) = new_colors {
719            result.set_colors(colors);
720        }
721        result
722    }
723}
724
725impl MeshSimplifier for ClusteringSimplifier {
726    fn simplify(&self, mesh: &TriangleMesh, reduction_ratio: f32) -> Result<TriangleMesh> {
727        if mesh.is_empty() {
728            return Err(Error::InvalidData("Mesh is empty".to_string()));
729        }
730        if !(0.0..=1.0).contains(&reduction_ratio) {
731            return Err(Error::InvalidData(
732                "Reduction ratio must be between 0.0 and 1.0".to_string(),
733            ));
734        }
735        if reduction_ratio == 0.0 {
736            return Ok(mesh.clone());
737        }
738
739        let bbox = BBox::from_vertices(&mesh.vertices);
740        let quadrics = compute_quadrics(mesh);
741        let valence = compute_vertex_valence(mesh);
742        let boundary_verts = if self.preserve_boundary {
743            find_boundary_vertices(mesh)
744        } else {
745            HashSet::new()
746        };
747        let feature_verts = find_feature_vertices(mesh, self.feature_angle_threshold);
748
749        let clusters = match self.mode {
750            ClusteringMode::Uniform => {
751                let cell_size =
752                    Self::compute_cell_size(&bbox, mesh.vertices.len(), reduction_ratio);
753                self.uniform_clustering(mesh, cell_size, &bbox, &boundary_verts, &feature_verts)
754            }
755            ClusteringMode::Adaptive {
756                max_depth,
757                error_threshold,
758            } => self.adaptive_clustering(
759                mesh,
760                &bbox,
761                &quadrics,
762                max_depth,
763                error_threshold,
764                &boundary_verts,
765                &feature_verts,
766            ),
767        };
768
769        Ok(self.build_simplified_mesh(mesh, &clusters, &quadrics, &valence))
770    }
771}
772
773#[cfg(test)]
774mod tests {
775    use super::*;
776    use nalgebra::Point3;
777
778    fn make_single_triangle() -> TriangleMesh {
779        TriangleMesh::from_vertices_and_faces(
780            vec![
781                Point3::new(0.0, 0.0, 0.0),
782                Point3::new(1.0, 0.0, 0.0),
783                Point3::new(0.5, 1.0, 0.0),
784            ],
785            vec![[0, 1, 2]],
786        )
787    }
788
789    fn make_tetrahedron() -> TriangleMesh {
790        TriangleMesh::from_vertices_and_faces(
791            vec![
792                Point3::new(0.0, 0.0, 0.0),
793                Point3::new(1.0, 0.0, 0.0),
794                Point3::new(0.5, 1.0, 0.0),
795                Point3::new(0.5, 0.5, 1.0),
796            ],
797            vec![[0, 2, 1], [0, 1, 3], [0, 3, 2], [1, 2, 3]],
798        )
799    }
800
801    fn make_plane_grid(size: usize) -> TriangleMesh {
802        let mut vertices = Vec::new();
803        for y in 0..size {
804            for x in 0..size {
805                vertices.push(Point3::new(x as f32, y as f32, 0.0));
806            }
807        }
808        let mut faces = Vec::new();
809        for y in 0..(size - 1) {
810            for x in 0..(size - 1) {
811                let tl = y * size + x;
812                let tr = tl + 1;
813                let bl = (y + 1) * size + x;
814                let br = bl + 1;
815                faces.push([tl, bl, tr]);
816                faces.push([tr, bl, br]);
817            }
818        }
819        TriangleMesh::from_vertices_and_faces(vertices, faces)
820    }
821
822    fn make_curved_surface(size: usize) -> TriangleMesh {
823        let mut vertices = Vec::new();
824        for y in 0..size {
825            for x in 0..size {
826                let fx = x as f32 / (size - 1) as f32 * std::f32::consts::PI;
827                let fy = y as f32 / (size - 1) as f32 * std::f32::consts::PI;
828                vertices.push(Point3::new(
829                    x as f32,
830                    y as f32,
831                    (fx.sin() * fy.sin()) * 2.0,
832                ));
833            }
834        }
835        let mut faces = Vec::new();
836        for y in 0..(size - 1) {
837            for x in 0..(size - 1) {
838                let tl = y * size + x;
839                let tr = tl + 1;
840                let bl = (y + 1) * size + x;
841                let br = bl + 1;
842                faces.push([tl, bl, tr]);
843                faces.push([tr, bl, br]);
844            }
845        }
846        TriangleMesh::from_vertices_and_faces(vertices, faces)
847    }
848
849    fn make_sharp_edge_mesh() -> TriangleMesh {
850        // Two planes meeting at a 90-degree angle along the x-axis
851        TriangleMesh::from_vertices_and_faces(
852            vec![
853                // Bottom plane (z=0)
854                Point3::new(0.0, 0.0, 0.0),
855                Point3::new(1.0, 0.0, 0.0),
856                Point3::new(2.0, 0.0, 0.0),
857                Point3::new(0.0, 1.0, 0.0),
858                Point3::new(1.0, 1.0, 0.0),
859                Point3::new(2.0, 1.0, 0.0),
860                // Top plane (going up at 90 degrees from the y=1 edge)
861                Point3::new(0.0, 1.0, 1.0),
862                Point3::new(1.0, 1.0, 1.0),
863                Point3::new(2.0, 1.0, 1.0),
864            ],
865            vec![
866                // Bottom plane faces
867                [0, 1, 3],
868                [1, 4, 3],
869                [1, 2, 4],
870                [2, 5, 4],
871                // Top plane faces
872                [3, 4, 6],
873                [4, 7, 6],
874                [4, 5, 7],
875                [5, 8, 7],
876            ],
877        )
878    }
879
880    // ---- Construction tests ----
881
882    #[test]
883    fn test_creation() {
884        let s = ClusteringSimplifier::new();
885        assert!(s.preserve_boundary);
886        assert_eq!(s.mode, ClusteringMode::Uniform);
887        assert_eq!(s.representative_strategy, RepresentativeStrategy::Centroid);
888    }
889
890    #[test]
891    fn test_with_params() {
892        let s = ClusteringSimplifier::with_params(
893            ClusteringMode::Adaptive {
894                max_depth: 6,
895                error_threshold: 0.01,
896            },
897            RepresentativeStrategy::MinimumError,
898            false,
899            30.0_f32.to_radians(),
900        );
901        assert!(!s.preserve_boundary);
902        assert_eq!(
903            s.mode,
904            ClusteringMode::Adaptive {
905                max_depth: 6,
906                error_threshold: 0.01,
907            }
908        );
909        assert_eq!(
910            s.representative_strategy,
911            RepresentativeStrategy::MinimumError
912        );
913    }
914
915    // ---- Validation tests ----
916
917    #[test]
918    fn test_empty_mesh() {
919        let s = ClusteringSimplifier::new();
920        let mesh = TriangleMesh::new();
921        assert!(s.simplify(&mesh, 0.5).is_err());
922    }
923
924    #[test]
925    fn test_invalid_reduction_ratio() {
926        let s = ClusteringSimplifier::new();
927        let mesh = make_single_triangle();
928        assert!(s.simplify(&mesh, -0.1).is_err());
929        assert!(s.simplify(&mesh, 1.1).is_err());
930    }
931
932    #[test]
933    fn test_zero_reduction() {
934        let s = ClusteringSimplifier::new();
935        let mesh = make_single_triangle();
936        let result = s.simplify(&mesh, 0.0).unwrap();
937        assert_eq!(result.vertex_count(), 3);
938        assert_eq!(result.face_count(), 1);
939    }
940
941    // ---- Uniform clustering tests ----
942
943    #[test]
944    fn test_single_triangle_uniform() {
945        let s = ClusteringSimplifier::new();
946        let mesh = make_single_triangle();
947        let result = s.simplify(&mesh, 0.5).unwrap();
948        // Should still produce valid output
949        assert!(result.vertex_count() > 0);
950    }
951
952    #[test]
953    fn test_tetrahedron_uniform() {
954        let s = ClusteringSimplifier::with_params(
955            ClusteringMode::Uniform,
956            RepresentativeStrategy::Centroid,
957            false,
958            std::f32::consts::PI,
959        );
960        let mesh = make_tetrahedron();
961        let result = s.simplify(&mesh, 0.5).unwrap();
962        assert!(result.vertex_count() <= mesh.vertex_count());
963    }
964
965    #[test]
966    fn test_planar_grid_uniform() {
967        let s = ClusteringSimplifier::new();
968        let mesh = make_plane_grid(6);
969        let original_faces = mesh.face_count();
970        assert_eq!(original_faces, 50);
971
972        let result = s.simplify(&mesh, 0.5).unwrap();
973        assert!(result.face_count() < original_faces);
974        assert!(result.face_count() > 0);
975    }
976
977    #[test]
978    fn test_curved_surface_uniform() {
979        let s = ClusteringSimplifier::new();
980        let mesh = make_curved_surface(8);
981        let original_faces = mesh.face_count();
982
983        let result = s.simplify(&mesh, 0.5).unwrap();
984        assert!(result.face_count() < original_faces);
985        assert!(result.face_count() > 0);
986    }
987
988    #[test]
989    fn test_large_grid_uniform() {
990        let s = ClusteringSimplifier::new();
991        let mesh = make_plane_grid(11);
992        let original = mesh.face_count(); // 200 faces
993
994        let result = s.simplify(&mesh, 0.5).unwrap();
995        assert!(result.face_count() < original);
996        assert!(result.face_count() > 0);
997        assert!(result.vertex_count() > 0);
998    }
999
1000    // ---- Adaptive clustering tests ----
1001
1002    #[test]
1003    fn test_planar_grid_adaptive() {
1004        let s = ClusteringSimplifier::with_params(
1005            ClusteringMode::Adaptive {
1006                max_depth: 4,
1007                error_threshold: 0.01,
1008            },
1009            RepresentativeStrategy::Centroid,
1010            true,
1011            45.0_f32.to_radians(),
1012        );
1013        let mesh = make_plane_grid(6);
1014        let original_faces = mesh.face_count();
1015
1016        let result = s.simplify(&mesh, 0.5).unwrap();
1017        assert!(result.face_count() <= original_faces);
1018        assert!(result.face_count() > 0);
1019    }
1020
1021    #[test]
1022    fn test_curved_surface_adaptive() {
1023        let s = ClusteringSimplifier::with_params(
1024            ClusteringMode::Adaptive {
1025                max_depth: 5,
1026                error_threshold: 0.1,
1027            },
1028            RepresentativeStrategy::MinimumError,
1029            true,
1030            45.0_f32.to_radians(),
1031        );
1032        let mesh = make_curved_surface(8);
1033        let original_faces = mesh.face_count();
1034
1035        let result = s.simplify(&mesh, 0.5).unwrap();
1036        assert!(result.face_count() <= original_faces);
1037        assert!(result.face_count() > 0);
1038    }
1039
1040    // ---- Representative strategy tests ----
1041
1042    #[test]
1043    fn test_centroid_strategy() {
1044        let s = ClusteringSimplifier::with_params(
1045            ClusteringMode::Uniform,
1046            RepresentativeStrategy::Centroid,
1047            false,
1048            std::f32::consts::PI,
1049        );
1050        let mesh = make_plane_grid(6);
1051        let result = s.simplify(&mesh, 0.5).unwrap();
1052        assert!(result.face_count() > 0);
1053    }
1054
1055    #[test]
1056    fn test_weighted_average_strategy() {
1057        let s = ClusteringSimplifier::with_params(
1058            ClusteringMode::Uniform,
1059            RepresentativeStrategy::WeightedAverage,
1060            false,
1061            std::f32::consts::PI,
1062        );
1063        let mesh = make_plane_grid(6);
1064        let result = s.simplify(&mesh, 0.5).unwrap();
1065        assert!(result.face_count() > 0);
1066    }
1067
1068    #[test]
1069    fn test_minimum_error_strategy() {
1070        let s = ClusteringSimplifier::with_params(
1071            ClusteringMode::Uniform,
1072            RepresentativeStrategy::MinimumError,
1073            false,
1074            std::f32::consts::PI,
1075        );
1076        let mesh = make_plane_grid(6);
1077        let result = s.simplify(&mesh, 0.5).unwrap();
1078        assert!(result.face_count() > 0);
1079    }
1080
1081    // ---- Boundary preservation tests ----
1082
1083    #[test]
1084    fn test_boundary_preservation_uniform() {
1085        let s = ClusteringSimplifier::with_params(
1086            ClusteringMode::Uniform,
1087            RepresentativeStrategy::Centroid,
1088            true,
1089            45.0_f32.to_radians(),
1090        );
1091        let mesh = make_plane_grid(6);
1092
1093        let original_boundary: HashSet<(i32, i32, i32)> = {
1094            let size = 6;
1095            let mut set = HashSet::new();
1096            for i in 0..size {
1097                for j in 0..size {
1098                    if i == 0 || i == size - 1 || j == 0 || j == size - 1 {
1099                        let idx = i * size + j;
1100                        let p = mesh.vertices[idx];
1101                        set.insert((
1102                            (p.x * 100.0) as i32,
1103                            (p.y * 100.0) as i32,
1104                            (p.z * 100.0) as i32,
1105                        ));
1106                    }
1107                }
1108            }
1109            set
1110        };
1111
1112        let result = s.simplify(&mesh, 0.5).unwrap();
1113        let result_positions: HashSet<(i32, i32, i32)> = result
1114            .vertices
1115            .iter()
1116            .map(|p| {
1117                (
1118                    (p.x * 100.0) as i32,
1119                    (p.y * 100.0) as i32,
1120                    (p.z * 100.0) as i32,
1121                )
1122            })
1123            .collect();
1124
1125        // Boundary vertices should be mostly preserved (clustered separately)
1126        let preserved = original_boundary.intersection(&result_positions).count();
1127        let ratio = preserved as f32 / original_boundary.len() as f32;
1128        assert!(
1129            ratio > 0.5,
1130            "Expected >50% boundary preservation, got {:.1}%",
1131            ratio * 100.0
1132        );
1133    }
1134
1135    // ---- Sharp feature tests ----
1136
1137    #[test]
1138    fn test_sharp_feature_detection() {
1139        let mesh = make_sharp_edge_mesh();
1140        let feature_verts = find_feature_vertices(&mesh, 45.0_f32.to_radians());
1141        // Vertices along the sharp edge (indices 3, 4, 5) should be detected
1142        assert!(
1143            !feature_verts.is_empty(),
1144            "Should detect feature vertices at the 90-degree edge"
1145        );
1146    }
1147
1148    #[test]
1149    fn test_sharp_feature_preservation() {
1150        let s = ClusteringSimplifier::with_params(
1151            ClusteringMode::Uniform,
1152            RepresentativeStrategy::Centroid,
1153            true,
1154            45.0_f32.to_radians(),
1155        );
1156        let mesh = make_sharp_edge_mesh();
1157        let result = s.simplify(&mesh, 0.3).unwrap();
1158        assert!(result.face_count() > 0);
1159        assert!(result.vertex_count() > 0);
1160    }
1161
1162    // ---- Attribute preservation tests ----
1163
1164    #[test]
1165    fn test_attribute_preservation_normals() {
1166        let mut mesh = make_plane_grid(5);
1167        let normals: Vec<Vector3f> = (0..mesh.vertex_count())
1168            .map(|_| Vector3f::new(0.0, 0.0, 1.0))
1169            .collect();
1170        mesh.set_normals(normals);
1171
1172        let s = ClusteringSimplifier::new();
1173        let result = s.simplify(&mesh, 0.3).unwrap();
1174        assert!(result.normals.is_some(), "normals should be preserved");
1175        let result_normals = result.normals.as_ref().unwrap();
1176        assert_eq!(result_normals.len(), result.vertex_count());
1177        for n in result_normals {
1178            assert!(n.z > 0.9, "normal z should be close to 1.0, got {}", n.z);
1179        }
1180    }
1181
1182    #[test]
1183    fn test_attribute_preservation_colors() {
1184        let mut mesh = make_plane_grid(5);
1185        let colors: Vec<[u8; 3]> = (0..mesh.vertex_count()).map(|_| [128, 64, 200]).collect();
1186        mesh.set_colors(colors);
1187
1188        let s = ClusteringSimplifier::new();
1189        let result = s.simplify(&mesh, 0.3).unwrap();
1190        assert!(result.colors.is_some(), "colors should be preserved");
1191        assert_eq!(result.colors.as_ref().unwrap().len(), result.vertex_count());
1192    }
1193
1194    // ---- Comparison tests (clustering vs other methods produce valid output) ----
1195
1196    #[test]
1197    fn test_all_strategies_produce_valid_output() {
1198        let mesh = make_curved_surface(8);
1199        let strategies = [
1200            RepresentativeStrategy::Centroid,
1201            RepresentativeStrategy::WeightedAverage,
1202            RepresentativeStrategy::MinimumError,
1203        ];
1204
1205        for strategy in &strategies {
1206            let s = ClusteringSimplifier::with_params(
1207                ClusteringMode::Uniform,
1208                *strategy,
1209                true,
1210                45.0_f32.to_radians(),
1211            );
1212            let result = s.simplify(&mesh, 0.5).unwrap();
1213            assert!(
1214                result.face_count() > 0,
1215                "Strategy {:?} produced empty mesh",
1216                strategy
1217            );
1218            assert!(
1219                result.vertex_count() > 0,
1220                "Strategy {:?} produced no vertices",
1221                strategy
1222            );
1223            assert!(
1224                result.face_count() < mesh.face_count(),
1225                "Strategy {:?} did not reduce faces",
1226                strategy
1227            );
1228        }
1229    }
1230
1231    #[test]
1232    fn test_both_modes_produce_valid_output() {
1233        let mesh = make_curved_surface(8);
1234
1235        let uniform = ClusteringSimplifier::new();
1236        let adaptive = ClusteringSimplifier::with_params(
1237            ClusteringMode::Adaptive {
1238                max_depth: 4,
1239                error_threshold: 0.1,
1240            },
1241            RepresentativeStrategy::Centroid,
1242            true,
1243            45.0_f32.to_radians(),
1244        );
1245
1246        let r1 = uniform.simplify(&mesh, 0.5).unwrap();
1247        let r2 = adaptive.simplify(&mesh, 0.5).unwrap();
1248
1249        assert!(r1.face_count() > 0);
1250        assert!(r2.face_count() > 0);
1251        assert!(r1.face_count() < mesh.face_count());
1252        assert!(r2.face_count() < mesh.face_count());
1253    }
1254}