titan_rust_client/
queue.rs1use std::collections::VecDeque;
4use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
5use std::sync::Arc;
6
7use titan_api_types::ws::v1::{RequestData, StreamDataPayload, SwapQuoteRequest, SwapQuotes};
8use tokio::sync::{mpsc, oneshot, Mutex, Notify};
9
10use crate::connection::Connection;
11use crate::error::TitanClientError;
12use crate::stream::QuoteStream;
13
14struct QueuedRequest {
16 request: SwapQuoteRequest,
17 result_tx: oneshot::Sender<Result<QuoteStream, TitanClientError>>,
18}
19
20pub struct StreamManager {
22 max_concurrent: AtomicU32,
23 active_count: AtomicU32,
24 queue: Mutex<VecDeque<QueuedRequest>>,
25 connection: Arc<Connection>,
26 slot_available: Notify,
27 queue_worker_active: AtomicBool,
28}
29
30impl StreamManager {
31 pub fn new(connection: Arc<Connection>, max_concurrent: u32) -> Arc<Self> {
33 Arc::new(Self {
34 max_concurrent: AtomicU32::new(max_concurrent),
35 active_count: AtomicU32::new(0),
36 queue: Mutex::new(VecDeque::new()),
37 connection,
38 slot_available: Notify::new(),
39 queue_worker_active: AtomicBool::new(false),
40 })
41 }
42
43 pub fn set_max_concurrent(&self, max: u32) {
45 self.max_concurrent.store(max, Ordering::SeqCst);
46 self.slot_available.notify_waiters();
48 }
49
50 pub fn active_count(&self) -> u32 {
52 self.active_count.load(Ordering::SeqCst)
53 }
54
55 pub async fn queue_len(&self) -> usize {
57 self.queue.lock().await.len()
58 }
59
60 #[tracing::instrument(skip_all)]
62 pub async fn request_stream(
63 self: &Arc<Self>,
64 request: SwapQuoteRequest,
65 ) -> Result<QuoteStream, TitanClientError> {
66 let max = self.max_concurrent.load(Ordering::SeqCst);
68 let current = self.active_count.load(Ordering::SeqCst);
69
70 if current < max {
71 if self
73 .active_count
74 .compare_exchange(current, current + 1, Ordering::SeqCst, Ordering::SeqCst)
75 .is_ok()
76 {
77 return self.start_stream_internal(request).await;
78 }
79 }
80
81 let (result_tx, result_rx) = oneshot::channel();
83 {
84 let mut queue = self.queue.lock().await;
85 queue.push_back(QueuedRequest { request, result_tx });
86 }
87
88 if self
90 .queue_worker_active
91 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
92 .is_ok()
93 {
94 let manager = self.clone();
95 tokio::spawn(async move {
96 manager.process_queue().await;
97 });
98 }
99
100 result_rx.await.map_err(|_| {
102 TitanClientError::Unexpected(anyhow::anyhow!("Stream request cancelled"))
103 })?
104 }
105
106 pub fn stream_ended(&self) {
108 self.active_count.fetch_sub(1, Ordering::SeqCst);
109 self.slot_available.notify_one();
110 }
111
112 async fn process_queue(self: &Arc<Self>) {
114 loop {
115 let max = self.max_concurrent.load(Ordering::SeqCst);
116 let current = self.active_count.load(Ordering::SeqCst);
117
118 if current >= max {
119 self.slot_available.notified().await;
121 continue;
122 }
123
124 let queued = {
126 let mut queue = self.queue.lock().await;
127 queue.pop_front()
128 };
129
130 let Some(queued) = queued else {
131 self.queue_worker_active.store(false, Ordering::SeqCst);
133 let has_work = !self.queue.lock().await.is_empty();
134 if has_work
135 && self
136 .queue_worker_active
137 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
138 .is_ok()
139 {
140 continue;
141 }
142 break;
143 };
144
145 if self
147 .active_count
148 .compare_exchange(current, current + 1, Ordering::SeqCst, Ordering::SeqCst)
149 .is_err()
150 {
151 let mut queue = self.queue.lock().await;
153 queue.push_front(queued);
154 continue;
155 }
156
157 let result = self.start_stream_internal(queued.request).await;
159 let _ = queued.result_tx.send(result);
160 }
161 }
162
163 async fn start_stream_internal(
165 self: &Arc<Self>,
166 request: SwapQuoteRequest,
167 ) -> Result<QuoteStream, TitanClientError> {
168 let slot_released = Arc::new(AtomicBool::new(false));
170
171 let response = self
172 .connection
173 .send_request(RequestData::NewSwapQuoteStream(request.clone()))
174 .await
175 .inspect_err(|_| {
176 if slot_released
178 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
179 .is_ok()
180 {
181 self.active_count.fetch_sub(1, Ordering::SeqCst);
182 self.slot_available.notify_one();
183 }
184 })?;
185
186 let stream_id = response
187 .stream
188 .ok_or_else(|| {
189 if slot_released
190 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
191 .is_ok()
192 {
193 self.active_count.fetch_sub(1, Ordering::SeqCst);
194 self.slot_available.notify_one();
195 }
196 TitanClientError::Unexpected(anyhow::anyhow!(
197 "NewSwapQuoteStream response missing stream info"
198 ))
199 })?
200 .id;
201
202 let effective_stream_id = Arc::new(AtomicU32::new(stream_id));
204 let stopped_flag = Arc::new(AtomicBool::new(false));
205
206 let on_end_slot_released = slot_released.clone();
208 let on_end_manager: Arc<Self> = self.clone();
209 let on_end_stopped = stopped_flag.clone();
210 let on_end: Arc<dyn Fn() + Send + Sync> = Arc::new(move || {
211 on_end_stopped.store(true, Ordering::SeqCst);
212 if on_end_slot_released
213 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
214 .is_ok()
215 {
216 on_end_manager.stream_ended();
217 }
218 });
219
220 let (raw_tx, mut raw_rx) = mpsc::channel::<titan_api_types::ws::v1::StreamData>(32);
222 let (quotes_tx, quotes_rx) = mpsc::channel::<SwapQuotes>(32);
223
224 self.connection
226 .register_stream(
227 stream_id,
228 request,
229 raw_tx,
230 Some(on_end),
231 Some(effective_stream_id.clone()),
232 stopped_flag.clone(),
233 )
234 .await;
235
236 let adapter_connection = self.connection.clone();
238 let adapter_effective_id = effective_stream_id.clone();
239 tokio::spawn(async move {
240 while let Some(data) = raw_rx.recv().await {
241 match data.payload {
242 StreamDataPayload::SwapQuotes(quotes) => {
243 if quotes_tx.send(quotes).await.is_err() {
244 let eid = adapter_effective_id.load(Ordering::SeqCst);
245 adapter_connection.unregister_stream(eid).await;
246 break;
247 }
248 }
249 StreamDataPayload::Other(_) => {
250 tracing::warn!("Received unexpected stream data payload type");
251 }
252 }
253 }
254 });
255
256 Ok(QuoteStream::new_managed(
257 stream_id,
258 quotes_rx,
259 self.connection.clone(),
260 Some(self.clone()),
261 effective_stream_id,
262 stopped_flag,
263 slot_released,
264 ))
265 }
266}