Skip to main content

tsafe_agent/
lib.rs

1//! tsafe Agent daemon — library entry point exposed for the tsafe meta-crate.
2//!
3//! All daemon logic is defined here so the meta-crate can link against it.
4//! `main.rs` is a thin shim that calls [`run`].
5
6#[cfg(not(target_os = "windows"))]
7use std::collections::HashMap;
8use std::io::{BufRead, BufReader, Write};
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::Arc;
11#[cfg(not(target_os = "windows"))]
12use std::sync::Mutex;
13use std::time::{Duration, Instant};
14
15use anyhow::{Context, Result};
16use zeroize::{Zeroize, ZeroizeOnDrop};
17
18#[cfg(not(target_os = "windows"))]
19use tsafe_core::agent::{cellos_socket_path, CellRecord, CellState, CellosRequest, CellosResponse};
20use tsafe_core::agent::{
21    clear_agent_sock, pipe_name, write_agent_sock, AgentRequest, AgentResponse, AgentSession,
22};
23#[cfg(not(target_os = "windows"))]
24use tsafe_core::audit::{AuditCellosContext, AuditContext, AuditEntry, AuditLog};
25use tsafe_core::profile;
26use tsafe_core::{keyring_store, vault::Vault};
27
28// ── Password holder — zeroed on drop ─────────────────────────────────────────
29
30#[derive(Zeroize, ZeroizeOnDrop)]
31struct Password(String);
32
33// ── CellOS cell cache ─────────────────────────────────────────────────────────
34
35#[cfg(not(target_os = "windows"))]
36type CellCache = Arc<Mutex<HashMap<String, CellState>>>;
37
38/// Lock the cell cache, recovering from poison if a previous holder panicked.
39///
40/// `cell_cache` is just a token store; partial-update panics cannot leave the
41/// `HashMap` in a structurally invalid state — at worst a single cell entry is
42/// half-written, which the next request will overwrite cleanly. Recovering is
43/// strictly preferable to the alternative (`.unwrap()`), which would propagate
44/// the original panic and kill the long-lived agent on every subsequent
45/// request.
46#[cfg(not(target_os = "windows"))]
47fn lock_cell_cache(
48    cell_cache: &CellCache,
49) -> std::sync::MutexGuard<'_, HashMap<String, CellState>> {
50    match cell_cache.lock() {
51        Ok(g) => g,
52        Err(poisoned) => poisoned.into_inner(),
53    }
54}
55
56/// Install process signal handlers for daemon lifecycle.
57///
58/// Without this, `kill <agent-pid>` (or a logout-time SIGTERM) terminates the
59/// process abruptly and leaves the unix socket file orphaned on disk —
60/// subsequent `tsafe agent unlock` runs then fail to bind. With the handler in
61/// place, the accept loop observes `stop` on its next iteration, returns from
62/// `serve()`, and the `SocketCleanup` RAII guard unlinks the socket file
63/// before the process exits.
64///
65/// SIGHUP is different: this is a background unlock daemon with explicit TTL
66/// and lock controls, so terminal hangup should not kill it immediately after
67/// the interactive launcher exits.
68#[cfg(not(target_os = "windows"))]
69fn install_signal_handlers(stop: Arc<AtomicBool>) -> std::io::Result<()> {
70    use signal_hook::consts::{SIGINT, SIGTERM};
71    use signal_hook::flag;
72    flag::register(SIGTERM, Arc::clone(&stop))?;
73    flag::register(SIGINT, Arc::clone(&stop))?;
74    // SAFETY: signal() is called during single-threaded daemon startup before
75    // request-serving threads are spawned. Ignoring SIGHUP is process-global
76    // and intentional for this short-lived TTL-bound background agent.
77    unsafe {
78        libc::signal(libc::SIGHUP, libc::SIG_IGN);
79    }
80    Ok(())
81}
82
83// ── Daemon entry ──────────────────────────────────────────────────────────────
84
85fn daemon_main() -> Result<()> {
86    let args: Vec<String> = std::env::args().collect();
87    if args.len() != 6 {
88        eprintln!(
89            "usage: tsafe-agent <profile> <session_token_hex> <requesting_pid> <idle_ttl_secs> <absolute_ttl_secs>"
90        );
91        std::process::exit(1);
92    }
93
94    let profile_name = &args[1];
95    let session_token = &args[2];
96    let requesting_pid: u32 = args[3].parse().context("invalid requesting_pid")?;
97    let idle_ttl_secs: u64 = args[4].parse().context("invalid idle_ttl_secs")?;
98    let absolute_ttl_secs: u64 = args[5].parse().context("invalid absolute_ttl_secs")?;
99
100    // Acquire the vault password: keychain → TSAFE_VAULT_PASSWORD env → interactive prompt.
101    #[cfg(not(target_os = "windows"))]
102    let raw_password = acquire_password(profile_name)?;
103    #[cfg(not(target_os = "windows"))]
104    let pw = Password(raw_password.clone());
105    #[cfg(target_os = "windows")]
106    let pw = Password(acquire_password(profile_name)?);
107
108    // Validate the password works before we start serving.
109    {
110        let path = profile::vault_path(profile_name);
111        Vault::open(&path, pw.0.as_bytes()).context("wrong password — agent will not start")?;
112    }
113
114    let agent_pid = std::process::id();
115    let pipe = pipe_name(agent_pid);
116
117    // Write the socket address to stdout so `tsafe agent unlock` can read it.
118    //
119    // IMPORTANT: this println is the agent→shell handshake. It MUST be consumed
120    // by `eval $(tsafe agent unlock)` so the token never lands in shell
121    // history. Anyone copy-pasting this output is exposing their session
122    // token. Do not change the format without updating the unlock-shell
123    // helper — the `KEY=PIPE::TOKEN` shape is load-bearing.
124    println!("TSAFE_AGENT_SOCK={pipe}::{session_token}");
125    let _ = std::io::stdout().flush();
126
127    // Persist the socket address so any tsafe invocation (not just the unlock
128    // shell) can find the running agent via the state file.
129    write_agent_sock(&format!("{pipe}::{session_token}"));
130
131    let stop = Arc::new(AtomicBool::new(false));
132    // Install SIGTERM/SIGINT handlers on unix so the accept loop drains and
133    // the SocketCleanup Drop guard unlinks the socket file. Without this, a
134    // `kill <agent-pid>` would orphan the socket and break the next unlock.
135    #[cfg(not(target_os = "windows"))]
136    install_signal_handlers(Arc::clone(&stop))
137        .context("failed to install SIGTERM/SIGINT handlers")?;
138    let absolute_deadline = Instant::now() + Duration::from_secs(absolute_ttl_secs);
139    let mut session =
140        AgentSession::new(session_token.to_string(), idle_ttl_secs, absolute_deadline);
141    spawn_expiry_watchdog(pipe.clone(), absolute_deadline, Arc::clone(&stop));
142
143    // Spawn the CellOS broker socket in a background thread (Unix only).
144    // The thread shares the vault password (Arc) and the revocation cache.
145    #[cfg(not(target_os = "windows"))]
146    {
147        let cell_cache: CellCache = Arc::new(Mutex::new(HashMap::new()));
148        let profile = profile_name.clone();
149        let shared_pw = Arc::new(raw_password);
150        let cache = Arc::clone(&cell_cache);
151        let stop_clone = Arc::clone(&stop);
152        std::thread::spawn(move || {
153            if let Err(e) = serve_cellos(&profile, shared_pw, cache, stop_clone) {
154                eprintln!("tsafe-agent: CellOS socket error: {e:#}");
155            }
156        });
157    }
158
159    serve(
160        &pipe,
161        &pw,
162        &mut session,
163        requesting_pid,
164        absolute_deadline,
165        stop,
166    )?;
167
168    // Remove the state file so stale sock addresses don't linger.
169    clear_agent_sock();
170
171    Ok(())
172}
173
174/// Launch the tsafe-agent daemon.
175///
176/// Parses `std::env::args()`, validates the vault password, binds the IPC
177/// socket/pipe, and serves requests until the TTL expires or a Lock request
178/// is received. Exits the process with code 1 on fatal error.
179pub fn run() {
180    if let Err(e) = daemon_main() {
181        eprintln!("tsafe-agent: {e:#}");
182        std::process::exit(1);
183    }
184}
185
186// ── Password acquisition ──────────────────────────────────────────────────────
187
188/// Acquire the vault password using a priority chain:
189///   1. OS keychain (macOS Keychain / Linux Secret Service)
190///   2. `TSAFE_VAULT_PASSWORD` env var — loud warning (ends up in /proc/self/environ)
191///   3. Interactive TTY prompt via rpassword
192fn acquire_password(profile: &str) -> Result<String> {
193    // 1. Try the OS keychain first (uses the same entry as `tsafe biometric enable`).
194    //
195    // We do NOT pre-check with `has_password`: on non-macOS the generic backend
196    // has no no-UI existence probe, so `has_password` performs a full
197    // `get_password` and `retrieve_password` performs a second one — surfacing
198    // as a double OS-keychain prompt for the user. `retrieve_password` already
199    // returns `Ok(None)` when the entry is absent, so a single call is
200    // sufficient and on every platform fires at most one keychain interaction.
201    match keyring_store::retrieve_password(profile) {
202        Ok(Some(pw)) => return Ok(pw),
203        Ok(None) => {}
204        Err(e) => eprintln!("tsafe-agent: keychain lookup failed: {e}; falling back"),
205    }
206
207    // 2. TSAFE_VAULT_PASSWORD env var — security risk, warn loudly.
208    if let Ok(env_pw) = std::env::var("TSAFE_VAULT_PASSWORD") {
209        eprintln!(
210            "tsafe-agent: WARNING — using TSAFE_VAULT_PASSWORD from environment. \
211             This value is visible in /proc/self/environ, `docker inspect`, and shell \
212             history. Use `tsafe biometric enable` to store the password in the OS \
213             keychain instead."
214        );
215        return Ok(env_pw);
216    }
217
218    // 3. Interactive prompt (original behaviour).
219    rpassword_read()
220}
221
222// ── CellOS socket server (Unix-only) ─────────────────────────────────────────
223
224/// Serve the CellOS broker socket at `cellos_socket_path()`.
225///
226/// Authentication: caller UID must match the daemon's own UID (SO_PEERCRED / getpeereid).
227/// Handles `Resolve` and `RevokeForCell` requests over newline-terminated JSON.
228#[cfg(not(target_os = "windows"))]
229fn serve_cellos(
230    profile: &str,
231    password: Arc<String>,
232    cell_cache: CellCache,
233    stop: Arc<AtomicBool>,
234) -> Result<()> {
235    use std::os::unix::net::UnixListener;
236
237    let sock_path = cellos_socket_path();
238    if let Some(parent) = sock_path.parent() {
239        std::fs::create_dir_all(parent)?;
240    }
241    let _ = std::fs::remove_file(&sock_path);
242    let listener = UnixListener::bind(&sock_path)
243        .with_context(|| format!("CellOS: failed to bind {}", sock_path.display()))?;
244    #[cfg(unix)]
245    {
246        use std::os::unix::fs::PermissionsExt;
247        let _ = std::fs::set_permissions(&sock_path, std::fs::Permissions::from_mode(0o600));
248    }
249    let _cleanup = SocketCleanup(sock_path.to_string_lossy().into_owned());
250
251    listener.set_nonblocking(true)?;
252    // SAFETY: getuid() is a leaf POSIX libc call that takes no arguments and
253    // reads/writes no user memory. Per POSIX it cannot fail and has no
254    // preconditions on the caller, so calling it is unconditionally sound.
255    let daemon_uid = unsafe { libc::getuid() };
256    let vault_path = profile::vault_path(profile);
257    let audit = AuditLog::new(&profile::audit_log_path(profile));
258
259    loop {
260        if stop.load(Ordering::Relaxed) {
261            break;
262        }
263        match listener.accept() {
264            Ok((stream, _)) => {
265                stream.set_nonblocking(false)?;
266                let cred = match unix_peer_credential(&stream) {
267                    Ok(c) => c,
268                    Err(e) => {
269                        eprintln!("tsafe-agent: CellOS: peer credential failed: {e}");
270                        continue;
271                    }
272                };
273                if cred.uid != daemon_uid {
274                    let resp = CellosResponse::Err {
275                        error: "uid mismatch".to_string(),
276                    };
277                    let mut w = &stream;
278                    let _ = writeln!(w, "{}", serde_json::to_string(&resp).unwrap_or_default());
279                    continue;
280                }
281                handle_cellos_connection(
282                    &stream,
283                    cred.pid,
284                    &vault_path,
285                    &password,
286                    profile,
287                    &cell_cache,
288                    &audit,
289                );
290            }
291            Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
292                std::thread::sleep(Duration::from_millis(200));
293            }
294            Err(e) => return Err(e.into()),
295        }
296    }
297    Ok(())
298}
299
300#[cfg(not(target_os = "windows"))]
301fn handle_cellos_connection(
302    stream: &std::os::unix::net::UnixStream,
303    peer_pid: u32,
304    vault_path: &std::path::Path,
305    password: &str,
306    profile: &str,
307    cell_cache: &CellCache,
308    audit: &AuditLog,
309) {
310    use std::io::BufRead;
311
312    let mut reader = BufReader::new(stream);
313    let mut line = String::new();
314    if reader.read_line(&mut line).unwrap_or(0) == 0 {
315        return;
316    }
317
318    let req: CellosRequest = match serde_json::from_str(line.trim()) {
319        Ok(r) => r,
320        Err(e) => {
321            let resp = CellosResponse::Err {
322                error: format!("bad request: {e}"),
323            };
324            let mut w = stream;
325            let _ = writeln!(w, "{}", serde_json::to_string(&resp).unwrap_or_default());
326            return;
327        }
328    };
329
330    let resp = dispatch_cellos(
331        req, peer_pid, vault_path, password, profile, cell_cache, audit,
332    );
333    let mut w = stream;
334    let _ = writeln!(w, "{}", serde_json::to_string(&resp).unwrap_or_default());
335}
336
337#[cfg(not(target_os = "windows"))]
338fn dispatch_cellos(
339    req: CellosRequest,
340    peer_pid: u32,
341    vault_path: &std::path::Path,
342    password: &str,
343    profile: &str,
344    cell_cache: &CellCache,
345    audit: &AuditLog,
346) -> CellosResponse {
347    match req {
348        CellosRequest::Resolve {
349            key,
350            cell_id,
351            ttl_seconds: _,
352            cell_token,
353        } => {
354            // Validate or register the cell.
355            {
356                let mut cache = lock_cell_cache(cell_cache);
357                match cache.get(&cell_id) {
358                    Some(CellState::Revoked) => {
359                        return CellosResponse::Err {
360                            error: "cell revoked".to_string(),
361                        };
362                    }
363                    Some(CellState::Active(record)) => {
364                        if record.token != cell_token {
365                            return CellosResponse::Err {
366                                error: "cell_token mismatch".to_string(),
367                            };
368                        }
369                    }
370                    None => {
371                        // First Resolve for this cell — register it.
372                        cache.insert(
373                            cell_id.clone(),
374                            CellState::Active(CellRecord {
375                                pid: peer_pid,
376                                token: cell_token.clone(),
377                            }),
378                        );
379                    }
380                }
381            }
382
383            // Open vault and look up the secret.
384            let value = match Vault::open_read_only(vault_path, password.as_bytes()) {
385                Ok(v) => match v.get(&key) {
386                    Ok(s) => s.to_string(),
387                    Err(_) => {
388                        return CellosResponse::Err {
389                            error: format!("key not found: {key}"),
390                        };
391                    }
392                },
393                Err(e) => {
394                    return CellosResponse::Err {
395                        error: format!("vault error: {e}"),
396                    };
397                }
398            };
399
400            // Audit the resolve.
401            audit
402                .append(
403                    &AuditEntry::success(profile, "cellos-resolve", Some(&key)).with_context(
404                        AuditContext::from_cellos(AuditCellosContext {
405                            cellos_cell_id: cell_id,
406                            cell_token: Some(cell_token),
407                        }),
408                    ),
409                )
410                .ok();
411
412            CellosResponse::Value { value }
413        }
414
415        CellosRequest::RevokeForCell { cell_id } => {
416            {
417                let mut cache = lock_cell_cache(cell_cache);
418                cache.insert(cell_id.clone(), CellState::Revoked);
419            }
420
421            audit
422                .append(
423                    &AuditEntry::success(profile, "cellos-revoke", None).with_context(
424                        AuditContext::from_cellos(AuditCellosContext {
425                            cellos_cell_id: cell_id,
426                            cell_token: None,
427                        }),
428                    ),
429                )
430                .ok();
431
432            CellosResponse::Ok
433        }
434    }
435}
436
437// ── Named-pipe server ─────────────────────────────────────────────────────────
438
439#[cfg(target_os = "windows")]
440fn serve(
441    pipe: &str,
442    pw: &Password,
443    session: &mut AgentSession,
444    _requesting_pid: u32,
445    deadline: Instant,
446    stop: Arc<AtomicBool>,
447) -> Result<()> {
448    use std::fs::File;
449    use std::os::windows::io::FromRawHandle;
450
451    // Strip the leading \\.\pipe\ prefix — CreateNamedPipeW takes the full path.
452    let pipe_wide: Vec<u16> = pipe.encode_utf16().chain(std::iter::once(0)).collect();
453
454    loop {
455        if stop.load(Ordering::Relaxed) || Instant::now() >= deadline {
456            break;
457        }
458        let handle = unsafe { windows_create_named_pipe(&pipe_wide)? };
459
460        // ConnectNamedPipe with a timeout so we can check the deadline.
461        let connected = unsafe { windows_connect_with_timeout(handle, 5_000) };
462        if !connected {
463            unsafe { windows_close_handle(handle) };
464            continue;
465        }
466
467        let client_file = unsafe { File::from_raw_handle(handle as _) };
468        let mut reader = BufReader::new(&client_file);
469        let mut writer = &client_file;
470        handle_connection(
471            &mut reader,
472            &mut writer,
473            pw,
474            session,
475            Some(unsafe { windows_get_named_pipe_client_process_id(handle)? }),
476            &stop,
477        )?;
478    }
479
480    Ok(())
481}
482
483pub(crate) fn handle_connection(
484    reader: &mut impl BufRead,
485    writer: &mut impl Write,
486    pw: &Password,
487    session: &mut AgentSession,
488    peer_pid: Option<u32>,
489    stop: &Arc<AtomicBool>,
490) -> Result<()> {
491    let mut line = String::new();
492    if reader.read_line(&mut line).unwrap_or(0) == 0 {
493        return Ok(());
494    }
495
496    let req: AgentRequest = match serde_json::from_str(line.trim()) {
497        Ok(r) => r,
498        Err(e) => {
499            let resp = AgentResponse::Err {
500                reason: format!("bad request: {e}"),
501            };
502            let _ = writeln!(writer, "{}", serde_json::to_string(&resp)?);
503            return Ok(());
504        }
505    };
506
507    let outcome = session.handle_request(&req, peer_pid, &pw.0, Instant::now());
508    if outcome.stop {
509        stop.store(true, Ordering::Relaxed);
510    }
511    let resp = outcome.response;
512
513    let _ = writeln!(writer, "{}", serde_json::to_string(&resp)?);
514    Ok(())
515}
516
517pub(crate) fn spawn_expiry_watchdog(pipe: String, deadline: Instant, stop: Arc<AtomicBool>) {
518    std::thread::spawn(move || loop {
519        if stop.load(Ordering::Relaxed) {
520            return;
521        }
522
523        let now = Instant::now();
524        if now >= deadline {
525            stop.store(true, Ordering::Relaxed);
526            wake_listener(&pipe);
527            return;
528        }
529
530        std::thread::sleep((deadline - now).min(Duration::from_millis(200)));
531    });
532}
533
534#[cfg(target_os = "windows")]
535fn wake_listener(pipe: &str) {
536    let _ = windows_connect_pipe_client(pipe);
537}
538
539#[cfg(not(target_os = "windows"))]
540fn wake_listener(pipe: &str) {
541    let _ = std::os::unix::net::UnixStream::connect(pipe);
542}
543
544// ── Password prompt ───────────────────────────────────────────────────────────
545
546fn rpassword_read() -> Result<String> {
547    // The agent is always launched interactively (the user just approved the toast).
548    rpassword::prompt_password("Vault password: ").context("failed to read password")
549}
550
551// ── Windows FFI shims ─────────────────────────────────────────────────────────
552
553#[cfg(target_os = "windows")]
554mod ffi {
555    use std::ffi::c_void;
556
557    extern "system" {
558        pub fn CreateNamedPipeW(
559            name: *const u16,
560            open_mode: u32,
561            pipe_mode: u32,
562            max_instances: u32,
563            out_buf: u32,
564            in_buf: u32,
565            default_timeout: u32,
566            security: *mut c_void,
567        ) -> *mut c_void;
568        pub fn CreateFileW(
569            name: *const u16,
570            access: u32,
571            share: u32,
572            security: *mut c_void,
573            creation: u32,
574            flags: u32,
575            template: *mut c_void,
576        ) -> *mut c_void;
577        pub fn ConnectNamedPipe(pipe: *mut c_void, overlapped: *mut c_void) -> i32;
578        pub fn CloseHandle(handle: *mut c_void) -> i32;
579        pub fn GetNamedPipeClientProcessId(pipe: *mut c_void, client_process_id: *mut u32) -> i32;
580        #[allow(dead_code)]
581        pub fn WaitForSingleObject(handle: *mut c_void, ms: u32) -> u32;
582    }
583}
584
585#[cfg(target_os = "windows")]
586unsafe fn windows_create_named_pipe(pipe_wide: &[u16]) -> Result<*mut std::ffi::c_void> {
587    // PIPE_ACCESS_DUPLEX = 3, PIPE_TYPE_BYTE | PIPE_WAIT = 0, 1 instance, 4096 bufs
588    let h = ffi::CreateNamedPipeW(
589        pipe_wide.as_ptr(),
590        3,    // PIPE_ACCESS_DUPLEX
591        0x00, // PIPE_TYPE_BYTE | PIPE_WAIT
592        1,    // 1 instance — only the approved process can connect
593        4096,
594        4096,
595        0,
596        std::ptr::null_mut(),
597    );
598    if h as isize == -1 || h.is_null() {
599        anyhow::bail!("CreateNamedPipeW failed");
600    }
601    Ok(h)
602}
603
604#[cfg(target_os = "windows")]
605unsafe fn windows_connect_with_timeout(handle: *mut std::ffi::c_void, _ms: u32) -> bool {
606    // ConnectNamedPipe blocks until a client connects or the handle is closed.
607    // A watchdog connection wakes the server when the session expires.
608    ffi::ConnectNamedPipe(handle, std::ptr::null_mut()) != 0
609}
610
611#[cfg(target_os = "windows")]
612unsafe fn windows_close_handle(handle: *mut std::ffi::c_void) {
613    ffi::CloseHandle(handle);
614}
615
616#[cfg(target_os = "windows")]
617unsafe fn windows_get_named_pipe_client_process_id(handle: *mut std::ffi::c_void) -> Result<u32> {
618    let mut pid = 0u32;
619    if ffi::GetNamedPipeClientProcessId(handle, &mut pid) == 0 {
620        anyhow::bail!("GetNamedPipeClientProcessId failed");
621    }
622    Ok(pid)
623}
624
625#[cfg(target_os = "windows")]
626fn windows_connect_pipe_client(pipe: &str) -> Result<std::fs::File> {
627    use std::os::windows::ffi::OsStrExt;
628    use std::os::windows::io::FromRawHandle;
629
630    let wide: Vec<u16> = std::ffi::OsStr::new(pipe)
631        .encode_wide()
632        .chain(std::iter::once(0))
633        .collect();
634
635    let handle = unsafe {
636        ffi::CreateFileW(
637            wide.as_ptr(),
638            0xC000_0000, // GENERIC_READ | GENERIC_WRITE
639            0,
640            std::ptr::null_mut(),
641            3,   // OPEN_EXISTING
642            128, // FILE_ATTRIBUTE_NORMAL
643            std::ptr::null_mut(),
644        )
645    };
646
647    if handle.is_null() || handle as isize == -1 {
648        anyhow::bail!("CreateFileW failed");
649    }
650
651    Ok(unsafe { std::fs::File::from_raw_handle(handle as _) })
652}
653
654// Unix domain socket server — mirrors the Windows named-pipe server above.
655#[cfg(not(target_os = "windows"))]
656fn serve(
657    pipe: &str,
658    pw: &Password,
659    session: &mut AgentSession,
660    _requesting_pid: u32,
661    deadline: Instant,
662    stop: Arc<AtomicBool>,
663) -> Result<()> {
664    use std::os::unix::net::UnixListener;
665
666    // Clean up stale socket from a previous crash.
667    let _ = std::fs::remove_file(pipe);
668
669    let listener =
670        UnixListener::bind(pipe).with_context(|| format!("failed to bind Unix socket: {pipe}"))?;
671
672    // Set socket file to owner-only (0600).
673    #[cfg(unix)]
674    {
675        use std::os::unix::fs::PermissionsExt;
676        let _ = std::fs::set_permissions(pipe, std::fs::Permissions::from_mode(0o600));
677    }
678
679    // Non-blocking accept with timeout so we can check deadline + PID liveness.
680    listener.set_nonblocking(true)?;
681
682    let _cleanup = SocketCleanup(pipe.to_string());
683
684    loop {
685        if stop.load(Ordering::Relaxed) || Instant::now() >= deadline {
686            break;
687        }
688        match listener.accept() {
689            Ok((stream, _)) => {
690                stream.set_nonblocking(false)?;
691                let mut reader = BufReader::new(&stream);
692                let mut writer = &stream;
693                handle_connection(
694                    &mut reader,
695                    &mut writer,
696                    pw,
697                    session,
698                    Some(unix_peer_pid(&stream)?),
699                    &stop,
700                )?;
701            }
702            Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
703                std::thread::sleep(Duration::from_millis(200));
704            }
705            Err(e) => return Err(e.into()),
706        }
707    }
708
709    Ok(())
710}
711
712/// RAII guard to remove the socket file on exit (normal or panic).
713#[cfg(not(target_os = "windows"))]
714struct SocketCleanup(String);
715
716#[cfg(not(target_os = "windows"))]
717impl Drop for SocketCleanup {
718    fn drop(&mut self) {
719        let _ = std::fs::remove_file(&self.0);
720    }
721}
722
723#[cfg(unix)]
724struct PeerCredential {
725    pid: u32,
726    uid: u32,
727}
728
729/// Retrieve both the PID and UID of the connecting process in one syscall where possible.
730#[cfg(target_os = "linux")]
731fn unix_peer_credential(
732    stream: &std::os::unix::net::UnixStream,
733) -> std::io::Result<PeerCredential> {
734    use std::mem::size_of;
735    use std::os::fd::AsRawFd;
736    let fd = stream.as_raw_fd();
737    // SAFETY: `libc::ucred` is a plain POD struct (three integral fields) and
738    // zeroing it is a valid initialization for the `getsockopt(SO_PEERCRED)`
739    // out-parameter contract.
740    let mut cred: libc::ucred = unsafe { std::mem::zeroed() };
741    let mut len = size_of::<libc::ucred>() as libc::socklen_t;
742    // SAFETY: `fd` is owned by `stream` (a `&UnixStream`) and remains valid for
743    // the duration of this call — the borrow holds the file descriptor open.
744    // The `optval` pointer addresses the local `cred` stack variable, which
745    // lives until the end of this function and is sized exactly
746    // `size_of::<libc::ucred>()` (matching the value passed via `&mut len`).
747    // `getsockopt` only writes through the pointer; the kernel will not
748    // exceed `len` bytes per the SO_PEERCRED contract.
749    let rc = unsafe {
750        libc::getsockopt(
751            fd,
752            libc::SOL_SOCKET,
753            libc::SO_PEERCRED,
754            &mut cred as *mut _ as *mut libc::c_void,
755            &mut len,
756        )
757    };
758    if rc != 0 {
759        return Err(std::io::Error::last_os_error());
760    }
761    Ok(PeerCredential {
762        pid: cred.pid as u32,
763        uid: cred.uid,
764    })
765}
766
767#[cfg(target_os = "macos")]
768fn unix_peer_credential(
769    stream: &std::os::unix::net::UnixStream,
770) -> std::io::Result<PeerCredential> {
771    use std::mem::size_of;
772    use std::os::fd::AsRawFd;
773    let fd = stream.as_raw_fd();
774    let mut uid: libc::uid_t = 0;
775    let mut gid: libc::gid_t = 0;
776    let mut pid: libc::pid_t = 0;
777    let mut len = size_of::<libc::pid_t>() as libc::socklen_t;
778    // SAFETY: `fd` is owned by `stream` (a `&UnixStream`) and is valid for the
779    // entirety of this call. `&mut uid` and `&mut gid` point to local stack
780    // variables sized exactly as `getpeereid` requires (`uid_t` and `gid_t`
781    // respectively); the kernel writes one value to each and never reads them.
782    let rc_uid = unsafe { libc::getpeereid(fd, &mut uid, &mut gid) };
783    // SAFETY: same `fd` lifetime invariant as above. The `optval` pointer
784    // addresses the local `pid` (a `libc::pid_t`), and `len` is initialized to
785    // `size_of::<libc::pid_t>()` so the kernel will write at most that many
786    // bytes per the LOCAL_PEERPID contract on Darwin.
787    let rc_pid = unsafe {
788        libc::getsockopt(
789            fd,
790            libc::SOL_LOCAL,
791            libc::LOCAL_PEERPID,
792            &mut pid as *mut _ as *mut libc::c_void,
793            &mut len,
794        )
795    };
796    if rc_uid != 0 {
797        return Err(std::io::Error::last_os_error());
798    }
799    if rc_pid != 0 {
800        return Err(std::io::Error::last_os_error());
801    }
802    Ok(PeerCredential {
803        pid: pid as u32,
804        uid: uid as u32,
805    })
806}
807
808#[cfg(all(unix, not(any(target_os = "linux", target_os = "macos"))))]
809fn unix_peer_credential(
810    _stream: &std::os::unix::net::UnixStream,
811) -> std::io::Result<PeerCredential> {
812    Err(std::io::Error::new(
813        std::io::ErrorKind::Unsupported,
814        "peer credentials unsupported on this platform",
815    ))
816}
817
818#[cfg(unix)]
819fn unix_peer_pid(stream: &std::os::unix::net::UnixStream) -> Result<u32> {
820    unix_peer_credential(stream)
821        .map(|c| c.pid)
822        .map_err(Into::into)
823}
824
825#[cfg(test)]
826mod tests {
827    use super::*;
828
829    fn run_request(
830        req: AgentRequest,
831        peer_pid: Option<u32>,
832        absolute_deadline: Instant,
833    ) -> (AgentResponse, bool) {
834        let stop = Arc::new(AtomicBool::new(false));
835        let mut input = std::io::Cursor::new(format!("{}\n", serde_json::to_string(&req).unwrap()));
836        let mut output = Vec::new();
837        let password = Password("secret".to_string());
838        // Use idle_secs matching the absolute window so idle doesn't fire first in tests.
839        let idle_secs = if absolute_deadline > Instant::now() {
840            (absolute_deadline - Instant::now()).as_secs().max(1)
841        } else {
842            1
843        };
844        let mut session = AgentSession::new("token-123", idle_secs, absolute_deadline);
845
846        handle_connection(
847            &mut input,
848            &mut output,
849            &password,
850            &mut session,
851            peer_pid,
852            &stop,
853        )
854        .unwrap();
855
856        let response: AgentResponse = serde_json::from_slice(&output).unwrap();
857        (response, stop.load(Ordering::Relaxed))
858    }
859
860    #[test]
861    fn open_vault_allows_matching_peer_pid() {
862        let (response, stop) = run_request(
863            AgentRequest::OpenVault {
864                profile: "default".into(),
865                session_token: "token-123".into(),
866                requesting_pid: 4242,
867            },
868            Some(4242),
869            Instant::now() + Duration::from_secs(60),
870        );
871
872        assert!(!stop);
873        match response {
874            AgentResponse::Password { password } => assert_eq!(password, "secret"),
875            other => panic!("expected password response, got {other:?}"),
876        }
877    }
878
879    #[test]
880    fn open_vault_rejects_pid_mismatch() {
881        let (response, stop) = run_request(
882            AgentRequest::OpenVault {
883                profile: "default".into(),
884                session_token: "token-123".into(),
885                requesting_pid: 4242,
886            },
887            Some(9001),
888            Instant::now() + Duration::from_secs(60),
889        );
890
891        assert!(!stop);
892        match response {
893            AgentResponse::Err { reason } => {
894                assert!(reason.contains("does not match the connecting process"));
895            }
896            other => panic!("expected authorization error, got {other:?}"),
897        }
898    }
899
900    #[test]
901    fn expired_session_rejects_requests_and_stops() {
902        let (response, stop) = run_request(
903            AgentRequest::Ping,
904            Some(4242),
905            Instant::now() - Duration::from_secs(1),
906        );
907
908        assert!(stop);
909        match response {
910            AgentResponse::Err { reason } => assert!(
911                reason.contains("agent session expired"),
912                "unexpected: {reason}"
913            ),
914            other => panic!("expected expiry error, got {other:?}"),
915        }
916    }
917
918    #[test]
919    fn lock_request_transitions_session_and_stops() {
920        let (response, stop) = run_request(
921            AgentRequest::Lock {
922                session_token: "token-123".into(),
923            },
924            Some(4242),
925            Instant::now() + Duration::from_secs(60),
926        );
927
928        assert!(stop);
929        assert!(matches!(response, AgentResponse::Ok));
930    }
931}