1use std::sync::Arc;
2use std::time::{Instant, SystemTime};
3
4use teaql_core::{
5 DeleteCommand, Entity, InsertCommand, Record, RecoverCommand, SelectQuery, SmartList,
6 UpdateCommand,
7};
8use teaql_sql::{CompiledQuery, SqlDialect};
9
10use crate::{
11 ContextError, GraphMutationPlan, GraphNode, RepositoryError, RuntimeError, SqlLogOperation,
12 UserContext,
13};
14
15use super::{
16 AggregationCacheBackend, ContextRepository, InMemoryAggregationCache, QueryExecutor,
17 Repository, ResolvedRepository, UserContextMetadata,
18 helpers::invalidate_aggregation_cache_namespace,
19};
20
21impl UserContext {
22 pub fn repository<D, E>(&self) -> Result<ContextRepository<'_, D, E>, ContextError>
23 where
24 D: SqlDialect + Send + Sync + 'static,
25 E: QueryExecutor + Send + Sync + 'static,
26 {
27 if self.metadata.is_none() {
28 return Err(ContextError::MissingResource("metadata".to_owned()));
29 }
30
31 let dialect = self.require_resource::<D>()?;
32 let executor = self.require_resource::<E>()?;
33 Ok(ContextRepository {
34 metadata: UserContextMetadata { context: self },
35 dialect,
36 executor,
37 })
38 }
39
40 #[doc(hidden)]
41 pub fn resolve_repository<D, E>(
42 &self,
43 entity: impl Into<String>,
44 ) -> Result<ResolvedRepository<'_, D, E>, ContextError>
45 where
46 D: SqlDialect + Send + Sync + 'static,
47 E: QueryExecutor + Send + Sync + 'static,
48 {
49 let entity = entity.into();
50 if !self.has_repository(&entity) {
51 return Err(ContextError::MissingRepository(entity));
52 }
53 Ok(ResolvedRepository {
54 entity,
55 repository: self.repository::<D, E>()?,
56 trace_context: Vec::new(),
57 })
58 }
59
60 pub fn plan_for_save_graph<D, E>(
61 &self,
62 node: GraphNode,
63 ) -> Result<GraphMutationPlan, RepositoryError<E::Error>>
64 where
65 D: SqlDialect + Send + Sync + 'static,
66 E: QueryExecutor + Send + Sync + 'static,
67 {
68 let repository = self
69 .resolve_repository::<D, E>(node.entity.clone())
70 .map_err(|err| RepositoryError::Runtime(RuntimeError::Graph(err.to_string())))?;
71 repository.plan_graph(node)
72 }
73}
74
75impl<'a, D, E> ContextRepository<'a, D, E>
76where
77 D: SqlDialect,
78 E: QueryExecutor,
79{
80 fn repository(&self) -> Repository<'_, D, UserContextMetadata<'_>, E> {
81 Repository::new(self.dialect, &self.metadata, self.executor)
82 }
83
84 pub fn compile(&self, query: &SelectQuery) -> Result<CompiledQuery, RuntimeError> {
85 self.repository().compile(query)
86 }
87
88 pub fn fetch_all(&self, query: &SelectQuery) -> Result<Vec<Record>, RepositoryError<E::Error>> {
89 let mut compiled = self.compile(query).map_err(RepositoryError::Runtime)?;
90 let final_comment = self.resolve_final_comment(&query.trace_chain, query.comment.clone());
91 compiled.comment = final_comment;
92
93 let started_at = SystemTime::now();
94 let started = Instant::now();
95 let rows = self
96 .executor
97 .fetch_all(&compiled)
98 .map_err(RepositoryError::Executor)?;
99 self.log_sql_result(
100 SqlLogOperation::Select,
101 &compiled,
102 started_at,
103 started,
104 Some(rows.len()),
105 Some(query.entity.clone()),
106 None,
107 query.trace_chain.clone(),
108 );
109 Ok(rows)
110 }
111
112 pub fn fetch_smart_list(
113 &self,
114 query: &SelectQuery,
115 ) -> Result<SmartList<Record>, RepositoryError<E::Error>> {
116 self.repository().fetch_smart_list(query)
117 }
118
119 pub fn fetch_entities<T>(
120 &self,
121 query: &SelectQuery,
122 ) -> Result<SmartList<T>, RepositoryError<E::Error>>
123 where
124 T: Entity,
125 {
126 self.repository().fetch_entities(query)
127 }
128
129 pub fn fetch_enhanced_entities<T>(
130 &self,
131 query: &SelectQuery,
132 ) -> Result<SmartList<T>, RepositoryError<E::Error>>
133 where
134 T: Entity,
135 {
136 self.repository().fetch_enhanced_entities(query)
137 }
138
139 pub fn insert(&self, command: &InsertCommand) -> Result<u64, RepositoryError<E::Error>> {
140 let mut compiled = self
141 .repository()
142 .compile_insert(command)
143 .map_err(RepositoryError::Runtime)?;
144 let final_comment = self.resolve_final_comment(&command.trace_chain, None);
145 compiled.comment = final_comment;
146
147 let started_at = SystemTime::now();
148 let started = Instant::now();
149 let affected = self
150 .executor
151 .execute(&compiled)
152 .map_err(RepositoryError::Executor)?;
153 self.log_sql_result(
154 SqlLogOperation::Insert,
155 &compiled,
156 started_at,
157 started,
158 None,
159 None,
160 Some(affected),
161 command.trace_chain.clone(),
162 );
163 self.invalidate_aggregation_cache_for(&command.entity);
164 Ok(affected)
165 }
166
167 pub fn update(&self, command: &UpdateCommand) -> Result<u64, RepositoryError<E::Error>> {
168 let affected = self.execute_mutation(
169 SqlLogOperation::Update,
170 &command.entity,
171 self.repository()
172 .compile_update(command)
173 .map_err(RepositoryError::Runtime)?,
174 command.trace_chain.clone(),
175 )?;
176 if command.expected_version.is_some() && affected == 0 {
177 return Err(RepositoryError::Runtime(
178 RuntimeError::OptimisticLockConflict {
179 entity: command.entity.clone(),
180 id: format!("{:?}", command.id),
181 },
182 ));
183 }
184 Ok(affected)
185 }
186
187 pub fn batch_insert(&self, command: &teaql_core::BatchInsertCommand) -> Result<u64, RepositoryError<E::Error>> {
188 let trace_chain = command.trace_chains.first().cloned().unwrap_or_default();
189 let affected = self.execute_mutation(
190 SqlLogOperation::Insert,
191 &command.entity,
192 self.repository()
193 .compile_batch_insert(command)
194 .map_err(RepositoryError::Runtime)?,
195 trace_chain,
196 )?;
197 Ok(affected)
198 }
199
200 pub fn batch_update(&self, command: &teaql_core::BatchUpdateCommand) -> Result<u64, RepositoryError<E::Error>> {
201 let trace_chain = command.trace_chains.first().cloned().unwrap_or_default();
202 let affected = self.execute_mutation(
203 SqlLogOperation::Update,
204 &command.entity,
205 self.repository()
206 .compile_batch_update(command)
207 .map_err(RepositoryError::Runtime)?,
208 trace_chain,
209 )?;
210 if command.batch_expected_versions.iter().any(|v| v.is_some()) {
211 if affected != command.batch_ids.len() as u64 {
212 return Err(RepositoryError::Runtime(
213 RuntimeError::OptimisticLockConflict {
214 entity: command.entity.clone(),
215 id: "BATCH".to_owned(),
216 },
217 ));
218 }
219 }
220 Ok(affected)
221 }
222
223 pub fn delete(&self, command: &DeleteCommand) -> Result<u64, RepositoryError<E::Error>> {
224 let affected = self.execute_mutation(
225 SqlLogOperation::Delete,
226 &command.entity,
227 self.repository()
228 .compile_delete(command)
229 .map_err(RepositoryError::Runtime)?,
230 command.trace_chain.clone(),
231 )?;
232 if command.expected_version.is_some() && affected == 0 {
233 return Err(RepositoryError::Runtime(
234 RuntimeError::OptimisticLockConflict {
235 entity: command.entity.clone(),
236 id: format!("{:?}", command.id),
237 },
238 ));
239 }
240 Ok(affected)
241 }
242
243 pub fn recover(&self, command: &RecoverCommand) -> Result<u64, RepositoryError<E::Error>> {
244 let affected = self.execute_mutation(
245 SqlLogOperation::Recover,
246 &command.entity,
247 self.repository()
248 .compile_recover(command)
249 .map_err(RepositoryError::Runtime)?,
250 command.trace_chain.clone(),
251 )?;
252 if affected == 0 {
253 return Err(RepositoryError::Runtime(
254 RuntimeError::OptimisticLockConflict {
255 entity: command.entity.clone(),
256 id: format!("{:?}", command.id),
257 },
258 ));
259 }
260 Ok(affected)
261 }
262
263 fn execute_mutation(
264 &self,
265 operation: SqlLogOperation,
266 entity: &str,
267 mut compiled: CompiledQuery,
268 trace_chain: Vec<teaql_core::TraceNode>,
269 ) -> Result<u64, RepositoryError<E::Error>> {
270 let final_comment = self.resolve_final_comment(&trace_chain, None);
271 compiled.comment = final_comment;
272
273 let started_at = SystemTime::now();
274 let started = Instant::now();
275 let affected = self
276 .executor
277 .execute(&compiled)
278 .map_err(RepositoryError::Executor)?;
279 self.log_sql_result(
280 operation,
281 &compiled,
282 started_at,
283 started,
284 None,
285 None,
286 Some(affected),
287 trace_chain,
288 );
289 self.invalidate_aggregation_cache_for(entity);
290 Ok(affected)
291 }
292
293 pub(super) fn log_sql_result(
294 &self,
295 operation: SqlLogOperation,
296 compiled: &CompiledQuery,
297 started_at: SystemTime,
298 started: Instant,
299 result_count: Option<usize>,
300 result_type: Option<String>,
301 affected_rows: Option<u64>,
302 trace_chain: Vec<teaql_core::TraceNode>,
303 ) {
304 self.metadata.context.record_sql_log(
305 operation,
306 compiled,
307 self.dialect.kind(),
308 started_at,
309 SystemTime::now(),
310 started.elapsed(),
311 result_count,
312 result_type,
313 affected_rows,
314 trace_chain,
315 );
316 }
317
318 pub(super) fn invalidate_aggregation_cache_for(&self, entity: &str) {
319 if let Some(cache) = self
320 .metadata
321 .context
322 .get_resource::<Arc<dyn AggregationCacheBackend>>()
323 {
324 invalidate_aggregation_cache_namespace(cache.as_ref(), entity);
325 }
326 if let Some(cache) = self
327 .metadata
328 .context
329 .get_resource::<InMemoryAggregationCache>()
330 {
331 invalidate_aggregation_cache_namespace(cache, entity);
332 }
333 }
334
335 pub(crate) fn resolve_final_comment(&self, trace_chain: &[teaql_core::TraceNode], comment: Option<String>) -> Option<String> {
336 let chain_str = if trace_chain.is_empty() {
337 None
338 } else {
339 let formatted = trace_chain.iter().map(|n| {
340 format!("{}({}): {}", n.entity_type, n.entity_id.map(|id| id.to_string()).unwrap_or_else(|| "pending".to_owned()), n.comment)
341 }).collect::<Vec<_>>().join(" -> ");
342 Some(formatted)
343 };
344
345 let business_comment = chain_str.or(comment);
346 let user_id = self.metadata.context.user_identifier().map(|s| s.to_owned());
347
348 match (user_id, business_comment) {
349 (Some(user), Some(bus)) if !user.is_empty() && !bus.is_empty() => {
350 Some(format!("[{user}] {bus}"))
351 }
352 (Some(user), _) if !user.is_empty() => {
353 Some(format!("[{user}]"))
354 }
355 (_, Some(bus)) if !bus.is_empty() => {
356 Some(bus)
357 }
358 _ => None,
359 }
360 }
361}