1use std::{ops::Deref, sync::Arc};
2
3use concurrent_queue::{ConcurrentQueue, PushError};
4use onetime::{channel, RecvError, Sender};
5use thiserror::Error;
6
7#[derive(Debug)]
8struct Waiters<T>(ConcurrentQueue<Sender<T>>);
9
10impl<T> Default for Waiters<T> {
11 fn default() -> Self {
12 Self(ConcurrentQueue::unbounded())
13 }
14}
15
16impl<T> Deref for Waiters<T> {
17 type Target = ConcurrentQueue<Sender<T>>;
18
19 fn deref(&self) -> &Self::Target {
20 &self.0
21 }
22}
23
24#[derive(Debug)]
26pub struct Customer<T> {
27 waiters: Arc<Waiters<T>>,
28}
29
30impl<T> Clone for Customer<T> {
31 fn clone(&self) -> Self {
32 Self { waiters: self.waiters.clone() }
33 }
34}
35
36impl<T> Customer<T> {
37 pub async fn request(&self) -> Result<T, RequestError> {
42 let (tx, rx) = channel();
43 self.waiters.push(tx)?;
44 rx.recv().await.map_err(Into::into)
45 }
46}
47
48#[derive(Debug)]
51pub struct Vendor<T> {
52 waiters: Arc<Waiters<T>>,
53}
54
55impl<T> Clone for Vendor<T> {
56 fn clone(&self) -> Self {
57 Self { waiters: self.waiters.clone() }
58 }
59}
60
61impl<T> Default for Vendor<T> {
62 fn default() -> Self {
63 Self { waiters: Arc::default() }
64 }
65}
66
67impl<T> Vendor<T> {
68 pub fn new() -> Self {
69 Self::default()
70 }
71
72 pub fn customer(&self) -> Customer<T> {
74 Customer { waiters: self.waiters.clone() }
75 }
76
77 pub fn send(&self, resource: T)
81 where
82 T: Clone,
83 {
84 if self.waiters_count() == 1 {
85 if let Ok(waiter) = self.waiters.pop() {
86 let _ = waiter.send(resource);
87 }
88 } else {
89 for _ in 0..self.waiters_count() - 1 {
90 if let Ok(waiter) = self.waiters.pop() {
91 let _ = waiter.send(resource.clone());
92 }
93 }
94
95 if let Ok(waiter) = self.waiters.pop() {
96 let _ = waiter.send(resource);
97 }
98 }
99 }
100
101 pub fn waiters_count(&self) -> usize {
102 self.waiters.len()
103 }
104
105 pub fn has_waiters(&self) -> bool {
108 self.waiters_count() > 0
109 }
110}
111
112#[derive(Debug, Error)]
113#[error("failed to request resource")]
114pub enum RequestError {
115 Push,
116 Recv,
117}
118
119impl<T> From<PushError<T>> for RequestError {
120 fn from(_: PushError<T>) -> Self {
121 Self::Push
122 }
123}
124
125impl From<RecvError> for RequestError {
126 fn from(_: RecvError) -> Self {
127 Self::Recv
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134
135 #[test]
136 fn it_works() {
137 smol::block_on(async move {
138 let vendor = Vendor::new();
139 let customer1 = vendor.customer();
140 let customer2 = vendor.customer();
141
142 let t1 = smol::spawn(async move {
143 assert!(matches!(customer1.request().await, Ok("ok")));
144 });
145
146 let t2 = smol::spawn(async move {
147 assert!(matches!(customer2.request().await, Ok("ok")));
148 });
149
150 let t3 = smol::spawn(async move {
151 vendor.send("ok");
152 });
153
154 t1.await;
155 t2.await;
156 t3.await;
157 });
158 }
159}