Skip to main content

torsh_core/
shape_graph.rs

1//! Graph-Based Shape Inference for Optimization
2//!
3//! This module provides graph-based shape inference capabilities that enable
4//! compile-time shape analysis and optimization. By building a computation graph
5//! of shape operations, we can:
6//!
7//! - Infer output shapes from input shapes
8//! - Detect shape errors at graph construction time
9//! - Optimize shape calculations by caching and reuse
10//! - Enable shape fusion and elimination of redundant operations
11//! - Provide better error messages with full operation context
12//!
13//! # Architecture
14//!
15//! The shape graph consists of nodes representing shape values and edges
16//! representing operations that transform shapes. Each node stores:
17//! - The shape value (if known)
18//! - Dependencies on other nodes
19//! - Metadata about the operation that produced it
20//!
21//! # Example
22//!
23//! ```ignore
24//! use torsh_core::shape_graph::*;
25//!
26//! let mut graph = ShapeGraph::new();
27//! let input = graph.add_input(vec![2, 3, 4]);
28//! let reshaped = graph.reshape(input, vec![2, 12]);
29//! let transposed = graph.transpose(reshaped, vec![1, 0]);
30//!
31//! // Infer the final shape
32//! let output_shape = graph.infer_shape(transposed).unwrap();
33//! assert_eq!(output_shape, vec![12, 2]);
34//! ```
35
36#[cfg(not(feature = "std"))]
37use alloc::collections::BTreeMap;
38#[cfg(not(feature = "std"))]
39use alloc::{string::String, vec::Vec};
40use core::fmt;
41#[cfg(feature = "std")]
42use std::collections::BTreeMap;
43#[cfg(feature = "std")]
44use std::vec::Vec;
45
46/// Node ID in the shape graph
47#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
48pub struct NodeId(usize);
49
50impl NodeId {
51    /// Create a new node ID
52    pub fn new(id: usize) -> Self {
53        Self(id)
54    }
55
56    /// Get the underlying ID
57    pub fn id(&self) -> usize {
58        self.0
59    }
60}
61
62impl fmt::Display for NodeId {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        write!(f, "Node({})", self.0)
65    }
66}
67
68/// Shape operation type
69#[derive(Debug, Clone, PartialEq, Eq)]
70pub enum ShapeOp {
71    /// Input shape (leaf node)
72    Input,
73    /// Reshape operation
74    Reshape { target_shape: Vec<usize> },
75    /// Transpose operation
76    Transpose { axes: Vec<usize> },
77    /// Broadcast operation
78    Broadcast { target_shape: Vec<usize> },
79    /// Concatenate operation along axis
80    Concatenate { axis: usize, other: NodeId },
81    /// Stack operation along new axis
82    Stack { axis: usize, other: NodeId },
83    /// Squeeze operation (remove dimensions of size 1)
84    Squeeze { axes: Option<Vec<usize>> },
85    /// Unsqueeze operation (add dimensions of size 1)
86    Unsqueeze { axes: Vec<usize> },
87    /// Slice operation
88    Slice { ranges: Vec<(usize, usize)> },
89    /// Expand operation (broadcast without materialization)
90    Expand { target_shape: Vec<usize> },
91    /// Flatten operation
92    Flatten { start_dim: usize, end_dim: usize },
93    /// Permute operation (generalized transpose)
94    Permute { dims: Vec<usize> },
95}
96
97impl fmt::Display for ShapeOp {
98    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99        match self {
100            ShapeOp::Input => write!(f, "Input"),
101            ShapeOp::Reshape { target_shape } => write!(f, "Reshape({:?})", target_shape),
102            ShapeOp::Transpose { axes } => write!(f, "Transpose({:?})", axes),
103            ShapeOp::Broadcast { target_shape } => write!(f, "Broadcast({:?})", target_shape),
104            ShapeOp::Concatenate { axis, other } => {
105                write!(f, "Concatenate(axis={}, {})", axis, other)
106            }
107            ShapeOp::Stack { axis, other } => write!(f, "Stack(axis={}, {})", axis, other),
108            ShapeOp::Squeeze { axes } => write!(f, "Squeeze({:?})", axes),
109            ShapeOp::Unsqueeze { axes } => write!(f, "Unsqueeze({:?})", axes),
110            ShapeOp::Slice { ranges } => write!(f, "Slice({:?})", ranges),
111            ShapeOp::Expand { target_shape } => write!(f, "Expand({:?})", target_shape),
112            ShapeOp::Flatten { start_dim, end_dim } => {
113                write!(f, "Flatten({}..{})", start_dim, end_dim)
114            }
115            ShapeOp::Permute { dims } => write!(f, "Permute({:?})", dims),
116        }
117    }
118}
119
120/// Shape graph node
121#[derive(Debug, Clone)]
122pub struct ShapeNode {
123    /// Node ID
124    id: NodeId,
125    /// Known shape (if computed)
126    shape: Option<Vec<usize>>,
127    /// Operation that produces this node
128    op: ShapeOp,
129    /// Dependencies (input nodes)
130    dependencies: Vec<NodeId>,
131    /// Metadata for debugging
132    name: Option<String>,
133}
134
135impl ShapeNode {
136    /// Create a new shape node
137    pub fn new(id: NodeId, op: ShapeOp, dependencies: Vec<NodeId>) -> Self {
138        Self {
139            id,
140            shape: None,
141            op,
142            dependencies,
143            name: None,
144        }
145    }
146
147    /// Get the node ID
148    pub fn id(&self) -> NodeId {
149        self.id
150    }
151
152    /// Set the shape for this node
153    pub fn set_shape(&mut self, shape: Vec<usize>) {
154        self.shape = Some(shape);
155    }
156
157    /// Get the shape if known
158    pub fn shape(&self) -> Option<&[usize]> {
159        self.shape.as_deref()
160    }
161
162    /// Get the operation
163    pub fn op(&self) -> &ShapeOp {
164        &self.op
165    }
166
167    /// Get the dependencies
168    pub fn dependencies(&self) -> &[NodeId] {
169        &self.dependencies
170    }
171
172    /// Set a name for debugging
173    pub fn set_name(&mut self, name: String) {
174        self.name = Some(name);
175    }
176
177    /// Get the name if set
178    pub fn name(&self) -> Option<&str> {
179        self.name.as_deref()
180    }
181}
182
183/// Shape inference error
184#[derive(Debug, Clone, PartialEq, Eq)]
185pub enum ShapeInferenceError {
186    /// Node not found
187    NodeNotFound(NodeId),
188    /// Invalid reshape target
189    InvalidReshape {
190        source_shape: Vec<usize>,
191        target_shape: Vec<usize>,
192        reason: String,
193    },
194    /// Invalid transpose axes
195    InvalidTranspose {
196        shape: Vec<usize>,
197        axes: Vec<usize>,
198        reason: String,
199    },
200    /// Invalid broadcast
201    InvalidBroadcast {
202        source_shape: Vec<usize>,
203        target_shape: Vec<usize>,
204        reason: String,
205    },
206    /// Invalid concatenation
207    InvalidConcatenate {
208        shape1: Vec<usize>,
209        shape2: Vec<usize>,
210        axis: usize,
211        reason: String,
212    },
213    /// Invalid slice
214    InvalidSlice {
215        shape: Vec<usize>,
216        ranges: Vec<(usize, usize)>,
217        reason: String,
218    },
219    /// Cyclic dependency detected
220    CyclicDependency(NodeId),
221    /// Unknown shape (cannot infer)
222    UnknownShape(NodeId),
223}
224
225impl fmt::Display for ShapeInferenceError {
226    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227        match self {
228            ShapeInferenceError::NodeNotFound(id) => write!(f, "Node not found: {}", id),
229            ShapeInferenceError::InvalidReshape {
230                source_shape,
231                target_shape,
232                reason,
233            } => write!(
234                f,
235                "Invalid reshape from {:?} to {:?}: {}",
236                source_shape, target_shape, reason
237            ),
238            ShapeInferenceError::InvalidTranspose {
239                shape,
240                axes,
241                reason,
242            } => {
243                write!(
244                    f,
245                    "Invalid transpose of {:?} with axes {:?}: {}",
246                    shape, axes, reason
247                )
248            }
249            ShapeInferenceError::InvalidBroadcast {
250                source_shape,
251                target_shape,
252                reason,
253            } => write!(
254                f,
255                "Invalid broadcast from {:?} to {:?}: {}",
256                source_shape, target_shape, reason
257            ),
258            ShapeInferenceError::InvalidConcatenate {
259                shape1,
260                shape2,
261                axis,
262                reason,
263            } => write!(
264                f,
265                "Invalid concatenate of {:?} and {:?} along axis {}: {}",
266                shape1, shape2, axis, reason
267            ),
268            ShapeInferenceError::InvalidSlice {
269                shape,
270                ranges,
271                reason,
272            } => {
273                write!(
274                    f,
275                    "Invalid slice of {:?} with ranges {:?}: {}",
276                    shape, ranges, reason
277                )
278            }
279            ShapeInferenceError::CyclicDependency(id) => {
280                write!(f, "Cyclic dependency detected at {}", id)
281            }
282            ShapeInferenceError::UnknownShape(id) => write!(f, "Unknown shape for {}", id),
283        }
284    }
285}
286
287#[cfg(feature = "std")]
288impl std::error::Error for ShapeInferenceError {}
289
290/// Result type for shape inference
291pub type InferenceResult<T> = Result<T, ShapeInferenceError>;
292
293/// Shape graph for inference and optimization
294#[derive(Debug, Clone)]
295pub struct ShapeGraph {
296    /// All nodes in the graph
297    nodes: BTreeMap<NodeId, ShapeNode>,
298    /// Next node ID
299    next_id: usize,
300    /// Cached inference results
301    cache: BTreeMap<NodeId, Vec<usize>>,
302}
303
304impl ShapeGraph {
305    /// Create a new empty shape graph
306    pub fn new() -> Self {
307        Self {
308            nodes: BTreeMap::new(),
309            next_id: 0,
310            cache: BTreeMap::new(),
311        }
312    }
313
314    /// Allocate a new node ID
315    fn alloc_id(&mut self) -> NodeId {
316        let id = NodeId(self.next_id);
317        self.next_id += 1;
318        id
319    }
320
321    /// Add an input node with a known shape
322    pub fn add_input(&mut self, shape: Vec<usize>) -> NodeId {
323        let id = self.alloc_id();
324        let mut node = ShapeNode::new(id, ShapeOp::Input, Vec::new());
325        node.set_shape(shape.clone());
326        self.nodes.insert(id, node);
327        self.cache.insert(id, shape);
328        id
329    }
330
331    /// Add a reshape operation
332    pub fn reshape(&mut self, input: NodeId, target_shape: Vec<usize>) -> NodeId {
333        let id = self.alloc_id();
334        let node = ShapeNode::new(
335            id,
336            ShapeOp::Reshape {
337                target_shape: target_shape.clone(),
338            },
339            vec![input],
340        );
341        self.nodes.insert(id, node);
342        id
343    }
344
345    /// Add a transpose operation
346    pub fn transpose(&mut self, input: NodeId, axes: Vec<usize>) -> NodeId {
347        let id = self.alloc_id();
348        let node = ShapeNode::new(id, ShapeOp::Transpose { axes }, vec![input]);
349        self.nodes.insert(id, node);
350        id
351    }
352
353    /// Add a broadcast operation
354    pub fn broadcast(&mut self, input: NodeId, target_shape: Vec<usize>) -> NodeId {
355        let id = self.alloc_id();
356        let node = ShapeNode::new(
357            id,
358            ShapeOp::Broadcast {
359                target_shape: target_shape.clone(),
360            },
361            vec![input],
362        );
363        self.nodes.insert(id, node);
364        id
365    }
366
367    /// Add a concatenate operation
368    pub fn concatenate(&mut self, input1: NodeId, input2: NodeId, axis: usize) -> NodeId {
369        let id = self.alloc_id();
370        let node = ShapeNode::new(
371            id,
372            ShapeOp::Concatenate {
373                axis,
374                other: input2,
375            },
376            vec![input1, input2],
377        );
378        self.nodes.insert(id, node);
379        id
380    }
381
382    /// Add a stack operation
383    pub fn stack(&mut self, input1: NodeId, input2: NodeId, axis: usize) -> NodeId {
384        let id = self.alloc_id();
385        let node = ShapeNode::new(
386            id,
387            ShapeOp::Stack {
388                axis,
389                other: input2,
390            },
391            vec![input1, input2],
392        );
393        self.nodes.insert(id, node);
394        id
395    }
396
397    /// Add a squeeze operation
398    pub fn squeeze(&mut self, input: NodeId, axes: Option<Vec<usize>>) -> NodeId {
399        let id = self.alloc_id();
400        let node = ShapeNode::new(id, ShapeOp::Squeeze { axes }, vec![input]);
401        self.nodes.insert(id, node);
402        id
403    }
404
405    /// Add an unsqueeze operation
406    pub fn unsqueeze(&mut self, input: NodeId, axes: Vec<usize>) -> NodeId {
407        let id = self.alloc_id();
408        let node = ShapeNode::new(id, ShapeOp::Unsqueeze { axes }, vec![input]);
409        self.nodes.insert(id, node);
410        id
411    }
412
413    /// Add a flatten operation
414    pub fn flatten(&mut self, input: NodeId, start_dim: usize, end_dim: usize) -> NodeId {
415        let id = self.alloc_id();
416        let node = ShapeNode::new(id, ShapeOp::Flatten { start_dim, end_dim }, vec![input]);
417        self.nodes.insert(id, node);
418        id
419    }
420
421    /// Get a node by ID
422    pub fn get_node(&self, id: NodeId) -> Option<&ShapeNode> {
423        self.nodes.get(&id)
424    }
425
426    /// Infer the shape for a node
427    pub fn infer_shape(&mut self, node_id: NodeId) -> InferenceResult<Vec<usize>> {
428        // Check cache first
429        if let Some(cached) = self.cache.get(&node_id) {
430            return Ok(cached.clone());
431        }
432
433        // Get the node
434        let node = self
435            .nodes
436            .get(&node_id)
437            .ok_or(ShapeInferenceError::NodeNotFound(node_id))?
438            .clone();
439
440        // If shape is already known, return it
441        if let Some(shape) = node.shape() {
442            let result = shape.to_vec();
443            self.cache.insert(node_id, result.clone());
444            return Ok(result);
445        }
446
447        // Infer based on operation
448        let inferred_shape = match &node.op {
449            ShapeOp::Input => {
450                return Err(ShapeInferenceError::UnknownShape(node_id));
451            }
452            ShapeOp::Reshape { target_shape } => {
453                let input_id = node.dependencies[0];
454                let input_shape = self.infer_shape(input_id)?;
455                Self::infer_reshape(&input_shape, target_shape)?
456            }
457            ShapeOp::Transpose { axes } => {
458                let input_id = node.dependencies[0];
459                let input_shape = self.infer_shape(input_id)?;
460                Self::infer_transpose(&input_shape, axes)?
461            }
462            ShapeOp::Broadcast { target_shape } => {
463                let input_id = node.dependencies[0];
464                let input_shape = self.infer_shape(input_id)?;
465                Self::infer_broadcast(&input_shape, target_shape)?
466            }
467            ShapeOp::Concatenate { axis, .. } => {
468                let input1_id = node.dependencies[0];
469                let input2_id = node.dependencies[1];
470                let shape1 = self.infer_shape(input1_id)?;
471                let shape2 = self.infer_shape(input2_id)?;
472                Self::infer_concatenate(&shape1, &shape2, *axis)?
473            }
474            ShapeOp::Stack { axis, .. } => {
475                let input1_id = node.dependencies[0];
476                let input2_id = node.dependencies[1];
477                let shape1 = self.infer_shape(input1_id)?;
478                let shape2 = self.infer_shape(input2_id)?;
479                Self::infer_stack(&shape1, &shape2, *axis)?
480            }
481            ShapeOp::Squeeze { axes } => {
482                let input_id = node.dependencies[0];
483                let input_shape = self.infer_shape(input_id)?;
484                Self::infer_squeeze(&input_shape, axes.as_ref())?
485            }
486            ShapeOp::Unsqueeze { axes } => {
487                let input_id = node.dependencies[0];
488                let input_shape = self.infer_shape(input_id)?;
489                Self::infer_unsqueeze(&input_shape, axes)?
490            }
491            ShapeOp::Flatten { start_dim, end_dim } => {
492                let input_id = node.dependencies[0];
493                let input_shape = self.infer_shape(input_id)?;
494                Self::infer_flatten(&input_shape, *start_dim, *end_dim)?
495            }
496            _ => {
497                return Err(ShapeInferenceError::UnknownShape(node_id));
498            }
499        };
500
501        // Cache the result
502        self.cache.insert(node_id, inferred_shape.clone());
503
504        // Update the node
505        if let Some(node) = self.nodes.get_mut(&node_id) {
506            node.set_shape(inferred_shape.clone());
507        }
508
509        Ok(inferred_shape)
510    }
511
512    /// Infer reshape output shape
513    fn infer_reshape(input_shape: &[usize], target_shape: &[usize]) -> InferenceResult<Vec<usize>> {
514        let input_numel: usize = input_shape.iter().product();
515        let mut output_shape = target_shape.to_vec();
516
517        // Handle -1 in target shape (infer dimension)
518        let neg_count = target_shape.iter().filter(|&&x| x == usize::MAX).count();
519        if neg_count > 1 {
520            return Err(ShapeInferenceError::InvalidReshape {
521                source_shape: input_shape.to_vec(),
522                target_shape: target_shape.to_vec(),
523                reason: "At most one dimension can be inferred".to_string(),
524            });
525        }
526
527        if neg_count == 1 {
528            let known_product: usize = target_shape.iter().filter(|&&x| x != usize::MAX).product();
529            if known_product == 0 || input_numel % known_product != 0 {
530                return Err(ShapeInferenceError::InvalidReshape {
531                    source_shape: input_shape.to_vec(),
532                    target_shape: target_shape.to_vec(),
533                    reason: "Cannot infer dimension size".to_string(),
534                });
535            }
536            let inferred = input_numel / known_product;
537            for dim in &mut output_shape {
538                if *dim == usize::MAX {
539                    *dim = inferred;
540                }
541            }
542        }
543
544        let output_numel: usize = output_shape.iter().product();
545        if input_numel != output_numel {
546            return Err(ShapeInferenceError::InvalidReshape {
547                source_shape: input_shape.to_vec(),
548                target_shape: target_shape.to_vec(),
549                reason: format!(
550                    "Element count mismatch: {} vs {}",
551                    input_numel, output_numel
552                ),
553            });
554        }
555
556        Ok(output_shape)
557    }
558
559    /// Infer transpose output shape
560    fn infer_transpose(input_shape: &[usize], axes: &[usize]) -> InferenceResult<Vec<usize>> {
561        if axes.len() != input_shape.len() {
562            return Err(ShapeInferenceError::InvalidTranspose {
563                shape: input_shape.to_vec(),
564                axes: axes.to_vec(),
565                reason: "Axes count must match shape rank".to_string(),
566            });
567        }
568
569        let mut output_shape = vec![0; input_shape.len()];
570        for (i, &axis) in axes.iter().enumerate() {
571            if axis >= input_shape.len() {
572                return Err(ShapeInferenceError::InvalidTranspose {
573                    shape: input_shape.to_vec(),
574                    axes: axes.to_vec(),
575                    reason: format!("Axis {} out of bounds", axis),
576                });
577            }
578            output_shape[i] = input_shape[axis];
579        }
580
581        Ok(output_shape)
582    }
583
584    /// Infer broadcast output shape
585    fn infer_broadcast(
586        input_shape: &[usize],
587        target_shape: &[usize],
588    ) -> InferenceResult<Vec<usize>> {
589        if input_shape.len() > target_shape.len() {
590            return Err(ShapeInferenceError::InvalidBroadcast {
591                source_shape: input_shape.to_vec(),
592                target_shape: target_shape.to_vec(),
593                reason: "Source rank exceeds target rank".to_string(),
594            });
595        }
596
597        let offset = target_shape.len() - input_shape.len();
598        for (i, &dim) in input_shape.iter().enumerate() {
599            let target_dim = target_shape[offset + i];
600            if dim != 1 && dim != target_dim {
601                return Err(ShapeInferenceError::InvalidBroadcast {
602                    source_shape: input_shape.to_vec(),
603                    target_shape: target_shape.to_vec(),
604                    reason: format!(
605                        "Dimension {} cannot be broadcast: {} to {}",
606                        i, dim, target_dim
607                    ),
608                });
609            }
610        }
611
612        Ok(target_shape.to_vec())
613    }
614
615    /// Infer concatenate output shape
616    fn infer_concatenate(
617        shape1: &[usize],
618        shape2: &[usize],
619        axis: usize,
620    ) -> InferenceResult<Vec<usize>> {
621        if shape1.len() != shape2.len() {
622            return Err(ShapeInferenceError::InvalidConcatenate {
623                shape1: shape1.to_vec(),
624                shape2: shape2.to_vec(),
625                axis,
626                reason: "Shapes must have same rank".to_string(),
627            });
628        }
629
630        if axis >= shape1.len() {
631            return Err(ShapeInferenceError::InvalidConcatenate {
632                shape1: shape1.to_vec(),
633                shape2: shape2.to_vec(),
634                axis,
635                reason: format!("Axis {} out of bounds", axis),
636            });
637        }
638
639        for (i, (&dim1, &dim2)) in shape1.iter().zip(shape2.iter()).enumerate() {
640            if i != axis && dim1 != dim2 {
641                return Err(ShapeInferenceError::InvalidConcatenate {
642                    shape1: shape1.to_vec(),
643                    shape2: shape2.to_vec(),
644                    axis,
645                    reason: format!("Dimension {} mismatch: {} vs {}", i, dim1, dim2),
646                });
647            }
648        }
649
650        let mut output = shape1.to_vec();
651        output[axis] += shape2[axis];
652        Ok(output)
653    }
654
655    /// Infer stack output shape
656    fn infer_stack(shape1: &[usize], shape2: &[usize], axis: usize) -> InferenceResult<Vec<usize>> {
657        if shape1 != shape2 {
658            return Err(ShapeInferenceError::InvalidConcatenate {
659                shape1: shape1.to_vec(),
660                shape2: shape2.to_vec(),
661                axis,
662                reason: "Shapes must be identical for stack".to_string(),
663            });
664        }
665
666        if axis > shape1.len() {
667            return Err(ShapeInferenceError::InvalidConcatenate {
668                shape1: shape1.to_vec(),
669                shape2: shape2.to_vec(),
670                axis,
671                reason: format!("Axis {} out of bounds", axis),
672            });
673        }
674
675        let mut output = Vec::with_capacity(shape1.len() + 1);
676        output.extend_from_slice(&shape1[..axis]);
677        output.push(2);
678        output.extend_from_slice(&shape1[axis..]);
679        Ok(output)
680    }
681
682    /// Infer squeeze output shape
683    fn infer_squeeze(
684        input_shape: &[usize],
685        axes: Option<&Vec<usize>>,
686    ) -> InferenceResult<Vec<usize>> {
687        let output = if let Some(axes) = axes {
688            input_shape
689                .iter()
690                .enumerate()
691                .filter(|(i, &dim)| !axes.contains(i) || dim != 1)
692                .map(|(_, &dim)| dim)
693                .collect()
694        } else {
695            input_shape
696                .iter()
697                .filter(|&&dim| dim != 1)
698                .copied()
699                .collect()
700        };
701        Ok(output)
702    }
703
704    /// Infer unsqueeze output shape
705    fn infer_unsqueeze(input_shape: &[usize], axes: &[usize]) -> InferenceResult<Vec<usize>> {
706        let mut sorted_axes = axes.to_vec();
707        sorted_axes.sort_unstable();
708
709        // Build output shape by inserting 1s at the specified axes
710        let final_rank = input_shape.len() + axes.len();
711        let mut output = Vec::with_capacity(final_rank);
712        let mut input_idx = 0;
713        let mut axes_idx = 0;
714
715        for i in 0..final_rank {
716            if axes_idx < sorted_axes.len() && sorted_axes[axes_idx] == i {
717                // This position is an unsqueezed axis
718                output.push(1);
719                axes_idx += 1;
720            } else {
721                // This position comes from the input
722                if input_idx >= input_shape.len() {
723                    return Err(ShapeInferenceError::UnknownShape(NodeId(0)));
724                }
725                output.push(input_shape[input_idx]);
726                input_idx += 1;
727            }
728        }
729
730        Ok(output)
731    }
732
733    /// Infer flatten output shape
734    fn infer_flatten(
735        input_shape: &[usize],
736        start_dim: usize,
737        end_dim: usize,
738    ) -> InferenceResult<Vec<usize>> {
739        if start_dim >= input_shape.len() || end_dim >= input_shape.len() || start_dim > end_dim {
740            return Err(ShapeInferenceError::UnknownShape(NodeId(0)));
741        }
742
743        let flattened_size: usize = input_shape[start_dim..=end_dim].iter().product();
744        let mut output = Vec::with_capacity(input_shape.len() - (end_dim - start_dim));
745        output.extend_from_slice(&input_shape[..start_dim]);
746        output.push(flattened_size);
747        if end_dim + 1 < input_shape.len() {
748            output.extend_from_slice(&input_shape[end_dim + 1..]);
749        }
750
751        Ok(output)
752    }
753
754    /// Clear the inference cache
755    pub fn clear_cache(&mut self) {
756        self.cache.clear();
757    }
758
759    /// Get all nodes in topological order
760    pub fn topological_sort(&self) -> InferenceResult<Vec<NodeId>> {
761        let mut result = Vec::new();
762        let mut visited = BTreeMap::new();
763        let mut temp_mark = BTreeMap::new();
764
765        for &node_id in self.nodes.keys() {
766            if !visited.contains_key(&node_id) {
767                self.visit_node(node_id, &mut visited, &mut temp_mark, &mut result)?;
768            }
769        }
770
771        Ok(result)
772    }
773
774    fn visit_node(
775        &self,
776        node_id: NodeId,
777        visited: &mut BTreeMap<NodeId, bool>,
778        temp_mark: &mut BTreeMap<NodeId, bool>,
779        result: &mut Vec<NodeId>,
780    ) -> InferenceResult<()> {
781        if visited.get(&node_id) == Some(&true) {
782            return Ok(());
783        }
784
785        if temp_mark.get(&node_id) == Some(&true) {
786            return Err(ShapeInferenceError::CyclicDependency(node_id));
787        }
788
789        temp_mark.insert(node_id, true);
790
791        if let Some(node) = self.nodes.get(&node_id) {
792            for &dep_id in &node.dependencies {
793                self.visit_node(dep_id, visited, temp_mark, result)?;
794            }
795        }
796
797        temp_mark.insert(node_id, false);
798        visited.insert(node_id, true);
799        result.push(node_id);
800
801        Ok(())
802    }
803
804    /// Count the number of nodes in the graph
805    pub fn node_count(&self) -> usize {
806        self.nodes.len()
807    }
808}
809
810impl Default for ShapeGraph {
811    fn default() -> Self {
812        Self::new()
813    }
814}
815
816#[cfg(test)]
817mod tests {
818    use super::*;
819
820    extern crate std;
821    use std::vec;
822
823    #[test]
824    fn test_input_node() {
825        let mut graph = ShapeGraph::new();
826        let input = graph.add_input(vec![2, 3, 4]);
827
828        let shape = graph
829            .infer_shape(input)
830            .expect("infer_shape should succeed");
831        assert_eq!(shape, vec![2, 3, 4]);
832    }
833
834    #[test]
835    fn test_reshape() {
836        let mut graph = ShapeGraph::new();
837        let input = graph.add_input(vec![2, 3, 4]);
838        let reshaped = graph.reshape(input, vec![2, 12]);
839
840        let shape = graph
841            .infer_shape(reshaped)
842            .expect("infer_shape should succeed");
843        assert_eq!(shape, vec![2, 12]);
844    }
845
846    #[test]
847    fn test_transpose() {
848        let mut graph = ShapeGraph::new();
849        let input = graph.add_input(vec![2, 3, 4]);
850        let transposed = graph.transpose(input, vec![2, 0, 1]);
851
852        let shape = graph
853            .infer_shape(transposed)
854            .expect("infer_shape should succeed");
855        assert_eq!(shape, vec![4, 2, 3]);
856    }
857
858    #[test]
859    fn test_broadcast() {
860        let mut graph = ShapeGraph::new();
861        let input = graph.add_input(vec![1, 3, 1]);
862        let broadcasted = graph.broadcast(input, vec![2, 3, 4]);
863
864        let shape = graph
865            .infer_shape(broadcasted)
866            .expect("infer_shape should succeed");
867        assert_eq!(shape, vec![2, 3, 4]);
868    }
869
870    #[test]
871    fn test_concatenate() {
872        let mut graph = ShapeGraph::new();
873        let input1 = graph.add_input(vec![2, 3, 4]);
874        let input2 = graph.add_input(vec![2, 5, 4]);
875        let concatenated = graph.concatenate(input1, input2, 1);
876
877        let shape = graph
878            .infer_shape(concatenated)
879            .expect("infer_shape should succeed");
880        assert_eq!(shape, vec![2, 8, 4]);
881    }
882
883    #[test]
884    fn test_stack() {
885        let mut graph = ShapeGraph::new();
886        let input1 = graph.add_input(vec![2, 3, 4]);
887        let input2 = graph.add_input(vec![2, 3, 4]);
888        let stacked = graph.stack(input1, input2, 1);
889
890        let shape = graph
891            .infer_shape(stacked)
892            .expect("infer_shape should succeed");
893        assert_eq!(shape, vec![2, 2, 3, 4]);
894    }
895
896    #[test]
897    fn test_squeeze() {
898        let mut graph = ShapeGraph::new();
899        let input = graph.add_input(vec![2, 1, 3, 1, 4]);
900        let squeezed = graph.squeeze(input, None);
901
902        let shape = graph
903            .infer_shape(squeezed)
904            .expect("infer_shape should succeed");
905        assert_eq!(shape, vec![2, 3, 4]);
906    }
907
908    #[test]
909    fn test_unsqueeze() {
910        let mut graph = ShapeGraph::new();
911        let input = graph.add_input(vec![2, 3, 4]);
912        let unsqueezed = graph.unsqueeze(input, vec![1, 3]);
913
914        let shape = graph
915            .infer_shape(unsqueezed)
916            .expect("infer_shape should succeed");
917        assert_eq!(shape, vec![2, 1, 3, 1, 4]);
918    }
919
920    #[test]
921    fn test_flatten() {
922        let mut graph = ShapeGraph::new();
923        let input = graph.add_input(vec![2, 3, 4, 5]);
924        let flattened = graph.flatten(input, 1, 2);
925
926        let shape = graph
927            .infer_shape(flattened)
928            .expect("infer_shape should succeed");
929        assert_eq!(shape, vec![2, 12, 5]);
930    }
931
932    #[test]
933    fn test_complex_graph() {
934        let mut graph = ShapeGraph::new();
935        let input = graph.add_input(vec![2, 3, 4]);
936        let reshaped = graph.reshape(input, vec![2, 12]);
937        let transposed = graph.transpose(reshaped, vec![1, 0]);
938        let unsqueezed = graph.unsqueeze(transposed, vec![1]);
939
940        let shape = graph
941            .infer_shape(unsqueezed)
942            .expect("infer_shape should succeed");
943        assert_eq!(shape, vec![12, 1, 2]);
944    }
945
946    #[test]
947    fn test_invalid_reshape() {
948        let mut graph = ShapeGraph::new();
949        let input = graph.add_input(vec![2, 3, 4]);
950        let reshaped = graph.reshape(input, vec![2, 13]); // 24 != 26
951
952        assert!(graph.infer_shape(reshaped).is_err());
953    }
954
955    #[test]
956    fn test_invalid_transpose() {
957        let mut graph = ShapeGraph::new();
958        let input = graph.add_input(vec![2, 3, 4]);
959        let transposed = graph.transpose(input, vec![0, 1]); // Wrong number of axes
960
961        assert!(graph.infer_shape(transposed).is_err());
962    }
963
964    #[test]
965    fn test_invalid_broadcast() {
966        let mut graph = ShapeGraph::new();
967        let input = graph.add_input(vec![2, 3, 4]);
968        let broadcasted = graph.broadcast(input, vec![2, 5, 4]); // 3 != 5
969
970        assert!(graph.infer_shape(broadcasted).is_err());
971    }
972
973    #[test]
974    fn test_topological_sort() {
975        let mut graph = ShapeGraph::new();
976        let input = graph.add_input(vec![2, 3, 4]);
977        let reshaped = graph.reshape(input, vec![2, 12]);
978        let transposed = graph.transpose(reshaped, vec![1, 0]);
979
980        let sorted = graph
981            .topological_sort()
982            .expect("topological_sort should succeed");
983        assert_eq!(sorted.len(), 3);
984
985        // Input should come before reshape, reshape before transpose
986        let input_pos = sorted
987            .iter()
988            .position(|&id| id == input)
989            .expect("input should be in sorted list");
990        let reshape_pos = sorted
991            .iter()
992            .position(|&id| id == reshaped)
993            .expect("reshaped should be in sorted list");
994        let transpose_pos = sorted
995            .iter()
996            .position(|&id| id == transposed)
997            .expect("transposed should be in sorted list");
998
999        assert!(input_pos < reshape_pos);
1000        assert!(reshape_pos < transpose_pos);
1001    }
1002
1003    #[test]
1004    fn test_cache() {
1005        let mut graph = ShapeGraph::new();
1006        let input = graph.add_input(vec![2, 3, 4]);
1007        let reshaped = graph.reshape(input, vec![2, 12]);
1008
1009        // First inference
1010        let shape1 = graph
1011            .infer_shape(reshaped)
1012            .expect("infer_shape should succeed");
1013
1014        // Second inference should use cache
1015        let shape2 = graph
1016            .infer_shape(reshaped)
1017            .expect("infer_shape should succeed");
1018
1019        assert_eq!(shape1, shape2);
1020        assert_eq!(shape1, vec![2, 12]);
1021    }
1022
1023    #[test]
1024    fn test_clear_cache() {
1025        let mut graph = ShapeGraph::new();
1026        let input = graph.add_input(vec![2, 3, 4]);
1027        let reshaped = graph.reshape(input, vec![2, 12]);
1028
1029        // Infer and cache
1030        let _ = graph
1031            .infer_shape(reshaped)
1032            .expect("infer_shape should succeed");
1033
1034        // Clear cache
1035        graph.clear_cache();
1036
1037        // Should still work after clearing cache
1038        let shape = graph
1039            .infer_shape(reshaped)
1040            .expect("infer_shape should succeed");
1041        assert_eq!(shape, vec![2, 12]);
1042    }
1043
1044    #[test]
1045    fn test_node_count() {
1046        let mut graph = ShapeGraph::new();
1047        assert_eq!(graph.node_count(), 0);
1048
1049        let input = graph.add_input(vec![2, 3, 4]);
1050        assert_eq!(graph.node_count(), 1);
1051
1052        let _reshaped = graph.reshape(input, vec![2, 12]);
1053        assert_eq!(graph.node_count(), 2);
1054    }
1055}