rustkernel_ml/
healthcare.rs

1//! Healthcare analytics kernels.
2//!
3//! This module provides GPU-accelerated healthcare algorithms:
4//! - DrugInteractionPrediction - Multi-drug interaction analysis
5//! - ClinicalPathwayConformance - Treatment guideline checking
6
7use rustkernel_core::{domain::Domain, kernel::KernelMetadata, traits::GpuKernel};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10
11// ============================================================================
12// Drug Interaction Prediction Kernel
13// ============================================================================
14
15/// Configuration for drug interaction prediction.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct DrugInteractionConfig {
18    /// Maximum interaction order to check (2 = pairwise, 3 = triplets, etc.).
19    pub max_order: usize,
20    /// Minimum confidence for reported interactions.
21    pub min_confidence: f64,
22    /// Include known interactions in output.
23    pub include_known: bool,
24    /// Severity levels to include.
25    pub severity_filter: Vec<Severity>,
26}
27
28impl Default for DrugInteractionConfig {
29    fn default() -> Self {
30        Self {
31            max_order: 3,
32            min_confidence: 0.5,
33            include_known: true,
34            severity_filter: vec![Severity::Major, Severity::Moderate, Severity::Minor],
35        }
36    }
37}
38
39/// Drug severity level.
40#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
41pub enum Severity {
42    /// Life-threatening or major organ damage.
43    Major,
44    /// Significant but not life-threatening.
45    Moderate,
46    /// Minor effects, usually manageable.
47    Minor,
48    /// Theoretical or minimal clinical significance.
49    Minimal,
50}
51
52/// A drug entity.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct Drug {
55    /// Drug identifier (e.g., RxNorm CUI).
56    pub id: String,
57    /// Drug name.
58    pub name: String,
59    /// Drug class/category.
60    pub drug_class: Option<String>,
61    /// Mechanism of action features.
62    pub moa_features: Vec<f64>,
63    /// Target proteins/receptors.
64    pub targets: Vec<String>,
65}
66
67/// Known interaction entry.
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct KnownInteraction {
70    /// Drug IDs involved.
71    pub drug_ids: Vec<String>,
72    /// Severity of interaction.
73    pub severity: Severity,
74    /// Description of the interaction.
75    pub description: String,
76    /// Clinical recommendation.
77    pub recommendation: String,
78}
79
80/// Drug interaction knowledge base.
81#[derive(Debug, Clone, Default, Serialize, Deserialize)]
82pub struct InteractionKnowledgeBase {
83    /// Known pairwise interactions.
84    pub pairwise: HashMap<(String, String), KnownInteraction>,
85    /// Known higher-order interactions.
86    pub higher_order: HashMap<Vec<String>, KnownInteraction>,
87    /// Drug class interactions.
88    pub class_interactions: HashMap<(String, String), Severity>,
89}
90
91impl InteractionKnowledgeBase {
92    /// Check if drugs have a known interaction.
93    pub fn get_known_interaction(&self, drug_ids: &[String]) -> Option<&KnownInteraction> {
94        if drug_ids.len() == 2 {
95            let key = Self::normalize_pair(&drug_ids[0], &drug_ids[1]);
96            self.pairwise.get(&key)
97        } else {
98            let mut sorted = drug_ids.to_vec();
99            sorted.sort();
100            self.higher_order.get(&sorted)
101        }
102    }
103
104    fn normalize_pair(a: &str, b: &str) -> (String, String) {
105        if a < b {
106            (a.to_string(), b.to_string())
107        } else {
108            (b.to_string(), a.to_string())
109        }
110    }
111
112    /// Add a known interaction.
113    pub fn add_interaction(&mut self, interaction: KnownInteraction) {
114        if interaction.drug_ids.len() == 2 {
115            let key = Self::normalize_pair(&interaction.drug_ids[0], &interaction.drug_ids[1]);
116            self.pairwise.insert(key, interaction);
117        } else {
118            let mut sorted = interaction.drug_ids.clone();
119            sorted.sort();
120            self.higher_order.insert(sorted, interaction);
121        }
122    }
123}
124
125/// Predicted drug interaction.
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct PredictedInteraction {
128    /// Drugs involved.
129    pub drug_ids: Vec<String>,
130    /// Drug names.
131    pub drug_names: Vec<String>,
132    /// Predicted severity.
133    pub severity: Severity,
134    /// Confidence score (0-1).
135    pub confidence: f64,
136    /// Whether this is a known interaction.
137    pub is_known: bool,
138    /// Interaction mechanism (if predicted).
139    pub mechanism: Option<String>,
140    /// Risk score.
141    pub risk_score: f64,
142}
143
144/// Result of drug interaction prediction.
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct DrugInteractionResult {
147    /// All predicted interactions.
148    pub interactions: Vec<PredictedInteraction>,
149    /// High-risk drug combinations.
150    pub high_risk_combinations: Vec<Vec<String>>,
151    /// Overall polypharmacy risk score.
152    pub polypharmacy_risk: f64,
153    /// Recommendations.
154    pub recommendations: Vec<String>,
155}
156
157/// Drug Interaction Prediction kernel.
158///
159/// Analyzes drug combinations for potential interactions using
160/// mechanism-of-action features and known interaction databases.
161/// Supports pairwise and higher-order (multi-drug) interactions.
162#[derive(Debug, Clone)]
163pub struct DrugInteractionPrediction {
164    metadata: KernelMetadata,
165}
166
167impl Default for DrugInteractionPrediction {
168    fn default() -> Self {
169        Self::new()
170    }
171}
172
173impl DrugInteractionPrediction {
174    /// Create a new Drug Interaction Prediction kernel.
175    #[must_use]
176    pub fn new() -> Self {
177        Self {
178            metadata: KernelMetadata::batch(
179                "ml/drug-interaction-prediction",
180                Domain::StatisticalML,
181            )
182            .with_description("Multi-drug interaction prediction using mechanism features")
183            .with_throughput(5_000)
184            .with_latency_us(200.0),
185        }
186    }
187
188    /// Predict interactions for a set of drugs.
189    pub fn predict(
190        drugs: &[Drug],
191        knowledge_base: &InteractionKnowledgeBase,
192        config: &DrugInteractionConfig,
193    ) -> DrugInteractionResult {
194        if drugs.is_empty() {
195            return DrugInteractionResult {
196                interactions: Vec::new(),
197                high_risk_combinations: Vec::new(),
198                polypharmacy_risk: 0.0,
199                recommendations: Vec::new(),
200            };
201        }
202
203        let mut interactions = Vec::new();
204        let mut high_risk = Vec::new();
205
206        // Check pairwise interactions
207        for i in 0..drugs.len() {
208            for j in (i + 1)..drugs.len() {
209                let drug_ids = vec![drugs[i].id.clone(), drugs[j].id.clone()];
210                let drug_names = vec![drugs[i].name.clone(), drugs[j].name.clone()];
211
212                // Check known interactions
213                if let Some(known) = knowledge_base.get_known_interaction(&drug_ids) {
214                    if config.include_known && config.severity_filter.contains(&known.severity) {
215                        interactions.push(PredictedInteraction {
216                            drug_ids: drug_ids.clone(),
217                            drug_names: drug_names.clone(),
218                            severity: known.severity,
219                            confidence: 1.0,
220                            is_known: true,
221                            mechanism: Some(known.description.clone()),
222                            risk_score: Self::severity_to_risk(known.severity),
223                        });
224
225                        if known.severity == Severity::Major {
226                            high_risk.push(drug_ids.clone());
227                        }
228                    }
229                } else {
230                    // Predict interaction based on features
231                    let (severity, confidence) =
232                        Self::predict_pairwise(&drugs[i], &drugs[j], knowledge_base);
233
234                    if confidence >= config.min_confidence
235                        && config.severity_filter.contains(&severity)
236                    {
237                        let risk = Self::severity_to_risk(severity) * confidence;
238
239                        interactions.push(PredictedInteraction {
240                            drug_ids: drug_ids.clone(),
241                            drug_names,
242                            severity,
243                            confidence,
244                            is_known: false,
245                            mechanism: Self::predict_mechanism(&drugs[i], &drugs[j]),
246                            risk_score: risk,
247                        });
248
249                        if severity == Severity::Major && confidence > 0.7 {
250                            high_risk.push(drug_ids);
251                        }
252                    }
253                }
254            }
255        }
256
257        // Check higher-order interactions if configured
258        if config.max_order >= 3 && drugs.len() >= 3 {
259            for i in 0..drugs.len() {
260                for j in (i + 1)..drugs.len() {
261                    for k in (j + 1)..drugs.len() {
262                        let drug_ids = vec![
263                            drugs[i].id.clone(),
264                            drugs[j].id.clone(),
265                            drugs[k].id.clone(),
266                        ];
267
268                        let (severity, confidence) =
269                            Self::predict_triplet(&drugs[i], &drugs[j], &drugs[k], knowledge_base);
270
271                        if confidence >= config.min_confidence {
272                            interactions.push(PredictedInteraction {
273                                drug_ids: drug_ids.clone(),
274                                drug_names: vec![
275                                    drugs[i].name.clone(),
276                                    drugs[j].name.clone(),
277                                    drugs[k].name.clone(),
278                                ],
279                                severity,
280                                confidence,
281                                is_known: false,
282                                mechanism: Some("Complex multi-drug interaction".to_string()),
283                                risk_score: Self::severity_to_risk(severity) * confidence,
284                            });
285                        }
286                    }
287                }
288            }
289        }
290
291        // Calculate polypharmacy risk
292        let polypharmacy_risk = Self::calculate_polypharmacy_risk(drugs.len(), &interactions);
293
294        // Generate recommendations
295        let recommendations = Self::generate_recommendations(&interactions, &high_risk);
296
297        // Sort interactions by risk
298        interactions.sort_by(|a, b| {
299            b.risk_score
300                .partial_cmp(&a.risk_score)
301                .unwrap_or(std::cmp::Ordering::Equal)
302        });
303
304        DrugInteractionResult {
305            interactions,
306            high_risk_combinations: high_risk,
307            polypharmacy_risk,
308            recommendations,
309        }
310    }
311
312    /// Predict pairwise interaction from drug features.
313    fn predict_pairwise(
314        drug_a: &Drug,
315        drug_b: &Drug,
316        kb: &InteractionKnowledgeBase,
317    ) -> (Severity, f64) {
318        // Check class-level interactions
319        if let (Some(class_a), Some(class_b)) = (&drug_a.drug_class, &drug_b.drug_class) {
320            let key = if class_a < class_b {
321                (class_a.clone(), class_b.clone())
322            } else {
323                (class_b.clone(), class_a.clone())
324            };
325
326            if let Some(&severity) = kb.class_interactions.get(&key) {
327                return (severity, 0.8);
328            }
329        }
330
331        // Compute feature-based similarity
332        let moa_similarity = Self::cosine_similarity(&drug_a.moa_features, &drug_b.moa_features);
333
334        // Check target overlap
335        let target_overlap = Self::jaccard_similarity(&drug_a.targets, &drug_b.targets);
336
337        // Heuristic: high target overlap + different MOA = higher risk
338        let risk_score = target_overlap * (1.0 - moa_similarity) + moa_similarity * 0.3;
339
340        let (severity, confidence) = if risk_score > 0.7 {
341            (Severity::Major, risk_score)
342        } else if risk_score > 0.5 {
343            (Severity::Moderate, risk_score)
344        } else if risk_score > 0.3 {
345            (Severity::Minor, risk_score)
346        } else {
347            (Severity::Minimal, risk_score)
348        };
349
350        (severity, confidence)
351    }
352
353    /// Predict triplet interaction.
354    fn predict_triplet(
355        drug_a: &Drug,
356        drug_b: &Drug,
357        drug_c: &Drug,
358        _kb: &InteractionKnowledgeBase,
359    ) -> (Severity, f64) {
360        // Aggregate pairwise features
361        let sim_ab = Self::cosine_similarity(&drug_a.moa_features, &drug_b.moa_features);
362        let sim_bc = Self::cosine_similarity(&drug_b.moa_features, &drug_c.moa_features);
363        let sim_ac = Self::cosine_similarity(&drug_a.moa_features, &drug_c.moa_features);
364
365        // Complex interaction if all pairs have some relationship
366        let avg_sim = (sim_ab + sim_bc + sim_ac) / 3.0;
367
368        // Target overlap analysis
369        let all_targets: HashSet<_> = drug_a
370            .targets
371            .iter()
372            .chain(drug_b.targets.iter())
373            .chain(drug_c.targets.iter())
374            .collect();
375
376        let unique_targets = all_targets.len();
377        let total_targets = drug_a.targets.len() + drug_b.targets.len() + drug_c.targets.len();
378
379        let overlap_ratio = if total_targets > 0 {
380            1.0 - (unique_targets as f64 / total_targets as f64)
381        } else {
382            0.0
383        };
384
385        let risk_score = avg_sim * 0.4 + overlap_ratio * 0.6;
386        let confidence = risk_score * 0.7; // Lower confidence for triplets
387
388        let severity = if risk_score > 0.6 {
389            Severity::Major
390        } else if risk_score > 0.4 {
391            Severity::Moderate
392        } else {
393            Severity::Minor
394        };
395
396        (severity, confidence)
397    }
398
399    fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
400        if a.is_empty() || b.is_empty() || a.len() != b.len() {
401            return 0.0;
402        }
403
404        let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
405        let norm_a: f64 = a.iter().map(|x| x * x).sum::<f64>().sqrt();
406        let norm_b: f64 = b.iter().map(|x| x * x).sum::<f64>().sqrt();
407
408        if norm_a < 1e-10 || norm_b < 1e-10 {
409            0.0
410        } else {
411            dot / (norm_a * norm_b)
412        }
413    }
414
415    fn jaccard_similarity(a: &[String], b: &[String]) -> f64 {
416        if a.is_empty() && b.is_empty() {
417            return 0.0;
418        }
419
420        let set_a: HashSet<_> = a.iter().collect();
421        let set_b: HashSet<_> = b.iter().collect();
422
423        let intersection = set_a.intersection(&set_b).count();
424        let union = set_a.union(&set_b).count();
425
426        if union == 0 {
427            0.0
428        } else {
429            intersection as f64 / union as f64
430        }
431    }
432
433    fn severity_to_risk(severity: Severity) -> f64 {
434        match severity {
435            Severity::Major => 1.0,
436            Severity::Moderate => 0.6,
437            Severity::Minor => 0.3,
438            Severity::Minimal => 0.1,
439        }
440    }
441
442    fn predict_mechanism(drug_a: &Drug, drug_b: &Drug) -> Option<String> {
443        let target_overlap = Self::jaccard_similarity(&drug_a.targets, &drug_b.targets);
444
445        if target_overlap > 0.5 {
446            Some("Pharmacodynamic: competing for same targets".to_string())
447        } else if target_overlap > 0.2 {
448            Some("Pharmacodynamic: overlapping target pathways".to_string())
449        } else {
450            Some("Pharmacokinetic: potential metabolic interaction".to_string())
451        }
452    }
453
454    fn calculate_polypharmacy_risk(
455        drug_count: usize,
456        interactions: &[PredictedInteraction],
457    ) -> f64 {
458        // Base risk from drug count
459        let count_risk = (drug_count as f64 - 1.0).max(0.0) * 0.1;
460
461        // Interaction-based risk
462        let interaction_risk: f64 = interactions
463            .iter()
464            .map(|i| i.risk_score * i.confidence)
465            .sum::<f64>()
466            / interactions.len().max(1) as f64;
467
468        (count_risk + interaction_risk).min(1.0)
469    }
470
471    fn generate_recommendations(
472        interactions: &[PredictedInteraction],
473        high_risk: &[Vec<String>],
474    ) -> Vec<String> {
475        let mut recommendations = Vec::new();
476
477        if !high_risk.is_empty() {
478            recommendations.push(format!(
479                "ALERT: {} high-risk drug combinations detected. Consider alternatives.",
480                high_risk.len()
481            ));
482        }
483
484        let major_count = interactions
485            .iter()
486            .filter(|i| i.severity == Severity::Major)
487            .count();
488        if major_count > 0 {
489            recommendations.push(format!(
490                "Review {} major interactions before prescribing.",
491                major_count
492            ));
493        }
494
495        if interactions.len() > 5 {
496            recommendations
497                .push("Consider medication review to reduce polypharmacy risk.".to_string());
498        }
499
500        recommendations
501    }
502}
503
504impl GpuKernel for DrugInteractionPrediction {
505    fn metadata(&self) -> &KernelMetadata {
506        &self.metadata
507    }
508}
509
510// ============================================================================
511// Clinical Pathway Conformance Kernel
512// ============================================================================
513
514/// Configuration for clinical pathway conformance.
515#[derive(Debug, Clone, Serialize, Deserialize)]
516pub struct PathwayConformanceConfig {
517    /// Strictness level for conformance checking.
518    pub strictness: ConformanceStrictness,
519    /// Allow deviations with documented reasons.
520    pub allow_documented_deviations: bool,
521    /// Time tolerance for step ordering (in hours).
522    pub time_tolerance_hours: f64,
523    /// Check required steps only or all steps.
524    pub required_only: bool,
525}
526
527impl Default for PathwayConformanceConfig {
528    fn default() -> Self {
529        Self {
530            strictness: ConformanceStrictness::Standard,
531            allow_documented_deviations: true,
532            time_tolerance_hours: 24.0,
533            required_only: false,
534        }
535    }
536}
537
538/// Strictness level.
539#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
540pub enum ConformanceStrictness {
541    /// Relaxed checking, warnings only.
542    Relaxed,
543    /// Standard checking.
544    Standard,
545    /// Strict checking, all deviations flagged.
546    Strict,
547}
548
549/// A step in a clinical pathway.
550#[derive(Debug, Clone, Serialize, Deserialize)]
551pub struct PathwayStep {
552    /// Step identifier.
553    pub id: String,
554    /// Step name/description.
555    pub name: String,
556    /// Required step (must be completed).
557    pub required: bool,
558    /// Expected timing (hours from start).
559    pub expected_timing: Option<f64>,
560    /// Dependencies (steps that must come before).
561    pub dependencies: Vec<String>,
562    /// Step category.
563    pub category: StepCategory,
564}
565
566/// Category of pathway step.
567#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
568pub enum StepCategory {
569    /// Diagnostic test or assessment.
570    Diagnostic,
571    /// Treatment or intervention.
572    Treatment,
573    /// Medication administration.
574    Medication,
575    /// Monitoring or observation.
576    Monitoring,
577    /// Consultation or referral.
578    Consultation,
579    /// Documentation or administrative.
580    Administrative,
581}
582
583/// A clinical pathway/protocol.
584#[derive(Debug, Clone, Serialize, Deserialize)]
585pub struct ClinicalPathway {
586    /// Pathway identifier.
587    pub id: String,
588    /// Pathway name.
589    pub name: String,
590    /// Condition/diagnosis this pathway applies to.
591    pub condition: String,
592    /// Steps in the pathway.
593    pub steps: Vec<PathwayStep>,
594    /// Expected total duration (hours).
595    pub expected_duration_hours: f64,
596}
597
598/// A completed care event.
599#[derive(Debug, Clone, Serialize, Deserialize)]
600pub struct CareEvent {
601    /// Event identifier.
602    pub id: String,
603    /// Corresponding pathway step ID (if matched).
604    pub step_id: Option<String>,
605    /// Event description.
606    pub description: String,
607    /// Timestamp (hours from pathway start).
608    pub timestamp_hours: f64,
609    /// Category of the event.
610    pub category: StepCategory,
611    /// Deviation reason if applicable.
612    pub deviation_reason: Option<String>,
613}
614
615/// A conformance deviation.
616#[derive(Debug, Clone, Serialize, Deserialize)]
617pub struct PathwayDeviation {
618    /// Step that deviated.
619    pub step_id: String,
620    /// Type of deviation.
621    pub deviation_type: DeviationType,
622    /// Severity of deviation.
623    pub severity: DeviationSeverity,
624    /// Description.
625    pub description: String,
626    /// Was a reason documented.
627    pub reason_documented: bool,
628}
629
630/// Type of pathway deviation.
631#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
632pub enum DeviationType {
633    /// Required step was missed.
634    MissedStep,
635    /// Step completed out of order.
636    OrderViolation,
637    /// Step timing deviated significantly.
638    TimingDeviation,
639    /// Extra step not in pathway.
640    ExtraStep,
641    /// Step completed multiple times.
642    DuplicateStep,
643}
644
645/// Severity of deviation.
646#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
647pub enum DeviationSeverity {
648    /// Critical - safety concern.
649    Critical,
650    /// Major - significant protocol violation.
651    Major,
652    /// Minor - documented acceptable deviation.
653    Minor,
654    /// Informational only.
655    Info,
656}
657
658/// Result of conformance check.
659#[derive(Debug, Clone, Serialize, Deserialize)]
660pub struct ConformanceResult {
661    /// Overall conformance score (0-1).
662    pub conformance_score: f64,
663    /// Is the pathway fully conformant.
664    pub is_conformant: bool,
665    /// List of deviations.
666    pub deviations: Vec<PathwayDeviation>,
667    /// Matched steps.
668    pub matched_steps: Vec<String>,
669    /// Unmatched pathway steps.
670    pub missing_steps: Vec<String>,
671    /// Events not matching any step.
672    pub extra_events: Vec<String>,
673    /// Completion percentage.
674    pub completion_percentage: f64,
675}
676
677/// Clinical Pathway Conformance kernel.
678///
679/// Checks patient care events against clinical pathways/protocols
680/// to identify deviations, ensure guideline adherence, and
681/// support quality metrics.
682#[derive(Debug, Clone)]
683pub struct ClinicalPathwayConformance {
684    metadata: KernelMetadata,
685}
686
687impl Default for ClinicalPathwayConformance {
688    fn default() -> Self {
689        Self::new()
690    }
691}
692
693impl ClinicalPathwayConformance {
694    /// Create a new Clinical Pathway Conformance kernel.
695    #[must_use]
696    pub fn new() -> Self {
697        Self {
698            metadata: KernelMetadata::batch(
699                "ml/clinical-pathway-conformance",
700                Domain::StatisticalML,
701            )
702            .with_description("Clinical guideline and pathway conformance checking")
703            .with_throughput(10_000)
704            .with_latency_us(50.0),
705        }
706    }
707
708    /// Check conformance of care events against a pathway.
709    pub fn check_conformance(
710        pathway: &ClinicalPathway,
711        events: &[CareEvent],
712        config: &PathwayConformanceConfig,
713    ) -> ConformanceResult {
714        if pathway.steps.is_empty() {
715            return ConformanceResult {
716                conformance_score: 1.0,
717                is_conformant: true,
718                deviations: Vec::new(),
719                matched_steps: Vec::new(),
720                missing_steps: Vec::new(),
721                extra_events: Vec::new(),
722                completion_percentage: 100.0,
723            };
724        }
725
726        let mut deviations = Vec::new();
727        let mut matched_steps = Vec::new();
728        let mut matched_event_ids: HashSet<String> = HashSet::new();
729
730        // Match events to steps
731        for step in &pathway.steps {
732            if config.required_only && !step.required {
733                continue;
734            }
735
736            let matching_events: Vec<&CareEvent> = events
737                .iter()
738                .filter(|e| {
739                    e.step_id.as_ref() == Some(&step.id)
740                        || (e.category == step.category
741                            && e.description
742                                .to_lowercase()
743                                .contains(&step.name.to_lowercase()))
744                })
745                .collect();
746
747            if matching_events.is_empty() {
748                if step.required {
749                    deviations.push(PathwayDeviation {
750                        step_id: step.id.clone(),
751                        deviation_type: DeviationType::MissedStep,
752                        severity: DeviationSeverity::Major,
753                        description: format!("Required step '{}' was not completed", step.name),
754                        reason_documented: false,
755                    });
756                }
757            } else {
758                matched_steps.push(step.id.clone());
759                matched_event_ids.insert(matching_events[0].id.clone());
760
761                // Check timing
762                if let Some(expected_time) = step.expected_timing {
763                    let actual_time = matching_events[0].timestamp_hours;
764                    let time_diff = (actual_time - expected_time).abs();
765
766                    if time_diff > config.time_tolerance_hours {
767                        let severity = if time_diff > config.time_tolerance_hours * 2.0 {
768                            DeviationSeverity::Major
769                        } else {
770                            DeviationSeverity::Minor
771                        };
772
773                        deviations.push(PathwayDeviation {
774                            step_id: step.id.clone(),
775                            deviation_type: DeviationType::TimingDeviation,
776                            severity,
777                            description: format!(
778                                "Step '{}' timing deviation: expected {}h, actual {}h",
779                                step.name, expected_time, actual_time
780                            ),
781                            reason_documented: matching_events[0].deviation_reason.is_some(),
782                        });
783                    }
784                }
785
786                // Check for duplicates
787                if matching_events.len() > 1 {
788                    deviations.push(PathwayDeviation {
789                        step_id: step.id.clone(),
790                        deviation_type: DeviationType::DuplicateStep,
791                        severity: DeviationSeverity::Info,
792                        description: format!(
793                            "Step '{}' completed {} times",
794                            step.name,
795                            matching_events.len()
796                        ),
797                        reason_documented: true,
798                    });
799                }
800            }
801        }
802
803        // Check dependencies (ordering)
804        for step in &pathway.steps {
805            if !matched_steps.contains(&step.id) {
806                continue;
807            }
808
809            for dep_id in &step.dependencies {
810                if !matched_steps.contains(dep_id) {
811                    deviations.push(PathwayDeviation {
812                        step_id: step.id.clone(),
813                        deviation_type: DeviationType::OrderViolation,
814                        severity: DeviationSeverity::Major,
815                        description: format!(
816                            "Step '{}' completed before dependency '{}'",
817                            step.name, dep_id
818                        ),
819                        reason_documented: false,
820                    });
821                }
822            }
823        }
824
825        // Find extra events
826        let extra_events: Vec<String> = events
827            .iter()
828            .filter(|e| !matched_event_ids.contains(&e.id))
829            .map(|e| e.id.clone())
830            .collect();
831
832        // Calculate missing steps
833        let required_steps: Vec<_> = pathway
834            .steps
835            .iter()
836            .filter(|s| s.required)
837            .map(|s| s.id.clone())
838            .collect();
839
840        let missing_steps: Vec<String> = required_steps
841            .iter()
842            .filter(|s| !matched_steps.contains(s))
843            .cloned()
844            .collect();
845
846        // Apply documented deviation allowance
847        if config.allow_documented_deviations {
848            deviations
849                .retain(|d| !(d.reason_documented && d.severity != DeviationSeverity::Critical));
850        }
851
852        // Calculate scores
853        let completion_percentage = if required_steps.is_empty() {
854            100.0
855        } else {
856            (matched_steps.len() as f64 / required_steps.len() as f64) * 100.0
857        };
858
859        let deviation_penalty: f64 = deviations
860            .iter()
861            .map(|d| match d.severity {
862                DeviationSeverity::Critical => 0.4,
863                DeviationSeverity::Major => 0.2,
864                DeviationSeverity::Minor => 0.05,
865                DeviationSeverity::Info => 0.0,
866            })
867            .sum();
868
869        let conformance_score =
870            (1.0 - deviation_penalty).max(0.0) * (completion_percentage / 100.0);
871
872        let is_conformant = match config.strictness {
873            ConformanceStrictness::Relaxed => conformance_score >= 0.7,
874            ConformanceStrictness::Standard => {
875                conformance_score >= 0.85 && missing_steps.is_empty()
876            }
877            ConformanceStrictness::Strict => {
878                conformance_score >= 0.95
879                    && missing_steps.is_empty()
880                    && deviations
881                        .iter()
882                        .all(|d| d.severity == DeviationSeverity::Info)
883            }
884        };
885
886        ConformanceResult {
887            conformance_score,
888            is_conformant,
889            deviations,
890            matched_steps,
891            missing_steps,
892            extra_events,
893            completion_percentage,
894        }
895    }
896
897    /// Check multiple patients against the same pathway.
898    pub fn check_batch(
899        pathway: &ClinicalPathway,
900        patient_events: &[Vec<CareEvent>],
901        config: &PathwayConformanceConfig,
902    ) -> Vec<ConformanceResult> {
903        patient_events
904            .iter()
905            .map(|events| Self::check_conformance(pathway, events, config))
906            .collect()
907    }
908
909    /// Calculate aggregate statistics across patients.
910    pub fn aggregate_stats(results: &[ConformanceResult]) -> PathwayStatistics {
911        if results.is_empty() {
912            return PathwayStatistics::default();
913        }
914
915        let n = results.len() as f64;
916
917        let avg_conformance = results.iter().map(|r| r.conformance_score).sum::<f64>() / n;
918        let conformant_count = results.iter().filter(|r| r.is_conformant).count();
919        let avg_completion = results.iter().map(|r| r.completion_percentage).sum::<f64>() / n;
920
921        // Count deviation types
922        let mut deviation_counts: HashMap<DeviationType, usize> = HashMap::new();
923        for result in results {
924            for dev in &result.deviations {
925                *deviation_counts.entry(dev.deviation_type).or_insert(0) += 1;
926            }
927        }
928
929        PathwayStatistics {
930            total_patients: results.len(),
931            conformant_patients: conformant_count,
932            conformance_rate: conformant_count as f64 / n,
933            average_conformance_score: avg_conformance,
934            average_completion: avg_completion,
935            deviation_counts,
936        }
937    }
938}
939
940/// Aggregate pathway statistics.
941#[derive(Debug, Clone, Default, Serialize, Deserialize)]
942pub struct PathwayStatistics {
943    /// Total patients analyzed.
944    pub total_patients: usize,
945    /// Patients meeting conformance threshold.
946    pub conformant_patients: usize,
947    /// Conformance rate (0-1).
948    pub conformance_rate: f64,
949    /// Average conformance score.
950    pub average_conformance_score: f64,
951    /// Average completion percentage.
952    pub average_completion: f64,
953    /// Count of each deviation type.
954    pub deviation_counts: HashMap<DeviationType, usize>,
955}
956
957impl GpuKernel for ClinicalPathwayConformance {
958    fn metadata(&self) -> &KernelMetadata {
959        &self.metadata
960    }
961}
962
963#[cfg(test)]
964mod tests {
965    use super::*;
966
967    #[test]
968    fn test_drug_interaction_metadata() {
969        let kernel = DrugInteractionPrediction::new();
970        assert_eq!(kernel.metadata().id, "ml/drug-interaction-prediction");
971    }
972
973    #[test]
974    fn test_drug_interaction_basic() {
975        let drugs = vec![
976            Drug {
977                id: "drug1".to_string(),
978                name: "Aspirin".to_string(),
979                drug_class: Some("NSAID".to_string()),
980                moa_features: vec![1.0, 0.0, 0.0],
981                targets: vec!["COX1".to_string(), "COX2".to_string()],
982            },
983            Drug {
984                id: "drug2".to_string(),
985                name: "Ibuprofen".to_string(),
986                drug_class: Some("NSAID".to_string()),
987                // Different MOA (orthogonal) + same targets = high risk
988                moa_features: vec![0.0, 1.0, 0.0],
989                targets: vec!["COX1".to_string(), "COX2".to_string()],
990            },
991        ];
992
993        let kb = InteractionKnowledgeBase::default();
994        let config = DrugInteractionConfig::default();
995
996        let result = DrugInteractionPrediction::predict(&drugs, &kb, &config);
997
998        // Should detect potential interaction (same targets + different MOA)
999        assert!(!result.interactions.is_empty());
1000    }
1001
1002    #[test]
1003    fn test_known_interaction() {
1004        let drugs = vec![
1005            Drug {
1006                id: "warfarin".to_string(),
1007                name: "Warfarin".to_string(),
1008                drug_class: Some("Anticoagulant".to_string()),
1009                moa_features: vec![],
1010                targets: vec![],
1011            },
1012            Drug {
1013                id: "aspirin".to_string(),
1014                name: "Aspirin".to_string(),
1015                drug_class: Some("NSAID".to_string()),
1016                moa_features: vec![],
1017                targets: vec![],
1018            },
1019        ];
1020
1021        let mut kb = InteractionKnowledgeBase::default();
1022        kb.add_interaction(KnownInteraction {
1023            drug_ids: vec!["warfarin".to_string(), "aspirin".to_string()],
1024            severity: Severity::Major,
1025            description: "Increased bleeding risk".to_string(),
1026            recommendation: "Avoid combination".to_string(),
1027        });
1028
1029        let config = DrugInteractionConfig::default();
1030        let result = DrugInteractionPrediction::predict(&drugs, &kb, &config);
1031
1032        assert!(
1033            result
1034                .interactions
1035                .iter()
1036                .any(|i| i.is_known && i.severity == Severity::Major)
1037        );
1038    }
1039
1040    #[test]
1041    fn test_empty_drugs() {
1042        let kb = InteractionKnowledgeBase::default();
1043        let config = DrugInteractionConfig::default();
1044        let result = DrugInteractionPrediction::predict(&[], &kb, &config);
1045        assert!(result.interactions.is_empty());
1046    }
1047
1048    #[test]
1049    fn test_pathway_conformance_metadata() {
1050        let kernel = ClinicalPathwayConformance::new();
1051        assert_eq!(kernel.metadata().id, "ml/clinical-pathway-conformance");
1052    }
1053
1054    #[test]
1055    fn test_pathway_conformance_basic() {
1056        let pathway = ClinicalPathway {
1057            id: "sepsis".to_string(),
1058            name: "Sepsis Bundle".to_string(),
1059            condition: "Sepsis".to_string(),
1060            steps: vec![
1061                PathwayStep {
1062                    id: "lactate".to_string(),
1063                    name: "Measure lactate".to_string(),
1064                    required: true,
1065                    expected_timing: Some(1.0),
1066                    dependencies: vec![],
1067                    category: StepCategory::Diagnostic,
1068                },
1069                PathwayStep {
1070                    id: "cultures".to_string(),
1071                    name: "Blood cultures".to_string(),
1072                    required: true,
1073                    expected_timing: Some(1.0),
1074                    dependencies: vec![],
1075                    category: StepCategory::Diagnostic,
1076                },
1077                PathwayStep {
1078                    id: "antibiotics".to_string(),
1079                    name: "Broad spectrum antibiotics".to_string(),
1080                    required: true,
1081                    expected_timing: Some(3.0),
1082                    dependencies: vec!["cultures".to_string()],
1083                    category: StepCategory::Medication,
1084                },
1085            ],
1086            expected_duration_hours: 6.0,
1087        };
1088
1089        let events = vec![
1090            CareEvent {
1091                id: "e1".to_string(),
1092                step_id: Some("lactate".to_string()),
1093                description: "Lactate measured".to_string(),
1094                timestamp_hours: 0.5,
1095                category: StepCategory::Diagnostic,
1096                deviation_reason: None,
1097            },
1098            CareEvent {
1099                id: "e2".to_string(),
1100                step_id: Some("cultures".to_string()),
1101                description: "Blood cultures drawn".to_string(),
1102                timestamp_hours: 0.75,
1103                category: StepCategory::Diagnostic,
1104                deviation_reason: None,
1105            },
1106            CareEvent {
1107                id: "e3".to_string(),
1108                step_id: Some("antibiotics".to_string()),
1109                description: "Antibiotics administered".to_string(),
1110                timestamp_hours: 2.0,
1111                category: StepCategory::Medication,
1112                deviation_reason: None,
1113            },
1114        ];
1115
1116        let config = PathwayConformanceConfig::default();
1117        let result = ClinicalPathwayConformance::check_conformance(&pathway, &events, &config);
1118
1119        assert!(result.conformance_score > 0.9);
1120        assert!(result.is_conformant);
1121        assert!(result.missing_steps.is_empty());
1122    }
1123
1124    #[test]
1125    fn test_missed_required_step() {
1126        let pathway = ClinicalPathway {
1127            id: "test".to_string(),
1128            name: "Test".to_string(),
1129            condition: "Test".to_string(),
1130            steps: vec![PathwayStep {
1131                id: "required".to_string(),
1132                name: "Required step".to_string(),
1133                required: true,
1134                expected_timing: None,
1135                dependencies: vec![],
1136                category: StepCategory::Treatment,
1137            }],
1138            expected_duration_hours: 24.0,
1139        };
1140
1141        let events: Vec<CareEvent> = vec![];
1142        let config = PathwayConformanceConfig::default();
1143
1144        let result = ClinicalPathwayConformance::check_conformance(&pathway, &events, &config);
1145
1146        assert!(!result.is_conformant);
1147        assert!(
1148            result
1149                .deviations
1150                .iter()
1151                .any(|d| d.deviation_type == DeviationType::MissedStep)
1152        );
1153    }
1154
1155    #[test]
1156    fn test_batch_conformance() {
1157        let pathway = ClinicalPathway {
1158            id: "simple".to_string(),
1159            name: "Simple".to_string(),
1160            condition: "Test".to_string(),
1161            steps: vec![PathwayStep {
1162                id: "step1".to_string(),
1163                name: "Step 1".to_string(),
1164                required: true,
1165                expected_timing: None,
1166                dependencies: vec![],
1167                category: StepCategory::Treatment,
1168            }],
1169            expected_duration_hours: 24.0,
1170        };
1171
1172        let patient_events = vec![
1173            vec![CareEvent {
1174                id: "p1e1".to_string(),
1175                step_id: Some("step1".to_string()),
1176                description: "Step 1".to_string(),
1177                timestamp_hours: 1.0,
1178                category: StepCategory::Treatment,
1179                deviation_reason: None,
1180            }],
1181            vec![], // Non-conformant
1182        ];
1183
1184        let config = PathwayConformanceConfig::default();
1185        let results = ClinicalPathwayConformance::check_batch(&pathway, &patient_events, &config);
1186
1187        assert_eq!(results.len(), 2);
1188        assert!(results[0].is_conformant);
1189        assert!(!results[1].is_conformant);
1190
1191        let stats = ClinicalPathwayConformance::aggregate_stats(&results);
1192        assert_eq!(stats.total_patients, 2);
1193        assert_eq!(stats.conformant_patients, 1);
1194    }
1195}