use crate::{Sentinel, SharedFileType, WriteState};
use pin_project::{pin_project, pinned_drop};
use std::io::{Error, ErrorKind, SeekFrom};
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io;
use tokio::io::{AsyncRead, AsyncSeek, ReadBuf};
use uuid::Uuid;
#[pin_project(PinnedDrop)]
pub struct SharedFileReader<T> {
id: Uuid,
#[pin]
file: T,
sentinel: Arc<Sentinel<T>>,
read: AtomicUsize,
}
static NODE_ID: &'static [u8; 6] = &[2, 3, 0, 6, 1, 2];
impl<T> SharedFileReader<T>
where
T: SharedFileType<Type = T>,
{
pub(crate) fn new(file: T, sentinel: Arc<Sentinel<T>>) -> Self {
Self {
id: Uuid::now_v1(&NODE_ID),
file,
sentinel,
read: AtomicUsize::new(0),
}
}
pub async fn fork(&self) -> Result<Self, T::OpenError> {
Ok(Self {
id: Uuid::now_v1(&NODE_ID),
file: self.sentinel.original.open_ro().await?,
sentinel: self.sentinel.clone(),
read: AtomicUsize::new(0),
})
}
}
impl<T> SharedFileReader<T> {
pub fn file_size(&self) -> FileSize {
match self.sentinel.state.load() {
WriteState::Pending(commited, _written) => FileSize::AtLeast(commited),
WriteState::Completed(size) => FileSize::Exactly(size),
WriteState::Failed => FileSize::Error,
}
}
}
#[derive(Debug, Copy, Clone)]
pub enum FileSize {
AtLeast(usize),
Exactly(usize),
Error,
}
#[pinned_drop]
impl<T> PinnedDrop for SharedFileReader<T> {
fn drop(mut self: Pin<&mut Self>) {
self.sentinel.remove_reader_waker(&self.id)
}
}
impl<T> AsyncRead for SharedFileReader<T>
where
T: AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let read_so_far = self.read.load(Ordering::Acquire);
let current_total = match self.sentinel.state.load() {
WriteState::Pending(committed, _written) => {
if read_so_far == committed {
self.sentinel
.register_reader_waker(self.id.clone(), cx.waker());
return Poll::Pending;
}
committed
}
WriteState::Completed(count) => {
if read_so_far == count {
return Poll::Ready(Ok(()));
}
count
}
WriteState::Failed => {
return Poll::Ready(Err(Error::new(
ErrorKind::BrokenPipe,
ReadError::FileClosed,
)))
}
};
let read_at_most = (current_total - read_so_far).min(buf.remaining());
let mut smaller_buf = buf.take(read_at_most);
let read_offset = smaller_buf.filled().len();
let this = self.project();
if let Poll::Ready(result) = this.file.poll_read(cx, &mut smaller_buf) {
this.sentinel.remove_reader_waker(&this.id);
if let Err(e) = result {
return Poll::Ready(Err(e));
}
let read_now = smaller_buf.filled().len();
if read_now != read_offset {
buf.advance(read_now);
let read = read_so_far + (read_now - read_offset);
this.read.store(read, Ordering::Release);
return Poll::Ready(result);
}
match this.sentinel.state.load() {
WriteState::Pending(_, _) => {}
WriteState::Completed(_) => return Poll::Ready(Ok(())),
WriteState::Failed => {
return Poll::Ready(Err(Error::new(
ErrorKind::BrokenPipe,
ReadError::FileClosed,
)))
}
}
}
buf.advance(0);
this.sentinel
.register_reader_waker(this.id.clone(), cx.waker());
Poll::Pending
}
}
impl<T> AsyncSeek for SharedFileReader<T>
where
T: AsyncSeek,
{
fn start_seek(self: Pin<&mut Self>, position: SeekFrom) -> io::Result<()> {
let this = self.project();
this.file.start_seek(position)
}
fn poll_complete(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
let this = self.project();
this.file.poll_complete(cx)
}
}
#[derive(Debug, thiserror::Error)]
pub enum ReadError {
#[error(transparent)]
Io(#[from] Error),
#[error("The file was already closed")]
FileClosed,
}