statsig_rust/
statsig_runtime.rs

1use futures::future::join_all;
2use std::collections::HashMap;
3use std::future::Future;
4use std::sync::Arc;
5use std::sync::Mutex;
6use std::time::Duration;
7use tokio::runtime::{Builder, Handle, Runtime};
8use tokio::sync::Notify;
9use tokio::task::JoinHandle;
10
11use crate::log_d;
12use crate::log_e;
13use crate::StatsigErr;
14
15const TAG: &str = stringify!(StatsigRuntime);
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18struct TaskId {
19    tag: String,
20    tokio_id: tokio::task::Id,
21}
22
23pub struct StatsigRuntime {
24    pub runtime_handle: Handle,
25    inner_runtime: Mutex<Option<Runtime>>,
26    spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
27    shutdown_notify: Arc<Notify>,
28}
29
30impl StatsigRuntime {
31    #[must_use]
32    pub fn get_runtime() -> Arc<StatsigRuntime> {
33        let (opt_runtime, runtime_handle) = create_runtime_if_required();
34
35        let shutdown_notify = Notify::new();
36        Arc::new(StatsigRuntime {
37            inner_runtime: Mutex::new(opt_runtime),
38            runtime_handle,
39            spawned_tasks: Arc::new(Mutex::new(HashMap::new())),
40            shutdown_notify: Arc::new(shutdown_notify),
41        })
42    }
43
44    pub fn get_handle(&self) -> Handle {
45        self.runtime_handle.clone()
46    }
47
48    pub fn get_num_active_tasks(&self) -> usize {
49        match self.spawned_tasks.lock() {
50            Ok(lock) => lock.len(),
51            Err(e) => {
52                log_e!(TAG, "Failed to lock spawned tasks {}", e);
53                0
54            }
55        }
56    }
57
58    pub fn shutdown(&self, timeout: Duration) {
59        self.shutdown_notify.notify_waiters();
60
61        if let Ok(mut lock) = self.spawned_tasks.lock() {
62            for (_, task) in lock.drain() {
63                task.abort();
64            }
65        }
66
67        if let Ok(mut lock) = self.inner_runtime.lock() {
68            if let Some(runtime) = lock.take() {
69                log_d!(
70                    TAG,
71                    "Shutting down Statsig runtime with timeout: {:?}",
72                    timeout
73                );
74                if timeout.as_millis() > 0 {
75                    runtime.shutdown_timeout(timeout);
76                } else {
77                    runtime.shutdown_background();
78                }
79            }
80        }
81    }
82
83    pub fn shutdown_immediate(&self) {
84        self.shutdown(Duration::from_millis(0));
85    }
86
87    pub fn spawn<F, Fut>(&self, tag: &str, task: F) -> tokio::task::Id
88    where
89        F: FnOnce(Arc<Notify>) -> Fut + Send + 'static,
90        Fut: Future<Output = ()> + Send + 'static,
91    {
92        let tag_string = tag.to_string();
93        let shutdown_notify = self.shutdown_notify.clone();
94        let spawned_tasks = self.spawned_tasks.clone();
95
96        log_d!(TAG, "Spawning task {}", tag);
97
98        let handle = self.runtime_handle.spawn(async move {
99            let task_id = tokio::task::id();
100            log_d!(TAG, "Executing task {}.{}", tag_string, task_id);
101            task(shutdown_notify).await;
102            remove_join_handle_with_id(spawned_tasks, tag_string, &task_id);
103        });
104
105        self.insert_join_handle(tag, handle)
106    }
107
108    pub async fn await_tasks_with_tag(&self, tag: &str) {
109        let mut handles = Vec::new();
110
111        match self.spawned_tasks.lock() {
112            Ok(mut lock) => {
113                let keys: Vec<TaskId> = lock.keys().cloned().collect();
114                for key in &keys {
115                    if key.tag == tag {
116                        let removed = if let Some(handle) = lock.remove(key) {
117                            handle
118                        } else {
119                            log_e!(TAG, "No running task found for tag {}", tag);
120                            continue;
121                        };
122
123                        handles.push(removed);
124                    }
125                }
126            }
127            Err(e) => {
128                log_e!(TAG, "Failed to lock spawned tasks {}", e);
129                return;
130            }
131        };
132
133        join_all(handles).await;
134    }
135
136    pub async fn await_join_handle(
137        &self,
138        tag: &str,
139        handle_id: &tokio::task::Id,
140    ) -> Result<(), StatsigErr> {
141        let task_id = TaskId {
142            tag: tag.to_string(),
143            tokio_id: *handle_id,
144        };
145
146        let handle = match self.spawned_tasks.lock() {
147            Ok(mut lock) => match lock.remove(&task_id) {
148                Some(handle) => handle,
149                None => {
150                    return Err(StatsigErr::ThreadFailure(
151                        "No running task found".to_string(),
152                    ));
153                }
154            },
155            Err(e) => {
156                log_e!(
157                    TAG,
158                    "An error occurred while getting join handle with id: {}: {}",
159                    handle_id,
160                    e.to_string()
161                );
162                return Err(StatsigErr::ThreadFailure(e.to_string()));
163            }
164        };
165
166        handle
167            .await
168            .map_err(|e| StatsigErr::ThreadFailure(e.to_string()))?;
169
170        Ok(())
171    }
172
173    fn insert_join_handle(&self, tag: &str, handle: JoinHandle<()>) -> tokio::task::Id {
174        let handle_id = handle.id();
175        let task_id = TaskId {
176            tag: tag.to_string(),
177            tokio_id: handle_id,
178        };
179
180        match self.spawned_tasks.lock() {
181            Ok(mut lock) => {
182                lock.insert(task_id, handle);
183            }
184            Err(e) => {
185                log_e!(
186                    TAG,
187                    "An error occurred while inserting join handle: {}",
188                    e.to_string()
189                );
190            }
191        }
192
193        handle_id
194    }
195}
196
197fn remove_join_handle_with_id(
198    spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
199    tag: String,
200    handle_id: &tokio::task::Id,
201) {
202    let task_id = TaskId {
203        tag,
204        tokio_id: *handle_id,
205    };
206
207    match spawned_tasks.lock() {
208        Ok(mut lock) => {
209            lock.remove(&task_id);
210        }
211        Err(e) => {
212            log_e!(
213                TAG,
214                "An error occurred while removing join handle {}",
215                e.to_string()
216            );
217        }
218    }
219}
220
221fn create_runtime_if_required() -> (Option<Runtime>, Handle) {
222    if let Ok(handle) = Handle::try_current() {
223        log_d!(TAG, "Existing tokio runtime found");
224        return (None, handle);
225    }
226
227    // todo: remove expects and return error
228    let rt = Builder::new_multi_thread()
229        .worker_threads(5)
230        .thread_name("statsig")
231        .enable_all()
232        .build()
233        .expect("Failed to find or create a tokio Runtime");
234
235    let handle = rt.handle().clone();
236    log_d!(TAG, "New tokio runtime created");
237    (Some(rt), handle)
238}
239
240impl Drop for StatsigRuntime {
241    fn drop(&mut self) {
242        self.shutdown(Duration::from_secs(1));
243
244        log_d!(TAG, "StatsigRuntime dropped");
245    }
246}