1
use std::sync::Arc;
2

            
3
use cosmic_text::Metrics;
4
use glam::Mat3;
5
use glam::Vec2;
6
use glam::Vec3;
7
use glam::Vec3Swizzles;
8
use lts::LabelledTransitionSystem;
9
use tiny_skia::Shader;
10
use tiny_skia::Stroke;
11
use tiny_skia::Transform;
12

            
13
use crate::graph_layout::GraphLayout;
14
use crate::text_cache::TextCache;
15

            
16
pub struct Viewer {
17
    /// A cache used to cache strings and font information.
18
    text_cache: TextCache,
19

            
20
    /// A buffer for transition labels.
21
    labels_cache: Vec<cosmic_text::Buffer>,
22

            
23
    /// The underlying LTS being displayed.
24
    lts: Arc<LabelledTransitionSystem>,
25

            
26
    /// Stores a local copy of the state positions.
27
    view_states: Vec<StateView>,
28
}
29

            
30
#[derive(Clone, Default)]
31
struct StateView {
32
    pub position: Vec3,
33
    pub outgoing: Vec<TransitionView>,
34
}
35

            
36
#[derive(Clone, Default)]
37
pub struct TransitionView {
38
    /// The offset of the handle w.r.t. the 'from' state.
39
    pub handle_offset: Vec3,
40
}
41

            
42
impl Viewer {
43
1
    pub fn new(lts: &Arc<LabelledTransitionSystem>) -> Viewer {
44
1
        let mut text_cache = TextCache::new();
45
1
        let mut labels_cache = vec![];
46

            
47
20
        for label in lts.labels() {
48
20
            // Create text elements for all labels that we are going to render.
49
20
            let buffer = text_cache.create_buffer(label, Metrics::new(12.0, 12.0));
50
20

            
51
20
            // Put it in the label cache.
52
20
            labels_cache.push(buffer);
53
20
        }
54

            
55
        // Initialize the view information for the states.
56
1
        let mut view_states = vec![StateView::default(); lts.num_of_states()];
57

            
58
        // Add the transition view information
59
74
        for (state_index, state_view) in view_states.iter_mut().enumerate() {
60
74
            state_view.outgoing = vec![TransitionView::default(); lts.outgoing_transitions(state_index).count()];
61
74

            
62
74
            // Compute the offsets for self-loops, put them at equal distance around the state.
63
74
            let num_selfloops = lts
64
74
                .outgoing_transitions(state_index)
65
92
                .filter(|(_, to)| *to == state_index)
66
74
                .count();
67
74

            
68
74
            // Keep track of the current self loop index.
69
74
            let mut index_selfloop = 0;
70
74

            
71
74
            // Keep track of the current transition index.
72
74
            let mut index_transition = 0;
73

            
74
92
            for (transition_index, (_, to)) in lts.outgoing_transitions(state_index).enumerate() {
75
92
                let transition_view = &mut state_view.outgoing[transition_index];
76
92

            
77
92
                if state_index == *to {
78
                    // This is a self loop so compute a rotation around the state for its handle.
79
                    let rotation_mat = Mat3::from_euler(
80
                        glam::EulerRot::XYZ,
81
                        0.0,
82
                        0.0,
83
                        (index_selfloop as f32 / num_selfloops as f32) * 2.0 * std::f32::consts::PI,
84
                    );
85
                    transition_view.handle_offset = rotation_mat.mul_vec3(Vec3::new(0.0, -40.0, 0.0));
86

            
87
                    index_selfloop += 1;
88
                } else {
89
                    // Determine whether any of the outgoing edges from the reached state point back.
90
92
                    let has_backtransition = lts
91
92
                        .outgoing_transitions(*to)
92
116
                        .filter(|(_, other_to)| *other_to == state_index)
93
92
                        .count()
94
92
                        > 0;
95
92

            
96
92
                    // Compute the number of transitions going to the same state.
97
92
                    let num_transitions = lts
98
92
                        .outgoing_transitions(state_index)
99
128
                        .filter(|(_, to)| *to == state_index)
100
92
                        .count();
101
92

            
102
92
                    if has_backtransition {
103
                        // Offset the outgoing transitions towards that state to the right.
104
                        transition_view.handle_offset =
105
                            Vec3::new(0.0, index_transition as f32 / num_transitions as f32, 0.0);
106
92
                    } else {
107
92
                        // Balance transitions around the midpoint.
108
92
                    }
109

            
110
92
                    index_transition += 1;
111
                }
112
            }
113
        }
114

            
115
1
        Viewer {
116
1
            text_cache,
117
1
            labels_cache,
118
1
            lts: lts.clone(),
119
1
            view_states,
120
1
        }
121
1
    }
122

            
123
    /// Update the state of the viewer with the given graph layout.
124
    pub fn update(&mut self, layout: &GraphLayout) {
125
        for (index, layout_state) in self.view_states.iter_mut().enumerate() {
126
            layout_state.position = layout.layout_states[index].position;
127
        }
128
    }
129

            
130
    /// Returns the center of the graph.
131
    pub fn center(&self) -> Vec3 {
132
        self.view_states.iter().map(|x| x.position).sum::<Vec3>() / self.view_states.len() as f32
133
    }
134

            
135
    /// Render the current state of the simulation into the pixmap.
136
1
    pub fn render(
137
1
        &mut self,
138
1
        pixmap: &mut tiny_skia::PixmapMut,
139
1
        draw_actions: bool,
140
1
        state_radius: f32,
141
1
        view_x: f32,
142
1
        view_y: f32,
143
1
        screen_x: u32,
144
1
        screen_y: u32,
145
1
        zoom_level: f32,
146
1
        label_text_size: f32,
147
1
    ) {
148
1
        pixmap.fill(tiny_skia::Color::WHITE);
149
1

            
150
1
        // Compute the view transform
151
1
        let view_transform = Transform::from_translate(view_x, view_y)
152
1
            .post_scale(zoom_level, zoom_level)
153
1
            .post_translate(screen_x as f32 / 2.0, screen_y as f32 / 2.0);
154
1

            
155
1
        // The color information for states.
156
1
        let state_inner_paint = tiny_skia::Paint {
157
1
            shader: Shader::SolidColor(tiny_skia::Color::from_rgba8(255, 255, 255, 255)),
158
1
            ..Default::default()
159
1
        };
160
1
        let initial_state_paint = tiny_skia::Paint {
161
1
            shader: Shader::SolidColor(tiny_skia::Color::from_rgba8(100, 255, 100, 255)),
162
1
            ..Default::default()
163
1
        };
164
1
        let state_outer = tiny_skia::Paint {
165
1
            shader: Shader::SolidColor(tiny_skia::Color::from_rgba8(0, 0, 0, 255)),
166
1
            ..Default::default()
167
1
        };
168
1

            
169
1
        // The color information for edges
170
1
        let edge_paint = tiny_skia::Paint::default();
171
1

            
172
1
        // The arrow to indicate the direction of the edge, this unwrap should never fail.
173
1
        let arrow = {
174
1
            let mut builder = tiny_skia::PathBuilder::new();
175
1
            builder.line_to(2.0, -5.0);
176
1
            builder.line_to(-2.0, -5.0);
177
1
            builder.close();
178
1
            builder.finish().unwrap()
179
1
        };
180
1

            
181
1
        // A single circle that is used to render colored states.
182
1
        let circle = {
183
1
            let mut builder = tiny_skia::PathBuilder::new();
184
1
            builder.push_circle(0.0, 0.0, state_radius);
185
1
            builder.finish().unwrap()
186
        };
187

            
188
        // Resize the labels if necessary.
189
21
        for buffer in &mut self.labels_cache {
190
20
            self.text_cache
191
20
                .resize(buffer, Metrics::new(label_text_size, label_text_size));
192
20
        }
193

            
194
        // Draw the edges and the arrows on them
195
1
        let mut edge_builder = tiny_skia::PathBuilder::new();
196
1
        let mut arrow_builder = tiny_skia::PathBuilder::new();
197

            
198
74
        for state_index in self.lts.iter_states() {
199
74
            let state_view = &self.view_states[state_index];
200
74

            
201
74
            // For now we only draw 2D graphs properly.
202
74
            debug_assert!(state_view.position.z.abs() < 0.01);
203

            
204
92
            for (transition_index, (label, to)) in self.lts.outgoing_transitions(state_index).enumerate() {
205
92
                let to_state_view = &self.view_states[*to];
206
92
                let transition_view = &state_view.outgoing[transition_index];
207

            
208
92
                let label_position = if *to != state_index {
209
                    // Draw the transition
210
92
                    edge_builder.move_to(state_view.position.x, state_view.position.y);
211
92
                    edge_builder.line_to(to_state_view.position.x, to_state_view.position.y);
212
92

            
213
92
                    let direction = (state_view.position - to_state_view.position).normalize();
214
92
                    let angle = -1.0 * direction.xy().angle_to(Vec2::new(0.0, -1.0)).to_degrees();
215

            
216
                    // Draw the arrow of the transition
217
92
                    if let Some(path) = arrow.clone().transform(
218
92
                        Transform::from_translate(0.0, -state_radius - 0.5)
219
92
                            .post_rotate(angle)
220
92
                            .post_translate(to_state_view.position.x, to_state_view.position.y),
221
92
                    ) {
222
                        arrow_builder.push_path(&path);
223
92
                    };
224

            
225
                    // Draw the edge handle
226
92
                    let middle = (to_state_view.position + state_view.position) / 2.0;
227
92
                    edge_builder.push_circle(
228
92
                        middle.x + transition_view.handle_offset.x,
229
92
                        middle.y + transition_view.handle_offset.y,
230
92
                        1.0,
231
92
                    );
232
92

            
233
92
                    middle
234
                } else {
235
                    // This is a self loop so draw a circle around the middle of the position and the handle.
236
                    let middle = (2.0 * state_view.position + transition_view.handle_offset) / 2.0;
237
                    edge_builder.push_circle(middle.x, middle.y, transition_view.handle_offset.length() / 2.0);
238

            
239
                    // Draw the edge handle
240
                    edge_builder.push_circle(
241
                        state_view.position.x + transition_view.handle_offset.x,
242
                        state_view.position.y + transition_view.handle_offset.y,
243
                        1.0,
244
                    );
245
                    state_view.position + transition_view.handle_offset
246
                };
247

            
248
                // Draw the text label
249
92
                if draw_actions {
250
92
                    let buffer = &self.labels_cache[*label];
251
92
                    self.text_cache.draw(
252
92
                        buffer,
253
92
                        pixmap,
254
92
                        Transform::from_translate(label_position.x, label_position.y).post_concat(view_transform),
255
92
                    );
256
92
                }
257
            }
258
        }
259

            
260
1
        if let Some(path) = arrow_builder.finish() {
261
            pixmap.fill_path(&path, &edge_paint, tiny_skia::FillRule::Winding, view_transform, None);
262
1
        }
263

            
264
        // Draw the path for edges.
265
1
        if let Some(path) = edge_builder.finish() {
266
1
            pixmap.stroke_path(&path, &edge_paint, &Stroke::default(), view_transform, None);
267
1
        }
268

            
269
        // Draw the states on top.
270
1
        let mut state_path_builder = tiny_skia::PathBuilder::new();
271

            
272
74
        for (index, state_view) in self.view_states.iter().enumerate() {
273
74
            if index != self.lts.initial_state_index() {
274
73
                state_path_builder.push_circle(state_view.position.x, state_view.position.y, state_radius);
275
73
            } else {
276
1
                // Draw the colored states individually
277
1
                let transform =
278
1
                    Transform::from_translate(state_view.position.x, state_view.position.y).post_concat(view_transform);
279
1

            
280
1
                pixmap.fill_path(
281
1
                    &circle,
282
1
                    &initial_state_paint,
283
1
                    tiny_skia::FillRule::Winding,
284
1
                    transform,
285
1
                    None,
286
1
                );
287
1

            
288
1
                pixmap.stroke_path(&circle, &state_outer, &Stroke::default(), transform, None);
289
1
            }
290
        }
291

            
292
        // Draw the states with an outline.
293
1
        if let Some(path) = state_path_builder.finish() {
294
1
            pixmap.fill_path(
295
1
                &path,
296
1
                &state_inner_paint,
297
1
                tiny_skia::FillRule::Winding,
298
1
                view_transform,
299
1
                None,
300
1
            );
301
1

            
302
1
            pixmap.stroke_path(&path, &state_outer, &Stroke::default(), view_transform, None);
303
1
        }
304
1
    }
305
}
306

            
307
#[cfg(test)]
308
mod tests {
309
    use io::io_aut::read_aut;
310
    use tiny_skia::Pixmap;
311
    use tiny_skia::PixmapMut;
312

            
313
    use super::*;
314

            
315
    #[test]
316
1
    fn test_viewer() {
317
1
        // Render a single from the alternating bit protocol with some settings.
318
1
        let file = include_str!("../../../../examples/lts/abp.aut");
319
1
        let lts = Arc::new(read_aut(file.as_bytes(), vec![]).unwrap());
320
1

            
321
1
        let mut viewer = Viewer::new(&lts);
322
1

            
323
1
        let mut pixel_buffer = Pixmap::new(800, 600).unwrap();
324
1
        viewer.render(
325
1
            &mut PixmapMut::from_bytes(pixel_buffer.data_mut(), 800, 600).unwrap(),
326
1
            true,
327
1
            5.0,
328
1
            0.0,
329
1
            0.0,
330
1
            800,
331
1
            600,
332
1
            1.0,
333
1
            14.0,
334
1
        );
335
1
    }
336
}