1 | // -*- C++ -*- |
2 | //===-- transform_binary.pass.cpp -----------------------------------------===// |
3 | // |
4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | //===----------------------------------------------------------------------===// |
9 | |
10 | // UNSUPPORTED: c++03, c++11, c++14 |
11 | |
12 | #include "support/pstl_test_config.h" |
13 | |
14 | #include <execution> |
15 | #include <algorithm> |
16 | |
17 | #include "support/utils.h" |
18 | |
19 | using namespace TestUtils; |
20 | |
21 | template <typename In1, typename In2, typename Out> |
22 | class TheOperation |
23 | { |
24 | Out val; |
25 | |
26 | public: |
27 | TheOperation(Out v) : val(v) {} |
28 | Out |
29 | operator()(const In1& x, const In2& y) const |
30 | { |
31 | return Out(val + x - y); |
32 | } |
33 | }; |
34 | |
35 | template <typename InputIterator1, typename InputIterator2, typename OutputIterator> |
36 | void |
37 | check_and_reset(InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, OutputIterator out_first) |
38 | { |
39 | typedef typename std::iterator_traits<OutputIterator>::value_type Out; |
40 | typename std::iterator_traits<OutputIterator>::difference_type k = 0; |
41 | for (; first1 != last1; ++first1, ++first2, ++out_first, ++k) |
42 | { |
43 | // check |
44 | Out expected = Out(1.5) + *first1 - *first2; |
45 | Out actual = *out_first; |
46 | if (std::is_floating_point<Out>::value) |
47 | { |
48 | EXPECT_TRUE((expected > actual ? expected - actual : actual - expected) < 1e7, |
49 | "wrong value in output sequence" ); |
50 | } |
51 | else |
52 | { |
53 | EXPECT_EQ(expected, actual, "wrong value in output sequence" ); |
54 | } |
55 | // reset |
56 | *out_first = k % 7 != 4 ? 7 * k - 5 : 0; |
57 | } |
58 | } |
59 | |
60 | struct test_one_policy |
61 | { |
62 | template <typename Policy, typename InputIterator1, typename InputIterator2, typename OutputIterator, |
63 | typename BinaryOp> |
64 | void |
65 | operator()(Policy&& exec, InputIterator1 first1, InputIterator1 last1, InputIterator2 first2, InputIterator2, |
66 | OutputIterator out_first, OutputIterator, BinaryOp op) |
67 | { |
68 | auto result = std::transform(exec, first1, last1, first2, out_first, op); |
69 | (void)result; |
70 | check_and_reset(first1, last1, first2, out_first); |
71 | } |
72 | }; |
73 | |
74 | template <typename In1, typename In2, typename Out, typename Predicate> |
75 | void |
76 | test(Predicate pred) |
77 | { |
78 | for (size_t n = 0; n <= 100000; n = n <= 16 ? n + 1 : size_t(3.1415 * n)) |
79 | { |
80 | Sequence<In1> in1(n, [](size_t k) { return k % 5 != 1 ? 3 * k - 7 : 0; }); |
81 | Sequence<In2> in2(n, [](size_t k) { return k % 7 != 2 ? 5 * k - 5 : 0; }); |
82 | |
83 | Sequence<Out> out(n, [](size_t) { return -1; }); |
84 | |
85 | invoke_on_all_policies(test_one_policy(), in1.begin(), in1.end(), in2.begin(), in2.end(), out.begin(), |
86 | out.end(), pred); |
87 | invoke_on_all_policies(test_one_policy(), in1.cbegin(), in1.cend(), in2.cbegin(), in2.cend(), out.begin(), |
88 | out.end(), pred); |
89 | } |
90 | } |
91 | |
92 | template <typename T> |
93 | struct test_non_const |
94 | { |
95 | template <typename Policy, typename InputIterator, typename OutputInterator> |
96 | void |
97 | operator()(Policy&& exec, InputIterator input_iter, OutputInterator out_iter) |
98 | { |
99 | invoke_if(exec, [&]() { |
100 | InputIterator input_iter2 = input_iter; |
101 | transform(exec, input_iter, input_iter, input_iter2, out_iter, non_const(std::plus<T>())); |
102 | }); |
103 | } |
104 | }; |
105 | |
106 | int |
107 | main() |
108 | { |
109 | //const operator() |
110 | test<int32_t, int32_t, int32_t>(pred: TheOperation<int32_t, int32_t, int32_t>(1)); |
111 | test<float32_t, float32_t, float32_t>(pred: TheOperation<float32_t, float32_t, float32_t>(1.5)); |
112 | //non-const operator() |
113 | test<int32_t, float32_t, float32_t>(pred: non_const(f: TheOperation<int32_t, float32_t, float32_t>(1.5))); |
114 | test<int64_t, float64_t, float32_t>(pred: non_const(f: TheOperation<int64_t, float64_t, float32_t>(1.5))); |
115 | //lambda |
116 | test<int8_t, float64_t, int8_t>(pred: [](const int8_t& x, const float64_t& y) { return int8_t(int8_t(1.5) + x - y); }); |
117 | |
118 | test_algo_basic_double<int32_t>(f: run_for_rnd_fw<test_non_const<int32_t>>()); |
119 | |
120 | std::cout << done() << std::endl; |
121 | return 0; |
122 | } |
123 | |