ssh_agent_mux/
lib.rs

1use std::{
2    collections::HashMap,
3    os::unix::net::UnixStream,
4    path::{Path, PathBuf},
5    sync::Arc,
6};
7
8use ssh_agent_lib::{
9    agent::{self, Agent, ListeningSocket, Session},
10    client,
11    error::AgentError,
12    proto::{extension::QueryResponse, Extension, Identity, SignRequest},
13    ssh_key::{public::KeyData as PubKeyData, Signature},
14};
15use tokio::{
16    net::UnixListener,
17    sync::{Mutex, OwnedMutexGuard},
18};
19
20type KnownPubKeysMap = HashMap<PubKeyData, PathBuf>;
21type KnownPubKeys = Arc<Mutex<KnownPubKeysMap>>;
22
23/// Only the `request_identities`, `sign`, and `extension` commands are implemented. For
24/// `extension`, only the `session-bind@openssh.com` and `query` extensions are supported.
25#[ssh_agent_lib::async_trait]
26impl Session for MuxAgent {
27    async fn request_identities(&mut self) -> Result<Vec<Identity>, AgentError> {
28        let mut known_keys = self.known_keys.clone().lock_owned().await;
29        self.refresh_identities(&mut known_keys).await
30    }
31
32    async fn sign(&mut self, request: SignRequest) -> Result<Signature, AgentError> {
33        if let Some(agent_sock_path) = self.get_agent_sock_for_pubkey(&request.pubkey).await? {
34            log::info!(
35                "Request: signature with key {} from upstream agent <{}>",
36                request.pubkey.fingerprint(Default::default()),
37                agent_sock_path.display()
38            );
39
40            let mut client = self.connect_upstream_agent(agent_sock_path)?;
41            client.sign(request).await
42        } else {
43            let fingerprint = request.pubkey.fingerprint(Default::default());
44            log::error!("No upstream agent found for public key {}", &fingerprint);
45            log::trace!("Known keys: {:?}", self.known_keys);
46            Err(AgentError::Other(
47                format!("No agent found for public key: {}", &fingerprint).into(),
48            ))
49        }
50    }
51
52    async fn extension(&mut self, request: Extension) -> Result<Option<Extension>, AgentError> {
53        match request.name.as_str() {
54            "query" => Ok(Some(Extension::new_message(QueryResponse {
55                extensions: ["session-bind@openssh.com"].map(String::from).to_vec(),
56            })?)),
57            "session-bind@openssh.com" => {
58                let mut session_bind_suceeded = false;
59                for sock_path in &self.socket_paths {
60                    // Try extension on upstream agents; discard any upstream failures from agents
61                    // that don't support the extension (but the default is Failure if there are no
62                    // successful upstream responses)
63                    if let Ok(mut client) = self.connect_upstream_agent(sock_path) {
64                        match client.extension(request.clone()).await {
65                            // Any agent succeeding is an overall success
66                            Ok(v) => {
67                                session_bind_suceeded = true;
68                                if v.is_some() {
69                                    log::warn!("session-bind@openssh.com request succeeded on socket <{}>, but an invalid response was received", sock_path.display());
70                                }
71                            }
72                            // Don't propagate upstream lack of extension support
73                            Err(AgentError::Failure) => continue,
74                            // Report but ignore any unexpected errors
75                            Err(e) => {
76                                log::error!("Unexpected error on socket <{}> when requesting session-bind@openssh.com extension: {}", sock_path.display(), e);
77                                continue;
78                            }
79                        }
80                    }
81                }
82                if session_bind_suceeded {
83                    Ok(None)
84                } else {
85                    Err(AgentError::Failure)
86                }
87            }
88            _ => Err(AgentError::Failure),
89        }
90    }
91}
92
93#[derive(Clone)]
94pub struct MuxAgent {
95    socket_paths: Vec<PathBuf>,
96    known_keys: KnownPubKeys,
97}
98
99impl MuxAgent {
100    /// Run a MuxAgent, listening for SSH agent protocol requests on `listen_sock`, forwarding
101    /// requests to the specified paths in `agent_socks`
102    pub async fn run<I, P>(listen_sock: impl AsRef<Path>, agent_socks: I) -> Result<(), AgentError>
103    where
104        I: IntoIterator<Item = P>,
105        P: AsRef<Path>,
106    {
107        let listen_sock = listen_sock.as_ref();
108        let socket_paths: Vec<_> = agent_socks
109            .into_iter()
110            .map(|p| p.as_ref().to_path_buf())
111            .collect();
112        if socket_paths.is_empty() {
113            log::warn!("Mux agent running but no upstream agents configured");
114        }
115        log::info!(
116            "Starting agent for {} upstream agents; listening on <{}>",
117            socket_paths.len(),
118            listen_sock.display()
119        );
120        log::debug!("Upstream agent sockets: {:?}", &socket_paths);
121
122        let listen_sock = match SelfDeletingUnixListener::bind(listen_sock) {
123            Ok(s) => s,
124            err => {
125                log::error!("Failed to open listening socket at {}", listen_sock.display());
126                err?
127            }
128        };
129        let this = Self {
130            socket_paths,
131            known_keys: Default::default(),
132        };
133        agent::listen(listen_sock, this).await
134    }
135
136    fn connect_upstream_agent(
137        &self,
138        sock_path: impl AsRef<Path>,
139    ) -> Result<Box<dyn Session>, AgentError> {
140        let sock_path = sock_path.as_ref();
141        let stream = UnixStream::connect(sock_path)?;
142        let client = client::connect(stream.into())
143            .map_err(|e| AgentError::Other(format!("Failed to connect to agent at {}: {}", sock_path.display(), e).into()))?;
144        log::trace!(
145            "Connected to upstream agent on socket: {}",
146            sock_path.display()
147        );
148        Ok(client)
149    }
150
151    async fn get_agent_sock_for_pubkey(
152        &mut self,
153        pubkey: &PubKeyData,
154    ) -> Result<Option<PathBuf>, AgentError> {
155        // Refresh available identities if the public key isn't found;
156        // hold lock for duration of signing operation
157        let mut known_keys = self.known_keys.clone().lock_owned().await;
158        if !known_keys.contains_key(pubkey) {
159            log::debug!("Key not found, re-requesting keys from upstream agents");
160            let _ = self.refresh_identities(&mut known_keys).await?;
161        }
162        let maybe_agent = known_keys.get(pubkey).cloned();
163        Ok(maybe_agent)
164    }
165
166    // Factored out so that the known_keys lock can be held across a total request that includes a
167    // refresh of keys from upstream agents
168    async fn refresh_identities(
169        &mut self,
170        known_keys: &mut OwnedMutexGuard<KnownPubKeysMap>,
171    ) -> Result<Vec<Identity>, AgentError> {
172        let mut identities = vec![];
173        known_keys.clear();
174
175        for sock_path in &self.socket_paths {
176            let mut client = match self.connect_upstream_agent(sock_path) {
177                Ok(c) => c,
178                Err(_) => {
179                    log::warn!(
180                        "Ignoring missing upstream agent socket: {}",
181                        sock_path.display()
182                    );
183                    continue;
184                }
185            };
186            let agent_identities = client.request_identities().await?;
187            {
188                for id in &agent_identities {
189                    known_keys.insert(id.pubkey.clone(), sock_path.clone());
190                }
191            }
192            identities.extend(agent_identities);
193        }
194
195        Ok(identities)
196    }
197}
198
199impl Agent<SelfDeletingUnixListener> for MuxAgent {
200    #[doc = "Create new session object when a new socket is accepted."]
201    fn new_session(
202        &mut self,
203        _socket: &<SelfDeletingUnixListener as ListeningSocket>::Stream,
204    ) -> impl Session {
205        self.clone()
206    }
207}
208
209#[derive(Debug)]
210/// A wrapper for UnixListener that keeps the socket path around so it can be deleted
211struct SelfDeletingUnixListener {
212    path: PathBuf,
213    listener: UnixListener,
214}
215
216impl SelfDeletingUnixListener {
217    fn bind(path: impl AsRef<Path>) -> std::io::Result<Self> {
218        let path = path.as_ref().to_path_buf();
219        UnixListener::bind(&path).map(|listener| Self { path, listener })
220    }
221}
222
223impl Drop for SelfDeletingUnixListener {
224    fn drop(&mut self) {
225        let _ = std::fs::remove_file(&self.path);
226    }
227}
228
229#[ssh_agent_lib::async_trait]
230impl ListeningSocket for SelfDeletingUnixListener {
231    type Stream = tokio::net::UnixStream;
232
233    async fn accept(&mut self) -> std::io::Result<Self::Stream> {
234        UnixListener::accept(&self.listener)
235            .await
236            .map(|(s, _addr)| s)
237    }
238}