1use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct DrugInteractionConfig {
18 pub max_order: usize,
20 pub min_confidence: f64,
22 pub include_known: bool,
24 pub severity_filter: Vec<Severity>,
26}
27
28impl Default for DrugInteractionConfig {
29 fn default() -> Self {
30 Self {
31 max_order: 3,
32 min_confidence: 0.5,
33 include_known: true,
34 severity_filter: vec![Severity::Major, Severity::Moderate, Severity::Minor],
35 }
36 }
37}
38
39#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
41pub enum Severity {
42 Major,
44 Moderate,
46 Minor,
48 Minimal,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct Drug {
55 pub id: String,
57 pub name: String,
59 pub drug_class: Option<String>,
61 pub moa_features: Vec<f64>,
63 pub targets: Vec<String>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct KnownInteraction {
70 pub drug_ids: Vec<String>,
72 pub severity: Severity,
74 pub description: String,
76 pub recommendation: String,
78}
79
80#[derive(Debug, Clone, Default, Serialize, Deserialize)]
82pub struct InteractionKnowledgeBase {
83 pub pairwise: HashMap<(String, String), KnownInteraction>,
85 pub higher_order: HashMap<Vec<String>, KnownInteraction>,
87 pub class_interactions: HashMap<(String, String), Severity>,
89}
90
91impl InteractionKnowledgeBase {
92 pub fn get_known_interaction(&self, drug_ids: &[String]) -> Option<&KnownInteraction> {
94 if drug_ids.len() == 2 {
95 let key = Self::normalize_pair(&drug_ids[0], &drug_ids[1]);
96 self.pairwise.get(&key)
97 } else {
98 let mut sorted = drug_ids.to_vec();
99 sorted.sort();
100 self.higher_order.get(&sorted)
101 }
102 }
103
104 fn normalize_pair(a: &str, b: &str) -> (String, String) {
105 if a < b {
106 (a.to_string(), b.to_string())
107 } else {
108 (b.to_string(), a.to_string())
109 }
110 }
111
112 pub fn add_interaction(&mut self, interaction: KnownInteraction) {
114 if interaction.drug_ids.len() == 2 {
115 let key = Self::normalize_pair(&interaction.drug_ids[0], &interaction.drug_ids[1]);
116 self.pairwise.insert(key, interaction);
117 } else {
118 let mut sorted = interaction.drug_ids.clone();
119 sorted.sort();
120 self.higher_order.insert(sorted, interaction);
121 }
122 }
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct PredictedInteraction {
128 pub drug_ids: Vec<String>,
130 pub drug_names: Vec<String>,
132 pub severity: Severity,
134 pub confidence: f64,
136 pub is_known: bool,
138 pub mechanism: Option<String>,
140 pub risk_score: f64,
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct DrugInteractionResult {
147 pub interactions: Vec<PredictedInteraction>,
149 pub high_risk_combinations: Vec<Vec<String>>,
151 pub polypharmacy_risk: f64,
153 pub recommendations: Vec<String>,
155}
156
157#[derive(Debug, Clone)]
163pub struct DrugInteractionPrediction {
164 metadata: KernelMetadata,
165}
166
167impl Default for DrugInteractionPrediction {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173impl DrugInteractionPrediction {
174 #[must_use]
176 pub fn new() -> Self {
177 Self {
178 metadata: KernelMetadata::batch(
179 "ml/drug-interaction-prediction",
180 Domain::StatisticalML,
181 )
182 .with_description("Multi-drug interaction prediction using mechanism features")
183 .with_throughput(5_000)
184 .with_latency_us(200.0),
185 }
186 }
187
188 pub fn predict(
190 drugs: &[Drug],
191 knowledge_base: &InteractionKnowledgeBase,
192 config: &DrugInteractionConfig,
193 ) -> DrugInteractionResult {
194 if drugs.is_empty() {
195 return DrugInteractionResult {
196 interactions: Vec::new(),
197 high_risk_combinations: Vec::new(),
198 polypharmacy_risk: 0.0,
199 recommendations: Vec::new(),
200 };
201 }
202
203 let mut interactions = Vec::new();
204 let mut high_risk = Vec::new();
205
206 for i in 0..drugs.len() {
208 for j in (i + 1)..drugs.len() {
209 let drug_ids = vec![drugs[i].id.clone(), drugs[j].id.clone()];
210 let drug_names = vec![drugs[i].name.clone(), drugs[j].name.clone()];
211
212 if let Some(known) = knowledge_base.get_known_interaction(&drug_ids) {
214 if config.include_known && config.severity_filter.contains(&known.severity) {
215 interactions.push(PredictedInteraction {
216 drug_ids: drug_ids.clone(),
217 drug_names: drug_names.clone(),
218 severity: known.severity,
219 confidence: 1.0,
220 is_known: true,
221 mechanism: Some(known.description.clone()),
222 risk_score: Self::severity_to_risk(known.severity),
223 });
224
225 if known.severity == Severity::Major {
226 high_risk.push(drug_ids.clone());
227 }
228 }
229 } else {
230 let (severity, confidence) =
232 Self::predict_pairwise(&drugs[i], &drugs[j], knowledge_base);
233
234 if confidence >= config.min_confidence
235 && config.severity_filter.contains(&severity)
236 {
237 let risk = Self::severity_to_risk(severity) * confidence;
238
239 interactions.push(PredictedInteraction {
240 drug_ids: drug_ids.clone(),
241 drug_names,
242 severity,
243 confidence,
244 is_known: false,
245 mechanism: Self::predict_mechanism(&drugs[i], &drugs[j]),
246 risk_score: risk,
247 });
248
249 if severity == Severity::Major && confidence > 0.7 {
250 high_risk.push(drug_ids);
251 }
252 }
253 }
254 }
255 }
256
257 if config.max_order >= 3 && drugs.len() >= 3 {
259 for i in 0..drugs.len() {
260 for j in (i + 1)..drugs.len() {
261 for k in (j + 1)..drugs.len() {
262 let drug_ids = vec![
263 drugs[i].id.clone(),
264 drugs[j].id.clone(),
265 drugs[k].id.clone(),
266 ];
267
268 let (severity, confidence) =
269 Self::predict_triplet(&drugs[i], &drugs[j], &drugs[k], knowledge_base);
270
271 if confidence >= config.min_confidence {
272 interactions.push(PredictedInteraction {
273 drug_ids: drug_ids.clone(),
274 drug_names: vec![
275 drugs[i].name.clone(),
276 drugs[j].name.clone(),
277 drugs[k].name.clone(),
278 ],
279 severity,
280 confidence,
281 is_known: false,
282 mechanism: Some("Complex multi-drug interaction".to_string()),
283 risk_score: Self::severity_to_risk(severity) * confidence,
284 });
285 }
286 }
287 }
288 }
289 }
290
291 let polypharmacy_risk = Self::calculate_polypharmacy_risk(drugs.len(), &interactions);
293
294 let recommendations = Self::generate_recommendations(&interactions, &high_risk);
296
297 interactions.sort_by(|a, b| {
299 b.risk_score
300 .partial_cmp(&a.risk_score)
301 .unwrap_or(std::cmp::Ordering::Equal)
302 });
303
304 DrugInteractionResult {
305 interactions,
306 high_risk_combinations: high_risk,
307 polypharmacy_risk,
308 recommendations,
309 }
310 }
311
312 fn predict_pairwise(
314 drug_a: &Drug,
315 drug_b: &Drug,
316 kb: &InteractionKnowledgeBase,
317 ) -> (Severity, f64) {
318 if let (Some(class_a), Some(class_b)) = (&drug_a.drug_class, &drug_b.drug_class) {
320 let key = if class_a < class_b {
321 (class_a.clone(), class_b.clone())
322 } else {
323 (class_b.clone(), class_a.clone())
324 };
325
326 if let Some(&severity) = kb.class_interactions.get(&key) {
327 return (severity, 0.8);
328 }
329 }
330
331 let moa_similarity = Self::cosine_similarity(&drug_a.moa_features, &drug_b.moa_features);
333
334 let target_overlap = Self::jaccard_similarity(&drug_a.targets, &drug_b.targets);
336
337 let risk_score = target_overlap * (1.0 - moa_similarity) + moa_similarity * 0.3;
339
340 let (severity, confidence) = if risk_score > 0.7 {
341 (Severity::Major, risk_score)
342 } else if risk_score > 0.5 {
343 (Severity::Moderate, risk_score)
344 } else if risk_score > 0.3 {
345 (Severity::Minor, risk_score)
346 } else {
347 (Severity::Minimal, risk_score)
348 };
349
350 (severity, confidence)
351 }
352
353 fn predict_triplet(
355 drug_a: &Drug,
356 drug_b: &Drug,
357 drug_c: &Drug,
358 _kb: &InteractionKnowledgeBase,
359 ) -> (Severity, f64) {
360 let sim_ab = Self::cosine_similarity(&drug_a.moa_features, &drug_b.moa_features);
362 let sim_bc = Self::cosine_similarity(&drug_b.moa_features, &drug_c.moa_features);
363 let sim_ac = Self::cosine_similarity(&drug_a.moa_features, &drug_c.moa_features);
364
365 let avg_sim = (sim_ab + sim_bc + sim_ac) / 3.0;
367
368 let all_targets: HashSet<_> = drug_a
370 .targets
371 .iter()
372 .chain(drug_b.targets.iter())
373 .chain(drug_c.targets.iter())
374 .collect();
375
376 let unique_targets = all_targets.len();
377 let total_targets = drug_a.targets.len() + drug_b.targets.len() + drug_c.targets.len();
378
379 let overlap_ratio = if total_targets > 0 {
380 1.0 - (unique_targets as f64 / total_targets as f64)
381 } else {
382 0.0
383 };
384
385 let risk_score = avg_sim * 0.4 + overlap_ratio * 0.6;
386 let confidence = risk_score * 0.7; let severity = if risk_score > 0.6 {
389 Severity::Major
390 } else if risk_score > 0.4 {
391 Severity::Moderate
392 } else {
393 Severity::Minor
394 };
395
396 (severity, confidence)
397 }
398
399 fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
400 if a.is_empty() || b.is_empty() || a.len() != b.len() {
401 return 0.0;
402 }
403
404 let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
405 let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
406 let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
407
408 if norm_a < 1e-10 || norm_b < 1e-10 {
409 0.0
410 } else {
411 dot / (norm_a * norm_b)
412 }
413 }
414
415 fn jaccard_similarity(a: &[String], b: &[String]) -> f64 {
416 if a.is_empty() && b.is_empty() {
417 return 0.0;
418 }
419
420 let set_a: HashSet<_> = a.iter().collect();
421 let set_b: HashSet<_> = b.iter().collect();
422
423 let intersection = set_a.intersection(&set_b).count();
424 let union = set_a.union(&set_b).count();
425
426 if union == 0 {
427 0.0
428 } else {
429 intersection as f64 / union as f64
430 }
431 }
432
433 fn severity_to_risk(severity: Severity) -> f64 {
434 match severity {
435 Severity::Major => 1.0,
436 Severity::Moderate => 0.6,
437 Severity::Minor => 0.3,
438 Severity::Minimal => 0.1,
439 }
440 }
441
442 fn predict_mechanism(drug_a: &Drug, drug_b: &Drug) -> Option<String> {
443 let target_overlap = Self::jaccard_similarity(&drug_a.targets, &drug_b.targets);
444
445 if target_overlap > 0.5 {
446 Some("Pharmacodynamic: competing for same targets".to_string())
447 } else if target_overlap > 0.2 {
448 Some("Pharmacodynamic: overlapping target pathways".to_string())
449 } else {
450 Some("Pharmacokinetic: potential metabolic interaction".to_string())
451 }
452 }
453
454 fn calculate_polypharmacy_risk(
455 drug_count: usize,
456 interactions: &[PredictedInteraction],
457 ) -> f64 {
458 let count_risk = (drug_count as f64 - 1.0).max(0.0) * 0.1;
460
461 let interaction_risk: f64 = interactions
463 .iter()
464 .map(|i| i.risk_score * i.confidence)
465 .sum::<f64>()
466 / interactions.len().max(1) as f64;
467
468 (count_risk + interaction_risk).min(1.0)
469 }
470
471 fn generate_recommendations(
472 interactions: &[PredictedInteraction],
473 high_risk: &[Vec<String>],
474 ) -> Vec<String> {
475 let mut recommendations = Vec::new();
476
477 if !high_risk.is_empty() {
478 recommendations.push(format!(
479 "ALERT: {} high-risk drug combinations detected. Consider alternatives.",
480 high_risk.len()
481 ));
482 }
483
484 let major_count = interactions
485 .iter()
486 .filter(|i| i.severity == Severity::Major)
487 .count();
488 if major_count > 0 {
489 recommendations.push(format!(
490 "Review {} major interactions before prescribing.",
491 major_count
492 ));
493 }
494
495 if interactions.len() > 5 {
496 recommendations
497 .push("Consider medication review to reduce polypharmacy risk.".to_string());
498 }
499
500 recommendations
501 }
502}
503
504impl GpuKernel for DrugInteractionPrediction {
505 fn metadata(&self) -> &KernelMetadata {
506 &self.metadata
507 }
508}
509
510#[derive(Debug, Clone, Serialize, Deserialize)]
516pub struct PathwayConformanceConfig {
517 pub strictness: ConformanceStrictness,
519 pub allow_documented_deviations: bool,
521 pub time_tolerance_hours: f64,
523 pub required_only: bool,
525}
526
527impl Default for PathwayConformanceConfig {
528 fn default() -> Self {
529 Self {
530 strictness: ConformanceStrictness::Standard,
531 allow_documented_deviations: true,
532 time_tolerance_hours: 24.0,
533 required_only: false,
534 }
535 }
536}
537
538#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
540pub enum ConformanceStrictness {
541 Relaxed,
543 Standard,
545 Strict,
547}
548
549#[derive(Debug, Clone, Serialize, Deserialize)]
551pub struct PathwayStep {
552 pub id: String,
554 pub name: String,
556 pub required: bool,
558 pub expected_timing: Option<f64>,
560 pub dependencies: Vec<String>,
562 pub category: StepCategory,
564}
565
566#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
568pub enum StepCategory {
569 Diagnostic,
571 Treatment,
573 Medication,
575 Monitoring,
577 Consultation,
579 Administrative,
581}
582
583#[derive(Debug, Clone, Serialize, Deserialize)]
585pub struct ClinicalPathway {
586 pub id: String,
588 pub name: String,
590 pub condition: String,
592 pub steps: Vec<PathwayStep>,
594 pub expected_duration_hours: f64,
596}
597
598#[derive(Debug, Clone, Serialize, Deserialize)]
600pub struct CareEvent {
601 pub id: String,
603 pub step_id: Option<String>,
605 pub description: String,
607 pub timestamp_hours: f64,
609 pub category: StepCategory,
611 pub deviation_reason: Option<String>,
613}
614
615#[derive(Debug, Clone, Serialize, Deserialize)]
617pub struct PathwayDeviation {
618 pub step_id: String,
620 pub deviation_type: DeviationType,
622 pub severity: DeviationSeverity,
624 pub description: String,
626 pub reason_documented: bool,
628}
629
630#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
632pub enum DeviationType {
633 MissedStep,
635 OrderViolation,
637 TimingDeviation,
639 ExtraStep,
641 DuplicateStep,
643}
644
645#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
647pub enum DeviationSeverity {
648 Critical,
650 Major,
652 Minor,
654 Info,
656}
657
658#[derive(Debug, Clone, Serialize, Deserialize)]
660pub struct ConformanceResult {
661 pub conformance_score: f64,
663 pub is_conformant: bool,
665 pub deviations: Vec<PathwayDeviation>,
667 pub matched_steps: Vec<String>,
669 pub missing_steps: Vec<String>,
671 pub extra_events: Vec<String>,
673 pub completion_percentage: f64,
675}
676
677#[derive(Debug, Clone)]
683pub struct ClinicalPathwayConformance {
684 metadata: KernelMetadata,
685}
686
687impl Default for ClinicalPathwayConformance {
688 fn default() -> Self {
689 Self::new()
690 }
691}
692
693impl ClinicalPathwayConformance {
694 #[must_use]
696 pub fn new() -> Self {
697 Self {
698 metadata: KernelMetadata::batch(
699 "ml/clinical-pathway-conformance",
700 Domain::StatisticalML,
701 )
702 .with_description("Clinical guideline and pathway conformance checking")
703 .with_throughput(10_000)
704 .with_latency_us(50.0),
705 }
706 }
707
708 pub fn check_conformance(
710 pathway: &ClinicalPathway,
711 events: &[CareEvent],
712 config: &PathwayConformanceConfig,
713 ) -> ConformanceResult {
714 if pathway.steps.is_empty() {
715 return ConformanceResult {
716 conformance_score: 1.0,
717 is_conformant: true,
718 deviations: Vec::new(),
719 matched_steps: Vec::new(),
720 missing_steps: Vec::new(),
721 extra_events: Vec::new(),
722 completion_percentage: 100.0,
723 };
724 }
725
726 let mut deviations = Vec::new();
727 let mut matched_steps = Vec::new();
728 let mut matched_event_ids: HashSet<String> = HashSet::new();
729
730 for step in &pathway.steps {
732 if config.required_only && !step.required {
733 continue;
734 }
735
736 let matching_events: Vec<&CareEvent> = events
737 .iter()
738 .filter(|e| {
739 e.step_id.as_ref() == Some(&step.id)
740 || (e.category == step.category
741 && e.description
742 .to_lowercase()
743 .contains(&step.name.to_lowercase()))
744 })
745 .collect();
746
747 if matching_events.is_empty() {
748 if step.required {
749 deviations.push(PathwayDeviation {
750 step_id: step.id.clone(),
751 deviation_type: DeviationType::MissedStep,
752 severity: DeviationSeverity::Major,
753 description: format!("Required step '{}' was not completed", step.name),
754 reason_documented: false,
755 });
756 }
757 } else {
758 matched_steps.push(step.id.clone());
759 matched_event_ids.insert(matching_events[0].id.clone());
760
761 if let Some(expected_time) = step.expected_timing {
763 let actual_time = matching_events[0].timestamp_hours;
764 let time_diff = (actual_time - expected_time).abs();
765
766 if time_diff > config.time_tolerance_hours {
767 let severity = if time_diff > config.time_tolerance_hours * 2.0 {
768 DeviationSeverity::Major
769 } else {
770 DeviationSeverity::Minor
771 };
772
773 deviations.push(PathwayDeviation {
774 step_id: step.id.clone(),
775 deviation_type: DeviationType::TimingDeviation,
776 severity,
777 description: format!(
778 "Step '{}' timing deviation: expected {}h, actual {}h",
779 step.name, expected_time, actual_time
780 ),
781 reason_documented: matching_events[0].deviation_reason.is_some(),
782 });
783 }
784 }
785
786 if matching_events.len() > 1 {
788 deviations.push(PathwayDeviation {
789 step_id: step.id.clone(),
790 deviation_type: DeviationType::DuplicateStep,
791 severity: DeviationSeverity::Info,
792 description: format!(
793 "Step '{}' completed {} times",
794 step.name,
795 matching_events.len()
796 ),
797 reason_documented: true,
798 });
799 }
800 }
801 }
802
803 for step in &pathway.steps {
805 if !matched_steps.contains(&step.id) {
806 continue;
807 }
808
809 for dep_id in &step.dependencies {
810 if !matched_steps.contains(dep_id) {
811 deviations.push(PathwayDeviation {
812 step_id: step.id.clone(),
813 deviation_type: DeviationType::OrderViolation,
814 severity: DeviationSeverity::Major,
815 description: format!(
816 "Step '{}' completed before dependency '{}'",
817 step.name, dep_id
818 ),
819 reason_documented: false,
820 });
821 }
822 }
823 }
824
825 let extra_events: Vec<String> = events
827 .iter()
828 .filter(|e| !matched_event_ids.contains(&e.id))
829 .map(|e| e.id.clone())
830 .collect();
831
832 let required_steps: Vec<_> = pathway
834 .steps
835 .iter()
836 .filter(|s| s.required)
837 .map(|s| s.id.clone())
838 .collect();
839
840 let missing_steps: Vec<String> = required_steps
841 .iter()
842 .filter(|s| !matched_steps.contains(s))
843 .cloned()
844 .collect();
845
846 if config.allow_documented_deviations {
848 deviations
849 .retain(|d| !(d.reason_documented && d.severity != DeviationSeverity::Critical));
850 }
851
852 let completion_percentage = if required_steps.is_empty() {
854 100.0
855 } else {
856 (matched_steps.len() as f64 / required_steps.len() as f64) * 100.0
857 };
858
859 let deviation_penalty: f64 = deviations
860 .iter()
861 .map(|d| match d.severity {
862 DeviationSeverity::Critical => 0.4,
863 DeviationSeverity::Major => 0.2,
864 DeviationSeverity::Minor => 0.05,
865 DeviationSeverity::Info => 0.0,
866 })
867 .sum();
868
869 let conformance_score =
870 (1.0 - deviation_penalty).max(0.0) * (completion_percentage / 100.0);
871
872 let is_conformant = match config.strictness {
873 ConformanceStrictness::Relaxed => conformance_score >= 0.7,
874 ConformanceStrictness::Standard => {
875 conformance_score >= 0.85 && missing_steps.is_empty()
876 }
877 ConformanceStrictness::Strict => {
878 conformance_score >= 0.95
879 && missing_steps.is_empty()
880 && deviations
881 .iter()
882 .all(|d| d.severity == DeviationSeverity::Info)
883 }
884 };
885
886 ConformanceResult {
887 conformance_score,
888 is_conformant,
889 deviations,
890 matched_steps,
891 missing_steps,
892 extra_events,
893 completion_percentage,
894 }
895 }
896
897 pub fn check_batch(
899 pathway: &ClinicalPathway,
900 patient_events: &[Vec<CareEvent>],
901 config: &PathwayConformanceConfig,
902 ) -> Vec<ConformanceResult> {
903 patient_events
904 .iter()
905 .map(|events| Self::check_conformance(pathway, events, config))
906 .collect()
907 }
908
909 pub fn aggregate_stats(results: &[ConformanceResult]) -> PathwayStatistics {
911 if results.is_empty() {
912 return PathwayStatistics::default();
913 }
914
915 let n = results.len() as f64;
916
917 let avg_conformance = results.iter().map(|r| r.conformance_score).sum::<f64>() / n;
918 let conformant_count = results.iter().filter(|r| r.is_conformant).count();
919 let avg_completion = results.iter().map(|r| r.completion_percentage).sum::<f64>() / n;
920
921 let mut deviation_counts: HashMap<DeviationType, usize> = HashMap::new();
923 for result in results {
924 for dev in &result.deviations {
925 *deviation_counts.entry(dev.deviation_type).or_insert(0) += 1;
926 }
927 }
928
929 PathwayStatistics {
930 total_patients: results.len(),
931 conformant_patients: conformant_count,
932 conformance_rate: conformant_count as f64 / n,
933 average_conformance_score: avg_conformance,
934 average_completion: avg_completion,
935 deviation_counts,
936 }
937 }
938}
939
940#[derive(Debug, Clone, Default, Serialize, Deserialize)]
942pub struct PathwayStatistics {
943 pub total_patients: usize,
945 pub conformant_patients: usize,
947 pub conformance_rate: f64,
949 pub average_conformance_score: f64,
951 pub average_completion: f64,
953 pub deviation_counts: HashMap<DeviationType, usize>,
955}
956
957impl GpuKernel for ClinicalPathwayConformance {
958 fn metadata(&self) -> &KernelMetadata {
959 &self.metadata
960 }
961}
962
963#[cfg(test)]
964mod tests {
965 use super::*;
966
967 #[test]
968 fn test_drug_interaction_metadata() {
969 let kernel = DrugInteractionPrediction::new();
970 assert_eq!(kernel.metadata().id, "ml/drug-interaction-prediction");
971 }
972
973 #[test]
974 fn test_drug_interaction_basic() {
975 let drugs = vec![
976 Drug {
977 id: "drug1".to_string(),
978 name: "Aspirin".to_string(),
979 drug_class: Some("NSAID".to_string()),
980 moa_features: vec![1.0, 0.0, 0.0],
981 targets: vec!["COX1".to_string(), "COX2".to_string()],
982 },
983 Drug {
984 id: "drug2".to_string(),
985 name: "Ibuprofen".to_string(),
986 drug_class: Some("NSAID".to_string()),
987 moa_features: vec![0.0, 1.0, 0.0],
989 targets: vec!["COX1".to_string(), "COX2".to_string()],
990 },
991 ];
992
993 let kb = InteractionKnowledgeBase::default();
994 let config = DrugInteractionConfig::default();
995
996 let result = DrugInteractionPrediction::predict(&drugs, &kb, &config);
997
998 assert!(!result.interactions.is_empty());
1000 }
1001
1002 #[test]
1003 fn test_known_interaction() {
1004 let drugs = vec![
1005 Drug {
1006 id: "warfarin".to_string(),
1007 name: "Warfarin".to_string(),
1008 drug_class: Some("Anticoagulant".to_string()),
1009 moa_features: vec![],
1010 targets: vec![],
1011 },
1012 Drug {
1013 id: "aspirin".to_string(),
1014 name: "Aspirin".to_string(),
1015 drug_class: Some("NSAID".to_string()),
1016 moa_features: vec![],
1017 targets: vec![],
1018 },
1019 ];
1020
1021 let mut kb = InteractionKnowledgeBase::default();
1022 kb.add_interaction(KnownInteraction {
1023 drug_ids: vec!["warfarin".to_string(), "aspirin".to_string()],
1024 severity: Severity::Major,
1025 description: "Increased bleeding risk".to_string(),
1026 recommendation: "Avoid combination".to_string(),
1027 });
1028
1029 let config = DrugInteractionConfig::default();
1030 let result = DrugInteractionPrediction::predict(&drugs, &kb, &config);
1031
1032 assert!(
1033 result
1034 .interactions
1035 .iter()
1036 .any(|i| i.is_known && i.severity == Severity::Major)
1037 );
1038 }
1039
1040 #[test]
1041 fn test_empty_drugs() {
1042 let kb = InteractionKnowledgeBase::default();
1043 let config = DrugInteractionConfig::default();
1044 let result = DrugInteractionPrediction::predict(&[], &kb, &config);
1045 assert!(result.interactions.is_empty());
1046 }
1047
1048 #[test]
1049 fn test_pathway_conformance_metadata() {
1050 let kernel = ClinicalPathwayConformance::new();
1051 assert_eq!(kernel.metadata().id, "ml/clinical-pathway-conformance");
1052 }
1053
1054 #[test]
1055 fn test_pathway_conformance_basic() {
1056 let pathway = ClinicalPathway {
1057 id: "sepsis".to_string(),
1058 name: "Sepsis Bundle".to_string(),
1059 condition: "Sepsis".to_string(),
1060 steps: vec![
1061 PathwayStep {
1062 id: "lactate".to_string(),
1063 name: "Measure lactate".to_string(),
1064 required: true,
1065 expected_timing: Some(1.0),
1066 dependencies: vec![],
1067 category: StepCategory::Diagnostic,
1068 },
1069 PathwayStep {
1070 id: "cultures".to_string(),
1071 name: "Blood cultures".to_string(),
1072 required: true,
1073 expected_timing: Some(1.0),
1074 dependencies: vec![],
1075 category: StepCategory::Diagnostic,
1076 },
1077 PathwayStep {
1078 id: "antibiotics".to_string(),
1079 name: "Broad spectrum antibiotics".to_string(),
1080 required: true,
1081 expected_timing: Some(3.0),
1082 dependencies: vec!["cultures".to_string()],
1083 category: StepCategory::Medication,
1084 },
1085 ],
1086 expected_duration_hours: 6.0,
1087 };
1088
1089 let events = vec![
1090 CareEvent {
1091 id: "e1".to_string(),
1092 step_id: Some("lactate".to_string()),
1093 description: "Lactate measured".to_string(),
1094 timestamp_hours: 0.5,
1095 category: StepCategory::Diagnostic,
1096 deviation_reason: None,
1097 },
1098 CareEvent {
1099 id: "e2".to_string(),
1100 step_id: Some("cultures".to_string()),
1101 description: "Blood cultures drawn".to_string(),
1102 timestamp_hours: 0.75,
1103 category: StepCategory::Diagnostic,
1104 deviation_reason: None,
1105 },
1106 CareEvent {
1107 id: "e3".to_string(),
1108 step_id: Some("antibiotics".to_string()),
1109 description: "Antibiotics administered".to_string(),
1110 timestamp_hours: 2.0,
1111 category: StepCategory::Medication,
1112 deviation_reason: None,
1113 },
1114 ];
1115
1116 let config = PathwayConformanceConfig::default();
1117 let result = ClinicalPathwayConformance::check_conformance(&pathway, &events, &config);
1118
1119 assert!(result.conformance_score > 0.9);
1120 assert!(result.is_conformant);
1121 assert!(result.missing_steps.is_empty());
1122 }
1123
1124 #[test]
1125 fn test_missed_required_step() {
1126 let pathway = ClinicalPathway {
1127 id: "test".to_string(),
1128 name: "Test".to_string(),
1129 condition: "Test".to_string(),
1130 steps: vec![PathwayStep {
1131 id: "required".to_string(),
1132 name: "Required step".to_string(),
1133 required: true,
1134 expected_timing: None,
1135 dependencies: vec![],
1136 category: StepCategory::Treatment,
1137 }],
1138 expected_duration_hours: 24.0,
1139 };
1140
1141 let events: Vec<CareEvent> = vec![];
1142 let config = PathwayConformanceConfig::default();
1143
1144 let result = ClinicalPathwayConformance::check_conformance(&pathway, &events, &config);
1145
1146 assert!(!result.is_conformant);
1147 assert!(
1148 result
1149 .deviations
1150 .iter()
1151 .any(|d| d.deviation_type == DeviationType::MissedStep)
1152 );
1153 }
1154
1155 #[test]
1156 fn test_batch_conformance() {
1157 let pathway = ClinicalPathway {
1158 id: "simple".to_string(),
1159 name: "Simple".to_string(),
1160 condition: "Test".to_string(),
1161 steps: vec![PathwayStep {
1162 id: "step1".to_string(),
1163 name: "Step 1".to_string(),
1164 required: true,
1165 expected_timing: None,
1166 dependencies: vec![],
1167 category: StepCategory::Treatment,
1168 }],
1169 expected_duration_hours: 24.0,
1170 };
1171
1172 let patient_events = vec![
1173 vec![CareEvent {
1174 id: "p1e1".to_string(),
1175 step_id: Some("step1".to_string()),
1176 description: "Step 1".to_string(),
1177 timestamp_hours: 1.0,
1178 category: StepCategory::Treatment,
1179 deviation_reason: None,
1180 }],
1181 vec![], ];
1183
1184 let config = PathwayConformanceConfig::default();
1185 let results = ClinicalPathwayConformance::check_batch(&pathway, &patient_events, &config);
1186
1187 assert_eq!(results.len(), 2);
1188 assert!(results[0].is_conformant);
1189 assert!(!results[1].is_conformant);
1190
1191 let stats = ClinicalPathwayConformance::aggregate_stats(&results);
1192 assert_eq!(stats.total_patients, 2);
1193 assert_eq!(stats.conformant_patients, 1);
1194 }
1195}