1use std::io;
4use std::sync::mpsc;
5use std::thread;
6use std::time::Duration;
7
8use interprocess::local_socket::prelude::*;
9use prost::Message;
10
11use crate::broker::capabilities::{handoff_transport_available, CAP_HANDLE_PASSING};
12use crate::broker::protocol::{
13 hello_reply::Result as HelloReplyResult, read_frame, validate_frame_envelope, write_frame,
14 AdminReply, AdminRequest, ErrorCode, Frame, FrameKind, FrameValidationError, FramingError,
15 HandoffAck, Hello, HelloReply, Negotiated, PayloadEncoding, ADMIN_PAYLOAD_PROTOCOL,
16 CONTROL_PAYLOAD_PROTOCOL, PROTOCOL_VERSION,
17};
18use crate::broker::server::handoff::validate_handoff_frame;
19use crate::broker::server::local_socket_name;
20
21pub const DEFAULT_HANDOFF_READY_TIMEOUT: Duration = Duration::from_secs(2);
24
25pub const RUNNING_PROCESS_DISABLE_ENV: &str = "RUNNING_PROCESS_DISABLE";
27pub const RUNNING_PROCESS_DISABLE_VALUE: &str = "1";
29pub const RUNNING_PROCESS_FAKE_BACKEND_ENV: &str = "RUNNING_PROCESS_FAKE_BACKEND";
42
43pub fn broker_disabled_by_env() -> Result<bool, BrokerDisableEnvError> {
48 let Some(value) = std::env::var_os(RUNNING_PROCESS_DISABLE_ENV) else {
49 return Ok(false);
50 };
51 let value = value.to_string_lossy();
52 if value == RUNNING_PROCESS_DISABLE_VALUE {
53 Ok(true)
54 } else {
55 Err(BrokerDisableEnvError {
56 value: value.into_owned(),
57 })
58 }
59}
60
61#[derive(Clone, Debug)]
63pub struct ConnectBackendRequest<'a> {
64 pub broker_endpoint: &'a str,
66 pub service_name: &'a str,
68 pub wanted_version: &'a str,
70 pub self_version: &'a str,
72 pub cached_backend_endpoint: Option<&'a str>,
74 pub client_version: &'a str,
76 pub client_lib_name: &'a str,
78 pub client_lib_version: &'a str,
80 pub client_keepalive_secs: u64,
82 pub adopt_handed_off_connection: bool,
98 pub handoff_ready_timeout: Duration,
101}
102
103impl<'a> ConnectBackendRequest<'a> {
104 pub fn new(
106 broker_endpoint: &'a str,
107 service_name: &'a str,
108 wanted_version: &'a str,
109 self_version: &'a str,
110 ) -> Self {
111 Self {
112 broker_endpoint,
113 service_name,
114 wanted_version,
115 self_version,
116 cached_backend_endpoint: None,
117 client_version: "",
118 client_lib_name: "running-process",
119 client_lib_version: env!("CARGO_PKG_VERSION"),
120 client_keepalive_secs: 0,
121 adopt_handed_off_connection: false,
122 handoff_ready_timeout: DEFAULT_HANDOFF_READY_TIMEOUT,
123 }
124 }
125
126 fn can_hello_skip(&self) -> bool {
127 self.cached_backend_endpoint.is_some() && self.wanted_version == self.self_version
128 }
129
130 fn hello(&self) -> Hello {
131 Hello {
132 client_min_protocol: PROTOCOL_VERSION,
133 client_max_protocol: PROTOCOL_VERSION,
134 service_name: self.service_name.into(),
135 wanted_version: self.wanted_version.into(),
136 client_version: self.client_version.into(),
137 client_capabilities: client_capabilities(),
138 auth_token: Vec::new(),
139 request_id: "hello".into(),
140 connection_id: 0,
141 peer_pid: std::process::id(),
142 client_lib_name: self.client_lib_name.into(),
143 client_lib_version: self.client_lib_version.into(),
144 peer_attestation_nonce: Vec::new(),
145 capability_token: Vec::new(),
146 client_keepalive_secs: self.client_keepalive_secs,
147 }
148 }
149}
150
151fn client_capabilities() -> u64 {
158 if handoff_transport_available() {
159 CAP_HANDLE_PASSING
160 } else {
161 0
162 }
163}
164
165#[derive(Clone, Copy, Debug, PartialEq, Eq)]
167pub enum BackendConnectionRoute {
168 HelloSkip,
176 BrokerNegotiated,
178 HandlePassed,
187}
188
189#[derive(Debug)]
191pub struct BackendConnection {
192 pub stream: interprocess::local_socket::Stream,
194 pub endpoint: String,
201 pub route: BackendConnectionRoute,
203 pub negotiated: Option<Negotiated>,
205}
206
207impl BackendConnection {
208 pub fn handoff_token(&self) -> Option<&[u8]> {
219 self.negotiated
220 .as_ref()
221 .map(|negotiated| negotiated.handle_passed_token.as_slice())
222 .filter(|token| !token.is_empty())
223 }
224}
225
226pub fn connect_to_backend(
249 request: ConnectBackendRequest<'_>,
250) -> Result<BackendConnection, BrokerClientError> {
251 if let Some(endpoint) = fake_backend_endpoint_from_env() {
252 let stream = connect_local_socket(&endpoint).map_err(BrokerClientError::BackendConnect)?;
253 return Ok(BackendConnection {
254 stream,
255 endpoint,
256 route: BackendConnectionRoute::HelloSkip,
257 negotiated: None,
258 });
259 }
260
261 if request.can_hello_skip() {
262 if let Some(endpoint) = request.cached_backend_endpoint {
263 if let Ok(stream) = connect_local_socket(endpoint) {
264 return Ok(BackendConnection {
265 stream,
266 endpoint: endpoint.into(),
267 route: BackendConnectionRoute::HelloSkip,
268 negotiated: None,
269 });
270 }
271 }
272 }
273
274 let (broker_stream, negotiated) = broker_hello(&request)?;
275 if request.adopt_handed_off_connection && handoff_negotiated(&negotiated) {
276 if let Some(adopted) = await_handoff_ready(
277 broker_stream,
278 negotiated.handle_passed_token.clone(),
279 request.handoff_ready_timeout,
280 ) {
281 return Ok(BackendConnection {
282 endpoint: negotiated.backend_pipe.clone(),
283 stream: adopted,
284 route: BackendConnectionRoute::HandlePassed,
285 negotiated: Some(negotiated),
286 });
287 }
288 }
289
290 if negotiated.backend_pipe.is_empty() {
291 return Err(BrokerClientError::EmptyBackendPipe);
292 }
293 let stream = connect_local_socket(&negotiated.backend_pipe)
294 .map_err(BrokerClientError::BackendConnect)?;
295 Ok(BackendConnection {
296 endpoint: negotiated.backend_pipe.clone(),
297 stream,
298 route: BackendConnectionRoute::BrokerNegotiated,
299 negotiated: Some(negotiated),
300 })
301}
302
303fn fake_backend_endpoint_from_env() -> Option<String> {
313 let value = std::env::var_os(RUNNING_PROCESS_FAKE_BACKEND_ENV)?;
314 let value = value.to_string_lossy();
315 if value.is_empty() {
316 return None;
317 }
318 if matches!(broker_disabled_by_env(), Ok(true)) {
319 return None;
320 }
321 Some(value.into_owned())
322}
323
324fn handoff_negotiated(negotiated: &Negotiated) -> bool {
327 negotiated.server_capabilities & CAP_HANDLE_PASSING == CAP_HANDLE_PASSING
328 && !negotiated.handle_passed_token.is_empty()
329}
330
331fn await_handoff_ready(
344 stream: interprocess::local_socket::Stream,
345 expected_token: Vec<u8>,
346 timeout: Duration,
347) -> Option<interprocess::local_socket::Stream> {
348 let (result_tx, result_rx) = mpsc::channel();
349 thread::spawn(move || {
350 let mut stream = stream;
351 let outcome = read_handoff_ready(&mut stream, &expected_token).map(|()| stream);
352 let _ = result_tx.send(outcome);
353 });
354 match result_rx.recv_timeout(timeout) {
355 Ok(Ok(stream)) => Some(stream),
356 Ok(Err(_)) | Err(_) => None,
357 }
358}
359
360fn read_handoff_ready(
365 stream: &mut interprocess::local_socket::Stream,
366 expected_token: &[u8],
367) -> Result<(), &'static str> {
368 let bytes = read_frame(stream).map_err(|_| "failed to read handoff-ready frame")?;
369 let frame =
370 Frame::decode(bytes.as_slice()).map_err(|_| "failed to decode handoff-ready Frame")?;
371 validate_handoff_frame(&frame, FrameKind::Event)?;
372 let ack = HandoffAck::decode(frame.payload.as_slice())
373 .map_err(|_| "failed to decode handoff-ready HandoffAck payload")?;
374 if ack.token != expected_token {
375 return Err("handoff-ready token echo does not match the negotiated token");
376 }
377 if !ack.accepted {
378 return Err("broker relayed a refused handoff");
379 }
380 Ok(())
381}
382
383pub fn send_admin_request(
385 broker_endpoint: &str,
386 request: AdminRequest,
387) -> Result<AdminReply, BrokerClientError> {
388 let mut stream =
389 connect_local_socket(broker_endpoint).map_err(BrokerClientError::BrokerConnect)?;
390 let request_frame = Frame {
391 envelope_version: PROTOCOL_VERSION,
392 kind: FrameKind::Request as i32,
393 payload_protocol: ADMIN_PAYLOAD_PROTOCOL,
394 payload: request.encode_to_vec(),
395 request_id: 1,
396 payload_encoding: PayloadEncoding::None as i32,
397 deadline_unix_ms: 0,
398 traceparent: String::new(),
399 tracestate: String::new(),
400 };
401 write_frame(&mut stream, &request_frame.encode_to_vec())?;
402
403 let response_bytes = read_frame(&mut stream)?;
404 let response_frame =
405 Frame::decode(response_bytes.as_slice()).map_err(BrokerClientError::DecodeFrame)?;
406 validate_response_frame(
407 &response_frame,
408 ADMIN_PAYLOAD_PROTOCOL,
409 "payload_protocol is not admin",
410 )?;
411 AdminReply::decode(response_frame.payload.as_slice())
412 .map_err(BrokerClientError::DecodeAdminReply)
413}
414
415pub fn connect_local_socket(endpoint: &str) -> io::Result<interprocess::local_socket::Stream> {
417 let name = local_socket_name(endpoint)?;
418 LocalSocketStream::connect(name)
419}
420
421fn broker_hello(
422 request: &ConnectBackendRequest<'_>,
423) -> Result<(interprocess::local_socket::Stream, Negotiated), BrokerClientError> {
424 let mut stream =
425 connect_local_socket(request.broker_endpoint).map_err(BrokerClientError::BrokerConnect)?;
426 let hello = request.hello();
427 let request_frame = Frame {
428 envelope_version: PROTOCOL_VERSION,
429 kind: FrameKind::Request as i32,
430 payload_protocol: CONTROL_PAYLOAD_PROTOCOL,
431 payload: hello.encode_to_vec(),
432 request_id: 1,
433 payload_encoding: PayloadEncoding::None as i32,
434 deadline_unix_ms: 0,
435 traceparent: String::new(),
436 tracestate: String::new(),
437 };
438 write_frame(&mut stream, &request_frame.encode_to_vec())?;
439
440 let response_bytes = read_frame(&mut stream)?;
441 let response_frame =
442 Frame::decode(response_bytes.as_slice()).map_err(BrokerClientError::DecodeFrame)?;
443 validate_response_frame(
444 &response_frame,
445 CONTROL_PAYLOAD_PROTOCOL,
446 "payload_protocol is not control-plane",
447 )?;
448 let reply = HelloReply::decode(response_frame.payload.as_slice())
449 .map_err(BrokerClientError::DecodeHelloReply)?;
450 match reply
451 .result
452 .ok_or(BrokerClientError::MissingHelloReplyResult)?
453 {
454 HelloReplyResult::Negotiated(negotiated) => Ok((stream, negotiated)),
455 HelloReplyResult::Refused(refused) => Err(BrokerClientError::Refused {
456 code: ErrorCode::try_from(refused.code).unwrap_or(ErrorCode::Unspecified),
457 reason: refused.reason,
458 retry_after_ms: refused.retry_after_ms,
459 }),
460 }
461}
462
463fn validate_response_frame(
464 frame: &Frame,
465 expected_payload_protocol: u32,
466 payload_protocol_error: &'static str,
467) -> Result<(), BrokerClientError> {
468 validate_frame_envelope(frame, FrameKind::Response, expected_payload_protocol).map_err(
469 |error| {
470 BrokerClientError::UnexpectedResponseFrame(match error {
471 FrameValidationError::EnvelopeVersion { .. } => "envelope_version is not v1",
472 FrameValidationError::Kind { .. } => "kind is not RESPONSE",
473 FrameValidationError::PayloadProtocol { .. } => payload_protocol_error,
474 FrameValidationError::PayloadEncoding { .. } => "payload is compressed",
475 })
476 },
477 )
478}
479
480#[derive(Clone, Debug, PartialEq, Eq, thiserror::Error)]
482#[error("RUNNING_PROCESS_DISABLE must be unset or 1, got {value:?}")]
483pub struct BrokerDisableEnvError {
484 pub value: String,
486}
487
488#[derive(Debug, thiserror::Error)]
490pub enum BrokerClientError {
491 #[error("failed to connect to broker: {0}")]
493 BrokerConnect(io::Error),
494 #[error("failed to connect to negotiated backend: {0}")]
496 BackendConnect(io::Error),
497 #[error(transparent)]
499 Framing(#[from] FramingError),
500 #[error("failed to decode broker response Frame: {0}")]
502 DecodeFrame(prost::DecodeError),
503 #[error("failed to decode broker HelloReply: {0}")]
505 DecodeHelloReply(prost::DecodeError),
506 #[error("failed to decode broker AdminReply: {0}")]
508 DecodeAdminReply(prost::DecodeError),
509 #[error("unexpected broker response frame: {0}")]
511 UnexpectedResponseFrame(&'static str),
512 #[error("broker HelloReply did not contain a result")]
514 MissingHelloReplyResult,
515 #[error("broker refused Hello: {reason} ({code:?}, retry_after_ms={retry_after_ms})")]
517 Refused {
518 code: ErrorCode,
520 reason: String,
522 retry_after_ms: u64,
524 },
525 #[error("broker negotiated an empty backend endpoint")]
527 EmptyBackendPipe,
528}