Skip to main content

tulpje_framework/
framework.rs

1use std::time::Duration;
2use std::{future::Future, pin::Pin, sync::Arc};
3
4use tokio::{sync::mpsc, task::JoinHandle};
5use tracing::{Instrument as _, Span};
6use twilight_standby::Standby;
7
8use crate::Metadata;
9use tokio_util::{sync::CancellationToken, task::TaskTracker};
10use twilight_gateway::Event;
11use twilight_http::Client;
12use twilight_model::id::{Id, marker::ApplicationMarker};
13
14use crate::handler::task_handler::TaskHandler;
15use crate::scheduler::{SchedulerHandle, SchedulerTaskMessage};
16use crate::{Context, Error, Registry};
17
18type SetupFunc<T> = fn(ctx: Context<T>) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>;
19type EventMessage = (Metadata, Event, Option<Span>);
20
21#[derive(Clone)]
22pub struct FrameworkBuilder<T: Clone + Send + Sync> {
23    registry: Arc<Registry<T>>,
24    client: Arc<Client>,
25    app_id: Id<ApplicationMarker>,
26    user_data: Arc<T>,
27
28    setup_fn: Option<SetupFunc<T>>,
29}
30
31impl<T: Clone + Send + Sync + 'static> FrameworkBuilder<T> {
32    pub fn new(
33        registry: Arc<Registry<T>>,
34        client: Client,
35        app_id: Id<ApplicationMarker>,
36        user_data: T,
37    ) -> Self {
38        Self {
39            registry,
40            client: Arc::new(client),
41            app_id,
42            user_data: Arc::new(user_data),
43            setup_fn: None,
44        }
45    }
46
47    pub fn setup(&mut self, func: SetupFunc<T>) -> &mut Self {
48        self.setup_fn = Some(func);
49        self
50    }
51
52    pub fn build(&self) -> Framework<T> {
53        Framework::new(
54            Arc::clone(&self.registry),
55            Arc::clone(&self.client),
56            self.app_id,
57            Arc::clone(&self.user_data),
58            self.setup_fn,
59        )
60    }
61}
62
63pub struct Framework<T: Clone + Send + Sync> {
64    ctx: Context<T>,
65    setup_fn: Option<SetupFunc<T>>,
66
67    scheduler: SchedulerHandle<T>,
68    dispatcher: DispatchHandle,
69}
70
71impl<T: Clone + Send + Sync + 'static> Framework<T> {
72    pub fn new(
73        registry: Arc<Registry<T>>,
74        client: Arc<Client>,
75        application_id: Id<ApplicationMarker>,
76        services: Arc<T>,
77        setup_fn: Option<SetupFunc<T>>,
78    ) -> Self {
79        let ctx = Context {
80            application_id,
81            services,
82            client,
83            standby: Arc::new(Standby::new()),
84        };
85        let scheduler =
86            SchedulerHandle::new(registry.tasks.values().cloned().collect(), ctx.clone());
87        let dispatcher = DispatchHandle::new(registry, ctx.clone());
88
89        Self {
90            ctx,
91            setup_fn,
92
93            scheduler,
94            dispatcher,
95        }
96    }
97
98    pub async fn start(&mut self) -> Result<(), Error> {
99        if let Some(setup_fn) = self.setup_fn.take() {
100            (setup_fn)(self.ctx.clone())
101                .await
102                .map_err(|err| format!("error running setup function: {}", err))?;
103        }
104
105        self.scheduler
106            .start()
107            .map_err(|err| format!("error starting scheduled tasks: {}", err))?;
108
109        Ok(())
110    }
111
112    pub fn enable_task(
113        &mut self,
114        handler: TaskHandler<T>,
115    ) -> Result<(), Box<mpsc::error::SendError<SchedulerTaskMessage<T>>>> {
116        self.scheduler.enable_task(handler)
117    }
118
119    pub fn disable_task(
120        &mut self,
121        name: String,
122    ) -> Result<(), Box<mpsc::error::SendError<SchedulerTaskMessage<T>>>> {
123        self.scheduler.disable_task(name)
124    }
125
126    pub fn sender(&self) -> Sender {
127        Sender {
128            sender: self.dispatcher.sender.clone(),
129        }
130    }
131
132    pub fn send(
133        &mut self,
134        meta: Metadata,
135        event: Event,
136        span: Option<Span>,
137    ) -> Result<(), Box<mpsc::error::SendError<EventMessage>>> {
138        self.dispatcher.send(meta, event, span)
139    }
140
141    pub async fn shutdown(&mut self) {
142        self.scheduler.shutdown();
143        self.dispatcher.shutdown();
144    }
145
146    pub async fn join(&mut self) -> Result<(), Error> {
147        self.scheduler.join().await?;
148        self.dispatcher.join().await?;
149
150        Ok(())
151    }
152}
153
154struct DispatchHandle {
155    sender: mpsc::UnboundedSender<EventMessage>,
156    shutdown: CancellationToken,
157    handle: Option<JoinHandle<()>>,
158}
159impl DispatchHandle {
160    fn new<T: Clone + Send + Sync + 'static>(registry: Arc<Registry<T>>, ctx: Context<T>) -> Self {
161        let (sender, receiver) = mpsc::unbounded_channel();
162        let shutdown = CancellationToken::new();
163
164        let mut dispatch = Dispatch::new(ctx, registry, receiver, shutdown.child_token());
165        let handle = Some(tokio::spawn(async move { dispatch.run().await }));
166
167        Self {
168            sender,
169            shutdown,
170            handle,
171        }
172    }
173
174    fn send(
175        &mut self,
176        meta: Metadata,
177        event: Event,
178        span: Option<Span>,
179    ) -> Result<(), Box<mpsc::error::SendError<EventMessage>>> {
180        Ok(self.sender.send((meta, event, span))?)
181    }
182
183    fn shutdown(&mut self) {
184        self.shutdown.cancel();
185    }
186
187    async fn join(&mut self) -> Result<(), Error> {
188        Ok(self
189            .handle
190            .take()
191            .ok_or("Dispatch already shutdown")?
192            .await?)
193    }
194}
195
196struct Dispatch<T: Clone + Send + Sync> {
197    registry: Arc<Registry<T>>,
198    ctx: Context<T>,
199
200    receiver: mpsc::UnboundedReceiver<EventMessage>,
201    shutdown: CancellationToken,
202
203    tracker: TaskTracker,
204}
205impl<T: Clone + Send + Sync + 'static> Dispatch<T> {
206    fn new(
207        ctx: Context<T>,
208        registry: Arc<Registry<T>>,
209
210        receiver: mpsc::UnboundedReceiver<EventMessage>,
211        shutdown: CancellationToken,
212    ) -> Self {
213        Self {
214            registry,
215            ctx,
216
217            receiver,
218            shutdown,
219
220            tracker: TaskTracker::new(),
221        }
222    }
223
224    async fn run(&mut self) {
225        loop {
226            tokio::select! {
227                Some((meta, event, span)) = self.receiver.recv() => {
228                    let registry = Arc::clone(&self.registry);
229                    let ctx = self.ctx.clone();
230
231                    self.tracker.spawn(async move {
232                        crate::handle(meta, ctx, &registry, event).instrument(span.unwrap_or(Span::none())).await;
233                    });
234                },
235                () = self.shutdown.cancelled() => break,
236            }
237        }
238
239        self.receiver.close();
240        self.tracker.close();
241
242        if let Err(err) = tokio::time::timeout(Duration::from_secs(5), self.tracker.wait()).await {
243            tracing::warn!("waiting for dispatch tasks timed out: {err}");
244        };
245    }
246}
247
248pub struct Sender {
249    sender: mpsc::UnboundedSender<EventMessage>,
250}
251
252impl Sender {
253    pub fn send(
254        &self,
255        meta: Metadata,
256        event: Event,
257    ) -> Result<(), Box<mpsc::error::SendError<EventMessage>>> {
258        Ok(self.sender.send((meta, event, None))?)
259    }
260
261    pub fn with_span(
262        &self,
263        meta: Metadata,
264        event: Event,
265        span: Span,
266    ) -> Result<(), Box<mpsc::error::SendError<EventMessage>>> {
267        Ok(self.sender.send((meta, event, Some(span)))?)
268    }
269
270    pub fn closed(&self) -> bool {
271        self.sender.is_closed()
272    }
273}