summaryrefslogtreecommitdiffstats
path: root/src/qdoc/catch_generators/src/catch_generators/generators/k_partition_of_r_generator.h
blob: 832ee2838e7739ac396fb3d1bda243569d8f5020 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
// Copyright (C) 2022 The Qt Company Ltd.
// SPDX-License-Identifier: LicenseRef-Qt-Commercial OR GPL-3.0-only WITH Qt-GPL-exception-1.0

#pragma once

#include "../namespaces.h"

#include <catch/catch.hpp>

#include <random>
#include <numeric>
#include <algorithm>

namespace QDOC_CATCH_GENERATORS_ROOT_NAMESPACE {
    namespace QDOC_CATCH_GENERATORS_PRIVATE_NAMESPACE {

        class KPartitionOfRGenerator : public Catch::Generators::IGenerator<std::vector<double>> {
        public:
            KPartitionOfRGenerator(double r, std::size_t k)
                : random_engine{std::random_device{}()},
                  interval_distribution{0.0, r},
                  k{k},
                  r{r},
                  current_partition(k)
            {
                assert(r >= 0.0);
                assert(k >= 1);

                static_cast<void>(next());
            }

            std::vector<double> const& get() const override { return current_partition; }

            bool next() override {
                if (k == 1) current_partition[0] = r;
                else {
                    // REMARK: The following wasn't formally proved
                    // but is based on intuition.
                    // It is probably erroneous but is expected to be
                    // good enough for our case.

                    // REMARK: We aim to provide a non skewed
                    // distribution for the elements of the partition.
                    //
                    // The reasoning for this is to ensure that our
                    // testing surface has a good chance of hitting
                    // many of the available elements between the many
                    // runs.
                    //
                    // To approximate this, a specific algorithm was chosen.
                    // The following code can be intuitively seen as doing the following:
                    //
                    // Consider an interval [0.0, r] on the real line, where r > 0.0.
                    //
                    // k - 1 > 0 elements of the interval are chosen,
                    // partitioning the interval into disjoint
                    // sub-intervals.
                    //
                    // ---------------------------------------------------------------------------------------------------------------------
                    // |     |                   |                                                       |                                 |
                    // 0    k_1                 k_2                                                     k_3                                r
                    // |     |                   |                                                       |                                 |
                    // _______--------------------_______________________________________________________-----------------------------------
                    // k_1 - 0     k_2 - k_1                           k_3 - k_2                                       r - k_3
                    //    p1          p2                                  p3                                            p4
                    //
                    // The length of each sub interval is chosen as one of the elements of the partition.
                    //
                    // Trivially, the sum of the chosen elements is r.
                    //
                    // Furthermore, as long as the distribution used
                    // to choose the elements of the original interval
                    // is uniform, the probability of each partition
                    // being produced should tend to being uniform
                    // itself.
                    std::generate(current_partition.begin(), current_partition.end() - 1, [this](){ return interval_distribution(random_engine); });

                    current_partition.back() = r;

                    std::sort(current_partition.begin(), current_partition.end());
                    std::adjacent_difference(current_partition.begin(), current_partition.end(), current_partition.begin());
                }

                return true;
            }

        private:
            std::mt19937 random_engine;
            std::uniform_real_distribution<double> interval_distribution;

            std::size_t k;
            double r;

            std::vector<double> current_partition;
        };

    } // end QDOC_CATCH_GENERATORS_PRIVATE_NAMESPACE

    /*!
     * Returns a generator that generates collections of \a k elements
     * whose sum is \a r.
     *
     * \a r must be a real number greater or euqal to zero and \a k
     * must be a natural number greater than zero.
     *
     * The generated partitions tends to be uniformely distributed
     * over the set of partitions of r.
     */
    inline Catch::Generators::GeneratorWrapper<std::vector<double>> k_partition_of_r(double r, std::size_t k) {
        return Catch::Generators::GeneratorWrapper<std::vector<double>>(std::unique_ptr<Catch::Generators::IGenerator<std::vector<double>>>(new QDOC_CATCH_GENERATORS_PRIVATE_NAMESPACE::KPartitionOfRGenerator(r, k)));
    }

} // end QDOC_CATCH_GENERATORS_ROOT_NAMESPACE