stateflow/
lib.rs

1//! A simple state machine library for Rust.
2
3use lru::LruCache;
4use once_cell::sync::Lazy;
5use serde::{Deserialize, Serialize};
6use serde_json::{Map, Value};
7use std::collections::HashMap;
8use std::env;
9use std::fmt::{self, Display, Formatter};
10use std::future::Future;
11use std::hash::{Hash, Hasher};
12use std::num::NonZero;
13use std::sync::{Arc, RwLock};
14use tokio::sync::RwLock as AsyncRwLock; // Alias to differentiate
15
16/// Represents an action with a type and command.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct Action {
19    /// The type of the action.
20    pub action_type: String,
21    /// The command to execute.
22    pub command: String,
23}
24
25/// A struct representing a state and its transitions, including actions on enter and exit.
26#[derive(Debug, Clone)]
27struct State {
28    name: String,
29    on_enter_actions: Vec<Action>,
30    on_exit_actions: Vec<Action>,
31    transitions: HashMap<String, Transition>, // Key: event name, Value: Transition instance
32    validations: Vec<ValidationRule>,         // State validation rules
33}
34
35/// Represents a transition between states, including actions and validations.
36#[derive(Debug, Clone)]
37struct Transition {
38    to_state: String,
39    actions: Vec<Action>,
40    validations: Vec<ValidationRule>, // Transition validation rules
41}
42
43/// Represents a validation rule applied to the memory.
44#[derive(Debug, Clone, Serialize, Deserialize)]
45struct ValidationRule {
46    field: String,
47    rules: Vec<FieldRule>,
48    condition: Option<Condition>, // Optional condition for conditional validations
49}
50
51/// Represents a single rule for a field.
52#[derive(Debug, Clone, Serialize, Deserialize)]
53#[serde(tag = "type")]
54enum FieldRule {
55    #[serde(rename = "type_check")]
56    TypeCheck { expected_type: String },
57    #[serde(rename = "nullable")]
58    Nullable { is_nullable: bool },
59    #[serde(rename = "min_value")]
60    MinValue { value: f64 },
61    #[serde(rename = "max_value")]
62    MaxValue { value: f64 },
63    #[serde(rename = "editable")]
64    Editable { is_editable: bool },
65    #[serde(rename = "read_only")]
66    ReadOnly { is_read_only: bool },
67    #[serde(rename = "enum")]
68    Enum { values: Vec<Value> },
69    // Add more rules as needed
70}
71
72/// Represents a condition for conditional validations.
73#[derive(Debug, Clone, Serialize, Deserialize)]
74struct Condition {
75    field: String,
76    operator: String,
77    value: Value,
78}
79
80/// Represents the configuration of a state machine loaded from JSON.
81#[derive(Debug, Serialize, Deserialize)]
82struct StateMachineConfig {
83    states: Vec<StateConfig>,
84    transitions: Vec<TransitionConfig>,
85}
86
87#[derive(Debug, Serialize, Deserialize)]
88struct StateConfig {
89    name: String,
90    #[serde(default)]
91    on_enter_actions: Vec<ActionConfig>,
92    #[serde(default)]
93    on_exit_actions: Vec<ActionConfig>,
94    validations: Option<Vec<ValidationRule>>,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98struct TransitionConfig {
99    from: String,
100    event: String,
101    to: String,
102    #[serde(default)]
103    actions: Vec<ActionConfig>, // Actions triggered during the transition
104    validations: Option<Vec<ValidationRule>>,
105}
106
107#[derive(Debug, Serialize, Deserialize)]
108struct ActionConfig {
109    action_type: String,
110    command: String,
111}
112
113type ActionHandler<C> = dyn for<'a> Fn(
114        &'a Action,
115        &'a mut Map<String, Value>,
116        &'a mut C,
117    ) -> std::pin::Pin<Box<dyn Future<Output = ()> + Send + 'a>>
118    + Send
119    + Sync;
120
121/// Define environment variable name and default cache size
122const LRU_CACHE_SIZE_ENV_KEY: &str = "STATEFLOW_LRU_CACHE_SIZE";
123const DEFAULT_CACHE_SIZE: usize = 100;
124
125/// Retrieves the LRU cache size from the environment variable.
126/// Defaults to `DEFAULT_CACHE_SIZE` if not set or invalid.
127fn get_cache_size() -> usize {
128    let lru_cache_size_env: usize = env::var(LRU_CACHE_SIZE_ENV_KEY)
129        .ok()
130        .and_then(|s| s.parse::<usize>().ok())
131        .unwrap_or(DEFAULT_CACHE_SIZE);
132    if lru_cache_size_env == 0 {
133        DEFAULT_CACHE_SIZE
134    } else {
135        lru_cache_size_env
136    }
137}
138
139/// Static cache for storing parsed configurations
140static CONFIG_CACHE: Lazy<RwLock<LruCache<u64, Arc<StateMachineConfig>>>> = Lazy::new(|| {
141    let cache_size = get_cache_size();
142    RwLock::new(LruCache::new(NonZero::new(cache_size).unwrap()))
143});
144
145/// The state machine containing all states, the current state, memory, context, and handlers.
146pub struct StateMachine<'a, C> {
147    states: Arc<RwLock<HashMap<String, State>>>,
148    current_state: Arc<RwLock<String>>,
149    action_handler: Arc<ActionHandler<C>>,
150    /// The memory used by the state machine to store data.
151    pub memory: Arc<AsyncRwLock<Map<String, Value>>>,
152    /// The context used by the state machine to store state.
153    pub context: Arc<AsyncRwLock<C>>,
154    _marker: std::marker::PhantomData<&'a ()>, // To tie the lifetime to the struct
155}
156
157impl<C> StateMachine<'_, C> {
158    /// Creates a new state machine from a JSON configuration string.
159    pub fn new<F>(
160        config_content: &str,
161        initial_state: Option<String>,
162        action_handler: F,
163        memory: Map<String, Value>,
164        context: C,
165    ) -> Result<Self, String>
166    where
167        F: for<'b> Fn(
168                &'b Action,
169                &'b mut Map<String, Value>,
170                &'b mut C,
171            ) -> std::pin::Pin<Box<dyn Future<Output = ()> + Send + 'b>>
172            + Send
173            + Sync
174            + 'static,
175    {
176        // Compute the hash of the config_content
177        let mut hasher = std::collections::hash_map::DefaultHasher::new();
178        config_content.hash(&mut hasher);
179        let config_hash = hasher.finish();
180
181        // Try to get the cached config
182        let config: Arc<StateMachineConfig> = {
183            let mut cache = CONFIG_CACHE.write().unwrap();
184            if let Some(cached_config) = cache.get(&config_hash) {
185                cached_config.clone()
186            } else {
187                // Parse and validate the config
188                // Generate and compile the JSON schema
189                let schema = Self::generate_and_compile_schema()?;
190
191                // Parse the configuration from the provided string
192                let config_value: serde_json::Value = serde_json::from_str(config_content)
193                    .map_err(|err| format!("Invalid JSON format in configuration: {}", err))?;
194
195                // Validate the configuration against the schema
196                let compiled_schema = jsonschema::Validator::new(&schema)
197                    .map_err(|e| format!("Failed to compile JSON schema: {}", e))?;
198                if let Err(error) = compiled_schema.validate(&config_value) {
199                    return Err(format!(
200                        "JSON configuration does not conform to schema: {}",
201                        error
202                    ));
203                }
204
205                // Deserialize the configuration
206                let config_deserialized: StateMachineConfig = serde_json::from_value(config_value)
207                    .map_err(|err| format!("Failed to deserialize configuration: {}", err))?;
208
209                // Validate the config
210                Self::validate_config(&config_deserialized)?;
211
212                // Cache the config
213                let config_arc = Arc::new(config_deserialized);
214                cache.put(config_hash, config_arc.clone());
215                config_arc
216            }
217        };
218
219        // Now proceed to create the StateMachine using `config`
220        // Create states and populate transitions
221        let mut states = HashMap::new();
222        for state_config in &config.states {
223            let state = State {
224                name: state_config.name.clone(),
225                on_enter_actions: Self::create_actions(&state_config.on_enter_actions),
226                on_exit_actions: Self::create_actions(&state_config.on_exit_actions),
227                transitions: HashMap::new(),
228                validations: state_config.validations.clone().unwrap_or_default(),
229            };
230            states.insert(state_config.name.clone(), state);
231        }
232
233        // Populate transitions for each state
234        for transition_config in &config.transitions {
235            if let Some(state) = states.get_mut(&transition_config.from) {
236                let transition = Transition {
237                    to_state: transition_config.to.clone(),
238                    actions: Self::create_actions(&transition_config.actions),
239                    validations: transition_config.validations.clone().unwrap_or_default(),
240                };
241                state
242                    .transitions
243                    .insert(transition_config.event.clone(), transition);
244            }
245        }
246
247        // Determine the starting state: use provided initial state or default to the first state
248        let current_state = initial_state.unwrap_or_else(|| config.states[0].name.clone());
249
250        Ok(StateMachine {
251            states: Arc::new(RwLock::new(states)),
252            current_state: Arc::new(RwLock::new(current_state)),
253            action_handler: Arc::new(action_handler),
254            memory: Arc::new(AsyncRwLock::new(memory)),
255            context: Arc::new(AsyncRwLock::new(context)),
256            _marker: std::marker::PhantomData,
257        })
258    }
259
260    /// Generates and compiles the JSON schema for the state machine configuration.
261    fn generate_and_compile_schema() -> Result<serde_json::Value, String> {
262        // Define the JSON schema as a serde_json::Value
263        let schema_json = serde_json::json!({
264            "$schema": "http://json-schema.org/draft-07/schema#",
265            "title": "StateMachineConfig",
266            "type": "object",
267            "required": ["states", "transitions"],
268            "properties": {
269                "states": {
270                    "type": "array",
271                    "items": {
272                        "type": "object",
273                        "required": ["name"],
274                        "properties": {
275                            "name": { "type": "string" },
276                            "on_enter_actions": {
277                                "type": "array",
278                                "items": { "$ref": "#/definitions/action" },
279                                "default": []
280                            },
281                            "on_exit_actions": {
282                                "type": "array",
283                                "items": { "$ref": "#/definitions/action" },
284                                "default": []
285                            },
286                            "validations": {
287                                "type": "array",
288                                "items": { "$ref": "#/definitions/validation_rule" }
289                            }
290                        }
291                    }
292                },
293                "transitions": {
294                    "type": "array",
295                    "items": {
296                        "type": "object",
297                        "required": ["from", "event", "to"],
298                        "properties": {
299                            "from": { "type": "string" },
300                            "event": { "type": "string" },
301                            "to": { "type": "string" },
302                            "actions": {
303                                "type": "array",
304                                "items": { "$ref": "#/definitions/action" },
305                                "default": []
306                            },
307                            "validations": {
308                                "type": "array",
309                                "items": { "$ref": "#/definitions/validation_rule" }
310                            }
311                        }
312                    }
313                }
314            },
315            "definitions": {
316                "action": {
317                    "type": "object",
318                    "required": ["action_type", "command"],
319                    "properties": {
320                        "action_type": { "type": "string" },
321                        "command": { "type": "string" }
322                    }
323                },
324                "validation_rule": {
325                    "type": "object",
326                    "required": ["field", "rules"],
327                    "properties": {
328                        "field": { "type": "string" },
329                        "rules": {
330                            "type": "array",
331                            "items": { "$ref": "#/definitions/field_rule" }
332                        },
333                        "condition": { "$ref": "#/definitions/condition" }
334                    }
335                },
336                "field_rule": {
337                    "type": "object",
338                    "oneOf": [
339                        {
340                            "type": "object",
341                            "required": ["type"],
342                            "properties": {
343                                "type": { "const": "type_check" },
344                                "expected_type": { "type": "string" }
345                            }
346                        },
347                        {
348                            "type": "object",
349                            "required": ["type"],
350                            "properties": {
351                                "type": { "const": "nullable" },
352                                "is_nullable": { "type": "boolean" }
353                            }
354                        },
355                        {
356                            "type": "object",
357                            "required": ["type"],
358                            "properties": {
359                                "type": { "const": "min_value" },
360                                "value": { "type": "number" }
361                            }
362                        },
363                        {
364                            "type": "object",
365                            "required": ["type"],
366                            "properties": {
367                                "type": { "const": "max_value" },
368                                "value": { "type": "number" }
369                            }
370                        },
371                        {
372                            "type": "object",
373                            "required": ["type"],
374                            "properties": {
375                                "type": { "const": "editable" },
376                                "is_editable": { "type": "boolean" }
377                            }
378                        },
379                        {
380                            "type": "object",
381                            "required": ["type"],
382                            "properties": {
383                                "type": { "const": "read_only" },
384                                "is_read_only": { "type": "boolean" }
385                            }
386                        },
387                        {
388                            "type": "object",
389                            "required": ["type"],
390                            "properties": {
391                                "type": { "const": "enum" },
392                                "values": {
393                                    "type": "array",
394                                    "items": {}
395                                }
396                            }
397                        }
398                        // Add more field rule schemas as needed
399                    ]
400                },
401                "condition": {
402                    "type": "object",
403                    "required": ["field", "operator", "value"],
404                    "properties": {
405                        "field": { "type": "string" },
406                        "operator": { "type": "string" },
407                        "value": {}
408                    }
409                }
410            }
411        });
412
413        Ok(schema_json)
414    }
415
416    /// Creates actions from the action configuration.
417    fn create_actions(action_configs: &[ActionConfig]) -> Vec<Action> {
418        action_configs
419            .iter()
420            .map(|config| Action {
421                action_type: config.action_type.clone(),
422                command: config.command.clone(),
423            })
424            .collect()
425    }
426
427    /// Validates the state machine configuration.
428    fn validate_config(config: &StateMachineConfig) -> Result<(), String> {
429        if config.states.is_empty() {
430            return Err("State machine must have at least one state.".into());
431        }
432
433        let mut state_set = std::collections::HashSet::new();
434        for state in &config.states {
435            if !state_set.insert(&state.name) {
436                return Err(format!("Duplicate state found: {}", state.name));
437            }
438        }
439
440        for transition in &config.transitions {
441            if !config.states.iter().any(|s| s.name == transition.from) {
442                return Err(format!(
443                    "Transition 'from' state '{}' is not defined in the states list.",
444                    transition.from
445                ));
446            }
447            if !config.states.iter().any(|s| s.name == transition.to) {
448                return Err(format!(
449                    "Transition 'to' state '{}' is not defined in the states list.",
450                    transition.to
451                ));
452            }
453            if transition.event.trim().is_empty() {
454                return Err(format!(
455                    "Transition from '{}' to '{}' has an empty event.",
456                    transition.from, transition.to
457                ));
458            }
459        }
460
461        Ok(())
462    }
463
464    /// Triggers an event, causing a state transition if applicable and executing actions.
465    pub async fn trigger(&self, event: &str) -> Result<(), String> {
466        // Acquire a read lock on the current state and clone its value
467        let current_state_name = {
468            let current_state_guard = self.current_state.read().unwrap();
469            current_state_guard.clone()
470        }; // Lock is released here
471
472        // Acquire a read lock on the states and get the current state and transition
473        let (current_state, transition) = {
474            let states_guard = self.states.read().unwrap();
475            // Clone the current state to own its data
476            let current_state = states_guard.get(&current_state_name).cloned();
477            if let Some(current_state) = current_state {
478                // Clone the transition to own its data
479                if let Some(transition) = current_state.transitions.get(event).cloned() {
480                    (current_state, transition)
481                } else {
482                    return Err(format!(
483                        "No transition found for event '{}' from state '{}'.",
484                        event, current_state_name
485                    ));
486                }
487            } else {
488                return Err(format!(
489                    "Current state '{}' not found in state machine.",
490                    current_state_name
491                ));
492            }
493        }; // Lock is released here
494
495        // Now `current_state` and `transition` own their data and do not borrow from `states_guard`
496
497        // Acquire write locks on memory and context
498        let mut memory = self.memory.write().await;
499        let mut context = self.context.write().await;
500
501        // Execute state validations
502        Self::evaluate_validations(&current_state.validations, &memory)?;
503
504        // Execute transition validations
505        Self::evaluate_validations(&transition.validations, &memory)?;
506
507        // Execute on-exit actions
508        self.execute_actions(&current_state.on_exit_actions, &mut memory, &mut context)
509            .await;
510
511        // Execute transition actions
512        self.execute_actions(&transition.actions, &mut memory, &mut context)
513            .await;
514
515        // Update the current state
516        {
517            let mut current_state_guard = self.current_state.write().unwrap();
518            *current_state_guard = transition.to_state.clone();
519        } // Lock is released here
520
521        // Execute on-enter actions of the next state
522        let next_state_on_enter_actions = {
523            let states_guard = self.states.read().unwrap();
524            if let Some(next_state) = states_guard.get(&transition.to_state) {
525                next_state.on_enter_actions.clone()
526            } else {
527                return Err(format!(
528                    "Next state '{}' not found in state machine.",
529                    transition.to_state
530                ));
531            }
532        }; // Lock is released here
533
534        // Now we can call execute_actions with the cloned actions
535        self.execute_actions(&next_state_on_enter_actions, &mut memory, &mut context)
536            .await;
537
538        Ok(())
539    }
540
541    /// Executes a list of actions using the provided async action handler.
542    async fn execute_actions<'b>(
543        &self,
544        actions: &[Action],
545        memory: &'b mut Map<String, Value>,
546        context: &'b mut C,
547    ) {
548        for action in actions {
549            (self.action_handler)(action, memory, context).await;
550        }
551    }
552
553    /// Evaluates a list of validation rules against the memory.
554    fn evaluate_validations(
555        validations: &[ValidationRule],
556        memory: &Map<String, Value>,
557    ) -> Result<(), String> {
558        for validation in validations {
559            // Check condition if present
560            if let Some(condition) = &validation.condition {
561                if !Self::evaluate_condition(condition, memory)? {
562                    // Condition not met, skip validation
563                    continue;
564                }
565            }
566
567            // Get the value from the memory
568            let field_value = memory.get(&validation.field);
569
570            for rule in &validation.rules {
571                match rule {
572                    FieldRule::TypeCheck { expected_type } => {
573                        if let Some(value) = field_value {
574                            let actual_type = Self::get_type_name(value);
575                            if actual_type != expected_type {
576                                return Err(format!(
577                                    "Validation failed: Field '{}' expected type '{}', got '{}'",
578                                    validation.field, expected_type, actual_type
579                                ));
580                            }
581                        } else {
582                            return Err(format!(
583                                "Validation failed: Field '{}' is missing in memory",
584                                validation.field
585                            ));
586                        }
587                    }
588                    FieldRule::Nullable { is_nullable } => {
589                        if !*is_nullable && field_value.is_none() {
590                            return Err(format!(
591                                "Validation failed: Field '{}' cannot be null",
592                                validation.field
593                            ));
594                        }
595                    }
596                    FieldRule::MinValue { value: min_value } => {
597                        if let Some(Value::Number(num)) = field_value {
598                            if num.as_f64().unwrap_or(f64::NAN) < *min_value {
599                                return Err(format!(
600                                    "Validation failed: Field '{}' value '{}' is less than minimum '{}'",
601                                    validation.field, num, min_value
602                                ));
603                            }
604                        } else {
605                            return Err(format!(
606                                "Validation failed: Field '{}' is not a number",
607                                validation.field
608                            ));
609                        }
610                    }
611                    FieldRule::MaxValue { value: max_value } => {
612                        if let Some(Value::Number(num)) = field_value {
613                            if num.as_f64().unwrap_or(f64::NAN) > *max_value {
614                                return Err(format!(
615                                    "Validation failed: Field '{}' value '{}' is greater than maximum '{}'",
616                                    validation.field, num, max_value
617                                ));
618                            }
619                        } else {
620                            return Err(format!(
621                                "Validation failed: Field '{}' is not a number",
622                                validation.field
623                            ));
624                        }
625                    }
626                    FieldRule::Editable { is_editable: _ }
627                    | FieldRule::ReadOnly { is_read_only: _ } => {
628                        // Not implemented
629                    }
630                    FieldRule::Enum { values } => {
631                        if let Some(value) = field_value {
632                            if !values.contains(value) {
633                                return Err(format!(
634                                    "Validation failed: Field '{}' value '{}' is not in enum {:?}",
635                                    validation.field, value, values
636                                ));
637                            }
638                        } else {
639                            return Err(format!(
640                                "Validation failed: Field '{}' is missing in memory",
641                                validation.field
642                            ));
643                        }
644                    } // Handle more rules as needed
645                }
646            }
647        }
648        Ok(())
649    }
650
651    /// Evaluates a condition against the memory.
652    fn evaluate_condition(
653        condition: &Condition,
654        memory: &Map<String, Value>,
655    ) -> Result<bool, String> {
656        let field_value = memory.get(&condition.field);
657        if let Some(actual_value) = field_value {
658            let result = match condition.operator.as_str() {
659                "==" => actual_value == &condition.value,
660                "!=" => actual_value != &condition.value,
661                ">" => Self::compare_values(
662                    actual_value,
663                    &condition.value,
664                    std::cmp::Ordering::Greater,
665                )?,
666                "<" => {
667                    Self::compare_values(actual_value, &condition.value, std::cmp::Ordering::Less)?
668                }
669                ">=" => {
670                    let ordering = Self::compare_values_ordering(actual_value, &condition.value)?;
671                    ordering == std::cmp::Ordering::Greater || ordering == std::cmp::Ordering::Equal
672                }
673                "<=" => {
674                    let ordering = Self::compare_values_ordering(actual_value, &condition.value)?;
675                    ordering == std::cmp::Ordering::Less || ordering == std::cmp::Ordering::Equal
676                }
677                _ => return Err(format!("Unsupported operator '{}'", condition.operator)),
678            };
679            Ok(result)
680        } else {
681            Err(format!(
682                "Condition evaluation failed: Field '{}' is missing in memory",
683                condition.field
684            ))
685        }
686    }
687
688    /// Compares two serde_json::Value numbers based on the expected ordering.
689    fn compare_values(
690        actual: &Value,
691        expected: &Value,
692        ordering: std::cmp::Ordering,
693    ) -> Result<bool, String> {
694        let actual_num = actual
695            .as_f64()
696            .ok_or_else(|| format!("Cannot compare non-numeric value '{}' in condition", actual))?;
697        let expected_num = expected.as_f64().ok_or_else(|| {
698            format!(
699                "Cannot compare non-numeric value '{}' in condition",
700                expected
701            )
702        })?;
703        Ok(actual_num.partial_cmp(&expected_num) == Some(ordering))
704    }
705
706    /// Compares two serde_json::Value numbers and returns the ordering.
707    fn compare_values_ordering(
708        actual: &Value,
709        expected: &Value,
710    ) -> Result<std::cmp::Ordering, String> {
711        let actual_num = actual
712            .as_f64()
713            .ok_or_else(|| format!("Cannot compare non-numeric value '{}' in condition", actual))?;
714        let expected_num = expected.as_f64().ok_or_else(|| {
715            format!(
716                "Cannot compare non-numeric value '{}' in condition",
717                expected
718            )
719        })?;
720        Ok(actual_num
721            .partial_cmp(&expected_num)
722            .unwrap_or(std::cmp::Ordering::Equal))
723    }
724
725    /// Returns a string representing the type of the serde_json::Value.
726    fn get_type_name(value: &Value) -> &str {
727        match value {
728            Value::Null => "null",
729            Value::Bool(_) => "boolean",
730            Value::Number(_) => "number",
731            Value::String(_) => "string",
732            Value::Array(_) => "array",
733            Value::Object(_) => "object",
734        }
735    }
736
737    /// Returns the current state of the state machine.
738    pub async fn get_current_state(&self) -> Result<String, String> {
739        let current_state_guard = self.current_state.read().unwrap();
740        Ok(current_state_guard.clone())
741    }
742}
743
744/// Implementing the Display trait to render the state machine as a string.
745impl<C> Display for StateMachine<'_, C> {
746    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
747        let states = self.states.read().unwrap();
748        let current_state = self.current_state.read().unwrap();
749
750        writeln!(f, "State Machine Diagram:")?;
751        writeln!(f, "======================")?;
752
753        for (state_name, state) in &*states {
754            let marker = if *state_name == *current_state {
755                "->" // Indicate the current state
756            } else {
757                "  "
758            };
759            writeln!(f, "{} State: {}", marker, state.name)?;
760
761            for (event, transition) in &state.transitions {
762                writeln!(f, "      -[{}]-> {}", event, transition.to_state)?;
763            }
764        }
765
766        writeln!(f, "======================")
767    }
768}