Skip to main content

teaql_runtime/repository/
context.rs

1use std::sync::Arc;
2use std::time::{Instant, SystemTime};
3
4use teaql_core::{
5    DeleteCommand, Entity, InsertCommand, Record, RecoverCommand, SelectQuery, SmartList,
6    UpdateCommand,
7};
8use teaql_sql::{CompiledQuery, SqlDialect};
9
10use crate::{
11    ContextError, GraphMutationPlan, GraphNode, RepositoryError, RuntimeError, SqlLogOperation,
12    UserContext,
13};
14
15use super::{
16    AggregationCacheBackend, ContextRepository, InMemoryAggregationCache, QueryExecutor,
17    Repository, ResolvedRepository, UserContextMetadata,
18    helpers::invalidate_aggregation_cache_namespace,
19};
20
21impl UserContext {
22    pub fn repository<D, E>(&self) -> Result<ContextRepository<'_, D, E>, ContextError>
23    where
24        D: SqlDialect + Send + Sync + 'static,
25        E: QueryExecutor + Send + Sync + 'static,
26    {
27        if self.metadata.is_none() {
28            return Err(ContextError::MissingResource("metadata".to_owned()));
29        }
30
31        let dialect = self.require_resource::<D>()?;
32        let executor = self.require_resource::<E>()?;
33        Ok(ContextRepository {
34            metadata: UserContextMetadata { context: self },
35            dialect,
36            executor,
37        })
38    }
39
40    pub fn resolve_repository<D, E>(
41        &self,
42        entity: impl Into<String>,
43    ) -> Result<ResolvedRepository<'_, D, E>, ContextError>
44    where
45        D: SqlDialect + Send + Sync + 'static,
46        E: QueryExecutor + Send + Sync + 'static,
47    {
48        let entity = entity.into();
49        if !self.has_repository(&entity) {
50            return Err(ContextError::MissingRepository(entity));
51        }
52        Ok(ResolvedRepository {
53            entity,
54            repository: self.repository::<D, E>()?,
55            trace_context: Vec::new(),
56        })
57    }
58
59    pub fn plan_for_save_graph<D, E>(
60        &self,
61        node: GraphNode,
62    ) -> Result<GraphMutationPlan, RepositoryError<E::Error>>
63    where
64        D: SqlDialect + Send + Sync + 'static,
65        E: QueryExecutor + Send + Sync + 'static,
66    {
67        let repository = self
68            .resolve_repository::<D, E>(node.entity.clone())
69            .map_err(|err| RepositoryError::Runtime(RuntimeError::Graph(err.to_string())))?;
70        repository.plan_graph(node)
71    }
72}
73
74impl<'a, D, E> ContextRepository<'a, D, E>
75where
76    D: SqlDialect,
77    E: QueryExecutor,
78{
79    fn repository(&self) -> Repository<'_, D, UserContextMetadata<'_>, E> {
80        Repository::new(self.dialect, &self.metadata, self.executor)
81    }
82
83    pub fn compile(&self, query: &SelectQuery) -> Result<CompiledQuery, RuntimeError> {
84        self.repository().compile(query)
85    }
86
87    pub fn fetch_all(&self, query: &SelectQuery) -> Result<Vec<Record>, RepositoryError<E::Error>> {
88        let mut compiled = self.compile(query).map_err(RepositoryError::Runtime)?;
89        let final_comment = self.resolve_final_comment(&query.trace_chain, query.comment.clone());
90        compiled.comment = final_comment;
91
92        let started_at = SystemTime::now();
93        let started = Instant::now();
94        let rows = self
95            .executor
96            .fetch_all(&compiled)
97            .map_err(RepositoryError::Executor)?;
98        self.log_sql_result(
99            SqlLogOperation::Select,
100            &compiled,
101            started_at,
102            started,
103            Some(rows.len()),
104            Some(query.entity.clone()),
105            None,
106            query.trace_chain.clone(),
107        );
108        Ok(rows)
109    }
110
111    pub fn fetch_smart_list(
112        &self,
113        query: &SelectQuery,
114    ) -> Result<SmartList<Record>, RepositoryError<E::Error>> {
115        self.repository().fetch_smart_list(query)
116    }
117
118    pub fn fetch_entities<T>(
119        &self,
120        query: &SelectQuery,
121    ) -> Result<SmartList<T>, RepositoryError<E::Error>>
122    where
123        T: Entity,
124    {
125        self.repository().fetch_entities(query)
126    }
127
128    pub fn fetch_enhanced_entities<T>(
129        &self,
130        query: &SelectQuery,
131    ) -> Result<SmartList<T>, RepositoryError<E::Error>>
132    where
133        T: Entity,
134    {
135        self.repository().fetch_enhanced_entities(query)
136    }
137
138    pub fn insert(&self, command: &InsertCommand) -> Result<u64, RepositoryError<E::Error>> {
139        let mut compiled = self
140            .repository()
141            .compile_insert(command)
142            .map_err(RepositoryError::Runtime)?;
143        let final_comment = self.resolve_final_comment(&command.trace_chain, None);
144        compiled.comment = final_comment;
145
146        let started_at = SystemTime::now();
147        let started = Instant::now();
148        let affected = self
149            .executor
150            .execute(&compiled)
151            .map_err(RepositoryError::Executor)?;
152        self.log_sql_result(
153            SqlLogOperation::Insert,
154            &compiled,
155            started_at,
156            started,
157            None,
158            None,
159            Some(affected),
160            command.trace_chain.clone(),
161        );
162        self.invalidate_aggregation_cache_for(&command.entity);
163        Ok(affected)
164    }
165
166    pub fn update(&self, command: &UpdateCommand) -> Result<u64, RepositoryError<E::Error>> {
167        let affected = self.execute_mutation(
168            SqlLogOperation::Update,
169            &command.entity,
170            self.repository()
171                .compile_update(command)
172                .map_err(RepositoryError::Runtime)?,
173            command.trace_chain.clone(),
174        )?;
175        if command.expected_version.is_some() && affected == 0 {
176            return Err(RepositoryError::Runtime(
177                RuntimeError::OptimisticLockConflict {
178                    entity: command.entity.clone(),
179                    id: format!("{:?}", command.id),
180                },
181            ));
182        }
183        Ok(affected)
184    }
185
186    pub fn delete(&self, command: &DeleteCommand) -> Result<u64, RepositoryError<E::Error>> {
187        let affected = self.execute_mutation(
188            SqlLogOperation::Delete,
189            &command.entity,
190            self.repository()
191                .compile_delete(command)
192                .map_err(RepositoryError::Runtime)?,
193            command.trace_chain.clone(),
194        )?;
195        if command.expected_version.is_some() && affected == 0 {
196            return Err(RepositoryError::Runtime(
197                RuntimeError::OptimisticLockConflict {
198                    entity: command.entity.clone(),
199                    id: format!("{:?}", command.id),
200                },
201            ));
202        }
203        Ok(affected)
204    }
205
206    pub fn recover(&self, command: &RecoverCommand) -> Result<u64, RepositoryError<E::Error>> {
207        let affected = self.execute_mutation(
208            SqlLogOperation::Recover,
209            &command.entity,
210            self.repository()
211                .compile_recover(command)
212                .map_err(RepositoryError::Runtime)?,
213            command.trace_chain.clone(),
214        )?;
215        if affected == 0 {
216            return Err(RepositoryError::Runtime(
217                RuntimeError::OptimisticLockConflict {
218                    entity: command.entity.clone(),
219                    id: format!("{:?}", command.id),
220                },
221            ));
222        }
223        Ok(affected)
224    }
225
226    fn execute_mutation(
227        &self,
228        operation: SqlLogOperation,
229        entity: &str,
230        mut compiled: CompiledQuery,
231        trace_chain: Vec<teaql_core::TraceNode>,
232    ) -> Result<u64, RepositoryError<E::Error>> {
233        let final_comment = self.resolve_final_comment(&trace_chain, None);
234        compiled.comment = final_comment;
235
236        let started_at = SystemTime::now();
237        let started = Instant::now();
238        let affected = self
239            .executor
240            .execute(&compiled)
241            .map_err(RepositoryError::Executor)?;
242        self.log_sql_result(
243            operation,
244            &compiled,
245            started_at,
246            started,
247            None,
248            None,
249            Some(affected),
250            trace_chain,
251        );
252        self.invalidate_aggregation_cache_for(entity);
253        Ok(affected)
254    }
255
256    pub(super) fn log_sql_result(
257        &self,
258        operation: SqlLogOperation,
259        compiled: &CompiledQuery,
260        started_at: SystemTime,
261        started: Instant,
262        result_count: Option<usize>,
263        result_type: Option<String>,
264        affected_rows: Option<u64>,
265        trace_chain: Vec<teaql_core::TraceNode>,
266    ) {
267        self.metadata.context.record_sql_log(
268            operation,
269            compiled,
270            self.dialect.kind(),
271            started_at,
272            SystemTime::now(),
273            started.elapsed(),
274            result_count,
275            result_type,
276            affected_rows,
277            trace_chain,
278        );
279    }
280
281    pub(super) fn invalidate_aggregation_cache_for(&self, entity: &str) {
282        if let Some(cache) = self
283            .metadata
284            .context
285            .get_resource::<Arc<dyn AggregationCacheBackend>>()
286        {
287            invalidate_aggregation_cache_namespace(cache.as_ref(), entity);
288        }
289        if let Some(cache) = self
290            .metadata
291            .context
292            .get_resource::<InMemoryAggregationCache>()
293        {
294            invalidate_aggregation_cache_namespace(cache, entity);
295        }
296    }
297
298    pub(crate) fn resolve_final_comment(&self, trace_chain: &[teaql_core::TraceNode], comment: Option<String>) -> Option<String> {
299        let chain_str = if trace_chain.is_empty() {
300            None
301        } else {
302            let formatted = trace_chain.iter().map(|n| {
303                format!("{}({}): {}", n.entity_type, n.entity_id.map(|id| id.to_string()).unwrap_or_else(|| "pending".to_owned()), n.comment)
304            }).collect::<Vec<_>>().join(" -> ");
305            Some(formatted)
306        };
307
308        let business_comment = chain_str.or(comment);
309        let user_id = self.metadata.context.user_identifier().map(|s| s.to_owned());
310
311        match (user_id, business_comment) {
312            (Some(user), Some(bus)) if !user.is_empty() && !bus.is_empty() => {
313                Some(format!("[{user}] {bus}"))
314            }
315            (Some(user), _) if !user.is_empty() => {
316                Some(format!("[{user}]"))
317            }
318            (_, Some(bus)) if !bus.is_empty() => {
319                Some(bus)
320            }
321            _ => None,
322        }
323    }
324}