1use std::{
2 collections::HashMap,
3 os::unix::net::UnixStream,
4 path::{Path, PathBuf},
5 sync::Arc,
6};
7
8use ssh_agent_lib::{
9 agent::{self, Agent, ListeningSocket, Session},
10 client,
11 error::AgentError,
12 proto::{extension::QueryResponse, Extension, Identity, SignRequest},
13 ssh_key::{public::KeyData as PubKeyData, Signature},
14};
15use tokio::{
16 net::UnixListener,
17 sync::{Mutex, OwnedMutexGuard},
18};
19
20type KnownPubKeysMap = HashMap<PubKeyData, PathBuf>;
21type KnownPubKeys = Arc<Mutex<KnownPubKeysMap>>;
22
23#[ssh_agent_lib::async_trait]
26impl Session for MuxAgent {
27 async fn request_identities(&mut self) -> Result<Vec<Identity>, AgentError> {
28 log::trace!("incoming: request_identities");
29 let mut known_keys = self.known_keys.clone().lock_owned().await;
30 self.refresh_identities(&mut known_keys).await
31 }
32
33 async fn sign(&mut self, request: SignRequest) -> Result<Signature, AgentError> {
34 let fingerprint = request.pubkey.fingerprint(Default::default());
35 log::trace!("incoming: sign({})", &fingerprint);
36
37 if let Some(agent_sock_path) = self.get_agent_sock_for_pubkey(&request.pubkey).await? {
38 log::info!(
39 "Requesting signature with key {} from upstream agent <{}>",
40 &fingerprint,
41 agent_sock_path.display()
42 );
43
44 let mut client = self.connect_upstream_agent(agent_sock_path)?;
45 client.sign(request).await
46 } else {
47 log::error!("No upstream agent found for public key {}", &fingerprint);
48 log::trace!("Known keys:\n{:#?}", self.known_keys);
49 Err(AgentError::Other(
50 format!("No agent found for public key: {}", &fingerprint).into(),
51 ))
52 }
53 }
54
55 async fn extension(&mut self, request: Extension) -> Result<Option<Extension>, AgentError> {
56 log::trace!("incoming: extension({})", request.name);
57 match request.name.as_str() {
58 "query" => Ok(Some(Extension::new_message(QueryResponse {
59 extensions: ["session-bind@openssh.com"].map(String::from).to_vec(),
60 })?)),
61 "session-bind@openssh.com" => {
62 let mut session_bind_suceeded = false;
63 for sock_path in &self.socket_paths {
64 if let Ok(mut client) = self.connect_upstream_agent(sock_path) {
68 match client.extension(request.clone()).await {
69 Ok(v) => {
71 session_bind_suceeded = true;
72 if v.is_some() {
73 log::warn!("session-bind@openssh.com request succeeded on socket <{}>, but an invalid response was received", sock_path.display());
74 }
75 }
76 Err(AgentError::Failure) => continue,
78 Err(e) => {
80 log::error!("Unexpected error on socket <{}> when requesting session-bind@openssh.com extension: {}", sock_path.display(), e);
81 continue;
82 }
83 }
84 }
85 }
86 if session_bind_suceeded {
87 Ok(None)
88 } else {
89 Err(AgentError::Failure)
90 }
91 }
92 _ => Err(AgentError::Failure),
93 }
94 }
95}
96
97#[derive(Clone)]
98pub struct MuxAgent {
99 socket_paths: Vec<PathBuf>,
100 known_keys: KnownPubKeys,
101}
102
103impl MuxAgent {
104 pub async fn run<I, P>(listen_sock: impl AsRef<Path>, agent_socks: I) -> Result<(), AgentError>
107 where
108 I: IntoIterator<Item = P>,
109 P: AsRef<Path>,
110 {
111 let listen_sock = listen_sock.as_ref();
112 let socket_paths: Vec<_> = agent_socks
113 .into_iter()
114 .map(|p| p.as_ref().to_path_buf())
115 .collect();
116 if socket_paths.is_empty() {
117 log::warn!("Mux agent running but no upstream agents configured");
118 }
119 log::info!(
120 "Starting agent for {} upstream agents; listening on <{}>",
121 socket_paths.len(),
122 listen_sock.display()
123 );
124 log::debug!("Upstream agent sockets: {:?}", &socket_paths);
125
126 let listen_sock = match SelfDeletingUnixListener::bind(listen_sock) {
127 Ok(s) => s,
128 err => {
129 log::error!(
130 "Failed to open listening socket at {}",
131 listen_sock.display()
132 );
133 err?
134 }
135 };
136 let this = Self {
137 socket_paths,
138 known_keys: Default::default(),
139 };
140 agent::listen(listen_sock, this).await
141 }
142
143 fn connect_upstream_agent(
144 &self,
145 sock_path: impl AsRef<Path>,
146 ) -> Result<Box<dyn Session>, AgentError> {
147 let sock_path = sock_path.as_ref();
148 let stream = UnixStream::connect(sock_path)?;
149 let client = client::connect(stream.into()).map_err(|e| {
150 AgentError::Other(
151 format!(
152 "Failed to connect to agent at {}: {}",
153 sock_path.display(),
154 e
155 )
156 .into(),
157 )
158 })?;
159 log::trace!(
160 "Connected to upstream agent on socket: {}",
161 sock_path.display()
162 );
163 Ok(client)
164 }
165
166 async fn get_agent_sock_for_pubkey(
167 &mut self,
168 pubkey: &PubKeyData,
169 ) -> Result<Option<PathBuf>, AgentError> {
170 let mut known_keys = self.known_keys.clone().lock_owned().await;
173 if !known_keys.contains_key(pubkey) {
174 log::debug!("Key not found, re-requesting keys from upstream agents");
175 let _ = self.refresh_identities(&mut known_keys).await?;
176 }
177 let maybe_agent = known_keys.get(pubkey).cloned();
178 Ok(maybe_agent)
179 }
180
181 async fn refresh_identities(
184 &mut self,
185 known_keys: &mut OwnedMutexGuard<KnownPubKeysMap>,
186 ) -> Result<Vec<Identity>, AgentError> {
187 let mut identities = vec![];
188 known_keys.clear();
189
190 log::debug!("Refreshing identities");
191 for sock_path in &self.socket_paths {
192 let mut client = match self.connect_upstream_agent(sock_path) {
193 Ok(c) => c,
194 Err(_) => {
195 log::warn!(
196 "Ignoring missing upstream agent socket: {}",
197 sock_path.display()
198 );
199 continue;
200 }
201 };
202 let agent_identities = client.request_identities().await?;
203 {
204 for id in &agent_identities {
205 known_keys.insert(id.pubkey.clone(), sock_path.clone());
206 }
207 }
208 log::trace!("Got {} identities from {}", agent_identities.len(), sock_path.display());
209 identities.extend(agent_identities);
210 }
211
212 Ok(identities)
213 }
214}
215
216impl Agent<SelfDeletingUnixListener> for MuxAgent {
217 #[doc = "Create new session object when a new socket is accepted."]
218 fn new_session(
219 &mut self,
220 _socket: &<SelfDeletingUnixListener as ListeningSocket>::Stream,
221 ) -> impl Session {
222 self.clone()
223 }
224}
225
226#[derive(Debug)]
227struct SelfDeletingUnixListener {
229 path: PathBuf,
230 listener: UnixListener,
231}
232
233impl SelfDeletingUnixListener {
234 fn bind(path: impl AsRef<Path>) -> std::io::Result<Self> {
235 let path = path.as_ref().to_path_buf();
236 UnixListener::bind(&path).map(|listener| Self { path, listener })
237 }
238}
239
240impl Drop for SelfDeletingUnixListener {
241 fn drop(&mut self) {
242 log::debug!("Cleaning up socket {}", self.path.display());
243 let _ = std::fs::remove_file(&self.path);
244 }
245}
246
247#[ssh_agent_lib::async_trait]
248impl ListeningSocket for SelfDeletingUnixListener {
249 type Stream = tokio::net::UnixStream;
250
251 async fn accept(&mut self) -> std::io::Result<Self::Stream> {
252 UnixListener::accept(&self.listener)
253 .await
254 .map(|(s, _addr)| s)
255 }
256}