sqltk/node_key.rs
1use std::{any::TypeId, marker::PhantomData};
2
3/// Acts as a type-erased proxy for any type bound by `'static`.
4///
5/// The primary purpose is to be used as a key in a [`std::collections::HashMap`] where the key of the map is derived
6/// from any type of AST node, effectively making the hashmap heterogeneous over the key type.
7///
8/// Additionally, the [`NodeKey::get_as`] method can *safely* get the value of the specific proxied type.
9///
10/// `NodeKey` works by capturing the address of an AST node in addition to its [`TypeId`]. Both are required to uniquely
11/// identify a node because different node values can have the same address; for example the address of a struct and the
12/// address of its first field will be equal but the struct and its first field are different types.
13///
14/// A `NodeKey` can only be created by [`AsNodeKey`] impls.
15#[derive(Debug, Hash, Eq, PartialEq, PartialOrd, Ord, Clone, Copy)]
16pub struct NodeKey<'ast> {
17 node_addr: usize,
18 node_type: TypeId,
19 _ast: PhantomData<&'ast ()>,
20}
21
22pub trait AsNodeKey
23where
24 Self: 'static,
25{
26 fn as_node_key(&self) -> NodeKey<'_>;
27}
28
29impl<N: AsNodeKey> AsNodeKey for Box<N> {
30 fn as_node_key(&self) -> NodeKey<'_> {
31 (**self).as_node_key()
32 }
33}
34
35impl<N: 'static> AsNodeKey for Option<N> {
36 fn as_node_key(&self) -> NodeKey<'_> {
37 NodeKey::new(self)
38 }
39}
40
41impl<N> AsNodeKey for Vec<N>
42where
43 N: AsNodeKey,
44{
45 fn as_node_key(&self) -> NodeKey<'_> {
46 NodeKey::new(self)
47 }
48}
49
50impl<'ast> NodeKey<'ast> {
51 pub fn new<N: 'static>(node: &'ast N) -> Self {
52 Self {
53 node_addr: node as *const N as usize,
54 node_type: TypeId::of::<N>(),
55 _ast: PhantomData,
56 }
57 }
58
59 pub fn get_as<N: 'static>(&self) -> Option<&'ast N> {
60 if self.node_type == TypeId::of::<N>() {
61 // SAFETY: we have verified that `N` is of the correct type to permit the cast and because `'ast` outlives
62 // `self` we know that the node has not been dropped.
63 unsafe { (self.node_addr as *const N).as_ref() }
64 } else {
65 None
66 }
67 }
68}