1use crate::domain::Domain;
7use crate::error::{KernelError, Result};
8use crate::kernel::{KernelMetadata, KernelMode};
9use crate::license::{LicenseError, LicenseValidator, SharedLicenseValidator};
10use crate::traits::{BatchKernelDyn, RingKernelDyn};
11use hashbrown::HashMap;
12use std::sync::{Arc, RwLock};
13use tracing::{debug, info, warn};
14
15#[derive(Debug, Clone, Default)]
17pub struct RegistryStats {
18 pub total: usize,
20 pub batch_kernels: usize,
22 pub ring_kernels: usize,
24 pub by_domain: HashMap<Domain, usize>,
26}
27
28#[derive(Clone)]
30pub struct BatchKernelEntry {
31 pub metadata: KernelMetadata,
33 factory: Arc<dyn Fn() -> Arc<dyn BatchKernelDyn> + Send + Sync>,
35}
36
37impl BatchKernelEntry {
38 pub fn new<F>(metadata: KernelMetadata, factory: F) -> Self
40 where
41 F: Fn() -> Arc<dyn BatchKernelDyn> + Send + Sync + 'static,
42 {
43 Self {
44 metadata,
45 factory: Arc::new(factory),
46 }
47 }
48
49 #[must_use]
51 pub fn create(&self) -> Arc<dyn BatchKernelDyn> {
52 (self.factory)()
53 }
54}
55
56impl std::fmt::Debug for BatchKernelEntry {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 f.debug_struct("BatchKernelEntry")
59 .field("metadata", &self.metadata)
60 .finish()
61 }
62}
63
64#[derive(Clone)]
66pub struct RingKernelEntry {
67 pub metadata: KernelMetadata,
69 factory: Arc<dyn Fn() -> Arc<dyn RingKernelDyn> + Send + Sync>,
71}
72
73impl RingKernelEntry {
74 pub fn new<F>(metadata: KernelMetadata, factory: F) -> Self
76 where
77 F: Fn() -> Arc<dyn RingKernelDyn> + Send + Sync + 'static,
78 {
79 Self {
80 metadata,
81 factory: Arc::new(factory),
82 }
83 }
84
85 #[must_use]
87 pub fn create(&self) -> Arc<dyn RingKernelDyn> {
88 (self.factory)()
89 }
90}
91
92impl std::fmt::Debug for RingKernelEntry {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("RingKernelEntry")
95 .field("metadata", &self.metadata)
96 .finish()
97 }
98}
99
100#[derive(Debug)]
102pub struct KernelRegistry {
103 batch_kernels: RwLock<HashMap<String, BatchKernelEntry>>,
105 ring_kernels: RwLock<HashMap<String, RingKernelEntry>>,
107 metadata_only: RwLock<HashMap<String, KernelMetadata>>,
109 license: Option<SharedLicenseValidator>,
111}
112
113impl KernelRegistry {
114 #[must_use]
116 pub fn new() -> Self {
117 Self {
118 batch_kernels: RwLock::new(HashMap::new()),
119 ring_kernels: RwLock::new(HashMap::new()),
120 metadata_only: RwLock::new(HashMap::new()),
121 license: None,
122 }
123 }
124
125 #[must_use]
127 pub fn with_license(license: SharedLicenseValidator) -> Self {
128 Self {
129 batch_kernels: RwLock::new(HashMap::new()),
130 ring_kernels: RwLock::new(HashMap::new()),
131 metadata_only: RwLock::new(HashMap::new()),
132 license: Some(license),
133 }
134 }
135
136 pub fn set_license(&mut self, license: SharedLicenseValidator) {
138 self.license = Some(license);
139 }
140
141 pub fn register_batch(&self, entry: BatchKernelEntry) -> Result<()> {
143 if let Some(ref license) = self.license {
145 self.validate_kernel_license(license.as_ref(), &entry.metadata)?;
146 }
147
148 let id = entry.metadata.id.clone();
149 let mut kernels = self.batch_kernels.write().unwrap();
150
151 if kernels.contains_key(&id) {
152 return Err(KernelError::KernelAlreadyRegistered(id));
153 }
154
155 debug!(kernel_id = %id, domain = %entry.metadata.domain, "Registering batch kernel");
156 kernels.insert(id, entry);
157 Ok(())
158 }
159
160 pub fn register_ring(&self, entry: RingKernelEntry) -> Result<()> {
162 if let Some(ref license) = self.license {
164 self.validate_kernel_license(license.as_ref(), &entry.metadata)?;
165 }
166
167 let id = entry.metadata.id.clone();
168 let mut kernels = self.ring_kernels.write().unwrap();
169
170 if kernels.contains_key(&id) {
171 return Err(KernelError::KernelAlreadyRegistered(id));
172 }
173
174 debug!(kernel_id = %id, domain = %entry.metadata.domain, "Registering ring kernel");
175 kernels.insert(id, entry);
176 Ok(())
177 }
178
179 pub fn register_metadata(&self, metadata: KernelMetadata) -> Result<()> {
185 if let Some(ref license) = self.license {
187 self.validate_kernel_license(license.as_ref(), &metadata)?;
188 }
189
190 let id = metadata.id.clone();
191
192 {
194 let batch = self.batch_kernels.read().unwrap();
195 if batch.contains_key(&id) {
196 return Err(KernelError::KernelAlreadyRegistered(id));
197 }
198 }
199 {
200 let ring = self.ring_kernels.read().unwrap();
201 if ring.contains_key(&id) {
202 return Err(KernelError::KernelAlreadyRegistered(id));
203 }
204 }
205
206 let mut metadata_map = self.metadata_only.write().unwrap();
207 if metadata_map.contains_key(&id) {
208 return Err(KernelError::KernelAlreadyRegistered(id));
209 }
210
211 debug!(kernel_id = %id, domain = %metadata.domain, mode = ?metadata.mode, "Registering kernel metadata");
212 metadata_map.insert(id, metadata);
213 Ok(())
214 }
215
216 fn validate_kernel_license(
218 &self,
219 license: &dyn LicenseValidator,
220 metadata: &KernelMetadata,
221 ) -> Result<()> {
222 license
224 .validate_domain(metadata.domain)
225 .map_err(KernelError::from)?;
226
227 if metadata.requires_gpu_native && !license.gpu_native_enabled() {
229 return Err(KernelError::from(LicenseError::GpuNativeNotLicensed));
230 }
231
232 Ok(())
233 }
234
235 #[must_use]
237 pub fn get_batch(&self, id: &str) -> Option<BatchKernelEntry> {
238 let kernels = self.batch_kernels.read().unwrap();
239 kernels.get(id).cloned()
240 }
241
242 #[must_use]
244 pub fn get_ring(&self, id: &str) -> Option<RingKernelEntry> {
245 let kernels = self.ring_kernels.read().unwrap();
246 kernels.get(id).cloned()
247 }
248
249 #[must_use]
251 pub fn get(&self, id: &str) -> Option<KernelMetadata> {
252 if let Some(entry) = self.get_batch(id) {
253 return Some(entry.metadata);
254 }
255 if let Some(entry) = self.get_ring(id) {
256 return Some(entry.metadata);
257 }
258 let metadata_map = self.metadata_only.read().unwrap();
259 metadata_map.get(id).cloned()
260 }
261
262 #[must_use]
264 pub fn get_metadata_only(&self, id: &str) -> Option<KernelMetadata> {
265 let metadata_map = self.metadata_only.read().unwrap();
266 metadata_map.get(id).cloned()
267 }
268
269 #[must_use]
271 pub fn contains(&self, id: &str) -> bool {
272 let batch = self.batch_kernels.read().unwrap();
273 let ring = self.ring_kernels.read().unwrap();
274 let metadata = self.metadata_only.read().unwrap();
275 batch.contains_key(id) || ring.contains_key(id) || metadata.contains_key(id)
276 }
277
278 #[must_use]
280 pub fn batch_kernel_ids(&self) -> Vec<String> {
281 let kernels = self.batch_kernels.read().unwrap();
282 kernels.keys().cloned().collect()
283 }
284
285 #[must_use]
287 pub fn ring_kernel_ids(&self) -> Vec<String> {
288 let kernels = self.ring_kernels.read().unwrap();
289 kernels.keys().cloned().collect()
290 }
291
292 #[must_use]
294 pub fn metadata_only_ids(&self) -> Vec<String> {
295 let metadata = self.metadata_only.read().unwrap();
296 metadata.keys().cloned().collect()
297 }
298
299 #[must_use]
301 pub fn all_kernel_ids(&self) -> Vec<String> {
302 let mut ids = self.batch_kernel_ids();
303 ids.extend(self.ring_kernel_ids());
304 ids.extend(self.metadata_only_ids());
305 ids
306 }
307
308 #[must_use]
310 pub fn by_domain(&self, domain: Domain) -> Vec<KernelMetadata> {
311 let mut result = Vec::new();
312
313 let batch = self.batch_kernels.read().unwrap();
314 for entry in batch.values() {
315 if entry.metadata.domain == domain {
316 result.push(entry.metadata.clone());
317 }
318 }
319
320 let ring = self.ring_kernels.read().unwrap();
321 for entry in ring.values() {
322 if entry.metadata.domain == domain {
323 result.push(entry.metadata.clone());
324 }
325 }
326
327 let metadata = self.metadata_only.read().unwrap();
328 for entry in metadata.values() {
329 if entry.domain == domain {
330 result.push(entry.clone());
331 }
332 }
333
334 result
335 }
336
337 #[must_use]
339 pub fn by_mode(&self, mode: KernelMode) -> Vec<KernelMetadata> {
340 let mut result: Vec<KernelMetadata> = match mode {
341 KernelMode::Batch => {
342 let kernels = self.batch_kernels.read().unwrap();
343 kernels.values().map(|e| e.metadata.clone()).collect()
344 }
345 KernelMode::Ring => {
346 let kernels = self.ring_kernels.read().unwrap();
347 kernels.values().map(|e| e.metadata.clone()).collect()
348 }
349 };
350
351 let metadata = self.metadata_only.read().unwrap();
353 for entry in metadata.values() {
354 if entry.mode == mode {
355 result.push(entry.clone());
356 }
357 }
358
359 result
360 }
361
362 #[must_use]
364 pub fn stats(&self) -> RegistryStats {
365 let batch = self.batch_kernels.read().unwrap();
366 let ring = self.ring_kernels.read().unwrap();
367 let metadata = self.metadata_only.read().unwrap();
368
369 let mut by_domain: HashMap<Domain, usize> = HashMap::new();
370
371 for entry in batch.values() {
372 *by_domain.entry(entry.metadata.domain).or_default() += 1;
373 }
374
375 for entry in ring.values() {
376 *by_domain.entry(entry.metadata.domain).or_default() += 1;
377 }
378
379 let mut metadata_batch = 0;
381 let mut metadata_ring = 0;
382 for entry in metadata.values() {
383 *by_domain.entry(entry.domain).or_default() += 1;
384 match entry.mode {
385 KernelMode::Batch => metadata_batch += 1,
386 KernelMode::Ring => metadata_ring += 1,
387 }
388 }
389
390 RegistryStats {
391 total: batch.len() + ring.len() + metadata.len(),
392 batch_kernels: batch.len() + metadata_batch,
393 ring_kernels: ring.len() + metadata_ring,
394 by_domain,
395 }
396 }
397
398 #[must_use]
400 pub fn total_count(&self) -> usize {
401 let batch = self.batch_kernels.read().unwrap();
402 let ring = self.ring_kernels.read().unwrap();
403 let metadata = self.metadata_only.read().unwrap();
404 batch.len() + ring.len() + metadata.len()
405 }
406
407 pub fn clear(&self) {
409 let mut batch = self.batch_kernels.write().unwrap();
410 let mut ring = self.ring_kernels.write().unwrap();
411 let mut metadata = self.metadata_only.write().unwrap();
412 batch.clear();
413 ring.clear();
414 metadata.clear();
415 info!("Cleared kernel registry");
416 }
417
418 pub fn unregister(&self, id: &str) -> bool {
420 let mut batch = self.batch_kernels.write().unwrap();
421 if batch.remove(id).is_some() {
422 debug!(kernel_id = %id, "Unregistered batch kernel");
423 return true;
424 }
425
426 let mut ring = self.ring_kernels.write().unwrap();
427 if ring.remove(id).is_some() {
428 debug!(kernel_id = %id, "Unregistered ring kernel");
429 return true;
430 }
431
432 let mut metadata = self.metadata_only.write().unwrap();
433 if metadata.remove(id).is_some() {
434 debug!(kernel_id = %id, "Unregistered metadata-only kernel");
435 return true;
436 }
437
438 warn!(kernel_id = %id, "Attempted to unregister non-existent kernel");
439 false
440 }
441}
442
443impl Default for KernelRegistry {
444 fn default() -> Self {
445 Self::new()
446 }
447}
448
449#[derive(Default)]
451pub struct KernelRegistryBuilder {
452 license: Option<SharedLicenseValidator>,
453 batch_entries: Vec<BatchKernelEntry>,
454 ring_entries: Vec<RingKernelEntry>,
455}
456
457impl KernelRegistryBuilder {
458 #[must_use]
460 pub fn new() -> Self {
461 Self::default()
462 }
463
464 #[must_use]
466 pub fn with_license(mut self, license: SharedLicenseValidator) -> Self {
467 self.license = Some(license);
468 self
469 }
470
471 #[must_use]
473 pub fn with_batch(mut self, entry: BatchKernelEntry) -> Self {
474 self.batch_entries.push(entry);
475 self
476 }
477
478 #[must_use]
480 pub fn with_ring(mut self, entry: RingKernelEntry) -> Self {
481 self.ring_entries.push(entry);
482 self
483 }
484
485 pub fn build(self) -> Result<KernelRegistry> {
491 let registry = match self.license {
492 Some(license) => KernelRegistry::with_license(license),
493 None => KernelRegistry::new(),
494 };
495
496 for entry in self.batch_entries {
497 registry.register_batch(entry)?;
498 }
499
500 for entry in self.ring_entries {
501 registry.register_ring(entry)?;
502 }
503
504 info!(
505 total = registry.total_count(),
506 batch = registry.batch_kernel_ids().len(),
507 ring = registry.ring_kernel_ids().len(),
508 "Built kernel registry"
509 );
510
511 Ok(registry)
512 }
513}
514
515static GLOBAL_REGISTRY: std::sync::OnceLock<KernelRegistry> = std::sync::OnceLock::new();
519
520pub fn global_registry() -> &'static KernelRegistry {
522 GLOBAL_REGISTRY.get_or_init(KernelRegistry::new)
523}
524
525pub fn init_global_registry(license: SharedLicenseValidator) -> &'static KernelRegistry {
533 let registry = KernelRegistry::with_license(license);
534 GLOBAL_REGISTRY
535 .set(registry)
536 .expect("Global registry already initialized");
537 GLOBAL_REGISTRY.get().unwrap()
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543 use crate::license::DevelopmentLicense;
544
545 fn test_batch_entry() -> BatchKernelEntry {
546 let metadata = KernelMetadata::batch("test-batch", Domain::Core);
547 BatchKernelEntry::new(metadata, || {
548 panic!("Not implemented for tests")
550 })
551 }
552
553 fn test_ring_entry() -> RingKernelEntry {
554 let metadata = KernelMetadata::ring("test-ring", Domain::Core);
555 RingKernelEntry::new(metadata, || {
556 panic!("Not implemented for tests")
558 })
559 }
560
561 #[test]
562 fn test_registry_creation() {
563 let registry = KernelRegistry::new();
564 assert_eq!(registry.total_count(), 0);
565 }
566
567 #[test]
568 fn test_batch_registration() {
569 let registry = KernelRegistry::new();
570 let entry = test_batch_entry();
571
572 registry.register_batch(entry).unwrap();
573 assert_eq!(registry.total_count(), 1);
574 assert!(registry.contains("test-batch"));
575 assert!(registry.get_batch("test-batch").is_some());
576 }
577
578 #[test]
579 fn test_ring_registration() {
580 let registry = KernelRegistry::new();
581 let entry = test_ring_entry();
582
583 registry.register_ring(entry).unwrap();
584 assert_eq!(registry.total_count(), 1);
585 assert!(registry.contains("test-ring"));
586 assert!(registry.get_ring("test-ring").is_some());
587 }
588
589 #[test]
590 fn test_duplicate_registration() {
591 let registry = KernelRegistry::new();
592 let entry1 = test_batch_entry();
593 let entry2 = test_batch_entry();
594
595 registry.register_batch(entry1).unwrap();
596 let result = registry.register_batch(entry2);
597 assert!(result.is_err());
598 }
599
600 #[test]
601 fn test_by_domain() {
602 let registry = KernelRegistry::new();
603
604 let core_entry = test_batch_entry();
605 registry.register_batch(core_entry).unwrap();
606
607 let graph_entry = BatchKernelEntry::new(
608 KernelMetadata::batch("test-graph", Domain::GraphAnalytics),
609 || panic!("Not implemented"),
610 );
611 registry.register_batch(graph_entry).unwrap();
612
613 let core_kernels = registry.by_domain(Domain::Core);
614 assert_eq!(core_kernels.len(), 1);
615
616 let graph_kernels = registry.by_domain(Domain::GraphAnalytics);
617 assert_eq!(graph_kernels.len(), 1);
618 }
619
620 #[test]
621 fn test_stats() {
622 let registry = KernelRegistry::new();
623
624 registry.register_batch(test_batch_entry()).unwrap();
625 registry.register_ring(test_ring_entry()).unwrap();
626
627 let stats = registry.stats();
628 assert_eq!(stats.total, 2);
629 assert_eq!(stats.batch_kernels, 1);
630 assert_eq!(stats.ring_kernels, 1);
631 assert_eq!(stats.by_domain.get(&Domain::Core), Some(&2));
632 }
633
634 #[test]
635 fn test_unregister() {
636 let registry = KernelRegistry::new();
637 registry.register_batch(test_batch_entry()).unwrap();
638
639 assert!(registry.contains("test-batch"));
640 assert!(registry.unregister("test-batch"));
641 assert!(!registry.contains("test-batch"));
642 assert!(!registry.unregister("test-batch"));
643 }
644
645 #[test]
646 fn test_with_license() {
647 let license: SharedLicenseValidator = Arc::new(DevelopmentLicense);
648 let registry = KernelRegistry::with_license(license);
649
650 registry.register_batch(test_batch_entry()).unwrap();
652 registry.register_ring(test_ring_entry()).unwrap();
653 }
654
655 #[test]
656 fn test_builder() {
657 let registry = KernelRegistryBuilder::new()
658 .with_batch(test_batch_entry())
659 .with_ring(test_ring_entry())
660 .build()
661 .unwrap();
662
663 assert_eq!(registry.total_count(), 2);
664 }
665
666 #[test]
667 fn test_metadata_only_registration() {
668 let registry = KernelRegistry::new();
669 let metadata = KernelMetadata::batch("test-metadata", Domain::GraphAnalytics);
670
671 registry.register_metadata(metadata).unwrap();
672
673 assert_eq!(registry.total_count(), 1);
674 assert!(registry.contains("test-metadata"));
675 assert!(registry.get("test-metadata").is_some());
676 assert!(registry.get_metadata_only("test-metadata").is_some());
677 assert!(registry.get_batch("test-metadata").is_none()); }
679
680 #[test]
681 fn test_metadata_only_duplicate() {
682 let registry = KernelRegistry::new();
683 let metadata1 = KernelMetadata::batch("test-dup", Domain::Core);
684 let metadata2 = KernelMetadata::batch("test-dup", Domain::Core);
685
686 registry.register_metadata(metadata1).unwrap();
687 let result = registry.register_metadata(metadata2);
688 assert!(result.is_err());
689 }
690
691 #[test]
692 fn test_metadata_only_conflict_with_batch() {
693 let registry = KernelRegistry::new();
694 registry.register_batch(test_batch_entry()).unwrap();
695
696 let metadata = KernelMetadata::batch("test-batch", Domain::Core);
698 let result = registry.register_metadata(metadata);
699 assert!(result.is_err());
700 }
701
702 #[test]
703 fn test_metadata_only_in_stats() {
704 let registry = KernelRegistry::new();
705
706 let batch_meta = KernelMetadata::batch("meta-batch", Domain::GraphAnalytics);
707 let ring_meta = KernelMetadata::ring("meta-ring", Domain::GraphAnalytics);
708 registry.register_metadata(batch_meta).unwrap();
709 registry.register_metadata(ring_meta).unwrap();
710
711 let stats = registry.stats();
712 assert_eq!(stats.total, 2);
713 assert_eq!(stats.batch_kernels, 1);
714 assert_eq!(stats.ring_kernels, 1);
715 assert_eq!(stats.by_domain.get(&Domain::GraphAnalytics), Some(&2));
716 }
717
718 #[test]
719 fn test_metadata_only_by_domain() {
720 let registry = KernelRegistry::new();
721
722 let graph_meta = KernelMetadata::batch("graph-kernel", Domain::GraphAnalytics);
723 let ml_meta = KernelMetadata::batch("ml-kernel", Domain::StatisticalML);
724 registry.register_metadata(graph_meta).unwrap();
725 registry.register_metadata(ml_meta).unwrap();
726
727 let graph_kernels = registry.by_domain(Domain::GraphAnalytics);
728 assert_eq!(graph_kernels.len(), 1);
729 assert_eq!(graph_kernels[0].id, "graph-kernel");
730
731 let ml_kernels = registry.by_domain(Domain::StatisticalML);
732 assert_eq!(ml_kernels.len(), 1);
733 assert_eq!(ml_kernels[0].id, "ml-kernel");
734 }
735
736 #[test]
737 fn test_unregister_metadata_only() {
738 let registry = KernelRegistry::new();
739 let metadata = KernelMetadata::batch("test-unreg", Domain::Core);
740
741 registry.register_metadata(metadata).unwrap();
742 assert!(registry.contains("test-unreg"));
743
744 assert!(registry.unregister("test-unreg"));
745 assert!(!registry.contains("test-unreg"));
746 }
747}