1use std::fmt;
2
3use oci_spec::runtime::{MemoryPolicyFlagType, MemoryPolicyModeType};
4
5use crate::syscall::{Syscall, SyscallError};
6
7#[derive(Debug, thiserror::Error)]
8pub enum MemoryPolicyError {
9 #[error("Invalid memory policy flag: {0}")]
10 InvalidFlag(String),
11
12 #[error("Invalid node specification: {0}")]
13 InvalidNodes(String),
14
15 #[error("Incompatible flag and mode combination: {0}")]
16 IncompatibleFlagMode(String),
17
18 #[error("Mutually exclusive flags: {0}")]
19 MutuallyExclusiveFlags(String),
20
21 #[error("Syscall error: {0}")]
22 Syscall(#[from] SyscallError),
23}
24
25type Result<T> = std::result::Result<T, MemoryPolicyError>;
26
27#[repr(i32)]
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29enum MemoryPolicyMode {
30 Default = 0,
31 Preferred = 1,
32 Bind = 2,
33 Interleave = 3,
34 Local = 4,
35 PreferredMany = 5,
36 WeightedInterleave = 6,
37}
38
39impl From<MemoryPolicyMode> for i32 {
40 fn from(mode: MemoryPolicyMode) -> Self {
41 mode as i32
42 }
43}
44
45impl From<MemoryPolicyModeType> for MemoryPolicyMode {
46 fn from(mode: MemoryPolicyModeType) -> Self {
47 match mode {
48 MemoryPolicyModeType::MpolDefault => MemoryPolicyMode::Default,
49 MemoryPolicyModeType::MpolPreferred => MemoryPolicyMode::Preferred,
50 MemoryPolicyModeType::MpolBind => MemoryPolicyMode::Bind,
51 MemoryPolicyModeType::MpolInterleave => MemoryPolicyMode::Interleave,
52 MemoryPolicyModeType::MpolLocal => MemoryPolicyMode::Local,
53 MemoryPolicyModeType::MpolPreferredMany => MemoryPolicyMode::PreferredMany,
54 MemoryPolicyModeType::MpolWeightedInterleave => MemoryPolicyMode::WeightedInterleave,
55 }
56 }
57}
58
59impl fmt::Display for MemoryPolicyMode {
60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61 let s = match self {
62 MemoryPolicyMode::Default => "MPOL_DEFAULT",
63 MemoryPolicyMode::Preferred => "MPOL_PREFERRED",
64 MemoryPolicyMode::Bind => "MPOL_BIND",
65 MemoryPolicyMode::Interleave => "MPOL_INTERLEAVE",
66 MemoryPolicyMode::Local => "MPOL_LOCAL",
67 MemoryPolicyMode::PreferredMany => "MPOL_PREFERRED_MANY",
68 MemoryPolicyMode::WeightedInterleave => "MPOL_WEIGHTED_INTERLEAVE",
69 };
70 write!(f, "{}", s)
71 }
72}
73
74#[repr(u32)]
75#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76enum MemoryPolicyFlag {
77 NumaBalancing = 1 << 13, RelativeNodes = 1 << 14, StaticNodes = 1 << 15, }
81
82impl From<MemoryPolicyFlag> for u32 {
83 fn from(flag: MemoryPolicyFlag) -> Self {
84 flag as u32
85 }
86}
87
88struct ValidatedMemoryPolicy {
89 mode_with_flags: i32,
90 nodemask: Vec<libc::c_ulong>,
91 maxnode: u64,
92}
93
94fn validate_memory_policy(
95 memory_policy: &Option<oci_spec::runtime::LinuxMemoryPolicy>,
96) -> Result<Option<ValidatedMemoryPolicy>> {
97 let Some(policy) = memory_policy else {
98 return Ok(None);
99 };
100
101 let base_mode = MemoryPolicyMode::from(policy.mode());
102
103 let (flags_value, has_static, has_relative) = policy
104 .flags()
105 .as_ref()
106 .map(|flags| {
107 flags
108 .iter()
109 .fold((0u32, false, false), |(val, s, r), flag| match flag {
110 MemoryPolicyFlagType::MpolFNumaBalancing => {
111 (val | u32::from(MemoryPolicyFlag::NumaBalancing), s, r)
112 }
113 MemoryPolicyFlagType::MpolFStaticNodes => {
114 (val | u32::from(MemoryPolicyFlag::StaticNodes), true, r)
115 }
116 MemoryPolicyFlagType::MpolFRelativeNodes => {
117 (val | u32::from(MemoryPolicyFlag::RelativeNodes), s, true)
118 }
119 })
120 })
121 .unwrap_or((0, false, false));
122
123 if let Some(flags) = policy.flags() {
125 if flags.contains(&MemoryPolicyFlagType::MpolFNumaBalancing)
126 && base_mode != MemoryPolicyMode::Bind
127 {
128 return Err(MemoryPolicyError::IncompatibleFlagMode(
129 "MPOL_F_NUMA_BALANCING can only be used with MPOL_BIND".to_string(),
130 ));
131 }
132 }
133
134 if has_static && has_relative {
135 return Err(MemoryPolicyError::MutuallyExclusiveFlags(
136 "MPOL_F_STATIC_NODES and MPOL_F_RELATIVE_NODES are mutually exclusive".to_string(),
137 ));
138 }
139
140 let mode_with_flags = i32::from(base_mode) | (flags_value as i32);
141
142 match base_mode {
143 MemoryPolicyMode::Default | MemoryPolicyMode::Local => {
144 let mode_name = base_mode.to_string();
145
146 if let Some(nodes) = policy.nodes() {
147 if !nodes.trim().is_empty() {
148 return Err(MemoryPolicyError::InvalidNodes(format!(
149 "{} does not accept node specification",
150 mode_name
151 )));
152 }
153 }
154 if flags_value != 0 {
155 return Err(MemoryPolicyError::InvalidFlag(format!(
156 "{} does not accept flags",
157 mode_name
158 )));
159 }
160 Ok(Some(ValidatedMemoryPolicy {
161 mode_with_flags,
162 nodemask: Vec::new(),
163 maxnode: 0,
164 }))
165 }
166 MemoryPolicyMode::Preferred => {
167 let relative_or_static: u32 = u32::from(MemoryPolicyFlag::RelativeNodes)
168 | u32::from(MemoryPolicyFlag::StaticNodes);
169
170 let check_empty_nodes_flags = |flags_value: u32| -> Result<()> {
171 if flags_value & relative_or_static != 0u32 {
172 return Err(MemoryPolicyError::IncompatibleFlagMode(
173 "MPOL_PREFERRED with empty nodes cannot use MPOL_F_STATIC_NODES or MPOL_F_RELATIVE_NODES flags".to_string(),
174 ));
175 }
176 Ok(())
177 };
178
179 match policy.nodes() {
180 None => {
181 check_empty_nodes_flags(flags_value)?;
182 Ok(Some(ValidatedMemoryPolicy {
183 mode_with_flags,
184 nodemask: Vec::new(),
185 maxnode: 0,
186 }))
187 }
188 Some(nodes) if nodes.trim().is_empty() => {
189 check_empty_nodes_flags(flags_value)?;
190 Ok(Some(ValidatedMemoryPolicy {
191 mode_with_flags,
192 nodemask: Vec::new(),
193 maxnode: 0,
194 }))
195 }
196 Some(nodes) => {
197 let (nodemask, maxnode) = build_nodemask(nodes)?;
198 if maxnode == 0 {
199 check_empty_nodes_flags(flags_value)?;
200 return Ok(Some(ValidatedMemoryPolicy {
201 mode_with_flags,
202 nodemask: Vec::new(),
203 maxnode: 0,
204 }));
205 }
206 Ok(Some(ValidatedMemoryPolicy {
207 mode_with_flags,
208 nodemask,
209 maxnode,
210 }))
211 }
212 }
213 }
214 _ => {
215 let mode_name = base_mode.to_string();
216 let nodes = match policy.nodes() {
217 None => {
218 return Err(MemoryPolicyError::InvalidNodes(format!(
219 "Mode {} requires non-empty node specification",
220 mode_name
221 )));
222 }
223 Some(nodes) if nodes.trim().is_empty() => {
224 return Err(MemoryPolicyError::InvalidNodes(format!(
225 "Mode {} requires non-empty node specification",
226 mode_name
227 )));
228 }
229 Some(nodes) => nodes,
230 };
231 let (nodemask, maxnode) = build_nodemask(nodes)?;
232 if maxnode == 0 {
233 return Err(MemoryPolicyError::InvalidNodes(format!(
234 "Mode {} requires non-empty node specification (parsed result is empty)",
235 mode_name
236 )));
237 }
238 Ok(Some(ValidatedMemoryPolicy {
239 mode_with_flags,
240 nodemask,
241 maxnode,
242 }))
243 }
244 }
245}
246
247pub fn setup_memory_policy(
251 memory_policy: &Option<oci_spec::runtime::LinuxMemoryPolicy>,
252 syscall: &dyn Syscall,
253) -> Result<()> {
254 let validated = validate_memory_policy(memory_policy)?;
255 if let Some(valid) = validated {
256 syscall
257 .set_mempolicy(valid.mode_with_flags, &valid.nodemask, valid.maxnode)
258 .map_err(|err| {
259 tracing::error!(?err, "failed to set memory policy");
260 MemoryPolicyError::Syscall(err)
261 })?;
262 }
263 Ok(())
264}
265
266fn build_nodemask(nodes: &str) -> Result<(Vec<libc::c_ulong>, u64)> {
268 let node_ids = parse_node_string(nodes)?;
269
270 if node_ids.is_empty() {
271 return Ok((Vec::new(), 0));
273 }
274
275 let highest_node = node_ids.iter().max().copied().unwrap_or(0) as usize;
277
278 let bits_per_ulong = std::mem::size_of::<libc::c_ulong>() * 8;
280 let num_ulongs = (highest_node / bits_per_ulong) + 1;
281
282 let maxnode = (num_ulongs * bits_per_ulong) as u64;
284
285 let mut nodemask = vec![0 as libc::c_ulong; num_ulongs];
287
288 for node_id in node_ids {
290 let node_id = node_id as usize;
291 let word_index = node_id / bits_per_ulong;
292 let bit_index = node_id % bits_per_ulong;
293
294 if word_index < nodemask.len() {
295 nodemask[word_index] |= (1 as libc::c_ulong) << bit_index;
296 }
297 }
298
299 Ok((nodemask, maxnode))
300}
301
302fn parse_node_string(nodes: &str) -> Result<Vec<u32>> {
303 let mut node_ids = Vec::new();
304
305 let nodes = nodes.trim();
307 if nodes.is_empty() {
308 return Ok(node_ids);
309 }
310
311 for range in nodes.split(',') {
312 let range = range.trim();
313 if range.is_empty() {
314 continue; }
316
317 if let Some(dash_pos) = range.find('-') {
318 let start_str = range[..dash_pos].trim();
320 let end_str = range[dash_pos + 1..].trim();
321
322 let start: u32 = start_str.parse().map_err(|_| {
323 MemoryPolicyError::InvalidNodes(format!("Invalid node range start: {}", start_str))
324 })?;
325 let end: u32 = end_str.parse().map_err(|_| {
326 MemoryPolicyError::InvalidNodes(format!("Invalid node range end: {}", end_str))
327 })?;
328
329 if start > end {
330 return Err(MemoryPolicyError::InvalidNodes(format!(
331 "Invalid node range: {}-{}",
332 start, end
333 )));
334 }
335
336 for node in start..=end {
337 node_ids.push(node);
338 }
339 } else {
340 let node: u32 = range
342 .parse()
343 .map_err(|_| MemoryPolicyError::InvalidNodes(format!("Invalid node: {}", range)))?;
344
345 node_ids.push(node);
346 }
347 }
348
349 Ok(node_ids)
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use crate::syscall::syscall::create_syscall;
356 use crate::syscall::test::TestHelperSyscall;
357
358 #[test]
359 fn test_parse_node_string() {
360 assert_eq!(parse_node_string("").unwrap(), Vec::<u32>::new());
362
363 assert_eq!(parse_node_string("0").unwrap(), vec![0]);
365 assert_eq!(parse_node_string("1").unwrap(), vec![1]);
366 assert_eq!(parse_node_string("2").unwrap(), vec![2]);
367
368 assert_eq!(parse_node_string("0-2").unwrap(), vec![0, 1, 2]);
370 assert_eq!(parse_node_string("1-3").unwrap(), vec![1, 2, 3]);
371
372 assert_eq!(parse_node_string("0,2").unwrap(), vec![0, 2]);
374 assert_eq!(parse_node_string("0,1,3").unwrap(), vec![0, 1, 3]);
375
376 assert_eq!(parse_node_string("0-1,3").unwrap(), vec![0, 1, 3]);
378 assert_eq!(parse_node_string("0,2-3").unwrap(), vec![0, 2, 3]);
379
380 assert_eq!(parse_node_string(" 0 , 2 ").unwrap(), vec![0, 2]);
382 assert_eq!(parse_node_string(" 0 - 2 ").unwrap(), vec![0, 1, 2]);
383
384 assert_eq!(parse_node_string(" ").unwrap(), Vec::<u32>::new());
386 assert_eq!(parse_node_string(" , , ").unwrap(), Vec::<u32>::new());
387
388 assert!(parse_node_string("2-1").is_err()); assert!(parse_node_string("abc").is_err()); assert!(parse_node_string("0-abc").is_err()); }
393
394 #[test]
395 fn test_setup_memory_policy() {
396 use oci_spec::runtime::{LinuxMemoryPolicyBuilder, MemoryPolicyModeType};
397
398 let syscall = create_syscall();
399
400 assert!(setup_memory_policy(&None, syscall.as_ref()).is_ok());
402
403 let policy = LinuxMemoryPolicyBuilder::default()
405 .mode(MemoryPolicyModeType::MpolBind)
406 .nodes("0,1".to_string())
407 .flags(vec![])
408 .build()
409 .unwrap();
410
411 assert!(setup_memory_policy(&Some(policy), syscall.as_ref()).is_ok());
412
413 let got_args = syscall
414 .as_any()
415 .downcast_ref::<TestHelperSyscall>()
416 .unwrap()
417 .get_mempolicy_args();
418
419 assert_eq!(got_args.len(), 1);
420 assert_eq!(got_args[0].mode, 2); assert_eq!(got_args[0].nodemask.len(), 1); assert_eq!(got_args[0].nodemask[0], 3); assert_eq!(got_args[0].maxnode, 64); let policy_with_flags = LinuxMemoryPolicyBuilder::default()
427 .mode(MemoryPolicyModeType::MpolBind)
428 .nodes("0".to_string())
429 .flags(vec![
430 oci_spec::runtime::MemoryPolicyFlagType::MpolFStaticNodes,
431 ])
432 .build()
433 .unwrap();
434
435 assert!(setup_memory_policy(&Some(policy_with_flags), syscall.as_ref()).is_ok());
436
437 let got_args_with_flags = syscall
438 .as_any()
439 .downcast_ref::<TestHelperSyscall>()
440 .unwrap()
441 .get_mempolicy_args();
442
443 assert_eq!(got_args_with_flags.len(), 2);
444 assert_eq!(got_args_with_flags[1].mode, 2 | (1 << 15));
447 assert_eq!(got_args_with_flags[1].nodemask.len(), 1);
448 assert_eq!(got_args_with_flags[1].nodemask[0], 1); assert_eq!(got_args_with_flags[1].maxnode, 64); let policy_invalid_flags = LinuxMemoryPolicyBuilder::default()
453 .mode(MemoryPolicyModeType::MpolBind)
454 .nodes("0".to_string())
455 .flags(vec![
456 oci_spec::runtime::MemoryPolicyFlagType::MpolFStaticNodes,
457 oci_spec::runtime::MemoryPolicyFlagType::MpolFRelativeNodes,
458 ])
459 .build()
460 .unwrap();
461
462 assert!(setup_memory_policy(&Some(policy_invalid_flags), syscall.as_ref()).is_err());
463
464 let policy_invalid_numa_balancing = LinuxMemoryPolicyBuilder::default()
466 .mode(MemoryPolicyModeType::MpolInterleave)
467 .nodes("0".to_string())
468 .flags(vec![
469 oci_spec::runtime::MemoryPolicyFlagType::MpolFNumaBalancing,
470 ])
471 .build()
472 .unwrap();
473
474 assert!(
475 setup_memory_policy(&Some(policy_invalid_numa_balancing), syscall.as_ref()).is_err()
476 );
477
478 let policy_default_with_nodes = LinuxMemoryPolicyBuilder::default()
480 .mode(MemoryPolicyModeType::MpolDefault)
481 .nodes("0".to_string())
482 .flags(vec![])
483 .build()
484 .unwrap();
485
486 assert!(setup_memory_policy(&Some(policy_default_with_nodes), syscall.as_ref()).is_err());
487
488 let policy_default_with_flags = LinuxMemoryPolicyBuilder::default()
490 .mode(MemoryPolicyModeType::MpolDefault)
491 .nodes("".to_string())
492 .flags(vec![
493 oci_spec::runtime::MemoryPolicyFlagType::MpolFStaticNodes,
494 ])
495 .build()
496 .unwrap();
497
498 assert!(setup_memory_policy(&Some(policy_default_with_flags), syscall.as_ref()).is_err());
499
500 let policy_local_with_nodes = LinuxMemoryPolicyBuilder::default()
502 .mode(MemoryPolicyModeType::MpolLocal)
503 .nodes("0".to_string())
504 .flags(vec![])
505 .build()
506 .unwrap();
507
508 assert!(setup_memory_policy(&Some(policy_local_with_nodes), syscall.as_ref()).is_err());
509
510 let policy_bind_empty = LinuxMemoryPolicyBuilder::default()
512 .mode(MemoryPolicyModeType::MpolBind)
513 .nodes("".to_string())
514 .flags(vec![])
515 .build()
516 .unwrap();
517
518 assert!(setup_memory_policy(&Some(policy_bind_empty), syscall.as_ref()).is_err());
519
520 let policy_bind_whitespace = LinuxMemoryPolicyBuilder::default()
522 .mode(MemoryPolicyModeType::MpolBind)
523 .nodes(" ".to_string())
524 .flags(vec![])
525 .build()
526 .unwrap();
527
528 assert!(setup_memory_policy(&Some(policy_bind_whitespace), syscall.as_ref()).is_err());
529
530 let policy_preferred_empty_with_flags = LinuxMemoryPolicyBuilder::default()
532 .mode(MemoryPolicyModeType::MpolPreferred)
533 .nodes("".to_string())
534 .flags(vec![
535 oci_spec::runtime::MemoryPolicyFlagType::MpolFStaticNodes,
536 ])
537 .build()
538 .unwrap();
539
540 assert!(
541 setup_memory_policy(&Some(policy_preferred_empty_with_flags), syscall.as_ref())
542 .is_err()
543 );
544 }
545}