tensorlogic_compiler/compile/
custom_ops.rs1use anyhow::{bail, Result};
7use std::collections::HashMap;
8use std::sync::{Arc, RwLock};
9use tensorlogic_ir::{EinsumGraph, TLExpr};
10
11use crate::{config::CompilationConfig, CompilerContext};
12
13pub type CustomOpHandler = Arc<
23 dyn Fn(&TLExpr, &mut CompilerContext, &mut EinsumGraph, &CustomOpData) -> Result<usize>
24 + Send
25 + Sync,
26>;
27
28#[derive(Debug, Clone)]
30pub struct CustomOpMetadata {
31 pub name: String,
33 pub description: String,
35 pub expected_arity: Option<usize>,
37 pub is_differentiable: bool,
39}
40
41#[derive(Debug, Clone, Default)]
43pub struct CustomOpData {
44 pub string_data: HashMap<String, String>,
46 pub numeric_data: HashMap<String, f64>,
48}
49
50impl CustomOpData {
51 pub fn new() -> Self {
53 Self::default()
54 }
55
56 pub fn with_string(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
58 self.string_data.insert(key.into(), value.into());
59 self
60 }
61
62 pub fn with_numeric(mut self, key: impl Into<String>, value: f64) -> Self {
64 self.numeric_data.insert(key.into(), value);
65 self
66 }
67
68 pub fn get_string(&self, key: &str) -> Option<&String> {
70 self.string_data.get(key)
71 }
72
73 pub fn get_numeric(&self, key: &str) -> Option<f64> {
75 self.numeric_data.get(key).copied()
76 }
77}
78
79pub struct CustomOpRegistry {
81 handlers: RwLock<HashMap<String, (CustomOpHandler, CustomOpMetadata)>>,
82}
83
84impl Default for CustomOpRegistry {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl CustomOpRegistry {
91 pub fn new() -> Self {
93 Self {
94 handlers: RwLock::new(HashMap::new()),
95 }
96 }
97
98 pub fn register(
126 &mut self,
127 name: impl Into<String>,
128 metadata: CustomOpMetadata,
129 handler: CustomOpHandler,
130 ) -> Result<()> {
131 let name = name.into();
132
133 let mut handlers = self.handlers.write().unwrap();
134
135 if handlers.contains_key(&name) {
136 bail!("Custom operation '{}' is already registered", name);
137 }
138
139 handlers.insert(name, (handler, metadata));
140 Ok(())
141 }
142
143 pub fn unregister(&mut self, name: &str) -> Result<()> {
145 let mut handlers = self.handlers.write().unwrap();
146
147 if handlers.remove(name).is_none() {
148 bail!("Custom operation '{}' not found", name);
149 }
150
151 Ok(())
152 }
153
154 pub fn has_operation(&self, name: &str) -> bool {
156 let handlers = self.handlers.read().unwrap();
157 handlers.contains_key(name)
158 }
159
160 pub fn get_metadata(&self, name: &str) -> Option<CustomOpMetadata> {
162 let handlers = self.handlers.read().unwrap();
163 handlers.get(name).map(|(_, meta)| meta.clone())
164 }
165
166 pub fn list_operations(&self) -> Vec<String> {
168 let handlers = self.handlers.read().unwrap();
169 handlers.keys().cloned().collect()
170 }
171
172 pub fn invoke(
174 &self,
175 name: &str,
176 expr: &TLExpr,
177 ctx: &mut CompilerContext,
178 graph: &mut EinsumGraph,
179 data: &CustomOpData,
180 ) -> Result<usize> {
181 let handlers = self.handlers.read().unwrap();
182
183 let (handler, metadata) = handlers
184 .get(name)
185 .ok_or_else(|| anyhow::anyhow!("Custom operation '{}' not found", name))?;
186
187 if let Some(expected) = metadata.expected_arity {
189 if let TLExpr::Pred { args, .. } = expr {
190 if args.len() != expected {
191 bail!(
192 "Custom operation '{}' expects {} arguments, got {}",
193 name,
194 expected,
195 args.len()
196 );
197 }
198 }
199 }
200
201 handler(expr, ctx, graph, data)
202 }
203}
204
205#[derive(Clone)]
207pub struct ExtendedCompilerContext {
208 pub base_context: CompilerContext,
210 pub custom_ops: Arc<CustomOpRegistry>,
212 pub custom_data: CustomOpData,
214}
215
216impl ExtendedCompilerContext {
217 pub fn new() -> Self {
219 Self {
220 base_context: CompilerContext::new(),
221 custom_ops: Arc::new(CustomOpRegistry::new()),
222 custom_data: CustomOpData::new(),
223 }
224 }
225
226 pub fn from_context(ctx: CompilerContext) -> Self {
228 Self {
229 base_context: ctx,
230 custom_ops: Arc::new(CustomOpRegistry::new()),
231 custom_data: CustomOpData::new(),
232 }
233 }
234
235 pub fn with_config(mut self, config: CompilationConfig) -> Self {
237 self.base_context = CompilerContext::with_config(config);
238 self
239 }
240
241 pub fn with_custom_data(mut self, data: CustomOpData) -> Self {
243 self.custom_data = data;
244 self
245 }
246
247 pub fn custom_ops_mut(&mut self) -> &mut CustomOpRegistry {
249 Arc::get_mut(&mut self.custom_ops)
250 .expect("Cannot get mutable access to shared CustomOpRegistry")
251 }
252}
253
254impl Default for ExtendedCompilerContext {
255 fn default() -> Self {
256 Self::new()
257 }
258}
259
260pub mod presets {
262 use super::*;
263
264 pub fn create_soft_threshold_and(sharpness: f64) -> (CustomOpMetadata, CustomOpHandler) {
268 let metadata = CustomOpMetadata {
269 name: "soft_threshold_and".to_string(),
270 description: format!("Soft threshold AND with sharpness parameter {}", sharpness),
271 expected_arity: Some(2),
272 is_differentiable: true,
273 };
274
275 let handler = Arc::new(
276 move |_expr: &TLExpr,
277 _ctx: &mut CompilerContext,
278 graph: &mut EinsumGraph,
279 data: &CustomOpData| {
280 let _k = data.get_numeric("sharpness").unwrap_or(sharpness);
282
283 let tensor_idx = graph.add_tensor("soft_threshold_and_result");
286
287 Ok(tensor_idx)
288 },
289 ) as CustomOpHandler;
290
291 (metadata, handler)
292 }
293
294 pub fn create_weighted_or(w1: f64, w2: f64) -> (CustomOpMetadata, CustomOpHandler) {
298 let metadata = CustomOpMetadata {
299 name: "weighted_or".to_string(),
300 description: format!("Weighted OR with weights {} and {}", w1, w2),
301 expected_arity: Some(2),
302 is_differentiable: true,
303 };
304
305 let handler = Arc::new(
306 move |_expr: &TLExpr,
307 _ctx: &mut CompilerContext,
308 graph: &mut EinsumGraph,
309 data: &CustomOpData| {
310 let weight1 = data.get_numeric("w1").unwrap_or(w1);
311 let weight2 = data.get_numeric("w2").unwrap_or(w2);
312
313 let tensor_idx =
315 graph.add_tensor(format!("weighted_or_result_{}_{}", weight1, weight2));
316
317 Ok(tensor_idx)
318 },
319 ) as CustomOpHandler;
320
321 (metadata, handler)
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328
329 #[test]
330 fn test_custom_op_data() {
331 let data = CustomOpData::new()
332 .with_string("mode", "test")
333 .with_numeric("threshold", 0.5);
334
335 assert_eq!(data.get_string("mode"), Some(&"test".to_string()));
336 assert_eq!(data.get_numeric("threshold"), Some(0.5));
337 assert_eq!(data.get_string("nonexistent"), None);
338 }
339
340 #[test]
345 fn test_extended_context() {
346 let ctx = ExtendedCompilerContext::new();
347 assert_eq!(ctx.base_context.domains.len(), 0);
348 }
349
350 #[test]
351 fn test_preset_soft_threshold_and() {
352 let (metadata, _handler) = presets::create_soft_threshold_and(2.0);
353 assert_eq!(metadata.name, "soft_threshold_and");
354 assert_eq!(metadata.expected_arity, Some(2));
355 assert!(metadata.is_differentiable);
356 }
357
358 #[test]
359 fn test_preset_weighted_or() {
360 let (metadata, _handler) = presets::create_weighted_or(0.6, 0.4);
361 assert_eq!(metadata.name, "weighted_or");
362 assert_eq!(metadata.expected_arity, Some(2));
363 assert!(metadata.is_differentiable);
364 }
365}