1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
#![cfg_attr(not(test), no_std)]
#![cfg_attr(test, feature(untagged_unions))]

//! This crate provides a macro that generates a trait-union type. That is, a trait
//! object type which can contain any one of a pre-determined set of implementors.
//!
//! The generated type does not allocate. The size of the type is the size of the largest
//! variant plus some constant overhead.
//!
//! **NOTE**: As of rustc 1.47, you must enable the `untagged_unions` feature to store
//! non-[Copy] types in a trait-union. This will change
//! [soon](https://github.com/rust-lang/rust/pull/77547).
//!
//! # Example
//!
//! ```rust
//! # use trait_union::trait_union;
//! # use std::fmt::Display;
//! #
//! trait_union! {
//!     /// Container can contain either an i32, a &'static str, or a bool.
//!     union Container: Display = i32 | &'static str | bool;
//! }
//!
//! let mut container = Container::new(32);
//! assert_eq!(container.to_string(), "32");
//!
//! container = Container::new("Hello World");
//! assert_eq!(container.to_string(), "Hello World");
//!
//! container = Container::new(true);
//! assert_eq!(container.to_string(), "true");
//! ```
//!
//! The generated type has the following interface:
//!
//! ```rust,ignore
//! struct Container {
//!     /* ... */
//! }
//!
//! impl Container {
//!     fn new(value: impl ContainerVariant) -> Self { /* ... */ }
//! }
//!
//! impl Deref for Container {
//!     type Target = dyn Display + 'static;
//!     /* ... */
//! }
//!
//! impl DerefMut for Container {
//!     /* ... */
//! }
//!
//! unsafe trait ContainerVariant: Display + 'static { }
//!
//! unsafe impl ContainerVariant for i32 { }
//! unsafe impl ContainerVariant for &'static str { }
//! unsafe impl ContainerVariant for bool { }
//! ```

/// Macro that generates a trait-union type
///
/// # Syntax
///
/// Each invocation of the macro can generate an arbitrary number of trait-union types.
///
/// The syntax of each declaration is as follows:
///
/// ```txt
/// ATTRIBUTE* VISIBILITY? 'union' NAME GENERICS? ':' TRAIT_BOUNDS ('where' WHERE_CLAUSE)? '=' TYPE ('|' TYPE)* '|'? ';'
/// ```
///
/// `?` denotes an optional segment. `*` denotes 0 or more repetitions.
///
/// For example:
///
/// ```rust,ignore
/// /// MyUnion trait-union
/// pub(crate) union MyUnion<'a, T: 'a>: Debug+'a where T: Debug+Copy = &'a str | Option<T>;
/// ```
///
/// # Trait bounds
///
/// The `TRAIT_BOUNDS` segment denotes the trait that the trait-union will deref to. As
/// such, it must contain at least one trait, at most one non-auto trait, and 0 or more
/// lifetimes.
///
/// For example:
///
/// ```rust,ignore
/// Debug+Copy+'a // OK
/// 'a            // Error: No trait
/// Debug+Display // Error: More than one non-auto trait
/// ```
///
/// If you do not provide a lifetime, the `'static` lifetime will be added automatically.
/// That is, `Debug` is the same as `Debug+'static`. For example
///
/// ```rust,ignore
/// union MyUnion<'a>: Debug = &'a str;
/// ```
///
/// will not compile because `&'a str` is not `'static`. Write
///
/// ```rust,ignore
/// union MyUnion<'a>: Debug+'a = &'a str;
/// ```
///
/// instead.
///
/// # Output
///
/// The macro generates a struct with the specified name and an unsafe trait of the same
/// name plus the suffix `Variant`. For example
///
/// ```rust,ignore
/// pub(crate) union MyUnion<'a, T: 'a>: Debug+'a where T: Debug+Copy = &'a str | Option<T>
/// ```
///
/// generates
///
/// ```rust,ignore
/// pub(crate) struct MyUnion<'a, T: 'a> where T: Debug+Copy {
///     /* ... */
/// }
///
/// pub(crate) unsafe trait MyUnionVariant<'a, T: 'a>: Debug+'a where T: Debug+Copy { }
/// ```
///
/// The trait will automatically be implemented for all specified variants. The struct has
/// a single associated method:
///
/// ```rust,ignore
/// pub(crate) fn new(value: impl MyUnionVariant<'a, T>) -> Self { /* ... */ }
/// ```
///
/// The struct implements `Deref` and `DerefMut` with `Target = Debug+'a`.
pub use trait_union_proc::trait_union;

/// Macro that generates a trait-union type for [Copy] implementors
///
/// This macro is identical to [trait_union] except that
///
/// - all implementors must be [Copy]
/// - the generated type is not [Drop]
/// - `#[derive(Copy, Clone)]` can be used as an attribute
pub use trait_union_proc::trait_union_copy;

#[cfg(test)]
mod test {
    use super::{trait_union, trait_union_copy};
    use std::{
        fmt,
        fmt::{Display, Formatter},
        mem,
        sync::atomic::{AtomicUsize, Ordering::Relaxed},
    };

    trait F: Display {
        fn len(&self) -> usize;

        fn set_len(&mut self, len: usize);
    }

    impl F for u8 {
        fn len(&self) -> usize {
            *self as usize
        }

        fn set_len(&mut self, len: usize) {
            *self = len as u8;
        }
    }

    impl F for String {
        fn len(&self) -> usize {
            self.len()
        }

        fn set_len(&mut self, len: usize) {
            self.truncate(len);
        }
    }

    #[repr(align(4))]
    struct X;
    impl F for X {
        fn len(&self) -> usize {
            !0
        }

        fn set_len(&mut self, len: usize) {
            X_DROP_COUNT.store(len, Relaxed);
        }
    }
    static X_DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
    impl Drop for X {
        fn drop(&mut self) {
            X_DROP_COUNT.fetch_add(1, Relaxed);
        }
    }
    impl Display for X {
        fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
            write!(f, "X")
        }
    }

    trait_union! {
        union U: F = u8 | String | X;
    }

    #[test]
    fn test1() {
        let mut c = U::new(33);
        assert_eq!(mem::align_of_val(&*c), 1);
        assert_eq!(mem::size_of_val(&*c), 1);
        assert!(mem::align_of_val(&c) >= 4);
        assert!(mem::size_of_val(&c) >= 4);
        assert_eq!(c.len(), 33);
        c.set_len(22);
        assert_eq!(c.len(), 22);
        c = U::new("Hello World".to_string());
        assert_eq!(c.len(), 11);
        c.set_len(5);
        assert_eq!(c.len(), 5);
        assert_eq!(c.to_string(), "Hello");
        c = U::new(X);
        assert_eq!(mem::align_of_val(&*c), 4);
        assert_eq!(mem::size_of_val(&*c), 0);
        assert_eq!(c.len(), !0);
        assert_eq!(X_DROP_COUNT.load(Relaxed), 0);
        c.set_len(2);
        assert_eq!(X_DROP_COUNT.load(Relaxed), 2);
        drop(c);
        assert_eq!(X_DROP_COUNT.load(Relaxed), 3);
    }

    #[test]
    fn size() {
        assert_eq!(mem::size_of::<U>(), mem::size_of::<Option<U>>());
    }

    #[test]
    fn compile() {
        let t = trybuild::TestCases::new();
        t.compile_fail("tests/compile-fail/*.rs");
        t.pass("tests/pass/*.rs");
    }

    #[test]
    fn copy() {
        trait_union_copy! {
            #[derive(Copy, Clone)]
            union U: Display = u8 | &'static str;
        }

        let u = U::new("test");
        let v = u;
        assert_eq!(u.to_string(), v.to_string());
    }

    #[test]
    fn assert_sync() {
        let _: &dyn Sync = &U::new(1);
    }
}