ruvector_scipix/preprocess/
segmentation.rs

1//! Text region detection and segmentation
2
3use super::{RegionType, Result, TextRegion};
4use image::GrayImage;
5use std::collections::{HashMap, HashSet};
6
7/// Find text regions in a binary or grayscale image
8///
9/// Uses connected component analysis and geometric heuristics to identify
10/// text regions and classify them by type (text, math, table, etc.)
11///
12/// # Arguments
13/// * `image` - Input grayscale or binary image
14/// * `min_region_size` - Minimum region area in pixels
15///
16/// # Returns
17/// Vector of detected text regions with bounding boxes
18///
19/// # Example
20/// ```no_run
21/// use ruvector_scipix::preprocess::segmentation::find_text_regions;
22/// # use image::GrayImage;
23/// # let image = GrayImage::new(100, 100);
24/// let regions = find_text_regions(&image, 100).unwrap();
25/// println!("Found {} regions", regions.len());
26/// ```
27pub fn find_text_regions(image: &GrayImage, min_region_size: u32) -> Result<Vec<TextRegion>> {
28    // Find connected components
29    let components = connected_components(image);
30
31    // Extract bounding boxes for each component
32    let bboxes = extract_bounding_boxes(&components);
33
34    // Filter by size and merge overlapping regions
35    let filtered = filter_by_size(bboxes, min_region_size);
36    let merged = merge_overlapping_regions(filtered, 10);
37
38    // Find text lines and group components
39    let text_lines = find_text_lines(image, &merged);
40
41    // Classify regions and create TextRegion objects
42    let regions = classify_regions(image, text_lines);
43
44    Ok(regions)
45}
46
47/// Connected component labeling using flood-fill algorithm
48///
49/// Returns labeled image where each connected component has a unique ID
50fn connected_components(image: &GrayImage) -> Vec<Vec<u32>> {
51    let (width, height) = image.dimensions();
52    let mut labels = vec![vec![0u32; width as usize]; height as usize];
53    let mut current_label = 1u32;
54
55    for y in 0..height {
56        for x in 0..width {
57            if labels[y as usize][x as usize] == 0 && image.get_pixel(x, y)[0] < 128 {
58                // Found unlabeled foreground pixel, start flood fill
59                flood_fill(image, &mut labels, x, y, current_label);
60                current_label += 1;
61            }
62        }
63    }
64
65    labels
66}
67
68/// Flood fill algorithm for connected component labeling
69fn flood_fill(
70    image: &GrayImage,
71    labels: &mut [Vec<u32>],
72    start_x: u32,
73    start_y: u32,
74    label: u32,
75) {
76    let (width, height) = image.dimensions();
77    let mut stack = vec![(start_x, start_y)];
78
79    while let Some((x, y)) = stack.pop() {
80        if x >= width || y >= height {
81            continue;
82        }
83
84        if labels[y as usize][x as usize] != 0 || image.get_pixel(x, y)[0] >= 128 {
85            continue;
86        }
87
88        labels[y as usize][x as usize] = label;
89
90        // Add 4-connected neighbors
91        if x > 0 {
92            stack.push((x - 1, y));
93        }
94        if x < width - 1 {
95            stack.push((x + 1, y));
96        }
97        if y > 0 {
98            stack.push((x, y - 1));
99        }
100        if y < height - 1 {
101            stack.push((x, y + 1));
102        }
103    }
104}
105
106/// Extract bounding boxes for each labeled component
107fn extract_bounding_boxes(labels: &[Vec<u32>]) -> HashMap<u32, (u32, u32, u32, u32)> {
108    let mut bboxes: HashMap<u32, (u32, u32, u32, u32)> = HashMap::new();
109
110    for (y, row) in labels.iter().enumerate() {
111        for (x, &label) in row.iter().enumerate() {
112            if label == 0 {
113                continue;
114            }
115
116            let bbox = bboxes.entry(label).or_insert((
117                x as u32,
118                y as u32,
119                x as u32,
120                y as u32,
121            ));
122
123            // Update bounding box
124            bbox.0 = bbox.0.min(x as u32); // min_x
125            bbox.1 = bbox.1.min(y as u32); // min_y
126            bbox.2 = bbox.2.max(x as u32); // max_x
127            bbox.3 = bbox.3.max(y as u32); // max_y
128        }
129    }
130
131    // Convert to (x, y, width, height) format
132    bboxes
133        .into_iter()
134        .map(|(label, (min_x, min_y, max_x, max_y))| {
135            let width = max_x - min_x + 1;
136            let height = max_y - min_y + 1;
137            (label, (min_x, min_y, width, height))
138        })
139        .collect()
140}
141
142/// Filter regions by minimum size
143fn filter_by_size(
144    bboxes: HashMap<u32, (u32, u32, u32, u32)>,
145    min_size: u32,
146) -> Vec<(u32, u32, u32, u32)> {
147    bboxes
148        .into_values()
149        .filter(|(_, _, w, h)| w * h >= min_size)
150        .collect()
151}
152
153/// Merge overlapping or nearby regions
154///
155/// # Arguments
156/// * `regions` - Vector of bounding boxes (x, y, width, height)
157/// * `merge_distance` - Maximum distance to merge regions
158pub fn merge_overlapping_regions(
159    regions: Vec<(u32, u32, u32, u32)>,
160    merge_distance: u32,
161) -> Vec<(u32, u32, u32, u32)> {
162    if regions.is_empty() {
163        return regions;
164    }
165
166    let mut merged = Vec::new();
167    let mut used = HashSet::new();
168
169    for i in 0..regions.len() {
170        if used.contains(&i) {
171            continue;
172        }
173
174        let mut current = regions[i];
175        let mut changed = true;
176
177        while changed {
178            changed = false;
179
180            for j in (i + 1)..regions.len() {
181                if used.contains(&j) {
182                    continue;
183                }
184
185                if boxes_overlap_or_close(&current, &regions[j], merge_distance) {
186                    current = merge_boxes(&current, &regions[j]);
187                    used.insert(j);
188                    changed = true;
189                }
190            }
191        }
192
193        merged.push(current);
194        used.insert(i);
195    }
196
197    merged
198}
199
200/// Check if two bounding boxes overlap or are close
201fn boxes_overlap_or_close(
202    box1: &(u32, u32, u32, u32),
203    box2: &(u32, u32, u32, u32),
204    distance: u32,
205) -> bool {
206    let (x1, y1, w1, h1) = *box1;
207    let (x2, y2, w2, h2) = *box2;
208
209    let x1_end = x1 + w1;
210    let y1_end = y1 + h1;
211    let x2_end = x2 + w2;
212    let y2_end = y2 + h2;
213
214    // Check for overlap or proximity
215    let x_overlap = (x1 <= x2_end + distance) && (x2 <= x1_end + distance);
216    let y_overlap = (y1 <= y2_end + distance) && (y2 <= y1_end + distance);
217
218    x_overlap && y_overlap
219}
220
221/// Merge two bounding boxes
222fn merge_boxes(
223    box1: &(u32, u32, u32, u32),
224    box2: &(u32, u32, u32, u32),
225) -> (u32, u32, u32, u32) {
226    let (x1, y1, w1, h1) = *box1;
227    let (x2, y2, w2, h2) = *box2;
228
229    let min_x = x1.min(x2);
230    let min_y = y1.min(y2);
231    let max_x = (x1 + w1).max(x2 + w2);
232    let max_y = (y1 + h1).max(y2 + h2);
233
234    (min_x, min_y, max_x - min_x, max_y - min_y)
235}
236
237/// Find text lines using projection profiles
238///
239/// Groups regions into lines based on vertical alignment
240pub fn find_text_lines(
241    _image: &GrayImage,
242    regions: &[(u32, u32, u32, u32)],
243) -> Vec<Vec<(u32, u32, u32, u32)>> {
244    if regions.is_empty() {
245        return Vec::new();
246    }
247
248    // Sort regions by y-coordinate
249    let mut sorted_regions = regions.to_vec();
250    sorted_regions.sort_by_key(|r| r.1);
251
252    let mut lines = Vec::new();
253    let mut current_line = vec![sorted_regions[0]];
254
255    for region in sorted_regions.iter().skip(1) {
256        let (_, y, _, h) = region;
257        let (_, prev_y, _, prev_h) = current_line.last().unwrap();
258
259        // Check if region is on the same line (vertical overlap)
260        let line_height = (*prev_h).max(*h);
261        let distance = if y > prev_y {
262            y - prev_y
263        } else {
264            prev_y - y
265        };
266
267        if distance < line_height / 2 {
268            current_line.push(*region);
269        } else {
270            lines.push(current_line.clone());
271            current_line = vec![*region];
272        }
273    }
274
275    if !current_line.is_empty() {
276        lines.push(current_line);
277    }
278
279    lines
280}
281
282/// Classify regions by type (text, math, table, etc.)
283fn classify_regions(
284    image: &GrayImage,
285    text_lines: Vec<Vec<(u32, u32, u32, u32)>>,
286) -> Vec<TextRegion> {
287    let mut regions = Vec::new();
288
289    for line in text_lines {
290        for bbox in line {
291            let (x, y, width, height) = bbox;
292
293            // Calculate features for classification
294            let aspect_ratio = width as f32 / height as f32;
295            let density = calculate_density(image, bbox);
296
297            // Simple heuristic classification
298            let region_type = if aspect_ratio > 10.0 {
299                // Very wide region might be a table or figure caption
300                RegionType::Table
301            } else if aspect_ratio < 0.5 && height > 50 {
302                // Tall region might be a figure
303                RegionType::Figure
304            } else if density > 0.3 && height < 30 {
305                // Dense, small region likely math
306                RegionType::Math
307            } else {
308                // Default to text
309                RegionType::Text
310            };
311
312            regions.push(TextRegion {
313                region_type,
314                bbox: (x, y, width, height),
315                confidence: 0.8, // Default confidence
316                text_height: height as f32,
317                baseline_angle: 0.0,
318            });
319        }
320    }
321
322    regions
323}
324
325/// Calculate pixel density in a region
326fn calculate_density(image: &GrayImage, bbox: (u32, u32, u32, u32)) -> f32 {
327    let (x, y, width, height) = bbox;
328    let total_pixels = (width * height) as f32;
329
330    if total_pixels == 0.0 {
331        return 0.0;
332    }
333
334    let mut foreground_pixels = 0;
335
336    for py in y..(y + height) {
337        for px in x..(x + width) {
338            if image.get_pixel(px, py)[0] < 128 {
339                foreground_pixels += 1;
340            }
341        }
342    }
343
344    foreground_pixels as f32 / total_pixels
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350    use image::Luma;
351
352    fn create_test_image_with_rectangles() -> GrayImage {
353        let mut img = GrayImage::new(200, 200);
354
355        // Fill with white
356        for pixel in img.pixels_mut() {
357            *pixel = Luma([255]);
358        }
359
360        // Draw some black rectangles (simulating text regions)
361        for y in 20..40 {
362            for x in 20..100 {
363                img.put_pixel(x, y, Luma([0]));
364            }
365        }
366
367        for y in 60..80 {
368            for x in 20..120 {
369                img.put_pixel(x, y, Luma([0]));
370            }
371        }
372
373        for y in 100..120 {
374            for x in 20..80 {
375                img.put_pixel(x, y, Luma([0]));
376            }
377        }
378
379        img
380    }
381
382    #[test]
383    fn test_find_text_regions() {
384        let img = create_test_image_with_rectangles();
385        let regions = find_text_regions(&img, 100);
386
387        assert!(regions.is_ok());
388        let r = regions.unwrap();
389
390        // Should find at least 3 regions
391        assert!(r.len() >= 3);
392
393        for region in r {
394            println!("Region: {:?} at {:?}", region.region_type, region.bbox);
395        }
396    }
397
398    #[test]
399    fn test_connected_components() {
400        let img = create_test_image_with_rectangles();
401        let components = connected_components(&img);
402
403        // Check that we have non-zero labels
404        let max_label = components
405            .iter()
406            .flat_map(|row| row.iter())
407            .max()
408            .unwrap_or(&0);
409
410        assert!(*max_label > 0);
411    }
412
413    #[test]
414    fn test_merge_overlapping_regions() {
415        let regions = vec![
416            (10, 10, 50, 20),
417            (40, 10, 50, 20),
418            (100, 100, 30, 30),
419        ];
420
421        let merged = merge_overlapping_regions(regions, 10);
422
423        // First two should merge, third stays separate
424        assert_eq!(merged.len(), 2);
425    }
426
427    #[test]
428    fn test_merge_boxes() {
429        let box1 = (10, 10, 50, 20);
430        let box2 = (40, 15, 30, 25);
431
432        let merged = merge_boxes(&box1, &box2);
433
434        assert_eq!(merged.0, 10); // min x
435        assert_eq!(merged.1, 10); // min y
436        assert!(merged.2 >= 50); // width
437        assert!(merged.3 >= 25); // height
438    }
439
440    #[test]
441    fn test_boxes_overlap() {
442        let box1 = (10, 10, 50, 20);
443        let box2 = (40, 10, 50, 20);
444
445        assert!(boxes_overlap_or_close(&box1, &box2, 0));
446        assert!(boxes_overlap_or_close(&box1, &box2, 10));
447    }
448
449    #[test]
450    fn test_boxes_dont_overlap() {
451        let box1 = (10, 10, 20, 20);
452        let box2 = (100, 100, 20, 20);
453
454        assert!(!boxes_overlap_or_close(&box1, &box2, 0));
455    }
456
457    #[test]
458    fn test_find_text_lines() {
459        let regions = vec![
460            (10, 10, 50, 20),
461            (70, 12, 50, 20),
462            (10, 50, 50, 20),
463            (70, 52, 50, 20),
464        ];
465
466        let img = GrayImage::new(200, 100);
467        let lines = find_text_lines(&img, &regions);
468
469        // Should find 2 lines
470        assert_eq!(lines.len(), 2);
471        assert_eq!(lines[0].len(), 2);
472        assert_eq!(lines[1].len(), 2);
473    }
474
475    #[test]
476    fn test_calculate_density() {
477        let mut img = GrayImage::new(100, 100);
478
479        // Fill region with 50% black pixels
480        for y in 10..30 {
481            for x in 10..30 {
482                let val = if (x + y) % 2 == 0 { 0 } else { 255 };
483                img.put_pixel(x, y, Luma([val]));
484            }
485        }
486
487        let density = calculate_density(&img, (10, 10, 20, 20));
488        assert!((density - 0.5).abs() < 0.1);
489    }
490
491    #[test]
492    fn test_filter_by_size() {
493        let mut bboxes = HashMap::new();
494        bboxes.insert(1, (10, 10, 50, 50)); // 2500 pixels
495        bboxes.insert(2, (100, 100, 10, 10)); // 100 pixels
496        bboxes.insert(3, (200, 200, 30, 30)); // 900 pixels
497
498        let filtered = filter_by_size(bboxes, 500);
499
500        // Should keep regions 1 and 3
501        assert_eq!(filtered.len(), 2);
502    }
503}