trait_set/
lib.rs

1//! This crate provide support for [trait aliases][alias]: a feature
2//! that is already supported by Rust compiler, but is [not stable][tracking_issue]
3//! yet.
4//!
5//! The idea is simple: combine group of traits under a single name. The simplest
6//! example will be:
7//!
8//! ```rust
9//! use trait_set::trait_set;
10//!
11//! trait_set! {
12//!     pub trait ThreadSafe = Send + Sync;
13//! }
14//! ```
15//!
16//! Macro [`trait_set`] displayed here is the main entity of the crate:
17//! it allows declaring multiple trait aliases, each of them is represented
18//! as
19//!
20//! ```text
21//! [visibility] trait [AliasName][<generics>] = [Element1] + [Element2] + ... + [ElementN];
22//! ```
23//!
24//! For more details, see the [`trait_set`] macro documentation.
25//!
26//! [alias]: https://doc.rust-lang.org/unstable-book/language-features/trait-alias.html
27//! [tracking_issue]: https://github.com/rust-lang/rust/issues/41517
28//! [`trait_set`]: macro.trait_set.html
29
30extern crate proc_macro;
31
32use std::iter::FromIterator;
33
34use proc_macro::TokenStream;
35use proc_macro2::TokenStream as TokenStream2;
36use quote::quote;
37use syn::{
38    parse::{Error, Parse, ParseStream},
39    parse_macro_input,
40    punctuated::Punctuated,
41    spanned::Spanned,
42    Attribute, GenericParam, Generics, Ident, Lit, Meta, MetaNameValue, Result, Token,
43    TypeTraitObject, Visibility,
44};
45
46/// Represents one trait alias.
47struct TraitSet {
48    doc_comment: Option<String>,
49    visibility: Visibility,
50    _trait_token: Token![trait],
51    alias_name: Ident,
52    generics: Generics,
53    _eq_token: Token![=],
54    traits: TypeTraitObject,
55}
56
57impl TraitSet {
58    /// Attempts to parse doc-comments from the trait attributes
59    /// and returns the results as a single string.
60    /// If multiple doc-comments were provided (e.g. with `///` and `#[doc]`),
61    /// they will be joined with a newline.
62    fn parse_doc(attrs: &[Attribute]) -> Result<Option<String>> {
63        let mut out = String::new();
64
65        for attr in attrs {
66            // Check whether current attribute is `#[doc = "..."]`.
67            if let Meta::NameValue(MetaNameValue { path, lit, .. }) = attr.parse_meta()? {
68                if let Some(path_ident) = path.get_ident() {
69                    if path_ident == "doc" {
70                        if let Lit::Str(doc_comment) = lit {
71                            out += &doc_comment.value();
72                            // Newlines are not included in the literal value,
73                            // so we have to add them manually.
74                            out.push('\n');
75                        }
76                    }
77                }
78            }
79        }
80
81        Ok(if !out.is_empty() { Some(out) } else { None })
82    }
83
84    /// Renders trait alias into a new trait with bounds set.
85    fn render(self) -> TokenStream2 {
86        // Generic and non-generic implementation have slightly different
87        // syntax, so it's simpler to process them individually rather than
88        // try to generalize implementation.
89        if self.generics.params.is_empty() {
90            self.render_non_generic()
91        } else {
92            self.render_generic()
93        }
94    }
95
96    /// Renders the trait alias without generic parameters.
97    fn render_non_generic(self) -> TokenStream2 {
98        let visibility = self.visibility;
99        let alias_name = self.alias_name;
100        let bounds = self.traits.bounds;
101        let doc_comment = self.doc_comment.map(|val| quote! { #[doc = #val] });
102        quote! {
103            #doc_comment
104            #visibility trait #alias_name: #bounds {}
105
106            impl<_INNER> #alias_name for _INNER where _INNER: #bounds {}
107        }
108    }
109
110    /// Renders the trait alias with generic parameters.
111    fn render_generic(self) -> TokenStream2 {
112        let visibility = self.visibility;
113        let alias_name = self.alias_name;
114        let bounds = self.traits.bounds;
115        let doc_comment = self.doc_comment.map(|val| quote! { #[doc = #val] });
116
117        // We differentiate `generics` and `bound_generics` because in the
118        // `impl<X> Trait<Y>` block there must be no trait bounds in the `<Y>` part,
119        // they must go into `<X>` part only.
120        // E.g. `impl<X: Send, _INNER> Trait<X> for _INNER`.
121        let mut unbound_generics = self.generics.clone();
122        for param in unbound_generics.params.iter_mut() {
123            if let GenericParam::Type(ty) = param {
124                if !ty.bounds.is_empty() {
125                    ty.bounds.clear();
126                }
127            }
128        }
129        let unbound_generics = unbound_generics.params;
130        let bound_generics = self.generics.params;
131
132        // Note that it's important for `_INNER` to go *after* user-defined
133        // generics, because generics can contain lifetimes, and lifetimes
134        // should always go first.
135        quote! {
136            #doc_comment
137            #visibility trait #alias_name<#bound_generics>: #bounds {}
138
139            impl<#bound_generics, _INNER> #alias_name<#unbound_generics> for _INNER where _INNER: #bounds {}
140        }
141    }
142}
143
144impl Parse for TraitSet {
145    fn parse(input: ParseStream) -> Result<Self> {
146        let attrs: Vec<Attribute> = input.call(Attribute::parse_outer)?;
147        let result = TraitSet {
148            doc_comment: Self::parse_doc(&attrs)?,
149            visibility: input.parse()?,
150            _trait_token: input.parse()?,
151            alias_name: input.parse()?,
152            generics: input.parse()?,
153            _eq_token: input.parse()?,
154            traits: input.parse()?,
155        };
156
157        if let Some(where_clause) = result.generics.where_clause {
158            return Err(Error::new(
159                where_clause.span(),
160                "Where clause is not allowed for trait alias",
161            ));
162        }
163        Ok(result)
164    }
165}
166
167/// Represents a sequence of trait aliases delimited by semicolon.
168struct ManyTraitSet {
169    entries: Punctuated<TraitSet, Token![;]>,
170}
171
172impl Parse for ManyTraitSet {
173    fn parse(input: ParseStream) -> Result<Self> {
174        Ok(ManyTraitSet {
175            entries: input.parse_terminated(TraitSet::parse)?,
176        })
177    }
178}
179
180impl ManyTraitSet {
181    fn render(self) -> TokenStream2 {
182        TokenStream2::from_iter(self.entries.into_iter().map(|entry| entry.render()))
183    }
184}
185
186/// Creates an alias for set of traits.
187///
188/// To demonstrate the idea, see the examples:
189///
190/// ```rust
191/// use trait_set::trait_set;
192///
193/// trait_set! {
194///     /// Doc-comments are also supported btw.
195///     pub trait ThreadSafe = Send + Sync;
196///     pub trait ThreadSafeIterator<T> = ThreadSafe + Iterator<Item = T>;
197///     pub trait ThreadSafeBytesIterator = ThreadSafeIterator<u8>;
198///     pub trait StaticDebug = 'static + std::fmt::Debug;
199/// }
200///```
201///
202/// This macro also supports [higher-rank trait bound][hrtb]:
203///
204/// ```rust
205/// # pub trait Serializer {
206/// #     type Ok;
207/// #     type Error;
208/// #
209/// #     fn ok_value() -> Self::Ok;
210/// # }
211/// # pub trait Deserializer<'de> {
212/// #     type Error;
213/// # }
214/// #
215/// # pub trait Serialize {
216/// #     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
217/// #     where
218/// #         S: Serializer;
219/// # }
220/// #
221/// # pub trait Deserialize<'de>: Sized {
222/// #     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
223/// #     where
224/// #         D: Deserializer<'de>;
225/// # }
226/// #
227/// # impl Serializer for u8 {
228/// #     type Ok = ();
229/// #     type Error = ();
230/// #
231/// #     fn ok_value() -> Self::Ok {
232/// #         ()
233/// #     }
234/// # }
235/// #
236/// # impl<'de> Deserializer<'de> for u8 {
237/// #     type Error = ();
238/// # }
239/// #
240/// # impl Serialize for u8 {
241/// #     fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
242/// #     where
243/// #         S: Serializer
244/// #     {
245/// #         Ok(S::ok_value())
246/// #     }
247/// # }
248/// #
249/// # impl<'de> Deserialize<'de> for u8 {
250/// #     fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
251/// #     where
252/// #         D: Deserializer<'de>
253/// #         {
254/// #             Ok(0u8)
255/// #         }
256/// # }
257/// use trait_set::trait_set;
258///
259/// trait_set!{
260///     pub trait Serde = Serialize + for<'de> Deserialize<'de>;
261///     // Note that you can also use lifetimes as a generic parameter.
262///     pub trait SerdeLifetimeTemplate<'de> = Serialize + Deserialize<'de>;
263/// }
264/// ```
265///
266/// [hrtb]: https://doc.rust-lang.org/nomicon/hrtb.html
267#[proc_macro]
268pub fn trait_set(tokens: TokenStream) -> TokenStream {
269    let input = parse_macro_input!(tokens as ManyTraitSet);
270    input.render().into()
271}