Skip to main content

ruranges_core/
nearest.rs

1use std::str::FromStr;
2
3use radsort::sort_by_key;
4
5use crate::{
6    overlaps::overlaps,
7    ruranges_structs::{GroupType, MinEvent, Nearest, OverlapPair, PositionType},
8    sorts::build_sorted_events_single_collection_separate_outputs,
9};
10
11/// For each MinEvent in `sorted_ends`, find up to `k` *unique positions*
12/// in `sorted_starts2` that lie to the right (including equal position on the
13/// same chromosome). If multiple entries in `sorted_starts2` share the same
14/// position, they all get reported, but they count as one unique position.
15pub fn nearest_intervals_to_the_right<C: GroupType, T: PositionType>(
16    sorted_ends: Vec<MinEvent<C, T>>,
17    sorted_starts2: Vec<MinEvent<C, T>>,
18    k: usize,
19) -> Vec<Nearest<T>> {
20    // We might need more than `sorted_ends.len()` because each end could
21    // contribute up to `k` *unique positions* (potentially multiplied by the
22    // number of intervals sharing those positions). So we set capacity
23    // accordingly.
24    // This is not strictly required, but it helps performance to reserve enough space.
25    let mut output = Vec::with_capacity(sorted_ends.len().saturating_mul(k));
26
27    let n_starts = sorted_starts2.len();
28
29    // `j` will track our position in sorted_starts2 as we move through sorted_ends.
30    let mut j = 0usize;
31
32    // Iterate over each 'end' event
33    for end in &sorted_ends {
34        let end_chr = end.chr;
35        let end_pos = end.pos;
36
37        // Advance `j` so that sorted_starts2[j] is the first start
38        // that is >= end_pos on the same chrom (or beyond).
39        // Because both arrays are sorted, we never need to move `j` backward.
40        while j < n_starts {
41            let start = &sorted_starts2[j];
42            if start.chr < end_chr {
43                // still on a smaller chromosome; move j forward
44                j += 1;
45            } else if start.chr == end_chr && start.pos < end_pos {
46                // same chrom but still to the left; move j forward
47                j += 1;
48            } else {
49                // now start.chr > end_chr (i.e. next chromosome) OR
50                // start.chr == end_chr && start.pos >= end_pos
51                // -> we've reached a region that is "to the right" or next chrom
52                break;
53            }
54        }
55
56        // Now collect up to k unique positions (on the same chromosome).
57        let mut unique_count = 0;
58        let mut last_pos: Option<T> = None;
59
60        // We'll scan from `j` onward, but we do NOT move `j` itself
61        // because the next 'end' might need a similar or slightly advanced position.
62        // Instead, we use `local_idx` to look ahead for this specific end.
63        let mut local_idx = j;
64        while local_idx < n_starts {
65            let start = &sorted_starts2[local_idx];
66
67            // If we've passed beyond the chromosome of this end, we won't find
68            // any more right-side intervals for this end.
69            if start.chr != end_chr {
70                break;
71            }
72
73            // Check if we're at a new unique position
74            if last_pos.map_or(true, |lp| start.pos != lp) {
75                unique_count += 1;
76                if unique_count > k {
77                    // we've reached the limit of k unique positions
78                    break;
79                }
80                last_pos = Some(start.pos);
81            }
82
83            // This start is included in the results
84            let distance = start.pos - end_pos + T::one(); // can be 0 or positive
85            output.push(Nearest {
86                distance,
87                idx: end.idx,
88                idx2: start.idx,
89            });
90
91            local_idx += 1;
92        }
93    }
94
95    output
96}
97
98/// For each MinEvent in `sorted_ends`, find up to `k` *unique positions*
99/// in `sorted_starts2` that lie to the left (strictly smaller position on
100/// the same chromosome). If multiple entries in `sorted_starts2` share
101/// the same position, they all get reported, but they count as one
102/// unique position in the limit `k`.
103pub fn nearest_intervals_to_the_left<C: GroupType, T: PositionType>(
104    sorted_ends: Vec<MinEvent<C, T>>,
105    sorted_starts2: Vec<MinEvent<C, T>>,
106    k: usize,
107) -> Vec<Nearest<T>> {
108    // The max possible size is (number of ends) * (k + duplicates at each of those k positions).
109    // We reserve a rough upper bound for efficiency.
110    let mut output = Vec::with_capacity(sorted_ends.len().saturating_mul(k));
111
112    let n_starts = sorted_starts2.len();
113    let mut j = 0_usize; // Points into sorted_starts2
114
115    for end in &sorted_ends {
116        let end_chr = end.chr;
117        let end_pos = end.pos;
118
119        // Move `j` forward so that:
120        // - All start events at indices < j have start.chr < end_chr
121        //   OR (start.chr == end_chr && start.pos < end_pos).
122        // - Equivalently, sorted_starts2[j] is the *first* event that is NOT
123        //   strictly to the left of `end`.
124        while j < n_starts {
125            let start = &sorted_starts2[j];
126            if start.chr < end_chr {
127                // still a smaller chromosome => definitely to the left
128                j += 1;
129            } else if start.chr == end_chr && start.pos < end_pos {
130                // same chrom, smaller position => to the left
131                j += 1;
132            } else {
133                // we've reached a start that is not to the left
134                break;
135            }
136        }
137
138        // Now, everything in [0..j) is strictly to the left of `end`.
139        // We'll look backwards from j-1 to gather up to k unique positions
140        // on the same chromosome.
141        if j == 0 {
142            // No intervals to the left; skip
143            continue;
144        }
145
146        let mut local_idx = j - 1;
147        let mut unique_count = 0;
148        let mut last_pos: Option<T> = None;
149
150        // Descend from j-1 down to 0 (or until we break).
151        loop {
152            let start = &sorted_starts2[local_idx];
153
154            // Must match the same chromosome
155            if start.chr != end_chr {
156                break;
157            }
158
159            // Check if we have a new (unique) position
160            if last_pos.map_or(true, |lp| start.pos != lp) {
161                unique_count += 1;
162                if unique_count > k {
163                    break;
164                }
165                last_pos = Some(start.pos);
166            }
167
168            // Calculate the distance (end.pos - start.pos)
169            // Here, start.pos < end.pos by definition if we get here.
170            let distance = end_pos - start.pos + T::one();
171            output.push(Nearest {
172                distance,
173                idx: end.idx,    // the 'end' event's idx
174                idx2: start.idx, // the 'start' event's idx
175            });
176
177            if local_idx == 0 {
178                break;
179            }
180            local_idx -= 1;
181        }
182    }
183
184    output
185}
186
187/// Merges th
188#[derive(Copy, Clone, Debug, PartialEq, Eq)]
189pub enum Direction {
190    Forward,
191    Backward,
192    Any,
193}
194
195impl FromStr for Direction {
196    type Err = &'static str;
197
198    fn from_str(s: &str) -> Result<Self, Self::Err> {
199        match s.to_lowercase().as_str() {
200            "forward" => Ok(Direction::Forward),
201            "backward" => Ok(Direction::Backward),
202            "any" => Ok(Direction::Any),
203            _ => Err("Invalid direction string"),
204        }
205    }
206}
207
208pub fn nearest<C: GroupType, T: PositionType>(
209    chrs: &[C],
210    starts: &[T],
211    ends: &[T],
212    chrs2: &[C],
213    starts2: &[T],
214    ends2: &[T],
215    slack: T,
216    k: usize,
217    include_overlaps: bool,
218    direction: &str,
219) -> (Vec<u32>, Vec<u32>, Vec<T>) {
220    let dir = Direction::from_str(direction).unwrap();
221
222    let sorted_starts = build_sorted_events_single_collection_separate_outputs(chrs, starts, slack);
223    let sorted_ends = build_sorted_events_single_collection_separate_outputs(chrs, ends, slack);
224
225    let sorted_starts2 =
226        build_sorted_events_single_collection_separate_outputs(chrs2, starts2, T::zero());
227    let sorted_ends2 =
228        build_sorted_events_single_collection_separate_outputs(chrs2, ends2, T::zero());
229
230    let overlaps = if include_overlaps {
231        let (idx, idx2) = overlaps(
232            chrs, starts, ends, chrs2, starts2, ends2, slack, "all", true, false,
233        );
234        idx.into_iter()
235            .zip(idx2)
236            .map(|(idx, idx2)| OverlapPair { idx, idx2 })
237            .collect()
238    } else {
239        Vec::new()
240    };
241    let nearest_left = if dir == Direction::Backward || dir == Direction::Any {
242        let mut tmp = nearest_intervals_to_the_left(sorted_starts, sorted_ends2, k);
243        radsort::sort_by_key(&mut tmp, |n| (n.idx, n.distance));
244        tmp
245    } else {
246        Vec::new()
247    };
248    let nearest_right = if dir == Direction::Forward || dir == Direction::Any {
249        let mut tmp = nearest_intervals_to_the_right(sorted_ends, sorted_starts2, k);
250        radsort::sort_by_key(&mut tmp, |n| (n.idx, n.distance));
251        tmp
252    } else {
253        Vec::new()
254    };
255
256    let merged = merge_three_way_by_index_distance(&overlaps, &nearest_left, &nearest_right, k);
257    merged
258}
259
260/// Merges three sources of intervals, grouped by `idx` (i.e. `idx1` in overlaps).
261/// For each unique `idx`, it returns up to `k` *distinct* distances (including
262/// all intervals at those distances). Overlaps are treated as distance=0 (or 1).
263///
264/// The data is assumed to be sorted in ascending order by `(idx, distance)`.
265pub fn merge_three_way_by_index_distance<T: PositionType>(
266    overlaps: &[OverlapPair],     // sorted by idx1
267    nearest_left: &[Nearest<T>],  // sorted by (idx, distance)
268    nearest_right: &[Nearest<T>], // sorted by (idx, distance)
269    k: usize,
270) -> (Vec<u32>, Vec<u32>, Vec<T>) {
271    // We'll return tuples: (idx, idx2, distance).
272    // You can adapt if you want a custom struct instead.
273    let mut results = Vec::new();
274
275    // Pointers over each input
276    let (mut i, mut j, mut r) = (0_usize, 0_usize, 0_usize);
277
278    // Outer loop: pick the smallest index among the three lists
279    while i < overlaps.len() || j < nearest_left.len() || r < nearest_right.len() {
280        // Current index (None if that list is exhausted)
281        let idx_o = overlaps.get(i).map(|o| o.idx);
282        let idx_l = nearest_left.get(j).map(|n| n.idx);
283        let idx_r = nearest_right.get(r).map(|n| n.idx);
284
285        // If all three are None, we're done
286        let current_idx = match (idx_o, idx_l, idx_r) {
287            (None, None, None) => break,
288            (Some(a), Some(b), Some(c)) => a.min(b.min(c)),
289            (Some(a), Some(b), None) => a.min(b),
290            (Some(a), None, Some(c)) => a.min(c),
291            (None, Some(b), Some(c)) => b.min(c),
292            (Some(a), None, None) => a,
293            (None, Some(b), None) => b,
294            (None, None, Some(c)) => c,
295        };
296
297        // Gather all overlaps for current_idx
298        let i_start = i;
299        while i < overlaps.len() && overlaps[i].idx == current_idx {
300            i += 1;
301        }
302        let overlaps_slice = &overlaps[i_start..i];
303
304        // Gather all nearest_left for current_idx
305        let j_start = j;
306        while j < nearest_left.len() && nearest_left[j].idx == current_idx {
307            j += 1;
308        }
309        let left_slice = &nearest_left[j_start..j];
310
311        // Gather all nearest_right for current_idx
312        let r_start = r;
313        while r < nearest_right.len() && nearest_right[r].idx == current_idx {
314            r += 1;
315        }
316        let right_slice = &nearest_right[r_start..r];
317
318        // Now we have three *already-sorted* slices (by distance) for this index:
319        //  1) overlaps_slice (distance=0 or 1, or if you store it in OverlapPair, read it)
320        //  2) left_slice (sorted ascending by distance)
321        //  3) right_slice (sorted ascending by distance)
322        //
323        // We'll do a 3-way merge *by distance*, collecting up to k *distinct* distances.
324        // If you store overlap distances in OverlapPair, you can read them;
325        // otherwise, assume overlap distance=0.
326
327        let mut used_distances = std::collections::HashSet::new();
328        let mut distinct_count = 0;
329
330        let (mut oi, mut lj, mut rr) = (0, 0, 0);
331
332        // Helper closures to peek distance from each slice
333        let overlap_dist = |_ix: usize| -> T {
334            // If you store distance in OverlapPair, return that. Otherwise 0 or 1.
335            // For the example, let's assume actual Overlap distance=0:
336            T::zero()
337        };
338        let left_dist = |ix: usize| -> T { left_slice[ix].distance };
339        let right_dist = |ix: usize| -> T { right_slice[ix].distance };
340
341        // Inner loop: pick the next *smallest* distance among the three slices
342        while oi < overlaps_slice.len() || lj < left_slice.len() || rr < right_slice.len() {
343            // Peek next distance (or i64::MAX if none)
344            let d_o = if oi < overlaps_slice.len() {
345                overlap_dist(oi)
346            } else {
347                T::max_value()
348            };
349            let d_l = if lj < left_slice.len() {
350                left_dist(lj)
351            } else {
352                T::max_value()
353            };
354            let d_r = if rr < right_slice.len() {
355                right_dist(rr)
356            } else {
357                T::max_value()
358            };
359
360            let smallest = d_o.min(d_l.min(d_r));
361            if smallest == T::max_value() {
362                // no more items
363                break;
364            }
365
366            // We'll pull everything from Overlaps that has distance == smallest
367            while oi < overlaps_slice.len() {
368                let dcur = overlap_dist(oi);
369                if dcur == smallest {
370                    // If this is a *new* distance (not in used_distances),
371                    // we check if it would exceed k distinct distances
372                    if !used_distances.contains(&dcur) {
373                        distinct_count += 1;
374                        if distinct_count > k {
375                            // no new distances allowed
376                            break;
377                        }
378                        used_distances.insert(dcur);
379                    }
380                    // Add to result
381                    let OverlapPair { idx, idx2 } = overlaps_slice[oi];
382                    results.push(Nearest {
383                        idx: idx,
384                        idx2: idx2,
385                        distance: T::zero(),
386                    });
387                    oi += 1;
388                } else {
389                    break;
390                }
391            }
392            if distinct_count > k {
393                break;
394            }
395
396            // Pull everything from Left that has distance == smallest
397            while lj < left_slice.len() {
398                let dcur = left_dist(lj);
399                if dcur == smallest {
400                    if !used_distances.contains(&dcur) {
401                        distinct_count += 1;
402                        if distinct_count > k {
403                            break;
404                        }
405                        used_distances.insert(dcur);
406                    }
407                    results.push(left_slice[lj]);
408                    lj += 1;
409                } else {
410                    break;
411                }
412            }
413            if distinct_count > k {
414                break;
415            }
416
417            // Pull everything from Right that has distance == smallest
418            while rr < right_slice.len() {
419                let dcur = right_dist(rr);
420                if dcur == smallest {
421                    if !used_distances.contains(&dcur) {
422                        distinct_count += 1;
423                        if distinct_count > k {
424                            break;
425                        }
426                        used_distances.insert(dcur);
427                    }
428                    results.push(right_slice[rr]);
429                    rr += 1;
430                } else {
431                    break;
432                }
433            }
434            if distinct_count > k {
435                break;
436            }
437        }
438        // done collecting up to k distinct distances for this index
439    }
440
441    sort_by_key(&mut results, |n| (n.idx, n.distance, n.idx2));
442
443    let mut out_idxs = Vec::with_capacity(results.len());
444    let mut out_idxs2 = Vec::with_capacity(results.len());
445    let mut out_distances = Vec::with_capacity(results.len());
446
447    for rec in results {
448        out_idxs.push(rec.idx);
449        out_idxs2.push(rec.idx2);
450        out_distances.push(rec.distance);
451    }
452
453    (out_idxs, out_idxs2, out_distances)
454}