1use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Mutex;
11use std::time::{Duration, Instant};
12
13use prost::Message;
14
15use crate::broker::capabilities::{handoff_transport_available, CAP_HANDLE_PASSING};
16use crate::broker::lifecycle::names::{validate_service_name, validate_version, PipePathError};
17use crate::broker::protocol::{
18 hello_reply::Result as HelloReplyResult, validate_frame_envelope, ErrorCode, Frame, FrameKind,
19 FrameValidationError, Hello, HelloReply, Negotiated, Refused, ServiceDefinition,
20 CONTROL_PAYLOAD_PROTOCOL, PROTOCOL_VERSION,
21};
22use crate::broker::server::handoff::{
23 AcknowledgedHandoff, ExpiredHandoff, HandoffAckError, HandoffAckRegistry, HandoffToken,
24 HandoffTokenStore, PendingHandoffBackend,
25};
26use crate::broker::server::version_allow_list::{check_version_allowed, VersionPolicyBlock};
27use crate::broker::server::TraceContext;
28
29const DEFAULT_KEEPALIVE_SECS: u64 = 30 * 60;
30const DEFAULT_RATE_LIMIT_MAX_PER_WINDOW: u32 = 256;
31const DEFAULT_RATE_LIMIT_WINDOW: Duration = Duration::from_secs(1);
32
33#[derive(Clone, Debug, PartialEq, Eq)]
35pub struct PeerIdentity {
36 pub pid: u32,
38 pub uid_or_sid: String,
40}
41
42#[derive(Clone, Debug)]
44pub struct HelloRequest {
45 pub frame: Frame,
47 pub hello: Hello,
49 pub peer: PeerIdentity,
51}
52
53impl HelloRequest {
54 pub fn decode(frame: Frame, peer: PeerIdentity) -> Result<Self, Refused> {
56 validate_frame_envelope(&frame, FrameKind::Request, CONTROL_PAYLOAD_PROTOCOL).map_err(
57 |error| match error {
58 FrameValidationError::EnvelopeVersion { .. } => refused(
59 ErrorCode::ErrorVersionUnsupported,
60 "frame envelope_version is not v1",
61 0,
62 ),
63 FrameValidationError::Kind { .. } => refused(
64 ErrorCode::ErrorPeerRejected,
65 "Hello frame kind must be REQUEST",
66 0,
67 ),
68 FrameValidationError::PayloadProtocol { .. } => refused(
69 ErrorCode::ErrorPeerRejected,
70 "Hello frame payload_protocol must be control-plane",
71 0,
72 ),
73 FrameValidationError::PayloadEncoding { .. } => refused(
74 ErrorCode::ErrorPeerRejected,
75 "Hello payload must not be compressed",
76 0,
77 ),
78 },
79 )?;
80 let hello = Hello::decode(frame.payload.as_slice())
81 .map_err(|_| refused(ErrorCode::ErrorPeerRejected, "malformed Hello payload", 0))?;
82 Ok(Self { frame, hello, peer })
83 }
84
85 pub fn trace_context(&self) -> TraceContext {
87 TraceContext::from_frame(&self.frame)
88 }
89}
90
91#[derive(Clone, Debug)]
93pub struct RegisteredBackend {
94 pub service_definition: ServiceDefinition,
96 pub daemon_version: String,
98 pub backend_pipe: String,
100 pub server_capabilities: u64,
102}
103
104#[derive(Debug)]
106pub struct HelloHandler {
107 backends: HashMap<String, RegisteredBackend>,
108 next_connection_id: AtomicU64,
109 rate_limiter: PeerRateLimiter,
110 handoff_tokens: Mutex<HandoffTokenStore>,
111 handoff_acks: Mutex<HandoffAckRegistry>,
112}
113
114impl HelloHandler {
115 pub fn new() -> Self {
117 Self {
118 backends: HashMap::new(),
119 next_connection_id: AtomicU64::new(1),
120 rate_limiter: PeerRateLimiter::default(),
121 handoff_tokens: Mutex::new(HandoffTokenStore::new()),
122 handoff_acks: Mutex::new(HandoffAckRegistry::new()),
123 }
124 }
125
126 pub fn with_handoff_ack_deadline(self, ack_deadline: Duration) -> Self {
128 *self.handoff_ack_registry() = HandoffAckRegistry::with_ack_deadline(ack_deadline);
129 self
130 }
131
132 pub fn handoff_token_store(&self) -> std::sync::MutexGuard<'_, HandoffTokenStore> {
138 self.handoff_tokens
139 .lock()
140 .unwrap_or_else(|poisoned| poisoned.into_inner())
141 }
142
143 pub fn handoff_ack_registry(&self) -> std::sync::MutexGuard<'_, HandoffAckRegistry> {
150 self.handoff_acks
151 .lock()
152 .unwrap_or_else(|poisoned| poisoned.into_inner())
153 }
154
155 pub fn acknowledge_handoff(
160 &self,
161 token: &HandoffToken,
162 now: Instant,
163 ) -> Result<AcknowledgedHandoff, HandoffAckError> {
164 let mut acks = self.handoff_ack_registry();
165 let mut tokens = self.handoff_token_store();
166 acks.acknowledge(&mut tokens, token, now)
167 }
168
169 pub fn expire_overdue_handoffs(&self, now: Instant) -> Vec<ExpiredHandoff> {
174 let mut acks = self.handoff_ack_registry();
175 let mut tokens = self.handoff_token_store();
176 acks.expire_overdue(&mut tokens, now)
177 }
178
179 pub fn with_rate_limit(mut self, max_per_window: u32, window: Duration) -> Self {
181 self.rate_limiter = PeerRateLimiter::new(max_per_window, window);
182 self
183 }
184
185 pub fn with_backend(mut self, backend: RegisteredBackend) -> Result<Self, HelloHandlerError> {
187 validate_service_name_for_result(&backend.service_definition.service_name)?;
188 if !backend.service_definition.min_version.is_empty() {
189 validate_version_for_result(&backend.service_definition.min_version)?;
190 }
191 for version in &backend.service_definition.version_allow_list {
192 validate_version_for_result(version)?;
193 }
194 self.backends
195 .insert(backend.service_definition.service_name.clone(), backend);
196 Ok(self)
197 }
198
199 pub fn handle_frame(&self, frame: Frame, peer: PeerIdentity) -> HelloReply {
201 match HelloRequest::decode(frame, peer) {
202 Ok(request) => self.handle_request(&request),
203 Err(refused) => refused_reply(refused),
204 }
205 }
206
207 pub fn handle_request(&self, request: &HelloRequest) -> HelloReply {
209 let hello = &request.hello;
210 if let Some(refused) = validate_hello_shape(hello, &request.peer) {
211 return refused_reply(refused);
212 }
213 if let Some(retry_after) = self.rate_limiter.check(request.peer.pid) {
214 return refused_reply(refused(
215 ErrorCode::ErrorRateLimited,
216 "Hello rate limit exceeded",
217 duration_to_retry_ms(retry_after),
218 ));
219 }
220
221 let Some(backend) = self.backends.get(&hello.service_name) else {
222 return refused_reply(refused(
223 ErrorCode::ErrorServiceUnknown,
224 "service is not registered",
225 0,
226 ));
227 };
228
229 if let Some(refused) = validate_version_policy(hello, &backend.service_definition) {
230 return refused_reply(refused);
231 }
232
233 let connection_id = self.next_connection_id.fetch_add(1, Ordering::Relaxed);
234 let handle_passed_token =
235 self.issue_handoff_token(hello.client_capabilities, &hello.service_name);
236 let mut server_capabilities = backend.server_capabilities;
237 if !handle_passed_token.is_empty() {
238 server_capabilities |= CAP_HANDLE_PASSING;
239 }
240 refused_or_negotiated(HelloReplyResult::Negotiated(Negotiated {
241 negotiated_protocol: PROTOCOL_VERSION,
242 daemon_version: backend.daemon_version.clone(),
243 backend_pipe: backend.backend_pipe.clone(),
244 warnings: Vec::new(),
245 server_capabilities,
246 keepalive_interval_secs: if hello.client_keepalive_secs == 0 {
247 DEFAULT_KEEPALIVE_SECS
248 } else {
249 hello.client_keepalive_secs
250 },
251 handle_passed_token,
252 connection_id,
253 }))
254 }
255
256 fn issue_handoff_token(&self, client_capabilities: u64, service_name: &str) -> Vec<u8> {
269 if client_capabilities & CAP_HANDLE_PASSING == 0 || !handoff_transport_available() {
270 return Vec::new();
271 }
272 let now = Instant::now();
273 let mut acks = self.handoff_ack_registry();
275 let mut tokens = self.handoff_token_store();
276 match tokens.issue(now) {
277 Ok(token) => {
278 acks.register(token, PendingHandoffBackend::for_service(service_name), now);
279 token.into_bytes().to_vec()
280 }
281 Err(_) => Vec::new(),
282 }
283 }
284}
285
286impl Default for HelloHandler {
287 fn default() -> Self {
288 Self::new()
289 }
290}
291
292#[derive(Debug, thiserror::Error)]
294pub enum HelloHandlerError {
295 #[error(transparent)]
297 PipePath(#[from] PipePathError),
298}
299
300#[derive(Debug)]
302struct PeerRateLimiter {
303 max_per_window: u32,
304 window: Duration,
305 entries: Mutex<HashMap<u32, PeerRateWindow>>,
306}
307
308impl PeerRateLimiter {
309 fn new(max_per_window: u32, window: Duration) -> Self {
310 Self {
311 max_per_window: max_per_window.max(1),
312 window: if window.is_zero() {
313 Duration::from_millis(1)
314 } else {
315 window
316 },
317 entries: Mutex::new(HashMap::new()),
318 }
319 }
320
321 fn check(&self, pid: u32) -> Option<Duration> {
322 if pid == 0 {
323 return None;
324 }
325
326 let now = Instant::now();
327 let mut entries = self
328 .entries
329 .lock()
330 .unwrap_or_else(|poisoned| poisoned.into_inner());
331 let entry = entries.entry(pid).or_insert(PeerRateWindow {
332 started_at: now,
333 count: 0,
334 });
335 let elapsed = now.duration_since(entry.started_at);
336 if elapsed >= self.window {
337 entry.started_at = now;
338 entry.count = 0;
339 }
340
341 if entry.count < self.max_per_window {
342 entry.count += 1;
343 None
344 } else {
345 Some(self.window.saturating_sub(elapsed))
346 }
347 }
348}
349
350impl Default for PeerRateLimiter {
351 fn default() -> Self {
352 Self::new(DEFAULT_RATE_LIMIT_MAX_PER_WINDOW, DEFAULT_RATE_LIMIT_WINDOW)
353 }
354}
355
356#[derive(Debug)]
357struct PeerRateWindow {
358 started_at: Instant,
359 count: u32,
360}
361
362fn validate_hello_shape(hello: &Hello, peer: &PeerIdentity) -> Option<Refused> {
363 if hello.client_min_protocol > PROTOCOL_VERSION || hello.client_max_protocol < PROTOCOL_VERSION
364 {
365 return Some(refused(
366 ErrorCode::ErrorVersionUnsupported,
367 "client protocol range does not include v1",
368 0,
369 ));
370 }
371 if validate_service_name(&hello.service_name).is_err() {
372 return Some(refused(
373 ErrorCode::ErrorPeerRejected,
374 "invalid service_name",
375 0,
376 ));
377 }
378 if hello.wanted_version.len() > 64 || validate_version(&hello.wanted_version).is_err() {
379 return Some(refused(
380 ErrorCode::ErrorPeerRejected,
381 "invalid wanted_version",
382 0,
383 ));
384 }
385 if hello.client_version.len() > 128 {
386 return Some(refused(
387 ErrorCode::ErrorPeerRejected,
388 "client_version exceeds 128 bytes",
389 0,
390 ));
391 }
392 if hello.client_lib_name.len() > 64 || hello.client_lib_version.len() > 64 {
393 return Some(refused(
394 ErrorCode::ErrorPeerRejected,
395 "client_lib fields exceed 64 bytes",
396 0,
397 ));
398 }
399 if hello.peer_pid != 0 && peer.pid != 0 && hello.peer_pid != peer.pid {
402 return Some(refused(
403 ErrorCode::ErrorPeerRejected,
404 "peer_pid does not match verified peer",
405 0,
406 ));
407 }
408 None
409}
410
411fn validate_version_policy(hello: &Hello, service: &ServiceDefinition) -> Option<Refused> {
412 match check_version_allowed(&hello.wanted_version, service) {
413 Ok(()) => None,
414 Err(VersionPolicyBlock::BelowMinVersion) => Some(refused(
415 ErrorCode::ErrorVersionBlocked,
416 "wanted_version is below min_version",
417 30_000,
418 )),
419 Err(VersionPolicyBlock::OutsideAllowList) => Some(refused(
420 ErrorCode::ErrorVersionBlocked,
421 "wanted_version is not in version_allow_list",
422 30_000,
423 )),
424 }
425}
426
427fn validate_service_name_for_result(name: &str) -> Result<(), HelloHandlerError> {
428 validate_service_name(name).map_err(HelloHandlerError::PipePath)
429}
430
431fn validate_version_for_result(version: &str) -> Result<(), HelloHandlerError> {
432 validate_version(version).map_err(HelloHandlerError::PipePath)
433}
434
435fn duration_to_retry_ms(duration: Duration) -> u64 {
436 let millis = duration.as_millis().max(1);
437 u64::try_from(millis).unwrap_or(u64::MAX)
438}
439
440fn refused(code: ErrorCode, reason: impl Into<String>, retry_after_ms: u64) -> Refused {
441 Refused {
442 reason: reason.into(),
443 daemon_min_protocol: PROTOCOL_VERSION,
444 daemon_max_protocol: PROTOCOL_VERSION,
445 code: code as i32,
446 details: HashMap::new(),
447 retry_after_ms,
448 }
449}
450
451fn refused_reply(refused: Refused) -> HelloReply {
452 refused_or_negotiated(HelloReplyResult::Refused(refused))
453}
454
455fn refused_or_negotiated(result: HelloReplyResult) -> HelloReply {
456 HelloReply {
457 result: Some(result),
458 }
459}