1use anyhow::anyhow;
2use graph::{NAdd, NMul, NetworkEdge, NetworkGraph, NetworkLeaf, NetworkNode, NetworkOp};
3use linnet::half_edge::NodeIndex;
4use serde::{Deserialize, Serialize};
5
6use library::{Library, LibraryError};
7
8use crate::algebra::algebraic_traits::RefOne;
9use crate::contraction::Contract;
10use crate::network::library::{DummyKey, FunctionLibrary, FunctionLibraryError, LibraryTensor};
11use crate::structure::abstract_index::AbstractIndex;
12use crate::structure::permuted::PermuteTensor;
13use crate::structure::representation::LibrarySlot;
15use crate::structure::slot::{AbsInd, IsAbstractSlot};
16use crate::structure::{HasName, PermutedStructure, StructureError};
17use std::borrow::Cow;
18use std::fmt::Display;
19use std::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign};
20use store::{NetworkStore, TensorScalarStore, TensorScalarStoreMapping};
21use thiserror::Error;
22use crate::{
25 contraction::ContractionError,
26 structure::{CastStructure, HasStructure, ScalarTensor, TensorStructure},
27};
28
29use std::{convert::Infallible, fmt::Debug};
32
33#[derive(
34 Debug,
35 Clone,
36 Serialize,
37 Deserialize,
38 bincode_trait_derive::Encode,
39 bincode_trait_derive::Decode,
40 PartialEq,
41 Eq,
42)]
43#[cfg_attr(
44 feature = "shadowing",
45 trait_decode(trait = symbolica::state::HasStateMap),
46)]
47pub struct Network<S, LibKey, FunKey, Aind = AbstractIndex> {
48 pub graph: NetworkGraph<LibKey, FunKey, Aind>,
49 pub store: S,
50 pub state: NetworkState,
51}
52
53#[derive(
54 Debug,
55 Clone,
56 Serialize,
57 Deserialize,
58 bincode_trait_derive::Encode,
59 bincode_trait_derive::Decode,
60 PartialEq,
61 Eq,
62 Copy,
63)]
64pub enum NetworkState {
65 PureScalar,
66 Tensor,
67 SelfDualTensor,
68 Scalar,
69}
70
71impl NetworkState {
72 pub fn is_scalar(&self) -> bool {
73 matches!(self, NetworkState::Scalar | NetworkState::PureScalar)
74 }
75
76 pub fn is_tensor(&self) -> bool {
77 matches!(self, NetworkState::Tensor | NetworkState::SelfDualTensor)
78 }
79
80 pub fn pow(self, pow: i8) -> Self {
81 match self {
82 NetworkState::PureScalar => NetworkState::PureScalar,
83 NetworkState::Scalar => NetworkState::Scalar,
84 NetworkState::SelfDualTensor => {
85 if pow % 2 == 0 {
86 NetworkState::Scalar
87 } else {
88 NetworkState::SelfDualTensor
89 }
90 }
91 NetworkState::Tensor => panic!("Cannot have integer power of non-self dual tensor"),
92 }
93 }
94
95 pub fn is_compatible(&self, other: &Self) -> bool {
96 matches!(
97 (self, other),
98 (NetworkState::Tensor, NetworkState::Tensor)
99 | (NetworkState::SelfDualTensor, NetworkState::SelfDualTensor)
100 | (NetworkState::Scalar, NetworkState::Scalar)
101 | (NetworkState::PureScalar, NetworkState::Scalar)
102 | (NetworkState::Scalar, NetworkState::PureScalar)
103 | (NetworkState::PureScalar, NetworkState::PureScalar)
104 )
105 }
106}
107impl MulAssign for NetworkState {
108 fn mul_assign(&mut self, rhs: Self) {
109 *self = match (*self, rhs) {
111 (NetworkState::PureScalar, NetworkState::PureScalar) => NetworkState::PureScalar,
112 (NetworkState::PureScalar, NetworkState::Scalar) => NetworkState::Scalar,
113 (NetworkState::Tensor, _) => NetworkState::Tensor,
114 (NetworkState::Scalar, NetworkState::PureScalar) => NetworkState::Scalar,
115 (_, NetworkState::Tensor) => NetworkState::Tensor,
116 (NetworkState::SelfDualTensor, _) => NetworkState::SelfDualTensor,
117 (_, NetworkState::SelfDualTensor) => NetworkState::SelfDualTensor,
118 (NetworkState::Scalar, NetworkState::Scalar) => NetworkState::Scalar,
119 }
120 }
121}
122
123impl AddAssign for NetworkState {
124 fn add_assign(&mut self, rhs: Self) {
125 *self = match (*self, rhs) {
127 (NetworkState::PureScalar, NetworkState::PureScalar) => NetworkState::PureScalar,
128 (NetworkState::PureScalar, NetworkState::Scalar) => NetworkState::Scalar,
129 (a, b) => {
130 assert_eq!(a, b, "Cannot add incompatible network states:{a:?} + {b:?}");
131 a
132 }
133 }
134 }
135}
136
137pub mod graph;
149pub mod library;
150pub mod set;
151pub mod store;
152
153impl<S: TensorScalarStoreMapping, K: Clone, FK: Clone, Aind: AbsInd> TensorScalarStoreMapping
154 for Network<S, K, FK, Aind>
155{
156 type Store<U, V> = Network<S::Store<U, V>, K, FK, Aind>;
157 type Scalar = S::Scalar;
158 type Tensor = S::Tensor;
159
160 fn iter_scalars(&self) -> impl Iterator<Item = &Self::Scalar> {
161 self.store.iter_scalars()
162 }
163
164 fn iter_tensors(&self) -> impl Iterator<Item = &Self::Tensor> {
165 self.store.iter_tensors()
166 }
167
168 fn iter_scalars_mut(&mut self) -> impl Iterator<Item = &mut Self::Scalar> {
169 self.store.iter_scalars_mut()
170 }
171 fn iter_tensors_mut(&mut self) -> impl Iterator<Item = &mut Self::Tensor> {
172 self.store.iter_tensors_mut()
173 }
174
175 fn map<U, V>(
176 self,
177 scalar_map: impl FnMut(Self::Scalar) -> U,
178 tensor_map: impl FnMut(Self::Tensor) -> V,
179 ) -> Self::Store<V, U> {
180 Network {
181 store: self.store.map(scalar_map, tensor_map),
182 graph: self.graph,
183 state: self.state,
184 }
185 }
186
187 fn map_result<U, V, Er>(
188 self,
189 scalar_map: impl FnMut(Self::Scalar) -> Result<U, Er>,
190 tensor_map: impl FnMut(Self::Tensor) -> Result<V, Er>,
191 ) -> Result<Self::Store<V, U>, Er> {
192 Ok(Network {
193 store: self.store.map_result(scalar_map, tensor_map)?,
194 graph: self.graph,
195 state: self.state,
196 })
197 }
198
199 fn map_ref<'a, U, V>(
200 &'a self,
201 scalar_map: impl FnMut(&'a Self::Scalar) -> U,
202 tensor_map: impl FnMut(&'a Self::Tensor) -> V,
203 ) -> Self::Store<V, U> {
204 Network {
205 store: self.store.map_ref(scalar_map, tensor_map),
206 graph: self.graph.clone(),
207 state: self.state,
208 }
209 }
210
211 fn map_ref_result<U, V, Er>(
212 &self,
213 scalar_map: impl FnMut(&Self::Scalar) -> Result<U, Er>,
214 tensor_map: impl FnMut(&Self::Tensor) -> Result<V, Er>,
215 ) -> Result<Self::Store<V, U>, Er> {
216 Ok(Network {
217 store: self.store.map_ref_result(scalar_map, tensor_map)?,
218 graph: self.graph.clone(),
219 state: self.state,
220 })
221 }
222
223 fn map_ref_enumerate<U, V>(
224 &self,
225 scalar_map: impl FnMut((usize, &Self::Scalar)) -> U,
226 tensor_map: impl FnMut((usize, &Self::Tensor)) -> V,
227 ) -> Self::Store<V, U> {
228 Network {
229 store: self.store.map_ref_enumerate(scalar_map, tensor_map),
230 graph: self.graph.clone(),
231 state: self.state,
232 }
233 }
234
235 fn map_ref_result_enumerate<U, V, Er>(
236 &self,
237 scalar_map: impl FnMut((usize, &Self::Scalar)) -> Result<U, Er>,
238 tensor_map: impl FnMut((usize, &Self::Tensor)) -> Result<V, Er>,
239 ) -> Result<Self::Store<V, U>, Er> {
240 Ok(Network {
241 store: self
242 .store
243 .map_ref_result_enumerate(scalar_map, tensor_map)?,
244 graph: self.graph.clone(),
245 state: self.state,
246 })
247 }
248
249 fn map_ref_mut<U, V>(
250 &mut self,
251 scalar_map: impl FnMut(&mut Self::Scalar) -> U,
252 tensor_map: impl FnMut(&mut Self::Tensor) -> V,
253 ) -> Self::Store<V, U> {
254 Network {
255 store: self.store.map_ref_mut(scalar_map, tensor_map),
256 graph: self.graph.clone(),
257 state: self.state,
258 }
259 }
260
261 fn map_ref_mut_result<U, V, Er>(
262 &mut self,
263 scalar_map: impl FnMut(&mut Self::Scalar) -> Result<U, Er>,
264 tensor_map: impl FnMut(&mut Self::Tensor) -> Result<V, Er>,
265 ) -> Result<Self::Store<V, U>, Er> {
266 Ok(Network {
267 store: self.store.map_ref_mut_result(scalar_map, tensor_map)?,
268 graph: self.graph.clone(),
269 state: self.state,
270 })
271 }
272
273 fn map_ref_mut_enumerate<U, V>(
274 &mut self,
275 scalar_map: impl FnMut((usize, &mut Self::Scalar)) -> U,
276 tensor_map: impl FnMut((usize, &mut Self::Tensor)) -> V,
277 ) -> Self::Store<V, U> {
278 Network {
279 store: self.store.map_ref_mut_enumerate(scalar_map, tensor_map),
280 graph: self.graph.clone(),
281 state: self.state,
282 }
283 }
284
285 fn map_ref_mut_result_enumerate<U, V, Er>(
286 &mut self,
287 scalar_map: impl FnMut((usize, &mut Self::Scalar)) -> Result<U, Er>,
288 tensor_map: impl FnMut((usize, &mut Self::Tensor)) -> Result<V, Er>,
289 ) -> Result<Self::Store<V, U>, Er> {
290 Ok(Network {
291 store: self
292 .store
293 .map_ref_mut_result_enumerate(scalar_map, tensor_map)?,
294 graph: self.graph.clone(),
295 state: self.state,
296 })
297 }
298}
299
300impl<S: TensorScalarStore, FK: Debug, K: Debug, Aind: AbsInd> Default for Network<S, K, FK, Aind> {
301 fn default() -> Self {
302 Self::one()
303 }
304}
305
306impl<S: TensorScalarStore, FK: Debug, K: Debug, Aind: AbsInd> NMul for Network<S, K, FK, Aind> {
307 type Output = Self;
308 fn n_mul<I: IntoIterator<Item = Self>>(self, iter: I) -> Self::Output {
309 let mut store = self.store;
310 let mut state = self.state;
311
312 let items = iter.into_iter().map(|mut a| {
313 a.graph.shift_scalars(store.n_scalars());
314 a.graph.shift_tensors(store.n_tensors());
315 store.extend(a.store);
316
317 state *= a.state;
318 a.graph
319 });
320
321 let graph = self.graph.n_mul(items);
322
323 if state.is_tensor() && graph.dangling_indices().is_empty() {
324 state = NetworkState::Scalar;
325 }
326
327 Network {
328 graph,
329 store,
330 state,
331 }
332 }
333}
334
335impl<S: TensorScalarStore, FK: Debug, K: Debug, Aind: AbsInd> Mul for Network<S, K, FK, Aind> {
336 type Output = Self;
337 fn mul(mut self, mut other: Self) -> Self::Output {
338 let mut store = self.store;
339
340 other.graph.shift_scalars(store.n_scalars());
341 other.graph.shift_tensors(store.n_tensors());
342 store.extend(other.store);
343 self.state *= other.state;
344
345 Network {
346 graph: self.graph * other.graph,
347 store,
348 state: self.state,
349 }
350 }
351}
352
353impl<S: TensorScalarStore, FK: Debug, K: Debug, Aind: AbsInd> MulAssign
354 for Network<S, K, FK, Aind>
355{
356 fn mul_assign(&mut self, mut rhs: Self) {
357 rhs.graph.shift_scalars(self.store.n_scalars());
358 rhs.graph.shift_tensors(self.store.n_tensors());
359 self.store.extend(rhs.store);
360 self.state *= rhs.state;
361 self.graph *= rhs.graph;
362 }
363}
364
365impl<T: TensorStructure, S, FK: Debug, K: Debug, Aind: AbsInd> MulAssign<T>
366 for Network<NetworkStore<T, S>, K, FK, Aind>
367where
368 T::Slot: IsAbstractSlot<Aind = Aind>,
369{
370 fn mul_assign(&mut self, rhs: T) {
371 *self *= Network::from_tensor(rhs);
372 }
373}
374
375impl<T: TensorStructure, S, FK: Debug, K: Debug, Aind: AbsInd> Mul<T>
376 for Network<NetworkStore<T, S>, K, FK, Aind>
377where
378 T::Slot: IsAbstractSlot<Aind = Aind>,
379{
380 type Output = Self;
381 fn mul(self, other: T) -> Self::Output {
382 let mut store = self.store;
383
384 let mut other = Network::from_tensor(other);
385
386 other.graph.shift_scalars(store.n_scalars());
387 other.graph.shift_tensors(store.n_tensors());
388 store.extend(other.store);
389
390 Network {
391 graph: self.graph * other.graph,
392 store,
393 state: self.state,
394 }
395 }
396}
397
398impl<S: TensorScalarStore, FK: Debug, K: Debug, Aind: AbsInd> Add for Network<S, K, FK, Aind> {
399 type Output = Self;
400 fn add(self, mut other: Self) -> Self::Output {
401 let mut store = self.store;
402
403 other.graph.shift_scalars(store.n_scalars());
404 other.graph.shift_tensors(store.n_tensors());
405 store.extend(other.store);
406
407 Network {
408 graph: self.graph + other.graph,
409 store,
410 state: self.state,
411 }
412 }
413}
414
415impl<S: TensorScalarStore, FK: Debug, K: Debug, Aind: AbsInd> AddAssign
416 for Network<S, K, FK, Aind>
417{
418 fn add_assign(&mut self, mut rhs: Self) {
419 rhs.graph.shift_scalars(self.store.n_scalars());
420 rhs.graph.shift_tensors(self.store.n_tensors());
421 self.store.extend(rhs.store);
422
423 self.graph += rhs.graph;
424 }
425}
426
427impl<T: TensorStructure, S, FK: Debug, K: Debug, Aind: AbsInd> AddAssign<T>
428 for Network<NetworkStore<T, S>, K, FK, Aind>
429where
430 T::Slot: IsAbstractSlot<Aind = Aind>,
431{
432 fn add_assign(&mut self, rhs: T) {
433 *self += Network::from_tensor(rhs);
434 }
435}
436
437impl<T: TensorStructure, S, FK: Debug, K: Debug, Aind: AbsInd> Add<T>
438 for Network<NetworkStore<T, S>, K, FK, Aind>
439where
440 T::Slot: IsAbstractSlot<Aind = Aind>,
441{
442 type Output = Self;
443 fn add(mut self, other: T) -> Self::Output {
444 self += other;
445 self
446 }
447}
448
449impl<T: TensorStructure, FK: Debug, K: Debug, Aind: AbsInd> Add<i8>
450 for Network<NetworkStore<T, i8>, K, FK, Aind>
451where
452 T::Slot: IsAbstractSlot<Aind = Aind>,
453{
454 type Output = Self;
455 fn add(mut self, other: i8) -> Self::Output {
456 let mut other = Network::from_scalar(other);
457 other.graph.shift_tensors(self.store.n_tensors());
458 other.graph.shift_tensors(self.store.n_tensors());
459
460 self.store.extend(other.store);
461 Network {
462 graph: self.graph + other.graph,
463 store: self.store,
464 state: self.state,
465 }
466 }
467}
468
469impl<S: TensorScalarStore, FK: Debug, K: Debug, Aind: AbsInd> NAdd for Network<S, K, FK, Aind> {
470 type Output = Self;
471 fn n_add<I: IntoIterator<Item = Self>>(self, iter: I) -> Self::Output {
472 let mut store = self.store;
473
474 let mut state = self.state;
475
476 let items = iter.into_iter().map(|mut a| {
477 a.graph.shift_scalars(store.n_scalars());
478 a.graph.shift_tensors(store.n_tensors());
479 store.extend(a.store);
480
481 state += a.state;
482 a.graph
483 });
484
485 Network {
486 graph: self.graph.n_add(items),
487 store,
488 state,
489 }
490 }
491}
492
493impl<S: TensorScalarStore, K: Clone + Debug, FK: Clone + Debug, Aind: AbsInd> Neg
494 for Network<S, K, FK, Aind>
495{
496 type Output = Self;
497 fn neg(self) -> Self::Output {
498 Self {
499 store: self.store,
500 graph: self.graph.neg(),
501 state: self.state,
502 }
503 }
504}
505
506impl<S: TensorScalarStore, K: Clone + Debug, FK: Clone + Debug, Aind: AbsInd> Sub
507 for Network<S, K, FK, Aind>
508{
509 type Output = Self;
510 fn sub(mut self, rhs: Self) -> Self::Output {
511 self -= rhs;
512 self
513 }
514}
515
516impl<S: TensorScalarStore, K: Clone + Debug, FK: Clone + Debug, Aind: AbsInd> SubAssign
517 for Network<S, K, FK, Aind>
518{
519 fn sub_assign(&mut self, mut rhs: Self) {
520 rhs.graph.shift_scalars(self.store.n_scalars());
521 rhs.graph.shift_tensors(self.store.n_tensors());
522 self.store.extend(rhs.store);
523
524 self.graph -= rhs.graph
525 }
526}
527
528impl<T: TensorStructure, S, K: Clone + Debug, FK: Clone + Debug, Aind: AbsInd> SubAssign<T>
529 for Network<NetworkStore<T, S>, K, FK, Aind>
530where
531 T::Slot: IsAbstractSlot<Aind = Aind>,
532{
533 fn sub_assign(&mut self, rhs: T) {
534 *self -= Network::from_tensor(rhs)
535 }
536}
537
538impl<T: TensorStructure, S, K: Clone + Debug, FK: Clone + Debug, Aind: AbsInd> Sub<T>
539 for Network<NetworkStore<T, S>, K, FK, Aind>
540where
541 T::Slot: IsAbstractSlot<Aind = Aind>,
542{
543 type Output = Self;
544 fn sub(mut self, other: T) -> Self::Output {
545 self -= other;
546 self
547 }
548}
549
550impl<S: TensorScalarStore, FK: Debug, K: Debug, Aind: AbsInd> Network<S, K, FK, Aind> {
551 pub fn pow(self, pow: i8) -> Self {
552 Self {
553 store: self.store,
554 graph: self.graph.pow(pow),
555 state: self.state.pow(pow),
556 }
557 }
558 pub fn fun(self, key: FK) -> Self {
559 let graph = self.graph.function(key);
560 Self {
561 store: self.store,
562 state: graph.state(),
563 graph,
564 }
565 }
566
567 pub fn from_scalar(scalar: S::Scalar) -> Self {
568 let mut store = S::default();
569 let id = store.add_scalar(scalar);
570 Network {
571 graph: NetworkGraph::scalar(id),
572 store,
573
574 state: NetworkState::PureScalar,
575 }
576 }
577
578 pub fn merge_ops(&mut self)
579 where
580 K: Clone,
581 {
582 self.graph.merge_ops();
583 }
584
585 pub fn from_tensor(tensor: S::Tensor) -> Self
586 where
587 S::Tensor: TensorStructure,
588 <S::Tensor as TensorStructure>::Slot: IsAbstractSlot<Aind = Aind>,
589 {
590 let mut store = S::default();
591
592 let state = if tensor.is_scalar() {
593 NetworkState::Scalar
594 } else if tensor.is_fully_self_dual() {
595 NetworkState::SelfDualTensor
597 } else {
598 NetworkState::Tensor
599 };
600 let id = store.add_tensor(tensor);
601
602 Network {
603 graph: NetworkGraph::tensor(store.get_tensor(id), NetworkLeaf::LocalTensor(id)),
604 store,
605 state,
606 }
607 }
608
609 pub fn library_tensor<T>(tensor: &T, key: PermutedStructure<K>) -> Self
610 where
611 T: TensorStructure,
612 T::Slot: IsAbstractSlot<Aind = Aind>,
613 {
614 let state = if tensor.is_scalar() {
615 NetworkState::Scalar
616 } else if tensor.is_fully_self_dual() {
617 NetworkState::SelfDualTensor
618 } else {
619 NetworkState::Tensor
620 };
621 Network {
622 graph: NetworkGraph::tensor(tensor, NetworkLeaf::LibraryKey(key)),
623 store: S::default(),
624 state,
625 }
626 }
627
628 pub fn one() -> Self {
629 Network {
630 graph: NetworkGraph::one(),
631 store: S::default(),
632 state: NetworkState::PureScalar,
633 }
634 }
635
636 pub fn zero() -> Self {
637 Network {
638 graph: NetworkGraph::zero(),
639 store: S::default(),
640 state: NetworkState::PureScalar,
641 }
642 }
643}
644
645#[derive(Error, Debug)]
646pub enum TensorNetworkError<K: Display, FK: Display> {
647 #[error("Slot edge to prod node")]
648 SlotEdgeToProdNode,
649 #[error("Slot edge to scalar node")]
650 SlotEdgeToScalarNode,
651 #[error("More than one neg")]
652 MoreThanOneNeg,
653 #[error("Childless neg")]
654 ChildlessNeg,
655 #[error("Contraction Error:{0}")]
656 ContractionError(#[from] ContractionError),
657 #[error("Scalar connected by a slot edge")]
658 ScalarSlotEdge,
659 #[error("Structure Error:{0}")]
660 StructErr(#[from] StructureError),
661 #[error("LibraryError:{0}")]
662 LibErr(#[from] LibraryError<K>),
663 #[error("FunctionLibraryError:{0}")]
664 FunLibErr(#[from] FunctionLibraryError<FK>),
665 #[error("Non tensor node still present")]
666 NonTensorNodePresent,
667 #[error("Negative non-even power on non-scalar node:{0}")]
668 NegativeExponentNonScalar(String),
669 #[error("Too many arguments for function:{0}")]
670 TooManyArgsFunction(String),
671 #[error("Non self-dual tensor power{0}")]
672 InvalidDotFunction(String),
673 #[error("Invalid dot function{0}")]
674 NonSelfDualTensorPower(String),
675 #[error("invalid resulting node{0}")]
676 InvalidResultNode(NetworkNode<DummyKey, FK>),
677 #[error("internal edge still present, contract it first")]
678 InternalEdgePresent,
679 #[error("uncontracted scalar")]
680 UncontractedScalar,
681 #[error("Cannot contract edge between {0} and {1}")]
682 CannotContractEdgeBetween(NetworkNode<K, FK>, NetworkNode<K, FK>),
683 #[error("no nodes in the graph")]
684 NoNodes,
685 #[error("no scalar present")]
686 NoScalar,
687 #[error("more than one node in the graph")]
688 MoreThanOneNode,
689 #[error("is not scalar output")]
690 NotScalarOutput,
691 #[error("failed scalar multiplication")]
692 FailedScalarMul,
693 #[error("scalar field is empty")]
694 ScalarFieldEmpty,
695 #[error("not all scalars: {0}")]
696 NotAllScalars(String),
697 #[error("try to sum scalar with library tensor: {0}")]
698 ScalarLibSum(String),
699 #[error("try to sum scalar with a tensor: {0}")]
700 SumScalarTensor(String),
701 #[error("Incompatible summands: {0}")]
702 IncompatibleSummand(String),
703 #[error("failed to contract")]
704 FailedContract(ContractionError),
705 #[error("negative exponent not yet supported")]
706 NegativeExponent,
707 #[error("failed to contract: {0}")]
708 FailedContractMsg(String),
709 #[error(transparent)]
710 Other(#[from] anyhow::Error),
711 #[error("Io error")]
712 InOut(#[from] std::io::Error),
713 #[error("Infallible")]
714 Infallible,
715}
716
717impl<K: Display, FK: Display> From<Infallible> for TensorNetworkError<K, FK> {
718 fn from(_: Infallible) -> Self {
719 TensorNetworkError::Infallible
720 }
721}
722
723pub enum TensorOrScalarOrKey<T, S, K, Aind> {
724 Tensor {
725 tensor: T,
726 graph_slots: Vec<LibrarySlot<Aind>>,
727 },
728 Scalar(S),
729 Key {
730 key: K,
731 nodeid: NodeIndex,
732 },
733}
734
735pub enum ExecutionResult<T> {
736 One,
737 Zero,
738 Val(T),
739}
740
741impl<T: Display> Display for ExecutionResult<T> {
742 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
743 match self {
744 ExecutionResult::One => write!(f, "One"),
745 ExecutionResult::Zero => write!(f, "Zero"),
746 ExecutionResult::Val(val) => write!(f, "{}", val),
747 }
748 }
749}
750
751impl<
752 T: TensorStructure,
753 S,
754 K: Display + Debug,
755 FK: Display + Debug,
756 Str: TensorScalarStore<Tensor = T, Scalar = S>,
757 Aind: AbsInd,
758> Network<Str, K, FK, Aind>
759where
760 T::Slot: IsAbstractSlot<Aind = Aind>,
761{
762 pub fn validate(&self)
763 where
764 K: TensorStructure,
765 {
766 for (n, _neigh, v) in self.graph.graph.iter_nodes() {
767 match v {
768 NetworkNode::Leaf(NetworkLeaf::LibraryKey(k)) => {
769 let reps = self
770 .graph
771 .slots(n)
772 .into_iter()
773 .map(|s| s.rep())
774 .collect::<Vec<_>>();
775 let n_reps = k
778 .structure
779 .external_reps_iter()
780 .map(|r| r.to_lib())
781 .collect::<Vec<_>>();
782 assert_eq!(n_reps, reps);
785 }
786 NetworkNode::Leaf(NetworkLeaf::LocalTensor(k)) => {
787 let reps = self
788 .graph
789 .slots(n)
790 .into_iter()
791 .map(|s| s.rep())
792 .collect::<Vec<_>>();
793 let n_reps = self
794 .store
795 .get_tensor(*k)
796 .external_reps_iter()
797 .map(|r| r.to_lib())
798 .collect::<Vec<_>>();
799 assert_eq!(n_reps, reps);
800 }
801 _ => {}
802 }
803 }
804 }
805
806 #[allow(clippy::result_large_err, clippy::type_complexity)]
807 pub fn result(
808 &self,
809 ) -> Result<
810 ExecutionResult<TensorOrScalarOrKey<&T, &S, &PermutedStructure<K>, Aind>>,
811 TensorNetworkError<K, FK>,
812 >
813 where
814 FK: Clone,
815 {
816 let (node, nid, graph_slots) = self.graph.result()?;
817
818 match node {
819 NetworkNode::Leaf(l) => match l {
820 NetworkLeaf::LibraryKey(k) => Ok(ExecutionResult::Val(TensorOrScalarOrKey::Key {
821 key: k,
822 nodeid: nid,
823 })),
824 NetworkLeaf::LocalTensor(t) => {
825 Ok(ExecutionResult::Val(TensorOrScalarOrKey::Tensor {
826 tensor: self.store.get_tensor(*t),
827 graph_slots,
828 }))
829 }
830 NetworkLeaf::Scalar(t) => Ok(ExecutionResult::Val(TensorOrScalarOrKey::Scalar(
831 self.store.get_scalar(*t),
832 ))),
833 },
834 NetworkNode::Op(o) => match o {
835 NetworkOp::Product => Ok(ExecutionResult::One),
836 NetworkOp::Sum => Ok(ExecutionResult::Zero),
837 o => Err(TensorNetworkError::InvalidResultNode(NetworkNode::Op(
838 o.clone(),
839 ))),
840 },
841 }
842 }
843
844 #[allow(clippy::result_large_err)]
845 pub fn result_tensor<'a, LT, L: Library<T::Structure, Key = K, Value = PermutedStructure<LT>>>(
846 &'a self,
847 lib: &L,
848 ) -> Result<ExecutionResult<Cow<'a, T>>, TensorNetworkError<K, FK>>
849 where
850 S: 'a,
851 T: Clone + ScalarTensor + HasStructure,
852 K: Display + Debug,
853 FK: Display + Debug + Clone,
854 LT: TensorStructure<Indexed = T> + Clone + LibraryTensor<WithIndices = T>,
855 T: PermuteTensor<Permuted = T>,
856 for<'b> &'b S: Into<T::Scalar>,
857 <<LT::WithIndices as HasStructure>::Structure as TensorStructure>::Slot:
858 IsAbstractSlot<Aind = Aind>,
859 {
860 Ok(match self.result()? {
861 ExecutionResult::One => ExecutionResult::One,
862 ExecutionResult::Zero => ExecutionResult::Zero,
863 ExecutionResult::Val(v) => ExecutionResult::Val(match v {
864 TensorOrScalarOrKey::Tensor { tensor, .. } => Cow::Borrowed(tensor),
865 TensorOrScalarOrKey::Scalar(s) => Cow::Owned(T::new_scalar(s.into())),
866 TensorOrScalarOrKey::Key { nodeid, .. } => {
867 let less = self.graph.get_lib_data(lib, nodeid).unwrap();
868
869 Cow::Owned(less)
870 }
871 }),
872 })
873 }
874
875 #[allow(clippy::result_large_err)]
876 pub fn result_scalar<'a>(
877 &'a self,
878 ) -> Result<ExecutionResult<Cow<'a, S>>, TensorNetworkError<K, FK>>
879 where
880 T: Clone + ScalarTensor + 'a,
881 T::Scalar: Into<S>,
882 K: Display,
883 FK: Display + Clone,
884 S: Clone,
885 {
886 Ok(match self.result()? {
887 ExecutionResult::One => ExecutionResult::One,
888 ExecutionResult::Zero => ExecutionResult::Zero,
889 ExecutionResult::Val(v) => ExecutionResult::Val(match v {
890 TensorOrScalarOrKey::Tensor { tensor: t, .. } => Cow::Owned(
891 t.clone()
892 .scalar()
893 .ok_or(TensorNetworkError::NoScalar)?
894 .into(),
895 ),
896 TensorOrScalarOrKey::Scalar(s) => Cow::Borrowed(s),
897 TensorOrScalarOrKey::Key { .. } => return Err(TensorNetworkError::NoScalar),
898 }),
899 })
900 }
901
902 pub fn cast<U>(self) -> Network<Str::Store<U, S>, K, FK, Aind>
903 where
904 K: Clone,
905 FK: Clone,
906 T: CastStructure<U> + HasStructure,
907 T::Structure: TensorStructure,
908 U: HasStructure,
909 U::Structure: From<T::Structure> + TensorStructure<Slot = T::Slot>,
910 {
911 self.map(|a| a, |t| t.cast_structure())
912 }
913}
914
915pub trait StructureLessDisplay {
916 fn display(&self) -> String {
917 String::new()
918 }
919}
920
921impl<S: HasName> StructureLessDisplay for S
922where
923 S::Args: StructureLessDisplay,
924 S::Name: Display,
925{
926 fn display(&self) -> String {
927 format!(
928 "{}({})",
929 self.name().map(|t| t.to_string()).unwrap_or_default(),
930 self.args().map(|t| t.display()).unwrap_or_default()
931 )
932 }
933}
934
935impl<S, K: Display + Debug, FK: Display + Debug, Aind: AbsInd> Network<S, K, FK, Aind> {
936 pub fn dot(&self) -> std::string::String {
937 self.graph.dot()
938 }
939
940 pub fn dot_pretty(&self) -> std::string::String
941 where
942 S: TensorScalarStore,
943 S::Scalar: Display,
944 K: StructureLessDisplay,
945 S::Tensor: StructureLessDisplay,
946 {
947 self.graph.dot_impl(
948 |i| {
949 let ss = &self.store.get_scalar(i);
950 format!("{}:{}", i, ss)
951 },
952 |k| k.display(),
953 |t| {
954 let tt = &self.store.get_tensor(t);
955 tt.display()
956 },
957 |fk| fk.to_string(),
958 )
959 }
960}
961
962impl<T, S, FK: Debug, K: Debug, Aind: AbsInd> Network<NetworkStore<T, S>, K, FK, Aind> {
963 pub fn dot_display_impl(
964 &self,
965 scalar_disp: impl Fn(&S) -> String,
966 library_disp: impl Fn(&K) -> Option<String>,
967 tensor_disp: impl Fn(&T) -> String,
968 function_disp: impl Fn(&FK) -> String,
969 ) -> std::string::String {
970 self.graph.graph.dot_impl(
971 &self.graph.graph.full_filter(),
972 "",
973 &|_| None,
974 &|e| {
975 if let NetworkEdge::Slot(s) = e {
976 Some(format!("label=\"{s}\""))
977 } else {
978 None
979 }
980 },
981 &|n| match n {
982 NetworkNode::Leaf(l) => match l {
983 NetworkLeaf::LibraryKey(l) => {
984 Some(format!("label = \"L:{}\"", library_disp(&l.structure)?))
986 }
990 NetworkLeaf::LocalTensor(l) => Some(format!(
991 "label = \"T:{}\"",
992 tensor_disp(self.store.get_tensor(*l))
993 )),
994 NetworkLeaf::Scalar(s) => Some(format!(
995 "label = \"S:{}\"",
996 scalar_disp(self.store.get_scalar(*s))
997 )),
998 },
999 NetworkNode::Op(o) => {
1000 Some(format!("label = \"{}\"", o.display_with(&function_disp)))
1001 }
1002 },
1003 )
1004 }
1006}
1007
1008#[cfg(feature = "shadowing")]
1010pub mod parsing;
1011pub mod contract;
1013pub use contract::{
1014 ContractScalars, ContractionStrategy, SingleSmallestDegree, SmallestDegree, SmallestDegreeIter,
1015};
1016pub trait ExecutionStrategy<E, FL, L, K, FK, Aind>
1017where
1018 E: ExecuteOp<FL, L, K, FK, Aind>,
1019{
1020 #[allow(clippy::result_large_err)]
1022 fn execute_all<C: ContractionStrategy<E, L, K, FK, Aind>>(
1023 executor: &mut E,
1024 graph: &mut NetworkGraph<K, FK, Aind>,
1025 lib: &L,
1026 fnlib: &FL,
1027 ) -> Result<(), TensorNetworkError<K, FK>>
1028 where
1029 K: Display,
1030 FK: Display;
1031}
1032
1033pub struct Sequential;
1034
1035pub struct Steps<const N: usize> {}
1036pub struct StepsDebug<const N: usize> {}
1037
1038impl<const N: usize, E, L, FL, FK: Debug, K: Debug, Aind: AbsInd>
1039 ExecutionStrategy<E, FL, L, K, FK, Aind> for StepsDebug<N>
1040where
1041 E: ExecuteOp<FL, L, K, FK, Aind>,
1042 K: Clone,
1043 FK: Clone,
1044{
1045 fn execute_all<C: ContractionStrategy<E, L, K, FK, Aind>>(
1046 executor: &mut E,
1047 graph: &mut NetworkGraph<K, FK, Aind>,
1048 lib: &L,
1049 fnlib: &FL,
1050 ) -> Result<(), TensorNetworkError<K, FK>>
1051 where
1052 K: Display,
1053 FK: Display,
1054 {
1055 for _ in 0..N {
1056 if let Some((extracted_graph, op)) = graph.extract_next_ready_op() {
1058 println!(
1059 "Extracted_graph: {}",
1060 extracted_graph.dot_impl(
1061 |s| s.to_string(),
1062 |_| "".to_string(),
1063 |s| s.to_string(),
1064 |_| "".to_string()
1065 )
1066 );
1067 println!(
1068 "Graph: {}",
1069 graph.dot_impl(
1070 |s| s.to_string(),
1071 |_| "".to_string(),
1072 |s| s.to_string(),
1073 |_| "".to_string()
1074 )
1075 );
1076 let replacement = executor.execute::<C>(extracted_graph, lib, fnlib, op)?;
1078 println!(
1079 "Replacement Graph: {}",
1080 replacement.dot_impl(
1081 |s| s.to_string(),
1082 |_| "".to_string(),
1083 |s| s.to_string(),
1084 |_| "".to_string()
1085 )
1086 );
1087
1088 graph.splice_descendents_of(replacement);
1089 }
1090 }
1091
1092 Ok(())
1093 }
1094}
1095
1096impl<const N: usize, E, FL, L, K, FK, Aind: AbsInd> ExecutionStrategy<E, FL, L, K, FK, Aind>
1097 for Steps<N>
1098where
1099 E: ExecuteOp<FL, L, K, FK, Aind>,
1100 K: Clone + Debug,
1101 FK: Clone + Debug,
1102{
1103 fn execute_all<C: ContractionStrategy<E, L, K, FK, Aind>>(
1104 executor: &mut E,
1105 graph: &mut NetworkGraph<K, FK, Aind>,
1106 lib: &L,
1107 fnlib: &FL,
1108 ) -> Result<(), TensorNetworkError<K, FK>>
1109 where
1110 K: Display,
1111 FK: Display,
1112 {
1113 for _ in 0..N {
1114 if let Some((extracted_graph, op)) = graph.extract_next_ready_op() {
1116 let replacement = executor.execute::<C>(extracted_graph, lib, fnlib, op)?;
1118 graph.splice_descendents_of(replacement);
1119 }
1120 }
1121
1122 Ok(())
1123 }
1124}
1125
1126impl<E, L, FL, K, FK, Aind: AbsInd> ExecutionStrategy<E, FL, L, K, FK, Aind> for Sequential
1127where
1128 E: ExecuteOp<FL, L, K, FK, Aind>,
1129 FK: Clone + Debug,
1130 K: Clone + Debug,
1131{
1132 fn execute_all<C: ContractionStrategy<E, L, K, FK, Aind>>(
1133 executor: &mut E,
1134 graph: &mut NetworkGraph<K, FK, Aind>,
1135 lib: &L,
1136 fnlib: &FL,
1137 ) -> Result<(), TensorNetworkError<K, FK>>
1138 where
1139 K: Display,
1140 FK: Display,
1141 {
1142 while {
1143 if let Some((extracted_graph, op)) = graph.extract_next_ready_op() {
1145 let replacement = executor.execute::<C>(extracted_graph, lib, fnlib, op)?;
1147 graph.splice_descendents_of(replacement);
1148 true
1149 } else {
1150 false
1151 }
1152 } {}
1153
1154 Ok(())
1155 }
1156}
1157
1158pub trait ExecuteOp<FL, L, K, FK, Aind>: Sized {
1198 #[allow(clippy::result_large_err)]
1200 fn execute<C: ContractionStrategy<Self, L, K, FK, Aind>>(
1201 &mut self,
1202 graph: NetworkGraph<K, FK, Aind>,
1203 lib: &L,
1204 fn_lib: &FL,
1205 op: NetworkOp<FK>,
1206 ) -> Result<NetworkGraph<K, FK, Aind>, TensorNetworkError<K, FK>>
1207 where
1208 K: Display,
1209 FK: Display;
1210}
1211
1212impl<S, Store: TensorScalarStore, K, FK, Aind: AbsInd> Network<Store, K, FK, Aind>
1213where
1214 Store::Tensor: HasStructure<Structure = S>,
1215{
1216 #[allow(clippy::result_large_err)]
1217 pub fn execute<
1218 Strat: ExecutionStrategy<Store, FL, L, K, FK, Aind>,
1219 C: ContractionStrategy<Store, L, K, FK, Aind>,
1220 LT,
1221 L,
1222 FL,
1223 >(
1224 &mut self,
1225 lib: &L,
1226 fn_lib: &FL,
1227 ) -> Result<(), TensorNetworkError<K, FK>>
1228 where
1229 K: Display + Clone + Debug,
1230 FK: Display + Clone + Debug,
1231 L: Library<S, Key = K, Value = PermutedStructure<LT>> + Sync,
1232 FL: FunctionLibrary<Store::Tensor, Store::Scalar, Key = FK>,
1233 LT: LibraryTensor<WithIndices = Store::Tensor>,
1234 Store: ExecuteOp<FL, L, K, FK, Aind>,
1235 {
1236 self.merge_ops();
1237 Strat::execute_all::<C>(&mut self.store, &mut self.graph, lib, fn_lib)
1241 }
1242}
1243
1244impl<
1245 LT: LibraryTensor + Clone,
1246 T: HasStructure
1247 + TensorStructure
1248 + Neg<Output = T>
1249 + Clone
1250 + Ref
1251 + Contract<LCM = T>
1252 + for<'a> AddAssign<T::Ref<'a>>
1253 + for<'a> AddAssign<LT::WithIndices>
1254 + From<LT::WithIndices>,
1255 L: Library<T::Structure, Key = K, Value = PermutedStructure<LT>>,
1256 Sc: Neg<Output = Sc>
1257 + RefOne
1258 + Div<Output = Sc>
1259 + for<'a> AddAssign<Sc::Ref<'a>>
1260 + Clone
1261 + for<'a> AddAssign<T::ScalarRef<'a>>
1262 + From<T::Scalar>
1263 + Ref
1264 + for<'a> MulAssign<Sc::Ref<'a>>,
1265 K: Display + Debug + Clone,
1266 FK: Display + Debug,
1267 FL: FunctionLibrary<T, Sc, Key = FK>,
1268 Aind: AbsInd,
1269> ExecuteOp<FL, L, K, FK, Aind> for NetworkStore<T, Sc>
1270where
1271 LT::WithIndices: PermuteTensor<Permuted = LT::WithIndices>,
1272 <<LT::WithIndices as HasStructure>::Structure as TensorStructure>::Slot:
1273 IsAbstractSlot<Aind = Aind>,
1274{
1275 fn execute<C: ContractionStrategy<Self, L, K, FK, Aind>>(
1276 &mut self,
1277 mut graph: NetworkGraph<K, FK, Aind>,
1278 lib: &L,
1279 fn_lib: &FL,
1280 op: NetworkOp<FK>,
1281 ) -> Result<NetworkGraph<K, FK, Aind>, TensorNetworkError<K, FK>> {
1282 graph.sync_order();
1283 match op {
1284 NetworkOp::Neg => {
1285 let ops = graph
1286 .graph
1287 .iter_nodes()
1288 .find(|(_, _, d)| matches!(d, NetworkNode::Op(NetworkOp::Neg)));
1289
1290 let (opid, children, _) = ops.unwrap();
1291
1292 let mut child = None;
1293 for c in children {
1294 if let Some(id) = graph.graph.involved_node_id(c)
1295 && let NetworkNode::Leaf(l) = &graph.graph[id]
1296 {
1297 if child.is_some() {
1298 return Err(TensorNetworkError::MoreThanOneNeg);
1299 } else {
1300 child = Some((id, l));
1301 }
1302 }
1303 }
1304 if let Some((child_id, leaf)) = child {
1305 let new_node = match leaf {
1306 NetworkLeaf::Scalar(s) => {
1307 let s = self.scalar[*s].clone().neg();
1308 let pos = self.scalar.len();
1309 self.scalar.push(s);
1310
1311 NetworkLeaf::Scalar(pos)
1312 }
1313 NetworkLeaf::LibraryKey(_) => {
1314 let inds = graph.get_lib_data(lib, child_id).unwrap();
1315
1316 let t = T::from(inds).neg();
1317 let pos = self.tensors.len();
1318 self.tensors.push(t);
1319 NetworkLeaf::LocalTensor(pos)
1320 }
1321 NetworkLeaf::LocalTensor(t) => {
1322 let t = self.tensors[*t].clone().neg();
1323 let pos = self.tensors.len();
1324 self.tensors.push(t);
1325 NetworkLeaf::LocalTensor(pos)
1326 }
1327 };
1328 graph.identify_nodes_without_self_edges(
1329 &[child_id, opid],
1330 NetworkNode::Leaf(new_node),
1331 );
1332 Ok(graph)
1333 } else {
1334 Err(TensorNetworkError::ChildlessNeg)
1335 }
1336 }
1337 NetworkOp::Product => {
1338 let (graph, _) = C::contract(self, graph, lib)?;
1340 Ok(graph)
1341 }
1342 NetworkOp::Sum => {
1343 let mut targets = Vec::new();
1345 let mut all_nodes = Vec::new();
1346 for (n, _, v) in graph.graph.iter_nodes() {
1347 all_nodes.push(n);
1348 if let NetworkNode::Leaf(l) = &v {
1349 targets.push((n, l));
1350 }
1351 }
1352
1353 let (nf, first) = &targets[0];
1354
1355 let new_node = match first {
1356 NetworkLeaf::Scalar(s) => {
1357 let mut accumulator = self.scalar[*s].clone();
1358
1359 for (_, t) in &targets[1..] {
1360 match t {
1361 NetworkLeaf::Scalar(s) => {
1362 accumulator += self.scalar[*s].refer();
1363 }
1364 NetworkLeaf::LocalTensor(t) => {
1365 if let Some(s) = self.tensors[*t].scalar_ref() {
1366 accumulator += s;
1367 } else {
1368 return Err(TensorNetworkError::NotAllScalars(
1369 "".to_string(),
1370 ));
1371 }
1372 }
1373 NetworkLeaf::LibraryKey { .. } => {
1374 return Err(TensorNetworkError::ScalarLibSum("".to_string()));
1375 }
1376 }
1377 }
1378
1379 let pos = self.scalar.len();
1380 self.scalar.push(accumulator);
1381 NetworkLeaf::Scalar(pos)
1382 }
1383 NetworkLeaf::LocalTensor(t) => {
1384 let mut accumulator = self.tensors[*t].clone();
1385 if accumulator.is_scalar() {
1386 let mut accumulator = Sc::from(accumulator.scalar().unwrap());
1387
1388 for (_, t) in &targets[1..] {
1389 match t {
1390 NetworkLeaf::Scalar(s) => {
1391 accumulator += self.scalar[*s].refer();
1392 }
1393 NetworkLeaf::LocalTensor(t) => {
1394 if let Some(s) = self.tensors[*t].scalar_ref() {
1395 accumulator += s;
1396 } else {
1397 return Err(TensorNetworkError::NotAllScalars(
1398 "".to_string(),
1399 ));
1400 }
1401 }
1402 NetworkLeaf::LibraryKey { .. } => {
1403 return Err(TensorNetworkError::ScalarLibSum(
1404 "".to_string(),
1405 ));
1406 }
1407 }
1408 }
1409
1410 let pos = self.scalar.len();
1411 self.scalar.push(accumulator);
1412 NetworkLeaf::Scalar(pos)
1413 } else {
1414 for (nid, t) in &targets[1..] {
1415 match t {
1416 NetworkLeaf::Scalar(_) => {
1417 return Err(TensorNetworkError::SumScalarTensor(
1418 "".to_string(),
1419 ));
1420 }
1421 NetworkLeaf::LocalTensor(t) => {
1422 accumulator += self.tensors[*t].refer();
1423 }
1424 NetworkLeaf::LibraryKey(_) => {
1425 let with_index = graph.get_lib_data(lib, *nid).unwrap();
1426
1427 accumulator += with_index;
1428 }
1429 }
1430 }
1431
1432 let pos = self.tensors.len();
1433 self.tensors.push(accumulator);
1434
1435 NetworkLeaf::LocalTensor(pos)
1436 }
1437 }
1438 NetworkLeaf::LibraryKey(_) => {
1439 let inds = graph.get_lib_data(lib, *nf).unwrap();
1440 let mut accumulator = T::from(inds);
1441 for (nid, t) in &targets[1..] {
1442 match t {
1443 NetworkLeaf::Scalar(_) => {
1444 return Err(TensorNetworkError::SumScalarTensor(
1445 "".to_string(),
1446 ));
1447 }
1448 NetworkLeaf::LocalTensor(t) => {
1449 accumulator += self.tensors[*t].refer();
1450 }
1451 NetworkLeaf::LibraryKey(_) => {
1452 let with = graph.get_lib_data(lib, *nid).unwrap();
1453 accumulator += with;
1454 }
1455 }
1456 }
1457
1458 let pos = self.tensors.len();
1459 self.tensors.push(accumulator);
1460
1461 NetworkLeaf::LocalTensor(pos)
1462 }
1463 };
1464
1465 graph.identify_nodes_without_self_edges(&all_nodes, NetworkNode::Leaf(new_node));
1466 Ok(graph)
1467 }
1468 NetworkOp::Function(f) => {
1469 let ops = graph
1470 .graph
1471 .iter_nodes()
1472 .find(|(_, _, d)| matches!(d, NetworkNode::Op(NetworkOp::Function(_))));
1473
1474 let (opid, children, _) = ops.unwrap();
1475
1476 let mut child = None;
1477 for c in children {
1478 if let Some(id) = graph.graph.involved_node_id(c)
1479 && let NetworkNode::Leaf(l) = &graph.graph[id]
1480 {
1481 if let Some((nid, _)) = child {
1482 if nid != id {
1483 return Err(TensorNetworkError::Other(anyhow!(
1484 "Cannot have more than one tensor argument to function"
1485 )));
1486 }
1487 } else {
1488 child = Some((id, l));
1489 }
1490 }
1491 }
1492 if let Some((child_id, leaf)) = child {
1493 let new_node = match leaf {
1494 NetworkLeaf::Scalar(s) => {
1495 let s = self.scalar[*s].clone();
1496 let pos = self.scalar.len();
1497 let s = fn_lib.apply_scalar(&f, s)?;
1498 self.scalar.push(s);
1499
1500 NetworkLeaf::Scalar(pos)
1501 }
1502 NetworkLeaf::LibraryKey(_) => {
1503 let inds = graph.get_lib_data(lib, child_id).unwrap();
1504 let t = fn_lib.apply(&f, T::from(inds))?;
1505 let pos = self.tensors.len();
1506 self.tensors.push(t);
1507 NetworkLeaf::LocalTensor(pos)
1508 }
1509 NetworkLeaf::LocalTensor(t) => {
1510 let t = self.tensors[*t].clone();
1511 let t = fn_lib.apply(&f, t)?;
1512 let pos = self.tensors.len();
1513 self.tensors.push(t);
1514 NetworkLeaf::LocalTensor(pos)
1515 }
1516 };
1517 graph.identify_nodes_without_self_edges(
1518 &[child_id, opid],
1519 NetworkNode::Leaf(new_node),
1520 );
1521 Ok(graph)
1522 } else {
1523 Err(TensorNetworkError::ChildlessNeg)
1524 }
1525 }
1526 NetworkOp::Power(_) => {
1527 let mut pow = 0;
1528 let ops = graph.graph.iter_nodes().find(|(_, _, d)| {
1529 if let NetworkNode::Op(NetworkOp::Power(i)) = d {
1530 pow = *i;
1531 true
1532 } else {
1533 false
1534 }
1535 });
1536
1537 let (opid, children, _) = ops.unwrap();
1538
1539 let mut child = None;
1540 for c in children {
1541 if let Some(id) = graph.graph.involved_node_id(c)
1542 && let NetworkNode::Leaf(l) = &graph.graph[id]
1543 {
1544 if let Some((nid, _)) = child {
1545 if nid != id {
1546 return Err(TensorNetworkError::Other(anyhow!(
1547 "Cannot have more than one tensor argument to power:{}",
1548 graph.dot()
1549 )));
1550 }
1551 } else {
1552 child = Some((id, l));
1553 }
1554 }
1555 }
1556 let n = pow.abs();
1557 if let Some((child_id, leaf)) = child {
1558 let new_node = match leaf {
1559 NetworkLeaf::Scalar(si) => {
1560 if n == 0 {
1561 NetworkLeaf::Scalar(*si)
1562 } else {
1563 let mut s = self.scalar[*si].clone();
1564
1565 for _ in 1..n {
1566 s *= self.scalar[*si].refer();
1567 }
1568
1569 if pow < 0 {
1570 s = s.ref_one() / s;
1571 }
1572
1573 let pos = self.scalar.len();
1574 self.scalar.push(s);
1575
1576 NetworkLeaf::Scalar(pos)
1577 }
1578 }
1579 NetworkLeaf::LibraryKey(a) => {
1580 let inds = graph.get_lib_data(lib, child_id).unwrap();
1581 let mut t = T::from(inds);
1582
1583 match pow {
1584 0 => {
1585 let pos = self.scalar.len();
1586 let one = self.scalar[0].ref_one();
1587 self.scalar.push(one);
1588 NetworkLeaf::Scalar(pos)
1589 }
1590 1 => NetworkLeaf::LibraryKey(a.clone()),
1591 _ => {
1592 let squares = n / 2;
1593 let mut square = t.contract(&t)?;
1594
1595 if n % 2 == 1 {
1596 if n != 1 {
1597 for _ in 0..squares {
1598 square = square.contract(&square)?;
1599 }
1600 t = square.contract(&t)?;
1601 }
1602
1603 if pow < 0 {
1604 if !t.is_scalar() {
1605 return Err(
1606 TensorNetworkError::NegativeExponentNonScalar(
1607 "".to_string(),
1608 ),
1609 );
1610 } else {
1611 let mut s = Sc::from(t.scalar().unwrap());
1612 let pos = self.scalar.len();
1613 s = s.ref_one() / s;
1614 self.scalar.push(s);
1615 NetworkLeaf::Scalar(pos)
1616 }
1617 } else {
1618 let pos = self.tensors.len();
1619 self.tensors.push(t);
1620 NetworkLeaf::LocalTensor(pos)
1621 }
1622 } else {
1623 let mut s = Sc::from(square.scalar().unwrap());
1624 let sc = s.clone();
1625 for _ in 1..squares {
1626 s *= sc.refer();
1627 }
1628 let pos = self.scalar.len();
1629 if pow < 0 {
1630 s = s.ref_one() / s;
1631 }
1632 self.scalar.push(s);
1633 NetworkLeaf::Scalar(pos)
1634 }
1635 }
1636 }
1637 }
1638 NetworkLeaf::LocalTensor(ti) => {
1639 let mut t = self.tensors[*ti].clone();
1640 match pow {
1641 0 => {
1642 let pos = self.scalar.len();
1643 let one = self.scalar[0].ref_one();
1644 self.scalar.push(one);
1645 NetworkLeaf::Scalar(pos)
1646 }
1647 1 => NetworkLeaf::LocalTensor(*ti),
1648 _ => {
1649 let squares = n / 2;
1650 let mut square = t.contract(&t)?;
1651
1652 if n % 2 == 1 {
1653 if n != 1 {
1654 for _ in 0..squares {
1655 square = square.contract(&square)?;
1656 }
1657 t = square.contract(&t)?;
1658 }
1659 if pow < 0 {
1660 if !t.is_scalar() {
1661 return Err(
1662 TensorNetworkError::NegativeExponentNonScalar(
1663 "".to_string(),
1664 ),
1665 );
1666 } else {
1667 let mut s = Sc::from(t.scalar().unwrap());
1668 let pos = self.scalar.len();
1669 s = s.ref_one() / s;
1670 self.scalar.push(s);
1671 NetworkLeaf::Scalar(pos)
1672 }
1673 } else {
1674 let pos = self.tensors.len();
1675 self.tensors.push(t);
1676 NetworkLeaf::LocalTensor(pos)
1677 }
1678 } else {
1679 let mut s = Sc::from(square.scalar().unwrap());
1680 let sc = s.clone();
1681 for _ in 1..squares {
1682 s *= sc.refer();
1683 }
1684 let pos = self.scalar.len();
1685 if pow < 0 {
1686 s = s.ref_one() / s;
1687 }
1688 self.scalar.push(s);
1689 NetworkLeaf::Scalar(pos)
1690 }
1691 }
1692 }
1693 }
1694 };
1695 graph.identify_nodes_without_self_edges(
1696 &[child_id, opid],
1697 NetworkNode::Leaf(new_node),
1698 );
1699 Ok(graph)
1700 } else {
1701 Err(TensorNetworkError::ChildlessNeg)
1702 }
1703 }
1704 }
1705 }
1706}
1707
1708pub trait Ref {
1709 type Ref<'a>
1710 where
1711 Self: 'a;
1712 fn refer(&self) -> Self::Ref<'_>;
1713}
1714
1715impl Ref for f64 {
1716 type Ref<'a>
1717 = &'a f64
1718 where
1719 Self: 'a;
1720
1721 fn refer(&self) -> Self::Ref<'_> {
1722 self
1723 }
1724}
1725
1726#[cfg(feature = "shadowing")]
1729pub mod symbolica_interop;
1730
1731#[cfg(test)]
1732mod tests;