wick_invocation_server/
invocation_server.rs

1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use flow_component::SharedComponent;
5use parking_lot::RwLock;
6use tokio::sync::mpsc;
7use tokio_stream::wrappers::ReceiverStream;
8use tokio_stream::StreamExt;
9use tonic::{Response, Status};
10use wick_packet::PacketStream;
11use wick_rpc::rpc::invocation_service_server::InvocationService;
12use wick_rpc::rpc::{InvocationRequest, ListResponse, Packet, StatsResponse};
13use wick_rpc::{rpc, DurationStatistics, Statistics};
14
15/// A GRPC server for implementers of [flow_component::Component].
16pub struct InvocationServer {
17  /// The component that will handle incoming requests.
18  pub collection: SharedComponent,
19
20  stats: RwLock<HashMap<String, Statistics>>,
21}
22
23impl std::fmt::Debug for InvocationServer {
24  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25    f.debug_struct("InvocationServer").finish()
26  }
27}
28
29impl InvocationServer {
30  /// Constructor.
31  #[must_use]
32  pub fn new(collection: SharedComponent) -> Self {
33    Self {
34      collection,
35      stats: RwLock::new(HashMap::new()),
36    }
37  }
38}
39
40#[derive(Debug, PartialEq, Eq, Copy, Clone)]
41enum JobResult {
42  Success,
43  Error,
44}
45
46impl InvocationServer {
47  fn record_execution<T: Into<String>>(&self, job: T, status: JobResult, time: Duration) {
48    let mut stats = self.stats.write();
49    let job = job.into();
50    let stat = stats.entry(job.clone()).or_insert_with(Statistics::default);
51    stat.runs += 1;
52    if status == JobResult::Error {
53      stat.errors += 1;
54    }
55    let durations = if stat.execution_duration.is_some() {
56      let mut durations = stat.execution_duration.take().unwrap();
57      if time < durations.min_time {
58        durations.min_time = time;
59      } else if time > durations.max_time {
60        durations.max_time = time;
61      }
62      let average = ((durations.average_time * (stat.runs - 1)) + time) / stat.runs;
63      durations.average_time = average;
64      let total = durations.total_time + time;
65      durations.total_time = total;
66
67      durations
68    } else {
69      DurationStatistics::new(time, time, time, time)
70    };
71    stat.execution_duration.replace(durations);
72  }
73}
74
75fn convert_invocation_stream(mut streaming: tonic::Streaming<InvocationRequest>) -> PacketStream {
76  let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
77  tokio::spawn(async move {
78    while let Some(p) = streaming.next().await {
79      let result = p.map_err(|e| wick_packet::Error::Component(e.to_string())).map(|p| {
80        p.data.map_or_else(
81          || unreachable!(),
82          |p| match p {
83            rpc::invocation_request::Data::Invocation(_) => unreachable!(),
84            rpc::invocation_request::Data::Packet(p) => wick_packet::Packet::from(p),
85          },
86        )
87      });
88
89      let _ = tx.send(result);
90    }
91  });
92
93  wick_packet::PacketStream::new(Box::new(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)))
94}
95
96#[async_trait::async_trait]
97impl InvocationService for InvocationServer {
98  type InvokeStream = ReceiverStream<Result<Packet, Status>>;
99
100  async fn invoke(
101    &self,
102    request: tonic::Request<tonic::Streaming<InvocationRequest>>,
103  ) -> Result<Response<Self::InvokeStream>, Status> {
104    let start = Instant::now();
105
106    let (tx, rx) = mpsc::channel(4);
107    let mut stream = request.into_inner();
108    let first = stream.next().await;
109    let invocation: wick_packet::InvocationData = if let Some(Ok(inv)) = first {
110      if let Some(rpc::invocation_request::Data::Invocation(inv)) = inv.data {
111        inv
112          .try_into()
113          .map_err(|_| Status::invalid_argument("First message must be a valid invocation"))?
114      } else {
115        return Err(Status::invalid_argument("First message must be an invocation"));
116      }
117    } else {
118      return Err(Status::invalid_argument("First message must be an invocation"));
119    };
120    let stream = convert_invocation_stream(stream);
121    let packet_stream = PacketStream::new(Box::new(stream));
122    let invocation = invocation.with_stream(packet_stream);
123
124    let op_id = invocation.target().operation_id().to_owned();
125
126    let result = self
127      .collection
128      .handle(invocation, Default::default(), Default::default())
129      .await;
130    if let Err(e) = result {
131      let message = e.to_string();
132      error!("invocation failed: {}", message);
133      tx.send(Err(Status::internal(message))).await.unwrap();
134      self.record_execution(op_id, JobResult::Error, start.elapsed());
135    } else {
136      tokio::spawn(async move {
137        let mut receiver = result.unwrap();
138        while let Some(next) = receiver.next().await {
139          if next.is_err() {
140            todo!("Handle error");
141          }
142          let next = next.unwrap();
143
144          tx.send(Ok(next.into())).await.unwrap();
145        }
146      });
147      self.record_execution(op_id, JobResult::Success, start.elapsed());
148    }
149
150    Ok(Response::new(ReceiverStream::new(rx)))
151  }
152
153  async fn list(&self, _request: tonic::Request<rpc::ListRequest>) -> Result<Response<ListResponse>, Status> {
154    let response = ListResponse {
155      components: vec![self.collection.signature().clone().try_into().unwrap()],
156    };
157    Ok(Response::new(response))
158  }
159
160  async fn stats(&self, _request: tonic::Request<rpc::StatsRequest>) -> Result<Response<StatsResponse>, Status> {
161    Ok(Response::new(StatsResponse {
162      stats: self.stats.read().values().cloned().map(From::from).collect(),
163    }))
164  }
165}
166
167#[cfg(test)]
168mod tests {
169  // tested in the workspace root with a native component
170}