scirs2_core/memory_efficient/
fusion.rs1use crate::error::{CoreError, ErrorContext, ErrorLocation};
2use once_cell::sync::Lazy;
3use std::any::{Any, TypeId};
4use std::collections::HashMap;
5use std::fmt;
6use std::sync::{Arc, Mutex};
7
8type FusedOpArc = Arc<dyn FusedOp>;
10
11type FusionRegistryMap = HashMap<TypeId, Vec<FusedOpArc>>;
13
14static FUSION_REGISTRY: Lazy<Mutex<FusionRegistryMap>> = Lazy::new(|| Mutex::new(HashMap::new()));
16
17pub trait FusedOp: Send + Sync {
19 fn name(&self) -> &str;
21
22 fn input_type(&self) -> TypeId;
24
25 fn output_type(&self) -> TypeId;
27
28 fn can_fuse_with(&self, other: &dyn FusedOp) -> bool;
30
31 fn fuse_with(&self, other: &dyn FusedOp) -> Arc<dyn FusedOp>;
33
34 fn apply(&self, input: &dyn Any) -> Result<Box<dyn Any>, CoreError>;
36
37 fn clone_op(&self) -> Arc<dyn FusedOp>;
39}
40
41#[derive(Clone)]
43pub struct OpFusion {
44 ops: Vec<Arc<dyn FusedOp>>,
46 input_type: TypeId,
48 output_type: TypeId,
50}
51
52impl fmt::Debug for OpFusion {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 f.debug_struct("OpFusion")
55 .field("num_ops", &self.ops.len())
56 .finish()
57 }
58}
59
60impl OpFusion {
61 pub fn new() -> Self {
63 Self {
64 ops: Vec::new(),
65 input_type: TypeId::of::<()>(),
66 output_type: TypeId::of::<()>(),
67 }
68 }
69
70 pub fn add_op(&mut self, op: Arc<dyn FusedOp>) -> Result<&mut Self, CoreError> {
72 if self.ops.is_empty() {
73 self.input_type = op.input_type();
74 self.output_type = op.output_type();
75 } else if op.input_type() != self.output_type {
76 return Err(CoreError::ValidationError(
77 ErrorContext::new("Operation input type does not match previous output type")
78 .with_location(ErrorLocation::new(file!(), line!())),
79 ));
80 }
81
82 let output_type = op.output_type();
83 self.ops.push(op);
84 self.output_type = output_type;
85 Ok(self)
86 }
87
88 pub fn optimize(&mut self) -> Result<&mut Self, CoreError> {
90 if self.ops.len() <= 1 {
91 return Ok(self);
92 }
93
94 let mut optimized = Vec::new();
95 let mut current_op = self.ops[0].clone_op();
96
97 for i in 1..self.ops.len() {
98 let next_op = &self.ops[i];
99
100 if current_op.can_fuse_with(next_op.as_ref()) {
101 current_op = current_op.fuse_with(next_op.as_ref());
102 } else {
103 optimized.push(current_op);
104 current_op = next_op.clone_op();
105 }
106 }
107
108 optimized.push(current_op);
109 self.ops = optimized;
110
111 Ok(self)
112 }
113
114 pub fn apply<A: 'static>(&self, input: A) -> Result<Box<dyn Any>, CoreError> {
116 if TypeId::of::<A>() != self.input_type {
117 return Err(CoreError::ValidationError(
118 ErrorContext::new("Input type does not match expected type")
119 .with_location(ErrorLocation::new(file!(), line!())),
120 ));
121 }
122
123 let mut result: Box<dyn Any> = Box::new(input);
124
125 for op in &self.ops {
126 result = op.apply(result.as_ref())?;
127 }
128
129 Ok(result)
130 }
131
132 pub fn num_ops(&self) -> usize {
134 self.ops.len()
135 }
136
137 pub fn is_empty(&self) -> bool {
139 self.ops.is_empty()
140 }
141}
142
143impl Default for OpFusion {
144 fn default() -> Self {
145 Self::new()
146 }
147}
148
149#[allow(dead_code)]
151pub fn register_fusion<T: 'static>(op: Arc<dyn FusedOp>) -> Result<(), CoreError> {
152 let type_id = TypeId::of::<T>();
153
154 let mut registry = FUSION_REGISTRY.lock().expect("Operation failed");
155 let ops = registry.entry(type_id).or_default();
156 ops.push(op);
157
158 Ok(())
159}
160
161#[allow(dead_code)]
163pub fn get_fusions<T: 'static>() -> Vec<Arc<dyn FusedOp>> {
164 let type_id = TypeId::of::<T>();
165
166 let registry = FUSION_REGISTRY.lock().expect("Operation failed");
167 match registry.get(&type_id) {
168 Some(ops) => ops.clone(),
169 None => Vec::new(),
170 }
171}