1#![doc = include_str!("../README.md")]
2use std::iter::once;
3
4use proc_macro2::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{
7 bracketed, parse::Parse, Attribute, Ident, LitInt, LitStr, StaticMutability, Token, Visibility,
8};
9
10pub fn wstr_impl(input: TokenStream) -> syn::Result<TokenStream> {
11 let WstrArgs { arr_len, input_str } = syn::parse2(input)?;
12
13 let mut v: Vec<_> = input_str.value().encode_utf16().chain(once(0)).collect();
14
15 if let Some(arr_len) = arr_len {
16 let sz: usize = arr_len.base10_parse()?;
17
18 if sz < v.len() {
19 return Err(syn::Error::new_spanned(
20 arr_len,
21 "array size must be at least length of input string plus null terminator",
22 ));
23 }
24
25 v.resize(sz, 0);
26 }
27
28 Ok(quote! {
29 [
30 #(#v),*
31 ]
32 })
33}
34
35struct WstrArgs {
36 arr_len: Option<LitInt>,
37 input_str: LitStr,
38}
39
40impl Parse for WstrArgs {
41 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
42 let arr_len = if input.peek(LitInt) {
43 Some(input.parse()?)
44 } else {
45 None
46 };
47 if arr_len.is_some() {
48 let _: Token![,] = input.parse()?;
49 }
50 let input_str: LitStr = input.parse()?;
51 Ok(Self { arr_len, input_str })
52 }
53}
54
55fn parse_array_size(input: syn::parse::ParseStream) -> syn::Result<Option<LitInt>> {
56 let lookahead = input.lookahead1();
57 if lookahead.peek(LitInt) {
58 Ok(Some(input.parse()?))
59 } else if lookahead.peek(Token![_]) {
60 let _ = input.parse::<Token![_]>()?;
61 Ok(None)
62 } else {
63 Err(lookahead.error())
64 }
65}
66
67struct WstrTypeArray {
68 elem: Ident,
69 size: Option<LitInt>,
70}
71
72impl Parse for WstrTypeArray {
73 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
74 let content;
75
76 let _ = bracketed!(content in input);
77 let elem: Ident = content.parse()?;
78 let _ = content.parse::<Token![;]>()?;
79 let size = parse_array_size(&content)?;
80
81 Ok(Self { elem, size })
82 }
83}
84
85enum WstrConstOrStatic {
86 Const(Token![const]),
87 Static(Token![static]),
88}
89
90impl Parse for WstrConstOrStatic {
91 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
92 let lookahead = input.lookahead1();
93 if lookahead.peek(Token![const]) {
94 Ok(WstrConstOrStatic::Const(input.parse()?))
95 } else if lookahead.peek(Token![static]) {
96 Ok(WstrConstOrStatic::Static(input.parse()?))
97 } else {
98 Err(lookahead.error())
99 }
100 }
101}
102
103impl ToTokens for WstrConstOrStatic {
104 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
105 match self {
106 WstrConstOrStatic::Const(t) => t.to_tokens(tokens),
107 WstrConstOrStatic::Static(t) => t.to_tokens(tokens),
108 }
109 }
110}
111
112struct WstrDeclaration {
113 attrs: Vec<Attribute>,
114 vis: Visibility,
115 const_or_static: WstrConstOrStatic,
116 mutability: StaticMutability,
117 ident: Ident,
118 ty: WstrTypeArray,
119 lit: LitStr,
120}
121
122impl Parse for WstrDeclaration {
123 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
124 let attrs = input.call(Attribute::parse_outer)?;
125 let vis: Visibility = input.parse()?;
126 let const_or_static: WstrConstOrStatic = input.parse()?;
127 let mutability: StaticMutability = input.parse()?;
128 let ident: Ident = input.parse()?;
129 let _ = input.parse::<Token![:]>()?;
130 let ty: WstrTypeArray = input.parse()?;
131 let _ = input.parse::<Token![=]>()?;
132 let lit: LitStr = input.parse()?;
133 let _ = input.parse::<Token![;]>()?;
134
135 Ok(Self {
136 attrs,
137 vis,
138 const_or_static,
139 mutability,
140 ident,
141 ty,
142 lit,
143 })
144 }
145}
146
147pub fn wstr_literal_impl(input: TokenStream) -> syn::Result<TokenStream> {
148 let input = syn::parse2::<WstrDeclaration>(input)?;
149
150 let WstrDeclaration {
151 attrs,
152 vis,
153 const_or_static,
154 mutability,
155 ident,
156 ty,
157 lit,
158 } = input;
159 let WstrTypeArray { elem, size } = ty;
160
161 let mut v: Vec<_> = lit.value().encode_utf16().chain(once(0)).collect();
162 let arr_len = match size {
163 Some(len) => {
164 let sz: usize = len.base10_parse()?;
165 if sz < v.len() {
166 return Err(syn::Error::new_spanned(
167 len,
168 "array size must be at least length of input string plus null terminator",
169 ));
170 }
171 v.resize(sz, 0);
172 sz
173 }
174 None => v.len(),
175 };
176
177 let arr = quote! {
178 [
179 #(#v),*
180 ]
181 };
182
183 Ok(quote! {
184 #(#attrs)*
185 #vis #const_or_static #mutability #ident: [#elem; #arr_len] = #arr;
186 })
187}
188
189#[cfg(test)]
190mod tests {
191 use proc_macro2::Span;
192 use syn::{ItemConst, ItemStatic};
193
194 use super::*;
195
196 trait RMSP {
197 fn rmsp(&self) -> String;
198 }
199
200 impl RMSP for String {
201 fn rmsp(&self) -> String {
203 self.replace(' ', "")
204 }
205 }
206
207 #[test]
208 fn test_wstr_impl_single_char() {
209 let result = wstr_impl(quote!("A")).unwrap();
210 assert_eq!(result.to_string().rmsp(), "[65u16,0u16]".to_string());
211 }
212
213 #[test]
214 fn test_wstr_impl_empty_string() {
215 let result = wstr_impl(quote!("")).unwrap();
216 assert_eq!(result.to_string(), "[0u16]".to_string());
217 }
218
219 #[test]
220 fn test_wstr_impl_ascii_string() {
221 let result = wstr_impl(quote!("Hello")).unwrap();
222 assert_eq!(
223 result.to_string().rmsp(),
224 "[72u16,101u16,108u16,108u16,111u16,0u16]".to_string()
225 );
226 }
227
228 #[test]
229 fn test_wstr_impl_unicode_string() {
230 let result = wstr_impl(quote!("こんにちは")).unwrap();
231 assert_eq!(
232 result.to_string().rmsp(),
233 "[12371u16,12435u16,12395u16,12385u16,12399u16,0u16]".to_string()
234 );
235 }
236
237 #[test]
238 fn test_wstr_impl_full_len_ascii() {
239 let result = wstr_impl(quote!(10, "Hello")).unwrap();
240
241 assert_eq!(
242 result.to_string().rmsp(),
243 "[72u16,101u16,108u16,108u16,111u16,0u16,0u16,0u16,0u16,0u16]".to_string()
244 );
245 }
246
247 #[test]
248 fn test_wstr_impl_full_len_ascii_exact_len() {
249 let result = wstr_impl(quote!(6, "Hello")).unwrap();
250
251 assert_eq!(
252 result.to_string().rmsp(),
253 "[72u16,101u16,108u16,108u16,111u16,0u16]".to_string()
254 );
255 }
256
257 #[test]
258 fn test_wstr_impl_full_len_ascii_less_len_error() {
259 let result = wstr_impl(quote!(5, "Hello"));
260 assert!(result.is_err());
261 }
262
263 #[test]
264 fn test_wstr_impl_full_len_empty_string_exact_len() {
265 let result = wstr_impl(quote!(1, "")).unwrap();
267 assert_eq!(result.to_string().rmsp(), "[0u16]".to_string());
268 }
269
270 #[test]
271 fn test_wstr_impl_full_len_empty_string_larger_len() {
272 let result = wstr_impl(quote!(5, "")).unwrap();
273 assert_eq!(
274 result.to_string().rmsp(),
275 "[0u16,0u16,0u16,0u16,0u16]".to_string()
276 );
277 }
278
279 #[test]
280 fn test_wstr_impl_full_len_zero_len_error() {
281 let result = wstr_impl(quote!(0, ""));
282 assert!(result.is_err());
283 }
284
285 #[test]
286 fn test_wstr_impl_full_len_unicode_larger_len() {
287 let result = wstr_impl(quote!(10, "こんにちは")).unwrap();
288 assert_eq!(
290 result.to_string().rmsp(),
291 "[12371u16,12435u16,12395u16,12385u16,12399u16,0u16,0u16,0u16,0u16,0u16]".to_string()
292 );
293 }
294
295 #[test]
296 fn test_wstr_impl_full_len_unicode_exact_len() {
297 let result = wstr_impl(quote!(6, "こんにちは")).unwrap();
298 assert_eq!(
299 result.to_string().rmsp(),
300 "[12371u16,12435u16,12395u16,12385u16,12399u16,0u16]".to_string()
301 );
302 }
303
304 #[test]
305 fn test_wstr_literal_impl_static_placeholder() {
306 let input = quote! {
307 static HELLO: [u16; _] = "hello";
308 };
309 let output = wstr_literal_impl(input).unwrap();
310 let item = syn::parse2::<ItemStatic>(output).unwrap();
311
312 assert_eq!(
313 item.ty.to_token_stream().to_string().rmsp(),
314 "[u16;6usize]".to_string()
315 );
316 assert_eq!(
317 item.expr.to_token_stream().to_string().rmsp(),
318 "[104u16,101u16,108u16,108u16,111u16,0u16]".to_string()
319 );
320 }
321
322 #[test]
323 fn test_wstr_literal_impl_static_larger_len() {
324 let input = quote! {
325 static HELLO: [u16; 10] = "hello";
326 };
327 let output = wstr_literal_impl(input).unwrap();
328 let item = syn::parse2::<ItemStatic>(output).unwrap();
329
330 assert_eq!(
331 item.ty.to_token_stream().to_string().rmsp(),
332 "[u16;10usize]".to_string()
333 );
334 assert_eq!(
335 item.expr.to_token_stream().to_string().rmsp(),
336 "[104u16,101u16,108u16,108u16,111u16,0u16,0u16,0u16,0u16,0u16]".to_string()
337 );
338 }
339
340 #[test]
341 fn test_wstr_literal_impl_const_placeholder() {
342 let input = quote! {
343 const HELLO: [u16; _] = "hello";
344 };
345 let output = wstr_literal_impl(input).unwrap();
346 let item = syn::parse2::<ItemConst>(output).unwrap();
347
348 assert_eq!(
349 item.ty.to_token_stream().to_string().rmsp(),
350 "[u16;6usize]".to_string()
351 );
352 assert_eq!(
353 item.expr.to_token_stream().to_string().rmsp(),
354 "[104u16,101u16,108u16,108u16,111u16,0u16]".to_string()
355 );
356 }
357
358 #[test]
359 fn test_wstr_literal_impl_const_larger_len() {
360 let input = quote! {
361 const HELLO: [u16; 10] = "hello";
362 };
363 let output = wstr_literal_impl(input).unwrap();
364 let item = syn::parse2::<ItemConst>(output).unwrap();
365
366 assert_eq!(
367 item.ty.to_token_stream().to_string().rmsp(),
368 "[u16;10usize]".to_string()
369 );
370 assert_eq!(
371 item.expr.to_token_stream().to_string().rmsp(),
372 "[104u16,101u16,108u16,108u16,111u16,0u16,0u16,0u16,0u16,0u16]".to_string()
373 );
374 }
375
376 #[test]
377 fn test_wstr_typearray() {
378 let output = syn::parse2::<WstrTypeArray>(quote!([u16; _])).unwrap();
379 assert_eq!(output.elem, Ident::new("u16", Span::call_site()));
380 assert!(matches!(output.size, None));
381
382 let output = syn::parse2::<WstrTypeArray>(quote!([u16; 0x10])).unwrap();
383 assert_eq!(output.elem, Ident::new("u16", Span::call_site()));
384 assert!(matches!(output.size, Some(_)));
385 assert_eq!(
386 match output.size {
387 Some(size) => size.base10_parse::<usize>().unwrap(),
388 _ => unreachable!(),
389 },
390 0x10
391 );
392 }
393}