summaryrefslogtreecommitdiffstats
path: root/libc/utils/MPFRWrapper/MPFRUtils.h
blob: d5ff590cd7bb69bcae59bd0dfd231846cd5dc44f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
//===-- MPFRUtils.h ---------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_UTILS_MPFRWRAPPER_MPFRUTILS_H
#define LLVM_LIBC_UTILS_MPFRWRAPPER_MPFRUTILS_H

#include "src/__support/CPP/type_traits.h"
#include "test/UnitTest/RoundingModeUtils.h"
#include "test/UnitTest/Test.h"

#include <stdint.h>

namespace LIBC_NAMESPACE {
namespace testing {
namespace mpfr {

enum class Operation : int {
  // Operations with take a single floating point number as input
  // and produce a single floating point number as output. The input
  // and output floating point numbers are of the same kind.
  BeginUnaryOperationsSingleOutput,
  Abs,
  Acos,
  Acosh,
  Asin,
  Asinh,
  Atan,
  Atanh,
  Ceil,
  Cos,
  Cosh,
  Erf,
  Exp,
  Exp2,
  Exp10,
  Expm1,
  Floor,
  Log,
  Log2,
  Log10,
  Log1p,
  Mod2PI,
  ModPIOver2,
  ModPIOver4,
  Round,
  Sin,
  Sinh,
  Sqrt,
  Tan,
  Tanh,
  Trunc,
  EndUnaryOperationsSingleOutput,

  // Operations which take a single floating point nubmer as input
  // but produce two outputs. The first ouput is a floating point
  // number of the same type as the input. The second output is of type
  // 'int'.
  BeginUnaryOperationsTwoOutputs,
  Frexp, // Floating point output, the first output, is the fractional part.
  EndUnaryOperationsTwoOutputs,

  // Operations wich take two floating point nubmers of the same type as
  // input and produce a single floating point number of the same type as
  // output.
  BeginBinaryOperationsSingleOutput,
  Atan2,
  Fmod,
  Hypot,
  Pow,
  EndBinaryOperationsSingleOutput,

  // Operations which take two floating point numbers of the same type as
  // input and produce two outputs. The first output is a floating nubmer of
  // the same type as the inputs. The second output is af type 'int'.
  BeginBinaryOperationsTwoOutputs,
  RemQuo, // The first output, the floating point output, is the remainder.
  EndBinaryOperationsTwoOutputs,

  // Operations which take three floating point nubmers of the same type as
  // input and produce a single floating point number of the same type as
  // output.
  BeginTernaryOperationsSingleOuput,
  Fma,
  EndTernaryOperationsSingleOutput,
};

using LIBC_NAMESPACE::fputil::testing::ForceRoundingMode;
using LIBC_NAMESPACE::fputil::testing::RoundingMode;

template <typename T> struct BinaryInput {
  static_assert(
      LIBC_NAMESPACE::cpp::is_floating_point_v<T>,
      "Template parameter of BinaryInput must be a floating point type.");

  using Type = T;
  T x, y;
};

template <typename T> struct TernaryInput {
  static_assert(
      LIBC_NAMESPACE::cpp::is_floating_point_v<T>,
      "Template parameter of TernaryInput must be a floating point type.");

  using Type = T;
  T x, y, z;
};

template <typename T> struct BinaryOutput {
  T f;
  int i;
};

namespace internal {

template <typename T1, typename T2>
struct AreMatchingBinaryInputAndBinaryOutput {
  static constexpr bool VALUE = false;
};

template <typename T>
struct AreMatchingBinaryInputAndBinaryOutput<BinaryInput<T>, BinaryOutput<T>> {
  static constexpr bool VALUE = cpp::is_floating_point_v<T>;
};

template <typename T>
bool compare_unary_operation_single_output(Operation op, T input, T libc_output,
                                           double ulp_tolerance,
                                           RoundingMode rounding);
template <typename T>
bool compare_unary_operation_two_outputs(Operation op, T input,
                                         const BinaryOutput<T> &libc_output,
                                         double ulp_tolerance,
                                         RoundingMode rounding);
template <typename T>
bool compare_binary_operation_two_outputs(Operation op,
                                          const BinaryInput<T> &input,
                                          const BinaryOutput<T> &libc_output,
                                          double ulp_tolerance,
                                          RoundingMode rounding);

template <typename T>
bool compare_binary_operation_one_output(Operation op,
                                         const BinaryInput<T> &input,
                                         T libc_output, double ulp_tolerance,
                                         RoundingMode rounding);

template <typename T>
bool compare_ternary_operation_one_output(Operation op,
                                          const TernaryInput<T> &input,
                                          T libc_output, double ulp_tolerance,
                                          RoundingMode rounding);

template <typename T>
void explain_unary_operation_single_output_error(Operation op, T input,
                                                 T match_value,
                                                 double ulp_tolerance,
                                                 RoundingMode rounding);
template <typename T>
void explain_unary_operation_two_outputs_error(
    Operation op, T input, const BinaryOutput<T> &match_value,
    double ulp_tolerance, RoundingMode rounding);
template <typename T>
void explain_binary_operation_two_outputs_error(
    Operation op, const BinaryInput<T> &input,
    const BinaryOutput<T> &match_value, double ulp_tolerance,
    RoundingMode rounding);

template <typename T>
void explain_binary_operation_one_output_error(Operation op,
                                               const BinaryInput<T> &input,
                                               T match_value,
                                               double ulp_tolerance,
                                               RoundingMode rounding);

template <typename T>
void explain_ternary_operation_one_output_error(Operation op,
                                                const TernaryInput<T> &input,
                                                T match_value,
                                                double ulp_tolerance,
                                                RoundingMode rounding);

template <Operation op, bool silent, typename InputType, typename OutputType>
class MPFRMatcher : public testing::Matcher<OutputType> {
  InputType input;
  OutputType match_value;
  double ulp_tolerance;
  RoundingMode rounding;

public:
  MPFRMatcher(InputType testInput, double ulp_tolerance, RoundingMode rounding)
      : input(testInput), ulp_tolerance(ulp_tolerance), rounding(rounding) {}

  bool match(OutputType libcResult) {
    match_value = libcResult;
    return match(input, match_value);
  }

  // This method is marked with NOLINT because the name `explainError` does not
  // conform to the coding style.
  void explainError() override { // NOLINT
    explain_error(input, match_value);
  }

  // Whether the `explainError` step is skipped or not.
  bool is_silent() const override { return silent; }

private:
  template <typename T> bool match(T in, T out) {
    return compare_unary_operation_single_output(op, in, out, ulp_tolerance,
                                                 rounding);
  }

  template <typename T> bool match(T in, const BinaryOutput<T> &out) {
    return compare_unary_operation_two_outputs(op, in, out, ulp_tolerance,
                                               rounding);
  }

  template <typename T> bool match(const BinaryInput<T> &in, T out) {
    return compare_binary_operation_one_output(op, in, out, ulp_tolerance,
                                               rounding);
  }

  template <typename T>
  bool match(BinaryInput<T> in, const BinaryOutput<T> &out) {
    return compare_binary_operation_two_outputs(op, in, out, ulp_tolerance,
                                                rounding);
  }

  template <typename T> bool match(const TernaryInput<T> &in, T out) {
    return compare_ternary_operation_one_output(op, in, out, ulp_tolerance,
                                                rounding);
  }

  template <typename T> void explain_error(T in, T out) {
    explain_unary_operation_single_output_error(op, in, out, ulp_tolerance,
                                                rounding);
  }

  template <typename T> void explain_error(T in, const BinaryOutput<T> &out) {
    explain_unary_operation_two_outputs_error(op, in, out, ulp_tolerance,
                                              rounding);
  }

  template <typename T>
  void explain_error(const BinaryInput<T> &in, const BinaryOutput<T> &out) {
    explain_binary_operation_two_outputs_error(op, in, out, ulp_tolerance,
                                               rounding);
  }

  template <typename T> void explain_error(const BinaryInput<T> &in, T out) {
    explain_binary_operation_one_output_error(op, in, out, ulp_tolerance,
                                              rounding);
  }

  template <typename T> void explain_error(const TernaryInput<T> &in, T out) {
    explain_ternary_operation_one_output_error(op, in, out, ulp_tolerance,
                                               rounding);
  }
};

} // namespace internal

// Return true if the input and ouput types for the operation op are valid
// types.
template <Operation op, typename InputType, typename OutputType>
constexpr bool is_valid_operation() {
  return (Operation::BeginUnaryOperationsSingleOutput < op &&
          op < Operation::EndUnaryOperationsSingleOutput &&
          cpp::is_same_v<InputType, OutputType> &&
          cpp::is_floating_point_v<InputType>) ||
         (Operation::BeginUnaryOperationsTwoOutputs < op &&
          op < Operation::EndUnaryOperationsTwoOutputs &&
          cpp::is_floating_point_v<InputType> &&
          cpp::is_same_v<OutputType, BinaryOutput<InputType>>) ||
         (Operation::BeginBinaryOperationsSingleOutput < op &&
          op < Operation::EndBinaryOperationsSingleOutput &&
          cpp::is_floating_point_v<OutputType> &&
          cpp::is_same_v<InputType, BinaryInput<OutputType>>) ||
         (Operation::BeginBinaryOperationsTwoOutputs < op &&
          op < Operation::EndBinaryOperationsTwoOutputs &&
          internal::AreMatchingBinaryInputAndBinaryOutput<InputType,
                                                          OutputType>::VALUE) ||
         (Operation::BeginTernaryOperationsSingleOuput < op &&
          op < Operation::EndTernaryOperationsSingleOutput &&
          cpp::is_floating_point_v<OutputType> &&
          cpp::is_same_v<InputType, TernaryInput<OutputType>>);
}

template <Operation op, typename InputType, typename OutputType>
__attribute__((no_sanitize("address"))) cpp::enable_if_t<
    is_valid_operation<op, InputType, OutputType>(),
    internal::MPFRMatcher<op, /*is_silent*/ false, InputType, OutputType>>
get_mpfr_matcher(InputType input, OutputType output_unused,
                 double ulp_tolerance, RoundingMode rounding) {
  return internal::MPFRMatcher<op, /*is_silent*/ false, InputType, OutputType>(
      input, ulp_tolerance, rounding);
}

template <Operation op, typename InputType, typename OutputType>
__attribute__((no_sanitize("address"))) cpp::enable_if_t<
    is_valid_operation<op, InputType, OutputType>(),
    internal::MPFRMatcher<op, /*is_silent*/ true, InputType, OutputType>>
get_silent_mpfr_matcher(InputType input, OutputType output_unused,
                        double ulp_tolerance, RoundingMode rounding) {
  return internal::MPFRMatcher<op, /*is_silent*/ true, InputType, OutputType>(
      input, ulp_tolerance, rounding);
}

template <typename T> T round(T x, RoundingMode mode);

template <typename T> bool round_to_long(T x, long &result);
template <typename T> bool round_to_long(T x, RoundingMode mode, long &result);

} // namespace mpfr
} // namespace testing
} // namespace LIBC_NAMESPACE

// GET_MPFR_DUMMY_ARG is going to be added to the end of GET_MPFR_MACRO as a
// simple way to avoid the compiler warning `gnu-zero-variadic-macro-arguments`.
#define GET_MPFR_DUMMY_ARG(...) 0

#define GET_MPFR_MACRO(__1, __2, __3, __4, __5, __NAME, ...) __NAME

#define EXPECT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance)       \
  EXPECT_THAT(match_value,                                                     \
              LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>(             \
                  input, match_value, ulp_tolerance,                           \
                  LIBC_NAMESPACE::testing::mpfr::RoundingMode::Nearest))

#define EXPECT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,      \
                                   rounding)                                   \
  EXPECT_THAT(match_value,                                                     \
              LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>(             \
                  input, match_value, ulp_tolerance, rounding))

#define EXPECT_MPFR_MATCH(...)                                                 \
  GET_MPFR_MACRO(__VA_ARGS__, EXPECT_MPFR_MATCH_ROUNDING,                      \
                 EXPECT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG)                \
  (__VA_ARGS__)

#define TEST_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,        \
                                 rounding)                                     \
  LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>(input, match_value,      \
                                                      ulp_tolerance, rounding) \
      .match(match_value)

#define TEST_MPFR_MATCH(...)                                                   \
  GET_MPFR_MACRO(__VA_ARGS__, TEST_MPFR_MATCH_ROUNDING,                        \
                 EXPECT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG)                \
  (__VA_ARGS__)

#define EXPECT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance)  \
  {                                                                            \
    namespace mpfr = LIBC_NAMESPACE::testing::mpfr;                            \
    mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest);                 \
    if (__r1.success) {                                                        \
      EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                 \
                        mpfr::RoundingMode::Nearest);                          \
    }                                                                          \
    mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward);                  \
    if (__r2.success) {                                                        \
      EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                 \
                        mpfr::RoundingMode::Upward);                           \
    }                                                                          \
    mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward);                \
    if (__r3.success) {                                                        \
      EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                 \
                        mpfr::RoundingMode::Downward);                         \
    }                                                                          \
    mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero);              \
    if (__r4.success) {                                                        \
      EXPECT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                 \
                        mpfr::RoundingMode::TowardZero);                       \
    }                                                                          \
  }

#define TEST_MPFR_MATCH_ROUNDING_SILENTLY(op, input, match_value,              \
                                          ulp_tolerance, rounding)             \
  LIBC_NAMESPACE::testing::mpfr::get_silent_mpfr_matcher<op>(                  \
      input, match_value, ulp_tolerance, rounding)                             \
      .match(match_value)

#define ASSERT_MPFR_MATCH_DEFAULT(op, input, match_value, ulp_tolerance)       \
  ASSERT_THAT(match_value,                                                     \
              LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>(             \
                  input, match_value, ulp_tolerance,                           \
                  LIBC_NAMESPACE::testing::mpfr::RoundingMode::Nearest))

#define ASSERT_MPFR_MATCH_ROUNDING(op, input, match_value, ulp_tolerance,      \
                                   rounding)                                   \
  ASSERT_THAT(match_value,                                                     \
              LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<op>(             \
                  input, match_value, ulp_tolerance, rounding))

#define ASSERT_MPFR_MATCH(...)                                                 \
  GET_MPFR_MACRO(__VA_ARGS__, ASSERT_MPFR_MATCH_ROUNDING,                      \
                 ASSERT_MPFR_MATCH_DEFAULT, GET_MPFR_DUMMY_ARG)                \
  (__VA_ARGS__)

#define ASSERT_MPFR_MATCH_ALL_ROUNDING(op, input, match_value, ulp_tolerance)  \
  {                                                                            \
    namespace mpfr = LIBC_NAMESPACE::testing::mpfr;                            \
    mpfr::ForceRoundingMode __r1(mpfr::RoundingMode::Nearest);                 \
    if (__r1.success) {                                                        \
      ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                 \
                        mpfr::RoundingMode::Nearest);                          \
    }                                                                          \
    mpfr::ForceRoundingMode __r2(mpfr::RoundingMode::Upward);                  \
    if (__r2.success) {                                                        \
      ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                 \
                        mpfr::RoundingMode::Upward);                           \
    }                                                                          \
    mpfr::ForceRoundingMode __r3(mpfr::RoundingMode::Downward);                \
    if (__r3.success) {                                                        \
      ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                 \
                        mpfr::RoundingMode::Downward);                         \
    }                                                                          \
    mpfr::ForceRoundingMode __r4(mpfr::RoundingMode::TowardZero);              \
    if (__r4.success) {                                                        \
      ASSERT_MPFR_MATCH(op, input, match_value, ulp_tolerance,                 \
                        mpfr::RoundingMode::TowardZero);                       \
    }                                                                          \
  }

#endif // LLVM_LIBC_UTILS_MPFRWRAPPER_MPFRUTILS_H