ssh_agent_lib/
agent.rs

1//! Traits for implementing custom SSH agents.
2//!
3//! Agents which store no state or their state is minimal should
4//! implement the [`Session`] trait. If a more elaborate state is
5//! needed, especially one which depends on the socket making the
6//! connection then it is advisable to implement the [`Agent`] trait.
7
8use std::fmt;
9use std::io;
10
11use async_trait::async_trait;
12use futures::{SinkExt, TryStreamExt};
13pub use service_binding;
14use ssh_key::Signature;
15use tokio::io::{AsyncRead, AsyncWrite};
16#[cfg(windows)]
17use tokio::net::windows::named_pipe::{NamedPipeServer, ServerOptions};
18use tokio::net::{TcpListener, TcpStream};
19#[cfg(unix)]
20use tokio::net::{UnixListener, UnixStream};
21use tokio_util::codec::Framed;
22
23use super::error::AgentError;
24use super::proto::message::{Request, Response};
25use crate::codec::Codec;
26use crate::proto::AddIdentity;
27use crate::proto::AddIdentityConstrained;
28use crate::proto::AddSmartcardKeyConstrained;
29use crate::proto::Extension;
30use crate::proto::Identity;
31use crate::proto::ProtoError;
32use crate::proto::RemoveIdentity;
33use crate::proto::SignRequest;
34use crate::proto::SmartcardKey;
35
36/// Type representing a socket that asynchronously returns a list of streams.
37///
38/// This trait is implemented for [TCP sockets](TcpListener) on all
39/// platforms, Unix sockets on Unix platforms (e.g. Linux, macOS) and
40/// Named Pipes on Windows.
41///
42/// Objects implementing this trait are passed to the [`listen`]
43/// function.
44///
45/// # Examples
46///
47/// The following example starts listening for connections and
48/// processes them with the `MyAgent` struct.
49///
50/// ```no_run
51/// # async fn main_() -> testresult::TestResult {
52/// use ssh_agent_lib::agent::{listen, Session};
53/// use tokio::net::TcpListener;
54///
55/// #[derive(Default, Clone)]
56/// struct MyAgent;
57///
58/// impl Session for MyAgent {
59///     // implement your agent logic here
60/// }
61///
62/// listen(
63///     TcpListener::bind("127.0.0.1:8080").await?,
64///     MyAgent::default(),
65/// )
66/// .await?;
67/// # Ok(()) }
68/// ```
69
70#[async_trait]
71pub trait ListeningSocket {
72    /// Stream type that represents an accepted socket.
73    type Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin + 'static;
74
75    /// Waits until a client connects and returns connected stream.
76    async fn accept(&mut self) -> io::Result<Self::Stream>;
77}
78
79#[cfg(unix)]
80#[async_trait]
81impl ListeningSocket for UnixListener {
82    type Stream = UnixStream;
83    async fn accept(&mut self) -> io::Result<Self::Stream> {
84        UnixListener::accept(self).await.map(|(s, _addr)| s)
85    }
86}
87
88#[async_trait]
89impl ListeningSocket for TcpListener {
90    type Stream = TcpStream;
91    async fn accept(&mut self) -> io::Result<Self::Stream> {
92        TcpListener::accept(self).await.map(|(s, _addr)| s)
93    }
94}
95
96/// Listener for Windows Named Pipes.
97#[cfg(windows)]
98#[derive(Debug)]
99pub struct NamedPipeListener(NamedPipeServer, std::ffi::OsString);
100
101#[cfg(windows)]
102impl NamedPipeListener {
103    /// Bind to a pipe path.
104    pub fn bind(pipe: impl Into<std::ffi::OsString>) -> std::io::Result<Self> {
105        let pipe = pipe.into();
106        Ok(NamedPipeListener(
107            ServerOptions::new()
108                .first_pipe_instance(true)
109                .create(&pipe)?,
110            pipe,
111        ))
112    }
113}
114
115#[cfg(windows)]
116#[async_trait]
117impl ListeningSocket for NamedPipeListener {
118    type Stream = NamedPipeServer;
119    async fn accept(&mut self) -> io::Result<Self::Stream> {
120        self.0.connect().await?;
121        Ok(std::mem::replace(
122            &mut self.0,
123            ServerOptions::new().create(&self.1)?,
124        ))
125    }
126}
127
128/// Represents one active SSH connection.
129///
130/// This type is implemented by agents that want to handle incoming SSH agent
131/// connections.
132///
133/// # Examples
134///
135/// The following examples shows the most minimal [`Session`]
136/// implementation: one that returns a list of public keys that it
137/// manages and signs all incoming signing requests.
138///
139/// Note that the `MyAgent` struct is cloned for all new sessions
140/// (incoming connections). If the cloning needs special behavior
141/// implementing [`Clone`] manually is a viable approach. If the newly
142/// created sessions require information from the underlying socket it
143/// is advisable to implement the [`Agent`] trait.
144///
145/// ```
146/// use ssh_agent_lib::{agent::Session, error::AgentError};
147/// use ssh_agent_lib::proto::{Identity, SignRequest};
148/// use ssh_key::{Algorithm, Signature};
149///
150/// #[derive(Default, Clone)]
151/// struct MyAgent;
152///
153/// #[ssh_agent_lib::async_trait]
154/// impl Session for MyAgent {
155///     async fn request_identities(&mut self) -> Result<Vec<Identity>, AgentError> {
156///         Ok(vec![ /* public keys that this agent knows of */ ])
157///     }
158///
159///     async fn sign(&mut self, request: SignRequest) -> Result<Signature, AgentError> {
160///         // get the signature by signing `request.data`
161///         let signature = vec![];
162///         Ok(Signature::new(
163///              Algorithm::new("algorithm").map_err(AgentError::other)?,
164///              signature,
165///         ).map_err(AgentError::other)?)
166///     }
167/// }
168/// ```
169#[async_trait]
170pub trait Session: 'static + Sync + Send + Unpin {
171    /// Request a list of keys managed by this session.
172    async fn request_identities(&mut self) -> Result<Vec<Identity>, AgentError> {
173        Err(AgentError::from(ProtoError::UnsupportedCommand {
174            command: 11,
175        }))
176    }
177
178    /// Perform a private key signature operation.
179    async fn sign(&mut self, _request: SignRequest) -> Result<Signature, AgentError> {
180        Err(AgentError::from(ProtoError::UnsupportedCommand {
181            command: 13,
182        }))
183    }
184
185    /// Add a private key to the agent.
186    async fn add_identity(&mut self, _identity: AddIdentity) -> Result<(), AgentError> {
187        Err(AgentError::from(ProtoError::UnsupportedCommand {
188            command: 17,
189        }))
190    }
191
192    /// Add a private key to the agent with a set of constraints.
193    async fn add_identity_constrained(
194        &mut self,
195        _identity: AddIdentityConstrained,
196    ) -> Result<(), AgentError> {
197        Err(AgentError::from(ProtoError::UnsupportedCommand {
198            command: 25,
199        }))
200    }
201
202    /// Remove private key from an agent.
203    async fn remove_identity(&mut self, _identity: RemoveIdentity) -> Result<(), AgentError> {
204        Err(AgentError::from(ProtoError::UnsupportedCommand {
205            command: 18,
206        }))
207    }
208
209    /// Remove all keys from an agent.
210    async fn remove_all_identities(&mut self) -> Result<(), AgentError> {
211        Err(AgentError::from(ProtoError::UnsupportedCommand {
212            command: 19,
213        }))
214    }
215
216    /// Add a key stored on a smartcard.
217    async fn add_smartcard_key(&mut self, _key: SmartcardKey) -> Result<(), AgentError> {
218        Err(AgentError::from(ProtoError::UnsupportedCommand {
219            command: 20,
220        }))
221    }
222
223    /// Add a key stored on a smartcard with a set of constraints.
224    async fn add_smartcard_key_constrained(
225        &mut self,
226        _key: AddSmartcardKeyConstrained,
227    ) -> Result<(), AgentError> {
228        Err(AgentError::from(ProtoError::UnsupportedCommand {
229            command: 26,
230        }))
231    }
232
233    /// Remove a smartcard key from the agent.
234    async fn remove_smartcard_key(&mut self, _key: SmartcardKey) -> Result<(), AgentError> {
235        Err(AgentError::from(ProtoError::UnsupportedCommand {
236            command: 21,
237        }))
238    }
239
240    /// Temporarily lock the agent with a password.
241    async fn lock(&mut self, _key: String) -> Result<(), AgentError> {
242        Err(AgentError::from(ProtoError::UnsupportedCommand {
243            command: 22,
244        }))
245    }
246
247    /// Unlock the agent with a password.
248    async fn unlock(&mut self, _key: String) -> Result<(), AgentError> {
249        Err(AgentError::from(ProtoError::UnsupportedCommand {
250            command: 23,
251        }))
252    }
253
254    /// Invoke a custom, vendor-specific extension on the agent.
255    async fn extension(&mut self, _extension: Extension) -> Result<Option<Extension>, AgentError> {
256        Err(AgentError::from(ProtoError::UnsupportedCommand {
257            command: 27,
258        }))
259    }
260
261    /// Handle a raw SSH agent request and return agent response.
262    ///
263    /// Note that it is preferable to use high-level functions instead of
264    /// this function. This function should be overridden only for custom
265    /// messages, outside of the SSH agent protocol specification.
266    async fn handle(&mut self, message: Request) -> Result<Response, AgentError> {
267        match message {
268            Request::RequestIdentities => {
269                return Ok(Response::IdentitiesAnswer(self.request_identities().await?))
270            }
271            Request::SignRequest(request) => {
272                return Ok(Response::SignResponse(self.sign(request).await?))
273            }
274            Request::AddIdentity(identity) => self.add_identity(identity).await?,
275            Request::RemoveIdentity(identity) => self.remove_identity(identity).await?,
276            Request::RemoveAllIdentities => self.remove_all_identities().await?,
277            Request::AddSmartcardKey(key) => self.add_smartcard_key(key).await?,
278            Request::RemoveSmartcardKey(key) => self.remove_smartcard_key(key).await?,
279            Request::Lock(key) => self.lock(key).await?,
280            Request::Unlock(key) => self.unlock(key).await?,
281            Request::AddIdConstrained(identity) => self.add_identity_constrained(identity).await?,
282            Request::AddSmartcardKeyConstrained(key) => {
283                self.add_smartcard_key_constrained(key).await?
284            }
285            Request::Extension(extension) => {
286                return match self.extension(extension).await? {
287                    Some(response) => Ok(Response::ExtensionResponse(response)),
288                    None => Ok(Response::Success),
289                }
290            }
291        }
292        Ok(Response::Success)
293    }
294}
295
296async fn handle_socket<S>(
297    mut session: impl Session,
298    mut adapter: Framed<S::Stream, Codec<Request, Response>>,
299) -> Result<(), AgentError>
300where
301    S: ListeningSocket + fmt::Debug + Send,
302{
303    loop {
304        if let Some(incoming_message) = adapter.try_next().await? {
305            log::debug!("Request: {incoming_message:?}");
306            let response = match session.handle(incoming_message).await {
307                Ok(message) => message,
308                Err(AgentError::ExtensionFailure) => {
309                    log::error!("Extension failure handling message");
310                    Response::ExtensionFailure
311                }
312                Err(e) => {
313                    log::error!("Error handling message: {:?}", e);
314                    Response::Failure
315                }
316            };
317            log::debug!("Response: {response:?}");
318
319            adapter.send(response).await?;
320        } else {
321            // Reached EOF of the stream (client disconnected),
322            // we can close the socket and exit the handler.
323            return Ok(());
324        }
325    }
326}
327
328/// Factory of sessions for the given type of sockets.
329///
330/// An agent implementation is automatically created for types which
331/// implement [`Session`] and [`Clone`]: new sessions are created by
332/// cloning the agent object. This is usually sufficient for the
333/// majority of use cases. In case the information about the
334/// underlying socket (connection source) is needed the [`Agent`] can
335/// be implemented manually.
336///
337/// # Examples
338///
339/// This example shows how to retrieve the connecting process ID on Unix:
340///
341/// ```
342/// use ssh_agent_lib::agent::{Agent, Session};
343///
344/// #[derive(Debug, Default)]
345/// struct AgentSocketInfo;
346///
347/// #[cfg(unix)]
348/// impl Agent<tokio::net::UnixListener> for AgentSocketInfo {
349///     fn new_session(&mut self, socket: &tokio::net::UnixStream) -> impl Session {
350///         let _socket_info = format!(
351///             "unix: addr: {:?} cred: {:?}",
352///             socket.peer_addr().unwrap(),
353///             socket.peer_cred().unwrap()
354///         );
355///         Self
356///     }
357/// }
358/// # impl Session for AgentSocketInfo { }
359/// ```
360pub trait Agent<S>: 'static + Send + Sync
361where
362    S: ListeningSocket + fmt::Debug + Send,
363{
364    /// Create a [`Session`] object for a given `socket`.
365    fn new_session(&mut self, socket: &S::Stream) -> impl Session;
366}
367
368/// Listen for connections on a given socket and use session factory
369/// to create new session for each accepted socket.
370///
371/// # Examples
372///
373/// The following example starts listening for connections and
374/// processes them with the `MyAgent` struct.
375///
376/// ```no_run
377/// # async fn main_() -> testresult::TestResult {
378/// use ssh_agent_lib::agent::{listen, Session};
379/// use tokio::net::TcpListener;
380///
381/// #[derive(Default, Clone)]
382/// struct MyAgent;
383///
384/// impl Session for MyAgent {
385///     // implement your agent logic here
386/// }
387///
388/// listen(
389///     TcpListener::bind("127.0.0.1:8080").await?,
390///     MyAgent::default(),
391/// )
392/// .await?;
393/// # Ok(()) }
394/// ```
395pub async fn listen<S>(mut socket: S, mut agent: impl Agent<S>) -> Result<(), AgentError>
396where
397    S: ListeningSocket + fmt::Debug + Send,
398{
399    log::info!("Listening; socket = {:?}", socket);
400    loop {
401        match socket.accept().await {
402            Ok(socket) => {
403                let session = agent.new_session(&socket);
404                tokio::spawn(async move {
405                    let adapter = Framed::new(socket, Codec::<Request, Response>::default());
406                    if let Err(e) = handle_socket::<S>(session, adapter).await {
407                        log::error!("Agent protocol error: {:?}", e);
408                    }
409                });
410            }
411            Err(e) => {
412                log::error!("Failed to accept socket: {:?}", e);
413                return Err(AgentError::IO(e));
414            }
415        }
416    }
417}
418
419#[cfg(unix)]
420impl<T> Agent<tokio::net::UnixListener> for T
421where
422    T: Clone + Send + Sync + Session,
423{
424    fn new_session(&mut self, _socket: &tokio::net::UnixStream) -> impl Session {
425        Self::clone(self)
426    }
427}
428
429impl<T> Agent<tokio::net::TcpListener> for T
430where
431    T: Clone + Send + Sync + Session,
432{
433    fn new_session(&mut self, _socket: &tokio::net::TcpStream) -> impl Session {
434        Self::clone(self)
435    }
436}
437
438#[cfg(windows)]
439impl<T> Agent<NamedPipeListener> for T
440where
441    T: Clone + Send + Sync + Session,
442{
443    fn new_session(
444        &mut self,
445        _socket: &tokio::net::windows::named_pipe::NamedPipeServer,
446    ) -> impl Session {
447        Self::clone(self)
448    }
449}
450
451#[cfg(unix)]
452type PlatformSpecificListener = tokio::net::UnixListener;
453
454#[cfg(windows)]
455type PlatformSpecificListener = NamedPipeListener;
456
457/// Bind to a service binding listener.
458///
459/// # Examples
460///
461/// The following example uses `clap` to parse the host socket data
462/// thus allowing the user to choose at runtime whether they want to
463/// use TCP sockets, Unix domain sockets (including systemd socket
464/// activation) or Named Pipes (under Windows).
465///
466/// ```no_run
467/// use clap::Parser;
468/// use service_binding::Binding;
469/// use ssh_agent_lib::agent::{bind, Session};
470///
471/// #[derive(Debug, Parser)]
472/// struct Args {
473///     #[clap(long, short = 'H', default_value = "unix:///tmp/ssh.sock")]
474///     host: Binding,
475/// }
476///
477/// #[derive(Default, Clone)]
478/// struct MyAgent;
479///
480/// impl Session for MyAgent {}
481///
482/// #[tokio::main]
483/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
484///     let args = Args::parse();
485///
486///     bind(args.host.try_into()?, MyAgent::default()).await?;
487///
488///     Ok(())
489/// }
490/// ```
491pub async fn bind<A>(listener: service_binding::Listener, agent: A) -> Result<(), AgentError>
492where
493    A: Agent<PlatformSpecificListener> + Agent<tokio::net::TcpListener>,
494{
495    match listener {
496        #[cfg(unix)]
497        service_binding::Listener::Unix(listener) => {
498            listen(UnixListener::from_std(listener)?, agent).await
499        }
500        service_binding::Listener::Tcp(listener) => {
501            listen(TcpListener::from_std(listener)?, agent).await
502        }
503        #[cfg(windows)]
504        service_binding::Listener::NamedPipe(pipe) => {
505            listen(NamedPipeListener::bind(pipe)?, agent).await
506        }
507        #[allow(unreachable_patterns)]
508        _ => Err(AgentError::IO(std::io::Error::other(
509            "Unsupported type of a listener.",
510        ))),
511    }
512}