skua_voice/input/adapters/
async_adapter.rs

1use crate::input::AudioStreamError;
2use async_trait::async_trait;
3use flume::{Receiver, RecvError, Sender, TryRecvError};
4use futures::{future::Either, stream::FuturesUnordered, FutureExt, StreamExt};
5use ringbuf::*;
6use std::{
7    io::{
8        Error as IoError,
9        ErrorKind as IoErrorKind,
10        Read,
11        Result as IoResult,
12        Seek,
13        SeekFrom,
14        Write,
15    },
16    sync::{
17        atomic::{AtomicBool, Ordering},
18        Arc,
19    },
20};
21use symphonia_core::io::MediaSource;
22use tokio::{
23    io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt},
24    sync::Notify,
25};
26
27struct AsyncAdapterSink {
28    bytes_in: HeapProducer<u8>,
29    req_rx: Receiver<AdapterRequest>,
30    resp_tx: Sender<AdapterResponse>,
31    stream: Box<dyn AsyncMediaSource>,
32    notify_rx: Arc<Notify>,
33}
34
35impl AsyncAdapterSink {
36    async fn launch(mut self) {
37        let mut inner_buf = [0u8; 32 * 1024];
38        let mut read_region = 0..0;
39        let mut hit_end = false;
40        let mut blocked = false;
41        let mut pause_buf_moves = false;
42        let mut seek_res = None;
43        let mut seen_bytes = 0;
44
45        loop {
46            // if read_region is empty, refill from src.
47            //  if that read is zero, tell other half.
48            // if WouldBlock, block on msg acquire,
49            // else non_block msg acquire.
50
51            if !pause_buf_moves {
52                if !hit_end && read_region.is_empty() {
53                    if let Ok(n) = self.stream.read(&mut inner_buf).await {
54                        read_region = 0..n;
55                        if n == 0 {
56                            drop(self.resp_tx.send_async(AdapterResponse::ReadZero).await);
57                            hit_end = true;
58                        }
59                        seen_bytes += n as u64;
60                    } else {
61                        match self.stream.try_resume(seen_bytes).await {
62                            Ok(s) => {
63                                self.stream = s;
64                            },
65                            Err(_e) => break,
66                        }
67                    }
68                }
69
70                while !read_region.is_empty() && !blocked {
71                    if let Ok(n_moved) = self
72                        .bytes_in
73                        .write(&inner_buf[read_region.start..read_region.end])
74                    {
75                        read_region.start += n_moved;
76                        drop(self.resp_tx.send_async(AdapterResponse::ReadOccurred).await);
77                    } else {
78                        blocked = true;
79                    }
80                }
81            }
82
83            let msg = if blocked || hit_end {
84                let mut fs = FuturesUnordered::new();
85                fs.push(Either::Left(self.req_rx.recv_async()));
86                fs.push(Either::Right(self.notify_rx.notified().map(|()| {
87                    let o: Result<AdapterRequest, RecvError> = Ok(AdapterRequest::Wake);
88                    o
89                })));
90
91                match fs.next().await {
92                    Some(Ok(a)) => a,
93                    _ => break,
94                }
95            } else {
96                match self.req_rx.try_recv() {
97                    Ok(a) => a,
98                    Err(TryRecvError::Empty) => continue,
99                    _ => break,
100                }
101            };
102
103            match msg {
104                AdapterRequest::Wake => blocked = false,
105                AdapterRequest::ByteLen => {
106                    drop(
107                        self.resp_tx
108                            .send_async(AdapterResponse::ByteLen(self.stream.byte_len().await))
109                            .await,
110                    );
111                },
112                AdapterRequest::Seek(pos) => {
113                    pause_buf_moves = true;
114                    drop(self.resp_tx.send_async(AdapterResponse::SeekClear).await);
115                    seek_res = Some(self.stream.seek(pos).await);
116                },
117                AdapterRequest::SeekCleared => {
118                    if let Some(res) = seek_res.take() {
119                        drop(
120                            self.resp_tx
121                                .send_async(AdapterResponse::SeekResult(res))
122                                .await,
123                        );
124                    }
125                    pause_buf_moves = false;
126                },
127            }
128        }
129    }
130}
131
132/// An adapter for converting an async media source into a synchronous one
133/// usable by symphonia.
134///
135/// This adapter takes a source implementing `AsyncRead`, and allows the receive side to
136/// pass along seek requests needed. This allows for passing bytes from exclusively `AsyncRead`
137/// streams (e.g., hyper HTTP sessions) to Songbird.
138pub struct AsyncAdapterStream {
139    bytes_out: HeapConsumer<u8>,
140    can_seek: bool,
141    // Note: these are Atomic just to work around the need for
142    // check_messages to take &self rather than &mut.
143    finalised: AtomicBool,
144    bytes_known_present: AtomicBool,
145    req_tx: Sender<AdapterRequest>,
146    resp_rx: Receiver<AdapterResponse>,
147    notify_tx: Arc<Notify>,
148}
149
150impl AsyncAdapterStream {
151    /// Wrap and pull from an async file stream, with an intermediate ring-buffer of size `buf_len`
152    /// between the async and sync halves.
153    #[must_use]
154    pub fn new(stream: Box<dyn AsyncMediaSource>, buf_len: usize) -> AsyncAdapterStream {
155        let (bytes_in, bytes_out) = SharedRb::new(buf_len).split();
156        let (resp_tx, resp_rx) = flume::unbounded();
157        let (req_tx, req_rx) = flume::unbounded();
158        let can_seek = stream.is_seekable();
159        let notify_rx = Arc::new(Notify::new());
160        let notify_tx = notify_rx.clone();
161
162        let sink = AsyncAdapterSink {
163            bytes_in,
164            req_rx,
165            resp_tx,
166            stream,
167            notify_rx,
168        };
169        let stream = AsyncAdapterStream {
170            bytes_out,
171            can_seek,
172            finalised: false.into(),
173            bytes_known_present: false.into(),
174            req_tx,
175            resp_rx,
176            notify_tx,
177        };
178
179        tokio::spawn(async move {
180            Box::pin(sink.launch()).await;
181        });
182
183        stream
184    }
185
186    fn handle_messages(&self, op: Operation) -> Option<AdapterResponse> {
187        loop {
188            let msg = if op.will_block() {
189                self.resp_rx.recv().ok()
190            } else {
191                self.resp_rx.try_recv().ok()
192            };
193
194            let Some(msg) = msg else { break None };
195
196            // state changes
197            match &msg {
198                AdapterResponse::ReadZero => {
199                    self.finalised.store(true, Ordering::Relaxed);
200                },
201                AdapterResponse::ReadOccurred => {
202                    self.bytes_known_present.store(true, Ordering::Relaxed);
203                },
204                _ => {},
205            }
206
207            if op.expected_msg(&msg) {
208                break Some(msg);
209            }
210        }
211    }
212
213    fn is_dropped_and_clear(&self) -> bool {
214        self.resp_rx.is_empty() && self.resp_rx.is_disconnected()
215    }
216
217    fn check_dropped(&self) -> IoResult<()> {
218        if self.is_dropped_and_clear() {
219            Err(IoError::new(
220                IoErrorKind::UnexpectedEof,
221                "Async half was dropped.",
222            ))
223        } else {
224            Ok(())
225        }
226    }
227}
228
229impl Read for AsyncAdapterStream {
230    fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
231        loop {
232            let block = !(self.bytes_known_present.load(Ordering::Relaxed)
233                || self.finalised.load(Ordering::Relaxed));
234            drop(self.handle_messages(Operation::Read { block }));
235
236            match self.bytes_out.read(buf) {
237                Ok(n) => {
238                    self.notify_tx.notify_one();
239                    return Ok(n);
240                },
241                Err(e) if e.kind() == IoErrorKind::WouldBlock => {
242                    // receive side must ABSOLUTELY be unblocked here.
243                    self.notify_tx.notify_one();
244                    if self.finalised.load(Ordering::Relaxed) {
245                        return Ok(0);
246                    }
247                    self.bytes_known_present.store(false, Ordering::Relaxed);
248                    self.check_dropped()?;
249                },
250                a => {
251                    println!("Misc err {a:?}");
252                    return a;
253                },
254            }
255        }
256    }
257}
258
259impl Seek for AsyncAdapterStream {
260    fn seek(&mut self, pos: SeekFrom) -> IoResult<u64> {
261        if !self.can_seek {
262            return Err(IoError::new(
263                IoErrorKind::Unsupported,
264                "Async half does not support seek operations.",
265            ));
266        }
267
268        self.check_dropped()?;
269
270        _ = self.req_tx.send(AdapterRequest::Seek(pos));
271
272        // wait for async to tell us that it has stopped writing,
273        // then clear buf and allow async to write again.
274        self.finalised.store(false, Ordering::Relaxed);
275        match self.handle_messages(Operation::Seek) {
276            Some(AdapterResponse::SeekClear) => {},
277            None => self.check_dropped().map(|()| unreachable!())?,
278            _ => unreachable!(),
279        }
280
281        self.bytes_out.skip(self.bytes_out.capacity());
282
283        _ = self.req_tx.send(AdapterRequest::SeekCleared);
284
285        match self.handle_messages(Operation::Seek) {
286            Some(AdapterResponse::SeekResult(a)) => a,
287            None => self.check_dropped().map(|()| unreachable!()),
288            _ => unreachable!(),
289        }
290    }
291}
292
293impl MediaSource for AsyncAdapterStream {
294    fn is_seekable(&self) -> bool {
295        self.can_seek
296    }
297
298    fn byte_len(&self) -> Option<u64> {
299        self.check_dropped().ok()?;
300
301        _ = self.req_tx.send(AdapterRequest::ByteLen);
302
303        match self.handle_messages(Operation::Len) {
304            Some(AdapterResponse::ByteLen(a)) => a,
305            None => self.check_dropped().ok().map(|()| unreachable!()),
306            _ => unreachable!(),
307        }
308    }
309}
310
311enum AdapterRequest {
312    Wake,
313    Seek(SeekFrom),
314    SeekCleared,
315    ByteLen,
316}
317
318enum AdapterResponse {
319    SeekResult(IoResult<u64>),
320    SeekClear,
321    ByteLen(Option<u64>),
322    ReadZero,
323    ReadOccurred,
324}
325
326#[derive(Copy, Clone)]
327enum Operation {
328    Read { block: bool },
329    Seek,
330    Len,
331}
332
333impl Operation {
334    fn will_block(self) -> bool {
335        match self {
336            Self::Read { block } => block,
337            _ => true,
338        }
339    }
340
341    fn expected_msg(self, msg: &AdapterResponse) -> bool {
342        match self {
343            Self::Read { .. } => matches!(
344                msg,
345                AdapterResponse::ReadOccurred | AdapterResponse::ReadZero
346            ),
347            Self::Seek => matches!(
348                msg,
349                AdapterResponse::SeekResult(_) | AdapterResponse::SeekClear
350            ),
351            Self::Len => matches!(msg, AdapterResponse::ByteLen(_)),
352        }
353    }
354}
355
356/// An async port of symphonia's [`MediaSource`].
357///
358/// Streams which are not seekable should implement `AsyncSeek` such that all operations
359/// fail with `Unsupported`, and implement `fn is_seekable(&self) -> { false }`.
360///
361/// [`MediaSource`]: MediaSource
362#[async_trait]
363pub trait AsyncMediaSource: AsyncRead + AsyncSeek + Send + Sync + Unpin {
364    /// Returns if the source is seekable. This may be an expensive operation.
365    fn is_seekable(&self) -> bool;
366
367    /// Returns the length in bytes, if available. This may be an expensive operation.
368    async fn byte_len(&self) -> Option<u64>;
369
370    /// Tries to recreate this stream in event of an error, resuming from the given offset.
371    async fn try_resume(
372        &mut self,
373        _offset: u64,
374    ) -> Result<Box<dyn AsyncMediaSource>, AudioStreamError> {
375        Err(AudioStreamError::Unsupported)
376    }
377}