shardize_core/
lib.rs

1use proc_macro2::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{
4    parse::{Parse, ParseStream},
5    punctuated::Punctuated,
6    token::Comma,
7    FnArg, Ident, ItemTrait, TraitItem, TraitItemMethod,
8};
9
10// pub trait Shardable: Default {
11//     // fn shard_key();
12// }
13
14pub fn shardize_transform(
15    macro_config: MacroConfig,
16    trait_definition: TraitDefinition,
17) -> Result<TokenStream, &'static str> {
18    let new_struct_name = macro_config.new_struct_name;
19    let trait_name = trait_definition.name();
20    let original_trait = trait_definition.to_token_stream();
21    let impl_methods = trait_definition.impl_methods();
22
23    Ok(quote! {
24        #original_trait
25
26        struct #new_struct_name<Impl, const NUM_SHARDS: usize>
27        where
28            Impl: #trait_name
29        {
30            sharder: &'static (dyn Fn(usize) -> usize + Send + Sync),
31            shards: [Impl; NUM_SHARDS],
32        }
33
34        impl <Impl, const NUM_SHARDS: usize> #new_struct_name<Impl, NUM_SHARDS>
35        where
36            Impl: #trait_name,
37            [Impl; NUM_SHARDS]: Default
38        {
39            fn new(
40                    sharder: &'static (dyn Fn(usize) -> usize + Send + Sync),
41            ) -> Self {
42                let shards: [Impl; NUM_SHARDS] = Default::default();
43
44                Self {
45                    sharder,
46                    shards,
47                }
48            }
49        }
50
51        impl <Impl, const NUM_SHARDS: usize> #new_struct_name<Impl, NUM_SHARDS>
52        where
53            Impl: #trait_name,
54        {
55            fn shard_key(&self, key: usize) -> usize {
56                (self.sharder)(key) % NUM_SHARDS
57            }
58        }
59
60        impl<Impl, const NUM_SHARDS: usize> Default for #new_struct_name<Impl, NUM_SHARDS>
61        where
62            Impl: #trait_name
63        {
64            fn default() -> Self {
65                panic!();
66            }
67        }
68
69
70        impl <Impl, const NUM_SHARDS: usize> #trait_name for #new_struct_name<Impl, NUM_SHARDS>
71        where
72            Impl: #trait_name
73        {
74            #impl_methods
75        }
76    })
77}
78
79pub struct MacroConfig {
80    pub new_struct_name: Ident,
81}
82
83impl Parse for MacroConfig {
84    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
85        let new_struct_name = Ident::parse(input)?;
86        Ok(Self { new_struct_name })
87    }
88}
89
90pub struct TraitDefinition(ItemTrait);
91
92impl Parse for TraitDefinition {
93    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
94        Ok(Self(ItemTrait::parse(input)?))
95    }
96}
97
98impl TraitDefinition {
99    pub fn name(&self) -> &Ident {
100        &self.0.ident
101    }
102
103    pub fn to_token_stream(&self) -> TokenStream {
104        self.0.to_token_stream()
105    }
106
107    pub fn impl_methods(&self) -> TokenStream {
108        self.0
109            .items
110            .iter()
111            .filter_map(|item| match item {
112                TraitItem::Method(method) => {
113                    Some(Self::to_sharded_method(method))
114                }
115                _ => None,
116            })
117            .collect()
118    }
119
120    fn to_sharded_method(method: &TraitItemMethod) -> TokenStream {
121        let signature = method.sig.to_token_stream();
122        let method_name = &method.sig.ident;
123        // TODO ensure it takes self
124        let args: Punctuated<_, Comma> = method
125            .sig
126            .inputs
127            .iter()
128            .filter_map(|arg| match arg {
129                FnArg::Typed(pat_type) => Some(&pat_type.pat),
130                _ => None,
131            })
132            .collect();
133        // TODO ignore associated functions
134
135        quote!(
136            #signature {
137                let k = self.shard_key(key);
138                self.shards[k].#method_name(#args)
139            }
140        )
141    }
142}
143
144#[cfg(test)]
145mod test {
146    use super::*;
147
148    #[test]
149    fn macro_config_parse_test() {
150        let macro_config: MacroConfig =
151            syn::parse2(quote!(ShardedHashMap)).unwrap();
152
153        assert_eq!(macro_config.new_struct_name, "ShardedHashMap");
154    }
155
156    #[test]
157    fn trait_definition_parse_test() {
158        let trait_definition: TraitDefinition = syn::parse2(quote!(
159            trait MyTrait {}
160        ))
161        .unwrap();
162
163        assert_eq!(trait_definition.name(), "MyTrait");
164    }
165
166    #[test]
167    fn impl_methods_test() {
168        let trait_definition: TraitDefinition = syn::parse2(quote!(
169            trait MyTrait {
170                fn get(&self, key: String);
171                fn set(&self, key: String, value: String);
172            }
173        ))
174        .unwrap();
175
176        assert_eq!(
177            trait_definition.impl_methods().to_string(),
178            quote!(
179                fn get(&self, key: String) {
180                    let k = self.shard_key(key);
181                    self.shards[k].get(key)
182                }
183
184                fn set(&self, key: String, value: String) {
185                    let k = self.shard_key(key);
186                    self.shards[k].set(key, value)
187                }
188            )
189            .to_string()
190        );
191    }
192}