Skip to main content

shaperail_runtime/graphql/
dataloader.rs

1//! DataLoader for GraphQL relations (M15). Batches and caches relation lookups
2//! to prevent N+1 queries when resolving nested relations.
3//!
4//! Each GraphQL request gets a `RelationLoader` that groups lookups for the same
5//! (resource, foreign_key, value) triple and resolves them in a single batch query.
6
7use std::collections::HashMap;
8use std::sync::Arc;
9use tokio::sync::Mutex;
10
11use shaperail_core::{
12    EndpointSpec, HttpMethod, PaginationStyle, ResourceDefinition, ShaperailError,
13};
14
15use crate::db::{FilterParam, FilterSet, PageRequest, ResourceQuery, ResourceRow, SortParam};
16use crate::handlers::crud::{store_for_or_error, AppState};
17
18/// Cache key: (resource_name, filter_field, filter_value).
19type CacheKey = (String, String, String);
20
21/// Batched relation loader. Caches results per request to prevent N+1 queries.
22///
23/// Thread-safe via `Mutex` — shared across all resolvers in a single request.
24#[derive(Clone)]
25pub struct RelationLoader {
26    state: Arc<AppState>,
27    resources: Vec<ResourceDefinition>,
28    /// Cache of already-loaded rows: key → rows.
29    cache: Arc<Mutex<HashMap<CacheKey, Vec<ResourceRow>>>>,
30}
31
32impl RelationLoader {
33    pub fn new(state: Arc<AppState>, resources: Vec<ResourceDefinition>) -> Self {
34        Self {
35            state,
36            resources,
37            cache: Arc::new(Mutex::new(HashMap::new())),
38        }
39    }
40
41    /// Load a single record by ID (belongs_to). Uses cache to avoid duplicate queries.
42    pub async fn load_by_id(
43        &self,
44        resource_name: &str,
45        id: &uuid::Uuid,
46    ) -> Result<Option<ResourceRow>, ShaperailError> {
47        let key: CacheKey = (resource_name.to_string(), "id".to_string(), id.to_string());
48
49        // Check cache first.
50        {
51            let cache = self.cache.lock().await;
52            if let Some(rows) = cache.get(&key) {
53                return Ok(rows.first().cloned());
54            }
55        }
56
57        // Cache miss: load from DB.
58        let resource = self
59            .resources
60            .iter()
61            .find(|r| r.resource == resource_name)
62            .ok_or_else(|| {
63                ShaperailError::Internal(format!("Resource '{resource_name}' not found"))
64            })?;
65
66        let store_opt = store_for_or_error(&self.state, resource)?;
67        let row = if let Some(store) = store_opt {
68            store.find_by_id(id).await?
69        } else {
70            let rq = ResourceQuery::new(resource, &self.state.pool);
71            rq.find_by_id(id).await?
72        };
73
74        // Cache result.
75        {
76            let mut cache = self.cache.lock().await;
77            cache.insert(key, vec![row.clone()]);
78        }
79
80        Ok(Some(row))
81    }
82
83    /// Load related records by a filter field (has_many/has_one).
84    /// Results are cached per (resource, field, value) triple.
85    pub async fn load_by_filter(
86        &self,
87        resource_name: &str,
88        filter_field: &str,
89        filter_value: &str,
90    ) -> Result<Vec<ResourceRow>, ShaperailError> {
91        let key: CacheKey = (
92            resource_name.to_string(),
93            filter_field.to_string(),
94            filter_value.to_string(),
95        );
96
97        // Check cache first.
98        {
99            let cache = self.cache.lock().await;
100            if let Some(rows) = cache.get(&key) {
101                return Ok(rows.clone());
102            }
103        }
104
105        // Cache miss: load from DB.
106        let resource = self
107            .resources
108            .iter()
109            .find(|r| r.resource == resource_name)
110            .ok_or_else(|| {
111                ShaperailError::Internal(format!("Resource '{resource_name}' not found"))
112            })?;
113
114        let endpoint = resource
115            .endpoints
116            .as_ref()
117            .and_then(|e| e.get("list"))
118            .cloned()
119            .unwrap_or_else(|| EndpointSpec {
120                method: Some(HttpMethod::Get),
121                path: Some(format!("/{}", resource.resource)),
122                auth: None,
123                input: None,
124                filters: None,
125                search: None,
126                pagination: Some(PaginationStyle::Offset),
127                sort: None,
128                cache: None,
129                controller: None,
130                events: None,
131                jobs: None,
132                upload: None,
133                soft_delete: false,
134            });
135
136        let filters = FilterSet {
137            filters: vec![FilterParam {
138                field: filter_field.to_string(),
139                value: filter_value.to_string(),
140            }],
141        };
142        let sort = SortParam::default();
143        let page = PageRequest::Offset {
144            offset: 0,
145            limit: 1000,
146        };
147
148        let store_opt = store_for_or_error(&self.state, resource)?;
149        let (rows, _) = if let Some(store) = store_opt {
150            store
151                .find_all(&endpoint, &filters, None, &sort, &page)
152                .await?
153        } else {
154            let rq = ResourceQuery::new(resource, &self.state.pool);
155            rq.find_all(&filters, None, &sort, &page).await?
156        };
157
158        // Cache result.
159        {
160            let mut cache = self.cache.lock().await;
161            cache.insert(key, rows.clone());
162        }
163
164        Ok(rows)
165    }
166
167    /// Returns the number of cached entries (for testing N+1 prevention).
168    pub async fn cache_size(&self) -> usize {
169        self.cache.lock().await.len()
170    }
171}