1use std::collections::BTreeMap;
2use std::num::{NonZeroU64, NonZeroUsize};
3use std::ops::{Bound, RangeBounds};
4use std::sync::{Arc, Mutex, MutexGuard};
5
6use bytesize::ByteSize;
7use serde::{Deserialize, Serialize};
8
9#[derive(Clone)]
10pub struct MemorySlicer {
11 range: MemorySlicerRange,
12 total: ByteSize,
13 inner: Arc<Inner>,
14}
15
16impl MemorySlicer {
17 pub fn new(range: MemorySlicerRange) -> Self {
18 let total = range.measure_available();
19 Self {
20 range,
21 total,
22 inner: Arc::new(Inner {
23 available: Mutex::new(total),
24 }),
25 }
26 }
27
28 pub fn range(&self) -> MemorySlicerRange {
29 self.range
30 }
31
32 pub fn total(&self) -> ByteSize {
33 self.total
34 }
35
36 pub fn lock(&self) -> MemorySlicerGuard<'_> {
37 MemorySlicerGuard {
38 range: self.range,
39 total: self.total,
40 available: self.inner.available.lock().unwrap(),
41 }
42 }
43}
44
45pub struct MemorySlicerGuard<'a> {
46 range: MemorySlicerRange,
47 total: ByteSize,
48 available: MutexGuard<'a, ByteSize>,
49}
50
51impl MemorySlicerGuard<'_> {
52 pub fn range(&self) -> MemorySlicerRange {
53 self.range
54 }
55
56 pub fn total(&self) -> ByteSize {
57 self.total
58 }
59
60 pub fn available(&self) -> ByteSize {
61 *self.available
62 }
63
64 pub fn snapshot(&self) -> MemorySlicer {
65 let available = *self.available;
66 MemorySlicer {
67 range: MemorySlicerRange::Fixed {
68 capacity: self.total,
69 },
70 total: self.total,
71 inner: Arc::new(Inner {
72 available: Mutex::new(available),
73 }),
74 }
75 }
76
77 pub fn alloc_fixed(&mut self, bytes: ByteSize) -> bool {
78 if let Some(available) = self.available.0.checked_sub(bytes.0) {
79 *self.available = ByteSize(available);
80 true
81 } else {
82 false
83 }
84 }
85
86 pub fn alloc_ratio(&mut self, nom: usize, denom: usize) -> Option<ByteSize> {
87 if nom > denom || denom == 0 {
89 return None;
90 }
91 let to_alloc = self.available.0.saturating_mul(nom as u64) / (denom as u64);
92 self.available.0 -= to_alloc;
93 (to_alloc != 0).then_some(ByteSize(to_alloc))
94 }
95
96 pub fn alloc_in_range<R: RangeBounds<ByteSize>>(&mut self, bytes: R) -> Option<ByteSize> {
97 let min_alloc = match bytes.start_bound() {
98 Bound::Included(bytes) => bytes.0,
99 Bound::Excluded(bytes) => bytes.0.saturating_add(1),
100 Bound::Unbounded => 0,
101 };
102
103 let mut to_alloc = match bytes.end_bound() {
104 Bound::Included(bytes) => bytes.0,
105 Bound::Excluded(bytes) => bytes.0.saturating_sub(1),
106 Bound::Unbounded => u64::MAX,
107 };
108 to_alloc = std::cmp::min(to_alloc, self.available.0);
109 if to_alloc < min_alloc {
110 return None;
111 }
112
113 self.available.0 -= to_alloc;
114 Some(ByteSize(to_alloc))
115 }
116
117 pub fn alloc_constraints<C: MemoryConstraints>(
123 &mut self,
124 constraints: C,
125 ) -> Result<C::Output, ByteSize> {
126 let (solved, result) =
127 solve_constraints(*self.available, constraints.as_constraitns_slice());
128 let result_total = result.iter().map(|size| size.as_u64()).sum();
129
130 if let Some(available) = self.available.0.checked_sub(result_total)
131 && solved
132 {
133 self.available.0 = available;
134 return Ok(C::make_output(result));
135 }
136
137 Err(ByteSize(result_total))
138 }
139
140 pub fn overflowing_alloc_constraints<C: MemoryConstraints>(
146 &mut self,
147 constraints: C,
148 ) -> AllocatedMemoryConstraints<C::Output> {
149 let (_, result) = solve_constraints(*self.available, constraints.as_constraitns_slice());
150 let result_total = result.iter().map(|size| size.as_u64()).sum();
151
152 let allocated = std::cmp::min(result_total, self.available.0);
153 self.available.0 -= allocated;
154
155 AllocatedMemoryConstraints {
156 result: C::make_output(result),
157 total: ByteSize(result_total),
158 overflow: (result_total > allocated).then(|| ByteSize(result_total - allocated)),
159 }
160 }
161
162 pub fn subdivide(&mut self, bytes: ByteSize) -> Option<MemorySlicer> {
163 if self.alloc_fixed(bytes) {
164 Some(MemorySlicer::new(MemorySlicerRange::Fixed {
165 capacity: bytes,
166 }))
167 } else {
168 None
169 }
170 }
171
172 pub fn subdivide_ratio(&mut self, nom: usize, denom: usize) -> Option<MemorySlicer> {
173 self.alloc_ratio(nom, denom)
174 .map(|capacity| MemorySlicer::new(MemorySlicerRange::Fixed { capacity }))
175 }
176
177 pub fn subdivide_in_range<R: RangeBounds<ByteSize>>(
178 &mut self,
179 bytes: R,
180 ) -> Option<MemorySlicer> {
181 self.alloc_in_range(bytes)
182 .map(|capacity| MemorySlicer::new(MemorySlicerRange::Fixed { capacity }))
183 }
184}
185
186impl From<MemorySlicerRange> for MemorySlicer {
187 #[inline]
188 fn from(value: MemorySlicerRange) -> Self {
189 Self::new(value)
190 }
191}
192
193struct Inner {
194 available: Mutex<ByteSize>,
195}
196
197#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
198#[cfg_attr(feature = "sysinfo", derive(Default))]
199#[serde(tag = "type")]
200pub enum MemorySlicerRange {
201 #[cfg(feature = "sysinfo")]
202 #[cfg_attr(feature = "sysinfo", default)]
203 Available,
204 #[cfg(feature = "sysinfo")]
205 Physical,
206 Fixed {
207 capacity: ByteSize,
208 },
209}
210
211impl MemorySlicerRange {
212 pub fn fixed(capacity: ByteSize) -> Self {
213 Self::Fixed { capacity }
214 }
215
216 pub fn into_slicer(self) -> MemorySlicer {
217 MemorySlicer::new(self)
218 }
219
220 pub fn measure_available(&self) -> ByteSize {
221 match self {
222 #[cfg(feature = "sysinfo")]
223 Self::Available => {
224 let mut sys = sysinfo::System::new();
225 sys.refresh_memory();
226 ByteSize(sys.available_memory())
227 }
228 #[cfg(feature = "sysinfo")]
230 Self::Physical => {
231 let mut sys = sysinfo::System::new();
232 sys.refresh_memory();
233 ByteSize(sys.total_memory())
234 }
235 Self::Fixed { capacity } => *capacity,
236 }
237 }
238}
239
240pub struct AllocatedMemoryConstraints<T> {
241 pub result: T,
242 pub total: ByteSize,
243 pub overflow: Option<ByteSize>,
244}
245
246pub trait MemoryConstraints {
247 type Output;
248
249 fn make_output(result: Vec<ByteSize>) -> Self::Output;
250
251 fn as_constraitns_slice(&self) -> &[MemoryConstraint];
253}
254
255impl<const N: usize> MemoryConstraints for [MemoryConstraint; N] {
256 type Output = [ByteSize; N];
257
258 fn make_output(result: Vec<ByteSize>) -> Self::Output {
259 result.try_into().unwrap()
260 }
261
262 fn as_constraitns_slice(&self) -> &[MemoryConstraint] {
263 self.as_slice()
264 }
265}
266
267impl MemoryConstraints for Vec<MemoryConstraint> {
268 type Output = Vec<ByteSize>;
269
270 fn make_output(result: Vec<ByteSize>) -> Self::Output {
271 result
272 }
273
274 fn as_constraitns_slice(&self) -> &[MemoryConstraint] {
275 self.as_slice()
276 }
277}
278
279#[derive(Debug, Clone)]
280pub struct MemoryConstraint {
281 priority: usize,
282 ratio: NonZeroUsize,
283 min_bytes: ByteSize,
284 max_bytes: ByteSize,
285}
286
287impl MemoryConstraint {
288 pub const HIGH_PRIORITY: usize = 0;
289 pub const MID_PRIORITY: usize = 1;
290 pub const LOW_PRIORITY: usize = 2;
291
292 pub fn exact(priority: usize, bytes: ByteSize) -> Self {
293 Self {
294 priority,
295 ratio: NonZeroUsize::MIN,
296 min_bytes: bytes,
297 max_bytes: bytes,
298 }
299 }
300
301 pub fn range<R>(priority: usize, ratio: usize, range: R) -> Self
302 where
303 R: RangeBounds<ByteSize>,
304 {
305 let min_bytes = match range.start_bound() {
306 Bound::Included(bytes) => *bytes,
307 Bound::Excluded(bytes) => ByteSize(bytes.as_u64().saturating_add(1)),
308 Bound::Unbounded => ByteSize(0),
309 };
310 let max_bytes = match range.end_bound() {
311 Bound::Included(bytes) => *bytes,
312 Bound::Excluded(bytes) => ByteSize(bytes.as_u64().saturating_sub(1)),
313 Bound::Unbounded => ByteSize(u64::MAX),
314 };
315 assert!(min_bytes <= max_bytes);
316
317 Self {
318 priority,
319 ratio: NonZeroUsize::new(ratio).unwrap_or(NonZeroUsize::MIN),
320 min_bytes,
321 max_bytes,
322 }
323 }
324}
325
326fn solve_constraints(total: ByteSize, constraitns: &[MemoryConstraint]) -> (bool, Vec<ByteSize>) {
327 struct Item {
328 idx: usize,
329 ratio: Option<NonZeroU64>,
330 max: NonZeroU64,
331 }
332
333 #[derive(Default)]
334 struct Group {
335 total_ratio: u64,
336 items: Vec<Item>,
337 }
338
339 const SCALE: u64 = 1 << 16;
340
341 let mut total = total.as_u64();
342 let mut result = Vec::with_capacity(constraitns.len());
343 let mut groups = BTreeMap::<usize, Group>::new();
344
345 let mut min_required = 0u64;
347 for (idx, constraint) in constraitns.iter().enumerate() {
348 min_required = min_required.saturating_add(constraint.min_bytes.as_u64());
349 result.push(constraint.min_bytes);
350
351 let range = constraint.max_bytes.0 - constraint.min_bytes.0;
352 if let Some(max) = NonZeroU64::new(range) {
353 let group = groups.entry(constraint.priority).or_default();
354 group.total_ratio = group
355 .total_ratio
356 .checked_add(constraint.ratio.get() as u64)
357 .expect("too big total ratio");
358 group.items.push(Item {
359 idx,
360 ratio: NonZeroU64::new(constraint.ratio.get() as _),
361 max,
362 });
363 }
364 }
365
366 let solved = total >= min_required;
368 total = total.saturating_sub(min_required);
369
370 'outer: for mut group in groups.into_values() {
372 while total > 0 && group.total_ratio > 0 {
374 let chunk = total.saturating_mul(SCALE) / group.total_ratio;
377
378 let mut has_complete = false;
380 for item in &mut group.items {
381 let Some(ratio) = item.ratio else {
383 continue;
384 };
385
386 let available = chunk.saturating_mul(ratio.get()) / SCALE;
388 if available >= item.max.get() {
389 let slot = &mut result[item.idx];
393 slot.0 = slot.0.saturating_add(item.max.get());
394
395 total -= item.max.get();
397 group.total_ratio -= ratio.get();
399 item.ratio = None;
401 has_complete = true;
403 }
404 }
405
406 if has_complete {
408 continue;
409 }
410
411 for item in &group.items {
413 let Some(ratio) = item.ratio else {
415 continue;
416 };
417
418 let available = chunk.saturating_mul(ratio.get()) / SCALE;
419
420 let slot = &mut result[item.idx];
422 slot.0 = slot.0.saturating_add(available);
423
424 total -= available;
426 }
427 continue 'outer;
428 }
429 }
430
431 (solved, result)
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn memory_slicer_works() {
440 let slicer = MemorySlicerRange::fixed(ByteSize::gb(32)).into_slicer();
441
442 let subslicer = slicer.lock().subdivide_ratio(2, 3).unwrap();
443
444 let [fixed, large_range, small_range] = subslicer
445 .lock()
446 .alloc_constraints([
447 MemoryConstraint::exact(0, ByteSize::gb(8)),
448 MemoryConstraint::range(1, 1, ByteSize::mb(128)..=ByteSize::gb(8)),
449 MemoryConstraint::range(1, 10, ByteSize::mb(128)..=ByteSize::gb(4)),
450 ])
451 .unwrap();
452
453 assert_eq!(fixed, ByteSize::gb(8));
454 assert_eq!(large_range, ByteSize::gb(8));
455 assert_eq!(small_range, ByteSize::gb(4));
456
457 println!("fixed={fixed}, large_range={large_range}, small_range={small_range}");
458 println!("available_outer={}", slicer.lock().available());
459 println!("available_inner={}", subslicer.lock().available());
460 }
461
462 #[test]
463 fn constraints_solver_works() {
464 #[derive(Debug)]
465 struct Task {
466 total: ByteSize,
467 solved: bool,
468 }
469
470 for task in [
471 Task {
472 total: ByteSize::gb(4),
473 solved: false,
474 },
475 Task {
476 total: ByteSize::gb(6),
477 solved: true,
478 },
479 Task {
480 total: ByteSize::gb(8),
481 solved: true,
482 },
483 Task {
484 total: ByteSize::gb(16),
485 solved: true,
486 },
487 ] {
488 let constraitns = [
489 MemoryConstraint::exact(0, ByteSize::gb(4)),
490 MemoryConstraint::range(1, 1, ByteSize::gb(1)..),
491 MemoryConstraint::range(1, 10, ByteSize::mb(512)..=ByteSize::gb(1)),
492 ];
493 let (solved, sizes) = solve_constraints(task.total, &constraitns);
494 assert_eq!(solved, task.solved);
495
496 println!("{task:?}");
497 println!("{sizes:#?}");
498 }
499 }
500}