titan_client/tcp/tcp_client_blocking.rs
1use std::{
2 io::{BufRead, BufReader, Write},
3 net::{TcpStream, ToSocketAddrs},
4 sync::{
5 atomic::{AtomicBool, Ordering},
6 mpsc, Arc, Mutex,
7 },
8 thread::{self, JoinHandle},
9 time::Duration,
10};
11
12use serde_json;
13use thiserror::Error;
14use titan_types_api::TcpSubscriptionRequest;
15use titan_types_core::Event;
16use tracing::{debug, error, info, warn};
17
18use crate::tcp::reconnection::ReconnectionManager;
19
20use super::{
21 connection_status::{ConnectionStatus, ConnectionStatusTracker},
22 reconnection,
23};
24
25#[derive(Debug, Error)]
26pub enum TcpClientError {
27 #[error("io error: {0}")]
28 IOError(#[from] std::io::Error),
29 #[error("serde error: {0}")]
30 SerdeError(#[from] serde_json::Error),
31 #[error("address parse error: {0}")]
32 AddrParseError(String),
33}
34/// Configuration for TCP client reconnection.
35#[derive(Debug, Clone)]
36pub struct TcpClientConfig {
37 /// Base duration for reconnect interval (will be used with exponential backoff)
38 pub base_reconnect_interval: Duration,
39 /// Maximum reconnect interval (cap for exponential backoff)
40 pub max_reconnect_interval: Duration,
41 /// Maximum number of reconnection attempts.
42 /// Use `None` for unlimited attempts.
43 pub max_reconnect_attempts: Option<u32>,
44 /// Connection timeout.
45 pub connection_timeout: Duration,
46 /// Initial capacity of the read buffer (in bytes)
47 pub read_buffer_capacity: usize,
48 /// Maximum allowed size for the read buffer (in bytes)
49 pub max_buffer_size: usize,
50 /// Interval between ping messages
51 pub ping_interval: Duration,
52 /// Timeout for waiting for pong responses
53 pub pong_timeout: Duration,
54}
55
56impl Default for TcpClientConfig {
57 fn default() -> Self {
58 TcpClientConfig {
59 base_reconnect_interval: Duration::from_secs(1),
60 max_reconnect_interval: Duration::from_secs(60),
61 max_reconnect_attempts: None,
62 connection_timeout: Duration::from_secs(30),
63 read_buffer_capacity: 4096, // 4KB initial capacity
64 max_buffer_size: 10 * 1024 * 1024, // 10MB max buffer size
65 ping_interval: Duration::from_secs(30), // Send ping every 30 seconds
66 pong_timeout: Duration::from_secs(10), // Wait 10 seconds for pong response
67 }
68 }
69}
70
71/// Synchronous TCP subscription listener with reconnection support.
72///
73/// Connects to the TCP server at `addr` and sends the given subscription request
74/// (encoded as JSON). It then spawns a dedicated thread that reads lines from the TCP
75/// connection. If the connection drops or an error occurs, it will attempt to reconnect
76/// according to the configuration settings.
77///
78/// # Thread Management
79///
80/// This client spawns a background thread to handle the TCP connection and event processing.
81/// To ensure proper cleanup, you should call `shutdown_and_join()` when you're done with the
82/// client. If you don't call this method, the background thread will be automatically
83/// signaled to shut down when the `TcpClient` is dropped, but the thread may continue
84/// running briefly after the client is dropped.
85///
86/// ```
87/// # use client::tcp_client_blocking::{TcpClient, TcpClientConfig};
88/// # fn main() {
89/// let client = TcpClient::new(TcpClientConfig::default());
90/// // Use the client...
91///
92/// // When done, ensure clean shutdown
93/// client.shutdown_and_join();
94/// # }
95/// ```
96#[cfg(feature = "tcp_client_blocking")]
97pub struct TcpClient {
98 shutdown_flag: Arc<AtomicBool>,
99 config: TcpClientConfig,
100 status_tracker: ConnectionStatusTracker,
101 worker_thread: Mutex<Option<JoinHandle<()>>>,
102}
103
104#[cfg(feature = "tcp_client_blocking")]
105impl TcpClient {
106 /// Creates a new TCP client with the given configuration.
107 pub fn new(config: TcpClientConfig) -> Self {
108 Self {
109 shutdown_flag: Arc::new(AtomicBool::new(false)),
110 config,
111 status_tracker: ConnectionStatusTracker::new(),
112 worker_thread: Mutex::new(None),
113 }
114 }
115
116 /// Get the current connection status
117 pub fn get_status(&self) -> ConnectionStatus {
118 self.status_tracker.get_status()
119 }
120
121 /// Get whether the client was disconnected at any point in time
122 pub fn create_status_subscriber(&self) -> mpsc::Receiver<ConnectionStatus> {
123 let (tx, rx) = mpsc::channel();
124 self.status_tracker.register_listener(tx);
125 rx
126 }
127
128 /// Checks if there is an active worker thread.
129 ///
130 /// Returns true if a worker thread is currently running.
131 pub fn has_active_thread(&self) -> bool {
132 match self.worker_thread.lock() {
133 Ok(lock) => lock.is_some(),
134 Err(_) => {
135 error!("Failed to acquire worker thread lock");
136 false
137 }
138 }
139 }
140
141 /// Subscribe to events from the given address.
142 ///
143 /// This will spawn a background thread that connects to the server and
144 /// listens for events. The events will be sent to the returned channel.
145 ///
146 /// If there's already an active worker thread, it will be shut down and
147 /// a new one will be created.
148 pub fn subscribe(
149 &self,
150 addr: String,
151 subscription_request: TcpSubscriptionRequest,
152 ) -> Result<mpsc::Receiver<Event>, TcpClientError> {
153 // Check if we already have a worker thread running
154 let mut worker_lock = self.worker_thread.lock().map_err(|_| {
155 TcpClientError::IOError(std::io::Error::new(
156 std::io::ErrorKind::Other,
157 "Failed to acquire worker thread lock",
158 ))
159 })?;
160
161 // If there is an active thread, signal it to shutdown and join it
162 if let Some(handle) = worker_lock.take() {
163 info!("Stopping existing subscription thread before re-subscribing...");
164 self.shutdown_flag.store(true, Ordering::SeqCst);
165 match handle.join() {
166 Ok(_) => info!("Successfully joined existing worker thread"),
167 Err(e) => error!("Failed to join existing worker thread: {:?}", e),
168 }
169 }
170
171 // Reset shutdown flag in case it was previously set
172 self.shutdown_flag.store(false, Ordering::SeqCst);
173
174 let shutdown_flag = self.shutdown_flag.clone();
175 let config = self.config.clone();
176 let status_tracker = self.status_tracker.clone();
177
178 // Call the subscribe function which returns both the receiver and thread handle
179 let (rx, handle) = subscribe(
180 addr,
181 subscription_request,
182 shutdown_flag,
183 config,
184 status_tracker,
185 )?;
186
187 // Store the thread handle for later joining
188 *worker_lock = Some(handle);
189
190 Ok(rx)
191 }
192
193 /// Signals the client to shut down and stop any reconnection attempts.
194 /// Does not wait for the worker thread to complete.
195 pub fn shutdown(&self) {
196 self.status_tracker
197 .update_status(ConnectionStatus::Disconnected);
198 self.shutdown_flag.store(true, Ordering::SeqCst);
199 }
200
201 /// Signals the client to shut down and waits for the worker thread to complete.
202 /// Returns true if the thread was successfully joined, false otherwise.
203 pub fn shutdown_and_join(&self) -> bool {
204 // Signal shutdown
205 self.shutdown();
206
207 // Try to join the thread
208 self.join()
209 }
210
211 /// Waits for the worker thread to complete.
212 /// Returns true if the thread was successfully joined, false otherwise.
213 pub fn join(&self) -> bool {
214 // Acquire the lock on the worker thread
215 let mut worker_lock = match self.worker_thread.lock() {
216 Ok(lock) => lock,
217 Err(e) => {
218 error!("Failed to acquire worker thread lock: {}", e);
219 return false;
220 }
221 };
222
223 // Take the thread handle out (replacing it with None)
224 if let Some(handle) = worker_lock.take() {
225 match handle.join() {
226 Ok(_) => {
227 info!("Successfully joined worker thread");
228 true
229 }
230 Err(e) => {
231 error!("Failed to join worker thread: {:?}", e);
232 false
233 }
234 }
235 } else {
236 // No thread to join
237 false
238 }
239 }
240}
241
242#[cfg(feature = "tcp_client_blocking")]
243impl Drop for TcpClient {
244 fn drop(&mut self) {
245 // Signal thread to terminate
246 self.shutdown();
247
248 // Attempt to join the thread directly in the destructor
249 // This is safe because we're taking ownership of the JoinHandle
250 if let Ok(mut worker_lock) = self.worker_thread.lock() {
251 if let Some(handle) = worker_lock.take() {
252 // Don't block for too long in a destructor - it's generally not good practice
253 // Just log that we're not waiting for the thread
254 info!("TcpClient dropped, thread will continue running until shutdown completes");
255 }
256 }
257 // The shutdown flag has been set, so the thread will terminate naturally
258 }
259}
260
261fn subscribe(
262 addr: String,
263 subscription_request: TcpSubscriptionRequest,
264 shutdown_flag: Arc<AtomicBool>,
265 config: TcpClientConfig,
266 status_tracker: ConnectionStatusTracker,
267) -> Result<(mpsc::Receiver<Event>, JoinHandle<()>), TcpClientError> {
268 // Create a standard mpsc channel to forward events.
269 let (tx, rx) = mpsc::channel::<Event>();
270
271 let address = addr
272 .to_socket_addrs()
273 .map_err(|_| TcpClientError::AddrParseError(format!("Invalid address: {}", addr)))?
274 .next()
275 .ok_or(TcpClientError::AddrParseError(format!(
276 "Invalid address: {}",
277 addr
278 )))?;
279
280 // Set initial status to Connecting
281 status_tracker.update_status(ConnectionStatus::Connecting);
282
283 // Create the reconnection manager
284 let reconnection_config = reconnection::from_tcp_client_config(&config);
285
286 let handle = thread::spawn(move || {
287 // Create a status updater for use in the thread
288 let update_status = status_tracker.create_updater();
289
290 // Create the reconnection manager
291 let mut reconnection_manager = ReconnectionManager::new(reconnection_config);
292
293 loop {
294 if shutdown_flag.load(Ordering::SeqCst) {
295 info!("Shutdown flag set. Exiting subscription thread.");
296 // Set status to disconnected
297 update_status(ConnectionStatus::Disconnected);
298 break;
299 }
300
301 // Try to connect to the server.
302 info!("Attempting to connect to {}...", addr);
303 // Ensure status is set to Connecting
304 update_status(ConnectionStatus::Connecting);
305
306 let connect_result = TcpStream::connect_timeout(&address, config.connection_timeout);
307
308 match connect_result {
309 Ok(mut stream) => {
310 info!("Connected to server at {}", addr);
311 // Update status to Connected
312 update_status(ConnectionStatus::Connected);
313
314 // Reset reconnection attempts after successful connection
315 reconnection_manager.reset();
316
317 // Set read timeout - use shorter timeout to allow for ping checks
318 if let Err(e) = stream.set_read_timeout(Some(Duration::from_millis(500))) {
319 error!("Failed to set read timeout: {}", e);
320 continue;
321 }
322
323 // Set write timeout
324 if let Err(e) = stream.set_write_timeout(Some(Duration::from_secs(5))) {
325 error!("Failed to set write timeout: {}", e);
326 continue;
327 }
328
329 // Clone the stream for reading.
330 let reader_stream = match stream.try_clone() {
331 Ok(rs) => rs,
332 Err(e) => {
333 error!("Failed to clone stream: {}", e);
334 continue;
335 }
336 };
337 let mut reader = BufReader::new(reader_stream);
338
339 // Serialize and send the subscription request.
340 match serde_json::to_string(&subscription_request) {
341 Ok(req_json) => {
342 if let Err(e) = stream.write_all(req_json.as_bytes()) {
343 error!("Failed to send subscription request: {}", e);
344 continue;
345 }
346 if let Err(e) = stream.write_all(b"\n") {
347 error!("Failed to send newline: {}", e);
348 continue;
349 }
350 if let Err(e) = stream.flush() {
351 error!("Failed to flush stream: {}", e);
352 continue;
353 }
354 }
355 Err(e) => {
356 error!("Failed to serialize subscription request: {}", e);
357 break;
358 }
359 }
360
361 // Initialize the byte buffer with the configured capacity
362 let mut byte_buf = Vec::with_capacity(config.read_buffer_capacity);
363
364 // Ping-pong state tracking
365 let mut last_ping_time = std::time::Instant::now();
366 let mut last_pong_time = std::time::Instant::now();
367 let mut awaiting_pong = false;
368
369 // Inner loop: read events from the connection with ping/pong support
370 loop {
371 if shutdown_flag.load(Ordering::SeqCst) {
372 info!("Shutdown flag set. Exiting inner read loop.");
373 update_status(ConnectionStatus::Disconnected);
374 break;
375 }
376
377 // Current time
378 let now = std::time::Instant::now();
379
380 // Handle ping-pong logic
381 if now.duration_since(last_ping_time) >= config.ping_interval {
382 if awaiting_pong {
383 // Check if we've exceeded the pong timeout
384 if now.duration_since(last_pong_time) >= config.pong_timeout {
385 warn!("Pong response timed out after {:?}, considering connection dead",
386 now.duration_since(last_pong_time));
387 update_status(ConnectionStatus::Reconnecting);
388 break;
389 }
390 } else {
391 // Time to send a ping
392 match stream.write_all(b"PING\n") {
393 Ok(_) => {
394 if let Err(e) = stream.flush() {
395 error!("Failed to flush after PING: {}", e);
396 update_status(ConnectionStatus::Reconnecting);
397 break;
398 }
399 last_ping_time = now;
400 awaiting_pong = true;
401 }
402 Err(e) => {
403 error!("Failed to send PING: {}", e);
404 update_status(ConnectionStatus::Reconnecting);
405 break;
406 }
407 }
408 }
409 }
410
411 // Set read timeout to allow for ping checks and shutdown signals
412 if let Err(e) = stream.set_read_timeout(Some(Duration::from_millis(50))) {
413 error!("Failed to set read timeout: {}", e);
414 update_status(ConnectionStatus::Reconnecting);
415 break;
416 }
417
418 // Try to read until newline
419 match reader.read_until(b'\n', &mut byte_buf) {
420 Ok(0) => {
421 // Connection closed by server
422 warn!("TCP connection closed by server. Attempting to reconnect.");
423 update_status(ConnectionStatus::Reconnecting);
424 break;
425 }
426 Ok(n) if n > 0 => {
427 // Note: read_until includes the delimiter in the buffer.
428 // Trim whitespace and the trailing newline before processing.
429 let message_bytes = byte_buf.trim_ascii_end();
430
431 if !message_bytes.is_empty() {
432 // Check if this is a PONG response
433 if message_bytes == b"PONG" {
434 if awaiting_pong {
435 awaiting_pong = false;
436 last_pong_time = std::time::Instant::now();
437 debug!("Received PONG");
438 }
439 } else {
440 // Check if message size exceeds limit *before* parsing JSON
441 if message_bytes.len() > config.max_buffer_size {
442 error!(
443 "Received message exceeds maximum allowed size ({}), skipping. Message starts with: {:?}",
444 config.max_buffer_size,
445 String::from_utf8_lossy(&message_bytes[..std::cmp::min(message_bytes.len(), 50)]) // Log first 50 bytes
446 );
447 // Don't break, just clear buffer and continue reading the next message.
448 } else {
449 // Try to parse as an event from the byte slice
450 match serde_json::from_slice::<Event>(message_bytes) {
451 Ok(event) => {
452 // Any successful message means the connection is alive
453 last_pong_time = std::time::Instant::now();
454 awaiting_pong = false; // Reset awaiting_pong if we received a valid event
455
456 if tx.send(event).is_err() {
457 error!("Receiver dropped. Exiting subscription thread.");
458 update_status(
459 ConnectionStatus::Disconnected,
460 ); // Set status before returning
461 return; // Exit the thread
462 }
463 }
464 Err(e) => {
465 error!(
466 "Failed to parse event: {}. Raw data (first 100 bytes): {:?}",
467 e,
468 String::from_utf8_lossy(&message_bytes[..std::cmp::min(message_bytes.len(), 100)])
469 );
470 // Consider if this error should cause a reconnect or just skip
471 // For now, let's try reconnecting on parse error for safety
472 update_status(ConnectionStatus::Reconnecting);
473 break; // Trigger reconnect on parse error
474 }
475 }
476 }
477 }
478 }
479 // Clear the buffer for the next message AFTER processing the current one
480 byte_buf.clear();
481 }
482 Ok(_) => {
483 // n == 0, should be handled by Ok(0) case, safety belt
484 byte_buf.clear();
485 }
486 Err(e) => {
487 if e.kind() == std::io::ErrorKind::TimedOut
488 || e.kind() == std::io::ErrorKind::WouldBlock
489 {
490 // Expected timeout - continue the loop to check ping/shutdown
491 continue;
492 } else {
493 // Real error
494 error!("Error reading from TCP socket using read_until: {}", e);
495 update_status(ConnectionStatus::Reconnecting);
496 break; // Break inner loop to trigger reconnect
497 }
498 }
499 }
500
501 // Check if buffer capacity is exceeding limits (less critical with clear(), but good safety)
502 if byte_buf.capacity() > config.max_buffer_size {
503 error!("Buffer capacity exceeded maximum allowed size ({}), resetting connection.", config.max_buffer_size);
504 update_status(ConnectionStatus::Reconnecting);
505 break;
506 }
507 } // end inner loop for current connection
508
509 // When we exit the inner loop (connection lost or shutdown)
510 // Update status to reconnecting only if not shutting down
511 if !shutdown_flag.load(Ordering::SeqCst) {
512 update_status(ConnectionStatus::Reconnecting);
513 }
514 }
515 Err(e) => {
516 error!("Failed to connect to {}: {}", addr, e);
517 // Set status to reconnecting since we're going to try again
518 update_status(ConnectionStatus::Reconnecting);
519 }
520 }
521
522 // Before attempting reconnect, check shutdown flag again
523 if shutdown_flag.load(Ordering::SeqCst) {
524 update_status(ConnectionStatus::Disconnected);
525 break;
526 }
527
528 // Get the next delay from the reconnection manager
529 match reconnection_manager.next_delay() {
530 Some(wait_time) => {
531 info!(
532 "Reconnecting in {:?}... (attempt {}/{:?})",
533 wait_time,
534 reconnection_manager.current_attempt(),
535 reconnection_manager.config().max_attempts
536 );
537 // Use a flag-aware sleep
538 let sleep_start = std::time::Instant::now();
539 while sleep_start.elapsed() < wait_time {
540 if shutdown_flag.load(Ordering::SeqCst) {
541 info!("Shutdown detected during reconnect wait.");
542 update_status(ConnectionStatus::Disconnected);
543 return; // Exit thread immediately
544 }
545 thread::sleep(Duration::from_millis(50)); // Check flag periodically
546 }
547 }
548 None => {
549 error!(
550 "Reached maximum reconnection attempts ({}). Exiting.",
551 reconnection_manager.config().max_attempts.unwrap_or(0)
552 );
553 // Set status to disconnected when max attempts reached
554 update_status(ConnectionStatus::Disconnected);
555 break;
556 }
557 }
558 }
559 info!("Exiting TCP subscription thread.");
560 // Ensure status is Disconnected when thread exits naturally
561 update_status(ConnectionStatus::Disconnected);
562 });
563
564 Ok((rx, handle))
565}
566
567#[cfg(test)]
568mod tests {
569 use super::*;
570 use std::io::{BufRead, Read, Write};
571 use std::net::{SocketAddr, TcpListener};
572 use std::thread;
573 use std::time::Duration;
574 use titan_types_core::EventType;
575
576 // Helper function to create a test TCP server
577 fn start_test_server(ready_tx: std::sync::mpsc::Sender<SocketAddr>) -> JoinHandle<()> {
578 thread::spawn(move || {
579 // Bind to a random available port
580 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
581 let addr = listener.local_addr().unwrap();
582
583 // Notify the test that we're ready and send the address
584 ready_tx.send(addr).unwrap();
585
586 // Accept one connection
587 if let Ok((mut stream, _)) = listener.accept() {
588 let mut reader = BufReader::new(stream.try_clone().unwrap());
589 let mut request_buf = Vec::new();
590
591 // Read the subscription request
592 match reader.read_until(b'\n', &mut request_buf) {
593 Ok(n) if n > 0 => {
594 let request_bytes = request_buf.trim_ascii_end();
595 println!(
596 "Server received request: {}",
597 String::from_utf8_lossy(request_bytes)
598 );
599
600 // Add a small delay to ensure the client is ready to receive
601 thread::sleep(Duration::from_millis(50));
602
603 // Send a sample event - using correct format for Event
604 let event = r#"{"type":"TransactionsAdded","data": {"txids":["1111111111111111111111111111111111111111111111111111111111111111"]}}"#;
605 if let Err(e) = stream.write_all(event.as_bytes()) {
606 println!("Server write error: {}", e);
607 return;
608 }
609 if let Err(e) = stream.write_all(b"\n") {
610 println!("Server write error: {}", e);
611 return;
612 }
613 if let Err(e) = stream.flush() {
614 println!("Server flush error: {}", e);
615 return;
616 }
617
618 // Keep the connection open for a while to ensure the client can read the response
619 thread::sleep(Duration::from_millis(500));
620 }
621 Ok(0) => println!("Server: Client disconnected before sending request"),
622 Err(e) => println!("Test server read error: {}", e),
623 _ => println!("Server: Unexpected read result for request"),
624 }
625 }
626 })
627 }
628
629 #[test]
630 fn test_connection_status_transitions() {
631 // Create a channel to sync with the test server
632 let (ready_tx, ready_rx) = std::sync::mpsc::channel();
633
634 // Start a test server
635 let server_handle = start_test_server(ready_tx);
636
637 // Wait for the server to be ready and get its address
638 let server_addr = ready_rx.recv_timeout(Duration::from_secs(5)).unwrap();
639
640 // Create a client with short timeout
641 let config = TcpClientConfig {
642 connection_timeout: Duration::from_secs(1),
643 max_reconnect_attempts: Some(1),
644 base_reconnect_interval: Duration::from_millis(100),
645 ..TcpClientConfig::default()
646 };
647 let client = TcpClient::new(config);
648
649 // Initially disconnected
650 assert_eq!(client.get_status(), ConnectionStatus::Disconnected);
651
652 // Subscribe - this should connect
653 let subscription_request = TcpSubscriptionRequest {
654 subscribe: vec![EventType::TransactionsAdded],
655 };
656
657 let rx = client
658 .subscribe(format!("{}", server_addr), subscription_request)
659 .unwrap();
660
661 // Give it time to connect
662 thread::sleep(Duration::from_millis(100));
663
664 // Should be connected now
665 assert_eq!(client.get_status(), ConnectionStatus::Connected);
666
667 // Shutdown the client
668 client.shutdown_and_join();
669
670 // Check the client is disconnected
671 assert_eq!(client.get_status(), ConnectionStatus::Disconnected);
672
673 // Wait for the server to finish
674 server_handle.join().unwrap();
675 }
676
677 #[test]
678 fn test_receive_events() {
679 // Create a channel to sync with the test server
680 let (ready_tx, ready_rx) = std::sync::mpsc::channel();
681
682 // Start a test server
683 let server_handle = start_test_server(ready_tx);
684
685 // Wait for the server to be ready and get its address
686 let server_addr = ready_rx.recv_timeout(Duration::from_secs(5)).unwrap();
687
688 // Create a client with short timeout
689 let config = TcpClientConfig {
690 connection_timeout: Duration::from_secs(1),
691 max_reconnect_attempts: Some(1),
692 base_reconnect_interval: Duration::from_millis(100),
693 ..TcpClientConfig::default()
694 };
695 let client = TcpClient::new(config);
696
697 // Subscribe to receive events
698 let subscription_request = TcpSubscriptionRequest {
699 subscribe: vec![EventType::TransactionsAdded],
700 };
701
702 let rx = client
703 .subscribe(format!("{}", server_addr), subscription_request)
704 .unwrap();
705
706 // Give it time to establish connection
707 thread::sleep(Duration::from_millis(200));
708
709 // Try to receive an event with timeout
710 let event = rx.recv_timeout(Duration::from_secs(2));
711 assert!(event.is_ok(), "Should have received an event");
712
713 match event.unwrap() {
714 Event::TransactionsAdded { txids } => {
715 assert_eq!(txids.len(), 1);
716 assert_eq!(
717 txids[0].to_string(),
718 "1111111111111111111111111111111111111111111111111111111111111111"
719 );
720 }
721 other => panic!("Received unexpected event type: {:?}", other),
722 }
723
724 // Shutdown the client
725 client.shutdown_and_join();
726
727 // Wait for the server to finish
728 server_handle.join().unwrap();
729 }
730
731 #[test]
732 fn test_connection_error_handling() {
733 // Create a client with short timeout
734 let config = TcpClientConfig {
735 connection_timeout: Duration::from_secs(1),
736 max_reconnect_attempts: Some(2),
737 base_reconnect_interval: Duration::from_millis(100),
738 ..TcpClientConfig::default()
739 };
740 let client = TcpClient::new(config);
741
742 // Initially disconnected
743 assert_eq!(client.get_status(), ConnectionStatus::Disconnected);
744
745 // Try to connect to a non-existent server
746 let subscription_request = TcpSubscriptionRequest {
747 subscribe: vec![EventType::TransactionsAdded],
748 };
749
750 let _rx = client
751 .subscribe("127.0.0.1:1".to_string(), subscription_request)
752 .unwrap();
753
754 // Give it time to attempt connection and reconnection
755 thread::sleep(Duration::from_millis(500));
756
757 // Should be in reconnecting state or disconnected if it already gave up
758 let status = client.get_status();
759 assert!(
760 status == ConnectionStatus::Reconnecting || status == ConnectionStatus::Disconnected,
761 "Expected Reconnecting or Disconnected state, got {:?}",
762 status
763 );
764
765 // Shutdown the client
766 client.shutdown_and_join();
767
768 // Check the client is disconnected
769 assert_eq!(client.get_status(), ConnectionStatus::Disconnected);
770 }
771
772 #[test]
773 fn test_resource_cleanup() {
774 // Create a client
775 let client = TcpClient::new(TcpClientConfig::default());
776
777 // Subscribe to a non-existent server
778 let subscription_request = TcpSubscriptionRequest {
779 subscribe: vec![EventType::TransactionsAdded],
780 };
781
782 let rx = client
783 .subscribe("127.0.0.1:1".to_string(), subscription_request)
784 .unwrap();
785
786 // Verify we have an active thread
787 assert!(client.has_active_thread());
788
789 // Drop the receiver channel
790 drop(rx);
791
792 // Give the thread a moment to notice the receiver is dropped (if applicable)
793 thread::sleep(Duration::from_millis(50));
794
795 // Shutdown and join the client
796 let joined = client.shutdown_and_join();
797 assert!(joined, "Should have successfully joined the worker thread");
798
799 // Verify we no longer have an active thread
800 assert!(!client.has_active_thread());
801 }
802
803 #[test]
804 fn test_resubscribe_stops_previous_thread() {
805 // Use short timeouts so the test runs quickly even if connections fail
806 let config = TcpClientConfig {
807 connection_timeout: Duration::from_millis(100),
808 max_reconnect_attempts: Some(2),
809 base_reconnect_interval: Duration::from_millis(50),
810 ..TcpClientConfig::default()
811 };
812 let client = TcpClient::new(config);
813
814 // First subscription to a non-existent server; this will start a worker thread
815 let _rx1 = client
816 .subscribe(
817 "127.0.0.1:1".to_string(),
818 TcpSubscriptionRequest {
819 subscribe: vec![EventType::TransactionsAdded],
820 },
821 )
822 .unwrap();
823
824 // Give the worker thread a moment to start
825 thread::sleep(Duration::from_millis(50));
826 assert!(
827 client.has_active_thread(),
828 "Expected an active worker thread after first subscribe"
829 );
830
831 // Second subscription should signal the previous worker to shut down, join it,
832 // and then start a new worker thread.
833 let _rx2 = client
834 .subscribe(
835 "127.0.0.1:1".to_string(),
836 TcpSubscriptionRequest {
837 subscribe: vec![EventType::TransactionsAdded],
838 },
839 )
840 .unwrap();
841
842 // After re-subscribing we should still have exactly one active worker thread,
843 // and the call should not hang (which would indicate the old thread didn't stop).
844 assert!(
845 client.has_active_thread(),
846 "Expected an active worker thread after re-subscribing"
847 );
848
849 // Finally, shutting down the client should cleanly join the active worker
850 let joined = client.shutdown_and_join();
851 assert!(joined, "Expected worker thread to be joined after shutdown");
852 assert!(
853 !client.has_active_thread(),
854 "Expected no active worker thread after shutdown"
855 );
856 }
857
858 // Helper function to create a test TCP server that handles ping/pong
859 fn start_ping_pong_server(ready_tx: std::sync::mpsc::Sender<SocketAddr>) -> JoinHandle<()> {
860 thread::spawn(move || {
861 // Bind to a random available port
862 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
863 let addr = listener.local_addr().unwrap();
864
865 // Notify the test that we're ready and send the address
866 ready_tx.send(addr).unwrap();
867
868 // Accept one connection
869 if let Ok((mut stream, _)) = listener.accept() {
870 // Set a read timeout so we don't block forever
871 stream
872 .set_read_timeout(Some(Duration::from_millis(200)))
873 .unwrap();
874
875 // Create a buffer reader
876 let mut reader = BufReader::new(stream.try_clone().unwrap());
877 let mut line_buf = Vec::new(); // Use Vec<u8> for read_until
878
879 // Read the subscription request
880 match reader.read_until(b'\n', &mut line_buf) {
881 Ok(n) if n > 0 => {
882 let request_bytes = line_buf.trim_ascii_end();
883 println!(
884 "Ping-pong server received request: {}",
885 String::from_utf8_lossy(request_bytes)
886 );
887
888 // Send a sample event
889 let event = r#"{"type":"TransactionsAdded","data": {"txids":["1111111111111111111111111111111111111111111111111111111111111111"]}}"#;
890 if let Err(e) = stream.write_all(event.as_bytes()) {
891 println!("Server write error: {}", e);
892 return;
893 }
894 if let Err(e) = stream.write_all(b"\n") {
895 println!("Server write error: {}", e);
896 return;
897 }
898 if let Err(e) = stream.flush() {
899 println!("Server flush error: {}", e);
900 return;
901 }
902 println!("Ping-pong server sent initial event");
903 }
904 Ok(0) => {
905 println!("Ping-pong server: Client disconnected early");
906 return;
907 }
908 _ => {
909 println!("Ping-pong server failed to read subscription request");
910 return;
911 }
912 }
913
914 // Clear line for next reads
915 line_buf.clear();
916
917 // Keep handling ping/pong for a while
918 let start = std::time::Instant::now();
919 let timeout = Duration::from_secs(5); // Run for 5 seconds
920
921 while start.elapsed() < timeout {
922 match reader.read_until(b'\n', &mut line_buf) {
923 // Use read_until here too
924 Ok(0) => {
925 println!("Ping-pong server: client closed connection");
926 break;
927 }
928 Ok(n) if n > 0 => {
929 let trimmed_line = line_buf.trim_ascii_end(); // Trim bytes
930 println!(
931 "Ping-pong server received: {}",
932 String::from_utf8_lossy(trimmed_line)
933 );
934
935 if trimmed_line == b"PING" {
936 // Compare bytes
937 println!("Ping-pong server sending PONG");
938 if let Err(e) = stream.write_all(b"PONG\n") {
939 println!("Ping-pong server failed to send PONG: {}", e);
940 break;
941 }
942 if let Err(e) = stream.flush() {
943 println!("Ping-pong server failed to flush PONG: {}", e);
944 break;
945 }
946 }
947 line_buf.clear(); // Clear buffer after processing
948 }
949 Ok(_) => {
950 /* n==0 case handled above */
951 line_buf.clear();
952 }
953 Err(e)
954 if e.kind() == std::io::ErrorKind::WouldBlock
955 || e.kind() == std::io::ErrorKind::TimedOut =>
956 {
957 // Expected timeout - continue
958 line_buf.clear(); // Ensure buffer is cleared even on timeout
959 }
960 Err(e) => {
961 println!("Ping-pong server error: {}", e);
962 break;
963 }
964 }
965
966 // Small sleep to prevent tight loop
967 thread::sleep(Duration::from_millis(50));
968 }
969
970 println!("Ping-pong server shutting down");
971 }
972 })
973 }
974
975 #[test]
976 fn test_ping_pong_mechanism() {
977 // Create a channel to sync with the test server
978 let (ready_tx, ready_rx) = std::sync::mpsc::channel();
979
980 // Start a ping-pong test server
981 let server_handle = start_ping_pong_server(ready_tx);
982
983 // Wait for the server to be ready and get its address
984 let server_addr = ready_rx.recv_timeout(Duration::from_secs(5)).unwrap();
985
986 // Create a client with short ping interval for faster testing
987 let config = TcpClientConfig {
988 connection_timeout: Duration::from_secs(1),
989 max_reconnect_attempts: Some(1),
990 base_reconnect_interval: Duration::from_millis(100),
991 ping_interval: Duration::from_millis(500), // Short ping interval for testing
992 pong_timeout: Duration::from_millis(1000), // 1 second timeout
993 ..TcpClientConfig::default()
994 };
995 let client = TcpClient::new(config);
996
997 // Subscribe to receive events
998 let subscription_request = TcpSubscriptionRequest {
999 subscribe: vec![EventType::TransactionsAdded],
1000 };
1001
1002 let _rx = client
1003 .subscribe(format!("{}", server_addr), subscription_request)
1004 .unwrap();
1005
1006 // Give it time to establish connection
1007 thread::sleep(Duration::from_millis(200));
1008
1009 // Verify connection status is connected
1010 assert_eq!(
1011 client.get_status(),
1012 ConnectionStatus::Connected,
1013 "Client should be connected"
1014 );
1015
1016 // Wait long enough for multiple ping/pong cycles
1017 thread::sleep(Duration::from_secs(2));
1018
1019 // Verify still connected after ping/pong cycles
1020 assert_eq!(
1021 client.get_status(),
1022 ConnectionStatus::Connected,
1023 "Client should remain connected after ping/pong exchanges"
1024 );
1025
1026 // Shutdown the client
1027 client.shutdown_and_join();
1028
1029 // Wait for the server to finish
1030 server_handle.join().unwrap();
1031 }
1032}