1use std::any::{Any, TypeId};
2use std::collections::{BTreeMap, HashMap};
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Mutex;
6use std::time::{Duration, SystemTime};
7
8use teaql_core::{EntityDescriptor, Record, UpdateCommand, Value};
9use teaql_sql::{CompiledQuery, DatabaseKind, SqlDialect};
10
11use crate::{
12 CheckResults, CheckerRegistry, ContextError, EntityEvent, EntityEventSink, GraphNode,
13 InternalIdGenerator, Language, MetadataStore, ObjectLocation, RepositoryBehavior,
14 RepositoryBehaviorRegistry, RepositoryRegistry, RequestPolicy, RuntimeError,
15 local_id_generator, translate_check_result,
16};
17use crate::{EntityRoot, QueryExecutor, RepositoryError};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum SqlLogOperation {
21 Select,
22 Insert,
23 Update,
24 Delete,
25 Recover,
26}
27
28impl SqlLogOperation {
29 pub fn is_select(self) -> bool {
30 matches!(self, Self::Select)
31 }
32
33 pub fn is_mutation(self) -> bool {
34 !self.is_select()
35 }
36}
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub struct SqlLogOptions {
40 pub select: bool,
41 pub mutation: bool,
42}
43
44impl SqlLogOptions {
45 pub fn disabled() -> Self {
46 Self {
47 select: false,
48 mutation: false,
49 }
50 }
51
52 pub fn select_only() -> Self {
53 Self {
54 select: true,
55 mutation: false,
56 }
57 }
58
59 pub fn mutation_only() -> Self {
60 Self {
61 select: false,
62 mutation: true,
63 }
64 }
65
66 pub fn all() -> Self {
67 Self {
68 select: true,
69 mutation: true,
70 }
71 }
72
73 pub fn enabled_for(self, operation: SqlLogOperation) -> bool {
74 if operation.is_select() {
75 self.select
76 } else {
77 self.mutation
78 }
79 }
80}
81
82#[derive(Debug, Clone, PartialEq)]
83pub struct SqlLogEntry {
84 pub operation: SqlLogOperation,
85 pub sql: String,
86 pub params: Vec<Value>,
87 pub debug_sql: String,
88 pub pretty_sql: String,
89 pub started_at: SystemTime,
90 pub ended_at: SystemTime,
91 pub elapsed: Duration,
92 pub result_count: Option<usize>,
93 pub result_type: Option<String>,
94 pub affected_rows: Option<u64>,
95 pub result_summary: String,
96}
97
98pub trait SchemaProvider: Send + Sync {
99 fn ensure_schema<'a>(
100 &'a self,
101 ctx: &'a UserContext,
102 ) -> Pin<Box<dyn Future<Output = Result<(), RuntimeError>> + Send + 'a>>;
103}
104
105pub struct UserContext {
106 pub(crate) metadata: Option<Box<dyn MetadataStore>>,
107 pub(crate) repository_registry: Option<Box<dyn RepositoryRegistry>>,
108 pub(crate) repository_behavior_registry: Option<Box<dyn RepositoryBehaviorRegistry>>,
109 pub(crate) request_policy: Option<Box<dyn RequestPolicy>>,
110 pub(crate) checker_registry: Option<Box<dyn CheckerRegistry>>,
111 pub(crate) event_sink: Option<Box<dyn EntityEventSink>>,
112 pub(crate) internal_id_generator: Option<Box<dyn InternalIdGenerator>>,
113 schema_provider: Option<Box<dyn SchemaProvider>>,
114 language: Language,
115 typed_resources: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
116 named_resources: BTreeMap<String, Box<dyn Any + Send + Sync>>,
117 locals: BTreeMap<String, Value>,
118 pub(crate) initial_graphs: Vec<GraphNode>,
119 entity_root: EntityRoot,
120 sql_log_options: SqlLogOptions,
121 sql_log_entries: Mutex<Vec<SqlLogEntry>>,
122}
123
124impl Default for UserContext {
125 fn default() -> Self {
126 Self {
127 metadata: None,
128 repository_registry: None,
129 repository_behavior_registry: None,
130 request_policy: None,
131 checker_registry: None,
132 event_sink: None,
133 internal_id_generator: None,
134 schema_provider: None,
135 language: Language::default(),
136 typed_resources: HashMap::new(),
137 named_resources: BTreeMap::new(),
138 locals: BTreeMap::new(),
139 initial_graphs: Vec::new(),
140 entity_root: EntityRoot::default(),
141 sql_log_options: SqlLogOptions::all(),
142 sql_log_entries: Mutex::new(Vec::new()),
143 }
144 }
145}
146
147impl UserContext {
148 pub fn new() -> Self {
149 Self::default()
150 }
151
152 pub fn with_module(mut self, module: crate::RuntimeModule) -> Self {
153 module.apply_to(&mut self);
154 self
155 }
156
157 pub fn entity_root(&self) -> EntityRoot {
158 self.entity_root.clone()
159 }
160
161 pub fn initial_graphs(&self) -> &[GraphNode] {
162 &self.initial_graphs
163 }
164
165 pub fn set_initial_graphs(&mut self, graphs: Vec<GraphNode>) {
166 self.initial_graphs = graphs;
167 }
168
169 pub fn with_metadata(mut self, metadata: impl MetadataStore + 'static) -> Self {
170 self.metadata = Some(Box::new(metadata));
171 self
172 }
173
174 pub fn set_metadata(&mut self, metadata: impl MetadataStore + 'static) {
175 self.metadata = Some(Box::new(metadata));
176 }
177
178 pub fn with_repository_registry(mut self, registry: impl RepositoryRegistry + 'static) -> Self {
179 self.repository_registry = Some(Box::new(registry));
180 self
181 }
182
183 pub fn set_repository_registry(&mut self, registry: impl RepositoryRegistry + 'static) {
184 self.repository_registry = Some(Box::new(registry));
185 }
186
187 pub fn with_repository_behavior_registry(
188 mut self,
189 registry: impl RepositoryBehaviorRegistry + 'static,
190 ) -> Self {
191 self.repository_behavior_registry = Some(Box::new(registry));
192 self
193 }
194
195 pub fn set_repository_behavior_registry(
196 &mut self,
197 registry: impl RepositoryBehaviorRegistry + 'static,
198 ) {
199 self.repository_behavior_registry = Some(Box::new(registry));
200 }
201
202 pub fn with_request_policy(mut self, policy: impl RequestPolicy + 'static) -> Self {
203 self.request_policy = Some(Box::new(policy));
204 self
205 }
206
207 pub fn set_request_policy(&mut self, policy: impl RequestPolicy + 'static) {
208 self.request_policy = Some(Box::new(policy));
209 }
210
211 pub fn clear_request_policy(&mut self) {
212 self.request_policy = None;
213 }
214
215 pub fn with_checker_registry(mut self, registry: impl CheckerRegistry + 'static) -> Self {
216 self.checker_registry = Some(Box::new(registry));
217 self
218 }
219
220 pub fn set_checker_registry(&mut self, registry: impl CheckerRegistry + 'static) {
221 self.checker_registry = Some(Box::new(registry));
222 }
223
224 pub fn with_event_sink(mut self, sink: impl EntityEventSink + 'static) -> Self {
225 self.event_sink = Some(Box::new(sink));
226 self
227 }
228
229 pub fn set_event_sink(&mut self, sink: impl EntityEventSink + 'static) {
230 self.event_sink = Some(Box::new(sink));
231 }
232
233 pub fn with_internal_id_generator(
234 mut self,
235 generator: impl InternalIdGenerator + 'static,
236 ) -> Self {
237 self.internal_id_generator = Some(Box::new(generator));
238 self
239 }
240
241 pub fn set_internal_id_generator(&mut self, generator: impl InternalIdGenerator + 'static) {
242 self.internal_id_generator = Some(Box::new(generator));
243 }
244
245 pub fn with_schema_provider(mut self, provider: impl SchemaProvider + 'static) -> Self {
246 self.schema_provider = Some(Box::new(provider));
247 self
248 }
249
250 pub fn set_schema_provider(&mut self, provider: impl SchemaProvider + 'static) {
251 self.schema_provider = Some(Box::new(provider));
252 }
253
254 pub async fn ensure_schema(&self) -> Result<(), RuntimeError> {
255 let provider = self
256 .schema_provider
257 .as_ref()
258 .ok_or_else(|| RuntimeError::Schema("missing schema provider".to_owned()))?;
259 provider.ensure_schema(self).await
260 }
261
262 pub fn with_language(mut self, language: Language) -> Self {
263 self.language = language;
264 self
265 }
266
267 pub fn set_language(&mut self, language: Language) {
268 self.language = language;
269 }
270
271 pub fn with_sql_log_options(mut self, options: SqlLogOptions) -> Self {
272 self.sql_log_options = options;
273 self
274 }
275
276 pub fn set_sql_log_options(&mut self, options: SqlLogOptions) {
277 self.sql_log_options = options;
278 }
279
280 pub fn enable_select_sql_log(&mut self) {
281 self.sql_log_options.select = true;
282 }
283
284 pub fn enable_mutation_sql_log(&mut self) {
285 self.sql_log_options.mutation = true;
286 }
287
288 pub fn enable_all_sql_log(&mut self) {
289 self.sql_log_options = SqlLogOptions::all();
290 }
291
292 pub fn disable_sql_log(&mut self) {
293 self.sql_log_options = SqlLogOptions::disabled();
294 self.clear_sql_logs();
295 }
296
297 pub fn sql_log_options(&self) -> SqlLogOptions {
298 self.sql_log_options
299 }
300
301 pub fn sql_logs(&self) -> Vec<SqlLogEntry> {
302 self.sql_log_entries
303 .lock()
304 .map(|entries| entries.clone())
305 .unwrap_or_default()
306 }
307
308 pub fn clear_sql_logs(&self) {
309 if let Ok(mut entries) = self.sql_log_entries.lock() {
310 entries.clear();
311 }
312 }
313
314 pub(crate) fn record_sql_log(
315 &self,
316 operation: SqlLogOperation,
317 query: &CompiledQuery,
318 database_kind: DatabaseKind,
319 started_at: SystemTime,
320 ended_at: SystemTime,
321 elapsed: Duration,
322 result_count: Option<usize>,
323 result_type: Option<String>,
324 affected_rows: Option<u64>,
325 ) {
326 if !self.sql_log_options.enabled_for(operation) {
327 return;
328 }
329 let debug_sql = query.debug_sql(database_kind);
330 if let Ok(mut entries) = self.sql_log_entries.lock() {
331 entries.push(SqlLogEntry {
332 operation,
333 sql: query.sql.clone(),
334 params: query.params.clone(),
335 pretty_sql: pretty_sql(&debug_sql),
336 debug_sql,
337 started_at,
338 ended_at,
339 elapsed,
340 result_summary: sql_result_summary(
341 operation,
342 result_count,
343 result_type.as_deref(),
344 affected_rows,
345 ),
346 result_count,
347 result_type,
348 affected_rows,
349 });
350 }
351 }
352
353 pub fn language(&self) -> Language {
354 self.language
355 }
356
357 pub fn set_language_code(&mut self, code: &str) -> Result<(), RuntimeError> {
358 let Some(language) = Language::from_code(code) else {
359 return Err(RuntimeError::Language(format!(
360 "unsupported language code: {code}"
361 )));
362 };
363 self.language = language;
364 Ok(())
365 }
366
367 pub fn generate_id(&self, entity: &str) -> Result<Option<u64>, RuntimeError> {
368 self.internal_id_generator
369 .as_ref()
370 .map(|generator| generator.generate_id(entity))
371 .transpose()
372 }
373
374 pub fn next_id(&self, entity: &str) -> Result<u64, RuntimeError> {
375 match self.generate_id(entity)? {
376 Some(id) => Ok(id),
377 None => local_id_generator().generate_id(entity),
378 }
379 }
380
381 pub fn entity(&self, name: &str) -> Option<&EntityDescriptor> {
382 self.metadata
383 .as_ref()
384 .and_then(|metadata| metadata.entity(name))
385 }
386
387 pub fn all_entities(&self) -> Vec<&EntityDescriptor> {
388 self.metadata
389 .as_ref()
390 .map(|metadata| metadata.all_entities())
391 .unwrap_or_default()
392 }
393
394 pub fn require_entity(&self, name: &str) -> Result<&EntityDescriptor, RuntimeError> {
395 self.entity(name)
396 .ok_or_else(|| RuntimeError::MissingEntity(name.to_owned()))
397 }
398
399 pub fn insert_resource<T>(&mut self, resource: T)
400 where
401 T: Send + Sync + 'static,
402 {
403 self.typed_resources
404 .insert(TypeId::of::<T>(), Box::new(resource));
405 }
406
407 pub fn get_resource<T>(&self) -> Option<&T>
408 where
409 T: Send + Sync + 'static,
410 {
411 self.typed_resources
412 .get(&TypeId::of::<T>())
413 .and_then(|value| value.downcast_ref::<T>())
414 }
415
416 pub fn require_resource<T>(&self) -> Result<&T, ContextError>
417 where
418 T: Send + Sync + 'static,
419 {
420 self.get_resource::<T>()
421 .ok_or(ContextError::MissingTypedResource(
422 std::any::type_name::<T>(),
423 ))
424 }
425
426 pub fn insert_named_resource<T>(&mut self, name: impl Into<String>, resource: T)
427 where
428 T: Send + Sync + 'static,
429 {
430 self.named_resources.insert(name.into(), Box::new(resource));
431 }
432
433 pub fn get_named_resource<T>(&self, name: &str) -> Option<&T>
434 where
435 T: Send + Sync + 'static,
436 {
437 self.named_resources
438 .get(name)
439 .and_then(|value| value.downcast_ref::<T>())
440 }
441
442 pub fn require_named_resource<T>(&self, name: &str) -> Result<&T, ContextError>
443 where
444 T: Send + Sync + 'static,
445 {
446 self.get_named_resource::<T>(name)
447 .ok_or_else(|| ContextError::MissingResource(name.to_owned()))
448 }
449
450 pub fn put_local(&mut self, key: impl Into<String>, value: impl Into<Value>) {
451 self.locals.insert(key.into(), value.into());
452 }
453
454 pub fn local(&self, key: &str) -> Option<&Value> {
455 self.locals.get(key)
456 }
457
458 pub fn remove_local(&mut self, key: &str) -> Option<Value> {
459 self.locals.remove(key)
460 }
461
462 pub fn has_repository(&self, entity: &str) -> bool {
463 let in_registry = self
464 .repository_registry
465 .as_ref()
466 .map(|registry| registry.contains(entity))
467 .unwrap_or(false);
468 in_registry || self.entity(entity).is_some()
469 }
470
471 pub fn repository_behavior(
472 &self,
473 entity: &str,
474 ) -> Option<std::sync::Arc<dyn RepositoryBehavior>> {
475 self.repository_behavior_registry
476 .as_ref()
477 .and_then(|registry| registry.behavior(entity))
478 }
479
480 pub fn has_checker(&self, entity: &str) -> bool {
481 self.checker_registry
482 .as_ref()
483 .and_then(|registry| registry.checker(entity))
484 .is_some()
485 }
486
487 pub fn check_and_fix_record(
488 &self,
489 entity: &str,
490 record: &mut Record,
491 ) -> Result<(), RuntimeError> {
492 self.check_and_fix_record_at(entity, record, &ObjectLocation::root())
493 }
494
495 pub fn check_and_fix_record_at(
496 &self,
497 entity: &str,
498 record: &mut Record,
499 location: &ObjectLocation,
500 ) -> Result<(), RuntimeError> {
501 let Some(checker) = self
502 .checker_registry
503 .as_ref()
504 .and_then(|registry| registry.checker(entity))
505 else {
506 return Ok(());
507 };
508 let mut results = CheckResults::new();
509 checker.check_and_fix(self, record, location, &mut results);
510 if results.is_empty() {
511 Ok(())
512 } else {
513 self.translate_check_results(&mut results);
514 Err(RuntimeError::Check(results))
515 }
516 }
517
518 pub fn translate_check_results(&self, results: &mut CheckResults) {
519 for result in results {
520 result.message = Some(translate_check_result(self.language, result));
521 }
522 }
523
524 pub fn send_event(&self, event: EntityEvent) -> Result<(), RuntimeError> {
525 let Some(sink) = self.event_sink.as_ref() else {
526 return Ok(());
527 };
528 sink.on_event(self, &event)
529 }
530
531 pub fn commit_changes<D, E>(&self) -> Result<(), RepositoryError<E::Error>>
532 where
533 D: SqlDialect + Send + Sync + 'static,
534 E: QueryExecutor + Send + Sync + 'static,
535 {
536 let dialect = self.require_resource::<D>().map_err(|err| {
537 RepositoryError::Runtime(RuntimeError::Graph(format!(
538 "cannot commit changes without dialect: {err}"
539 )))
540 })?;
541 let executor = self.require_resource::<E>().map_err(|err| {
542 RepositoryError::Runtime(RuntimeError::Graph(format!(
543 "cannot commit changes without executor: {err}"
544 )))
545 })?;
546 let change_set = self.entity_root.current_change_set();
547
548 for (key, changes) in change_set.changes() {
549 if changes.is_empty() {
550 continue;
551 }
552 let entity = self
553 .require_entity(&key.entity)
554 .map_err(RepositoryError::Runtime)?;
555 let mut command = UpdateCommand::new(&key.entity, key.id.clone());
556 for (field, value) in changes {
557 command = command.value(field.clone(), value.clone());
558 }
559 let query = dialect
560 .compile_update(entity, &command)
561 .map_err(RuntimeError::from)
562 .map_err(RepositoryError::Runtime)?;
563 executor
564 .execute(&query)
565 .map_err(RepositoryError::Executor)?;
566 }
567
568 self.entity_root.clear_current_change_set();
569 Ok(())
570 }
571}
572
573fn sql_result_summary(
574 operation: SqlLogOperation,
575 result_count: Option<usize>,
576 result_type: Option<&str>,
577 affected_rows: Option<u64>,
578) -> String {
579 match operation {
580 SqlLogOperation::Select => {
581 let count = result_count.unwrap_or(0);
582 match result_type {
583 Some(result_type) => format!("{count} x {result_type}"),
584 None => format!("{count} rows"),
585 }
586 }
587 _ => format!("{} rows affected", affected_rows.unwrap_or(0)),
588 }
589}
590
591fn pretty_sql(sql: &str) -> String {
592 let mut pretty = sql.to_owned();
593 for keyword in [
594 " FROM ",
595 " WHERE ",
596 " GROUP BY ",
597 " HAVING ",
598 " ORDER BY ",
599 " LIMIT ",
600 " OFFSET ",
601 " RETURNING ",
602 ] {
603 pretty = pretty.replace(keyword, &format!("\n{}", keyword.trim_start()));
604 }
605 pretty
606 .replace(" AND ", "\n AND ")
607 .replace(" OR ", "\n OR ")
608}