1#![cfg(feature = "tokio")]
2use std::fmt::{Debug, Formatter};
3use std::sync::Arc;
4use std::sync::atomic::{AtomicI64, AtomicUsize, Ordering};
5use std::time::{Duration, Instant};
6
7use crate::encode::{Value, VoltError};
8use crate::node::{ConnInfo, NodeOpt};
9use crate::procedure_invocation::new_procedure_invocation;
10use crate::protocol::{PING_HANDLE, build_auth_message, parse_auth_response};
11use crate::response::VoltResponseInfo;
12use crate::table::{VoltTable, new_volt_table};
13use crate::volt_param;
14use byteorder::{BigEndian, ByteOrder};
15use bytes::{Buf, BytesMut};
16use dashmap::DashMap;
17use tokio::io::{AsyncReadExt, AsyncWriteExt, ReadHalf, WriteHalf};
18use tokio::net::TcpStream;
19use tokio::sync::{mpsc, watch};
20use tokio::time::timeout;
21
22const MAX_MESSAGE_SIZE: usize = 50 * 1024 * 1024; const WRITE_BUFFER_SIZE: usize = 1024; const BATCH_WRITE_THRESHOLD: usize = 8192; const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30); #[allow(dead_code)]
28const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(60); #[allow(dead_code)]
32enum WriteCommand {
33 Data(Vec<u8>),
34 Flush,
35}
36
37#[allow(dead_code)]
39struct AsyncNetworkRequest {
40 handle: i64,
41 query: bool,
42 sync: bool,
43 num_bytes: i32,
44 channel: mpsc::Sender<VoltTable>,
45 created_at: Instant, }
47
48impl Debug for AsyncNetworkRequest {
49 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
50 f.debug_struct("AsyncNetworkRequest")
51 .field("handle", &self.handle)
52 .field("query", &self.query)
53 .field("age_ms", &self.created_at.elapsed().as_millis())
54 .finish()
55 }
56}
57
58pub struct AsyncNode {
60 write_tx: mpsc::Sender<WriteCommand>,
62 info: ConnInfo,
64 requests: Arc<DashMap<i64, AsyncNetworkRequest>>,
66 stop: Arc<watch::Sender<bool>>,
68 counter: Arc<AtomicI64>,
70 pending_requests: Arc<AtomicUsize>,
72}
73
74impl Debug for AsyncNode {
75 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
76 f.debug_struct("AsyncNode")
77 .field(
78 "pending_requests",
79 &self.pending_requests.load(Ordering::Relaxed),
80 )
81 .field("total_requests", &self.requests.len())
82 .finish()
83 }
84}
85
86impl AsyncNode {
87 pub async fn new(opt: NodeOpt) -> Result<AsyncNode, VoltError> {
89 let addr = format!("{}:{}", opt.ip_port.ip_host, opt.ip_port.port);
90
91 let auth_msg = build_auth_message(opt.user.as_deref(), opt.pass.as_deref())?;
93
94 let mut stream = TcpStream::connect(&addr).await?;
96
97 stream.set_nodelay(true)?; stream.write_all(&auth_msg).await?;
105 stream.flush().await?;
106
107 let mut len_buf = [0u8; 4];
109 stream.read_exact(&mut len_buf).await?;
110 let read = BigEndian::read_u32(&len_buf) as usize;
111
112 let mut all = vec![0; read];
113 stream.read_exact(&mut all).await?;
114
115 let info = parse_auth_response(&all)?;
117
118 let (read_half, write_half) = tokio::io::split(stream);
120
121 let requests = Arc::new(DashMap::new());
123 let (stop_tx, stop_rx) = watch::channel(false);
124 let (write_tx, write_rx) = mpsc::channel(WRITE_BUFFER_SIZE);
125
126 let node = AsyncNode {
127 stop: Arc::new(stop_tx),
128 write_tx,
129 info,
130 requests: requests.clone(),
131 counter: Arc::new(AtomicI64::new(1)),
132 pending_requests: Arc::new(AtomicUsize::new(0)),
133 };
134
135 node.spawn_writer(write_half, write_rx, stop_rx.clone());
137 node.spawn_reader(read_half, stop_rx.clone());
138 node.spawn_timeout_checker(stop_rx);
139
140 Ok(node)
141 }
142
143 #[inline]
145 pub fn get_sequence(&self) -> i64 {
146 self.counter.fetch_add(1, Ordering::Relaxed)
147 }
148
149 #[inline]
151 pub fn pending_count(&self) -> usize {
152 self.pending_requests.load(Ordering::Relaxed)
153 }
154
155 pub fn conn_info(&self) -> &ConnInfo {
157 &self.info
158 }
159
160 pub async fn list_procedures(&self) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
162 self.call_sp("@SystemCatalog", volt_param!("PROCEDURES"))
163 .await
164 }
165
166 pub async fn call_sp(
168 &self,
169 query: &str,
170 param: Vec<&dyn Value>,
171 ) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
172 self.call_sp_with_timeout(query, param, DEFAULT_TIMEOUT)
173 .await
174 }
175
176 pub async fn call_sp_with_timeout(
178 &self,
179 query: &str,
180 param: Vec<&dyn Value>,
181 _timeout_duration: Duration,
182 ) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
183 let req = self.get_sequence();
184 let mut proc = new_procedure_invocation(req, false, ¶m, query);
185
186 let (tx, rx) = mpsc::channel(1);
188
189 let seq = AsyncNetworkRequest {
190 query: true,
191 handle: req,
192 num_bytes: proc.slen,
193 sync: true,
194 channel: tx,
195 created_at: Instant::now(),
196 };
197
198 self.requests.insert(req, seq);
200 self.pending_requests.fetch_add(1, Ordering::Relaxed);
201
202 let bs = proc.bytes();
204 self.write_tx
205 .send(WriteCommand::Data(bs))
206 .await
207 .map_err(|_| VoltError::connection_closed())?;
208
209 Ok(rx)
210 }
211
212 pub async fn upload_jar(&self, bs: Vec<u8>) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
214 self.call_sp("@UpdateClasses", volt_param!(bs, "")).await
215 }
216
217 pub async fn query(&self, sql: &str) -> Result<mpsc::Receiver<VoltTable>, VoltError> {
219 let mut zero_vec: Vec<&dyn Value> = Vec::new();
220 zero_vec.push(&sql);
221 self.call_sp("@AdHoc", zero_vec).await
222 }
223
224 pub async fn ping(&self) -> Result<(), VoltError> {
226 let zero_vec: Vec<&dyn Value> = Vec::new();
227 let mut proc = new_procedure_invocation(PING_HANDLE, false, &zero_vec, "@Ping");
228 let bs = proc.bytes();
229
230 self.write_tx
231 .send(WriteCommand::Data(bs))
232 .await
233 .map_err(|_| VoltError::connection_closed())?;
234
235 Ok(())
236 }
237
238 pub async fn shutdown(&self) -> Result<(), VoltError> {
240 let _ = self.stop.send(true);
241 Ok(())
242 }
243
244 fn spawn_writer(
246 &self,
247 mut write_half: WriteHalf<TcpStream>,
248 mut write_rx: mpsc::Receiver<WriteCommand>,
249 mut stop_rx: watch::Receiver<bool>,
250 ) {
251 tokio::spawn(async move {
252 let mut batch_buffer = Vec::with_capacity(BATCH_WRITE_THRESHOLD * 2);
253
254 loop {
255 tokio::select! {
256 _ = stop_rx.changed() => {
257 if *stop_rx.borrow() {
258 break;
259 }
260 }
261 cmd = write_rx.recv() => {
262 match cmd {
263 Some(WriteCommand::Data(bytes)) => {
264 batch_buffer.extend_from_slice(&bytes);
265
266 while batch_buffer.len() < BATCH_WRITE_THRESHOLD {
268 match write_rx.try_recv() {
269 Ok(WriteCommand::Data(more_bytes)) => {
270 batch_buffer.extend_from_slice(&more_bytes);
271 }
272 Ok(WriteCommand::Flush) => break,
273 Err(_) => break,
274 }
275 }
276
277 if let Err(e) = write_half.write_all(&batch_buffer).await {
279 eprintln!("Write error: {}", e);
280 break;
281 }
282 batch_buffer.clear();
283 }
284 Some(WriteCommand::Flush) => {
285 if !batch_buffer.is_empty() {
286 if let Err(e) = write_half.write_all(&batch_buffer).await {
287 eprintln!("Flush error: {}", e);
288 break;
289 }
290 batch_buffer.clear();
291 }
292 let _ = write_half.flush().await;
293 }
294 None => break,
295 }
296 }
297 }
298 }
299
300 eprintln!("Writer task terminated");
301 });
302 }
303
304 fn spawn_reader(&self, mut read_half: ReadHalf<TcpStream>, mut stop_rx: watch::Receiver<bool>) {
306 let requests = Arc::clone(&self.requests);
307 let pending_requests = Arc::clone(&self.pending_requests);
308
309 tokio::spawn(async move {
310 let reason = loop {
311 tokio::select! {
312 _ = stop_rx.changed() => {
313 if *stop_rx.borrow() {
314 break "shutdown requested";
315 }
316 }
317 result = Self::async_job(&mut read_half, &requests, &pending_requests) => {
318 if let Err(e) = result {
319 if !*stop_rx.borrow() {
320 eprintln!("Read error: {}", e);
321 }
322 break "connection error";
323 }
324 }
325 }
326 };
327
328 Self::cleanup_requests(&requests, &pending_requests, reason).await;
330 });
331 }
332
333 fn spawn_timeout_checker(&self, mut stop_rx: watch::Receiver<bool>) {
335 let requests = Arc::clone(&self.requests);
336 let pending_requests = Arc::clone(&self.pending_requests);
337
338 tokio::spawn(async move {
339 let mut interval = tokio::time::interval(Duration::from_secs(5));
340
341 loop {
342 tokio::select! {
343 _ = stop_rx.changed() => {
344 if *stop_rx.borrow() {
345 break;
346 }
347 }
348 _ = interval.tick() => {
349 let now = Instant::now();
350 let mut expired = Vec::new();
351
352 for entry in requests.iter() {
354 let age = now.duration_since(entry.created_at);
355 if age > DEFAULT_TIMEOUT * 2 {
356 expired.push(*entry.key());
357 }
358 }
359
360 for handle in expired {
362 if let Some((_, req)) = requests.remove(&handle) {
363 pending_requests.fetch_sub(1, Ordering::Relaxed);
364 eprintln!("Request {} timed out after {:?}", handle,
365 now.duration_since(req.created_at));
366 }
368 }
369 }
370 }
371 }
372 });
373 }
374
375 async fn async_job(
377 tcp: &mut ReadHalf<TcpStream>,
378 requests: &Arc<DashMap<i64, AsyncNetworkRequest>>,
379 pending_requests: &Arc<AtomicUsize>,
380 ) -> Result<(), VoltError> {
381 let mut len_buf = [0u8; 4];
383 tcp.read_exact(&mut len_buf).await?;
384 let msg_len = BigEndian::read_u32(&len_buf) as usize;
385
386 if msg_len > MAX_MESSAGE_SIZE {
388 return Err(VoltError::MessageTooLarge(msg_len));
389 }
390
391 if msg_len == 0 {
392 return Ok(());
393 }
394
395 let mut buf = BytesMut::with_capacity(msg_len);
397 buf.resize(msg_len, 0);
398 tcp.read_exact(&mut buf).await?;
399
400 let _ = buf.get_u8();
402 let handle = buf.get_i64();
403
404 if handle == PING_HANDLE {
406 return Ok(());
407 }
408
409 if let Some((_, req)) = requests.remove(&handle) {
411 pending_requests.fetch_sub(1, Ordering::Relaxed);
412
413 let frozen_buf = buf.freeze();
415
416 tokio::spawn(async move {
418 match Self::parse_response(frozen_buf, handle) {
419 Ok(table) => {
420 let _ = req.channel.send(table).await;
421 }
422 Err(e) => {
423 eprintln!("Parse error for handle {}: {}", handle, e);
424 }
426 }
427 });
428 } else {
429 eprintln!("Received response for unknown handle: {}", handle);
430 }
431
432 Ok(())
433 }
434
435 fn parse_response(buf: bytes::Bytes, handle: i64) -> Result<VoltTable, VoltError> {
437 let mut byte_buf = bytebuffer::ByteBuffer::from_bytes(&buf[..]);
439 let info = VoltResponseInfo::new(&mut byte_buf, handle)?;
440 let table = new_volt_table(&mut byte_buf, info)?;
441 Ok(table)
442 }
443
444 async fn cleanup_requests(
446 requests: &Arc<DashMap<i64, AsyncNetworkRequest>>,
447 pending_requests: &Arc<AtomicUsize>,
448 reason: &str,
449 ) {
450 let pending_count = requests.len();
451
452 if pending_count > 0 {
453 eprintln!("Cleaning up {} pending requests: {}", pending_count, reason);
454 }
455
456 requests.clear();
458 pending_requests.store(0, Ordering::Relaxed);
459 }
460}
461
462pub async fn async_block_for_result(
464 rx: &mut mpsc::Receiver<VoltTable>,
465) -> Result<VoltTable, VoltError> {
466 match rx.recv().await {
467 Some(table) => Ok(table), None => Err(VoltError::ConnectionNotAvailable),
469 }
470}
471
472pub async fn async_block_for_result_with_timeout(
474 rx: &mut mpsc::Receiver<VoltTable>,
475 timeout_duration: Duration,
476) -> Result<VoltTable, VoltError> {
477 match timeout(timeout_duration, rx.recv()).await {
478 Ok(Some(mut table)) => match table.has_error() {
479 None => Ok(table),
480 Some(err) => Err(err),
481 },
482 Ok(None) => Err(VoltError::ConnectionNotAvailable),
483 Err(_) => Err(VoltError::Timeout),
484 }
485}
486
487impl VoltError {
489 pub fn message_too_large(size: usize) -> Self {
490 VoltError::MessageTooLarge(size)
491 }
492
493 pub fn connection_closed() -> Self {
494 VoltError::ConnectionClosed
495 }
496
497 pub fn timeout() -> Self {
498 VoltError::Timeout
499 }
500}
501
502#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[tokio::test]
508 async fn test_sequence_generation() {
509 let node = AsyncNode {
510 write_tx: mpsc::channel(1).0,
511 info: ConnInfo::default(),
512 requests: Arc::new(DashMap::new()),
513 stop: Arc::new(watch::channel(false).0),
514 counter: Arc::new(AtomicI64::new(1)),
515 pending_requests: Arc::new(AtomicUsize::new(0)),
516 };
517
518 let seq1 = node.get_sequence();
519 let seq2 = node.get_sequence();
520 assert_eq!(seq2, seq1 + 1);
521 }
522
523 #[tokio::test]
524 async fn test_pending_count() {
525 let node = AsyncNode {
526 write_tx: mpsc::channel(1).0,
527 info: ConnInfo::default(),
528 requests: Arc::new(DashMap::new()),
529 stop: Arc::new(watch::channel(false).0),
530 counter: Arc::new(AtomicI64::new(1)),
531 pending_requests: Arc::new(AtomicUsize::new(5)),
532 };
533 assert_eq!(node.pending_count(), 5);
534 }
535}