use crate::{FilePath, Sentinel, SharedFileType, WriteState};
use crossbeam::atomic::AtomicCell;
use pin_project::{pin_project, pinned_drop};
use std::io::{Error, ErrorKind, IoSlice};
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io;
use tokio::io::AsyncWrite;
#[pin_project(PinnedDrop)]
pub struct SharedFileWriter<T> {
#[pin]
file: T,
sentinel: Arc<Sentinel<T>>,
}
impl<T> SharedFileWriter<T> {
pub(crate) fn new(file: T, sentinel: Arc<Sentinel<T>>) -> Self {
Self { file, sentinel }
}
pub fn file_path(&self) -> &PathBuf
where
T: FilePath,
{
self.file.file_path()
}
pub async fn sync_all(&self) -> Result<(), T::SyncError>
where
T: SharedFileType,
{
self.file.sync_all().await?;
Self::sync_committed_and_written(&self.sentinel);
self.sentinel.wake_readers();
Ok(())
}
pub async fn sync_data(&self) -> Result<(), T::SyncError>
where
T: SharedFileType,
{
self.file.sync_data().await?;
Self::sync_committed_and_written(&self.sentinel);
self.sentinel.wake_readers();
Ok(())
}
pub async fn complete(self) -> Result<(), CompleteWritingError>
where
T: SharedFileType,
{
if let Err(_) = self.sync_all().await {
return Err(CompleteWritingError::SyncError);
}
self.complete_no_sync()
}
pub fn complete_no_sync(self) -> Result<(), CompleteWritingError> {
self.finalize_state()
}
fn sync_committed_and_written(sentinel: &Arc<Sentinel<T>>) {
match sentinel.state.load() {
WriteState::Pending(_committed, written) => {
sentinel.state.store(WriteState::Pending(written, written));
}
WriteState::Completed(_) => {}
WriteState::Failed => {}
}
}
fn finalize_state(&self) -> Result<(), CompleteWritingError> {
let result = match self.sentinel.state.load() {
WriteState::Pending(_committed, written) => {
assert_eq!(_committed, written, "The number of committed bytes is less than the number of written bytes - call sync before dropping");
self.sentinel.state.store(WriteState::Completed(written));
Ok(())
}
WriteState::Completed(_) => Ok(()),
WriteState::Failed => Err(CompleteWritingError::FileWritingFailed),
};
self.sentinel.wake_readers();
result
}
fn update_state(state: &AtomicCell<WriteState>, written: usize) -> Result<usize, Error> {
match state.load() {
WriteState::Pending(committed, previously_written) => {
let count = previously_written + written;
state.store(WriteState::Pending(committed, count));
Ok(count)
}
WriteState::Completed(count) => {
if written != 0 {
return Err(Error::new(ErrorKind::BrokenPipe, WriteError::FileClosed));
}
Ok(count)
}
WriteState::Failed => Err(Error::from(ErrorKind::Other)),
}
}
fn handle_poll_write_result(
sentinel: &Sentinel<T>,
poll: Poll<Result<usize, Error>>,
) -> Poll<Result<usize, Error>> {
match poll {
Poll::Ready(result) => match result {
Ok(written) => match Self::update_state(&sentinel.state, written) {
Ok(_) => Poll::Ready(Ok(written)),
Err(e) => Poll::Ready(Err(e)),
},
Err(e) => {
sentinel.state.store(WriteState::Failed);
sentinel.wake_readers();
Poll::Ready(Err(e))
}
},
Poll::Pending => Poll::Pending,
}
}
}
#[pinned_drop]
impl<T> PinnedDrop for SharedFileWriter<T> {
fn drop(mut self: Pin<&mut Self>) {
self.finalize_state().ok();
}
}
#[derive(Debug, thiserror::Error)]
pub enum CompleteWritingError {
#[error(transparent)]
Io(#[from] Error),
#[error("Writing to the file failed")]
FileWritingFailed,
#[error("Failed to synchronize the file with the underlying buffer")]
SyncError,
}
#[derive(Debug, thiserror::Error)]
pub enum WriteError {
#[error(transparent)]
Io(#[from] Error),
#[error("The file was already closed")]
FileClosed,
}
impl<T> AsyncWrite for SharedFileWriter<T>
where
T: AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.project();
let poll = this.file.poll_write(cx, buf);
Self::handle_poll_write_result(&this.sentinel, poll)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.project();
match this.file.poll_flush(cx) {
Poll::Ready(result) => match result {
Ok(()) => {
Self::sync_committed_and_written(&this.sentinel);
this.sentinel.wake_readers();
Poll::Ready(Ok(()))
}
Err(e) => {
this.sentinel.state.store(WriteState::Failed);
this.sentinel.wake_readers();
Poll::Ready(Err(e))
}
},
Poll::Pending => Poll::Pending,
}
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
let this = self.project();
match this.file.poll_shutdown(cx) {
Poll::Ready(result) => match result {
Ok(()) => {
if let WriteState::Pending(_committed, written) = this.sentinel.state.load() {
debug_assert_eq!(_committed, written);
this.sentinel.state.store(WriteState::Completed(written));
}
Poll::Ready(Ok(()))
}
Err(e) => {
this.sentinel.state.store(WriteState::Failed);
Poll::Ready(Err(e))
}
},
Poll::Pending => Poll::Pending,
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, Error>> {
let this = self.project();
let poll = this.file.poll_write_vectored(cx, bufs);
Self::handle_poll_write_result(&this.sentinel, poll)
}
fn is_write_vectored(&self) -> bool {
self.file.is_write_vectored()
}
}