sine_macro/
lib.rs

1/*
2 * Copyright (c) 2025 Tomi Leppänen
3 * SPDX-License-Identifier: MIT
4 */
5
6//! A procedural macro for generating signed integer sine waves as arrays.
7//!
8//! # Example
9//! ```rust
10//! use sine_macro::sine_wave;
11//!
12//! // Sine wave defined as const item:
13//! sine_wave! {
14//!     const MY_CONST_SINE_WAVE = sine_wave(frequency: 400, rate: 16_000);
15//! }
16//!
17//! // Or as static item:
18//! sine_wave! {
19//!     static MY_STATIC_SINE_WAVE = sine_wave(frequency: 1000, rate: 48_000, len: 48_000);
20//! }
21//!
22//! // Sine wave defined as local variable with default rate of 44,100 Hz:
23//! let wave = sine_wave!(frequency: 800, repeats: 10);
24//! ```
25//!
26//! See the macro documentation for [more examples][crate::sine_wave!#arguments-and-examples].
27
28#![deny(missing_docs)]
29#![forbid(unsafe_code)]
30
31use itertools::Itertools;
32use proc_macro2::{Delimiter, Group, Punct, Spacing, TokenStream, TokenTree};
33use quote::quote;
34use std::f64::consts::PI;
35use std::iter::repeat_n;
36use std::num::{NonZero, NonZeroU32, NonZeroUsize};
37use syn::parse::{Error, Parse, ParseStream};
38use syn::punctuated::Punctuated;
39use syn::token::Paren;
40use syn::{Ident, LitInt, Result, StaticMutability, Visibility, parse_macro_input};
41use syn::{Token, parenthesized};
42
43mod types;
44use crate::types::helpers::{Ident as GetIdent, Literal as GetLiteral, Max as GetMax};
45use crate::types::*;
46
47const DEFAULT_RATE: u32 = 44_100;
48const DEFAULT_TYPE: &str = "i16";
49
50struct SineWaveAttrs {
51    frequency: LitInt,
52    rate: Option<LitInt>,
53    len: Option<LitInt>,
54    repeats: Option<LitInt>,
55    skip: Option<LitInt>,
56    ty: Option<Type>,
57}
58
59impl Parse for SineWaveAttrs {
60    fn parse(input: ParseStream) -> Result<Self> {
61        let attrs = Punctuated::<AttrInput, Token![,]>::parse_terminated(input)?;
62        let mut frequency = None;
63        let mut rate: Option<LitInt> = None;
64        let mut len = None;
65        let mut repeats = None;
66        let mut skip = None;
67        let mut ty = None;
68        for attr in attrs {
69            match attr {
70                AttrInput::Int(IntAttrInput {
71                    name,
72                    value: Int::Frequency(attr_value),
73                    ..
74                }) => {
75                    if frequency.is_none() {
76                        let value: NonZeroU32 = attr_value.base10_parse()?;
77                        if let Some(rate) = &rate {
78                            let rate: NonZeroU32 = rate.base10_parse().unwrap();
79                            if rate < value {
80                                return Err(Error::new_spanned(
81                                    attr_value,
82                                    format_args!(
83                                        "`frequency` should be less than `rate`, which is {} Hz",
84                                        rate
85                                    ),
86                                ));
87                            }
88                        }
89                        frequency = Some(attr_value)
90                    } else {
91                        return Err(Error::new_spanned(name, "`frequency` defined twice"));
92                    }
93                }
94                AttrInput::Int(IntAttrInput {
95                    name,
96                    value: Int::Rate(attr_value),
97                    ..
98                }) => {
99                    if rate.is_none() {
100                        let value: NonZeroU32 = attr_value.base10_parse()?;
101                        if let Some(frequency) = &frequency {
102                            let frequency: NonZeroU32 = frequency.base10_parse().unwrap();
103                            if frequency > value {
104                                return Err(Error::new_spanned(
105                                    attr_value,
106                                    format_args!(
107                                        "`rate` should be more than `frequency`, which is {} Hz",
108                                        frequency
109                                    ),
110                                ));
111                            }
112                        }
113                        rate = Some(attr_value)
114                    } else {
115                        return Err(Error::new_spanned(name, "`rate` defined twice"));
116                    }
117                }
118                AttrInput::Int(IntAttrInput {
119                    name,
120                    value: Int::Len(attr_value),
121                    ..
122                }) => {
123                    if repeats.is_some() {
124                        return Err(Error::new_spanned(
125                            name,
126                            "cannot define both `len` and `repeats`",
127                        ));
128                    } else if len.is_none() {
129                        let _value: NonZeroUsize = attr_value.base10_parse()?;
130                        len = Some(attr_value)
131                    } else {
132                        return Err(Error::new_spanned(name, "`len` defined twice"));
133                    }
134                }
135                AttrInput::Int(IntAttrInput {
136                    name,
137                    value: Int::Repeats(attr_value),
138                    ..
139                }) => {
140                    if len.is_some() {
141                        return Err(Error::new(
142                            name.span(),
143                            "cannot define both `len` and `repeats`",
144                        ));
145                    } else if repeats.is_none() {
146                        let value: usize = attr_value.base10_parse()?;
147                        if value > 0 {
148                            repeats = Some(attr_value)
149                        } else {
150                            return Err(Error::new_spanned(
151                                attr_value,
152                                "`repeats` must be positive",
153                            ));
154                        }
155                    } else {
156                        return Err(Error::new_spanned(name, "`repeats` defined twice"));
157                    }
158                }
159                AttrInput::Int(IntAttrInput {
160                    name,
161                    value: Int::Skip(attr_value),
162                    ..
163                }) => {
164                    if skip.is_none() {
165                        let _value: u32 = attr_value.base10_parse()?;
166                        skip = Some(attr_value);
167                    } else {
168                        return Err(Error::new_spanned(name, "`skip` defined twice"));
169                    }
170                }
171                AttrInput::Type(TypeAttrInput {
172                    name,
173                    value: attr_value,
174                    ..
175                }) => {
176                    if ty.is_none() {
177                        ty = Some(attr_value)
178                    } else {
179                        return Err(Error::new_spanned(name, "`type` defined twice"));
180                    }
181                }
182            };
183        }
184        if let Some(frequency) = frequency {
185            if rate.is_none() {
186                let value: NonZeroU32 = frequency.base10_parse().unwrap();
187                if DEFAULT_RATE < value.get() {
188                    return Err(Error::new_spanned(
189                        frequency,
190                        "`frequency` should be less than `rate`, which is 44100 Hz",
191                    ));
192                }
193            }
194            Ok(SineWaveAttrs {
195                frequency,
196                rate,
197                len,
198                repeats,
199                skip,
200                ty,
201            })
202        } else {
203            Err(Error::new(input.span(), "`frequency` must be defined"))
204        }
205    }
206}
207
208struct Static {
209    vis: Visibility,
210    _static_token: Token![static],
211    mutability: StaticMutability,
212    ident: Ident,
213    _eq_token: Token![=],
214    name: Ident,
215    _paren: Paren,
216    attrs: SineWaveAttrs,
217    _semi_token: Token![;],
218}
219
220impl Parse for Static {
221    fn parse(input: ParseStream) -> Result<Self> {
222        let content;
223        Ok(Static {
224            vis: input.parse()?,
225            _static_token: input.parse()?,
226            mutability: input.parse()?,
227            ident: input.parse()?,
228            _eq_token: input.parse()?,
229            name: {
230                let name: Ident = input.parse()?;
231                if name != "sine_wave" {
232                    return Err(Error::new(
233                        name.span(),
234                        "the identifier must be `sine_wave`",
235                    ));
236                }
237                name
238            },
239            _paren: parenthesized!(content in input),
240            attrs: content.parse()?,
241            _semi_token: input.parse()?,
242        })
243    }
244}
245
246struct Const {
247    vis: Visibility,
248    _const_token: Token![const],
249    ident: Ident,
250    _eq_token: Token![=],
251    name: Ident,
252    _paren: Paren,
253    attrs: SineWaveAttrs,
254    _semi_token: Token![;],
255}
256
257impl Parse for Const {
258    fn parse(input: ParseStream) -> Result<Self> {
259        let content;
260        Ok(Const {
261            vis: input.parse()?,
262            _const_token: input.parse()?,
263            ident: input.parse()?,
264            _eq_token: input.parse()?,
265            name: {
266                let name: Ident = input.parse()?;
267                if name != "sine_wave" {
268                    return Err(Error::new(
269                        name.span(),
270                        "the identifier must be `sine_wave`",
271                    ));
272                }
273                name
274            },
275            _paren: parenthesized!(content in input),
276            attrs: content.parse()?,
277            _semi_token: input.parse()?,
278        })
279    }
280}
281
282enum SineWaveInput {
283    Local(SineWaveAttrs),
284    Static(Static),
285    Const(Const),
286}
287
288impl Parse for SineWaveInput {
289    fn parse(input: ParseStream) -> Result<Self> {
290        if input.peek(Token![pub]) && (input.peek2(Token![static]) || input.peek2(Token![const])) {
291            if input.peek2(Token![static]) {
292                input.parse().map(SineWaveInput::Static)
293            } else {
294                input.parse().map(SineWaveInput::Const)
295            }
296        } else if input.peek(Token![static]) {
297            input.parse().map(SineWaveInput::Static)
298        } else if input.peek(Token![const]) {
299            input.parse().map(SineWaveInput::Const)
300        } else {
301            input.parse().map(SineWaveInput::Local)
302        }
303    }
304}
305
306fn get_number_of_samples(frequency: f64, rate: f64) -> usize {
307    ((rate / frequency) as u64).try_into().unwrap()
308}
309
310impl SineWaveInput {
311    fn get_attrs(&self) -> &SineWaveAttrs {
312        match self {
313            Self::Local(attrs) => attrs,
314            Self::Static(Static { attrs, .. }) => attrs,
315            Self::Const(Const { attrs, .. }) => attrs,
316        }
317    }
318}
319
320/// Generates an array of signed integers for a sine wave.
321///
322/// Sample rate and frequency of the wave can be controlled with `rate` and `frequency`
323/// respectively. [Rounding][crate::sine_wave!#rounding] may apply which can affect the frequency
324/// of the final wave slightly. The length of the array is calculated as `floor(rate / frequency)`.
325///
326/// The array is by default one period long so it can be repeated as many times as needed. If a
327/// specific number of samples or number of repeated periods are required use `len` and `repeats`
328/// respectively. Both cannot be used simultaneously. It is also possible to start the array on a
329/// later point with `skip`. That effectively introduces a phase shift and defaults to zero skipped
330/// samples.
331///
332/// # Arguments and examples
333/// `frequency` selects the frequency of the sine wave, and it is the only required argument.
334/// Negative or zero frequency is not accepted. It also must be sufficiently smaller than the
335/// sampling rate used. See [Nyquist frequency][Nyquist_frequency] for more information. This macro
336/// refuses to generate arrays with only zero values.
337///
338/// [Nyquist_frequency]: https://en.wikipedia.org/wiki/Nyquist_frequency
339///
340/// ```rust
341/// # use sine_macro::sine_wave;
342/// // Sine wave of 1,000 Hz with sampling rate of 44,100 Hz (the default).
343/// let wave = sine_wave!(frequency: 1_000);
344/// ```
345///
346/// `rate` specifies sampling rate of the array. If unspecified, 44,100 Hz is used instead.
347/// Sampling rate must be sufficiently larger than the specified frequency of the wave. See the
348/// information above about `frequency` for more information.
349///
350/// ```rust
351/// # use sine_macro::sine_wave;
352/// // Sine wave of 400 Hz with sampling rate of 48,000 Hz
353/// let wave = sine_wave!(rate: 48_000, frequency: 400);
354/// ```
355///
356/// `type` defines the data type of the array. It can be any of [`i8`], [`i16`] and [`i32`]. Defaults to
357/// [`i16`] when unspecified. The values will always span the whole range of the type sans `MIN`.
358///
359/// ```rust
360/// # use sine_macro::sine_wave;
361/// // Sine wave of 100 Hz with i8 data type, so
362/// let wave = sine_wave!(frequency: 100, type: i8);
363/// ```
364///
365/// `len` specifies how many samples the array must contain. This may cut the wave short on any
366/// period but it can be also used for generating waves of specific duration. E.g. one second long
367/// wave can be generated by setting (sampling) `rate` and `len` to the same value. However the
368/// same effect can be achieved with iterators (see [cycle][core::iter::Iterator::cycle] and
369/// [take][core::iter::Iterator::take]) without storing longer arrays. `len` and `repeats` cannot
370/// be used at the same time.
371///
372/// ```rust
373/// # use sine_macro::sine_wave;
374/// // Sine wave of 440 Hz with sampling rate of 16,000 Hz and one second length
375/// let wave_of_a_second = sine_wave!(frequency: 440, rate: 16_000, len: 16_000);
376/// // Or you can store only the first phase and use iterator tricks:
377/// let wave = sine_wave!(frequency: 440, rate: 16_000);
378/// let iter = wave.iter().cycle().take(16_000);
379/// assert_eq!(wave_of_a_second.len(), iter.count());
380/// ```
381///
382/// `repeats` specifies the number of periods to create. This can be useful if you know that you
383/// need a number of repeats but cannot read the same array repeatedly. Otherwise this can be
384/// achieved with iterators (see [cycle][core::iter::Iterator::cycle] and
385/// [take][core::iter::Iterator::take]) without storing longer arrays. `len` and `repeats` cannot
386/// be used at the same time.
387///
388/// ```rust
389/// # use sine_macro::sine_wave;
390/// // Sine wave of 360 Hz repeated 10 times
391/// let wave_10_repeats = sine_wave!(frequency: 360, rate: 44_100, repeats: 10);
392/// // Or you can store only the first phase and use iterator tricks:
393/// let wave_no_repeats = sine_wave!(frequency: 360, rate: 44_100);
394/// let iter = wave_no_repeats.iter().cycle().take(44_100 / 360 * 10);
395/// assert_eq!(wave_10_repeats.len(), iter.count());
396/// ```
397///
398/// `skip` specifies the number of samples to skip before starting to generate the array. This does
399/// not affect the length of the array but it can be used for introducing a phase shift.
400///
401/// ```rust
402/// # use sine_macro::sine_wave;
403/// // Cosine wave of 400 Hz (cosine is sine with 90 degree phase shift, i.e. it starts 1/4 later)
404/// let wave = sine_wave!(frequency: 400, skip: 100);
405/// ```
406///
407/// # Use with static and const
408/// Since `const` and `static` items must have their types defined and a macro cannot override
409/// that, this provides a syntax similar to
410/// [lazy_static!](https://docs.rs/lazy_static/latest/lazy_static/) for defining `const` or
411/// `static` items. All the same arguments are supported with this syntax.
412///
413/// ```ignore
414/// # use sine_macro::sine_wave;
415/// // Syntax for const:
416/// sine_wave! {
417///     [pub] const NAME = sine_wave(frequency: ...);
418/// }
419/// // or static:
420/// sine_wave! {
421///     [pub] static [mut] NAME = sine_wave(frequency: ...);
422/// }
423/// ```
424///
425/// So, for example, to define public const item for a sine wave:
426/// ```rust
427/// # use sine_macro::sine_wave;
428/// sine_wave! {
429///     pub const BEEP = sine_wave(frequency: 440, rate: 48_000, type: i8);
430/// }
431/// ```
432///
433/// Or static mutable (for whatever reason):
434/// ```rust
435/// # use sine_macro::sine_wave;
436/// sine_wave! {
437///     static mut MUTABLE_BEEP = sine_wave(frequency: 440, rate: 48_000, type: i8);
438/// }
439/// ```
440///
441/// # Rounding
442/// Rounding of the length of the array can affect the sine wave slightly. This always rounds
443/// before generating the wave so the waves will always start from zero and end so that the next
444/// value would be zero, unless `skip` or `len` is changing that.
445///
446/// ```rust
447/// # use sine_macro::sine_wave;
448/// // For example these waves are actually both 441 Hz
449/// let wave_440 = sine_wave!(frequency: 440, rate: 44_100);
450/// let wave_441 = sine_wave!(frequency: 441, rate: 44_100);
451/// assert_eq!(wave_440, wave_441);
452/// ```
453#[proc_macro]
454pub fn sine_wave(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
455    let input = parse_macro_input!(tokens as SineWaveInput);
456    let attrs = input.get_attrs();
457    let ty = attrs.ty.clone();
458    let frequency: NonZeroU32 = attrs.frequency.clone().base10_parse().unwrap();
459    let rate: NonZeroU32 = attrs
460        .rate
461        .clone()
462        .map(|input| input.base10_parse().unwrap())
463        .unwrap_or_else(|| NonZero::new(DEFAULT_RATE).unwrap());
464    let values = get_number_of_samples(frequency.get() as f64, rate.get() as f64);
465    let count;
466    let sine_wave_tokens = {
467        let multiplier = PI * 2_f64 / values as f64;
468        let samples: Vec<_> = (0..values)
469            .map(|i| (i as f64 * multiplier))
470            .map(f64::sin)
471            .map(|value| value * ty.max() as f64)
472            .map(|value| value as i32)
473            .collect();
474        // Just a little sanity check
475        if !samples.iter().any(|x| *x != 0) {
476            return {
477                Error::new_spanned(
478                    &attrs.frequency,
479                    format_args!(
480                        "could not generate sine wave for `rate` of {} Hz and `frequency` of {} Hz",
481                        rate, frequency
482                    ),
483                )
484                .into_compile_error()
485                .into()
486            };
487        }
488        count = attrs
489            .len
490            .clone()
491            .map(|input| input.base10_parse().unwrap())
492            .unwrap_or_else(|| {
493                samples.len()
494                    * attrs
495                        .repeats
496                        .clone()
497                        .map(|input| input.base10_parse().unwrap())
498                        .unwrap_or(1)
499            });
500        let skip = attrs
501            .skip
502            .clone()
503            .map(|input| input.base10_parse().unwrap())
504            .unwrap_or(0);
505        let tokens = TokenStream::from_iter(
506            samples
507                .iter()
508                .cycle()
509                .skip(skip)
510                .take(count)
511                .map(|value| TokenTree::Literal(ty.literal(*value)))
512                .interleave(repeat_n(
513                    TokenTree::from(Punct::new(',', Spacing::Alone)),
514                    count - 1,
515                )),
516        );
517        TokenStream::from(TokenTree::from(Group::new(Delimiter::Bracket, tokens)))
518    };
519    match input {
520        SineWaveInput::Local(_) => sine_wave_tokens.into(),
521        SineWaveInput::Static(item) => {
522            assert_eq!(item.name, "sine_wave");
523            let vis = item.vis;
524            let mutability = item.mutability;
525            let ident = item.ident;
526            let ty = ty.ident();
527            quote! {
528                #vis static #mutability #ident: [#ty; #count] = #sine_wave_tokens;
529            }
530            .into()
531        }
532        SineWaveInput::Const(item) => {
533            assert_eq!(item.name, "sine_wave");
534            let vis = item.vis;
535            let ident = item.ident;
536            let ty = ty.ident();
537            quote! {
538                #vis const #ident: [#ty; #count] = #sine_wave_tokens;
539            }
540            .into()
541        }
542    }
543}