scirs2_core/array_protocol/
jit_impl.rs1use std::any::{Any, TypeId};
19use std::collections::HashMap;
20use std::fmt::Debug;
21use std::marker::PhantomData;
22use std::sync::{Arc, LazyLock, RwLock};
23
24use crate::array_protocol::{ArrayProtocol, JITArray, JITFunction, JITFunctionFactory};
25use crate::error::{CoreError, CoreResult, ErrorContext};
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum JITBackend {
30 LLVM,
32
33 Cranelift,
35
36 WASM,
38
39 Custom(TypeId),
41}
42
43impl Default for JITBackend {
44 fn default() -> Self {
45 Self::LLVM
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct JITConfig {
52 pub backend: JITBackend,
54
55 pub optimize: bool,
57
58 pub opt_level: usize,
60
61 pub use_cache: bool,
63
64 pub backend_options: HashMap<String, String>,
66}
67
68impl Default for JITConfig {
69 fn default() -> Self {
70 Self {
71 backend: JITBackend::default(),
72 optimize: true,
73 opt_level: 2,
74 use_cache: true,
75 backend_options: HashMap::new(),
76 }
77 }
78}
79
80pub type JITFunctionType = dyn Fn(&[Box<dyn Any>]) -> CoreResult<Box<dyn Any>> + Send + Sync;
82
83pub struct JITFunctionImpl {
85 source: String,
87
88 function: Box<JITFunctionType>,
90
91 compile_info: HashMap<String, String>,
93}
94
95impl Debug for JITFunctionImpl {
96 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
97 f.debug_struct("JITFunctionImpl")
98 .field("source", &self.source)
99 .field("compile_info", &self.compile_info)
100 .finish_non_exhaustive()
101 }
102}
103
104impl JITFunctionImpl {
105 #[must_use]
107 pub fn new(
108 source: String,
109 function: Box<JITFunctionType>,
110 compile_info: HashMap<String, String>,
111 ) -> Self {
112 Self {
113 source,
114 function,
115 compile_info,
116 }
117 }
118}
119
120impl JITFunction for JITFunctionImpl {
121 fn evaluate(&self, args: &[Box<dyn Any>]) -> CoreResult<Box<dyn Any>> {
122 (self.function)(args)
123 }
124
125 fn source(&self) -> String {
126 self.source.clone()
127 }
128
129 fn compile_info(&self) -> HashMap<String, String> {
130 self.compile_info.clone()
131 }
132
133 fn clone_box(&self) -> Box<dyn JITFunction> {
134 let source = self.source.clone();
136 let compile_info = self.compile_info.clone();
137
138 let cloned_function: Box<JITFunctionType> = Box::new(move |_args| {
141 Ok(Box::new(42.0))
143 });
144
145 Box::new(Self {
146 source,
147 function: cloned_function,
148 compile_info,
149 })
150 }
151}
152
153pub struct LLVMFunctionFactory {
155 config: JITConfig,
157
158 cache: HashMap<String, Arc<dyn JITFunction>>,
160}
161
162impl LLVMFunctionFactory {
163 pub fn new(config: JITConfig) -> Self {
165 Self {
166 config,
167 cache: HashMap::new(),
168 }
169 }
170
171 fn compile(&self, expression: &str, array_type_id: TypeId) -> CoreResult<Arc<dyn JITFunction>> {
173 let mut compile_info = HashMap::new();
178 compile_info.insert("backend".to_string(), "LLVM".to_string());
179 compile_info.insert("opt_level".to_string(), self.config.opt_level.to_string());
180 compile_info.insert("array_type".to_string(), format!("{array_type_id:?}"));
181
182 let source = expression.to_string();
185 let function: Box<JITFunctionType> = Box::new(move |_args| {
186 Ok(Box::new(42.0))
188 });
189
190 let jit_function = JITFunctionImpl::new(source, function, compile_info);
192
193 Ok(Arc::new(jit_function))
194 }
195}
196
197impl JITFunctionFactory for LLVMFunctionFactory {
198 fn create(&self, expression: &str, array_type_id: TypeId) -> CoreResult<Box<dyn JITFunction>> {
199 if self.config.use_cache {
201 let cache_key = format!("{expression}-{array_type_id:?}");
202 if let Some(cached_fn) = self.cache.get(&cache_key) {
203 return Ok(cached_fn.as_ref().clone_box());
204 }
205 }
206
207 let jit_function = self.compile(expression, array_type_id)?;
209
210 if self.config.use_cache {
211 let cache_key = format!("{expression}-{array_type_id:?}");
213 let mut cache = self.cache.clone();
216 cache.insert(cache_key, jit_function.clone());
217 }
218
219 Ok(jit_function.as_ref().clone_box())
221 }
222
223 fn supports_array_type(&self, _array_type_id: TypeId) -> bool {
224 true
226 }
227}
228
229pub struct CraneliftFunctionFactory {
231 config: JITConfig,
233
234 cache: HashMap<String, Arc<dyn JITFunction>>,
236}
237
238impl CraneliftFunctionFactory {
239 pub fn new(config: JITConfig) -> Self {
241 Self {
242 config,
243 cache: HashMap::new(),
244 }
245 }
246
247 fn compile(&self, expression: &str, array_type_id: TypeId) -> CoreResult<Arc<dyn JITFunction>> {
249 let mut compile_info = HashMap::new();
254 compile_info.insert("backend".to_string(), "Cranelift".to_string());
255 compile_info.insert("opt_level".to_string(), self.config.opt_level.to_string());
256 compile_info.insert("array_type".to_string(), format!("{array_type_id:?}"));
257
258 let source = expression.to_string();
261 let function: Box<JITFunctionType> = Box::new(move |_args| {
262 Ok(Box::new(42.0))
264 });
265
266 let jit_function = JITFunctionImpl::new(source, function, compile_info);
268
269 Ok(Arc::new(jit_function))
270 }
271}
272
273impl JITFunctionFactory for CraneliftFunctionFactory {
274 fn create(&self, expression: &str, array_type_id: TypeId) -> CoreResult<Box<dyn JITFunction>> {
275 if self.config.use_cache {
277 let cache_key = format!("{expression}-{array_type_id:?}");
278 if let Some(cached_fn) = self.cache.get(&cache_key) {
279 return Ok(cached_fn.as_ref().clone_box());
280 }
281 }
282
283 let jit_function = self.compile(expression, array_type_id)?;
285
286 if self.config.use_cache {
287 let cache_key = format!("{expression}-{array_type_id:?}");
289 let mut cache = self.cache.clone();
292 cache.insert(cache_key, jit_function.clone());
293 }
294
295 Ok(jit_function.as_ref().clone_box())
297 }
298
299 fn supports_array_type(&self, _array_type_id: TypeId) -> bool {
300 true
302 }
303}
304
305pub struct JITManager {
307 factories: Vec<Box<dyn JITFunctionFactory>>,
309
310 default_config: JITConfig,
312}
313
314impl JITManager {
315 pub fn new(default_config: JITConfig) -> Self {
317 Self {
318 factories: Vec::new(),
319 default_config,
320 }
321 }
322
323 pub fn register_factory(&mut self, factory: Box<dyn JITFunctionFactory>) {
325 self.factories.push(factory);
326 }
327
328 pub fn get_factory_for_array_type(
330 &self,
331 array_type_id: TypeId,
332 ) -> Option<&dyn JITFunctionFactory> {
333 for factory in &self.factories {
334 if factory.supports_array_type(array_type_id) {
335 return Some(&**factory);
336 }
337 }
338 None
339 }
340
341 pub fn compile(
343 &self,
344 expression: &str,
345 array_type_id: TypeId,
346 ) -> CoreResult<Box<dyn JITFunction>> {
347 if let Some(factory) = self.get_factory_for_array_type(array_type_id) {
349 factory.create(expression, array_type_id)
350 } else {
351 Err(CoreError::JITError(ErrorContext::new(format!(
352 "No JIT factory supports array type: {:?}",
353 array_type_id
354 ))))
355 }
356 }
357
358 pub fn initialize(&mut self) {
360 let llvm_config = JITConfig {
362 backend: JITBackend::LLVM,
363 ..self.default_config.clone()
364 };
365 let llvm_factory = Box::new(LLVMFunctionFactory::new(llvm_config));
366
367 let cranelift_config = JITConfig {
368 backend: JITBackend::Cranelift,
369 ..self.default_config.clone()
370 };
371 let cranelift_factory = Box::new(CraneliftFunctionFactory::new(cranelift_config));
372
373 self.register_factory(llvm_factory);
374 self.register_factory(cranelift_factory);
375 }
376
377 #[must_use]
379 pub fn global() -> &'static RwLock<Self> {
380 static INSTANCE: LazyLock<RwLock<JITManager>> = LazyLock::new(|| {
381 RwLock::new(JITManager {
382 factories: Vec::new(),
383 default_config: JITConfig {
384 backend: JITBackend::LLVM,
385 optimize: true,
386 opt_level: 2,
387 use_cache: true,
388 backend_options: HashMap::new(),
389 },
390 })
391 });
392 &INSTANCE
393 }
394}
395
396pub struct JITEnabledArray<T, A> {
398 inner: A,
400
401 _phantom: PhantomData<T>,
403}
404
405impl<T, A> JITEnabledArray<T, A> {
406 pub fn new(inner: A) -> Self {
408 Self {
409 inner,
410 _phantom: PhantomData,
411 }
412 }
413
414 pub const fn inner(&self) -> &A {
416 &self.inner
417 }
418}
419
420impl<T, A: Clone> Clone for JITEnabledArray<T, A> {
421 fn clone(&self) -> Self {
422 Self {
423 inner: self.inner.clone(),
424 _phantom: PhantomData::<T>,
425 }
426 }
427}
428
429impl<T, A> JITArray for JITEnabledArray<T, A>
430where
431 T: Send + Sync + 'static,
432 A: ArrayProtocol + Clone + Send + Sync + 'static,
433{
434 fn compile(&self, expression: &str) -> CoreResult<Box<dyn JITFunction>> {
435 let jit_manager = JITManager::global();
437 let jit_manager = jit_manager.read().unwrap();
438
439 jit_manager.compile(expression, TypeId::of::<A>())
441 }
442
443 fn supports_jit(&self) -> bool {
444 let jit_manager = JITManager::global();
446 let jit_manager = jit_manager.read().unwrap();
447
448 jit_manager
449 .get_factory_for_array_type(TypeId::of::<A>())
450 .is_some()
451 }
452
453 fn jit_info(&self) -> HashMap<String, String> {
454 let mut info = HashMap::new();
455
456 let supported = self.supports_jit();
458 info.insert("supports_jit".to_string(), supported.to_string());
459
460 if supported {
461 let jit_manager = JITManager::global();
463 let jit_manager = jit_manager.read().unwrap();
464
465 if jit_manager
467 .get_factory_for_array_type(TypeId::of::<A>())
468 .is_some()
469 {
470 info.insert("factory".to_string(), "JIT factory available".to_string());
472 }
473 }
474
475 info
476 }
477}
478
479impl<T, A> ArrayProtocol for JITEnabledArray<T, A>
480where
481 T: Send + Sync + 'static,
482 A: ArrayProtocol + Clone + Send + Sync + 'static,
483{
484 fn array_function(
485 &self,
486 func: &crate::array_protocol::ArrayFunction,
487 types: &[TypeId],
488 args: &[Box<dyn Any>],
489 kwargs: &HashMap<String, Box<dyn Any>>,
490 ) -> Result<Box<dyn Any>, crate::array_protocol::NotImplemented> {
491 self.inner.array_function(func, types, args, kwargs)
493 }
494
495 fn as_any(&self) -> &dyn Any {
496 self
497 }
498
499 fn shape(&self) -> &[usize] {
500 self.inner.shape()
501 }
502
503 fn dtype(&self) -> TypeId {
504 self.inner.dtype()
505 }
506
507 fn box_clone(&self) -> Box<dyn ArrayProtocol> {
508 let inner_clone = self.inner.clone();
510 Box::new(Self {
511 inner: inner_clone,
512 _phantom: PhantomData::<T>,
513 })
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520 use crate::array_protocol::NdarrayWrapper;
521 use ndarray::Array2;
522
523 #[test]
524 fn test_jit_function_creation() {
525 let config = JITConfig {
527 backend: JITBackend::LLVM,
528 ..Default::default()
529 };
530 let factory = LLVMFunctionFactory::new(config);
531
532 let expression = "x + y";
534
535 let array_type_id = TypeId::of::<NdarrayWrapper<f64, ndarray::Ix2>>();
537 let jit_function = factory.create(expression, array_type_id).unwrap();
538
539 assert_eq!(jit_function.source(), expression);
541 let compile_info = jit_function.compile_info();
542 assert_eq!(compile_info.get("backend").unwrap(), "LLVM");
543 }
544
545 #[test]
546 fn test_jit_manager() {
547 let mut jit_manager = JITManager::new(JITConfig::default());
549 jit_manager.initialize();
550
551 let array_type_id = TypeId::of::<NdarrayWrapper<f64, ndarray::Ix2>>();
553 assert!(jit_manager
554 .get_factory_for_array_type(array_type_id)
555 .is_some());
556
557 let expression = "x + y";
559 let jit_function = jit_manager.compile(expression, array_type_id).unwrap();
560
561 assert_eq!(jit_function.source(), expression);
563 }
564
565 #[test]
566 fn test_jit_enabled_array() {
567 let array = Array2::<f64>::ones((10, 5));
569 let wrapped = NdarrayWrapper::new(array);
570
571 let jit_array: JITEnabledArray<f64, _> = JITEnabledArray::new(wrapped);
573
574 {
576 let mut jit_manager = JITManager::global().write().unwrap();
577 jit_manager.initialize();
578 }
579
580 assert!(jit_array.supports_jit());
582
583 let expression = "x + y";
585 let jit_function = jit_array.compile(expression).unwrap();
586
587 assert_eq!(jit_function.source(), expression);
589 }
590}