1use crate::{
2 dtype::DType,
3 error::TensorError,
4 graph::{AttributeValue, Graph, NodeId, NodeType},
5 ops::registry::OpRegistry,
6 tensor::Tensor,
7};
8use std::collections::HashMap;
9use std::sync::{Arc, RwLock};
10
11#[derive(Clone, Debug)]
13pub struct SessionConfig {
14 pub allow_soft_placement: bool,
16 pub log_device_placement: bool,
18 pub gpu_memory_growth: bool,
20 pub gpu_memory_limit: Option<usize>,
22 pub inter_op_parallelism_threads: usize,
24 pub intra_op_parallelism_threads: usize,
26}
27
28impl Default for SessionConfig {
29 fn default() -> Self {
30 Self {
31 allow_soft_placement: true,
32 log_device_placement: false,
33 gpu_memory_growth: true,
34 gpu_memory_limit: None,
35 inter_op_parallelism_threads: 0, intra_op_parallelism_threads: 0, }
38 }
39}
40
41pub type FeedDict = HashMap<String, Tensor<f32>>;
43
44#[derive(Clone, Debug, PartialEq, Eq, Hash)]
46pub enum FetchSpec {
47 Name(String),
49 NodeId(NodeId),
51 NamedOutput(String, usize),
53 IndexedOutput(NodeId, usize),
55}
56
57pub trait Session {
59 fn run(
61 &mut self,
62 fetches: &[FetchSpec],
63 feed_dict: &FeedDict,
64 ) -> Result<Vec<Tensor<f32>>, TensorError>;
65
66 fn partial_run_setup(
68 &mut self,
69 feeds: &[String],
70 fetches: &[FetchSpec],
71 targets: &[String],
72 ) -> Result<String, TensorError>; fn partial_run(
76 &mut self,
77 handle: &str,
78 feed_dict: &FeedDict,
79 fetches: &[FetchSpec],
80 ) -> Result<Vec<Tensor<f32>>, TensorError>;
81
82 fn close(&mut self) -> Result<(), TensorError>;
84}
85
86pub type VariableStore = HashMap<String, Tensor<f32>>;
88
89#[allow(dead_code)]
91pub struct DefaultSession {
92 graph: Arc<RwLock<Graph>>,
93 config: SessionConfig,
94 op_registry: Arc<OpRegistry>,
95 closed: bool,
96 variables: VariableStore,
98 execution_cache: HashMap<Vec<FetchSpec>, ExecutionPlan>,
100 partial_runs: HashMap<String, PartialRunState>,
102 next_partial_run_id: u64,
103}
104
105#[derive(Clone, Debug)]
107#[allow(dead_code)]
108struct ExecutionPlan {
109 execution_order: Vec<NodeId>,
111 input_mapping: HashMap<String, NodeId>,
113 output_mapping: HashMap<FetchSpec, (NodeId, usize)>,
115}
116
117#[derive(Debug)]
119#[allow(dead_code)]
120struct PartialRunState {
121 feeds: Vec<String>,
122 fetches: Vec<FetchSpec>,
123 targets: Vec<String>,
124 plan: ExecutionPlan,
125 intermediate_values: HashMap<NodeId, Vec<Tensor<f32>>>,
127}
128
129impl DefaultSession {
130 pub fn new(
132 graph: Arc<RwLock<Graph>>,
133 config: SessionConfig,
134 op_registry: Arc<OpRegistry>,
135 ) -> Self {
136 Self {
137 graph,
138 config,
139 op_registry,
140 closed: false,
141 variables: HashMap::new(),
142 execution_cache: HashMap::new(),
143 partial_runs: HashMap::new(),
144 next_partial_run_id: 0,
145 }
146 }
147
148 fn create_execution_plan(&self, fetches: &[FetchSpec]) -> Result<ExecutionPlan, TensorError> {
150 let graph = self.graph.read().expect("read lock should not be poisoned");
151
152 let mut required_nodes = std::collections::HashSet::new();
154 let mut output_mapping = HashMap::new();
155
156 for fetch in fetches {
158 let (node_id, output_idx) = match fetch {
159 FetchSpec::Name(name) => {
160 let node = graph.get_node_by_name(name).ok_or_else(|| {
161 TensorError::invalid_argument(format!("Node '{name}' not found"))
162 })?;
163 (node.id, 0)
164 }
165 FetchSpec::NodeId(id) => (*id, 0),
166 FetchSpec::NamedOutput(name, idx) => {
167 let node = graph.get_node_by_name(name).ok_or_else(|| {
168 TensorError::invalid_argument(format!("Node '{name}' not found"))
169 })?;
170 (node.id, *idx)
171 }
172 FetchSpec::IndexedOutput(id, idx) => (*id, *idx),
173 };
174
175 if graph.get_node(node_id).is_none() {
177 return Err(TensorError::invalid_argument(format!(
178 "Node {node_id} not found"
179 )));
180 }
181
182 required_nodes.insert(node_id);
183 output_mapping.insert(fetch.clone(), (node_id, output_idx));
184 }
185
186 let mut stack = required_nodes.iter().cloned().collect::<Vec<_>>();
188 while let Some(node_id) = stack.pop() {
189 if let Some(node) = graph.get_node(node_id) {
190 for &edge_id in &node.inputs {
192 if let Some(edge) = graph.get_edge(edge_id) {
193 if required_nodes.insert(edge.from_node) {
194 stack.push(edge.from_node);
195 }
196 }
197 }
198 }
199 }
200
201 let full_topo_order = {
204 drop(graph); let mut graph_write = self
206 .graph
207 .write()
208 .expect("write lock should not be poisoned");
209 graph_write.compute_topological_order()?.to_vec()
210 };
211 let execution_order: Vec<NodeId> = full_topo_order
212 .iter()
213 .filter(|&&node_id| required_nodes.contains(&node_id))
214 .cloned()
215 .collect();
216
217 let graph = self.graph.read().expect("read lock should not be poisoned");
219 let mut input_mapping = HashMap::new();
220 for node in graph.nodes() {
221 if let NodeType::Placeholder { .. } = node.op_type {
222 input_mapping.insert(node.name.clone(), node.id);
223 }
224 }
225
226 Ok(ExecutionPlan {
227 execution_order,
228 input_mapping,
229 output_mapping,
230 })
231 }
232
233 fn execute_node(
235 &mut self,
236 node_id: NodeId,
237 node_values: &mut HashMap<NodeId, Vec<Tensor<f32>>>,
238 feed_dict: &FeedDict,
239 ) -> Result<(), TensorError> {
240 let graph = self.graph.read().expect("read lock should not be poisoned");
241 let node = graph
242 .get_node(node_id)
243 .ok_or_else(|| TensorError::invalid_argument(format!("Node {node_id} not found")))?;
244
245 match &node.op_type {
246 NodeType::Placeholder { .. } => {
247 if let Some(value) = feed_dict.get(&node.name) {
249 node_values.insert(node_id, vec![value.clone()]);
250 } else {
251 return Err(TensorError::invalid_argument(format!(
252 "No value provided for placeholder '{}'",
253 node.name
254 )));
255 }
256 }
257 NodeType::Constant => {
258 if let Some(AttributeValue::Tensor(tensor)) = node.attributes.get("value") {
260 node_values.insert(node_id, vec![tensor.clone()]);
261 } else {
262 return Err(TensorError::invalid_argument(format!(
263 "Constant node '{}' has no value attribute",
264 node.name
265 )));
266 }
267 }
268 NodeType::Variable { shape, dtype, .. } => {
269 if let Some(var_tensor) = self.variables.get(&node.name) {
271 node_values.insert(node_id, vec![var_tensor.clone()]);
273 } else {
274 let tensor = if let Some(AttributeValue::Tensor(init_tensor)) =
276 node.attributes.get("initializer")
277 {
278 init_tensor.clone()
279 } else {
280 match dtype {
282 DType::Float32 => Tensor::<f32>::zeros(shape.dims()),
283 _ => {
284 return Err(TensorError::unsupported_operation_simple(format!(
285 "Variable dtype {dtype:?} not supported"
286 )))
287 }
288 }
289 };
290
291 self.variables.insert(node.name.clone(), tensor.clone());
293 node_values.insert(node_id, vec![tensor]);
294 }
295 }
296 NodeType::Operation(op_name) => {
297 let mut input_tensors = Vec::new();
299 for &edge_id in &node.inputs {
300 if let Some(edge) = graph.get_edge(edge_id) {
301 if let Some(from_outputs) = node_values.get(&edge.from_node) {
302 if edge.from_output < from_outputs.len() {
303 input_tensors.push(from_outputs[edge.from_output].clone());
304 } else {
305 return Err(TensorError::invalid_argument(format!(
306 "Invalid output index {} for node {}",
307 edge.from_output, edge.from_node
308 )));
309 }
310 } else {
311 return Err(TensorError::invalid_argument(format!(
312 "Input node {} has not been computed",
313 edge.from_node
314 )));
315 }
316 }
317 }
318
319 let outputs = self.execute_operation(op_name, &input_tensors, &node.attributes)?;
321 node_values.insert(node_id, outputs);
322 }
323 }
324
325 Ok(())
326 }
327
328 fn execute_operation(
330 &self,
331 op_name: &str,
332 inputs: &[Tensor<f32>],
333 _attributes: &HashMap<String, AttributeValue>,
334 ) -> Result<Vec<Tensor<f32>>, TensorError> {
335 match op_name {
338 "Add" => {
339 if inputs.len() != 2 {
340 return Err(TensorError::invalid_argument(
341 "Add operation requires 2 inputs".to_string(),
342 ));
343 }
344 Ok(vec![inputs[0].add(&inputs[1])?])
345 }
346 "Mul" => {
347 if inputs.len() != 2 {
348 return Err(TensorError::invalid_argument(
349 "Mul operation requires 2 inputs".to_string(),
350 ));
351 }
352 Ok(vec![inputs[0].mul(&inputs[1])?])
353 }
354 "MatMul" => {
355 if inputs.len() != 2 {
356 return Err(TensorError::invalid_argument(
357 "MatMul operation requires 2 inputs".to_string(),
358 ));
359 }
360 Ok(vec![inputs[0].matmul(&inputs[1])?])
361 }
362 "Identity" => {
363 if inputs.len() != 1 {
364 return Err(TensorError::invalid_argument(
365 "Identity operation requires 1 input".to_string(),
366 ));
367 }
368 Ok(vec![inputs[0].clone()])
369 }
370 "Sub" => {
371 if inputs.len() != 2 {
372 return Err(TensorError::invalid_argument(
373 "Sub operation requires 2 inputs".to_string(),
374 ));
375 }
376 Ok(vec![inputs[0].sub(&inputs[1])?])
377 }
378 "Div" => {
379 if inputs.len() != 2 {
380 return Err(TensorError::invalid_argument(
381 "Div operation requires 2 inputs".to_string(),
382 ));
383 }
384 Ok(vec![inputs[0].div(&inputs[1])?])
385 }
386 "Pow" => {
387 if inputs.len() != 2 {
388 return Err(TensorError::invalid_argument(
389 "Pow operation requires 2 inputs".to_string(),
390 ));
391 }
392 Ok(vec![crate::ops::pow(&inputs[0], &inputs[1])?])
393 }
394 "Exp" => {
395 if inputs.len() != 1 {
396 return Err(TensorError::invalid_argument(
397 "Exp operation requires 1 input".to_string(),
398 ));
399 }
400 Ok(vec![crate::ops::exp(&inputs[0])?])
401 }
402 "Log" => {
403 if inputs.len() != 1 {
404 return Err(TensorError::invalid_argument(
405 "Log operation requires 1 input".to_string(),
406 ));
407 }
408 Ok(vec![crate::ops::log(&inputs[0])?])
409 }
410 "Sin" => {
411 if inputs.len() != 1 {
412 return Err(TensorError::invalid_argument(
413 "Sin operation requires 1 input".to_string(),
414 ));
415 }
416 Ok(vec![crate::ops::sin(&inputs[0])?])
417 }
418 "Cos" => {
419 if inputs.len() != 1 {
420 return Err(TensorError::invalid_argument(
421 "Cos operation requires 1 input".to_string(),
422 ));
423 }
424 Ok(vec![crate::ops::cos(&inputs[0])?])
425 }
426 "Tanh" => {
427 if inputs.len() != 1 {
428 return Err(TensorError::invalid_argument(
429 "Tanh operation requires 1 input".to_string(),
430 ));
431 }
432 Ok(vec![crate::ops::tanh(&inputs[0])?])
433 }
434 "Relu" => {
435 if inputs.len() != 1 {
436 return Err(TensorError::invalid_argument(
437 "Relu operation requires 1 input".to_string(),
438 ));
439 }
440 Ok(vec![crate::ops::relu(&inputs[0])?])
441 }
442 "Sigmoid" => {
443 if inputs.len() != 1 {
444 return Err(TensorError::invalid_argument(
445 "Sigmoid operation requires 1 input".to_string(),
446 ));
447 }
448 Ok(vec![crate::ops::sigmoid(&inputs[0])?])
449 }
450 "Softmax" => {
451 if inputs.len() != 1 {
452 return Err(TensorError::invalid_argument(
453 "Softmax operation requires 1 input".to_string(),
454 ));
455 }
456 Ok(vec![crate::ops::softmax(&inputs[0], Some(-1))?])
458 }
459 "Sum" => {
460 if inputs.len() != 1 {
461 return Err(TensorError::invalid_argument(
462 "Sum operation requires 1 input".to_string(),
463 ));
464 }
465 Ok(vec![crate::ops::sum(&inputs[0], None, false)?])
466 }
467 "Mean" => {
468 if inputs.len() != 1 {
469 return Err(TensorError::invalid_argument(
470 "Mean operation requires 1 input".to_string(),
471 ));
472 }
473 Ok(vec![crate::ops::mean(&inputs[0], None, false)?])
474 }
475 "Reshape" => {
476 if inputs.len() != 1 {
477 return Err(TensorError::invalid_argument(
478 "Reshape operation requires 1 input (shape as attribute)".to_string(),
479 ));
480 }
481 let total_elements = inputs[0].shape().dims().iter().product::<usize>();
483 Ok(vec![inputs[0].reshape(&[total_elements])?])
484 }
485 "Transpose" => {
486 if inputs.len() != 1 {
487 return Err(TensorError::invalid_argument(
488 "Transpose operation requires 1 input".to_string(),
489 ));
490 }
491 Ok(vec![crate::ops::transpose(&inputs[0])?])
492 }
493 "Conv2D" => {
494 if inputs.len() < 2 {
495 return Err(TensorError::invalid_argument(
496 "Conv2D operation requires at least 2 inputs".to_string(),
497 ));
498 }
499 Ok(vec![crate::ops::conv2d(
501 &inputs[0],
502 &inputs[1],
503 None,
504 (1, 1),
505 "VALID",
506 )?])
507 }
508 "MaxPool2D" => {
509 if inputs.len() != 1 {
510 return Err(TensorError::invalid_argument(
511 "MaxPool2D operation requires 1 input".to_string(),
512 ));
513 }
514 Ok(vec![crate::ops::max_pool2d(
516 &inputs[0],
517 (2, 2),
518 (2, 2),
519 "VALID",
520 )?])
521 }
522 "AvgPool2D" => {
523 if inputs.len() != 1 {
524 return Err(TensorError::invalid_argument(
525 "AvgPool2D operation requires 1 input".to_string(),
526 ));
527 }
528 Ok(vec![crate::ops::avg_pool2d(
530 &inputs[0],
531 (2, 2),
532 (2, 2),
533 "VALID",
534 )?])
535 }
536 "Max" => {
537 if inputs.len() != 1 {
538 return Err(TensorError::invalid_argument(
539 "Max operation requires 1 input".to_string(),
540 ));
541 }
542 Ok(vec![crate::ops::max(&inputs[0], None, false)?])
543 }
544 "Min" => {
545 if inputs.len() != 1 {
546 return Err(TensorError::invalid_argument(
547 "Min operation requires 1 input".to_string(),
548 ));
549 }
550 Ok(vec![crate::ops::min(&inputs[0], None, false)?])
551 }
552 "Gelu" => {
553 if inputs.len() != 1 {
554 return Err(TensorError::invalid_argument(
555 "Gelu operation requires 1 input".to_string(),
556 ));
557 }
558 Ok(vec![crate::ops::gelu(&inputs[0])?])
559 }
560 "Swish" => {
561 if inputs.len() != 1 {
562 return Err(TensorError::invalid_argument(
563 "Swish operation requires 1 input".to_string(),
564 ));
565 }
566 Ok(vec![crate::ops::swish(&inputs[0])?])
567 }
568 _ => Err(TensorError::unsupported_operation_simple(format!(
569 "Operation '{op_name}' not supported in session execution"
570 ))),
571 }
572 }
573}
574
575impl Session for DefaultSession {
576 fn run(
577 &mut self,
578 fetches: &[FetchSpec],
579 feed_dict: &FeedDict,
580 ) -> Result<Vec<Tensor<f32>>, TensorError> {
581 if self.closed {
582 return Err(TensorError::invalid_argument(
583 "Session is closed".to_string(),
584 ));
585 }
586
587 let plan = if let Some(cached_plan) = self.execution_cache.get(fetches) {
589 cached_plan.clone()
590 } else {
591 let plan = self.create_execution_plan(fetches)?;
592 self.execution_cache.insert(fetches.to_vec(), plan.clone());
593 plan
594 };
595
596 let mut node_values: HashMap<NodeId, Vec<Tensor<f32>>> = HashMap::new();
598
599 for &node_id in &plan.execution_order {
600 self.execute_node(node_id, &mut node_values, feed_dict)?;
601 }
602
603 let mut results = Vec::new();
605 for fetch in fetches {
606 if let Some(&(node_id, output_idx)) = plan.output_mapping.get(fetch) {
607 if let Some(outputs) = node_values.get(&node_id) {
608 if output_idx < outputs.len() {
609 results.push(outputs[output_idx].clone());
610 } else {
611 return Err(TensorError::invalid_argument(format!(
612 "Invalid output index {output_idx} for node {node_id}"
613 )));
614 }
615 } else {
616 return Err(TensorError::invalid_argument(format!(
617 "Node {node_id} was not computed"
618 )));
619 }
620 } else {
621 return Err(TensorError::invalid_argument(
622 "Invalid fetch specification".to_string(),
623 ));
624 }
625 }
626
627 Ok(results)
628 }
629
630 fn partial_run_setup(
631 &mut self,
632 feeds: &[String],
633 fetches: &[FetchSpec],
634 targets: &[String],
635 ) -> Result<String, TensorError> {
636 if self.closed {
637 return Err(TensorError::invalid_argument(
638 "Session is closed".to_string(),
639 ));
640 }
641
642 let plan = self.create_execution_plan(fetches)?;
644
645 let handle = format!("partial_run_{}", self.next_partial_run_id);
647 self.next_partial_run_id += 1;
648
649 let partial_state = PartialRunState {
651 feeds: feeds.to_vec(),
652 fetches: fetches.to_vec(),
653 targets: targets.to_vec(),
654 plan,
655 intermediate_values: HashMap::new(),
656 };
657
658 self.partial_runs.insert(handle.clone(), partial_state);
659 Ok(handle)
660 }
661
662 fn partial_run(
663 &mut self,
664 handle: &str,
665 feed_dict: &FeedDict,
666 fetches: &[FetchSpec],
667 ) -> Result<Vec<Tensor<f32>>, TensorError> {
668 if self.closed {
669 return Err(TensorError::invalid_argument(
670 "Session is closed".to_string(),
671 ));
672 }
673
674 let (execution_order, output_mapping, mut node_values) = {
676 let partial_state = self.partial_runs.get(handle).ok_or_else(|| {
677 TensorError::invalid_argument(format!("Invalid partial run handle: {handle}"))
678 })?;
679 (
680 partial_state.plan.execution_order.clone(),
681 partial_state.plan.output_mapping.clone(),
682 partial_state.intermediate_values.clone(),
683 )
684 };
685
686 for &node_id in &execution_order {
688 if !node_values.contains_key(&node_id) {
689 self.execute_node(node_id, &mut node_values, feed_dict)?;
690 }
691 }
692
693 if let Some(partial_state) = self.partial_runs.get_mut(handle) {
695 partial_state.intermediate_values = node_values.clone();
696 }
697
698 let mut results = Vec::new();
700 for fetch in fetches {
701 if let Some(&(node_id, output_idx)) = output_mapping.get(fetch) {
702 if let Some(outputs) = node_values.get(&node_id) {
703 if output_idx < outputs.len() {
704 results.push(outputs[output_idx].clone());
705 } else {
706 return Err(TensorError::invalid_argument(format!(
707 "Invalid output index {output_idx} for node {node_id}"
708 )));
709 }
710 } else {
711 return Err(TensorError::invalid_argument(format!(
712 "Node {node_id} was not computed"
713 )));
714 }
715 } else {
716 return Err(TensorError::invalid_argument(
717 "Invalid fetch specification".to_string(),
718 ));
719 }
720 }
721
722 Ok(results)
723 }
724
725 fn close(&mut self) -> Result<(), TensorError> {
726 if self.closed {
727 return Ok(());
728 }
729
730 self.execution_cache.clear();
732 self.partial_runs.clear();
733 self.closed = true;
734
735 Ok(())
736 }
737}
738
739pub fn create_session(
741 graph: Arc<RwLock<Graph>>,
742 config: Option<SessionConfig>,
743 op_registry: Option<Arc<OpRegistry>>,
744) -> DefaultSession {
745 let config = config.unwrap_or_default();
746 let op_registry = op_registry.unwrap_or_else(|| Arc::new(OpRegistry::new()));
747 DefaultSession::new(graph, config, op_registry)
748}
749
750#[cfg(test)]
751mod tests {
752 use super::*;
753 use crate::{
754 device::Device,
755 dtype::DType,
756 graph::{AttributeValue, Graph, NodeType},
757 shape::Shape,
758 tensor::Tensor,
759 };
760 use std::collections::HashMap;
761
762 #[test]
763 fn test_session_creation() {
764 let graph = Arc::new(RwLock::new(Graph::new()));
765 let session = create_session(graph, None, None);
766 assert!(!session.closed);
767 }
768
769 #[test]
770 fn test_simple_execution() {
771 let mut graph = Graph::new();
772
773 let placeholder_id = graph
775 .add_node(
776 "input".to_string(),
777 NodeType::Placeholder {
778 dtype: DType::Float32,
779 shape: Shape::new(vec![2, 2]),
780 },
781 Device::Cpu,
782 HashMap::new(),
783 )
784 .expect("test: operation should succeed");
785
786 let identity_id = graph
788 .add_node(
789 "output".to_string(),
790 NodeType::Operation("Identity".to_string()),
791 Device::Cpu,
792 HashMap::new(),
793 )
794 .expect("test: operation should succeed");
795
796 graph
798 .add_edge(
799 placeholder_id,
800 identity_id,
801 0,
802 0,
803 DType::Float32,
804 Shape::new(vec![2, 2]),
805 false,
806 )
807 .expect("test: operation should succeed");
808
809 let graph = Arc::new(RwLock::new(graph));
810 let mut session = create_session(graph, None, None);
811
812 let input_tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2])
814 .expect("test: from_vec should succeed");
815 let mut feed_dict = FeedDict::new();
816 feed_dict.insert("input".to_string(), input_tensor.clone());
817
818 let fetches = vec![FetchSpec::Name("output".to_string())];
820 let results = session
821 .run(&fetches, &feed_dict)
822 .expect("test: run should succeed");
823
824 assert_eq!(results.len(), 1);
825 assert_eq!(results[0].shape(), input_tensor.shape());
826 }
827
828 #[test]
829 fn test_addition_execution() {
830 let mut graph = Graph::new();
831
832 let input1_id = graph
834 .add_node(
835 "input1".to_string(),
836 NodeType::Placeholder {
837 dtype: DType::Float32,
838 shape: Shape::new(vec![2]),
839 },
840 Device::Cpu,
841 HashMap::new(),
842 )
843 .expect("test: operation should succeed");
844
845 let input2_id = graph
846 .add_node(
847 "input2".to_string(),
848 NodeType::Placeholder {
849 dtype: DType::Float32,
850 shape: Shape::new(vec![2]),
851 },
852 Device::Cpu,
853 HashMap::new(),
854 )
855 .expect("test: operation should succeed");
856
857 let add_id = graph
859 .add_node(
860 "add".to_string(),
861 NodeType::Operation("Add".to_string()),
862 Device::Cpu,
863 HashMap::new(),
864 )
865 .expect("test: operation should succeed");
866
867 graph
869 .add_edge(
870 input1_id,
871 add_id,
872 0,
873 0,
874 DType::Float32,
875 Shape::new(vec![2]),
876 false,
877 )
878 .expect("test: operation should succeed");
879
880 graph
881 .add_edge(
882 input2_id,
883 add_id,
884 0,
885 1,
886 DType::Float32,
887 Shape::new(vec![2]),
888 false,
889 )
890 .expect("operation should succeed");
891
892 let graph = Arc::new(RwLock::new(graph));
893 let mut session = create_session(graph, None, None);
894
895 let input1 =
897 Tensor::<f32>::from_vec(vec![1.0, 2.0], &[2]).expect("from_vec should succeed");
898 let input2 =
899 Tensor::<f32>::from_vec(vec![3.0, 4.0], &[2]).expect("from_vec should succeed");
900
901 let mut feed_dict = FeedDict::new();
902 feed_dict.insert("input1".to_string(), input1);
903 feed_dict.insert("input2".to_string(), input2);
904
905 let fetches = vec![FetchSpec::Name("add".to_string())];
907 let results = session
908 .run(&fetches, &feed_dict)
909 .expect("run should succeed");
910
911 assert_eq!(results.len(), 1);
912 assert_eq!(results[0].shape(), &Shape::new(vec![2]));
913
914 if let Some(result_slice) = results[0].as_slice() {
916 assert!((result_slice[0] - 4.0).abs() < 1e-6); assert!((result_slice[1] - 6.0).abs() < 1e-6); } else {
919 panic!("Failed to get tensor slice");
920 }
921 }
922
923 #[test]
924 fn test_session_close() {
925 let graph = Arc::new(RwLock::new(Graph::new()));
926 let mut session = create_session(graph, None, None);
927
928 session.close().expect("test: close should succeed");
929 assert!(session.closed);
930
931 let feed_dict = FeedDict::new();
933 let fetches = vec![];
934 let result = session.run(&fetches, &feed_dict);
935 assert!(result.is_err());
936 }
937}