1use std::cmp;
9use std::io;
10use std::io::{IoSliceMut, Read, Seek};
11use std::ops::Sub;
12
13use super::SeekBuffered;
14use super::{MediaSource, ReadBytes};
15
16#[inline(always)]
17fn unexpected_eof_error<T>() -> io::Result<T> {
18 Err(io::Error::from(io::ErrorKind::UnexpectedEof))
19}
20
21pub struct MediaSourceStreamOptions {
23 pub buffer_len: usize,
25}
26
27impl Default for MediaSourceStreamOptions {
28 fn default() -> Self {
29 MediaSourceStreamOptions { buffer_len: 64 * 1024 }
30 }
31}
32
33pub struct MediaSourceStream<'s> {
53 inner: Box<dyn MediaSource + 's>,
55 ring: Box<[u8]>,
57 ring_mask: usize,
59 read_pos: usize,
61 write_pos: usize,
63 read_block_len: usize,
65 abs_pos: u64,
67 rel_pos: u64,
70}
71
72impl<'s> MediaSourceStream<'s> {
73 const MIN_BLOCK_LEN: usize = 1 * 1024;
74 const MAX_BLOCK_LEN: usize = 32 * 1024;
75
76 pub fn new(source: Box<dyn MediaSource + 's>, options: MediaSourceStreamOptions) -> Self {
77 assert!(options.buffer_len.count_ones() == 1);
79 assert!(options.buffer_len > Self::MAX_BLOCK_LEN);
80
81 MediaSourceStream {
82 inner: source,
83 ring: vec![0; options.buffer_len].into_boxed_slice(),
84 ring_mask: options.buffer_len - 1,
85 read_pos: 0,
86 write_pos: 0,
87 read_block_len: Self::MIN_BLOCK_LEN,
88 abs_pos: 0,
89 rel_pos: 0,
90 }
91 }
92
93 #[inline(always)]
96 fn is_buffer_exhausted(&self) -> bool {
97 self.read_pos == self.write_pos
98 }
99
100 fn fetch(&mut self) -> io::Result<()> {
102 if self.is_buffer_exhausted() {
104 let (vec1, vec0) = self.ring.split_at_mut(self.write_pos);
107
108 let actual_read_len = if vec0.len() >= self.read_block_len {
112 self.inner.read(&mut vec0[..self.read_block_len])?
113 }
114 else {
115 let rem = self.read_block_len - vec0.len();
117
118 let ring_vectors = &mut [IoSliceMut::new(vec0), IoSliceMut::new(&mut vec1[..rem])];
119
120 self.inner.read_vectored(ring_vectors)?
121 };
122
123 self.write_pos = (self.write_pos + actual_read_len) & self.ring_mask;
125
126 self.abs_pos += actual_read_len as u64;
128 self.rel_pos += actual_read_len as u64;
129
130 self.read_block_len = cmp::min(self.read_block_len << 1, Self::MAX_BLOCK_LEN);
133 }
134
135 Ok(())
136 }
137
138 fn fetch_or_eof(&mut self) -> io::Result<()> {
141 self.fetch()?;
142
143 if self.is_buffer_exhausted() {
144 return unexpected_eof_error();
145 }
146
147 Ok(())
148 }
149
150 #[inline(always)]
152 fn consume(&mut self, len: usize) {
153 self.read_pos = (self.read_pos + len) & self.ring_mask;
154 }
155
156 #[inline(always)]
158 fn continguous_buf(&self) -> &[u8] {
159 if self.write_pos >= self.read_pos {
160 &self.ring[self.read_pos..self.write_pos]
161 }
162 else {
163 &self.ring[self.read_pos..]
164 }
165 }
166
167 fn reset(&mut self, pos: u64) {
169 self.read_pos = 0;
170 self.write_pos = 0;
171 self.read_block_len = Self::MIN_BLOCK_LEN;
172 self.abs_pos = pos;
173 self.rel_pos = 0;
174 }
175}
176
177impl MediaSource for MediaSourceStream<'_> {
178 #[inline]
179 fn is_seekable(&self) -> bool {
180 self.inner.is_seekable()
181 }
182
183 #[inline]
184 fn byte_len(&self) -> Option<u64> {
185 self.inner.byte_len()
186 }
187}
188
189impl io::Read for MediaSourceStream<'_> {
190 fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
191 let read_len = buf.len();
192
193 while !buf.is_empty() {
194 self.fetch()?;
196
197 match self.continguous_buf().read(buf) {
200 Ok(0) => break,
201 Ok(count) => {
202 buf = &mut buf[count..];
203 self.consume(count);
204 }
205 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
206 Err(e) => return Err(e),
207 }
208 }
209
210 Ok(read_len - buf.len())
213 }
214}
215
216impl io::Seek for MediaSourceStream<'_> {
217 fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
218 let pos = match pos {
223 io::SeekFrom::Current(0) => return Ok(self.pos()),
224 io::SeekFrom::Current(delta_pos) => {
225 let delta = delta_pos - self.unread_buffer_len() as i64;
226 self.inner.seek(io::SeekFrom::Current(delta))
227 }
228 _ => self.inner.seek(pos),
229 }?;
230
231 self.reset(pos);
232
233 Ok(pos)
234 }
235}
236
237impl ReadBytes for MediaSourceStream<'_> {
238 #[inline(always)]
239 fn read_byte(&mut self) -> io::Result<u8> {
240 if self.is_buffer_exhausted() {
244 self.fetch_or_eof()?;
245 }
246
247 let value = self.ring[self.read_pos];
248 self.consume(1);
249
250 Ok(value)
251 }
252
253 fn read_double_bytes(&mut self) -> io::Result<[u8; 2]> {
254 let mut bytes = [0; 2];
255
256 let buf = self.continguous_buf();
257
258 if buf.len() >= 2 {
259 bytes.copy_from_slice(&buf[..2]);
260 self.consume(2);
261 }
262 else {
263 for byte in bytes.iter_mut() {
264 *byte = self.read_byte()?;
265 }
266 };
267
268 Ok(bytes)
269 }
270
271 fn read_triple_bytes(&mut self) -> io::Result<[u8; 3]> {
272 let mut bytes = [0; 3];
273
274 let buf = self.continguous_buf();
275
276 if buf.len() >= 3 {
277 bytes.copy_from_slice(&buf[..3]);
278 self.consume(3);
279 }
280 else {
281 for byte in bytes.iter_mut() {
282 *byte = self.read_byte()?;
283 }
284 };
285 Ok(bytes)
286 }
287
288 fn read_quad_bytes(&mut self) -> io::Result<[u8; 4]> {
289 let mut bytes = [0; 4];
290
291 let buf = self.continguous_buf();
292
293 if buf.len() >= 4 {
294 bytes.copy_from_slice(&buf[..4]);
295 self.consume(4);
296 }
297 else {
298 for byte in bytes.iter_mut() {
299 *byte = self.read_byte()?;
300 }
301 };
302 Ok(bytes)
303 }
304
305 fn read_buf(&mut self, buf: &mut [u8]) -> io::Result<usize> {
306 let read = self.read(buf)?;
308
309 if !buf.is_empty() && read == 0 { unexpected_eof_error() } else { Ok(read) }
313 }
314
315 fn read_buf_exact(&mut self, mut buf: &mut [u8]) -> io::Result<()> {
316 while !buf.is_empty() {
317 match self.read(buf) {
318 Ok(0) => break,
319 Ok(count) => {
320 buf = &mut buf[count..];
321 }
322 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
323 Err(e) => return Err(e),
324 }
325 }
326
327 if !buf.is_empty() { unexpected_eof_error() } else { Ok(()) }
328 }
329
330 fn scan_bytes_aligned<'a>(
331 &mut self,
332 _: &[u8],
333 _: usize,
334 _: &'a mut [u8],
335 ) -> io::Result<&'a mut [u8]> {
336 unimplemented!();
338 }
339
340 fn ignore_bytes(&mut self, mut count: u64) -> io::Result<()> {
341 let ring_len = self.ring.len() as u64;
345
346 while count >= 2 * ring_len && self.is_seekable() {
348 let delta = count.clamp(0, i64::MAX as u64).sub(ring_len);
349 self.seek(io::SeekFrom::Current(delta as i64))?;
350 count -= delta;
351 }
352
353 while count > 0 {
355 self.fetch_or_eof()?;
356 let discard_count = cmp::min(self.unread_buffer_len() as u64, count);
357 self.consume(discard_count as usize);
358 count -= discard_count;
359 }
360 Ok(())
361 }
362
363 fn pos(&self) -> u64 {
364 self.abs_pos - self.unread_buffer_len() as u64
365 }
366}
367
368impl SeekBuffered for MediaSourceStream<'_> {
369 fn ensure_seekback_buffer(&mut self, len: usize) {
370 let ring_len = self.ring.len();
371
372 let new_ring_len = (Self::MAX_BLOCK_LEN + len).next_power_of_two();
376
377 if ring_len < new_ring_len {
379 let mut new_ring = vec![0; new_ring_len].into_boxed_slice();
381
382 let (vec0, vec1) = if self.write_pos >= self.read_pos {
384 (&self.ring[self.read_pos..self.write_pos], None)
385 }
386 else {
387 (&self.ring[self.read_pos..], Some(&self.ring[..self.write_pos]))
388 };
389
390 let vec0_len = vec0.len();
392 new_ring[..vec0_len].copy_from_slice(vec0);
393
394 self.write_pos = if let Some(vec1) = vec1 {
395 let total_len = vec0_len + vec1.len();
396 new_ring[vec0_len..total_len].copy_from_slice(vec1);
397 total_len
398 }
399 else {
400 vec0_len
401 };
402
403 self.ring = new_ring;
404 self.ring_mask = new_ring_len - 1;
405 self.read_pos = 0;
406 }
407 }
408
409 fn unread_buffer_len(&self) -> usize {
410 if self.write_pos >= self.read_pos {
411 self.write_pos - self.read_pos
412 }
413 else {
414 self.write_pos + (self.ring.len() - self.read_pos)
415 }
416 }
417
418 fn read_buffer_len(&self) -> usize {
419 let unread_len = self.unread_buffer_len();
420
421 cmp::min(self.ring.len(), self.rel_pos as usize) - unread_len
422 }
423
424 fn seek_buffered(&mut self, pos: u64) -> u64 {
425 let old_pos = self.pos();
426
427 let delta = if pos > old_pos {
429 assert!(pos - old_pos < isize::MAX as u64);
430 (pos - old_pos) as isize
431 }
432 else if pos < old_pos {
433 assert!(old_pos - pos < isize::MAX as u64);
435 -((old_pos - pos) as isize)
436 }
437 else {
438 0
439 };
440
441 self.seek_buffered_rel(delta)
442 }
443
444 fn seek_buffered_rel(&mut self, delta: isize) -> u64 {
445 if delta < 0 {
446 let abs_delta = cmp::min((-delta) as usize, self.read_buffer_len());
447 self.read_pos = (self.read_pos + self.ring.len() - abs_delta) & self.ring_mask;
448 }
449 else if delta > 0 {
450 let abs_delta = cmp::min(delta as usize, self.unread_buffer_len());
451 self.read_pos = (self.read_pos + abs_delta) & self.ring_mask;
452 }
453
454 self.pos()
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::{MediaSourceStream, ReadBytes, SeekBuffered};
461 use std::io::{Cursor, Read};
462
463 fn generate_random_bytes(len: usize) -> Box<[u8]> {
465 let mut lcg: u32 = 0xec57c4bf;
466
467 let mut bytes = vec![0; len];
468
469 for quad in bytes.chunks_mut(4) {
470 lcg = lcg.wrapping_mul(1664525).wrapping_add(1013904223);
471 for (src, dest) in quad.iter_mut().zip(&lcg.to_le_bytes()) {
472 *src = *dest;
473 }
474 }
475
476 bytes.into_boxed_slice()
477 }
478
479 #[test]
480 fn verify_mss_read() {
481 let data = generate_random_bytes(5 * 96 * 1024);
482
483 let ms = Cursor::new(data.clone());
484 let mut mss = MediaSourceStream::new(Box::new(ms), Default::default());
485
486 let mut buf = &data[..];
490
491 for byte in &buf[..96 * 1024] {
493 assert_eq!(*byte, mss.read_byte().unwrap());
494 }
495
496 mss.ignore_bytes(11).unwrap();
497
498 buf = &buf[11 + (96 * 1024)..];
499
500 for bytes in buf[..2 * 48 * 1024].chunks_exact(2) {
502 assert_eq!(bytes, &mss.read_double_bytes().unwrap());
503 }
504
505 mss.ignore_bytes(33).unwrap();
506
507 buf = &buf[33 + (2 * 48 * 1024)..];
508
509 for bytes in buf[..3 * 32 * 1024].chunks_exact(3) {
511 assert_eq!(bytes, &mss.read_triple_bytes().unwrap());
512 }
513
514 mss.ignore_bytes(55).unwrap();
515
516 buf = &buf[55 + (3 * 32 * 1024)..];
517
518 for bytes in buf[..4 * 24 * 1024].chunks_exact(4) {
520 assert_eq!(bytes, &mss.read_quad_bytes().unwrap());
521 }
522 }
523
524 #[test]
525 fn verify_mss_read_to_end() {
526 let data = generate_random_bytes(5 * 96 * 1024);
527
528 let ms = Cursor::new(data.clone());
529 let mut mss = MediaSourceStream::new(Box::new(ms), Default::default());
530 let mut output: Vec<u8> = Vec::new();
531 assert_eq!(mss.read_to_end(&mut output).unwrap(), data.len());
532 assert_eq!(output.into_boxed_slice(), data);
533 }
534
535 #[test]
536 fn verify_mss_seek_buffered() {
537 let data = generate_random_bytes(1024 * 1024);
538
539 let ms = Cursor::new(data);
540 let mut mss = MediaSourceStream::new(Box::new(ms), Default::default());
541
542 assert_eq!(mss.read_buffer_len(), 0);
543 assert_eq!(mss.unread_buffer_len(), 0);
544
545 mss.ignore_bytes(5122).unwrap();
546
547 assert_eq!(5122, mss.pos());
548 assert_eq!(mss.read_buffer_len(), 5122);
549
550 let upper = mss.read_byte().unwrap();
551
552 assert_eq!(mss.seek_buffered_rel(-1000), 4123);
554 assert_eq!(mss.pos(), 4123);
555 assert_eq!(mss.read_buffer_len(), 4123);
556
557 assert_eq!(mss.seek_buffered_rel(999), 5122);
559 assert_eq!(mss.pos(), 5122);
560 assert_eq!(mss.read_buffer_len(), 5122);
561
562 assert_eq!(upper, mss.read_byte().unwrap());
563 }
564
565 #[test]
566 fn verify_reading_be() {
567 let data = generate_random_bytes(1024 * 1024);
568
569 let ms = Cursor::new(data);
570 let mut mss = MediaSourceStream::new(Box::new(ms), Default::default());
571
572 mss.ignore_bytes(2).unwrap();
574
575 assert_eq!(mss.read_be_f32().unwrap(), -72818055000000000000000000000.0);
576 assert_eq!(mss.read_be_f64().unwrap(), -0.000000000000011582640453292664);
577
578 assert_eq!(mss.read_be_u16().unwrap(), 32624);
579 assert_eq!(mss.read_be_u24().unwrap(), 6739677);
580 assert_eq!(mss.read_be_u32().unwrap(), 1569552917);
581 assert_eq!(mss.read_be_u64().unwrap(), 6091217585348000864);
582 }
583
584 #[test]
585 fn verify_reading_le() {
586 let data = generate_random_bytes(1024 * 1024);
587
588 let ms = Cursor::new(data);
589 let mut mss = MediaSourceStream::new(Box::new(ms), Default::default());
590
591 mss.ignore_bytes(1024).unwrap();
592
593 assert_eq!(mss.read_f32().unwrap(), -0.00000000000000000000000000048426285);
594 assert_eq!(mss.read_f64().unwrap(), -6444325820119113.0);
595
596 assert_eq!(mss.read_u16().unwrap(), 36195);
597 assert_eq!(mss.read_u24().unwrap(), 6710386);
598 assert_eq!(mss.read_u32().unwrap(), 2378776723);
599 assert_eq!(mss.read_u64().unwrap(), 5170196279331153683);
600 }
601}