Skip to main content

teaql_runtime/repository/
context.rs

1use std::sync::Arc;
2use std::time::{Instant, SystemTime};
3
4use teaql_core::{
5    DeleteCommand, Entity, InsertCommand, Record, RecoverCommand, SelectQuery, SmartList,
6    UpdateCommand,
7};
8use teaql_sql::{CompiledQuery, SqlDialect};
9
10use crate::{
11    ContextError, GraphMutationPlan, GraphNode, RepositoryError, RuntimeError, SqlLogOperation,
12    UserContext,
13};
14
15use super::{
16    AggregationCacheBackend, ContextRepository, InMemoryAggregationCache, QueryExecutor,
17    Repository, ResolvedRepository, UserContextMetadata,
18    helpers::invalidate_aggregation_cache_namespace,
19};
20
21impl UserContext {
22    pub fn repository<D, E>(&self) -> Result<ContextRepository<'_, D, E>, ContextError>
23    where
24        D: SqlDialect + Send + Sync + 'static,
25        E: QueryExecutor + Send + Sync + 'static,
26    {
27        if self.metadata.is_none() {
28            return Err(ContextError::MissingResource("metadata".to_owned()));
29        }
30
31        let dialect = self.require_resource::<D>()?;
32        let executor = self.require_resource::<E>()?;
33        Ok(ContextRepository {
34            metadata: UserContextMetadata { context: self },
35            dialect,
36            executor,
37        })
38    }
39
40    #[doc(hidden)]
41    pub fn resolve_repository<D, E>(
42        &self,
43        entity: impl Into<String>,
44    ) -> Result<ResolvedRepository<'_, D, E>, ContextError>
45    where
46        D: SqlDialect + Send + Sync + 'static,
47        E: QueryExecutor + Send + Sync + 'static,
48    {
49        let entity = entity.into();
50        if !self.has_repository(&entity) {
51            return Err(ContextError::MissingRepository(entity));
52        }
53        Ok(ResolvedRepository {
54            entity,
55            repository: self.repository::<D, E>()?,
56            trace_context: Vec::new(),
57        })
58    }
59
60    pub fn plan_for_save_graph<D, E>(
61        &self,
62        node: GraphNode,
63    ) -> Result<GraphMutationPlan, RepositoryError<E::Error>>
64    where
65        D: SqlDialect + Send + Sync + 'static,
66        E: QueryExecutor + Send + Sync + 'static,
67    {
68        let repository = self
69            .resolve_repository::<D, E>(node.entity.clone())
70            .map_err(|err| RepositoryError::Runtime(RuntimeError::Graph(err.to_string())))?;
71        repository.plan_graph(node)
72    }
73}
74
75impl<'a, D, E> ContextRepository<'a, D, E>
76where
77    D: SqlDialect,
78    E: QueryExecutor,
79{
80    fn repository(&self) -> Repository<'_, D, UserContextMetadata<'_>, E> {
81        Repository::new(self.dialect, &self.metadata, self.executor)
82    }
83
84    pub fn compile(&self, query: &SelectQuery) -> Result<CompiledQuery, RuntimeError> {
85        self.repository().compile(query)
86    }
87
88    pub fn fetch_all(&self, query: &SelectQuery) -> Result<Vec<Record>, RepositoryError<E::Error>> {
89        let mut compiled = self.compile(query).map_err(RepositoryError::Runtime)?;
90        let final_comment = self.resolve_final_comment(&query.trace_chain, query.comment.clone());
91        compiled.comment = final_comment;
92
93        let started_at = SystemTime::now();
94        let started = Instant::now();
95        let rows = self
96            .executor
97            .fetch_all(&compiled)
98            .map_err(RepositoryError::Executor)?;
99        self.log_sql_result(
100            SqlLogOperation::Select,
101            &compiled,
102            started_at,
103            started,
104            Some(rows.len()),
105            Some(query.entity.clone()),
106            None,
107            query.trace_chain.clone(),
108        );
109        Ok(rows)
110    }
111
112    pub fn fetch_smart_list(
113        &self,
114        query: &SelectQuery,
115    ) -> Result<SmartList<Record>, RepositoryError<E::Error>> {
116        self.repository().fetch_smart_list(query)
117    }
118
119    pub fn fetch_entities<T>(
120        &self,
121        query: &SelectQuery,
122    ) -> Result<SmartList<T>, RepositoryError<E::Error>>
123    where
124        T: Entity,
125    {
126        self.repository().fetch_entities(query)
127    }
128
129    pub fn fetch_enhanced_entities<T>(
130        &self,
131        query: &SelectQuery,
132    ) -> Result<SmartList<T>, RepositoryError<E::Error>>
133    where
134        T: Entity,
135    {
136        self.repository().fetch_enhanced_entities(query)
137    }
138
139    pub fn insert(&self, command: &InsertCommand) -> Result<u64, RepositoryError<E::Error>> {
140        let mut compiled = self
141            .repository()
142            .compile_insert(command)
143            .map_err(RepositoryError::Runtime)?;
144        let final_comment = self.resolve_final_comment(&command.trace_chain, None);
145        compiled.comment = final_comment;
146
147        let started_at = SystemTime::now();
148        let started = Instant::now();
149        let affected = self
150            .executor
151            .execute(&compiled)
152            .map_err(RepositoryError::Executor)?;
153        self.log_sql_result(
154            SqlLogOperation::Insert,
155            &compiled,
156            started_at,
157            started,
158            None,
159            None,
160            Some(affected),
161            command.trace_chain.clone(),
162        );
163        self.invalidate_aggregation_cache_for(&command.entity);
164        Ok(affected)
165    }
166
167    pub fn update(&self, command: &UpdateCommand) -> Result<u64, RepositoryError<E::Error>> {
168        let affected = self.execute_mutation(
169            SqlLogOperation::Update,
170            &command.entity,
171            self.repository()
172                .compile_update(command)
173                .map_err(RepositoryError::Runtime)?,
174            command.trace_chain.clone(),
175        )?;
176        if command.expected_version.is_some() && affected == 0 {
177            return Err(RepositoryError::Runtime(
178                RuntimeError::OptimisticLockConflict {
179                    entity: command.entity.clone(),
180                    id: format!("{:?}", command.id),
181                },
182            ));
183        }
184        Ok(affected)
185    }
186
187    pub fn batch_insert(&self, command: &teaql_core::BatchInsertCommand) -> Result<u64, RepositoryError<E::Error>> {
188        let trace_chain = command.trace_chains.first().cloned().unwrap_or_default();
189        let affected = self.execute_mutation(
190            SqlLogOperation::Insert,
191            &command.entity,
192            self.repository()
193                .compile_batch_insert(command)
194                .map_err(RepositoryError::Runtime)?,
195            trace_chain,
196        )?;
197        Ok(affected)
198    }
199
200    pub fn batch_update(&self, command: &teaql_core::BatchUpdateCommand) -> Result<u64, RepositoryError<E::Error>> {
201        let trace_chain = command.trace_chains.first().cloned().unwrap_or_default();
202        let affected = self.execute_mutation(
203            SqlLogOperation::Update,
204            &command.entity,
205            self.repository()
206                .compile_batch_update(command)
207                .map_err(RepositoryError::Runtime)?,
208            trace_chain,
209        )?;
210        if command.batch_expected_versions.iter().any(|v| v.is_some()) {
211            if affected != command.batch_ids.len() as u64 {
212                return Err(RepositoryError::Runtime(
213                    RuntimeError::OptimisticLockConflict {
214                        entity: command.entity.clone(),
215                        id: "BATCH".to_owned(),
216                    },
217                ));
218            }
219        }
220        Ok(affected)
221    }
222
223    pub fn delete(&self, command: &DeleteCommand) -> Result<u64, RepositoryError<E::Error>> {
224        let affected = self.execute_mutation(
225            SqlLogOperation::Delete,
226            &command.entity,
227            self.repository()
228                .compile_delete(command)
229                .map_err(RepositoryError::Runtime)?,
230            command.trace_chain.clone(),
231        )?;
232        if command.expected_version.is_some() && affected == 0 {
233            return Err(RepositoryError::Runtime(
234                RuntimeError::OptimisticLockConflict {
235                    entity: command.entity.clone(),
236                    id: format!("{:?}", command.id),
237                },
238            ));
239        }
240        Ok(affected)
241    }
242
243    pub fn recover(&self, command: &RecoverCommand) -> Result<u64, RepositoryError<E::Error>> {
244        let affected = self.execute_mutation(
245            SqlLogOperation::Recover,
246            &command.entity,
247            self.repository()
248                .compile_recover(command)
249                .map_err(RepositoryError::Runtime)?,
250            command.trace_chain.clone(),
251        )?;
252        if affected == 0 {
253            return Err(RepositoryError::Runtime(
254                RuntimeError::OptimisticLockConflict {
255                    entity: command.entity.clone(),
256                    id: format!("{:?}", command.id),
257                },
258            ));
259        }
260        Ok(affected)
261    }
262
263    fn execute_mutation(
264        &self,
265        operation: SqlLogOperation,
266        entity: &str,
267        mut compiled: CompiledQuery,
268        trace_chain: Vec<teaql_core::TraceNode>,
269    ) -> Result<u64, RepositoryError<E::Error>> {
270        let final_comment = self.resolve_final_comment(&trace_chain, None);
271        compiled.comment = final_comment;
272
273        let started_at = SystemTime::now();
274        let started = Instant::now();
275        let affected = self
276            .executor
277            .execute(&compiled)
278            .map_err(RepositoryError::Executor)?;
279        self.log_sql_result(
280            operation,
281            &compiled,
282            started_at,
283            started,
284            None,
285            None,
286            Some(affected),
287            trace_chain,
288        );
289        self.invalidate_aggregation_cache_for(entity);
290        Ok(affected)
291    }
292
293    pub(super) fn log_sql_result(
294        &self,
295        operation: SqlLogOperation,
296        compiled: &CompiledQuery,
297        started_at: SystemTime,
298        started: Instant,
299        result_count: Option<usize>,
300        result_type: Option<String>,
301        affected_rows: Option<u64>,
302        trace_chain: Vec<teaql_core::TraceNode>,
303    ) {
304        self.metadata.context.record_sql_log(
305            operation,
306            compiled,
307            self.dialect.kind(),
308            started_at,
309            SystemTime::now(),
310            started.elapsed(),
311            result_count,
312            result_type,
313            affected_rows,
314            trace_chain,
315        );
316    }
317
318    pub(super) fn invalidate_aggregation_cache_for(&self, entity: &str) {
319        if let Some(cache) = self
320            .metadata
321            .context
322            .get_resource::<Arc<dyn AggregationCacheBackend>>()
323        {
324            invalidate_aggregation_cache_namespace(cache.as_ref(), entity);
325        }
326        if let Some(cache) = self
327            .metadata
328            .context
329            .get_resource::<InMemoryAggregationCache>()
330        {
331            invalidate_aggregation_cache_namespace(cache, entity);
332        }
333    }
334
335    pub(crate) fn resolve_final_comment(&self, trace_chain: &[teaql_core::TraceNode], comment: Option<String>) -> Option<String> {
336        let chain_str = if trace_chain.is_empty() {
337            None
338        } else {
339            let formatted = trace_chain.iter().map(|n| {
340                format!("{}({}): {}", n.entity_type, n.entity_id.map(|id| id.to_string()).unwrap_or_else(|| "pending".to_owned()), n.comment)
341            }).collect::<Vec<_>>().join(" -> ");
342            Some(formatted)
343        };
344
345        let business_comment = chain_str.or(comment);
346        let user_id = self.metadata.context.user_identifier().map(|s| s.to_owned());
347
348        match (user_id, business_comment) {
349            (Some(user), Some(bus)) if !user.is_empty() && !bus.is_empty() => {
350                Some(format!("[{user}] {bus}"))
351            }
352            (Some(user), _) if !user.is_empty() => {
353                Some(format!("[{user}]"))
354            }
355            (_, Some(bus)) if !bus.is_empty() => {
356                Some(bus)
357            }
358            _ => None,
359        }
360    }
361}