rust_ethernet_ip/
python.rs1use pyo3::prelude::*;
2use pyo3::types::{PyDict, PyTuple};
4use tokio::runtime::Runtime;
6use std::collections::HashMap;
7use crate::{
8    EipClient, PlcValue, SubscriptionOptions, Result as EipResult
9};
10
11#[pymodule]
13fn rust_ethernet_ip(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
14    m.add_class::<PyEipClient>()?;
15    m.add_class::<PyPlcValue>()?;
16    m.add_class::<PySubscriptionOptions>()?;
17    Ok(())
18}
19
20#[pyclass]
22struct PyEipClient {
23    client: EipClient,
24    runtime: Runtime,
25}
26
27struct TagValueArg {
29    name: String,
30    value: PyPlcValue,
31}
32
33impl<'a> FromPyObject<'a> for TagValueArg {
34    fn extract(ob: &'a pyo3::PyAny) -> PyResult<Self> {
35        let tuple = ob.downcast::<PyTuple>()?;
36        if tuple.len() != 2 {
37            return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
38                "Expected tuple of length 2"
39            ));
40        }
41        let name = tuple.get_item(0)?.extract::<String>()?;
42        let value = tuple.get_item(1)?.extract::<PyPlcValue>()?;
43        Ok(TagValueArg { name, value })
44    }
45}
46
47struct TagSubOptArg {
49    name: String,
50    options: PySubscriptionOptions,
51}
52
53impl<'a> FromPyObject<'a> for TagSubOptArg {
54    fn extract(ob: &'a pyo3::PyAny) -> PyResult<Self> {
55        let tuple = ob.downcast::<PyTuple>()?;
56        if tuple.len() != 2 {
57            return Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
58                "Expected tuple of length 2"
59            ));
60        }
61        let name = tuple.get_item(0)?.extract::<String>()?;
62        let options = tuple.get_item(1)?.extract::<PySubscriptionOptions>()?;
63        Ok(TagSubOptArg { name, options })
64    }
65}
66
67#[pymethods]
68impl PyEipClient {
69    #[new]
71    fn new(addr: &str) -> PyResult<Self> {
72        let runtime = Runtime::new().unwrap();
73        let client = runtime.block_on(async {
74            EipClient::connect(addr).await
75        }).map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
76        
77        Ok(PyEipClient { client, runtime })
78    }
79
80    fn read_tag(&mut self, tag_name: &str) -> PyResult<PyPlcValue> {
82        let value = self.runtime.block_on(async {
83            self.client.read_tag(tag_name).await
84        }).map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
85        
86        Ok(PyPlcValue { value })
87    }
88
89    fn write_tag(&mut self, tag_name: &str, value: &PyPlcValue) -> PyResult<bool> {
91        let result = self.runtime.block_on(async {
92            self.client.write_tag(tag_name, value.value.clone()).await
93        });
94        match result {
95            Ok(_) => Ok(true),
96            Err(e) => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string())),
97        }
98    }
99
100    fn read_tags_batch(&mut self, tag_names: Vec<String>) -> PyResult<Vec<(String, PyObject)>> {
102        Python::with_gil(|py| {
103            let runtime = tokio::runtime::Runtime::new().unwrap();
104            let results = runtime.block_on(async {
105                self.client.read_tags_batch(&tag_names.iter().map(|s| s.as_str()).collect::<Vec<_>>()).await
106            }).map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
107            Ok(results.into_iter().map(|(name, result)| {
108                let obj = match result {
109                    Ok(v) => PyPlcValue { value: v }.into_py(py),
110                    Err(e) => PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()).into_py(py),
111                };
112                (name, obj)
113            }).collect())
114        })
115    }
116
117    fn write_tags_batch(&mut self, tag_values: Vec<TagValueArg>) -> PyResult<Vec<(String, PyObject)>> {
119        Python::with_gil(|py| {
120            let runtime = tokio::runtime::Runtime::new().unwrap();
121            let results = runtime.block_on(async {
122                self.client.write_tags_batch(&tag_values.iter()
123                    .map(|arg| (arg.name.as_str(), arg.value.value.clone()))
124                    .collect::<Vec<_>>()).await
125            }).map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
126            Ok(results.into_iter().map(|(name, result)| {
127                let obj = match result {
128                    Ok(()) => py.None(),
129                    Err(e) => PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()).into_py(py),
130                };
131                (name, obj)
132            }).collect())
133        })
134    }
135
136    fn subscribe_to_tag(&self, tag_path: &str, options: &PySubscriptionOptions) -> PyResult<()> {
138        self.runtime.block_on(async {
139            self.client.subscribe_to_tag(tag_path, options.options.clone()).await
140        }).map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
141        
142        Ok(())
143    }
144
145    fn subscribe_to_tags(&self, tags: Vec<TagSubOptArg>) -> PyResult<()> {
147        self.runtime.block_on(async {
148            self.client.subscribe_to_tags(&tags.iter()
149                .map(|arg| (arg.name.as_str(), arg.options.options.clone()))
150                .collect::<Vec<_>>()).await
151        }).map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
152        Ok(())
153    }
154
155    fn unregister_session(&mut self) -> PyResult<()> {
157        self.runtime.block_on(async {
158            self.client.unregister_session().await
159        }).map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
160        
161        Ok(())
162    }
163}
164
165#[pyclass]
167struct PyPlcValue {
168    value: PlcValue,
169}
170
171impl FromPyObject<'_> for PyPlcValue {
172    fn extract(ob: &PyAny) -> PyResult<Self> {
173        if let Ok(bool_val) = ob.extract::<bool>() {
174            Ok(PyPlcValue { value: PlcValue::Bool(bool_val) })
175        } else if let Ok(int_val) = ob.extract::<i32>() {
176            Ok(PyPlcValue { value: PlcValue::Dint(int_val) })
177        } else if let Ok(float_val) = ob.extract::<f64>() {
178            Ok(PyPlcValue { value: PlcValue::Lreal(float_val) })
179        } else if let Ok(string_val) = ob.extract::<String>() {
180            Ok(PyPlcValue { value: PlcValue::String(string_val) })
181        } else if let Ok(dict) = ob.downcast::<PyDict>() {
182            let mut map = HashMap::new();
183            for (key, value) in dict.iter() {
184                let key = key.extract::<String>()?;
185                let value = value.extract::<PyPlcValue>()?.value;
186                map.insert(key, value);
187            }
188            Ok(PyPlcValue { value: PlcValue::Udt(map) })
189        } else {
190            Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
191                "Unsupported value type"
192            ))
193        }
194    }
195}
196
197#[pymethods]
198impl PyPlcValue {
199    #[new]
200    fn new(value: PyObject) -> PyResult<Self> {
201        Python::with_gil(|py| {
202            if let Ok(val) = value.extract::<bool>(py) {
203                Ok(PyPlcValue { value: PlcValue::Bool(val) })
204            } else if let Ok(val) = value.extract::<i32>(py) {
205                Ok(PyPlcValue { value: PlcValue::Dint(val) })
206            } else if let Ok(val) = value.extract::<f32>(py) {
207                Ok(PyPlcValue { value: PlcValue::Real(val) })
208            } else if let Ok(val) = value.extract::<f64>(py) {
209                Ok(PyPlcValue { value: PlcValue::Real(val as f32) })
210            } else if let Ok(val) = value.extract::<String>(py) {
211                Ok(PyPlcValue { value: PlcValue::String(val) })
212            } else {
213                Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>("Unsupported value type"))
214            }
215        })
216    }
217
218    #[staticmethod]
219    fn real(val: f32) -> Self {
220        PyPlcValue { value: PlcValue::Real(val) }
221    }
222    #[staticmethod]
223    fn lreal(val: f64) -> Self {
224        PyPlcValue { value: PlcValue::Lreal(val) }
225    }
226    #[staticmethod]
227    fn dint(val: i32) -> Self {
228        PyPlcValue { value: PlcValue::Dint(val) }
229    }
230    #[staticmethod]
231    fn lint(val: i64) -> Self {
232        PyPlcValue { value: PlcValue::Lint(val) }
233    }
234    #[staticmethod]
235    fn string(val: String) -> Self {
236        PyPlcValue { value: PlcValue::String(val) }
237    }
238
239    #[getter]
240    fn value(&self, py: Python) -> PyResult<PyObject> {
241        match &self.value {
242            PlcValue::Bool(b) => Ok(b.into_py(py)),
243            PlcValue::Sint(i) => Ok(i.into_py(py)),
244            PlcValue::Int(i) => Ok(i.into_py(py)),
245            PlcValue::Dint(i) => Ok(i.into_py(py)),
246            PlcValue::Lint(i) => Ok(i.into_py(py)),
247            PlcValue::Usint(u) => Ok(u.into_py(py)),
248            PlcValue::Uint(u) => Ok(u.into_py(py)),
249            PlcValue::Udint(u) => Ok(u.into_py(py)),
250            PlcValue::Ulint(u) => Ok(u.into_py(py)),
251            PlcValue::Real(f) => Ok(f.into_py(py)),
252            PlcValue::Lreal(f) => Ok(f.into_py(py)),
253            PlcValue::String(s) => Ok(s.into_py(py)),
254            PlcValue::Udt(map) => {
255                let dict = PyDict::new(py);
256                for (k, v) in map.iter() {
257                    let v_py = PyPlcValue { value: v.clone() }.value(py)?;
258                    dict.set_item(k, v_py)?;
259                }
260                Ok(dict.into_py(py))
261            }
262        }
263    }
264
265    fn __str__(&self) -> String {
266        format!("{:?}", self.value)
267    }
268
269    fn __repr__(&self) -> String {
270        format!("PyPlcValue({:?})", self.value)
271    }
272}
273
274#[pyclass]
276struct PySubscriptionOptions {
277    options: SubscriptionOptions,
278}
279
280impl FromPyObject<'_> for PySubscriptionOptions {
281    fn extract(ob: &PyAny) -> PyResult<Self> {
282        let update_rate = ob.getattr("update_rate")?.extract::<u32>()?;
283        let change_threshold = ob.getattr("change_threshold")?.extract::<f32>()?;
284        let timeout = ob.getattr("timeout")?.extract::<u32>()?;
285        
286        Ok(PySubscriptionOptions {
287            options: SubscriptionOptions {
288                update_rate,
289                change_threshold,
290                timeout,
291            }
292        })
293    }
294}
295
296#[pymethods]
297impl PySubscriptionOptions {
298    #[new]
299    fn new(update_rate: u32, change_threshold: f32, timeout: u32) -> PyResult<Self> {
300        let options = SubscriptionOptions {
301            update_rate,
302            change_threshold,
303            timeout,
304        };
305        
306        Ok(PySubscriptionOptions { options })
307    }
308
309    #[getter]
311    fn update_rate(&self) -> u32 {
312        self.options.update_rate
313    }
314
315    #[getter]
317    fn change_threshold(&self) -> f32 {
318        self.options.change_threshold
319    }
320
321    #[getter]
323    fn timeout(&self) -> u32 {
324        self.options.timeout
325    }
326}