use rayon::{
    iter::plumbing::*,
    prelude::{IndexedParallelIterator, ParallelIterator},
};
use zng_app_context::LocalContext;
pub trait ParallelIteratorExt: ParallelIterator {
    fn with_ctx(self) -> ParallelIteratorWithCtx<Self> {
        ParallelIteratorWithCtx {
            base: self,
            ctx: LocalContext::capture(),
        }
    }
}
impl<I: ParallelIterator> ParallelIteratorExt for I {}
pub struct ParallelIteratorWithCtx<I> {
    base: I,
    ctx: LocalContext,
}
impl<T, I> ParallelIterator for ParallelIteratorWithCtx<I>
where
    T: Send,
    I: ParallelIterator<Item = T>,
{
    type Item = T;
    fn drive_unindexed<C>(mut self, consumer: C) -> C::Result
    where
        C: UnindexedConsumer<Self::Item>,
    {
        let consumer = ParallelCtxConsumer {
            base: consumer,
            ctx: self.ctx.clone(),
        };
        self.ctx.with_context(move || self.base.drive_unindexed(consumer))
    }
    fn opt_len(&self) -> Option<usize> {
        self.base.opt_len()
    }
}
impl<I: IndexedParallelIterator> IndexedParallelIterator for ParallelIteratorWithCtx<I> {
    fn len(&self) -> usize {
        self.base.len()
    }
    fn drive<C: Consumer<Self::Item>>(mut self, consumer: C) -> C::Result {
        let consumer = ParallelCtxConsumer {
            base: consumer,
            ctx: self.ctx.clone(),
        };
        self.ctx.with_context(move || self.base.drive(consumer))
    }
    fn with_producer<CB: ProducerCallback<Self::Item>>(mut self, callback: CB) -> CB::Output {
        let callback = ParallelCtxProducerCallback {
            base: callback,
            ctx: self.ctx.clone(),
        };
        self.ctx.with_context(move || self.base.with_producer(callback))
    }
}
struct ParallelCtxConsumer<C> {
    base: C,
    ctx: LocalContext,
}
impl<T, C> Consumer<T> for ParallelCtxConsumer<C>
where
    C: Consumer<T>,
    T: Send,
{
    type Folder = ParallelCtxFolder<C::Folder>;
    type Reducer = ParallelCtxReducer<C::Reducer>;
    type Result = C::Result;
    fn split_at(mut self, index: usize) -> (Self, Self, Self::Reducer) {
        let (left, right, reducer) = self.ctx.with_context(|| self.base.split_at(index));
        let reducer = ParallelCtxReducer {
            base: reducer,
            ctx: self.ctx.clone(),
        };
        let left = Self {
            base: left,
            ctx: self.ctx.clone(),
        };
        let right = Self {
            base: right,
            ctx: self.ctx,
        };
        (left, right, reducer)
    }
    fn into_folder(mut self) -> Self::Folder {
        let base = self.ctx.with_context(|| self.base.into_folder());
        ParallelCtxFolder { base, ctx: self.ctx }
    }
    fn full(&self) -> bool {
        self.base.full()
    }
}
impl<T, C> UnindexedConsumer<T> for ParallelCtxConsumer<C>
where
    C: UnindexedConsumer<T>,
    T: Send,
{
    fn split_off_left(&self) -> Self {
        Self {
            base: self.base.split_off_left(),
            ctx: self.ctx.clone(),
        }
    }
    fn to_reducer(&self) -> Self::Reducer {
        ParallelCtxReducer {
            base: self.base.to_reducer(),
            ctx: self.ctx.clone(),
        }
    }
}
struct ParallelCtxFolder<F> {
    base: F,
    ctx: LocalContext,
}
impl<Item, F> Folder<Item> for ParallelCtxFolder<F>
where
    F: Folder<Item>,
{
    type Result = F::Result;
    fn consume(mut self, item: Item) -> Self {
        let base = self.ctx.with_context(move || self.base.consume(item));
        Self { base, ctx: self.ctx }
    }
    fn complete(mut self) -> Self::Result {
        self.ctx.with_context(|| self.base.complete())
    }
    fn full(&self) -> bool {
        self.base.full()
    }
}
struct ParallelCtxReducer<R> {
    base: R,
    ctx: LocalContext,
}
impl<Result, R> Reducer<Result> for ParallelCtxReducer<R>
where
    R: Reducer<Result>,
{
    fn reduce(mut self, left: Result, right: Result) -> Result {
        self.ctx.with_context(move || self.base.reduce(left, right))
    }
}
struct ParallelCtxProducerCallback<C> {
    base: C,
    ctx: LocalContext,
}
impl<T, C: ProducerCallback<T>> ProducerCallback<T> for ParallelCtxProducerCallback<C> {
    type Output = C::Output;
    fn callback<P>(mut self, producer: P) -> Self::Output
    where
        P: Producer<Item = T>,
    {
        let producer = ParallelCtxProducer {
            base: producer,
            ctx: self.ctx.clone(),
        };
        self.ctx.with_context(move || self.base.callback(producer))
    }
}
struct ParallelCtxProducer<P> {
    base: P,
    ctx: LocalContext,
}
impl<P: Producer> Producer for ParallelCtxProducer<P> {
    type Item = P::Item;
    type IntoIter = P::IntoIter;
    fn into_iter(mut self) -> Self::IntoIter {
        self.ctx.with_context(|| self.base.into_iter())
    }
    fn split_at(mut self, index: usize) -> (Self, Self) {
        let (left, right) = self.ctx.with_context(|| self.base.split_at(index));
        (
            Self {
                base: left,
                ctx: self.ctx.clone(),
            },
            Self {
                base: right,
                ctx: self.ctx,
            },
        )
    }
}
#[cfg(test)]
mod tests {
    use std::sync::{
        atomic::{AtomicBool, AtomicU32, Ordering},
        Arc,
    };
    use super::*;
    use rayon::prelude::*;
    use zng_app_context::*;
    context_local! {
        static VALUE: u32 = 0u32;
    }
    #[test]
    fn map_and_sum_with_context() {
        let _app = LocalContext::start_app(AppId::new_unique());
        let thread_id = std::thread::current().id();
        let used_other_thread = Arc::new(AtomicBool::new(false));
        let sum: u32 = VALUE.with_context(&mut Some(Arc::new(1)), || {
            (0..1000)
                .into_par_iter()
                .with_ctx()
                .map(|_| {
                    if thread_id != std::thread::current().id() {
                        used_other_thread.store(true, Ordering::Relaxed);
                    }
                    *VALUE.get()
                })
                .sum()
        });
        assert_eq!(sum, 1000);
        assert!(used_other_thread.load(Ordering::Relaxed));
    }
    #[test]
    fn for_each_with_context() {
        let _app = LocalContext::start_app(AppId::new_unique());
        let thread_id = std::thread::current().id();
        let used_other_thread = Arc::new(AtomicBool::new(false));
        let sum: u32 = VALUE.with_context(&mut Some(Arc::new(1)), || {
            let sum = Arc::new(AtomicU32::new(0));
            (0..1000).into_par_iter().with_ctx().for_each(|_| {
                if thread_id != std::thread::current().id() {
                    used_other_thread.store(true, Ordering::Relaxed);
                }
                sum.fetch_add(*VALUE.get(), Ordering::Relaxed);
            });
            sum.load(Ordering::Relaxed)
        });
        assert_eq!(sum, 1000);
        assert!(used_other_thread.load(Ordering::Relaxed));
    }
    #[test]
    fn chain_for_each_with_context() {
        let _app = LocalContext::start_app(AppId::new_unique());
        let thread_id = std::thread::current().id();
        let used_other_thread = Arc::new(AtomicBool::new(false));
        let sum: u32 = VALUE.with_context(&mut Some(Arc::new(1)), || {
            let sum = Arc::new(AtomicU32::new(0));
            let a = (0..500).into_par_iter();
            let b = (0..500).into_par_iter();
            a.chain(b).with_ctx().for_each(|_| {
                if thread_id != std::thread::current().id() {
                    used_other_thread.store(true, Ordering::Relaxed);
                }
                sum.fetch_add(*VALUE.get(), Ordering::Relaxed);
            });
            sum.load(Ordering::Relaxed)
        });
        assert_eq!(sum, 1000);
        assert!(used_other_thread.load(Ordering::Relaxed));
    }
    #[test]
    fn chain_for_each_with_context_inverted() {
        let _app = LocalContext::start_app(AppId::new_unique());
        let thread_id = std::thread::current().id();
        let used_other_thread = Arc::new(AtomicBool::new(false));
        let sum: u32 = VALUE.with_context(&mut Some(Arc::new(1)), || {
            let sum = Arc::new(AtomicU32::new(0));
            let a = (0..500).into_par_iter().with_ctx();
            let b = (0..500).into_par_iter().with_ctx();
            a.chain(b).for_each(|_| {
                if thread_id != std::thread::current().id() {
                    used_other_thread.store(true, Ordering::Relaxed);
                }
                sum.fetch_add(*VALUE.get(), Ordering::Relaxed);
            });
            sum.load(Ordering::Relaxed)
        });
        assert_eq!(sum, 1000);
        assert!(used_other_thread.load(Ordering::Relaxed));
    }
}