1#[cfg(not(loom))]
2use std::sync::atomic::Ordering;
3use std::{
4 marker::PhantomData,
5 mem::{self, needs_drop},
6 ops::Range,
7};
8
9use async_local::LocalRef;
10#[cfg(loom)]
11use loom::sync::atomic::Ordering;
12#[cfg(not(loom))]
13use tokio::task::{spawn_blocking, JoinHandle};
14
15use crate::{
16 queue::{Inner, TaskQueue, INDEX_SHIFT, PHASE},
17 task::TaskRef,
18 BufferCell,
19};
20
21pub struct PendingAssignment<'a, T: TaskQueue, const N: usize> {
23 base_slot: usize,
24 queue: LocalRef<'a, Inner<TaskRef<T>, N>>,
25}
26
27impl<'a, T, const N: usize> PendingAssignment<'a, T, N>
28where
29 T: TaskQueue,
30{
31 pub(crate) fn new(base_slot: usize, queue: LocalRef<'a, Inner<TaskRef<T>, N>>) -> Self {
32 PendingAssignment { base_slot, queue }
33 }
34
35 #[inline(always)]
36 fn set_assignment_bounds(&self) -> Range<usize> {
37 let end_slot = self.queue.slot.fetch_xor(PHASE, Ordering::Relaxed);
38
39 (self.base_slot >> INDEX_SHIFT)..(end_slot >> INDEX_SHIFT)
40 }
41
42 pub fn into_assignment(self) -> TaskAssignment<'a, T, N> {
47 let task_range = self.set_assignment_bounds();
48 let queue = self.queue;
49
50 mem::forget(self);
51
52 TaskAssignment::new(task_range, queue)
53 }
54
55 #[cfg(not(loom))]
57 pub async fn with_blocking<F>(self, f: F) -> CompletionReceipt<T>
58 where
59 F: for<'b> FnOnce(PendingAssignment<'b, T, N>) -> CompletionReceipt<T> + Send + 'static,
60 {
61 let batch: PendingAssignment<'_, T, N> = unsafe { std::mem::transmute(self) };
62 tokio::task::spawn_blocking(move || f(batch)).await.unwrap()
63 }
64}
65
66unsafe impl<T, const N: usize> Send for PendingAssignment<'_, T, N> where T: TaskQueue {}
67unsafe impl<T, const N: usize> Sync for PendingAssignment<'_, T, N> where T: TaskQueue {}
68
69impl<T, const N: usize> Drop for PendingAssignment<'_, T, N>
70where
71 T: TaskQueue,
72{
73 fn drop(&mut self) {
74 let task_range = self.set_assignment_bounds();
75 let queue = self.queue;
76
77 TaskAssignment::new(task_range, queue);
78 }
79}
80
81pub struct TaskAssignment<'a, T: TaskQueue, const N: usize> {
83 task_range: Range<usize>,
84 queue: LocalRef<'a, Inner<TaskRef<T>, N>>,
85}
86
87impl<'a, T, const N: usize> TaskAssignment<'a, T, N>
88where
89 T: TaskQueue,
90{
91 fn new(task_range: Range<usize>, queue: LocalRef<'a, Inner<TaskRef<T>, N>>) -> Self {
92 TaskAssignment { task_range, queue }
93 }
94
95 pub fn as_slices(&self) -> (&[TaskRef<T>], &[TaskRef<T>]) {
97 let start = self.task_range.start & (N - 1);
98 let end = self.task_range.end & (N - 1);
99
100 if end > start {
101 unsafe { (self.queue.buffer.get_unchecked(start..end), &[]) }
102 } else {
103 unsafe {
104 (
105 self.queue.buffer.get_unchecked(start..N),
106 self.queue.buffer.get_unchecked(0..end),
107 )
108 }
109 }
110 }
111
112 pub fn tasks(&self) -> impl Iterator<Item = &TaskRef<T>> {
114 let tasks = self.as_slices();
115 tasks.0.iter().chain(tasks.1.iter())
116 }
117
118 pub fn resolve_with_iter<I>(self, iter: I) -> CompletionReceipt<T>
120 where
121 I: IntoIterator<Item = T::Value>,
122 {
123 self.tasks().zip(iter).for_each(|(task_ref, value)| unsafe {
124 if needs_drop::<T::Task>() {
125 drop(task_ref.take_task_unchecked());
126 }
127
128 task_ref.resolve_unchecked(value);
129 });
130
131 self.into_completion_receipt()
132 }
133
134 pub fn map<F>(self, op: F) -> CompletionReceipt<T>
136 where
137 F: Fn(T::Task) -> T::Value + Sync,
138 {
139 self.tasks().for_each(|task_ref| unsafe {
140 let task = task_ref.take_task_unchecked();
141 task_ref.resolve_unchecked(op(task));
142 });
143
144 self.into_completion_receipt()
145 }
146
147 #[inline(always)]
148 fn deoccupy_buffer(&self) {
149 self.queue.deoccupy_region(self.task_range.start & (N - 1));
150 }
151
152 fn into_completion_receipt(self) -> CompletionReceipt<T> {
153 self.deoccupy_buffer();
154
155 mem::forget(self);
156
157 CompletionReceipt::new()
158 }
159
160 #[cfg(not(loom))]
162 pub async fn with_blocking<F>(self, f: F) -> CompletionReceipt<T>
163 where
164 F: for<'b> FnOnce(TaskAssignment<'b, T, N>) -> CompletionReceipt<T> + Send + 'static,
165 {
166 let batch: TaskAssignment<'_, T, N> = unsafe { std::mem::transmute(self) };
167 tokio::task::spawn_blocking(move || f(batch)).await.unwrap()
168 }
169}
170
171impl<T, const N: usize> Drop for TaskAssignment<'_, T, N>
172where
173 T: TaskQueue,
174{
175 fn drop(&mut self) {
176 if needs_drop::<T::Task>() {
177 self
178 .tasks()
179 .for_each(|task_ref| unsafe { drop(task_ref.take_task_unchecked()) });
180 }
181
182 self.deoccupy_buffer();
183 }
184}
185
186unsafe impl<T, const N: usize> Send for TaskAssignment<'_, T, N> where T: TaskQueue {}
187unsafe impl<T, const N: usize> Sync for TaskAssignment<'_, T, N> where T: TaskQueue {}
188
189pub struct CompletionReceipt<T: TaskQueue>(PhantomData<T>);
191
192impl<T> CompletionReceipt<T>
193where
194 T: TaskQueue,
195{
196 fn new() -> Self {
197 CompletionReceipt(PhantomData)
198 }
199}
200pub struct UnboundedRange<'a, T: Send + Sync + Sized + 'static, const N: usize> {
203 base_slot: usize,
204 queue: LocalRef<'a, Inner<BufferCell<T>, N>>,
205}
206
207impl<'a, T, const N: usize> UnboundedRange<'a, T, N>
208where
209 T: Send + Sync + Sized + 'static,
210{
211 pub(crate) fn new(base_slot: usize, queue: LocalRef<'a, Inner<BufferCell<T>, N>>) -> Self {
212 UnboundedRange { base_slot, queue }
213 }
214
215 #[inline(always)]
216 fn set_bounds(&self) -> Range<usize> {
217 let end_slot = self.queue.slot.fetch_xor(PHASE, Ordering::Relaxed);
218 (self.base_slot >> INDEX_SHIFT)..(end_slot >> INDEX_SHIFT)
219 }
220
221 pub fn into_bounded(self) -> BoundedRange<'a, T, N> {
223 let range = self.set_bounds();
224 let queue = self.queue;
225
226 mem::forget(self);
227
228 BoundedRange::new(range, queue)
229 }
230
231 #[cfg(not(loom))]
233 pub fn with_blocking<F, R>(self, f: F) -> JoinHandle<R>
234 where
235 F: for<'b> FnOnce(UnboundedRange<'b, T, N>) -> R + Send + 'static,
236 R: Send + 'static,
237 {
238 let batch: UnboundedRange<'_, T, N> = unsafe { std::mem::transmute(self) };
239 spawn_blocking(move || f(batch))
240 }
241}
242
243impl<T, const N: usize> Drop for UnboundedRange<'_, T, N>
244where
245 T: Send + Sync + Sized + 'static,
246{
247 fn drop(&mut self) {
248 let task_range = self.set_bounds();
249 let start_index = task_range.start & (N - 1);
250
251 let queue = self.queue;
252
253 if needs_drop::<T>() {
254 for index in task_range {
255 unsafe {
256 queue.with_buffer_cell(|cell| (*cell).assume_init_drop(), index & (N - 1));
257 }
258 }
259 }
260
261 self.queue.deoccupy_region(start_index);
262 }
263}
264
265unsafe impl<T, const N: usize> Send for UnboundedRange<'_, T, N> where
266 T: Send + Sync + Sized + 'static
267{
268}
269unsafe impl<T, const N: usize> Sync for UnboundedRange<'_, T, N> where
270 T: Send + Sync + Sized + 'static
271{
272}
273
274pub struct BoundedRange<'a, T: Send + Sync + Sized + 'static, const N: usize> {
277 range: Range<usize>,
278 queue: LocalRef<'a, Inner<BufferCell<T>, N>>,
279}
280
281impl<'a, T, const N: usize> BoundedRange<'a, T, N>
282where
283 T: Send + Sync + Sized + 'static,
284{
285 fn new(range: Range<usize>, queue: LocalRef<'a, Inner<BufferCell<T>, N>>) -> Self {
286 BoundedRange { range, queue }
287 }
288
289 #[cfg(not(loom))]
291 pub fn as_slices(&self) -> (&[T], &[T]) {
292 let start = self.range.start & (N - 1);
293 let end = self.range.end & (N - 1);
294
295 if end > start {
296 unsafe {
297 mem::transmute::<(&[BufferCell<T>], &[BufferCell<T>]), _>((
298 self.queue.buffer.get_unchecked(start..end),
299 &[],
300 ))
301 }
302 } else {
303 unsafe {
304 mem::transmute((
305 self.queue.buffer.get_unchecked(start..N),
306 self.queue.buffer.get_unchecked(0..end),
307 ))
308 }
309 }
310 }
311
312 #[cfg(not(loom))]
314 pub fn iter(&self) -> impl Iterator<Item = &T> {
315 let tasks = self.as_slices();
316 tasks.0.iter().chain(tasks.1.iter())
317 }
318
319 #[cfg(not(loom))]
320 pub fn to_vec(self) -> Vec<T> {
321 let items = self.as_slices();
322 let front_len = items.0.len();
323 let back_len = items.1.len();
324 let total_len = front_len + back_len;
325 let mut buffer = Vec::new();
326 buffer.reserve_exact(total_len);
327
328 unsafe {
329 std::ptr::copy_nonoverlapping(items.0.as_ptr(), buffer.as_mut_ptr(), front_len);
330 if back_len > 0 {
331 std::ptr::copy_nonoverlapping(
332 items.1.as_ptr(),
333 buffer.as_mut_ptr().add(front_len),
334 back_len,
335 );
336 }
337 buffer.set_len(total_len);
338 }
339
340 self.deoccupy_buffer();
341
342 mem::forget(self);
343
344 buffer
345 }
346
347 #[cfg(not(loom))]
349 pub fn with_blocking<F, R>(self, f: F) -> JoinHandle<R>
350 where
351 F: for<'b> FnOnce(BoundedRange<'b, T, N>) -> R + Send + 'static,
352 R: Send + 'static,
353 {
354 let batch: BoundedRange<'_, T, N> = unsafe { std::mem::transmute(self) };
355 batch.queue.with_blocking(move |_| f(batch))
356 }
357
358 #[inline(always)]
359 fn deoccupy_buffer(&self) {
360 self.queue.deoccupy_region(self.range.start & (N - 1));
361 }
362}
363
364impl<T, const N: usize> Drop for BoundedRange<'_, T, N>
365where
366 T: Send + Sync + Sized + 'static,
367{
368 fn drop(&mut self) {
369 if needs_drop::<T>() {
370 for index in self.range.clone() {
371 unsafe {
372 self
373 .queue
374 .with_buffer_cell(|cell| (*cell).assume_init_drop(), index & (N - 1));
375 }
376 }
377 }
378
379 self.deoccupy_buffer();
380 }
381}
382
383unsafe impl<T, const N: usize> Send for BoundedRange<'_, T, N> where T: Send + Sync + Sized + 'static
384{}
385unsafe impl<T, const N: usize> Sync for BoundedRange<'_, T, N> where T: Send + Sync + Sized + 'static
386{}
387
388pub struct BufferIter<'a, T: Send + Sync + Sized + 'static, const N: usize> {
390 current: usize,
391 range: Range<usize>,
392 queue: LocalRef<'a, Inner<BufferCell<T>, N>>,
393}
394
395impl<T, const N: usize> BufferIter<'_, T, N>
396where
397 T: Send + Sync + Sized + 'static,
398{
399 fn deoccupy_buffer(&self) {
400 self.queue.deoccupy_region(self.range.start & (N - 1));
401 }
402}
403
404impl<T, const N: usize> Iterator for BufferIter<'_, T, N>
405where
406 T: Send + Sync + Sized + 'static,
407{
408 type Item = T;
409
410 fn next(&mut self) -> Option<Self::Item> {
411 if self.current < self.range.end {
412 let task = unsafe {
413 self
414 .queue
415 .with_buffer_cell(|cell| (*cell).assume_init_read(), self.current & (N - 1))
416 };
417
418 self.current += 1;
419
420 Some(task)
421 } else {
422 None
423 }
424 }
425}
426
427unsafe impl<T, const N: usize> Send for BufferIter<'_, T, N> where T: Send + Sync + Sized + 'static {}
428unsafe impl<T, const N: usize> Sync for BufferIter<'_, T, N> where T: Send + Sync + Sized + 'static {}
429
430impl<T, const N: usize> Drop for BufferIter<'_, T, N>
431where
432 T: Send + Sync + Sized + 'static,
433{
434 fn drop(&mut self) {
435 if needs_drop::<T>() {
436 while self.next().is_some() {}
437 }
438 self.deoccupy_buffer();
439 }
440}
441
442impl<'a, T, const N: usize> IntoIterator for BoundedRange<'a, T, N>
443where
444 T: Send + Sync + Sized + 'static,
445{
446 type Item = T;
447 type IntoIter = BufferIter<'a, T, N>;
448
449 fn into_iter(self) -> Self::IntoIter {
450 let iter = BufferIter {
451 current: self.range.start,
452 range: self.range.clone(),
453 queue: self.queue,
454 };
455
456 mem::forget(self);
457
458 iter
459 }
460}