tx2_core/
query.rs

1use std::collections::{HashSet, HashMap};
2use crate::component::{ComponentStore, ComponentId};
3use crate::entity::EntityId;
4
5#[derive(Debug, Clone)]
6pub enum QueryFilter {
7    All(Vec<ComponentId>),
8    Any(Vec<ComponentId>),
9    None(Vec<ComponentId>),
10}
11
12#[derive(Default)]
13pub struct QueryDescriptor {
14    pub all: Vec<ComponentId>,
15    pub any: Vec<ComponentId>,
16    pub none: Vec<ComponentId>,
17}
18
19pub struct Query {
20    filters: Vec<QueryFilter>,
21    cached_results: Option<HashSet<EntityId>>,
22    dirty: bool,
23}
24
25impl Query {
26    pub fn new(descriptor: QueryDescriptor) -> Self {
27        let mut filters = Vec::new();
28        if !descriptor.all.is_empty() {
29            filters.push(QueryFilter::All(descriptor.all));
30        }
31        if !descriptor.any.is_empty() {
32            filters.push(QueryFilter::Any(descriptor.any));
33        }
34        if !descriptor.none.is_empty() {
35            filters.push(QueryFilter::None(descriptor.none));
36        }
37
38        if filters.is_empty() {
39            panic!("Query must have at least one filter");
40        }
41
42        Self {
43            filters,
44            cached_results: None,
45            dirty: true,
46        }
47    }
48
49    pub fn matches(&self, entity_id: EntityId, store: &ComponentStore) -> bool {
50        for filter in &self.filters {
51            match filter {
52                QueryFilter::All(components) => {
53                    if !components.iter().all(|c| store.has(entity_id, c)) {
54                        return false;
55                    }
56                }
57                QueryFilter::Any(components) => {
58                    if !components.iter().any(|c| store.has(entity_id, c)) {
59                        return false;
60                    }
61                }
62                QueryFilter::None(components) => {
63                    if components.iter().any(|c| store.has(entity_id, c)) {
64                        return false;
65                    }
66                }
67            }
68        }
69        true
70    }
71
72    pub fn execute(&mut self, store: &ComponentStore) -> HashSet<EntityId> {
73        if !self.dirty {
74            if let Some(results) = &self.cached_results {
75                return results.clone();
76            }
77        }
78
79        let candidates = self.get_candidate_entities(store);
80        let mut results = HashSet::new();
81
82        for entity_id in candidates {
83            if self.matches(entity_id, store) {
84                results.insert(entity_id);
85            }
86        }
87
88        self.cached_results = Some(results.clone());
89        self.dirty = false;
90        results
91    }
92
93    fn get_candidate_entities(&self, store: &ComponentStore) -> HashSet<EntityId> {
94        let mut candidates: Option<HashSet<EntityId>> = None;
95
96        // Union entities for any-filters
97        let any_filters: Vec<&QueryFilter> = self.filters.iter().filter(|f| matches!(f, QueryFilter::Any(_))).collect();
98        if !any_filters.is_empty() {
99            let mut any_candidates = HashSet::new();
100            for filter in any_filters {
101                if let QueryFilter::Any(components) = filter {
102                    for component_id in components {
103                        for entity_id in store.get_entities_with_component(component_id) {
104                            any_candidates.insert(entity_id);
105                        }
106                    }
107                }
108            }
109            candidates = Some(any_candidates);
110        }
111
112        // Narrow candidates by all-filters via intersection
113        let all_filters: Vec<&QueryFilter> = self.filters.iter().filter(|f| matches!(f, QueryFilter::All(_))).collect();
114        for filter in all_filters {
115            if let QueryFilter::All(components) = filter {
116                for component_id in components {
117                    let entities = store.get_entities_with_component(component_id);
118                    if let Some(current_candidates) = &mut candidates {
119                        current_candidates.retain(|id| entities.contains(id));
120                    } else {
121                        candidates = Some(entities);
122                    }
123                }
124            }
125        }
126
127        candidates.unwrap_or_else(|| store.get_all_entities())
128    }
129
130    pub fn mark_dirty(&mut self) {
131        self.dirty = true;
132    }
133}
134
135pub struct QueryBuilder {
136    descriptor: QueryDescriptor,
137}
138
139impl QueryBuilder {
140    pub fn new() -> Self {
141        Self {
142            descriptor: QueryDescriptor::default(),
143        }
144    }
145
146    pub fn all(mut self, components: Vec<ComponentId>) -> Self {
147        self.descriptor.all = components;
148        self
149    }
150
151    pub fn any(mut self, components: Vec<ComponentId>) -> Self {
152        self.descriptor.any = components;
153        self
154    }
155
156    pub fn none(mut self, components: Vec<ComponentId>) -> Self {
157        self.descriptor.none = components;
158        self
159    }
160
161    pub fn build(self) -> Query {
162        Query::new(self.descriptor)
163    }
164}
165
166pub struct QueryCache {
167    queries: HashMap<String, Query>,
168}
169
170impl QueryCache {
171    pub fn new() -> Self {
172        Self {
173            queries: HashMap::new(),
174        }
175    }
176
177    pub fn get(&mut self, descriptor: QueryDescriptor) -> &mut Query {
178        let key = self.get_key(&descriptor);
179        self.queries.entry(key).or_insert_with(|| Query::new(descriptor))
180    }
181
182    fn get_key(&self, descriptor: &QueryDescriptor) -> String {
183        let mut parts = Vec::new();
184        if !descriptor.all.is_empty() {
185            let mut sorted = descriptor.all.clone();
186            sorted.sort();
187            parts.push(format!("all:{}", sorted.join(",")));
188        }
189        if !descriptor.any.is_empty() {
190            let mut sorted = descriptor.any.clone();
191            sorted.sort();
192            parts.push(format!("any:{}", sorted.join(",")));
193        }
194        if !descriptor.none.is_empty() {
195            let mut sorted = descriptor.none.clone();
196            sorted.sort();
197            parts.push(format!("none:{}", sorted.join(",")));
198        }
199        parts.join("|")
200    }
201
202    pub fn mark_all_dirty(&mut self) {
203        for query in self.queries.values_mut() {
204            query.mark_dirty();
205        }
206    }
207
208    pub fn mark_dirty_for_component(&mut self, component_id: &str) {
209        for (key, query) in self.queries.iter_mut() {
210            if key.contains(component_id) {
211                query.mark_dirty();
212            }
213        }
214    }
215
216    pub fn clear(&mut self) {
217        self.queries.clear();
218    }
219}