using_param/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::*;
4use proc_macro::Spacing::*;
5use proc_macro_tool::*;
6
7#[proc_macro_attribute]
8pub fn using_param(attr: TokenStream, item: TokenStream) -> TokenStream {
9    process(Type::Param, attr, item)
10}
11
12#[proc_macro_attribute]
13pub fn using_generic(attr: TokenStream, item: TokenStream) -> TokenStream {
14    process(Type::Generic, attr, item)
15}
16
17#[proc_macro_attribute]
18pub fn using_return(attr: TokenStream, item: TokenStream) -> TokenStream {
19    process(Type::RetType, attr, item)
20}
21
22enum Type { Param, Generic, RetType }
23
24fn process(ty: Type, attr: TokenStream, item: TokenStream) -> TokenStream {
25    let cfg = match ty {
26        Type::Param => param_cfg(attr),
27        Type::Generic => generic_cfg(attr),
28        Type::RetType => Conf { return_type: attr, ..Default::default() },
29    };
30
31    let mut iter = item.parse_iter();
32    let mut out = TokenStream::new();
33    out.extend(iter.next_attributes());
34
35    if !iter.peek_is(|i| i.is_keyword("impl")) {
36        err!("expected impl keyword", iter.span())
37    }
38    let block = iter.reduce(|a, b| {
39        out.push(a);
40        b
41    }).unwrap();
42
43    let items = block.to_brace_stream().unwrap();
44    match process_impl_block(cfg, items) {
45        Err(e) => e,
46        Ok(b) => {
47            out.push(b.grouped_brace().tt());
48            out
49        }
50    }
51}
52
53#[derive(Debug, Default)]
54struct Conf {
55    params_after: bool,
56    param: TokenStream,
57    generics: TokenStream,
58    generics_after: bool,
59    return_type: TokenStream,
60}
61
62fn fn_generic(iter: &mut ParseIter<impl Iterator<Item = TokenTree>>) -> TokenStream {
63    let mut out = TokenStream::new();
64
65    loop {
66        if let Some(arrow) = iter.next_puncts("->") {
67            out.extend(arrow);
68        } else if iter.is_puncts(">")
69            && iter.peek_i(1).is_none_or(|t| t.is_delimiter_paren())
70        {
71            break;
72        } else if let Some(tt) = iter.next() {
73            out.push(tt);
74        } else {
75            break;
76        }
77    }
78
79    out
80}
81
82fn param_cfg(attr: TokenStream) -> Conf {
83    let mut iter = attr.parse_iter();
84    let mut cfg = Conf::default();
85
86    cfg.params_after = iter.next_if(|t| t.is_punch(',')).is_some();
87    cfg.param = iter.collect();
88
89    cfg
90}
91
92fn generic_cfg(attr: TokenStream) -> Conf {
93    let mut iter = attr.parse_iter();
94    let mut cfg = Conf::default();
95
96    cfg.generics_after = iter.next_if(|t| t.is_punch(',')).is_some();
97    cfg.generics = iter.collect();
98
99    cfg
100}
101
102fn join_with_comma(a: TokenStream, b: TokenStream) -> TokenStream {
103    if a.is_empty() {
104        return b;
105    }
106    if b.is_empty() {
107        return a;
108    }
109
110    let mut left = a.into_iter().collect::<Vec<_>>();
111    let right = b.into_iter().collect::<Vec<_>>();
112
113    left.pop_if(|t| t.is_punch(','));
114
115    if !left.is_empty() {
116        left.push(','.punct(Alone).tt());
117    }
118
119    let rcom = right.first().is_some_and(|t| t.is_punch(','));
120    stream(left.into_iter().chain(right.into_iter().skip(rcom.into())))
121}
122
123fn self_param(iter: &mut ParseIter<impl Iterator<Item = TokenTree>>) -> TokenStream {
124    macro_rules! ok {
125        () => {
126            return iter
127                .split_puncts_include(",")
128                .unwrap_or_else(|| iter.collect());
129        };
130    }
131    if iter.peek_is(|t| t.is_punch('&')) {
132        if iter.peek_i_is(1, |t| t.is_punch('\''))
133        && iter.peek_i_is(2, |t| t.is_ident())
134        {
135            if iter.peek_i_is(3, |t| t.is_keyword("self")) {
136                ok!();
137            }
138
139            if iter.peek_i_is(3, |t| t.is_keyword("mut"))
140            && iter.peek_i_is(4, |t| t.is_keyword("self"))
141            {
142                ok!();
143            }
144        }
145
146        if iter.peek_i_is(1, |t| t.is_keyword("self")) {
147            ok!();
148        }
149
150        if iter.peek_i_is(1, |t| t.is_keyword("mut"))
151        && iter.peek_i_is(2, |t| t.is_keyword("self"))
152        {
153            ok!();
154        }
155    } else if iter.peek_is(|t| t.is_keyword("self")) {
156        ok!();
157    }
158    TokenStream::new()
159}
160
161fn process_impl_block(
162    cfg: Conf,
163    items: TokenStream,
164) -> Result<TokenStream, TokenStream> {
165    let mut out = TokenStream::new();
166    let mut iter = items.parse_iter();
167
168    out.extend(iter.next_outer_attributes());
169
170    while iter.peek().is_some() {
171        out.extend(iter.next_attributes());
172        out.extend(iter.next_vis());
173
174        if iter.peek_is(|t| t.is_keyword("fn"))
175            && iter.peek_i_is(1, |t| t.is_ident())
176        {
177            out.extend(iter.next_tts::<2>());
178
179            if iter.push_if_to(&mut out, |t| t.is_punch('<')) {
180                let generic = fn_generic(&mut iter);
181
182                out.add(if cfg.generics_after {
183                    join_with_comma(generic, cfg.generics.clone())
184                } else {
185                    join_with_comma(cfg.generics.clone(), generic)
186                });
187
188                iter.push_if_to(&mut out, |t| t.is_punch('>'));
189            } else if !cfg.generics.is_empty() {
190                out.push('<'.punct(Alone).tt());
191                out.add(cfg.generics.clone());
192                out.push('>'.punct(Alone).tt());
193            }
194
195            if let Some(TokenTree::Group(paren))
196                = iter.next_if(|t| t.is_delimiter_paren())
197            {
198
199                let params_group = paren.map(|paren| {
200                    let mut iter = paren.parse_iter();
201                    let mut self_ = self_param(&mut iter);
202                    let mut param = cfg.param.clone().parse_iter();
203                    let other_self = self_param(&mut param);
204
205                    if self_.is_empty() {
206                        self_ = other_self;
207                    }
208
209                    join_with_comma(self_, if cfg.params_after {
210                        join_with_comma(iter.collect(), param.collect())
211                    } else {
212                        join_with_comma(param.collect(), iter.collect())
213                    })
214                });
215                out.push(params_group.tt());
216
217                if !cfg.return_type.is_empty() && !iter.is_puncts("->") {
218                    out.push('-'.punct(Joint).tt());
219                    out.push('>'.punct(Alone).tt());
220                    out.add(cfg.return_type.clone());
221                }
222            }
223        } else {
224            out.push(iter.next().unwrap());
225        }
226    }
227
228    Ok(out)
229}
230
231/// ```
232/// using_param::__test_join! {}
233/// ```
234#[doc(hidden)]
235#[proc_macro]
236pub fn __test_join(_: TokenStream) -> TokenStream {
237    let datas = [
238        ("", "", ""),
239        ("a", "", "a"),
240        ("a,", "", "a,"),
241        ("", "a", "a"),
242        ("", "a,", "a,"),
243        ("a", "b", "a, b"),
244        ("a,", "b", "a, b"),
245        ("a,", "b,", "a, b,"),
246        ("a,", "b,", "a, b,"),
247        ("a", "b,", "a, b,"),
248    ];
249    for (a, b, expected) in datas {
250        let out = join_with_comma(a.parse().unwrap(), b.parse().unwrap());
251        assert_eq!(out.to_string(), expected, "{a:?}, {b:?}");
252    }
253    TokenStream::new()
254}
255
256/// ```
257/// using_param::__test_before! {}
258/// ```
259#[doc(hidden)]
260#[proc_macro]
261pub fn __test_before(_: TokenStream) -> TokenStream {
262    let out = using_param("ctx: i32".parse().unwrap(), "
263impl Foo {
264    #[doc(hidden)]
265    pub fn foo(&self, s: &str) -> &str {
266        s
267    }
268    pub fn bar(&self) -> i32 {
269        ctx
270    }
271    pub fn baz() -> i32 {
272        ctx
273    }
274    pub fn f(self: &Self) -> i32 {
275        ctx
276    }
277    pub fn a(x: i32) -> i32 {
278        ctx+x
279    }
280    pub fn b(&mut self, a: i32, b: i32) -> i32 {
281        ctx+a+b
282    }
283    pub fn c(&'a mut self, a: i32, b: i32) -> i32 {
284        ctx+a+b
285    }
286    pub fn d(&'static mut self, a: i32, b: i32) -> i32 {
287        ctx+a+b
288    }
289}
290    ".parse().unwrap()).to_string();
291    assert_eq!(out, "
292impl Foo {
293    #[doc(hidden)]
294    pub fn foo(& self, ctx : i32, s : & str) -> & str {
295        s
296    }
297    pub fn bar(& self, ctx : i32) -> i32 {
298        ctx
299    }
300    pub fn baz(ctx : i32) -> i32 {
301        ctx
302    }
303    pub fn f(self : & Self, ctx : i32) -> i32 {
304        ctx
305    }
306    pub fn a(ctx : i32, x : i32) -> i32 {
307        ctx+x
308    }
309    pub fn b(& mut self, ctx : i32, a : i32, b : i32) -> i32 {
310        ctx+a+b
311    }
312    pub fn c(& 'a mut self, ctx : i32, a : i32, b : i32) -> i32 {
313        ctx+a+b
314    }
315    pub fn d(& 'static mut self, ctx : i32, a : i32, b : i32) -> i32 {
316        ctx+a+b
317    }
318}
319    ".parse::<TokenStream>().unwrap().to_string());
320    TokenStream::new()
321}
322
323
324/// ```
325/// using_param::__test_after! {}
326/// ```
327#[doc(hidden)]
328#[proc_macro]
329pub fn __test_after(_: TokenStream) -> TokenStream {
330    let out = using_param(", ctx: i32".parse().unwrap(), "
331impl Foo {
332    #[doc(hidden)]
333    pub fn foo(&self, s: &str) -> &str {
334        s
335    }
336    pub fn bar(&self) -> i32 {
337        ctx
338    }
339    pub fn baz() -> i32 {
340        ctx
341    }
342    pub fn f(self: &Self) -> i32 {
343        ctx
344    }
345    pub fn a(x: i32) -> i32 {
346        ctx+x
347    }
348    pub fn b(&mut self, a: i32, b: i32) -> i32 {
349        ctx+a+b
350    }
351}
352    ".parse().unwrap()).to_string();
353    assert_eq!(out, "
354impl Foo {
355    #[doc(hidden)]
356    pub fn foo(& self, s : & str, ctx : i32) -> & str {
357        s
358    }
359    pub fn bar(& self, ctx : i32) -> i32 {
360        ctx
361    }
362    pub fn baz(ctx : i32) -> i32 {
363        ctx
364    }
365    pub fn f(self : & Self, ctx : i32) -> i32 {
366        ctx
367    }
368    pub fn a(x : i32, ctx : i32) -> i32 {
369        ctx+x
370    }
371    pub fn b(& mut self, a : i32, b : i32, ctx : i32) -> i32 {
372        ctx+a+b
373    }
374}
375    ".parse::<TokenStream>().unwrap().to_string());
376    TokenStream::new()
377}
378
379
380/// ```
381/// using_param::__test_self_param! {}
382/// ```
383#[doc(hidden)]
384#[proc_macro]
385pub fn __test_self_param(_: TokenStream) -> TokenStream {
386    let out = using_param("&'static self, ctx: i32".parse().unwrap(), "
387impl Foo {
388    pub fn foo(&self, s: &str) -> &str {
389        s
390    }
391    pub fn bar(&mut self) -> i32 {
392        ctx
393    }
394    pub fn baz(self: &Self) -> i32 {
395        ctx
396    }
397    pub fn a(this: &Self) -> i32 {
398        ctx
399    }
400}
401    ".parse().unwrap()).to_string();
402    assert_eq!(out, "
403impl Foo {
404    pub fn foo(& self, ctx : i32, s : & str) -> & str {
405        s
406    }
407    pub fn bar(& mut self, ctx : i32) -> i32 {
408        ctx
409    }
410    pub fn baz(self : & Self, ctx : i32) -> i32 {
411        ctx
412    }
413    pub fn a(& 'static self, ctx : i32, this : & Self) -> i32 {
414        ctx
415    }
416}
417    ".parse::<TokenStream>().unwrap().to_string());
418    TokenStream::new()
419}
420
421
422/// ```
423/// using_param::__test_self_param! {}
424/// ```
425#[doc(hidden)]
426#[proc_macro]
427pub fn __test_self_param_after(_: TokenStream) -> TokenStream {
428    let out = using_param(", &'static self, ctx: i32".parse().unwrap(), "
429impl Foo {
430    pub fn foo(&self, s: &str) -> &str {
431        s
432    }
433    pub fn bar(&mut self) -> i32 {
434        ctx
435    }
436    pub fn baz(self: &Self) -> i32 {
437        ctx
438    }
439    pub fn a(this: &Self) -> i32 {
440        ctx
441    }
442}
443    ".parse().unwrap()).to_string();
444    assert_eq!(out, "
445impl Foo {
446    pub fn foo(& self, s : & str, ctx : i32) -> & str {
447        s
448    }
449    pub fn bar(& mut self, ctx : i32) -> i32 {
450        ctx
451    }
452    pub fn baz(self : & Self, ctx : i32) -> i32 {
453        ctx
454    }
455    pub fn a(& 'static self, this : & Self, ctx : i32) -> i32 {
456        ctx
457    }
458}
459    ".parse::<TokenStream>().unwrap().to_string());
460    TokenStream::new()
461}
462
463
464/// ```
465/// using_param::__test_generic_before! {}
466/// ```
467#[doc(hidden)]
468#[proc_macro]
469pub fn __test_generic_before(_: TokenStream) -> TokenStream {
470    let out = using_generic("'a".parse().unwrap(), "
471impl Foo {
472    fn foo() {}
473    fn bar<'b>() {}
474    fn baz<'b, T>() {}
475}
476    ".parse().unwrap()).to_string();
477    assert_eq!(out, "
478impl Foo {
479    fn foo < 'a > () {}
480    fn bar < 'a, 'b > () {}
481    fn baz < 'a, 'b, T > () {}
482}
483    ".parse::<TokenStream>().unwrap().to_string());
484    TokenStream::new()
485}
486
487
488/// ```
489/// using_param::__test_generic_after! {}
490/// ```
491#[doc(hidden)]
492#[proc_macro]
493pub fn __test_generic_after(_: TokenStream) -> TokenStream {
494    let out = using_generic(", 'a".parse().unwrap(), "
495impl Foo {
496    fn foo() {}
497    fn bar<'b>() {}
498}
499    ".parse().unwrap()).to_string();
500    assert_eq!(out, "
501impl Foo {
502    fn foo < 'a > () {}
503    fn bar < 'b, 'a > () {}
504}
505    ".parse::<TokenStream>().unwrap().to_string());
506    TokenStream::new()
507}
508
509
510/// ```
511/// using_param::__test_other_assoc_item! {}
512/// ```
513#[doc(hidden)]
514#[proc_macro]
515pub fn __test_other_assoc_item(_: TokenStream) -> TokenStream {
516    let out = using_param("ctx: i32".parse().unwrap(), "
517impl Foo {
518    pub const M: usize = 3;
519    pub type C = i32;
520    some_macro!();
521    fn foo() {}
522    fn bar(m: i32) { m+ctx }
523    fn baz(self, m: i32) { m+ctx }
524}
525    ".parse().unwrap()).to_string();
526    assert_eq!(out, "
527impl Foo {
528    pub const M : usize = 3;
529    pub type C = i32;
530    some_macro! ();
531    fn foo(ctx : i32) {}
532    fn bar(ctx : i32, m : i32) { m+ctx }
533    fn baz(self, ctx : i32, m : i32) { m+ctx }
534}
535    ".parse::<TokenStream>().unwrap().to_string());
536    TokenStream::new()
537}
538
539
540/// ```
541/// using_param::__test_return_type! {}
542/// ```
543#[doc(hidden)]
544#[proc_macro]
545pub fn __test_return_type(_: TokenStream) -> TokenStream {
546    let out = using_return("i32".parse().unwrap(), "
547impl Foo {
548    pub const M: usize = 3;
549    pub type C = i32;
550    some_macro!();
551    fn foo() {}
552    fn bar(m: i32) { m+ctx }
553    fn baz(self, m: i32) { m+ctx }
554    fn xfoo() -> u32 {}
555    fn xbar(m: i32) -> u32 { m+ctx }
556    fn xbaz(self, m: i32) -> u32 { m+ctx }
557}
558    ".parse().unwrap()).to_string();
559    assert_eq!(out, "
560impl Foo {
561    pub const M : usize = 3;
562    pub type C = i32;
563    some_macro! ();
564    fn foo() -> i32 {}
565    fn bar(m : i32) -> i32 { m+ctx }
566    fn baz(self, m : i32) -> i32 { m+ctx }
567    fn xfoo() -> u32 {}
568    fn xbar(m : i32) -> u32 { m+ctx }
569    fn xbaz(self, m : i32) -> u32 { m+ctx }
570}
571    ".parse::<TokenStream>().unwrap().to_string());
572    TokenStream::new()
573}