traq_ws_bot/bot/
mod.rs

1use std::{collections::HashSet, sync::Arc, time::Duration};
2
3use futures::{
4    future::{self, BoxFuture},
5    pin_mut, Future, StreamExt,
6};
7use paste::paste;
8use reqwest::Url;
9use tokio_tungstenite::{
10    connect_async,
11    tungstenite::{handshake::client::generate_key, Message},
12};
13
14use crate::events::{payload, Events};
15
16pub mod handler;
17pub mod keys;
18
19use self::handler::Handler;
20
21pub const TRAQ_ORIGIN: &str = "https://q.trap.jp";
22pub const TRAQ_ORIGIN_WS: &str = "wss://q.trap.jp";
23
24pub const TRAQ_WS_GATEWAY_PATH: &str = "/api/v3/bots/ws";
25
26pub const INITIAL_RETRY_WAIT: Duration = Duration::from_secs(3);
27pub const MAX_RETRY_WAIT: Duration = Duration::from_secs(10 * 60);
28
29pub struct TraqBotBuilder<T: Send + Sync + 'static> {
30    authorization_scheme: String,
31    token: String,
32    target_url: Url,
33    handlers: [Vec<Arc<dyn Handler<T>>>; keys::KEYS_COUNT],
34    resource: Option<T>,
35}
36
37pub struct TraqBot<T: Send + Sync + 'static> {
38    authorization_scheme: String,
39    token: String,
40    ws_origin: Url,
41    gateway_path: String,
42    handlers: [Box<[Arc<dyn Handler<T>>]>; keys::KEYS_COUNT],
43    resource: Arc<T>,
44}
45
46macro_rules! on_x_payload {
47    ($($x:ident),*$(,)?) => {
48        $(
49            paste! {
50                #[doc = ""[<$x:camel>]" イベントを受け取った際のハンドラを登録する"]
51                #[doc = ""]
52                #[doc = "# Example"]
53                #[doc = "```rust"]
54                #[doc = "use traq_ws_bot::bot::builder;"]
55                #[doc = ""]
56                #[doc = "let bot = builder(\"BOT_ACCESS_TOKEN\")"]
57                #[doc = "    ."[<on_ $x:snake>]"(|event| async move {"]
58                #[doc = "        println!(\"{:?}\", event);"]
59                #[doc = "    })"]
60                #[doc = "    .build();"]
61                #[doc = "```"]
62                pub fn [<on_ $x:snake>]<Fut>(mut self, handler: fn(payload::[<$x:camel>]) -> Fut) -> Self
63                where
64                    Fut: Future<Output = ()> + std::marker::Send + 'static,
65                {
66                    self.handlers[keys::Keys::[<$x:camel>] as usize].push(Arc::new(handler));
67                    self
68                }
69                #[doc = ""[<$x:camel>]" イベントを受け取った際のハンドラを登録する"]
70                #[doc = "引数から resource を取得することができる"]
71                #[doc = ""]
72                #[doc = "# Example"]
73                #[doc = "```rust"]
74                #[doc = "use traq_ws_bot::bot::builder;"]
75                #[doc = ""]
76                #[doc = "let bot = builder(\"BOT_ACCESS_TOKEN\")"]
77                #[doc = "    ."[<on_ $x:snake _with_resource>]"(|event, resource| async move {"]
78                #[doc = "        println!(\"{:?}, {:?}\", event, resource);"]
79                #[doc = "    })"]
80                #[doc = "    .build();"]
81                #[doc = "```"]
82                pub fn [<on_ $x:snake _with_resource>]<Fut>(mut self, handler: fn(payload::[<$x:camel>], Arc<T>) -> Fut) -> Self
83                where
84                    Fut: Future<Output = ()> + std::marker::Send + 'static,
85                {
86                    self.handlers[keys::Keys::[<$x:camel>] as usize].push(Arc::new(handler));
87                    self
88                }
89            }
90        )*
91    };
92}
93
94macro_rules! handle_event_inner {
95    ($self:expr, $event:expr => {$($x:ident),*$(,)?}, $resource:expr) => {
96        paste!{
97            match $event {
98                $(
99                    Events::[<$x:camel>](_) => Box::pin(async {
100                        future::join_all($self.handlers[keys::Keys::[<$x:camel>] as usize].iter().map(
101                            |handler| async {
102                                handler.handle($event.clone(), $resource.clone()).await;
103                            },
104                        ))
105                        .await;
106                    }),
107                )*
108            }
109        }
110    }
111}
112
113impl<T: Send + Sync + 'static> TraqBot<T> {
114    /// BOT を起動する
115    ///
116    /// # Examples
117    /// ```rust
118    /// use traq_ws_bot::bot::builder;
119    ///
120    /// # async fn try_main() -> anyhow::Result<()> {
121    /// let bot = builder("BOT_ACCESS_TOKEN")
122    ///     .on_message_created(|event| async move {
123    ///       println!("{:?}", event);
124    ///     })
125    ///     .build();
126    /// bot.start().await?;
127    /// # Ok(())
128    /// # }
129    ///
130    /// # #[tokio::main]
131    /// # async fn main() -> anyhow::Result<()> {
132    /// #     let _ = try_main().await;
133    /// #     Ok(())
134    /// # }
135    /// ```
136    pub async fn start(&self) -> anyhow::Result<()> {
137        let host = self.get_ws_url().host_str().unwrap().to_owned();
138        let mut retry_wait = INITIAL_RETRY_WAIT;
139
140        loop {
141            match self.start_inner(&host).await {
142                Ok(()) => {
143                    retry_wait = INITIAL_RETRY_WAIT;
144                }
145                Err(e) => {
146                    log::error!("Error: {}", e);
147                    retry_wait = (retry_wait * 2).min(MAX_RETRY_WAIT);
148                }
149            }
150
151            log::info!("Disconnected. retry after {} seconds", retry_wait.as_secs());
152            tokio::time::sleep(retry_wait).await;
153        }
154    }
155
156    async fn start_inner(&self, host: &str) -> anyhow::Result<()> {
157        let request = http::Request::builder()
158            .method("GET")
159            .header("Host", host)
160            .header("Connection", "Upgrade")
161            .header("Upgrade", "websocket")
162            .header("Sec-Websocket-Version", "13")
163            .header("Sec-WebSocket-Key", generate_key())
164            .uri(self.get_ws_url().to_string())
165            .header(
166                "Authorization",
167                format!("{} {}", self.authorization_scheme, self.token),
168            )
169            .body(())?;
170
171        let (ws_stream, _) = connect_async(request).await?;
172
173        let (_tx, rx) = futures::channel::mpsc::unbounded();
174        let (write, read) = ws_stream.split();
175
176        let write_loop = rx.map(Ok).forward(write);
177
178        let read_loop = {
179            futures::TryStreamExt::try_for_each(
180                read.map(|msg| -> Result<_, ()> { Ok(msg) }),
181                |message| async {
182                    match message {
183                        Ok(message) => match message {
184                            Message::Ping(_) => {
185                                // nop
186                                Ok(())
187                            }
188                            Message::Text(content) => {
189                                let event = serde_json::from_str(&content);
190                                if let Ok(event) = event {
191                                    self.handle_event(&event, self.resource.clone()).await;
192                                } else {
193                                    eprintln!("failed to parse event: {}", content);
194                                }
195                                Ok(())
196                            }
197                            Message::Close(_) => Err(()),
198                            _ => {
199                                eprintln!("not supported message: {:?}", message);
200                                Ok(())
201                            }
202                        },
203                        Err(e) => {
204                            eprintln!("error: {:?}", e);
205                            Ok(())
206                        }
207                    }
208                },
209            )
210        };
211
212        pin_mut!(write_loop, read_loop);
213        future::select(read_loop, write_loop).await;
214
215        Ok(())
216    }
217
218    /// ws もしくは wss で始まる origin に相当する URL を返す
219    ///
220    /// **Example** `wss://q.trap.jp`, `ws://localhost:8080`
221    pub fn get_ws_origin(&self) -> Url {
222        self.ws_origin.clone()
223    }
224    /// http もしくは https で始まる origin に相当する URL を返す
225    ///
226    /// **Example** `https://q.trap.jp`, `http://localhost:8080`
227    pub fn get_http_origin(&self) -> Url {
228        let mut origin = self.get_ws_origin();
229        match origin.scheme() {
230            "wss" => origin.set_scheme("https").unwrap(),
231            "ws" => origin.set_scheme("http").unwrap(),
232            _ => panic!("Invalid scheme: {} (expected: ws, wss)", origin.scheme()),
233        }
234        origin
235    }
236
237    /// ws もしくは wss で始まる gateway の URL を返す
238    ///
239    /// **Example** `wss://q.trap.jp/api/v3/bot/ws`
240    pub fn get_ws_url(&self) -> Url {
241        self.ws_origin.join(&self.gateway_path).unwrap()
242    }
243    /// http もしくは https で始まる gateway の URL を返す
244    ///
245    /// **Example** `https://q.trap.jp/api/v3/bot/ws`
246    pub fn get_http_url(&self) -> Url {
247        let mut url = self.get_ws_url();
248        match url.scheme() {
249            "wss" => url.set_scheme("https").unwrap(),
250            "ws" => url.set_scheme("http").unwrap(),
251            _ => panic!("Invalid scheme: {} (expected: ws, wss)", url.scheme()),
252        }
253        url
254    }
255
256    /// bot access token を返す
257    pub fn get_token(&self) -> &str {
258        &self.token
259    }
260
261    /// イベントに対してハンドラを呼び出す
262    async fn handle_event(&self, event: &Events, resource: Arc<T>) {
263        let promise: BoxFuture<()> = handle_event_inner!(
264            self,
265            event => {
266                Ping,
267                Joined,
268                Left,
269                MessageCreated,
270                MessageUpdated,
271                MessageDeleted,
272                BotMessageStampsUpdated,
273                DirectMessageCreated,
274                DirectMessageUpdated,
275                DirectMessageDeleted,
276                ChannelCreated,
277                ChannelTopicChanged,
278                UserCreated,
279                StampCreated,
280                TagAdded,
281                TagRemoved,
282                Error,
283            },
284            resource
285        );
286        promise.await;
287    }
288}
289
290/// TraqBot の Builder を作成する
291pub fn builder(token: impl Into<String>) -> TraqBotBuilder<()> {
292    TraqBotBuilder {
293        token: token.into(),
294        resource: Some(()),
295        ..Default::default()
296    }
297}
298
299#[doc(hidden)]
300#[allow(unused)]
301#[rustfmt::skip]
302/*pub */ fn builder_with_config(_config: ()) -> TraqBotBuilder<()> {
303    unimplemented!()
304}
305
306impl<T: Send + Sync + 'static> Default for TraqBotBuilder<T> {
307    fn default() -> Self {
308        let handlers_arr: [Vec<_>; keys::KEYS_COUNT] = Default::default();
309
310        Self {
311            authorization_scheme: "Bearer".to_owned(),
312            token: Default::default(),
313            target_url: Url::parse(TRAQ_ORIGIN_WS)
314                .unwrap()
315                .join(TRAQ_WS_GATEWAY_PATH)
316                .unwrap(),
317            handlers: handlers_arr,
318            resource: Default::default(),
319        }
320    }
321}
322
323fn convert_to_ws_url<U>(url: U) -> anyhow::Result<Url>
324where
325    U: TryInto<Url>,
326    U::Error: std::error::Error + Send + Sync + 'static,
327{
328    let mut url = url.try_into()?;
329    match url.scheme() {
330        "wss" | "ws" => Ok(url),
331        "http" => {
332            url.set_scheme("ws").unwrap();
333            Ok(url)
334        }
335        "https" => {
336            url.set_scheme("wss").unwrap();
337            Ok(url)
338        }
339        _ => Err(anyhow::anyhow!(
340            "Invalid scheme: {} (expected: ws, wss, http, https)",
341            url.scheme()
342        )),
343    }
344}
345
346impl<T: Send + Sync + 'static> TraqBotBuilder<T> {
347    /// TraqBotBuilder から TraqBot を作成する
348    ///
349    /// # Example
350    /// ```
351    /// use traq_ws_bot::bot::builder;
352    ///
353    /// let bot = builder("BOT_ACCESS_TOKEN")
354    ///     .on_message_created(|event| async move {
355    ///         println!("{:?}", event);
356    ///     })
357    ///    .build();
358    /// ```
359    pub fn build(self) -> TraqBot<T> {
360        let target_url_ws = convert_to_ws_url(self.target_url).unwrap();
361        let ws_origin = target_url_ws
362            .origin()
363            .ascii_serialization()
364            .parse()
365            .unwrap();
366        let gateway_path = target_url_ws.path().to_owned();
367
368        TraqBot {
369            authorization_scheme: self.authorization_scheme,
370            token: self.token,
371            ws_origin,
372            gateway_path,
373            handlers: self
374                .handlers
375                .into_iter()
376                .map(|v| v.into_boxed_slice())
377                .collect::<Vec<_>>()
378                .try_into()
379                .map_err(|v: Vec<Box<[Arc<dyn Handler<T>>]>>| {
380                    format!(
381                        "Invalid handlers length: {} (expected: {})",
382                        v.len(),
383                        keys::KEYS_COUNT
384                    )
385                })
386                .unwrap(),
387            resource: Arc::new(self.resource.unwrap()),
388        }
389    }
390
391    /// 認証の scheme を指定する
392    ///
393    /// **Default** `Bearer`
394    pub fn set_auth_scheme(mut self, scheme: impl Into<String>) -> Self {
395        self.authorization_scheme = scheme.into();
396        self
397    }
398    /// Bot の access token を指定する
399    pub fn set_token(mut self, token: impl Into<String>) -> Self {
400        self.token = token.into();
401        self
402    }
403    /// Bot が参加するための WebSocket の URL を指定する
404    ///
405    /// **Default** `wss://q.trap.jp/api/v3/bot/ws`
406    pub fn set_target_url<U>(mut self, url: U) -> Self
407    where
408        U: TryInto<Url>,
409        U::Error: std::fmt::Debug,
410    {
411        self.target_url = url.try_into().unwrap();
412        self
413    }
414
415    /// keys のイベントに対応するハンドラを設定する
416    /// keys は複数同時に指定可能
417    ///
418    /// ハンドラに渡される enum は key で指定したイベントに含まれることが保証される
419    ///
420    /// **NOTE** 全ての key を指定したい場合は KEYS_ALL を用いると良い
421    ///
422    /// # Example
423    /// ```rust
424    /// use traq_ws_bot::{bot::{builder, keys::Keys}, events::Events};
425    ///
426    /// let bot = builder("BOT_ACCESS_TOKEN")
427    ///     .on_event(Keys::Joined, |event| async move {
428    ///        if let Events::Joined(event) = event {
429    ///           println!("{:?}", event);
430    ///       }
431    ///    })
432    ///   .build();
433    /// ```
434    ///
435    /// ```rust
436    /// use traq_ws_bot::{bot::{builder, keys::Keys}, events::Events};
437    ///
438    /// let bot = builder("BOT_ACCESS_TOKEN")
439    ///     .on_event([Keys::Joined, Keys::Left], |event| async move {
440    ///         match event {
441    ///             Events::Joined(event) => println!("{:?}", event),
442    ///             Events::Left(event) => println!("{:?}", event),
443    ///             _ => unreachable!(),
444    ///         }
445    ///     })
446    ///     .build();
447    /// ```
448    pub fn on_event<Fut, K>(mut self, keys: K, handler: fn(Events) -> Fut) -> Self
449    where
450        Fut: Future<Output = ()> + std::marker::Send + 'static,
451        K: IntoIterator<Item = keys::Keys>,
452    {
453        let keys_set = keys.into_iter().collect::<HashSet<_>>();
454        let handler = Arc::new(handler);
455        for key in keys_set {
456            self.handlers[key as usize].push(handler.clone());
457        }
458        self
459    }
460
461    on_x_payload!(
462        Ping,
463        Joined,
464        Left,
465        MessageCreated,
466        MessageUpdated,
467        MessageDeleted,
468        BotMessageStampsUpdated,
469        DirectMessageCreated,
470        DirectMessageUpdated,
471        DirectMessageDeleted,
472        ChannelCreated,
473        ChannelTopicChanged,
474        UserCreated,
475        StampCreated,
476        TagAdded,
477        TagRemoved,
478    );
479
480    #[doc = "Error イベントを受け取った際のハンドラを登録する"]
481    #[doc = ""]
482    #[doc = "# Example"]
483    #[doc = "```rust"]
484    #[doc = "use traq_ws_bot::bot::builder;"]
485    #[doc = ""]
486    #[doc = "let bot = builder(\"BOT_ACCESS_TOKEN\")"]
487    #[doc = "    .on_error(|event| async move {"]
488    #[doc = "        println!(\"{:?}\", event);"]
489    #[doc = "    })"]
490    #[doc = "    .build();"]
491    #[doc = "```"]
492    pub fn on_error<Fut>(mut self, handler: fn(String) -> Fut) -> Self
493    where
494        Fut: Future<Output = ()> + std::marker::Send + 'static,
495    {
496        self.handlers[keys::Keys::Error as usize].push(Arc::new(handler));
497        self
498    }
499    #[doc = "Error イベントを受け取った際のハンドラを登録する"]
500    #[doc = "引数から resource を取得することができる"]
501    #[doc = ""]
502    #[doc = "# Example"]
503    #[doc = "```rust"]
504    #[doc = "use traq_ws_bot::bot::builder;"]
505    #[doc = ""]
506    #[doc = "let bot = builder(\"BOT_ACCESS_TOKEN\")"]
507    #[doc = "    .on_error_with_resource(|event, resource| async move {"]
508    #[doc = "        println!(\"{:?}, {:?}\", event, resource);"]
509    #[doc = "    })"]
510    #[doc = "    .build();"]
511    #[doc = "```"]
512    pub fn on_error_with_resource<Fut>(mut self, handler: fn(String, Arc<T>) -> Fut) -> Self
513    where
514        Fut: Future<Output = ()> + std::marker::Send + 'static,
515    {
516        self.handlers[keys::Keys::Error as usize].push(Arc::new(handler));
517        self
518    }
519
520    #[doc = "Resource を登録する"]
521    #[doc = ""]
522    #[doc = "**Warning**: これより前に登録したハンドラは削除される"]
523    #[doc = ""]
524    #[doc = "# Example"]
525    #[doc = "```rust"]
526    #[doc = "use traq_ws_bot::bot::builder;"]
527    #[doc = ""]
528    #[doc = "let bot = builder(\"BOT_ACCESS_TOKEN\")"]
529    #[doc = "    .insert_resource(\"Hello, world!\")"]
530    #[doc = "    .build();"]
531    #[doc = "```"]
532    pub fn insert_resource<U>(self, resource: U) -> TraqBotBuilder<U>
533    where
534        U: Send + Sync + 'static,
535    {
536        TraqBotBuilder {
537            token: self.token,
538            target_url: self.target_url,
539            resource: Some(resource),
540            authorization_scheme: self.authorization_scheme,
541            ..Default::default()
542        }
543    }
544}
545
546async fn tmp(_: String) {
547    println!("tmp");
548}
549
550#[allow(dead_code)]
551async fn tmp2() {
552    let _bot = builder("")
553        .on_error(tmp)
554        .on_error(|_: String| async move { println!("tmp") })
555        .on_event(keys::Keys::Error, |event| async {
556            if let Events::Error(event) = event {
557                println!("{:?}", event);
558                tmp(event).await;
559            }
560        });
561}