Skip to main content

tenflowers_core/
session.rs

1use crate::{
2    dtype::DType,
3    error::TensorError,
4    graph::{AttributeValue, Graph, NodeId, NodeType},
5    ops::registry::OpRegistry,
6    tensor::Tensor,
7};
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11/// Session configuration options
12#[derive(Clone, Debug)]
13pub struct SessionConfig {
14    /// Allow soft device placement (fallback to CPU when GPU unavailable)
15    pub allow_soft_placement: bool,
16    /// Log device placement decisions
17    pub log_device_placement: bool,
18    /// GPU memory growth configuration
19    pub gpu_memory_growth: bool,
20    /// GPU memory limit (bytes)
21    pub gpu_memory_limit: Option<usize>,
22    /// Number of inter-op threads
23    pub inter_op_parallelism_threads: usize,
24    /// Number of intra-op threads
25    pub intra_op_parallelism_threads: usize,
26}
27
28impl Default for SessionConfig {
29    fn default() -> Self {
30        Self {
31            allow_soft_placement: true,
32            log_device_placement: false,
33            gpu_memory_growth: true,
34            gpu_memory_limit: None,
35            inter_op_parallelism_threads: 0, // Use system default
36            intra_op_parallelism_threads: 0, // Use system default
37        }
38    }
39}
40
41/// Feed dictionary for providing input values to placeholders
42pub type FeedDict = HashMap<String, Tensor<f32>>;
43
44/// Fetch specification for requesting output values
45#[derive(Clone, Debug, PartialEq, Eq, Hash)]
46pub enum FetchSpec {
47    /// Fetch by node name
48    Name(String),
49    /// Fetch by node ID
50    NodeId(NodeId),
51    /// Fetch by node name and output index
52    NamedOutput(String, usize),
53    /// Fetch by node ID and output index
54    IndexedOutput(NodeId, usize),
55}
56
57/// Session for executing computation graphs
58pub trait Session {
59    /// Run the graph with the given feeds and fetches
60    fn run(
61        &mut self,
62        fetches: &[FetchSpec],
63        feed_dict: &FeedDict,
64    ) -> Result<Vec<Tensor<f32>>, TensorError>;
65
66    /// Partial run - allows incrementally feeding inputs and fetching outputs
67    fn partial_run_setup(
68        &mut self,
69        feeds: &[String],
70        fetches: &[FetchSpec],
71        targets: &[String],
72    ) -> Result<String, TensorError>; // Returns run handle
73
74    /// Continue a partial run
75    fn partial_run(
76        &mut self,
77        handle: &str,
78        feed_dict: &FeedDict,
79        fetches: &[FetchSpec],
80    ) -> Result<Vec<Tensor<f32>>, TensorError>;
81
82    /// Close the session and release resources
83    fn close(&mut self) -> Result<(), TensorError>;
84}
85
86/// Variable store for managing session variables
87pub type VariableStore = HashMap<String, Tensor<f32>>;
88
89/// Default session implementation
90#[allow(dead_code)]
91pub struct DefaultSession {
92    graph: Arc<RwLock<Graph>>,
93    config: SessionConfig,
94    op_registry: Arc<OpRegistry>,
95    closed: bool,
96    // Variable storage
97    variables: VariableStore,
98    // Cached execution plan
99    execution_cache: HashMap<Vec<FetchSpec>, ExecutionPlan>,
100    // Partial run state
101    partial_runs: HashMap<String, PartialRunState>,
102    next_partial_run_id: u64,
103}
104
105/// Execution plan for a set of fetches
106#[derive(Clone, Debug)]
107#[allow(dead_code)]
108struct ExecutionPlan {
109    /// Nodes to execute in topological order
110    execution_order: Vec<NodeId>,
111    /// Input mapping (placeholder name -> node id)
112    input_mapping: HashMap<String, NodeId>,
113    /// Output mapping (fetch spec -> (node id, output index))
114    output_mapping: HashMap<FetchSpec, (NodeId, usize)>,
115}
116
117/// State for partial runs
118#[derive(Debug)]
119#[allow(dead_code)]
120struct PartialRunState {
121    feeds: Vec<String>,
122    fetches: Vec<FetchSpec>,
123    targets: Vec<String>,
124    plan: ExecutionPlan,
125    // Intermediate values stored between partial runs
126    intermediate_values: HashMap<NodeId, Vec<Tensor<f32>>>,
127}
128
129impl DefaultSession {
130    /// Create a new session with the given graph and configuration
131    pub fn new(
132        graph: Arc<RwLock<Graph>>,
133        config: SessionConfig,
134        op_registry: Arc<OpRegistry>,
135    ) -> Self {
136        Self {
137            graph,
138            config,
139            op_registry,
140            closed: false,
141            variables: HashMap::new(),
142            execution_cache: HashMap::new(),
143            partial_runs: HashMap::new(),
144            next_partial_run_id: 0,
145        }
146    }
147
148    /// Create execution plan for the given fetches
149    fn create_execution_plan(&self, fetches: &[FetchSpec]) -> Result<ExecutionPlan, TensorError> {
150        let graph = self.graph.read().expect("read lock should not be poisoned");
151
152        // Find all nodes that need to be executed
153        let mut required_nodes = std::collections::HashSet::new();
154        let mut output_mapping = HashMap::new();
155
156        // Process each fetch specification
157        for fetch in fetches {
158            let (node_id, output_idx) = match fetch {
159                FetchSpec::Name(name) => {
160                    let node = graph.get_node_by_name(name).ok_or_else(|| {
161                        TensorError::invalid_argument(format!("Node '{name}' not found"))
162                    })?;
163                    (node.id, 0)
164                }
165                FetchSpec::NodeId(id) => (*id, 0),
166                FetchSpec::NamedOutput(name, idx) => {
167                    let node = graph.get_node_by_name(name).ok_or_else(|| {
168                        TensorError::invalid_argument(format!("Node '{name}' not found"))
169                    })?;
170                    (node.id, *idx)
171                }
172                FetchSpec::IndexedOutput(id, idx) => (*id, *idx),
173            };
174
175            // Verify node exists
176            if graph.get_node(node_id).is_none() {
177                return Err(TensorError::invalid_argument(format!(
178                    "Node {node_id} not found"
179                )));
180            }
181
182            required_nodes.insert(node_id);
183            output_mapping.insert(fetch.clone(), (node_id, output_idx));
184        }
185
186        // Find all dependencies using DFS
187        let mut stack = required_nodes.iter().cloned().collect::<Vec<_>>();
188        while let Some(node_id) = stack.pop() {
189            if let Some(node) = graph.get_node(node_id) {
190                // Add all input nodes
191                for &edge_id in &node.inputs {
192                    if let Some(edge) = graph.get_edge(edge_id) {
193                        if required_nodes.insert(edge.from_node) {
194                            stack.push(edge.from_node);
195                        }
196                    }
197                }
198            }
199        }
200
201        // We need to access compute_topological_order on the graph
202        // Since we can't clone RwLockReadGuard, we'll call it on the original graph
203        let full_topo_order = {
204            drop(graph); // Release the read lock
205            let mut graph_write = self
206                .graph
207                .write()
208                .expect("write lock should not be poisoned");
209            graph_write.compute_topological_order()?.to_vec()
210        };
211        let execution_order: Vec<NodeId> = full_topo_order
212            .iter()
213            .filter(|&&node_id| required_nodes.contains(&node_id))
214            .cloned()
215            .collect();
216
217        // Create input mapping (placeholders)
218        let graph = self.graph.read().expect("read lock should not be poisoned");
219        let mut input_mapping = HashMap::new();
220        for node in graph.nodes() {
221            if let NodeType::Placeholder { .. } = node.op_type {
222                input_mapping.insert(node.name.clone(), node.id);
223            }
224        }
225
226        Ok(ExecutionPlan {
227            execution_order,
228            input_mapping,
229            output_mapping,
230        })
231    }
232
233    /// Execute a single node
234    fn execute_node(
235        &mut self,
236        node_id: NodeId,
237        node_values: &mut HashMap<NodeId, Vec<Tensor<f32>>>,
238        feed_dict: &FeedDict,
239    ) -> Result<(), TensorError> {
240        let graph = self.graph.read().expect("read lock should not be poisoned");
241        let node = graph
242            .get_node(node_id)
243            .ok_or_else(|| TensorError::invalid_argument(format!("Node {node_id} not found")))?;
244
245        match &node.op_type {
246            NodeType::Placeholder { .. } => {
247                // Look up value in feed_dict
248                if let Some(value) = feed_dict.get(&node.name) {
249                    node_values.insert(node_id, vec![value.clone()]);
250                } else {
251                    return Err(TensorError::invalid_argument(format!(
252                        "No value provided for placeholder '{}'",
253                        node.name
254                    )));
255                }
256            }
257            NodeType::Constant => {
258                // Get constant value from attributes
259                if let Some(AttributeValue::Tensor(tensor)) = node.attributes.get("value") {
260                    node_values.insert(node_id, vec![tensor.clone()]);
261                } else {
262                    return Err(TensorError::invalid_argument(format!(
263                        "Constant node '{}' has no value attribute",
264                        node.name
265                    )));
266                }
267            }
268            NodeType::Variable { shape, dtype, .. } => {
269                // Check if variable is already initialized
270                if let Some(var_tensor) = self.variables.get(&node.name) {
271                    // Use existing variable value
272                    node_values.insert(node_id, vec![var_tensor.clone()]);
273                } else {
274                    // Initialize variable with zeros or from initializer attribute
275                    let tensor = if let Some(AttributeValue::Tensor(init_tensor)) =
276                        node.attributes.get("initializer")
277                    {
278                        init_tensor.clone()
279                    } else {
280                        // Default initialization with zeros
281                        match dtype {
282                            DType::Float32 => Tensor::<f32>::zeros(shape.dims()),
283                            _ => {
284                                return Err(TensorError::unsupported_operation_simple(format!(
285                                    "Variable dtype {dtype:?} not supported"
286                                )))
287                            }
288                        }
289                    };
290
291                    // Store the variable for future use
292                    self.variables.insert(node.name.clone(), tensor.clone());
293                    node_values.insert(node_id, vec![tensor]);
294                }
295            }
296            NodeType::Operation(op_name) => {
297                // Gather input tensors
298                let mut input_tensors = Vec::new();
299                for &edge_id in &node.inputs {
300                    if let Some(edge) = graph.get_edge(edge_id) {
301                        if let Some(from_outputs) = node_values.get(&edge.from_node) {
302                            if edge.from_output < from_outputs.len() {
303                                input_tensors.push(from_outputs[edge.from_output].clone());
304                            } else {
305                                return Err(TensorError::invalid_argument(format!(
306                                    "Invalid output index {} for node {}",
307                                    edge.from_output, edge.from_node
308                                )));
309                            }
310                        } else {
311                            return Err(TensorError::invalid_argument(format!(
312                                "Input node {} has not been computed",
313                                edge.from_node
314                            )));
315                        }
316                    }
317                }
318
319                // Execute the operation
320                let outputs = self.execute_operation(op_name, &input_tensors, &node.attributes)?;
321                node_values.insert(node_id, outputs);
322            }
323        }
324
325        Ok(())
326    }
327
328    /// Execute an operation with given inputs
329    fn execute_operation(
330        &self,
331        op_name: &str,
332        inputs: &[Tensor<f32>],
333        _attributes: &HashMap<String, AttributeValue>,
334    ) -> Result<Vec<Tensor<f32>>, TensorError> {
335        // This is a simplified implementation
336        // In practice, we'd use the op registry to dispatch to the correct kernel
337        match op_name {
338            "Add" => {
339                if inputs.len() != 2 {
340                    return Err(TensorError::invalid_argument(
341                        "Add operation requires 2 inputs".to_string(),
342                    ));
343                }
344                Ok(vec![inputs[0].add(&inputs[1])?])
345            }
346            "Mul" => {
347                if inputs.len() != 2 {
348                    return Err(TensorError::invalid_argument(
349                        "Mul operation requires 2 inputs".to_string(),
350                    ));
351                }
352                Ok(vec![inputs[0].mul(&inputs[1])?])
353            }
354            "MatMul" => {
355                if inputs.len() != 2 {
356                    return Err(TensorError::invalid_argument(
357                        "MatMul operation requires 2 inputs".to_string(),
358                    ));
359                }
360                Ok(vec![inputs[0].matmul(&inputs[1])?])
361            }
362            "Identity" => {
363                if inputs.len() != 1 {
364                    return Err(TensorError::invalid_argument(
365                        "Identity operation requires 1 input".to_string(),
366                    ));
367                }
368                Ok(vec![inputs[0].clone()])
369            }
370            "Sub" => {
371                if inputs.len() != 2 {
372                    return Err(TensorError::invalid_argument(
373                        "Sub operation requires 2 inputs".to_string(),
374                    ));
375                }
376                Ok(vec![inputs[0].sub(&inputs[1])?])
377            }
378            "Div" => {
379                if inputs.len() != 2 {
380                    return Err(TensorError::invalid_argument(
381                        "Div operation requires 2 inputs".to_string(),
382                    ));
383                }
384                Ok(vec![inputs[0].div(&inputs[1])?])
385            }
386            "Pow" => {
387                if inputs.len() != 2 {
388                    return Err(TensorError::invalid_argument(
389                        "Pow operation requires 2 inputs".to_string(),
390                    ));
391                }
392                Ok(vec![crate::ops::pow(&inputs[0], &inputs[1])?])
393            }
394            "Exp" => {
395                if inputs.len() != 1 {
396                    return Err(TensorError::invalid_argument(
397                        "Exp operation requires 1 input".to_string(),
398                    ));
399                }
400                Ok(vec![crate::ops::exp(&inputs[0])?])
401            }
402            "Log" => {
403                if inputs.len() != 1 {
404                    return Err(TensorError::invalid_argument(
405                        "Log operation requires 1 input".to_string(),
406                    ));
407                }
408                Ok(vec![crate::ops::log(&inputs[0])?])
409            }
410            "Sin" => {
411                if inputs.len() != 1 {
412                    return Err(TensorError::invalid_argument(
413                        "Sin operation requires 1 input".to_string(),
414                    ));
415                }
416                Ok(vec![crate::ops::sin(&inputs[0])?])
417            }
418            "Cos" => {
419                if inputs.len() != 1 {
420                    return Err(TensorError::invalid_argument(
421                        "Cos operation requires 1 input".to_string(),
422                    ));
423                }
424                Ok(vec![crate::ops::cos(&inputs[0])?])
425            }
426            "Tanh" => {
427                if inputs.len() != 1 {
428                    return Err(TensorError::invalid_argument(
429                        "Tanh operation requires 1 input".to_string(),
430                    ));
431                }
432                Ok(vec![crate::ops::tanh(&inputs[0])?])
433            }
434            "Relu" => {
435                if inputs.len() != 1 {
436                    return Err(TensorError::invalid_argument(
437                        "Relu operation requires 1 input".to_string(),
438                    ));
439                }
440                Ok(vec![crate::ops::relu(&inputs[0])?])
441            }
442            "Sigmoid" => {
443                if inputs.len() != 1 {
444                    return Err(TensorError::invalid_argument(
445                        "Sigmoid operation requires 1 input".to_string(),
446                    ));
447                }
448                Ok(vec![crate::ops::sigmoid(&inputs[0])?])
449            }
450            "Softmax" => {
451                if inputs.len() != 1 {
452                    return Err(TensorError::invalid_argument(
453                        "Softmax operation requires 1 input".to_string(),
454                    ));
455                }
456                // Default to last axis (-1)
457                Ok(vec![crate::ops::softmax(&inputs[0], Some(-1))?])
458            }
459            "Sum" => {
460                if inputs.len() != 1 {
461                    return Err(TensorError::invalid_argument(
462                        "Sum operation requires 1 input".to_string(),
463                    ));
464                }
465                Ok(vec![crate::ops::sum(&inputs[0], None, false)?])
466            }
467            "Mean" => {
468                if inputs.len() != 1 {
469                    return Err(TensorError::invalid_argument(
470                        "Mean operation requires 1 input".to_string(),
471                    ));
472                }
473                Ok(vec![crate::ops::mean(&inputs[0], None, false)?])
474            }
475            "Reshape" => {
476                if inputs.len() != 1 {
477                    return Err(TensorError::invalid_argument(
478                        "Reshape operation requires 1 input (shape as attribute)".to_string(),
479                    ));
480                }
481                // For session execution, we'll use a simple flattening reshape
482                let total_elements = inputs[0].shape().dims().iter().product::<usize>();
483                Ok(vec![inputs[0].reshape(&[total_elements])?])
484            }
485            "Transpose" => {
486                if inputs.len() != 1 {
487                    return Err(TensorError::invalid_argument(
488                        "Transpose operation requires 1 input".to_string(),
489                    ));
490                }
491                Ok(vec![crate::ops::transpose(&inputs[0])?])
492            }
493            "Conv2D" => {
494                if inputs.len() < 2 {
495                    return Err(TensorError::invalid_argument(
496                        "Conv2D operation requires at least 2 inputs".to_string(),
497                    ));
498                }
499                // Use default parameters for stride, padding
500                Ok(vec![crate::ops::conv2d(
501                    &inputs[0],
502                    &inputs[1],
503                    None,
504                    (1, 1),
505                    "VALID",
506                )?])
507            }
508            "MaxPool2D" => {
509                if inputs.len() != 1 {
510                    return Err(TensorError::invalid_argument(
511                        "MaxPool2D operation requires 1 input".to_string(),
512                    ));
513                }
514                // Use default 2x2 kernel with stride 2
515                Ok(vec![crate::ops::max_pool2d(
516                    &inputs[0],
517                    (2, 2),
518                    (2, 2),
519                    "VALID",
520                )?])
521            }
522            "AvgPool2D" => {
523                if inputs.len() != 1 {
524                    return Err(TensorError::invalid_argument(
525                        "AvgPool2D operation requires 1 input".to_string(),
526                    ));
527                }
528                // Use default 2x2 kernel with stride 2
529                Ok(vec![crate::ops::avg_pool2d(
530                    &inputs[0],
531                    (2, 2),
532                    (2, 2),
533                    "VALID",
534                )?])
535            }
536            "Max" => {
537                if inputs.len() != 1 {
538                    return Err(TensorError::invalid_argument(
539                        "Max operation requires 1 input".to_string(),
540                    ));
541                }
542                Ok(vec![crate::ops::max(&inputs[0], None, false)?])
543            }
544            "Min" => {
545                if inputs.len() != 1 {
546                    return Err(TensorError::invalid_argument(
547                        "Min operation requires 1 input".to_string(),
548                    ));
549                }
550                Ok(vec![crate::ops::min(&inputs[0], None, false)?])
551            }
552            "Gelu" => {
553                if inputs.len() != 1 {
554                    return Err(TensorError::invalid_argument(
555                        "Gelu operation requires 1 input".to_string(),
556                    ));
557                }
558                Ok(vec![crate::ops::gelu(&inputs[0])?])
559            }
560            "Swish" => {
561                if inputs.len() != 1 {
562                    return Err(TensorError::invalid_argument(
563                        "Swish operation requires 1 input".to_string(),
564                    ));
565                }
566                Ok(vec![crate::ops::swish(&inputs[0])?])
567            }
568            _ => Err(TensorError::unsupported_operation_simple(format!(
569                "Operation '{op_name}' not supported in session execution"
570            ))),
571        }
572    }
573}
574
575impl Session for DefaultSession {
576    fn run(
577        &mut self,
578        fetches: &[FetchSpec],
579        feed_dict: &FeedDict,
580    ) -> Result<Vec<Tensor<f32>>, TensorError> {
581        if self.closed {
582            return Err(TensorError::invalid_argument(
583                "Session is closed".to_string(),
584            ));
585        }
586
587        // Get or create execution plan
588        let plan = if let Some(cached_plan) = self.execution_cache.get(fetches) {
589            cached_plan.clone()
590        } else {
591            let plan = self.create_execution_plan(fetches)?;
592            self.execution_cache.insert(fetches.to_vec(), plan.clone());
593            plan
594        };
595
596        // Execute nodes in topological order
597        let mut node_values: HashMap<NodeId, Vec<Tensor<f32>>> = HashMap::new();
598
599        for &node_id in &plan.execution_order {
600            self.execute_node(node_id, &mut node_values, feed_dict)?;
601        }
602
603        // Collect results
604        let mut results = Vec::new();
605        for fetch in fetches {
606            if let Some(&(node_id, output_idx)) = plan.output_mapping.get(fetch) {
607                if let Some(outputs) = node_values.get(&node_id) {
608                    if output_idx < outputs.len() {
609                        results.push(outputs[output_idx].clone());
610                    } else {
611                        return Err(TensorError::invalid_argument(format!(
612                            "Invalid output index {output_idx} for node {node_id}"
613                        )));
614                    }
615                } else {
616                    return Err(TensorError::invalid_argument(format!(
617                        "Node {node_id} was not computed"
618                    )));
619                }
620            } else {
621                return Err(TensorError::invalid_argument(
622                    "Invalid fetch specification".to_string(),
623                ));
624            }
625        }
626
627        Ok(results)
628    }
629
630    fn partial_run_setup(
631        &mut self,
632        feeds: &[String],
633        fetches: &[FetchSpec],
634        targets: &[String],
635    ) -> Result<String, TensorError> {
636        if self.closed {
637            return Err(TensorError::invalid_argument(
638                "Session is closed".to_string(),
639            ));
640        }
641
642        // Create execution plan for fetches
643        let plan = self.create_execution_plan(fetches)?;
644
645        // Generate unique handle
646        let handle = format!("partial_run_{}", self.next_partial_run_id);
647        self.next_partial_run_id += 1;
648
649        // Store partial run state
650        let partial_state = PartialRunState {
651            feeds: feeds.to_vec(),
652            fetches: fetches.to_vec(),
653            targets: targets.to_vec(),
654            plan,
655            intermediate_values: HashMap::new(),
656        };
657
658        self.partial_runs.insert(handle.clone(), partial_state);
659        Ok(handle)
660    }
661
662    fn partial_run(
663        &mut self,
664        handle: &str,
665        feed_dict: &FeedDict,
666        fetches: &[FetchSpec],
667    ) -> Result<Vec<Tensor<f32>>, TensorError> {
668        if self.closed {
669            return Err(TensorError::invalid_argument(
670                "Session is closed".to_string(),
671            ));
672        }
673
674        // Get the execution plan and intermediate values first
675        let (execution_order, output_mapping, mut node_values) = {
676            let partial_state = self.partial_runs.get(handle).ok_or_else(|| {
677                TensorError::invalid_argument(format!("Invalid partial run handle: {handle}"))
678            })?;
679            (
680                partial_state.plan.execution_order.clone(),
681                partial_state.plan.output_mapping.clone(),
682                partial_state.intermediate_values.clone(),
683            )
684        };
685
686        // Execute nodes that aren't already computed
687        for &node_id in &execution_order {
688            if !node_values.contains_key(&node_id) {
689                self.execute_node(node_id, &mut node_values, feed_dict)?;
690            }
691        }
692
693        // Update intermediate values
694        if let Some(partial_state) = self.partial_runs.get_mut(handle) {
695            partial_state.intermediate_values = node_values.clone();
696        }
697
698        // Collect results
699        let mut results = Vec::new();
700        for fetch in fetches {
701            if let Some(&(node_id, output_idx)) = output_mapping.get(fetch) {
702                if let Some(outputs) = node_values.get(&node_id) {
703                    if output_idx < outputs.len() {
704                        results.push(outputs[output_idx].clone());
705                    } else {
706                        return Err(TensorError::invalid_argument(format!(
707                            "Invalid output index {output_idx} for node {node_id}"
708                        )));
709                    }
710                } else {
711                    return Err(TensorError::invalid_argument(format!(
712                        "Node {node_id} was not computed"
713                    )));
714                }
715            } else {
716                return Err(TensorError::invalid_argument(
717                    "Invalid fetch specification".to_string(),
718                ));
719            }
720        }
721
722        Ok(results)
723    }
724
725    fn close(&mut self) -> Result<(), TensorError> {
726        if self.closed {
727            return Ok(());
728        }
729
730        // Clear caches and partial run state
731        self.execution_cache.clear();
732        self.partial_runs.clear();
733        self.closed = true;
734
735        Ok(())
736    }
737}
738
739/// Convenience function to create a new session
740pub fn create_session(
741    graph: Arc<RwLock<Graph>>,
742    config: Option<SessionConfig>,
743    op_registry: Option<Arc<OpRegistry>>,
744) -> DefaultSession {
745    let config = config.unwrap_or_default();
746    let op_registry = op_registry.unwrap_or_else(|| Arc::new(OpRegistry::new()));
747    DefaultSession::new(graph, config, op_registry)
748}
749
750#[cfg(test)]
751mod tests {
752    use super::*;
753    use crate::{
754        device::Device,
755        dtype::DType,
756        graph::{AttributeValue, Graph, NodeType},
757        shape::Shape,
758        tensor::Tensor,
759    };
760    use std::collections::HashMap;
761
762    #[test]
763    fn test_session_creation() {
764        let graph = Arc::new(RwLock::new(Graph::new()));
765        let session = create_session(graph, None, None);
766        assert!(!session.closed);
767    }
768
769    #[test]
770    fn test_simple_execution() {
771        let mut graph = Graph::new();
772
773        // Create placeholder
774        let placeholder_id = graph
775            .add_node(
776                "input".to_string(),
777                NodeType::Placeholder {
778                    dtype: DType::Float32,
779                    shape: Shape::new(vec![2, 2]),
780                },
781                Device::Cpu,
782                HashMap::new(),
783            )
784            .expect("test: operation should succeed");
785
786        // Create identity operation
787        let identity_id = graph
788            .add_node(
789                "output".to_string(),
790                NodeType::Operation("Identity".to_string()),
791                Device::Cpu,
792                HashMap::new(),
793            )
794            .expect("test: operation should succeed");
795
796        // Connect them
797        graph
798            .add_edge(
799                placeholder_id,
800                identity_id,
801                0,
802                0,
803                DType::Float32,
804                Shape::new(vec![2, 2]),
805                false,
806            )
807            .expect("test: operation should succeed");
808
809        let graph = Arc::new(RwLock::new(graph));
810        let mut session = create_session(graph, None, None);
811
812        // Create input tensor
813        let input_tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
814            .expect("test: from_vec should succeed");
815        let mut feed_dict = FeedDict::new();
816        feed_dict.insert("input".to_string(), input_tensor.clone());
817
818        // Run session
819        let fetches = vec![FetchSpec::Name("output".to_string())];
820        let results = session
821            .run(&fetches, &feed_dict)
822            .expect("test: run should succeed");
823
824        assert_eq!(results.len(), 1);
825        assert_eq!(results[0].shape(), input_tensor.shape());
826    }
827
828    #[test]
829    fn test_addition_execution() {
830        let mut graph = Graph::new();
831
832        // Create two placeholders
833        let input1_id = graph
834            .add_node(
835                "input1".to_string(),
836                NodeType::Placeholder {
837                    dtype: DType::Float32,
838                    shape: Shape::new(vec![2]),
839                },
840                Device::Cpu,
841                HashMap::new(),
842            )
843            .expect("test: operation should succeed");
844
845        let input2_id = graph
846            .add_node(
847                "input2".to_string(),
848                NodeType::Placeholder {
849                    dtype: DType::Float32,
850                    shape: Shape::new(vec![2]),
851                },
852                Device::Cpu,
853                HashMap::new(),
854            )
855            .expect("test: operation should succeed");
856
857        // Create add operation
858        let add_id = graph
859            .add_node(
860                "add".to_string(),
861                NodeType::Operation("Add".to_string()),
862                Device::Cpu,
863                HashMap::new(),
864            )
865            .expect("test: operation should succeed");
866
867        // Connect inputs to add
868        graph
869            .add_edge(
870                input1_id,
871                add_id,
872                0,
873                0,
874                DType::Float32,
875                Shape::new(vec![2]),
876                false,
877            )
878            .expect("test: operation should succeed");
879
880        graph
881            .add_edge(
882                input2_id,
883                add_id,
884                0,
885                1,
886                DType::Float32,
887                Shape::new(vec![2]),
888                false,
889            )
890            .expect("operation should succeed");
891
892        let graph = Arc::new(RwLock::new(graph));
893        let mut session = create_session(graph, None, None);
894
895        // Create input tensors
896        let input1 =
897            Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).expect("from_vec should succeed");
898        let input2 =
899            Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).expect("from_vec should succeed");
900
901        let mut feed_dict = FeedDict::new();
902        feed_dict.insert("input1".to_string(), input1);
903        feed_dict.insert("input2".to_string(), input2);
904
905        // Run session
906        let fetches = vec![FetchSpec::Name("add".to_string())];
907        let results = session
908            .run(&fetches, &feed_dict)
909            .expect("run should succeed");
910
911        assert_eq!(results.len(), 1);
912        assert_eq!(results[0].shape(), &Shape::new(vec![2]));
913
914        // Check result values
915        if let Some(result_slice) = results[0].as_slice() {
916            assert!((result_slice[0] - 4.0).abs() < 1e-6); // 1.0 + 3.0
917            assert!((result_slice[1] - 6.0).abs() < 1e-6); // 2.0 + 4.0
918        } else {
919            panic!("Failed to get tensor slice");
920        }
921    }
922
923    #[test]
924    fn test_session_close() {
925        let graph = Arc::new(RwLock::new(Graph::new()));
926        let mut session = create_session(graph, None, None);
927
928        session.close().expect("test: close should succeed");
929        assert!(session.closed);
930
931        // Trying to run after close should fail
932        let feed_dict = FeedDict::new();
933        let fetches = vec![];
934        let result = session.run(&fetches, &feed_dict);
935        assert!(result.is_err());
936    }
937}