tensorflow/
variable.rs

1use crate::ops;
2use crate::AnyTensor;
3use crate::DataType;
4use crate::Operation;
5use crate::Output;
6use crate::Result;
7use crate::Scope;
8use crate::Shape;
9use crate::Tensor;
10use crate::TensorType;
11use std::borrow::Borrow;
12
13/// Holds state in the form of a tensor that persists across steps.
14#[derive(Debug, Clone)]
15pub struct Variable {
16    pub(crate) name: String,
17    pub(crate) initializer: Operation,
18    pub(crate) output: Output,
19    pub(crate) dtype: DataType,
20    pub(crate) shape: Shape,
21}
22
23impl Variable {
24    /// Creates a builder which can be used to create a Variable.
25    pub fn builder<'a>() -> VariableBuilder<'a> {
26        VariableBuilder::default()
27    }
28
29    /// Returns the name.
30    pub fn name(&self) -> &str {
31        &self.name
32    }
33
34    /// Returns the output which evaluates to the value of the variable.
35    pub fn output(&self) -> &Output {
36        &self.output
37    }
38
39    /// Returns the initializer.
40    pub fn initializer(&self) -> &Operation {
41        &self.initializer
42    }
43
44    /// Returns the data type.
45    pub fn data_type(&self) -> DataType {
46        self.dtype
47    }
48
49    /// Returns the shape.
50    pub fn shape(&self) -> &Shape {
51        &self.shape
52    }
53}
54
55#[derive(Debug)]
56enum VariableInitialValue<'a> {
57    Unspecified,
58    TensorBox(Box<dyn AnyTensor>),
59    TensorRef(&'a dyn AnyTensor),
60    Output(Output),
61}
62
63/// Builds a Variable.
64#[derive(Debug)]
65pub struct VariableBuilder<'a> {
66    initial_value: VariableInitialValue<'a>,
67    shape: Shape,
68    dtype: Option<DataType>,
69}
70
71impl<'a> Default for VariableBuilder<'a> {
72    fn default() -> Self {
73        Self {
74            initial_value: VariableInitialValue::Unspecified,
75            shape: Shape(None),
76            dtype: None,
77        }
78    }
79}
80
81impl<'a> VariableBuilder<'a> {
82    /// Sets the initial value from anything that can be converted into a Tensor.
83    /// This also sets the type and shape.
84    pub fn const_initial_value<T: TensorType, TT: Into<Tensor<T>>>(self, value: TT) -> Self {
85        let t: Tensor<T> = value.into();
86        let shape = t.shape();
87        Self {
88            initial_value: VariableInitialValue::TensorBox(Box::<Tensor<T>>::new(t)),
89            dtype: Some(T::data_type()),
90            shape,
91        }
92    }
93
94    /// Sets the initial value from a Tensor.
95    /// This also sets the type and shape.
96    pub fn const_initial_tensor<T: TensorType>(self, value: &'a Tensor<T>) -> Self {
97        let shape = value.shape();
98        Self {
99            initial_value: VariableInitialValue::TensorRef(value),
100            dtype: Some(T::data_type()),
101            shape,
102        }
103    }
104
105    /// Sets the initial value from an existing output in the graph.
106    /// The type and shape are not set and will need to be set manually.
107    pub fn initial_value<T: Into<Output>>(self, value: T) -> Self {
108        Self {
109            initial_value: VariableInitialValue::Output(value.into()),
110            ..self
111        }
112    }
113
114    /// Sets the shape of the variable.
115    pub fn shape<S: Into<Shape>>(self, shape: S) -> Self {
116        Self {
117            shape: shape.into(),
118            ..self
119        }
120    }
121
122    /// Sets the data type of the variable.
123    pub fn data_type(self, data_type: DataType) -> Self {
124        Self {
125            dtype: Some(data_type),
126            ..self
127        }
128    }
129
130    /// Builds the Variable.
131    pub fn build(self, scope: &mut Scope) -> Result<Variable> {
132        let name = scope.get_unique_name_for_op("VariableV2");
133        let dtype = match self.dtype {
134            Some(d) => d,
135            None => return Err(invalid_arg!("data_type must be specified")),
136        };
137        let variable_op = {
138            let mut graph = scope.graph_mut();
139            let mut nd = graph.new_operation("VariableV2", &name)?;
140            nd.set_attr_type("dtype", dtype)?;
141            nd.set_attr_shape("shape", &self.shape)?;
142            nd.finish()?
143        };
144        let initial_value = match self.initial_value {
145            VariableInitialValue::Unspecified => {
146                return Err(invalid_arg!("an initial value is required"))
147            }
148            VariableInitialValue::TensorBox(t) => ops::any_constant(t.borrow(), scope)?.into(),
149            VariableInitialValue::TensorRef(t) => ops::any_constant(t, scope)?.into(),
150            VariableInitialValue::Output(o) => o,
151        };
152        let initializer = ops::assign(variable_op.clone(), initial_value, scope)?;
153        Ok(Variable {
154            name,
155            output: variable_op.into(),
156            initializer,
157            dtype,
158            shape: self.shape,
159        })
160    }
161}
162
163////////////////////////
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::Code;
169    use crate::Session;
170    use crate::SessionOptions;
171    use crate::SessionRunArgs;
172
173    #[test]
174    fn const_initialized_scalar() {
175        let scope = Scope::new_root_scope();
176
177        let variable = Variable::builder()
178            .const_initial_value(3.0f32)
179            .build(&mut scope.with_op_name("foo"))
180            .unwrap();
181        assert_eq!(variable.name, "foo");
182        assert_eq!(variable.shape, Shape(Some(vec![])));
183        assert_eq!(variable.dtype, DataType::Float);
184        assert_eq!(
185            variable.output.operation.get_attr_shape("shape").unwrap(),
186            Shape(Some(vec![]))
187        );
188        assert_eq!(
189            variable.output.operation.get_attr_type("dtype").unwrap(),
190            DataType::Float
191        );
192
193        let options = SessionOptions::new();
194        let session = Session::new(&options, &scope.graph()).unwrap();
195        let mut run_args = SessionRunArgs::new();
196        run_args.add_target(&variable.initializer);
197        session.run(&mut run_args).unwrap();
198
199        let mut run_args = SessionRunArgs::new();
200        let fetch = run_args.request_fetch(&variable.output.operation, 0);
201        session.run(&mut run_args).unwrap();
202        let output = run_args.fetch::<f32>(fetch).unwrap();
203        assert_eq!(&output[..], &[3.0f32]);
204    }
205
206    #[test]
207    fn const_initialized_matrix() {
208        let scope = Scope::new_root_scope();
209
210        let initial = Tensor::<i32>::new(&[2, 3])
211            .with_values(&[1, 2, 3, 4, 5, 6])
212            .unwrap();
213        let variable = Variable::builder()
214            .const_initial_tensor(&initial)
215            .build(&mut scope.with_op_name("foo"))
216            .unwrap();
217        assert_eq!(variable.name, "foo");
218        assert_eq!(variable.shape, Shape(Some(vec![Some(2), Some(3)])));
219        assert_eq!(variable.dtype, DataType::Int32);
220        assert_eq!(
221            variable.output.operation.get_attr_shape("shape").unwrap(),
222            Shape(Some(vec![Some(2), Some(3)]))
223        );
224        assert_eq!(
225            variable.output.operation.get_attr_type("dtype").unwrap(),
226            DataType::Int32
227        );
228
229        let options = SessionOptions::new();
230        let session = Session::new(&options, &scope.graph()).unwrap();
231        let mut run_args = SessionRunArgs::new();
232        run_args.add_target(&variable.initializer);
233        session.run(&mut run_args).unwrap();
234
235        let mut run_args = SessionRunArgs::new();
236        let fetch = run_args.request_fetch(&variable.output.operation, 0);
237        session.run(&mut run_args).unwrap();
238        let output = run_args.fetch::<i32>(fetch).unwrap();
239        assert_eq!(&output[..], &initial[..]);
240    }
241
242    #[test]
243    fn custom_initializer_missing_dtype() {
244        let mut scope = Scope::new_root_scope();
245        let value = Tensor::new(&[]).with_values(&[3.0f32]).unwrap();
246        let const_op = ops::constant(value, &mut scope).unwrap();
247
248        assert_eq!(
249            Variable::builder()
250                .initial_value(const_op)
251                .build(&mut scope.with_op_name("foo"))
252                .unwrap_err()
253                .code(),
254            Code::InvalidArgument
255        );
256    }
257
258    #[test]
259    fn custom_initializer() {
260        let mut scope = Scope::new_root_scope();
261        let value = Tensor::new(&[]).with_values(&[3.0f32]).unwrap();
262        let const_op = ops::constant(value, &mut scope).unwrap();
263
264        let variable = Variable::builder()
265            .initial_value(const_op)
266            .data_type(DataType::Float)
267            .build(&mut scope.with_op_name("foo"))
268            .unwrap();
269        assert_eq!(variable.name, "foo");
270        assert_eq!(variable.shape, Shape(None));
271        assert_eq!(variable.dtype, DataType::Float);
272        assert_eq!(
273            variable.output.operation.get_attr_shape("shape").unwrap(),
274            Shape(None)
275        );
276        assert_eq!(
277            variable.output.operation.get_attr_type("dtype").unwrap(),
278            DataType::Float
279        );
280
281        let options = SessionOptions::new();
282        let session = Session::new(&options, &scope.graph()).unwrap();
283        let mut run_args = SessionRunArgs::new();
284        run_args.add_target(&variable.initializer);
285        session.run(&mut run_args).unwrap();
286
287        let mut run_args = SessionRunArgs::new();
288        let fetch = run_args.request_fetch(&variable.output.operation, 0);
289        session.run(&mut run_args).unwrap();
290        let output = run_args.fetch::<f32>(fetch).unwrap();
291        assert_eq!(&output[..], &[3.0f32]);
292    }
293}