1
use std::error::Error;
2
use std::fmt;
3

            
4
use ahash::AHashSet;
5
use log::trace;
6

            
7
use crate::aterm::ATerm;
8
use crate::aterm::Symbol;
9
use crate::aterm::TermPool;
10

            
11
/// This can be used to construct a term from a given input of (inductive) type I,
12
/// without using the system stack, i.e. recursion. See evaluate.
13
#[derive(Default)]
14
pub struct TermBuilder<I, C> {
15
    // The stack of terms
16
    terms: Vec<ATerm>,
17
    configs: Vec<Config<I, C>>,
18
}
19

            
20
/// Applies the given function to every subterm of the given term using the [TermBuilder].
21
///     function(subterm) returns:
22
///         None   , in which case subterm is kept and it is recursed into its argments.
23
///         Some(x), in which case subterm is replaced by x.
24
5
pub fn apply<F>(tp: &mut TermPool, t: &ATerm, function: &F) -> ATerm
25
5
where
26
5
    F: Fn(&mut TermPool, &ATerm) -> Option<ATerm>,
27
5
{
28
5
    let mut builder = TermBuilder::<ATerm, Symbol>::new();
29
5

            
30
5
    builder
31
5
        .evaluate(
32
5
            tp,
33
5
            t.clone(),
34
19
            |tp, args, t| match function(tp, &t) {
35
8
                Some(result) => Ok(Yield::Term(result)),
36
                None => {
37
14
                    for arg in t.arguments() {
38
14
                        args.push(arg.protect());
39
14
                    }
40

            
41
11
                    Ok(Yield::Construct(t.get_head_symbol().protect()))
42
                }
43
19
            },
44
11
            |tp, symbol, args| Ok(tp.create(&symbol, args)),
45
5
        )
46
5
        .unwrap()
47
5
}
48

            
49
impl<I: fmt::Debug, C: fmt::Debug> TermBuilder<I, C> {
50
8091052
    pub fn new() -> TermBuilder<I, C> {
51
8091052
        TermBuilder {
52
8091052
            terms: vec![],
53
8091052
            configs: vec![],
54
8091052
        }
55
8091052
    }
56

            
57
    /// This can be used to construct a term from a given input of (inductive)
58
    /// type I, without using the system stack, i.e. recursion.
59
    ///
60
    /// The `transformer` function is applied to every instance I, which can put
61
    /// more generate more inputs using a so-called argument stack and some
62
    /// instance C that is used to construct the result term. Alternatively, it
63
    /// yields a result term directly.
64
    ///
65
    /// The `construct` function takes an instance C and the arguments pushed to
66
    /// stack where the transformer was applied for every input pushed onto the
67
    /// stack previously.
68
    ///
69
    /// # Example
70
    ///
71
    /// A simple example could be to transform a term into another term using a
72
    /// function `f : ATerm -> Option<ATerm>`. Then `I` will be ATerm since that is
73
    /// the input, and `C` will be the Symbol from which we can construct the
74
    /// recursive term.
75
    ///
76
    /// `transformer` takes the input and applies f(input). Then either we
77
    /// return Yield(x) when f returns some term, or Construct(head(input)) with
78
    /// the arguments of the input term pushed to stack.
79
    ///
80
    /// `construct` simply constructs the term from the symbol and the arguments
81
    /// on the stack.
82
    ///
83
    /// However, it can also be that I is some syntax tree from which we want to
84
    /// construct a term.
85
10886591
    pub fn evaluate<F, G>(
86
10886591
        &mut self,
87
10886591
        tp: &mut TermPool,
88
10886591
        input: I,
89
10886591
        transformer: F,
90
10886591
        construct: G,
91
10886591
    ) -> Result<ATerm, Box<dyn Error>>
92
10886591
    where
93
10886591
        F: Fn(&mut TermPool, &mut ArgStack<I, C>, I) -> Result<Yield<C>, Box<dyn Error>>,
94
10886591
        G: Fn(&mut TermPool, C, &[ATerm]) -> Result<ATerm, Box<dyn Error>>,
95
10886591
    {
96
10886591
        trace!("Transforming {:?}", input);
97
10886591
        self.terms.push(ATerm::default());
98
10886591
        self.configs.push(Config::Apply(input, 0));
99

            
100
60827513
        while let Some(config) = self.configs.pop() {
101
49940922
            match config {
102
39737948
                Config::Apply(input, result) => {
103
39737948
                    // Applies the given function to this input, and obtain a number of symbol and arguments.
104
39737948
                    let top_of_stack = self.configs.len();
105
39737948
                    let mut args = ArgStack::new(&mut self.terms, &mut self.configs);
106
39737948

            
107
39737948
                    match transformer(tp, &mut args, input)? {
108
10202974
                        Yield::Construct(input) => {
109
10202974
                            // This occurs before the other constructs.
110
10202974
                            let arity = args.len();
111
10202974
                            self.configs.reserve(1);
112
10202974
                            self.configs
113
10202974
                                .insert(top_of_stack, Config::Construct(input, arity, result));
114
10202974
                        }
115
29534974
                        Yield::Term(term) => {
116
29534974
                            self.terms[result] = term;
117
29534974
                        }
118
                    }
119
                }
120
10202974
                Config::Construct(input, arity, result) => {
121
10202974
                    let arguments = &self.terms[self.terms.len() - arity..];
122

            
123
10202974
                    self.terms[result] = construct(tp, input, arguments)?;
124

            
125
                    // Remove elements from the stack.
126
10202974
                    self.terms.drain(self.terms.len() - arity..);
127
                }
128
            }
129

            
130
49940922
            trace!("{:?}", self);
131
        }
132

            
133
10886591
        debug_assert!(self.terms.len() == 1, "Expect exactly one term on the result stack");
134

            
135
10886591
        Ok(self.terms.pop().expect("There should be at last one result"))
136
10886591
    }
137
}
138

            
139
enum Config<I, C> {
140
    Apply(I, usize),
141
    Construct(C, usize, usize),
142
}
143

            
144
pub enum Yield<C> {
145
    Term(ATerm),  // Yield this term as is.
146
    Construct(C), // Yield f(args) for every arg push to the argument stack, with the function applied to it.
147
}
148

            
149
/// This struct defines a local argument stack on the global stack.
150
pub struct ArgStack<'a, I, C> {
151
    terms: &'a mut Vec<ATerm>,
152
    configs: &'a mut Vec<Config<I, C>>,
153
    top_of_stack: usize,
154
}
155

            
156
impl<'a, I, C> ArgStack<'a, I, C> {
157
39737948
    fn new(terms: &'a mut Vec<ATerm>, configs: &'a mut Vec<Config<I, C>>) -> ArgStack<'a, I, C> {
158
39737948
        let top_of_stack = terms.len();
159
39737948
        ArgStack {
160
39737948
            terms,
161
39737948
            configs,
162
39737948
            top_of_stack,
163
39737948
        }
164
39737948
    }
165

            
166
    /// Returns the amount of arguments added.
167
10202974
    fn len(&self) -> usize {
168
10202974
        self.terms.len() - self.top_of_stack
169
10202974
    }
170

            
171
    /// Adds the term to the argument stack, will construct construct(C, args...) with the transformer applied to arguments.
172
28851357
    pub fn push(&mut self, input: I) {
173
28851357
        self.configs.push(Config::Apply(input, self.terms.len()));
174
28851357
        self.terms.push(ATerm::default());
175
28851357
    }
176
}
177

            
178
impl<I: fmt::Debug, C: fmt::Debug> fmt::Debug for TermBuilder<I, C> {
179
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180
        writeln!(f, "Terms: [")?;
181
        for (i, term) in self.terms.iter().enumerate() {
182
            writeln!(f, "{}\t{:?}", i, term)?;
183
        }
184
        writeln!(f, "]")?;
185

            
186
        writeln!(f, "Configs: [")?;
187
        for config in &self.configs {
188
            writeln!(f, "\t{:?}", config)?;
189
        }
190
        write!(f, "]")
191
    }
192
}
193

            
194
impl<I: fmt::Debug, C: fmt::Debug> fmt::Debug for Config<I, C> {
195
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196
        match self {
197
            Config::Apply(x, result) => write!(f, "Apply({:?}, {})", x, result),
198
            Config::Construct(symbol, arity, result) => {
199
                write!(f, "Construct({:?}, {}, {})", symbol, arity, result)
200
            }
201
        }
202
    }
203
}
204

            
205
/// Create a random term consisting of the given symbol and constants. Performs
206
/// iterations number of constructions, and uses chance_duplicates to choose the
207
/// amount of subterms that are duplicated.
208
401
pub fn random_term(
209
401
    tp: &mut TermPool,
210
401
    rng: &mut impl rand::Rng,
211
401
    symbols: &[(String, usize)],
212
401
    constants: &[String],
213
401
    iterations: usize,
214
401
) -> ATerm {
215
    use rand::prelude::IteratorRandom;
216

            
217
401
    debug_assert!(!constants.is_empty(), "We need constants to be able to create a term");
218

            
219
802
    let mut subterms = AHashSet::<ATerm>::from_iter(constants.iter().map(|name| {
220
802
        let symbol = tp.create_symbol(name, 0);
221
802
        let a: &[ATerm] = &[];
222
802
        tp.create(&symbol, a)
223
802
    }));
224
401

            
225
401
    let mut result = ATerm::default();
226
401
    for _ in 0..iterations {
227
4005
        let (symbol, arity) = symbols.iter().choose(rng).unwrap();
228
4005

            
229
4005
        let mut arguments = vec![];
230
8010
        for _ in 0..*arity {
231
8010
            arguments.push(subterms.iter().choose(rng).unwrap().clone());
232
8010
        }
233

            
234
4005
        let symbol = tp.create_symbol(symbol, *arity);
235
4005
        result = tp.create(&symbol, &arguments);
236
4005

            
237
4005
        // Make this term available as another subterm that can be used.
238
4005
        subterms.insert(result.clone());
239
    }
240

            
241
401
    result
242
401
}