Skip to main content

tenflowers_dataset/
debug_tools.rs

1//! Debug and profiling tools for data pipeline analysis
2//!
3//! This module provides comprehensive debugging and profiling capabilities for
4//! analyzing data pipeline performance, identifying bottlenecks, and optimizing
5//! data loading workflows.
6
7use crate::Dataset;
8use std::collections::HashMap;
9use std::time::{Duration, Instant};
10use tenflowers_core::{Result, TensorError};
11
12/// Pipeline profiler for analyzing data loading performance
13pub struct PipelineProfiler {
14    /// Name of the pipeline being profiled
15    name: String,
16    /// Start time of profiling session
17    start_time: Option<Instant>,
18    /// Recorded events
19    events: Vec<ProfileEvent>,
20    /// Stage timings
21    stage_timings: HashMap<String, Vec<Duration>>,
22    /// Configuration
23    config: ProfilerConfig,
24}
25
26/// Configuration for the profiler
27#[derive(Debug, Clone)]
28pub struct ProfilerConfig {
29    /// Enable memory tracking
30    pub track_memory: bool,
31    /// Enable cache statistics
32    pub track_cache: bool,
33    /// Enable I/O statistics
34    pub track_io: bool,
35    /// Maximum events to store
36    pub max_events: usize,
37    /// Sample rate (1.0 = all events, 0.1 = 10% of events)
38    pub sample_rate: f64,
39}
40
41impl Default for ProfilerConfig {
42    fn default() -> Self {
43        Self {
44            track_memory: true,
45            track_cache: true,
46            track_io: true,
47            max_events: 10000,
48            sample_rate: 1.0,
49        }
50    }
51}
52
53/// A profiling event
54#[derive(Debug, Clone)]
55pub struct ProfileEvent {
56    /// Event timestamp
57    pub timestamp: Instant,
58    /// Event type
59    pub event_type: EventType,
60    /// Stage name
61    pub stage: String,
62    /// Duration (if applicable)
63    pub duration: Option<Duration>,
64    /// Additional metadata
65    pub metadata: HashMap<String, String>,
66}
67
68/// Types of profiling events
69#[derive(Debug, Clone, PartialEq, Eq)]
70pub enum EventType {
71    /// Stage started
72    StageStart,
73    /// Stage completed
74    StageEnd,
75    /// Data loaded
76    DataLoad,
77    /// Transform applied
78    Transform,
79    /// Cache hit
80    CacheHit,
81    /// Cache miss
82    CacheMiss,
83    /// Memory allocation
84    MemoryAlloc,
85    /// I/O operation
86    IoOperation,
87    /// Custom event
88    Custom(String),
89}
90
91impl PipelineProfiler {
92    /// Create a new profiler
93    pub fn new(name: impl Into<String>, config: ProfilerConfig) -> Self {
94        Self {
95            name: name.into(),
96            start_time: None,
97            events: Vec::new(),
98            stage_timings: HashMap::new(),
99            config,
100        }
101    }
102
103    /// Create with default configuration
104    pub fn default_config(name: impl Into<String>) -> Self {
105        Self::new(name, ProfilerConfig::default())
106    }
107
108    /// Start profiling
109    pub fn start(&mut self) {
110        self.start_time = Some(Instant::now());
111        self.record_event(
112            EventType::Custom("profiling_started".to_string()),
113            "root",
114            None,
115        );
116    }
117
118    /// Stop profiling
119    pub fn stop(&mut self) {
120        if let Some(start) = self.start_time {
121            let duration = start.elapsed();
122            self.record_event(
123                EventType::Custom("profiling_stopped".to_string()),
124                "root",
125                Some(duration),
126            );
127        }
128    }
129
130    /// Record a profiling event
131    pub fn record_event(
132        &mut self,
133        event_type: EventType,
134        stage: impl Into<String>,
135        duration: Option<Duration>,
136    ) {
137        // Apply sampling
138        if self.config.sample_rate < 1.0 {
139            use scirs2_core::random::rand_prelude::*;
140            let mut rng = scirs2_core::random::rng();
141            let sample: f64 = rng.random();
142            if sample > self.config.sample_rate {
143                return;
144            }
145        }
146
147        if self.events.len() >= self.config.max_events {
148            // Remove oldest event
149            self.events.remove(0);
150        }
151
152        let event = ProfileEvent {
153            timestamp: Instant::now(),
154            event_type,
155            stage: stage.into(),
156            duration,
157            metadata: HashMap::new(),
158        };
159
160        self.events.push(event);
161    }
162
163    /// Start timing a stage
164    pub fn start_stage(&mut self, stage: impl Into<String>) -> StageTimer {
165        let stage_name = stage.into();
166        self.record_event(EventType::StageStart, &stage_name, None);
167        StageTimer::new(stage_name, self.start_time.unwrap_or_else(Instant::now))
168    }
169
170    /// End timing a stage
171    pub fn end_stage(&mut self, timer: StageTimer) {
172        let duration = timer.elapsed();
173        self.record_event(EventType::StageEnd, &timer.stage, Some(duration));
174
175        self.stage_timings
176            .entry(timer.stage.clone())
177            .or_insert_with(Vec::new)
178            .push(duration);
179    }
180
181    /// Generate profiling report
182    pub fn generate_report(&self) -> ProfileReport {
183        let total_duration = self
184            .start_time
185            .map(|start| start.elapsed())
186            .unwrap_or(Duration::from_secs(0));
187
188        // Aggregate stage statistics
189        let mut stage_stats = HashMap::new();
190        for (stage, durations) in &self.stage_timings {
191            let stats = StageStatistics::from_durations(durations);
192            stage_stats.insert(stage.clone(), stats);
193        }
194
195        // Count event types
196        let mut event_counts = HashMap::new();
197        for event in &self.events {
198            let event_name = format!("{:?}", event.event_type);
199            *event_counts.entry(event_name).or_insert(0) += 1;
200        }
201
202        // Calculate cache statistics
203        let cache_hits = self
204            .events
205            .iter()
206            .filter(|e| e.event_type == EventType::CacheHit)
207            .count();
208        let cache_misses = self
209            .events
210            .iter()
211            .filter(|e| e.event_type == EventType::CacheMiss)
212            .count();
213        let cache_hit_rate = if cache_hits + cache_misses > 0 {
214            cache_hits as f64 / (cache_hits + cache_misses) as f64
215        } else {
216            0.0
217        };
218
219        ProfileReport {
220            pipeline_name: self.name.clone(),
221            total_duration,
222            total_events: self.events.len(),
223            stage_stats,
224            event_counts,
225            cache_hit_rate,
226            bottlenecks: self.identify_bottlenecks(),
227            recommendations: self.generate_recommendations(),
228        }
229    }
230
231    /// Identify performance bottlenecks
232    fn identify_bottlenecks(&self) -> Vec<Bottleneck> {
233        let mut bottlenecks = Vec::new();
234
235        // Find slow stages
236        for (stage, durations) in &self.stage_timings {
237            if durations.is_empty() {
238                continue;
239            }
240
241            let avg_duration = durations.iter().sum::<Duration>() / durations.len() as u32;
242
243            // Flag stages taking >100ms on average
244            if avg_duration.as_millis() > 100 {
245                bottlenecks.push(Bottleneck {
246                    category: BottleneckCategory::SlowStage,
247                    description: format!("Stage '{}' is slow (avg: {:?})", stage, avg_duration),
248                    severity: if avg_duration.as_millis() > 1000 {
249                        Severity::High
250                    } else {
251                        Severity::Medium
252                    },
253                    affected_component: stage.clone(),
254                });
255            }
256        }
257
258        // Check cache hit rate
259        let cache_hits = self
260            .events
261            .iter()
262            .filter(|e| e.event_type == EventType::CacheHit)
263            .count();
264        let cache_misses = self
265            .events
266            .iter()
267            .filter(|e| e.event_type == EventType::CacheMiss)
268            .count();
269
270        if cache_hits + cache_misses > 0 {
271            let hit_rate = cache_hits as f64 / (cache_hits + cache_misses) as f64;
272            if hit_rate < 0.5 {
273                bottlenecks.push(Bottleneck {
274                    category: BottleneckCategory::LowCacheHitRate,
275                    description: format!("Low cache hit rate: {:.1}%", hit_rate * 100.0),
276                    severity: Severity::Medium,
277                    affected_component: "cache".to_string(),
278                });
279            }
280        }
281
282        bottlenecks
283    }
284
285    /// Generate optimization recommendations
286    fn generate_recommendations(&self) -> Vec<String> {
287        let mut recommendations = Vec::new();
288
289        // Check cache configuration
290        let cache_hits = self
291            .events
292            .iter()
293            .filter(|e| e.event_type == EventType::CacheHit)
294            .count();
295        let cache_misses = self
296            .events
297            .iter()
298            .filter(|e| e.event_type == EventType::CacheMiss)
299            .count();
300
301        if cache_hits + cache_misses > 0 {
302            let hit_rate = cache_hits as f64 / (cache_hits + cache_misses) as f64;
303            if hit_rate < 0.7 {
304                recommendations.push(
305                    "Consider increasing cache size or using predictive prefetching".to_string(),
306                );
307            }
308        }
309
310        // Check for slow stages
311        for (stage, durations) in &self.stage_timings {
312            if durations.is_empty() {
313                continue;
314            }
315
316            let avg_duration = durations.iter().sum::<Duration>() / durations.len() as u32;
317            if avg_duration.as_millis() > 500 {
318                recommendations.push(format!(
319                    "Optimize '{}' stage - consider parallelization or GPU acceleration",
320                    stage
321                ));
322            }
323        }
324
325        if recommendations.is_empty() {
326            recommendations.push("Pipeline is well optimized".to_string());
327        }
328
329        recommendations
330    }
331
332    /// Export events to JSON-compatible format
333    pub fn export_events(&self) -> Vec<HashMap<String, String>> {
334        self.events
335            .iter()
336            .map(|event| {
337                let mut map = HashMap::new();
338                map.insert("stage".to_string(), event.stage.clone());
339                map.insert("type".to_string(), format!("{:?}", event.event_type));
340                if let Some(duration) = event.duration {
341                    map.insert("duration_ms".to_string(), duration.as_millis().to_string());
342                }
343                map
344            })
345            .collect()
346    }
347}
348
349/// Timer for measuring stage duration
350pub struct StageTimer {
351    stage: String,
352    start: Instant,
353}
354
355impl StageTimer {
356    fn new(stage: String, start: Instant) -> Self {
357        Self {
358            stage,
359            start: Instant::now(),
360        }
361    }
362
363    fn elapsed(&self) -> Duration {
364        self.start.elapsed()
365    }
366}
367
368/// Statistics for a pipeline stage
369#[derive(Debug, Clone)]
370pub struct StageStatistics {
371    /// Number of executions
372    pub count: usize,
373    /// Total time spent
374    pub total_duration: Duration,
375    /// Average duration
376    pub avg_duration: Duration,
377    /// Minimum duration
378    pub min_duration: Duration,
379    /// Maximum duration
380    pub max_duration: Duration,
381    /// Standard deviation
382    pub std_dev: Duration,
383}
384
385impl StageStatistics {
386    fn from_durations(durations: &[Duration]) -> Self {
387        if durations.is_empty() {
388            return Self {
389                count: 0,
390                total_duration: Duration::from_secs(0),
391                avg_duration: Duration::from_secs(0),
392                min_duration: Duration::from_secs(0),
393                max_duration: Duration::from_secs(0),
394                std_dev: Duration::from_secs(0),
395            };
396        }
397
398        let total: Duration = durations.iter().sum();
399        let avg = total / durations.len() as u32;
400        let min = *durations
401            .iter()
402            .min()
403            .expect("collection should not be empty for min()");
404        let max = *durations
405            .iter()
406            .max()
407            .expect("collection should not be empty for max()");
408
409        // Calculate standard deviation
410        let variance: f64 = durations
411            .iter()
412            .map(|d| {
413                let diff = d.as_secs_f64() - avg.as_secs_f64();
414                diff * diff
415            })
416            .sum::<f64>()
417            / durations.len() as f64;
418        let std_dev = Duration::from_secs_f64(variance.sqrt());
419
420        Self {
421            count: durations.len(),
422            total_duration: total,
423            avg_duration: avg,
424            min_duration: min,
425            max_duration: max,
426            std_dev,
427        }
428    }
429}
430
431/// Profiling report
432#[derive(Debug, Clone)]
433pub struct ProfileReport {
434    /// Pipeline name
435    pub pipeline_name: String,
436    /// Total duration
437    pub total_duration: Duration,
438    /// Total events recorded
439    pub total_events: usize,
440    /// Stage statistics
441    pub stage_stats: HashMap<String, StageStatistics>,
442    /// Event counts by type
443    pub event_counts: HashMap<String, usize>,
444    /// Cache hit rate
445    pub cache_hit_rate: f64,
446    /// Identified bottlenecks
447    pub bottlenecks: Vec<Bottleneck>,
448    /// Optimization recommendations
449    pub recommendations: Vec<String>,
450}
451
452impl ProfileReport {
453    /// Generate human-readable report
454    pub fn format_report(&self) -> String {
455        let mut report = String::new();
456
457        report.push_str(&format!(
458            "Pipeline Profiling Report: {}\n",
459            self.pipeline_name
460        ));
461        report.push_str("=".repeat(60).as_str());
462        report.push('\n');
463
464        report.push_str(&format!("Total Duration: {:?}\n", self.total_duration));
465        report.push_str(&format!("Total Events: {}\n", self.total_events));
466        report.push_str(&format!(
467            "Cache Hit Rate: {:.1}%\n\n",
468            self.cache_hit_rate * 100.0
469        ));
470
471        // Stage statistics
472        if !self.stage_stats.is_empty() {
473            report.push_str("Stage Statistics:\n");
474            report.push_str("-".repeat(60).as_str());
475            report.push('\n');
476
477            let mut stages: Vec<_> = self.stage_stats.iter().collect();
478            stages.sort_by_key(|a| std::cmp::Reverse(a.1.total_duration));
479
480            for (stage, stats) in stages {
481                report.push_str(&format!(
482                    "  {}: {} calls, avg {:?}, total {:?}\n",
483                    stage, stats.count, stats.avg_duration, stats.total_duration
484                ));
485            }
486            report.push('\n');
487        }
488
489        // Bottlenecks
490        if !self.bottlenecks.is_empty() {
491            report.push_str("Identified Bottlenecks:\n");
492            report.push_str("-".repeat(60).as_str());
493            report.push('\n');
494
495            for bottleneck in &self.bottlenecks {
496                report.push_str(&format!(
497                    "  [{:?}] {}\n",
498                    bottleneck.severity, bottleneck.description
499                ));
500            }
501            report.push('\n');
502        }
503
504        // Recommendations
505        if !self.recommendations.is_empty() {
506            report.push_str("Recommendations:\n");
507            report.push_str("-".repeat(60).as_str());
508            report.push('\n');
509
510            for (i, rec) in self.recommendations.iter().enumerate() {
511                report.push_str(&format!("  {}. {}\n", i + 1, rec));
512            }
513        }
514
515        report
516    }
517}
518
519/// Performance bottleneck
520#[derive(Debug, Clone)]
521pub struct Bottleneck {
522    /// Bottleneck category
523    pub category: BottleneckCategory,
524    /// Description
525    pub description: String,
526    /// Severity
527    pub severity: Severity,
528    /// Affected component
529    pub affected_component: String,
530}
531
532/// Bottleneck categories
533#[derive(Debug, Clone, PartialEq, Eq)]
534pub enum BottleneckCategory {
535    /// Slow stage execution
536    SlowStage,
537    /// High memory usage
538    HighMemoryUsage,
539    /// Low cache hit rate
540    LowCacheHitRate,
541    /// Slow I/O
542    SlowIo,
543    /// Inefficient transform
544    InefficientTransform,
545}
546
547/// Severity levels
548#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
549pub enum Severity {
550    /// Low severity
551    Low,
552    /// Medium severity
553    Medium,
554    /// High severity
555    High,
556    /// Critical severity
557    Critical,
558}
559
560/// Dataset debugger for inspecting dataset contents
561pub struct DatasetDebugger;
562
563impl DatasetDebugger {
564    /// Inspect dataset samples
565    pub fn inspect_samples<T>(
566        dataset: &dyn Dataset<T>,
567        num_samples: usize,
568    ) -> Result<Vec<SampleInfo>>
569    where
570        T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
571    {
572        let mut samples = Vec::new();
573        let count = num_samples.min(dataset.len());
574
575        for i in 0..count {
576            if let Ok((features, labels)) = dataset.get(i) {
577                samples.push(SampleInfo {
578                    index: i,
579                    feature_shape: features.shape().dims().to_vec(),
580                    label_shape: labels.shape().dims().to_vec(),
581                    feature_size: features.size(),
582                    label_size: labels.size(),
583                });
584            }
585        }
586
587        Ok(samples)
588    }
589
590    /// Verify dataset consistency
591    pub fn verify_consistency<T>(dataset: &dyn Dataset<T>) -> Result<ConsistencyReport>
592    where
593        T: Clone + Default + scirs2_core::numeric::Zero + Send + Sync + 'static,
594    {
595        let mut issues = Vec::new();
596        let samples_to_check = dataset.len().min(100);
597
598        if samples_to_check == 0 {
599            return Ok(ConsistencyReport {
600                total_samples: 0,
601                checked_samples: 0,
602                issues,
603                is_consistent: true,
604            });
605        }
606
607        // Check first sample to establish expected shapes
608        let (first_features, first_labels) = dataset.get(0)?;
609        let expected_feature_shape = first_features.shape().dims().to_vec();
610        let expected_label_shape = first_labels.shape().dims().to_vec();
611
612        for i in 1..samples_to_check {
613            if let Ok((features, labels)) = dataset.get(i) {
614                if features.shape().dims() != expected_feature_shape.as_slice() {
615                    issues.push(format!(
616                        "Sample {}: Inconsistent feature shape {:?}, expected {:?}",
617                        i,
618                        features.shape().dims(),
619                        expected_feature_shape
620                    ));
621                }
622                if labels.shape().dims() != expected_label_shape.as_slice() {
623                    issues.push(format!(
624                        "Sample {}: Inconsistent label shape {:?}, expected {:?}",
625                        i,
626                        labels.shape().dims(),
627                        expected_label_shape
628                    ));
629                }
630            } else {
631                issues.push(format!("Sample {}: Failed to load", i));
632            }
633        }
634
635        let is_consistent = issues.is_empty();
636        Ok(ConsistencyReport {
637            total_samples: dataset.len(),
638            checked_samples: samples_to_check,
639            issues,
640            is_consistent,
641        })
642    }
643}
644
645/// Information about a dataset sample
646#[derive(Debug, Clone)]
647pub struct SampleInfo {
648    /// Sample index
649    pub index: usize,
650    /// Feature tensor shape
651    pub feature_shape: Vec<usize>,
652    /// Label tensor shape
653    pub label_shape: Vec<usize>,
654    /// Feature size (total elements)
655    pub feature_size: usize,
656    /// Label size (total elements)
657    pub label_size: usize,
658}
659
660/// Consistency check report
661#[derive(Debug, Clone)]
662pub struct ConsistencyReport {
663    /// Total samples in dataset
664    pub total_samples: usize,
665    /// Number of samples checked
666    pub checked_samples: usize,
667    /// Issues found
668    pub issues: Vec<String>,
669    /// Whether dataset is consistent
670    pub is_consistent: bool,
671}
672
673impl ConsistencyReport {
674    /// Generate report string
675    pub fn format_report(&self) -> String {
676        let mut report = String::new();
677
678        report.push_str("Dataset Consistency Report\n");
679        report.push_str("=".repeat(60).as_str());
680        report.push('\n');
681
682        report.push_str(&format!("Total Samples: {}\n", self.total_samples));
683        report.push_str(&format!("Checked Samples: {}\n", self.checked_samples));
684        report.push_str(&format!("Is Consistent: {}\n\n", self.is_consistent));
685
686        if !self.issues.is_empty() {
687            report.push_str(&format!("Issues Found ({}):\n", self.issues.len()));
688            for (i, issue) in self.issues.iter().enumerate() {
689                report.push_str(&format!("  {}. {}\n", i + 1, issue));
690            }
691        } else {
692            report.push_str("No issues found.\n");
693        }
694
695        report
696    }
697}
698
699// ──────────────────────────────────────────────────────────────────────────────
700// PipelineInspector — per-step transform instrumentation
701// ──────────────────────────────────────────────────────────────────────────────
702
703/// Record of one transform step applied to a single sample.
704#[derive(Debug, Clone)]
705pub struct InspectionEvent {
706    /// Name given to this transform step.
707    pub step_name: String,
708    /// Shape of the feature tensor *before* the transform.
709    pub input_shape: Vec<usize>,
710    /// Shape of the feature tensor *after* the transform (None if an error occurred).
711    pub output_shape: Option<Vec<usize>>,
712    /// Wall-clock time the transform took in microseconds.
713    pub latency_micros: u64,
714    /// Error message if the transform failed, otherwise `None`.
715    pub error: Option<String>,
716}
717
718/// Aggregated result of running a `PipelineInspector` over one or more samples.
719#[derive(Debug, Clone)]
720pub struct PipelineInspectionReport {
721    /// All recorded events, in chronological order.
722    pub events: Vec<InspectionEvent>,
723    /// Sum of all per-event latencies (microseconds).
724    pub total_latency_micros: u64,
725    /// Number of events that recorded an error.
726    pub error_count: usize,
727    /// Number of samples processed.
728    pub sample_count: usize,
729}
730
731impl PipelineInspectionReport {
732    /// Create a new empty report.
733    pub fn new() -> Self {
734        Self {
735            events: Vec::new(),
736            total_latency_micros: 0,
737            error_count: 0,
738            sample_count: 0,
739        }
740    }
741
742    fn push_event(&mut self, event: InspectionEvent) {
743        self.total_latency_micros += event.latency_micros;
744        if event.error.is_some() {
745            self.error_count += 1;
746        }
747        self.events.push(event);
748    }
749
750    /// Average latency per step in microseconds. Returns `0` if no events recorded.
751    pub fn avg_latency_per_step_micros(&self) -> u64 {
752        if self.events.is_empty() {
753            return 0;
754        }
755        self.total_latency_micros / self.events.len() as u64
756    }
757
758    /// Fraction of steps that produced an error (in [0, 1]).
759    pub fn error_rate(&self) -> f64 {
760        if self.events.is_empty() {
761            return 0.0;
762        }
763        self.error_count as f64 / self.events.len() as f64
764    }
765}
766
767impl Default for PipelineInspectionReport {
768    fn default() -> Self {
769        Self::new()
770    }
771}
772
773/// An instrumented transform pipeline that records per-step latency, shapes, and errors.
774///
775/// Steps are appended with [`InspectablePipeline::add_step`] and the pipeline is exercised via
776/// [`InspectablePipeline::inspect_sample`] or [`InspectablePipeline::run_inspection_batch`].
777pub struct InspectablePipeline {
778    steps: Vec<(String, Box<dyn crate::transforms::Transform<f32>>)>,
779}
780
781impl InspectablePipeline {
782    /// Create an empty pipeline.
783    pub fn new() -> Self {
784        Self { steps: Vec::new() }
785    }
786
787    /// Append a named transform step to the pipeline.
788    pub fn add_step(
789        &mut self,
790        name: impl Into<String>,
791        transform: Box<dyn crate::transforms::Transform<f32>>,
792    ) {
793        self.steps.push((name.into(), transform));
794    }
795
796    /// Run all steps on a single `(features, labels)` sample and return one
797    /// `InspectionEvent` per step.
798    pub fn inspect_sample(
799        &self,
800        sample: (tenflowers_core::Tensor<f32>, tenflowers_core::Tensor<f32>),
801    ) -> Vec<InspectionEvent> {
802        let mut events = Vec::with_capacity(self.steps.len());
803        let mut current = sample;
804
805        for (name, transform) in &self.steps {
806            let input_shape = current.0.shape().to_vec();
807            let start = std::time::Instant::now();
808            match transform.apply(current.clone()) {
809                Ok(out) => {
810                    let latency_micros = start.elapsed().as_micros() as u64;
811                    let output_shape = Some(out.0.shape().to_vec());
812                    events.push(InspectionEvent {
813                        step_name: name.clone(),
814                        input_shape,
815                        output_shape,
816                        latency_micros,
817                        error: None,
818                    });
819                    current = out;
820                }
821                Err(e) => {
822                    let latency_micros = start.elapsed().as_micros() as u64;
823                    events.push(InspectionEvent {
824                        step_name: name.clone(),
825                        input_shape,
826                        output_shape: None,
827                        latency_micros,
828                        error: Some(e.to_string()),
829                    });
830                    break;
831                }
832            }
833        }
834
835        events
836    }
837
838    /// Inspect `n_samples` from a `Dataset` and return an aggregated `PipelineInspectionReport`.
839    pub fn run_inspection_batch<D>(&self, dataset: &D, n_samples: usize) -> PipelineInspectionReport
840    where
841        D: crate::Dataset<f32>,
842    {
843        let mut report = PipelineInspectionReport::new();
844        let count = n_samples.min(dataset.len());
845
846        for idx in 0..count {
847            if let Ok(sample) = dataset.get(idx) {
848                for event in self.inspect_sample(sample) {
849                    report.push_event(event);
850                }
851                report.sample_count += 1;
852            }
853        }
854
855        report
856    }
857}
858
859impl Default for InspectablePipeline {
860    fn default() -> Self {
861        Self::new()
862    }
863}
864
865#[cfg(test)]
866mod tests {
867    use super::*;
868    use crate::TensorDataset;
869    use tenflowers_core::Tensor;
870
871    #[test]
872    fn test_profiler_creation() {
873        let profiler = PipelineProfiler::default_config("test_pipeline");
874        assert_eq!(profiler.name, "test_pipeline");
875    }
876
877    #[test]
878    fn test_profiler_events() {
879        let mut profiler = PipelineProfiler::default_config("test");
880        profiler.start();
881
882        profiler.record_event(EventType::DataLoad, "load_stage", None);
883        profiler.record_event(
884            EventType::Transform,
885            "transform_stage",
886            Some(Duration::from_millis(10)),
887        );
888
889        profiler.stop();
890
891        let report = profiler.generate_report();
892        assert!(report.total_events > 0);
893    }
894
895    #[test]
896    fn test_stage_timing() {
897        let mut profiler = PipelineProfiler::default_config("test");
898        profiler.start();
899
900        let timer = profiler.start_stage("test_stage");
901        std::thread::sleep(Duration::from_millis(10));
902        profiler.end_stage(timer);
903
904        let report = profiler.generate_report();
905        assert!(report.stage_stats.contains_key("test_stage"));
906    }
907
908    #[test]
909    fn test_dataset_debugger_inspect() {
910        let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
911            .expect("test: tensor creation should succeed");
912        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
913            .expect("test: tensor creation should succeed");
914        let dataset = TensorDataset::new(features, labels);
915
916        let samples =
917            DatasetDebugger::inspect_samples(&dataset, 5).expect("test: operation should succeed");
918        assert_eq!(samples.len(), 2);
919        assert_eq!(samples[0].feature_shape, vec![2]);
920    }
921
922    #[test]
923    fn test_consistency_check() {
924        let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
925            .expect("test: tensor creation should succeed");
926        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
927            .expect("test: tensor creation should succeed");
928        let dataset = TensorDataset::new(features, labels);
929
930        let report =
931            DatasetDebugger::verify_consistency(&dataset).expect("test: operation should succeed");
932        assert!(report.is_consistent);
933        assert_eq!(report.total_samples, 2);
934    }
935
936    #[test]
937    fn test_profile_report_generation() {
938        let mut profiler = PipelineProfiler::default_config("test");
939        profiler.start();
940
941        let timer = profiler.start_stage("data_loading");
942        std::thread::sleep(Duration::from_millis(5));
943        profiler.end_stage(timer);
944
945        profiler.stop();
946
947        let report = profiler.generate_report();
948        let report_string = report.format_report();
949
950        assert!(report_string.contains("Pipeline Profiling Report"));
951        assert!(report_string.contains("data_loading"));
952    }
953
954    // ── InspectablePipeline tests ────────────────────────────────────────────
955
956    struct IdentityTransform;
957
958    impl crate::transforms::Transform<f32> for IdentityTransform {
959        fn apply(
960            &self,
961            sample: (Tensor<f32>, Tensor<f32>),
962        ) -> tenflowers_core::Result<(Tensor<f32>, Tensor<f32>)> {
963            Ok(sample)
964        }
965    }
966
967    #[test]
968    fn test_inspectable_pipeline_records_events() {
969        let mut pipeline = InspectablePipeline::new();
970        pipeline.add_step("identity", Box::new(IdentityTransform));
971
972        let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0], &[3])
973            .expect("test: tensor creation should succeed");
974        let labels =
975            Tensor::<f32>::from_vec(vec![1.0], &[1]).expect("test: tensor creation should succeed");
976
977        let events = pipeline.inspect_sample((features, labels));
978        assert_eq!(events.len(), 1);
979        assert_eq!(events[0].step_name, "identity");
980        assert!(events[0].error.is_none());
981        assert!(events[0].output_shape.is_some());
982    }
983
984    #[test]
985    fn test_inspectable_pipeline_shape_tracking() {
986        let mut pipeline = InspectablePipeline::new();
987        pipeline.add_step("step1", Box::new(IdentityTransform));
988        pipeline.add_step("step2", Box::new(IdentityTransform));
989
990        let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
991            .expect("test: tensor creation should succeed");
992        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
993            .expect("test: tensor creation should succeed");
994
995        let events = pipeline.inspect_sample((features, labels));
996        assert_eq!(events.len(), 2);
997        assert_eq!(events[0].input_shape, vec![2, 2]);
998        assert_eq!(events[1].input_shape, vec![2, 2]);
999    }
1000
1001    #[test]
1002    fn test_run_inspection_batch_aggregation() {
1003        let features = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
1004            .expect("test: tensor creation should succeed");
1005        let labels = Tensor::<f32>::from_vec(vec![0.0, 1.0], &[2])
1006            .expect("test: tensor creation should succeed");
1007        let dataset = TensorDataset::new(features, labels);
1008
1009        let mut pipeline = InspectablePipeline::new();
1010        pipeline.add_step("identity", Box::new(IdentityTransform));
1011
1012        let report = pipeline.run_inspection_batch(&dataset, 100);
1013        assert_eq!(report.sample_count, 2);
1014        assert_eq!(report.events.len(), 2);
1015        assert_eq!(report.error_count, 0);
1016        assert_eq!(report.error_rate(), 0.0);
1017    }
1018
1019    #[test]
1020    fn test_pipeline_inspection_report_empty() {
1021        let report = PipelineInspectionReport::new();
1022        assert_eq!(report.avg_latency_per_step_micros(), 0);
1023        assert_eq!(report.error_rate(), 0.0);
1024    }
1025}