1use proc_macro2::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{
4 parse::{Parse, ParseStream},
5 punctuated::Punctuated,
6 token::Comma,
7 FnArg, Ident, ItemTrait, TraitItem, TraitItemMethod,
8};
9
10pub fn shardize_transform(
15 macro_config: MacroConfig,
16 trait_definition: TraitDefinition,
17) -> Result<TokenStream, &'static str> {
18 let new_struct_name = macro_config.new_struct_name;
19 let trait_name = trait_definition.name();
20 let original_trait = trait_definition.to_token_stream();
21 let impl_methods = trait_definition.impl_methods();
22
23 Ok(quote! {
24 #original_trait
25
26 struct #new_struct_name<Impl, const NUM_SHARDS: usize>
27 where
28 Impl: #trait_name
29 {
30 sharder: &'static (dyn Fn(usize) -> usize + Send + Sync),
31 shards: [Impl; NUM_SHARDS],
32 }
33
34 impl <Impl, const NUM_SHARDS: usize> #new_struct_name<Impl, NUM_SHARDS>
35 where
36 Impl: #trait_name,
37 [Impl; NUM_SHARDS]: Default
38 {
39 fn new(
40 sharder: &'static (dyn Fn(usize) -> usize + Send + Sync),
41 ) -> Self {
42 let shards: [Impl; NUM_SHARDS] = Default::default();
43
44 Self {
45 sharder,
46 shards,
47 }
48 }
49 }
50
51 impl <Impl, const NUM_SHARDS: usize> #new_struct_name<Impl, NUM_SHARDS>
52 where
53 Impl: #trait_name,
54 {
55 fn shard_key(&self, key: usize) -> usize {
56 (self.sharder)(key) % NUM_SHARDS
57 }
58 }
59
60 impl<Impl, const NUM_SHARDS: usize> Default for #new_struct_name<Impl, NUM_SHARDS>
61 where
62 Impl: #trait_name
63 {
64 fn default() -> Self {
65 panic!();
66 }
67 }
68
69
70 impl <Impl, const NUM_SHARDS: usize> #trait_name for #new_struct_name<Impl, NUM_SHARDS>
71 where
72 Impl: #trait_name
73 {
74 #impl_methods
75 }
76 })
77}
78
79pub struct MacroConfig {
80 pub new_struct_name: Ident,
81}
82
83impl Parse for MacroConfig {
84 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
85 let new_struct_name = Ident::parse(input)?;
86 Ok(Self { new_struct_name })
87 }
88}
89
90pub struct TraitDefinition(ItemTrait);
91
92impl Parse for TraitDefinition {
93 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
94 Ok(Self(ItemTrait::parse(input)?))
95 }
96}
97
98impl TraitDefinition {
99 pub fn name(&self) -> &Ident {
100 &self.0.ident
101 }
102
103 pub fn to_token_stream(&self) -> TokenStream {
104 self.0.to_token_stream()
105 }
106
107 pub fn impl_methods(&self) -> TokenStream {
108 self.0
109 .items
110 .iter()
111 .filter_map(|item| match item {
112 TraitItem::Method(method) => {
113 Some(Self::to_sharded_method(method))
114 }
115 _ => None,
116 })
117 .collect()
118 }
119
120 fn to_sharded_method(method: &TraitItemMethod) -> TokenStream {
121 let signature = method.sig.to_token_stream();
122 let method_name = &method.sig.ident;
123 let args: Punctuated<_, Comma> = method
125 .sig
126 .inputs
127 .iter()
128 .filter_map(|arg| match arg {
129 FnArg::Typed(pat_type) => Some(&pat_type.pat),
130 _ => None,
131 })
132 .collect();
133 quote!(
136 #signature {
137 let k = self.shard_key(key);
138 self.shards[k].#method_name(#args)
139 }
140 )
141 }
142}
143
144#[cfg(test)]
145mod test {
146 use super::*;
147
148 #[test]
149 fn macro_config_parse_test() {
150 let macro_config: MacroConfig =
151 syn::parse2(quote!(ShardedHashMap)).unwrap();
152
153 assert_eq!(macro_config.new_struct_name, "ShardedHashMap");
154 }
155
156 #[test]
157 fn trait_definition_parse_test() {
158 let trait_definition: TraitDefinition = syn::parse2(quote!(
159 trait MyTrait {}
160 ))
161 .unwrap();
162
163 assert_eq!(trait_definition.name(), "MyTrait");
164 }
165
166 #[test]
167 fn impl_methods_test() {
168 let trait_definition: TraitDefinition = syn::parse2(quote!(
169 trait MyTrait {
170 fn get(&self, key: String);
171 fn set(&self, key: String, value: String);
172 }
173 ))
174 .unwrap();
175
176 assert_eq!(
177 trait_definition.impl_methods().to_string(),
178 quote!(
179 fn get(&self, key: String) {
180 let k = self.shard_key(key);
181 self.shards[k].get(key)
182 }
183
184 fn set(&self, key: String, value: String) {
185 let k = self.shard_key(key);
186 self.shards[k].set(key, value)
187 }
188 )
189 .to_string()
190 );
191 }
192}