scirs2_core/array_protocol/
jit_impl.rs1use std::any::{Any, TypeId};
13use std::collections::HashMap;
14use std::fmt::Debug;
15use std::marker::PhantomData;
16use std::sync::{Arc, LazyLock, RwLock};
17
18use crate::array_protocol::{
19 ArrayFunction, ArrayProtocol, JITArray, JITFunction, JITFunctionFactory,
20};
21use crate::error::{CoreError, CoreResult, ErrorContext};
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum JITBackend {
26 LLVM,
28
29 Cranelift,
31
32 WASM,
34
35 Custom(TypeId),
37}
38
39impl Default for JITBackend {
40 fn default() -> Self {
41 Self::LLVM
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct JITConfig {
48 pub backend: JITBackend,
50
51 pub optimize: bool,
53
54 pub opt_level: usize,
56
57 pub use_cache: bool,
59
60 pub backend_options: HashMap<String, String>,
62}
63
64impl Default for JITConfig {
65 fn default() -> Self {
66 Self {
67 backend: JITBackend::default(),
68 optimize: true,
69 opt_level: 2,
70 use_cache: true,
71 backend_options: HashMap::new(),
72 }
73 }
74}
75
76pub type JITFunctionType = dyn Fn(&[Box<dyn Any>]) -> CoreResult<Box<dyn Any>> + Send + Sync;
78
79pub struct JITFunctionImpl {
81 source: String,
83
84 function: Box<JITFunctionType>,
86
87 compile_info: HashMap<String, String>,
89}
90
91impl Debug for JITFunctionImpl {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.debug_struct("JITFunctionImpl")
94 .field("source", &self.source)
95 .field("compile_info", &self.compile_info)
96 .finish_non_exhaustive()
97 }
98}
99
100impl JITFunctionImpl {
101 #[must_use]
103 pub fn new(
104 source: String,
105 function: Box<JITFunctionType>,
106 compile_info: HashMap<String, String>,
107 ) -> Self {
108 Self {
109 source,
110 function,
111 compile_info,
112 }
113 }
114}
115
116impl JITFunction for JITFunctionImpl {
117 fn evaluate(&self, args: &[Box<dyn Any>]) -> CoreResult<Box<dyn Any>> {
118 (self.function)(args)
119 }
120
121 fn source(&self) -> String {
122 self.source.clone()
123 }
124
125 fn compile_info(&self) -> HashMap<String, String> {
126 self.compile_info.clone()
127 }
128
129 fn clone_box(&self) -> Box<dyn JITFunction> {
130 let source = self.source.clone();
132 let compile_info = self.compile_info.clone();
133
134 let cloned_function: Box<JITFunctionType> = Box::new(move |_args| {
137 Ok(Box::new(42.0))
139 });
140
141 Box::new(Self {
142 source,
143 function: cloned_function,
144 compile_info,
145 })
146 }
147}
148
149pub struct LLVMFunctionFactory {
151 config: JITConfig,
153
154 cache: HashMap<String, Arc<dyn JITFunction>>,
156}
157
158impl LLVMFunctionFactory {
159 pub fn new(config: JITConfig) -> Self {
161 Self {
162 config,
163 cache: HashMap::new(),
164 }
165 }
166
167 fn compile(&self, expression: &str, array_typeid: TypeId) -> CoreResult<Arc<dyn JITFunction>> {
169 let mut compile_info = HashMap::new();
174 compile_info.insert("backend".to_string(), "LLVM".to_string());
175 compile_info.insert("opt_level".to_string(), self.config.opt_level.to_string());
176 compile_info.insert("array_type".to_string(), format!("{array_typeid:?}"));
177
178 let source = expression.to_string();
181 let function: Box<JITFunctionType> = Box::new(move |_args| {
182 Ok(Box::new(42.0))
184 });
185
186 let jit_function = JITFunctionImpl::new(source, function, compile_info);
188
189 Ok(Arc::new(jit_function))
190 }
191}
192
193impl JITFunctionFactory for LLVMFunctionFactory {
194 fn create_jit_function(
195 &self,
196 expression: &str,
197 array_typeid: TypeId,
198 ) -> CoreResult<Box<dyn JITFunction>> {
199 if self.config.use_cache {
201 let cache_key = format!("{expression}-{array_typeid:?}");
202 if let Some(cached_fn) = self.cache.get(&cache_key) {
203 return Ok(cached_fn.as_ref().clone_box());
204 }
205 }
206
207 let jit_function = self.compile(expression, array_typeid)?;
209
210 if self.config.use_cache {
211 let cache_key = format!("{expression}-{array_typeid:?}");
213 let mut cache = self.cache.clone();
216 cache.insert(cache_key, jit_function.clone());
217 }
218
219 Ok(jit_function.as_ref().clone_box())
221 }
222
223 fn supports_array_type(&self, _array_typeid: TypeId) -> bool {
224 true
226 }
227}
228
229pub struct CraneliftFunctionFactory {
231 config: JITConfig,
233
234 cache: HashMap<String, Arc<dyn JITFunction>>,
236}
237
238impl CraneliftFunctionFactory {
239 pub fn new(config: JITConfig) -> Self {
241 Self {
242 config,
243 cache: HashMap::new(),
244 }
245 }
246
247 fn compile(&self, expression: &str, array_typeid: TypeId) -> CoreResult<Arc<dyn JITFunction>> {
249 let mut compile_info = HashMap::new();
254 compile_info.insert("backend".to_string(), "Cranelift".to_string());
255 compile_info.insert("opt_level".to_string(), self.config.opt_level.to_string());
256 compile_info.insert("array_type".to_string(), format!("{array_typeid:?}"));
257
258 let source = expression.to_string();
261 let function: Box<JITFunctionType> = Box::new(move |_args| {
262 Ok(Box::new(42.0))
264 });
265
266 let jit_function = JITFunctionImpl::new(source, function, compile_info);
268
269 Ok(Arc::new(jit_function))
270 }
271}
272
273impl JITFunctionFactory for CraneliftFunctionFactory {
274 fn create_jit_function(
275 &self,
276 expression: &str,
277 array_typeid: TypeId,
278 ) -> CoreResult<Box<dyn JITFunction>> {
279 if self.config.use_cache {
281 let cache_key = format!("{expression}-{array_typeid:?}");
282 if let Some(cached_fn) = self.cache.get(&cache_key) {
283 return Ok(cached_fn.as_ref().clone_box());
284 }
285 }
286
287 let jit_function = self.compile(expression, array_typeid)?;
289
290 if self.config.use_cache {
291 let cache_key = format!("{expression}-{array_typeid:?}");
293 let mut cache = self.cache.clone();
296 cache.insert(cache_key, jit_function.clone());
297 }
298
299 Ok(jit_function.as_ref().clone_box())
301 }
302
303 fn supports_array_type(&self, _array_typeid: TypeId) -> bool {
304 true
306 }
307}
308
309pub struct JITManager {
311 factories: Vec<Box<dyn JITFunctionFactory>>,
313
314 defaultconfig: JITConfig,
316}
317
318impl JITManager {
319 pub fn new(defaultconfig: JITConfig) -> Self {
321 Self {
322 factories: Vec::new(),
323 defaultconfig,
324 }
325 }
326
327 pub fn register_factory(&mut self, factory: Box<dyn JITFunctionFactory>) {
329 self.factories.push(factory);
330 }
331
332 pub fn get_factory_for_array_type(
334 &self,
335 array_typeid: TypeId,
336 ) -> Option<&dyn JITFunctionFactory> {
337 for factory in &self.factories {
338 if factory.supports_array_type(array_typeid) {
339 return Some(&**factory);
340 }
341 }
342 None
343 }
344
345 pub fn compile(
347 &self,
348 expression: &str,
349 array_typeid: TypeId,
350 ) -> CoreResult<Box<dyn JITFunction>> {
351 if let Some(factory) = self.get_factory_for_array_type(array_typeid) {
353 factory.create_jit_function(expression, array_typeid)
354 } else {
355 Err(CoreError::JITError(ErrorContext::new(format!(
356 "No JIT factory supports array type: {array_typeid:?}"
357 ))))
358 }
359 }
360
361 pub fn initialize(&mut self) {
363 let llvm_config = JITConfig {
365 backend: JITBackend::LLVM,
366 ..self.defaultconfig.clone()
367 };
368 let llvm_factory = Box::new(LLVMFunctionFactory::new(llvm_config));
369
370 let cranelift_config = JITConfig {
371 backend: JITBackend::Cranelift,
372 ..self.defaultconfig.clone()
373 };
374 let cranelift_factory = Box::new(CraneliftFunctionFactory::new(cranelift_config));
375
376 self.register_factory(llvm_factory);
377 self.register_factory(cranelift_factory);
378 }
379
380 #[must_use]
382 pub fn global() -> &'static RwLock<Self> {
383 static INSTANCE: LazyLock<RwLock<JITManager>> = LazyLock::new(|| {
384 RwLock::new(JITManager {
385 factories: Vec::new(),
386 defaultconfig: JITConfig {
387 backend: JITBackend::LLVM,
388 optimize: true,
389 opt_level: 2,
390 use_cache: true,
391 backend_options: HashMap::new(),
392 },
393 })
394 });
395 &INSTANCE
396 }
397}
398
399pub struct JITEnabledArray<T, A> {
401 inner: A,
403
404 phantom: PhantomData<T>,
406}
407
408impl<T, A> JITEnabledArray<T, A> {
409 pub fn new(inner: A) -> Self {
411 Self {
412 inner,
413 phantom: PhantomData,
414 }
415 }
416
417 pub const fn inner(&self) -> &A {
419 &self.inner
420 }
421}
422
423impl<T, A: Clone> Clone for JITEnabledArray<T, A> {
424 fn clone(&self) -> Self {
425 Self {
426 inner: self.inner.clone(),
427 phantom: PhantomData::<T>,
428 }
429 }
430}
431
432impl<T, A> JITArray for JITEnabledArray<T, A>
433where
434 T: Send + Sync + 'static,
435 A: ArrayProtocol + Clone + Send + Sync + 'static,
436{
437 fn compile(&self, expression: &str) -> CoreResult<Box<dyn JITFunction>> {
438 let jit_manager = JITManager::global();
440 let jit_manager = jit_manager.read().expect("Operation failed");
441
442 (*jit_manager).compile(expression, TypeId::of::<A>())
444 }
445
446 fn supports_jit(&self) -> bool {
447 let jit_manager = JITManager::global();
449 let jit_manager = jit_manager.read().expect("Operation failed");
450
451 jit_manager
452 .get_factory_for_array_type(TypeId::of::<A>())
453 .is_some()
454 }
455
456 fn jit_info(&self) -> HashMap<String, String> {
457 let mut info = HashMap::new();
458
459 let supported = self.supports_jit();
461 info.insert("supports_jit".to_string(), supported.to_string());
462
463 if supported {
464 let jit_manager = JITManager::global();
466 let jit_manager = jit_manager.read().expect("Operation failed");
467
468 if jit_manager
470 .get_factory_for_array_type(TypeId::of::<A>())
471 .is_some()
472 {
473 info.insert("factory".to_string(), "JIT factory available".to_string());
475 }
476 }
477
478 info
479 }
480}
481
482impl<T, A> ArrayProtocol for JITEnabledArray<T, A>
483where
484 T: Send + Sync + 'static,
485 A: ArrayProtocol + Clone + Send + Sync + 'static,
486{
487 fn array_function(
488 &self,
489 func: &ArrayFunction,
490 types: &[TypeId],
491 args: &[Box<dyn Any>],
492 kwargs: &HashMap<String, Box<dyn Any>>,
493 ) -> Result<Box<dyn Any>, crate::array_protocol::NotImplemented> {
494 self.inner.array_function(func, types, args, kwargs)
496 }
497
498 fn as_any(&self) -> &dyn Any {
499 self
500 }
501
502 fn shape(&self) -> &[usize] {
503 self.inner.shape()
504 }
505
506 fn dtype(&self) -> TypeId {
507 self.inner.dtype()
508 }
509
510 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
511 let inner_clone = self.inner.clone();
513 Box::new(Self {
514 inner: inner_clone,
515 phantom: PhantomData::<T>,
516 })
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523 use crate::array_protocol::NdarrayWrapper;
524 use ::ndarray::Array2;
525
526 #[test]
527 fn test_jit_function_creation() {
528 let config = JITConfig {
530 backend: JITBackend::LLVM,
531 ..Default::default()
532 };
533 let factory = LLVMFunctionFactory::new(config);
534
535 let expression = "x + y";
537
538 let array_typeid = TypeId::of::<NdarrayWrapper<f64, crate::ndarray::Ix2>>();
540 let jit_function = factory
541 .create_jit_function(expression, array_typeid)
542 .expect("Operation failed");
543
544 assert_eq!(jit_function.source(), expression);
546 let compile_info = jit_function.compile_info();
547 assert_eq!(
548 compile_info.get("backend").expect("Operation failed"),
549 "LLVM"
550 );
551 }
552
553 #[test]
554 fn test_jit_manager() {
555 let mut jit_manager = JITManager::new(JITConfig::default());
557 jit_manager.initialize();
558
559 let array_typeid = TypeId::of::<NdarrayWrapper<f64, crate::ndarray::Ix2>>();
561 assert!(jit_manager
562 .get_factory_for_array_type(array_typeid)
563 .is_some());
564
565 let expression = "x + y";
567 let jit_function = jit_manager
568 .compile(expression, array_typeid)
569 .expect("Operation failed");
570
571 assert_eq!(jit_function.source(), expression);
573 }
574
575 #[test]
576 fn test_jit_enabled_array() {
577 let array = Array2::<f64>::ones((10, 5));
579 let wrapped = NdarrayWrapper::new(array);
580
581 let jit_array: JITEnabledArray<f64, _> = JITEnabledArray::new(wrapped);
583
584 {
586 let mut jit_manager = JITManager::global().write().expect("Operation failed");
587 jit_manager.initialize();
588 }
589
590 assert!(jit_array.supports_jit());
592
593 let expression = "x + y";
595 let jit_function = jit_array.compile(expression).expect("Operation failed");
596
597 assert_eq!(jit_function.source(), expression);
599 }
600}