Skip to main content

strike48_connector/multi/
mod.rs

1//! Multi-registration connector runner.
2//!
3//! Lets one host process register `N` independently-approvable connectors
4//! against a single Matrix server while sharing the underlying transport:
5//!
6//! - **gRPC**: one TCP+TLS connection, N HTTP/2 streams (one per registration).
7//!   Lazily opens additional channels when `max_streams_per_channel` is hit.
8//! - **WebSocket**: HTTP/1.1 — no native multiplexing — falls back to N
9//!   independent WS connections in one process. Same public API.
10//!
11//! From the Matrix server's point of view, each registration is a normal
12//! `Connect` RPC. There are zero server-side changes.
13//!
14//! ## Reconnect behaviour
15//!
16//! The runner has two layers of reconnect, owned by different actors:
17//!
18//! - **Stream-level (per registration)**: implemented in the internal
19//!   `registration_runner` module. When a stream ends (server closes, network
20//!   blip, heartbeat timeout) the runner sleeps with exponential backoff +
21//!   jitter (caps at `MultiTransportOptions::reconnect_max_delay_ms`) and
22//!   opens a fresh stream over the existing channel. Fully shutdown-aware
23//!   — never blocks shutdown for more than one backoff slice. Each
24//!   reconnect bumps the registration's `successful_reconnects` /
25//!   `total_disconnects` metrics.
26//!
27//! - **Channel-level (per HTTP/2 connection, gRPC only)**: delegated to
28//!   `tonic::transport::Channel`. The channel is configured with HTTP/2
29//!   keepalive and dialled **eagerly** via `endpoint.connect().await` the
30//!   first time a registration needs it. tonic auto-recovers on transient
31//!   TCP / TLS failures by re-dialing internally on the next request; the
32//!   SDK does **not** explicitly recreate channels today.
33//!
34//!   In practice this means: if the underlying TCP connection breaks, all
35//!   N registrations using that channel will see their streams close. The
36//!   per-registration loop opens a fresh stream, which in turn forces the
37//!   channel to redial. End state: connections recover transparently
38//!   without involvement from this module.
39//!
40//!   If a channel ever goes **permanently dead** (e.g. DNS now points to
41//!   an unreachable host and the lazy redial keeps failing), every
42//!   registration on that channel will spend its time backing off. The
43//!   [`MultiConnectorRunner::shutdown_handle`] still works, but recovery
44//!   for the current process is best-effort. A future improvement would
45//!   be to track per-channel consecutive-failure counts and rebuild the
46//!   channel after a threshold; tracked under
47//!   `connector-sdk-rust-channel-reconnect-policy`.
48//!
49//! ## Backward compatibility
50//!
51//! This module is **purely additive**. The existing single-registration
52//! [`crate::ConnectorRunner`] API and behaviour are unchanged.
53//!
54//! ## Example
55//!
56//! ```no_run
57//! # async fn run() -> strike48_connector::Result<()> {
58//! use std::sync::Arc;
59//! use strike48_connector::{
60//!     BaseConnector, ConnectorConfig, ConnectorRegistration, MultiConnectorRunner,
61//!     MultiTransportOptions, Result, TransportType,
62//! };
63//!
64//! struct Echo;
65//! impl BaseConnector for Echo {
66//!     fn connector_type(&self) -> &str { "echo" }
67//!     fn version(&self) -> &str { "1.0.0" }
68//!     fn execute(
69//!         &self,
70//!         req: serde_json::Value,
71//!         _: Option<&str>,
72//!     ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send + '_>> {
73//!         Box::pin(async move { Ok(req) })
74//!     }
75//! }
76//!
77//! let opts = MultiTransportOptions::builder()
78//!     .host("localhost:50061")
79//!     .transport_type(TransportType::Grpc)
80//!     .build();
81//!
82//! let registrations = (0..3).map(|i| {
83//!     ConnectorRegistration::new(
84//!         ConnectorConfig {
85//!             tenant_id: "demo-org".into(),
86//!             connector_type: "echo".into(),
87//!             instance_id: format!("echo-{i}"),
88//!             ..ConnectorConfig::default()
89//!         },
90//!         Echo,
91//!     )
92//! }).collect::<Vec<_>>();
93//!
94//! let runner = MultiConnectorRunner::new(opts, registrations);
95//! let _shutdown = runner.shutdown_handle();
96//! runner.run().await?;
97//! # Ok(()) }
98//! ```
99
100use std::collections::HashMap;
101use std::sync::Arc;
102use std::sync::atomic::{AtomicBool, Ordering};
103use std::time::Duration;
104
105use tokio::sync::{Mutex, RwLock, Semaphore, watch};
106
107use crate::connector::{BaseConnector, ConnectorConfig, ShutdownHandle};
108use crate::error::{ConnectorError, Result};
109use crate::logger::Logger;
110use crate::transport::TransportType;
111use crate::types::ConnectorMetrics;
112
113mod registration_runner;
114mod shared_channel;
115
116use registration_runner::RegistrationRunner;
117use shared_channel::SharedChannel;
118
119// =============================================================================
120// Public types
121// =============================================================================
122
123/// Transport-level configuration shared by every registration in a
124/// [`MultiConnectorRunner`]. Per-registration identity (tenant, type, instance,
125/// auth token, ...) lives on each [`ConnectorRegistration`]'s [`ConnectorConfig`].
126///
127/// Marked `#[non_exhaustive]` so future fields can be added without breaking
128/// downstream struct-literal construction. Use
129/// [`MultiTransportOptions::builder`] (preferred) or
130/// `MultiTransportOptions { ..MultiTransportOptions::default() }`.
131#[derive(Debug, Clone)]
132#[non_exhaustive]
133pub struct MultiTransportOptions {
134    /// Server host:port (e.g. `localhost:50061` for gRPC, `localhost:4000` for WS).
135    pub host: String,
136    /// Whether to use TLS for the transport.
137    pub use_tls: bool,
138    /// Transport scheme.
139    pub transport_type: TransportType,
140
141    /// Soft cap on concurrent gRPC streams per channel before the runner
142    /// opens an additional channel. Defaults to `80` to leave headroom under
143    /// the typical Cowboy/RFC 7540 default of 100. Ignored for WebSocket.
144    pub max_streams_per_channel: usize,
145
146    /// Initial connect timeout (ms). Default 10_000.
147    pub connect_timeout_ms: u64,
148
149    /// Enable channel-level reconnect on transport failure. Default `true`.
150    pub reconnect_enabled: bool,
151    /// Base reconnect backoff (ms). Default 500.
152    pub reconnect_delay_ms: u64,
153    /// Max reconnect backoff (ms). Default 60_000.
154    pub max_backoff_delay_ms: u64,
155    /// Reconnect jitter (ms). Default 500.
156    pub reconnect_jitter_ms: u64,
157
158    /// Per-registration maximum number of in-flight `ExecuteRequest`s being
159    /// processed by the user `BaseConnector::execute()` callback. When the
160    /// limit is reached, additional `ExecuteRequest`s queue on a semaphore
161    /// until a permit is released. Mirrors
162    /// [`crate::ConnectorConfig::max_concurrent_requests`] for the
163    /// single-runner path. Default `100`.
164    pub max_concurrent_requests: usize,
165
166    /// Per-registration outbound heartbeat interval. `None` (default) means
167    /// use the SDK default of 30s, which matches the Matrix server's
168    /// session-reaper expectation.
169    ///
170    /// Only tune this when running against a Matrix deployment with a
171    /// non-default heartbeat configuration, or for flaky-network testing.
172    pub heartbeat_interval: Option<Duration>,
173
174    /// Per-registration heartbeat watchdog timeout. If no `HeartbeatResponse`
175    /// arrives within this window the runner declares the stream dead and
176    /// reconnects. `None` (default) means use the SDK default of 45s.
177    ///
178    /// Only tune this when running against a Matrix deployment with a
179    /// non-default heartbeat configuration, or for flaky-network testing.
180    pub heartbeat_timeout: Option<Duration>,
181}
182
183impl MultiTransportOptions {
184    /// Start a builder with sensible defaults (gRPC, plaintext, localhost:50061).
185    pub fn builder() -> MultiTransportOptionsBuilder {
186        MultiTransportOptionsBuilder::default()
187    }
188}
189
190impl Default for MultiTransportOptions {
191    fn default() -> Self {
192        Self {
193            host: "localhost:50061".to_string(),
194            use_tls: false,
195            transport_type: TransportType::Grpc,
196            max_streams_per_channel: 80,
197            connect_timeout_ms: 10_000,
198            reconnect_enabled: true,
199            reconnect_delay_ms: 500,
200            max_backoff_delay_ms: 60_000,
201            reconnect_jitter_ms: 500,
202            max_concurrent_requests: 100,
203            heartbeat_interval: None,
204            heartbeat_timeout: None,
205        }
206    }
207}
208
209/// Validate a heartbeat (interval, timeout) pair and emit a `tracing::warn!`
210/// if the timeout is shorter than the interval — that misconfigures the
211/// watchdog (the very first tick can fire after the timeout has already
212/// elapsed). Returns `false` when the pair is misconfigured.
213///
214/// Extracted as a free function so it can be unit-tested in isolation.
215fn validate_heartbeat_pair(interval: Option<Duration>, timeout: Option<Duration>) -> bool {
216    match (interval, timeout) {
217        (Some(i), Some(t)) if t < i => {
218            tracing::warn!(
219                target: "strike48_connector::heartbeat",
220                interval_ms = i.as_millis() as u64,
221                timeout_ms = t.as_millis() as u64,
222                "heartbeat_timeout < heartbeat_interval; the watchdog can fire before the first heartbeat reply has a chance to arrive"
223            );
224            false
225        }
226        _ => true,
227    }
228}
229
230/// Fluent builder for [`MultiTransportOptions`].
231#[derive(Debug, Clone, Default)]
232pub struct MultiTransportOptionsBuilder {
233    inner: Option<MultiTransportOptions>,
234}
235
236impl MultiTransportOptionsBuilder {
237    fn opts(&mut self) -> &mut MultiTransportOptions {
238        self.inner
239            .get_or_insert_with(MultiTransportOptions::default)
240    }
241
242    /// Set the server host:port.
243    pub fn host(mut self, host: impl Into<String>) -> Self {
244        self.opts().host = host.into();
245        self
246    }
247
248    /// Set whether to use TLS.
249    pub fn use_tls(mut self, use_tls: bool) -> Self {
250        self.opts().use_tls = use_tls;
251        self
252    }
253
254    /// Set the transport scheme.
255    pub fn transport_type(mut self, t: TransportType) -> Self {
256        self.opts().transport_type = t;
257        self
258    }
259
260    /// Override the soft cap on concurrent gRPC streams per channel (gRPC only).
261    pub fn max_streams_per_channel(mut self, n: usize) -> Self {
262        self.opts().max_streams_per_channel = n;
263        self
264    }
265
266    /// Override the initial connect timeout (ms).
267    pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
268        self.opts().connect_timeout_ms = ms;
269        self
270    }
271
272    /// Override the per-registration `max_concurrent_requests` cap.
273    pub fn max_concurrent_requests(mut self, n: usize) -> Self {
274        self.opts().max_concurrent_requests = n.max(1);
275        self
276    }
277
278    /// Enable or disable automatic reconnection on stream loss.
279    pub fn reconnect_enabled(mut self, enabled: bool) -> Self {
280        self.opts().reconnect_enabled = enabled;
281        self
282    }
283
284    /// Initial reconnect delay (ms) — the base for exponential backoff.
285    pub fn reconnect_delay_ms(mut self, ms: u64) -> Self {
286        self.opts().reconnect_delay_ms = ms;
287        self
288    }
289
290    /// Hard cap on reconnect backoff (ms). Jitter is applied first, then
291    /// the result is clamped to this value, so the cap is a strict upper
292    /// bound on the wait between attempts.
293    pub fn max_backoff_delay_ms(mut self, ms: u64) -> Self {
294        self.opts().max_backoff_delay_ms = ms;
295        self
296    }
297
298    /// Per-attempt jitter range (ms). A uniformly-random value in
299    /// `0..=ms` is added to the scaled backoff before capping.
300    pub fn reconnect_jitter_ms(mut self, ms: u64) -> Self {
301        self.opts().reconnect_jitter_ms = ms;
302        self
303    }
304
305    /// Override the per-registration heartbeat interval.
306    ///
307    /// Defaults to **30s** (the Matrix server's session-reaper expectation).
308    /// Only tune this when running against a Matrix deployment with a
309    /// non-default heartbeat configuration, or for flaky-network testing.
310    /// If the resulting `heartbeat_timeout` is shorter than the interval,
311    /// a `tracing::warn!` is emitted at builder time but the value is
312    /// still applied (the runner uses whatever is configured).
313    pub fn heartbeat_interval(mut self, d: Duration) -> Self {
314        self.opts().heartbeat_interval = Some(d);
315        let _ = validate_heartbeat_pair(
316            self.opts().heartbeat_interval,
317            self.opts().heartbeat_timeout,
318        );
319        self
320    }
321
322    /// Override the per-registration heartbeat watchdog timeout.
323    ///
324    /// Defaults to **45s** (Matrix server default + slack for one missed
325    /// reply). Only tune this when running against a Matrix deployment
326    /// with a non-default heartbeat configuration, or for flaky-network
327    /// testing. Misconfigured pairs (`timeout < interval`) emit a
328    /// `tracing::warn!` at builder time but are still applied.
329    pub fn heartbeat_timeout(mut self, d: Duration) -> Self {
330        self.opts().heartbeat_timeout = Some(d);
331        let _ = validate_heartbeat_pair(
332            self.opts().heartbeat_interval,
333            self.opts().heartbeat_timeout,
334        );
335        self
336    }
337
338    /// Build the options.
339    pub fn build(mut self) -> MultiTransportOptions {
340        self.inner.take().unwrap_or_default()
341    }
342}
343
344/// One logical connector to run inside a [`MultiConnectorRunner`].
345///
346/// `config.host`, `config.use_tls`, and `config.transport_type` are ignored —
347/// the transport is governed by [`MultiTransportOptions`]. All other fields
348/// (`tenant_id`, `connector_type`, `instance_id`, `auth_token`,
349/// `display_name`, `tags`, `metadata`, `max_concurrent_requests`,
350/// `metrics_*`) apply to this registration only.
351///
352/// Marked `#[non_exhaustive]` so future fields (e.g. per-registration
353/// behaviour overrides) can be added without breaking downstream
354/// struct-literal construction. Prefer [`ConnectorRegistration::new`].
355#[non_exhaustive]
356pub struct ConnectorRegistration {
357    pub config: ConnectorConfig,
358    pub connector: Arc<dyn BaseConnector>,
359}
360
361impl ConnectorRegistration {
362    /// Build a registration from a `ConnectorConfig` and any
363    /// [`BaseConnector`] implementor — the conversion to
364    /// `Arc<dyn BaseConnector>` happens internally so callers don't have
365    /// to write `Arc::new(...) as Arc<dyn BaseConnector>` themselves.
366    ///
367    /// ```
368    /// use std::sync::Arc;
369    /// use strike48_connector::{
370    ///     BaseConnector, ConnectorConfig, ConnectorRegistration, Result,
371    /// };
372    ///
373    /// struct Echo;
374    /// impl BaseConnector for Echo {
375    ///     fn connector_type(&self) -> &str { "echo" }
376    ///     fn version(&self) -> &str { "1.0.0" }
377    ///     fn execute(
378    ///         &self,
379    ///         req: serde_json::Value,
380    ///         _: Option<&str>,
381    ///     ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send + '_>> {
382    ///         Box::pin(async move { Ok(req) })
383    ///     }
384    /// }
385    ///
386    /// let cfg = ConnectorConfig {
387    ///     tenant_id: "demo".into(),
388    ///     connector_type: "echo".into(),
389    ///     instance_id: "echo-1".into(),
390    ///     ..ConnectorConfig::default()
391    /// };
392    /// let reg = ConnectorRegistration::new(cfg, Echo);
393    /// assert_eq!(reg.config.connector_type, "echo");
394    /// # let _ = reg;
395    /// ```
396    pub fn new<T>(config: ConnectorConfig, connector: T) -> Self
397    where
398        T: BaseConnector + 'static,
399    {
400        Self {
401            config,
402            connector: Arc::new(connector) as Arc<dyn BaseConnector>,
403        }
404    }
405
406    /// Build a registration from a config and an already-erased
407    /// `Arc<dyn BaseConnector>`. Useful when callers have a heterogeneous
408    /// list of differently-typed connectors that they have already boxed.
409    pub fn from_arc(config: ConnectorConfig, connector: Arc<dyn BaseConnector>) -> Self {
410        Self { config, connector }
411    }
412}
413
414impl std::fmt::Debug for ConnectorRegistration {
415    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416        f.debug_struct("ConnectorRegistration")
417            .field("config", &self.config)
418            .field(
419                "connector",
420                &format_args!(
421                    "Arc<dyn BaseConnector>(\"{}\")",
422                    self.connector.connector_type()
423                ),
424            )
425            .finish()
426    }
427}
428
429/// Stable identity for a registration. Matches the `tenant.type.instance`
430/// triple the Matrix server uses to key a `ConnectorSession`.
431///
432/// Marked `#[non_exhaustive]` so additional identity dimensions (e.g. an
433/// optional region tag) can be added without breaking downstream
434/// pattern-match exhaustiveness.
435#[derive(Debug, Clone, Hash, PartialEq, Eq)]
436#[non_exhaustive]
437pub struct RegistrationKey {
438    pub tenant_id: String,
439    pub connector_type: String,
440    pub instance_id: String,
441}
442
443impl RegistrationKey {
444    pub fn from_config(config: &ConnectorConfig) -> Self {
445        Self {
446            tenant_id: config.tenant_id.clone(),
447            connector_type: config.connector_type.clone(),
448            instance_id: config.instance_id.clone(),
449        }
450    }
451}
452
453impl std::fmt::Display for RegistrationKey {
454    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
455        write!(
456            f,
457            "{}.{}.{}",
458            self.tenant_id, self.connector_type, self.instance_id
459        )
460    }
461}
462
463// =============================================================================
464// MultiConnectorRunner
465// =============================================================================
466
467struct RegistrationEntry {
468    key: RegistrationKey,
469    config: ConnectorConfig,
470    connector: Arc<dyn BaseConnector>,
471    metrics: Arc<Mutex<ConnectorMetrics>>,
472}
473
474/// Runs `N` independently-approvable connectors over a shared transport.
475///
476/// See module-level docs for transport semantics. From the Matrix server's
477/// point of view, each registration is a normal `Connect` RPC.
478pub struct MultiConnectorRunner {
479    opts: MultiTransportOptions,
480    registrations: RwLock<Vec<RegistrationEntry>>,
481    shutdown_requested: Arc<AtomicBool>,
482    running: Arc<AtomicBool>,
483}
484
485impl MultiConnectorRunner {
486    /// Create a new runner. Construction does not open any connections —
487    /// transport is established lazily by [`MultiConnectorRunner::run`].
488    ///
489    /// Duplicate registrations (same `tenant.type.instance`) are rejected by
490    /// [`MultiConnectorRunner::add`]; duplicates passed in here are reduced
491    /// to the first occurrence and logged.
492    pub fn new(opts: MultiTransportOptions, registrations: Vec<ConnectorRegistration>) -> Self {
493        let mut entries: Vec<RegistrationEntry> = Vec::with_capacity(registrations.len());
494        for ConnectorRegistration { config, connector } in registrations {
495            let key = RegistrationKey::from_config(&config);
496            if entries.iter().any(|e| e.key == key) {
497                tracing::warn!(
498                    target: "strike48_connector::multi",
499                    registration = %key,
500                    "duplicate registration ignored"
501                );
502                continue;
503            }
504            entries.push(RegistrationEntry {
505                key,
506                config,
507                connector,
508                metrics: Arc::new(Mutex::new(ConnectorMetrics::default())),
509            });
510        }
511
512        Self {
513            opts,
514            registrations: RwLock::new(entries),
515            shutdown_requested: Arc::new(AtomicBool::new(false)),
516            running: Arc::new(AtomicBool::new(false)),
517        }
518    }
519
520    /// Append a registration. Only valid before [`MultiConnectorRunner::run`]
521    /// is called; returns an error if `run()` has already started or if the
522    /// registration's `tenant.type.instance` collides with an existing one.
523    ///
524    /// `running` is checked **inside** the same write-lock critical section
525    /// that `run()` uses to snapshot the registration list, so a concurrent
526    /// `add` either lands before the snapshot (and is driven) or returns
527    /// [`ConnectorError::AlreadyRunning`]. There is no window where the
528    /// registration is silently dropped.
529    pub async fn add(&self, registration: ConnectorRegistration) -> Result<()> {
530        let key = RegistrationKey::from_config(&registration.config);
531        let mut regs = self.registrations.write().await;
532        if self.running.load(Ordering::SeqCst) {
533            return Err(ConnectorError::AlreadyRunning);
534        }
535        if regs.iter().any(|e| e.key == key) {
536            return Err(ConnectorError::InvalidConfig(format!(
537                "duplicate registration: {key}"
538            )));
539        }
540        regs.push(RegistrationEntry {
541            key,
542            config: registration.config,
543            connector: registration.connector,
544            metrics: Arc::new(Mutex::new(ConnectorMetrics::default())),
545        });
546        Ok(())
547    }
548
549    /// Get a [`ShutdownHandle`] that signals every registration to exit.
550    pub fn shutdown_handle(&self) -> ShutdownHandle {
551        ShutdownHandle::from_flag(self.shutdown_requested.clone())
552    }
553
554    /// Snapshot of the registered keys (in insertion order).
555    pub async fn registrations(&self) -> Vec<RegistrationKey> {
556        self.registrations
557            .read()
558            .await
559            .iter()
560            .map(|e| e.key.clone())
561            .collect()
562    }
563
564    /// Per-registration metrics snapshot.
565    ///
566    /// Each registration owns its own [`ConnectorMetrics`] (no global
567    /// singleton), so values are independent across registrations sharing the
568    /// same transport.
569    pub async fn metrics_snapshot(&self) -> HashMap<RegistrationKey, ConnectorMetrics> {
570        let regs = self.registrations.read().await;
571        let mut out = HashMap::with_capacity(regs.len());
572        for entry in regs.iter() {
573            let snapshot = entry.metrics.lock().await.clone();
574            out.insert(entry.key.clone(), snapshot);
575        }
576        out
577    }
578
579    /// Run all registrations to completion or until shutdown.
580    ///
581    /// Returns once every registration has exited. Individual registration
582    /// failures are logged but do not abort the runner unless
583    /// [`MultiTransportOptions::reconnect_enabled`] is `false`.
584    pub async fn run(&self) -> Result<()> {
585        let logger = Logger::new("multi");
586
587        // Take the write lock so concurrent `add()` calls observe `running`
588        // atomically with the snapshot — see `add` for the partner half of
589        // this protocol. We hold the lock only long enough to flip `running`
590        // and clone the entries.
591        let entries: Vec<RegistrationEntry> = {
592            let regs = self.registrations.write().await;
593            if self
594                .running
595                .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
596                .is_err()
597            {
598                return Err(ConnectorError::AlreadyRunning);
599            }
600            regs.iter()
601                .map(|e| RegistrationEntry {
602                    key: e.key.clone(),
603                    config: e.config.clone(),
604                    connector: e.connector.clone(),
605                    metrics: e.metrics.clone(),
606                })
607                .collect()
608        };
609
610        // Pre-signalled shutdown is a clean no-op exit.
611        if self.shutdown_requested.load(Ordering::SeqCst) {
612            logger.debug("shutdown signalled before run; exiting");
613            self.running.store(false, Ordering::SeqCst);
614            return Ok(());
615        }
616
617        if entries.is_empty() {
618            logger.warn("no registrations configured; run() exiting immediately");
619            self.running.store(false, Ordering::SeqCst);
620            return Ok(());
621        }
622
623        let result = match self.opts.transport_type {
624            TransportType::Grpc => self.run_grpc(entries, logger).await,
625            TransportType::WebSocket => self.run_websocket(entries, logger).await,
626        };
627        self.running.store(false, Ordering::SeqCst);
628        result
629    }
630
631    async fn run_grpc(&self, entries: Vec<RegistrationEntry>, logger: Logger) -> Result<()> {
632        let shared = Arc::new(SharedChannel::new(self.opts.clone()));
633        let mut tasks = Vec::with_capacity(entries.len());
634
635        for entry in entries {
636            let runner = RegistrationRunner {
637                key: entry.key.clone(),
638                config: Arc::new(RwLock::new(entry.config)),
639                connector: entry.connector,
640                shared_channel: shared.clone(),
641                shutdown: self.shutdown_requested.clone(),
642                metrics: entry.metrics,
643                opts: self.opts.clone(),
644                request_semaphore: Arc::new(Semaphore::new(
645                    self.opts.max_concurrent_requests.max(1),
646                )),
647                session_token: Arc::new(RwLock::new(None)),
648            };
649            tasks.push(tokio::spawn(async move { runner.run().await }));
650        }
651
652        for task in tasks {
653            match task.await {
654                Ok(Ok(())) => {}
655                Ok(Err(e)) => {
656                    logger.warn(&format!("registration runner exited with error: {e}"));
657                }
658                Err(join_err) => {
659                    logger.error("registration task panicked", &join_err.to_string());
660                }
661            }
662        }
663
664        Ok(())
665    }
666
667    /// WebSocket has no native multiplexing (HTTP/1.1), so each logical
668    /// registration uses its own `WebSocketTransport` underneath. We fan out
669    /// to N independent [`crate::ConnectorRunner`]s — same public ergonomics
670    /// as gRPC mode, but the transport count == registration count. This keeps
671    /// the API symmetric and lets the existing single-runner reconnect /
672    /// auth / metrics paths apply unchanged.
673    async fn run_websocket(&self, entries: Vec<RegistrationEntry>, logger: Logger) -> Result<()> {
674        use crate::ConnectorRunner;
675
676        // One watch::channel for the whole runner: when the multi-runner's
677        // shutdown flag is flipped we bump this channel; every child task
678        // observes the change immediately via `changed().await` and signals
679        // its own ConnectorRunner's shutdown handle. No per-registration
680        // polling task — those leaked indefinitely on the previous
681        // `tokio::spawn(loop { sleep(100ms); load(...) })` design.
682        let (shutdown_tx, shutdown_rx_template) = watch::channel(false);
683
684        // Bridge the AtomicBool flag (kept for backward compatibility with
685        // ShutdownHandle) to the watch sender. Owned by `run_websocket` and
686        // dropped when this function returns; that drop closes the watch
687        // channel and lets every child observe completion if it hadn't
688        // already.
689        let multi_shutdown = self.shutdown_requested.clone();
690        let bridge_tx = shutdown_tx.clone();
691        let bridge = tokio::spawn(async move {
692            // Coarse poll on the AtomicBool — only ONE such task per runner,
693            // not one per registration. Exits cleanly when shutdown fires.
694            while !multi_shutdown.load(Ordering::SeqCst) {
695                tokio::time::sleep(std::time::Duration::from_millis(100)).await;
696                if bridge_tx.is_closed() {
697                    // All receivers gone (run_websocket returned via the
698                    // happy path before shutdown was ever signalled).
699                    return;
700                }
701            }
702            let _ = bridge_tx.send(true);
703        });
704
705        let mut tasks = Vec::with_capacity(entries.len());
706
707        for entry in entries {
708            let mut config = entry.config.clone();
709            config.transport_type = TransportType::WebSocket;
710            config.host = self.opts.host.clone();
711            config.use_tls = self.opts.use_tls;
712            config.reconnect_enabled = self.opts.reconnect_enabled;
713            config.reconnect_delay_ms = self.opts.reconnect_delay_ms;
714            config.max_backoff_delay_ms = self.opts.max_backoff_delay_ms;
715            config.reconnect_jitter_ms = self.opts.reconnect_jitter_ms;
716
717            let runner = ConnectorRunner::new(config, entry.connector);
718            let child_shutdown = runner.shutdown_handle();
719            let mut shutdown_rx = shutdown_rx_template.clone();
720            let key = entry.key.clone();
721
722            tasks.push(tokio::spawn(async move {
723                let mut runner_fut = Box::pin(runner.run());
724                let res = loop {
725                    tokio::select! {
726                        biased;
727                        // Propagate shutdown to the child runner the
728                        // instant the multi-runner flag flips. After
729                        // signalling we keep awaiting the runner's own
730                        // exit so its drain logic runs.
731                        changed = shutdown_rx.changed() => {
732                            match changed {
733                                Ok(()) if *shutdown_rx.borrow() => {
734                                    child_shutdown.shutdown();
735                                }
736                                Err(_) => {
737                                    // Sender dropped; treat as "no further
738                                    // shutdown will arrive". Wait for the
739                                    // runner to finish on its own.
740                                    break runner_fut.await;
741                                }
742                                _ => {}
743                            }
744                        }
745                        result = &mut runner_fut => break result,
746                    }
747                };
748                (key, res)
749            }));
750        }
751
752        // Drop the template receiver so the only live receivers are the
753        // ones inside the spawned tasks. This lets the bridge task observe
754        // `is_closed()` once every task has exited.
755        drop(shutdown_rx_template);
756
757        for task in tasks {
758            match task.await {
759                Ok((_key, Ok(()))) => {}
760                Ok((key, Err(e))) => {
761                    logger.warn(&format!("ws registration {key} exited with error: {e}"));
762                }
763                Err(join_err) => {
764                    logger.error("ws registration task panicked", &join_err.to_string());
765                }
766            }
767        }
768
769        // Tear down the bridge task. If it exited on its own we just await
770        // a finished JoinHandle (cheap). Otherwise abort + await yields a
771        // JoinError we ignore — the task is short-lived and trivial.
772        bridge.abort();
773        let _ = bridge.await;
774
775        // Keep the sender alive for the whole function; dropping here.
776        drop(shutdown_tx);
777
778        Ok(())
779    }
780}
781
782#[cfg(test)]
783mod tests {
784    use super::*;
785    use crate::types::ConnectorBehavior;
786
787    struct DummyConnector;
788    impl BaseConnector for DummyConnector {
789        fn connector_type(&self) -> &str {
790            "dummy"
791        }
792        fn version(&self) -> &str {
793            "0.0.0"
794        }
795        fn execute(
796            &self,
797            _: serde_json::Value,
798            _: Option<&str>,
799        ) -> std::pin::Pin<
800            Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send + '_>,
801        > {
802            Box::pin(async { Ok(serde_json::json!({})) })
803        }
804        fn behavior(&self) -> ConnectorBehavior {
805            ConnectorBehavior::Tool
806        }
807    }
808
809    fn reg(tenant: &str, ty: &str, inst: &str) -> ConnectorRegistration {
810        ConnectorRegistration::new(
811            ConnectorConfig {
812                tenant_id: tenant.into(),
813                connector_type: ty.into(),
814                instance_id: inst.into(),
815                ..ConnectorConfig::default()
816            },
817            DummyConnector,
818        )
819    }
820
821    #[test]
822    fn options_builder_defaults_match_default_impl() {
823        let built = MultiTransportOptions::builder().build();
824        let defaulted = MultiTransportOptions::default();
825        assert_eq!(built.host, defaulted.host);
826        assert_eq!(built.use_tls, defaulted.use_tls);
827        assert_eq!(
828            built.max_streams_per_channel,
829            defaulted.max_streams_per_channel
830        );
831        assert_eq!(built.transport_type, defaulted.transport_type);
832    }
833
834    #[test]
835    fn options_builder_heartbeat_roundtrip() {
836        let opts = MultiTransportOptions::builder()
837            .heartbeat_interval(Duration::from_secs(5))
838            .heartbeat_timeout(Duration::from_secs(15))
839            .build();
840        assert_eq!(opts.heartbeat_interval, Some(Duration::from_secs(5)));
841        assert_eq!(opts.heartbeat_timeout, Some(Duration::from_secs(15)));
842    }
843
844    #[test]
845    fn options_builder_heartbeat_defaults_are_none() {
846        let opts = MultiTransportOptions::builder().build();
847        assert!(opts.heartbeat_interval.is_none());
848        assert!(opts.heartbeat_timeout.is_none());
849    }
850
851    #[test]
852    fn validate_heartbeat_pair_flags_misordered_pair() {
853        // timeout < interval → invalid (warning emitted; we just check return)
854        assert!(!validate_heartbeat_pair(
855            Some(Duration::from_secs(30)),
856            Some(Duration::from_secs(10))
857        ));
858        // Properly ordered pair → ok.
859        assert!(validate_heartbeat_pair(
860            Some(Duration::from_secs(30)),
861            Some(Duration::from_secs(45))
862        ));
863        // Either side missing → ok (caller intends SDK default for the other).
864        assert!(validate_heartbeat_pair(None, Some(Duration::from_secs(5))));
865        assert!(validate_heartbeat_pair(Some(Duration::from_secs(5)), None));
866    }
867
868    #[test]
869    fn options_builder_overrides_apply() {
870        let opts = MultiTransportOptions::builder()
871            .host("h:1")
872            .use_tls(true)
873            .max_streams_per_channel(42)
874            .transport_type(TransportType::WebSocket)
875            .build();
876        assert_eq!(opts.host, "h:1");
877        assert!(opts.use_tls);
878        assert_eq!(opts.max_streams_per_channel, 42);
879        assert_eq!(opts.transport_type, TransportType::WebSocket);
880    }
881
882    #[tokio::test]
883    async fn registration_key_from_config_matches_display_form() {
884        let r = reg("t", "c", "i");
885        let k = RegistrationKey::from_config(&r.config);
886        assert_eq!(k.to_string(), "t.c.i");
887    }
888
889    #[tokio::test]
890    async fn duplicate_registrations_in_new_are_collapsed() {
891        let runner = MultiConnectorRunner::new(
892            MultiTransportOptions::default(),
893            vec![reg("t", "c", "i"), reg("t", "c", "i"), reg("t", "c", "j")],
894        );
895        let keys = runner.registrations().await;
896        assert_eq!(keys.len(), 2, "second duplicate should be dropped");
897        assert_eq!(keys[0].instance_id, "i");
898        assert_eq!(keys[1].instance_id, "j");
899    }
900
901    #[tokio::test]
902    async fn add_rejects_duplicates() {
903        let runner =
904            MultiConnectorRunner::new(MultiTransportOptions::default(), vec![reg("t", "c", "i")]);
905        let err = runner.add(reg("t", "c", "i")).await.unwrap_err();
906        assert!(matches!(err, ConnectorError::InvalidConfig(_)));
907    }
908
909    #[tokio::test]
910    async fn add_after_run_starts_is_rejected() {
911        let runner =
912            MultiConnectorRunner::new(MultiTransportOptions::default(), vec![reg("t", "c", "i")]);
913        runner.running.store(true, Ordering::SeqCst);
914        let err = runner.add(reg("t", "c", "j")).await.unwrap_err();
915        assert!(matches!(&err, ConnectorError::AlreadyRunning));
916    }
917
918    #[tokio::test]
919    async fn add_rejects_duplicate_with_invalid_config() {
920        let runner =
921            MultiConnectorRunner::new(MultiTransportOptions::default(), vec![reg("t", "c", "i")]);
922        let err = runner.add(reg("t", "c", "i")).await.unwrap_err();
923        assert!(matches!(&err, ConnectorError::InvalidConfig(m) if m.contains("duplicate")));
924    }
925
926    #[tokio::test]
927    async fn shutdown_handle_signals_internal_flag() {
928        let runner =
929            MultiConnectorRunner::new(MultiTransportOptions::default(), vec![reg("t", "c", "i")]);
930        let h = runner.shutdown_handle();
931        assert!(!runner.shutdown_requested.load(Ordering::SeqCst));
932        h.shutdown();
933        assert!(runner.shutdown_requested.load(Ordering::SeqCst));
934    }
935
936    #[tokio::test]
937    async fn run_with_empty_registrations_is_ok() {
938        let runner = MultiConnectorRunner::new(MultiTransportOptions::default(), vec![]);
939        runner.run().await.expect("empty run should succeed");
940    }
941
942    #[tokio::test]
943    async fn run_with_pre_signalled_shutdown_is_ok() {
944        let runner =
945            MultiConnectorRunner::new(MultiTransportOptions::default(), vec![reg("t", "c", "i")]);
946        runner.shutdown_handle().shutdown();
947        runner
948            .run()
949            .await
950            .expect("pre-signalled shutdown should be a clean Ok exit");
951    }
952
953    #[tokio::test]
954    async fn run_websocket_accepts_config_and_shuts_down_cleanly() {
955        // WS path fans out to N independent ConnectorRunners under the hood.
956        // We can't run a real WS server here, so we verify the multi-runner
957        // accepts WS config and exits cleanly when shutdown is signalled
958        // before run starts.
959        let opts = MultiTransportOptions::builder()
960            .transport_type(TransportType::WebSocket)
961            .host("localhost:65535") // unreachable; reconnect would loop forever
962            .build();
963        let runner = MultiConnectorRunner::new(opts, vec![reg("t", "c", "i")]);
964        runner.shutdown_handle().shutdown();
965        runner
966            .run()
967            .await
968            .expect("pre-signalled WS shutdown should be a clean Ok exit");
969    }
970
971    #[tokio::test]
972    async fn run_websocket_shutdown_does_not_leak_watcher_tasks() {
973        // Regression test for the per-registration watcher-task leak. With
974        // reconnect disabled the child ConnectorRunner exits on its own
975        // (unreachable host, fail-fast). After it returns, the multi
976        // run_websocket must tear its bridge task down so we can shut down
977        // promptly and not leave a spinning watcher behind.
978        let opts = MultiTransportOptions::builder()
979            .transport_type(TransportType::WebSocket)
980            // Localhost:1 is RFC-reserved and reliably refuses connections.
981            .host("127.0.0.1:1")
982            .build();
983        let mut opts = opts;
984        opts.reconnect_enabled = false;
985
986        let runner = MultiConnectorRunner::new(opts, vec![reg("t", "c", "i")]);
987        let shutdown = runner.shutdown_handle();
988
989        // Trigger shutdown shortly after starting; the runner must observe
990        // it via the watch channel and exit promptly even if the child
991        // ConnectorRunner is still mid-handshake.
992        tokio::spawn(async move {
993            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
994            shutdown.shutdown();
995        });
996
997        let res = tokio::time::timeout(std::time::Duration::from_secs(5), runner.run()).await;
998        assert!(
999            res.is_ok(),
1000            "run_websocket must exit within 5s of shutdown signal"
1001        );
1002        res.unwrap()
1003            .expect("run_websocket should return Ok after clean shutdown");
1004    }
1005
1006    #[tokio::test]
1007    async fn add_races_with_run_either_lands_or_rejects_never_silently_dropped() {
1008        // Regression test for the TOCTOU between `add()`'s `running.load()`
1009        // and `run()`'s snapshot. Either the registration must be visible
1010        // to the runner (driven), or `add()` must return AlreadyRunning —
1011        // we must never observe "added but not driven".
1012        //
1013        // We can't drive a real run() without a server, so we exploit the
1014        // public surface: spawn run() against an empty pre-shutdown runner
1015        // (which exits cleanly without touching any registrations) and
1016        // race add() against it. The invariant is that `add()`'s outcome
1017        // must be consistent with `registrations()` *as seen after `run()`
1018        // returns*.
1019        for _ in 0..200 {
1020            let opts = MultiTransportOptions::default();
1021            let runner = std::sync::Arc::new(MultiConnectorRunner::new(opts, vec![]));
1022            // Pre-signal shutdown so run() exits without needing a server.
1023            runner.shutdown_handle().shutdown();
1024
1025            let r1 = runner.clone();
1026            let run_task = tokio::spawn(async move { r1.run().await });
1027
1028            // Yield once to let run() take the write lock first sometimes.
1029            tokio::task::yield_now().await;
1030
1031            let add_res = runner.add(reg("t", "c", "i")).await;
1032            run_task.await.expect("run join").expect("run ok");
1033
1034            match add_res {
1035                Ok(()) => {
1036                    // Must be visible in the registration list.
1037                    let keys = runner.registrations().await;
1038                    assert!(
1039                        keys.iter().any(|k| k.instance_id == "i"),
1040                        "add() succeeded but registration is not visible"
1041                    );
1042                }
1043                Err(ConnectorError::AlreadyRunning) => {
1044                    // Must NOT be in the list — rejection means not added.
1045                    let keys = runner.registrations().await;
1046                    assert!(
1047                        !keys.iter().any(|k| k.instance_id == "i"),
1048                        "add() returned AlreadyRunning but registration was still inserted"
1049                    );
1050                }
1051                Err(other) => panic!("unexpected add() error: {other:?}"),
1052            }
1053        }
1054    }
1055
1056    #[tokio::test]
1057    async fn run_called_twice_returns_already_running() {
1058        let runner = MultiConnectorRunner::new(MultiTransportOptions::default(), vec![]);
1059        // Manually flip running so the second call hits the AlreadyRunning path
1060        // without us needing to race with a real run(). We can't naturally
1061        // observe two concurrent run()s without a running server — this test
1062        // covers only the precondition.
1063        runner.running.store(true, Ordering::SeqCst);
1064        let err = runner.run().await.unwrap_err();
1065        assert!(matches!(err, ConnectorError::AlreadyRunning));
1066    }
1067}