1use arc_swap::ArcSwap;
4use futures_util::{SinkExt, StreamExt};
5use parking_lot::{Mutex, RwLock};
6use serde::{Deserialize, Serialize};
7use std::collections::VecDeque;
8use std::fs;
9use std::net::IpAddr;
10use std::path::PathBuf;
11use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
12use std::sync::Arc;
13use std::time::{Duration, Instant};
14use sysinfo::System;
15use tokio::sync::mpsc::error::TrySendError;
16use tokio::sync::{broadcast, mpsc};
17use tokio_tungstenite::tungstenite::client::IntoClientRequest;
18use tokio_tungstenite::tungstenite::Message;
19use tracing::{debug, error, info, warn};
20
21use super::blocklist::BlocklistCache;
22use super::config::HorizonConfig;
23use super::error::HorizonError;
24use super::types::{
25 AuthPayload, ConnectionState, HeartbeatPayload, HubMessage, SensorMessage, ThreatSignal,
26 PROTOCOL_VERSION,
27};
28use crate::access::{check_ssrf, SsrfCheckResult};
29use crate::config_manager::ConfigManager;
30use crate::utils::circuit_breaker::CircuitBreaker;
31use async_trait::async_trait;
32
33#[async_trait]
35pub trait SignalSink: Send + Sync {
36 async fn report_signal(&self, signal: ThreatSignal) -> Result<(), String>;
37}
38
39#[async_trait]
40impl SignalSink for HorizonClient {
41 async fn report_signal(&self, signal: ThreatSignal) -> Result<(), String> {
42 if !self.circuit_breaker().allow_request().await {
43 return Err("Circuit breaker open".to_string());
44 }
45 self.report_signal(signal);
46 Ok(())
47 }
48}
49
50pub trait MetricsProvider: Send + Sync {
52 fn cpu_usage(&self) -> f64;
53 fn memory_usage(&self) -> f64;
54 fn disk_usage(&self) -> f64;
55 fn requests_last_minute(&self) -> u64;
56 fn avg_latency_ms(&self) -> f64;
57 fn config_hash(&self) -> String;
58 fn rules_hash(&self) -> String;
59 fn active_connections(&self) -> Option<u32>;
60}
61
62pub struct NoopMetricsProvider;
64
65impl MetricsProvider for NoopMetricsProvider {
66 fn cpu_usage(&self) -> f64 {
67 0.0
68 }
69 fn memory_usage(&self) -> f64 {
70 0.0
71 }
72 fn disk_usage(&self) -> f64 {
73 0.0
74 }
75 fn requests_last_minute(&self) -> u64 {
76 0
77 }
78 fn avg_latency_ms(&self) -> f64 {
79 0.0
80 }
81 fn config_hash(&self) -> String {
82 String::new()
83 }
84 fn rules_hash(&self) -> String {
85 String::new()
86 }
87 fn active_connections(&self) -> Option<u32> {
88 None
89 }
90}
91
92struct InternalStats {
94 signals_sent: AtomicU64,
95 signals_acked: AtomicU64,
96 signals_queued: AtomicU64,
97 signals_dropped: AtomicU64,
98 batches_sent: AtomicU64,
99 heartbeats_sent: AtomicU64,
100 heartbeat_failures: AtomicU64,
101 reconnect_attempts: AtomicU32,
102}
103
104#[derive(Debug, Clone, Default, Serialize, Deserialize)]
106pub struct ClientStats {
107 pub signals_sent: u64,
108 pub signals_acked: u64,
109 pub signals_queued: u64,
110 pub signals_dropped: u64,
111 pub batches_sent: u64,
112 pub heartbeats_sent: u64,
113 pub heartbeat_failures: u64,
114 pub reconnect_attempts: u32,
115}
116
117impl From<&InternalStats> for ClientStats {
118 fn from(stats: &InternalStats) -> Self {
119 Self {
120 signals_sent: stats.signals_sent.load(Ordering::Relaxed),
121 signals_acked: stats.signals_acked.load(Ordering::Relaxed),
122 signals_queued: stats.signals_queued.load(Ordering::Relaxed),
123 signals_dropped: stats.signals_dropped.load(Ordering::Relaxed),
124 batches_sent: stats.batches_sent.load(Ordering::Relaxed),
125 heartbeats_sent: stats.heartbeats_sent.load(Ordering::Relaxed),
126 heartbeat_failures: stats.heartbeat_failures.load(Ordering::Relaxed),
127 reconnect_attempts: stats.reconnect_attempts.load(Ordering::Relaxed),
128 }
129 }
130}
131
132pub struct HorizonClient {
134 config: HorizonConfig,
135 state: Arc<RwLock<ConnectionState>>,
136 blocklist: Arc<BlocklistCache>,
137 stats: Arc<InternalStats>,
138 metrics_provider: Arc<dyn MetricsProvider>,
139 signal_tx: ArcSwap<Option<mpsc::Sender<ThreatSignal>>>,
140 signal_retry: Arc<Mutex<VecDeque<ThreatSignal>>>,
141 shutdown_tx: ArcSwap<Option<broadcast::Sender<()>>>,
142 tenant_id: Arc<RwLock<Option<String>>>,
143 capabilities: Arc<RwLock<Vec<String>>>,
144 config_manager: Arc<ArcSwap<Option<Arc<ConfigManager>>>>,
146 circuit_breaker: Arc<CircuitBreaker>,
147}
148
149impl HorizonClient {
150 pub fn new(config: HorizonConfig) -> Self {
152 let circuit_breaker = Arc::new(CircuitBreaker::new(5, Duration::from_secs(30)));
153 Self {
154 config,
155 state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
156 blocklist: Arc::new(BlocklistCache::new()),
157 stats: Arc::new(InternalStats {
158 signals_sent: AtomicU64::new(0),
159 signals_acked: AtomicU64::new(0),
160 signals_queued: AtomicU64::new(0),
161 signals_dropped: AtomicU64::new(0),
162 batches_sent: AtomicU64::new(0),
163 heartbeats_sent: AtomicU64::new(0),
164 heartbeat_failures: AtomicU64::new(0),
165 reconnect_attempts: AtomicU32::new(0),
166 }),
167 metrics_provider: Arc::new(NoopMetricsProvider),
168 signal_tx: ArcSwap::from_pointee(None),
169 signal_retry: Arc::new(Mutex::new(VecDeque::new())),
170 shutdown_tx: ArcSwap::from_pointee(None),
171 tenant_id: Arc::new(RwLock::new(None)),
172 capabilities: Arc::new(RwLock::new(Vec::new())),
173 config_manager: Arc::new(ArcSwap::from_pointee(None)),
174 circuit_breaker,
175 }
176 }
177
178 pub fn with_metrics_provider(mut self, provider: Arc<dyn MetricsProvider>) -> Self {
180 self.metrics_provider = provider;
181 self
182 }
183
184 pub fn with_config_manager(mut self, manager: Arc<ConfigManager>) -> Self {
186 self.config_manager = Arc::new(ArcSwap::from_pointee(Some(manager)));
188 self
189 }
190
191 pub fn set_config_manager(&self, manager: Arc<ConfigManager>) {
193 self.config_manager.store(Arc::new(Some(manager)));
194 }
195
196 pub async fn start(&self) -> Result<(), HorizonError> {
198 if !self.config.enabled {
199 debug!("Horizon client disabled, skipping start");
200 return Ok(());
201 }
202
203 {
206 let mut state = self.state.write();
207 if *state != ConnectionState::Disconnected {
208 debug!("Horizon client already started (state: {:?})", *state);
209 return Ok(());
210 }
211 *state = ConnectionState::Connecting;
212 }
213
214 if let Err(e) = self.perform_start().await {
215 *self.state.write() = ConnectionState::Disconnected;
217 return Err(e);
218 }
219
220 Ok(())
221 }
222
223 async fn perform_start(&self) -> Result<(), HorizonError> {
225 self.config.validate()?;
226
227 if should_enforce_hub_url_ssrf() {
230 validate_hub_url_ssrf(&self.config.hub_url).await?;
231 }
232
233 let (signal_tx, signal_rx) = mpsc::channel::<ThreatSignal>(self.config.max_queued_signals);
235 let (shutdown_tx, _shutdown_rx) = broadcast::channel::<()>(1);
236
237 self.signal_tx.store(Arc::new(Some(signal_tx.clone())));
238 self.shutdown_tx.store(Arc::new(Some(shutdown_tx.clone())));
239
240 let params = ConnectionParams {
242 config: self.config.clone(),
243 state: Arc::clone(&self.state),
244 blocklist: Arc::clone(&self.blocklist),
245 stats: Arc::clone(&self.stats),
246 metrics_provider: Arc::clone(&self.metrics_provider),
247 tenant_id: Arc::clone(&self.tenant_id),
248 capabilities: Arc::clone(&self.capabilities),
249 config_manager: Arc::clone(&self.config_manager),
250 circuit_breaker: Arc::clone(&self.circuit_breaker),
251 };
252
253 let retry_queue = Arc::clone(&self.signal_retry);
254 let retry_stats = Arc::clone(&self.stats);
255 let retry_tx = signal_tx.clone();
256 let retry_limit = self.config.max_queued_signals;
257 let shutdown_rx_conn = shutdown_tx.subscribe();
258
259 tokio::spawn(async move {
260 connection_loop(params, signal_rx, shutdown_rx_conn).await;
261 });
262
263 let mut shutdown_rx_retry = shutdown_tx.subscribe();
264 tokio::spawn(async move {
265 let mut interval = tokio::time::interval(Duration::from_millis(250));
266 loop {
267 tokio::select! {
268 _ = interval.tick() => {
269 let mut queue = retry_queue.lock();
270 if queue.is_empty() {
271 continue;
272 }
273 while let Some(signal) = queue.pop_front() {
274 match retry_tx.try_send(signal) {
275 Ok(()) => {
276 retry_stats.signals_sent.fetch_add(1, Ordering::Relaxed);
277 }
278 Err(TrySendError::Full(signal)) => {
279 queue.push_front(signal);
280 break;
281 }
282 Err(TrySendError::Closed(_)) => {
283 retry_stats.signals_dropped.fetch_add(1, Ordering::Relaxed);
285 queue.clear();
286 break;
287 }
288 }
289 }
290 if queue.len() > retry_limit {
291 let overflow = queue.len() - retry_limit;
292 for _ in 0..overflow {
293 queue.pop_front();
294 }
295 retry_stats
297 .signals_dropped
298 .fetch_add(overflow as u64, Ordering::Relaxed);
299 }
300 }
301 _ = shutdown_rx_retry.recv() => break,
302 }
303 }
304 });
305
306 Ok(())
307 }
308
309 pub async fn stop(&self) {
311 if let Some(tx) = self.shutdown_tx.swap(Arc::new(None)).as_ref() {
312 let _ = tx.send(());
313 }
314 *self.state.write() = ConnectionState::Disconnected;
315 }
316
317 pub fn report_signal(&self, signal: ThreatSignal) {
319 if let Some(ref tx) = **self.signal_tx.load() {
320 match tx.try_send(signal) {
321 Ok(()) => {
322 self.stats.signals_sent.fetch_add(1, Ordering::Relaxed);
323 }
324 Err(TrySendError::Full(signal)) => {
325 let mut queue = self.signal_retry.lock();
326 if queue.len() >= self.config.max_queued_signals {
327 self.stats.signals_dropped.fetch_add(1, Ordering::Relaxed);
329 warn!("Signal queue full; dropping signal");
330 } else {
331 queue.push_back(signal);
332 self.stats.signals_queued.fetch_add(1, Ordering::Relaxed);
333 warn!("Signal queue full; queued for retry");
334 }
335 }
336 Err(TrySendError::Closed(_)) => {
337 self.stats.signals_dropped.fetch_add(1, Ordering::Relaxed);
339 warn!("Signal channel closed; dropping signal");
340 }
341 }
342 }
343 }
344
345 pub async fn flush_signals(&self) {
347 }
349
350 #[inline]
352 pub fn is_ip_blocked(&self, ip: &str) -> bool {
353 self.blocklist.is_ip_blocked(ip)
354 }
355
356 #[inline]
358 pub fn is_fingerprint_blocked(&self, fingerprint: &str) -> bool {
359 self.blocklist.is_fingerprint_blocked(fingerprint)
360 }
361
362 pub fn is_blocked(&self, ip: Option<&str>, fingerprint: Option<&str>) -> bool {
364 if let Some(ip) = ip {
365 if self.is_ip_blocked(ip) {
366 return true;
367 }
368 }
369 if let Some(fp) = fingerprint {
370 if self.is_fingerprint_blocked(fp) {
371 return true;
372 }
373 }
374 false
375 }
376
377 pub async fn connection_state(&self) -> ConnectionState {
379 *self.state.read()
380 }
381
382 pub async fn is_connected(&self) -> bool {
384 *self.state.read() == ConnectionState::Connected
385 }
386
387 pub fn blocklist_size(&self) -> usize {
389 self.blocklist.size()
390 }
391
392 pub fn blocklist(&self) -> &Arc<BlocklistCache> {
394 &self.blocklist
395 }
396
397 pub fn stats(&self) -> ClientStats {
399 ClientStats::from(self.stats.as_ref())
400 }
401
402 pub fn circuit_breaker(&self) -> Arc<CircuitBreaker> {
404 Arc::clone(&self.circuit_breaker)
405 }
406
407 pub async fn tenant_id(&self) -> Option<String> {
409 self.tenant_id.read().clone()
410 }
411
412 pub async fn capabilities(&self) -> Vec<String> {
414 self.capabilities.read().clone()
415 }
416}
417
418struct ConnectionParams {
420 config: HorizonConfig,
421 state: Arc<RwLock<ConnectionState>>,
422 blocklist: Arc<BlocklistCache>,
423 stats: Arc<InternalStats>,
424 metrics_provider: Arc<dyn MetricsProvider>,
425 tenant_id: Arc<RwLock<Option<String>>>,
426 capabilities: Arc<RwLock<Vec<String>>>,
427 config_manager: Arc<ArcSwap<Option<Arc<ConfigManager>>>>,
428 circuit_breaker: Arc<CircuitBreaker>,
429}
430
431async fn connection_loop(
433 params: ConnectionParams,
434 mut signal_rx: mpsc::Receiver<ThreatSignal>,
435 mut shutdown_rx: broadcast::Receiver<()>,
436) {
437 let mut reconnect_delay = params.config.reconnect_delay_ms;
438 let mut attempt = 0u32;
439 let mut consecutive_failures = 0u32;
440 let mut circuit_open_until: Option<Instant> = None;
441 let mut pending_signals: VecDeque<ThreatSignal> = VecDeque::new();
442 let mut inflight_signals: VecDeque<ThreatSignal> = VecDeque::new();
443
444 loop {
445 if let Ok(()) | Err(broadcast::error::TryRecvError::Closed) = shutdown_rx.try_recv() {
447 info!("Horizon client shutdown requested");
448 *params.state.write() = ConnectionState::Disconnected;
449 return;
450 }
451
452 if let Some(until) = circuit_open_until {
454 let now = Instant::now();
455 if now < until {
456 *params.state.write() = ConnectionState::Degraded;
457 let remaining = until.saturating_duration_since(now);
458 tokio::select! {
459 _ = shutdown_rx.recv() => {
460 info!("Horizon client shutdown requested");
461 *params.state.write() = ConnectionState::Disconnected;
462 return;
463 }
464 _ = tokio::time::sleep(remaining) => {}
465 }
466 continue;
467 }
468
469 circuit_open_until = None;
470 info!("Horizon circuit breaker closed; resuming connection attempts");
471 }
472
473 if params.config.max_reconnect_attempts > 0
475 && attempt >= params.config.max_reconnect_attempts
476 {
477 error!("Max reconnect attempts reached");
478 *params.state.write() = ConnectionState::Error;
479 return;
480 }
481
482 *params.state.write() = ConnectionState::Connecting;
484 info!("Connecting to Hub: {}", params.config.hub_url);
485
486 match connect_and_run(
487 ¶ms,
488 &mut signal_rx,
489 &mut shutdown_rx,
490 &mut pending_signals,
491 &mut inflight_signals,
492 )
493 .await
494 {
495 ConnectionResult::Shutdown => {
496 info!("Horizon client shutdown");
497 *params.state.write() = ConnectionState::Disconnected;
498 return;
499 }
500 ConnectionResult::AuthFailed => {
501 error!("Authentication failed, not retrying");
502 *params.state.write() = ConnectionState::Error;
503 return;
504 }
505 ConnectionResult::Disconnected { had_connection } => {
506 requeue_inflight(
507 &mut pending_signals,
508 &mut inflight_signals,
509 params.config.max_queued_signals,
510 ¶ms.stats,
511 );
512 if had_connection {
513 attempt = 0;
514 reconnect_delay = params.config.reconnect_delay_ms;
515 consecutive_failures = 0;
516 }
517
518 attempt = attempt.saturating_add(1);
519 params
520 .stats
521 .reconnect_attempts
522 .store(attempt, Ordering::Relaxed);
523 consecutive_failures = consecutive_failures.saturating_add(1);
524
525 if params.config.circuit_breaker_threshold > 0
526 && consecutive_failures >= params.config.circuit_breaker_threshold
527 {
528 let cooldown =
529 Duration::from_millis(params.config.circuit_breaker_cooldown_ms.max(1));
530 circuit_open_until = Some(Instant::now() + cooldown);
531 *params.state.write() = ConnectionState::Degraded;
532 warn!(
533 "Horizon circuit breaker opened after {} consecutive failures; cooling down for {}ms",
534 consecutive_failures, cooldown.as_millis()
535 );
536 consecutive_failures = 0;
537 reconnect_delay = params.config.reconnect_delay_ms;
538 continue;
539 }
540
541 if attempt > 1 {
543 reconnect_delay = (reconnect_delay * 2).min(60_000);
544 }
545
546 let jitter_percent = fastrand::u32(0..50); let jitter_factor = 0.75 + (jitter_percent as f64 / 100.0);
549 let delay_with_jitter = (reconnect_delay as f64 * jitter_factor) as u64;
550
551 warn!(
552 "Disconnected, reconnecting in {}ms (attempt {}, base {}ms)",
553 delay_with_jitter, attempt, reconnect_delay
554 );
555 *params.state.write() = ConnectionState::Reconnecting;
556
557 tokio::time::sleep(Duration::from_millis(delay_with_jitter)).await;
558 }
559 ConnectionResult::Stopped => {
560 *params.state.write() = ConnectionState::Disconnected;
561 return;
562 }
563 }
564 }
565}
566
567enum ConnectionResult {
568 Shutdown,
569 AuthFailed,
570 Disconnected { had_connection: bool },
571 Stopped,
572}
573
574fn should_enforce_hub_url_ssrf() -> bool {
575 if std::env::var("SYNAPSE_ALLOW_INTERNAL_HORIZON_URL")
577 .map(|v| matches!(v.to_ascii_lowercase().as_str(), "1" | "true" | "yes"))
578 .unwrap_or(false)
579 {
580 return false;
581 }
582
583 !cfg!(debug_assertions)
585}
586
587async fn validate_hub_url_ssrf(hub_url: &str) -> Result<(), HorizonError> {
588 let url = reqwest::Url::parse(hub_url)
589 .map_err(|e| HorizonError::ConfigError(format!("Invalid hub_url '{}': {}", hub_url, e)))?;
590
591 let host = url
592 .host_str()
593 .ok_or_else(|| HorizonError::ConfigError("hub_url must include a hostname".to_string()))?;
594
595 if host.eq_ignore_ascii_case("localhost") {
596 return Err(HorizonError::ConfigError(
597 "hub_url resolves to localhost which is not allowed in production".to_string(),
598 ));
599 }
600
601 let port = url.port_or_known_default().unwrap_or(443);
602
603 if let Ok(ip) = host.parse::<IpAddr>() {
605 let result = check_ssrf(&ip);
606 if result.is_blocked() {
607 return Err(HorizonError::ConfigError(format!(
608 "hub_url targets blocked address {} ({:?})",
609 ip, result
610 )));
611 }
612 return Ok(());
613 }
614
615 let mut any = false;
617 let addrs = tokio::net::lookup_host((host, port)).await.map_err(|e| {
618 HorizonError::ConfigError(format!("Failed to resolve hub_url host '{}': {}", host, e))
619 })?;
620
621 for addr in addrs {
622 any = true;
623 let ip = addr.ip();
624 let result: SsrfCheckResult = check_ssrf(&ip);
625 if result.is_blocked() {
626 return Err(HorizonError::ConfigError(format!(
627 "hub_url resolves to blocked address {} ({:?})",
628 ip, result
629 )));
630 }
631 }
632
633 if !any {
634 return Err(HorizonError::ConfigError(format!(
635 "hub_url host '{}' did not resolve to any addresses",
636 host
637 )));
638 }
639
640 Ok(())
641}
642
643fn stash_pending(
645 pending: &mut VecDeque<ThreatSignal>,
646 batch: &mut Vec<ThreatSignal>,
647 max_size: usize,
648 stats: &Arc<InternalStats>,
649) {
650 if batch.is_empty() {
651 return;
652 }
653
654 let to_add = batch.len();
655 let current_size = pending.len();
656
657 if current_size + to_add > max_size {
658 let overflow = (current_size + to_add).saturating_sub(max_size);
659 let drop_from_pending = overflow.min(current_size);
661 if drop_from_pending > 0 {
662 for _ in 0..drop_from_pending {
663 pending.pop_front();
664 }
665 stats
667 .signals_dropped
668 .fetch_add(drop_from_pending as u64, Ordering::Relaxed);
669 }
670
671 let drop_from_batch = overflow.saturating_sub(drop_from_pending);
673 if drop_from_batch > 0 {
674 batch.drain(0..drop_from_batch);
675 stats
677 .signals_dropped
678 .fetch_add(drop_from_batch as u64, Ordering::Relaxed);
679 }
680
681 warn!(
682 "Signal buffer overflow ({} > {}); dropped {} oldest signals (FIFO)",
683 current_size + to_add,
684 max_size,
685 overflow
686 );
687 }
688
689 pending.extend(batch.drain(..));
690}
691
692fn requeue_inflight(
694 pending: &mut VecDeque<ThreatSignal>,
695 inflight: &mut VecDeque<ThreatSignal>,
696 max_size: usize,
697 stats: &Arc<InternalStats>,
698) {
699 if inflight.is_empty() {
700 return;
701 }
702
703 let to_add = inflight.len();
704 let current_size = pending.len();
705
706 if current_size + to_add > max_size {
707 let overflow = (current_size + to_add).saturating_sub(max_size);
708 let drop_count = overflow.min(current_size);
710 for _ in 0..drop_count {
711 pending.pop_front();
712 }
713 stats
715 .signals_dropped
716 .fetch_add(drop_count as u64, Ordering::Relaxed);
717
718 let remaining_overflow = overflow.saturating_sub(drop_count);
720 if remaining_overflow > 0 {
721 for _ in 0..remaining_overflow {
722 inflight.pop_front();
723 }
724 stats
726 .signals_dropped
727 .fetch_add(remaining_overflow as u64, Ordering::Relaxed);
728 }
729
730 warn!(
731 "Signal buffer overflow during requeue ({} > {}); dropped {} oldest signals (FIFO)",
732 current_size + to_add,
733 max_size,
734 overflow
735 );
736 }
737
738 let mut combined = VecDeque::with_capacity(pending.len() + inflight.len());
740 combined.extend(inflight.drain(..));
741 combined.extend(pending.drain(..));
742 *pending = combined;
743}
744
745async fn connect_and_run(
746 params: &ConnectionParams,
747 signal_rx: &mut mpsc::Receiver<ThreatSignal>,
748 shutdown_rx: &mut broadcast::Receiver<()>,
749 pending_signals: &mut VecDeque<ThreatSignal>,
750 inflight_signals: &mut VecDeque<ThreatSignal>,
751) -> ConnectionResult {
752 let mut had_connection = false;
753
754 let mut request = match params.config.hub_url.clone().into_client_request() {
756 Ok(req) => req,
757 Err(e) => {
758 error!("Failed to build WebSocket request: {}", e);
759 return ConnectionResult::Disconnected { had_connection };
760 }
761 };
762
763 if let Ok(value) = http::HeaderValue::from_str(&format!("Bearer {}", params.config.api_key)) {
764 request
765 .headers_mut()
766 .insert(http::header::AUTHORIZATION, value);
767 }
768
769 let ws_stream = match tokio_tungstenite::connect_async(request).await {
770 Ok((stream, _)) => stream,
771 Err(e) => {
772 error!("WebSocket connection failed: {}", e);
773 params.circuit_breaker.record_failure().await;
774 return ConnectionResult::Disconnected { had_connection };
775 }
776 };
777
778 let (mut ws_tx, mut ws_rx) = ws_stream.split();
779
780 *params.state.write() = ConnectionState::Authenticating;
782 let auth_msg = SensorMessage::Auth {
783 payload: AuthPayload {
784 api_key: params.config.api_key.clone(),
785 sensor_id: params.config.sensor_id.clone(),
786 sensor_name: params.config.sensor_name.clone(),
787 version: params.config.version.clone(),
788 protocol_version: Some(PROTOCOL_VERSION.to_string()),
789 },
790 };
791
792 if let Err(e) = ws_tx
793 .send(Message::Text(auth_msg.to_json().unwrap().into()))
794 .await
795 {
796 error!("Failed to send auth: {}", e);
797 return ConnectionResult::Disconnected { had_connection };
798 }
799
800 let auth_timeout = tokio::time::timeout(Duration::from_secs(10), ws_rx.next()).await;
801
802 match auth_timeout {
803 Ok(Some(Ok(Message::Text(text)))) => match HubMessage::from_json(&text) {
804 Ok(HubMessage::AuthSuccess {
805 sensor_id: _,
806 tenant_id: tid,
807 capabilities: caps,
808 protocol_version: negotiated_version,
809 }) => {
810 if let Some(ref pv) = negotiated_version {
811 info!("Authenticated with Hub (tenant: {}, protocol: {})", tid, pv);
812 } else {
813 info!("Authenticated with Hub (tenant: {})", tid);
814 }
815 params.circuit_breaker.record_success().await;
816 *params.tenant_id.write() = Some(tid);
817 *params.capabilities.write() = caps;
818 *params.state.write() = ConnectionState::Connected;
819 had_connection = true;
820
821 let _ = ws_tx
822 .send(Message::Text(
823 SensorMessage::BlocklistSync.to_json().unwrap().into(),
824 ))
825 .await;
826 }
827 Ok(HubMessage::AuthFailed { error }) => {
828 error!("Auth failed: {}", error);
829 return ConnectionResult::AuthFailed;
830 }
831 _ => {
832 error!("Unexpected auth response");
833 params.circuit_breaker.record_failure().await;
834 return ConnectionResult::Disconnected { had_connection };
835 }
836 },
837 _ => {
838 error!("Auth timeout or error");
839 params.circuit_breaker.record_failure().await;
840 return ConnectionResult::Disconnected { had_connection };
841 }
842 }
843
844 let mut heartbeat_interval =
846 tokio::time::interval(Duration::from_millis(params.config.heartbeat_interval_ms));
847 let mut signal_batch: Vec<ThreatSignal> = Vec::with_capacity(params.config.signal_batch_size);
848 let mut batch_timer =
849 tokio::time::interval(Duration::from_millis(params.config.signal_batch_delay_ms));
850
851 if !pending_signals.is_empty() {
852 signal_batch.extend(pending_signals.drain(..));
853 if let Err(e) = send_batch(
854 &mut ws_tx,
855 &mut signal_batch,
856 inflight_signals,
857 ¶ms.stats,
858 )
859 .await
860 {
861 error!("Failed to send buffered signals: {}", e);
862 params.circuit_breaker.record_failure().await;
863 stash_pending(
864 pending_signals,
865 &mut signal_batch,
866 params.config.max_queued_signals,
867 ¶ms.stats,
868 );
869 return ConnectionResult::Disconnected { had_connection };
870 }
871 }
872
873 loop {
874 tokio::select! {
875 _ = shutdown_rx.recv() => {
876 info!("Shutdown received");
877 let _ = ws_tx.close().await;
878 return ConnectionResult::Shutdown;
879 }
880
881 signal = signal_rx.recv() => {
882 match signal {
883 Some(sig) => {
884 signal_batch.push(sig);
885 if signal_batch.len() >= params.config.signal_batch_size {
886 if let Err(e) = send_batch(&mut ws_tx, &mut signal_batch, inflight_signals, ¶ms.stats).await {
887 error!("Failed to send batch: {}", e);
888 params.circuit_breaker.record_failure().await;
889 stash_pending(pending_signals, &mut signal_batch, params.config.max_queued_signals, ¶ms.stats);
890 return ConnectionResult::Disconnected { had_connection };
891 }
892 params.circuit_breaker.record_success().await;
893 }
894 }
895 None => {
896 return ConnectionResult::Stopped;
897 }
898 }
899 }
900
901 _ = batch_timer.tick() => {
902 if !signal_batch.is_empty() {
903 if let Err(e) = send_batch(&mut ws_tx, &mut signal_batch, inflight_signals, ¶ms.stats).await {
904 error!("Failed to send batch: {}", e);
905 params.circuit_breaker.record_failure().await;
906 stash_pending(pending_signals, &mut signal_batch, params.config.max_queued_signals, ¶ms.stats);
907 return ConnectionResult::Disconnected { had_connection };
908 }
909 params.circuit_breaker.record_success().await;
910 }
911 }
912
913 msg = ws_rx.next() => {
914 match msg {
915 Some(Ok(Message::Text(text))) => {
916 if let Ok(hub_msg) = HubMessage::from_json(&text) {
917 let cm = params.config_manager.load();
918 handle_hub_message(
919 hub_msg,
920 ¶ms.blocklist,
921 ¶ms.stats,
922 ¶ms.metrics_provider,
923 &**cm,
924 inflight_signals,
925 &mut ws_tx,
926 )
927 .await;
928 }
929 }
930 Some(Ok(Message::Ping(data))) => {
931 let _ = ws_tx.send(Message::Pong(data)).await;
932 }
933 Some(Ok(Message::Close(_))) | None => {
934 warn!("WebSocket closed");
935 stash_pending(pending_signals, &mut signal_batch, params.config.max_queued_signals, ¶ms.stats);
936 return ConnectionResult::Disconnected { had_connection };
937 }
938 Some(Err(e)) => {
939 error!("WebSocket error: {}", e);
940 stash_pending(pending_signals, &mut signal_batch, params.config.max_queued_signals, ¶ms.stats);
941 return ConnectionResult::Disconnected { had_connection };
942 }
943 _ => {}
944 }
945 }
946
947 _ = heartbeat_interval.tick() => {
948 let payload = HeartbeatPayload {
949 timestamp: chrono::Utc::now().timestamp_millis(),
950 status: "healthy".to_string(),
951 cpu: params.metrics_provider.cpu_usage(),
952 memory: params.metrics_provider.memory_usage(),
953 disk: params.metrics_provider.disk_usage(),
954 requests_last_minute: params.metrics_provider.requests_last_minute(),
955 avg_latency_ms: params.metrics_provider.avg_latency_ms(),
956 config_hash: params.metrics_provider.config_hash(),
957 rules_hash: params.metrics_provider.rules_hash(),
958 active_connections: params.metrics_provider.active_connections(),
959 blocklist_size: Some(params.blocklist.size()),
960 };
961
962 let msg = SensorMessage::Heartbeat { payload };
963 if let Err(e) = ws_tx.send(Message::Text(msg.to_json().unwrap().into())).await {
964 warn!("Failed to send heartbeat: {}", e);
965 params.stats.heartbeat_failures.fetch_add(1, Ordering::Relaxed);
966 } else {
967 params.stats.heartbeats_sent.fetch_add(1, Ordering::Relaxed);
968 debug!("Sent heartbeat");
969 }
970 }
971 }
972 }
973}
974
975async fn send_batch<S>(
976 ws_tx: &mut futures_util::stream::SplitSink<S, Message>,
977 batch: &mut Vec<ThreatSignal>,
978 inflight: &mut VecDeque<ThreatSignal>,
979 stats: &Arc<InternalStats>,
980) -> Result<(), HorizonError>
981where
982 S: futures_util::Sink<Message> + Unpin,
983 <S as futures_util::Sink<Message>>::Error: std::fmt::Display,
984{
985 if batch.is_empty() {
986 return Ok(());
987 }
988
989 let signals: Vec<ThreatSignal> = std::mem::take(batch);
990 let count = signals.len();
991
992 if count == 0 {
993 return Ok(());
994 }
995
996 for signal in &signals {
997 inflight.push_back(signal.clone());
998 }
999
1000 let msg = if count == 1 {
1001 SensorMessage::Signal {
1002 payload: signals.into_iter().next().unwrap(),
1003 }
1004 } else {
1005 SensorMessage::SignalBatch { payload: signals }
1006 };
1007
1008 ws_tx
1009 .send(Message::Text(msg.to_json()?))
1010 .await
1011 .map_err(|e| HorizonError::SendFailed(e.to_string()))?;
1012
1013 stats.batches_sent.fetch_add(1, Ordering::Relaxed);
1014 debug!("Sent batch of {} signals", count);
1015
1016 Ok(())
1017}
1018
1019use super::types::CommandAckPayload;
1020
1021async fn send_command_ack<S>(
1022 ws_tx: &mut futures_util::stream::SplitSink<S, Message>,
1023 command_id: String,
1024 result: Result<Option<serde_json::Value>, String>,
1025) where
1026 S: futures_util::Sink<Message> + Unpin,
1027 <S as futures_util::Sink<Message>>::Error: std::fmt::Display,
1028{
1029 let (success, message, result_value) = match result {
1030 Ok(result_value) => (true, None, result_value),
1031 Err(message) => (false, Some(message), None),
1032 };
1033
1034 let ack = SensorMessage::CommandAck {
1035 payload: CommandAckPayload {
1036 command_id,
1037 success,
1038 message,
1039 result: result_value,
1040 },
1041 };
1042
1043 if let Ok(json) = ack.to_json() {
1044 if let Err(e) = ws_tx.send(Message::Text(json)).await {
1045 error!("Failed to send command ack: {}", e);
1046 }
1047 }
1048}
1049
1050fn sanitize_filename_component(input: &str) -> String {
1051 input
1052 .chars()
1053 .map(|c| {
1054 if c.is_ascii_alphanumeric() || c == '-' || c == '_' {
1055 c
1056 } else {
1057 '_'
1058 }
1059 })
1060 .collect()
1061}
1062
1063fn stage_update_payload(
1064 command_id: &str,
1065 payload: &serde_json::Value,
1066) -> Result<serde_json::Value, String> {
1067 let update_dir =
1068 std::env::var("SYNAPSE_UPDATE_DIR").unwrap_or_else(|_| "/tmp/synapse-updates".to_string());
1069
1070 fs::create_dir_all(&update_dir)
1071 .map_err(|e| format!("Failed to create update dir {}: {}", update_dir, e))?;
1072
1073 let safe_id = sanitize_filename_component(command_id);
1074 let file_name = format!(
1075 "update-{}-{}.json",
1076 chrono::Utc::now().format("%Y%m%d-%H%M%S"),
1077 safe_id
1078 );
1079 let path = PathBuf::from(&update_dir).join(file_name);
1080
1081 let body = serde_json::to_string_pretty(payload)
1082 .map_err(|e| format!("Failed to serialize update payload: {}", e))?;
1083 fs::write(&path, body.as_bytes())
1084 .map_err(|e| format!("Failed to stage update payload: {}", e))?;
1085
1086 Ok(serde_json::json!({
1087 "staged": true,
1088 "path": path.to_string_lossy(),
1089 "bytes": body.len(),
1090 "update_dir": update_dir,
1091 "payload_version": payload.get("version").and_then(|value| value.as_str()),
1092 }))
1093}
1094
1095fn soft_restart(config_manager: &Option<Arc<ConfigManager>>) -> Result<serde_json::Value, String> {
1096 let manager = config_manager
1097 .as_ref()
1098 .ok_or_else(|| "ConfigManager not available".to_string())?;
1099
1100 let config = manager.get_full_config();
1101 let mutation = manager
1102 .update_full_config(config)
1103 .map_err(|e| e.to_string())?;
1104
1105 let rules = manager.list_rules();
1106 let rules_count = rules.len();
1107 let rules_loaded = manager
1108 .replace_rules(rules, None)
1109 .map_err(|e| e.to_string())?;
1110
1111 Ok(serde_json::json!({
1112 "restart_mode": "soft",
1113 "config_reloaded": true,
1114 "rules_loaded": rules_loaded,
1115 "rules_count": rules_count,
1116 "applied": mutation.applied,
1117 "persisted": mutation.persisted,
1118 "rebuild_required": mutation.rebuild_required,
1119 "warnings": mutation.warnings,
1120 }))
1121}
1122
1123fn collect_diagnostics(
1124 metrics_provider: &Arc<dyn MetricsProvider>,
1125 config_manager: &Option<Arc<ConfigManager>>,
1126 blocklist: &Arc<BlocklistCache>,
1127 stats: &Arc<InternalStats>,
1128 payload: &serde_json::Value,
1129) -> serde_json::Value {
1130 let include_config = payload
1131 .get("include_config")
1132 .and_then(|value| value.as_bool())
1133 .unwrap_or(false);
1134 let include_sites = payload
1135 .get("include_sites")
1136 .and_then(|value| value.as_bool())
1137 .unwrap_or(true);
1138 let include_rules = payload
1139 .get("include_rules")
1140 .and_then(|value| value.as_bool())
1141 .unwrap_or(false);
1142
1143 let mut sys = System::new_all();
1144 sys.refresh_all();
1145
1146 let system_info = serde_json::json!({
1147 "hostname": System::host_name().unwrap_or_default(),
1148 "os": System::name().unwrap_or_default(),
1149 "os_version": System::os_version().unwrap_or_default(),
1150 "kernel_version": System::kernel_version().unwrap_or_default(),
1151 "cpu_count": sys.cpus().len(),
1152 "total_memory_mb": sys.total_memory() / 1024 / 1024,
1153 "used_memory_mb": sys.used_memory() / 1024 / 1024,
1154 "uptime_secs": System::uptime(),
1155 });
1156
1157 let mut config_summary = serde_json::Map::new();
1158 let mut rules_summary = serde_json::Map::new();
1159
1160 if let Some(manager) = config_manager {
1161 let config = manager.get_full_config();
1162 let site_count = config.sites.len();
1163 let tls_sites = config
1164 .sites
1165 .iter()
1166 .filter(|site| site.tls.is_some())
1167 .count();
1168 let waf_sites = config
1169 .sites
1170 .iter()
1171 .filter(|site| site.waf.as_ref().map(|waf| waf.enabled).unwrap_or(false))
1172 .count();
1173 config_summary.insert("available".to_string(), serde_json::json!(true));
1174 config_summary.insert("site_count".to_string(), serde_json::json!(site_count));
1175 config_summary.insert("tls_site_count".to_string(), serde_json::json!(tls_sites));
1176 config_summary.insert(
1177 "waf_enabled_sites".to_string(),
1178 serde_json::json!(waf_sites),
1179 );
1180 if include_sites {
1181 let site_hostnames = config
1182 .sites
1183 .iter()
1184 .map(|site| site.hostname.clone())
1185 .collect::<Vec<_>>();
1186 config_summary.insert(
1187 "site_hostnames".to_string(),
1188 serde_json::json!(site_hostnames),
1189 );
1190 }
1191 if include_config {
1192 if let Ok(value) = serde_json::to_value(&config) {
1193 config_summary.insert("config".to_string(), value);
1194 }
1195 }
1196
1197 let rules = manager.list_rules();
1198 rules_summary.insert("count".to_string(), serde_json::json!(rules.len()));
1199 if include_rules {
1200 if let Ok(value) = serde_json::to_value(&rules) {
1201 rules_summary.insert("rules".to_string(), value);
1202 }
1203 }
1204 } else {
1205 config_summary.insert("available".to_string(), serde_json::json!(false));
1206 rules_summary.insert("count".to_string(), serde_json::json!(0));
1207 }
1208
1209 let stats_value = serde_json::to_value(ClientStats::from(stats.as_ref()))
1210 .unwrap_or_else(|_| serde_json::json!({}));
1211
1212 serde_json::json!({
1213 "generated_at": chrono::Utc::now().to_rfc3339(),
1214 "version": env!("CARGO_PKG_VERSION"),
1215 "system": system_info,
1216 "metrics": {
1217 "cpu": metrics_provider.cpu_usage(),
1218 "memory": metrics_provider.memory_usage(),
1219 "disk": metrics_provider.disk_usage(),
1220 "requests_last_minute": metrics_provider.requests_last_minute(),
1221 "avg_latency_ms": metrics_provider.avg_latency_ms(),
1222 "active_connections": metrics_provider.active_connections(),
1223 "config_hash": metrics_provider.config_hash(),
1224 "rules_hash": metrics_provider.rules_hash(),
1225 },
1226 "blocklist": { "size": blocklist.size() },
1227 "client_stats": stats_value,
1228 "config": serde_json::Value::Object(config_summary),
1229 "rules": serde_json::Value::Object(rules_summary),
1230 })
1231}
1232
1233async fn handle_hub_message<S>(
1234 msg: HubMessage,
1235 blocklist: &Arc<BlocklistCache>,
1236 stats: &Arc<InternalStats>,
1237 metrics_provider: &Arc<dyn MetricsProvider>,
1238 config_manager: &Option<Arc<ConfigManager>>,
1239 inflight_signals: &mut VecDeque<ThreatSignal>,
1240 ws_tx: &mut futures_util::stream::SplitSink<S, Message>,
1241) where
1242 S: futures_util::Sink<Message> + Unpin,
1243 <S as futures_util::Sink<Message>>::Error: std::fmt::Display,
1244{
1245 match msg {
1246 HubMessage::SignalAck { sequence_id: _ } => {
1247 stats.signals_acked.fetch_add(1, Ordering::Relaxed);
1248 if inflight_signals.pop_front().is_none() {
1249 warn!("Received signal ack but no inflight signals were tracked");
1250 }
1251 }
1252 HubMessage::BatchAck {
1253 count,
1254 sequence_id: _,
1255 } => {
1256 stats
1257 .signals_acked
1258 .fetch_add(count as u64, Ordering::Relaxed);
1259 debug!("Batch of {} signals acknowledged", count);
1260 let mut remaining = count as usize;
1261 while remaining > 0 {
1262 if inflight_signals.pop_front().is_none() {
1263 warn!(
1264 "Received batch ack for {} signals but inflight queue was empty",
1265 count
1266 );
1267 break;
1268 }
1269 remaining -= 1;
1270 }
1271 }
1272 HubMessage::Ping { timestamp: _ } => {}
1273 HubMessage::BlocklistSnapshot {
1274 entries,
1275 sequence_id,
1276 } => {
1277 info!(
1278 "Received blocklist snapshot: {} entries (seq: {})",
1279 entries.len(),
1280 sequence_id
1281 );
1282 blocklist.load_snapshot(entries, sequence_id);
1283 }
1284 HubMessage::BlocklistUpdate {
1285 updates,
1286 sequence_id,
1287 } => {
1288 debug!(
1289 "Received blocklist update: {} changes (seq: {})",
1290 updates.len(),
1291 sequence_id
1292 );
1293 blocklist.apply_updates(updates, sequence_id);
1294 }
1295 HubMessage::Error { error, code } => {
1296 warn!("Hub error: {} (code: {:?})", error, code);
1297 }
1298 HubMessage::ConfigUpdate { config: _, version } => {
1299 info!(
1300 "Received config update (legacy direct) version: {}",
1301 version
1302 );
1303 }
1304 HubMessage::PushConfig {
1305 command_id,
1306 payload,
1307 } => {
1308 let version = payload
1309 .version
1310 .clone()
1311 .unwrap_or_else(|| "unknown".to_string());
1312 info!(
1313 "Received PushConfig command (id: {}, version: {})",
1314 command_id, version
1315 );
1316
1317 let result = if let Some(manager) = config_manager {
1318 if let Some(config_value) = payload.config.as_ref() {
1319 match serde_json::from_value::<crate::config::ConfigFile>(config_value.clone())
1320 {
1321 Ok(new_config) => match manager.update_full_config(new_config) {
1322 Ok(result) => {
1323 info!("Applied config update v{}", version);
1324 Ok(Some(serde_json::json!({
1325 "applied": result.applied,
1326 "persisted": result.persisted,
1327 "rebuild_required": result.rebuild_required,
1328 "warnings": result.warnings,
1329 })))
1330 }
1331 Err(e) => {
1332 error!("Failed to apply config update v{}: {}", version, e);
1333 Err(e.to_string())
1334 }
1335 },
1336 Err(e) => {
1337 error!("Failed to parse config update v{}: {}", version, e);
1338 Err(e.to_string())
1339 }
1340 }
1341 } else if let Some(action) = payload.action.as_deref() {
1342 Err(format!(
1343 "push_config action '{}' not supported via hub",
1344 action
1345 ))
1346 } else {
1347 Err("push_config payload missing config".to_string())
1348 }
1349 } else {
1350 warn!("Config update received but no ConfigManager available");
1351 Err("ConfigManager not available".to_string())
1352 };
1353
1354 send_command_ack(ws_tx, command_id, result).await;
1355 }
1356 HubMessage::PushRules {
1357 command_id,
1358 payload,
1359 } => {
1360 info!("Received PushRules command (id: {})", command_id);
1361
1362 let result = if let Some(manager) = config_manager {
1363 let rules_value = payload.get("rules").unwrap_or(&payload);
1364 let rules_hash = payload.get("hash").and_then(|value| value.as_str());
1365 if !rules_value.is_array() {
1366 Err("push_rules payload missing rules array".to_string())
1367 } else {
1368 match serde_json::to_vec(rules_value) {
1369 Ok(rules_bytes) => match manager.update_waf_rules(&rules_bytes, rules_hash)
1370 {
1371 Ok(count) => {
1372 info!("Applied push_rules: {} rules loaded", count);
1373 Ok(Some(serde_json::json!({ "rules_loaded": count })))
1374 }
1375 Err(e) => {
1376 error!("Failed to apply push_rules: {}", e);
1377 Err(e.to_string())
1378 }
1379 },
1380 Err(e) => {
1381 error!("Failed to serialize push_rules payload: {}", e);
1382 Err(e.to_string())
1383 }
1384 }
1385 }
1386 } else {
1387 warn!("PushRules received but no ConfigManager available");
1388 Err("ConfigManager not available".to_string())
1389 };
1390
1391 send_command_ack(ws_tx, command_id, result).await;
1392 }
1393 HubMessage::Restart {
1394 command_id,
1395 payload,
1396 } => {
1397 info!("Received Restart command (id: {})", command_id);
1398 let requested_mode = payload
1399 .get("mode")
1400 .and_then(|value| value.as_str())
1401 .unwrap_or("soft");
1402
1403 let result = match soft_restart(config_manager) {
1404 Ok(mut value) => {
1405 if let Some(obj) = value.as_object_mut() {
1406 obj.insert(
1407 "requested_mode".to_string(),
1408 serde_json::json!(requested_mode),
1409 );
1410 }
1411 Ok(Some(value))
1412 }
1413 Err(e) => Err(e),
1414 };
1415
1416 send_command_ack(ws_tx, command_id, result).await;
1417 }
1418 HubMessage::CollectDiagnostics {
1419 command_id,
1420 payload,
1421 } => {
1422 info!("Received CollectDiagnostics command (id: {})", command_id);
1423 let result = Ok(Some(collect_diagnostics(
1424 metrics_provider,
1425 config_manager,
1426 blocklist,
1427 stats,
1428 &payload,
1429 )));
1430 send_command_ack(ws_tx, command_id, result).await;
1431 }
1432 HubMessage::Update {
1433 command_id,
1434 payload,
1435 } => {
1436 info!("Received Update command (id: {})", command_id);
1437 let result = stage_update_payload(&command_id, &payload)
1438 .map(Some)
1439 .map_err(|e| e.to_string());
1440 send_command_ack(ws_tx, command_id, result).await;
1441 }
1442 HubMessage::SyncBlocklist {
1443 command_id,
1444 payload: _,
1445 } => {
1446 info!("Received SyncBlocklist command (id: {})", command_id);
1447 let result = match SensorMessage::BlocklistSync.to_json() {
1448 Ok(json) => {
1449 if let Err(e) = ws_tx.send(Message::Text(json)).await {
1450 Err(format!("Failed to request blocklist sync: {}", e))
1451 } else {
1452 Ok(None)
1453 }
1454 }
1455 Err(e) => Err(format!("Failed to serialize blocklist sync: {}", e)),
1456 };
1457
1458 send_command_ack(ws_tx, command_id, result).await;
1459 }
1460 HubMessage::RulesUpdate { rules, version } => {
1461 info!("Received rules update (version: {})", version);
1462
1463 let result = if let Some(manager) = config_manager {
1464 match serde_json::to_vec(&rules) {
1465 Ok(rules_bytes) => match manager.update_waf_rules(&rules_bytes, None) {
1466 Ok(count) => {
1467 info!("Applied rules update v{}: {} rules loaded", version, count);
1468 Ok(count)
1469 }
1470 Err(e) => {
1471 error!("Failed to apply rules update v{}: {}", version, e);
1472 Err(e.to_string())
1473 }
1474 },
1475 Err(e) => {
1476 error!("Failed to serialize rules for update v{}: {}", version, e);
1477 Err(e.to_string())
1478 }
1479 }
1480 } else {
1481 warn!("Rules update received but no ConfigManager available");
1482 Err("ConfigManager not available".to_string())
1483 };
1484
1485 send_command_ack(
1486 ws_tx,
1487 format!("rules_update_{}", version),
1488 result.map(|count| Some(serde_json::json!({ "rules_loaded": count }))),
1489 )
1490 .await;
1491 }
1492 HubMessage::AuthSuccess {
1493 tenant_id,
1494 sensor_id,
1495 capabilities,
1496 protocol_version,
1497 } => {
1498 info!(
1499 "Auth success: tenant={} sensor={} capabilities={:?} protocol={:?}",
1500 tenant_id, sensor_id, capabilities, protocol_version
1501 );
1502 }
1503 HubMessage::AuthFailed { error } => {
1504 error!("Auth failed (redundant): {}", error);
1505 }
1506 HubMessage::TunnelOpen {
1507 tunnel_id,
1508 target_host,
1509 target_port,
1510 } => {
1511 warn!(
1512 "Tunnel open requested (id: {}, target: {}:{}) but tunnels are not supported",
1513 tunnel_id, target_host, target_port
1514 );
1515 let error_msg = SensorMessage::TunnelError {
1516 tunnel_id,
1517 code: "TUNNEL_UNSUPPORTED".to_string(),
1518 message: "This sensor does not support tunnel connections".to_string(),
1519 };
1520 if let Ok(json) = error_msg.to_json() {
1521 let _ = ws_tx.send(Message::Text(json)).await;
1522 }
1523 }
1524 HubMessage::TunnelClose { tunnel_id } => {
1525 warn!(
1526 "Tunnel close requested (id: {}) but tunnels are not supported",
1527 tunnel_id
1528 );
1529 }
1530 HubMessage::TunnelData { tunnel_id, .. } => {
1531 warn!(
1532 "Tunnel data received (id: {}) but tunnels are not supported",
1533 tunnel_id
1534 );
1535 }
1536 }
1537}
1538
1539#[cfg(test)]
1540mod tests {
1541 use super::*;
1542
1543 #[test]
1544 fn test_noop_metrics_provider() {
1545 let provider = NoopMetricsProvider;
1546 assert_eq!(provider.cpu_usage(), 0.0);
1547 assert_eq!(provider.memory_usage(), 0.0);
1548 assert_eq!(provider.disk_usage(), 0.0);
1549 assert_eq!(provider.requests_last_minute(), 0);
1550 assert_eq!(provider.avg_latency_ms(), 0.0);
1551 assert!(provider.config_hash().is_empty());
1552 assert!(provider.rules_hash().is_empty());
1553 assert!(provider.active_connections().is_none());
1554 }
1555
1556 #[test]
1557 fn test_client_stats_default() {
1558 let stats = ClientStats::default();
1559 assert_eq!(stats.signals_sent, 0);
1560 assert_eq!(stats.signals_acked, 0);
1561 assert_eq!(stats.signals_queued, 0);
1562 assert_eq!(stats.signals_dropped, 0);
1563 assert_eq!(stats.batches_sent, 0);
1564 assert_eq!(stats.heartbeats_sent, 0);
1565 }
1566
1567 #[tokio::test]
1568 async fn test_client_disabled() {
1569 let config = HorizonConfig::default();
1570 let client = HorizonClient::new(config);
1571
1572 assert!(client.start().await.is_ok());
1573 }
1574
1575 #[tokio::test]
1576 async fn test_client_blocklist_lookup() {
1577 let config = HorizonConfig::default();
1578 let client = HorizonClient::new(config);
1579
1580 client
1581 .blocklist
1582 .add(super::super::blocklist::BlocklistEntry {
1583 block_type: super::super::blocklist::BlockType::Ip,
1584 indicator: "192.168.1.100".to_string(),
1585 expires_at: None,
1586 source: "test".to_string(),
1587 reason: None,
1588 created_at: None,
1589 });
1590
1591 assert!(client.is_ip_blocked("192.168.1.100"));
1592 assert!(!client.is_ip_blocked("192.168.1.101"));
1593 }
1594
1595 #[tokio::test]
1596 async fn test_validate_hub_url_ssrf_blocks_cloud_metadata() {
1597 let err = validate_hub_url_ssrf("wss://169.254.169.254/ws")
1598 .await
1599 .expect_err("expected metadata IP to be blocked");
1600 assert!(matches!(err, HorizonError::ConfigError(_)));
1601 }
1602
1603 #[tokio::test]
1604 async fn test_validate_hub_url_ssrf_blocks_loopback() {
1605 let err = validate_hub_url_ssrf("ws://127.0.0.1:1234/ws")
1606 .await
1607 .expect_err("expected loopback IP to be blocked");
1608 assert!(matches!(err, HorizonError::ConfigError(_)));
1609 }
1610
1611 #[tokio::test]
1612 async fn test_validate_hub_url_ssrf_allows_public_ip() {
1613 validate_hub_url_ssrf("wss://8.8.8.8/ws")
1614 .await
1615 .expect("expected public IP to be allowed");
1616 }
1617}