Skip to main content

running_process/broker/backend_lifecycle/
probe.rs

1//! Endpoint and process identity checks for backend handles.
2
3use std::io::{self, Read, Write};
4use std::thread;
5use std::time::{Duration, Instant};
6
7use interprocess::local_socket::prelude::*;
8use prost::Message;
9
10use crate::broker::backend_lifecycle::identity::{DaemonProcess, IdentityError};
11use crate::broker::backend_lifecycle::verify_pid::{self, ProcessHandle, VerifyPidError};
12use crate::broker::protocol::{
13    self, read_frame, write_frame, Endpoint, Frame, FrameKind, FramingError, PayloadEncoding,
14    ENVELOPE_VERSION, MAX_FRAME_BYTES,
15};
16
17const PROTOCOL_VERSION: u32 = 1;
18const PROBE_NONCE_BYTES: usize = 32;
19const NONBLOCKING_POLL_INTERVAL: Duration = Duration::from_millis(5);
20
21/// Payload protocol reserved for `BackendHandle` endpoint identity probes.
22pub const BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL: u32 = 0xB232;
23
24/// Default deadline for the active endpoint-response proof.
25pub const DEFAULT_ENDPOINT_PROBE_TIMEOUT: Duration = Duration::from_millis(500);
26
27/// Verify that an endpoint refers to the expected daemon process.
28pub fn probe_endpoint(
29    endpoint: &Endpoint,
30    expected: &DaemonProcess,
31) -> Result<ProcessHandle, ProbeError> {
32    if !same_endpoint(endpoint, &expected.ipc_endpoint) {
33        return Err(ProbeError::EndpointMismatch);
34    }
35    let process_handle =
36        verify_pid::verify_daemon_process(expected).map_err(ProbeError::VerifyPid)?;
37    probe_endpoint_response(endpoint, expected)?;
38    Ok(process_handle)
39}
40
41/// Compare two endpoint identities exactly.
42pub fn same_endpoint(left: &Endpoint, right: &Endpoint) -> bool {
43    left.namespace_id == right.namespace_id && left.path == right.path
44}
45
46/// Actively probe a backend endpoint and verify that it returns the expected
47/// daemon identity.
48///
49/// The probe uses the broker v1 frame layout with a dedicated payload protocol.
50/// Requests carry a 32-byte nonce. Responses must echo that nonce and include a
51/// prost-encoded `DaemonProcess` payload that exactly matches `expected`.
52pub fn probe_endpoint_response(
53    endpoint: &Endpoint,
54    expected: &DaemonProcess,
55) -> Result<(), EndpointProbeError> {
56    probe_endpoint_response_with_timeout(endpoint, expected, DEFAULT_ENDPOINT_PROBE_TIMEOUT)
57}
58
59/// Timed variant of [`probe_endpoint_response`] used by tests and diagnostics.
60pub fn probe_endpoint_response_with_timeout(
61    endpoint: &Endpoint,
62    expected: &DaemonProcess,
63    timeout: Duration,
64) -> Result<(), EndpointProbeError> {
65    let mut nonce = [0_u8; PROBE_NONCE_BYTES];
66    getrandom::fill(&mut nonce).map_err(EndpointProbeError::Random)?;
67    let request_id = u64::from_le_bytes(nonce[..8].try_into().expect("nonce has 8 bytes"));
68    let request_frame = endpoint_probe_request_frame(request_id, &nonce);
69    let mut request_bytes = Vec::new();
70    request_frame
71        .encode(&mut request_bytes)
72        .map_err(EndpointProbeError::EncodeFrame)?;
73
74    let deadline = Instant::now() + timeout;
75    let mut stream = connect_endpoint(endpoint)?;
76    stream
77        .set_nonblocking(true)
78        .map_err(EndpointProbeError::ConfigureNonblocking)?;
79    write_probe_frame_with_deadline(&mut stream, &request_bytes, deadline)?;
80
81    let response_bytes = read_probe_frame_with_deadline(&mut stream, deadline)?;
82    let response_frame =
83        Frame::decode(response_bytes.as_slice()).map_err(EndpointProbeError::DecodeFrame)?;
84    validate_endpoint_probe_response_frame(&response_frame, request_id)?;
85    let actual = decode_response_identity(&response_frame.payload, &nonce)?;
86    if !same_daemon_identity(&actual, expected) {
87        return Err(identity_mismatch(expected, &actual));
88    }
89    Ok(())
90}
91
92/// Decoded endpoint probe request for backend-side responders.
93#[derive(Debug, Clone, PartialEq, Eq)]
94pub struct EndpointProbeRequest {
95    /// Request frame ID that the response must echo.
96    pub request_id: u64,
97    /// Random challenge that the response must echo.
98    pub nonce: [u8; PROBE_NONCE_BYTES],
99    /// Trace context copied from the request frame, if any.
100    pub traceparent: String,
101    /// Trace state copied from the request frame, if any.
102    pub tracestate: String,
103}
104
105/// Read and validate one endpoint probe request from an accepted IPC stream.
106pub fn read_endpoint_probe_request<S: Read>(
107    stream: &mut S,
108) -> Result<EndpointProbeRequest, EndpointProbeServerError> {
109    let request_bytes = read_frame(stream)?;
110    let frame =
111        Frame::decode(request_bytes.as_slice()).map_err(EndpointProbeServerError::DecodeFrame)?;
112    validate_endpoint_probe_request_frame(&frame)?;
113    let nonce = frame
114        .payload
115        .as_slice()
116        .try_into()
117        .map_err(|_| EndpointProbeServerError::MalformedPayload("nonce must be 32 bytes"))?;
118    Ok(EndpointProbeRequest {
119        request_id: frame.request_id,
120        nonce,
121        traceparent: frame.traceparent,
122        tracestate: frame.tracestate,
123    })
124}
125
126/// Write one endpoint probe response for a validated request.
127pub fn write_endpoint_probe_response<S: Write>(
128    stream: &mut S,
129    request: &EndpointProbeRequest,
130    daemon: &DaemonProcess,
131) -> Result<(), EndpointProbeServerError> {
132    let response_frame = endpoint_probe_response_frame(request, daemon);
133    let mut response_bytes = Vec::new();
134    response_frame
135        .encode(&mut response_bytes)
136        .map_err(EndpointProbeServerError::EncodeFrame)?;
137    write_frame(stream, &response_bytes)?;
138    Ok(())
139}
140
141/// Serve exactly one endpoint probe request on an already-accepted IPC stream.
142pub fn handle_endpoint_probe<S: Read + Write>(
143    stream: &mut S,
144    daemon: &DaemonProcess,
145) -> Result<(), EndpointProbeServerError> {
146    let request = read_endpoint_probe_request(stream)?;
147    write_endpoint_probe_response(stream, &request, daemon)
148}
149
150/// Errors returned while probing a backend endpoint.
151#[derive(Debug, thiserror::Error)]
152pub enum ProbeError {
153    /// The caller-provided endpoint did not match the expected daemon endpoint.
154    #[error("endpoint does not match expected daemon identity")]
155    EndpointMismatch,
156    /// The endpoint did not answer the active identity probe as expected.
157    #[error(transparent)]
158    EndpointResponse(#[from] EndpointProbeError),
159    /// The daemon process identity could not be verified.
160    #[error(transparent)]
161    VerifyPid(#[from] VerifyPidError),
162}
163
164/// Errors returned by the active endpoint-response probe.
165#[derive(Debug, thiserror::Error)]
166pub enum EndpointProbeError {
167    /// The probe nonce could not be generated.
168    #[error("backend endpoint probe random generation failed: {0}")]
169    Random(getrandom::Error),
170    /// The endpoint path/name could not be converted to a local socket name.
171    #[error("backend endpoint probe local-socket name failed: {0}")]
172    LocalSocketName(io::Error),
173    /// Connecting to the endpoint failed.
174    #[error("backend endpoint probe connect failed: {0}")]
175    Connect(io::Error),
176    /// The stream could not be switched to nonblocking mode for deadline I/O.
177    #[error("backend endpoint probe nonblocking setup failed: {0}")]
178    ConfigureNonblocking(io::Error),
179    /// Probe I/O exceeded the configured deadline.
180    #[error("backend endpoint probe timed out")]
181    Timeout,
182    /// Raw probe I/O failed.
183    #[error("backend endpoint probe I/O failed: {0}")]
184    Io(io::Error),
185    /// The peer used the wrong broker framing byte.
186    #[error("backend endpoint probe unsupported framing version: got {got}, expected {expected}")]
187    UnsupportedFramingVersion {
188        /// Framing byte received from the peer.
189        got: u8,
190        /// Framing byte expected by v1.
191        expected: u8,
192    },
193    /// The peer advertised a frame that exceeds the v1 frame cap.
194    #[error("backend endpoint probe frame body too large: {body_length} bytes exceeds cap {cap}")]
195    FrameTooLarge {
196        /// Advertised frame body length.
197        body_length: usize,
198        /// Maximum accepted frame body length.
199        cap: usize,
200    },
201    /// The outbound probe request frame could not be encoded.
202    #[error("failed to encode endpoint probe frame: {0}")]
203    EncodeFrame(prost::EncodeError),
204    /// The response frame could not be decoded.
205    #[error("failed to decode endpoint probe response Frame: {0}")]
206    DecodeFrame(prost::DecodeError),
207    /// The response frame did not match the endpoint-probe contract.
208    #[error("unexpected endpoint probe response: {0}")]
209    UnexpectedFrame(&'static str),
210    /// The response payload did not match the endpoint-probe contract.
211    #[error("endpoint probe response payload is malformed: {0}")]
212    MalformedPayload(&'static str),
213    /// The response daemon identity could not be decoded.
214    #[error("failed to decode endpoint probe daemon identity: {0}")]
215    DecodeDaemonProcess(prost::DecodeError),
216    /// The response daemon identity was malformed.
217    #[error(transparent)]
218    Identity(#[from] IdentityError),
219    /// The response daemon identity did not match the expected identity.
220    #[error("endpoint probe response identity did not match expected daemon identity: {field}")]
221    IdentityMismatch {
222        /// First mismatched identity field.
223        field: &'static str,
224    },
225}
226
227/// Errors returned by backend-side endpoint probe responders.
228#[derive(Debug, thiserror::Error)]
229pub enum EndpointProbeServerError {
230    /// v1 framing failed.
231    #[error(transparent)]
232    Framing(#[from] FramingError),
233    /// The request frame could not be decoded.
234    #[error("failed to decode endpoint probe request Frame: {0}")]
235    DecodeFrame(prost::DecodeError),
236    /// The response frame could not be encoded.
237    #[error("failed to encode endpoint probe response Frame: {0}")]
238    EncodeFrame(prost::EncodeError),
239    /// The request frame did not match the endpoint-probe contract.
240    #[error("unexpected endpoint probe request: {0}")]
241    UnexpectedFrame(&'static str),
242    /// The request payload did not match the endpoint-probe contract.
243    #[error("endpoint probe request payload is malformed: {0}")]
244    MalformedPayload(&'static str),
245}
246
247fn endpoint_probe_request_frame(request_id: u64, nonce: &[u8; PROBE_NONCE_BYTES]) -> Frame {
248    Frame {
249        envelope_version: PROTOCOL_VERSION,
250        kind: FrameKind::Request as i32,
251        payload_protocol: BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL,
252        payload: nonce.to_vec(),
253        request_id,
254        payload_encoding: PayloadEncoding::None as i32,
255        deadline_unix_ms: 0,
256        traceparent: String::new(),
257        tracestate: String::new(),
258    }
259}
260
261fn endpoint_probe_response_frame(request: &EndpointProbeRequest, daemon: &DaemonProcess) -> Frame {
262    let mut payload = Vec::with_capacity(PROBE_NONCE_BYTES + 128);
263    payload.extend_from_slice(&request.nonce);
264    daemon.to_proto().encode(&mut payload).expect(
265        "prost encoding DaemonProcess into Vec cannot fail because Vec writes are infallible",
266    );
267
268    Frame {
269        envelope_version: PROTOCOL_VERSION,
270        kind: FrameKind::Response as i32,
271        payload_protocol: BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL,
272        payload,
273        request_id: request.request_id,
274        payload_encoding: PayloadEncoding::None as i32,
275        deadline_unix_ms: 0,
276        traceparent: request.traceparent.clone(),
277        tracestate: request.tracestate.clone(),
278    }
279}
280
281fn validate_endpoint_probe_request_frame(frame: &Frame) -> Result<(), EndpointProbeServerError> {
282    if frame.envelope_version != PROTOCOL_VERSION {
283        return Err(EndpointProbeServerError::UnexpectedFrame(
284            "envelope_version is not v1",
285        ));
286    }
287    if FrameKind::try_from(frame.kind) != Ok(FrameKind::Request) {
288        return Err(EndpointProbeServerError::UnexpectedFrame(
289            "kind is not REQUEST",
290        ));
291    }
292    if frame.payload_protocol != BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL {
293        return Err(EndpointProbeServerError::UnexpectedFrame(
294            "payload_protocol is not endpoint probe",
295        ));
296    }
297    if PayloadEncoding::try_from(frame.payload_encoding) != Ok(PayloadEncoding::None) {
298        return Err(EndpointProbeServerError::UnexpectedFrame(
299            "payload is compressed",
300        ));
301    }
302    if frame.payload.len() != PROBE_NONCE_BYTES {
303        return Err(EndpointProbeServerError::MalformedPayload(
304            "nonce must be 32 bytes",
305        ));
306    }
307    Ok(())
308}
309
310fn validate_endpoint_probe_response_frame(
311    frame: &Frame,
312    request_id: u64,
313) -> Result<(), EndpointProbeError> {
314    if frame.envelope_version != PROTOCOL_VERSION {
315        return Err(EndpointProbeError::UnexpectedFrame(
316            "envelope_version is not v1",
317        ));
318    }
319    if FrameKind::try_from(frame.kind) != Ok(FrameKind::Response) {
320        return Err(EndpointProbeError::UnexpectedFrame("kind is not RESPONSE"));
321    }
322    if frame.payload_protocol != BACKEND_HANDLE_PROBE_PAYLOAD_PROTOCOL {
323        return Err(EndpointProbeError::UnexpectedFrame(
324            "payload_protocol is not endpoint probe",
325        ));
326    }
327    if frame.request_id != request_id {
328        return Err(EndpointProbeError::UnexpectedFrame(
329            "request_id does not match endpoint probe request",
330        ));
331    }
332    if PayloadEncoding::try_from(frame.payload_encoding) != Ok(PayloadEncoding::None) {
333        return Err(EndpointProbeError::UnexpectedFrame("payload is compressed"));
334    }
335    Ok(())
336}
337
338fn decode_response_identity(
339    payload: &[u8],
340    expected_nonce: &[u8; PROBE_NONCE_BYTES],
341) -> Result<DaemonProcess, EndpointProbeError> {
342    if payload.len() < PROBE_NONCE_BYTES {
343        return Err(EndpointProbeError::MalformedPayload(
344            "payload shorter than nonce",
345        ));
346    }
347    let (nonce, identity_bytes) = payload.split_at(PROBE_NONCE_BYTES);
348    if nonce != expected_nonce {
349        return Err(EndpointProbeError::UnexpectedFrame(
350            "nonce does not match endpoint probe request",
351        ));
352    }
353    let proto_identity = protocol::DaemonProcess::decode(identity_bytes)
354        .map_err(EndpointProbeError::DecodeDaemonProcess)?;
355    DaemonProcess::try_from(proto_identity).map_err(EndpointProbeError::Identity)
356}
357
358fn identity_mismatch(expected: &DaemonProcess, actual: &DaemonProcess) -> EndpointProbeError {
359    let field = if actual.pid != expected.pid {
360        "pid"
361    } else if actual.exe_path != expected.exe_path {
362        "exe_path"
363    } else if actual.exe_sha256 != expected.exe_sha256 {
364        "exe_sha256"
365    } else if actual.boot_id != expected.boot_id {
366        "boot_id"
367    } else if !same_endpoint(&actual.ipc_endpoint, &expected.ipc_endpoint) {
368        "ipc_endpoint"
369    } else {
370        "unknown"
371    };
372    EndpointProbeError::IdentityMismatch { field }
373}
374
375fn same_daemon_identity(left: &DaemonProcess, right: &DaemonProcess) -> bool {
376    left.pid == right.pid
377        && left.exe_path == right.exe_path
378        && left.exe_sha256 == right.exe_sha256
379        && left.boot_id == right.boot_id
380        && same_endpoint(&left.ipc_endpoint, &right.ipc_endpoint)
381}
382
383fn connect_endpoint(
384    endpoint: &Endpoint,
385) -> Result<interprocess::local_socket::Stream, EndpointProbeError> {
386    if endpoint.path.is_empty() {
387        return Err(EndpointProbeError::Connect(io::Error::new(
388            io::ErrorKind::InvalidInput,
389            "backend endpoint path is empty",
390        )));
391    }
392    let name = endpoint_name(&endpoint.path).map_err(EndpointProbeError::LocalSocketName)?;
393    interprocess::local_socket::Stream::connect(name).map_err(EndpointProbeError::Connect)
394}
395
396fn write_probe_frame_with_deadline(
397    stream: &mut interprocess::local_socket::Stream,
398    body: &[u8],
399    deadline: Instant,
400) -> Result<(), EndpointProbeError> {
401    if body.len() > MAX_FRAME_BYTES {
402        return Err(EndpointProbeError::FrameTooLarge {
403            body_length: body.len(),
404            cap: MAX_FRAME_BYTES,
405        });
406    }
407    let mut wire = Vec::with_capacity(1 + 4 + body.len());
408    wire.push(ENVELOPE_VERSION);
409    wire.extend_from_slice(&(body.len() as u32).to_le_bytes());
410    wire.extend_from_slice(body);
411    write_all_with_deadline(stream, &wire, deadline)?;
412    flush_with_deadline(stream, deadline)
413}
414
415fn read_probe_frame_with_deadline(
416    stream: &mut interprocess::local_socket::Stream,
417    deadline: Instant,
418) -> Result<Vec<u8>, EndpointProbeError> {
419    let mut version = [0_u8; 1];
420    read_exact_with_deadline(stream, &mut version, deadline)?;
421    if version[0] != ENVELOPE_VERSION {
422        return Err(EndpointProbeError::UnsupportedFramingVersion {
423            got: version[0],
424            expected: ENVELOPE_VERSION,
425        });
426    }
427
428    let mut len = [0_u8; 4];
429    read_exact_with_deadline(stream, &mut len, deadline)?;
430    let body_length = u32::from_le_bytes(len) as usize;
431    if body_length > MAX_FRAME_BYTES {
432        return Err(EndpointProbeError::FrameTooLarge {
433            body_length,
434            cap: MAX_FRAME_BYTES,
435        });
436    }
437
438    let mut body = vec![0_u8; body_length];
439    if body_length > 0 {
440        read_exact_with_deadline(stream, &mut body, deadline)?;
441    }
442    Ok(body)
443}
444
445fn write_all_with_deadline<W: Write>(
446    writer: &mut W,
447    mut buf: &[u8],
448    deadline: Instant,
449) -> Result<(), EndpointProbeError> {
450    while !buf.is_empty() {
451        match writer.write(buf) {
452            Ok(0) => {
453                return Err(EndpointProbeError::Io(io::Error::new(
454                    io::ErrorKind::WriteZero,
455                    "endpoint probe write returned zero bytes",
456                )));
457            }
458            Ok(written) => buf = &buf[written..],
459            Err(err) if err.kind() == io::ErrorKind::WouldBlock => wait_for_io(deadline)?,
460            Err(err) => return Err(EndpointProbeError::Io(err)),
461        }
462    }
463    Ok(())
464}
465
466fn read_exact_with_deadline<R: Read>(
467    reader: &mut R,
468    mut buf: &mut [u8],
469    deadline: Instant,
470) -> Result<(), EndpointProbeError> {
471    while !buf.is_empty() {
472        match reader.read(buf) {
473            Ok(0) => wait_for_io(deadline)?,
474            Ok(read) => {
475                let tmp = buf;
476                buf = &mut tmp[read..];
477            }
478            Err(err) if err.kind() == io::ErrorKind::WouldBlock => wait_for_io(deadline)?,
479            Err(err) => return Err(EndpointProbeError::Io(err)),
480        }
481    }
482    Ok(())
483}
484
485fn flush_with_deadline<W: Write>(
486    writer: &mut W,
487    deadline: Instant,
488) -> Result<(), EndpointProbeError> {
489    loop {
490        match writer.flush() {
491            Ok(()) => return Ok(()),
492            Err(err) if err.kind() == io::ErrorKind::WouldBlock => wait_for_io(deadline)?,
493            Err(err) => return Err(EndpointProbeError::Io(err)),
494        }
495    }
496}
497
498fn wait_for_io(deadline: Instant) -> Result<(), EndpointProbeError> {
499    if Instant::now() >= deadline {
500        return Err(EndpointProbeError::Timeout);
501    }
502    let remaining = deadline.saturating_duration_since(Instant::now());
503    thread::sleep(remaining.min(NONBLOCKING_POLL_INTERVAL));
504    Ok(())
505}
506
507fn endpoint_name(path: &str) -> io::Result<interprocess::local_socket::Name<'_>> {
508    #[cfg(unix)]
509    {
510        use interprocess::local_socket::GenericFilePath;
511        path.to_fs_name::<GenericFilePath>()
512    }
513
514    #[cfg(windows)]
515    {
516        use interprocess::local_socket::GenericNamespaced;
517        path.to_ns_name::<GenericNamespaced>()
518    }
519}