1use std::io::{Read, Write};
17
18use interprocess::local_socket::traits::Stream as _;
19use interprocess::local_socket::Stream;
20use prost::Message;
21
22use crate::broker::lifecycle::names::PipePathError;
23use crate::broker::lifecycle::names_v2::v2_program_pipe;
24use crate::broker::lifecycle::sid::{user_sid_hash, SidError};
25use crate::broker::protocol::{
26 hello_reply, read_frame, write_frame, FramingError, Hello, HelloReply, Negotiated, Refused,
27 ENVELOPE_VERSION,
28};
29
30#[derive(Debug, thiserror::Error)]
32pub enum BrokerV2Error {
33 #[error(transparent)]
35 Sid(#[from] SidError),
36
37 #[error(transparent)]
39 PipeName(#[from] PipePathError),
40
41 #[error("dial v2 broker pipe at {socket_path:?}: {source}")]
43 Dial {
44 socket_path: String,
46 #[source]
48 source: std::io::Error,
49 },
50
51 #[error(transparent)]
54 Framing(#[from] FramingError),
55
56 #[error("Hello round-trip io: {0}")]
58 Io(#[from] std::io::Error),
59
60 #[error("HelloReply decode: {0}")]
62 Decode(#[from] prost::DecodeError),
63
64 #[error("HelloReply.result missing")]
66 MissingResult,
67
68 #[error("broker refused Hello: {reason}")]
70 Refused {
71 reason: String,
73 details: Box<Refused>,
75 },
76
77 #[error("Hello encode: {0}")]
79 Encode(#[from] prost::EncodeError),
80}
81
82#[derive(Debug)]
89pub struct ClientSession {
90 stream: Stream,
91 negotiated: Negotiated,
92}
93
94impl ClientSession {
95 pub fn negotiated(&self) -> &Negotiated {
97 &self.negotiated
98 }
99
100 pub fn into_inner(self) -> (Stream, Negotiated) {
105 (self.stream, self.negotiated)
106 }
107}
108
109pub fn connect(program: &str, version_hint: &str) -> Result<ClientSession, BrokerV2Error> {
120 let sid = user_sid_hash()?;
121 let pipe_name = v2_program_pipe(program, &sid, 0)?;
122 let socket_path = resolve_socket_path(&pipe_name);
123 let name = wrap_socket_name(&socket_path).map_err(|err| BrokerV2Error::Dial {
124 socket_path: socket_path.clone(),
125 source: std::io::Error::new(std::io::ErrorKind::InvalidInput, err),
126 })?;
127 let mut stream = Stream::connect(name).map_err(|source| BrokerV2Error::Dial {
128 socket_path: socket_path.clone(),
129 source,
130 })?;
131 let negotiated = hello_round_trip(&mut stream, program, version_hint)?;
132 Ok(ClientSession { stream, negotiated })
133}
134
135fn hello_round_trip<S: Read + Write>(
136 stream: &mut S,
137 program: &str,
138 version_hint: &str,
139) -> Result<Negotiated, BrokerV2Error> {
140 let hello = Hello {
141 client_min_protocol: ENVELOPE_VERSION as u32,
142 client_max_protocol: ENVELOPE_VERSION as u32,
143 service_name: program.to_string(),
144 wanted_version: version_hint.to_string(),
145 client_version: env!("CARGO_PKG_VERSION").to_string(),
146 client_capabilities: 0,
147 auth_token: Vec::new(),
148 request_id: format!("client_v2-{program}-{}", std::process::id()),
149 connection_id: 0,
150 peer_pid: std::process::id(),
151 client_lib_name: "running-process broker::client_v2".to_string(),
152 client_lib_version: env!("CARGO_PKG_VERSION").to_string(),
153 peer_attestation_nonce: Vec::new(),
154 capability_token: Vec::new(),
155 client_keepalive_secs: 0,
156 };
157 let mut body = Vec::with_capacity(hello.encoded_len());
158 hello.encode(&mut body)?;
159 write_frame(stream, &body)?;
160
161 let reply_bytes = read_frame(stream)?;
162 let reply = HelloReply::decode(reply_bytes.as_slice())?;
163 match reply.result {
164 Some(hello_reply::Result::Negotiated(n)) => Ok(n),
165 Some(hello_reply::Result::Refused(r)) => Err(BrokerV2Error::Refused {
166 reason: r.reason.clone(),
167 details: Box::new(r),
168 }),
169 None => Err(BrokerV2Error::MissingResult),
170 }
171}
172
173fn resolve_socket_path(bare_name: &str) -> String {
174 #[cfg(windows)]
175 {
176 format!(r"\\.\pipe\{bare_name}")
177 }
178 #[cfg(unix)]
179 {
180 use std::path::PathBuf;
181 let dir: PathBuf = {
182 #[cfg(target_os = "macos")]
183 {
184 let uid = unsafe { libc::getuid() };
185 let tmp = std::env::var_os("TMPDIR")
186 .map(PathBuf::from)
187 .unwrap_or_else(|| PathBuf::from("/tmp"));
188 tmp.join(format!(".rp-{uid}-broker-v2"))
189 }
190 #[cfg(not(target_os = "macos"))]
191 {
192 if let Some(d) = std::env::var_os("XDG_RUNTIME_DIR") {
193 PathBuf::from(d).join("running-process").join("broker-v2")
194 } else {
195 let uid = unsafe { libc::getuid() };
196 PathBuf::from(format!("/tmp/running-process-{uid}/broker-v2"))
197 }
198 }
199 };
200 let leaf = if cfg!(target_os = "macos") {
201 let mut hash = blake3::Hasher::new();
202 hash.update(bare_name.as_bytes());
203 let bytes = hash.finalize();
204 let mut hex = String::with_capacity(16);
205 for b in bytes.as_bytes().iter().take(8) {
206 use std::fmt::Write as _;
207 let _ = write!(hex, "{b:02x}");
208 }
209 format!("{hex}.sock")
210 } else {
211 format!("{bare_name}.sock")
212 };
213 dir.join(leaf).to_string_lossy().into_owned()
214 }
215}
216
217fn wrap_socket_name(socket_path: &str) -> Result<interprocess::local_socket::Name<'_>, String> {
218 use interprocess::local_socket::prelude::*;
219 #[cfg(windows)]
220 {
221 use interprocess::local_socket::GenericNamespaced;
222 let bare = socket_path
223 .strip_prefix(r"\\.\pipe\")
224 .unwrap_or(socket_path);
225 bare.to_ns_name::<GenericNamespaced>()
226 .map_err(|e| format!("to_ns_name: {e}"))
227 }
228 #[cfg(unix)]
229 {
230 use interprocess::local_socket::GenericFilePath;
231 socket_path
232 .to_fs_name::<GenericFilePath>()
233 .map_err(|e| format!("to_fs_name: {e}"))
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use interprocess::local_socket::traits::Listener as _;
241 use interprocess::local_socket::ListenerOptions;
242 use std::sync::mpsc;
243 use std::thread;
244 use std::time::{Duration, Instant};
245
246 fn spawn_stub_broker(socket_path: String) -> mpsc::Receiver<()> {
251 let (tx, rx) = mpsc::channel();
252 thread::spawn(move || {
253 let name = wrap_socket_name(&socket_path).expect("wrap_socket_name");
254 #[cfg(unix)]
255 {
256 let _ = std::fs::create_dir_all(
257 std::path::Path::new(&socket_path).parent().unwrap(),
258 );
259 let _ = std::fs::remove_file(&socket_path);
260 }
261 let listener = ListenerOptions::new()
262 .name(name)
263 .create_sync()
264 .expect("ListenerOptions create_sync");
265 tx.send(()).expect("send listener-ready signal");
266 let mut stream = listener.accept().expect("accept");
267 let bytes = read_frame(&mut stream).expect("read Hello frame");
268 let hello = Hello::decode(bytes.as_slice()).expect("decode Hello");
269 let reply = HelloReply {
270 result: Some(hello_reply::Result::Negotiated(Negotiated {
271 negotiated_protocol: ENVELOPE_VERSION as u32,
272 daemon_version: "stub-1.2.3".to_string(),
273 backend_pipe: String::new(),
274 warnings: Vec::new(),
275 server_capabilities: 0,
276 keepalive_interval_secs: 0,
277 handle_passed_token: Vec::new(),
278 connection_id: 0x00C0_FFEE,
279 })),
280 };
281 let mut body = Vec::with_capacity(reply.encoded_len());
282 reply.encode(&mut body).expect("encode HelloReply");
283 write_frame(&mut stream, &body).expect("write HelloReply frame");
284 #[cfg(unix)]
285 {
286 let _ = std::fs::remove_file(&socket_path);
287 }
288 let _ = hello.service_name;
289 });
290 rx
291 }
292
293 #[test]
294 fn connect_completes_hello_round_trip_against_stub_broker() {
295 let program = "client-v2-stub";
297 let sid = user_sid_hash().expect("user_sid_hash");
298 let pipe_name = v2_program_pipe(program, &sid, 0).expect("pipe name");
299 let socket_path = resolve_socket_path(&pipe_name);
300
301 let ready = spawn_stub_broker(socket_path.clone());
302 ready
303 .recv_timeout(Duration::from_secs(2))
304 .expect("stub broker listening");
305
306 let start = Instant::now();
310 let session = loop {
311 match connect(program, "0.0.0") {
312 Ok(s) => break s,
313 Err(err) if start.elapsed() < Duration::from_secs(2) => {
314 eprintln!("connect retry after error: {err}");
315 std::thread::sleep(Duration::from_millis(50));
316 continue;
317 }
318 Err(err) => panic!("connect failed after retries: {err}"),
319 }
320 };
321
322 let neg = session.negotiated();
323 assert_eq!(neg.negotiated_protocol, ENVELOPE_VERSION as u32);
324 assert_eq!(neg.connection_id, 0x00C0_FFEE);
325 assert_eq!(neg.daemon_version, "stub-1.2.3");
326 }
327
328 #[test]
329 fn connect_with_no_broker_returns_dial_error() {
330 let err = connect("client-v2-no-broker-ever", "0.0.0")
331 .expect_err("no broker => Dial error");
332 match err {
333 BrokerV2Error::Dial { .. } => {}
334 other => panic!("expected Dial, got: {other:?}"),
335 }
336 }
337}