1use std::collections::HashMap;
17use std::net::Ipv6Addr;
18use std::sync::Arc;
19
20use serde::de::DeserializeOwned;
21use serde::Serialize;
22use serde_json::{Map, Value};
23use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, ReadHalf, WriteHalf};
24use tokio::net::{TcpStream, ToSocketAddrs, UdpSocket};
25use tokio::sync::{oneshot, Mutex};
26use tokio::{task, time};
27
28mod discovery;
29pub mod error;
30
31pub use discovery::{run as discover, DiscoveredDevice, Protocol};
32
33enum WriteSocketKind {
34 TCP(WriteHalf<TcpStream>),
35 #[allow(dead_code)]
36 UDP(UdpSocket),
37}
38
39enum ReadSocketKind {
40 TCP(ReadHalf<TcpStream>),
41 #[allow(dead_code)]
42 UDP(UdpSocket),
43}
44
45struct State {
46 reply_to: Option<oneshot::Sender<String>>,
47}
48
49pub struct Client {
59 state: Arc<Mutex<State>>,
60 socket: WriteSocketKind,
61}
62
63#[derive(PartialEq, Eq, Debug)]
64pub enum ListNode {
65 Branch(Box<HashMap<String, ListNode>>),
66 Leaf,
67}
68
69impl Client {
70 pub async fn connect<TSA: ToSocketAddrs>(addr: TSA, mode: Protocol) -> error::Result<Self> {
72 let (write_socket, read_socket) = match mode {
73 Protocol::UDP => {
74 let socket = UdpSocket::bind((Ipv6Addr::UNSPECIFIED, 0)).await?;
75 socket.connect(addr).await?;
76 WriteSocketKind::UDP(socket);
77 todo!()
78 }
79 Protocol::TCP => {
80 let (rx, tx) = tokio::io::split(TcpStream::connect(addr).await?);
81 (WriteSocketKind::TCP(tx), ReadSocketKind::TCP(rx))
82 }
83 };
84
85 let state = Arc::new(Mutex::new(State { reply_to: None }));
86
87 task::spawn(receiver(state.clone(), read_socket));
88
89 Ok(Client {
90 state,
91 socket: write_socket,
92 })
93 }
94
95 pub async fn get<T: DeserializeOwned>(&mut self, path: &str) -> error::Result<T> {
97 let (tx, rx) = oneshot::channel();
98 self.register_callback(path.to_owned(), tx).await;
99 self.send_message(path, &serde_json::Value::Null).await?;
100
101 let response = wait_response(rx).await?;
102
103 unserialize_json_message(path, response)
104 }
105
106 pub async fn set<T: Serialize>(&mut self, path: &str, value: &T) -> error::Result<()> {
108 let (tx, rx) = oneshot::channel();
109 self.register_callback(path.to_owned(), tx).await;
110 self.send_message(path, &value).await?;
111
112 let res = wait_response(rx).await?;
113
114 println!("SET Result: {res}");
115
116 Ok(())
117 }
118
119 pub async fn list(&mut self, path: &str) -> error::Result<HashMap<String, ListNode>> {
121 let message = build_json_message(path, &serde_json::Value::Null)?;
122
123 let message = if path == "/" {
124 message
125 } else {
126 serde_json::Value::Array(vec![message])
127 };
128
129 let (tx, rx) = oneshot::channel();
130 self.register_callback(path.to_owned(), tx).await;
131
132 self.send_message("/osc/schema", &message).await?;
133
134 let res = wait_response(rx).await?;
135
136 let outer_response: Vec<serde_json::Value> = unserialize_json_message("/osc/schema", res)?;
137
138 let actual_schema: HashMap<String, serde_json::Value> =
139 unpack_json_message(path, outer_response.into_iter().next().unwrap())?;
140
141 let mut res = HashMap::new();
142
143 for (k, v) in actual_schema {
144 let v = match v {
145 Value::Null => ListNode::Leaf,
146 Value::Object(_) => {
147 let sub_path = if path == "/" {
148 format!("/{k}")
149 } else {
150 format!("{path}/{k}")
151 };
152 ListNode::Branch(Box::new(Box::pin(self.list(&sub_path)).await?))
153 }
154 _ => return Err(error::Error::InvalidPath),
155 };
156 res.insert(k, v);
157 }
158
159 Ok(res)
160 }
161
162 async fn register_callback(&self, _path: String, callback: oneshot::Sender<String>) {
163 let mut guard = self.state.lock().await;
164 guard.reply_to = Some(callback)
165 }
166
167 async fn send_message<T: serde::Serialize>(
168 &mut self,
169 path: &str,
170 message: &T,
171 ) -> error::Result<()> {
172 let mut data = serialize_json_message(path, message)?;
173
174 match &mut self.socket {
175 WriteSocketKind::TCP(socket) => {
176 data.extend_from_slice(b"\r\n");
177 socket.write_all(&data).await?;
178 }
179 WriteSocketKind::UDP(_) => todo!(),
180 }
181 Ok(())
182 }
183}
184
185fn serialize_json_message<T: Serialize>(path: &str, content: &T) -> error::Result<Vec<u8>> {
186 let data = build_json_message(path, content)?;
187 Ok(serde_json::to_vec(&data)?)
188}
189
190fn build_json_message<T: Serialize>(path: &str, content: &T) -> error::Result<serde_json::Value> {
191 let components = normalize_path(path)?
192 .split("/")
193 .collect::<Vec<_>>()
194 .into_iter()
195 .rev();
196
197 let mut data = serde_json::to_value(content)?;
198 for component in components {
199 if component == "" {
200 data = serde_json::Value::Null;
201 } else {
202 let mut hm = Map::new();
203 hm.insert(component.to_owned(), data);
204 data = serde_json::Value::Object(hm);
205 }
206 }
207
208 Ok(data)
209}
210
211fn unserialize_json_message<T: DeserializeOwned>(path: &str, data: String) -> error::Result<T> {
212 let value: serde_json::Value = serde_json::from_str(&data)?;
213
214 unpack_json_message(path, value)
215}
216
217fn unpack_json_message<T: DeserializeOwned>(
218 path: &str,
219 mut value: serde_json::Value,
220) -> error::Result<T> {
221 if path != "/" {
223 for component in normalize_path(path)?.split("/") {
224 if let serde_json::Value::Object(mut map) = value {
225 if let Some((key, new_value)) = map.remove_entry(component) {
226 if key != component {
227 return Err(error::Error::UnexpectedPath);
228 }
229 value = new_value;
230 } else {
231 return Err(error::Error::UnexpectedPath);
232 }
233 } else {
234 return Err(error::Error::UnexpectedPath);
235 }
236 }
237 }
238
239 Ok(serde_json::from_value(value)?)
240}
241
242async fn wait_response(rx: oneshot::Receiver<String>) -> error::Result<String> {
243 Ok(time::timeout(time::Duration::from_secs(5), rx)
244 .await
245 .map_err(|_| error::Error::RequestTimeout)?
246 .map_err(|_| error::Error::ProcessingResponseError)?)
247}
248
249async fn receiver(state: Arc<Mutex<State>>, read: ReadSocketKind) {
250 match read {
251 ReadSocketKind::TCP(read) => {
252 let mut lines = BufReader::new(read).lines();
253
254 loop {
255 while let Some(line) = lines.next_line().await.unwrap() {
256 let mut guard = state.lock().await;
257
258 if let Some(reply_to) = guard.reply_to.take() {
259 reply_to.send(line).unwrap();
260 }
261 }
262 }
263 }
264 ReadSocketKind::UDP(_) => todo!(),
265 }
266}
267
268fn normalize_path(path: &str) -> error::Result<&str> {
269 if path.starts_with("/") {
270 Ok(&path[1..])
271 } else {
272 Err(error::Error::InvalidPath)
273 }
274}
275
276#[cfg(test)]
277mod test {
278 use serde::{Deserialize, Serialize};
279
280 use super::*;
281
282 #[test]
283 fn test_build_json_message() {
284 #[derive(Serialize)]
285 struct Test {
286 value: u8,
287 }
288
289 let got =
290 String::from_utf8(serialize_json_message("/test/42", &Test { value: 42 }).unwrap())
291 .unwrap();
292 let want = r#"{"test":{"42":{"value":42}}}"#;
293
294 assert_eq!(&got, want);
295 }
296
297 #[test]
298 fn test_unpack_json_message() {
299 #[derive(Deserialize, PartialEq, Eq, Debug)]
300 struct Test {
301 value: u8,
302 }
303
304 let got: Test = unserialize_json_message(
305 "/test/42",
306 r#"{"test": {"42": {"value": 42 } } }"#.to_owned(),
307 )
308 .unwrap();
309 let want = Test { value: 42 };
310
311 assert_eq!(got, want);
312 }
313}