tc_scalar/reference/
with.rs

1//! Limits the execution scope of an inline `Op`.
2
3use std::collections::HashSet;
4use std::fmt;
5use std::ops::Deref;
6
7use async_hash::{Digest, Hash, Output};
8use async_trait::async_trait;
9use destream::{de, en};
10use futures::future::TryFutureExt;
11use get_size::GetSize;
12use get_size_derive::*;
13use log::debug;
14use safecast::{TryCastFrom, TryCastInto};
15
16use tc_error::*;
17use tc_transact::public::{Public, StateInstance, ToState};
18use tc_value::Value;
19use tcgeneric::{Id, Instance, Map, PathSegment, TCPathBuf, Tuple};
20
21use crate::{OpDef, Scalar, Scope, SELF};
22
23use super::Refer;
24
25/// A flow control operator which closes over the context of an [`OpDef`] to produce a closure.
26#[derive(Clone, Eq, PartialEq, GetSize)]
27pub struct With {
28    capture: Tuple<Id>,
29    op: OpDef,
30}
31
32impl With {
33    pub fn new(capture: Tuple<Id>, op: OpDef) -> Self {
34        With { capture, op }
35    }
36}
37
38#[async_trait]
39impl<State> Refer<State> for With
40where
41    State: StateInstance + Refer<State> + From<Scalar>,
42    State::Closure: From<(Map<State>, OpDef)> + TryCastFrom<State>,
43    Map<State>: TryFrom<State, Error = TCError>,
44    Value: TryFrom<State, Error = TCError> + TryCastFrom<State>,
45    bool: TryCastFrom<State>,
46{
47    fn dereference_self(self, path: &TCPathBuf) -> Self {
48        Self {
49            capture: self.capture.into_iter().filter(|id| id != &SELF).collect(),
50            op: self.op.dereference_self::<State>(path),
51        }
52    }
53
54    fn is_conditional(&self) -> bool {
55        false
56    }
57
58    fn is_inter_service_write(&self, cluster_path: &[PathSegment]) -> bool {
59        self.op.is_inter_service_write::<State>(cluster_path)
60    }
61
62    fn is_ref(&self) -> bool {
63        true
64    }
65
66    fn reference_self(self, path: &TCPathBuf) -> Self {
67        let before = self.op.clone();
68        let op = self.op.reference_self::<State>(path);
69        let capture = if op == before {
70            self.capture
71        } else {
72            let mut capture = self.capture;
73            capture.push(SELF.into());
74            capture
75        };
76
77        Self { capture, op }
78    }
79
80    fn requires(&self, deps: &mut HashSet<Id>) {
81        deps.extend(self.capture.iter().filter(|id| *id != &SELF).cloned())
82    }
83
84    async fn resolve<'a, T: ToState<State> + Public<State> + Instance>(
85        self,
86        context: &'a Scope<'a, State, T>,
87        _txn: &'a State::Txn,
88    ) -> TCResult<State> {
89        let closed_over = self
90            .capture
91            .into_iter()
92            .map(|id| {
93                context.resolve_id(&id).map(|state| {
94                    debug!("closure captured {}: {:?}", id, state);
95                    (id, state)
96                })
97            })
98            .collect::<TCResult<Map<State>>>()?;
99
100        Ok(State::Closure::from((closed_over, self.op)).into())
101    }
102}
103
104impl<'a, D: Digest> Hash<D> for &'a With {
105    fn hash(self) -> Output<D> {
106        Hash::<D>::hash((self.capture.deref(), &self.op))
107    }
108}
109
110impl TryCastFrom<Scalar> for With {
111    fn can_cast_from(scalar: &Scalar) -> bool {
112        if let Scalar::Tuple(tuple) = scalar {
113            if tuple.len() == 2 {
114                if !OpDef::can_cast_from(&tuple[1]) {
115                    return false;
116                }
117
118                return match &tuple[0] {
119                    Scalar::Tuple(capture) => capture.iter().all(Id::can_cast_from),
120                    Scalar::Value(Value::Tuple(capture)) => capture.iter().all(Id::can_cast_from),
121                    _ => false,
122                };
123            }
124        }
125
126        false
127    }
128
129    fn opt_cast_from(scalar: Scalar) -> Option<Self> {
130        let (capture, op): (Scalar, OpDef) = scalar.opt_cast_into()?;
131        let capture = match capture {
132            Scalar::Tuple(capture) => capture
133                .into_iter()
134                .map(Id::opt_cast_from)
135                .collect::<Option<Tuple<Id>>>(),
136
137            Scalar::Value(Value::Tuple(capture)) => capture
138                .into_iter()
139                .map(Id::opt_cast_from)
140                .collect::<Option<Tuple<Id>>>(),
141
142            _ => None,
143        }?;
144
145        Some(Self { capture, op })
146    }
147}
148
149#[async_trait]
150impl de::FromStream for With {
151    type Context = ();
152
153    async fn from_stream<D: de::Decoder>(context: (), decoder: &mut D) -> Result<Self, D::Error> {
154        de::FromStream::from_stream(context, decoder)
155            .map_ok(|(capture, op)| Self { capture, op })
156            .await
157    }
158}
159
160impl<'en> en::IntoStream<'en> for With {
161    fn into_stream<E: en::Encoder<'en>>(self, encoder: E) -> Result<E::Ok, E::Error> {
162        (self.capture, self.op).into_stream(encoder)
163    }
164}
165
166impl<'en> en::ToStream<'en> for With {
167    fn to_stream<E: en::Encoder<'en>>(&'en self, encoder: E) -> Result<E::Ok, E::Error> {
168        en::IntoStream::into_stream((&self.capture, &self.op), encoder)
169    }
170}
171
172impl fmt::Debug for With {
173    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
174        write!(f, "with {:?}: {:?}", self.capture, self.op)
175    }
176}