tc_scalar/reference/
while.rs

1//! Resolve a `Closure` repeatedly while a condition is met.
2
3use std::collections::HashSet;
4use std::fmt;
5
6use async_hash::{Digest, Hash, Output};
7use async_trait::async_trait;
8use destream::{de, en};
9use futures::try_join;
10use get_size::GetSize;
11use get_size_derive::*;
12use log::debug;
13use safecast::{Match, TryCastFrom, TryCastInto};
14
15use tc_error::*;
16use tc_transact::public::{ClosureInstance, Public, StateInstance, ToState};
17use tc_value::Value;
18use tcgeneric::{Id, Instance, Map, PathSegment, TCPathBuf};
19
20use crate::{OpDef, Scalar, Scope};
21
22use super::Refer;
23
24/// A while loop.
25#[derive(Clone, Eq, PartialEq, GetSize)]
26pub struct While {
27    cond: Scalar,
28    closure: Scalar,
29    state: Scalar,
30}
31
32#[async_trait]
33impl<State> Refer<State> for While
34where
35    State: StateInstance + Refer<State> + From<Scalar>,
36    State::Closure: From<(Map<State>, OpDef)> + TryCastFrom<State>,
37    Map<State>: TryFrom<State, Error = TCError>,
38    Value: TryFrom<State, Error = TCError> + TryCastFrom<State>,
39    bool: TryCastFrom<State>,
40{
41    fn dereference_self(self, path: &TCPathBuf) -> Self {
42        Self {
43            cond: self.cond.dereference_self(path),
44            closure: self.closure.dereference_self(path),
45            state: self.state.dereference_self(path),
46        }
47    }
48
49    fn is_conditional(&self) -> bool {
50        self.closure.is_conditional()
51    }
52
53    fn is_inter_service_write(&self, cluster_path: &[PathSegment]) -> bool {
54        self.cond.is_inter_service_write(cluster_path)
55            || self.closure.is_inter_service_write(cluster_path)
56            || self.state.is_inter_service_write(cluster_path)
57    }
58
59    fn is_ref(&self) -> bool {
60        true
61    }
62
63    fn reference_self(self, path: &TCPathBuf) -> Self {
64        Self {
65            cond: self.cond.reference_self(path),
66            closure: self.closure.reference_self(path),
67            state: self.state.reference_self(path),
68        }
69    }
70
71    fn requires(&self, deps: &mut HashSet<Id>) {
72        self.cond.requires(deps);
73        self.closure.requires(deps);
74        self.state.requires(deps);
75    }
76
77    async fn resolve<'a, T: ToState<State> + Public<State> + Instance>(
78        self,
79        context: &'a Scope<'a, State, T>,
80        txn: &'a State::Txn,
81    ) -> TCResult<State> {
82        debug!("While::resolve {:?}", self);
83
84        if self.cond.is_conditional() {
85            return Err(bad_request!(
86                "While does not allow a nested conditional {:?}",
87                self.cond,
88            ));
89        } else if self.state.is_conditional() {
90            return Err(bad_request!(
91                "While does not allow a nested conditional {:?}",
92                self.state,
93            ));
94        }
95
96        let (cond, closure, mut state) = try_join!(
97            self.cond.resolve(context, txn),
98            self.closure.resolve(context, txn),
99            self.state.resolve(context, txn)
100        )?;
101
102        debug!("While condition definition is {:?}", cond);
103
104        loop {
105            let mut cond = cond.clone();
106            let still_going = loop {
107                let cond_op = State::Closure::try_cast_from(cond.clone(), |s| {
108                    bad_request!("expected an Op or Closure for a While loop but found {s:?}")
109                })?;
110
111                let intermediate = Box::new(cond_op).call(txn.clone(), state.clone()).await?;
112                if intermediate.is_ref() {
113                    cond = intermediate;
114                } else {
115                    break bool::try_cast_from(intermediate, |s| {
116                        bad_request!("expected a boolean condition but found {s:?}")
117                    })?;
118                }
119            };
120
121            if still_going {
122                let while_op: State::Closure = closure.clone().try_cast_into(|s| {
123                    bad_request!("expected an Op or Closure for a While loop but found {s:?}")
124                })?;
125
126                state = Box::new(while_op).call(txn.clone(), state).await?;
127
128                if state.is_conditional() {
129                    return Err(bad_request!(
130                        "conditional state {state:?} is not allowed in a While loop",
131                    ));
132                }
133
134                debug!("While loop state is {state:?}");
135            } else {
136                break Ok(state);
137            }
138        }
139    }
140}
141
142impl<'a, D: Digest> Hash<D> for &'a While {
143    fn hash(self) -> Output<D> {
144        Hash::<D>::hash((&self.cond, &self.closure, &self.state))
145    }
146}
147
148impl TryCastFrom<Scalar> for While {
149    fn can_cast_from(scalar: &Scalar) -> bool {
150        scalar.matches::<(Scalar, Scalar, Scalar)>()
151    }
152
153    fn opt_cast_from(scalar: Scalar) -> Option<Self> {
154        if scalar.matches::<(Scalar, Scalar, Scalar)>() {
155            scalar.opt_cast_into().map(|(cond, closure, state)| Self {
156                cond,
157                closure,
158                state,
159            })
160        } else {
161            None
162        }
163    }
164}
165
166#[async_trait]
167impl de::FromStream for While {
168    type Context = ();
169
170    async fn from_stream<D: de::Decoder>(context: (), decoder: &mut D) -> Result<Self, D::Error> {
171        let while_loop = Scalar::from_stream(context, decoder).await?;
172        Self::try_cast_from(while_loop, |s| {
173            de::Error::invalid_value(format!("{s:?}"), "a While loop")
174        })
175    }
176}
177
178impl<'en> en::IntoStream<'en> for While {
179    fn into_stream<E: en::Encoder<'en>>(self, encoder: E) -> Result<E::Ok, E::Error> {
180        (self.cond, self.closure, self.state).into_stream(encoder)
181    }
182}
183
184impl<'en> en::ToStream<'en> for While {
185    fn to_stream<E: en::Encoder<'en>>(&'en self, encoder: E) -> Result<E::Ok, E::Error> {
186        en::IntoStream::into_stream((&self.cond, &self.closure, &self.state), encoder)
187    }
188}
189
190impl fmt::Debug for While {
191    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
192        write!(
193            f,
194            "while {:?} call {:?} with state {:?}",
195            self.cond, self.closure, self.state
196        )
197    }
198}