Skip to main content

teaql_runtime/repository/
context.rs

1use std::sync::Arc;
2use std::time::{Instant, SystemTime};
3
4use teaql_core::{
5    DeleteCommand, Entity, InsertCommand, Record, RecoverCommand, SelectQuery, SmartList,
6    UpdateCommand,
7};
8use teaql_sql::{CompiledQuery, SqlDialect};
9
10use crate::{
11    ContextError, GraphMutationPlan, GraphNode, RepositoryError, RuntimeError, SqlLogOperation,
12    UserContext,
13};
14
15use super::{
16    AggregationCacheBackend, ContextRepository, InMemoryAggregationCache, QueryExecutor,
17    Repository, ResolvedRepository, UserContextMetadata,
18    helpers::invalidate_aggregation_cache_namespace,
19};
20
21impl UserContext {
22    pub fn repository<D, E>(&self) -> Result<ContextRepository<'_, D, E>, ContextError>
23    where
24        D: SqlDialect + Send + Sync + 'static,
25        E: QueryExecutor + Send + Sync + 'static,
26    {
27        if self.metadata.is_none() {
28            return Err(ContextError::MissingResource("metadata".to_owned()));
29        }
30
31        let dialect = self.require_resource::<D>()?;
32        let executor = self.require_resource::<E>()?;
33        Ok(ContextRepository {
34            metadata: UserContextMetadata { context: self },
35            dialect,
36            executor,
37        })
38    }
39
40    pub fn resolve_repository<D, E>(
41        &self,
42        entity: impl Into<String>,
43    ) -> Result<ResolvedRepository<'_, D, E>, ContextError>
44    where
45        D: SqlDialect + Send + Sync + 'static,
46        E: QueryExecutor + Send + Sync + 'static,
47    {
48        let entity = entity.into();
49        if !self.has_repository(&entity) {
50            return Err(ContextError::MissingRepository(entity));
51        }
52        Ok(ResolvedRepository {
53            entity,
54            repository: self.repository::<D, E>()?,
55        })
56    }
57
58    pub fn plan_for_save_graph<D, E>(
59        &self,
60        node: GraphNode,
61    ) -> Result<GraphMutationPlan, RepositoryError<E::Error>>
62    where
63        D: SqlDialect + Send + Sync + 'static,
64        E: QueryExecutor + Send + Sync + 'static,
65    {
66        let repository = self
67            .resolve_repository::<D, E>(node.entity.clone())
68            .map_err(|err| RepositoryError::Runtime(RuntimeError::Graph(err.to_string())))?;
69        repository.plan_graph(node)
70    }
71}
72
73impl<'a, D, E> ContextRepository<'a, D, E>
74where
75    D: SqlDialect,
76    E: QueryExecutor,
77{
78    fn repository(&self) -> Repository<'_, D, UserContextMetadata<'_>, E> {
79        Repository::new(self.dialect, &self.metadata, self.executor)
80    }
81
82    pub fn compile(&self, query: &SelectQuery) -> Result<CompiledQuery, RuntimeError> {
83        self.repository().compile(query)
84    }
85
86    pub fn fetch_all(&self, query: &SelectQuery) -> Result<Vec<Record>, RepositoryError<E::Error>> {
87        let compiled = self.compile(query).map_err(RepositoryError::Runtime)?;
88        let started_at = SystemTime::now();
89        let started = Instant::now();
90        let rows = self
91            .executor
92            .fetch_all(&compiled)
93            .map_err(RepositoryError::Executor)?;
94        self.log_sql_result(
95            SqlLogOperation::Select,
96            &compiled,
97            started_at,
98            started,
99            Some(rows.len()),
100            Some(query.entity.clone()),
101            None,
102        );
103        Ok(rows)
104    }
105
106    pub fn fetch_smart_list(
107        &self,
108        query: &SelectQuery,
109    ) -> Result<SmartList<Record>, RepositoryError<E::Error>> {
110        self.repository().fetch_smart_list(query)
111    }
112
113    pub fn fetch_entities<T>(
114        &self,
115        query: &SelectQuery,
116    ) -> Result<SmartList<T>, RepositoryError<E::Error>>
117    where
118        T: Entity,
119    {
120        self.repository().fetch_entities(query)
121    }
122
123    pub fn fetch_enhanced_entities<T>(
124        &self,
125        query: &SelectQuery,
126    ) -> Result<SmartList<T>, RepositoryError<E::Error>>
127    where
128        T: Entity,
129    {
130        self.repository().fetch_enhanced_entities(query)
131    }
132
133    pub fn insert(&self, command: &InsertCommand) -> Result<u64, RepositoryError<E::Error>> {
134        let compiled = self
135            .repository()
136            .compile_insert(command)
137            .map_err(RepositoryError::Runtime)?;
138        let started_at = SystemTime::now();
139        let started = Instant::now();
140        let affected = self
141            .executor
142            .execute(&compiled)
143            .map_err(RepositoryError::Executor)?;
144        self.log_sql_result(
145            SqlLogOperation::Insert,
146            &compiled,
147            started_at,
148            started,
149            None,
150            None,
151            Some(affected),
152        );
153        self.invalidate_aggregation_cache_for(&command.entity);
154        Ok(affected)
155    }
156
157    pub fn update(&self, command: &UpdateCommand) -> Result<u64, RepositoryError<E::Error>> {
158        let affected = self.execute_mutation(
159            SqlLogOperation::Update,
160            &command.entity,
161            self.repository()
162                .compile_update(command)
163                .map_err(RepositoryError::Runtime)?,
164        )?;
165        if command.expected_version.is_some() && affected == 0 {
166            return Err(RepositoryError::Runtime(
167                RuntimeError::OptimisticLockConflict {
168                    entity: command.entity.clone(),
169                    id: format!("{:?}", command.id),
170                },
171            ));
172        }
173        Ok(affected)
174    }
175
176    pub fn delete(&self, command: &DeleteCommand) -> Result<u64, RepositoryError<E::Error>> {
177        let affected = self.execute_mutation(
178            SqlLogOperation::Delete,
179            &command.entity,
180            self.repository()
181                .compile_delete(command)
182                .map_err(RepositoryError::Runtime)?,
183        )?;
184        if command.expected_version.is_some() && affected == 0 {
185            return Err(RepositoryError::Runtime(
186                RuntimeError::OptimisticLockConflict {
187                    entity: command.entity.clone(),
188                    id: format!("{:?}", command.id),
189                },
190            ));
191        }
192        Ok(affected)
193    }
194
195    pub fn recover(&self, command: &RecoverCommand) -> Result<u64, RepositoryError<E::Error>> {
196        let affected = self.execute_mutation(
197            SqlLogOperation::Recover,
198            &command.entity,
199            self.repository()
200                .compile_recover(command)
201                .map_err(RepositoryError::Runtime)?,
202        )?;
203        if affected == 0 {
204            return Err(RepositoryError::Runtime(
205                RuntimeError::OptimisticLockConflict {
206                    entity: command.entity.clone(),
207                    id: format!("{:?}", command.id),
208                },
209            ));
210        }
211        Ok(affected)
212    }
213
214    fn execute_mutation(
215        &self,
216        operation: SqlLogOperation,
217        entity: &str,
218        compiled: CompiledQuery,
219    ) -> Result<u64, RepositoryError<E::Error>> {
220        let started_at = SystemTime::now();
221        let started = Instant::now();
222        let affected = self
223            .executor
224            .execute(&compiled)
225            .map_err(RepositoryError::Executor)?;
226        self.log_sql_result(
227            operation,
228            &compiled,
229            started_at,
230            started,
231            None,
232            None,
233            Some(affected),
234        );
235        self.invalidate_aggregation_cache_for(entity);
236        Ok(affected)
237    }
238
239    pub(super) fn log_sql_result(
240        &self,
241        operation: SqlLogOperation,
242        compiled: &CompiledQuery,
243        started_at: SystemTime,
244        started: Instant,
245        result_count: Option<usize>,
246        result_type: Option<String>,
247        affected_rows: Option<u64>,
248    ) {
249        self.metadata.context.record_sql_log(
250            operation,
251            compiled,
252            self.dialect.kind(),
253            started_at,
254            SystemTime::now(),
255            started.elapsed(),
256            result_count,
257            result_type,
258            affected_rows,
259        );
260    }
261
262    pub(super) fn invalidate_aggregation_cache_for(&self, entity: &str) {
263        if let Some(cache) = self
264            .metadata
265            .context
266            .get_resource::<Arc<dyn AggregationCacheBackend>>()
267        {
268            invalidate_aggregation_cache_namespace(cache.as_ref(), entity);
269        }
270        if let Some(cache) = self
271            .metadata
272            .context
273            .get_resource::<InMemoryAggregationCache>()
274        {
275            invalidate_aggregation_cache_namespace(cache, entity);
276        }
277    }
278}