Skip to main content

ts_analyzer/
reader.rs

1//! A module for reading the transport stream.
2use std::io::Read;
3use std::io::Seek;
4use std::io::SeekFrom;
5
6#[cfg(feature = "tracing")]
7use tracing::debug;
8#[cfg(any(feature = "tracing", feature = "read_amount_mb"))]
9use tracing::info;
10#[cfg(feature = "tracing")]
11use tracing::trace;
12
13use crate::ErrorKind;
14use crate::helpers::tracked_payload::TrackedPayload;
15use crate::packet::PACKET_SIZE;
16use crate::packet::TsPacket;
17use crate::packet::header::SYNC_BYTE;
18
19/// Struct used for holding information related to reading the transport stream.
20#[derive(Debug, getset::Getters)]
21pub struct TsReader<T> {
22    /// Buffered reader for the transport stream file.
23    reader: T,
24    /// Sync byte alignment. A Sync byte should be found every `PACKET_SIZE`
25    /// away.
26    sync_alignment: u64,
27    /// Counter of the number of packets read
28    #[get = "pub"]
29    packets_read: u64,
30    /// PIDs that should be tracked when querying for packets or payloads.
31    ///
32    /// If empty, all PIDs are tracked. This will use more memory as there are
33    /// more incomplete payloads to keep track of.
34    tracked_pids: Vec<u16>,
35    /// Payloads that are currently being tracked by the reader.
36    tracked_payloads: Vec<TrackedPayload>,
37
38    /// Number of bytes last reported to have been read
39    last_reported_read_amount: u64,
40}
41
42impl<T> TsReader<T>
43where
44    T: Read + Seek,
45{
46    pub fn iter_packets(&'_ mut self) -> TsPackets<'_, T> {
47        TsPackets(self)
48    }
49
50    pub fn iter_payloads(&'_ mut self) -> TsPayloads<'_, T> {
51        TsPayloads(self)
52    }
53
54    pub fn stream_position(&mut self) -> u64 {
55        self.reader.stream_position().unwrap()
56    }
57
58    /// Create a new TSReader instance using the given file.
59    ///
60    /// This function also finds the first SYNC byte, so we can determine the
61    /// alignment of the transport packets.
62    /// # Parameters
63    /// - `buf_reader`: a buffered reader that contains transport stream data.
64    pub fn new(mut buf_reader: T) -> Result<Self, ErrorKind> {
65        #[cfg(feature = "tracing")]
66        trace!("Attempting to create new TS packet");
67        // Find the first sync byte, so we can search easier by doing simple
68        // `PACKET_SIZE` buffer reads.
69        let mut read_buf = [0];
70        let sync_alignment: u64;
71
72        loop {
73            let count = buf_reader.read(&mut read_buf)?;
74
75            // Return a `NoSyncByteFound` error if no SYNC byte could be found
76            // in the reader.
77            if count == 0 {
78                return Err(ErrorKind::NoSyncByteFound);
79            }
80
81            // Run through this loop until we find a sync byte.
82            if read_buf[0] != SYNC_BYTE {
83                continue;
84            }
85
86            // Note the location of this SYNC byte for later
87            let sync_pos = buf_reader
88                .stream_position()
89                .expect("Couldn't get stream position from BufReader");
90
91            #[cfg(feature = "tracing")]
92            trace!("SYNC found at position {}", sync_pos);
93
94            // If we think this is the correct alignment because we have found a
95            // SYNC byte we need to verify that this is correct by
96            // seeking 1 `PACKET_SIZE` away and verifying a SYNC
97            // byte is there. If there isn't one there then this is simply the
98            // same data as a SYNC byte by coincidence, and we need
99            // to keep looking.
100            //
101            // WARN: There is always the possibility that we hit a `0x47` in the
102            // payload, seek 1 `PACKET_SIZE` further, and find another `0x47`
103            // but I don't have a way of accounting for that, so we're going
104            // with blind hope that this case doesn't get seen.
105            buf_reader.seek_relative(PACKET_SIZE as i64 - 1)?;
106            let count = buf_reader.read(&mut read_buf)?;
107
108            // If we run out of data to read while trying to verify that the
109            // SYNC byte is actually a SYNC byte and isn't part of a
110            // payload then we'll simply assume that it really is a
111            // SYNC byte as we have nothing else to go off of.
112            if count == 0 {
113                #[cfg(feature = "tracing")]
114                debug!("Could not find SYNC byte in stream");
115                return Err(ErrorKind::NoSyncByteFound);
116            }
117
118            // Seek back to the original location for later reading.
119            buf_reader.seek(SeekFrom::Start(sync_pos - 1))?;
120
121            // If the byte 1 `PACKET_SIZE` away is also a SYNC byte we can be
122            // relatively sure that this alignment is correct.
123            if read_buf[0] == SYNC_BYTE {
124                sync_alignment = sync_pos;
125                break;
126            }
127        }
128
129        Ok(TsReader {
130            reader: buf_reader,
131            sync_alignment,
132            packets_read: 0,
133            tracked_pids: Vec::new(),
134            tracked_payloads: Vec::new(),
135            last_reported_read_amount: 0,
136        })
137    }
138
139    /// Read the next packet from the transport stream file.
140    ///
141    /// This function returns `None` for any `Err` in order to prevent the need
142    /// for `.unwrap()` calls in more concise code.
143    /// # Returns
144    /// `Some(TSPacket)` if the next transport stream packet could be parsed
145    /// from the file. `None` if the next transport stream packet could not
146    /// be parsed from the file for any reason. This includes if the entire
147    /// file has been fully read.
148    pub fn next_packet_unchecked(&mut self) -> Option<TsPacket> {
149        self.next_packet().unwrap_or(None)
150    }
151
152    /// Read the next packet from the transport stream file.
153    /// # Returns
154    /// `Ok(Some(TSPacket))` if the next transport stream packet could be parsed
155    /// from the file. `Ok(None)` if there was no issue reading the file and
156    /// no more TS packets can be read.
157    pub fn next_packet(&mut self) -> Result<Option<TsPacket>, ErrorKind> {
158        let mut packet_buf = [0; PACKET_SIZE];
159        loop {
160            if let Err(e) = self.reader.read_exact(&mut packet_buf) {
161                if e.kind() == std::io::ErrorKind::UnexpectedEof {
162                    #[cfg(feature = "tracing")]
163                    info!("Finished reading file");
164                    return Ok(None);
165                }
166
167                return Err(e.into());
168            }
169
170            if (cfg!(feature = "tracing") || cfg!(feature = "read_amount_mb"))
171                && let Ok(position) = self.reader.stream_position()
172            {
173                #[cfg(feature = "tracing")]
174                trace!("Seek position in stream: {}", position);
175
176                if cfg!(feature = "read_amount_mb") {
177                    let amount = position / (1000 * 1000);
178                    if self.last_reported_read_amount < amount {
179                        #[cfg(feature = "tracing")]
180                        info!("Read {}MB from stream", amount);
181                        self.last_reported_read_amount = amount;
182                    }
183                }
184            }
185
186            self.packets_read += 1;
187            #[cfg(feature = "tracing")]
188            trace!("Packets read from input: {}", self.packets_read);
189
190            let packet = match TsPacket::from_bytes(&mut packet_buf) {
191                Ok(packet) => packet,
192                Err(e) => {
193                    #[cfg(feature = "tracing")]
194                    debug!(
195                        "Got error when trying to parse next packet from bytes {:2X?}",
196                        packet_buf
197                    );
198                    return Err(e);
199                }
200            };
201
202            // We should only return a packet if it is in the tracked PIDs (or
203            // there are no tracked PIDs)
204            if !self.tracked_pids.is_empty()
205                && !self.tracked_pids.contains(&packet.header().pid())
206            {
207                continue;
208            }
209
210            #[cfg(feature = "tracing")]
211            debug!("Returning TS packet for PID: {}", packet.header().pid());
212            return Ok(Some(packet));
213        }
214    }
215
216    /// Read the next payload from the transport stream file.
217    ///
218    /// This function returns `None` for any `Err` in order to prevent the need
219    /// for `.unwrap()` calls in more concise code.
220    /// # Returns
221    /// `Some(TSPayload)` if the next transport stream packet could be parsed
222    /// from the file. `None` if the next transport stream payload could not
223    /// be parsed from the file for any reason. This includes if the entire
224    /// file has been fully read.
225    pub fn next_payload_unchecked(&mut self) -> Option<Vec<u8>> {
226        self.next_payload().unwrap_or(None)
227    }
228
229    /// Read the next full payload from the file.
230    ///
231    /// This function parses through all transport stream packets, stores them
232    /// in a buffer and concatenates their payloads together once a payload
233    /// has been complete.
234    ///
235    /// NOTE: I make the assumption that all packets containing KLV data are PSI
236    /// packets and therefore the first byte of the payload indicates when the
237    /// start of the new payload is.
238    ///
239    /// NOTE: By looking at the source I determined that `mpeg2ts` classifies
240    /// the payload type by the PID and only checks against known, reserved
241    /// PIDs, while all other are instantiated as
242    /// `TransportPayload::Raw(Bytes)`. Because all PIDs for the data we are
243    /// looking for are going to be non-constant, we can just discard all
244    /// variants besides `TransportPayload::Raw`.
245    ///
246    /// TODO: A performance enhancment that could be made is to look for the
247    /// first `TransportPayload::PAT` payload and use that (and the
248    /// `TransportPayload::PMT` that it points to) to determine what PIDs are
249    /// PSI and which are PET. We can then disregard all PET streams instead
250    /// of naively treating them as PSI streams and trying to find a KLV
251    /// value in it.
252    pub fn next_payload(&mut self) -> Result<Option<Vec<u8>>, ErrorKind> {
253        #[cfg(feature = "tracing")]
254        trace!("Getting next payload");
255        loop {
256            let possible_packet = self.next_packet()?;
257
258            let Some(packet) = possible_packet else {
259                return Ok(None);
260            };
261
262            // Add this packet's payload to the tracked payload and retrieve the
263            // completed payload if it exists.
264            let payload = self.add_tracked_payload(packet);
265            if payload.is_some() {
266                return Ok(payload);
267            }
268        }
269    }
270
271    /// Return the alignment of the SYNC bytes in this reader.
272    pub fn sync_byte_alignment(&self) -> u64 {
273        self.sync_alignment
274    }
275
276    /// Add a PID to the tracking list.
277    ///
278    /// Only tracked PIDs are returned when running methods that gather packets
279    /// or payloads. If no PIDs are set to be tracked, then all PIDs are
280    /// tracked.
281    pub fn add_tracked_pid(&mut self, pid: u16) {
282        self.tracked_pids.push(pid);
283    }
284
285    /// Remove this PID from being tracked.
286    ///
287    /// Only tracked PIDs are returned when running methods that gather packets
288    /// or payloads. If no PIDs are set to be tracked, then all PIDs are
289    /// tracked.
290    pub fn remove_tracked_pid(&mut self, pid: u16) {
291        self.tracked_pids.retain(|vec_pid| *vec_pid != pid);
292    }
293
294    /// Add payload data from a packet to the tracked payloads list.
295    fn add_tracked_payload(&mut self, packet: TsPacket) -> Option<Vec<u8>> {
296        // Check to see if we already have an TrackedPayload object for this
297        // item PID
298        let pid = &packet.header().pid();
299
300        if let Some(ref mut tracked_payload) =
301            self.tracked_payloads.iter_mut().find(|tp| &tp.pid() == pid)
302        {
303            let payload = packet.to_payload()?;
304            return tracked_payload.add(payload);
305        }
306
307        // We cannot possibly know that a payload is complete from the first
308        // packet. In order to know that a payload is fully contained in
309        // 1 packet we need to see the `PUSI` flag set in
310        // the next packet so there is no reason to check if the packet is
311        // complete when creating a new TrackedPayload.
312
313        if let Ok(tp) = TrackedPayload::from_packet(packet) {
314            self.tracked_payloads.push(tp);
315        };
316
317        None
318    }
319}
320
321pub struct TsPackets<'a, T>(&'a mut TsReader<T>);
322impl<'a, T> Iterator for TsPackets<'a, T>
323where
324    T: Read + Seek,
325{
326    type Item = Result<TsPacket, ErrorKind>;
327
328    fn next(&mut self) -> Option<Self::Item> {
329        self.0.next_packet().transpose()
330    }
331}
332
333#[derive(derive_more::Deref)]
334pub struct TsPayloads<'a, T>(&'a mut TsReader<T>);
335impl<'a, T> Iterator for TsPayloads<'a, T>
336where
337    T: Read + Seek,
338{
339    type Item = Result<Vec<u8>, ErrorKind>;
340
341    fn next(&mut self) -> Option<Self::Item> {
342        self.0.next_payload().transpose()
343    }
344}