1use rayon::iter::IterBridge;
6use rayon::prelude::*;
7use rayon_cond::CondIterator;
8use std::sync::atomic::AtomicBool;
9use std::sync::atomic::AtomicU8;
10use std::sync::atomic::Ordering;
11
12pub use rayon::current_num_threads;
14
15pub const ENV_VARIABLE: &str = "TOKENIZERS_PARALLELISM";
16
17static USED_PARALLELISM: AtomicBool = AtomicBool::new(false);
18static PARALLELISM: AtomicU8 = AtomicU8::new(0);
19
20pub fn is_parallelism_configured() -> bool {
22 std::env::var(ENV_VARIABLE).is_ok() || get_override_parallelism().is_some()
23}
24
25pub fn has_parallelism_been_used() -> bool {
27 USED_PARALLELISM.load(Ordering::SeqCst)
28}
29
30fn get_override_parallelism() -> Option<bool> {
32 match PARALLELISM.load(Ordering::SeqCst) {
33 0 => None,
34 1 => Some(false),
35 2 => Some(true),
36 _ => unreachable!(),
37 }
38}
39
40fn get_env_parallelism() -> bool {
42 match std::env::var(ENV_VARIABLE) {
43 Ok(mut v) => {
44 v.make_ascii_lowercase();
45 !matches!(v.as_ref(), "" | "off" | "false" | "f" | "no" | "n" | "0")
46 }
47 Err(_) => true, }
49}
50
51pub fn get_parallelism() -> bool {
52 if let Some(parallel) = get_override_parallelism() {
53 parallel
54 } else {
55 get_env_parallelism()
56 }
57}
58
59pub fn set_parallelism(val: bool) {
61 PARALLELISM.store(if val { 2 } else { 1 }, Ordering::SeqCst);
62}
63
64pub trait MaybeParallelIterator<P, S>
72where
73 P: ParallelIterator,
74 S: Iterator<Item = P::Item>,
75{
76 fn into_maybe_par_iter(self) -> CondIterator<P, S>;
79 fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator<P, S>;
83}
84
85impl<P, S, I> MaybeParallelIterator<P, S> for I
86where
87 I: IntoParallelIterator<Iter = P, Item = P::Item> + IntoIterator<IntoIter = S, Item = S::Item>,
88 P: ParallelIterator,
89 S: Iterator<Item = P::Item>,
90{
91 fn into_maybe_par_iter(self) -> CondIterator<P, S> {
92 let parallelism = get_parallelism();
93 if parallelism {
94 USED_PARALLELISM.store(true, Ordering::SeqCst);
95 }
96 CondIterator::new(self, parallelism)
97 }
98
99 fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator<P, S> {
100 if cond {
101 self.into_maybe_par_iter()
102 } else {
103 CondIterator::from_serial(self)
104 }
105 }
106}
107
108pub trait MaybeParallelRefIterator<'data, P, S>
111where
112 P: ParallelIterator,
113 S: Iterator<Item = P::Item>,
114 P::Item: 'data,
115{
116 fn maybe_par_iter(&'data self) -> CondIterator<P, S>;
117 fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator<P, S>;
118}
119
120impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefIterator<'data, P, S> for I
121where
122 &'data I: MaybeParallelIterator<P, S>,
123 P: ParallelIterator,
124 S: Iterator<Item = P::Item>,
125 P::Item: 'data,
126{
127 fn maybe_par_iter(&'data self) -> CondIterator<P, S> {
128 self.into_maybe_par_iter()
129 }
130
131 fn maybe_par_iter_cond(&'data self, cond: bool) -> CondIterator<P, S> {
132 self.into_maybe_par_iter_cond(cond)
133 }
134}
135
136pub trait MaybeParallelRefMutIterator<'data, P, S>
139where
140 P: ParallelIterator,
141 S: Iterator<Item = P::Item>,
142 P::Item: 'data,
143{
144 fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S>;
145 fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator<P, S>;
146}
147
148impl<'data, P, S, I: 'data + ?Sized> MaybeParallelRefMutIterator<'data, P, S> for I
149where
150 &'data mut I: MaybeParallelIterator<P, S>,
151 P: ParallelIterator,
152 S: Iterator<Item = P::Item>,
153 P::Item: 'data,
154{
155 fn maybe_par_iter_mut(&'data mut self) -> CondIterator<P, S> {
156 self.into_maybe_par_iter()
157 }
158
159 fn maybe_par_iter_mut_cond(&'data mut self, cond: bool) -> CondIterator<P, S> {
160 self.into_maybe_par_iter_cond(cond)
161 }
162}
163
164pub trait MaybeParallelBridge<T, S>
166where
167 S: Iterator<Item = T> + Send,
168 T: Send,
169{
170 fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S>;
171 fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator<IterBridge<S>, S>;
172}
173
174impl<T, S> MaybeParallelBridge<T, S> for S
175where
176 S: Iterator<Item = T> + Send,
177 T: Send,
178{
179 fn maybe_par_bridge(self) -> CondIterator<IterBridge<S>, S> {
180 let iter = CondIterator::from_serial(self);
181
182 if get_parallelism() {
183 USED_PARALLELISM.store(true, Ordering::SeqCst);
184 CondIterator::from_parallel(iter.into_parallel().right().unwrap())
185 } else {
186 iter
187 }
188 }
189
190 fn maybe_par_bridge_cond(self, cond: bool) -> CondIterator<IterBridge<S>, S> {
191 if cond {
192 self.maybe_par_bridge()
193 } else {
194 CondIterator::from_serial(self)
195 }
196 }
197}
198
199pub trait MaybeParallelSlice<'data, T>
201where
202 T: Sync,
203{
204 fn maybe_par_chunks(
207 &'_ self,
208 chunk_size: usize,
209 ) -> CondIterator<rayon::slice::Chunks<'_, T>, std::slice::Chunks<'_, T>>;
210 fn maybe_par_chunks_cond(
214 &'_ self,
215 cond: bool,
216 chunk_size: usize,
217 ) -> CondIterator<rayon::slice::Chunks<'_, T>, std::slice::Chunks<'_, T>>;
218}
219
220impl<T> MaybeParallelSlice<'_, T> for [T]
221where
222 T: Sync,
223{
224 fn maybe_par_chunks(
225 &'_ self,
226 chunk_size: usize,
227 ) -> CondIterator<rayon::slice::Chunks<'_, T>, std::slice::Chunks<'_, T>> {
228 let parallelism = get_parallelism();
229 if parallelism {
230 CondIterator::from_parallel(self.par_chunks(chunk_size))
231 } else {
232 CondIterator::from_serial(self.chunks(chunk_size))
233 }
234 }
235 fn maybe_par_chunks_cond(
236 &'_ self,
237 cond: bool,
238 chunk_size: usize,
239 ) -> CondIterator<rayon::slice::Chunks<'_, T>, std::slice::Chunks<'_, T>> {
240 if cond {
241 self.maybe_par_chunks(chunk_size)
242 } else {
243 CondIterator::from_serial(self.chunks(chunk_size))
244 }
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn test_maybe_parallel_iterator() {
254 let mut v = vec![1u32, 2, 3, 4, 5, 6];
255
256 assert_eq!(v.maybe_par_iter().sum::<u32>(), 21);
257 assert_eq!(
258 v.maybe_par_iter_mut()
259 .map(|v| {
260 *v *= 2;
261 *v
262 })
263 .sum::<u32>(),
264 42
265 );
266 assert_eq!(v.maybe_par_iter().sum::<u32>(), 42);
267 assert_eq!(v.into_maybe_par_iter().sum::<u32>(), 42);
268 }
269
270 #[test]
271 fn test_maybe_parallel_slice() {
272 let v = [1, 2, 3, 4, 5];
273
274 let chunks: Vec<_> = v.maybe_par_chunks(2).collect();
275 assert_eq!(chunks, vec![&[1, 2][..], &[3, 4], &[5]]);
276 }
277}