1use std::io::{self, Read, Write};
4use std::thread;
5use std::time::{Duration, Instant};
6
7use interprocess::local_socket::prelude::*;
8use prost::Message;
9
10use crate::broker::backend_lifecycle::identity::{DaemonProcess, IdentityError};
11use crate::broker::backend_lifecycle::verify_pid::{self, ProcessHandle, VerifyPidError};
12use crate::broker::protocol::{
13 self, read_frame, write_frame, Endpoint, Frame, FrameKind, FramingError, PayloadEncoding,
14 ENVELOPE_VERSION, MAX_FRAME_BYTES, PROTOCOL_VERSION,
15};
16
17pub const PROBE_NONCE_BYTES: usize = 32;
19const NONBLOCKING_POLL_INTERVAL: Duration = Duration::from_millis(5);
20
21pub use crate::broker::protocol::registry::BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL;
27
28pub const DEFAULT_ENDPOINT_PROBE_TIMEOUT: Duration = Duration::from_millis(500);
30
31pub fn probe_endpoint(
33 endpoint: &Endpoint,
34 expected: &DaemonProcess,
35) -> Result<ProcessHandle, ProbeError> {
36 if !same_endpoint(endpoint, &expected.ipc_endpoint) {
37 return Err(ProbeError::EndpointMismatch);
38 }
39 let process_handle =
40 verify_pid::verify_daemon_process(expected).map_err(ProbeError::VerifyPid)?;
41 probe_endpoint_response(endpoint, expected)?;
42 Ok(process_handle)
43}
44
45pub fn same_endpoint(left: &Endpoint, right: &Endpoint) -> bool {
47 left.namespace_id == right.namespace_id && left.path == right.path
48}
49
50pub fn probe_endpoint_response(
57 endpoint: &Endpoint,
58 expected: &DaemonProcess,
59) -> Result<(), EndpointProbeError> {
60 probe_endpoint_response_with_timeout(endpoint, expected, DEFAULT_ENDPOINT_PROBE_TIMEOUT)
61}
62
63pub fn probe_endpoint_response_with_timeout(
65 endpoint: &Endpoint,
66 expected: &DaemonProcess,
67 timeout: Duration,
68) -> Result<(), EndpointProbeError> {
69 let mut nonce = [0_u8; PROBE_NONCE_BYTES];
70 getrandom::fill(&mut nonce).map_err(EndpointProbeError::Random)?;
71 let request_id = u64::from_le_bytes(nonce[..8].try_into().expect("nonce has 8 bytes"));
72 let request_frame = endpoint_probe_request_frame(request_id, &nonce);
73 let mut request_bytes = Vec::new();
74 request_frame
75 .encode(&mut request_bytes)
76 .map_err(EndpointProbeError::EncodeFrame)?;
77
78 let deadline = Instant::now() + timeout;
79 let mut stream = connect_endpoint_with_deadline(endpoint, deadline)?;
80 stream
81 .set_nonblocking(true)
82 .map_err(EndpointProbeError::ConfigureNonblocking)?;
83 write_probe_frame_with_deadline(&mut stream, &request_bytes, deadline)?;
84
85 let response_bytes = read_probe_frame_with_deadline(&mut stream, deadline)?;
86 let response_frame =
87 Frame::decode(response_bytes.as_slice()).map_err(EndpointProbeError::DecodeFrame)?;
88 validate_endpoint_probe_response_frame(&response_frame, request_id)?;
89 let actual = decode_response_identity(&response_frame.payload, &nonce)?;
90 if !same_daemon_identity(&actual, expected) {
91 return Err(identity_mismatch(expected, &actual));
92 }
93 Ok(())
94}
95
96#[derive(Debug, Clone, PartialEq, Eq)]
98pub struct EndpointProbeRequest {
99 pub request_id: u64,
101 pub nonce: [u8; PROBE_NONCE_BYTES],
103 pub traceparent: String,
105 pub tracestate: String,
107}
108
109pub fn read_endpoint_probe_request<S: Read>(
111 stream: &mut S,
112) -> Result<EndpointProbeRequest, EndpointProbeServerError> {
113 let request_bytes = read_frame(stream)?;
114 let frame =
115 Frame::decode(request_bytes.as_slice()).map_err(EndpointProbeServerError::DecodeFrame)?;
116 endpoint_probe_request_from_frame(&frame)
117}
118
119pub fn endpoint_probe_request_from_frame(
125 frame: &Frame,
126) -> Result<EndpointProbeRequest, EndpointProbeServerError> {
127 validate_endpoint_probe_request_frame(frame)?;
128 let nonce = frame
129 .payload
130 .as_slice()
131 .try_into()
132 .map_err(|_| EndpointProbeServerError::MalformedPayload("nonce must be 32 bytes"))?;
133 Ok(EndpointProbeRequest {
134 request_id: frame.request_id,
135 nonce,
136 traceparent: frame.traceparent.clone(),
137 tracestate: frame.tracestate.clone(),
138 })
139}
140
141pub fn write_endpoint_probe_response<S: Write>(
143 stream: &mut S,
144 request: &EndpointProbeRequest,
145 daemon: &DaemonProcess,
146) -> Result<(), EndpointProbeServerError> {
147 let response_frame = endpoint_probe_response_frame(request, daemon);
148 let mut response_bytes = Vec::new();
149 response_frame
150 .encode(&mut response_bytes)
151 .map_err(EndpointProbeServerError::EncodeFrame)?;
152 write_frame(stream, &response_bytes)?;
153 Ok(())
154}
155
156pub fn handle_endpoint_probe<S: Read + Write>(
158 stream: &mut S,
159 daemon: &DaemonProcess,
160) -> Result<(), EndpointProbeServerError> {
161 let request = read_endpoint_probe_request(stream)?;
162 write_endpoint_probe_response(stream, &request, daemon)
163}
164
165#[derive(Debug, thiserror::Error)]
167pub enum ProbeError {
168 #[error("endpoint does not match expected daemon identity")]
170 EndpointMismatch,
171 #[error(transparent)]
173 EndpointResponse(#[from] EndpointProbeError),
174 #[error(transparent)]
176 VerifyPid(#[from] VerifyPidError),
177}
178
179#[derive(Debug, thiserror::Error)]
181pub enum EndpointProbeError {
182 #[error("backend endpoint probe random generation failed: {0}")]
184 Random(getrandom::Error),
185 #[error("backend endpoint probe local-socket name failed: {0}")]
187 LocalSocketName(io::Error),
188 #[error("backend endpoint probe connect failed: {0}")]
190 Connect(io::Error),
191 #[error("backend endpoint probe nonblocking setup failed: {0}")]
193 ConfigureNonblocking(io::Error),
194 #[error("backend endpoint probe timed out")]
196 Timeout,
197 #[error("backend endpoint probe I/O failed: {0}")]
199 Io(io::Error),
200 #[error("backend endpoint probe unsupported framing version: got {got}, expected {expected}")]
202 UnsupportedFramingVersion {
203 got: u8,
205 expected: u8,
207 },
208 #[error("backend endpoint probe frame body too large: {body_length} bytes exceeds cap {cap}")]
210 FrameTooLarge {
211 body_length: usize,
213 cap: usize,
215 },
216 #[error("failed to encode endpoint probe frame: {0}")]
218 EncodeFrame(prost::EncodeError),
219 #[error("failed to decode endpoint probe response Frame: {0}")]
221 DecodeFrame(prost::DecodeError),
222 #[error("unexpected endpoint probe response: {0}")]
224 UnexpectedFrame(&'static str),
225 #[error("endpoint probe response payload is malformed: {0}")]
227 MalformedPayload(&'static str),
228 #[error("failed to decode endpoint probe daemon identity: {0}")]
230 DecodeDaemonProcess(prost::DecodeError),
231 #[error(transparent)]
233 Identity(#[from] IdentityError),
234 #[error("endpoint probe response identity did not match expected daemon identity: {field}")]
236 IdentityMismatch {
237 field: &'static str,
239 },
240}
241
242#[derive(Debug, thiserror::Error)]
244pub enum EndpointProbeServerError {
245 #[error(transparent)]
247 Framing(#[from] FramingError),
248 #[error("failed to decode endpoint probe request Frame: {0}")]
250 DecodeFrame(prost::DecodeError),
251 #[error("failed to encode endpoint probe response Frame: {0}")]
253 EncodeFrame(prost::EncodeError),
254 #[error("unexpected endpoint probe request: {0}")]
256 UnexpectedFrame(&'static str),
257 #[error("endpoint probe request payload is malformed: {0}")]
259 MalformedPayload(&'static str),
260}
261
262fn endpoint_probe_request_frame(request_id: u64, nonce: &[u8; PROBE_NONCE_BYTES]) -> Frame {
263 Frame {
264 envelope_version: PROTOCOL_VERSION,
265 kind: FrameKind::Request as i32,
266 payload_protocol: BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL,
267 payload: nonce.to_vec(),
268 request_id,
269 payload_encoding: PayloadEncoding::None as i32,
270 deadline_unix_ms: 0,
271 traceparent: String::new(),
272 tracestate: String::new(),
273 }
274}
275
276pub fn endpoint_probe_response_frame(
282 request: &EndpointProbeRequest,
283 daemon: &DaemonProcess,
284) -> Frame {
285 let mut payload = Vec::with_capacity(PROBE_NONCE_BYTES + 128);
286 payload.extend_from_slice(&request.nonce);
287 daemon.to_proto().encode(&mut payload).expect(
288 "prost encoding DaemonProcess into Vec cannot fail because Vec writes are infallible",
289 );
290
291 Frame {
292 envelope_version: PROTOCOL_VERSION,
293 kind: FrameKind::Response as i32,
294 payload_protocol: BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL,
295 payload,
296 request_id: request.request_id,
297 payload_encoding: PayloadEncoding::None as i32,
298 deadline_unix_ms: 0,
299 traceparent: request.traceparent.clone(),
300 tracestate: request.tracestate.clone(),
301 }
302}
303
304pub fn validate_endpoint_probe_request_frame(
308 frame: &Frame,
309) -> Result<(), EndpointProbeServerError> {
310 if frame.envelope_version != PROTOCOL_VERSION {
311 return Err(EndpointProbeServerError::UnexpectedFrame(
312 "envelope_version is not v1",
313 ));
314 }
315 if FrameKind::try_from(frame.kind) != Ok(FrameKind::Request) {
316 return Err(EndpointProbeServerError::UnexpectedFrame(
317 "kind is not REQUEST",
318 ));
319 }
320 if frame.payload_protocol != BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL {
321 return Err(EndpointProbeServerError::UnexpectedFrame(
322 "payload_protocol is not endpoint probe",
323 ));
324 }
325 if PayloadEncoding::try_from(frame.payload_encoding) != Ok(PayloadEncoding::None) {
326 return Err(EndpointProbeServerError::UnexpectedFrame(
327 "payload is compressed",
328 ));
329 }
330 if frame.payload.len() != PROBE_NONCE_BYTES {
331 return Err(EndpointProbeServerError::MalformedPayload(
332 "nonce must be 32 bytes",
333 ));
334 }
335 Ok(())
336}
337
338fn validate_endpoint_probe_response_frame(
339 frame: &Frame,
340 request_id: u64,
341) -> Result<(), EndpointProbeError> {
342 if frame.envelope_version != PROTOCOL_VERSION {
343 return Err(EndpointProbeError::UnexpectedFrame(
344 "envelope_version is not v1",
345 ));
346 }
347 if FrameKind::try_from(frame.kind) != Ok(FrameKind::Response) {
348 return Err(EndpointProbeError::UnexpectedFrame("kind is not RESPONSE"));
349 }
350 if frame.payload_protocol != BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL {
351 return Err(EndpointProbeError::UnexpectedFrame(
352 "payload_protocol is not endpoint probe",
353 ));
354 }
355 if frame.request_id != request_id {
356 return Err(EndpointProbeError::UnexpectedFrame(
357 "request_id does not match endpoint probe request",
358 ));
359 }
360 if PayloadEncoding::try_from(frame.payload_encoding) != Ok(PayloadEncoding::None) {
361 return Err(EndpointProbeError::UnexpectedFrame("payload is compressed"));
362 }
363 Ok(())
364}
365
366pub fn decode_response_identity(
374 payload: &[u8],
375 expected_nonce: &[u8; PROBE_NONCE_BYTES],
376) -> Result<DaemonProcess, EndpointProbeError> {
377 if payload.len() < PROBE_NONCE_BYTES {
378 return Err(EndpointProbeError::MalformedPayload(
379 "payload shorter than nonce",
380 ));
381 }
382 let (nonce, identity_bytes) = payload.split_at(PROBE_NONCE_BYTES);
383 if nonce != expected_nonce {
384 return Err(EndpointProbeError::UnexpectedFrame(
385 "nonce does not match endpoint probe request",
386 ));
387 }
388 let proto_identity = protocol::DaemonProcess::decode(identity_bytes)
389 .map_err(EndpointProbeError::DecodeDaemonProcess)?;
390 DaemonProcess::try_from(proto_identity).map_err(EndpointProbeError::Identity)
391}
392
393fn identity_mismatch(expected: &DaemonProcess, actual: &DaemonProcess) -> EndpointProbeError {
394 let field = if actual.pid != expected.pid {
395 "pid"
396 } else if actual.exe_path != expected.exe_path {
397 "exe_path"
398 } else if actual.exe_sha256 != expected.exe_sha256 {
399 "exe_sha256"
400 } else if actual.boot_id != expected.boot_id {
401 "boot_id"
402 } else if !same_endpoint(&actual.ipc_endpoint, &expected.ipc_endpoint) {
403 "ipc_endpoint"
404 } else {
405 "unknown"
406 };
407 EndpointProbeError::IdentityMismatch { field }
408}
409
410fn same_daemon_identity(left: &DaemonProcess, right: &DaemonProcess) -> bool {
411 left.pid == right.pid
412 && left.exe_path == right.exe_path
413 && left.exe_sha256 == right.exe_sha256
414 && left.boot_id == right.boot_id
415 && same_endpoint(&left.ipc_endpoint, &right.ipc_endpoint)
416}
417
418fn connect_endpoint_with_deadline(
429 endpoint: &Endpoint,
430 deadline: Instant,
431) -> Result<interprocess::local_socket::Stream, EndpointProbeError> {
432 if endpoint.path.is_empty() {
433 return Err(EndpointProbeError::Connect(io::Error::new(
434 io::ErrorKind::InvalidInput,
435 "backend endpoint path is empty",
436 )));
437 }
438 endpoint_name(&endpoint.path).map_err(EndpointProbeError::LocalSocketName)?;
440
441 let path = endpoint.path.clone();
442 let (tx, rx) = std::sync::mpsc::channel();
443 thread::Builder::new()
444 .name("rp-endpoint-probe-connect".to_string())
445 .spawn(move || {
446 let result = match endpoint_name(&path) {
447 Ok(name) => interprocess::local_socket::Stream::connect(name),
448 Err(err) => Err(err),
449 };
450 let _ = tx.send(result);
452 })
453 .map_err(EndpointProbeError::Connect)?;
454
455 let remaining = deadline.saturating_duration_since(Instant::now());
456 match rx.recv_timeout(remaining) {
457 Ok(Ok(stream)) => Ok(stream),
458 Ok(Err(err)) => Err(EndpointProbeError::Connect(err)),
459 Err(_) => Err(EndpointProbeError::Connect(io::Error::new(
460 io::ErrorKind::TimedOut,
461 format!(
462 "backend endpoint probe connect timed out after the probe deadline \
463 (endpoint {}): the listener exists but never completed the connection",
464 endpoint.path
465 ),
466 ))),
467 }
468}
469
470fn write_probe_frame_with_deadline(
471 stream: &mut interprocess::local_socket::Stream,
472 body: &[u8],
473 deadline: Instant,
474) -> Result<(), EndpointProbeError> {
475 if body.len() > MAX_FRAME_BYTES {
476 return Err(EndpointProbeError::FrameTooLarge {
477 body_length: body.len(),
478 cap: MAX_FRAME_BYTES,
479 });
480 }
481 let mut wire = Vec::with_capacity(1 + 4 + body.len());
482 wire.push(ENVELOPE_VERSION);
483 wire.extend_from_slice(&(body.len() as u32).to_le_bytes());
484 wire.extend_from_slice(body);
485 write_all_with_deadline(stream, &wire, deadline)?;
486 flush_with_deadline(stream, deadline)
487}
488
489fn read_probe_frame_with_deadline(
490 stream: &mut interprocess::local_socket::Stream,
491 deadline: Instant,
492) -> Result<Vec<u8>, EndpointProbeError> {
493 parse_probe_frame(|buf| read_exact_with_deadline(stream, buf, deadline))
494}
495
496pub fn read_probe_frame<R: Read>(reader: &mut R) -> Result<Vec<u8>, EndpointProbeError> {
504 parse_probe_frame(|buf| reader.read_exact(buf).map_err(EndpointProbeError::Io))
505}
506
507fn parse_probe_frame(
512 mut read_exact: impl FnMut(&mut [u8]) -> Result<(), EndpointProbeError>,
513) -> Result<Vec<u8>, EndpointProbeError> {
514 let mut version = [0_u8; 1];
515 read_exact(&mut version)?;
516 if version[0] != ENVELOPE_VERSION {
517 return Err(EndpointProbeError::UnsupportedFramingVersion {
518 got: version[0],
519 expected: ENVELOPE_VERSION,
520 });
521 }
522
523 let mut len = [0_u8; 4];
524 read_exact(&mut len)?;
525 let body_length = u32::from_le_bytes(len) as usize;
526 if body_length > MAX_FRAME_BYTES {
527 return Err(EndpointProbeError::FrameTooLarge {
528 body_length,
529 cap: MAX_FRAME_BYTES,
530 });
531 }
532
533 let mut body = vec![0_u8; body_length];
534 if body_length > 0 {
535 read_exact(&mut body)?;
536 }
537 Ok(body)
538}
539
540fn write_all_with_deadline<W: Write>(
541 writer: &mut W,
542 mut buf: &[u8],
543 deadline: Instant,
544) -> Result<(), EndpointProbeError> {
545 while !buf.is_empty() {
546 match writer.write(buf) {
547 Ok(0) => {
548 return Err(EndpointProbeError::Io(io::Error::new(
549 io::ErrorKind::WriteZero,
550 "endpoint probe write returned zero bytes",
551 )));
552 }
553 Ok(written) => buf = &buf[written..],
554 Err(err) if err.kind() == io::ErrorKind::WouldBlock => wait_for_io(deadline)?,
555 Err(err) => return Err(EndpointProbeError::Io(err)),
556 }
557 }
558 Ok(())
559}
560
561fn read_exact_with_deadline<R: Read>(
562 reader: &mut R,
563 mut buf: &mut [u8],
564 deadline: Instant,
565) -> Result<(), EndpointProbeError> {
566 while !buf.is_empty() {
567 match reader.read(buf) {
568 Ok(0) => wait_for_io(deadline)?,
569 Ok(read) => {
570 let tmp = buf;
571 buf = &mut tmp[read..];
572 }
573 Err(err) if err.kind() == io::ErrorKind::WouldBlock => wait_for_io(deadline)?,
574 Err(err) => return Err(EndpointProbeError::Io(err)),
575 }
576 }
577 Ok(())
578}
579
580fn flush_with_deadline<W: Write>(
581 writer: &mut W,
582 deadline: Instant,
583) -> Result<(), EndpointProbeError> {
584 loop {
585 match writer.flush() {
586 Ok(()) => return Ok(()),
587 Err(err) if err.kind() == io::ErrorKind::WouldBlock => wait_for_io(deadline)?,
588 Err(err) => return Err(EndpointProbeError::Io(err)),
589 }
590 }
591}
592
593fn wait_for_io(deadline: Instant) -> Result<(), EndpointProbeError> {
594 if Instant::now() >= deadline {
595 return Err(EndpointProbeError::Timeout);
596 }
597 let remaining = deadline.saturating_duration_since(Instant::now());
598 thread::sleep(remaining.min(NONBLOCKING_POLL_INTERVAL));
599 Ok(())
600}
601
602fn endpoint_name(path: &str) -> io::Result<interprocess::local_socket::Name<'_>> {
603 #[cfg(unix)]
604 {
605 use interprocess::local_socket::GenericFilePath;
606 path.to_fs_name::<GenericFilePath>()
607 }
608
609 #[cfg(windows)]
610 {
611 use interprocess::local_socket::GenericNamespaced;
612 path.to_ns_name::<GenericNamespaced>()
613 }
614}