1use std::fmt;
4
5use futures::{SinkExt, TryStreamExt};
6use ssh_key::Signature;
7use tokio::io::{AsyncRead, AsyncWrite};
8use tokio_util::codec::Framed;
9
10use crate::{
11 codec::Codec,
12 error::AgentError,
13 proto::{
14 AddIdentity, AddIdentityConstrained, AddSmartcardKeyConstrained, Extension, Identity,
15 ProtoError, RemoveIdentity, Request, Response, SignRequest, SmartcardKey,
16 },
17};
18
19#[derive(Debug)]
21pub struct Client<Stream>
22where
23 Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin + 'static,
24{
25 adapter: Framed<Stream, Codec<Response, Request>>,
26}
27
28impl<Stream> Client<Stream>
29where
30 Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Unpin + 'static,
31{
32 pub fn new(socket: Stream) -> Self {
34 let adapter = Framed::new(socket, Codec::default());
35 Self { adapter }
36 }
37}
38
39pub fn connect(
41 stream: service_binding::Stream,
42) -> Result<Box<dyn crate::agent::Session>, Box<dyn std::error::Error>> {
43 match stream {
44 #[cfg(unix)]
45 service_binding::Stream::Unix(stream) => {
46 let stream = tokio::net::UnixStream::from_std(stream)?;
47 Ok(Box::new(Client::new(stream)))
48 }
49 service_binding::Stream::Tcp(stream) => {
50 let stream = tokio::net::TcpStream::from_std(stream)?;
51 Ok(Box::new(Client::new(stream)))
52 }
53 #[cfg(windows)]
54 service_binding::Stream::NamedPipe(pipe) => {
55 use tokio::net::windows::named_pipe::ClientOptions;
56 let stream = loop {
57 const ERROR_PIPE_BUSY: u32 = 231u32;
59
60 match ClientOptions::new().open(&pipe) {
63 Ok(client) => break client,
64 Err(e) if e.raw_os_error() == Some(ERROR_PIPE_BUSY as i32) => (),
65 Err(e) => Err(e)?,
66 }
67
68 std::thread::sleep(std::time::Duration::from_millis(50));
69 };
70 Ok(Box::new(Client::new(stream)))
71 }
72 #[cfg(not(windows))]
73 service_binding::Stream::NamedPipe(_) => Err(ProtoError::IO(std::io::Error::other(
74 "Named pipes supported on Windows only",
75 ))
76 .into()),
77 }
78}
79
80#[async_trait::async_trait]
81impl<Stream> crate::agent::Session for Client<Stream>
82where
83 Stream: fmt::Debug + AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
84{
85 async fn request_identities(&mut self) -> Result<Vec<Identity>, AgentError> {
86 if let Response::IdentitiesAnswer(identities) =
87 self.handle(Request::RequestIdentities).await?
88 {
89 Ok(identities)
90 } else {
91 Err(ProtoError::UnexpectedResponse.into())
92 }
93 }
94
95 async fn sign(&mut self, request: SignRequest) -> Result<Signature, AgentError> {
96 if let Response::SignResponse(response) = self.handle(Request::SignRequest(request)).await?
97 {
98 Ok(response)
99 } else {
100 Err(ProtoError::UnexpectedResponse.into())
101 }
102 }
103
104 async fn add_identity(&mut self, identity: AddIdentity) -> Result<(), AgentError> {
105 if let Response::Success = self.handle(Request::AddIdentity(identity)).await? {
106 Ok(())
107 } else {
108 Err(ProtoError::UnexpectedResponse.into())
109 }
110 }
111
112 async fn add_identity_constrained(
113 &mut self,
114 identity: AddIdentityConstrained,
115 ) -> Result<(), AgentError> {
116 if let Response::Success = self.handle(Request::AddIdConstrained(identity)).await? {
117 Ok(())
118 } else {
119 Err(ProtoError::UnexpectedResponse.into())
120 }
121 }
122
123 async fn remove_identity(&mut self, identity: RemoveIdentity) -> Result<(), AgentError> {
124 if let Response::Success = self.handle(Request::RemoveIdentity(identity)).await? {
125 Ok(())
126 } else {
127 Err(ProtoError::UnexpectedResponse.into())
128 }
129 }
130
131 async fn remove_all_identities(&mut self) -> Result<(), AgentError> {
132 if let Response::Success = self.handle(Request::RemoveAllIdentities).await? {
133 Ok(())
134 } else {
135 Err(ProtoError::UnexpectedResponse.into())
136 }
137 }
138
139 async fn add_smartcard_key(&mut self, key: SmartcardKey) -> Result<(), AgentError> {
140 if let Response::Success = self.handle(Request::AddSmartcardKey(key)).await? {
141 Ok(())
142 } else {
143 Err(ProtoError::UnexpectedResponse.into())
144 }
145 }
146
147 async fn add_smartcard_key_constrained(
148 &mut self,
149 key: AddSmartcardKeyConstrained,
150 ) -> Result<(), AgentError> {
151 if let Response::Success = self
152 .handle(Request::AddSmartcardKeyConstrained(key))
153 .await?
154 {
155 Ok(())
156 } else {
157 Err(ProtoError::UnexpectedResponse.into())
158 }
159 }
160
161 async fn remove_smartcard_key(&mut self, key: SmartcardKey) -> Result<(), AgentError> {
162 if let Response::Success = self.handle(Request::RemoveSmartcardKey(key)).await? {
163 Ok(())
164 } else {
165 Err(ProtoError::UnexpectedResponse.into())
166 }
167 }
168
169 async fn lock(&mut self, key: String) -> Result<(), AgentError> {
170 if let Response::Success = self.handle(Request::Lock(key)).await? {
171 Ok(())
172 } else {
173 Err(ProtoError::UnexpectedResponse.into())
174 }
175 }
176
177 async fn unlock(&mut self, key: String) -> Result<(), AgentError> {
178 if let Response::Success = self.handle(Request::Unlock(key)).await? {
179 Ok(())
180 } else {
181 Err(ProtoError::UnexpectedResponse.into())
182 }
183 }
184
185 async fn extension(&mut self, extension: Extension) -> Result<Option<Extension>, AgentError> {
186 match self.handle(Request::Extension(extension)).await? {
187 Response::Success => Ok(None),
188 Response::ExtensionResponse(response) => Ok(Some(response)),
189 _ => Err(ProtoError::UnexpectedResponse.into()),
190 }
191 }
192
193 async fn handle(&mut self, message: Request) -> Result<Response, AgentError> {
194 self.adapter.send(message).await?;
195 if let Some(response) = self.adapter.try_next().await? {
196 Ok(response)
197 } else {
198 Err(ProtoError::IO(std::io::Error::other("server disconnected")).into())
199 }
200 }
201}