tokenizers/utils/
parallelism.rs

1//!
2//! This module defines helpers to allow optional Rayon usage.
3//!
4
5use rayon::iter::IterBridge;
6use rayon::prelude::*;
7use rayon_cond::CondIterator;
8use std::sync::atomic::AtomicBool;
9use std::sync::atomic::AtomicU8;
10use std::sync::atomic::Ordering;
11
12// Re-export rayon current_num_threads
13pub use rayon::current_num_threads;
14
15pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM";
16
17static USED_PARALLELISM: AtomicBool = AtomicBool::new(false);
18static PARALLELISM: AtomicU8 = AtomicU8::new(0);
19
20/// Check if the TOKENIZERS_PARALLELISM env variable has been explicitly set
21pub fn is_parallelism_configured() -> bool {
22    std::env::var(ENV_VARIABLE).is_ok() || get_override_parallelism().is_some()
23}
24
25/// Check if at some point we used a parallel iterator
26pub fn has_parallelism_been_used() -> bool {
27    USED_PARALLELISM.load(Ordering::SeqCst)
28}
29
30/// Get internally set parallelism
31fn get_override_parallelism() -> Option<bool> {
32    match PARALLELISM.load(Ordering::SeqCst) {
33        0 => None,
34        1 => Some(false),
35        2 => Some(true),
36        _ => unreachable!(),
37    }
38}
39
40/// Get the currently set value for `TOKENIZERS_PARALLELISM` env variable
41fn get_env_parallelism() -> bool {
42    match std::env::var(ENV_VARIABLE) {
43        Ok(mut v) => {
44            v.make_ascii_lowercase();
45            !matches!(v.as_ref(), "" | "off" | "false" | "f" | "no" | "n" | "0")
46        }
47        Err(_) => true, // If we couldn't get the variable, we use the default
48    }
49}
50
51pub fn get_parallelism() -> bool {
52    if let Some(parallel) = get_override_parallelism() {
53        parallel
54    } else {
55        get_env_parallelism()
56    }
57}
58
59/// Set the value for `TOKENIZERS_PARALLELISM` for the current process
60pub fn set_parallelism(val: bool) {
61    PARALLELISM.store(if val { 2 } else { 1 }, Ordering::SeqCst);
62}
63
64/// Allows to convert into an iterator that can be executed either parallelly or serially.
65///
66/// The choice is made according to the currently set `TOKENIZERS_PARALLELISM` environment variable.
67/// This variable can have one of the following values
68///   - False => "" (empty value), "false", "f", "off", "no", "n", "0"
69///   - True => Any other value
70///
71pub trait MaybeParallelIterator<P, S>
72where
73    P: ParallelIterator,
74    S: Iterator<Item = P::Item>,
75{
76    /// Convert ourself in a CondIterator, that will be executed either in parallel or serially,
77    /// based solely on the `TOKENIZERS_PARALLELISM` environment variable
78    fn into_maybe_par_iter(self) -> CondIterator<P, S>;
79    /// Convert ourself in a CondIterator, that will be executed either in parallel or serially,
80    /// based on both the `TOKENIZERS_PARALLELISM` environment variable and the provided bool.
81    /// Both must be true to run with parallelism activated.
82    fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator<P, S>;
83}
84
85impl<P, S, I> MaybeParallelIterator<P, S> for I
86where
87    I: IntoParallelIterator<Iter = P, Item = P::Item> + IntoIterator<IntoIter = S, Item = S::Item>,
88    P: ParallelIterator,
89    S: Iterator<Item = P::Item>,
90{
91    fn into_maybe_par_iter(self) -> CondIterator<P, S> {
92        let parallelism = get_parallelism();
93        if parallelism {
94            USED_PARALLELISM.store(true, Ordering::SeqCst);
95        }
96        CondIterator::new(self, parallelism)
97    }
98
99    fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator<P, S> {
100        if cond {
101            self.into_maybe_par_iter()
102        } else {
103            CondIterator::from_serial(self)
104        }
105    }
106}
107
108/// Shared reference version of MaybeParallelIterator, works the same but returns an iterator
109/// over references, does not consume self
110pub trait MaybeParallelRefIterator<'data, P, S>
111where
112    P: ParallelIterator,
113    S: Iterator<Item = P::Item>,
114    P::Item: 'data,
115{
116    fn maybe_par_iter(&'data self) -> CondIterator<P, S>;
117    fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator<P, S>;
118}
119
120impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefIterator<'data, P, S> for I
121where
122    &'data I: MaybeParallelIterator<P, S>,
123    P: ParallelIterator,
124    S: Iterator<Item = P::Item>,
125    P::Item: 'data,
126{
127    fn maybe_par_iter(&'data self) -> CondIterator<P, S> {
128        self.into_maybe_par_iter()
129    }
130
131    fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator<P, S> {
132        self.into_maybe_par_iter_cond(cond)
133    }
134}
135
136/// Exclusive reference version of MaybeParallelIterator, works the same but returns an iterator
137/// over mutable references, does not consume self
138pub trait MaybeParallelRefMutIterator<'data, P, S>
139where
140    P: ParallelIterator,
141    S: Iterator<Item = P::Item>,
142    P::Item: 'data,
143{
144    fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S>;
145    fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator<P, S>;
146}
147
148impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefMutIterator<'data, P, S> for I
149where
150    &'data mut I: MaybeParallelIterator<P, S>,
151    P: ParallelIterator,
152    S: Iterator<Item = P::Item>,
153    P::Item: 'data,
154{
155    fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S> {
156        self.into_maybe_par_iter()
157    }
158
159    fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator<P, S> {
160        self.into_maybe_par_iter_cond(cond)
161    }
162}
163
164/// Converts any serial iterator into a CondIterator, that can either run parallelly or serially.
165pub trait MaybeParallelBridge<T, S>
166where
167    S: Iterator<Item = T> + Send,
168    T: Send,
169{
170    fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S>;
171    fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator<IterBridge<S>, S>;
172}
173
174impl<T, S> MaybeParallelBridge<T, S> for S
175where
176    S: Iterator<Item = T> + Send,
177    T: Send,
178{
179    fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S> {
180        let iter = CondIterator::from_serial(self);
181
182        if get_parallelism() {
183            USED_PARALLELISM.store(true, Ordering::SeqCst);
184            CondIterator::from_parallel(iter.into_parallel().right().unwrap())
185        } else {
186            iter
187        }
188    }
189
190    fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator<IterBridge<S>, S> {
191        if cond {
192            self.maybe_par_bridge()
193        } else {
194            CondIterator::from_serial(self)
195        }
196    }
197}
198
199/// Allows to convert into `chunks` that can be executed either parallelly or serially.
200pub trait MaybeParallelSlice<'data, T>
201where
202    T: Sync,
203{
204    /// Create a CondIterator, that will be executed either in parallel or serially,
205    /// based solely on the `TOKENIZERS_PARALLELISM` environment variable
206    fn maybe_par_chunks(
207        &'_ self,
208        chunk_size: usize,
209    ) -> CondIterator<rayon::slice::Chunks<'_, T>, std::slice::Chunks<'_, T>>;
210    /// Create a CondIterator, that will be executed either in parallel or serially,
211    /// based on both the `TOKENIZERS_PARALLELISM` environment variable and the provided bool.
212    /// Both must be true to run with parallelism activated.
213    fn maybe_par_chunks_cond(
214        &'_ self,
215        cond: bool,
216        chunk_size: usize,
217    ) -> CondIterator<rayon::slice::Chunks<'_, T>, std::slice::Chunks<'_, T>>;
218}
219
220impl<T> MaybeParallelSlice<'_, T> for [T]
221where
222    T: Sync,
223{
224    fn maybe_par_chunks(
225        &'_ self,
226        chunk_size: usize,
227    ) -> CondIterator<rayon::slice::Chunks<'_, T>, std::slice::Chunks<'_, T>> {
228        let parallelism = get_parallelism();
229        if parallelism {
230            CondIterator::from_parallel(self.par_chunks(chunk_size))
231        } else {
232            CondIterator::from_serial(self.chunks(chunk_size))
233        }
234    }
235    fn maybe_par_chunks_cond(
236        &'_ self,
237        cond: bool,
238        chunk_size: usize,
239    ) -> CondIterator<rayon::slice::Chunks<'_, T>, std::slice::Chunks<'_, T>> {
240        if cond {
241            self.maybe_par_chunks(chunk_size)
242        } else {
243            CondIterator::from_serial(self.chunks(chunk_size))
244        }
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn test_maybe_parallel_iterator() {
254        let mut v = vec![1u32, 2, 3, 4, 5, 6];
255
256        assert_eq!(v.maybe_par_iter().sum::<u32>(), 21);
257        assert_eq!(
258            v.maybe_par_iter_mut()
259                .map(|v| {
260                    *v *= 2;
261                    *v
262                })
263                .sum::<u32>(),
264            42
265        );
266        assert_eq!(v.maybe_par_iter().sum::<u32>(), 42);
267        assert_eq!(v.into_maybe_par_iter().sum::<u32>(), 42);
268    }
269
270    #[test]
271    fn test_maybe_parallel_slice() {
272        let v = [1, 2, 3, 4, 5];
273
274        let chunks: Vec<_> = v.maybe_par_chunks(2).collect();
275        assert_eq!(chunks, vec![&[1, 2][..], &[3, 4], &[5]]);
276    }
277}