statsig_rust/
statsig_runtime.rs1use futures::future::join_all;
2use std::collections::HashMap;
3use std::future::Future;
4use std::sync::Arc;
5use std::sync::Mutex;
6use std::time::Duration;
7use tokio::runtime::{Builder, Handle, Runtime};
8use tokio::sync::Notify;
9use tokio::task::JoinHandle;
10
11use crate::log_d;
12use crate::log_e;
13use crate::StatsigErr;
14
15const TAG: &str = stringify!(StatsigRuntime);
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18struct TaskId {
19 tag: String,
20 tokio_id: tokio::task::Id,
21}
22
23pub struct StatsigRuntime {
24 pub runtime_handle: Handle,
25 inner_runtime: Mutex<Option<Runtime>>,
26 spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
27 shutdown_notify: Arc<Notify>,
28}
29
30impl StatsigRuntime {
31 #[must_use]
32 pub fn get_runtime() -> Arc<StatsigRuntime> {
33 let (opt_runtime, runtime_handle) = create_runtime_if_required();
34
35 let shutdown_notify = Notify::new();
36 Arc::new(StatsigRuntime {
37 inner_runtime: Mutex::new(opt_runtime),
38 runtime_handle,
39 spawned_tasks: Arc::new(Mutex::new(HashMap::new())),
40 shutdown_notify: Arc::new(shutdown_notify),
41 })
42 }
43
44 pub fn get_handle(&self) -> Handle {
45 self.runtime_handle.clone()
46 }
47
48 pub fn get_num_active_tasks(&self) -> usize {
49 match self.spawned_tasks.lock() {
50 Ok(lock) => lock.len(),
51 Err(e) => {
52 log_e!(TAG, "Failed to lock spawned tasks {}", e);
53 0
54 }
55 }
56 }
57
58 pub fn shutdown(&self, timeout: Duration) {
59 self.shutdown_notify.notify_waiters();
60
61 if let Ok(mut lock) = self.spawned_tasks.lock() {
62 for (_, task) in lock.drain() {
63 task.abort();
64 }
65 }
66
67 if let Ok(mut lock) = self.inner_runtime.lock() {
68 if let Some(runtime) = lock.take() {
69 log_d!(
70 TAG,
71 "Shutting down Statsig runtime with timeout: {:?}",
72 timeout
73 );
74 if timeout.as_millis() > 0 {
75 runtime.shutdown_timeout(timeout);
76 } else {
77 runtime.shutdown_background();
78 }
79 }
80 }
81 }
82
83 pub fn shutdown_immediate(&self) {
84 self.shutdown(Duration::from_millis(0));
85 }
86
87 pub fn spawn<F, Fut>(&self, tag: &str, task: F) -> tokio::task::Id
88 where
89 F: FnOnce(Arc<Notify>) -> Fut + Send + 'static,
90 Fut: Future<Output = ()> + Send + 'static,
91 {
92 let tag_string = tag.to_string();
93 let shutdown_notify = self.shutdown_notify.clone();
94 let spawned_tasks = self.spawned_tasks.clone();
95
96 log_d!(TAG, "Spawning task {}", tag);
97
98 let handle = self.runtime_handle.spawn(async move {
99 let task_id = tokio::task::id();
100 log_d!(TAG, "Executing task {}.{}", tag_string, task_id);
101 task(shutdown_notify).await;
102 remove_join_handle_with_id(spawned_tasks, tag_string, &task_id);
103 });
104
105 self.insert_join_handle(tag, handle)
106 }
107
108 pub async fn await_tasks_with_tag(&self, tag: &str) {
109 let mut handles = Vec::new();
110
111 match self.spawned_tasks.lock() {
112 Ok(mut lock) => {
113 let keys: Vec<TaskId> = lock.keys().cloned().collect();
114 for key in &keys {
115 if key.tag == tag {
116 let removed = if let Some(handle) = lock.remove(key) {
117 handle
118 } else {
119 log_e!(TAG, "No running task found for tag {}", tag);
120 continue;
121 };
122
123 handles.push(removed);
124 }
125 }
126 }
127 Err(e) => {
128 log_e!(TAG, "Failed to lock spawned tasks {}", e);
129 return;
130 }
131 };
132
133 join_all(handles).await;
134 }
135
136 pub async fn await_join_handle(
137 &self,
138 tag: &str,
139 handle_id: &tokio::task::Id,
140 ) -> Result<(), StatsigErr> {
141 let task_id = TaskId {
142 tag: tag.to_string(),
143 tokio_id: *handle_id,
144 };
145
146 let handle = match self.spawned_tasks.lock() {
147 Ok(mut lock) => match lock.remove(&task_id) {
148 Some(handle) => handle,
149 None => {
150 return Err(StatsigErr::ThreadFailure(
151 "No running task found".to_string(),
152 ));
153 }
154 },
155 Err(e) => {
156 log_e!(
157 TAG,
158 "An error occurred while getting join handle with id: {}: {}",
159 handle_id,
160 e.to_string()
161 );
162 return Err(StatsigErr::ThreadFailure(e.to_string()));
163 }
164 };
165
166 handle
167 .await
168 .map_err(|e| StatsigErr::ThreadFailure(e.to_string()))?;
169
170 Ok(())
171 }
172
173 fn insert_join_handle(&self, tag: &str, handle: JoinHandle<()>) -> tokio::task::Id {
174 let handle_id = handle.id();
175 let task_id = TaskId {
176 tag: tag.to_string(),
177 tokio_id: handle_id,
178 };
179
180 match self.spawned_tasks.lock() {
181 Ok(mut lock) => {
182 lock.insert(task_id, handle);
183 }
184 Err(e) => {
185 log_e!(
186 TAG,
187 "An error occurred while inserting join handle: {}",
188 e.to_string()
189 );
190 }
191 }
192
193 handle_id
194 }
195}
196
197fn remove_join_handle_with_id(
198 spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
199 tag: String,
200 handle_id: &tokio::task::Id,
201) {
202 let task_id = TaskId {
203 tag,
204 tokio_id: *handle_id,
205 };
206
207 match spawned_tasks.lock() {
208 Ok(mut lock) => {
209 lock.remove(&task_id);
210 }
211 Err(e) => {
212 log_e!(
213 TAG,
214 "An error occurred while removing join handle {}",
215 e.to_string()
216 );
217 }
218 }
219}
220
221fn create_runtime_if_required() -> (Option<Runtime>, Handle) {
222 if let Ok(handle) = Handle::try_current() {
223 log_d!(TAG, "Existing tokio runtime found");
224 return (None, handle);
225 }
226
227 let rt = Builder::new_multi_thread()
229 .worker_threads(5)
230 .thread_name("statsig")
231 .enable_all()
232 .build()
233 .expect("Failed to find or create a tokio Runtime");
234
235 let handle = rt.handle().clone();
236 log_d!(TAG, "New tokio runtime created");
237 (Some(rt), handle)
238}
239
240impl Drop for StatsigRuntime {
241 fn drop(&mut self) {
242 self.shutdown(Duration::from_secs(1));
243
244 log_d!(TAG, "StatsigRuntime dropped");
245 }
246}