1use std::fmt;
9use std::io;
10
11use async_trait::async_trait;
12use futures::{SinkExt, TryStreamExt};
13pub use service_binding;
14use ssh_key::Signature;
15use tokio::io::{AsyncRead, AsyncWrite};
16#[cfg(windows)]
17use tokio::net::windows::named_pipe::{NamedPipeServer, ServerOptions};
18use tokio::net::{TcpListener, TcpStream};
19#[cfg(unix)]
20use tokio::net::{UnixListener, UnixStream};
21use tokio_util::codec::Framed;
22
23use super::error::AgentError;
24use super::proto::message::{Request, Response};
25use crate::codec::Codec;
26use crate::proto::AddIdentity;
27use crate::proto::AddIdentityConstrained;
28use crate::proto::AddSmartcardKeyConstrained;
29use crate::proto::Extension;
30use crate::proto::Identity;
31use crate::proto::ProtoError;
32use crate::proto::RemoveIdentity;
33use crate::proto::SignRequest;
34use crate::proto::SmartcardKey;
35
36#[async_trait]
71pub trait ListeningSocket {
72 type Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin + 'static;
74
75 async fn accept(&mut self) -> io::Result<Self::Stream>;
77}
78
79#[cfg(unix)]
80#[async_trait]
81impl ListeningSocket for UnixListener {
82 type Stream = UnixStream;
83 async fn accept(&mut self) -> io::Result<Self::Stream> {
84 UnixListener::accept(self).await.map(|(s, _addr)| s)
85 }
86}
87
88#[async_trait]
89impl ListeningSocket for TcpListener {
90 type Stream = TcpStream;
91 async fn accept(&mut self) -> io::Result<Self::Stream> {
92 TcpListener::accept(self).await.map(|(s, _addr)| s)
93 }
94}
95
96#[cfg(windows)]
98#[derive(Debug)]
99pub struct NamedPipeListener(NamedPipeServer, std::ffi::OsString);
100
101#[cfg(windows)]
102impl NamedPipeListener {
103 pub fn bind(pipe: impl Into<std::ffi::OsString>) -> std::io::Result<Self> {
105 let pipe = pipe.into();
106 Ok(NamedPipeListener(
107 ServerOptions::new()
108 .first_pipe_instance(true)
109 .create(&pipe)?,
110 pipe,
111 ))
112 }
113}
114
115#[cfg(windows)]
116#[async_trait]
117impl ListeningSocket for NamedPipeListener {
118 type Stream = NamedPipeServer;
119 async fn accept(&mut self) -> io::Result<Self::Stream> {
120 self.0.connect().await?;
121 Ok(std::mem::replace(
122 &mut self.0,
123 ServerOptions::new().create(&self.1)?,
124 ))
125 }
126}
127
128#[async_trait]
170pub trait Session: 'static + Sync + Send + Unpin {
171 async fn request_identities(&mut self) -> Result<Vec<Identity>, AgentError> {
173 Err(AgentError::from(ProtoError::UnsupportedCommand {
174 command: 11,
175 }))
176 }
177
178 async fn sign(&mut self, _request: SignRequest) -> Result<Signature, AgentError> {
180 Err(AgentError::from(ProtoError::UnsupportedCommand {
181 command: 13,
182 }))
183 }
184
185 async fn add_identity(&mut self, _identity: AddIdentity) -> Result<(), AgentError> {
187 Err(AgentError::from(ProtoError::UnsupportedCommand {
188 command: 17,
189 }))
190 }
191
192 async fn add_identity_constrained(
194 &mut self,
195 _identity: AddIdentityConstrained,
196 ) -> Result<(), AgentError> {
197 Err(AgentError::from(ProtoError::UnsupportedCommand {
198 command: 25,
199 }))
200 }
201
202 async fn remove_identity(&mut self, _identity: RemoveIdentity) -> Result<(), AgentError> {
204 Err(AgentError::from(ProtoError::UnsupportedCommand {
205 command: 18,
206 }))
207 }
208
209 async fn remove_all_identities(&mut self) -> Result<(), AgentError> {
211 Err(AgentError::from(ProtoError::UnsupportedCommand {
212 command: 19,
213 }))
214 }
215
216 async fn add_smartcard_key(&mut self, _key: SmartcardKey) -> Result<(), AgentError> {
218 Err(AgentError::from(ProtoError::UnsupportedCommand {
219 command: 20,
220 }))
221 }
222
223 async fn add_smartcard_key_constrained(
225 &mut self,
226 _key: AddSmartcardKeyConstrained,
227 ) -> Result<(), AgentError> {
228 Err(AgentError::from(ProtoError::UnsupportedCommand {
229 command: 26,
230 }))
231 }
232
233 async fn remove_smartcard_key(&mut self, _key: SmartcardKey) -> Result<(), AgentError> {
235 Err(AgentError::from(ProtoError::UnsupportedCommand {
236 command: 21,
237 }))
238 }
239
240 async fn lock(&mut self, _key: String) -> Result<(), AgentError> {
242 Err(AgentError::from(ProtoError::UnsupportedCommand {
243 command: 22,
244 }))
245 }
246
247 async fn unlock(&mut self, _key: String) -> Result<(), AgentError> {
249 Err(AgentError::from(ProtoError::UnsupportedCommand {
250 command: 23,
251 }))
252 }
253
254 async fn extension(&mut self, _extension: Extension) -> Result<Option<Extension>, AgentError> {
256 Err(AgentError::from(ProtoError::UnsupportedCommand {
257 command: 27,
258 }))
259 }
260
261 async fn handle(&mut self, message: Request) -> Result<Response, AgentError> {
267 match message {
268 Request::RequestIdentities => {
269 return Ok(Response::IdentitiesAnswer(self.request_identities().await?))
270 }
271 Request::SignRequest(request) => {
272 return Ok(Response::SignResponse(self.sign(request).await?))
273 }
274 Request::AddIdentity(identity) => self.add_identity(identity).await?,
275 Request::RemoveIdentity(identity) => self.remove_identity(identity).await?,
276 Request::RemoveAllIdentities => self.remove_all_identities().await?,
277 Request::AddSmartcardKey(key) => self.add_smartcard_key(key).await?,
278 Request::RemoveSmartcardKey(key) => self.remove_smartcard_key(key).await?,
279 Request::Lock(key) => self.lock(key).await?,
280 Request::Unlock(key) => self.unlock(key).await?,
281 Request::AddIdConstrained(identity) => self.add_identity_constrained(identity).await?,
282 Request::AddSmartcardKeyConstrained(key) => {
283 self.add_smartcard_key_constrained(key).await?
284 }
285 Request::Extension(extension) => {
286 return match self.extension(extension).await? {
287 Some(response) => Ok(Response::ExtensionResponse(response)),
288 None => Ok(Response::Success),
289 }
290 }
291 }
292 Ok(Response::Success)
293 }
294}
295
296async fn handle_socket<S>(
297 mut session: impl Session,
298 mut adapter: Framed<S::Stream, Codec<Request, Response>>,
299) -> Result<(), AgentError>
300where
301 S: ListeningSocket + fmt::Debug + Send,
302{
303 loop {
304 if let Some(incoming_message) = adapter.try_next().await? {
305 log::debug!("Request: {incoming_message:?}");
306 let response = match session.handle(incoming_message).await {
307 Ok(message) => message,
308 Err(AgentError::ExtensionFailure) => {
309 log::error!("Extension failure handling message");
310 Response::ExtensionFailure
311 }
312 Err(e) => {
313 log::error!("Error handling message: {:?}", e);
314 Response::Failure
315 }
316 };
317 log::debug!("Response: {response:?}");
318
319 adapter.send(response).await?;
320 } else {
321 return Ok(());
324 }
325 }
326}
327
328pub trait Agent<S>: 'static + Send + Sync
361where
362 S: ListeningSocket + fmt::Debug + Send,
363{
364 fn new_session(&mut self, socket: &S::Stream) -> impl Session;
366}
367
368pub async fn listen<S>(mut socket: S, mut agent: impl Agent<S>) -> Result<(), AgentError>
396where
397 S: ListeningSocket + fmt::Debug + Send,
398{
399 log::info!("Listening; socket = {:?}", socket);
400 loop {
401 match socket.accept().await {
402 Ok(socket) => {
403 let session = agent.new_session(&socket);
404 tokio::spawn(async move {
405 let adapter = Framed::new(socket, Codec::<Request, Response>::default());
406 if let Err(e) = handle_socket::<S>(session, adapter).await {
407 log::error!("Agent protocol error: {:?}", e);
408 }
409 });
410 }
411 Err(e) => {
412 log::error!("Failed to accept socket: {:?}", e);
413 return Err(AgentError::IO(e));
414 }
415 }
416 }
417}
418
419#[cfg(unix)]
420impl<T> Agent<tokio::net::UnixListener> for T
421where
422 T: Clone + Send + Sync + Session,
423{
424 fn new_session(&mut self, _socket: &tokio::net::UnixStream) -> impl Session {
425 Self::clone(self)
426 }
427}
428
429impl<T> Agent<tokio::net::TcpListener> for T
430where
431 T: Clone + Send + Sync + Session,
432{
433 fn new_session(&mut self, _socket: &tokio::net::TcpStream) -> impl Session {
434 Self::clone(self)
435 }
436}
437
438#[cfg(windows)]
439impl<T> Agent<NamedPipeListener> for T
440where
441 T: Clone + Send + Sync + Session,
442{
443 fn new_session(
444 &mut self,
445 _socket: &tokio::net::windows::named_pipe::NamedPipeServer,
446 ) -> impl Session {
447 Self::clone(self)
448 }
449}
450
451#[cfg(unix)]
452type PlatformSpecificListener = tokio::net::UnixListener;
453
454#[cfg(windows)]
455type PlatformSpecificListener = NamedPipeListener;
456
457pub async fn bind<A>(listener: service_binding::Listener, agent: A) -> Result<(), AgentError>
492where
493 A: Agent<PlatformSpecificListener> + Agent<tokio::net::TcpListener>,
494{
495 match listener {
496 #[cfg(unix)]
497 service_binding::Listener::Unix(listener) => {
498 listen(UnixListener::from_std(listener)?, agent).await
499 }
500 service_binding::Listener::Tcp(listener) => {
501 listen(TcpListener::from_std(listener)?, agent).await
502 }
503 #[cfg(windows)]
504 service_binding::Listener::NamedPipe(pipe) => {
505 listen(NamedPipeListener::bind(pipe)?, agent).await
506 }
507 #[allow(unreachable_patterns)]
508 _ => Err(AgentError::IO(std::io::Error::other(
509 "Unsupported type of a listener.",
510 ))),
511 }
512}