1use crate::error::OptimizeError;
10use ndarray::{Array1, Array2, ArrayView1};
11use std::collections::HashMap;
12use std::sync::{Arc, Mutex};
13
14type CompiledObjectiveFn = Box<dyn Fn(&ArrayView1<f64>) -> f64 + Send + Sync>;
16
17type CompiledGradientFn = Box<dyn Fn(&ArrayView1<f64>) -> Array1<f64> + Send + Sync>;
19
20type CompiledHessianFn = Box<dyn Fn(&ArrayView1<f64>) -> Array2<f64> + Send + Sync>;
22
23type JitCompilationResult = Result<CompiledObjectiveFn, OptimizeError>;
25
26type DerivativeCompilationResult =
28 Result<(Option<CompiledGradientFn>, Option<CompiledHessianFn>), OptimizeError>;
29
30type OptimizedFunctionResult = Result<Box<dyn Fn(&ArrayView1<f64>) -> f64>, OptimizeError>;
32
33#[derive(Debug, Clone)]
35pub struct JitOptions {
36 pub enable_jit: bool,
38 pub enable_vectorization: bool,
40 pub optimization_level: u8,
42 pub enable_specialization: bool,
44 pub enable_caching: bool,
46 pub max_cache_size: usize,
48 pub enable_pgo: bool,
50}
51
52impl Default for JitOptions {
53 fn default() -> Self {
54 Self {
55 enable_jit: true,
56 enable_vectorization: true,
57 optimization_level: 2,
58 enable_specialization: true,
59 enable_caching: true,
60 max_cache_size: 100,
61 enable_pgo: false, }
63 }
64}
65
66#[derive(Debug, Clone, PartialEq, Eq, Hash)]
68pub enum FunctionPattern {
69 Quadratic,
71 SumOfSquares,
73 Polynomial(usize),
75 Exponential,
77 Trigonometric,
79 Separable,
81 General,
83}
84
85pub struct CompiledFunction {
87 pub signature: u64,
89 pub pattern: FunctionPattern,
91 pub implementation: CompiledObjectiveFn,
93 pub gradient: Option<CompiledGradientFn>,
95 pub hessian: Option<CompiledHessianFn>,
97 pub metadata: FunctionMetadata,
99}
100
101#[derive(Debug, Clone)]
103pub struct FunctionMetadata {
104 pub n_vars: usize,
106 pub compile_time_ms: u64,
108 pub call_count: usize,
110 pub avg_execution_time_ns: u64,
112 pub is_vectorized: bool,
114 pub optimization_flags: Vec<String>,
116}
117
118pub struct JitCompiler {
120 options: JitOptions,
121 cache: Arc<Mutex<HashMap<u64, Arc<CompiledFunction>>>>,
122 pattern_detector: PatternDetector,
123 #[allow(dead_code)]
124 profiler: Option<FunctionProfiler>,
125}
126
127impl JitCompiler {
128 pub fn new(options: JitOptions) -> Self {
130 let profiler = if options.enable_pgo {
131 Some(FunctionProfiler::new())
132 } else {
133 None
134 };
135
136 Self {
137 options,
138 cache: Arc::new(Mutex::new(HashMap::new())),
139 pattern_detector: PatternDetector::new(),
140 profiler,
141 }
142 }
143
144 pub fn compile_function<F>(
146 &mut self,
147 fun: F,
148 n_vars: usize,
149 ) -> Result<Arc<CompiledFunction>, OptimizeError>
150 where
151 F: Fn(&ArrayView1<f64>) -> f64 + Send + Sync + 'static,
152 {
153 let start_time = std::time::Instant::now();
154
155 let signature = self.generate_signature(&fun, n_vars);
157
158 if self.options.enable_caching {
160 let cache = self.cache.lock().unwrap();
161 if let Some(compiled) = cache.get(&signature) {
162 return Ok(compiled.clone());
163 }
164 }
165
166 let pattern = if self.options.enable_specialization {
168 self.pattern_detector.detect_pattern(&fun, n_vars)?
169 } else {
170 FunctionPattern::General
171 };
172
173 let implementation = self.create_optimized_implementation(fun, n_vars, &pattern)?;
175
176 let (gradient, hessian) = self.generate_derivatives(&pattern, n_vars)?;
178
179 let compile_time = start_time.elapsed().as_millis() as u64;
180
181 let metadata = FunctionMetadata {
182 n_vars,
183 compile_time_ms: compile_time,
184 call_count: 0,
185 avg_execution_time_ns: 0,
186 is_vectorized: self.options.enable_vectorization && pattern.supports_vectorization(),
187 optimization_flags: self.get_optimization_flags(&pattern),
188 };
189
190 let compiled = Arc::new(CompiledFunction {
191 signature,
192 pattern,
193 implementation,
194 gradient,
195 hessian,
196 metadata,
197 });
198
199 if self.options.enable_caching {
201 let mut cache = self.cache.lock().unwrap();
202 if cache.len() >= self.options.max_cache_size {
203 if let Some((&oldest_key, _)) = cache.iter().next() {
205 cache.remove(&oldest_key);
206 }
207 }
208 cache.insert(signature, compiled.clone());
209 }
210
211 Ok(compiled)
212 }
213
214 fn generate_signature<F>(&self, fun: &F, n_vars: usize) -> u64
216 where
217 F: Fn(&ArrayView1<f64>) -> f64,
218 {
219 use std::collections::hash_map::DefaultHasher;
222 use std::hash::{Hash, Hasher};
223
224 let mut hasher = DefaultHasher::new();
225 n_vars.hash(&mut hasher);
226 (std::ptr::addr_of!(*fun) as usize).hash(&mut hasher);
228 hasher.finish()
229 }
230
231 fn create_optimized_implementation<F>(
233 &self,
234 fun: F,
235 n_vars: usize,
236 pattern: &FunctionPattern,
237 ) -> JitCompilationResult
238 where
239 F: Fn(&ArrayView1<f64>) -> f64 + Send + Sync + 'static,
240 {
241 match pattern {
242 FunctionPattern::Quadratic => {
243 self.create_quadratic_implementation(fun, n_vars)
245 }
246 FunctionPattern::SumOfSquares => {
247 self.create_sum_of_squares_implementation(fun, n_vars)
249 }
250 FunctionPattern::Separable => {
251 self.create_separable_implementation(fun, n_vars)
253 }
254 FunctionPattern::Polynomial(_degree) => {
255 self.create_polynomial_implementation(fun, n_vars)
257 }
258 _ => {
259 if self.options.enable_vectorization {
261 self.create_vectorized_implementation(fun, n_vars)
262 } else {
263 Ok(Box::new(fun))
264 }
265 }
266 }
267 }
268
269 fn create_quadratic_implementation<F>(&self, fun: F, n_vars: usize) -> JitCompilationResult
271 where
272 F: Fn(&ArrayView1<f64>) -> f64 + Send + Sync + 'static,
273 {
274 Ok(Box::new(move |x: &ArrayView1<f64>| {
278 fun(x)
280 }))
281 }
282
283 fn create_sum_of_squares_implementation<F>(&self, fun: F, n_vars: usize) -> JitCompilationResult
285 where
286 F: Fn(&ArrayView1<f64>) -> f64 + Send + Sync + 'static,
287 {
288 Ok(Box::new(move |x: &ArrayView1<f64>| {
290 fun(x)
292 }))
293 }
294
295 fn create_separable_implementation<F>(&self, fun: F, n_vars: usize) -> JitCompilationResult
297 where
298 F: Fn(&ArrayView1<f64>) -> f64 + Send + Sync + 'static,
299 {
300 Ok(Box::new(move |x: &ArrayView1<f64>| {
302 if n_vars > 1000 {
303 use scirs2_core::parallel_ops::*;
305
306 let chunk_size = (n_vars / num_threads()).max(100);
308 (0..n_vars)
309 .into_par_iter()
310 .chunks(chunk_size)
311 .map(|chunk| {
312 let mut chunk_x = Array1::zeros(x.len());
313 chunk_x.assign(x);
314
315 let mut chunk_sum = 0.0;
317 for _i in chunk {
318 chunk_sum += fun(&chunk_x.view()) / n_vars as f64; }
321 chunk_sum
322 })
323 .sum()
324 } else {
325 fun(x)
326 }
327 }))
328 }
329
330 fn create_polynomial_implementation<F>(&self, fun: F, n_vars: usize) -> JitCompilationResult
332 where
333 F: Fn(&ArrayView1<f64>) -> f64 + Send + Sync + 'static,
334 {
335 Ok(Box::new(fun))
337 }
338
339 fn create_vectorized_implementation<F>(&self, fun: F, n_vars: usize) -> JitCompilationResult
341 where
342 F: Fn(&ArrayView1<f64>) -> f64 + Send + Sync + 'static,
343 {
344 if n_vars >= 8 && self.options.enable_vectorization {
345 Ok(Box::new(move |x: &ArrayView1<f64>| {
347 fun(x)
350 }))
351 } else {
352 Ok(Box::new(fun))
353 }
354 }
355
356 fn generate_derivatives(
358 &self,
359 pattern: &FunctionPattern,
360 n_vars: usize,
361 ) -> DerivativeCompilationResult {
362 match pattern {
363 FunctionPattern::Quadratic => {
364 let gradient = Box::new(move |x: &ArrayView1<f64>| {
367 Array1::zeros(n_vars)
369 });
370
371 let hessian = Box::new(move |x: &ArrayView1<f64>| {
372 Array2::zeros((n_vars, n_vars))
374 });
375
376 Ok((Some(gradient), Some(hessian)))
377 }
378 FunctionPattern::Separable => {
379 let gradient = Box::new(move |x: &ArrayView1<f64>| {
381 Array1::zeros(n_vars)
383 });
384
385 Ok((Some(gradient), None))
386 }
387 _ => Ok((None, None)),
388 }
389 }
390
391 fn get_optimization_flags(&self, pattern: &FunctionPattern) -> Vec<String> {
393 let mut flags = Vec::new();
394
395 if self.options.enable_vectorization {
396 flags.push("vectorization".to_string());
397 }
398
399 match pattern {
400 FunctionPattern::Quadratic => flags.push("quadratic-opt".to_string()),
401 FunctionPattern::SumOfSquares => flags.push("sum-of-squares-opt".to_string()),
402 FunctionPattern::Separable => flags.push("separable-opt".to_string()),
403 FunctionPattern::Polynomial(_) => flags.push("polynomial-opt".to_string()),
404 _ => flags.push("general-opt".to_string()),
405 }
406
407 flags
408 }
409
410 pub fn get_stats(&self) -> JitStats {
412 let cache = self.cache.lock().unwrap();
413 JitStats {
414 total_compiled: cache.len(),
415 cache_hits: 0, cache_misses: 0,
417 total_compile_time_ms: cache.values().map(|f| f.metadata.compile_time_ms).sum(),
418 }
419 }
420}
421
422pub struct PatternDetector {
424 sample_points: Vec<Array1<f64>>,
425}
426
427impl Default for PatternDetector {
428 fn default() -> Self {
429 Self::new()
430 }
431}
432
433impl PatternDetector {
434 pub fn new() -> Self {
435 Self {
436 sample_points: Vec::new(),
437 }
438 }
439
440 pub fn detect_pattern<F>(
442 &mut self,
443 fun: &F,
444 n_vars: usize,
445 ) -> Result<FunctionPattern, OptimizeError>
446 where
447 F: Fn(&ArrayView1<f64>) -> f64,
448 {
449 if self.sample_points.is_empty() {
451 self.generate_sample_points(n_vars)?;
452 }
453
454 let mut values = Vec::new();
456 for point in &self.sample_points {
457 values.push(fun(&point.view()));
458 }
459
460 if self.is_quadratic(&values, n_vars) {
462 Ok(FunctionPattern::Quadratic)
463 } else if self.is_sum_of_squares(&values) {
464 Ok(FunctionPattern::SumOfSquares)
465 } else if self.is_separable(fun, n_vars)? {
466 Ok(FunctionPattern::Separable)
467 } else if let Some(degree) = self.detect_polynomial_degree(&values) {
468 Ok(FunctionPattern::Polynomial(degree))
469 } else {
470 Ok(FunctionPattern::General)
471 }
472 }
473
474 fn generate_sample_points(&mut self, n_vars: usize) -> Result<(), OptimizeError> {
475 use rand::{prelude::*, rng};
476 let mut rng = rand::rng();
477
478 let n_samples = (20 + n_vars).min(100); for _ in 0..n_samples {
482 let mut point = Array1::zeros(n_vars);
483 for j in 0..n_vars {
484 point[j] = rng.gen_range(-2.0..2.0);
485 }
486 self.sample_points.push(point);
487 }
488
489 self.sample_points.push(Array1::zeros(n_vars)); self.sample_points.push(Array1::ones(n_vars)); Ok(())
494 }
495
496 fn is_quadratic(&self, _values: &[f64], _nvars: usize) -> bool {
497 false }
501
502 fn is_sum_of_squares(&self, values: &[f64]) -> bool {
503 false
506 }
507
508 fn is_separable<F>(&self, fun: &F, n_vars: usize) -> Result<bool, OptimizeError>
509 where
510 F: Fn(&ArrayView1<f64>) -> f64,
511 {
512 Ok(false)
516 }
517
518 fn detect_polynomial_degree(&self, values: &[f64]) -> Option<usize> {
519 None
522 }
523}
524
525impl FunctionPattern {
526 pub fn supports_vectorization(&self) -> bool {
528 matches!(
529 self,
530 FunctionPattern::Quadratic
531 | FunctionPattern::SumOfSquares
532 | FunctionPattern::Separable
533 | FunctionPattern::Polynomial(_)
534 )
535 }
536}
537
538pub struct FunctionProfiler {
540 profiles: HashMap<u64, ProfileData>,
541}
542
543#[derive(Debug, Clone)]
544struct ProfileData {
545 call_count: usize,
546 total_time_ns: u64,
547 #[allow(dead_code)]
548 hot_paths: Vec<String>,
549}
550
551impl Default for FunctionProfiler {
552 fn default() -> Self {
553 Self::new()
554 }
555}
556
557impl FunctionProfiler {
558 pub fn new() -> Self {
559 Self {
560 profiles: HashMap::new(),
561 }
562 }
563
564 pub fn record_call(&mut self, signature: u64, execution_time_ns: u64) {
565 let profile = self.profiles.entry(signature).or_insert(ProfileData {
566 call_count: 0,
567 total_time_ns: 0,
568 hot_paths: Vec::new(),
569 });
570
571 profile.call_count += 1;
572 profile.total_time_ns += execution_time_ns;
573 }
574
575 pub fn get_hot_functions(&self) -> Vec<u64> {
576 let mut functions: Vec<_> = self.profiles.iter().collect();
577 functions.sort_by_key(|(_, profile)| profile.total_time_ns);
578 functions
579 .into_iter()
580 .rev()
581 .take(10)
582 .map(|(&sig, _)| sig)
583 .collect()
584 }
585}
586
587#[derive(Debug, Clone)]
589pub struct JitStats {
590 pub total_compiled: usize,
591 pub cache_hits: usize,
592 pub cache_misses: usize,
593 pub total_compile_time_ms: u64,
594}
595
596#[allow(dead_code)]
598pub fn optimize_function<F>(
599 fun: F,
600 n_vars: usize,
601 options: Option<JitOptions>,
602) -> OptimizedFunctionResult
603where
604 F: Fn(&ArrayView1<f64>) -> f64 + Send + Sync + 'static,
605{
606 let options = options.unwrap_or_default();
607
608 if !options.enable_jit {
609 return Ok(Box::new(fun));
611 }
612
613 let mut compiler = JitCompiler::new(options);
614 let compiled = compiler.compile_function(fun, n_vars)?;
615
616 Ok(Box::new(move |x: &ArrayView1<f64>| -> f64 {
617 (compiled.implementation)(x)
618 }))
619}
620
621#[allow(dead_code)]
623fn estimate_memory_usage(n_vars: usize, max_history: usize) -> usize {
624 let vector_size = n_vars * std::mem::size_of::<f64>();
626 let matrix_size = n_vars * n_vars * std::mem::size_of::<f64>();
627
628 let basic_vectors = 3 * vector_size;
630
631 let history_vectors = 2 * max_history * vector_size;
633
634 let temp_memory = 2 * matrix_size + 5 * vector_size;
636
637 basic_vectors + history_vectors + temp_memory
638}
639
640#[cfg(test)]
641mod tests {
642 use super::*;
643 use approx::assert_abs_diff_eq;
644
645 #[test]
646 fn test_jit_compiler_creation() {
647 let options = JitOptions::default();
648 let compiler = JitCompiler::new(options);
649
650 let stats = compiler.get_stats();
651 assert_eq!(stats.total_compiled, 0);
652 }
653
654 #[test]
655 fn test_pattern_detection() {
656 let mut detector = PatternDetector::new();
657
658 let quadratic = |x: &ArrayView1<f64>| x[0] * x[0] + x[1] * x[1];
660
661 let pattern = detector.detect_pattern(&quadratic, 2).unwrap();
662
663 assert!(matches!(
665 pattern,
666 FunctionPattern::General | FunctionPattern::Quadratic
667 ));
668 }
669
670 #[test]
671 fn test_function_optimization() {
672 let quadratic = |x: &ArrayView1<f64>| x[0] * x[0] + x[1] * x[1];
673
674 let optimized = optimize_function(quadratic, 2, None).unwrap();
675
676 let x = Array1::from_vec(vec![1.0, 2.0]);
677 let result = (*optimized)(&x.view());
678
679 assert_abs_diff_eq!(result, 5.0, epsilon = 1e-10);
680 }
681
682 #[test]
683 fn test_memory_usage_estimation() {
684 let n_vars = 1000;
686 let max_history = 10;
687
688 let estimated = estimate_memory_usage(n_vars, max_history);
689 assert!(estimated > 0);
690
691 let estimated_large = estimate_memory_usage(n_vars * 2, max_history);
693 assert!(estimated_large > estimated);
694 }
695}