scylla_proxy/
proxy.rs

1use crate::actions::{EvaluationContext, RequestRule, ResponseRule};
2use crate::errors::{DoorkeeperError, ProxyError, WorkerError};
3use crate::frame::{
4    self, FrameOpcode, FrameParams, RequestFrame, ResponseFrame, read_response_frame, write_frame,
5};
6use crate::{RequestOpcode, TargetShard};
7use bytes::Bytes;
8use compression::no_compression;
9use scylla_cql::frame::types::read_string_multimap;
10use std::collections::HashMap;
11use std::fmt::Display;
12use std::future::Future;
13use std::net::{IpAddr, Ipv4Addr, SocketAddr};
14use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
15use std::sync::{Arc, Mutex};
16use tokio::io::{AsyncRead, AsyncWrite};
17use tokio::net::{TcpListener, TcpSocket, TcpStream};
18use tokio::sync::mpsc::error::TryRecvError;
19use tokio::sync::{broadcast, mpsc};
20use tracing::{debug, error, info, trace, warn};
21
22// Used to notify the user that the proxy finished - this happens when all Senders are dropped.
23type FinishWaiter = mpsc::Receiver<()>;
24type FinishGuard = mpsc::Sender<()>;
25
26// Used to tell all the proxy workers to stop when the user requests that with [RunningProxy::finish()].
27type TerminateNotifier = tokio::sync::broadcast::Receiver<()>;
28type TerminateSignaler = tokio::sync::broadcast::Sender<()>;
29
30// Used to tell all proxy workers working on same connection to stop when
31// a rule being applied has connection drop set.
32type ConnectionCloseNotifier = tokio::sync::broadcast::Receiver<()>;
33type ConnectionCloseSignaler = tokio::sync::broadcast::Sender<()>;
34
35// Used to gather errors from all proxy workers and propagate them to the proxy user,
36// returning the first of them from [RunningProxy::finish()].
37type ErrorPropagator = mpsc::UnboundedSender<ProxyError>;
38type ErrorSink = mpsc::UnboundedReceiver<ProxyError>;
39
40static HARDCODED_OPTIONS_PARAMS: FrameParams = FrameParams {
41    flags: 0,
42    version: 0x04,
43    stream: 0,
44};
45
46/// Specifies proxy's behaviour regarding shard awareness.
47#[derive(Clone, Copy, Debug)]
48pub enum ShardAwareness {
49    /// Acts as if the connection was made to the shard-unaware port.
50    Unaware,
51    /// The first time the driver attempts to connect to the particular node (through proxy),
52    /// the related node is first queried on a temporary connection for its number of shards,
53    /// and only then establishes another connection for the driver's real communication with the node.
54    /// If the queried node does not provide sharding info (e.g. in case of a Cassandra node),
55    /// then this mode behaves as Unaware.
56    QueryNode,
57    /// Binds to the port that is the same as the driver's port modulo the provided number of shards.
58    FixedNum(u16),
59}
60
61impl ShardAwareness {
62    pub fn is_aware(&self) -> bool {
63        !matches!(self, Self::Unaware)
64    }
65}
66
67/// Node can be either Real (truly backed by a Scylla node) or Simulated
68/// (driver believes it's real, but we merely simulate it with the proxy).
69/// In Simulated mode, no node address is provided and proxy does not attempt
70/// to establish connection with a Scylla node.
71///
72/// For Real node, all workers are created, so such frame flow is possible:
73/// [driver] -> receiver_from_driver -> requests_processor -> sender_to_cluster -> [node] (ordinary request flow)
74///                                        |                    /\
75///     (forging response) ++--------------+      +-------------++ (forging request)
76///                        \/                     |
77/// [driver] <- sender_to_driver <- response_processor <- receiver_from_cluster <- [node] (ordinary response flow)
78///
79/// For Simulated node, it looks like this:
80/// [driver] -> receiver_from_driver -> requests_processor -+
81///                                                         |   (forging response)
82/// [driver] <- sender_to_driver <--------------------------+
83///
84/// For Real node, the default reaction to a frame is to pass it to its intended addresse.
85/// For Simulated node, the default reaction to a request is to drop it.
86enum NodeType {
87    Real {
88        real_addr: SocketAddr,
89        shard_awareness: ShardAwareness,
90        response_rules: Option<Vec<ResponseRule>>,
91    },
92    Simulated,
93}
94
95pub struct Node {
96    proxy_addr: SocketAddr,
97    request_rules: Option<Vec<RequestRule>>,
98    node_type: NodeType,
99}
100
101impl Node {
102    /// Creates an abstract node that is backed by a real Scylla node.
103    pub fn new(
104        real_addr: SocketAddr,
105        proxy_addr: SocketAddr,
106        shard_awareness: ShardAwareness,
107        request_rules: Option<Vec<RequestRule>>,
108        response_rules: Option<Vec<ResponseRule>>,
109    ) -> Self {
110        Self {
111            proxy_addr,
112            request_rules,
113            node_type: NodeType::Real {
114                real_addr,
115                shard_awareness,
116                response_rules,
117            },
118        }
119    }
120
121    /// Creates a simulated node that is not backed by any real Scylla node.
122    pub fn new_dry_mode(proxy_addr: SocketAddr, request_rules: Option<Vec<RequestRule>>) -> Self {
123        Self {
124            proxy_addr,
125            request_rules,
126            node_type: NodeType::Simulated,
127        }
128    }
129
130    pub fn builder() -> NodeBuilder {
131        NodeBuilder {
132            real_addr: None,
133            proxy_addr: None,
134            shard_awareness: None,
135            request_rules: None,
136            response_rules: None,
137        }
138    }
139}
140
141pub struct NodeBuilder {
142    real_addr: Option<SocketAddr>,
143    proxy_addr: Option<SocketAddr>,
144    shard_awareness: Option<ShardAwareness>,
145    request_rules: Option<Vec<RequestRule>>,
146    response_rules: Option<Vec<ResponseRule>>,
147}
148
149impl NodeBuilder {
150    pub fn real_address(mut self, real_addr: SocketAddr) -> Self {
151        self.real_addr = Some(real_addr);
152        self
153    }
154
155    pub fn proxy_address(mut self, proxy_addr: SocketAddr) -> Self {
156        self.proxy_addr = Some(proxy_addr);
157        self
158    }
159
160    pub fn shard_awareness(mut self, shard_awareness: ShardAwareness) -> Self {
161        self.shard_awareness = Some(shard_awareness);
162        self
163    }
164
165    pub fn request_rules(mut self, request_rules: Vec<RequestRule>) -> Self {
166        self.request_rules = Some(request_rules);
167        self
168    }
169
170    pub fn response_rules(mut self, response_rules: Vec<ResponseRule>) -> Self {
171        self.response_rules = Some(response_rules);
172        self
173    }
174
175    /// Creates an abstract node that is backed by a real Scylla node.
176    pub fn build(self) -> Node {
177        Node {
178            proxy_addr: self.proxy_addr.expect("Proxy addr is required!"),
179            request_rules: self.request_rules,
180            node_type: NodeType::Real {
181                real_addr: self.real_addr.expect("Real addr is required!"),
182                shard_awareness: self.shard_awareness.expect("Shard awareness is required!"),
183                response_rules: self.response_rules,
184            },
185        }
186    }
187
188    /// Creates a simulated node that is not backed by any real Scylla node.
189    pub fn build_dry_mode(self) -> Node {
190        Node {
191            proxy_addr: self.proxy_addr.expect("Proxy addr is required!"),
192            request_rules: self.request_rules,
193            node_type: NodeType::Simulated,
194        }
195    }
196}
197
198#[derive(Clone, Copy)]
199struct DisplayableRealAddrOption(Option<SocketAddr>);
200impl Display for DisplayableRealAddrOption {
201    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202        if let Some(addr) = self.0 {
203            write!(f, "{addr}")
204        } else {
205            write!(f, "<dry mode>")
206        }
207    }
208}
209
210#[derive(Clone, Copy)]
211struct DisplayableShard(Option<TargetShard>);
212impl Display for DisplayableShard {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        if let Some(shard) = self.0 {
215            write!(f, "shard {shard}")
216        } else {
217            write!(f, "unknown shard")
218        }
219    }
220}
221
222enum InternalNode {
223    Real {
224        real_addr: SocketAddr,
225        proxy_addr: SocketAddr,
226        shard_awareness: ShardAwareness,
227        request_rules: Arc<Mutex<Vec<RequestRule>>>,
228        response_rules: Arc<Mutex<Vec<ResponseRule>>>,
229    },
230    Simulated {
231        proxy_addr: SocketAddr,
232        request_rules: Arc<Mutex<Vec<RequestRule>>>,
233    },
234}
235
236impl InternalNode {
237    fn proxy_addr(&self) -> SocketAddr {
238        match *self {
239            InternalNode::Real { proxy_addr, .. } => proxy_addr,
240            InternalNode::Simulated { proxy_addr, .. } => proxy_addr,
241        }
242    }
243    fn real_addr(&self) -> Option<SocketAddr> {
244        match *self {
245            InternalNode::Real { real_addr, .. } => Some(real_addr),
246            InternalNode::Simulated { .. } => None,
247        }
248    }
249    fn request_rules(&self) -> &Arc<Mutex<Vec<RequestRule>>> {
250        match self {
251            InternalNode::Real { request_rules, .. } => request_rules,
252            InternalNode::Simulated { request_rules, .. } => request_rules,
253        }
254    }
255}
256
257impl From<Node> for InternalNode {
258    fn from(node: Node) -> Self {
259        match node.node_type {
260            NodeType::Real {
261                real_addr,
262                shard_awareness,
263                response_rules,
264            } => InternalNode::Real {
265                real_addr,
266                proxy_addr: node.proxy_addr,
267                shard_awareness,
268                request_rules: node
269                    .request_rules
270                    .map(|rules| Arc::new(Mutex::new(rules)))
271                    .unwrap_or_default(),
272                response_rules: response_rules
273                    .map(|rules| Arc::new(Mutex::new(rules)))
274                    .unwrap_or_default(),
275            },
276            NodeType::Simulated => InternalNode::Simulated {
277                proxy_addr: node.proxy_addr,
278                request_rules: node
279                    .request_rules
280                    .map(|rules| Arc::new(Mutex::new(rules)))
281                    .unwrap_or_default(),
282            },
283        }
284    }
285}
286
287pub struct ProxyBuilder {
288    nodes: Vec<Node>,
289}
290
291impl ProxyBuilder {
292    pub fn with_node(mut self, node: Node) -> ProxyBuilder {
293        self.nodes.push(node);
294        self
295    }
296
297    pub fn build(self) -> Proxy {
298        Proxy::new(self.nodes)
299    }
300}
301
302pub struct Proxy {
303    nodes: Vec<InternalNode>,
304}
305
306impl Proxy {
307    pub fn new(nodes: impl IntoIterator<Item = Node>) -> Self {
308        Proxy {
309            nodes: nodes.into_iter().map(|node| node.into()).collect(),
310        }
311    }
312
313    pub fn builder() -> ProxyBuilder {
314        ProxyBuilder { nodes: vec![] }
315    }
316
317    /// Build a translation map based on provided proxy and node addresses.
318    /// The map can be passed to `Session` `address_translator()` to ensure
319    /// that the driver contacts the nodes through the proxy (and not directly).
320    pub fn translation_map(&self) -> HashMap<SocketAddr, SocketAddr> {
321        let mut translation_map = HashMap::new();
322        for node in self.nodes.iter() {
323            if let &InternalNode::Real {
324                real_addr,
325                proxy_addr,
326                ..
327            } = node
328            {
329                translation_map.insert(real_addr, proxy_addr);
330                let shard_aware_real_addr = SocketAddr::new(real_addr.ip(), 19042);
331                translation_map.insert(shard_aware_real_addr, proxy_addr);
332            }
333        }
334        translation_map
335    }
336
337    /// Runs the [Proxy], i.e. makes it ready for accepting drivers' connections.
338    /// Returns a [RunningProxy] handle that can be used to stop the proxy or change the rules.
339    pub async fn run(self) -> Result<RunningProxy, DoorkeeperError> {
340        let (terminate_signaler, _t) = tokio::sync::broadcast::channel(1);
341        let (finish_guard, finish_waiter) = mpsc::channel(1);
342
343        let (error_propagator, error_sink) = mpsc::unbounded_channel();
344        let (doorkeepers, running_nodes): (Vec<_>, Vec<RunningNode>) = self
345            .nodes
346            .into_iter()
347            .map(|node| {
348                let running = {
349                    let (request_rules, response_rules) = match node {
350                        InternalNode::Real {
351                            ref request_rules,
352                            ref response_rules,
353                            ..
354                        } => (request_rules, Some(response_rules)),
355                        InternalNode::Simulated {
356                            ref request_rules, ..
357                        } => (request_rules, None),
358                    };
359                    RunningNode {
360                        request_rules: request_rules.clone(),
361                        response_rules: response_rules.cloned(),
362                    }
363                };
364                (
365                    Doorkeeper::spawn(
366                        node,
367                        terminate_signaler.clone(),
368                        finish_guard.clone(),
369                        error_propagator.clone(),
370                    ),
371                    running,
372                )
373            })
374            .unzip();
375
376        for doorkeeper in doorkeepers {
377            doorkeeper.await?; // await doorkeeper creation, including binding to a socket
378        }
379
380        Ok(RunningProxy {
381            terminate_signaler,
382            finish_waiter,
383            running_nodes,
384            error_sink,
385        })
386    }
387}
388
389/// A handle that can be used to change the rules regarding the particular node.
390pub struct RunningNode {
391    request_rules: Arc<Mutex<Vec<RequestRule>>>,
392    response_rules: Option<Arc<Mutex<Vec<ResponseRule>>>>,
393}
394
395impl RunningNode {
396    /// Replaces the previous request rules with the new ones.
397    pub fn change_request_rules(&mut self, rules: Option<Vec<RequestRule>>) {
398        *self.request_rules.lock().unwrap() = rules.unwrap_or_default();
399    }
400
401    /// Adds new request rules to the end of the list (so with lowest priority)
402    pub fn append_request_rules(&mut self, mut rules: Vec<RequestRule>) {
403        self.request_rules.lock().unwrap().append(&mut rules);
404    }
405
406    /// Adds new request rules to the beginning of the list (so with highest priority)
407    pub fn prepend_request_rules(&mut self, rules: Vec<RequestRule>) {
408        let mut new_rules = rules;
409        let mut old_rules_guard = self.request_rules.lock().unwrap();
410        new_rules.append(&mut *old_rules_guard);
411        *old_rules_guard = new_rules;
412    }
413
414    /// Replaces the previous response rules with the new ones.
415    pub fn change_response_rules(&mut self, rules: Option<Vec<ResponseRule>>) {
416        *self
417            .response_rules
418            .as_ref()
419            .expect("No response rules on a simulated node!")
420            .lock()
421            .unwrap() = rules.unwrap_or_default();
422    }
423
424    /// Adds new response rules to the end of the list (so with lowest priority)
425    pub fn append_response_rules(&mut self, mut rules: Vec<ResponseRule>) {
426        self.response_rules
427            .as_ref()
428            .expect("No response rules on a simulated node!")
429            .lock()
430            .unwrap()
431            .append(&mut rules);
432    }
433
434    /// Adds new response rules to the beginning of the list (so with highest priority)
435    pub fn prepend_response_rules(&mut self, rules: Vec<ResponseRule>) {
436        let mut old_rules_guard = self
437            .response_rules
438            .as_ref()
439            .expect("No response rules on a simulated node!")
440            .lock()
441            .unwrap();
442        let mut new_rules = rules;
443        new_rules.append(&mut *old_rules_guard);
444        *old_rules_guard = new_rules;
445    }
446}
447
448/// A handle that can be used to stop the proxy or change the rules.
449pub struct RunningProxy {
450    terminate_signaler: TerminateSignaler,
451    finish_waiter: FinishWaiter,
452    pub running_nodes: Vec<RunningNode>,
453    error_sink: ErrorSink,
454}
455
456impl RunningProxy {
457    /// Disables all the rules in the proxy, effectively making it a pass-through-only proxy.
458    pub fn turn_off_rules(&mut self) {
459        for (request_rules, response_rules) in self
460            .running_nodes
461            .iter_mut()
462            .map(|node| (&node.request_rules, &node.response_rules))
463        {
464            request_rules.lock().unwrap().clear();
465            if let Some(response_rules) = response_rules {
466                response_rules.lock().unwrap().clear();
467            }
468        }
469    }
470
471    /// Attempts to fetch the first error that has occurred in proxy since last check.
472    /// If no errors occurred, returns Ok(()).
473    pub fn sanity_check(&mut self) -> Result<(), ProxyError> {
474        match self.error_sink.try_recv() {
475            Ok(err) => Err(err),
476            Err(TryRecvError::Empty) => Ok(()),
477            Err(TryRecvError::Disconnected) => {
478                // As we haven't awaited finish of all workers yet, there must be a faulty case without proper error handling.
479                Err(ProxyError::SanityCheckFailure)
480            }
481        }
482    }
483
484    /// Waits until an error occurs in proxy. If proxy finishes with no errors occurred, returns Err(()).
485    pub async fn wait_for_error(&mut self) -> Option<ProxyError> {
486        self.error_sink.recv().await
487    }
488
489    /// Requests termination of all proxy workers and awaits its completion.
490    /// Returns the first error that occurred in proxy.
491    pub async fn finish(mut self) -> Result<(), ProxyError> {
492        self.terminate_signaler.send(()).map_err(|err| {
493            ProxyError::AwaitFinishFailure(format!(
494                "Send error in terminate_signaler: {err} (bug!)"
495            ))
496        })?;
497        info!("Sent finish signal to proxy workers.");
498
499        // This to make sure that also workers not-yet-spawned when terminate signal was sent will terminate.
500        std::mem::drop(self.terminate_signaler);
501
502        if self.finish_waiter.recv().await.is_some() {
503            unreachable!();
504        };
505        info!("All workers have finished.");
506
507        match self.error_sink.try_recv() {
508            Ok(err) => Err(err),
509            Err(TryRecvError::Disconnected) => Ok(()),
510            Err(TryRecvError::Empty) => {
511                // As we have already awaited finish of all workers, there must be a logic bug.
512                unreachable!("Worker await logic bug!");
513            }
514        }
515    }
516}
517
518/// A worker corresponding to a particular node. It listens in a loop for driver's connections
519/// on specified proxy bind address, respects ports regarding advanced shard-awareness (if set),
520/// to this end obtaining number of shards from the node (if set), then establishes connection
521/// to the node, spawns workers for this connection and continues to listen.
522struct Doorkeeper {
523    node: InternalNode,
524    listener: TcpListener,
525    terminate_signaler: TerminateSignaler,
526    finish_guard: FinishGuard,
527    shards_count: Option<u16>,
528    error_propagator: ErrorPropagator,
529}
530
531impl Doorkeeper {
532    async fn spawn(
533        node: InternalNode,
534        terminate_signaler: TerminateSignaler,
535        finish_guard: FinishGuard,
536        error_propagator: ErrorPropagator,
537    ) -> Result<(), DoorkeeperError> {
538        let listener = TcpListener::bind(node.proxy_addr())
539            .await
540            .map_err(|err| DoorkeeperError::DriverConnectionAttempt(node.proxy_addr(), err))?;
541
542        if let InternalNode::Real {
543            shard_awareness,
544            real_addr,
545            ..
546        } = node
547        {
548            info!(
549                "Spawned a {} doorkeeper for pair real:{} - proxy:{}.",
550                if shard_awareness.is_aware() {
551                    "shard-aware"
552                } else {
553                    "shard-unaware"
554                },
555                real_addr,
556                node.proxy_addr(),
557            );
558        } else {
559            info!(
560                "Spawned a dry-mode doorkeeper for proxy:{}.",
561                node.proxy_addr(),
562            )
563        };
564
565        let doorkeeper = Doorkeeper {
566            shards_count: None, // temporarily, until Doorkeeper examines its ShardAwareness
567            node,
568            listener,
569            terminate_signaler,
570            finish_guard,
571            error_propagator,
572        };
573        tokio::task::spawn(doorkeeper.run());
574        Ok(())
575    }
576
577    async fn run(mut self) {
578        self.update_shards_count().await;
579        let mut own_terminate_notifier = self.terminate_signaler.subscribe();
580        let (connection_close_tx, _connection_close_rx) = broadcast::channel::<()>(2);
581        let mut connection_no: usize = 0;
582        loop {
583            tokio::select! {
584                res = self.accept_connection(&connection_close_tx, connection_no) => {
585                    match res {
586                        Ok(()) => connection_no += 1,
587                        Err(err) => {
588                            error!(
589                                "Error in doorkeeper with addr {} for node {}: {}",
590                                self.node.proxy_addr(),
591                                DisplayableRealAddrOption(self.node.real_addr()),
592                                err
593                            );
594                            let _ = self.error_propagator.send(err.into());
595                            break;
596                        },
597                    }
598                },
599                _terminate = own_terminate_notifier.recv() => break
600            }
601        }
602        debug!(
603            "Doorkeeper exits: proxy {}, node {}.",
604            self.node.proxy_addr(),
605            DisplayableRealAddrOption(self.node.real_addr())
606        );
607    }
608
609    async fn update_shards_count(&mut self) {
610        if let InternalNode::Real {
611            real_addr,
612            shard_awareness,
613            ..
614        } = self.node
615        {
616            self.shards_count = match shard_awareness {
617                ShardAwareness::Unaware => None,
618                ShardAwareness::FixedNum(shards_num) => Some(shards_num),
619                ShardAwareness::QueryNode => match self.obtain_shards_count(real_addr).await {
620                    Ok(shards) => Some(shards),
621                    // If a node offers no sharding info, change proxy ShardAwareness to Unaware.
622                    Err(DoorkeeperError::ObtainingShardNumberNoShardInfo) => {
623                        info!(
624                            "Doorkeeper with addr {} found no shard info in node {}; falling back to ShardAwareness::Unaware",
625                            self.node.proxy_addr(),
626                            DisplayableRealAddrOption(self.node.real_addr()),
627                        );
628                        None
629                    }
630                    Err(e) => {
631                        error!(
632                            "Error in doorkeeper with addr {} while querying shard info from node {}: {}",
633                            self.node.proxy_addr(),
634                            DisplayableRealAddrOption(self.node.real_addr()),
635                            e
636                        );
637                        None
638                    }
639                },
640            }
641        }
642    }
643
644    async fn spawn_workers(
645        &mut self,
646        driver_addr: SocketAddr,
647        connection_close_tx: &ConnectionCloseSignaler,
648        connection_no: usize,
649        driver_stream: TcpStream,
650        cluster_stream: Option<TcpStream>,
651        shard: Option<TargetShard>,
652    ) {
653        let (driver_read, driver_write) = driver_stream.into_split();
654
655        let new_worker = || ProxyWorker {
656            terminate_notifier: self.terminate_signaler.subscribe(),
657            finish_guard: self.finish_guard.clone(),
658            connection_close_notifier: connection_close_tx.subscribe(),
659            error_propagator: self.error_propagator.clone(),
660            driver_addr,
661            real_addr: self.node.real_addr(),
662            proxy_addr: self.node.proxy_addr(),
663            shard,
664        };
665
666        let (tx_request, rx_request) = mpsc::unbounded_channel::<RequestFrame>();
667        let (tx_response, rx_response) = mpsc::unbounded_channel::<ResponseFrame>();
668        let (tx_cluster, rx_cluster) = mpsc::unbounded_channel::<RequestFrame>();
669        let (tx_driver, rx_driver) = mpsc::unbounded_channel::<ResponseFrame>();
670        let event_register_flag = Arc::new(AtomicBool::new(false));
671
672        let (
673            compression_writer_request_processor,
674            compression_reader_receiver_from_driver,
675            compression_reader_receiver_from_cluster,
676            compression_reader_sender_to_driver,
677            compression_reader_sender_to_cluster,
678        ) = compression::make_compression_infra();
679
680        tokio::task::spawn(new_worker().receiver_from_driver(
681            driver_read,
682            tx_request,
683            compression_reader_receiver_from_driver,
684        ));
685        tokio::task::spawn(new_worker().sender_to_driver(
686            driver_write,
687            rx_driver,
688            connection_close_tx.subscribe(),
689            self.terminate_signaler.subscribe(),
690            compression_reader_sender_to_driver,
691        ));
692        tokio::task::spawn(new_worker().request_processor(
693            rx_request,
694            tx_driver.clone(),
695            tx_cluster.clone(),
696            connection_no,
697            self.node.request_rules().clone(),
698            connection_close_tx.clone(),
699            event_register_flag.clone(),
700            compression_writer_request_processor,
701        ));
702        if let InternalNode::Real {
703            ref response_rules, ..
704        } = self.node
705        {
706            let (cluster_read, cluster_write) = cluster_stream.unwrap().into_split();
707            tokio::task::spawn(new_worker().sender_to_cluster(
708                cluster_write,
709                rx_cluster,
710                connection_close_tx.subscribe(),
711                self.terminate_signaler.subscribe(),
712                compression_reader_sender_to_cluster,
713            ));
714            tokio::task::spawn(new_worker().receiver_from_cluster(
715                cluster_read,
716                tx_response,
717                compression_reader_receiver_from_cluster,
718            ));
719            tokio::task::spawn(new_worker().response_processor(
720                rx_response,
721                tx_driver,
722                tx_cluster,
723                connection_no,
724                response_rules.clone(),
725                connection_close_tx.clone(),
726                event_register_flag.clone(),
727            ));
728        }
729        debug!(
730            "Doorkeeper with addr {} of node {} spawned workers.",
731            self.node.proxy_addr(),
732            DisplayableRealAddrOption(self.node.real_addr())
733        );
734    }
735
736    async fn accept_connection(
737        &mut self,
738        connection_close_tx: &ConnectionCloseSignaler,
739        connection_no: usize,
740    ) -> Result<(), DoorkeeperError> {
741        let (driver_stream, driver_addr) = self.make_driver_stream(connection_no).await?;
742        let (cluster_stream, shard) = match self.node {
743            InternalNode::Real { real_addr, .. } => {
744                let (cluster_stream, shard) =
745                    self.make_cluster_stream(driver_addr, real_addr).await?;
746                (Some(cluster_stream), shard)
747            }
748            InternalNode::Simulated { .. } => (None, None),
749        };
750
751        self.spawn_workers(
752            driver_addr,
753            connection_close_tx,
754            connection_no,
755            driver_stream,
756            cluster_stream,
757            shard,
758        )
759        .await;
760
761        Ok(())
762    }
763
764    async fn make_driver_stream(
765        &mut self,
766        connection_no: usize,
767    ) -> Result<(TcpStream, SocketAddr), DoorkeeperError> {
768        let (driver_stream, driver_addr) =
769            self.listener.accept().await.map_err(|err| {
770                DoorkeeperError::DriverConnectionAttempt(self.node.proxy_addr(), err)
771            })?;
772        info!(
773            "Connected driver from {} to {}, connection no={}.",
774            driver_addr,
775            self.node.proxy_addr(),
776            connection_no
777        );
778        Ok((driver_stream, driver_addr))
779    }
780
781    async fn make_cluster_stream(
782        &mut self,
783        driver_addr: SocketAddr,
784        real_addr: SocketAddr,
785    ) -> Result<(TcpStream, Option<TargetShard>), DoorkeeperError> {
786        let mut cluster_stream = if let Some(shards) = self.shards_count {
787            let socket = match self.node.proxy_addr().ip() {
788                std::net::IpAddr::V4(_) => TcpSocket::new_v4(),
789                std::net::IpAddr::V6(_) => TcpSocket::new_v6(),
790            }
791            .map_err(DoorkeeperError::SocketCreate)?;
792
793            let shard_preserving_addr = {
794                let mut desired_addr =
795                    SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), driver_addr.port());
796                while socket.bind(desired_addr).is_err() {
797                    // in search for a port that translates to the desired shard
798                    let next_port = self.next_port_to_same_shard(desired_addr.port());
799                    if next_port == driver_addr.port() {
800                        return Err(DoorkeeperError::NoMorePorts);
801                    }
802                    desired_addr.set_port(next_port);
803                }
804                desired_addr
805            };
806
807            let stream = socket.connect(real_addr).await;
808            if let Ok(ok) = &stream {
809                info!(
810                    "Connected to the cluster from {} at {}, intended shard {}.",
811                    ok.local_addr().unwrap(),
812                    real_addr,
813                    shard_preserving_addr.port() % shards
814                );
815            }
816            stream
817        } else {
818            let stream = TcpStream::connect(real_addr).await;
819            if stream.is_ok() {
820                info!("Connected to the cluster at {}.", real_addr);
821            }
822            stream
823        }
824        .map_err(|err| DoorkeeperError::NodeConnectionAttempt(real_addr, err))?;
825
826        // If ShardAwareness is aware (QueryNode or FixedNum variants) and the
827        // proxy succeeded to know the shards count (in FixedNum we get it for
828        // free, in QueryNode the initial Options query succeeded and Supported
829        // contained SCYLLA_SHARDS_NUM), then upon opening each connection to the
830        // node, the proxy issues another Options requests and acknowledges the
831        // shard it got connected to.
832        let shard = if self.shards_count.is_some() {
833            self.obtain_shard_number(real_addr, &mut cluster_stream)
834                .await?
835        } else {
836            None
837        };
838
839        Ok((cluster_stream, shard))
840    }
841
842    fn next_port_to_same_shard(&self, port: u16) -> u16 {
843        port.wrapping_add(self.shards_count.unwrap())
844    }
845
846    async fn get_supported_options(
847        connection: &mut TcpStream,
848    ) -> Result<HashMap<String, Vec<String>>, DoorkeeperError> {
849        write_frame(
850            HARDCODED_OPTIONS_PARAMS,
851            FrameOpcode::Request(RequestOpcode::Options),
852            &Bytes::new(),
853            connection,
854            &no_compression(),
855        )
856        .await
857        .map_err(DoorkeeperError::ObtainingShardNumber)?;
858
859        let supported_frame = read_response_frame(connection, &compression::no_compression())
860            .await
861            .map_err(DoorkeeperError::ObtainingShardNumberFrame)?;
862
863        let options = read_string_multimap(&mut supported_frame.body.as_ref())
864            .map_err(DoorkeeperError::ObtainingShardNumberParseOptions)?;
865
866        Ok(options)
867    }
868
869    async fn obtain_shards_count(&self, real_addr: SocketAddr) -> Result<u16, DoorkeeperError> {
870        let mut connection = TcpStream::connect(real_addr)
871            .await
872            .map_err(|err| DoorkeeperError::NodeConnectionAttempt(real_addr, err))?;
873        let options = Self::get_supported_options(&mut connection).await?;
874        let nr_shards_entry = options.get("SCYLLA_NR_SHARDS");
875        let shards = match nr_shards_entry
876            .and_then(|vec| vec.first())
877            .ok_or(DoorkeeperError::ObtainingShardNumberNoShardInfo)?
878            .parse::<u16>()
879            .map_err(DoorkeeperError::ObtainingShardNumberParseShardNumber)?
880        {
881            0u16 => Err(DoorkeeperError::ObtainingShardNumberGotZero),
882            num => Ok(num),
883        }?;
884        info!("Obtained shards number on node {}: {}", real_addr, shards);
885        Ok(shards)
886    }
887
888    async fn obtain_shard_number(
889        &self,
890        real_addr: SocketAddr,
891        connection: &mut TcpStream,
892    ) -> Result<Option<TargetShard>, DoorkeeperError> {
893        let options = Self::get_supported_options(connection).await?;
894        let shard_entry = options.get("SCYLLA_SHARD");
895        let shard = shard_entry
896            .and_then(|vec| vec.first())
897            .map(|s| {
898                s.parse::<u16>()
899                    .map_err(DoorkeeperError::ObtainingShardNumberParseShardNumber)
900            })
901            .transpose()?;
902        info!("Connected to node {}, shard {:?}", real_addr, shard);
903        Ok(shard)
904    }
905}
906
907mod compression {
908    use std::error::Error;
909    use std::sync::{Arc, OnceLock};
910
911    use bytes::Bytes;
912    use scylla_cql::frame::frame_errors::{
913        CqlRequestSerializationError, FrameBodyExtensionsParseError,
914    };
915    use scylla_cql::frame::request::{
916        DeserializableRequest as _, RequestDeserializationError, Startup, options,
917    };
918    use scylla_cql::frame::{Compression, compress_append, decompress, flag};
919    use tracing::{error, warn};
920
921    #[derive(Debug, thiserror::Error)]
922    pub(crate) enum CompressionError {
923        /// Body Snap compression failed.
924        #[error("Snap compression error: {0}")]
925        SnapCompressError(Arc<dyn Error + Sync + Send>),
926
927        /// Frame is to be compressed, but no compression was negotiated for the connection.
928        #[error("Frame is to be compressed, but no compression negotiated for connection.")]
929        NoCompressionNegotiated,
930    }
931
932    type CompressionInfo = Arc<OnceLock<Option<Compression>>>;
933
934    /// The write end of compression config for a connection.
935    ///
936    /// Used by the request processor upon STARTUP frame captured
937    /// and compression setting retrieved from it.
938    #[derive(Debug, Clone)]
939    pub(crate) struct CompressionWriter(CompressionInfo);
940    impl CompressionWriter {
941        pub(crate) fn set(
942            &self,
943            compression: Option<Compression>,
944        ) -> Result<(), Option<Compression>> {
945            self.0.set(compression)
946        }
947
948        pub(crate) fn set_from_startup(
949            &self,
950            mut body: &[u8],
951        ) -> Result<Option<Compression>, RequestDeserializationError> {
952            let startup = Startup::deserialize_with_features(&mut body, &Default::default())?;
953            let maybe_compression = startup.options.get(options::COMPRESSION);
954            let maybe_compression = maybe_compression.and_then(|compression| {
955                compression
956                    .parse::<Compression>()
957                    .inspect_err(|err| error!("STARTUP compression error: {}", err))
958                    .ok()
959            });
960            let _ = self.set(maybe_compression).inspect_err(|_| {
961                warn!("Captured second or further STARTUP frame on the same connection")
962            });
963
964            Ok(maybe_compression)
965        }
966    }
967
968    /// The read end of compression config for a connection.
969    ///
970    /// Used by frame (de)serializers.
971    #[derive(Debug, Clone)]
972    pub(crate) struct CompressionReader(CompressionInfo);
973    impl CompressionReader {
974        /// Return the compression negotiated for the connection.
975        ///
976        /// Outer Option signifies whether the negotiation took place,
977        /// inner Option is the compression (or lack of it) negotiated.
978        pub(crate) fn get(&self) -> Option<Option<Compression>> {
979            self.0.get().copied()
980        }
981
982        pub(crate) fn maybe_compress_body(
983            &self,
984            flags: u8,
985            body: &[u8],
986        ) -> Result<Option<Bytes>, CompressionError> {
987            match (flags & flag::COMPRESSION != 0, self.get().flatten()) {
988                (true, Some(compression)) => {
989                    let mut buf = Vec::new();
990                    compress_append(body, compression, &mut buf).map_err(|err| {
991                        let CqlRequestSerializationError::SnapCompressError(err) = err else {
992                            unreachable!("BUG: compress_append returned variant different than SnapCompressError")
993                        };
994                        CompressionError::SnapCompressError(err)
995                    })?;
996                    Ok(Some(Bytes::from(buf)))
997                }
998                (true, None) => Err(CompressionError::NoCompressionNegotiated),
999                (false, _) => Ok(None),
1000            }
1001        }
1002
1003        pub(crate) fn maybe_decompress_body(
1004            &self,
1005            flags: u8,
1006            body: Bytes,
1007        ) -> Result<Bytes, FrameBodyExtensionsParseError> {
1008            match (flags & flag::COMPRESSION != 0, self.get().flatten()) {
1009                (true, Some(compression)) => decompress(&body, compression).map(Into::into),
1010                (true, None) => Err(FrameBodyExtensionsParseError::NoCompressionNegotiated),
1011                (false, _) => Ok(body),
1012            }
1013        }
1014    }
1015
1016    pub(crate) fn make_compression_infra() -> (
1017        CompressionWriter,
1018        CompressionReader,
1019        CompressionReader,
1020        CompressionReader,
1021        CompressionReader,
1022    ) {
1023        let info = Arc::new(OnceLock::new());
1024        (
1025            CompressionWriter(info.clone()),
1026            CompressionReader(info.clone()),
1027            CompressionReader(info.clone()),
1028            CompressionReader(info.clone()),
1029            CompressionReader(info),
1030        )
1031    }
1032
1033    fn mock_compression_reader(compression: Option<Compression>) -> CompressionReader {
1034        CompressionReader(Arc::new({
1035            let once = OnceLock::new();
1036            once.set(compression).unwrap();
1037            once
1038        }))
1039    }
1040
1041    // Compression explicitly turned off.
1042    pub(crate) fn no_compression() -> CompressionReader {
1043        mock_compression_reader(None)
1044    }
1045
1046    // Compression explicitly turned on.
1047    #[cfg(test)] // Currently only used for tests.
1048    pub(crate) fn with_compression(compression: Compression) -> CompressionReader {
1049        mock_compression_reader(Some(compression))
1050    }
1051}
1052pub(crate) use compression::{CompressionReader, CompressionWriter};
1053
1054struct ProxyWorker {
1055    terminate_notifier: TerminateNotifier,
1056    finish_guard: FinishGuard,
1057    connection_close_notifier: ConnectionCloseNotifier,
1058    error_propagator: ErrorPropagator,
1059    driver_addr: SocketAddr,
1060    real_addr: Option<SocketAddr>,
1061    proxy_addr: SocketAddr,
1062    shard: Option<TargetShard>,
1063}
1064
1065impl ProxyWorker {
1066    fn exit(self, duty: &'static str) {
1067        debug!(
1068            "Worker exits: [driver: {}, proxy: {}, node: {}, {}]::{}.",
1069            self.driver_addr,
1070            self.proxy_addr,
1071            DisplayableRealAddrOption(self.real_addr),
1072            DisplayableShard(self.shard),
1073            duty
1074        );
1075        std::mem::drop(self.finish_guard);
1076    }
1077
1078    async fn run_until_interrupted<F, Fut>(mut self, worker_name: &'static str, f: F)
1079    where
1080        F: FnOnce(SocketAddr, SocketAddr, Option<SocketAddr>) -> Fut,
1081        Fut: Future<Output = Result<(), ProxyError>>,
1082    {
1083        let fut = f(self.driver_addr, self.proxy_addr, self.real_addr);
1084
1085        tokio::select! {
1086            result = fut => {
1087                if let Err(err) = result {
1088                    // error_propagator could be a field
1089                    let _ = self.error_propagator.send(err);
1090                }
1091            }
1092            _ = self.terminate_notifier.recv() => (),
1093            _ = self.connection_close_notifier.recv() => (),
1094        }
1095        self.exit(worker_name);
1096    }
1097
1098    async fn receiver_from_driver(
1099        self,
1100        mut read_half: impl AsyncRead + Unpin,
1101        request_processor_tx: mpsc::UnboundedSender<RequestFrame>,
1102        compression: CompressionReader,
1103    ) {
1104        let shard = self.shard;
1105        self.run_until_interrupted(
1106            "receiver_from_driver",
1107            |driver_addr, proxy_addr, _real_addr| async move {
1108                loop {
1109                    let frame = frame::read_request_frame(&mut read_half, &compression)
1110                        .await
1111                        .map_err(|err| {
1112                            warn!("Request reception from {} error: {}", driver_addr, err);
1113                            WorkerError::DriverDisconnected(driver_addr)
1114                        })?;
1115
1116                    debug!(
1117                        "Intercepted Driver ({}) -> Cluster ({}) ({}) frame. opcode: {:?}.",
1118                        driver_addr,
1119                        proxy_addr,
1120                        DisplayableShard(shard),
1121                        &frame.opcode
1122                    );
1123                    if request_processor_tx.send(frame).is_err() {
1124                        warn!("request_processor had exited.");
1125                        return Result::<(), ProxyError>::Ok(());
1126                    }
1127                }
1128            },
1129        )
1130        .await
1131    }
1132
1133    async fn receiver_from_cluster(
1134        self,
1135        mut read_half: impl AsyncRead + Unpin,
1136        response_processor_tx: mpsc::UnboundedSender<ResponseFrame>,
1137        compression: CompressionReader,
1138    ) {
1139        let shard = self.shard;
1140        self.run_until_interrupted(
1141            "receiver_from_cluster",
1142            |driver_addr, _proxy_addr, real_addr| async move {
1143                let real_addr = real_addr.expect("BUG: no real_addr in cluster worker");
1144                loop {
1145                    let frame = frame::read_response_frame(&mut read_half, &compression)
1146                        .await
1147                        .map_err(|err| {
1148                            warn!("Response reception from {} error: {}", real_addr, err);
1149                            WorkerError::NodeDisconnected(real_addr)
1150                        })?;
1151
1152                    debug!(
1153                        "Intercepted Cluster ({}) ({}) -> Driver ({}) frame. opcode: {:?}.",
1154                        real_addr,
1155                        DisplayableShard(shard),
1156                        driver_addr,
1157                        &frame.opcode
1158                    );
1159
1160                    if response_processor_tx.send(frame).is_err() {
1161                        warn!("response_processor had exited.");
1162                        return Ok::<(), ProxyError>(());
1163                    }
1164                }
1165            },
1166        )
1167        .await;
1168    }
1169
1170    async fn sender_to_driver(
1171        self,
1172        mut write_half: impl AsyncWrite + Unpin,
1173        mut responses_rx: mpsc::UnboundedReceiver<ResponseFrame>,
1174        mut connection_close_notifier: ConnectionCloseNotifier,
1175        mut terminate_notifier: TerminateNotifier,
1176        compression: CompressionReader,
1177    ) {
1178        let shard = self.shard;
1179        self.run_until_interrupted(
1180            "sender_to_driver",
1181            |driver_addr, proxy_addr, _real_addr| async move {
1182                loop {
1183                    let response = match responses_rx.recv().await {
1184                        Some(response) => response,
1185                        None => {
1186                            if terminate_notifier.try_recv().is_err()
1187                                && connection_close_notifier.try_recv().is_err()
1188                            {
1189                                warn!("Response processor had exited");
1190                            }
1191                            return Ok(());
1192                        }
1193                    };
1194
1195                    debug!(
1196                        "Sending Proxy ({}) ({}) -> Driver ({}) frame. opcode: {:?}.",
1197                        proxy_addr,
1198                        DisplayableShard(shard),
1199                        driver_addr,
1200                        &response.opcode
1201                    );
1202                    if response.write(&mut write_half, &compression).await.is_err() {
1203                        if terminate_notifier.try_recv().is_err()
1204                            && connection_close_notifier.try_recv().is_err()
1205                        {
1206                            warn!("Driver dropped connection");
1207                            return Err(WorkerError::DriverDisconnected(driver_addr).into());
1208                        }
1209                        return Ok(());
1210                    }
1211                }
1212            },
1213        )
1214        .await;
1215    }
1216
1217    async fn sender_to_cluster(
1218        self,
1219        mut write_half: impl AsyncWrite + Unpin,
1220        mut requests_rx: mpsc::UnboundedReceiver<RequestFrame>,
1221        mut connection_close_notifier: ConnectionCloseNotifier,
1222        mut terminate_notifier: TerminateNotifier,
1223        compression: CompressionReader,
1224    ) {
1225        let shard = self.shard;
1226        self.run_until_interrupted(
1227            "sender_to_driver",
1228            |_driver_addr, proxy_addr, real_addr| async move {
1229                let real_addr = real_addr.expect("BUG: no real_addr in cluster worker");
1230                loop {
1231                    let request = match requests_rx.recv().await {
1232                        Some(request) => request,
1233                        None => {
1234                            if terminate_notifier.try_recv().is_err()
1235                                && connection_close_notifier.try_recv().is_err()
1236                            {
1237                                warn!("Request processor had exited");
1238                            }
1239                            return Ok(());
1240                        }
1241                    };
1242
1243                    debug!(
1244                        "Sending Proxy ({}) -> Cluster ({}) ({}) frame. opcode: {:?}.",
1245                        proxy_addr,
1246                        real_addr,
1247                        DisplayableShard(shard),
1248                        &request.opcode
1249                    );
1250
1251                    if request.write(&mut write_half, &compression).await.is_err() {
1252                        if terminate_notifier.try_recv().is_err()
1253                            && connection_close_notifier.try_recv().is_err()
1254                        {
1255                            warn!("Node {} dropped connection", real_addr);
1256                            return Err(WorkerError::NodeDisconnected(real_addr).into());
1257                        }
1258                        return Ok(());
1259                    }
1260                }
1261            },
1262        )
1263        .await;
1264    }
1265
1266    #[expect(clippy::too_many_arguments)]
1267    async fn request_processor(
1268        self,
1269        mut requests_rx: mpsc::UnboundedReceiver<RequestFrame>,
1270        driver_tx: mpsc::UnboundedSender<ResponseFrame>,
1271        cluster_tx: mpsc::UnboundedSender<RequestFrame>,
1272        connection_no: usize,
1273        request_rules: Arc<Mutex<Vec<RequestRule>>>,
1274        connection_close_signaler: ConnectionCloseSignaler,
1275        event_registered_flag: Arc<AtomicBool>,
1276        compression: CompressionWriter,
1277    ) {
1278        let shard = self.shard;
1279        self.run_until_interrupted("request_processor", |driver_addr, _, real_addr| async move {
1280            'mainloop: loop {
1281                match requests_rx.recv().await {
1282                    Some(request) => {
1283                        if request.opcode == RequestOpcode::Register {
1284                            event_registered_flag.store(true, Ordering::Relaxed);
1285                        } else if request.opcode == RequestOpcode::Startup {
1286                            match compression.set_from_startup(&request.body) {
1287                                Err(err) => error!("Failed to deserialize STARTUP frame: {}", err),
1288                                Ok(read_compression) => info!(
1289                                    "Intercepted STARTUP frame ({} -> {} ({})), so set compression accordingly to {:?}.",
1290                                    driver_addr,
1291                                    DisplayableRealAddrOption(real_addr),
1292                                    DisplayableShard(shard),
1293                                    read_compression
1294                                )
1295                            };
1296                        }
1297
1298                        let ctx = EvaluationContext {
1299                            connection_seq_no: connection_no,
1300                            opcode: FrameOpcode::Request(request.opcode),
1301                            frame_body: request.body.clone(),
1302                            connection_has_events: event_registered_flag.load(Ordering::Relaxed),
1303                        };
1304                        let mut guard = request_rules.lock().unwrap();
1305                        '_ruleloop: for (i, request_rule) in guard.iter_mut().enumerate() {
1306                            if request_rule.0.eval(&ctx) {
1307                                info!("Applying rule no={} to request ({} -> {} ({})).", i, driver_addr, DisplayableRealAddrOption(real_addr), DisplayableShard(shard));
1308                                debug!("-> Applied rule: {:?}", request_rule);
1309                                debug!("-> To request: {:?}", ctx.opcode);
1310                                trace!("{:?}", request);
1311
1312                                if let Some(ref tx) = request_rule.1.feedback_channel {
1313                                    tx.send((request.clone(), shard)).unwrap_or_else(|err|
1314                                        warn!("Could not send received request as feedback: {}", err)
1315                                    );
1316                                }
1317
1318                                let request_rule = request_rule.clone();
1319                                let to_addressee_action = request_rule.1.to_addressee;
1320                                let to_sender_action = request_rule.1.to_sender;
1321                                let drop_connection_action = request_rule.1.drop_connection;
1322
1323                                let cluster_tx_clone = cluster_tx.clone();
1324                                let request_clone = request.clone();
1325                                let pass_action = async move {
1326                                    if let Some(ref pass_action) = to_addressee_action {
1327                                        if let Some(time) = pass_action.delay {
1328                                            tokio::time::sleep(time).await;
1329                                        }
1330                                        let passed_frame = match pass_action.msg_processor {
1331                                            Some(ref processor) => processor(request_clone),
1332                                            None => request_clone,
1333                                        };
1334                                        let _ = cluster_tx_clone.send(passed_frame);
1335                                    };
1336                                };
1337
1338                                let driver_tx_clone = driver_tx.clone();
1339                                let request_clone = request.clone();
1340                                let forge_action = async move {
1341                                    if let Some(ref forge_action) = to_sender_action {
1342                                        if let Some(time) = forge_action.delay {
1343                                            tokio::time::sleep(time).await;
1344                                        }
1345                                        let forged_frame = {
1346                                            let processor = forge_action.msg_processor.as_ref()
1347                                                .expect("Frame processor is required to forge a frame.");
1348                                            processor(request_clone)
1349                                        };
1350                                        let _ = driver_tx_clone.send(forged_frame);
1351                                    };
1352                                };
1353
1354                                let connection_close_signaler_clone =
1355                                    connection_close_signaler.clone();
1356                                let drop_action = async move {
1357                                    if let Some(ref delay) = drop_connection_action {
1358                                        if let Some(time) = delay {
1359                                            tokio::time::sleep(*time).await;
1360                                        }
1361                                        // close connection.
1362                                        info!(
1363                                            "Dropping connection between {} and {} ({}) (as requested by a proxy rule)!",
1364                                            driver_addr,
1365                                            DisplayableRealAddrOption(real_addr),
1366                                            DisplayableShard(shard),
1367                                        );
1368                                        let _ = connection_close_signaler_clone.send(());
1369                                    }
1370                                };
1371
1372                                tokio::task::spawn(async {
1373                                    futures::join!(pass_action, forge_action, drop_action);
1374                                });
1375
1376                                continue 'mainloop; // only one rule can be applied to one frame
1377                            }
1378                        }
1379                        let _ = cluster_tx.send(request); // default action
1380                    }
1381                    None => return Ok(()),
1382                }
1383            }
1384        })
1385        .await;
1386    }
1387
1388    #[expect(clippy::too_many_arguments)]
1389    async fn response_processor(
1390        self,
1391        mut responses_rx: mpsc::UnboundedReceiver<ResponseFrame>,
1392        driver_tx: mpsc::UnboundedSender<ResponseFrame>,
1393        cluster_tx: mpsc::UnboundedSender<RequestFrame>,
1394        connection_no: usize,
1395        response_rules: Arc<Mutex<Vec<ResponseRule>>>,
1396        connection_close_signaler: ConnectionCloseSignaler,
1397        event_registered_flag: Arc<AtomicBool>,
1398    ) {
1399        let shard = self.shard;
1400        self.run_until_interrupted("response_processor", |driver_addr, _, real_addr| async move {
1401            'mainloop: loop {
1402                match responses_rx.recv().await {
1403                    Some(response) => {
1404                        let ctx = EvaluationContext {
1405                            connection_seq_no: connection_no,
1406                            opcode: FrameOpcode::Response(response.opcode),
1407                            frame_body: response.body.clone(),
1408                            connection_has_events: event_registered_flag.load(Ordering::Relaxed),
1409                        };
1410                        let mut guard = response_rules.lock().unwrap();
1411                        '_ruleloop: for (i, response_rule) in guard.iter_mut().enumerate() {
1412                            if response_rule.0.eval(&ctx) {
1413                                info!("Applying rule no={} to response ({} -> {} ({})).", i, DisplayableRealAddrOption(real_addr), driver_addr, DisplayableShard(shard));
1414                                debug!("-> Applied rule: {:?}", response_rule);
1415                                debug!("-> To response: {:?}", ctx.opcode);
1416                                trace!("{:?}", response);
1417
1418                                if let Some(ref tx) = response_rule.1.feedback_channel {
1419                                    tx.send((response.clone(), shard)).unwrap_or_else(|err| warn!(
1420                                        "Could not send received response as feedback: {}", err
1421                                    ));
1422                                }
1423
1424                                let response_rule = response_rule.clone();
1425                                let to_addressee_action = response_rule.1.to_addressee;
1426                                let to_sender_action = response_rule.1.to_sender;
1427                                let drop_connection_action = response_rule.1.drop_connection;
1428
1429                                let response_clone = response.clone();
1430                                let driver_tx_clone = driver_tx.clone();
1431                                let pass_action = async move {
1432                                    if let Some(ref pass_action) = to_addressee_action {
1433                                        if let Some(time) = pass_action.delay {
1434                                            tokio::time::sleep(time).await;
1435                                        }
1436                                        let passed_frame = match pass_action.msg_processor {
1437                                            Some(ref processor) => processor(response_clone),
1438                                            None => response_clone,
1439                                        };
1440                                        let _ = driver_tx_clone.send(passed_frame);
1441                                    };
1442                                };
1443
1444                                let response_clone = response.clone();
1445                                let cluster_tx_clone = cluster_tx.clone();
1446                                let forge_action = async move {
1447                                    if let Some(ref forge_action) = to_sender_action {
1448                                        if let Some(time) = forge_action.delay {
1449                                            tokio::time::sleep(time).await;
1450                                        }
1451                                        let forged_frame = {
1452                                            let processor = forge_action.msg_processor.as_ref()
1453                                                .expect("Frame processor is required to forge a frame.");
1454                                            processor(response_clone)
1455                                        };
1456                                        let _ = cluster_tx_clone.send(forged_frame);
1457                                    };
1458                                };
1459
1460                                let connection_close_signaler_clone =
1461                                    connection_close_signaler.clone();
1462                                let drop_action = async move {
1463                                    if let Some(ref delay) = drop_connection_action {
1464                                        if let Some(time) = delay {
1465                                            tokio::time::sleep(*time).await;
1466                                        }
1467                                        // close connection.
1468                                        info!(
1469                                            "Dropping connection between {} and {} ({}) (as requested by a proxy rule)!",
1470                                            driver_addr,
1471                                            real_addr.expect("BUG: response rules are unavailable for dry-mode proxy!"),
1472                                            DisplayableShard(shard)
1473                                        );
1474                                        let _ = connection_close_signaler_clone.send(());
1475                                    }
1476                                };
1477
1478                                tokio::task::spawn(async {
1479                                    futures::join!(pass_action, forge_action, drop_action);
1480                                });
1481
1482                                continue 'mainloop;
1483                            }
1484                        }
1485                        let _ = driver_tx.send(response); // default action
1486                    }
1487                    None => return Ok(()),
1488                }
1489            }
1490        })
1491        .await
1492    }
1493}
1494
1495// Returns next free IP address for another proxy instance.
1496// Useful for concurrent testing.
1497pub fn get_exclusive_local_address() -> IpAddr {
1498    // A big enough number reduces possibility of clashes with user-taken addresses:
1499    static ADDRESS_LOWER_THREE_OCTETS: AtomicU32 = AtomicU32::new(4242);
1500    let next_addr = ADDRESS_LOWER_THREE_OCTETS.fetch_add(1, Ordering::Relaxed);
1501    if next_addr > (u32::MAX >> 8) {
1502        panic!("Loopback address pool for tests depleted");
1503    }
1504    let next_addr_bytes = next_addr.to_le_bytes();
1505    IpAddr::V4(Ipv4Addr::new(
1506        127,
1507        next_addr_bytes[2],
1508        next_addr_bytes[1],
1509        next_addr_bytes[0],
1510    ))
1511}
1512
1513#[cfg(test)]
1514mod tests {
1515    use super::compression::no_compression;
1516    use super::*;
1517    use crate::errors::ReadFrameError;
1518    use crate::frame::{FrameType, read_frame, read_request_frame};
1519    use crate::proxy::compression::with_compression;
1520    use crate::{
1521        Condition, Reaction as _, RequestReaction, ResponseOpcode, ResponseReaction, setup_tracing,
1522    };
1523    use assert_matches::assert_matches;
1524    use bytes::{BufMut, BytesMut};
1525    use futures::future::{join, join3};
1526    use rand::RngCore;
1527    use scylla_cql::frame::request::options;
1528    use scylla_cql::frame::request::{SerializableRequest as _, Startup};
1529    use scylla_cql::frame::types::write_string_multimap;
1530    use scylla_cql::frame::{Compression, flag};
1531    use std::collections::HashMap;
1532    use std::mem;
1533    use std::str::FromStr;
1534    use std::time::Duration;
1535    use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
1536    use tokio::sync::oneshot;
1537
1538    fn random_body() -> Bytes {
1539        let body_len = (rand::random::<u32>() % 1000) as usize;
1540        let mut body = BytesMut::zeroed(body_len);
1541        rand::rng().fill_bytes(body.as_mut());
1542        body.freeze()
1543    }
1544
1545    async fn respond_with_supported(
1546        conn: &mut TcpStream,
1547        supported_options: &HashMap<String, Vec<String>>,
1548        compression: &CompressionReader,
1549    ) {
1550        let RequestFrame {
1551            params: recvd_params,
1552            opcode: recvd_opcode,
1553            body: recvd_body,
1554        } = read_request_frame(conn, compression).await.unwrap();
1555        assert_eq!(recvd_params, HARDCODED_OPTIONS_PARAMS);
1556        assert_eq!(recvd_opcode, RequestOpcode::Options);
1557        assert_eq!(recvd_body, Bytes::new()); // body should be empty
1558
1559        let mut body = BytesMut::new();
1560        write_string_multimap(supported_options, &mut body).unwrap();
1561
1562        let body = body.freeze();
1563
1564        write_frame(
1565            HARDCODED_OPTIONS_PARAMS.for_response(),
1566            FrameOpcode::Response(ResponseOpcode::Supported),
1567            &body,
1568            conn,
1569            &no_compression(),
1570        )
1571        .await
1572        .unwrap();
1573    }
1574
1575    fn supported_shards_count(shards_count: u16) -> HashMap<String, Vec<String>> {
1576        let mut sharded_info = HashMap::new();
1577        sharded_info.insert(
1578            String::from("SCYLLA_NR_SHARDS"),
1579            vec![shards_count.to_string()],
1580        );
1581        sharded_info
1582    }
1583
1584    fn supported_shard_number(shard_num: TargetShard) -> HashMap<String, Vec<String>> {
1585        let mut sharded_info = HashMap::new();
1586        sharded_info.insert(String::from("SCYLLA_SHARD"), vec![shard_num.to_string()]);
1587        sharded_info
1588    }
1589
1590    async fn respond_with_shards_count(
1591        conn: &mut TcpStream,
1592        shards_count: u16,
1593        compression: &CompressionReader,
1594    ) {
1595        respond_with_supported(conn, &supported_shards_count(shards_count), compression).await;
1596    }
1597
1598    async fn respond_with_shard_num(
1599        conn: &mut TcpStream,
1600        shard_num: TargetShard,
1601        compression: &CompressionReader,
1602    ) {
1603        respond_with_supported(conn, &supported_shard_number(shard_num), compression).await;
1604    }
1605
1606    fn next_local_address_with_port(port: u16) -> SocketAddr {
1607        SocketAddr::new(get_exclusive_local_address(), port)
1608    }
1609
1610    async fn identity_proxy_does_not_mutate_frames(shard_awareness: ShardAwareness) {
1611        let node1_real_addr = next_local_address_with_port(9876);
1612        let node1_proxy_addr = next_local_address_with_port(9876);
1613        let proxy = Proxy::new([Node::new(
1614            node1_real_addr,
1615            node1_proxy_addr,
1616            shard_awareness,
1617            None,
1618            None,
1619        )]);
1620        let running_proxy = proxy.run().await.unwrap();
1621
1622        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1623
1624        let params = FrameParams {
1625            flags: 0,
1626            version: 0x04,
1627            stream: 0,
1628        };
1629        let opcode = FrameOpcode::Request(RequestOpcode::Options);
1630
1631        let body = random_body();
1632
1633        let send_frame_to_shard = async {
1634            let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
1635
1636            write_frame(params, opcode, &body, &mut conn, &no_compression())
1637                .await
1638                .unwrap();
1639            conn
1640        };
1641
1642        let mock_node_action = async {
1643            if let ShardAwareness::QueryNode = shard_awareness {
1644                respond_with_shards_count(
1645                    &mut mock_node_listener.accept().await.unwrap().0,
1646                    1,
1647                    &no_compression(),
1648                )
1649                .await;
1650            }
1651            let (mut conn, _) = mock_node_listener.accept().await.unwrap();
1652            if shard_awareness.is_aware() {
1653                respond_with_shard_num(&mut conn, 1, &no_compression()).await;
1654            }
1655            let RequestFrame {
1656                params: recvd_params,
1657                opcode: recvd_opcode,
1658                body: recvd_body,
1659            } = read_request_frame(&mut conn, &no_compression())
1660                .await
1661                .unwrap();
1662            assert_eq!(recvd_params, params);
1663            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
1664            assert_eq!(recvd_body, body);
1665            conn
1666        };
1667
1668        // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
1669        let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1670        running_proxy.finish().await.unwrap();
1671    }
1672
1673    #[tokio::test]
1674    #[ntest::timeout(1000)]
1675    async fn identity_shard_unaware_proxy_does_not_mutate_frames() {
1676        setup_tracing();
1677        identity_proxy_does_not_mutate_frames(ShardAwareness::Unaware).await
1678    }
1679
1680    #[tokio::test]
1681    #[ntest::timeout(1000)]
1682    async fn identity_shard_aware_proxy_does_not_mutate_frames() {
1683        setup_tracing();
1684        identity_proxy_does_not_mutate_frames(ShardAwareness::QueryNode).await
1685    }
1686
1687    #[tokio::test]
1688    #[ntest::timeout(1000)]
1689    async fn shard_aware_proxy_is_transparent_for_connection_to_shards() {
1690        setup_tracing();
1691        async fn test_for_shards_num(shards_num: u16) {
1692            let node1_real_addr = next_local_address_with_port(9876);
1693            let node1_proxy_addr = next_local_address_with_port(9876);
1694            let proxy = Proxy::new([Node::new(
1695                node1_real_addr,
1696                node1_proxy_addr,
1697                ShardAwareness::FixedNum(shards_num),
1698                None,
1699                None,
1700            )]);
1701            let running_proxy = proxy.run().await.unwrap();
1702
1703            let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1704
1705            let (driver_addr_tx, driver_addr_rx) = oneshot::channel::<SocketAddr>();
1706
1707            let send_frame_to_shard = async {
1708                let socket = TcpSocket::new_v4().unwrap();
1709                socket
1710                    .bind(SocketAddr::from_str("0.0.0.0:0").unwrap())
1711                    .unwrap();
1712                let conn = socket.connect(node1_proxy_addr).await.unwrap();
1713                driver_addr_tx.send(conn.local_addr().unwrap()).unwrap();
1714                conn
1715            };
1716
1717            let mock_node_action = async {
1718                let (conn, remote_addr) = mock_node_listener.accept().await.unwrap();
1719                let driver_addr = driver_addr_rx.await.unwrap();
1720                assert_eq!(
1721                    driver_addr.port() % shards_num,
1722                    remote_addr.port() % shards_num
1723                );
1724                conn
1725            };
1726
1727            // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
1728            let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1729            running_proxy.finish().await.unwrap();
1730        }
1731
1732        for shard_num in 1..6 {
1733            test_for_shards_num(shard_num).await;
1734        }
1735    }
1736
1737    #[tokio::test]
1738    #[ntest::timeout(1000)]
1739    async fn shard_aware_proxy_queries_shards_number() {
1740        setup_tracing();
1741        async fn test_for_shards_num(shards_num: u16) {
1742            for shard_num in 0..shards_num {
1743                let node1_real_addr = next_local_address_with_port(9876);
1744                let node1_proxy_addr = next_local_address_with_port(9876);
1745                let proxy = Proxy::new([Node::new(
1746                    node1_real_addr,
1747                    node1_proxy_addr,
1748                    ShardAwareness::QueryNode,
1749                    None,
1750                    None,
1751                )]);
1752                let running_proxy = proxy.run().await.unwrap();
1753
1754                let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1755
1756                let (driver_addr_tx, driver_addr_rx) = oneshot::channel::<SocketAddr>();
1757
1758                let mock_driver_addr = next_local_address_with_port(shards_num * 1234 + shard_num);
1759                let send_frame_to_shard = async {
1760                    let socket = TcpSocket::new_v4().unwrap();
1761                    socket
1762                        .bind(mock_driver_addr)
1763                        .unwrap_or_else(|_| panic!("driver_addr failed: {mock_driver_addr}"));
1764                    driver_addr_tx.send(socket.local_addr().unwrap()).unwrap();
1765                    socket.connect(node1_proxy_addr).await.unwrap()
1766                };
1767
1768                let mock_node_action = async {
1769                    respond_with_shards_count(
1770                        &mut mock_node_listener.accept().await.unwrap().0,
1771                        shards_num,
1772                        &no_compression(),
1773                    )
1774                    .await;
1775                    let (conn, remote_addr) = mock_node_listener.accept().await.unwrap();
1776                    let driver_addr = driver_addr_rx.await.unwrap();
1777                    assert_eq!(
1778                        driver_addr.port() % shards_num,
1779                        remote_addr.port() % shards_num
1780                    );
1781                    conn
1782                };
1783
1784                let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1785                running_proxy.finish().await.unwrap();
1786            }
1787        }
1788
1789        for shard_num in 1..6 {
1790            test_for_shards_num(shard_num).await;
1791        }
1792    }
1793
1794    #[tokio::test]
1795    #[ntest::timeout(1000)]
1796    async fn forger_proxy_forges_response() {
1797        setup_tracing();
1798        let node1_real_addr = next_local_address_with_port(9876);
1799        let node1_proxy_addr = next_local_address_with_port(9876);
1800
1801        let this_shall_pass = b"This.Shall.Pass.";
1802        let test_msg = b"Test";
1803
1804        let proxy = Proxy::new([Node::new(
1805            node1_real_addr,
1806            node1_proxy_addr,
1807            ShardAwareness::Unaware,
1808            Some(vec![
1809                RequestRule(
1810                    Condition::RequestOpcode(RequestOpcode::Register),
1811                    RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
1812                        ResponseFrame {
1813                            params: params.for_response(),
1814                            opcode: ResponseOpcode::Event,
1815                            body: Bytes::from_static(test_msg),
1816                        }
1817                    })),
1818                ),
1819                RequestRule(
1820                    Condition::BodyContainsCaseSensitive(Box::new(*this_shall_pass)),
1821                    RequestReaction::noop(),
1822                ),
1823                RequestRule(
1824                    Condition::True, // only the first matching rule is applied, so "True" covers all remaining cases
1825                    RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
1826                        ResponseFrame {
1827                            params: params.for_response(),
1828                            opcode: ResponseOpcode::Ready,
1829                            body: Bytes::new(),
1830                        }
1831                    })),
1832                ),
1833            ]),
1834            None,
1835        )]);
1836        let running_proxy = proxy.run().await.unwrap();
1837
1838        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1839
1840        let params1 = FrameParams {
1841            flags: 2,
1842            version: 0x42,
1843            stream: 42,
1844        };
1845        let opcode1 = FrameOpcode::Request(RequestOpcode::Startup);
1846
1847        let params2 = FrameParams {
1848            flags: 4,
1849            version: 0x04,
1850            stream: 17,
1851        };
1852        let opcode2 = FrameOpcode::Request(RequestOpcode::Register);
1853
1854        let params3 = FrameParams {
1855            flags: 8,
1856            version: 0x04,
1857            stream: 11,
1858        };
1859        let opcode3 = FrameOpcode::Request(RequestOpcode::Execute);
1860
1861        let body1 = random_body();
1862        let body2 = random_body();
1863        let body3 = {
1864            let mut body = BytesMut::new();
1865            body.put(&b"uSeLeSs JuNk"[..]);
1866            body.put(&this_shall_pass[..]);
1867            body.freeze()
1868        };
1869
1870        let send_frame_to_shard = async {
1871            let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
1872
1873            write_frame(params1, opcode1, &body1, &mut conn, &no_compression())
1874                .await
1875                .unwrap();
1876            write_frame(params2, opcode2, &body2, &mut conn, &no_compression())
1877                .await
1878                .unwrap();
1879            write_frame(params3, opcode3, &body3, &mut conn, &no_compression())
1880                .await
1881                .unwrap();
1882
1883            let ResponseFrame {
1884                params: recvd_params,
1885                opcode: recvd_opcode,
1886                body: recvd_body,
1887            } = read_response_frame(&mut conn, &no_compression())
1888                .await
1889                .unwrap();
1890            assert_eq!(recvd_params, params1.for_response());
1891            assert_eq!(recvd_opcode, ResponseOpcode::Ready);
1892            assert_eq!(recvd_body, Bytes::new());
1893
1894            let ResponseFrame {
1895                params: recvd_params,
1896                opcode: recvd_opcode,
1897                body: recvd_body,
1898            } = read_response_frame(&mut conn, &no_compression())
1899                .await
1900                .unwrap();
1901            assert_eq!(recvd_params, params2.for_response());
1902            assert_eq!(recvd_opcode, ResponseOpcode::Event);
1903            assert_eq!(recvd_body, Bytes::from_static(test_msg));
1904
1905            conn
1906        };
1907
1908        let mock_node_action = async {
1909            let (mut conn, _) = mock_node_listener.accept().await.unwrap();
1910            let RequestFrame {
1911                params: recvd_params,
1912                opcode: recvd_opcode,
1913                body: recvd_body,
1914            } = read_request_frame(&mut conn, &no_compression())
1915                .await
1916                .unwrap();
1917            assert_eq!(recvd_params, params3);
1918            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode3);
1919            assert_eq!(recvd_body, body3);
1920
1921            conn
1922        };
1923
1924        let (mut node_conn, mut driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1925
1926        running_proxy.finish().await.unwrap();
1927
1928        assert_matches!(driver_conn.read(&mut [0u8; 1]).await, Ok(0));
1929        assert_matches!(node_conn.read(&mut [0u8; 1]).await, Ok(0));
1930    }
1931
1932    #[tokio::test]
1933    #[ntest::timeout(1000)]
1934    async fn ad_hoc_rules_changing() {
1935        setup_tracing();
1936        let node1_real_addr = next_local_address_with_port(9876);
1937        let node1_proxy_addr = next_local_address_with_port(9876);
1938        let proxy = Proxy::new([Node::new(
1939            node1_real_addr,
1940            node1_proxy_addr,
1941            ShardAwareness::Unaware,
1942            None,
1943            None,
1944        )]);
1945        let mut running_proxy = proxy.run().await.unwrap();
1946
1947        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1948
1949        let params = FrameParams {
1950            flags: 0,
1951            version: 0x04,
1952            stream: 0,
1953        };
1954        let opcode = FrameOpcode::Request(RequestOpcode::Options);
1955
1956        let body = random_body();
1957
1958        let (mut driver, mut node) = {
1959            let results = join(
1960                TcpStream::connect(node1_proxy_addr),
1961                mock_node_listener.accept(),
1962            )
1963            .await;
1964            (results.0.unwrap(), results.1.unwrap().0)
1965        };
1966
1967        async fn request(
1968            driver: &mut TcpStream,
1969            node: &mut TcpStream,
1970            params: FrameParams,
1971            opcode: FrameOpcode,
1972            body: &Bytes,
1973        ) -> Result<RequestFrame, ReadFrameError> {
1974            let (send_res, recv_res) = join(
1975                write_frame(params, opcode, &body.clone(), driver, &no_compression()),
1976                read_request_frame(node, &no_compression()),
1977            )
1978            .await;
1979            send_res.unwrap();
1980            recv_res
1981        }
1982        {
1983            // one run still without custom rules
1984            let RequestFrame {
1985                params: recvd_params,
1986                opcode: recvd_opcode,
1987                body: recvd_body,
1988            } = request(&mut driver, &mut node, params, opcode, &body)
1989                .await
1990                .unwrap();
1991            assert_eq!(recvd_params, params);
1992            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
1993            assert_eq!(recvd_body, body);
1994        }
1995        running_proxy.running_nodes[0].change_request_rules(Some(vec![RequestRule(
1996            Condition::True,
1997            RequestReaction::drop_frame(),
1998        )]));
1999
2000        {
2001            // one run with custom rules
2002            tokio::select! {
2003                res = request(&mut driver, &mut node, params, opcode, &body) => panic!("Rules did not work: received response {:?}", res),
2004                _ = tokio::time::sleep(std::time::Duration::from_millis(20)) => (),
2005            };
2006        }
2007
2008        running_proxy.turn_off_rules();
2009
2010        {
2011            // one run already without custom rules
2012            let RequestFrame {
2013                params: recvd_params,
2014                opcode: recvd_opcode,
2015                body: recvd_body,
2016            } = request(&mut driver, &mut node, params, opcode, &body)
2017                .await
2018                .unwrap();
2019            assert_eq!(recvd_params, params);
2020            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
2021            assert_eq!(recvd_body, body);
2022        }
2023
2024        running_proxy.finish().await.unwrap();
2025    }
2026
2027    #[tokio::test]
2028    #[ntest::timeout(2000)]
2029    async fn limited_times_condition_expires() {
2030        setup_tracing();
2031        const FAILING_TRIES: usize = 4;
2032        const PASSING_TRIES: usize = 5;
2033
2034        let node1_real_addr = next_local_address_with_port(9876);
2035        let node1_proxy_addr = next_local_address_with_port(9876);
2036        let proxy = Proxy::new([Node::new(
2037            node1_real_addr,
2038            node1_proxy_addr,
2039            ShardAwareness::Unaware,
2040            Some(vec![
2041                RequestRule(
2042                    // this will be always fired after first PASSING_TRIES + FAILING_TRIES
2043                    Condition::not(Condition::TrueForLimitedTimes(
2044                        FAILING_TRIES + PASSING_TRIES,
2045                    )),
2046                    RequestReaction::drop_frame(),
2047                ),
2048                RequestRule(
2049                    // this will be fired for PASSING_TRIES after first FAILING_TRIES
2050                    Condition::not(Condition::TrueForLimitedTimes(FAILING_TRIES)),
2051                    RequestReaction::noop(),
2052                ),
2053                RequestRule(
2054                    // this will be fired for first FAILING_TRIES
2055                    Condition::True,
2056                    RequestReaction::drop_frame(),
2057                ),
2058            ]),
2059            None,
2060        )]);
2061        let running_proxy = proxy.run().await.unwrap();
2062
2063        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2064
2065        let params = FrameParams {
2066            flags: 0,
2067            version: 0x04,
2068            stream: 0,
2069        };
2070        let opcode = FrameOpcode::Request(RequestOpcode::Options);
2071        let body = random_body();
2072
2073        let (mut driver, mut node) = {
2074            let results = join(
2075                TcpStream::connect(node1_proxy_addr),
2076                mock_node_listener.accept(),
2077            )
2078            .await;
2079            (results.0.unwrap(), results.1.unwrap().0)
2080        };
2081
2082        async fn request(
2083            driver: &mut TcpStream,
2084            node: &mut TcpStream,
2085            params: FrameParams,
2086            opcode: FrameOpcode,
2087            body: &Bytes,
2088        ) -> Result<RequestFrame, ReadFrameError> {
2089            let (send_res, recv_res) = join(
2090                write_frame(params, opcode, &body.clone(), driver, &no_compression()),
2091                read_request_frame(node, &no_compression()),
2092            )
2093            .await;
2094            send_res.unwrap();
2095            recv_res
2096        }
2097
2098        for _ in 0..FAILING_TRIES {
2099            tokio::select! {
2100                res = request(&mut driver, &mut node, params, opcode, &body) => panic!("Rules did not work: received response {:?}", res),
2101                _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => (),
2102            };
2103        }
2104
2105        for _ in 0..PASSING_TRIES {
2106            let RequestFrame {
2107                params: recvd_params,
2108                opcode: recvd_opcode,
2109                body: recvd_body,
2110            } = request(&mut driver, &mut node, params, opcode, &body)
2111                .await
2112                .unwrap();
2113            assert_eq!(recvd_params, params);
2114            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
2115            assert_eq!(recvd_body, body);
2116        }
2117
2118        for _ in 0..3 {
2119            // any further number of requests should fail
2120            tokio::select! {
2121                res = request(&mut driver, &mut node, params, opcode, &body) => panic!("Rules did not work: received response {:?}", res),
2122                _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => (),
2123            };
2124        }
2125
2126        running_proxy.finish().await.unwrap();
2127    }
2128
2129    #[tokio::test]
2130    #[ntest::timeout(1000)]
2131    async fn proxy_reports_requests_and_responses_as_feedback() {
2132        setup_tracing();
2133        let node1_real_addr = next_local_address_with_port(9876);
2134        let node1_proxy_addr = next_local_address_with_port(9876);
2135
2136        let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
2137        let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
2138        let proxy = Proxy::new([Node::new(
2139            node1_real_addr,
2140            node1_proxy_addr,
2141            ShardAwareness::Unaware,
2142            Some(vec![RequestRule(
2143                Condition::True,
2144                RequestReaction::drop_frame().with_feedback_when_performed(request_feedback_tx),
2145            )]),
2146            Some(vec![ResponseRule(
2147                Condition::True,
2148                ResponseReaction::drop_frame().with_feedback_when_performed(response_feedback_tx),
2149            )]),
2150        )]);
2151        let running_proxy = proxy.run().await.unwrap();
2152
2153        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2154
2155        let params = FrameParams {
2156            flags: 0,
2157            version: 0x04,
2158            stream: 0,
2159        };
2160        let request_opcode = FrameOpcode::Request(RequestOpcode::Options);
2161        let response_opcode = FrameOpcode::Response(ResponseOpcode::Ready);
2162
2163        let body = random_body();
2164
2165        let send_frame_to_shard = async {
2166            let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2167            write_frame(params, request_opcode, &body, &mut conn, &no_compression())
2168                .await
2169                .unwrap();
2170            conn
2171        };
2172
2173        let mock_node_action = async {
2174            let (mut conn, _) = mock_node_listener.accept().await.unwrap();
2175            write_frame(
2176                params.for_response(),
2177                response_opcode,
2178                &body,
2179                &mut conn,
2180                &no_compression(),
2181            )
2182            .await
2183            .unwrap();
2184            conn
2185        };
2186
2187        // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
2188        let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
2189
2190        let (feedback_request, _shard) = request_feedback_rx.recv().await.unwrap();
2191        assert_eq!(feedback_request.params, params);
2192        assert_eq!(
2193            FrameOpcode::Request(feedback_request.opcode),
2194            request_opcode
2195        );
2196        assert_eq!(feedback_request.body, body);
2197        let (feedback_response, _shard) = response_feedback_rx.recv().await.unwrap();
2198        assert_eq!(feedback_response.params, params.for_response());
2199        assert_eq!(
2200            FrameOpcode::Response(feedback_response.opcode),
2201            response_opcode
2202        );
2203        assert_eq!(feedback_response.body, body);
2204
2205        running_proxy.finish().await.unwrap();
2206    }
2207
2208    #[tokio::test]
2209    #[ntest::timeout(1000)]
2210    async fn sanity_check_reports_errors() {
2211        setup_tracing();
2212        let node1_real_addr = next_local_address_with_port(9876);
2213        let node1_proxy_addr = next_local_address_with_port(9876);
2214        let proxy = Proxy::new([Node::new(
2215            node1_real_addr,
2216            node1_proxy_addr,
2217            ShardAwareness::Unaware,
2218            None,
2219            None,
2220        )]);
2221        let mut running_proxy = proxy.run().await.unwrap();
2222
2223        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2224
2225        let send_frame_to_shard = async {
2226            let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2227
2228            conn.write_all(b"uselessJunk").await.unwrap();
2229            conn
2230        };
2231
2232        let mock_node_action = async {
2233            let (conn, _) = mock_node_listener.accept().await.unwrap();
2234            conn
2235        };
2236
2237        let (node_conn, driver_conn) = join(mock_node_action, send_frame_to_shard).await;
2238
2239        running_proxy.sanity_check().unwrap();
2240
2241        mem::drop(driver_conn);
2242        assert_matches!(
2243            running_proxy.wait_for_error().await,
2244            Some(ProxyError::Worker(WorkerError::DriverDisconnected(_)))
2245        );
2246        running_proxy.sanity_check().unwrap();
2247
2248        mem::drop(node_conn);
2249        assert_matches!(
2250            running_proxy.wait_for_error().await,
2251            Some(ProxyError::Worker(WorkerError::NodeDisconnected(_)))
2252        );
2253        running_proxy.sanity_check().unwrap();
2254
2255        // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
2256        let _ = running_proxy.finish().await;
2257    }
2258
2259    #[tokio::test]
2260    #[ntest::timeout(1000)]
2261    async fn proxy_processes_requests_concurrently() {
2262        setup_tracing();
2263        let node1_real_addr = next_local_address_with_port(9876);
2264        let node1_proxy_addr = next_local_address_with_port(9876);
2265
2266        let delay = Duration::from_millis(60);
2267
2268        let proxy = Proxy::new([Node::new(
2269            node1_real_addr,
2270            node1_proxy_addr,
2271            ShardAwareness::Unaware,
2272            Some(vec![RequestRule(
2273                Condition::TrueForLimitedTimes(1),
2274                RequestReaction::delay(delay),
2275            )]),
2276            None,
2277        )]);
2278        let running_proxy = proxy.run().await.unwrap();
2279
2280        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2281
2282        let params1 = FrameParams {
2283            flags: 0,
2284            version: 0x04,
2285            stream: 0,
2286        };
2287        let opcode1 = FrameOpcode::Request(RequestOpcode::Options);
2288
2289        let body1 = random_body();
2290
2291        let params2 = FrameParams {
2292            flags: 0,
2293            version: 0x04,
2294            stream: 0,
2295        };
2296        let opcode2 = FrameOpcode::Request(RequestOpcode::Register);
2297
2298        let body2 = random_body();
2299
2300        let send_frame_to_shard = async {
2301            let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2302
2303            write_frame(params1, opcode1, &body1, &mut conn, &no_compression())
2304                .await
2305                .unwrap();
2306            write_frame(params2, opcode2, &body2, &mut conn, &no_compression())
2307                .await
2308                .unwrap();
2309            conn
2310        };
2311
2312        let mock_node_action = async {
2313            let (mut conn, _) = mock_node_listener.accept().await.unwrap();
2314            let RequestFrame {
2315                params: recvd_params,
2316                opcode: recvd_opcode,
2317                body: recvd_body,
2318            } = read_request_frame(&mut conn, &no_compression())
2319                .await
2320                .unwrap();
2321            assert_eq!(recvd_params, params2);
2322            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode2);
2323            assert_eq!(recvd_body, body2);
2324            conn
2325        };
2326
2327        // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
2328        let (_node_conn, _driver_conn) =
2329            tokio::time::timeout(delay, join(mock_node_action, send_frame_to_shard))
2330                .await
2331                .expect("Request processing was not concurrent");
2332        running_proxy.finish().await.unwrap();
2333    }
2334
2335    #[tokio::test]
2336    #[ntest::timeout(1000)]
2337    async fn dry_mode_proxy_drops_incoming_frames() {
2338        setup_tracing();
2339        let node1_proxy_addr = next_local_address_with_port(9876);
2340        let proxy = Proxy::new([Node::new_dry_mode(node1_proxy_addr, None)]);
2341        let running_proxy = proxy.run().await.unwrap();
2342
2343        let params = FrameParams {
2344            flags: 0,
2345            version: 0x04,
2346            stream: 0,
2347        };
2348        let opcode = FrameOpcode::Request(RequestOpcode::Options);
2349
2350        let body = random_body();
2351
2352        let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2353
2354        write_frame(params, opcode, &body, &mut conn, &no_compression())
2355            .await
2356            .unwrap();
2357        // We assert that after sufficiently long time, no error happens inside proxy.
2358        tokio::time::sleep(Duration::from_millis(3)).await;
2359        running_proxy.finish().await.unwrap();
2360    }
2361
2362    #[tokio::test]
2363    #[ntest::timeout(1000)]
2364    async fn dry_mode_forger_proxy_forges_response() {
2365        setup_tracing();
2366        let node1_proxy_addr = next_local_address_with_port(9876);
2367
2368        let this_shall_pass = b"This.Shall.Pass.";
2369        let test_msg = b"Test";
2370
2371        let proxy = Proxy::new([Node::new_dry_mode(
2372            node1_proxy_addr,
2373            Some(vec![
2374                RequestRule(
2375                    Condition::RequestOpcode(RequestOpcode::Register),
2376                    RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
2377                        ResponseFrame {
2378                            params: params.for_response(),
2379                            opcode: ResponseOpcode::Event,
2380                            body: Bytes::from_static(test_msg),
2381                        }
2382                    })),
2383                ),
2384                RequestRule(
2385                    Condition::BodyContainsCaseSensitive(Box::new(*this_shall_pass)),
2386                    RequestReaction::noop(),
2387                ),
2388                RequestRule(
2389                    Condition::True, // only the first matching rule is applied, so "True" covers all remaining cases
2390                    RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
2391                        ResponseFrame {
2392                            params: params.for_response(),
2393                            opcode: ResponseOpcode::Ready,
2394                            body: Bytes::new(),
2395                        }
2396                    })),
2397                ),
2398            ]),
2399        )]);
2400        let running_proxy = proxy.run().await.unwrap();
2401
2402        let params1 = FrameParams {
2403            flags: 2,
2404            version: 0x42,
2405            stream: 42,
2406        };
2407        let opcode1 = FrameOpcode::Request(RequestOpcode::Startup);
2408
2409        let params2 = FrameParams {
2410            flags: 4,
2411            version: 0x04,
2412            stream: 17,
2413        };
2414        let opcode2 = FrameOpcode::Request(RequestOpcode::Register);
2415
2416        let params3 = FrameParams {
2417            flags: 8,
2418            version: 0x04,
2419            stream: 11,
2420        };
2421        let opcode3 = FrameOpcode::Request(RequestOpcode::Execute);
2422
2423        let body1 = random_body();
2424        let body2 = random_body();
2425        let body3 = {
2426            let mut body = BytesMut::new();
2427            body.put(&b"uSeLeSs JuNk"[..]);
2428            body.put(&this_shall_pass[..]);
2429            body.freeze()
2430        };
2431
2432        let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2433
2434        write_frame(params1, opcode1, &body1, &mut conn, &no_compression())
2435            .await
2436            .unwrap();
2437        write_frame(params2, opcode2, &body2, &mut conn, &no_compression())
2438            .await
2439            .unwrap();
2440        write_frame(params3, opcode3, &body3, &mut conn, &no_compression())
2441            .await
2442            .unwrap();
2443
2444        let ResponseFrame {
2445            params: recvd_params,
2446            opcode: recvd_opcode,
2447            body: recvd_body,
2448        } = read_response_frame(&mut conn, &no_compression())
2449            .await
2450            .unwrap();
2451        assert_eq!(recvd_params, params1.for_response());
2452        assert_eq!(recvd_opcode, ResponseOpcode::Ready);
2453        assert_eq!(recvd_body, Bytes::new());
2454
2455        let ResponseFrame {
2456            params: recvd_params,
2457            opcode: recvd_opcode,
2458            body: recvd_body,
2459        } = read_response_frame(&mut conn, &no_compression())
2460            .await
2461            .unwrap();
2462        assert_eq!(recvd_params, params2.for_response());
2463        assert_eq!(recvd_opcode, ResponseOpcode::Event);
2464        assert_eq!(recvd_body, Bytes::from_static(test_msg));
2465
2466        running_proxy.finish().await.unwrap();
2467
2468        assert_matches!(conn.read(&mut [0u8; 1]).await, Ok(0));
2469    }
2470
2471    // The test asserts that once a (mock) driver connects to the proxy from some port,
2472    // the proxy will connect to a shard corresponding to that port and that the target
2473    // shard number will be sent through the feedback channel.
2474    #[tokio::test]
2475    #[ntest::timeout(1000)]
2476    async fn proxy_reports_target_shard_as_feedback() {
2477        setup_tracing();
2478
2479        let node_port = 10101;
2480        let node_real_addr = next_local_address_with_port(node_port);
2481        let mock_node_listener = TcpListener::bind(node_real_addr).await.unwrap();
2482
2483        let params = FrameParams {
2484            flags: 0,
2485            version: 0x04,
2486            stream: 0,
2487        };
2488        let request_opcode = FrameOpcode::Request(RequestOpcode::Options);
2489        let response_opcode = FrameOpcode::Response(ResponseOpcode::Ready);
2490
2491        let body = random_body();
2492
2493        for shards_count in 2..9 {
2494            // Two driver connections are simulated, each to a different shard.
2495            let driver1_shard = shards_count - 1;
2496            let driver2_shard = shards_count - 2;
2497            let node_proxy_addr = next_local_address_with_port(node_port);
2498
2499            let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
2500            let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
2501
2502            let proxy = Proxy::new([Node::new(
2503                node_real_addr,
2504                node_proxy_addr,
2505                ShardAwareness::FixedNum(shards_count),
2506                Some(vec![RequestRule(
2507                    Condition::True,
2508                    RequestReaction::drop_frame().with_feedback_when_performed(request_feedback_tx),
2509                )]),
2510                Some(vec![ResponseRule(
2511                    Condition::True,
2512                    ResponseReaction::drop_frame()
2513                        .with_feedback_when_performed(response_feedback_tx),
2514                )]),
2515            )]);
2516            let running_proxy = proxy.run().await.unwrap();
2517
2518            /// Choose a source port `p` such that `shard == shard_of_source_port(p)`.
2519            fn draw_source_port_for_shard(shards_count: u16, shard: u16) -> u16 {
2520                assert!(shard < shards_count);
2521                49152u16.div_ceil(shards_count) * shards_count + shard
2522            }
2523
2524            async fn bind_socket_for_shard(shards_count: u16, shard: u16) -> TcpSocket {
2525                let socket = TcpSocket::new_v4().unwrap();
2526                let initial_port = draw_source_port_for_shard(shards_count, shard);
2527
2528                let mut desired_addr =
2529                    SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), initial_port);
2530                while socket.bind(desired_addr).is_err() {
2531                    // in search for a port that translates to the desired shard
2532                    let next_port = desired_addr.port().wrapping_add(shards_count);
2533                    if next_port == initial_port {
2534                        panic!("No more ports left");
2535                    }
2536                    desired_addr.set_port(next_port);
2537                }
2538
2539                socket
2540            }
2541
2542            let body_ref = &body;
2543            let send_frame_to_shard = |driver_shard: u16| async move {
2544                let socket = bind_socket_for_shard(shards_count, driver_shard).await;
2545                let mut conn = socket.connect(node_proxy_addr).await.unwrap();
2546
2547                write_frame(
2548                    params,
2549                    request_opcode,
2550                    body_ref,
2551                    &mut conn,
2552                    &no_compression(),
2553                )
2554                .await
2555                .unwrap();
2556                conn
2557            };
2558
2559            let mock_driver1_action = send_frame_to_shard(driver1_shard);
2560            let mock_driver2_action = send_frame_to_shard(driver2_shard);
2561
2562            // Accepts two connections and sends a response to each of them.
2563            let mock_node_action = async {
2564                let mut conns_futs = (0..2)
2565                    .map(|_| async {
2566                        let (mut conn, driver_addr) = mock_node_listener.accept().await.unwrap();
2567                        respond_with_shard_num(
2568                            &mut conn,
2569                            driver_addr.port() % shards_count,
2570                            &no_compression(),
2571                        )
2572                        .await;
2573                        write_frame(
2574                            params.for_response(),
2575                            response_opcode,
2576                            body_ref,
2577                            &mut conn,
2578                            &no_compression(),
2579                        )
2580                        .await
2581                        .unwrap();
2582                        conn
2583                    })
2584                    .collect::<Vec<_>>();
2585                let conn2 = conns_futs.pop().unwrap().await;
2586                let conn1 = conns_futs.pop().unwrap().await;
2587                (conn1, conn2)
2588            };
2589
2590            // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
2591            let (_node_conns, _driver1_conn, _driver2_conn) =
2592                join3(mock_node_action, mock_driver1_action, mock_driver2_action).await;
2593
2594            let assert_feedback_request = |feedback_request: RequestFrame| {
2595                assert_eq!(feedback_request.params, params);
2596                assert_eq!(
2597                    FrameOpcode::Request(feedback_request.opcode),
2598                    request_opcode
2599                );
2600                assert_eq!(feedback_request.body, body);
2601            };
2602
2603            let assert_feedback_response = |feedback_response: ResponseFrame| {
2604                assert_eq!(feedback_response.params, params.for_response());
2605                assert_eq!(
2606                    FrameOpcode::Response(feedback_response.opcode),
2607                    response_opcode
2608                );
2609                assert_eq!(feedback_response.body, body);
2610            };
2611
2612            let (feedback_request, shard1) = request_feedback_rx.recv().await.unwrap();
2613            assert_feedback_request(feedback_request);
2614            let (feedback_request, shard2) = request_feedback_rx.recv().await.unwrap();
2615            assert_feedback_request(feedback_request);
2616            let (feedback_response, shard3) = response_feedback_rx.recv().await.unwrap();
2617            assert_feedback_response(feedback_response);
2618            let (feedback_response, shard4) = response_feedback_rx.recv().await.unwrap();
2619            assert_feedback_response(feedback_response);
2620
2621            // expected: {driver1_shard request, driver1_shard response, driver2_shard request, driver2_shard response}
2622            let mut expected_shards = [driver1_shard, driver1_shard, driver2_shard, driver2_shard];
2623            expected_shards.sort_unstable();
2624
2625            let mut got_shards = [
2626                shard1.unwrap(),
2627                shard2.unwrap(),
2628                shard3.unwrap(),
2629                shard4.unwrap(),
2630            ];
2631            got_shards.sort_unstable();
2632
2633            assert_eq!(expected_shards, got_shards);
2634
2635            running_proxy.finish().await.unwrap();
2636        }
2637    }
2638
2639    #[tokio::test]
2640    #[ntest::timeout(1000)]
2641    async fn proxy_ignores_control_connection_messages() {
2642        setup_tracing();
2643        let node1_real_addr = next_local_address_with_port(9876);
2644        let node1_proxy_addr = next_local_address_with_port(9876);
2645
2646        let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
2647        let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
2648        let proxy = Proxy::new([Node::new(
2649            node1_real_addr,
2650            node1_proxy_addr,
2651            ShardAwareness::Unaware,
2652            Some(vec![RequestRule(
2653                Condition::not(Condition::ConnectionRegisteredAnyEvent),
2654                RequestReaction::noop().with_feedback_when_performed(request_feedback_tx),
2655            )]),
2656            Some(vec![ResponseRule(
2657                Condition::not(Condition::ConnectionRegisteredAnyEvent),
2658                ResponseReaction::noop().with_feedback_when_performed(response_feedback_tx),
2659            )]),
2660        )]);
2661        let running_proxy = proxy.run().await.unwrap();
2662
2663        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2664
2665        let (mut client_socket, mut server_socket) = join(
2666            async { TcpStream::connect(node1_proxy_addr).await.unwrap() },
2667            async { mock_node_listener.accept().await.unwrap().0 },
2668        )
2669        .await;
2670
2671        async fn perform_reqest_response<'a>(
2672            req_opcode: RequestOpcode,
2673            resp_opcode: ResponseOpcode,
2674            client_socket_ref: &'a mut TcpStream,
2675            server_socket_ref: &'a mut TcpStream,
2676            body_base: &'a str,
2677        ) {
2678            let params = FrameParams {
2679                flags: 0,
2680                version: 0x04,
2681                stream: 0,
2682            };
2683
2684            write_frame(
2685                params,
2686                FrameOpcode::Request(req_opcode),
2687                (body_base.to_string() + "|request|").as_bytes(),
2688                client_socket_ref,
2689                &no_compression(),
2690            )
2691            .await
2692            .unwrap();
2693
2694            let received_request =
2695                read_frame(server_socket_ref, FrameType::Request, &no_compression())
2696                    .await
2697                    .unwrap();
2698            assert_eq!(received_request.1, FrameOpcode::Request(req_opcode));
2699
2700            write_frame(
2701                params.for_response(),
2702                FrameOpcode::Response(resp_opcode),
2703                (body_base.to_string() + "|response|").as_bytes(),
2704                server_socket_ref,
2705                &no_compression(),
2706            )
2707            .await
2708            .unwrap();
2709
2710            let received_response =
2711                read_frame(client_socket_ref, FrameType::Response, &no_compression())
2712                    .await
2713                    .unwrap();
2714            assert_eq!(received_response.1, FrameOpcode::Response(resp_opcode));
2715        }
2716
2717        // Messages before REGISTER should be fed back to channels
2718        for i in 0..5 {
2719            perform_reqest_response(
2720                RequestOpcode::Query,
2721                ResponseOpcode::Result,
2722                &mut client_socket,
2723                &mut server_socket,
2724                &format!("message_before_{i}"),
2725            )
2726            .await
2727        }
2728
2729        perform_reqest_response(
2730            RequestOpcode::Register,
2731            ResponseOpcode::Result,
2732            &mut client_socket,
2733            &mut server_socket,
2734            "message_register",
2735        )
2736        .await;
2737
2738        // Messages after REGISTER should be passed through without feedback
2739        for i in 0..5 {
2740            perform_reqest_response(
2741                RequestOpcode::Query,
2742                ResponseOpcode::Result,
2743                &mut client_socket,
2744                &mut server_socket,
2745                &format!("message_after_{i}"),
2746            )
2747            .await
2748        }
2749
2750        running_proxy.finish().await.unwrap();
2751
2752        for _ in 0..5 {
2753            let (feedback_request, _shard) = request_feedback_rx.recv().await.unwrap();
2754            assert_eq!(feedback_request.opcode, RequestOpcode::Query);
2755            let (feedback_response, _shard) = response_feedback_rx.recv().await.unwrap();
2756            assert_eq!(feedback_response.opcode, ResponseOpcode::Result);
2757        }
2758
2759        // Response to REGISTER and further requests / responses should be ignored
2760        let _ = request_feedback_rx.try_recv().unwrap_err();
2761        let _ = response_feedback_rx.try_recv().unwrap_err();
2762    }
2763
2764    #[tokio::test]
2765    #[ntest::timeout(1000)]
2766    async fn proxy_compresses_and_decompresses_frames_iff_compression_negotiated() {
2767        setup_tracing();
2768        let node1_real_addr = next_local_address_with_port(9876);
2769        let node1_proxy_addr = next_local_address_with_port(9876);
2770
2771        let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
2772        let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
2773        let proxy = Proxy::builder()
2774            .with_node(
2775                Node::builder()
2776                    .real_address(node1_real_addr)
2777                    .proxy_address(node1_proxy_addr)
2778                    .shard_awareness(ShardAwareness::Unaware)
2779                    .request_rules(vec![RequestRule(
2780                        Condition::True,
2781                        RequestReaction::noop().with_feedback_when_performed(request_feedback_tx),
2782                    )])
2783                    .response_rules(vec![ResponseRule(
2784                        Condition::True,
2785                        ResponseReaction::noop().with_feedback_when_performed(response_feedback_tx),
2786                    )])
2787                    .build(),
2788            )
2789            .build();
2790        let running_proxy = proxy.run().await.unwrap();
2791
2792        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2793
2794        const PARAMS_REQUEST_NO_COMPRESSION: FrameParams = FrameParams {
2795            flags: 0,
2796            version: 0x04,
2797            stream: 0,
2798        };
2799        const PARAMS_REQUEST_COMPRESSION: FrameParams = FrameParams {
2800            flags: flag::COMPRESSION,
2801            ..PARAMS_REQUEST_NO_COMPRESSION
2802        };
2803        const PARAMS_RESPONSE_NO_COMPRESSION: FrameParams =
2804            PARAMS_REQUEST_NO_COMPRESSION.for_response();
2805        const PARAMS_RESPONSE_COMPRESSION: FrameParams =
2806            PARAMS_REQUEST_NO_COMPRESSION.for_response();
2807
2808        let make_driver_conn = async { TcpStream::connect(node1_proxy_addr).await.unwrap() };
2809        let make_node_conn = async { mock_node_listener.accept().await.unwrap() };
2810
2811        let (mut driver_conn, (mut node_conn, _)) = join(make_driver_conn, make_node_conn).await;
2812
2813        /* Outline of the test:
2814         * 1. "driver" sends an uncompressed, e.g., QUERY frame, feedback returns its uncompressed body,
2815         *    and "node" receives the uncompressed frame.
2816         * 2. "node" responds with an uncompressed RESULT frame, feedback returns its uncompressed body,
2817         *    and "driver" receives the uncompressed frame.
2818         * 3. "driver" sends an uncompressed STARTUP frame, feedback returns its uncompressed body,
2819         *    and "node" receives the uncompressed frame. This step also triggers `CompressionWriter::set()`
2820         *    in the proxy, so the associated `CompressionReader`s are notified about it (and can use
2821         *    the negotiated compression algorithm to (de)compress the frames sent in steps 4. and 5.).
2822         * 4. "driver" sends a compressed, e.g., QUERY frame, feedback returns its uncompressed body,
2823         *    and "node" receives the compressed frame.
2824         * 5. "node" responds with a compressed RESULT frame, feedback returns its uncompressed body,
2825         *    and "driver" receives the compressed frame.
2826         */
2827
2828        // 1. "driver" sends an uncompressed, e.g., QUERY frame, feedback returns its uncompressed body,
2829        //    and "node" receives the uncompressed frame.
2830        {
2831            let sent_frame = RequestFrame {
2832                params: PARAMS_REQUEST_NO_COMPRESSION,
2833                opcode: RequestOpcode::Query,
2834                body: random_body(),
2835            };
2836
2837            sent_frame
2838                .write(&mut driver_conn, &no_compression())
2839                .await
2840                .unwrap();
2841
2842            let (captured_frame, _) = request_feedback_rx.recv().await.unwrap();
2843            assert_eq!(captured_frame, sent_frame);
2844
2845            let received_frame = read_request_frame(&mut node_conn, &no_compression())
2846                .await
2847                .unwrap();
2848            assert_eq!(received_frame, sent_frame);
2849        }
2850
2851        // 2. "node" responds with an uncompressed RESULT frame, feedback returns its uncompressed body,
2852        //    and "driver" receives the uncompressed frame.
2853        {
2854            let sent_frame = ResponseFrame {
2855                params: PARAMS_RESPONSE_NO_COMPRESSION,
2856                opcode: ResponseOpcode::Result,
2857                body: random_body(),
2858            };
2859
2860            sent_frame
2861                .write(&mut node_conn, &no_compression())
2862                .await
2863                .unwrap();
2864
2865            let (captured_frame, _) = response_feedback_rx.recv().await.unwrap();
2866            assert_eq!(captured_frame, sent_frame);
2867
2868            let received_frame = read_response_frame(&mut driver_conn, &no_compression())
2869                .await
2870                .unwrap();
2871            assert_eq!(received_frame, sent_frame);
2872        }
2873
2874        // 3. "driver" sends an uncompressed STARTUP frame, feedback returns its uncompressed body,
2875        //    and "node" receives the uncompressed frame. This step also triggers `CompressionWriter::set()`
2876        //    in the proxy, so the associated `CompressionReader`s are notified about it (and can use
2877        //    the negotiated compression algorithm to (de)compress the frames sent in steps 4. and 5.).
2878        {
2879            let startup_body = Startup {
2880                options: std::iter::once((
2881                    options::COMPRESSION.into(),
2882                    Compression::Lz4.as_str().into(),
2883                ))
2884                .collect(),
2885            }
2886            .to_bytes()
2887            .unwrap();
2888
2889            let sent_frame = RequestFrame {
2890                params: PARAMS_REQUEST_NO_COMPRESSION,
2891                opcode: RequestOpcode::Startup,
2892                body: startup_body,
2893            };
2894
2895            sent_frame
2896                .write(&mut driver_conn, &no_compression())
2897                .await
2898                .unwrap();
2899
2900            let (captured_frame, _) = request_feedback_rx.recv().await.unwrap();
2901            assert_eq!(captured_frame, sent_frame);
2902
2903            let received_frame = read_request_frame(&mut node_conn, &no_compression())
2904                .await
2905                .unwrap();
2906            assert_eq!(received_frame, sent_frame);
2907        }
2908
2909        // 4. "driver" sends a compressed, e.g., QUERY frame, feedback returns its uncompressed body,
2910        //    and "node" receives the compressed frame.
2911        {
2912            let sent_frame = RequestFrame {
2913                params: PARAMS_REQUEST_COMPRESSION,
2914                opcode: RequestOpcode::Query,
2915                body: random_body(),
2916            };
2917
2918            sent_frame
2919                .write(&mut driver_conn, &with_compression(Compression::Lz4))
2920                .await
2921                .unwrap();
2922
2923            let (captured_frame, _) = request_feedback_rx.recv().await.unwrap();
2924            assert_eq!(captured_frame, sent_frame);
2925
2926            let received_frame =
2927                read_request_frame(&mut node_conn, &with_compression(Compression::Lz4))
2928                    .await
2929                    .unwrap();
2930            assert_eq!(received_frame, sent_frame);
2931        }
2932
2933        // 5. "node" responds with a compressed RESULT frame, feedback returns its uncompressed body,
2934        //    and "driver" receives the compressed frame.
2935        {
2936            let sent_frame = ResponseFrame {
2937                params: PARAMS_RESPONSE_COMPRESSION,
2938                opcode: ResponseOpcode::Result,
2939                body: random_body(),
2940            };
2941
2942            sent_frame
2943                .write(&mut node_conn, &with_compression(Compression::Lz4))
2944                .await
2945                .unwrap();
2946
2947            let (captured_frame, _) = response_feedback_rx.recv().await.unwrap();
2948            assert_eq!(captured_frame, sent_frame);
2949
2950            let received_frame =
2951                read_response_frame(&mut driver_conn, &with_compression(Compression::Lz4))
2952                    .await
2953                    .unwrap();
2954            assert_eq!(received_frame, sent_frame);
2955        }
2956
2957        running_proxy.finish().await.unwrap();
2958    }
2959}