1
use ahash::AHashSet;
2
use mcrl2::aterm::ATerm;
3
use mcrl2::aterm::ATermRef;
4
use mcrl2::aterm::Protected;
5
use mcrl2::aterm::TermBuilder;
6
use mcrl2::aterm::TermPool;
7
use mcrl2::aterm::Yield;
8
use mcrl2::data::DataExpression;
9
use mcrl2::data::DataFunctionSymbol;
10
use mcrl2::data::DataVariable;
11

            
12
pub type SubstitutionBuilder = Protected<Vec<ATermRef<'static>>>;
13

            
14
/// Creates a new term where a subterm is replaced with another term.
15
///
16
/// # Parameters
17
/// 't'             -   The original term
18
/// 'new_subterm'   -   The subterm that will be injected
19
/// 'p'             -   The place in 't' on which 'new_subterm' will be placed,
20
///                     given as a slice of position indexes
21
///
22
/// # Example
23
///
24
/// The term is constructed bottom up. As an an example take the term s(s(a)).
25
/// Lets say we want to replace the a with the term 0. Then we traverse the term
26
/// until we have arrived at a and replace it with 0. We then construct s(0)
27
/// and then construct s(s(0)).
28
1
pub fn substitute(tp: &mut TermPool, t: &ATermRef<'_>, new_subterm: ATerm, p: &[usize]) -> ATerm {
29
1
    let mut args = Protected::new(vec![]);
30
1
    substitute_rec(tp, t, new_subterm, p, &mut args, 0)
31
1
}
32

            
33
7437990
pub fn substitute_with(
34
7437990
    builder: &mut SubstitutionBuilder,
35
7437990
    tp: &mut TermPool,
36
7437990
    t: &ATermRef<'_>,
37
7437990
    new_subterm: ATerm,
38
7437990
    p: &[usize],
39
7437990
) -> ATerm {
40
7437990
    substitute_rec(tp, t, new_subterm, p, builder, 0)
41
7437990
}
42

            
43
/// The recursive implementation for subsitute
44
///
45
/// 'depth'         -   Used to keep track of the depth in 't'. Function should be called with
46
///                     'depth' = 0.
47
8386403
fn substitute_rec(
48
8386403
    tp: &mut TermPool,
49
8386403
    t: &ATermRef<'_>,
50
8386403
    new_subterm: ATerm,
51
8386403
    p: &[usize],
52
8386403
    args: &mut Protected<Vec<ATermRef<'static>>>,
53
8386403
    depth: usize,
54
8386403
) -> ATerm {
55
8386403
    if p.len() == depth {
56
        // in this case we have arrived at the place where 'new_subterm' needs to be injected
57
7437991
        new_subterm
58
    } else {
59
        // else recurse deeper into 't'
60
948412
        let new_child_index = p[depth] - 1;
61
948412
        let new_child = substitute_rec(tp, &t.arg(new_child_index), new_subterm, p, args, depth + 1);
62
948412

            
63
948412
        let mut write_args = args.write();
64
3325932
        for (index, arg) in t.arguments().enumerate() {
65
3325932
            if index == new_child_index {
66
948412
                let t = write_args.protect(&new_child);
67
948412
                write_args.push(t);
68
2377522
            } else {
69
2377520
                let t = write_args.protect(&arg);
70
2377520
                write_args.push(t);
71
2377520
            }
72
        }
73

            
74
948412
        let result = tp.create(&t.get_head_symbol(), &write_args);
75
948412
        drop(write_args);
76
948412

            
77
948412
        // TODO: When write is dropped we check whether all terms where inserted, but this clear violates that assumption.
78
948412
        args.write().clear();
79
948412
        result
80
    }
81
8386403
}
82

            
83
/// Converts an [ATerm] to an untyped data expression.
84
9371
pub fn to_untyped_data_expression(tp: &mut TermPool, t: &ATerm, variables: &AHashSet<String>) -> DataExpression {
85
9371
    let mut builder = TermBuilder::<ATerm, ATerm>::new();
86
9371

            
87
9371
    builder
88
9371
        .evaluate(
89
9371
            tp,
90
9371
            t.clone(),
91
110752
            |tp, args, t| {
92
110752
                debug_assert!(!t.is_int(), "Term cannot be an aterm_int, although not sure why");
93

            
94
110752
                if variables.contains(t.get_head_symbol().name()) {
95
                    // Convert a constant variable, for example 'x', into an untyped variable.
96
9493
                    Ok(Yield::Term(DataVariable::new(tp, t.get_head_symbol().name()).into()))
97
101259
                } else if t.get_head_symbol().arity() == 0 {
98
16081
                    Ok(Yield::Term(
99
16081
                        DataFunctionSymbol::new(tp, t.get_head_symbol().name()).into(),
100
16081
                    ))
101
                } else {
102
                    // This is a function symbol applied to a number of arguments (higher order terms not allowed)
103
85178
                    let head = DataFunctionSymbol::new(tp, t.get_head_symbol().name());
104

            
105
101381
                    for arg in t.arguments() {
106
101381
                        args.push(arg.protect());
107
101381
                    }
108

            
109
85178
                    Ok(Yield::Construct(head.into()))
110
                }
111
110752
            },
112
85178
            |tp, input, args| Ok(tp.create_data_application(&input, args)),
113
9371
        )
114
9371
        .unwrap()
115
9371
        .into()
116
9371
}
117

            
118
#[cfg(test)]
119
mod tests {
120
    use crate::utilities::ExplicitPosition;
121
    use crate::utilities::PositionIndexed;
122

            
123
    use super::*;
124

            
125
    #[test]
126
1
    fn test_substitute() {
127
1
        let mut term_pool = TermPool::new();
128
1

            
129
1
        let t = term_pool.from_string("s(s(a))").unwrap();
130
1
        let t0 = term_pool.from_string("0").unwrap();
131
1

            
132
1
        // substitute the a for 0 in the term s(s(a))
133
1
        let result = substitute(&mut term_pool, &t, t0.clone(), &vec![1, 1]);
134
1

            
135
1
        // Check that indeed the new term as a 0 at position 1.1.
136
1
        assert_eq!(t0, result.get_position(&ExplicitPosition::new(&vec![1, 1])).protect());
137
1
    }
138

            
139
    #[test]
140
1
    fn test_to_data_expression() {
141
1
        let mut term_pool = TermPool::new();
142
1

            
143
1
        let t = term_pool.from_string("s(s(a))").unwrap();
144
1

            
145
1
        let _expression = to_untyped_data_expression(&mut term_pool, &t, &AHashSet::from_iter(["a".to_string()]));
146
1
    }
147
}