ssh_agent_lib/
client.rs

1//! SSH agent client support.
2
3use std::fmt;
4
5use futures::{SinkExt, TryStreamExt};
6use ssh_key::Signature;
7use tokio::io::{AsyncRead, AsyncWrite};
8use tokio_util::codec::Framed;
9
10use crate::{
11    codec::Codec,
12    error::AgentError,
13    proto::{
14        AddIdentity, AddIdentityConstrained, AddSmartcardKeyConstrained, Extension, Identity,
15        ProtoError, RemoveIdentity, Request, Response, SignRequest, SmartcardKey,
16    },
17};
18
19/// SSH agent client
20#[derive(Debug)]
21pub struct Client<Stream>
22where
23    Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin + 'static,
24{
25    adapter: Framed<Stream, Codec<Response, Request>>,
26}
27
28impl<Stream> Client<Stream>
29where
30    Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin + 'static,
31{
32    /// Create a new SSH agent client wrapping a given socket.
33    pub fn new(socket: Stream) -> Self {
34        let adapter = Framed::new(socket, Codec::default());
35        Self { adapter }
36    }
37}
38
39/// Wrap a stream into an SSH agent client.
40pub fn connect(
41    stream: service_binding::Stream,
42) -> Result<Box<dyn crate::agent::Session>, Box<dyn std::error::Error>> {
43    match stream {
44        #[cfg(unix)]
45        service_binding::Stream::Unix(stream) => {
46            let stream = tokio::net::UnixStream::from_std(stream)?;
47            Ok(Box::new(Client::new(stream)))
48        }
49        service_binding::Stream::Tcp(stream) => {
50            let stream = tokio::net::TcpStream::from_std(stream)?;
51            Ok(Box::new(Client::new(stream)))
52        }
53        #[cfg(windows)]
54        service_binding::Stream::NamedPipe(pipe) => {
55            use tokio::net::windows::named_pipe::ClientOptions;
56            let stream = loop {
57                // https://docs.rs/windows-sys/latest/windows_sys/Win32/Foundation/constant.ERROR_PIPE_BUSY.html
58                const ERROR_PIPE_BUSY: u32 = 231u32;
59
60                // correct way to do it taken from
61                // https://docs.rs/tokio/latest/tokio/net/windows/named_pipe/struct.NamedPipeClient.html
62                match ClientOptions::new().open(&pipe) {
63                    Ok(client) => break client,
64                    Err(e) if e.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (),
65                    Err(e) => Err(e)?,
66                }
67
68                std::thread::sleep(std::time::Duration::from_millis(50));
69            };
70            Ok(Box::new(Client::new(stream)))
71        }
72        #[cfg(not(windows))]
73        service_binding::Stream::NamedPipe(_) => Err(ProtoError::IO(std::io::Error::other(
74            "Named pipes supported on Windows only",
75        ))
76        .into()),
77    }
78}
79
80#[async_trait::async_trait]
81impl<Stream> crate::agent::Session for Client<Stream>
82where
83    Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
84{
85    async fn request_identities(&mut self) -> Result<Vec<Identity>, AgentError> {
86        if let Response::IdentitiesAnswer(identities) =
87            self.handle(Request::RequestIdentities).await?
88        {
89            Ok(identities)
90        } else {
91            Err(ProtoError::UnexpectedResponse.into())
92        }
93    }
94
95    async fn sign(&mut self, request: SignRequest) -> Result<Signature, AgentError> {
96        if let Response::SignResponse(response) = self.handle(Request::SignRequest(request)).await?
97        {
98            Ok(response)
99        } else {
100            Err(ProtoError::UnexpectedResponse.into())
101        }
102    }
103
104    async fn add_identity(&mut self, identity: AddIdentity) -> Result<(), AgentError> {
105        if let Response::Success = self.handle(Request::AddIdentity(identity)).await? {
106            Ok(())
107        } else {
108            Err(ProtoError::UnexpectedResponse.into())
109        }
110    }
111
112    async fn add_identity_constrained(
113        &mut self,
114        identity: AddIdentityConstrained,
115    ) -> Result<(), AgentError> {
116        if let Response::Success = self.handle(Request::AddIdConstrained(identity)).await? {
117            Ok(())
118        } else {
119            Err(ProtoError::UnexpectedResponse.into())
120        }
121    }
122
123    async fn remove_identity(&mut self, identity: RemoveIdentity) -> Result<(), AgentError> {
124        if let Response::Success = self.handle(Request::RemoveIdentity(identity)).await? {
125            Ok(())
126        } else {
127            Err(ProtoError::UnexpectedResponse.into())
128        }
129    }
130
131    async fn remove_all_identities(&mut self) -> Result<(), AgentError> {
132        if let Response::Success = self.handle(Request::RemoveAllIdentities).await? {
133            Ok(())
134        } else {
135            Err(ProtoError::UnexpectedResponse.into())
136        }
137    }
138
139    async fn add_smartcard_key(&mut self, key: SmartcardKey) -> Result<(), AgentError> {
140        if let Response::Success = self.handle(Request::AddSmartcardKey(key)).await? {
141            Ok(())
142        } else {
143            Err(ProtoError::UnexpectedResponse.into())
144        }
145    }
146
147    async fn add_smartcard_key_constrained(
148        &mut self,
149        key: AddSmartcardKeyConstrained,
150    ) -> Result<(), AgentError> {
151        if let Response::Success = self
152            .handle(Request::AddSmartcardKeyConstrained(key))
153            .await?
154        {
155            Ok(())
156        } else {
157            Err(ProtoError::UnexpectedResponse.into())
158        }
159    }
160
161    async fn remove_smartcard_key(&mut self, key: SmartcardKey) -> Result<(), AgentError> {
162        if let Response::Success = self.handle(Request::RemoveSmartcardKey(key)).await? {
163            Ok(())
164        } else {
165            Err(ProtoError::UnexpectedResponse.into())
166        }
167    }
168
169    async fn lock(&mut self, key: String) -> Result<(), AgentError> {
170        if let Response::Success = self.handle(Request::Lock(key)).await? {
171            Ok(())
172        } else {
173            Err(ProtoError::UnexpectedResponse.into())
174        }
175    }
176
177    async fn unlock(&mut self, key: String) -> Result<(), AgentError> {
178        if let Response::Success = self.handle(Request::Unlock(key)).await? {
179            Ok(())
180        } else {
181            Err(ProtoError::UnexpectedResponse.into())
182        }
183    }
184
185    async fn extension(&mut self, extension: Extension) -> Result<Option<Extension>, AgentError> {
186        match self.handle(Request::Extension(extension)).await? {
187            Response::Success => Ok(None),
188            Response::ExtensionResponse(response) => Ok(Some(response)),
189            _ => Err(ProtoError::UnexpectedResponse.into()),
190        }
191    }
192
193    async fn handle(&mut self, message: Request) -> Result<Response, AgentError> {
194        self.adapter.send(message).await?;
195        if let Some(response) = self.adapter.try_next().await? {
196            Ok(response)
197        } else {
198            Err(ProtoError::IO(std::io::Error::other("server disconnected")).into())
199        }
200    }
201}