Skip to main content

vpp_plugin/vlib/process_node/
core.rs

1//! Core infrastructure for VPP process nodes
2//!
3//! This module provides the [`ProcessNode`] trait and core infrastructure
4//! for running async/await coroutines within VPP process nodes.
5
6use futures_task::{ArcWake, waker_ref};
7use pin_project_lite::pin_project;
8
9use crate::{
10    bindings::{
11        _vlib_node_registration, async_context, vl_api_force_rpc_call_main_thread,
12        vlib_helper_get_global_main, vlib_helper_process_node_loop,
13        vlib_helper_remove_node_from_registrations, vlib_main_t, vlib_node_registration_t,
14        vlib_node_runtime_t, vlib_process_signal_event_mt_args_t,
15        vlib_process_signal_event_mt_helper,
16    },
17    vlib::{
18        MainRef, NodeRuntimeRef,
19        node::{ErrorCounters, NextNodes},
20        process_node::tw_timer::{Timer, TimerWheel},
21    },
22};
23use std::{
24    cell::{RefCell, UnsafeCell},
25    ffi::c_void,
26    fmt,
27    future::Future,
28    pin::Pin,
29    rc::Rc,
30    sync::Arc,
31    task::{Context, Poll},
32    time::{Duration, Instant},
33};
34
35pub use futures_task::LocalFutureObj;
36
37// This could be set to 1000 / VLIB_TW_TICKS_PER_SECOND (defined in VPP code), but that is 10μs, which is smaller than
38// the epoll_wait granularity of 1ms (which is called by vlib_file_poll in the main event loop), and so wouldn't be
39// able to be achieved reliably (even with no other Unix processes in the system pre-empting the VPP main thread).
40//
41// So instead, this is just set to 1ms which is the minimum theoretically reliably achievable for process nodes.
42const TICK_INTERVAL_PER_MS: u64 = 1;
43const TICK_INTERVAL_S: f64 = TICK_INTERVAL_PER_MS as f64 / 1000.0;
44
45/// Trait for defining a VPP process (async) node
46pub trait ProcessNode {
47    /// Type defining the next nodes of this node
48    ///
49    /// Typically an enum using the [`vpp_plugin_macros::NextNodes`] derive macro.
50    type NextNodes: NextNodes;
51
52    /// Type defining the runtime data of this node
53    ///
54    /// This data is per-node instance and per-thread.
55    // Send + Copy due to:
56    //     if (vec_len (n->runtime_data) > 0)
57    //       clib_memcpy (rt->runtime_data, n->runtime_data,
58    //                    vec_len (n->runtime_data));
59    //     else
60    //       clib_memset (rt->runtime_data, 0, VLIB_NODE_RUNTIME_DATA_SIZE);
61    type RuntimeData: Send + Copy;
62    /// Type defining the error counters of this node
63    ///
64    /// Typically an enum using the [`vpp_plugin_macros::ErrorCounters`] derive macro.
65    type Errors: ErrorCounters;
66
67    /// The main async coroutine for this process node
68    #[must_use = "Futures do nothing unless awaited"]
69    fn function(
70        &self,
71        vm: &mut MainRef,
72        node: &mut NodeRuntimeRef<Self>,
73    ) -> impl Future<Output = ()>;
74}
75
76/// Registration information for a VPP process node
77///
78/// Used for registering and unregistering process nodes with VPP.
79///
80/// This is typically created automatically using the [`vpp_plugin_macros::vlib_process_node`] macro.
81pub struct ProcessNodeRegistration<N: ProcessNode, const N_NEXT_NODES: usize> {
82    registration: UnsafeCell<_vlib_node_registration<[*mut std::os::raw::c_char; N_NEXT_NODES]>>,
83    _marker: std::marker::PhantomData<N>,
84}
85
86impl<N: ProcessNode, const N_NEXT_NODES: usize> ProcessNodeRegistration<N, N_NEXT_NODES> {
87    /// Creates a new `ProcessNodeRegistration` from the given registration data
88    pub const fn new(
89        registration: _vlib_node_registration<[*mut std::os::raw::c_char; N_NEXT_NODES]>,
90    ) -> Self {
91        Self {
92            registration: UnsafeCell::new(registration),
93            _marker: ::std::marker::PhantomData,
94        }
95    }
96
97    /// Registers the node with VPP
98    ///
99    /// # Safety
100    ///
101    /// - Must be called only once for this node registration.
102    /// - Must be called from a constructor function that is invoked before VPP initialises.
103    /// - The following pointers in the registration data must be valid:
104    ///   - `name` (must be a valid, nul-terminated string)
105    ///   - `function` (must point to a valid node function)
106    ///   - `error_descriptions` (must point to an array of `n_errors` valid `vlib_error_desc_t` entries)
107    ///   - `next_nodes` (each entry must be a valid nul-terminated string and length must be at least `n_next_nodes`)
108    /// - Other pointers in the registration data must be either valid or null as appropriate.
109    /// - `vector_size`, `scalar_size`, and `aux_size` must match the sizes of the corresponding types in `N`.
110    /// - `n_errors` must match the discriminants in N::Errors
111    /// - `n_next_nodes` must match the discriminants in N::NextNodes
112    pub unsafe fn register(&'static self) {
113        // SAFETY: The safety requirements are documented in the function's safety comment.
114        unsafe {
115            let vgm = vlib_helper_get_global_main();
116            let reg = self.registration.get();
117            (*reg).next_registration = (*vgm).node_registrations;
118            (*vgm).node_registrations = reg as *mut vlib_node_registration_t;
119        }
120    }
121
122    /// Unregisters the node from VPP
123    ///
124    /// # Safety
125    ///
126    /// - Must be called only once for this node registration.
127    /// - Must be called from a destructor function that is invoked after VPP uninitialises.
128    /// - The node must have been previously registered with VPP using [`Self::register`].
129    pub unsafe fn unregister(&self) {
130        // SAFETY: The safety requirements are documented in the function's safety comment.
131        unsafe {
132            let vgm = vlib_helper_get_global_main();
133            vlib_helper_remove_node_from_registrations(
134                vgm,
135                self.registration.get() as *mut vlib_node_registration_t,
136            );
137        }
138    }
139
140    /// Creates a `&mut NodeRuntimeRef` directly from a pointer
141    ///
142    /// This is a convenience method that calls [`NodeRuntimeRef::from_ptr_mut`], for code that
143    /// has an instance of `NodeRegistration`, but doesn't know the name of the type for the node.
144    /// As such, `self` isn't used, it's just taken so that the generic types are known.
145    ///
146    /// # Safety
147    ///
148    /// - The same preconditions as [`NodeRuntimeRef::from_ptr_mut`] apply.
149    pub unsafe fn node_runtime_from_ptr<'a>(
150        &self,
151        ptr: *mut vlib_node_runtime_t,
152    ) -> &'a mut NodeRuntimeRef<N> {
153        // SAFETY: The safety requirements are documented in the function's safety comment.
154        unsafe { NodeRuntimeRef::from_ptr_mut(ptr) }
155    }
156}
157
158// SAFETY: there is nothing in vlib_node_registration that is tied to a specific thread or that
159// mutates global state, so it's safe to send between threads.
160unsafe impl<N: ProcessNode, const N_NEXT_NODES: usize> Send
161    for ProcessNodeRegistration<N, N_NEXT_NODES>
162{
163}
164// SAFETY: NodeRegistration doesn't allow any modification after creation (and vpp doesn't
165// modify it afterwards either), so it's safe to access from multiple threads. The only exception
166// to this is the register/unregister methods, but it's the duty of the caller
167// to ensure they are called at times when no other threads have a reference to the object.
168unsafe impl<N: ProcessNode, const N_NEXT_NODES: usize> Sync
169    for ProcessNodeRegistration<N, N_NEXT_NODES>
170{
171}
172
173/// Async context shared with other objects that need scheduling
174pub(crate) struct ProcessAsyncContextShared {
175    timer_wheel: Rc<RefCell<Box<TimerWheel>>>,
176    waker: Arc<ProcessAsyncContextWaker>,
177    start_time: Instant,
178}
179
180impl ProcessAsyncContextShared {
181    fn new(node_index: u32) -> Self {
182        // Initialise on the heap to avoid excessive stack usage
183        let mut timer_wheel = Box::new_uninit();
184        TimerWheel::init(&mut timer_wheel);
185        // SAFETY: timer_wheel is initialized by TimerWheel::init above
186        let timer_wheel = unsafe { timer_wheel.assume_init() };
187        Self {
188            timer_wheel: Rc::new(RefCell::new(timer_wheel)),
189            waker: Arc::new(ProcessAsyncContextWaker { node_index }),
190            start_time: Instant::now(),
191        }
192    }
193
194    /// Convert an instant in time to number of ticks since the start time
195    ///
196    /// If the instant in time is before the start time, it will be classed as 0 ticks. Times
197    /// greater than [`u64::MAX`] ticks into the future are treated as just [`u64::MAX`] ticks.
198    fn instant_to_ticks(&self, t: Instant) -> u64 {
199        let duration = t.saturating_duration_since(self.start_time);
200        duration
201            .as_millis()
202            .div_ceil(TICK_INTERVAL_PER_MS.into())
203            .try_into()
204            .unwrap_or(u64::MAX)
205    }
206}
207
208pin_project! {
209    /// Async context for running a future within a VPP process node.
210    ///
211    /// This struct holds the state needed to poll a async future from the
212    /// VPP process node loop, including a timer wheel for async operations.
213    pub struct ProcessAsyncContext<'a> {
214        main_ref: *mut vlib_main_t,
215        #[pin]
216        future: Option<LocalFutureObj<'a, ()>>,
217        shared: Rc<ProcessAsyncContextShared>,
218    }
219}
220
221impl<'a> ProcessAsyncContext<'a> {
222    /// Create a new async context for the given future.
223    pub fn new<N>(
224        vm: &'a mut MainRef,
225        node: &NodeRuntimeRef<N>,
226        future: LocalFutureObj<'a, ()>,
227    ) -> Self {
228        Self {
229            main_ref: vm.as_ptr(),
230            future: Some(future),
231            shared: Rc::new(ProcessAsyncContextShared::new(node.node_index())),
232        }
233    }
234
235    /// Run the async context in the VPP process node loop.
236    ///
237    /// This method never returns as it enters VPP's process node loop.
238    pub fn run(mut self) -> ! {
239        // SAFETY: This enters the VPP process node loop which is the intended
240        // usage of this function. Since `Self::new` enforces that the MainRef must live as long
241        // as self then the underlying pointer must also last that long.
242        unsafe {
243            vlib_helper_process_node_loop(
244                self.main_ref,
245                &mut self as *mut Self as *mut async_context,
246            )
247        }
248    }
249}
250
251struct ProcessAsyncContextWaker {
252    node_index: u32,
253}
254
255impl ArcWake for ProcessAsyncContextWaker {
256    fn wake_by_ref(arc_self: &std::sync::Arc<Self>) {
257        let mut args = vlib_process_signal_event_mt_args_t {
258            node_index: arc_self.node_index as u64,
259            type_opaque: 0,
260            data: 0,
261        };
262        // This is conservative since we don't know whether or not we're on the main thread
263        // SAFETY: this is safe to call on any thread since VPP takes a spinlock around the
264        // critical section and the arguments match what vlib_process_signal_event_mt_helper
265        // expects.
266        unsafe {
267            vl_api_force_rpc_call_main_thread(
268                vlib_process_signal_event_mt_helper as *mut c_void,
269                std::ptr::addr_of_mut!(args) as *mut u8,
270                std::mem::size_of_val(&args) as u32,
271            )
272        };
273    }
274}
275
276/// Poll the async coroutine once.
277///
278/// This function is called by VPP to advance the async future forward.
279/// It should be called repeatedly until the future completes.
280///
281/// # Safety
282///
283/// - `context` must be a valid, non-null pointer to a live `ProcessAsyncContext`.
284/// - The caller must ensure that the context remains valid for the duration of this call.
285/// - This function must only be called from a single thread at a time.
286#[unsafe(no_mangle)]
287unsafe extern "C" fn vpp_plugin_rs_poll_async_coroutine(context: *mut ProcessAsyncContext) {
288    // SAFETY: `context` is guaranteed non-null and points to a valid `ProcessAsyncContext`.
289    let mut ctx = unsafe { Pin::new_unchecked(&mut *context) };
290
291    let ticks_since_start = ctx.shared.instant_to_ticks(Instant::now());
292    ctx.shared
293        .timer_wheel
294        .borrow_mut()
295        .expire_timers(ticks_since_start);
296
297    let ctx_project = ctx.as_mut().project();
298    if let Some(fut) = ctx_project.future.as_pin_mut() {
299        ASYNC_CONTEXT.with(|tls_ctx| {
300            tls_ctx.replace(Some(ctx_project.shared.clone()));
301        });
302        let waker = waker_ref(&ctx_project.shared.waker);
303        let mut executor_context = Context::from_waker(&waker);
304        if matches!(fut.poll(&mut executor_context), Poll::Ready(_)) {
305            // > Once a future has finished, clients should not poll it again.
306            // [https://doc.rust-lang.org/std/future/trait.Future.html]
307            ctx.project().future.set(None);
308        }
309        ASYNC_CONTEXT.with(|tls_ctx| {
310            tls_ctx.replace(None);
311        });
312    }
313}
314
315/// Get the amount of time to wait before the next timer expires.
316///
317/// If there is no next timer, then [`f64::MAX`] will be returned.
318///
319/// # Safety
320///
321/// - `context` must be a pointer to a live `ProcessAsyncContext`.
322/// - The pointer must not be null and must remain valid for the duration of the call.
323#[unsafe(no_mangle)]
324unsafe extern "C" fn vpp_plugin_rs_next_timer_duration(context: *mut ProcessAsyncContext) -> f64 {
325    // SAFETY: `context` is validated by the caller contract to be non-null and valid.
326    let ctx = unsafe { &*context };
327    let next_expiration = ctx.shared.timer_wheel.borrow().next_expiration();
328    next_expiration
329        .map(|ticks| ticks as f64 * TICK_INTERVAL_S)
330        .unwrap_or(f64::MAX)
331}
332
333thread_local! {
334    /// Async context for process nodes
335    ///
336    /// This is updated before and after suspending a VPP process node and is only valid when
337    /// polling the `ProcessAsyncContext` future.
338    static ASYNC_CONTEXT: RefCell<Option<Rc<ProcessAsyncContextShared>>> = const { RefCell::new(None) };
339}
340
341/// Execute a closure that receives a reference to the current process node async context
342///
343/// The result of the function is that of the closure.
344///
345/// # Panics
346///
347/// If not called from a vpp-plugin-rs process node.
348pub(crate) fn with_current_async_context<F, R>(f: F) -> R
349where
350    F: FnOnce(&Rc<ProcessAsyncContextShared>) -> R,
351{
352    ASYNC_CONTEXT.with(|ctx| {
353        f(ctx.borrow().as_ref().expect(
354            "There is no async context present - must be called from a vpp-plugin-rs process node",
355        ))
356    })
357}
358
359pin_project! {
360    /// Future returned by [`sleep()`]
361    #[project(!Unpin)]
362    #[derive(Debug)]
363    #[must_use = "futures do nothing unless you `.await` or poll them"]
364    pub struct Sleep {
365        // The link between the `Sleep` instance and the timer that drives it.
366        #[pin]
367        entry: Timer,
368    }
369}
370
371impl Sleep {
372    pub(crate) fn new_timeout(deadline: Instant, ctx: &Rc<ProcessAsyncContextShared>) -> Self {
373        let deadline_ticks = ctx.instant_to_ticks(deadline);
374        let entry = Timer::new(ctx.timer_wheel.clone(), deadline_ticks);
375        Self { entry }
376    }
377
378    /// Returns `true` if `Sleep` has elapsed.
379    ///
380    /// A `Sleep` instance is elapsed when the requested duration has elapsed.
381    pub fn is_elapsed(&self) -> bool {
382        self.entry.is_ready()
383    }
384}
385
386impl Future for Sleep {
387    type Output = ();
388
389    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
390        self.project().entry.poll(cx)
391    }
392}
393
394/// Waits until `duration` has elapsed.
395///
396/// An asynchronous analog to [`std::thread::sleep`].
397pub fn sleep(duration: Duration) -> Sleep {
398    let deadline = Instant::now().checked_add(duration).unwrap_or_else(|| {
399        // Roughly 30 years from now.
400        // Standard library does not provide a way to obtain max `Instant`
401        // or convert specific date in the future to instant.
402        // 1000 years overflows on macOS, 100 years overflows on FreeBSD.
403        Instant::now() + Duration::from_secs(86400 * 365 * 30)
404    });
405    with_current_async_context(|ctx| Sleep::new_timeout(deadline, ctx))
406}
407
408/// Errors returned by `Timeout`.
409///
410/// This error is returned when a timeout expires before the function was able
411/// to finish.
412#[derive(Debug, PartialEq, Eq)]
413// It may become more complicated in the future
414#[allow(missing_copy_implementations)]
415pub struct Elapsed(());
416
417impl fmt::Display for Elapsed {
418    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
419        "deadline has elapsed".fmt(fmt)
420    }
421}
422
423impl std::error::Error for Elapsed {}
424
425impl From<Elapsed> for std::io::Error {
426    fn from(_err: Elapsed) -> std::io::Error {
427        std::io::ErrorKind::TimedOut.into()
428    }
429}
430
431/// Apply a timeout to the given `future`
432///
433/// If `future` completes before `duration` has elapsed, then the completed value is returned.
434/// Otherwise, an [`Elapsed`] error is returned and the future is cancelled.
435pub fn timeout<F>(duration: Duration, future: F) -> Timeout<F::IntoFuture>
436where
437    F: IntoFuture,
438{
439    let delay = sleep(duration);
440    Timeout::new_with_delay(future.into_future(), delay)
441}
442
443pin_project! {
444    /// Future returned by [`timeout`](timeout).
445    #[must_use = "futures do nothing unless you `.await` or poll them"]
446    #[derive(Debug)]
447    pub struct Timeout<T> {
448        #[pin]
449        value: T,
450        #[pin]
451        delay: Sleep,
452    }
453}
454
455impl<T> Timeout<T> {
456    pub(crate) fn new_with_delay(value: T, delay: Sleep) -> Timeout<T> {
457        Timeout { value, delay }
458    }
459
460    /// Gets a reference to the underlying value in this timeout.
461    pub fn get_ref(&self) -> &T {
462        &self.value
463    }
464
465    /// Gets a mutable reference to the underlying value in this timeout.
466    pub fn get_mut(&mut self) -> &mut T {
467        &mut self.value
468    }
469
470    /// Consumes this timeout, returning the underlying value.
471    pub fn into_inner(self) -> T {
472        self.value
473    }
474}
475
476impl<T> Future for Timeout<T>
477where
478    T: Future,
479{
480    type Output = Result<T::Output, Elapsed>;
481
482    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
483        let me = self.project();
484
485        // First try polling the value future
486        if let Poll::Ready(v) = me.value.poll(cx) {
487            return Poll::Ready(Ok(v));
488        }
489
490        // Then try polling the delay future
491        match me.delay.poll(cx) {
492            Poll::Ready(()) => Poll::Ready(Err(Elapsed(()))),
493            Poll::Pending => Poll::Pending,
494        }
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use super::{Elapsed, sleep};
501
502    use std::time::Duration;
503
504    #[test]
505    #[should_panic(
506        expected = "There is no async context present - must be called from a vpp-plugin-rs process node"
507    )]
508    fn sleep_outside_process_node_panics() {
509        std::mem::drop(sleep(Duration::from_secs(1)));
510    }
511
512    #[test]
513    fn elapsed_to_std_error() {
514        let e: std::io::Error = Elapsed(()).into();
515        assert_eq!(e.kind(), std::io::ErrorKind::TimedOut);
516    }
517}