1use std::{collections::HashSet, sync::Arc, time::Duration};
2
3use futures::{
4 future::{self, BoxFuture},
5 pin_mut, Future, StreamExt,
6};
7use paste::paste;
8use reqwest::Url;
9use tokio_tungstenite::{
10 connect_async,
11 tungstenite::{handshake::client::generate_key, Message},
12};
13
14use crate::events::{payload, Events};
15
16pub mod handler;
17pub mod keys;
18
19use self::handler::Handler;
20
21pub const TRAQ_ORIGIN: &str = "https://q.trap.jp";
22pub const TRAQ_ORIGIN_WS: &str = "wss://q.trap.jp";
23
24pub const TRAQ_WS_GATEWAY_PATH: &str = "/api/v3/bots/ws";
25
26pub const INITIAL_RETRY_WAIT: Duration = Duration::from_secs(3);
27pub const MAX_RETRY_WAIT: Duration = Duration::from_secs(10 * 60);
28
29pub struct TraqBotBuilder<T: Send + Sync + 'static> {
30 authorization_scheme: String,
31 token: String,
32 target_url: Url,
33 handlers: [Vec<Arc<dyn Handler<T>>>; keys::KEYS_COUNT],
34 resource: Option<T>,
35}
36
37pub struct TraqBot<T: Send + Sync + 'static> {
38 authorization_scheme: String,
39 token: String,
40 ws_origin: Url,
41 gateway_path: String,
42 handlers: [Box<[Arc<dyn Handler<T>>]>; keys::KEYS_COUNT],
43 resource: Arc<T>,
44}
45
46macro_rules! on_x_payload {
47 ($($x:ident),*$(,)?) => {
48 $(
49 paste! {
50 #[doc = ""[<$x:camel>]" イベントを受け取った際のハンドラを登録する"]
51 #[doc = ""]
52 #[doc = "# Example"]
53 #[doc = "```rust"]
54 #[doc = "use traq_ws_bot::bot::builder;"]
55 #[doc = ""]
56 #[doc = "let bot = builder(\"BOT_ACCESS_TOKEN\")"]
57 #[doc = " ."[<on_ $x:snake>]"(|event| async move {"]
58 #[doc = " println!(\"{:?}\", event);"]
59 #[doc = " })"]
60 #[doc = " .build();"]
61 #[doc = "```"]
62 pub fn [<on_ $x:snake>]<Fut>(mut self, handler: fn(payload::[<$x:camel>]) -> Fut) -> Self
63 where
64 Fut: Future<Output = ()> + std::marker::Send + 'static,
65 {
66 self.handlers[keys::Keys::[<$x:camel>] as usize].push(Arc::new(handler));
67 self
68 }
69 #[doc = ""[<$x:camel>]" イベントを受け取った際のハンドラを登録する"]
70 #[doc = "引数から resource を取得することができる"]
71 #[doc = ""]
72 #[doc = "# Example"]
73 #[doc = "```rust"]
74 #[doc = "use traq_ws_bot::bot::builder;"]
75 #[doc = ""]
76 #[doc = "let bot = builder(\"BOT_ACCESS_TOKEN\")"]
77 #[doc = " ."[<on_ $x:snake _with_resource>]"(|event, resource| async move {"]
78 #[doc = " println!(\"{:?}, {:?}\", event, resource);"]
79 #[doc = " })"]
80 #[doc = " .build();"]
81 #[doc = "```"]
82 pub fn [<on_ $x:snake _with_resource>]<Fut>(mut self, handler: fn(payload::[<$x:camel>], Arc<T>) -> Fut) -> Self
83 where
84 Fut: Future<Output = ()> + std::marker::Send + 'static,
85 {
86 self.handlers[keys::Keys::[<$x:camel>] as usize].push(Arc::new(handler));
87 self
88 }
89 }
90 )*
91 };
92}
93
94macro_rules! handle_event_inner {
95 ($self:expr, $event:expr => {$($x:ident),*$(,)?}, $resource:expr) => {
96 paste!{
97 match $event {
98 $(
99 Events::[<$x:camel>](_) => Box::pin(async {
100 future::join_all($self.handlers[keys::Keys::[<$x:camel>] as usize].iter().map(
101 |handler| async {
102 handler.handle($event.clone(), $resource.clone()).await;
103 },
104 ))
105 .await;
106 }),
107 )*
108 }
109 }
110 }
111}
112
113impl<T: Send + Sync + 'static> TraqBot<T> {
114 pub async fn start(&self) -> anyhow::Result<()> {
137 let host = self.get_ws_url().host_str().unwrap().to_owned();
138 let mut retry_wait = INITIAL_RETRY_WAIT;
139
140 loop {
141 match self.start_inner(&host).await {
142 Ok(()) => {
143 retry_wait = INITIAL_RETRY_WAIT;
144 }
145 Err(e) => {
146 log::error!("Error: {}", e);
147 retry_wait = (retry_wait * 2).min(MAX_RETRY_WAIT);
148 }
149 }
150
151 log::info!("Disconnected. retry after {} seconds", retry_wait.as_secs());
152 tokio::time::sleep(retry_wait).await;
153 }
154 }
155
156 async fn start_inner(&self, host: &str) -> anyhow::Result<()> {
157 let request = http::Request::builder()
158 .method("GET")
159 .header("Host", host)
160 .header("Connection", "Upgrade")
161 .header("Upgrade", "websocket")
162 .header("Sec-Websocket-Version", "13")
163 .header("Sec-WebSocket-Key", generate_key())
164 .uri(self.get_ws_url().to_string())
165 .header(
166 "Authorization",
167 format!("{} {}", self.authorization_scheme, self.token),
168 )
169 .body(())?;
170
171 let (ws_stream, _) = connect_async(request).await?;
172
173 let (_tx, rx) = futures::channel::mpsc::unbounded();
174 let (write, read) = ws_stream.split();
175
176 let write_loop = rx.map(Ok).forward(write);
177
178 let read_loop = {
179 futures::TryStreamExt::try_for_each(
180 read.map(|msg| -> Result<_, ()> { Ok(msg) }),
181 |message| async {
182 match message {
183 Ok(message) => match message {
184 Message::Ping(_) => {
185 Ok(())
187 }
188 Message::Text(content) => {
189 let event = serde_json::from_str(&content);
190 if let Ok(event) = event {
191 self.handle_event(&event, self.resource.clone()).await;
192 } else {
193 eprintln!("failed to parse event: {}", content);
194 }
195 Ok(())
196 }
197 Message::Close(_) => Err(()),
198 _ => {
199 eprintln!("not supported message: {:?}", message);
200 Ok(())
201 }
202 },
203 Err(e) => {
204 eprintln!("error: {:?}", e);
205 Ok(())
206 }
207 }
208 },
209 )
210 };
211
212 pin_mut!(write_loop, read_loop);
213 future::select(read_loop, write_loop).await;
214
215 Ok(())
216 }
217
218 pub fn get_ws_origin(&self) -> Url {
222 self.ws_origin.clone()
223 }
224 pub fn get_http_origin(&self) -> Url {
228 let mut origin = self.get_ws_origin();
229 match origin.scheme() {
230 "wss" => origin.set_scheme("https").unwrap(),
231 "ws" => origin.set_scheme("http").unwrap(),
232 _ => panic!("Invalid scheme: {} (expected: ws, wss)", origin.scheme()),
233 }
234 origin
235 }
236
237 pub fn get_ws_url(&self) -> Url {
241 self.ws_origin.join(&self.gateway_path).unwrap()
242 }
243 pub fn get_http_url(&self) -> Url {
247 let mut url = self.get_ws_url();
248 match url.scheme() {
249 "wss" => url.set_scheme("https").unwrap(),
250 "ws" => url.set_scheme("http").unwrap(),
251 _ => panic!("Invalid scheme: {} (expected: ws, wss)", url.scheme()),
252 }
253 url
254 }
255
256 pub fn get_token(&self) -> &str {
258 &self.token
259 }
260
261 async fn handle_event(&self, event: &Events, resource: Arc<T>) {
263 let promise: BoxFuture<()> = handle_event_inner!(
264 self,
265 event => {
266 Ping,
267 Joined,
268 Left,
269 MessageCreated,
270 MessageUpdated,
271 MessageDeleted,
272 BotMessageStampsUpdated,
273 DirectMessageCreated,
274 DirectMessageUpdated,
275 DirectMessageDeleted,
276 ChannelCreated,
277 ChannelTopicChanged,
278 UserCreated,
279 StampCreated,
280 TagAdded,
281 TagRemoved,
282 Error,
283 },
284 resource
285 );
286 promise.await;
287 }
288}
289
290pub fn builder(token: impl Into<String>) -> TraqBotBuilder<()> {
292 TraqBotBuilder {
293 token: token.into(),
294 resource: Some(()),
295 ..Default::default()
296 }
297}
298
299#[doc(hidden)]
300#[allow(unused)]
301#[rustfmt::skip]
302fn builder_with_config(_config: ()) -> TraqBotBuilder<()> {
303 unimplemented!()
304}
305
306impl<T: Send + Sync + 'static> Default for TraqBotBuilder<T> {
307 fn default() -> Self {
308 let handlers_arr: [Vec<_>; keys::KEYS_COUNT] = Default::default();
309
310 Self {
311 authorization_scheme: "Bearer".to_owned(),
312 token: Default::default(),
313 target_url: Url::parse(TRAQ_ORIGIN_WS)
314 .unwrap()
315 .join(TRAQ_WS_GATEWAY_PATH)
316 .unwrap(),
317 handlers: handlers_arr,
318 resource: Default::default(),
319 }
320 }
321}
322
323fn convert_to_ws_url<U>(url: U) -> anyhow::Result<Url>
324where
325 U: TryInto<Url>,
326 U::Error: std::error::Error + Send + Sync + 'static,
327{
328 let mut url = url.try_into()?;
329 match url.scheme() {
330 "wss" | "ws" => Ok(url),
331 "http" => {
332 url.set_scheme("ws").unwrap();
333 Ok(url)
334 }
335 "https" => {
336 url.set_scheme("wss").unwrap();
337 Ok(url)
338 }
339 _ => Err(anyhow::anyhow!(
340 "Invalid scheme: {} (expected: ws, wss, http, https)",
341 url.scheme()
342 )),
343 }
344}
345
346impl<T: Send + Sync + 'static> TraqBotBuilder<T> {
347 pub fn build(self) -> TraqBot<T> {
360 let target_url_ws = convert_to_ws_url(self.target_url).unwrap();
361 let ws_origin = target_url_ws
362 .origin()
363 .ascii_serialization()
364 .parse()
365 .unwrap();
366 let gateway_path = target_url_ws.path().to_owned();
367
368 TraqBot {
369 authorization_scheme: self.authorization_scheme,
370 token: self.token,
371 ws_origin,
372 gateway_path,
373 handlers: self
374 .handlers
375 .into_iter()
376 .map(|v| v.into_boxed_slice())
377 .collect::<Vec<_>>()
378 .try_into()
379 .map_err(|v: Vec<Box<[Arc<dyn Handler<T>>]>>| {
380 format!(
381 "Invalid handlers length: {} (expected: {})",
382 v.len(),
383 keys::KEYS_COUNT
384 )
385 })
386 .unwrap(),
387 resource: Arc::new(self.resource.unwrap()),
388 }
389 }
390
391 pub fn set_auth_scheme(mut self, scheme: impl Into<String>) -> Self {
395 self.authorization_scheme = scheme.into();
396 self
397 }
398 pub fn set_token(mut self, token: impl Into<String>) -> Self {
400 self.token = token.into();
401 self
402 }
403 pub fn set_target_url<U>(mut self, url: U) -> Self
407 where
408 U: TryInto<Url>,
409 U::Error: std::fmt::Debug,
410 {
411 self.target_url = url.try_into().unwrap();
412 self
413 }
414
415 pub fn on_event<Fut, K>(mut self, keys: K, handler: fn(Events) -> Fut) -> Self
449 where
450 Fut: Future<Output = ()> + std::marker::Send + 'static,
451 K: IntoIterator<Item = keys::Keys>,
452 {
453 let keys_set = keys.into_iter().collect::<HashSet<_>>();
454 let handler = Arc::new(handler);
455 for key in keys_set {
456 self.handlers[key as usize].push(handler.clone());
457 }
458 self
459 }
460
461 on_x_payload!(
462 Ping,
463 Joined,
464 Left,
465 MessageCreated,
466 MessageUpdated,
467 MessageDeleted,
468 BotMessageStampsUpdated,
469 DirectMessageCreated,
470 DirectMessageUpdated,
471 DirectMessageDeleted,
472 ChannelCreated,
473 ChannelTopicChanged,
474 UserCreated,
475 StampCreated,
476 TagAdded,
477 TagRemoved,
478 );
479
480 #[doc = "Error イベントを受け取った際のハンドラを登録する"]
481 #[doc = ""]
482 #[doc = "# Example"]
483 #[doc = "```rust"]
484 #[doc = "use traq_ws_bot::bot::builder;"]
485 #[doc = ""]
486 #[doc = "let bot = builder(\"BOT_ACCESS_TOKEN\")"]
487 #[doc = " .on_error(|event| async move {"]
488 #[doc = " println!(\"{:?}\", event);"]
489 #[doc = " })"]
490 #[doc = " .build();"]
491 #[doc = "```"]
492 pub fn on_error<Fut>(mut self, handler: fn(String) -> Fut) -> Self
493 where
494 Fut: Future<Output = ()> + std::marker::Send + 'static,
495 {
496 self.handlers[keys::Keys::Error as usize].push(Arc::new(handler));
497 self
498 }
499 #[doc = "Error イベントを受け取った際のハンドラを登録する"]
500 #[doc = "引数から resource を取得することができる"]
501 #[doc = ""]
502 #[doc = "# Example"]
503 #[doc = "```rust"]
504 #[doc = "use traq_ws_bot::bot::builder;"]
505 #[doc = ""]
506 #[doc = "let bot = builder(\"BOT_ACCESS_TOKEN\")"]
507 #[doc = " .on_error_with_resource(|event, resource| async move {"]
508 #[doc = " println!(\"{:?}, {:?}\", event, resource);"]
509 #[doc = " })"]
510 #[doc = " .build();"]
511 #[doc = "```"]
512 pub fn on_error_with_resource<Fut>(mut self, handler: fn(String, Arc<T>) -> Fut) -> Self
513 where
514 Fut: Future<Output = ()> + std::marker::Send + 'static,
515 {
516 self.handlers[keys::Keys::Error as usize].push(Arc::new(handler));
517 self
518 }
519
520 #[doc = "Resource を登録する"]
521 #[doc = ""]
522 #[doc = "**Warning**: これより前に登録したハンドラは削除される"]
523 #[doc = ""]
524 #[doc = "# Example"]
525 #[doc = "```rust"]
526 #[doc = "use traq_ws_bot::bot::builder;"]
527 #[doc = ""]
528 #[doc = "let bot = builder(\"BOT_ACCESS_TOKEN\")"]
529 #[doc = " .insert_resource(\"Hello, world!\")"]
530 #[doc = " .build();"]
531 #[doc = "```"]
532 pub fn insert_resource<U>(self, resource: U) -> TraqBotBuilder<U>
533 where
534 U: Send + Sync + 'static,
535 {
536 TraqBotBuilder {
537 token: self.token,
538 target_url: self.target_url,
539 resource: Some(resource),
540 authorization_scheme: self.authorization_scheme,
541 ..Default::default()
542 }
543 }
544}
545
546async fn tmp(_: String) {
547 println!("tmp");
548}
549
550#[allow(dead_code)]
551async fn tmp2() {
552 let _bot = builder("")
553 .on_error(tmp)
554 .on_error(|_: String| async move { println!("tmp") })
555 .on_event(keys::Keys::Error, |event| async {
556 if let Events::Error(event) = event {
557 println!("{:?}", event);
558 tmp(event).await;
559 }
560 });
561}