termusic_stream/
source.rs1use async_trait::async_trait;
2use bytes::Bytes;
3use futures::{Stream, StreamExt};
4use parking_lot::{Condvar, Mutex, RwLock, RwLockReadGuard};
5use rangemap::RangeSet;
6use std::{
7 error::Error,
8 fs::File,
9 io::{self, BufWriter, Seek, SeekFrom, Write},
10 sync::{
11 atomic::{AtomicI64, Ordering},
12 Arc,
13 },
14};
15use tap::TapFallible;
16use tokio::sync::mpsc;
17use tracing::{debug, error, info, trace};
18
19#[async_trait]
20pub trait SourceStream:
21 Stream<Item = Result<Bytes, Self::Error>> + Unpin + Send + Sync + Sized + 'static
22{
23 type Url: Send;
24 type Error: Error + Send;
25
26 async fn create(
27 url: Self::Url,
28 is_radio: bool,
29 radio_title: Arc<Mutex<String>>,
30 ) -> io::Result<Self>;
31 async fn content_length(&self) -> Option<u64>;
32 async fn seek_range(&mut self, start: u64, end: Option<u64>) -> io::Result<()>;
33}
34
35#[derive(Debug, Clone)]
36pub struct SourceHandle {
37 downloaded: Arc<RwLock<RangeSet<u64>>>,
38 requested_position: Arc<AtomicI64>,
39 position_reached: Arc<(Mutex<Waiter>, Condvar)>,
40 content_length_retrieved: Arc<(Mutex<bool>, Condvar)>,
41 content_length: Arc<AtomicI64>,
42 seek_tx: mpsc::Sender<u64>,
43}
44
45impl SourceHandle {
46 pub fn downloaded(&self) -> RwLockReadGuard<rangemap::RangeSet<u64>> {
47 self.downloaded.read()
48 }
49
50 pub fn request_position(&self, position: u64) {
51 self.requested_position
52 .store(position as i64, Ordering::SeqCst);
53 }
54
55 pub fn wait_for_requested_position(&self) {
56 let (mutex, cvar) = &*self.position_reached;
57 let mut waiter = mutex.lock();
58 if !waiter.stream_done {
59 debug!("Waiting for requested position");
60 cvar.wait_while(&mut waiter, |waiter| {
61 !waiter.stream_done && !waiter.position_reached
62 });
63 if !waiter.stream_done {
64 waiter.position_reached = false;
65 }
66 debug!("Position reached");
67 }
68 }
69
70 pub fn seek(&self, position: u64) {
71 self.seek_tx.try_send(position).ok();
72 }
73
74 pub fn content_length(&self) -> Option<u64> {
75 let (mutex, cvar) = &*self.content_length_retrieved;
76 let mut done = mutex.lock();
77 if !*done {
78 cvar.wait_while(&mut done, |done| !*done);
79 }
80 let length = self.content_length.load(Ordering::SeqCst);
81 if length > -1 {
82 Some(length as u64)
83 } else {
84 None
85 }
86 }
87}
88
89#[derive(Default, Debug)]
90struct Waiter {
91 position_reached: bool,
92 stream_done: bool,
93}
94
95pub struct Source {
96 writer: BufWriter<File>,
97 downloaded: Arc<RwLock<RangeSet<u64>>>,
98 requested_position: Arc<AtomicI64>,
99 position_reached: Arc<(Mutex<Waiter>, Condvar)>,
100 content_length_retrieved: Arc<(Mutex<bool>, Condvar)>,
101 content_length: Arc<AtomicI64>,
102 seek_tx: mpsc::Sender<u64>,
103 seek_rx: mpsc::Receiver<u64>,
104}
105
106const PREFETCH_BYTES: u64 = 1024 * 256;
107
108impl Source {
109 pub fn new(tempfile: File) -> Self {
110 let (seek_tx, seek_rx) = mpsc::channel(32);
111 Self {
112 writer: BufWriter::new(tempfile),
113 downloaded: Default::default(),
114 requested_position: Arc::new(AtomicI64::new(-1)),
115 position_reached: Default::default(),
116 content_length_retrieved: Default::default(),
117 seek_tx,
118 seek_rx,
119 content_length: Default::default(),
120 }
121 }
122
123 pub async fn download<S: SourceStream>(
124 mut self,
125 mut stream: S,
126 radio_downloaded: Arc<Mutex<u64>>,
127 ) -> io::Result<()> {
128 info!("Starting file download");
129 let content_length = stream.content_length().await;
130 if let Some(content_length) = content_length {
131 self.content_length
132 .swap(content_length as i64, Ordering::SeqCst);
133 } else {
134 self.content_length.swap(-1, Ordering::SeqCst);
135 }
136 {
137 let (mutex, cvar) = &*self.content_length_retrieved;
138 *mutex.lock() = true;
139 cvar.notify_all();
140 }
141 loop {
142 if let Some(Ok(bytes)) = stream
143 .next()
144 .await
145 .map(|b| b.tap_err(|e| error!("Error reading stream: {e}")))
146 {
147 self.writer.write_all(&bytes)?;
148 let stream_position = self.writer.stream_position()?;
149 trace!("Prefetch: {}/{} bytes", stream_position, PREFETCH_BYTES);
150 if stream_position >= PREFETCH_BYTES {
151 self.downloaded.write().insert(0..stream_position);
152 break;
153 }
154 } else {
155 info!("File shorter than prefetch length");
156 self.writer.flush()?;
157 self.downloaded
158 .write()
159 .insert(0..self.writer.stream_position()?);
160 let (mutex, cvar) = &*self.position_reached;
161 (mutex.lock()).stream_done = true;
162 cvar.notify_all();
163 return Ok(());
164 }
165 }
166 info!("Prefetch complete");
167 loop {
168 tokio::select! {
169 bytes = stream.next() => {
170 if let Some(Ok(bytes)) =
171 bytes.map(|b| b.tap_err(|e| error!("Error reading from stream: {e}"))) {
172 let position = self.writer.stream_position()?;
173 *radio_downloaded.lock() = position;
174 self.writer.write_all(&bytes)?;
175 let new_position = self.writer.stream_position()?;
176 self.downloaded.write().insert(position .. new_position);
180 let requested = self.requested_position.load(Ordering::SeqCst);
181 if requested > -1 {
182 debug!("downloader: requested {requested} current {}", new_position);
183 }
184 if requested > -1 && new_position as i64 >= requested {
185 info!("Notifying requested position reached: {requested}. New position: {new_position}");
186 self.requested_position.store(-1, Ordering::SeqCst);
187 let (mutex, cvar) = &*self.position_reached;
188 (mutex.lock()).position_reached = true;
189 cvar.notify_all();
190 }
191 } else {
192 info!("Stream finished downloading");
193 if let Some(content_length) = content_length {
194 let gap = {
195 let downloaded = self.downloaded.read();
196 let range = 0 .. content_length;
197 let mut gaps = downloaded.gaps(&range);
198 gaps.next()
199 };
200 if let Some(gap) = gap {
201 debug!("Downloading missing stream chunk: {gap:?}.");
202 stream.seek_range(gap.start, Some(gap.end)).await?;
203 self.writer.seek(SeekFrom::Start(gap.start))?;
204 continue;
205 }
206 }
207 self.writer.flush()?;
208 let (mutex, cvar) = &*self.position_reached;
209 (mutex.lock()).stream_done = true;
210 cvar.notify_all();
211 return Ok(());
212 }
213 },
214 pos = self.seek_rx.recv() => {
215 if let Some(pos) = pos {
216 debug!("Received seek position {pos}");
217 let do_seek = {
218 let downloaded = self.downloaded.read();
219 if let Some(range) = downloaded.get(&pos) {
220 !range.contains(&self.writer.stream_position()?)
221 } else {
222 true
223 }
224 };
225 if do_seek {
226 debug!("Seek position not yet downloaded");
227 stream.seek_range(pos, None).await?;
228 self.writer.seek(SeekFrom::Start(pos))?;
229 }
230 }
231 }
232 }
233 }
234 }
235
236 pub fn source_handle(&self) -> SourceHandle {
237 SourceHandle {
238 downloaded: self.downloaded.clone(),
239 requested_position: self.requested_position.clone(),
240 position_reached: self.position_reached.clone(),
241 seek_tx: self.seek_tx.clone(),
242 content_length_retrieved: self.content_length_retrieved.clone(),
243 content_length: self.content_length.clone(),
244 }
245 }
246}