1use crate::prelude::{
26    collections::btree_map::{BTreeMap, Entry},
27    marker::PhantomData,
28    vec::Vec,
29};
30
31#[cfg(feature = "serde")]
32use serde::{Deserialize, Serialize};
33
34#[cfg(feature = "schema")]
35use schemars::JsonSchema;
36
37#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, scale::Encode, scale::Decode)]
42#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
43#[cfg_attr(feature = "serde", serde(transparent))]
44pub struct UntrackedSymbol<T> {
45    #[codec(compact)]
47    pub id: u32,
48    #[cfg_attr(feature = "serde", serde(skip))]
49    marker: PhantomData<fn() -> T>,
50}
51
52impl<T> UntrackedSymbol<T> {
53    #[deprecated(
55        since = "2.5.0",
56        note = "Prefer to access the fields directly; this getter will be removed in the next major version"
57    )]
58    pub fn id(&self) -> u32 {
59        self.id
60    }
61}
62
63impl<T> From<u32> for UntrackedSymbol<T> {
64    fn from(id: u32) -> Self {
65        Self {
66            id,
67            marker: Default::default(),
68        }
69    }
70}
71
72#[cfg(feature = "schema")]
73impl<T> JsonSchema for UntrackedSymbol<T> {
74    fn schema_name() -> String {
75        String::from("UntrackedSymbol")
76    }
77
78    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> schemars::schema::Schema {
79        gen.subschema_for::<u32>()
80    }
81}
82
83#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
87#[cfg_attr(feature = "serde", derive(Serialize))]
88#[cfg_attr(feature = "serde", serde(transparent))]
89#[cfg_attr(feature = "schema", derive(JsonSchema))]
90pub struct Symbol<'a, T: 'a> {
91    id: u32,
92    #[cfg_attr(feature = "serde", serde(skip))]
93    marker: PhantomData<fn() -> &'a T>,
94}
95
96impl<T> Symbol<'_, T> {
97    pub fn into_untracked(self) -> UntrackedSymbol<T> {
113        UntrackedSymbol {
114            id: self.id,
115            marker: PhantomData,
116        }
117    }
118}
119
120#[derive(Debug, PartialEq, Eq)]
130#[cfg_attr(feature = "serde", derive(Serialize))]
131#[cfg_attr(feature = "serde", serde(transparent))]
132#[cfg_attr(feature = "schema", derive(JsonSchema))]
133pub struct Interner<T> {
134    #[cfg_attr(feature = "serde", serde(skip))]
140    map: BTreeMap<T, usize>,
141    vec: Vec<T>,
147}
148
149impl<T> Interner<T>
150where
151    T: Ord,
152{
153    pub fn new() -> Self {
155        Self {
156            map: BTreeMap::new(),
157            vec: Vec::new(),
158        }
159    }
160}
161
162impl<T: Ord> Default for Interner<T> {
163    fn default() -> Self {
164        Self::new()
165    }
166}
167
168impl<T> Interner<T>
169where
170    T: Ord + Clone,
171{
172    pub fn intern_or_get(&mut self, s: T) -> (bool, Symbol<T>) {
175        let next_id = self.vec.len();
176        let (inserted, sym_id) = match self.map.entry(s.clone()) {
177            Entry::Vacant(vacant) => {
178                vacant.insert(next_id);
179                self.vec.push(s);
180                (true, next_id)
181            }
182            Entry::Occupied(occupied) => (false, *occupied.get()),
183        };
184        (
185            inserted,
186            Symbol {
187                id: sym_id as u32,
188                marker: PhantomData,
189            },
190        )
191    }
192
193    pub fn get(&self, sym: &T) -> Option<Symbol<T>> {
196        self.map.get(sym).map(|&id| Symbol {
197            id: id as u32,
198            marker: PhantomData,
199        })
200    }
201
202    pub fn resolve(&self, sym: Symbol<T>) -> Option<&T> {
205        let idx = sym.id as usize;
206        if idx >= self.vec.len() {
207            return None;
208        }
209        self.vec.get(idx)
210    }
211
212    pub fn elements(&self) -> &[T] {
214        &self.vec
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221
222    type StringInterner = Interner<&'static str>;
223
224    fn assert_id(interner: &mut StringInterner, new_symbol: &'static str, expected_id: u32) {
225        let actual_id = interner.intern_or_get(new_symbol).1.id;
226        assert_eq!(actual_id, expected_id,);
227    }
228
229    fn assert_resolve<E>(interner: &StringInterner, symbol_id: u32, expected_str: E)
230    where
231        E: Into<Option<&'static str>>,
232    {
233        let actual_str = interner.resolve(Symbol {
234            id: symbol_id,
235            marker: PhantomData,
236        });
237        assert_eq!(actual_str.cloned(), expected_str.into(),);
238    }
239
240    #[test]
241    fn simple() {
242        let mut interner = StringInterner::new();
243        assert_id(&mut interner, "Hello", 0);
244        assert_id(&mut interner, ", World!", 1);
245        assert_id(&mut interner, "1 2 3", 2);
246        assert_id(&mut interner, "Hello", 0);
247
248        assert_resolve(&interner, 0, "Hello");
249        assert_resolve(&interner, 1, ", World!");
250        assert_resolve(&interner, 2, "1 2 3");
251        assert_resolve(&interner, 3, None);
252    }
253}