wstr_literal_impl/
lib.rs

1#![doc = include_str!("../README.md")]
2use std::iter::once;
3
4use proc_macro2::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{
7    bracketed, parse::Parse, Attribute, Ident, LitInt, LitStr, StaticMutability, Token, Visibility,
8};
9
10pub fn wstr_impl(input: TokenStream) -> syn::Result<TokenStream> {
11    let WstrArgs { arr_len, input_str } = syn::parse2(input)?;
12
13    let mut v: Vec<_> = input_str.value().encode_utf16().chain(once(0)).collect();
14
15    if let Some(arr_len) = arr_len {
16        let sz: usize = arr_len.base10_parse()?;
17
18        if sz < v.len() {
19            return Err(syn::Error::new_spanned(
20                arr_len,
21                "array size must be at least length of input string plus null terminator",
22            ));
23        }
24
25        v.resize(sz, 0);
26    }
27
28    Ok(quote! {
29        [
30            #(#v),*
31        ]
32    })
33}
34
35struct WstrArgs {
36    arr_len: Option<LitInt>,
37    input_str: LitStr,
38}
39
40impl Parse for WstrArgs {
41    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
42        let arr_len = if input.peek(LitInt) {
43            Some(input.parse()?)
44        } else {
45            None
46        };
47        if arr_len.is_some() {
48            let _: Token![,] = input.parse()?;
49        }
50        let input_str: LitStr = input.parse()?;
51        Ok(Self { arr_len, input_str })
52    }
53}
54
55fn parse_array_size(input: syn::parse::ParseStream) -> syn::Result<Option<LitInt>> {
56    let lookahead = input.lookahead1();
57    if lookahead.peek(LitInt) {
58        Ok(Some(input.parse()?))
59    } else if lookahead.peek(Token![_]) {
60        let _ = input.parse::<Token![_]>()?;
61        Ok(None)
62    } else {
63        Err(lookahead.error())
64    }
65}
66
67struct WstrTypeArray {
68    elem: Ident,
69    size: Option<LitInt>,
70}
71
72impl Parse for WstrTypeArray {
73    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
74        let content;
75
76        let _ = bracketed!(content in input);
77        let elem: Ident = content.parse()?;
78        let _ = content.parse::<Token![;]>()?;
79        let size = parse_array_size(&content)?;
80
81        Ok(Self { elem, size })
82    }
83}
84
85enum WstrConstOrStatic {
86    Const(Token![const]),
87    Static(Token![static]),
88}
89
90impl Parse for WstrConstOrStatic {
91    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
92        let lookahead = input.lookahead1();
93        if lookahead.peek(Token![const]) {
94            Ok(WstrConstOrStatic::Const(input.parse()?))
95        } else if lookahead.peek(Token![static]) {
96            Ok(WstrConstOrStatic::Static(input.parse()?))
97        } else {
98            Err(lookahead.error())
99        }
100    }
101}
102
103impl ToTokens for WstrConstOrStatic {
104    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
105        match self {
106            WstrConstOrStatic::Const(t) => t.to_tokens(tokens),
107            WstrConstOrStatic::Static(t) => t.to_tokens(tokens),
108        }
109    }
110}
111
112struct WstrDeclaration {
113    attrs: Vec<Attribute>,
114    vis: Visibility,
115    const_or_static: WstrConstOrStatic,
116    mutability: StaticMutability,
117    ident: Ident,
118    ty: WstrTypeArray,
119    lit: LitStr,
120}
121
122impl Parse for WstrDeclaration {
123    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
124        let attrs = input.call(Attribute::parse_outer)?;
125        let vis: Visibility = input.parse()?;
126        let const_or_static: WstrConstOrStatic = input.parse()?;
127        let mutability: StaticMutability = input.parse()?;
128        let ident: Ident = input.parse()?;
129        let _ = input.parse::<Token![:]>()?;
130        let ty: WstrTypeArray = input.parse()?;
131        let _ = input.parse::<Token![=]>()?;
132        let lit: LitStr = input.parse()?;
133        let _ = input.parse::<Token![;]>()?;
134
135        Ok(Self {
136            attrs,
137            vis,
138            const_or_static,
139            mutability,
140            ident,
141            ty,
142            lit,
143        })
144    }
145}
146
147pub fn wstr_literal_impl(input: TokenStream) -> syn::Result<TokenStream> {
148    let input = syn::parse2::<WstrDeclaration>(input)?;
149
150    let WstrDeclaration {
151        attrs,
152        vis,
153        const_or_static,
154        mutability,
155        ident,
156        ty,
157        lit,
158    } = input;
159    let WstrTypeArray { elem, size } = ty;
160
161    let mut v: Vec<_> = lit.value().encode_utf16().chain(once(0)).collect();
162    let arr_len = match size {
163        Some(len) => {
164            let sz: usize = len.base10_parse()?;
165            if sz < v.len() {
166                return Err(syn::Error::new_spanned(
167                    len,
168                    "array size must be at least length of input string plus null terminator",
169                ));
170            }
171            v.resize(sz, 0);
172            sz
173        }
174        None => v.len(),
175    };
176
177    let arr = quote! {
178        [
179            #(#v),*
180        ]
181    };
182
183    Ok(quote! {
184        #(#attrs)*
185        #vis #const_or_static #mutability #ident: [#elem; #arr_len] = #arr;
186    })
187}
188
189#[cfg(test)]
190mod tests {
191    use proc_macro2::Span;
192    use syn::{ItemConst, ItemStatic};
193
194    use super::*;
195
196    trait RMSP {
197        fn rmsp(&self) -> String;
198    }
199
200    impl RMSP for String {
201        /// スペースを削除する
202        fn rmsp(&self) -> String {
203            self.replace(' ', "")
204        }
205    }
206
207    #[test]
208    fn test_wstr_impl_single_char() {
209        let result = wstr_impl(quote!("A")).unwrap();
210        assert_eq!(result.to_string().rmsp(), "[65u16,0u16]".to_string());
211    }
212
213    #[test]
214    fn test_wstr_impl_empty_string() {
215        let result = wstr_impl(quote!("")).unwrap();
216        assert_eq!(result.to_string(), "[0u16]".to_string());
217    }
218
219    #[test]
220    fn test_wstr_impl_ascii_string() {
221        let result = wstr_impl(quote!("Hello")).unwrap();
222        assert_eq!(
223            result.to_string().rmsp(),
224            "[72u16,101u16,108u16,108u16,111u16,0u16]".to_string()
225        );
226    }
227
228    #[test]
229    fn test_wstr_impl_unicode_string() {
230        let result = wstr_impl(quote!("こんにちは")).unwrap();
231        assert_eq!(
232            result.to_string().rmsp(),
233            "[12371u16,12435u16,12395u16,12385u16,12399u16,0u16]".to_string()
234        );
235    }
236
237    #[test]
238    fn test_wstr_impl_full_len_ascii() {
239        let result = wstr_impl(quote!(10, "Hello")).unwrap();
240
241        assert_eq!(
242            result.to_string().rmsp(),
243            "[72u16,101u16,108u16,108u16,111u16,0u16,0u16,0u16,0u16,0u16]".to_string()
244        );
245    }
246
247    #[test]
248    fn test_wstr_impl_full_len_ascii_exact_len() {
249        let result = wstr_impl(quote!(6, "Hello")).unwrap();
250
251        assert_eq!(
252            result.to_string().rmsp(),
253            "[72u16,101u16,108u16,108u16,111u16,0u16]".to_string()
254        );
255    }
256
257    #[test]
258    fn test_wstr_impl_full_len_ascii_less_len_error() {
259        let result = wstr_impl(quote!(5, "Hello"));
260        assert!(result.is_err());
261    }
262
263    #[test]
264    fn test_wstr_impl_full_len_empty_string_exact_len() {
265        // empty string + null terminator => len 1
266        let result = wstr_impl(quote!(1, "")).unwrap();
267        assert_eq!(result.to_string().rmsp(), "[0u16]".to_string());
268    }
269
270    #[test]
271    fn test_wstr_impl_full_len_empty_string_larger_len() {
272        let result = wstr_impl(quote!(5, "")).unwrap();
273        assert_eq!(
274            result.to_string().rmsp(),
275            "[0u16,0u16,0u16,0u16,0u16]".to_string()
276        );
277    }
278
279    #[test]
280    fn test_wstr_impl_full_len_zero_len_error() {
281        let result = wstr_impl(quote!(0, ""));
282        assert!(result.is_err());
283    }
284
285    #[test]
286    fn test_wstr_impl_full_len_unicode_larger_len() {
287        let result = wstr_impl(quote!(10, "こんにちは")).unwrap();
288        // 5 chars + null + 4 zeros padding => 10 total
289        assert_eq!(
290            result.to_string().rmsp(),
291            "[12371u16,12435u16,12395u16,12385u16,12399u16,0u16,0u16,0u16,0u16,0u16]".to_string()
292        );
293    }
294
295    #[test]
296    fn test_wstr_impl_full_len_unicode_exact_len() {
297        let result = wstr_impl(quote!(6, "こんにちは")).unwrap();
298        assert_eq!(
299            result.to_string().rmsp(),
300            "[12371u16,12435u16,12395u16,12385u16,12399u16,0u16]".to_string()
301        );
302    }
303
304    #[test]
305    fn test_wstr_literal_impl_static_placeholder() {
306        let input = quote! {
307            static HELLO: [u16; _] = "hello";
308        };
309        let output = wstr_literal_impl(input).unwrap();
310        let item = syn::parse2::<ItemStatic>(output).unwrap();
311
312        assert_eq!(
313            item.ty.to_token_stream().to_string().rmsp(),
314            "[u16;6usize]".to_string()
315        );
316        assert_eq!(
317            item.expr.to_token_stream().to_string().rmsp(),
318            "[104u16,101u16,108u16,108u16,111u16,0u16]".to_string()
319        );
320    }
321
322    #[test]
323    fn test_wstr_literal_impl_static_larger_len() {
324        let input = quote! {
325            static HELLO: [u16; 10] = "hello";
326        };
327        let output = wstr_literal_impl(input).unwrap();
328        let item = syn::parse2::<ItemStatic>(output).unwrap();
329
330        assert_eq!(
331            item.ty.to_token_stream().to_string().rmsp(),
332            "[u16;10usize]".to_string()
333        );
334        assert_eq!(
335            item.expr.to_token_stream().to_string().rmsp(),
336            "[104u16,101u16,108u16,108u16,111u16,0u16,0u16,0u16,0u16,0u16]".to_string()
337        );
338    }
339
340    #[test]
341    fn test_wstr_literal_impl_const_placeholder() {
342        let input = quote! {
343            const HELLO: [u16; _] = "hello";
344        };
345        let output = wstr_literal_impl(input).unwrap();
346        let item = syn::parse2::<ItemConst>(output).unwrap();
347
348        assert_eq!(
349            item.ty.to_token_stream().to_string().rmsp(),
350            "[u16;6usize]".to_string()
351        );
352        assert_eq!(
353            item.expr.to_token_stream().to_string().rmsp(),
354            "[104u16,101u16,108u16,108u16,111u16,0u16]".to_string()
355        );
356    }
357
358    #[test]
359    fn test_wstr_literal_impl_const_larger_len() {
360        let input = quote! {
361            const HELLO: [u16; 10] = "hello";
362        };
363        let output = wstr_literal_impl(input).unwrap();
364        let item = syn::parse2::<ItemConst>(output).unwrap();
365
366        assert_eq!(
367            item.ty.to_token_stream().to_string().rmsp(),
368            "[u16;10usize]".to_string()
369        );
370        assert_eq!(
371            item.expr.to_token_stream().to_string().rmsp(),
372            "[104u16,101u16,108u16,108u16,111u16,0u16,0u16,0u16,0u16,0u16]".to_string()
373        );
374    }
375
376    #[test]
377    fn test_wstr_typearray() {
378        let output = syn::parse2::<WstrTypeArray>(quote!([u16; _])).unwrap();
379        assert_eq!(output.elem, Ident::new("u16", Span::call_site()));
380        assert!(matches!(output.size, None));
381
382        let output = syn::parse2::<WstrTypeArray>(quote!([u16; 0x10])).unwrap();
383        assert_eq!(output.elem, Ident::new("u16", Span::call_site()));
384        assert!(matches!(output.size, Some(_)));
385        assert_eq!(
386            match output.size {
387                Some(size) => size.base10_parse::<usize>().unwrap(),
388                _ => unreachable!(),
389            },
390            0x10
391        );
392    }
393}