strike48_connector/multi/mod.rs
1//! Multi-registration connector runner.
2//!
3//! Lets one host process register `N` independently-approvable connectors
4//! against a single Matrix server while sharing the underlying transport:
5//!
6//! - **gRPC**: one TCP+TLS connection, N HTTP/2 streams (one per registration).
7//! Lazily opens additional channels when `max_streams_per_channel` is hit.
8//! - **WebSocket**: HTTP/1.1 — no native multiplexing — falls back to N
9//! independent WS connections in one process. Same public API.
10//!
11//! From the Matrix server's point of view, each registration is a normal
12//! `Connect` RPC. There are zero server-side changes.
13//!
14//! ## Reconnect behaviour
15//!
16//! The runner has two layers of reconnect, owned by different actors:
17//!
18//! - **Stream-level (per registration)**: implemented in the internal
19//! `registration_runner` module. When a stream ends (server closes, network
20//! blip, heartbeat timeout) the runner sleeps with exponential backoff +
21//! jitter (caps at `MultiTransportOptions::reconnect_max_delay_ms`) and
22//! opens a fresh stream over the existing channel. Fully shutdown-aware
23//! — never blocks shutdown for more than one backoff slice. Each
24//! reconnect bumps the registration's `successful_reconnects` /
25//! `total_disconnects` metrics.
26//!
27//! - **Channel-level (per HTTP/2 connection, gRPC only)**: delegated to
28//! `tonic::transport::Channel`. The channel is configured with HTTP/2
29//! keepalive and dialled **eagerly** via `endpoint.connect().await` the
30//! first time a registration needs it. tonic auto-recovers on transient
31//! TCP / TLS failures by re-dialing internally on the next request; the
32//! SDK does **not** explicitly recreate channels today.
33//!
34//! In practice this means: if the underlying TCP connection breaks, all
35//! N registrations using that channel will see their streams close. The
36//! per-registration loop opens a fresh stream, which in turn forces the
37//! channel to redial. End state: connections recover transparently
38//! without involvement from this module.
39//!
40//! If a channel ever goes **permanently dead** (e.g. DNS now points to
41//! an unreachable host and the lazy redial keeps failing), every
42//! registration on that channel will spend its time backing off. The
43//! [`MultiConnectorRunner::shutdown_handle`] still works, but recovery
44//! for the current process is best-effort. A future improvement would
45//! be to track per-channel consecutive-failure counts and rebuild the
46//! channel after a threshold; tracked under
47//! `connector-sdk-rust-channel-reconnect-policy`.
48//!
49//! ## Backward compatibility
50//!
51//! This module is **purely additive**. The existing single-registration
52//! [`crate::ConnectorRunner`] API and behaviour are unchanged.
53//!
54//! ## Example
55//!
56//! ```no_run
57//! # async fn run() -> strike48_connector::Result<()> {
58//! use std::sync::Arc;
59//! use strike48_connector::{
60//! BaseConnector, ConnectorConfig, ConnectorRegistration, MultiConnectorRunner,
61//! MultiTransportOptions, Result, TransportType,
62//! };
63//!
64//! struct Echo;
65//! impl BaseConnector for Echo {
66//! fn connector_type(&self) -> &str { "echo" }
67//! fn version(&self) -> &str { "1.0.0" }
68//! fn execute(
69//! &self,
70//! req: serde_json::Value,
71//! _: Option<&str>,
72//! ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send + '_>> {
73//! Box::pin(async move { Ok(req) })
74//! }
75//! }
76//!
77//! let opts = MultiTransportOptions::builder()
78//! .host("localhost:50061")
79//! .transport_type(TransportType::Grpc)
80//! .build();
81//!
82//! let registrations = (0..3).map(|i| {
83//! ConnectorRegistration::new(
84//! ConnectorConfig {
85//! tenant_id: "demo-org".into(),
86//! connector_type: "echo".into(),
87//! instance_id: format!("echo-{i}"),
88//! ..ConnectorConfig::default()
89//! },
90//! Echo,
91//! )
92//! }).collect::<Vec<_>>();
93//!
94//! let runner = MultiConnectorRunner::new(opts, registrations);
95//! let _shutdown = runner.shutdown_handle();
96//! runner.run().await?;
97//! # Ok(()) }
98//! ```
99
100use std::collections::HashMap;
101use std::sync::Arc;
102use std::sync::atomic::{AtomicBool, Ordering};
103use std::time::Duration;
104
105use tokio::sync::{Mutex, RwLock, Semaphore, watch};
106
107use crate::connector::{BaseConnector, ConnectorConfig, ShutdownHandle};
108use crate::error::{ConnectorError, Result};
109use crate::logger::Logger;
110use crate::transport::TransportType;
111use crate::types::ConnectorMetrics;
112
113mod registration_runner;
114mod shared_channel;
115
116use registration_runner::RegistrationRunner;
117use shared_channel::SharedChannel;
118
119// =============================================================================
120// Public types
121// =============================================================================
122
123/// Transport-level configuration shared by every registration in a
124/// [`MultiConnectorRunner`]. Per-registration identity (tenant, type, instance,
125/// auth token, ...) lives on each [`ConnectorRegistration`]'s [`ConnectorConfig`].
126///
127/// Marked `#[non_exhaustive]` so future fields can be added without breaking
128/// downstream struct-literal construction. Use
129/// [`MultiTransportOptions::builder`] (preferred) or
130/// `MultiTransportOptions { ..MultiTransportOptions::default() }`.
131#[derive(Debug, Clone)]
132#[non_exhaustive]
133pub struct MultiTransportOptions {
134 /// Server host:port (e.g. `localhost:50061` for gRPC, `localhost:4000` for WS).
135 pub host: String,
136 /// Whether to use TLS for the transport.
137 pub use_tls: bool,
138 /// Transport scheme.
139 pub transport_type: TransportType,
140
141 /// Soft cap on concurrent gRPC streams per channel before the runner
142 /// opens an additional channel. Defaults to `80` to leave headroom under
143 /// the typical Cowboy/RFC 7540 default of 100. Ignored for WebSocket.
144 pub max_streams_per_channel: usize,
145
146 /// Initial connect timeout (ms). Default 10_000.
147 pub connect_timeout_ms: u64,
148
149 /// Enable channel-level reconnect on transport failure. Default `true`.
150 pub reconnect_enabled: bool,
151 /// Base reconnect backoff (ms). Default 500.
152 pub reconnect_delay_ms: u64,
153 /// Max reconnect backoff (ms). Default 60_000.
154 pub max_backoff_delay_ms: u64,
155 /// Reconnect jitter (ms). Default 500.
156 pub reconnect_jitter_ms: u64,
157
158 /// Per-registration maximum number of in-flight `ExecuteRequest`s being
159 /// processed by the user `BaseConnector::execute()` callback. When the
160 /// limit is reached, additional `ExecuteRequest`s queue on a semaphore
161 /// until a permit is released. Mirrors
162 /// [`crate::ConnectorConfig::max_concurrent_requests`] for the
163 /// single-runner path. Default `100`.
164 pub max_concurrent_requests: usize,
165
166 /// Per-registration outbound heartbeat interval. `None` (default) means
167 /// use the SDK default of 30s, which matches the Matrix server's
168 /// session-reaper expectation.
169 ///
170 /// Only tune this when running against a Matrix deployment with a
171 /// non-default heartbeat configuration, or for flaky-network testing.
172 pub heartbeat_interval: Option<Duration>,
173
174 /// Per-registration heartbeat watchdog timeout. If no `HeartbeatResponse`
175 /// arrives within this window the runner declares the stream dead and
176 /// reconnects. `None` (default) means use the SDK default of 45s.
177 ///
178 /// Only tune this when running against a Matrix deployment with a
179 /// non-default heartbeat configuration, or for flaky-network testing.
180 pub heartbeat_timeout: Option<Duration>,
181}
182
183impl MultiTransportOptions {
184 /// Start a builder with sensible defaults (gRPC, plaintext, localhost:50061).
185 pub fn builder() -> MultiTransportOptionsBuilder {
186 MultiTransportOptionsBuilder::default()
187 }
188}
189
190impl Default for MultiTransportOptions {
191 fn default() -> Self {
192 Self {
193 host: "localhost:50061".to_string(),
194 use_tls: false,
195 transport_type: TransportType::Grpc,
196 max_streams_per_channel: 80,
197 connect_timeout_ms: 10_000,
198 reconnect_enabled: true,
199 reconnect_delay_ms: 500,
200 max_backoff_delay_ms: 60_000,
201 reconnect_jitter_ms: 500,
202 max_concurrent_requests: 100,
203 heartbeat_interval: None,
204 heartbeat_timeout: None,
205 }
206 }
207}
208
209/// Validate a heartbeat (interval, timeout) pair and emit a `tracing::warn!`
210/// if the timeout is shorter than the interval — that misconfigures the
211/// watchdog (the very first tick can fire after the timeout has already
212/// elapsed). Returns `false` when the pair is misconfigured.
213///
214/// Extracted as a free function so it can be unit-tested in isolation.
215fn validate_heartbeat_pair(interval: Option<Duration>, timeout: Option<Duration>) -> bool {
216 match (interval, timeout) {
217 (Some(i), Some(t)) if t < i => {
218 tracing::warn!(
219 target: "strike48_connector::heartbeat",
220 interval_ms = i.as_millis() as u64,
221 timeout_ms = t.as_millis() as u64,
222 "heartbeat_timeout < heartbeat_interval; the watchdog can fire before the first heartbeat reply has a chance to arrive"
223 );
224 false
225 }
226 _ => true,
227 }
228}
229
230/// Fluent builder for [`MultiTransportOptions`].
231#[derive(Debug, Clone, Default)]
232pub struct MultiTransportOptionsBuilder {
233 inner: Option<MultiTransportOptions>,
234}
235
236impl MultiTransportOptionsBuilder {
237 fn opts(&mut self) -> &mut MultiTransportOptions {
238 self.inner
239 .get_or_insert_with(MultiTransportOptions::default)
240 }
241
242 /// Set the server host:port.
243 pub fn host(mut self, host: impl Into<String>) -> Self {
244 self.opts().host = host.into();
245 self
246 }
247
248 /// Set whether to use TLS.
249 pub fn use_tls(mut self, use_tls: bool) -> Self {
250 self.opts().use_tls = use_tls;
251 self
252 }
253
254 /// Set the transport scheme.
255 pub fn transport_type(mut self, t: TransportType) -> Self {
256 self.opts().transport_type = t;
257 self
258 }
259
260 /// Override the soft cap on concurrent gRPC streams per channel (gRPC only).
261 pub fn max_streams_per_channel(mut self, n: usize) -> Self {
262 self.opts().max_streams_per_channel = n;
263 self
264 }
265
266 /// Override the initial connect timeout (ms).
267 pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
268 self.opts().connect_timeout_ms = ms;
269 self
270 }
271
272 /// Override the per-registration `max_concurrent_requests` cap.
273 pub fn max_concurrent_requests(mut self, n: usize) -> Self {
274 self.opts().max_concurrent_requests = n.max(1);
275 self
276 }
277
278 /// Enable or disable automatic reconnection on stream loss.
279 pub fn reconnect_enabled(mut self, enabled: bool) -> Self {
280 self.opts().reconnect_enabled = enabled;
281 self
282 }
283
284 /// Initial reconnect delay (ms) — the base for exponential backoff.
285 pub fn reconnect_delay_ms(mut self, ms: u64) -> Self {
286 self.opts().reconnect_delay_ms = ms;
287 self
288 }
289
290 /// Hard cap on reconnect backoff (ms). Jitter is applied first, then
291 /// the result is clamped to this value, so the cap is a strict upper
292 /// bound on the wait between attempts.
293 pub fn max_backoff_delay_ms(mut self, ms: u64) -> Self {
294 self.opts().max_backoff_delay_ms = ms;
295 self
296 }
297
298 /// Per-attempt jitter range (ms). A uniformly-random value in
299 /// `0..=ms` is added to the scaled backoff before capping.
300 pub fn reconnect_jitter_ms(mut self, ms: u64) -> Self {
301 self.opts().reconnect_jitter_ms = ms;
302 self
303 }
304
305 /// Override the per-registration heartbeat interval.
306 ///
307 /// Defaults to **30s** (the Matrix server's session-reaper expectation).
308 /// Only tune this when running against a Matrix deployment with a
309 /// non-default heartbeat configuration, or for flaky-network testing.
310 /// If the resulting `heartbeat_timeout` is shorter than the interval,
311 /// a `tracing::warn!` is emitted at builder time but the value is
312 /// still applied (the runner uses whatever is configured).
313 pub fn heartbeat_interval(mut self, d: Duration) -> Self {
314 self.opts().heartbeat_interval = Some(d);
315 let _ = validate_heartbeat_pair(
316 self.opts().heartbeat_interval,
317 self.opts().heartbeat_timeout,
318 );
319 self
320 }
321
322 /// Override the per-registration heartbeat watchdog timeout.
323 ///
324 /// Defaults to **45s** (Matrix server default + slack for one missed
325 /// reply). Only tune this when running against a Matrix deployment
326 /// with a non-default heartbeat configuration, or for flaky-network
327 /// testing. Misconfigured pairs (`timeout < interval`) emit a
328 /// `tracing::warn!` at builder time but are still applied.
329 pub fn heartbeat_timeout(mut self, d: Duration) -> Self {
330 self.opts().heartbeat_timeout = Some(d);
331 let _ = validate_heartbeat_pair(
332 self.opts().heartbeat_interval,
333 self.opts().heartbeat_timeout,
334 );
335 self
336 }
337
338 /// Build the options.
339 pub fn build(mut self) -> MultiTransportOptions {
340 self.inner.take().unwrap_or_default()
341 }
342}
343
344/// One logical connector to run inside a [`MultiConnectorRunner`].
345///
346/// `config.host`, `config.use_tls`, and `config.transport_type` are ignored —
347/// the transport is governed by [`MultiTransportOptions`]. All other fields
348/// (`tenant_id`, `connector_type`, `instance_id`, `auth_token`,
349/// `display_name`, `tags`, `metadata`, `max_concurrent_requests`,
350/// `metrics_*`) apply to this registration only.
351///
352/// Marked `#[non_exhaustive]` so future fields (e.g. per-registration
353/// behaviour overrides) can be added without breaking downstream
354/// struct-literal construction. Prefer [`ConnectorRegistration::new`].
355#[non_exhaustive]
356pub struct ConnectorRegistration {
357 pub config: ConnectorConfig,
358 pub connector: Arc<dyn BaseConnector>,
359}
360
361impl ConnectorRegistration {
362 /// Build a registration from a `ConnectorConfig` and any
363 /// [`BaseConnector`] implementor — the conversion to
364 /// `Arc<dyn BaseConnector>` happens internally so callers don't have
365 /// to write `Arc::new(...) as Arc<dyn BaseConnector>` themselves.
366 ///
367 /// ```
368 /// use std::sync::Arc;
369 /// use strike48_connector::{
370 /// BaseConnector, ConnectorConfig, ConnectorRegistration, Result,
371 /// };
372 ///
373 /// struct Echo;
374 /// impl BaseConnector for Echo {
375 /// fn connector_type(&self) -> &str { "echo" }
376 /// fn version(&self) -> &str { "1.0.0" }
377 /// fn execute(
378 /// &self,
379 /// req: serde_json::Value,
380 /// _: Option<&str>,
381 /// ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send + '_>> {
382 /// Box::pin(async move { Ok(req) })
383 /// }
384 /// }
385 ///
386 /// let cfg = ConnectorConfig {
387 /// tenant_id: "demo".into(),
388 /// connector_type: "echo".into(),
389 /// instance_id: "echo-1".into(),
390 /// ..ConnectorConfig::default()
391 /// };
392 /// let reg = ConnectorRegistration::new(cfg, Echo);
393 /// assert_eq!(reg.config.connector_type, "echo");
394 /// # let _ = reg;
395 /// ```
396 pub fn new<T>(config: ConnectorConfig, connector: T) -> Self
397 where
398 T: BaseConnector + 'static,
399 {
400 Self {
401 config,
402 connector: Arc::new(connector) as Arc<dyn BaseConnector>,
403 }
404 }
405
406 /// Build a registration from a config and an already-erased
407 /// `Arc<dyn BaseConnector>`. Useful when callers have a heterogeneous
408 /// list of differently-typed connectors that they have already boxed.
409 pub fn from_arc(config: ConnectorConfig, connector: Arc<dyn BaseConnector>) -> Self {
410 Self { config, connector }
411 }
412}
413
414impl std::fmt::Debug for ConnectorRegistration {
415 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 f.debug_struct("ConnectorRegistration")
417 .field("config", &self.config)
418 .field(
419 "connector",
420 &format_args!(
421 "Arc<dyn BaseConnector>(\"{}\")",
422 self.connector.connector_type()
423 ),
424 )
425 .finish()
426 }
427}
428
429/// Stable identity for a registration. Matches the `tenant.type.instance`
430/// triple the Matrix server uses to key a `ConnectorSession`.
431///
432/// Marked `#[non_exhaustive]` so additional identity dimensions (e.g. an
433/// optional region tag) can be added without breaking downstream
434/// pattern-match exhaustiveness.
435#[derive(Debug, Clone, Hash, PartialEq, Eq)]
436#[non_exhaustive]
437pub struct RegistrationKey {
438 pub tenant_id: String,
439 pub connector_type: String,
440 pub instance_id: String,
441}
442
443impl RegistrationKey {
444 pub fn from_config(config: &ConnectorConfig) -> Self {
445 Self {
446 tenant_id: config.tenant_id.clone(),
447 connector_type: config.connector_type.clone(),
448 instance_id: config.instance_id.clone(),
449 }
450 }
451}
452
453impl std::fmt::Display for RegistrationKey {
454 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
455 write!(
456 f,
457 "{}.{}.{}",
458 self.tenant_id, self.connector_type, self.instance_id
459 )
460 }
461}
462
463// =============================================================================
464// MultiConnectorRunner
465// =============================================================================
466
467struct RegistrationEntry {
468 key: RegistrationKey,
469 config: ConnectorConfig,
470 connector: Arc<dyn BaseConnector>,
471 metrics: Arc<Mutex<ConnectorMetrics>>,
472}
473
474/// Runs `N` independently-approvable connectors over a shared transport.
475///
476/// See module-level docs for transport semantics. From the Matrix server's
477/// point of view, each registration is a normal `Connect` RPC.
478pub struct MultiConnectorRunner {
479 opts: MultiTransportOptions,
480 registrations: RwLock<Vec<RegistrationEntry>>,
481 shutdown_requested: Arc<AtomicBool>,
482 running: Arc<AtomicBool>,
483}
484
485impl MultiConnectorRunner {
486 /// Create a new runner. Construction does not open any connections —
487 /// transport is established lazily by [`MultiConnectorRunner::run`].
488 ///
489 /// Duplicate registrations (same `tenant.type.instance`) are rejected by
490 /// [`MultiConnectorRunner::add`]; duplicates passed in here are reduced
491 /// to the first occurrence and logged.
492 pub fn new(opts: MultiTransportOptions, registrations: Vec<ConnectorRegistration>) -> Self {
493 let mut entries: Vec<RegistrationEntry> = Vec::with_capacity(registrations.len());
494 for ConnectorRegistration { config, connector } in registrations {
495 let key = RegistrationKey::from_config(&config);
496 if entries.iter().any(|e| e.key == key) {
497 tracing::warn!(
498 target: "strike48_connector::multi",
499 registration = %key,
500 "duplicate registration ignored"
501 );
502 continue;
503 }
504 entries.push(RegistrationEntry {
505 key,
506 config,
507 connector,
508 metrics: Arc::new(Mutex::new(ConnectorMetrics::default())),
509 });
510 }
511
512 Self {
513 opts,
514 registrations: RwLock::new(entries),
515 shutdown_requested: Arc::new(AtomicBool::new(false)),
516 running: Arc::new(AtomicBool::new(false)),
517 }
518 }
519
520 /// Append a registration. Only valid before [`MultiConnectorRunner::run`]
521 /// is called; returns an error if `run()` has already started or if the
522 /// registration's `tenant.type.instance` collides with an existing one.
523 ///
524 /// `running` is checked **inside** the same write-lock critical section
525 /// that `run()` uses to snapshot the registration list, so a concurrent
526 /// `add` either lands before the snapshot (and is driven) or returns
527 /// [`ConnectorError::AlreadyRunning`]. There is no window where the
528 /// registration is silently dropped.
529 pub async fn add(&self, registration: ConnectorRegistration) -> Result<()> {
530 let key = RegistrationKey::from_config(®istration.config);
531 let mut regs = self.registrations.write().await;
532 if self.running.load(Ordering::SeqCst) {
533 return Err(ConnectorError::AlreadyRunning);
534 }
535 if regs.iter().any(|e| e.key == key) {
536 return Err(ConnectorError::InvalidConfig(format!(
537 "duplicate registration: {key}"
538 )));
539 }
540 regs.push(RegistrationEntry {
541 key,
542 config: registration.config,
543 connector: registration.connector,
544 metrics: Arc::new(Mutex::new(ConnectorMetrics::default())),
545 });
546 Ok(())
547 }
548
549 /// Get a [`ShutdownHandle`] that signals every registration to exit.
550 pub fn shutdown_handle(&self) -> ShutdownHandle {
551 ShutdownHandle::from_flag(self.shutdown_requested.clone())
552 }
553
554 /// Snapshot of the registered keys (in insertion order).
555 pub async fn registrations(&self) -> Vec<RegistrationKey> {
556 self.registrations
557 .read()
558 .await
559 .iter()
560 .map(|e| e.key.clone())
561 .collect()
562 }
563
564 /// Per-registration metrics snapshot.
565 ///
566 /// Each registration owns its own [`ConnectorMetrics`] (no global
567 /// singleton), so values are independent across registrations sharing the
568 /// same transport.
569 pub async fn metrics_snapshot(&self) -> HashMap<RegistrationKey, ConnectorMetrics> {
570 let regs = self.registrations.read().await;
571 let mut out = HashMap::with_capacity(regs.len());
572 for entry in regs.iter() {
573 let snapshot = entry.metrics.lock().await.clone();
574 out.insert(entry.key.clone(), snapshot);
575 }
576 out
577 }
578
579 /// Run all registrations to completion or until shutdown.
580 ///
581 /// Returns once every registration has exited. Individual registration
582 /// failures are logged but do not abort the runner unless
583 /// [`MultiTransportOptions::reconnect_enabled`] is `false`.
584 pub async fn run(&self) -> Result<()> {
585 let logger = Logger::new("multi");
586
587 // Take the write lock so concurrent `add()` calls observe `running`
588 // atomically with the snapshot — see `add` for the partner half of
589 // this protocol. We hold the lock only long enough to flip `running`
590 // and clone the entries.
591 let entries: Vec<RegistrationEntry> = {
592 let regs = self.registrations.write().await;
593 if self
594 .running
595 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
596 .is_err()
597 {
598 return Err(ConnectorError::AlreadyRunning);
599 }
600 regs.iter()
601 .map(|e| RegistrationEntry {
602 key: e.key.clone(),
603 config: e.config.clone(),
604 connector: e.connector.clone(),
605 metrics: e.metrics.clone(),
606 })
607 .collect()
608 };
609
610 // Pre-signalled shutdown is a clean no-op exit.
611 if self.shutdown_requested.load(Ordering::SeqCst) {
612 logger.debug("shutdown signalled before run; exiting");
613 self.running.store(false, Ordering::SeqCst);
614 return Ok(());
615 }
616
617 if entries.is_empty() {
618 logger.warn("no registrations configured; run() exiting immediately");
619 self.running.store(false, Ordering::SeqCst);
620 return Ok(());
621 }
622
623 let result = match self.opts.transport_type {
624 TransportType::Grpc => self.run_grpc(entries, logger).await,
625 TransportType::WebSocket => self.run_websocket(entries, logger).await,
626 };
627 self.running.store(false, Ordering::SeqCst);
628 result
629 }
630
631 async fn run_grpc(&self, entries: Vec<RegistrationEntry>, logger: Logger) -> Result<()> {
632 let shared = Arc::new(SharedChannel::new(self.opts.clone()));
633 let mut tasks = Vec::with_capacity(entries.len());
634
635 for entry in entries {
636 let runner = RegistrationRunner {
637 key: entry.key.clone(),
638 config: Arc::new(RwLock::new(entry.config)),
639 connector: entry.connector,
640 shared_channel: shared.clone(),
641 shutdown: self.shutdown_requested.clone(),
642 metrics: entry.metrics,
643 opts: self.opts.clone(),
644 request_semaphore: Arc::new(Semaphore::new(
645 self.opts.max_concurrent_requests.max(1),
646 )),
647 session_token: Arc::new(RwLock::new(None)),
648 };
649 tasks.push(tokio::spawn(async move { runner.run().await }));
650 }
651
652 for task in tasks {
653 match task.await {
654 Ok(Ok(())) => {}
655 Ok(Err(e)) => {
656 logger.warn(&format!("registration runner exited with error: {e}"));
657 }
658 Err(join_err) => {
659 logger.error("registration task panicked", &join_err.to_string());
660 }
661 }
662 }
663
664 Ok(())
665 }
666
667 /// WebSocket has no native multiplexing (HTTP/1.1), so each logical
668 /// registration uses its own `WebSocketTransport` underneath. We fan out
669 /// to N independent [`crate::ConnectorRunner`]s — same public ergonomics
670 /// as gRPC mode, but the transport count == registration count. This keeps
671 /// the API symmetric and lets the existing single-runner reconnect /
672 /// auth / metrics paths apply unchanged.
673 async fn run_websocket(&self, entries: Vec<RegistrationEntry>, logger: Logger) -> Result<()> {
674 use crate::ConnectorRunner;
675
676 // One watch::channel for the whole runner: when the multi-runner's
677 // shutdown flag is flipped we bump this channel; every child task
678 // observes the change immediately via `changed().await` and signals
679 // its own ConnectorRunner's shutdown handle. No per-registration
680 // polling task — those leaked indefinitely on the previous
681 // `tokio::spawn(loop { sleep(100ms); load(...) })` design.
682 let (shutdown_tx, shutdown_rx_template) = watch::channel(false);
683
684 // Bridge the AtomicBool flag (kept for backward compatibility with
685 // ShutdownHandle) to the watch sender. Owned by `run_websocket` and
686 // dropped when this function returns; that drop closes the watch
687 // channel and lets every child observe completion if it hadn't
688 // already.
689 let multi_shutdown = self.shutdown_requested.clone();
690 let bridge_tx = shutdown_tx.clone();
691 let bridge = tokio::spawn(async move {
692 // Coarse poll on the AtomicBool — only ONE such task per runner,
693 // not one per registration. Exits cleanly when shutdown fires.
694 while !multi_shutdown.load(Ordering::SeqCst) {
695 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
696 if bridge_tx.is_closed() {
697 // All receivers gone (run_websocket returned via the
698 // happy path before shutdown was ever signalled).
699 return;
700 }
701 }
702 let _ = bridge_tx.send(true);
703 });
704
705 let mut tasks = Vec::with_capacity(entries.len());
706
707 for entry in entries {
708 let mut config = entry.config.clone();
709 config.transport_type = TransportType::WebSocket;
710 config.host = self.opts.host.clone();
711 config.use_tls = self.opts.use_tls;
712 config.reconnect_enabled = self.opts.reconnect_enabled;
713 config.reconnect_delay_ms = self.opts.reconnect_delay_ms;
714 config.max_backoff_delay_ms = self.opts.max_backoff_delay_ms;
715 config.reconnect_jitter_ms = self.opts.reconnect_jitter_ms;
716
717 let runner = ConnectorRunner::new(config, entry.connector);
718 let child_shutdown = runner.shutdown_handle();
719 let mut shutdown_rx = shutdown_rx_template.clone();
720 let key = entry.key.clone();
721
722 tasks.push(tokio::spawn(async move {
723 let mut runner_fut = Box::pin(runner.run());
724 let res = loop {
725 tokio::select! {
726 biased;
727 // Propagate shutdown to the child runner the
728 // instant the multi-runner flag flips. After
729 // signalling we keep awaiting the runner's own
730 // exit so its drain logic runs.
731 changed = shutdown_rx.changed() => {
732 match changed {
733 Ok(()) if *shutdown_rx.borrow() => {
734 child_shutdown.shutdown();
735 }
736 Err(_) => {
737 // Sender dropped; treat as "no further
738 // shutdown will arrive". Wait for the
739 // runner to finish on its own.
740 break runner_fut.await;
741 }
742 _ => {}
743 }
744 }
745 result = &mut runner_fut => break result,
746 }
747 };
748 (key, res)
749 }));
750 }
751
752 // Drop the template receiver so the only live receivers are the
753 // ones inside the spawned tasks. This lets the bridge task observe
754 // `is_closed()` once every task has exited.
755 drop(shutdown_rx_template);
756
757 for task in tasks {
758 match task.await {
759 Ok((_key, Ok(()))) => {}
760 Ok((key, Err(e))) => {
761 logger.warn(&format!("ws registration {key} exited with error: {e}"));
762 }
763 Err(join_err) => {
764 logger.error("ws registration task panicked", &join_err.to_string());
765 }
766 }
767 }
768
769 // Tear down the bridge task. If it exited on its own we just await
770 // a finished JoinHandle (cheap). Otherwise abort + await yields a
771 // JoinError we ignore — the task is short-lived and trivial.
772 bridge.abort();
773 let _ = bridge.await;
774
775 // Keep the sender alive for the whole function; dropping here.
776 drop(shutdown_tx);
777
778 Ok(())
779 }
780}
781
782#[cfg(test)]
783mod tests {
784 use super::*;
785 use crate::types::ConnectorBehavior;
786
787 struct DummyConnector;
788 impl BaseConnector for DummyConnector {
789 fn connector_type(&self) -> &str {
790 "dummy"
791 }
792 fn version(&self) -> &str {
793 "0.0.0"
794 }
795 fn execute(
796 &self,
797 _: serde_json::Value,
798 _: Option<&str>,
799 ) -> std::pin::Pin<
800 Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send + '_>,
801 > {
802 Box::pin(async { Ok(serde_json::json!({})) })
803 }
804 fn behavior(&self) -> ConnectorBehavior {
805 ConnectorBehavior::Tool
806 }
807 }
808
809 fn reg(tenant: &str, ty: &str, inst: &str) -> ConnectorRegistration {
810 ConnectorRegistration::new(
811 ConnectorConfig {
812 tenant_id: tenant.into(),
813 connector_type: ty.into(),
814 instance_id: inst.into(),
815 ..ConnectorConfig::default()
816 },
817 DummyConnector,
818 )
819 }
820
821 #[test]
822 fn options_builder_defaults_match_default_impl() {
823 let built = MultiTransportOptions::builder().build();
824 let defaulted = MultiTransportOptions::default();
825 assert_eq!(built.host, defaulted.host);
826 assert_eq!(built.use_tls, defaulted.use_tls);
827 assert_eq!(
828 built.max_streams_per_channel,
829 defaulted.max_streams_per_channel
830 );
831 assert_eq!(built.transport_type, defaulted.transport_type);
832 }
833
834 #[test]
835 fn options_builder_heartbeat_roundtrip() {
836 let opts = MultiTransportOptions::builder()
837 .heartbeat_interval(Duration::from_secs(5))
838 .heartbeat_timeout(Duration::from_secs(15))
839 .build();
840 assert_eq!(opts.heartbeat_interval, Some(Duration::from_secs(5)));
841 assert_eq!(opts.heartbeat_timeout, Some(Duration::from_secs(15)));
842 }
843
844 #[test]
845 fn options_builder_heartbeat_defaults_are_none() {
846 let opts = MultiTransportOptions::builder().build();
847 assert!(opts.heartbeat_interval.is_none());
848 assert!(opts.heartbeat_timeout.is_none());
849 }
850
851 #[test]
852 fn validate_heartbeat_pair_flags_misordered_pair() {
853 // timeout < interval → invalid (warning emitted; we just check return)
854 assert!(!validate_heartbeat_pair(
855 Some(Duration::from_secs(30)),
856 Some(Duration::from_secs(10))
857 ));
858 // Properly ordered pair → ok.
859 assert!(validate_heartbeat_pair(
860 Some(Duration::from_secs(30)),
861 Some(Duration::from_secs(45))
862 ));
863 // Either side missing → ok (caller intends SDK default for the other).
864 assert!(validate_heartbeat_pair(None, Some(Duration::from_secs(5))));
865 assert!(validate_heartbeat_pair(Some(Duration::from_secs(5)), None));
866 }
867
868 #[test]
869 fn options_builder_overrides_apply() {
870 let opts = MultiTransportOptions::builder()
871 .host("h:1")
872 .use_tls(true)
873 .max_streams_per_channel(42)
874 .transport_type(TransportType::WebSocket)
875 .build();
876 assert_eq!(opts.host, "h:1");
877 assert!(opts.use_tls);
878 assert_eq!(opts.max_streams_per_channel, 42);
879 assert_eq!(opts.transport_type, TransportType::WebSocket);
880 }
881
882 #[tokio::test]
883 async fn registration_key_from_config_matches_display_form() {
884 let r = reg("t", "c", "i");
885 let k = RegistrationKey::from_config(&r.config);
886 assert_eq!(k.to_string(), "t.c.i");
887 }
888
889 #[tokio::test]
890 async fn duplicate_registrations_in_new_are_collapsed() {
891 let runner = MultiConnectorRunner::new(
892 MultiTransportOptions::default(),
893 vec![reg("t", "c", "i"), reg("t", "c", "i"), reg("t", "c", "j")],
894 );
895 let keys = runner.registrations().await;
896 assert_eq!(keys.len(), 2, "second duplicate should be dropped");
897 assert_eq!(keys[0].instance_id, "i");
898 assert_eq!(keys[1].instance_id, "j");
899 }
900
901 #[tokio::test]
902 async fn add_rejects_duplicates() {
903 let runner =
904 MultiConnectorRunner::new(MultiTransportOptions::default(), vec![reg("t", "c", "i")]);
905 let err = runner.add(reg("t", "c", "i")).await.unwrap_err();
906 assert!(matches!(err, ConnectorError::InvalidConfig(_)));
907 }
908
909 #[tokio::test]
910 async fn add_after_run_starts_is_rejected() {
911 let runner =
912 MultiConnectorRunner::new(MultiTransportOptions::default(), vec![reg("t", "c", "i")]);
913 runner.running.store(true, Ordering::SeqCst);
914 let err = runner.add(reg("t", "c", "j")).await.unwrap_err();
915 assert!(matches!(&err, ConnectorError::AlreadyRunning));
916 }
917
918 #[tokio::test]
919 async fn add_rejects_duplicate_with_invalid_config() {
920 let runner =
921 MultiConnectorRunner::new(MultiTransportOptions::default(), vec![reg("t", "c", "i")]);
922 let err = runner.add(reg("t", "c", "i")).await.unwrap_err();
923 assert!(matches!(&err, ConnectorError::InvalidConfig(m) if m.contains("duplicate")));
924 }
925
926 #[tokio::test]
927 async fn shutdown_handle_signals_internal_flag() {
928 let runner =
929 MultiConnectorRunner::new(MultiTransportOptions::default(), vec![reg("t", "c", "i")]);
930 let h = runner.shutdown_handle();
931 assert!(!runner.shutdown_requested.load(Ordering::SeqCst));
932 h.shutdown();
933 assert!(runner.shutdown_requested.load(Ordering::SeqCst));
934 }
935
936 #[tokio::test]
937 async fn run_with_empty_registrations_is_ok() {
938 let runner = MultiConnectorRunner::new(MultiTransportOptions::default(), vec![]);
939 runner.run().await.expect("empty run should succeed");
940 }
941
942 #[tokio::test]
943 async fn run_with_pre_signalled_shutdown_is_ok() {
944 let runner =
945 MultiConnectorRunner::new(MultiTransportOptions::default(), vec![reg("t", "c", "i")]);
946 runner.shutdown_handle().shutdown();
947 runner
948 .run()
949 .await
950 .expect("pre-signalled shutdown should be a clean Ok exit");
951 }
952
953 #[tokio::test]
954 async fn run_websocket_accepts_config_and_shuts_down_cleanly() {
955 // WS path fans out to N independent ConnectorRunners under the hood.
956 // We can't run a real WS server here, so we verify the multi-runner
957 // accepts WS config and exits cleanly when shutdown is signalled
958 // before run starts.
959 let opts = MultiTransportOptions::builder()
960 .transport_type(TransportType::WebSocket)
961 .host("localhost:65535") // unreachable; reconnect would loop forever
962 .build();
963 let runner = MultiConnectorRunner::new(opts, vec![reg("t", "c", "i")]);
964 runner.shutdown_handle().shutdown();
965 runner
966 .run()
967 .await
968 .expect("pre-signalled WS shutdown should be a clean Ok exit");
969 }
970
971 #[tokio::test]
972 async fn run_websocket_shutdown_does_not_leak_watcher_tasks() {
973 // Regression test for the per-registration watcher-task leak. With
974 // reconnect disabled the child ConnectorRunner exits on its own
975 // (unreachable host, fail-fast). After it returns, the multi
976 // run_websocket must tear its bridge task down so we can shut down
977 // promptly and not leave a spinning watcher behind.
978 let opts = MultiTransportOptions::builder()
979 .transport_type(TransportType::WebSocket)
980 // Localhost:1 is RFC-reserved and reliably refuses connections.
981 .host("127.0.0.1:1")
982 .build();
983 let mut opts = opts;
984 opts.reconnect_enabled = false;
985
986 let runner = MultiConnectorRunner::new(opts, vec![reg("t", "c", "i")]);
987 let shutdown = runner.shutdown_handle();
988
989 // Trigger shutdown shortly after starting; the runner must observe
990 // it via the watch channel and exit promptly even if the child
991 // ConnectorRunner is still mid-handshake.
992 tokio::spawn(async move {
993 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
994 shutdown.shutdown();
995 });
996
997 let res = tokio::time::timeout(std::time::Duration::from_secs(5), runner.run()).await;
998 assert!(
999 res.is_ok(),
1000 "run_websocket must exit within 5s of shutdown signal"
1001 );
1002 res.unwrap()
1003 .expect("run_websocket should return Ok after clean shutdown");
1004 }
1005
1006 #[tokio::test]
1007 async fn add_races_with_run_either_lands_or_rejects_never_silently_dropped() {
1008 // Regression test for the TOCTOU between `add()`'s `running.load()`
1009 // and `run()`'s snapshot. Either the registration must be visible
1010 // to the runner (driven), or `add()` must return AlreadyRunning —
1011 // we must never observe "added but not driven".
1012 //
1013 // We can't drive a real run() without a server, so we exploit the
1014 // public surface: spawn run() against an empty pre-shutdown runner
1015 // (which exits cleanly without touching any registrations) and
1016 // race add() against it. The invariant is that `add()`'s outcome
1017 // must be consistent with `registrations()` *as seen after `run()`
1018 // returns*.
1019 for _ in 0..200 {
1020 let opts = MultiTransportOptions::default();
1021 let runner = std::sync::Arc::new(MultiConnectorRunner::new(opts, vec![]));
1022 // Pre-signal shutdown so run() exits without needing a server.
1023 runner.shutdown_handle().shutdown();
1024
1025 let r1 = runner.clone();
1026 let run_task = tokio::spawn(async move { r1.run().await });
1027
1028 // Yield once to let run() take the write lock first sometimes.
1029 tokio::task::yield_now().await;
1030
1031 let add_res = runner.add(reg("t", "c", "i")).await;
1032 run_task.await.expect("run join").expect("run ok");
1033
1034 match add_res {
1035 Ok(()) => {
1036 // Must be visible in the registration list.
1037 let keys = runner.registrations().await;
1038 assert!(
1039 keys.iter().any(|k| k.instance_id == "i"),
1040 "add() succeeded but registration is not visible"
1041 );
1042 }
1043 Err(ConnectorError::AlreadyRunning) => {
1044 // Must NOT be in the list — rejection means not added.
1045 let keys = runner.registrations().await;
1046 assert!(
1047 !keys.iter().any(|k| k.instance_id == "i"),
1048 "add() returned AlreadyRunning but registration was still inserted"
1049 );
1050 }
1051 Err(other) => panic!("unexpected add() error: {other:?}"),
1052 }
1053 }
1054 }
1055
1056 #[tokio::test]
1057 async fn run_called_twice_returns_already_running() {
1058 let runner = MultiConnectorRunner::new(MultiTransportOptions::default(), vec![]);
1059 // Manually flip running so the second call hits the AlreadyRunning path
1060 // without us needing to race with a real run(). We can't naturally
1061 // observe two concurrent run()s without a running server — this test
1062 // covers only the precondition.
1063 runner.running.store(true, Ordering::SeqCst);
1064 let err = runner.run().await.unwrap_err();
1065 assert!(matches!(err, ConnectorError::AlreadyRunning));
1066 }
1067}