tsg_core/graph/
node.rs

1use std::fmt;
2use std::str::FromStr;
3
4use crate::graph::Attribute;
5use ahash::HashMap;
6use anyhow::Context;
7use anyhow::Result;
8use bon::Builder;
9use bon::builder;
10use bstr::BString;
11use bstr::ByteSlice;
12use rayon::prelude::*;
13use serde_json::json;
14use std::io;
15use tracing::debug;
16
17/// Represents a simple interval with start and end positions.
18///
19/// An interval is defined by two positions:
20/// - `start`: The inclusive beginning position of the interval
21/// - `end`: The exclusive ending position of the interval
22///
23/// The interval spans from `start` (inclusive) to `end` (exclusive).
24#[derive(Debug, Builder, Clone)]
25pub struct Interval {
26    pub start: usize,
27    pub end: usize,
28}
29
30impl Interval {
31    /// Returns the length of the interval.
32    ///
33    /// The length is calculated as `end - start`.
34    ///
35    /// # Returns
36    /// The length of the interval as a `usize`.
37    pub fn span(&self) -> usize {
38        self.end - self.start
39    }
40}
41
42impl FromStr for Interval {
43    type Err = io::Error;
44
45    fn from_str(s: &str) -> Result<Self, Self::Err> {
46        let parts: Vec<&str> = s.split('-').collect();
47        if parts.len() != 2 {
48            return Err(io::Error::new(
49                io::ErrorKind::InvalidData,
50                format!("Invalid exon coordinates format: {}", s),
51            ));
52        }
53
54        let start = parts[0].parse::<usize>().map_err(|e| {
55            io::Error::new(
56                io::ErrorKind::InvalidData,
57                format!("Invalid start coordinate: {}", e),
58            )
59        })?;
60
61        let end = parts[1].parse::<usize>().map_err(|e| {
62            io::Error::new(
63                io::ErrorKind::InvalidData,
64                format!("Invalid end coordinate: {}", e),
65            )
66        })?;
67
68        Ok(Self { start, end })
69    }
70}
71
72#[derive(Debug, Builder, Clone, Default)]
73/// Represents a collection of exons, which are contiguous regions within genomic sequences.
74///
75/// Exons are the parts of a gene's DNA that code for proteins, and they're separated by
76/// non-coding regions called introns. This structure stores a collection of exons as intervals.
77///
78/// # Fields
79///
80/// * `exons` - A vector of intervals representing the positions of exons within a genomic sequence.
81pub struct Exons {
82    pub exons: Vec<Interval>,
83}
84
85impl FromStr for Exons {
86    type Err = io::Error;
87    fn from_str(s: &str) -> Result<Self, Self::Err> {
88        let exons = s
89            .split(',')
90            .map(|x| x.parse())
91            .collect::<Result<Vec<Interval>, Self::Err>>()?;
92        Ok(Exons { exons })
93    }
94}
95
96impl fmt::Display for Exons {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        let exons = self
99            .exons
100            .iter()
101            .map(|x| format!("{}-{}", x.start, x.end))
102            .collect::<Vec<String>>()
103            .join(",");
104        write!(f, "{}", exons)
105    }
106}
107
108/// Methods for working with exon structures.
109///
110/// # Methods
111///
112/// - `introns()` - Calculates the intron intervals between exons
113/// - `is_empty()` - Checks if there are no exons
114/// - `len()` - Returns the number of exons
115/// - `span()` - Calculates the total number of bases covered by all exons
116/// - `first_exon()` - Gets a reference to the first exon
117/// - `last_exon()` - Gets a reference to the last exon
118///
119/// # Panics
120///
121/// - `first_exon()` will panic if there are no exons
122/// - `last_exon()` will panic if there are no exons
123impl Exons {
124    /// Returns a vector of intervals representing introns.
125    ///
126    /// Introns are the regions between consecutive exons. For each pair of adjacent exons,
127    /// an intron is created starting at the position immediately after the end of the first exon
128    /// and ending at the position immediately before the start of the second exon.
129    ///
130    /// # Returns
131    /// A `Vec<Interval>` containing all introns between exons in this structure.
132    pub fn introns(&self) -> Vec<Interval> {
133        let mut introns = Vec::with_capacity(self.exons.len().saturating_sub(1));
134        for i in 0..self.exons.len().saturating_sub(1) {
135            introns.push(Interval {
136                start: self.exons[i].end + 1,
137                end: self.exons[i + 1].start,
138            });
139        }
140        introns
141    }
142
143    /// Checks if the exon collection is empty.
144    ///
145    /// # Returns
146    /// `true` if there are no exons, `false` otherwise.
147    pub fn is_empty(&self) -> bool {
148        self.exons.is_empty()
149    }
150
151    /// Returns the number of exons.
152    ///
153    /// # Returns
154    /// The count of exons as a `usize`.
155    pub fn len(&self) -> usize {
156        self.exons.len()
157    }
158
159    /// Calculates the total span (combined length) of all exons.
160    ///
161    /// The span is computed by summing the lengths of all intervals,
162    /// where each interval length is calculated as `end - start + 1`.
163    ///
164    /// # Returns
165    /// The total span as a `usize`.
166    pub fn span(&self) -> usize {
167        self.exons.iter().map(|e| e.span()).sum()
168    }
169
170    /// Returns a reference to the first exon.
171    ///
172    /// # Panics
173    /// Panics if the exon collection is empty.
174    ///
175    /// # Returns
176    /// A reference to the first `Interval` in the exon collection.
177    pub fn first_exon(&self) -> &Interval {
178        &self.exons[0]
179    }
180
181    /// Returns a reference to the last exon.
182    ///
183    /// # Panics
184    /// Panics if the exon collection is empty.
185    ///
186    /// # Returns
187    /// A reference to the last `Interval` in the exon collection.
188    pub fn last_exon(&self) -> &Interval {
189        &self.exons[self.exons.len() - 1]
190    }
191}
192
193#[derive(Debug, Clone, Builder, PartialEq)]
194#[builder(on(BString, into))]
195#[builder(on(ReadIdentity, into))]
196pub struct ReadData {
197    pub id: BString,
198    pub identity: ReadIdentity,
199}
200
201impl fmt::Display for ReadData {
202    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203        write!(f, "{}:{:?}", self.id, self.identity)
204    }
205}
206
207impl FromStr for ReadData {
208    type Err = io::Error;
209
210    fn from_str(s: &str) -> Result<Self, Self::Err> {
211        // <id>:<identity>
212        let fields: Vec<&str> = s.split(':').collect();
213        if fields.len() != 2 {
214            return Err(io::Error::new(
215                io::ErrorKind::InvalidData,
216                format!("Invalid read line format: {}", s),
217            ));
218        }
219
220        let id: BString = fields[0].into();
221        let identity = fields[1].parse()?;
222        Ok(Self { id, identity })
223    }
224}
225
226#[derive(Debug, Clone, PartialEq)]
227pub enum ReadIdentity {
228    SO, // source
229    IN, // intermediate
230    SI, // sink
231}
232
233impl fmt::Display for ReadIdentity {
234    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
235        match self {
236            ReadIdentity::SO => write!(f, "SO"),
237            ReadIdentity::IN => write!(f, "IN"),
238            ReadIdentity::SI => write!(f, "SI"),
239        }
240    }
241}
242
243impl FromStr for ReadIdentity {
244    type Err = io::Error;
245
246    fn from_str(s: &str) -> Result<Self, Self::Err> {
247        match s {
248            "SO" => Ok(ReadIdentity::SO),
249            "IN" => Ok(ReadIdentity::IN),
250            "SI" => Ok(ReadIdentity::SI),
251            _ => Err(io::Error::new(
252                io::ErrorKind::InvalidData,
253                format!("Invalid read identity: {}", s),
254            )),
255        }
256    }
257}
258
259impl From<&str> for ReadIdentity {
260    fn from(s: &str) -> Self {
261        s.parse().unwrap()
262    }
263}
264
265/// Represents DNA strand orientation
266#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
267pub enum Strand {
268    #[default]
269    Forward,
270    Reverse,
271}
272
273impl FromStr for Strand {
274    type Err = anyhow::Error;
275
276    fn from_str(s: &str) -> Result<Self, Self::Err> {
277        match s {
278            "+" => Ok(Strand::Forward),
279            "-" => Ok(Strand::Reverse),
280            _ => Err(anyhow::anyhow!("Invalid strand: {}", s)),
281        }
282    }
283}
284
285impl fmt::Display for Strand {
286    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
287        match self {
288            Strand::Forward => write!(f, "+"),
289            Strand::Reverse => write!(f, "-"),
290        }
291    }
292}
293
294/// Node in the transcript segment graph
295#[derive(Debug, Clone, Default, Builder)]
296#[builder(on(BString, into))]
297pub struct NodeData {
298    pub id: BString,
299    pub reference_id: BString,
300    pub strand: Strand,
301    pub exons: Exons,
302    pub reads: Vec<ReadData>,
303    pub sequence: Option<BString>,
304    pub attributes: HashMap<BString, Attribute>,
305}
306
307impl NodeData {
308    pub fn reference_start(&self) -> usize {
309        self.exons.first_exon().start
310    }
311    pub fn reference_end(&self) -> usize {
312        self.exons.last_exon().end
313    }
314    /// Converts the node data to a JSON representation
315    ///
316    /// # Arguments
317    /// * `attributes` - Optional additional attributes to include in the JSON
318    ///
319    /// # Returns
320    /// A JSON value representing the node
321    pub fn to_json(&self, attributes: Option<&[Attribute]>) -> Result<serde_json::Value> {
322        let mut data = json!({
323            "chrom": self.reference_id.to_str().unwrap(),
324            "ref_start": self.reference_start(),
325            "ref_end": self.reference_end(),
326            "strand": self.strand.to_string(),
327            "exons": format!("[{}]",  self.exons.to_string()),
328            "reads": self.reads.par_iter().map(|r| format!("{}", r) ).collect::<Vec<_>>(),
329            "id": self.id.to_str().unwrap(),
330        });
331
332        for attr in self.attributes.values() {
333            data[attr.tag.to_str().unwrap()] = match attr.attribute_type {
334                'f' => attr.as_float()?.into(),
335                'i' => attr.as_int()?.into(),
336                _ => attr.value.to_str().unwrap().into(),
337            };
338        }
339
340        if let Some(attributes) = attributes.as_ref() {
341            for attr in attributes.iter() {
342                data[attr.tag.to_str().unwrap()] = match attr.attribute_type {
343                    'f' => attr.as_float()?.into(),
344                    'i' => attr.as_int()?.into(),
345                    _ => attr.value.to_str().unwrap().into(),
346                };
347            }
348        }
349        let json = json!({"data": data});
350        Ok(json)
351    }
352
353    pub fn to_gtf(&self, attributes: Option<&[Attribute]>) -> Result<BString> {
354        // chr1    scannls exon    173867960       173867991       .       -       .       exon_id "001"; segment_id "0001"; ptc "1"; ptf "1.0"; transcript_id "3x1"; gene_id "3";
355        let mut res = vec![];
356        for (idx, exon) in self.exons.exons.iter().enumerate() {
357            let mut gtf = String::from("");
358            gtf.push_str(self.reference_id.to_str().unwrap());
359            gtf.push_str("\ttsg\texon\t");
360            gtf.push_str(&format!("{}\t{}\t", exon.start, exon.end));
361            gtf.push_str(".\t");
362            gtf.push_str(self.strand.to_string().as_str());
363            gtf.push_str("\t.\t");
364            gtf.push_str(format!("exon_id \"{:03}\"; ", idx + 1).as_str());
365
366            for attr in self.attributes.values() {
367                gtf.push_str(format!("{} \"{}\"; ", attr.tag, attr.value).as_str());
368            }
369
370            if let Some(attributes) = attributes.as_ref() {
371                for attr in attributes.iter().rev() {
372                    gtf.push_str(format!("{} \"{}\"; ", attr.tag, attr.value).as_str());
373                }
374            }
375            res.push(gtf);
376        }
377        Ok(res.join("\n").into())
378    }
379}
380
381impl fmt::Display for NodeData {
382    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
383        write!(
384            f,
385            "N\t{}\t{}:{}:{}\t{}\t{}",
386            self.id,
387            self.reference_id,
388            self.strand,
389            self.exons,
390            self.reads
391                .iter()
392                .map(|r| r.to_string())
393                .collect::<Vec<_>>()
394                .join(","),
395            self.sequence.as_ref().unwrap_or(&"".into())
396        )
397    }
398}
399
400impl FromStr for NodeData {
401    type Err = io::Error;
402
403    fn from_str(s: &str) -> Result<Self, Self::Err> {
404        // N  <rid>:<id>  <chrom>:<strand>:<exons>  <reads>  [<seq>]
405        let fields: Vec<&str> = s.split_whitespace().collect();
406        if fields.len() < 4 {
407            return Err(io::Error::new(
408                io::ErrorKind::InvalidData,
409                format!("Invalid node line format: {}", s),
410            ));
411        }
412
413        debug!("Parsing node: {}", s);
414        let id: BString = fields[1].into();
415
416        let reference_and_exons: Vec<&str> = fields[2].split(":").collect();
417        let reference_id = reference_and_exons[0].into();
418        let strand = reference_and_exons[1].parse().map_err(|e| {
419            io::Error::new(
420                io::ErrorKind::InvalidData,
421                format!("Failed to parse strand: {}", e),
422            )
423        })?;
424        let exons = reference_and_exons[2].parse().map_err(|e| {
425            io::Error::new(
426                io::ErrorKind::InvalidData,
427                format!("Failed to parse exons: {}", e),
428            )
429        })?;
430
431        let reads = fields[3]
432            .split(',')
433            .map(|s| s.parse().context("failed to parse reads").unwrap())
434            .collect::<Vec<_>>();
435
436        let sequence = if fields.len() > 4 && !fields[4].is_empty() {
437            Some(fields[4].into())
438        } else {
439            None
440        };
441
442        Ok(NodeData {
443            id,
444            reference_id,
445            strand,
446            exons,
447            reads,
448            sequence,
449            ..Default::default()
450        })
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use ahash::HashMapExt;
458
459    #[test]
460    fn test_node_from_str() {
461        let node1 = NodeData::from_str("N\tn1\tchr1:+:1000-2000\tread1:SO").unwrap();
462        assert_eq!(node1.id, "n1");
463    }
464
465    #[test]
466    fn test_exons_introns() {
467        let exons = Exons::from_str("100-200,300-400,500-600").unwrap();
468        let introns = exons.introns();
469        assert_eq!(introns.len(), 2);
470        assert_eq!(introns[0].start, 201);
471        assert_eq!(introns[0].end, 300);
472        assert_eq!(introns[1].start, 401);
473        assert_eq!(introns[1].end, 500);
474    }
475
476    #[test]
477    fn test_exons_len() {
478        let exons = Exons::from_str("100-200,300-400,500-600").unwrap();
479        assert_eq!(exons.len(), 3);
480    }
481
482    #[test]
483    fn test_exons_span() {
484        let exons = Exons::from_str("100-200,300-400,500-600").unwrap();
485        // (200-100) + (400-300) + (600-500) = 100 + 100 + 100 = 300
486        assert_eq!(exons.span(), 300);
487    }
488
489    #[test]
490    fn test_exons_first_last() {
491        let exons = Exons::from_str("100-200,300-400,500-600").unwrap();
492        assert_eq!(exons.first_exon().start, 100);
493        assert_eq!(exons.first_exon().end, 200);
494        assert_eq!(exons.last_exon().start, 500);
495        assert_eq!(exons.last_exon().end, 600);
496    }
497
498    #[test]
499    fn test_node_reference_start_end() {
500        let node = NodeData {
501            id: "node1".into(),
502            reference_id: "chr1".into(),
503            exons: Exons {
504                exons: vec![
505                    Interval {
506                        start: 100,
507                        end: 200,
508                    },
509                    Interval {
510                        start: 300,
511                        end: 400,
512                    },
513                ],
514            },
515            ..Default::default()
516        };
517
518        assert_eq!(node.reference_start(), 100);
519        assert_eq!(node.reference_end(), 400);
520    }
521
522    #[test]
523    fn test_node_to_json() -> Result<()> {
524        let node = NodeData {
525            id: "node1".into(),
526            reference_id: "chr1".into(),
527            strand: Strand::Forward,
528            exons: Exons {
529                exons: vec![
530                    Interval {
531                        start: 100,
532                        end: 200,
533                    },
534                    Interval {
535                        start: 300,
536                        end: 400,
537                    },
538                ],
539            },
540            reads: vec![
541                ReadData::builder().id("read1").identity("SO").build(),
542                ReadData::builder().id("read2").identity("IN").build(),
543            ],
544            attributes: {
545                let mut map = HashMap::new();
546                map.insert(
547                    "ptc".into(),
548                    Attribute {
549                        tag: "ptc".into(),
550                        attribute_type: 'Z',
551                        value: "1".into(),
552                    },
553                );
554                map.insert(
555                    "ptf".into(),
556                    Attribute {
557                        tag: "ptf".into(),
558                        attribute_type: 'Z',
559                        value: "0.0".into(),
560                    },
561                );
562                map
563            },
564            ..Default::default()
565        };
566
567        let json = node.to_json(None)?;
568        println!("{}", json);
569
570        // Check basic structure
571        assert!(json.get("data").is_some());
572        let data = json["data"].as_object().unwrap();
573
574        // Check fields
575        assert_eq!(data["chrom"], "chr1");
576        assert_eq!(data["ref_start"], 100);
577        assert_eq!(data["ref_end"], 400);
578        assert_eq!(data["strand"], "+");
579        assert_eq!(data["id"], "node1");
580        assert_eq!(data["ptc"], "1");
581        assert_eq!(data["ptf"], "0.0");
582
583        // Test with additional attributes
584        let additional_attrs = vec![Attribute {
585            tag: "is_head".into(),
586            attribute_type: 'Z',
587            value: "true".into(),
588        }];
589
590        let json_with_attrs = node.to_json(Some(&additional_attrs))?;
591        let data = json_with_attrs["data"].as_object().unwrap();
592        assert_eq!(data["is_head"], "true");
593
594        println!("{}", json_with_attrs);
595
596        Ok(())
597    }
598
599    #[test]
600    fn test_node_to_gtf() -> Result<()> {
601        let node = NodeData {
602            id: "node1".into(),
603            reference_id: "chr1".into(),
604            strand: Strand::Forward,
605            exons: Exons {
606                exons: vec![
607                    Interval {
608                        start: 100,
609                        end: 200,
610                    },
611                    Interval {
612                        start: 300,
613                        end: 400,
614                    },
615                ],
616            },
617            attributes: {
618                let mut map = HashMap::new();
619                map.insert(
620                    "segment_id".into(),
621                    Attribute {
622                        tag: "segment_id".into(),
623                        attribute_type: 'Z',
624                        value: "001".into(),
625                    },
626                );
627                map
628            },
629            ..Default::default()
630        };
631
632        let gtf = node.to_gtf(None)?;
633        let gtf_str = gtf.to_str().unwrap();
634        let lines: Vec<&str> = gtf_str.split('\n').collect();
635
636        assert_eq!(lines.len(), 2);
637        assert!(lines[0].starts_with("chr1\ttsg\texon\t100\t200\t.\t+\t.\texon_id \"001\""));
638        assert!(lines[0].contains("segment_id \"001\""));
639        assert!(lines[1].starts_with("chr1\ttsg\texon\t300\t400\t.\t+\t.\texon_id \"002\""));
640
641        // Test with additional attributes
642        let additional_attrs = vec![Attribute {
643            tag: "transcript_id".into(),
644            attribute_type: 'Z',
645            value: "1".into(),
646        }];
647
648        let gtf_with_attrs = node.to_gtf(Some(&additional_attrs))?;
649        let gtf_str = gtf_with_attrs.to_str().unwrap();
650        let lines: Vec<&str> = gtf_str.split('\n').collect();
651
652        assert!(lines[0].contains("transcript_id \"1\""));
653        assert!(lines[1].contains("transcript_id \"1\""));
654
655        Ok(())
656    }
657}