1use crate::reader::run_reader_thread;
2use crate::request_factory::{self, *};
3use ads_proto::error::AdsError;
4use ads_proto::proto::ads_state::AdsState;
5use ads_proto::proto::ads_transition_mode::AdsTransMode;
6use ads_proto::proto::ams_address::{AmsAddress, AmsNetId};
7use ads_proto::proto::ams_header::{AmsHeader, AmsTcpHeader};
8use ads_proto::proto::proto_traits::*;
9use ads_proto::proto::request::{
10 ReadDeviceInfoRequest, ReadRequest, ReadStateRequest, Request, WriteRequest,
11};
12use ads_proto::proto::response::Response;
13use ads_proto::proto::response::*;
14use ads_proto::proto::state_flags::StateFlags;
15use ads_proto::proto::sumup::sumup_request::{SumupReadRequest, SumupWriteRequest};
16use ads_proto::proto::sumup::sumup_response::{SumupReadResponse, SumupWriteResponse};
17use anyhow::Error;
18use anyhow::{anyhow, Result};
19use byteorder::{LittleEndian, ReadBytesExt};
20use std::collections::HashMap;
21use std::io::Write;
22use std::net::{Ipv4Addr, Shutdown, SocketAddr, TcpStream};
23use std::str::FromStr;
24use std::sync::mpsc::{channel, Receiver, Sender};
25use std::time::Duration;
26
27pub const ADS_UDP_SERVER_PORT: u16 = 48899;
29pub const ADS_TCP_SERVER_PORT: u16 = 48898;
31pub const ADS_SECURE_TCP_SERVER_PORT: u16 = 8016;
33
34pub type ClientResult<T> = Result<T, anyhow::Error>;
35type TxGeneral = Sender<(u32, Sender<ClientResult<Response>>)>;
36type TxNotification = Sender<(u32, Sender<ClientResult<AdsNotificationStream>>)>;
37type TxStreamUpdate = Sender<TcpStream>;
38
39#[derive(Debug)]
40pub struct Client {
41 route: Option<Ipv4Addr>,
42 ams_targed_address: AmsAddress,
43 ams_source_address: AmsAddress,
44 stream: Option<TcpStream>,
45 invoke_id: u32,
46 tx_general: Option<TxGeneral>,
47 tx_notification: Option<TxNotification>,
48 tx_stream_update: Option<TxStreamUpdate>,
49 thread_started: bool,
50 handle_list: HashMap<String, u32>,
51 notification_handle_list: HashMap<String, u32>,
52}
53
54impl Drop for Client {
55 fn drop(&mut self) {
56 if let Some(s) = &self.stream {
57 let _ = s.shutdown(Shutdown::Both);
58 }
59 }
60}
61
62impl Client {
63 pub fn new(ams_targed_address: AmsAddress, route: Option<Ipv4Addr>) -> Self {
66 Client {
67 route,
68 ams_targed_address,
69 ams_source_address: AmsAddress::new(AmsNetId::from([0, 0, 0, 0, 0, 0]), 0),
70 stream: None,
71 invoke_id: 0,
72 tx_general: None,
73 tx_notification: None,
74 tx_stream_update: None,
75 thread_started: false,
76 handle_list: HashMap::new(),
77 notification_handle_list: HashMap::new(),
78 }
79 }
80
81 pub fn connect(&mut self) -> ClientResult<ReadStateResponse> {
84 if self.stream.is_none() {
85 self.stream = Some(self.create_stream()?);
86 if self.route.is_none() {
87 self.open_local_port()?;
88 }
89 }
90
91 if let Some(stream) = &self.stream {
92 if self.route.is_some() {
93 self.ams_source_address
94 .update_from_socket_addr(stream.local_addr()?)?;
95 }
96
97 if !self.thread_started {
98 let (tx, rx) = channel::<(u32, Sender<ClientResult<Response>>)>();
99 let (tx_not, rx_not) =
100 channel::<(u32, Sender<ClientResult<AdsNotificationStream>>)>();
101 let (tx_tcp, rx_tcp) = channel::<TcpStream>();
102 self.tx_general = Some(tx);
103 self.tx_notification = Some(tx_not);
104 self.tx_stream_update = Some(tx_tcp);
105 self.thread_started = run_reader_thread(stream.try_clone()?, rx, rx_not, rx_tcp)?;
106 } else if let Some(tx) = &self.tx_stream_update {
107 tx.send(stream.try_clone()?)?;
108 }
109 self.read_state()
111 } else {
112 Err(anyhow!(AdsError::ErrPortNotConnected))
113 }
114 }
115
116 fn create_stream(&mut self) -> ClientResult<TcpStream> {
118 let mut route = Ipv4Addr::from_str("127.0.0.1")?;
119 if let Some(r) = self.route {
120 route = r;
121 }
122
123 let stream = TcpStream::connect(SocketAddr::from((route, ADS_TCP_SERVER_PORT)))?;
124 stream.set_nodelay(true)?;
125 stream.set_write_timeout(Some(Duration::from_millis(1000)))?;
126 stream.set_read_timeout(Some(Duration::from_millis(1000)))?;
127 Ok(stream)
128 }
129
130 fn open_local_port(&mut self) -> ClientResult<()> {
132 let request_port_msg = [0, 16, 2, 0, 0, 0, 0, 0];
133 let mut buf = [0; 14];
134
135 if let Some(s) = &mut self.stream {
136 s.write_all(&request_port_msg).unwrap();
137 use std::io::Read;
138 s.read_exact(&mut buf)?;
139 let (_, mut buf_split) = buf.split_at(6);
140 let ams_address = AmsAddress::read_from(&mut buf_split);
141 self.ams_source_address = ams_address.unwrap();
142 }
143 Ok(())
144 }
145
146 pub fn request(&mut self, request: Request) -> ClientResult<Response> {
150 let rx = self.request_rx(request)?;
151 let response = rx.recv()?;
152 self.check_tcp_stream(&response);
153 response
154 }
155
156 pub fn request_rx(&mut self, request: Request) -> ClientResult<Receiver<Result<Response>>> {
159 let ams_header = self.new_tcp_ams_request_header(request);
160 let (tx, rx) = channel::<ClientResult<Response>>();
161 self.get_general_tx()?
162 .send((self.invoke_id, tx))
163 .expect("Failed to send request to thread by mpsc channel");
164 let mut buffer = Vec::new();
165
166 ams_header.write_to(&mut buffer)?;
167
168 if let Some(s) = &mut self.stream {
169 s.write_all(&buffer)?;
170 return Ok(rx);
171 }
172 Err(anyhow!(AdsError::AdsErrClientPortNotOpen))
173 }
174
175 pub fn read_by_name(&mut self, var_name: &str, len: u32) -> ClientResult<ReadResponse> {
178 let handle = self.get_var_handle(var_name)?;
179 let request = Request::Read(request_factory::get_read_request(handle, len));
180 let response = self.request(request)?;
181 let read_response: ReadResponse = response.try_into()?;
182 Ok(read_response)
183 }
184
185 pub fn sumup_read_by_name(
188 &mut self,
189 var_list: &HashMap<String, u32>,
190 ) -> ClientResult<HashMap<String, ReadResponse>> {
191 let mut requests: Vec<ReadRequest> = Vec::new();
192 let mut var_names: Vec<String> = Vec::new();
193 for name in var_list.keys() {
194 var_names.push(name.clone());
195 }
196
197 let handles = self.sumup_get_var_handle(&var_names)?;
198 for (var, length) in var_list {
199 if let Some(h) = handles.get(var) {
200 requests.push(get_read_request(*h, *length));
201 }
202 }
203
204 let mut buf = Vec::new();
205 let sumup_request = SumupReadRequest::new(requests);
206 sumup_request.write_to(&mut buf)?;
207 let request = Request::ReadWrite(get_sumup_read_request(
208 sumup_request.request_count(),
209 sumup_request.expected_response_len(),
210 buf,
211 ));
212 let response = self.request(request)?;
213 let read_write_response: ReadWriteResponse = response.try_into()?;
214 let sumup_read_response =
215 SumupReadResponse::read_from(&mut read_write_response.data.as_slice())?;
216 let mut result: HashMap<String, ReadResponse> = HashMap::new();
217
218 if read_write_response.result == AdsError::ErrNoError {
219 for (n, (name, _)) in var_list.iter().enumerate() {
220 result.insert(name.clone(), sumup_read_response.read_responses[n].clone());
221 }
222 } else {
223 return Err(anyhow![read_write_response.result]);
224 }
225 Ok(result)
226 }
227
228 pub fn write_by_name(&mut self, var_name: &str, data: Vec<u8>) -> ClientResult<WriteResponse> {
231 let handle = self.get_var_handle(var_name)?;
232 let request = Request::Write(request_factory::get_write_request(handle, data));
233 let response = self.request(request)?;
234 let write_response: WriteResponse = response.try_into()?;
235 Ok(write_response)
236 }
237
238 pub fn sumup_write_by_name(
241 &mut self,
242 var_list: HashMap<String, Vec<u8>>,
243 ) -> ClientResult<HashMap<String, WriteResponse>> {
244 let mut requests: Vec<WriteRequest> = Vec::new();
245 for (varname, data) in &var_list {
246 let handle = self.get_var_handle(varname.as_str())?;
247 requests.push(get_write_request(handle, data.clone()));
248 }
249
250 let mut buf = Vec::new();
251 let sumup_request = SumupWriteRequest::new(requests);
252 sumup_request.write_to(&mut buf)?;
253 let request = Request::ReadWrite(get_sumup_write_request(
254 sumup_request.request_count(),
255 sumup_request.expected_response_len(),
256 buf,
257 ));
258 let response = self.request(request)?;
259 let read_write_response: ReadWriteResponse = response.try_into()?;
260 let sumup_write_response =
261 SumupWriteResponse::read_from(&mut read_write_response.data.as_slice())?;
262 let mut result: HashMap<String, WriteResponse> = HashMap::new();
263
264 if read_write_response.result == AdsError::ErrNoError {
265 for (n, (name, _)) in var_list.iter().enumerate() {
266 result.insert(
267 name.clone(),
268 sumup_write_response.write_responses[n].clone(),
269 );
270 }
271 } else {
272 return Err(anyhow![read_write_response.result]);
273 }
274 Ok(result)
275 }
276
277 pub fn read_device_info(&mut self) -> ClientResult<ReadDeviceInfoResponse> {
280 let request = Request::ReadDeviceInfo(ReadDeviceInfoRequest::new());
281 let response = self.request(request)?;
282 let device_info_response: ReadDeviceInfoResponse = response.try_into()?;
283 Ok(device_info_response)
284 }
285
286 pub fn read_state(&mut self) -> ClientResult<ReadStateResponse> {
289 let request = Request::ReadState(ReadStateRequest::new());
290 let response = self.request(request)?;
291 let device_state: ReadStateResponse = response.try_into()?;
292 Ok(device_state)
293 }
294
295 pub fn write_control(
298 &mut self,
299 ads_state: AdsState,
300 device_state: u16,
301 ) -> ClientResult<WriteControlResponse> {
302 let request = Request::WriteControl(request_factory::get_write_control_request(
303 ads_state,
304 device_state,
305 ));
306 let response = self.request(request)?;
307 let write_control_response: WriteControlResponse = response.try_into()?;
308 Ok(write_control_response)
309 }
310
311 pub fn read_write(
314 &mut self,
315 index_offset: u32,
316 read_len: u32,
317 write_data: Vec<u8>,
318 ) -> ClientResult<ReadWriteResponse> {
319 let request = Request::ReadWrite(request_factory::get_read_write_request(
320 index_offset,
321 read_len,
322 write_data,
323 ));
324 let response = self.request(request)?;
325 let read_write_response: ReadWriteResponse = response.try_into()?;
326 Ok(read_write_response)
327 }
328
329 pub fn add_device_notification(
332 &mut self,
333 var_name: &str,
334 length: u32,
335 transmission_mode: AdsTransMode,
336 max_delay: u32,
337 cycle_time: u32,
338 ) -> ClientResult<Receiver<Result<AdsNotificationStream, Error>>> {
339 let handle = self.get_var_handle(var_name)?;
340 let request = Request::AddDeviceNotification(request_factory::get_add_device_notification(
341 handle,
342 length,
343 transmission_mode,
344 max_delay,
345 cycle_time,
346 ));
347
348 let response: AddDeviceNotificationResponse = self.request(request)?.try_into()?;
350 let handle = response.notification_handle;
351 let (tx, rx) = channel::<ClientResult<AdsNotificationStream>>();
353 self.get_notification_tx()?
355 .send((handle, tx))
356 .expect("Failed to send request to thread by mpsc channel");
357
358 self.notification_handle_list
359 .insert(var_name.to_string(), handle);
360 Ok(rx)
361 }
362
363 pub fn delete_device_notification(
366 &mut self,
367 var_name: &str,
368 ) -> ClientResult<DeleteDeviceNotificationResponse> {
369 let handle;
370 if let Some(h) = self.notification_handle_list.get(var_name) {
371 handle = *h;
372 let request = Request::DeleteDeviceNotification(
373 request_factory::get_delete_device_notification(handle),
374 );
375 let response = self.request(request)?;
376 let response: DeleteDeviceNotificationResponse = response.try_into()?;
377 self.notification_handle_list.remove(var_name);
378 return Ok(response);
379 }
380 Err(anyhow!(AdsError::AdsErrDeviceSymbolNotFound)) }
382
383 fn get_var_handle(&mut self, var_name: &str) -> ClientResult<u32> {
384 if let Some(handle) = self.handle_list.get(var_name) {
385 Ok(*handle)
386 } else {
387 let handle = self.request_var_handle(var_name)?;
388 self.handle_list.insert(var_name.to_string(), handle);
389 Ok(handle)
390 }
391 }
392
393 fn sumup_get_var_handle(
394 &mut self,
395 var_names: &Vec<String>,
396 ) -> ClientResult<HashMap<String, u32>> {
397 let mut do_request: Vec<String> = Vec::new();
398 let mut handles: HashMap<String, u32> = HashMap::new();
399 for var in var_names {
400 if let Some(handle) = self.handle_list.get(var) {
401 handles.insert(var.clone(), *handle);
402 } else {
403 do_request.push(var.clone());
404 }
405 }
406
407 if !do_request.is_empty() {
408 let requested_handles = self.sumup_request_var_handle(&do_request)?;
409 for (name, handle) in requested_handles {
410 self.handle_list.insert(name.clone(), handle);
411 handles.insert(name, handle);
412 }
413 }
414 Ok(handles)
415 }
416
417 fn request_var_handle(&mut self, var_name: &str) -> ClientResult<u32> {
419 let request = Request::ReadWrite(get_var_handle_request(var_name));
420 let response: ReadWriteResponse = self.request(request)?.try_into()?;
421
422 if response.length == 4 {
423 return Ok(response.data.as_slice().read_u32::<LittleEndian>()?);
424 }
425 Err(anyhow!(
426 "Failed to get var handle! Variable {} not found!",
427 var_name
428 ))
429 }
430
431 fn sumup_request_var_handle(
435 &mut self,
436 var_list: &Vec<String>,
437 ) -> ClientResult<HashMap<String, u32>> {
438 let mut result: HashMap<String, u32> = HashMap::new();
439 for var in var_list {
440 let response = self.request(Request::ReadWrite(get_var_handle_request(var)))?;
441 let handle: ReadWriteResponse = response.try_into()?;
442 let handle = handle.data.as_slice().read_u32::<LittleEndian>()?;
443 result.insert(var.clone(), handle);
444 }
445 Ok(result)
446 }
447
448 pub fn release_handle(&mut self, var_name: &str) -> ClientResult<WriteResponse> {
450 if let Some(handle) = self.handle_list.get(var_name) {
451 let request = Request::Write(request_factory::get_release_handle_request(*handle));
452 let response = self.request(request)?;
453 let response: WriteResponse = response.try_into()?;
454 self.handle_list.remove(var_name);
455 return Ok(response);
456 }
457 Err(anyhow!("Handle not available"))
458 }
459
460 fn new_tcp_ams_request_header(&mut self, request: Request) -> AmsTcpHeader {
462 self.invoke_id += 1;
463 AmsTcpHeader::from(AmsHeader::new(
464 self.ams_targed_address.clone(),
465 self.ams_source_address.clone(),
466 StateFlags::req_default(),
467 self.invoke_id,
468 request,
469 ))
470 }
471
472 fn check_tcp_stream(&mut self, response: &ClientResult<Response>) {
474 if let Err(e) = response {
475 if e.is::<AdsError>() {
476 let e = e.downcast_ref::<AdsError>();
477 if let Some(e) = e {
478 if e == &AdsError::ErrPortNotConnected {
479 if let Some(stream) = &self.stream {
480 let _ = stream.shutdown(Shutdown::Both);
481 }
482 self.handle_list.clear();
483 self.notification_handle_list.clear();
484 self.stream = None;
485 }
486 }
487 }
488 }
489 }
490
491 fn get_general_tx(&self) -> ClientResult<&TxGeneral> {
493 if let Some(tx) = &self.tx_general {
494 return Ok(tx);
495 }
496 Err(anyhow!(AdsError::AdsErrClientError)) }
498
499 fn get_notification_tx(&self) -> ClientResult<&TxNotification> {
501 if let Some(tx) = &self.tx_notification {
502 return Ok(tx);
503 }
504 Err(anyhow!(AdsError::AdsErrClientError)) }
506}