1use lru::LruCache;
4use once_cell::sync::Lazy;
5use serde::{Deserialize, Serialize};
6use serde_json::{Map, Value};
7use std::collections::HashMap;
8use std::env;
9use std::fmt::{self, Display, Formatter};
10use std::future::Future;
11use std::hash::{Hash, Hasher};
12use std::num::NonZero;
13use std::sync::{Arc, RwLock};
14use tokio::sync::RwLock as AsyncRwLock; #[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct Action {
19 pub action_type: String,
21 pub command: String,
23}
24
25#[derive(Debug, Clone)]
27struct State {
28 name: String,
29 on_enter_actions: Vec<Action>,
30 on_exit_actions: Vec<Action>,
31 transitions: HashMap<String, Transition>, validations: Vec<ValidationRule>, }
34
35#[derive(Debug, Clone)]
37struct Transition {
38 to_state: String,
39 actions: Vec<Action>,
40 validations: Vec<ValidationRule>, }
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45struct ValidationRule {
46 field: String,
47 rules: Vec<FieldRule>,
48 condition: Option<Condition>, }
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53#[serde(tag = "type")]
54enum FieldRule {
55 #[serde(rename = "type_check")]
56 TypeCheck { expected_type: String },
57 #[serde(rename = "nullable")]
58 Nullable { is_nullable: bool },
59 #[serde(rename = "min_value")]
60 MinValue { value: f64 },
61 #[serde(rename = "max_value")]
62 MaxValue { value: f64 },
63 #[serde(rename = "editable")]
64 Editable { is_editable: bool },
65 #[serde(rename = "read_only")]
66 ReadOnly { is_read_only: bool },
67 #[serde(rename = "enum")]
68 Enum { values: Vec<Value> },
69 }
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
74struct Condition {
75 field: String,
76 operator: String,
77 value: Value,
78}
79
80#[derive(Debug, Serialize, Deserialize)]
82struct StateMachineConfig {
83 states: Vec<StateConfig>,
84 transitions: Vec<TransitionConfig>,
85}
86
87#[derive(Debug, Serialize, Deserialize)]
88struct StateConfig {
89 name: String,
90 #[serde(default)]
91 on_enter_actions: Vec<ActionConfig>,
92 #[serde(default)]
93 on_exit_actions: Vec<ActionConfig>,
94 validations: Option<Vec<ValidationRule>>,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98struct TransitionConfig {
99 from: String,
100 event: String,
101 to: String,
102 #[serde(default)]
103 actions: Vec<ActionConfig>, validations: Option<Vec<ValidationRule>>,
105}
106
107#[derive(Debug, Serialize, Deserialize)]
108struct ActionConfig {
109 action_type: String,
110 command: String,
111}
112
113type ActionHandler<C> = dyn for<'a> Fn(
114 &'a Action,
115 &'a mut Map<String, Value>,
116 &'a mut C,
117 ) -> std::pin::Pin<Box<dyn Future<Output = ()> + Send + 'a>>
118 + Send
119 + Sync;
120
121const LRU_CACHE_SIZE_ENV_KEY: &str = "STATEFLOW_LRU_CACHE_SIZE";
123const DEFAULT_CACHE_SIZE: usize = 100;
124
125fn get_cache_size() -> usize {
128 let lru_cache_size_env: usize = env::var(LRU_CACHE_SIZE_ENV_KEY)
129 .ok()
130 .and_then(|s| s.parse::<usize>().ok())
131 .unwrap_or(DEFAULT_CACHE_SIZE);
132 if lru_cache_size_env == 0 {
133 DEFAULT_CACHE_SIZE
134 } else {
135 lru_cache_size_env
136 }
137}
138
139static CONFIG_CACHE: Lazy<RwLock<LruCache<u64, Arc<StateMachineConfig>>>> = Lazy::new(|| {
141 let cache_size = get_cache_size();
142 RwLock::new(LruCache::new(NonZero::new(cache_size).unwrap()))
143});
144
145pub struct StateMachine<'a, C> {
147 states: Arc<RwLock<HashMap<String, State>>>,
148 current_state: Arc<RwLock<String>>,
149 action_handler: Arc<ActionHandler<C>>,
150 pub memory: Arc<AsyncRwLock<Map<String, Value>>>,
152 pub context: Arc<AsyncRwLock<C>>,
154 _marker: std::marker::PhantomData<&'a ()>, }
156
157impl<C> StateMachine<'_, C> {
158 pub fn new<F>(
160 config_content: &str,
161 initial_state: Option<String>,
162 action_handler: F,
163 memory: Map<String, Value>,
164 context: C,
165 ) -> Result<Self, String>
166 where
167 F: for<'b> Fn(
168 &'b Action,
169 &'b mut Map<String, Value>,
170 &'b mut C,
171 ) -> std::pin::Pin<Box<dyn Future<Output = ()> + Send + 'b>>
172 + Send
173 + Sync
174 + 'static,
175 {
176 let mut hasher = std::collections::hash_map::DefaultHasher::new();
178 config_content.hash(&mut hasher);
179 let config_hash = hasher.finish();
180
181 let config: Arc<StateMachineConfig> = {
183 let mut cache = CONFIG_CACHE.write().unwrap();
184 if let Some(cached_config) = cache.get(&config_hash) {
185 cached_config.clone()
186 } else {
187 let schema = Self::generate_and_compile_schema()?;
190
191 let config_value: serde_json::Value = serde_json::from_str(config_content)
193 .map_err(|err| format!("Invalid JSON format in configuration: {}", err))?;
194
195 let compiled_schema = jsonschema::Validator::new(&schema)
197 .map_err(|e| format!("Failed to compile JSON schema: {}", e))?;
198 if let Err(error) = compiled_schema.validate(&config_value) {
199 return Err(format!(
200 "JSON configuration does not conform to schema: {}",
201 error
202 ));
203 }
204
205 let config_deserialized: StateMachineConfig = serde_json::from_value(config_value)
207 .map_err(|err| format!("Failed to deserialize configuration: {}", err))?;
208
209 Self::validate_config(&config_deserialized)?;
211
212 let config_arc = Arc::new(config_deserialized);
214 cache.put(config_hash, config_arc.clone());
215 config_arc
216 }
217 };
218
219 let mut states = HashMap::new();
222 for state_config in &config.states {
223 let state = State {
224 name: state_config.name.clone(),
225 on_enter_actions: Self::create_actions(&state_config.on_enter_actions),
226 on_exit_actions: Self::create_actions(&state_config.on_exit_actions),
227 transitions: HashMap::new(),
228 validations: state_config.validations.clone().unwrap_or_default(),
229 };
230 states.insert(state_config.name.clone(), state);
231 }
232
233 for transition_config in &config.transitions {
235 if let Some(state) = states.get_mut(&transition_config.from) {
236 let transition = Transition {
237 to_state: transition_config.to.clone(),
238 actions: Self::create_actions(&transition_config.actions),
239 validations: transition_config.validations.clone().unwrap_or_default(),
240 };
241 state
242 .transitions
243 .insert(transition_config.event.clone(), transition);
244 }
245 }
246
247 let current_state = initial_state.unwrap_or_else(|| config.states[0].name.clone());
249
250 Ok(StateMachine {
251 states: Arc::new(RwLock::new(states)),
252 current_state: Arc::new(RwLock::new(current_state)),
253 action_handler: Arc::new(action_handler),
254 memory: Arc::new(AsyncRwLock::new(memory)),
255 context: Arc::new(AsyncRwLock::new(context)),
256 _marker: std::marker::PhantomData,
257 })
258 }
259
260 fn generate_and_compile_schema() -> Result<serde_json::Value, String> {
262 let schema_json = serde_json::json!({
264 "$schema": "http://json-schema.org/draft-07/schema#",
265 "title": "StateMachineConfig",
266 "type": "object",
267 "required": ["states", "transitions"],
268 "properties": {
269 "states": {
270 "type": "array",
271 "items": {
272 "type": "object",
273 "required": ["name"],
274 "properties": {
275 "name": { "type": "string" },
276 "on_enter_actions": {
277 "type": "array",
278 "items": { "$ref": "#/definitions/action" },
279 "default": []
280 },
281 "on_exit_actions": {
282 "type": "array",
283 "items": { "$ref": "#/definitions/action" },
284 "default": []
285 },
286 "validations": {
287 "type": "array",
288 "items": { "$ref": "#/definitions/validation_rule" }
289 }
290 }
291 }
292 },
293 "transitions": {
294 "type": "array",
295 "items": {
296 "type": "object",
297 "required": ["from", "event", "to"],
298 "properties": {
299 "from": { "type": "string" },
300 "event": { "type": "string" },
301 "to": { "type": "string" },
302 "actions": {
303 "type": "array",
304 "items": { "$ref": "#/definitions/action" },
305 "default": []
306 },
307 "validations": {
308 "type": "array",
309 "items": { "$ref": "#/definitions/validation_rule" }
310 }
311 }
312 }
313 }
314 },
315 "definitions": {
316 "action": {
317 "type": "object",
318 "required": ["action_type", "command"],
319 "properties": {
320 "action_type": { "type": "string" },
321 "command": { "type": "string" }
322 }
323 },
324 "validation_rule": {
325 "type": "object",
326 "required": ["field", "rules"],
327 "properties": {
328 "field": { "type": "string" },
329 "rules": {
330 "type": "array",
331 "items": { "$ref": "#/definitions/field_rule" }
332 },
333 "condition": { "$ref": "#/definitions/condition" }
334 }
335 },
336 "field_rule": {
337 "type": "object",
338 "oneOf": [
339 {
340 "type": "object",
341 "required": ["type"],
342 "properties": {
343 "type": { "const": "type_check" },
344 "expected_type": { "type": "string" }
345 }
346 },
347 {
348 "type": "object",
349 "required": ["type"],
350 "properties": {
351 "type": { "const": "nullable" },
352 "is_nullable": { "type": "boolean" }
353 }
354 },
355 {
356 "type": "object",
357 "required": ["type"],
358 "properties": {
359 "type": { "const": "min_value" },
360 "value": { "type": "number" }
361 }
362 },
363 {
364 "type": "object",
365 "required": ["type"],
366 "properties": {
367 "type": { "const": "max_value" },
368 "value": { "type": "number" }
369 }
370 },
371 {
372 "type": "object",
373 "required": ["type"],
374 "properties": {
375 "type": { "const": "editable" },
376 "is_editable": { "type": "boolean" }
377 }
378 },
379 {
380 "type": "object",
381 "required": ["type"],
382 "properties": {
383 "type": { "const": "read_only" },
384 "is_read_only": { "type": "boolean" }
385 }
386 },
387 {
388 "type": "object",
389 "required": ["type"],
390 "properties": {
391 "type": { "const": "enum" },
392 "values": {
393 "type": "array",
394 "items": {}
395 }
396 }
397 }
398 ]
400 },
401 "condition": {
402 "type": "object",
403 "required": ["field", "operator", "value"],
404 "properties": {
405 "field": { "type": "string" },
406 "operator": { "type": "string" },
407 "value": {}
408 }
409 }
410 }
411 });
412
413 Ok(schema_json)
414 }
415
416 fn create_actions(action_configs: &[ActionConfig]) -> Vec<Action> {
418 action_configs
419 .iter()
420 .map(|config| Action {
421 action_type: config.action_type.clone(),
422 command: config.command.clone(),
423 })
424 .collect()
425 }
426
427 fn validate_config(config: &StateMachineConfig) -> Result<(), String> {
429 if config.states.is_empty() {
430 return Err("State machine must have at least one state.".into());
431 }
432
433 let mut state_set = std::collections::HashSet::new();
434 for state in &config.states {
435 if !state_set.insert(&state.name) {
436 return Err(format!("Duplicate state found: {}", state.name));
437 }
438 }
439
440 for transition in &config.transitions {
441 if !config.states.iter().any(|s| s.name == transition.from) {
442 return Err(format!(
443 "Transition 'from' state '{}' is not defined in the states list.",
444 transition.from
445 ));
446 }
447 if !config.states.iter().any(|s| s.name == transition.to) {
448 return Err(format!(
449 "Transition 'to' state '{}' is not defined in the states list.",
450 transition.to
451 ));
452 }
453 if transition.event.trim().is_empty() {
454 return Err(format!(
455 "Transition from '{}' to '{}' has an empty event.",
456 transition.from, transition.to
457 ));
458 }
459 }
460
461 Ok(())
462 }
463
464 pub async fn trigger(&self, event: &str) -> Result<(), String> {
466 let current_state_name = {
468 let current_state_guard = self.current_state.read().unwrap();
469 current_state_guard.clone()
470 }; let (current_state, transition) = {
474 let states_guard = self.states.read().unwrap();
475 let current_state = states_guard.get(¤t_state_name).cloned();
477 if let Some(current_state) = current_state {
478 if let Some(transition) = current_state.transitions.get(event).cloned() {
480 (current_state, transition)
481 } else {
482 return Err(format!(
483 "No transition found for event '{}' from state '{}'.",
484 event, current_state_name
485 ));
486 }
487 } else {
488 return Err(format!(
489 "Current state '{}' not found in state machine.",
490 current_state_name
491 ));
492 }
493 }; let mut memory = self.memory.write().await;
499 let mut context = self.context.write().await;
500
501 Self::evaluate_validations(¤t_state.validations, &memory)?;
503
504 Self::evaluate_validations(&transition.validations, &memory)?;
506
507 self.execute_actions(¤t_state.on_exit_actions, &mut memory, &mut context)
509 .await;
510
511 self.execute_actions(&transition.actions, &mut memory, &mut context)
513 .await;
514
515 {
517 let mut current_state_guard = self.current_state.write().unwrap();
518 *current_state_guard = transition.to_state.clone();
519 } let next_state_on_enter_actions = {
523 let states_guard = self.states.read().unwrap();
524 if let Some(next_state) = states_guard.get(&transition.to_state) {
525 next_state.on_enter_actions.clone()
526 } else {
527 return Err(format!(
528 "Next state '{}' not found in state machine.",
529 transition.to_state
530 ));
531 }
532 }; self.execute_actions(&next_state_on_enter_actions, &mut memory, &mut context)
536 .await;
537
538 Ok(())
539 }
540
541 async fn execute_actions<'b>(
543 &self,
544 actions: &[Action],
545 memory: &'b mut Map<String, Value>,
546 context: &'b mut C,
547 ) {
548 for action in actions {
549 (self.action_handler)(action, memory, context).await;
550 }
551 }
552
553 fn evaluate_validations(
555 validations: &[ValidationRule],
556 memory: &Map<String, Value>,
557 ) -> Result<(), String> {
558 for validation in validations {
559 if let Some(condition) = &validation.condition {
561 if !Self::evaluate_condition(condition, memory)? {
562 continue;
564 }
565 }
566
567 let field_value = memory.get(&validation.field);
569
570 for rule in &validation.rules {
571 match rule {
572 FieldRule::TypeCheck { expected_type } => {
573 if let Some(value) = field_value {
574 let actual_type = Self::get_type_name(value);
575 if actual_type != expected_type {
576 return Err(format!(
577 "Validation failed: Field '{}' expected type '{}', got '{}'",
578 validation.field, expected_type, actual_type
579 ));
580 }
581 } else {
582 return Err(format!(
583 "Validation failed: Field '{}' is missing in memory",
584 validation.field
585 ));
586 }
587 }
588 FieldRule::Nullable { is_nullable } => {
589 if !*is_nullable && field_value.is_none() {
590 return Err(format!(
591 "Validation failed: Field '{}' cannot be null",
592 validation.field
593 ));
594 }
595 }
596 FieldRule::MinValue { value: min_value } => {
597 if let Some(Value::Number(num)) = field_value {
598 if num.as_f64().unwrap_or(f64::NAN) < *min_value {
599 return Err(format!(
600 "Validation failed: Field '{}' value '{}' is less than minimum '{}'",
601 validation.field, num, min_value
602 ));
603 }
604 } else {
605 return Err(format!(
606 "Validation failed: Field '{}' is not a number",
607 validation.field
608 ));
609 }
610 }
611 FieldRule::MaxValue { value: max_value } => {
612 if let Some(Value::Number(num)) = field_value {
613 if num.as_f64().unwrap_or(f64::NAN) > *max_value {
614 return Err(format!(
615 "Validation failed: Field '{}' value '{}' is greater than maximum '{}'",
616 validation.field, num, max_value
617 ));
618 }
619 } else {
620 return Err(format!(
621 "Validation failed: Field '{}' is not a number",
622 validation.field
623 ));
624 }
625 }
626 FieldRule::Editable { is_editable: _ }
627 | FieldRule::ReadOnly { is_read_only: _ } => {
628 }
630 FieldRule::Enum { values } => {
631 if let Some(value) = field_value {
632 if !values.contains(value) {
633 return Err(format!(
634 "Validation failed: Field '{}' value '{}' is not in enum {:?}",
635 validation.field, value, values
636 ));
637 }
638 } else {
639 return Err(format!(
640 "Validation failed: Field '{}' is missing in memory",
641 validation.field
642 ));
643 }
644 } }
646 }
647 }
648 Ok(())
649 }
650
651 fn evaluate_condition(
653 condition: &Condition,
654 memory: &Map<String, Value>,
655 ) -> Result<bool, String> {
656 let field_value = memory.get(&condition.field);
657 if let Some(actual_value) = field_value {
658 let result = match condition.operator.as_str() {
659 "==" => actual_value == &condition.value,
660 "!=" => actual_value != &condition.value,
661 ">" => Self::compare_values(
662 actual_value,
663 &condition.value,
664 std::cmp::Ordering::Greater,
665 )?,
666 "<" => {
667 Self::compare_values(actual_value, &condition.value, std::cmp::Ordering::Less)?
668 }
669 ">=" => {
670 let ordering = Self::compare_values_ordering(actual_value, &condition.value)?;
671 ordering == std::cmp::Ordering::Greater || ordering == std::cmp::Ordering::Equal
672 }
673 "<=" => {
674 let ordering = Self::compare_values_ordering(actual_value, &condition.value)?;
675 ordering == std::cmp::Ordering::Less || ordering == std::cmp::Ordering::Equal
676 }
677 _ => return Err(format!("Unsupported operator '{}'", condition.operator)),
678 };
679 Ok(result)
680 } else {
681 Err(format!(
682 "Condition evaluation failed: Field '{}' is missing in memory",
683 condition.field
684 ))
685 }
686 }
687
688 fn compare_values(
690 actual: &Value,
691 expected: &Value,
692 ordering: std::cmp::Ordering,
693 ) -> Result<bool, String> {
694 let actual_num = actual
695 .as_f64()
696 .ok_or_else(|| format!("Cannot compare non-numeric value '{}' in condition", actual))?;
697 let expected_num = expected.as_f64().ok_or_else(|| {
698 format!(
699 "Cannot compare non-numeric value '{}' in condition",
700 expected
701 )
702 })?;
703 Ok(actual_num.partial_cmp(&expected_num) == Some(ordering))
704 }
705
706 fn compare_values_ordering(
708 actual: &Value,
709 expected: &Value,
710 ) -> Result<std::cmp::Ordering, String> {
711 let actual_num = actual
712 .as_f64()
713 .ok_or_else(|| format!("Cannot compare non-numeric value '{}' in condition", actual))?;
714 let expected_num = expected.as_f64().ok_or_else(|| {
715 format!(
716 "Cannot compare non-numeric value '{}' in condition",
717 expected
718 )
719 })?;
720 Ok(actual_num
721 .partial_cmp(&expected_num)
722 .unwrap_or(std::cmp::Ordering::Equal))
723 }
724
725 fn get_type_name(value: &Value) -> &str {
727 match value {
728 Value::Null => "null",
729 Value::Bool(_) => "boolean",
730 Value::Number(_) => "number",
731 Value::String(_) => "string",
732 Value::Array(_) => "array",
733 Value::Object(_) => "object",
734 }
735 }
736
737 pub async fn get_current_state(&self) -> Result<String, String> {
739 let current_state_guard = self.current_state.read().unwrap();
740 Ok(current_state_guard.clone())
741 }
742}
743
744impl<C> Display for StateMachine<'_, C> {
746 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
747 let states = self.states.read().unwrap();
748 let current_state = self.current_state.read().unwrap();
749
750 writeln!(f, "State Machine Diagram:")?;
751 writeln!(f, "======================")?;
752
753 for (state_name, state) in &*states {
754 let marker = if *state_name == *current_state {
755 "->" } else {
757 " "
758 };
759 writeln!(f, "{} State: {}", marker, state.name)?;
760
761 for (event, transition) in &state.transitions {
762 writeln!(f, " -[{}]-> {}", event, transition.to_state)?;
763 }
764 }
765
766 writeln!(f, "======================")
767 }
768}