Skip to main content

sp1_hypercube/prover/
memory_permit.rs

1use std::sync::Arc;
2use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore};
3
4/// A semaphore that can be used to permit memory usage.
5pub struct MemoryPermitting {
6    inner: Arc<Semaphore>,
7    mem_in_bytes: usize,
8}
9
10impl MemoryPermitting {
11    /// The maximum number of bytes that can be permitted.
12    ///
13    /// Note: This bound comes from underlying [`Semaphore`].
14    pub const MAX: usize = usize::MAX >> 3;
15
16    /// Create a new memory permitting.
17    ///
18    /// Panics if the number of permits is greater than [`Self::MAX`].
19    #[must_use]
20    pub fn new(mem_in_bytes: usize) -> Self {
21        Self { inner: Arc::new(Semaphore::new(mem_in_bytes)), mem_in_bytes }
22    }
23
24    /// Get the total memory that can be permitted.
25    #[must_use]
26    pub fn total_memory(&self) -> usize {
27        self.inner.available_permits()
28    }
29
30    /// Get a memory permit for the given number of bytes.
31    ///
32    /// # Panics
33    ///
34    /// Panics if the number of bytes is greater than [`Self::MAX`].
35    pub async fn acquire(&self, mem_in_bytes: usize) -> Result<MemoryPermit, MemoryPermitError> {
36        if mem_in_bytes > self.mem_in_bytes {
37            return Err(MemoryPermitError::ExceedsMaxPermittedMemory);
38        } else if mem_in_bytes == 0 {
39            return Err(MemoryPermitError::TriedToAcquireZero);
40        }
41
42        let permits = accquire_raw(&self.inner, mem_in_bytes).await?;
43        Ok(MemoryPermit { inner: permits, mem_in_bytes: self.mem_in_bytes })
44    }
45}
46
47impl Clone for MemoryPermitting {
48    fn clone(&self) -> Self {
49        Self { inner: self.inner.clone(), mem_in_bytes: self.mem_in_bytes }
50    }
51}
52
53/// Errors that can occur when acquiring a memory permit.
54#[derive(Debug, thiserror::Error)]
55pub enum MemoryPermitError {
56    /// The requested memory is zero.
57    #[error("Request a permit for 0 memory.")]
58    TriedToAcquireZero,
59    /// The requested memory exceeds the maximum permitted memory.
60    #[error("Requested memory exceeds the maximum permitted memory")]
61    ExceedsMaxPermittedMemory,
62    /// The requested memory is negative.
63    #[error("Split request with insufficient memory permit")]
64    NotEnoughMemoryToSplit,
65    /// The requested memory is negative.
66    #[error("The semaphore has been explicitly closed, this is a bug")]
67    Closed(#[from] AcquireError),
68}
69
70/// A memory permit.
71pub struct MemoryPermit {
72    inner: Vec<OwnedSemaphorePermit>,
73    /// The total possible memory that can be permitted.
74    mem_in_bytes: usize,
75}
76
77impl MemoryPermit {
78    /// Create a new memory permit from a list of permits.
79    #[must_use]
80    pub fn num_bytes(&self) -> usize {
81        #[allow(clippy::pedantic)]
82        self.inner.iter().map(|p| p.num_permits()).sum()
83    }
84
85    /// Split the memory permit into two.
86    ///
87    /// # Panics
88    ///
89    /// Panics if the number of bytes is greater than [`Self::num_bytes`].
90    pub fn split(&mut self, mem_in_bytes: usize) -> Result<MemoryPermit, MemoryPermitError> {
91        if mem_in_bytes > self.num_bytes() {
92            return Err(MemoryPermitError::NotEnoughMemoryToSplit);
93        } else if mem_in_bytes == 0 {
94            return Err(MemoryPermitError::TriedToAcquireZero);
95        }
96
97        let mut permits = Vec::new();
98        let mut to_acquire = mem_in_bytes;
99        while let Some(permit) = self.inner.last_mut() {
100            let num_permits = permit.num_permits();
101
102            if num_permits <= to_acquire {
103                to_acquire -= num_permits;
104                permits.push(self.inner.pop().unwrap());
105            } else {
106                // todo this accepts a usize, should we just use usize everywhere?
107                permits.push(permit.split(to_acquire).unwrap());
108            }
109        }
110
111        Ok(MemoryPermit { inner: permits, mem_in_bytes: self.mem_in_bytes })
112    }
113
114    /// Increase the memory permit by the given number of bytes.
115    pub async fn increase(&mut self, mem_in_bytes: usize) -> Result<(), MemoryPermitError> {
116        if mem_in_bytes == 0 {
117            return Ok(());
118        }
119
120        self.inner.extend(
121            accquire_raw(
122                self.inner
123                    .first()
124                    .expect("We should have at least one permit, this is a bug.")
125                    .semaphore(),
126                mem_in_bytes,
127            )
128            .await?,
129        );
130
131        Ok(())
132    }
133
134    /// Partially release the memory permit.
135    ///
136    /// This will release the memory permit by the given number of bytes.
137    ///
138    /// # Panics
139    ///
140    /// Panics if the number of bytes is greater than [`Self::num_bytes`].
141    pub fn release(&mut self, mem_in_bytes: usize) -> Result<(), MemoryPermitError> {
142        if mem_in_bytes == 0 {
143            return Ok(());
144        }
145
146        // On drop, the permits will be released.
147        let _ = self.split(mem_in_bytes)?;
148
149        Ok(())
150    }
151}
152
153/// Acquire a list of permits from a semaphore.
154///
155/// This helper function is needed because the [`Semaphore::acquire_many_owned`] method only accepts
156/// a [`u32`].
157async fn accquire_raw(
158    inner: &Arc<Semaphore>,
159    mem_in_bytes: usize,
160) -> Result<Vec<OwnedSemaphorePermit>, MemoryPermitError> {
161    let mut permits = Vec::new();
162    let mut to_acquire = mem_in_bytes;
163    while to_acquire > 0 {
164        let n = to_acquire.min(u32::MAX as usize);
165        let permit = inner.clone().acquire_many_owned(n as u32).await?;
166
167        permits.push(permit);
168        to_acquire -= n;
169    }
170
171    Ok(permits)
172}