1use std::{
2 collections::HashMap,
3 os::unix::net::UnixStream,
4 path::{Path, PathBuf},
5 sync::Arc,
6};
7
8use ssh_agent_lib::{
9 agent::{self, Agent, ListeningSocket, Session},
10 client,
11 error::AgentError,
12 proto::{extension::QueryResponse, Extension, Identity, SignRequest},
13 ssh_key::{public::KeyData as PubKeyData, Signature},
14};
15use tokio::{
16 net::UnixListener,
17 sync::{Mutex, OwnedMutexGuard},
18};
19
20type KnownPubKeysMap = HashMap<PubKeyData, PathBuf>;
21type KnownPubKeys = Arc<Mutex<KnownPubKeysMap>>;
22
23#[ssh_agent_lib::async_trait]
26impl Session for MuxAgent {
27 async fn request_identities(&mut self) -> Result<Vec<Identity>, AgentError> {
28 let mut known_keys = self.known_keys.clone().lock_owned().await;
29 self.refresh_identities(&mut known_keys).await
30 }
31
32 async fn sign(&mut self, request: SignRequest) -> Result<Signature, AgentError> {
33 if let Some(agent_sock_path) = self.get_agent_sock_for_pubkey(&request.pubkey).await? {
34 log::info!(
35 "Request: signature with key {} from upstream agent <{}>",
36 request.pubkey.fingerprint(Default::default()),
37 agent_sock_path.display()
38 );
39
40 let mut client = self.connect_upstream_agent(agent_sock_path)?;
41 client.sign(request).await
42 } else {
43 let fingerprint = request.pubkey.fingerprint(Default::default());
44 log::error!("No upstream agent found for public key {}", &fingerprint);
45 log::trace!("Known keys: {:?}", self.known_keys);
46 Err(AgentError::Other(
47 format!("No agent found for public key: {}", &fingerprint).into(),
48 ))
49 }
50 }
51
52 async fn extension(&mut self, request: Extension) -> Result<Option<Extension>, AgentError> {
53 match request.name.as_str() {
54 "query" => Ok(Some(Extension::new_message(QueryResponse {
55 extensions: ["session-bind@openssh.com"].map(String::from).to_vec(),
56 })?)),
57 "session-bind@openssh.com" => {
58 let mut session_bind_suceeded = false;
59 for sock_path in &self.socket_paths {
60 if let Ok(mut client) = self.connect_upstream_agent(sock_path) {
64 match client.extension(request.clone()).await {
65 Ok(v) => {
67 session_bind_suceeded = true;
68 if v.is_some() {
69 log::warn!("session-bind@openssh.com request succeeded on socket <{}>, but an invalid response was received", sock_path.display());
70 }
71 }
72 Err(AgentError::Failure) => continue,
74 Err(e) => {
76 log::error!("Unexpected error on socket <{}> when requesting session-bind@openssh.com extension: {}", sock_path.display(), e);
77 continue;
78 }
79 }
80 }
81 }
82 if session_bind_suceeded {
83 Ok(None)
84 } else {
85 Err(AgentError::Failure)
86 }
87 }
88 _ => Err(AgentError::Failure),
89 }
90 }
91}
92
93#[derive(Clone)]
94pub struct MuxAgent {
95 socket_paths: Vec<PathBuf>,
96 known_keys: KnownPubKeys,
97}
98
99impl MuxAgent {
100 pub async fn run<I, P>(listen_sock: impl AsRef<Path>, agent_socks: I) -> Result<(), AgentError>
103 where
104 I: IntoIterator<Item = P>,
105 P: AsRef<Path>,
106 {
107 let socket_paths: Vec<_> = agent_socks
108 .into_iter()
109 .map(|p| p.as_ref().to_path_buf())
110 .collect();
111 log::info!(
112 "Starting agent for {} upstream agents; listening on <{}>",
113 socket_paths.len(),
114 listen_sock.as_ref().display()
115 );
116 log::debug!("Upstream agent sockets: {:?}", &socket_paths);
117
118 let listen_sock = SelfDeletingUnixListener::bind(listen_sock)?;
119 let this = Self {
120 socket_paths,
121 known_keys: Default::default(),
122 };
123 agent::listen(listen_sock, this).await
124 }
125
126 fn connect_upstream_agent(
127 &self,
128 sock_path: impl AsRef<Path>,
129 ) -> Result<Box<dyn Session>, AgentError> {
130 let sock_path = sock_path.as_ref();
131 let stream = UnixStream::connect(sock_path)?;
132 let client = client::connect(stream.into())
133 .map_err(|e| AgentError::Other(format!("Failed to connect to agent: {e}").into()))?;
134 log::trace!(
135 "Connected to upstream agent on socket: {}",
136 sock_path.display()
137 );
138 Ok(client)
139 }
140
141 async fn get_agent_sock_for_pubkey(
142 &mut self,
143 pubkey: &PubKeyData,
144 ) -> Result<Option<PathBuf>, AgentError> {
145 let mut known_keys = self.known_keys.clone().lock_owned().await;
148 if !known_keys.contains_key(pubkey) {
149 log::debug!("Key not found, re-requesting keys from upstream agents");
150 let _ = self.refresh_identities(&mut known_keys).await?;
151 }
152 let maybe_agent = known_keys.get(pubkey).cloned();
153 Ok(maybe_agent)
154 }
155
156 async fn refresh_identities(
159 &mut self,
160 known_keys: &mut OwnedMutexGuard<KnownPubKeysMap>,
161 ) -> Result<Vec<Identity>, AgentError> {
162 let mut identities = vec![];
163 known_keys.clear();
164
165 for sock_path in &self.socket_paths {
166 let mut client = match self.connect_upstream_agent(sock_path) {
167 Ok(c) => c,
168 Err(_) => {
169 log::warn!(
170 "Ignoring missing upstream agent socket: {}",
171 sock_path.display()
172 );
173 continue;
174 }
175 };
176 let agent_identities = client.request_identities().await?;
177 {
178 for id in &agent_identities {
179 known_keys.insert(id.pubkey.clone(), sock_path.clone());
180 }
181 }
182 identities.extend(agent_identities);
183 }
184
185 Ok(identities)
186 }
187}
188
189impl Agent<SelfDeletingUnixListener> for MuxAgent {
190 #[doc = "Create new session object when a new socket is accepted."]
191 fn new_session(
192 &mut self,
193 _socket: &<SelfDeletingUnixListener as ListeningSocket>::Stream,
194 ) -> impl Session {
195 self.clone()
196 }
197}
198
199#[derive(Debug)]
200struct SelfDeletingUnixListener {
202 path: PathBuf,
203 listener: UnixListener,
204}
205
206impl SelfDeletingUnixListener {
207 fn bind(path: impl AsRef<Path>) -> std::io::Result<Self> {
208 let path = path.as_ref().to_path_buf();
209 UnixListener::bind(&path).map(|listener| Self { path, listener })
210 }
211}
212
213impl Drop for SelfDeletingUnixListener {
214 fn drop(&mut self) {
215 let _ = std::fs::remove_file(&self.path);
216 }
217}
218
219#[ssh_agent_lib::async_trait]
220impl ListeningSocket for SelfDeletingUnixListener {
221 type Stream = tokio::net::UnixStream;
222
223 async fn accept(&mut self) -> std::io::Result<Self::Stream> {
224 UnixListener::accept(&self.listener)
225 .await
226 .map(|(s, _addr)| s)
227 }
228}