1use super::OptimizationError;
7use crate::graph::{Graph, TensorID};
8use crate::tensor::TensorInternal;
9use crate::Float;
10use std::collections::{HashMap, HashSet};
11
12pub struct ConstantFolder<F: Float> {
14 constant_cache: HashMap<TensorID, F>,
16 constant_nodes: HashSet<TensorID>,
18}
19
20impl<F: Float> ConstantFolder<F> {
21 pub fn new() -> Self {
23 Self {
24 constant_cache: HashMap::new(),
25 constant_nodes: HashSet::new(),
26 }
27 }
28
29 pub fn fold_constants(&mut self, graph: &mut Graph<F>) -> Result<usize, OptimizationError> {
31 let folded_count = 0;
32
33 self.mark_constant_nodes(graph)?;
40 let _propagated = self.propagate_constants(graph)?;
41 let _evaluated = self.evaluate_constant_expressions(graph)?;
42
43 Ok(folded_count)
44 }
45
46 fn mark_constant_nodes(&mut self, _graph: &Graph<F>) -> Result<(), OptimizationError> {
48 Ok(())
54 }
55
56 fn propagate_constants(&mut self, _graph: &Graph<F>) -> Result<usize, OptimizationError> {
58 Ok(0)
63 }
64
65 fn evaluate_constant_expressions(
67 &mut self,
68 _graph: &mut Graph<F>,
69 ) -> Result<usize, OptimizationError> {
70 Ok(0)
76 }
77
78 pub fn is_constant(&self, tensor_id: TensorID) -> bool {
80 self.constant_nodes.contains(&tensor_id)
81 }
82
83 pub fn get_constant_value(&self, tensor_id: TensorID) -> Option<F> {
85 self.constant_cache.get(&tensor_id).copied()
86 }
87
88 pub fn clear_cache(&mut self) {
90 self.constant_cache.clear();
91 self.constant_nodes.clear();
92 }
93}
94
95impl<F: Float> Default for ConstantFolder<F> {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101#[derive(Debug, Clone)]
103pub enum ConstantValue<F: Float> {
104 Scalar(F),
106 Vector(Vec<F>),
108 Matrix { values: Vec<F>, shape: Vec<usize> },
110}
111
112impl<F: Float> ConstantValue<F> {
113 pub fn is_zero(&self) -> bool {
115 match self {
116 ConstantValue::Scalar(x) => x.is_zero(),
117 ConstantValue::Vector(v) => v.iter().all(|x| x.is_zero()),
118 ConstantValue::Matrix { values, .. } => values.iter().all(|x| x.is_zero()),
119 }
120 }
121
122 pub fn is_one(&self) -> bool {
124 match self {
125 ConstantValue::Scalar(x) => *x == F::one(),
126 ConstantValue::Vector(v) => v.iter().all(|x| *x == F::one()),
127 ConstantValue::Matrix { values, .. } => values.iter().all(|x| *x == F::one()),
128 }
129 }
130
131 pub fn shape(&self) -> Vec<usize> {
133 match self {
134 ConstantValue::Scalar(_) => vec![],
135 ConstantValue::Vector(v) => vec![v.len()],
136 ConstantValue::Matrix { shape, .. } => shape.clone(),
137 }
138 }
139}
140
141#[derive(Debug, Clone, Copy)]
143pub enum ConstantPattern {
144 Zero,
146 One,
148 NegativeOne,
150 NonZero,
152 Finite,
154}
155
156impl ConstantPattern {
157 pub fn matches<F: Float>(&self, value: &ConstantValue<F>) -> bool {
159 match self {
160 ConstantPattern::Zero => value.is_zero(),
161 ConstantPattern::One => value.is_one(),
162 ConstantPattern::NegativeOne => {
163 matches!(value, ConstantValue::Scalar(x) if *x == -F::one())
164 }
165 ConstantPattern::NonZero => !value.is_zero(),
166 ConstantPattern::Finite => true, }
168 }
169}
170
171#[allow(dead_code)]
175pub(crate) fn is_literal_constant<F: Float>(_tensor_internal: &TensorInternal<F>) -> bool {
176 false
178}
179
180#[allow(dead_code)]
182pub(crate) fn extract_constant_value<F: Float>(
183 _tensor_internal: &TensorInternal<F>,
184) -> Option<ConstantValue<F>> {
185 None
187}
188
189#[allow(dead_code)]
191pub fn create_constant_tensor<F: Float>(
192 _graph: &mut Graph<F>,
193 _value: ConstantValue<F>,
194) -> Result<TensorID, OptimizationError> {
195 Err(OptimizationError::InvalidOperation(
197 "Not implemented".to_string(),
198 ))
199}
200
201impl<F: Float> ConstantValue<F> {
203 pub fn add(&self, other: &Self) -> Result<Self, OptimizationError> {
205 match (self, other) {
206 (ConstantValue::Scalar(a), ConstantValue::Scalar(b)) => {
207 Ok(ConstantValue::Scalar(*a + *b))
208 }
209 (ConstantValue::Vector(a), ConstantValue::Vector(b)) => {
210 if a.len() != b.len() {
211 return Err(OptimizationError::InvalidOperation(format!(
212 "Vector length mismatch in add: {} vs {}",
213 a.len(),
214 b.len()
215 )));
216 }
217 Ok(ConstantValue::Vector(
218 a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect(),
219 ))
220 }
221 (
222 ConstantValue::Matrix {
223 values: a,
224 shape: sa,
225 },
226 ConstantValue::Matrix {
227 values: b,
228 shape: sb,
229 },
230 ) => {
231 if sa != sb {
232 return Err(OptimizationError::InvalidOperation(format!(
233 "Matrix shape mismatch in add: {:?} vs {:?}",
234 sa, sb
235 )));
236 }
237 Ok(ConstantValue::Matrix {
238 values: a.iter().zip(b.iter()).map(|(&x, &y)| x + y).collect(),
239 shape: sa.clone(),
240 })
241 }
242 _ => Err(OptimizationError::InvalidOperation(
243 "Incompatible constant types for addition".to_string(),
244 )),
245 }
246 }
247
248 pub fn sub(&self, other: &Self) -> Result<Self, OptimizationError> {
250 match (self, other) {
251 (ConstantValue::Scalar(a), ConstantValue::Scalar(b)) => {
252 Ok(ConstantValue::Scalar(*a - *b))
253 }
254 (ConstantValue::Vector(a), ConstantValue::Vector(b)) => {
255 if a.len() != b.len() {
256 return Err(OptimizationError::InvalidOperation(format!(
257 "Vector length mismatch in sub: {} vs {}",
258 a.len(),
259 b.len()
260 )));
261 }
262 Ok(ConstantValue::Vector(
263 a.iter().zip(b.iter()).map(|(&x, &y)| x - y).collect(),
264 ))
265 }
266 (
267 ConstantValue::Matrix {
268 values: a,
269 shape: sa,
270 },
271 ConstantValue::Matrix {
272 values: b,
273 shape: sb,
274 },
275 ) => {
276 if sa != sb {
277 return Err(OptimizationError::InvalidOperation(format!(
278 "Matrix shape mismatch in sub: {:?} vs {:?}",
279 sa, sb
280 )));
281 }
282 Ok(ConstantValue::Matrix {
283 values: a.iter().zip(b.iter()).map(|(&x, &y)| x - y).collect(),
284 shape: sa.clone(),
285 })
286 }
287 _ => Err(OptimizationError::InvalidOperation(
288 "Incompatible constant types for subtraction".to_string(),
289 )),
290 }
291 }
292
293 pub fn mul(&self, other: &Self) -> Result<Self, OptimizationError> {
295 match (self, other) {
296 (ConstantValue::Scalar(a), ConstantValue::Scalar(b)) => {
297 Ok(ConstantValue::Scalar(*a * *b))
298 }
299 (ConstantValue::Vector(a), ConstantValue::Vector(b)) => {
300 if a.len() != b.len() {
301 return Err(OptimizationError::InvalidOperation(format!(
302 "Vector length mismatch in mul: {} vs {}",
303 a.len(),
304 b.len()
305 )));
306 }
307 Ok(ConstantValue::Vector(
308 a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect(),
309 ))
310 }
311 (ConstantValue::Scalar(s), ConstantValue::Vector(v))
312 | (ConstantValue::Vector(v), ConstantValue::Scalar(s)) => {
313 Ok(ConstantValue::Vector(v.iter().map(|&x| x * *s).collect()))
314 }
315 (ConstantValue::Scalar(s), ConstantValue::Matrix { values, shape })
316 | (ConstantValue::Matrix { values, shape }, ConstantValue::Scalar(s)) => {
317 Ok(ConstantValue::Matrix {
318 values: values.iter().map(|&x| x * *s).collect(),
319 shape: shape.clone(),
320 })
321 }
322 (
323 ConstantValue::Matrix {
324 values: a,
325 shape: sa,
326 },
327 ConstantValue::Matrix {
328 values: b,
329 shape: sb,
330 },
331 ) => {
332 if sa != sb {
333 return Err(OptimizationError::InvalidOperation(format!(
334 "Matrix shape mismatch in mul: {:?} vs {:?}",
335 sa, sb
336 )));
337 }
338 Ok(ConstantValue::Matrix {
339 values: a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect(),
340 shape: sa.clone(),
341 })
342 }
343 (ConstantValue::Vector(_), ConstantValue::Matrix { .. })
344 | (ConstantValue::Matrix { .. }, ConstantValue::Vector(_)) => {
345 Err(OptimizationError::InvalidOperation(
346 "Incompatible constant types for multiplication (Vector vs Matrix)".to_string(),
347 ))
348 }
349 }
350 }
351
352 pub fn div(&self, other: &Self) -> Result<Self, OptimizationError> {
354 match (self, other) {
355 (ConstantValue::Scalar(a), ConstantValue::Scalar(b)) => {
356 if b.is_zero() {
357 return Err(OptimizationError::InvalidOperation(
358 "Division by zero".to_string(),
359 ));
360 }
361 Ok(ConstantValue::Scalar(*a / *b))
362 }
363 (ConstantValue::Vector(a), ConstantValue::Vector(b)) => {
364 if a.len() != b.len() {
365 return Err(OptimizationError::InvalidOperation(format!(
366 "Vector length mismatch in div: {} vs {}",
367 a.len(),
368 b.len()
369 )));
370 }
371 if b.iter().any(|x| x.is_zero()) {
372 return Err(OptimizationError::InvalidOperation(
373 "Division by zero in vector".to_string(),
374 ));
375 }
376 Ok(ConstantValue::Vector(
377 a.iter().zip(b.iter()).map(|(&x, &y)| x / y).collect(),
378 ))
379 }
380 (ConstantValue::Vector(v), ConstantValue::Scalar(s)) => {
381 if s.is_zero() {
382 return Err(OptimizationError::InvalidOperation(
383 "Division by zero".to_string(),
384 ));
385 }
386 Ok(ConstantValue::Vector(v.iter().map(|&x| x / *s).collect()))
387 }
388 (ConstantValue::Matrix { values, shape }, ConstantValue::Scalar(s)) => {
389 if s.is_zero() {
390 return Err(OptimizationError::InvalidOperation(
391 "Division by zero".to_string(),
392 ));
393 }
394 Ok(ConstantValue::Matrix {
395 values: values.iter().map(|&x| x / *s).collect(),
396 shape: shape.clone(),
397 })
398 }
399 (
400 ConstantValue::Matrix {
401 values: a,
402 shape: sa,
403 },
404 ConstantValue::Matrix {
405 values: b,
406 shape: sb,
407 },
408 ) => {
409 if sa != sb {
410 return Err(OptimizationError::InvalidOperation(format!(
411 "Matrix shape mismatch in div: {:?} vs {:?}",
412 sa, sb
413 )));
414 }
415 if b.iter().any(|x| x.is_zero()) {
416 return Err(OptimizationError::InvalidOperation(
417 "Division by zero in matrix".to_string(),
418 ));
419 }
420 Ok(ConstantValue::Matrix {
421 values: a.iter().zip(b.iter()).map(|(&x, &y)| x / y).collect(),
422 shape: sa.clone(),
423 })
424 }
425 (ConstantValue::Scalar(_), ConstantValue::Vector(_))
426 | (ConstantValue::Scalar(_), ConstantValue::Matrix { .. })
427 | (ConstantValue::Vector(_), ConstantValue::Matrix { .. })
428 | (ConstantValue::Matrix { .. }, ConstantValue::Vector(_)) => {
429 Err(OptimizationError::InvalidOperation(
430 "Incompatible constant types for division".to_string(),
431 ))
432 }
433 }
434 }
435
436 pub fn neg(&self) -> Result<Self, OptimizationError> {
438 match self {
439 ConstantValue::Scalar(x) => Ok(ConstantValue::Scalar(-*x)),
440 ConstantValue::Vector(v) => Ok(ConstantValue::Vector(v.iter().map(|x| -*x).collect())),
441 ConstantValue::Matrix { values, shape } => Ok(ConstantValue::Matrix {
442 values: values.iter().map(|x| -*x).collect(),
443 shape: shape.clone(),
444 }),
445 }
446 }
447
448 pub fn apply_unary<Func>(&self, func: Func) -> Result<Self, OptimizationError>
450 where
451 Func: Fn(F) -> F,
452 {
453 match self {
454 ConstantValue::Scalar(x) => Ok(ConstantValue::Scalar(func(*x))),
455 ConstantValue::Vector(v) => {
456 Ok(ConstantValue::Vector(v.iter().map(|x| func(*x)).collect()))
457 }
458 ConstantValue::Matrix { values, shape } => Ok(ConstantValue::Matrix {
459 values: values.iter().map(|x| func(*x)).collect(),
460 shape: shape.clone(),
461 }),
462 }
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_constant_folder_creation() {
472 let _folder = ConstantFolder::<f32>::new();
473 }
474
475 #[test]
476 fn test_constant_value_creation() {
477 let scalar = ConstantValue::Scalar(42.0f32);
478 assert_eq!(scalar.shape(), Vec::<usize>::new());
479
480 let vector = ConstantValue::Vector(vec![1.0, 2.0, 3.0]);
481 assert_eq!(vector.shape(), vec![3]);
482
483 let matrix = ConstantValue::Matrix {
484 values: vec![1.0, 2.0, 3.0, 4.0],
485 shape: vec![2, 2],
486 };
487 assert_eq!(matrix.shape(), vec![2, 2]);
488 }
489
490 #[test]
491 fn test_constant_patterns() {
492 let zero = ConstantValue::Scalar(0.0f32);
493 let one = ConstantValue::Scalar(1.0f32);
494 let neg_one = ConstantValue::Scalar(-1.0f32);
495 let other = ConstantValue::Scalar(42.0f32);
496
497 assert!(ConstantPattern::Zero.matches(&zero));
498 assert!(!ConstantPattern::Zero.matches(&one));
499
500 assert!(ConstantPattern::One.matches(&one));
501 assert!(!ConstantPattern::One.matches(&zero));
502
503 assert!(ConstantPattern::NegativeOne.matches(&neg_one));
504 assert!(!ConstantPattern::NegativeOne.matches(&one));
505
506 assert!(ConstantPattern::NonZero.matches(&other));
507 assert!(!ConstantPattern::NonZero.matches(&zero));
508
509 assert!(ConstantPattern::Finite.matches(&other));
510 }
511
512 #[test]
513 fn test_constant_value_properties() {
514 let zero = ConstantValue::Scalar(0.0f32);
515 let one = ConstantValue::Scalar(1.0f32);
516 let other = ConstantValue::Scalar(42.0f32);
517
518 assert!(zero.is_zero());
519 assert!(!one.is_zero());
520 assert!(!other.is_zero());
521
522 assert!(one.is_one());
523 assert!(!zero.is_one());
524 assert!(!other.is_one());
525 }
526
527 #[test]
528 fn test_constant_value_negation() {
529 let positive = ConstantValue::Scalar(42.0f32);
530 let negative = positive.neg().expect("Operation failed");
531
532 if let ConstantValue::Scalar(val) = negative {
533 assert_eq!(val, -42.0);
534 } else {
535 panic!("Expected scalar result");
536 }
537 }
538
539 #[test]
540 fn test_constant_value_unary_function() {
541 let value = ConstantValue::Scalar(4.0f32);
542 let sqrt_value = value.apply_unary(|x| x.sqrt()).expect("Operation failed");
543
544 if let ConstantValue::Scalar(val) = sqrt_value {
545 assert_eq!(val, 2.0);
546 } else {
547 panic!("Expected scalar result");
548 }
549 }
550}