servo_pio/
pwm_cluster.rs

1use arrayvec::ArrayVec;
2use core::any::{Any, TypeId};
3use core::cell::RefCell;
4use core::marker::PhantomData;
5use cortex_m::singleton;
6use critical_section::Mutex;
7use defmt::Format;
8use fugit::HertzU32;
9use pio::Program;
10use pio_proc::pio_file;
11use rp2040_hal::clocks::SystemClock;
12use rp2040_hal::dma::double_buffer::{
13    Config as DoubleBufferingConfig, ReadNext, Transfer as DoubleBuffering,
14};
15use rp2040_hal::dma::{Channel, ChannelIndex, ReadTarget, SingleChannel};
16use rp2040_hal::gpio::{DynPinId, Function, Pin, PullNone};
17use rp2040_hal::pio::{
18    PIOExt, PinDir, Running, StateMachine, StateMachineIndex, Tx, UninitStateMachine, PIO,
19};
20use rp2040_hal::{self, Clock};
21
22use crate::initialize_array;
23
24pub type DynPin<F> = Pin<DynPinId, F, PullNone>;
25type PinData = &'static mut Sequence;
26type TxTransfer<C1, C2, P, SM> =
27    DoubleBuffering<Channel<C1>, Channel<C2>, PinData, Tx<(P, SM)>, ReadNext<PinData>>;
28type WaitingTxTransfer<C1, C2, P, SM> =
29    DoubleBuffering<Channel<C1>, Channel<C2>, PinData, Tx<(P, SM)>, ()>;
30
31const NUM_BUFFERS: usize = 3;
32// Set to 64, the maximum number of single rises and falls for 32 channels within a looping time period
33const BUFFER_SIZE: usize = 64;
34// The number of dummy transitions to insert into the data to delay the DMA interrupt (if zero then
35// no zone is used)
36const LOADING_ZONE_SIZE: u32 = 3;
37// The number of levels before the top to insert the load zone.
38// Smaller values will make the DMA interrupt trigger closer to the time the data is needed,
39// but risks stalling the PIO if the interrupt takes longer due to other processes.
40const LOADING_ZONE_POSITION: u32 = 55;
41const MAX_PWM_CLUSTER_WRAP: u64 = u16::MAX as u64;
42// KEEP IN SYNC WITH pwm.pio!
43const PWM_CLUSTER_CYCLES: u64 = 5;
44
45pub struct GlobalState<C1, C2, P, SM>
46where
47    C1: ChannelIndex,
48    C2: ChannelIndex,
49    P: PIOExt,
50    SM: StateMachineIndex,
51{
52    handler: Option<InterruptHandler<C1, C2, P, SM>>,
53    indices: Mutex<RefCell<Indices>>,
54    sequences: Mutex<RefCell<SequenceList>>,
55    sequence1: Sequence,
56    sequence2: Sequence,
57    sequence3: Sequence,
58    loop_sequences: Mutex<RefCell<SequenceList>>,
59    loop_sequence1: Sequence,
60    loop_sequence2: Sequence,
61    loop_sequence3: Sequence,
62}
63
64/// A trait to abstract over [GlobalState] structs with different generic parameters.
65pub trait Handler: Any {
66    fn try_next_dma_sequence(&mut self);
67}
68
69// Downcast hack since conversion between dyn Handler and dyn Any not yet supported.
70// Copy of dyn Any handler (but only for mut variants).
71impl dyn Handler {
72    #[inline]
73    pub(crate) fn downcast_mut<T: Any>(&mut self) -> Option<&mut T> {
74        if self.is::<T>() {
75            // SAFETY: just checked whether we are pointing to the correct type, and we can rely on
76            // that check for memory safety because we have implemented Any for all types; no other
77            // impls can exist as they would conflict with our impl.
78            unsafe { Some(self.downcast_mut_unchecked()) }
79        } else {
80            None
81        }
82    }
83
84    #[inline]
85    /// # Safety
86    /// caller guarantees that T is the correct type
87    pub(crate) unsafe fn downcast_mut_unchecked<T: Any>(&mut self) -> &mut T {
88        debug_assert!(self.is::<T>());
89        // SAFETY: caller guarantees that T is the correct type
90        unsafe { &mut *(self as *mut dyn Handler as *mut T) }
91    }
92
93    #[inline]
94    pub(crate) fn is<T: Any>(&self) -> bool {
95        // Get `TypeId` of the type this function is instantiated with.
96        let t = TypeId::of::<T>();
97
98        // Get `TypeId` of the type in the trait object (`self`).
99        let concrete = self.type_id();
100
101        // Compare both `TypeId`s on equality.
102        t == concrete
103    }
104}
105
106impl<C1, C2, P, SM> Handler for GlobalState<C1, C2, P, SM>
107where
108    C1: ChannelIndex + 'static,
109    C2: ChannelIndex + 'static,
110    P: PIOExt + 'static,
111    SM: StateMachineIndex + 'static,
112{
113    fn try_next_dma_sequence(&mut self) {
114        if let Some(ref mut handler) = self.handler {
115            handler.try_next_dma_sequence();
116        }
117    }
118}
119
120/// A struct that can be used to store multiple [GlobalState]s.
121pub struct GlobalStates<const NUM_CHANNELS: usize> {
122    pub states: [Option<&'static mut dyn Handler>; NUM_CHANNELS],
123}
124
125impl<const NUM_CHANNELS: usize> GlobalStates<NUM_CHANNELS> {
126    /// Returns global state if types match.
127    pub(crate) fn get_mut<C1, C2, P, SM, F>(
128        &mut self,
129        _channels: &mut (Channel<C1>, Channel<C2>),
130        ctor: F,
131    ) -> Option<*mut GlobalState<C1, C2, P, SM>>
132    where
133        C1: ChannelIndex + 'static,
134        C2: ChannelIndex + 'static,
135        P: PIOExt + 'static,
136        SM: StateMachineIndex + 'static,
137        F: FnOnce() -> &'static mut GlobalState<C1, C2, P, SM>,
138    {
139        let entry = &mut self.states[<C1 as ChannelIndex>::id() as usize];
140        if entry.is_none() {
141            *entry = Some(ctor())
142        }
143
144        let state: *mut &'static mut dyn Handler = entry.as_mut().unwrap() as *mut _;
145        // Safety: Already reference to static mut, which is already unsafe.
146        let state: &'static mut dyn Handler = unsafe { *state };
147        <dyn Handler>::downcast_mut::<GlobalState<C1, C2, P, SM>>(state).map(|d| d as *mut _)
148    }
149}
150
151/// Indices for the current sequence to write to.
152struct Indices {
153    /// The last written index.
154    last_written_index: usize,
155    /// The current read index.
156    read_index: usize,
157}
158
159impl Indices {
160    fn new() -> Self {
161        Self {
162            last_written_index: 0,
163            read_index: 0,
164        }
165    }
166}
167
168/// Builder for array of [Sequence]s.
169struct SequenceListBuilder;
170type SequenceList = [Option<*mut Sequence>; NUM_BUFFERS];
171
172impl SequenceListBuilder {
173    /// Construct a list of sequences with the first transition in each defaulted to a delay of 10.
174    fn build() -> SequenceList {
175        initialize_array::<Option<*mut Sequence>, NUM_BUFFERS>(|| None)
176    }
177}
178
179pub struct InterruptHandler<C1, C2, P, SM>
180where
181    C1: ChannelIndex,
182    C2: ChannelIndex,
183    P: PIOExt,
184    SM: StateMachineIndex,
185{
186    sequences: &'static Mutex<RefCell<SequenceList>>,
187    loop_sequences: &'static Mutex<RefCell<SequenceList>>,
188    indices: &'static Mutex<RefCell<Indices>>,
189    tx_transfer: Option<TxTransfer<C1, C2, P, SM>>,
190    /// Track where the last sequence came from so we can put it back.
191    /// None means it came from the singletons, so don't worry
192    last_was_loop: Option<bool>,
193}
194
195/// DMA interrupt handler. Call directly from DMA_IRQ_0 or DMA_IRQ_1.
196#[inline]
197pub fn dma_interrupt<const NUM_PINS: usize>(global_state: &'static mut GlobalStates<NUM_PINS>) {
198    for state in global_state.states.iter_mut().flatten() {
199        state.try_next_dma_sequence();
200    }
201}
202
203impl<C1, C2, P, SM> InterruptHandler<C1, C2, P, SM>
204where
205    C1: ChannelIndex,
206    C2: ChannelIndex,
207    P: PIOExt,
208    SM: StateMachineIndex,
209{
210    /// Try to setup the next dma sequence..
211    fn try_next_dma_sequence(&mut self) {
212        let (tx_buf, tx) = {
213            if let Some(mut tx_transfer) = self.tx_transfer.take() {
214                // Check the interrupt to clear it if this is the transfer that's ready.
215                if tx_transfer.check_irq0() && tx_transfer.is_done() {
216                    tx_transfer.wait()
217                } else {
218                    // Either this wasn't the transfer that triggered the interrupt, or it did, but
219                    // for some reason it's not ready. Place it back so we can try again next time.
220                    self.tx_transfer = Some(tx_transfer);
221                    return;
222                }
223            } else {
224                // This interrupt handler has not been configured with a handler yet.
225                return;
226            }
227        };
228        self.next_dma_sequence(tx_buf, tx);
229    }
230
231    fn next_dma_sequence(
232        &mut self,
233        tx_buf: &'static mut Sequence,
234        tx: WaitingTxTransfer<C1, C2, P, SM>,
235    ) {
236        critical_section::with(|cs| {
237            let mut indices = self.indices.borrow(cs).borrow_mut();
238            // If there was a write since the last read...
239            let next_buf = if indices.last_written_index != indices.read_index {
240                if let Some(last_was_loop) = self.last_was_loop {
241                    // Put the sequence back before updating.
242                    if last_was_loop {
243                        self.loop_sequences.borrow(cs).borrow_mut()[indices.read_index] =
244                            Some(tx_buf);
245                    } else {
246                        self.sequences.borrow(cs).borrow_mut()[indices.read_index] = Some(tx_buf);
247                    }
248                }
249
250                // Update the read index and use sequences.
251                indices.read_index = indices.last_written_index;
252                self.last_was_loop = Some(false);
253                self.sequences.borrow(cs).borrow_mut()[indices.read_index]
254                    .take()
255                    .unwrap()
256            } else {
257                if let Some(false) = self.last_was_loop {
258                    // Put the sequence back before updating.
259                    self.sequences.borrow(cs).borrow_mut()[indices.read_index] = Some(tx_buf);
260                }
261
262                // Otherwise just use the loop sequences.
263                if let Some(sequence) =
264                    self.loop_sequences.borrow(cs).borrow_mut()[indices.read_index].take()
265                {
266                    self.last_was_loop = Some(true);
267                    // We took ownership from the sequence.
268                    sequence
269                } else {
270                    // We already have the buffer, so just re-use it.
271                    tx_buf as *mut _
272                }
273            };
274
275            // Start the next transfer.
276            // Safety: We took ownership from the sequence list, so we're the only
277            // location that has access to this unique reference.
278            self.tx_transfer = Some(tx.read_next(unsafe { &mut *next_buf }));
279        });
280    }
281}
282
283/// A type to manage a cluster of PWM signals.
284pub struct PwmCluster<const NUM_PINS: usize, P, SM>
285where
286    P: PIOExt,
287    SM: StateMachineIndex,
288{
289    sm: Option<StateMachine<(P, SM), Running>>,
290    channel_to_pin_map: [u8; NUM_PINS],
291    channels: [ChannelState; NUM_PINS],
292    sequences: &'static Mutex<RefCell<SequenceList>>,
293    loop_sequences: &'static Mutex<RefCell<SequenceList>>,
294    indices: &'static Mutex<RefCell<Indices>>,
295    transitions: [TransitionData; BUFFER_SIZE],
296    looping_transitions: [TransitionData; BUFFER_SIZE],
297    loading_zone: bool,
298    top: u32,
299}
300
301/// A type to build a [PwmCluster]
302pub struct PwmClusterBuilder<const NUM_PINS: usize, P> {
303    pin_mask: u32,
304    #[cfg(feature = "debug_pio")]
305    side_set_pin: u8,
306    channel_to_pin_map: [u8; NUM_PINS],
307    _phantom: PhantomData<P>,
308}
309
310impl<const NUM_PINS: usize, P> PwmClusterBuilder<NUM_PINS, P> {
311    /// Construct a new [PwmClusterBuilder].
312    pub fn new() -> Self {
313        Self {
314            pin_mask: 0,
315            #[cfg(feature = "debug_pio")]
316            side_set_pin: 0,
317            channel_to_pin_map: [0; NUM_PINS],
318            _phantom: PhantomData,
319        }
320    }
321
322    /// Set the pin_base for this cluster. NUM_PINS will be used to determine the pin count.
323    pub fn pin_base<F>(mut self, base_pin: DynPin<F>) -> Self
324    where
325        P: PIOExt<PinFunction = F>,
326        F: Function,
327    {
328        let base_pin = base_pin.id().num;
329        for (i, pin_map) in self.channel_to_pin_map.iter_mut().enumerate() {
330            let pin_id = base_pin + i as u8;
331            self.pin_mask |= 1 << pin_id;
332            *pin_map = pin_id;
333        }
334        self
335    }
336
337    /// Set the pins directly from the `pins` parameter.
338    pub fn pins<F>(mut self, pins: &[DynPin<F>; NUM_PINS]) -> Self
339    where
340        P: PIOExt<PinFunction = F>,
341        F: Function,
342    {
343        for (pin, pin_map) in pins.iter().zip(self.channel_to_pin_map.iter_mut()) {
344            let pin_id = pin.id().num;
345            self.pin_mask |= 1 << pin_id;
346            *pin_map = pin_id;
347        }
348
349        self
350    }
351
352    /// Set the side pin for debugging.
353    ///
354    /// This method can be enabled with the "debug_pio" feature.
355    #[cfg(feature = "debug_pio")]
356    pub fn side_pin<F>(mut self, side_set_pin: &DynPin<F>) -> Self
357    where
358        P: PIOExt<PinFunction = F>,
359        F: Function,
360    {
361        self.side_set_pin = side_set_pin.id().num;
362        self
363    }
364
365    /// Initialize [GlobalState]
366    pub(crate) fn prep_global_state<C1, C2, SM>(
367        global_state: &'static mut Option<GlobalState<C1, C2, P, SM>>,
368    ) -> &'static mut GlobalState<C1, C2, P, SM>
369    where
370        C1: ChannelIndex,
371        C2: ChannelIndex,
372        P: PIOExt,
373        SM: StateMachineIndex,
374    {
375        if global_state.is_none() {
376            *global_state = Some(GlobalState {
377                handler: None,
378                indices: Mutex::new(RefCell::new(Indices::new())),
379                sequences: Mutex::new(RefCell::new(SequenceListBuilder::build())),
380                sequence1: Sequence::new_for_list(),
381                sequence2: Sequence::new_for_list(),
382                sequence3: Sequence::new_for_list(),
383                loop_sequences: Mutex::new(RefCell::new(SequenceListBuilder::build())),
384                loop_sequence1: Sequence::new_for_list(),
385                loop_sequence2: Sequence::new_for_list(),
386                loop_sequence3: Sequence::new_for_list(),
387            });
388            {
389                let state = (*global_state).as_mut().unwrap();
390                critical_section::with(|cs| {
391                    // Safety: These self-referential fields are ok because the global state is
392                    // guaranteed to be stored in a static.
393                    let mut sequences = state.sequences.borrow(cs).borrow_mut();
394                    sequences[0] = Some(&mut state.sequence1 as *mut _);
395                    sequences[1] = Some(&mut state.sequence2 as *mut _);
396                    sequences[2] = Some(&mut state.sequence3 as *mut _);
397                    let mut loop_sequences = state.loop_sequences.borrow(cs).borrow_mut();
398                    loop_sequences[0] = Some(&mut state.loop_sequence1 as *mut _);
399                    loop_sequences[1] = Some(&mut state.loop_sequence2 as *mut _);
400                    loop_sequences[2] = Some(&mut state.loop_sequence3 as *mut _);
401                });
402            }
403        }
404        global_state.as_mut().unwrap()
405    }
406
407    /// Build a PwmCluster.
408    ///
409    /// # Safety
410    /// Caller must ensure that global_state is not being read/mutated anywhere else.
411    #[allow(clippy::too_many_arguments)]
412    pub unsafe fn build<C1, C2, SM, F>(
413        self,
414        servo_pins: [DynPin<F>; NUM_PINS],
415        pio: &mut PIO<P>,
416        sm: UninitStateMachine<(P, SM)>,
417        mut dma_channels: (Channel<C1>, Channel<C2>),
418        sys_clock: &SystemClock,
419        global_state: *mut GlobalState<C1, C2, P, SM>,
420    ) -> PwmCluster<NUM_PINS, P, SM>
421    where
422        C1: ChannelIndex,
423        C2: ChannelIndex,
424        P: PIOExt<PinFunction = F>,
425        F: Function,
426        SM: StateMachineIndex,
427    {
428        let program = Self::pio_program();
429        let installed = pio.install(&program).unwrap();
430        const DESIRED_CLOCK_HZ: u32 = 500_000;
431        let sys_hz = sys_clock.freq().to_Hz();
432        let (int, frac) = (
433            (sys_hz / DESIRED_CLOCK_HZ) as u16,
434            (sys_hz as u64 * 256 / DESIRED_CLOCK_HZ as u64) as u8,
435        );
436        let (mut sm, _, tx) = {
437            let mut builder = rp2040_hal::pio::PIOBuilder::from_program(installed);
438            #[cfg(not(feature = "debug_pio"))]
439            {
440                builder = builder.out_pins(0, 32)
441            }
442            #[cfg(feature = "debug_pio")]
443            {
444                builder = builder
445                    .out_pins(0, self.side_set_pin)
446                    .side_set_pin_base(self.side_set_pin)
447            }
448            builder.clock_divisor_fixed_point(int, frac).build(sm)
449        };
450        {
451            let iter_a = servo_pins.into_iter().map(|pin| pin.id().num);
452            let iter;
453            #[cfg(not(feature = "debug_pio"))]
454            {
455                iter = iter_a;
456            }
457            #[cfg(feature = "debug_pio")]
458            {
459                iter = iter_a.chain(Some(self.side_set_pin));
460            }
461            sm.set_pindirs(iter.map(|id| (id, PinDir::Output)));
462        }
463
464        let sequence = Sequence::new_for_list();
465
466        // Safety: caller guarantees that global_state is not being read/mutated anywhere else.
467        let mut interrupt_handler = unsafe {
468            InterruptHandler {
469                sequences: &(*global_state).sequences,
470                loop_sequences: &(*global_state).loop_sequences,
471                indices: &(*global_state).indices,
472                tx_transfer: None,
473                last_was_loop: None,
474            }
475        };
476
477        let tx_buf = singleton!(: Sequence = sequence.clone()).unwrap();
478        dma_channels.0.enable_irq0();
479        dma_channels.1.enable_irq0();
480        let tx = DoubleBufferingConfig::new(dma_channels, tx_buf, tx).start();
481        let tx_buf = singleton!(: Sequence = sequence).unwrap();
482        interrupt_handler.next_dma_sequence(tx_buf, tx);
483
484        // Safety: caller guarantees that global_state is not being read/mutated anywhere else.
485        unsafe { (*global_state).handler = Some(interrupt_handler) };
486
487        let sm = sm.start();
488
489        // Safety: caller guarantees that global_state is not being read/mutated anywhere else.
490        unsafe {
491            PwmCluster::new(
492                sm,
493                self.channel_to_pin_map,
494                &(*global_state).indices,
495                &(*global_state).sequences,
496                &(*global_state).loop_sequences,
497            )
498        }
499    }
500
501    /// Get PIO program data.
502    #[cfg(feature = "debug_pio")]
503    fn pio_program() -> Program<32> {
504        #[allow(non_snake_case)]
505        let pwm_program = pio_file!("./src/pwm.pio", select_program("debug_pwm_cluster"));
506        pwm_program.program
507    }
508
509    /// Get PIO program data.
510    #[cfg(not(feature = "debug_pio"))]
511    fn pio_program() -> Program<32> {
512        #[allow(non_snake_case)]
513        let pwm_program = pio_file!("./src/pwm.pio", select_program("pwm_cluster"));
514        pwm_program.program
515    }
516}
517
518impl<const NUM_PINS: usize, P> Default for PwmClusterBuilder<NUM_PINS, P> {
519    fn default() -> Self {
520        Self::new()
521    }
522}
523
524impl<const NUM_PINS: usize, P, SM> PwmCluster<NUM_PINS, P, SM>
525where
526    P: PIOExt,
527    SM: StateMachineIndex,
528{
529    /// The number of channels used by this cluster.
530    pub const CHANNEL_COUNT: usize = NUM_PINS;
531    /// The number of channel pairs used by this cluster.
532    pub const CHANNEL_PAIR_COUNT: usize = NUM_PINS / 2;
533
534    /// Get a [PwmClusterBuilder].
535    pub fn builder() -> PwmClusterBuilder<NUM_PINS, P> {
536        PwmClusterBuilder::new()
537    }
538
539    /// Construct the cluster.
540    fn new(
541        sm: StateMachine<(P, SM), Running>,
542        channel_to_pin_map: [u8; NUM_PINS],
543        indices: &'static Mutex<RefCell<Indices>>,
544        sequences: &'static Mutex<RefCell<SequenceList>>,
545        loop_sequences: &'static Mutex<RefCell<SequenceList>>,
546    ) -> Self {
547        let channels = [ChannelState::new(); NUM_PINS];
548        let transitions = [TransitionData::default(); BUFFER_SIZE];
549        let looping_transitions = [TransitionData::default(); BUFFER_SIZE];
550
551        Self {
552            sm: Some(sm),
553            channel_to_pin_map,
554            channels,
555            sequences,
556            loop_sequences,
557            loading_zone: false,
558            top: 0,
559            indices,
560            transitions,
561            looping_transitions,
562        }
563    }
564
565    /// Calculate the pwm factors (div and frac) from the system clock and desired frequency.
566    pub fn calculate_pwm_factors(system_clock_hz: HertzU32, freq: f32) -> Option<(u32, u32)> {
567        let source_hz = system_clock_hz.to_Hz() as u64 / PWM_CLUSTER_CYCLES;
568
569        // Check the provided frequency is valid
570        if (freq >= 0.01) && (freq <= (source_hz >> 1) as f32) {
571            let mut div256_top = ((source_hz << 8) as f32 / freq) as u64;
572            let mut top: u64 = 1;
573
574            loop {
575                // Try a few small prime factors to get close to the desired frequency.
576                if (div256_top >= (11 << 8))
577                    && (div256_top % 11 == 0)
578                    && (top * 11 <= MAX_PWM_CLUSTER_WRAP)
579                {
580                    div256_top /= 11;
581                    top *= 11;
582                } else if (div256_top >= (7 << 8))
583                    && (div256_top % 7 == 0)
584                    && (top * 7 <= MAX_PWM_CLUSTER_WRAP)
585                {
586                    div256_top /= 7;
587                    top *= 7;
588                } else if (div256_top >= (5 << 8))
589                    && (div256_top % 5 == 0)
590                    && (top * 5 <= MAX_PWM_CLUSTER_WRAP)
591                {
592                    div256_top /= 5;
593                    top *= 5;
594                } else if (div256_top >= (3 << 8))
595                    && (div256_top % 3 == 0)
596                    && (top * 3 <= MAX_PWM_CLUSTER_WRAP)
597                {
598                    div256_top /= 3;
599                    top *= 3;
600                } else if (div256_top >= (2 << 8)) && (top * 2 <= MAX_PWM_CLUSTER_WRAP) {
601                    div256_top /= 2;
602                    top *= 2;
603                } else {
604                    break;
605                }
606            }
607
608            // Only return valid factors if the divisor is actually achievable.
609            if div256_top >= 256 && div256_top <= ((u8::MAX as u64) << 8) {
610                Some((top as u32, div256_top as u32))
611            } else {
612                None
613            }
614        } else {
615            None
616        }
617    }
618
619    /// Set the clock divisor for this cluster.
620    pub fn clock_divisor_fixed_point(&mut self, div: u16, frac: u8) {
621        if let Some(sm) = self.sm.take() {
622            let mut sm = sm.stop();
623            sm.clock_divisor_fixed_point(div, frac);
624            self.sm = Some(sm.start());
625        }
626    }
627
628    /// The configured pwm level for the supplied `channel`.
629    pub fn channel_level(&self, channel: u8) -> Result<u32, PwmError> {
630        self.channels
631            .get(channel as usize)
632            .map(|c| c.level)
633            .ok_or(PwmError::MissingChannel)
634    }
635
636    /// Set the pwm level for the supplied `channel`. If `load` is true, update
637    /// the cluster's data in the PIO state machine.
638    pub fn set_channel_level(
639        &mut self,
640        channel: u8,
641        level: u32,
642        load: bool,
643    ) -> Result<(), PwmError> {
644        let res = self
645            .channels
646            .get_mut(channel as usize)
647            .map(|c| c.level = level)
648            .ok_or(PwmError::MissingChannel);
649        if load {
650            self.load_pwm();
651        }
652        res
653    }
654
655    pub fn channel_offset(&self, channel: u8) -> Result<u32, PwmError> {
656        self.channels
657            .get(channel as usize)
658            .map(|c| c.offset)
659            .ok_or(PwmError::MissingChannel)
660    }
661
662    /// Set the start offset for the supplied `channel`. If `load` is true,
663    /// update the cluster's data in the PIO state machine.
664    pub fn set_channel_offset(
665        &mut self,
666        channel: u8,
667        offset: u32,
668        load: bool,
669    ) -> Result<(), PwmError> {
670        self.channels
671            .get_mut(channel as usize)
672            .map(|c| c.offset = offset)
673            .ok_or(PwmError::MissingChannel)?;
674        if load {
675            self.load_pwm();
676        }
677        Ok(())
678    }
679
680    pub fn channel_polarity(&self, channel: u8) -> Result<bool, PwmError> {
681        self.channels
682            .get(channel as usize)
683            .map(|c| c.polarity)
684            .ok_or(PwmError::MissingChannel)
685    }
686
687    /// Set the polarity for the supplied `channel`. If `load` is true, update
688    /// the cluster's data in the PIO state machine.
689    pub fn set_channel_polarity(
690        &mut self,
691        channel: u8,
692        polarity: bool,
693        load: bool,
694    ) -> Result<(), PwmError> {
695        self.channels
696            .get_mut(channel as usize)
697            .map(|c| c.polarity = polarity)
698            .ok_or(PwmError::MissingChannel)?;
699        if load {
700            self.load_pwm()
701        }
702        Ok(())
703    }
704
705    /// Gets the top value for the pwm counter. This is equivalent to [`Slice::get_top`] when using
706    /// [pac::PWM].
707    ///
708    /// [`Slice::get_top`]: rp2040_hal::pwm::Slice
709    pub fn top(&self) -> u32 {
710        self.top
711    }
712
713    /// Set the top value for the supplied `channel`. If `load` is true, update
714    /// the cluster's data in the PIO state machine.
715    pub fn set_top(&mut self, top: u32, load_pwm: bool) {
716        self.top = top.max(1); // Cannot have a wrap of zero.
717        if load_pwm {
718            self.load_pwm();
719        }
720    }
721
722    /// Load the pwm data into the PIO state machine.
723    pub fn load_pwm(&mut self) {
724        let mut data_size = 0;
725        let mut looping_data_size = 0;
726
727        // Start with all pins low.
728        let mut pin_states = 0;
729
730        // Check if the data we last wrote has been picked up by the DMA yet.
731        let read_since_last_write = critical_section::with(|cs| {
732            let indices = self.indices.borrow(cs).borrow();
733            indices.read_index == indices.last_written_index
734        });
735
736        // Go through each channel that we are assigned to.
737        for (channel_idx, state) in self.channels.iter_mut().enumerate() {
738            // Invert this channel's initial state of its polarity invert is set.
739            if state.polarity {
740                pin_states |= 1 << self.channel_to_pin_map[channel_idx];
741            }
742
743            let channel_start = state.offset;
744            let channel_end = state.offset + state.level;
745            let channel_wrap_end = channel_end % self.top;
746
747            // If the data as been read, copy the channel overruns from that sequence.
748            // Otherwise, keep the ones we previously stored.
749            if read_since_last_write {
750                // This condition was added to deal with cases of load_pwm() being called multiple
751                // times between DMA reads, and thus losing memory of the previous sequence's
752                // overruns.
753                state.overrun = state.next_overrun;
754            }
755
756            // Always clear the next channel overruns, as we are loading new data.
757            state.next_overrun = 0;
758
759            // Did the previous sequence overrun the top level?
760            if state.overrun > 0 {
761                // Flip the initial state so the pin starts "high" (or "low" if polarity inverted).
762                pin_states ^= 1 << self.channel_to_pin_map[channel_idx];
763
764                // Is our end level before our start level?
765                if channel_wrap_end < channel_start {
766                    // Yes, so add a transition to "low" (or "high" if polarity inverted) at the end
767                    // level, rather than the overrun (so our pulse takes effect earlier).
768                    Self::sorted_insert(
769                        &mut self.transitions,
770                        &mut data_size,
771                        TransitionData::new(channel_idx as u8, channel_wrap_end, state.polarity),
772                    )
773                } else if state.overrun < channel_start {
774                    // No, so add a transition to "low" (or "high" if polarity inverted) at the
775                    // overrun instead.
776                    Self::sorted_insert(
777                        &mut self.transitions,
778                        &mut data_size,
779                        TransitionData::new(channel_idx as u8, state.overrun, state.polarity),
780                    )
781                }
782            }
783
784            // Is the state level greater than zero, and the start level within the top?
785            if state.level > 0 && channel_start < self.top {
786                // Add a transition to "high" (or "low" if polarity inverted) at the start level.
787                Self::sorted_insert(
788                    &mut self.transitions,
789                    &mut data_size,
790                    TransitionData {
791                        channel: channel_idx as u8,
792                        level: channel_start,
793                        state: !state.polarity,
794                        dummy: false,
795                    },
796                );
797                Self::sorted_insert(
798                    &mut self.looping_transitions,
799                    &mut looping_data_size,
800                    TransitionData {
801                        channel: channel_idx as u8,
802                        level: channel_start,
803                        state: !state.polarity,
804                        dummy: false,
805                    },
806                );
807
808                // If the channel has overrun the top level, record by how much.
809                if channel_wrap_end < channel_start {
810                    state.next_overrun = channel_wrap_end;
811                }
812            }
813
814            // Is the state level within the top?
815            if state.level < self.top {
816                // Is the end level within the wrap too?
817                if channel_end < self.top {
818                    // Add a transition to "low" (or "high" if the polarity inverted) at the end level.
819                    Self::sorted_insert(
820                        &mut self.transitions,
821                        &mut data_size,
822                        TransitionData::new(channel_idx as u8, channel_end, state.polarity),
823                    );
824                }
825
826                // Add a transition to "low" (or "high" if polarity inverted) at the top level.
827                Self::sorted_insert(
828                    &mut self.looping_transitions,
829                    &mut looping_data_size,
830                    TransitionData::new(channel_idx as u8, channel_wrap_end, state.polarity),
831                );
832            }
833        }
834
835        if self.loading_zone {
836            // Introduce "Loading Zone" transitions to the end of the sequence to
837            // prevent the DMA interrupt firing many milliseconds before the sequence ends.
838            let zone_inserts = LOADING_ZONE_SIZE.min(self.top - LOADING_ZONE_POSITION);
839            for i in (zone_inserts + LOADING_ZONE_POSITION)..LOADING_ZONE_POSITION {
840                Self::sorted_insert(
841                    &mut self.transitions,
842                    &mut data_size,
843                    TransitionData::with_level(self.top - i),
844                );
845                Self::sorted_insert(
846                    &mut self.looping_transitions,
847                    &mut looping_data_size,
848                    TransitionData::with_level(self.top - i),
849                );
850            }
851        }
852
853        // Read | Last W = Write
854        // 0    | 0      = 1 (or 2)
855        // 0    | 1      = 2
856        // 0    | 2      = 1
857        // 1    | 0      = 2
858        // 1    | 1      = 2 (or 0)
859        // 1    | 2      = 0
860        // 2    | 0      = 1
861        // 2    | 1      = 0
862        // 2    | 2      = 0 (or 1)
863
864        // Choose the write index based on the read and last written indices (using the above table).
865        let write_index = critical_section::with(|cs| {
866            let indices = self.indices.borrow(cs).borrow();
867            let write_index = (indices.read_index + 1) % NUM_BUFFERS;
868            if write_index == indices.last_written_index {
869                (write_index + 1) % NUM_BUFFERS
870            } else {
871                write_index
872            }
873        });
874
875        self.populate_sequence(
876            TransitionType::Normal,
877            data_size,
878            write_index,
879            &mut pin_states,
880        );
881        self.populate_sequence(
882            TransitionType::Loop,
883            looping_data_size,
884            write_index,
885            &mut pin_states,
886        );
887
888        critical_section::with(|cs| {
889            let mut indices = self.indices.borrow(cs).borrow_mut();
890            indices.last_written_index = write_index;
891        });
892    }
893
894    /// Insert `data` into `transitions` in sorted order.
895    fn sorted_insert(transitions: &mut [TransitionData], size: &mut usize, data: TransitionData) {
896        let mut i = *size;
897        while i > 0 && transitions[i - 1].level > data.level {
898            transitions[i] = transitions[i - 1];
899            i -= 1;
900        }
901        transitions[i] = data;
902        *size += 1;
903    }
904
905    /// Populate the sequence of kind `TransitionType` with `transition_size` number of elements
906    /// into sequence at `sequence_id`. Updates `pin_states` with the current states for each pin.
907    fn populate_sequence(
908        &mut self,
909        transition_type: TransitionType,
910        transition_size: usize,
911        sequence_id: usize,
912        pin_states: &mut u32,
913    ) {
914        critical_section::with(|cs| {
915            let mut sequences;
916            let mut loop_sequences;
917            let (transitions, sequence) = match transition_type {
918                TransitionType::Normal => {
919                    sequences = self.sequences.borrow(cs).borrow_mut();
920                    (
921                        &self.transitions[..transition_size],
922                        sequences[sequence_id].as_mut().unwrap(),
923                    )
924                }
925                TransitionType::Loop => {
926                    loop_sequences = self.loop_sequences.borrow(cs).borrow_mut();
927                    (
928                        &self.looping_transitions[..transition_size],
929                        loop_sequences[sequence_id].as_mut().unwrap(),
930                    )
931                }
932            };
933
934            // Reset the sequence, otherwise we end up appending and weird thing happen.
935            // Safety: We only mutate this if we can get a reference to it. The interrupt takes
936            // ownership of the sequence when it's in use by the DMA controller.
937            (unsafe { &mut **sequence }).data.clear();
938
939            let mut data_index = 0;
940            let mut current_level = 0;
941
942            // Populate the selected write sequence with pin states and delays.
943            while data_index < transition_size {
944                // Set the next level to be the top, initially.
945                let mut next_level = self.top;
946
947                loop {
948                    // Is the level of this transition at the current level being checked?
949                    let transition = &transitions[data_index];
950                    if transition.level <= current_level {
951                        // Yes, so add the transition state to the pin states mask, if it's not a
952                        // dummy transition.
953                        if !transition.dummy {
954                            if transition.state {
955                                *pin_states |=
956                                    1 << self.channel_to_pin_map[transition.channel as usize];
957                            } else {
958                                *pin_states &=
959                                    !(1 << self.channel_to_pin_map[transition.channel as usize]);
960                            }
961                        }
962
963                        // Move on to the next transition.
964                        data_index += 1;
965                    } else {
966                        // No, it is higher, so set it as our next level and break out of the loop.
967                        next_level = transition.level;
968                        break;
969                    }
970
971                    if data_index >= transition_size {
972                        break;
973                    }
974                }
975
976                // Add the transition to the sequence.
977                let transition = Transition {
978                    mask: *pin_states,
979                    delay: (next_level - current_level) - 1,
980                };
981                // Safety: We only mutate this if we can get a reference to it. The interrupt takes
982                // ownership of the sequence when it's in use by the DMA controller.
983                (unsafe { &mut **sequence }).data.push(transition);
984                current_level = next_level;
985            }
986        });
987    }
988}
989
990/// Error for PwnCluster methods.
991pub enum PwmError {
992    /// The supplied channel was not part of the PwmCluster.
993    MissingChannel,
994}
995
996/// Kinds of transitions.
997#[derive(Copy, Clone, Format)]
998enum TransitionType {
999    /// Normal types are used when a new update comes in.
1000    Normal,
1001    /// Loop types are used to repeat the currently setup structure of updates.
1002    Loop,
1003}
1004
1005/// State used for each channel.
1006#[derive(Copy, Clone, Format)]
1007struct ChannelState {
1008    /// The level the channel is currently set to.
1009    level: u32,
1010    /// The offset the channel should use when starting a new high signal.
1011    offset: u32,
1012    /// Whether to invert the high/low signals.
1013    polarity: bool,
1014    /// Track when level wraps around the `top` value of the cluster.
1015    overrun: u32,
1016    /// Helper storage for overrun to account for multiple loads between DMA reads.
1017    next_overrun: u32,
1018}
1019
1020impl ChannelState {
1021    fn new() -> Self {
1022        Self {
1023            level: 0,
1024            offset: 0,
1025            polarity: false,
1026            overrun: 0,
1027            next_overrun: 0,
1028        }
1029    }
1030}
1031
1032/// A Sequence of [Transition]s
1033#[derive(Clone)]
1034pub struct Sequence {
1035    /// Inner array of transitions.
1036    data: ArrayVec<Transition, BUFFER_SIZE>,
1037}
1038
1039impl Sequence {
1040    /// Constructor for a Sequence.
1041    pub fn new() -> Self {
1042        let mut data = ArrayVec::default();
1043        data.push(Transition::new());
1044        Self { data }
1045    }
1046
1047    fn new_for_list() -> Self {
1048        let mut this = Self::new();
1049        this.data[0].delay = 10;
1050        this
1051    }
1052}
1053
1054impl Default for Sequence {
1055    fn default() -> Self {
1056        Self::new()
1057    }
1058}
1059
1060// ReadTarget allows Sequence to be used directly by the DMA.
1061unsafe impl ReadTarget for &mut Sequence {
1062    type ReceivedWord = u32;
1063
1064    fn rx_treq() -> Option<u8> {
1065        None
1066    }
1067
1068    fn rx_address_count(&self) -> (u32, u32) {
1069        (self.data.as_ptr() as u32, self.data.len() as u32 * 2)
1070    }
1071
1072    fn rx_increment(&self) -> bool {
1073        true
1074    }
1075}
1076
1077/// Data to be sent to the PIO program.
1078#[derive(Copy, Clone)]
1079#[repr(C)]
1080pub struct Transition {
1081    /// Mask for pin states. All low bits turn off the signal for an output pin,
1082    /// and all high bits turn on the signal for an output pin.
1083    mask: u32,
1084    /// The number of cycles to wait before activating the next [Transition].
1085    delay: u32,
1086}
1087
1088impl Format for Transition {
1089    fn format(&self, f: defmt::Formatter) {
1090        defmt::write!(
1091            f,
1092            "Transition {{ mask: {:#032b}, delay: {} }}",
1093            self.mask,
1094            self.delay
1095        )
1096    }
1097}
1098
1099impl Transition {
1100    /// Constructor for a [Transition].
1101    fn new() -> Self {
1102        Self { mask: 0, delay: 0 }
1103    }
1104}
1105
1106/// Data for the PwmCluster to track transitions.
1107#[derive(Copy, Clone, Default, Format)]
1108struct TransitionData {
1109    /// The channel that this transition applies to.
1110    channel: u8,
1111    /// The level when this transition should occur.
1112    level: u32,
1113    /// The state tracks whether to emit a high or low signal.
1114    state: bool,
1115    /// Dummy states just keep the same state but allow delaying the DMA interrupt.
1116    dummy: bool,
1117}
1118
1119impl TransitionData {
1120    /// Construct a transition for `channel` at `level` and set it to `state`.
1121    fn new(channel: u8, level: u32, state: bool) -> Self {
1122        Self {
1123            channel,
1124            level,
1125            state,
1126            dummy: false,
1127        }
1128    }
1129
1130    /// Construct a dummy transition at `level`.
1131    fn with_level(level: u32) -> Self {
1132        Self {
1133            channel: 0,
1134            level,
1135            state: false,
1136            dummy: true,
1137        }
1138    }
1139}
1140
1141#[cfg(test)]
1142mod tests {
1143    use super::*;
1144
1145    #[test]
1146    fn transition_must_be_exact() {
1147        assert_eq!(core::mem::size_of::<Transition>(), 2);
1148        assert_eq!(core::mem::size_of::<MaybeUninit<Transition>>(), 2);
1149    }
1150}