1use std::collections::HashSet;
4use std::fmt;
5
6use async_hash::{Digest, Hash, Output};
7use async_trait::async_trait;
8use destream::{de, en};
9use futures::try_join;
10use get_size::GetSize;
11use get_size_derive::*;
12use log::debug;
13use safecast::{Match, TryCastFrom, TryCastInto};
14
15use tc_error::*;
16use tc_transact::public::{ClosureInstance, Public, StateInstance, ToState};
17use tc_value::Value;
18use tcgeneric::{Id, Instance, Map, PathSegment, TCPathBuf};
19
20use crate::{OpDef, Scalar, Scope};
21
22use super::Refer;
23
24#[derive(Clone, Eq, PartialEq, GetSize)]
26pub struct While {
27 cond: Scalar,
28 closure: Scalar,
29 state: Scalar,
30}
31
32#[async_trait]
33impl<State> Refer<State> for While
34where
35 State: StateInstance + Refer<State> + From<Scalar>,
36 State::Closure: From<(Map<State>, OpDef)> + TryCastFrom<State>,
37 Map<State>: TryFrom<State, Error = TCError>,
38 Value: TryFrom<State, Error = TCError> + TryCastFrom<State>,
39 bool: TryCastFrom<State>,
40{
41 fn dereference_self(self, path: &TCPathBuf) -> Self {
42 Self {
43 cond: self.cond.dereference_self(path),
44 closure: self.closure.dereference_self(path),
45 state: self.state.dereference_self(path),
46 }
47 }
48
49 fn is_conditional(&self) -> bool {
50 self.closure.is_conditional()
51 }
52
53 fn is_inter_service_write(&self, cluster_path: &[PathSegment]) -> bool {
54 self.cond.is_inter_service_write(cluster_path)
55 || self.closure.is_inter_service_write(cluster_path)
56 || self.state.is_inter_service_write(cluster_path)
57 }
58
59 fn is_ref(&self) -> bool {
60 true
61 }
62
63 fn reference_self(self, path: &TCPathBuf) -> Self {
64 Self {
65 cond: self.cond.reference_self(path),
66 closure: self.closure.reference_self(path),
67 state: self.state.reference_self(path),
68 }
69 }
70
71 fn requires(&self, deps: &mut HashSet<Id>) {
72 self.cond.requires(deps);
73 self.closure.requires(deps);
74 self.state.requires(deps);
75 }
76
77 async fn resolve<'a, T: ToState<State> + Public<State> + Instance>(
78 self,
79 context: &'a Scope<'a, State, T>,
80 txn: &'a State::Txn,
81 ) -> TCResult<State> {
82 debug!("While::resolve {:?}", self);
83
84 if self.cond.is_conditional() {
85 return Err(bad_request!(
86 "While does not allow a nested conditional {:?}",
87 self.cond,
88 ));
89 } else if self.state.is_conditional() {
90 return Err(bad_request!(
91 "While does not allow a nested conditional {:?}",
92 self.state,
93 ));
94 }
95
96 let (cond, closure, mut state) = try_join!(
97 self.cond.resolve(context, txn),
98 self.closure.resolve(context, txn),
99 self.state.resolve(context, txn)
100 )?;
101
102 debug!("While condition definition is {:?}", cond);
103
104 loop {
105 let mut cond = cond.clone();
106 let still_going = loop {
107 let cond_op = State::Closure::try_cast_from(cond.clone(), |s| {
108 bad_request!("expected an Op or Closure for a While loop but found {s:?}")
109 })?;
110
111 let intermediate = Box::new(cond_op).call(txn.clone(), state.clone()).await?;
112 if intermediate.is_ref() {
113 cond = intermediate;
114 } else {
115 break bool::try_cast_from(intermediate, |s| {
116 bad_request!("expected a boolean condition but found {s:?}")
117 })?;
118 }
119 };
120
121 if still_going {
122 let while_op: State::Closure = closure.clone().try_cast_into(|s| {
123 bad_request!("expected an Op or Closure for a While loop but found {s:?}")
124 })?;
125
126 state = Box::new(while_op).call(txn.clone(), state).await?;
127
128 if state.is_conditional() {
129 return Err(bad_request!(
130 "conditional state {state:?} is not allowed in a While loop",
131 ));
132 }
133
134 debug!("While loop state is {state:?}");
135 } else {
136 break Ok(state);
137 }
138 }
139 }
140}
141
142impl<'a, D: Digest> Hash<D> for &'a While {
143 fn hash(self) -> Output<D> {
144 Hash::<D>::hash((&self.cond, &self.closure, &self.state))
145 }
146}
147
148impl TryCastFrom<Scalar> for While {
149 fn can_cast_from(scalar: &Scalar) -> bool {
150 scalar.matches::<(Scalar, Scalar, Scalar)>()
151 }
152
153 fn opt_cast_from(scalar: Scalar) -> Option<Self> {
154 if scalar.matches::<(Scalar, Scalar, Scalar)>() {
155 scalar.opt_cast_into().map(|(cond, closure, state)| Self {
156 cond,
157 closure,
158 state,
159 })
160 } else {
161 None
162 }
163 }
164}
165
166#[async_trait]
167impl de::FromStream for While {
168 type Context = ();
169
170 async fn from_stream<D: de::Decoder>(context: (), decoder: &mut D) -> Result<Self, D::Error> {
171 let while_loop = Scalar::from_stream(context, decoder).await?;
172 Self::try_cast_from(while_loop, |s| {
173 de::Error::invalid_value(format!("{s:?}"), "a While loop")
174 })
175 }
176}
177
178impl<'en> en::IntoStream<'en> for While {
179 fn into_stream<E: en::Encoder<'en>>(self, encoder: E) -> Result<E::Ok, E::Error> {
180 (self.cond, self.closure, self.state).into_stream(encoder)
181 }
182}
183
184impl<'en> en::ToStream<'en> for While {
185 fn to_stream<E: en::Encoder<'en>>(&'en self, encoder: E) -> Result<E::Ok, E::Error> {
186 en::IntoStream::into_stream((&self.cond, &self.closure, &self.state), encoder)
187 }
188}
189
190impl fmt::Debug for While {
191 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
192 write!(
193 f,
194 "while {:?} call {:?} with state {:?}",
195 self.cond, self.closure, self.state
196 )
197 }
198}