tycho_util/mem/
slicer.rs

1use std::collections::BTreeMap;
2use std::num::{NonZeroU64, NonZeroUsize};
3use std::ops::{Bound, RangeBounds};
4use std::sync::{Arc, Mutex, MutexGuard};
5
6use bytesize::ByteSize;
7use serde::{Deserialize, Serialize};
8
9#[derive(Clone)]
10pub struct MemorySlicer {
11    range: MemorySlicerRange,
12    total: ByteSize,
13    inner: Arc<Inner>,
14}
15
16impl MemorySlicer {
17    pub fn new(range: MemorySlicerRange) -> Self {
18        let total = range.measure_available();
19        Self {
20            range,
21            total,
22            inner: Arc::new(Inner {
23                available: Mutex::new(total),
24            }),
25        }
26    }
27
28    pub fn range(&self) -> MemorySlicerRange {
29        self.range
30    }
31
32    pub fn total(&self) -> ByteSize {
33        self.total
34    }
35
36    pub fn lock(&self) -> MemorySlicerGuard<'_> {
37        MemorySlicerGuard {
38            range: self.range,
39            total: self.total,
40            available: self.inner.available.lock().unwrap(),
41        }
42    }
43}
44
45pub struct MemorySlicerGuard<'a> {
46    range: MemorySlicerRange,
47    total: ByteSize,
48    available: MutexGuard<'a, ByteSize>,
49}
50
51impl MemorySlicerGuard<'_> {
52    pub fn range(&self) -> MemorySlicerRange {
53        self.range
54    }
55
56    pub fn total(&self) -> ByteSize {
57        self.total
58    }
59
60    pub fn available(&self) -> ByteSize {
61        *self.available
62    }
63
64    pub fn snapshot(&self) -> MemorySlicer {
65        let available = *self.available;
66        MemorySlicer {
67            range: MemorySlicerRange::Fixed {
68                capacity: self.total,
69            },
70            total: self.total,
71            inner: Arc::new(Inner {
72                available: Mutex::new(available),
73            }),
74        }
75    }
76
77    pub fn alloc_fixed(&mut self, bytes: ByteSize) -> bool {
78        if let Some(available) = self.available.0.checked_sub(bytes.0) {
79            *self.available = ByteSize(available);
80            true
81        } else {
82            false
83        }
84    }
85
86    pub fn alloc_ratio(&mut self, nom: usize, denom: usize) -> Option<ByteSize> {
87        // TODO: Panic here?
88        if nom > denom || denom == 0 {
89            return None;
90        }
91        let to_alloc = self.available.0.saturating_mul(nom as u64) / (denom as u64);
92        self.available.0 -= to_alloc;
93        (to_alloc != 0).then_some(ByteSize(to_alloc))
94    }
95
96    pub fn alloc_in_range<R: RangeBounds<ByteSize>>(&mut self, bytes: R) -> Option<ByteSize> {
97        let min_alloc = match bytes.start_bound() {
98            Bound::Included(bytes) => bytes.0,
99            Bound::Excluded(bytes) => bytes.0.saturating_add(1),
100            Bound::Unbounded => 0,
101        };
102
103        let mut to_alloc = match bytes.end_bound() {
104            Bound::Included(bytes) => bytes.0,
105            Bound::Excluded(bytes) => bytes.0.saturating_sub(1),
106            Bound::Unbounded => u64::MAX,
107        };
108        to_alloc = std::cmp::min(to_alloc, self.available.0);
109        if to_alloc < min_alloc {
110            return None;
111        }
112
113        self.available.0 -= to_alloc;
114        Some(ByteSize(to_alloc))
115    }
116
117    /// Tries to allocate an amount of memory which satisfies provided
118    /// constraints.
119    ///
120    /// On success returns allocated amounts in the same order as constraints.
121    /// Otherwise returns the minimum required memory size (total).
122    pub fn alloc_constraints<C: MemoryConstraints>(
123        &mut self,
124        constraints: C,
125    ) -> Result<C::Output, ByteSize> {
126        let (solved, result) =
127            solve_constraints(*self.available, constraints.as_constraitns_slice());
128        let result_total = result.iter().map(|size| size.as_u64()).sum();
129
130        if let Some(available) = self.available.0.checked_sub(result_total)
131            && solved
132        {
133            self.available.0 = available;
134            return Ok(C::make_output(result));
135        }
136
137        Err(ByteSize(result_total))
138    }
139
140    /// Tries to allocate an amount of memory which satisfies provided
141    /// constraints.
142    ///
143    /// In case of overflow still allocates, but returns the required remainder
144    /// in [`AllocatedMemoryConstraints::overflow`].
145    pub fn overflowing_alloc_constraints<C: MemoryConstraints>(
146        &mut self,
147        constraints: C,
148    ) -> AllocatedMemoryConstraints<C::Output> {
149        let (_, result) = solve_constraints(*self.available, constraints.as_constraitns_slice());
150        let result_total = result.iter().map(|size| size.as_u64()).sum();
151
152        let allocated = std::cmp::min(result_total, self.available.0);
153        self.available.0 -= allocated;
154
155        AllocatedMemoryConstraints {
156            result: C::make_output(result),
157            total: ByteSize(result_total),
158            overflow: (result_total > allocated).then(|| ByteSize(result_total - allocated)),
159        }
160    }
161
162    pub fn subdivide(&mut self, bytes: ByteSize) -> Option<MemorySlicer> {
163        if self.alloc_fixed(bytes) {
164            Some(MemorySlicer::new(MemorySlicerRange::Fixed {
165                capacity: bytes,
166            }))
167        } else {
168            None
169        }
170    }
171
172    pub fn subdivide_ratio(&mut self, nom: usize, denom: usize) -> Option<MemorySlicer> {
173        self.alloc_ratio(nom, denom)
174            .map(|capacity| MemorySlicer::new(MemorySlicerRange::Fixed { capacity }))
175    }
176
177    pub fn subdivide_in_range<R: RangeBounds<ByteSize>>(
178        &mut self,
179        bytes: R,
180    ) -> Option<MemorySlicer> {
181        self.alloc_in_range(bytes)
182            .map(|capacity| MemorySlicer::new(MemorySlicerRange::Fixed { capacity }))
183    }
184}
185
186impl From<MemorySlicerRange> for MemorySlicer {
187    #[inline]
188    fn from(value: MemorySlicerRange) -> Self {
189        Self::new(value)
190    }
191}
192
193struct Inner {
194    available: Mutex<ByteSize>,
195}
196
197#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
198#[cfg_attr(feature = "sysinfo", derive(Default))]
199#[serde(tag = "type")]
200pub enum MemorySlicerRange {
201    #[cfg(feature = "sysinfo")]
202    #[cfg_attr(feature = "sysinfo", default)]
203    Available,
204    #[cfg(feature = "sysinfo")]
205    Physical,
206    Fixed {
207        capacity: ByteSize,
208    },
209}
210
211impl MemorySlicerRange {
212    pub fn fixed(capacity: ByteSize) -> Self {
213        Self::Fixed { capacity }
214    }
215
216    pub fn into_slicer(self) -> MemorySlicer {
217        MemorySlicer::new(self)
218    }
219
220    pub fn measure_available(&self) -> ByteSize {
221        match self {
222            #[cfg(feature = "sysinfo")]
223            Self::Available => {
224                let mut sys = sysinfo::System::new();
225                sys.refresh_memory();
226                ByteSize(sys.available_memory())
227            }
228            // TODO: Add support for cgroups?
229            #[cfg(feature = "sysinfo")]
230            Self::Physical => {
231                let mut sys = sysinfo::System::new();
232                sys.refresh_memory();
233                ByteSize(sys.total_memory())
234            }
235            Self::Fixed { capacity } => *capacity,
236        }
237    }
238}
239
240pub struct AllocatedMemoryConstraints<T> {
241    pub result: T,
242    pub total: ByteSize,
243    pub overflow: Option<ByteSize>,
244}
245
246pub trait MemoryConstraints {
247    type Output;
248
249    fn make_output(result: Vec<ByteSize>) -> Self::Output;
250
251    // NOTE: Not using `as_slice` to annoying import in other places.
252    fn as_constraitns_slice(&self) -> &[MemoryConstraint];
253}
254
255impl<const N: usize> MemoryConstraints for [MemoryConstraint; N] {
256    type Output = [ByteSize; N];
257
258    fn make_output(result: Vec<ByteSize>) -> Self::Output {
259        result.try_into().unwrap()
260    }
261
262    fn as_constraitns_slice(&self) -> &[MemoryConstraint] {
263        self.as_slice()
264    }
265}
266
267impl MemoryConstraints for Vec<MemoryConstraint> {
268    type Output = Vec<ByteSize>;
269
270    fn make_output(result: Vec<ByteSize>) -> Self::Output {
271        result
272    }
273
274    fn as_constraitns_slice(&self) -> &[MemoryConstraint] {
275        self.as_slice()
276    }
277}
278
279#[derive(Debug, Clone)]
280pub struct MemoryConstraint {
281    priority: usize,
282    ratio: NonZeroUsize,
283    min_bytes: ByteSize,
284    max_bytes: ByteSize,
285}
286
287impl MemoryConstraint {
288    pub const HIGH_PRIORITY: usize = 0;
289    pub const MID_PRIORITY: usize = 1;
290    pub const LOW_PRIORITY: usize = 2;
291
292    pub fn exact(priority: usize, bytes: ByteSize) -> Self {
293        Self {
294            priority,
295            ratio: NonZeroUsize::MIN,
296            min_bytes: bytes,
297            max_bytes: bytes,
298        }
299    }
300
301    pub fn range<R>(priority: usize, ratio: usize, range: R) -> Self
302    where
303        R: RangeBounds<ByteSize>,
304    {
305        let min_bytes = match range.start_bound() {
306            Bound::Included(bytes) => *bytes,
307            Bound::Excluded(bytes) => ByteSize(bytes.as_u64().saturating_add(1)),
308            Bound::Unbounded => ByteSize(0),
309        };
310        let max_bytes = match range.end_bound() {
311            Bound::Included(bytes) => *bytes,
312            Bound::Excluded(bytes) => ByteSize(bytes.as_u64().saturating_sub(1)),
313            Bound::Unbounded => ByteSize(u64::MAX),
314        };
315        assert!(min_bytes <= max_bytes);
316
317        Self {
318            priority,
319            ratio: NonZeroUsize::new(ratio).unwrap_or(NonZeroUsize::MIN),
320            min_bytes,
321            max_bytes,
322        }
323    }
324}
325
326fn solve_constraints(total: ByteSize, constraitns: &[MemoryConstraint]) -> (bool, Vec<ByteSize>) {
327    struct Item {
328        idx: usize,
329        ratio: Option<NonZeroU64>,
330        max: NonZeroU64,
331    }
332
333    #[derive(Default)]
334    struct Group {
335        total_ratio: u64,
336        items: Vec<Item>,
337    }
338
339    const SCALE: u64 = 1 << 16;
340
341    let mut total = total.as_u64();
342    let mut result = Vec::with_capacity(constraitns.len());
343    let mut groups = BTreeMap::<usize, Group>::new();
344
345    // Group constraints by priorities.
346    let mut min_required = 0u64;
347    for (idx, constraint) in constraitns.iter().enumerate() {
348        min_required = min_required.saturating_add(constraint.min_bytes.as_u64());
349        result.push(constraint.min_bytes);
350
351        let range = constraint.max_bytes.0 - constraint.min_bytes.0;
352        if let Some(max) = NonZeroU64::new(range) {
353            let group = groups.entry(constraint.priority).or_default();
354            group.total_ratio = group
355                .total_ratio
356                .checked_add(constraint.ratio.get() as u64)
357                .expect("too big total ratio");
358            group.items.push(Item {
359                idx,
360                ratio: NonZeroU64::new(constraint.ratio.get() as _),
361                max,
362            });
363        }
364    }
365
366    // Consume the minimum required amount first.
367    let solved = total >= min_required;
368    total = total.saturating_sub(min_required);
369
370    // Distribute memory inside groups.
371    'outer: for mut group in groups.into_values() {
372        // TODO: Replace loop with some math (but this might lose some precision).
373        while total > 0 && group.total_ratio > 0 {
374            // Compute the size of a chunk per ratio unit.
375            // NOTE: We additionally multiply by `SCALE` to preserve some precision.
376            let chunk = total.saturating_mul(SCALE) / group.total_ratio;
377
378            // Try to complete some ranges first.
379            let mut has_complete = false;
380            for item in &mut group.items {
381                // Skip items without ratio (i.e. complete)
382                let Some(ratio) = item.ratio else {
383                    continue;
384                };
385
386                // Compute the suggested memory for this item`s ratio.
387                let available = chunk.saturating_mul(ratio.get()) / SCALE;
388                if available >= item.max.get() {
389                    // Found a complete constraint.
390
391                    // Add the maximum available memory to it.
392                    let slot = &mut result[item.idx];
393                    slot.0 = slot.0.saturating_add(item.max.get());
394
395                    // Decrease the total amount by the distributed amount.
396                    total -= item.max.get();
397                    // Remove item ratio from the group.
398                    group.total_ratio -= ratio.get();
399                    // Mark item as complete.
400                    item.ratio = None;
401                    // Update the flag.
402                    has_complete = true;
403                }
404            }
405
406            // Recompute chunk size if some constraints were complete.
407            if has_complete {
408                continue;
409            }
410
411            // At this point we will need to distribute the remaining memory across this group.
412            for item in &group.items {
413                // Skip items without ratio (i.e. complete)
414                let Some(ratio) = item.ratio else {
415                    continue;
416                };
417
418                let available = chunk.saturating_mul(ratio.get()) / SCALE;
419
420                // Add the available memory to the result slot.
421                let slot = &mut result[item.idx];
422                slot.0 = slot.0.saturating_add(available);
423
424                // Decrease the total amount by the distributed amount.
425                total -= available;
426            }
427            continue 'outer;
428        }
429    }
430
431    (solved, result)
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn memory_slicer_works() {
440        let slicer = MemorySlicerRange::fixed(ByteSize::gb(32)).into_slicer();
441
442        let subslicer = slicer.lock().subdivide_ratio(2, 3).unwrap();
443
444        let [fixed, large_range, small_range] = subslicer
445            .lock()
446            .alloc_constraints([
447                MemoryConstraint::exact(0, ByteSize::gb(8)),
448                MemoryConstraint::range(1, 1, ByteSize::mb(128)..=ByteSize::gb(8)),
449                MemoryConstraint::range(1, 10, ByteSize::mb(128)..=ByteSize::gb(4)),
450            ])
451            .unwrap();
452
453        assert_eq!(fixed, ByteSize::gb(8));
454        assert_eq!(large_range, ByteSize::gb(8));
455        assert_eq!(small_range, ByteSize::gb(4));
456
457        println!("fixed={fixed}, large_range={large_range}, small_range={small_range}");
458        println!("available_outer={}", slicer.lock().available());
459        println!("available_inner={}", subslicer.lock().available());
460    }
461
462    #[test]
463    fn constraints_solver_works() {
464        #[derive(Debug)]
465        struct Task {
466            total: ByteSize,
467            solved: bool,
468        }
469
470        for task in [
471            Task {
472                total: ByteSize::gb(4),
473                solved: false,
474            },
475            Task {
476                total: ByteSize::gb(6),
477                solved: true,
478            },
479            Task {
480                total: ByteSize::gb(8),
481                solved: true,
482            },
483            Task {
484                total: ByteSize::gb(16),
485                solved: true,
486            },
487        ] {
488            let constraitns = [
489                MemoryConstraint::exact(0, ByteSize::gb(4)),
490                MemoryConstraint::range(1, 1, ByteSize::gb(1)..),
491                MemoryConstraint::range(1, 10, ByteSize::mb(512)..=ByteSize::gb(1)),
492            ];
493            let (solved, sizes) = solve_constraints(task.total, &constraitns);
494            assert_eq!(solved, task.solved);
495
496            println!("{task:?}");
497            println!("{sizes:#?}");
498        }
499    }
500}