1use crate::apps::connections_open::connections_open;
2use crate::error::Error;
3use crate::http_client::SlackWebAPIClient;
4use crate::socket::event::{
5 AcknowledgeMessage, DisconnectEvent, EventsAPI, HelloEvent, InteractiveEvent,
6 SlashCommandsEvent, SocketModeEvent,
7};
8use async_std::fs::read;
9use async_std::net::TcpStream;
10use async_tls::client::TlsStream;
11use async_tls::TlsConnector;
12use async_trait::async_trait;
13use async_tungstenite::tungstenite::Message;
14use async_tungstenite::{client_async, WebSocketStream};
15use futures_util::{SinkExt, StreamExt};
16use rustls::ClientConfig;
17use std::collections::HashMap;
18use std::io::Cursor;
19use std::sync::Arc;
20use url::Url;
21
22pub type Stream = WebSocketStream<TlsStream<TcpStream>>;
23
24#[allow(unused_variables)]
26#[async_trait]
27pub trait EventHandler<S>: Send
28where
29 S: SlackWebAPIClient,
30{
31 async fn on_close(&mut self, socket_mode: &SocketMode<S>) {
32 log::info!("websocket close");
33 }
34 async fn on_connect(&mut self, socket_mode: &SocketMode<S>) {
35 log::info!("websocket connect");
36 }
37 async fn on_hello(&mut self, socket_mode: &SocketMode<S>, e: HelloEvent, s: &mut Stream) {
38 log::info!("hello event: {:?}", e);
39 }
40 async fn on_disconnect(
41 &mut self,
42 socket_mode: &SocketMode<S>,
43 e: DisconnectEvent,
44 s: &mut Stream,
45 ) {
46 log::info!("disconnect event: {:?}", e);
47 }
48 async fn on_events_api(&mut self, socket_mode: &SocketMode<S>, e: EventsAPI, s: &mut Stream) {
49 log::info!("events api event: {:?}", e);
50 }
51 async fn on_interactive(
52 &mut self,
53 socket_mode: &SocketMode<S>,
54 e: InteractiveEvent,
55 s: &mut Stream,
56 ) {
57 log::info!("interactive event: {:?}", e);
58 }
59 async fn on_slash_commands(
60 &mut self,
61 socket_mode: &SocketMode<S>,
62 e: SlashCommandsEvent,
63 s: &mut Stream,
64 ) {
65 log::info!("slash commands event: {:?}", e);
66 }
67}
68
69pub struct SocketMode<S>
71where
72 S: SlackWebAPIClient,
73{
74 pub api_client: S,
75 pub app_token: String,
76 pub bot_token: String,
77 pub option_parameter: HashMap<String, String>,
78 pub web_socket_port: u16,
79 pub ca_file_path: Option<String>,
80}
81
82impl<S> SocketMode<S>
83where
84 S: SlackWebAPIClient,
85{
86 pub fn new(api_client: S, app_token: String, bot_token: String) -> Self {
87 SocketMode {
88 api_client,
89 app_token,
90 bot_token,
91 option_parameter: HashMap::new(),
92 web_socket_port: 443,
93 ca_file_path: None,
94 }
95 }
96 pub fn option_parameter(mut self, key: String, value: String) -> Self {
97 self.option_parameter.insert(key, value);
98 self
99 }
100 pub fn web_socket_port(mut self, port: u16) -> Self {
101 self.web_socket_port = port;
102 self
103 }
104 pub fn ca_file_path(mut self, ca_file_path: String) -> Self {
105 self.ca_file_path = Some(ca_file_path);
106 self
107 }
108 pub async fn run<T>(self, handler: &mut T) -> Result<(), Error>
110 where
111 T: EventHandler<S>,
112 {
113 let response = connections_open(&self.api_client, &self.app_token).await?;
114 let ws_url = response.url.ok_or(Error::SocketModeOpenConnectionError)?;
115 let ws_url_parsed = Url::parse(&ws_url)?;
116 let ws_domain = ws_url_parsed.domain().ok_or(Error::NotFoundDomain)?;
117
118 let tcp_stream = TcpStream::connect((ws_domain, self.web_socket_port)).await?;
119 let connector = if let Some(ca_file_path) = &self.ca_file_path {
120 connector_for_ca_file(ca_file_path).await?
121 } else {
122 TlsConnector::default()
123 };
124 let tls_stream = connector.connect(ws_domain, tcp_stream).await?;
125
126 let (mut ws, _) = client_async(&ws_url, tls_stream).await?;
127
128 handler.on_connect(&self).await;
129
130 loop {
131 let message = ws.next().await.ok_or(Error::NotFoundStream)?;
132
133 match message? {
134 Message::Text(t) => {
135 let event = serde_json::from_str::<SocketModeEvent>(&t)?;
136 match event {
137 SocketModeEvent::HelloEvent(e) => handler.on_hello(&self, e, &mut ws).await,
138 SocketModeEvent::DisconnectEvent(e) => {
139 handler.on_disconnect(&self, e, &mut ws).await
140 }
141 SocketModeEvent::EventsAPI(e) => {
142 handler.on_events_api(&self, e, &mut ws).await
143 }
144 SocketModeEvent::InteractiveEvent(e) => {
145 handler.on_interactive(&self, e, &mut ws).await
146 }
147 SocketModeEvent::SlashCommandsEvent(e) => {
148 handler.on_slash_commands(&self, e, &mut ws).await
149 }
150 }
151 }
152 Message::Ping(p) => log::info!("ping: {:?}", p),
153 Message::Close(_) => {
154 handler.on_close(&self).await;
155 break;
156 }
157 m => log::warn!("unsupported web socket message: {:?}", m),
158 }
159 }
160 Ok(())
161 }
162}
163
164pub async fn ack(envelope_id: &str, stream: &mut Stream) -> Result<(), Error> {
165 let json = serde_json::to_string(&AcknowledgeMessage { envelope_id })?;
166 stream
167 .send(Message::Text(json))
168 .await
169 .map_err(Error::WebSocketError)
170}
171
172pub async fn connector_for_ca_file(ca_file_path: &str) -> Result<TlsConnector, Error> {
173 let mut config = ClientConfig::new();
174 let file = read(ca_file_path).await?;
175 let mut pem = Cursor::new(file);
176 config
177 .root_store
178 .add_pem_file(&mut pem)
179 .map_err(|_| Error::InvalidInputError)?;
180 Ok(TlsConnector::from(Arc::new(config)))
181}
182
183#[cfg(test)]
184mod test {
185 use crate::event_api::event::Event;
186 use crate::http_client::{MockSlackWebAPIClient, SlackWebAPIClient};
187 use crate::payloads::interactive::InteractiveEventType;
188 use crate::socket::event::{
189 DisconnectEvent, DisconnectReason, EventsAPI, HelloEvent, InteractiveEvent,
190 SlashCommandsEvent,
191 };
192 use crate::socket::socket_mode::{EventHandler, SocketMode, Stream};
193 use async_std::net::TcpListener;
194 use async_std::task;
195 use async_tls::TlsAcceptor;
196 use async_trait::async_trait;
197 use async_tungstenite::tungstenite::Message;
198 use futures_util::{SinkExt, StreamExt};
199 use rustls::internal::pemfile::{certs, pkcs8_private_keys};
200 use rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
201 use std::error::Error;
202 use std::fs::File;
203 use std::io;
204 use std::io::BufReader;
205 use std::sync::Arc;
206
207 pub struct Handler;
208
209 #[allow(unused_variables)]
210 #[async_trait]
211 impl<S> EventHandler<S> for Handler
212 where
213 S: SlackWebAPIClient,
214 {
215 async fn on_hello(&mut self, socket_mode: &SocketMode<S>, e: HelloEvent, s: &mut Stream) {
216 assert_eq!(e.connection_info.unwrap().app_id.unwrap(), "app_id");
217 assert_eq!(e.num_connections.unwrap(), 1);
218 assert_eq!(e.debug_info.unwrap().host.unwrap(), "host");
219 log::info!("success on_hello test");
220 }
221 async fn on_disconnect(
222 &mut self,
223 socket_mode: &SocketMode<S>,
224 e: DisconnectEvent,
225 s: &mut Stream,
226 ) {
227 assert_eq!(e.reason, DisconnectReason::LinkDisabled);
228 assert_eq!(e.debug_info.unwrap().host.unwrap(), "wss-111.slack.com");
229 log::info!("success on_disconnect test");
230 }
231 async fn on_events_api(
232 &mut self,
233 socket_mode: &SocketMode<S>,
234 e: EventsAPI,
235 s: &mut Stream,
236 ) {
237 assert_eq!(e.envelope_id, "dbdd0ef3-1543-4f94-bfb4-133d0e6c1545");
238 assert!(!e.accepts_response_payload, "false");
239
240 match e.payload {
241 Event::AppHomeOpened { user, .. } => {
242 assert_eq!(user, "U061F7AUR");
243 }
244 _ => panic!("Payload deserialize into incorrect variant"),
245 }
246 log::info!("success on_events_api test");
247 }
248 async fn on_interactive(
249 &mut self,
250 socket_mode: &SocketMode<S>,
251 e: InteractiveEvent,
252 s: &mut Stream,
253 ) {
254 assert_eq!(e.envelope_id, "dbdd0ef3-1543-4f94-bfb4-133d0e6c1545");
255 assert!(e.accepts_response_payload, "true");
256 assert_eq!(e.payload.type_filed, InteractiveEventType::ViewSubmission);
257 log::info!("success on_interactive test")
258 }
259 async fn on_slash_commands(
260 &mut self,
261 socket_mode: &SocketMode<S>,
262 e: SlashCommandsEvent,
263 s: &mut Stream,
264 ) {
265 assert_eq!(e.envelope_id, "dbdd0ef3-1543-4f94-bfb4-133d0e6c1545");
266 assert!(e.accepts_response_payload, "true");
267 assert_eq!(e.payload.token.unwrap(), "bHKJ2n9AW6Ju3MjciOHfbA1b");
268 log::info!("success on_slash_commands test");
269 }
270 }
271
272 #[async_std::test]
273 async fn test_socket_mode() {
274 env_logger::init();
275
276 let event = vec![
277 r##"{
278 "type": "hello",
279 "connection_info": {
280 "app_id": "app_id"
281 },
282 "num_connections": 1,
283 "debug_info": {
284 "host": "host"
285 }
286}"##
287 .to_string(),
288 r##"{
289 "type": "disconnect",
290 "reason": "link_disabled",
291 "debug_info": {
292 "host": "wss-111.slack.com"
293 }
294}"##
295 .to_string(),
296 r##"{
297 "type": "events_api",
298 "envelope_id": "dbdd0ef3-1543-4f94-bfb4-133d0e6c1545",
299 "accepts_response_payload": false,
300 "payload": {
301 "type": "app_home_opened",
302 "user": "U061F7AUR",
303 "channel": "D0LAN2Q65",
304 "event_ts": "1515449522000016",
305 "tab": "home",
306 "view": {
307 "id": "VPASKP233"
308 }
309 }
310}"##
311 .to_string(),
312 r##"{
313 "type": "interactive",
314 "envelope_id": "dbdd0ef3-1543-4f94-bfb4-133d0e6c1545",
315 "accepts_response_payload": true,
316 "payload": {
317 "type": "view_submission"
318 }
319}"##
320 .to_string(),
321 r##"{
322 "type": "slash_commands",
323 "envelope_id": "dbdd0ef3-1543-4f94-bfb4-133d0e6c1545",
324 "accepts_response_payload": true,
325 "payload": {
326 "token": "bHKJ2n9AW6Ju3MjciOHfbA1b"
327 }
328}"##
329 .to_string(),
330 ];
331
332 let mut mock = MockSlackWebAPIClient::new();
333 mock.expect_post().times(1).returning(|_, _| {
334 Ok(r##"{
335 "ok": true,
336 "url": "wss://localhost"
337 }"##
338 .to_string())
339 });
340
341 let port = mock_web_socket(event).await.unwrap();
342 SocketMode::new(
343 mock,
344 "slack_app_token".to_string(),
345 "slack_bot_token".to_string(),
346 )
347 .web_socket_port(port)
348 .option_parameter(
349 "SLACK_CHANNEL_ID".to_string(),
350 "slack_channel_id".to_string(),
351 )
352 .ca_file_path("rootCA.pem".to_string())
353 .run(&mut Handler)
354 .await
355 .unwrap_or_else(|_| panic!("socket mode run error."));
356 }
357
358 async fn mock_web_socket(event: Vec<String>) -> Result<u16, Box<dyn Error>> {
359 let listener = TcpListener::bind("localhost:0").await?;
360 let port = listener.local_addr()?.port();
361
362 task::spawn(async move {
363 web_socket_handler(&listener, event).await;
364 });
365
366 Ok(port)
367 }
368
369 async fn web_socket_handler(listener: &TcpListener, event: Vec<String>) {
370 let config = load_config("localhost.pem", "localhost-key.pem").unwrap();
371 let acceptor = TlsAcceptor::from(Arc::new(config));
373
374 let mut incoming = listener.incoming();
375
376 while let Some(stream) = incoming.next().await {
377 let acceptor = acceptor.clone();
378 let tcp_stream = stream.unwrap();
379 let tls_stream = acceptor.accept(tcp_stream).await.unwrap();
380 let mut ws = async_tungstenite::accept_async(tls_stream).await.unwrap();
381
382 let m = event.clone();
383
384 for e in m {
385 ws.send(Message::Text(e.to_string())).await.unwrap();
386 }
387
388 ws.close(None).await.unwrap();
389 }
390 }
391
392 fn load_config(certs_path: &str, key_path: &str) -> io::Result<ServerConfig> {
393 let certs = load_certs(certs_path).unwrap();
394 let mut private_key = load_key(key_path).unwrap();
395
396 let mut config = ServerConfig::new(NoClientAuth::new());
397 config
398 .set_single_cert(certs, private_key.remove(0))
399 .unwrap();
400
401 Ok(config)
402 }
403
404 fn load_certs(path: &str) -> io::Result<Vec<Certificate>> {
405 certs(&mut BufReader::new(File::open(path)?))
406 .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))
407 }
408
409 fn load_key(path: &str) -> io::Result<Vec<PrivateKey>> {
410 pkcs8_private_keys(&mut BufReader::new(File::open(path)?))
411 .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))
412 }
413}