Skip to main content

sqlmodel_session/
n1_detection.rs

1//! N+1 Query Detection for SQLModel Rust.
2//!
3//! This module provides detection and warning for the N+1 query anti-pattern,
4//! which occurs when code loads N objects and then lazily loads a relationship
5//! for each, resulting in N+1 database queries instead of 2.
6//!
7//! # Example
8//!
9//! ```ignore
10//! // Enable N+1 detection
11//! session.enable_n1_detection(3);  // Warn after 3 lazy loads
12//!
13//! // This will trigger a warning:
14//! for hero in &mut heroes {
15//!     hero.team.load(&mut session).await?;  // N queries!
16//! }
17//!
18//! // This is the fix:
19//! session.load_many(&mut heroes, |h| &mut h.team).await?;  // 1 query
20//! ```
21
22use std::collections::HashMap;
23use std::sync::atomic::{AtomicUsize, Ordering};
24
25/// Tracks lazy load queries for N+1 detection.
26#[derive(Debug)]
27pub struct N1QueryTracker {
28    /// (parent_type, relationship_name) -> query count
29    counts: HashMap<(&'static str, &'static str), AtomicUsize>,
30    /// Threshold for warning (queries per relationship)
31    threshold: usize,
32    /// Whether detection is enabled
33    enabled: bool,
34    /// Captured call sites for debugging
35    call_sites: Vec<CallSite>,
36}
37
38impl Default for N1QueryTracker {
39    fn default() -> Self {
40        Self::new()
41    }
42}
43
44/// Information about where a lazy load was triggered.
45#[derive(Debug, Clone)]
46pub struct CallSite {
47    /// The parent model type name
48    pub parent_type: &'static str,
49    /// The relationship field name
50    pub relationship: &'static str,
51    /// Source file where the load was triggered
52    pub file: &'static str,
53    /// Line number in the source file
54    pub line: u32,
55    /// When the load occurred
56    pub timestamp: std::time::Instant,
57}
58
59/// Statistics about N+1 detection.
60#[derive(Debug, Clone, Default)]
61pub struct N1Stats {
62    /// Total number of lazy loads recorded
63    pub total_loads: usize,
64    /// Number of distinct relationships loaded
65    pub relationships_loaded: usize,
66    /// Number of relationships that exceeded the threshold
67    pub potential_n1: usize,
68}
69
70impl N1QueryTracker {
71    /// Create a new tracker with default threshold (3).
72    #[must_use]
73    pub fn new() -> Self {
74        Self {
75            counts: HashMap::new(),
76            threshold: 3,
77            enabled: true,
78            call_sites: Vec::new(),
79        }
80    }
81
82    /// Set the threshold for N+1 warnings.
83    ///
84    /// A warning is emitted when the number of lazy loads for a single
85    /// relationship reaches this threshold.
86    #[must_use]
87    pub fn with_threshold(mut self, threshold: usize) -> Self {
88        self.threshold = threshold;
89        self
90    }
91
92    /// Get the current threshold.
93    #[must_use]
94    pub fn threshold(&self) -> usize {
95        self.threshold
96    }
97
98    /// Check if detection is enabled.
99    #[must_use]
100    pub fn is_enabled(&self) -> bool {
101        self.enabled
102    }
103
104    /// Disable N+1 detection.
105    pub fn disable(&mut self) {
106        self.enabled = false;
107    }
108
109    /// Enable N+1 detection.
110    pub fn enable(&mut self) {
111        self.enabled = true;
112    }
113
114    /// Record a lazy load query.
115    ///
116    /// This should be called whenever a lazy relationship is loaded.
117    /// When the count for a (parent_type, relationship) pair reaches
118    /// the threshold, a warning is emitted.
119    #[track_caller]
120    pub fn record_load(&mut self, parent_type: &'static str, relationship: &'static str) {
121        if !self.enabled {
122            return;
123        }
124
125        let key = (parent_type, relationship);
126        let count = self
127            .counts
128            .entry(key)
129            .or_insert_with(|| AtomicUsize::new(0))
130            .fetch_add(1, Ordering::Relaxed)
131            + 1;
132
133        // Capture call site
134        let caller = std::panic::Location::caller();
135        self.call_sites.push(CallSite {
136            parent_type,
137            relationship,
138            file: caller.file(),
139            line: caller.line(),
140            timestamp: std::time::Instant::now(),
141        });
142
143        // Check threshold
144        if count == self.threshold {
145            self.emit_warning(parent_type, relationship, count);
146        }
147    }
148
149    /// Emit a warning about potential N+1 query pattern.
150    fn emit_warning(&self, parent_type: &'static str, relationship: &'static str, count: usize) {
151        tracing::warn!(
152            target: "sqlmodel::n1",
153            parent = parent_type,
154            relationship = relationship,
155            queries = count,
156            threshold = self.threshold,
157            "N+1 QUERY PATTERN DETECTED! Consider using Session::load_many() for batch loading."
158        );
159
160        // Log recent call sites for this relationship
161        let sites: Vec<_> = self
162            .call_sites
163            .iter()
164            .filter(|s| s.parent_type == parent_type && s.relationship == relationship)
165            .take(5)
166            .collect();
167
168        for (i, site) in sites.iter().enumerate() {
169            tracing::debug!(
170                target: "sqlmodel::n1",
171                index = i,
172                file = site.file,
173                line = site.line,
174                "  [{}] {}:{}",
175                i,
176                site.file,
177                site.line
178            );
179        }
180    }
181
182    /// Reset all counts and call sites.
183    ///
184    /// Call this at the start of a new request or transaction scope.
185    pub fn reset(&mut self) {
186        self.counts.clear();
187        self.call_sites.clear();
188    }
189
190    /// Get the current count for a specific relationship.
191    #[must_use]
192    pub fn count_for(&self, parent_type: &'static str, relationship: &'static str) -> usize {
193        self.counts
194            .get(&(parent_type, relationship))
195            .map_or(0, |c| c.load(Ordering::Relaxed))
196    }
197
198    /// Get statistics about N+1 detection.
199    #[must_use]
200    pub fn stats(&self) -> N1Stats {
201        N1Stats {
202            total_loads: self
203                .counts
204                .values()
205                .map(|c| c.load(Ordering::Relaxed))
206                .sum(),
207            relationships_loaded: self.counts.len(),
208            potential_n1: self
209                .counts
210                .iter()
211                .filter(|(_, c)| c.load(Ordering::Relaxed) >= self.threshold)
212                .count(),
213        }
214    }
215
216    /// Get all call sites (for debugging).
217    #[must_use]
218    pub fn call_sites(&self) -> &[CallSite] {
219        &self.call_sites
220    }
221}
222
223// ============================================================================
224// N1DetectionScope - RAII Guard
225// ============================================================================
226
227/// Scope helper for N+1 detection tracking.
228///
229/// This helper captures the initial N+1 stats when created, allowing you to
230/// compare against final stats and log a summary of issues detected within
231/// the scope.
232///
233/// **Note:** This is NOT an automatic RAII guard - you must call `log_summary()`
234/// manually with the final stats. For automatic logging, wrap your code in a
235/// block and call `log_summary` at the end.
236///
237/// # Example
238///
239/// ```ignore
240/// // Capture initial state
241/// let scope = N1DetectionScope::from_tracker(session.n1_tracker());
242///
243/// // Do work that might cause N+1...
244/// for hero in &mut heroes {
245///     hero.team.load(&mut session).await?;
246/// }
247///
248/// // Manually log summary with final stats
249/// scope.log_summary(&session.n1_stats());
250/// ```
251pub struct N1DetectionScope {
252    /// Stats captured when the scope was created (for comparison)
253    initial_stats: N1Stats,
254    /// Threshold used for this scope
255    threshold: usize,
256    /// Whether to log on drop even if no issues
257    verbose: bool,
258}
259
260impl N1DetectionScope {
261    /// Create a new detection scope.
262    ///
263    /// This does NOT automatically enable detection on a Session - the caller
264    /// should have already called `session.enable_n1_detection()`. This scope
265    /// captures the initial state and logs a summary on drop.
266    ///
267    /// # Arguments
268    ///
269    /// * `initial_stats` - The current N1Stats (from `session.n1_stats()`)
270    /// * `threshold` - The threshold being used for detection
271    #[must_use]
272    pub fn new(initial_stats: N1Stats, threshold: usize) -> Self {
273        tracing::debug!(
274            target: "sqlmodel::n1",
275            threshold = threshold,
276            "N+1 detection scope started"
277        );
278
279        Self {
280            initial_stats,
281            threshold,
282            verbose: false,
283        }
284    }
285
286    /// Create a scope from a tracker reference.
287    ///
288    /// Convenience method that extracts stats and threshold from an existing tracker.
289    #[must_use]
290    pub fn from_tracker(tracker: &N1QueryTracker) -> Self {
291        Self::new(tracker.stats(), tracker.threshold())
292    }
293
294    /// Enable verbose logging (log summary even if no issues).
295    #[must_use]
296    pub fn verbose(mut self) -> Self {
297        self.verbose = true;
298        self
299    }
300
301    /// Log a summary of the detection results.
302    ///
303    /// Called automatically on drop, but can be called manually for
304    /// intermediate reporting.
305    pub fn log_summary(&self, final_stats: &N1Stats) {
306        let new_loads = final_stats
307            .total_loads
308            .saturating_sub(self.initial_stats.total_loads);
309        let new_relationships = final_stats
310            .relationships_loaded
311            .saturating_sub(self.initial_stats.relationships_loaded);
312        let new_n1 = final_stats
313            .potential_n1
314            .saturating_sub(self.initial_stats.potential_n1);
315
316        if new_n1 > 0 {
317            tracing::warn!(
318                target: "sqlmodel::n1",
319                potential_n1 = new_n1,
320                total_loads = new_loads,
321                relationships = new_relationships,
322                threshold = self.threshold,
323                "N+1 ISSUES DETECTED in this scope! Consider using Session::load_many() for batch loading."
324            );
325        } else if self.verbose {
326            tracing::info!(
327                target: "sqlmodel::n1",
328                total_loads = new_loads,
329                relationships = new_relationships,
330                "N+1 detection scope completed (no issues)"
331            );
332        } else {
333            tracing::debug!(
334                target: "sqlmodel::n1",
335                total_loads = new_loads,
336                relationships = new_relationships,
337                "N+1 detection scope completed (no issues)"
338            );
339        }
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn test_tracker_new_defaults() {
349        let tracker = N1QueryTracker::new();
350        assert_eq!(tracker.threshold(), 3);
351        assert!(tracker.is_enabled());
352    }
353
354    #[test]
355    fn test_tracker_with_threshold() {
356        let tracker = N1QueryTracker::new().with_threshold(5);
357        assert_eq!(tracker.threshold(), 5);
358    }
359
360    #[test]
361    fn test_tracker_enable_disable() {
362        let mut tracker = N1QueryTracker::new();
363        assert!(tracker.is_enabled());
364
365        tracker.disable();
366        assert!(!tracker.is_enabled());
367
368        tracker.enable();
369        assert!(tracker.is_enabled());
370    }
371
372    #[test]
373    fn test_tracker_records_single_load() {
374        let mut tracker = N1QueryTracker::new();
375        tracker.record_load("Hero", "team");
376        assert_eq!(tracker.count_for("Hero", "team"), 1);
377    }
378
379    #[test]
380    fn test_tracker_records_multiple_loads() {
381        let mut tracker = N1QueryTracker::new().with_threshold(10);
382        for _ in 0..5 {
383            tracker.record_load("Hero", "team");
384        }
385        assert_eq!(tracker.count_for("Hero", "team"), 5);
386    }
387
388    #[test]
389    fn test_tracker_records_multiple_relationships() {
390        let mut tracker = N1QueryTracker::new();
391        tracker.record_load("Hero", "team");
392        tracker.record_load("Hero", "team");
393        tracker.record_load("Hero", "powers");
394        tracker.record_load("Team", "heroes");
395
396        assert_eq!(tracker.count_for("Hero", "team"), 2);
397        assert_eq!(tracker.count_for("Hero", "powers"), 1);
398        assert_eq!(tracker.count_for("Team", "heroes"), 1);
399    }
400
401    #[test]
402    fn test_tracker_disabled_no_recording() {
403        let mut tracker = N1QueryTracker::new();
404        tracker.disable();
405        tracker.record_load("Hero", "team");
406        assert_eq!(tracker.count_for("Hero", "team"), 0);
407    }
408
409    #[test]
410    fn test_tracker_reset_clears_counts() {
411        let mut tracker = N1QueryTracker::new();
412        tracker.record_load("Hero", "team");
413        tracker.record_load("Hero", "team");
414        assert_eq!(tracker.count_for("Hero", "team"), 2);
415
416        tracker.reset();
417        assert_eq!(tracker.count_for("Hero", "team"), 0);
418        assert!(tracker.call_sites().is_empty());
419    }
420
421    #[test]
422    fn test_callsite_captures_location() {
423        let mut tracker = N1QueryTracker::new();
424        tracker.record_load("Hero", "team");
425
426        assert_eq!(tracker.call_sites().len(), 1);
427        let site = &tracker.call_sites()[0];
428        assert_eq!(site.parent_type, "Hero");
429        assert_eq!(site.relationship, "team");
430        assert!(site.file.contains("n1_detection.rs"));
431        assert!(site.line > 0);
432    }
433
434    #[test]
435    fn test_callsite_timestamp_monotonic() {
436        let mut tracker = N1QueryTracker::new();
437        tracker.record_load("Hero", "team");
438        tracker.record_load("Hero", "team");
439
440        let sites = tracker.call_sites();
441        assert!(sites[1].timestamp >= sites[0].timestamp);
442    }
443
444    #[test]
445    fn test_stats_total_loads_accurate() {
446        let mut tracker = N1QueryTracker::new().with_threshold(10);
447        tracker.record_load("Hero", "team");
448        tracker.record_load("Hero", "team");
449        tracker.record_load("Hero", "powers");
450
451        let stats = tracker.stats();
452        assert_eq!(stats.total_loads, 3);
453    }
454
455    #[test]
456    fn test_stats_relationships_count() {
457        let mut tracker = N1QueryTracker::new();
458        tracker.record_load("Hero", "team");
459        tracker.record_load("Hero", "powers");
460        tracker.record_load("Team", "heroes");
461
462        let stats = tracker.stats();
463        assert_eq!(stats.relationships_loaded, 3);
464    }
465
466    #[test]
467    fn test_stats_potential_n1_count() {
468        let mut tracker = N1QueryTracker::new().with_threshold(2);
469        tracker.record_load("Hero", "team");
470        tracker.record_load("Hero", "team"); // Reaches threshold
471        tracker.record_load("Hero", "powers"); // Only 1
472
473        let stats = tracker.stats();
474        assert_eq!(stats.potential_n1, 1);
475    }
476
477    #[test]
478    fn test_stats_default() {
479        let stats = N1Stats::default();
480        assert_eq!(stats.total_loads, 0);
481        assert_eq!(stats.relationships_loaded, 0);
482        assert_eq!(stats.potential_n1, 0);
483    }
484
485    // ========================================================================
486    // N1DetectionScope Tests
487    // ========================================================================
488
489    #[test]
490    fn test_scope_new_captures_initial_state() {
491        let initial = N1Stats {
492            total_loads: 5,
493            relationships_loaded: 2,
494            potential_n1: 1,
495        };
496        let scope = N1DetectionScope::new(initial.clone(), 3);
497        assert_eq!(scope.initial_stats.total_loads, 5);
498        assert_eq!(scope.threshold, 3);
499    }
500
501    #[test]
502    fn test_scope_from_tracker() {
503        let mut tracker = N1QueryTracker::new().with_threshold(5);
504        tracker.record_load("Hero", "team");
505        tracker.record_load("Hero", "team");
506
507        let scope = N1DetectionScope::from_tracker(&tracker);
508        assert_eq!(scope.threshold, 5);
509        assert_eq!(scope.initial_stats.total_loads, 2);
510    }
511
512    #[test]
513    fn test_scope_verbose_flag() {
514        let initial = N1Stats::default();
515        let scope = N1DetectionScope::new(initial, 3);
516        assert!(!scope.verbose);
517
518        let verbose_scope = scope.verbose();
519        assert!(verbose_scope.verbose);
520    }
521
522    #[test]
523    fn test_scope_log_summary_no_issues() {
524        let initial = N1Stats::default();
525        let scope = N1DetectionScope::new(initial, 3);
526
527        // Final stats same as initial - no issues
528        let final_stats = N1Stats {
529            total_loads: 2,
530            relationships_loaded: 1,
531            potential_n1: 0,
532        };
533
534        // Should not panic and should log at debug level
535        scope.log_summary(&final_stats);
536    }
537
538    #[test]
539    fn test_scope_log_summary_with_issues() {
540        let initial = N1Stats::default();
541        let scope = N1DetectionScope::new(initial, 3);
542
543        // Final stats show N+1 issues
544        let final_stats = N1Stats {
545            total_loads: 10,
546            relationships_loaded: 2,
547            potential_n1: 1,
548        };
549
550        // Should log warning
551        scope.log_summary(&final_stats);
552    }
553
554    #[test]
555    fn test_scope_calculates_delta() {
556        let initial = N1Stats {
557            total_loads: 5,
558            relationships_loaded: 2,
559            potential_n1: 0,
560        };
561        let scope = N1DetectionScope::new(initial, 3);
562
563        let final_stats = N1Stats {
564            total_loads: 15,
565            relationships_loaded: 4,
566            potential_n1: 2,
567        };
568
569        // The scope should calculate: 15-5=10 new loads, 4-2=2 new relationships, 2-0=2 new N+1s
570        // We can't directly test the calculation, but the log_summary should use deltas
571        scope.log_summary(&final_stats);
572    }
573}