#[cfg(not(loom))]
use std::sync::atomic::Ordering;
use std::{marker::PhantomData, mem, ops::Range};
use async_local::LocalRef;
#[cfg(loom)]
use loom::sync::atomic::Ordering;
#[cfg(not(loom))]
use tokio::task::{spawn_blocking, JoinHandle};
use crate::{
  queue::{Inner, TaskQueue, INDEX_SHIFT, PHASE},
  task::TaskRef,
  BufferCell,
};
pub struct PendingAssignment<'a, T: TaskQueue, const N: usize> {
  base_slot: usize,
  queue: LocalRef<'a, Inner<TaskRef<T>, N>>,
}
impl<'a, T, const N: usize> PendingAssignment<'a, T, N>
where
  T: TaskQueue,
{
  pub(crate) fn new(base_slot: usize, queue: LocalRef<'a, Inner<TaskRef<T>, N>>) -> Self {
    PendingAssignment { base_slot, queue }
  }
  fn set_assignment_bounds(&self) -> Range<usize> {
    let end_slot = self.queue.slot.fetch_xor(PHASE, Ordering::Relaxed);
    (self.base_slot >> INDEX_SHIFT)..(end_slot >> INDEX_SHIFT)
  }
  pub fn into_assignment(self) -> TaskAssignment<'a, T, N> {
    let task_range = self.set_assignment_bounds();
    let queue = self.queue;
    mem::forget(self);
    TaskAssignment::new(task_range, queue)
  }
  #[cfg(not(loom))]
  pub async fn with_blocking<F>(self, f: F) -> CompletionReceipt<T>
  where
    F: for<'b> FnOnce(PendingAssignment<'b, T, N>) -> CompletionReceipt<T> + Send + 'static,
  {
    let batch: PendingAssignment<'_, T, N> = unsafe { std::mem::transmute(self) };
    tokio::task::spawn_blocking(move || f(batch)).await.unwrap()
  }
}
unsafe impl<'a, T, const N: usize> Send for PendingAssignment<'a, T, N> where T: TaskQueue {}
unsafe impl<'a, T, const N: usize> Sync for PendingAssignment<'a, T, N> where T: TaskQueue {}
impl<'a, T, const N: usize> Drop for PendingAssignment<'a, T, N>
where
  T: TaskQueue,
{
  fn drop(&mut self) {
    let task_range = self.set_assignment_bounds();
    let queue = self.queue;
    TaskAssignment::new(task_range, queue);
  }
}
pub struct TaskAssignment<'a, T: TaskQueue, const N: usize> {
  task_range: Range<usize>,
  queue: LocalRef<'a, Inner<TaskRef<T>, N>>,
}
impl<'a, T, const N: usize> TaskAssignment<'a, T, N>
where
  T: TaskQueue,
{
  fn new(task_range: Range<usize>, queue: LocalRef<'a, Inner<TaskRef<T>, N>>) -> Self {
    TaskAssignment { task_range, queue }
  }
  pub fn as_slices(&self) -> (&[TaskRef<T>], &[TaskRef<T>]) {
    let start = self.task_range.start & (N - 1);
    let end = self.task_range.end & (N - 1);
    if end > start {
      unsafe { (self.queue.buffer.get_unchecked(start..end), &[]) }
    } else {
      unsafe {
        (
          self.queue.buffer.get_unchecked(start..N),
          self.queue.buffer.get_unchecked(0..end),
        )
      }
    }
  }
  pub fn tasks(&self) -> impl Iterator<Item = &TaskRef<T>> {
    let tasks = self.as_slices();
    tasks.0.iter().chain(tasks.1.iter())
  }
  pub fn resolve_with_iter<I>(self, iter: I) -> CompletionReceipt<T>
  where
    I: IntoIterator<Item = T::Value>,
  {
    self.tasks().zip(iter).for_each(|(task_ref, value)| unsafe {
      drop(task_ref.take_task_unchecked());
      task_ref.resolve_unchecked(value);
    });
    self.into_completion_receipt()
  }
  pub fn map<F>(self, op: F) -> CompletionReceipt<T>
  where
    F: Fn(T::Task) -> T::Value + Sync,
  {
    self.tasks().for_each(|task_ref| unsafe {
      let task = task_ref.take_task_unchecked();
      task_ref.resolve_unchecked(op(task));
    });
    self.into_completion_receipt()
  }
  fn deoccupy_buffer(&self) {
    self.queue.deoccupy_region(self.task_range.start & (N - 1));
  }
  fn into_completion_receipt(self) -> CompletionReceipt<T> {
    self.deoccupy_buffer();
    mem::forget(self);
    CompletionReceipt::new()
  }
  #[cfg(not(loom))]
  pub async fn with_blocking<F>(self, f: F) -> CompletionReceipt<T>
  where
    F: for<'b> FnOnce(TaskAssignment<'b, T, N>) -> CompletionReceipt<T> + Send + 'static,
  {
    let batch: TaskAssignment<'_, T, N> = unsafe { std::mem::transmute(self) };
    tokio::task::spawn_blocking(move || f(batch)).await.unwrap()
  }
}
impl<'a, T, const N: usize> Drop for TaskAssignment<'a, T, N>
where
  T: TaskQueue,
{
  fn drop(&mut self) {
    self
      .tasks()
      .for_each(|task_ref| unsafe { drop(task_ref.take_task_unchecked()) });
    self.deoccupy_buffer();
  }
}
unsafe impl<'a, T, const N: usize> Send for TaskAssignment<'a, T, N> where T: TaskQueue {}
unsafe impl<'a, T, const N: usize> Sync for TaskAssignment<'a, T, N> where T: TaskQueue {}
pub struct CompletionReceipt<T: TaskQueue>(PhantomData<T>);
impl<T> CompletionReceipt<T>
where
  T: TaskQueue,
{
  fn new() -> Self {
    CompletionReceipt(PhantomData)
  }
}
pub struct UnboundedRange<'a, T: Send + Sync + Sized + 'static, const N: usize> {
  base_slot: usize,
  queue: LocalRef<'a, Inner<BufferCell<T>, N>>,
}
impl<'a, T, const N: usize> UnboundedRange<'a, T, N>
where
  T: Send + Sync + Sized + 'static,
{
  pub(crate) fn new(base_slot: usize, queue: LocalRef<'a, Inner<BufferCell<T>, N>>) -> Self {
    UnboundedRange { base_slot, queue }
  }
  fn set_bounds(&self) -> Range<usize> {
    let end_slot = self.queue.slot.fetch_xor(PHASE, Ordering::Relaxed);
    (self.base_slot >> INDEX_SHIFT)..(end_slot >> INDEX_SHIFT)
  }
  pub fn into_bounded(self) -> BoundedRange<'a, T, N> {
    let range = self.set_bounds();
    let queue = self.queue;
    mem::forget(self);
    BoundedRange::new(range, queue)
  }
  #[cfg(not(loom))]
  pub fn with_blocking<F, R>(self, f: F) -> JoinHandle<R>
  where
    F: for<'b> FnOnce(UnboundedRange<'b, T, N>) -> R + Send + 'static,
    R: Send + 'static,
  {
    let batch: UnboundedRange<'_, T, N> = unsafe { std::mem::transmute(self) };
    spawn_blocking(move || f(batch))
  }
}
impl<'a, T, const N: usize> Drop for UnboundedRange<'a, T, N>
where
  T: Send + Sync + Sized + 'static,
{
  fn drop(&mut self) {
    let task_range = self.set_bounds();
    let start_index = task_range.start & (N - 1);
    let queue = self.queue;
    for index in task_range {
      unsafe {
        queue.with_buffer_cell(|cell| (*cell).assume_init_drop(), index & (N - 1));
      }
    }
    self.queue.deoccupy_region(start_index);
  }
}
unsafe impl<'a, T, const N: usize> Send for UnboundedRange<'a, T, N> where
  T: Send + Sync + Sized + 'static
{
}
unsafe impl<'a, T, const N: usize> Sync for UnboundedRange<'a, T, N> where
  T: Send + Sync + Sized + 'static
{
}
pub struct BoundedRange<'a, T: Send + Sync + Sized + 'static, const N: usize> {
  range: Range<usize>,
  queue: LocalRef<'a, Inner<BufferCell<T>, N>>,
}
impl<'a, T, const N: usize> BoundedRange<'a, T, N>
where
  T: Send + Sync + Sized + 'static,
{
  fn new(range: Range<usize>, queue: LocalRef<'a, Inner<BufferCell<T>, N>>) -> Self {
    BoundedRange { range, queue }
  }
  #[cfg(not(loom))]
  pub fn as_slices(&self) -> (&[T], &[T]) {
    let start = self.range.start & (N - 1);
    let end = self.range.end & (N - 1);
    if end > start {
      unsafe {
        mem::transmute::<(&[BufferCell<T>], &[BufferCell<T>]), _>((
          self.queue.buffer.get_unchecked(start..end),
          &[],
        ))
      }
    } else {
      unsafe {
        mem::transmute((
          self.queue.buffer.get_unchecked(start..N),
          self.queue.buffer.get_unchecked(0..end),
        ))
      }
    }
  }
  #[cfg(not(loom))]
  pub fn iter(&self) -> impl Iterator<Item = &T> {
    let tasks = self.as_slices();
    tasks.0.iter().chain(tasks.1.iter())
  }
  #[cfg(not(loom))]
  pub fn to_vec(self) -> Vec<T> {
    let items = self.as_slices();
    let front_len = items.0.len();
    let back_len = items.1.len();
    let total_len = front_len + back_len;
    let mut buffer = Vec::new();
    buffer.reserve_exact(total_len);
    unsafe {
      std::ptr::copy_nonoverlapping(items.0.as_ptr(), buffer.as_mut_ptr(), front_len);
      if back_len > 0 {
        std::ptr::copy_nonoverlapping(
          items.1.as_ptr(),
          buffer.as_mut_ptr().add(front_len),
          back_len,
        );
      }
      buffer.set_len(total_len);
    }
    self.deoccupy_buffer();
    mem::forget(self);
    buffer
  }
  #[cfg(not(loom))]
  pub fn with_blocking<F, R>(self, f: F) -> JoinHandle<R>
  where
    F: for<'b> FnOnce(BoundedRange<'b, T, N>) -> R + Send + 'static,
    R: Send + 'static,
  {
    let batch: BoundedRange<'_, T, N> = unsafe { std::mem::transmute(self) };
    batch.queue.with_blocking(move |_| f(batch))
  }
  fn deoccupy_buffer(&self) {
    self.queue.deoccupy_region(self.range.start & (N - 1));
  }
}
impl<'a, T, const N: usize> Drop for BoundedRange<'a, T, N>
where
  T: Send + Sync + Sized + 'static,
{
  fn drop(&mut self) {
    for index in self.range.clone() {
      unsafe {
        self
          .queue
          .with_buffer_cell(|cell| (*cell).assume_init_drop(), index & (N - 1));
      }
    }
    self.deoccupy_buffer();
  }
}
unsafe impl<'a, T, const N: usize> Send for BoundedRange<'a, T, N> where
  T: Send + Sync + Sized + 'static
{
}
unsafe impl<'a, T, const N: usize> Sync for BoundedRange<'a, T, N> where
  T: Send + Sync + Sized + 'static
{
}
pub struct BufferIter<'a, T: Send + Sync + Sized + 'static, const N: usize> {
  current: usize,
  range: Range<usize>,
  queue: LocalRef<'a, Inner<BufferCell<T>, N>>,
}
impl<'a, T, const N: usize> BufferIter<'a, T, N>
where
  T: Send + Sync + Sized + 'static,
{
  fn deoccupy_buffer(&self) {
    self.queue.deoccupy_region(self.range.start & (N - 1));
  }
}
impl<'a, T, const N: usize> Iterator for BufferIter<'a, T, N>
where
  T: Send + Sync + Sized + 'static,
{
  type Item = T;
  fn next(&mut self) -> Option<Self::Item> {
    if self.current < self.range.end {
      let task = unsafe {
        self
          .queue
          .with_buffer_cell(|cell| (*cell).assume_init_read(), self.current & (N - 1))
      };
      self.current += 1;
      Some(task)
    } else {
      None
    }
  }
}
unsafe impl<'a, T, const N: usize> Send for BufferIter<'a, T, N> where
  T: Send + Sync + Sized + 'static
{
}
unsafe impl<'a, T, const N: usize> Sync for BufferIter<'a, T, N> where
  T: Send + Sync + Sized + 'static
{
}
impl<'a, T, const N: usize> Drop for BufferIter<'a, T, N>
where
  T: Send + Sync + Sized + 'static,
{
  fn drop(&mut self) {
    while self.next().is_some() {}
    self.deoccupy_buffer();
  }
}
impl<'a, T, const N: usize> IntoIterator for BoundedRange<'a, T, N>
where
  T: Send + Sync + Sized + 'static,
{
  type Item = T;
  type IntoIter = BufferIter<'a, T, N>;
  fn into_iter(self) -> Self::IntoIter {
    let iter = BufferIter {
      current: self.range.start,
      range: self.range.clone(),
      queue: self.queue,
    };
    mem::forget(self);
    iter
  }
}