Skip to main content

scylla_proxy/
proxy.rs

1use crate::actions::{EvaluationContext, RequestRule, ResponseRule};
2use crate::errors::{DoorkeeperError, ProxyError, WorkerError};
3use crate::frame::{
4    self, FrameOpcode, FrameParams, RequestFrame, ResponseFrame, ResponseOpcode,
5    read_response_frame, write_frame,
6};
7use crate::{RequestOpcode, TargetShard};
8use bytes::Bytes;
9use compression::no_compression;
10use scylla_cql::frame::types::read_string_multimap;
11use std::collections::HashMap;
12use std::env::VarError;
13use std::fmt::Display;
14use std::future::Future;
15use std::net::{IpAddr, Ipv4Addr, SocketAddr};
16use std::sync::atomic::{AtomicBool, AtomicU8, AtomicU32, Ordering};
17use std::sync::{Arc, Mutex};
18use tokio::io::{AsyncRead, AsyncWrite};
19use tokio::net::{TcpListener, TcpSocket, TcpStream};
20use tokio::sync::mpsc::error::TryRecvError;
21use tokio::sync::{broadcast, mpsc};
22use tracing::{debug, error, info, trace, warn};
23
24// Used to notify the user that the proxy finished - this happens when all Senders are dropped.
25type FinishWaiter = mpsc::Receiver<()>;
26type FinishGuard = mpsc::Sender<()>;
27
28// Used to tell all the proxy workers to stop when the user requests that with [RunningProxy::finish()].
29type TerminateNotifier = tokio::sync::broadcast::Receiver<()>;
30type TerminateSignaler = tokio::sync::broadcast::Sender<()>;
31
32// Used to tell all proxy workers working on same connection to stop when
33// a rule being applied has connection drop set.
34type ConnectionCloseNotifier = tokio::sync::broadcast::Receiver<()>;
35type ConnectionCloseSignaler = tokio::sync::broadcast::Sender<()>;
36
37// Used to gather errors from all proxy workers and propagate them to the proxy user,
38// returning the first of them from [RunningProxy::finish()].
39type ErrorPropagator = mpsc::UnboundedSender<ProxyError>;
40type ErrorSink = mpsc::UnboundedReceiver<ProxyError>;
41
42/// Tracks the number of active driver connections to a proxy node.
43///
44/// Shared between the [`Doorkeeper`] (which increments on accept) and
45/// [`RunningNode`] (which exposes [`RunningNode::wait_for_connection`]).
46/// Each accepted connection gets an [`Arc<ConnectionLifetime>`] guard that
47/// decrements the count when the last worker for that connection exits.
48struct ConnectionTracker {
49    active_count: std::sync::atomic::AtomicUsize,
50    notify: tokio::sync::Notify,
51}
52
53impl ConnectionTracker {
54    /// Track a new connection.
55    ///
56    /// Increments active count and creates a per-connection lifetime guard.
57    /// When the guard ([ConnectionLifetime]) is Drop'ed,
58    /// the active count is decremented.
59    fn register_connection(self: &Arc<Self>) -> Arc<ConnectionLifetime> {
60        self.active_count.fetch_add(1, Ordering::Relaxed);
61        self.notify.notify_waiters();
62        Arc::new(ConnectionLifetime {
63            tracker: Arc::clone(self),
64        })
65    }
66}
67
68/// Per-connection guard that decrements the active connection count on drop.
69///
70/// Wrapped in [`Arc`] and cloned to each worker task spawned for a connection.
71/// When the last worker finishes and drops its clone, the inner value drops
72/// and the active count is decremented.
73struct ConnectionLifetime {
74    tracker: Arc<ConnectionTracker>,
75}
76
77impl Drop for ConnectionLifetime {
78    fn drop(&mut self) {
79        self.tracker.active_count.fetch_sub(1, Ordering::Relaxed);
80    }
81}
82
83static HARDCODED_OPTIONS_PARAMS: FrameParams = FrameParams {
84    flags: 0,
85    version: 0x04,
86    stream: 0,
87};
88
89/// Specifies proxy's behaviour regarding shard awareness.
90#[derive(Clone, Copy, Debug)]
91pub enum ShardAwareness {
92    /// Acts as if the connection was made to the shard-unaware port.
93    Unaware,
94    /// The first time the driver attempts to connect to the particular node (through proxy),
95    /// the related node is first queried on a temporary connection for its number of shards,
96    /// and only then establishes another connection for the driver's real communication with the node.
97    /// If the queried node does not provide sharding info (e.g. in case of a Cassandra node),
98    /// then this mode behaves as Unaware.
99    QueryNode,
100    /// Binds to the port that is the same as the driver's port modulo the provided number of shards.
101    FixedNum(u16),
102}
103
104impl ShardAwareness {
105    pub fn is_aware(&self) -> bool {
106        !matches!(self, Self::Unaware)
107    }
108}
109
110/// Node can be either Real (truly backed by a Scylla node) or Simulated
111/// (driver believes it's real, but we merely simulate it with the proxy).
112/// In Simulated mode, no node address is provided and proxy does not attempt
113/// to establish connection with a Scylla node.
114///
115/// For Real node, all workers are created, so such frame flow is possible:
116/// [driver] -> receiver_from_driver -> requests_processor -> sender_to_cluster -> [node] (ordinary request flow)
117///                                        |                    /\
118///     (forging response) ++--------------+      +-------------++ (forging request)
119///                        \/                     |
120/// [driver] <- sender_to_driver <- response_processor <- receiver_from_cluster <- [node] (ordinary response flow)
121///
122/// For Simulated node, it looks like this:
123/// [driver] -> receiver_from_driver -> requests_processor -+
124///                                                         |   (forging response)
125/// [driver] <- sender_to_driver <--------------------------+
126///
127/// For Real node, the default reaction to a frame is to pass it to its intended addresse.
128/// For Simulated node, the default reaction to a request is to drop it.
129enum NodeType {
130    Real {
131        real_addr: SocketAddr,
132        shard_awareness: ShardAwareness,
133        response_rules: Option<Vec<ResponseRule>>,
134    },
135    Simulated,
136}
137
138pub struct Node {
139    proxy_addr: SocketAddr,
140    request_rules: Option<Vec<RequestRule>>,
141    node_type: NodeType,
142}
143
144impl Node {
145    /// Creates an abstract node that is backed by a real Scylla node.
146    pub fn new(
147        real_addr: SocketAddr,
148        proxy_addr: SocketAddr,
149        shard_awareness: ShardAwareness,
150        request_rules: Option<Vec<RequestRule>>,
151        response_rules: Option<Vec<ResponseRule>>,
152    ) -> Self {
153        Self {
154            proxy_addr,
155            request_rules,
156            node_type: NodeType::Real {
157                real_addr,
158                shard_awareness,
159                response_rules,
160            },
161        }
162    }
163
164    /// Creates a simulated node that is not backed by any real Scylla node.
165    pub fn new_dry_mode(proxy_addr: SocketAddr, request_rules: Option<Vec<RequestRule>>) -> Self {
166        Self {
167            proxy_addr,
168            request_rules,
169            node_type: NodeType::Simulated,
170        }
171    }
172
173    pub fn builder() -> NodeBuilder {
174        NodeBuilder {
175            real_addr: None,
176            proxy_addr: None,
177            shard_awareness: None,
178            request_rules: None,
179            response_rules: None,
180        }
181    }
182}
183
184pub struct NodeBuilder {
185    real_addr: Option<SocketAddr>,
186    proxy_addr: Option<SocketAddr>,
187    shard_awareness: Option<ShardAwareness>,
188    request_rules: Option<Vec<RequestRule>>,
189    response_rules: Option<Vec<ResponseRule>>,
190}
191
192impl NodeBuilder {
193    pub fn real_address(mut self, real_addr: SocketAddr) -> Self {
194        self.real_addr = Some(real_addr);
195        self
196    }
197
198    pub fn proxy_address(mut self, proxy_addr: SocketAddr) -> Self {
199        self.proxy_addr = Some(proxy_addr);
200        self
201    }
202
203    pub fn shard_awareness(mut self, shard_awareness: ShardAwareness) -> Self {
204        self.shard_awareness = Some(shard_awareness);
205        self
206    }
207
208    pub fn request_rules(mut self, request_rules: Vec<RequestRule>) -> Self {
209        self.request_rules = Some(request_rules);
210        self
211    }
212
213    pub fn response_rules(mut self, response_rules: Vec<ResponseRule>) -> Self {
214        self.response_rules = Some(response_rules);
215        self
216    }
217
218    /// Creates an abstract node that is backed by a real Scylla node.
219    pub fn build(self) -> Node {
220        Node {
221            proxy_addr: self.proxy_addr.expect("Proxy addr is required!"),
222            request_rules: self.request_rules,
223            node_type: NodeType::Real {
224                real_addr: self.real_addr.expect("Real addr is required!"),
225                shard_awareness: self.shard_awareness.expect("Shard awareness is required!"),
226                response_rules: self.response_rules,
227            },
228        }
229    }
230
231    /// Creates a simulated node that is not backed by any real Scylla node.
232    pub fn build_dry_mode(self) -> Node {
233        Node {
234            proxy_addr: self.proxy_addr.expect("Proxy addr is required!"),
235            request_rules: self.request_rules,
236            node_type: NodeType::Simulated,
237        }
238    }
239}
240
241#[derive(Clone, Copy)]
242struct DisplayableRealAddrOption(Option<SocketAddr>);
243impl Display for DisplayableRealAddrOption {
244    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245        if let Some(addr) = self.0 {
246            write!(f, "{addr}")
247        } else {
248            write!(f, "<dry mode>")
249        }
250    }
251}
252
253#[derive(Clone, Copy)]
254struct DisplayableShard(Option<TargetShard>);
255impl Display for DisplayableShard {
256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257        if let Some(shard) = self.0 {
258            write!(f, "shard {shard}")
259        } else {
260            write!(f, "unknown shard")
261        }
262    }
263}
264
265enum InternalNode {
266    Real {
267        real_addr: SocketAddr,
268        proxy_addr: SocketAddr,
269        shard_awareness: ShardAwareness,
270        request_rules: Arc<Mutex<Vec<RequestRule>>>,
271        response_rules: Arc<Mutex<Vec<ResponseRule>>>,
272    },
273    Simulated {
274        proxy_addr: SocketAddr,
275        request_rules: Arc<Mutex<Vec<RequestRule>>>,
276    },
277}
278
279impl InternalNode {
280    fn proxy_addr(&self) -> SocketAddr {
281        match *self {
282            InternalNode::Real { proxy_addr, .. } => proxy_addr,
283            InternalNode::Simulated { proxy_addr, .. } => proxy_addr,
284        }
285    }
286    fn real_addr(&self) -> Option<SocketAddr> {
287        match *self {
288            InternalNode::Real { real_addr, .. } => Some(real_addr),
289            InternalNode::Simulated { .. } => None,
290        }
291    }
292    fn request_rules(&self) -> &Arc<Mutex<Vec<RequestRule>>> {
293        match self {
294            InternalNode::Real { request_rules, .. } => request_rules,
295            InternalNode::Simulated { request_rules, .. } => request_rules,
296        }
297    }
298}
299
300impl From<Node> for InternalNode {
301    fn from(node: Node) -> Self {
302        match node.node_type {
303            NodeType::Real {
304                real_addr,
305                shard_awareness,
306                response_rules,
307            } => InternalNode::Real {
308                real_addr,
309                proxy_addr: node.proxy_addr,
310                shard_awareness,
311                request_rules: node
312                    .request_rules
313                    .map(|rules| Arc::new(Mutex::new(rules)))
314                    .unwrap_or_default(),
315                response_rules: response_rules
316                    .map(|rules| Arc::new(Mutex::new(rules)))
317                    .unwrap_or_default(),
318            },
319            NodeType::Simulated => InternalNode::Simulated {
320                proxy_addr: node.proxy_addr,
321                request_rules: node
322                    .request_rules
323                    .map(|rules| Arc::new(Mutex::new(rules)))
324                    .unwrap_or_default(),
325            },
326        }
327    }
328}
329
330pub struct ProxyBuilder {
331    nodes: Vec<Node>,
332}
333
334impl ProxyBuilder {
335    pub fn with_node(mut self, node: Node) -> ProxyBuilder {
336        self.nodes.push(node);
337        self
338    }
339
340    pub fn build(self) -> Proxy {
341        Proxy::new(self.nodes)
342    }
343}
344
345pub struct Proxy {
346    nodes: Vec<InternalNode>,
347}
348
349impl Proxy {
350    pub fn new(nodes: impl IntoIterator<Item = Node>) -> Self {
351        Proxy {
352            nodes: nodes.into_iter().map(|node| node.into()).collect(),
353        }
354    }
355
356    pub fn builder() -> ProxyBuilder {
357        ProxyBuilder { nodes: vec![] }
358    }
359
360    /// Build a translation map based on provided proxy and node addresses.
361    /// The map can be passed to `Session` `address_translator()` to ensure
362    /// that the driver contacts the nodes through the proxy (and not directly).
363    pub fn translation_map(&self) -> HashMap<SocketAddr, SocketAddr> {
364        let mut translation_map = HashMap::new();
365        for node in self.nodes.iter() {
366            if let &InternalNode::Real {
367                real_addr,
368                proxy_addr,
369                ..
370            } = node
371            {
372                translation_map.insert(real_addr, proxy_addr);
373                let shard_aware_real_addr = SocketAddr::new(real_addr.ip(), 19042);
374                translation_map.insert(shard_aware_real_addr, proxy_addr);
375            }
376        }
377        translation_map
378    }
379
380    /// Runs the [Proxy], i.e. makes it ready for accepting drivers' connections.
381    /// Returns a [RunningProxy] handle that can be used to stop the proxy or change the rules.
382    pub async fn run(self) -> Result<RunningProxy, DoorkeeperError> {
383        let (terminate_signaler, _t) = tokio::sync::broadcast::channel(1);
384        let (finish_guard, finish_waiter) = mpsc::channel(1);
385
386        let (error_propagator, error_sink) = mpsc::unbounded_channel();
387        let (doorkeepers, running_nodes): (Vec<_>, Vec<RunningNode>) = self
388            .nodes
389            .into_iter()
390            .map(|node| {
391                let connection_tracker = Arc::new(ConnectionTracker {
392                    active_count: std::sync::atomic::AtomicUsize::new(0),
393                    notify: tokio::sync::Notify::new(),
394                });
395                let cc_event_sender = Arc::new(Mutex::new(HashMap::new()));
396                let running = {
397                    let (request_rules, response_rules) = match node {
398                        InternalNode::Real {
399                            ref request_rules,
400                            ref response_rules,
401                            ..
402                        } => (request_rules, Some(response_rules)),
403                        InternalNode::Simulated {
404                            ref request_rules, ..
405                        } => (request_rules, None),
406                    };
407                    RunningNode {
408                        request_rules: request_rules.clone(),
409                        response_rules: response_rules.cloned(),
410                        connection_tracker: connection_tracker.clone(),
411                        cc_event_sender: cc_event_sender.clone(),
412                    }
413                };
414                (
415                    Doorkeeper::spawn(
416                        node,
417                        terminate_signaler.clone(),
418                        finish_guard.clone(),
419                        error_propagator.clone(),
420                        connection_tracker,
421                        cc_event_sender,
422                    ),
423                    running,
424                )
425            })
426            .unzip();
427
428        for doorkeeper in doorkeepers {
429            doorkeeper.await?; // await doorkeeper creation, including binding to a socket
430        }
431
432        Ok(RunningProxy {
433            terminate_signaler,
434            finish_waiter,
435            running_nodes,
436            error_sink,
437        })
438    }
439}
440
441/// A handle that can be used to change the rules regarding the particular node.
442pub struct RunningNode {
443    request_rules: Arc<Mutex<Vec<RequestRule>>>,
444    response_rules: Option<Arc<Mutex<Vec<ResponseRule>>>>,
445    connection_tracker: Arc<ConnectionTracker>,
446
447    /// Senders to the driver-facing sockets of all control connections (those
448    /// that sent REGISTER). Keyed by `connection_no` so that each connection
449    /// can be individually removed on close without affecting others.
450    ///
451    /// Populated by `request_processor` when it sees a REGISTER frame,
452    /// entries removed when the corresponding connection closes.
453    ///
454    /// Used by [`inject_event_to_cc`](Self::inject_event_to_cc) to push
455    /// unsolicited EVENT frames to every registered control connection.
456    cc_event_sender: Arc<Mutex<HashMap<usize, mpsc::UnboundedSender<ResponseFrame>>>>,
457}
458
459impl RunningNode {
460    /// Replaces the previous request rules with the new ones.
461    pub fn change_request_rules(&mut self, rules: Option<Vec<RequestRule>>) {
462        *self.request_rules.lock().unwrap() = rules.unwrap_or_default();
463    }
464
465    /// Adds new request rules to the end of the list (so with lowest priority)
466    pub fn append_request_rules(&mut self, mut rules: Vec<RequestRule>) {
467        self.request_rules.lock().unwrap().append(&mut rules);
468    }
469
470    /// Adds new request rules to the beginning of the list (so with highest priority)
471    pub fn prepend_request_rules(&mut self, rules: Vec<RequestRule>) {
472        let mut new_rules = rules;
473        let mut old_rules_guard = self.request_rules.lock().unwrap();
474        new_rules.append(&mut *old_rules_guard);
475        *old_rules_guard = new_rules;
476    }
477
478    /// Replaces the previous response rules with the new ones.
479    pub fn change_response_rules(&mut self, rules: Option<Vec<ResponseRule>>) {
480        *self
481            .response_rules
482            .as_ref()
483            .expect("No response rules on a simulated node!")
484            .lock()
485            .unwrap() = rules.unwrap_or_default();
486    }
487
488    /// Adds new response rules to the end of the list (so with lowest priority)
489    pub fn append_response_rules(&mut self, mut rules: Vec<ResponseRule>) {
490        self.response_rules
491            .as_ref()
492            .expect("No response rules on a simulated node!")
493            .lock()
494            .unwrap()
495            .append(&mut rules);
496    }
497
498    /// Adds new response rules to the beginning of the list (so with highest priority)
499    pub fn prepend_response_rules(&mut self, rules: Vec<ResponseRule>) {
500        let mut old_rules_guard = self
501            .response_rules
502            .as_ref()
503            .expect("No response rules on a simulated node!")
504            .lock()
505            .unwrap();
506        let mut new_rules = rules;
507        new_rules.append(&mut *old_rules_guard);
508        *old_rules_guard = new_rules;
509    }
510
511    /// Waits until at least one driver connection is active on this node.
512    ///
513    /// Returns immediately if there is already an active connection.
514    /// Otherwise blocks until a new connection is accepted by the
515    /// node's doorkeeper.
516    pub async fn wait_for_connection(&self) {
517        loop {
518            // Prepare the notification future BEFORE checking the count
519            // to avoid a race where a connection arrives between the check
520            // and the await.
521            let notified = self.connection_tracker.notify.notified();
522            if self.connection_tracker.active_count.load(Ordering::Relaxed) > 0 {
523                return;
524            }
525            notified.await;
526        }
527    }
528
529    /// Injects a CQL EVENT frame into all registered control connections.
530    ///
531    /// Builds an EVENT frame (stream = −1, flags = 0, opcode = Event) with the
532    /// supplied `body` and sends it to every driver that has sent REGISTER on
533    /// this node. Dead senders (closed connections) are pruned automatically.
534    ///
535    /// Returns `true` if the frame was successfully enqueued to at least one
536    /// control connection, `false` if no control connections are currently
537    /// registered or all sends failed.
538    pub fn inject_event_to_cc(&self, body: Bytes) -> bool {
539        let mut guard = self.cc_event_sender.lock().unwrap();
540        if guard.is_empty() {
541            return false;
542        }
543        let mut any_sent = false;
544        guard.retain(|_conn_no, tx| {
545            let frame = ResponseFrame {
546                params: FrameParams {
547                    version: 4,
548                    flags: 0,
549                    stream: -1,
550                }
551                .for_response(),
552                opcode: ResponseOpcode::Event,
553                body: body.clone(),
554            };
555            let ok = tx.send(frame).is_ok();
556            any_sent |= ok;
557            // Remove dead senders (connection already closed on the
558            // receiving end).
559            ok
560        });
561        any_sent
562    }
563}
564
565/// A handle that can be used to stop the proxy or change the rules.
566pub struct RunningProxy {
567    terminate_signaler: TerminateSignaler,
568    finish_waiter: FinishWaiter,
569    pub running_nodes: Vec<RunningNode>,
570    error_sink: ErrorSink,
571}
572
573impl RunningProxy {
574    /// Disables all the rules in the proxy, effectively making it a pass-through-only proxy.
575    pub fn turn_off_rules(&mut self) {
576        for (request_rules, response_rules) in self
577            .running_nodes
578            .iter_mut()
579            .map(|node| (&node.request_rules, &node.response_rules))
580        {
581            request_rules.lock().unwrap().clear();
582            if let Some(response_rules) = response_rules {
583                response_rules.lock().unwrap().clear();
584            }
585        }
586    }
587
588    /// Attempts to fetch the first error that has occurred in proxy since last check.
589    /// If no errors occurred, returns Ok(()).
590    pub fn sanity_check(&mut self) -> Result<(), ProxyError> {
591        match self.error_sink.try_recv() {
592            Ok(err) => Err(err),
593            Err(TryRecvError::Empty) => Ok(()),
594            Err(TryRecvError::Disconnected) => {
595                // As we haven't awaited finish of all workers yet, there must be a faulty case without proper error handling.
596                Err(ProxyError::SanityCheckFailure)
597            }
598        }
599    }
600
601    /// Waits until an error occurs in proxy. If proxy finishes with no errors occurred, returns Err(()).
602    pub async fn wait_for_error(&mut self) -> Option<ProxyError> {
603        self.error_sink.recv().await
604    }
605
606    /// Requests termination of all proxy workers and awaits its completion.
607    /// Returns the first error that occurred in proxy.
608    pub async fn finish(mut self) -> Result<(), ProxyError> {
609        self.terminate_signaler.send(()).map_err(|err| {
610            ProxyError::AwaitFinishFailure(format!(
611                "Send error in terminate_signaler: {err} (bug!)"
612            ))
613        })?;
614        info!("Sent finish signal to proxy workers.");
615
616        // This to make sure that also workers not-yet-spawned when terminate signal was sent will terminate.
617        std::mem::drop(self.terminate_signaler);
618
619        if self.finish_waiter.recv().await.is_some() {
620            unreachable!();
621        };
622        info!("All workers have finished.");
623
624        match self.error_sink.try_recv() {
625            Ok(err) => Err(err),
626            Err(TryRecvError::Disconnected) => Ok(()),
627            Err(TryRecvError::Empty) => {
628                // As we have already awaited finish of all workers, there must be a logic bug.
629                unreachable!("Worker await logic bug!");
630            }
631        }
632    }
633
634    /// Waits until at least one driver connection is active on any node
635    /// in this proxy.
636    ///
637    /// For single-node proxies (typical in client-routes tests), this waits
638    /// until a driver has connected to the sole node. For multi-node
639    /// proxies, returns as soon as any node receives a connection.
640    pub async fn wait_for_connection(&self) {
641        // Build a future for each node and race them.
642        futures::future::select_all(
643            self.running_nodes
644                .iter()
645                .map(|n| Box::pin(n.wait_for_connection())),
646        )
647        .await;
648    }
649}
650
651/// A worker corresponding to a particular node. It listens in a loop for driver's connections
652/// on specified proxy bind address, respects ports regarding advanced shard-awareness (if set),
653/// to this end obtaining number of shards from the node (if set), then establishes connection
654/// to the node, spawns workers for this connection and continues to listen.
655struct Doorkeeper {
656    node: InternalNode,
657    listener: TcpListener,
658    terminate_signaler: TerminateSignaler,
659    finish_guard: FinishGuard,
660    shards_count: Option<u16>,
661    error_propagator: ErrorPropagator,
662    connection_tracker: Arc<ConnectionTracker>,
663    cc_event_sender: Arc<Mutex<HashMap<usize, mpsc::UnboundedSender<ResponseFrame>>>>,
664}
665
666impl Doorkeeper {
667    async fn spawn(
668        node: InternalNode,
669        terminate_signaler: TerminateSignaler,
670        finish_guard: FinishGuard,
671        error_propagator: ErrorPropagator,
672        connection_tracker: Arc<ConnectionTracker>,
673        cc_event_sender: Arc<Mutex<HashMap<usize, mpsc::UnboundedSender<ResponseFrame>>>>,
674    ) -> Result<(), DoorkeeperError> {
675        let listener = TcpListener::bind(node.proxy_addr())
676            .await
677            .map_err(|err| DoorkeeperError::DriverConnectionAttempt(node.proxy_addr(), err))?;
678
679        if let InternalNode::Real {
680            shard_awareness,
681            real_addr,
682            ..
683        } = node
684        {
685            info!(
686                "Spawned a {} doorkeeper for pair real:{} - proxy:{}.",
687                if shard_awareness.is_aware() {
688                    "shard-aware"
689                } else {
690                    "shard-unaware"
691                },
692                real_addr,
693                node.proxy_addr(),
694            );
695        } else {
696            info!(
697                "Spawned a dry-mode doorkeeper for proxy:{}.",
698                node.proxy_addr(),
699            )
700        };
701
702        let doorkeeper = Doorkeeper {
703            shards_count: None, // temporarily, until Doorkeeper examines its ShardAwareness
704            node,
705            listener,
706            terminate_signaler,
707            finish_guard,
708            error_propagator,
709            connection_tracker,
710            cc_event_sender,
711        };
712        tokio::task::spawn(doorkeeper.run());
713        Ok(())
714    }
715
716    async fn run(mut self) {
717        self.update_shards_count().await;
718        let mut own_terminate_notifier = self.terminate_signaler.subscribe();
719        let (connection_close_tx, _connection_close_rx) = broadcast::channel::<()>(2);
720        let mut connection_no: usize = 0;
721        loop {
722            tokio::select! {
723                res = self.accept_connection(&connection_close_tx, connection_no) => {
724                    match res {
725                        Ok(()) => connection_no += 1,
726                        Err(err) => {
727                            error!(
728                                "Error in doorkeeper with addr {} for node {}: {}",
729                                self.node.proxy_addr(),
730                                DisplayableRealAddrOption(self.node.real_addr()),
731                                err
732                            );
733                            let _ = self.error_propagator.send(err.into());
734                            break;
735                        },
736                    }
737                },
738                _terminate = own_terminate_notifier.recv() => break
739            }
740        }
741        debug!(
742            "Doorkeeper exits: proxy {}, node {}.",
743            self.node.proxy_addr(),
744            DisplayableRealAddrOption(self.node.real_addr())
745        );
746    }
747
748    async fn update_shards_count(&mut self) {
749        if let InternalNode::Real {
750            real_addr,
751            shard_awareness,
752            ..
753        } = self.node
754        {
755            self.shards_count = match shard_awareness {
756                ShardAwareness::Unaware => None,
757                ShardAwareness::FixedNum(shards_num) => Some(shards_num),
758                ShardAwareness::QueryNode => match self.obtain_shards_count(real_addr).await {
759                    Ok(shards) => Some(shards),
760                    // If a node offers no sharding info, change proxy ShardAwareness to Unaware.
761                    Err(DoorkeeperError::ObtainingShardNumberNoShardInfo) => {
762                        info!(
763                            "Doorkeeper with addr {} found no shard info in node {}; falling back to ShardAwareness::Unaware",
764                            self.node.proxy_addr(),
765                            DisplayableRealAddrOption(self.node.real_addr()),
766                        );
767                        None
768                    }
769                    Err(e) => {
770                        error!(
771                            "Error in doorkeeper with addr {} while querying shard info from node {}: {}",
772                            self.node.proxy_addr(),
773                            DisplayableRealAddrOption(self.node.real_addr()),
774                            e
775                        );
776                        None
777                    }
778                },
779            }
780        }
781    }
782
783    async fn spawn_workers(
784        &mut self,
785        driver_addr: SocketAddr,
786        connection_close_tx: &ConnectionCloseSignaler,
787        connection_no: usize,
788        driver_stream: TcpStream,
789        cluster_stream: Option<TcpStream>,
790        shard: Option<TargetShard>,
791    ) {
792        // Track a new connection: increment active count and create a
793        // per-connection lifetime guard. The guard is wrapped in Arc and
794        // cloned to each spawned worker task. When the last worker for this
795        // connection exits, the Arc's inner ConnectionLifetime drops and
796        // decrements the active count.
797        let conn_lifetime = self.connection_tracker.register_connection();
798
799        let (driver_read, driver_write) = driver_stream.into_split();
800
801        let new_worker = || ProxyWorker {
802            terminate_notifier: self.terminate_signaler.subscribe(),
803            finish_guard: self.finish_guard.clone(),
804            connection_close_notifier: connection_close_tx.subscribe(),
805            error_propagator: self.error_propagator.clone(),
806            driver_addr,
807            real_addr: self.node.real_addr(),
808            proxy_addr: self.node.proxy_addr(),
809            shard,
810        };
811
812        let (tx_request, rx_request) = mpsc::unbounded_channel::<RequestFrame>();
813        let (tx_response, rx_response) = mpsc::unbounded_channel::<ResponseFrame>();
814        let (tx_cluster, rx_cluster) = mpsc::unbounded_channel::<RequestFrame>();
815        let (tx_driver, rx_driver) = mpsc::unbounded_channel::<ResponseFrame>();
816        let event_register_flag = Arc::new(AtomicBool::new(false));
817
818        let (
819            compression_writer_request_processor,
820            compression_reader_receiver_from_driver,
821            compression_reader_receiver_from_cluster,
822            compression_reader_sender_to_driver,
823            compression_reader_sender_to_cluster,
824        ) = compression::make_compression_infra();
825
826        {
827            let guard = Arc::clone(&conn_lifetime);
828            let worker = new_worker();
829            tokio::task::spawn(async move {
830                let _conn = guard;
831                worker
832                    .receiver_from_driver(
833                        driver_read,
834                        tx_request,
835                        compression_reader_receiver_from_driver,
836                    )
837                    .await;
838            });
839        }
840        {
841            let guard = Arc::clone(&conn_lifetime);
842            let worker = new_worker();
843            let conn_close_sub = connection_close_tx.subscribe();
844            let term_sub = self.terminate_signaler.subscribe();
845            tokio::task::spawn(async move {
846                let _conn = guard;
847                worker
848                    .sender_to_driver(
849                        driver_write,
850                        rx_driver,
851                        conn_close_sub,
852                        term_sub,
853                        compression_reader_sender_to_driver,
854                    )
855                    .await;
856            });
857        }
858        {
859            let guard = Arc::clone(&conn_lifetime);
860            let worker = new_worker();
861            let request_rules = Arc::clone(self.node.request_rules());
862            let conn_close = connection_close_tx.clone();
863            let event_flag = Arc::clone(&event_register_flag);
864            let tx_driver_clone = tx_driver.clone();
865            let tx_cluster_clone = tx_cluster.clone();
866            let cc_sender = Arc::clone(&self.cc_event_sender);
867            tokio::task::spawn(async move {
868                let _conn = guard;
869                worker
870                    .request_processor(
871                        rx_request,
872                        tx_driver_clone,
873                        tx_cluster_clone,
874                        connection_no,
875                        request_rules,
876                        conn_close,
877                        event_flag,
878                        compression_writer_request_processor,
879                        cc_sender,
880                    )
881                    .await;
882            });
883        }
884        if let InternalNode::Real {
885            ref response_rules, ..
886        } = self.node
887        {
888            let (cluster_read, cluster_write) = cluster_stream.unwrap().into_split();
889            {
890                let guard = Arc::clone(&conn_lifetime);
891                let worker = new_worker();
892                let conn_close_sub = connection_close_tx.subscribe();
893                let term_sub = self.terminate_signaler.subscribe();
894                tokio::task::spawn(async move {
895                    let _conn = guard;
896                    worker
897                        .sender_to_cluster(
898                            cluster_write,
899                            rx_cluster,
900                            conn_close_sub,
901                            term_sub,
902                            compression_reader_sender_to_cluster,
903                        )
904                        .await;
905                });
906            }
907            {
908                let guard = Arc::clone(&conn_lifetime);
909                let worker = new_worker();
910                tokio::task::spawn(async move {
911                    let _conn = guard;
912                    worker
913                        .receiver_from_cluster(
914                            cluster_read,
915                            tx_response,
916                            compression_reader_receiver_from_cluster,
917                        )
918                        .await;
919                });
920            }
921            {
922                let guard = conn_lifetime;
923                let worker = new_worker();
924                let response_rules = Arc::clone(response_rules);
925                let conn_close = connection_close_tx.clone();
926                let event_flag = Arc::clone(&event_register_flag);
927                tokio::task::spawn(async move {
928                    let _conn = guard;
929                    worker
930                        .response_processor(
931                            rx_response,
932                            tx_driver,
933                            tx_cluster,
934                            connection_no,
935                            response_rules,
936                            conn_close,
937                            event_flag,
938                        )
939                        .await;
940                });
941            }
942        }
943        debug!(
944            "Doorkeeper with addr {} of node {} spawned workers.",
945            self.node.proxy_addr(),
946            DisplayableRealAddrOption(self.node.real_addr())
947        );
948    }
949
950    async fn accept_connection(
951        &mut self,
952        connection_close_tx: &ConnectionCloseSignaler,
953        connection_no: usize,
954    ) -> Result<(), DoorkeeperError> {
955        let (driver_stream, driver_addr) = self.make_driver_stream(connection_no).await?;
956        let (cluster_stream, shard) = match self.node {
957            InternalNode::Real { real_addr, .. } => {
958                let (cluster_stream, shard) =
959                    self.make_cluster_stream(driver_addr, real_addr).await?;
960                (Some(cluster_stream), shard)
961            }
962            InternalNode::Simulated { .. } => (None, None),
963        };
964
965        self.spawn_workers(
966            driver_addr,
967            connection_close_tx,
968            connection_no,
969            driver_stream,
970            cluster_stream,
971            shard,
972        )
973        .await;
974
975        Ok(())
976    }
977
978    async fn make_driver_stream(
979        &mut self,
980        connection_no: usize,
981    ) -> Result<(TcpStream, SocketAddr), DoorkeeperError> {
982        let (driver_stream, driver_addr) =
983            self.listener.accept().await.map_err(|err| {
984                DoorkeeperError::DriverConnectionAttempt(self.node.proxy_addr(), err)
985            })?;
986        info!(
987            "Connected driver from {} to {}, connection no={}.",
988            driver_addr,
989            self.node.proxy_addr(),
990            connection_no
991        );
992        Ok((driver_stream, driver_addr))
993    }
994
995    async fn make_cluster_stream(
996        &mut self,
997        driver_addr: SocketAddr,
998        real_addr: SocketAddr,
999    ) -> Result<(TcpStream, Option<TargetShard>), DoorkeeperError> {
1000        let mut cluster_stream = if let Some(shards) = self.shards_count {
1001            let socket = match self.node.proxy_addr().ip() {
1002                std::net::IpAddr::V4(_) => TcpSocket::new_v4(),
1003                std::net::IpAddr::V6(_) => TcpSocket::new_v6(),
1004            }
1005            .map_err(DoorkeeperError::SocketCreate)?;
1006
1007            let shard_preserving_addr = {
1008                let mut desired_addr =
1009                    SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), driver_addr.port());
1010                while socket.bind(desired_addr).is_err() {
1011                    // in search for a port that translates to the desired shard
1012                    let next_port = self.next_port_to_same_shard(desired_addr.port());
1013                    if next_port == driver_addr.port() {
1014                        return Err(DoorkeeperError::NoMorePorts);
1015                    }
1016                    desired_addr.set_port(next_port);
1017                }
1018                desired_addr
1019            };
1020
1021            let stream = socket.connect(real_addr).await;
1022            if let Ok(ok) = &stream {
1023                info!(
1024                    "Connected to the cluster from {} at {}, intended shard {}.",
1025                    ok.local_addr().unwrap(),
1026                    real_addr,
1027                    shard_preserving_addr.port() % shards
1028                );
1029            }
1030            stream
1031        } else {
1032            let stream = TcpStream::connect(real_addr).await;
1033            if stream.is_ok() {
1034                info!("Connected to the cluster at {}.", real_addr);
1035            }
1036            stream
1037        }
1038        .map_err(|err| DoorkeeperError::NodeConnectionAttempt(real_addr, err))?;
1039
1040        // If ShardAwareness is aware (QueryNode or FixedNum variants) and the
1041        // proxy succeeded to know the shards count (in FixedNum we get it for
1042        // free, in QueryNode the initial Options query succeeded and Supported
1043        // contained SCYLLA_SHARDS_NUM), then upon opening each connection to the
1044        // node, the proxy issues another Options requests and acknowledges the
1045        // shard it got connected to.
1046        let shard = if self.shards_count.is_some() {
1047            self.obtain_shard_number(real_addr, &mut cluster_stream)
1048                .await?
1049        } else {
1050            None
1051        };
1052
1053        Ok((cluster_stream, shard))
1054    }
1055
1056    fn next_port_to_same_shard(&self, port: u16) -> u16 {
1057        port.wrapping_add(self.shards_count.unwrap())
1058    }
1059
1060    async fn get_supported_options(
1061        connection: &mut TcpStream,
1062    ) -> Result<HashMap<String, Vec<String>>, DoorkeeperError> {
1063        write_frame(
1064            HARDCODED_OPTIONS_PARAMS,
1065            FrameOpcode::Request(RequestOpcode::Options),
1066            &Bytes::new(),
1067            connection,
1068            &no_compression(),
1069        )
1070        .await
1071        .map_err(DoorkeeperError::ObtainingShardNumber)?;
1072
1073        let supported_frame = read_response_frame(connection, &compression::no_compression())
1074            .await
1075            .map_err(DoorkeeperError::ObtainingShardNumberFrame)?;
1076
1077        let options = read_string_multimap(&mut supported_frame.body.as_ref())
1078            .map_err(DoorkeeperError::ObtainingShardNumberParseOptions)?;
1079
1080        Ok(options)
1081    }
1082
1083    async fn obtain_shards_count(&self, real_addr: SocketAddr) -> Result<u16, DoorkeeperError> {
1084        let mut connection = TcpStream::connect(real_addr)
1085            .await
1086            .map_err(|err| DoorkeeperError::NodeConnectionAttempt(real_addr, err))?;
1087        let options = Self::get_supported_options(&mut connection).await?;
1088        let nr_shards_entry = options.get("SCYLLA_NR_SHARDS");
1089        let shards = match nr_shards_entry
1090            .and_then(|vec| vec.first())
1091            .ok_or(DoorkeeperError::ObtainingShardNumberNoShardInfo)?
1092            .parse::<u16>()
1093            .map_err(DoorkeeperError::ObtainingShardNumberParseShardNumber)?
1094        {
1095            0u16 => Err(DoorkeeperError::ObtainingShardNumberGotZero),
1096            num => Ok(num),
1097        }?;
1098        info!("Obtained shards number on node {}: {}", real_addr, shards);
1099        Ok(shards)
1100    }
1101
1102    async fn obtain_shard_number(
1103        &self,
1104        real_addr: SocketAddr,
1105        connection: &mut TcpStream,
1106    ) -> Result<Option<TargetShard>, DoorkeeperError> {
1107        let options = Self::get_supported_options(connection).await?;
1108        let shard_entry = options.get("SCYLLA_SHARD");
1109        let shard = shard_entry
1110            .and_then(|vec| vec.first())
1111            .map(|s| {
1112                s.parse::<u16>()
1113                    .map_err(DoorkeeperError::ObtainingShardNumberParseShardNumber)
1114            })
1115            .transpose()?;
1116        info!("Connected to node {}, shard {:?}", real_addr, shard);
1117        Ok(shard)
1118    }
1119}
1120
1121mod compression {
1122    use std::error::Error;
1123    use std::sync::{Arc, OnceLock};
1124
1125    use bytes::Bytes;
1126    use scylla_cql::frame::frame_errors::{
1127        CqlRequestSerializationError, FrameBodyExtensionsParseError,
1128    };
1129    use scylla_cql::frame::request::{
1130        DeserializableRequest as _, RequestDeserializationError, Startup, options,
1131    };
1132    use scylla_cql::frame::{Compression, compress_append, decompress, flag};
1133    use tracing::{error, warn};
1134
1135    #[derive(Debug, thiserror::Error)]
1136    pub(crate) enum CompressionError {
1137        /// Body Snap compression failed.
1138        #[error("Snap compression error: {0}")]
1139        SnapCompressError(Arc<dyn Error + Sync + Send>),
1140
1141        /// Frame is to be compressed, but no compression was negotiated for the connection.
1142        #[error("Frame is to be compressed, but no compression negotiated for connection.")]
1143        NoCompressionNegotiated,
1144    }
1145
1146    type CompressionInfo = Arc<OnceLock<Option<Compression>>>;
1147
1148    /// The write end of compression config for a connection.
1149    ///
1150    /// Used by the request processor upon STARTUP frame captured
1151    /// and compression setting retrieved from it.
1152    #[derive(Debug, Clone)]
1153    pub(crate) struct CompressionWriter(CompressionInfo);
1154    impl CompressionWriter {
1155        pub(crate) fn set(
1156            &self,
1157            compression: Option<Compression>,
1158        ) -> Result<(), Option<Compression>> {
1159            self.0.set(compression)
1160        }
1161
1162        pub(crate) fn set_from_startup(
1163            &self,
1164            mut body: &[u8],
1165        ) -> Result<Option<Compression>, RequestDeserializationError> {
1166            let startup = Startup::deserialize_with_features(&mut body, &Default::default())?;
1167            let maybe_compression = startup.options.get(options::COMPRESSION);
1168            let maybe_compression = maybe_compression.and_then(|compression| {
1169                compression
1170                    .parse::<Compression>()
1171                    .inspect_err(|err| error!("STARTUP compression error: {}", err))
1172                    .ok()
1173            });
1174            let _ = self.set(maybe_compression).inspect_err(|_| {
1175                warn!("Captured second or further STARTUP frame on the same connection")
1176            });
1177
1178            Ok(maybe_compression)
1179        }
1180    }
1181
1182    /// The read end of compression config for a connection.
1183    ///
1184    /// Used by frame (de)serializers.
1185    #[derive(Debug, Clone)]
1186    pub(crate) struct CompressionReader(CompressionInfo);
1187    impl CompressionReader {
1188        /// Return the compression negotiated for the connection.
1189        ///
1190        /// Outer Option signifies whether the negotiation took place,
1191        /// inner Option is the compression (or lack of it) negotiated.
1192        pub(crate) fn get(&self) -> Option<Option<Compression>> {
1193            self.0.get().copied()
1194        }
1195
1196        pub(crate) fn maybe_compress_body(
1197            &self,
1198            flags: u8,
1199            body: &[u8],
1200        ) -> Result<Option<Bytes>, CompressionError> {
1201            match (flags & flag::COMPRESSION != 0, self.get().flatten()) {
1202                (true, Some(compression)) => {
1203                    let mut buf = Vec::new();
1204                    compress_append(body, compression, &mut buf).map_err(|err| {
1205                        let CqlRequestSerializationError::SnapCompressError(err) = err else {
1206                            unreachable!("BUG: compress_append returned variant different than SnapCompressError")
1207                        };
1208                        CompressionError::SnapCompressError(err)
1209                    })?;
1210                    Ok(Some(Bytes::from(buf)))
1211                }
1212                (true, None) => Err(CompressionError::NoCompressionNegotiated),
1213                (false, _) => Ok(None),
1214            }
1215        }
1216
1217        pub(crate) fn maybe_decompress_body(
1218            &self,
1219            flags: u8,
1220            body: Bytes,
1221        ) -> Result<Bytes, FrameBodyExtensionsParseError> {
1222            match (flags & flag::COMPRESSION != 0, self.get().flatten()) {
1223                (true, Some(compression)) => decompress(&body, compression).map(Into::into),
1224                (true, None) => Err(FrameBodyExtensionsParseError::NoCompressionNegotiated),
1225                (false, _) => Ok(body),
1226            }
1227        }
1228    }
1229
1230    pub(crate) fn make_compression_infra() -> (
1231        CompressionWriter,
1232        CompressionReader,
1233        CompressionReader,
1234        CompressionReader,
1235        CompressionReader,
1236    ) {
1237        let info = Arc::new(OnceLock::new());
1238        (
1239            CompressionWriter(info.clone()),
1240            CompressionReader(info.clone()),
1241            CompressionReader(info.clone()),
1242            CompressionReader(info.clone()),
1243            CompressionReader(info),
1244        )
1245    }
1246
1247    fn mock_compression_reader(compression: Option<Compression>) -> CompressionReader {
1248        CompressionReader(Arc::new({
1249            let once = OnceLock::new();
1250            once.set(compression).unwrap();
1251            once
1252        }))
1253    }
1254
1255    // Compression explicitly turned off.
1256    pub(crate) fn no_compression() -> CompressionReader {
1257        mock_compression_reader(None)
1258    }
1259
1260    // Compression explicitly turned on.
1261    #[cfg(test)] // Currently only used for tests.
1262    pub(crate) fn with_compression(compression: Compression) -> CompressionReader {
1263        mock_compression_reader(Some(compression))
1264    }
1265}
1266pub(crate) use compression::{CompressionReader, CompressionWriter};
1267
1268struct ProxyWorker {
1269    terminate_notifier: TerminateNotifier,
1270    finish_guard: FinishGuard,
1271    connection_close_notifier: ConnectionCloseNotifier,
1272    error_propagator: ErrorPropagator,
1273    driver_addr: SocketAddr,
1274    real_addr: Option<SocketAddr>,
1275    proxy_addr: SocketAddr,
1276    shard: Option<TargetShard>,
1277}
1278
1279impl ProxyWorker {
1280    fn exit(self, duty: &'static str) {
1281        debug!(
1282            "Worker exits: [driver: {}, proxy: {}, node: {}, {}]::{}.",
1283            self.driver_addr,
1284            self.proxy_addr,
1285            DisplayableRealAddrOption(self.real_addr),
1286            DisplayableShard(self.shard),
1287            duty
1288        );
1289        std::mem::drop(self.finish_guard);
1290    }
1291
1292    async fn run_until_interrupted<F, Fut>(mut self, worker_name: &'static str, f: F)
1293    where
1294        F: FnOnce(SocketAddr, SocketAddr, Option<SocketAddr>) -> Fut,
1295        Fut: Future<Output = Result<(), ProxyError>>,
1296    {
1297        let fut = f(self.driver_addr, self.proxy_addr, self.real_addr);
1298
1299        tokio::select! {
1300            result = fut => {
1301                if let Err(err) = result {
1302                    // error_propagator could be a field
1303                    let _ = self.error_propagator.send(err);
1304                }
1305            }
1306            _ = self.terminate_notifier.recv() => (),
1307            _ = self.connection_close_notifier.recv() => (),
1308        }
1309        self.exit(worker_name);
1310    }
1311
1312    async fn receiver_from_driver(
1313        self,
1314        mut read_half: impl AsyncRead + Unpin,
1315        request_processor_tx: mpsc::UnboundedSender<RequestFrame>,
1316        compression: CompressionReader,
1317    ) {
1318        let shard = self.shard;
1319        self.run_until_interrupted(
1320            "receiver_from_driver",
1321            |driver_addr, proxy_addr, _real_addr| async move {
1322                loop {
1323                    let frame = frame::read_request_frame(&mut read_half, &compression)
1324                        .await
1325                        .map_err(|err| {
1326                            warn!("Request reception from {} error: {}", driver_addr, err);
1327                            WorkerError::DriverDisconnected(driver_addr)
1328                        })?;
1329
1330                    debug!(
1331                        "Intercepted Driver ({}) -> Cluster ({}) ({}) frame. opcode: {:?}.",
1332                        driver_addr,
1333                        proxy_addr,
1334                        DisplayableShard(shard),
1335                        &frame.opcode
1336                    );
1337                    if request_processor_tx.send(frame).is_err() {
1338                        warn!("request_processor had exited.");
1339                        return Result::<(), ProxyError>::Ok(());
1340                    }
1341                }
1342            },
1343        )
1344        .await
1345    }
1346
1347    async fn receiver_from_cluster(
1348        self,
1349        mut read_half: impl AsyncRead + Unpin,
1350        response_processor_tx: mpsc::UnboundedSender<ResponseFrame>,
1351        compression: CompressionReader,
1352    ) {
1353        let shard = self.shard;
1354        self.run_until_interrupted(
1355            "receiver_from_cluster",
1356            |driver_addr, _proxy_addr, real_addr| async move {
1357                let real_addr = real_addr.expect("BUG: no real_addr in cluster worker");
1358                loop {
1359                    let frame = frame::read_response_frame(&mut read_half, &compression)
1360                        .await
1361                        .map_err(|err| {
1362                            warn!("Response reception from {} error: {}", real_addr, err);
1363                            WorkerError::NodeDisconnected(real_addr)
1364                        })?;
1365
1366                    debug!(
1367                        "Intercepted Cluster ({}) ({}) -> Driver ({}) frame. opcode: {:?}.",
1368                        real_addr,
1369                        DisplayableShard(shard),
1370                        driver_addr,
1371                        &frame.opcode
1372                    );
1373
1374                    if response_processor_tx.send(frame).is_err() {
1375                        warn!("response_processor had exited.");
1376                        return Ok::<(), ProxyError>(());
1377                    }
1378                }
1379            },
1380        )
1381        .await;
1382    }
1383
1384    async fn sender_to_driver(
1385        self,
1386        mut write_half: impl AsyncWrite + Unpin,
1387        mut responses_rx: mpsc::UnboundedReceiver<ResponseFrame>,
1388        mut connection_close_notifier: ConnectionCloseNotifier,
1389        mut terminate_notifier: TerminateNotifier,
1390        compression: CompressionReader,
1391    ) {
1392        let shard = self.shard;
1393        self.run_until_interrupted(
1394            "sender_to_driver",
1395            |driver_addr, proxy_addr, _real_addr| async move {
1396                loop {
1397                    let response = match responses_rx.recv().await {
1398                        Some(response) => response,
1399                        None => {
1400                            if terminate_notifier.try_recv().is_err()
1401                                && connection_close_notifier.try_recv().is_err()
1402                            {
1403                                warn!("Response processor had exited");
1404                            }
1405                            return Ok(());
1406                        }
1407                    };
1408
1409                    debug!(
1410                        "Sending Proxy ({}) ({}) -> Driver ({}) frame. opcode: {:?}.",
1411                        proxy_addr,
1412                        DisplayableShard(shard),
1413                        driver_addr,
1414                        &response.opcode
1415                    );
1416                    if response.write(&mut write_half, &compression).await.is_err() {
1417                        if terminate_notifier.try_recv().is_err()
1418                            && connection_close_notifier.try_recv().is_err()
1419                        {
1420                            warn!("Driver dropped connection");
1421                            return Err(WorkerError::DriverDisconnected(driver_addr).into());
1422                        }
1423                        return Ok(());
1424                    }
1425                }
1426            },
1427        )
1428        .await;
1429    }
1430
1431    async fn sender_to_cluster(
1432        self,
1433        mut write_half: impl AsyncWrite + Unpin,
1434        mut requests_rx: mpsc::UnboundedReceiver<RequestFrame>,
1435        mut connection_close_notifier: ConnectionCloseNotifier,
1436        mut terminate_notifier: TerminateNotifier,
1437        compression: CompressionReader,
1438    ) {
1439        let shard = self.shard;
1440        self.run_until_interrupted(
1441            "sender_to_driver",
1442            |_driver_addr, proxy_addr, real_addr| async move {
1443                let real_addr = real_addr.expect("BUG: no real_addr in cluster worker");
1444                loop {
1445                    let request = match requests_rx.recv().await {
1446                        Some(request) => request,
1447                        None => {
1448                            if terminate_notifier.try_recv().is_err()
1449                                && connection_close_notifier.try_recv().is_err()
1450                            {
1451                                warn!("Request processor had exited");
1452                            }
1453                            return Ok(());
1454                        }
1455                    };
1456
1457                    debug!(
1458                        "Sending Proxy ({}) -> Cluster ({}) ({}) frame. opcode: {:?}.",
1459                        proxy_addr,
1460                        real_addr,
1461                        DisplayableShard(shard),
1462                        &request.opcode
1463                    );
1464
1465                    if request.write(&mut write_half, &compression).await.is_err() {
1466                        if terminate_notifier.try_recv().is_err()
1467                            && connection_close_notifier.try_recv().is_err()
1468                        {
1469                            warn!("Node {} dropped connection", real_addr);
1470                            return Err(WorkerError::NodeDisconnected(real_addr).into());
1471                        }
1472                        return Ok(());
1473                    }
1474                }
1475            },
1476        )
1477        .await;
1478    }
1479
1480    #[expect(clippy::too_many_arguments)]
1481    async fn request_processor(
1482        self,
1483        mut requests_rx: mpsc::UnboundedReceiver<RequestFrame>,
1484        driver_tx: mpsc::UnboundedSender<ResponseFrame>,
1485        cluster_tx: mpsc::UnboundedSender<RequestFrame>,
1486        connection_no: usize,
1487        request_rules: Arc<Mutex<Vec<RequestRule>>>,
1488        connection_close_signaler: ConnectionCloseSignaler,
1489        event_registered_flag: Arc<AtomicBool>,
1490        compression: CompressionWriter,
1491        cc_event_sender: Arc<Mutex<HashMap<usize, mpsc::UnboundedSender<ResponseFrame>>>>,
1492    ) {
1493        let shard = self.shard;
1494        self.run_until_interrupted("request_processor", |driver_addr, _, real_addr| async move {
1495            'mainloop: loop {
1496                match requests_rx.recv().await {
1497                    Some(request) => {
1498                        if request.opcode == RequestOpcode::Register {
1499                            event_registered_flag.store(true, Ordering::Relaxed);
1500                            // Expose this connection's driver-facing sender so
1501                            // that RunningNode::inject_event_to_cc() can push
1502                            // unsolicited EVENT frames to this control connection.
1503                            cc_event_sender.lock().unwrap().insert(connection_no, driver_tx.clone());
1504                            info!(
1505                                "REGISTER seen on connection {} ({} →  {} ({})); registered cc_event_sender",
1506                                connection_no,
1507                                driver_addr,
1508                                DisplayableRealAddrOption(real_addr),
1509                                DisplayableShard(shard),
1510                            );
1511                        } else if request.opcode == RequestOpcode::Startup {
1512                            match compression.set_from_startup(&request.body) {
1513                                Err(err) => error!("Failed to deserialize STARTUP frame: {}", err),
1514                                Ok(read_compression) => info!(
1515                                    "Intercepted STARTUP frame ({} -> {} ({})), so set compression accordingly to {:?}.",
1516                                    driver_addr,
1517                                    DisplayableRealAddrOption(real_addr),
1518                                    DisplayableShard(shard),
1519                                    read_compression
1520                                )
1521                            };
1522                        }
1523
1524                        let ctx = EvaluationContext {
1525                            connection_seq_no: connection_no,
1526                            opcode: FrameOpcode::Request(request.opcode),
1527                            frame_body: request.body.clone(),
1528                            connection_has_events: event_registered_flag.load(Ordering::Relaxed),
1529                        };
1530                        let mut guard = request_rules.lock().unwrap();
1531                        '_ruleloop: for (i, request_rule) in guard.iter_mut().enumerate() {
1532                            if request_rule.0.eval(&ctx) {
1533                                debug!("Applying rule no={} to request ({} -> {} ({})).", i, driver_addr, DisplayableRealAddrOption(real_addr), DisplayableShard(shard));
1534                                debug!("-> Applied rule: {:?}", request_rule);
1535                                debug!("-> To request: {:?}", ctx.opcode);
1536                                trace!("{:?}", request);
1537
1538                                if let Some(ref tx) = request_rule.1.feedback_channel {
1539                                    tx.send((request.clone(), shard)).unwrap_or_else(|err|
1540                                        warn!("Could not send received request as feedback: {}", err)
1541                                    );
1542                                }
1543
1544                                let request_rule = request_rule.clone();
1545                                let to_addressee_action = request_rule.1.to_addressee;
1546                                let to_sender_action = request_rule.1.to_sender;
1547                                let drop_connection_action = request_rule.1.drop_connection;
1548
1549                                let cluster_tx_clone = cluster_tx.clone();
1550                                let request_clone = request.clone();
1551                                let pass_action = async move {
1552                                    if let Some(ref pass_action) = to_addressee_action {
1553                                        if let Some(time) = pass_action.delay {
1554                                            tokio::time::sleep(time).await;
1555                                        }
1556                                        let passed_frame = match pass_action.msg_processor {
1557                                            Some(ref processor) => processor(request_clone),
1558                                            None => request_clone,
1559                                        };
1560                                        let _ = cluster_tx_clone.send(passed_frame);
1561                                    };
1562                                };
1563
1564                                let driver_tx_clone = driver_tx.clone();
1565                                let request_clone = request.clone();
1566                                let forge_action = async move {
1567                                    if let Some(ref forge_action) = to_sender_action {
1568                                        if let Some(time) = forge_action.delay {
1569                                            tokio::time::sleep(time).await;
1570                                        }
1571                                        let forged_frame = {
1572                                            let processor = forge_action.msg_processor.as_ref()
1573                                                .expect("Frame processor is required to forge a frame.");
1574                                            processor(request_clone)
1575                                        };
1576                                        let _ = driver_tx_clone.send(forged_frame);
1577                                    };
1578                                };
1579
1580                                let connection_close_signaler_clone =
1581                                    connection_close_signaler.clone();
1582                                let drop_action = async move {
1583                                    if let Some(ref delay) = drop_connection_action {
1584                                        if let Some(time) = delay {
1585                                            tokio::time::sleep(*time).await;
1586                                        }
1587                                        // close connection.
1588                                        info!(
1589                                            "Dropping connection between {} and {} ({}) (as requested by a proxy rule)!",
1590                                            driver_addr,
1591                                            DisplayableRealAddrOption(real_addr),
1592                                            DisplayableShard(shard),
1593                                        );
1594                                        let _ = connection_close_signaler_clone.send(());
1595                                    }
1596                                };
1597
1598                                tokio::task::spawn(async {
1599                                    futures::join!(pass_action, forge_action, drop_action);
1600                                });
1601
1602                                continue 'mainloop; // only one rule can be applied to one frame
1603                            }
1604                        }
1605                        let _ = cluster_tx.send(request); // default action
1606                    }
1607                    None => {
1608                        // Connection closed. If this was the control
1609                        // connection (REGISTER was seen), remove only this
1610                        // connection's sender from the shared map.
1611                        if event_registered_flag.load(Ordering::Relaxed) {
1612                            cc_event_sender.lock().unwrap().remove(&connection_no);
1613                            info!(
1614                                "Control connection {} ({} →  {} ({})) closed; removed cc_event_sender",
1615                                connection_no,
1616                                driver_addr,
1617                                DisplayableRealAddrOption(real_addr),
1618                                DisplayableShard(shard),
1619                            );
1620                        }
1621                        return Ok(());
1622                    }
1623                }
1624            }
1625        })
1626        .await;
1627    }
1628
1629    #[expect(clippy::too_many_arguments)]
1630    async fn response_processor(
1631        self,
1632        mut responses_rx: mpsc::UnboundedReceiver<ResponseFrame>,
1633        driver_tx: mpsc::UnboundedSender<ResponseFrame>,
1634        cluster_tx: mpsc::UnboundedSender<RequestFrame>,
1635        connection_no: usize,
1636        response_rules: Arc<Mutex<Vec<ResponseRule>>>,
1637        connection_close_signaler: ConnectionCloseSignaler,
1638        event_registered_flag: Arc<AtomicBool>,
1639    ) {
1640        let shard = self.shard;
1641        self.run_until_interrupted("response_processor", |driver_addr, _, real_addr| async move {
1642            'mainloop: loop {
1643                match responses_rx.recv().await {
1644                    Some(response) => {
1645                        let ctx = EvaluationContext {
1646                            connection_seq_no: connection_no,
1647                            opcode: FrameOpcode::Response(response.opcode),
1648                            frame_body: response.body.clone(),
1649                            connection_has_events: event_registered_flag.load(Ordering::Relaxed),
1650                        };
1651                        let mut guard = response_rules.lock().unwrap();
1652                        '_ruleloop: for (i, response_rule) in guard.iter_mut().enumerate() {
1653                            if response_rule.0.eval(&ctx) {
1654                                debug!("Applying rule no={} to response ({} -> {} ({})).", i, DisplayableRealAddrOption(real_addr), driver_addr, DisplayableShard(shard));
1655                                debug!("-> Applied rule: {:?}", response_rule);
1656                                debug!("-> To response: {:?}", ctx.opcode);
1657                                trace!("{:?}", response);
1658
1659                                if let Some(ref tx) = response_rule.1.feedback_channel {
1660                                    tx.send((response.clone(), shard)).unwrap_or_else(|err| warn!(
1661                                        "Could not send received response as feedback: {}", err
1662                                    ));
1663                                }
1664
1665                                let response_rule = response_rule.clone();
1666                                let to_addressee_action = response_rule.1.to_addressee;
1667                                let to_sender_action = response_rule.1.to_sender;
1668                                let drop_connection_action = response_rule.1.drop_connection;
1669
1670                                let response_clone = response.clone();
1671                                let driver_tx_clone = driver_tx.clone();
1672                                let pass_action = async move {
1673                                    if let Some(ref pass_action) = to_addressee_action {
1674                                        if let Some(time) = pass_action.delay {
1675                                            tokio::time::sleep(time).await;
1676                                        }
1677                                        let passed_frame = match pass_action.msg_processor {
1678                                            Some(ref processor) => processor(response_clone),
1679                                            None => response_clone,
1680                                        };
1681                                        let _ = driver_tx_clone.send(passed_frame);
1682                                    };
1683                                };
1684
1685                                let response_clone = response.clone();
1686                                let cluster_tx_clone = cluster_tx.clone();
1687                                let forge_action = async move {
1688                                    if let Some(ref forge_action) = to_sender_action {
1689                                        if let Some(time) = forge_action.delay {
1690                                            tokio::time::sleep(time).await;
1691                                        }
1692                                        let forged_frame = {
1693                                            let processor = forge_action.msg_processor.as_ref()
1694                                                .expect("Frame processor is required to forge a frame.");
1695                                            processor(response_clone)
1696                                        };
1697                                        let _ = cluster_tx_clone.send(forged_frame);
1698                                    };
1699                                };
1700
1701                                let connection_close_signaler_clone =
1702                                    connection_close_signaler.clone();
1703                                let drop_action = async move {
1704                                    if let Some(ref delay) = drop_connection_action {
1705                                        if let Some(time) = delay {
1706                                            tokio::time::sleep(*time).await;
1707                                        }
1708                                        // close connection.
1709                                        info!(
1710                                            "Dropping connection between {} and {} ({}) (as requested by a proxy rule)!",
1711                                            driver_addr,
1712                                            real_addr.expect("BUG: response rules are unavailable for dry-mode proxy!"),
1713                                            DisplayableShard(shard)
1714                                        );
1715                                        let _ = connection_close_signaler_clone.send(());
1716                                    }
1717                                };
1718
1719                                tokio::task::spawn(async {
1720                                    futures::join!(pass_action, forge_action, drop_action);
1721                                });
1722
1723                                continue 'mainloop;
1724                            }
1725                        }
1726                        let _ = driver_tx.send(response); // default action
1727                    }
1728                    None => return Ok(()),
1729                }
1730            }
1731        })
1732        .await
1733    }
1734}
1735
1736// Returns next free IP address for another proxy instance.
1737// Useful for concurrent testing.
1738pub fn get_exclusive_local_address() -> IpAddr {
1739    match std::env::var("NEXTEST_TEST_GLOBAL_SLOT") {
1740        Ok(slot) => {
1741            let slot: u16 = slot
1742                .parse()
1743                .unwrap_or_else(|e| panic!("Invalid slot {e:?}"));
1744            get_exclusive_local_address_nextest(slot)
1745        }
1746        Err(VarError::NotPresent) => get_exclusive_local_address_libtest(),
1747        Err(VarError::NotUnicode(e)) => panic!("Invalid slot {e:?}"),
1748    }
1749}
1750
1751fn get_exclusive_local_address_libtest() -> IpAddr {
1752    // A big enough number reduces possibility of clashes with user-taken addresses:
1753    static ADDRESS_LOWER_THREE_OCTETS: AtomicU32 = AtomicU32::new(4242);
1754    let next_addr = ADDRESS_LOWER_THREE_OCTETS.fetch_add(1, Ordering::Relaxed);
1755    if next_addr > (u32::MAX >> 8) {
1756        panic!("Loopback address pool for tests depleted");
1757    }
1758    let next_addr_bytes = next_addr.to_le_bytes();
1759    IpAddr::V4(Ipv4Addr::new(
1760        127,
1761        next_addr_bytes[2],
1762        next_addr_bytes[1],
1763        next_addr_bytes[0],
1764    ))
1765}
1766
1767fn get_exclusive_local_address_nextest(slot: u16) -> IpAddr {
1768    static ADDRESS_LOWER_OCTET: AtomicU8 = AtomicU8::new(255);
1769    // This is a heuristic to avoid using low addresses, which I think have
1770    // a higher chance of being taken.
1771    const FREE_RANGES: u16 = 16;
1772    let next_address_lower = ADDRESS_LOWER_OCTET.fetch_sub(1, Ordering::Relaxed);
1773    if next_address_lower == 0 {
1774        panic!("Loopback address pool for this test depleted");
1775    }
1776
1777    let next_range_bytes: [u8; 2] = slot
1778        .checked_add(FREE_RANGES)
1779        .unwrap_or_else(|| panic!("Loopback address pool for tests depleted"))
1780        .to_le_bytes();
1781
1782    IpAddr::V4(Ipv4Addr::new(
1783        127,
1784        next_range_bytes[1],
1785        next_range_bytes[0],
1786        next_address_lower,
1787    ))
1788}
1789
1790#[cfg(test)]
1791mod tests {
1792    use super::compression::no_compression;
1793    use super::*;
1794    use crate::errors::ReadFrameError;
1795    use crate::frame::{FrameType, read_frame, read_request_frame, read_response_frame};
1796    use crate::proxy::compression::with_compression;
1797    use crate::{
1798        Condition, Reaction as _, RequestReaction, ResponseOpcode, ResponseReaction, setup_tracing,
1799    };
1800    use assert_matches::assert_matches;
1801    use bytes::{BufMut, BytesMut};
1802    use futures::future::{join, join3};
1803    use rand::RngCore;
1804    use scylla_cql::frame::request::options;
1805    use scylla_cql::frame::request::{SerializableRequest as _, Startup};
1806    use scylla_cql::frame::types::write_string_multimap;
1807    use scylla_cql::frame::{Compression, flag};
1808    use std::collections::HashMap;
1809    use std::mem;
1810    use std::str::FromStr;
1811    use std::time::Duration;
1812    use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
1813    use tokio::sync::oneshot;
1814
1815    fn random_body() -> Bytes {
1816        let body_len = (rand::random::<u32>() % 1000) as usize;
1817        let mut body = BytesMut::zeroed(body_len);
1818        rand::rng().fill_bytes(body.as_mut());
1819        body.freeze()
1820    }
1821
1822    async fn respond_with_supported(
1823        conn: &mut TcpStream,
1824        supported_options: &HashMap<String, Vec<String>>,
1825        compression: &CompressionReader,
1826    ) {
1827        let RequestFrame {
1828            params: recvd_params,
1829            opcode: recvd_opcode,
1830            body: recvd_body,
1831        } = read_request_frame(conn, compression).await.unwrap();
1832        assert_eq!(recvd_params, HARDCODED_OPTIONS_PARAMS);
1833        assert_eq!(recvd_opcode, RequestOpcode::Options);
1834        assert_eq!(recvd_body, Bytes::new()); // body should be empty
1835
1836        let mut body = BytesMut::new();
1837        write_string_multimap(supported_options, &mut body).unwrap();
1838
1839        let body = body.freeze();
1840
1841        write_frame(
1842            HARDCODED_OPTIONS_PARAMS.for_response(),
1843            FrameOpcode::Response(ResponseOpcode::Supported),
1844            &body,
1845            conn,
1846            &no_compression(),
1847        )
1848        .await
1849        .unwrap();
1850    }
1851
1852    fn supported_shards_count(shards_count: u16) -> HashMap<String, Vec<String>> {
1853        let mut sharded_info = HashMap::new();
1854        sharded_info.insert(
1855            String::from("SCYLLA_NR_SHARDS"),
1856            vec![shards_count.to_string()],
1857        );
1858        sharded_info
1859    }
1860
1861    fn supported_shard_number(shard_num: TargetShard) -> HashMap<String, Vec<String>> {
1862        let mut sharded_info = HashMap::new();
1863        sharded_info.insert(String::from("SCYLLA_SHARD"), vec![shard_num.to_string()]);
1864        sharded_info
1865    }
1866
1867    async fn respond_with_shards_count(
1868        conn: &mut TcpStream,
1869        shards_count: u16,
1870        compression: &CompressionReader,
1871    ) {
1872        respond_with_supported(conn, &supported_shards_count(shards_count), compression).await;
1873    }
1874
1875    async fn respond_with_shard_num(
1876        conn: &mut TcpStream,
1877        shard_num: TargetShard,
1878        compression: &CompressionReader,
1879    ) {
1880        respond_with_supported(conn, &supported_shard_number(shard_num), compression).await;
1881    }
1882
1883    fn next_local_address_with_port(port: u16) -> SocketAddr {
1884        SocketAddr::new(get_exclusive_local_address(), port)
1885    }
1886
1887    async fn identity_proxy_does_not_mutate_frames(shard_awareness: ShardAwareness) {
1888        let node1_real_addr = next_local_address_with_port(9876);
1889        let node1_proxy_addr = next_local_address_with_port(9876);
1890        let proxy = Proxy::new([Node::new(
1891            node1_real_addr,
1892            node1_proxy_addr,
1893            shard_awareness,
1894            None,
1895            None,
1896        )]);
1897        let running_proxy = proxy.run().await.unwrap();
1898
1899        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1900
1901        let params = FrameParams {
1902            flags: 0,
1903            version: 0x04,
1904            stream: 0,
1905        };
1906        let opcode = FrameOpcode::Request(RequestOpcode::Options);
1907
1908        let body = random_body();
1909
1910        let send_frame_to_shard = async {
1911            let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
1912
1913            write_frame(params, opcode, &body, &mut conn, &no_compression())
1914                .await
1915                .unwrap();
1916            conn
1917        };
1918
1919        let mock_node_action = async {
1920            if let ShardAwareness::QueryNode = shard_awareness {
1921                respond_with_shards_count(
1922                    &mut mock_node_listener.accept().await.unwrap().0,
1923                    1,
1924                    &no_compression(),
1925                )
1926                .await;
1927            }
1928            let (mut conn, _) = mock_node_listener.accept().await.unwrap();
1929            if shard_awareness.is_aware() {
1930                respond_with_shard_num(&mut conn, 1, &no_compression()).await;
1931            }
1932            let RequestFrame {
1933                params: recvd_params,
1934                opcode: recvd_opcode,
1935                body: recvd_body,
1936            } = read_request_frame(&mut conn, &no_compression())
1937                .await
1938                .unwrap();
1939            assert_eq!(recvd_params, params);
1940            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
1941            assert_eq!(recvd_body, body);
1942            conn
1943        };
1944
1945        // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
1946        let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
1947        running_proxy.finish().await.unwrap();
1948    }
1949
1950    #[tokio::test]
1951    async fn identity_shard_unaware_proxy_does_not_mutate_frames() {
1952        setup_tracing();
1953        identity_proxy_does_not_mutate_frames(ShardAwareness::Unaware).await
1954    }
1955
1956    #[tokio::test]
1957    async fn identity_shard_aware_proxy_does_not_mutate_frames() {
1958        setup_tracing();
1959        identity_proxy_does_not_mutate_frames(ShardAwareness::QueryNode).await
1960    }
1961
1962    #[tokio::test]
1963    async fn shard_aware_proxy_is_transparent_for_connection_to_shards() {
1964        setup_tracing();
1965        async fn test_for_shards_num(shards_num: u16) {
1966            let node1_real_addr = next_local_address_with_port(9876);
1967            let node1_proxy_addr = next_local_address_with_port(9876);
1968            let proxy = Proxy::new([Node::new(
1969                node1_real_addr,
1970                node1_proxy_addr,
1971                ShardAwareness::FixedNum(shards_num),
1972                None,
1973                None,
1974            )]);
1975            let running_proxy = proxy.run().await.unwrap();
1976
1977            let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
1978
1979            let (driver_addr_tx, driver_addr_rx) = oneshot::channel::<SocketAddr>();
1980
1981            let send_frame_to_shard = async {
1982                let socket = TcpSocket::new_v4().unwrap();
1983                socket
1984                    .bind(SocketAddr::from_str("0.0.0.0:0").unwrap())
1985                    .unwrap();
1986                let conn = socket.connect(node1_proxy_addr).await.unwrap();
1987                driver_addr_tx.send(conn.local_addr().unwrap()).unwrap();
1988                conn
1989            };
1990
1991            let mock_node_action = async {
1992                let (conn, remote_addr) = mock_node_listener.accept().await.unwrap();
1993                let driver_addr = driver_addr_rx.await.unwrap();
1994                assert_eq!(
1995                    driver_addr.port() % shards_num,
1996                    remote_addr.port() % shards_num
1997                );
1998                conn
1999            };
2000
2001            // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
2002            let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
2003            running_proxy.finish().await.unwrap();
2004        }
2005
2006        for shard_num in 1..6 {
2007            test_for_shards_num(shard_num).await;
2008        }
2009    }
2010
2011    #[tokio::test]
2012    async fn shard_aware_proxy_queries_shards_number() {
2013        setup_tracing();
2014        async fn test_for_shards_num(shards_num: u16) {
2015            for shard_num in 0..shards_num {
2016                let node1_real_addr = next_local_address_with_port(9876);
2017                let node1_proxy_addr = next_local_address_with_port(9876);
2018                let proxy = Proxy::new([Node::new(
2019                    node1_real_addr,
2020                    node1_proxy_addr,
2021                    ShardAwareness::QueryNode,
2022                    None,
2023                    None,
2024                )]);
2025                let running_proxy = proxy.run().await.unwrap();
2026
2027                let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2028
2029                let (driver_addr_tx, driver_addr_rx) = oneshot::channel::<SocketAddr>();
2030
2031                let mock_driver_addr = next_local_address_with_port(shards_num * 1234 + shard_num);
2032                let send_frame_to_shard = async {
2033                    let socket = TcpSocket::new_v4().unwrap();
2034                    socket
2035                        .bind(mock_driver_addr)
2036                        .unwrap_or_else(|_| panic!("driver_addr failed: {mock_driver_addr}"));
2037                    driver_addr_tx.send(socket.local_addr().unwrap()).unwrap();
2038                    socket.connect(node1_proxy_addr).await.unwrap()
2039                };
2040
2041                let mock_node_action = async {
2042                    respond_with_shards_count(
2043                        &mut mock_node_listener.accept().await.unwrap().0,
2044                        shards_num,
2045                        &no_compression(),
2046                    )
2047                    .await;
2048                    let (conn, remote_addr) = mock_node_listener.accept().await.unwrap();
2049                    let driver_addr = driver_addr_rx.await.unwrap();
2050                    assert_eq!(
2051                        driver_addr.port() % shards_num,
2052                        remote_addr.port() % shards_num
2053                    );
2054                    conn
2055                };
2056
2057                let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
2058                running_proxy.finish().await.unwrap();
2059            }
2060        }
2061
2062        for shard_num in 1..6 {
2063            test_for_shards_num(shard_num).await;
2064        }
2065    }
2066
2067    #[tokio::test]
2068    async fn forger_proxy_forges_response() {
2069        setup_tracing();
2070        let node1_real_addr = next_local_address_with_port(9876);
2071        let node1_proxy_addr = next_local_address_with_port(9876);
2072
2073        let this_shall_pass = b"This.Shall.Pass.";
2074        let test_msg = b"Test";
2075
2076        let proxy = Proxy::new([Node::new(
2077            node1_real_addr,
2078            node1_proxy_addr,
2079            ShardAwareness::Unaware,
2080            Some(vec![
2081                RequestRule(
2082                    Condition::RequestOpcode(RequestOpcode::Register),
2083                    RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
2084                        ResponseFrame {
2085                            params: params.for_response(),
2086                            opcode: ResponseOpcode::Event,
2087                            body: Bytes::from_static(test_msg),
2088                        }
2089                    })),
2090                ),
2091                RequestRule(
2092                    Condition::BodyContainsCaseSensitive(Box::new(*this_shall_pass)),
2093                    RequestReaction::noop(),
2094                ),
2095                RequestRule(
2096                    Condition::True, // only the first matching rule is applied, so "True" covers all remaining cases
2097                    RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
2098                        ResponseFrame {
2099                            params: params.for_response(),
2100                            opcode: ResponseOpcode::Ready,
2101                            body: Bytes::new(),
2102                        }
2103                    })),
2104                ),
2105            ]),
2106            None,
2107        )]);
2108        let running_proxy = proxy.run().await.unwrap();
2109
2110        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2111
2112        let params1 = FrameParams {
2113            flags: 2,
2114            version: 0x42,
2115            stream: 42,
2116        };
2117        let opcode1 = FrameOpcode::Request(RequestOpcode::Startup);
2118
2119        let params2 = FrameParams {
2120            flags: 4,
2121            version: 0x04,
2122            stream: 17,
2123        };
2124        let opcode2 = FrameOpcode::Request(RequestOpcode::Register);
2125
2126        let params3 = FrameParams {
2127            flags: 8,
2128            version: 0x04,
2129            stream: 11,
2130        };
2131        let opcode3 = FrameOpcode::Request(RequestOpcode::Execute);
2132
2133        let body1 = random_body();
2134        let body2 = random_body();
2135        let body3 = {
2136            let mut body = BytesMut::new();
2137            body.put(&b"uSeLeSs JuNk"[..]);
2138            body.put(&this_shall_pass[..]);
2139            body.freeze()
2140        };
2141
2142        let send_frame_to_shard = async {
2143            let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2144
2145            write_frame(params1, opcode1, &body1, &mut conn, &no_compression())
2146                .await
2147                .unwrap();
2148            write_frame(params2, opcode2, &body2, &mut conn, &no_compression())
2149                .await
2150                .unwrap();
2151            write_frame(params3, opcode3, &body3, &mut conn, &no_compression())
2152                .await
2153                .unwrap();
2154
2155            let ResponseFrame {
2156                params: recvd_params,
2157                opcode: recvd_opcode,
2158                body: recvd_body,
2159            } = read_response_frame(&mut conn, &no_compression())
2160                .await
2161                .unwrap();
2162            assert_eq!(recvd_params, params1.for_response());
2163            assert_eq!(recvd_opcode, ResponseOpcode::Ready);
2164            assert_eq!(recvd_body, Bytes::new());
2165
2166            let ResponseFrame {
2167                params: recvd_params,
2168                opcode: recvd_opcode,
2169                body: recvd_body,
2170            } = read_response_frame(&mut conn, &no_compression())
2171                .await
2172                .unwrap();
2173            assert_eq!(recvd_params, params2.for_response());
2174            assert_eq!(recvd_opcode, ResponseOpcode::Event);
2175            assert_eq!(recvd_body, Bytes::from_static(test_msg));
2176
2177            conn
2178        };
2179
2180        let mock_node_action = async {
2181            let (mut conn, _) = mock_node_listener.accept().await.unwrap();
2182            let RequestFrame {
2183                params: recvd_params,
2184                opcode: recvd_opcode,
2185                body: recvd_body,
2186            } = read_request_frame(&mut conn, &no_compression())
2187                .await
2188                .unwrap();
2189            assert_eq!(recvd_params, params3);
2190            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode3);
2191            assert_eq!(recvd_body, body3);
2192
2193            conn
2194        };
2195
2196        let (mut node_conn, mut driver_conn) = join(mock_node_action, send_frame_to_shard).await;
2197
2198        running_proxy.finish().await.unwrap();
2199
2200        assert_matches!(driver_conn.read(&mut [0u8; 1]).await, Ok(0));
2201        assert_matches!(node_conn.read(&mut [0u8; 1]).await, Ok(0));
2202    }
2203
2204    #[tokio::test]
2205    async fn ad_hoc_rules_changing() {
2206        setup_tracing();
2207        let node1_real_addr = next_local_address_with_port(9876);
2208        let node1_proxy_addr = next_local_address_with_port(9876);
2209        let proxy = Proxy::new([Node::new(
2210            node1_real_addr,
2211            node1_proxy_addr,
2212            ShardAwareness::Unaware,
2213            None,
2214            None,
2215        )]);
2216        let mut running_proxy = proxy.run().await.unwrap();
2217
2218        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2219
2220        let params = FrameParams {
2221            flags: 0,
2222            version: 0x04,
2223            stream: 0,
2224        };
2225        let opcode = FrameOpcode::Request(RequestOpcode::Options);
2226
2227        let body = random_body();
2228
2229        let (mut driver, mut node) = {
2230            let results = join(
2231                TcpStream::connect(node1_proxy_addr),
2232                mock_node_listener.accept(),
2233            )
2234            .await;
2235            (results.0.unwrap(), results.1.unwrap().0)
2236        };
2237
2238        async fn request(
2239            driver: &mut TcpStream,
2240            node: &mut TcpStream,
2241            params: FrameParams,
2242            opcode: FrameOpcode,
2243            body: &Bytes,
2244        ) -> Result<RequestFrame, ReadFrameError> {
2245            let (send_res, recv_res) = join(
2246                write_frame(params, opcode, &body.clone(), driver, &no_compression()),
2247                read_request_frame(node, &no_compression()),
2248            )
2249            .await;
2250            send_res.unwrap();
2251            recv_res
2252        }
2253        {
2254            // one run still without custom rules
2255            let RequestFrame {
2256                params: recvd_params,
2257                opcode: recvd_opcode,
2258                body: recvd_body,
2259            } = request(&mut driver, &mut node, params, opcode, &body)
2260                .await
2261                .unwrap();
2262            assert_eq!(recvd_params, params);
2263            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
2264            assert_eq!(recvd_body, body);
2265        }
2266        running_proxy.running_nodes[0].change_request_rules(Some(vec![RequestRule(
2267            Condition::True,
2268            RequestReaction::drop_frame(),
2269        )]));
2270
2271        {
2272            // one run with custom rules
2273            tokio::select! {
2274                res = request(&mut driver, &mut node, params, opcode, &body) => panic!("Rules did not work: received response {:?}", res),
2275                _ = tokio::time::sleep(std::time::Duration::from_millis(20)) => (),
2276            };
2277        }
2278
2279        running_proxy.turn_off_rules();
2280
2281        {
2282            // one run already without custom rules
2283            let RequestFrame {
2284                params: recvd_params,
2285                opcode: recvd_opcode,
2286                body: recvd_body,
2287            } = request(&mut driver, &mut node, params, opcode, &body)
2288                .await
2289                .unwrap();
2290            assert_eq!(recvd_params, params);
2291            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
2292            assert_eq!(recvd_body, body);
2293        }
2294
2295        running_proxy.finish().await.unwrap();
2296    }
2297
2298    #[tokio::test]
2299    async fn limited_times_condition_expires() {
2300        setup_tracing();
2301        const FAILING_TRIES: usize = 4;
2302        const PASSING_TRIES: usize = 5;
2303
2304        let node1_real_addr = next_local_address_with_port(9876);
2305        let node1_proxy_addr = next_local_address_with_port(9876);
2306        let proxy = Proxy::new([Node::new(
2307            node1_real_addr,
2308            node1_proxy_addr,
2309            ShardAwareness::Unaware,
2310            Some(vec![
2311                RequestRule(
2312                    // this will be always fired after first PASSING_TRIES + FAILING_TRIES
2313                    Condition::not(Condition::TrueForLimitedTimes(
2314                        FAILING_TRIES + PASSING_TRIES,
2315                    )),
2316                    RequestReaction::drop_frame(),
2317                ),
2318                RequestRule(
2319                    // this will be fired for PASSING_TRIES after first FAILING_TRIES
2320                    Condition::not(Condition::TrueForLimitedTimes(FAILING_TRIES)),
2321                    RequestReaction::noop(),
2322                ),
2323                RequestRule(
2324                    // this will be fired for first FAILING_TRIES
2325                    Condition::True,
2326                    RequestReaction::drop_frame(),
2327                ),
2328            ]),
2329            None,
2330        )]);
2331        let running_proxy = proxy.run().await.unwrap();
2332
2333        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2334
2335        let params = FrameParams {
2336            flags: 0,
2337            version: 0x04,
2338            stream: 0,
2339        };
2340        let opcode = FrameOpcode::Request(RequestOpcode::Options);
2341        let body = random_body();
2342
2343        let (mut driver, mut node) = {
2344            let results = join(
2345                TcpStream::connect(node1_proxy_addr),
2346                mock_node_listener.accept(),
2347            )
2348            .await;
2349            (results.0.unwrap(), results.1.unwrap().0)
2350        };
2351
2352        async fn request(
2353            driver: &mut TcpStream,
2354            node: &mut TcpStream,
2355            params: FrameParams,
2356            opcode: FrameOpcode,
2357            body: &Bytes,
2358        ) -> Result<RequestFrame, ReadFrameError> {
2359            let (send_res, recv_res) = join(
2360                write_frame(params, opcode, &body.clone(), driver, &no_compression()),
2361                read_request_frame(node, &no_compression()),
2362            )
2363            .await;
2364            send_res.unwrap();
2365            recv_res
2366        }
2367
2368        for _ in 0..FAILING_TRIES {
2369            tokio::select! {
2370                res = request(&mut driver, &mut node, params, opcode, &body) => panic!("Rules did not work: received response {:?}", res),
2371                _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => (),
2372            };
2373        }
2374
2375        for _ in 0..PASSING_TRIES {
2376            let RequestFrame {
2377                params: recvd_params,
2378                opcode: recvd_opcode,
2379                body: recvd_body,
2380            } = request(&mut driver, &mut node, params, opcode, &body)
2381                .await
2382                .unwrap();
2383            assert_eq!(recvd_params, params);
2384            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode);
2385            assert_eq!(recvd_body, body);
2386        }
2387
2388        for _ in 0..3 {
2389            // any further number of requests should fail
2390            tokio::select! {
2391                res = request(&mut driver, &mut node, params, opcode, &body) => panic!("Rules did not work: received response {:?}", res),
2392                _ = tokio::time::sleep(std::time::Duration::from_millis(10)) => (),
2393            };
2394        }
2395
2396        running_proxy.finish().await.unwrap();
2397    }
2398
2399    #[tokio::test]
2400    async fn proxy_reports_requests_and_responses_as_feedback() {
2401        setup_tracing();
2402        let node1_real_addr = next_local_address_with_port(9876);
2403        let node1_proxy_addr = next_local_address_with_port(9876);
2404
2405        let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
2406        let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
2407        let proxy = Proxy::new([Node::new(
2408            node1_real_addr,
2409            node1_proxy_addr,
2410            ShardAwareness::Unaware,
2411            Some(vec![RequestRule(
2412                Condition::True,
2413                RequestReaction::drop_frame().with_feedback_when_performed(request_feedback_tx),
2414            )]),
2415            Some(vec![ResponseRule(
2416                Condition::True,
2417                ResponseReaction::drop_frame().with_feedback_when_performed(response_feedback_tx),
2418            )]),
2419        )]);
2420        let running_proxy = proxy.run().await.unwrap();
2421
2422        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2423
2424        let params = FrameParams {
2425            flags: 0,
2426            version: 0x04,
2427            stream: 0,
2428        };
2429        let request_opcode = FrameOpcode::Request(RequestOpcode::Options);
2430        let response_opcode = FrameOpcode::Response(ResponseOpcode::Ready);
2431
2432        let body = random_body();
2433
2434        let send_frame_to_shard = async {
2435            let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2436            write_frame(params, request_opcode, &body, &mut conn, &no_compression())
2437                .await
2438                .unwrap();
2439            conn
2440        };
2441
2442        let mock_node_action = async {
2443            let (mut conn, _) = mock_node_listener.accept().await.unwrap();
2444            write_frame(
2445                params.for_response(),
2446                response_opcode,
2447                &body,
2448                &mut conn,
2449                &no_compression(),
2450            )
2451            .await
2452            .unwrap();
2453            conn
2454        };
2455
2456        // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
2457        let (_node_conn, _driver_conn) = join(mock_node_action, send_frame_to_shard).await;
2458
2459        let (feedback_request, _shard) = request_feedback_rx.recv().await.unwrap();
2460        assert_eq!(feedback_request.params, params);
2461        assert_eq!(
2462            FrameOpcode::Request(feedback_request.opcode),
2463            request_opcode
2464        );
2465        assert_eq!(feedback_request.body, body);
2466        let (feedback_response, _shard) = response_feedback_rx.recv().await.unwrap();
2467        assert_eq!(feedback_response.params, params.for_response());
2468        assert_eq!(
2469            FrameOpcode::Response(feedback_response.opcode),
2470            response_opcode
2471        );
2472        assert_eq!(feedback_response.body, body);
2473
2474        running_proxy.finish().await.unwrap();
2475    }
2476
2477    #[tokio::test]
2478    async fn sanity_check_reports_errors() {
2479        setup_tracing();
2480        let node1_real_addr = next_local_address_with_port(9876);
2481        let node1_proxy_addr = next_local_address_with_port(9876);
2482        let proxy = Proxy::new([Node::new(
2483            node1_real_addr,
2484            node1_proxy_addr,
2485            ShardAwareness::Unaware,
2486            None,
2487            None,
2488        )]);
2489        let mut running_proxy = proxy.run().await.unwrap();
2490
2491        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2492
2493        let send_frame_to_shard = async {
2494            let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2495
2496            conn.write_all(b"uselessJunk").await.unwrap();
2497            conn
2498        };
2499
2500        let mock_node_action = async {
2501            let (conn, _) = mock_node_listener.accept().await.unwrap();
2502            conn
2503        };
2504
2505        let (node_conn, driver_conn) = join(mock_node_action, send_frame_to_shard).await;
2506
2507        running_proxy.sanity_check().unwrap();
2508
2509        mem::drop(driver_conn);
2510        assert_matches!(
2511            running_proxy.wait_for_error().await,
2512            Some(ProxyError::Worker(WorkerError::DriverDisconnected(_)))
2513        );
2514        running_proxy.sanity_check().unwrap();
2515
2516        mem::drop(node_conn);
2517        assert_matches!(
2518            running_proxy.wait_for_error().await,
2519            Some(ProxyError::Worker(WorkerError::NodeDisconnected(_)))
2520        );
2521        running_proxy.sanity_check().unwrap();
2522
2523        // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
2524        let _ = running_proxy.finish().await;
2525    }
2526
2527    #[tokio::test]
2528    async fn proxy_processes_requests_concurrently() {
2529        setup_tracing();
2530        let node1_real_addr = next_local_address_with_port(9876);
2531        let node1_proxy_addr = next_local_address_with_port(9876);
2532
2533        let delay = Duration::from_millis(60);
2534
2535        let proxy = Proxy::new([Node::new(
2536            node1_real_addr,
2537            node1_proxy_addr,
2538            ShardAwareness::Unaware,
2539            Some(vec![RequestRule(
2540                Condition::TrueForLimitedTimes(1),
2541                RequestReaction::delay(delay),
2542            )]),
2543            None,
2544        )]);
2545        let running_proxy = proxy.run().await.unwrap();
2546
2547        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2548
2549        let params1 = FrameParams {
2550            flags: 0,
2551            version: 0x04,
2552            stream: 0,
2553        };
2554        let opcode1 = FrameOpcode::Request(RequestOpcode::Options);
2555
2556        let body1 = random_body();
2557
2558        let params2 = FrameParams {
2559            flags: 0,
2560            version: 0x04,
2561            stream: 0,
2562        };
2563        let opcode2 = FrameOpcode::Request(RequestOpcode::Register);
2564
2565        let body2 = random_body();
2566
2567        let send_frame_to_shard = async {
2568            let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2569
2570            write_frame(params1, opcode1, &body1, &mut conn, &no_compression())
2571                .await
2572                .unwrap();
2573            write_frame(params2, opcode2, &body2, &mut conn, &no_compression())
2574                .await
2575                .unwrap();
2576            conn
2577        };
2578
2579        let mock_node_action = async {
2580            let (mut conn, _) = mock_node_listener.accept().await.unwrap();
2581            let RequestFrame {
2582                params: recvd_params,
2583                opcode: recvd_opcode,
2584                body: recvd_body,
2585            } = read_request_frame(&mut conn, &no_compression())
2586                .await
2587                .unwrap();
2588            assert_eq!(recvd_params, params2);
2589            assert_eq!(FrameOpcode::Request(recvd_opcode), opcode2);
2590            assert_eq!(recvd_body, body2);
2591            conn
2592        };
2593
2594        // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
2595        let (_node_conn, _driver_conn) =
2596            tokio::time::timeout(delay, join(mock_node_action, send_frame_to_shard))
2597                .await
2598                .expect("Request processing was not concurrent");
2599        running_proxy.finish().await.unwrap();
2600    }
2601
2602    #[tokio::test]
2603    async fn dry_mode_proxy_drops_incoming_frames() {
2604        setup_tracing();
2605        let node1_proxy_addr = next_local_address_with_port(9876);
2606        let proxy = Proxy::new([Node::new_dry_mode(node1_proxy_addr, None)]);
2607        let running_proxy = proxy.run().await.unwrap();
2608
2609        let params = FrameParams {
2610            flags: 0,
2611            version: 0x04,
2612            stream: 0,
2613        };
2614        let opcode = FrameOpcode::Request(RequestOpcode::Options);
2615
2616        let body = random_body();
2617
2618        let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2619
2620        write_frame(params, opcode, &body, &mut conn, &no_compression())
2621            .await
2622            .unwrap();
2623        // We assert that after sufficiently long time, no error happens inside proxy.
2624        tokio::time::sleep(Duration::from_millis(3)).await;
2625        running_proxy.finish().await.unwrap();
2626    }
2627
2628    #[tokio::test]
2629    async fn dry_mode_forger_proxy_forges_response() {
2630        setup_tracing();
2631        let node1_proxy_addr = next_local_address_with_port(9876);
2632
2633        let this_shall_pass = b"This.Shall.Pass.";
2634        let test_msg = b"Test";
2635
2636        let proxy = Proxy::new([Node::new_dry_mode(
2637            node1_proxy_addr,
2638            Some(vec![
2639                RequestRule(
2640                    Condition::RequestOpcode(RequestOpcode::Register),
2641                    RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
2642                        ResponseFrame {
2643                            params: params.for_response(),
2644                            opcode: ResponseOpcode::Event,
2645                            body: Bytes::from_static(test_msg),
2646                        }
2647                    })),
2648                ),
2649                RequestRule(
2650                    Condition::BodyContainsCaseSensitive(Box::new(*this_shall_pass)),
2651                    RequestReaction::noop(),
2652                ),
2653                RequestRule(
2654                    Condition::True, // only the first matching rule is applied, so "True" covers all remaining cases
2655                    RequestReaction::forge_response(Arc::new(|RequestFrame { params, .. }| {
2656                        ResponseFrame {
2657                            params: params.for_response(),
2658                            opcode: ResponseOpcode::Ready,
2659                            body: Bytes::new(),
2660                        }
2661                    })),
2662                ),
2663            ]),
2664        )]);
2665        let running_proxy = proxy.run().await.unwrap();
2666
2667        let params1 = FrameParams {
2668            flags: 2,
2669            version: 0x42,
2670            stream: 42,
2671        };
2672        let opcode1 = FrameOpcode::Request(RequestOpcode::Startup);
2673
2674        let params2 = FrameParams {
2675            flags: 4,
2676            version: 0x04,
2677            stream: 17,
2678        };
2679        let opcode2 = FrameOpcode::Request(RequestOpcode::Register);
2680
2681        let params3 = FrameParams {
2682            flags: 8,
2683            version: 0x04,
2684            stream: 11,
2685        };
2686        let opcode3 = FrameOpcode::Request(RequestOpcode::Execute);
2687
2688        let body1 = random_body();
2689        let body2 = random_body();
2690        let body3 = {
2691            let mut body = BytesMut::new();
2692            body.put(&b"uSeLeSs JuNk"[..]);
2693            body.put(&this_shall_pass[..]);
2694            body.freeze()
2695        };
2696
2697        let mut conn = TcpStream::connect(node1_proxy_addr).await.unwrap();
2698
2699        write_frame(params1, opcode1, &body1, &mut conn, &no_compression())
2700            .await
2701            .unwrap();
2702        write_frame(params2, opcode2, &body2, &mut conn, &no_compression())
2703            .await
2704            .unwrap();
2705        write_frame(params3, opcode3, &body3, &mut conn, &no_compression())
2706            .await
2707            .unwrap();
2708
2709        let ResponseFrame {
2710            params: recvd_params,
2711            opcode: recvd_opcode,
2712            body: recvd_body,
2713        } = read_response_frame(&mut conn, &no_compression())
2714            .await
2715            .unwrap();
2716        assert_eq!(recvd_params, params1.for_response());
2717        assert_eq!(recvd_opcode, ResponseOpcode::Ready);
2718        assert_eq!(recvd_body, Bytes::new());
2719
2720        let ResponseFrame {
2721            params: recvd_params,
2722            opcode: recvd_opcode,
2723            body: recvd_body,
2724        } = read_response_frame(&mut conn, &no_compression())
2725            .await
2726            .unwrap();
2727        assert_eq!(recvd_params, params2.for_response());
2728        assert_eq!(recvd_opcode, ResponseOpcode::Event);
2729        assert_eq!(recvd_body, Bytes::from_static(test_msg));
2730
2731        running_proxy.finish().await.unwrap();
2732
2733        assert_matches!(conn.read(&mut [0u8; 1]).await, Ok(0));
2734    }
2735
2736    // The test asserts that once a (mock) driver connects to the proxy from some port,
2737    // the proxy will connect to a shard corresponding to that port and that the target
2738    // shard number will be sent through the feedback channel.
2739    #[tokio::test]
2740    async fn proxy_reports_target_shard_as_feedback() {
2741        setup_tracing();
2742
2743        let node_port = 10101;
2744        let node_real_addr = next_local_address_with_port(node_port);
2745        let mock_node_listener = TcpListener::bind(node_real_addr).await.unwrap();
2746
2747        let params = FrameParams {
2748            flags: 0,
2749            version: 0x04,
2750            stream: 0,
2751        };
2752        let request_opcode = FrameOpcode::Request(RequestOpcode::Options);
2753        let response_opcode = FrameOpcode::Response(ResponseOpcode::Ready);
2754
2755        let body = random_body();
2756
2757        for shards_count in 2..9 {
2758            // Two driver connections are simulated, each to a different shard.
2759            let driver1_shard = shards_count - 1;
2760            let driver2_shard = shards_count - 2;
2761            let node_proxy_addr = next_local_address_with_port(node_port);
2762
2763            let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
2764            let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
2765
2766            let proxy = Proxy::new([Node::new(
2767                node_real_addr,
2768                node_proxy_addr,
2769                ShardAwareness::FixedNum(shards_count),
2770                Some(vec![RequestRule(
2771                    Condition::True,
2772                    RequestReaction::drop_frame().with_feedback_when_performed(request_feedback_tx),
2773                )]),
2774                Some(vec![ResponseRule(
2775                    Condition::True,
2776                    ResponseReaction::drop_frame()
2777                        .with_feedback_when_performed(response_feedback_tx),
2778                )]),
2779            )]);
2780            let running_proxy = proxy.run().await.unwrap();
2781
2782            /// Choose a source port `p` such that `shard == shard_of_source_port(p)`.
2783            fn draw_source_port_for_shard(shards_count: u16, shard: u16) -> u16 {
2784                assert!(shard < shards_count);
2785                49152u16.div_ceil(shards_count) * shards_count + shard
2786            }
2787
2788            async fn bind_socket_for_shard(shards_count: u16, shard: u16) -> TcpSocket {
2789                let socket = TcpSocket::new_v4().unwrap();
2790                let initial_port = draw_source_port_for_shard(shards_count, shard);
2791
2792                let mut desired_addr =
2793                    SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), initial_port);
2794                while socket.bind(desired_addr).is_err() {
2795                    // in search for a port that translates to the desired shard
2796                    let next_port = desired_addr.port().wrapping_add(shards_count);
2797                    if next_port == initial_port {
2798                        panic!("No more ports left");
2799                    }
2800                    desired_addr.set_port(next_port);
2801                }
2802
2803                socket
2804            }
2805
2806            let body_ref = &body;
2807            let send_frame_to_shard = |driver_shard: u16| async move {
2808                let socket = bind_socket_for_shard(shards_count, driver_shard).await;
2809                let mut conn = socket.connect(node_proxy_addr).await.unwrap();
2810
2811                write_frame(
2812                    params,
2813                    request_opcode,
2814                    body_ref,
2815                    &mut conn,
2816                    &no_compression(),
2817                )
2818                .await
2819                .unwrap();
2820                conn
2821            };
2822
2823            let mock_driver1_action = send_frame_to_shard(driver1_shard);
2824            let mock_driver2_action = send_frame_to_shard(driver2_shard);
2825
2826            // Accepts two connections and sends a response to each of them.
2827            let mock_node_action = async {
2828                let mut conns_futs = (0..2)
2829                    .map(|_| async {
2830                        let (mut conn, driver_addr) = mock_node_listener.accept().await.unwrap();
2831                        respond_with_shard_num(
2832                            &mut conn,
2833                            driver_addr.port() % shards_count,
2834                            &no_compression(),
2835                        )
2836                        .await;
2837                        write_frame(
2838                            params.for_response(),
2839                            response_opcode,
2840                            body_ref,
2841                            &mut conn,
2842                            &no_compression(),
2843                        )
2844                        .await
2845                        .unwrap();
2846                        conn
2847                    })
2848                    .collect::<Vec<_>>();
2849                let conn2 = conns_futs.pop().unwrap().await;
2850                let conn1 = conns_futs.pop().unwrap().await;
2851                (conn1, conn2)
2852            };
2853
2854            // we keep the connections open until proxy finishes to let it perform clean exit with no disconnects
2855            let (_node_conns, _driver1_conn, _driver2_conn) =
2856                join3(mock_node_action, mock_driver1_action, mock_driver2_action).await;
2857
2858            let assert_feedback_request = |feedback_request: RequestFrame| {
2859                assert_eq!(feedback_request.params, params);
2860                assert_eq!(
2861                    FrameOpcode::Request(feedback_request.opcode),
2862                    request_opcode
2863                );
2864                assert_eq!(feedback_request.body, body);
2865            };
2866
2867            let assert_feedback_response = |feedback_response: ResponseFrame| {
2868                assert_eq!(feedback_response.params, params.for_response());
2869                assert_eq!(
2870                    FrameOpcode::Response(feedback_response.opcode),
2871                    response_opcode
2872                );
2873                assert_eq!(feedback_response.body, body);
2874            };
2875
2876            let (feedback_request, shard1) = request_feedback_rx.recv().await.unwrap();
2877            assert_feedback_request(feedback_request);
2878            let (feedback_request, shard2) = request_feedback_rx.recv().await.unwrap();
2879            assert_feedback_request(feedback_request);
2880            let (feedback_response, shard3) = response_feedback_rx.recv().await.unwrap();
2881            assert_feedback_response(feedback_response);
2882            let (feedback_response, shard4) = response_feedback_rx.recv().await.unwrap();
2883            assert_feedback_response(feedback_response);
2884
2885            // expected: {driver1_shard request, driver1_shard response, driver2_shard request, driver2_shard response}
2886            let mut expected_shards = [driver1_shard, driver1_shard, driver2_shard, driver2_shard];
2887            expected_shards.sort_unstable();
2888
2889            let mut got_shards = [
2890                shard1.unwrap(),
2891                shard2.unwrap(),
2892                shard3.unwrap(),
2893                shard4.unwrap(),
2894            ];
2895            got_shards.sort_unstable();
2896
2897            assert_eq!(expected_shards, got_shards);
2898
2899            running_proxy.finish().await.unwrap();
2900        }
2901    }
2902
2903    #[tokio::test]
2904    async fn proxy_ignores_control_connection_messages() {
2905        setup_tracing();
2906        let node1_real_addr = next_local_address_with_port(9876);
2907        let node1_proxy_addr = next_local_address_with_port(9876);
2908
2909        let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
2910        let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
2911        let proxy = Proxy::new([Node::new(
2912            node1_real_addr,
2913            node1_proxy_addr,
2914            ShardAwareness::Unaware,
2915            Some(vec![RequestRule(
2916                Condition::not(Condition::ConnectionRegisteredAnyEvent),
2917                RequestReaction::noop().with_feedback_when_performed(request_feedback_tx),
2918            )]),
2919            Some(vec![ResponseRule(
2920                Condition::not(Condition::ConnectionRegisteredAnyEvent),
2921                ResponseReaction::noop().with_feedback_when_performed(response_feedback_tx),
2922            )]),
2923        )]);
2924        let running_proxy = proxy.run().await.unwrap();
2925
2926        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
2927
2928        let (mut client_socket, mut server_socket) = join(
2929            async { TcpStream::connect(node1_proxy_addr).await.unwrap() },
2930            async { mock_node_listener.accept().await.unwrap().0 },
2931        )
2932        .await;
2933
2934        async fn perform_reqest_response<'a>(
2935            req_opcode: RequestOpcode,
2936            resp_opcode: ResponseOpcode,
2937            client_socket_ref: &'a mut TcpStream,
2938            server_socket_ref: &'a mut TcpStream,
2939            body_base: &'a str,
2940        ) {
2941            let params = FrameParams {
2942                flags: 0,
2943                version: 0x04,
2944                stream: 0,
2945            };
2946
2947            write_frame(
2948                params,
2949                FrameOpcode::Request(req_opcode),
2950                (body_base.to_string() + "|request|").as_bytes(),
2951                client_socket_ref,
2952                &no_compression(),
2953            )
2954            .await
2955            .unwrap();
2956
2957            let received_request =
2958                read_frame(server_socket_ref, FrameType::Request, &no_compression())
2959                    .await
2960                    .unwrap();
2961            assert_eq!(received_request.1, FrameOpcode::Request(req_opcode));
2962
2963            write_frame(
2964                params.for_response(),
2965                FrameOpcode::Response(resp_opcode),
2966                (body_base.to_string() + "|response|").as_bytes(),
2967                server_socket_ref,
2968                &no_compression(),
2969            )
2970            .await
2971            .unwrap();
2972
2973            let received_response =
2974                read_frame(client_socket_ref, FrameType::Response, &no_compression())
2975                    .await
2976                    .unwrap();
2977            assert_eq!(received_response.1, FrameOpcode::Response(resp_opcode));
2978        }
2979
2980        // Messages before REGISTER should be fed back to channels
2981        for i in 0..5 {
2982            perform_reqest_response(
2983                RequestOpcode::Query,
2984                ResponseOpcode::Result,
2985                &mut client_socket,
2986                &mut server_socket,
2987                &format!("message_before_{i}"),
2988            )
2989            .await
2990        }
2991
2992        perform_reqest_response(
2993            RequestOpcode::Register,
2994            ResponseOpcode::Result,
2995            &mut client_socket,
2996            &mut server_socket,
2997            "message_register",
2998        )
2999        .await;
3000
3001        // Messages after REGISTER should be passed through without feedback
3002        for i in 0..5 {
3003            perform_reqest_response(
3004                RequestOpcode::Query,
3005                ResponseOpcode::Result,
3006                &mut client_socket,
3007                &mut server_socket,
3008                &format!("message_after_{i}"),
3009            )
3010            .await
3011        }
3012
3013        running_proxy.finish().await.unwrap();
3014
3015        for _ in 0..5 {
3016            let (feedback_request, _shard) = request_feedback_rx.recv().await.unwrap();
3017            assert_eq!(feedback_request.opcode, RequestOpcode::Query);
3018            let (feedback_response, _shard) = response_feedback_rx.recv().await.unwrap();
3019            assert_eq!(feedback_response.opcode, ResponseOpcode::Result);
3020        }
3021
3022        // Response to REGISTER and further requests / responses should be ignored
3023        let _ = request_feedback_rx.try_recv().unwrap_err();
3024        let _ = response_feedback_rx.try_recv().unwrap_err();
3025    }
3026
3027    #[tokio::test]
3028    async fn proxy_compresses_and_decompresses_frames_iff_compression_negotiated() {
3029        setup_tracing();
3030        let node1_real_addr = next_local_address_with_port(9876);
3031        let node1_proxy_addr = next_local_address_with_port(9876);
3032
3033        let (request_feedback_tx, mut request_feedback_rx) = mpsc::unbounded_channel();
3034        let (response_feedback_tx, mut response_feedback_rx) = mpsc::unbounded_channel();
3035        let proxy = Proxy::builder()
3036            .with_node(
3037                Node::builder()
3038                    .real_address(node1_real_addr)
3039                    .proxy_address(node1_proxy_addr)
3040                    .shard_awareness(ShardAwareness::Unaware)
3041                    .request_rules(vec![RequestRule(
3042                        Condition::True,
3043                        RequestReaction::noop().with_feedback_when_performed(request_feedback_tx),
3044                    )])
3045                    .response_rules(vec![ResponseRule(
3046                        Condition::True,
3047                        ResponseReaction::noop().with_feedback_when_performed(response_feedback_tx),
3048                    )])
3049                    .build(),
3050            )
3051            .build();
3052        let running_proxy = proxy.run().await.unwrap();
3053
3054        let mock_node_listener = TcpListener::bind(node1_real_addr).await.unwrap();
3055
3056        const PARAMS_REQUEST_NO_COMPRESSION: FrameParams = FrameParams {
3057            flags: 0,
3058            version: 0x04,
3059            stream: 0,
3060        };
3061        const PARAMS_REQUEST_COMPRESSION: FrameParams = FrameParams {
3062            flags: flag::COMPRESSION,
3063            ..PARAMS_REQUEST_NO_COMPRESSION
3064        };
3065        const PARAMS_RESPONSE_NO_COMPRESSION: FrameParams =
3066            PARAMS_REQUEST_NO_COMPRESSION.for_response();
3067        const PARAMS_RESPONSE_COMPRESSION: FrameParams =
3068            PARAMS_REQUEST_NO_COMPRESSION.for_response();
3069
3070        let make_driver_conn = async { TcpStream::connect(node1_proxy_addr).await.unwrap() };
3071        let make_node_conn = async { mock_node_listener.accept().await.unwrap() };
3072
3073        let (mut driver_conn, (mut node_conn, _)) = join(make_driver_conn, make_node_conn).await;
3074
3075        /* Outline of the test:
3076         * 1. "driver" sends an uncompressed, e.g., QUERY frame, feedback returns its uncompressed body,
3077         *    and "node" receives the uncompressed frame.
3078         * 2. "node" responds with an uncompressed RESULT frame, feedback returns its uncompressed body,
3079         *    and "driver" receives the uncompressed frame.
3080         * 3. "driver" sends an uncompressed STARTUP frame, feedback returns its uncompressed body,
3081         *    and "node" receives the uncompressed frame. This step also triggers `CompressionWriter::set()`
3082         *    in the proxy, so the associated `CompressionReader`s are notified about it (and can use
3083         *    the negotiated compression algorithm to (de)compress the frames sent in steps 4. and 5.).
3084         * 4. "driver" sends a compressed, e.g., QUERY frame, feedback returns its uncompressed body,
3085         *    and "node" receives the compressed frame.
3086         * 5. "node" responds with a compressed RESULT frame, feedback returns its uncompressed body,
3087         *    and "driver" receives the compressed frame.
3088         */
3089
3090        // 1. "driver" sends an uncompressed, e.g., QUERY frame, feedback returns its uncompressed body,
3091        //    and "node" receives the uncompressed frame.
3092        {
3093            let sent_frame = RequestFrame {
3094                params: PARAMS_REQUEST_NO_COMPRESSION,
3095                opcode: RequestOpcode::Query,
3096                body: random_body(),
3097            };
3098
3099            sent_frame
3100                .write(&mut driver_conn, &no_compression())
3101                .await
3102                .unwrap();
3103
3104            let (captured_frame, _) = request_feedback_rx.recv().await.unwrap();
3105            assert_eq!(captured_frame, sent_frame);
3106
3107            let received_frame = read_request_frame(&mut node_conn, &no_compression())
3108                .await
3109                .unwrap();
3110            assert_eq!(received_frame, sent_frame);
3111        }
3112
3113        // 2. "node" responds with an uncompressed RESULT frame, feedback returns its uncompressed body,
3114        //    and "driver" receives the uncompressed frame.
3115        {
3116            let sent_frame = ResponseFrame {
3117                params: PARAMS_RESPONSE_NO_COMPRESSION,
3118                opcode: ResponseOpcode::Result,
3119                body: random_body(),
3120            };
3121
3122            sent_frame
3123                .write(&mut node_conn, &no_compression())
3124                .await
3125                .unwrap();
3126
3127            let (captured_frame, _) = response_feedback_rx.recv().await.unwrap();
3128            assert_eq!(captured_frame, sent_frame);
3129
3130            let received_frame = read_response_frame(&mut driver_conn, &no_compression())
3131                .await
3132                .unwrap();
3133            assert_eq!(received_frame, sent_frame);
3134        }
3135
3136        // 3. "driver" sends an uncompressed STARTUP frame, feedback returns its uncompressed body,
3137        //    and "node" receives the uncompressed frame. This step also triggers `CompressionWriter::set()`
3138        //    in the proxy, so the associated `CompressionReader`s are notified about it (and can use
3139        //    the negotiated compression algorithm to (de)compress the frames sent in steps 4. and 5.).
3140        {
3141            let startup_body = Startup {
3142                options: std::iter::once((
3143                    options::COMPRESSION.into(),
3144                    Compression::Lz4.as_str().into(),
3145                ))
3146                .collect(),
3147            }
3148            .to_bytes()
3149            .unwrap();
3150
3151            let sent_frame = RequestFrame {
3152                params: PARAMS_REQUEST_NO_COMPRESSION,
3153                opcode: RequestOpcode::Startup,
3154                body: startup_body,
3155            };
3156
3157            sent_frame
3158                .write(&mut driver_conn, &no_compression())
3159                .await
3160                .unwrap();
3161
3162            let (captured_frame, _) = request_feedback_rx.recv().await.unwrap();
3163            assert_eq!(captured_frame, sent_frame);
3164
3165            let received_frame = read_request_frame(&mut node_conn, &no_compression())
3166                .await
3167                .unwrap();
3168            assert_eq!(received_frame, sent_frame);
3169        }
3170
3171        // 4. "driver" sends a compressed, e.g., QUERY frame, feedback returns its uncompressed body,
3172        //    and "node" receives the compressed frame.
3173        {
3174            let sent_frame = RequestFrame {
3175                params: PARAMS_REQUEST_COMPRESSION,
3176                opcode: RequestOpcode::Query,
3177                body: random_body(),
3178            };
3179
3180            sent_frame
3181                .write(&mut driver_conn, &with_compression(Compression::Lz4))
3182                .await
3183                .unwrap();
3184
3185            let (captured_frame, _) = request_feedback_rx.recv().await.unwrap();
3186            assert_eq!(captured_frame, sent_frame);
3187
3188            let received_frame =
3189                read_request_frame(&mut node_conn, &with_compression(Compression::Lz4))
3190                    .await
3191                    .unwrap();
3192            assert_eq!(received_frame, sent_frame);
3193        }
3194
3195        // 5. "node" responds with a compressed RESULT frame, feedback returns its uncompressed body,
3196        //    and "driver" receives the compressed frame.
3197        {
3198            let sent_frame = ResponseFrame {
3199                params: PARAMS_RESPONSE_COMPRESSION,
3200                opcode: ResponseOpcode::Result,
3201                body: random_body(),
3202            };
3203
3204            sent_frame
3205                .write(&mut node_conn, &with_compression(Compression::Lz4))
3206                .await
3207                .unwrap();
3208
3209            let (captured_frame, _) = response_feedback_rx.recv().await.unwrap();
3210            assert_eq!(captured_frame, sent_frame);
3211
3212            let received_frame =
3213                read_response_frame(&mut driver_conn, &with_compression(Compression::Lz4))
3214                    .await
3215                    .unwrap();
3216            assert_eq!(received_frame, sent_frame);
3217        }
3218
3219        running_proxy.finish().await.unwrap();
3220    }
3221
3222    #[tokio::test]
3223    async fn connection_tracker_register_increments_and_drop_decrements() {
3224        let tracker = Arc::new(ConnectionTracker {
3225            active_count: std::sync::atomic::AtomicUsize::new(0),
3226            notify: tokio::sync::Notify::new(),
3227        });
3228
3229        assert_eq!(tracker.active_count.load(Ordering::Relaxed), 0);
3230
3231        let guard1 = tracker.register_connection();
3232        assert_eq!(tracker.active_count.load(Ordering::Relaxed), 1);
3233
3234        let guard2 = tracker.register_connection();
3235        assert_eq!(tracker.active_count.load(Ordering::Relaxed), 2);
3236
3237        // Cloning the Arc does NOT increment the count — only register_connection does.
3238        let guard1_clone = Arc::clone(&guard1);
3239        assert_eq!(tracker.active_count.load(Ordering::Relaxed), 2);
3240
3241        // Dropping the clone doesn't decrement yet (the original is still alive).
3242        drop(guard1_clone);
3243        assert_eq!(tracker.active_count.load(Ordering::Relaxed), 2);
3244
3245        // Dropping the last Arc for connection 1 decrements.
3246        drop(guard1);
3247        assert_eq!(tracker.active_count.load(Ordering::Relaxed), 1);
3248
3249        drop(guard2);
3250        assert_eq!(tracker.active_count.load(Ordering::Relaxed), 0);
3251    }
3252
3253    #[tokio::test]
3254    async fn connection_tracker_register_notifies_waiters() {
3255        let tracker = Arc::new(ConnectionTracker {
3256            active_count: std::sync::atomic::AtomicUsize::new(0),
3257            notify: tokio::sync::Notify::new(),
3258        });
3259
3260        // Set up a waiter BEFORE registering a connection.
3261        let notified = tracker.notify.notified();
3262
3263        // register_connection should notify.
3264        let _guard = tracker.register_connection();
3265
3266        // The notification should resolve immediately (no timeout needed).
3267        tokio::time::timeout(Duration::from_millis(200), notified)
3268            .await
3269            .expect("notify was not triggered by register_connection");
3270    }
3271
3272    #[tokio::test]
3273    async fn wait_for_connection_returns_immediately_when_connected() {
3274        let node_real_addr = next_local_address_with_port(9876);
3275        let node_proxy_addr = next_local_address_with_port(9876);
3276        let proxy = Proxy::new([Node::new(
3277            node_real_addr,
3278            node_proxy_addr,
3279            ShardAwareness::Unaware,
3280            None,
3281            None,
3282        )]);
3283        let running_proxy = proxy.run().await.unwrap();
3284
3285        let mock_node_listener = TcpListener::bind(node_real_addr).await.unwrap();
3286
3287        // Connect a driver and accept the backend connection.
3288        let _driver_conn = TcpStream::connect(node_proxy_addr).await.unwrap();
3289        let (_backend_conn, _) = mock_node_listener.accept().await.unwrap();
3290
3291        // wait_for_connection should return promptly (connection is already active).
3292        tokio::time::timeout(
3293            Duration::from_millis(200),
3294            running_proxy.running_nodes[0].wait_for_connection(),
3295        )
3296        .await
3297        .expect("wait_for_connection timed out despite active connection");
3298
3299        // Same via RunningProxy::wait_for_connection.
3300        tokio::time::timeout(
3301            Duration::from_millis(200),
3302            running_proxy.wait_for_connection(),
3303        )
3304        .await
3305        .expect("RunningProxy::wait_for_connection timed out despite active connection");
3306
3307        running_proxy.finish().await.unwrap();
3308    }
3309
3310    #[tokio::test]
3311    async fn wait_for_connection_blocks_until_driver_connects() {
3312        let node_real_addr = next_local_address_with_port(9876);
3313        let node_proxy_addr = next_local_address_with_port(9876);
3314        let proxy = Proxy::new([Node::new(
3315            node_real_addr,
3316            node_proxy_addr,
3317            ShardAwareness::Unaware,
3318            None,
3319            None,
3320        )]);
3321        let running_proxy = proxy.run().await.unwrap();
3322
3323        let mock_node_listener = TcpListener::bind(node_real_addr).await.unwrap();
3324
3325        // Before any driver connects, wait_for_connection should NOT return.
3326        let result = tokio::time::timeout(
3327            Duration::from_millis(20),
3328            running_proxy.running_nodes[0].wait_for_connection(),
3329        )
3330        .await;
3331        assert!(
3332            result.is_err(),
3333            "wait_for_connection returned before any driver connected"
3334        );
3335
3336        // Now connect a driver and accept the backend side.
3337        let _driver_conn = TcpStream::connect(node_proxy_addr).await.unwrap();
3338        let (_backend_conn, _) = mock_node_listener.accept().await.unwrap();
3339
3340        // wait_for_connection should now resolve.
3341        tokio::time::timeout(
3342            Duration::from_millis(200),
3343            running_proxy.running_nodes[0].wait_for_connection(),
3344        )
3345        .await
3346        .expect("wait_for_connection timed out after driver connected");
3347
3348        running_proxy.finish().await.unwrap();
3349    }
3350
3351    /// Helper: send a REGISTER frame from the driver side and wait until
3352    /// the mock node receives it. This is enough for the proxy's
3353    /// request_processor to register the cc_event_sender for that
3354    /// connection — no response is needed.
3355    async fn send_register(driver_conn: &mut TcpStream, node_conn: &mut TcpStream) {
3356        let params = FrameParams {
3357            flags: 0,
3358            version: 0x04,
3359            stream: 0,
3360        };
3361        write_frame(
3362            params,
3363            FrameOpcode::Request(RequestOpcode::Register),
3364            b"",
3365            driver_conn,
3366            &no_compression(),
3367        )
3368        .await
3369        .unwrap();
3370
3371        // Wait until the mock node receives the REGISTER so we know the
3372        // proxy has already processed it and registered the sender.
3373        let _req = read_request_frame(node_conn, &no_compression())
3374            .await
3375            .unwrap();
3376    }
3377
3378    #[tokio::test]
3379    async fn inject_event_to_cc_returns_false_when_no_control_connections() {
3380        setup_tracing();
3381        let node_real_addr = next_local_address_with_port(9876);
3382        let node_proxy_addr = next_local_address_with_port(9876);
3383        let proxy = Proxy::new([Node::new(
3384            node_real_addr,
3385            node_proxy_addr,
3386            ShardAwareness::Unaware,
3387            None,
3388            None,
3389        )]);
3390        let running_proxy = proxy.run().await.unwrap();
3391        let _mock_node_listener = TcpListener::bind(node_real_addr).await.unwrap();
3392
3393        // No connections at all — inject should return false.
3394        assert!(
3395            !running_proxy.running_nodes[0].inject_event_to_cc(Bytes::from_static(b"test")),
3396            "inject_event_to_cc should return false with no control connections"
3397        );
3398
3399        // finish() may report errors because no real node accepted connections.
3400        let _ = running_proxy.finish().await;
3401    }
3402
3403    #[tokio::test]
3404    async fn inject_event_to_cc_returns_false_when_connection_did_not_register() {
3405        setup_tracing();
3406        let node_real_addr = next_local_address_with_port(9876);
3407        let node_proxy_addr = next_local_address_with_port(9876);
3408        let proxy = Proxy::new([Node::new(
3409            node_real_addr,
3410            node_proxy_addr,
3411            ShardAwareness::Unaware,
3412            None,
3413            None,
3414        )]);
3415        let running_proxy = proxy.run().await.unwrap();
3416        let mock_node_listener = TcpListener::bind(node_real_addr).await.unwrap();
3417
3418        // Connect but do NOT send REGISTER.
3419        let _driver_conn = TcpStream::connect(node_proxy_addr).await.unwrap();
3420        let (_node_conn, _) = mock_node_listener.accept().await.unwrap();
3421
3422        // Connection exists but hasn't sent REGISTER — inject should return false.
3423        assert!(
3424            !running_proxy.running_nodes[0].inject_event_to_cc(Bytes::from_static(b"test")),
3425            "inject_event_to_cc should return false when no REGISTER was sent"
3426        );
3427
3428        running_proxy.finish().await.unwrap();
3429    }
3430
3431    #[tokio::test]
3432    async fn inject_event_to_cc_delivers_event_after_register() {
3433        setup_tracing();
3434        let node_real_addr = next_local_address_with_port(9876);
3435        let node_proxy_addr = next_local_address_with_port(9876);
3436        let proxy = Proxy::new([Node::new(
3437            node_real_addr,
3438            node_proxy_addr,
3439            ShardAwareness::Unaware,
3440            None,
3441            None,
3442        )]);
3443        let running_proxy = proxy.run().await.unwrap();
3444        let mock_node_listener = TcpListener::bind(node_real_addr).await.unwrap();
3445
3446        let (mut driver_conn, mut node_conn) = join(
3447            async { TcpStream::connect(node_proxy_addr).await.unwrap() },
3448            async { mock_node_listener.accept().await.unwrap().0 },
3449        )
3450        .await;
3451
3452        // Complete the REGISTER handshake so the proxy registers the cc sender.
3453        send_register(&mut driver_conn, &mut node_conn).await;
3454
3455        // Inject an event.
3456        let event_body = Bytes::from_static(b"injected_event_payload");
3457        assert!(
3458            running_proxy.running_nodes[0].inject_event_to_cc(event_body.clone()),
3459            "inject_event_to_cc should return true after REGISTER"
3460        );
3461
3462        // Read the injected frame on the driver side.
3463        let frame = tokio::time::timeout(
3464            Duration::from_millis(100),
3465            read_response_frame(&mut driver_conn, &no_compression()),
3466        )
3467        .await
3468        .expect("timed out waiting for injected event")
3469        .expect("failed to read injected event frame");
3470
3471        assert_eq!(frame.opcode, ResponseOpcode::Event);
3472        assert_eq!(frame.body, event_body);
3473        assert_eq!(frame.params.stream, -1);
3474
3475        running_proxy.finish().await.unwrap();
3476    }
3477
3478    #[tokio::test]
3479    async fn inject_event_to_cc_prunes_closed_connections() {
3480        setup_tracing();
3481        let node_real_addr = next_local_address_with_port(9876);
3482        let node_proxy_addr = next_local_address_with_port(9876);
3483        let proxy = Proxy::new([Node::new(
3484            node_real_addr,
3485            node_proxy_addr,
3486            ShardAwareness::Unaware,
3487            None,
3488            None,
3489        )]);
3490        let running_proxy = proxy.run().await.unwrap();
3491        let mock_node_listener = TcpListener::bind(node_real_addr).await.unwrap();
3492
3493        // Establish first control connection.
3494        let (mut driver_conn1, mut node_conn1) = join(
3495            async { TcpStream::connect(node_proxy_addr).await.unwrap() },
3496            async { mock_node_listener.accept().await.unwrap().0 },
3497        )
3498        .await;
3499        send_register(&mut driver_conn1, &mut node_conn1).await;
3500
3501        // Establish second control connection.
3502        let (mut driver_conn2, mut node_conn2) = join(
3503            async { TcpStream::connect(node_proxy_addr).await.unwrap() },
3504            async { mock_node_listener.accept().await.unwrap().0 },
3505        )
3506        .await;
3507        send_register(&mut driver_conn2, &mut node_conn2).await;
3508
3509        // Both connections registered — inject should succeed.
3510        assert!(running_proxy.running_nodes[0].inject_event_to_cc(Bytes::from_static(b"ev1")));
3511
3512        // Read event from both.
3513        let f1 = read_response_frame(&mut driver_conn1, &no_compression())
3514            .await
3515            .unwrap();
3516        let f2 = read_response_frame(&mut driver_conn2, &no_compression())
3517            .await
3518            .unwrap();
3519        assert_eq!(f1.body, Bytes::from_static(b"ev1"));
3520        assert_eq!(f2.body, Bytes::from_static(b"ev1"));
3521
3522        // Close connection 1 (both sides).
3523        drop(driver_conn1);
3524        drop(node_conn1);
3525
3526        // Give the proxy a moment to detect the closed connection and clean up
3527        // its cc_event_sender entry.
3528        tokio::time::sleep(Duration::from_millis(100)).await;
3529
3530        // Inject again — should still succeed via connection 2, and prune
3531        // the dead sender for connection 1.
3532        assert!(running_proxy.running_nodes[0].inject_event_to_cc(Bytes::from_static(b"ev2")));
3533
3534        let f2 = read_response_frame(&mut driver_conn2, &no_compression())
3535            .await
3536            .unwrap();
3537        assert_eq!(f2.body, Bytes::from_static(b"ev2"));
3538
3539        // Close connection 2 as well.
3540        drop(driver_conn2);
3541        drop(node_conn2);
3542
3543        tokio::time::sleep(Duration::from_millis(100)).await;
3544
3545        // Now all control connections are gone — inject should return false.
3546        assert!(
3547            !running_proxy.running_nodes[0].inject_event_to_cc(Bytes::from_static(b"ev3")),
3548            "inject_event_to_cc should return false after all control connections closed"
3549        );
3550
3551        // finish() may report DriverDisconnected errors from the intentionally
3552        // dropped connections — that's expected.
3553        let _ = running_proxy.finish().await;
3554    }
3555}