1use std::io::{self, Read, Write};
4use std::thread;
5use std::time::{Duration, Instant};
6
7use interprocess::local_socket::prelude::*;
8use prost::Message;
9
10use crate::broker::backend_lifecycle::identity::{DaemonProcess, IdentityError};
11use crate::broker::backend_lifecycle::verify_pid::{self, ProcessHandle, VerifyPidError};
12use crate::broker::protocol::{
13 self, read_frame, write_frame, Endpoint, Frame, FrameKind, FramingError, PayloadEncoding,
14 ENVELOPE_VERSION, MAX_FRAME_BYTES,
15};
16
17const PROTOCOL_VERSION: u32 = 1;
18const PROBE_NONCE_BYTES: usize = 32;
19const NONBLOCKING_POLL_INTERVAL: Duration = Duration::from_millis(5);
20
21pub const BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL: u32 = 0xB232;
23
24pub const DEFAULT_ENDPOINT_PROBE_TIMEOUT: Duration = Duration::from_millis(500);
26
27pub fn probe_endpoint(
29 endpoint: &Endpoint,
30 expected: &DaemonProcess,
31) -> Result<ProcessHandle, ProbeError> {
32 if !same_endpoint(endpoint, &expected.ipc_endpoint) {
33 return Err(ProbeError::EndpointMismatch);
34 }
35 let process_handle =
36 verify_pid::verify_daemon_process(expected).map_err(ProbeError::VerifyPid)?;
37 probe_endpoint_response(endpoint, expected)?;
38 Ok(process_handle)
39}
40
41pub fn same_endpoint(left: &Endpoint, right: &Endpoint) -> bool {
43 left.namespace_id == right.namespace_id && left.path == right.path
44}
45
46pub fn probe_endpoint_response(
53 endpoint: &Endpoint,
54 expected: &DaemonProcess,
55) -> Result<(), EndpointProbeError> {
56 probe_endpoint_response_with_timeout(endpoint, expected, DEFAULT_ENDPOINT_PROBE_TIMEOUT)
57}
58
59pub fn probe_endpoint_response_with_timeout(
61 endpoint: &Endpoint,
62 expected: &DaemonProcess,
63 timeout: Duration,
64) -> Result<(), EndpointProbeError> {
65 let mut nonce = [0_u8; PROBE_NONCE_BYTES];
66 getrandom::fill(&mut nonce).map_err(EndpointProbeError::Random)?;
67 let request_id = u64::from_le_bytes(nonce[..8].try_into().expect("nonce has 8 bytes"));
68 let request_frame = endpoint_probe_request_frame(request_id, &nonce);
69 let mut request_bytes = Vec::new();
70 request_frame
71 .encode(&mut request_bytes)
72 .map_err(EndpointProbeError::EncodeFrame)?;
73
74 let deadline = Instant::now() + timeout;
75 let mut stream = connect_endpoint(endpoint)?;
76 stream
77 .set_nonblocking(true)
78 .map_err(EndpointProbeError::ConfigureNonblocking)?;
79 write_probe_frame_with_deadline(&mut stream, &request_bytes, deadline)?;
80
81 let response_bytes = read_probe_frame_with_deadline(&mut stream, deadline)?;
82 let response_frame =
83 Frame::decode(response_bytes.as_slice()).map_err(EndpointProbeError::DecodeFrame)?;
84 validate_endpoint_probe_response_frame(&response_frame, request_id)?;
85 let actual = decode_response_identity(&response_frame.payload, &nonce)?;
86 if !same_daemon_identity(&actual, expected) {
87 return Err(identity_mismatch(expected, &actual));
88 }
89 Ok(())
90}
91
92#[derive(Debug, Clone, PartialEq, Eq)]
94pub struct EndpointProbeRequest {
95 pub request_id: u64,
97 pub nonce: [u8; PROBE_NONCE_BYTES],
99 pub traceparent: String,
101 pub tracestate: String,
103}
104
105pub fn read_endpoint_probe_request<S: Read>(
107 stream: &mut S,
108) -> Result<EndpointProbeRequest, EndpointProbeServerError> {
109 let request_bytes = read_frame(stream)?;
110 let frame =
111 Frame::decode(request_bytes.as_slice()).map_err(EndpointProbeServerError::DecodeFrame)?;
112 validate_endpoint_probe_request_frame(&frame)?;
113 let nonce = frame
114 .payload
115 .as_slice()
116 .try_into()
117 .map_err(|_| EndpointProbeServerError::MalformedPayload("nonce must be 32 bytes"))?;
118 Ok(EndpointProbeRequest {
119 request_id: frame.request_id,
120 nonce,
121 traceparent: frame.traceparent,
122 tracestate: frame.tracestate,
123 })
124}
125
126pub fn write_endpoint_probe_response<S: Write>(
128 stream: &mut S,
129 request: &EndpointProbeRequest,
130 daemon: &DaemonProcess,
131) -> Result<(), EndpointProbeServerError> {
132 let response_frame = endpoint_probe_response_frame(request, daemon);
133 let mut response_bytes = Vec::new();
134 response_frame
135 .encode(&mut response_bytes)
136 .map_err(EndpointProbeServerError::EncodeFrame)?;
137 write_frame(stream, &response_bytes)?;
138 Ok(())
139}
140
141pub fn handle_endpoint_probe<S: Read + Write>(
143 stream: &mut S,
144 daemon: &DaemonProcess,
145) -> Result<(), EndpointProbeServerError> {
146 let request = read_endpoint_probe_request(stream)?;
147 write_endpoint_probe_response(stream, &request, daemon)
148}
149
150#[derive(Debug, thiserror::Error)]
152pub enum ProbeError {
153 #[error("endpoint does not match expected daemon identity")]
155 EndpointMismatch,
156 #[error(transparent)]
158 EndpointResponse(#[from] EndpointProbeError),
159 #[error(transparent)]
161 VerifyPid(#[from] VerifyPidError),
162}
163
164#[derive(Debug, thiserror::Error)]
166pub enum EndpointProbeError {
167 #[error("backend endpoint probe random generation failed: {0}")]
169 Random(getrandom::Error),
170 #[error("backend endpoint probe local-socket name failed: {0}")]
172 LocalSocketName(io::Error),
173 #[error("backend endpoint probe connect failed: {0}")]
175 Connect(io::Error),
176 #[error("backend endpoint probe nonblocking setup failed: {0}")]
178 ConfigureNonblocking(io::Error),
179 #[error("backend endpoint probe timed out")]
181 Timeout,
182 #[error("backend endpoint probe I/O failed: {0}")]
184 Io(io::Error),
185 #[error("backend endpoint probe unsupported framing version: got {got}, expected {expected}")]
187 UnsupportedFramingVersion {
188 got: u8,
190 expected: u8,
192 },
193 #[error("backend endpoint probe frame body too large: {body_length} bytes exceeds cap {cap}")]
195 FrameTooLarge {
196 body_length: usize,
198 cap: usize,
200 },
201 #[error("failed to encode endpoint probe frame: {0}")]
203 EncodeFrame(prost::EncodeError),
204 #[error("failed to decode endpoint probe response Frame: {0}")]
206 DecodeFrame(prost::DecodeError),
207 #[error("unexpected endpoint probe response: {0}")]
209 UnexpectedFrame(&'static str),
210 #[error("endpoint probe response payload is malformed: {0}")]
212 MalformedPayload(&'static str),
213 #[error("failed to decode endpoint probe daemon identity: {0}")]
215 DecodeDaemonProcess(prost::DecodeError),
216 #[error(transparent)]
218 Identity(#[from] IdentityError),
219 #[error("endpoint probe response identity did not match expected daemon identity: {field}")]
221 IdentityMismatch {
222 field: &'static str,
224 },
225}
226
227#[derive(Debug, thiserror::Error)]
229pub enum EndpointProbeServerError {
230 #[error(transparent)]
232 Framing(#[from] FramingError),
233 #[error("failed to decode endpoint probe request Frame: {0}")]
235 DecodeFrame(prost::DecodeError),
236 #[error("failed to encode endpoint probe response Frame: {0}")]
238 EncodeFrame(prost::EncodeError),
239 #[error("unexpected endpoint probe request: {0}")]
241 UnexpectedFrame(&'static str),
242 #[error("endpoint probe request payload is malformed: {0}")]
244 MalformedPayload(&'static str),
245}
246
247fn endpoint_probe_request_frame(request_id: u64, nonce: &[u8; PROBE_NONCE_BYTES]) -> Frame {
248 Frame {
249 envelope_version: PROTOCOL_VERSION,
250 kind: FrameKind::Request as i32,
251 payload_protocol: BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL,
252 payload: nonce.to_vec(),
253 request_id,
254 payload_encoding: PayloadEncoding::None as i32,
255 deadline_unix_ms: 0,
256 traceparent: String::new(),
257 tracestate: String::new(),
258 }
259}
260
261fn endpoint_probe_response_frame(request: &EndpointProbeRequest, daemon: &DaemonProcess) -> Frame {
262 let mut payload = Vec::with_capacity(PROBE_NONCE_BYTES + 128);
263 payload.extend_from_slice(&request.nonce);
264 daemon.to_proto().encode(&mut payload).expect(
265 "prost encoding DaemonProcess into Vec cannot fail because Vec writes are infallible",
266 );
267
268 Frame {
269 envelope_version: PROTOCOL_VERSION,
270 kind: FrameKind::Response as i32,
271 payload_protocol: BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL,
272 payload,
273 request_id: request.request_id,
274 payload_encoding: PayloadEncoding::None as i32,
275 deadline_unix_ms: 0,
276 traceparent: request.traceparent.clone(),
277 tracestate: request.tracestate.clone(),
278 }
279}
280
281fn validate_endpoint_probe_request_frame(frame: &Frame) -> Result<(), EndpointProbeServerError> {
282 if frame.envelope_version != PROTOCOL_VERSION {
283 return Err(EndpointProbeServerError::UnexpectedFrame(
284 "envelope_version is not v1",
285 ));
286 }
287 if FrameKind::try_from(frame.kind) != Ok(FrameKind::Request) {
288 return Err(EndpointProbeServerError::UnexpectedFrame(
289 "kind is not REQUEST",
290 ));
291 }
292 if frame.payload_protocol != BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL {
293 return Err(EndpointProbeServerError::UnexpectedFrame(
294 "payload_protocol is not endpoint probe",
295 ));
296 }
297 if PayloadEncoding::try_from(frame.payload_encoding) != Ok(PayloadEncoding::None) {
298 return Err(EndpointProbeServerError::UnexpectedFrame(
299 "payload is compressed",
300 ));
301 }
302 if frame.payload.len() != PROBE_NONCE_BYTES {
303 return Err(EndpointProbeServerError::MalformedPayload(
304 "nonce must be 32 bytes",
305 ));
306 }
307 Ok(())
308}
309
310fn validate_endpoint_probe_response_frame(
311 frame: &Frame,
312 request_id: u64,
313) -> Result<(), EndpointProbeError> {
314 if frame.envelope_version != PROTOCOL_VERSION {
315 return Err(EndpointProbeError::UnexpectedFrame(
316 "envelope_version is not v1",
317 ));
318 }
319 if FrameKind::try_from(frame.kind) != Ok(FrameKind::Response) {
320 return Err(EndpointProbeError::UnexpectedFrame("kind is not RESPONSE"));
321 }
322 if frame.payload_protocol != BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL {
323 return Err(EndpointProbeError::UnexpectedFrame(
324 "payload_protocol is not endpoint probe",
325 ));
326 }
327 if frame.request_id != request_id {
328 return Err(EndpointProbeError::UnexpectedFrame(
329 "request_id does not match endpoint probe request",
330 ));
331 }
332 if PayloadEncoding::try_from(frame.payload_encoding) != Ok(PayloadEncoding::None) {
333 return Err(EndpointProbeError::UnexpectedFrame("payload is compressed"));
334 }
335 Ok(())
336}
337
338fn decode_response_identity(
339 payload: &[u8],
340 expected_nonce: &[u8; PROBE_NONCE_BYTES],
341) -> Result<DaemonProcess, EndpointProbeError> {
342 if payload.len() < PROBE_NONCE_BYTES {
343 return Err(EndpointProbeError::MalformedPayload(
344 "payload shorter than nonce",
345 ));
346 }
347 let (nonce, identity_bytes) = payload.split_at(PROBE_NONCE_BYTES);
348 if nonce != expected_nonce {
349 return Err(EndpointProbeError::UnexpectedFrame(
350 "nonce does not match endpoint probe request",
351 ));
352 }
353 let proto_identity = protocol::DaemonProcess::decode(identity_bytes)
354 .map_err(EndpointProbeError::DecodeDaemonProcess)?;
355 DaemonProcess::try_from(proto_identity).map_err(EndpointProbeError::Identity)
356}
357
358fn identity_mismatch(expected: &DaemonProcess, actual: &DaemonProcess) -> EndpointProbeError {
359 let field = if actual.pid != expected.pid {
360 "pid"
361 } else if actual.exe_path != expected.exe_path {
362 "exe_path"
363 } else if actual.exe_sha256 != expected.exe_sha256 {
364 "exe_sha256"
365 } else if actual.boot_id != expected.boot_id {
366 "boot_id"
367 } else if !same_endpoint(&actual.ipc_endpoint, &expected.ipc_endpoint) {
368 "ipc_endpoint"
369 } else {
370 "unknown"
371 };
372 EndpointProbeError::IdentityMismatch { field }
373}
374
375fn same_daemon_identity(left: &DaemonProcess, right: &DaemonProcess) -> bool {
376 left.pid == right.pid
377 && left.exe_path == right.exe_path
378 && left.exe_sha256 == right.exe_sha256
379 && left.boot_id == right.boot_id
380 && same_endpoint(&left.ipc_endpoint, &right.ipc_endpoint)
381}
382
383fn connect_endpoint(
384 endpoint: &Endpoint,
385) -> Result<interprocess::local_socket::Stream, EndpointProbeError> {
386 if endpoint.path.is_empty() {
387 return Err(EndpointProbeError::Connect(io::Error::new(
388 io::ErrorKind::InvalidInput,
389 "backend endpoint path is empty",
390 )));
391 }
392 let name = endpoint_name(&endpoint.path).map_err(EndpointProbeError::LocalSocketName)?;
393 interprocess::local_socket::Stream::connect(name).map_err(EndpointProbeError::Connect)
394}
395
396fn write_probe_frame_with_deadline(
397 stream: &mut interprocess::local_socket::Stream,
398 body: &[u8],
399 deadline: Instant,
400) -> Result<(), EndpointProbeError> {
401 if body.len() > MAX_FRAME_BYTES {
402 return Err(EndpointProbeError::FrameTooLarge {
403 body_length: body.len(),
404 cap: MAX_FRAME_BYTES,
405 });
406 }
407 let mut wire = Vec::with_capacity(1 + 4 + body.len());
408 wire.push(ENVELOPE_VERSION);
409 wire.extend_from_slice(&(body.len() as u32).to_le_bytes());
410 wire.extend_from_slice(body);
411 write_all_with_deadline(stream, &wire, deadline)?;
412 flush_with_deadline(stream, deadline)
413}
414
415fn read_probe_frame_with_deadline(
416 stream: &mut interprocess::local_socket::Stream,
417 deadline: Instant,
418) -> Result<Vec<u8>, EndpointProbeError> {
419 let mut version = [0_u8; 1];
420 read_exact_with_deadline(stream, &mut version, deadline)?;
421 if version[0] != ENVELOPE_VERSION {
422 return Err(EndpointProbeError::UnsupportedFramingVersion {
423 got: version[0],
424 expected: ENVELOPE_VERSION,
425 });
426 }
427
428 let mut len = [0_u8; 4];
429 read_exact_with_deadline(stream, &mut len, deadline)?;
430 let body_length = u32::from_le_bytes(len) as usize;
431 if body_length > MAX_FRAME_BYTES {
432 return Err(EndpointProbeError::FrameTooLarge {
433 body_length,
434 cap: MAX_FRAME_BYTES,
435 });
436 }
437
438 let mut body = vec![0_u8; body_length];
439 if body_length > 0 {
440 read_exact_with_deadline(stream, &mut body, deadline)?;
441 }
442 Ok(body)
443}
444
445fn write_all_with_deadline<W: Write>(
446 writer: &mut W,
447 mut buf: &[u8],
448 deadline: Instant,
449) -> Result<(), EndpointProbeError> {
450 while !buf.is_empty() {
451 match writer.write(buf) {
452 Ok(0) => {
453 return Err(EndpointProbeError::Io(io::Error::new(
454 io::ErrorKind::WriteZero,
455 "endpoint probe write returned zero bytes",
456 )));
457 }
458 Ok(written) => buf = &buf[written..],
459 Err(err) if err.kind() == io::ErrorKind::WouldBlock => wait_for_io(deadline)?,
460 Err(err) => return Err(EndpointProbeError::Io(err)),
461 }
462 }
463 Ok(())
464}
465
466fn read_exact_with_deadline<R: Read>(
467 reader: &mut R,
468 mut buf: &mut [u8],
469 deadline: Instant,
470) -> Result<(), EndpointProbeError> {
471 while !buf.is_empty() {
472 match reader.read(buf) {
473 Ok(0) => wait_for_io(deadline)?,
474 Ok(read) => {
475 let tmp = buf;
476 buf = &mut tmp[read..];
477 }
478 Err(err) if err.kind() == io::ErrorKind::WouldBlock => wait_for_io(deadline)?,
479 Err(err) => return Err(EndpointProbeError::Io(err)),
480 }
481 }
482 Ok(())
483}
484
485fn flush_with_deadline<W: Write>(
486 writer: &mut W,
487 deadline: Instant,
488) -> Result<(), EndpointProbeError> {
489 loop {
490 match writer.flush() {
491 Ok(()) => return Ok(()),
492 Err(err) if err.kind() == io::ErrorKind::WouldBlock => wait_for_io(deadline)?,
493 Err(err) => return Err(EndpointProbeError::Io(err)),
494 }
495 }
496}
497
498fn wait_for_io(deadline: Instant) -> Result<(), EndpointProbeError> {
499 if Instant::now() >= deadline {
500 return Err(EndpointProbeError::Timeout);
501 }
502 let remaining = deadline.saturating_duration_since(Instant::now());
503 thread::sleep(remaining.min(NONBLOCKING_POLL_INTERVAL));
504 Ok(())
505}
506
507fn endpoint_name(path: &str) -> io::Result<interprocess::local_socket::Name<'_>> {
508 #[cfg(unix)]
509 {
510 use interprocess::local_socket::GenericFilePath;
511 path.to_fs_name::<GenericFilePath>()
512 }
513
514 #[cfg(windows)]
515 {
516 use interprocess::local_socket::GenericNamespaced;
517 path.to_ns_name::<GenericNamespaced>()
518 }
519}