1use std::collections::{HashMap, HashSet};
2use crate::world::World;
3use crate::error::{SystemErrorContext, SystemErrorHandler, SystemErrorStrategy, default_error_handler};
4use std::sync::{Arc, Mutex};
5
6pub type SystemId = String;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub enum SystemPhase {
10 Init,
11 FixedUpdate,
12 Update,
13 LateUpdate,
14 Cleanup,
15}
16
17pub struct SystemContext<'a> {
18 pub world: &'a mut World,
19 pub delta_time: f64,
20 pub time: f64,
21 pub phase: SystemPhase,
22}
23
24pub trait SystemFn: Send + Sync {
25 fn run(&mut self, ctx: SystemContext) -> Result<(), String>;
26}
27
28impl<F> SystemFn for F
29where
30 F: FnMut(SystemContext) -> Result<(), String> + Send + Sync,
31{
32 fn run(&mut self, ctx: SystemContext) -> Result<(), String> {
33 (self)(ctx)
34 }
35}
36
37pub trait InfallibleSystemFn: Send + Sync {
39 fn run_infallible(&mut self, ctx: SystemContext);
40}
41
42impl<F> InfallibleSystemFn for F
43where
44 F: FnMut(SystemContext) + Send + Sync,
45{
46 fn run_infallible(&mut self, ctx: SystemContext) {
47 (self)(ctx);
48 }
49}
50
51struct SystemFnWrapper {
53 inner: Box<dyn SystemFn>,
54}
55
56impl SystemFnWrapper {
57 fn from_fallible(f: Box<dyn SystemFn>) -> Self {
58 Self { inner: f }
59 }
60
61 }
65
66pub struct System {
67 pub id: SystemId,
68 pub name: String,
69 pub phases: HashSet<SystemPhase>,
70 pub priority: i32,
71 pub run_before: HashSet<SystemId>,
72 pub run_after: HashSet<SystemId>,
73 pub enabled: bool,
74 pub consecutive_failures: u32,
75 pub on_error: Option<SystemErrorHandler>,
76 fn_ptr: Box<dyn SystemFn>,
77}
78
79impl System {
80 pub fn new(
81 id: SystemId,
82 name: String,
83 phases: HashSet<SystemPhase>,
84 priority: i32,
85 run_before: HashSet<SystemId>,
86 run_after: HashSet<SystemId>,
87 func: Box<dyn SystemFn>,
88 ) -> Self {
89 Self {
90 id,
91 name,
92 phases,
93 priority,
94 run_before,
95 run_after,
96 enabled: true,
97 consecutive_failures: 0,
98 on_error: None,
99 fn_ptr: func,
100 }
101 }
102
103 pub fn with_error_handler(mut self, handler: SystemErrorHandler) -> Self {
104 self.on_error = Some(handler);
105 self
106 }
107
108 pub fn run(&mut self, ctx: SystemContext) {
109 if !self.enabled {
110 return;
111 }
112
113 match self.fn_ptr.run(ctx) {
114 Ok(_) => {
115 self.consecutive_failures = 0;
116 }
117 Err(e) => {
118 self.consecutive_failures += 1;
119 let phase_str = match self.phases.iter().next().unwrap_or(&SystemPhase::Update) {
120 SystemPhase::Init => "init",
121 SystemPhase::FixedUpdate => "fixedUpdate",
122 SystemPhase::Update => "update",
123 SystemPhase::LateUpdate => "lateUpdate",
124 SystemPhase::Cleanup => "cleanup",
125 }.to_string(); let error_ctx = SystemErrorContext {
128 system_id: self.id.clone(),
129 error: e,
130 phase: phase_str,
131 consecutive_failures: self.consecutive_failures,
132 };
133
134 let strategy = if let Some(handler) = self.on_error {
135 handler(&error_ctx)
136 } else {
137 default_error_handler(&error_ctx)
138 };
139
140 match strategy {
141 SystemErrorStrategy::Disable => self.enabled = false,
142 SystemErrorStrategy::Ignore => {},
143 SystemErrorStrategy::Retry => {
144 }
147 }
148 }
149 }
150 }
151}
152
153pub struct SystemScheduler {
154 systems: HashMap<SystemId, Arc<Mutex<System>>>,
155 execution_order: HashMap<SystemPhase, Vec<SystemId>>,
156 dirty: bool,
157}
158
159impl SystemScheduler {
160 pub fn new() -> Self {
161 Self {
162 systems: HashMap::new(),
163 execution_order: HashMap::new(),
164 dirty: true,
165 }
166 }
167
168 pub fn add(&mut self, system: System) {
169 if self.systems.contains_key(&system.id) {
170 panic!("System {} already exists", system.id);
171 }
172 self.systems.insert(system.id.clone(), Arc::new(Mutex::new(system)));
173 self.dirty = true;
174 }
175
176 pub fn remove(&mut self, system_id: &str) -> bool {
177 if self.systems.remove(system_id).is_some() {
178 self.dirty = true;
179 return true;
180 }
181 false
182 }
183
184 pub fn execute_phase(&mut self, phase: SystemPhase, world: &mut World, delta_time: f64, time: f64) {
185 if self.dirty {
186 self.recompute_execution_order();
187 }
188
189 if let Some(system_ids) = self.execution_order.get(&phase) {
190 let ids = system_ids.clone();
191
192 for system_id in ids {
193 if let Some(system_arc) = self.systems.get(&system_id) {
194 let mut system = system_arc.lock().unwrap();
195 let ctx = SystemContext {
196 world,
197 delta_time,
198 time,
199 phase,
200 };
201 system.run(ctx);
202 }
203 }
204 }
205 }
206
207 fn recompute_execution_order(&mut self) {
208 self.execution_order.clear();
209 let phases = [
210 SystemPhase::Init,
211 SystemPhase::FixedUpdate,
212 SystemPhase::Update,
213 SystemPhase::LateUpdate,
214 SystemPhase::Cleanup,
215 ];
216
217 for phase in phases {
218 let phase_systems: Vec<Arc<Mutex<System>>> = self.systems.values()
219 .filter(|s| s.lock().unwrap().phases.contains(&phase))
220 .cloned()
221 .collect();
222
223 let sorted = self.topological_sort(phase_systems);
224 self.execution_order.insert(phase, sorted);
225 }
226
227 self.dirty = false;
228 }
229
230 fn topological_sort(&self, systems: Vec<Arc<Mutex<System>>>) -> Vec<SystemId> {
231 let mut sorted = Vec::new();
232 let mut visited = HashSet::new();
233 let mut visiting = HashSet::new();
234
235 let mut systems_by_priority = systems;
236 systems_by_priority.sort_by(|a, b| {
237 b.lock().unwrap().priority.cmp(&a.lock().unwrap().priority)
238 });
239
240 for system_arc in &systems_by_priority {
241 let system_id = system_arc.lock().unwrap().id.clone();
242 self.visit(&system_id, &mut visited, &mut visiting, &mut sorted);
243 }
244
245 sorted
246 }
247
248 fn visit(
249 &self,
250 system_id: &SystemId,
251 visited: &mut HashSet<SystemId>,
252 visiting: &mut HashSet<SystemId>,
253 sorted: &mut Vec<SystemId>,
254 ) {
255 if visited.contains(system_id) {
256 return;
257 }
258
259 if visiting.contains(system_id) {
260 panic!("Circular dependency detected in system: {}", system_id);
261 }
262
263 visiting.insert(system_id.clone());
264
265 let system_arc = self.systems.get(system_id).unwrap();
266 let run_after = system_arc.lock().unwrap().run_after.clone();
267
268 for after_id in run_after {
269 if self.systems.contains_key(&after_id) {
270 self.visit(&after_id, visited, visiting, sorted);
271 }
272 }
273
274 let run_before = system_arc.lock().unwrap().run_before.clone();
275 for before_id in run_before {
276 if let Some(_) = self.systems.get(&before_id) {
277 if !visited.contains(&before_id) {
278 visiting.remove(system_id);
279 return;
280 }
281 }
282 }
283
284 visiting.remove(system_id);
285 visited.insert(system_id.clone());
286 sorted.push(system_id.clone());
287 }
288}