scylla_proxy/
proxy.rs

1use crate::actions::{EvaluationContext, RequestRule, ResponseRule};
2use crate::errors::{DoorkeeperError, ProxyError, WorkerError};
3use crate::frame::{
4    self, read_response_frame, write_frame, FrameOpcode, FrameParams, RequestFrame, ResponseFrame,
5};
6use crate::{RequestOpcode, TargetShard};
7use bytes::Bytes;
8use scylla_cql::frame::types::read_string_multimap;
9use std::collections::HashMap;
10use std::fmt::Display;
11use std::future::Future;
12use std::net::{IpAddr, Ipv4Addr, SocketAddr};
13use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
14use std::sync::{Arc, Mutex};
15use tokio::io::{AsyncRead, AsyncWrite};
16use tokio::net::{TcpListener, TcpSocket, TcpStream};
17use tokio::sync::mpsc::error::TryRecvError;
18use tokio::sync::{broadcast, mpsc};
19use tracing::{debug, error, info, trace, warn};
20
21// Used to notify the user that the proxy finished - this happens when all Senders are dropped.
22type FinishWaiter = mpsc::Receiver<()>;
23type FinishGuard = mpsc::Sender<()>;
24
25// Used to tell all the proxy workers to stop when the user requests that with [RunningProxy::finish()].
26type TerminateNotifier = tokio::sync::broadcast::Receiver<()>;
27type TerminateSignaler = tokio::sync::broadcast::Sender<()>;
28
29// Used to tell all proxy workers working on same connection to stop when
30// a rule being applied has connection drop set.
31type ConnectionCloseNotifier = tokio::sync::broadcast::Receiver<()>;
32type ConnectionCloseSignaler = tokio::sync::broadcast::Sender<()>;
33
34// Used to gather errors from all proxy workers and propagate them to the proxy user,
35// returning the first of them from [RunningProxy::finish()].
36type ErrorPropagator = mpsc::UnboundedSender<ProxyError>;
37type ErrorSink = mpsc::UnboundedReceiver<ProxyError>;
38
39static HARDCODED_OPTIONS_PARAMS: FrameParams = FrameParams {
40    flags: 0,
41    version: 0x04,
42    stream: 0,
43};
44
45/// Specifies proxy's behaviour regarding shard awareness.
46#[derive(Clone, Copy, Debug)]
47pub enum ShardAwareness {
48    /// Acts as if the connection was made to the shard-unaware port.
49    Unaware,
50    /// The first time the driver attempts to connect to the particular node (through proxy),
51    /// the related node is first queried on a temporary connection for its number of shards,
52    /// and only then establishes another connection for the driver's real communication with the node.
53    /// If the queried node does not provide sharding info (e.g. in case of a Cassandra node),
54    /// then this mode behaves as Unaware.
55    QueryNode,
56    /// Binds to the port that is the same as the driver's port modulo the provided number of shards.
57    FixedNum(u16),
58}
59
60impl ShardAwareness {
61    pub fn is_aware(&self) -> bool {
62        !matches!(self, Self::Unaware)
63    }
64}
65
66/// Node can be either Real (truly backed by a Scylla node) or Simulated
67/// (driver believes it's real, but we merely simulate it with the proxy).
68/// In Simulated mode, no node address is provided and proxy does not attempt
69/// to establish connection with a Scylla node.
70///
71/// For Real node, all workers are created, so such frame flow is possible:
72/// [driver] -> receiver_from_driver -> requests_processor -> sender_to_cluster -> [node] (ordinary request flow)
73///                                        |                    /\
74///     (forging response) ++--------------+      +-------------++ (forging request)
75///                        \/                     |
76/// [driver] <- sender_to_driver <- response_processor <- receiver_from_cluster <- [node] (ordinary response flow)
77///
78/// For Simulated node, it looks like this:
79/// [driver] -> receiver_from_driver -> requests_processor -+
80///                                                         |   (forging response)
81/// [driver] <- sender_to_driver <--------------------------+
82///
83/// For Real node, the default reaction to a frame is to pass it to its intended addresse.
84/// For Simulated node, the default reaction to a request is to drop it.
85enum NodeType {
86    Real {
87        real_addr: SocketAddr,
88        shard_awareness: ShardAwareness,
89        response_rules: Option<Vec<ResponseRule>>,
90    },
91    Simulated,
92}
93
94pub struct Node {
95    proxy_addr: SocketAddr,
96    request_rules: Option<Vec<RequestRule>>,
97    node_type: NodeType,
98}
99
100impl Node {
101    /// Creates an abstract node that is backed by a real Scylla node.
102    pub fn new(
103        real_addr: SocketAddr,
104        proxy_addr: SocketAddr,
105        shard_awareness: ShardAwareness,
106        request_rules: Option<Vec<RequestRule>>,
107        response_rules: Option<Vec<ResponseRule>>,
108    ) -> Self {
109        Self {
110            proxy_addr,
111            request_rules,
112            node_type: NodeType::Real {
113                real_addr,
114                shard_awareness,
115                response_rules,
116            },
117        }
118    }
119
120    /// Creates a simulated node that is not backed by any real Scylla node.
121    pub fn new_dry_mode(proxy_addr: SocketAddr, request_rules: Option<Vec<RequestRule>>) -> Self {
122        Self {
123            proxy_addr,
124            request_rules,
125            node_type: NodeType::Simulated,
126        }
127    }
128
129    pub fn builder() -> NodeBuilder {
130        NodeBuilder {
131            real_addr: None,
132            proxy_addr: None,
133            shard_awareness: None,
134            request_rules: None,
135            response_rules: None,
136        }
137    }
138}
139
140pub struct NodeBuilder {
141    real_addr: Option<SocketAddr>,
142    proxy_addr: Option<SocketAddr>,
143    shard_awareness: Option<ShardAwareness>,
144    request_rules: Option<Vec<RequestRule>>,
145    response_rules: Option<Vec<ResponseRule>>,
146}
147
148impl NodeBuilder {
149    pub fn real_address(mut self, real_addr: SocketAddr) -> Self {
150        self.real_addr = Some(real_addr);
151        self
152    }
153
154    pub fn proxy_address(mut self, proxy_addr: SocketAddr) -> Self {
155        self.proxy_addr = Some(proxy_addr);
156        self
157    }
158
159    pub fn shard_awareness(mut self, shard_awareness: ShardAwareness) -> Self {
160        self.shard_awareness = Some(shard_awareness);
161        self
162    }
163
164    pub fn request_rules(mut self, request_rules: Vec<RequestRule>) -> Self {
165        self.request_rules = Some(request_rules);
166        self
167    }
168
169    pub fn response_rules(mut self, response_rules: Vec<ResponseRule>) -> Self {
170        self.response_rules = Some(response_rules);
171        self
172    }
173
174    /// Creates an abstract node that is backed by a real Scylla node.
175    pub fn build(self) -> Node {
176        Node {
177            proxy_addr: self.proxy_addr.expect("Proxy addr is required!"),
178            request_rules: self.request_rules,
179            node_type: NodeType::Real {
180                real_addr: self.real_addr.expect("Real addr is required!"),
181                shard_awareness: self.shard_awareness.expect("Shard awareness is required!"),
182                response_rules: self.response_rules,
183            },
184        }
185    }
186
187    /// Creates a simulated node that is not backed by any real Scylla node.
188    pub fn build_dry_mode(self) -> Node {
189        Node {
190            proxy_addr: self.proxy_addr.expect("Proxy addr is required!"),
191            request_rules: self.request_rules,
192            node_type: NodeType::Simulated,
193        }
194    }
195}
196
197#[derive(Clone, Copy)]
198struct DisplayableRealAddrOption(Option<SocketAddr>);
199impl Display for DisplayableRealAddrOption {
200    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201        if let Some(addr) = self.0 {
202            write!(f, "{}", addr)
203        } else {
204            write!(f, "<dry mode>")
205        }
206    }
207}
208
209#[derive(Clone, Copy)]
210struct DisplayableShard(Option<TargetShard>);
211impl Display for DisplayableShard {
212    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213        if let Some(shard) = self.0 {
214            write!(f, "shard {}", shard)
215        } else {
216            write!(f, "unknown shard")
217        }
218    }
219}
220
221enum InternalNode {
222    Real {
223        real_addr: SocketAddr,
224        proxy_addr: SocketAddr,
225        shard_awareness: ShardAwareness,
226        request_rules: Arc<Mutex<Vec<RequestRule>>>,
227        response_rules: Arc<Mutex<Vec<ResponseRule>>>,
228    },
229    Simulated {
230        proxy_addr: SocketAddr,
231        request_rules: Arc<Mutex<Vec<RequestRule>>>,
232    },
233}
234
235impl InternalNode {
236    fn proxy_addr(&self) -> SocketAddr {
237        match *self {
238            InternalNode::Real { proxy_addr, .. } => proxy_addr,
239            InternalNode::Simulated { proxy_addr, .. } => proxy_addr,
240        }
241    }
242    fn real_addr(&self) -> Option<SocketAddr> {
243        match *self {
244            InternalNode::Real { real_addr, .. } => Some(real_addr),
245            InternalNode::Simulated { .. } => None,
246        }
247    }
248    fn request_rules(&self) -> &Arc<Mutex<Vec<RequestRule>>> {
249        match self {
250            InternalNode::Real { request_rules, .. } => request_rules,
251            InternalNode::Simulated { request_rules, .. } => request_rules,
252        }
253    }
254}
255
256impl From<Node> for InternalNode {
257    fn from(node: Node) -> Self {
258        match node.node_type {
259            NodeType::Real {
260                real_addr,
261                shard_awareness,
262                response_rules,
263            } => InternalNode::Real {
264                real_addr,
265                proxy_addr: node.proxy_addr,
266                shard_awareness,
267                request_rules: node
268                    .request_rules
269                    .map(|rules| Arc::new(Mutex::new(rules)))
270                    .unwrap_or_default(),
271                response_rules: response_rules
272                    .map(|rules| Arc::new(Mutex::new(rules)))
273                    .unwrap_or_default(),
274            },
275            NodeType::Simulated => InternalNode::Simulated {
276                proxy_addr: node.proxy_addr,
277                request_rules: node
278                    .request_rules
279                    .map(|rules| Arc::new(Mutex::new(rules)))
280                    .unwrap_or_default(),
281            },
282        }
283    }
284}
285
286pub struct ProxyBuilder {
287    nodes: Vec<Node>,
288}
289
290impl ProxyBuilder {
291    pub fn with_node(mut self, node: Node) -> ProxyBuilder {
292        self.nodes.push(node);
293        self
294    }
295
296    pub fn build(self) -> Proxy {
297        Proxy::new(self.nodes)
298    }
299}
300
301pub struct Proxy {
302    nodes: Vec<InternalNode>,
303}
304
305impl Proxy {
306    pub fn new(nodes: impl IntoIterator<Item = Node>) -> Self {
307        Proxy {
308            nodes: nodes.into_iter().map(|node| node.into()).collect(),
309        }
310    }
311
312    pub fn builder() -> ProxyBuilder {
313        ProxyBuilder { nodes: vec![] }
314    }
315
316    /// Build a translation map based on provided proxy and node addresses.
317    /// The map can be passed to `Session` `address_translator()` to ensure
318    /// that the driver contacts the nodes through the proxy (and not directly).
319    pub fn translation_map(&self) -> HashMap<SocketAddr, SocketAddr> {
320        let mut translation_map = HashMap::new();
321        for node in self.nodes.iter() {
322            if let &InternalNode::Real {
323                real_addr,
324                proxy_addr,
325                ..
326            } = node
327            {
328                translation_map.insert(real_addr, proxy_addr);
329                let shard_aware_real_addr = SocketAddr::new(real_addr.ip(), 19042);
330                translation_map.insert(shard_aware_real_addr, proxy_addr);
331            }
332        }
333        translation_map
334    }
335
336    /// Runs the [Proxy], i.e. makes it ready for accepting drivers' connections.
337    /// Returns a [RunningProxy] handle that can be used to stop the proxy or change the rules.
338    pub async fn run(self) -> Result<RunningProxy, DoorkeeperError> {
339        let (terminate_signaler, _t) = tokio::sync::broadcast::channel(1);
340        let (finish_guard, finish_waiter) = mpsc::channel(1);
341
342        let (error_propagator, error_sink) = mpsc::unbounded_channel();
343        let (doorkeepers, running_nodes): (Vec<_>, Vec<RunningNode>) = self
344            .nodes
345            .into_iter()
346            .map(|node| {
347                let running = {
348                    let (request_rules, response_rules) = match node {
349                        InternalNode::Real {
350                            ref request_rules,
351                            ref response_rules,
352                            ..
353                        } => (request_rules, Some(response_rules)),
354                        InternalNode::Simulated {
355                            ref request_rules, ..
356                        } => (request_rules, None),
357                    };
358                    RunningNode {
359                        request_rules: request_rules.clone(),
360                        response_rules: response_rules.cloned(),
361                    }
362                };
363                (
364                    Doorkeeper::spawn(
365                        node,
366                        terminate_signaler.clone(),
367                        finish_guard.clone(),
368                        error_propagator.clone(),
369                    ),
370                    running,
371                )
372            })
373            .unzip();
374
375        for doorkeeper in doorkeepers {
376            doorkeeper.await?; // await doorkeeper creation, including binding to a socket
377        }
378
379        Ok(RunningProxy {
380            terminate_signaler,
381            finish_waiter,
382            running_nodes,
383            error_sink,
384        })
385    }
386}
387
388/// A handle that can be used to change the rules regarding the particular node.
389pub struct RunningNode {
390    request_rules: Arc<Mutex<Vec<RequestRule>>>,
391    response_rules: Option<Arc<Mutex<Vec<ResponseRule>>>>,
392}
393
394impl RunningNode {
395    /// Replaces the previous request rules with the new ones.
396    pub fn change_request_rules(&mut self, rules: Option<Vec<RequestRule>>) {
397        *self.request_rules.lock().unwrap() = rules.unwrap_or_default();
398    }
399
400    /// Replaces the previous response rules with the new ones.
401    pub fn change_response_rules(&mut self, rules: Option<Vec<ResponseRule>>) {
402        *self
403            .response_rules
404            .as_ref()
405            .expect("No response rules on a simulated node!")
406            .lock()
407            .unwrap() = rules.unwrap_or_default();
408    }
409}
410
411/// A handle that can be used to stop the proxy or change the rules.
412pub struct RunningProxy {
413    terminate_signaler: TerminateSignaler,
414    finish_waiter: FinishWaiter,
415    pub running_nodes: Vec<RunningNode>,
416    error_sink: ErrorSink,
417}
418
419impl RunningProxy {
420    /// Disables all the rules in the proxy, effectively making it a pass-through-only proxy.
421    pub fn turn_off_rules(&mut self) {
422        for (request_rules, response_rules) in self
423            .running_nodes
424            .iter_mut()
425            .map(|node| (&node.request_rules, &node.response_rules))
426        {
427            request_rules.lock().unwrap().clear();
428            if let Some(response_rules) = response_rules {
429                response_rules.lock().unwrap().clear();
430            }
431        }
432    }
433
434    /// Attempts to fetch the first error that has occurred in proxy since last check.
435    /// If no errors occurred, returns Ok(()).
436    pub fn sanity_check(&mut self) -> Result<(), ProxyError> {
437        match self.error_sink.try_recv() {
438            Ok(err) => Err(err),
439            Err(TryRecvError::Empty) => Ok(()),
440            Err(TryRecvError::Disconnected) => {
441                // As we haven't awaited finish of all workers yet, there must be a faulty case without proper error handling.
442                Err(ProxyError::SanityCheckFailure)
443            }
444        }
445    }
446
447    /// Waits until an error occurs in proxy. If proxy finishes with no errors occurred, returns Err(()).
448    pub async fn wait_for_error(&mut self) -> Option<ProxyError> {
449        self.error_sink.recv().await
450    }
451
452    /// Requests termination of all proxy workers and awaits its completion.
453    /// Returns the first error that occurred in proxy.
454    pub async fn finish(mut self) -> Result<(), ProxyError> {
455        self.terminate_signaler.send(()).map_err(|err| {
456            ProxyError::AwaitFinishFailure(format!(
457                "Send error in terminate_signaler: {} (bug!)",
458                err
459            ))
460        })?;
461        info!("Sent finish signal to proxy workers.");
462
463        // This to make sure that also workers not-yet-spawned when terminate signal was sent will terminate.
464        std::mem::drop(self.terminate_signaler);
465
466        if self.finish_waiter.recv().await.is_some() {
467            unreachable!();
468        };
469        info!("All workers have finished.");
470
471        match self.error_sink.try_recv() {
472            Ok(err) => Err(err),
473            Err(TryRecvError::Disconnected) => Ok(()),
474            Err(TryRecvError::Empty) => {
475                // As we have already awaited finish of all workers, there must be a logic bug.
476                unreachable!("Worker await logic bug!");
477            }
478        }
479    }
480}
481
482/// A worker corresponding to a particular node. It listens in a loop for driver's connections
483/// on specified proxy bind address, respects ports regarding advanced shard-awareness (if set),
484/// to this end obtaining number of shards from the node (if set), then establishes connection
485/// to the node, spawns workers for this connection and continues to listen.
486struct Doorkeeper {
487    node: InternalNode,
488    listener: TcpListener,
489    terminate_signaler: TerminateSignaler,
490    finish_guard: FinishGuard,
491    shards_count: Option<u16>,
492    error_propagator: ErrorPropagator,
493}
494
495impl Doorkeeper {
496    async fn spawn(
497        node: InternalNode,
498        terminate_signaler: TerminateSignaler,
499        finish_guard: FinishGuard,
500        error_propagator: ErrorPropagator,
501    ) -> Result<(), DoorkeeperError> {
502        let listener = TcpListener::bind(node.proxy_addr())
503            .await
504            .map_err(|err| DoorkeeperError::DriverConnectionAttempt(node.proxy_addr(), err))?;
505
506        if let InternalNode::Real {
507            shard_awareness,
508            real_addr,
509            ..
510        } = node
511        {
512            info!(
513                "Spawned a {} doorkeeper for pair real:{} - proxy:{}.",
514                if shard_awareness.is_aware() {
515                    "shard-aware"
516                } else {
517                    "shard-unaware"
518                },
519                real_addr,
520                node.proxy_addr(),
521            );
522        } else {
523            info!(
524                "Spawned a dry-mode doorkeeper for proxy:{}.",
525                node.proxy_addr(),
526            )
527        };
528
529        let doorkeeper = Doorkeeper {
530            shards_count: None, // temporarily, until Doorkeeper examines its ShardAwareness
531            node,
532            listener,
533            terminate_signaler,
534            finish_guard,
535            error_propagator,
536        };
537        tokio::task::spawn(doorkeeper.run());
538        Ok(())
539    }
540
541    async fn run(mut self) {
542        self.update_shards_count().await;
543        let mut own_terminate_notifier = self.terminate_signaler.subscribe();
544        let (connection_close_tx, _connection_close_rx) = broadcast::channel::<()>(2);
545        let mut connection_no: usize = 0;
546        loop {
547            tokio::select! {
548                res = self.accept_connection(&connection_close_tx, connection_no) => {
549                    match res {
550                        Ok(()) => connection_no += 1,
551                        Err(err) => {
552                            error!(
553                                "Error in doorkeeper with addr {} for node {}: {}",
554                                self.node.proxy_addr(),
555                                DisplayableRealAddrOption(self.node.real_addr()),
556                                err
557                            );
558                            let _ = self.error_propagator.send(err.into());
559                            break;
560                        },
561                    }
562                },
563                _terminate = own_terminate_notifier.recv() => break
564            }
565        }
566        debug!(
567            "Doorkeeper exits: proxy {}, node {}.",
568            self.node.proxy_addr(),
569            DisplayableRealAddrOption(self.node.real_addr())
570        );
571    }
572
573    async fn update_shards_count(&mut self) {
574        if let InternalNode::Real {
575            real_addr,
576            shard_awareness,
577            ..
578        } = self.node
579        {
580            self.shards_count = match shard_awareness {
581                ShardAwareness::Unaware => None,
582                ShardAwareness::FixedNum(shards_num) => Some(shards_num),
583                ShardAwareness::QueryNode => match self.obtain_shards_count(real_addr).await {
584                    Ok(shards) => Some(shards),
585                    // If a node offers no sharding info, change proxy ShardAwareness to Unaware.
586                    Err(DoorkeeperError::ObtainingShardNumberNoShardInfo) => {
587                        info!(
588                            "Doorkeeper with addr {} found no shard info in node {}; falling back to ShardAwareness::Unaware",
589                            self.node.proxy_addr(),
590                            DisplayableRealAddrOption(self.node.real_addr()),
591                        );
592                        None
593                    }
594                    Err(e) => {
595                        error!(
596                            "Error in doorkeeper with addr {} while querying shard info from node {}: {}",
597                            self.node.proxy_addr(),
598                            DisplayableRealAddrOption(self.node.real_addr()),
599                            e
600                        );
601                        None
602                    }
603                },
604            }
605        }
606    }
607
608    async fn spawn_workers(
609        &mut self,
610        driver_addr: SocketAddr,
611        connection_close_tx: &ConnectionCloseSignaler,
612        connection_no: usize,
613        driver_stream: TcpStream,
614        cluster_stream: Option<TcpStream>,
615        shard: Option<TargetShard>,
616    ) {
617        let (driver_read, driver_write) = driver_stream.into_split();
618
619        let new_worker = || ProxyWorker {
620            terminate_notifier: self.terminate_signaler.subscribe(),
621            finish_guard: self.finish_guard.clone(),
622            connection_close_notifier: connection_close_tx.subscribe(),
623            error_propagator: self.error_propagator.clone(),
624            driver_addr,
625            real_addr: self.node.real_addr(),
626            proxy_addr: self.node.proxy_addr(),
627            shard,
628        };
629
630        let (tx_request, rx_request) = mpsc::unbounded_channel::<RequestFrame>();
631        let (tx_response, rx_response) = mpsc::unbounded_channel::<ResponseFrame>();
632        let (tx_cluster, rx_cluster) = mpsc::unbounded_channel::<RequestFrame>();
633        let (tx_driver, rx_driver) = mpsc::unbounded_channel::<ResponseFrame>();
634        let event_register_flag = Arc::new(AtomicBool::new(false));
635
636        tokio::task::spawn(new_worker().receiver_from_driver(driver_read, tx_request));
637        tokio::task::spawn(new_worker().sender_to_driver(
638            driver_write,
639            rx_driver,
640            connection_close_tx.subscribe(),
641            self.terminate_signaler.subscribe(),
642        ));
643        tokio::task::spawn(new_worker().request_processor(
644            rx_request,
645            tx_driver.clone(),
646            tx_cluster.clone(),
647            connection_no,
648            self.node.request_rules().clone(),
649            connection_close_tx.clone(),
650            event_register_flag.clone(),
651        ));
652        if let InternalNode::Real {
653            ref response_rules, ..
654        } = self.node
655        {
656            let (cluster_read, cluster_write) = cluster_stream.unwrap().into_split();
657            tokio::task::spawn(new_worker().sender_to_cluster(
658                cluster_write,
659                rx_cluster,
660                connection_close_tx.subscribe(),
661                self.terminate_signaler.subscribe(),
662            ));
663            tokio::task::spawn(new_worker().receiver_from_cluster(cluster_read, tx_response));
664            tokio::task::spawn(new_worker().response_processor(
665                rx_response,
666                tx_driver,
667                tx_cluster,
668                connection_no,
669                response_rules.clone(),
670                connection_close_tx.clone(),
671                event_register_flag.clone(),
672            ));
673        }
674        debug!(
675            "Doorkeeper with addr {} of node {} spawned workers.",
676            self.node.proxy_addr(),
677            DisplayableRealAddrOption(self.node.real_addr())
678        );
679    }
680
681    async fn accept_connection(
682        &mut self,
683        connection_close_tx: &ConnectionCloseSignaler,
684        connection_no: usize,
685    ) -> Result<(), DoorkeeperError> {
686        let (driver_stream, driver_addr) = self.make_driver_stream(connection_no).await?;
687        let (cluster_stream, shard) = match self.node {
688            InternalNode::Real { real_addr, .. } => {
689                let (cluster_stream, shard) =
690                    self.make_cluster_stream(driver_addr, real_addr).await?;
691                (Some(cluster_stream), shard)
692            }
693            InternalNode::Simulated { .. } => (None, None),
694        };
695
696        self.spawn_workers(
697            driver_addr,
698            connection_close_tx,
699            connection_no,
700            driver_stream,
701            cluster_stream,
702            shard,
703        )
704        .await;
705
706        Ok(())
707    }
708
709    async fn make_driver_stream(
710        &mut self,
711        connection_no: usize,
712    ) -> Result<(TcpStream, SocketAddr), DoorkeeperError> {
713        let (driver_stream, driver_addr) =
714            self.listener.accept().await.map_err(|err| {
715                DoorkeeperError::DriverConnectionAttempt(self.node.proxy_addr(), err)
716            })?;
717        info!(
718            "Connected driver from {} to {}, connection no={}.",
719            driver_addr,
720            self.node.proxy_addr(),
721            connection_no
722        );
723        Ok((driver_stream, driver_addr))
724    }
725
726    async fn make_cluster_stream(
727        &mut self,
728        driver_addr: SocketAddr,
729        real_addr: SocketAddr,
730    ) -> Result<(TcpStream, Option<TargetShard>), DoorkeeperError> {
731        let mut cluster_stream = if let Some(shards) = self.shards_count {
732            let socket = match self.node.proxy_addr().ip() {
733                std::net::IpAddr::V4(_) => TcpSocket::new_v4(),
734                std::net::IpAddr::V6(_) => TcpSocket::new_v6(),
735            }
736            .map_err(DoorkeeperError::SocketCreate)?;
737
738            let shard_preserving_addr = {
739                let mut desired_addr =
740                    SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), driver_addr.port());
741                while socket.bind(desired_addr).is_err() {
742                    // in search for a port that translates to the desired shard
743                    let next_port = self.next_port_to_same_shard(desired_addr.port());
744                    if next_port == driver_addr.port() {
745                        return Err(DoorkeeperError::NoMorePorts);
746                    }
747                    desired_addr.set_port(next_port);
748                }
749                desired_addr
750            };
751
752            socket.connect(real_addr).await.map(|ok| {
753                info!(
754                    "Connected to the cluster from {} at {}, intended shard {}.",
755                    ok.local_addr().unwrap(),
756                    real_addr,
757                    shard_preserving_addr.port() % shards
758                );
759                ok
760            })
761        } else {
762            TcpStream::connect(real_addr).await.map(|ok| {
763                info!("Connected to the cluster at {}.", real_addr);
764                ok
765            })
766        }
767        .map_err(|err| DoorkeeperError::NodeConnectionAttempt(real_addr, err))?;
768
769        // If ShardAwareness is aware (QueryNode or FixedNum variants) and the
770        // proxy succeeded to know the shards count (in FixedNum we get it for
771        // free, in QueryNode the initial Options query succeeded and Supported
772        // contained SCYLLA_SHARDS_NUM), then upon opening each connection to the
773        // node, the proxy issues another Options requests and acknowledges the
774        // shard it got connected to.
775        let shard = if self.shards_count.is_some() {
776            self.obtain_shard_number(real_addr, &mut cluster_stream)
777                .await?
778        } else {
779            None
780        };
781
782        Ok((cluster_stream, shard))
783    }
784
785    fn next_port_to_same_shard(&self, port: u16) -> u16 {
786        port.wrapping_add(self.shards_count.unwrap())
787    }
788
789    async fn get_supported_options(
790        connection: &mut TcpStream,
791    ) -> Result<HashMap<String, Vec<String>>, DoorkeeperError> {
792        write_frame(
793            HARDCODED_OPTIONS_PARAMS,
794            FrameOpcode::Request(RequestOpcode::Options),
795            &Bytes::new(),
796            connection,
797        )
798        .await
799        .map_err(DoorkeeperError::ObtainingShardNumber)?;
800
801        let supported_frame = read_response_frame(connection)
802            .await
803            .map_err(DoorkeeperError::ObtainingShardNumberFrame)?;
804
805        let options = read_string_multimap(&mut supported_frame.body.as_ref())
806            .map_err(DoorkeeperError::ObtainingShardNumberParseOptions)?;
807
808        Ok(options)
809    }
810
811    async fn obtain_shards_count(&self, real_addr: SocketAddr) -> Result<u16, DoorkeeperError> {
812        let mut connection = TcpStream::connect(real_addr)
813            .await
814            .map_err(|err| DoorkeeperError::NodeConnectionAttempt(real_addr, err))?;
815        let options = Self::get_supported_options(&mut connection).await?;
816        let nr_shards_entry = options.get("SCYLLA_NR_SHARDS");
817        let shards = match nr_shards_entry
818            .and_then(|vec| vec.first())
819            .ok_or(DoorkeeperError::ObtainingShardNumberNoShardInfo)?
820            .parse::<u16>()
821            .map_err(DoorkeeperError::ObtainingShardNumberParseShardNumber)?
822        {
823            0u16 => Err(DoorkeeperError::ObtainingShardNumberGotZero),
824            num => Ok(num),
825        }?;
826        info!("Obtained shards number on node {}: {}", real_addr, shards);
827        Ok(shards)
828    }
829
830    async fn obtain_shard_number(
831        &self,
832        real_addr: SocketAddr,
833        connection: &mut TcpStream,
834    ) -> Result<Option<TargetShard>, DoorkeeperError> {
835        let options = Self::get_supported_options(connection).await?;
836        let shard_entry = options.get("SCYLLA_SHARD");
837        let shard = shard_entry
838            .and_then(|vec| vec.first())
839            .map(|s| {
840                s.parse::<u16>()
841                    .map_err(DoorkeeperError::ObtainingShardNumberParseShardNumber)
842            })
843            .transpose()?;
844        info!("Connected to node {}, shard {:?}", real_addr, shard);
845        Ok(shard)
846    }
847}
848
849struct ProxyWorker {
850    terminate_notifier: TerminateNotifier,
851    finish_guard: FinishGuard,
852    connection_close_notifier: ConnectionCloseNotifier,
853    error_propagator: ErrorPropagator,
854    driver_addr: SocketAddr,
855    real_addr: Option<SocketAddr>,
856    proxy_addr: SocketAddr,
857    shard: Option<TargetShard>,
858}
859
860impl ProxyWorker {
861    fn exit(self, duty: &'static str) {
862        debug!(
863            "Worker exits: [driver: {}, proxy: {}, node: {}, {}]::{}.",
864            self.driver_addr,
865            self.proxy_addr,
866            DisplayableRealAddrOption(self.real_addr),
867            DisplayableShard(self.shard),
868            duty
869        );
870        std::mem::drop(self.finish_guard);
871    }
872
873    async fn run_until_interrupted<F, Fut>(mut self, worker_name: &'static str, f: F)
874    where
875        F: FnOnce(SocketAddr, SocketAddr, Option<SocketAddr>) -> Fut,
876        Fut: Future<Output = Result<(), ProxyError>>,
877    {
878        let fut = f(self.driver_addr, self.proxy_addr, self.real_addr);
879
880        tokio::select! {
881            result = fut => {
882                if let Err(err) = result {
883                    // error_propagator could be a field
884                    let _ = self.error_propagator.send(err);
885                }
886            }
887            _ = self.terminate_notifier.recv() => (),
888            _ = self.connection_close_notifier.recv() => (),
889        }
890        self.exit(worker_name);
891    }
892
893    async fn receiver_from_driver(
894        self,
895        mut read_half: (impl AsyncRead + Unpin),
896        request_processor_tx: mpsc::UnboundedSender<RequestFrame>,
897    ) {
898        let shard = self.shard;
899        self.run_until_interrupted(
900            "receiver_from_driver",
901            |driver_addr, proxy_addr, _real_addr| async move {
902                loop {
903                    let frame = frame::read_request_frame(&mut read_half)
904                        .await
905                        .map_err(|err| {
906                            warn!("Request reception from {} error: {}", driver_addr, err);
907                            WorkerError::DriverDisconnected(driver_addr)
908                        })?;
909
910                    debug!(
911                        "Intercepted Driver ({}) -> Cluster ({}) ({}) frame. opcode: {:?}.",
912                        driver_addr,
913                        proxy_addr,
914                        DisplayableShard(shard),
915                        &frame.opcode
916                    );
917                    if request_processor_tx.send(frame).is_err() {
918                        warn!("request_processor had exited.");
919                        return Result::<(), ProxyError>::Ok(());
920                    }
921                }
922            },
923        )
924        .await
925    }
926
927    async fn receiver_from_cluster(
928        self,
929        mut read_half: (impl AsyncRead + Unpin),
930        response_processor_tx: mpsc::UnboundedSender<ResponseFrame>,
931    ) {
932        let shard = self.shard;
933        self.run_until_interrupted(
934            "receiver_from_cluster",
935            |driver_addr, _proxy_addr, real_addr| async move {
936                let real_addr = real_addr.expect("BUG: no real_addr in cluster worker");
937                loop {
938                    let frame =
939                        frame::read_response_frame(&mut read_half)
940                            .await
941                            .map_err(|err| {
942                                warn!("Response reception from {} error: {}", real_addr, err);
943                                WorkerError::NodeDisconnected(real_addr)
944                            })?;
945
946                    debug!(
947                        "Intercepted Cluster ({}) -> Driver ({}) ({}) frame. opcode: {:?}.",
948                        real_addr,
949                        driver_addr,
950                        DisplayableShard(shard),
951                        &frame.opcode
952                    );
953
954                    if response_processor_tx.send(frame).is_err() {
955                        warn!("response_processor had exited.");
956                        return Ok::<(), ProxyError>(());
957                    }
958                }
959            },
960        )
961        .await;
962    }
963
964    async fn sender_to_driver(
965        self,
966        mut write_half: (impl AsyncWrite + Unpin),
967        mut responses_rx: mpsc::UnboundedReceiver<ResponseFrame>,
968        mut connection_close_notifier: ConnectionCloseNotifier,
969        mut terminate_notifier: TerminateNotifier,
970    ) {
971        let shard = self.shard;
972        self.run_until_interrupted(
973            "sender_to_driver",
974            |driver_addr, proxy_addr, _real_addr| async move {
975                loop {
976                    let response = match responses_rx.recv().await {
977                        Some(response) => response,
978                        None => {
979                            if terminate_notifier.try_recv().is_err()
980                                && connection_close_notifier.try_recv().is_err()
981                            {
982                                warn!("Response processor had exited");
983                            }
984                            return Ok(());
985                        }
986                    };
987
988                    debug!(
989                        "Sending Proxy ({}) -> Driver ({}) ({}) frame. opcode: {:?}.",
990                        proxy_addr,
991                        driver_addr,
992                        DisplayableShard(shard),
993                        &response.opcode
994                    );
995                    if response.write(&mut write_half).await.is_err() {
996                        if terminate_notifier.try_recv().is_err()
997                            && connection_close_notifier.try_recv().is_err()
998                        {
999                            warn!("Driver dropped connection");
1000                            return Err(WorkerError::DriverDisconnected(driver_addr).into());
1001                        }
1002                        return Ok(());
1003                    }
1004                }
1005            },
1006        )
1007        .await;
1008    }
1009
1010    async fn sender_to_cluster(
1011        self,
1012        mut write_half: (impl AsyncWrite + Unpin),
1013        mut requests_rx: mpsc::UnboundedReceiver<RequestFrame>,
1014        mut connection_close_notifier: ConnectionCloseNotifier,
1015        mut terminate_notifier: TerminateNotifier,
1016    ) {
1017        let shard = self.shard;
1018        self.run_until_interrupted(
1019            "sender_to_driver",
1020            |_driver_addr, proxy_addr, real_addr| async move {
1021                let real_addr = real_addr.expect("BUG: no real_addr in cluster worker");
1022                loop {
1023                    let request = match requests_rx.recv().await {
1024                        Some(request) => request,
1025                        None => {
1026                            if terminate_notifier.try_recv().is_err()
1027                                && connection_close_notifier.try_recv().is_err()
1028                            {
1029                                warn!("Request processor had exited");
1030                            }
1031                            return Ok(());
1032                        }
1033                    };
1034
1035                    debug!(
1036                        "Sending Proxy ({}) -> Cluster ({}) ({}) frame. opcode: {:?}.",
1037                        proxy_addr,
1038                        real_addr,
1039                        DisplayableShard(shard),
1040                        &request.opcode
1041                    );
1042
1043                    if request.write(&mut write_half).await.is_err() {
1044                        if terminate_notifier.try_recv().is_err()
1045                            && connection_close_notifier.try_recv().is_err()
1046                        {
1047                            warn!("Node {} dropped connection", real_addr);
1048                            return Err(WorkerError::NodeDisconnected(real_addr).into());
1049                        }
1050                        return Ok(());
1051                    }
1052                }
1053            },
1054        )
1055        .await;
1056    }
1057
1058    #[allow(clippy::too_many_arguments)]
1059    async fn request_processor(
1060        self,
1061        mut requests_rx: mpsc::UnboundedReceiver<RequestFrame>,
1062        driver_tx: mpsc::UnboundedSender<ResponseFrame>,
1063        cluster_tx: mpsc::UnboundedSender<RequestFrame>,
1064        connection_no: usize,
1065        request_rules: Arc<Mutex<Vec<RequestRule>>>,
1066        connection_close_signaler: ConnectionCloseSignaler,
1067        event_registered_flag: Arc<AtomicBool>,
1068    ) {
1069        let shard = self.shard;
1070        self.run_until_interrupted("request_processor", |driver_addr, _, real_addr| async move {
1071            'mainloop: loop {
1072                match requests_rx.recv().await {
1073                    Some(request) => {
1074                        if request.opcode == RequestOpcode::Register {
1075                            event_registered_flag.store(true, Ordering::Relaxed);
1076                        }
1077                        let ctx = EvaluationContext {
1078                            connection_seq_no: connection_no,
1079                            opcode: FrameOpcode::Request(request.opcode),
1080                            frame_body: request.body.clone(),
1081                            connection_has_events: event_registered_flag.load(Ordering::Relaxed),
1082                        };
1083                        let mut guard = request_rules.lock().unwrap();
1084                        '_ruleloop: for (i, request_rule) in guard.iter_mut().enumerate() {
1085                            if request_rule.0.eval(&ctx) {
1086                                info!("Applying rule no={} to request ({} -> {} ({})).", i, driver_addr, DisplayableRealAddrOption(real_addr), DisplayableShard(shard));
1087                                debug!("-> Applied rule: {:?}", request_rule);
1088                                debug!("-> To request: {:?}", ctx.opcode);
1089                                trace!("{:?}", request);
1090
1091                                if let Some(ref tx) = request_rule.1.feedback_channel {
1092                                    tx.send((request.clone(), shard)).unwrap_or_else(|err|
1093                                        warn!("Could not send received request as feedback: {}", err)
1094                                    );
1095                                }
1096
1097                                let request_rule = request_rule.clone();
1098                                let to_addressee_action = request_rule.1.to_addressee;
1099                                let to_sender_action = request_rule.1.to_sender;
1100                                let drop_connection_action = request_rule.1.drop_connection;
1101
1102                                let cluster_tx_clone = cluster_tx.clone();
1103                                let request_clone = request.clone();
1104                                let pass_action = async move {
1105                                    if let Some(ref pass_action) = to_addressee_action {
1106                                        if let Some(time) = pass_action.delay {
1107                                            tokio::time::sleep(time).await;
1108                                        }
1109                                        let passed_frame = match pass_action.msg_processor {
1110                                            Some(ref processor) => processor(request_clone),
1111                                            None => request_clone,
1112                                        };
1113                                        let _ = cluster_tx_clone.send(passed_frame);
1114                                    };
1115                                };
1116
1117                                let driver_tx_clone = driver_tx.clone();
1118                                let request_clone = request.clone();
1119                                let forge_action = async move {
1120                                    if let Some(ref forge_action) = to_sender_action {
1121                                        if let Some(time) = forge_action.delay {
1122                                            tokio::time::sleep(time).await;
1123                                        }
1124                                        let forged_frame = {
1125                                            let processor = forge_action.msg_processor.as_ref()
1126                                                .expect("Frame processor is required to forge a frame.");
1127                                            processor(request_clone)
1128                                        };
1129                                        let _ = driver_tx_clone.send(forged_frame);
1130                                    };
1131                                };
1132
1133                                let connection_close_signaler_clone =
1134                                    connection_close_signaler.clone();
1135                                let drop_action = async move {
1136                                    if let Some(ref delay) = drop_connection_action {
1137                                        if let Some(ref time) = delay {
1138                                            tokio::time::sleep(*time).await;
1139                                        }
1140                                        // close connection.
1141                                        info!(
1142                                            "Dropping connection between {} and {} ({}) (as requested by a proxy rule)!",
1143                                            driver_addr,
1144                                            DisplayableRealAddrOption(real_addr),
1145                                            DisplayableShard(shard),
1146                                        );
1147                                        let _ = connection_close_signaler_clone.send(());
1148                                    }
1149                                };
1150
1151                                tokio::task::spawn(async {
1152                                    futures::join!(pass_action, forge_action, drop_action);
1153                                });
1154
1155                                continue 'mainloop; // only one rule can be applied to one frame
1156                            }
1157                        }
1158                        let _ = cluster_tx.send(request); // default action
1159                    }
1160                    None => return Ok(()),
1161                }
1162            }
1163        })
1164        .await;
1165    }
1166
1167    #[allow(clippy::too_many_arguments)]
1168    async fn response_processor(
1169        self,
1170        mut responses_rx: mpsc::UnboundedReceiver<ResponseFrame>,
1171        driver_tx: mpsc::UnboundedSender<ResponseFrame>,
1172        cluster_tx: mpsc::UnboundedSender<RequestFrame>,
1173        connection_no: usize,
1174        response_rules: Arc<Mutex<Vec<ResponseRule>>>,
1175        connection_close_signaler: ConnectionCloseSignaler,
1176        event_registered_flag: Arc<AtomicBool>,
1177    ) {
1178        let shard = self.shard;
1179        self.run_until_interrupted("request_processor", |driver_addr, _, real_addr| async move {
1180            'mainloop: loop {
1181                match responses_rx.recv().await {
1182                    Some(response) => {
1183                        let ctx = EvaluationContext {
1184                            connection_seq_no: connection_no,
1185                            opcode: FrameOpcode::Response(response.opcode),
1186                            frame_body: response.body.clone(),
1187                            connection_has_events: event_registered_flag.load(Ordering::Relaxed),
1188                        };
1189                        let mut guard = response_rules.lock().unwrap();
1190                        '_ruleloop: for (i, response_rule) in guard.iter_mut().enumerate() {
1191                            if response_rule.0.eval(&ctx) {
1192                                info!("Applying rule no={} to request ({} -> {} ({})).", i, DisplayableRealAddrOption(real_addr), driver_addr, DisplayableShard(shard));
1193                                debug!("-> Applied rule: {:?}", response_rule);
1194                                debug!("-> To response: {:?}", ctx.opcode);
1195                                trace!("{:?}", response);
1196
1197                                if let Some(ref tx) = response_rule.1.feedback_channel {
1198                                    tx.send((response.clone(), shard)).unwrap_or_else(|err| warn!(
1199                                        "Could not send received response as feedback: {}", err
1200                                    ));
1201                                }
1202
1203                                let response_rule = response_rule.clone();
1204                                let to_addressee_action = response_rule.1.to_addressee;
1205                                let to_sender_action = response_rule.1.to_sender;
1206                                let drop_connection_action = response_rule.1.drop_connection;
1207
1208                                let response_clone = response.clone();
1209                                let driver_tx_clone = driver_tx.clone();
1210                                let pass_action = async move {
1211                                    if let Some(ref pass_action) = to_addressee_action {
1212                                        if let Some(time) = pass_action.delay {
1213                                            tokio::time::sleep(time).await;
1214                                        }
1215                                        let passed_frame = match pass_action.msg_processor {
1216                                            Some(ref processor) => processor(response_clone),
1217                                            None => response_clone,
1218                                        };
1219                                        let _ = driver_tx_clone.send(passed_frame);
1220                                    };
1221                                };
1222
1223                                let response_clone = response.clone();
1224                                let cluster_tx_clone = cluster_tx.clone();
1225                                let forge_action = async move {
1226                                    if let Some(ref forge_action) = to_sender_action {
1227                                        if let Some(time) = forge_action.delay {
1228                                            tokio::time::sleep(time).await;
1229                                        }
1230                                        let forged_frame = {
1231                                            let processor = forge_action.msg_processor.as_ref()
1232                                                .expect("Frame processor is required to forge a frame.");
1233                                            processor(response_clone)
1234                                        };
1235                                        let _ = cluster_tx_clone.send(forged_frame);
1236                                    };
1237                                };
1238
1239                                let connection_close_signaler_clone =
1240                                    connection_close_signaler.clone();
1241                                let drop_action = async move {
1242                                    if let Some(ref delay) = drop_connection_action {
1243                                        if let Some(ref time) = delay {
1244                                            tokio::time::sleep(*time).await;
1245                                        }
1246                                        // close connection.
1247                                        info!(
1248                                            "Dropping connection between {} and {} ({}) (as requested by a proxy rule)!",
1249                                            driver_addr,
1250                                            real_addr.expect("BUG: response rules are unavailable for dry-mode proxy!"),
1251                                            DisplayableShard(shard)
1252                                        );
1253                                        let _ = connection_close_signaler_clone.send(());
1254                                    }
1255                                };
1256
1257                                tokio::task::spawn(async {
1258                                    futures::join!(pass_action, forge_action, drop_action);
1259                                });
1260
1261                                continue 'mainloop;
1262                            }
1263                        }
1264                        let _ = driver_tx.send(response); // default action
1265                    }
1266                    None => return Ok(()),
1267                }
1268            }
1269        })
1270        .await
1271    }
1272}
1273
1274// Returns next free IP address for another proxy instance.
1275// Useful for concurrent testing.
1276#[doc(hidden)]
1277pub fn get_exclusive_local_address() -> IpAddr {
1278    // A big enough number reduces possibility of clashes with user-taken addresses:
1279    static ADDRESS_LOWER_THREE_OCTETS: AtomicU32 = AtomicU32::new(4242);
1280    let next_addr = ADDRESS_LOWER_THREE_OCTETS.fetch_add(1, Ordering::Relaxed);
1281    if next_addr > (u32::MAX >> 8) {
1282        panic!("Loopback address pool for tests depleted");
1283    }
1284    let next_addr_bytes = next_addr.to_le_bytes();
1285    IpAddr::V4(Ipv4Addr::new(
1286        127,
1287        next_addr_bytes[2],
1288        next_addr_bytes[1],
1289        next_addr_bytes[0],
1290    ))
1291}
1292
1293#[cfg(test)]
1294mod tests {
1295    use super::*;
1296    use crate::frame::{read_frame, read_request_frame, FrameType};
1297    use crate::{
1298        setup_tracing, Condition, Reaction as _, RequestReaction, ResponseOpcode, ResponseReaction,
1299    };
1300    use assert_matches::assert_matches;
1301    use bytes::{BufMut, BytesMut};
1302    use futures::future::{join, join3};
1303    use rand::RngCore;
1304    use scylla_cql::frame::frame_errors::FrameError;
1305    use scylla_cql::frame::types::write_string_multimap;
1306    use std::collections::HashMap;
1307    use std::mem;
1308    use std::str::FromStr;
1309    use std::time::Duration;
1310    use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
1311    use tokio::sync::oneshot;
1312
1313    fn random_body() -> Bytes {
1314        let body_len = (rand::random::<u32>() % 1000) as usize;
1315        let mut body = BytesMut::zeroed(body_len);
1316        rand::thread_rng().fill_bytes(body.as_mut());
1317        body.freeze()
1318    }
1319
1320    async fn respond_with_supported(
1321        conn: &mut TcpStream,
1322        supported_options: &HashMap<String, Vec<String>>,
1323    ) {
1324        let RequestFrame {
1325            params: recvd_params,
1326            opcode: recvd_opcode,
1327            body: recvd_body,
1328        } = read_request_frame(conn).await.unwrap();
1329        assert_eq!(recvd_params, HARDCODED_OPTIONS_PARAMS);
1330        assert_eq!(recvd_opcode, RequestOpcode::Options);
1331        assert_eq!(recvd_body, Bytes::new()); // body should be empty
1332
1333        let mut body = BytesMut::new();
1334        write_string_multimap(supported_options, &mut body).unwrap();
1335
1336        let body = body.freeze();
1337
1338        write_frame(
1339            HARDCODED_OPTIONS_PARAMS.for_response(),
1340            FrameOpcode::Response(ResponseOpcode::Supported),
1341            &body,
1342            conn,
1343        )
1344        .await
1345        .unwrap();
1346    }
1347
1348    fn supported_shards_count(shards_count: u16) -> HashMap<String, Vec<String>> {
1349        let mut sharded_info = HashMap::new();
1350        sharded_info.insert(
1351            String::from("SCYLLA_NR_SHARDS"),
1352            vec![shards_count.to_string()],
1353        );
1354        sharded_info
1355    }
1356
1357    fn supported_shard_number(shard_num: TargetShard) -> HashMap<String, Vec<String>> {
1358        let mut sharded_info = HashMap::new();
1359        sharded_info.insert(String::from("SCYLLA_SHARD"), vec![shard_num.to_string()]);
1360        sharded_info
1361    }
1362
1363    async fn respond_with_shards_count(conn: &mut TcpStream, shards_count: u16) {
1364        respond_with_supported(conn, &supported_shards_count(shards_count)).await;
1365    }
1366
1367    async fn respond_with_shard_num(conn: &mut TcpStream, shard_num: TargetShard) {
1368        respond_with_supported(conn, &supported_shard_number(shard_num)).await;
1369    }
1370
1371    fn next_local_address_with_port(port: u16) -> SocketAddr {
1372        SocketAddr::new(get_exclusive_local_address(), port)
1373    }
1374
1375    async fn identity_proxy_does_not_mutate_frames(shard_awareness: ShardAwareness) {
1376        let node1_real_addr = next_local_address_with_port(9876);
1377        let node1_proxy_addr = next_local_address_with_port(9876);
1378        let proxy = Proxy::new([Node::new(
1379            node1_real_addr,
1380            node1_proxy_addr,
1381            shard_awareness,
1382            None,
1383            None,
1384        )]);
1385        let running_proxy = proxy.run().await.unwrap();
1386
1387        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1388
1389        let params = FrameParams {
1390            flags: 0,
1391            version: 0x04,
1392            stream: 0,
1393        };
1394        let opcode = FrameOpcode::Request(RequestOpcode::Options);
1395
1396        let body = random_body();
1397
1398        let send_frame_to_shard = async {
1399            let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
1400
1401            write_frame(params, opcode, &body, &mut conn).await.unwrap();
1402            conn
1403        };
1404
1405        let mock_node_action = async {
1406            if let ShardAwareness::QueryNode = shard_awareness {
1407                respond_with_shards_count(&mut mock_node_listener.accept().await.unwrap().0, 1)
1408                    .await;
1409            }
1410            let (mut conn, _) = mock_node_listener.accept().await.unwrap();
1411            if shard_awareness.is_aware() {
1412                respond_with_shard_num(&mut conn, 1).await;
1413            }
1414            let RequestFrame {
1415                params: recvd_params,
1416                opcode: recvd_opcode,
1417                body: recvd_body,
1418            } = read_request_frame(&mut conn).await.unwrap();
1419            assert_eq!(recvd_params, params);
1420            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
1421            assert_eq!(recvd_body, body);
1422            conn
1423        };
1424
1425        // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
1426        let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1427        running_proxy.finish().await.unwrap();
1428    }
1429
1430    #[tokio::test]
1431    #[ntest::timeout(1000)]
1432    async fn identity_shard_unaware_proxy_does_not_mutate_frames() {
1433        setup_tracing();
1434        identity_proxy_does_not_mutate_frames(ShardAwareness::Unaware).await
1435    }
1436
1437    #[tokio::test]
1438    #[ntest::timeout(1000)]
1439    async fn identity_shard_aware_proxy_does_not_mutate_frames() {
1440        setup_tracing();
1441        identity_proxy_does_not_mutate_frames(ShardAwareness::QueryNode).await
1442    }
1443
1444    #[tokio::test]
1445    #[ntest::timeout(1000)]
1446    async fn shard_aware_proxy_is_transparent_for_connection_to_shards() {
1447        setup_tracing();
1448        async fn test_for_shards_num(shards_num: u16) {
1449            let node1_real_addr = next_local_address_with_port(9876);
1450            let node1_proxy_addr = next_local_address_with_port(9876);
1451            let proxy = Proxy::new([Node::new(
1452                node1_real_addr,
1453                node1_proxy_addr,
1454                ShardAwareness::FixedNum(shards_num),
1455                None,
1456                None,
1457            )]);
1458            let running_proxy = proxy.run().await.unwrap();
1459
1460            let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1461
1462            let (driver_addr_tx, driver_addr_rx) = oneshot::channel::<SocketAddr>();
1463
1464            let send_frame_to_shard = async {
1465                let socket = TcpSocket::new_v4().unwrap();
1466                socket
1467                    .bind(SocketAddr::from_str("0.0.0.0:0").unwrap())
1468                    .unwrap();
1469                let conn = socket.connect(node1_proxy_addr).await.unwrap();
1470                driver_addr_tx.send(conn.local_addr().unwrap()).unwrap();
1471                conn
1472            };
1473
1474            let mock_node_action = async {
1475                let (conn, remote_addr) = mock_node_listener.accept().await.unwrap();
1476                let driver_addr = driver_addr_rx.await.unwrap();
1477                assert_eq!(
1478                    driver_addr.port() % shards_num,
1479                    remote_addr.port() % shards_num
1480                );
1481                conn
1482            };
1483
1484            // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
1485            let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1486            running_proxy.finish().await.unwrap();
1487        }
1488
1489        for shard_num in 1..6 {
1490            test_for_shards_num(shard_num).await;
1491        }
1492    }
1493
1494    #[tokio::test]
1495    #[ntest::timeout(1000)]
1496    async fn shard_aware_proxy_queries_shards_number() {
1497        setup_tracing();
1498        async fn test_for_shards_num(shards_num: u16) {
1499            for shard_num in 0..shards_num {
1500                let node1_real_addr = next_local_address_with_port(9876);
1501                let node1_proxy_addr = next_local_address_with_port(9876);
1502                let proxy = Proxy::new([Node::new(
1503                    node1_real_addr,
1504                    node1_proxy_addr,
1505                    ShardAwareness::QueryNode,
1506                    None,
1507                    None,
1508                )]);
1509                let running_proxy = proxy.run().await.unwrap();
1510
1511                let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1512
1513                let (driver_addr_tx, driver_addr_rx) = oneshot::channel::<SocketAddr>();
1514
1515                let mock_driver_addr = next_local_address_with_port(shards_num * 1234 + shard_num);
1516                let send_frame_to_shard = async {
1517                    let socket = TcpSocket::new_v4().unwrap();
1518                    socket
1519                        .bind(mock_driver_addr)
1520                        .unwrap_or_else(|_| panic!("driver_addr failed: {}", mock_driver_addr));
1521                    driver_addr_tx.send(socket.local_addr().unwrap()).unwrap();
1522                    socket.connect(node1_proxy_addr).await.unwrap()
1523                };
1524
1525                let mock_node_action = async {
1526                    respond_with_shards_count(
1527                        &mut mock_node_listener.accept().await.unwrap().0,
1528                        shards_num,
1529                    )
1530                    .await;
1531                    let (conn, remote_addr) = mock_node_listener.accept().await.unwrap();
1532                    let driver_addr = driver_addr_rx.await.unwrap();
1533                    assert_eq!(
1534                        driver_addr.port() % shards_num,
1535                        remote_addr.port() % shards_num
1536                    );
1537                    conn
1538                };
1539
1540                let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1541                running_proxy.finish().await.unwrap();
1542            }
1543        }
1544
1545        for shard_num in 1..6 {
1546            test_for_shards_num(shard_num).await;
1547        }
1548    }
1549
1550    #[tokio::test]
1551    #[ntest::timeout(1000)]
1552    async fn forger_proxy_forges_response() {
1553        setup_tracing();
1554        let node1_real_addr = next_local_address_with_port(9876);
1555        let node1_proxy_addr = next_local_address_with_port(9876);
1556
1557        let this_shall_pass = b"This.Shall.Pass.";
1558        let test_msg = b"Test";
1559
1560        let proxy = Proxy::new([Node::new(
1561            node1_real_addr,
1562            node1_proxy_addr,
1563            ShardAwareness::Unaware,
1564            Some(vec![
1565                RequestRule(
1566                    Condition::RequestOpcode(RequestOpcode::Register),
1567                    RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
1568                        ResponseFrame {
1569                            params: params.for_response(),
1570                            opcode: ResponseOpcode::Event,
1571                            body: Bytes::from_static(test_msg),
1572                        }
1573                    })),
1574                ),
1575                RequestRule(
1576                    Condition::BodyContainsCaseSensitive(Box::new(*this_shall_pass)),
1577                    RequestReaction::noop(),
1578                ),
1579                RequestRule(
1580                    Condition::True, // only the first matching rule is applied, so "True" covers all remaining cases
1581                    RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
1582                        ResponseFrame {
1583                            params: params.for_response(),
1584                            opcode: ResponseOpcode::Ready,
1585                            body: Bytes::new(),
1586                        }
1587                    })),
1588                ),
1589            ]),
1590            None,
1591        )]);
1592        let running_proxy = proxy.run().await.unwrap();
1593
1594        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1595
1596        let params1 = FrameParams {
1597            flags: 3,
1598            version: 0x42,
1599            stream: 42,
1600        };
1601        let opcode1 = FrameOpcode::Request(RequestOpcode::Startup);
1602
1603        let params2 = FrameParams {
1604            flags: 4,
1605            version: 0x04,
1606            stream: 17,
1607        };
1608        let opcode2 = FrameOpcode::Request(RequestOpcode::Register);
1609
1610        let params3 = FrameParams {
1611            flags: 8,
1612            version: 0x04,
1613            stream: 11,
1614        };
1615        let opcode3 = FrameOpcode::Request(RequestOpcode::Execute);
1616
1617        let body1 = random_body();
1618        let body2 = random_body();
1619        let body3 = {
1620            let mut body = BytesMut::new();
1621            body.put(&b"uSeLeSs JuNk"[..]);
1622            body.put(&this_shall_pass[..]);
1623            body.freeze()
1624        };
1625
1626        let send_frame_to_shard = async {
1627            let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
1628
1629            write_frame(params1, opcode1, &body1, &mut conn)
1630                .await
1631                .unwrap();
1632            write_frame(params2, opcode2, &body2, &mut conn)
1633                .await
1634                .unwrap();
1635            write_frame(params3, opcode3, &body3, &mut conn)
1636                .await
1637                .unwrap();
1638
1639            let ResponseFrame {
1640                params: recvd_params,
1641                opcode: recvd_opcode,
1642                body: recvd_body,
1643            } = read_response_frame(&mut conn).await.unwrap();
1644            assert_eq!(recvd_params, params1.for_response());
1645            assert_eq!(recvd_opcode, ResponseOpcode::Ready);
1646            assert_eq!(recvd_body, Bytes::new());
1647
1648            let ResponseFrame {
1649                params: recvd_params,
1650                opcode: recvd_opcode,
1651                body: recvd_body,
1652            } = read_response_frame(&mut conn).await.unwrap();
1653            assert_eq!(recvd_params, params2.for_response());
1654            assert_eq!(recvd_opcode, ResponseOpcode::Event);
1655            assert_eq!(recvd_body, Bytes::from_static(test_msg));
1656
1657            conn
1658        };
1659
1660        let mock_node_action = async {
1661            let (mut conn, _) = mock_node_listener.accept().await.unwrap();
1662            let RequestFrame {
1663                params: recvd_params,
1664                opcode: recvd_opcode,
1665                body: recvd_body,
1666            } = read_request_frame(&mut conn).await.unwrap();
1667            assert_eq!(recvd_params, params3);
1668            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode3);
1669            assert_eq!(recvd_body, body3);
1670
1671            conn
1672        };
1673
1674        let (mut node_conn, mut driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1675
1676        running_proxy.finish().await.unwrap();
1677
1678        assert_matches!(driver_conn.read(&mut [0u8; 1]).await, Ok(0));
1679        assert_matches!(node_conn.read(&mut [0u8; 1]).await, Ok(0));
1680    }
1681
1682    #[tokio::test]
1683    #[ntest::timeout(1000)]
1684    async fn ad_hoc_rules_changing() {
1685        setup_tracing();
1686        let node1_real_addr = next_local_address_with_port(9876);
1687        let node1_proxy_addr = next_local_address_with_port(9876);
1688        let proxy = Proxy::new([Node::new(
1689            node1_real_addr,
1690            node1_proxy_addr,
1691            ShardAwareness::Unaware,
1692            None,
1693            None,
1694        )]);
1695        let mut running_proxy = proxy.run().await.unwrap();
1696
1697        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1698
1699        let params = FrameParams {
1700            flags: 0,
1701            version: 0x04,
1702            stream: 0,
1703        };
1704        let opcode = FrameOpcode::Request(RequestOpcode::Options);
1705
1706        let body = random_body();
1707
1708        let (mut driver, mut node) = {
1709            let results = join(
1710                TcpStream::connect(node1_proxy_addr),
1711                mock_node_listener.accept(),
1712            )
1713            .await;
1714            (results.0.unwrap(), results.1.unwrap().0)
1715        };
1716
1717        async fn request(
1718            driver: &mut TcpStream,
1719            node: &mut TcpStream,
1720            params: FrameParams,
1721            opcode: FrameOpcode,
1722            body: &Bytes,
1723        ) -> Result<RequestFrame, FrameError> {
1724            let (send_res, recv_res) = join(
1725                write_frame(params, opcode, &body.clone(), driver),
1726                read_request_frame(node),
1727            )
1728            .await;
1729            send_res.unwrap();
1730            recv_res
1731        }
1732        {
1733            // one run still without custom rules
1734            let RequestFrame {
1735                params: recvd_params,
1736                opcode: recvd_opcode,
1737                body: recvd_body,
1738            } = request(&mut driver, &mut node, params, opcode, &body)
1739                .await
1740                .unwrap();
1741            assert_eq!(recvd_params, params);
1742            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
1743            assert_eq!(recvd_body, body);
1744        }
1745        running_proxy.running_nodes[0].change_request_rules(Some(vec![RequestRule(
1746            Condition::True,
1747            RequestReaction::drop_frame(),
1748        )]));
1749
1750        {
1751            // one run with custom rules
1752            tokio::select! {
1753                res = request(&mut driver, &mut node, params, opcode, &body) => panic!("Rules did not work: received response {:?}", res),
1754                _ = tokio::time::sleep(std::time::Duration::from_millis(20)) => (),
1755            };
1756        }
1757
1758        running_proxy.turn_off_rules();
1759
1760        {
1761            // one run already without custom rules
1762            let RequestFrame {
1763                params: recvd_params,
1764                opcode: recvd_opcode,
1765                body: recvd_body,
1766            } = request(&mut driver, &mut node, params, opcode, &body)
1767                .await
1768                .unwrap();
1769            assert_eq!(recvd_params, params);
1770            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
1771            assert_eq!(recvd_body, body);
1772        }
1773
1774        running_proxy.finish().await.unwrap();
1775    }
1776
1777    #[tokio::test]
1778    #[ntest::timeout(2000)]
1779    async fn limited_times_condition_expires() {
1780        setup_tracing();
1781        const FAILING_TRIES: usize = 4;
1782        const PASSING_TRIES: usize = 5;
1783
1784        let node1_real_addr = next_local_address_with_port(9876);
1785        let node1_proxy_addr = next_local_address_with_port(9876);
1786        let proxy = Proxy::new([Node::new(
1787            node1_real_addr,
1788            node1_proxy_addr,
1789            ShardAwareness::Unaware,
1790            Some(vec![
1791                RequestRule(
1792                    // this will be always fired after first PASSING_TRIES + FAILING_TRIES
1793                    Condition::not(Condition::TrueForLimitedTimes(
1794                        FAILING_TRIES + PASSING_TRIES,
1795                    )),
1796                    RequestReaction::drop_frame(),
1797                ),
1798                RequestRule(
1799                    // this will be fired for PASSING_TRIES after first FAILING_TRIES
1800                    Condition::not(Condition::TrueForLimitedTimes(FAILING_TRIES)),
1801                    RequestReaction::noop(),
1802                ),
1803                RequestRule(
1804                    // this will be fired for first FAILING_TRIES
1805                    Condition::True,
1806                    RequestReaction::drop_frame(),
1807                ),
1808            ]),
1809            None,
1810        )]);
1811        let running_proxy = proxy.run().await.unwrap();
1812
1813        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1814
1815        let params = FrameParams {
1816            flags: 0,
1817            version: 0x04,
1818            stream: 0,
1819        };
1820        let opcode = FrameOpcode::Request(RequestOpcode::Options);
1821        let body = random_body();
1822
1823        let (mut driver, mut node) = {
1824            let results = join(
1825                TcpStream::connect(node1_proxy_addr),
1826                mock_node_listener.accept(),
1827            )
1828            .await;
1829            (results.0.unwrap(), results.1.unwrap().0)
1830        };
1831
1832        async fn request(
1833            driver: &mut TcpStream,
1834            node: &mut TcpStream,
1835            params: FrameParams,
1836            opcode: FrameOpcode,
1837            body: &Bytes,
1838        ) -> Result<RequestFrame, FrameError> {
1839            let (send_res, recv_res) = join(
1840                write_frame(params, opcode, &body.clone(), driver),
1841                read_request_frame(node),
1842            )
1843            .await;
1844            send_res.unwrap();
1845            recv_res
1846        }
1847
1848        for _ in 0..FAILING_TRIES {
1849            tokio::select! {
1850                res = request(&mut driver, &mut node, params, opcode, &body) => panic!("Rules did not work: received response {:?}", res),
1851                _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => (),
1852            };
1853        }
1854
1855        for _ in 0..PASSING_TRIES {
1856            let RequestFrame {
1857                params: recvd_params,
1858                opcode: recvd_opcode,
1859                body: recvd_body,
1860            } = request(&mut driver, &mut node, params, opcode, &body)
1861                .await
1862                .unwrap();
1863            assert_eq!(recvd_params, params);
1864            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
1865            assert_eq!(recvd_body, body);
1866        }
1867
1868        for _ in 0..3 {
1869            // any further number of requests should fail
1870            tokio::select! {
1871                res = request(&mut driver, &mut node, params, opcode, &body) => panic!("Rules did not work: received response {:?}", res),
1872                _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => (),
1873            };
1874        }
1875
1876        running_proxy.finish().await.unwrap();
1877    }
1878
1879    #[tokio::test]
1880    #[ntest::timeout(1000)]
1881    async fn proxy_reports_requests_and_responses_as_feedback() {
1882        setup_tracing();
1883        let node1_real_addr = next_local_address_with_port(9876);
1884        let node1_proxy_addr = next_local_address_with_port(9876);
1885
1886        let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
1887        let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
1888        let proxy = Proxy::new([Node::new(
1889            node1_real_addr,
1890            node1_proxy_addr,
1891            ShardAwareness::Unaware,
1892            Some(vec![RequestRule(
1893                Condition::True,
1894                RequestReaction::drop_frame().with_feedback_when_performed(request_feedback_tx),
1895            )]),
1896            Some(vec![ResponseRule(
1897                Condition::True,
1898                ResponseReaction::drop_frame().with_feedback_when_performed(response_feedback_tx),
1899            )]),
1900        )]);
1901        let running_proxy = proxy.run().await.unwrap();
1902
1903        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1904
1905        let params = FrameParams {
1906            flags: 0,
1907            version: 0x04,
1908            stream: 0,
1909        };
1910        let request_opcode = FrameOpcode::Request(RequestOpcode::Options);
1911        let response_opcode = FrameOpcode::Response(ResponseOpcode::Ready);
1912
1913        let body = random_body();
1914
1915        let send_frame_to_shard = async {
1916            let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
1917            write_frame(params, request_opcode, &body, &mut conn)
1918                .await
1919                .unwrap();
1920            conn
1921        };
1922
1923        let mock_node_action = async {
1924            let (mut conn, _) = mock_node_listener.accept().await.unwrap();
1925            write_frame(params.for_response(), response_opcode, &body, &mut conn)
1926                .await
1927                .unwrap();
1928            conn
1929        };
1930
1931        // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
1932        let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1933
1934        let (feedback_request, _shard) = request_feedback_rx.recv().await.unwrap();
1935        assert_eq!(feedback_request.params, params);
1936        assert_eq!(
1937            FrameOpcode::Request(feedback_request.opcode),
1938            request_opcode
1939        );
1940        assert_eq!(feedback_request.body, body);
1941        let (feedback_response, _shard) = response_feedback_rx.recv().await.unwrap();
1942        assert_eq!(feedback_response.params, params.for_response());
1943        assert_eq!(
1944            FrameOpcode::Response(feedback_response.opcode),
1945            response_opcode
1946        );
1947        assert_eq!(feedback_response.body, body);
1948
1949        running_proxy.finish().await.unwrap();
1950    }
1951
1952    #[tokio::test]
1953    #[ntest::timeout(1000)]
1954    async fn sanity_check_reports_errors() {
1955        setup_tracing();
1956        let node1_real_addr = next_local_address_with_port(9876);
1957        let node1_proxy_addr = next_local_address_with_port(9876);
1958        let proxy = Proxy::new([Node::new(
1959            node1_real_addr,
1960            node1_proxy_addr,
1961            ShardAwareness::Unaware,
1962            None,
1963            None,
1964        )]);
1965        let mut running_proxy = proxy.run().await.unwrap();
1966
1967        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1968
1969        let send_frame_to_shard = async {
1970            let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
1971
1972            conn.write_all(b"uselessJunk").await.unwrap();
1973            conn
1974        };
1975
1976        let mock_node_action = async {
1977            let (conn, _) = mock_node_listener.accept().await.unwrap();
1978            conn
1979        };
1980
1981        let (node_conn, driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1982
1983        running_proxy.sanity_check().unwrap();
1984
1985        mem::drop(driver_conn);
1986        assert_matches!(
1987            running_proxy.wait_for_error().await,
1988            Some(ProxyError::Worker(WorkerError::DriverDisconnected(_)))
1989        );
1990        running_proxy.sanity_check().unwrap();
1991
1992        mem::drop(node_conn);
1993        assert_matches!(
1994            running_proxy.wait_for_error().await,
1995            Some(ProxyError::Worker(WorkerError::NodeDisconnected(_)))
1996        );
1997        running_proxy.sanity_check().unwrap();
1998
1999        // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
2000        let _ = running_proxy.finish().await;
2001    }
2002
2003    #[tokio::test]
2004    #[ntest::timeout(1000)]
2005    async fn proxy_processes_requests_concurrently() {
2006        setup_tracing();
2007        let node1_real_addr = next_local_address_with_port(9876);
2008        let node1_proxy_addr = next_local_address_with_port(9876);
2009
2010        let delay = Duration::from_millis(30);
2011
2012        let proxy = Proxy::new([Node::new(
2013            node1_real_addr,
2014            node1_proxy_addr,
2015            ShardAwareness::Unaware,
2016            Some(vec![RequestRule(
2017                Condition::TrueForLimitedTimes(1),
2018                RequestReaction::delay(delay),
2019            )]),
2020            None,
2021        )]);
2022        let running_proxy = proxy.run().await.unwrap();
2023
2024        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2025
2026        let params1 = FrameParams {
2027            flags: 0,
2028            version: 0x04,
2029            stream: 0,
2030        };
2031        let opcode1 = FrameOpcode::Request(RequestOpcode::Options);
2032
2033        let body1 = random_body();
2034
2035        let params2 = FrameParams {
2036            flags: 0,
2037            version: 0x04,
2038            stream: 0,
2039        };
2040        let opcode2 = FrameOpcode::Request(RequestOpcode::Register);
2041
2042        let body2 = random_body();
2043
2044        let send_frame_to_shard = async {
2045            let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2046
2047            write_frame(params1, opcode1, &body1, &mut conn)
2048                .await
2049                .unwrap();
2050            write_frame(params2, opcode2, &body2, &mut conn)
2051                .await
2052                .unwrap();
2053            conn
2054        };
2055
2056        let mock_node_action = async {
2057            let (mut conn, _) = mock_node_listener.accept().await.unwrap();
2058            let RequestFrame {
2059                params: recvd_params,
2060                opcode: recvd_opcode,
2061                body: recvd_body,
2062            } = read_request_frame(&mut conn).await.unwrap();
2063            assert_eq!(recvd_params, params2);
2064            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode2);
2065            assert_eq!(recvd_body, body2);
2066            conn
2067        };
2068
2069        // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
2070        let (_node_conn, _driver_conn) =
2071            tokio::time::timeout(delay, join(mock_node_action, send_frame_to_shard))
2072                .await
2073                .expect("Request processing was not concurrent");
2074        running_proxy.finish().await.unwrap();
2075    }
2076
2077    #[tokio::test]
2078    #[ntest::timeout(1000)]
2079    async fn dry_mode_proxy_drops_incoming_frames() {
2080        setup_tracing();
2081        let node1_proxy_addr = next_local_address_with_port(9876);
2082        let proxy = Proxy::new([Node::new_dry_mode(node1_proxy_addr, None)]);
2083        let running_proxy = proxy.run().await.unwrap();
2084
2085        let params = FrameParams {
2086            flags: 0,
2087            version: 0x04,
2088            stream: 0,
2089        };
2090        let opcode = FrameOpcode::Request(RequestOpcode::Options);
2091
2092        let body = random_body();
2093
2094        let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2095
2096        write_frame(params, opcode, &body, &mut conn).await.unwrap();
2097        // We assert that after sufficiently long time, no error happens inside proxy.
2098        tokio::time::sleep(Duration::from_millis(3)).await;
2099        running_proxy.finish().await.unwrap();
2100    }
2101
2102    #[tokio::test]
2103    #[ntest::timeout(1000)]
2104    async fn dry_mode_forger_proxy_forges_response() {
2105        setup_tracing();
2106        let node1_proxy_addr = next_local_address_with_port(9876);
2107
2108        let this_shall_pass = b"This.Shall.Pass.";
2109        let test_msg = b"Test";
2110
2111        let proxy = Proxy::new([Node::new_dry_mode(
2112            node1_proxy_addr,
2113            Some(vec![
2114                RequestRule(
2115                    Condition::RequestOpcode(RequestOpcode::Register),
2116                    RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
2117                        ResponseFrame {
2118                            params: params.for_response(),
2119                            opcode: ResponseOpcode::Event,
2120                            body: Bytes::from_static(test_msg),
2121                        }
2122                    })),
2123                ),
2124                RequestRule(
2125                    Condition::BodyContainsCaseSensitive(Box::new(*this_shall_pass)),
2126                    RequestReaction::noop(),
2127                ),
2128                RequestRule(
2129                    Condition::True, // only the first matching rule is applied, so "True" covers all remaining cases
2130                    RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
2131                        ResponseFrame {
2132                            params: params.for_response(),
2133                            opcode: ResponseOpcode::Ready,
2134                            body: Bytes::new(),
2135                        }
2136                    })),
2137                ),
2138            ]),
2139        )]);
2140        let running_proxy = proxy.run().await.unwrap();
2141
2142        let params1 = FrameParams {
2143            flags: 3,
2144            version: 0x42,
2145            stream: 42,
2146        };
2147        let opcode1 = FrameOpcode::Request(RequestOpcode::Startup);
2148
2149        let params2 = FrameParams {
2150            flags: 4,
2151            version: 0x04,
2152            stream: 17,
2153        };
2154        let opcode2 = FrameOpcode::Request(RequestOpcode::Register);
2155
2156        let params3 = FrameParams {
2157            flags: 8,
2158            version: 0x04,
2159            stream: 11,
2160        };
2161        let opcode3 = FrameOpcode::Request(RequestOpcode::Execute);
2162
2163        let body1 = random_body();
2164        let body2 = random_body();
2165        let body3 = {
2166            let mut body = BytesMut::new();
2167            body.put(&b"uSeLeSs JuNk"[..]);
2168            body.put(&this_shall_pass[..]);
2169            body.freeze()
2170        };
2171
2172        let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2173
2174        write_frame(params1, opcode1, &body1, &mut conn)
2175            .await
2176            .unwrap();
2177        write_frame(params2, opcode2, &body2, &mut conn)
2178            .await
2179            .unwrap();
2180        write_frame(params3, opcode3, &body3, &mut conn)
2181            .await
2182            .unwrap();
2183
2184        let ResponseFrame {
2185            params: recvd_params,
2186            opcode: recvd_opcode,
2187            body: recvd_body,
2188        } = read_response_frame(&mut conn).await.unwrap();
2189        assert_eq!(recvd_params, params1.for_response());
2190        assert_eq!(recvd_opcode, ResponseOpcode::Ready);
2191        assert_eq!(recvd_body, Bytes::new());
2192
2193        let ResponseFrame {
2194            params: recvd_params,
2195            opcode: recvd_opcode,
2196            body: recvd_body,
2197        } = read_response_frame(&mut conn).await.unwrap();
2198        assert_eq!(recvd_params, params2.for_response());
2199        assert_eq!(recvd_opcode, ResponseOpcode::Event);
2200        assert_eq!(recvd_body, Bytes::from_static(test_msg));
2201
2202        running_proxy.finish().await.unwrap();
2203
2204        assert_matches!(conn.read(&mut [0u8; 1]).await, Ok(0));
2205    }
2206
2207    // The test asserts that once a (mock) driver connects to the proxy from some port,
2208    // the proxy will connect to a shard corresponding to that port and that the target
2209    // shard number will be sent through the feedback channel.
2210    #[tokio::test]
2211    #[ntest::timeout(1000)]
2212    async fn proxy_reports_target_shard_as_feedback() {
2213        setup_tracing();
2214
2215        let node_port = 10101;
2216        let node_real_addr = next_local_address_with_port(node_port);
2217        let mock_node_listener = TcpListener::bind(node_real_addr).await.unwrap();
2218
2219        let params = FrameParams {
2220            flags: 0,
2221            version: 0x04,
2222            stream: 0,
2223        };
2224        let request_opcode = FrameOpcode::Request(RequestOpcode::Options);
2225        let response_opcode = FrameOpcode::Response(ResponseOpcode::Ready);
2226
2227        let body = random_body();
2228
2229        for shards_count in 2..9 {
2230            // Two driver connections are simulated, each to a different shard.
2231            let driver1_shard = shards_count - 1;
2232            let driver2_shard = shards_count - 2;
2233            let node_proxy_addr = next_local_address_with_port(node_port);
2234
2235            let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
2236            let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
2237
2238            let proxy = Proxy::new([Node::new(
2239                node_real_addr,
2240                node_proxy_addr,
2241                ShardAwareness::FixedNum(shards_count),
2242                Some(vec![RequestRule(
2243                    Condition::True,
2244                    RequestReaction::drop_frame().with_feedback_when_performed(request_feedback_tx),
2245                )]),
2246                Some(vec![ResponseRule(
2247                    Condition::True,
2248                    ResponseReaction::drop_frame()
2249                        .with_feedback_when_performed(response_feedback_tx),
2250                )]),
2251            )]);
2252            let running_proxy = proxy.run().await.unwrap();
2253
2254            /// Choose a source port `p` such that `shard == shard_of_source_port(p)`.
2255            fn draw_source_port_for_shard(shards_count: u16, shard: u16) -> u16 {
2256                assert!(shard < shards_count);
2257                (49152 + shards_count - 1) / shards_count * shards_count + shard
2258            }
2259
2260            async fn bind_socket_for_shard(shards_count: u16, shard: u16) -> TcpSocket {
2261                let socket = TcpSocket::new_v4().unwrap();
2262                let initial_port = draw_source_port_for_shard(shards_count, shard);
2263
2264                let mut desired_addr =
2265                    SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), initial_port);
2266                while socket.bind(desired_addr).is_err() {
2267                    // in search for a port that translates to the desired shard
2268                    let next_port = desired_addr.port().wrapping_add(shards_count);
2269                    if next_port == initial_port {
2270                        panic!("No more ports left");
2271                    }
2272                    desired_addr.set_port(next_port);
2273                }
2274
2275                socket
2276            }
2277
2278            let body_ref = &body;
2279            let send_frame_to_shard = |driver_shard: u16| async move {
2280                let socket = bind_socket_for_shard(shards_count, driver_shard).await;
2281                let mut conn = socket.connect(node_proxy_addr).await.unwrap();
2282
2283                write_frame(params, request_opcode, body_ref, &mut conn)
2284                    .await
2285                    .unwrap();
2286                conn
2287            };
2288
2289            let mock_driver1_action = send_frame_to_shard(driver1_shard);
2290            let mock_driver2_action = send_frame_to_shard(driver2_shard);
2291
2292            // Accepts two connections and sends a response to each of them.
2293            let mock_node_action = async {
2294                let mut conns_futs = (0..2)
2295                    .map(|_| async {
2296                        let (mut conn, driver_addr) = mock_node_listener.accept().await.unwrap();
2297                        respond_with_shard_num(&mut conn, driver_addr.port() % shards_count).await;
2298                        write_frame(params.for_response(), response_opcode, body_ref, &mut conn)
2299                            .await
2300                            .unwrap();
2301                        conn
2302                    })
2303                    .collect::<Vec<_>>();
2304                let conn2 = conns_futs.pop().unwrap().await;
2305                let conn1 = conns_futs.pop().unwrap().await;
2306                (conn1, conn2)
2307            };
2308
2309            // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
2310            let (_node_conns, _driver1_conn, _driver2_conn) =
2311                join3(mock_node_action, mock_driver1_action, mock_driver2_action).await;
2312
2313            let assert_feedback_request = |feedback_request: RequestFrame| {
2314                assert_eq!(feedback_request.params, params);
2315                assert_eq!(
2316                    FrameOpcode::Request(feedback_request.opcode),
2317                    request_opcode
2318                );
2319                assert_eq!(feedback_request.body, body);
2320            };
2321
2322            let assert_feedback_response = |feedback_response: ResponseFrame| {
2323                assert_eq!(feedback_response.params, params.for_response());
2324                assert_eq!(
2325                    FrameOpcode::Response(feedback_response.opcode),
2326                    response_opcode
2327                );
2328                assert_eq!(feedback_response.body, body);
2329            };
2330
2331            let (feedback_request, shard1) = request_feedback_rx.recv().await.unwrap();
2332            assert_feedback_request(feedback_request);
2333            let (feedback_request, shard2) = request_feedback_rx.recv().await.unwrap();
2334            assert_feedback_request(feedback_request);
2335            let (feedback_response, shard3) = response_feedback_rx.recv().await.unwrap();
2336            assert_feedback_response(feedback_response);
2337            let (feedback_response, shard4) = response_feedback_rx.recv().await.unwrap();
2338            assert_feedback_response(feedback_response);
2339
2340            // expected: {driver1_shard request, driver1_shard response, driver2_shard request, driver2_shard response}
2341            let mut expected_shards = [driver1_shard, driver1_shard, driver2_shard, driver2_shard];
2342            expected_shards.sort_unstable();
2343
2344            let mut got_shards = [
2345                shard1.unwrap(),
2346                shard2.unwrap(),
2347                shard3.unwrap(),
2348                shard4.unwrap(),
2349            ];
2350            got_shards.sort_unstable();
2351
2352            assert_eq!(expected_shards, got_shards);
2353
2354            running_proxy.finish().await.unwrap();
2355        }
2356    }
2357
2358    #[tokio::test]
2359    #[ntest::timeout(1000)]
2360    async fn proxy_ignores_control_connection_messages() {
2361        setup_tracing();
2362        let node1_real_addr = next_local_address_with_port(9876);
2363        let node1_proxy_addr = next_local_address_with_port(9876);
2364
2365        let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
2366        let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
2367        let proxy = Proxy::new([Node::new(
2368            node1_real_addr,
2369            node1_proxy_addr,
2370            ShardAwareness::Unaware,
2371            Some(vec![RequestRule(
2372                Condition::not(Condition::ConnectionRegisteredAnyEvent),
2373                RequestReaction::noop().with_feedback_when_performed(request_feedback_tx),
2374            )]),
2375            Some(vec![ResponseRule(
2376                Condition::not(Condition::ConnectionRegisteredAnyEvent),
2377                ResponseReaction::noop().with_feedback_when_performed(response_feedback_tx),
2378            )]),
2379        )]);
2380        let running_proxy = proxy.run().await.unwrap();
2381
2382        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2383
2384        let (mut client_socket, mut server_socket) = join(
2385            async { TcpStream::connect(node1_proxy_addr).await.unwrap() },
2386            async { mock_node_listener.accept().await.unwrap().0 },
2387        )
2388        .await;
2389
2390        async fn perform_reqest_response<'a>(
2391            req_opcode: RequestOpcode,
2392            resp_opcode: ResponseOpcode,
2393            client_socket_ref: &'a mut TcpStream,
2394            server_socket_ref: &'a mut TcpStream,
2395            body_base: &'a str,
2396        ) {
2397            let params = FrameParams {
2398                flags: 0,
2399                version: 0x04,
2400                stream: 0,
2401            };
2402
2403            write_frame(
2404                params,
2405                FrameOpcode::Request(req_opcode),
2406                &(body_base.to_string() + "|request|").into(),
2407                client_socket_ref,
2408            )
2409            .await
2410            .unwrap();
2411
2412            let received_request = read_frame(server_socket_ref, FrameType::Request)
2413                .await
2414                .unwrap();
2415            assert_eq!(received_request.1, FrameOpcode::Request(req_opcode));
2416
2417            write_frame(
2418                params.for_response(),
2419                FrameOpcode::Response(resp_opcode),
2420                &(body_base.to_string() + "|response|").into(),
2421                server_socket_ref,
2422            )
2423            .await
2424            .unwrap();
2425
2426            let received_response = read_frame(client_socket_ref, FrameType::Response)
2427                .await
2428                .unwrap();
2429            assert_eq!(received_response.1, FrameOpcode::Response(resp_opcode));
2430        }
2431
2432        // Messages before REGISTER should be fed back to channels
2433        for i in 0..5 {
2434            perform_reqest_response(
2435                RequestOpcode::Query,
2436                ResponseOpcode::Result,
2437                &mut client_socket,
2438                &mut server_socket,
2439                &format!("message_before_{i}"),
2440            )
2441            .await
2442        }
2443
2444        perform_reqest_response(
2445            RequestOpcode::Register,
2446            ResponseOpcode::Result,
2447            &mut client_socket,
2448            &mut server_socket,
2449            "message_register",
2450        )
2451        .await;
2452
2453        // Messages after REGISTER should be passed through without feedback
2454        for i in 0..5 {
2455            perform_reqest_response(
2456                RequestOpcode::Query,
2457                ResponseOpcode::Result,
2458                &mut client_socket,
2459                &mut server_socket,
2460                &format!("message_after_{i}"),
2461            )
2462            .await
2463        }
2464
2465        running_proxy.finish().await.unwrap();
2466
2467        for _ in 0..5 {
2468            let (feedback_request, _shard) = request_feedback_rx.recv().await.unwrap();
2469            assert_eq!(feedback_request.opcode, RequestOpcode::Query);
2470            let (feedback_response, _shard) = response_feedback_rx.recv().await.unwrap();
2471            assert_eq!(feedback_response.opcode, ResponseOpcode::Result);
2472        }
2473
2474        // Response to REGISTER and further requests / responses should be ignored
2475        let _ = request_feedback_rx.try_recv().unwrap_err();
2476        let _ = response_feedback_rx.try_recv().unwrap_err();
2477    }
2478}