1use super::{Message, MessageOut};
2use failure::Fail;
3use futures::prelude::*;
4use serde::{de, ser};
5use serde_derive::Serialize;
6use std::marker::PhantomData;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use tokio::net::TcpStream;
10use tokio_tungstenite::{self, MaybeTlsStream, WebSocketStream};
11use url::Url;
12
13pub struct StreamDeckSocket<G, S, MI, MO> {
19 inner: WebSocketStream<MaybeTlsStream<TcpStream>>,
20 _g: PhantomData<G>,
21 _s: PhantomData<S>,
22 _mi: PhantomData<MI>,
23 _mo: PhantomData<MO>,
24}
25
26impl<G, S, MI, MO> StreamDeckSocket<G, S, MI, MO> {
27 pub async fn connect<A: Into<Address>>(
42 address: A,
43 event: String,
44 uuid: String,
45 ) -> Result<Self, ConnectError> {
46 let address = address.into();
47
48 let (mut stream, _) = tokio_tungstenite::connect_async(address.url)
49 .await
50 .map_err(ConnectError::ConnectionError)?;
51
52 let message = serde_json::to_string(&Registration {
53 event: &event,
54 uuid: &uuid,
55 })
56 .unwrap();
57 stream
58 .send(tungstenite::Message::Text(message))
59 .await
60 .map_err(ConnectError::SendError)?;
61
62 Ok(StreamDeckSocket {
63 inner: stream,
64 _g: PhantomData,
65 _s: PhantomData,
66 _mi: PhantomData,
67 _mo: PhantomData,
68 })
69 }
70
71 fn pin_get_inner(self: Pin<&mut Self>) -> Pin<&mut WebSocketStream<MaybeTlsStream<TcpStream>>> {
72 unsafe { self.map_unchecked_mut(|s| &mut s.inner) }
73 }
74}
75
76#[derive(Debug, Fail)]
78pub enum StreamDeckSocketError {
79 #[fail(display = "WebSocket error")]
81 WebSocketError(#[fail(cause)] tungstenite::error::Error),
82 #[fail(display = "Bad message")]
84 BadMessage(#[fail(cause)] serde_json::Error),
85}
86
87impl<G, S, MI, MO> Stream for StreamDeckSocket<G, S, MI, MO>
88where
89 G: de::DeserializeOwned,
90 S: de::DeserializeOwned,
91 MI: de::DeserializeOwned,
92{
93 type Item = Result<Message<G, S, MI>, StreamDeckSocketError>;
94
95 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
96 let mut inner = self.pin_get_inner();
97 loop {
98 match inner.as_mut().poll_next(cx) {
99 Poll::Ready(Some(Ok(tungstenite::Message::Text(message)))) => {
100 break match serde_json::from_str(&message) {
101 Ok(message) => Poll::Ready(Some(Ok(message))),
102 Err(error) => {
103 Poll::Ready(Some(Err(StreamDeckSocketError::BadMessage(error))))
104 }
105 };
106 }
107 Poll::Ready(Some(Ok(_))) => {}
108 Poll::Ready(Some(Err(error))) => {
109 break Poll::Ready(Some(Err(StreamDeckSocketError::WebSocketError(error))))
110 }
111 Poll::Ready(None) => break Poll::Ready(None),
112 Poll::Pending => break Poll::Pending,
113 }
114 }
115 }
116}
117
118impl<G, S, MI, MO> Sink<MessageOut<G, S, MO>> for StreamDeckSocket<G, S, MI, MO>
119where
120 G: ser::Serialize,
121 S: ser::Serialize,
122 MO: ser::Serialize,
123{
124 type Error = StreamDeckSocketError;
125
126 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
127 self.pin_get_inner()
128 .poll_ready(cx)
129 .map_err(StreamDeckSocketError::WebSocketError)
130 }
131
132 fn start_send(self: Pin<&mut Self>, item: MessageOut<G, S, MO>) -> Result<(), Self::Error> {
133 let message = serde_json::to_string(&item).map_err(StreamDeckSocketError::BadMessage)?;
134 self.pin_get_inner()
135 .start_send(tungstenite::Message::Text(message))
136 .map_err(StreamDeckSocketError::WebSocketError)
137 }
138
139 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
140 self.pin_get_inner()
141 .poll_flush(cx)
142 .map_err(StreamDeckSocketError::WebSocketError)
143 }
144
145 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
146 self.pin_get_inner()
147 .poll_close(cx)
148 .map_err(StreamDeckSocketError::WebSocketError)
149 }
150}
151
152pub struct Address {
154 pub url: Url,
155}
156
157impl From<Url> for Address {
158 fn from(value: Url) -> Self {
159 Address { url: value }
160 }
161}
162
163impl From<u16> for Address {
164 fn from(value: u16) -> Self {
165 let mut url = Url::parse("ws://localhost").unwrap();
166 url.set_port(Some(value)).unwrap();
167 Address { url }
168 }
169}
170
171#[derive(Debug, Fail)]
173pub enum ConnectError {
174 #[fail(display = "Websocket connection error")]
176 ConnectionError(#[fail(cause)] tungstenite::error::Error),
177 #[fail(display = "Send error")]
179 SendError(#[fail(cause)] tungstenite::error::Error),
180}
181
182#[derive(Serialize)]
183struct Registration<'a> {
184 event: &'a str,
185 uuid: &'a str,
186}