Skip to main content

ruranges_core/
spliced_subsequence.rs

1use radsort::sort_by_key;
2
3use crate::{
4    ruranges_structs::{GroupType, PositionType, SplicedSubsequenceInterval},
5    sorts::build_sorted_subsequence_intervals,
6};
7
8/// (idxs, starts, ends, strands) for exactly one (start,end) slice
9fn global_shift<T: PositionType>(starts: &[T], ends: &[T]) -> T {
10    let mut min_coord = T::zero();
11    for &v in starts {
12        if v < min_coord {
13            min_coord = v;
14        }
15    }
16    for &v in ends {
17        if v < min_coord {
18            min_coord = v;
19        }
20    }
21    if min_coord < T::zero() {
22        -min_coord
23    } else {
24        T::zero()
25    }
26}
27
28/// (idxs, starts, ends, strands) for **one** (start,end) slice
29pub fn spliced_subseq<G: GroupType, T: PositionType>(
30    chrs: &[G],
31    starts: &[T],
32    ends: &[T],
33    strand_flags: &[bool],
34    start: T,
35    end: Option<T>,
36    force_plus_strand: bool,
37) -> (Vec<u32>, Vec<T>, Vec<T>, Vec<bool>) {
38    // ────────────────────────── 1. pre-processing: apply global shift ─────
39    let shift = global_shift(starts, ends);
40
41    // Either borrow the original slices (shift == 0) or build shifted copies.
42    // `tmp_storage` keeps the vectors alive for as long as we need the slices.
43    let (starts_slice, ends_slice);
44    let _tmp_storage: Option<(Vec<T>, Vec<T>)>;
45
46    if shift > T::zero() {
47        let mut s = Vec::with_capacity(starts.len());
48        let mut e = Vec::with_capacity(ends.len());
49        for i in 0..starts.len() {
50            s.push(starts[i] + shift);
51            e.push(ends[i] + shift);
52        }
53        _tmp_storage = Some((s, e));
54        let (s_ref, e_ref) = _tmp_storage.as_ref().unwrap();
55        starts_slice = s_ref.as_slice();
56        ends_slice = e_ref.as_slice();
57    } else {
58        _tmp_storage = None;
59        starts_slice = starts;
60        ends_slice = ends;
61    }
62    // ───────────────────────────────────────────────────────────────────────
63
64    // ────────────── helper struct local to this function ───────────────────
65    struct OutRec<T: PositionType> {
66        idx: u32,
67        start: T,
68        end: T,
69        strand: bool,
70    }
71
72    // Build sorted interval vector (caller guarantees same grouping rules).
73    let mut intervals =
74        build_sorted_subsequence_intervals(chrs, starts_slice, ends_slice, strand_flags);
75
76    // Early-exit when nothing to do
77    if intervals.is_empty() {
78        return (Vec::new(), Vec::new(), Vec::new(), Vec::new());
79    }
80
81    let mut out_recs: Vec<OutRec<T>> = Vec::with_capacity(intervals.len());
82
83    let mut group_buf: Vec<SplicedSubsequenceInterval<G, T>> = Vec::new();
84    let mut current_chr = intervals[0].chr;
85    let mut running_sum = T::zero();
86
87    // ───────── helper: finalise one transcript/group ───────────────────────
88    let mut finalize_group = |group: &mut [SplicedSubsequenceInterval<G, T>]| {
89        if group.is_empty() {
90            return;
91        }
92
93        // total spliced length
94        let total_len = group.last().unwrap().temp_cumsum;
95        let end_val = end.unwrap_or(total_len);
96
97        // translate negative offsets
98        let global_start = if start < T::zero() {
99            total_len + start
100        } else {
101            start
102        };
103        let global_end = if end_val < T::zero() {
104            total_len + end_val
105        } else {
106            end_val
107        };
108
109        let group_forward = group[0].forward_strand;
110
111        // per-exon closure so we don’t duplicate maths
112        let mut process_iv = |iv: &mut SplicedSubsequenceInterval<G, T>| {
113            let cumsum_start = iv.temp_cumsum - iv.temp_length;
114            let cumsum_end = iv.temp_cumsum;
115
116            let mut st = iv.start;
117            let mut en = iv.end;
118
119            // coordinate arithmetic orientation
120            let processed_forward = force_plus_strand || iv.forward_strand;
121
122            if processed_forward {
123                let shift = global_start - cumsum_start;
124                if shift > T::zero() {
125                    st = st + shift;
126                }
127                let shift = cumsum_end - global_end;
128                if shift > T::zero() {
129                    en = en - shift;
130                }
131            } else {
132                let shift = global_start - cumsum_start;
133                if shift > T::zero() {
134                    en = en - shift;
135                }
136                let shift = cumsum_end - global_end;
137                if shift > T::zero() {
138                    st = st + shift;
139                }
140            }
141
142            // keep only non-empty pieces
143            if st < en {
144                out_recs.push(OutRec {
145                    idx: iv.idx,
146                    start: st,
147                    end: en,
148                    strand: iv.forward_strand == processed_forward, // (+)*(+) or (−)*(−) → '+'
149                });
150            }
151        };
152
153        // walk exons in transcription order
154        if group_forward {
155            for iv in group.iter_mut() {
156                process_iv(iv);
157            }
158        } else {
159            for iv in group.iter_mut().rev() {
160                process_iv(iv);
161            }
162        }
163    };
164    // ───────────────────────────────────────────────────────────────────────
165
166    // single linear scan over all exons
167    for mut iv in intervals.into_iter() {
168        iv.start = iv.start.abs();
169        iv.end = iv.end.abs();
170
171        // new chromosome ⇒ flush buffer
172        if iv.chr != current_chr {
173            finalize_group(&mut group_buf);
174            group_buf.clear();
175            running_sum = T::zero();
176            current_chr = iv.chr;
177        }
178
179        iv.temp_length = iv.end - iv.start;
180        iv.temp_cumsum = running_sum + iv.temp_length;
181        running_sum = iv.temp_cumsum;
182
183        group_buf.push(iv);
184    }
185    finalize_group(&mut group_buf);
186
187    // restore original row order
188    sort_by_key(&mut out_recs, |r| r.idx);
189
190    // ───────── explode OutRec list into parallel result vectors ────────────
191    let mut out_idxs = Vec::with_capacity(out_recs.len());
192    let mut out_starts = Vec::with_capacity(out_recs.len());
193    let mut out_ends = Vec::with_capacity(out_recs.len());
194    let mut out_strands = Vec::with_capacity(out_recs.len());
195
196    for rec in out_recs {
197        out_idxs.push(rec.idx);
198        out_starts.push(rec.start);
199        out_ends.push(rec.end);
200        out_strands.push(rec.strand);
201    }
202
203    // ─────────────────────────── 3. post-processing: undo shift ────────────
204    if shift > T::zero() {
205        for v in &mut out_starts {
206            *v = *v - shift;
207        }
208        for v in &mut out_ends {
209            *v = *v - shift;
210        }
211    }
212    // ───────────────────────────────────────────────────────────────────────
213
214    (out_idxs, out_starts, out_ends, out_strands)
215}
216
217pub fn spliced_subseq_multi<G: GroupType, T: PositionType>(
218    chrs: &[G],
219    starts: &[T],
220    ends: &[T],
221    strand_flags: &[bool],
222    slice_starts: &[T],
223    slice_ends: &[Option<T>],
224    force_plus_strand: bool,
225) -> (Vec<u32>, Vec<T>, Vec<T>, Vec<bool>) {
226    assert_eq!(chrs.len(), starts.len());
227    assert_eq!(starts.len(), ends.len());
228    assert_eq!(ends.len(), strand_flags.len());
229    assert_eq!(strand_flags.len(), slice_starts.len());
230    assert_eq!(slice_starts.len(), slice_ends.len());
231
232    let shift = global_shift(starts, ends);
233
234    let (starts_slice, ends_slice);
235    let _tmp_storage: Option<(Vec<T>, Vec<T>)>;
236    if shift > T::zero() {
237        let mut s = Vec::with_capacity(starts.len());
238        let mut e = Vec::with_capacity(ends.len());
239        for i in 0..starts.len() {
240            s.push(starts[i] + shift);
241            e.push(ends[i] + shift);
242        }
243        _tmp_storage = Some((s, e));
244        let (s_ref, e_ref) = _tmp_storage.as_ref().unwrap();
245        starts_slice = s_ref.as_slice();
246        ends_slice = e_ref.as_slice();
247    } else {
248        _tmp_storage = None;
249        starts_slice = starts;
250        ends_slice = ends;
251    }
252
253    struct OutRec<T: PositionType> {
254        idx: u32,
255        start: T,
256        end: T,
257        strand: bool,
258    }
259
260    let mut intervals =
261        build_sorted_subsequence_intervals(chrs, starts_slice, ends_slice, strand_flags);
262
263    if intervals.is_empty() {
264        return (Vec::new(), Vec::new(), Vec::new(), Vec::new());
265    }
266
267    let mut out_recs: Vec<OutRec<T>> = Vec::with_capacity(intervals.len());
268    let mut group_buf: Vec<SplicedSubsequenceInterval<G, T>> = Vec::new();
269    let mut current_chr = intervals[0].chr;
270    let mut running_sum = T::zero();
271    let mut current_slice_start: T = slice_starts[intervals[0].idx as usize];
272    let mut current_slice_end: Option<T> = slice_ends[intervals[0].idx as usize];
273
274    let mut finalize_group =
275        |group: &mut [SplicedSubsequenceInterval<G, T>], slice_start: T, slice_end: Option<T>| {
276            if group.is_empty() {
277                return;
278            }
279
280            let total_len = group.last().unwrap().temp_cumsum;
281            let end_val = slice_end.unwrap_or(total_len);
282
283            let global_start = if slice_start < T::zero() {
284                total_len + slice_start
285            } else {
286                slice_start
287            };
288            let global_end = if end_val < T::zero() {
289                total_len + end_val
290            } else {
291                end_val
292            };
293
294            let group_forward = group[0].forward_strand;
295
296            let mut process_iv = |iv: &mut SplicedSubsequenceInterval<G, T>| {
297                let cumsum_start = iv.temp_cumsum - iv.temp_length;
298                let cumsum_end = iv.temp_cumsum;
299
300                let mut st = iv.start;
301                let mut en = iv.end;
302
303                let processed_forward = force_plus_strand || iv.forward_strand;
304
305                if processed_forward {
306                    let shift = global_start - cumsum_start;
307                    if shift > T::zero() {
308                        st = st + shift;
309                    }
310                    let shift = cumsum_end - global_end;
311                    if shift > T::zero() {
312                        en = en - shift;
313                    }
314                } else {
315                    let shift = global_start - cumsum_start;
316                    if shift > T::zero() {
317                        en = en - shift;
318                    }
319                    let shift = cumsum_end - global_end;
320                    if shift > T::zero() {
321                        st = st + shift;
322                    }
323                }
324
325                if st < en {
326                    out_recs.push(OutRec {
327                        idx: iv.idx,
328                        start: st,
329                        end: en,
330                        strand: iv.forward_strand == processed_forward,
331                    });
332                }
333            };
334
335            if group_forward {
336                for iv in group.iter_mut() {
337                    process_iv(iv);
338                }
339            } else {
340                for iv in group.iter_mut().rev() {
341                    process_iv(iv);
342                }
343            }
344        };
345
346    for mut iv in intervals.into_iter() {
347        iv.start = iv.start.abs();
348        iv.end = iv.end.abs();
349
350        if iv.chr != current_chr {
351            finalize_group(&mut group_buf, current_slice_start, current_slice_end);
352            group_buf.clear();
353            running_sum = T::zero();
354            current_chr = iv.chr;
355            current_slice_start = slice_starts[iv.idx as usize];
356            current_slice_end = slice_ends[iv.idx as usize];
357        }
358
359        iv.temp_length = iv.end - iv.start;
360        iv.temp_cumsum = running_sum + iv.temp_length;
361        running_sum = iv.temp_cumsum;
362
363        group_buf.push(iv);
364    }
365    finalize_group(&mut group_buf, current_slice_start, current_slice_end);
366
367    sort_by_key(&mut out_recs, |r| r.idx);
368
369    let mut out_idxs = Vec::with_capacity(out_recs.len());
370    let mut out_starts = Vec::with_capacity(out_recs.len());
371    let mut out_ends = Vec::with_capacity(out_recs.len());
372    let mut out_strands = Vec::with_capacity(out_recs.len());
373
374    for rec in out_recs {
375        out_idxs.push(rec.idx);
376        out_starts.push(rec.start);
377        out_ends.push(rec.end);
378        out_strands.push(rec.strand);
379    }
380
381    if shift > T::zero() {
382        for v in &mut out_starts {
383            *v = *v - shift;
384        }
385        for v in &mut out_ends {
386            *v = *v - shift;
387        }
388    }
389
390    (out_idxs, out_starts, out_ends, out_strands)
391}