Skip to main content

trustformers_core/
numa_optimization.rs

1#![allow(unused_variables)] // NUMA optimization with platform-specific code
2
3use crate::errors::{Result, TrustformersError};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::{Arc, Mutex, RwLock};
7use std::thread;
8
9/// NUMA (Non-Uniform Memory Access) optimization system for TrustformeRS
10/// Provides intelligent memory allocation and thread scheduling for optimal performance on multi-socket systems
11///
12/// NUMA node information
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
14pub struct NumaNode {
15    pub node_id: u32,
16    pub cpu_cores: Vec<u32>,
17    pub memory_size_gb: f64,
18    pub available_memory_gb: f64,
19    pub memory_bandwidth_gbps: f64,
20    pub interconnect_latency_ns: HashMap<u32, u32>, // node_id -> latency
21    pub is_available: bool,
22}
23
24/// NUMA topology information
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct NumaTopology {
27    pub nodes: HashMap<u32, NumaNode>,
28    pub total_nodes: u32,
29    pub total_cores: u32,
30    pub total_memory_gb: f64,
31    pub node_distances: HashMap<(u32, u32), u32>, // (from, to) -> distance
32}
33
34/// NUMA allocation strategy
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
36pub enum NumaStrategy {
37    /// Allocate memory on the same node as the current thread
38    LocalNode,
39    /// Spread allocations across all available nodes
40    Interleaved,
41    /// Prefer specific nodes in order
42    PreferredNodes(Vec<u32>),
43    /// Custom strategy based on workload characteristics
44    WorkloadAware,
45    /// Bind to specific nodes
46    Bind(Vec<u32>),
47}
48
49/// NUMA memory allocation policy
50#[derive(Debug, Clone)]
51pub struct NumaPolicy {
52    pub strategy: NumaStrategy,
53    pub strict: bool, // Fail if preferred nodes are not available
54    pub fallback_strategy: Option<NumaStrategy>,
55    pub large_page_support: bool,
56    pub memory_prefetch: bool,
57}
58
59impl Default for NumaPolicy {
60    fn default() -> Self {
61        Self {
62            strategy: NumaStrategy::LocalNode,
63            strict: false,
64            fallback_strategy: Some(NumaStrategy::Interleaved),
65            large_page_support: true,
66            memory_prefetch: false,
67        }
68    }
69}
70
71/// NUMA allocation tracking
72#[derive(Debug, Clone)]
73pub struct NumaAllocation {
74    pub allocation_id: String,
75    pub node_id: u32,
76    pub size_bytes: usize,
77    pub address: usize,
78    pub allocation_time: std::time::SystemTime,
79    pub access_pattern: AccessPattern,
80}
81
82#[derive(Debug, Clone, PartialEq)]
83pub enum AccessPattern {
84    Sequential,
85    Random,
86    Strided(usize),
87    HotCold { hot_ratio: f64 },
88    ReadOnly,
89    WriteOnly,
90    ReadWrite,
91    Interleaved,
92}
93
94/// Thread affinity configuration
95#[derive(Debug, Clone)]
96pub struct ThreadAffinity {
97    pub thread_id: thread::ThreadId,
98    pub preferred_nodes: Vec<u32>,
99    pub cpu_cores: Vec<u32>,
100    pub priority: ThreadPriority,
101}
102
103#[derive(Debug, Clone, PartialEq)]
104pub enum ThreadPriority {
105    Low,
106    Normal,
107    High,
108    RealTime,
109}
110
111/// NUMA-aware memory allocator
112pub struct NumaAllocator {
113    topology: Arc<RwLock<NumaTopology>>,
114    allocations: Arc<Mutex<HashMap<String, NumaAllocation>>>,
115    policies: Arc<RwLock<HashMap<String, NumaPolicy>>>,
116    allocation_counter: Arc<Mutex<u64>>,
117    performance_monitor: Arc<Mutex<NumaPerformanceMonitor>>,
118}
119
120/// Performance monitoring for NUMA operations
121#[derive(Debug, Clone, Default)]
122pub struct NumaPerformanceMonitor {
123    pub allocation_stats: HashMap<u32, AllocationStats>,
124    pub memory_bandwidth_usage: HashMap<u32, f64>,
125    pub cross_node_traffic: HashMap<(u32, u32), u64>,
126    pub cache_miss_rates: HashMap<u32, f64>,
127    pub memory_latencies: HashMap<u32, Vec<u64>>,
128}
129
130#[derive(Debug, Default, Clone)]
131pub struct AllocationStats {
132    pub total_allocations: u64,
133    pub total_bytes: u64,
134    pub average_allocation_size: f64,
135    pub peak_memory_usage: u64,
136    pub current_memory_usage: u64,
137    pub allocation_failures: u64,
138}
139
140impl NumaAllocator {
141    pub fn new() -> Result<Self> {
142        let topology = Self::detect_numa_topology()?;
143
144        Ok(Self {
145            topology: Arc::new(RwLock::new(topology)),
146            allocations: Arc::new(Mutex::new(HashMap::new())),
147            policies: Arc::new(RwLock::new(HashMap::new())),
148            allocation_counter: Arc::new(Mutex::new(0)),
149            performance_monitor: Arc::new(Mutex::new(NumaPerformanceMonitor::default())),
150        })
151    }
152
153    /// Detect NUMA topology on the current system
154    fn detect_numa_topology() -> Result<NumaTopology> {
155        // In a real implementation, this would use system calls to detect actual NUMA topology
156        // For now, we'll create a mock topology for demonstration
157
158        let num_nodes = Self::get_numa_node_count()?;
159        let mut nodes = HashMap::new();
160        let mut node_distances = HashMap::new();
161
162        let cores_per_node = num_cpus::get() / num_nodes as usize;
163        let memory_per_node = Self::get_total_memory()? / num_nodes as f64;
164
165        for node_id in 0..num_nodes {
166            let cpu_cores: Vec<u32> = ((node_id * cores_per_node as u32)
167                ..((node_id + 1) * cores_per_node as u32))
168                .collect();
169
170            let mut interconnect_latency = HashMap::new();
171            for other_node in 0..num_nodes {
172                let latency = if node_id == other_node {
173                    10 // Local access latency (ns)
174                } else {
175                    50 + (node_id.abs_diff(other_node) * 10) // Remote access latency
176                };
177                interconnect_latency.insert(other_node, latency);
178            }
179
180            let node = NumaNode {
181                node_id,
182                cpu_cores,
183                memory_size_gb: memory_per_node,
184                available_memory_gb: memory_per_node * 0.8, // 80% available
185                memory_bandwidth_gbps: 100.0,               // GB/s
186                interconnect_latency_ns: interconnect_latency,
187                is_available: true,
188            };
189
190            nodes.insert(node_id, node);
191
192            // Calculate node distances
193            for other_node in 0..num_nodes {
194                let distance = if node_id == other_node {
195                    10 // Local distance
196                } else {
197                    20 + (node_id.abs_diff(other_node) * 10) // Remote distance
198                };
199                node_distances.insert((node_id, other_node), distance);
200            }
201        }
202
203        Ok(NumaTopology {
204            nodes,
205            total_nodes: num_nodes,
206            total_cores: num_cpus::get() as u32,
207            total_memory_gb: Self::get_total_memory()?,
208            node_distances,
209        })
210    }
211
212    fn get_numa_node_count() -> Result<u32> {
213        // Try to detect actual NUMA nodes, fallback to 1 if not available
214        #[cfg(target_os = "linux")]
215        {
216            use std::fs;
217            match fs::read_dir("/sys/devices/system/node") {
218                Ok(entries) => {
219                    let count = entries
220                        .filter_map(|entry| entry.ok())
221                        .filter(|entry| entry.file_name().to_string_lossy().starts_with("node"))
222                        .count() as u32;
223                    Ok(if count > 0 { count } else { 1 })
224                },
225                Err(_) => Ok(1),
226            }
227        }
228
229        #[cfg(not(target_os = "linux"))]
230        {
231            // For non-Linux systems, assume single NUMA node for now
232            // In a full implementation, this would use platform-specific APIs
233            Ok(std::cmp::max(1, (num_cpus::get() / 8) as u32))
234        }
235    }
236
237    fn get_total_memory() -> Result<f64> {
238        // Get total system memory in GB
239        #[cfg(target_os = "linux")]
240        {
241            use std::fs;
242            if let Ok(meminfo) = fs::read_to_string("/proc/meminfo") {
243                for line in meminfo.lines() {
244                    if line.starts_with("MemTotal:") {
245                        if let Some(kb_str) = line.split_whitespace().nth(1) {
246                            if let Ok(kb) = kb_str.parse::<u64>() {
247                                return Ok(kb as f64 / 1024.0 / 1024.0); // Convert KB to GB
248                            }
249                        }
250                    }
251                }
252            }
253        }
254
255        // Fallback estimation
256        Ok(8.0) // Default to 8GB
257    }
258
259    /// Allocate memory with NUMA awareness
260    pub fn allocate_numa_aware(
261        &self,
262        size: usize,
263        alignment: usize,
264        policy_name: Option<&str>,
265        access_pattern: AccessPattern,
266    ) -> Result<NumaAllocation> {
267        let policy = if let Some(name) = policy_name {
268            let policies = self.policies.read().expect("lock should not be poisoned");
269            policies.get(name).cloned().unwrap_or_default()
270        } else {
271            NumaPolicy::default()
272        };
273
274        let node_id = self.select_optimal_node(&policy, size, &access_pattern)?;
275
276        // Simulate memory allocation (in a real implementation, this would use NUMA-specific allocation)
277        let address = self.allocate_on_node(node_id, size, alignment)?;
278
279        let allocation_id = self.generate_allocation_id();
280        let allocation = NumaAllocation {
281            allocation_id: allocation_id.clone(),
282            node_id,
283            size_bytes: size,
284            address,
285            allocation_time: std::time::SystemTime::now(),
286            access_pattern,
287        };
288
289        // Track allocation
290        {
291            let mut allocations = self.allocations.lock().expect("lock should not be poisoned");
292            allocations.insert(allocation_id, allocation.clone());
293        }
294
295        // Update performance statistics
296        self.update_allocation_stats(node_id, size);
297
298        Ok(allocation)
299    }
300
301    /// Select optimal NUMA node based on policy and workload characteristics
302    fn select_optimal_node(
303        &self,
304        policy: &NumaPolicy,
305        size: usize,
306        access_pattern: &AccessPattern,
307    ) -> Result<u32> {
308        let topology = self.topology.read().expect("lock should not be poisoned");
309
310        match &policy.strategy {
311            NumaStrategy::LocalNode => self.get_current_node(),
312            NumaStrategy::Interleaved => self.select_least_loaded_node(&topology),
313            NumaStrategy::PreferredNodes(nodes) => {
314                self.select_from_preferred_nodes(&topology, nodes, policy.strict)
315            },
316            NumaStrategy::WorkloadAware => {
317                self.select_workload_aware_node(&topology, size, access_pattern)
318            },
319            NumaStrategy::Bind(nodes) => {
320                if nodes.is_empty() {
321                    Err(TrustformersError::other(
322                        "No nodes specified for bind strategy".to_string(),
323                    ))
324                } else {
325                    Ok(nodes[0]) // Use first node in bind list
326                }
327            },
328        }
329    }
330
331    fn get_current_node(&self) -> Result<u32> {
332        // In a real implementation, this would detect which NUMA node the current thread is running on
333        // For now, we'll use a simple heuristic based on thread ID
334        let thread_id = thread::current().id();
335        let topology = self.topology.read().expect("lock should not be poisoned");
336        let node_count = topology.total_nodes;
337
338        // Simple hash-based selection
339        let hash = format!("{:?}", thread_id).len();
340        Ok((hash as u32) % node_count)
341    }
342
343    fn select_least_loaded_node(&self, topology: &NumaTopology) -> Result<u32> {
344        let monitor = self.performance_monitor.lock().expect("lock should not be poisoned");
345
346        let least_loaded = topology
347            .nodes
348            .keys()
349            .min_by_key(|&&node_id| {
350                monitor
351                    .allocation_stats
352                    .get(&node_id)
353                    .map(|stats| stats.current_memory_usage)
354                    .unwrap_or(0)
355            })
356            .copied();
357
358        least_loaded.ok_or_else(|| TrustformersError::other("No available NUMA nodes".to_string()))
359    }
360
361    fn select_from_preferred_nodes(
362        &self,
363        topology: &NumaTopology,
364        preferred_nodes: &[u32],
365        strict: bool,
366    ) -> Result<u32> {
367        for &node_id in preferred_nodes {
368            if topology.nodes.contains_key(&node_id) {
369                let node = &topology.nodes[&node_id];
370                if node.is_available && node.available_memory_gb > 0.1 {
371                    return Ok(node_id);
372                }
373            }
374        }
375
376        if strict {
377            Err(TrustformersError::other(
378                "No preferred NUMA nodes available".to_string(),
379            ))
380        } else {
381            self.select_least_loaded_node(topology)
382        }
383    }
384
385    fn select_workload_aware_node(
386        &self,
387        topology: &NumaTopology,
388        size: usize,
389        access_pattern: &AccessPattern,
390    ) -> Result<u32> {
391        let mut scores = HashMap::new();
392        let monitor = self.performance_monitor.lock().expect("lock should not be poisoned");
393
394        for (&node_id, node) in &topology.nodes {
395            if !node.is_available {
396                continue;
397            }
398
399            let mut score = 0.0;
400
401            // Memory availability score
402            let memory_score = node.available_memory_gb / node.memory_size_gb;
403            score += memory_score * 0.3;
404
405            // Bandwidth utilization score (prefer less utilized nodes)
406            let bandwidth_util =
407                monitor.memory_bandwidth_usage.get(&node_id).copied().unwrap_or(0.0);
408            let bandwidth_score = 1.0 - (bandwidth_util / node.memory_bandwidth_gbps);
409            score += bandwidth_score * 0.2;
410
411            // Access pattern compatibility score
412            let pattern_score = match access_pattern {
413                AccessPattern::Sequential => {
414                    // Prefer nodes with lower cross-node traffic
415                    let cross_traffic: u64 = monitor
416                        .cross_node_traffic
417                        .iter()
418                        .filter(|((from, _to), _)| *from == node_id)
419                        .map(|(_, traffic)| *traffic)
420                        .sum();
421                    1.0 / (1.0 + cross_traffic as f64 / 1000000.0) // Normalize
422                },
423                AccessPattern::Random => {
424                    // Prefer nodes with better cache performance
425                    let cache_miss_rate =
426                        monitor.cache_miss_rates.get(&node_id).copied().unwrap_or(0.1);
427                    1.0 - cache_miss_rate
428                },
429                _ => 0.5, // Neutral score for other patterns
430            };
431            score += pattern_score * 0.3;
432
433            // Current load score
434            let current_load = monitor
435                .allocation_stats
436                .get(&node_id)
437                .map(|stats| {
438                    stats.current_memory_usage as f64
439                        / (node.memory_size_gb * 1024.0 * 1024.0 * 1024.0)
440                })
441                .unwrap_or(0.0);
442            let load_score = 1.0 - current_load;
443            score += load_score * 0.2;
444
445            scores.insert(node_id, score);
446        }
447
448        scores
449            .into_iter()
450            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
451            .map(|(node_id, _)| node_id)
452            .ok_or_else(|| TrustformersError::other("No suitable NUMA node found".to_string()))
453    }
454
455    fn allocate_on_node(&self, node_id: u32, size: usize, _alignment: usize) -> Result<usize> {
456        // In a real implementation, this would use NUMA-specific allocation APIs
457        // For now, we'll simulate allocation
458
459        let topology = self.topology.read().expect("lock should not be poisoned");
460        if !topology.nodes.contains_key(&node_id) {
461            return Err(TrustformersError::other(format!(
462                "Invalid NUMA node: {}",
463                node_id
464            )));
465        }
466
467        // Simulate memory allocation by returning a mock address
468        // In reality, this would call numa_alloc_onnode() or similar
469        let mock_address = 0x1000000 + (node_id as usize * 0x10000000) + size;
470        Ok(mock_address)
471    }
472
473    fn generate_allocation_id(&self) -> String {
474        let mut counter = self.allocation_counter.lock().expect("lock should not be poisoned");
475        *counter += 1;
476        format!("numa_alloc_{}", *counter)
477    }
478
479    fn update_allocation_stats(&self, node_id: u32, size: usize) {
480        let mut monitor = self.performance_monitor.lock().expect("lock should not be poisoned");
481        let stats = monitor.allocation_stats.entry(node_id).or_default();
482
483        stats.total_allocations += 1;
484        stats.total_bytes += size as u64;
485        stats.current_memory_usage += size as u64;
486        stats.average_allocation_size = stats.total_bytes as f64 / stats.total_allocations as f64;
487
488        if stats.current_memory_usage > stats.peak_memory_usage {
489            stats.peak_memory_usage = stats.current_memory_usage;
490        }
491    }
492
493    /// Set thread affinity to specific NUMA nodes
494    pub fn set_thread_affinity(&self, affinity: ThreadAffinity) -> Result<()> {
495        // In a real implementation, this would set CPU affinity using platform-specific APIs
496        // For Linux: sched_setaffinity()
497        // For Windows: SetThreadAffinityMask()
498
499        tracing::info!(
500            "Setting thread affinity for {:?} to nodes {:?}",
501            affinity.thread_id,
502            affinity.preferred_nodes
503        );
504
505        // Mock implementation - in reality would call OS-specific APIs
506        self.bind_thread_to_nodes(&affinity.preferred_nodes)?;
507
508        Ok(())
509    }
510
511    fn bind_thread_to_nodes(&self, node_ids: &[u32]) -> Result<()> {
512        #[cfg(target_os = "linux")]
513        {
514            // On Linux, we would use libnuma or direct syscalls
515            // This is a simplified mock implementation
516            tracing::debug!("Binding thread to NUMA nodes: {:?}", node_ids);
517        }
518
519        #[cfg(target_os = "windows")]
520        {
521            // On Windows, we would use SetThreadAffinityMask
522            tracing::debug!("Binding thread to NUMA nodes: {:?}", node_ids);
523        }
524
525        Ok(())
526    }
527
528    /// Free NUMA-aware allocated memory
529    pub fn deallocate(&self, allocation_id: &str) -> Result<()> {
530        let allocation = {
531            let mut allocations = self.allocations.lock().expect("lock should not be poisoned");
532            allocations.remove(allocation_id).ok_or_else(|| {
533                TrustformersError::other(format!("Allocation not found: {}", allocation_id))
534            })?
535        };
536
537        // Update statistics
538        {
539            let mut monitor = self.performance_monitor.lock().expect("lock should not be poisoned");
540            if let Some(stats) = monitor.allocation_stats.get_mut(&allocation.node_id) {
541                stats.current_memory_usage =
542                    stats.current_memory_usage.saturating_sub(allocation.size_bytes as u64);
543            }
544        }
545
546        // In a real implementation, this would call numa_free() or similar
547        tracing::debug!(
548            "Deallocated {} bytes from NUMA node {} (allocation: {})",
549            allocation.size_bytes,
550            allocation.node_id,
551            allocation_id
552        );
553
554        Ok(())
555    }
556
557    /// Register a custom NUMA policy
558    pub fn register_policy(&self, name: String, policy: NumaPolicy) {
559        let mut policies = self.policies.write().expect("lock should not be poisoned");
560        policies.insert(name, policy);
561    }
562
563    /// Get NUMA topology information
564    pub fn get_topology(&self) -> NumaTopology {
565        let topology = self.topology.read().expect("lock should not be poisoned");
566        (*topology).clone()
567    }
568
569    /// Get performance statistics
570    pub fn get_performance_stats(&self) -> NumaPerformanceMonitor {
571        let monitor = self.performance_monitor.lock().expect("lock should not be poisoned");
572        (*monitor).clone()
573    }
574
575    /// Optimize memory layout for a specific access pattern
576    pub fn optimize_memory_layout(
577        &self,
578        allocations: &[String],
579        access_pattern: AccessPattern,
580    ) -> Result<Vec<String>> {
581        let mut optimized_allocations = Vec::new();
582        let allocations_map = self.allocations.lock().expect("Lock poisoned");
583
584        match access_pattern {
585            AccessPattern::Sequential => {
586                // For sequential access, try to place allocations on the same node
587                if let Some(first_alloc) =
588                    allocations.first().and_then(|id| allocations_map.get(id))
589                {
590                    let preferred_node = first_alloc.node_id;
591
592                    for alloc_id in allocations {
593                        if let Some(allocation) = allocations_map.get(alloc_id) {
594                            if allocation.node_id != preferred_node {
595                                // Suggest migration
596                                let new_id = format!("{}_migrated", alloc_id);
597                                optimized_allocations.push(new_id);
598                            } else {
599                                optimized_allocations.push(alloc_id.clone());
600                            }
601                        }
602                    }
603                }
604            },
605            AccessPattern::Interleaved => {
606                // For interleaved access, spread allocations across nodes
607                let topology = self.topology.read().expect("lock should not be poisoned");
608                let available_nodes: Vec<u32> = topology.nodes.keys().copied().collect();
609
610                for (node_index, alloc_id) in allocations.iter().enumerate() {
611                    let target_node = available_nodes[node_index % available_nodes.len()];
612
613                    if let Some(allocation) = allocations_map.get(alloc_id) {
614                        if allocation.node_id != target_node {
615                            let new_id = format!("{}_migrated_to_node_{}", alloc_id, target_node);
616                            optimized_allocations.push(new_id);
617                        } else {
618                            optimized_allocations.push(alloc_id.clone());
619                        }
620                    }
621                }
622            },
623            _ => {
624                // For other patterns, keep current layout
625                optimized_allocations.extend_from_slice(allocations);
626            },
627        }
628
629        Ok(optimized_allocations)
630    }
631
632    /// Monitor cross-NUMA traffic and suggest optimizations
633    pub fn analyze_numa_traffic(&self) -> NumaTrafficAnalysis {
634        let monitor = self.performance_monitor.lock().expect("lock should not be poisoned");
635        let topology = self.topology.read().expect("lock should not be poisoned");
636
637        let mut analysis = NumaTrafficAnalysis {
638            total_cross_node_traffic: 0,
639            hotspots: Vec::new(),
640            optimization_suggestions: Vec::new(),
641        };
642
643        // Calculate total cross-node traffic
644        for ((from, to), traffic) in &monitor.cross_node_traffic {
645            if from != to {
646                analysis.total_cross_node_traffic += traffic;
647            }
648        }
649
650        // Identify traffic hotspots
651        let mut traffic_by_node: HashMap<u32, u64> = HashMap::new();
652        for ((from, _to), traffic) in &monitor.cross_node_traffic {
653            *traffic_by_node.entry(*from).or_insert(0) += traffic;
654        }
655
656        let mut sorted_traffic: Vec<_> = traffic_by_node.into_iter().collect();
657        sorted_traffic.sort_by_key(|item| std::cmp::Reverse(item.1));
658
659        for (node_id, traffic) in sorted_traffic.into_iter().take(3) {
660            analysis.hotspots.push(TrafficHotspot {
661                node_id,
662                traffic_volume: traffic,
663                severity: if traffic > 1000000 {
664                    HotspotSeverity::High
665                } else if traffic > 100000 {
666                    HotspotSeverity::Medium
667                } else {
668                    HotspotSeverity::Low
669                },
670            });
671        }
672
673        // Generate optimization suggestions
674        if analysis.total_cross_node_traffic > 10000000 {
675            analysis.optimization_suggestions.push(
676                "Consider using NUMA-local allocations to reduce cross-node traffic".to_string(),
677            );
678        }
679
680        for hotspot in &analysis.hotspots {
681            if hotspot.severity == HotspotSeverity::High {
682                analysis.optimization_suggestions.push(format!(
683                    "Node {} is experiencing high traffic - consider redistributing workload",
684                    hotspot.node_id
685                ));
686            }
687        }
688
689        analysis
690    }
691}
692
693#[derive(Debug, Clone)]
694pub struct NumaTrafficAnalysis {
695    pub total_cross_node_traffic: u64,
696    pub hotspots: Vec<TrafficHotspot>,
697    pub optimization_suggestions: Vec<String>,
698}
699
700#[derive(Debug, Clone)]
701pub struct TrafficHotspot {
702    pub node_id: u32,
703    pub traffic_volume: u64,
704    pub severity: HotspotSeverity,
705}
706
707#[derive(Debug, Clone, PartialEq)]
708pub enum HotspotSeverity {
709    Low,
710    Medium,
711    High,
712}
713
714/// Global NUMA allocator instance
715static NUMA_ALLOCATOR: std::sync::OnceLock<Arc<NumaAllocator>> = std::sync::OnceLock::new();
716
717/// Initialize global NUMA allocator
718pub fn init_numa_allocator() -> Result<()> {
719    let allocator = Arc::new(NumaAllocator::new()?);
720    NUMA_ALLOCATOR
721        .set(allocator)
722        .map_err(|_| TrustformersError::other("NUMA allocator already initialized".to_string()))?;
723    Ok(())
724}
725
726/// Get global NUMA allocator
727pub fn get_numa_allocator() -> Result<Arc<NumaAllocator>> {
728    NUMA_ALLOCATOR
729        .get()
730        .cloned()
731        .ok_or_else(|| TrustformersError::other("NUMA allocator not initialized".to_string()))
732}
733
734/// Convenience function for NUMA-aware allocation
735pub fn numa_alloc(
736    size: usize,
737    alignment: usize,
738    policy: Option<&str>,
739    pattern: AccessPattern,
740) -> Result<NumaAllocation> {
741    get_numa_allocator()?.allocate_numa_aware(size, alignment, policy, pattern)
742}
743
744/// Convenience function for NUMA deallocation
745pub fn numa_free(allocation_id: &str) -> Result<()> {
746    get_numa_allocator()?.deallocate(allocation_id)
747}
748
749#[cfg(test)]
750mod tests {
751    use super::*;
752
753    #[test]
754    fn test_numa_allocator_creation() {
755        let allocator = NumaAllocator::new().expect("operation failed in test");
756        let topology = allocator.get_topology();
757        assert!(topology.total_nodes > 0);
758        assert!(topology.total_cores > 0);
759    }
760
761    #[test]
762    fn test_numa_allocation() {
763        let allocator = NumaAllocator::new().expect("operation failed in test");
764
765        let allocation = allocator
766            .allocate_numa_aware(1024, 64, None, AccessPattern::Sequential)
767            .expect("operation failed in test");
768
769        assert_eq!(allocation.size_bytes, 1024);
770        assert!(!allocation.allocation_id.is_empty());
771
772        allocator
773            .deallocate(&allocation.allocation_id)
774            .expect("operation failed in test");
775    }
776
777    #[test]
778    fn test_numa_policy() {
779        let policy = NumaPolicy {
780            strategy: NumaStrategy::PreferredNodes(vec![0, 1]),
781            strict: true,
782            ..Default::default()
783        };
784
785        assert_eq!(policy.strategy, NumaStrategy::PreferredNodes(vec![0, 1]));
786        assert!(policy.strict);
787    }
788
789    #[test]
790    fn test_topology_detection() {
791        let topology = NumaAllocator::detect_numa_topology().expect("operation failed in test");
792        assert!(topology.total_nodes >= 1);
793        assert!(!topology.nodes.is_empty());
794
795        for (node_id, node) in &topology.nodes {
796            assert_eq!(*node_id, node.node_id);
797            assert!(node.memory_size_gb > 0.0);
798            assert!(!node.cpu_cores.is_empty());
799        }
800    }
801
802    #[test]
803    fn test_workload_aware_selection() {
804        let allocator = NumaAllocator::new().expect("operation failed in test");
805        let topology = allocator.get_topology();
806
807        let node_id = allocator
808            .select_workload_aware_node(&topology, 1024 * 1024, &AccessPattern::Sequential)
809            .expect("operation failed in test");
810
811        assert!(topology.nodes.contains_key(&node_id));
812    }
813
814    #[test]
815    fn test_performance_monitoring() {
816        let allocator = NumaAllocator::new().expect("operation failed in test");
817
818        // Make some allocations
819        let _alloc1 = allocator
820            .allocate_numa_aware(1024, 64, None, AccessPattern::Sequential)
821            .expect("operation failed in test");
822
823        let _alloc2 = allocator
824            .allocate_numa_aware(2048, 64, None, AccessPattern::Random)
825            .expect("operation failed in test");
826
827        let stats = allocator.get_performance_stats();
828        let total_allocations: u64 =
829            stats.allocation_stats.values().map(|s| s.total_allocations).sum();
830
831        assert!(total_allocations >= 2);
832    }
833
834    #[test]
835    fn test_memory_layout_optimization() {
836        let allocator = NumaAllocator::new().expect("operation failed in test");
837
838        let alloc1 = allocator
839            .allocate_numa_aware(1024, 64, None, AccessPattern::Sequential)
840            .expect("operation failed in test");
841
842        let alloc2 = allocator
843            .allocate_numa_aware(1024, 64, None, AccessPattern::Sequential)
844            .expect("operation failed in test");
845
846        let allocation_ids = vec![alloc1.allocation_id.clone(), alloc2.allocation_id.clone()];
847
848        let optimized = allocator
849            .optimize_memory_layout(&allocation_ids, AccessPattern::Sequential)
850            .expect("operation failed in test");
851
852        assert_eq!(optimized.len(), allocation_ids.len());
853    }
854}