spenso/
network.rs

1use anyhow::anyhow;
2use graph::{NAdd, NMul, NetworkEdge, NetworkGraph, NetworkLeaf, NetworkNode, NetworkOp};
3use linnet::half_edge::NodeIndex;
4use serde::{Deserialize, Serialize};
5
6use library::{Library, LibraryError};
7
8use crate::algebra::algebraic_traits::RefOne;
9use crate::contraction::Contract;
10use crate::network::library::{DummyKey, FunctionLibrary, FunctionLibraryError, LibraryTensor};
11use crate::structure::abstract_index::AbstractIndex;
12use crate::structure::permuted::PermuteTensor;
13// use crate::shadowing::Concretize;
14use crate::structure::representation::LibrarySlot;
15use crate::structure::slot::{AbsInd, IsAbstractSlot};
16use crate::structure::{HasName, PermutedStructure, StructureError};
17use std::borrow::Cow;
18use std::fmt::Display;
19use std::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign};
20use store::{NetworkStore, TensorScalarStore, TensorScalarStoreMapping};
21use thiserror::Error;
22// use log::trace;
23
24use crate::{
25    contraction::ContractionError,
26    structure::{CastStructure, HasStructure, ScalarTensor, TensorStructure},
27};
28
29// use anyhow::Result;
30
31use std::{convert::Infallible, fmt::Debug};
32
33#[derive(
34    Debug,
35    Clone,
36    Serialize,
37    Deserialize,
38    bincode_trait_derive::Encode,
39    bincode_trait_derive::Decode,
40    PartialEq,
41    Eq,
42)]
43#[cfg_attr(
44    feature = "shadowing",
45    trait_decode(trait = symbolica::state::HasStateMap),
46)]
47pub struct Network<S, LibKey, FunKey, Aind = AbstractIndex> {
48    pub graph: NetworkGraph<LibKey, FunKey, Aind>,
49    pub store: S,
50    pub state: NetworkState,
51}
52
53#[derive(
54    Debug,
55    Clone,
56    Serialize,
57    Deserialize,
58    bincode_trait_derive::Encode,
59    bincode_trait_derive::Decode,
60    PartialEq,
61    Eq,
62    Copy,
63)]
64pub enum NetworkState {
65    PureScalar,
66    Tensor,
67    SelfDualTensor,
68    Scalar,
69}
70
71impl NetworkState {
72    pub fn is_scalar(&self) -> bool {
73        matches!(self, NetworkState::Scalar | NetworkState::PureScalar)
74    }
75
76    pub fn is_tensor(&self) -> bool {
77        matches!(self, NetworkState::Tensor | NetworkState::SelfDualTensor)
78    }
79
80    pub fn pow(self, pow: i8) -> Self {
81        match self {
82            NetworkState::PureScalar => NetworkState::PureScalar,
83            NetworkState::Scalar => NetworkState::Scalar,
84            NetworkState::SelfDualTensor => {
85                if pow % 2 == 0 {
86                    NetworkState::Scalar
87                } else {
88                    NetworkState::SelfDualTensor
89                }
90            }
91            NetworkState::Tensor => panic!("Cannot have integer power of non-self dual tensor"),
92        }
93    }
94
95    pub fn is_compatible(&self, other: &Self) -> bool {
96        matches!(
97            (self, other),
98            (NetworkState::Tensor, NetworkState::Tensor)
99                | (NetworkState::SelfDualTensor, NetworkState::SelfDualTensor)
100                | (NetworkState::Scalar, NetworkState::Scalar)
101                | (NetworkState::PureScalar, NetworkState::Scalar)
102                | (NetworkState::Scalar, NetworkState::PureScalar)
103                | (NetworkState::PureScalar, NetworkState::PureScalar)
104        )
105    }
106}
107impl MulAssign for NetworkState {
108    fn mul_assign(&mut self, rhs: Self) {
109        // println!("{self:?} *={rhs:?}");
110        *self = match (*self, rhs) {
111            (NetworkState::PureScalar, NetworkState::PureScalar) => NetworkState::PureScalar,
112            (NetworkState::PureScalar, NetworkState::Scalar) => NetworkState::Scalar,
113            (NetworkState::Tensor, _) => NetworkState::Tensor,
114            (NetworkState::Scalar, NetworkState::PureScalar) => NetworkState::Scalar,
115            (_, NetworkState::Tensor) => NetworkState::Tensor,
116            (NetworkState::SelfDualTensor, _) => NetworkState::SelfDualTensor,
117            (_, NetworkState::SelfDualTensor) => NetworkState::SelfDualTensor,
118            (NetworkState::Scalar, NetworkState::Scalar) => NetworkState::Scalar,
119        }
120    }
121}
122
123impl AddAssign for NetworkState {
124    fn add_assign(&mut self, rhs: Self) {
125        // println!("{self:?} *={rhs:?}");
126        *self = match (*self, rhs) {
127            (NetworkState::PureScalar, NetworkState::PureScalar) => NetworkState::PureScalar,
128            (NetworkState::PureScalar, NetworkState::Scalar) => NetworkState::Scalar,
129            (a, b) => {
130                assert_eq!(a, b, "Cannot add incompatible network states:{a:?} + {b:?}");
131                a
132            }
133        }
134    }
135}
136
137// pub type TensorNetwork<T, S, Str: TensorScalarStore<Tensor = T, Scalar = S>, K> = Network<Str, K>;
138
139// pub struct TensorNetwork<
140//     T,
141//     S,
142//     K,
143//     Str: TensorScalarStore<Tensor = T, Scalar = S> = NetworkStore<T, S>,
144// > {
145//     net: Network<Str, K>,
146// }
147
148pub mod graph;
149pub mod library;
150pub mod set;
151pub mod store;
152
153impl<S: TensorScalarStoreMapping, K: Clone, FK: Clone, Aind: AbsInd> TensorScalarStoreMapping
154    for Network<S, K, FK, Aind>
155{
156    type Store<U, V> = Network<S::Store<U, V>, K, FK, Aind>;
157    type Scalar = S::Scalar;
158    type Tensor = S::Tensor;
159
160    fn iter_scalars(&self) -> impl Iterator<Item = &Self::Scalar> {
161        self.store.iter_scalars()
162    }
163
164    fn iter_tensors(&self) -> impl Iterator<Item = &Self::Tensor> {
165        self.store.iter_tensors()
166    }
167
168    fn iter_scalars_mut(&mut self) -> impl Iterator<Item = &mut Self::Scalar> {
169        self.store.iter_scalars_mut()
170    }
171    fn iter_tensors_mut(&mut self) -> impl Iterator<Item = &mut Self::Tensor> {
172        self.store.iter_tensors_mut()
173    }
174
175    fn map<U, V>(
176        self,
177        scalar_map: impl FnMut(Self::Scalar) -> U,
178        tensor_map: impl FnMut(Self::Tensor) -> V,
179    ) -> Self::Store<V, U> {
180        Network {
181            store: self.store.map(scalar_map, tensor_map),
182            graph: self.graph,
183            state: self.state,
184        }
185    }
186
187    fn map_result<U, V, Er>(
188        self,
189        scalar_map: impl FnMut(Self::Scalar) -> Result<U, Er>,
190        tensor_map: impl FnMut(Self::Tensor) -> Result<V, Er>,
191    ) -> Result<Self::Store<V, U>, Er> {
192        Ok(Network {
193            store: self.store.map_result(scalar_map, tensor_map)?,
194            graph: self.graph,
195            state: self.state,
196        })
197    }
198
199    fn map_ref<'a, U, V>(
200        &'a self,
201        scalar_map: impl FnMut(&'a Self::Scalar) -> U,
202        tensor_map: impl FnMut(&'a Self::Tensor) -> V,
203    ) -> Self::Store<V, U> {
204        Network {
205            store: self.store.map_ref(scalar_map, tensor_map),
206            graph: self.graph.clone(),
207            state: self.state,
208        }
209    }
210
211    fn map_ref_result<U, V, Er>(
212        &self,
213        scalar_map: impl FnMut(&Self::Scalar) -> Result<U, Er>,
214        tensor_map: impl FnMut(&Self::Tensor) -> Result<V, Er>,
215    ) -> Result<Self::Store<V, U>, Er> {
216        Ok(Network {
217            store: self.store.map_ref_result(scalar_map, tensor_map)?,
218            graph: self.graph.clone(),
219            state: self.state,
220        })
221    }
222
223    fn map_ref_enumerate<U, V>(
224        &self,
225        scalar_map: impl FnMut((usize, &Self::Scalar)) -> U,
226        tensor_map: impl FnMut((usize, &Self::Tensor)) -> V,
227    ) -> Self::Store<V, U> {
228        Network {
229            store: self.store.map_ref_enumerate(scalar_map, tensor_map),
230            graph: self.graph.clone(),
231            state: self.state,
232        }
233    }
234
235    fn map_ref_result_enumerate<U, V, Er>(
236        &self,
237        scalar_map: impl FnMut((usize, &Self::Scalar)) -> Result<U, Er>,
238        tensor_map: impl FnMut((usize, &Self::Tensor)) -> Result<V, Er>,
239    ) -> Result<Self::Store<V, U>, Er> {
240        Ok(Network {
241            store: self
242                .store
243                .map_ref_result_enumerate(scalar_map, tensor_map)?,
244            graph: self.graph.clone(),
245            state: self.state,
246        })
247    }
248
249    fn map_ref_mut<U, V>(
250        &mut self,
251        scalar_map: impl FnMut(&mut Self::Scalar) -> U,
252        tensor_map: impl FnMut(&mut Self::Tensor) -> V,
253    ) -> Self::Store<V, U> {
254        Network {
255            store: self.store.map_ref_mut(scalar_map, tensor_map),
256            graph: self.graph.clone(),
257            state: self.state,
258        }
259    }
260
261    fn map_ref_mut_result<U, V, Er>(
262        &mut self,
263        scalar_map: impl FnMut(&mut Self::Scalar) -> Result<U, Er>,
264        tensor_map: impl FnMut(&mut Self::Tensor) -> Result<V, Er>,
265    ) -> Result<Self::Store<V, U>, Er> {
266        Ok(Network {
267            store: self.store.map_ref_mut_result(scalar_map, tensor_map)?,
268            graph: self.graph.clone(),
269            state: self.state,
270        })
271    }
272
273    fn map_ref_mut_enumerate<U, V>(
274        &mut self,
275        scalar_map: impl FnMut((usize, &mut Self::Scalar)) -> U,
276        tensor_map: impl FnMut((usize, &mut Self::Tensor)) -> V,
277    ) -> Self::Store<V, U> {
278        Network {
279            store: self.store.map_ref_mut_enumerate(scalar_map, tensor_map),
280            graph: self.graph.clone(),
281            state: self.state,
282        }
283    }
284
285    fn map_ref_mut_result_enumerate<U, V, Er>(
286        &mut self,
287        scalar_map: impl FnMut((usize, &mut Self::Scalar)) -> Result<U, Er>,
288        tensor_map: impl FnMut((usize, &mut Self::Tensor)) -> Result<V, Er>,
289    ) -> Result<Self::Store<V, U>, Er> {
290        Ok(Network {
291            store: self
292                .store
293                .map_ref_mut_result_enumerate(scalar_map, tensor_map)?,
294            graph: self.graph.clone(),
295            state: self.state,
296        })
297    }
298}
299
300impl<S: TensorScalarStore, FK: Debug, K: Debug, Aind: AbsInd> Default for Network<S, K, FK, Aind> {
301    fn default() -> Self {
302        Self::one()
303    }
304}
305
306impl<S: TensorScalarStore, FK: Debug, K: Debug, Aind: AbsInd> NMul for Network<S, K, FK, Aind> {
307    type Output = Self;
308    fn n_mul<I: IntoIterator<Item = Self>>(self, iter: I) -> Self::Output {
309        let mut store = self.store;
310        let mut state = self.state;
311
312        let items = iter.into_iter().map(|mut a| {
313            a.graph.shift_scalars(store.n_scalars());
314            a.graph.shift_tensors(store.n_tensors());
315            store.extend(a.store);
316
317            state *= a.state;
318            a.graph
319        });
320
321        let graph = self.graph.n_mul(items);
322
323        if state.is_tensor() && graph.dangling_indices().is_empty() {
324            state = NetworkState::Scalar;
325        }
326
327        Network {
328            graph,
329            store,
330            state,
331        }
332    }
333}
334
335impl<S: TensorScalarStore, FK: Debug, K: Debug, Aind: AbsInd> Mul for Network<S, K, FK, Aind> {
336    type Output = Self;
337    fn mul(mut self, mut other: Self) -> Self::Output {
338        let mut store = self.store;
339
340        other.graph.shift_scalars(store.n_scalars());
341        other.graph.shift_tensors(store.n_tensors());
342        store.extend(other.store);
343        self.state *= other.state;
344
345        Network {
346            graph: self.graph * other.graph,
347            store,
348            state: self.state,
349        }
350    }
351}
352
353impl<S: TensorScalarStore, FK: Debug, K: Debug, Aind: AbsInd> MulAssign
354    for Network<S, K, FK, Aind>
355{
356    fn mul_assign(&mut self, mut rhs: Self) {
357        rhs.graph.shift_scalars(self.store.n_scalars());
358        rhs.graph.shift_tensors(self.store.n_tensors());
359        self.store.extend(rhs.store);
360        self.state *= rhs.state;
361        self.graph *= rhs.graph;
362    }
363}
364
365impl<T: TensorStructure, S, FK: Debug, K: Debug, Aind: AbsInd> MulAssign<T>
366    for Network<NetworkStore<T, S>, K, FK, Aind>
367where
368    T::Slot: IsAbstractSlot<Aind = Aind>,
369{
370    fn mul_assign(&mut self, rhs: T) {
371        *self *= Network::from_tensor(rhs);
372    }
373}
374
375impl<T: TensorStructure, S, FK: Debug, K: Debug, Aind: AbsInd> Mul<T>
376    for Network<NetworkStore<T, S>, K, FK, Aind>
377where
378    T::Slot: IsAbstractSlot<Aind = Aind>,
379{
380    type Output = Self;
381    fn mul(self, other: T) -> Self::Output {
382        let mut store = self.store;
383
384        let mut other = Network::from_tensor(other);
385
386        other.graph.shift_scalars(store.n_scalars());
387        other.graph.shift_tensors(store.n_tensors());
388        store.extend(other.store);
389
390        Network {
391            graph: self.graph * other.graph,
392            store,
393            state: self.state,
394        }
395    }
396}
397
398impl<S: TensorScalarStore, FK: Debug, K: Debug, Aind: AbsInd> Add for Network<S, K, FK, Aind> {
399    type Output = Self;
400    fn add(self, mut other: Self) -> Self::Output {
401        let mut store = self.store;
402
403        other.graph.shift_scalars(store.n_scalars());
404        other.graph.shift_tensors(store.n_tensors());
405        store.extend(other.store);
406
407        Network {
408            graph: self.graph + other.graph,
409            store,
410            state: self.state,
411        }
412    }
413}
414
415impl<S: TensorScalarStore, FK: Debug, K: Debug, Aind: AbsInd> AddAssign
416    for Network<S, K, FK, Aind>
417{
418    fn add_assign(&mut self, mut rhs: Self) {
419        rhs.graph.shift_scalars(self.store.n_scalars());
420        rhs.graph.shift_tensors(self.store.n_tensors());
421        self.store.extend(rhs.store);
422
423        self.graph += rhs.graph;
424    }
425}
426
427impl<T: TensorStructure, S, FK: Debug, K: Debug, Aind: AbsInd> AddAssign<T>
428    for Network<NetworkStore<T, S>, K, FK, Aind>
429where
430    T::Slot: IsAbstractSlot<Aind = Aind>,
431{
432    fn add_assign(&mut self, rhs: T) {
433        *self += Network::from_tensor(rhs);
434    }
435}
436
437impl<T: TensorStructure, S, FK: Debug, K: Debug, Aind: AbsInd> Add<T>
438    for Network<NetworkStore<T, S>, K, FK, Aind>
439where
440    T::Slot: IsAbstractSlot<Aind = Aind>,
441{
442    type Output = Self;
443    fn add(mut self, other: T) -> Self::Output {
444        self += other;
445        self
446    }
447}
448
449impl<T: TensorStructure, FK: Debug, K: Debug, Aind: AbsInd> Add<i8>
450    for Network<NetworkStore<T, i8>, K, FK, Aind>
451where
452    T::Slot: IsAbstractSlot<Aind = Aind>,
453{
454    type Output = Self;
455    fn add(mut self, other: i8) -> Self::Output {
456        let mut other = Network::from_scalar(other);
457        other.graph.shift_tensors(self.store.n_tensors());
458        other.graph.shift_tensors(self.store.n_tensors());
459
460        self.store.extend(other.store);
461        Network {
462            graph: self.graph + other.graph,
463            store: self.store,
464            state: self.state,
465        }
466    }
467}
468
469impl<S: TensorScalarStore, FK: Debug, K: Debug, Aind: AbsInd> NAdd for Network<S, K, FK, Aind> {
470    type Output = Self;
471    fn n_add<I: IntoIterator<Item = Self>>(self, iter: I) -> Self::Output {
472        let mut store = self.store;
473
474        let mut state = self.state;
475
476        let items = iter.into_iter().map(|mut a| {
477            a.graph.shift_scalars(store.n_scalars());
478            a.graph.shift_tensors(store.n_tensors());
479            store.extend(a.store);
480
481            state += a.state;
482            a.graph
483        });
484
485        Network {
486            graph: self.graph.n_add(items),
487            store,
488            state,
489        }
490    }
491}
492
493impl<S: TensorScalarStore, K: Clone + Debug, FK: Clone + Debug, Aind: AbsInd> Neg
494    for Network<S, K, FK, Aind>
495{
496    type Output = Self;
497    fn neg(self) -> Self::Output {
498        Self {
499            store: self.store,
500            graph: self.graph.neg(),
501            state: self.state,
502        }
503    }
504}
505
506impl<S: TensorScalarStore, K: Clone + Debug, FK: Clone + Debug, Aind: AbsInd> Sub
507    for Network<S, K, FK, Aind>
508{
509    type Output = Self;
510    fn sub(mut self, rhs: Self) -> Self::Output {
511        self -= rhs;
512        self
513    }
514}
515
516impl<S: TensorScalarStore, K: Clone + Debug, FK: Clone + Debug, Aind: AbsInd> SubAssign
517    for Network<S, K, FK, Aind>
518{
519    fn sub_assign(&mut self, mut rhs: Self) {
520        rhs.graph.shift_scalars(self.store.n_scalars());
521        rhs.graph.shift_tensors(self.store.n_tensors());
522        self.store.extend(rhs.store);
523
524        self.graph -= rhs.graph
525    }
526}
527
528impl<T: TensorStructure, S, K: Clone + Debug, FK: Clone + Debug, Aind: AbsInd> SubAssign<T>
529    for Network<NetworkStore<T, S>, K, FK, Aind>
530where
531    T::Slot: IsAbstractSlot<Aind = Aind>,
532{
533    fn sub_assign(&mut self, rhs: T) {
534        *self -= Network::from_tensor(rhs)
535    }
536}
537
538impl<T: TensorStructure, S, K: Clone + Debug, FK: Clone + Debug, Aind: AbsInd> Sub<T>
539    for Network<NetworkStore<T, S>, K, FK, Aind>
540where
541    T::Slot: IsAbstractSlot<Aind = Aind>,
542{
543    type Output = Self;
544    fn sub(mut self, other: T) -> Self::Output {
545        self -= other;
546        self
547    }
548}
549
550impl<S: TensorScalarStore, FK: Debug, K: Debug, Aind: AbsInd> Network<S, K, FK, Aind> {
551    pub fn pow(self, pow: i8) -> Self {
552        Self {
553            store: self.store,
554            graph: self.graph.pow(pow),
555            state: self.state.pow(pow),
556        }
557    }
558    pub fn fun(self, key: FK) -> Self {
559        let graph = self.graph.function(key);
560        Self {
561            store: self.store,
562            state: graph.state(),
563            graph,
564        }
565    }
566
567    pub fn from_scalar(scalar: S::Scalar) -> Self {
568        let mut store = S::default();
569        let id = store.add_scalar(scalar);
570        Network {
571            graph: NetworkGraph::scalar(id),
572            store,
573
574            state: NetworkState::PureScalar,
575        }
576    }
577
578    pub fn merge_ops(&mut self)
579    where
580        K: Clone,
581    {
582        self.graph.merge_ops();
583    }
584
585    pub fn from_tensor(tensor: S::Tensor) -> Self
586    where
587        S::Tensor: TensorStructure,
588        <S::Tensor as TensorStructure>::Slot: IsAbstractSlot<Aind = Aind>,
589    {
590        let mut store = S::default();
591
592        let state = if tensor.is_scalar() {
593            NetworkState::Scalar
594        } else if tensor.is_fully_self_dual() {
595            // tensor.structure().dual();
596            NetworkState::SelfDualTensor
597        } else {
598            NetworkState::Tensor
599        };
600        let id = store.add_tensor(tensor);
601
602        Network {
603            graph: NetworkGraph::tensor(store.get_tensor(id), NetworkLeaf::LocalTensor(id)),
604            store,
605            state,
606        }
607    }
608
609    pub fn library_tensor<T>(tensor: &T, key: PermutedStructure<K>) -> Self
610    where
611        T: TensorStructure,
612        T::Slot: IsAbstractSlot<Aind = Aind>,
613    {
614        let state = if tensor.is_scalar() {
615            NetworkState::Scalar
616        } else if tensor.is_fully_self_dual() {
617            NetworkState::SelfDualTensor
618        } else {
619            NetworkState::Tensor
620        };
621        Network {
622            graph: NetworkGraph::tensor(tensor, NetworkLeaf::LibraryKey(key)),
623            store: S::default(),
624            state,
625        }
626    }
627
628    pub fn one() -> Self {
629        Network {
630            graph: NetworkGraph::one(),
631            store: S::default(),
632            state: NetworkState::PureScalar,
633        }
634    }
635
636    pub fn zero() -> Self {
637        Network {
638            graph: NetworkGraph::zero(),
639            store: S::default(),
640            state: NetworkState::PureScalar,
641        }
642    }
643}
644
645#[derive(Error, Debug)]
646pub enum TensorNetworkError<K: Display, FK: Display> {
647    #[error("Slot edge to prod node")]
648    SlotEdgeToProdNode,
649    #[error("Slot edge to scalar node")]
650    SlotEdgeToScalarNode,
651    #[error("More than one neg")]
652    MoreThanOneNeg,
653    #[error("Childless neg")]
654    ChildlessNeg,
655    #[error("Contraction Error:{0}")]
656    ContractionError(#[from] ContractionError),
657    #[error("Scalar connected by a slot edge")]
658    ScalarSlotEdge,
659    #[error("Structure Error:{0}")]
660    StructErr(#[from] StructureError),
661    #[error("LibraryError:{0}")]
662    LibErr(#[from] LibraryError<K>),
663    #[error("FunctionLibraryError:{0}")]
664    FunLibErr(#[from] FunctionLibraryError<FK>),
665    #[error("Non tensor node still present")]
666    NonTensorNodePresent,
667    #[error("Negative non-even power on non-scalar node:{0}")]
668    NegativeExponentNonScalar(String),
669    #[error("Too many arguments for function:{0}")]
670    TooManyArgsFunction(String),
671    #[error("Non self-dual tensor power{0}")]
672    InvalidDotFunction(String),
673    #[error("Invalid dot function{0}")]
674    NonSelfDualTensorPower(String),
675    #[error("invalid resulting node{0}")]
676    InvalidResultNode(NetworkNode<DummyKey, FK>),
677    #[error("internal edge still present, contract it first")]
678    InternalEdgePresent,
679    #[error("uncontracted scalar")]
680    UncontractedScalar,
681    #[error("Cannot contract edge between {0} and {1}")]
682    CannotContractEdgeBetween(NetworkNode<K, FK>, NetworkNode<K, FK>),
683    #[error("no nodes in the graph")]
684    NoNodes,
685    #[error("no scalar present")]
686    NoScalar,
687    #[error("more than one node in the graph")]
688    MoreThanOneNode,
689    #[error("is not scalar output")]
690    NotScalarOutput,
691    #[error("failed scalar multiplication")]
692    FailedScalarMul,
693    #[error("scalar field is empty")]
694    ScalarFieldEmpty,
695    #[error("not all scalars: {0}")]
696    NotAllScalars(String),
697    #[error("try to sum scalar with library tensor: {0}")]
698    ScalarLibSum(String),
699    #[error("try to sum scalar with a tensor: {0}")]
700    SumScalarTensor(String),
701    #[error("Incompatible summands: {0}")]
702    IncompatibleSummand(String),
703    #[error("failed to contract")]
704    FailedContract(ContractionError),
705    #[error("negative exponent not yet supported")]
706    NegativeExponent,
707    #[error("failed to contract: {0}")]
708    FailedContractMsg(String),
709    #[error(transparent)]
710    Other(#[from] anyhow::Error),
711    #[error("Io error")]
712    InOut(#[from] std::io::Error),
713    #[error("Infallible")]
714    Infallible,
715}
716
717impl<K: Display, FK: Display> From<Infallible> for TensorNetworkError<K, FK> {
718    fn from(_: Infallible) -> Self {
719        TensorNetworkError::Infallible
720    }
721}
722
723pub enum TensorOrScalarOrKey<T, S, K, Aind> {
724    Tensor {
725        tensor: T,
726        graph_slots: Vec<LibrarySlot<Aind>>,
727    },
728    Scalar(S),
729    Key {
730        key: K,
731        nodeid: NodeIndex,
732    },
733}
734
735pub enum ExecutionResult<T> {
736    One,
737    Zero,
738    Val(T),
739}
740
741impl<T: Display> Display for ExecutionResult<T> {
742    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
743        match self {
744            ExecutionResult::One => write!(f, "One"),
745            ExecutionResult::Zero => write!(f, "Zero"),
746            ExecutionResult::Val(val) => write!(f, "{}", val),
747        }
748    }
749}
750
751impl<
752    T: TensorStructure,
753    S,
754    K: Display + Debug,
755    FK: Display + Debug,
756    Str: TensorScalarStore<Tensor = T, Scalar = S>,
757    Aind: AbsInd,
758> Network<Str, K, FK, Aind>
759where
760    T::Slot: IsAbstractSlot<Aind = Aind>,
761{
762    pub fn validate(&self)
763    where
764        K: TensorStructure,
765    {
766        for (n, _neigh, v) in self.graph.graph.iter_nodes() {
767            match v {
768                NetworkNode::Leaf(NetworkLeaf::LibraryKey(k)) => {
769                    let reps = self
770                        .graph
771                        .slots(n)
772                        .into_iter()
773                        .map(|s| s.rep())
774                        .collect::<Vec<_>>();
775                    // let p = Permutation::sort(&reps);
776
777                    let n_reps = k
778                        .structure
779                        .external_reps_iter()
780                        .map(|r| r.to_lib())
781                        .collect::<Vec<_>>();
782                    // let q = Permutation::sort(&n_reps);
783                    // println!("p{p}q{q}");
784                    assert_eq!(n_reps, reps);
785                }
786                NetworkNode::Leaf(NetworkLeaf::LocalTensor(k)) => {
787                    let reps = self
788                        .graph
789                        .slots(n)
790                        .into_iter()
791                        .map(|s| s.rep())
792                        .collect::<Vec<_>>();
793                    let n_reps = self
794                        .store
795                        .get_tensor(*k)
796                        .external_reps_iter()
797                        .map(|r| r.to_lib())
798                        .collect::<Vec<_>>();
799                    assert_eq!(n_reps, reps);
800                }
801                _ => {}
802            }
803        }
804    }
805
806    #[allow(clippy::result_large_err, clippy::type_complexity)]
807    pub fn result(
808        &self,
809    ) -> Result<
810        ExecutionResult<TensorOrScalarOrKey<&T, &S, &PermutedStructure<K>, Aind>>,
811        TensorNetworkError<K, FK>,
812    >
813    where
814        FK: Clone,
815    {
816        let (node, nid, graph_slots) = self.graph.result()?;
817
818        match node {
819            NetworkNode::Leaf(l) => match l {
820                NetworkLeaf::LibraryKey(k) => Ok(ExecutionResult::Val(TensorOrScalarOrKey::Key {
821                    key: k,
822                    nodeid: nid,
823                })),
824                NetworkLeaf::LocalTensor(t) => {
825                    Ok(ExecutionResult::Val(TensorOrScalarOrKey::Tensor {
826                        tensor: self.store.get_tensor(*t),
827                        graph_slots,
828                    }))
829                }
830                NetworkLeaf::Scalar(t) => Ok(ExecutionResult::Val(TensorOrScalarOrKey::Scalar(
831                    self.store.get_scalar(*t),
832                ))),
833            },
834            NetworkNode::Op(o) => match o {
835                NetworkOp::Product => Ok(ExecutionResult::One),
836                NetworkOp::Sum => Ok(ExecutionResult::Zero),
837                o => Err(TensorNetworkError::InvalidResultNode(NetworkNode::Op(
838                    o.clone(),
839                ))),
840            },
841        }
842    }
843
844    #[allow(clippy::result_large_err)]
845    pub fn result_tensor<'a, LT, L: Library<T::Structure, Key = K, Value = PermutedStructure<LT>>>(
846        &'a self,
847        lib: &L,
848    ) -> Result<ExecutionResult<Cow<'a, T>>, TensorNetworkError<K, FK>>
849    where
850        S: 'a,
851        T: Clone + ScalarTensor + HasStructure,
852        K: Display + Debug,
853        FK: Display + Debug + Clone,
854        LT: TensorStructure<Indexed = T> + Clone + LibraryTensor<WithIndices = T>,
855        T: PermuteTensor<Permuted = T>,
856        for<'b> &'b S: Into<T::Scalar>,
857        <<LT::WithIndices as HasStructure>::Structure as TensorStructure>::Slot:
858            IsAbstractSlot<Aind = Aind>,
859    {
860        Ok(match self.result()? {
861            ExecutionResult::One => ExecutionResult::One,
862            ExecutionResult::Zero => ExecutionResult::Zero,
863            ExecutionResult::Val(v) => ExecutionResult::Val(match v {
864                TensorOrScalarOrKey::Tensor { tensor, .. } => Cow::Borrowed(tensor),
865                TensorOrScalarOrKey::Scalar(s) => Cow::Owned(T::new_scalar(s.into())),
866                TensorOrScalarOrKey::Key { nodeid, .. } => {
867                    let less = self.graph.get_lib_data(lib, nodeid).unwrap();
868
869                    Cow::Owned(less)
870                }
871            }),
872        })
873    }
874
875    #[allow(clippy::result_large_err)]
876    pub fn result_scalar<'a>(
877        &'a self,
878    ) -> Result<ExecutionResult<Cow<'a, S>>, TensorNetworkError<K, FK>>
879    where
880        T: Clone + ScalarTensor + 'a,
881        T::Scalar: Into<S>,
882        K: Display,
883        FK: Display + Clone,
884        S: Clone,
885    {
886        Ok(match self.result()? {
887            ExecutionResult::One => ExecutionResult::One,
888            ExecutionResult::Zero => ExecutionResult::Zero,
889            ExecutionResult::Val(v) => ExecutionResult::Val(match v {
890                TensorOrScalarOrKey::Tensor { tensor: t, .. } => Cow::Owned(
891                    t.clone()
892                        .scalar()
893                        .ok_or(TensorNetworkError::NoScalar)?
894                        .into(),
895                ),
896                TensorOrScalarOrKey::Scalar(s) => Cow::Borrowed(s),
897                TensorOrScalarOrKey::Key { .. } => return Err(TensorNetworkError::NoScalar),
898            }),
899        })
900    }
901
902    pub fn cast<U>(self) -> Network<Str::Store<U, S>, K, FK, Aind>
903    where
904        K: Clone,
905        FK: Clone,
906        T: CastStructure<U> + HasStructure,
907        T::Structure: TensorStructure,
908        U: HasStructure,
909        U::Structure: From<T::Structure> + TensorStructure<Slot = T::Slot>,
910    {
911        self.map(|a| a, |t| t.cast_structure())
912    }
913}
914
915pub trait StructureLessDisplay {
916    fn display(&self) -> String {
917        String::new()
918    }
919}
920
921impl<S: HasName> StructureLessDisplay for S
922where
923    S::Args: StructureLessDisplay,
924    S::Name: Display,
925{
926    fn display(&self) -> String {
927        format!(
928            "{}({})",
929            self.name().map(|t| t.to_string()).unwrap_or_default(),
930            self.args().map(|t| t.display()).unwrap_or_default()
931        )
932    }
933}
934
935impl<S, K: Display + Debug, FK: Display + Debug, Aind: AbsInd> Network<S, K, FK, Aind> {
936    pub fn dot(&self) -> std::string::String {
937        self.graph.dot()
938    }
939
940    pub fn dot_pretty(&self) -> std::string::String
941    where
942        S: TensorScalarStore,
943        S::Scalar: Display,
944        K: StructureLessDisplay,
945        S::Tensor: StructureLessDisplay,
946    {
947        self.graph.dot_impl(
948            |i| {
949                let ss = &self.store.get_scalar(i);
950                format!("{}:{}", i, ss)
951            },
952            |k| k.display(),
953            |t| {
954                let tt = &self.store.get_tensor(t);
955                tt.display()
956            },
957            |fk| fk.to_string(),
958        )
959    }
960}
961
962impl<T, S, FK: Debug, K: Debug, Aind: AbsInd> Network<NetworkStore<T, S>, K, FK, Aind> {
963    pub fn dot_display_impl(
964        &self,
965        scalar_disp: impl Fn(&S) -> String,
966        library_disp: impl Fn(&K) -> Option<String>,
967        tensor_disp: impl Fn(&T) -> String,
968        function_disp: impl Fn(&FK) -> String,
969    ) -> std::string::String {
970        self.graph.graph.dot_impl(
971            &self.graph.graph.full_filter(),
972            "",
973            &|_| None,
974            &|e| {
975                if let NetworkEdge::Slot(s) = e {
976                    Some(format!("label=\"{s}\""))
977                } else {
978                    None
979                }
980            },
981            &|n| match n {
982                NetworkNode::Leaf(l) => match l {
983                    NetworkLeaf::LibraryKey(l) => {
984                        // if let Ok(v) = lib.get(l) {
985                        Some(format!("label = \"L:{}\"", library_disp(&l.structure)?))
986                        // } else {
987                        // None
988                        // }
989                    }
990                    NetworkLeaf::LocalTensor(l) => Some(format!(
991                        "label = \"T:{}\"",
992                        tensor_disp(self.store.get_tensor(*l))
993                    )),
994                    NetworkLeaf::Scalar(s) => Some(format!(
995                        "label = \"S:{}\"",
996                        scalar_disp(self.store.get_scalar(*s))
997                    )),
998                },
999                NetworkNode::Op(o) => {
1000                    Some(format!("label = \"{}\"", o.display_with(&function_disp)))
1001                }
1002            },
1003        )
1004        // self.graph.dot()
1005    }
1006}
1007
1008// use log::trace;
1009#[cfg(feature = "shadowing")]
1010pub mod parsing;
1011// use log::trace;
1012pub mod contract;
1013pub use contract::{
1014    ContractScalars, ContractionStrategy, SingleSmallestDegree, SmallestDegree, SmallestDegreeIter,
1015};
1016pub trait ExecutionStrategy<E, FL, L, K, FK, Aind>
1017where
1018    E: ExecuteOp<FL, L, K, FK, Aind>,
1019{
1020    /// Run the entire contraction to one leaf.
1021    #[allow(clippy::result_large_err)]
1022    fn execute_all<C: ContractionStrategy<E, L, K, FK, Aind>>(
1023        executor: &mut E,
1024        graph: &mut NetworkGraph<K, FK, Aind>,
1025        lib: &L,
1026        fnlib: &FL,
1027    ) -> Result<(), TensorNetworkError<K, FK>>
1028    where
1029        K: Display,
1030        FK: Display;
1031}
1032
1033pub struct Sequential;
1034
1035pub struct Steps<const N: usize> {}
1036pub struct StepsDebug<const N: usize> {}
1037
1038impl<const N: usize, E, L, FL, FK: Debug, K: Debug, Aind: AbsInd>
1039    ExecutionStrategy<E, FL, L, K, FK, Aind> for StepsDebug<N>
1040where
1041    E: ExecuteOp<FL, L, K, FK, Aind>,
1042    K: Clone,
1043    FK: Clone,
1044{
1045    fn execute_all<C: ContractionStrategy<E, L, K, FK, Aind>>(
1046        executor: &mut E,
1047        graph: &mut NetworkGraph<K, FK, Aind>,
1048        lib: &L,
1049        fnlib: &FL,
1050    ) -> Result<(), TensorNetworkError<K, FK>>
1051    where
1052        K: Display,
1053        FK: Display,
1054    {
1055        for _ in 0..N {
1056            // find the *one* ready op
1057            if let Some((extracted_graph, op)) = graph.extract_next_ready_op() {
1058                println!(
1059                    "Extracted_graph: {}",
1060                    extracted_graph.dot_impl(
1061                        |s| s.to_string(),
1062                        |_| "".to_string(),
1063                        |s| s.to_string(),
1064                        |_| "".to_string()
1065                    )
1066                );
1067                println!(
1068                    "Graph: {}",
1069                    graph.dot_impl(
1070                        |s| s.to_string(),
1071                        |_| "".to_string(),
1072                        |s| s.to_string(),
1073                        |_| "".to_string()
1074                    )
1075                );
1076                // execute + splice
1077                let replacement = executor.execute::<C>(extracted_graph, lib, fnlib, op)?;
1078                println!(
1079                    "Replacement Graph: {}",
1080                    replacement.dot_impl(
1081                        |s| s.to_string(),
1082                        |_| "".to_string(),
1083                        |s| s.to_string(),
1084                        |_| "".to_string()
1085                    )
1086                );
1087
1088                graph.splice_descendents_of(replacement);
1089            }
1090        }
1091
1092        Ok(())
1093    }
1094}
1095
1096impl<const N: usize, E, FL, L, K, FK, Aind: AbsInd> ExecutionStrategy<E, FL, L, K, FK, Aind>
1097    for Steps<N>
1098where
1099    E: ExecuteOp<FL, L, K, FK, Aind>,
1100    K: Clone + Debug,
1101    FK: Clone + Debug,
1102{
1103    fn execute_all<C: ContractionStrategy<E, L, K, FK, Aind>>(
1104        executor: &mut E,
1105        graph: &mut NetworkGraph<K, FK, Aind>,
1106        lib: &L,
1107        fnlib: &FL,
1108    ) -> Result<(), TensorNetworkError<K, FK>>
1109    where
1110        K: Display,
1111        FK: Display,
1112    {
1113        for _ in 0..N {
1114            // find the *one* ready op
1115            if let Some((extracted_graph, op)) = graph.extract_next_ready_op() {
1116                // execute + splice
1117                let replacement = executor.execute::<C>(extracted_graph, lib, fnlib, op)?;
1118                graph.splice_descendents_of(replacement);
1119            }
1120        }
1121
1122        Ok(())
1123    }
1124}
1125
1126impl<E, L, FL, K, FK, Aind: AbsInd> ExecutionStrategy<E, FL, L, K, FK, Aind> for Sequential
1127where
1128    E: ExecuteOp<FL, L, K, FK, Aind>,
1129    FK: Clone + Debug,
1130    K: Clone + Debug,
1131{
1132    fn execute_all<C: ContractionStrategy<E, L, K, FK, Aind>>(
1133        executor: &mut E,
1134        graph: &mut NetworkGraph<K, FK, Aind>,
1135        lib: &L,
1136        fnlib: &FL,
1137    ) -> Result<(), TensorNetworkError<K, FK>>
1138    where
1139        K: Display,
1140        FK: Display,
1141    {
1142        while {
1143            // find the *one* ready op
1144            if let Some((extracted_graph, op)) = graph.extract_next_ready_op() {
1145                // execute + splice
1146                let replacement = executor.execute::<C>(extracted_graph, lib, fnlib, op)?;
1147                graph.splice_descendents_of(replacement);
1148                true
1149            } else {
1150                false
1151            }
1152        } {}
1153
1154        Ok(())
1155    }
1156}
1157
1158// 2b) Parallel: batch‐execute all ready ops, then splice serially.
1159// pub struct Parallel;
1160// impl<E, K> ExecutionStrategy<E, K> for Parallel
1161// where
1162//     E: ExecuteOp<K> + Clone + Send + Sync,
1163//     K: Clone + Send + Sync,
1164// {
1165//     fn contract_all<L: Library<Key = K> + Sync>(
1166//         &self,
1167//         executor: &mut E,
1168//         graph: &mut NetworkGraph<K>,
1169//         lib: &L,
1170//     ) {
1171//         loop {
1172//             // 1) collect *all* ready ops this round
1173//             let ready = graph.find_all_ready_ops();
1174
1175//             if ready.is_empty() {
1176//                 break;
1177//             }
1178
1179//             // 2) execute them in parallel
1180//             let results: Vec<(NodeIndex, NetworkGraph<K>)> = ready
1181//                 .into_par_iter()
1182//                 .map(|(nid, op, leaves)| {
1183//                     let mut local = executor.clone();
1184//                     let replacement = local.execute(lib, op, &leaves);
1185//                     (nid, replacement)
1186//                 })
1187//                 .collect();
1188
1189//             // 3) splice back sequentially
1190//             for (nid, replacement) in results {
1191//                 graph.splice_descendents_of(nid, replacement);
1192//             }
1193//         }
1194//     }
1195// }
1196
1197pub trait ExecuteOp<FL, L, K, FK, Aind>: Sized {
1198    // type LibStruct;
1199    #[allow(clippy::result_large_err)]
1200    fn execute<C: ContractionStrategy<Self, L, K, FK, Aind>>(
1201        &mut self,
1202        graph: NetworkGraph<K, FK, Aind>,
1203        lib: &L,
1204        fn_lib: &FL,
1205        op: NetworkOp<FK>,
1206    ) -> Result<NetworkGraph<K, FK, Aind>, TensorNetworkError<K, FK>>
1207    where
1208        K: Display,
1209        FK: Display;
1210}
1211
1212impl<S, Store: TensorScalarStore, K, FK, Aind: AbsInd> Network<Store, K, FK, Aind>
1213where
1214    Store::Tensor: HasStructure<Structure = S>,
1215{
1216    #[allow(clippy::result_large_err)]
1217    pub fn execute<
1218        Strat: ExecutionStrategy<Store, FL, L, K, FK, Aind>,
1219        C: ContractionStrategy<Store, L, K, FK, Aind>,
1220        LT,
1221        L,
1222        FL,
1223    >(
1224        &mut self,
1225        lib: &L,
1226        fn_lib: &FL,
1227    ) -> Result<(), TensorNetworkError<K, FK>>
1228    where
1229        K: Display + Clone + Debug,
1230        FK: Display + Clone + Debug,
1231        L: Library<S, Key = K, Value = PermutedStructure<LT>> + Sync,
1232        FL: FunctionLibrary<Store::Tensor, Store::Scalar, Key = FK>,
1233        LT: LibraryTensor<WithIndices = Store::Tensor>,
1234        Store: ExecuteOp<FL, L, K, FK, Aind>,
1235    {
1236        self.merge_ops();
1237        // println!("Hi");
1238        // println!("{}", self.graph.dot());
1239        // Ok(())
1240        Strat::execute_all::<C>(&mut self.store, &mut self.graph, lib, fn_lib)
1241    }
1242}
1243
1244impl<
1245    LT: LibraryTensor + Clone,
1246    T: HasStructure
1247        + TensorStructure
1248        + Neg<Output = T>
1249        + Clone
1250        + Ref
1251        + Contract<LCM = T>
1252        + for<'a> AddAssign<T::Ref<'a>>
1253        + for<'a> AddAssign<LT::WithIndices>
1254        + From<LT::WithIndices>,
1255    L: Library<T::Structure, Key = K, Value = PermutedStructure<LT>>,
1256    Sc: Neg<Output = Sc>
1257        + RefOne
1258        + Div<Output = Sc>
1259        + for<'a> AddAssign<Sc::Ref<'a>>
1260        + Clone
1261        + for<'a> AddAssign<T::ScalarRef<'a>>
1262        + From<T::Scalar>
1263        + Ref
1264        + for<'a> MulAssign<Sc::Ref<'a>>,
1265    K: Display + Debug + Clone,
1266    FK: Display + Debug,
1267    FL: FunctionLibrary<T, Sc, Key = FK>,
1268    Aind: AbsInd,
1269> ExecuteOp<FL, L, K, FK, Aind> for NetworkStore<T, Sc>
1270where
1271    LT::WithIndices: PermuteTensor<Permuted = LT::WithIndices>,
1272    <<LT::WithIndices as HasStructure>::Structure as TensorStructure>::Slot:
1273        IsAbstractSlot<Aind = Aind>,
1274{
1275    fn execute<C: ContractionStrategy<Self, L, K, FK, Aind>>(
1276        &mut self,
1277        mut graph: NetworkGraph<K, FK, Aind>,
1278        lib: &L,
1279        fn_lib: &FL,
1280        op: NetworkOp<FK>,
1281    ) -> Result<NetworkGraph<K, FK, Aind>, TensorNetworkError<K, FK>> {
1282        graph.sync_order();
1283        match op {
1284            NetworkOp::Neg => {
1285                let ops = graph
1286                    .graph
1287                    .iter_nodes()
1288                    .find(|(_, _, d)| matches!(d, NetworkNode::Op(NetworkOp::Neg)));
1289
1290                let (opid, children, _) = ops.unwrap();
1291
1292                let mut child = None;
1293                for c in children {
1294                    if let Some(id) = graph.graph.involved_node_id(c)
1295                        && let NetworkNode::Leaf(l) = &graph.graph[id]
1296                    {
1297                        if child.is_some() {
1298                            return Err(TensorNetworkError::MoreThanOneNeg);
1299                        } else {
1300                            child = Some((id, l));
1301                        }
1302                    }
1303                }
1304                if let Some((child_id, leaf)) = child {
1305                    let new_node = match leaf {
1306                        NetworkLeaf::Scalar(s) => {
1307                            let s = self.scalar[*s].clone().neg();
1308                            let pos = self.scalar.len();
1309                            self.scalar.push(s);
1310
1311                            NetworkLeaf::Scalar(pos)
1312                        }
1313                        NetworkLeaf::LibraryKey(_) => {
1314                            let inds = graph.get_lib_data(lib, child_id).unwrap();
1315
1316                            let t = T::from(inds).neg();
1317                            let pos = self.tensors.len();
1318                            self.tensors.push(t);
1319                            NetworkLeaf::LocalTensor(pos)
1320                        }
1321                        NetworkLeaf::LocalTensor(t) => {
1322                            let t = self.tensors[*t].clone().neg();
1323                            let pos = self.tensors.len();
1324                            self.tensors.push(t);
1325                            NetworkLeaf::LocalTensor(pos)
1326                        }
1327                    };
1328                    graph.identify_nodes_without_self_edges(
1329                        &[child_id, opid],
1330                        NetworkNode::Leaf(new_node),
1331                    );
1332                    Ok(graph)
1333                } else {
1334                    Err(TensorNetworkError::ChildlessNeg)
1335                }
1336            }
1337            NetworkOp::Product => {
1338                // println!("Doing Product");
1339                let (graph, _) = C::contract(self, graph, lib)?;
1340                Ok(graph)
1341            }
1342            NetworkOp::Sum => {
1343                // let mut op = None;
1344                let mut targets = Vec::new();
1345                let mut all_nodes = Vec::new();
1346                for (n, _, v) in graph.graph.iter_nodes() {
1347                    all_nodes.push(n);
1348                    if let NetworkNode::Leaf(l) = &v {
1349                        targets.push((n, l));
1350                    }
1351                }
1352
1353                let (nf, first) = &targets[0];
1354
1355                let new_node = match first {
1356                    NetworkLeaf::Scalar(s) => {
1357                        let mut accumulator = self.scalar[*s].clone();
1358
1359                        for (_, t) in &targets[1..] {
1360                            match t {
1361                                NetworkLeaf::Scalar(s) => {
1362                                    accumulator += self.scalar[*s].refer();
1363                                }
1364                                NetworkLeaf::LocalTensor(t) => {
1365                                    if let Some(s) = self.tensors[*t].scalar_ref() {
1366                                        accumulator += s;
1367                                    } else {
1368                                        return Err(TensorNetworkError::NotAllScalars(
1369                                            "".to_string(),
1370                                        ));
1371                                    }
1372                                }
1373                                NetworkLeaf::LibraryKey { .. } => {
1374                                    return Err(TensorNetworkError::ScalarLibSum("".to_string()));
1375                                }
1376                            }
1377                        }
1378
1379                        let pos = self.scalar.len();
1380                        self.scalar.push(accumulator);
1381                        NetworkLeaf::Scalar(pos)
1382                    }
1383                    NetworkLeaf::LocalTensor(t) => {
1384                        let mut accumulator = self.tensors[*t].clone();
1385                        if accumulator.is_scalar() {
1386                            let mut accumulator = Sc::from(accumulator.scalar().unwrap());
1387
1388                            for (_, t) in &targets[1..] {
1389                                match t {
1390                                    NetworkLeaf::Scalar(s) => {
1391                                        accumulator += self.scalar[*s].refer();
1392                                    }
1393                                    NetworkLeaf::LocalTensor(t) => {
1394                                        if let Some(s) = self.tensors[*t].scalar_ref() {
1395                                            accumulator += s;
1396                                        } else {
1397                                            return Err(TensorNetworkError::NotAllScalars(
1398                                                "".to_string(),
1399                                            ));
1400                                        }
1401                                    }
1402                                    NetworkLeaf::LibraryKey { .. } => {
1403                                        return Err(TensorNetworkError::ScalarLibSum(
1404                                            "".to_string(),
1405                                        ));
1406                                    }
1407                                }
1408                            }
1409
1410                            let pos = self.scalar.len();
1411                            self.scalar.push(accumulator);
1412                            NetworkLeaf::Scalar(pos)
1413                        } else {
1414                            for (nid, t) in &targets[1..] {
1415                                match t {
1416                                    NetworkLeaf::Scalar(_) => {
1417                                        return Err(TensorNetworkError::SumScalarTensor(
1418                                            "".to_string(),
1419                                        ));
1420                                    }
1421                                    NetworkLeaf::LocalTensor(t) => {
1422                                        accumulator += self.tensors[*t].refer();
1423                                    }
1424                                    NetworkLeaf::LibraryKey(_) => {
1425                                        let with_index = graph.get_lib_data(lib, *nid).unwrap();
1426
1427                                        accumulator += with_index;
1428                                    }
1429                                }
1430                            }
1431
1432                            let pos = self.tensors.len();
1433                            self.tensors.push(accumulator);
1434
1435                            NetworkLeaf::LocalTensor(pos)
1436                        }
1437                    }
1438                    NetworkLeaf::LibraryKey(_) => {
1439                        let inds = graph.get_lib_data(lib, *nf).unwrap();
1440                        let mut accumulator = T::from(inds);
1441                        for (nid, t) in &targets[1..] {
1442                            match t {
1443                                NetworkLeaf::Scalar(_) => {
1444                                    return Err(TensorNetworkError::SumScalarTensor(
1445                                        "".to_string(),
1446                                    ));
1447                                }
1448                                NetworkLeaf::LocalTensor(t) => {
1449                                    accumulator += self.tensors[*t].refer();
1450                                }
1451                                NetworkLeaf::LibraryKey(_) => {
1452                                    let with = graph.get_lib_data(lib, *nid).unwrap();
1453                                    accumulator += with;
1454                                }
1455                            }
1456                        }
1457
1458                        let pos = self.tensors.len();
1459                        self.tensors.push(accumulator);
1460
1461                        NetworkLeaf::LocalTensor(pos)
1462                    }
1463                };
1464
1465                graph.identify_nodes_without_self_edges(&all_nodes, NetworkNode::Leaf(new_node));
1466                Ok(graph)
1467            }
1468            NetworkOp::Function(f) => {
1469                let ops = graph
1470                    .graph
1471                    .iter_nodes()
1472                    .find(|(_, _, d)| matches!(d, NetworkNode::Op(NetworkOp::Function(_))));
1473
1474                let (opid, children, _) = ops.unwrap();
1475
1476                let mut child = None;
1477                for c in children {
1478                    if let Some(id) = graph.graph.involved_node_id(c)
1479                        && let NetworkNode::Leaf(l) = &graph.graph[id]
1480                    {
1481                        if let Some((nid, _)) = child {
1482                            if nid != id {
1483                                return Err(TensorNetworkError::Other(anyhow!(
1484                                    "Cannot have more than one tensor argument to function"
1485                                )));
1486                            }
1487                        } else {
1488                            child = Some((id, l));
1489                        }
1490                    }
1491                }
1492                if let Some((child_id, leaf)) = child {
1493                    let new_node = match leaf {
1494                        NetworkLeaf::Scalar(s) => {
1495                            let s = self.scalar[*s].clone();
1496                            let pos = self.scalar.len();
1497                            let s = fn_lib.apply_scalar(&f, s)?;
1498                            self.scalar.push(s);
1499
1500                            NetworkLeaf::Scalar(pos)
1501                        }
1502                        NetworkLeaf::LibraryKey(_) => {
1503                            let inds = graph.get_lib_data(lib, child_id).unwrap();
1504                            let t = fn_lib.apply(&f, T::from(inds))?;
1505                            let pos = self.tensors.len();
1506                            self.tensors.push(t);
1507                            NetworkLeaf::LocalTensor(pos)
1508                        }
1509                        NetworkLeaf::LocalTensor(t) => {
1510                            let t = self.tensors[*t].clone();
1511                            let t = fn_lib.apply(&f, t)?;
1512                            let pos = self.tensors.len();
1513                            self.tensors.push(t);
1514                            NetworkLeaf::LocalTensor(pos)
1515                        }
1516                    };
1517                    graph.identify_nodes_without_self_edges(
1518                        &[child_id, opid],
1519                        NetworkNode::Leaf(new_node),
1520                    );
1521                    Ok(graph)
1522                } else {
1523                    Err(TensorNetworkError::ChildlessNeg)
1524                }
1525            }
1526            NetworkOp::Power(_) => {
1527                let mut pow = 0;
1528                let ops = graph.graph.iter_nodes().find(|(_, _, d)| {
1529                    if let NetworkNode::Op(NetworkOp::Power(i)) = d {
1530                        pow = *i;
1531                        true
1532                    } else {
1533                        false
1534                    }
1535                });
1536
1537                let (opid, children, _) = ops.unwrap();
1538
1539                let mut child = None;
1540                for c in children {
1541                    if let Some(id) = graph.graph.involved_node_id(c)
1542                        && let NetworkNode::Leaf(l) = &graph.graph[id]
1543                    {
1544                        if let Some((nid, _)) = child {
1545                            if nid != id {
1546                                return Err(TensorNetworkError::Other(anyhow!(
1547                                    "Cannot have more than one tensor argument to power:{}",
1548                                    graph.dot()
1549                                )));
1550                            }
1551                        } else {
1552                            child = Some((id, l));
1553                        }
1554                    }
1555                }
1556                let n = pow.abs();
1557                if let Some((child_id, leaf)) = child {
1558                    let new_node = match leaf {
1559                        NetworkLeaf::Scalar(si) => {
1560                            if n == 0 {
1561                                NetworkLeaf::Scalar(*si)
1562                            } else {
1563                                let mut s = self.scalar[*si].clone();
1564
1565                                for _ in 1..n {
1566                                    s *= self.scalar[*si].refer();
1567                                }
1568
1569                                if pow < 0 {
1570                                    s = s.ref_one() / s;
1571                                }
1572
1573                                let pos = self.scalar.len();
1574                                self.scalar.push(s);
1575
1576                                NetworkLeaf::Scalar(pos)
1577                            }
1578                        }
1579                        NetworkLeaf::LibraryKey(a) => {
1580                            let inds = graph.get_lib_data(lib, child_id).unwrap();
1581                            let mut t = T::from(inds);
1582
1583                            match pow {
1584                                0 => {
1585                                    let pos = self.scalar.len();
1586                                    let one = self.scalar[0].ref_one();
1587                                    self.scalar.push(one);
1588                                    NetworkLeaf::Scalar(pos)
1589                                }
1590                                1 => NetworkLeaf::LibraryKey(a.clone()),
1591                                _ => {
1592                                    let squares = n / 2;
1593                                    let mut square = t.contract(&t)?;
1594
1595                                    if n % 2 == 1 {
1596                                        if n != 1 {
1597                                            for _ in 0..squares {
1598                                                square = square.contract(&square)?;
1599                                            }
1600                                            t = square.contract(&t)?;
1601                                        }
1602
1603                                        if pow < 0 {
1604                                            if !t.is_scalar() {
1605                                                return Err(
1606                                                    TensorNetworkError::NegativeExponentNonScalar(
1607                                                        "".to_string(),
1608                                                    ),
1609                                                );
1610                                            } else {
1611                                                let mut s = Sc::from(t.scalar().unwrap());
1612                                                let pos = self.scalar.len();
1613                                                s = s.ref_one() / s;
1614                                                self.scalar.push(s);
1615                                                NetworkLeaf::Scalar(pos)
1616                                            }
1617                                        } else {
1618                                            let pos = self.tensors.len();
1619                                            self.tensors.push(t);
1620                                            NetworkLeaf::LocalTensor(pos)
1621                                        }
1622                                    } else {
1623                                        let mut s = Sc::from(square.scalar().unwrap());
1624                                        let sc = s.clone();
1625                                        for _ in 1..squares {
1626                                            s *= sc.refer();
1627                                        }
1628                                        let pos = self.scalar.len();
1629                                        if pow < 0 {
1630                                            s = s.ref_one() / s;
1631                                        }
1632                                        self.scalar.push(s);
1633                                        NetworkLeaf::Scalar(pos)
1634                                    }
1635                                }
1636                            }
1637                        }
1638                        NetworkLeaf::LocalTensor(ti) => {
1639                            let mut t = self.tensors[*ti].clone();
1640                            match pow {
1641                                0 => {
1642                                    let pos = self.scalar.len();
1643                                    let one = self.scalar[0].ref_one();
1644                                    self.scalar.push(one);
1645                                    NetworkLeaf::Scalar(pos)
1646                                }
1647                                1 => NetworkLeaf::LocalTensor(*ti),
1648                                _ => {
1649                                    let squares = n / 2;
1650                                    let mut square = t.contract(&t)?;
1651
1652                                    if n % 2 == 1 {
1653                                        if n != 1 {
1654                                            for _ in 0..squares {
1655                                                square = square.contract(&square)?;
1656                                            }
1657                                            t = square.contract(&t)?;
1658                                        }
1659                                        if pow < 0 {
1660                                            if !t.is_scalar() {
1661                                                return Err(
1662                                                    TensorNetworkError::NegativeExponentNonScalar(
1663                                                        "".to_string(),
1664                                                    ),
1665                                                );
1666                                            } else {
1667                                                let mut s = Sc::from(t.scalar().unwrap());
1668                                                let pos = self.scalar.len();
1669                                                s = s.ref_one() / s;
1670                                                self.scalar.push(s);
1671                                                NetworkLeaf::Scalar(pos)
1672                                            }
1673                                        } else {
1674                                            let pos = self.tensors.len();
1675                                            self.tensors.push(t);
1676                                            NetworkLeaf::LocalTensor(pos)
1677                                        }
1678                                    } else {
1679                                        let mut s = Sc::from(square.scalar().unwrap());
1680                                        let sc = s.clone();
1681                                        for _ in 1..squares {
1682                                            s *= sc.refer();
1683                                        }
1684                                        let pos = self.scalar.len();
1685                                        if pow < 0 {
1686                                            s = s.ref_one() / s;
1687                                        }
1688                                        self.scalar.push(s);
1689                                        NetworkLeaf::Scalar(pos)
1690                                    }
1691                                }
1692                            }
1693                        }
1694                    };
1695                    graph.identify_nodes_without_self_edges(
1696                        &[child_id, opid],
1697                        NetworkNode::Leaf(new_node),
1698                    );
1699                    Ok(graph)
1700                } else {
1701                    Err(TensorNetworkError::ChildlessNeg)
1702                }
1703            }
1704        }
1705    }
1706}
1707
1708pub trait Ref {
1709    type Ref<'a>
1710    where
1711        Self: 'a;
1712    fn refer(&self) -> Self::Ref<'_>;
1713}
1714
1715impl Ref for f64 {
1716    type Ref<'a>
1717        = &'a f64
1718    where
1719        Self: 'a;
1720
1721    fn refer(&self) -> Self::Ref<'_> {
1722        self
1723    }
1724}
1725
1726// #[cfg(feature = "shadowing")]
1727// pub mod levels;
1728#[cfg(feature = "shadowing")]
1729pub mod symbolica_interop;
1730
1731#[cfg(test)]
1732mod tests;