1use crate::errors::{Result, TrustformersError};
37use crate::tensor::{DType, Tensor};
38use serde::{Deserialize, Serialize};
39use std::collections::HashMap;
40use std::fmt;
41use std::sync::Arc;
42
43#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
45pub enum OpType {
46 Add,
48 Sub,
49 Mul,
50 Div,
51 MatMul,
53 Transpose,
54 ReLU,
56 Sigmoid,
57 Tanh,
58 GELU,
59 Softmax(i32), Sum(Option<Vec<usize>>), Mean(Option<Vec<usize>>), Max(Option<Vec<usize>>), Min(Option<Vec<usize>>), Reshape(Vec<usize>),
67 Slice(Vec<(usize, usize)>), Concat(usize), Broadcast(Vec<usize>), Pow(f64), Sqrt,
74 Log,
75 Exp,
76 Greater,
78 Less,
79 Equal,
80 Where, }
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct ExprNode {
87 pub id: usize,
88 pub op: OpType,
89 pub operands: Vec<usize>, pub shape: Vec<usize>,
91 pub dtype: DType,
92 pub is_leaf: bool, #[serde(skip)]
94 pub tensor_data: Option<Arc<Tensor>>, }
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct TensorExpr {
100 nodes: HashMap<usize, ExprNode>,
101 root: usize,
102 next_id: usize,
103}
104
105#[allow(dead_code)] pub struct ExprBuilder<'a> {
108 expr: &'a mut TensorExpr,
109 current_node: usize,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct OptimizationHints {
115 pub enable_fusion: bool,
117 pub optimize_memory_layout: bool,
119 pub enable_vectorization: bool,
121 pub max_fusion_size: usize,
123 pub prefer_inplace: bool,
125}
126
127#[derive(Debug, Clone, Default)]
129pub struct EvalContext {
130 pub hints: OptimizationHints,
131 pub device: Option<String>,
132 pub memory_budget: Option<usize>, }
134
135impl fmt::Display for TensorExpr {
136 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137 write!(f, "{}", self.node_to_string(self.root))
138 }
139}
140
141impl TensorExpr {
142 pub fn from(tensor: &Tensor) -> Result<Self> {
144 let shape = tensor.shape();
145 let dtype = tensor.dtype();
146
147 let mut nodes = HashMap::new();
148 let root_node = ExprNode {
149 id: 0,
150 op: OpType::Add, operands: vec![],
152 shape,
153 dtype,
154 is_leaf: true,
155 tensor_data: Some(Arc::new(tensor.clone())),
156 };
157
158 nodes.insert(0, root_node);
159
160 Ok(TensorExpr {
161 nodes,
162 root: 0,
163 next_id: 1,
164 })
165 }
166
167 pub fn constant(tensor: Tensor) -> Result<Self> {
169 Self::from(&tensor)
170 }
171
172 pub fn shape(&self) -> Vec<usize> {
174 self.nodes[&self.root].shape.clone()
175 }
176
177 pub fn dtype(&self) -> DType {
179 self.nodes[&self.root].dtype
180 }
181
182 #[allow(clippy::should_implement_trait)] pub fn add(self, other: TensorExpr) -> Result<Self> {
185 self.binary_op(other, OpType::Add)
186 }
187
188 #[allow(clippy::should_implement_trait)] pub fn sub(self, other: TensorExpr) -> Result<Self> {
191 self.binary_op(other, OpType::Sub)
192 }
193
194 #[allow(clippy::should_implement_trait)] pub fn mul(self, other: TensorExpr) -> Result<Self> {
197 self.binary_op(other, OpType::Mul)
198 }
199
200 #[allow(clippy::should_implement_trait)] pub fn div(self, other: TensorExpr) -> Result<Self> {
203 self.binary_op(other, OpType::Div)
204 }
205
206 pub fn matmul(mut self, other: TensorExpr) -> Result<Self> {
208 let left_shape = self.nodes[&self.root].shape.clone();
210 let right_shape = other.nodes[&other.root].shape.clone();
211
212 if left_shape.len() < 2 || right_shape.len() < 2 {
213 return Err(TrustformersError::tensor_op_error(
214 "Matrix multiplication requires at least 2D tensors",
215 "matmul_validate",
216 ));
217 }
218
219 let left_cols = left_shape[left_shape.len() - 1];
220 let right_rows = right_shape[right_shape.len() - 2];
221
222 if left_cols != right_rows {
223 return Err(TrustformersError::tensor_op_error(
224 &format!(
225 "Incompatible shapes for matmul: {:?} x {:?}",
226 left_shape, right_shape
227 ),
228 "matmul_shape_check",
229 ));
230 }
231
232 let other_root = self.merge_expression(other)?;
234
235 let mut result_shape = left_shape[..left_shape.len() - 1].to_vec();
237 result_shape.push(right_shape[right_shape.len() - 1]);
238
239 let new_node = ExprNode {
240 id: self.next_id,
241 op: OpType::MatMul,
242 operands: vec![self.root, other_root],
243 shape: result_shape,
244 dtype: self.nodes[&self.root].dtype,
245 is_leaf: false,
246 tensor_data: None,
247 };
248
249 self.nodes.insert(self.next_id, new_node);
250 self.root = self.next_id;
251 self.next_id += 1;
252
253 Ok(self)
254 }
255
256 pub fn relu(self) -> Result<Self> {
258 self.unary_op(OpType::ReLU)
259 }
260
261 pub fn sigmoid(self) -> Result<Self> {
263 self.unary_op(OpType::Sigmoid)
264 }
265
266 pub fn tanh(self) -> Result<Self> {
268 self.unary_op(OpType::Tanh)
269 }
270
271 pub fn gelu(self) -> Result<Self> {
273 self.unary_op(OpType::GELU)
274 }
275
276 pub fn softmax(self, axis: i32) -> Result<Self> {
278 self.unary_op(OpType::Softmax(axis))
279 }
280
281 pub fn sum(mut self, axes: Option<Vec<usize>>) -> Result<Self> {
283 let result_shape = if let Some(ref axes) = axes {
284 let mut shape = self.nodes[&self.root].shape.clone();
285 let mut sorted_axes = axes.clone();
287 sorted_axes.sort_by(|a, b| b.cmp(a));
288 for &axis in &sorted_axes {
289 if axis >= shape.len() {
290 return Err(TrustformersError::tensor_op_error(
291 &format!(
292 "Axis {} out of bounds for tensor with {} dimensions",
293 axis,
294 shape.len()
295 ),
296 "reduce",
297 ));
298 }
299 shape.remove(axis);
300 }
301 shape
302 } else {
303 vec![] };
305
306 let new_node = ExprNode {
307 id: self.next_id,
308 op: OpType::Sum(axes),
309 operands: vec![self.root],
310 shape: result_shape,
311 dtype: self.nodes[&self.root].dtype,
312 is_leaf: false,
313 tensor_data: None,
314 };
315
316 self.nodes.insert(self.next_id, new_node);
317 self.root = self.next_id;
318 self.next_id += 1;
319
320 Ok(self)
321 }
322
323 pub fn mean(mut self, axes: Option<Vec<usize>>) -> Result<Self> {
325 let result_shape = if let Some(ref axes) = axes {
326 let mut shape = self.nodes[&self.root].shape.clone();
327 let mut sorted_axes = axes.clone();
328 sorted_axes.sort_by(|a, b| b.cmp(a));
329 for &axis in &sorted_axes {
330 if axis >= shape.len() {
331 return Err(TrustformersError::tensor_op_error(
332 &format!(
333 "Axis {} out of bounds for tensor with {} dimensions",
334 axis,
335 shape.len()
336 ),
337 "reduce",
338 ));
339 }
340 shape.remove(axis);
341 }
342 shape
343 } else {
344 vec![] };
346
347 let new_node = ExprNode {
348 id: self.next_id,
349 op: OpType::Mean(axes),
350 operands: vec![self.root],
351 shape: result_shape,
352 dtype: self.nodes[&self.root].dtype,
353 is_leaf: false,
354 tensor_data: None,
355 };
356
357 self.nodes.insert(self.next_id, new_node);
358 self.root = self.next_id;
359 self.next_id += 1;
360
361 Ok(self)
362 }
363
364 pub fn reshape(mut self, shape: &[usize]) -> Result<Self> {
366 let current_shape = &self.nodes[&self.root].shape;
368 let current_size: usize = current_shape.iter().product();
369 let new_size: usize = shape.iter().product();
370
371 if current_size != new_size {
372 return Err(TrustformersError::tensor_op_error(
373 &format!(
374 "Cannot reshape tensor with {} elements to shape with {} elements",
375 current_size, new_size
376 ),
377 "reshape",
378 ));
379 }
380
381 let new_node = ExprNode {
382 id: self.next_id,
383 op: OpType::Reshape(shape.to_vec()),
384 operands: vec![self.root],
385 shape: shape.to_vec(),
386 dtype: self.nodes[&self.root].dtype,
387 is_leaf: false,
388 tensor_data: None,
389 };
390
391 self.nodes.insert(self.next_id, new_node);
392 self.root = self.next_id;
393 self.next_id += 1;
394
395 Ok(self)
396 }
397
398 pub fn transpose(mut self) -> Result<Self> {
400 let current_shape = &self.nodes[&self.root].shape;
401 if current_shape.len() < 2 {
402 return Err(TrustformersError::tensor_op_error(
403 "Transpose requires at least 2D tensor",
404 "transpose",
405 ));
406 }
407
408 let mut new_shape = current_shape.clone();
409 let len = new_shape.len();
410 new_shape.swap(len - 2, len - 1);
411
412 let new_node = ExprNode {
413 id: self.next_id,
414 op: OpType::Transpose,
415 operands: vec![self.root],
416 shape: new_shape,
417 dtype: self.nodes[&self.root].dtype,
418 is_leaf: false,
419 tensor_data: None,
420 };
421
422 self.nodes.insert(self.next_id, new_node);
423 self.root = self.next_id;
424 self.next_id += 1;
425
426 Ok(self)
427 }
428
429 pub fn eval(&self) -> Result<Tensor> {
431 self.eval_with_context(&EvalContext::default())
432 }
433
434 pub fn eval_with_context(&self, context: &EvalContext) -> Result<Tensor> {
436 let optimized_expr =
438 if context.hints.enable_fusion { self.optimize_fusion()? } else { self.clone() };
439
440 optimized_expr.eval_recursive(optimized_expr.root, context)
442 }
443
444 pub fn can_fuse_with(&self, other: &TensorExpr) -> bool {
446 self.shape() == other.shape() && self.is_elementwise() && other.is_elementwise()
448 }
449
450 pub fn operation_count(&self) -> usize {
452 self.nodes.len() - self.leaf_count()
453 }
454
455 pub fn leaf_count(&self) -> usize {
457 self.nodes.values().filter(|n| n.is_leaf).count()
458 }
459
460 pub fn to_dot(&self) -> String {
462 let mut dot = String::from("digraph TensorExpr {\n");
463
464 for node in self.nodes.values() {
465 let label = if node.is_leaf {
466 format!("Tensor\\n{:?}\\n{:?}", node.shape, node.dtype)
467 } else {
468 format!("{:?}\\n{:?}\\n{:?}", node.op, node.shape, node.dtype)
469 };
470
471 let color = if node.is_leaf { "lightblue" } else { "lightgreen" };
472 dot.push_str(&format!(
473 " {} [label=\"{}\" fillcolor={} style=filled];\n",
474 node.id, label, color
475 ));
476
477 for &operand in &node.operands {
478 dot.push_str(&format!(" {} -> {};\n", operand, node.id));
479 }
480 }
481
482 dot.push_str("}\n");
483 dot
484 }
485
486 fn binary_op(mut self, other: TensorExpr, op: OpType) -> Result<Self> {
489 let left_shape = &self.nodes[&self.root].shape;
491 let right_shape = &other.nodes[&other.root].shape;
492 let result_shape = self.broadcast_shapes(left_shape, right_shape)?;
493
494 let other_root = self.merge_expression(other)?;
496
497 let new_node = ExprNode {
498 id: self.next_id,
499 op,
500 operands: vec![self.root, other_root],
501 shape: result_shape,
502 dtype: self.nodes[&self.root].dtype, is_leaf: false,
504 tensor_data: None,
505 };
506
507 self.nodes.insert(self.next_id, new_node);
508 self.root = self.next_id;
509 self.next_id += 1;
510
511 Ok(self)
512 }
513
514 fn unary_op(mut self, op: OpType) -> Result<Self> {
515 let new_node = ExprNode {
516 id: self.next_id,
517 op,
518 operands: vec![self.root],
519 shape: self.nodes[&self.root].shape.clone(),
520 dtype: self.nodes[&self.root].dtype,
521 is_leaf: false,
522 tensor_data: None,
523 };
524
525 self.nodes.insert(self.next_id, new_node);
526 self.root = self.next_id;
527 self.next_id += 1;
528
529 Ok(self)
530 }
531
532 fn merge_expression(&mut self, other: TensorExpr) -> Result<usize> {
533 let id_offset = self.next_id;
534
535 for (old_id, mut node) in other.nodes {
537 let new_id = old_id + id_offset;
538 node.id = new_id;
539
540 for operand in &mut node.operands {
542 *operand += id_offset;
543 }
544
545 self.nodes.insert(new_id, node);
546 }
547
548 self.next_id += other.next_id;
549 Ok(other.root + id_offset)
550 }
551
552 fn broadcast_shapes(&self, left: &[usize], right: &[usize]) -> Result<Vec<usize>> {
553 let max_len = left.len().max(right.len());
554 let mut result = vec![1; max_len];
555
556 for i in 0..max_len {
557 let left_dim = if i < left.len() { left[left.len() - 1 - i] } else { 1 };
558 let right_dim = if i < right.len() { right[right.len() - 1 - i] } else { 1 };
559
560 if left_dim == right_dim {
561 result[max_len - 1 - i] = left_dim;
562 } else if left_dim == 1 {
563 result[max_len - 1 - i] = right_dim;
564 } else if right_dim == 1 {
565 result[max_len - 1 - i] = left_dim;
566 } else {
567 return Err(TrustformersError::tensor_op_error(
568 &format!("Cannot broadcast shapes {:?} and {:?}", left, right),
569 "broadcast_shape_check",
570 ));
571 }
572 }
573
574 Ok(result)
575 }
576
577 fn is_elementwise(&self) -> bool {
578 matches!(
579 self.nodes[&self.root].op,
580 OpType::Add
581 | OpType::Sub
582 | OpType::Mul
583 | OpType::Div
584 | OpType::ReLU
585 | OpType::Sigmoid
586 | OpType::Tanh
587 | OpType::GELU
588 | OpType::Pow(_)
589 | OpType::Sqrt
590 | OpType::Log
591 | OpType::Exp
592 )
593 }
594
595 fn optimize_fusion(&self) -> Result<TensorExpr> {
596 let mut optimized = self.clone();
598
599 let fusion_chains = optimized.find_fusion_chains();
601
602 for chain in fusion_chains {
604 optimized.fuse_operations(&chain)?;
605 }
606
607 Ok(optimized)
608 }
609
610 fn find_fusion_chains(&self) -> Vec<Vec<usize>> {
611 let mut chains = Vec::new();
613 let mut visited = std::collections::HashSet::new();
614
615 for &node_id in self.nodes.keys() {
616 if visited.contains(&node_id) {
617 continue;
618 }
619
620 let mut chain = Vec::new();
621 let mut current = node_id;
622
623 while let Some(node) = self.nodes.get(¤t) {
624 if !self.is_node_elementwise(node) {
625 break;
626 }
627
628 chain.push(current);
629 visited.insert(current);
630
631 if node.operands.len() == 1 {
633 current = node.operands[0];
634 } else {
635 break;
636 }
637 }
638
639 if chain.len() > 1 {
640 chains.push(chain);
641 }
642 }
643
644 chains
645 }
646
647 fn is_node_elementwise(&self, node: &ExprNode) -> bool {
648 matches!(
649 node.op,
650 OpType::Add
651 | OpType::Sub
652 | OpType::Mul
653 | OpType::Div
654 | OpType::ReLU
655 | OpType::Sigmoid
656 | OpType::Tanh
657 | OpType::GELU
658 | OpType::Pow(_)
659 | OpType::Sqrt
660 | OpType::Log
661 | OpType::Exp
662 )
663 }
664
665 fn fuse_operations(&mut self, chain: &[usize]) -> Result<()> {
666 if chain.len() < 2 {
670 return Ok(());
671 }
672
673 Ok(())
677 }
678
679 fn eval_recursive(&self, node_id: usize, _context: &EvalContext) -> Result<Tensor> {
680 let node = &self.nodes[&node_id];
681
682 if node.is_leaf {
683 return node
684 .tensor_data
685 .as_ref()
686 .ok_or_else(|| {
687 TrustformersError::tensor_op_error(
688 "Leaf node must have tensor data",
689 "eval_recursive",
690 )
691 })
692 .map(|t| t.as_ref().clone());
693 }
694
695 let operand_results: Result<Vec<Tensor>> =
697 node.operands.iter().map(|&id| self.eval_recursive(id, _context)).collect();
698 let operands = operand_results?;
699
700 match &node.op {
702 OpType::Add => operands[0].add(&operands[1]),
703 OpType::Sub => operands[0].sub(&operands[1]),
704 OpType::Mul => operands[0].mul(&operands[1]),
705 OpType::Div => operands[0].div(&operands[1]),
706 OpType::MatMul => operands[0].matmul(&operands[1]),
707 OpType::Transpose => {
708 let shape = operands[0].shape();
709 let rank = shape.len();
710 if rank < 2 {
711 return Err(crate::errors::TrustformersError::dimension_mismatch(
712 "at least 2 dimensions".to_string(),
713 format!("{} dimensions", rank),
714 ));
715 }
716 operands[0].transpose(rank - 2, rank - 1)
717 },
718 OpType::ReLU => operands[0].relu(),
719 OpType::Sigmoid => operands[0].sigmoid(),
720 OpType::Tanh => operands[0].tanh(),
721 OpType::GELU => operands[0].gelu(),
722 OpType::Softmax(axis) => operands[0].softmax(*axis),
723 OpType::Sum(axes) => {
724 match axes {
725 Some(ref axes_vec) => operands[0].sum_axes(axes_vec),
726 None => {
727 let shape = operands[0].shape();
729 let all_axes: Vec<usize> = (0..shape.len()).collect();
730 operands[0].sum_axes(&all_axes)
731 },
732 }
733 },
734 OpType::Mean(axes) => match axes {
735 Some(ref axes_vec) => operands[0].mean_axes(axes_vec),
736 None => operands[0].mean(),
737 },
738 OpType::Reshape(shape) => operands[0].reshape(shape),
739 OpType::Pow(power) => operands[0].pow_scalar(*power),
740 OpType::Sqrt => operands[0].sqrt(),
741 OpType::Log => operands[0].log(),
742 OpType::Exp => operands[0].exp(),
743 OpType::Max(axes) => match axes {
744 Some(ref axes_vec) => operands[0].max_axes(axes_vec),
745 None => operands[0].max_scalar(),
746 },
747 OpType::Min(axes) => match axes {
748 Some(ref axes_vec) => operands[0].min_axes(axes_vec),
749 None => operands[0].min_scalar(),
750 },
751 OpType::Slice(ranges) => {
752 if ranges.is_empty() {
754 return Err(TrustformersError::tensor_op_error(
755 "No slice ranges provided",
756 "slice",
757 ));
758 }
759 operands[0].slice_multi(ranges)
760 },
761 OpType::Concat(axis) => {
762 if operands.len() < 2 {
763 return Err(TrustformersError::tensor_op_error(
764 "Concat requires at least 2 operands",
765 "evaluate_node",
766 ));
767 }
768
769 Tensor::concat(&operands, *axis)
771 },
772 OpType::Broadcast(target_shape) => operands[0].broadcast_to(target_shape),
773 OpType::Greater => {
774 if operands.len() != 2 {
775 return Err(TrustformersError::tensor_op_error(
776 "Greater operation requires exactly 2 operands",
777 "evaluate_node",
778 ));
779 }
780 operands[0].greater(&operands[1])
781 },
782 OpType::Less => {
783 if operands.len() != 2 {
784 return Err(TrustformersError::tensor_op_error(
785 "Less operation requires exactly 2 operands",
786 "evaluate_node",
787 ));
788 }
789 operands[0].less(&operands[1])
790 },
791 OpType::Equal => {
792 if operands.len() != 2 {
793 return Err(TrustformersError::tensor_op_error(
794 "Equal operation requires exactly 2 operands",
795 "evaluate_node",
796 ));
797 }
798 operands[0].equal(&operands[1])
799 },
800 OpType::Where => {
801 if operands.len() != 3 {
802 return Err(TrustformersError::tensor_op_error(
803 "Where operation requires exactly 3 operands: condition, x, y",
804 "evaluate_node",
805 ));
806 }
807 operands[0].where_cond(&operands[1], &operands[2])
809 },
810 }
811 }
812
813 fn node_to_string(&self, node_id: usize) -> String {
814 let node = &self.nodes[&node_id];
815
816 if node.is_leaf {
817 format!("Tensor{:?}", node.shape)
818 } else {
819 let operand_strs: Vec<String> =
820 node.operands.iter().map(|&id| self.node_to_string(id)).collect();
821
822 match &node.op {
823 OpType::Add => format!("({} + {})", operand_strs[0], operand_strs[1]),
824 OpType::Sub => format!("({} - {})", operand_strs[0], operand_strs[1]),
825 OpType::Mul => format!("({} * {})", operand_strs[0], operand_strs[1]),
826 OpType::Div => format!("({} / {})", operand_strs[0], operand_strs[1]),
827 OpType::MatMul => format!("matmul({}, {})", operand_strs[0], operand_strs[1]),
828 OpType::ReLU => format!("relu({})", operand_strs[0]),
829 OpType::Sigmoid => format!("sigmoid({})", operand_strs[0]),
830 OpType::Tanh => format!("tanh({})", operand_strs[0]),
831 OpType::GELU => format!("gelu({})", operand_strs[0]),
832 OpType::Softmax(axis) => format!("softmax({}, axis={})", operand_strs[0], axis),
833 OpType::Sum(axes) => format!("sum({}, axes={:?})", operand_strs[0], axes),
834 OpType::Mean(axes) => format!("mean({}, axes={:?})", operand_strs[0], axes),
835 OpType::Reshape(shape) => format!("reshape({}, {:?})", operand_strs[0], shape),
836 OpType::Transpose => format!("transpose({})", operand_strs[0]),
837 _ => format!("{:?}({})", node.op, operand_strs.join(", ")),
838 }
839 }
840 }
841}
842
843impl Default for OptimizationHints {
844 fn default() -> Self {
845 Self {
846 enable_fusion: true,
847 optimize_memory_layout: true,
848 enable_vectorization: true,
849 max_fusion_size: 8,
850 prefer_inplace: false,
851 }
852 }
853}
854
855#[cfg(test)]
856mod tests {
857 use super::*;
858 use crate::tensor::Tensor;
859
860 #[test]
861 fn test_basic_expression_creation() -> Result<()> {
862 let a = Tensor::ones(&[2, 3])?;
863 let expr = TensorExpr::from(&a)?;
864
865 assert_eq!(expr.shape(), vec![2, 3]);
866 assert_eq!(expr.dtype(), DType::F32);
867 assert_eq!(expr.operation_count(), 0);
868 assert_eq!(expr.leaf_count(), 1);
869
870 Ok(())
871 }
872
873 #[test]
874 fn test_binary_operations() -> Result<()> {
875 let a = Tensor::ones(&[2, 3])?;
876 let b = Tensor::ones(&[2, 3])?;
877
878 let expr_a = TensorExpr::from(&a)?;
879 let expr_b = TensorExpr::from(&b)?;
880
881 let result_expr = expr_a.add(expr_b)?;
882
883 assert_eq!(result_expr.shape(), vec![2, 3]);
884 assert_eq!(result_expr.operation_count(), 1);
885 assert_eq!(result_expr.leaf_count(), 2);
886
887 Ok(())
888 }
889
890 #[test]
891 fn test_chained_operations() -> Result<()> {
892 let a = Tensor::ones(&[2, 3])?;
893 let b = Tensor::ones(&[2, 3])?;
894 let c = Tensor::ones(&[2, 3])?;
895
896 let expr = TensorExpr::from(&a)?
897 .add(TensorExpr::from(&b)?)?
898 .mul(TensorExpr::from(&c)?)?
899 .relu()?;
900
901 assert_eq!(expr.shape(), vec![2, 3]);
902 assert_eq!(expr.operation_count(), 3); assert_eq!(expr.leaf_count(), 3);
904
905 Ok(())
906 }
907
908 #[test]
909 fn test_matrix_multiplication() -> Result<()> {
910 let a = Tensor::ones(&[2, 3])?;
911 let b = Tensor::ones(&[3, 4])?;
912
913 let expr = TensorExpr::from(&a)?.matmul(TensorExpr::from(&b)?)?;
914
915 assert_eq!(expr.shape(), vec![2, 4]);
916 assert_eq!(expr.operation_count(), 1);
917
918 Ok(())
919 }
920
921 #[test]
922 fn test_reduction_operations() -> Result<()> {
923 let a = Tensor::ones(&[2, 3, 4])?;
924
925 let sum_all = TensorExpr::from(&a)?.sum(None)?;
926 assert_eq!(sum_all.shape(), vec![] as Vec<usize>);
927
928 let sum_axis = TensorExpr::from(&a)?.sum(Some(vec![1]))?;
929 assert_eq!(sum_axis.shape(), vec![2, 4]);
930
931 Ok(())
932 }
933
934 #[test]
935 fn test_reshape_operation() -> Result<()> {
936 let a = Tensor::ones(&[2, 3, 4])?;
937
938 let reshaped = TensorExpr::from(&a)?.reshape(&[6, 4])?;
939 assert_eq!(reshaped.shape(), vec![6, 4]);
940
941 Ok(())
942 }
943
944 #[test]
945 fn test_expression_evaluation() -> Result<()> {
946 let a = Tensor::ones(&[2, 2])?;
947 let b = Tensor::ones(&[2, 2])?;
948
949 let expr = TensorExpr::from(&a)?.add(TensorExpr::from(&b)?)?;
950
951 let result = expr.eval()?;
952 assert_eq!(result.shape(), vec![2, 2]);
953
954 let _expected = Tensor::full_with_shape(&[2, 2], 2.0)?;
956 Ok(())
959 }
960
961 #[test]
962 fn test_expression_to_string() -> Result<()> {
963 let a = Tensor::ones(&[2, 2])?;
964 let b = Tensor::ones(&[2, 2])?;
965
966 let expr = TensorExpr::from(&a)?.add(TensorExpr::from(&b)?)?.relu()?;
967
968 let expr_str = expr.to_string();
969 assert!(expr_str.contains("+"));
970 assert!(expr_str.contains("relu"));
971
972 Ok(())
973 }
974
975 #[test]
976 fn test_dot_export() -> Result<()> {
977 let a = Tensor::ones(&[2, 2])?;
978 let b = Tensor::ones(&[2, 2])?;
979
980 let expr = TensorExpr::from(&a)?.add(TensorExpr::from(&b)?)?;
981
982 let dot = expr.to_dot();
983 assert!(dot.contains("digraph TensorExpr"));
984 assert!(dot.contains("Add"));
985
986 Ok(())
987 }
988
989 #[test]
990 fn test_optimization_hints() {
991 let hints = OptimizationHints::default();
992 assert!(hints.enable_fusion);
993 assert!(hints.optimize_memory_layout);
994 assert!(hints.enable_vectorization);
995 assert_eq!(hints.max_fusion_size, 8);
996 assert!(!hints.prefer_inplace);
997 }
998
999 #[test]
1000 fn test_can_fuse_operations() -> Result<()> {
1001 let a = Tensor::ones(&[2, 2])?;
1002 let b = Tensor::ones(&[2, 2])?;
1003
1004 let expr1 = TensorExpr::from(&a)?.relu()?;
1005 let expr2 = TensorExpr::from(&b)?.sigmoid()?;
1006
1007 assert!(expr1.can_fuse_with(&expr2));
1008
1009 Ok(())
1010 }
1011}