1use crate::error::{ProblemJson, ViiperError};
4use crate::types::*;
5use std::net::SocketAddr;
6use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
7use tokio::net::TcpStream;
8
9#[cfg(feature = "async")]
11pub enum AsyncStreamWrapper {
12 Plain(TcpStream),
13 Encrypted(crate::auth::AsyncEncryptedStream),
14}
15
16#[cfg(feature = "async")]
18pub enum AsyncReadWrapper {
19 Plain(tokio::net::tcp::OwnedReadHalf),
20 Encrypted(crate::auth::AsyncEncryptedRead),
21}
22
23#[cfg(feature = "async")]
25pub enum AsyncWriteWrapper {
26 Plain(tokio::net::tcp::OwnedWriteHalf),
27 Encrypted(crate::auth::AsyncEncryptedWrite),
28}
29
30#[cfg(feature = "async")]
31impl AsyncRead for AsyncStreamWrapper {
32 fn poll_read(
33 mut self: std::pin::Pin<&mut Self>,
34 cx: &mut std::task::Context<'_>,
35 buf: &mut tokio::io::ReadBuf<'_>,
36 ) -> std::task::Poll<std::io::Result<()>> {
37 match &mut *self {
38 AsyncStreamWrapper::Plain(s) => std::pin::Pin::new(s).poll_read(cx, buf),
39 AsyncStreamWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_read(cx, buf),
40 }
41 }
42}
43
44#[cfg(feature = "async")]
45impl AsyncRead for AsyncReadWrapper {
46 fn poll_read(
47 mut self: std::pin::Pin<&mut Self>,
48 cx: &mut std::task::Context<'_>,
49 buf: &mut tokio::io::ReadBuf<'_>,
50 ) -> std::task::Poll<std::io::Result<()>> {
51 match &mut *self {
52 AsyncReadWrapper::Plain(s) => std::pin::Pin::new(s).poll_read(cx, buf),
53 AsyncReadWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_read(cx, buf),
54 }
55 }
56}
57
58#[cfg(feature = "async")]
59impl AsyncWrite for AsyncStreamWrapper {
60 fn poll_write(
61 mut self: std::pin::Pin<&mut Self>,
62 cx: &mut std::task::Context<'_>,
63 buf: &[u8],
64 ) -> std::task::Poll<Result<usize, std::io::Error>> {
65 match &mut *self {
66 AsyncStreamWrapper::Plain(s) => std::pin::Pin::new(s).poll_write(cx, buf),
67 AsyncStreamWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_write(cx, buf),
68 }
69 }
70
71 fn poll_flush(
72 mut self: std::pin::Pin<&mut Self>,
73 cx: &mut std::task::Context<'_>,
74 ) -> std::task::Poll<Result<(), std::io::Error>> {
75 match &mut *self {
76 AsyncStreamWrapper::Plain(s) => std::pin::Pin::new(s).poll_flush(cx),
77 AsyncStreamWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_flush(cx),
78 }
79 }
80
81 fn poll_shutdown(
82 mut self: std::pin::Pin<&mut Self>,
83 cx: &mut std::task::Context<'_>,
84 ) -> std::task::Poll<Result<(), std::io::Error>> {
85 match &mut *self {
86 AsyncStreamWrapper::Plain(s) => std::pin::Pin::new(s).poll_shutdown(cx),
87 AsyncStreamWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_shutdown(cx),
88 }
89 }
90}
91
92#[cfg(feature = "async")]
93impl AsyncWrite for AsyncWriteWrapper {
94 fn poll_write(
95 mut self: std::pin::Pin<&mut Self>,
96 cx: &mut std::task::Context<'_>,
97 buf: &[u8],
98 ) -> std::task::Poll<Result<usize, std::io::Error>> {
99 match &mut *self {
100 AsyncWriteWrapper::Plain(s) => std::pin::Pin::new(s).poll_write(cx, buf),
101 AsyncWriteWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_write(cx, buf),
102 }
103 }
104
105 fn poll_flush(
106 mut self: std::pin::Pin<&mut Self>,
107 cx: &mut std::task::Context<'_>,
108 ) -> std::task::Poll<Result<(), std::io::Error>> {
109 match &mut *self {
110 AsyncWriteWrapper::Plain(s) => std::pin::Pin::new(s).poll_flush(cx),
111 AsyncWriteWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_flush(cx),
112 }
113 }
114
115 fn poll_shutdown(
116 mut self: std::pin::Pin<&mut Self>,
117 cx: &mut std::task::Context<'_>,
118 ) -> std::task::Poll<Result<(), std::io::Error>> {
119 match &mut *self {
120 AsyncWriteWrapper::Plain(s) => std::pin::Pin::new(s).poll_shutdown(cx),
121 AsyncWriteWrapper::Encrypted(s) => std::pin::Pin::new(s).poll_shutdown(cx),
122 }
123 }
124}
125
126#[cfg(feature = "async")]
128pub struct AsyncViiperClient {
129 addr: SocketAddr,
130 password: Option<String>,
131}
132
133#[cfg(feature = "async")]
134impl AsyncViiperClient {
135 pub fn new(addr: SocketAddr) -> Self {
137 Self { addr, password: None }
138 }
139
140 pub fn new_with_password(addr: SocketAddr, password: String) -> Self {
143 let password = if password.is_empty() { None } else { Some(password) };
144 Self { addr, password }
145 }
146
147 async fn do_request<T: for<'de> serde::Deserialize<'de>>(
148 &self,
149 path: &str,
150 payload: Option<&str>,
151 ) -> Result<T, ViiperError> {
152 let tcp_stream = TcpStream::connect(self.addr).await?;
153 tcp_stream.set_nodelay(true)?;
154
155 let mut stream = if let Some(ref pwd) = self.password {
156 AsyncStreamWrapper::Encrypted(crate::auth::perform_handshake_async(tcp_stream, pwd).await?)
157 } else {
158 AsyncStreamWrapper::Plain(tcp_stream)
159 };
160
161 stream.write_all(path.as_bytes()).await?;
162 if let Some(p) = payload {
163 stream.write_all(b" ").await?;
164 stream.write_all(p.as_bytes()).await?;
165 }
166 stream.write_all(b"\0").await?;
167
168 let mut buf = Vec::new();
169 stream.read_to_end(&mut buf).await?;
170
171 let response = String::from_utf8(buf)
172 .map_err(|_| ViiperError::UnexpectedResponse("invalid UTF-8".into()))?
173 .trim_end_matches('\n')
174 .to_string();
175
176 if response.starts_with("{\"status\":") {
177 let problem: ProblemJson = serde_json::from_str(&response)?;
178 return Err(ViiperError::Protocol(problem));
179 }
180
181 serde_json::from_str(&response).map_err(Into::into)
182 }
183
184 pub async fn ping(&self) -> Result<PingResponse, ViiperError> {
186 let path = "ping".to_string();
187 let payload: Option<String> = None;
188 self.do_request(&path, payload.as_deref()).await
189 }
190
191 pub async fn bus_list(&self) -> Result<BusListResponse, ViiperError> {
193 let path = "bus/list".to_string();
194 let payload: Option<String> = None;
195 self.do_request(&path, payload.as_deref()).await
196 }
197
198 pub async fn bus_create(&self, uint32: Option<u32>) -> Result<BusCreateResponse, ViiperError> {
200 let path = "bus/create".to_string();
201 let payload = uint32.map(|v| v.to_string());
202 self.do_request(&path, payload.as_deref()).await
203 }
204
205 pub async fn bus_remove(&self, uint32: Option<u32>) -> Result<BusRemoveResponse, ViiperError> {
207 let path = "bus/remove".to_string();
208 let payload = uint32.map(|v| v.to_string());
209 self.do_request(&path, payload.as_deref()).await
210 }
211
212 pub async fn bus_devices_list(&self, id: u32) -> Result<DevicesListResponse, ViiperError> {
214 let path = format!("bus/{}/list", id);
215 let payload: Option<String> = None;
216 self.do_request(&path, payload.as_deref()).await
217 }
218
219 pub async fn bus_device_add(&self, id: u32, device_create_request: &DeviceCreateRequest) -> Result<Device, ViiperError> {
221 let path = format!("bus/{}/add", id);
222 let payload = Some(serde_json::to_string(&device_create_request)?);
223 self.do_request(&path, payload.as_deref()).await
224 }
225
226 pub async fn bus_device_remove(&self, id: u32, string: Option<&str>) -> Result<DeviceRemoveResponse, ViiperError> {
228 let path = format!("bus/{}/remove", id);
229 let payload = string.map(|s| s.to_string());
230 self.do_request(&path, payload.as_deref()).await
231 }
232
233 pub async fn connect_device(&self, bus_id: u32, dev_id: &str) -> Result<AsyncDeviceStream, ViiperError> {
235 AsyncDeviceStream::connect(self.addr, bus_id, dev_id, self.password.as_deref()).await
236 }
237}
238
239#[cfg(feature = "async")]
241pub struct AsyncDeviceStream {
242 read_stream: std::sync::Arc<tokio::sync::Mutex<AsyncReadWrapper>>,
243 write_stream: std::sync::Arc<tokio::sync::Mutex<AsyncWriteWrapper>>,
244 cancel_token: Option<tokio_util::sync::CancellationToken>,
245 disconnect_callback: std::sync::Mutex<Option<Box<dyn FnOnce() + Send + 'static>>>,
246}
247
248#[cfg(feature = "async")]
249impl AsyncDeviceStream {
250 pub async fn connect(addr: SocketAddr, bus_id: u32, dev_id: &str, password: Option<&str>) -> Result<Self, ViiperError> {
251 let tcp_stream = TcpStream::connect(addr).await?;
252 tcp_stream.set_nodelay(true)?;
253
254 let (read_stream, mut write_stream) = if let Some(pwd) = password {
255 let encrypted = crate::auth::perform_handshake_async(tcp_stream, pwd).await?;
256 let (read_half, write_half) = encrypted.into_split();
257 (AsyncReadWrapper::Encrypted(read_half), AsyncWriteWrapper::Encrypted(write_half))
258 } else {
259 let (read_half, write_half) = tcp_stream.into_split();
260 (AsyncReadWrapper::Plain(read_half), AsyncWriteWrapper::Plain(write_half))
261 };
262
263 let handshake = format!("bus/{}/{}\0", bus_id, dev_id);
264 write_stream.write_all(handshake.as_bytes()).await?;
265
266 Ok(Self {
267 read_stream: std::sync::Arc::new(tokio::sync::Mutex::new(read_stream)),
268 write_stream: std::sync::Arc::new(tokio::sync::Mutex::new(write_stream)),
269 cancel_token: None,
270 disconnect_callback: std::sync::Mutex::new(None),
271 })
272 }
273
274 pub async fn send<T: crate::wire::DeviceInput>(
276 &self,
277 input: &T,
278 ) -> Result<(), ViiperError> {
279 let bytes = input.to_bytes();
280 let mut stream = self.write_stream.lock().await;
281 stream.write_all(&bytes).await?;
282 Ok(())
283 }
284
285 pub async fn send_timeout<T: crate::wire::DeviceInput>(
291 &self,
292 input: &T,
293 timeout: std::time::Duration,
294 ) -> Result<(), ViiperError> {
295 let bytes = input.to_bytes();
296 let mut stream = self.write_stream.lock().await;
297 tokio::time::timeout(timeout, stream.write_all(&bytes))
298 .await
299 .map_err(|_| ViiperError::Timeout)?
300 .map_err(Into::into)
301 }
302
303 pub fn on_output<F, Fut>(&mut self, callback: F) -> Result<(), ViiperError>
308 where
309 F: Fn(std::sync::Arc<tokio::sync::Mutex<AsyncReadWrapper>>) -> Fut + Send + 'static,
310 Fut: std::future::Future<Output = std::io::Result<()>> + Send + 'static,
311 {
312 if self.cancel_token.is_some() {
313 return Err(ViiperError::UnexpectedResponse("Output callback already registered".into()));
314 }
315
316 let stream = self.read_stream.clone();
317 let cancel_token = tokio_util::sync::CancellationToken::new();
318 let cancel_clone = cancel_token.clone();
319 let Ok(mut guard) = self.disconnect_callback.lock() else {
320 return Err(ViiperError::UnexpectedResponse("Disconnect callback mutex poisoned".into()));
321 };
322 let disconnect = guard.take();
323
324 tokio::spawn(async move {
325 loop {
326 tokio::select! {
327 _ = cancel_clone.cancelled() => break,
328 result = callback(stream.clone()) => {
329 match result {
330 Ok(()) => continue,
331 Err(_) => break,
332 }
333 }
334 }
335 }
336 if let Some(cb) = disconnect {
337 cb();
338 }
339 });
340
341 self.cancel_token = Some(cancel_token);
342 Ok(())
343 }
344
345 pub fn on_disconnect<F>(&mut self, callback: F) -> Result<(), ViiperError>
346 where
347 F: FnOnce() + Send + 'static,
348 {
349 let Ok(mut guard) = self.disconnect_callback.lock() else {
350 return Err(ViiperError::UnexpectedResponse("Disconnect callback mutex poisoned".into()));
351 };
352 *guard = Some(Box::new(callback));
353 Ok(())
354 }
355
356 pub async fn send_raw(&self, data: &[u8]) -> Result<(), ViiperError> {
358 let mut stream = self.write_stream.lock().await;
359 stream.write_all(data).await?;
360 Ok(())
361 }
362
363 pub async fn read_raw(&self, buf: &mut [u8]) -> Result<usize, ViiperError> {
365 let mut stream = self.read_stream.lock().await;
366 stream.read(buf).await.map_err(Into::into)
367 }
368
369 pub async fn read_exact(&self, buf: &mut [u8]) -> Result<(), ViiperError> {
371 let mut stream = self.read_stream.lock().await;
372 stream.read_exact(buf).await?;
373 Ok(())
374 }
375}
376
377#[cfg(feature = "async")]
378impl Drop for AsyncDeviceStream {
379 fn drop(&mut self) {
380 if let Some(token) = &self.cancel_token {
381 token.cancel();
382 }
383 }
384}