slack_rust/socket/
socket_mode.rs

1use crate::apps::connections_open::connections_open;
2use crate::error::Error;
3use crate::http_client::SlackWebAPIClient;
4use crate::socket::event::{
5    AcknowledgeMessage, DisconnectEvent, EventsAPI, HelloEvent, InteractiveEvent,
6    SlashCommandsEvent, SocketModeEvent,
7};
8use async_std::fs::read;
9use async_std::net::TcpStream;
10use async_tls::client::TlsStream;
11use async_tls::TlsConnector;
12use async_trait::async_trait;
13use async_tungstenite::tungstenite::Message;
14use async_tungstenite::{client_async, WebSocketStream};
15use futures_util::{SinkExt, StreamExt};
16use rustls::ClientConfig;
17use std::collections::HashMap;
18use std::io::Cursor;
19use std::sync::Arc;
20use url::Url;
21
22pub type Stream = WebSocketStream<TlsStream<TcpStream>>;
23
24/// Implement this trait in your code to handle slack events.
25#[allow(unused_variables)]
26#[async_trait]
27pub trait EventHandler<S>: Send
28where
29    S: SlackWebAPIClient,
30{
31    async fn on_close(&mut self, socket_mode: &SocketMode<S>) {
32        log::info!("websocket close");
33    }
34    async fn on_connect(&mut self, socket_mode: &SocketMode<S>) {
35        log::info!("websocket connect");
36    }
37    async fn on_hello(&mut self, socket_mode: &SocketMode<S>, e: HelloEvent, s: &mut Stream) {
38        log::info!("hello event: {:?}", e);
39    }
40    async fn on_disconnect(
41        &mut self,
42        socket_mode: &SocketMode<S>,
43        e: DisconnectEvent,
44        s: &mut Stream,
45    ) {
46        log::info!("disconnect event: {:?}", e);
47    }
48    async fn on_events_api(&mut self, socket_mode: &SocketMode<S>, e: EventsAPI, s: &mut Stream) {
49        log::info!("events api event: {:?}", e);
50    }
51    async fn on_interactive(
52        &mut self,
53        socket_mode: &SocketMode<S>,
54        e: InteractiveEvent,
55        s: &mut Stream,
56    ) {
57        log::info!("interactive event: {:?}", e);
58    }
59    async fn on_slash_commands(
60        &mut self,
61        socket_mode: &SocketMode<S>,
62        e: SlashCommandsEvent,
63        s: &mut Stream,
64    ) {
65        log::info!("slash commands event: {:?}", e);
66    }
67}
68
69/// The socket mode client.
70pub struct SocketMode<S>
71where
72    S: SlackWebAPIClient,
73{
74    pub api_client: S,
75    pub app_token: String,
76    pub bot_token: String,
77    pub option_parameter: HashMap<String, String>,
78    pub web_socket_port: u16,
79    pub ca_file_path: Option<String>,
80}
81
82impl<S> SocketMode<S>
83where
84    S: SlackWebAPIClient,
85{
86    pub fn new(api_client: S, app_token: String, bot_token: String) -> Self {
87        SocketMode {
88            api_client,
89            app_token,
90            bot_token,
91            option_parameter: HashMap::new(),
92            web_socket_port: 443,
93            ca_file_path: None,
94        }
95    }
96    pub fn option_parameter(mut self, key: String, value: String) -> Self {
97        self.option_parameter.insert(key, value);
98        self
99    }
100    pub fn web_socket_port(mut self, port: u16) -> Self {
101        self.web_socket_port = port;
102        self
103    }
104    pub fn ca_file_path(mut self, ca_file_path: String) -> Self {
105        self.ca_file_path = Some(ca_file_path);
106        self
107    }
108    /// Run slack and websocket communication.
109    pub async fn run<T>(self, handler: &mut T) -> Result<(), Error>
110    where
111        T: EventHandler<S>,
112    {
113        let response = connections_open(&self.api_client, &self.app_token).await?;
114        let ws_url = response.url.ok_or(Error::SocketModeOpenConnectionError)?;
115        let ws_url_parsed = Url::parse(&ws_url)?;
116        let ws_domain = ws_url_parsed.domain().ok_or(Error::NotFoundDomain)?;
117
118        let tcp_stream = TcpStream::connect((ws_domain, self.web_socket_port)).await?;
119        let connector = if let Some(ca_file_path) = &self.ca_file_path {
120            connector_for_ca_file(ca_file_path).await?
121        } else {
122            TlsConnector::default()
123        };
124        let tls_stream = connector.connect(ws_domain, tcp_stream).await?;
125
126        let (mut ws, _) = client_async(&ws_url, tls_stream).await?;
127
128        handler.on_connect(&self).await;
129
130        loop {
131            let message = ws.next().await.ok_or(Error::NotFoundStream)?;
132
133            match message? {
134                Message::Text(t) => {
135                    let event = serde_json::from_str::<SocketModeEvent>(&t)?;
136                    match event {
137                        SocketModeEvent::HelloEvent(e) => handler.on_hello(&self, e, &mut ws).await,
138                        SocketModeEvent::DisconnectEvent(e) => {
139                            handler.on_disconnect(&self, e, &mut ws).await
140                        }
141                        SocketModeEvent::EventsAPI(e) => {
142                            handler.on_events_api(&self, e, &mut ws).await
143                        }
144                        SocketModeEvent::InteractiveEvent(e) => {
145                            handler.on_interactive(&self, e, &mut ws).await
146                        }
147                        SocketModeEvent::SlashCommandsEvent(e) => {
148                            handler.on_slash_commands(&self, e, &mut ws).await
149                        }
150                    }
151                }
152                Message::Ping(p) => log::info!("ping: {:?}", p),
153                Message::Close(_) => {
154                    handler.on_close(&self).await;
155                    break;
156                }
157                m => log::warn!("unsupported web socket message: {:?}", m),
158            }
159        }
160        Ok(())
161    }
162}
163
164pub async fn ack(envelope_id: &str, stream: &mut Stream) -> Result<(), Error> {
165    let json = serde_json::to_string(&AcknowledgeMessage { envelope_id })?;
166    stream
167        .send(Message::Text(json))
168        .await
169        .map_err(Error::WebSocketError)
170}
171
172pub async fn connector_for_ca_file(ca_file_path: &str) -> Result<TlsConnector, Error> {
173    let mut config = ClientConfig::new();
174    let file = read(ca_file_path).await?;
175    let mut pem = Cursor::new(file);
176    config
177        .root_store
178        .add_pem_file(&mut pem)
179        .map_err(|_| Error::InvalidInputError)?;
180    Ok(TlsConnector::from(Arc::new(config)))
181}
182
183#[cfg(test)]
184mod test {
185    use crate::event_api::event::Event;
186    use crate::http_client::{MockSlackWebAPIClient, SlackWebAPIClient};
187    use crate::payloads::interactive::InteractiveEventType;
188    use crate::socket::event::{
189        DisconnectEvent, DisconnectReason, EventsAPI, HelloEvent, InteractiveEvent,
190        SlashCommandsEvent,
191    };
192    use crate::socket::socket_mode::{EventHandler, SocketMode, Stream};
193    use async_std::net::TcpListener;
194    use async_std::task;
195    use async_tls::TlsAcceptor;
196    use async_trait::async_trait;
197    use async_tungstenite::tungstenite::Message;
198    use futures_util::{SinkExt, StreamExt};
199    use rustls::internal::pemfile::{certs, pkcs8_private_keys};
200    use rustls::{Certificate, NoClientAuth, PrivateKey, ServerConfig};
201    use std::error::Error;
202    use std::fs::File;
203    use std::io;
204    use std::io::BufReader;
205    use std::sync::Arc;
206
207    pub struct Handler;
208
209    #[allow(unused_variables)]
210    #[async_trait]
211    impl<S> EventHandler<S> for Handler
212    where
213        S: SlackWebAPIClient,
214    {
215        async fn on_hello(&mut self, socket_mode: &SocketMode<S>, e: HelloEvent, s: &mut Stream) {
216            assert_eq!(e.connection_info.unwrap().app_id.unwrap(), "app_id");
217            assert_eq!(e.num_connections.unwrap(), 1);
218            assert_eq!(e.debug_info.unwrap().host.unwrap(), "host");
219            log::info!("success on_hello test");
220        }
221        async fn on_disconnect(
222            &mut self,
223            socket_mode: &SocketMode<S>,
224            e: DisconnectEvent,
225            s: &mut Stream,
226        ) {
227            assert_eq!(e.reason, DisconnectReason::LinkDisabled);
228            assert_eq!(e.debug_info.unwrap().host.unwrap(), "wss-111.slack.com");
229            log::info!("success on_disconnect test");
230        }
231        async fn on_events_api(
232            &mut self,
233            socket_mode: &SocketMode<S>,
234            e: EventsAPI,
235            s: &mut Stream,
236        ) {
237            assert_eq!(e.envelope_id, "dbdd0ef3-1543-4f94-bfb4-133d0e6c1545");
238            assert!(!e.accepts_response_payload, "false");
239
240            match e.payload {
241                Event::AppHomeOpened { user, .. } => {
242                    assert_eq!(user, "U061F7AUR");
243                }
244                _ => panic!("Payload deserialize into incorrect variant"),
245            }
246            log::info!("success on_events_api test");
247        }
248        async fn on_interactive(
249            &mut self,
250            socket_mode: &SocketMode<S>,
251            e: InteractiveEvent,
252            s: &mut Stream,
253        ) {
254            assert_eq!(e.envelope_id, "dbdd0ef3-1543-4f94-bfb4-133d0e6c1545");
255            assert!(e.accepts_response_payload, "true");
256            assert_eq!(e.payload.type_filed, InteractiveEventType::ViewSubmission);
257            log::info!("success on_interactive test")
258        }
259        async fn on_slash_commands(
260            &mut self,
261            socket_mode: &SocketMode<S>,
262            e: SlashCommandsEvent,
263            s: &mut Stream,
264        ) {
265            assert_eq!(e.envelope_id, "dbdd0ef3-1543-4f94-bfb4-133d0e6c1545");
266            assert!(e.accepts_response_payload, "true");
267            assert_eq!(e.payload.token.unwrap(), "bHKJ2n9AW6Ju3MjciOHfbA1b");
268            log::info!("success on_slash_commands test");
269        }
270    }
271
272    #[async_std::test]
273    async fn test_socket_mode() {
274        env_logger::init();
275
276        let event = vec![
277            r##"{
278  "type": "hello",
279  "connection_info": {
280    "app_id": "app_id"
281  },
282  "num_connections": 1,
283  "debug_info": {
284    "host": "host"
285  }
286}"##
287            .to_string(),
288            r##"{
289  "type": "disconnect",
290  "reason": "link_disabled",
291  "debug_info": {
292    "host": "wss-111.slack.com"
293  }
294}"##
295            .to_string(),
296            r##"{
297  "type": "events_api",
298  "envelope_id": "dbdd0ef3-1543-4f94-bfb4-133d0e6c1545",
299  "accepts_response_payload": false,
300  "payload": {
301    "type": "app_home_opened",
302    "user": "U061F7AUR",
303    "channel": "D0LAN2Q65",
304    "event_ts": "1515449522000016",
305    "tab": "home",
306    "view": {
307      "id": "VPASKP233"
308    }
309  }
310}"##
311            .to_string(),
312            r##"{
313  "type": "interactive",
314  "envelope_id": "dbdd0ef3-1543-4f94-bfb4-133d0e6c1545",
315  "accepts_response_payload": true,
316  "payload": {
317    "type": "view_submission"
318  }
319}"##
320            .to_string(),
321            r##"{
322  "type": "slash_commands",
323  "envelope_id": "dbdd0ef3-1543-4f94-bfb4-133d0e6c1545",
324  "accepts_response_payload": true,
325  "payload": {
326    "token": "bHKJ2n9AW6Ju3MjciOHfbA1b"
327  }
328}"##
329            .to_string(),
330        ];
331
332        let mut mock = MockSlackWebAPIClient::new();
333        mock.expect_post().times(1).returning(|_, _| {
334            Ok(r##"{
335                  "ok": true,
336                  "url": "wss://localhost"
337                }"##
338            .to_string())
339        });
340
341        let port = mock_web_socket(event).await.unwrap();
342        SocketMode::new(
343            mock,
344            "slack_app_token".to_string(),
345            "slack_bot_token".to_string(),
346        )
347        .web_socket_port(port)
348        .option_parameter(
349            "SLACK_CHANNEL_ID".to_string(),
350            "slack_channel_id".to_string(),
351        )
352        .ca_file_path("rootCA.pem".to_string())
353        .run(&mut Handler)
354        .await
355        .unwrap_or_else(|_| panic!("socket mode run error."));
356    }
357
358    async fn mock_web_socket(event: Vec<String>) -> Result<u16, Box<dyn Error>> {
359        let listener = TcpListener::bind("localhost:0").await?;
360        let port = listener.local_addr()?.port();
361
362        task::spawn(async move {
363            web_socket_handler(&listener, event).await;
364        });
365
366        Ok(port)
367    }
368
369    async fn web_socket_handler(listener: &TcpListener, event: Vec<String>) {
370        let config = load_config("localhost.pem", "localhost-key.pem").unwrap();
371        // TODO: async-tungstenite latest version Crate depends on rustls v.0.19
372        let acceptor = TlsAcceptor::from(Arc::new(config));
373
374        let mut incoming = listener.incoming();
375
376        while let Some(stream) = incoming.next().await {
377            let acceptor = acceptor.clone();
378            let tcp_stream = stream.unwrap();
379            let tls_stream = acceptor.accept(tcp_stream).await.unwrap();
380            let mut ws = async_tungstenite::accept_async(tls_stream).await.unwrap();
381
382            let m = event.clone();
383
384            for e in m {
385                ws.send(Message::Text(e.to_string())).await.unwrap();
386            }
387
388            ws.close(None).await.unwrap();
389        }
390    }
391
392    fn load_config(certs_path: &str, key_path: &str) -> io::Result<ServerConfig> {
393        let certs = load_certs(certs_path).unwrap();
394        let mut private_key = load_key(key_path).unwrap();
395
396        let mut config = ServerConfig::new(NoClientAuth::new());
397        config
398            .set_single_cert(certs, private_key.remove(0))
399            .unwrap();
400
401        Ok(config)
402    }
403
404    fn load_certs(path: &str) -> io::Result<Vec<Certificate>> {
405        certs(&mut BufReader::new(File::open(path)?))
406            .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))
407    }
408
409    fn load_key(path: &str) -> io::Result<Vec<PrivateKey>> {
410        pkcs8_private_keys(&mut BufReader::new(File::open(path)?))
411            .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))
412    }
413}