1use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Mutex;
11use std::time::{Duration, Instant};
12
13use prost::Message;
14
15use crate::broker::capabilities::{handoff_transport_available, CAP_HANDLE_PASSING};
16use crate::broker::lifecycle::names::{validate_service_name, validate_version, PipePathError};
17use crate::broker::protocol::{
18 hello_reply::Result as HelloReplyResult, ErrorCode, Frame, FrameKind, Hello, HelloReply,
19 Negotiated, PayloadEncoding, Refused, ServiceDefinition,
20};
21use crate::broker::server::handoff::{
22 AcknowledgedHandoff, ExpiredHandoff, HandoffAckError, HandoffAckRegistry, HandoffToken,
23 HandoffTokenStore, PendingHandoffBackend,
24};
25use crate::broker::server::version_allow_list::{check_version_allowed, VersionPolicyBlock};
26use crate::broker::server::TraceContext;
27
28const PROTOCOL_VERSION: u32 = 1;
29const DEFAULT_KEEPALIVE_SECS: u64 = 30 * 60;
30const CONTROL_PAYLOAD_PROTOCOL: u32 = 0;
31const DEFAULT_RATE_LIMIT_MAX_PER_WINDOW: u32 = 256;
32const DEFAULT_RATE_LIMIT_WINDOW: Duration = Duration::from_secs(1);
33
34#[derive(Clone, Debug, PartialEq, Eq)]
36pub struct PeerIdentity {
37 pub pid: u32,
39 pub uid_or_sid: String,
41}
42
43#[derive(Clone, Debug)]
45pub struct HelloRequest {
46 pub frame: Frame,
48 pub hello: Hello,
50 pub peer: PeerIdentity,
52}
53
54impl HelloRequest {
55 pub fn decode(frame: Frame, peer: PeerIdentity) -> Result<Self, Refused> {
57 if frame.envelope_version != PROTOCOL_VERSION {
58 return Err(refused(
59 ErrorCode::ErrorVersionUnsupported,
60 "frame envelope_version is not v1",
61 0,
62 ));
63 }
64 if FrameKind::try_from(frame.kind) != Ok(FrameKind::Request) {
65 return Err(refused(
66 ErrorCode::ErrorPeerRejected,
67 "Hello frame kind must be REQUEST",
68 0,
69 ));
70 }
71 if frame.payload_protocol != CONTROL_PAYLOAD_PROTOCOL {
72 return Err(refused(
73 ErrorCode::ErrorPeerRejected,
74 "Hello frame payload_protocol must be control-plane",
75 0,
76 ));
77 }
78 if PayloadEncoding::try_from(frame.payload_encoding) != Ok(PayloadEncoding::None) {
79 return Err(refused(
80 ErrorCode::ErrorPeerRejected,
81 "Hello payload must not be compressed",
82 0,
83 ));
84 }
85 let hello = Hello::decode(frame.payload.as_slice())
86 .map_err(|_| refused(ErrorCode::ErrorPeerRejected, "malformed Hello payload", 0))?;
87 Ok(Self { frame, hello, peer })
88 }
89
90 pub fn trace_context(&self) -> TraceContext {
92 TraceContext::from_frame(&self.frame)
93 }
94}
95
96#[derive(Clone, Debug)]
98pub struct RegisteredBackend {
99 pub service_definition: ServiceDefinition,
101 pub daemon_version: String,
103 pub backend_pipe: String,
105 pub server_capabilities: u64,
107}
108
109#[derive(Debug)]
111pub struct HelloHandler {
112 backends: HashMap<String, RegisteredBackend>,
113 next_connection_id: AtomicU64,
114 rate_limiter: PeerRateLimiter,
115 handoff_tokens: Mutex<HandoffTokenStore>,
116 handoff_acks: Mutex<HandoffAckRegistry>,
117}
118
119impl HelloHandler {
120 pub fn new() -> Self {
122 Self {
123 backends: HashMap::new(),
124 next_connection_id: AtomicU64::new(1),
125 rate_limiter: PeerRateLimiter::default(),
126 handoff_tokens: Mutex::new(HandoffTokenStore::new()),
127 handoff_acks: Mutex::new(HandoffAckRegistry::new()),
128 }
129 }
130
131 pub fn with_handoff_ack_deadline(self, ack_deadline: Duration) -> Self {
133 *self.handoff_ack_registry() = HandoffAckRegistry::with_ack_deadline(ack_deadline);
134 self
135 }
136
137 pub fn handoff_token_store(&self) -> std::sync::MutexGuard<'_, HandoffTokenStore> {
143 self.handoff_tokens
144 .lock()
145 .unwrap_or_else(|poisoned| poisoned.into_inner())
146 }
147
148 pub fn handoff_ack_registry(&self) -> std::sync::MutexGuard<'_, HandoffAckRegistry> {
155 self.handoff_acks
156 .lock()
157 .unwrap_or_else(|poisoned| poisoned.into_inner())
158 }
159
160 pub fn acknowledge_handoff(
165 &self,
166 token: &HandoffToken,
167 now: Instant,
168 ) -> Result<AcknowledgedHandoff, HandoffAckError> {
169 let mut acks = self.handoff_ack_registry();
170 let mut tokens = self.handoff_token_store();
171 acks.acknowledge(&mut tokens, token, now)
172 }
173
174 pub fn expire_overdue_handoffs(&self, now: Instant) -> Vec<ExpiredHandoff> {
179 let mut acks = self.handoff_ack_registry();
180 let mut tokens = self.handoff_token_store();
181 acks.expire_overdue(&mut tokens, now)
182 }
183
184 pub fn with_rate_limit(mut self, max_per_window: u32, window: Duration) -> Self {
186 self.rate_limiter = PeerRateLimiter::new(max_per_window, window);
187 self
188 }
189
190 pub fn with_backend(mut self, backend: RegisteredBackend) -> Result<Self, HelloHandlerError> {
192 validate_service_name_for_result(&backend.service_definition.service_name)?;
193 if !backend.service_definition.min_version.is_empty() {
194 validate_version_for_result(&backend.service_definition.min_version)?;
195 }
196 for version in &backend.service_definition.version_allow_list {
197 validate_version_for_result(version)?;
198 }
199 self.backends
200 .insert(backend.service_definition.service_name.clone(), backend);
201 Ok(self)
202 }
203
204 pub fn handle_frame(&self, frame: Frame, peer: PeerIdentity) -> HelloReply {
206 match HelloRequest::decode(frame, peer) {
207 Ok(request) => self.handle_request(&request),
208 Err(refused) => refused_reply(refused),
209 }
210 }
211
212 pub fn handle_request(&self, request: &HelloRequest) -> HelloReply {
214 let hello = &request.hello;
215 if let Some(refused) = validate_hello_shape(hello, &request.peer) {
216 return refused_reply(refused);
217 }
218 if let Some(retry_after) = self.rate_limiter.check(request.peer.pid) {
219 return refused_reply(refused(
220 ErrorCode::ErrorRateLimited,
221 "Hello rate limit exceeded",
222 duration_to_retry_ms(retry_after),
223 ));
224 }
225
226 let Some(backend) = self.backends.get(&hello.service_name) else {
227 return refused_reply(refused(
228 ErrorCode::ErrorServiceUnknown,
229 "service is not registered",
230 0,
231 ));
232 };
233
234 if let Some(refused) = validate_version_policy(hello, &backend.service_definition) {
235 return refused_reply(refused);
236 }
237
238 let connection_id = self.next_connection_id.fetch_add(1, Ordering::Relaxed);
239 let handle_passed_token =
240 self.issue_handoff_token(hello.client_capabilities, &hello.service_name);
241 let mut server_capabilities = backend.server_capabilities;
242 if !handle_passed_token.is_empty() {
243 server_capabilities |= CAP_HANDLE_PASSING;
244 }
245 refused_or_negotiated(HelloReplyResult::Negotiated(Negotiated {
246 negotiated_protocol: PROTOCOL_VERSION,
247 daemon_version: backend.daemon_version.clone(),
248 backend_pipe: backend.backend_pipe.clone(),
249 warnings: Vec::new(),
250 server_capabilities,
251 keepalive_interval_secs: if hello.client_keepalive_secs == 0 {
252 DEFAULT_KEEPALIVE_SECS
253 } else {
254 hello.client_keepalive_secs
255 },
256 handle_passed_token,
257 connection_id,
258 }))
259 }
260
261 fn issue_handoff_token(&self, client_capabilities: u64, service_name: &str) -> Vec<u8> {
274 if client_capabilities & CAP_HANDLE_PASSING == 0 || !handoff_transport_available() {
275 return Vec::new();
276 }
277 let now = Instant::now();
278 let mut acks = self.handoff_ack_registry();
280 let mut tokens = self.handoff_token_store();
281 match tokens.issue(now) {
282 Ok(token) => {
283 acks.register(token, PendingHandoffBackend::for_service(service_name), now);
284 token.into_bytes().to_vec()
285 }
286 Err(_) => Vec::new(),
287 }
288 }
289}
290
291impl Default for HelloHandler {
292 fn default() -> Self {
293 Self::new()
294 }
295}
296
297#[derive(Debug, thiserror::Error)]
299pub enum HelloHandlerError {
300 #[error(transparent)]
302 PipePath(#[from] PipePathError),
303}
304
305#[derive(Debug)]
307struct PeerRateLimiter {
308 max_per_window: u32,
309 window: Duration,
310 entries: Mutex<HashMap<u32, PeerRateWindow>>,
311}
312
313impl PeerRateLimiter {
314 fn new(max_per_window: u32, window: Duration) -> Self {
315 Self {
316 max_per_window: max_per_window.max(1),
317 window: if window.is_zero() {
318 Duration::from_millis(1)
319 } else {
320 window
321 },
322 entries: Mutex::new(HashMap::new()),
323 }
324 }
325
326 fn check(&self, pid: u32) -> Option<Duration> {
327 if pid == 0 {
328 return None;
329 }
330
331 let now = Instant::now();
332 let mut entries = self
333 .entries
334 .lock()
335 .unwrap_or_else(|poisoned| poisoned.into_inner());
336 let entry = entries.entry(pid).or_insert(PeerRateWindow {
337 started_at: now,
338 count: 0,
339 });
340 let elapsed = now.duration_since(entry.started_at);
341 if elapsed >= self.window {
342 entry.started_at = now;
343 entry.count = 0;
344 }
345
346 if entry.count < self.max_per_window {
347 entry.count += 1;
348 None
349 } else {
350 Some(self.window.saturating_sub(elapsed))
351 }
352 }
353}
354
355impl Default for PeerRateLimiter {
356 fn default() -> Self {
357 Self::new(DEFAULT_RATE_LIMIT_MAX_PER_WINDOW, DEFAULT_RATE_LIMIT_WINDOW)
358 }
359}
360
361#[derive(Debug)]
362struct PeerRateWindow {
363 started_at: Instant,
364 count: u32,
365}
366
367fn validate_hello_shape(hello: &Hello, peer: &PeerIdentity) -> Option<Refused> {
368 if hello.client_min_protocol > PROTOCOL_VERSION || hello.client_max_protocol < PROTOCOL_VERSION
369 {
370 return Some(refused(
371 ErrorCode::ErrorVersionUnsupported,
372 "client protocol range does not include v1",
373 0,
374 ));
375 }
376 if validate_service_name(&hello.service_name).is_err() {
377 return Some(refused(
378 ErrorCode::ErrorPeerRejected,
379 "invalid service_name",
380 0,
381 ));
382 }
383 if hello.wanted_version.len() > 64 || validate_version(&hello.wanted_version).is_err() {
384 return Some(refused(
385 ErrorCode::ErrorPeerRejected,
386 "invalid wanted_version",
387 0,
388 ));
389 }
390 if hello.client_version.len() > 128 {
391 return Some(refused(
392 ErrorCode::ErrorPeerRejected,
393 "client_version exceeds 128 bytes",
394 0,
395 ));
396 }
397 if hello.client_lib_name.len() > 64 || hello.client_lib_version.len() > 64 {
398 return Some(refused(
399 ErrorCode::ErrorPeerRejected,
400 "client_lib fields exceed 64 bytes",
401 0,
402 ));
403 }
404 if hello.peer_pid != 0 && hello.peer_pid != peer.pid {
405 return Some(refused(
406 ErrorCode::ErrorPeerRejected,
407 "peer_pid does not match verified peer",
408 0,
409 ));
410 }
411 None
412}
413
414fn validate_version_policy(hello: &Hello, service: &ServiceDefinition) -> Option<Refused> {
415 match check_version_allowed(&hello.wanted_version, service) {
416 Ok(()) => None,
417 Err(VersionPolicyBlock::BelowMinVersion) => Some(refused(
418 ErrorCode::ErrorVersionBlocked,
419 "wanted_version is below min_version",
420 30_000,
421 )),
422 Err(VersionPolicyBlock::OutsideAllowList) => Some(refused(
423 ErrorCode::ErrorVersionBlocked,
424 "wanted_version is not in version_allow_list",
425 30_000,
426 )),
427 }
428}
429
430fn validate_service_name_for_result(name: &str) -> Result<(), HelloHandlerError> {
431 validate_service_name(name).map_err(HelloHandlerError::PipePath)
432}
433
434fn validate_version_for_result(version: &str) -> Result<(), HelloHandlerError> {
435 validate_version(version).map_err(HelloHandlerError::PipePath)
436}
437
438fn duration_to_retry_ms(duration: Duration) -> u64 {
439 let millis = duration.as_millis().max(1);
440 u64::try_from(millis).unwrap_or(u64::MAX)
441}
442
443fn refused(code: ErrorCode, reason: impl Into<String>, retry_after_ms: u64) -> Refused {
444 Refused {
445 reason: reason.into(),
446 daemon_min_protocol: PROTOCOL_VERSION,
447 daemon_max_protocol: PROTOCOL_VERSION,
448 code: code as i32,
449 details: HashMap::new(),
450 retry_after_ms,
451 }
452}
453
454fn refused_reply(refused: Refused) -> HelloReply {
455 refused_or_negotiated(HelloReplyResult::Refused(refused))
456}
457
458fn refused_or_negotiated(result: HelloReplyResult) -> HelloReply {
459 HelloReply {
460 result: Some(result),
461 }
462}