sp1_hypercube/prover/
memory_permit.rs1use std::sync::Arc;
2use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore};
3
4pub struct MemoryPermitting {
6 inner: Arc<Semaphore>,
7 mem_in_bytes: usize,
8}
9
10impl MemoryPermitting {
11 pub const MAX: usize = usize::MAX >> 3;
15
16 #[must_use]
20 pub fn new(mem_in_bytes: usize) -> Self {
21 Self { inner: Arc::new(Semaphore::new(mem_in_bytes)), mem_in_bytes }
22 }
23
24 #[must_use]
26 pub fn total_memory(&self) -> usize {
27 self.inner.available_permits()
28 }
29
30 pub async fn acquire(&self, mem_in_bytes: usize) -> Result<MemoryPermit, MemoryPermitError> {
36 if mem_in_bytes > self.mem_in_bytes {
37 return Err(MemoryPermitError::ExceedsMaxPermittedMemory);
38 } else if mem_in_bytes == 0 {
39 return Err(MemoryPermitError::TriedToAcquireZero);
40 }
41
42 let permits = accquire_raw(&self.inner, mem_in_bytes).await?;
43 Ok(MemoryPermit { inner: permits, mem_in_bytes: self.mem_in_bytes })
44 }
45}
46
47impl Clone for MemoryPermitting {
48 fn clone(&self) -> Self {
49 Self { inner: self.inner.clone(), mem_in_bytes: self.mem_in_bytes }
50 }
51}
52
53#[derive(Debug, thiserror::Error)]
55pub enum MemoryPermitError {
56 #[error("Request a permit for 0 memory.")]
58 TriedToAcquireZero,
59 #[error("Requested memory exceeds the maximum permitted memory")]
61 ExceedsMaxPermittedMemory,
62 #[error("Split request with insufficient memory permit")]
64 NotEnoughMemoryToSplit,
65 #[error("The semaphore has been explicitly closed, this is a bug")]
67 Closed(#[from] AcquireError),
68}
69
70pub struct MemoryPermit {
72 inner: Vec<OwnedSemaphorePermit>,
73 mem_in_bytes: usize,
75}
76
77impl MemoryPermit {
78 #[must_use]
80 pub fn num_bytes(&self) -> usize {
81 #[allow(clippy::pedantic)]
82 self.inner.iter().map(|p| p.num_permits()).sum()
83 }
84
85 pub fn split(&mut self, mem_in_bytes: usize) -> Result<MemoryPermit, MemoryPermitError> {
91 if mem_in_bytes > self.num_bytes() {
92 return Err(MemoryPermitError::NotEnoughMemoryToSplit);
93 } else if mem_in_bytes == 0 {
94 return Err(MemoryPermitError::TriedToAcquireZero);
95 }
96
97 let mut permits = Vec::new();
98 let mut to_acquire = mem_in_bytes;
99 while let Some(permit) = self.inner.last_mut() {
100 let num_permits = permit.num_permits();
101
102 if num_permits <= to_acquire {
103 to_acquire -= num_permits;
104 permits.push(self.inner.pop().unwrap());
105 } else {
106 permits.push(permit.split(to_acquire).unwrap());
108 }
109 }
110
111 Ok(MemoryPermit { inner: permits, mem_in_bytes: self.mem_in_bytes })
112 }
113
114 pub async fn increase(&mut self, mem_in_bytes: usize) -> Result<(), MemoryPermitError> {
116 if mem_in_bytes == 0 {
117 return Ok(());
118 }
119
120 self.inner.extend(
121 accquire_raw(
122 self.inner
123 .first()
124 .expect("We should have at least one permit, this is a bug.")
125 .semaphore(),
126 mem_in_bytes,
127 )
128 .await?,
129 );
130
131 Ok(())
132 }
133
134 pub fn release(&mut self, mem_in_bytes: usize) -> Result<(), MemoryPermitError> {
142 if mem_in_bytes == 0 {
143 return Ok(());
144 }
145
146 let _ = self.split(mem_in_bytes)?;
148
149 Ok(())
150 }
151}
152
153async fn accquire_raw(
158 inner: &Arc<Semaphore>,
159 mem_in_bytes: usize,
160) -> Result<Vec<OwnedSemaphorePermit>, MemoryPermitError> {
161 let mut permits = Vec::new();
162 let mut to_acquire = mem_in_bytes;
163 while to_acquire > 0 {
164 let n = to_acquire.min(u32::MAX as usize);
165 let permit = inner.clone().acquire_many_owned(n as u32).await?;
166
167 permits.push(permit);
168 to_acquire -= n;
169 }
170
171 Ok(permits)
172}