1#[cfg(not(feature = "std"))]
37use alloc::collections::BTreeMap;
38#[cfg(not(feature = "std"))]
39use alloc::{string::String, vec::Vec};
40use core::fmt;
41#[cfg(feature = "std")]
42use std::collections::BTreeMap;
43#[cfg(feature = "std")]
44use std::vec::Vec;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
48pub struct NodeId(usize);
49
50impl NodeId {
51 pub fn new(id: usize) -> Self {
53 Self(id)
54 }
55
56 pub fn id(&self) -> usize {
58 self.0
59 }
60}
61
62impl fmt::Display for NodeId {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 write!(f, "Node({})", self.0)
65 }
66}
67
68#[derive(Debug, Clone, PartialEq, Eq)]
70pub enum ShapeOp {
71 Input,
73 Reshape { target_shape: Vec<usize> },
75 Transpose { axes: Vec<usize> },
77 Broadcast { target_shape: Vec<usize> },
79 Concatenate { axis: usize, other: NodeId },
81 Stack { axis: usize, other: NodeId },
83 Squeeze { axes: Option<Vec<usize>> },
85 Unsqueeze { axes: Vec<usize> },
87 Slice { ranges: Vec<(usize, usize)> },
89 Expand { target_shape: Vec<usize> },
91 Flatten { start_dim: usize, end_dim: usize },
93 Permute { dims: Vec<usize> },
95}
96
97impl fmt::Display for ShapeOp {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 match self {
100 ShapeOp::Input => write!(f, "Input"),
101 ShapeOp::Reshape { target_shape } => write!(f, "Reshape({:?})", target_shape),
102 ShapeOp::Transpose { axes } => write!(f, "Transpose({:?})", axes),
103 ShapeOp::Broadcast { target_shape } => write!(f, "Broadcast({:?})", target_shape),
104 ShapeOp::Concatenate { axis, other } => {
105 write!(f, "Concatenate(axis={}, {})", axis, other)
106 }
107 ShapeOp::Stack { axis, other } => write!(f, "Stack(axis={}, {})", axis, other),
108 ShapeOp::Squeeze { axes } => write!(f, "Squeeze({:?})", axes),
109 ShapeOp::Unsqueeze { axes } => write!(f, "Unsqueeze({:?})", axes),
110 ShapeOp::Slice { ranges } => write!(f, "Slice({:?})", ranges),
111 ShapeOp::Expand { target_shape } => write!(f, "Expand({:?})", target_shape),
112 ShapeOp::Flatten { start_dim, end_dim } => {
113 write!(f, "Flatten({}..{})", start_dim, end_dim)
114 }
115 ShapeOp::Permute { dims } => write!(f, "Permute({:?})", dims),
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
122pub struct ShapeNode {
123 id: NodeId,
125 shape: Option<Vec<usize>>,
127 op: ShapeOp,
129 dependencies: Vec<NodeId>,
131 name: Option<String>,
133}
134
135impl ShapeNode {
136 pub fn new(id: NodeId, op: ShapeOp, dependencies: Vec<NodeId>) -> Self {
138 Self {
139 id,
140 shape: None,
141 op,
142 dependencies,
143 name: None,
144 }
145 }
146
147 pub fn id(&self) -> NodeId {
149 self.id
150 }
151
152 pub fn set_shape(&mut self, shape: Vec<usize>) {
154 self.shape = Some(shape);
155 }
156
157 pub fn shape(&self) -> Option<&[usize]> {
159 self.shape.as_deref()
160 }
161
162 pub fn op(&self) -> &ShapeOp {
164 &self.op
165 }
166
167 pub fn dependencies(&self) -> &[NodeId] {
169 &self.dependencies
170 }
171
172 pub fn set_name(&mut self, name: String) {
174 self.name = Some(name);
175 }
176
177 pub fn name(&self) -> Option<&str> {
179 self.name.as_deref()
180 }
181}
182
183#[derive(Debug, Clone, PartialEq, Eq)]
185pub enum ShapeInferenceError {
186 NodeNotFound(NodeId),
188 InvalidReshape {
190 source_shape: Vec<usize>,
191 target_shape: Vec<usize>,
192 reason: String,
193 },
194 InvalidTranspose {
196 shape: Vec<usize>,
197 axes: Vec<usize>,
198 reason: String,
199 },
200 InvalidBroadcast {
202 source_shape: Vec<usize>,
203 target_shape: Vec<usize>,
204 reason: String,
205 },
206 InvalidConcatenate {
208 shape1: Vec<usize>,
209 shape2: Vec<usize>,
210 axis: usize,
211 reason: String,
212 },
213 InvalidSlice {
215 shape: Vec<usize>,
216 ranges: Vec<(usize, usize)>,
217 reason: String,
218 },
219 CyclicDependency(NodeId),
221 UnknownShape(NodeId),
223}
224
225impl fmt::Display for ShapeInferenceError {
226 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227 match self {
228 ShapeInferenceError::NodeNotFound(id) => write!(f, "Node not found: {}", id),
229 ShapeInferenceError::InvalidReshape {
230 source_shape,
231 target_shape,
232 reason,
233 } => write!(
234 f,
235 "Invalid reshape from {:?} to {:?}: {}",
236 source_shape, target_shape, reason
237 ),
238 ShapeInferenceError::InvalidTranspose {
239 shape,
240 axes,
241 reason,
242 } => {
243 write!(
244 f,
245 "Invalid transpose of {:?} with axes {:?}: {}",
246 shape, axes, reason
247 )
248 }
249 ShapeInferenceError::InvalidBroadcast {
250 source_shape,
251 target_shape,
252 reason,
253 } => write!(
254 f,
255 "Invalid broadcast from {:?} to {:?}: {}",
256 source_shape, target_shape, reason
257 ),
258 ShapeInferenceError::InvalidConcatenate {
259 shape1,
260 shape2,
261 axis,
262 reason,
263 } => write!(
264 f,
265 "Invalid concatenate of {:?} and {:?} along axis {}: {}",
266 shape1, shape2, axis, reason
267 ),
268 ShapeInferenceError::InvalidSlice {
269 shape,
270 ranges,
271 reason,
272 } => {
273 write!(
274 f,
275 "Invalid slice of {:?} with ranges {:?}: {}",
276 shape, ranges, reason
277 )
278 }
279 ShapeInferenceError::CyclicDependency(id) => {
280 write!(f, "Cyclic dependency detected at {}", id)
281 }
282 ShapeInferenceError::UnknownShape(id) => write!(f, "Unknown shape for {}", id),
283 }
284 }
285}
286
287#[cfg(feature = "std")]
288impl std::error::Error for ShapeInferenceError {}
289
290pub type InferenceResult<T> = Result<T, ShapeInferenceError>;
292
293#[derive(Debug, Clone)]
295pub struct ShapeGraph {
296 nodes: BTreeMap<NodeId, ShapeNode>,
298 next_id: usize,
300 cache: BTreeMap<NodeId, Vec<usize>>,
302}
303
304impl ShapeGraph {
305 pub fn new() -> Self {
307 Self {
308 nodes: BTreeMap::new(),
309 next_id: 0,
310 cache: BTreeMap::new(),
311 }
312 }
313
314 fn alloc_id(&mut self) -> NodeId {
316 let id = NodeId(self.next_id);
317 self.next_id += 1;
318 id
319 }
320
321 pub fn add_input(&mut self, shape: Vec<usize>) -> NodeId {
323 let id = self.alloc_id();
324 let mut node = ShapeNode::new(id, ShapeOp::Input, Vec::new());
325 node.set_shape(shape.clone());
326 self.nodes.insert(id, node);
327 self.cache.insert(id, shape);
328 id
329 }
330
331 pub fn reshape(&mut self, input: NodeId, target_shape: Vec<usize>) -> NodeId {
333 let id = self.alloc_id();
334 let node = ShapeNode::new(
335 id,
336 ShapeOp::Reshape {
337 target_shape: target_shape.clone(),
338 },
339 vec![input],
340 );
341 self.nodes.insert(id, node);
342 id
343 }
344
345 pub fn transpose(&mut self, input: NodeId, axes: Vec<usize>) -> NodeId {
347 let id = self.alloc_id();
348 let node = ShapeNode::new(id, ShapeOp::Transpose { axes }, vec![input]);
349 self.nodes.insert(id, node);
350 id
351 }
352
353 pub fn broadcast(&mut self, input: NodeId, target_shape: Vec<usize>) -> NodeId {
355 let id = self.alloc_id();
356 let node = ShapeNode::new(
357 id,
358 ShapeOp::Broadcast {
359 target_shape: target_shape.clone(),
360 },
361 vec![input],
362 );
363 self.nodes.insert(id, node);
364 id
365 }
366
367 pub fn concatenate(&mut self, input1: NodeId, input2: NodeId, axis: usize) -> NodeId {
369 let id = self.alloc_id();
370 let node = ShapeNode::new(
371 id,
372 ShapeOp::Concatenate {
373 axis,
374 other: input2,
375 },
376 vec![input1, input2],
377 );
378 self.nodes.insert(id, node);
379 id
380 }
381
382 pub fn stack(&mut self, input1: NodeId, input2: NodeId, axis: usize) -> NodeId {
384 let id = self.alloc_id();
385 let node = ShapeNode::new(
386 id,
387 ShapeOp::Stack {
388 axis,
389 other: input2,
390 },
391 vec![input1, input2],
392 );
393 self.nodes.insert(id, node);
394 id
395 }
396
397 pub fn squeeze(&mut self, input: NodeId, axes: Option<Vec<usize>>) -> NodeId {
399 let id = self.alloc_id();
400 let node = ShapeNode::new(id, ShapeOp::Squeeze { axes }, vec![input]);
401 self.nodes.insert(id, node);
402 id
403 }
404
405 pub fn unsqueeze(&mut self, input: NodeId, axes: Vec<usize>) -> NodeId {
407 let id = self.alloc_id();
408 let node = ShapeNode::new(id, ShapeOp::Unsqueeze { axes }, vec![input]);
409 self.nodes.insert(id, node);
410 id
411 }
412
413 pub fn flatten(&mut self, input: NodeId, start_dim: usize, end_dim: usize) -> NodeId {
415 let id = self.alloc_id();
416 let node = ShapeNode::new(id, ShapeOp::Flatten { start_dim, end_dim }, vec![input]);
417 self.nodes.insert(id, node);
418 id
419 }
420
421 pub fn get_node(&self, id: NodeId) -> Option<&ShapeNode> {
423 self.nodes.get(&id)
424 }
425
426 pub fn infer_shape(&mut self, node_id: NodeId) -> InferenceResult<Vec<usize>> {
428 if let Some(cached) = self.cache.get(&node_id) {
430 return Ok(cached.clone());
431 }
432
433 let node = self
435 .nodes
436 .get(&node_id)
437 .ok_or(ShapeInferenceError::NodeNotFound(node_id))?
438 .clone();
439
440 if let Some(shape) = node.shape() {
442 let result = shape.to_vec();
443 self.cache.insert(node_id, result.clone());
444 return Ok(result);
445 }
446
447 let inferred_shape = match &node.op {
449 ShapeOp::Input => {
450 return Err(ShapeInferenceError::UnknownShape(node_id));
451 }
452 ShapeOp::Reshape { target_shape } => {
453 let input_id = node.dependencies[0];
454 let input_shape = self.infer_shape(input_id)?;
455 Self::infer_reshape(&input_shape, target_shape)?
456 }
457 ShapeOp::Transpose { axes } => {
458 let input_id = node.dependencies[0];
459 let input_shape = self.infer_shape(input_id)?;
460 Self::infer_transpose(&input_shape, axes)?
461 }
462 ShapeOp::Broadcast { target_shape } => {
463 let input_id = node.dependencies[0];
464 let input_shape = self.infer_shape(input_id)?;
465 Self::infer_broadcast(&input_shape, target_shape)?
466 }
467 ShapeOp::Concatenate { axis, .. } => {
468 let input1_id = node.dependencies[0];
469 let input2_id = node.dependencies[1];
470 let shape1 = self.infer_shape(input1_id)?;
471 let shape2 = self.infer_shape(input2_id)?;
472 Self::infer_concatenate(&shape1, &shape2, *axis)?
473 }
474 ShapeOp::Stack { axis, .. } => {
475 let input1_id = node.dependencies[0];
476 let input2_id = node.dependencies[1];
477 let shape1 = self.infer_shape(input1_id)?;
478 let shape2 = self.infer_shape(input2_id)?;
479 Self::infer_stack(&shape1, &shape2, *axis)?
480 }
481 ShapeOp::Squeeze { axes } => {
482 let input_id = node.dependencies[0];
483 let input_shape = self.infer_shape(input_id)?;
484 Self::infer_squeeze(&input_shape, axes.as_ref())?
485 }
486 ShapeOp::Unsqueeze { axes } => {
487 let input_id = node.dependencies[0];
488 let input_shape = self.infer_shape(input_id)?;
489 Self::infer_unsqueeze(&input_shape, axes)?
490 }
491 ShapeOp::Flatten { start_dim, end_dim } => {
492 let input_id = node.dependencies[0];
493 let input_shape = self.infer_shape(input_id)?;
494 Self::infer_flatten(&input_shape, *start_dim, *end_dim)?
495 }
496 _ => {
497 return Err(ShapeInferenceError::UnknownShape(node_id));
498 }
499 };
500
501 self.cache.insert(node_id, inferred_shape.clone());
503
504 if let Some(node) = self.nodes.get_mut(&node_id) {
506 node.set_shape(inferred_shape.clone());
507 }
508
509 Ok(inferred_shape)
510 }
511
512 fn infer_reshape(input_shape: &[usize], target_shape: &[usize]) -> InferenceResult<Vec<usize>> {
514 let input_numel: usize = input_shape.iter().product();
515 let mut output_shape = target_shape.to_vec();
516
517 let neg_count = target_shape.iter().filter(|&&x| x == usize::MAX).count();
519 if neg_count > 1 {
520 return Err(ShapeInferenceError::InvalidReshape {
521 source_shape: input_shape.to_vec(),
522 target_shape: target_shape.to_vec(),
523 reason: "At most one dimension can be inferred".to_string(),
524 });
525 }
526
527 if neg_count == 1 {
528 let known_product: usize = target_shape.iter().filter(|&&x| x != usize::MAX).product();
529 if known_product == 0 || input_numel % known_product != 0 {
530 return Err(ShapeInferenceError::InvalidReshape {
531 source_shape: input_shape.to_vec(),
532 target_shape: target_shape.to_vec(),
533 reason: "Cannot infer dimension size".to_string(),
534 });
535 }
536 let inferred = input_numel / known_product;
537 for dim in &mut output_shape {
538 if *dim == usize::MAX {
539 *dim = inferred;
540 }
541 }
542 }
543
544 let output_numel: usize = output_shape.iter().product();
545 if input_numel != output_numel {
546 return Err(ShapeInferenceError::InvalidReshape {
547 source_shape: input_shape.to_vec(),
548 target_shape: target_shape.to_vec(),
549 reason: format!(
550 "Element count mismatch: {} vs {}",
551 input_numel, output_numel
552 ),
553 });
554 }
555
556 Ok(output_shape)
557 }
558
559 fn infer_transpose(input_shape: &[usize], axes: &[usize]) -> InferenceResult<Vec<usize>> {
561 if axes.len() != input_shape.len() {
562 return Err(ShapeInferenceError::InvalidTranspose {
563 shape: input_shape.to_vec(),
564 axes: axes.to_vec(),
565 reason: "Axes count must match shape rank".to_string(),
566 });
567 }
568
569 let mut output_shape = vec![0; input_shape.len()];
570 for (i, &axis) in axes.iter().enumerate() {
571 if axis >= input_shape.len() {
572 return Err(ShapeInferenceError::InvalidTranspose {
573 shape: input_shape.to_vec(),
574 axes: axes.to_vec(),
575 reason: format!("Axis {} out of bounds", axis),
576 });
577 }
578 output_shape[i] = input_shape[axis];
579 }
580
581 Ok(output_shape)
582 }
583
584 fn infer_broadcast(
586 input_shape: &[usize],
587 target_shape: &[usize],
588 ) -> InferenceResult<Vec<usize>> {
589 if input_shape.len() > target_shape.len() {
590 return Err(ShapeInferenceError::InvalidBroadcast {
591 source_shape: input_shape.to_vec(),
592 target_shape: target_shape.to_vec(),
593 reason: "Source rank exceeds target rank".to_string(),
594 });
595 }
596
597 let offset = target_shape.len() - input_shape.len();
598 for (i, &dim) in input_shape.iter().enumerate() {
599 let target_dim = target_shape[offset + i];
600 if dim != 1 && dim != target_dim {
601 return Err(ShapeInferenceError::InvalidBroadcast {
602 source_shape: input_shape.to_vec(),
603 target_shape: target_shape.to_vec(),
604 reason: format!(
605 "Dimension {} cannot be broadcast: {} to {}",
606 i, dim, target_dim
607 ),
608 });
609 }
610 }
611
612 Ok(target_shape.to_vec())
613 }
614
615 fn infer_concatenate(
617 shape1: &[usize],
618 shape2: &[usize],
619 axis: usize,
620 ) -> InferenceResult<Vec<usize>> {
621 if shape1.len() != shape2.len() {
622 return Err(ShapeInferenceError::InvalidConcatenate {
623 shape1: shape1.to_vec(),
624 shape2: shape2.to_vec(),
625 axis,
626 reason: "Shapes must have same rank".to_string(),
627 });
628 }
629
630 if axis >= shape1.len() {
631 return Err(ShapeInferenceError::InvalidConcatenate {
632 shape1: shape1.to_vec(),
633 shape2: shape2.to_vec(),
634 axis,
635 reason: format!("Axis {} out of bounds", axis),
636 });
637 }
638
639 for (i, (&dim1, &dim2)) in shape1.iter().zip(shape2.iter()).enumerate() {
640 if i != axis && dim1 != dim2 {
641 return Err(ShapeInferenceError::InvalidConcatenate {
642 shape1: shape1.to_vec(),
643 shape2: shape2.to_vec(),
644 axis,
645 reason: format!("Dimension {} mismatch: {} vs {}", i, dim1, dim2),
646 });
647 }
648 }
649
650 let mut output = shape1.to_vec();
651 output[axis] += shape2[axis];
652 Ok(output)
653 }
654
655 fn infer_stack(shape1: &[usize], shape2: &[usize], axis: usize) -> InferenceResult<Vec<usize>> {
657 if shape1 != shape2 {
658 return Err(ShapeInferenceError::InvalidConcatenate {
659 shape1: shape1.to_vec(),
660 shape2: shape2.to_vec(),
661 axis,
662 reason: "Shapes must be identical for stack".to_string(),
663 });
664 }
665
666 if axis > shape1.len() {
667 return Err(ShapeInferenceError::InvalidConcatenate {
668 shape1: shape1.to_vec(),
669 shape2: shape2.to_vec(),
670 axis,
671 reason: format!("Axis {} out of bounds", axis),
672 });
673 }
674
675 let mut output = Vec::with_capacity(shape1.len() + 1);
676 output.extend_from_slice(&shape1[..axis]);
677 output.push(2);
678 output.extend_from_slice(&shape1[axis..]);
679 Ok(output)
680 }
681
682 fn infer_squeeze(
684 input_shape: &[usize],
685 axes: Option<&Vec<usize>>,
686 ) -> InferenceResult<Vec<usize>> {
687 let output = if let Some(axes) = axes {
688 input_shape
689 .iter()
690 .enumerate()
691 .filter(|(i, &dim)| !axes.contains(i) || dim != 1)
692 .map(|(_, &dim)| dim)
693 .collect()
694 } else {
695 input_shape
696 .iter()
697 .filter(|&&dim| dim != 1)
698 .copied()
699 .collect()
700 };
701 Ok(output)
702 }
703
704 fn infer_unsqueeze(input_shape: &[usize], axes: &[usize]) -> InferenceResult<Vec<usize>> {
706 let mut sorted_axes = axes.to_vec();
707 sorted_axes.sort_unstable();
708
709 let final_rank = input_shape.len() + axes.len();
711 let mut output = Vec::with_capacity(final_rank);
712 let mut input_idx = 0;
713 let mut axes_idx = 0;
714
715 for i in 0..final_rank {
716 if axes_idx < sorted_axes.len() && sorted_axes[axes_idx] == i {
717 output.push(1);
719 axes_idx += 1;
720 } else {
721 if input_idx >= input_shape.len() {
723 return Err(ShapeInferenceError::UnknownShape(NodeId(0)));
724 }
725 output.push(input_shape[input_idx]);
726 input_idx += 1;
727 }
728 }
729
730 Ok(output)
731 }
732
733 fn infer_flatten(
735 input_shape: &[usize],
736 start_dim: usize,
737 end_dim: usize,
738 ) -> InferenceResult<Vec<usize>> {
739 if start_dim >= input_shape.len() || end_dim >= input_shape.len() || start_dim > end_dim {
740 return Err(ShapeInferenceError::UnknownShape(NodeId(0)));
741 }
742
743 let flattened_size: usize = input_shape[start_dim..=end_dim].iter().product();
744 let mut output = Vec::with_capacity(input_shape.len() - (end_dim - start_dim));
745 output.extend_from_slice(&input_shape[..start_dim]);
746 output.push(flattened_size);
747 if end_dim + 1 < input_shape.len() {
748 output.extend_from_slice(&input_shape[end_dim + 1..]);
749 }
750
751 Ok(output)
752 }
753
754 pub fn clear_cache(&mut self) {
756 self.cache.clear();
757 }
758
759 pub fn topological_sort(&self) -> InferenceResult<Vec<NodeId>> {
761 let mut result = Vec::new();
762 let mut visited = BTreeMap::new();
763 let mut temp_mark = BTreeMap::new();
764
765 for &node_id in self.nodes.keys() {
766 if !visited.contains_key(&node_id) {
767 self.visit_node(node_id, &mut visited, &mut temp_mark, &mut result)?;
768 }
769 }
770
771 Ok(result)
772 }
773
774 fn visit_node(
775 &self,
776 node_id: NodeId,
777 visited: &mut BTreeMap<NodeId, bool>,
778 temp_mark: &mut BTreeMap<NodeId, bool>,
779 result: &mut Vec<NodeId>,
780 ) -> InferenceResult<()> {
781 if visited.get(&node_id) == Some(&true) {
782 return Ok(());
783 }
784
785 if temp_mark.get(&node_id) == Some(&true) {
786 return Err(ShapeInferenceError::CyclicDependency(node_id));
787 }
788
789 temp_mark.insert(node_id, true);
790
791 if let Some(node) = self.nodes.get(&node_id) {
792 for &dep_id in &node.dependencies {
793 self.visit_node(dep_id, visited, temp_mark, result)?;
794 }
795 }
796
797 temp_mark.insert(node_id, false);
798 visited.insert(node_id, true);
799 result.push(node_id);
800
801 Ok(())
802 }
803
804 pub fn node_count(&self) -> usize {
806 self.nodes.len()
807 }
808}
809
810impl Default for ShapeGraph {
811 fn default() -> Self {
812 Self::new()
813 }
814}
815
816#[cfg(test)]
817mod tests {
818 use super::*;
819
820 extern crate std;
821 use std::vec;
822
823 #[test]
824 fn test_input_node() {
825 let mut graph = ShapeGraph::new();
826 let input = graph.add_input(vec![2, 3, 4]);
827
828 let shape = graph
829 .infer_shape(input)
830 .expect("infer_shape should succeed");
831 assert_eq!(shape, vec![2, 3, 4]);
832 }
833
834 #[test]
835 fn test_reshape() {
836 let mut graph = ShapeGraph::new();
837 let input = graph.add_input(vec![2, 3, 4]);
838 let reshaped = graph.reshape(input, vec![2, 12]);
839
840 let shape = graph
841 .infer_shape(reshaped)
842 .expect("infer_shape should succeed");
843 assert_eq!(shape, vec![2, 12]);
844 }
845
846 #[test]
847 fn test_transpose() {
848 let mut graph = ShapeGraph::new();
849 let input = graph.add_input(vec![2, 3, 4]);
850 let transposed = graph.transpose(input, vec![2, 0, 1]);
851
852 let shape = graph
853 .infer_shape(transposed)
854 .expect("infer_shape should succeed");
855 assert_eq!(shape, vec![4, 2, 3]);
856 }
857
858 #[test]
859 fn test_broadcast() {
860 let mut graph = ShapeGraph::new();
861 let input = graph.add_input(vec![1, 3, 1]);
862 let broadcasted = graph.broadcast(input, vec![2, 3, 4]);
863
864 let shape = graph
865 .infer_shape(broadcasted)
866 .expect("infer_shape should succeed");
867 assert_eq!(shape, vec![2, 3, 4]);
868 }
869
870 #[test]
871 fn test_concatenate() {
872 let mut graph = ShapeGraph::new();
873 let input1 = graph.add_input(vec![2, 3, 4]);
874 let input2 = graph.add_input(vec![2, 5, 4]);
875 let concatenated = graph.concatenate(input1, input2, 1);
876
877 let shape = graph
878 .infer_shape(concatenated)
879 .expect("infer_shape should succeed");
880 assert_eq!(shape, vec![2, 8, 4]);
881 }
882
883 #[test]
884 fn test_stack() {
885 let mut graph = ShapeGraph::new();
886 let input1 = graph.add_input(vec![2, 3, 4]);
887 let input2 = graph.add_input(vec![2, 3, 4]);
888 let stacked = graph.stack(input1, input2, 1);
889
890 let shape = graph
891 .infer_shape(stacked)
892 .expect("infer_shape should succeed");
893 assert_eq!(shape, vec![2, 2, 3, 4]);
894 }
895
896 #[test]
897 fn test_squeeze() {
898 let mut graph = ShapeGraph::new();
899 let input = graph.add_input(vec![2, 1, 3, 1, 4]);
900 let squeezed = graph.squeeze(input, None);
901
902 let shape = graph
903 .infer_shape(squeezed)
904 .expect("infer_shape should succeed");
905 assert_eq!(shape, vec![2, 3, 4]);
906 }
907
908 #[test]
909 fn test_unsqueeze() {
910 let mut graph = ShapeGraph::new();
911 let input = graph.add_input(vec![2, 3, 4]);
912 let unsqueezed = graph.unsqueeze(input, vec![1, 3]);
913
914 let shape = graph
915 .infer_shape(unsqueezed)
916 .expect("infer_shape should succeed");
917 assert_eq!(shape, vec![2, 1, 3, 1, 4]);
918 }
919
920 #[test]
921 fn test_flatten() {
922 let mut graph = ShapeGraph::new();
923 let input = graph.add_input(vec![2, 3, 4, 5]);
924 let flattened = graph.flatten(input, 1, 2);
925
926 let shape = graph
927 .infer_shape(flattened)
928 .expect("infer_shape should succeed");
929 assert_eq!(shape, vec![2, 12, 5]);
930 }
931
932 #[test]
933 fn test_complex_graph() {
934 let mut graph = ShapeGraph::new();
935 let input = graph.add_input(vec![2, 3, 4]);
936 let reshaped = graph.reshape(input, vec![2, 12]);
937 let transposed = graph.transpose(reshaped, vec![1, 0]);
938 let unsqueezed = graph.unsqueeze(transposed, vec![1]);
939
940 let shape = graph
941 .infer_shape(unsqueezed)
942 .expect("infer_shape should succeed");
943 assert_eq!(shape, vec![12, 1, 2]);
944 }
945
946 #[test]
947 fn test_invalid_reshape() {
948 let mut graph = ShapeGraph::new();
949 let input = graph.add_input(vec![2, 3, 4]);
950 let reshaped = graph.reshape(input, vec![2, 13]); assert!(graph.infer_shape(reshaped).is_err());
953 }
954
955 #[test]
956 fn test_invalid_transpose() {
957 let mut graph = ShapeGraph::new();
958 let input = graph.add_input(vec![2, 3, 4]);
959 let transposed = graph.transpose(input, vec![0, 1]); assert!(graph.infer_shape(transposed).is_err());
962 }
963
964 #[test]
965 fn test_invalid_broadcast() {
966 let mut graph = ShapeGraph::new();
967 let input = graph.add_input(vec![2, 3, 4]);
968 let broadcasted = graph.broadcast(input, vec![2, 5, 4]); assert!(graph.infer_shape(broadcasted).is_err());
971 }
972
973 #[test]
974 fn test_topological_sort() {
975 let mut graph = ShapeGraph::new();
976 let input = graph.add_input(vec![2, 3, 4]);
977 let reshaped = graph.reshape(input, vec![2, 12]);
978 let transposed = graph.transpose(reshaped, vec![1, 0]);
979
980 let sorted = graph
981 .topological_sort()
982 .expect("topological_sort should succeed");
983 assert_eq!(sorted.len(), 3);
984
985 let input_pos = sorted
987 .iter()
988 .position(|&id| id == input)
989 .expect("input should be in sorted list");
990 let reshape_pos = sorted
991 .iter()
992 .position(|&id| id == reshaped)
993 .expect("reshaped should be in sorted list");
994 let transpose_pos = sorted
995 .iter()
996 .position(|&id| id == transposed)
997 .expect("transposed should be in sorted list");
998
999 assert!(input_pos < reshape_pos);
1000 assert!(reshape_pos < transpose_pos);
1001 }
1002
1003 #[test]
1004 fn test_cache() {
1005 let mut graph = ShapeGraph::new();
1006 let input = graph.add_input(vec![2, 3, 4]);
1007 let reshaped = graph.reshape(input, vec![2, 12]);
1008
1009 let shape1 = graph
1011 .infer_shape(reshaped)
1012 .expect("infer_shape should succeed");
1013
1014 let shape2 = graph
1016 .infer_shape(reshaped)
1017 .expect("infer_shape should succeed");
1018
1019 assert_eq!(shape1, shape2);
1020 assert_eq!(shape1, vec![2, 12]);
1021 }
1022
1023 #[test]
1024 fn test_clear_cache() {
1025 let mut graph = ShapeGraph::new();
1026 let input = graph.add_input(vec![2, 3, 4]);
1027 let reshaped = graph.reshape(input, vec![2, 12]);
1028
1029 let _ = graph
1031 .infer_shape(reshaped)
1032 .expect("infer_shape should succeed");
1033
1034 graph.clear_cache();
1036
1037 let shape = graph
1039 .infer_shape(reshaped)
1040 .expect("infer_shape should succeed");
1041 assert_eq!(shape, vec![2, 12]);
1042 }
1043
1044 #[test]
1045 fn test_node_count() {
1046 let mut graph = ShapeGraph::new();
1047 assert_eq!(graph.node_count(), 0);
1048
1049 let input = graph.add_input(vec![2, 3, 4]);
1050 assert_eq!(graph.node_count(), 1);
1051
1052 let _reshaped = graph.reshape(input, vec![2, 12]);
1053 assert_eq!(graph.node_count(), 2);
1054 }
1055}