strange_loop/nano_agent/
optimization.rs

1//! SIMD optimizations and cache-aligned data structures for nano-agents
2
3#[cfg(all(target_arch = "x86_64", not(target_family = "wasm")))]
4use std::arch::x86_64::*;
5use std::mem;
6
7/// Cache-aligned vector for SIMD operations
8#[repr(align(64))] // Align to cache line (64 bytes)
9pub struct AlignedVector {
10    data: Vec<f32>,
11    capacity: usize,
12}
13
14impl AlignedVector {
15    /// Create new cache-aligned vector with specified capacity
16    pub fn new(capacity: usize) -> Self {
17        let aligned_capacity = (capacity + 15) & !15; // Round up to multiple of 16
18        let mut data = Vec::with_capacity(aligned_capacity);
19        data.resize(aligned_capacity, 0.0);
20
21        Self {
22            data,
23            capacity: aligned_capacity,
24        }
25    }
26
27    /// Get raw pointer for SIMD operations
28    pub fn as_ptr(&self) -> *const f32 {
29        self.data.as_ptr()
30    }
31
32    /// Get mutable raw pointer for SIMD operations
33    pub fn as_mut_ptr(&mut self) -> *mut f32 {
34        self.data.as_mut_ptr()
35    }
36
37    /// Get length of vector
38    pub fn len(&self) -> usize {
39        self.data.len()
40    }
41
42    /// Check if vector is empty
43    pub fn is_empty(&self) -> bool {
44        self.data.is_empty()
45    }
46
47    /// SIMD-accelerated vector addition (x86_64 only)
48    #[cfg(all(target_arch = "x86_64", not(target_family = "wasm")))]
49    #[target_feature(enable = "avx2")]
50    pub unsafe fn simd_add(&mut self, other: &AlignedVector) -> Result<(), &'static str> {
51        if self.len() != other.len() {
52            return Err("Vector lengths must match");
53        }
54
55        let len = self.len();
56        let chunks = len / 8; // Process 8 f32s at a time with AVX2
57
58        let self_ptr = self.as_mut_ptr();
59        let other_ptr = other.as_ptr();
60
61        // Process chunks of 8 elements
62        for i in 0..chunks {
63            let offset = i * 8;
64
65            // Load 8 f32 values from each vector
66            let a = _mm256_load_ps(self_ptr.add(offset));
67            let b = _mm256_load_ps(other_ptr.add(offset));
68
69            // Perform SIMD addition
70            let result = _mm256_add_ps(a, b);
71
72            // Store result back
73            _mm256_store_ps(self_ptr.add(offset), result);
74        }
75
76        // Handle remaining elements
77        for i in (chunks * 8)..len {
78            *self_ptr.add(i) += *other_ptr.add(i);
79        }
80
81        Ok(())
82    }
83
84    /// Fallback vector addition for WASM and other targets
85    #[cfg(any(not(target_arch = "x86_64"), target_family = "wasm"))]
86    pub fn simd_add(&mut self, other: &AlignedVector) -> Result<(), &'static str> {
87        if self.len() != other.len() {
88            return Err("Vector lengths must match");
89        }
90
91        for i in 0..self.len() {
92            self.data[i] += other.data[i];
93        }
94
95        Ok(())
96    }
97
98    /// SIMD-accelerated dot product (x86_64 only)
99    #[cfg(all(target_arch = "x86_64", not(target_family = "wasm")))]
100    #[target_feature(enable = "avx2")]
101    pub unsafe fn simd_dot(&self, other: &AlignedVector) -> Result<f32, &'static str> {
102        if self.len() != other.len() {
103            return Err("Vector lengths must match");
104        }
105
106        let len = self.len();
107        let chunks = len / 8;
108
109        let self_ptr = self.as_ptr();
110        let other_ptr = other.as_ptr();
111
112        // Accumulator for sum
113        let mut sum_vec = _mm256_setzero_ps();
114
115        // Process chunks of 8 elements
116        for i in 0..chunks {
117            let offset = i * 8;
118
119            let a = _mm256_load_ps(self_ptr.add(offset));
120            let b = _mm256_load_ps(other_ptr.add(offset));
121
122            // Multiply and accumulate
123            let product = _mm256_mul_ps(a, b);
124            sum_vec = _mm256_add_ps(sum_vec, product);
125        }
126
127        // Horizontal sum of the accumulated vector
128        let mut result_array = [0.0f32; 8];
129        _mm256_store_ps(result_array.as_mut_ptr(), sum_vec);
130        let mut dot_product: f32 = result_array.iter().sum();
131
132        // Handle remaining elements
133        for i in (chunks * 8)..len {
134            dot_product += *self_ptr.add(i) * *other_ptr.add(i);
135        }
136
137        Ok(dot_product)
138    }
139
140    /// Fallback dot product for WASM and other targets
141    #[cfg(any(not(target_arch = "x86_64"), target_family = "wasm"))]
142    pub fn simd_dot(&self, other: &AlignedVector) -> Result<f32, &'static str> {
143        if self.len() != other.len() {
144            return Err("Vector lengths must match");
145        }
146
147        let dot_product: f32 = self.data.iter()
148            .zip(&other.data)
149            .map(|(a, b)| a * b)
150            .sum();
151
152        Ok(dot_product)
153    }
154
155    /// SIMD-accelerated vector scaling (x86_64 only)
156    #[cfg(all(target_arch = "x86_64", not(target_family = "wasm")))]
157    #[target_feature(enable = "avx2")]
158    pub unsafe fn simd_scale(&mut self, scalar: f32) {
159        let len = self.len();
160        let chunks = len / 8;
161
162        let self_ptr = self.as_mut_ptr();
163        let scalar_vec = _mm256_set1_ps(scalar); // Broadcast scalar to all elements
164
165        // Process chunks of 8 elements
166        for i in 0..chunks {
167            let offset = i * 8;
168
169            let a = _mm256_load_ps(self_ptr.add(offset));
170            let result = _mm256_mul_ps(a, scalar_vec);
171            _mm256_store_ps(self_ptr.add(offset), result);
172        }
173
174        // Handle remaining elements
175        for i in (chunks * 8)..len {
176            *self_ptr.add(i) *= scalar;
177        }
178    }
179
180    /// Fallback vector scaling for WASM and other targets
181    #[cfg(any(not(target_arch = "x86_64"), target_family = "wasm"))]
182    pub fn simd_scale(&mut self, scalar: f32) {
183        for value in &mut self.data {
184            *value *= scalar;
185        }
186    }
187}
188
189/// Cache-optimized agent state structure
190#[repr(align(64))]
191pub struct AgentState {
192    // Hot data (frequently accessed) - first cache line
193    pub position: [f32; 3],
194    pub velocity: [f32; 3],
195    pub acceleration: [f32; 3],
196    pub energy: f32,
197    pub active: bool,
198    _padding1: [u8; 31], // Pad to cache line boundary
199
200    // Warm data - second cache line
201    pub parameters: AlignedVector,
202    pub last_update_ns: u128,
203    pub performance_score: f32,
204    _padding2: [u8; 36],
205
206    // Cold data - third cache line
207    pub debug_info: String,
208    pub creation_time: std::time::Instant,
209}
210
211impl AgentState {
212    pub fn new(param_count: usize) -> Self {
213        Self {
214            position: [0.0; 3],
215            velocity: [0.0; 3],
216            acceleration: [0.0; 3],
217            energy: 1.0,
218            active: true,
219            _padding1: [0; 31],
220            parameters: AlignedVector::new(param_count),
221            last_update_ns: 0,
222            performance_score: 0.0,
223            _padding2: [0; 36],
224            debug_info: String::new(),
225            creation_time: std::time::Instant::now(),
226        }
227    }
228
229    /// SIMD-optimized state update (x86_64 only)
230    #[cfg(all(target_arch = "x86_64", not(target_family = "wasm")))]
231    pub fn simd_update(&mut self, dt: f32) {
232        unsafe {
233            // Update position using SIMD for 3D vector operations
234            let pos_ptr = self.position.as_mut_ptr();
235            let vel_ptr = self.velocity.as_ptr();
236
237            // Load position and velocity vectors (pad to 4 elements for SIMD)
238            let mut pos_padded = [0.0f32; 4];
239            let mut vel_padded = [0.0f32; 4];
240
241            pos_padded[..3].copy_from_slice(&self.position);
242            vel_padded[..3].copy_from_slice(&self.velocity);
243
244            let pos_vec = _mm_load_ps(pos_padded.as_ptr());
245            let vel_vec = _mm_load_ps(vel_padded.as_ptr());
246            let dt_vec = _mm_set1_ps(dt);
247
248            // position += velocity * dt
249            let vel_scaled = _mm_mul_ps(vel_vec, dt_vec);
250            let new_pos = _mm_add_ps(pos_vec, vel_scaled);
251
252            // Store result back (only first 3 elements)
253            _mm_store_ps(pos_padded.as_mut_ptr(), new_pos);
254            self.position.copy_from_slice(&pos_padded[..3]);
255        }
256    }
257
258    /// Fallback state update for WASM and other targets
259    #[cfg(any(not(target_arch = "x86_64"), target_family = "wasm"))]
260    pub fn simd_update(&mut self, dt: f32) {
261        // Simple scalar version
262        for i in 0..3 {
263            self.position[i] += self.velocity[i] * dt;
264        }
265    }
266}
267
268/// SIMD-optimized batch operations for multiple agents
269pub struct BatchProcessor {
270    positions: AlignedVector,
271    velocities: AlignedVector,
272    accelerations: AlignedVector,
273    agent_count: usize,
274}
275
276impl BatchProcessor {
277    pub fn new(max_agents: usize) -> Self {
278        Self {
279            positions: AlignedVector::new(max_agents * 3),
280            velocities: AlignedVector::new(max_agents * 3),
281            accelerations: AlignedVector::new(max_agents * 3),
282            agent_count: 0,
283        }
284    }
285
286    /// Batch update all agent positions using SIMD (x86_64 only)
287    #[cfg(all(target_arch = "x86_64", not(target_family = "wasm")))]
288    #[target_feature(enable = "avx2")]
289    pub unsafe fn batch_update_positions(&mut self, dt: f32) {
290        // positions += velocities * dt + 0.5 * accelerations * dt^2
291
292        let len = self.agent_count * 3;
293        let chunks = len / 8;
294
295        let pos_ptr = self.positions.as_mut_ptr();
296        let vel_ptr = self.velocities.as_ptr();
297        let acc_ptr = self.accelerations.as_ptr();
298
299        let dt_vec = _mm256_set1_ps(dt);
300        let dt2_vec = _mm256_set1_ps(dt * dt * 0.5);
301
302        for i in 0..chunks {
303            let offset = i * 8;
304
305            let pos = _mm256_load_ps(pos_ptr.add(offset));
306            let vel = _mm256_load_ps(vel_ptr.add(offset));
307            let acc = _mm256_load_ps(acc_ptr.add(offset));
308
309            // velocity * dt
310            let vel_term = _mm256_mul_ps(vel, dt_vec);
311
312            // 0.5 * acceleration * dt^2
313            let acc_term = _mm256_mul_ps(acc, dt2_vec);
314
315            // position + vel_term + acc_term
316            let result = _mm256_add_ps(pos, _mm256_add_ps(vel_term, acc_term));
317
318            _mm256_store_ps(pos_ptr.add(offset), result);
319        }
320
321        // Handle remaining elements
322        for i in (chunks * 8)..len {
323            *pos_ptr.add(i) += *vel_ptr.add(i) * dt + 0.5 * *acc_ptr.add(i) * dt * dt;
324        }
325    }
326
327    /// Fallback batch update for WASM and other targets
328    #[cfg(any(not(target_arch = "x86_64"), target_family = "wasm"))]
329    pub fn batch_update_positions(&mut self, dt: f32) {
330        let len = self.agent_count * 3;
331        for i in 0..len {
332            self.positions.data[i] += self.velocities.data[i] * dt + 0.5 * self.accelerations.data[i] * dt * dt;
333        }
334    }
335
336    /// Calculate forces between agents using SIMD (x86_64 only)
337    #[cfg(all(target_arch = "x86_64", not(target_family = "wasm")))]
338    #[target_feature(enable = "avx2")]
339    pub unsafe fn calculate_forces(&mut self) -> AlignedVector {
340        let mut forces = AlignedVector::new(self.agent_count * 3);
341
342        // Simplified force calculation (normally would be N^2 complexity)
343        // This is a placeholder for actual force computation
344
345        forces
346    }
347
348    /// Fallback force calculation for WASM and other targets
349    #[cfg(any(not(target_arch = "x86_64"), target_family = "wasm"))]
350    pub fn calculate_forces(&mut self) -> AlignedVector {
351        let forces = AlignedVector::new(self.agent_count * 3);
352
353        // Simplified force calculation (normally would be N^2 complexity)
354        // This is a placeholder for actual force computation
355
356        forces
357    }
358}
359
360/// Memory pool for zero-allocation agent operations
361pub struct AgentMemoryPool {
362    states: Vec<AgentState>,
363    free_indices: Vec<usize>,
364    capacity: usize,
365}
366
367impl AgentMemoryPool {
368    pub fn new(capacity: usize) -> Self {
369        let mut states = Vec::with_capacity(capacity);
370        let mut free_indices = Vec::with_capacity(capacity);
371
372        for i in 0..capacity {
373            states.push(AgentState::new(16)); // 16 parameters per agent
374            free_indices.push(i);
375        }
376
377        Self {
378            states,
379            free_indices,
380            capacity,
381        }
382    }
383
384    pub fn allocate_agent(&mut self) -> Option<usize> {
385        self.free_indices.pop()
386    }
387
388    pub fn deallocate_agent(&mut self, index: usize) {
389        if index < self.capacity {
390            self.free_indices.push(index);
391        }
392    }
393
394    pub fn get_state(&self, index: usize) -> Option<&AgentState> {
395        self.states.get(index)
396    }
397
398    pub fn get_state_mut(&mut self, index: usize) -> Option<&mut AgentState> {
399        self.states.get_mut(index)
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn test_aligned_vector_creation() {
409        let vec = AlignedVector::new(100);
410        assert_eq!(vec.len(), 112); // Rounded up to multiple of 16
411        assert_eq!(vec.as_ptr() as usize % 64, 0); // Cache-aligned
412    }
413
414    #[test]
415    fn test_simd_operations() {
416        let mut a = AlignedVector::new(16);
417        let mut b = AlignedVector::new(16);
418
419        // Initialize test data
420        for i in 0..16 {
421            a.data[i] = i as f32;
422            b.data[i] = (i * 2) as f32;
423        }
424
425        // Test addition
426        unsafe { a.simd_add(&b).unwrap(); }
427
428        // Test dot product
429        let dot = unsafe { a.simd_dot(&b).unwrap() };
430        assert!(dot > 0.0);
431
432        // Test scaling
433        #[cfg(all(target_arch = "x86_64", not(target_family = "wasm")))]
434        unsafe {
435            a.simd_scale(2.0);
436        }
437        #[cfg(any(not(target_arch = "x86_64"), target_family = "wasm"))]
438        a.simd_scale(2.0);
439    }
440
441    #[test]
442    fn test_agent_state_alignment() {
443        let state = AgentState::new(16);
444        let ptr = &state as *const AgentState as usize;
445        assert_eq!(ptr % 64, 0); // Cache-aligned
446    }
447
448    #[test]
449    fn test_memory_pool() {
450        let mut pool = AgentMemoryPool::new(10);
451
452        let agent1 = pool.allocate_agent().unwrap();
453        let agent2 = pool.allocate_agent().unwrap();
454
455        assert_ne!(agent1, agent2);
456
457        pool.deallocate_agent(agent1);
458        let agent3 = pool.allocate_agent().unwrap();
459        assert_eq!(agent1, agent3); // Reused index
460    }
461}