Skip to main content

vortex_array/stats/
session.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4//! Session state for stats rewrite rules.
5
6use std::any::Any;
7use std::sync::Arc;
8
9use parking_lot::RwLock;
10use vortex_session::Ref;
11use vortex_session::SessionExt;
12use vortex_session::SessionVar;
13use vortex_utils::aliases::hash_map::HashMap;
14
15use crate::scalar_fn::ScalarFnId;
16use crate::stats::rewrite::StatsRewriteRule;
17use crate::stats::rewrite::StatsRewriteRuleRef;
18
19type StatsRewriteRuleSet = Arc<[StatsRewriteRuleRef]>;
20
21/// Session state for stats rewrite rules.
22#[derive(Debug, Default)]
23pub struct StatsRewriteSession {
24    rules: RwLock<HashMap<ScalarFnId, StatsRewriteRuleSet>>,
25}
26
27impl StatsRewriteSession {
28    /// Register a stats rewrite rule.
29    #[allow(dead_code)]
30    pub(crate) fn register<R: StatsRewriteRule>(&self, rule: R) {
31        self.register_ref(Arc::new(rule));
32    }
33
34    /// Register a shared stats rewrite rule.
35    #[allow(dead_code)]
36    pub(crate) fn register_ref(&self, rule: StatsRewriteRuleRef) {
37        let mut rules = self.rules.write();
38        let rule_id = rule.scalar_fn_id();
39        let mut updated_rules = rules
40            .get(&rule_id)
41            .map(|rules| rules.iter().cloned().collect::<Vec<_>>())
42            .unwrap_or_default();
43        updated_rules.push(rule);
44        rules.insert(rule_id, updated_rules.into());
45    }
46
47    /// Return the rewrite rules registered for `scalar_fn_id`.
48    pub(crate) fn rules_for(&self, scalar_fn_id: ScalarFnId) -> Option<StatsRewriteRuleSet> {
49        self.rules.read().get(&scalar_fn_id).cloned()
50    }
51}
52
53impl SessionVar for StatsRewriteSession {
54    fn as_any(&self) -> &dyn Any {
55        self
56    }
57
58    fn as_any_mut(&mut self) -> &mut dyn Any {
59        self
60    }
61}
62
63/// Extension trait for accessing stats rewrite session data.
64pub(crate) trait StatsRewriteSessionExt: SessionExt {
65    /// Returns the stats rewrite rule registry.
66    fn stats_rewrites(&self) -> Ref<'_, StatsRewriteSession> {
67        self.get::<StatsRewriteSession>()
68    }
69}
70impl<S: SessionExt> StatsRewriteSessionExt for S {}