1use crate::error::TraceEngineError;
2use crate::parquet::tracing::traits::arrow_schema_to_delta;
3use crate::parquet::utils::register_cloud_logstore_factories;
4use crate::storage::ObjectStore;
5use arrow::array::*;
6use arrow::datatypes::*;
7use arrow_array::RecordBatch;
8use chrono::{DateTime, Duration, Utc};
9use datafusion::logical_expr::{col, lit};
10use datafusion::prelude::SessionContext;
11use deltalake::{DeltaTable, DeltaTableBuilder, TableProperty};
12use std::sync::Arc;
13use tokio::sync::RwLock as AsyncRwLock;
14use tracing::{debug, info, warn};
15use url::Url;
16
17const CONTROL_TABLE_NAME: &str = "_scouter_control";
18
19const STALE_LOCK_MINUTES: i64 = 30;
22
23mod status {
25 pub const IDLE: &str = "idle";
26 pub const PROCESSING: &str = "processing";
27}
28
29#[derive(Debug, Clone)]
31pub struct TaskRecord {
32 pub task_name: String,
33 pub status: String,
34 pub pod_id: String,
35 pub claimed_at: DateTime<Utc>,
36 pub completed_at: Option<DateTime<Utc>>,
37 pub next_run_at: DateTime<Utc>,
38}
39
40fn control_schema() -> Schema {
45 Schema::new(vec![
46 Field::new("task_name", DataType::Utf8, false),
47 Field::new("status", DataType::Utf8, false),
48 Field::new("pod_id", DataType::Utf8, false),
49 Field::new(
50 "claimed_at",
51 DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
52 false,
53 ),
54 Field::new(
55 "completed_at",
56 DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
57 true,
58 ),
59 Field::new(
60 "next_run_at",
61 DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
62 false,
63 ),
64 ])
65}
66
67fn build_task_batch(
68 schema: &SchemaRef,
69 record: &TaskRecord,
70) -> Result<RecordBatch, TraceEngineError> {
71 let task_name = StringArray::from(vec![record.task_name.as_str()]);
72 let status = StringArray::from(vec![record.status.as_str()]);
73 let pod_id = StringArray::from(vec![record.pod_id.as_str()]);
74 let claimed_at = TimestampMicrosecondArray::from(vec![record.claimed_at.timestamp_micros()])
75 .with_timezone("UTC");
76 let completed_at = if let Some(ts) = record.completed_at {
77 TimestampMicrosecondArray::from(vec![Some(ts.timestamp_micros())]).with_timezone("UTC")
78 } else {
79 TimestampMicrosecondArray::from(vec![None::<i64>]).with_timezone("UTC")
80 };
81 let next_run_at = TimestampMicrosecondArray::from(vec![record.next_run_at.timestamp_micros()])
82 .with_timezone("UTC");
83
84 RecordBatch::try_new(
85 schema.clone(),
86 vec![
87 Arc::new(task_name),
88 Arc::new(status),
89 Arc::new(pod_id),
90 Arc::new(claimed_at),
91 Arc::new(completed_at),
92 Arc::new(next_run_at),
93 ],
94 )
95 .map_err(Into::into)
96}
97
98pub fn get_pod_id() -> String {
103 std::env::var("HOSTNAME")
104 .or_else(|_| std::env::var("POD_NAME"))
105 .unwrap_or_else(|_| format!("local-{}", std::process::id()))
106}
107
108pub struct ControlTableEngine {
118 schema: SchemaRef,
119 #[allow(dead_code)] table: Arc<AsyncRwLock<DeltaTable>>,
121 ctx: Arc<SessionContext>,
122 pod_id: String,
123}
124
125impl ControlTableEngine {
126 pub async fn new(object_store: &ObjectStore, pod_id: String) -> Result<Self, TraceEngineError> {
131 let schema = Arc::new(control_schema());
132 let table = build_or_create_control_table(object_store, schema.clone()).await?;
133 let ctx = object_store.get_session()?;
134
135 if let Ok(provider) = table.table_provider().await {
136 ctx.register_table(CONTROL_TABLE_NAME, provider)?;
137 } else {
138 info!("Empty control table at init — deferring registration until first write");
139 }
140
141 Ok(Self {
142 schema,
143 table: Arc::new(AsyncRwLock::new(table)),
144 ctx: Arc::new(ctx),
145 pod_id,
146 })
147 }
148
149 pub async fn try_claim_task(&self, task_name: &str) -> Result<bool, TraceEngineError> {
159 let mut table_guard = self.table.write().await;
160
161 if let Err(e) = table_guard.update_incremental(None).await {
163 debug!("Control table update skipped (new table): {}", e);
164 }
165
166 let _ = self.ctx.deregister_table(CONTROL_TABLE_NAME);
168 if let Ok(provider) = table_guard.table_provider().await {
169 self.ctx.register_table(CONTROL_TABLE_NAME, provider)?;
170 }
171
172 let current = self
174 .read_task(&table_guard_to_ctx(&self.ctx), task_name)
175 .await?;
176
177 let now = Utc::now();
178
179 match current {
180 Some(record) => {
181 if record.status == status::PROCESSING {
183 let stale_threshold = now - Duration::minutes(STALE_LOCK_MINUTES);
184 if record.claimed_at > stale_threshold {
185 debug!(
186 "Task '{}' is being processed by pod '{}' (not stale), skipping",
187 task_name, record.pod_id
188 );
189 return Ok(false);
190 }
191 warn!(
192 "Task '{}' claimed by pod '{}' is stale (claimed_at: {}), reclaiming",
193 task_name, record.pod_id, record.claimed_at
194 );
195 }
196
197 if now < record.next_run_at {
199 debug!(
200 "Task '{}' not due until {}, skipping",
201 task_name, record.next_run_at
202 );
203 return Ok(false);
204 }
205
206 let claimed = TaskRecord {
209 task_name: task_name.to_string(),
210 status: status::PROCESSING.to_string(),
211 pod_id: self.pod_id.clone(),
212 claimed_at: now,
213 completed_at: None,
214 next_run_at: record.next_run_at,
215 };
216
217 match self.write_task_update(&mut table_guard, &claimed).await {
218 Ok(()) => {
219 info!("Successfully claimed task '{}'", task_name);
220 Ok(true)
221 }
222 Err(TraceEngineError::DataTableError(ref e))
223 if e.to_string().contains("Transaction") =>
224 {
225 info!("Lost OCC race for task '{}' to another pod", task_name);
226 Ok(false)
227 }
228 Err(e) => Err(e),
229 }
230 }
231 None => {
232 let claimed = TaskRecord {
235 task_name: task_name.to_string(),
236 status: status::PROCESSING.to_string(),
237 pod_id: self.pod_id.clone(),
238 claimed_at: now,
239 completed_at: None,
240 next_run_at: now, };
242
243 match self.write_task_update(&mut table_guard, &claimed).await {
244 Ok(()) => {
245 info!("Created and claimed new task '{}'", task_name);
246 Ok(true)
247 }
248 Err(TraceEngineError::DataTableError(ref e))
249 if e.to_string().contains("Transaction") =>
250 {
251 info!("Lost OCC race for new task '{}' to another pod", task_name);
252 Ok(false)
253 }
254 Err(e) => Err(e),
255 }
256 }
257 }
258 }
259
260 pub async fn release_task(
264 &self,
265 task_name: &str,
266 next_run_interval: Duration,
267 ) -> Result<(), TraceEngineError> {
268 let mut table_guard = self.table.write().await;
269 let now = Utc::now();
270
271 let released = TaskRecord {
272 task_name: task_name.to_string(),
273 status: status::IDLE.to_string(),
274 pod_id: self.pod_id.clone(),
275 claimed_at: now,
276 completed_at: Some(now),
277 next_run_at: now + next_run_interval,
278 };
279
280 self.write_task_update(&mut table_guard, &released).await?;
281
282 info!(
283 "Released task '{}', next run at {}",
284 task_name, released.next_run_at
285 );
286 Ok(())
287 }
288
289 pub async fn release_task_on_failure(&self, task_name: &str) -> Result<(), TraceEngineError> {
292 let mut table_guard = self.table.write().await;
293
294 if let Err(e) = table_guard.update_incremental(None).await {
296 debug!("Control table update skipped: {}", e);
297 }
298
299 let _ = self.ctx.deregister_table(CONTROL_TABLE_NAME);
300 if let Ok(provider) = table_guard.table_provider().await {
301 self.ctx.register_table(CONTROL_TABLE_NAME, provider)?;
302 }
303
304 let current = self
305 .read_task(&table_guard_to_ctx(&self.ctx), task_name)
306 .await?;
307
308 let now = Utc::now();
309 let next_run = current.map(|r| r.next_run_at).unwrap_or(now);
310
311 let released = TaskRecord {
312 task_name: task_name.to_string(),
313 status: status::IDLE.to_string(),
314 pod_id: self.pod_id.clone(),
315 claimed_at: now,
316 completed_at: Some(now),
317 next_run_at: next_run,
318 };
319
320 self.write_task_update(&mut table_guard, &released).await?;
321
322 warn!(
323 "Released task '{}' after failure, next_run_at unchanged: {}",
324 task_name, next_run
325 );
326 Ok(())
327 }
328
329 pub async fn is_task_due(&self, task_name: &str) -> Result<bool, TraceEngineError> {
331 let mut table_guard = self.table.write().await;
332
333 if let Err(e) = table_guard.update_incremental(None).await {
334 debug!("Control table update skipped: {}", e);
335 }
336
337 let _ = self.ctx.deregister_table(CONTROL_TABLE_NAME);
338 if let Ok(provider) = table_guard.table_provider().await {
339 self.ctx.register_table(CONTROL_TABLE_NAME, provider)?;
340 }
341
342 let current = self
343 .read_task(&table_guard_to_ctx(&self.ctx), task_name)
344 .await?;
345
346 let now = Utc::now();
347 match current {
348 Some(record) => {
349 if record.status == status::PROCESSING {
350 let stale_threshold = now - Duration::minutes(STALE_LOCK_MINUTES);
351 Ok(record.claimed_at <= stale_threshold)
353 } else {
354 Ok(now >= record.next_run_at)
355 }
356 }
357 None => Ok(true),
359 }
360 }
361
362 async fn read_task(
364 &self,
365 ctx: &SessionContext,
366 task_name: &str,
367 ) -> Result<Option<TaskRecord>, TraceEngineError> {
368 let table_exists = ctx.table_exist(CONTROL_TABLE_NAME)?;
369 if !table_exists {
370 return Ok(None);
371 }
372
373 let df = ctx
374 .table(CONTROL_TABLE_NAME)
375 .await
376 .map_err(TraceEngineError::DatafusionError)?;
377
378 let df = df
379 .filter(col("task_name").eq(lit(task_name)))
380 .map_err(TraceEngineError::DatafusionError)?;
381
382 let batches = df
383 .collect()
384 .await
385 .map_err(TraceEngineError::DatafusionError)?;
386
387 for batch in &batches {
391 if batch.num_rows() == 0 {
392 continue;
393 }
394
395 let get_string = |col_name: &'static str| -> Result<String, TraceEngineError> {
396 let col = batch
397 .column_by_name(col_name)
398 .ok_or(TraceEngineError::DowncastError(col_name))?;
399 let casted = arrow::compute::cast(col, &DataType::Utf8)
400 .map_err(TraceEngineError::ArrowError)?;
401 let arr = casted
402 .as_any()
403 .downcast_ref::<StringArray>()
404 .ok_or(TraceEngineError::DowncastError(col_name))?;
405 Ok(arr.value(0).to_string())
406 };
407
408 let get_timestamp =
409 |col_name: &'static str| -> Result<Option<DateTime<Utc>>, TraceEngineError> {
410 let col = batch
411 .column_by_name(col_name)
412 .ok_or(TraceEngineError::DowncastError(col_name))?;
413 if col.is_null(0) {
414 return Ok(None);
415 }
416 let arr = col
417 .as_any()
418 .downcast_ref::<TimestampMicrosecondArray>()
419 .ok_or(TraceEngineError::DowncastError(col_name))?;
420 Ok(DateTime::from_timestamp_micros(arr.value(0)))
421 };
422
423 let task_name_val = get_string("task_name")?;
424 let status_val = get_string("status")?;
425 let pod_id_val = get_string("pod_id")?;
426 let claimed_at = get_timestamp("claimed_at")?.unwrap_or_else(Utc::now);
427 let completed_at = get_timestamp("completed_at")?;
428 let next_run_at = get_timestamp("next_run_at")?.unwrap_or_else(Utc::now);
429
430 return Ok(Some(TaskRecord {
431 task_name: task_name_val,
432 status: status_val,
433 pod_id: pod_id_val,
434 claimed_at,
435 completed_at,
436 next_run_at,
437 }));
438 }
439
440 Ok(None)
441 }
442
443 async fn write_task_update(
449 &self,
450 table_guard: &mut DeltaTable,
451 record: &TaskRecord,
452 ) -> Result<(), TraceEngineError> {
453 let batch = build_task_batch(&self.schema, record)?;
454
455 debug_assert!(
461 record
462 .task_name
463 .chars()
464 .all(|c| c.is_alphanumeric() || c == '_'),
465 "task_name must be alphanumeric + underscore, got: {}",
466 record.task_name
467 );
468 let predicate = format!("task_name = '{}'", record.task_name);
469 let delete_result = table_guard.clone().delete().with_predicate(predicate).await;
470
471 match delete_result {
472 Ok((updated_table, _metrics)) => {
473 let updated_table = updated_table
475 .write(vec![batch])
476 .with_save_mode(deltalake::protocol::SaveMode::Append)
477 .await?;
478
479 let _ = self.ctx.deregister_table(CONTROL_TABLE_NAME);
480 if let Ok(provider) = updated_table.table_provider().await {
481 self.ctx.register_table(CONTROL_TABLE_NAME, provider)?;
482 }
483 *table_guard = updated_table;
484 }
485 Err(e) => {
486 let err_msg = e.to_string();
487 if !err_msg.contains("No data") && !err_msg.contains("empty") {
491 warn!(
492 "Delete before write_task_update failed unexpectedly: {}",
493 err_msg
494 );
495 return Err(TraceEngineError::DataTableError(e));
496 }
497
498 let updated_table = table_guard
499 .clone()
500 .write(vec![batch])
501 .with_save_mode(deltalake::protocol::SaveMode::Append)
502 .await?;
503
504 let _ = self.ctx.deregister_table(CONTROL_TABLE_NAME);
505 if let Ok(provider) = updated_table.table_provider().await {
506 self.ctx.register_table(CONTROL_TABLE_NAME, provider)?;
507 }
508 *table_guard = updated_table;
509 }
510 }
511
512 Ok(())
513 }
514}
515
516fn table_guard_to_ctx(ctx: &Arc<SessionContext>) -> SessionContext {
518 ctx.as_ref().clone()
519}
520
521async fn build_or_create_control_table(
523 object_store: &ObjectStore,
524 schema: SchemaRef,
525) -> Result<DeltaTable, TraceEngineError> {
526 register_cloud_logstore_factories();
529
530 let base_url = object_store.get_base_url()?;
531 let control_url = append_path_to_url(&base_url, CONTROL_TABLE_NAME)?;
532
533 info!(
534 "Loading control table [{}://.../{} ]",
535 control_url.scheme(),
536 control_url
537 .path_segments()
538 .and_then(|mut s| s.next_back())
539 .unwrap_or(CONTROL_TABLE_NAME)
540 );
541
542 let store = object_store.as_dyn_object_store();
543
544 let is_delta_table = if control_url.scheme() == "file" {
545 if let Ok(path) = control_url.to_file_path() {
546 if !path.exists() {
547 info!("Creating directory for control table: {:?}", path);
548 std::fs::create_dir_all(&path)?;
549 }
550 path.join("_delta_log").exists()
551 } else {
552 false
553 }
554 } else {
555 match DeltaTableBuilder::from_url(control_url.clone()) {
556 Ok(builder) => builder
557 .with_storage_backend(store.clone(), control_url.clone())
558 .load()
559 .await
560 .is_ok(),
561 Err(_) => false,
562 }
563 };
564
565 if is_delta_table {
566 info!(
567 "Loaded existing control table [{}://.../{} ]",
568 control_url.scheme(),
569 control_url
570 .path_segments()
571 .and_then(|mut s| s.next_back())
572 .unwrap_or(CONTROL_TABLE_NAME)
573 );
574 let table = DeltaTableBuilder::from_url(control_url.clone())?
575 .with_storage_backend(store, control_url)
576 .load()
577 .await?;
578 Ok(table)
579 } else {
580 info!("Creating new control table");
581 let table = DeltaTableBuilder::from_url(control_url.clone())?
582 .with_storage_backend(store, control_url)
583 .build()?;
584
585 let delta_fields = arrow_schema_to_delta(&schema);
586
587 table
588 .create()
589 .with_table_name(CONTROL_TABLE_NAME)
590 .with_columns(delta_fields)
591 .with_configuration_property(TableProperty::CheckpointInterval, Some("5"))
592 .await
593 .map_err(Into::into)
594 }
595}
596
597fn append_path_to_url(base: &Url, segment: &str) -> Result<Url, TraceEngineError> {
599 let mut url = base.clone();
600 if !url.path().ends_with('/') {
602 url.set_path(&format!("{}/", url.path()));
603 }
604 url = url.join(segment)?;
605 Ok(url)
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611 use scouter_settings::ObjectStorageSettings;
612
613 fn make_test_object_store(storage_settings: &ObjectStorageSettings) -> ObjectStore {
614 ObjectStore::new(storage_settings).unwrap()
615 }
616
617 fn cleanup() {
618 let storage_settings = ObjectStorageSettings::default();
619 let current_dir = std::env::current_dir().unwrap();
620 let storage_path = current_dir.join(storage_settings.storage_root());
621 if storage_path.exists() {
622 let _ = std::fs::remove_dir_all(storage_path);
623 }
624 }
625
626 #[tokio::test]
627 async fn test_control_table_init() -> Result<(), TraceEngineError> {
628 cleanup();
629
630 let settings = ObjectStorageSettings::default();
631 let object_store = make_test_object_store(&settings);
632 let engine = ControlTableEngine::new(&object_store, "pod-1".to_string()).await?;
633
634 let due = engine.is_task_due("optimize").await?;
636 assert!(due, "New task should be due (never run before)");
637
638 cleanup();
639 Ok(())
640 }
641
642 #[tokio::test]
643 async fn test_claim_and_release() -> Result<(), TraceEngineError> {
644 cleanup();
645
646 let settings = ObjectStorageSettings::default();
647 let object_store = make_test_object_store(&settings);
648 let engine = ControlTableEngine::new(&object_store, "pod-1".to_string()).await?;
649
650 let claimed = engine.try_claim_task("optimize").await?;
652 assert!(claimed, "First claim should succeed");
653
654 let claimed_again = engine.try_claim_task("optimize").await?;
656 assert!(
657 !claimed_again,
658 "Second claim should fail (already processing)"
659 );
660
661 engine.release_task("optimize", Duration::hours(1)).await?;
663
664 let due = engine.is_task_due("optimize").await?;
666 assert!(!due, "Task should not be due yet");
667
668 cleanup();
669 Ok(())
670 }
671
672 #[tokio::test]
673 async fn test_claim_release_then_due() -> Result<(), TraceEngineError> {
674 cleanup();
675
676 let settings = ObjectStorageSettings::default();
677 let object_store = make_test_object_store(&settings);
678 let engine = ControlTableEngine::new(&object_store, "pod-1".to_string()).await?;
679
680 let claimed = engine.try_claim_task("vacuum").await?;
682 assert!(claimed);
683
684 engine.release_task("vacuum", Duration::seconds(0)).await?;
685
686 let due = engine.is_task_due("vacuum").await?;
688 assert!(due, "Task should be due after 0-second interval");
689
690 let claimed = engine.try_claim_task("vacuum").await?;
692 assert!(claimed, "Task should be claimable after release");
693
694 engine.release_task_on_failure("vacuum").await?;
696
697 cleanup();
698 Ok(())
699 }
700
701 #[tokio::test]
702 async fn test_multiple_tasks() -> Result<(), TraceEngineError> {
703 cleanup();
704
705 let settings = ObjectStorageSettings::default();
706 let object_store = make_test_object_store(&settings);
707 let engine = ControlTableEngine::new(&object_store, "pod-1".to_string()).await?;
708
709 let claimed_opt = engine.try_claim_task("optimize").await?;
711 let claimed_vac = engine.try_claim_task("vacuum").await?;
712 assert!(claimed_opt, "Optimize claim should succeed");
713 assert!(claimed_vac, "Vacuum claim should succeed");
714
715 engine.release_task("optimize", Duration::hours(24)).await?;
717 engine.release_task("vacuum", Duration::hours(168)).await?;
718
719 cleanup();
720 Ok(())
721 }
722}