Skip to main content

sp1_hypercube/prover/
permits.rs

1use std::{sync::Arc, time::Duration};
2
3use thiserror::Error;
4use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore};
5use tracing::Span;
6
7/// A permit for the prover.
8#[derive(Debug)]
9pub struct ProverPermit {
10    /// The underlying permit.
11    #[allow(dead_code)]
12    permit: OwnedSemaphorePermit,
13    /// The span for the permit lifetime.
14    span: Span,
15    /// The time the permit was acquired.
16    time: tokio::time::Instant,
17}
18
19impl ProverPermit {
20    /// Release the permit and return the duration it was held for.
21    #[must_use]
22    pub fn release(self) -> Duration {
23        self.time.elapsed()
24    }
25}
26
27impl Drop for ProverPermit {
28    fn drop(&mut self) {
29        let duration = self.time.elapsed();
30        tracing::debug!(parent: &self.span, "permit acquired for {:?} ", duration);
31    }
32}
33
34/// A semaphore for the prover.
35#[derive(Debug, Clone)]
36pub struct ProverSemaphore {
37    /// The underlying semaphore.
38    sem: Arc<Semaphore>,
39}
40
41impl ProverSemaphore {
42    /// Create a new prover semaphore with the given number of permits.
43    #[must_use]
44    #[inline]
45    pub fn new(max_permits: usize) -> Self {
46        Self { sem: Arc::new(Semaphore::new(max_permits)) }
47    }
48
49    /// Acquire a permit.
50    #[inline]
51    pub async fn acquire(self) -> Result<ProverPermit, ProverAcquireError> {
52        let span = tracing::Span::current();
53        let permit = self.sem.acquire_owned().await?;
54        let time = tokio::time::Instant::now();
55        Ok(ProverPermit { permit, span, time })
56    }
57
58    /// Acquire multiple permits.
59    #[inline]
60    pub async fn acquire_many(self, n: u32) -> Result<ProverPermit, ProverAcquireError> {
61        let span = tracing::Span::current();
62        let permit = self.sem.acquire_many_owned(n).await?;
63        let time = tokio::time::Instant::now();
64        Ok(ProverPermit { permit, span, time })
65    }
66}
67
68/// An error that occurs when acquiring a permit.
69#[derive(Debug, Error)]
70#[error("failed to acquire permit")]
71pub struct ProverAcquireError(#[from] AcquireError);