1use std::sync::Arc;
2use std::time::{Instant, SystemTime};
3
4use teaql_core::{
5 DeleteCommand, Entity, InsertCommand, Record, RecoverCommand, SelectQuery, SmartList,
6 UpdateCommand,
7};
8use teaql_sql::{CompiledQuery, SqlDialect};
9
10use crate::{
11 ContextError, GraphMutationPlan, GraphNode, RepositoryError, RuntimeError, SqlLogOperation,
12 UserContext,
13};
14
15use super::{
16 AggregationCacheBackend, ContextRepository, InMemoryAggregationCache, QueryExecutor,
17 Repository, ResolvedRepository, UserContextMetadata,
18 helpers::invalidate_aggregation_cache_namespace,
19};
20
21impl UserContext {
22 pub fn repository<D, E>(&self) -> Result<ContextRepository<'_, D, E>, ContextError>
23 where
24 D: SqlDialect + Send + Sync + 'static,
25 E: QueryExecutor + Send + Sync + 'static,
26 {
27 if self.metadata.is_none() {
28 return Err(ContextError::MissingResource("metadata".to_owned()));
29 }
30
31 let dialect = self.require_resource::<D>()?;
32 let executor = self.require_resource::<E>()?;
33 Ok(ContextRepository {
34 metadata: UserContextMetadata { context: self },
35 dialect,
36 executor,
37 })
38 }
39
40 pub fn resolve_repository<D, E>(
41 &self,
42 entity: impl Into<String>,
43 ) -> Result<ResolvedRepository<'_, D, E>, ContextError>
44 where
45 D: SqlDialect + Send + Sync + 'static,
46 E: QueryExecutor + Send + Sync + 'static,
47 {
48 let entity = entity.into();
49 if !self.has_repository(&entity) {
50 return Err(ContextError::MissingRepository(entity));
51 }
52 Ok(ResolvedRepository {
53 entity,
54 repository: self.repository::<D, E>()?,
55 trace_context: Vec::new(),
56 })
57 }
58
59 pub fn plan_for_save_graph<D, E>(
60 &self,
61 node: GraphNode,
62 ) -> Result<GraphMutationPlan, RepositoryError<E::Error>>
63 where
64 D: SqlDialect + Send + Sync + 'static,
65 E: QueryExecutor + Send + Sync + 'static,
66 {
67 let repository = self
68 .resolve_repository::<D, E>(node.entity.clone())
69 .map_err(|err| RepositoryError::Runtime(RuntimeError::Graph(err.to_string())))?;
70 repository.plan_graph(node)
71 }
72}
73
74impl<'a, D, E> ContextRepository<'a, D, E>
75where
76 D: SqlDialect,
77 E: QueryExecutor,
78{
79 fn repository(&self) -> Repository<'_, D, UserContextMetadata<'_>, E> {
80 Repository::new(self.dialect, &self.metadata, self.executor)
81 }
82
83 pub fn compile(&self, query: &SelectQuery) -> Result<CompiledQuery, RuntimeError> {
84 self.repository().compile(query)
85 }
86
87 pub fn fetch_all(&self, query: &SelectQuery) -> Result<Vec<Record>, RepositoryError<E::Error>> {
88 let mut compiled = self.compile(query).map_err(RepositoryError::Runtime)?;
89 let final_comment = self.resolve_final_comment(&query.trace_chain, query.comment.clone());
90 compiled.comment = final_comment;
91
92 let started_at = SystemTime::now();
93 let started = Instant::now();
94 let rows = self
95 .executor
96 .fetch_all(&compiled)
97 .map_err(RepositoryError::Executor)?;
98 self.log_sql_result(
99 SqlLogOperation::Select,
100 &compiled,
101 started_at,
102 started,
103 Some(rows.len()),
104 Some(query.entity.clone()),
105 None,
106 query.trace_chain.clone(),
107 );
108 Ok(rows)
109 }
110
111 pub fn fetch_smart_list(
112 &self,
113 query: &SelectQuery,
114 ) -> Result<SmartList<Record>, RepositoryError<E::Error>> {
115 self.repository().fetch_smart_list(query)
116 }
117
118 pub fn fetch_entities<T>(
119 &self,
120 query: &SelectQuery,
121 ) -> Result<SmartList<T>, RepositoryError<E::Error>>
122 where
123 T: Entity,
124 {
125 self.repository().fetch_entities(query)
126 }
127
128 pub fn fetch_enhanced_entities<T>(
129 &self,
130 query: &SelectQuery,
131 ) -> Result<SmartList<T>, RepositoryError<E::Error>>
132 where
133 T: Entity,
134 {
135 self.repository().fetch_enhanced_entities(query)
136 }
137
138 pub fn insert(&self, command: &InsertCommand) -> Result<u64, RepositoryError<E::Error>> {
139 let mut compiled = self
140 .repository()
141 .compile_insert(command)
142 .map_err(RepositoryError::Runtime)?;
143 let final_comment = self.resolve_final_comment(&command.trace_chain, None);
144 compiled.comment = final_comment;
145
146 let started_at = SystemTime::now();
147 let started = Instant::now();
148 let affected = self
149 .executor
150 .execute(&compiled)
151 .map_err(RepositoryError::Executor)?;
152 self.log_sql_result(
153 SqlLogOperation::Insert,
154 &compiled,
155 started_at,
156 started,
157 None,
158 None,
159 Some(affected),
160 command.trace_chain.clone(),
161 );
162 self.invalidate_aggregation_cache_for(&command.entity);
163 Ok(affected)
164 }
165
166 pub fn update(&self, command: &UpdateCommand) -> Result<u64, RepositoryError<E::Error>> {
167 let affected = self.execute_mutation(
168 SqlLogOperation::Update,
169 &command.entity,
170 self.repository()
171 .compile_update(command)
172 .map_err(RepositoryError::Runtime)?,
173 command.trace_chain.clone(),
174 )?;
175 if command.expected_version.is_some() && affected == 0 {
176 return Err(RepositoryError::Runtime(
177 RuntimeError::OptimisticLockConflict {
178 entity: command.entity.clone(),
179 id: format!("{:?}", command.id),
180 },
181 ));
182 }
183 Ok(affected)
184 }
185
186 pub fn delete(&self, command: &DeleteCommand) -> Result<u64, RepositoryError<E::Error>> {
187 let affected = self.execute_mutation(
188 SqlLogOperation::Delete,
189 &command.entity,
190 self.repository()
191 .compile_delete(command)
192 .map_err(RepositoryError::Runtime)?,
193 command.trace_chain.clone(),
194 )?;
195 if command.expected_version.is_some() && affected == 0 {
196 return Err(RepositoryError::Runtime(
197 RuntimeError::OptimisticLockConflict {
198 entity: command.entity.clone(),
199 id: format!("{:?}", command.id),
200 },
201 ));
202 }
203 Ok(affected)
204 }
205
206 pub fn recover(&self, command: &RecoverCommand) -> Result<u64, RepositoryError<E::Error>> {
207 let affected = self.execute_mutation(
208 SqlLogOperation::Recover,
209 &command.entity,
210 self.repository()
211 .compile_recover(command)
212 .map_err(RepositoryError::Runtime)?,
213 command.trace_chain.clone(),
214 )?;
215 if affected == 0 {
216 return Err(RepositoryError::Runtime(
217 RuntimeError::OptimisticLockConflict {
218 entity: command.entity.clone(),
219 id: format!("{:?}", command.id),
220 },
221 ));
222 }
223 Ok(affected)
224 }
225
226 fn execute_mutation(
227 &self,
228 operation: SqlLogOperation,
229 entity: &str,
230 mut compiled: CompiledQuery,
231 trace_chain: Vec<teaql_core::TraceNode>,
232 ) -> Result<u64, RepositoryError<E::Error>> {
233 let final_comment = self.resolve_final_comment(&trace_chain, None);
234 compiled.comment = final_comment;
235
236 let started_at = SystemTime::now();
237 let started = Instant::now();
238 let affected = self
239 .executor
240 .execute(&compiled)
241 .map_err(RepositoryError::Executor)?;
242 self.log_sql_result(
243 operation,
244 &compiled,
245 started_at,
246 started,
247 None,
248 None,
249 Some(affected),
250 trace_chain,
251 );
252 self.invalidate_aggregation_cache_for(entity);
253 Ok(affected)
254 }
255
256 pub(super) fn log_sql_result(
257 &self,
258 operation: SqlLogOperation,
259 compiled: &CompiledQuery,
260 started_at: SystemTime,
261 started: Instant,
262 result_count: Option<usize>,
263 result_type: Option<String>,
264 affected_rows: Option<u64>,
265 trace_chain: Vec<teaql_core::TraceNode>,
266 ) {
267 self.metadata.context.record_sql_log(
268 operation,
269 compiled,
270 self.dialect.kind(),
271 started_at,
272 SystemTime::now(),
273 started.elapsed(),
274 result_count,
275 result_type,
276 affected_rows,
277 trace_chain,
278 );
279 }
280
281 pub(super) fn invalidate_aggregation_cache_for(&self, entity: &str) {
282 if let Some(cache) = self
283 .metadata
284 .context
285 .get_resource::<Arc<dyn AggregationCacheBackend>>()
286 {
287 invalidate_aggregation_cache_namespace(cache.as_ref(), entity);
288 }
289 if let Some(cache) = self
290 .metadata
291 .context
292 .get_resource::<InMemoryAggregationCache>()
293 {
294 invalidate_aggregation_cache_namespace(cache, entity);
295 }
296 }
297
298 pub(crate) fn resolve_final_comment(&self, trace_chain: &[teaql_core::TraceNode], comment: Option<String>) -> Option<String> {
299 let chain_str = if trace_chain.is_empty() {
300 None
301 } else {
302 let formatted = trace_chain.iter().map(|n| {
303 format!("{}({}): {}", n.entity_type, n.entity_id.map(|id| id.to_string()).unwrap_or_else(|| "pending".to_owned()), n.comment)
304 }).collect::<Vec<_>>().join(" -> ");
305 Some(formatted)
306 };
307
308 let business_comment = chain_str.or(comment);
309 let user_id = self.metadata.context.user_identifier().map(|s| s.to_owned());
310
311 match (user_id, business_comment) {
312 (Some(user), Some(bus)) if !user.is_empty() && !bus.is_empty() => {
313 Some(format!("[{user}] {bus}"))
314 }
315 (Some(user), _) if !user.is_empty() => {
316 Some(format!("[{user}]"))
317 }
318 (_, Some(bus)) if !bus.is_empty() => {
319 Some(bus)
320 }
321 _ => None,
322 }
323 }
324}