Skip to main content

teaql_runtime/repository/
context.rs

1use std::sync::Arc;
2
3use teaql_core::{
4    DeleteCommand, Entity, InsertCommand, Record, RecoverCommand, SelectQuery, SmartList,
5    UpdateCommand,
6};
7use teaql_sql::{CompiledQuery, SqlDialect};
8
9use crate::{
10    ContextError, GraphMutationPlan, GraphNode, RepositoryError, RuntimeError, SqlLogOperation,
11    UserContext,
12};
13
14use super::{
15    AggregationCacheBackend, ContextRepository, InMemoryAggregationCache, QueryExecutor,
16    Repository, ResolvedRepository, UserContextMetadata,
17    helpers::invalidate_aggregation_cache_namespace,
18};
19
20impl UserContext {
21    pub fn repository<D, E>(&self) -> Result<ContextRepository<'_, D, E>, ContextError>
22    where
23        D: SqlDialect + Send + Sync + 'static,
24        E: QueryExecutor + Send + Sync + 'static,
25    {
26        if self.metadata.is_none() {
27            return Err(ContextError::MissingResource("metadata".to_owned()));
28        }
29
30        let dialect = self.require_resource::<D>()?;
31        let executor = self.require_resource::<E>()?;
32        Ok(ContextRepository {
33            metadata: UserContextMetadata { context: self },
34            dialect,
35            executor,
36        })
37    }
38
39    pub fn resolve_repository<D, E>(
40        &self,
41        entity: impl Into<String>,
42    ) -> Result<ResolvedRepository<'_, D, E>, ContextError>
43    where
44        D: SqlDialect + Send + Sync + 'static,
45        E: QueryExecutor + Send + Sync + 'static,
46    {
47        let entity = entity.into();
48        if !self.has_repository(&entity) {
49            return Err(ContextError::MissingRepository(entity));
50        }
51        Ok(ResolvedRepository {
52            entity,
53            repository: self.repository::<D, E>()?,
54        })
55    }
56
57    pub fn plan_for_save_graph<D, E>(
58        &self,
59        node: GraphNode,
60    ) -> Result<GraphMutationPlan, RepositoryError<E::Error>>
61    where
62        D: SqlDialect + Send + Sync + 'static,
63        E: QueryExecutor + Send + Sync + 'static,
64    {
65        let repository = self
66            .resolve_repository::<D, E>(node.entity.clone())
67            .map_err(|err| RepositoryError::Runtime(RuntimeError::Graph(err.to_string())))?;
68        repository.plan_graph(node)
69    }
70}
71
72impl<'a, D, E> ContextRepository<'a, D, E>
73where
74    D: SqlDialect,
75    E: QueryExecutor,
76{
77    fn repository(&self) -> Repository<'_, D, UserContextMetadata<'_>, E> {
78        Repository::new(self.dialect, &self.metadata, self.executor)
79    }
80
81    pub fn compile(&self, query: &SelectQuery) -> Result<CompiledQuery, RuntimeError> {
82        self.repository().compile(query)
83    }
84
85    pub fn fetch_all(&self, query: &SelectQuery) -> Result<Vec<Record>, RepositoryError<E::Error>> {
86        let compiled = self.compile(query).map_err(RepositoryError::Runtime)?;
87        self.log_sql(SqlLogOperation::Select, &compiled);
88        self.executor
89            .fetch_all(&compiled)
90            .map_err(RepositoryError::Executor)
91    }
92
93    pub fn fetch_smart_list(
94        &self,
95        query: &SelectQuery,
96    ) -> Result<SmartList<Record>, RepositoryError<E::Error>> {
97        self.repository().fetch_smart_list(query)
98    }
99
100    pub fn fetch_entities<T>(
101        &self,
102        query: &SelectQuery,
103    ) -> Result<SmartList<T>, RepositoryError<E::Error>>
104    where
105        T: Entity,
106    {
107        self.repository().fetch_entities(query)
108    }
109
110    pub fn fetch_enhanced_entities<T>(
111        &self,
112        query: &SelectQuery,
113    ) -> Result<SmartList<T>, RepositoryError<E::Error>>
114    where
115        T: Entity,
116    {
117        self.repository().fetch_enhanced_entities(query)
118    }
119
120    pub fn insert(&self, command: &InsertCommand) -> Result<u64, RepositoryError<E::Error>> {
121        let compiled = self
122            .repository()
123            .compile_insert(command)
124            .map_err(RepositoryError::Runtime)?;
125        self.log_sql(SqlLogOperation::Insert, &compiled);
126        let affected = self
127            .executor
128            .execute(&compiled)
129            .map_err(RepositoryError::Executor)?;
130        self.invalidate_aggregation_cache_for(&command.entity);
131        Ok(affected)
132    }
133
134    pub fn update(&self, command: &UpdateCommand) -> Result<u64, RepositoryError<E::Error>> {
135        let affected = self.execute_mutation(
136            SqlLogOperation::Update,
137            &command.entity,
138            self.repository()
139                .compile_update(command)
140                .map_err(RepositoryError::Runtime)?,
141        )?;
142        if command.expected_version.is_some() && affected == 0 {
143            return Err(RepositoryError::Runtime(
144                RuntimeError::OptimisticLockConflict {
145                    entity: command.entity.clone(),
146                    id: format!("{:?}", command.id),
147                },
148            ));
149        }
150        Ok(affected)
151    }
152
153    pub fn delete(&self, command: &DeleteCommand) -> Result<u64, RepositoryError<E::Error>> {
154        let affected = self.execute_mutation(
155            SqlLogOperation::Delete,
156            &command.entity,
157            self.repository()
158                .compile_delete(command)
159                .map_err(RepositoryError::Runtime)?,
160        )?;
161        if command.expected_version.is_some() && affected == 0 {
162            return Err(RepositoryError::Runtime(
163                RuntimeError::OptimisticLockConflict {
164                    entity: command.entity.clone(),
165                    id: format!("{:?}", command.id),
166                },
167            ));
168        }
169        Ok(affected)
170    }
171
172    pub fn recover(&self, command: &RecoverCommand) -> Result<u64, RepositoryError<E::Error>> {
173        let affected = self.execute_mutation(
174            SqlLogOperation::Recover,
175            &command.entity,
176            self.repository()
177                .compile_recover(command)
178                .map_err(RepositoryError::Runtime)?,
179        )?;
180        if affected == 0 {
181            return Err(RepositoryError::Runtime(
182                RuntimeError::OptimisticLockConflict {
183                    entity: command.entity.clone(),
184                    id: format!("{:?}", command.id),
185                },
186            ));
187        }
188        Ok(affected)
189    }
190
191    fn execute_mutation(
192        &self,
193        operation: SqlLogOperation,
194        entity: &str,
195        compiled: CompiledQuery,
196    ) -> Result<u64, RepositoryError<E::Error>> {
197        self.log_sql(operation, &compiled);
198        let affected = self
199            .executor
200            .execute(&compiled)
201            .map_err(RepositoryError::Executor)?;
202        self.invalidate_aggregation_cache_for(entity);
203        Ok(affected)
204    }
205
206    pub(super) fn log_sql(&self, operation: SqlLogOperation, compiled: &CompiledQuery) {
207        self.metadata
208            .context
209            .record_sql_log(operation, compiled, self.dialect.kind());
210    }
211
212    pub(super) fn invalidate_aggregation_cache_for(&self, entity: &str) {
213        if let Some(cache) = self
214            .metadata
215            .context
216            .get_resource::<Arc<dyn AggregationCacheBackend>>()
217        {
218            invalidate_aggregation_cache_namespace(cache.as_ref(), entity);
219        }
220        if let Some(cache) = self
221            .metadata
222            .context
223            .get_resource::<InMemoryAggregationCache>()
224        {
225            invalidate_aggregation_cache_namespace(cache, entity);
226        }
227    }
228}