stepflow_base/
object_store.rs

1use std::hash::Hash;
2use std::borrow::{Cow, Borrow};
3use std::collections::{HashMap};
4use std::sync::atomic::{AtomicU16, Ordering};
5use super::IdError;
6
7pub trait ObjectStoreContent {
8  type IdType;
9  fn new_id(id_val: u16) -> Self::IdType;
10  fn id(&self) -> &Self::IdType;
11}
12
13/// A store for objects that are weak referenced by an ID and optional name.
14///
15/// There are two different ways to insert an object.
16/// - Use [`insert_new`](ObjectStore::insert_new) which takes a closure that receives the ID for the new object
17/// - Get an ID with [`reserve_id`](ObjectStore::reserve_id) and then [`register`](ObjectStore::register) the object with that ID
18///
19/// To add objects with an associated name, use the corresponding
20/// [`insert_new_named`](ObjectStore::insert_new_named) and [`register_named`](ObjectStore::register_named)
21/// instead.
22///
23/// # Examples
24/// ```
25/// # use stepflow_base::{ObjectStore, ObjectStoreContent, IdError, generate_id_type};
26/// # generate_id_type!(ObjectId);
27/// # struct Object { id: ObjectId }
28/// # impl ObjectStoreContent for Object {
29/// #   type IdType = ObjectId;
30/// #   fn new_id(id_val: u16) -> Self::IdType { ObjectId::new(id_val) }
31/// #   fn id(&self) -> &Self::IdType { &self.id }
32/// # }
33/// // create an ObjectStore with a test object
34/// let mut store = ObjectStore::new();
35/// let object_id = store.insert_new_named("test object", |id| Ok(Object { id })).unwrap();
36///
37/// // get the object either by ID or name
38/// let object = store.get(&object_id).unwrap();
39/// let object = store.get_by_name("test object").unwrap();
40/// ```
41#[derive(Debug)]
42pub struct ObjectStore<T, TID> 
43    where TID: Eq + Hash
44{
45  id_to_object: HashMap<TID, T>,
46  name_to_id: HashMap<Cow<'static, str>, TID>,
47  next_id: AtomicU16,
48}
49
50impl<'s, T, TID> ObjectStore<T, TID> 
51    where T:ObjectStoreContent + ObjectStoreContent<IdType = TID>,
52          TID: Eq + Hash + Clone,
53          
54{
55  /// Create a new ObjectStore
56  pub fn new() -> Self {
57    Self::with_capacity(0)
58  }
59
60  /// Create a new ObjectStore with initial capacity
61  pub fn with_capacity(capacity: usize) -> Self {
62    Self {
63      id_to_object: HashMap::with_capacity(capacity),
64      name_to_id: HashMap::with_capacity(capacity),
65      next_id: AtomicU16::new(0)
66    }
67  }
68
69  /// Reserve an ID in the ObjectStore. Generally followed with a call to [`register`](ObjectStore::register) using the ID.
70  pub fn reserve_id(&mut self) -> TID {
71    T::new_id(self.next_id.fetch_add(1, Ordering::SeqCst))
72  }
73
74  /// Registers an object into the ObjectStore
75  pub fn register(&mut self, object: T) -> Result<TID, IdError<TID>> {
76    // check if ID of object being registered already exists
77    if self.id_to_object.contains_key(object.id()) {
78      return Err(IdError::IdAlreadyExists(object.id().clone()))
79    }
80
81    // register the object with ID
82    let object_id = object.id().clone();
83    self.id_to_object.insert(object.id().clone(), object);
84
85    Ok(object_id)
86  }
87
88  /// Registers a named object into the ObjectStore
89  pub fn register_named<STR>(&mut self, name: STR, object: T) -> Result<TID, IdError<TID>> 
90      where STR: Into<Cow<'static, str>>
91  {
92    let name: Cow<'static, str> = name.into();
93  
94    // check if name of object being registered already exists
95    if self.name_to_id.contains_key(&name) {
96      return Err(IdError::NameAlreadyExists(name.clone().into_owned()))
97    }
98
99    // register the object
100    self.register(object)
101      .map(|object_id| {
102        // register the object's name
103        self.name_to_id.insert(name, object_id.clone());
104        object_id
105      })    
106  }
107
108  /// Reserves an ID and registers the object in a single call. The object created must use the ID given to the closure.
109  pub fn insert_new<CB>(&mut self, cb: CB) -> Result<TID, IdError<TID>>
110      where CB: FnOnce(TID) -> Result<T, IdError<TID>>
111  {
112    // reserve an ID
113    let id: TID = self.reserve_id();
114    let id_clone = id.clone();
115
116    // get the object and ensure they used the reserved ID
117    let object = cb(id)?;
118    if *object.id() != id_clone {
119      return Err(IdError::IdNotReserved(object.id().clone()));
120    }
121
122    // register the object
123    self.register(object)
124  }
125
126  /// Reserves an ID and registers the named object in a single call. The object created must use the ID given to the closure.
127  pub fn insert_new_named<CB, STR>(&mut self, name: STR, cb: CB) -> Result<TID, IdError<TID>>
128      where CB: FnOnce(TID) -> Result<T, IdError<TID>>,
129            STR: Into<Cow<'static, str>>
130  {
131    let name: Cow<'static, str> = name.into();
132
133    // reserve an ID
134    let id: TID = self.reserve_id();
135    let id_clone = id.clone();
136
137    // get the object and ensure they used the reserved ID
138    let object = cb(id)?;
139    if *object.id() != id_clone {
140      return Err(IdError::IdNotReserved(object.id().clone()));
141    }
142
143    // register the object
144    self.register_named(name, object)
145  }
146
147  /// Get the Object ID from the name
148  pub fn id_from_name(&self, name: &str) -> Option<&TID> {
149    self.name_to_id.get(name)
150  }
151
152  /// Get the name from the Object ID
153  pub fn name_from_id(&self, id: &TID) -> Option<&str> {
154    self.name_to_id.iter()
155      .find(|(_iter_name, iter_id)| { *iter_id == id })
156      .and_then(|(name, _)| Some(name.borrow()))
157  }
158
159  /// Get an object by its name
160  pub fn get_by_name(&self, name: &str) -> Option<&T> {
161    self.id_from_name(name).and_then(|id| self.get(id))
162  }
163
164  /// Get an object by its ID
165  pub fn get(&self, id: &TID) -> Option<&T> {
166    self.id_to_object.get(id)
167  }
168
169  /// Get a mutable reference to the object
170  pub fn get_mut(&mut self, id: &TID) -> Option<&mut T> {
171    self.id_to_object.get_mut(id)
172  }
173
174  // Iterator for registered object names
175  pub fn iter_names(&self) -> impl Iterator<Item = (&Cow<'static, str>, &TID)> {
176    self.name_to_id.iter()
177  }
178}
179
180
181#[cfg(test)]
182mod tests {
183  use stepflow_test_util::test_id;
184  use super::{ObjectStore};
185  use crate::{test::TestObject, test::TestObjectId, IdError};
186
187  #[test]
188  fn basic() {
189    let mut test_store: ObjectStore<TestObject, TestObjectId> = ObjectStore::new();
190    let t1 = test_store.insert_new(|id| Ok(TestObject::new(id, 100))).unwrap();
191    let t2 = test_store.insert_new(|id| Ok(TestObject::new(id, 200))).unwrap();
192    assert_ne!(t1, t2);
193
194    // don't allow dupe
195    let t1_dupe = TestObject::new(t1.clone(), 3);
196    let dupe_result = test_store.register(t1_dupe);
197    assert_eq!(dupe_result, Err(IdError::IdAlreadyExists(t1.clone())));
198
199    // don't allow custom ids
200    let testid_bad = TestObjectId::new(1000);
201    let t_custom = test_store.insert_new(|_id| Ok(TestObject::new(testid_bad.clone(), 10)));
202    assert_eq!(t_custom, Err(IdError::IdNotReserved(testid_bad)));
203
204    // check values
205    assert_eq!(test_store.get(&t1).unwrap().val(), 100);
206    assert_eq!(test_store.get(&TestObjectId::new(999)), None);
207
208    // callback failure
209    assert_eq!(test_store.insert_new(|_id| Err(IdError::CannotParse("hi".to_owned()))), Err(IdError::CannotParse("hi".to_owned())));
210  }
211
212  #[test]
213  fn register() {
214    let mut test_store: ObjectStore<TestObject, TestObjectId> = ObjectStore::new();
215    let id1 = test_id!(TestObjectId);
216    let id2 = test_id!(TestObjectId);
217    test_store.register(TestObject::new(id1, 100)).unwrap();
218    test_store.register(TestObject::new(id2, 100)).unwrap();
219    assert_eq!(test_store.register(TestObject::new(id1, 100)), Err(IdError::IdAlreadyExists(id1)));
220  }
221
222  #[test]
223  fn names() {
224    let mut test_store: ObjectStore<TestObject, TestObjectId> = ObjectStore::new();
225    let t1 = test_store.insert_new_named("t1", |id| Ok(TestObject::new(id, 100))).unwrap();
226    let _t2 = test_store.insert_new_named("t2".to_owned(), |id| Ok(TestObject::new(id, 200))).unwrap();
227
228    // don't allow register dupe name
229    let t1_dupe = test_store.insert_new_named("t1", |id| Ok(TestObject::new(id, 150)));
230    assert_eq!(t1_dupe, Err(IdError::NameAlreadyExists("t1".to_owned())));
231
232    // check values
233    assert_eq!(test_store.id_from_name("t1").unwrap().val(), t1.val());
234    assert_eq!(test_store.get_by_name("t1").unwrap().val(), 100);
235    assert_eq!(test_store.get_by_name("BAD"), None);
236  }
237
238  #[test]
239  fn get() {
240    let mut test_store: ObjectStore<TestObject, TestObjectId> = ObjectStore::new();
241    let t1 = test_store.insert_new_named("t1", |id| Ok(TestObject::new(id, 100))).unwrap();
242    let _t2 = test_store.insert_new_named("t2", |id| Ok(TestObject::new(id, 200))).unwrap();
243
244    test_store.get_mut(&t1).unwrap().set_val(5);
245    assert_eq!(test_store.get(&t1).unwrap().val(), 5);
246  }
247}