xycut_plus_plus/
core.rs

1use core::f32;
2
3use crate::histogram::{build_horizontal_histogram, build_vertical_histogram, find_largest_gap};
4use crate::matching::partition_by_mask;
5use crate::traits::{BoundingBox, SemanticLabel};
6use crate::utils::compute_distance_with_early_exit;
7
8/// Configuration for XY-Cut algorithm
9#[derive(Debug, Clone)]
10pub struct XYCutConfig {
11    /// Minimum gap size (in pixels) to consider for cutting
12    pub min_cut_threshold: f32,
13
14    /// Resolution for projection histogram (bin per 100 pixels)
15    pub histogram_resolution_scale: f32,
16
17    /// Tolerance for considering elements in the same row (pixels)
18    pub same_row_tolerance: f32,
19}
20
21impl Default for XYCutConfig {
22    fn default() -> Self {
23        Self {
24            min_cut_threshold: 15.0,
25            histogram_resolution_scale: 0.5, // 1 bin per 2 pixels
26            same_row_tolerance: 10.0,
27        }
28    }
29}
30
31pub struct XYCutPlusPlus {
32    config: XYCutConfig,
33}
34
35impl XYCutPlusPlus {
36    pub fn new(config: XYCutConfig) -> Self {
37        Self { config }
38    }
39
40    /// Main entry point: compute reading order for elements
41    pub fn compute_order<T: BoundingBox>(
42        &self,
43        elements: &[T],
44        x_min: f32,
45        y_min: f32,
46        x_max: f32,
47        y_max: f32,
48    ) -> Vec<usize> {
49        // Validate empty input
50        if elements.is_empty() {
51            return Vec::new();
52        }
53
54        let page_width = x_max - x_min;
55        let page_height = y_max - y_min;
56
57        // Validate page dimensions
58        if !page_width.is_finite()
59            || !page_height.is_finite()
60            || page_width <= 0.0
61            || page_height <= 0.0
62        {
63            eprintln!(
64                "Warning: Invalid page dimensions ({}, {})",
65                page_width, page_height
66            );
67
68            return Vec::new();
69        }
70
71        let partition = partition_by_mask(elements, page_width, page_height);
72        let regular_order =
73            self.recursive_cut(&partition.regular_elements, x_min, y_min, x_max, y_max);
74
75        self.merged_masked_elements(
76            &partition.regular_elements,
77            &regular_order,
78            &partition.masked_elements,
79        )
80    }
81
82    // TODO: Add this function before recursive_cut
83    /// Calculate density ratio τd (tau_d) from Equation 4-5
84    /// τd = Σ(w_k^(Cc) / h_k^(Cc)) / Σ(w_k^(Cs) / h_k^(Cs))
85    fn compute_density_ratio<T: BoundingBox>(elements: &[T]) -> f32 {
86        let mut cross_layout_density = 0.0; // Cc - wide elements
87        let mut single_layout_density = 0.0; // Cs - narrow elements
88
89        for element in elements {
90            let (x1, y1, x2, y2) = element.bounds();
91            let width = x2 - x1;
92            let height = y2 - y1;
93
94            // Avoid division by zero
95            if height == 0.0 {
96                continue;
97            }
98
99            let aspect_ratio = width / height;
100
101            // Use semantic label instead of width threshold
102            match element.semantic_label() {
103                SemanticLabel::CrossLayout => cross_layout_density += aspect_ratio,
104                _ => single_layout_density += aspect_ratio,
105            }
106        }
107
108        // Return the ratio τd = cross_layout_density / single_layout_density
109        // Handle division by zero: if single_layout_density == 0.0, return 1.0
110        if single_layout_density == 0.0 {
111            return 1.0;
112        }
113
114        cross_layout_density / single_layout_density
115    }
116
117    fn recursive_cut<T: BoundingBox>(
118        &self,
119        elements: &[T],
120        x_min: f32,
121        y_min: f32,
122        x_max: f32,
123        y_max: f32,
124    ) -> Vec<usize> {
125        if elements.is_empty() {
126            return Vec::new();
127        }
128        if elements.len() == 1 {
129            return vec![elements[0].id()];
130        }
131
132        // Equation 4: Calculate density ration τd
133        let tau_d = Self::compute_density_ratio(elements);
134
135        // Equation 5: Use XY-Cut (vertical first) if τd > 0.9
136        let try_vertical_first = tau_d > 0.9;
137
138        if try_vertical_first {
139            // Try vertical cut first for multi-column layouts
140            if let Some(x_cut) = self.find_vertical_cut(elements, x_min, x_max) {
141                eprintln!(
142                    "  [XYCut] Vertical cut at x={:.0}, splitting {} elements (multi-column)",
143                    x_cut,
144                    elements.len()
145                );
146                let (left, right) = self.split_vertical(elements, x_cut);
147                eprintln!(
148                    "    → Left: {} elements, Right: {} elements",
149                    left.len(),
150                    right.len()
151                );
152                let mut result = Vec::new();
153                result.extend(self.recursive_cut(&left, x_min, y_min, x_cut, y_max));
154                result.extend(self.recursive_cut(&right, x_cut, y_min, x_max, y_max));
155                return result;
156            }
157        }
158
159        // Try horizontal cut first (top-to-bottom reading)
160        if let Some(y_cut) = self.find_horizontal_cut(elements, y_min, y_max) {
161            eprintln!(
162                "  [XYCut] Horizontal cut at y={:.0}, splitting {} elements",
163                y_cut,
164                elements.len()
165            );
166            let (top, bottom) = self.split_horizontal(elements, y_cut);
167            eprintln!(
168                "    → Top: {} elements, Bottom: {} elements",
169                top.len(),
170                bottom.len()
171            );
172            let mut result = Vec::new();
173            result.extend(self.recursive_cut(&top, x_min, y_min, x_max, y_cut));
174            result.extend(self.recursive_cut(&bottom, x_min, y_cut, x_max, y_max));
175            return result;
176        }
177
178        // Try vertical cut (left-to-right for multi-column)
179        if let Some(x_cut) = self.find_vertical_cut(elements, x_min, x_max) {
180            eprintln!(
181                "  [XYCut] Vertical cut at x={:.0}, splitting {} elements",
182                x_cut,
183                elements.len()
184            );
185            let (left, right) = self.split_vertical(elements, x_cut);
186            eprintln!(
187                "    → Left: {} elements, Right: {} elements",
188                left.len(),
189                right.len()
190            );
191            let mut result = Vec::new();
192            result.extend(self.recursive_cut(&left, x_min, y_min, x_cut, y_max));
193            result.extend(self.recursive_cut(&right, x_cut, y_min, x_max, y_max));
194            return result;
195        }
196
197        // No valid cuts found - sort by position
198        eprintln!(
199            "  [XYCut] No cuts found, sorting {} elements by position",
200            elements.len()
201        );
202        self.sort_by_position(elements)
203    }
204
205    /// Find horizontal cut position using projection histogram
206    /// Returns y-coordinate where to split, or None if no good cut found
207    fn find_horizontal_cut<T: BoundingBox>(
208        &self,
209        elements: &[T],
210        y_min: f32,
211        y_max: f32,
212    ) -> Option<f32> {
213        let resolution = ((y_max - y_min) * self.config.histogram_resolution_scale) as usize;
214        let histogram = build_horizontal_histogram(elements, y_min, y_max, resolution);
215
216        let min_gap_bins =
217            (self.config.min_cut_threshold * self.config.histogram_resolution_scale) as usize;
218
219        let bin_index = find_largest_gap(&histogram, min_gap_bins);
220
221        if let Some(bin_index) = bin_index {
222            let y_coord = y_min + (bin_index as f32 / resolution as f32) * (y_max - y_min);
223            return Some(y_coord);
224        }
225
226        None
227    }
228
229    /// Find vertical cut position using projection histogram
230    /// Returns x-coordinate where to split, or None if no good cut found
231    fn find_vertical_cut<T: BoundingBox>(
232        &self,
233        elements: &[T],
234        x_min: f32,
235        x_max: f32,
236    ) -> Option<f32> {
237        let resolution = ((x_max - x_min) * self.config.histogram_resolution_scale) as usize;
238        let histogram = build_vertical_histogram(elements, x_min, x_max, resolution);
239
240        let min_gap_bins =
241            (self.config.min_cut_threshold * self.config.histogram_resolution_scale) as usize;
242
243        // Debug: show histogram for large element counts
244        if elements.len() > 15 {
245            eprintln!(
246                "    [Histogram] Vertical: {} bins, min_gap={}, x_range={:.0}-{:.0}",
247                resolution, min_gap_bins, x_min, x_max
248            );
249        }
250
251        let bin_index = find_largest_gap(&histogram, min_gap_bins);
252        if let Some(bin_index) = bin_index {
253            let x_coord = x_min + (bin_index as f32 / resolution as f32) * (x_max - x_min);
254            if elements.len() > 15 {
255                eprintln!(
256                    "    [Histogram] Found gap at bin {}, x={:.0}",
257                    bin_index, x_coord
258                );
259            }
260            return Some(x_coord);
261        }
262
263        None
264    }
265
266    /// Split elements into top and bottom groups based on y-coordinate cut
267    fn split_horizontal<T: BoundingBox>(&self, elements: &[T], y_cut: f32) -> (Vec<T>, Vec<T>) {
268        let mut top: Vec<T> = Vec::new();
269        let mut bottom: Vec<T> = Vec::new();
270
271        for element in elements.iter() {
272            if element.center().1 < y_cut {
273                top.push(element.clone());
274            } else {
275                bottom.push(element.clone())
276            }
277        }
278
279        (top, bottom)
280    }
281
282    /// Split elements into left and right groups based on x-coordinate cut
283    fn split_vertical<T: BoundingBox>(&self, elements: &[T], x_cut: f32) -> (Vec<T>, Vec<T>) {
284        let mut left: Vec<T> = Vec::new();
285        let mut right: Vec<T> = Vec::new();
286
287        for element in elements.iter() {
288            if element.center().0 < x_cut {
289                left.push(element.clone());
290            } else {
291                right.push(element.clone());
292            }
293        }
294
295        (left, right)
296    }
297
298    /// Fallback sorting when no valid cuts found
299    /// Sort by y-position first (top to bottom), then x-position (left to right)
300    fn sort_by_position<T: BoundingBox>(&self, elements: &[T]) -> Vec<usize> {
301        let mut indexed: Vec<(usize, T)> = elements
302            .iter()
303            .enumerate()
304            .map(|(i, bbox)| (i, bbox.clone()))
305            .collect();
306
307        indexed.sort_by(|a, b| {
308            let y_diff = (a.1.center().1 - b.1.center().1).abs();
309            if y_diff < self.config.same_row_tolerance {
310                // Same row - sort by x
311                a.1.center()
312                    .0
313                    .partial_cmp(&b.1.center().0)
314                    .unwrap_or(std::cmp::Ordering::Equal)
315            } else {
316                // Different rows - sort by y
317                a.1.center()
318                    .1
319                    .partial_cmp(&b.1.center().1)
320                    .unwrap_or(std::cmp::Ordering::Equal)
321            }
322        });
323
324        indexed.iter().map(|(_, bbox)| bbox.id()).collect()
325    }
326
327    fn compute_page_width<T: BoundingBox>(&self, elements: &[T]) -> f32 {
328        if elements.is_empty() {
329            return 0.0;
330        }
331        let x_min = elements
332            .iter()
333            .map(|e| e.bounds().0)
334            .fold(f32::INFINITY, f32::min);
335        let x_max = elements
336            .iter()
337            .map(|e| e.bounds().2)
338            .fold(f32::NEG_INFINITY, f32::max);
339
340        x_max - x_min
341    }
342
343    fn merged_masked_elements<T: BoundingBox>(
344        &self,
345        regular_elements: &[T],
346        regular_order: &[usize],
347        masked_elements: &[T],
348    ) -> Vec<usize> {
349        // Start with regular order as base
350        let mut result: Vec<usize> = regular_order.to_vec();
351
352        let mut priority_groups: Vec<Vec<T>> = vec![Vec::new(); 4];
353        for element in masked_elements {
354            let priority = Self::label_priority(element.semantic_label()) as usize;
355            if priority < 4 {
356                priority_groups[priority].push(element.clone());
357            }
358        }
359
360        // Process each priority group in order (CrossLayout → Title → Vision → Regular)
361        for mut group in priority_groups {
362            // Within each priority group, sort by reading order (y, then x)
363            group.sort_by(|a, b| {
364                let y_diff = (a.center().1 - b.center().1).abs();
365                if y_diff < self.config.same_row_tolerance {
366                    a.center()
367                        .0
368                        .partial_cmp(&b.center().0)
369                        .unwrap_or(std::cmp::Ordering::Equal)
370                } else {
371                    a.center()
372                        .1
373                        .partial_cmp(&b.center().1)
374                        .unwrap_or(std::cmp::Ordering::Equal)
375                }
376            });
377
378            // Process each element in this priority group
379            for masked in &group {
380                // Find the best insertion position using 4-component distance metric
381                let mut best_distance = f32::INFINITY;
382                let mut best_position: Option<usize> = None;
383
384                // Get masked element's semantic priority for constraint checking
385                let masked_priority = Self::label_priority(masked.semantic_label());
386
387                // Search through result to handle growing array correctly
388                for (idx, &elem_id) in result.iter().enumerate() {
389                    // Find the element - could be regular OR previously inserted masked
390                    let candidate = regular_elements
391                        .iter()
392                        .find(|e| e.id() == elem_id)
393                        .cloned()
394                        .or_else(|| {
395                            // Also check masked elements from ALL groups
396                            masked_elements.iter().find(|e| e.id() == elem_id).cloned()
397                        });
398
399                    if let Some(candidate) = candidate {
400                        // Enforce L'o ⪰ l constraint (Equation 7)
401                        let candidate_priority = Self::label_priority(candidate.semantic_label());
402                        if candidate_priority < masked_priority {
403                            continue;
404                        }
405
406                        // Use 4-component distance metric
407                        let distance =
408                            compute_distance_with_early_exit(masked, &candidate, best_distance);
409                        if distance < best_distance {
410                            best_distance = distance;
411                            best_position = Some(idx);
412                        }
413                    }
414                }
415
416                if let Some(position) = best_position {
417                    eprintln!(
418                        "  [INSERT] Masked element {} ({:?}) -> position {} (before element {})",
419                        masked.id(),
420                        masked.semantic_label(),
421                        position,
422                        result[position]
423                    );
424                    result.insert(position, masked.id());
425                } else {
426                    // No valid match found - append to end as a fallback
427                    eprintln!(
428                        "⚠️  No valid insertion for element {} ({:?}), appending",
429                        masked.id(),
430                        masked.semantic_label()
431                    );
432                    result.push(masked.id());
433                }
434            }
435        }
436        result
437    }
438
439    /// Get priority value for semantic label (lower = higher priority)
440    fn label_priority(label: SemanticLabel) -> u8 {
441        match label {
442            SemanticLabel::CrossLayout => 0,
443            SemanticLabel::HorizontalTitle => 1,
444            SemanticLabel::VerticalTitle => 1,
445            SemanticLabel::Vision => 2,
446            SemanticLabel::Regular => 3,
447        }
448    }
449}