1use std::collections::HashMap;
9use std::io::{self, Read, Write};
10use std::sync::Arc;
11use std::thread;
12
13use interprocess::local_socket::traits::Listener;
14use prost::Message;
15
16use crate::broker::protocol::{
17 hello_reply::Result as HelloReplyResult, read_frame_with_cap, write_frame, ErrorCode, Frame,
18 FrameKind, FramingError, HelloReply, PayloadEncoding, Refused, MAX_HELLO_BYTES,
19};
20use crate::broker::server::{HelloHandler, HelloRouter, PeerIdentity};
21
22const PROTOCOL_VERSION: u32 = 1;
23const CONTROL_PAYLOAD_PROTOCOL: u32 = 0;
24
25#[derive(Clone, Debug, PartialEq, Eq)]
27pub enum PeerCredentialPolicy {
28 AllowAny,
30 OwnerOnly {
32 uid_or_sid: String,
34 },
35}
36
37impl PeerCredentialPolicy {
38 pub fn allow_any() -> Self {
40 Self::AllowAny
41 }
42
43 pub fn owner_only(uid_or_sid: impl Into<String>) -> Self {
45 Self::OwnerOnly {
46 uid_or_sid: uid_or_sid.into(),
47 }
48 }
49
50 pub fn current_user() -> Option<Self> {
52 #[cfg(unix)]
53 {
54 Some(Self::owner_only(unsafe { libc::geteuid() }.to_string()))
55 }
56
57 #[cfg(windows)]
58 {
59 current_process_user_sid().ok().map(Self::owner_only)
60 }
61 }
62
63 pub fn allows(&self, peer: &PeerIdentity) -> bool {
65 match self {
66 Self::AllowAny => true,
67 Self::OwnerOnly { uid_or_sid } => {
68 !uid_or_sid.is_empty() && peer.uid_or_sid == *uid_or_sid
69 }
70 }
71 }
72}
73
74pub trait HelloResponder {
80 fn handle_frame(&self, frame: Frame, peer: PeerIdentity) -> HelloReply;
82}
83
84impl HelloResponder for HelloHandler {
85 fn handle_frame(&self, frame: Frame, peer: PeerIdentity) -> HelloReply {
86 Self::handle_frame(self, frame, peer)
87 }
88}
89
90impl HelloResponder for HelloRouter<'_> {
91 fn handle_frame(&self, frame: Frame, peer: PeerIdentity) -> HelloReply {
92 Self::handle_frame(self, frame, peer)
93 }
94}
95
96pub fn handle_hello_connection<S: Read + Write>(
102 stream: &mut S,
103 handler: &HelloHandler,
104 peer: PeerIdentity,
105) -> Result<HelloReply, BrokerConnectionError> {
106 handle_hello_connection_with(stream, handler, peer)
107}
108
109pub fn handle_hello_connection_with<S, R>(
114 stream: &mut S,
115 responder: &R,
116 peer: PeerIdentity,
117) -> Result<HelloReply, BrokerConnectionError>
118where
119 S: Read + Write,
120 R: HelloResponder + ?Sized,
121{
122 handle_hello_connection_with_peer_policy(
123 stream,
124 responder,
125 peer,
126 &PeerCredentialPolicy::allow_any(),
127 )
128 .map(|reply| reply.expect("allow-any policy must not drop peers"))
129}
130
131pub fn handle_hello_connection_with_peer_policy<S, R>(
137 stream: &mut S,
138 responder: &R,
139 peer: PeerIdentity,
140 peer_policy: &PeerCredentialPolicy,
141) -> Result<Option<HelloReply>, BrokerConnectionError>
142where
143 S: Read + Write,
144 R: HelloResponder + ?Sized,
145{
146 if !peer_policy.allows(&peer) {
147 return Ok(None);
148 }
149
150 let request_bytes = match read_frame_with_cap(stream, MAX_HELLO_BYTES) {
151 Ok(bytes) => bytes,
152 Err(err) => {
153 let reply = reply_for_framing_error(&err);
154 write_response_frame(stream, None, &reply)?;
155 return Ok(Some(reply));
156 }
157 };
158
159 let request_frame = match Frame::decode(request_bytes.as_slice()) {
160 Ok(frame) => frame,
161 Err(_) => {
162 let reply = refused_reply(ErrorCode::ErrorPeerRejected, "malformed broker Frame", 0);
163 write_response_frame(stream, None, &reply)?;
164 return Ok(Some(reply));
165 }
166 };
167
168 let reply = responder.handle_frame(request_frame.clone(), peer);
169 write_response_frame(stream, Some(&request_frame), &reply)?;
170 Ok(Some(reply))
171}
172
173pub fn serve_one_local_socket(
180 socket_path: &str,
181 handler: &HelloHandler,
182) -> Result<HelloReply, BrokerConnectionError> {
183 serve_one_local_socket_with(socket_path, handler)
184}
185
186pub fn serve_one_local_socket_with<R>(
189 socket_path: &str,
190 responder: &R,
191) -> Result<HelloReply, BrokerConnectionError>
192where
193 R: HelloResponder + ?Sized,
194{
195 serve_one_local_socket_with_peer_policy(
196 socket_path,
197 responder,
198 &PeerCredentialPolicy::allow_any(),
199 )
200 .map(|reply| reply.expect("allow-any policy must not drop peers"))
201}
202
203pub fn serve_one_local_socket_with_peer_policy<R>(
205 socket_path: &str,
206 responder: &R,
207 peer_policy: &PeerCredentialPolicy,
208) -> Result<Option<HelloReply>, BrokerConnectionError>
209where
210 R: HelloResponder + ?Sized,
211{
212 let listener = bind_local_socket(socket_path)?;
213 let cleanup = LocalSocketCleanup(socket_path);
214 let result = (|| {
215 let mut stream = listener.accept()?;
216 let peer = peer_identity_from_stream(&stream)?;
217 handle_hello_connection_with_peer_policy(&mut stream, responder, peer, peer_policy)
218 })();
219 drop(listener);
220 drop(cleanup);
221 result
222}
223
224pub fn serve_local_socket_connections(
230 socket_path: &str,
231 handler: Arc<HelloHandler>,
232 connection_count: usize,
233) -> Result<(), BrokerConnectionError> {
234 serve_local_socket_connections_with_peer_policy(
235 socket_path,
236 handler,
237 connection_count,
238 &PeerCredentialPolicy::allow_any(),
239 )
240}
241
242pub fn serve_local_socket_connections_with_peer_policy(
244 socket_path: &str,
245 handler: Arc<HelloHandler>,
246 connection_count: usize,
247 peer_policy: &PeerCredentialPolicy,
248) -> Result<(), BrokerConnectionError> {
249 if connection_count == 0 {
250 return Ok(());
251 }
252
253 let listener = bind_local_socket(socket_path)?;
254 let cleanup = LocalSocketCleanup(socket_path);
255 let result = (|| {
256 let mut workers = Vec::with_capacity(connection_count);
257 let peer_policy = Arc::new(peer_policy.clone());
258
259 for _ in 0..connection_count {
260 let mut stream = listener.accept()?;
261 let peer = peer_identity_from_stream(&stream)?;
262 let handler = Arc::clone(&handler);
263 let peer_policy = Arc::clone(&peer_policy);
264 workers.push(thread::spawn(move || {
265 handle_hello_connection_with_peer_policy(
266 &mut stream,
267 handler.as_ref(),
268 peer,
269 peer_policy.as_ref(),
270 )
271 .map(|_| ())
272 }));
273 }
274
275 for worker in workers {
276 match worker.join() {
277 Ok(Ok(())) => {}
278 Ok(Err(err)) => return Err(err),
279 Err(_) => return Err(BrokerConnectionError::WorkerPanic),
280 }
281 }
282 Ok(())
283 })();
284 drop(listener);
285 drop(cleanup);
286 result
287}
288
289pub fn serve_local_socket_connections_with<R>(
295 socket_path: &str,
296 responder: &R,
297 connection_count: usize,
298) -> Result<(), BrokerConnectionError>
299where
300 R: HelloResponder + ?Sized,
301{
302 serve_local_socket_connections_with_policy(
303 socket_path,
304 responder,
305 connection_count,
306 &PeerCredentialPolicy::allow_any(),
307 )
308}
309
310pub fn serve_local_socket_connections_with_policy<R>(
312 socket_path: &str,
313 responder: &R,
314 connection_count: usize,
315 peer_policy: &PeerCredentialPolicy,
316) -> Result<(), BrokerConnectionError>
317where
318 R: HelloResponder + ?Sized,
319{
320 if connection_count == 0 {
321 return Ok(());
322 }
323
324 let listener = bind_local_socket(socket_path)?;
325 let cleanup = LocalSocketCleanup(socket_path);
326 let result = (|| {
327 for _ in 0..connection_count {
328 let mut stream = listener.accept()?;
329 let peer = peer_identity_from_stream(&stream)?;
330 let _ = handle_hello_connection_with_peer_policy(
331 &mut stream,
332 responder,
333 peer,
334 peer_policy,
335 )?;
336 }
337 Ok(())
338 })();
339 drop(listener);
340 drop(cleanup);
341 result
342}
343
344pub fn local_socket_name(socket_path: &str) -> io::Result<interprocess::local_socket::Name<'_>> {
347 #[cfg(unix)]
348 {
349 use interprocess::local_socket::{GenericFilePath, ToFsName};
350 socket_path.to_fs_name::<GenericFilePath>()
351 }
352
353 #[cfg(windows)]
354 {
355 use interprocess::local_socket::{GenericNamespaced, ToNsName};
356 socket_path.to_ns_name::<GenericNamespaced>()
357 }
358}
359
360#[derive(Debug, thiserror::Error)]
362pub enum BrokerConnectionError {
363 #[error(transparent)]
365 Framing(#[from] FramingError),
366 #[error("failed to encode broker response Frame: {0}")]
368 EncodeFrame(prost::EncodeError),
369 #[error(transparent)]
371 Io(#[from] io::Error),
372 #[error("broker connection worker panicked")]
374 WorkerPanic,
375}
376
377pub(super) fn bind_local_socket(
378 socket_path: &str,
379) -> Result<interprocess::local_socket::Listener, BrokerConnectionError> {
380 use interprocess::local_socket::ListenerOptions;
381
382 prepare_local_socket_path(socket_path)?;
383 let name = local_socket_name(socket_path)?;
384 let listener = ListenerOptions::new().name(name).create_sync()?;
385 secure_local_socket_path(socket_path)?;
386 Ok(listener)
387}
388
389pub(super) struct LocalSocketCleanup<'a>(pub(super) &'a str);
390
391impl Drop for LocalSocketCleanup<'_> {
392 fn drop(&mut self) {
393 cleanup_local_socket_path(self.0);
394 }
395}
396
397pub(super) fn write_response_frame<W: Write>(
398 writer: &mut W,
399 request_frame: Option<&Frame>,
400 reply: &HelloReply,
401) -> Result<(), BrokerConnectionError> {
402 let response_frame = Frame {
403 envelope_version: PROTOCOL_VERSION,
404 kind: FrameKind::Response as i32,
405 payload_protocol: CONTROL_PAYLOAD_PROTOCOL,
406 payload: reply.encode_to_vec(),
407 request_id: request_frame.map_or(0, |frame| frame.request_id),
408 payload_encoding: PayloadEncoding::None as i32,
409 deadline_unix_ms: 0,
410 traceparent: request_frame
411 .map(|frame| frame.traceparent.clone())
412 .unwrap_or_default(),
413 tracestate: request_frame
414 .map(|frame| frame.tracestate.clone())
415 .unwrap_or_default(),
416 };
417 let mut response_bytes = Vec::new();
418 response_frame
419 .encode(&mut response_bytes)
420 .map_err(BrokerConnectionError::EncodeFrame)?;
421 write_frame(writer, &response_bytes)?;
422 Ok(())
423}
424
425pub(super) fn reply_for_framing_error(error: &FramingError) -> HelloReply {
426 match error {
427 FramingError::UnsupportedFramingVersion { .. } => refused_reply(
428 ErrorCode::ErrorVersionUnsupported,
429 "unsupported framing version",
430 0,
431 ),
432 FramingError::FrameTooLarge { .. } => refused_reply(
433 ErrorCode::ErrorPeerRejected,
434 "initial Hello frame exceeds 64 KiB",
435 0,
436 ),
437 FramingError::UnexpectedEof { .. } | FramingError::Io(_) => {
438 refused_reply(ErrorCode::ErrorPeerRejected, "incomplete Hello frame", 0)
439 }
440 }
441}
442
443pub(super) fn refused_reply(
444 code: ErrorCode,
445 reason: impl Into<String>,
446 retry_after_ms: u64,
447) -> HelloReply {
448 HelloReply {
449 result: Some(HelloReplyResult::Refused(Refused {
450 reason: reason.into(),
451 daemon_min_protocol: PROTOCOL_VERSION,
452 daemon_max_protocol: PROTOCOL_VERSION,
453 code: code as i32,
454 details: HashMap::new(),
455 retry_after_ms,
456 })),
457 }
458}
459
460pub(super) fn peer_identity_from_stream(
461 stream: &interprocess::local_socket::Stream,
462) -> Result<PeerIdentity, BrokerConnectionError> {
463 use interprocess::local_socket::traits::StreamCommon;
464
465 let creds = stream.peer_creds()?;
466 #[cfg(unix)]
467 let pid = creds
468 .pid()
469 .and_then(|pid| u32::try_from(pid).ok())
470 .unwrap_or(0);
471
472 #[cfg(windows)]
473 let pid = creds.pid().unwrap_or(0);
474
475 #[cfg(unix)]
476 let uid_or_sid = creds.euid().map(|uid| uid.to_string()).unwrap_or_default();
477
478 #[cfg(windows)]
479 let uid_or_sid = if pid == 0 {
480 String::new()
481 } else {
482 process_user_sid(pid).unwrap_or_default()
483 };
484
485 Ok(PeerIdentity { pid, uid_or_sid })
486}
487
488#[cfg(windows)]
489fn current_process_user_sid() -> io::Result<String> {
490 process_user_sid(std::process::id())
491}
492
493#[cfg(windows)]
494fn process_user_sid(pid: u32) -> io::Result<String> {
495 use std::ptr;
496 use winapi::um::processthreadsapi::{OpenProcess, OpenProcessToken};
497 use winapi::um::winnt::{
498 TokenUser, HANDLE, PROCESS_QUERY_LIMITED_INFORMATION, TOKEN_QUERY, TOKEN_USER,
499 };
500
501 unsafe {
502 let process = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, 0, pid);
503 if process.is_null() {
504 return Err(io::Error::last_os_error());
505 }
506 let _process_guard = WindowsHandle(process);
507
508 let mut token: HANDLE = ptr::null_mut();
509 if OpenProcessToken(process, TOKEN_QUERY, &mut token) == 0 {
510 return Err(io::Error::last_os_error());
511 }
512 let _token_guard = WindowsHandle(token);
513
514 let mut required_size = 0_u32;
515 let _ = winapi::um::securitybaseapi::GetTokenInformation(
516 token,
517 TokenUser,
518 ptr::null_mut(),
519 0,
520 &mut required_size,
521 );
522 if required_size == 0 {
523 return Err(io::Error::last_os_error());
524 }
525
526 let mut buffer = vec![0_u8; required_size as usize];
527 if winapi::um::securitybaseapi::GetTokenInformation(
528 token,
529 TokenUser,
530 buffer.as_mut_ptr().cast(),
531 required_size,
532 &mut required_size,
533 ) == 0
534 {
535 return Err(io::Error::last_os_error());
536 }
537
538 let token_user: *const TOKEN_USER = buffer.as_ptr().cast();
539 let sid = (*token_user).User.Sid;
540 sid_to_stable_string(sid)
541 }
542}
543
544#[cfg(windows)]
545struct WindowsHandle(winapi::um::winnt::HANDLE);
546
547#[cfg(windows)]
548impl Drop for WindowsHandle {
549 fn drop(&mut self) {
550 unsafe {
551 winapi::um::handleapi::CloseHandle(self.0);
552 }
553 }
554}
555
556#[cfg(windows)]
557unsafe fn sid_to_stable_string(sid: winapi::um::winnt::PSID) -> io::Result<String> {
558 use winapi::um::securitybaseapi::{GetLengthSid, IsValidSid};
559
560 if sid.is_null() || IsValidSid(sid) == 0 {
561 return Err(io::Error::other("invalid Windows SID"));
562 }
563 let len = GetLengthSid(sid) as usize;
564 if len == 0 || len > 1024 {
565 return Err(io::Error::other(format!(
566 "implausible Windows SID length {len}"
567 )));
568 }
569 let bytes = std::slice::from_raw_parts(sid.cast::<u8>(), len);
570 let mut out = String::with_capacity("windows-sid:".len() + len * 2);
571 out.push_str("windows-sid:");
572 for byte in bytes {
573 out.push(nibble_to_hex(byte >> 4));
574 out.push(nibble_to_hex(byte & 0x0f));
575 }
576 Ok(out)
577}
578
579#[cfg(windows)]
580fn nibble_to_hex(nibble: u8) -> char {
581 match nibble {
582 0..=9 => (b'0' + nibble) as char,
583 10..=15 => (b'a' + (nibble - 10)) as char,
584 _ => unreachable!("nibble out of range"),
585 }
586}
587
588fn prepare_local_socket_path(socket_path: &str) -> io::Result<()> {
589 #[cfg(unix)]
590 {
591 let path = std::path::Path::new(socket_path);
592 if let Some(parent) = path.parent() {
593 std::fs::create_dir_all(parent)?;
594 }
595 match std::fs::symlink_metadata(path) {
596 Ok(_) => {
597 return Err(io::Error::new(
598 io::ErrorKind::AlreadyExists,
599 "broker local socket path already exists",
600 ));
601 }
602 Err(err) if err.kind() == io::ErrorKind::NotFound => {}
603 Err(err) => return Err(err),
604 }
605 }
606
607 #[cfg(windows)]
608 let _ = socket_path;
609
610 Ok(())
611}
612
613fn secure_local_socket_path(socket_path: &str) -> io::Result<()> {
614 #[cfg(unix)]
615 {
616 use std::os::unix::fs::PermissionsExt;
617
618 let perms = std::fs::Permissions::from_mode(0o600);
619 std::fs::set_permissions(socket_path, perms)?;
620 }
621
622 #[cfg(windows)]
623 let _ = socket_path;
624
625 Ok(())
626}
627
628fn cleanup_local_socket_path(socket_path: &str) {
629 #[cfg(unix)]
630 {
631 let _ = std::fs::remove_file(socket_path);
632 }
633
634 #[cfg(windows)]
635 let _ = socket_path;
636}