Skip to main content

libcontainer/process/
memory_policy.rs

1use std::fmt;
2
3use oci_spec::runtime::{MemoryPolicyFlagType, MemoryPolicyModeType};
4
5use crate::syscall::{Syscall, SyscallError};
6
7#[derive(Debug, thiserror::Error)]
8pub enum MemoryPolicyError {
9    #[error("Invalid memory policy flag: {0}")]
10    InvalidFlag(String),
11
12    #[error("Invalid node specification: {0}")]
13    InvalidNodes(String),
14
15    #[error("Incompatible flag and mode combination: {0}")]
16    IncompatibleFlagMode(String),
17
18    #[error("Mutually exclusive flags: {0}")]
19    MutuallyExclusiveFlags(String),
20
21    #[error("Syscall error: {0}")]
22    Syscall(#[from] SyscallError),
23}
24
25type Result<T> = std::result::Result<T, MemoryPolicyError>;
26
27#[repr(i32)]
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29enum MemoryPolicyMode {
30    Default = 0,
31    Preferred = 1,
32    Bind = 2,
33    Interleave = 3,
34    Local = 4,
35    PreferredMany = 5,
36    WeightedInterleave = 6,
37}
38
39impl From<MemoryPolicyMode> for i32 {
40    fn from(mode: MemoryPolicyMode) -> Self {
41        mode as i32
42    }
43}
44
45impl From<MemoryPolicyModeType> for MemoryPolicyMode {
46    fn from(mode: MemoryPolicyModeType) -> Self {
47        match mode {
48            MemoryPolicyModeType::MpolDefault => MemoryPolicyMode::Default,
49            MemoryPolicyModeType::MpolPreferred => MemoryPolicyMode::Preferred,
50            MemoryPolicyModeType::MpolBind => MemoryPolicyMode::Bind,
51            MemoryPolicyModeType::MpolInterleave => MemoryPolicyMode::Interleave,
52            MemoryPolicyModeType::MpolLocal => MemoryPolicyMode::Local,
53            MemoryPolicyModeType::MpolPreferredMany => MemoryPolicyMode::PreferredMany,
54            MemoryPolicyModeType::MpolWeightedInterleave => MemoryPolicyMode::WeightedInterleave,
55        }
56    }
57}
58
59impl fmt::Display for MemoryPolicyMode {
60    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61        let s = match self {
62            MemoryPolicyMode::Default => "MPOL_DEFAULT",
63            MemoryPolicyMode::Preferred => "MPOL_PREFERRED",
64            MemoryPolicyMode::Bind => "MPOL_BIND",
65            MemoryPolicyMode::Interleave => "MPOL_INTERLEAVE",
66            MemoryPolicyMode::Local => "MPOL_LOCAL",
67            MemoryPolicyMode::PreferredMany => "MPOL_PREFERRED_MANY",
68            MemoryPolicyMode::WeightedInterleave => "MPOL_WEIGHTED_INTERLEAVE",
69        };
70        write!(f, "{}", s)
71    }
72}
73
74#[repr(u32)]
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76enum MemoryPolicyFlag {
77    NumaBalancing = 1 << 13, // 0x2000
78    RelativeNodes = 1 << 14, // 0x4000
79    StaticNodes = 1 << 15,   // 0x8000
80}
81
82impl From<MemoryPolicyFlag> for u32 {
83    fn from(flag: MemoryPolicyFlag) -> Self {
84        flag as u32
85    }
86}
87
88struct ValidatedMemoryPolicy {
89    mode_with_flags: i32,
90    nodemask: Vec<libc::c_ulong>,
91    maxnode: u64,
92}
93
94fn validate_memory_policy(
95    memory_policy: &Option<oci_spec::runtime::LinuxMemoryPolicy>,
96) -> Result<Option<ValidatedMemoryPolicy>> {
97    let Some(policy) = memory_policy else {
98        return Ok(None);
99    };
100
101    let base_mode = MemoryPolicyMode::from(policy.mode());
102
103    let (flags_value, has_static, has_relative) = policy
104        .flags()
105        .as_ref()
106        .map(|flags| {
107            flags
108                .iter()
109                .fold((0u32, false, false), |(val, s, r), flag| match flag {
110                    MemoryPolicyFlagType::MpolFNumaBalancing => {
111                        (val | u32::from(MemoryPolicyFlag::NumaBalancing), s, r)
112                    }
113                    MemoryPolicyFlagType::MpolFStaticNodes => {
114                        (val | u32::from(MemoryPolicyFlag::StaticNodes), true, r)
115                    }
116                    MemoryPolicyFlagType::MpolFRelativeNodes => {
117                        (val | u32::from(MemoryPolicyFlag::RelativeNodes), s, true)
118                    }
119                })
120        })
121        .unwrap_or((0, false, false));
122
123    // Validate flags
124    if let Some(flags) = policy.flags() {
125        if flags.contains(&MemoryPolicyFlagType::MpolFNumaBalancing)
126            && base_mode != MemoryPolicyMode::Bind
127        {
128            return Err(MemoryPolicyError::IncompatibleFlagMode(
129                "MPOL_F_NUMA_BALANCING can only be used with MPOL_BIND".to_string(),
130            ));
131        }
132    }
133
134    if has_static && has_relative {
135        return Err(MemoryPolicyError::MutuallyExclusiveFlags(
136            "MPOL_F_STATIC_NODES and MPOL_F_RELATIVE_NODES are mutually exclusive".to_string(),
137        ));
138    }
139
140    let mode_with_flags = i32::from(base_mode) | (flags_value as i32);
141
142    match base_mode {
143        MemoryPolicyMode::Default | MemoryPolicyMode::Local => {
144            let mode_name = base_mode.to_string();
145
146            if let Some(nodes) = policy.nodes() {
147                if !nodes.trim().is_empty() {
148                    return Err(MemoryPolicyError::InvalidNodes(format!(
149                        "{} does not accept node specification",
150                        mode_name
151                    )));
152                }
153            }
154            if flags_value != 0 {
155                return Err(MemoryPolicyError::InvalidFlag(format!(
156                    "{} does not accept flags",
157                    mode_name
158                )));
159            }
160            Ok(Some(ValidatedMemoryPolicy {
161                mode_with_flags,
162                nodemask: Vec::new(),
163                maxnode: 0,
164            }))
165        }
166        MemoryPolicyMode::Preferred => {
167            let relative_or_static: u32 = u32::from(MemoryPolicyFlag::RelativeNodes)
168                | u32::from(MemoryPolicyFlag::StaticNodes);
169
170            let check_empty_nodes_flags = |flags_value: u32| -> Result<()> {
171                if flags_value & relative_or_static != 0u32 {
172                    return Err(MemoryPolicyError::IncompatibleFlagMode(
173                        "MPOL_PREFERRED with empty nodes cannot use MPOL_F_STATIC_NODES or MPOL_F_RELATIVE_NODES flags".to_string(),
174                    ));
175                }
176                Ok(())
177            };
178
179            match policy.nodes() {
180                None => {
181                    check_empty_nodes_flags(flags_value)?;
182                    Ok(Some(ValidatedMemoryPolicy {
183                        mode_with_flags,
184                        nodemask: Vec::new(),
185                        maxnode: 0,
186                    }))
187                }
188                Some(nodes) if nodes.trim().is_empty() => {
189                    check_empty_nodes_flags(flags_value)?;
190                    Ok(Some(ValidatedMemoryPolicy {
191                        mode_with_flags,
192                        nodemask: Vec::new(),
193                        maxnode: 0,
194                    }))
195                }
196                Some(nodes) => {
197                    let (nodemask, maxnode) = build_nodemask(nodes)?;
198                    if maxnode == 0 {
199                        check_empty_nodes_flags(flags_value)?;
200                        return Ok(Some(ValidatedMemoryPolicy {
201                            mode_with_flags,
202                            nodemask: Vec::new(),
203                            maxnode: 0,
204                        }));
205                    }
206                    Ok(Some(ValidatedMemoryPolicy {
207                        mode_with_flags,
208                        nodemask,
209                        maxnode,
210                    }))
211                }
212            }
213        }
214        _ => {
215            let mode_name = base_mode.to_string();
216            let nodes = match policy.nodes() {
217                None => {
218                    return Err(MemoryPolicyError::InvalidNodes(format!(
219                        "Mode {} requires non-empty node specification",
220                        mode_name
221                    )));
222                }
223                Some(nodes) if nodes.trim().is_empty() => {
224                    return Err(MemoryPolicyError::InvalidNodes(format!(
225                        "Mode {} requires non-empty node specification",
226                        mode_name
227                    )));
228                }
229                Some(nodes) => nodes,
230            };
231            let (nodemask, maxnode) = build_nodemask(nodes)?;
232            if maxnode == 0 {
233                return Err(MemoryPolicyError::InvalidNodes(format!(
234                    "Mode {} requires non-empty node specification (parsed result is empty)",
235                    mode_name
236                )));
237            }
238            Ok(Some(ValidatedMemoryPolicy {
239                mode_with_flags,
240                nodemask,
241                maxnode,
242            }))
243        }
244    }
245}
246
247/// Configure the memory policy for the process using set_mempolicy(2).
248///
249/// See: https://man7.org/linux/man-pages/man2/set_mempolicy.2.html
250pub fn setup_memory_policy(
251    memory_policy: &Option<oci_spec::runtime::LinuxMemoryPolicy>,
252    syscall: &dyn Syscall,
253) -> Result<()> {
254    let validated = validate_memory_policy(memory_policy)?;
255    if let Some(valid) = validated {
256        syscall
257            .set_mempolicy(valid.mode_with_flags, &valid.nodemask, valid.maxnode)
258            .map_err(|err| {
259                tracing::error!(?err, "failed to set memory policy");
260                MemoryPolicyError::Syscall(err)
261            })?;
262    }
263    Ok(())
264}
265
266// Build a proper nodemask for set_mempolicy
267fn build_nodemask(nodes: &str) -> Result<(Vec<libc::c_ulong>, u64)> {
268    let node_ids = parse_node_string(nodes)?;
269
270    if node_ids.is_empty() {
271        // Empty nodemask - return NULL equivalent (empty vector)
272        return Ok((Vec::new(), 0));
273    }
274
275    // Find the highest node ID
276    let highest_node = node_ids.iter().max().copied().unwrap_or(0) as usize;
277
278    // Calculate how many c_ulong values we need to store the bitmask
279    let bits_per_ulong = std::mem::size_of::<libc::c_ulong>() * 8;
280    let num_ulongs = (highest_node / bits_per_ulong) + 1;
281
282    // Calculate maxnode = number of bits provided in nodemask
283    let maxnode = (num_ulongs * bits_per_ulong) as u64;
284
285    // Build the nodemask array as Vec<c_ulong>
286    let mut nodemask = vec![0 as libc::c_ulong; num_ulongs];
287
288    // Set bits for each node ID
289    for node_id in node_ids {
290        let node_id = node_id as usize;
291        let word_index = node_id / bits_per_ulong;
292        let bit_index = node_id % bits_per_ulong;
293
294        if word_index < nodemask.len() {
295            nodemask[word_index] |= (1 as libc::c_ulong) << bit_index;
296        }
297    }
298
299    Ok((nodemask, maxnode))
300}
301
302fn parse_node_string(nodes: &str) -> Result<Vec<u32>> {
303    let mut node_ids = Vec::new();
304
305    // Trim whitespace and check for empty string
306    let nodes = nodes.trim();
307    if nodes.is_empty() {
308        return Ok(node_ids);
309    }
310
311    for range in nodes.split(',') {
312        let range = range.trim();
313        if range.is_empty() {
314            continue; // Skip empty entries caused by multiple commas
315        }
316
317        if let Some(dash_pos) = range.find('-') {
318            // Range format: "node1-node2"
319            let start_str = range[..dash_pos].trim();
320            let end_str = range[dash_pos + 1..].trim();
321
322            let start: u32 = start_str.parse().map_err(|_| {
323                MemoryPolicyError::InvalidNodes(format!("Invalid node range start: {}", start_str))
324            })?;
325            let end: u32 = end_str.parse().map_err(|_| {
326                MemoryPolicyError::InvalidNodes(format!("Invalid node range end: {}", end_str))
327            })?;
328
329            if start > end {
330                return Err(MemoryPolicyError::InvalidNodes(format!(
331                    "Invalid node range: {}-{}",
332                    start, end
333                )));
334            }
335
336            for node in start..=end {
337                node_ids.push(node);
338            }
339        } else {
340            // Single node
341            let node: u32 = range
342                .parse()
343                .map_err(|_| MemoryPolicyError::InvalidNodes(format!("Invalid node: {}", range)))?;
344
345            node_ids.push(node);
346        }
347    }
348
349    Ok(node_ids)
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use crate::syscall::syscall::create_syscall;
356    use crate::syscall::test::TestHelperSyscall;
357
358    #[test]
359    fn test_parse_node_string() {
360        // Test empty string
361        assert_eq!(parse_node_string("").unwrap(), Vec::<u32>::new());
362
363        // Test single node
364        assert_eq!(parse_node_string("0").unwrap(), vec![0]);
365        assert_eq!(parse_node_string("1").unwrap(), vec![1]);
366        assert_eq!(parse_node_string("2").unwrap(), vec![2]);
367
368        // Test node range
369        assert_eq!(parse_node_string("0-2").unwrap(), vec![0, 1, 2]);
370        assert_eq!(parse_node_string("1-3").unwrap(), vec![1, 2, 3]);
371
372        // Test multiple nodes
373        assert_eq!(parse_node_string("0,2").unwrap(), vec![0, 2]);
374        assert_eq!(parse_node_string("0,1,3").unwrap(), vec![0, 1, 3]);
375
376        // Test combination of ranges and single nodes
377        assert_eq!(parse_node_string("0-1,3").unwrap(), vec![0, 1, 3]);
378        assert_eq!(parse_node_string("0,2-3").unwrap(), vec![0, 2, 3]);
379
380        // Test with spaces
381        assert_eq!(parse_node_string(" 0 , 2 ").unwrap(), vec![0, 2]);
382        assert_eq!(parse_node_string(" 0 - 2 ").unwrap(), vec![0, 1, 2]);
383
384        // Test whitespace-only string
385        assert_eq!(parse_node_string("   ").unwrap(), Vec::<u32>::new());
386        assert_eq!(parse_node_string(" , , ").unwrap(), Vec::<u32>::new());
387
388        // Test error cases
389        assert!(parse_node_string("2-1").is_err()); // Invalid range
390        assert!(parse_node_string("abc").is_err()); // Invalid format
391        assert!(parse_node_string("0-abc").is_err()); // Invalid range end
392    }
393
394    #[test]
395    fn test_setup_memory_policy() {
396        use oci_spec::runtime::{LinuxMemoryPolicyBuilder, MemoryPolicyModeType};
397
398        let syscall = create_syscall();
399
400        // Test with None (no memory policy)
401        assert!(setup_memory_policy(&None, syscall.as_ref()).is_ok());
402
403        // Test with basic memory policy
404        let policy = LinuxMemoryPolicyBuilder::default()
405            .mode(MemoryPolicyModeType::MpolBind)
406            .nodes("0,1".to_string())
407            .flags(vec![])
408            .build()
409            .unwrap();
410
411        assert!(setup_memory_policy(&Some(policy), syscall.as_ref()).is_ok());
412
413        let got_args = syscall
414            .as_any()
415            .downcast_ref::<TestHelperSyscall>()
416            .unwrap()
417            .get_mempolicy_args();
418
419        assert_eq!(got_args.len(), 1);
420        assert_eq!(got_args[0].mode, 2); // MPOL_BIND (corrected value)
421        assert_eq!(got_args[0].nodemask.len(), 1); // One c_ulong needed
422        assert_eq!(got_args[0].nodemask[0], 3); // 2^0 + 2^1 = 1 + 2 = 3
423        assert_eq!(got_args[0].maxnode, 64); // (num_u64s * u64_bits) = (1 * 64) = 64
424
425        // Test with flags
426        let policy_with_flags = LinuxMemoryPolicyBuilder::default()
427            .mode(MemoryPolicyModeType::MpolBind)
428            .nodes("0".to_string())
429            .flags(vec![
430                oci_spec::runtime::MemoryPolicyFlagType::MpolFStaticNodes,
431            ])
432            .build()
433            .unwrap();
434
435        assert!(setup_memory_policy(&Some(policy_with_flags), syscall.as_ref()).is_ok());
436
437        let got_args_with_flags = syscall
438            .as_any()
439            .downcast_ref::<TestHelperSyscall>()
440            .unwrap()
441            .get_mempolicy_args();
442
443        assert_eq!(got_args_with_flags.len(), 2);
444        // Second call should have mode with flags OR'ed in
445        // MPOL_BIND (2) | MPOL_F_STATIC_NODES (0x8000)
446        assert_eq!(got_args_with_flags[1].mode, 2 | (1 << 15));
447        assert_eq!(got_args_with_flags[1].nodemask.len(), 1);
448        assert_eq!(got_args_with_flags[1].nodemask[0], 1); // 2^0 = 1
449        assert_eq!(got_args_with_flags[1].maxnode, 64); // (num_u64s * u64_bits) = (1 * 64) = 64
450
451        // Test invalid flag combinations
452        let policy_invalid_flags = LinuxMemoryPolicyBuilder::default()
453            .mode(MemoryPolicyModeType::MpolBind)
454            .nodes("0".to_string())
455            .flags(vec![
456                oci_spec::runtime::MemoryPolicyFlagType::MpolFStaticNodes,
457                oci_spec::runtime::MemoryPolicyFlagType::MpolFRelativeNodes,
458            ])
459            .build()
460            .unwrap();
461
462        assert!(setup_memory_policy(&Some(policy_invalid_flags), syscall.as_ref()).is_err());
463
464        // Test MPOL_F_NUMA_BALANCING with non-BIND mode
465        let policy_invalid_numa_balancing = LinuxMemoryPolicyBuilder::default()
466            .mode(MemoryPolicyModeType::MpolInterleave)
467            .nodes("0".to_string())
468            .flags(vec![
469                oci_spec::runtime::MemoryPolicyFlagType::MpolFNumaBalancing,
470            ])
471            .build()
472            .unwrap();
473
474        assert!(
475            setup_memory_policy(&Some(policy_invalid_numa_balancing), syscall.as_ref()).is_err()
476        );
477
478        // Test MPOL_DEFAULT with nodes (should fail)
479        let policy_default_with_nodes = LinuxMemoryPolicyBuilder::default()
480            .mode(MemoryPolicyModeType::MpolDefault)
481            .nodes("0".to_string())
482            .flags(vec![])
483            .build()
484            .unwrap();
485
486        assert!(setup_memory_policy(&Some(policy_default_with_nodes), syscall.as_ref()).is_err());
487
488        // Test MPOL_DEFAULT with flags (should fail)
489        let policy_default_with_flags = LinuxMemoryPolicyBuilder::default()
490            .mode(MemoryPolicyModeType::MpolDefault)
491            .nodes("".to_string())
492            .flags(vec![
493                oci_spec::runtime::MemoryPolicyFlagType::MpolFStaticNodes,
494            ])
495            .build()
496            .unwrap();
497
498        assert!(setup_memory_policy(&Some(policy_default_with_flags), syscall.as_ref()).is_err());
499
500        // Test MPOL_LOCAL with nodes (should fail)
501        let policy_local_with_nodes = LinuxMemoryPolicyBuilder::default()
502            .mode(MemoryPolicyModeType::MpolLocal)
503            .nodes("0".to_string())
504            .flags(vec![])
505            .build()
506            .unwrap();
507
508        assert!(setup_memory_policy(&Some(policy_local_with_nodes), syscall.as_ref()).is_err());
509
510        // Test MPOL_BIND with empty nodes (should fail)
511        let policy_bind_empty = LinuxMemoryPolicyBuilder::default()
512            .mode(MemoryPolicyModeType::MpolBind)
513            .nodes("".to_string())
514            .flags(vec![])
515            .build()
516            .unwrap();
517
518        assert!(setup_memory_policy(&Some(policy_bind_empty), syscall.as_ref()).is_err());
519
520        // Test MPOL_BIND with whitespace-only nodes (should fail)
521        let policy_bind_whitespace = LinuxMemoryPolicyBuilder::default()
522            .mode(MemoryPolicyModeType::MpolBind)
523            .nodes("   ".to_string())
524            .flags(vec![])
525            .build()
526            .unwrap();
527
528        assert!(setup_memory_policy(&Some(policy_bind_whitespace), syscall.as_ref()).is_err());
529
530        // Test MPOL_PREFERRED with empty nodes and STATIC_NODES flag (should fail)
531        let policy_preferred_empty_with_flags = LinuxMemoryPolicyBuilder::default()
532            .mode(MemoryPolicyModeType::MpolPreferred)
533            .nodes("".to_string())
534            .flags(vec![
535                oci_spec::runtime::MemoryPolicyFlagType::MpolFStaticNodes,
536            ])
537            .build()
538            .unwrap();
539
540        assert!(
541            setup_memory_policy(&Some(policy_preferred_empty_with_flags), syscall.as_ref())
542                .is_err()
543        );
544    }
545}