tantivy/core/
executor.rs

1use std::sync::Arc;
2
3#[cfg(feature = "quickwit")]
4use futures_util::{future::Either, FutureExt};
5
6use crate::TantivyError;
7
8/// Executor makes it possible to run tasks in single thread or
9/// in a thread pool.
10#[derive(Clone)]
11pub enum Executor {
12    /// Single thread variant of an Executor
13    SingleThread,
14    /// Thread pool variant of an Executor
15    ThreadPool(Arc<rayon::ThreadPool>),
16}
17
18#[cfg(feature = "quickwit")]
19impl From<Arc<rayon::ThreadPool>> for Executor {
20    fn from(thread_pool: Arc<rayon::ThreadPool>) -> Self {
21        Executor::ThreadPool(thread_pool)
22    }
23}
24
25impl Executor {
26    /// Creates an Executor that performs all task in the caller thread.
27    pub fn single_thread() -> Executor {
28        Executor::SingleThread
29    }
30
31    /// Creates an Executor that dispatches the tasks in a thread pool.
32    pub fn multi_thread(num_threads: usize, prefix: &'static str) -> crate::Result<Executor> {
33        let pool = rayon::ThreadPoolBuilder::new()
34            .num_threads(num_threads)
35            .thread_name(move |num| format!("{prefix}{num}"))
36            .build()?;
37        Ok(Executor::ThreadPool(Arc::new(pool)))
38    }
39
40    /// Perform a map in the thread pool.
41    ///
42    /// Regardless of the executor (`SingleThread` or `ThreadPool`), panics in the task
43    /// will propagate to the caller.
44    pub fn map<A, R, F>(&self, f: F, args: impl Iterator<Item = A>) -> crate::Result<Vec<R>>
45    where
46        A: Send,
47        R: Send,
48        F: Sized + Sync + Fn(A) -> crate::Result<R>,
49    {
50        match self {
51            Executor::SingleThread => args.map(f).collect::<crate::Result<_>>(),
52            Executor::ThreadPool(pool) => {
53                let args: Vec<A> = args.collect();
54                let num_fruits = args.len();
55                let fruit_receiver = {
56                    let (fruit_sender, fruit_receiver) = crossbeam_channel::unbounded();
57                    pool.scope(|scope| {
58                        for (idx, arg) in args.into_iter().enumerate() {
59                            // We name references for f and fruit_sender_ref because we do not
60                            // want these two to be moved into the closure.
61                            let f_ref = &f;
62                            let fruit_sender_ref = &fruit_sender;
63                            scope.spawn(move |_| {
64                                let fruit = f_ref(arg);
65                                if let Err(err) = fruit_sender_ref.send((idx, fruit)) {
66                                    error!(
67                                        "Failed to send search task. It probably means all search \
68                                         threads have panicked. {err:?}"
69                                    );
70                                }
71                            });
72                        }
73                    });
74                    fruit_receiver
75                    // This ends the scope of fruit_sender.
76                    // This is important as it makes it possible for the fruit_receiver iteration to
77                    // terminate.
78                };
79                let mut result_placeholders: Vec<Option<R>> =
80                    std::iter::repeat_with(|| None).take(num_fruits).collect();
81                for (pos, fruit_res) in fruit_receiver {
82                    let fruit = fruit_res?;
83                    result_placeholders[pos] = Some(fruit);
84                }
85                let results: Vec<R> = result_placeholders.into_iter().flatten().collect();
86                if results.len() != num_fruits {
87                    return Err(TantivyError::InternalError(
88                        "One of the mapped execution failed.".to_string(),
89                    ));
90                }
91                Ok(results)
92            }
93        }
94    }
95
96    /// Spawn a task on the pool, returning a future completing on task success.
97    ///
98    /// If the task panics, returns `Err(())`.
99    #[cfg(feature = "quickwit")]
100    pub fn spawn_blocking<T: Send + 'static>(
101        &self,
102        cpu_intensive_task: impl FnOnce() -> T + Send + 'static,
103    ) -> impl std::future::Future<Output = Result<T, ()>> {
104        match self {
105            Executor::SingleThread => Either::Left(std::future::ready(Ok(cpu_intensive_task()))),
106            Executor::ThreadPool(pool) => {
107                let (sender, receiver) = oneshot::channel();
108                pool.spawn(|| {
109                    if sender.is_closed() {
110                        return;
111                    }
112                    let task_result = cpu_intensive_task();
113                    let _ = sender.send(task_result);
114                });
115
116                let res = receiver.map(|res| res.map_err(|_| ()));
117                Either::Right(res)
118            }
119        }
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::Executor;
126
127    #[test]
128    #[should_panic(expected = "panic should propagate")]
129    fn test_panic_propagates_single_thread() {
130        let _result: Vec<usize> = Executor::single_thread()
131            .map(
132                |_| {
133                    panic!("panic should propagate");
134                },
135                vec![0].into_iter(),
136            )
137            .unwrap();
138    }
139
140    #[test]
141    #[should_panic] //< unfortunately the panic message is not propagated
142    fn test_panic_propagates_multi_thread() {
143        let _result: Vec<usize> = Executor::multi_thread(1, "search-test")
144            .unwrap()
145            .map(
146                |_| {
147                    panic!("panic should propagate");
148                },
149                vec![0].into_iter(),
150            )
151            .unwrap();
152    }
153
154    #[test]
155    fn test_map_singlethread() {
156        let result: Vec<usize> = Executor::single_thread()
157            .map(|i| Ok(i * 2), 0..1_000)
158            .unwrap();
159        assert_eq!(result.len(), 1_000);
160        for i in 0..1_000 {
161            assert_eq!(result[i], i * 2);
162        }
163    }
164
165    #[test]
166    fn test_map_multithread() {
167        let result: Vec<usize> = Executor::multi_thread(3, "search-test")
168            .unwrap()
169            .map(|i| Ok(i * 2), 0..10)
170            .unwrap();
171        assert_eq!(result.len(), 10);
172        for i in 0..10 {
173            assert_eq!(result[i], i * 2);
174        }
175    }
176
177    #[cfg(feature = "quickwit")]
178    #[test]
179    fn test_cancel_cpu_intensive_tasks() {
180        use std::sync::atomic::{AtomicU64, Ordering};
181        use std::sync::Arc;
182
183        let counter: Arc<AtomicU64> = Default::default();
184
185        let other_counter: Arc<AtomicU64> = Default::default();
186
187        let mut futures = Vec::new();
188        let mut other_futures = Vec::new();
189
190        let (tx, rx) = crossbeam_channel::bounded::<()>(0);
191        let rx = Arc::new(rx);
192        let executor = Executor::multi_thread(3, "search-test").unwrap();
193        for _ in 0..1000 {
194            let counter_clone: Arc<AtomicU64> = counter.clone();
195            let other_counter_clone: Arc<AtomicU64> = other_counter.clone();
196
197            let rx_clone = rx.clone();
198            let rx_clone2 = rx.clone();
199            let fut = executor.spawn_blocking(move || {
200                counter_clone.fetch_add(1, Ordering::SeqCst);
201                let _ = rx_clone.recv();
202            });
203            futures.push(fut);
204            let other_fut = executor.spawn_blocking(move || {
205                other_counter_clone.fetch_add(1, Ordering::SeqCst);
206                let _ = rx_clone2.recv();
207            });
208            other_futures.push(other_fut);
209        }
210
211        // We execute 100 futures.
212        for _ in 0..100 {
213            tx.send(()).unwrap();
214        }
215
216        let counter_val = counter.load(Ordering::SeqCst);
217        let other_counter_val = other_counter.load(Ordering::SeqCst);
218        assert!(counter_val >= 30);
219        assert!(other_counter_val >= 30);
220
221        drop(other_futures);
222
223        // We execute 100 futures.
224        for _ in 0..100 {
225            tx.send(()).unwrap();
226        }
227
228        let counter_val2 = counter.load(Ordering::SeqCst);
229        assert!(counter_val2 >= counter_val + 100 - 6);
230
231        let other_counter_val2 = other_counter.load(Ordering::SeqCst);
232        assert!(other_counter_val2 <= other_counter_val + 6);
233    }
234}