1use std::collections::HashMap;
9use std::io::{self, Read, Write};
10use std::sync::Arc;
11use std::thread;
12
13use interprocess::local_socket::traits::Listener;
14use prost::Message;
15
16use crate::broker::protocol::{
17 hello_reply::Result as HelloReplyResult, read_frame_with_cap, write_frame, ErrorCode, Frame,
18 FrameKind, FramingError, HelloReply, PayloadEncoding, Refused, CONTROL_PAYLOAD_PROTOCOL,
19 MAX_HELLO_BYTES, PROTOCOL_VERSION,
20};
21use crate::broker::server::{HelloHandler, HelloRouter, PeerIdentity};
22
23#[derive(Clone, Debug, PartialEq, Eq)]
25pub enum PeerCredentialPolicy {
26 AllowAny,
28 OwnerOnly {
30 uid_or_sid: String,
32 },
33}
34
35impl PeerCredentialPolicy {
36 pub fn allow_any() -> Self {
38 Self::AllowAny
39 }
40
41 pub fn owner_only(uid_or_sid: impl Into<String>) -> Self {
43 Self::OwnerOnly {
44 uid_or_sid: uid_or_sid.into(),
45 }
46 }
47
48 pub fn current_user() -> Option<Self> {
50 #[cfg(unix)]
51 {
52 Some(Self::owner_only(unsafe { libc::geteuid() }.to_string()))
53 }
54
55 #[cfg(windows)]
56 {
57 current_process_user_sid().ok().map(Self::owner_only)
58 }
59 }
60
61 pub fn allows(&self, peer: &PeerIdentity) -> bool {
63 match self {
64 Self::AllowAny => true,
65 Self::OwnerOnly { uid_or_sid } => {
66 !uid_or_sid.is_empty() && peer.uid_or_sid == *uid_or_sid
67 }
68 }
69 }
70}
71
72pub trait HelloResponder {
78 fn handle_frame(&self, frame: Frame, peer: PeerIdentity) -> HelloReply;
80}
81
82impl HelloResponder for HelloHandler {
83 fn handle_frame(&self, frame: Frame, peer: PeerIdentity) -> HelloReply {
84 Self::handle_frame(self, frame, peer)
85 }
86}
87
88impl HelloResponder for HelloRouter<'_> {
89 fn handle_frame(&self, frame: Frame, peer: PeerIdentity) -> HelloReply {
90 Self::handle_frame(self, frame, peer)
91 }
92}
93
94pub fn handle_hello_connection<S: Read + Write>(
100 stream: &mut S,
101 handler: &HelloHandler,
102 peer: PeerIdentity,
103) -> Result<HelloReply, BrokerConnectionError> {
104 handle_hello_connection_with(stream, handler, peer)
105}
106
107pub fn handle_hello_connection_with<S, R>(
112 stream: &mut S,
113 responder: &R,
114 peer: PeerIdentity,
115) -> Result<HelloReply, BrokerConnectionError>
116where
117 S: Read + Write,
118 R: HelloResponder + ?Sized,
119{
120 handle_hello_connection_with_peer_policy(
121 stream,
122 responder,
123 peer,
124 &PeerCredentialPolicy::allow_any(),
125 )
126 .map(|reply| reply.expect("allow-any policy must not drop peers"))
127}
128
129pub fn handle_hello_connection_with_peer_policy<S, R>(
135 stream: &mut S,
136 responder: &R,
137 peer: PeerIdentity,
138 peer_policy: &PeerCredentialPolicy,
139) -> Result<Option<HelloReply>, BrokerConnectionError>
140where
141 S: Read + Write,
142 R: HelloResponder + ?Sized,
143{
144 if !peer_policy.allows(&peer) {
145 return Ok(None);
146 }
147
148 let request_bytes = match read_frame_with_cap(stream, MAX_HELLO_BYTES) {
149 Ok(bytes) => bytes,
150 Err(err) => {
151 let reply = reply_for_framing_error(&err);
152 write_response_frame(stream, None, &reply)?;
153 return Ok(Some(reply));
154 }
155 };
156
157 let request_frame = match Frame::decode(request_bytes.as_slice()) {
158 Ok(frame) => frame,
159 Err(_) => {
160 let reply = refused_reply(ErrorCode::ErrorPeerRejected, "malformed broker Frame", 0);
161 write_response_frame(stream, None, &reply)?;
162 return Ok(Some(reply));
163 }
164 };
165
166 let reply = responder.handle_frame(request_frame.clone(), peer);
167 write_response_frame(stream, Some(&request_frame), &reply)?;
168 Ok(Some(reply))
169}
170
171pub fn serve_one_local_socket(
178 socket_path: &str,
179 handler: &HelloHandler,
180) -> Result<HelloReply, BrokerConnectionError> {
181 serve_one_local_socket_with(socket_path, handler)
182}
183
184pub fn serve_one_local_socket_with<R>(
187 socket_path: &str,
188 responder: &R,
189) -> Result<HelloReply, BrokerConnectionError>
190where
191 R: HelloResponder + ?Sized,
192{
193 serve_one_local_socket_with_peer_policy(
194 socket_path,
195 responder,
196 &PeerCredentialPolicy::allow_any(),
197 )
198 .map(|reply| reply.expect("allow-any policy must not drop peers"))
199}
200
201pub fn serve_one_local_socket_with_peer_policy<R>(
203 socket_path: &str,
204 responder: &R,
205 peer_policy: &PeerCredentialPolicy,
206) -> Result<Option<HelloReply>, BrokerConnectionError>
207where
208 R: HelloResponder + ?Sized,
209{
210 let listener = bind_local_socket(socket_path)?;
211 let cleanup = LocalSocketCleanup(socket_path);
212 let result = (|| {
213 let mut stream = listener.accept()?;
214 let peer = peer_identity_from_stream(&stream)?;
215 handle_hello_connection_with_peer_policy(&mut stream, responder, peer, peer_policy)
216 })();
217 drop(listener);
218 drop(cleanup);
219 result
220}
221
222pub fn serve_local_socket_connections(
228 socket_path: &str,
229 handler: Arc<HelloHandler>,
230 connection_count: usize,
231) -> Result<(), BrokerConnectionError> {
232 serve_local_socket_connections_with_peer_policy(
233 socket_path,
234 handler,
235 connection_count,
236 &PeerCredentialPolicy::allow_any(),
237 )
238}
239
240pub fn serve_local_socket_connections_with_peer_policy(
242 socket_path: &str,
243 handler: Arc<HelloHandler>,
244 connection_count: usize,
245 peer_policy: &PeerCredentialPolicy,
246) -> Result<(), BrokerConnectionError> {
247 if connection_count == 0 {
248 return Ok(());
249 }
250
251 let listener = bind_local_socket(socket_path)?;
252 let cleanup = LocalSocketCleanup(socket_path);
253 let result = (|| {
254 let mut workers = Vec::with_capacity(connection_count);
255 let peer_policy = Arc::new(peer_policy.clone());
256
257 for _ in 0..connection_count {
258 let mut stream = listener.accept()?;
259 let peer = peer_identity_from_stream(&stream)?;
260 let handler = Arc::clone(&handler);
261 let peer_policy = Arc::clone(&peer_policy);
262 workers.push(thread::spawn(move || {
263 handle_hello_connection_with_peer_policy(
264 &mut stream,
265 handler.as_ref(),
266 peer,
267 peer_policy.as_ref(),
268 )
269 .map(|_| ())
270 }));
271 }
272
273 for worker in workers {
274 match worker.join() {
275 Ok(Ok(())) => {}
276 Ok(Err(err)) => return Err(err),
277 Err(_) => return Err(BrokerConnectionError::WorkerPanic),
278 }
279 }
280 Ok(())
281 })();
282 drop(listener);
283 drop(cleanup);
284 result
285}
286
287pub fn serve_local_socket_connections_with<R>(
293 socket_path: &str,
294 responder: &R,
295 connection_count: usize,
296) -> Result<(), BrokerConnectionError>
297where
298 R: HelloResponder + ?Sized,
299{
300 serve_local_socket_connections_with_policy(
301 socket_path,
302 responder,
303 connection_count,
304 &PeerCredentialPolicy::allow_any(),
305 )
306}
307
308pub fn serve_local_socket_connections_with_policy<R>(
310 socket_path: &str,
311 responder: &R,
312 connection_count: usize,
313 peer_policy: &PeerCredentialPolicy,
314) -> Result<(), BrokerConnectionError>
315where
316 R: HelloResponder + ?Sized,
317{
318 if connection_count == 0 {
319 return Ok(());
320 }
321
322 let listener = bind_local_socket(socket_path)?;
323 let cleanup = LocalSocketCleanup(socket_path);
324 let result = (|| {
325 for _ in 0..connection_count {
326 let mut stream = listener.accept()?;
327 let peer = peer_identity_from_stream(&stream)?;
328 let _ = handle_hello_connection_with_peer_policy(
329 &mut stream,
330 responder,
331 peer,
332 peer_policy,
333 )?;
334 }
335 Ok(())
336 })();
337 drop(listener);
338 drop(cleanup);
339 result
340}
341
342pub fn local_socket_name(socket_path: &str) -> io::Result<interprocess::local_socket::Name<'_>> {
345 #[cfg(unix)]
346 {
347 use interprocess::local_socket::{GenericFilePath, ToFsName};
348 socket_path.to_fs_name::<GenericFilePath>()
349 }
350
351 #[cfg(windows)]
352 {
353 use interprocess::local_socket::{GenericNamespaced, ToNsName};
354 socket_path.to_ns_name::<GenericNamespaced>()
355 }
356}
357
358#[derive(Debug, thiserror::Error)]
360pub enum BrokerConnectionError {
361 #[error(transparent)]
363 Framing(#[from] FramingError),
364 #[error("failed to encode broker response Frame: {0}")]
366 EncodeFrame(prost::EncodeError),
367 #[error(transparent)]
369 Io(#[from] io::Error),
370 #[error("broker connection worker panicked")]
372 WorkerPanic,
373}
374
375pub(super) fn bind_local_socket(
376 socket_path: &str,
377) -> Result<interprocess::local_socket::Listener, BrokerConnectionError> {
378 use interprocess::local_socket::ListenerOptions;
379
380 prepare_local_socket_path(socket_path)?;
381 let name = local_socket_name(socket_path)?;
382 let listener = ListenerOptions::new().name(name).create_sync()?;
383 secure_local_socket_path(socket_path)?;
384 Ok(listener)
385}
386
387pub(super) struct LocalSocketCleanup<'a>(pub(super) &'a str);
388
389impl Drop for LocalSocketCleanup<'_> {
390 fn drop(&mut self) {
391 cleanup_local_socket_path(self.0);
392 }
393}
394
395pub(super) fn write_response_frame<W: Write>(
396 writer: &mut W,
397 request_frame: Option<&Frame>,
398 reply: &HelloReply,
399) -> Result<(), BrokerConnectionError> {
400 let response_frame = Frame {
401 envelope_version: PROTOCOL_VERSION,
402 kind: FrameKind::Response as i32,
403 payload_protocol: CONTROL_PAYLOAD_PROTOCOL,
404 payload: reply.encode_to_vec(),
405 request_id: request_frame.map_or(0, |frame| frame.request_id),
406 payload_encoding: PayloadEncoding::None as i32,
407 deadline_unix_ms: 0,
408 traceparent: request_frame
409 .map(|frame| frame.traceparent.clone())
410 .unwrap_or_default(),
411 tracestate: request_frame
412 .map(|frame| frame.tracestate.clone())
413 .unwrap_or_default(),
414 };
415 let mut response_bytes = Vec::new();
416 response_frame
417 .encode(&mut response_bytes)
418 .map_err(BrokerConnectionError::EncodeFrame)?;
419 write_frame(writer, &response_bytes)?;
420 Ok(())
421}
422
423pub(super) fn reply_for_framing_error(error: &FramingError) -> HelloReply {
424 match error {
425 FramingError::UnsupportedFramingVersion { .. } => refused_reply(
426 ErrorCode::ErrorVersionUnsupported,
427 "unsupported framing version",
428 0,
429 ),
430 FramingError::FrameTooLarge { .. } => refused_reply(
431 ErrorCode::ErrorPeerRejected,
432 "initial Hello frame exceeds 64 KiB",
433 0,
434 ),
435 FramingError::UnexpectedEof { .. } | FramingError::Io(_) => {
436 refused_reply(ErrorCode::ErrorPeerRejected, "incomplete Hello frame", 0)
437 }
438 FramingError::Decode(_) => {
441 refused_reply(ErrorCode::ErrorPeerRejected, "malformed Hello frame", 0)
442 }
443 }
444}
445
446pub(super) fn refused_reply(
447 code: ErrorCode,
448 reason: impl Into<String>,
449 retry_after_ms: u64,
450) -> HelloReply {
451 HelloReply {
452 result: Some(HelloReplyResult::Refused(Refused {
453 reason: reason.into(),
454 daemon_min_protocol: PROTOCOL_VERSION,
455 daemon_max_protocol: PROTOCOL_VERSION,
456 code: code as i32,
457 details: HashMap::new(),
458 retry_after_ms,
459 })),
460 }
461}
462
463pub(super) fn peer_identity_from_stream(
464 stream: &interprocess::local_socket::Stream,
465) -> Result<PeerIdentity, BrokerConnectionError> {
466 use interprocess::local_socket::traits::StreamCommon;
467
468 let creds = stream.peer_creds()?;
469 #[cfg(unix)]
470 let pid = creds
471 .pid()
472 .and_then(|pid| u32::try_from(pid).ok())
473 .unwrap_or(0);
474
475 #[cfg(windows)]
476 let pid = creds.pid().unwrap_or(0);
477
478 #[cfg(unix)]
479 let uid_or_sid = creds.euid().map(|uid| uid.to_string()).unwrap_or_default();
480
481 #[cfg(windows)]
482 let uid_or_sid = if pid == 0 {
483 String::new()
484 } else {
485 process_user_sid(pid).unwrap_or_default()
486 };
487
488 Ok(PeerIdentity { pid, uid_or_sid })
489}
490
491#[cfg(windows)]
492fn current_process_user_sid() -> io::Result<String> {
493 process_user_sid(std::process::id())
494}
495
496#[cfg(windows)]
497fn process_user_sid(pid: u32) -> io::Result<String> {
498 use std::ptr;
499 use winapi::um::processthreadsapi::{OpenProcess, OpenProcessToken};
500 use winapi::um::winnt::{
501 TokenUser, HANDLE, PROCESS_QUERY_LIMITED_INFORMATION, TOKEN_QUERY, TOKEN_USER,
502 };
503
504 unsafe {
505 let process = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, 0, pid);
506 if process.is_null() {
507 return Err(io::Error::last_os_error());
508 }
509 let _process_guard = WindowsHandle(process);
510
511 let mut token: HANDLE = ptr::null_mut();
512 if OpenProcessToken(process, TOKEN_QUERY, &mut token) == 0 {
513 return Err(io::Error::last_os_error());
514 }
515 let _token_guard = WindowsHandle(token);
516
517 let mut required_size = 0_u32;
518 let _ = winapi::um::securitybaseapi::GetTokenInformation(
519 token,
520 TokenUser,
521 ptr::null_mut(),
522 0,
523 &mut required_size,
524 );
525 if required_size == 0 {
526 return Err(io::Error::last_os_error());
527 }
528
529 let mut buffer = vec![0_u8; required_size as usize];
530 if winapi::um::securitybaseapi::GetTokenInformation(
531 token,
532 TokenUser,
533 buffer.as_mut_ptr().cast(),
534 required_size,
535 &mut required_size,
536 ) == 0
537 {
538 return Err(io::Error::last_os_error());
539 }
540
541 let token_user: *const TOKEN_USER = buffer.as_ptr().cast();
542 let sid = (*token_user).User.Sid;
543 sid_to_stable_string(sid)
544 }
545}
546
547#[cfg(windows)]
548struct WindowsHandle(winapi::um::winnt::HANDLE);
549
550#[cfg(windows)]
551impl Drop for WindowsHandle {
552 fn drop(&mut self) {
553 unsafe {
554 winapi::um::handleapi::CloseHandle(self.0);
555 }
556 }
557}
558
559#[cfg(windows)]
560unsafe fn sid_to_stable_string(sid: winapi::um::winnt::PSID) -> io::Result<String> {
561 use winapi::um::securitybaseapi::{GetLengthSid, IsValidSid};
562
563 if sid.is_null() || IsValidSid(sid) == 0 {
564 return Err(io::Error::other("invalid Windows SID"));
565 }
566 let len = GetLengthSid(sid) as usize;
567 if len == 0 || len > 1024 {
568 return Err(io::Error::other(format!(
569 "implausible Windows SID length {len}"
570 )));
571 }
572 let bytes = std::slice::from_raw_parts(sid.cast::<u8>(), len);
573 let mut out = String::with_capacity("windows-sid:".len() + len * 2);
574 out.push_str("windows-sid:");
575 for byte in bytes {
576 out.push(nibble_to_hex(byte >> 4));
577 out.push(nibble_to_hex(byte & 0x0f));
578 }
579 Ok(out)
580}
581
582#[cfg(windows)]
583fn nibble_to_hex(nibble: u8) -> char {
584 match nibble {
585 0..=9 => (b'0' + nibble) as char,
586 10..=15 => (b'a' + (nibble - 10)) as char,
587 _ => unreachable!("nibble out of range"),
588 }
589}
590
591fn prepare_local_socket_path(socket_path: &str) -> io::Result<()> {
592 #[cfg(unix)]
593 {
594 let path = std::path::Path::new(socket_path);
595 if let Some(parent) = path.parent() {
596 std::fs::create_dir_all(parent)?;
597 }
598 match std::fs::symlink_metadata(path) {
599 Ok(_) => {
600 return Err(io::Error::new(
601 io::ErrorKind::AlreadyExists,
602 "broker local socket path already exists",
603 ));
604 }
605 Err(err) if err.kind() == io::ErrorKind::NotFound => {}
606 Err(err) => return Err(err),
607 }
608 }
609
610 #[cfg(windows)]
611 let _ = socket_path;
612
613 Ok(())
614}
615
616fn secure_local_socket_path(socket_path: &str) -> io::Result<()> {
617 #[cfg(unix)]
618 {
619 use std::os::unix::fs::PermissionsExt;
620
621 let perms = std::fs::Permissions::from_mode(0o600);
622 std::fs::set_permissions(socket_path, perms)?;
623 }
624
625 #[cfg(windows)]
626 let _ = socket_path;
627
628 Ok(())
629}
630
631fn cleanup_local_socket_path(socket_path: &str) {
632 #[cfg(unix)]
633 {
634 let _ = std::fs::remove_file(socket_path);
635 }
636
637 #[cfg(windows)]
638 let _ = socket_path;
639}