1use crate::store_utils::{DEFAULT_TIMEOUT, get_with_timeout, put_with_timeout};
11use anyhow::Result;
12use bytes::Bytes;
13use object_store::path::Path;
14use object_store::{ObjectStore, PutMode, PutOptions, UpdateVersion};
15use serde::{Deserialize, Serialize};
16use std::sync::Arc;
17use tokio::sync::Mutex;
18use uni_common::core::id::{Eid, Vid};
19
20#[derive(Serialize, Deserialize, Default, Clone)]
22struct CounterManifest {
23 next_vid_batch: u64,
25 next_eid_batch: u64,
27}
28
29struct AllocatorState {
31 manifest: CounterManifest,
32 manifest_version: Option<String>, current_vid: u64,
34 current_eid: u64,
35}
36
37pub struct IdAllocator {
44 store: Arc<dyn ObjectStore>,
45 path: Path,
46 state: Mutex<AllocatorState>,
47 batch_size: u64,
48}
49
50impl IdAllocator {
51 pub async fn new(store: Arc<dyn ObjectStore>, path: Path, batch_size: u64) -> Result<Self> {
53 let (manifest, version) = match get_with_timeout(&store, &path, DEFAULT_TIMEOUT).await {
54 Ok(get_result) => {
55 let version = get_result.meta.e_tag.clone();
56 let bytes = get_result.bytes().await?;
57 let manifest: CounterManifest = serde_json::from_slice(&bytes)?;
58 (manifest, version)
59 }
60 Err(e) if e.to_string().contains("not found") => (CounterManifest::default(), None),
61 Err(e) => return Err(e),
62 };
63
64 let current_vid = manifest.next_vid_batch;
66 let current_eid = manifest.next_eid_batch;
67
68 Ok(Self {
69 store,
70 path,
71 state: Mutex::new(AllocatorState {
72 manifest,
73 manifest_version: version,
74 current_vid,
75 current_eid,
76 }),
77 batch_size,
78 })
79 }
80
81 pub async fn allocate_vid(&self) -> Result<Vid> {
85 let mut state = self.state.lock().await;
86
87 if state.current_vid >= state.manifest.next_vid_batch {
89 state.manifest.next_vid_batch = state.current_vid + self.batch_size;
91 self.persist_manifest(&mut state).await?;
92 }
93
94 let vid = Vid::new(state.current_vid);
95 state.current_vid += 1;
96 Ok(vid)
97 }
98
99 pub async fn allocate_vids(&self, count: usize) -> Result<Vec<Vid>> {
101 let mut state = self.state.lock().await;
102 let needed = count as u64;
103
104 if state.current_vid + needed > state.manifest.next_vid_batch {
106 state.manifest.next_vid_batch = state.current_vid + needed + self.batch_size;
108 self.persist_manifest(&mut state).await?;
109 }
110
111 let vids: Vec<Vid> = (0..count)
112 .map(|i| Vid::new(state.current_vid + i as u64))
113 .collect();
114 state.current_vid += needed;
115 Ok(vids)
116 }
117
118 pub async fn allocate_eid(&self) -> Result<Eid> {
122 let mut state = self.state.lock().await;
123
124 if state.current_eid >= state.manifest.next_eid_batch {
126 state.manifest.next_eid_batch = state.current_eid + self.batch_size;
128 self.persist_manifest(&mut state).await?;
129 }
130
131 let eid = Eid::new(state.current_eid);
132 state.current_eid += 1;
133 Ok(eid)
134 }
135
136 pub async fn allocate_eids(&self, count: usize) -> Result<Vec<Eid>> {
138 let mut state = self.state.lock().await;
139 let needed = count as u64;
140
141 if state.current_eid + needed > state.manifest.next_eid_batch {
143 state.manifest.next_eid_batch = state.current_eid + needed + self.batch_size;
145 self.persist_manifest(&mut state).await?;
146 }
147
148 let eids: Vec<Eid> = (0..count)
149 .map(|i| Eid::new(state.current_eid + i as u64))
150 .collect();
151 state.current_eid += needed;
152 Ok(eids)
153 }
154
155 pub async fn current_vid(&self) -> u64 {
157 self.state.lock().await.current_vid
158 }
159
160 pub async fn current_eid(&self) -> u64 {
162 self.state.lock().await.current_eid
163 }
164
165 pub async fn current_hwm(&self) -> (u64, u64) {
174 let state = self.state.lock().await;
175 (state.current_vid, state.current_eid)
176 }
177
178 pub async fn checkpoint(&self) -> Result<()> {
196 let mut state = self.state.lock().await;
197 if state.manifest.next_vid_batch < state.current_vid {
200 state.manifest.next_vid_batch = state.current_vid;
201 }
202 if state.manifest.next_eid_batch < state.current_eid {
203 state.manifest.next_eid_batch = state.current_eid;
204 }
205 self.persist_manifest(&mut state).await
206 }
207
208 async fn persist_manifest(&self, state: &mut AllocatorState) -> Result<()> {
210 let json = serde_json::to_vec_pretty(&state.manifest)?;
211 let bytes = Bytes::from(json);
212
213 let put_result = if let Some(version) = &state.manifest_version {
216 let opts: PutOptions = PutMode::Update(UpdateVersion {
217 e_tag: Some(version.clone()),
218 version: None,
219 })
220 .into();
221 match tokio::time::timeout(
222 DEFAULT_TIMEOUT,
223 self.store.put_opts(&self.path, bytes.clone().into(), opts),
224 )
225 .await
226 {
227 Ok(Ok(result)) => result,
228 Ok(Err(e))
229 if e.to_string().contains("not yet implemented")
230 || e.to_string().contains("not supported") =>
231 {
232 put_with_timeout(&self.store, &self.path, bytes, DEFAULT_TIMEOUT).await?
234 }
235 Ok(Err(e)) => return Err(e.into()),
236 Err(_) => {
237 return Err(anyhow::anyhow!(
238 "Object store put_opts timed out after {:?}",
239 DEFAULT_TIMEOUT
240 ));
241 }
242 }
243 } else {
244 let opts: PutOptions = PutMode::Create.into();
246 match tokio::time::timeout(
247 DEFAULT_TIMEOUT,
248 self.store.put_opts(&self.path, bytes.clone().into(), opts),
249 )
250 .await
251 {
252 Ok(Ok(result)) => result,
253 Ok(Err(object_store::Error::AlreadyExists { .. })) => {
254 put_with_timeout(&self.store, &self.path, bytes, DEFAULT_TIMEOUT).await?
256 }
257 Ok(Err(e)) if e.to_string().contains("not yet implemented") => {
258 put_with_timeout(&self.store, &self.path, bytes, DEFAULT_TIMEOUT).await?
259 }
260 Ok(Err(e)) => return Err(e.into()),
261 Err(_) => {
262 return Err(anyhow::anyhow!(
263 "Object store put_opts timed out after {:?}",
264 DEFAULT_TIMEOUT
265 ));
266 }
267 }
268 };
269
270 state.manifest_version = put_result.e_tag;
271 Ok(())
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use object_store::memory::InMemory;
279
280 #[tokio::test]
281 async fn test_allocate_vid() {
282 let store = Arc::new(InMemory::new());
283 let path = Path::from("id_counters.json");
284 let allocator = IdAllocator::new(store, path, 100).await.unwrap();
285
286 let vid1 = allocator.allocate_vid().await.unwrap();
287 let vid2 = allocator.allocate_vid().await.unwrap();
288 let vid3 = allocator.allocate_vid().await.unwrap();
289
290 assert_eq!(vid1.as_u64(), 0);
291 assert_eq!(vid2.as_u64(), 1);
292 assert_eq!(vid3.as_u64(), 2);
293 }
294
295 #[tokio::test]
296 async fn test_allocate_eid() {
297 let store = Arc::new(InMemory::new());
298 let path = Path::from("id_counters.json");
299 let allocator = IdAllocator::new(store, path, 100).await.unwrap();
300
301 let eid1 = allocator.allocate_eid().await.unwrap();
302 let eid2 = allocator.allocate_eid().await.unwrap();
303
304 assert_eq!(eid1.as_u64(), 0);
305 assert_eq!(eid2.as_u64(), 1);
306 }
307
308 #[tokio::test]
309 async fn test_allocate_many() {
310 let store = Arc::new(InMemory::new());
311 let path = Path::from("id_counters.json");
312 let allocator = IdAllocator::new(store, path, 100).await.unwrap();
313
314 let vids = allocator.allocate_vids(5).await.unwrap();
315 assert_eq!(vids.len(), 5);
316 for (i, vid) in vids.iter().enumerate() {
317 assert_eq!(vid.as_u64(), i as u64);
318 }
319
320 let next = allocator.allocate_vid().await.unwrap();
322 assert_eq!(next.as_u64(), 5);
323 }
324
325 #[tokio::test]
326 async fn test_persistence() {
327 let store = Arc::new(InMemory::new());
328 let path = Path::from("id_counters.json");
329
330 {
332 let allocator = IdAllocator::new(store.clone(), path.clone(), 10)
333 .await
334 .unwrap();
335 for _ in 0..15 {
336 allocator.allocate_vid().await.unwrap();
337 }
338 }
339
340 {
342 let allocator = IdAllocator::new(store, path, 10).await.unwrap();
343 let vid = allocator.allocate_vid().await.unwrap();
346 assert_eq!(vid.as_u64(), 20);
347 }
348 }
349}