scirs2_core/array_protocol/
jit_impl.rs1use std::any::{Any, TypeId};
19use std::collections::HashMap;
20use std::fmt::Debug;
21use std::marker::PhantomData;
22use std::sync::{Arc, LazyLock, RwLock};
23
24use crate::array_protocol::{
25 ArrayFunction, ArrayProtocol, JITArray, JITFunction, JITFunctionFactory,
26};
27use crate::error::{CoreError, CoreResult, ErrorContext};
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum JITBackend {
32 LLVM,
34
35 Cranelift,
37
38 WASM,
40
41 Custom(TypeId),
43}
44
45impl Default for JITBackend {
46 fn default() -> Self {
47 Self::LLVM
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct JITConfig {
54 pub backend: JITBackend,
56
57 pub optimize: bool,
59
60 pub opt_level: usize,
62
63 pub use_cache: bool,
65
66 pub backend_options: HashMap<String, String>,
68}
69
70impl Default for JITConfig {
71 fn default() -> Self {
72 Self {
73 backend: JITBackend::default(),
74 optimize: true,
75 opt_level: 2,
76 use_cache: true,
77 backend_options: HashMap::new(),
78 }
79 }
80}
81
82pub type JITFunctionType = dyn Fn(&[Box<dyn Any>]) -> CoreResult<Box<dyn Any>> + Send + Sync;
84
85pub struct JITFunctionImpl {
87 source: String,
89
90 function: Box<JITFunctionType>,
92
93 compile_info: HashMap<String, String>,
95}
96
97impl Debug for JITFunctionImpl {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 f.debug_struct("JITFunctionImpl")
100 .field("source", &self.source)
101 .field("compile_info", &self.compile_info)
102 .finish_non_exhaustive()
103 }
104}
105
106impl JITFunctionImpl {
107 #[must_use]
109 pub fn new(
110 source: String,
111 function: Box<JITFunctionType>,
112 compile_info: HashMap<String, String>,
113 ) -> Self {
114 Self {
115 source,
116 function,
117 compile_info,
118 }
119 }
120}
121
122impl JITFunction for JITFunctionImpl {
123 fn evaluate(&self, args: &[Box<dyn Any>]) -> CoreResult<Box<dyn Any>> {
124 (self.function)(args)
125 }
126
127 fn source(&self) -> String {
128 self.source.clone()
129 }
130
131 fn compile_info(&self) -> HashMap<String, String> {
132 self.compile_info.clone()
133 }
134
135 fn clone_box(&self) -> Box<dyn JITFunction> {
136 let source = self.source.clone();
138 let compile_info = self.compile_info.clone();
139
140 let cloned_function: Box<JITFunctionType> = Box::new(move |_args| {
143 Ok(Box::new(42.0))
145 });
146
147 Box::new(Self {
148 source,
149 function: cloned_function,
150 compile_info,
151 })
152 }
153}
154
155pub struct LLVMFunctionFactory {
157 config: JITConfig,
159
160 cache: HashMap<String, Arc<dyn JITFunction>>,
162}
163
164impl LLVMFunctionFactory {
165 pub fn new(config: JITConfig) -> Self {
167 Self {
168 config,
169 cache: HashMap::new(),
170 }
171 }
172
173 fn compile(&self, expression: &str, array_typeid: TypeId) -> CoreResult<Arc<dyn JITFunction>> {
175 let mut compile_info = HashMap::new();
180 compile_info.insert("backend".to_string(), "LLVM".to_string());
181 compile_info.insert("opt_level".to_string(), self.config.opt_level.to_string());
182 compile_info.insert("array_type".to_string(), format!("{array_typeid:?}"));
183
184 let source = expression.to_string();
187 let function: Box<JITFunctionType> = Box::new(move |_args| {
188 Ok(Box::new(42.0))
190 });
191
192 let jit_function = JITFunctionImpl::new(source, function, compile_info);
194
195 Ok(Arc::new(jit_function))
196 }
197}
198
199impl JITFunctionFactory for LLVMFunctionFactory {
200 fn create_jit_function(
201 &self,
202 expression: &str,
203 array_typeid: TypeId,
204 ) -> CoreResult<Box<dyn JITFunction>> {
205 if self.config.use_cache {
207 let cache_key = format!("{expression}-{array_typeid:?}");
208 if let Some(cached_fn) = self.cache.get(&cache_key) {
209 return Ok(cached_fn.as_ref().clone_box());
210 }
211 }
212
213 let jit_function = self.compile(expression, array_typeid)?;
215
216 if self.config.use_cache {
217 let cache_key = format!("{expression}-{array_typeid:?}");
219 let mut cache = self.cache.clone();
222 cache.insert(cache_key, jit_function.clone());
223 }
224
225 Ok(jit_function.as_ref().clone_box())
227 }
228
229 fn supports_array_type(&self, _array_typeid: TypeId) -> bool {
230 true
232 }
233}
234
235pub struct CraneliftFunctionFactory {
237 config: JITConfig,
239
240 cache: HashMap<String, Arc<dyn JITFunction>>,
242}
243
244impl CraneliftFunctionFactory {
245 pub fn new(config: JITConfig) -> Self {
247 Self {
248 config,
249 cache: HashMap::new(),
250 }
251 }
252
253 fn compile(&self, expression: &str, array_typeid: TypeId) -> CoreResult<Arc<dyn JITFunction>> {
255 let mut compile_info = HashMap::new();
260 compile_info.insert("backend".to_string(), "Cranelift".to_string());
261 compile_info.insert("opt_level".to_string(), self.config.opt_level.to_string());
262 compile_info.insert("array_type".to_string(), format!("{array_typeid:?}"));
263
264 let source = expression.to_string();
267 let function: Box<JITFunctionType> = Box::new(move |_args| {
268 Ok(Box::new(42.0))
270 });
271
272 let jit_function = JITFunctionImpl::new(source, function, compile_info);
274
275 Ok(Arc::new(jit_function))
276 }
277}
278
279impl JITFunctionFactory for CraneliftFunctionFactory {
280 fn create_jit_function(
281 &self,
282 expression: &str,
283 array_typeid: TypeId,
284 ) -> CoreResult<Box<dyn JITFunction>> {
285 if self.config.use_cache {
287 let cache_key = format!("{expression}-{array_typeid:?}");
288 if let Some(cached_fn) = self.cache.get(&cache_key) {
289 return Ok(cached_fn.as_ref().clone_box());
290 }
291 }
292
293 let jit_function = self.compile(expression, array_typeid)?;
295
296 if self.config.use_cache {
297 let cache_key = format!("{expression}-{array_typeid:?}");
299 let mut cache = self.cache.clone();
302 cache.insert(cache_key, jit_function.clone());
303 }
304
305 Ok(jit_function.as_ref().clone_box())
307 }
308
309 fn supports_array_type(&self, _array_typeid: TypeId) -> bool {
310 true
312 }
313}
314
315pub struct JITManager {
317 factories: Vec<Box<dyn JITFunctionFactory>>,
319
320 defaultconfig: JITConfig,
322}
323
324impl JITManager {
325 pub fn new(defaultconfig: JITConfig) -> Self {
327 Self {
328 factories: Vec::new(),
329 defaultconfig,
330 }
331 }
332
333 pub fn register_factory(&mut self, factory: Box<dyn JITFunctionFactory>) {
335 self.factories.push(factory);
336 }
337
338 pub fn get_factory_for_array_type(
340 &self,
341 array_typeid: TypeId,
342 ) -> Option<&dyn JITFunctionFactory> {
343 for factory in &self.factories {
344 if factory.supports_array_type(array_typeid) {
345 return Some(&**factory);
346 }
347 }
348 None
349 }
350
351 pub fn compile(
353 &self,
354 expression: &str,
355 array_typeid: TypeId,
356 ) -> CoreResult<Box<dyn JITFunction>> {
357 if let Some(factory) = self.get_factory_for_array_type(array_typeid) {
359 factory.create_jit_function(expression, array_typeid)
360 } else {
361 Err(CoreError::JITError(ErrorContext::new(format!(
362 "No JIT factory supports array type: {array_typeid:?}"
363 ))))
364 }
365 }
366
367 pub fn initialize(&mut self) {
369 let llvm_config = JITConfig {
371 backend: JITBackend::LLVM,
372 ..self.defaultconfig.clone()
373 };
374 let llvm_factory = Box::new(LLVMFunctionFactory::new(llvm_config));
375
376 let cranelift_config = JITConfig {
377 backend: JITBackend::Cranelift,
378 ..self.defaultconfig.clone()
379 };
380 let cranelift_factory = Box::new(CraneliftFunctionFactory::new(cranelift_config));
381
382 self.register_factory(llvm_factory);
383 self.register_factory(cranelift_factory);
384 }
385
386 #[must_use]
388 pub fn global() -> &'static RwLock<Self> {
389 static INSTANCE: LazyLock<RwLock<JITManager>> = LazyLock::new(|| {
390 RwLock::new(JITManager {
391 factories: Vec::new(),
392 defaultconfig: JITConfig {
393 backend: JITBackend::LLVM,
394 optimize: true,
395 opt_level: 2,
396 use_cache: true,
397 backend_options: HashMap::new(),
398 },
399 })
400 });
401 &INSTANCE
402 }
403}
404
405pub struct JITEnabledArray<T, A> {
407 inner: A,
409
410 phantom: PhantomData<T>,
412}
413
414impl<T, A> JITEnabledArray<T, A> {
415 pub fn new(inner: A) -> Self {
417 Self {
418 inner,
419 phantom: PhantomData,
420 }
421 }
422
423 pub const fn inner(&self) -> &A {
425 &self.inner
426 }
427}
428
429impl<T, A: Clone> Clone for JITEnabledArray<T, A> {
430 fn clone(&self) -> Self {
431 Self {
432 inner: self.inner.clone(),
433 phantom: PhantomData::<T>,
434 }
435 }
436}
437
438impl<T, A> JITArray for JITEnabledArray<T, A>
439where
440 T: Send + Sync + 'static,
441 A: ArrayProtocol + Clone + Send + Sync + 'static,
442{
443 fn compile(&self, expression: &str) -> CoreResult<Box<dyn JITFunction>> {
444 let jit_manager = JITManager::global();
446 let jit_manager = jit_manager.read().unwrap();
447
448 (*jit_manager).compile(expression, TypeId::of::<A>())
450 }
451
452 fn supports_jit(&self) -> bool {
453 let jit_manager = JITManager::global();
455 let jit_manager = jit_manager.read().unwrap();
456
457 jit_manager
458 .get_factory_for_array_type(TypeId::of::<A>())
459 .is_some()
460 }
461
462 fn jit_info(&self) -> HashMap<String, String> {
463 let mut info = HashMap::new();
464
465 let supported = self.supports_jit();
467 info.insert("supports_jit".to_string(), supported.to_string());
468
469 if supported {
470 let jit_manager = JITManager::global();
472 let jit_manager = jit_manager.read().unwrap();
473
474 if jit_manager
476 .get_factory_for_array_type(TypeId::of::<A>())
477 .is_some()
478 {
479 info.insert("factory".to_string(), "JIT factory available".to_string());
481 }
482 }
483
484 info
485 }
486}
487
488impl<T, A> ArrayProtocol for JITEnabledArray<T, A>
489where
490 T: Send + Sync + 'static,
491 A: ArrayProtocol + Clone + Send + Sync + 'static,
492{
493 fn array_function(
494 &self,
495 func: &ArrayFunction,
496 types: &[TypeId],
497 args: &[Box<dyn Any>],
498 kwargs: &HashMap<String, Box<dyn Any>>,
499 ) -> Result<Box<dyn Any>, crate::array_protocol::NotImplemented> {
500 self.inner.array_function(func, types, args, kwargs)
502 }
503
504 fn as_any(&self) -> &dyn Any {
505 self
506 }
507
508 fn shape(&self) -> &[usize] {
509 self.inner.shape()
510 }
511
512 fn dtype(&self) -> TypeId {
513 self.inner.dtype()
514 }
515
516 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
517 let inner_clone = self.inner.clone();
519 Box::new(Self {
520 inner: inner_clone,
521 phantom: PhantomData::<T>,
522 })
523 }
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529 use crate::array_protocol::NdarrayWrapper;
530 use ndarray::Array2;
531
532 #[test]
533 fn test_jit_function_creation() {
534 let config = JITConfig {
536 backend: JITBackend::LLVM,
537 ..Default::default()
538 };
539 let factory = LLVMFunctionFactory::new(config);
540
541 let expression = "x + y";
543
544 let array_typeid = TypeId::of::<NdarrayWrapper<f64, ndarray::Ix2>>();
546 let jit_function = factory
547 .create_jit_function(expression, array_typeid)
548 .unwrap();
549
550 assert_eq!(jit_function.source(), expression);
552 let compile_info = jit_function.compile_info();
553 assert_eq!(compile_info.get("backend").unwrap(), "LLVM");
554 }
555
556 #[test]
557 fn test_jit_manager() {
558 let mut jit_manager = JITManager::new(JITConfig::default());
560 jit_manager.initialize();
561
562 let array_typeid = TypeId::of::<NdarrayWrapper<f64, ndarray::Ix2>>();
564 assert!(jit_manager
565 .get_factory_for_array_type(array_typeid)
566 .is_some());
567
568 let expression = "x + y";
570 let jit_function = jit_manager.compile(expression, array_typeid).unwrap();
571
572 assert_eq!(jit_function.source(), expression);
574 }
575
576 #[test]
577 fn test_jit_enabled_array() {
578 let array = Array2::<f64>::ones((10, 5));
580 let wrapped = NdarrayWrapper::new(array);
581
582 let jit_array: JITEnabledArray<f64, _> = JITEnabledArray::new(wrapped);
584
585 {
587 let mut jit_manager = JITManager::global().write().unwrap();
588 jit_manager.initialize();
589 }
590
591 assert!(jit_array.supports_jit());
593
594 let expression = "x + y";
596 let jit_function = jit_array.compile(expression).unwrap();
597
598 assert_eq!(jit_function.source(), expression);
600 }
601}