1#![allow(unsafe_code)]
5
6use std::arch::x86_64::*;
7use std::alloc::{alloc, dealloc, Layout};
8use std::time::{Duration, Instant};
9use core_affinity;
10
11#[repr(C, align(64))] pub struct QuantizedWeights {
14 w1_int8: *mut i8,
16 w1_scale: [f32; 32], w2_int8: *mut i8,
20 w2_scale: [f32; 4],
21
22 b1: [f32; 32],
24 b2: [f32; 4],
25}
26
27impl QuantizedWeights {
28 pub fn new() -> Self {
29 unsafe {
30 let w1_layout = Layout::from_size_align(32 * 128, 64).unwrap();
32 let w2_layout = Layout::from_size_align(4 * 32, 64).unwrap();
33
34 let w1_ptr = alloc(w1_layout) as *mut i8;
35 let w2_ptr = alloc(w2_layout) as *mut i8;
36
37 let mut w1_scale = [0.0f32; 32];
38 let mut w2_scale = [0.0f32; 4];
39
40 for i in 0..32 {
42 let mut max_val = 0.0f32;
43 let mut row_weights = vec![0.0f32; 128];
44
45 for j in 0..128 {
47 let weight = ((i * j) as f32 * 0.001).sin() * 0.1;
48 row_weights[j] = weight;
49 max_val = max_val.max(weight.abs());
50 }
51
52 w1_scale[i] = max_val / 127.0;
54 for j in 0..128 {
55 let quantized = (row_weights[j] / w1_scale[i]).round() as i8;
56 *w1_ptr.add(i * 128 + j) = quantized;
57 }
58 }
59
60 for i in 0..4 {
62 let mut max_val = 0.0f32;
63 let mut row_weights = vec![0.0f32; 32];
64
65 for j in 0..32 {
66 let weight = ((i * j) as f32 * 0.002).cos() * 0.2;
67 row_weights[j] = weight;
68 max_val = max_val.max(weight.abs());
69 }
70
71 w2_scale[i] = max_val / 127.0;
72 for j in 0..32 {
73 let quantized = (row_weights[j] / w2_scale[i]).round() as i8;
74 *w2_ptr.add(i * 32 + j) = quantized;
75 }
76 }
77
78 Self {
79 w1_int8: w1_ptr,
80 w1_scale,
81 w2_int8: w2_ptr,
82 w2_scale,
83 b1: [0.0; 32],
84 b2: [0.0; 4],
85 }
86 }
87 }
88
89 #[target_feature(enable = "avx2")]
91 pub unsafe fn gemm_int8_avx2(
92 &self,
93 input: &[f32; 128],
94 hidden: &mut [f32; 32],
95 ) {
96 for row_block in (0..32).step_by(8) {
98 let mut acc0 = _mm256_setzero_ps();
100 let mut acc1 = _mm256_setzero_ps();
101 let mut acc2 = _mm256_setzero_ps();
102 let mut acc3 = _mm256_setzero_ps();
103 let mut acc4 = _mm256_setzero_ps();
104 let mut acc5 = _mm256_setzero_ps();
105 let mut acc6 = _mm256_setzero_ps();
106 let mut acc7 = _mm256_setzero_ps();
107
108 for col in (0..128).step_by(8) {
110 let input_vec = _mm256_loadu_ps(input.as_ptr().add(col));
112
113 for r in 0..8.min(32 - row_block) {
116 let row = row_block + r;
117 let weight_ptr = self.w1_int8.add(row * 128 + col);
118
119 let weights_i8 = _mm_loadl_epi64(weight_ptr as *const __m128i);
121 let weights_i32 = _mm256_cvtepi8_epi32(weights_i8);
123 let weights_f32 = _mm256_cvtepi32_ps(weights_i32);
125
126 let scale = _mm256_set1_ps(self.w1_scale[row]);
128 let scaled_weights = _mm256_mul_ps(weights_f32, scale);
129
130 match r {
132 0 => acc0 = _mm256_fmadd_ps(scaled_weights, input_vec, acc0),
133 1 => acc1 = _mm256_fmadd_ps(scaled_weights, input_vec, acc1),
134 2 => acc2 = _mm256_fmadd_ps(scaled_weights, input_vec, acc2),
135 3 => acc3 = _mm256_fmadd_ps(scaled_weights, input_vec, acc3),
136 4 => acc4 = _mm256_fmadd_ps(scaled_weights, input_vec, acc4),
137 5 => acc5 = _mm256_fmadd_ps(scaled_weights, input_vec, acc5),
138 6 => acc6 = _mm256_fmadd_ps(scaled_weights, input_vec, acc6),
139 7 => acc7 = _mm256_fmadd_ps(scaled_weights, input_vec, acc7),
140 _ => {}
141 }
142 }
143 }
144
145 let sum_array = |acc: __m256| -> f32 {
147 let sum = _mm256_hadd_ps(acc, acc);
148 let sum = _mm256_hadd_ps(sum, sum);
149 let high = _mm256_extractf128_ps(sum, 1);
150 let low = _mm256_castps256_ps128(sum);
151 let final_sum = _mm_add_ps(low, high);
152 _mm_cvtss_f32(final_sum)
153 };
154
155 for r in 0..8.min(32 - row_block) {
156 let row = row_block + r;
157 hidden[row] = match r {
158 0 => sum_array(acc0) + self.b1[row],
159 1 => sum_array(acc1) + self.b1[row],
160 2 => sum_array(acc2) + self.b1[row],
161 3 => sum_array(acc3) + self.b1[row],
162 4 => sum_array(acc4) + self.b1[row],
163 5 => sum_array(acc5) + self.b1[row],
164 6 => sum_array(acc6) + self.b1[row],
165 7 => sum_array(acc7) + self.b1[row],
166 _ => 0.0,
167 };
168 }
169 }
170 }
171
172 #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
174 #[target_feature(enable = "avx512f")]
175 pub unsafe fn gemm_int8_avx512(
176 &self,
177 input: &[f32; 128],
178 hidden: &mut [f32; 32],
179 ) {
180 use std::arch::x86_64::*;
181
182 for row in 0..32 {
184 let mut acc = _mm512_setzero_ps();
185
186 for col in (0..128).step_by(16) {
187 let input_vec = _mm512_loadu_ps(input.as_ptr().add(col));
189
190 let weight_ptr = self.w1_int8.add(row * 128 + col);
192 let weights_i8 = _mm_loadu_si128(weight_ptr as *const __m128i);
193 let weights_i32 = _mm512_cvtepi8_epi32(weights_i8);
194 let weights_f32 = _mm512_cvtepi32_ps(weights_i32);
195
196 let scale = _mm512_set1_ps(self.w1_scale[row]);
198 let scaled_weights = _mm512_mul_ps(weights_f32, scale);
199 acc = _mm512_fmadd_ps(scaled_weights, input_vec, acc);
200 }
201
202 hidden[row] = _mm512_reduce_add_ps(acc) + self.b1[row];
204 }
205 }
206}
207
208impl Drop for QuantizedWeights {
209 fn drop(&mut self) {
210 unsafe {
211 let w1_layout = Layout::from_size_align(32 * 128, 64).unwrap();
212 let w2_layout = Layout::from_size_align(4 * 32, 64).unwrap();
213 dealloc(self.w1_int8 as *mut u8, w1_layout);
214 dealloc(self.w2_int8 as *mut u8, w2_layout);
215 }
216 }
217}
218
219#[repr(C, align(64))]
221pub struct OptimizedNeuralNetwork {
222 weights: QuantizedWeights,
223 hidden_buffer: [f32; 32],
225 output_buffer: [f32; 4],
226}
227
228impl OptimizedNeuralNetwork {
229 pub fn new() -> Self {
230 Self {
231 weights: QuantizedWeights::new(),
232 hidden_buffer: [0.0; 32],
233 output_buffer: [0.0; 4],
234 }
235 }
236
237 #[inline(always)]
238 pub fn forward(&mut self, input: &[f32; 128]) -> [f32; 4] {
239 unsafe {
240 self.weights.gemm_int8_avx2(input, &mut self.hidden_buffer);
242
243 for chunk in self.hidden_buffer.chunks_exact_mut(8) {
245 let vals = _mm256_loadu_ps(chunk.as_ptr());
246 let zero = _mm256_setzero_ps();
247 let relu = _mm256_max_ps(vals, zero);
248 _mm256_storeu_ps(chunk.as_mut_ptr(), relu);
249 }
250
251 for i in 0..4 {
253 let mut acc = _mm256_setzero_ps();
254
255 for j in (0..32).step_by(8) {
256 let hidden_vec = _mm256_loadu_ps(self.hidden_buffer.as_ptr().add(j));
257
258 let weight_ptr = self.weights.w2_int8.add(i * 32 + j);
260 let weights_i8 = _mm_loadl_epi64(weight_ptr as *const __m128i);
261 let weights_i32 = _mm256_cvtepi8_epi32(weights_i8);
262 let weights_f32 = _mm256_cvtepi32_ps(weights_i32);
263
264 let scale = _mm256_set1_ps(self.weights.w2_scale[i]);
265 let scaled_weights = _mm256_mul_ps(weights_f32, scale);
266
267 acc = _mm256_fmadd_ps(scaled_weights, hidden_vec, acc);
268 }
269
270 let sum = _mm256_hadd_ps(acc, acc);
272 let sum = _mm256_hadd_ps(sum, sum);
273 let high = _mm256_extractf128_ps(sum, 1);
274 let low = _mm256_castps256_ps128(sum);
275 let final_sum = _mm_add_ps(low, high);
276
277 self.output_buffer[i] = _mm_cvtss_f32(final_sum) + self.weights.b2[i];
278 }
279 }
280
281 self.output_buffer
282 }
283}
284
285#[cfg(target_arch = "x86_64")]
287pub mod asm_optimizations {
288 use std::arch::asm;
289
290 #[inline(always)]
292 pub unsafe fn dot_product_asm(a: *const f32, b: *const f32, len: usize) -> f32 {
293 let mut result: f32;
294
295 asm!(
296 "vzeroall", "xor {i}, {i}", "vxorps ymm0, ymm0, ymm0", "2:", "vmovaps ymm1, [{a} + {i}*4]", "vmovaps ymm2, [{b} + {i}*4]", "vfmadd231ps ymm0, ymm1, ymm2", "add {i}, 8", "cmp {i}, {len}", "jl 2b", "vhaddps ymm0, ymm0, ymm0",
310 "vhaddps ymm0, ymm0, ymm0",
311 "vextractf128 xmm1, ymm0, 1",
312 "vaddps xmm0, xmm0, xmm1",
313 "vmovss {result}, xmm0",
314
315 i = out(reg) _,
316 a = in(reg) a,
317 b = in(reg) b,
318 len = in(reg) len,
319 result = out(xmm_reg) result,
320 out("ymm0") _, out("ymm1") _, out("ymm2") _,
321 );
322
323 result
324 }
325
326 #[inline(always)]
328 pub unsafe fn relu_asm(data: *mut f32, len: usize) {
329 asm!(
330 "vxorps ymm1, ymm1, ymm1", "xor {i}, {i}", "2:", "vmovaps ymm0, [{data} + {i}*4]", "vmaxps ymm0, ymm0, ymm1", "vmovaps [{data} + {i}*4], ymm0", "add {i}, 8",
338 "cmp {i}, {len}",
339 "jl 2b",
340
341 i = out(reg) _,
342 data = in(reg) data,
343 len = in(reg) len,
344 out("ymm0") _, out("ymm1") _,
345 );
346 }
347}
348
349pub struct CpuOptimizer {
351 core_id: usize,
352}
353
354impl CpuOptimizer {
355 pub fn new(preferred_core: usize) -> Self {
356 let core_ids = core_affinity::get_core_ids().unwrap();
358 if preferred_core < core_ids.len() {
359 core_affinity::set_for_current(core_ids[preferred_core]);
360 }
361
362 #[cfg(unix)]
364 unsafe {
365 libc::setpriority(libc::PRIO_PROCESS, 0, -20);
366 }
367
368 Self {
369 core_id: preferred_core,
370 }
371 }
372
373 pub fn prefetch_data<T>(data: &[T]) {
374 unsafe {
375 let ptr = data.as_ptr() as *const i8;
376 for i in (0..data.len()).step_by(64) {
377 _mm_prefetch(ptr.add(i * std::mem::size_of::<T>()), _MM_HINT_T0);
378 }
379 }
380 }
381}
382
383pub struct FullyOptimizedSolver {
385 nn: OptimizedNeuralNetwork,
386 cpu_opt: CpuOptimizer,
387}
388
389impl FullyOptimizedSolver {
390 pub fn new() -> Self {
391 Self {
392 nn: OptimizedNeuralNetwork::new(),
393 cpu_opt: CpuOptimizer::new(0), }
395 }
396
397 #[inline(always)]
398 pub fn predict(&mut self, input: &[f32; 128]) -> ([f32; 4], Duration) {
399 CpuOptimizer::prefetch_data(input);
401
402 let start = Instant::now();
403 let output = self.nn.forward(input);
404 let duration = start.elapsed();
405
406 (output, duration)
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_int8_quantization() {
416 let weights = QuantizedWeights::new();
417 unsafe {
418 for i in 0..32 {
420 for j in 0..128 {
421 let quantized = *weights.w1_int8.add(i * 128 + j);
422 assert!(quantized >= -128 && quantized <= 127);
423 }
424 }
425 }
426 }
427
428 #[test]
429 fn test_fully_optimized() {
430 let mut solver = FullyOptimizedSolver::new();
431 let input = [0.1f32; 128];
432
433 for _ in 0..1000 {
435 solver.predict(&input);
436 }
437
438 let mut timings = Vec::new();
440 for _ in 0..1000 {
441 let (_, duration) = solver.predict(&input);
442 timings.push(duration);
443 }
444
445 timings.sort();
446 let p50 = timings[500];
447 let p99 = timings[990];
448
449 println!("Fully Optimized Performance:");
450 println!(" P50: {:?}", p50);
451 println!(" P99: {:?}", p99);
452
453 assert!(p99.as_micros() < 10);
455 }
456}