1use crate::actions::{EvaluationContext, RequestRule, ResponseRule};
2use crate::errors::{DoorkeeperError, ProxyError, WorkerError};
3use crate::frame::{
4 self, FrameOpcode, FrameParams, RequestFrame, ResponseFrame, ResponseOpcode,
5 read_response_frame, write_frame,
6};
7use crate::{RequestOpcode, TargetShard};
8use bytes::Bytes;
9use compression::no_compression;
10use scylla_cql::frame::types::read_string_multimap;
11use std::collections::HashMap;
12use std::env::VarError;
13use std::fmt::Display;
14use std::future::Future;
15use std::net::{IpAddr, Ipv4Addr, SocketAddr};
16use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, Ordering};
17use std::sync::{Arc, Mutex};
18use tokio::io::{AsyncRead, AsyncWrite};
19use tokio::net::{TcpListener, TcpSocket, TcpStream};
20use tokio::sync::mpsc::error::TryRecvError;
21use tokio::sync::{broadcast, mpsc};
22use tracing::{debug, error, info, trace, warn};
23
24type FinishWaiter = mpsc::Receiver<()>;
26type FinishGuard = mpsc::Sender<()>;
27
28type TerminateNotifier = tokio::sync::broadcast::Receiver<()>;
30type TerminateSignaler = tokio::sync::broadcast::Sender<()>;
31
32type ConnectionCloseNotifier = tokio::sync::broadcast::Receiver<()>;
35type ConnectionCloseSignaler = tokio::sync::broadcast::Sender<()>;
36
37type ErrorPropagator = mpsc::UnboundedSender<ProxyError>;
40type ErrorSink = mpsc::UnboundedReceiver<ProxyError>;
41
42struct ConnectionTracker {
49 active_count: std::sync::atomic::AtomicUsize,
50 notify: tokio::sync::Notify,
51}
52
53impl ConnectionTracker {
54 fn register_connection(self: &Arc<Self>) -> Arc<ConnectionLifetime> {
60 self.active_count.fetch_add(1, Ordering::Relaxed);
61 self.notify.notify_waiters();
62 Arc::new(ConnectionLifetime {
63 tracker: Arc::clone(self),
64 })
65 }
66}
67
68struct ConnectionLifetime {
74 tracker: Arc<ConnectionTracker>,
75}
76
77impl Drop for ConnectionLifetime {
78 fn drop(&mut self) {
79 self.tracker.active_count.fetch_sub(1, Ordering::Relaxed);
80 }
81}
82
83static HARDCODED_OPTIONS_PARAMS: FrameParams = FrameParams {
84 flags: 0,
85 version: 0x04,
86 stream: 0,
87};
88
89#[derive(Clone, Copy, Debug)]
91pub enum ShardAwareness {
92 Unaware,
94 QueryNode,
100 FixedNum(u16),
102}
103
104impl ShardAwareness {
105 pub fn is_aware(&self) -> bool {
106 !matches!(self, Self::Unaware)
107 }
108}
109
110enum NodeType {
130 Real {
131 real_addr: SocketAddr,
132 shard_awareness: ShardAwareness,
133 response_rules: Option<Vec<ResponseRule>>,
134 },
135 Simulated,
136}
137
138pub struct Node {
139 proxy_addr: SocketAddr,
140 request_rules: Option<Vec<RequestRule>>,
141 node_type: NodeType,
142}
143
144impl Node {
145 pub fn new(
147 real_addr: SocketAddr,
148 proxy_addr: SocketAddr,
149 shard_awareness: ShardAwareness,
150 request_rules: Option<Vec<RequestRule>>,
151 response_rules: Option<Vec<ResponseRule>>,
152 ) -> Self {
153 Self {
154 proxy_addr,
155 request_rules,
156 node_type: NodeType::Real {
157 real_addr,
158 shard_awareness,
159 response_rules,
160 },
161 }
162 }
163
164 pub fn new_dry_mode(proxy_addr: SocketAddr, request_rules: Option<Vec<RequestRule>>) -> Self {
166 Self {
167 proxy_addr,
168 request_rules,
169 node_type: NodeType::Simulated,
170 }
171 }
172
173 pub fn builder() -> NodeBuilder {
174 NodeBuilder {
175 real_addr: None,
176 proxy_addr: None,
177 shard_awareness: None,
178 request_rules: None,
179 response_rules: None,
180 }
181 }
182}
183
184pub struct NodeBuilder {
185 real_addr: Option<SocketAddr>,
186 proxy_addr: Option<SocketAddr>,
187 shard_awareness: Option<ShardAwareness>,
188 request_rules: Option<Vec<RequestRule>>,
189 response_rules: Option<Vec<ResponseRule>>,
190}
191
192impl NodeBuilder {
193 pub fn real_address(mut self, real_addr: SocketAddr) -> Self {
194 self.real_addr = Some(real_addr);
195 self
196 }
197
198 pub fn proxy_address(mut self, proxy_addr: SocketAddr) -> Self {
199 self.proxy_addr = Some(proxy_addr);
200 self
201 }
202
203 pub fn shard_awareness(mut self, shard_awareness: ShardAwareness) -> Self {
204 self.shard_awareness = Some(shard_awareness);
205 self
206 }
207
208 pub fn request_rules(mut self, request_rules: Vec<RequestRule>) -> Self {
209 self.request_rules = Some(request_rules);
210 self
211 }
212
213 pub fn response_rules(mut self, response_rules: Vec<ResponseRule>) -> Self {
214 self.response_rules = Some(response_rules);
215 self
216 }
217
218 pub fn build(self) -> Node {
220 Node {
221 proxy_addr: self.proxy_addr.expect("Proxy addr is required!"),
222 request_rules: self.request_rules,
223 node_type: NodeType::Real {
224 real_addr: self.real_addr.expect("Real addr is required!"),
225 shard_awareness: self.shard_awareness.expect("Shard awareness is required!"),
226 response_rules: self.response_rules,
227 },
228 }
229 }
230
231 pub fn build_dry_mode(self) -> Node {
233 Node {
234 proxy_addr: self.proxy_addr.expect("Proxy addr is required!"),
235 request_rules: self.request_rules,
236 node_type: NodeType::Simulated,
237 }
238 }
239}
240
241#[derive(Clone, Copy)]
242struct DisplayableRealAddrOption(Option<SocketAddr>);
243impl Display for DisplayableRealAddrOption {
244 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245 if let Some(addr) = self.0 {
246 write!(f, "{addr}")
247 } else {
248 write!(f, "<dry mode>")
249 }
250 }
251}
252
253#[derive(Clone, Copy)]
254struct DisplayableShard(Option<TargetShard>);
255impl Display for DisplayableShard {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 if let Some(shard) = self.0 {
258 write!(f, "shard {shard}")
259 } else {
260 write!(f, "unknown shard")
261 }
262 }
263}
264
265enum InternalNode {
266 Real {
267 real_addr: SocketAddr,
268 proxy_addr: SocketAddr,
269 shard_awareness: ShardAwareness,
270 request_rules: Arc<Mutex<Vec<RequestRule>>>,
271 response_rules: Arc<Mutex<Vec<ResponseRule>>>,
272 },
273 Simulated {
274 proxy_addr: SocketAddr,
275 request_rules: Arc<Mutex<Vec<RequestRule>>>,
276 },
277}
278
279impl InternalNode {
280 fn proxy_addr(&self) -> SocketAddr {
281 match *self {
282 InternalNode::Real { proxy_addr, .. } => proxy_addr,
283 InternalNode::Simulated { proxy_addr, .. } => proxy_addr,
284 }
285 }
286 fn real_addr(&self) -> Option<SocketAddr> {
287 match *self {
288 InternalNode::Real { real_addr, .. } => Some(real_addr),
289 InternalNode::Simulated { .. } => None,
290 }
291 }
292 fn request_rules(&self) -> &Arc<Mutex<Vec<RequestRule>>> {
293 match self {
294 InternalNode::Real { request_rules, .. } => request_rules,
295 InternalNode::Simulated { request_rules, .. } => request_rules,
296 }
297 }
298}
299
300impl From<Node> for InternalNode {
301 fn from(node: Node) -> Self {
302 match node.node_type {
303 NodeType::Real {
304 real_addr,
305 shard_awareness,
306 response_rules,
307 } => InternalNode::Real {
308 real_addr,
309 proxy_addr: node.proxy_addr,
310 shard_awareness,
311 request_rules: node
312 .request_rules
313 .map(|rules| Arc::new(Mutex::new(rules)))
314 .unwrap_or_default(),
315 response_rules: response_rules
316 .map(|rules| Arc::new(Mutex::new(rules)))
317 .unwrap_or_default(),
318 },
319 NodeType::Simulated => InternalNode::Simulated {
320 proxy_addr: node.proxy_addr,
321 request_rules: node
322 .request_rules
323 .map(|rules| Arc::new(Mutex::new(rules)))
324 .unwrap_or_default(),
325 },
326 }
327 }
328}
329
330pub struct ProxyBuilder {
331 nodes: Vec<Node>,
332}
333
334impl ProxyBuilder {
335 pub fn with_node(mut self, node: Node) -> ProxyBuilder {
336 self.nodes.push(node);
337 self
338 }
339
340 pub fn build(self) -> Proxy {
341 Proxy::new(self.nodes)
342 }
343}
344
345pub struct Proxy {
346 nodes: Vec<InternalNode>,
347}
348
349impl Proxy {
350 pub fn new(nodes: impl IntoIterator<Item = Node>) -> Self {
351 Proxy {
352 nodes: nodes.into_iter().map(|node| node.into()).collect(),
353 }
354 }
355
356 pub fn builder() -> ProxyBuilder {
357 ProxyBuilder { nodes: vec![] }
358 }
359
360 pub fn translation_map(&self) -> HashMap<SocketAddr, SocketAddr> {
364 let mut translation_map = HashMap::new();
365 for node in self.nodes.iter() {
366 if let &InternalNode::Real {
367 real_addr,
368 proxy_addr,
369 ..
370 } = node
371 {
372 translation_map.insert(real_addr, proxy_addr);
373 let shard_aware_real_addr = SocketAddr::new(real_addr.ip(), 19042);
374 translation_map.insert(shard_aware_real_addr, proxy_addr);
375 }
376 }
377 translation_map
378 }
379
380 pub async fn run(self) -> Result<RunningProxy, DoorkeeperError> {
383 let (terminate_signaler, _t) = tokio::sync::broadcast::channel(1);
384 let (finish_guard, finish_waiter) = mpsc::channel(1);
385
386 let (error_propagator, error_sink) = mpsc::unbounded_channel();
387 let (doorkeepers, running_nodes): (Vec<_>, Vec<RunningNode>) = self
388 .nodes
389 .into_iter()
390 .map(|node| {
391 let connection_tracker = Arc::new(ConnectionTracker {
392 active_count: std::sync::atomic::AtomicUsize::new(0),
393 notify: tokio::sync::Notify::new(),
394 });
395 let cc_event_sender = Arc::new(Mutex::new(HashMap::new()));
396 let running = {
397 let (request_rules, response_rules) = match node {
398 InternalNode::Real {
399 ref request_rules,
400 ref response_rules,
401 ..
402 } => (request_rules, Some(response_rules)),
403 InternalNode::Simulated {
404 ref request_rules, ..
405 } => (request_rules, None),
406 };
407 RunningNode {
408 request_rules: request_rules.clone(),
409 response_rules: response_rules.cloned(),
410 connection_tracker: connection_tracker.clone(),
411 cc_event_sender: cc_event_sender.clone(),
412 }
413 };
414 (
415 Doorkeeper::spawn(
416 node,
417 terminate_signaler.clone(),
418 finish_guard.clone(),
419 error_propagator.clone(),
420 connection_tracker,
421 cc_event_sender,
422 ),
423 running,
424 )
425 })
426 .unzip();
427
428 for doorkeeper in doorkeepers {
429 doorkeeper.await?; }
431
432 Ok(RunningProxy {
433 terminate_signaler,
434 finish_waiter,
435 running_nodes,
436 error_sink,
437 })
438 }
439}
440
441pub struct RunningNode {
443 request_rules: Arc<Mutex<Vec<RequestRule>>>,
444 response_rules: Option<Arc<Mutex<Vec<ResponseRule>>>>,
445 connection_tracker: Arc<ConnectionTracker>,
446
447 cc_event_sender: Arc<Mutex<HashMap<usize, mpsc::UnboundedSender<ResponseFrame>>>>,
457}
458
459impl RunningNode {
460 pub fn change_request_rules(&mut self, rules: Option<Vec<RequestRule>>) {
462 *self.request_rules.lock().unwrap() = rules.unwrap_or_default();
463 }
464
465 pub fn append_request_rules(&mut self, mut rules: Vec<RequestRule>) {
467 self.request_rules.lock().unwrap().append(&mut rules);
468 }
469
470 pub fn prepend_request_rules(&mut self, rules: Vec<RequestRule>) {
472 let mut new_rules = rules;
473 let mut old_rules_guard = self.request_rules.lock().unwrap();
474 new_rules.append(&mut *old_rules_guard);
475 *old_rules_guard = new_rules;
476 }
477
478 pub fn change_response_rules(&mut self, rules: Option<Vec<ResponseRule>>) {
480 *self
481 .response_rules
482 .as_ref()
483 .expect("No response rules on a simulated node!")
484 .lock()
485 .unwrap() = rules.unwrap_or_default();
486 }
487
488 pub fn append_response_rules(&mut self, mut rules: Vec<ResponseRule>) {
490 self.response_rules
491 .as_ref()
492 .expect("No response rules on a simulated node!")
493 .lock()
494 .unwrap()
495 .append(&mut rules);
496 }
497
498 pub fn prepend_response_rules(&mut self, rules: Vec<ResponseRule>) {
500 let mut old_rules_guard = self
501 .response_rules
502 .as_ref()
503 .expect("No response rules on a simulated node!")
504 .lock()
505 .unwrap();
506 let mut new_rules = rules;
507 new_rules.append(&mut *old_rules_guard);
508 *old_rules_guard = new_rules;
509 }
510
511 pub async fn wait_for_connection(&self) {
517 loop {
518 let notified = self.connection_tracker.notify.notified();
522 if self.connection_tracker.active_count.load(Ordering::Relaxed) > 0 {
523 return;
524 }
525 notified.await;
526 }
527 }
528
529 pub fn inject_event_to_cc(&self, body: Bytes) -> bool {
539 let mut guard = self.cc_event_sender.lock().unwrap();
540 if guard.is_empty() {
541 return false;
542 }
543 let mut any_sent = false;
544 guard.retain(|_conn_no, tx| {
545 let frame = ResponseFrame {
546 params: FrameParams {
547 version: 4,
548 flags: 0,
549 stream: -1,
550 }
551 .for_response(),
552 opcode: ResponseOpcode::Event,
553 body: body.clone(),
554 };
555 let ok = tx.send(frame).is_ok();
556 any_sent |= ok;
557 ok
560 });
561 any_sent
562 }
563}
564
565pub struct RunningProxy {
567 terminate_signaler: TerminateSignaler,
568 finish_waiter: FinishWaiter,
569 pub running_nodes: Vec<RunningNode>,
570 error_sink: ErrorSink,
571}
572
573impl RunningProxy {
574 pub fn turn_off_rules(&mut self) {
576 for (request_rules, response_rules) in self
577 .running_nodes
578 .iter_mut()
579 .map(|node| (&node.request_rules, &node.response_rules))
580 {
581 request_rules.lock().unwrap().clear();
582 if let Some(response_rules) = response_rules {
583 response_rules.lock().unwrap().clear();
584 }
585 }
586 }
587
588 pub fn sanity_check(&mut self) -> Result<(), ProxyError> {
591 match self.error_sink.try_recv() {
592 Ok(err) => Err(err),
593 Err(TryRecvError::Empty) => Ok(()),
594 Err(TryRecvError::Disconnected) => {
595 Err(ProxyError::SanityCheckFailure)
597 }
598 }
599 }
600
601 pub async fn wait_for_error(&mut self) -> Option<ProxyError> {
603 self.error_sink.recv().await
604 }
605
606 pub async fn finish(mut self) -> Result<(), ProxyError> {
609 self.terminate_signaler.send(()).map_err(|err| {
610 ProxyError::AwaitFinishFailure(format!(
611 "Send error in terminate_signaler: {err} (bug!)"
612 ))
613 })?;
614 info!("Sent finish signal to proxy workers.");
615
616 std::mem::drop(self.terminate_signaler);
618
619 if self.finish_waiter.recv().await.is_some() {
620 unreachable!();
621 };
622 info!("All workers have finished.");
623
624 match self.error_sink.try_recv() {
625 Ok(err) => Err(err),
626 Err(TryRecvError::Disconnected) => Ok(()),
627 Err(TryRecvError::Empty) => {
628 unreachable!("Worker await logic bug!");
630 }
631 }
632 }
633
634 pub async fn wait_for_connection(&self) {
641 futures::future::select_all(
643 self.running_nodes
644 .iter()
645 .map(|n| Box::pin(n.wait_for_connection())),
646 )
647 .await;
648 }
649}
650
651struct Doorkeeper {
656 node: InternalNode,
657 listener: TcpListener,
658 terminate_signaler: TerminateSignaler,
659 finish_guard: FinishGuard,
660 shards_count: Option<u16>,
661 error_propagator: ErrorPropagator,
662 connection_tracker: Arc<ConnectionTracker>,
663 cc_event_sender: Arc<Mutex<HashMap<usize, mpsc::UnboundedSender<ResponseFrame>>>>,
664}
665
666impl Doorkeeper {
667 async fn spawn(
668 node: InternalNode,
669 terminate_signaler: TerminateSignaler,
670 finish_guard: FinishGuard,
671 error_propagator: ErrorPropagator,
672 connection_tracker: Arc<ConnectionTracker>,
673 cc_event_sender: Arc<Mutex<HashMap<usize, mpsc::UnboundedSender<ResponseFrame>>>>,
674 ) -> Result<(), DoorkeeperError> {
675 let listener = TcpListener::bind(node.proxy_addr())
676 .await
677 .map_err(|err| DoorkeeperError::DriverConnectionAttempt(node.proxy_addr(), err))?;
678
679 if let InternalNode::Real {
680 shard_awareness,
681 real_addr,
682 ..
683 } = node
684 {
685 info!(
686 "Spawned a {} doorkeeper for pair real:{} - proxy:{}.",
687 if shard_awareness.is_aware() {
688 "shard-aware"
689 } else {
690 "shard-unaware"
691 },
692 real_addr,
693 node.proxy_addr(),
694 );
695 } else {
696 info!(
697 "Spawned a dry-mode doorkeeper for proxy:{}.",
698 node.proxy_addr(),
699 )
700 };
701
702 let doorkeeper = Doorkeeper {
703 shards_count: None, node,
705 listener,
706 terminate_signaler,
707 finish_guard,
708 error_propagator,
709 connection_tracker,
710 cc_event_sender,
711 };
712 tokio::task::spawn(doorkeeper.run());
713 Ok(())
714 }
715
716 async fn run(mut self) {
717 self.update_shards_count().await;
718 let mut own_terminate_notifier = self.terminate_signaler.subscribe();
719 let (connection_close_tx, _connection_close_rx) = broadcast::channel::<()>(2);
720 let mut connection_no: usize = 0;
721 loop {
722 tokio::select! {
723 res = self.accept_connection(&connection_close_tx, connection_no) => {
724 match res {
725 Ok(()) => connection_no += 1,
726 Err(err) => {
727 error!(
728 "Error in doorkeeper with addr {} for node {}: {}",
729 self.node.proxy_addr(),
730 DisplayableRealAddrOption(self.node.real_addr()),
731 err
732 );
733 let _ = self.error_propagator.send(err.into());
734 break;
735 },
736 }
737 },
738 _terminate = own_terminate_notifier.recv() => break
739 }
740 }
741 debug!(
742 "Doorkeeper exits: proxy {}, node {}.",
743 self.node.proxy_addr(),
744 DisplayableRealAddrOption(self.node.real_addr())
745 );
746 }
747
748 async fn update_shards_count(&mut self) {
749 if let InternalNode::Real {
750 real_addr,
751 shard_awareness,
752 ..
753 } = self.node
754 {
755 self.shards_count = match shard_awareness {
756 ShardAwareness::Unaware => None,
757 ShardAwareness::FixedNum(shards_num) => Some(shards_num),
758 ShardAwareness::QueryNode => match self.obtain_shards_count(real_addr).await {
759 Ok(shards) => Some(shards),
760 Err(DoorkeeperError::ObtainingShardNumberNoShardInfo) => {
762 info!(
763 "Doorkeeper with addr {} found no shard info in node {}; falling back to ShardAwareness::Unaware",
764 self.node.proxy_addr(),
765 DisplayableRealAddrOption(self.node.real_addr()),
766 );
767 None
768 }
769 Err(e) => {
770 error!(
771 "Error in doorkeeper with addr {} while querying shard info from node {}: {}",
772 self.node.proxy_addr(),
773 DisplayableRealAddrOption(self.node.real_addr()),
774 e
775 );
776 None
777 }
778 },
779 }
780 }
781 }
782
783 async fn spawn_workers(
784 &mut self,
785 driver_addr: SocketAddr,
786 connection_close_tx: &ConnectionCloseSignaler,
787 connection_no: usize,
788 driver_stream: TcpStream,
789 cluster_stream: Option<TcpStream>,
790 shard: Option<TargetShard>,
791 ) {
792 let conn_lifetime = self.connection_tracker.register_connection();
798
799 let (driver_read, driver_write) = driver_stream.into_split();
800
801 let new_worker = || ProxyWorker {
802 terminate_notifier: self.terminate_signaler.subscribe(),
803 finish_guard: self.finish_guard.clone(),
804 connection_close_notifier: connection_close_tx.subscribe(),
805 error_propagator: self.error_propagator.clone(),
806 driver_addr,
807 real_addr: self.node.real_addr(),
808 proxy_addr: self.node.proxy_addr(),
809 shard,
810 };
811
812 let (tx_request, rx_request) = mpsc::unbounded_channel::<RequestFrame>();
813 let (tx_response, rx_response) = mpsc::unbounded_channel::<ResponseFrame>();
814 let (tx_cluster, rx_cluster) = mpsc::unbounded_channel::<RequestFrame>();
815 let (tx_driver, rx_driver) = mpsc::unbounded_channel::<ResponseFrame>();
816 let event_register_flag = Arc::new(AtomicBool::new(false));
817
818 let (
819 compression_writer_request_processor,
820 compression_reader_receiver_from_driver,
821 compression_reader_receiver_from_cluster,
822 compression_reader_sender_to_driver,
823 compression_reader_sender_to_cluster,
824 ) = compression::make_compression_infra();
825
826 {
827 let guard = Arc::clone(&conn_lifetime);
828 let worker = new_worker();
829 tokio::task::spawn(async move {
830 let _conn = guard;
831 worker
832 .receiver_from_driver(
833 driver_read,
834 tx_request,
835 compression_reader_receiver_from_driver,
836 )
837 .await;
838 });
839 }
840 {
841 let guard = Arc::clone(&conn_lifetime);
842 let worker = new_worker();
843 let conn_close_sub = connection_close_tx.subscribe();
844 let term_sub = self.terminate_signaler.subscribe();
845 tokio::task::spawn(async move {
846 let _conn = guard;
847 worker
848 .sender_to_driver(
849 driver_write,
850 rx_driver,
851 conn_close_sub,
852 term_sub,
853 compression_reader_sender_to_driver,
854 )
855 .await;
856 });
857 }
858 {
859 let guard = Arc::clone(&conn_lifetime);
860 let worker = new_worker();
861 let request_rules = Arc::clone(self.node.request_rules());
862 let conn_close = connection_close_tx.clone();
863 let event_flag = Arc::clone(&event_register_flag);
864 let tx_driver_clone = tx_driver.clone();
865 let tx_cluster_clone = tx_cluster.clone();
866 let cc_sender = Arc::clone(&self.cc_event_sender);
867 tokio::task::spawn(async move {
868 let _conn = guard;
869 worker
870 .request_processor(
871 rx_request,
872 tx_driver_clone,
873 tx_cluster_clone,
874 connection_no,
875 request_rules,
876 conn_close,
877 event_flag,
878 compression_writer_request_processor,
879 cc_sender,
880 )
881 .await;
882 });
883 }
884 if let InternalNode::Real {
885 ref response_rules, ..
886 } = self.node
887 {
888 let (cluster_read, cluster_write) = cluster_stream.unwrap().into_split();
889 {
890 let guard = Arc::clone(&conn_lifetime);
891 let worker = new_worker();
892 let conn_close_sub = connection_close_tx.subscribe();
893 let term_sub = self.terminate_signaler.subscribe();
894 tokio::task::spawn(async move {
895 let _conn = guard;
896 worker
897 .sender_to_cluster(
898 cluster_write,
899 rx_cluster,
900 conn_close_sub,
901 term_sub,
902 compression_reader_sender_to_cluster,
903 )
904 .await;
905 });
906 }
907 {
908 let guard = Arc::clone(&conn_lifetime);
909 let worker = new_worker();
910 tokio::task::spawn(async move {
911 let _conn = guard;
912 worker
913 .receiver_from_cluster(
914 cluster_read,
915 tx_response,
916 compression_reader_receiver_from_cluster,
917 )
918 .await;
919 });
920 }
921 {
922 let guard = conn_lifetime;
923 let worker = new_worker();
924 let response_rules = Arc::clone(response_rules);
925 let conn_close = connection_close_tx.clone();
926 let event_flag = Arc::clone(&event_register_flag);
927 tokio::task::spawn(async move {
928 let _conn = guard;
929 worker
930 .response_processor(
931 rx_response,
932 tx_driver,
933 tx_cluster,
934 connection_no,
935 response_rules,
936 conn_close,
937 event_flag,
938 )
939 .await;
940 });
941 }
942 }
943 debug!(
944 "Doorkeeper with addr {} of node {} spawned workers.",
945 self.node.proxy_addr(),
946 DisplayableRealAddrOption(self.node.real_addr())
947 );
948 }
949
950 async fn accept_connection(
951 &mut self,
952 connection_close_tx: &ConnectionCloseSignaler,
953 connection_no: usize,
954 ) -> Result<(), DoorkeeperError> {
955 let (driver_stream, driver_addr) = self.make_driver_stream(connection_no).await?;
956 let (cluster_stream, shard) = match self.node {
957 InternalNode::Real { real_addr, .. } => {
958 let (cluster_stream, shard) =
959 self.make_cluster_stream(driver_addr, real_addr).await?;
960 (Some(cluster_stream), shard)
961 }
962 InternalNode::Simulated { .. } => (None, None),
963 };
964
965 self.spawn_workers(
966 driver_addr,
967 connection_close_tx,
968 connection_no,
969 driver_stream,
970 cluster_stream,
971 shard,
972 )
973 .await;
974
975 Ok(())
976 }
977
978 async fn make_driver_stream(
979 &mut self,
980 connection_no: usize,
981 ) -> Result<(TcpStream, SocketAddr), DoorkeeperError> {
982 let (driver_stream, driver_addr) =
983 self.listener.accept().await.map_err(|err| {
984 DoorkeeperError::DriverConnectionAttempt(self.node.proxy_addr(), err)
985 })?;
986 info!(
987 "Connected driver from {} to {}, connection no={}.",
988 driver_addr,
989 self.node.proxy_addr(),
990 connection_no
991 );
992 Ok((driver_stream, driver_addr))
993 }
994
995 async fn make_cluster_stream(
996 &mut self,
997 driver_addr: SocketAddr,
998 real_addr: SocketAddr,
999 ) -> Result<(TcpStream, Option<TargetShard>), DoorkeeperError> {
1000 let mut cluster_stream = if let Some(shards) = self.shards_count {
1001 let socket = match self.node.proxy_addr().ip() {
1002 std::net::IpAddr::V4(_) => TcpSocket::new_v4(),
1003 std::net::IpAddr::V6(_) => TcpSocket::new_v6(),
1004 }
1005 .map_err(DoorkeeperError::SocketCreate)?;
1006
1007 let shard_preserving_addr = {
1008 let mut desired_addr =
1009 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), driver_addr.port());
1010 while socket.bind(desired_addr).is_err() {
1011 let next_port = self.next_port_to_same_shard(desired_addr.port());
1013 if next_port == driver_addr.port() {
1014 return Err(DoorkeeperError::NoMorePorts);
1015 }
1016 desired_addr.set_port(next_port);
1017 }
1018 desired_addr
1019 };
1020
1021 let stream = socket.connect(real_addr).await;
1022 if let Ok(ok) = &stream {
1023 info!(
1024 "Connected to the cluster from {} at {}, intended shard {}.",
1025 ok.local_addr().unwrap(),
1026 real_addr,
1027 shard_preserving_addr.port() % shards
1028 );
1029 }
1030 stream
1031 } else {
1032 let stream = TcpStream::connect(real_addr).await;
1033 if stream.is_ok() {
1034 info!("Connected to the cluster at {}.", real_addr);
1035 }
1036 stream
1037 }
1038 .map_err(|err| DoorkeeperError::NodeConnectionAttempt(real_addr, err))?;
1039
1040 let shard = if self.shards_count.is_some() {
1047 self.obtain_shard_number(real_addr, &mut cluster_stream)
1048 .await?
1049 } else {
1050 None
1051 };
1052
1053 Ok((cluster_stream, shard))
1054 }
1055
1056 fn next_port_to_same_shard(&self, port: u16) -> u16 {
1057 port.wrapping_add(self.shards_count.unwrap())
1058 }
1059
1060 async fn get_supported_options(
1061 connection: &mut TcpStream,
1062 ) -> Result<HashMap<String, Vec<String>>, DoorkeeperError> {
1063 write_frame(
1064 HARDCODED_OPTIONS_PARAMS,
1065 FrameOpcode::Request(RequestOpcode::Options),
1066 &Bytes::new(),
1067 connection,
1068 &no_compression(),
1069 )
1070 .await
1071 .map_err(DoorkeeperError::ObtainingShardNumber)?;
1072
1073 let supported_frame = read_response_frame(connection, &compression::no_compression())
1074 .await
1075 .map_err(DoorkeeperError::ObtainingShardNumberFrame)?;
1076
1077 let options = read_string_multimap(&mut supported_frame.body.as_ref())
1078 .map_err(DoorkeeperError::ObtainingShardNumberParseOptions)?;
1079
1080 Ok(options)
1081 }
1082
1083 async fn obtain_shards_count(&self, real_addr: SocketAddr) -> Result<u16, DoorkeeperError> {
1084 let mut connection = TcpStream::connect(real_addr)
1085 .await
1086 .map_err(|err| DoorkeeperError::NodeConnectionAttempt(real_addr, err))?;
1087 let options = Self::get_supported_options(&mut connection).await?;
1088 let nr_shards_entry = options.get("SCYLLA_NR_SHARDS");
1089 let shards = match nr_shards_entry
1090 .and_then(|vec| vec.first())
1091 .ok_or(DoorkeeperError::ObtainingShardNumberNoShardInfo)?
1092 .parse::<u16>()
1093 .map_err(DoorkeeperError::ObtainingShardNumberParseShardNumber)?
1094 {
1095 0u16 => Err(DoorkeeperError::ObtainingShardNumberGotZero),
1096 num => Ok(num),
1097 }?;
1098 info!("Obtained shards number on node {}: {}", real_addr, shards);
1099 Ok(shards)
1100 }
1101
1102 async fn obtain_shard_number(
1103 &self,
1104 real_addr: SocketAddr,
1105 connection: &mut TcpStream,
1106 ) -> Result<Option<TargetShard>, DoorkeeperError> {
1107 let options = Self::get_supported_options(connection).await?;
1108 let shard_entry = options.get("SCYLLA_SHARD");
1109 let shard = shard_entry
1110 .and_then(|vec| vec.first())
1111 .map(|s| {
1112 s.parse::<u16>()
1113 .map_err(DoorkeeperError::ObtainingShardNumberParseShardNumber)
1114 })
1115 .transpose()?;
1116 info!("Connected to node {}, shard {:?}", real_addr, shard);
1117 Ok(shard)
1118 }
1119}
1120
1121mod compression {
1122 use std::error::Error;
1123 use std::sync::{Arc, OnceLock};
1124
1125 use bytes::Bytes;
1126 use scylla_cql::frame::frame_errors::{
1127 CqlRequestSerializationError, FrameBodyExtensionsParseError,
1128 };
1129 use scylla_cql::frame::request::{
1130 DeserializableRequest as _, RequestDeserializationError, Startup, options,
1131 };
1132 use scylla_cql::frame::{Compression, compress_append, decompress, flag};
1133 use tracing::{error, warn};
1134
1135 #[derive(Debug, thiserror::Error)]
1136 pub(crate) enum CompressionError {
1137 #[error("Snap compression error: {0}")]
1139 SnapCompressError(Arc<dyn Error + Sync + Send>),
1140
1141 #[error("Frame is to be compressed, but no compression negotiated for connection.")]
1143 NoCompressionNegotiated,
1144 }
1145
1146 type CompressionInfo = Arc<OnceLock<Option<Compression>>>;
1147
1148 #[derive(Debug, Clone)]
1153 pub(crate) struct CompressionWriter(CompressionInfo);
1154 impl CompressionWriter {
1155 pub(crate) fn set(
1156 &self,
1157 compression: Option<Compression>,
1158 ) -> Result<(), Option<Compression>> {
1159 self.0.set(compression)
1160 }
1161
1162 pub(crate) fn set_from_startup(
1163 &self,
1164 mut body: &[u8],
1165 ) -> Result<Option<Compression>, RequestDeserializationError> {
1166 let startup = Startup::deserialize_with_features(&mut body, &Default::default())?;
1167 let maybe_compression = startup.options.get(options::COMPRESSION);
1168 let maybe_compression = maybe_compression.and_then(|compression| {
1169 compression
1170 .parse::<Compression>()
1171 .inspect_err(|err| error!("STARTUP compression error: {}", err))
1172 .ok()
1173 });
1174 let _ = self.set(maybe_compression).inspect_err(|_| {
1175 warn!("Captured second or further STARTUP frame on the same connection")
1176 });
1177
1178 Ok(maybe_compression)
1179 }
1180 }
1181
1182 #[derive(Debug, Clone)]
1186 pub(crate) struct CompressionReader(CompressionInfo);
1187 impl CompressionReader {
1188 pub(crate) fn get(&self) -> Option<Option<Compression>> {
1193 self.0.get().copied()
1194 }
1195
1196 pub(crate) fn maybe_compress_body(
1197 &self,
1198 flags: u8,
1199 body: &[u8],
1200 ) -> Result<Option<Bytes>, CompressionError> {
1201 match (flags & flag::COMPRESSION != 0, self.get().flatten()) {
1202 (true, Some(compression)) => {
1203 let mut buf = Vec::new();
1204 compress_append(body, compression, &mut buf).map_err(|err| {
1205 let CqlRequestSerializationError::SnapCompressError(err) = err else {
1206 unreachable!("BUG: compress_append returned variant different than SnapCompressError")
1207 };
1208 CompressionError::SnapCompressError(err)
1209 })?;
1210 Ok(Some(Bytes::from(buf)))
1211 }
1212 (true, None) => Err(CompressionError::NoCompressionNegotiated),
1213 (false, _) => Ok(None),
1214 }
1215 }
1216
1217 pub(crate) fn maybe_decompress_body(
1218 &self,
1219 flags: u8,
1220 body: Bytes,
1221 ) -> Result<Bytes, FrameBodyExtensionsParseError> {
1222 match (flags & flag::COMPRESSION != 0, self.get().flatten()) {
1223 (true, Some(compression)) => decompress(&body, compression).map(Into::into),
1224 (true, None) => Err(FrameBodyExtensionsParseError::NoCompressionNegotiated),
1225 (false, _) => Ok(body),
1226 }
1227 }
1228 }
1229
1230 pub(crate) fn make_compression_infra() -> (
1231 CompressionWriter,
1232 CompressionReader,
1233 CompressionReader,
1234 CompressionReader,
1235 CompressionReader,
1236 ) {
1237 let info = Arc::new(OnceLock::new());
1238 (
1239 CompressionWriter(info.clone()),
1240 CompressionReader(info.clone()),
1241 CompressionReader(info.clone()),
1242 CompressionReader(info.clone()),
1243 CompressionReader(info),
1244 )
1245 }
1246
1247 fn mock_compression_reader(compression: Option<Compression>) -> CompressionReader {
1248 CompressionReader(Arc::new({
1249 let once = OnceLock::new();
1250 once.set(compression).unwrap();
1251 once
1252 }))
1253 }
1254
1255 pub(crate) fn no_compression() -> CompressionReader {
1257 mock_compression_reader(None)
1258 }
1259
1260 #[cfg(test)] pub(crate) fn with_compression(compression: Compression) -> CompressionReader {
1263 mock_compression_reader(Some(compression))
1264 }
1265}
1266pub(crate) use compression::{CompressionReader, CompressionWriter};
1267
1268struct ProxyWorker {
1269 terminate_notifier: TerminateNotifier,
1270 finish_guard: FinishGuard,
1271 connection_close_notifier: ConnectionCloseNotifier,
1272 error_propagator: ErrorPropagator,
1273 driver_addr: SocketAddr,
1274 real_addr: Option<SocketAddr>,
1275 proxy_addr: SocketAddr,
1276 shard: Option<TargetShard>,
1277}
1278
1279impl ProxyWorker {
1280 fn exit(self, duty: &'static str) {
1281 debug!(
1282 "Worker exits: [driver: {}, proxy: {}, node: {}, {}]::{}.",
1283 self.driver_addr,
1284 self.proxy_addr,
1285 DisplayableRealAddrOption(self.real_addr),
1286 DisplayableShard(self.shard),
1287 duty
1288 );
1289 std::mem::drop(self.finish_guard);
1290 }
1291
1292 async fn run_until_interrupted<F, Fut>(mut self, worker_name: &'static str, f: F)
1293 where
1294 F: FnOnce(SocketAddr, SocketAddr, Option<SocketAddr>) -> Fut,
1295 Fut: Future<Output = Result<(), ProxyError>>,
1296 {
1297 let fut = f(self.driver_addr, self.proxy_addr, self.real_addr);
1298
1299 tokio::select! {
1300 result = fut => {
1301 if let Err(err) = result {
1302 let _ = self.error_propagator.send(err);
1304 }
1305 }
1306 _ = self.terminate_notifier.recv() => (),
1307 _ = self.connection_close_notifier.recv() => (),
1308 }
1309 self.exit(worker_name);
1310 }
1311
1312 async fn receiver_from_driver(
1313 self,
1314 mut read_half: impl AsyncRead + Unpin,
1315 request_processor_tx: mpsc::UnboundedSender<RequestFrame>,
1316 compression: CompressionReader,
1317 ) {
1318 let shard = self.shard;
1319 self.run_until_interrupted(
1320 "receiver_from_driver",
1321 |driver_addr, proxy_addr, _real_addr| async move {
1322 loop {
1323 let frame = frame::read_request_frame(&mut read_half, &compression)
1324 .await
1325 .map_err(|err| {
1326 warn!("Request reception from {} error: {}", driver_addr, err);
1327 WorkerError::DriverDisconnected(driver_addr)
1328 })?;
1329
1330 debug!(
1331 "Intercepted Driver ({}) -> Cluster ({}) ({}) frame. opcode: {:?}.",
1332 driver_addr,
1333 proxy_addr,
1334 DisplayableShard(shard),
1335 &frame.opcode
1336 );
1337 if request_processor_tx.send(frame).is_err() {
1338 warn!("request_processor had exited.");
1339 return Result::<(), ProxyError>::Ok(());
1340 }
1341 }
1342 },
1343 )
1344 .await
1345 }
1346
1347 async fn receiver_from_cluster(
1348 self,
1349 mut read_half: impl AsyncRead + Unpin,
1350 response_processor_tx: mpsc::UnboundedSender<ResponseFrame>,
1351 compression: CompressionReader,
1352 ) {
1353 let shard = self.shard;
1354 self.run_until_interrupted(
1355 "receiver_from_cluster",
1356 |driver_addr, _proxy_addr, real_addr| async move {
1357 let real_addr = real_addr.expect("BUG: no real_addr in cluster worker");
1358 loop {
1359 let frame = frame::read_response_frame(&mut read_half, &compression)
1360 .await
1361 .map_err(|err| {
1362 warn!("Response reception from {} error: {}", real_addr, err);
1363 WorkerError::NodeDisconnected(real_addr)
1364 })?;
1365
1366 debug!(
1367 "Intercepted Cluster ({}) ({}) -> Driver ({}) frame. opcode: {:?}.",
1368 real_addr,
1369 DisplayableShard(shard),
1370 driver_addr,
1371 &frame.opcode
1372 );
1373
1374 if response_processor_tx.send(frame).is_err() {
1375 warn!("response_processor had exited.");
1376 return Ok::<(), ProxyError>(());
1377 }
1378 }
1379 },
1380 )
1381 .await;
1382 }
1383
1384 async fn sender_to_driver(
1385 self,
1386 mut write_half: impl AsyncWrite + Unpin,
1387 mut responses_rx: mpsc::UnboundedReceiver<ResponseFrame>,
1388 mut connection_close_notifier: ConnectionCloseNotifier,
1389 mut terminate_notifier: TerminateNotifier,
1390 compression: CompressionReader,
1391 ) {
1392 let shard = self.shard;
1393 self.run_until_interrupted(
1394 "sender_to_driver",
1395 |driver_addr, proxy_addr, _real_addr| async move {
1396 loop {
1397 let response = match responses_rx.recv().await {
1398 Some(response) => response,
1399 None => {
1400 if terminate_notifier.try_recv().is_err()
1401 && connection_close_notifier.try_recv().is_err()
1402 {
1403 warn!("Response processor had exited");
1404 }
1405 return Ok(());
1406 }
1407 };
1408
1409 debug!(
1410 "Sending Proxy ({}) ({}) -> Driver ({}) frame. opcode: {:?}.",
1411 proxy_addr,
1412 DisplayableShard(shard),
1413 driver_addr,
1414 &response.opcode
1415 );
1416 if response.write(&mut write_half, &compression).await.is_err() {
1417 if terminate_notifier.try_recv().is_err()
1418 && connection_close_notifier.try_recv().is_err()
1419 {
1420 warn!("Driver dropped connection");
1421 return Err(WorkerError::DriverDisconnected(driver_addr).into());
1422 }
1423 return Ok(());
1424 }
1425 }
1426 },
1427 )
1428 .await;
1429 }
1430
1431 async fn sender_to_cluster(
1432 self,
1433 mut write_half: impl AsyncWrite + Unpin,
1434 mut requests_rx: mpsc::UnboundedReceiver<RequestFrame>,
1435 mut connection_close_notifier: ConnectionCloseNotifier,
1436 mut terminate_notifier: TerminateNotifier,
1437 compression: CompressionReader,
1438 ) {
1439 let shard = self.shard;
1440 self.run_until_interrupted(
1441 "sender_to_driver",
1442 |_driver_addr, proxy_addr, real_addr| async move {
1443 let real_addr = real_addr.expect("BUG: no real_addr in cluster worker");
1444 loop {
1445 let request = match requests_rx.recv().await {
1446 Some(request) => request,
1447 None => {
1448 if terminate_notifier.try_recv().is_err()
1449 && connection_close_notifier.try_recv().is_err()
1450 {
1451 warn!("Request processor had exited");
1452 }
1453 return Ok(());
1454 }
1455 };
1456
1457 debug!(
1458 "Sending Proxy ({}) -> Cluster ({}) ({}) frame. opcode: {:?}.",
1459 proxy_addr,
1460 real_addr,
1461 DisplayableShard(shard),
1462 &request.opcode
1463 );
1464
1465 if request.write(&mut write_half, &compression).await.is_err() {
1466 if terminate_notifier.try_recv().is_err()
1467 && connection_close_notifier.try_recv().is_err()
1468 {
1469 warn!("Node {} dropped connection", real_addr);
1470 return Err(WorkerError::NodeDisconnected(real_addr).into());
1471 }
1472 return Ok(());
1473 }
1474 }
1475 },
1476 )
1477 .await;
1478 }
1479
1480 #[expect(clippy::too_many_arguments)]
1481 async fn request_processor(
1482 self,
1483 mut requests_rx: mpsc::UnboundedReceiver<RequestFrame>,
1484 driver_tx: mpsc::UnboundedSender<ResponseFrame>,
1485 cluster_tx: mpsc::UnboundedSender<RequestFrame>,
1486 connection_no: usize,
1487 request_rules: Arc<Mutex<Vec<RequestRule>>>,
1488 connection_close_signaler: ConnectionCloseSignaler,
1489 event_registered_flag: Arc<AtomicBool>,
1490 compression: CompressionWriter,
1491 cc_event_sender: Arc<Mutex<HashMap<usize, mpsc::UnboundedSender<ResponseFrame>>>>,
1492 ) {
1493 let shard = self.shard;
1494 self.run_until_interrupted("request_processor", |driver_addr, _, real_addr| async move {
1495 'mainloop: loop {
1496 match requests_rx.recv().await {
1497 Some(request) => {
1498 if request.opcode == RequestOpcode::Register {
1499 event_registered_flag.store(true, Ordering::Relaxed);
1500 cc_event_sender.lock().unwrap().insert(connection_no, driver_tx.clone());
1504 info!(
1505 "REGISTER seen on connection {} ({} → {} ({})); registered cc_event_sender",
1506 connection_no,
1507 driver_addr,
1508 DisplayableRealAddrOption(real_addr),
1509 DisplayableShard(shard),
1510 );
1511 } else if request.opcode == RequestOpcode::Startup {
1512 match compression.set_from_startup(&request.body) {
1513 Err(err) => error!("Failed to deserialize STARTUP frame: {}", err),
1514 Ok(read_compression) => info!(
1515 "Intercepted STARTUP frame ({} -> {} ({})), so set compression accordingly to {:?}.",
1516 driver_addr,
1517 DisplayableRealAddrOption(real_addr),
1518 DisplayableShard(shard),
1519 read_compression
1520 )
1521 };
1522 }
1523
1524 let ctx = EvaluationContext {
1525 connection_seq_no: connection_no,
1526 opcode: FrameOpcode::Request(request.opcode),
1527 frame_body: request.body.clone(),
1528 connection_has_events: event_registered_flag.load(Ordering::Relaxed),
1529 };
1530 let mut guard = request_rules.lock().unwrap();
1531 '_ruleloop: for (i, request_rule) in guard.iter_mut().enumerate() {
1532 if request_rule.0.eval(&ctx) {
1533 debug!("Applying rule no={} to request ({} -> {} ({})).", i, driver_addr, DisplayableRealAddrOption(real_addr), DisplayableShard(shard));
1534 debug!("-> Applied rule: {:?}", request_rule);
1535 debug!("-> To request: {:?}", ctx.opcode);
1536 trace!("{:?}", request);
1537
1538 if let Some(ref tx) = request_rule.1.feedback_channel {
1539 tx.send((request.clone(), shard)).unwrap_or_else(|err|
1540 warn!("Could not send received request as feedback: {}", err)
1541 );
1542 }
1543
1544 let request_rule = request_rule.clone();
1545 let to_addressee_action = request_rule.1.to_addressee;
1546 let to_sender_action = request_rule.1.to_sender;
1547 let drop_connection_action = request_rule.1.drop_connection;
1548
1549 let cluster_tx_clone = cluster_tx.clone();
1550 let request_clone = request.clone();
1551 let pass_action = async move {
1552 if let Some(ref pass_action) = to_addressee_action {
1553 if let Some(time) = pass_action.delay {
1554 tokio::time::sleep(time).await;
1555 }
1556 let passed_frame = match pass_action.msg_processor {
1557 Some(ref processor) => processor(request_clone),
1558 None => request_clone,
1559 };
1560 let _ = cluster_tx_clone.send(passed_frame);
1561 };
1562 };
1563
1564 let driver_tx_clone = driver_tx.clone();
1565 let request_clone = request.clone();
1566 let forge_action = async move {
1567 if let Some(ref forge_action) = to_sender_action {
1568 if let Some(time) = forge_action.delay {
1569 tokio::time::sleep(time).await;
1570 }
1571 let forged_frame = {
1572 let processor = forge_action.msg_processor.as_ref()
1573 .expect("Frame processor is required to forge a frame.");
1574 processor(request_clone)
1575 };
1576 let _ = driver_tx_clone.send(forged_frame);
1577 };
1578 };
1579
1580 let connection_close_signaler_clone =
1581 connection_close_signaler.clone();
1582 let drop_action = async move {
1583 if let Some(ref delay) = drop_connection_action {
1584 if let Some(time) = delay {
1585 tokio::time::sleep(*time).await;
1586 }
1587 info!(
1589 "Dropping connection between {} and {} ({}) (as requested by a proxy rule)!",
1590 driver_addr,
1591 DisplayableRealAddrOption(real_addr),
1592 DisplayableShard(shard),
1593 );
1594 let _ = connection_close_signaler_clone.send(());
1595 }
1596 };
1597
1598 tokio::task::spawn(async {
1599 futures::join!(pass_action, forge_action, drop_action);
1600 });
1601
1602 continue 'mainloop; }
1604 }
1605 let _ = cluster_tx.send(request); }
1607 None => {
1608 if event_registered_flag.load(Ordering::Relaxed) {
1612 cc_event_sender.lock().unwrap().remove(&connection_no);
1613 info!(
1614 "Control connection {} ({} → {} ({})) closed; removed cc_event_sender",
1615 connection_no,
1616 driver_addr,
1617 DisplayableRealAddrOption(real_addr),
1618 DisplayableShard(shard),
1619 );
1620 }
1621 return Ok(());
1622 }
1623 }
1624 }
1625 })
1626 .await;
1627 }
1628
1629 #[expect(clippy::too_many_arguments)]
1630 async fn response_processor(
1631 self,
1632 mut responses_rx: mpsc::UnboundedReceiver<ResponseFrame>,
1633 driver_tx: mpsc::UnboundedSender<ResponseFrame>,
1634 cluster_tx: mpsc::UnboundedSender<RequestFrame>,
1635 connection_no: usize,
1636 response_rules: Arc<Mutex<Vec<ResponseRule>>>,
1637 connection_close_signaler: ConnectionCloseSignaler,
1638 event_registered_flag: Arc<AtomicBool>,
1639 ) {
1640 let shard = self.shard;
1641 self.run_until_interrupted("response_processor", |driver_addr, _, real_addr| async move {
1642 'mainloop: loop {
1643 match responses_rx.recv().await {
1644 Some(response) => {
1645 let ctx = EvaluationContext {
1646 connection_seq_no: connection_no,
1647 opcode: FrameOpcode::Response(response.opcode),
1648 frame_body: response.body.clone(),
1649 connection_has_events: event_registered_flag.load(Ordering::Relaxed),
1650 };
1651 let mut guard = response_rules.lock().unwrap();
1652 '_ruleloop: for (i, response_rule) in guard.iter_mut().enumerate() {
1653 if response_rule.0.eval(&ctx) {
1654 debug!("Applying rule no={} to response ({} -> {} ({})).", i, DisplayableRealAddrOption(real_addr), driver_addr, DisplayableShard(shard));
1655 debug!("-> Applied rule: {:?}", response_rule);
1656 debug!("-> To response: {:?}", ctx.opcode);
1657 trace!("{:?}", response);
1658
1659 if let Some(ref tx) = response_rule.1.feedback_channel {
1660 tx.send((response.clone(), shard)).unwrap_or_else(|err| warn!(
1661 "Could not send received response as feedback: {}", err
1662 ));
1663 }
1664
1665 let response_rule = response_rule.clone();
1666 let to_addressee_action = response_rule.1.to_addressee;
1667 let to_sender_action = response_rule.1.to_sender;
1668 let drop_connection_action = response_rule.1.drop_connection;
1669
1670 let response_clone = response.clone();
1671 let driver_tx_clone = driver_tx.clone();
1672 let pass_action = async move {
1673 if let Some(ref pass_action) = to_addressee_action {
1674 if let Some(time) = pass_action.delay {
1675 tokio::time::sleep(time).await;
1676 }
1677 let passed_frame = match pass_action.msg_processor {
1678 Some(ref processor) => processor(response_clone),
1679 None => response_clone,
1680 };
1681 let _ = driver_tx_clone.send(passed_frame);
1682 };
1683 };
1684
1685 let response_clone = response.clone();
1686 let cluster_tx_clone = cluster_tx.clone();
1687 let forge_action = async move {
1688 if let Some(ref forge_action) = to_sender_action {
1689 if let Some(time) = forge_action.delay {
1690 tokio::time::sleep(time).await;
1691 }
1692 let forged_frame = {
1693 let processor = forge_action.msg_processor.as_ref()
1694 .expect("Frame processor is required to forge a frame.");
1695 processor(response_clone)
1696 };
1697 let _ = cluster_tx_clone.send(forged_frame);
1698 };
1699 };
1700
1701 let connection_close_signaler_clone =
1702 connection_close_signaler.clone();
1703 let drop_action = async move {
1704 if let Some(ref delay) = drop_connection_action {
1705 if let Some(time) = delay {
1706 tokio::time::sleep(*time).await;
1707 }
1708 info!(
1710 "Dropping connection between {} and {} ({}) (as requested by a proxy rule)!",
1711 driver_addr,
1712 real_addr.expect("BUG: response rules are unavailable for dry-mode proxy!"),
1713 DisplayableShard(shard)
1714 );
1715 let _ = connection_close_signaler_clone.send(());
1716 }
1717 };
1718
1719 tokio::task::spawn(async {
1720 futures::join!(pass_action, forge_action, drop_action);
1721 });
1722
1723 continue 'mainloop;
1724 }
1725 }
1726 let _ = driver_tx.send(response); }
1728 None => return Ok(()),
1729 }
1730 }
1731 })
1732 .await
1733 }
1734}
1735
1736pub fn get_exclusive_local_address() -> IpAddr {
1739 match std::env::var("NEXTEST_TEST_GLOBAL_SLOT") {
1740 Ok(slot) => {
1741 let slot: u16 = slot
1742 .parse()
1743 .unwrap_or_else(|e| panic!("Invalid slot {e:?}"));
1744 get_exclusive_local_address_nextest(slot)
1745 }
1746 Err(VarError::NotPresent) => get_exclusive_local_address_libtest(),
1747 Err(VarError::NotUnicode(e)) => panic!("Invalid slot {e:?}"),
1748 }
1749}
1750
1751fn get_exclusive_local_address_libtest() -> IpAddr {
1752 static ADDRESS_LOWER_THREE_OCTETS: AtomicU32 = AtomicU32::new(4242);
1754 let next_addr = ADDRESS_LOWER_THREE_OCTETS.fetch_add(1, Ordering::Relaxed);
1755 if next_addr > (u32::MAX >> 8) {
1756 panic!("Loopback address pool for tests depleted");
1757 }
1758 let next_addr_bytes = next_addr.to_le_bytes();
1759 IpAddr::V4(Ipv4Addr::new(
1760 127,
1761 next_addr_bytes[2],
1762 next_addr_bytes[1],
1763 next_addr_bytes[0],
1764 ))
1765}
1766
1767fn get_exclusive_local_address_nextest(slot: u16) -> IpAddr {
1768 static ADDRESS_LOWER_OCTET: AtomicU8 = AtomicU8::new(255);
1769 const FREE_RANGES: u16 = 16;
1772 let next_address_lower = ADDRESS_LOWER_OCTET.fetch_sub(1, Ordering::Relaxed);
1773 if next_address_lower == 0 {
1774 panic!("Loopback address pool for this test depleted");
1775 }
1776
1777 let next_range_bytes: [u8; 2] = slot
1778 .checked_add(FREE_RANGES)
1779 .unwrap_or_else(|| panic!("Loopback address pool for tests depleted"))
1780 .to_le_bytes();
1781
1782 IpAddr::V4(Ipv4Addr::new(
1783 127,
1784 next_range_bytes[1],
1785 next_range_bytes[0],
1786 next_address_lower,
1787 ))
1788}
1789
1790#[cfg(test)]
1791mod tests {
1792 use super::compression::no_compression;
1793 use super::*;
1794 use crate::errors::ReadFrameError;
1795 use crate::frame::{FrameType, read_frame, read_request_frame, read_response_frame};
1796 use crate::proxy::compression::with_compression;
1797 use crate::{
1798 Condition, Reaction as _, RequestReaction, ResponseOpcode, ResponseReaction, setup_tracing,
1799 };
1800 use assert_matches::assert_matches;
1801 use bytes::{BufMut, BytesMut};
1802 use futures::future::{join, join3};
1803 use rand::RngCore;
1804 use scylla_cql::frame::request::options;
1805 use scylla_cql::frame::request::{SerializableRequest as _, Startup};
1806 use scylla_cql::frame::types::write_string_multimap;
1807 use scylla_cql::frame::{Compression, flag};
1808 use std::collections::HashMap;
1809 use std::mem;
1810 use std::str::FromStr;
1811 use std::time::Duration;
1812 use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
1813 use tokio::sync::oneshot;
1814
1815 fn random_body() -> Bytes {
1816 let body_len = (rand::random::<u32>() % 1000) as usize;
1817 let mut body = BytesMut::zeroed(body_len);
1818 rand::rng().fill_bytes(body.as_mut());
1819 body.freeze()
1820 }
1821
1822 async fn respond_with_supported(
1823 conn: &mut TcpStream,
1824 supported_options: &HashMap<String, Vec<String>>,
1825 compression: &CompressionReader,
1826 ) {
1827 let RequestFrame {
1828 params: recvd_params,
1829 opcode: recvd_opcode,
1830 body: recvd_body,
1831 } = read_request_frame(conn, compression).await.unwrap();
1832 assert_eq!(recvd_params, HARDCODED_OPTIONS_PARAMS);
1833 assert_eq!(recvd_opcode, RequestOpcode::Options);
1834 assert_eq!(recvd_body, Bytes::new()); let mut body = BytesMut::new();
1837 write_string_multimap(supported_options, &mut body).unwrap();
1838
1839 let body = body.freeze();
1840
1841 write_frame(
1842 HARDCODED_OPTIONS_PARAMS.for_response(),
1843 FrameOpcode::Response(ResponseOpcode::Supported),
1844 &body,
1845 conn,
1846 &no_compression(),
1847 )
1848 .await
1849 .unwrap();
1850 }
1851
1852 fn supported_shards_count(shards_count: u16) -> HashMap<String, Vec<String>> {
1853 let mut sharded_info = HashMap::new();
1854 sharded_info.insert(
1855 String::from("SCYLLA_NR_SHARDS"),
1856 vec![shards_count.to_string()],
1857 );
1858 sharded_info
1859 }
1860
1861 fn supported_shard_number(shard_num: TargetShard) -> HashMap<String, Vec<String>> {
1862 let mut sharded_info = HashMap::new();
1863 sharded_info.insert(String::from("SCYLLA_SHARD"), vec![shard_num.to_string()]);
1864 sharded_info
1865 }
1866
1867 async fn respond_with_shards_count(
1868 conn: &mut TcpStream,
1869 shards_count: u16,
1870 compression: &CompressionReader,
1871 ) {
1872 respond_with_supported(conn, &supported_shards_count(shards_count), compression).await;
1873 }
1874
1875 async fn respond_with_shard_num(
1876 conn: &mut TcpStream,
1877 shard_num: TargetShard,
1878 compression: &CompressionReader,
1879 ) {
1880 respond_with_supported(conn, &supported_shard_number(shard_num), compression).await;
1881 }
1882
1883 fn next_local_address_with_port(port: u16) -> SocketAddr {
1884 SocketAddr::new(get_exclusive_local_address(), port)
1885 }
1886
1887 async fn identity_proxy_does_not_mutate_frames(shard_awareness: ShardAwareness) {
1888 let node1_real_addr = next_local_address_with_port(9876);
1889 let node1_proxy_addr = next_local_address_with_port(9876);
1890 let proxy = Proxy::new([Node::new(
1891 node1_real_addr,
1892 node1_proxy_addr,
1893 shard_awareness,
1894 None,
1895 None,
1896 )]);
1897 let running_proxy = proxy.run().await.unwrap();
1898
1899 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1900
1901 let params = FrameParams {
1902 flags: 0,
1903 version: 0x04,
1904 stream: 0,
1905 };
1906 let opcode = FrameOpcode::Request(RequestOpcode::Options);
1907
1908 let body = random_body();
1909
1910 let send_frame_to_shard = async {
1911 let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
1912
1913 write_frame(params, opcode, &body, &mut conn, &no_compression())
1914 .await
1915 .unwrap();
1916 conn
1917 };
1918
1919 let mock_node_action = async {
1920 if let ShardAwareness::QueryNode = shard_awareness {
1921 respond_with_shards_count(
1922 &mut mock_node_listener.accept().await.unwrap().0,
1923 1,
1924 &no_compression(),
1925 )
1926 .await;
1927 }
1928 let (mut conn, _) = mock_node_listener.accept().await.unwrap();
1929 if shard_awareness.is_aware() {
1930 respond_with_shard_num(&mut conn, 1, &no_compression()).await;
1931 }
1932 let RequestFrame {
1933 params: recvd_params,
1934 opcode: recvd_opcode,
1935 body: recvd_body,
1936 } = read_request_frame(&mut conn, &no_compression())
1937 .await
1938 .unwrap();
1939 assert_eq!(recvd_params, params);
1940 assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
1941 assert_eq!(recvd_body, body);
1942 conn
1943 };
1944
1945 let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1947 running_proxy.finish().await.unwrap();
1948 }
1949
1950 #[tokio::test]
1951 async fn identity_shard_unaware_proxy_does_not_mutate_frames() {
1952 setup_tracing();
1953 identity_proxy_does_not_mutate_frames(ShardAwareness::Unaware).await
1954 }
1955
1956 #[tokio::test]
1957 async fn identity_shard_aware_proxy_does_not_mutate_frames() {
1958 setup_tracing();
1959 identity_proxy_does_not_mutate_frames(ShardAwareness::QueryNode).await
1960 }
1961
1962 #[tokio::test]
1963 async fn shard_aware_proxy_is_transparent_for_connection_to_shards() {
1964 setup_tracing();
1965 async fn test_for_shards_num(shards_num: u16) {
1966 let node1_real_addr = next_local_address_with_port(9876);
1967 let node1_proxy_addr = next_local_address_with_port(9876);
1968 let proxy = Proxy::new([Node::new(
1969 node1_real_addr,
1970 node1_proxy_addr,
1971 ShardAwareness::FixedNum(shards_num),
1972 None,
1973 None,
1974 )]);
1975 let running_proxy = proxy.run().await.unwrap();
1976
1977 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1978
1979 let (driver_addr_tx, driver_addr_rx) = oneshot::channel::<SocketAddr>();
1980
1981 let send_frame_to_shard = async {
1982 let socket = TcpSocket::new_v4().unwrap();
1983 socket
1984 .bind(SocketAddr::from_str("0.0.0.0:0").unwrap())
1985 .unwrap();
1986 let conn = socket.connect(node1_proxy_addr).await.unwrap();
1987 driver_addr_tx.send(conn.local_addr().unwrap()).unwrap();
1988 conn
1989 };
1990
1991 let mock_node_action = async {
1992 let (conn, remote_addr) = mock_node_listener.accept().await.unwrap();
1993 let driver_addr = driver_addr_rx.await.unwrap();
1994 assert_eq!(
1995 driver_addr.port() % shards_num,
1996 remote_addr.port() % shards_num
1997 );
1998 conn
1999 };
2000
2001 let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
2003 running_proxy.finish().await.unwrap();
2004 }
2005
2006 for shard_num in 1..6 {
2007 test_for_shards_num(shard_num).await;
2008 }
2009 }
2010
2011 #[tokio::test]
2012 async fn shard_aware_proxy_queries_shards_number() {
2013 setup_tracing();
2014 async fn test_for_shards_num(shards_num: u16) {
2015 for shard_num in 0..shards_num {
2016 let node1_real_addr = next_local_address_with_port(9876);
2017 let node1_proxy_addr = next_local_address_with_port(9876);
2018 let proxy = Proxy::new([Node::new(
2019 node1_real_addr,
2020 node1_proxy_addr,
2021 ShardAwareness::QueryNode,
2022 None,
2023 None,
2024 )]);
2025 let running_proxy = proxy.run().await.unwrap();
2026
2027 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2028
2029 let (driver_addr_tx, driver_addr_rx) = oneshot::channel::<SocketAddr>();
2030
2031 let mock_driver_addr = next_local_address_with_port(shards_num * 1234 + shard_num);
2032 let send_frame_to_shard = async {
2033 let socket = TcpSocket::new_v4().unwrap();
2034 socket
2035 .bind(mock_driver_addr)
2036 .unwrap_or_else(|_| panic!("driver_addr failed: {mock_driver_addr}"));
2037 driver_addr_tx.send(socket.local_addr().unwrap()).unwrap();
2038 socket.connect(node1_proxy_addr).await.unwrap()
2039 };
2040
2041 let mock_node_action = async {
2042 respond_with_shards_count(
2043 &mut mock_node_listener.accept().await.unwrap().0,
2044 shards_num,
2045 &no_compression(),
2046 )
2047 .await;
2048 let (conn, remote_addr) = mock_node_listener.accept().await.unwrap();
2049 let driver_addr = driver_addr_rx.await.unwrap();
2050 assert_eq!(
2051 driver_addr.port() % shards_num,
2052 remote_addr.port() % shards_num
2053 );
2054 conn
2055 };
2056
2057 let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
2058 running_proxy.finish().await.unwrap();
2059 }
2060 }
2061
2062 for shard_num in 1..6 {
2063 test_for_shards_num(shard_num).await;
2064 }
2065 }
2066
2067 #[tokio::test]
2068 async fn forger_proxy_forges_response() {
2069 setup_tracing();
2070 let node1_real_addr = next_local_address_with_port(9876);
2071 let node1_proxy_addr = next_local_address_with_port(9876);
2072
2073 let this_shall_pass = b"This.Shall.Pass.";
2074 let test_msg = b"Test";
2075
2076 let proxy = Proxy::new([Node::new(
2077 node1_real_addr,
2078 node1_proxy_addr,
2079 ShardAwareness::Unaware,
2080 Some(vec![
2081 RequestRule(
2082 Condition::RequestOpcode(RequestOpcode::Register),
2083 RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
2084 ResponseFrame {
2085 params: params.for_response(),
2086 opcode: ResponseOpcode::Event,
2087 body: Bytes::from_static(test_msg),
2088 }
2089 })),
2090 ),
2091 RequestRule(
2092 Condition::BodyContainsCaseSensitive(Box::new(*this_shall_pass)),
2093 RequestReaction::noop(),
2094 ),
2095 RequestRule(
2096 Condition::True, RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
2098 ResponseFrame {
2099 params: params.for_response(),
2100 opcode: ResponseOpcode::Ready,
2101 body: Bytes::new(),
2102 }
2103 })),
2104 ),
2105 ]),
2106 None,
2107 )]);
2108 let running_proxy = proxy.run().await.unwrap();
2109
2110 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2111
2112 let params1 = FrameParams {
2113 flags: 2,
2114 version: 0x42,
2115 stream: 42,
2116 };
2117 let opcode1 = FrameOpcode::Request(RequestOpcode::Startup);
2118
2119 let params2 = FrameParams {
2120 flags: 4,
2121 version: 0x04,
2122 stream: 17,
2123 };
2124 let opcode2 = FrameOpcode::Request(RequestOpcode::Register);
2125
2126 let params3 = FrameParams {
2127 flags: 8,
2128 version: 0x04,
2129 stream: 11,
2130 };
2131 let opcode3 = FrameOpcode::Request(RequestOpcode::Execute);
2132
2133 let body1 = random_body();
2134 let body2 = random_body();
2135 let body3 = {
2136 let mut body = BytesMut::new();
2137 body.put(&b"uSeLeSs JuNk"[..]);
2138 body.put(&this_shall_pass[..]);
2139 body.freeze()
2140 };
2141
2142 let send_frame_to_shard = async {
2143 let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2144
2145 write_frame(params1, opcode1, &body1, &mut conn, &no_compression())
2146 .await
2147 .unwrap();
2148 write_frame(params2, opcode2, &body2, &mut conn, &no_compression())
2149 .await
2150 .unwrap();
2151 write_frame(params3, opcode3, &body3, &mut conn, &no_compression())
2152 .await
2153 .unwrap();
2154
2155 let ResponseFrame {
2156 params: recvd_params,
2157 opcode: recvd_opcode,
2158 body: recvd_body,
2159 } = read_response_frame(&mut conn, &no_compression())
2160 .await
2161 .unwrap();
2162 assert_eq!(recvd_params, params1.for_response());
2163 assert_eq!(recvd_opcode, ResponseOpcode::Ready);
2164 assert_eq!(recvd_body, Bytes::new());
2165
2166 let ResponseFrame {
2167 params: recvd_params,
2168 opcode: recvd_opcode,
2169 body: recvd_body,
2170 } = read_response_frame(&mut conn, &no_compression())
2171 .await
2172 .unwrap();
2173 assert_eq!(recvd_params, params2.for_response());
2174 assert_eq!(recvd_opcode, ResponseOpcode::Event);
2175 assert_eq!(recvd_body, Bytes::from_static(test_msg));
2176
2177 conn
2178 };
2179
2180 let mock_node_action = async {
2181 let (mut conn, _) = mock_node_listener.accept().await.unwrap();
2182 let RequestFrame {
2183 params: recvd_params,
2184 opcode: recvd_opcode,
2185 body: recvd_body,
2186 } = read_request_frame(&mut conn, &no_compression())
2187 .await
2188 .unwrap();
2189 assert_eq!(recvd_params, params3);
2190 assert_eq!(FrameOpcode::Request(recvd_opcode), opcode3);
2191 assert_eq!(recvd_body, body3);
2192
2193 conn
2194 };
2195
2196 let (mut node_conn, mut driver_conn) = join(mock_node_action, send_frame_to_shard).await;
2197
2198 running_proxy.finish().await.unwrap();
2199
2200 assert_matches!(driver_conn.read(&mut [0u8; 1]).await, Ok(0));
2201 assert_matches!(node_conn.read(&mut [0u8; 1]).await, Ok(0));
2202 }
2203
2204 #[tokio::test]
2205 async fn ad_hoc_rules_changing() {
2206 setup_tracing();
2207 let node1_real_addr = next_local_address_with_port(9876);
2208 let node1_proxy_addr = next_local_address_with_port(9876);
2209 let proxy = Proxy::new([Node::new(
2210 node1_real_addr,
2211 node1_proxy_addr,
2212 ShardAwareness::Unaware,
2213 None,
2214 None,
2215 )]);
2216 let mut running_proxy = proxy.run().await.unwrap();
2217
2218 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2219
2220 let params = FrameParams {
2221 flags: 0,
2222 version: 0x04,
2223 stream: 0,
2224 };
2225 let opcode = FrameOpcode::Request(RequestOpcode::Options);
2226
2227 let body = random_body();
2228
2229 let (mut driver, mut node) = {
2230 let results = join(
2231 TcpStream::connect(node1_proxy_addr),
2232 mock_node_listener.accept(),
2233 )
2234 .await;
2235 (results.0.unwrap(), results.1.unwrap().0)
2236 };
2237
2238 async fn request(
2239 driver: &mut TcpStream,
2240 node: &mut TcpStream,
2241 params: FrameParams,
2242 opcode: FrameOpcode,
2243 body: &Bytes,
2244 ) -> Result<RequestFrame, ReadFrameError> {
2245 let (send_res, recv_res) = join(
2246 write_frame(params, opcode, &body.clone(), driver, &no_compression()),
2247 read_request_frame(node, &no_compression()),
2248 )
2249 .await;
2250 send_res.unwrap();
2251 recv_res
2252 }
2253 {
2254 let RequestFrame {
2256 params: recvd_params,
2257 opcode: recvd_opcode,
2258 body: recvd_body,
2259 } = request(&mut driver, &mut node, params, opcode, &body)
2260 .await
2261 .unwrap();
2262 assert_eq!(recvd_params, params);
2263 assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
2264 assert_eq!(recvd_body, body);
2265 }
2266 running_proxy.running_nodes[0].change_request_rules(Some(vec![RequestRule(
2267 Condition::True,
2268 RequestReaction::drop_frame(),
2269 )]));
2270
2271 {
2272 tokio::select! {
2274 res = request(&mut driver, &mut node, params, opcode, &body) => panic!("Rules did not work: received response {:?}", res),
2275 _ = tokio::time::sleep(std::time::Duration::from_millis(20)) => (),
2276 };
2277 }
2278
2279 running_proxy.turn_off_rules();
2280
2281 {
2282 let RequestFrame {
2284 params: recvd_params,
2285 opcode: recvd_opcode,
2286 body: recvd_body,
2287 } = request(&mut driver, &mut node, params, opcode, &body)
2288 .await
2289 .unwrap();
2290 assert_eq!(recvd_params, params);
2291 assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
2292 assert_eq!(recvd_body, body);
2293 }
2294
2295 running_proxy.finish().await.unwrap();
2296 }
2297
2298 #[tokio::test]
2299 async fn limited_times_condition_expires() {
2300 setup_tracing();
2301 const FAILING_TRIES: usize = 4;
2302 const PASSING_TRIES: usize = 5;
2303
2304 let node1_real_addr = next_local_address_with_port(9876);
2305 let node1_proxy_addr = next_local_address_with_port(9876);
2306 let proxy = Proxy::new([Node::new(
2307 node1_real_addr,
2308 node1_proxy_addr,
2309 ShardAwareness::Unaware,
2310 Some(vec![
2311 RequestRule(
2312 Condition::not(Condition::TrueForLimitedTimes(
2314 FAILING_TRIES + PASSING_TRIES,
2315 )),
2316 RequestReaction::drop_frame(),
2317 ),
2318 RequestRule(
2319 Condition::not(Condition::TrueForLimitedTimes(FAILING_TRIES)),
2321 RequestReaction::noop(),
2322 ),
2323 RequestRule(
2324 Condition::True,
2326 RequestReaction::drop_frame(),
2327 ),
2328 ]),
2329 None,
2330 )]);
2331 let running_proxy = proxy.run().await.unwrap();
2332
2333 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2334
2335 let params = FrameParams {
2336 flags: 0,
2337 version: 0x04,
2338 stream: 0,
2339 };
2340 let opcode = FrameOpcode::Request(RequestOpcode::Options);
2341 let body = random_body();
2342
2343 let (mut driver, mut node) = {
2344 let results = join(
2345 TcpStream::connect(node1_proxy_addr),
2346 mock_node_listener.accept(),
2347 )
2348 .await;
2349 (results.0.unwrap(), results.1.unwrap().0)
2350 };
2351
2352 async fn request(
2353 driver: &mut TcpStream,
2354 node: &mut TcpStream,
2355 params: FrameParams,
2356 opcode: FrameOpcode,
2357 body: &Bytes,
2358 ) -> Result<RequestFrame, ReadFrameError> {
2359 let (send_res, recv_res) = join(
2360 write_frame(params, opcode, &body.clone(), driver, &no_compression()),
2361 read_request_frame(node, &no_compression()),
2362 )
2363 .await;
2364 send_res.unwrap();
2365 recv_res
2366 }
2367
2368 for _ in 0..FAILING_TRIES {
2369 tokio::select! {
2370 res = request(&mut driver, &mut node, params, opcode, &body) => panic!("Rules did not work: received response {:?}", res),
2371 _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => (),
2372 };
2373 }
2374
2375 for _ in 0..PASSING_TRIES {
2376 let RequestFrame {
2377 params: recvd_params,
2378 opcode: recvd_opcode,
2379 body: recvd_body,
2380 } = request(&mut driver, &mut node, params, opcode, &body)
2381 .await
2382 .unwrap();
2383 assert_eq!(recvd_params, params);
2384 assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
2385 assert_eq!(recvd_body, body);
2386 }
2387
2388 for _ in 0..3 {
2389 tokio::select! {
2391 res = request(&mut driver, &mut node, params, opcode, &body) => panic!("Rules did not work: received response {:?}", res),
2392 _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => (),
2393 };
2394 }
2395
2396 running_proxy.finish().await.unwrap();
2397 }
2398
2399 #[tokio::test]
2400 async fn proxy_reports_requests_and_responses_as_feedback() {
2401 setup_tracing();
2402 let node1_real_addr = next_local_address_with_port(9876);
2403 let node1_proxy_addr = next_local_address_with_port(9876);
2404
2405 let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
2406 let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
2407 let proxy = Proxy::new([Node::new(
2408 node1_real_addr,
2409 node1_proxy_addr,
2410 ShardAwareness::Unaware,
2411 Some(vec![RequestRule(
2412 Condition::True,
2413 RequestReaction::drop_frame().with_feedback_when_performed(request_feedback_tx),
2414 )]),
2415 Some(vec![ResponseRule(
2416 Condition::True,
2417 ResponseReaction::drop_frame().with_feedback_when_performed(response_feedback_tx),
2418 )]),
2419 )]);
2420 let running_proxy = proxy.run().await.unwrap();
2421
2422 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2423
2424 let params = FrameParams {
2425 flags: 0,
2426 version: 0x04,
2427 stream: 0,
2428 };
2429 let request_opcode = FrameOpcode::Request(RequestOpcode::Options);
2430 let response_opcode = FrameOpcode::Response(ResponseOpcode::Ready);
2431
2432 let body = random_body();
2433
2434 let send_frame_to_shard = async {
2435 let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2436 write_frame(params, request_opcode, &body, &mut conn, &no_compression())
2437 .await
2438 .unwrap();
2439 conn
2440 };
2441
2442 let mock_node_action = async {
2443 let (mut conn, _) = mock_node_listener.accept().await.unwrap();
2444 write_frame(
2445 params.for_response(),
2446 response_opcode,
2447 &body,
2448 &mut conn,
2449 &no_compression(),
2450 )
2451 .await
2452 .unwrap();
2453 conn
2454 };
2455
2456 let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
2458
2459 let (feedback_request, _shard) = request_feedback_rx.recv().await.unwrap();
2460 assert_eq!(feedback_request.params, params);
2461 assert_eq!(
2462 FrameOpcode::Request(feedback_request.opcode),
2463 request_opcode
2464 );
2465 assert_eq!(feedback_request.body, body);
2466 let (feedback_response, _shard) = response_feedback_rx.recv().await.unwrap();
2467 assert_eq!(feedback_response.params, params.for_response());
2468 assert_eq!(
2469 FrameOpcode::Response(feedback_response.opcode),
2470 response_opcode
2471 );
2472 assert_eq!(feedback_response.body, body);
2473
2474 running_proxy.finish().await.unwrap();
2475 }
2476
2477 #[tokio::test]
2478 async fn sanity_check_reports_errors() {
2479 setup_tracing();
2480 let node1_real_addr = next_local_address_with_port(9876);
2481 let node1_proxy_addr = next_local_address_with_port(9876);
2482 let proxy = Proxy::new([Node::new(
2483 node1_real_addr,
2484 node1_proxy_addr,
2485 ShardAwareness::Unaware,
2486 None,
2487 None,
2488 )]);
2489 let mut running_proxy = proxy.run().await.unwrap();
2490
2491 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2492
2493 let send_frame_to_shard = async {
2494 let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2495
2496 conn.write_all(b"uselessJunk").await.unwrap();
2497 conn
2498 };
2499
2500 let mock_node_action = async {
2501 let (conn, _) = mock_node_listener.accept().await.unwrap();
2502 conn
2503 };
2504
2505 let (node_conn, driver_conn) = join(mock_node_action, send_frame_to_shard).await;
2506
2507 running_proxy.sanity_check().unwrap();
2508
2509 mem::drop(driver_conn);
2510 assert_matches!(
2511 running_proxy.wait_for_error().await,
2512 Some(ProxyError::Worker(WorkerError::DriverDisconnected(_)))
2513 );
2514 running_proxy.sanity_check().unwrap();
2515
2516 mem::drop(node_conn);
2517 assert_matches!(
2518 running_proxy.wait_for_error().await,
2519 Some(ProxyError::Worker(WorkerError::NodeDisconnected(_)))
2520 );
2521 running_proxy.sanity_check().unwrap();
2522
2523 let _ = running_proxy.finish().await;
2525 }
2526
2527 #[tokio::test]
2528 async fn proxy_processes_requests_concurrently() {
2529 setup_tracing();
2530 let node1_real_addr = next_local_address_with_port(9876);
2531 let node1_proxy_addr = next_local_address_with_port(9876);
2532
2533 let delay = Duration::from_millis(60);
2534
2535 let proxy = Proxy::new([Node::new(
2536 node1_real_addr,
2537 node1_proxy_addr,
2538 ShardAwareness::Unaware,
2539 Some(vec![RequestRule(
2540 Condition::TrueForLimitedTimes(1),
2541 RequestReaction::delay(delay),
2542 )]),
2543 None,
2544 )]);
2545 let running_proxy = proxy.run().await.unwrap();
2546
2547 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2548
2549 let params1 = FrameParams {
2550 flags: 0,
2551 version: 0x04,
2552 stream: 0,
2553 };
2554 let opcode1 = FrameOpcode::Request(RequestOpcode::Options);
2555
2556 let body1 = random_body();
2557
2558 let params2 = FrameParams {
2559 flags: 0,
2560 version: 0x04,
2561 stream: 0,
2562 };
2563 let opcode2 = FrameOpcode::Request(RequestOpcode::Register);
2564
2565 let body2 = random_body();
2566
2567 let send_frame_to_shard = async {
2568 let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2569
2570 write_frame(params1, opcode1, &body1, &mut conn, &no_compression())
2571 .await
2572 .unwrap();
2573 write_frame(params2, opcode2, &body2, &mut conn, &no_compression())
2574 .await
2575 .unwrap();
2576 conn
2577 };
2578
2579 let mock_node_action = async {
2580 let (mut conn, _) = mock_node_listener.accept().await.unwrap();
2581 let RequestFrame {
2582 params: recvd_params,
2583 opcode: recvd_opcode,
2584 body: recvd_body,
2585 } = read_request_frame(&mut conn, &no_compression())
2586 .await
2587 .unwrap();
2588 assert_eq!(recvd_params, params2);
2589 assert_eq!(FrameOpcode::Request(recvd_opcode), opcode2);
2590 assert_eq!(recvd_body, body2);
2591 conn
2592 };
2593
2594 let (_node_conn, _driver_conn) =
2596 tokio::time::timeout(delay, join(mock_node_action, send_frame_to_shard))
2597 .await
2598 .expect("Request processing was not concurrent");
2599 running_proxy.finish().await.unwrap();
2600 }
2601
2602 #[tokio::test]
2603 async fn dry_mode_proxy_drops_incoming_frames() {
2604 setup_tracing();
2605 let node1_proxy_addr = next_local_address_with_port(9876);
2606 let proxy = Proxy::new([Node::new_dry_mode(node1_proxy_addr, None)]);
2607 let running_proxy = proxy.run().await.unwrap();
2608
2609 let params = FrameParams {
2610 flags: 0,
2611 version: 0x04,
2612 stream: 0,
2613 };
2614 let opcode = FrameOpcode::Request(RequestOpcode::Options);
2615
2616 let body = random_body();
2617
2618 let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2619
2620 write_frame(params, opcode, &body, &mut conn, &no_compression())
2621 .await
2622 .unwrap();
2623 tokio::time::sleep(Duration::from_millis(3)).await;
2625 running_proxy.finish().await.unwrap();
2626 }
2627
2628 #[tokio::test]
2629 async fn dry_mode_forger_proxy_forges_response() {
2630 setup_tracing();
2631 let node1_proxy_addr = next_local_address_with_port(9876);
2632
2633 let this_shall_pass = b"This.Shall.Pass.";
2634 let test_msg = b"Test";
2635
2636 let proxy = Proxy::new([Node::new_dry_mode(
2637 node1_proxy_addr,
2638 Some(vec![
2639 RequestRule(
2640 Condition::RequestOpcode(RequestOpcode::Register),
2641 RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
2642 ResponseFrame {
2643 params: params.for_response(),
2644 opcode: ResponseOpcode::Event,
2645 body: Bytes::from_static(test_msg),
2646 }
2647 })),
2648 ),
2649 RequestRule(
2650 Condition::BodyContainsCaseSensitive(Box::new(*this_shall_pass)),
2651 RequestReaction::noop(),
2652 ),
2653 RequestRule(
2654 Condition::True, RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
2656 ResponseFrame {
2657 params: params.for_response(),
2658 opcode: ResponseOpcode::Ready,
2659 body: Bytes::new(),
2660 }
2661 })),
2662 ),
2663 ]),
2664 )]);
2665 let running_proxy = proxy.run().await.unwrap();
2666
2667 let params1 = FrameParams {
2668 flags: 2,
2669 version: 0x42,
2670 stream: 42,
2671 };
2672 let opcode1 = FrameOpcode::Request(RequestOpcode::Startup);
2673
2674 let params2 = FrameParams {
2675 flags: 4,
2676 version: 0x04,
2677 stream: 17,
2678 };
2679 let opcode2 = FrameOpcode::Request(RequestOpcode::Register);
2680
2681 let params3 = FrameParams {
2682 flags: 8,
2683 version: 0x04,
2684 stream: 11,
2685 };
2686 let opcode3 = FrameOpcode::Request(RequestOpcode::Execute);
2687
2688 let body1 = random_body();
2689 let body2 = random_body();
2690 let body3 = {
2691 let mut body = BytesMut::new();
2692 body.put(&b"uSeLeSs JuNk"[..]);
2693 body.put(&this_shall_pass[..]);
2694 body.freeze()
2695 };
2696
2697 let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2698
2699 write_frame(params1, opcode1, &body1, &mut conn, &no_compression())
2700 .await
2701 .unwrap();
2702 write_frame(params2, opcode2, &body2, &mut conn, &no_compression())
2703 .await
2704 .unwrap();
2705 write_frame(params3, opcode3, &body3, &mut conn, &no_compression())
2706 .await
2707 .unwrap();
2708
2709 let ResponseFrame {
2710 params: recvd_params,
2711 opcode: recvd_opcode,
2712 body: recvd_body,
2713 } = read_response_frame(&mut conn, &no_compression())
2714 .await
2715 .unwrap();
2716 assert_eq!(recvd_params, params1.for_response());
2717 assert_eq!(recvd_opcode, ResponseOpcode::Ready);
2718 assert_eq!(recvd_body, Bytes::new());
2719
2720 let ResponseFrame {
2721 params: recvd_params,
2722 opcode: recvd_opcode,
2723 body: recvd_body,
2724 } = read_response_frame(&mut conn, &no_compression())
2725 .await
2726 .unwrap();
2727 assert_eq!(recvd_params, params2.for_response());
2728 assert_eq!(recvd_opcode, ResponseOpcode::Event);
2729 assert_eq!(recvd_body, Bytes::from_static(test_msg));
2730
2731 running_proxy.finish().await.unwrap();
2732
2733 assert_matches!(conn.read(&mut [0u8; 1]).await, Ok(0));
2734 }
2735
2736 #[tokio::test]
2740 async fn proxy_reports_target_shard_as_feedback() {
2741 setup_tracing();
2742
2743 let node_port = 10101;
2744 let node_real_addr = next_local_address_with_port(node_port);
2745 let mock_node_listener = TcpListener::bind(node_real_addr).await.unwrap();
2746
2747 let params = FrameParams {
2748 flags: 0,
2749 version: 0x04,
2750 stream: 0,
2751 };
2752 let request_opcode = FrameOpcode::Request(RequestOpcode::Options);
2753 let response_opcode = FrameOpcode::Response(ResponseOpcode::Ready);
2754
2755 let body = random_body();
2756
2757 for shards_count in 2..9 {
2758 let driver1_shard = shards_count - 1;
2760 let driver2_shard = shards_count - 2;
2761 let node_proxy_addr = next_local_address_with_port(node_port);
2762
2763 let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
2764 let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
2765
2766 let proxy = Proxy::new([Node::new(
2767 node_real_addr,
2768 node_proxy_addr,
2769 ShardAwareness::FixedNum(shards_count),
2770 Some(vec![RequestRule(
2771 Condition::True,
2772 RequestReaction::drop_frame().with_feedback_when_performed(request_feedback_tx),
2773 )]),
2774 Some(vec![ResponseRule(
2775 Condition::True,
2776 ResponseReaction::drop_frame()
2777 .with_feedback_when_performed(response_feedback_tx),
2778 )]),
2779 )]);
2780 let running_proxy = proxy.run().await.unwrap();
2781
2782 fn draw_source_port_for_shard(shards_count: u16, shard: u16) -> u16 {
2784 assert!(shard < shards_count);
2785 49152u16.div_ceil(shards_count) * shards_count + shard
2786 }
2787
2788 async fn bind_socket_for_shard(shards_count: u16, shard: u16) -> TcpSocket {
2789 let socket = TcpSocket::new_v4().unwrap();
2790 let initial_port = draw_source_port_for_shard(shards_count, shard);
2791
2792 let mut desired_addr =
2793 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), initial_port);
2794 while socket.bind(desired_addr).is_err() {
2795 let next_port = desired_addr.port().wrapping_add(shards_count);
2797 if next_port == initial_port {
2798 panic!("No more ports left");
2799 }
2800 desired_addr.set_port(next_port);
2801 }
2802
2803 socket
2804 }
2805
2806 let body_ref = &body;
2807 let send_frame_to_shard = |driver_shard: u16| async move {
2808 let socket = bind_socket_for_shard(shards_count, driver_shard).await;
2809 let mut conn = socket.connect(node_proxy_addr).await.unwrap();
2810
2811 write_frame(
2812 params,
2813 request_opcode,
2814 body_ref,
2815 &mut conn,
2816 &no_compression(),
2817 )
2818 .await
2819 .unwrap();
2820 conn
2821 };
2822
2823 let mock_driver1_action = send_frame_to_shard(driver1_shard);
2824 let mock_driver2_action = send_frame_to_shard(driver2_shard);
2825
2826 let mock_node_action = async {
2828 let mut conns_futs = (0..2)
2829 .map(|_| async {
2830 let (mut conn, driver_addr) = mock_node_listener.accept().await.unwrap();
2831 respond_with_shard_num(
2832 &mut conn,
2833 driver_addr.port() % shards_count,
2834 &no_compression(),
2835 )
2836 .await;
2837 write_frame(
2838 params.for_response(),
2839 response_opcode,
2840 body_ref,
2841 &mut conn,
2842 &no_compression(),
2843 )
2844 .await
2845 .unwrap();
2846 conn
2847 })
2848 .collect::<Vec<_>>();
2849 let conn2 = conns_futs.pop().unwrap().await;
2850 let conn1 = conns_futs.pop().unwrap().await;
2851 (conn1, conn2)
2852 };
2853
2854 let (_node_conns, _driver1_conn, _driver2_conn) =
2856 join3(mock_node_action, mock_driver1_action, mock_driver2_action).await;
2857
2858 let assert_feedback_request = |feedback_request: RequestFrame| {
2859 assert_eq!(feedback_request.params, params);
2860 assert_eq!(
2861 FrameOpcode::Request(feedback_request.opcode),
2862 request_opcode
2863 );
2864 assert_eq!(feedback_request.body, body);
2865 };
2866
2867 let assert_feedback_response = |feedback_response: ResponseFrame| {
2868 assert_eq!(feedback_response.params, params.for_response());
2869 assert_eq!(
2870 FrameOpcode::Response(feedback_response.opcode),
2871 response_opcode
2872 );
2873 assert_eq!(feedback_response.body, body);
2874 };
2875
2876 let (feedback_request, shard1) = request_feedback_rx.recv().await.unwrap();
2877 assert_feedback_request(feedback_request);
2878 let (feedback_request, shard2) = request_feedback_rx.recv().await.unwrap();
2879 assert_feedback_request(feedback_request);
2880 let (feedback_response, shard3) = response_feedback_rx.recv().await.unwrap();
2881 assert_feedback_response(feedback_response);
2882 let (feedback_response, shard4) = response_feedback_rx.recv().await.unwrap();
2883 assert_feedback_response(feedback_response);
2884
2885 let mut expected_shards = [driver1_shard, driver1_shard, driver2_shard, driver2_shard];
2887 expected_shards.sort_unstable();
2888
2889 let mut got_shards = [
2890 shard1.unwrap(),
2891 shard2.unwrap(),
2892 shard3.unwrap(),
2893 shard4.unwrap(),
2894 ];
2895 got_shards.sort_unstable();
2896
2897 assert_eq!(expected_shards, got_shards);
2898
2899 running_proxy.finish().await.unwrap();
2900 }
2901 }
2902
2903 #[tokio::test]
2904 async fn proxy_ignores_control_connection_messages() {
2905 setup_tracing();
2906 let node1_real_addr = next_local_address_with_port(9876);
2907 let node1_proxy_addr = next_local_address_with_port(9876);
2908
2909 let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
2910 let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
2911 let proxy = Proxy::new([Node::new(
2912 node1_real_addr,
2913 node1_proxy_addr,
2914 ShardAwareness::Unaware,
2915 Some(vec![RequestRule(
2916 Condition::not(Condition::ConnectionRegisteredAnyEvent),
2917 RequestReaction::noop().with_feedback_when_performed(request_feedback_tx),
2918 )]),
2919 Some(vec![ResponseRule(
2920 Condition::not(Condition::ConnectionRegisteredAnyEvent),
2921 ResponseReaction::noop().with_feedback_when_performed(response_feedback_tx),
2922 )]),
2923 )]);
2924 let running_proxy = proxy.run().await.unwrap();
2925
2926 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2927
2928 let (mut client_socket, mut server_socket) = join(
2929 async { TcpStream::connect(node1_proxy_addr).await.unwrap() },
2930 async { mock_node_listener.accept().await.unwrap().0 },
2931 )
2932 .await;
2933
2934 async fn perform_reqest_response<'a>(
2935 req_opcode: RequestOpcode,
2936 resp_opcode: ResponseOpcode,
2937 client_socket_ref: &'a mut TcpStream,
2938 server_socket_ref: &'a mut TcpStream,
2939 body_base: &'a str,
2940 ) {
2941 let params = FrameParams {
2942 flags: 0,
2943 version: 0x04,
2944 stream: 0,
2945 };
2946
2947 write_frame(
2948 params,
2949 FrameOpcode::Request(req_opcode),
2950 (body_base.to_string() + "|request|").as_bytes(),
2951 client_socket_ref,
2952 &no_compression(),
2953 )
2954 .await
2955 .unwrap();
2956
2957 let received_request =
2958 read_frame(server_socket_ref, FrameType::Request, &no_compression())
2959 .await
2960 .unwrap();
2961 assert_eq!(received_request.1, FrameOpcode::Request(req_opcode));
2962
2963 write_frame(
2964 params.for_response(),
2965 FrameOpcode::Response(resp_opcode),
2966 (body_base.to_string() + "|response|").as_bytes(),
2967 server_socket_ref,
2968 &no_compression(),
2969 )
2970 .await
2971 .unwrap();
2972
2973 let received_response =
2974 read_frame(client_socket_ref, FrameType::Response, &no_compression())
2975 .await
2976 .unwrap();
2977 assert_eq!(received_response.1, FrameOpcode::Response(resp_opcode));
2978 }
2979
2980 for i in 0..5 {
2982 perform_reqest_response(
2983 RequestOpcode::Query,
2984 ResponseOpcode::Result,
2985 &mut client_socket,
2986 &mut server_socket,
2987 &format!("message_before_{i}"),
2988 )
2989 .await
2990 }
2991
2992 perform_reqest_response(
2993 RequestOpcode::Register,
2994 ResponseOpcode::Result,
2995 &mut client_socket,
2996 &mut server_socket,
2997 "message_register",
2998 )
2999 .await;
3000
3001 for i in 0..5 {
3003 perform_reqest_response(
3004 RequestOpcode::Query,
3005 ResponseOpcode::Result,
3006 &mut client_socket,
3007 &mut server_socket,
3008 &format!("message_after_{i}"),
3009 )
3010 .await
3011 }
3012
3013 running_proxy.finish().await.unwrap();
3014
3015 for _ in 0..5 {
3016 let (feedback_request, _shard) = request_feedback_rx.recv().await.unwrap();
3017 assert_eq!(feedback_request.opcode, RequestOpcode::Query);
3018 let (feedback_response, _shard) = response_feedback_rx.recv().await.unwrap();
3019 assert_eq!(feedback_response.opcode, ResponseOpcode::Result);
3020 }
3021
3022 let _ = request_feedback_rx.try_recv().unwrap_err();
3024 let _ = response_feedback_rx.try_recv().unwrap_err();
3025 }
3026
3027 #[tokio::test]
3028 async fn proxy_compresses_and_decompresses_frames_iff_compression_negotiated() {
3029 setup_tracing();
3030 let node1_real_addr = next_local_address_with_port(9876);
3031 let node1_proxy_addr = next_local_address_with_port(9876);
3032
3033 let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
3034 let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
3035 let proxy = Proxy::builder()
3036 .with_node(
3037 Node::builder()
3038 .real_address(node1_real_addr)
3039 .proxy_address(node1_proxy_addr)
3040 .shard_awareness(ShardAwareness::Unaware)
3041 .request_rules(vec![RequestRule(
3042 Condition::True,
3043 RequestReaction::noop().with_feedback_when_performed(request_feedback_tx),
3044 )])
3045 .response_rules(vec![ResponseRule(
3046 Condition::True,
3047 ResponseReaction::noop().with_feedback_when_performed(response_feedback_tx),
3048 )])
3049 .build(),
3050 )
3051 .build();
3052 let running_proxy = proxy.run().await.unwrap();
3053
3054 let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
3055
3056 const PARAMS_REQUEST_NO_COMPRESSION: FrameParams = FrameParams {
3057 flags: 0,
3058 version: 0x04,
3059 stream: 0,
3060 };
3061 const PARAMS_REQUEST_COMPRESSION: FrameParams = FrameParams {
3062 flags: flag::COMPRESSION,
3063 ..PARAMS_REQUEST_NO_COMPRESSION
3064 };
3065 const PARAMS_RESPONSE_NO_COMPRESSION: FrameParams =
3066 PARAMS_REQUEST_NO_COMPRESSION.for_response();
3067 const PARAMS_RESPONSE_COMPRESSION: FrameParams =
3068 PARAMS_REQUEST_NO_COMPRESSION.for_response();
3069
3070 let make_driver_conn = async { TcpStream::connect(node1_proxy_addr).await.unwrap() };
3071 let make_node_conn = async { mock_node_listener.accept().await.unwrap() };
3072
3073 let (mut driver_conn, (mut node_conn, _)) = join(make_driver_conn, make_node_conn).await;
3074
3075 {
3093 let sent_frame = RequestFrame {
3094 params: PARAMS_REQUEST_NO_COMPRESSION,
3095 opcode: RequestOpcode::Query,
3096 body: random_body(),
3097 };
3098
3099 sent_frame
3100 .write(&mut driver_conn, &no_compression())
3101 .await
3102 .unwrap();
3103
3104 let (captured_frame, _) = request_feedback_rx.recv().await.unwrap();
3105 assert_eq!(captured_frame, sent_frame);
3106
3107 let received_frame = read_request_frame(&mut node_conn, &no_compression())
3108 .await
3109 .unwrap();
3110 assert_eq!(received_frame, sent_frame);
3111 }
3112
3113 {
3116 let sent_frame = ResponseFrame {
3117 params: PARAMS_RESPONSE_NO_COMPRESSION,
3118 opcode: ResponseOpcode::Result,
3119 body: random_body(),
3120 };
3121
3122 sent_frame
3123 .write(&mut node_conn, &no_compression())
3124 .await
3125 .unwrap();
3126
3127 let (captured_frame, _) = response_feedback_rx.recv().await.unwrap();
3128 assert_eq!(captured_frame, sent_frame);
3129
3130 let received_frame = read_response_frame(&mut driver_conn, &no_compression())
3131 .await
3132 .unwrap();
3133 assert_eq!(received_frame, sent_frame);
3134 }
3135
3136 {
3141 let startup_body = Startup {
3142 options: std::iter::once((
3143 options::COMPRESSION.into(),
3144 Compression::Lz4.as_str().into(),
3145 ))
3146 .collect(),
3147 }
3148 .to_bytes()
3149 .unwrap();
3150
3151 let sent_frame = RequestFrame {
3152 params: PARAMS_REQUEST_NO_COMPRESSION,
3153 opcode: RequestOpcode::Startup,
3154 body: startup_body,
3155 };
3156
3157 sent_frame
3158 .write(&mut driver_conn, &no_compression())
3159 .await
3160 .unwrap();
3161
3162 let (captured_frame, _) = request_feedback_rx.recv().await.unwrap();
3163 assert_eq!(captured_frame, sent_frame);
3164
3165 let received_frame = read_request_frame(&mut node_conn, &no_compression())
3166 .await
3167 .unwrap();
3168 assert_eq!(received_frame, sent_frame);
3169 }
3170
3171 {
3174 let sent_frame = RequestFrame {
3175 params: PARAMS_REQUEST_COMPRESSION,
3176 opcode: RequestOpcode::Query,
3177 body: random_body(),
3178 };
3179
3180 sent_frame
3181 .write(&mut driver_conn, &with_compression(Compression::Lz4))
3182 .await
3183 .unwrap();
3184
3185 let (captured_frame, _) = request_feedback_rx.recv().await.unwrap();
3186 assert_eq!(captured_frame, sent_frame);
3187
3188 let received_frame =
3189 read_request_frame(&mut node_conn, &with_compression(Compression::Lz4))
3190 .await
3191 .unwrap();
3192 assert_eq!(received_frame, sent_frame);
3193 }
3194
3195 {
3198 let sent_frame = ResponseFrame {
3199 params: PARAMS_RESPONSE_COMPRESSION,
3200 opcode: ResponseOpcode::Result,
3201 body: random_body(),
3202 };
3203
3204 sent_frame
3205 .write(&mut node_conn, &with_compression(Compression::Lz4))
3206 .await
3207 .unwrap();
3208
3209 let (captured_frame, _) = response_feedback_rx.recv().await.unwrap();
3210 assert_eq!(captured_frame, sent_frame);
3211
3212 let received_frame =
3213 read_response_frame(&mut driver_conn, &with_compression(Compression::Lz4))
3214 .await
3215 .unwrap();
3216 assert_eq!(received_frame, sent_frame);
3217 }
3218
3219 running_proxy.finish().await.unwrap();
3220 }
3221
3222 #[tokio::test]
3223 async fn connection_tracker_register_increments_and_drop_decrements() {
3224 let tracker = Arc::new(ConnectionTracker {
3225 active_count: std::sync::atomic::AtomicUsize::new(0),
3226 notify: tokio::sync::Notify::new(),
3227 });
3228
3229 assert_eq!(tracker.active_count.load(Ordering::Relaxed), 0);
3230
3231 let guard1 = tracker.register_connection();
3232 assert_eq!(tracker.active_count.load(Ordering::Relaxed), 1);
3233
3234 let guard2 = tracker.register_connection();
3235 assert_eq!(tracker.active_count.load(Ordering::Relaxed), 2);
3236
3237 let guard1_clone = Arc::clone(&guard1);
3239 assert_eq!(tracker.active_count.load(Ordering::Relaxed), 2);
3240
3241 drop(guard1_clone);
3243 assert_eq!(tracker.active_count.load(Ordering::Relaxed), 2);
3244
3245 drop(guard1);
3247 assert_eq!(tracker.active_count.load(Ordering::Relaxed), 1);
3248
3249 drop(guard2);
3250 assert_eq!(tracker.active_count.load(Ordering::Relaxed), 0);
3251 }
3252
3253 #[tokio::test]
3254 async fn connection_tracker_register_notifies_waiters() {
3255 let tracker = Arc::new(ConnectionTracker {
3256 active_count: std::sync::atomic::AtomicUsize::new(0),
3257 notify: tokio::sync::Notify::new(),
3258 });
3259
3260 let notified = tracker.notify.notified();
3262
3263 let _guard = tracker.register_connection();
3265
3266 tokio::time::timeout(Duration::from_millis(200), notified)
3268 .await
3269 .expect("notify was not triggered by register_connection");
3270 }
3271
3272 #[tokio::test]
3273 async fn wait_for_connection_returns_immediately_when_connected() {
3274 let node_real_addr = next_local_address_with_port(9876);
3275 let node_proxy_addr = next_local_address_with_port(9876);
3276 let proxy = Proxy::new([Node::new(
3277 node_real_addr,
3278 node_proxy_addr,
3279 ShardAwareness::Unaware,
3280 None,
3281 None,
3282 )]);
3283 let running_proxy = proxy.run().await.unwrap();
3284
3285 let mock_node_listener = TcpListener::bind(node_real_addr).await.unwrap();
3286
3287 let _driver_conn = TcpStream::connect(node_proxy_addr).await.unwrap();
3289 let (_backend_conn, _) = mock_node_listener.accept().await.unwrap();
3290
3291 tokio::time::timeout(
3293 Duration::from_millis(200),
3294 running_proxy.running_nodes[0].wait_for_connection(),
3295 )
3296 .await
3297 .expect("wait_for_connection timed out despite active connection");
3298
3299 tokio::time::timeout(
3301 Duration::from_millis(200),
3302 running_proxy.wait_for_connection(),
3303 )
3304 .await
3305 .expect("RunningProxy::wait_for_connection timed out despite active connection");
3306
3307 running_proxy.finish().await.unwrap();
3308 }
3309
3310 #[tokio::test]
3311 async fn wait_for_connection_blocks_until_driver_connects() {
3312 let node_real_addr = next_local_address_with_port(9876);
3313 let node_proxy_addr = next_local_address_with_port(9876);
3314 let proxy = Proxy::new([Node::new(
3315 node_real_addr,
3316 node_proxy_addr,
3317 ShardAwareness::Unaware,
3318 None,
3319 None,
3320 )]);
3321 let running_proxy = proxy.run().await.unwrap();
3322
3323 let mock_node_listener = TcpListener::bind(node_real_addr).await.unwrap();
3324
3325 let result = tokio::time::timeout(
3327 Duration::from_millis(20),
3328 running_proxy.running_nodes[0].wait_for_connection(),
3329 )
3330 .await;
3331 assert!(
3332 result.is_err(),
3333 "wait_for_connection returned before any driver connected"
3334 );
3335
3336 let _driver_conn = TcpStream::connect(node_proxy_addr).await.unwrap();
3338 let (_backend_conn, _) = mock_node_listener.accept().await.unwrap();
3339
3340 tokio::time::timeout(
3342 Duration::from_millis(200),
3343 running_proxy.running_nodes[0].wait_for_connection(),
3344 )
3345 .await
3346 .expect("wait_for_connection timed out after driver connected");
3347
3348 running_proxy.finish().await.unwrap();
3349 }
3350
3351 async fn send_register(driver_conn: &mut TcpStream, node_conn: &mut TcpStream) {
3356 let params = FrameParams {
3357 flags: 0,
3358 version: 0x04,
3359 stream: 0,
3360 };
3361 write_frame(
3362 params,
3363 FrameOpcode::Request(RequestOpcode::Register),
3364 b"",
3365 driver_conn,
3366 &no_compression(),
3367 )
3368 .await
3369 .unwrap();
3370
3371 let _req = read_request_frame(node_conn, &no_compression())
3374 .await
3375 .unwrap();
3376 }
3377
3378 #[tokio::test]
3379 async fn inject_event_to_cc_returns_false_when_no_control_connections() {
3380 setup_tracing();
3381 let node_real_addr = next_local_address_with_port(9876);
3382 let node_proxy_addr = next_local_address_with_port(9876);
3383 let proxy = Proxy::new([Node::new(
3384 node_real_addr,
3385 node_proxy_addr,
3386 ShardAwareness::Unaware,
3387 None,
3388 None,
3389 )]);
3390 let running_proxy = proxy.run().await.unwrap();
3391 let _mock_node_listener = TcpListener::bind(node_real_addr).await.unwrap();
3392
3393 assert!(
3395 !running_proxy.running_nodes[0].inject_event_to_cc(Bytes::from_static(b"test")),
3396 "inject_event_to_cc should return false with no control connections"
3397 );
3398
3399 let _ = running_proxy.finish().await;
3401 }
3402
3403 #[tokio::test]
3404 async fn inject_event_to_cc_returns_false_when_connection_did_not_register() {
3405 setup_tracing();
3406 let node_real_addr = next_local_address_with_port(9876);
3407 let node_proxy_addr = next_local_address_with_port(9876);
3408 let proxy = Proxy::new([Node::new(
3409 node_real_addr,
3410 node_proxy_addr,
3411 ShardAwareness::Unaware,
3412 None,
3413 None,
3414 )]);
3415 let running_proxy = proxy.run().await.unwrap();
3416 let mock_node_listener = TcpListener::bind(node_real_addr).await.unwrap();
3417
3418 let _driver_conn = TcpStream::connect(node_proxy_addr).await.unwrap();
3420 let (_node_conn, _) = mock_node_listener.accept().await.unwrap();
3421
3422 assert!(
3424 !running_proxy.running_nodes[0].inject_event_to_cc(Bytes::from_static(b"test")),
3425 "inject_event_to_cc should return false when no REGISTER was sent"
3426 );
3427
3428 running_proxy.finish().await.unwrap();
3429 }
3430
3431 #[tokio::test]
3432 async fn inject_event_to_cc_delivers_event_after_register() {
3433 setup_tracing();
3434 let node_real_addr = next_local_address_with_port(9876);
3435 let node_proxy_addr = next_local_address_with_port(9876);
3436 let proxy = Proxy::new([Node::new(
3437 node_real_addr,
3438 node_proxy_addr,
3439 ShardAwareness::Unaware,
3440 None,
3441 None,
3442 )]);
3443 let running_proxy = proxy.run().await.unwrap();
3444 let mock_node_listener = TcpListener::bind(node_real_addr).await.unwrap();
3445
3446 let (mut driver_conn, mut node_conn) = join(
3447 async { TcpStream::connect(node_proxy_addr).await.unwrap() },
3448 async { mock_node_listener.accept().await.unwrap().0 },
3449 )
3450 .await;
3451
3452 send_register(&mut driver_conn, &mut node_conn).await;
3454
3455 let event_body = Bytes::from_static(b"injected_event_payload");
3457 assert!(
3458 running_proxy.running_nodes[0].inject_event_to_cc(event_body.clone()),
3459 "inject_event_to_cc should return true after REGISTER"
3460 );
3461
3462 let frame = tokio::time::timeout(
3464 Duration::from_millis(100),
3465 read_response_frame(&mut driver_conn, &no_compression()),
3466 )
3467 .await
3468 .expect("timed out waiting for injected event")
3469 .expect("failed to read injected event frame");
3470
3471 assert_eq!(frame.opcode, ResponseOpcode::Event);
3472 assert_eq!(frame.body, event_body);
3473 assert_eq!(frame.params.stream, -1);
3474
3475 running_proxy.finish().await.unwrap();
3476 }
3477
3478 #[tokio::test]
3479 async fn inject_event_to_cc_prunes_closed_connections() {
3480 setup_tracing();
3481 let node_real_addr = next_local_address_with_port(9876);
3482 let node_proxy_addr = next_local_address_with_port(9876);
3483 let proxy = Proxy::new([Node::new(
3484 node_real_addr,
3485 node_proxy_addr,
3486 ShardAwareness::Unaware,
3487 None,
3488 None,
3489 )]);
3490 let running_proxy = proxy.run().await.unwrap();
3491 let mock_node_listener = TcpListener::bind(node_real_addr).await.unwrap();
3492
3493 let (mut driver_conn1, mut node_conn1) = join(
3495 async { TcpStream::connect(node_proxy_addr).await.unwrap() },
3496 async { mock_node_listener.accept().await.unwrap().0 },
3497 )
3498 .await;
3499 send_register(&mut driver_conn1, &mut node_conn1).await;
3500
3501 let (mut driver_conn2, mut node_conn2) = join(
3503 async { TcpStream::connect(node_proxy_addr).await.unwrap() },
3504 async { mock_node_listener.accept().await.unwrap().0 },
3505 )
3506 .await;
3507 send_register(&mut driver_conn2, &mut node_conn2).await;
3508
3509 assert!(running_proxy.running_nodes[0].inject_event_to_cc(Bytes::from_static(b"ev1")));
3511
3512 let f1 = read_response_frame(&mut driver_conn1, &no_compression())
3514 .await
3515 .unwrap();
3516 let f2 = read_response_frame(&mut driver_conn2, &no_compression())
3517 .await
3518 .unwrap();
3519 assert_eq!(f1.body, Bytes::from_static(b"ev1"));
3520 assert_eq!(f2.body, Bytes::from_static(b"ev1"));
3521
3522 drop(driver_conn1);
3524 drop(node_conn1);
3525
3526 tokio::time::sleep(Duration::from_millis(100)).await;
3529
3530 assert!(running_proxy.running_nodes[0].inject_event_to_cc(Bytes::from_static(b"ev2")));
3533
3534 let f2 = read_response_frame(&mut driver_conn2, &no_compression())
3535 .await
3536 .unwrap();
3537 assert_eq!(f2.body, Bytes::from_static(b"ev2"));
3538
3539 drop(driver_conn2);
3541 drop(node_conn2);
3542
3543 tokio::time::sleep(Duration::from_millis(100)).await;
3544
3545 assert!(
3547 !running_proxy.running_nodes[0].inject_event_to_cc(Bytes::from_static(b"ev3")),
3548 "inject_event_to_cc should return false after all control connections closed"
3549 );
3550
3551 let _ = running_proxy.finish().await;
3554 }
3555}