ssh_agent_mux/
lib.rs

1use std::{
2    collections::HashMap,
3    os::unix::net::UnixStream,
4    path::{Path, PathBuf},
5    sync::Arc,
6};
7
8use ssh_agent_lib::{
9    agent::{self, Agent, ListeningSocket, Session},
10    client,
11    error::AgentError,
12    proto::{extension::QueryResponse, Extension, Identity, SignRequest},
13    ssh_key::{public::KeyData as PubKeyData, Signature},
14};
15use tokio::{
16    net::UnixListener,
17    sync::{Mutex, OwnedMutexGuard},
18};
19
20type KnownPubKeysMap = HashMap<PubKeyData, PathBuf>;
21type KnownPubKeys = Arc<Mutex<KnownPubKeysMap>>;
22
23/// Only the `request_identities`, `sign`, and `extension` commands are implemented. For
24/// `extension`, only the `session-bind@openssh.com` and `query` extensions are supported.
25#[ssh_agent_lib::async_trait]
26impl Session for MuxAgent {
27    async fn request_identities(&mut self) -> Result<Vec<Identity>, AgentError> {
28        log::trace!("incoming: request_identities");
29        let mut known_keys = self.known_keys.clone().lock_owned().await;
30        self.refresh_identities(&mut known_keys).await
31    }
32
33    async fn sign(&mut self, request: SignRequest) -> Result<Signature, AgentError> {
34        let fingerprint = request.pubkey.fingerprint(Default::default());
35        log::trace!("incoming: sign({})", &fingerprint);
36
37        if let Some(agent_sock_path) = self.get_agent_sock_for_pubkey(&request.pubkey).await? {
38            log::info!(
39                "Requesting signature with key {} from upstream agent <{}>",
40                &fingerprint,
41                agent_sock_path.display()
42            );
43
44            let mut client = self.connect_upstream_agent(agent_sock_path)?;
45            client.sign(request).await
46        } else {
47            log::error!("No upstream agent found for public key {}", &fingerprint);
48            log::trace!("Known keys:\n{:#?}", self.known_keys);
49            Err(AgentError::Other(
50                format!("No agent found for public key: {}", &fingerprint).into(),
51            ))
52        }
53    }
54
55    async fn extension(&mut self, request: Extension) -> Result<Option<Extension>, AgentError> {
56        log::trace!("incoming: extension({})", request.name);
57        match request.name.as_str() {
58            "query" => Ok(Some(Extension::new_message(QueryResponse {
59                extensions: ["session-bind@openssh.com"].map(String::from).to_vec(),
60            })?)),
61            "session-bind@openssh.com" => {
62                let mut session_bind_suceeded = false;
63                for sock_path in &self.socket_paths {
64                    // Try extension on upstream agents; discard any upstream failures from agents
65                    // that don't support the extension (but the default is Failure if there are no
66                    // successful upstream responses)
67                    if let Ok(mut client) = self.connect_upstream_agent(sock_path) {
68                        match client.extension(request.clone()).await {
69                            // Any agent succeeding is an overall success
70                            Ok(v) => {
71                                session_bind_suceeded = true;
72                                if v.is_some() {
73                                    log::warn!("session-bind@openssh.com request succeeded on socket <{}>, but an invalid response was received", sock_path.display());
74                                }
75                            }
76                            // Don't propagate upstream lack of extension support
77                            Err(AgentError::Failure) => continue,
78                            // Report but ignore any unexpected errors
79                            Err(e) => {
80                                log::error!("Unexpected error on socket <{}> when requesting session-bind@openssh.com extension: {}", sock_path.display(), e);
81                                continue;
82                            }
83                        }
84                    }
85                }
86                if session_bind_suceeded {
87                    Ok(None)
88                } else {
89                    Err(AgentError::Failure)
90                }
91            }
92            _ => Err(AgentError::Failure),
93        }
94    }
95}
96
97#[derive(Clone)]
98pub struct MuxAgent {
99    socket_paths: Vec<PathBuf>,
100    known_keys: KnownPubKeys,
101}
102
103impl MuxAgent {
104    /// Run a MuxAgent, listening for SSH agent protocol requests on `listen_sock`, forwarding
105    /// requests to the specified paths in `agent_socks`
106    pub async fn run<I, P>(listen_sock: impl AsRef<Path>, agent_socks: I) -> Result<(), AgentError>
107    where
108        I: IntoIterator<Item = P>,
109        P: AsRef<Path>,
110    {
111        let listen_sock = listen_sock.as_ref();
112        let socket_paths: Vec<_> = agent_socks
113            .into_iter()
114            .map(|p| p.as_ref().to_path_buf())
115            .collect();
116        if socket_paths.is_empty() {
117            log::warn!("Mux agent running but no upstream agents configured");
118        }
119        log::info!(
120            "Starting agent for {} upstream agents; listening on <{}>",
121            socket_paths.len(),
122            listen_sock.display()
123        );
124        log::debug!("Upstream agent sockets: {:?}", &socket_paths);
125
126        let listen_sock = match SelfDeletingUnixListener::bind(listen_sock) {
127            Ok(s) => s,
128            err => {
129                log::error!(
130                    "Failed to open listening socket at {}",
131                    listen_sock.display()
132                );
133                err?
134            }
135        };
136        let this = Self {
137            socket_paths,
138            known_keys: Default::default(),
139        };
140        agent::listen(listen_sock, this).await
141    }
142
143    fn connect_upstream_agent(
144        &self,
145        sock_path: impl AsRef<Path>,
146    ) -> Result<Box<dyn Session>, AgentError> {
147        let sock_path = sock_path.as_ref();
148        let stream = UnixStream::connect(sock_path)?;
149        let client = client::connect(stream.into()).map_err(|e| {
150            AgentError::Other(
151                format!(
152                    "Failed to connect to agent at {}: {}",
153                    sock_path.display(),
154                    e
155                )
156                .into(),
157            )
158        })?;
159        log::trace!(
160            "Connected to upstream agent on socket: {}",
161            sock_path.display()
162        );
163        Ok(client)
164    }
165
166    async fn get_agent_sock_for_pubkey(
167        &mut self,
168        pubkey: &PubKeyData,
169    ) -> Result<Option<PathBuf>, AgentError> {
170        // Refresh available identities if the public key isn't found;
171        // hold lock for duration of signing operation
172        let mut known_keys = self.known_keys.clone().lock_owned().await;
173        if !known_keys.contains_key(pubkey) {
174            log::debug!("Key not found, re-requesting keys from upstream agents");
175            let _ = self.refresh_identities(&mut known_keys).await?;
176        }
177        let maybe_agent = known_keys.get(pubkey).cloned();
178        Ok(maybe_agent)
179    }
180
181    // Factored out so that the known_keys lock can be held across a total request that includes a
182    // refresh of keys from upstream agents
183    async fn refresh_identities(
184        &mut self,
185        known_keys: &mut OwnedMutexGuard<KnownPubKeysMap>,
186    ) -> Result<Vec<Identity>, AgentError> {
187        let mut identities = vec![];
188        known_keys.clear();
189
190        log::debug!("Refreshing identities");
191        for sock_path in &self.socket_paths {
192            let mut client = match self.connect_upstream_agent(sock_path) {
193                Ok(c) => c,
194                Err(_) => {
195                    log::warn!(
196                        "Ignoring missing upstream agent socket: {}",
197                        sock_path.display()
198                    );
199                    continue;
200                }
201            };
202            let agent_identities = client.request_identities().await?;
203            {
204                for id in &agent_identities {
205                    known_keys.insert(id.pubkey.clone(), sock_path.clone());
206                }
207            }
208            log::trace!("Got {} identities from {}", agent_identities.len(), sock_path.display());
209            identities.extend(agent_identities);
210        }
211
212        Ok(identities)
213    }
214}
215
216impl Agent<SelfDeletingUnixListener> for MuxAgent {
217    #[doc = "Create new session object when a new socket is accepted."]
218    fn new_session(
219        &mut self,
220        _socket: &<SelfDeletingUnixListener as ListeningSocket>::Stream,
221    ) -> impl Session {
222        self.clone()
223    }
224}
225
226#[derive(Debug)]
227/// A wrapper for UnixListener that keeps the socket path around so it can be deleted
228struct SelfDeletingUnixListener {
229    path: PathBuf,
230    listener: UnixListener,
231}
232
233impl SelfDeletingUnixListener {
234    fn bind(path: impl AsRef<Path>) -> std::io::Result<Self> {
235        let path = path.as_ref().to_path_buf();
236        UnixListener::bind(&path).map(|listener| Self { path, listener })
237    }
238}
239
240impl Drop for SelfDeletingUnixListener {
241    fn drop(&mut self) {
242        log::debug!("Cleaning up socket {}", self.path.display());
243        let _ = std::fs::remove_file(&self.path);
244    }
245}
246
247#[ssh_agent_lib::async_trait]
248impl ListeningSocket for SelfDeletingUnixListener {
249    type Stream = tokio::net::UnixStream;
250
251    async fn accept(&mut self) -> std::io::Result<Self::Stream> {
252        UnixListener::accept(&self.listener)
253            .await
254            .map(|(s, _addr)| s)
255    }
256}