1use axum::{
7 extract::{
8 ws::{close_code, CloseFrame, Message, WebSocket},
9 Path, State, WebSocketUpgrade,
10 },
11 http::{header, HeaderMap, Method, StatusCode, Uri},
12 response::{Html as AxumHtml, IntoResponse, Response},
13 routing::get,
14 Json, Router,
15};
16use base64::{
17 engine::general_purpose::{STANDARD, URL_SAFE_NO_PAD},
18 Engine as _,
19};
20use futures_util::{SinkExt, StreamExt};
21use hmac::{Hmac, Mac};
22use serde::Serialize;
23use serde_json::Value as JsonValue;
24use sha2::Sha256;
25use shelly::{
26 ClientMessage, LiveSession, LiveView, PubSub, PubSubCommand, ResumeStatus, RuntimeCommand,
27 ServerMessage, TelemetryEvent, TelemetryEventKind, TelemetrySink, INTERNAL_RENDER_FLUSH_EVENT,
28};
29use std::{
30 collections::{BTreeMap, HashMap, HashSet},
31 fs,
32 path::PathBuf,
33 sync::{Arc, Mutex as StdMutex},
34 time::{SystemTime, UNIX_EPOCH},
35};
36use tokio::{
37 io::AsyncWriteExt,
38 sync::{mpsc, Mutex},
39 task::JoinHandle,
40 time::{sleep, timeout, Duration, Instant, MissedTickBehavior},
41};
42use tower_http::trace::TraceLayer;
43use tracing::{debug, error, info, warn};
44use tracing_subscriber::EnvFilter;
45use uuid::Uuid;
46
47const CLIENT_JS: &str = include_str!("../assets/shelly_liveview.js");
48const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024;
49const DEFAULT_MAX_UPLOAD_SIZE: u64 = 5 * 1024 * 1024;
50const PROTOCOL_VERSION: &str = "shelly/1";
51const DEFAULT_RESUME_TTL_MS: u64 = 120_000;
52const DEFAULT_CONNECT_HANDSHAKE_TIMEOUT_MS: u64 = 5_000;
53const DEFAULT_RECONNECT_BASE_MS: u64 = 750;
54const DEFAULT_RECONNECT_MAX_MS: u64 = 30_000;
55const DEFAULT_RECONNECT_JITTER_MS: u64 = 250;
56const DEFAULT_HEARTBEAT_INTERVAL_MS: u64 = 15_000;
57const DEFAULT_HEARTBEAT_TIMEOUT_MS: u64 = 10_000;
58const DEFAULT_DURABLE_LEASE_TTL_MS: u64 = 300_000;
59const DEFAULT_DURABLE_JOURNAL_LIMIT: usize = 512;
60const DEFAULT_OUTBOUND_QUEUE_CAPACITY: usize = 1024;
61const DEFAULT_OUTBOUND_BATCH_MAX_MESSAGES: usize = 32;
62const DEFAULT_OUTBOUND_BATCH_MAX_BYTES: usize = 64 * 1024;
63const DEFAULT_OUTBOUND_BATCH_FLUSH_INTERVAL_MS: u64 = 2;
64const DEFAULT_RENDER_CADENCE_MS: u64 = 0;
65const DEFAULT_OVERLOAD_WINDOW_MS: u64 = 1_000;
66const DEFAULT_SESSION_QUEUE_DEPTH_BUDGET: usize = 768;
67const DEFAULT_SESSION_BYTES_PER_SEC_BUDGET: usize = 512 * 1024;
68const DEFAULT_SESSION_CPU_MS_PER_SEC_BUDGET: u64 = 250;
69const DEFAULT_TENANT_QUEUE_DEPTH_BUDGET: usize = 4_096;
70const DEFAULT_TENANT_BYTES_PER_SEC_BUDGET: usize = 4 * 1024 * 1024;
71const DEFAULT_TENANT_CPU_MS_PER_SEC_BUDGET: u64 = 1_500;
72const DEFAULT_TENANT_QUOTA_WINDOW_MS: u64 = 60_000;
73const DEFAULT_TENANT_SESSION_QUOTA_PER_WINDOW: usize = 1_024;
74const DEFAULT_TENANT_EVENT_QUOTA_PER_WINDOW: usize = 250_000;
75const DEFAULT_HTTP3_ALT_SVC: &str = r#"h3=":443"; ma=86400"#;
76
77type HmacSha256 = Hmac<Sha256>;
78type RateLimitHook = Arc<dyn Fn(&RateLimitContext) -> bool + Send + Sync>;
79type AuthorizationHook = Arc<dyn Fn(&AuthorizationContext) -> AuthorizationDecision + Send + Sync>;
80type QuotaPolicyHook = Arc<dyn Fn(&QuotaContext) -> QuotaDecision + Send + Sync>;
81type OverloadPolicyHook =
82 Arc<dyn Fn(&OverloadContext, &OverloadDecision) -> OverloadDecision + Send + Sync>;
83type DurablePlacementHook =
84 Arc<dyn Fn(&DurablePlacementContext) -> DurablePlacementDecision + Send + Sync>;
85
86pub type LiveViewFactory = Arc<dyn Fn() -> Box<dyn LiveView> + Send + Sync>;
88
89#[derive(Clone)]
90struct AppState {
91 routes: Arc<Vec<LiveRoute>>,
92 target_id: String,
93 max_message_size: usize,
94 pubsub: PubSub,
95 uploads: UploadConfig,
96 security: SecurityConfig,
97 telemetry: Arc<TelemetryPipeline>,
98 reconnect: ReconnectConfig,
99 distributed: DistributedConfig,
100 durable: DurableRuntimeConfig,
101 outbound: OutboundConfig,
102 render: RenderConfig,
103 overload: OverloadConfig,
104 transport_http: HttpTransportConfig,
105}
106
107impl AppState {
108 fn route_for(&self, path: &str) -> Option<MatchedRoute> {
109 self.routes.iter().find_map(|route| route.match_path(path))
110 }
111}
112
113impl TokenSigner {
114 fn ephemeral() -> Self {
115 let secret = format!("{}:{}", Uuid::new_v4(), Uuid::new_v4()).into_bytes();
116 Self {
117 secret: Arc::new(secret),
118 }
119 }
120
121 fn new(secret: impl Into<Vec<u8>>) -> Self {
122 Self {
123 secret: Arc::new(secret.into()),
124 }
125 }
126
127 fn sign_session(&self, session_id: &str, path: &str, node_id: &str) -> String {
128 self.sign(&serde_json::json!({
129 "kind": "session",
130 "session_id": session_id,
131 "path": path,
132 "node_id": node_id,
133 }))
134 }
135
136 fn verify_session(&self, token: &str) -> Option<SignedSession> {
137 let payload = self.verify(token)?;
138 if payload.get("kind")?.as_str()? != "session" {
139 return None;
140 }
141
142 Some(SignedSession {
143 session_id: payload.get("session_id")?.as_str()?.to_string(),
144 path: payload.get("path")?.as_str()?.to_string(),
145 node_id: payload
146 .get("node_id")
147 .and_then(serde_json::Value::as_str)
148 .map(ToString::to_string),
149 })
150 }
151
152 fn sign_resume(&self, session_id: &str, path: &str) -> String {
153 self.sign(&serde_json::json!({
154 "kind": "resume",
155 "session_id": session_id,
156 "path": path,
157 "nonce": Uuid::new_v4().to_string(),
158 }))
159 }
160
161 fn verify_resume(&self, token: &str) -> Option<SignedResume> {
162 let payload = self.verify(token)?;
163 if payload.get("kind")?.as_str()? != "resume" {
164 return None;
165 }
166
167 Some(SignedResume {
168 session_id: payload.get("session_id")?.as_str()?.to_string(),
169 path: payload.get("path")?.as_str()?.to_string(),
170 nonce: payload.get("nonce")?.as_str()?.to_string(),
171 })
172 }
173
174 fn sign_csrf(&self, session_id: &str, path: &str) -> String {
175 self.sign(&serde_json::json!({
176 "kind": "csrf",
177 "session_id": session_id,
178 "path": path,
179 "nonce": Uuid::new_v4().to_string(),
180 }))
181 }
182
183 fn verify_csrf(&self, token: &str, session_id: &str, path: &str) -> bool {
184 let Some(payload) = self.verify(token) else {
185 return false;
186 };
187
188 payload.get("kind").and_then(serde_json::Value::as_str) == Some("csrf")
189 && payload
190 .get("session_id")
191 .and_then(serde_json::Value::as_str)
192 == Some(session_id)
193 && payload.get("path").and_then(serde_json::Value::as_str) == Some(path)
194 && payload
195 .get("nonce")
196 .and_then(serde_json::Value::as_str)
197 .is_some()
198 }
199
200 fn sign(&self, payload: &serde_json::Value) -> String {
201 let payload = serde_json::to_vec(payload).expect("token payload should serialize");
202 let signature = self.mac(&payload);
203 format!(
204 "{}.{}",
205 URL_SAFE_NO_PAD.encode(payload),
206 URL_SAFE_NO_PAD.encode(signature)
207 )
208 }
209
210 fn verify(&self, token: &str) -> Option<serde_json::Value> {
211 let (payload, signature) = token.split_once('.')?;
212 let payload = URL_SAFE_NO_PAD.decode(payload).ok()?;
213 let signature = URL_SAFE_NO_PAD.decode(signature).ok()?;
214 let mut mac = HmacSha256::new_from_slice(self.secret.as_slice()).ok()?;
215 mac.update(&payload);
216 mac.verify_slice(&signature).ok()?;
217 serde_json::from_slice(&payload).ok()
218 }
219
220 fn mac(&self, payload: &[u8]) -> Vec<u8> {
221 let mut mac = HmacSha256::new_from_slice(self.secret.as_slice())
222 .expect("HMAC accepts arbitrary key lengths");
223 mac.update(payload);
224 mac.finalize().into_bytes().to_vec()
225 }
226}
227
228pub struct ShellyRouter {
230 routes: Vec<LiveRoute>,
231 target_id: String,
232 max_message_size: usize,
233 pubsub: PubSub,
234 uploads: UploadConfig,
235 security: SecurityConfig,
236 telemetry: Arc<TelemetryPipeline>,
237 reconnect: ReconnectConfig,
238 distributed: DistributedConfig,
239 durable: DurableRuntimeConfig,
240 outbound: OutboundConfig,
241 render: RenderConfig,
242 overload: OverloadConfig,
243 transport_http: HttpTransportConfig,
244}
245
246#[derive(Clone)]
247struct LiveRoute {
248 pattern: String,
249 segments: Vec<RouteSegment>,
250 factory: LiveViewFactory,
251}
252
253#[derive(Debug, Clone, PartialEq, Eq)]
254enum RouteSegment {
255 Static(String),
256 Param(String),
257}
258
259#[derive(Clone)]
260struct MatchedRoute {
261 pattern: String,
262 path: String,
263 params: BTreeMap<String, String>,
264 factory: LiveViewFactory,
265}
266
267#[derive(Clone)]
268struct SocketConfig {
269 routes: Arc<Vec<LiveRoute>>,
270 target_id: String,
271 route_path: String,
272 signed_session_id: String,
273 max_message_size: usize,
274 pubsub: PubSub,
275 uploads: UploadConfig,
276 security: SecurityConfig,
277 telemetry: Arc<TelemetryPipeline>,
278 correlation: CorrelationContext,
279 reconnect: ReconnectConfig,
280 distributed: DistributedConfig,
281 durable: DurableRuntimeConfig,
282 outbound: OutboundConfig,
283 render: RenderConfig,
284 overload: OverloadConfig,
285}
286
287#[derive(Debug, Clone, Copy, PartialEq, Eq)]
288pub enum SessionAffinityMode {
289 Disabled,
290 Required,
291}
292
293#[derive(Debug, Clone, PartialEq, Eq)]
294pub struct SessionAffinityContext {
295 pub session_id: String,
296 pub route_path: String,
297 pub current_node_id: String,
298 pub token_node_id: Option<String>,
299}
300
301#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
303pub enum OutboundOverflowPolicy {
304 #[default]
306 Disconnect,
307 DropNewest,
309}
310
311#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
313pub enum DurableTakeoverPolicy {
314 Deny,
316 #[default]
318 AllowExpired,
319 Force,
321}
322
323#[derive(Debug, Clone, PartialEq, Eq)]
325pub struct DurablePlacementContext {
326 pub session_id: String,
327 pub route_path: String,
328 pub current_node_id: String,
329 pub preferred_node_id: Option<String>,
330}
331
332#[derive(Debug, Clone, PartialEq, Eq)]
334pub struct DurablePlacementDecision {
335 pub allowed: bool,
336 pub code: Option<String>,
337 pub message: Option<String>,
338}
339
340impl DurablePlacementDecision {
341 pub fn allow() -> Self {
342 Self {
343 allowed: true,
344 code: None,
345 message: None,
346 }
347 }
348
349 pub fn deny(code: impl Into<String>, message: impl Into<String>) -> Self {
350 Self {
351 allowed: false,
352 code: Some(code.into()),
353 message: Some(message.into()),
354 }
355 }
356}
357
358#[derive(Debug, Clone, PartialEq, Eq)]
360pub struct DurableStoreError {
361 pub code: String,
362 pub message: String,
363}
364
365impl DurableStoreError {
366 fn new(code: impl Into<String>, message: impl Into<String>) -> Self {
367 Self {
368 code: code.into(),
369 message: message.into(),
370 }
371 }
372}
373
374#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
376pub struct DurableLeaseRequest {
377 pub session_id: String,
378 pub node_id: String,
379 pub ttl_ms: u64,
380 pub takeover_policy: DurableTakeoverPolicy,
381}
382
383#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
385pub struct DurableLeaseGrant {
386 pub fence_token: u64,
387 pub owner_node_id: String,
388 pub transferred_from: Option<String>,
389 pub expires_at_unix_ms: u64,
390}
391
392#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
394pub struct DurableSessionSnapshot {
395 pub route_path: String,
396 pub route_pattern: String,
397 pub target_id: String,
398 pub revision: u64,
399 pub resume_token: String,
400 pub owner_node_id: String,
401 pub updated_at_unix_ms: u64,
402}
403
404#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
406pub struct DurableJournalEntry {
407 pub sequence: u64,
408 pub message: ClientMessage,
409 pub recorded_at_unix_ms: u64,
410}
411
412#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
414pub struct DurableSessionRecord {
415 pub snapshot: DurableSessionSnapshot,
416 pub journal: Vec<DurableJournalEntry>,
417}
418
419pub trait DurableSessionStore: Send + Sync {
421 fn acquire_lease(
422 &self,
423 request: DurableLeaseRequest,
424 ) -> Result<DurableLeaseGrant, DurableStoreError>;
425 fn renew_lease(
426 &self,
427 session_id: &str,
428 node_id: &str,
429 fence_token: u64,
430 ttl_ms: u64,
431 ) -> Result<DurableLeaseGrant, DurableStoreError>;
432 fn release_lease(&self, session_id: &str, node_id: &str, fence_token: u64);
433 fn load_record(&self, session_id: &str) -> Option<DurableSessionRecord>;
434 fn save_snapshot(&self, session_id: &str, snapshot: DurableSessionSnapshot);
435 fn append_journal_entry(
436 &self,
437 session_id: &str,
438 node_id: &str,
439 fence_token: u64,
440 message: ClientMessage,
441 max_journal: usize,
442 ) -> Result<DurableJournalEntry, DurableStoreError>;
443 fn set_node_draining(&self, node_id: &str, draining: bool);
444 fn is_node_draining(&self, node_id: &str) -> bool;
445}
446
447#[derive(Default, Clone)]
448pub struct InMemoryDurableSessionStore {
449 inner: Arc<StdMutex<InMemoryDurableSessionStoreState>>,
450}
451
452#[derive(Default)]
453struct InMemoryDurableSessionStoreState {
454 sessions: HashMap<String, InMemoryDurableSessionEntry>,
455 draining_nodes: HashSet<String>,
456}
457
458#[derive(Clone)]
459struct InMemoryDurableSessionEntry {
460 lease: Option<InMemoryDurableLease>,
461 snapshot: Option<DurableSessionSnapshot>,
462 journal: Vec<DurableJournalEntry>,
463 next_sequence: u64,
464}
465
466#[derive(Clone)]
467struct InMemoryDurableLease {
468 owner_node_id: String,
469 fence_token: u64,
470 expires_at_unix_ms: u64,
471}
472
473impl InMemoryDurableSessionStore {
474 pub fn new() -> Self {
475 Self::default()
476 }
477}
478
479impl DurableSessionStore for InMemoryDurableSessionStore {
480 fn acquire_lease(
481 &self,
482 request: DurableLeaseRequest,
483 ) -> Result<DurableLeaseGrant, DurableStoreError> {
484 let mut state = self.inner.lock().map_err(|_| {
485 DurableStoreError::new("durable_store_poisoned", "durable store mutex poisoned")
486 })?;
487
488 if state.draining_nodes.contains(&request.node_id) {
489 return Err(DurableStoreError::new(
490 "node_draining",
491 format!(
492 "node {} is draining and cannot acquire durable session ownership",
493 request.node_id
494 ),
495 ));
496 }
497
498 let now = now_unix_ms();
499 let entry = state
500 .sessions
501 .entry(request.session_id.clone())
502 .or_insert_with(|| InMemoryDurableSessionEntry {
503 lease: None,
504 snapshot: None,
505 journal: Vec::new(),
506 next_sequence: 1,
507 });
508
509 let mut transferred_from = None;
510 let mut next_fence = entry
511 .lease
512 .as_ref()
513 .map(|lease| lease.fence_token + 1)
514 .unwrap_or(1);
515
516 if let Some(existing) = entry.lease.as_ref() {
517 let is_owner = existing.owner_node_id == request.node_id;
518 let is_expired = existing.expires_at_unix_ms <= now;
519 if !is_owner && !is_expired && request.takeover_policy != DurableTakeoverPolicy::Force {
520 return Err(DurableStoreError::new(
521 "lease_conflict",
522 format!(
523 "durable lease is owned by node {} until {}",
524 existing.owner_node_id, existing.expires_at_unix_ms
525 ),
526 ));
527 }
528 if !is_owner {
529 transferred_from = Some(existing.owner_node_id.clone());
530 } else {
531 next_fence = existing.fence_token;
532 }
533 }
534
535 let expires_at_unix_ms = now + request.ttl_ms.max(1_000);
536 entry.lease = Some(InMemoryDurableLease {
537 owner_node_id: request.node_id.clone(),
538 fence_token: next_fence,
539 expires_at_unix_ms,
540 });
541
542 Ok(DurableLeaseGrant {
543 fence_token: next_fence,
544 owner_node_id: request.node_id,
545 transferred_from,
546 expires_at_unix_ms,
547 })
548 }
549
550 fn renew_lease(
551 &self,
552 session_id: &str,
553 node_id: &str,
554 fence_token: u64,
555 ttl_ms: u64,
556 ) -> Result<DurableLeaseGrant, DurableStoreError> {
557 let mut state = self.inner.lock().map_err(|_| {
558 DurableStoreError::new("durable_store_poisoned", "durable store mutex poisoned")
559 })?;
560 let entry = state
561 .sessions
562 .get_mut(session_id)
563 .ok_or_else(|| DurableStoreError::new("lease_not_found", "durable lease not found"))?;
564 let lease = entry
565 .lease
566 .as_mut()
567 .ok_or_else(|| DurableStoreError::new("lease_not_found", "durable lease not found"))?;
568 if lease.owner_node_id != node_id || lease.fence_token != fence_token {
569 return Err(DurableStoreError::new(
570 "lease_not_owner",
571 "cannot renew durable lease from non-owner node",
572 ));
573 }
574 lease.expires_at_unix_ms = now_unix_ms() + ttl_ms.max(1_000);
575 Ok(DurableLeaseGrant {
576 fence_token: lease.fence_token,
577 owner_node_id: lease.owner_node_id.clone(),
578 transferred_from: None,
579 expires_at_unix_ms: lease.expires_at_unix_ms,
580 })
581 }
582
583 fn release_lease(&self, session_id: &str, node_id: &str, fence_token: u64) {
584 let Ok(mut state) = self.inner.lock() else {
585 return;
586 };
587 let Some(entry) = state.sessions.get_mut(session_id) else {
588 return;
589 };
590 let should_release = entry
591 .lease
592 .as_ref()
593 .map(|lease| lease.owner_node_id == node_id && lease.fence_token == fence_token)
594 .unwrap_or(false);
595 if should_release {
596 entry.lease = None;
597 }
598 }
599
600 fn load_record(&self, session_id: &str) -> Option<DurableSessionRecord> {
601 let Ok(state) = self.inner.lock() else {
602 return None;
603 };
604 let entry = state.sessions.get(session_id)?;
605 let snapshot = entry.snapshot.clone()?;
606 Some(DurableSessionRecord {
607 snapshot,
608 journal: entry.journal.clone(),
609 })
610 }
611
612 fn save_snapshot(&self, session_id: &str, snapshot: DurableSessionSnapshot) {
613 let Ok(mut state) = self.inner.lock() else {
614 return;
615 };
616 let entry = state
617 .sessions
618 .entry(session_id.to_string())
619 .or_insert_with(|| InMemoryDurableSessionEntry {
620 lease: None,
621 snapshot: None,
622 journal: Vec::new(),
623 next_sequence: 1,
624 });
625 entry.snapshot = Some(snapshot);
626 }
627
628 fn append_journal_entry(
629 &self,
630 session_id: &str,
631 node_id: &str,
632 fence_token: u64,
633 message: ClientMessage,
634 max_journal: usize,
635 ) -> Result<DurableJournalEntry, DurableStoreError> {
636 let mut state = self.inner.lock().map_err(|_| {
637 DurableStoreError::new("durable_store_poisoned", "durable store mutex poisoned")
638 })?;
639 let entry = state
640 .sessions
641 .entry(session_id.to_string())
642 .or_insert_with(|| InMemoryDurableSessionEntry {
643 lease: None,
644 snapshot: None,
645 journal: Vec::new(),
646 next_sequence: 1,
647 });
648 let lease = entry
649 .lease
650 .as_ref()
651 .ok_or_else(|| DurableStoreError::new("lease_not_found", "durable lease not found"))?;
652 if lease.owner_node_id != node_id || lease.fence_token != fence_token {
653 return Err(DurableStoreError::new(
654 "lease_not_owner",
655 "cannot append durable journal from non-owner node",
656 ));
657 }
658
659 let journal_entry = DurableJournalEntry {
660 sequence: entry.next_sequence,
661 message,
662 recorded_at_unix_ms: now_unix_ms(),
663 };
664 entry.next_sequence += 1;
665 entry.journal.push(journal_entry.clone());
666 let max_journal = max_journal.max(1);
667 if entry.journal.len() > max_journal {
668 let trim_count = entry.journal.len() - max_journal;
669 entry.journal.drain(0..trim_count);
670 }
671 Ok(journal_entry)
672 }
673
674 fn set_node_draining(&self, node_id: &str, draining: bool) {
675 let Ok(mut state) = self.inner.lock() else {
676 return;
677 };
678 if draining {
679 state.draining_nodes.insert(node_id.to_string());
680 } else {
681 state.draining_nodes.remove(node_id);
682 }
683 }
684
685 fn is_node_draining(&self, node_id: &str) -> bool {
686 let Ok(state) = self.inner.lock() else {
687 return false;
688 };
689 state.draining_nodes.contains(node_id)
690 }
691}
692
693#[derive(Clone)]
699pub struct FileDurableSessionStore {
700 root_dir: Arc<PathBuf>,
701 lock: Arc<StdMutex<()>>,
702 draining_nodes: Arc<StdMutex<HashSet<String>>>,
703}
704
705#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
706struct FileDurableSessionRecord {
707 lease: Option<FileDurableLease>,
708 snapshot: Option<DurableSessionSnapshot>,
709 journal: Vec<DurableJournalEntry>,
710 next_sequence: u64,
711}
712
713impl Default for FileDurableSessionRecord {
714 fn default() -> Self {
715 Self {
716 lease: None,
717 snapshot: None,
718 journal: Vec::new(),
719 next_sequence: 1,
720 }
721 }
722}
723
724#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
725struct FileDurableLease {
726 owner_node_id: String,
727 fence_token: u64,
728 expires_at_unix_ms: u64,
729}
730
731impl FileDurableSessionStore {
732 pub fn new(root_dir: impl Into<PathBuf>) -> Self {
733 Self {
734 root_dir: Arc::new(root_dir.into()),
735 lock: Arc::new(StdMutex::new(())),
736 draining_nodes: Arc::new(StdMutex::new(HashSet::new())),
737 }
738 }
739
740 fn sessions_dir(&self) -> PathBuf {
741 self.root_dir.join("sessions")
742 }
743
744 fn session_path(&self, session_id: &str) -> PathBuf {
745 let encoded = URL_SAFE_NO_PAD.encode(session_id.as_bytes());
746 self.sessions_dir().join(format!("{encoded}.json"))
747 }
748
749 fn ensure_sessions_dir(&self) -> Result<(), DurableStoreError> {
750 fs::create_dir_all(self.sessions_dir()).map_err(|error| {
751 DurableStoreError::new(
752 "durable_store_io",
753 format!("failed to create durable sessions dir: {error}"),
754 )
755 })
756 }
757
758 fn lock_guard(&self) -> Result<std::sync::MutexGuard<'_, ()>, DurableStoreError> {
759 self.lock.lock().map_err(|_| {
760 DurableStoreError::new("durable_store_poisoned", "durable store mutex poisoned")
761 })
762 }
763
764 fn read_record_locked(
765 &self,
766 session_id: &str,
767 ) -> Result<Option<FileDurableSessionRecord>, DurableStoreError> {
768 let path = self.session_path(session_id);
769 if !path.exists() {
770 return Ok(None);
771 }
772 let raw = fs::read_to_string(&path).map_err(|error| {
773 DurableStoreError::new(
774 "durable_store_io",
775 format!("failed to read durable record for session {session_id}: {error}"),
776 )
777 })?;
778 let record = serde_json::from_str::<FileDurableSessionRecord>(&raw).map_err(|error| {
779 DurableStoreError::new(
780 "durable_store_corrupt",
781 format!("failed to decode durable record for session {session_id}: {error}"),
782 )
783 })?;
784 Ok(Some(record))
785 }
786
787 fn read_or_default_locked(
788 &self,
789 session_id: &str,
790 ) -> Result<FileDurableSessionRecord, DurableStoreError> {
791 Ok(self.read_record_locked(session_id)?.unwrap_or_default())
792 }
793
794 fn write_record_locked(
795 &self,
796 session_id: &str,
797 record: &FileDurableSessionRecord,
798 ) -> Result<(), DurableStoreError> {
799 self.ensure_sessions_dir()?;
800 let path = self.session_path(session_id);
801 let temp_path = path.with_extension("json.tmp");
802 let encoded = serde_json::to_string(record).map_err(|error| {
803 DurableStoreError::new(
804 "durable_store_encode",
805 format!("failed to encode durable record for session {session_id}: {error}"),
806 )
807 })?;
808 fs::write(&temp_path, encoded).map_err(|error| {
809 DurableStoreError::new(
810 "durable_store_io",
811 format!(
812 "failed to write durable record temp file for session {session_id}: {error}"
813 ),
814 )
815 })?;
816 fs::rename(&temp_path, &path).map_err(|error| {
817 DurableStoreError::new(
818 "durable_store_io",
819 format!("failed to persist durable record for session {session_id}: {error}"),
820 )
821 })?;
822 Ok(())
823 }
824}
825
826impl DurableSessionStore for FileDurableSessionStore {
827 fn acquire_lease(
828 &self,
829 request: DurableLeaseRequest,
830 ) -> Result<DurableLeaseGrant, DurableStoreError> {
831 let _guard = self.lock_guard()?;
832 if self.is_node_draining(&request.node_id) {
833 return Err(DurableStoreError::new(
834 "node_draining",
835 format!(
836 "node {} is draining and cannot acquire durable session ownership",
837 request.node_id
838 ),
839 ));
840 }
841
842 let mut record = self.read_or_default_locked(&request.session_id)?;
843 let now = now_unix_ms();
844 let mut transferred_from = None;
845 let mut next_fence = record
846 .lease
847 .as_ref()
848 .map(|lease| lease.fence_token + 1)
849 .unwrap_or(1);
850
851 if let Some(existing) = record.lease.as_ref() {
852 let is_owner = existing.owner_node_id == request.node_id;
853 let is_expired = existing.expires_at_unix_ms <= now;
854 if !is_owner && !is_expired && request.takeover_policy != DurableTakeoverPolicy::Force {
855 return Err(DurableStoreError::new(
856 "lease_conflict",
857 format!(
858 "durable lease is owned by node {} until {}",
859 existing.owner_node_id, existing.expires_at_unix_ms
860 ),
861 ));
862 }
863 if !is_owner {
864 transferred_from = Some(existing.owner_node_id.clone());
865 } else {
866 next_fence = existing.fence_token;
867 }
868 }
869
870 let expires_at_unix_ms = now + request.ttl_ms.max(1_000);
871 record.lease = Some(FileDurableLease {
872 owner_node_id: request.node_id.clone(),
873 fence_token: next_fence,
874 expires_at_unix_ms,
875 });
876 self.write_record_locked(&request.session_id, &record)?;
877
878 Ok(DurableLeaseGrant {
879 fence_token: next_fence,
880 owner_node_id: request.node_id,
881 transferred_from,
882 expires_at_unix_ms,
883 })
884 }
885
886 fn renew_lease(
887 &self,
888 session_id: &str,
889 node_id: &str,
890 fence_token: u64,
891 ttl_ms: u64,
892 ) -> Result<DurableLeaseGrant, DurableStoreError> {
893 let _guard = self.lock_guard()?;
894 let mut record = self
895 .read_record_locked(session_id)?
896 .ok_or_else(|| DurableStoreError::new("lease_not_found", "durable lease not found"))?;
897 let lease = record
898 .lease
899 .as_mut()
900 .ok_or_else(|| DurableStoreError::new("lease_not_found", "durable lease not found"))?;
901 if lease.owner_node_id != node_id || lease.fence_token != fence_token {
902 return Err(DurableStoreError::new(
903 "lease_not_owner",
904 "cannot renew durable lease from non-owner node",
905 ));
906 }
907 lease.expires_at_unix_ms = now_unix_ms() + ttl_ms.max(1_000);
908 let grant = DurableLeaseGrant {
909 fence_token: lease.fence_token,
910 owner_node_id: lease.owner_node_id.clone(),
911 transferred_from: None,
912 expires_at_unix_ms: lease.expires_at_unix_ms,
913 };
914 self.write_record_locked(session_id, &record)?;
915 Ok(grant)
916 }
917
918 fn release_lease(&self, session_id: &str, node_id: &str, fence_token: u64) {
919 let Ok(_guard) = self.lock_guard() else {
920 return;
921 };
922 let Ok(Some(mut record)) = self.read_record_locked(session_id) else {
923 return;
924 };
925 let should_release = record
926 .lease
927 .as_ref()
928 .map(|lease| lease.owner_node_id == node_id && lease.fence_token == fence_token)
929 .unwrap_or(false);
930 if !should_release {
931 return;
932 }
933 record.lease = None;
934 if let Err(error) = self.write_record_locked(session_id, &record) {
935 warn!(
936 session_id,
937 code = error.code,
938 message = error.message,
939 "failed to persist durable lease release"
940 );
941 }
942 }
943
944 fn load_record(&self, session_id: &str) -> Option<DurableSessionRecord> {
945 let _guard = match self.lock_guard() {
946 Ok(guard) => guard,
947 Err(error) => {
948 warn!(
949 session_id,
950 code = error.code,
951 message = error.message,
952 "failed to lock durable file store"
953 );
954 return None;
955 }
956 };
957 let record = match self.read_record_locked(session_id) {
958 Ok(record) => record,
959 Err(error) => {
960 warn!(
961 session_id,
962 code = error.code,
963 message = error.message,
964 "failed to load durable record"
965 );
966 return None;
967 }
968 };
969 let record = record?;
970 let snapshot = record.snapshot?;
971 Some(DurableSessionRecord {
972 snapshot,
973 journal: record.journal,
974 })
975 }
976
977 fn save_snapshot(&self, session_id: &str, snapshot: DurableSessionSnapshot) {
978 let Ok(_guard) = self.lock_guard() else {
979 return;
980 };
981 let mut record = match self.read_or_default_locked(session_id) {
982 Ok(record) => record,
983 Err(error) => {
984 warn!(
985 session_id,
986 code = error.code,
987 message = error.message,
988 "failed to load durable record before snapshot save"
989 );
990 return;
991 }
992 };
993 record.snapshot = Some(snapshot);
994 if let Err(error) = self.write_record_locked(session_id, &record) {
995 warn!(
996 session_id,
997 code = error.code,
998 message = error.message,
999 "failed to persist durable snapshot"
1000 );
1001 }
1002 }
1003
1004 fn append_journal_entry(
1005 &self,
1006 session_id: &str,
1007 node_id: &str,
1008 fence_token: u64,
1009 message: ClientMessage,
1010 max_journal: usize,
1011 ) -> Result<DurableJournalEntry, DurableStoreError> {
1012 let _guard = self.lock_guard()?;
1013 let mut record = self.read_or_default_locked(session_id)?;
1014 let lease = record
1015 .lease
1016 .as_ref()
1017 .ok_or_else(|| DurableStoreError::new("lease_not_found", "durable lease not found"))?;
1018 if lease.owner_node_id != node_id || lease.fence_token != fence_token {
1019 return Err(DurableStoreError::new(
1020 "lease_not_owner",
1021 "cannot append durable journal from non-owner node",
1022 ));
1023 }
1024
1025 let journal_entry = DurableJournalEntry {
1026 sequence: record.next_sequence,
1027 message,
1028 recorded_at_unix_ms: now_unix_ms(),
1029 };
1030 record.next_sequence += 1;
1031 record.journal.push(journal_entry.clone());
1032 let max_journal = max_journal.max(1);
1033 if record.journal.len() > max_journal {
1034 let trim_count = record.journal.len() - max_journal;
1035 record.journal.drain(0..trim_count);
1036 }
1037 self.write_record_locked(session_id, &record)?;
1038 Ok(journal_entry)
1039 }
1040
1041 fn set_node_draining(&self, node_id: &str, draining: bool) {
1042 let Ok(mut draining_nodes) = self.draining_nodes.lock() else {
1043 return;
1044 };
1045 if draining {
1046 draining_nodes.insert(node_id.to_string());
1047 } else {
1048 draining_nodes.remove(node_id);
1049 }
1050 }
1051
1052 fn is_node_draining(&self, node_id: &str) -> bool {
1053 let Ok(draining_nodes) = self.draining_nodes.lock() else {
1054 return false;
1055 };
1056 draining_nodes.contains(node_id)
1057 }
1058}
1059
1060#[derive(Clone)]
1061struct DistributedConfig {
1062 node_id: String,
1063 session_affinity: SessionAffinityMode,
1064}
1065
1066impl Default for DistributedConfig {
1067 fn default() -> Self {
1068 Self {
1069 node_id: format!("node-{}", Uuid::new_v4()),
1070 session_affinity: SessionAffinityMode::Disabled,
1071 }
1072 }
1073}
1074
1075#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1077pub enum TransportMode {
1078 WebSocket,
1080 ServerSentEvents,
1082 LongPoll,
1084}
1085
1086impl TransportMode {
1087 fn as_str(self) -> &'static str {
1088 match self {
1089 Self::WebSocket => "websocket",
1090 Self::ServerSentEvents => "sse",
1091 Self::LongPoll => "long_poll",
1092 }
1093 }
1094}
1095
1096#[derive(Debug, Clone)]
1097struct HttpTransportConfig {
1098 http3_alt_svc: Option<String>,
1099 emit_diagnostics_header: bool,
1100 primary_mode: TransportMode,
1101 fallback_modes: Vec<TransportMode>,
1102 progressive_enhancement: bool,
1103}
1104
1105impl Default for HttpTransportConfig {
1106 fn default() -> Self {
1107 Self {
1108 http3_alt_svc: None,
1109 emit_diagnostics_header: false,
1110 primary_mode: TransportMode::WebSocket,
1111 fallback_modes: Vec::new(),
1112 progressive_enhancement: true,
1113 }
1114 }
1115}
1116
1117impl HttpTransportConfig {
1118 fn fallback_names(&self) -> Vec<&'static str> {
1119 self.fallback_modes
1120 .iter()
1121 .copied()
1122 .map(TransportMode::as_str)
1123 .collect()
1124 }
1125
1126 fn fallback_json(&self) -> String {
1127 serde_json::to_string(&self.fallback_names()).expect("transport modes serialize")
1128 }
1129
1130 fn diagnostics_header(&self) -> String {
1131 let h3 = if self.http3_alt_svc.is_some() {
1132 "advertised"
1133 } else {
1134 "disabled"
1135 };
1136 let fallback = if self.fallback_modes.is_empty() {
1137 "none".to_string()
1138 } else {
1139 self.fallback_names().join(",")
1140 };
1141 let progressive = if self.progressive_enhancement {
1142 "enabled"
1143 } else {
1144 "disabled"
1145 };
1146 format!(
1147 "{}; fallback={fallback}; progressive={progressive}; h3={h3}",
1148 self.primary_mode.as_str()
1149 )
1150 }
1151}
1152
1153#[derive(Debug, Serialize)]
1154struct TransportCapabilities {
1155 protocol: &'static str,
1156 primary: &'static str,
1157 fallbacks: Vec<&'static str>,
1158 progressive_enhancement: bool,
1159 http3_advertised: bool,
1160}
1161
1162#[derive(Debug, Clone)]
1163struct ClientRuntimeConfig {
1164 reconnect_base_ms: u64,
1165 reconnect_max_ms: u64,
1166 reconnect_jitter_ms: u64,
1167 heartbeat_interval_ms: u64,
1168 heartbeat_timeout_ms: u64,
1169}
1170
1171impl Default for ClientRuntimeConfig {
1172 fn default() -> Self {
1173 Self {
1174 reconnect_base_ms: DEFAULT_RECONNECT_BASE_MS,
1175 reconnect_max_ms: DEFAULT_RECONNECT_MAX_MS,
1176 reconnect_jitter_ms: DEFAULT_RECONNECT_JITTER_MS,
1177 heartbeat_interval_ms: DEFAULT_HEARTBEAT_INTERVAL_MS,
1178 heartbeat_timeout_ms: DEFAULT_HEARTBEAT_TIMEOUT_MS,
1179 }
1180 }
1181}
1182
1183#[derive(Clone)]
1184struct ReconnectConfig {
1185 resume_ttl: Duration,
1186 connect_handshake_timeout: Duration,
1187 client: ClientRuntimeConfig,
1188 snapshots: Arc<Mutex<HashMap<String, ResumeSnapshot>>>,
1189}
1190
1191impl Default for ReconnectConfig {
1192 fn default() -> Self {
1193 Self {
1194 resume_ttl: Duration::from_millis(DEFAULT_RESUME_TTL_MS),
1195 connect_handshake_timeout: Duration::from_millis(DEFAULT_CONNECT_HANDSHAKE_TIMEOUT_MS),
1196 client: ClientRuntimeConfig::default(),
1197 snapshots: Arc::new(Mutex::new(HashMap::new())),
1198 }
1199 }
1200}
1201
1202#[derive(Clone)]
1203struct DurableRuntimeConfig {
1204 store: Option<Arc<dyn DurableSessionStore>>,
1205 lease_ttl: Duration,
1206 journal_limit: usize,
1207 takeover_policy: DurableTakeoverPolicy,
1208 drain_mode: bool,
1209 placement_hook: Option<DurablePlacementHook>,
1210}
1211
1212impl Default for DurableRuntimeConfig {
1213 fn default() -> Self {
1214 Self {
1215 store: None,
1216 lease_ttl: Duration::from_millis(DEFAULT_DURABLE_LEASE_TTL_MS),
1217 journal_limit: DEFAULT_DURABLE_JOURNAL_LIMIT,
1218 takeover_policy: DurableTakeoverPolicy::default(),
1219 drain_mode: false,
1220 placement_hook: None,
1221 }
1222 }
1223}
1224
1225#[derive(Debug, Clone)]
1226struct OutboundConfig {
1227 queue_capacity: usize,
1228 batch_max_messages: usize,
1229 batch_max_bytes: usize,
1230 batch_flush_interval: Duration,
1231 overflow_policy: OutboundOverflowPolicy,
1232}
1233
1234impl Default for OutboundConfig {
1235 fn default() -> Self {
1236 Self {
1237 queue_capacity: DEFAULT_OUTBOUND_QUEUE_CAPACITY,
1238 batch_max_messages: DEFAULT_OUTBOUND_BATCH_MAX_MESSAGES,
1239 batch_max_bytes: DEFAULT_OUTBOUND_BATCH_MAX_BYTES,
1240 batch_flush_interval: Duration::from_millis(DEFAULT_OUTBOUND_BATCH_FLUSH_INTERVAL_MS),
1241 overflow_policy: OutboundOverflowPolicy::default(),
1242 }
1243 }
1244}
1245
1246#[derive(Debug, Clone)]
1247struct RenderConfig {
1248 default_cadence_ms: u64,
1249}
1250
1251impl Default for RenderConfig {
1252 fn default() -> Self {
1253 Self {
1254 default_cadence_ms: DEFAULT_RENDER_CADENCE_MS,
1255 }
1256 }
1257}
1258
1259struct ResumeSnapshot {
1260 session: LiveSession,
1261 route_pattern: String,
1262 resume_token: String,
1263 expires_at: Instant,
1264}
1265
1266struct ConnectHandshake {
1267 client_session_id: Option<String>,
1268 client_revision: u64,
1269 resume_token: Option<String>,
1270 tenant_id: Option<String>,
1271 trace_id: Option<String>,
1272 span_id: Option<String>,
1273 parent_span_id: Option<String>,
1274 correlation_id: Option<String>,
1275 request_id: Option<String>,
1276}
1277
1278#[derive(Debug, Clone, Default, PartialEq, Eq)]
1279struct CorrelationContext {
1280 trace_id: String,
1281 span_id: String,
1282 parent_span_id: Option<String>,
1283 correlation_id: Option<String>,
1284 request_id: Option<String>,
1285}
1286
1287impl CorrelationContext {
1288 fn with_new_span(&self) -> Self {
1289 let mut next = self.clone();
1290 next.parent_span_id = Some(self.span_id.clone());
1291 next.span_id = generate_span_id();
1292 next
1293 }
1294
1295 fn apply_to_event(&self, event: TelemetryEvent) -> TelemetryEvent {
1296 let mut enriched = event
1297 .with_trace_id(self.trace_id.clone())
1298 .with_span_id(self.span_id.clone());
1299 if let Some(parent_span_id) = &self.parent_span_id {
1300 enriched = enriched.with_parent_span_id(parent_span_id.clone());
1301 }
1302 if let Some(correlation_id) = &self.correlation_id {
1303 enriched = enriched.with_correlation_id(correlation_id.clone());
1304 }
1305 if let Some(request_id) = &self.request_id {
1306 enriched = enriched.with_request_id(request_id.clone());
1307 }
1308 enriched
1309 }
1310}
1311
1312#[derive(Clone)]
1313struct UploadConfig {
1314 max_file_size: u64,
1315 allowed_content_types: Arc<Vec<String>>,
1316 temp_dir: Arc<PathBuf>,
1317}
1318
1319impl Default for UploadConfig {
1320 fn default() -> Self {
1321 Self {
1322 max_file_size: DEFAULT_MAX_UPLOAD_SIZE,
1323 allowed_content_types: Arc::new(Vec::new()),
1324 temp_dir: Arc::new(std::env::temp_dir().join("shelly-liveview-uploads")),
1325 }
1326 }
1327}
1328
1329struct UploadEntry {
1330 event: String,
1331 target: Option<String>,
1332 name: String,
1333 size: u64,
1334 content_type: Option<String>,
1335 path: PathBuf,
1336 file: tokio::fs::File,
1337 received: u64,
1338}
1339
1340struct TextDispatch<'a> {
1341 current_route_pattern: &'a mut String,
1342 routes: &'a [LiveRoute],
1343 target_id: &'a str,
1344 session_id: &'a str,
1345 upload_config: &'a UploadConfig,
1346 uploads: &'a mut HashMap<String, UploadEntry>,
1347 telemetry: &'a Arc<dyn TelemetrySink>,
1348}
1349
1350struct UploadStartRequest {
1351 upload_id: String,
1352 event: String,
1353 target: Option<String>,
1354 name: String,
1355 size: u64,
1356 content_type: Option<String>,
1357}
1358
1359impl LiveRoute {
1360 fn new(pattern: String, factory: LiveViewFactory) -> Self {
1361 let segments = route_segments(&pattern);
1362 Self {
1363 pattern,
1364 segments,
1365 factory,
1366 }
1367 }
1368
1369 fn match_path(&self, path: &str) -> Option<MatchedRoute> {
1370 let path_segments = path_segments(path);
1371 if self.segments.len() != path_segments.len() {
1372 return None;
1373 }
1374
1375 let mut params = BTreeMap::new();
1376 for (pattern, actual) in self.segments.iter().zip(path_segments) {
1377 match pattern {
1378 RouteSegment::Static(expected) if expected == actual => {}
1379 RouteSegment::Static(_) => return None,
1380 RouteSegment::Param(name) => {
1381 params.insert(name.clone(), actual.to_string());
1382 }
1383 }
1384 }
1385
1386 Some(MatchedRoute {
1387 pattern: self.pattern.clone(),
1388 path: path.to_string(),
1389 params,
1390 factory: self.factory.clone(),
1391 })
1392 }
1393}
1394
1395#[derive(Debug, Clone, PartialEq, Eq)]
1397pub struct RateLimitContext {
1398 pub route_path: String,
1399 pub session_id: String,
1400 pub message_kind: &'static str,
1401}
1402
1403#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1405pub enum SecurityOperation {
1406 Connect,
1407 Event,
1408 PatchUrl,
1409 Navigate,
1410 UploadStart,
1411 UploadChunk,
1412 UploadComplete,
1413 Ping,
1414 Binary,
1415}
1416
1417impl SecurityOperation {
1418 fn as_str(self) -> &'static str {
1419 match self {
1420 Self::Connect => "connect",
1421 Self::Event => "event",
1422 Self::PatchUrl => "patch_url",
1423 Self::Navigate => "navigate",
1424 Self::UploadStart => "upload_start",
1425 Self::UploadChunk => "upload_chunk",
1426 Self::UploadComplete => "upload_complete",
1427 Self::Ping => "ping",
1428 Self::Binary => "binary",
1429 }
1430 }
1431
1432 fn is_mutating(self) -> bool {
1433 matches!(
1434 self,
1435 Self::Event
1436 | Self::PatchUrl
1437 | Self::Navigate
1438 | Self::UploadStart
1439 | Self::UploadChunk
1440 | Self::UploadComplete
1441 )
1442 }
1443}
1444
1445#[derive(Debug, Clone, PartialEq, Eq)]
1447pub struct AuthorizationContext {
1448 pub route_path: String,
1449 pub session_id: String,
1450 pub tenant_id: Option<String>,
1451 pub message_kind: &'static str,
1452 pub operation: SecurityOperation,
1453 pub event_name: Option<String>,
1454 pub event_target: Option<String>,
1455}
1456
1457#[derive(Debug, Clone, PartialEq, Eq)]
1459pub struct AuthorizationDecision {
1460 pub allowed: bool,
1461 pub code: Option<String>,
1462 pub message: Option<String>,
1463}
1464
1465impl AuthorizationDecision {
1466 pub fn allow() -> Self {
1467 Self {
1468 allowed: true,
1469 code: None,
1470 message: None,
1471 }
1472 }
1473
1474 pub fn deny(code: impl Into<String>, message: impl Into<String>) -> Self {
1475 Self {
1476 allowed: false,
1477 code: Some(code.into()),
1478 message: Some(message.into()),
1479 }
1480 }
1481}
1482
1483#[derive(Debug, Clone, PartialEq, Eq)]
1485pub struct QuotaContext {
1486 pub route_path: String,
1487 pub session_id: String,
1488 pub tenant_id: Option<String>,
1489 pub message_kind: &'static str,
1490 pub operation: SecurityOperation,
1491 pub event_name: Option<String>,
1492}
1493
1494#[derive(Debug, Clone, PartialEq, Eq)]
1496pub struct QuotaDecision {
1497 pub allowed: bool,
1498 pub code: Option<String>,
1499 pub message: Option<String>,
1500}
1501
1502impl QuotaDecision {
1503 pub fn allow() -> Self {
1504 Self {
1505 allowed: true,
1506 code: None,
1507 message: None,
1508 }
1509 }
1510
1511 pub fn deny(code: impl Into<String>, message: impl Into<String>) -> Self {
1512 Self {
1513 allowed: false,
1514 code: Some(code.into()),
1515 message: Some(message.into()),
1516 }
1517 }
1518}
1519
1520#[derive(Debug, Clone, PartialEq, Eq)]
1521pub struct TenantQuotaBudgets {
1522 pub max_sessions_per_window: usize,
1523 pub max_events_per_window: usize,
1524 pub require_tenant_id: bool,
1525}
1526
1527impl Default for TenantQuotaBudgets {
1528 fn default() -> Self {
1529 Self {
1530 max_sessions_per_window: DEFAULT_TENANT_SESSION_QUOTA_PER_WINDOW,
1531 max_events_per_window: DEFAULT_TENANT_EVENT_QUOTA_PER_WINDOW,
1532 require_tenant_id: true,
1533 }
1534 }
1535}
1536
1537#[derive(Debug, Clone)]
1538struct TenantQuotaWindow {
1539 started_at: Instant,
1540 sessions: HashSet<String>,
1541 events: usize,
1542}
1543
1544impl TenantQuotaWindow {
1545 fn new(now: Instant) -> Self {
1546 Self {
1547 started_at: now,
1548 sessions: HashSet::new(),
1549 events: 0,
1550 }
1551 }
1552
1553 fn roll_if_needed(&mut self, now: Instant, window: Duration) {
1554 if now.duration_since(self.started_at) >= window {
1555 *self = Self::new(now);
1556 }
1557 }
1558}
1559
1560#[derive(Debug, Default, Clone)]
1561struct TenantQuotaState {
1562 windows: HashMap<String, TenantQuotaWindow>,
1563}
1564
1565#[derive(Clone)]
1566pub struct TenantQuotaPolicy {
1567 budgets: TenantQuotaBudgets,
1568 window: Duration,
1569 state: Arc<StdMutex<TenantQuotaState>>,
1570}
1571
1572impl TenantQuotaPolicy {
1573 pub fn new() -> Self {
1574 Self {
1575 budgets: TenantQuotaBudgets::default(),
1576 window: Duration::from_millis(DEFAULT_TENANT_QUOTA_WINDOW_MS),
1577 state: Arc::new(StdMutex::new(TenantQuotaState::default())),
1578 }
1579 }
1580
1581 pub fn with_budgets(mut self, budgets: TenantQuotaBudgets) -> Self {
1582 self.budgets = budgets;
1583 self
1584 }
1585
1586 pub fn with_window_ms(mut self, window_ms: u64) -> Self {
1587 self.window = Duration::from_millis(window_ms.max(1));
1588 self
1589 }
1590
1591 pub fn evaluate(&self, context: &QuotaContext) -> QuotaDecision {
1592 let tenant_id = normalize_tenant_id_ref(context.tenant_id.as_deref());
1593 if self.budgets.require_tenant_id && tenant_id.is_none() {
1594 return QuotaDecision::deny(
1595 "tenant_context_required",
1596 "tenant_id is required for quota policy evaluation",
1597 );
1598 }
1599 let Some(tenant_id) = tenant_id else {
1600 return QuotaDecision::allow();
1601 };
1602
1603 let now = Instant::now();
1604 let Ok(mut state) = self.state.lock() else {
1605 return QuotaDecision::deny(
1606 "tenant_quota_unavailable",
1607 "tenant quota state is unavailable",
1608 );
1609 };
1610 let window = state
1611 .windows
1612 .entry(tenant_id)
1613 .or_insert_with(|| TenantQuotaWindow::new(now));
1614 window.roll_if_needed(now, self.window);
1615
1616 if matches!(context.operation, SecurityOperation::Connect) {
1617 let is_new_session = window.sessions.insert(context.session_id.clone());
1618 if is_new_session && window.sessions.len() > self.budgets.max_sessions_per_window.max(1)
1619 {
1620 return QuotaDecision::deny(
1621 "tenant_session_quota_exceeded",
1622 "tenant session quota exceeded",
1623 );
1624 }
1625 return QuotaDecision::allow();
1626 }
1627
1628 if context.operation.is_mutating() {
1629 let next_events = window.events.saturating_add(1);
1630 if next_events > self.budgets.max_events_per_window.max(1) {
1631 return QuotaDecision::deny(
1632 "tenant_event_quota_exceeded",
1633 "tenant event quota exceeded",
1634 );
1635 }
1636 window.events = next_events;
1637 }
1638
1639 QuotaDecision::allow()
1640 }
1641}
1642
1643impl Default for TenantQuotaPolicy {
1644 fn default() -> Self {
1645 Self::new()
1646 }
1647}
1648
1649#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1651pub enum OverloadPriority {
1652 Interactive,
1653 Background,
1654}
1655
1656impl OverloadPriority {
1657 fn as_str(self) -> &'static str {
1658 match self {
1659 Self::Interactive => "interactive",
1660 Self::Background => "background",
1661 }
1662 }
1663}
1664
1665#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
1667pub enum OverloadShedPolicy {
1668 #[default]
1670 PreferInteractive,
1671 Strict,
1673}
1674
1675#[derive(Debug, Clone, PartialEq, Eq)]
1677pub struct OverloadBudgets {
1678 pub session_queue_depth: usize,
1679 pub session_bytes_per_sec: usize,
1680 pub session_cpu_ms_per_sec: u64,
1681 pub tenant_queue_depth: usize,
1682 pub tenant_bytes_per_sec: usize,
1683 pub tenant_cpu_ms_per_sec: u64,
1684}
1685
1686impl Default for OverloadBudgets {
1687 fn default() -> Self {
1688 Self {
1689 session_queue_depth: DEFAULT_SESSION_QUEUE_DEPTH_BUDGET,
1690 session_bytes_per_sec: DEFAULT_SESSION_BYTES_PER_SEC_BUDGET,
1691 session_cpu_ms_per_sec: DEFAULT_SESSION_CPU_MS_PER_SEC_BUDGET,
1692 tenant_queue_depth: DEFAULT_TENANT_QUEUE_DEPTH_BUDGET,
1693 tenant_bytes_per_sec: DEFAULT_TENANT_BYTES_PER_SEC_BUDGET,
1694 tenant_cpu_ms_per_sec: DEFAULT_TENANT_CPU_MS_PER_SEC_BUDGET,
1695 }
1696 }
1697}
1698
1699#[derive(Debug, Clone, PartialEq, Eq)]
1701pub struct OverloadContext {
1702 pub route_path: String,
1703 pub session_id: String,
1704 pub tenant_id: Option<String>,
1705 pub message_kind: &'static str,
1706 pub operation: SecurityOperation,
1707 pub event_name: Option<String>,
1708 pub priority: OverloadPriority,
1709 pub queue_depth: usize,
1710 pub queue_capacity: usize,
1711 pub inbound_bytes: usize,
1712}
1713
1714#[derive(Debug, Clone, PartialEq, Eq)]
1716pub struct OverloadDecision {
1717 pub allowed: bool,
1718 pub throttle_ms: u64,
1719 pub code: Option<String>,
1720 pub message: Option<String>,
1721 pub reason: Option<String>,
1722}
1723
1724impl OverloadDecision {
1725 pub fn allow() -> Self {
1726 Self {
1727 allowed: true,
1728 throttle_ms: 0,
1729 code: None,
1730 message: None,
1731 reason: None,
1732 }
1733 }
1734
1735 pub fn throttle(ms: u64, reason: impl Into<String>) -> Self {
1736 Self {
1737 allowed: true,
1738 throttle_ms: ms,
1739 code: Some("overload_throttle".to_string()),
1740 message: Some("overload throttling applied".to_string()),
1741 reason: Some(reason.into()),
1742 }
1743 }
1744
1745 pub fn shed(reason: impl Into<String>) -> Self {
1746 Self {
1747 allowed: false,
1748 throttle_ms: 0,
1749 code: Some("overload_shed".to_string()),
1750 message: Some("server overloaded; request shed by policy".to_string()),
1751 reason: Some(reason.into()),
1752 }
1753 }
1754}
1755
1756#[derive(Debug, Clone)]
1757struct OverloadBudgetWindow {
1758 started_at: Instant,
1759 events: usize,
1760 bytes: usize,
1761 cpu_ms: u64,
1762 queue_depth_peak: usize,
1763}
1764
1765impl OverloadBudgetWindow {
1766 fn new(now: Instant) -> Self {
1767 Self {
1768 started_at: now,
1769 events: 0,
1770 bytes: 0,
1771 cpu_ms: 0,
1772 queue_depth_peak: 0,
1773 }
1774 }
1775
1776 fn roll_window_if_needed(&mut self, now: Instant, window: Duration) {
1777 if now.duration_since(self.started_at) >= window {
1778 *self = Self::new(now);
1779 }
1780 }
1781}
1782
1783#[derive(Debug, Default)]
1784struct OverloadState {
1785 session_windows: HashMap<String, OverloadBudgetWindow>,
1786 tenant_windows: HashMap<String, OverloadBudgetWindow>,
1787}
1788
1789#[derive(Clone)]
1790struct OverloadConfig {
1791 budgets: OverloadBudgets,
1792 shed_policy: OverloadShedPolicy,
1793 policy_hook: Option<OverloadPolicyHook>,
1794 window: Duration,
1795 state: Arc<Mutex<OverloadState>>,
1796}
1797
1798impl Default for OverloadConfig {
1799 fn default() -> Self {
1800 Self {
1801 budgets: OverloadBudgets::default(),
1802 shed_policy: OverloadShedPolicy::default(),
1803 policy_hook: None,
1804 window: Duration::from_millis(DEFAULT_OVERLOAD_WINDOW_MS),
1805 state: Arc::new(Mutex::new(OverloadState::default())),
1806 }
1807 }
1808}
1809
1810#[derive(Clone)]
1811struct SecurityConfig {
1812 signer: TokenSigner,
1813 allowed_origins: Arc<Vec<String>>,
1814 rate_limiter: Option<RateLimitHook>,
1815 authorization: Option<AuthorizationHook>,
1816 quota_policy: Option<QuotaPolicyHook>,
1817}
1818
1819impl Default for SecurityConfig {
1820 fn default() -> Self {
1821 Self {
1822 signer: TokenSigner::ephemeral(),
1823 allowed_origins: Arc::new(Vec::new()),
1824 rate_limiter: None,
1825 authorization: None,
1826 quota_policy: None,
1827 }
1828 }
1829}
1830
1831#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1833pub enum ConsoleLogFormat {
1834 Json,
1836 Pretty,
1838}
1839
1840impl ConsoleLogFormat {
1841 fn parse(raw: &str) -> Self {
1842 match raw.trim().to_ascii_lowercase().as_str() {
1843 "pretty" | "text" | "plain" => Self::Pretty,
1844 _ => Self::Json,
1845 }
1846 }
1847
1848 fn as_str(self) -> &'static str {
1849 match self {
1850 Self::Json => "json",
1851 Self::Pretty => "pretty",
1852 }
1853 }
1854}
1855
1856pub fn init_console_logging(service_name: impl AsRef<str>) -> Result<(), String> {
1863 let format = std::env::var("SHELLY_LOG_FORMAT")
1864 .map(|raw| ConsoleLogFormat::parse(&raw))
1865 .unwrap_or(ConsoleLogFormat::Json);
1866 let log_filter =
1867 std::env::var("RUST_LOG").unwrap_or_else(|_| "shelly_axum=info,info".to_string());
1868 let env_name = std::env::var("SHELLY_ENV").unwrap_or_else(|_| "development".to_string());
1869 let env_filter = EnvFilter::new(log_filter);
1870
1871 match format {
1872 ConsoleLogFormat::Json => tracing_subscriber::fmt()
1873 .with_env_filter(env_filter)
1874 .with_target(true)
1875 .json()
1876 .flatten_event(true)
1877 .try_init(),
1878 ConsoleLogFormat::Pretty => tracing_subscriber::fmt()
1879 .with_env_filter(env_filter)
1880 .with_target(true)
1881 .pretty()
1882 .try_init(),
1883 }
1884 .map_err(|err| format!("failed to initialize Shelly logging subscriber: {err}"))?;
1885
1886 info!(
1887 target: "shelly.startup",
1888 service = %service_name.as_ref(),
1889 environment = %env_name,
1890 log_format = %format.as_str(),
1891 "Shelly logging initialized"
1892 );
1893 Ok(())
1894}
1895
1896#[derive(Debug, Clone, PartialEq, Eq)]
1898pub enum TelemetryExporter {
1899 Tracing,
1901 OpenTelemetryJson,
1903 AxiomJson {
1905 dataset: String,
1906 org_id: Option<String>,
1907 },
1908}
1909
1910#[derive(Debug, Clone, PartialEq, Eq)]
1912pub struct TelemetryConfig {
1913 pub service_name: String,
1914 pub exporter: TelemetryExporter,
1915}
1916
1917impl TelemetryConfig {
1918 pub fn tracing(service_name: impl Into<String>) -> Self {
1919 Self {
1920 service_name: service_name.into(),
1921 exporter: TelemetryExporter::Tracing,
1922 }
1923 }
1924
1925 pub fn otel_json(service_name: impl Into<String>) -> Self {
1926 Self {
1927 service_name: service_name.into(),
1928 exporter: TelemetryExporter::OpenTelemetryJson,
1929 }
1930 }
1931
1932 pub fn axiom_json(
1933 service_name: impl Into<String>,
1934 dataset: impl Into<String>,
1935 org_id: Option<String>,
1936 ) -> Self {
1937 Self {
1938 service_name: service_name.into(),
1939 exporter: TelemetryExporter::AxiomJson {
1940 dataset: dataset.into(),
1941 org_id,
1942 },
1943 }
1944 }
1945}
1946
1947#[derive(Clone)]
1948struct TelemetryPipeline {
1949 config: Option<TelemetryConfig>,
1950}
1951
1952impl TelemetryPipeline {
1953 fn disabled() -> Self {
1954 Self { config: None }
1955 }
1956
1957 fn enabled(config: TelemetryConfig) -> Self {
1958 Self {
1959 config: Some(config),
1960 }
1961 }
1962}
1963
1964impl TelemetrySink for TelemetryPipeline {
1965 fn emit(&self, event: TelemetryEvent) -> Result<(), String> {
1966 let Some(config) = &self.config else {
1967 return Ok(());
1968 };
1969 let attributes = redacted_attributes(&event.attributes);
1970
1971 match &config.exporter {
1972 TelemetryExporter::Tracing => {
1973 tracing::info!(
1974 target: "shelly.telemetry",
1975 schema = "shelly.telemetry.v1",
1976 service = %config.service_name,
1977 kind = ?event.kind,
1978 trace_id = event.trace_id.as_deref().unwrap_or("-"),
1979 span_id = event.span_id.as_deref().unwrap_or("-"),
1980 parent_span_id = event.parent_span_id.as_deref().unwrap_or("-"),
1981 correlation_id = event.correlation_id.as_deref().unwrap_or("-"),
1982 request_id = event.request_id.as_deref().unwrap_or("-"),
1983 session_id = event.session_id.as_deref().unwrap_or("-"),
1984 route_path = event.route_path.as_deref().unwrap_or("-"),
1985 event_name = event.event_name.as_deref().unwrap_or("-"),
1986 ok = event.ok,
1987 latency_ms = event.latency_ms.unwrap_or(0),
1988 bytes = event.bytes.unwrap_or(0),
1989 count = event.count.unwrap_or(0),
1990 attributes = ?attributes,
1991 "shelly telemetry event"
1992 );
1993 }
1994 TelemetryExporter::OpenTelemetryJson => {
1995 let payload = serde_json::json!({
1996 "schema": "shelly.telemetry.v1",
1997 "resource": {
1998 "service.name": config.service_name,
1999 },
2000 "otel_log": {
2001 "severity_text": "INFO",
2002 "body": "shelly telemetry event",
2003 "trace_id": event.trace_id,
2004 "span_id": event.span_id,
2005 "attributes": {
2006 "event.kind": event.kind,
2007 "event.session_id": event.session_id,
2008 "event.route_path": event.route_path,
2009 "event.event_name": event.event_name,
2010 "event.ok": event.ok,
2011 "event.latency_ms": event.latency_ms,
2012 "event.bytes": event.bytes,
2013 "event.count": event.count,
2014 "event.parent_span_id": event.parent_span_id,
2015 "event.correlation_id": event.correlation_id,
2016 "event.request_id": event.request_id,
2017 "event.attributes": attributes,
2018 }
2019 }
2020 });
2021 tracing::info!(target: "shelly.telemetry.otel", payload = %payload);
2022 }
2023 TelemetryExporter::AxiomJson { dataset, org_id } => {
2024 let payload = serde_json::json!({
2025 "schema": "shelly.telemetry.v1",
2026 "_time": chrono_like_timestamp(),
2027 "dataset": dataset,
2028 "org_id": org_id,
2029 "service": config.service_name,
2030 "event": {
2031 "kind": event.kind,
2032 "trace_id": event.trace_id,
2033 "span_id": event.span_id,
2034 "parent_span_id": event.parent_span_id,
2035 "correlation_id": event.correlation_id,
2036 "request_id": event.request_id,
2037 "session_id": event.session_id,
2038 "route_path": event.route_path,
2039 "event_name": event.event_name,
2040 "ok": event.ok,
2041 "latency_ms": event.latency_ms,
2042 "bytes": event.bytes,
2043 "count": event.count,
2044 "attributes": attributes,
2045 },
2046 });
2047 tracing::info!(target: "shelly.telemetry.axiom", payload = %payload);
2048 }
2049 }
2050
2051 Ok(())
2052 }
2053}
2054
2055#[derive(Clone)]
2056struct SessionTelemetrySink {
2057 inner: Arc<TelemetryPipeline>,
2058 correlation: CorrelationContext,
2059}
2060
2061impl SessionTelemetrySink {
2062 fn new(inner: Arc<TelemetryPipeline>, correlation: CorrelationContext) -> Self {
2063 Self { inner, correlation }
2064 }
2065}
2066
2067impl TelemetrySink for SessionTelemetrySink {
2068 fn emit(&self, event: TelemetryEvent) -> Result<(), String> {
2069 self.inner.emit(self.correlation.apply_to_event(event))
2070 }
2071}
2072
2073#[derive(Clone)]
2074struct TokenSigner {
2075 secret: Arc<Vec<u8>>,
2076}
2077
2078#[derive(Debug, Clone, PartialEq, Eq)]
2079struct SignedSession {
2080 session_id: String,
2081 path: String,
2082 node_id: Option<String>,
2083}
2084
2085#[derive(Debug, Clone, PartialEq, Eq)]
2086struct SignedResume {
2087 session_id: String,
2088 path: String,
2089 nonce: String,
2090}
2091
2092struct ShellConfig<'a> {
2093 target_id: &'a str,
2094 inner_html: &'a str,
2095 path: &'a str,
2096 title: &'a str,
2097 session_id: &'a str,
2098 session_token: &'a str,
2099 csrf_token: &'a str,
2100 protocol: &'a str,
2101 trace_id: &'a str,
2102 span_id: &'a str,
2103 correlation_id: Option<&'a str>,
2104 request_id: Option<&'a str>,
2105 reconnect_base_ms: u64,
2106 reconnect_max_ms: u64,
2107 reconnect_jitter_ms: u64,
2108 heartbeat_interval_ms: u64,
2109 heartbeat_timeout_ms: u64,
2110 transport_mode: &'a str,
2111 transport_fallbacks: &'a str,
2112 progressive_enhancement: bool,
2113}
2114
2115impl Default for ShellyRouter {
2116 fn default() -> Self {
2117 Self::new()
2118 }
2119}
2120
2121impl ShellyRouter {
2122 pub fn new() -> Self {
2124 Self {
2125 routes: Vec::new(),
2126 target_id: "shelly-root".to_string(),
2127 max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
2128 pubsub: PubSub::default(),
2129 uploads: UploadConfig::default(),
2130 security: SecurityConfig::default(),
2131 telemetry: Arc::new(TelemetryPipeline::disabled()),
2132 reconnect: ReconnectConfig::default(),
2133 distributed: DistributedConfig::default(),
2134 durable: DurableRuntimeConfig::default(),
2135 outbound: OutboundConfig::default(),
2136 render: RenderConfig::default(),
2137 overload: OverloadConfig::default(),
2138 transport_http: HttpTransportConfig::default(),
2139 }
2140 }
2141
2142 pub fn with_target_id(mut self, target_id: impl Into<String>) -> Self {
2144 self.target_id = target_id.into();
2145 self
2146 }
2147
2148 pub fn with_max_message_size(mut self, max_message_size: usize) -> Self {
2150 self.max_message_size = max_message_size;
2151 self
2152 }
2153
2154 pub fn with_pubsub(mut self, pubsub: PubSub) -> Self {
2156 self.pubsub = pubsub;
2157 self
2158 }
2159
2160 pub fn with_max_upload_size(mut self, max_file_size: u64) -> Self {
2162 self.uploads.max_file_size = max_file_size;
2163 self
2164 }
2165
2166 pub fn with_allowed_upload_content_type(mut self, content_type: impl Into<String>) -> Self {
2168 let mut allowed = (*self.uploads.allowed_content_types).clone();
2169 allowed.push(content_type.into());
2170 self.uploads.allowed_content_types = Arc::new(allowed);
2171 self
2172 }
2173
2174 pub fn with_upload_temp_dir(mut self, temp_dir: impl Into<PathBuf>) -> Self {
2176 self.uploads.temp_dir = Arc::new(temp_dir.into());
2177 self
2178 }
2179
2180 pub fn with_secret(mut self, secret: impl Into<Vec<u8>>) -> Self {
2182 self.security.signer = TokenSigner::new(secret);
2183 self
2184 }
2185
2186 pub fn with_allowed_origin(mut self, origin: impl Into<String>) -> Self {
2190 let mut origins = (*self.security.allowed_origins).clone();
2191 origins.push(origin.into());
2192 self.security.allowed_origins = Arc::new(origins);
2193 self
2194 }
2195
2196 pub fn with_rate_limiter<F>(mut self, rate_limiter: F) -> Self
2201 where
2202 F: Fn(&RateLimitContext) -> bool + Send + Sync + 'static,
2203 {
2204 self.security.rate_limiter = Some(Arc::new(rate_limiter));
2205 self
2206 }
2207
2208 pub fn with_authorization_hook<F>(mut self, authorization: F) -> Self
2214 where
2215 F: Fn(&AuthorizationContext) -> AuthorizationDecision + Send + Sync + 'static,
2216 {
2217 self.security.authorization = Some(Arc::new(authorization));
2218 self
2219 }
2220
2221 pub fn with_quota_policy<F>(mut self, quota_policy: F) -> Self
2227 where
2228 F: Fn(&QuotaContext) -> QuotaDecision + Send + Sync + 'static,
2229 {
2230 self.security.quota_policy = Some(Arc::new(quota_policy));
2231 self
2232 }
2233
2234 pub fn with_tenant_quota_policy(mut self, tenant_quota_policy: TenantQuotaPolicy) -> Self {
2236 self.security
2237 .quota_policy
2238 .replace(Arc::new(move |ctx| tenant_quota_policy.evaluate(ctx)));
2239 self
2240 }
2241
2242 pub fn with_telemetry(mut self, config: TelemetryConfig) -> Self {
2244 self.telemetry = Arc::new(TelemetryPipeline::enabled(config));
2245 self
2246 }
2247
2248 pub fn with_reconnect_backoff(mut self, base_ms: u64, max_ms: u64, jitter_ms: u64) -> Self {
2250 self.reconnect.client.reconnect_base_ms = base_ms.max(100);
2251 self.reconnect.client.reconnect_max_ms =
2252 max_ms.max(self.reconnect.client.reconnect_base_ms);
2253 self.reconnect.client.reconnect_jitter_ms = jitter_ms;
2254 self
2255 }
2256
2257 pub fn with_heartbeat(mut self, interval_ms: u64, timeout_ms: u64) -> Self {
2259 self.reconnect.client.heartbeat_interval_ms = interval_ms.max(1_000);
2260 self.reconnect.client.heartbeat_timeout_ms = timeout_ms.max(1_000);
2261 self
2262 }
2263
2264 pub fn with_resume_ttl_ms(mut self, ttl_ms: u64) -> Self {
2266 self.reconnect.resume_ttl = Duration::from_millis(ttl_ms.max(1_000));
2267 self
2268 }
2269
2270 pub fn with_connect_handshake_timeout_ms(mut self, timeout_ms: u64) -> Self {
2272 self.reconnect.connect_handshake_timeout = Duration::from_millis(timeout_ms.max(500));
2273 self
2274 }
2275
2276 pub fn with_http3_alt_svc(mut self, alt_svc: impl Into<String>) -> Self {
2280 let alt_svc = alt_svc.into();
2281 let alt_svc = alt_svc.trim();
2282 self.transport_http.http3_alt_svc = if alt_svc.is_empty() {
2283 None
2284 } else {
2285 Some(alt_svc.to_string())
2286 };
2287 self
2288 }
2289
2290 pub fn with_http3_default_alt_svc(mut self) -> Self {
2292 self.transport_http.http3_alt_svc = Some(DEFAULT_HTTP3_ALT_SVC.to_string());
2293 self
2294 }
2295
2296 pub fn with_transport_diagnostics_header(mut self, enabled: bool) -> Self {
2298 self.transport_http.emit_diagnostics_header = enabled;
2299 self
2300 }
2301
2302 pub fn with_transport_modes(
2308 mut self,
2309 primary: TransportMode,
2310 fallback_modes: impl IntoIterator<Item = TransportMode>,
2311 ) -> Self {
2312 let mut fallbacks = Vec::new();
2313 for mode in fallback_modes {
2314 if mode != primary && !fallbacks.contains(&mode) {
2315 fallbacks.push(mode);
2316 }
2317 }
2318 self.transport_http.primary_mode = primary;
2319 self.transport_http.fallback_modes = fallbacks;
2320 self
2321 }
2322
2323 pub fn with_progressive_transport_fallbacks(mut self) -> Self {
2325 self.transport_http.fallback_modes =
2326 vec![TransportMode::ServerSentEvents, TransportMode::LongPoll];
2327 self.transport_http.progressive_enhancement = true;
2328 self
2329 }
2330
2331 pub fn with_progressive_enhancement(mut self, enabled: bool) -> Self {
2333 self.transport_http.progressive_enhancement = enabled;
2334 self
2335 }
2336
2337 pub fn with_node_id(mut self, node_id: impl Into<String>) -> Self {
2339 self.distributed.node_id = node_id.into();
2340 self
2341 }
2342
2343 pub fn with_session_affinity_mode(mut self, mode: SessionAffinityMode) -> Self {
2345 self.distributed.session_affinity = mode;
2346 self
2347 }
2348
2349 pub fn with_durable_session_store(mut self, store: Arc<dyn DurableSessionStore>) -> Self {
2351 self.durable.store = Some(store);
2352 self
2353 }
2354
2355 pub fn with_durable_lease_ttl_ms(mut self, ttl_ms: u64) -> Self {
2357 self.durable.lease_ttl = Duration::from_millis(ttl_ms.max(1_000));
2358 self
2359 }
2360
2361 pub fn with_durable_journal_limit(mut self, max_entries: usize) -> Self {
2363 self.durable.journal_limit = max_entries.max(1);
2364 self
2365 }
2366
2367 pub fn with_durable_takeover_policy(mut self, policy: DurableTakeoverPolicy) -> Self {
2369 self.durable.takeover_policy = policy;
2370 self
2371 }
2372
2373 pub fn with_durable_drain_mode(mut self, draining: bool) -> Self {
2377 self.durable.drain_mode = draining;
2378 self
2379 }
2380
2381 pub fn with_durable_placement_hook<F>(mut self, placement_hook: F) -> Self
2383 where
2384 F: Fn(&DurablePlacementContext) -> DurablePlacementDecision + Send + Sync + 'static,
2385 {
2386 self.durable.placement_hook = Some(Arc::new(placement_hook));
2387 self
2388 }
2389
2390 pub fn with_outbound_queue_capacity(mut self, queue_capacity: usize) -> Self {
2392 self.outbound.queue_capacity = queue_capacity.max(1);
2393 self
2394 }
2395
2396 pub fn with_outbound_batching(
2398 mut self,
2399 batch_max_messages: usize,
2400 batch_max_bytes: usize,
2401 flush_interval_ms: u64,
2402 ) -> Self {
2403 self.outbound.batch_max_messages = batch_max_messages.max(1);
2404 self.outbound.batch_max_bytes = batch_max_bytes.max(256);
2405 self.outbound.batch_flush_interval = Duration::from_millis(flush_interval_ms.max(1));
2406 self
2407 }
2408
2409 pub fn with_outbound_overflow_policy(mut self, policy: OutboundOverflowPolicy) -> Self {
2411 self.outbound.overflow_policy = policy;
2412 self
2413 }
2414
2415 pub fn with_default_render_cadence_ms(mut self, cadence_ms: u64) -> Self {
2420 self.render.default_cadence_ms = cadence_ms;
2421 self
2422 }
2423
2424 pub fn with_overload_budgets(mut self, budgets: OverloadBudgets) -> Self {
2426 self.overload.budgets = budgets;
2427 self
2428 }
2429
2430 pub fn with_overload_shed_policy(mut self, policy: OverloadShedPolicy) -> Self {
2432 self.overload.shed_policy = policy;
2433 self
2434 }
2435
2436 pub fn with_overload_policy_hook<F>(mut self, policy_hook: F) -> Self
2438 where
2439 F: Fn(&OverloadContext, &OverloadDecision) -> OverloadDecision + Send + Sync + 'static,
2440 {
2441 self.overload.policy_hook = Some(Arc::new(policy_hook));
2442 self
2443 }
2444
2445 pub fn live<V, F>(mut self, path: impl Into<String>, factory: F) -> Self
2447 where
2448 V: LiveView,
2449 F: Fn() -> V + Send + Sync + 'static,
2450 {
2451 let factory: LiveViewFactory = Arc::new(move || Box::new(factory()) as Box<dyn LiveView>);
2452 let pattern = normalize_path(path.into());
2453 self.routes.push(LiveRoute::new(pattern, factory));
2454 self
2455 }
2456
2457 pub fn into_router(self) -> Router {
2459 if let Some(store) = self.durable.store.as_ref() {
2460 store.set_node_draining(&self.distributed.node_id, self.durable.drain_mode);
2461 }
2462 let state = Arc::new(AppState {
2463 routes: Arc::new(self.routes),
2464 target_id: self.target_id,
2465 max_message_size: self.max_message_size,
2466 pubsub: self.pubsub,
2467 uploads: self.uploads,
2468 security: self.security,
2469 telemetry: self.telemetry,
2470 reconnect: self.reconnect,
2471 distributed: self.distributed,
2472 durable: self.durable,
2473 outbound: self.outbound,
2474 render: self.render,
2475 overload: self.overload,
2476 transport_http: self.transport_http,
2477 });
2478
2479 Router::new()
2480 .route("/__shelly/client.js", get(client_js))
2481 .route("/__shelly/transport", get(transport_capabilities))
2482 .route("/__shelly/ws/{*view_path}", get(ws_handler))
2483 .fallback(get(initial_handler))
2484 .layer(TraceLayer::new_for_http())
2485 .with_state(state)
2486 }
2487}
2488
2489fn normalize_path(path: String) -> String {
2490 if path.is_empty() {
2491 return "/".to_string();
2492 }
2493
2494 if path.starts_with('/') {
2495 path
2496 } else {
2497 format!("/{path}")
2498 }
2499}
2500
2501fn route_segments(path: &str) -> Vec<RouteSegment> {
2502 path.trim_matches('/')
2503 .split('/')
2504 .filter(|segment| !segment.is_empty())
2505 .map(|segment| {
2506 if let Some(name) = segment.strip_prefix(':') {
2507 RouteSegment::Param(name.to_string())
2508 } else {
2509 RouteSegment::Static(segment.to_string())
2510 }
2511 })
2512 .collect()
2513}
2514
2515fn path_segments(path: &str) -> Vec<&str> {
2516 path.trim_matches('/')
2517 .split('/')
2518 .filter(|segment| !segment.is_empty())
2519 .collect()
2520}
2521
2522fn log_http_ingress(
2523 endpoint: &'static str,
2524 method: &Method,
2525 uri: &Uri,
2526 headers: &HeaderMap,
2527) -> CorrelationContext {
2528 let correlation = correlation_from_headers(headers);
2529 let user_agent = headers
2530 .get(header::USER_AGENT)
2531 .and_then(|value| value.to_str().ok())
2532 .unwrap_or("-");
2533 let origin = headers
2534 .get(header::ORIGIN)
2535 .and_then(|value| value.to_str().ok())
2536 .unwrap_or("-");
2537 let host = headers
2538 .get(header::HOST)
2539 .and_then(|value| value.to_str().ok())
2540 .unwrap_or("-");
2541 info!(
2542 target: "shelly.incoming.http",
2543 schema = "shelly.incoming.http.v1",
2544 endpoint,
2545 method = %method,
2546 path = uri.path(),
2547 query = sanitize_query_for_log(uri.query()),
2548 trace_id = %correlation.trace_id,
2549 span_id = %correlation.span_id,
2550 parent_span_id = correlation.parent_span_id.as_deref().unwrap_or("-"),
2551 correlation_id = correlation.correlation_id.as_deref().unwrap_or("-"),
2552 request_id = correlation.request_id.as_deref().unwrap_or("-"),
2553 user_agent,
2554 origin,
2555 host,
2556 "Shelly HTTP request received"
2557 );
2558 correlation
2559}
2560
2561fn correlation_from_headers(headers: &HeaderMap) -> CorrelationContext {
2562 let request_id = header_value(headers, "x-request-id");
2563 let correlation_id = header_value(headers, "x-correlation-id").or_else(|| request_id.clone());
2564 let explicit_trace_id = header_value(headers, "x-trace-id")
2565 .and_then(|value| normalize_hex_id(&value, 32))
2566 .or_else(|| {
2567 header_value(headers, "trace-id").and_then(|value| normalize_hex_id(&value, 32))
2568 });
2569 let (trace_from_parent, parent_span_id) = headers
2570 .get("traceparent")
2571 .and_then(|value| value.to_str().ok())
2572 .and_then(parse_traceparent)
2573 .unwrap_or((String::new(), String::new()));
2574
2575 let trace_id = if let Some(value) = explicit_trace_id {
2576 value
2577 } else if !trace_from_parent.is_empty() {
2578 trace_from_parent
2579 } else {
2580 generate_trace_id()
2581 };
2582 let parent_span_id = if parent_span_id.is_empty() {
2583 None
2584 } else {
2585 Some(parent_span_id)
2586 };
2587
2588 CorrelationContext {
2589 trace_id,
2590 span_id: generate_span_id(),
2591 parent_span_id,
2592 correlation_id,
2593 request_id,
2594 }
2595}
2596
2597fn correlation_from_query(
2598 query: &HashMap<String, String>,
2599 fallback: &CorrelationContext,
2600) -> CorrelationContext {
2601 let request_id = query
2602 .get("request_id")
2603 .and_then(|value| non_empty_text(value))
2604 .or_else(|| fallback.request_id.clone());
2605 let correlation_id = query
2606 .get("correlation_id")
2607 .and_then(|value| non_empty_text(value))
2608 .or_else(|| fallback.correlation_id.clone())
2609 .or_else(|| request_id.clone());
2610 let trace_id = query
2611 .get("trace_id")
2612 .and_then(|value| normalize_hex_id(value, 32))
2613 .unwrap_or_else(|| fallback.trace_id.clone());
2614 let parent_span_id = query
2615 .get("parent_span_id")
2616 .and_then(|value| normalize_hex_id(value, 16))
2617 .or_else(|| Some(fallback.span_id.clone()));
2618
2619 CorrelationContext {
2620 trace_id,
2621 span_id: generate_span_id(),
2622 parent_span_id,
2623 correlation_id,
2624 request_id,
2625 }
2626}
2627
2628fn header_value(headers: &HeaderMap, name: &str) -> Option<String> {
2629 headers
2630 .get(name)
2631 .and_then(|value| value.to_str().ok())
2632 .and_then(non_empty_text)
2633}
2634
2635fn non_empty_text(value: &str) -> Option<String> {
2636 let trimmed = value.trim();
2637 if trimmed.is_empty() {
2638 None
2639 } else {
2640 Some(trimmed.to_string())
2641 }
2642}
2643
2644fn normalize_hex_id(value: &str, expected_len: usize) -> Option<String> {
2645 let mut out = String::with_capacity(expected_len);
2646 for ch in value.chars() {
2647 if ch == '-' {
2648 continue;
2649 }
2650 if !ch.is_ascii_hexdigit() {
2651 return None;
2652 }
2653 out.push(ch.to_ascii_lowercase());
2654 }
2655 if out.len() == expected_len {
2656 Some(out)
2657 } else {
2658 None
2659 }
2660}
2661
2662fn parse_traceparent(value: &str) -> Option<(String, String)> {
2663 let mut parts = value.trim().split('-');
2664 let _version = parts.next()?;
2665 let trace_id = normalize_hex_id(parts.next()?, 32)?;
2666 let span_id = normalize_hex_id(parts.next()?, 16)?;
2667 let _flags = parts.next()?;
2668 Some((trace_id, span_id))
2669}
2670
2671fn generate_trace_id() -> String {
2672 Uuid::new_v4().simple().to_string()
2673}
2674
2675fn generate_span_id() -> String {
2676 let hex = Uuid::new_v4().simple().to_string();
2677 hex[..16].to_string()
2678}
2679
2680fn sanitize_query_for_log(query: Option<&str>) -> String {
2681 let Some(query) = query else {
2682 return String::new();
2683 };
2684 if query.is_empty() {
2685 return String::new();
2686 }
2687
2688 query
2689 .split('&')
2690 .filter(|pair| !pair.is_empty())
2691 .map(|pair| {
2692 let (key, value) = pair.split_once('=').unwrap_or((pair, ""));
2693 if is_sensitive_key(key) {
2694 format!("{key}=<redacted>")
2695 } else if value.is_empty() {
2696 key.to_string()
2697 } else {
2698 format!("{key}={value}")
2699 }
2700 })
2701 .collect::<Vec<_>>()
2702 .join("&")
2703}
2704
2705fn is_sensitive_key(key: &str) -> bool {
2706 let key = key.to_ascii_lowercase();
2707 matches!(
2708 key.as_str(),
2709 "session" | "csrf" | "token" | "authorization" | "api_key" | "password"
2710 ) || key.contains("secret")
2711 || key.ends_with("_token")
2712 || key.ends_with("_secret")
2713}
2714
2715fn redacted_attributes(
2716 attributes: &serde_json::Map<String, JsonValue>,
2717) -> serde_json::Map<String, JsonValue> {
2718 attributes
2719 .iter()
2720 .map(|(key, value)| {
2721 if is_sensitive_key(key) {
2722 (key.clone(), JsonValue::String("<redacted>".to_string()))
2723 } else {
2724 (key.clone(), value.clone())
2725 }
2726 })
2727 .collect()
2728}
2729
2730fn apply_http_transport_headers(headers: &mut HeaderMap, config: &HttpTransportConfig) {
2731 if let Some(value) = config.http3_alt_svc.as_deref() {
2732 match header::HeaderValue::from_str(value) {
2733 Ok(value) => {
2734 headers.insert("alt-svc", value);
2735 }
2736 Err(_) => {
2737 warn!(value, "invalid Alt-Svc header value configured");
2738 }
2739 }
2740 }
2741
2742 if config.emit_diagnostics_header {
2743 let diagnostics = config.diagnostics_header();
2744 if let Ok(value) = header::HeaderValue::from_str(&diagnostics) {
2745 headers.insert("x-shelly-transport", value);
2746 }
2747 let progressive = if config.progressive_enhancement {
2748 header::HeaderValue::from_static("enabled")
2749 } else {
2750 header::HeaderValue::from_static("disabled")
2751 };
2752 headers.insert("x-shelly-progressive-enhancement", progressive);
2753 }
2754}
2755
2756fn with_http_transport_headers(mut response: Response, config: &HttpTransportConfig) -> Response {
2757 apply_http_transport_headers(response.headers_mut(), config);
2758 response
2759}
2760
2761async fn client_js(
2762 State(state): State<Arc<AppState>>,
2763 method: Method,
2764 headers: HeaderMap,
2765 uri: Uri,
2766) -> impl IntoResponse {
2767 let _ = log_http_ingress("client_js", &method, &uri, &headers);
2768 with_http_transport_headers(
2769 (
2770 [("content-type", "text/javascript; charset=utf-8")],
2771 CLIENT_JS,
2772 )
2773 .into_response(),
2774 &state.transport_http,
2775 )
2776}
2777
2778async fn transport_capabilities(
2779 State(state): State<Arc<AppState>>,
2780 method: Method,
2781 headers: HeaderMap,
2782 uri: Uri,
2783) -> Response {
2784 let _ = log_http_ingress("transport_capabilities", &method, &uri, &headers);
2785 let capabilities = TransportCapabilities {
2786 protocol: PROTOCOL_VERSION,
2787 primary: state.transport_http.primary_mode.as_str(),
2788 fallbacks: state.transport_http.fallback_names(),
2789 progressive_enhancement: state.transport_http.progressive_enhancement,
2790 http3_advertised: state.transport_http.http3_alt_svc.is_some(),
2791 };
2792 with_http_transport_headers(Json(capabilities).into_response(), &state.transport_http)
2793}
2794
2795async fn initial_handler(
2796 State(state): State<Arc<AppState>>,
2797 method: Method,
2798 headers: HeaderMap,
2799 uri: Uri,
2800) -> Response {
2801 let correlation = log_http_ingress("initial_render", &method, &uri, &headers);
2802 let path = uri.path();
2803 let Some(route) = state.route_for(path) else {
2804 warn!(path, "Shelly route not found");
2805 return with_http_transport_headers(
2806 (
2807 StatusCode::NOT_FOUND,
2808 format!("No Shelly route registered for {path}"),
2809 )
2810 .into_response(),
2811 &state.transport_http,
2812 );
2813 };
2814
2815 debug!(path, "rendering initial Shelly route");
2816 let mut session = LiveSession::new_with_route(
2817 (route.factory)(),
2818 state.target_id.clone(),
2819 route.path.clone(),
2820 route.params,
2821 );
2822 let session_telemetry: Arc<dyn TelemetrySink> = Arc::new(SessionTelemetrySink::new(
2823 state.telemetry.clone(),
2824 correlation.with_new_span(),
2825 ));
2826 session.set_telemetry_sink(session_telemetry);
2827 if let Err(err) = session.mount() {
2828 error!(path, ?err, "Shelly mount failed during initial render");
2829 return with_http_transport_headers(
2830 (
2831 StatusCode::INTERNAL_SERVER_ERROR,
2832 format!("Shelly mount failed: {err}"),
2833 )
2834 .into_response(),
2835 &state.transport_http,
2836 );
2837 }
2838
2839 let session_token =
2840 state
2841 .security
2842 .signer
2843 .sign_session(session.session_id(), path, &state.distributed.node_id);
2844 let csrf_token = state.security.signer.sign_csrf(session.session_id(), path);
2845 let transport_fallbacks = state.transport_http.fallback_json();
2846 let body = render_shell(ShellConfig {
2847 target_id: &state.target_id,
2848 inner_html: session.render_html().as_str(),
2849 path,
2850 title: "Shelly LiveView",
2851 session_id: session.session_id(),
2852 session_token: &session_token,
2853 csrf_token: &csrf_token,
2854 protocol: PROTOCOL_VERSION,
2855 trace_id: &correlation.trace_id,
2856 span_id: &correlation.span_id,
2857 correlation_id: correlation.correlation_id.as_deref(),
2858 request_id: correlation.request_id.as_deref(),
2859 reconnect_base_ms: state.reconnect.client.reconnect_base_ms,
2860 reconnect_max_ms: state.reconnect.client.reconnect_max_ms,
2861 reconnect_jitter_ms: state.reconnect.client.reconnect_jitter_ms,
2862 heartbeat_interval_ms: state.reconnect.client.heartbeat_interval_ms,
2863 heartbeat_timeout_ms: state.reconnect.client.heartbeat_timeout_ms,
2864 transport_mode: state.transport_http.primary_mode.as_str(),
2865 transport_fallbacks: &transport_fallbacks,
2866 progressive_enhancement: state.transport_http.progressive_enhancement,
2867 });
2868
2869 with_http_transport_headers(AxumHtml(body).into_response(), &state.transport_http)
2870}
2871
2872async fn ws_handler(
2873 Path(view_path): Path<String>,
2874 State(state): State<Arc<AppState>>,
2875 method: Method,
2876 headers: HeaderMap,
2877 uri: Uri,
2878 ws: WebSocketUpgrade,
2879) -> Response {
2880 let route_path = if view_path == "__root__" {
2881 "/".to_string()
2882 } else {
2883 normalize_path(view_path)
2884 };
2885 let query = parse_query(uri.query().unwrap_or_default());
2886 let ingress_correlation = log_http_ingress("ws_upgrade", &method, &uri, &headers);
2887 let socket_correlation = correlation_from_query(&query, &ingress_correlation);
2888 info!(
2889 target: "shelly.incoming.http",
2890 schema = "shelly.incoming.http.v1",
2891 transport = "ws",
2892 stage = "upgrade_request",
2893 route = %route_path,
2894 trace_id = %socket_correlation.trace_id,
2895 span_id = %socket_correlation.span_id,
2896 parent_span_id = socket_correlation.parent_span_id.as_deref().unwrap_or("-"),
2897 correlation_id = socket_correlation.correlation_id.as_deref().unwrap_or("-"),
2898 request_id = socket_correlation.request_id.as_deref().unwrap_or("-"),
2899 protocol = query
2900 .get("protocol")
2901 .map(String::as_str)
2902 .unwrap_or("unknown"),
2903 has_session_token = query.contains_key("session"),
2904 has_csrf_token = query.contains_key("csrf"),
2905 "Shelly websocket upgrade request received"
2906 );
2907
2908 let Some(route) = state.route_for(&route_path) else {
2909 warn!(path = route_path, "Shelly websocket route not found");
2910 return with_http_transport_headers(
2911 (
2912 StatusCode::NOT_FOUND,
2913 format!("No Shelly route registered for {route_path}"),
2914 )
2915 .into_response(),
2916 &state.transport_http,
2917 );
2918 };
2919
2920 if !origin_allowed(&headers, &state.security) {
2921 emit_security_audit(
2922 &state.telemetry,
2923 &socket_correlation,
2924 &route_path,
2925 None,
2926 None,
2927 "origin_check",
2928 false,
2929 Some("origin_rejected"),
2930 SecurityOperation::Connect,
2931 "text",
2932 None,
2933 Some("origin rejected by websocket policy"),
2934 );
2935 warn!(
2936 path = route_path,
2937 "Shelly websocket rejected cross-origin request"
2938 );
2939 return with_http_transport_headers(
2940 (
2941 StatusCode::FORBIDDEN,
2942 "Shelly WebSocket origin rejected".to_string(),
2943 )
2944 .into_response(),
2945 &state.transport_http,
2946 );
2947 }
2948
2949 if query.get("protocol").map(String::as_str) != Some(PROTOCOL_VERSION) {
2950 emit_security_audit(
2951 &state.telemetry,
2952 &socket_correlation,
2953 &route_path,
2954 None,
2955 None,
2956 "protocol_check",
2957 false,
2958 Some("unsupported_protocol"),
2959 SecurityOperation::Connect,
2960 "text",
2961 None,
2962 Some("unsupported websocket protocol version"),
2963 );
2964 warn!(
2965 path = route_path,
2966 "Shelly websocket rejected unsupported protocol"
2967 );
2968 return with_http_transport_headers(
2969 (
2970 StatusCode::BAD_REQUEST,
2971 format!("Shelly WebSocket requires protocol {PROTOCOL_VERSION}"),
2972 )
2973 .into_response(),
2974 &state.transport_http,
2975 );
2976 }
2977
2978 let Some(signed_session) = verify_websocket_tokens(&state.security.signer, &query, &route_path)
2979 else {
2980 emit_security_audit(
2981 &state.telemetry,
2982 &socket_correlation,
2983 &route_path,
2984 None,
2985 None,
2986 "token_check",
2987 false,
2988 Some("signed_session_required"),
2989 SecurityOperation::Connect,
2990 "text",
2991 None,
2992 Some("signed session or csrf token validation failed"),
2993 );
2994 warn!(
2995 path = route_path,
2996 "Shelly websocket rejected unsigned or invalid session"
2997 );
2998 return with_http_transport_headers(
2999 (
3000 StatusCode::UNAUTHORIZED,
3001 "Shelly WebSocket requires a valid signed session".to_string(),
3002 )
3003 .into_response(),
3004 &state.transport_http,
3005 );
3006 };
3007 if let Some(affinity_context) =
3008 session_affinity_mismatch(&state.distributed, &signed_session, &route_path)
3009 {
3010 emit_security_audit(
3011 &state.telemetry,
3012 &socket_correlation,
3013 &route_path,
3014 Some(&affinity_context.session_id),
3015 None,
3016 "session_affinity",
3017 false,
3018 Some("session_affinity_mismatch"),
3019 SecurityOperation::Connect,
3020 "text",
3021 None,
3022 Some("session affinity policy rejected websocket ownership"),
3023 );
3024 warn!(
3025 path = affinity_context.route_path,
3026 session_id = affinity_context.session_id,
3027 current_node_id = affinity_context.current_node_id,
3028 token_node_id = affinity_context.token_node_id.as_deref().unwrap_or("-"),
3029 "Shelly websocket rejected session-affinity mismatch"
3030 );
3031 return with_http_transport_headers(
3032 (
3033 StatusCode::CONFLICT,
3034 "Shelly WebSocket session-affinity mismatch".to_string(),
3035 )
3036 .into_response(),
3037 &state.transport_http,
3038 );
3039 }
3040
3041 let config = SocketConfig {
3042 routes: state.routes.clone(),
3043 target_id: state.target_id.clone(),
3044 route_path: route_path.clone(),
3045 signed_session_id: signed_session.session_id,
3046 max_message_size: state.max_message_size,
3047 pubsub: state.pubsub.clone(),
3048 uploads: state.uploads.clone(),
3049 security: state.security.clone(),
3050 telemetry: state.telemetry.clone(),
3051 correlation: socket_correlation,
3052 reconnect: state.reconnect.clone(),
3053 distributed: state.distributed.clone(),
3054 durable: state.durable.clone(),
3055 outbound: state.outbound.clone(),
3056 render: state.render.clone(),
3057 overload: state.overload.clone(),
3058 };
3059 with_http_transport_headers(
3060 ws.on_upgrade(move |socket| handle_socket(socket, route, config))
3061 .into_response(),
3062 &state.transport_http,
3063 )
3064}
3065
3066fn cleanup_expired_snapshots(snapshots: &mut HashMap<String, ResumeSnapshot>) {
3067 let now = Instant::now();
3068 snapshots.retain(|_, snapshot| snapshot.expires_at > now);
3069}
3070
3071fn session_affinity_mismatch(
3072 distributed: &DistributedConfig,
3073 signed_session: &SignedSession,
3074 route_path: &str,
3075) -> Option<SessionAffinityContext> {
3076 if distributed.session_affinity != SessionAffinityMode::Required {
3077 return None;
3078 }
3079 if signed_session.node_id.as_deref() == Some(distributed.node_id.as_str()) {
3080 return None;
3081 }
3082 Some(SessionAffinityContext {
3083 session_id: signed_session.session_id.clone(),
3084 route_path: route_path.to_string(),
3085 current_node_id: distributed.node_id.clone(),
3086 token_node_id: signed_session.node_id.clone(),
3087 })
3088}
3089
3090async fn take_resume_snapshot(
3091 reconnect: &ReconnectConfig,
3092 session_id: &str,
3093) -> Option<ResumeSnapshot> {
3094 let mut snapshots = reconnect.snapshots.lock().await;
3095 cleanup_expired_snapshots(&mut snapshots);
3096 snapshots.remove(session_id)
3097}
3098
3099async fn put_resume_snapshot(
3100 reconnect: &ReconnectConfig,
3101 session_id: String,
3102 snapshot: ResumeSnapshot,
3103) {
3104 let mut snapshots = reconnect.snapshots.lock().await;
3105 cleanup_expired_snapshots(&mut snapshots);
3106 snapshots.insert(session_id, snapshot);
3107}
3108
3109async fn stash_reconnect_snapshot(
3110 reconnect: &ReconnectConfig,
3111 session_id: String,
3112 session: LiveSession,
3113 route_pattern: String,
3114 resume_token: Option<String>,
3115) {
3116 let Some(resume_token) = resume_token else {
3117 return;
3118 };
3119 let snapshot = ResumeSnapshot {
3120 session,
3121 route_pattern,
3122 resume_token,
3123 expires_at: Instant::now() + reconnect.resume_ttl,
3124 };
3125 put_resume_snapshot(reconnect, session_id, snapshot).await;
3126}
3127
3128#[derive(Debug, Clone)]
3129struct DurableLeaseHandle {
3130 owner_node_id: String,
3131 fence_token: u64,
3132}
3133
3134struct DurableRecoveryResult {
3135 snapshot: ResumeSnapshot,
3136 replayed_entries: usize,
3137 source_owner_node_id: String,
3138}
3139
3140fn durable_error_message(error: &DurableStoreError) -> ServerMessage {
3141 ServerMessage::Error {
3142 message: error.message.clone(),
3143 code: Some(error.code.clone()),
3144 }
3145}
3146
3147fn durable_message_should_journal(message: &ClientMessage) -> bool {
3148 match message {
3149 ClientMessage::Event { event, .. } => event != INTERNAL_RENDER_FLUSH_EVENT,
3150 ClientMessage::PatchUrl { .. } | ClientMessage::Navigate { .. } => true,
3151 _ => false,
3152 }
3153}
3154
3155fn contains_server_error(messages: &[ServerMessage]) -> bool {
3156 messages
3157 .iter()
3158 .any(|message| matches!(message, ServerMessage::Error { .. }))
3159}
3160
3161fn durable_save_snapshot(
3162 durable: &DurableRuntimeConfig,
3163 lease: &DurableLeaseHandle,
3164 session: &LiveSession,
3165 route_pattern: &str,
3166 target_id: &str,
3167 resume_token: &str,
3168) {
3169 let Some(store) = durable.store.as_ref() else {
3170 return;
3171 };
3172 store.save_snapshot(
3173 session.session_id(),
3174 DurableSessionSnapshot {
3175 route_path: session.route_path().to_string(),
3176 route_pattern: route_pattern.to_string(),
3177 target_id: target_id.to_string(),
3178 revision: session.revision(),
3179 resume_token: resume_token.to_string(),
3180 owner_node_id: lease.owner_node_id.clone(),
3181 updated_at_unix_ms: now_unix_ms(),
3182 },
3183 );
3184}
3185
3186fn durable_append_journal_entry(
3187 durable: &DurableRuntimeConfig,
3188 lease: &DurableLeaseHandle,
3189 session_id: &str,
3190 message: &ClientMessage,
3191) -> Result<Option<DurableJournalEntry>, DurableStoreError> {
3192 if !durable_message_should_journal(message) {
3193 return Ok(None);
3194 }
3195 let Some(store) = durable.store.as_ref() else {
3196 return Ok(None);
3197 };
3198 let entry = store.append_journal_entry(
3199 session_id,
3200 &lease.owner_node_id,
3201 lease.fence_token,
3202 message.clone(),
3203 durable.journal_limit,
3204 )?;
3205 Ok(Some(entry))
3206}
3207
3208fn durable_release_lease(
3209 durable: &DurableRuntimeConfig,
3210 session_id: &str,
3211 lease: Option<&DurableLeaseHandle>,
3212) {
3213 let (Some(store), Some(lease)) = (durable.store.as_ref(), lease) else {
3214 return;
3215 };
3216 store.release_lease(session_id, &lease.owner_node_id, lease.fence_token);
3217}
3218
3219fn durable_acquire_lease(
3220 durable: &DurableRuntimeConfig,
3221 distributed: &DistributedConfig,
3222 session_id: &str,
3223 route_path: &str,
3224) -> Result<Option<DurableLeaseHandle>, DurableStoreError> {
3225 let Some(store) = durable.store.as_ref() else {
3226 return Ok(None);
3227 };
3228
3229 if durable.drain_mode || store.is_node_draining(&distributed.node_id) {
3230 return Err(DurableStoreError::new(
3231 "node_draining",
3232 format!(
3233 "node {} is draining and cannot accept durable sessions",
3234 distributed.node_id
3235 ),
3236 ));
3237 }
3238
3239 if let Some(placement_hook) = durable.placement_hook.as_ref() {
3240 let preferred_node_id = store
3241 .load_record(session_id)
3242 .map(|record| record.snapshot.owner_node_id);
3243 let decision = placement_hook(&DurablePlacementContext {
3244 session_id: session_id.to_string(),
3245 route_path: route_path.to_string(),
3246 current_node_id: distributed.node_id.clone(),
3247 preferred_node_id,
3248 });
3249 if !decision.allowed {
3250 return Err(DurableStoreError::new(
3251 decision
3252 .code
3253 .unwrap_or_else(|| "placement_rejected".to_string()),
3254 decision
3255 .message
3256 .unwrap_or_else(|| "durable placement policy rejected ownership".to_string()),
3257 ));
3258 }
3259 }
3260
3261 let lease = store.acquire_lease(DurableLeaseRequest {
3262 session_id: session_id.to_string(),
3263 node_id: distributed.node_id.clone(),
3264 ttl_ms: durable.lease_ttl.as_millis() as u64,
3265 takeover_policy: durable.takeover_policy,
3266 })?;
3267 Ok(Some(DurableLeaseHandle {
3268 owner_node_id: lease.owner_node_id,
3269 fence_token: lease.fence_token,
3270 }))
3271}
3272
3273fn replay_durable_journal_entry(
3274 session: &mut LiveSession,
3275 current_route_pattern: &mut String,
3276 routes: &[LiveRoute],
3277 target_id: &str,
3278 session_id: &str,
3279 telemetry: &Arc<dyn TelemetrySink>,
3280 entry: &DurableJournalEntry,
3281) -> Result<(), String> {
3282 let messages = match &entry.message {
3283 ClientMessage::PatchUrl { to } => {
3284 handle_patch_url(session, current_route_pattern, routes, to)
3285 }
3286 ClientMessage::Navigate { to } => handle_navigate(
3287 session,
3288 current_route_pattern,
3289 routes,
3290 target_id,
3291 session_id,
3292 telemetry,
3293 to,
3294 ),
3295 ClientMessage::Event { .. } => session.handle_client_message(entry.message.clone()),
3296 _ => Vec::new(),
3297 };
3298 if let Some(ServerMessage::Error { code, message }) = messages
3299 .into_iter()
3300 .find(|message| matches!(message, ServerMessage::Error { .. }))
3301 {
3302 return Err(format!(
3303 "durable journal replay failed: code={} message={}",
3304 code.unwrap_or_else(|| "-".to_string()),
3305 message
3306 ));
3307 }
3308 Ok(())
3309}
3310
3311fn recover_from_durable_record(
3312 config: &SocketConfig,
3313 session_id: &str,
3314 telemetry: &Arc<dyn TelemetrySink>,
3315) -> Option<DurableRecoveryResult> {
3316 let store = config.durable.store.as_ref()?;
3317 let record = store.load_record(session_id)?;
3318 let DurableSessionRecord {
3319 snapshot: stored_snapshot,
3320 mut journal,
3321 } = record;
3322 let recovered_path = normalize_path(stored_snapshot.route_path.clone());
3323 let matched_route = config
3324 .routes
3325 .iter()
3326 .find_map(|route| route.match_path(&recovered_path))
3327 .or_else(|| {
3328 config
3329 .routes
3330 .iter()
3331 .find_map(|route| route.match_path(&config.route_path))
3332 })?;
3333
3334 let mut current_route_pattern = matched_route.pattern.clone();
3335 let mut session = LiveSession::new_with_route_and_session_id(
3336 (matched_route.factory)(),
3337 session_id.to_string(),
3338 config.target_id.clone(),
3339 matched_route.path,
3340 matched_route.params,
3341 );
3342 session.set_telemetry_sink(telemetry.clone());
3343 if session.mount().is_err() {
3344 return None;
3345 }
3346
3347 let mut replayed_entries = 0usize;
3348 journal.sort_by_key(|entry| entry.sequence);
3349 for entry in journal {
3350 if replay_durable_journal_entry(
3351 &mut session,
3352 &mut current_route_pattern,
3353 &config.routes,
3354 &config.target_id,
3355 session_id,
3356 telemetry,
3357 &entry,
3358 )
3359 .is_err()
3360 {
3361 return None;
3362 }
3363 replayed_entries += 1;
3364 }
3365
3366 Some(DurableRecoveryResult {
3367 snapshot: ResumeSnapshot {
3368 session,
3369 route_pattern: current_route_pattern,
3370 resume_token: stored_snapshot.resume_token.clone(),
3371 expires_at: Instant::now() + config.reconnect.resume_ttl,
3372 },
3373 replayed_entries,
3374 source_owner_node_id: stored_snapshot.owner_node_id,
3375 })
3376}
3377
3378fn correlation_from_connect(
3379 base: &CorrelationContext,
3380 connect: &ConnectHandshake,
3381) -> CorrelationContext {
3382 let trace_id = connect
3383 .trace_id
3384 .as_deref()
3385 .and_then(|value| normalize_hex_id(value, 32))
3386 .unwrap_or_else(|| base.trace_id.clone());
3387 let parent_span_id = connect
3388 .span_id
3389 .as_deref()
3390 .and_then(|value| normalize_hex_id(value, 16))
3391 .or_else(|| {
3392 connect
3393 .parent_span_id
3394 .as_deref()
3395 .and_then(|value| normalize_hex_id(value, 16))
3396 })
3397 .or_else(|| Some(base.span_id.clone()));
3398 let correlation_id = connect
3399 .correlation_id
3400 .as_deref()
3401 .and_then(non_empty_text)
3402 .or_else(|| base.correlation_id.clone());
3403 let request_id = connect
3404 .request_id
3405 .as_deref()
3406 .and_then(non_empty_text)
3407 .or_else(|| base.request_id.clone());
3408
3409 CorrelationContext {
3410 trace_id,
3411 span_id: generate_span_id(),
3412 parent_span_id,
3413 correlation_id,
3414 request_id,
3415 }
3416}
3417
3418fn reconnect_event(
3419 session_id: &str,
3420 route_path: &str,
3421 correlation: &CorrelationContext,
3422 ok: bool,
3423 status: ResumeStatus,
3424 reason: Option<&str>,
3425) -> TelemetryEvent {
3426 let mut event = TelemetryEvent::new(TelemetryEventKind::Connect)
3427 .with_session(session_id.to_string())
3428 .with_route(route_path.to_string())
3429 .with_ok(ok)
3430 .with_attribute(
3431 "resume_status".to_string(),
3432 JsonValue::String(
3433 match status {
3434 ResumeStatus::Fresh => "fresh",
3435 ResumeStatus::Resumed => "resumed",
3436 ResumeStatus::Fallback => "fallback",
3437 }
3438 .to_string(),
3439 ),
3440 );
3441 event = correlation.apply_to_event(event);
3442 if let Some(reason) = reason {
3443 event = event.with_attribute(
3444 "resume_reason".to_string(),
3445 JsonValue::String(reason.to_string()),
3446 );
3447 }
3448 event
3449}
3450
3451async fn read_connect_handshake(
3452 receiver: &mut futures_util::stream::SplitStream<WebSocket>,
3453 timeout_window: Duration,
3454) -> Result<ConnectHandshake, ServerMessage> {
3455 let maybe_frame = timeout(timeout_window, receiver.next())
3456 .await
3457 .map_err(|_| ServerMessage::Error {
3458 message: "connect handshake timed out".to_string(),
3459 code: Some("connect_timeout".to_string()),
3460 })?;
3461 let Some(frame) = maybe_frame else {
3462 return Err(ServerMessage::Error {
3463 message: "websocket closed before connect handshake".to_string(),
3464 code: Some("connect_missing".to_string()),
3465 });
3466 };
3467 let frame = frame.map_err(|err| ServerMessage::Error {
3468 message: format!("websocket receive error during connect handshake: {err}"),
3469 code: Some("connect_io_error".to_string()),
3470 })?;
3471 let Message::Text(text) = frame else {
3472 return Err(ServerMessage::Error {
3473 message: "first websocket frame must be a connect message".to_string(),
3474 code: Some("connect_required".to_string()),
3475 });
3476 };
3477
3478 let parsed = serde_json::from_str::<ClientMessage>(text.as_str()).map_err(|err| {
3479 ServerMessage::Error {
3480 message: format!("invalid protocol message: {err}"),
3481 code: Some("invalid_protocol".to_string()),
3482 }
3483 })?;
3484
3485 let ClientMessage::Connect {
3486 protocol,
3487 session_id,
3488 last_revision,
3489 resume_token,
3490 tenant_id,
3491 trace_id,
3492 span_id,
3493 parent_span_id,
3494 correlation_id,
3495 request_id,
3496 } = parsed
3497 else {
3498 return Err(ServerMessage::Error {
3499 message: "first websocket message must be type connect".to_string(),
3500 code: Some("connect_required".to_string()),
3501 });
3502 };
3503
3504 if protocol != PROTOCOL_VERSION {
3505 return Err(ServerMessage::Error {
3506 message: format!(
3507 "unsupported protocol in connect: expected {PROTOCOL_VERSION}, got {protocol}"
3508 ),
3509 code: Some("unsupported_protocol".to_string()),
3510 });
3511 }
3512
3513 Ok(ConnectHandshake {
3514 client_session_id: session_id,
3515 client_revision: last_revision.unwrap_or(0),
3516 resume_token,
3517 tenant_id: normalize_tenant_id(tenant_id),
3518 trace_id,
3519 span_id,
3520 parent_span_id,
3521 correlation_id,
3522 request_id,
3523 })
3524}
3525
3526async fn handle_socket(socket: WebSocket, route: MatchedRoute, config: SocketConfig) {
3527 let (mut sender, mut receiver) = socket.split();
3528 let mut current_route_pattern: String;
3529 let (pubsub_tx, mut pubsub_rx) = mpsc::unbounded_channel();
3530 let (runtime_tx, mut runtime_rx) = mpsc::unbounded_channel();
3531 let mut pubsub_tasks = Vec::new();
3532 let mut runtime_tasks = HashMap::new();
3533 let mut subscribed_topics = HashSet::new();
3534 let mut uploads = HashMap::new();
3535 let session_id = config.signed_session_id.clone();
3536 let connect =
3537 match read_connect_handshake(&mut receiver, config.reconnect.connect_handshake_timeout)
3538 .await
3539 {
3540 Ok(connect) => connect,
3541 Err(message) => {
3542 emit_security_audit(
3543 &config.telemetry,
3544 &config.correlation,
3545 &config.route_path,
3546 Some(&session_id),
3547 None,
3548 "connect_handshake",
3549 false,
3550 Some("connect_required"),
3551 SecurityOperation::Connect,
3552 "text",
3553 None,
3554 Some("first websocket message was not a valid connect handshake"),
3555 );
3556 warn!(
3557 path = config.route_path,
3558 session_id, "Shelly websocket connect handshake failed"
3559 );
3560 let _ = send_server_message(&mut sender, &message).await;
3561 return;
3562 }
3563 };
3564
3565 if connect
3566 .client_session_id
3567 .as_deref()
3568 .map(|id| id != session_id.as_str())
3569 .unwrap_or(false)
3570 {
3571 emit_security_audit(
3572 &config.telemetry,
3573 &config.correlation,
3574 &config.route_path,
3575 Some(&session_id),
3576 connect.tenant_id.as_deref(),
3577 "connect_session_match",
3578 false,
3579 Some("session_mismatch"),
3580 SecurityOperation::Connect,
3581 "text",
3582 None,
3583 Some("connect session id does not match signed session"),
3584 );
3585 let _ = send_server_message(
3586 &mut sender,
3587 &ServerMessage::Error {
3588 message: "connect session id does not match signed session".to_string(),
3589 code: Some("session_mismatch".to_string()),
3590 },
3591 )
3592 .await;
3593 return;
3594 }
3595 let connection_correlation = correlation_from_connect(&config.correlation, &connect);
3596 let session_telemetry: Arc<dyn TelemetrySink> = Arc::new(SessionTelemetrySink::new(
3597 config.telemetry.clone(),
3598 connection_correlation.clone(),
3599 ));
3600 if let Some(rejection) = quota_denied(
3601 &config.security,
3602 &config.route_path,
3603 &session_id,
3604 connect.tenant_id.as_deref(),
3605 "text",
3606 SecurityOperation::Connect,
3607 None,
3608 ) {
3609 emit_security_audit(
3610 &config.telemetry,
3611 &connection_correlation,
3612 &config.route_path,
3613 Some(&session_id),
3614 connect.tenant_id.as_deref(),
3615 "quota_policy",
3616 false,
3617 Some("quota_exceeded"),
3618 SecurityOperation::Connect,
3619 "text",
3620 None,
3621 Some("tenant/session quota policy rejected connect handshake"),
3622 );
3623 let _ = send_server_message(&mut sender, &rejection).await;
3624 return;
3625 }
3626 if let Some(rejection) = authorization_denied(
3627 &config.security,
3628 AuthorizationInput {
3629 route_path: &config.route_path,
3630 session_id: &session_id,
3631 tenant_id: connect.tenant_id.as_deref(),
3632 message_kind: "text",
3633 operation: SecurityOperation::Connect,
3634 event_name: None,
3635 event_target: None,
3636 },
3637 ) {
3638 emit_security_audit(
3639 &config.telemetry,
3640 &connection_correlation,
3641 &config.route_path,
3642 Some(&session_id),
3643 connect.tenant_id.as_deref(),
3644 "authorization_policy",
3645 false,
3646 Some("unauthorized"),
3647 SecurityOperation::Connect,
3648 "text",
3649 None,
3650 Some("tenant/session authorization policy rejected connect handshake"),
3651 );
3652 let _ = send_server_message(&mut sender, &rejection).await;
3653 return;
3654 }
3655 emit_security_audit(
3656 &config.telemetry,
3657 &connection_correlation,
3658 &config.route_path,
3659 Some(&session_id),
3660 connect.tenant_id.as_deref(),
3661 "connect_gate",
3662 true,
3663 Some("allowed"),
3664 SecurityOperation::Connect,
3665 "text",
3666 None,
3667 Some("connect gate accepted websocket session"),
3668 );
3669
3670 let durable_lease = match durable_acquire_lease(
3671 &config.durable,
3672 &config.distributed,
3673 &session_id,
3674 &config.route_path,
3675 ) {
3676 Ok(lease) => lease,
3677 Err(err) => {
3678 warn!(
3679 path = config.route_path,
3680 session_id,
3681 code = err.code,
3682 message = err.message,
3683 "Shelly durable ownership rejected websocket connect"
3684 );
3685 let _ = send_server_message(&mut sender, &durable_error_message(&err)).await;
3686 return;
3687 }
3688 };
3689
3690 let mut resume_status = ResumeStatus::Fresh;
3691 let mut resume_reason: Option<String> = None;
3692 let hello_revision: u64;
3693 let mut hello_server_revision: Option<u64> = None;
3694 let mut reconciliation_patch: Option<ServerMessage> = None;
3695 let mut recovered_from_durable = false;
3696 let mut durable_replay_metadata: Option<(usize, String)> = None;
3697
3698 let recoverable_snapshot = if let Some(snapshot) =
3699 take_resume_snapshot(&config.reconnect, &session_id).await
3700 {
3701 Some(snapshot)
3702 } else if let Some(recovery) =
3703 recover_from_durable_record(&config, &session_id, &session_telemetry)
3704 {
3705 recovered_from_durable = true;
3706 durable_replay_metadata = Some((recovery.replayed_entries, recovery.source_owner_node_id));
3707 Some(recovery.snapshot)
3708 } else {
3709 None
3710 };
3711
3712 let mut session = if let Some(snapshot) = recoverable_snapshot {
3713 let token_reason = match connect.resume_token.as_deref() {
3714 None => Some("resume_token_missing"),
3715 Some(token) => match config.security.signer.verify_resume(token) {
3716 None => Some("resume_token_invalid"),
3717 Some(signed_resume) => {
3718 if signed_resume.session_id != session_id
3719 || signed_resume.path != config.route_path
3720 {
3721 Some("resume_token_scope_mismatch")
3722 } else if token != snapshot.resume_token {
3723 Some("resume_token_stale")
3724 } else {
3725 None
3726 }
3727 }
3728 },
3729 };
3730
3731 if let Some(reason) = token_reason {
3732 resume_status = ResumeStatus::Fallback;
3733 resume_reason = Some(reason.to_string());
3734 let mut fresh = LiveSession::new_with_route_and_session_id(
3735 (route.factory)(),
3736 session_id.clone(),
3737 config.target_id.clone(),
3738 route.path.clone(),
3739 route.params.clone(),
3740 );
3741 fresh.set_tenant_id_optional(connect.tenant_id.clone());
3742 fresh.set_telemetry_sink(session_telemetry.clone());
3743 if let Err(err) = fresh.mount() {
3744 error!(
3745 path = config.route_path,
3746 session_id,
3747 ?err,
3748 "Shelly websocket mount failed"
3749 );
3750 let _ = send_server_message(
3751 &mut sender,
3752 &ServerMessage::Error {
3753 message: format!("mount failed: {err}"),
3754 code: Some("mount_failed".to_string()),
3755 },
3756 )
3757 .await;
3758 durable_release_lease(&config.durable, &session_id, durable_lease.as_ref());
3759 return;
3760 }
3761 current_route_pattern = route.pattern.clone();
3762 hello_revision = fresh.revision();
3763 fresh
3764 } else {
3765 let mut resumed = snapshot.session;
3766 resumed.set_telemetry_sink(session_telemetry.clone());
3767 let server_revision = resumed.revision();
3768 if connect.client_revision > server_revision {
3769 resume_status = ResumeStatus::Fallback;
3770 resume_reason = Some("resume_revision_ahead".to_string());
3771 let mut fresh = LiveSession::new_with_route_and_session_id(
3772 (route.factory)(),
3773 session_id.clone(),
3774 config.target_id.clone(),
3775 route.path.clone(),
3776 route.params.clone(),
3777 );
3778 fresh.set_tenant_id_optional(connect.tenant_id.clone());
3779 fresh.set_telemetry_sink(session_telemetry.clone());
3780 if let Err(err) = fresh.mount() {
3781 error!(
3782 path = config.route_path,
3783 session_id,
3784 ?err,
3785 "Shelly websocket mount failed"
3786 );
3787 let _ = send_server_message(
3788 &mut sender,
3789 &ServerMessage::Error {
3790 message: format!("mount failed: {err}"),
3791 code: Some("mount_failed".to_string()),
3792 },
3793 )
3794 .await;
3795 durable_release_lease(&config.durable, &session_id, durable_lease.as_ref());
3796 return;
3797 }
3798 current_route_pattern = route.pattern.clone();
3799 hello_revision = fresh.revision();
3800 fresh
3801 } else {
3802 resume_status = ResumeStatus::Resumed;
3803 current_route_pattern = snapshot.route_pattern;
3804 if connect.client_revision < server_revision {
3805 hello_revision = connect.client_revision;
3806 hello_server_revision = Some(server_revision);
3807 reconciliation_patch = Some(resumed.render_snapshot_patch());
3808 if recovered_from_durable {
3809 if let Some((replayed_entries, source_owner)) =
3810 durable_replay_metadata.as_ref()
3811 {
3812 resume_reason = Some(format!(
3813 "durable_revision_reconciled:from={source_owner}:entries={replayed_entries}"
3814 ));
3815 } else {
3816 resume_reason = Some("durable_revision_reconciled".to_string());
3817 }
3818 } else {
3819 resume_reason = Some("revision_reconciled".to_string());
3820 }
3821 } else {
3822 hello_revision = server_revision;
3823 if recovered_from_durable {
3824 if let Some((replayed_entries, source_owner)) =
3825 durable_replay_metadata.as_ref()
3826 {
3827 resume_reason = Some(format!(
3828 "durable_replay:from={source_owner}:entries={replayed_entries}"
3829 ));
3830 } else {
3831 resume_reason = Some("durable_replay".to_string());
3832 }
3833 }
3834 }
3835 resumed
3836 }
3837 }
3838 } else {
3839 if connect.resume_token.is_some() || connect.client_revision > 0 {
3840 resume_status = ResumeStatus::Fallback;
3841 resume_reason = Some("resume_snapshot_missing".to_string());
3842 }
3843 let mut fresh = LiveSession::new_with_route_and_session_id(
3844 (route.factory)(),
3845 session_id.clone(),
3846 config.target_id.clone(),
3847 route.path.clone(),
3848 route.params.clone(),
3849 );
3850 fresh.set_tenant_id_optional(connect.tenant_id.clone());
3851 fresh.set_telemetry_sink(session_telemetry.clone());
3852 if let Err(err) = fresh.mount() {
3853 error!(
3854 path = config.route_path,
3855 session_id,
3856 ?err,
3857 "Shelly websocket mount failed"
3858 );
3859 let _ = send_server_message(
3860 &mut sender,
3861 &ServerMessage::Error {
3862 message: format!("mount failed: {err}"),
3863 code: Some("mount_failed".to_string()),
3864 },
3865 )
3866 .await;
3867 durable_release_lease(&config.durable, &session_id, durable_lease.as_ref());
3868 return;
3869 }
3870 current_route_pattern = route.pattern.clone();
3871 hello_revision = fresh.revision();
3872 fresh
3873 };
3874 if tenant_context_conflict(session.tenant_id(), connect.tenant_id.as_deref()) {
3875 emit_security_audit(
3876 &config.telemetry,
3877 &connection_correlation,
3878 &config.route_path,
3879 Some(&session_id),
3880 session.tenant_id(),
3881 "tenant_policy",
3882 false,
3883 Some("tenant_mismatch"),
3884 SecurityOperation::Connect,
3885 "text",
3886 None,
3887 Some("connect tenant_id does not match restored session tenant context"),
3888 );
3889 let _ = send_server_message(
3890 &mut sender,
3891 &ServerMessage::Error {
3892 message: "connect tenant_id does not match restored session tenant context"
3893 .to_string(),
3894 code: Some("tenant_mismatch".to_string()),
3895 },
3896 )
3897 .await;
3898 durable_release_lease(&config.durable, &session_id, durable_lease.as_ref());
3899 return;
3900 }
3901 if session.tenant_id().is_none() {
3902 session.set_tenant_id_optional(connect.tenant_id.clone());
3903 }
3904 session.set_default_render_cadence_ms(config.render.default_cadence_ms);
3905
3906 let resume_token = if matches!(resume_status, ResumeStatus::Resumed) {
3907 connect.resume_token.clone().unwrap_or_else(|| {
3908 config
3909 .security
3910 .signer
3911 .sign_resume(session.session_id(), session.route_path())
3912 })
3913 } else {
3914 config
3915 .security
3916 .signer
3917 .sign_resume(session.session_id(), session.route_path())
3918 };
3919 let resume_expires_in_ms = config.reconnect.resume_ttl.as_millis() as u64;
3920 let hello = ServerMessage::Hello {
3921 session_id: session.session_id().to_string(),
3922 target: config.target_id.clone(),
3923 revision: hello_revision,
3924 protocol: PROTOCOL_VERSION.to_string(),
3925 server_revision: hello_server_revision,
3926 resume_status: Some(resume_status),
3927 resume_reason: resume_reason.clone(),
3928 resume_token: Some(resume_token.clone()),
3929 resume_expires_in_ms: Some(resume_expires_in_ms),
3930 };
3931
3932 process_pubsub_commands(
3933 session.drain_pubsub_commands(),
3934 &config.pubsub,
3935 &pubsub_tx,
3936 &mut pubsub_tasks,
3937 &mut subscribed_topics,
3938 &session_id,
3939 &config.distributed.node_id,
3940 &config.telemetry,
3941 &connection_correlation,
3942 );
3943 process_runtime_commands(
3944 session.drain_runtime_commands(),
3945 &runtime_tx,
3946 &mut runtime_tasks,
3947 );
3948
3949 info!(
3950 path = config.route_path,
3951 trace_id = %connection_correlation.trace_id,
3952 span_id = %connection_correlation.span_id,
3953 parent_span_id = connection_correlation.parent_span_id.as_deref().unwrap_or("-"),
3954 correlation_id = connection_correlation.correlation_id.as_deref().unwrap_or("-"),
3955 request_id = connection_correlation.request_id.as_deref().unwrap_or("-"),
3956 session_id, "Shelly websocket connected"
3957 );
3958 let _ = config.telemetry.emit(reconnect_event(
3959 &session_id,
3960 session.route_path(),
3961 &connection_correlation,
3962 matches!(resume_status, ResumeStatus::Fresh | ResumeStatus::Resumed),
3963 resume_status,
3964 resume_reason.as_deref(),
3965 ));
3966
3967 let mut active_resume_token = Some(resume_token.clone());
3968 let (outbound_tx, outbound_rx) = mpsc::channel(config.outbound.queue_capacity);
3969 let mut writer_task = Some(tokio::spawn(run_outbound_writer(
3970 sender,
3971 outbound_rx,
3972 config.outbound.clone(),
3973 config.telemetry.clone(),
3974 connection_correlation.clone(),
3975 config.route_path.clone(),
3976 session_id.clone(),
3977 )));
3978 if let (Some(lease), Some(token)) = (durable_lease.as_ref(), active_resume_token.as_deref()) {
3979 durable_save_snapshot(
3980 &config.durable,
3981 lease,
3982 &session,
3983 ¤t_route_pattern,
3984 &config.target_id,
3985 token,
3986 );
3987 }
3988
3989 match queue_server_message(
3990 &outbound_tx,
3991 &hello,
3992 &config.outbound,
3993 &config.telemetry,
3994 &connection_correlation,
3995 &config.route_path,
3996 &session_id,
3997 ) {
3998 OutboundQueuePush::Queued | OutboundQueuePush::Dropped => {}
3999 OutboundQueuePush::Disconnect => {
4000 warn!(
4001 path = config.route_path,
4002 session_id, "Shelly websocket closed before hello"
4003 );
4004 if let Some(task) = writer_task.take() {
4005 task.abort();
4006 let _ = task.await;
4007 }
4008 unregister_pubsub_presence(
4009 &config.pubsub,
4010 &subscribed_topics,
4011 &session_id,
4012 &config.distributed.node_id,
4013 );
4014 abort_pubsub_tasks(pubsub_tasks);
4015 abort_runtime_tasks(runtime_tasks);
4016 stash_reconnect_snapshot(
4017 &config.reconnect,
4018 session_id.clone(),
4019 session,
4020 current_route_pattern.clone(),
4021 active_resume_token.take(),
4022 )
4023 .await;
4024 durable_release_lease(&config.durable, &session_id, durable_lease.as_ref());
4025 return;
4026 }
4027 }
4028
4029 let should_skip_initial_patch = matches!(resume_status, ResumeStatus::Resumed)
4030 && connect.client_revision == session.revision();
4031 if !should_skip_initial_patch {
4032 let initial_patch = reconciliation_patch.unwrap_or_else(|| session.render_patch());
4033 match queue_server_message(
4034 &outbound_tx,
4035 &initial_patch,
4036 &config.outbound,
4037 &config.telemetry,
4038 &connection_correlation,
4039 &config.route_path,
4040 &session_id,
4041 ) {
4042 OutboundQueuePush::Queued | OutboundQueuePush::Dropped => {}
4043 OutboundQueuePush::Disconnect => {
4044 warn!(
4045 path = config.route_path,
4046 session_id, "Shelly websocket closed before initial patch"
4047 );
4048 if let Some(task) = writer_task.take() {
4049 task.abort();
4050 let _ = task.await;
4051 }
4052 unregister_pubsub_presence(
4053 &config.pubsub,
4054 &subscribed_topics,
4055 &session_id,
4056 &config.distributed.node_id,
4057 );
4058 abort_pubsub_tasks(pubsub_tasks);
4059 abort_runtime_tasks(runtime_tasks);
4060 stash_reconnect_snapshot(
4061 &config.reconnect,
4062 session_id.clone(),
4063 session,
4064 current_route_pattern.clone(),
4065 active_resume_token.take(),
4066 )
4067 .await;
4068 durable_release_lease(&config.durable, &session_id, durable_lease.as_ref());
4069 return;
4070 }
4071 }
4072 }
4073 if let (Some(lease), Some(token)) = (durable_lease.as_ref(), active_resume_token.as_deref()) {
4074 durable_save_snapshot(
4075 &config.durable,
4076 lease,
4077 &session,
4078 ¤t_route_pattern,
4079 &config.target_id,
4080 token,
4081 );
4082 }
4083
4084 macro_rules! abort_with_resume {
4085 () => {{
4086 unregister_pubsub_presence(
4087 &config.pubsub,
4088 &subscribed_topics,
4089 &session_id,
4090 &config.distributed.node_id,
4091 );
4092 abort_pubsub_tasks(pubsub_tasks);
4093 abort_runtime_tasks(runtime_tasks);
4094 stash_reconnect_snapshot(
4095 &config.reconnect,
4096 session_id.clone(),
4097 session,
4098 current_route_pattern.clone(),
4099 active_resume_token.take(),
4100 )
4101 .await;
4102 durable_release_lease(&config.durable, &session_id, durable_lease.as_ref());
4103 if let Some(task) = writer_task.take() {
4104 task.abort();
4105 let _ = task.await;
4106 }
4107 return;
4108 }};
4109 }
4110
4111 loop {
4112 tokio::select! {
4113 maybe_message = pubsub_rx.recv() => {
4114 let Some(message) = maybe_message else {
4115 break;
4116 };
4117 let topic = message.topic;
4118
4119 for server_message in message.messages {
4120 match queue_server_message(
4121 &outbound_tx,
4122 &server_message,
4123 &config.outbound,
4124 &config.telemetry,
4125 &connection_correlation,
4126 &config.route_path,
4127 &session_id,
4128 ) {
4129 OutboundQueuePush::Queued => {}
4130 OutboundQueuePush::Dropped => {
4131 warn!(
4132 path = config.route_path,
4133 session_id,
4134 topic,
4135 "Shelly websocket dropped PubSub message due outbound backpressure"
4136 );
4137 }
4138 OutboundQueuePush::Disconnect => {
4139 warn!(
4140 path = config.route_path,
4141 session_id,
4142 topic,
4143 "Shelly websocket closed while sending PubSub message"
4144 );
4145 abort_with_resume!();
4146 }
4147 }
4148 }
4149 }
4150 maybe_runtime = runtime_rx.recv() => {
4151 let Some(runtime_event) = maybe_runtime else {
4152 break;
4153 };
4154 runtime_tasks.retain(|_, task| !task.is_finished());
4155
4156 let runtime_client_message = ClientMessage::Event {
4157 event: runtime_event.event,
4158 target: runtime_event.target,
4159 value: runtime_event.value,
4160 metadata: runtime_event.metadata.into_iter().collect(),
4161 };
4162 let runtime_priority = overload_priority_for_client_message(&runtime_client_message);
4163 let runtime_tenant_id = effective_tenant_id_for_message(
4164 session.tenant_id(),
4165 tenant_id_for_client_message(&runtime_client_message).as_deref(),
4166 );
4167 let (runtime_operation, runtime_event_name, _) =
4168 security_operation_for_message(&runtime_client_message);
4169 let runtime_overload_context = OverloadContext {
4170 route_path: config.route_path.clone(),
4171 session_id: session_id.clone(),
4172 tenant_id: runtime_tenant_id.clone(),
4173 message_kind: "runtime",
4174 operation: runtime_operation,
4175 event_name: runtime_event_name.map(ToString::to_string),
4176 priority: runtime_priority,
4177 queue_depth: outbound_queue_depth(&outbound_tx, config.outbound.queue_capacity),
4178 queue_capacity: config.outbound.queue_capacity,
4179 inbound_bytes: serde_json::to_vec(&runtime_client_message)
4180 .map(|bytes| bytes.len())
4181 .unwrap_or_default(),
4182 };
4183 let runtime_overload_decision =
4184 overload_decision_for_dispatch(&config.overload, &runtime_overload_context).await;
4185 emit_overload_telemetry(
4186 &config.telemetry,
4187 &connection_correlation,
4188 &runtime_overload_context,
4189 &runtime_overload_decision,
4190 );
4191 if !runtime_overload_decision.allowed {
4192 warn!(
4193 path = config.route_path,
4194 session_id,
4195 priority = runtime_priority.as_str(),
4196 reason = runtime_overload_decision.reason.as_deref().unwrap_or("shed"),
4197 "Shelly runtime event shed by overload policy"
4198 );
4199 if runtime_priority == OverloadPriority::Interactive
4200 && matches!(
4201 queue_server_message(
4202 &outbound_tx,
4203 &overload_decision_to_error(&runtime_overload_decision),
4204 &config.outbound,
4205 &config.telemetry,
4206 &connection_correlation,
4207 &config.route_path,
4208 &session_id,
4209 ),
4210 OutboundQueuePush::Disconnect
4211 )
4212 {
4213 abort_with_resume!();
4214 }
4215 continue;
4216 }
4217 apply_overload_throttle(&runtime_overload_decision).await;
4218 let runtime_dispatch_started = Instant::now();
4219 let messages = session.handle_client_message(runtime_client_message.clone());
4220 overload_record_cpu_usage(
4221 &config.overload,
4222 &session_id,
4223 runtime_tenant_id.as_deref(),
4224 runtime_dispatch_started
4225 .elapsed()
4226 .as_millis()
4227 .max(1)
4228 .try_into()
4229 .unwrap_or(u64::MAX),
4230 )
4231 .await;
4232 process_pubsub_commands(
4233 session.drain_pubsub_commands(),
4234 &config.pubsub,
4235 &pubsub_tx,
4236 &mut pubsub_tasks,
4237 &mut subscribed_topics,
4238 &session_id,
4239 &config.distributed.node_id,
4240 &config.telemetry,
4241 &connection_correlation,
4242 );
4243 process_runtime_commands(
4244 session.drain_runtime_commands(),
4245 &runtime_tx,
4246 &mut runtime_tasks,
4247 );
4248 if let (Some(lease), Some(active_resume_token)) =
4249 (durable_lease.as_ref(), active_resume_token.as_deref())
4250 {
4251 if !contains_server_error(&messages) {
4252 if let Err(err) = durable_append_journal_entry(
4253 &config.durable,
4254 lease,
4255 &session_id,
4256 &runtime_client_message,
4257 ) {
4258 let rejection = durable_error_message(&err);
4259 if matches!(
4260 queue_server_message(
4261 &outbound_tx,
4262 &rejection,
4263 &config.outbound,
4264 &config.telemetry,
4265 &connection_correlation,
4266 &config.route_path,
4267 &session_id,
4268 ),
4269 OutboundQueuePush::Disconnect
4270 ) {
4271 abort_with_resume!();
4272 }
4273 abort_with_resume!();
4274 }
4275 durable_save_snapshot(
4276 &config.durable,
4277 lease,
4278 &session,
4279 ¤t_route_pattern,
4280 &config.target_id,
4281 active_resume_token,
4282 );
4283 }
4284 }
4285
4286 for message in messages {
4287 match queue_server_message(
4288 &outbound_tx,
4289 &message,
4290 &config.outbound,
4291 &config.telemetry,
4292 &connection_correlation,
4293 &config.route_path,
4294 &session_id,
4295 ) {
4296 OutboundQueuePush::Queued => {}
4297 OutboundQueuePush::Dropped => {
4298 warn!(
4299 path = config.route_path,
4300 session_id,
4301 "Shelly websocket dropped runtime response due outbound backpressure"
4302 );
4303 }
4304 OutboundQueuePush::Disconnect => {
4305 warn!(
4306 path = config.route_path,
4307 session_id, "Shelly websocket closed while sending runtime response"
4308 );
4309 abort_with_resume!();
4310 }
4311 }
4312 }
4313 }
4314 result = receiver.next() => {
4315 let Some(result) = result else {
4316 break;
4317 };
4318
4319 let message = match result {
4320 Ok(message) => message,
4321 Err(err) => {
4322 error!(?err, "websocket receive error");
4323 break;
4324 }
4325 };
4326
4327 match message {
4328 Message::Text(text) => {
4329 let text = text.as_str();
4330 if rate_limited(&config.security, &config.route_path, &session_id, "text") {
4331 emit_security_audit(
4332 &config.telemetry,
4333 &connection_correlation,
4334 &config.route_path,
4335 Some(&session_id),
4336 session.tenant_id(),
4337 "rate_limiter",
4338 false,
4339 Some("rate_limited"),
4340 SecurityOperation::Event,
4341 "text",
4342 None,
4343 Some("rate limiter rejected text frame"),
4344 );
4345 warn!(
4346 path = config.route_path,
4347 session_id, "Shelly websocket rate limited text message"
4348 );
4349 if matches!(
4350 queue_server_message(
4351 &outbound_tx,
4352 &rate_limited_error(),
4353 &config.outbound,
4354 &config.telemetry,
4355 &connection_correlation,
4356 &config.route_path,
4357 &session_id,
4358 ),
4359 OutboundQueuePush::Disconnect
4360 ) {
4361 abort_with_resume!();
4362 }
4363 continue;
4364 }
4365
4366 if text.len() > config.max_message_size {
4367 warn!(
4368 path = config.route_path,
4369 session_id,
4370 size = text.len(),
4371 max_message_size = config.max_message_size,
4372 "Shelly websocket text payload too large"
4373 );
4374 let _ = queue_server_message(
4375 &outbound_tx,
4376 &payload_too_large_error(text.len(), config.max_message_size),
4377 &config.outbound,
4378 &config.telemetry,
4379 &connection_correlation,
4380 &config.route_path,
4381 &session_id,
4382 );
4383 let _ = queue_close_frame(
4384 &outbound_tx,
4385 Some(payload_too_large_close()),
4386 &config.outbound,
4387 &config.telemetry,
4388 &connection_correlation,
4389 &config.route_path,
4390 &session_id,
4391 );
4392 unregister_pubsub_presence(
4393 &config.pubsub,
4394 &subscribed_topics,
4395 &session_id,
4396 &config.distributed.node_id,
4397 );
4398 abort_pubsub_tasks(pubsub_tasks);
4399 abort_runtime_tasks(runtime_tasks);
4400 if let Some(task) = writer_task.take() {
4401 let _ = timeout(Duration::from_millis(25), task).await;
4402 }
4403 durable_release_lease(&config.durable, &session_id, durable_lease.as_ref());
4404 return;
4405 }
4406
4407 let parsed_client_message = serde_json::from_str::<ClientMessage>(text);
4408 log_ws_text_ingress(
4409 &config.route_path,
4410 &session_id,
4411 text.len(),
4412 parsed_client_message.as_ref(),
4413 &connection_correlation,
4414 );
4415 let mut overload_tenant_id: Option<String> = None;
4416 let mut overload_dispatch_started: Option<Instant> = None;
4417 if let Ok(client_message) = parsed_client_message.as_ref() {
4418 let (operation, event_name, event_target) =
4419 security_operation_for_message(client_message);
4420 let message_tenant_id = tenant_id_for_client_message(client_message);
4421 if tenant_context_conflict(
4422 session.tenant_id(),
4423 message_tenant_id.as_deref(),
4424 ) {
4425 let rejection = ServerMessage::Error {
4426 message:
4427 "tenant context mismatch for session-bound message"
4428 .to_string(),
4429 code: Some("tenant_mismatch".to_string()),
4430 };
4431 emit_security_audit(
4432 &config.telemetry,
4433 &connection_correlation,
4434 &config.route_path,
4435 Some(&session_id),
4436 session.tenant_id(),
4437 "tenant_policy",
4438 false,
4439 Some("tenant_mismatch"),
4440 operation,
4441 "text",
4442 event_name,
4443 Some("message tenant_id does not match session tenant context"),
4444 );
4445 if matches!(
4446 queue_server_message(
4447 &outbound_tx,
4448 &rejection,
4449 &config.outbound,
4450 &config.telemetry,
4451 &connection_correlation,
4452 &config.route_path,
4453 &session_id,
4454 ),
4455 OutboundQueuePush::Disconnect
4456 ) {
4457 abort_with_resume!();
4458 }
4459 continue;
4460 }
4461 let effective_tenant_id = effective_tenant_id_for_message(
4462 session.tenant_id(),
4463 message_tenant_id.as_deref(),
4464 );
4465 if let Some(rejection) = quota_denied(
4466 &config.security,
4467 &config.route_path,
4468 &session_id,
4469 effective_tenant_id.as_deref(),
4470 "text",
4471 operation,
4472 event_name,
4473 ) {
4474 emit_security_audit(
4475 &config.telemetry,
4476 &connection_correlation,
4477 &config.route_path,
4478 Some(&session_id),
4479 effective_tenant_id.as_deref(),
4480 "quota_policy",
4481 false,
4482 Some("quota_exceeded"),
4483 operation,
4484 "text",
4485 event_name,
4486 Some("tenant/session quota policy rejected message"),
4487 );
4488 if matches!(
4489 queue_server_message(
4490 &outbound_tx,
4491 &rejection,
4492 &config.outbound,
4493 &config.telemetry,
4494 &connection_correlation,
4495 &config.route_path,
4496 &session_id,
4497 ),
4498 OutboundQueuePush::Disconnect
4499 ) {
4500 abort_with_resume!();
4501 }
4502 continue;
4503 }
4504 if let Some(rejection) = authorization_denied(
4505 &config.security,
4506 AuthorizationInput {
4507 route_path: &config.route_path,
4508 session_id: &session_id,
4509 tenant_id: effective_tenant_id.as_deref(),
4510 message_kind: "text",
4511 operation,
4512 event_name,
4513 event_target,
4514 },
4515 ) {
4516 emit_security_audit(
4517 &config.telemetry,
4518 &connection_correlation,
4519 &config.route_path,
4520 Some(&session_id),
4521 effective_tenant_id.as_deref(),
4522 "authorization_policy",
4523 false,
4524 Some("unauthorized"),
4525 operation,
4526 "text",
4527 event_name,
4528 Some("tenant/session authorization policy rejected message"),
4529 );
4530 if matches!(
4531 queue_server_message(
4532 &outbound_tx,
4533 &rejection,
4534 &config.outbound,
4535 &config.telemetry,
4536 &connection_correlation,
4537 &config.route_path,
4538 &session_id,
4539 ),
4540 OutboundQueuePush::Disconnect
4541 ) {
4542 abort_with_resume!();
4543 }
4544 continue;
4545 }
4546 let overload_context = OverloadContext {
4547 route_path: config.route_path.clone(),
4548 session_id: session_id.clone(),
4549 tenant_id: effective_tenant_id.clone(),
4550 message_kind: "text",
4551 operation,
4552 event_name: event_name.map(ToString::to_string),
4553 priority: overload_priority_for_client_message(client_message),
4554 queue_depth: outbound_queue_depth(&outbound_tx, config.outbound.queue_capacity),
4555 queue_capacity: config.outbound.queue_capacity,
4556 inbound_bytes: text.len(),
4557 };
4558 let overload_decision =
4559 overload_decision_for_dispatch(&config.overload, &overload_context).await;
4560 emit_overload_telemetry(
4561 &config.telemetry,
4562 &connection_correlation,
4563 &overload_context,
4564 &overload_decision,
4565 );
4566 if !overload_decision.allowed {
4567 warn!(
4568 path = config.route_path,
4569 session_id,
4570 priority = overload_context.priority.as_str(),
4571 reason = overload_decision.reason.as_deref().unwrap_or("shed"),
4572 "Shelly websocket message shed by overload policy"
4573 );
4574 if matches!(
4575 queue_server_message(
4576 &outbound_tx,
4577 &overload_decision_to_error(&overload_decision),
4578 &config.outbound,
4579 &config.telemetry,
4580 &connection_correlation,
4581 &config.route_path,
4582 &session_id,
4583 ),
4584 OutboundQueuePush::Disconnect
4585 ) {
4586 abort_with_resume!();
4587 }
4588 continue;
4589 }
4590 apply_overload_throttle(&overload_decision).await;
4591 overload_tenant_id = overload_context.tenant_id.clone();
4592 overload_dispatch_started = Some(Instant::now());
4593 if operation.is_mutating() {
4594 emit_security_audit(
4595 &config.telemetry,
4596 &connection_correlation,
4597 &config.route_path,
4598 Some(&session_id),
4599 effective_tenant_id.as_deref(),
4600 "mutation_dispatch",
4601 true,
4602 Some("allowed"),
4603 operation,
4604 "text",
4605 event_name,
4606 Some("tenant/session policy accepted mutating operation"),
4607 );
4608 }
4609 }
4610 let journal_candidate = parsed_client_message.as_ref().ok().cloned();
4611 let messages = messages_for_client_message_with_navigation(
4612 &mut session,
4613 parsed_client_message,
4614 TextDispatch {
4615 current_route_pattern: &mut current_route_pattern,
4616 routes: &config.routes,
4617 target_id: &config.target_id,
4618 session_id: &session_id,
4619 upload_config: &config.uploads,
4620 uploads: &mut uploads,
4621 telemetry: &session_telemetry,
4622 },
4623 )
4624 .await;
4625 if let Some(started) = overload_dispatch_started {
4626 overload_record_cpu_usage(
4627 &config.overload,
4628 &session_id,
4629 overload_tenant_id.as_deref(),
4630 started
4631 .elapsed()
4632 .as_millis()
4633 .max(1)
4634 .try_into()
4635 .unwrap_or(u64::MAX),
4636 )
4637 .await;
4638 }
4639 process_pubsub_commands(
4640 session.drain_pubsub_commands(),
4641 &config.pubsub,
4642 &pubsub_tx,
4643 &mut pubsub_tasks,
4644 &mut subscribed_topics,
4645 &session_id,
4646 &config.distributed.node_id,
4647 &config.telemetry,
4648 &connection_correlation,
4649 );
4650 process_runtime_commands(
4651 session.drain_runtime_commands(),
4652 &runtime_tx,
4653 &mut runtime_tasks,
4654 );
4655 if let (Some(lease), Some(client_message), Some(active_resume_token)) = (
4656 durable_lease.as_ref(),
4657 journal_candidate.as_ref(),
4658 active_resume_token.as_deref(),
4659 ) {
4660 if durable_message_should_journal(client_message)
4661 && !contains_server_error(&messages)
4662 {
4663 if let Err(err) = durable_append_journal_entry(
4664 &config.durable,
4665 lease,
4666 &session_id,
4667 client_message,
4668 ) {
4669 let rejection = durable_error_message(&err);
4670 if matches!(
4671 queue_server_message(
4672 &outbound_tx,
4673 &rejection,
4674 &config.outbound,
4675 &config.telemetry,
4676 &connection_correlation,
4677 &config.route_path,
4678 &session_id,
4679 ),
4680 OutboundQueuePush::Disconnect
4681 ) {
4682 abort_with_resume!();
4683 }
4684 abort_with_resume!();
4685 }
4686 durable_save_snapshot(
4687 &config.durable,
4688 lease,
4689 &session,
4690 ¤t_route_pattern,
4691 &config.target_id,
4692 active_resume_token,
4693 );
4694 }
4695 }
4696
4697 for message in messages {
4698 emit_upload_lifecycle_telemetry(
4699 &config.telemetry,
4700 session.session_id(),
4701 session.route_path(),
4702 &message,
4703 &connection_correlation,
4704 );
4705 match queue_server_message(
4706 &outbound_tx,
4707 &message,
4708 &config.outbound,
4709 &config.telemetry,
4710 &connection_correlation,
4711 &config.route_path,
4712 &session_id,
4713 ) {
4714 OutboundQueuePush::Queued => {}
4715 OutboundQueuePush::Dropped => {
4716 warn!(
4717 path = config.route_path,
4718 session_id,
4719 "Shelly websocket dropped response due outbound backpressure"
4720 );
4721 }
4722 OutboundQueuePush::Disconnect => {
4723 warn!(
4724 path = config.route_path,
4725 session_id, "Shelly websocket closed while sending response"
4726 );
4727 abort_with_resume!();
4728 }
4729 }
4730 }
4731 }
4732 Message::Binary(payload) if payload.len() > config.max_message_size => {
4733 info!(
4734 target: "shelly.incoming.ws",
4735 schema = "shelly.incoming.ws.v1",
4736 route = %config.route_path,
4737 session_id,
4738 trace_id = %connection_correlation.trace_id,
4739 span_id = %connection_correlation.span_id,
4740 parent_span_id = connection_correlation.parent_span_id.as_deref().unwrap_or("-"),
4741 correlation_id = connection_correlation.correlation_id.as_deref().unwrap_or("-"),
4742 request_id = connection_correlation.request_id.as_deref().unwrap_or("-"),
4743 frame_kind = "binary",
4744 bytes = payload.len(),
4745 "Shelly websocket frame received"
4746 );
4747 if rate_limited(&config.security, &config.route_path, &session_id, "binary") {
4748 emit_security_audit(
4749 &config.telemetry,
4750 &connection_correlation,
4751 &config.route_path,
4752 Some(&session_id),
4753 session.tenant_id(),
4754 "rate_limiter",
4755 false,
4756 Some("rate_limited"),
4757 SecurityOperation::Binary,
4758 "binary",
4759 None,
4760 Some("rate limiter rejected binary frame"),
4761 );
4762 warn!(
4763 path = config.route_path,
4764 session_id, "Shelly websocket rate limited binary message"
4765 );
4766 if matches!(
4767 queue_server_message(
4768 &outbound_tx,
4769 &rate_limited_error(),
4770 &config.outbound,
4771 &config.telemetry,
4772 &connection_correlation,
4773 &config.route_path,
4774 &session_id,
4775 ),
4776 OutboundQueuePush::Disconnect
4777 ) {
4778 abort_with_resume!();
4779 }
4780 continue;
4781 }
4782
4783 warn!(
4784 path = config.route_path,
4785 session_id,
4786 size = payload.len(),
4787 max_message_size = config.max_message_size,
4788 "Shelly websocket binary payload too large"
4789 );
4790 let _ = queue_server_message(
4791 &outbound_tx,
4792 &payload_too_large_error(payload.len(), config.max_message_size),
4793 &config.outbound,
4794 &config.telemetry,
4795 &connection_correlation,
4796 &config.route_path,
4797 &session_id,
4798 );
4799 let _ = queue_close_frame(
4800 &outbound_tx,
4801 Some(payload_too_large_close()),
4802 &config.outbound,
4803 &config.telemetry,
4804 &connection_correlation,
4805 &config.route_path,
4806 &session_id,
4807 );
4808 abort_with_resume!();
4809 }
4810 Message::Binary(_) => {
4811 info!(
4812 target: "shelly.incoming.ws",
4813 schema = "shelly.incoming.ws.v1",
4814 route = %config.route_path,
4815 session_id,
4816 trace_id = %connection_correlation.trace_id,
4817 span_id = %connection_correlation.span_id,
4818 parent_span_id = connection_correlation.parent_span_id.as_deref().unwrap_or("-"),
4819 correlation_id = connection_correlation.correlation_id.as_deref().unwrap_or("-"),
4820 request_id = connection_correlation.request_id.as_deref().unwrap_or("-"),
4821 frame_kind = "binary",
4822 "Shelly websocket frame received"
4823 );
4824 if rate_limited(&config.security, &config.route_path, &session_id, "binary") {
4825 emit_security_audit(
4826 &config.telemetry,
4827 &connection_correlation,
4828 &config.route_path,
4829 Some(&session_id),
4830 session.tenant_id(),
4831 "rate_limiter",
4832 false,
4833 Some("rate_limited"),
4834 SecurityOperation::Binary,
4835 "binary",
4836 None,
4837 Some("rate limiter rejected binary frame"),
4838 );
4839 warn!(
4840 path = config.route_path,
4841 session_id, "Shelly websocket rate limited binary message"
4842 );
4843 if matches!(
4844 queue_server_message(
4845 &outbound_tx,
4846 &rate_limited_error(),
4847 &config.outbound,
4848 &config.telemetry,
4849 &connection_correlation,
4850 &config.route_path,
4851 &session_id,
4852 ),
4853 OutboundQueuePush::Disconnect
4854 ) {
4855 abort_with_resume!();
4856 }
4857 continue;
4858 }
4859 if let Some(rejection) = quota_denied(
4860 &config.security,
4861 &config.route_path,
4862 &session_id,
4863 session.tenant_id(),
4864 "binary",
4865 SecurityOperation::Binary,
4866 None,
4867 ) {
4868 emit_security_audit(
4869 &config.telemetry,
4870 &connection_correlation,
4871 &config.route_path,
4872 Some(&session_id),
4873 session.tenant_id(),
4874 "quota_policy",
4875 false,
4876 Some("quota_exceeded"),
4877 SecurityOperation::Binary,
4878 "binary",
4879 None,
4880 Some("tenant/session quota policy rejected binary frame"),
4881 );
4882 if matches!(
4883 queue_server_message(
4884 &outbound_tx,
4885 &rejection,
4886 &config.outbound,
4887 &config.telemetry,
4888 &connection_correlation,
4889 &config.route_path,
4890 &session_id,
4891 ),
4892 OutboundQueuePush::Disconnect
4893 ) {
4894 abort_with_resume!();
4895 }
4896 continue;
4897 }
4898 if let Some(rejection) = authorization_denied(
4899 &config.security,
4900 AuthorizationInput {
4901 route_path: &config.route_path,
4902 session_id: &session_id,
4903 tenant_id: session.tenant_id(),
4904 message_kind: "binary",
4905 operation: SecurityOperation::Binary,
4906 event_name: None,
4907 event_target: None,
4908 },
4909 ) {
4910 emit_security_audit(
4911 &config.telemetry,
4912 &connection_correlation,
4913 &config.route_path,
4914 Some(&session_id),
4915 session.tenant_id(),
4916 "authorization_policy",
4917 false,
4918 Some("unauthorized"),
4919 SecurityOperation::Binary,
4920 "binary",
4921 None,
4922 Some("tenant/session authorization policy rejected binary frame"),
4923 );
4924 if matches!(
4925 queue_server_message(
4926 &outbound_tx,
4927 &rejection,
4928 &config.outbound,
4929 &config.telemetry,
4930 &connection_correlation,
4931 &config.route_path,
4932 &session_id,
4933 ),
4934 OutboundQueuePush::Disconnect
4935 ) {
4936 abort_with_resume!();
4937 }
4938 continue;
4939 }
4940
4941 warn!(
4942 path = config.route_path,
4943 session_id, "Shelly websocket rejected binary message"
4944 );
4945 if matches!(
4946 queue_server_message(
4947 &outbound_tx,
4948 &unsupported_binary_error(),
4949 &config.outbound,
4950 &config.telemetry,
4951 &connection_correlation,
4952 &config.route_path,
4953 &session_id,
4954 ),
4955 OutboundQueuePush::Disconnect
4956 ) {
4957 abort_with_resume!();
4958 }
4959 }
4960 Message::Ping(payload) => {
4961 info!(
4962 target: "shelly.incoming.ws",
4963 schema = "shelly.incoming.ws.v1",
4964 route = %config.route_path,
4965 session_id,
4966 trace_id = %connection_correlation.trace_id,
4967 span_id = %connection_correlation.span_id,
4968 parent_span_id = connection_correlation.parent_span_id.as_deref().unwrap_or("-"),
4969 correlation_id = connection_correlation.correlation_id.as_deref().unwrap_or("-"),
4970 request_id = connection_correlation.request_id.as_deref().unwrap_or("-"),
4971 frame_kind = "ping",
4972 bytes = payload.len(),
4973 "Shelly websocket frame received"
4974 );
4975 if matches!(
4976 queue_pong_frame(
4977 &outbound_tx,
4978 payload.to_vec(),
4979 &config.outbound,
4980 &config.telemetry,
4981 &connection_correlation,
4982 &config.route_path,
4983 &session_id,
4984 ),
4985 OutboundQueuePush::Disconnect
4986 ) {
4987 abort_with_resume!();
4988 }
4989 }
4990 Message::Close(frame) => {
4991 info!(
4992 target: "shelly.incoming.ws",
4993 schema = "shelly.incoming.ws.v1",
4994 route = %config.route_path,
4995 session_id,
4996 trace_id = %connection_correlation.trace_id,
4997 span_id = %connection_correlation.span_id,
4998 parent_span_id = connection_correlation.parent_span_id.as_deref().unwrap_or("-"),
4999 correlation_id = connection_correlation.correlation_id.as_deref().unwrap_or("-"),
5000 request_id = connection_correlation.request_id.as_deref().unwrap_or("-"),
5001 frame_kind = "close",
5002 close_frame = ?frame,
5003 "Shelly websocket frame received"
5004 );
5005 break
5006 }
5007 Message::Pong(payload) => {
5008 info!(
5009 target: "shelly.incoming.ws",
5010 schema = "shelly.incoming.ws.v1",
5011 route = %config.route_path,
5012 session_id,
5013 trace_id = %connection_correlation.trace_id,
5014 span_id = %connection_correlation.span_id,
5015 parent_span_id = connection_correlation.parent_span_id.as_deref().unwrap_or("-"),
5016 correlation_id = connection_correlation.correlation_id.as_deref().unwrap_or("-"),
5017 request_id = connection_correlation.request_id.as_deref().unwrap_or("-"),
5018 frame_kind = "pong",
5019 bytes = payload.len(),
5020 "Shelly websocket frame received"
5021 );
5022 }
5023 }
5024 }
5025 }
5026 }
5027
5028 unregister_pubsub_presence(
5029 &config.pubsub,
5030 &subscribed_topics,
5031 &session_id,
5032 &config.distributed.node_id,
5033 );
5034 abort_pubsub_tasks(pubsub_tasks);
5035 abort_runtime_tasks(runtime_tasks);
5036 drop(outbound_tx);
5037 if let Some(task) = writer_task.take() {
5038 let _ = task.await;
5039 }
5040 let final_route_path = session.route_path().to_string();
5041 stash_reconnect_snapshot(
5042 &config.reconnect,
5043 session_id.clone(),
5044 session,
5045 current_route_pattern.clone(),
5046 active_resume_token.take(),
5047 )
5048 .await;
5049 durable_release_lease(&config.durable, &session_id, durable_lease.as_ref());
5050 info!(
5051 path = config.route_path,
5052 trace_id = %connection_correlation.trace_id,
5053 span_id = %connection_correlation.span_id,
5054 parent_span_id = connection_correlation.parent_span_id.as_deref().unwrap_or("-"),
5055 correlation_id = connection_correlation.correlation_id.as_deref().unwrap_or("-"),
5056 request_id = connection_correlation.request_id.as_deref().unwrap_or("-"),
5057 session_id, "Shelly websocket disconnected"
5058 );
5059 let _ = config.telemetry.emit(
5060 connection_correlation.apply_to_event(
5061 TelemetryEvent::new(TelemetryEventKind::Disconnect)
5062 .with_session(session_id)
5063 .with_route(final_route_path),
5064 ),
5065 );
5066}
5067
5068#[allow(clippy::too_many_arguments)]
5069fn process_pubsub_commands(
5070 commands: Vec<PubSubCommand>,
5071 pubsub: &PubSub,
5072 sender: &mpsc::UnboundedSender<shelly::PubSubMessage>,
5073 tasks: &mut Vec<JoinHandle<()>>,
5074 subscribed_topics: &mut HashSet<String>,
5075 session_id: &str,
5076 node_id: &str,
5077 telemetry: &Arc<TelemetryPipeline>,
5078 correlation: &CorrelationContext,
5079) {
5080 for command in commands {
5081 match command {
5082 PubSubCommand::Subscribe { topic } => {
5083 if subscribed_topics.insert(topic.clone()) {
5084 pubsub.register_presence(topic.clone(), session_id, node_id);
5085 let mut subscription = pubsub.subscribe(topic);
5086 let sender = sender.clone();
5087 tasks.push(tokio::spawn(async move {
5088 loop {
5089 match subscription.recv().await {
5090 Ok(message) => {
5091 if sender.send(message).is_err() {
5092 break;
5093 }
5094 }
5095 Err(shelly::PubSubReceiveError::Lagged(skipped)) => {
5096 warn!(skipped, "Shelly PubSub subscriber lagged");
5097 }
5098 Err(shelly::PubSubReceiveError::Closed) => break,
5099 }
5100 }
5101 }));
5102 }
5103 }
5104 PubSubCommand::Broadcast { topic, messages } => {
5105 let message_count = messages.len();
5106 let recipients = pubsub.broadcast(topic.clone(), messages);
5107 let _ = telemetry.emit(
5108 correlation.apply_to_event(
5109 TelemetryEvent::new(TelemetryEventKind::PubSubFanout)
5110 .with_ok(true)
5111 .with_count(recipients)
5112 .with_attribute("topic".to_string(), serde_json::Value::String(topic))
5113 .with_attribute(
5114 "message_count".to_string(),
5115 JsonValue::Number(serde_json::Number::from(message_count)),
5116 ),
5117 ),
5118 );
5119 }
5120 }
5121 }
5122}
5123
5124fn process_runtime_commands(
5125 commands: Vec<RuntimeCommand>,
5126 sender: &mpsc::UnboundedSender<shelly::RuntimeEvent>,
5127 tasks: &mut HashMap<String, JoinHandle<()>>,
5128) {
5129 tasks.retain(|_, task| !task.is_finished());
5130
5131 for command in commands {
5132 match command {
5133 RuntimeCommand::ScheduleOnce {
5134 id,
5135 delay_ms,
5136 dispatch,
5137 } => {
5138 if let Some(task) = tasks.remove(&id) {
5139 task.abort();
5140 }
5141 let sender = sender.clone();
5142 tasks.insert(
5143 id,
5144 tokio::spawn(async move {
5145 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
5146 let _ = sender.send(dispatch);
5147 }),
5148 );
5149 }
5150 RuntimeCommand::ScheduleInterval {
5151 id,
5152 every_ms,
5153 dispatch,
5154 } => {
5155 if let Some(task) = tasks.remove(&id) {
5156 task.abort();
5157 }
5158 let sender = sender.clone();
5159 tasks.insert(
5160 id,
5161 tokio::spawn(async move {
5162 let interval_ms = every_ms.max(1);
5163 let mut interval = tokio::time::interval_at(
5164 Instant::now() + Duration::from_millis(interval_ms),
5165 Duration::from_millis(interval_ms),
5166 );
5167 interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
5168 loop {
5169 interval.tick().await;
5170 if sender.send(dispatch.clone()).is_err() {
5171 break;
5172 }
5173 }
5174 }),
5175 );
5176 }
5177 RuntimeCommand::Cancel { id } => {
5178 if let Some(task) = tasks.remove(&id) {
5179 task.abort();
5180 }
5181 }
5182 }
5183 }
5184}
5185
5186fn unregister_pubsub_presence(
5187 pubsub: &PubSub,
5188 subscribed_topics: &HashSet<String>,
5189 session_id: &str,
5190 node_id: &str,
5191) {
5192 for topic in subscribed_topics {
5193 pubsub.unregister_presence(topic.clone(), session_id.to_string(), node_id.to_string());
5194 }
5195}
5196
5197fn abort_pubsub_tasks(tasks: Vec<JoinHandle<()>>) {
5198 for task in tasks {
5199 task.abort();
5200 }
5201}
5202
5203fn abort_runtime_tasks(tasks: HashMap<String, JoinHandle<()>>) {
5204 for task in tasks.into_values() {
5205 task.abort();
5206 }
5207}
5208
5209fn chrono_like_timestamp() -> String {
5210 use std::time::{SystemTime, UNIX_EPOCH};
5211 let millis = SystemTime::now()
5212 .duration_since(UNIX_EPOCH)
5213 .unwrap_or_default()
5214 .as_millis();
5215 format!("{millis}")
5216}
5217
5218fn now_unix_ms() -> u64 {
5219 SystemTime::now()
5220 .duration_since(UNIX_EPOCH)
5221 .map(|duration| duration.as_millis() as u64)
5222 .unwrap_or(0)
5223}
5224
5225#[cfg(test)]
5226fn messages_for_text_payload(session: &mut LiveSession, text: &str) -> Vec<ServerMessage> {
5227 match serde_json::from_str::<ClientMessage>(text) {
5228 Ok(message) => session.handle_client_message(message),
5229 Err(err) => vec![ServerMessage::Error {
5230 message: format!("invalid protocol message: {err}"),
5231 code: Some("invalid_protocol".to_string()),
5232 }],
5233 }
5234}
5235
5236#[derive(Debug)]
5237struct ClientMessageLogFields {
5238 message_type: &'static str,
5239 connect_session_id: Option<String>,
5240 connect_last_revision: Option<u64>,
5241 has_resume_token: Option<bool>,
5242 event_name: Option<String>,
5243 event_target: Option<String>,
5244 navigation_to: Option<String>,
5245 upload_id: Option<String>,
5246 upload_event: Option<String>,
5247}
5248
5249fn client_message_log_fields(message: &ClientMessage) -> ClientMessageLogFields {
5250 match message {
5251 ClientMessage::Connect {
5252 session_id,
5253 last_revision,
5254 resume_token,
5255 ..
5256 } => ClientMessageLogFields {
5257 message_type: "connect",
5258 connect_session_id: session_id.clone(),
5259 connect_last_revision: *last_revision,
5260 has_resume_token: Some(
5261 resume_token
5262 .as_ref()
5263 .map(|token| !token.is_empty())
5264 .unwrap_or(false),
5265 ),
5266 event_name: None,
5267 event_target: None,
5268 navigation_to: None,
5269 upload_id: None,
5270 upload_event: None,
5271 },
5272 ClientMessage::Event { event, target, .. } => ClientMessageLogFields {
5273 message_type: "event",
5274 connect_session_id: None,
5275 connect_last_revision: None,
5276 has_resume_token: None,
5277 event_name: Some(event.clone()),
5278 event_target: target.clone(),
5279 navigation_to: None,
5280 upload_id: None,
5281 upload_event: None,
5282 },
5283 ClientMessage::Ping { .. } => ClientMessageLogFields {
5284 message_type: "ping",
5285 connect_session_id: None,
5286 connect_last_revision: None,
5287 has_resume_token: None,
5288 event_name: None,
5289 event_target: None,
5290 navigation_to: None,
5291 upload_id: None,
5292 upload_event: None,
5293 },
5294 ClientMessage::PatchUrl { to } => ClientMessageLogFields {
5295 message_type: "patch_url",
5296 connect_session_id: None,
5297 connect_last_revision: None,
5298 has_resume_token: None,
5299 event_name: None,
5300 event_target: None,
5301 navigation_to: Some(to.clone()),
5302 upload_id: None,
5303 upload_event: None,
5304 },
5305 ClientMessage::Navigate { to } => ClientMessageLogFields {
5306 message_type: "navigate",
5307 connect_session_id: None,
5308 connect_last_revision: None,
5309 has_resume_token: None,
5310 event_name: None,
5311 event_target: None,
5312 navigation_to: Some(to.clone()),
5313 upload_id: None,
5314 upload_event: None,
5315 },
5316 ClientMessage::UploadStart {
5317 upload_id,
5318 event,
5319 target,
5320 ..
5321 } => ClientMessageLogFields {
5322 message_type: "upload_start",
5323 connect_session_id: None,
5324 connect_last_revision: None,
5325 has_resume_token: None,
5326 event_name: None,
5327 event_target: target.clone(),
5328 navigation_to: None,
5329 upload_id: Some(upload_id.clone()),
5330 upload_event: Some(event.clone()),
5331 },
5332 ClientMessage::UploadChunk { upload_id, .. } => ClientMessageLogFields {
5333 message_type: "upload_chunk",
5334 connect_session_id: None,
5335 connect_last_revision: None,
5336 has_resume_token: None,
5337 event_name: None,
5338 event_target: None,
5339 navigation_to: None,
5340 upload_id: Some(upload_id.clone()),
5341 upload_event: None,
5342 },
5343 ClientMessage::UploadComplete { upload_id } => ClientMessageLogFields {
5344 message_type: "upload_complete",
5345 connect_session_id: None,
5346 connect_last_revision: None,
5347 has_resume_token: None,
5348 event_name: None,
5349 event_target: None,
5350 navigation_to: None,
5351 upload_id: Some(upload_id.clone()),
5352 upload_event: None,
5353 },
5354 }
5355}
5356
5357fn log_ws_text_ingress(
5358 route_path: &str,
5359 session_id: &str,
5360 payload_size: usize,
5361 message: Result<&ClientMessage, &serde_json::Error>,
5362 correlation: &CorrelationContext,
5363) {
5364 match message {
5365 Ok(message) => {
5366 let fields = client_message_log_fields(message);
5367 info!(
5368 target: "shelly.incoming.ws",
5369 schema = "shelly.incoming.ws.v1",
5370 route = route_path,
5371 session_id,
5372 trace_id = %correlation.trace_id,
5373 span_id = %correlation.span_id,
5374 parent_span_id = correlation.parent_span_id.as_deref().unwrap_or("-"),
5375 correlation_id = correlation.correlation_id.as_deref().unwrap_or("-"),
5376 request_id = correlation.request_id.as_deref().unwrap_or("-"),
5377 frame_kind = "text",
5378 bytes = payload_size,
5379 protocol_message_type = fields.message_type,
5380 connect_session_id = ?fields.connect_session_id,
5381 connect_last_revision = ?fields.connect_last_revision,
5382 has_resume_token = ?fields.has_resume_token,
5383 event_name = ?fields.event_name,
5384 event_target = ?fields.event_target,
5385 navigation_to = ?fields.navigation_to,
5386 upload_id = ?fields.upload_id,
5387 upload_event = ?fields.upload_event,
5388 "Shelly websocket frame received"
5389 );
5390 }
5391 Err(err) => {
5392 warn!(
5393 target: "shelly.incoming.ws",
5394 schema = "shelly.incoming.ws.v1",
5395 route = route_path,
5396 session_id,
5397 trace_id = %correlation.trace_id,
5398 span_id = %correlation.span_id,
5399 parent_span_id = correlation.parent_span_id.as_deref().unwrap_or("-"),
5400 correlation_id = correlation.correlation_id.as_deref().unwrap_or("-"),
5401 request_id = correlation.request_id.as_deref().unwrap_or("-"),
5402 frame_kind = "text",
5403 bytes = payload_size,
5404 protocol_message_type = "invalid_protocol",
5405 parse_error = %err,
5406 "Shelly websocket frame failed protocol parsing"
5407 );
5408 }
5409 }
5410}
5411
5412async fn messages_for_client_message_with_navigation(
5413 session: &mut LiveSession,
5414 message: Result<ClientMessage, serde_json::Error>,
5415 dispatch: TextDispatch<'_>,
5416) -> Vec<ServerMessage> {
5417 match message {
5418 Ok(ClientMessage::PatchUrl { to }) => handle_patch_url(
5419 session,
5420 dispatch.current_route_pattern,
5421 dispatch.routes,
5422 &to,
5423 ),
5424 Ok(ClientMessage::Navigate { to }) => handle_navigate(
5425 session,
5426 dispatch.current_route_pattern,
5427 dispatch.routes,
5428 dispatch.target_id,
5429 dispatch.session_id,
5430 dispatch.telemetry,
5431 &to,
5432 ),
5433 Ok(ClientMessage::UploadStart {
5434 upload_id,
5435 event,
5436 target,
5437 name,
5438 size,
5439 content_type,
5440 }) => {
5441 handle_upload_start(
5442 dispatch.uploads,
5443 dispatch.upload_config,
5444 UploadStartRequest {
5445 upload_id,
5446 event,
5447 target,
5448 name,
5449 size,
5450 content_type,
5451 },
5452 )
5453 .await
5454 }
5455 Ok(ClientMessage::UploadChunk {
5456 upload_id,
5457 offset,
5458 data,
5459 }) => handle_upload_chunk(dispatch.uploads, &upload_id, offset, &data).await,
5460 Ok(ClientMessage::UploadComplete { upload_id }) => {
5461 handle_upload_complete(session, dispatch.uploads, &upload_id).await
5462 }
5463 Ok(message) => session.handle_client_message(message),
5464 Err(err) => vec![ServerMessage::Error {
5465 message: format!("invalid protocol message: {err}"),
5466 code: Some("invalid_protocol".to_string()),
5467 }],
5468 }
5469}
5470
5471async fn handle_upload_start(
5472 uploads: &mut HashMap<String, UploadEntry>,
5473 config: &UploadConfig,
5474 request: UploadStartRequest,
5475) -> Vec<ServerMessage> {
5476 if request.size > config.max_file_size {
5477 return vec![upload_error(
5478 request.upload_id,
5479 "upload exceeds configured size limit",
5480 "upload_too_large",
5481 )];
5482 }
5483
5484 if !upload_content_type_allowed(config, request.content_type.as_deref()) {
5485 return vec![upload_error(
5486 request.upload_id,
5487 "upload content type is not allowed",
5488 "upload_type_not_allowed",
5489 )];
5490 }
5491
5492 if let Err(err) = tokio::fs::create_dir_all(config.temp_dir.as_ref()).await {
5493 return vec![upload_error(
5494 request.upload_id,
5495 format!("upload temp directory failed: {err}"),
5496 "upload_temp_failed",
5497 )];
5498 }
5499
5500 let path = config
5501 .temp_dir
5502 .join(format!("{}.upload", request.upload_id));
5503 let file = match tokio::fs::File::create(&path).await {
5504 Ok(file) => file,
5505 Err(err) => {
5506 return vec![upload_error(
5507 request.upload_id,
5508 format!("upload temp file failed: {err}"),
5509 "upload_temp_failed",
5510 )];
5511 }
5512 };
5513
5514 let upload_id = request.upload_id;
5515 let size = request.size;
5516 uploads.insert(
5517 upload_id.clone(),
5518 UploadEntry {
5519 event: request.event,
5520 target: request.target,
5521 name: request.name,
5522 size,
5523 content_type: request.content_type,
5524 path,
5525 file,
5526 received: 0,
5527 },
5528 );
5529
5530 vec![ServerMessage::UploadProgress {
5531 upload_id,
5532 received: 0,
5533 total: size,
5534 }]
5535}
5536
5537async fn handle_upload_chunk(
5538 uploads: &mut HashMap<String, UploadEntry>,
5539 upload_id: &str,
5540 offset: u64,
5541 data: &str,
5542) -> Vec<ServerMessage> {
5543 let Some(entry) = uploads.get_mut(upload_id) else {
5544 return vec![upload_error(
5545 upload_id,
5546 "unknown upload",
5547 "upload_not_found",
5548 )];
5549 };
5550
5551 if offset != entry.received {
5552 return vec![upload_error(
5553 upload_id,
5554 "upload chunk offset is out of order",
5555 "upload_offset_mismatch",
5556 )];
5557 }
5558
5559 let bytes = match STANDARD.decode(data) {
5560 Ok(bytes) => bytes,
5561 Err(_) => {
5562 return vec![upload_error(
5563 upload_id,
5564 "upload chunk is not valid base64",
5565 "upload_invalid_chunk",
5566 )];
5567 }
5568 };
5569
5570 let next_received = entry.received + bytes.len() as u64;
5571 if next_received > entry.size {
5572 return vec![upload_error(
5573 upload_id,
5574 "upload chunk exceeds negotiated size",
5575 "upload_too_large",
5576 )];
5577 }
5578
5579 if let Err(err) = entry.file.write_all(&bytes).await {
5580 return vec![upload_error(
5581 upload_id,
5582 format!("upload chunk write failed: {err}"),
5583 "upload_write_failed",
5584 )];
5585 }
5586
5587 entry.received = next_received;
5588 vec![ServerMessage::UploadProgress {
5589 upload_id: upload_id.to_string(),
5590 received: entry.received,
5591 total: entry.size,
5592 }]
5593}
5594
5595async fn handle_upload_complete(
5596 session: &mut LiveSession,
5597 uploads: &mut HashMap<String, UploadEntry>,
5598 upload_id: &str,
5599) -> Vec<ServerMessage> {
5600 let Some(mut entry) = uploads.remove(upload_id) else {
5601 return vec![upload_error(
5602 upload_id,
5603 "unknown upload",
5604 "upload_not_found",
5605 )];
5606 };
5607
5608 if entry.received != entry.size {
5609 let _ = tokio::fs::remove_file(&entry.path).await;
5610 return vec![upload_error(
5611 upload_id,
5612 "upload completed before all bytes were received",
5613 "upload_incomplete",
5614 )];
5615 }
5616
5617 if let Err(err) = entry.file.flush().await {
5618 let _ = tokio::fs::remove_file(&entry.path).await;
5619 return vec![upload_error(
5620 upload_id,
5621 format!("upload flush failed: {err}"),
5622 "upload_write_failed",
5623 )];
5624 }
5625
5626 let mut messages = vec![ServerMessage::UploadComplete {
5627 upload_id: upload_id.to_string(),
5628 name: entry.name.clone(),
5629 size: entry.size,
5630 content_type: entry.content_type.clone(),
5631 }];
5632 messages.extend(session.handle_client_message(ClientMessage::Event {
5633 event: entry.event,
5634 target: entry.target,
5635 value: serde_json::json!({
5636 "upload_id": upload_id,
5637 "name": entry.name,
5638 "size": entry.size,
5639 "content_type": entry.content_type,
5640 "path": entry.path.display().to_string(),
5641 }),
5642 metadata: serde_json::Map::from_iter([
5643 (
5644 "tag".to_string(),
5645 serde_json::Value::String("INPUT".to_string()),
5646 ),
5647 (
5648 "name".to_string(),
5649 serde_json::Value::String("upload".to_string()),
5650 ),
5651 ]),
5652 }));
5653 messages
5654}
5655
5656fn upload_content_type_allowed(config: &UploadConfig, content_type: Option<&str>) -> bool {
5657 config.allowed_content_types.is_empty()
5658 || content_type
5659 .map(|value| {
5660 config
5661 .allowed_content_types
5662 .iter()
5663 .any(|allowed| allowed == value)
5664 })
5665 .unwrap_or(false)
5666}
5667
5668fn upload_error(
5669 upload_id: impl Into<String>,
5670 message: impl Into<String>,
5671 code: impl Into<String>,
5672) -> ServerMessage {
5673 ServerMessage::UploadError {
5674 upload_id: upload_id.into(),
5675 message: message.into(),
5676 code: Some(code.into()),
5677 }
5678}
5679
5680fn emit_upload_lifecycle_telemetry(
5681 telemetry: &Arc<TelemetryPipeline>,
5682 session_id: &str,
5683 route_path: &str,
5684 message: &ServerMessage,
5685 correlation: &CorrelationContext,
5686) {
5687 let Some(event) = upload_lifecycle_event(session_id, route_path, message, correlation) else {
5688 return;
5689 };
5690 let _ = telemetry.emit(event);
5691}
5692
5693fn upload_lifecycle_event(
5694 session_id: &str,
5695 route_path: &str,
5696 message: &ServerMessage,
5697 correlation: &CorrelationContext,
5698) -> Option<TelemetryEvent> {
5699 let event = match message {
5700 ServerMessage::UploadProgress {
5701 upload_id,
5702 received,
5703 total,
5704 } => {
5705 let phase = if *received == 0 { "start" } else { "progress" };
5706 TelemetryEvent::new(TelemetryEventKind::UploadLifecycle)
5707 .with_session(session_id.to_string())
5708 .with_route(route_path.to_string())
5709 .with_ok(true)
5710 .with_bytes(*received as usize)
5711 .with_count(1)
5712 .with_attribute("phase".to_string(), JsonValue::String(phase.to_string()))
5713 .with_attribute(
5714 "upload_id".to_string(),
5715 JsonValue::String(upload_id.clone()),
5716 )
5717 .with_attribute(
5718 "received".to_string(),
5719 JsonValue::Number(serde_json::Number::from(*received)),
5720 )
5721 .with_attribute(
5722 "total".to_string(),
5723 JsonValue::Number(serde_json::Number::from(*total)),
5724 )
5725 }
5726 ServerMessage::UploadComplete {
5727 upload_id,
5728 name,
5729 size,
5730 content_type,
5731 } => TelemetryEvent::new(TelemetryEventKind::UploadLifecycle)
5732 .with_session(session_id.to_string())
5733 .with_route(route_path.to_string())
5734 .with_ok(true)
5735 .with_bytes(*size as usize)
5736 .with_count(1)
5737 .with_attribute(
5738 "phase".to_string(),
5739 JsonValue::String("complete".to_string()),
5740 )
5741 .with_attribute(
5742 "upload_id".to_string(),
5743 JsonValue::String(upload_id.clone()),
5744 )
5745 .with_attribute("name".to_string(), JsonValue::String(name.clone()))
5746 .with_attribute("size".to_string(), JsonValue::Number((*size).into()))
5747 .with_attribute(
5748 "content_type".to_string(),
5749 content_type
5750 .as_ref()
5751 .map(|value| JsonValue::String(value.clone()))
5752 .unwrap_or(JsonValue::Null),
5753 ),
5754 ServerMessage::UploadError {
5755 upload_id,
5756 code,
5757 message,
5758 } => TelemetryEvent::new(TelemetryEventKind::UploadLifecycle)
5759 .with_session(session_id.to_string())
5760 .with_route(route_path.to_string())
5761 .with_ok(false)
5762 .with_count(1)
5763 .with_attribute("phase".to_string(), JsonValue::String("error".to_string()))
5764 .with_attribute(
5765 "upload_id".to_string(),
5766 JsonValue::String(upload_id.clone()),
5767 )
5768 .with_attribute("message".to_string(), JsonValue::String(message.clone()))
5769 .with_attribute(
5770 "code".to_string(),
5771 code.as_ref()
5772 .map(|value| JsonValue::String(value.clone()))
5773 .unwrap_or(JsonValue::Null),
5774 ),
5775 _ => return None,
5776 };
5777
5778 Some(correlation.apply_to_event(event))
5779}
5780
5781fn handle_patch_url(
5782 session: &mut LiveSession,
5783 current_route_pattern: &mut String,
5784 routes: &[LiveRoute],
5785 to: &str,
5786) -> Vec<ServerMessage> {
5787 let Some(path) = internal_path(to) else {
5788 return vec![navigation_error(
5789 "invalid_navigation",
5790 format!("internal navigation requires an absolute path: {to}"),
5791 )];
5792 };
5793
5794 let Some(route) = routes.iter().find_map(|route| route.match_path(&path)) else {
5795 return vec![navigation_error(
5796 "navigation_route_not_found",
5797 format!("no Shelly route registered for {path}"),
5798 )];
5799 };
5800
5801 if route.pattern != *current_route_pattern {
5802 return vec![navigation_error(
5803 "navigation_route_mismatch",
5804 format!(
5805 "shelly-patch cannot switch from {current_route_pattern} to {}",
5806 route.pattern
5807 ),
5808 )];
5809 }
5810
5811 if let Err(err) = session.patch_route(route.path, route.params) {
5812 return vec![ServerMessage::Error {
5813 message: err.to_string(),
5814 code: Some("route_patch_failed".to_string()),
5815 }];
5816 }
5817
5818 vec![
5819 ServerMessage::PatchUrl { to: to.to_string() },
5820 session.render_update(),
5821 ]
5822}
5823
5824fn handle_navigate(
5825 session: &mut LiveSession,
5826 current_route_pattern: &mut String,
5827 routes: &[LiveRoute],
5828 target_id: &str,
5829 session_id: &str,
5830 telemetry: &Arc<dyn TelemetrySink>,
5831 to: &str,
5832) -> Vec<ServerMessage> {
5833 let Some(path) = internal_path(to) else {
5834 return vec![navigation_error(
5835 "invalid_navigation",
5836 format!("internal navigation requires an absolute path: {to}"),
5837 )];
5838 };
5839
5840 let Some(route) = routes.iter().find_map(|route| route.match_path(&path)) else {
5841 return vec![navigation_error(
5842 "navigation_route_not_found",
5843 format!("no Shelly route registered for {path}"),
5844 )];
5845 };
5846
5847 let mut next = LiveSession::new_with_route_and_session_id(
5848 (route.factory)(),
5849 session_id.to_string(),
5850 target_id.to_string(),
5851 route.path,
5852 route.params,
5853 );
5854 next.set_telemetry_sink(telemetry.clone());
5855 if let Err(err) = next.mount() {
5856 return vec![ServerMessage::Error {
5857 message: format!("navigation mount failed: {err}"),
5858 code: Some("navigation_failed".to_string()),
5859 }];
5860 }
5861
5862 *current_route_pattern = route.pattern;
5863 *session = next;
5864 vec![
5865 ServerMessage::Navigate { to: to.to_string() },
5866 session.hello(),
5867 session.render_patch(),
5868 ]
5869}
5870
5871fn internal_path(to: &str) -> Option<String> {
5872 if !to.starts_with('/') || to.starts_with("//") {
5873 return None;
5874 }
5875
5876 let path = to
5877 .split(['?', '#'])
5878 .next()
5879 .filter(|path| !path.is_empty())?;
5880 Some(normalize_path(path.to_string()))
5881}
5882
5883fn navigation_error(code: impl Into<String>, message: impl Into<String>) -> ServerMessage {
5884 ServerMessage::Error {
5885 message: message.into(),
5886 code: Some(code.into()),
5887 }
5888}
5889
5890fn parse_query(query: &str) -> HashMap<String, String> {
5891 query
5892 .split('&')
5893 .filter(|part| !part.is_empty())
5894 .filter_map(|part| {
5895 let (key, value) = part.split_once('=')?;
5896 Some((percent_decode(key)?, percent_decode(value)?))
5897 })
5898 .collect()
5899}
5900
5901fn percent_decode(value: &str) -> Option<String> {
5902 let bytes = value.as_bytes();
5903 let mut out = Vec::with_capacity(bytes.len());
5904 let mut index = 0;
5905
5906 while index < bytes.len() {
5907 match bytes[index] {
5908 b'+' => {
5909 out.push(b' ');
5910 index += 1;
5911 }
5912 b'%' => {
5913 if index + 2 >= bytes.len() {
5914 return None;
5915 }
5916 let high = from_hex(bytes[index + 1])?;
5917 let low = from_hex(bytes[index + 2])?;
5918 out.push((high << 4) | low);
5919 index += 3;
5920 }
5921 byte => {
5922 out.push(byte);
5923 index += 1;
5924 }
5925 }
5926 }
5927
5928 String::from_utf8(out).ok()
5929}
5930
5931fn from_hex(byte: u8) -> Option<u8> {
5932 match byte {
5933 b'0'..=b'9' => Some(byte - b'0'),
5934 b'a'..=b'f' => Some(byte - b'a' + 10),
5935 b'A'..=b'F' => Some(byte - b'A' + 10),
5936 _ => None,
5937 }
5938}
5939
5940fn verify_websocket_tokens(
5941 signer: &TokenSigner,
5942 query: &HashMap<String, String>,
5943 route_path: &str,
5944) -> Option<SignedSession> {
5945 let session = signer.verify_session(query.get("session")?)?;
5946 if session.path != route_path {
5947 return None;
5948 }
5949
5950 let csrf = query.get("csrf")?;
5951 if !signer.verify_csrf(csrf, &session.session_id, route_path) {
5952 return None;
5953 }
5954
5955 Some(session)
5956}
5957
5958fn origin_allowed(headers: &HeaderMap, security: &SecurityConfig) -> bool {
5959 let Some(origin) = headers
5960 .get(header::ORIGIN)
5961 .and_then(|value| value.to_str().ok())
5962 else {
5963 return true;
5964 };
5965
5966 if security
5967 .allowed_origins
5968 .iter()
5969 .any(|allowed| allowed == origin)
5970 {
5971 return true;
5972 }
5973
5974 let Some(host) = headers
5975 .get(header::HOST)
5976 .and_then(|value| value.to_str().ok())
5977 else {
5978 return false;
5979 };
5980
5981 origin == format!("http://{host}") || origin == format!("https://{host}")
5982}
5983
5984fn security_operation_for_message(
5985 message: &ClientMessage,
5986) -> (SecurityOperation, Option<&str>, Option<&str>) {
5987 match message {
5988 ClientMessage::Connect { .. } => (SecurityOperation::Connect, None, None),
5989 ClientMessage::Event { event, target, .. } => (
5990 SecurityOperation::Event,
5991 Some(event.as_str()),
5992 target.as_deref(),
5993 ),
5994 ClientMessage::Ping { .. } => (SecurityOperation::Ping, None, None),
5995 ClientMessage::PatchUrl { .. } => (SecurityOperation::PatchUrl, None, None),
5996 ClientMessage::Navigate { .. } => (SecurityOperation::Navigate, None, None),
5997 ClientMessage::UploadStart { event, target, .. } => (
5998 SecurityOperation::UploadStart,
5999 Some(event.as_str()),
6000 target.as_deref(),
6001 ),
6002 ClientMessage::UploadChunk { .. } => (SecurityOperation::UploadChunk, None, None),
6003 ClientMessage::UploadComplete { .. } => (SecurityOperation::UploadComplete, None, None),
6004 }
6005}
6006
6007struct AuthorizationInput<'a> {
6008 route_path: &'a str,
6009 session_id: &'a str,
6010 tenant_id: Option<&'a str>,
6011 message_kind: &'static str,
6012 operation: SecurityOperation,
6013 event_name: Option<&'a str>,
6014 event_target: Option<&'a str>,
6015}
6016
6017fn authorization_denied(
6018 security: &SecurityConfig,
6019 input: AuthorizationInput<'_>,
6020) -> Option<ServerMessage> {
6021 let hook = security.authorization.as_ref()?;
6022 let decision = hook(&AuthorizationContext {
6023 route_path: input.route_path.to_string(),
6024 session_id: input.session_id.to_string(),
6025 tenant_id: input.tenant_id.map(ToString::to_string),
6026 message_kind: input.message_kind,
6027 operation: input.operation,
6028 event_name: input.event_name.map(ToString::to_string),
6029 event_target: input.event_target.map(ToString::to_string),
6030 });
6031 if decision.allowed {
6032 None
6033 } else {
6034 Some(ServerMessage::Error {
6035 message: decision
6036 .message
6037 .unwrap_or_else(|| "authorization policy rejected message".to_string()),
6038 code: Some(decision.code.unwrap_or_else(|| "unauthorized".to_string())),
6039 })
6040 }
6041}
6042
6043fn quota_denied(
6044 security: &SecurityConfig,
6045 route_path: &str,
6046 session_id: &str,
6047 tenant_id: Option<&str>,
6048 message_kind: &'static str,
6049 operation: SecurityOperation,
6050 event_name: Option<&str>,
6051) -> Option<ServerMessage> {
6052 let hook = security.quota_policy.as_ref()?;
6053 let decision = hook(&QuotaContext {
6054 route_path: route_path.to_string(),
6055 session_id: session_id.to_string(),
6056 tenant_id: tenant_id.map(ToString::to_string),
6057 message_kind,
6058 operation,
6059 event_name: event_name.map(ToString::to_string),
6060 });
6061 if decision.allowed {
6062 None
6063 } else {
6064 Some(ServerMessage::Error {
6065 message: decision
6066 .message
6067 .unwrap_or_else(|| "quota policy rejected message".to_string()),
6068 code: Some(
6069 decision
6070 .code
6071 .unwrap_or_else(|| "quota_exceeded".to_string()),
6072 ),
6073 })
6074 }
6075}
6076
6077fn overload_priority_for_client_message(message: &ClientMessage) -> OverloadPriority {
6078 match message {
6079 ClientMessage::Event {
6080 event, metadata, ..
6081 } => {
6082 let metadata_priority = metadata
6083 .get("priority")
6084 .and_then(JsonValue::as_str)
6085 .map(|value| value.trim().to_ascii_lowercase());
6086 if matches!(metadata_priority.as_deref(), Some("background"))
6087 || event.starts_with("bg:")
6088 || event.starts_with("cron:")
6089 {
6090 OverloadPriority::Background
6091 } else {
6092 OverloadPriority::Interactive
6093 }
6094 }
6095 _ => OverloadPriority::Interactive,
6096 }
6097}
6098
6099fn tenant_id_for_client_message(message: &ClientMessage) -> Option<String> {
6100 match message {
6101 ClientMessage::Connect { tenant_id, .. } => normalize_tenant_id(tenant_id.clone()),
6102 ClientMessage::Event {
6103 metadata, value, ..
6104 } => metadata
6105 .get("tenant_id")
6106 .and_then(JsonValue::as_str)
6107 .or_else(|| value.get("tenant_id").and_then(JsonValue::as_str))
6108 .and_then(|value| normalize_tenant_id(Some(value.to_string()))),
6109 _ => None,
6110 }
6111}
6112
6113fn normalize_tenant_id(tenant_id: Option<String>) -> Option<String> {
6114 tenant_id
6115 .map(|value| value.trim().to_string())
6116 .filter(|value| !value.is_empty())
6117}
6118
6119fn normalize_tenant_id_ref(tenant_id: Option<&str>) -> Option<String> {
6120 tenant_id
6121 .map(str::trim)
6122 .filter(|value| !value.is_empty())
6123 .map(ToString::to_string)
6124}
6125
6126fn tenant_context_conflict(
6127 session_tenant_id: Option<&str>,
6128 message_tenant_id: Option<&str>,
6129) -> bool {
6130 match (
6131 normalize_tenant_id_ref(session_tenant_id),
6132 normalize_tenant_id_ref(message_tenant_id),
6133 ) {
6134 (Some(session_tenant), Some(message_tenant)) => session_tenant != message_tenant,
6135 _ => false,
6136 }
6137}
6138
6139fn effective_tenant_id_for_message(
6140 session_tenant_id: Option<&str>,
6141 message_tenant_id: Option<&str>,
6142) -> Option<String> {
6143 normalize_tenant_id_ref(message_tenant_id)
6144 .or_else(|| normalize_tenant_id_ref(session_tenant_id))
6145}
6146
6147fn outbound_queue_depth(sender: &mpsc::Sender<OutboundEnvelope>, queue_capacity: usize) -> usize {
6148 queue_capacity.saturating_sub(sender.capacity())
6149}
6150
6151fn overload_saturation_reason(
6152 session_saturated: bool,
6153 tenant_saturated: bool,
6154 queue_saturated: bool,
6155) -> String {
6156 let mut reasons = Vec::new();
6157 if queue_saturated {
6158 reasons.push("queue_depth");
6159 }
6160 if session_saturated {
6161 reasons.push("session_budget");
6162 }
6163 if tenant_saturated {
6164 reasons.push("tenant_budget");
6165 }
6166 if reasons.is_empty() {
6167 "none".to_string()
6168 } else {
6169 reasons.join("+")
6170 }
6171}
6172
6173async fn overload_decision_for_dispatch(
6174 overload: &OverloadConfig,
6175 context: &OverloadContext,
6176) -> OverloadDecision {
6177 let mut state = overload.state.lock().await;
6178 let now = Instant::now();
6179 let (projected_session_bytes, projected_session_cpu_ms, projected_session_queue) = {
6180 let session_window = state
6181 .session_windows
6182 .entry(context.session_id.clone())
6183 .or_insert_with(|| OverloadBudgetWindow::new(now));
6184 session_window.roll_window_if_needed(now, overload.window);
6185 (
6186 session_window.bytes.saturating_add(context.inbound_bytes),
6187 session_window.cpu_ms,
6188 session_window.queue_depth_peak.max(context.queue_depth),
6189 )
6190 };
6191 let session_saturated = projected_session_queue > overload.budgets.session_queue_depth
6192 || projected_session_bytes > overload.budgets.session_bytes_per_sec
6193 || projected_session_cpu_ms > overload.budgets.session_cpu_ms_per_sec;
6194
6195 let mut tenant_saturated = false;
6196 if let Some(tenant_id) = context.tenant_id.as_deref() {
6197 let tenant_window = state
6198 .tenant_windows
6199 .entry(tenant_id.to_string())
6200 .or_insert_with(|| OverloadBudgetWindow::new(now));
6201 tenant_window.roll_window_if_needed(now, overload.window);
6202 let projected_tenant_bytes = tenant_window.bytes.saturating_add(context.inbound_bytes);
6203 let projected_tenant_cpu_ms = tenant_window.cpu_ms;
6204 let projected_tenant_queue = tenant_window.queue_depth_peak.max(context.queue_depth);
6205 tenant_saturated = projected_tenant_queue > overload.budgets.tenant_queue_depth
6206 || projected_tenant_bytes > overload.budgets.tenant_bytes_per_sec
6207 || projected_tenant_cpu_ms > overload.budgets.tenant_cpu_ms_per_sec;
6208 }
6209
6210 let queue_saturated = context.queue_depth >= context.queue_capacity;
6211 let saturated = queue_saturated || session_saturated || tenant_saturated;
6212 let saturation_reason =
6213 overload_saturation_reason(session_saturated, tenant_saturated, queue_saturated);
6214
6215 let base_decision = if !saturated {
6216 OverloadDecision::allow()
6217 } else if context.priority == OverloadPriority::Background
6218 && overload.shed_policy == OverloadShedPolicy::PreferInteractive
6219 {
6220 OverloadDecision::shed(format!("background_shed:{saturation_reason}"))
6221 } else if queue_saturated || overload.shed_policy == OverloadShedPolicy::Strict {
6222 OverloadDecision::shed(format!("shed:{saturation_reason}"))
6223 } else {
6224 OverloadDecision::throttle(10, format!("throttle:{saturation_reason}"))
6225 };
6226
6227 let decision = if let Some(policy_hook) = overload.policy_hook.as_ref() {
6228 policy_hook(context, &base_decision)
6229 } else {
6230 base_decision
6231 };
6232
6233 if decision.allowed {
6234 let session_window = state
6235 .session_windows
6236 .entry(context.session_id.clone())
6237 .or_insert_with(|| OverloadBudgetWindow::new(now));
6238 session_window.roll_window_if_needed(now, overload.window);
6239 session_window.events = session_window.events.saturating_add(1);
6240 session_window.bytes = session_window.bytes.saturating_add(context.inbound_bytes);
6241 session_window.queue_depth_peak = session_window.queue_depth_peak.max(context.queue_depth);
6242
6243 if let Some(tenant_id) = context.tenant_id.as_deref() {
6244 let tenant_window = state
6245 .tenant_windows
6246 .entry(tenant_id.to_string())
6247 .or_insert_with(|| OverloadBudgetWindow::new(now));
6248 tenant_window.roll_window_if_needed(now, overload.window);
6249 tenant_window.events = tenant_window.events.saturating_add(1);
6250 tenant_window.bytes = tenant_window.bytes.saturating_add(context.inbound_bytes);
6251 tenant_window.queue_depth_peak =
6252 tenant_window.queue_depth_peak.max(context.queue_depth);
6253 }
6254 }
6255
6256 decision
6257}
6258
6259async fn overload_record_cpu_usage(
6260 overload: &OverloadConfig,
6261 session_id: &str,
6262 tenant_id: Option<&str>,
6263 cpu_ms: u64,
6264) {
6265 let mut state = overload.state.lock().await;
6266 let now = Instant::now();
6267
6268 let session_window = state
6269 .session_windows
6270 .entry(session_id.to_string())
6271 .or_insert_with(|| OverloadBudgetWindow::new(now));
6272 session_window.roll_window_if_needed(now, overload.window);
6273 session_window.cpu_ms = session_window.cpu_ms.saturating_add(cpu_ms);
6274
6275 if let Some(tenant_id) = tenant_id {
6276 let tenant_window = state
6277 .tenant_windows
6278 .entry(tenant_id.to_string())
6279 .or_insert_with(|| OverloadBudgetWindow::new(now));
6280 tenant_window.roll_window_if_needed(now, overload.window);
6281 tenant_window.cpu_ms = tenant_window.cpu_ms.saturating_add(cpu_ms);
6282 }
6283}
6284
6285fn overload_decision_to_error(decision: &OverloadDecision) -> ServerMessage {
6286 ServerMessage::Error {
6287 message: decision
6288 .message
6289 .clone()
6290 .unwrap_or_else(|| "server overloaded".to_string()),
6291 code: decision.code.clone(),
6292 }
6293}
6294
6295fn overload_telemetry_event(
6296 correlation: &CorrelationContext,
6297 context: &OverloadContext,
6298 decision: &OverloadDecision,
6299) -> Option<TelemetryEvent> {
6300 if decision.allowed && decision.throttle_ms == 0 {
6301 return None;
6302 }
6303 let action = if !decision.allowed {
6304 "shed"
6305 } else if decision.throttle_ms > 0 {
6306 "throttle"
6307 } else {
6308 "allow"
6309 };
6310 let mut event = TelemetryEvent::new(TelemetryEventKind::Error)
6311 .with_session(context.session_id.clone())
6312 .with_route(context.route_path.clone())
6313 .with_ok(decision.allowed)
6314 .with_attribute(
6315 "phase".to_string(),
6316 JsonValue::String("overload_control".to_string()),
6317 )
6318 .with_attribute("action".to_string(), JsonValue::String(action.to_string()))
6319 .with_attribute(
6320 "priority".to_string(),
6321 JsonValue::String(context.priority.as_str().to_string()),
6322 )
6323 .with_attribute(
6324 "queue_depth".to_string(),
6325 JsonValue::Number(serde_json::Number::from(context.queue_depth)),
6326 )
6327 .with_attribute(
6328 "queue_capacity".to_string(),
6329 JsonValue::Number(serde_json::Number::from(context.queue_capacity)),
6330 )
6331 .with_attribute(
6332 "inbound_bytes".to_string(),
6333 JsonValue::Number(serde_json::Number::from(context.inbound_bytes)),
6334 );
6335 if decision.throttle_ms > 0 {
6336 event = event.with_latency_ms(decision.throttle_ms);
6337 }
6338 if let Some(event_name) = context.event_name.as_ref() {
6339 event = event.with_event_name(event_name.clone());
6340 }
6341 if let Some(tenant_id) = context.tenant_id.as_ref() {
6342 event = event.with_attribute(
6343 "tenant_id".to_string(),
6344 JsonValue::String(tenant_id.clone()),
6345 );
6346 }
6347 if let Some(reason) = decision.reason.as_ref() {
6348 event = event.with_attribute("reason".to_string(), JsonValue::String(reason.clone()));
6349 }
6350 Some(correlation.apply_to_event(event))
6351}
6352
6353fn emit_overload_telemetry(
6354 telemetry: &Arc<TelemetryPipeline>,
6355 correlation: &CorrelationContext,
6356 context: &OverloadContext,
6357 decision: &OverloadDecision,
6358) {
6359 if let Some(event) = overload_telemetry_event(correlation, context, decision) {
6360 let _ = telemetry.emit(event);
6361 }
6362}
6363
6364async fn apply_overload_throttle(decision: &OverloadDecision) {
6365 if decision.allowed && decision.throttle_ms > 0 {
6366 sleep(Duration::from_millis(decision.throttle_ms)).await;
6367 }
6368}
6369
6370#[allow(clippy::too_many_arguments)]
6371fn security_audit_event(
6372 correlation: &CorrelationContext,
6373 route_path: &str,
6374 session_id: Option<&str>,
6375 tenant_id: Option<&str>,
6376 control: &'static str,
6377 ok: bool,
6378 code: Option<&str>,
6379 operation: SecurityOperation,
6380 message_kind: &'static str,
6381 event_name: Option<&str>,
6382 policy_reason: Option<&str>,
6383) -> TelemetryEvent {
6384 let mut event = TelemetryEvent::new(TelemetryEventKind::SecurityAudit)
6385 .with_route(route_path.to_string())
6386 .with_ok(ok)
6387 .with_count(1)
6388 .with_attribute(
6389 "control".to_string(),
6390 JsonValue::String(control.to_string()),
6391 )
6392 .with_attribute(
6393 "operation".to_string(),
6394 JsonValue::String(operation.as_str().to_string()),
6395 )
6396 .with_attribute(
6397 "message_kind".to_string(),
6398 JsonValue::String(message_kind.to_string()),
6399 );
6400
6401 if let Some(session_id) = session_id {
6402 event = event.with_session(session_id.to_string());
6403 }
6404 if let Some(tenant_id) = tenant_id {
6405 event = event.with_attribute(
6406 "tenant_id".to_string(),
6407 JsonValue::String(tenant_id.to_string()),
6408 );
6409 }
6410 if let Some(code) = code {
6411 event = event.with_attribute("code".to_string(), JsonValue::String(code.to_string()));
6412 }
6413 if let Some(event_name) = event_name {
6414 event = event.with_event_name(event_name.to_string());
6415 }
6416 if let Some(policy_reason) = policy_reason {
6417 event = event.with_attribute(
6418 "policy_reason".to_string(),
6419 JsonValue::String(policy_reason.to_string()),
6420 );
6421 }
6422
6423 correlation.apply_to_event(event)
6424}
6425
6426#[allow(clippy::too_many_arguments)]
6427fn emit_security_audit(
6428 telemetry: &Arc<TelemetryPipeline>,
6429 correlation: &CorrelationContext,
6430 route_path: &str,
6431 session_id: Option<&str>,
6432 tenant_id: Option<&str>,
6433 control: &'static str,
6434 ok: bool,
6435 code: Option<&str>,
6436 operation: SecurityOperation,
6437 message_kind: &'static str,
6438 event_name: Option<&str>,
6439 policy_reason: Option<&str>,
6440) {
6441 let _ = telemetry.emit(security_audit_event(
6442 correlation,
6443 route_path,
6444 session_id,
6445 tenant_id,
6446 control,
6447 ok,
6448 code,
6449 operation,
6450 message_kind,
6451 event_name,
6452 policy_reason,
6453 ));
6454}
6455
6456fn rate_limited(
6457 security: &SecurityConfig,
6458 route_path: &str,
6459 session_id: &str,
6460 message_kind: &'static str,
6461) -> bool {
6462 let Some(rate_limiter) = &security.rate_limiter else {
6463 return false;
6464 };
6465
6466 !rate_limiter(&RateLimitContext {
6467 route_path: route_path.to_string(),
6468 session_id: session_id.to_string(),
6469 message_kind,
6470 })
6471}
6472
6473fn payload_too_large_error(size: usize, max_message_size: usize) -> ServerMessage {
6474 ServerMessage::Error {
6475 message: format!("payload too large: {size} bytes exceeds {max_message_size} byte limit"),
6476 code: Some("payload_too_large".to_string()),
6477 }
6478}
6479
6480fn unsupported_binary_error() -> ServerMessage {
6481 ServerMessage::Error {
6482 message: "binary protocol messages are not supported by Shelly v0".to_string(),
6483 code: Some("unsupported_message_type".to_string()),
6484 }
6485}
6486
6487fn rate_limited_error() -> ServerMessage {
6488 ServerMessage::Error {
6489 message: "rate limit exceeded".to_string(),
6490 code: Some("rate_limited".to_string()),
6491 }
6492}
6493
6494fn payload_too_large_close() -> CloseFrame {
6495 CloseFrame {
6496 code: close_code::SIZE,
6497 reason: "payload too large".into(),
6498 }
6499}
6500
6501#[derive(Debug, Clone)]
6502struct OutboundCloseFrame {
6503 code: u16,
6504 reason: String,
6505}
6506
6507#[derive(Debug, Clone)]
6508enum OutboundEnvelope {
6509 Text(String),
6510 Pong(Vec<u8>),
6511 Close(Option<OutboundCloseFrame>),
6512}
6513
6514impl OutboundEnvelope {
6515 fn estimated_bytes(&self) -> usize {
6516 match self {
6517 Self::Text(encoded) => encoded.len(),
6518 Self::Pong(payload) => payload.len(),
6519 Self::Close(frame) => frame
6520 .as_ref()
6521 .map(|close| close.reason.len().saturating_add(2))
6522 .unwrap_or(0),
6523 }
6524 }
6525}
6526
6527enum OutboundQueuePush {
6528 Queued,
6529 Dropped,
6530 Disconnect,
6531}
6532
6533fn emit_outbound_batch_telemetry(
6534 telemetry: &Arc<TelemetryPipeline>,
6535 correlation: &CorrelationContext,
6536 route_path: &str,
6537 session_id: &str,
6538 message_count: usize,
6539 bytes: usize,
6540 flush_count: usize,
6541) {
6542 let _ = telemetry.emit(
6543 correlation.apply_to_event(
6544 TelemetryEvent::new(TelemetryEventKind::Error)
6545 .with_session(session_id)
6546 .with_route(route_path)
6547 .with_ok(true)
6548 .with_count(message_count)
6549 .with_bytes(bytes)
6550 .with_attribute(
6551 "phase".to_string(),
6552 JsonValue::String("ws_outbound_batch".to_string()),
6553 )
6554 .with_attribute(
6555 "flush_count".to_string(),
6556 JsonValue::Number(serde_json::Number::from(flush_count)),
6557 ),
6558 ),
6559 );
6560}
6561
6562fn emit_outbound_overflow_telemetry(
6563 telemetry: &Arc<TelemetryPipeline>,
6564 correlation: &CorrelationContext,
6565 route_path: &str,
6566 session_id: &str,
6567 policy: OutboundOverflowPolicy,
6568 queue_capacity: usize,
6569) {
6570 let _ = telemetry.emit(
6571 correlation.apply_to_event(
6572 TelemetryEvent::new(TelemetryEventKind::Error)
6573 .with_session(session_id)
6574 .with_route(route_path)
6575 .with_ok(false)
6576 .with_attribute(
6577 "phase".to_string(),
6578 JsonValue::String("ws_outbound_overflow".to_string()),
6579 )
6580 .with_attribute(
6581 "overflow_policy".to_string(),
6582 JsonValue::String(
6583 match policy {
6584 OutboundOverflowPolicy::Disconnect => "disconnect",
6585 OutboundOverflowPolicy::DropNewest => "drop_newest",
6586 }
6587 .to_string(),
6588 ),
6589 )
6590 .with_attribute(
6591 "queue_capacity".to_string(),
6592 JsonValue::Number(serde_json::Number::from(queue_capacity)),
6593 ),
6594 ),
6595 );
6596}
6597
6598fn queue_outbound_envelope(
6599 sender: &mpsc::Sender<OutboundEnvelope>,
6600 envelope: OutboundEnvelope,
6601 config: &OutboundConfig,
6602 telemetry: &Arc<TelemetryPipeline>,
6603 correlation: &CorrelationContext,
6604 route_path: &str,
6605 session_id: &str,
6606) -> OutboundQueuePush {
6607 match sender.try_send(envelope) {
6608 Ok(()) => OutboundQueuePush::Queued,
6609 Err(mpsc::error::TrySendError::Full(_)) => {
6610 emit_outbound_overflow_telemetry(
6611 telemetry,
6612 correlation,
6613 route_path,
6614 session_id,
6615 config.overflow_policy,
6616 config.queue_capacity,
6617 );
6618 match config.overflow_policy {
6619 OutboundOverflowPolicy::Disconnect => OutboundQueuePush::Disconnect,
6620 OutboundOverflowPolicy::DropNewest => OutboundQueuePush::Dropped,
6621 }
6622 }
6623 Err(mpsc::error::TrySendError::Closed(_)) => OutboundQueuePush::Disconnect,
6624 }
6625}
6626
6627fn queue_server_message(
6628 sender: &mpsc::Sender<OutboundEnvelope>,
6629 message: &ServerMessage,
6630 config: &OutboundConfig,
6631 telemetry: &Arc<TelemetryPipeline>,
6632 correlation: &CorrelationContext,
6633 route_path: &str,
6634 session_id: &str,
6635) -> OutboundQueuePush {
6636 let encoded = serde_json::to_string(message).expect("server messages should serialize");
6637 queue_outbound_envelope(
6638 sender,
6639 OutboundEnvelope::Text(encoded),
6640 config,
6641 telemetry,
6642 correlation,
6643 route_path,
6644 session_id,
6645 )
6646}
6647
6648fn queue_pong_frame(
6649 sender: &mpsc::Sender<OutboundEnvelope>,
6650 payload: Vec<u8>,
6651 config: &OutboundConfig,
6652 telemetry: &Arc<TelemetryPipeline>,
6653 correlation: &CorrelationContext,
6654 route_path: &str,
6655 session_id: &str,
6656) -> OutboundQueuePush {
6657 queue_outbound_envelope(
6658 sender,
6659 OutboundEnvelope::Pong(payload),
6660 config,
6661 telemetry,
6662 correlation,
6663 route_path,
6664 session_id,
6665 )
6666}
6667
6668fn queue_close_frame(
6669 sender: &mpsc::Sender<OutboundEnvelope>,
6670 frame: Option<CloseFrame>,
6671 config: &OutboundConfig,
6672 telemetry: &Arc<TelemetryPipeline>,
6673 correlation: &CorrelationContext,
6674 route_path: &str,
6675 session_id: &str,
6676) -> OutboundQueuePush {
6677 let close = frame.map(|value| OutboundCloseFrame {
6678 code: value.code,
6679 reason: value.reason.to_string(),
6680 });
6681 queue_outbound_envelope(
6682 sender,
6683 OutboundEnvelope::Close(close),
6684 config,
6685 telemetry,
6686 correlation,
6687 route_path,
6688 session_id,
6689 )
6690}
6691
6692async fn run_outbound_writer(
6693 mut sender: futures_util::stream::SplitSink<WebSocket, Message>,
6694 mut receiver: mpsc::Receiver<OutboundEnvelope>,
6695 config: OutboundConfig,
6696 telemetry: Arc<TelemetryPipeline>,
6697 correlation: CorrelationContext,
6698 route_path: String,
6699 session_id: String,
6700) -> Result<(), axum::Error> {
6701 let mut pending: Option<OutboundEnvelope> = None;
6702 let mut should_stop = false;
6703
6704 while !should_stop {
6705 let first = if let Some(envelope) = pending.take() {
6706 Some(envelope)
6707 } else {
6708 receiver.recv().await
6709 };
6710 let Some(first) = first else {
6711 break;
6712 };
6713
6714 let mut batch = vec![first];
6715 let mut batch_bytes = batch[0].estimated_bytes();
6716 let deadline = Instant::now() + config.batch_flush_interval;
6717 let mut channel_closed = false;
6718
6719 loop {
6720 if batch.len() >= config.batch_max_messages || batch_bytes >= config.batch_max_bytes {
6721 break;
6722 }
6723
6724 match receiver.try_recv() {
6725 Ok(next) => {
6726 let next_bytes = next.estimated_bytes();
6727 let exceeds_limits = batch.len() + 1 > config.batch_max_messages
6728 || batch_bytes.saturating_add(next_bytes) > config.batch_max_bytes;
6729 if exceeds_limits {
6730 pending = Some(next);
6731 break;
6732 }
6733 batch_bytes = batch_bytes.saturating_add(next_bytes);
6734 batch.push(next);
6735 }
6736 Err(mpsc::error::TryRecvError::Empty) => {
6737 if batch.len() == 1 {
6738 break;
6739 }
6740 let remaining = deadline.saturating_duration_since(Instant::now());
6741 if remaining.is_zero() {
6742 break;
6743 }
6744 match timeout(remaining, receiver.recv()).await {
6745 Ok(Some(next)) => {
6746 let next_bytes = next.estimated_bytes();
6747 let exceeds_limits = batch.len() + 1 > config.batch_max_messages
6748 || batch_bytes.saturating_add(next_bytes) > config.batch_max_bytes;
6749 if exceeds_limits {
6750 pending = Some(next);
6751 break;
6752 }
6753 batch_bytes = batch_bytes.saturating_add(next_bytes);
6754 batch.push(next);
6755 }
6756 Ok(None) => {
6757 channel_closed = true;
6758 break;
6759 }
6760 Err(_) => break,
6761 }
6762 }
6763 Err(mpsc::error::TryRecvError::Disconnected) => {
6764 channel_closed = true;
6765 break;
6766 }
6767 }
6768 }
6769
6770 let mut message_count = 0_usize;
6771 let mut flush_count = 0_usize;
6772 let mut flushed_bytes = 0_usize;
6773 for envelope in batch {
6774 flushed_bytes = flushed_bytes.saturating_add(envelope.estimated_bytes());
6775 match envelope {
6776 OutboundEnvelope::Text(encoded) => {
6777 message_count = message_count.saturating_add(1);
6778 sender.feed(Message::Text(encoded.into())).await?;
6779 }
6780 OutboundEnvelope::Pong(payload) => {
6781 message_count = message_count.saturating_add(1);
6782 sender.feed(Message::Pong(payload.into())).await?;
6783 }
6784 OutboundEnvelope::Close(frame) => {
6785 message_count = message_count.saturating_add(1);
6786 let close = frame.map(|value| CloseFrame {
6787 code: value.code,
6788 reason: value.reason.into(),
6789 });
6790 sender.feed(Message::Close(close)).await?;
6791 should_stop = true;
6792 }
6793 }
6794 }
6795
6796 sender.flush().await?;
6797 flush_count = flush_count.saturating_add(1);
6798 emit_outbound_batch_telemetry(
6799 &telemetry,
6800 &correlation,
6801 &route_path,
6802 &session_id,
6803 message_count,
6804 flushed_bytes,
6805 flush_count,
6806 );
6807
6808 if channel_closed {
6809 break;
6810 }
6811 }
6812
6813 Ok(())
6814}
6815
6816async fn send_server_message(
6817 sender: &mut futures_util::stream::SplitSink<WebSocket, Message>,
6818 message: &ServerMessage,
6819) -> Result<(), axum::Error> {
6820 let encoded = serde_json::to_string(message).expect("server messages should serialize");
6821 sender.send(Message::Text(encoded.into())).await
6822}
6823
6824fn render_shell(config: ShellConfig<'_>) -> String {
6825 let target_id = shelly::escape_html(config.target_id);
6826 let path = js_string(config.path);
6827 let title = shelly::escape_html(config.title);
6828 let session_id = js_string(config.session_id);
6829 let session_token = js_string(config.session_token);
6830 let csrf_token = js_string(config.csrf_token);
6831 let protocol = js_string(config.protocol);
6832 let trace_id = js_string(config.trace_id);
6833 let span_id = js_string(config.span_id);
6834 let correlation_id = js_string(config.correlation_id.unwrap_or(""));
6835 let request_id = js_string(config.request_id.unwrap_or(""));
6836 let reconnect_base_ms = config.reconnect_base_ms;
6837 let reconnect_max_ms = config.reconnect_max_ms;
6838 let reconnect_jitter_ms = config.reconnect_jitter_ms;
6839 let heartbeat_interval_ms = config.heartbeat_interval_ms;
6840 let heartbeat_timeout_ms = config.heartbeat_timeout_ms;
6841 let transport_mode = js_string(config.transport_mode);
6842 let transport_fallbacks = config.transport_fallbacks;
6843 let progressive_enhancement = config.progressive_enhancement;
6844 let inner_html = config.inner_html;
6845
6846 format!(
6847 r#"<!doctype html>
6848<html lang="en">
6849 <head>
6850 <meta charset="utf-8" />
6851 <meta name="viewport" content="width=device-width, initial-scale=1" />
6852 <title>{title}</title>
6853 <style>
6854 :root {{
6855 color-scheme: light;
6856 font-family: "Inter", "Segoe UI", system-ui, -apple-system, sans-serif;
6857 --shelly-color-bg: #f7f9fc;
6858 --shelly-color-surface: #ffffff;
6859 --shelly-color-surface-muted: #eef3f8;
6860 --shelly-color-border: #d5dfeb;
6861 --shelly-color-border-strong: #b7c7d8;
6862 --shelly-color-text: #132438;
6863 --shelly-color-text-muted: #5a7187;
6864 --shelly-color-primary: #f2662b;
6865 --shelly-color-primary-strong: #d94f18;
6866 --shelly-color-on-primary: #fff9f6;
6867 --shelly-color-accent: #1c8ea4;
6868 --shelly-color-accent-strong: #14788b;
6869 --shelly-color-on-accent: #f3fcff;
6870 --shelly-color-danger: #c53c26;
6871 --shelly-color-success: #1f8a65;
6872 --shelly-shadow-soft: 0 10px 28px rgba(19, 36, 56, 0.1);
6873 }}
6874 * {{ box-sizing: border-box; }}
6875 body {{
6876 margin: 0;
6877 min-height: 100vh;
6878 padding: 1.5rem;
6879 background: var(--shelly-color-bg);
6880 color: var(--shelly-color-text);
6881 }}
6882 a {{
6883 color: var(--shelly-color-accent-strong);
6884 text-underline-offset: 2px;
6885 }}
6886 a:hover {{ color: var(--shelly-color-accent); }}
6887 button {{
6888 border: 1px solid var(--shelly-color-border-strong);
6889 border-radius: 0.6rem;
6890 padding: 0.5rem 0.8rem;
6891 margin-right: 0.3rem;
6892 background: var(--shelly-color-surface);
6893 color: var(--shelly-color-text);
6894 font: inherit;
6895 cursor: pointer;
6896 transition: background 120ms ease, border-color 120ms ease, color 120ms ease;
6897 }}
6898 button:hover {{
6899 border-color: var(--shelly-color-accent);
6900 background: var(--shelly-color-surface-muted);
6901 }}
6902 button:disabled {{
6903 opacity: 0.6;
6904 cursor: not-allowed;
6905 }}
6906 input, select, textarea {{
6907 border: 1px solid var(--shelly-color-border-strong);
6908 border-radius: 0.55rem;
6909 padding: 0.48rem 0.65rem;
6910 background: var(--shelly-color-surface);
6911 color: var(--shelly-color-text);
6912 font: inherit;
6913 }}
6914 .card {{
6915 max-width: 34rem;
6916 padding: 1rem;
6917 border: 1px solid var(--shelly-color-border);
6918 border-radius: 0.75rem;
6919 background: var(--shelly-color-surface);
6920 box-shadow: var(--shelly-shadow-soft);
6921 }}
6922 .muted {{ color: var(--shelly-color-text-muted); }}
6923 pre {{
6924 margin: 0;
6925 border-radius: 0.65rem;
6926 background: #0f2235;
6927 color: #f8fbff;
6928 }}
6929 </style>
6930 </head>
6931 <body>
6932 <div id="{target_id}">{inner_html}</div>
6933 <script>
6934 window.__SHELLY = {{
6935 target: "{target_id}",
6936 path: "{path}",
6937 sessionId: "{session_id}",
6938 sessionToken: "{session_token}",
6939 csrfToken: "{csrf_token}",
6940 protocol: "{protocol}",
6941 traceId: "{trace_id}",
6942 spanId: "{span_id}",
6943 correlationId: "{correlation_id}",
6944 requestId: "{request_id}",
6945 reconnectBaseMs: {reconnect_base_ms},
6946 reconnectMaxMs: {reconnect_max_ms},
6947 reconnectJitterMs: {reconnect_jitter_ms},
6948 heartbeatIntervalMs: {heartbeat_interval_ms},
6949 heartbeatTimeoutMs: {heartbeat_timeout_ms},
6950 transportMode: "{transport_mode}",
6951 transportFallbacks: {transport_fallbacks},
6952 progressiveEnhancement: {progressive_enhancement}
6953 }};
6954 </script>
6955 <script src="/__shelly/client.js"></script>
6956 </body>
6957</html>"#
6958 )
6959}
6960
6961fn js_string(value: &str) -> String {
6962 value
6963 .replace('\\', "\\\\")
6964 .replace('"', "\\\"")
6965 .replace('\n', "\\n")
6966 .replace('\r', "\\r")
6967 .replace('<', "\\u003c")
6968 .replace('>', "\\u003e")
6969 .replace('&', "\\u0026")
6970}
6971
6972#[cfg(test)]
6973mod tests {
6974 use super::{
6975 abort_pubsub_tasks, abort_runtime_tasks, apply_overload_throttle, authorization_denied,
6976 cleanup_expired_snapshots, client_js, client_message_log_fields, contains_server_error,
6977 correlation_from_connect, correlation_from_headers, correlation_from_query,
6978 durable_acquire_lease, durable_error_message, durable_message_should_journal,
6979 durable_release_lease, effective_tenant_id_for_message, emit_outbound_batch_telemetry,
6980 emit_outbound_overflow_telemetry, emit_security_audit, emit_upload_lifecycle_telemetry,
6981 from_hex, handle_navigate, handle_patch_url, handle_upload_chunk, handle_upload_complete,
6982 handle_upload_start, header_value, initial_handler, internal_path, log_http_ingress,
6983 log_ws_text_ingress, messages_for_text_payload, non_empty_text, normalize_hex_id,
6984 normalize_path, now_unix_ms, origin_allowed, outbound_queue_depth,
6985 overload_decision_for_dispatch, overload_decision_to_error,
6986 overload_priority_for_client_message, overload_telemetry_event, parse_query,
6987 parse_traceparent, payload_too_large_close, payload_too_large_error, percent_decode,
6988 process_pubsub_commands, process_runtime_commands, queue_close_frame, queue_pong_frame,
6989 queue_server_message, quota_denied, rate_limited, rate_limited_error, reconnect_event,
6990 recover_from_durable_record, render_shell, replay_durable_journal_entry, route_segments,
6991 sanitize_query_for_log, security_audit_event, security_operation_for_message,
6992 session_affinity_mismatch, tenant_id_for_client_message, transport_capabilities,
6993 unregister_pubsub_presence, unsupported_binary_error, upload_error, upload_lifecycle_event,
6994 verify_websocket_tokens, AppState, AuthorizationDecision, AuthorizationInput,
6995 ConnectHandshake, ConsoleLogFormat, CorrelationContext, DistributedConfig,
6996 DurableJournalEntry, DurableLeaseRequest, DurablePlacementDecision, DurableRuntimeConfig,
6997 DurableSessionSnapshot, DurableSessionStore, DurableStoreError, DurableTakeoverPolicy,
6998 HttpTransportConfig, InMemoryDurableSessionStore, LiveRoute, OutboundCloseFrame,
6999 OutboundConfig, OutboundEnvelope, OutboundOverflowPolicy, OutboundQueuePush,
7000 OverloadBudgets, OverloadConfig, OverloadContext, OverloadDecision, OverloadPriority,
7001 OverloadShedPolicy, QuotaDecision, ReconnectConfig, ResumeSnapshot, RouteSegment,
7002 SecurityConfig, SecurityOperation, SessionAffinityMode, SessionTelemetrySink, ShellConfig,
7003 ShellyRouter, SignedSession, SocketConfig, TelemetryConfig, TelemetryExporter,
7004 TelemetryPipeline, TenantQuotaBudgets, TenantQuotaPolicy, TokenSigner, TransportMode,
7005 UploadConfig, UploadStartRequest, CLIENT_JS, DEFAULT_HEARTBEAT_INTERVAL_MS,
7006 DEFAULT_HEARTBEAT_TIMEOUT_MS, DEFAULT_HTTP3_ALT_SVC, DEFAULT_RECONNECT_BASE_MS,
7007 DEFAULT_RECONNECT_JITTER_MS, DEFAULT_RECONNECT_MAX_MS, PROTOCOL_VERSION,
7008 };
7009 use axum::{
7010 body::to_bytes,
7011 extract::{ws::close_code, State},
7012 http::{header, HeaderMap, HeaderValue, Method, StatusCode, Uri},
7013 response::IntoResponse,
7014 };
7015 use futures_util::{SinkExt, StreamExt};
7016 use shelly::{
7017 ClientMessage, Context, Event, Html, LiveResult, LiveSession, LiveView, PubSub,
7018 PubSubCommand, ResumeStatus, RuntimeCommand, RuntimeEvent, ServerMessage,
7019 TelemetryEventKind, TelemetrySink,
7020 };
7021 use std::{
7022 collections::{HashMap, HashSet},
7023 path::PathBuf,
7024 sync::Arc,
7025 };
7026 use tokio::{
7027 sync::mpsc,
7028 task::JoinHandle,
7029 time::{advance, Duration, Instant},
7030 };
7031 use tokio_tungstenite::{
7032 connect_async,
7033 tungstenite::{client::IntoClientRequest, Error as WsError, Message as WsMessage},
7034 };
7035 use uuid::Uuid;
7036
7037 fn test_correlation() -> CorrelationContext {
7038 CorrelationContext {
7039 trace_id: "0123456789abcdef0123456789abcdef".to_string(),
7040 span_id: "0123456789abcdef".to_string(),
7041 parent_span_id: None,
7042 correlation_id: Some("corr-123".to_string()),
7043 request_id: Some("req-123".to_string()),
7044 }
7045 }
7046
7047 fn test_overload_context(
7048 session_id: &str,
7049 tenant_id: Option<&str>,
7050 priority: OverloadPriority,
7051 inbound_bytes: usize,
7052 queue_depth: usize,
7053 queue_capacity: usize,
7054 ) -> OverloadContext {
7055 OverloadContext {
7056 route_path: "/bench".to_string(),
7057 session_id: session_id.to_string(),
7058 tenant_id: tenant_id.map(ToString::to_string),
7059 message_kind: "text",
7060 operation: SecurityOperation::Event,
7061 event_name: Some("bench.event".to_string()),
7062 priority,
7063 queue_depth,
7064 queue_capacity,
7065 inbound_bytes,
7066 }
7067 }
7068
7069 fn test_app_state_with(
7070 pattern: &str,
7071 security: SecurityConfig,
7072 distributed: DistributedConfig,
7073 ) -> Arc<AppState> {
7074 Arc::new(AppState {
7075 routes: Arc::new(vec![LiveRoute::new(
7076 pattern.to_string(),
7077 Arc::new(|| Box::<Counter>::default()),
7078 )]),
7079 target_id: "root".to_string(),
7080 max_message_size: 64 * 1024,
7081 pubsub: PubSub::default(),
7082 uploads: UploadConfig::default(),
7083 security,
7084 telemetry: Arc::new(TelemetryPipeline::disabled()),
7085 reconnect: ReconnectConfig::default(),
7086 distributed,
7087 durable: DurableRuntimeConfig::default(),
7088 outbound: OutboundConfig::default(),
7089 render: super::RenderConfig::default(),
7090 overload: OverloadConfig::default(),
7091 transport_http: HttpTransportConfig::default(),
7092 })
7093 }
7094
7095 async fn spawn_test_router(router: ShellyRouter) -> (String, JoinHandle<()>) {
7096 let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
7097 .await
7098 .expect("bind test listener");
7099 let addr = listener.local_addr().expect("listener addr");
7100 let app = router.into_router();
7101 let handle = tokio::spawn(async move {
7102 let _ = axum::serve(listener, app).await;
7103 });
7104 (format!("ws://{addr}"), handle)
7105 }
7106
7107 fn ws_error_status(err: WsError) -> StatusCode {
7108 match err {
7109 WsError::Http(response) => response.status(),
7110 other => panic!("unexpected websocket error: {other:?}"),
7111 }
7112 }
7113
7114 async fn await_hello(
7115 socket: &mut tokio_tungstenite::WebSocketStream<
7116 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
7117 >,
7118 ) {
7119 for _ in 0..8 {
7120 let frame = socket.next().await.expect("hello frame");
7121 if let Ok(WsMessage::Text(text)) = frame {
7122 let message: ServerMessage =
7123 serde_json::from_str(text.as_str()).expect("hello frame should decode");
7124 if matches!(message, ServerMessage::Hello { .. }) {
7125 return;
7126 }
7127 }
7128 }
7129 panic!("expected hello frame");
7130 }
7131
7132 async fn await_error_code(
7133 socket: &mut tokio_tungstenite::WebSocketStream<
7134 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
7135 >,
7136 code: &str,
7137 ) {
7138 for _ in 0..12 {
7139 let Some(frame) = socket.next().await else {
7140 break;
7141 };
7142 if let Ok(WsMessage::Text(text)) = frame {
7143 let message: ServerMessage =
7144 serde_json::from_str(text.as_str()).expect("error frame should decode");
7145 if matches!(
7146 message,
7147 ServerMessage::Error {
7148 code: Some(ref candidate),
7149 ..
7150 } if candidate == code
7151 ) {
7152 return;
7153 }
7154 }
7155 }
7156 panic!("expected server error code `{code}`");
7157 }
7158
7159 #[test]
7160 fn normalize_path_adds_leading_slash() {
7161 assert_eq!(normalize_path("users".to_string()), "/users");
7162 assert_eq!(normalize_path("/users".to_string()), "/users");
7163 assert_eq!(normalize_path("".to_string()), "/");
7164 }
7165
7166 #[test]
7167 fn durable_and_overload_helper_branches_are_exercised() {
7168 let durable_err = DurableStoreError::new("store_unavailable", "durable store down");
7169 let message = durable_error_message(&durable_err);
7170 assert!(matches!(
7171 message,
7172 ServerMessage::Error {
7173 code: Some(ref code),
7174 ..
7175 } if code == "store_unavailable"
7176 ));
7177
7178 assert!(durable_message_should_journal(&ClientMessage::Navigate {
7179 to: "/x".to_string()
7180 }));
7181 assert!(!durable_message_should_journal(&ClientMessage::Ping {
7182 nonce: None
7183 }));
7184 assert!(!durable_message_should_journal(&ClientMessage::Event {
7185 event: super::INTERNAL_RENDER_FLUSH_EVENT.to_string(),
7186 value: serde_json::Value::Null,
7187 target: None,
7188 metadata: Default::default(),
7189 }));
7190
7191 assert!(contains_server_error(&[
7192 ServerMessage::Patch {
7193 revision: 1,
7194 target: "root".to_string(),
7195 html: "<div/>".to_string(),
7196 },
7197 ServerMessage::Error {
7198 message: "boom".to_string(),
7199 code: Some("e".to_string()),
7200 },
7201 ]));
7202 assert!(!contains_server_error(&[ServerMessage::Pong {
7203 nonce: None
7204 }]));
7205
7206 let interactive = ClientMessage::Event {
7207 event: "save".to_string(),
7208 value: serde_json::json!({"tenant_id": "tenant-a"}),
7209 target: None,
7210 metadata: Default::default(),
7211 };
7212 let background = ClientMessage::Event {
7213 event: "bg:refresh".to_string(),
7214 value: serde_json::Value::Null,
7215 target: None,
7216 metadata: serde_json::Map::from_iter([(
7217 "priority".to_string(),
7218 serde_json::Value::String("background".to_string()),
7219 )]),
7220 };
7221 assert_eq!(
7222 overload_priority_for_client_message(&interactive),
7223 OverloadPriority::Interactive
7224 );
7225 assert_eq!(
7226 overload_priority_for_client_message(&background),
7227 OverloadPriority::Background
7228 );
7229 assert_eq!(
7230 tenant_id_for_client_message(&interactive).as_deref(),
7231 Some("tenant-a")
7232 );
7233 assert_eq!(
7234 effective_tenant_id_for_message(Some("tenant-session"), Some("tenant-msg")).as_deref(),
7235 Some("tenant-msg")
7236 );
7237 assert_eq!(
7238 effective_tenant_id_for_message(Some("tenant-session"), None).as_deref(),
7239 Some("tenant-session")
7240 );
7241
7242 let (sender, _receiver) = mpsc::channel::<OutboundEnvelope>(4);
7243 assert_eq!(outbound_queue_depth(&sender, 4), 0);
7244 let _ = sender.try_send(OutboundEnvelope::Text("x".to_string()));
7245 assert_eq!(outbound_queue_depth(&sender, 4), 1);
7246
7247 let shed = OverloadDecision::shed("shed:queue_depth");
7248 let throttle = OverloadDecision::throttle(12, "throttle:session_budget");
7249 assert!(matches!(
7250 overload_decision_to_error(&shed),
7251 ServerMessage::Error {
7252 code: Some(ref code),
7253 ..
7254 } if code == "overload_shed"
7255 ));
7256 assert!(matches!(
7257 overload_decision_to_error(&throttle),
7258 ServerMessage::Error {
7259 code: Some(ref code),
7260 ..
7261 } if code == "overload_throttle"
7262 ));
7263 }
7264
7265 #[test]
7266 fn durable_runtime_helpers_cover_acquire_append_release_and_replay_paths() {
7267 let distributed = DistributedConfig {
7268 node_id: "node-a".to_string(),
7269 ..DistributedConfig::default()
7270 };
7271 let no_store = DurableRuntimeConfig::default();
7272 assert!(
7273 super::durable_acquire_lease(&no_store, &distributed, "session-none", "/")
7274 .expect("no-store lease lookup should succeed")
7275 .is_none()
7276 );
7277
7278 let mut session = LiveSession::new(Box::<Counter>::default(), "root");
7279 session.mount().expect("mount counter session");
7280 let dummy_lease = super::DurableLeaseHandle {
7281 owner_node_id: "node-a".to_string(),
7282 fence_token: 1,
7283 };
7284 super::durable_save_snapshot(
7285 &no_store,
7286 &dummy_lease,
7287 &session,
7288 "/",
7289 "root",
7290 "resume-none",
7291 );
7292 assert!(super::durable_append_journal_entry(
7293 &no_store,
7294 &dummy_lease,
7295 "session-none",
7296 &ClientMessage::Navigate {
7297 to: "/users".to_string(),
7298 },
7299 )
7300 .expect("append should no-op without store")
7301 .is_none());
7302 super::durable_release_lease(&no_store, "session-none", Some(&dummy_lease));
7303
7304 let store = Arc::new(InMemoryDurableSessionStore::new());
7305 let mut durable = DurableRuntimeConfig {
7306 store: Some(store.clone()),
7307 drain_mode: true,
7308 ..DurableRuntimeConfig::default()
7309 };
7310 let err = super::durable_acquire_lease(&durable, &distributed, "session-draining", "/")
7311 .expect_err("drain mode should reject ownership");
7312 assert_eq!(err.code, "node_draining");
7313
7314 durable.drain_mode = false;
7315 store.set_node_draining(&distributed.node_id, true);
7316 let err = super::durable_acquire_lease(&durable, &distributed, "session-draining", "/")
7317 .expect_err("draining node flag should reject ownership");
7318 assert_eq!(err.code, "node_draining");
7319
7320 store.set_node_draining(&distributed.node_id, false);
7321 durable.placement_hook = Some(Arc::new(|_| {
7322 DurablePlacementDecision::deny("placement_rejected_custom", "placement rejected")
7323 }));
7324 let err = super::durable_acquire_lease(&durable, &distributed, "session-placed", "/")
7325 .expect_err("placement hook should reject ownership");
7326 assert_eq!(err.code, "placement_rejected_custom");
7327
7328 durable.placement_hook = None;
7329 let lease = super::durable_acquire_lease(&durable, &distributed, "session-live", "/")
7330 .expect("acquire lease")
7331 .expect("store-backed durable runtime should produce lease");
7332
7333 super::durable_save_snapshot(&durable, &lease, &session, "/", "root", "resume-live");
7334 let first = super::durable_append_journal_entry(
7335 &durable,
7336 &lease,
7337 "session-live",
7338 &ClientMessage::Navigate {
7339 to: "/counter".to_string(),
7340 },
7341 )
7342 .expect("append navigate should succeed")
7343 .expect("navigate should be journaled");
7344 assert_eq!(first.sequence, 1);
7345 assert!(super::durable_append_journal_entry(
7346 &durable,
7347 &lease,
7348 "session-live",
7349 &ClientMessage::Ping { nonce: None },
7350 )
7351 .expect("append ping should succeed")
7352 .is_none());
7353
7354 super::durable_release_lease(&durable, "session-live", Some(&lease));
7355 let stale = super::durable_append_journal_entry(
7356 &durable,
7357 &lease,
7358 "session-live",
7359 &ClientMessage::Navigate {
7360 to: "/counter".to_string(),
7361 },
7362 )
7363 .expect_err("stale lease should not append after release");
7364 assert_eq!(stale.code, "lease_not_found");
7365
7366 let telemetry: Arc<dyn TelemetrySink> = Arc::new(TelemetryPipeline::disabled());
7367 let routes = vec![LiveRoute::new(
7368 "/".to_string(),
7369 Arc::new(|| Box::<Counter>::default()),
7370 )];
7371 let mut route_pattern = "/".to_string();
7372 let replay_error = replay_durable_journal_entry(
7373 &mut session,
7374 &mut route_pattern,
7375 &routes,
7376 "root",
7377 "session-live",
7378 &telemetry,
7379 &DurableJournalEntry {
7380 sequence: 99,
7381 message: ClientMessage::PatchUrl {
7382 to: "https://example.test/off-domain".to_string(),
7383 },
7384 recorded_at_unix_ms: 0,
7385 },
7386 )
7387 .expect_err("invalid patch_url should surface a replay error");
7388 assert!(replay_error.contains("durable journal replay failed"));
7389 }
7390
7391 #[tokio::test]
7392 async fn abort_task_helpers_cancel_running_tasks() {
7393 let pubsub_task = tokio::spawn(async {
7394 tokio::time::sleep(Duration::from_secs(5)).await;
7395 });
7396 abort_pubsub_tasks(vec![pubsub_task]);
7397
7398 let mut runtime_tasks = HashMap::new();
7399 runtime_tasks.insert(
7400 "ticker".to_string(),
7401 tokio::spawn(async {
7402 tokio::time::sleep(Duration::from_secs(5)).await;
7403 }),
7404 );
7405 abort_runtime_tasks(runtime_tasks);
7406 }
7407
7408 #[test]
7409 fn console_log_format_parser_defaults_to_json() {
7410 assert_eq!(ConsoleLogFormat::parse("json"), ConsoleLogFormat::Json);
7411 assert_eq!(ConsoleLogFormat::parse(" JSON "), ConsoleLogFormat::Json);
7412 assert_eq!(ConsoleLogFormat::parse("unknown"), ConsoleLogFormat::Json);
7413 }
7414
7415 #[test]
7416 fn console_log_format_parser_supports_pretty_aliases() {
7417 for raw in ["pretty", "PRETTY", "text", "plain"] {
7418 assert_eq!(ConsoleLogFormat::parse(raw), ConsoleLogFormat::Pretty);
7419 }
7420 }
7421
7422 #[test]
7423 fn console_log_format_as_str_outputs_expected_values() {
7424 assert_eq!(ConsoleLogFormat::Json.as_str(), "json");
7425 assert_eq!(ConsoleLogFormat::Pretty.as_str(), "pretty");
7426 }
7427
7428 #[test]
7429 fn telemetry_config_builders_select_expected_exporters() {
7430 let tracing = TelemetryConfig::tracing("svc-a");
7431 assert_eq!(tracing.service_name, "svc-a");
7432 assert!(matches!(tracing.exporter, TelemetryExporter::Tracing));
7433
7434 let otel = TelemetryConfig::otel_json("svc-b");
7435 assert_eq!(otel.service_name, "svc-b");
7436 assert!(matches!(
7437 otel.exporter,
7438 TelemetryExporter::OpenTelemetryJson
7439 ));
7440
7441 let axiom =
7442 TelemetryConfig::axiom_json("svc-c", "dataset-main", Some("org-123".to_string()));
7443 assert_eq!(axiom.service_name, "svc-c");
7444 assert!(matches!(
7445 axiom.exporter,
7446 TelemetryExporter::AxiomJson { .. }
7447 ));
7448 }
7449
7450 #[test]
7451 fn telemetry_pipeline_emit_supports_all_exporters() {
7452 let event = shelly::TelemetryEvent::new(TelemetryEventKind::HandleEvent)
7453 .with_session("session-1")
7454 .with_route("/demo")
7455 .with_event_name("save")
7456 .with_ok(true)
7457 .with_latency_ms(12)
7458 .with_bytes(256)
7459 .with_count(1)
7460 .with_attribute("api_key", serde_json::json!("secret-token"));
7461
7462 let tracing_pipeline = TelemetryPipeline::enabled(TelemetryConfig::tracing("svc"));
7463 assert!(tracing_pipeline.emit(event.clone()).is_ok());
7464
7465 let otel_pipeline = TelemetryPipeline::enabled(TelemetryConfig::otel_json("svc"));
7466 assert!(otel_pipeline.emit(event.clone()).is_ok());
7467
7468 let axiom_pipeline = TelemetryPipeline::enabled(TelemetryConfig::axiom_json(
7469 "svc",
7470 "dataset-main",
7471 Some("org-1".to_string()),
7472 ));
7473 assert!(axiom_pipeline.emit(event).is_ok());
7474 }
7475
7476 #[test]
7477 fn session_telemetry_sink_applies_correlation_context() {
7478 let sink = SessionTelemetrySink::new(
7479 Arc::new(TelemetryPipeline::disabled()),
7480 CorrelationContext {
7481 trace_id: "0123456789abcdef0123456789abcdef".to_string(),
7482 span_id: "0123456789abcdef".to_string(),
7483 parent_span_id: Some("fedcba9876543210".to_string()),
7484 correlation_id: Some("corr-1".to_string()),
7485 request_id: Some("req-1".to_string()),
7486 },
7487 );
7488 assert!(sink
7489 .emit(shelly::TelemetryEvent::new(TelemetryEventKind::Patch))
7490 .is_ok());
7491 }
7492
7493 #[test]
7494 fn correlation_helpers_extract_and_fallback_fields() {
7495 let mut headers = HeaderMap::new();
7496 headers.insert("x-request-id", HeaderValue::from_static("req-7"));
7497 headers.insert("x-correlation-id", HeaderValue::from_static("corr-7"));
7498 headers.insert(
7499 "traceparent",
7500 HeaderValue::from_static("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"),
7501 );
7502 headers.insert(
7503 "x-trace-id",
7504 HeaderValue::from_static("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"),
7505 );
7506
7507 assert_eq!(non_empty_text(" "), None);
7508 assert_eq!(
7509 header_value(&headers, "x-request-id").as_deref(),
7510 Some("req-7")
7511 );
7512
7513 let base = correlation_from_headers(&headers);
7514 assert_eq!(
7515 base.trace_id,
7516 "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string()
7517 );
7518 assert_eq!(base.correlation_id.as_deref(), Some("corr-7"));
7519 assert_eq!(base.request_id.as_deref(), Some("req-7"));
7520 assert_eq!(base.parent_span_id.as_deref(), Some("00f067aa0ba902b7"));
7521
7522 let mut query = HashMap::new();
7523 query.insert(
7524 "trace_id".to_string(),
7525 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb".to_string(),
7526 );
7527 query.insert("request_id".to_string(), "req-9".to_string());
7528 let from_query = correlation_from_query(&query, &base);
7529 assert_eq!(
7530 from_query.trace_id,
7531 "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb".to_string()
7532 );
7533 assert_eq!(from_query.request_id.as_deref(), Some("req-9"));
7534 assert_eq!(from_query.correlation_id.as_deref(), Some("corr-7"));
7535 assert_eq!(
7536 from_query.parent_span_id.as_deref(),
7537 Some(base.span_id.as_str())
7538 );
7539 }
7540
7541 #[test]
7542 fn http_ingress_logging_returns_correlation_context() {
7543 let mut headers = HeaderMap::new();
7544 headers.insert(header::USER_AGENT, HeaderValue::from_static("shelly-test"));
7545 headers.insert(header::HOST, HeaderValue::from_static("127.0.0.1:3000"));
7546 headers.insert("x-request-id", HeaderValue::from_static("req-42"));
7547 let method = Method::GET;
7548 let uri: Uri = "/grid?session=sensitive&custom=ok".parse().expect("uri");
7549
7550 let correlation = log_http_ingress("initial", &method, &uri, &headers);
7551 assert_eq!(correlation.request_id.as_deref(), Some("req-42"));
7552 assert!(!correlation.trace_id.is_empty());
7553 assert!(!correlation.span_id.is_empty());
7554 }
7555
7556 #[test]
7557 fn security_operation_helpers_cover_mutation_matrix() {
7558 assert_eq!(SecurityOperation::Connect.as_str(), "connect");
7559 assert_eq!(SecurityOperation::Binary.as_str(), "binary");
7560 assert!(!SecurityOperation::Connect.is_mutating());
7561 assert!(SecurityOperation::Event.is_mutating());
7562 assert!(SecurityOperation::UploadComplete.is_mutating());
7563 assert!(!SecurityOperation::Ping.is_mutating());
7564 }
7565
7566 #[test]
7567 fn durable_placement_decision_helpers_toggle_allowed_state() {
7568 let allow = DurablePlacementDecision::allow();
7569 assert!(allow.allowed);
7570 assert!(allow.code.is_none());
7571
7572 let deny = DurablePlacementDecision::deny("placement_reject", "no capacity");
7573 assert!(!deny.allowed);
7574 assert_eq!(deny.code.as_deref(), Some("placement_reject"));
7575 assert_eq!(deny.message.as_deref(), Some("no capacity"));
7576 }
7577
7578 #[test]
7579 fn in_memory_durable_store_renew_release_and_drain_controls() {
7580 let store = InMemoryDurableSessionStore::new();
7581 assert!(!store.is_node_draining("node-a"));
7582 store.set_node_draining("node-a", true);
7583 assert!(store.is_node_draining("node-a"));
7584 store.set_node_draining("node-a", false);
7585 assert!(!store.is_node_draining("node-a"));
7586
7587 let lease = store
7588 .acquire_lease(DurableLeaseRequest {
7589 session_id: "session-1".to_string(),
7590 node_id: "node-a".to_string(),
7591 ttl_ms: 10_000,
7592 takeover_policy: DurableTakeoverPolicy::AllowExpired,
7593 })
7594 .expect("acquire lease");
7595 let renewed = store
7596 .renew_lease("session-1", "node-a", lease.fence_token, 10_000)
7597 .expect("renew lease");
7598 assert_eq!(renewed.fence_token, lease.fence_token);
7599
7600 let renew_err = store
7601 .renew_lease("session-1", "node-b", lease.fence_token, 10_000)
7602 .expect_err("non-owner renew should fail");
7603 assert_eq!(renew_err.code, "lease_not_owner");
7604
7605 store.release_lease("session-1", "node-b", lease.fence_token);
7606 store
7607 .append_journal_entry(
7608 "session-1",
7609 "node-a",
7610 lease.fence_token,
7611 ClientMessage::Ping { nonce: None },
7612 2,
7613 )
7614 .expect("owner can append journal");
7615 store.release_lease("session-1", "node-a", lease.fence_token);
7616 let append_err = store
7617 .append_journal_entry(
7618 "session-1",
7619 "node-a",
7620 lease.fence_token,
7621 ClientMessage::Ping { nonce: None },
7622 2,
7623 )
7624 .expect_err("append after release should fail");
7625 assert_eq!(append_err.code, "lease_not_found");
7626 assert!(store.load_record("missing").is_none());
7627 }
7628
7629 #[test]
7630 fn app_state_route_for_matches_registered_pattern() {
7631 let routes = Arc::new(vec![LiveRoute::new(
7632 "/users/:id".to_string(),
7633 Arc::new(|| Box::<Counter>::default()),
7634 )]);
7635 let state = AppState {
7636 routes,
7637 target_id: "root".to_string(),
7638 max_message_size: 64 * 1024,
7639 pubsub: PubSub::default(),
7640 uploads: UploadConfig::default(),
7641 security: SecurityConfig::default(),
7642 telemetry: Arc::new(TelemetryPipeline::disabled()),
7643 reconnect: ReconnectConfig::default(),
7644 distributed: DistributedConfig::default(),
7645 durable: DurableRuntimeConfig::default(),
7646 outbound: OutboundConfig::default(),
7647 render: super::RenderConfig::default(),
7648 overload: OverloadConfig::default(),
7649 transport_http: HttpTransportConfig::default(),
7650 };
7651
7652 let matched = state.route_for("/users/42").expect("route match");
7653 assert_eq!(matched.pattern, "/users/:id");
7654 assert_eq!(matched.params.get("id").map(String::as_str), Some("42"));
7655 assert!(state.route_for("/posts/42").is_none());
7656 }
7657
7658 #[test]
7659 fn shelly_router_builder_methods_apply_configuration() {
7660 let store: Arc<dyn DurableSessionStore> = Arc::new(InMemoryDurableSessionStore::new());
7661 let overload_budgets = OverloadBudgets {
7662 session_queue_depth: 12,
7663 session_bytes_per_sec: 2_048,
7664 session_cpu_ms_per_sec: 16,
7665 tenant_queue_depth: 24,
7666 tenant_bytes_per_sec: 4_096,
7667 tenant_cpu_ms_per_sec: 32,
7668 };
7669 let router = ShellyRouter::new()
7670 .with_target_id("root-test")
7671 .with_max_message_size(32 * 1024)
7672 .with_max_upload_size(2048)
7673 .with_allowed_upload_content_type("application/json")
7674 .with_upload_temp_dir(PathBuf::from("/tmp/shelly-upload-tests"))
7675 .with_secret(b"top-secret".to_vec())
7676 .with_allowed_origin("https://example.com")
7677 .with_rate_limiter(|_| true)
7678 .with_authorization_hook(|_| AuthorizationDecision::allow())
7679 .with_quota_policy(|_| QuotaDecision::allow())
7680 .with_telemetry(TelemetryConfig::tracing("svc-router"))
7681 .with_reconnect_backoff(10, 20, 30)
7682 .with_heartbeat(100, 200)
7683 .with_resume_ttl_ms(100)
7684 .with_connect_handshake_timeout_ms(100)
7685 .with_http3_default_alt_svc()
7686 .with_transport_diagnostics_header(true)
7687 .with_http3_alt_svc("h3=\":8443\"; ma=300")
7688 .with_node_id("node-test")
7689 .with_session_affinity_mode(SessionAffinityMode::Required)
7690 .with_durable_session_store(store.clone())
7691 .with_durable_lease_ttl_ms(10)
7692 .with_durable_journal_limit(0)
7693 .with_durable_takeover_policy(DurableTakeoverPolicy::Force)
7694 .with_durable_drain_mode(true)
7695 .with_durable_placement_hook(|_| DurablePlacementDecision::allow())
7696 .with_outbound_queue_capacity(0)
7697 .with_outbound_batching(0, 128, 0)
7698 .with_outbound_overflow_policy(OutboundOverflowPolicy::DropNewest)
7699 .with_default_render_cadence_ms(22)
7700 .with_overload_budgets(overload_budgets.clone())
7701 .with_overload_shed_policy(OverloadShedPolicy::Strict)
7702 .with_overload_policy_hook(|_ctx, decision| decision.clone())
7703 .live("/", Counter::default);
7704
7705 assert_eq!(router.target_id, "root-test");
7706 assert_eq!(router.max_message_size, 32 * 1024);
7707 assert_eq!(router.uploads.max_file_size, 2048);
7708 assert_eq!(
7709 router.uploads.allowed_content_types.as_ref(),
7710 &vec!["application/json".to_string()]
7711 );
7712 assert_eq!(router.distributed.node_id, "node-test");
7713 assert_eq!(
7714 router.distributed.session_affinity,
7715 SessionAffinityMode::Required
7716 );
7717 assert_eq!(router.reconnect.client.reconnect_base_ms, 100);
7718 assert_eq!(router.reconnect.client.reconnect_max_ms, 100);
7719 assert_eq!(router.reconnect.client.heartbeat_interval_ms, 1_000);
7720 assert_eq!(router.reconnect.client.heartbeat_timeout_ms, 1_000);
7721 assert_eq!(
7722 router.transport_http.http3_alt_svc.as_deref(),
7723 Some("h3=\":8443\"; ma=300")
7724 );
7725 assert!(router.transport_http.emit_diagnostics_header);
7726 assert_eq!(router.durable.journal_limit, 1);
7727 assert_eq!(router.durable.takeover_policy, DurableTakeoverPolicy::Force);
7728 assert_eq!(router.outbound.queue_capacity, 1);
7729 assert_eq!(router.outbound.batch_max_messages, 1);
7730 assert_eq!(router.outbound.batch_max_bytes, 256);
7731 assert_eq!(
7732 router.outbound.batch_flush_interval,
7733 Duration::from_millis(1)
7734 );
7735 assert_eq!(
7736 router.outbound.overflow_policy,
7737 OutboundOverflowPolicy::DropNewest
7738 );
7739 assert_eq!(router.render.default_cadence_ms, 22);
7740 assert_eq!(router.overload.budgets, overload_budgets);
7741 assert_eq!(router.overload.shed_policy, OverloadShedPolicy::Strict);
7742 assert!(router.overload.policy_hook.is_some());
7743 assert!(router.durable.drain_mode);
7744 assert!(router.security.rate_limiter.is_some());
7745 assert!(router.security.authorization.is_some());
7746 assert!(router.security.quota_policy.is_some());
7747 assert!(router.durable.placement_hook.is_some());
7748 assert_eq!(router.routes.len(), 1);
7749
7750 let node_id = router.distributed.node_id.clone();
7751 let durable_store = router
7752 .durable
7753 .store
7754 .as_ref()
7755 .expect("durable store should exist")
7756 .clone();
7757 let _axum_router = router.into_router();
7758 assert!(durable_store.is_node_draining(&node_id));
7759 }
7760
7761 #[test]
7762 fn route_patterns_capture_params() {
7763 assert_eq!(
7764 route_segments("/users/:id"),
7765 vec![
7766 RouteSegment::Static("users".to_string()),
7767 RouteSegment::Param("id".to_string())
7768 ]
7769 );
7770
7771 let route = LiveRoute::new(
7772 "/users/:id".to_string(),
7773 std::sync::Arc::new(|| Box::<Counter>::default()),
7774 );
7775 let matched = route.match_path("/users/42").unwrap();
7776 assert_eq!(matched.pattern, "/users/:id");
7777 assert_eq!(matched.params.get("id").map(String::as_str), Some("42"));
7778 assert!(route.match_path("/users/42/edit").is_none());
7779 }
7780
7781 #[test]
7782 fn internal_navigation_rejects_external_urls() {
7783 assert_eq!(
7784 internal_path("/pages/intro?tab=a#top").as_deref(),
7785 Some("/pages/intro")
7786 );
7787 assert_eq!(internal_path("https://example.test/pages"), None);
7788 assert_eq!(internal_path("//example.test/pages"), None);
7789 }
7790
7791 #[test]
7792 fn shell_contains_target_and_client_asset() {
7793 let html = render_shell(ShellConfig {
7794 target_id: "root",
7795 inner_html: "<p>hi</p>",
7796 path: "/",
7797 title: "Test",
7798 session_id: "sid",
7799 session_token: "session",
7800 csrf_token: "csrf",
7801 protocol: "shelly/1",
7802 trace_id: "0123456789abcdef0123456789abcdef",
7803 span_id: "0123456789abcdef",
7804 correlation_id: Some("corr-1"),
7805 request_id: Some("req-1"),
7806 reconnect_base_ms: DEFAULT_RECONNECT_BASE_MS,
7807 reconnect_max_ms: DEFAULT_RECONNECT_MAX_MS,
7808 reconnect_jitter_ms: DEFAULT_RECONNECT_JITTER_MS,
7809 heartbeat_interval_ms: DEFAULT_HEARTBEAT_INTERVAL_MS,
7810 heartbeat_timeout_ms: DEFAULT_HEARTBEAT_TIMEOUT_MS,
7811 transport_mode: "websocket",
7812 transport_fallbacks: r#"["sse","long_poll"]"#,
7813 progressive_enhancement: true,
7814 });
7815 assert!(html.contains("<div id=\"root\"><p>hi</p></div>"));
7816 assert!(html.contains("/__shelly/client.js"));
7817 assert!(html.contains("window.__SHELLY"));
7818 assert!(html.contains(r#"sessionId: "sid""#));
7819 assert!(html.contains(r#"sessionToken: "session""#));
7820 assert!(html.contains(r#"csrfToken: "csrf""#));
7821 assert!(html.contains(r#"protocol: "shelly/1""#));
7822 assert!(html.contains(r#"traceId: "0123456789abcdef0123456789abcdef""#));
7823 assert!(html.contains(r#"spanId: "0123456789abcdef""#));
7824 assert!(html.contains(r#"correlationId: "corr-1""#));
7825 assert!(html.contains(r#"requestId: "req-1""#));
7826 assert!(html.contains("reconnectBaseMs: 750"));
7827 assert!(html.contains("reconnectMaxMs: 30000"));
7828 assert!(html.contains("heartbeatIntervalMs: 15000"));
7829 assert!(html.contains("heartbeatTimeoutMs: 10000"));
7830 assert!(html.contains(r#"transportMode: "websocket""#));
7831 assert!(html.contains(r#"transportFallbacks: ["sse","long_poll"]"#));
7832 assert!(html.contains("progressiveEnhancement: true"));
7833 }
7834
7835 #[test]
7836 fn shell_escapes_security_config_values() {
7837 let html = render_shell(ShellConfig {
7838 target_id: r#"ro"ot"#,
7839 inner_html: "<p>trusted</p>",
7840 path: r#"/"><script>"#,
7841 title: r#"<Title>"#,
7842 session_id: "sid",
7843 session_token: "session",
7844 csrf_token: "csrf",
7845 protocol: "shelly/1",
7846 trace_id: "0123456789abcdef0123456789abcdef",
7847 span_id: "0123456789abcdef",
7848 correlation_id: Some("corr-1"),
7849 request_id: Some("req-1"),
7850 reconnect_base_ms: DEFAULT_RECONNECT_BASE_MS,
7851 reconnect_max_ms: DEFAULT_RECONNECT_MAX_MS,
7852 reconnect_jitter_ms: DEFAULT_RECONNECT_JITTER_MS,
7853 heartbeat_interval_ms: DEFAULT_HEARTBEAT_INTERVAL_MS,
7854 heartbeat_timeout_ms: DEFAULT_HEARTBEAT_TIMEOUT_MS,
7855 transport_mode: "websocket",
7856 transport_fallbacks: "[]",
7857 progressive_enhancement: true,
7858 });
7859
7860 assert!(html.contains("<title><Title></title>"));
7861 assert!(html.contains(r#"<div id="ro"ot">"#));
7862 assert!(html.contains("\\u003cscript\\u003e"));
7863 assert!(!html.contains(r#"path: "/"><script>""#));
7864 }
7865
7866 #[derive(Default)]
7867 struct Counter {
7868 count: i64,
7869 }
7870
7871 async fn recv_runtime_event_with_virtual_time(
7872 receiver: &mut mpsc::UnboundedReceiver<RuntimeEvent>,
7873 delay: Duration,
7874 ) -> RuntimeEvent {
7875 advance(delay).await;
7876 tokio::task::yield_now().await;
7877 receiver
7878 .recv()
7879 .await
7880 .expect("runtime event should be available after virtual-time advance")
7881 }
7882
7883 impl LiveView for Counter {
7884 fn handle_event(&mut self, event: Event, _ctx: &mut Context) -> LiveResult {
7885 match event.name.as_str() {
7886 "inc" => self.count += 1,
7887 "dec" => self.count -= 1,
7888 _ => {}
7889 }
7890
7891 Ok(())
7892 }
7893
7894 fn render(&self) -> Html {
7895 Html::new(format!("<p>{}</p>", self.count))
7896 }
7897 }
7898
7899 #[test]
7900 fn text_payload_dispatches_client_events() {
7901 let mut session = LiveSession::new(Box::<Counter>::default(), "root");
7902 session.mount().unwrap();
7903
7904 let messages = messages_for_text_payload(
7905 &mut session,
7906 r#"{"type":"event","event":"inc","target":"inc","value":null}"#,
7907 );
7908
7909 assert_eq!(
7910 messages,
7911 vec![ServerMessage::Patch {
7912 target: "root".to_string(),
7913 html: "<p>1</p>".to_string(),
7914 revision: 1,
7915 }]
7916 );
7917 }
7918
7919 #[test]
7920 fn text_payload_accepts_protocol_connect() {
7921 let mut session = LiveSession::new(Box::<Counter>::default(), "root");
7922 session.mount().unwrap();
7923
7924 assert_eq!(
7925 messages_for_text_payload(&mut session, r#"{"type":"connect","protocol":"shelly/1"}"#),
7926 Vec::<ServerMessage>::new()
7927 );
7928 }
7929
7930 #[test]
7931 fn text_payload_rejects_unsupported_protocol_connect() {
7932 let mut session = LiveSession::new(Box::<Counter>::default(), "root");
7933 session.mount().unwrap();
7934
7935 assert_eq!(
7936 messages_for_text_payload(&mut session, r#"{"type":"connect","protocol":"shelly/9"}"#),
7937 vec![ServerMessage::Error {
7938 message: "unsupported protocol in connect: expected shelly/1, got shelly/9"
7939 .to_string(),
7940 code: Some("unsupported_protocol".to_string()),
7941 }]
7942 );
7943 }
7944
7945 #[derive(Default)]
7946 struct Page {
7947 slug: String,
7948 }
7949
7950 impl LiveView for Page {
7951 fn handle_params(&mut self, ctx: &mut Context) -> LiveResult {
7952 self.slug = ctx.route_param("slug").unwrap_or("home").to_string();
7953 Ok(())
7954 }
7955
7956 fn render(&self) -> Html {
7957 Html::new(format!("<p>Page {}</p>", self.slug))
7958 }
7959 }
7960
7961 #[derive(Default)]
7962 struct User {
7963 id: String,
7964 }
7965
7966 impl LiveView for User {
7967 fn handle_params(&mut self, ctx: &mut Context) -> LiveResult {
7968 self.id = ctx.route_param("id").unwrap_or_default().to_string();
7969 Ok(())
7970 }
7971
7972 fn render(&self) -> Html {
7973 Html::new(format!("<p>User {}</p>", self.id))
7974 }
7975 }
7976
7977 #[test]
7978 fn patch_url_updates_route_params_and_patches_without_remounting() {
7979 let routes = vec![LiveRoute::new(
7980 "/pages/:slug".to_string(),
7981 std::sync::Arc::new(|| Box::<Page>::default()),
7982 )];
7983 let mut session = LiveSession::new_with_route(
7984 Box::<Page>::default(),
7985 "root",
7986 "/pages/home",
7987 [("slug".to_string(), "home".to_string())].into(),
7988 );
7989 session.mount().unwrap();
7990 session.render_patch();
7991 let mut current = "/pages/:slug".to_string();
7992
7993 let messages = handle_patch_url(&mut session, &mut current, &routes, "/pages/docs");
7994
7995 assert_eq!(
7996 messages,
7997 vec![
7998 ServerMessage::PatchUrl {
7999 to: "/pages/docs".to_string()
8000 },
8001 ServerMessage::Patch {
8002 target: "root".to_string(),
8003 html: "<p>Page docs</p>".to_string(),
8004 revision: 2,
8005 },
8006 ]
8007 );
8008 }
8009
8010 #[test]
8011 fn navigate_switches_liveview_route_and_renders_target() {
8012 let routes = vec![
8013 LiveRoute::new(
8014 "/pages/:slug".to_string(),
8015 std::sync::Arc::new(|| Box::<Page>::default()),
8016 ),
8017 LiveRoute::new(
8018 "/users/:id".to_string(),
8019 std::sync::Arc::new(|| Box::<User>::default()),
8020 ),
8021 ];
8022 let mut session = LiveSession::new_with_route(
8023 Box::<Page>::default(),
8024 "root",
8025 "/pages/home",
8026 [("slug".to_string(), "home".to_string())].into(),
8027 );
8028 session.mount().unwrap();
8029 let mut current = "/pages/:slug".to_string();
8030 let telemetry: Arc<dyn TelemetrySink> = Arc::new(TelemetryPipeline::disabled());
8031
8032 let messages = handle_navigate(
8033 &mut session,
8034 &mut current,
8035 &routes,
8036 "root",
8037 "sid",
8038 &telemetry,
8039 "/users/42",
8040 );
8041
8042 assert_eq!(current, "/users/:id");
8043 assert_eq!(messages.len(), 3);
8044 assert_eq!(
8045 messages[0],
8046 ServerMessage::Navigate {
8047 to: "/users/42".to_string()
8048 }
8049 );
8050 assert!(matches!(messages[1], ServerMessage::Hello { .. }));
8051 assert_eq!(
8052 messages[2],
8053 ServerMessage::Patch {
8054 target: "root".to_string(),
8055 html: "<p>User 42</p>".to_string(),
8056 revision: 1,
8057 }
8058 );
8059 }
8060
8061 #[test]
8062 fn query_parser_decodes_protocol_version() {
8063 let query = parse_query("protocol=shelly%2F1&session=abc&csrf=def");
8064
8065 assert_eq!(query.get("protocol").map(String::as_str), Some("shelly/1"));
8066 }
8067
8068 #[test]
8069 fn query_log_sanitizer_redacts_sensitive_keys() {
8070 let sanitized = sanitize_query_for_log(Some(
8071 "protocol=shelly%2F1&session=abc&csrf=def&trace_id=abcd&custom=ok",
8072 ));
8073 assert!(sanitized.contains("protocol=shelly%2F1"));
8074 assert!(sanitized.contains("trace_id=abcd"));
8075 assert!(sanitized.contains("custom=ok"));
8076 assert!(sanitized.contains("session=<redacted>"));
8077 assert!(sanitized.contains("csrf=<redacted>"));
8078 }
8079
8080 #[test]
8081 fn traceparent_parser_extracts_trace_and_span() {
8082 let parsed = parse_traceparent("00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01")
8083 .expect("parse traceparent");
8084 assert_eq!(parsed.0, "4bf92f3577b34da6a3ce929d0e0e4736");
8085 assert_eq!(parsed.1, "00f067aa0ba902b7");
8086 }
8087
8088 #[test]
8089 fn normalize_hex_id_rejects_invalid_lengths() {
8090 assert_eq!(
8091 normalize_hex_id("4bf92f3577b34da6a3ce929d0e0e4736", 32).as_deref(),
8092 Some("4bf92f3577b34da6a3ce929d0e0e4736")
8093 );
8094 assert!(normalize_hex_id("xyz", 32).is_none());
8095 assert!(normalize_hex_id("1234", 16).is_none());
8096 }
8097
8098 #[test]
8099 fn malformed_protocol_messages_return_structured_errors() {
8100 let mut session = LiveSession::new(Box::<Counter>::default(), "root");
8101 session.mount().unwrap();
8102
8103 let messages = messages_for_text_payload(&mut session, r#"{"type":"event""#);
8104
8105 assert_eq!(messages.len(), 1);
8106 match &messages[0] {
8107 ServerMessage::Error { message, code } => {
8108 assert!(message.starts_with("invalid protocol message:"));
8109 assert_eq!(code.as_deref(), Some("invalid_protocol"));
8110 }
8111 other => panic!("unexpected message: {other:?}"),
8112 }
8113 }
8114
8115 #[test]
8116 fn payload_size_errors_are_structured_and_close_with_size_code() {
8117 assert_eq!(
8118 payload_too_large_error(12, 8),
8119 ServerMessage::Error {
8120 message: "payload too large: 12 bytes exceeds 8 byte limit".to_string(),
8121 code: Some("payload_too_large".to_string()),
8122 }
8123 );
8124
8125 let close = payload_too_large_close();
8126 assert_eq!(close.code, close_code::SIZE);
8127 assert_eq!(close.reason.as_str(), "payload too large");
8128 }
8129
8130 #[test]
8131 fn binary_protocol_messages_return_structured_errors() {
8132 assert_eq!(
8133 unsupported_binary_error(),
8134 ServerMessage::Error {
8135 message: "binary protocol messages are not supported by Shelly v0".to_string(),
8136 code: Some("unsupported_message_type".to_string()),
8137 }
8138 );
8139 }
8140
8141 #[test]
8142 fn signed_session_and_csrf_tokens_round_trip_for_websocket_auth() {
8143 let signer = TokenSigner::new(b"test secret".to_vec());
8144 let session = signer.sign_session("session-1", "/", "node-a");
8145 let csrf = signer.sign_csrf("session-1", "/");
8146 let resume = signer.sign_resume("session-1", "/");
8147 let query = parse_query(&format!("session={session}&csrf={csrf}"));
8148
8149 let signed_session = verify_websocket_tokens(&signer, &query, "/").unwrap();
8150 assert_eq!(signed_session.session_id, "session-1");
8151 assert_eq!(signed_session.node_id.as_deref(), Some("node-a"));
8152 assert!(verify_websocket_tokens(&signer, &query, "/other").is_none());
8153
8154 let mut tampered = query;
8155 tampered.insert("session".to_string(), "bad".to_string());
8156 assert!(verify_websocket_tokens(&signer, &tampered, "/").is_none());
8157
8158 let bad_csrf = parse_query(&format!(
8159 "session={session}&csrf={}",
8160 signer.sign_csrf("session-1", "/different")
8161 ));
8162 assert!(verify_websocket_tokens(&signer, &bad_csrf, "/").is_none());
8163
8164 let signed_resume = signer.verify_resume(&resume).unwrap();
8165 assert_eq!(signed_resume.session_id, "session-1");
8166 assert_eq!(signed_resume.path, "/");
8167 assert!(!signed_resume.nonce.is_empty());
8168 }
8169
8170 #[test]
8171 fn token_signer_rejects_wrong_token_kinds() {
8172 let signer = TokenSigner::new(b"test secret".to_vec());
8173 let session = signer.sign_session("session-1", "/", "node-a");
8174 let csrf = signer.sign_csrf("session-1", "/");
8175 let resume = signer.sign_resume("session-1", "/");
8176
8177 assert!(signer.verify_session(&csrf).is_none());
8178 assert!(signer.verify_resume(&session).is_none());
8179 assert!(!signer.verify_csrf("bad-token", "session-1", "/"));
8180 assert!(signer.verify_session(&resume).is_none());
8181 }
8182
8183 #[test]
8184 fn correlation_and_reconnect_helpers_cover_resume_variants() {
8185 let base = test_correlation();
8186 let fallback = correlation_from_connect(
8187 &base,
8188 &ConnectHandshake {
8189 client_session_id: None,
8190 client_revision: 0,
8191 resume_token: None,
8192 tenant_id: None,
8193 trace_id: Some("not-hex".to_string()),
8194 span_id: Some("bad-span".to_string()),
8195 parent_span_id: Some("also-bad".to_string()),
8196 correlation_id: Some(String::new()),
8197 request_id: Some(String::new()),
8198 },
8199 );
8200 assert_eq!(fallback.trace_id, base.trace_id);
8201 assert_eq!(
8202 fallback.parent_span_id.as_deref(),
8203 Some(base.span_id.as_str())
8204 );
8205 assert_eq!(fallback.correlation_id, base.correlation_id);
8206 assert_eq!(fallback.request_id, base.request_id);
8207 assert_eq!(fallback.span_id.len(), 16);
8208
8209 let connected = correlation_from_connect(
8210 &base,
8211 &ConnectHandshake {
8212 client_session_id: Some("session-1".to_string()),
8213 client_revision: 2,
8214 resume_token: Some("resume-1".to_string()),
8215 tenant_id: Some("tenant-a".to_string()),
8216 trace_id: Some("4bf92f3577b34da6a3ce929d0e0e4736".to_string()),
8217 span_id: Some("00f067aa0ba902b7".to_string()),
8218 parent_span_id: None,
8219 correlation_id: Some("corr-next".to_string()),
8220 request_id: Some("req-next".to_string()),
8221 },
8222 );
8223 assert_eq!(connected.trace_id, "4bf92f3577b34da6a3ce929d0e0e4736");
8224 assert_eq!(
8225 connected.parent_span_id.as_deref(),
8226 Some("00f067aa0ba902b7")
8227 );
8228 assert_eq!(connected.correlation_id.as_deref(), Some("corr-next"));
8229 assert_eq!(connected.request_id.as_deref(), Some("req-next"));
8230 assert_eq!(connected.span_id.len(), 16);
8231
8232 let resumed = reconnect_event(
8233 "session-1",
8234 "/",
8235 &connected,
8236 true,
8237 ResumeStatus::Resumed,
8238 Some("token_match"),
8239 );
8240 assert_eq!(resumed.kind, TelemetryEventKind::Connect);
8241 assert_eq!(
8242 resumed.attributes.get("resume_status"),
8243 Some(&serde_json::json!("resumed"))
8244 );
8245 assert_eq!(
8246 resumed.attributes.get("resume_reason"),
8247 Some(&serde_json::json!("token_match"))
8248 );
8249
8250 let fallback_event = reconnect_event(
8251 "session-1",
8252 "/",
8253 &connected,
8254 false,
8255 ResumeStatus::Fallback,
8256 None,
8257 );
8258 assert_eq!(
8259 fallback_event.attributes.get("resume_status"),
8260 Some(&serde_json::json!("fallback"))
8261 );
8262 assert!(!fallback_event.attributes.contains_key("resume_reason"));
8263 }
8264
8265 #[test]
8266 fn cleanup_expired_snapshots_removes_stale_entries() {
8267 let mut snapshots = HashMap::new();
8268 snapshots.insert(
8269 "fresh".to_string(),
8270 ResumeSnapshot {
8271 session: LiveSession::new(Box::<Counter>::default(), "root"),
8272 route_pattern: "/".to_string(),
8273 resume_token: "resume-fresh".to_string(),
8274 expires_at: Instant::now() + Duration::from_secs(5),
8275 },
8276 );
8277 snapshots.insert(
8278 "stale".to_string(),
8279 ResumeSnapshot {
8280 session: LiveSession::new(Box::<Counter>::default(), "root"),
8281 route_pattern: "/".to_string(),
8282 resume_token: "resume-stale".to_string(),
8283 expires_at: Instant::now() - Duration::from_millis(1),
8284 },
8285 );
8286
8287 cleanup_expired_snapshots(&mut snapshots);
8288
8289 assert!(snapshots.contains_key("fresh"));
8290 assert!(!snapshots.contains_key("stale"));
8291 }
8292
8293 #[tokio::test]
8294 async fn http_handlers_serve_client_js_and_initial_shell() {
8295 let state =
8296 test_app_state_with("/", SecurityConfig::default(), DistributedConfig::default());
8297 let js = client_js(
8298 State(state.clone()),
8299 Method::GET,
8300 HeaderMap::new(),
8301 Uri::from_static("/__shelly/client.js"),
8302 )
8303 .await
8304 .into_response();
8305 assert_eq!(js.status(), StatusCode::OK);
8306 assert_eq!(
8307 js.headers().get(header::CONTENT_TYPE),
8308 Some(&HeaderValue::from_static("text/javascript; charset=utf-8"))
8309 );
8310 assert!(js.headers().get("alt-svc").is_none());
8311 assert!(js.headers().get("x-shelly-transport").is_none());
8312 let js_body = to_bytes(js.into_body(), usize::MAX).await.expect("js body");
8313 assert_eq!(js_body.as_ref(), CLIENT_JS.as_bytes());
8314 let not_found = initial_handler(
8315 State(state.clone()),
8316 Method::GET,
8317 HeaderMap::new(),
8318 Uri::from_static("/missing"),
8319 )
8320 .await;
8321 assert_eq!(not_found.status(), StatusCode::NOT_FOUND);
8322 let not_found_body = to_bytes(not_found.into_body(), usize::MAX)
8323 .await
8324 .expect("not found body");
8325 assert!(String::from_utf8_lossy(¬_found_body).contains("/missing"));
8326
8327 let ok = initial_handler(
8328 State(state),
8329 Method::GET,
8330 HeaderMap::new(),
8331 Uri::from_static("/"),
8332 )
8333 .await;
8334 assert_eq!(ok.status(), StatusCode::OK);
8335 assert!(ok.headers().get("alt-svc").is_none());
8336 assert!(ok.headers().get("x-shelly-transport").is_none());
8337 let body = to_bytes(ok.into_body(), usize::MAX)
8338 .await
8339 .expect("shell body");
8340 let html = String::from_utf8_lossy(&body);
8341 assert!(html.contains("window.__SHELLY"));
8342 assert!(html.contains("<div id=\"root\"><p>0</p></div>"));
8343 assert!(html.contains("sessionToken"));
8344 assert!(html.contains("csrfToken"));
8345 }
8346
8347 #[tokio::test]
8348 async fn http_handlers_emit_http3_headers_when_configured() {
8349 let mut state =
8350 test_app_state_with("/", SecurityConfig::default(), DistributedConfig::default());
8351 Arc::make_mut(&mut state).transport_http = HttpTransportConfig {
8352 http3_alt_svc: Some(DEFAULT_HTTP3_ALT_SVC.to_string()),
8353 emit_diagnostics_header: true,
8354 primary_mode: TransportMode::WebSocket,
8355 fallback_modes: vec![TransportMode::ServerSentEvents, TransportMode::LongPoll],
8356 progressive_enhancement: true,
8357 };
8358
8359 let js = client_js(
8360 State(state.clone()),
8361 Method::GET,
8362 HeaderMap::new(),
8363 Uri::from_static("/__shelly/client.js"),
8364 )
8365 .await
8366 .into_response();
8367 assert_eq!(
8368 js.headers().get("alt-svc"),
8369 Some(&HeaderValue::from_static(DEFAULT_HTTP3_ALT_SVC))
8370 );
8371 assert_eq!(
8372 js.headers().get("x-shelly-transport"),
8373 Some(&HeaderValue::from_static(
8374 "websocket; fallback=sse,long_poll; progressive=enabled; h3=advertised"
8375 ))
8376 );
8377 assert_eq!(
8378 js.headers().get("x-shelly-progressive-enhancement"),
8379 Some(&HeaderValue::from_static("enabled"))
8380 );
8381
8382 let shell = initial_handler(
8383 State(state),
8384 Method::GET,
8385 HeaderMap::new(),
8386 Uri::from_static("/"),
8387 )
8388 .await;
8389 assert_eq!(
8390 shell.headers().get("alt-svc"),
8391 Some(&HeaderValue::from_static(DEFAULT_HTTP3_ALT_SVC))
8392 );
8393 assert_eq!(
8394 shell.headers().get("x-shelly-transport"),
8395 Some(&HeaderValue::from_static(
8396 "websocket; fallback=sse,long_poll; progressive=enabled; h3=advertised"
8397 ))
8398 );
8399 }
8400
8401 #[tokio::test]
8402 async fn transport_capabilities_describe_degraded_profiles() {
8403 let mut state =
8404 test_app_state_with("/", SecurityConfig::default(), DistributedConfig::default());
8405 Arc::make_mut(&mut state).transport_http = HttpTransportConfig {
8406 http3_alt_svc: None,
8407 emit_diagnostics_header: true,
8408 primary_mode: TransportMode::WebSocket,
8409 fallback_modes: vec![TransportMode::ServerSentEvents, TransportMode::LongPoll],
8410 progressive_enhancement: true,
8411 };
8412
8413 let response = transport_capabilities(
8414 State(state),
8415 Method::GET,
8416 HeaderMap::new(),
8417 Uri::from_static("/__shelly/transport"),
8418 )
8419 .await;
8420
8421 assert_eq!(response.status(), StatusCode::OK);
8422 assert_eq!(
8423 response.headers().get("x-shelly-transport"),
8424 Some(&HeaderValue::from_static(
8425 "websocket; fallback=sse,long_poll; progressive=enabled; h3=disabled"
8426 ))
8427 );
8428 let body = to_bytes(response.into_body(), usize::MAX)
8429 .await
8430 .expect("transport capabilities body");
8431 let json: serde_json::Value =
8432 serde_json::from_slice(&body).expect("transport capabilities json");
8433 assert_eq!(json["protocol"], PROTOCOL_VERSION);
8434 assert_eq!(json["primary"], "websocket");
8435 assert_eq!(json["fallbacks"], serde_json::json!(["sse", "long_poll"]));
8436 assert_eq!(json["progressive_enhancement"], true);
8437 assert_eq!(json["http3_advertised"], false);
8438 }
8439
8440 #[tokio::test]
8441 async fn ws_handler_enforces_guardrails_and_accepts_valid_upgrade() {
8442 let secret = b"test secret".to_vec();
8443
8444 let (origin_base, origin_server) = spawn_test_router(
8445 ShellyRouter::new()
8446 .with_secret(secret.clone())
8447 .live("/", Counter::default),
8448 )
8449 .await;
8450 let mut origin_request = format!("{origin_base}/__shelly/ws/__root__?protocol=shelly/1")
8451 .into_client_request()
8452 .expect("origin request");
8453 origin_request
8454 .headers_mut()
8455 .insert("origin", HeaderValue::from_static("https://evil.example"));
8456 let origin_err = connect_async(origin_request)
8457 .await
8458 .expect_err("cross origin upgrade should fail");
8459 assert_eq!(ws_error_status(origin_err), StatusCode::FORBIDDEN);
8460 origin_server.abort();
8461
8462 let (protocol_base, protocol_server) = spawn_test_router(
8463 ShellyRouter::new()
8464 .with_secret(secret.clone())
8465 .with_allowed_origin("https://example.test")
8466 .live("/", Counter::default),
8467 )
8468 .await;
8469 let mut protocol_request =
8470 format!("{protocol_base}/__shelly/ws/__root__?protocol=shelly/9")
8471 .into_client_request()
8472 .expect("protocol request");
8473 protocol_request
8474 .headers_mut()
8475 .insert("origin", HeaderValue::from_static("https://example.test"));
8476 let protocol_err = connect_async(protocol_request)
8477 .await
8478 .expect_err("unsupported protocol should fail");
8479 assert_eq!(ws_error_status(protocol_err), StatusCode::BAD_REQUEST);
8480 protocol_server.abort();
8481
8482 let (unsigned_base, unsigned_server) = spawn_test_router(
8483 ShellyRouter::new()
8484 .with_secret(secret.clone())
8485 .with_allowed_origin("https://example.test")
8486 .live("/", Counter::default),
8487 )
8488 .await;
8489 let mut unsigned_request =
8490 format!("{unsigned_base}/__shelly/ws/__root__?protocol=shelly/1")
8491 .into_client_request()
8492 .expect("unsigned request");
8493 unsigned_request
8494 .headers_mut()
8495 .insert("origin", HeaderValue::from_static("https://example.test"));
8496 let unsigned_err = connect_async(unsigned_request)
8497 .await
8498 .expect_err("missing signed session should fail");
8499 assert_eq!(ws_error_status(unsigned_err), StatusCode::UNAUTHORIZED);
8500 unsigned_server.abort();
8501
8502 let signer = TokenSigner::new(secret.clone());
8503 let mismatch_session = signer.sign_session("session-1", "/", "node-a");
8504 let mismatch_csrf = signer.sign_csrf("session-1", "/");
8505 let (affinity_base, affinity_server) = spawn_test_router(
8506 ShellyRouter::new()
8507 .with_secret(secret.clone())
8508 .with_allowed_origin("https://example.test")
8509 .with_node_id("node-b")
8510 .with_session_affinity_mode(SessionAffinityMode::Required)
8511 .live("/", Counter::default),
8512 )
8513 .await;
8514 let mut mismatch_request = format!(
8515 "{affinity_base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={mismatch_session}&csrf={mismatch_csrf}"
8516 )
8517 .into_client_request()
8518 .expect("mismatch request");
8519 mismatch_request
8520 .headers_mut()
8521 .insert("origin", HeaderValue::from_static("https://example.test"));
8522 let mismatch_err = connect_async(mismatch_request)
8523 .await
8524 .expect_err("session-affinity mismatch should fail");
8525 assert_eq!(ws_error_status(mismatch_err), StatusCode::CONFLICT);
8526 affinity_server.abort();
8527
8528 let valid_session = signer.sign_session("session-2", "/", "node-b");
8529 let valid_csrf = signer.sign_csrf("session-2", "/");
8530 let (ok_base, ok_server) = spawn_test_router(
8531 ShellyRouter::new()
8532 .with_secret(secret)
8533 .with_allowed_origin("https://example.test")
8534 .with_node_id("node-b")
8535 .with_session_affinity_mode(SessionAffinityMode::Required)
8536 .live("/", Counter::default),
8537 )
8538 .await;
8539 let mut accepted_request = format!(
8540 "{ok_base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={valid_session}&csrf={valid_csrf}"
8541 )
8542 .into_client_request()
8543 .expect("accepted request");
8544 accepted_request
8545 .headers_mut()
8546 .insert("origin", HeaderValue::from_static("https://example.test"));
8547 let (mut socket, response) = connect_async(accepted_request)
8548 .await
8549 .expect("valid websocket upgrade should succeed");
8550 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
8551
8552 let connect = serde_json::json!({
8553 "type": "connect",
8554 "protocol": PROTOCOL_VERSION,
8555 "session_id": "session-2",
8556 "last_revision": 0,
8557 "trace_id": "4bf92f3577b34da6a3ce929d0e0e4736",
8558 "span_id": "00f067aa0ba902b7",
8559 "correlation_id": "corr-ws-test",
8560 "request_id": "req-ws-test"
8561 });
8562 socket
8563 .send(WsMessage::Text(connect.to_string().into()))
8564 .await
8565 .expect("send connect frame");
8566 let first = socket.next().await.expect("first server frame");
8567 assert!(matches!(first, Ok(WsMessage::Text(_))));
8568 let _ = socket.close(None).await;
8569 ok_server.abort();
8570 }
8571
8572 #[tokio::test]
8573 async fn ws_connect_handshake_rejects_non_connect_and_protocol_mismatch_frames() {
8574 let secret = b"test secret".to_vec();
8575 let signer = TokenSigner::new(secret.clone());
8576
8577 let session_token = signer.sign_session("session-connect-reject", "/", "node-a");
8578 let csrf_token = signer.sign_csrf("session-connect-reject", "/");
8579 let (base, server) = spawn_test_router(
8580 ShellyRouter::new()
8581 .with_secret(secret.clone())
8582 .with_allowed_origin("https://example.test")
8583 .live("/", Counter::default),
8584 )
8585 .await;
8586
8587 let mut request = format!(
8588 "{base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={session_token}&csrf={csrf_token}"
8589 )
8590 .into_client_request()
8591 .expect("connect-required request");
8592 request
8593 .headers_mut()
8594 .insert("origin", HeaderValue::from_static("https://example.test"));
8595 let (mut socket, response) = connect_async(request)
8596 .await
8597 .expect("websocket connect should succeed");
8598 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
8599
8600 let event_first = serde_json::json!({
8601 "type": "event",
8602 "event": "inc",
8603 "target": null,
8604 "value": null,
8605 "metadata": {}
8606 });
8607 socket
8608 .send(WsMessage::Text(event_first.to_string().into()))
8609 .await
8610 .expect("send non-connect first frame");
8611 let first = socket.next().await.expect("first rejection frame");
8612 match first {
8613 Ok(WsMessage::Text(text)) => {
8614 let message: ServerMessage =
8615 serde_json::from_str(text.as_str()).expect("error frame should decode");
8616 assert!(matches!(
8617 message,
8618 ServerMessage::Error { code: Some(code), .. } if code == "connect_required"
8619 ));
8620 }
8621 other => panic!("unexpected first rejection frame: {other:?}"),
8622 }
8623 let _ = socket.close(None).await;
8624 server.abort();
8625
8626 let session_token = signer.sign_session("session-connect-proto", "/", "node-a");
8627 let csrf_token = signer.sign_csrf("session-connect-proto", "/");
8628 let (base, server) = spawn_test_router(
8629 ShellyRouter::new()
8630 .with_secret(secret)
8631 .with_allowed_origin("https://example.test")
8632 .live("/", Counter::default),
8633 )
8634 .await;
8635
8636 let mut request = format!(
8637 "{base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={session_token}&csrf={csrf_token}"
8638 )
8639 .into_client_request()
8640 .expect("protocol mismatch request");
8641 request
8642 .headers_mut()
8643 .insert("origin", HeaderValue::from_static("https://example.test"));
8644 let (mut socket, response) = connect_async(request)
8645 .await
8646 .expect("websocket connect should succeed");
8647 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
8648
8649 let unsupported_connect = serde_json::json!({
8650 "type": "connect",
8651 "protocol": "shelly/9",
8652 "session_id": "session-connect-proto",
8653 "last_revision": 0
8654 });
8655 socket
8656 .send(WsMessage::Text(unsupported_connect.to_string().into()))
8657 .await
8658 .expect("send unsupported connect frame");
8659 let first = socket.next().await.expect("first rejection frame");
8660 match first {
8661 Ok(WsMessage::Text(text)) => {
8662 let message: ServerMessage =
8663 serde_json::from_str(text.as_str()).expect("error frame should decode");
8664 assert!(matches!(
8665 message,
8666 ServerMessage::Error { code: Some(code), .. } if code == "unsupported_protocol"
8667 ));
8668 }
8669 other => panic!("unexpected unsupported protocol frame: {other:?}"),
8670 }
8671 let _ = socket.close(None).await;
8672 server.abort();
8673 }
8674
8675 #[tokio::test]
8676 async fn ws_resume_reconcile_and_fallback_paths_are_exercised() {
8677 let secret = b"test secret".to_vec();
8678 let signer = TokenSigner::new(secret.clone());
8679 let session_id = "session-resume-flow";
8680 let session_token = signer.sign_session(session_id, "/", "node-a");
8681 let csrf_token = signer.sign_csrf(session_id, "/");
8682
8683 let (base, server) = spawn_test_router(
8684 ShellyRouter::new()
8685 .with_secret(secret)
8686 .with_allowed_origin("https://example.test")
8687 .with_resume_ttl_ms(5_000)
8688 .live("/", Counter::default),
8689 )
8690 .await;
8691
8692 let mut request = format!(
8693 "{base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={session_token}&csrf={csrf_token}"
8694 )
8695 .into_client_request()
8696 .expect("fresh connect request");
8697 request
8698 .headers_mut()
8699 .insert("origin", HeaderValue::from_static("https://example.test"));
8700 let (mut socket, response) = connect_async(request)
8701 .await
8702 .expect("fresh websocket should connect");
8703 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
8704
8705 let connect_fresh = serde_json::json!({
8706 "type": "connect",
8707 "protocol": PROTOCOL_VERSION,
8708 "session_id": session_id,
8709 "last_revision": 0
8710 });
8711 socket
8712 .send(WsMessage::Text(connect_fresh.to_string().into()))
8713 .await
8714 .expect("send fresh connect");
8715
8716 let mut resume_token = None::<String>;
8717 let mut hello_seen = false;
8718 for _ in 0..4 {
8719 let frame = socket.next().await.expect("fresh server frame");
8720 if let Ok(WsMessage::Text(text)) = frame {
8721 let message: ServerMessage =
8722 serde_json::from_str(text.as_str()).expect("fresh frame should decode");
8723 if let ServerMessage::Hello {
8724 resume_status,
8725 resume_token: token,
8726 ..
8727 } = message
8728 {
8729 hello_seen = true;
8730 assert_eq!(resume_status, Some(ResumeStatus::Fresh));
8731 resume_token = token;
8732 break;
8733 }
8734 }
8735 }
8736 assert!(hello_seen, "expected hello frame in fresh connect flow");
8737 let resume_token = resume_token.expect("fresh hello must include resume token");
8738
8739 let increment = serde_json::json!({
8740 "type": "event",
8741 "event": "inc",
8742 "target": null,
8743 "value": null,
8744 "metadata": {}
8745 });
8746 socket
8747 .send(WsMessage::Text(increment.to_string().into()))
8748 .await
8749 .expect("send increment event");
8750 let mut patch_seen = false;
8751 for _ in 0..6 {
8752 let frame = socket.next().await.expect("event response frame");
8753 if let Ok(WsMessage::Text(text)) = frame {
8754 let message: ServerMessage =
8755 serde_json::from_str(text.as_str()).expect("event response should decode");
8756 if let ServerMessage::Patch { html, revision, .. } = message {
8757 assert!(html.contains("<p>1</p>") || html.contains("<p>0</p>"));
8758 assert!(revision >= 1);
8759 patch_seen = true;
8760 break;
8761 }
8762 }
8763 }
8764 assert!(patch_seen, "expected at least one patch after increment");
8765 let _ = socket.close(None).await;
8766 tokio::time::sleep(Duration::from_millis(25)).await;
8767
8768 let mut request = format!(
8769 "{base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={session_token}&csrf={csrf_token}"
8770 )
8771 .into_client_request()
8772 .expect("resume request");
8773 request
8774 .headers_mut()
8775 .insert("origin", HeaderValue::from_static("https://example.test"));
8776 let (mut resumed_socket, response) = connect_async(request)
8777 .await
8778 .expect("resume websocket should connect");
8779 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
8780
8781 let resume_connect = serde_json::json!({
8782 "type": "connect",
8783 "protocol": PROTOCOL_VERSION,
8784 "session_id": session_id,
8785 "last_revision": 0,
8786 "resume_token": resume_token
8787 });
8788 resumed_socket
8789 .send(WsMessage::Text(resume_connect.to_string().into()))
8790 .await
8791 .expect("send resume connect");
8792
8793 let mut resumed_hello = false;
8794 let mut reconciled_patch = false;
8795 for _ in 0..8 {
8796 let frame = resumed_socket.next().await.expect("resume response frame");
8797 if let Ok(WsMessage::Text(text)) = frame {
8798 let message: ServerMessage =
8799 serde_json::from_str(text.as_str()).expect("resume frame should decode");
8800 match message {
8801 ServerMessage::Hello {
8802 resume_status,
8803 server_revision,
8804 ..
8805 } => {
8806 assert_eq!(resume_status, Some(ResumeStatus::Resumed));
8807 assert!(server_revision.is_some());
8808 resumed_hello = true;
8809 }
8810 ServerMessage::Patch { .. } => {
8811 reconciled_patch = true;
8812 }
8813 _ => {}
8814 }
8815 if resumed_hello && reconciled_patch {
8816 break;
8817 }
8818 }
8819 }
8820 assert!(resumed_hello, "expected resumed hello");
8821 assert!(reconciled_patch, "expected reconciliation patch");
8822 let _ = resumed_socket.close(None).await;
8823 tokio::time::sleep(Duration::from_millis(25)).await;
8824
8825 let mut request = format!(
8826 "{base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={session_token}&csrf={csrf_token}"
8827 )
8828 .into_client_request()
8829 .expect("fallback request");
8830 request
8831 .headers_mut()
8832 .insert("origin", HeaderValue::from_static("https://example.test"));
8833 let (mut fallback_socket, response) = connect_async(request)
8834 .await
8835 .expect("fallback websocket should connect");
8836 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
8837
8838 let fallback_connect = serde_json::json!({
8839 "type": "connect",
8840 "protocol": PROTOCOL_VERSION,
8841 "session_id": session_id,
8842 "last_revision": 3,
8843 "resume_token": "invalid-token"
8844 });
8845 fallback_socket
8846 .send(WsMessage::Text(fallback_connect.to_string().into()))
8847 .await
8848 .expect("send fallback connect");
8849
8850 let mut fallback_seen = false;
8851 for _ in 0..6 {
8852 let frame = fallback_socket.next().await.expect("fallback frame");
8853 if let Ok(WsMessage::Text(text)) = frame {
8854 let message: ServerMessage =
8855 serde_json::from_str(text.as_str()).expect("fallback frame should decode");
8856 if let ServerMessage::Hello {
8857 resume_status,
8858 resume_reason,
8859 ..
8860 } = message
8861 {
8862 assert_eq!(resume_status, Some(ResumeStatus::Fallback));
8863 assert_eq!(resume_reason.as_deref(), Some("resume_token_invalid"));
8864 fallback_seen = true;
8865 break;
8866 }
8867 }
8868 }
8869 assert!(fallback_seen, "expected fallback hello");
8870 let _ = fallback_socket.close(None).await;
8871 server.abort();
8872 }
8873
8874 #[tokio::test]
8875 async fn ws_connect_handshake_enforces_signed_session_id_match() {
8876 let secret = b"test secret".to_vec();
8877 let signer = TokenSigner::new(secret.clone());
8878 let session_token = signer.sign_session("session-a", "/", "node-a");
8879 let csrf_token = signer.sign_csrf("session-a", "/");
8880 let (base, server) = spawn_test_router(
8881 ShellyRouter::new()
8882 .with_secret(secret)
8883 .with_allowed_origin("https://example.test")
8884 .live("/", Counter::default),
8885 )
8886 .await;
8887
8888 let mut request = format!(
8889 "{base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={session_token}&csrf={csrf_token}"
8890 )
8891 .into_client_request()
8892 .expect("session mismatch request");
8893 request
8894 .headers_mut()
8895 .insert("origin", HeaderValue::from_static("https://example.test"));
8896 let (mut socket, response) = connect_async(request)
8897 .await
8898 .expect("websocket connect should succeed");
8899 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
8900
8901 let connect = serde_json::json!({
8902 "type": "connect",
8903 "protocol": PROTOCOL_VERSION,
8904 "session_id": "session-b",
8905 "last_revision": 0
8906 });
8907 socket
8908 .send(WsMessage::Text(connect.to_string().into()))
8909 .await
8910 .expect("send connect");
8911 let frame = socket.next().await.expect("session mismatch frame");
8912 match frame {
8913 Ok(WsMessage::Text(text)) => {
8914 let message: ServerMessage =
8915 serde_json::from_str(text.as_str()).expect("frame should decode");
8916 assert!(matches!(
8917 message,
8918 ServerMessage::Error { code: Some(code), .. } if code == "session_mismatch"
8919 ));
8920 }
8921 other => panic!("unexpected session mismatch frame: {other:?}"),
8922 }
8923 let _ = socket.close(None).await;
8924 server.abort();
8925 }
8926
8927 #[tokio::test]
8928 async fn ws_connect_handshake_can_be_rejected_by_quota_and_authorization_hooks() {
8929 let secret = b"test secret".to_vec();
8930 let signer = TokenSigner::new(secret.clone());
8931
8932 let quota_session = signer.sign_session("session-quota", "/", "node-a");
8933 let quota_csrf = signer.sign_csrf("session-quota", "/");
8934 let (quota_base, quota_server) = spawn_test_router(
8935 ShellyRouter::new()
8936 .with_secret(secret.clone())
8937 .with_allowed_origin("https://example.test")
8938 .with_quota_policy(|ctx| {
8939 if ctx.operation == SecurityOperation::Connect {
8940 QuotaDecision::deny("quota_exceeded", "connect quota exceeded")
8941 } else {
8942 QuotaDecision::allow()
8943 }
8944 })
8945 .live("/", Counter::default),
8946 )
8947 .await;
8948
8949 let mut quota_request = format!(
8950 "{quota_base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={quota_session}&csrf={quota_csrf}"
8951 )
8952 .into_client_request()
8953 .expect("quota request");
8954 quota_request
8955 .headers_mut()
8956 .insert("origin", HeaderValue::from_static("https://example.test"));
8957 let (mut quota_socket, response) = connect_async(quota_request)
8958 .await
8959 .expect("quota websocket should connect");
8960 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
8961
8962 let connect = serde_json::json!({
8963 "type": "connect",
8964 "protocol": PROTOCOL_VERSION,
8965 "session_id": "session-quota",
8966 "last_revision": 0
8967 });
8968 quota_socket
8969 .send(WsMessage::Text(connect.to_string().into()))
8970 .await
8971 .expect("send quota connect");
8972 let frame = quota_socket.next().await.expect("quota rejection frame");
8973 match frame {
8974 Ok(WsMessage::Text(text)) => {
8975 let message: ServerMessage =
8976 serde_json::from_str(text.as_str()).expect("quota frame should decode");
8977 assert!(matches!(
8978 message,
8979 ServerMessage::Error { code: Some(code), .. } if code == "quota_exceeded"
8980 ));
8981 }
8982 other => panic!("unexpected quota frame: {other:?}"),
8983 }
8984 let _ = quota_socket.close(None).await;
8985 quota_server.abort();
8986
8987 let auth_session = signer.sign_session("session-auth", "/", "node-a");
8988 let auth_csrf = signer.sign_csrf("session-auth", "/");
8989 let (auth_base, auth_server) = spawn_test_router(
8990 ShellyRouter::new()
8991 .with_secret(secret)
8992 .with_allowed_origin("https://example.test")
8993 .with_authorization_hook(|ctx| {
8994 if ctx.operation == SecurityOperation::Connect {
8995 AuthorizationDecision::deny("unauthorized", "connect denied")
8996 } else {
8997 AuthorizationDecision::allow()
8998 }
8999 })
9000 .live("/", Counter::default),
9001 )
9002 .await;
9003
9004 let mut auth_request = format!(
9005 "{auth_base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={auth_session}&csrf={auth_csrf}"
9006 )
9007 .into_client_request()
9008 .expect("authorization request");
9009 auth_request
9010 .headers_mut()
9011 .insert("origin", HeaderValue::from_static("https://example.test"));
9012 let (mut auth_socket, response) = connect_async(auth_request)
9013 .await
9014 .expect("authorization websocket should connect");
9015 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
9016
9017 let connect = serde_json::json!({
9018 "type": "connect",
9019 "protocol": PROTOCOL_VERSION,
9020 "session_id": "session-auth",
9021 "last_revision": 0
9022 });
9023 auth_socket
9024 .send(WsMessage::Text(connect.to_string().into()))
9025 .await
9026 .expect("send authorization connect");
9027 let frame = auth_socket
9028 .next()
9029 .await
9030 .expect("authorization rejection frame");
9031 match frame {
9032 Ok(WsMessage::Text(text)) => {
9033 let message: ServerMessage =
9034 serde_json::from_str(text.as_str()).expect("authorization frame should decode");
9035 assert!(matches!(
9036 message,
9037 ServerMessage::Error { code: Some(code), .. } if code == "unauthorized"
9038 ));
9039 }
9040 other => panic!("unexpected authorization frame: {other:?}"),
9041 }
9042 let _ = auth_socket.close(None).await;
9043 auth_server.abort();
9044 }
9045
9046 #[tokio::test]
9047 async fn ws_resume_rejects_tenant_context_mismatch() {
9048 let secret = b"test secret".to_vec();
9049 let signer = TokenSigner::new(secret.clone());
9050 let session_id = "session-tenant-resume";
9051 let session_token = signer.sign_session(session_id, "/", "node-a");
9052 let csrf_token = signer.sign_csrf(session_id, "/");
9053 let (base, server) = spawn_test_router(
9054 ShellyRouter::new()
9055 .with_secret(secret)
9056 .with_allowed_origin("https://example.test")
9057 .with_resume_ttl_ms(5_000)
9058 .live("/", Counter::default),
9059 )
9060 .await;
9061
9062 let mut request = format!(
9063 "{base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={session_token}&csrf={csrf_token}"
9064 )
9065 .into_client_request()
9066 .expect("first tenant connect request");
9067 request
9068 .headers_mut()
9069 .insert("origin", HeaderValue::from_static("https://example.test"));
9070 let (mut socket, response) = connect_async(request)
9071 .await
9072 .expect("first tenant websocket should connect");
9073 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
9074
9075 let connect = serde_json::json!({
9076 "type": "connect",
9077 "protocol": PROTOCOL_VERSION,
9078 "session_id": session_id,
9079 "last_revision": 0,
9080 "tenant_id": "tenant-a"
9081 });
9082 socket
9083 .send(WsMessage::Text(connect.to_string().into()))
9084 .await
9085 .expect("send first tenant connect");
9086
9087 let mut resume_token = None::<String>;
9088 for _ in 0..4 {
9089 let frame = socket.next().await.expect("first tenant frame");
9090 if let Ok(WsMessage::Text(text)) = frame {
9091 let message: ServerMessage =
9092 serde_json::from_str(text.as_str()).expect("first tenant frame decode");
9093 if let ServerMessage::Hello {
9094 resume_token: token,
9095 ..
9096 } = message
9097 {
9098 resume_token = token;
9099 break;
9100 }
9101 }
9102 }
9103 let resume_token = resume_token.expect("tenant hello should include resume token");
9104 let _ = socket.close(None).await;
9105 tokio::time::sleep(Duration::from_millis(25)).await;
9106
9107 let mut request = format!(
9108 "{base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={session_token}&csrf={csrf_token}"
9109 )
9110 .into_client_request()
9111 .expect("tenant mismatch reconnect request");
9112 request
9113 .headers_mut()
9114 .insert("origin", HeaderValue::from_static("https://example.test"));
9115 let (mut reconnect_socket, response) = connect_async(request)
9116 .await
9117 .expect("tenant mismatch websocket should connect");
9118 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
9119
9120 let connect = serde_json::json!({
9121 "type": "connect",
9122 "protocol": PROTOCOL_VERSION,
9123 "session_id": session_id,
9124 "last_revision": 0,
9125 "tenant_id": "tenant-b",
9126 "resume_token": resume_token
9127 });
9128 reconnect_socket
9129 .send(WsMessage::Text(connect.to_string().into()))
9130 .await
9131 .expect("send tenant mismatch reconnect");
9132 let frame = reconnect_socket
9133 .next()
9134 .await
9135 .expect("tenant mismatch error frame");
9136 match frame {
9137 Ok(WsMessage::Text(text)) => {
9138 let message: ServerMessage =
9139 serde_json::from_str(text.as_str()).expect("tenant mismatch frame decode");
9140 assert!(matches!(
9141 message,
9142 ServerMessage::Error { code: Some(code), .. } if code == "tenant_mismatch"
9143 ));
9144 }
9145 other => panic!("unexpected tenant mismatch frame: {other:?}"),
9146 }
9147 let _ = reconnect_socket.close(None).await;
9148 server.abort();
9149 }
9150
9151 #[tokio::test]
9152 async fn ws_text_event_rejections_cover_tenant_quota_and_authorization_paths() {
9153 let secret = b"test secret".to_vec();
9154 let signer = TokenSigner::new(secret.clone());
9155
9156 let session_id = "session-tenant-text";
9157 let session_token = signer.sign_session(session_id, "/", "node-a");
9158 let csrf_token = signer.sign_csrf(session_id, "/");
9159 let (base, server) = spawn_test_router(
9160 ShellyRouter::new()
9161 .with_secret(secret.clone())
9162 .with_allowed_origin("https://example.test")
9163 .live("/", Counter::default),
9164 )
9165 .await;
9166
9167 let mut request = format!(
9168 "{base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={session_token}&csrf={csrf_token}"
9169 )
9170 .into_client_request()
9171 .expect("text tenant request");
9172 request
9173 .headers_mut()
9174 .insert("origin", HeaderValue::from_static("https://example.test"));
9175 let (mut socket, response) = connect_async(request)
9176 .await
9177 .expect("text tenant websocket should connect");
9178 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
9179
9180 let connect = serde_json::json!({
9181 "type": "connect",
9182 "protocol": PROTOCOL_VERSION,
9183 "session_id": session_id,
9184 "last_revision": 0,
9185 "tenant_id": "tenant-a"
9186 });
9187 socket
9188 .send(WsMessage::Text(connect.to_string().into()))
9189 .await
9190 .expect("send tenant connect");
9191 await_hello(&mut socket).await;
9192
9193 let tenant_conflict_event = serde_json::json!({
9194 "type": "event",
9195 "event": "inc",
9196 "target": null,
9197 "value": { "tenant_id": "tenant-b" },
9198 "metadata": { "tenant_id": "tenant-b" }
9199 });
9200 socket
9201 .send(WsMessage::Text(tenant_conflict_event.to_string().into()))
9202 .await
9203 .expect("send tenant conflict event");
9204 await_error_code(&mut socket, "tenant_mismatch").await;
9205 let _ = socket.close(None).await;
9206 server.abort();
9207
9208 let quota_session = signer.sign_session("session-quota-text", "/", "node-a");
9209 let quota_csrf = signer.sign_csrf("session-quota-text", "/");
9210 let (quota_base, quota_server) = spawn_test_router(
9211 ShellyRouter::new()
9212 .with_secret(secret.clone())
9213 .with_allowed_origin("https://example.test")
9214 .with_quota_policy(|ctx| {
9215 if ctx.operation == SecurityOperation::Event {
9216 QuotaDecision::deny("quota_exceeded", "event quota exceeded")
9217 } else {
9218 QuotaDecision::allow()
9219 }
9220 })
9221 .live("/", Counter::default),
9222 )
9223 .await;
9224
9225 let mut quota_request = format!(
9226 "{quota_base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={quota_session}&csrf={quota_csrf}"
9227 )
9228 .into_client_request()
9229 .expect("quota text request");
9230 quota_request
9231 .headers_mut()
9232 .insert("origin", HeaderValue::from_static("https://example.test"));
9233 let (mut quota_socket, response) = connect_async(quota_request)
9234 .await
9235 .expect("quota text websocket should connect");
9236 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
9237 let connect = serde_json::json!({
9238 "type": "connect",
9239 "protocol": PROTOCOL_VERSION,
9240 "session_id": "session-quota-text",
9241 "last_revision": 0,
9242 "tenant_id": "tenant-a"
9243 });
9244 quota_socket
9245 .send(WsMessage::Text(connect.to_string().into()))
9246 .await
9247 .expect("send quota text connect");
9248 await_hello(&mut quota_socket).await;
9249 let event = serde_json::json!({
9250 "type": "event",
9251 "event": "inc",
9252 "target": null,
9253 "value": { "tenant_id": "tenant-a" },
9254 "metadata": { "tenant_id": "tenant-a" }
9255 });
9256 quota_socket
9257 .send(WsMessage::Text(event.to_string().into()))
9258 .await
9259 .expect("send quota text event");
9260 await_error_code(&mut quota_socket, "quota_exceeded").await;
9261 let _ = quota_socket.close(None).await;
9262 quota_server.abort();
9263
9264 let auth_session = signer.sign_session("session-auth-text", "/", "node-a");
9265 let auth_csrf = signer.sign_csrf("session-auth-text", "/");
9266 let (auth_base, auth_server) = spawn_test_router(
9267 ShellyRouter::new()
9268 .with_secret(secret)
9269 .with_allowed_origin("https://example.test")
9270 .with_authorization_hook(|ctx| {
9271 if ctx.operation == SecurityOperation::Event {
9272 AuthorizationDecision::deny("unauthorized", "event denied")
9273 } else {
9274 AuthorizationDecision::allow()
9275 }
9276 })
9277 .live("/", Counter::default),
9278 )
9279 .await;
9280 let mut auth_request = format!(
9281 "{auth_base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={auth_session}&csrf={auth_csrf}"
9282 )
9283 .into_client_request()
9284 .expect("auth text request");
9285 auth_request
9286 .headers_mut()
9287 .insert("origin", HeaderValue::from_static("https://example.test"));
9288 let (mut auth_socket, response) = connect_async(auth_request)
9289 .await
9290 .expect("auth text websocket should connect");
9291 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
9292 let connect = serde_json::json!({
9293 "type": "connect",
9294 "protocol": PROTOCOL_VERSION,
9295 "session_id": "session-auth-text",
9296 "last_revision": 0,
9297 "tenant_id": "tenant-a"
9298 });
9299 auth_socket
9300 .send(WsMessage::Text(connect.to_string().into()))
9301 .await
9302 .expect("send auth text connect");
9303 await_hello(&mut auth_socket).await;
9304 let event = serde_json::json!({
9305 "type": "event",
9306 "event": "inc",
9307 "target": null,
9308 "value": { "tenant_id": "tenant-a" },
9309 "metadata": { "tenant_id": "tenant-a" }
9310 });
9311 auth_socket
9312 .send(WsMessage::Text(event.to_string().into()))
9313 .await
9314 .expect("send auth text event");
9315 await_error_code(&mut auth_socket, "unauthorized").await;
9316 let _ = auth_socket.close(None).await;
9317 auth_server.abort();
9318 }
9319
9320 #[tokio::test]
9321 async fn ws_overload_and_binary_branches_emit_expected_errors() {
9322 let secret = b"test secret".to_vec();
9323 let signer = TokenSigner::new(secret.clone());
9324
9325 let shed_session = signer.sign_session("session-overload", "/", "node-a");
9326 let shed_csrf = signer.sign_csrf("session-overload", "/");
9327 let (shed_base, shed_server) = spawn_test_router(
9328 ShellyRouter::new()
9329 .with_secret(secret.clone())
9330 .with_allowed_origin("https://example.test")
9331 .with_overload_budgets(OverloadBudgets {
9332 session_queue_depth: usize::MAX,
9333 session_bytes_per_sec: 1,
9334 session_cpu_ms_per_sec: u64::MAX,
9335 tenant_queue_depth: usize::MAX,
9336 tenant_bytes_per_sec: usize::MAX,
9337 tenant_cpu_ms_per_sec: u64::MAX,
9338 })
9339 .with_overload_shed_policy(OverloadShedPolicy::Strict)
9340 .live("/", Counter::default),
9341 )
9342 .await;
9343
9344 let mut shed_request = format!(
9345 "{shed_base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={shed_session}&csrf={shed_csrf}"
9346 )
9347 .into_client_request()
9348 .expect("overload request");
9349 shed_request
9350 .headers_mut()
9351 .insert("origin", HeaderValue::from_static("https://example.test"));
9352 let (mut shed_socket, response) = connect_async(shed_request)
9353 .await
9354 .expect("overload websocket should connect");
9355 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
9356 let connect = serde_json::json!({
9357 "type": "connect",
9358 "protocol": PROTOCOL_VERSION,
9359 "session_id": "session-overload",
9360 "last_revision": 0
9361 });
9362 shed_socket
9363 .send(WsMessage::Text(connect.to_string().into()))
9364 .await
9365 .expect("send overload connect");
9366 await_hello(&mut shed_socket).await;
9367 let event = serde_json::json!({
9368 "type": "event",
9369 "event": "inc",
9370 "target": null,
9371 "value": { "payload": "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" },
9372 "metadata": {}
9373 });
9374 shed_socket
9375 .send(WsMessage::Text(event.to_string().into()))
9376 .await
9377 .expect("send overload event");
9378 await_error_code(&mut shed_socket, "overload_shed").await;
9379 let _ = shed_socket.close(None).await;
9380 shed_server.abort();
9381
9382 let binary_large_session = signer.sign_session("session-binary-large", "/", "node-a");
9383 let binary_large_csrf = signer.sign_csrf("session-binary-large", "/");
9384 let (binary_large_base, binary_large_server) = spawn_test_router(
9385 ShellyRouter::new()
9386 .with_secret(secret.clone())
9387 .with_allowed_origin("https://example.test")
9388 .with_max_message_size(8)
9389 .live("/", Counter::default),
9390 )
9391 .await;
9392 let mut binary_large_request = format!(
9393 "{binary_large_base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={binary_large_session}&csrf={binary_large_csrf}"
9394 )
9395 .into_client_request()
9396 .expect("binary large request");
9397 binary_large_request
9398 .headers_mut()
9399 .insert("origin", HeaderValue::from_static("https://example.test"));
9400 let (mut binary_large_socket, response) = connect_async(binary_large_request)
9401 .await
9402 .expect("binary large websocket should connect");
9403 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
9404 let connect = serde_json::json!({
9405 "type": "connect",
9406 "protocol": PROTOCOL_VERSION,
9407 "session_id": "session-binary-large",
9408 "last_revision": 0
9409 });
9410 binary_large_socket
9411 .send(WsMessage::Text(connect.to_string().into()))
9412 .await
9413 .expect("send binary large connect");
9414 await_hello(&mut binary_large_socket).await;
9415 binary_large_socket
9416 .send(WsMessage::Binary(vec![0_u8; 64].into()))
9417 .await
9418 .expect("send oversized binary");
9419 let mut saw_terminal_frame = false;
9420 for _ in 0..10 {
9421 let Some(frame) = binary_large_socket.next().await else {
9422 saw_terminal_frame = true;
9423 break;
9424 };
9425 match frame {
9426 Ok(WsMessage::Text(text)) => {
9427 let message: ServerMessage =
9428 serde_json::from_str(text.as_str()).expect("payload frame decode");
9429 if matches!(
9430 message,
9431 ServerMessage::Error {
9432 code: Some(ref code),
9433 ..
9434 } if code == "payload_too_large"
9435 ) {
9436 saw_terminal_frame = true;
9437 break;
9438 }
9439 }
9440 Ok(WsMessage::Close(_)) => {
9441 saw_terminal_frame = true;
9442 break;
9443 }
9444 _ => {}
9445 }
9446 }
9447 assert!(
9448 saw_terminal_frame,
9449 "expected oversized binary path to terminate with error/close/no-frame"
9450 );
9451 let _ = binary_large_socket.close(None).await;
9452 binary_large_server.abort();
9453
9454 let binary_rate_session = signer.sign_session("session-binary-rate", "/", "node-a");
9455 let binary_rate_csrf = signer.sign_csrf("session-binary-rate", "/");
9456 let (binary_rate_base, binary_rate_server) = spawn_test_router(
9457 ShellyRouter::new()
9458 .with_secret(secret.clone())
9459 .with_allowed_origin("https://example.test")
9460 .with_rate_limiter(|ctx| ctx.message_kind != "binary")
9461 .live("/", Counter::default),
9462 )
9463 .await;
9464 let mut binary_rate_request = format!(
9465 "{binary_rate_base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={binary_rate_session}&csrf={binary_rate_csrf}"
9466 )
9467 .into_client_request()
9468 .expect("binary rate request");
9469 binary_rate_request
9470 .headers_mut()
9471 .insert("origin", HeaderValue::from_static("https://example.test"));
9472 let (mut binary_rate_socket, response) = connect_async(binary_rate_request)
9473 .await
9474 .expect("binary rate websocket should connect");
9475 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
9476 let connect = serde_json::json!({
9477 "type": "connect",
9478 "protocol": PROTOCOL_VERSION,
9479 "session_id": "session-binary-rate",
9480 "last_revision": 0
9481 });
9482 binary_rate_socket
9483 .send(WsMessage::Text(connect.to_string().into()))
9484 .await
9485 .expect("send binary rate connect");
9486 await_hello(&mut binary_rate_socket).await;
9487 binary_rate_socket
9488 .send(WsMessage::Binary(vec![1_u8; 4].into()))
9489 .await
9490 .expect("send binary for rate limiter");
9491 await_error_code(&mut binary_rate_socket, "rate_limited").await;
9492 let _ = binary_rate_socket.close(None).await;
9493 binary_rate_server.abort();
9494
9495 let binary_quota_session = signer.sign_session("session-binary-quota", "/", "node-a");
9496 let binary_quota_csrf = signer.sign_csrf("session-binary-quota", "/");
9497 let (binary_quota_base, binary_quota_server) = spawn_test_router(
9498 ShellyRouter::new()
9499 .with_secret(secret.clone())
9500 .with_allowed_origin("https://example.test")
9501 .with_quota_policy(|ctx| {
9502 if ctx.operation == SecurityOperation::Binary {
9503 QuotaDecision::deny("quota_exceeded", "binary quota exceeded")
9504 } else {
9505 QuotaDecision::allow()
9506 }
9507 })
9508 .live("/", Counter::default),
9509 )
9510 .await;
9511 let mut binary_quota_request = format!(
9512 "{binary_quota_base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={binary_quota_session}&csrf={binary_quota_csrf}"
9513 )
9514 .into_client_request()
9515 .expect("binary quota request");
9516 binary_quota_request
9517 .headers_mut()
9518 .insert("origin", HeaderValue::from_static("https://example.test"));
9519 let (mut binary_quota_socket, response) = connect_async(binary_quota_request)
9520 .await
9521 .expect("binary quota websocket should connect");
9522 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
9523 let connect = serde_json::json!({
9524 "type": "connect",
9525 "protocol": PROTOCOL_VERSION,
9526 "session_id": "session-binary-quota",
9527 "last_revision": 0
9528 });
9529 binary_quota_socket
9530 .send(WsMessage::Text(connect.to_string().into()))
9531 .await
9532 .expect("send binary quota connect");
9533 await_hello(&mut binary_quota_socket).await;
9534 binary_quota_socket
9535 .send(WsMessage::Binary(vec![2_u8; 4].into()))
9536 .await
9537 .expect("send binary for quota");
9538 await_error_code(&mut binary_quota_socket, "quota_exceeded").await;
9539 let _ = binary_quota_socket.close(None).await;
9540 binary_quota_server.abort();
9541
9542 let binary_auth_session = signer.sign_session("session-binary-auth", "/", "node-a");
9543 let binary_auth_csrf = signer.sign_csrf("session-binary-auth", "/");
9544 let (binary_auth_base, binary_auth_server) = spawn_test_router(
9545 ShellyRouter::new()
9546 .with_secret(secret)
9547 .with_allowed_origin("https://example.test")
9548 .with_authorization_hook(|ctx| {
9549 if ctx.operation == SecurityOperation::Binary {
9550 AuthorizationDecision::deny("unauthorized", "binary denied")
9551 } else {
9552 AuthorizationDecision::allow()
9553 }
9554 })
9555 .live("/", Counter::default),
9556 )
9557 .await;
9558 let mut binary_auth_request = format!(
9559 "{binary_auth_base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={binary_auth_session}&csrf={binary_auth_csrf}"
9560 )
9561 .into_client_request()
9562 .expect("binary auth request");
9563 binary_auth_request
9564 .headers_mut()
9565 .insert("origin", HeaderValue::from_static("https://example.test"));
9566 let (mut binary_auth_socket, response) = connect_async(binary_auth_request)
9567 .await
9568 .expect("binary auth websocket should connect");
9569 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
9570 let connect = serde_json::json!({
9571 "type": "connect",
9572 "protocol": PROTOCOL_VERSION,
9573 "session_id": "session-binary-auth",
9574 "last_revision": 0
9575 });
9576 binary_auth_socket
9577 .send(WsMessage::Text(connect.to_string().into()))
9578 .await
9579 .expect("send binary auth connect");
9580 await_hello(&mut binary_auth_socket).await;
9581 binary_auth_socket
9582 .send(WsMessage::Binary(vec![3_u8; 4].into()))
9583 .await
9584 .expect("send binary for auth");
9585 await_error_code(&mut binary_auth_socket, "unauthorized").await;
9586 let _ = binary_auth_socket.close(None).await;
9587 binary_auth_server.abort();
9588 }
9589
9590 #[tokio::test]
9591 async fn outbound_frame_helpers_enqueue_pong_and_close_messages() {
9592 let (sender, mut receiver) = mpsc::channel(4);
9593 let config = OutboundConfig::default();
9594 let telemetry = Arc::new(TelemetryPipeline::disabled());
9595 let correlation = test_correlation();
9596
9597 assert!(matches!(
9598 queue_pong_frame(
9599 &sender,
9600 vec![1, 2, 3],
9601 &config,
9602 &telemetry,
9603 &correlation,
9604 "/",
9605 "session-1",
9606 ),
9607 OutboundQueuePush::Queued
9608 ));
9609 match receiver.recv().await {
9610 Some(OutboundEnvelope::Pong(payload)) => assert_eq!(payload, vec![1, 2, 3]),
9611 other => panic!("unexpected envelope: {other:?}"),
9612 }
9613
9614 let close = axum::extract::ws::CloseFrame {
9615 code: close_code::NORMAL,
9616 reason: "closing".into(),
9617 };
9618 assert!(matches!(
9619 queue_close_frame(
9620 &sender,
9621 Some(close),
9622 &config,
9623 &telemetry,
9624 &correlation,
9625 "/",
9626 "session-1",
9627 ),
9628 OutboundQueuePush::Queued
9629 ));
9630 match receiver.recv().await {
9631 Some(OutboundEnvelope::Close(Some(frame))) => {
9632 assert_eq!(frame.code, close_code::NORMAL);
9633 assert_eq!(frame.reason, "closing");
9634 }
9635 other => panic!("unexpected envelope: {other:?}"),
9636 }
9637 }
9638
9639 #[test]
9640 fn cross_origin_websocket_requests_are_rejected_by_default() {
9641 let security = SecurityConfig {
9642 signer: TokenSigner::new(b"test secret".to_vec()),
9643 ..SecurityConfig::default()
9644 };
9645 let mut headers = HeaderMap::new();
9646 headers.insert(header::HOST, HeaderValue::from_static("example.test"));
9647 headers.insert(
9648 header::ORIGIN,
9649 HeaderValue::from_static("https://evil.example"),
9650 );
9651
9652 assert!(!origin_allowed(&headers, &security));
9653
9654 headers.insert(
9655 header::ORIGIN,
9656 HeaderValue::from_static("https://example.test"),
9657 );
9658 assert!(origin_allowed(&headers, &security));
9659 }
9660
9661 #[test]
9662 fn explicitly_allowed_origins_are_accepted() {
9663 let security = SecurityConfig {
9664 signer: TokenSigner::new(b"test secret".to_vec()),
9665 allowed_origins: std::sync::Arc::new(vec!["https://app.example".to_string()]),
9666 rate_limiter: None,
9667 authorization: None,
9668 quota_policy: None,
9669 };
9670 let mut headers = HeaderMap::new();
9671 headers.insert(header::HOST, HeaderValue::from_static("api.example"));
9672 headers.insert(
9673 header::ORIGIN,
9674 HeaderValue::from_static("https://app.example"),
9675 );
9676
9677 assert!(origin_allowed(&headers, &security));
9678 }
9679
9680 #[test]
9681 fn session_affinity_policy_is_explicit_and_testable() {
9682 let distributed = DistributedConfig {
9683 node_id: "node-b".to_string(),
9684 session_affinity: SessionAffinityMode::Required,
9685 };
9686 let signed_session = SignedSession {
9687 session_id: "session-1".to_string(),
9688 path: "/".to_string(),
9689 node_id: Some("node-a".to_string()),
9690 };
9691 let mismatch = session_affinity_mismatch(&distributed, &signed_session, "/").unwrap();
9692 assert_eq!(mismatch.session_id, "session-1");
9693 assert_eq!(mismatch.current_node_id, "node-b");
9694 assert_eq!(mismatch.token_node_id.as_deref(), Some("node-a"));
9695
9696 let matched = SignedSession {
9697 node_id: Some("node-b".to_string()),
9698 ..signed_session
9699 };
9700 assert!(session_affinity_mismatch(&distributed, &matched, "/").is_none());
9701 }
9702
9703 #[test]
9704 fn session_affinity_optional_mode_and_id_generators_cover_additional_paths() {
9705 let distributed = DistributedConfig {
9706 node_id: "node-b".to_string(),
9707 session_affinity: SessionAffinityMode::Disabled,
9708 };
9709 let signed_session = SignedSession {
9710 session_id: "session-1".to_string(),
9711 path: "/".to_string(),
9712 node_id: Some("node-a".to_string()),
9713 };
9714 assert!(session_affinity_mismatch(&distributed, &signed_session, "/").is_none());
9715
9716 let trace_id = super::generate_trace_id();
9717 let span_id = super::generate_span_id();
9718 assert_eq!(trace_id.len(), 32);
9719 assert_eq!(span_id.len(), 16);
9720 assert!(trace_id.chars().all(|ch| ch.is_ascii_hexdigit()));
9721 assert!(span_id.chars().all(|ch| ch.is_ascii_hexdigit()));
9722
9723 let now_ms = now_unix_ms();
9724 assert!(now_ms > 0);
9725
9726 let chrono_ms = super::chrono_like_timestamp();
9727 assert!(chrono_ms.chars().all(|ch| ch.is_ascii_digit()));
9728 }
9729
9730 #[test]
9731 fn durable_release_and_acquire_no_store_paths_are_explicit() {
9732 let durable = DurableRuntimeConfig::default();
9733 durable_release_lease(&durable, "session-none", None);
9735 durable_release_lease(
9736 &durable,
9737 "session-none",
9738 Some(&super::DurableLeaseHandle {
9739 owner_node_id: "node-a".to_string(),
9740 fence_token: 1,
9741 }),
9742 );
9743
9744 let distributed = DistributedConfig::default();
9745 let acquired = durable_acquire_lease(&durable, &distributed, "session-none", "/")
9746 .expect("no-store acquire should not fail");
9747 assert!(acquired.is_none());
9748 }
9749
9750 #[tokio::test]
9751 async fn ws_connect_handshake_rejects_session_id_mismatch() {
9752 let secret = b"test secret".to_vec();
9753 let signer = TokenSigner::new(secret.clone());
9754 let session_token = signer.sign_session("session-match", "/", "node-a");
9755 let csrf_token = signer.sign_csrf("session-match", "/");
9756
9757 let (base, server) = spawn_test_router(
9758 ShellyRouter::new()
9759 .with_secret(secret)
9760 .with_allowed_origin("https://example.test")
9761 .live("/", Counter::default),
9762 )
9763 .await;
9764
9765 let mut request = format!(
9766 "{base}/__shelly/ws/__root__?protocol={PROTOCOL_VERSION}&session={session_token}&csrf={csrf_token}"
9767 )
9768 .into_client_request()
9769 .expect("session mismatch request");
9770 request
9771 .headers_mut()
9772 .insert("origin", HeaderValue::from_static("https://example.test"));
9773
9774 let (mut socket, response) = connect_async(request)
9775 .await
9776 .expect("session mismatch websocket should connect");
9777 assert_eq!(response.status(), StatusCode::SWITCHING_PROTOCOLS);
9778
9779 let connect = serde_json::json!({
9780 "type": "connect",
9781 "protocol": PROTOCOL_VERSION,
9782 "session_id": "session-other",
9783 "last_revision": 0
9784 });
9785 socket
9786 .send(WsMessage::Text(connect.to_string().into()))
9787 .await
9788 .expect("send mismatched connect handshake");
9789
9790 await_error_code(&mut socket, "session_mismatch").await;
9791 let _ = socket.close(None).await;
9792 server.abort();
9793 }
9794
9795 #[test]
9796 fn durable_store_enforces_lease_conflicts_and_force_takeover() {
9797 let store = InMemoryDurableSessionStore::new();
9798 let first = store
9799 .acquire_lease(DurableLeaseRequest {
9800 session_id: "session-1".to_string(),
9801 node_id: "node-a".to_string(),
9802 ttl_ms: 60_000,
9803 takeover_policy: DurableTakeoverPolicy::AllowExpired,
9804 })
9805 .expect("node-a acquires first lease");
9806 assert_eq!(first.owner_node_id, "node-a");
9807 assert_eq!(first.fence_token, 1);
9808
9809 let conflict = store
9810 .acquire_lease(DurableLeaseRequest {
9811 session_id: "session-1".to_string(),
9812 node_id: "node-b".to_string(),
9813 ttl_ms: 60_000,
9814 takeover_policy: DurableTakeoverPolicy::Deny,
9815 })
9816 .expect_err("node-b should be rejected while node-a owns lease");
9817 assert_eq!(conflict.code, "lease_conflict");
9818
9819 let takeover = store
9820 .acquire_lease(DurableLeaseRequest {
9821 session_id: "session-1".to_string(),
9822 node_id: "node-b".to_string(),
9823 ttl_ms: 60_000,
9824 takeover_policy: DurableTakeoverPolicy::Force,
9825 })
9826 .expect("force takeover should transfer ownership");
9827 assert_eq!(takeover.owner_node_id, "node-b");
9828 assert_eq!(takeover.transferred_from.as_deref(), Some("node-a"));
9829 assert_eq!(takeover.fence_token, 2);
9830 }
9831
9832 #[test]
9833 fn durable_force_takeover_invalidates_stale_owner_fence() {
9834 let store = InMemoryDurableSessionStore::new();
9835 let first = store
9836 .acquire_lease(DurableLeaseRequest {
9837 session_id: "session-2".to_string(),
9838 node_id: "node-a".to_string(),
9839 ttl_ms: 60_000,
9840 takeover_policy: DurableTakeoverPolicy::AllowExpired,
9841 })
9842 .expect("node-a acquires first lease");
9843
9844 let takeover = store
9845 .acquire_lease(DurableLeaseRequest {
9846 session_id: "session-2".to_string(),
9847 node_id: "node-b".to_string(),
9848 ttl_ms: 60_000,
9849 takeover_policy: DurableTakeoverPolicy::Force,
9850 })
9851 .expect("node-b force takes ownership");
9852
9853 let stale_renew = store
9854 .renew_lease("session-2", "node-a", first.fence_token, 60_000)
9855 .expect_err("stale owner should not renew after takeover");
9856 assert_eq!(stale_renew.code, "lease_not_owner");
9857
9858 let stale_append = store
9859 .append_journal_entry(
9860 "session-2",
9861 "node-a",
9862 first.fence_token,
9863 ClientMessage::Event {
9864 event: "inc".to_string(),
9865 target: None,
9866 value: serde_json::Value::Null,
9867 metadata: serde_json::Map::new(),
9868 },
9869 64,
9870 )
9871 .expect_err("stale owner should not append journal after takeover");
9872 assert_eq!(stale_append.code, "lease_not_owner");
9873
9874 store.release_lease("session-2", "node-a", first.fence_token);
9875
9876 let renewed_owner = store
9877 .renew_lease("session-2", "node-b", takeover.fence_token, 60_000)
9878 .expect("current owner keeps lease after stale release");
9879 assert_eq!(renewed_owner.owner_node_id, "node-b");
9880 assert_eq!(renewed_owner.fence_token, takeover.fence_token);
9881 }
9882
9883 #[test]
9884 fn durable_recovery_replays_journal_entries_after_owner_failover() {
9885 let store = Arc::new(InMemoryDurableSessionStore::new());
9886 let lease = store
9887 .acquire_lease(DurableLeaseRequest {
9888 session_id: "session-1".to_string(),
9889 node_id: "node-a".to_string(),
9890 ttl_ms: 60_000,
9891 takeover_policy: DurableTakeoverPolicy::AllowExpired,
9892 })
9893 .expect("node-a acquires lease");
9894 store.save_snapshot(
9895 "session-1",
9896 DurableSessionSnapshot {
9897 route_path: "/".to_string(),
9898 route_pattern: "/".to_string(),
9899 target_id: "root".to_string(),
9900 revision: 0,
9901 resume_token: "resume-1".to_string(),
9902 owner_node_id: "node-a".to_string(),
9903 updated_at_unix_ms: 0,
9904 },
9905 );
9906 store
9907 .append_journal_entry(
9908 "session-1",
9909 "node-a",
9910 lease.fence_token,
9911 ClientMessage::Event {
9912 event: "inc".to_string(),
9913 target: None,
9914 value: serde_json::Value::Null,
9915 metadata: serde_json::Map::new(),
9916 },
9917 128,
9918 )
9919 .expect("append first durable event");
9920 store
9921 .append_journal_entry(
9922 "session-1",
9923 "node-a",
9924 lease.fence_token,
9925 ClientMessage::Event {
9926 event: "inc".to_string(),
9927 target: None,
9928 value: serde_json::Value::Null,
9929 metadata: serde_json::Map::new(),
9930 },
9931 128,
9932 )
9933 .expect("append second durable event");
9934 store.release_lease("session-1", "node-a", lease.fence_token);
9935
9936 let routes = Arc::new(vec![LiveRoute::new(
9937 "/".to_string(),
9938 Arc::new(|| Box::<Counter>::default()),
9939 )]);
9940 let telemetry: Arc<dyn TelemetrySink> = Arc::new(TelemetryPipeline::disabled());
9941 let config = SocketConfig {
9942 routes,
9943 target_id: "root".to_string(),
9944 route_path: "/".to_string(),
9945 signed_session_id: "session-1".to_string(),
9946 max_message_size: 64 * 1024,
9947 pubsub: PubSub::default(),
9948 uploads: UploadConfig::default(),
9949 security: SecurityConfig::default(),
9950 telemetry: Arc::new(TelemetryPipeline::disabled()),
9951 correlation: test_correlation(),
9952 reconnect: ReconnectConfig::default(),
9953 distributed: DistributedConfig::default(),
9954 durable: DurableRuntimeConfig {
9955 store: Some(store),
9956 ..DurableRuntimeConfig::default()
9957 },
9958 outbound: OutboundConfig::default(),
9959 render: super::RenderConfig::default(),
9960 overload: OverloadConfig::default(),
9961 };
9962
9963 let recovered = recover_from_durable_record(&config, "session-1", &telemetry)
9964 .expect("durable recovery should rebuild session");
9965 assert_eq!(recovered.replayed_entries, 2);
9966 assert_eq!(recovered.source_owner_node_id, "node-a");
9967 assert_eq!(recovered.snapshot.session.revision(), 2);
9968 assert_eq!(recovered.snapshot.route_pattern, "/");
9969 assert_eq!(recovered.snapshot.session.route_path(), "/");
9970 }
9971
9972 #[test]
9973 fn rate_limiter_hook_can_reject_messages() {
9974 let mut security = SecurityConfig {
9975 signer: TokenSigner::new(b"test secret".to_vec()),
9976 ..SecurityConfig::default()
9977 };
9978 security.rate_limiter = Some(std::sync::Arc::new(|ctx| {
9979 ctx.route_path == "/" && ctx.message_kind == "text" && ctx.session_id == "allowed"
9980 }));
9981
9982 assert!(!rate_limited(&security, "/", "allowed", "text"));
9983 assert!(rate_limited(&security, "/", "blocked", "text"));
9984 assert_eq!(
9985 rate_limited_error(),
9986 ServerMessage::Error {
9987 message: "rate limit exceeded".to_string(),
9988 code: Some("rate_limited".to_string()),
9989 }
9990 );
9991 }
9992
9993 #[test]
9994 fn rate_limited_defaults_to_false_without_hook() {
9995 let security = SecurityConfig::default();
9996 assert!(!rate_limited(&security, "/", "session-1", "text"));
9997 }
9998
9999 #[test]
10000 fn security_operation_mapping_covers_navigation_and_upload_variants() {
10001 let upload_start = ClientMessage::UploadStart {
10002 upload_id: "up-1".to_string(),
10003 event: "uploaded".to_string(),
10004 target: Some("profile".to_string()),
10005 name: "avatar.png".to_string(),
10006 size: 1,
10007 content_type: None,
10008 };
10009 let (operation, event_name, event_target) = security_operation_for_message(&upload_start);
10010 assert_eq!(operation, SecurityOperation::UploadStart);
10011 assert_eq!(event_name, Some("uploaded"));
10012 assert_eq!(event_target, Some("profile"));
10013 assert_eq!(
10014 security_operation_for_message(&ClientMessage::PatchUrl {
10015 to: "/users".to_string()
10016 })
10017 .0,
10018 SecurityOperation::PatchUrl
10019 );
10020 assert_eq!(
10021 security_operation_for_message(&ClientMessage::Navigate {
10022 to: "/users/1".to_string()
10023 })
10024 .0,
10025 SecurityOperation::Navigate
10026 );
10027 assert_eq!(
10028 security_operation_for_message(&ClientMessage::UploadChunk {
10029 upload_id: "up-1".to_string(),
10030 offset: 0,
10031 data: "AA==".to_string(),
10032 })
10033 .0,
10034 SecurityOperation::UploadChunk
10035 );
10036 assert_eq!(
10037 security_operation_for_message(&ClientMessage::UploadComplete {
10038 upload_id: "up-1".to_string()
10039 })
10040 .0,
10041 SecurityOperation::UploadComplete
10042 );
10043 }
10044
10045 #[test]
10046 fn client_message_log_fields_cover_all_client_variants() {
10047 let connect = client_message_log_fields(&ClientMessage::Connect {
10048 protocol: "shelly/1".to_string(),
10049 session_id: Some("sid-1".to_string()),
10050 last_revision: Some(7),
10051 resume_token: Some("resume".to_string()),
10052 tenant_id: None,
10053 trace_id: None,
10054 span_id: None,
10055 parent_span_id: None,
10056 correlation_id: None,
10057 request_id: None,
10058 });
10059 assert_eq!(connect.message_type, "connect");
10060 assert_eq!(connect.connect_session_id.as_deref(), Some("sid-1"));
10061 assert_eq!(connect.connect_last_revision, Some(7));
10062 assert_eq!(connect.has_resume_token, Some(true));
10063
10064 let event = client_message_log_fields(&ClientMessage::Event {
10065 event: "save".to_string(),
10066 target: Some("profile".to_string()),
10067 value: serde_json::json!({}),
10068 metadata: serde_json::Map::new(),
10069 });
10070 assert_eq!(event.message_type, "event");
10071 assert_eq!(event.event_name.as_deref(), Some("save"));
10072 assert_eq!(event.event_target.as_deref(), Some("profile"));
10073
10074 let ping = client_message_log_fields(&ClientMessage::Ping { nonce: None });
10075 assert_eq!(ping.message_type, "ping");
10076
10077 let patch_url = client_message_log_fields(&ClientMessage::PatchUrl {
10078 to: "/users".to_string(),
10079 });
10080 assert_eq!(patch_url.message_type, "patch_url");
10081 assert_eq!(patch_url.navigation_to.as_deref(), Some("/users"));
10082
10083 let navigate = client_message_log_fields(&ClientMessage::Navigate {
10084 to: "/users/1".to_string(),
10085 });
10086 assert_eq!(navigate.message_type, "navigate");
10087 assert_eq!(navigate.navigation_to.as_deref(), Some("/users/1"));
10088
10089 let upload_start = client_message_log_fields(&ClientMessage::UploadStart {
10090 upload_id: "up-1".to_string(),
10091 event: "uploaded".to_string(),
10092 target: Some("profile".to_string()),
10093 name: "avatar.png".to_string(),
10094 size: 1,
10095 content_type: None,
10096 });
10097 assert_eq!(upload_start.message_type, "upload_start");
10098 assert_eq!(upload_start.event_target.as_deref(), Some("profile"));
10099 assert_eq!(upload_start.upload_id.as_deref(), Some("up-1"));
10100 assert_eq!(upload_start.upload_event.as_deref(), Some("uploaded"));
10101
10102 let upload_chunk = client_message_log_fields(&ClientMessage::UploadChunk {
10103 upload_id: "up-1".to_string(),
10104 offset: 0,
10105 data: "AA==".to_string(),
10106 });
10107 assert_eq!(upload_chunk.message_type, "upload_chunk");
10108 assert_eq!(upload_chunk.upload_id.as_deref(), Some("up-1"));
10109
10110 let upload_complete = client_message_log_fields(&ClientMessage::UploadComplete {
10111 upload_id: "up-1".to_string(),
10112 });
10113 assert_eq!(upload_complete.message_type, "upload_complete");
10114 assert_eq!(upload_complete.upload_id.as_deref(), Some("up-1"));
10115 }
10116
10117 #[test]
10118 fn log_helpers_and_decoders_cover_error_paths() {
10119 let correlation = test_correlation();
10120 let message = ClientMessage::Event {
10121 event: "save".to_string(),
10122 target: Some("profile".to_string()),
10123 value: serde_json::json!({}),
10124 metadata: serde_json::Map::new(),
10125 };
10126 log_ws_text_ingress("/users", "sid-1", 32, Ok(&message), &correlation);
10127 let parse_err = serde_json::from_str::<ClientMessage>("{").expect_err("invalid JSON");
10128 log_ws_text_ingress("/users", "sid-1", 1, Err(&parse_err), &correlation);
10129
10130 assert_eq!(
10131 percent_decode("hello+world").as_deref(),
10132 Some("hello world")
10133 );
10134 assert_eq!(percent_decode("a%2Fb").as_deref(), Some("a/b"));
10135 assert!(percent_decode("%").is_none());
10136 assert!(percent_decode("%GG").is_none());
10137 assert_eq!(from_hex(b'F'), Some(15));
10138 assert_eq!(from_hex(b'z'), None);
10139 assert!(parse_query("broken=%GG").is_empty());
10140 }
10141
10142 #[test]
10143 fn outbound_envelope_size_and_telemetry_helpers_are_safe() {
10144 assert_eq!(
10145 OutboundEnvelope::Text("abc".to_string()).estimated_bytes(),
10146 3
10147 );
10148 assert_eq!(
10149 OutboundEnvelope::Pong(vec![1, 2, 3, 4]).estimated_bytes(),
10150 4
10151 );
10152 assert_eq!(
10153 OutboundEnvelope::Close(Some(OutboundCloseFrame {
10154 code: 1000,
10155 reason: "bye".to_string(),
10156 }))
10157 .estimated_bytes(),
10158 5
10159 );
10160 assert_eq!(OutboundEnvelope::Close(None).estimated_bytes(), 0);
10161
10162 let telemetry = Arc::new(TelemetryPipeline::disabled());
10163 let correlation = test_correlation();
10164 emit_outbound_batch_telemetry(&telemetry, &correlation, "/users", "sid-1", 3, 120, 2);
10165 emit_outbound_overflow_telemetry(
10166 &telemetry,
10167 &correlation,
10168 "/users",
10169 "sid-1",
10170 OutboundOverflowPolicy::Disconnect,
10171 64,
10172 );
10173 emit_outbound_overflow_telemetry(
10174 &telemetry,
10175 &correlation,
10176 "/users",
10177 "sid-1",
10178 OutboundOverflowPolicy::DropNewest,
10179 64,
10180 );
10181 }
10182
10183 #[test]
10184 fn security_audit_emitter_handles_optional_context_fields() {
10185 let telemetry = Arc::new(TelemetryPipeline::disabled());
10186 let correlation = test_correlation();
10187 emit_security_audit(
10188 &telemetry,
10189 &correlation,
10190 "/users",
10191 Some("sid-1"),
10192 Some("tenant-a"),
10193 "authorization",
10194 false,
10195 Some("forbidden"),
10196 SecurityOperation::Event,
10197 "text",
10198 Some("save"),
10199 Some("policy-deny"),
10200 );
10201 emit_security_audit(
10202 &telemetry,
10203 &correlation,
10204 "/users",
10205 None,
10206 None,
10207 "authorization",
10208 true,
10209 None,
10210 SecurityOperation::Ping,
10211 "text",
10212 None,
10213 None,
10214 );
10215 }
10216
10217 #[test]
10218 fn outbound_queue_overflow_disconnect_policy_requests_disconnect() {
10219 let (sender, mut receiver) = mpsc::channel(1);
10220 let outbound = OutboundConfig {
10221 queue_capacity: 1,
10222 overflow_policy: OutboundOverflowPolicy::Disconnect,
10223 ..OutboundConfig::default()
10224 };
10225 let telemetry = Arc::new(TelemetryPipeline::disabled());
10226 let correlation = test_correlation();
10227
10228 let first = queue_server_message(
10229 &sender,
10230 &ServerMessage::Pong { nonce: None },
10231 &outbound,
10232 &telemetry,
10233 &correlation,
10234 "/",
10235 "session-1",
10236 );
10237 assert!(matches!(first, super::OutboundQueuePush::Queued));
10238
10239 let second = queue_server_message(
10240 &sender,
10241 &ServerMessage::Pong {
10242 nonce: Some("2".to_string()),
10243 },
10244 &outbound,
10245 &telemetry,
10246 &correlation,
10247 "/",
10248 "session-1",
10249 );
10250 assert!(matches!(second, super::OutboundQueuePush::Disconnect));
10251
10252 let _ = receiver.try_recv().expect("first outbound message queued");
10253 }
10254
10255 #[test]
10256 fn outbound_queue_overflow_drop_policy_drops_newest() {
10257 let (sender, mut receiver) = mpsc::channel(1);
10258 let outbound = OutboundConfig {
10259 queue_capacity: 1,
10260 overflow_policy: OutboundOverflowPolicy::DropNewest,
10261 ..OutboundConfig::default()
10262 };
10263 let telemetry = Arc::new(TelemetryPipeline::disabled());
10264 let correlation = test_correlation();
10265
10266 let first = queue_server_message(
10267 &sender,
10268 &ServerMessage::Pong { nonce: None },
10269 &outbound,
10270 &telemetry,
10271 &correlation,
10272 "/",
10273 "session-1",
10274 );
10275 assert!(matches!(first, super::OutboundQueuePush::Queued));
10276
10277 let second = queue_server_message(
10278 &sender,
10279 &ServerMessage::Pong {
10280 nonce: Some("2".to_string()),
10281 },
10282 &outbound,
10283 &telemetry,
10284 &correlation,
10285 "/",
10286 "session-1",
10287 );
10288 assert!(matches!(second, super::OutboundQueuePush::Dropped));
10289
10290 let _ = receiver.try_recv().expect("first outbound message queued");
10291 }
10292
10293 #[test]
10294 fn security_operation_mapping_tracks_event_metadata() {
10295 let event = ClientMessage::Event {
10296 event: "save".to_string(),
10297 target: Some("form".to_string()),
10298 value: serde_json::Value::Null,
10299 metadata: serde_json::Map::new(),
10300 };
10301 let upload_start = ClientMessage::UploadStart {
10302 upload_id: "u1".to_string(),
10303 event: "upload".to_string(),
10304 target: Some("avatar".to_string()),
10305 name: "a.png".to_string(),
10306 size: 10,
10307 content_type: Some("image/png".to_string()),
10308 };
10309
10310 let (event_operation, event_name, event_target) = security_operation_for_message(&event);
10311 assert_eq!(event_operation, SecurityOperation::Event);
10312 assert_eq!(event_name, Some("save"));
10313 assert_eq!(event_target, Some("form"));
10314
10315 let (upload_operation, upload_event_name, upload_target) =
10316 security_operation_for_message(&upload_start);
10317 assert_eq!(upload_operation, SecurityOperation::UploadStart);
10318 assert_eq!(upload_event_name, Some("upload"));
10319 assert_eq!(upload_target, Some("avatar"));
10320 }
10321
10322 #[test]
10323 fn authorization_hook_can_reject_sensitive_event_names() {
10324 let mut security = SecurityConfig {
10325 signer: TokenSigner::new(b"test secret".to_vec()),
10326 ..SecurityConfig::default()
10327 };
10328 security.authorization = Some(Arc::new(|ctx| {
10329 if ctx.event_name.as_deref() == Some("admin.delete_user") {
10330 AuthorizationDecision::deny(
10331 "authz_denied",
10332 "authorization policy rejected admin.delete_user",
10333 )
10334 } else {
10335 AuthorizationDecision::allow()
10336 }
10337 }));
10338
10339 assert_eq!(
10340 authorization_denied(
10341 &security,
10342 AuthorizationInput {
10343 route_path: "/admin",
10344 session_id: "session-1",
10345 tenant_id: None,
10346 message_kind: "text",
10347 operation: SecurityOperation::Event,
10348 event_name: Some("admin.delete_user"),
10349 event_target: Some("delete"),
10350 },
10351 ),
10352 Some(ServerMessage::Error {
10353 message: "authorization policy rejected admin.delete_user".to_string(),
10354 code: Some("authz_denied".to_string()),
10355 })
10356 );
10357 assert!(authorization_denied(
10358 &security,
10359 AuthorizationInput {
10360 route_path: "/admin",
10361 session_id: "session-1",
10362 tenant_id: None,
10363 message_kind: "text",
10364 operation: SecurityOperation::Event,
10365 event_name: Some("profile.save"),
10366 event_target: Some("save"),
10367 },
10368 )
10369 .is_none());
10370 }
10371
10372 #[test]
10373 fn quota_policy_can_apply_session_route_and_event_scopes() {
10374 let mut security = SecurityConfig {
10375 signer: TokenSigner::new(b"test secret".to_vec()),
10376 ..SecurityConfig::default()
10377 };
10378 security.quota_policy = Some(Arc::new(|ctx| {
10379 if ctx.session_id == "blocked-session" {
10380 return QuotaDecision::deny("session_quota_exceeded", "session quota exceeded");
10381 }
10382 if ctx.route_path == "/admin" && ctx.event_name.as_deref() == Some("bulk.export") {
10383 return QuotaDecision::deny("event_quota_exceeded", "event quota exceeded");
10384 }
10385 QuotaDecision::allow()
10386 }));
10387
10388 assert_eq!(
10389 quota_denied(
10390 &security,
10391 "/feed",
10392 "blocked-session",
10393 None,
10394 "text",
10395 SecurityOperation::Event,
10396 Some("scroll"),
10397 ),
10398 Some(ServerMessage::Error {
10399 message: "session quota exceeded".to_string(),
10400 code: Some("session_quota_exceeded".to_string()),
10401 })
10402 );
10403 assert_eq!(
10404 quota_denied(
10405 &security,
10406 "/admin",
10407 "session-1",
10408 None,
10409 "text",
10410 SecurityOperation::Event,
10411 Some("bulk.export"),
10412 ),
10413 Some(ServerMessage::Error {
10414 message: "event quota exceeded".to_string(),
10415 code: Some("event_quota_exceeded".to_string()),
10416 })
10417 );
10418 assert!(quota_denied(
10419 &security,
10420 "/admin",
10421 "session-1",
10422 None,
10423 "text",
10424 SecurityOperation::Event,
10425 Some("profile.save"),
10426 )
10427 .is_none());
10428 }
10429
10430 #[test]
10431 fn tenant_quota_policy_requires_tenant_context_by_default() {
10432 let policy = TenantQuotaPolicy::new();
10433 let decision = policy.evaluate(&super::QuotaContext {
10434 route_path: "/tenant".to_string(),
10435 session_id: "session-1".to_string(),
10436 tenant_id: None,
10437 message_kind: "text",
10438 operation: SecurityOperation::Connect,
10439 event_name: None,
10440 });
10441 assert!(!decision.allowed);
10442 assert_eq!(decision.code.as_deref(), Some("tenant_context_required"));
10443 }
10444
10445 #[test]
10446 fn tenant_quota_policy_enforces_per_tenant_session_and_event_limits() {
10447 let policy = TenantQuotaPolicy::new()
10448 .with_window_ms(u64::MAX)
10449 .with_budgets(TenantQuotaBudgets {
10450 max_sessions_per_window: 1,
10451 max_events_per_window: 2,
10452 require_tenant_id: true,
10453 });
10454
10455 let connect_for = |session_id: &str, tenant_id: &str| super::QuotaContext {
10456 route_path: "/tenant".to_string(),
10457 session_id: session_id.to_string(),
10458 tenant_id: Some(tenant_id.to_string()),
10459 message_kind: "text",
10460 operation: SecurityOperation::Connect,
10461 event_name: None,
10462 };
10463 let event_for = |session_id: &str, tenant_id: &str, name: &str| super::QuotaContext {
10464 route_path: "/tenant".to_string(),
10465 session_id: session_id.to_string(),
10466 tenant_id: Some(tenant_id.to_string()),
10467 message_kind: "text",
10468 operation: SecurityOperation::Event,
10469 event_name: Some(name.to_string()),
10470 };
10471
10472 assert!(
10473 policy
10474 .evaluate(&connect_for("session-a1", "tenant-a"))
10475 .allowed
10476 );
10477 let tenant_a_second_session = policy.evaluate(&connect_for("session-a2", "tenant-a"));
10478 assert!(!tenant_a_second_session.allowed);
10479 assert_eq!(
10480 tenant_a_second_session.code.as_deref(),
10481 Some("tenant_session_quota_exceeded")
10482 );
10483 assert!(
10484 policy
10485 .evaluate(&connect_for("session-b1", "tenant-b"))
10486 .allowed,
10487 "tenant-b must keep an isolated session quota window"
10488 );
10489
10490 assert!(
10491 policy
10492 .evaluate(&event_for("session-b1", "tenant-b", "save"))
10493 .allowed
10494 );
10495 assert!(
10496 policy
10497 .evaluate(&event_for("session-b1", "tenant-b", "save"))
10498 .allowed
10499 );
10500 let tenant_b_third_event = policy.evaluate(&event_for("session-b1", "tenant-b", "save"));
10501 assert!(!tenant_b_third_event.allowed);
10502 assert_eq!(
10503 tenant_b_third_event.code.as_deref(),
10504 Some("tenant_event_quota_exceeded")
10505 );
10506 assert!(
10507 policy
10508 .evaluate(&event_for("session-a1", "tenant-a", "save"))
10509 .allowed,
10510 "tenant-a event budget must remain independent from tenant-b"
10511 );
10512 }
10513
10514 #[test]
10515 fn security_audit_event_schema_is_structured_and_correlated() {
10516 let event = security_audit_event(
10517 &test_correlation(),
10518 "/admin",
10519 Some("session-1"),
10520 Some("tenant-a"),
10521 "authorization_policy",
10522 false,
10523 Some("authz_denied"),
10524 SecurityOperation::Event,
10525 "text",
10526 Some("admin.delete_user"),
10527 Some("authorization policy rejected admin.delete_user"),
10528 );
10529
10530 assert_eq!(event.kind, TelemetryEventKind::SecurityAudit);
10531 assert_eq!(event.route_path.as_deref(), Some("/admin"));
10532 assert_eq!(event.session_id.as_deref(), Some("session-1"));
10533 assert_eq!(
10534 event.attributes.get("tenant_id"),
10535 Some(&serde_json::Value::String("tenant-a".to_string()))
10536 );
10537 assert_eq!(event.event_name.as_deref(), Some("admin.delete_user"));
10538 assert!(!event.ok);
10539 assert_eq!(
10540 event.attributes.get("control"),
10541 Some(&serde_json::Value::String(
10542 "authorization_policy".to_string()
10543 ))
10544 );
10545 assert_eq!(
10546 event.attributes.get("operation"),
10547 Some(&serde_json::Value::String("event".to_string()))
10548 );
10549 assert_eq!(
10550 event.attributes.get("code"),
10551 Some(&serde_json::Value::String("authz_denied".to_string()))
10552 );
10553 assert_eq!(
10554 event.attributes.get("policy_reason"),
10555 Some(&serde_json::Value::String(
10556 "authorization policy rejected admin.delete_user".to_string()
10557 ))
10558 );
10559 assert_eq!(
10560 event.trace_id.as_deref(),
10561 Some("0123456789abcdef0123456789abcdef")
10562 );
10563 }
10564
10565 #[tokio::test]
10566 async fn overload_decision_sheds_background_when_tenant_budget_is_saturated() {
10567 let mut overload = OverloadConfig::default();
10568 overload.budgets.session_queue_depth = usize::MAX;
10569 overload.budgets.session_bytes_per_sec = usize::MAX;
10570 overload.budgets.session_cpu_ms_per_sec = u64::MAX;
10571 overload.budgets.tenant_queue_depth = usize::MAX;
10572 overload.budgets.tenant_bytes_per_sec = 32;
10573 overload.budgets.tenant_cpu_ms_per_sec = u64::MAX;
10574
10575 let first = test_overload_context(
10576 "session-1",
10577 Some("tenant-a"),
10578 OverloadPriority::Interactive,
10579 24,
10580 1,
10581 256,
10582 );
10583 assert!(
10584 overload_decision_for_dispatch(&overload, &first)
10585 .await
10586 .allowed
10587 );
10588
10589 let second = test_overload_context(
10590 "session-2",
10591 Some("tenant-a"),
10592 OverloadPriority::Background,
10593 24,
10594 1,
10595 256,
10596 );
10597 let decision = overload_decision_for_dispatch(&overload, &second).await;
10598 assert!(!decision.allowed);
10599 assert_eq!(decision.code.as_deref(), Some("overload_shed"));
10600 assert!(decision
10601 .reason
10602 .as_deref()
10603 .unwrap_or_default()
10604 .contains("tenant_budget"));
10605 }
10606
10607 #[tokio::test]
10608 async fn overload_decision_throttles_interactive_when_session_budget_is_saturated() {
10609 let mut overload = OverloadConfig::default();
10610 overload.budgets.session_queue_depth = usize::MAX;
10611 overload.budgets.session_bytes_per_sec = 32;
10612 overload.budgets.session_cpu_ms_per_sec = u64::MAX;
10613 overload.budgets.tenant_queue_depth = usize::MAX;
10614 overload.budgets.tenant_bytes_per_sec = usize::MAX;
10615 overload.budgets.tenant_cpu_ms_per_sec = u64::MAX;
10616
10617 let first =
10618 test_overload_context("session-1", None, OverloadPriority::Interactive, 24, 1, 256);
10619 assert!(
10620 overload_decision_for_dispatch(&overload, &first)
10621 .await
10622 .allowed
10623 );
10624
10625 let second =
10626 test_overload_context("session-1", None, OverloadPriority::Interactive, 24, 1, 256);
10627 let decision = overload_decision_for_dispatch(&overload, &second).await;
10628 assert!(decision.allowed);
10629 assert_eq!(decision.code.as_deref(), Some("overload_throttle"));
10630 assert_eq!(decision.throttle_ms, 10);
10631 assert!(decision
10632 .reason
10633 .as_deref()
10634 .unwrap_or_default()
10635 .contains("session_budget"));
10636 }
10637
10638 #[tokio::test]
10639 async fn overload_policy_hook_can_override_shed_decision() {
10640 let mut overload = OverloadConfig::default();
10641 overload.budgets.session_bytes_per_sec = 1;
10642 overload.policy_hook = Some(Arc::new(|_ctx, _decision| OverloadDecision::allow()));
10643
10644 let context =
10645 test_overload_context("session-1", None, OverloadPriority::Background, 8, 1, 128);
10646 let decision = overload_decision_for_dispatch(&overload, &context).await;
10647 assert!(decision.allowed);
10648 assert_eq!(decision.code, None);
10649 assert_eq!(decision.reason, None);
10650 }
10651
10652 #[test]
10653 fn overload_telemetry_event_has_expected_schema() {
10654 let context = test_overload_context(
10655 "session-5",
10656 Some("tenant-z"),
10657 OverloadPriority::Background,
10658 512,
10659 12,
10660 16,
10661 );
10662 let decision = OverloadDecision::shed("shed:queue_depth+tenant_budget");
10663 let event = overload_telemetry_event(&test_correlation(), &context, &decision)
10664 .expect("shed decisions must emit telemetry");
10665 assert_eq!(event.kind, TelemetryEventKind::Error);
10666 assert!(!event.ok);
10667 assert_eq!(event.route_path.as_deref(), Some("/bench"));
10668 assert_eq!(event.session_id.as_deref(), Some("session-5"));
10669 assert_eq!(event.event_name.as_deref(), Some("bench.event"));
10670 assert_eq!(
10671 event.attributes.get("phase"),
10672 Some(&serde_json::json!("overload_control"))
10673 );
10674 assert_eq!(
10675 event.attributes.get("action"),
10676 Some(&serde_json::json!("shed"))
10677 );
10678 assert_eq!(
10679 event.attributes.get("priority"),
10680 Some(&serde_json::json!("background"))
10681 );
10682 assert_eq!(
10683 event.attributes.get("queue_depth"),
10684 Some(&serde_json::json!(12))
10685 );
10686 assert_eq!(
10687 event.attributes.get("queue_capacity"),
10688 Some(&serde_json::json!(16))
10689 );
10690 assert_eq!(
10691 event.attributes.get("inbound_bytes"),
10692 Some(&serde_json::json!(512))
10693 );
10694 assert_eq!(
10695 event.attributes.get("tenant_id"),
10696 Some(&serde_json::json!("tenant-z"))
10697 );
10698 assert_eq!(
10699 event.attributes.get("reason"),
10700 Some(&serde_json::json!("shed:queue_depth+tenant_budget"))
10701 );
10702 }
10703
10704 #[tokio::test(start_paused = true)]
10705 async fn overload_throttle_respects_configured_delay() {
10706 let decision = OverloadDecision::throttle(25, "throttle:session_budget");
10707 let handle = tokio::spawn(async move {
10708 apply_overload_throttle(&decision).await;
10709 });
10710
10711 tokio::task::yield_now().await;
10712 advance(Duration::from_millis(24)).await;
10713 tokio::task::yield_now().await;
10714 assert!(!handle.is_finished());
10715
10716 advance(Duration::from_millis(1)).await;
10717 handle.await.expect("throttle task should complete");
10718 }
10719
10720 #[tokio::test]
10721 async fn pubsub_commands_forward_broadcasts_to_socket_channel() {
10722 let pubsub = PubSub::default();
10723 let (sender, mut receiver) = mpsc::unbounded_channel();
10724 let mut tasks: Vec<JoinHandle<()>> = Vec::new();
10725 let mut subscribed_topics = HashSet::new();
10726 let correlation = test_correlation();
10727
10728 process_pubsub_commands(
10729 vec![PubSubCommand::Subscribe {
10730 topic: "chat:lobby".to_string(),
10731 }],
10732 &pubsub,
10733 &sender,
10734 &mut tasks,
10735 &mut subscribed_topics,
10736 "session-1",
10737 "node-test",
10738 &Arc::new(TelemetryPipeline::disabled()),
10739 &correlation,
10740 );
10741 process_pubsub_commands(
10742 vec![PubSubCommand::Broadcast {
10743 topic: "chat:lobby".to_string(),
10744 messages: vec![ServerMessage::Pong {
10745 nonce: Some("n1".to_string()),
10746 }],
10747 }],
10748 &pubsub,
10749 &sender,
10750 &mut tasks,
10751 &mut subscribed_topics,
10752 "session-1",
10753 "node-test",
10754 &Arc::new(TelemetryPipeline::disabled()),
10755 &correlation,
10756 );
10757
10758 tokio::task::yield_now().await;
10759 let message = receiver.recv().await.unwrap();
10760 assert_eq!(message.topic, "chat:lobby");
10761 assert_eq!(
10762 message.messages,
10763 vec![ServerMessage::Pong {
10764 nonce: Some("n1".to_string())
10765 }]
10766 );
10767
10768 for task in tasks {
10769 task.abort();
10770 }
10771 }
10772
10773 #[tokio::test]
10774 async fn pubsub_commands_register_presence_for_subscribed_topics() {
10775 let pubsub = PubSub::default();
10776 let (sender, _receiver) = mpsc::unbounded_channel();
10777 let mut tasks: Vec<JoinHandle<()>> = Vec::new();
10778 let mut subscribed_topics = HashSet::new();
10779 let correlation = test_correlation();
10780
10781 process_pubsub_commands(
10782 vec![PubSubCommand::Subscribe {
10783 topic: "chat:lobby".to_string(),
10784 }],
10785 &pubsub,
10786 &sender,
10787 &mut tasks,
10788 &mut subscribed_topics,
10789 "session-42",
10790 "node-alpha",
10791 &Arc::new(TelemetryPipeline::disabled()),
10792 &correlation,
10793 );
10794 let snapshot = pubsub.presence_snapshot("chat:lobby");
10795 assert_eq!(snapshot.total_sessions, 1);
10796 assert_eq!(snapshot.by_node.get("node-alpha"), Some(&1));
10797
10798 unregister_pubsub_presence(&pubsub, &subscribed_topics, "session-42", "node-alpha");
10799 let cleared = pubsub.presence_snapshot("chat:lobby");
10800 assert_eq!(cleared.total_sessions, 0);
10801
10802 for task in tasks {
10803 task.abort();
10804 }
10805 }
10806
10807 #[tokio::test(start_paused = true)]
10808 async fn runtime_commands_dispatch_scheduled_events_to_socket_channel() {
10809 let (sender, mut receiver) = mpsc::unbounded_channel();
10810 let mut tasks: HashMap<String, JoinHandle<()>> = HashMap::new();
10811 process_runtime_commands(
10812 vec![RuntimeCommand::ScheduleOnce {
10813 id: "once".to_string(),
10814 delay_ms: 0,
10815 dispatch: RuntimeEvent::new("tick", serde_json::json!({"n": 1})),
10816 }],
10817 &sender,
10818 &mut tasks,
10819 );
10820
10821 let event =
10822 recv_runtime_event_with_virtual_time(&mut receiver, Duration::from_millis(1)).await;
10823 assert_eq!(event.event, "tick");
10824 assert_eq!(event.value, serde_json::json!({"n": 1}));
10825
10826 for task in tasks.into_values() {
10827 task.abort();
10828 }
10829 }
10830
10831 #[tokio::test(start_paused = true)]
10832 async fn runtime_commands_replace_existing_task_for_same_id() {
10833 let (sender, mut receiver) = mpsc::unbounded_channel();
10834 let mut tasks: HashMap<String, JoinHandle<()>> = HashMap::new();
10835
10836 process_runtime_commands(
10837 vec![
10838 RuntimeCommand::ScheduleOnce {
10839 id: "tick".to_string(),
10840 delay_ms: 100,
10841 dispatch: RuntimeEvent::new("tick_first", serde_json::json!({"n": 1})),
10842 },
10843 RuntimeCommand::ScheduleOnce {
10844 id: "tick".to_string(),
10845 delay_ms: 200,
10846 dispatch: RuntimeEvent::new("tick_second", serde_json::json!({"n": 2})),
10847 },
10848 ],
10849 &sender,
10850 &mut tasks,
10851 );
10852
10853 assert_eq!(tasks.len(), 1);
10854 let event =
10855 recv_runtime_event_with_virtual_time(&mut receiver, Duration::from_millis(250)).await;
10856 assert_eq!(event.event, "tick_second");
10857 assert_eq!(event.value, serde_json::json!({"n": 2}));
10858
10859 for task in tasks.into_values() {
10860 task.abort();
10861 }
10862 }
10863
10864 #[tokio::test]
10865 async fn runtime_commands_cancel_active_interval_task() {
10866 let (sender, _receiver) = mpsc::unbounded_channel();
10867 let mut tasks: HashMap<String, JoinHandle<()>> = HashMap::new();
10868 process_runtime_commands(
10869 vec![RuntimeCommand::ScheduleInterval {
10870 id: "pulse".to_string(),
10871 every_ms: 1000,
10872 dispatch: RuntimeEvent::new("pulse", serde_json::Value::Null),
10873 }],
10874 &sender,
10875 &mut tasks,
10876 );
10877 assert!(tasks.contains_key("pulse"));
10878
10879 process_runtime_commands(
10880 vec![RuntimeCommand::Cancel {
10881 id: "pulse".to_string(),
10882 }],
10883 &sender,
10884 &mut tasks,
10885 );
10886 assert!(!tasks.contains_key("pulse"));
10887 }
10888
10889 #[tokio::test]
10890 async fn upload_lifecycle_writes_temp_file_and_dispatches_event() {
10891 #[derive(Default)]
10892 struct UploadRecorder {
10893 name: String,
10894 }
10895
10896 impl LiveView for UploadRecorder {
10897 fn handle_event(&mut self, event: Event, _ctx: &mut Context) -> LiveResult {
10898 if event.name == "uploaded" {
10899 self.name = event
10900 .value
10901 .get("name")
10902 .and_then(serde_json::Value::as_str)
10903 .unwrap_or_default()
10904 .to_string();
10905 }
10906 Ok(())
10907 }
10908
10909 fn render(&self) -> Html {
10910 Html::new(format!("<p>{}</p>", self.name))
10911 }
10912 }
10913
10914 let temp_dir = std::env::temp_dir().join(format!("shelly-upload-test-{}", Uuid::new_v4()));
10915 let config = UploadConfig {
10916 max_file_size: 16,
10917 allowed_content_types: Arc::new(vec!["text/plain".to_string()]),
10918 temp_dir: Arc::new(temp_dir.clone()),
10919 };
10920 let mut uploads = HashMap::new();
10921 let mut session = LiveSession::new(Box::<UploadRecorder>::default(), "root");
10922 session.mount().unwrap();
10923 session.render_patch();
10924
10925 assert_eq!(
10926 handle_upload_start(
10927 &mut uploads,
10928 &config,
10929 UploadStartRequest {
10930 upload_id: "up-1".to_string(),
10931 event: "uploaded".to_string(),
10932 target: Some("file".to_string()),
10933 name: "notes.txt".to_string(),
10934 size: 5,
10935 content_type: Some("text/plain".to_string()),
10936 },
10937 )
10938 .await,
10939 vec![ServerMessage::UploadProgress {
10940 upload_id: "up-1".to_string(),
10941 received: 0,
10942 total: 5,
10943 }]
10944 );
10945 assert_eq!(
10946 handle_upload_chunk(&mut uploads, "up-1", 0, "aGVsbG8=").await,
10947 vec![ServerMessage::UploadProgress {
10948 upload_id: "up-1".to_string(),
10949 received: 5,
10950 total: 5,
10951 }]
10952 );
10953
10954 let messages = handle_upload_complete(&mut session, &mut uploads, "up-1").await;
10955 assert!(matches!(
10956 &messages[0],
10957 ServerMessage::UploadComplete {
10958 upload_id,
10959 name,
10960 size: 5,
10961 content_type: Some(content_type),
10962 } if upload_id == "up-1" && name == "notes.txt" && content_type == "text/plain"
10963 ));
10964 assert!(
10965 matches!(&messages[1], ServerMessage::Patch { html, .. } if html.contains("notes.txt"))
10966 );
10967
10968 let stored = tokio::fs::read_to_string(temp_dir.join("up-1.upload"))
10969 .await
10970 .unwrap();
10971 assert_eq!(stored, "hello");
10972 let _ = tokio::fs::remove_dir_all(temp_dir).await;
10973 }
10974
10975 #[tokio::test]
10976 async fn upload_start_validates_size_and_content_type() {
10977 let config = UploadConfig {
10978 max_file_size: 4,
10979 allowed_content_types: Arc::new(vec!["text/plain".to_string()]),
10980 temp_dir: Arc::new(PathBuf::from("/tmp/shelly-unused")),
10981 };
10982 let mut uploads = HashMap::new();
10983
10984 assert_eq!(
10985 handle_upload_start(
10986 &mut uploads,
10987 &config,
10988 UploadStartRequest {
10989 upload_id: "up-large".to_string(),
10990 event: "uploaded".to_string(),
10991 target: None,
10992 name: "big.txt".to_string(),
10993 size: 5,
10994 content_type: Some("text/plain".to_string()),
10995 },
10996 )
10997 .await,
10998 vec![upload_error(
10999 "up-large",
11000 "upload exceeds configured size limit",
11001 "upload_too_large",
11002 )]
11003 );
11004 assert_eq!(
11005 handle_upload_start(
11006 &mut uploads,
11007 &config,
11008 UploadStartRequest {
11009 upload_id: "up-type".to_string(),
11010 event: "uploaded".to_string(),
11011 target: None,
11012 name: "image.png".to_string(),
11013 size: 4,
11014 content_type: Some("image/png".to_string()),
11015 },
11016 )
11017 .await,
11018 vec![upload_error(
11019 "up-type",
11020 "upload content type is not allowed",
11021 "upload_type_not_allowed",
11022 )]
11023 );
11024 }
11025
11026 #[test]
11027 fn upload_lifecycle_event_maps_progress_complete_and_error() {
11028 let session_id = "sid-1";
11029 let route_path = "/upload";
11030 let correlation = test_correlation();
11031
11032 let start = upload_lifecycle_event(
11033 session_id,
11034 route_path,
11035 &ServerMessage::UploadProgress {
11036 upload_id: "up-1".to_string(),
11037 received: 0,
11038 total: 5,
11039 },
11040 &correlation,
11041 )
11042 .unwrap();
11043 assert_eq!(start.kind, TelemetryEventKind::UploadLifecycle);
11044 assert!(start.ok);
11045 assert_eq!(
11046 start.attributes.get("phase"),
11047 Some(&serde_json::Value::String("start".to_string()))
11048 );
11049
11050 let complete = upload_lifecycle_event(
11051 session_id,
11052 route_path,
11053 &ServerMessage::UploadComplete {
11054 upload_id: "up-1".to_string(),
11055 name: "notes.txt".to_string(),
11056 size: 5,
11057 content_type: Some("text/plain".to_string()),
11058 },
11059 &correlation,
11060 )
11061 .unwrap();
11062 assert_eq!(complete.kind, TelemetryEventKind::UploadLifecycle);
11063 assert!(complete.ok);
11064 assert_eq!(
11065 complete.attributes.get("phase"),
11066 Some(&serde_json::Value::String("complete".to_string()))
11067 );
11068
11069 let error = upload_lifecycle_event(
11070 session_id,
11071 route_path,
11072 &ServerMessage::UploadError {
11073 upload_id: "up-1".to_string(),
11074 message: "upload failed".to_string(),
11075 code: Some("upload_write_failed".to_string()),
11076 },
11077 &correlation,
11078 )
11079 .unwrap();
11080 assert_eq!(error.kind, TelemetryEventKind::UploadLifecycle);
11081 assert!(!error.ok);
11082 assert_eq!(
11083 error.attributes.get("phase"),
11084 Some(&serde_json::Value::String("error".to_string()))
11085 );
11086 }
11087
11088 #[test]
11089 fn upload_lifecycle_telemetry_ignores_non_upload_messages() {
11090 let correlation = test_correlation();
11091 assert!(upload_lifecycle_event(
11092 "sid-1",
11093 "/upload",
11094 &ServerMessage::Pong { nonce: None },
11095 &correlation,
11096 )
11097 .is_none());
11098
11099 let telemetry = Arc::new(TelemetryPipeline::disabled());
11100 emit_upload_lifecycle_telemetry(
11101 &telemetry,
11102 "sid-1",
11103 "/upload",
11104 &ServerMessage::Pong { nonce: None },
11105 &correlation,
11106 );
11107 }
11108
11109 #[test]
11110 fn client_runtime_has_production_reconnect_behavior() {
11111 assert!(CLIENT_JS.contains("reconnectBaseMs"));
11112 assert!(CLIENT_JS.contains("reconnectMaxMs"));
11113 assert!(CLIENT_JS.contains("reconnectJitterMs"));
11114 assert!(CLIENT_JS.contains("heartbeatIntervalMs"));
11115 assert!(CLIENT_JS.contains("heartbeatTimeoutMs"));
11116 assert!(CLIENT_JS.contains("scheduleReconnect("));
11117 assert!(CLIENT_JS.contains("window.addEventListener(\"online\""));
11118 assert!(CLIENT_JS.contains("document.addEventListener(\"visibilitychange\""));
11119 assert!(CLIENT_JS.contains("window.addEventListener(\"beforeunload\""));
11120 assert!(CLIENT_JS.contains("socket.addEventListener(\"close\""));
11121 assert!(CLIENT_JS.contains("nextReconnectDelayMs()"));
11122 }
11123
11124 #[test]
11125 fn client_runtime_sends_session_and_csrf_tokens_on_websocket_connect() {
11126 assert!(CLIENT_JS.contains("new URLSearchParams"));
11127 assert!(CLIENT_JS.contains("protocol: protocol"));
11128 assert!(CLIENT_JS.contains("session: cfg.sessionToken"));
11129 assert!(CLIENT_JS.contains("csrf: cfg.csrfToken"));
11130 assert!(CLIENT_JS.contains("trace_id: traceId"));
11131 assert!(CLIENT_JS.contains("parent_span_id: spanId"));
11132 assert!(CLIENT_JS.contains("correlation_id: correlationId || \"\""));
11133 assert!(CLIENT_JS.contains("request_id: requestId || \"\""));
11134 assert!(CLIENT_JS.contains("/__shelly/ws/${wsPath}?${wsQuery}"));
11135 assert!(CLIENT_JS.contains("type: \"connect\""));
11136 assert!(CLIENT_JS.contains("session_id: window.__SHELLY.sessionId || null"));
11137 assert!(CLIENT_JS.contains("last_revision: currentRevision()"));
11138 assert!(CLIENT_JS.contains("resume_token: window.__SHELLY.resumeToken || null"));
11139 assert!(CLIENT_JS.contains("trace_id: traceId"));
11140 assert!(CLIENT_JS.contains("span_id: spanId"));
11141 assert!(CLIENT_JS.contains("correlation_id: correlationId"));
11142 assert!(CLIENT_JS.contains("request_id: requestId"));
11143 }
11144
11145 #[test]
11146 fn client_runtime_handles_resume_handshake_metadata() {
11147 assert!(CLIENT_JS.contains("case \"hello\":"));
11148 assert!(CLIENT_JS.contains("message.resume_status"));
11149 assert!(CLIENT_JS.contains("message.resume_reason"));
11150 assert!(CLIENT_JS.contains("message.server_revision"));
11151 assert!(CLIENT_JS.contains("window.__SHELLY.resumeToken = message.resume_token"));
11152 assert!(CLIENT_JS.contains("recordDebug(\"resume_handshake\""));
11153 }
11154
11155 #[test]
11156 fn client_runtime_detects_diff_revision_gaps_and_requests_resync() {
11157 assert!(CLIENT_JS.contains("diff revision gap detected"));
11158 assert!(CLIENT_JS.contains("requestResync(\"diff_revision_gap\")"));
11159 assert!(CLIENT_JS.contains("socket.close(1012, \"resync\")"));
11160 }
11161
11162 #[test]
11163 fn client_runtime_patches_dynamic_slots() {
11164 assert!(CLIENT_JS.contains("case \"diff\":"));
11165 assert!(CLIENT_JS.contains("patchDynamicSlots(message.target, message.slots || [])"));
11166 assert!(CLIENT_JS.contains("data-shelly-slot"));
11167 assert!(CLIENT_JS.contains("node.innerHTML = slot.html || \"\""));
11168 }
11169
11170 #[test]
11171 fn client_runtime_handles_stream_operations() {
11172 assert!(CLIENT_JS.contains("case \"stream_insert\":"));
11173 assert!(CLIENT_JS.contains("case \"stream_delete\":"));
11174 assert!(CLIENT_JS.contains("case \"stream_batch\":"));
11175 assert!(CLIENT_JS.contains("streamBatch(message.target, message.operations || [])"));
11176 assert!(CLIENT_JS.contains("stream.prepend(item)"));
11177 assert!(CLIENT_JS.contains("stream.append(item)"));
11178 assert!(CLIENT_JS.contains("item.remove()"));
11179 }
11180
11181 #[test]
11182 fn client_runtime_handles_uploads() {
11183 assert!(CLIENT_JS.contains("shelly-upload"));
11184 assert!(CLIENT_JS.contains("type: \"upload_start\""));
11185 assert!(CLIENT_JS.contains("type: \"upload_chunk\""));
11186 assert!(CLIENT_JS.contains("type: \"upload_complete\""));
11187 assert!(CLIENT_JS.contains("case \"upload_progress\":"));
11188 assert!(CLIENT_JS.contains("shelly:upload-progress"));
11189 }
11190
11191 #[test]
11192 fn client_runtime_has_debug_inspector() {
11193 assert!(CLIENT_JS.contains("shelly_debug"));
11194 assert!(CLIENT_JS.contains("shelly-debug-inspector"));
11195 assert!(CLIENT_JS.contains("shelly-debug-summary"));
11196 assert!(CLIENT_JS.contains("shelly:debug"));
11197 assert!(CLIENT_JS.contains("recordDebug(\"client_message\""));
11198 assert!(CLIENT_JS.contains("recordDebug(\"server_message\""));
11199 assert!(CLIENT_JS.contains("Session tree"));
11200 assert!(CLIENT_JS.contains("Transport diagnostics"));
11201 assert!(CLIENT_JS.contains("Patch stream"));
11202 assert!(CLIENT_JS.contains("Profiler hotspots"));
11203 assert!(CLIENT_JS.contains("Event timeline"));
11204 assert!(CLIENT_JS.contains("trackDebugRuntimeStats(message)"));
11205 }
11206
11207 #[test]
11208 fn client_runtime_updates_clicks_by_patching_without_reload() {
11209 assert!(CLIENT_JS.contains("document.addEventListener(\"click\""));
11210 assert!(CLIENT_JS.contains("send(eventPayload(name, element, null))"));
11211 assert!(CLIENT_JS.contains("case \"patch\":"));
11212 assert!(CLIENT_JS.contains("root.innerHTML = html"));
11213 assert!(CLIENT_JS.contains("root.replaceWith(replacement)"));
11214 assert!(!CLIENT_JS.contains("window.location.reload"));
11215 }
11216
11217 #[test]
11218 fn client_runtime_supports_explicit_component_targets() {
11219 assert!(CLIENT_JS.contains("shelly-target"));
11220 assert!(CLIENT_JS.contains("elementFromHtml(html)"));
11221 assert!(CLIENT_JS.contains("replacement.getAttribute(\"id\") === target"));
11222 }
11223
11224 #[test]
11225 fn client_runtime_sends_submit_payloads_as_json_objects() {
11226 assert!(CLIENT_JS.contains("document.addEventListener(\"submit\""));
11227 assert!(CLIENT_JS.contains("event.preventDefault()"));
11228 assert!(CLIENT_JS.contains("const data = new FormData(form)"));
11229 assert!(CLIENT_JS.contains("send(eventPayload(name, form, serializeForm(form)))"));
11230 }
11231
11232 #[test]
11233 fn client_runtime_sends_input_and_change_values() {
11234 assert!(CLIENT_JS.contains("document.addEventListener(\"input\""));
11235 assert!(CLIENT_JS.contains("nearestActionTarget(event.target, \"shelly-input\")"));
11236 assert!(CLIENT_JS.contains("document.addEventListener(\"change\""));
11237 assert!(CLIENT_JS.contains("nearestActionTarget(event.target, \"shelly-change\")"));
11238 assert!(CLIENT_JS.contains("send(eventPayload(name, element, element.value))"));
11239 }
11240
11241 #[test]
11242 fn client_runtime_supports_headless_accessibility_primitives() {
11243 assert!(CLIENT_JS.contains("shelly-focus-trap"));
11244 assert!(CLIENT_JS.contains("focusTrapState"));
11245 assert!(CLIENT_JS.contains("handleFocusTrapTab(event)"));
11246 assert!(CLIENT_JS.contains("shelly-dismiss-outside"));
11247 assert!(CLIENT_JS.contains("handleRovingKeydown(event)"));
11248 assert!(CLIENT_JS.contains("handleRovingProxyKeydown(event)"));
11249 assert!(CLIENT_JS.contains("handleTypeahead(event)"));
11250 assert!(CLIENT_JS.contains("resolveRovingRoot(event.target)"));
11251 assert!(CLIENT_JS.contains("shelly-keydown"));
11252 assert!(CLIENT_JS.contains("keyMetadata(event)"));
11253 assert!(CLIENT_JS.contains("shelly-command-hotkey"));
11254 assert!(CLIENT_JS.contains("handleCommandHotkey(event)"));
11255 assert!(CLIENT_JS.contains("shelly-focus"));
11256 }
11257
11258 #[test]
11259 fn client_runtime_supports_native_charting() {
11260 assert!(CLIENT_JS.contains("shelly-chart-line"));
11261 assert!(CLIENT_JS.contains("case \"chart_series_append\":"));
11262 assert!(CLIENT_JS.contains("case \"chart_series_append_many\":"));
11263 assert!(CLIENT_JS.contains("case \"chart_series_replace\":"));
11264 assert!(CLIENT_JS.contains("case \"chart_reset\":"));
11265 assert!(CLIENT_JS.contains("case \"chart_annotation_upsert\":"));
11266 assert!(CLIENT_JS.contains("case \"chart_annotation_delete\":"));
11267 assert!(
11268 CLIENT_JS.contains("chartSeriesAppend(message.chart, message.series, message.point)")
11269 );
11270 assert!(CLIENT_JS
11271 .contains("chartSeriesAppendMany(message.chart, message.series, message.points)"));
11272 assert!(
11273 CLIENT_JS.contains("chartSeriesReplace(message.chart, message.series, message.points)")
11274 );
11275 assert!(CLIENT_JS.contains("chartReset(message.chart)"));
11276 assert!(CLIENT_JS.contains("chartAnnotationUpsert(message.chart, message.annotation)"));
11277 assert!(CLIENT_JS.contains("chartAnnotationDelete(message.chart, message.id)"));
11278 assert!(CLIENT_JS.contains("data-shelly-chart-series"));
11279 assert!(CLIENT_JS.contains("data-shelly-chart-annotations"));
11280 assert!(CLIENT_JS.contains("shelly-chart-hover"));
11281 assert!(CLIENT_JS.contains("shelly-chart-zoom"));
11282 assert!(CLIENT_JS.contains("handleChartPointerMove(event)"));
11283 assert!(CLIENT_JS.contains("handleChartWheel(event)"));
11284 assert!(CLIENT_JS.contains("handleChartPointerDown(event)"));
11285 assert!(CLIENT_JS.contains("handleChartPointerUp(event)"));
11286 }
11287
11288 #[test]
11289 fn client_runtime_supports_native_notifications() {
11290 assert!(CLIENT_JS.contains("case \"toast_push\":"));
11291 assert!(CLIENT_JS.contains("case \"toast_dismiss\":"));
11292 assert!(CLIENT_JS.contains("case \"inbox_upsert\":"));
11293 assert!(CLIENT_JS.contains("case \"inbox_delete\":"));
11294 assert!(CLIENT_JS.contains("toastPush(message.toast)"));
11295 assert!(CLIENT_JS.contains("toastDismiss(message.id)"));
11296 assert!(CLIENT_JS.contains("inboxUpsert(message.item)"));
11297 assert!(CLIENT_JS.contains("inboxDelete(message.id)"));
11298 assert!(CLIENT_JS.contains("shelly-notifications-host"));
11299 assert!(CLIENT_JS.contains("data-shelly-inbox-toggle"));
11300 assert!(CLIENT_JS.contains("data-shelly-toast-region"));
11301 }
11302
11303 #[test]
11304 fn client_runtime_supports_native_dragdrop_primitives() {
11305 assert!(CLIENT_JS.contains("data-shelly-sortable"));
11306 assert!(CLIENT_JS.contains("data-shelly-sort-item"));
11307 assert!(CLIENT_JS.contains("data-shelly-sort-keyboard"));
11308 assert!(CLIENT_JS.contains("shelly_drag_start"));
11309 assert!(CLIENT_JS.contains("shelly_drag_over"));
11310 assert!(CLIENT_JS.contains("shelly_sort"));
11311 assert!(CLIENT_JS.contains("handleSortableKeyboard(event)"));
11312 assert!(CLIENT_JS.contains("handleSortableDragStart(event)"));
11313 assert!(CLIENT_JS.contains("handleSortableDragOver(event)"));
11314 assert!(CLIENT_JS.contains("handleSortableDrop(event)"));
11315 assert!(CLIENT_JS.contains("\"dragstart\","));
11316 assert!(CLIENT_JS.contains("\"dragover\","));
11317 assert!(CLIENT_JS.contains("\"drop\","));
11318 assert!(CLIENT_JS.contains("\"dragend\","));
11319 }
11320
11321 #[test]
11322 fn client_runtime_supports_native_markdown_editor_primitives() {
11323 assert!(CLIENT_JS.contains("shelly-editor-input"));
11324 assert!(CLIENT_JS.contains("shelly-editor-select"));
11325 assert!(CLIENT_JS.contains("shelly-editor-command"));
11326 assert!(CLIENT_JS.contains("shelly-editor-command-event"));
11327 assert!(CLIENT_JS.contains("handleEditorCommandClick(event.target)"));
11328 assert!(CLIENT_JS.contains("editorSelectionState"));
11329 assert!(CLIENT_JS.contains("editorSnapshot(element)"));
11330 assert!(CLIENT_JS.contains("resolveEditorTarget(actionElement)"));
11331 assert!(CLIENT_JS.contains("\"selectionchange\""));
11332 }
11333
11334 #[test]
11335 fn client_runtime_supports_nested_form_serialization() {
11336 assert!(CLIENT_JS.contains("parseFieldPath(key)"));
11337 assert!(CLIENT_JS.contains("setNestedValue(out, path, value)"));
11338 assert!(CLIENT_JS.contains("isNumericPathToken(token)"));
11339 assert!(CLIENT_JS.contains("serializeForm(form)"));
11340 }
11341
11342 #[test]
11343 fn client_runtime_supports_js_interop_dispatch_and_bindings() {
11344 assert!(CLIENT_JS.contains("case \"interop_dispatch\":"));
11345 assert!(CLIENT_JS.contains("interopDispatch(message.dispatch)"));
11346 assert!(CLIENT_JS.contains("function interopDispatch(dispatch)"));
11347 assert!(CLIENT_JS.contains("shelly:interop-dispatch"));
11348 assert!(CLIENT_JS.contains("[shelly-js-interop]"));
11349 assert!(CLIENT_JS.contains("function ensureJsInteropBindings(root)"));
11350 assert!(CLIENT_JS.contains("shelly-js-push"));
11351 assert!(CLIENT_JS.contains("\"js_interop\""));
11352 }
11353
11354 #[test]
11355 fn client_runtime_supports_webrtc_signal_compatibility() {
11356 assert!(CLIENT_JS.contains("dispatch.event"));
11357 assert!(CLIENT_JS.contains("new CustomEvent(eventName"));
11358 assert!(CLIENT_JS.contains("detail: detail"));
11359 assert!(CLIENT_JS.contains("dispatch.bubbles !== false"));
11360 assert!(CLIENT_JS.contains("composed: true"));
11361 }
11362
11363 #[test]
11364 fn client_runtime_supports_enterprise_grid_primitives() {
11365 assert!(CLIENT_JS.contains("case \"grid_replace\":"));
11366 assert!(CLIENT_JS.contains("case \"grid_rows_replace\":"));
11367 assert!(CLIENT_JS.contains("gridReplace(message.grid, message.state)"));
11368 assert!(CLIENT_JS.contains("gridRowsReplace(message.grid, message.window)"));
11369 assert!(CLIENT_JS.contains("function renderGrid(root)"));
11370 assert!(CLIENT_JS.contains("data-shelly-grid-viewport"));
11371 assert!(CLIENT_JS.contains("data-shelly-grid-resize"));
11372 assert!(CLIENT_JS.contains("handleGridResizePointerDown(event)"));
11373 assert!(CLIENT_JS.contains("handleGridResizePointerMove(event)"));
11374 assert!(CLIENT_JS.contains("handleGridResizePointerUp(event)"));
11375 assert!(CLIENT_JS.contains("handleGridKeydown(event)"));
11376 assert!(CLIENT_JS.contains("shelly-grid-edit"));
11377 assert!(CLIENT_JS.contains("shelly-grid-filter"));
11378 assert!(CLIENT_JS.contains("window_start"));
11379 assert!(CLIENT_JS.contains("window_end"));
11380 assert!(CLIENT_JS.contains("shelly:grid-csv-download"));
11381 assert!(CLIENT_JS.contains("downloadGridCsv(event.detail)"));
11382 }
11383}