1use std::any::{Any, TypeId};
2use std::collections::{BTreeMap, HashMap};
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Mutex;
6use std::time::{Duration, SystemTime};
7
8use teaql_core::{EntityDescriptor, Record, UpdateCommand, Value};
9use teaql_sql::{CompiledQuery, DatabaseKind, SqlDialect};
10
11use crate::{
12 CheckResults, CheckerRegistry, ContextError, EntityEvent, EntityEventSink, GraphNode,
13 InternalIdGenerator, Language, MetadataStore, ObjectLocation, RepositoryBehavior,
14 RepositoryBehaviorRegistry, RepositoryRegistry, RuntimeError, local_id_generator,
15 translate_check_result,
16};
17use crate::{EntityRoot, QueryExecutor, RepositoryError};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum SqlLogOperation {
21 Select,
22 Insert,
23 Update,
24 Delete,
25 Recover,
26}
27
28impl SqlLogOperation {
29 pub fn is_select(self) -> bool {
30 matches!(self, Self::Select)
31 }
32
33 pub fn is_mutation(self) -> bool {
34 !self.is_select()
35 }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub struct SqlLogOptions {
40 pub select: bool,
41 pub mutation: bool,
42}
43
44impl SqlLogOptions {
45 pub fn select_only() -> Self {
46 Self {
47 select: true,
48 mutation: false,
49 }
50 }
51
52 pub fn mutation_only() -> Self {
53 Self {
54 select: false,
55 mutation: true,
56 }
57 }
58
59 pub fn all() -> Self {
60 Self {
61 select: true,
62 mutation: true,
63 }
64 }
65
66 pub fn enabled_for(self, operation: SqlLogOperation) -> bool {
67 if operation.is_select() {
68 self.select
69 } else {
70 self.mutation
71 }
72 }
73}
74
75#[derive(Debug, Clone, PartialEq)]
76pub struct SqlLogEntry {
77 pub operation: SqlLogOperation,
78 pub sql: String,
79 pub params: Vec<Value>,
80 pub debug_sql: String,
81 pub pretty_sql: String,
82 pub started_at: SystemTime,
83 pub ended_at: SystemTime,
84 pub elapsed: Duration,
85 pub result_count: Option<usize>,
86 pub result_type: Option<String>,
87 pub affected_rows: Option<u64>,
88 pub result_summary: String,
89}
90
91pub trait SchemaProvider: Send + Sync {
92 fn ensure_schema<'a>(
93 &'a self,
94 ctx: &'a UserContext,
95 ) -> Pin<Box<dyn Future<Output = Result<(), RuntimeError>> + Send + 'a>>;
96}
97
98#[derive(Default)]
99pub struct UserContext {
100 pub(crate) metadata: Option<Box<dyn MetadataStore>>,
101 pub(crate) repository_registry: Option<Box<dyn RepositoryRegistry>>,
102 pub(crate) repository_behavior_registry: Option<Box<dyn RepositoryBehaviorRegistry>>,
103 pub(crate) checker_registry: Option<Box<dyn CheckerRegistry>>,
104 pub(crate) event_sink: Option<Box<dyn EntityEventSink>>,
105 pub(crate) internal_id_generator: Option<Box<dyn InternalIdGenerator>>,
106 schema_provider: Option<Box<dyn SchemaProvider>>,
107 language: Language,
108 typed_resources: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
109 named_resources: BTreeMap<String, Box<dyn Any + Send + Sync>>,
110 locals: BTreeMap<String, Value>,
111 pub(crate) initial_graphs: Vec<GraphNode>,
112 entity_root: EntityRoot,
113 sql_log_options: SqlLogOptions,
114 sql_log_entries: Mutex<Vec<SqlLogEntry>>,
115}
116
117impl UserContext {
118 pub fn new() -> Self {
119 Self::default()
120 }
121
122 pub fn with_module(mut self, module: crate::RuntimeModule) -> Self {
123 module.apply_to(&mut self);
124 self
125 }
126
127 pub fn entity_root(&self) -> EntityRoot {
128 self.entity_root.clone()
129 }
130
131 pub fn initial_graphs(&self) -> &[GraphNode] {
132 &self.initial_graphs
133 }
134
135 pub fn set_initial_graphs(&mut self, graphs: Vec<GraphNode>) {
136 self.initial_graphs = graphs;
137 }
138
139 pub fn with_metadata(mut self, metadata: impl MetadataStore + 'static) -> Self {
140 self.metadata = Some(Box::new(metadata));
141 self
142 }
143
144 pub fn set_metadata(&mut self, metadata: impl MetadataStore + 'static) {
145 self.metadata = Some(Box::new(metadata));
146 }
147
148 pub fn with_repository_registry(mut self, registry: impl RepositoryRegistry + 'static) -> Self {
149 self.repository_registry = Some(Box::new(registry));
150 self
151 }
152
153 pub fn set_repository_registry(&mut self, registry: impl RepositoryRegistry + 'static) {
154 self.repository_registry = Some(Box::new(registry));
155 }
156
157 pub fn with_repository_behavior_registry(
158 mut self,
159 registry: impl RepositoryBehaviorRegistry + 'static,
160 ) -> Self {
161 self.repository_behavior_registry = Some(Box::new(registry));
162 self
163 }
164
165 pub fn set_repository_behavior_registry(
166 &mut self,
167 registry: impl RepositoryBehaviorRegistry + 'static,
168 ) {
169 self.repository_behavior_registry = Some(Box::new(registry));
170 }
171
172 pub fn with_checker_registry(mut self, registry: impl CheckerRegistry + 'static) -> Self {
173 self.checker_registry = Some(Box::new(registry));
174 self
175 }
176
177 pub fn set_checker_registry(&mut self, registry: impl CheckerRegistry + 'static) {
178 self.checker_registry = Some(Box::new(registry));
179 }
180
181 pub fn with_event_sink(mut self, sink: impl EntityEventSink + 'static) -> Self {
182 self.event_sink = Some(Box::new(sink));
183 self
184 }
185
186 pub fn set_event_sink(&mut self, sink: impl EntityEventSink + 'static) {
187 self.event_sink = Some(Box::new(sink));
188 }
189
190 pub fn with_internal_id_generator(
191 mut self,
192 generator: impl InternalIdGenerator + 'static,
193 ) -> Self {
194 self.internal_id_generator = Some(Box::new(generator));
195 self
196 }
197
198 pub fn set_internal_id_generator(&mut self, generator: impl InternalIdGenerator + 'static) {
199 self.internal_id_generator = Some(Box::new(generator));
200 }
201
202 pub fn with_schema_provider(mut self, provider: impl SchemaProvider + 'static) -> Self {
203 self.schema_provider = Some(Box::new(provider));
204 self
205 }
206
207 pub fn set_schema_provider(&mut self, provider: impl SchemaProvider + 'static) {
208 self.schema_provider = Some(Box::new(provider));
209 }
210
211 pub async fn ensure_schema(&self) -> Result<(), RuntimeError> {
212 let provider = self
213 .schema_provider
214 .as_ref()
215 .ok_or_else(|| RuntimeError::Schema("missing schema provider".to_owned()))?;
216 provider.ensure_schema(self).await
217 }
218
219 pub fn with_language(mut self, language: Language) -> Self {
220 self.language = language;
221 self
222 }
223
224 pub fn set_language(&mut self, language: Language) {
225 self.language = language;
226 }
227
228 pub fn with_sql_log_options(mut self, options: SqlLogOptions) -> Self {
229 self.sql_log_options = options;
230 self
231 }
232
233 pub fn set_sql_log_options(&mut self, options: SqlLogOptions) {
234 self.sql_log_options = options;
235 }
236
237 pub fn enable_select_sql_log(&mut self) {
238 self.sql_log_options.select = true;
239 }
240
241 pub fn enable_mutation_sql_log(&mut self) {
242 self.sql_log_options.mutation = true;
243 }
244
245 pub fn enable_all_sql_log(&mut self) {
246 self.sql_log_options = SqlLogOptions::all();
247 }
248
249 pub fn disable_sql_log(&mut self) {
250 self.sql_log_options = SqlLogOptions::default();
251 self.clear_sql_logs();
252 }
253
254 pub fn sql_log_options(&self) -> SqlLogOptions {
255 self.sql_log_options
256 }
257
258 pub fn sql_logs(&self) -> Vec<SqlLogEntry> {
259 self.sql_log_entries
260 .lock()
261 .map(|entries| entries.clone())
262 .unwrap_or_default()
263 }
264
265 pub fn clear_sql_logs(&self) {
266 if let Ok(mut entries) = self.sql_log_entries.lock() {
267 entries.clear();
268 }
269 }
270
271 pub(crate) fn record_sql_log(
272 &self,
273 operation: SqlLogOperation,
274 query: &CompiledQuery,
275 database_kind: DatabaseKind,
276 started_at: SystemTime,
277 ended_at: SystemTime,
278 elapsed: Duration,
279 result_count: Option<usize>,
280 result_type: Option<String>,
281 affected_rows: Option<u64>,
282 ) {
283 if !self.sql_log_options.enabled_for(operation) {
284 return;
285 }
286 let debug_sql = query.debug_sql(database_kind);
287 if let Ok(mut entries) = self.sql_log_entries.lock() {
288 entries.push(SqlLogEntry {
289 operation,
290 sql: query.sql.clone(),
291 params: query.params.clone(),
292 pretty_sql: pretty_sql(&debug_sql),
293 debug_sql,
294 started_at,
295 ended_at,
296 elapsed,
297 result_summary: sql_result_summary(
298 operation,
299 result_count,
300 result_type.as_deref(),
301 affected_rows,
302 ),
303 result_count,
304 result_type,
305 affected_rows,
306 });
307 }
308 }
309
310 pub fn language(&self) -> Language {
311 self.language
312 }
313
314 pub fn set_language_code(&mut self, code: &str) -> Result<(), RuntimeError> {
315 let Some(language) = Language::from_code(code) else {
316 return Err(RuntimeError::Language(format!(
317 "unsupported language code: {code}"
318 )));
319 };
320 self.language = language;
321 Ok(())
322 }
323
324 pub fn generate_id(&self, entity: &str) -> Result<Option<u64>, RuntimeError> {
325 self.internal_id_generator
326 .as_ref()
327 .map(|generator| generator.generate_id(entity))
328 .transpose()
329 }
330
331 pub fn next_id(&self, entity: &str) -> Result<u64, RuntimeError> {
332 match self.generate_id(entity)? {
333 Some(id) => Ok(id),
334 None => local_id_generator().generate_id(entity),
335 }
336 }
337
338 pub fn entity(&self, name: &str) -> Option<&EntityDescriptor> {
339 self.metadata
340 .as_ref()
341 .and_then(|metadata| metadata.entity(name))
342 }
343
344 pub fn all_entities(&self) -> Vec<&EntityDescriptor> {
345 self.metadata
346 .as_ref()
347 .map(|metadata| metadata.all_entities())
348 .unwrap_or_default()
349 }
350
351 pub fn require_entity(&self, name: &str) -> Result<&EntityDescriptor, RuntimeError> {
352 self.entity(name)
353 .ok_or_else(|| RuntimeError::MissingEntity(name.to_owned()))
354 }
355
356 pub fn insert_resource<T>(&mut self, resource: T)
357 where
358 T: Send + Sync + 'static,
359 {
360 self.typed_resources
361 .insert(TypeId::of::<T>(), Box::new(resource));
362 }
363
364 pub fn get_resource<T>(&self) -> Option<&T>
365 where
366 T: Send + Sync + 'static,
367 {
368 self.typed_resources
369 .get(&TypeId::of::<T>())
370 .and_then(|value| value.downcast_ref::<T>())
371 }
372
373 pub fn require_resource<T>(&self) -> Result<&T, ContextError>
374 where
375 T: Send + Sync + 'static,
376 {
377 self.get_resource::<T>()
378 .ok_or(ContextError::MissingTypedResource(
379 std::any::type_name::<T>(),
380 ))
381 }
382
383 pub fn insert_named_resource<T>(&mut self, name: impl Into<String>, resource: T)
384 where
385 T: Send + Sync + 'static,
386 {
387 self.named_resources.insert(name.into(), Box::new(resource));
388 }
389
390 pub fn get_named_resource<T>(&self, name: &str) -> Option<&T>
391 where
392 T: Send + Sync + 'static,
393 {
394 self.named_resources
395 .get(name)
396 .and_then(|value| value.downcast_ref::<T>())
397 }
398
399 pub fn require_named_resource<T>(&self, name: &str) -> Result<&T, ContextError>
400 where
401 T: Send + Sync + 'static,
402 {
403 self.get_named_resource::<T>(name)
404 .ok_or_else(|| ContextError::MissingResource(name.to_owned()))
405 }
406
407 pub fn put_local(&mut self, key: impl Into<String>, value: impl Into<Value>) {
408 self.locals.insert(key.into(), value.into());
409 }
410
411 pub fn local(&self, key: &str) -> Option<&Value> {
412 self.locals.get(key)
413 }
414
415 pub fn remove_local(&mut self, key: &str) -> Option<Value> {
416 self.locals.remove(key)
417 }
418
419 pub fn has_repository(&self, entity: &str) -> bool {
420 let in_registry = self
421 .repository_registry
422 .as_ref()
423 .map(|registry| registry.contains(entity))
424 .unwrap_or(false);
425 in_registry || self.entity(entity).is_some()
426 }
427
428 pub fn repository_behavior(
429 &self,
430 entity: &str,
431 ) -> Option<std::sync::Arc<dyn RepositoryBehavior>> {
432 self.repository_behavior_registry
433 .as_ref()
434 .and_then(|registry| registry.behavior(entity))
435 }
436
437 pub fn has_checker(&self, entity: &str) -> bool {
438 self.checker_registry
439 .as_ref()
440 .and_then(|registry| registry.checker(entity))
441 .is_some()
442 }
443
444 pub fn check_and_fix_record(
445 &self,
446 entity: &str,
447 record: &mut Record,
448 ) -> Result<(), RuntimeError> {
449 self.check_and_fix_record_at(entity, record, &ObjectLocation::root())
450 }
451
452 pub fn check_and_fix_record_at(
453 &self,
454 entity: &str,
455 record: &mut Record,
456 location: &ObjectLocation,
457 ) -> Result<(), RuntimeError> {
458 let Some(checker) = self
459 .checker_registry
460 .as_ref()
461 .and_then(|registry| registry.checker(entity))
462 else {
463 return Ok(());
464 };
465 let mut results = CheckResults::new();
466 checker.check_and_fix(self, record, location, &mut results);
467 if results.is_empty() {
468 Ok(())
469 } else {
470 self.translate_check_results(&mut results);
471 Err(RuntimeError::Check(results))
472 }
473 }
474
475 pub fn translate_check_results(&self, results: &mut CheckResults) {
476 for result in results {
477 result.message = Some(translate_check_result(self.language, result));
478 }
479 }
480
481 pub fn send_event(&self, event: EntityEvent) -> Result<(), RuntimeError> {
482 let Some(sink) = self.event_sink.as_ref() else {
483 return Ok(());
484 };
485 sink.on_event(self, &event)
486 }
487
488 pub fn commit_changes<D, E>(&self) -> Result<(), RepositoryError<E::Error>>
489 where
490 D: SqlDialect + Send + Sync + 'static,
491 E: QueryExecutor + Send + Sync + 'static,
492 {
493 let dialect = self.require_resource::<D>().map_err(|err| {
494 RepositoryError::Runtime(RuntimeError::Graph(format!(
495 "cannot commit changes without dialect: {err}"
496 )))
497 })?;
498 let executor = self.require_resource::<E>().map_err(|err| {
499 RepositoryError::Runtime(RuntimeError::Graph(format!(
500 "cannot commit changes without executor: {err}"
501 )))
502 })?;
503 let change_set = self.entity_root.current_change_set();
504
505 for (key, changes) in change_set.changes() {
506 if changes.is_empty() {
507 continue;
508 }
509 let entity = self
510 .require_entity(&key.entity)
511 .map_err(RepositoryError::Runtime)?;
512 let mut command = UpdateCommand::new(&key.entity, key.id.clone());
513 for (field, value) in changes {
514 command = command.value(field.clone(), value.clone());
515 }
516 let query = dialect
517 .compile_update(entity, &command)
518 .map_err(RuntimeError::from)
519 .map_err(RepositoryError::Runtime)?;
520 executor
521 .execute(&query)
522 .map_err(RepositoryError::Executor)?;
523 }
524
525 self.entity_root.clear_current_change_set();
526 Ok(())
527 }
528}
529
530fn sql_result_summary(
531 operation: SqlLogOperation,
532 result_count: Option<usize>,
533 result_type: Option<&str>,
534 affected_rows: Option<u64>,
535) -> String {
536 match operation {
537 SqlLogOperation::Select => {
538 let count = result_count.unwrap_or(0);
539 match result_type {
540 Some(result_type) => format!("{count} x {result_type}"),
541 None => format!("{count} rows"),
542 }
543 }
544 _ => format!("{} rows affected", affected_rows.unwrap_or(0)),
545 }
546}
547
548fn pretty_sql(sql: &str) -> String {
549 let mut pretty = sql.to_owned();
550 for keyword in [
551 " FROM ",
552 " WHERE ",
553 " GROUP BY ",
554 " HAVING ",
555 " ORDER BY ",
556 " LIMIT ",
557 " OFFSET ",
558 " RETURNING ",
559 ] {
560 pretty = pretty.replace(keyword, &format!("\n{}", keyword.trim_start()));
561 }
562 pretty
563 .replace(" AND ", "\n AND ")
564 .replace(" OR ", "\n OR ")
565}