scirs2_core/memory_efficient/
fusion.rs

1use crate::error::{CoreError, ErrorContext, ErrorLocation};
2use once_cell::sync::Lazy;
3use std::any::{Any, TypeId};
4use std::collections::HashMap;
5use std::fmt;
6use std::sync::{Arc, Mutex};
7
8/// Type alias for a fused operation
9type FusedOpArc = Arc<dyn FusedOp>;
10
11/// Type alias for the fusion registry storage
12type FusionRegistryMap = HashMap<TypeId, Vec<FusedOpArc>>;
13
14// Global registry of fused operations
15static FUSION_REGISTRY: Lazy<Mutex<FusionRegistryMap>> = Lazy::new(|| Mutex::new(HashMap::new()));
16
17/// A trait for operations that can be fused together for better performance
18pub trait FusedOp: Send + Sync {
19    /// Returns the unique name of this operation
20    fn name(&self) -> &str;
21
22    /// Returns the type ID of the input to this operation
23    fn input_type(&self) -> TypeId;
24
25    /// Returns the type ID of the output from this operation
26    fn output_type(&self) -> TypeId;
27
28    /// Checks if this operation can be fused with another operation
29    fn can_fuse_with(&self, other: &dyn FusedOp) -> bool;
30
31    /// Creates a new operation that is the fusion of this operation with another
32    fn fuse_with(&self, other: &dyn FusedOp) -> Arc<dyn FusedOp>;
33
34    /// Applies this operation to an input (as Any)
35    fn apply(&self, input: &dyn Any) -> Result<Box<dyn Any>, CoreError>;
36
37    /// Clone this operation
38    fn clone_op(&self) -> Arc<dyn FusedOp>;
39}
40
41/// A structure for chaining multiple operations together and optimizing the execution
42#[derive(Clone)]
43pub struct OpFusion {
44    /// The sequence of operations to apply
45    ops: Vec<Arc<dyn FusedOp>>,
46    /// The input type ID
47    input_type: TypeId,
48    /// The output type ID
49    output_type: TypeId,
50}
51
52impl fmt::Debug for OpFusion {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        f.debug_struct("OpFusion")
55            .field("num_ops", &self.ops.len())
56            .finish()
57    }
58}
59
60impl OpFusion {
61    /// Create a new operation fusion
62    pub fn new() -> Self {
63        Self {
64            ops: Vec::new(),
65            input_type: TypeId::of::<()>(),
66            output_type: TypeId::of::<()>(),
67        }
68    }
69
70    /// Add an operation to the fusion chain
71    pub fn add_op(&mut self, op: Arc<dyn FusedOp>) -> Result<&mut Self, CoreError> {
72        if self.ops.is_empty() {
73            self.input_type = op.input_type();
74            self.output_type = op.output_type();
75        } else if op.input_type() != self.output_type {
76            return Err(CoreError::ValidationError(
77                ErrorContext::new("Operation input type does not match previous output type")
78                    .with_location(ErrorLocation::new(file!(), line!())),
79            ));
80        }
81
82        let output_type = op.output_type();
83        self.ops.push(op);
84        self.output_type = output_type;
85        Ok(self)
86    }
87
88    /// Optimize the operation chain by fusing operations where possible
89    pub fn optimize(&mut self) -> Result<&mut Self, CoreError> {
90        if self.ops.len() <= 1 {
91            return Ok(self);
92        }
93
94        let mut optimized = Vec::new();
95        let mut current_op = self.ops[0].clone_op();
96
97        for i in 1..self.ops.len() {
98            let next_op = &self.ops[i];
99
100            if current_op.can_fuse_with(next_op.as_ref()) {
101                current_op = current_op.fuse_with(next_op.as_ref());
102            } else {
103                optimized.push(current_op);
104                current_op = next_op.clone_op();
105            }
106        }
107
108        optimized.push(current_op);
109        self.ops = optimized;
110
111        Ok(self)
112    }
113
114    /// Apply the operation chain to an input value
115    pub fn apply<A: 'static>(&self, input: A) -> Result<Box<dyn Any>, CoreError> {
116        if TypeId::of::<A>() != self.input_type {
117            return Err(CoreError::ValidationError(
118                ErrorContext::new("Input type does not match expected type")
119                    .with_location(ErrorLocation::new(file!(), line!())),
120            ));
121        }
122
123        let mut result: Box<dyn Any> = Box::new(input);
124
125        for op in &self.ops {
126            result = op.apply(result.as_ref())?;
127        }
128
129        Ok(result)
130    }
131
132    /// Get the number of operations in the chain
133    pub fn num_ops(&self) -> usize {
134        self.ops.len()
135    }
136
137    /// Check if the operation chain is empty
138    pub fn is_empty(&self) -> bool {
139        self.ops.is_empty()
140    }
141}
142
143impl Default for OpFusion {
144    fn default() -> Self {
145        Self::new()
146    }
147}
148
149/// Register a fused operation in the global registry
150#[allow(dead_code)]
151pub fn register_fusion<T: 'static>(op: Arc<dyn FusedOp>) -> Result<(), CoreError> {
152    let type_id = TypeId::of::<T>();
153
154    let mut registry = FUSION_REGISTRY.lock().expect("Operation failed");
155    let ops = registry.entry(type_id).or_default();
156    ops.push(op);
157
158    Ok(())
159}
160
161/// Get all registered fused operations for a type
162#[allow(dead_code)]
163pub fn get_fusions<T: 'static>() -> Vec<Arc<dyn FusedOp>> {
164    let type_id = TypeId::of::<T>();
165
166    let registry = FUSION_REGISTRY.lock().expect("Operation failed");
167    match registry.get(&type_id) {
168        Some(ops) => ops.clone(),
169        None => Vec::new(),
170    }
171}