volt_client_rs/
websocket_rpc.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex as StdMutex};
3use tokio::sync::{mpsc, Mutex};
4
5#[derive(Debug, Clone, Hash, Eq, PartialEq)]
6pub enum EventType {
7    Data,
8    Error,
9    End,
10}
11
12#[derive(Debug, Clone)]
13pub enum Event {
14    Data(serde_json::Value),
15    Error(String),
16    End,
17}
18
19type Callback = Arc<StdMutex<dyn Fn(Event) + Send + Sync>>;
20
21#[derive(Clone)]
22pub struct WebsocketRpc {
23    pub id: u64,
24    protocol_manager: Arc<Mutex<volt_ws_protocol::rpc_manager::RpcManager>>,
25    send_channel: mpsc::Sender<Vec<u8>>,
26    callbacks: Arc<StdMutex<HashMap<EventType, Vec<Callback>>>>,
27}
28
29impl WebsocketRpc {
30    pub fn new(
31        id: u64,
32        protocol_manager: Arc<Mutex<volt_ws_protocol::rpc_manager::RpcManager>>,
33        send_channel: mpsc::Sender<Vec<u8>>,
34    ) -> Self {
35        WebsocketRpc {
36            id,
37            protocol_manager,
38            send_channel,
39            callbacks: Arc::new(StdMutex::new(HashMap::new())),
40        }
41    }
42
43    pub fn on<F>(&self, event_type: EventType, callback: F)
44    where
45        F: Fn(Event) + Send + Sync + 'static,
46    {
47        let mut callbacks = self.callbacks.lock().unwrap();
48        let entry = callbacks.entry(event_type).or_insert_with(Vec::new);
49        entry.push(Arc::new(StdMutex::new(callback)));
50    }
51
52    fn emit(&self, event: Event) {
53        let event_type = match &event {
54            Event::Data(_) => EventType::Data,
55            Event::Error(_) => EventType::Error,
56            Event::End => EventType::End,
57        };
58
59        if let Some(callbacks) = self.callbacks.lock().unwrap().get(&event_type) {
60            for callback in callbacks.iter() {
61                let callback = callback.clone();
62                let event = event.clone();
63                (callback.lock().unwrap())(event);
64            }
65        }
66    }
67
68    pub fn abort(&mut self, error: String) {
69        self.emit(Event::Error(error));
70    }
71
72    async fn send_internal(&self, payload: &str) -> Result<(), String> {
73        // Encode the payload using the wasm protocol manager.
74        let protocol = self.protocol_manager.lock().await;
75        let encoded = match protocol.encode_payload(&self.id, payload) {
76            Ok(encoded) => encoded,
77            Err(e) => {
78                return Err(format!("Failed to encode payload: {}", e));
79            }
80        };
81
82        println!("Sending payload size: {}", encoded.len());
83
84        // Send the request via the websocket send channel.
85        let send_result = self.send_channel.send(encoded).await;
86
87        match send_result {
88            Ok(_) => Ok(()),
89            Err(e) => {
90                return Err(format!("Failed to send payload: {}", e));
91            }
92        }
93    }
94
95    pub async fn send(&self, payload: &serde_json::Value) -> Result<(), String> {
96        let payload_json = match serde_json::to_string(&payload) {
97            Ok(payload_json) => payload_json,
98            Err(e) => {
99                return Err(format!("Failed to serialize payload: {}", e));
100            }
101        };
102
103        self.send_internal(&payload_json).await
104    }
105
106    pub async fn end(&self) -> Result<(), String> {
107        self.send_internal("").await
108    }
109
110    pub fn handle_response(&mut self, response: &serde_json::Value) {
111        if !response["error"].is_null() {
112            self.abort(response["error"].as_str().unwrap().to_string());
113        } else {
114            let response_payload = &response["payload"];
115
116            if response_payload.is_null() {
117                println!(
118                    "Received response with no payload for method_id: {}",
119                    self.id
120                );
121            } else {
122                let method_payload = &response_payload["method_payload"];
123                if method_payload.is_null() {
124                    let method_end = &response_payload["method_end"];
125                    if method_end.is_null() {
126                        println!(
127                            "Received response with no method_payload for method_id: {}",
128                            self.id
129                        );
130                    } else {
131                        println!("received method_end for {} {}", self.id, method_end);
132                        if response["error"].is_null() {
133                            self.emit(Event::End);
134                        } else {
135                            self.emit(Event::Error(
136                                response["error"].as_str().unwrap().to_string(),
137                            ));
138                        }
139                    }
140                } else {
141                    println!("received payload for {} {}", self.id, method_payload);
142                    let payload_json = match method_payload["json_payload"].as_str() {
143                        Some(json_payload) => json_payload,
144                        None => {
145                            self.abort("Received response with no json_payload".to_string());
146                            return;
147                        }
148                    };
149
150                    let payload: serde_json::Value = match serde_json::from_str(payload_json) {
151                        Ok(payload) => payload,
152                        Err(e) => {
153                            self.abort(format!("Failed to parse json_payload: {}", e));
154                            return;
155                        }
156                    };
157
158                    if payload["status"].is_null() {
159                        self.emit(Event::Data(payload));
160                    } else {
161                        self.emit(Event::Error(
162                            payload["status"]["message"].as_str().unwrap().to_string(),
163                        ));
164                    }
165                }
166            }
167        }
168    }
169}