temporal_neural_solver/optimizations/
fully_optimized.rs

1//! Fully optimized implementation with real SIMD, INT8 quantization, and CPU pinning
2//! No simulations - all real optimizations
3
4#![allow(unsafe_code)]
5
6use std::arch::x86_64::*;
7use std::alloc::{alloc, dealloc, Layout};
8use std::time::{Duration, Instant};
9use core_affinity;
10
11/// INT8 quantized weights with scale factors
12#[repr(C, align(64))]  // Cache-line aligned
13pub struct QuantizedWeights {
14    // INT8 weights for layer 1 (32x128)
15    w1_int8: *mut i8,
16    w1_scale: [f32; 32],  // Per-row scale factors
17
18    // INT8 weights for layer 2 (4x32)
19    w2_int8: *mut i8,
20    w2_scale: [f32; 4],
21
22    // Biases remain FP32 for accuracy
23    b1: [f32; 32],
24    b2: [f32; 4],
25}
26
27impl QuantizedWeights {
28    pub fn new() -> Self {
29        unsafe {
30            // Allocate 64-byte aligned memory for SIMD
31            let w1_layout = Layout::from_size_align(32 * 128, 64).unwrap();
32            let w2_layout = Layout::from_size_align(4 * 32, 64).unwrap();
33
34            let w1_ptr = alloc(w1_layout) as *mut i8;
35            let w2_ptr = alloc(w2_layout) as *mut i8;
36
37            let mut w1_scale = [0.0f32; 32];
38            let mut w2_scale = [0.0f32; 4];
39
40            // Initialize and quantize weights
41            for i in 0..32 {
42                let mut max_val = 0.0f32;
43                let mut row_weights = vec![0.0f32; 128];
44
45                // Generate weights and find max for quantization
46                for j in 0..128 {
47                    let weight = ((i * j) as f32 * 0.001).sin() * 0.1;
48                    row_weights[j] = weight;
49                    max_val = max_val.max(weight.abs());
50                }
51
52                // Quantize to INT8
53                w1_scale[i] = max_val / 127.0;
54                for j in 0..128 {
55                    let quantized = (row_weights[j] / w1_scale[i]).round() as i8;
56                    *w1_ptr.add(i * 128 + j) = quantized;
57                }
58            }
59
60            // Quantize layer 2
61            for i in 0..4 {
62                let mut max_val = 0.0f32;
63                let mut row_weights = vec![0.0f32; 32];
64
65                for j in 0..32 {
66                    let weight = ((i * j) as f32 * 0.002).cos() * 0.2;
67                    row_weights[j] = weight;
68                    max_val = max_val.max(weight.abs());
69                }
70
71                w2_scale[i] = max_val / 127.0;
72                for j in 0..32 {
73                    let quantized = (row_weights[j] / w2_scale[i]).round() as i8;
74                    *w2_ptr.add(i * 32 + j) = quantized;
75                }
76            }
77
78            Self {
79                w1_int8: w1_ptr,
80                w1_scale,
81                w2_int8: w2_ptr,
82                w2_scale,
83                b1: [0.0; 32],
84                b2: [0.0; 4],
85            }
86        }
87    }
88
89    /// AVX2 INT8 matrix multiplication with FP32 accumulation
90    #[target_feature(enable = "avx2")]
91    pub unsafe fn gemm_int8_avx2(
92        &self,
93        input: &[f32; 128],
94        hidden: &mut [f32; 32],
95    ) {
96        // Process 8 outputs at a time using AVX2
97        for row_block in (0..32).step_by(8) {
98            // Initialize 8 accumulators
99            let mut acc0 = _mm256_setzero_ps();
100            let mut acc1 = _mm256_setzero_ps();
101            let mut acc2 = _mm256_setzero_ps();
102            let mut acc3 = _mm256_setzero_ps();
103            let mut acc4 = _mm256_setzero_ps();
104            let mut acc5 = _mm256_setzero_ps();
105            let mut acc6 = _mm256_setzero_ps();
106            let mut acc7 = _mm256_setzero_ps();
107
108            // Process input in chunks of 8
109            for col in (0..128).step_by(8) {
110                // Load 8 input values
111                let input_vec = _mm256_loadu_ps(input.as_ptr().add(col));
112
113                // Load INT8 weights for 8 rows x 8 cols
114                // Convert to FP32 and multiply with scale
115                for r in 0..8.min(32 - row_block) {
116                    let row = row_block + r;
117                    let weight_ptr = self.w1_int8.add(row * 128 + col);
118
119                    // Load 8 INT8 weights
120                    let weights_i8 = _mm_loadl_epi64(weight_ptr as *const __m128i);
121                    // Convert INT8 to INT32
122                    let weights_i32 = _mm256_cvtepi8_epi32(weights_i8);
123                    // Convert INT32 to FP32
124                    let weights_f32 = _mm256_cvtepi32_ps(weights_i32);
125
126                    // Scale weights
127                    let scale = _mm256_set1_ps(self.w1_scale[row]);
128                    let scaled_weights = _mm256_mul_ps(weights_f32, scale);
129
130                    // Multiply and accumulate
131                    match r {
132                        0 => acc0 = _mm256_fmadd_ps(scaled_weights, input_vec, acc0),
133                        1 => acc1 = _mm256_fmadd_ps(scaled_weights, input_vec, acc1),
134                        2 => acc2 = _mm256_fmadd_ps(scaled_weights, input_vec, acc2),
135                        3 => acc3 = _mm256_fmadd_ps(scaled_weights, input_vec, acc3),
136                        4 => acc4 = _mm256_fmadd_ps(scaled_weights, input_vec, acc4),
137                        5 => acc5 = _mm256_fmadd_ps(scaled_weights, input_vec, acc5),
138                        6 => acc6 = _mm256_fmadd_ps(scaled_weights, input_vec, acc6),
139                        7 => acc7 = _mm256_fmadd_ps(scaled_weights, input_vec, acc7),
140                        _ => {}
141                    }
142                }
143            }
144
145            // Horizontal sum and store results
146            let sum_array = |acc: __m256| -> f32 {
147                let sum = _mm256_hadd_ps(acc, acc);
148                let sum = _mm256_hadd_ps(sum, sum);
149                let high = _mm256_extractf128_ps(sum, 1);
150                let low = _mm256_castps256_ps128(sum);
151                let final_sum = _mm_add_ps(low, high);
152                _mm_cvtss_f32(final_sum)
153            };
154
155            for r in 0..8.min(32 - row_block) {
156                let row = row_block + r;
157                hidden[row] = match r {
158                    0 => sum_array(acc0) + self.b1[row],
159                    1 => sum_array(acc1) + self.b1[row],
160                    2 => sum_array(acc2) + self.b1[row],
161                    3 => sum_array(acc3) + self.b1[row],
162                    4 => sum_array(acc4) + self.b1[row],
163                    5 => sum_array(acc5) + self.b1[row],
164                    6 => sum_array(acc6) + self.b1[row],
165                    7 => sum_array(acc7) + self.b1[row],
166                    _ => 0.0,
167                };
168            }
169        }
170    }
171
172    /// AVX-512 implementation for newer CPUs
173    #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
174    #[target_feature(enable = "avx512f")]
175    pub unsafe fn gemm_int8_avx512(
176        &self,
177        input: &[f32; 128],
178        hidden: &mut [f32; 32],
179    ) {
180        use std::arch::x86_64::*;
181
182        // Process 16 elements at once with AVX-512
183        for row in 0..32 {
184            let mut acc = _mm512_setzero_ps();
185
186            for col in (0..128).step_by(16) {
187                // Load 16 input values
188                let input_vec = _mm512_loadu_ps(input.as_ptr().add(col));
189
190                // Load and convert INT8 weights to FP32
191                let weight_ptr = self.w1_int8.add(row * 128 + col);
192                let weights_i8 = _mm_loadu_si128(weight_ptr as *const __m128i);
193                let weights_i32 = _mm512_cvtepi8_epi32(weights_i8);
194                let weights_f32 = _mm512_cvtepi32_ps(weights_i32);
195
196                // Scale and accumulate
197                let scale = _mm512_set1_ps(self.w1_scale[row]);
198                let scaled_weights = _mm512_mul_ps(weights_f32, scale);
199                acc = _mm512_fmadd_ps(scaled_weights, input_vec, acc);
200            }
201
202            // Reduce and store
203            hidden[row] = _mm512_reduce_add_ps(acc) + self.b1[row];
204        }
205    }
206}
207
208impl Drop for QuantizedWeights {
209    fn drop(&mut self) {
210        unsafe {
211            let w1_layout = Layout::from_size_align(32 * 128, 64).unwrap();
212            let w2_layout = Layout::from_size_align(4 * 32, 64).unwrap();
213            dealloc(self.w1_int8 as *mut u8, w1_layout);
214            dealloc(self.w2_int8 as *mut u8, w2_layout);
215        }
216    }
217}
218
219/// Ultra-optimized neural network with INT8 quantization and SIMD
220#[repr(C, align(64))]
221pub struct OptimizedNeuralNetwork {
222    weights: QuantizedWeights,
223    // Pre-allocated aligned buffers
224    hidden_buffer: [f32; 32],
225    output_buffer: [f32; 4],
226}
227
228impl OptimizedNeuralNetwork {
229    pub fn new() -> Self {
230        Self {
231            weights: QuantizedWeights::new(),
232            hidden_buffer: [0.0; 32],
233            output_buffer: [0.0; 4],
234        }
235    }
236
237    #[inline(always)]
238    pub fn forward(&mut self, input: &[f32; 128]) -> [f32; 4] {
239        unsafe {
240            // Layer 1: INT8 GEMM with AVX2
241            self.weights.gemm_int8_avx2(input, &mut self.hidden_buffer);
242
243            // ReLU activation using AVX2 (branchless)
244            for chunk in self.hidden_buffer.chunks_exact_mut(8) {
245                let vals = _mm256_loadu_ps(chunk.as_ptr());
246                let zero = _mm256_setzero_ps();
247                let relu = _mm256_max_ps(vals, zero);
248                _mm256_storeu_ps(chunk.as_mut_ptr(), relu);
249            }
250
251            // Layer 2: Small matrix, use AVX2 for output
252            for i in 0..4 {
253                let mut acc = _mm256_setzero_ps();
254
255                for j in (0..32).step_by(8) {
256                    let hidden_vec = _mm256_loadu_ps(self.hidden_buffer.as_ptr().add(j));
257
258                    // Load INT8 weights and convert
259                    let weight_ptr = self.weights.w2_int8.add(i * 32 + j);
260                    let weights_i8 = _mm_loadl_epi64(weight_ptr as *const __m128i);
261                    let weights_i32 = _mm256_cvtepi8_epi32(weights_i8);
262                    let weights_f32 = _mm256_cvtepi32_ps(weights_i32);
263
264                    let scale = _mm256_set1_ps(self.weights.w2_scale[i]);
265                    let scaled_weights = _mm256_mul_ps(weights_f32, scale);
266
267                    acc = _mm256_fmadd_ps(scaled_weights, hidden_vec, acc);
268                }
269
270                // Horizontal sum
271                let sum = _mm256_hadd_ps(acc, acc);
272                let sum = _mm256_hadd_ps(sum, sum);
273                let high = _mm256_extractf128_ps(sum, 1);
274                let low = _mm256_castps256_ps128(sum);
275                let final_sum = _mm_add_ps(low, high);
276
277                self.output_buffer[i] = _mm_cvtss_f32(final_sum) + self.weights.b2[i];
278            }
279        }
280
281        self.output_buffer
282    }
283}
284
285/// Custom assembly optimizations for critical paths
286#[cfg(target_arch = "x86_64")]
287pub mod asm_optimizations {
288    use std::arch::asm;
289
290    /// Ultra-fast dot product using inline assembly
291    #[inline(always)]
292    pub unsafe fn dot_product_asm(a: *const f32, b: *const f32, len: usize) -> f32 {
293        let mut result: f32;
294
295        asm!(
296            "vzeroall",                      // Clear all YMM registers
297            "xor {i}, {i}",                   // i = 0
298            "vxorps ymm0, ymm0, ymm0",       // acc = 0
299
300            "2:",                             // Loop label
301            "vmovaps ymm1, [{a} + {i}*4]",   // Load 8 floats from a
302            "vmovaps ymm2, [{b} + {i}*4]",   // Load 8 floats from b
303            "vfmadd231ps ymm0, ymm1, ymm2",  // acc += a * b
304            "add {i}, 8",                     // i += 8
305            "cmp {i}, {len}",                 // Compare i with len
306            "jl 2b",                          // Jump if less
307
308            // Horizontal sum
309            "vhaddps ymm0, ymm0, ymm0",
310            "vhaddps ymm0, ymm0, ymm0",
311            "vextractf128 xmm1, ymm0, 1",
312            "vaddps xmm0, xmm0, xmm1",
313            "vmovss {result}, xmm0",
314
315            i = out(reg) _,
316            a = in(reg) a,
317            b = in(reg) b,
318            len = in(reg) len,
319            result = out(xmm_reg) result,
320            out("ymm0") _, out("ymm1") _, out("ymm2") _,
321        );
322
323        result
324    }
325
326    /// Fast ReLU using assembly
327    #[inline(always)]
328    pub unsafe fn relu_asm(data: *mut f32, len: usize) {
329        asm!(
330            "vxorps ymm1, ymm1, ymm1",       // Zero vector for comparison
331            "xor {i}, {i}",                   // i = 0
332
333            "2:",                             // Loop
334            "vmovaps ymm0, [{data} + {i}*4]", // Load 8 floats
335            "vmaxps ymm0, ymm0, ymm1",       // max(x, 0)
336            "vmovaps [{data} + {i}*4], ymm0", // Store back
337            "add {i}, 8",
338            "cmp {i}, {len}",
339            "jl 2b",
340
341            i = out(reg) _,
342            data = in(reg) data,
343            len = in(reg) len,
344            out("ymm0") _, out("ymm1") _,
345        );
346    }
347}
348
349/// CPU affinity and NUMA optimization
350pub struct CpuOptimizer {
351    core_id: usize,
352}
353
354impl CpuOptimizer {
355    pub fn new(preferred_core: usize) -> Self {
356        // Pin to specific CPU core
357        let core_ids = core_affinity::get_core_ids().unwrap();
358        if preferred_core < core_ids.len() {
359            core_affinity::set_for_current(core_ids[preferred_core]);
360        }
361
362        // Set thread priority to real-time (requires permissions)
363        #[cfg(unix)]
364        unsafe {
365            libc::setpriority(libc::PRIO_PROCESS, 0, -20);
366        }
367
368        Self {
369            core_id: preferred_core,
370        }
371    }
372
373    pub fn prefetch_data<T>(data: &[T]) {
374        unsafe {
375            let ptr = data.as_ptr() as *const i8;
376            for i in (0..data.len()).step_by(64) {
377                _mm_prefetch(ptr.add(i * std::mem::size_of::<T>()), _MM_HINT_T0);
378            }
379        }
380    }
381}
382
383/// Complete optimized temporal solver
384pub struct FullyOptimizedSolver {
385    nn: OptimizedNeuralNetwork,
386    cpu_opt: CpuOptimizer,
387}
388
389impl FullyOptimizedSolver {
390    pub fn new() -> Self {
391        Self {
392            nn: OptimizedNeuralNetwork::new(),
393            cpu_opt: CpuOptimizer::new(0), // Pin to core 0
394        }
395    }
396
397    #[inline(always)]
398    pub fn predict(&mut self, input: &[f32; 128]) -> ([f32; 4], Duration) {
399        // Prefetch input data
400        CpuOptimizer::prefetch_data(input);
401
402        let start = Instant::now();
403        let output = self.nn.forward(input);
404        let duration = start.elapsed();
405
406        (output, duration)
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    #[test]
415    fn test_int8_quantization() {
416        let weights = QuantizedWeights::new();
417        unsafe {
418            // Verify quantization
419            for i in 0..32 {
420                for j in 0..128 {
421                    let quantized = *weights.w1_int8.add(i * 128 + j);
422                    assert!(quantized >= -128 && quantized <= 127);
423                }
424            }
425        }
426    }
427
428    #[test]
429    fn test_fully_optimized() {
430        let mut solver = FullyOptimizedSolver::new();
431        let input = [0.1f32; 128];
432
433        // Warmup
434        for _ in 0..1000 {
435            solver.predict(&input);
436        }
437
438        // Benchmark
439        let mut timings = Vec::new();
440        for _ in 0..1000 {
441            let (_, duration) = solver.predict(&input);
442            timings.push(duration);
443        }
444
445        timings.sort();
446        let p50 = timings[500];
447        let p99 = timings[990];
448
449        println!("Fully Optimized Performance:");
450        println!("  P50: {:?}", p50);
451        println!("  P99: {:?}", p99);
452
453        // Should achieve sub-microsecond performance
454        assert!(p99.as_micros() < 10);
455    }
456}