volt_client_rs/
websocket_rpc.rs1use std::collections::HashMap;
2use std::sync::{Arc, Mutex as StdMutex};
3use tokio::sync::{mpsc, Mutex};
4
5#[derive(Debug, Clone, Hash, Eq, PartialEq)]
6pub enum EventType {
7 Data,
8 Error,
9 End,
10}
11
12#[derive(Debug, Clone)]
13pub enum Event {
14 Data(serde_json::Value),
15 Error(String),
16 End,
17}
18
19type Callback = Arc<StdMutex<dyn Fn(Event) + Send + Sync>>;
20
21#[derive(Clone)]
22pub struct WebsocketRpc {
23 pub id: u64,
24 protocol_manager: Arc<Mutex<volt_ws_protocol::rpc_manager::RpcManager>>,
25 send_channel: mpsc::Sender<Vec<u8>>,
26 callbacks: Arc<StdMutex<HashMap<EventType, Vec<Callback>>>>,
27}
28
29impl WebsocketRpc {
30 pub fn new(
31 id: u64,
32 protocol_manager: Arc<Mutex<volt_ws_protocol::rpc_manager::RpcManager>>,
33 send_channel: mpsc::Sender<Vec<u8>>,
34 ) -> Self {
35 WebsocketRpc {
36 id,
37 protocol_manager,
38 send_channel,
39 callbacks: Arc::new(StdMutex::new(HashMap::new())),
40 }
41 }
42
43 pub fn on<F>(&self, event_type: EventType, callback: F)
44 where
45 F: Fn(Event) + Send + Sync + 'static,
46 {
47 let mut callbacks = self.callbacks.lock().unwrap();
48 let entry = callbacks.entry(event_type).or_insert_with(Vec::new);
49 entry.push(Arc::new(StdMutex::new(callback)));
50 }
51
52 fn emit(&self, event: Event) {
53 let event_type = match &event {
54 Event::Data(_) => EventType::Data,
55 Event::Error(_) => EventType::Error,
56 Event::End => EventType::End,
57 };
58
59 if let Some(callbacks) = self.callbacks.lock().unwrap().get(&event_type) {
60 for callback in callbacks.iter() {
61 let callback = callback.clone();
62 let event = event.clone();
63 (callback.lock().unwrap())(event);
64 }
65 }
66 }
67
68 pub fn abort(&mut self, error: String) {
69 self.emit(Event::Error(error));
70 }
71
72 async fn send_internal(&self, payload: &str) -> Result<(), String> {
73 let protocol = self.protocol_manager.lock().await;
75 let encoded = match protocol.encode_payload(&self.id, payload) {
76 Ok(encoded) => encoded,
77 Err(e) => {
78 return Err(format!("Failed to encode payload: {}", e));
79 }
80 };
81
82 println!("Sending payload size: {}", encoded.len());
83
84 let send_result = self.send_channel.send(encoded).await;
86
87 match send_result {
88 Ok(_) => Ok(()),
89 Err(e) => {
90 return Err(format!("Failed to send payload: {}", e));
91 }
92 }
93 }
94
95 pub async fn send(&self, payload: &serde_json::Value) -> Result<(), String> {
96 let payload_json = match serde_json::to_string(&payload) {
97 Ok(payload_json) => payload_json,
98 Err(e) => {
99 return Err(format!("Failed to serialize payload: {}", e));
100 }
101 };
102
103 self.send_internal(&payload_json).await
104 }
105
106 pub async fn end(&self) -> Result<(), String> {
107 self.send_internal("").await
108 }
109
110 pub fn handle_response(&mut self, response: &serde_json::Value) {
111 if !response["error"].is_null() {
112 self.abort(response["error"].as_str().unwrap().to_string());
113 } else {
114 let response_payload = &response["payload"];
115
116 if response_payload.is_null() {
117 println!(
118 "Received response with no payload for method_id: {}",
119 self.id
120 );
121 } else {
122 let method_payload = &response_payload["method_payload"];
123 if method_payload.is_null() {
124 let method_end = &response_payload["method_end"];
125 if method_end.is_null() {
126 println!(
127 "Received response with no method_payload for method_id: {}",
128 self.id
129 );
130 } else {
131 println!("received method_end for {} {}", self.id, method_end);
132 if response["error"].is_null() {
133 self.emit(Event::End);
134 } else {
135 self.emit(Event::Error(
136 response["error"].as_str().unwrap().to_string(),
137 ));
138 }
139 }
140 } else {
141 println!("received payload for {} {}", self.id, method_payload);
142 let payload_json = match method_payload["json_payload"].as_str() {
143 Some(json_payload) => json_payload,
144 None => {
145 self.abort("Received response with no json_payload".to_string());
146 return;
147 }
148 };
149
150 let payload: serde_json::Value = match serde_json::from_str(payload_json) {
151 Ok(payload) => payload,
152 Err(e) => {
153 self.abort(format!("Failed to parse json_payload: {}", e));
154 return;
155 }
156 };
157
158 if payload["status"].is_null() {
159 self.emit(Event::Data(payload));
160 } else {
161 self.emit(Event::Error(
162 payload["status"]["message"].as_str().unwrap().to_string(),
163 ));
164 }
165 }
166 }
167 }
168 }
169}