summaryrefslogtreecommitdiffstats
path: root/src/qdoc/catch_generators/src/catch_generators/generators/combinators/oneof_generator.h
blob: 5de9dcb6c6921395432d9fea7d415a3c7c1854b4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
// Copyright (C) 2022 The Qt Company Ltd.
// SPDX-License-Identifier: LicenseRef-Qt-Commercial OR GPL-3.0-only WITH Qt-GPL-exception-1.0

#pragma once

#include "../../namespaces.h"
#include "../../utilities/statistics/percentages.h"
#include "../../utilities/semantics/generator_handler.h"

#include <catch/catch.hpp>

#include <vector>
#include <random>
#include <algorithm>
#include <numeric>

namespace QDOC_CATCH_GENERATORS_ROOT_NAMESPACE {
    namespace QDOC_CATCH_GENERATORS_PRIVATE_NAMESPACE {

        template<typename T>
        class OneOfGenerator : public Catch::Generators::IGenerator<T> {
        public:
            OneOfGenerator(
                std::vector<Catch::Generators::GeneratorWrapper<T>>&& generators,
                const std::vector<double>& weights
            ) : generators{std::move(generators)},
                random_engine{std::random_device{}()},
                choice_distribution{weights.cbegin(), weights.cend()}
            {
                assert(weights.size() == this->generators.size());
                assert(std::reduce(weights.cbegin(), weights.cend()) == Approx(100.0));

                std::transform(
                    this->generators.begin(), this->generators.end(), this->generators.begin(),
                    [](auto& generator){ return QDOC_CATCH_GENERATORS_UTILITIES_ABSOLUTE_NAMESPACE::handler(std::move(generator)); }
                );

                static_cast<void>(next());
            }

            T const& get() const override { return current_value; }

            bool next() override {
                std::size_t generator_index{choice_distribution(random_engine)};

                if (!generators[generator_index].next()) return false;
                current_value = generators[generator_index].get();

                return true;
            }

        private:
            std::vector<Catch::Generators::GeneratorWrapper<T>> generators;

            std::mt19937 random_engine;
            std::discrete_distribution<std::size_t> choice_distribution;

            T current_value;
        };

    } // end QDOC_CATCH_GENERATORS_PRIVATE_NAMESPACE

    /*!
     * Returns a generator whose set of elements is the union of the
     * set of elements of the generators in \a generators.
     *
     * Each time the generator produces a value, a generator from \a
     * generators is randomly chosen to produce the value.
     *
     * The distribution for the choice is given by \a weights.
     * The \e {ith} element in \a weights represent the percentage
     * probability of the \e {ith} element of \a generators to be
     * chosen.
     *
     * It follows that the size of \a weights must be the same as the
     * size of \a generators.
     *
     * Furthermore, the sum of elements in \a weights should be a
     * hundred.
     *
     * The generator produces values until a generator that is chosen
     * to produce a value is unable to do so.
     * The first such generator to do so will stop the generation
     * independently of the availability of the other generators.
     *
     * Similarly, values will be produced as long as the chosen
     * generator can produce a value, independently of the other
     * generators being exhausted already.
     */
    template<typename T>
    inline Catch::Generators::GeneratorWrapper<T> oneof(
        std::vector<Catch::Generators::GeneratorWrapper<T>>&& generators,
        const std::vector<double>& weights
    ) {
        return Catch::Generators::GeneratorWrapper<T>(std::unique_ptr<Catch::Generators::IGenerator<T>>(new QDOC_CATCH_GENERATORS_PRIVATE_NAMESPACE::OneOfGenerator(std::move(generators), weights)));
    }


    /*!
     * Returns a generator whose set of elements is the union of the
     * set of elements of the generators in \a generators and in which
     * the distribution of the generated elements is uniform over \a
     * generators.
     *
     * Each time the generator produces a value, a generator from \a
     * generators is randomly chosen to produce the value.
     *
     * Each generator from \a generators has the same chance of being
     * chosen.
     *
     * Do note that the distribution over the set of values is not
     * necessarily uniform.
     *
     * The generator produces values until a generator that is chosen
     * to produce a value is unable to do so.
     * The first such generator to do so will stop the generation
     * independently of the availability of the other generators.
     *
     * Similarly, values will be produced as long as the chosen
     * generator can produce a value, independently of the other
     * generators being exhausted already.
     */
    template<typename T>
    inline Catch::Generators::GeneratorWrapper<T> uniform_oneof(
        std::vector<Catch::Generators::GeneratorWrapper<T>>&& generators
    ) {
        std::vector<double> weights(
            generators.size(),
            QDOC_CATCH_GENERATORS_UTILITIES_ABSOLUTE_NAMESPACE::uniform_probability(generators.size())
        );
        return oneof(std::move(generators), std::move(weights));
    }

    /*!
     * Returns a generator whose set of elements is the union of the
     * set of elements of the generators in \a generators and in which
     * the distribution of the generated elements is uniform over the
     * elements of \a generators.
     *
     * The generators in \a generator should have a uniform
     * distribution and be finite.
     * If the set of elements that the generators in \a generator is
     * not disjoint, the distribution will be skewed towards repeated
     * elements.
     *
     * Each time the generator produces a value, a generator from \a
     * generators is randomly chosen to produce the value.
     *
     * Each generator from \a generators has a probability of being
     * chosen based on the proportion of the cardinality of the subset
     * it produces.
     *
     * The \e {ith} element of \a amounts should contain the
     * cardinality of the set produced by the \e {ith} generator in \a
     * generators.
     *
     * The generator produces values until a generator that is chosen
     * to produce a value is unable to do so.
     * The first such generator to do so will stop the generation
     * independently of the availability of the other generators.
     *
     * Similarly, values will be produced as long as the chosen
     * generator can produce a value, independently of the other
     * generators being exhausted already.
     */
    template<typename T>
    inline Catch::Generators::GeneratorWrapper<T> uniformly_valued_oneof(
        std::vector<Catch::Generators::GeneratorWrapper<T>>&& generators,
        const std::vector<std::size_t>& amounts
    ) {
        std::size_t total_amount{std::accumulate(amounts.cbegin(), amounts.cend(), std::size_t{0})};

        std::vector<double> weights;
        weights.reserve(amounts.size());

        std::transform(
            amounts.cbegin(), amounts.cend(),
            std::back_inserter(weights),
            [total_amount](auto element){ return QDOC_CATCH_GENERATORS_UTILITIES_ABSOLUTE_NAMESPACE::percent_of(static_cast<double>(element), static_cast<double>(total_amount)); }
        );

        return oneof(std::move(generators), std::move(weights));
    }

} // end QDOC_CATCH_GENERATORS_ROOT_NAMESPACE