ssh_agent_mux/
lib.rs

1use std::{
2    collections::HashMap,
3    os::unix::net::UnixStream,
4    path::{Path, PathBuf},
5    sync::Arc,
6};
7
8use ssh_agent_lib::{
9    agent::{self, Agent, ListeningSocket, Session},
10    client,
11    error::AgentError,
12    proto::{extension::QueryResponse, Extension, Identity, SignRequest},
13    ssh_key::{public::KeyData as PubKeyData, Signature},
14};
15use tokio::{
16    net::UnixListener,
17    sync::{Mutex, OwnedMutexGuard},
18};
19
20type KnownPubKeysMap = HashMap<PubKeyData, PathBuf>;
21type KnownPubKeys = Arc<Mutex<KnownPubKeysMap>>;
22
23/// Only the `request_identities`, `sign`, and `extension` commands are implemented. For
24/// `extension`, only the `session-bind@openssh.com` and `query` extensions are supported.
25#[ssh_agent_lib::async_trait]
26impl Session for MuxAgent {
27    async fn request_identities(&mut self) -> Result<Vec<Identity>, AgentError> {
28        let mut known_keys = self.known_keys.clone().lock_owned().await;
29        self.refresh_identities(&mut known_keys).await
30    }
31
32    async fn sign(&mut self, request: SignRequest) -> Result<Signature, AgentError> {
33        if let Some(agent_sock_path) = self.get_agent_sock_for_pubkey(&request.pubkey).await? {
34            log::info!(
35                "Request: signature with key {} from upstream agent <{}>",
36                request.pubkey.fingerprint(Default::default()),
37                agent_sock_path.display()
38            );
39
40            let mut client = self.connect_upstream_agent(agent_sock_path)?;
41            client.sign(request).await
42        } else {
43            let fingerprint = request.pubkey.fingerprint(Default::default());
44            log::error!("No upstream agent found for public key {}", &fingerprint);
45            log::trace!("Known keys: {:?}", self.known_keys);
46            Err(AgentError::Other(
47                format!("No agent found for public key: {}", &fingerprint).into(),
48            ))
49        }
50    }
51
52    async fn extension(&mut self, request: Extension) -> Result<Option<Extension>, AgentError> {
53        match request.name.as_str() {
54            "query" => Ok(Some(Extension::new_message(QueryResponse {
55                extensions: ["session-bind@openssh.com"].map(String::from).to_vec(),
56            })?)),
57            "session-bind@openssh.com" => {
58                let mut session_bind_suceeded = false;
59                for sock_path in &self.socket_paths {
60                    // Try extension on upstream agents; discard any upstream failures from agents
61                    // that don't support the extension (but the default is Failure if there are no
62                    // successful upstream responses)
63                    if let Ok(mut client) = self.connect_upstream_agent(sock_path) {
64                        match client.extension(request.clone()).await {
65                            // Any agent succeeding is an overall success
66                            Ok(v) => {
67                                session_bind_suceeded = true;
68                                if v.is_some() {
69                                    log::warn!("session-bind@openssh.com request succeeded on socket <{}>, but an invalid response was received", sock_path.display());
70                                }
71                            }
72                            // Don't propagate upstream lack of extension support
73                            Err(AgentError::Failure) => continue,
74                            // Report but ignore any unexpected errors
75                            Err(e) => {
76                                log::error!("Unexpected error on socket <{}> when requesting session-bind@openssh.com extension: {}", sock_path.display(), e);
77                                continue;
78                            }
79                        }
80                    }
81                }
82                if session_bind_suceeded {
83                    Ok(None)
84                } else {
85                    Err(AgentError::Failure)
86                }
87            }
88            _ => Err(AgentError::Failure),
89        }
90    }
91}
92
93#[derive(Clone)]
94pub struct MuxAgent {
95    socket_paths: Vec<PathBuf>,
96    known_keys: KnownPubKeys,
97}
98
99impl MuxAgent {
100    /// Run a MuxAgent, listening for SSH agent protocol requests on `listen_sock`, forwarding
101    /// requests to the specified paths in `agent_socks`
102    pub async fn run<I, P>(listen_sock: impl AsRef<Path>, agent_socks: I) -> Result<(), AgentError>
103    where
104        I: IntoIterator<Item = P>,
105        P: AsRef<Path>,
106    {
107        let socket_paths: Vec<_> = agent_socks
108            .into_iter()
109            .map(|p| p.as_ref().to_path_buf())
110            .collect();
111        log::info!(
112            "Starting agent for {} upstream agents; listening on <{}>",
113            socket_paths.len(),
114            listen_sock.as_ref().display()
115        );
116        log::debug!("Upstream agent sockets: {:?}", &socket_paths);
117
118        let listen_sock = SelfDeletingUnixListener::bind(listen_sock)?;
119        let this = Self {
120            socket_paths,
121            known_keys: Default::default(),
122        };
123        agent::listen(listen_sock, this).await
124    }
125
126    fn connect_upstream_agent(
127        &self,
128        sock_path: impl AsRef<Path>,
129    ) -> Result<Box<dyn Session>, AgentError> {
130        let sock_path = sock_path.as_ref();
131        let stream = UnixStream::connect(sock_path)?;
132        let client = client::connect(stream.into())
133            .map_err(|e| AgentError::Other(format!("Failed to connect to agent: {e}").into()))?;
134        log::trace!(
135            "Connected to upstream agent on socket: {}",
136            sock_path.display()
137        );
138        Ok(client)
139    }
140
141    async fn get_agent_sock_for_pubkey(
142        &mut self,
143        pubkey: &PubKeyData,
144    ) -> Result<Option<PathBuf>, AgentError> {
145        // Refresh available identities if the public key isn't found;
146        // hold lock for duration of signing operation
147        let mut known_keys = self.known_keys.clone().lock_owned().await;
148        if !known_keys.contains_key(pubkey) {
149            log::debug!("Key not found, re-requesting keys from upstream agents");
150            let _ = self.refresh_identities(&mut known_keys).await?;
151        }
152        let maybe_agent = known_keys.get(pubkey).cloned();
153        Ok(maybe_agent)
154    }
155
156    // Factored out so that the known_keys lock can be held across a total request that includes a
157    // refresh of keys from upstream agents
158    async fn refresh_identities(
159        &mut self,
160        known_keys: &mut OwnedMutexGuard<KnownPubKeysMap>,
161    ) -> Result<Vec<Identity>, AgentError> {
162        let mut identities = vec![];
163        known_keys.clear();
164
165        for sock_path in &self.socket_paths {
166            let mut client = match self.connect_upstream_agent(sock_path) {
167                Ok(c) => c,
168                Err(_) => {
169                    log::warn!(
170                        "Ignoring missing upstream agent socket: {}",
171                        sock_path.display()
172                    );
173                    continue;
174                }
175            };
176            let agent_identities = client.request_identities().await?;
177            {
178                for id in &agent_identities {
179                    known_keys.insert(id.pubkey.clone(), sock_path.clone());
180                }
181            }
182            identities.extend(agent_identities);
183        }
184
185        Ok(identities)
186    }
187}
188
189impl Agent<SelfDeletingUnixListener> for MuxAgent {
190    #[doc = "Create new session object when a new socket is accepted."]
191    fn new_session(
192        &mut self,
193        _socket: &<SelfDeletingUnixListener as ListeningSocket>::Stream,
194    ) -> impl Session {
195        self.clone()
196    }
197}
198
199#[derive(Debug)]
200/// A wrapper for UnixListener that keeps the socket path around so it can be deleted
201struct SelfDeletingUnixListener {
202    path: PathBuf,
203    listener: UnixListener,
204}
205
206impl SelfDeletingUnixListener {
207    fn bind(path: impl AsRef<Path>) -> std::io::Result<Self> {
208        let path = path.as_ref().to_path_buf();
209        UnixListener::bind(&path).map(|listener| Self { path, listener })
210    }
211}
212
213impl Drop for SelfDeletingUnixListener {
214    fn drop(&mut self) {
215        let _ = std::fs::remove_file(&self.path);
216    }
217}
218
219#[ssh_agent_lib::async_trait]
220impl ListeningSocket for SelfDeletingUnixListener {
221    type Stream = tokio::net::UnixStream;
222
223    async fn accept(&mut self) -> std::io::Result<Self::Stream> {
224        UnixListener::accept(&self.listener)
225            .await
226            .map(|(s, _addr)| s)
227    }
228}