1use std::alloc::{Layout, alloc, dealloc};
33use std::ptr::NonNull;
34use std::sync::atomic::{AtomicUsize, Ordering};
35
36pub type NumaNode = u32;
38
39pub type NumaResult<T> = Result<T, NumaError>;
41
42#[derive(Debug, Clone, PartialEq, Eq)]
44pub enum NumaError {
45 NotAvailable,
47 InvalidNode(NumaNode),
49 AllocationFailed,
51 PinningFailed,
53 SystemError(String),
55}
56
57impl std::fmt::Display for NumaError {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 match self {
60 NumaError::NotAvailable => write!(f, "NUMA not available"),
61 NumaError::InvalidNode(n) => write!(f, "invalid NUMA node: {}", n),
62 NumaError::AllocationFailed => write!(f, "NUMA allocation failed"),
63 NumaError::PinningFailed => write!(f, "thread pinning failed"),
64 NumaError::SystemError(e) => write!(f, "system error: {}", e),
65 }
66 }
67}
68
69impl std::error::Error for NumaError {}
70
71#[derive(Debug, Clone)]
73pub struct NumaTopology {
74 pub num_nodes: usize,
76 pub cpus_per_node: Vec<Vec<usize>>,
78 pub memory_per_node: Vec<usize>,
80 pub distances: Vec<Vec<u32>>,
82}
83
84impl NumaTopology {
85 pub fn detect() -> Self {
87 #[cfg(target_os = "linux")]
91 {
92 Self::detect_linux().unwrap_or_else(Self::single_node)
93 }
94
95 #[cfg(not(target_os = "linux"))]
96 {
97 Self::single_node()
98 }
99 }
100
101 pub fn single_node() -> Self {
103 let num_cpus = std::thread::available_parallelism()
104 .map(|n| n.get())
105 .unwrap_or(1);
106
107 Self {
108 num_nodes: 1,
109 cpus_per_node: vec![(0..num_cpus).collect()],
110 memory_per_node: vec![0], distances: vec![vec![10]], }
113 }
114
115 #[cfg(target_os = "linux")]
116 fn detect_linux() -> Option<Self> {
117 use std::fs;
118
119 let node_path = "/sys/devices/system/node";
121 let entries = fs::read_dir(node_path).ok()?;
122
123 let mut nodes = Vec::new();
124 for entry in entries.flatten() {
125 let name = entry.file_name();
126 let name_str = name.to_string_lossy();
127 if name_str.starts_with("node") {
128 if let Ok(num) = name_str[4..].parse::<usize>() {
129 nodes.push(num);
130 }
131 }
132 }
133
134 if nodes.is_empty() {
135 return None;
136 }
137
138 nodes.sort();
139 let num_nodes = nodes.len();
140
141 let mut cpus_per_node = Vec::new();
143 for node in &nodes {
144 let cpu_path = format!("{}/node{}/cpulist", node_path, node);
145 let cpulist = fs::read_to_string(cpu_path).ok()?;
146 let cpus = Self::parse_cpulist(&cpulist);
147 cpus_per_node.push(cpus);
148 }
149
150 let mut memory_per_node = Vec::new();
152 for node in &nodes {
153 let mem_path = format!("{}/node{}/meminfo", node_path, node);
154 let meminfo = fs::read_to_string(mem_path).unwrap_or_default();
155 let mem = Self::parse_meminfo(&meminfo);
156 memory_per_node.push(mem);
157 }
158
159 let mut distances = vec![vec![20u32; num_nodes]; num_nodes];
161 for i in 0..num_nodes {
162 distances[i][i] = 10; }
164
165 Some(Self {
166 num_nodes,
167 cpus_per_node,
168 memory_per_node,
169 distances,
170 })
171 }
172
173 #[cfg(target_os = "linux")]
174 fn parse_cpulist(cpulist: &str) -> Vec<usize> {
175 let mut cpus = Vec::new();
176 for part in cpulist.trim().split(',') {
177 if part.contains('-') {
178 let range: Vec<&str> = part.split('-').collect();
179 if range.len() == 2 {
180 if let (Ok(start), Ok(end)) =
181 (range[0].parse::<usize>(), range[1].parse::<usize>())
182 {
183 cpus.extend(start..=end);
184 }
185 }
186 } else if let Ok(cpu) = part.parse::<usize>() {
187 cpus.push(cpu);
188 }
189 }
190 cpus
191 }
192
193 #[cfg(target_os = "linux")]
194 fn parse_meminfo(meminfo: &str) -> usize {
195 for line in meminfo.lines() {
196 if line.starts_with("Node") && line.contains("MemTotal:") {
197 let parts: Vec<&str> = line.split_whitespace().collect();
198 if parts.len() >= 4 {
199 if let Ok(kb) = parts[3].parse::<usize>() {
200 return kb * 1024; }
202 }
203 }
204 }
205 0
206 }
207
208 pub fn local_cpus(&self, node: NumaNode) -> &[usize] {
210 self.cpus_per_node
211 .get(node as usize)
212 .map(|v| v.as_slice())
213 .unwrap_or(&[])
214 }
215
216 pub fn distance(&self, from: NumaNode, to: NumaNode) -> u32 {
218 self.distances
219 .get(from as usize)
220 .and_then(|row| row.get(to as usize))
221 .copied()
222 .unwrap_or(u32::MAX)
223 }
224
225 pub fn nearest_nodes(&self, node: NumaNode) -> Vec<NumaNode> {
227 let mut nodes: Vec<(NumaNode, u32)> = (0..self.num_nodes as NumaNode)
228 .map(|n| (n, self.distance(node, n)))
229 .collect();
230 nodes.sort_by_key(|&(_, d)| d);
231 nodes.into_iter().map(|(n, _)| n).collect()
232 }
233}
234
235pub struct NumaBuffer {
237 ptr: NonNull<u8>,
239 size: usize,
241 layout: Layout,
243 node: Option<NumaNode>,
245 faulted: bool,
247}
248
249unsafe impl Send for NumaBuffer {}
251unsafe impl Sync for NumaBuffer {}
252
253impl NumaBuffer {
254 #[inline]
256 pub fn as_ptr(&self) -> *const u8 {
257 self.ptr.as_ptr()
258 }
259
260 #[inline]
262 pub fn as_mut_ptr(&mut self) -> *mut u8 {
263 self.ptr.as_ptr()
264 }
265
266 #[inline]
268 pub fn len(&self) -> usize {
269 self.size
270 }
271
272 #[inline]
274 pub fn is_empty(&self) -> bool {
275 self.size == 0
276 }
277
278 #[inline]
280 pub fn node(&self) -> Option<NumaNode> {
281 self.node
282 }
283
284 #[inline]
286 pub fn is_faulted(&self) -> bool {
287 self.faulted
288 }
289
290 #[inline]
292 pub fn as_slice(&self) -> &[u8] {
293 unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.size) }
294 }
295
296 #[inline]
298 pub fn as_mut_slice(&mut self) -> &mut [u8] {
299 unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.size) }
300 }
301}
302
303impl Drop for NumaBuffer {
304 fn drop(&mut self) {
305 unsafe {
306 dealloc(self.ptr.as_ptr(), self.layout);
307 }
308 }
309}
310
311pub struct NumaAllocator {
313 topology: NumaTopology,
315 page_size: usize,
317 allocated: AtomicUsize,
319 allocations_per_node: Vec<AtomicUsize>,
321}
322
323impl NumaAllocator {
324 pub fn new() -> NumaResult<Self> {
326 let topology = NumaTopology::detect();
327 Self::with_topology(topology)
328 }
329
330 pub fn with_topology(topology: NumaTopology) -> NumaResult<Self> {
332 let page_size = Self::get_page_size();
333 let allocations_per_node = (0..topology.num_nodes)
334 .map(|_| AtomicUsize::new(0))
335 .collect();
336
337 Ok(Self {
338 topology,
339 page_size,
340 allocated: AtomicUsize::new(0),
341 allocations_per_node,
342 })
343 }
344
345 fn get_page_size() -> usize {
347 #[cfg(unix)]
348 {
349 let page_size = unsafe { libc::sysconf(libc::_SC_PAGESIZE) };
351 if page_size > 0 {
352 return page_size as usize;
353 }
354 }
355 4096
357 }
358
359 pub fn allocate_on_node(&self, size: usize, node: NumaNode) -> NumaResult<NumaBuffer> {
361 if node as usize >= self.topology.num_nodes {
362 return Err(NumaError::InvalidNode(node));
363 }
364
365 let aligned_size = (size + self.page_size - 1) & !(self.page_size - 1);
367 let layout = Layout::from_size_align(aligned_size, self.page_size)
368 .map_err(|_| NumaError::AllocationFailed)?;
369
370 let ptr = unsafe { alloc(layout) };
373 let ptr = NonNull::new(ptr).ok_or(NumaError::AllocationFailed)?;
374
375 self.allocated.fetch_add(aligned_size, Ordering::Relaxed);
376 self.allocations_per_node[node as usize].fetch_add(1, Ordering::Relaxed);
377
378 Ok(NumaBuffer {
379 ptr,
380 size: aligned_size,
381 layout,
382 node: Some(node),
383 faulted: false,
384 })
385 }
386
387 pub fn allocate(&self, size: usize) -> NumaResult<NumaBuffer> {
389 let aligned_size = (size + self.page_size - 1) & !(self.page_size - 1);
390 let layout = Layout::from_size_align(aligned_size, self.page_size)
391 .map_err(|_| NumaError::AllocationFailed)?;
392
393 let ptr = unsafe { alloc(layout) };
394 let ptr = NonNull::new(ptr).ok_or(NumaError::AllocationFailed)?;
395
396 self.allocated.fetch_add(aligned_size, Ordering::Relaxed);
397
398 Ok(NumaBuffer {
399 ptr,
400 size: aligned_size,
401 layout,
402 node: None,
403 faulted: false,
404 })
405 }
406
407 pub fn prefault(&self, buffer: &mut NumaBuffer) {
412 if buffer.faulted {
413 return;
414 }
415
416 let page_size = self.page_size;
418 let ptr = buffer.as_mut_ptr();
419 let size = buffer.len();
420
421 for offset in (0..size).step_by(page_size) {
422 unsafe {
423 std::ptr::write_volatile(ptr.add(offset), 0);
424 }
425 }
426
427 std::sync::atomic::fence(Ordering::SeqCst);
429 buffer.faulted = true;
430 }
431
432 pub fn topology(&self) -> &NumaTopology {
434 &self.topology
435 }
436
437 pub fn total_allocated(&self) -> usize {
439 self.allocated.load(Ordering::Relaxed)
440 }
441
442 pub fn page_size(&self) -> usize {
444 self.page_size
445 }
446}
447
448impl Default for NumaAllocator {
449 fn default() -> Self {
450 Self::new().unwrap_or_else(|_| Self::with_topology(NumaTopology::single_node()).unwrap())
451 }
452}
453
454pub struct ThreadPinner {
456 topology: NumaTopology,
457}
458
459impl ThreadPinner {
460 pub fn new(topology: NumaTopology) -> Self {
462 Self { topology }
463 }
464
465 #[cfg(target_os = "linux")]
467 pub fn pin_to_cpu(&self, cpu: usize) -> NumaResult<()> {
468 use std::mem::size_of;
469
470 unsafe {
471 let mut cpuset: libc::cpu_set_t = std::mem::zeroed();
472 libc::CPU_ZERO(&mut cpuset);
473 libc::CPU_SET(cpu, &mut cpuset);
474
475 let result = libc::sched_setaffinity(
476 0, size_of::<libc::cpu_set_t>(),
478 &cpuset,
479 );
480
481 if result == 0 {
482 Ok(())
483 } else {
484 Err(NumaError::PinningFailed)
485 }
486 }
487 }
488
489 #[cfg(not(target_os = "linux"))]
491 pub fn pin_to_cpu(&self, _cpu: usize) -> NumaResult<()> {
492 Err(NumaError::NotAvailable)
494 }
495
496 pub fn pin_to_node(&self, node: NumaNode) -> NumaResult<()> {
498 let cpus = self.topology.local_cpus(node);
499 if cpus.is_empty() {
500 return Err(NumaError::InvalidNode(node));
501 }
502
503 self.pin_to_cpu(cpus[0])
505 }
506
507 #[cfg(target_os = "linux")]
509 pub fn current_cpu(&self) -> Option<usize> {
510 unsafe {
511 let cpu = libc::sched_getcpu();
512 if cpu >= 0 { Some(cpu as usize) } else { None }
513 }
514 }
515
516 #[cfg(not(target_os = "linux"))]
518 pub fn current_cpu(&self) -> Option<usize> {
519 None
520 }
521
522 pub fn current_node(&self) -> Option<NumaNode> {
524 let cpu = self.current_cpu()?;
525
526 for (node, cpus) in self.topology.cpus_per_node.iter().enumerate() {
527 if cpus.contains(&cpu) {
528 return Some(node as NumaNode);
529 }
530 }
531
532 None
533 }
534}
535
536#[derive(Debug, Clone, Copy, PartialEq, Eq)]
538pub enum AllocationStrategy {
539 Fixed(NumaNode),
541 Local,
543 RoundRobin,
545 Interleave,
547}
548
549pub struct NumaVectorStorage<T> {
551 buffers: Vec<NumaBuffer>,
552 len: usize,
553 capacity: usize,
554 element_size: usize,
555 #[allow(dead_code)]
556 allocator: NumaAllocator,
557 _phantom: std::marker::PhantomData<T>,
558}
559
560impl<T: Copy> NumaVectorStorage<T> {
561 pub fn with_capacity_on_node(capacity: usize, node: NumaNode) -> NumaResult<Self> {
563 let allocator = NumaAllocator::new()?;
564 let element_size = std::mem::size_of::<T>();
565 let byte_size = capacity * element_size;
566
567 let mut buffer = allocator.allocate_on_node(byte_size, node)?;
568 allocator.prefault(&mut buffer);
569
570 Ok(Self {
571 buffers: vec![buffer],
572 len: 0,
573 capacity,
574 element_size,
575 allocator,
576 _phantom: std::marker::PhantomData,
577 })
578 }
579
580 pub fn get(&self, index: usize) -> Option<&T> {
582 if index >= self.len {
583 return None;
584 }
585
586 let buffer = &self.buffers[0];
588 let offset = index * self.element_size;
589
590 if offset + self.element_size > buffer.len() {
591 return None;
592 }
593
594 unsafe { Some(&*(buffer.as_ptr().add(offset) as *const T)) }
595 }
596
597 pub fn push(&mut self, value: T) -> NumaResult<()> {
599 if self.len >= self.capacity {
600 return Err(NumaError::AllocationFailed);
601 }
602
603 let buffer = &mut self.buffers[0];
604 let offset = self.len * self.element_size;
605
606 unsafe {
607 std::ptr::write(buffer.as_mut_ptr().add(offset) as *mut T, value);
608 }
609
610 self.len += 1;
611 Ok(())
612 }
613
614 pub fn len(&self) -> usize {
616 self.len
617 }
618
619 pub fn is_empty(&self) -> bool {
621 self.len == 0
622 }
623
624 pub fn capacity(&self) -> usize {
626 self.capacity
627 }
628}
629
630#[cfg(test)]
631mod tests {
632 use super::*;
633
634 #[test]
635 fn test_topology_single_node() {
636 let topo = NumaTopology::single_node();
637 assert_eq!(topo.num_nodes, 1);
638 assert!(!topo.cpus_per_node[0].is_empty());
639 assert_eq!(topo.distance(0, 0), 10);
640 }
641
642 #[test]
643 fn test_topology_detect() {
644 let topo = NumaTopology::detect();
645 assert!(topo.num_nodes >= 1);
646 assert!(!topo.cpus_per_node.is_empty());
647 }
648
649 #[test]
650 fn test_allocator_basic() {
651 let allocator = NumaAllocator::default();
652
653 let buffer = allocator.allocate(4096).unwrap();
654 assert!(buffer.len() >= 4096);
655 assert!(!buffer.is_faulted());
656 }
657
658 #[test]
659 fn test_allocator_on_node() {
660 let allocator = NumaAllocator::default();
661
662 let buffer = allocator.allocate_on_node(8192, 0).unwrap();
664 assert!(buffer.len() >= 8192);
665 assert_eq!(buffer.node(), Some(0));
666 }
667
668 #[test]
669 fn test_prefault() {
670 let allocator = NumaAllocator::default();
671 let mut buffer = allocator.allocate(65536).unwrap();
672
673 assert!(!buffer.is_faulted());
674 allocator.prefault(&mut buffer);
675 assert!(buffer.is_faulted());
676 }
677
678 #[test]
679 fn test_buffer_read_write() {
680 let allocator = NumaAllocator::default();
681 let mut buffer = allocator.allocate(4096).unwrap();
682 allocator.prefault(&mut buffer);
683
684 let slice = buffer.as_mut_slice();
686 for (i, byte) in slice.iter_mut().enumerate() {
687 *byte = (i % 256) as u8;
688 }
689
690 let slice = buffer.as_slice();
692 for (i, &byte) in slice.iter().enumerate() {
693 assert_eq!(byte, (i % 256) as u8);
694 }
695 }
696
697 #[test]
698 fn test_invalid_node() {
699 let allocator = NumaAllocator::default();
700 let result = allocator.allocate_on_node(4096, 999);
701 assert!(matches!(result, Err(NumaError::InvalidNode(999))));
702 }
703
704 #[test]
705 fn test_total_allocated() {
706 let allocator = NumaAllocator::default();
707
708 let initial = allocator.total_allocated();
709 let _b1 = allocator.allocate(4096).unwrap();
710 let _b2 = allocator.allocate(8192).unwrap();
711
712 let total = allocator.total_allocated();
713 assert!(total >= initial + 4096 + 8192);
715 }
716
717 #[test]
718 fn test_nearest_nodes() {
719 let topo = NumaTopology {
720 num_nodes: 3,
721 cpus_per_node: vec![vec![0, 1], vec![2, 3], vec![4, 5]],
722 memory_per_node: vec![0, 0, 0],
723 distances: vec![vec![10, 20, 30], vec![20, 10, 20], vec![30, 20, 10]],
724 };
725
726 let nearest = topo.nearest_nodes(0);
727 assert_eq!(nearest[0], 0); assert_eq!(nearest[1], 1); assert_eq!(nearest[2], 2); }
731
732 #[test]
733 fn test_vector_storage() {
734 let storage: NumaVectorStorage<f32> =
735 NumaVectorStorage::with_capacity_on_node(100, 0).unwrap();
736
737 assert_eq!(storage.len(), 0);
738 assert_eq!(storage.capacity(), 100);
739 }
740
741 #[test]
742 fn test_vector_storage_push_get() {
743 let mut storage: NumaVectorStorage<u64> =
744 NumaVectorStorage::with_capacity_on_node(10, 0).unwrap();
745
746 storage.push(42).unwrap();
747 storage.push(123).unwrap();
748
749 assert_eq!(storage.len(), 2);
750 assert_eq!(storage.get(0), Some(&42));
751 assert_eq!(storage.get(1), Some(&123));
752 assert_eq!(storage.get(2), None);
753 }
754
755 #[test]
756 fn test_thread_pinner() {
757 let topo = NumaTopology::detect();
758 let pinner = ThreadPinner::new(topo);
759
760 let _ = pinner.current_cpu();
762 let _ = pinner.current_node();
763 }
764
765 #[test]
766 fn test_thread_pinner_pin_to_cpu() {
767 let topo = NumaTopology::detect();
768 let pinner = ThreadPinner::new(topo.clone());
769
770 if let Some(cpus) = topo.cpus_per_node.first() {
772 if let Some(&cpu) = cpus.first() {
773 let result = pinner.pin_to_cpu(cpu);
775 #[cfg(target_os = "linux")]
777 assert!(result.is_ok(), "Pin should succeed on Linux");
778 #[cfg(not(target_os = "linux"))]
779 assert!(matches!(result, Err(NumaError::NotAvailable)));
780 }
781 }
782 }
783
784 #[test]
785 fn test_thread_pinner_pin_to_node() {
786 let topo = NumaTopology::detect();
787 let pinner = ThreadPinner::new(topo.clone());
788
789 let result = pinner.pin_to_node(0);
791 #[cfg(target_os = "linux")]
792 {
793 if !topo.cpus_per_node.is_empty() && !topo.cpus_per_node[0].is_empty() {
795 assert!(result.is_ok());
796 }
797 }
798 #[cfg(not(target_os = "linux"))]
799 assert!(matches!(result, Err(NumaError::NotAvailable)));
800 }
801
802 #[test]
803 fn test_thread_pinner_current_cpu_libc() {
804 let topo = NumaTopology::detect();
805 let pinner = ThreadPinner::new(topo);
806
807 #[cfg(target_os = "linux")]
809 {
810 let cpu = pinner.current_cpu();
811 assert!(cpu.is_some(), "sched_getcpu should work on Linux");
812 }
813
814 #[cfg(not(target_os = "linux"))]
816 {
817 let cpu = pinner.current_cpu();
818 assert!(cpu.is_none());
819 }
820 }
821
822 #[test]
823 fn test_page_size_libc() {
824 let allocator = NumaAllocator::default();
826 let page_size = allocator.page_size;
827
828 assert!(page_size >= 4096);
830 assert!(page_size.is_power_of_two());
831
832 #[cfg(unix)]
834 {
835 let libc_page_size = unsafe { libc::sysconf(libc::_SC_PAGESIZE) };
836 if libc_page_size > 0 {
837 assert_eq!(page_size, libc_page_size as usize);
838 }
839 }
840 }
841}