1use std::{
2 collections::HashMap,
3 os::unix::net::UnixStream,
4 path::{Path, PathBuf},
5 sync::Arc,
6};
7
8use ssh_agent_lib::{
9 agent::{self, Agent, ListeningSocket, Session},
10 client,
11 error::AgentError,
12 proto::{extension::QueryResponse, Extension, Identity, SignRequest},
13 ssh_key::{public::KeyData as PubKeyData, Signature},
14};
15use tokio::{
16 net::UnixListener,
17 sync::{Mutex, OwnedMutexGuard},
18};
19
20type KnownPubKeysMap = HashMap<PubKeyData, PathBuf>;
21type KnownPubKeys = Arc<Mutex<KnownPubKeysMap>>;
22
23#[ssh_agent_lib::async_trait]
26impl Session for MuxAgent {
27 async fn request_identities(&mut self) -> Result<Vec<Identity>, AgentError> {
28 let mut known_keys = self.known_keys.clone().lock_owned().await;
29 self.refresh_identities(&mut known_keys).await
30 }
31
32 async fn sign(&mut self, request: SignRequest) -> Result<Signature, AgentError> {
33 if let Some(agent_sock_path) = self.get_agent_sock_for_pubkey(&request.pubkey).await? {
34 log::info!(
35 "Request: signature with key {} from upstream agent <{}>",
36 request.pubkey.fingerprint(Default::default()),
37 agent_sock_path.display()
38 );
39
40 let mut client = self.connect_upstream_agent(agent_sock_path)?;
41 client.sign(request).await
42 } else {
43 let fingerprint = request.pubkey.fingerprint(Default::default());
44 log::error!("No upstream agent found for public key {}", &fingerprint);
45 log::trace!("Known keys: {:?}", self.known_keys);
46 Err(AgentError::Other(
47 format!("No agent found for public key: {}", &fingerprint).into(),
48 ))
49 }
50 }
51
52 async fn extension(&mut self, request: Extension) -> Result<Option<Extension>, AgentError> {
53 match request.name.as_str() {
54 "query" => Ok(Some(Extension::new_message(QueryResponse {
55 extensions: ["session-bind@openssh.com"].map(String::from).to_vec(),
56 })?)),
57 "session-bind@openssh.com" => {
58 let mut session_bind_suceeded = false;
59 for sock_path in &self.socket_paths {
60 if let Ok(mut client) = self.connect_upstream_agent(sock_path) {
64 match client.extension(request.clone()).await {
65 Ok(v) => {
67 session_bind_suceeded = true;
68 if v.is_some() {
69 log::warn!("session-bind@openssh.com request succeeded on socket <{}>, but an invalid response was received", sock_path.display());
70 }
71 }
72 Err(AgentError::Failure) => continue,
74 Err(e) => {
76 log::error!("Unexpected error on socket <{}> when requesting session-bind@openssh.com extension: {}", sock_path.display(), e);
77 continue;
78 }
79 }
80 }
81 }
82 if session_bind_suceeded {
83 Ok(None)
84 } else {
85 Err(AgentError::Failure)
86 }
87 }
88 _ => Err(AgentError::Failure),
89 }
90 }
91}
92
93#[derive(Clone)]
94pub struct MuxAgent {
95 socket_paths: Vec<PathBuf>,
96 known_keys: KnownPubKeys,
97}
98
99impl MuxAgent {
100 pub async fn run<I, P>(listen_sock: impl AsRef<Path>, agent_socks: I) -> Result<(), AgentError>
103 where
104 I: IntoIterator<Item = P>,
105 P: AsRef<Path>,
106 {
107 let listen_sock = listen_sock.as_ref();
108 let socket_paths: Vec<_> = agent_socks
109 .into_iter()
110 .map(|p| p.as_ref().to_path_buf())
111 .collect();
112 if socket_paths.is_empty() {
113 log::warn!("Mux agent running but no upstream agents configured");
114 }
115 log::info!(
116 "Starting agent for {} upstream agents; listening on <{}>",
117 socket_paths.len(),
118 listen_sock.display()
119 );
120 log::debug!("Upstream agent sockets: {:?}", &socket_paths);
121
122 let listen_sock = match SelfDeletingUnixListener::bind(listen_sock) {
123 Ok(s) => s,
124 err => {
125 log::error!("Failed to open listening socket at {}", listen_sock.display());
126 err?
127 }
128 };
129 let this = Self {
130 socket_paths,
131 known_keys: Default::default(),
132 };
133 agent::listen(listen_sock, this).await
134 }
135
136 fn connect_upstream_agent(
137 &self,
138 sock_path: impl AsRef<Path>,
139 ) -> Result<Box<dyn Session>, AgentError> {
140 let sock_path = sock_path.as_ref();
141 let stream = UnixStream::connect(sock_path)?;
142 let client = client::connect(stream.into())
143 .map_err(|e| AgentError::Other(format!("Failed to connect to agent at {}: {}", sock_path.display(), e).into()))?;
144 log::trace!(
145 "Connected to upstream agent on socket: {}",
146 sock_path.display()
147 );
148 Ok(client)
149 }
150
151 async fn get_agent_sock_for_pubkey(
152 &mut self,
153 pubkey: &PubKeyData,
154 ) -> Result<Option<PathBuf>, AgentError> {
155 let mut known_keys = self.known_keys.clone().lock_owned().await;
158 if !known_keys.contains_key(pubkey) {
159 log::debug!("Key not found, re-requesting keys from upstream agents");
160 let _ = self.refresh_identities(&mut known_keys).await?;
161 }
162 let maybe_agent = known_keys.get(pubkey).cloned();
163 Ok(maybe_agent)
164 }
165
166 async fn refresh_identities(
169 &mut self,
170 known_keys: &mut OwnedMutexGuard<KnownPubKeysMap>,
171 ) -> Result<Vec<Identity>, AgentError> {
172 let mut identities = vec![];
173 known_keys.clear();
174
175 for sock_path in &self.socket_paths {
176 let mut client = match self.connect_upstream_agent(sock_path) {
177 Ok(c) => c,
178 Err(_) => {
179 log::warn!(
180 "Ignoring missing upstream agent socket: {}",
181 sock_path.display()
182 );
183 continue;
184 }
185 };
186 let agent_identities = client.request_identities().await?;
187 {
188 for id in &agent_identities {
189 known_keys.insert(id.pubkey.clone(), sock_path.clone());
190 }
191 }
192 identities.extend(agent_identities);
193 }
194
195 Ok(identities)
196 }
197}
198
199impl Agent<SelfDeletingUnixListener> for MuxAgent {
200 #[doc = "Create new session object when a new socket is accepted."]
201 fn new_session(
202 &mut self,
203 _socket: &<SelfDeletingUnixListener as ListeningSocket>::Stream,
204 ) -> impl Session {
205 self.clone()
206 }
207}
208
209#[derive(Debug)]
210struct SelfDeletingUnixListener {
212 path: PathBuf,
213 listener: UnixListener,
214}
215
216impl SelfDeletingUnixListener {
217 fn bind(path: impl AsRef<Path>) -> std::io::Result<Self> {
218 let path = path.as_ref().to_path_buf();
219 UnixListener::bind(&path).map(|listener| Self { path, listener })
220 }
221}
222
223impl Drop for SelfDeletingUnixListener {
224 fn drop(&mut self) {
225 let _ = std::fs::remove_file(&self.path);
226 }
227}
228
229#[ssh_agent_lib::async_trait]
230impl ListeningSocket for SelfDeletingUnixListener {
231 type Stream = tokio::net::UnixStream;
232
233 async fn accept(&mut self) -> std::io::Result<Self::Stream> {
234 UnixListener::accept(&self.listener)
235 .await
236 .map(|(s, _addr)| s)
237 }
238}