summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVitaly Buka <vitalybuka@google.com>2024-04-04 14:20:01 -0700
committerVitaly Buka <vitalybuka@google.com>2024-04-04 14:20:01 -0700
commit7a3817eae6589cc4b4c63e7327a3a87e8acf2fb5 (patch)
tree77164df55d95cf8f7d71aecd34d0f0d0641df57d
parent03b73cd9c221234dee96ab3cb48b0a38e303644e (diff)
parentf5960c168dfe17c7599acea0a7d94a26545f4777 (diff)
Created using spr 1.3.4 [skip ci]
-rw-r--r--clang/lib/Headers/__stddef_unreachable.h4
-rw-r--r--compiler-rt/lib/hwasan/hwasan_thread_list.h3
-rw-r--r--flang/lib/Semantics/check-declarations.cpp5
-rw-r--r--flang/test/Semantics/cuf03.cuf3
-rw-r--r--libcxx/include/CMakeLists.txt2
-rw-r--r--libcxx/include/__algorithm/comp.h5
-rw-r--r--libcxx/include/__algorithm/equal.h7
-rw-r--r--libcxx/include/__algorithm/mismatch.h4
-rw-r--r--libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h8
-rw-r--r--libcxx/include/__functional/operations.h11
-rw-r--r--libcxx/include/__functional/ranges_operations.h5
-rw-r--r--libcxx/include/__numeric/pstl_transform_reduce.h2
-rw-r--r--libcxx/include/__type_traits/desugars_to.h (renamed from libcxx/include/__type_traits/operation_traits.h)9
-rw-r--r--libcxx/include/libcxx.imp2
-rw-r--r--libcxx/include/module.modulemap2
-rw-r--r--llvm/include/llvm/IR/Intrinsics.td3
-rw-r--r--llvm/lib/ProfileData/InstrProfWriter.cpp218
-rw-r--r--llvm/lib/Target/ARM/Thumb2InstrInfo.cpp19
-rw-r--r--llvm/lib/Target/ARM/Thumb2InstrInfo.h4
-rw-r--r--llvm/test/CodeGen/ARM/misched-branch-targets.mir166
-rw-r--r--mlir/include/mlir/Dialect/Affine/LoopUtils.h49
-rw-r--r--mlir/include/mlir/Dialect/SCF/Utils/Utils.h7
-rw-r--r--mlir/include/mlir/IR/PatternMatch.h3
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp4
-rw-r--r--mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp48
-rw-r--r--mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Utils/Utils.cpp353
-rw-r--r--mlir/lib/IR/PatternMatch.cpp9
-rw-r--r--mlir/test/Dialect/Affine/loop-coalescing.mlir71
-rw-r--r--mlir/test/Dialect/SCF/transform-op-coalesce.mlir211
-rw-r--r--mlir/test/Transforms/parallel-loop-collapsing.mlir32
-rw-r--r--mlir/test/Transforms/single-parallel-loop-collapsing.mlir32
-rw-r--r--mlir/test/mlir-tblgen/op-properties.td21
-rw-r--r--mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp52
35 files changed, 830 insertions, 552 deletions
diff --git a/clang/lib/Headers/__stddef_unreachable.h b/clang/lib/Headers/__stddef_unreachable.h
index 518580c92d3f..61df43e9732f 100644
--- a/clang/lib/Headers/__stddef_unreachable.h
+++ b/clang/lib/Headers/__stddef_unreachable.h
@@ -7,6 +7,8 @@
*===-----------------------------------------------------------------------===
*/
+#ifndef __cplusplus
+
/*
* When -fbuiltin-headers-in-system-modules is set this is a non-modular header
* and needs to behave as if it was textual.
@@ -15,3 +17,5 @@
(__has_feature(modules) && !__building_module(_Builtin_stddef))
#define unreachable() __builtin_unreachable()
#endif
+
+#endif
diff --git a/compiler-rt/lib/hwasan/hwasan_thread_list.h b/compiler-rt/lib/hwasan/hwasan_thread_list.h
index f36d27864fc2..369a1c3d6f5f 100644
--- a/compiler-rt/lib/hwasan/hwasan_thread_list.h
+++ b/compiler-rt/lib/hwasan/hwasan_thread_list.h
@@ -55,6 +55,9 @@ static uptr RingBufferSize() {
uptr desired_bytes = flags()->stack_history_size * sizeof(uptr);
// FIXME: increase the limit to 8 once this bug is fixed:
// https://bugs.llvm.org/show_bug.cgi?id=39030
+ // Note that we *cannot* do that on Android, as the runtime will indefinitely
+ // have to support code that is compiled with ashr, which only works with
+ // shifts up to 6.
for (int shift = 0; shift < 7; ++shift) {
uptr size = 4096 * (1ULL << shift);
if (size >= desired_bytes)
diff --git a/flang/lib/Semantics/check-declarations.cpp b/flang/lib/Semantics/check-declarations.cpp
index dec8fee774c5..b2de37759a06 100644
--- a/flang/lib/Semantics/check-declarations.cpp
+++ b/flang/lib/Semantics/check-declarations.cpp
@@ -948,6 +948,11 @@ void CheckHelper::CheckObjectEntity(
"Component '%s' with ATTRIBUTES(DEVICE) must also be allocatable"_err_en_US,
symbol.name());
}
+ if (IsAssumedSizeArray(symbol)) {
+ messages_.Say(
+ "Object '%s' with ATTRIBUTES(DEVICE) may not be assumed size"_err_en_US,
+ symbol.name());
+ }
break;
case common::CUDADataAttr::Managed:
if (!IsAutomatic(symbol) && !IsAllocatable(symbol) &&
diff --git a/flang/test/Semantics/cuf03.cuf b/flang/test/Semantics/cuf03.cuf
index 41bfbb767813..7384a104831d 100644
--- a/flang/test/Semantics/cuf03.cuf
+++ b/flang/test/Semantics/cuf03.cuf
@@ -51,7 +51,8 @@ module m
contains
attributes(device) subroutine devsubr(n,da)
integer, intent(in) :: n
- real, device :: da(*) ! ok
+ !ERROR: Object 'da' with ATTRIBUTES(DEVICE) may not be assumed size
+ real, device :: da(*)
real, managed :: ma(n) ! ok
!WARNING: Pointer 'dp' may not be associated in a device subprogram
real, device, pointer :: dp
diff --git a/libcxx/include/CMakeLists.txt b/libcxx/include/CMakeLists.txt
index db3980342f50..097a41d4c417 100644
--- a/libcxx/include/CMakeLists.txt
+++ b/libcxx/include/CMakeLists.txt
@@ -738,6 +738,7 @@ set(files
__type_traits/datasizeof.h
__type_traits/decay.h
__type_traits/dependent_type.h
+ __type_traits/desugars_to.h
__type_traits/disjunction.h
__type_traits/enable_if.h
__type_traits/extent.h
@@ -822,7 +823,6 @@ set(files
__type_traits/nat.h
__type_traits/negation.h
__type_traits/noexcept_move_assign_container.h
- __type_traits/operation_traits.h
__type_traits/promote.h
__type_traits/rank.h
__type_traits/remove_all_extents.h
diff --git a/libcxx/include/__algorithm/comp.h b/libcxx/include/__algorithm/comp.h
index 3902f7560304..a089375e3da1 100644
--- a/libcxx/include/__algorithm/comp.h
+++ b/libcxx/include/__algorithm/comp.h
@@ -10,8 +10,7 @@
#define _LIBCPP___ALGORITHM_COMP_H
#include <__config>
-#include <__type_traits/integral_constant.h>
-#include <__type_traits/operation_traits.h>
+#include <__type_traits/desugars_to.h>
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
# pragma GCC system_header
@@ -27,7 +26,7 @@ struct __equal_to {
};
template <class _Tp, class _Up>
-struct __desugars_to<__equal_tag, __equal_to, _Tp, _Up> : true_type {};
+inline const bool __desugars_to_v<__equal_tag, __equal_to, _Tp, _Up> = true;
// The definition is required because __less is part of the ABI, but it's empty
// because all comparisons should be transparent.
diff --git a/libcxx/include/__algorithm/equal.h b/libcxx/include/__algorithm/equal.h
index c76a16b47f5d..1341d9e4159b 100644
--- a/libcxx/include/__algorithm/equal.h
+++ b/libcxx/include/__algorithm/equal.h
@@ -18,12 +18,11 @@
#include <__iterator/distance.h>
#include <__iterator/iterator_traits.h>
#include <__string/constexpr_c_functions.h>
+#include <__type_traits/desugars_to.h>
#include <__type_traits/enable_if.h>
-#include <__type_traits/integral_constant.h>
#include <__type_traits/is_constant_evaluated.h>
#include <__type_traits/is_equality_comparable.h>
#include <__type_traits/is_volatile.h>
-#include <__type_traits/operation_traits.h>
#include <__utility/move.h>
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -47,7 +46,7 @@ _LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 boo
template <class _Tp,
class _Up,
class _BinaryPredicate,
- __enable_if_t<__desugars_to<__equal_tag, _BinaryPredicate, _Tp, _Up>::value && !is_volatile<_Tp>::value &&
+ __enable_if_t<__desugars_to_v<__equal_tag, _BinaryPredicate, _Tp, _Up> && !is_volatile<_Tp>::value &&
!is_volatile<_Up>::value && __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value,
int> = 0>
_LIBCPP_NODISCARD inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 bool
@@ -87,7 +86,7 @@ template <class _Tp,
class _Pred,
class _Proj1,
class _Proj2,
- __enable_if_t<__desugars_to<__equal_tag, _Pred, _Tp, _Up>::value && __is_identity<_Proj1>::value &&
+ __enable_if_t<__desugars_to_v<__equal_tag, _Pred, _Tp, _Up> && __is_identity<_Proj1>::value &&
__is_identity<_Proj2>::value && !is_volatile<_Tp>::value && !is_volatile<_Up>::value &&
__libcpp_is_trivially_equality_comparable<_Tp, _Up>::value,
int> = 0>
diff --git a/libcxx/include/__algorithm/mismatch.h b/libcxx/include/__algorithm/mismatch.h
index 8abb273ac178..4ada29eabc47 100644
--- a/libcxx/include/__algorithm/mismatch.h
+++ b/libcxx/include/__algorithm/mismatch.h
@@ -16,11 +16,11 @@
#include <__algorithm/unwrap_iter.h>
#include <__config>
#include <__functional/identity.h>
+#include <__type_traits/desugars_to.h>
#include <__type_traits/invoke.h>
#include <__type_traits/is_constant_evaluated.h>
#include <__type_traits/is_equality_comparable.h>
#include <__type_traits/is_integral.h>
-#include <__type_traits/operation_traits.h>
#include <__utility/move.h>
#include <__utility/pair.h>
#include <__utility/unreachable.h>
@@ -59,7 +59,7 @@ template <class _Tp,
class _Pred,
class _Proj1,
class _Proj2,
- __enable_if_t<is_integral<_Tp>::value && __desugars_to<__equal_tag, _Pred, _Tp, _Tp>::value &&
+ __enable_if_t<is_integral<_Tp>::value && __desugars_to_v<__equal_tag, _Pred, _Tp, _Tp> &&
__is_identity<_Proj1>::value && __is_identity<_Proj2>::value,
int> = 0>
_LIBCPP_NODISCARD _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20 pair<_Tp*, _Tp*>
diff --git a/libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h b/libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h
index 14a0d76741d4..376abd39fa36 100644
--- a/libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h
+++ b/libcxx/include/__algorithm/pstl_backends/cpu_backends/transform_reduce.h
@@ -14,9 +14,9 @@
#include <__iterator/concepts.h>
#include <__iterator/iterator_traits.h>
#include <__numeric/transform_reduce.h>
+#include <__type_traits/desugars_to.h>
#include <__type_traits/is_arithmetic.h>
#include <__type_traits/is_execution_policy.h>
-#include <__type_traits/operation_traits.h>
#include <__utility/move.h>
#include <new>
#include <optional>
@@ -37,7 +37,7 @@ template <typename _DifferenceType,
typename _BinaryOperation,
typename _UnaryOperation,
typename _UnaryResult = invoke_result_t<_UnaryOperation, _DifferenceType>,
- __enable_if_t<__desugars_to<__plus_tag, _BinaryOperation, _Tp, _UnaryResult>::value && is_arithmetic_v<_Tp> &&
+ __enable_if_t<__desugars_to_v<__plus_tag, _BinaryOperation, _Tp, _UnaryResult> && is_arithmetic_v<_Tp> &&
is_arithmetic_v<_UnaryResult>,
int> = 0>
_LIBCPP_HIDE_FROM_ABI _Tp
@@ -53,8 +53,8 @@ template <typename _Size,
typename _BinaryOperation,
typename _UnaryOperation,
typename _UnaryResult = invoke_result_t<_UnaryOperation, _Size>,
- __enable_if_t<!(__desugars_to<__plus_tag, _BinaryOperation, _Tp, _UnaryResult>::value &&
- is_arithmetic_v<_Tp> && is_arithmetic_v<_UnaryResult>),
+ __enable_if_t<!(__desugars_to_v<__plus_tag, _BinaryOperation, _Tp, _UnaryResult> && is_arithmetic_v<_Tp> &&
+ is_arithmetic_v<_UnaryResult>),
int> = 0>
_LIBCPP_HIDE_FROM_ABI _Tp
__simd_transform_reduce(_Size __n, _Tp __init, _BinaryOperation __binary_op, _UnaryOperation __f) noexcept {
diff --git a/libcxx/include/__functional/operations.h b/libcxx/include/__functional/operations.h
index 7ddc00650f16..9aa28e492506 100644
--- a/libcxx/include/__functional/operations.h
+++ b/libcxx/include/__functional/operations.h
@@ -13,8 +13,7 @@
#include <__config>
#include <__functional/binary_function.h>
#include <__functional/unary_function.h>
-#include <__type_traits/integral_constant.h>
-#include <__type_traits/operation_traits.h>
+#include <__type_traits/desugars_to.h>
#include <__utility/forward.h>
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -41,10 +40,10 @@ _LIBCPP_CTAD_SUPPORTED_FOR_TYPE(plus);
// The non-transparent std::plus specialization is only equivalent to a raw plus
// operator when we don't perform an implicit conversion when calling it.
template <class _Tp>
-struct __desugars_to<__plus_tag, plus<_Tp>, _Tp, _Tp> : true_type {};
+inline const bool __desugars_to_v<__plus_tag, plus<_Tp>, _Tp, _Tp> = true;
template <class _Tp, class _Up>
-struct __desugars_to<__plus_tag, plus<void>, _Tp, _Up> : true_type {};
+inline const bool __desugars_to_v<__plus_tag, plus<void>, _Tp, _Up> = true;
#if _LIBCPP_STD_VER >= 14
template <>
@@ -315,11 +314,11 @@ struct _LIBCPP_TEMPLATE_VIS equal_to<void> {
// The non-transparent std::equal_to specialization is only equivalent to a raw equality
// comparison when we don't perform an implicit conversion when calling it.
template <class _Tp>
-struct __desugars_to<__equal_tag, equal_to<_Tp>, _Tp, _Tp> : true_type {};
+inline const bool __desugars_to_v<__equal_tag, equal_to<_Tp>, _Tp, _Tp> = true;
// In the transparent case, we do not enforce that
template <class _Tp, class _Up>
-struct __desugars_to<__equal_tag, equal_to<void>, _Tp, _Up> : true_type {};
+inline const bool __desugars_to_v<__equal_tag, equal_to<void>, _Tp, _Up> = true;
#if _LIBCPP_STD_VER >= 14
template <class _Tp = void>
diff --git a/libcxx/include/__functional/ranges_operations.h b/libcxx/include/__functional/ranges_operations.h
index 38b28018049e..a9dffaf69625 100644
--- a/libcxx/include/__functional/ranges_operations.h
+++ b/libcxx/include/__functional/ranges_operations.h
@@ -13,8 +13,7 @@
#include <__concepts/equality_comparable.h>
#include <__concepts/totally_ordered.h>
#include <__config>
-#include <__type_traits/integral_constant.h>
-#include <__type_traits/operation_traits.h>
+#include <__type_traits/desugars_to.h>
#include <__utility/forward.h>
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -98,7 +97,7 @@ struct greater_equal {
// For ranges we do not require that the types on each side of the equality
// operator are of the same type
template <class _Tp, class _Up>
-struct __desugars_to<__equal_tag, ranges::equal_to, _Tp, _Up> : true_type {};
+inline const bool __desugars_to_v<__equal_tag, ranges::equal_to, _Tp, _Up> = true;
#endif // _LIBCPP_STD_VER >= 20
diff --git a/libcxx/include/__numeric/pstl_transform_reduce.h b/libcxx/include/__numeric/pstl_transform_reduce.h
index 2f412d41f7f2..07ecf0d9956b 100644
--- a/libcxx/include/__numeric/pstl_transform_reduce.h
+++ b/libcxx/include/__numeric/pstl_transform_reduce.h
@@ -87,7 +87,7 @@ _LIBCPP_HIDE_FROM_ABI _Tp transform_reduce(
}
// This overload doesn't get a customization point because it's trivial to detect (through e.g.
-// __desugars_to) when specializing the more general variant, which should always be preferred
+// __desugars_to_v) when specializing the more general variant, which should always be preferred
template <class _ExecutionPolicy,
class _ForwardIterator1,
class _ForwardIterator2,
diff --git a/libcxx/include/__type_traits/operation_traits.h b/libcxx/include/__type_traits/desugars_to.h
index ef6e71693430..a8f69c28dfc5 100644
--- a/libcxx/include/__type_traits/operation_traits.h
+++ b/libcxx/include/__type_traits/desugars_to.h
@@ -6,11 +6,10 @@
//
//===----------------------------------------------------------------------===//
-#ifndef _LIBCPP___TYPE_TRAITS_OPERATION_TRAITS_H
-#define _LIBCPP___TYPE_TRAITS_OPERATION_TRAITS_H
+#ifndef _LIBCPP___TYPE_TRAITS_DESUGARS_TO_H
+#define _LIBCPP___TYPE_TRAITS_DESUGARS_TO_H
#include <__config>
-#include <__type_traits/integral_constant.h>
#if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
# pragma GCC system_header
@@ -33,8 +32,8 @@ struct __plus_tag {};
// predicate being passed is actually going to call a builtin operator, or has
// some specific semantics.
template <class _CanonicalTag, class _Operation, class... _Args>
-struct __desugars_to : false_type {};
+inline const bool __desugars_to_v = false;
_LIBCPP_END_NAMESPACE_STD
-#endif // _LIBCPP___TYPE_TRAITS_OPERATION_TRAITS_H
+#endif // _LIBCPP___TYPE_TRAITS_DESUGARS_TO_H
diff --git a/libcxx/include/libcxx.imp b/libcxx/include/libcxx.imp
index 2cb1fa5e1e2a..607f63e6d822 100644
--- a/libcxx/include/libcxx.imp
+++ b/libcxx/include/libcxx.imp
@@ -734,6 +734,7 @@
{ include: [ "<__type_traits/datasizeof.h>", "private", "<type_traits>", "public" ] },
{ include: [ "<__type_traits/decay.h>", "private", "<type_traits>", "public" ] },
{ include: [ "<__type_traits/dependent_type.h>", "private", "<type_traits>", "public" ] },
+ { include: [ "<__type_traits/desugars_to.h>", "private", "<type_traits>", "public" ] },
{ include: [ "<__type_traits/disjunction.h>", "private", "<type_traits>", "public" ] },
{ include: [ "<__type_traits/enable_if.h>", "private", "<type_traits>", "public" ] },
{ include: [ "<__type_traits/extent.h>", "private", "<type_traits>", "public" ] },
@@ -818,7 +819,6 @@
{ include: [ "<__type_traits/nat.h>", "private", "<type_traits>", "public" ] },
{ include: [ "<__type_traits/negation.h>", "private", "<type_traits>", "public" ] },
{ include: [ "<__type_traits/noexcept_move_assign_container.h>", "private", "<type_traits>", "public" ] },
- { include: [ "<__type_traits/operation_traits.h>", "private", "<type_traits>", "public" ] },
{ include: [ "<__type_traits/promote.h>", "private", "<type_traits>", "public" ] },
{ include: [ "<__type_traits/rank.h>", "private", "<type_traits>", "public" ] },
{ include: [ "<__type_traits/remove_all_extents.h>", "private", "<type_traits>", "public" ] },
diff --git a/libcxx/include/module.modulemap b/libcxx/include/module.modulemap
index 6d4dcc2511f3..ed45a1b18338 100644
--- a/libcxx/include/module.modulemap
+++ b/libcxx/include/module.modulemap
@@ -1867,6 +1867,7 @@ module std_private_type_traits_decay [system
export std_private_type_traits_add_pointer
}
module std_private_type_traits_dependent_type [system] { header "__type_traits/dependent_type.h" }
+module std_private_type_traits_desugars_to [system] { header "__type_traits/desugars_to.h" }
module std_private_type_traits_disjunction [system] { header "__type_traits/disjunction.h" }
module std_private_type_traits_enable_if [system] { header "__type_traits/enable_if.h" }
module std_private_type_traits_extent [system] { header "__type_traits/extent.h" }
@@ -2017,7 +2018,6 @@ module std_private_type_traits_maybe_const [system
module std_private_type_traits_nat [system] { header "__type_traits/nat.h" }
module std_private_type_traits_negation [system] { header "__type_traits/negation.h" }
module std_private_type_traits_noexcept_move_assign_container [system] { header "__type_traits/noexcept_move_assign_container.h" }
-module std_private_type_traits_operation_traits [system] { header "__type_traits/operation_traits.h" }
module std_private_type_traits_promote [system] { header "__type_traits/promote.h" }
module std_private_type_traits_rank [system] { header "__type_traits/rank.h" }
module std_private_type_traits_remove_all_extents [system] { header "__type_traits/remove_all_extents.h" }
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index c04f4c526921..f0723a633f0f 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1733,8 +1733,7 @@ def int_ubsantrap : Intrinsic<[], [llvm_i8_ty],
// Return true if ubsan check is allowed.
def int_allow_ubsan_check : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i8_ty],
- [IntrInaccessibleMemOnly, IntrWriteMem, ImmArg<ArgIndex<0>>, NoUndef<RetIndex>]>,
- ClangBuiltin<"__builtin_allow_ubsan_check">;
+ [IntrInaccessibleMemOnly, IntrWriteMem, ImmArg<ArgIndex<0>>, NoUndef<RetIndex>]>;
// Return true if runtime check is allowed.
def int_allow_runtime_check : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_metadata_ty],
diff --git a/llvm/lib/ProfileData/InstrProfWriter.cpp b/llvm/lib/ProfileData/InstrProfWriter.cpp
index 96ab729f91e4..72d77d5580a5 100644
--- a/llvm/lib/ProfileData/InstrProfWriter.cpp
+++ b/llvm/lib/ProfileData/InstrProfWriter.cpp
@@ -414,6 +414,144 @@ static void setSummary(IndexedInstrProf::Summary *TheSummary,
TheSummary->setEntry(I, Res[I]);
}
+// Serialize Schema.
+static void writeMemProfSchema(ProfOStream &OS,
+ const memprof::MemProfSchema &Schema) {
+ OS.write(static_cast<uint64_t>(Schema.size()));
+ for (const auto Id : Schema)
+ OS.write(static_cast<uint64_t>(Id));
+}
+
+// Serialize MemProfRecordData. Return RecordTableOffset.
+static uint64_t writeMemProfRecords(
+ ProfOStream &OS,
+ llvm::MapVector<GlobalValue::GUID, memprof::IndexedMemProfRecord>
+ &MemProfRecordData,
+ memprof::MemProfSchema *Schema) {
+ auto RecordWriter =
+ std::make_unique<memprof::RecordWriterTrait>(memprof::Version1);
+ RecordWriter->Schema = Schema;
+ OnDiskChainedHashTableGenerator<memprof::RecordWriterTrait>
+ RecordTableGenerator;
+ for (auto &I : MemProfRecordData) {
+ // Insert the key (func hash) and value (memprof record).
+ RecordTableGenerator.insert(I.first, I.second, *RecordWriter.get());
+ }
+ // Release the memory of this MapVector as it is no longer needed.
+ MemProfRecordData.clear();
+
+ // The call to Emit invokes RecordWriterTrait::EmitData which destructs
+ // the memprof record copies owned by the RecordTableGenerator. This works
+ // because the RecordTableGenerator is not used after this point.
+ return RecordTableGenerator.Emit(OS.OS, *RecordWriter);
+}
+
+// Serialize MemProfFrameData. Return FrameTableOffset.
+static uint64_t writeMemProfFrames(
+ ProfOStream &OS,
+ llvm::MapVector<memprof::FrameId, memprof::Frame> &MemProfFrameData) {
+ auto FrameWriter = std::make_unique<memprof::FrameWriterTrait>();
+ OnDiskChainedHashTableGenerator<memprof::FrameWriterTrait>
+ FrameTableGenerator;
+ for (auto &I : MemProfFrameData) {
+ // Insert the key (frame id) and value (frame contents).
+ FrameTableGenerator.insert(I.first, I.second);
+ }
+ // Release the memory of this MapVector as it is no longer needed.
+ MemProfFrameData.clear();
+
+ return FrameTableGenerator.Emit(OS.OS, *FrameWriter);
+}
+
+static Error writeMemProfV0(
+ ProfOStream &OS,
+ llvm::MapVector<GlobalValue::GUID, memprof::IndexedMemProfRecord>
+ &MemProfRecordData,
+ llvm::MapVector<memprof::FrameId, memprof::Frame> &MemProfFrameData) {
+ uint64_t HeaderUpdatePos = OS.tell();
+ OS.write(0ULL); // Reserve space for the memprof record table offset.
+ OS.write(0ULL); // Reserve space for the memprof frame payload offset.
+ OS.write(0ULL); // Reserve space for the memprof frame table offset.
+
+ auto Schema = memprof::PortableMemInfoBlock::getSchema();
+ writeMemProfSchema(OS, Schema);
+
+ uint64_t RecordTableOffset =
+ writeMemProfRecords(OS, MemProfRecordData, &Schema);
+
+ uint64_t FramePayloadOffset = OS.tell();
+ uint64_t FrameTableOffset = writeMemProfFrames(OS, MemProfFrameData);
+
+ uint64_t Header[] = {RecordTableOffset, FramePayloadOffset, FrameTableOffset};
+ OS.patch({{HeaderUpdatePos, Header, std::size(Header)}});
+
+ return Error::success();
+}
+
+static Error writeMemProfV1(
+ ProfOStream &OS,
+ llvm::MapVector<GlobalValue::GUID, memprof::IndexedMemProfRecord>
+ &MemProfRecordData,
+ llvm::MapVector<memprof::FrameId, memprof::Frame> &MemProfFrameData) {
+ OS.write(memprof::Version0);
+ uint64_t HeaderUpdatePos = OS.tell();
+ OS.write(0ULL); // Reserve space for the memprof record table offset.
+ OS.write(0ULL); // Reserve space for the memprof frame payload offset.
+ OS.write(0ULL); // Reserve space for the memprof frame table offset.
+
+ auto Schema = memprof::PortableMemInfoBlock::getSchema();
+ writeMemProfSchema(OS, Schema);
+
+ uint64_t RecordTableOffset =
+ writeMemProfRecords(OS, MemProfRecordData, &Schema);
+
+ uint64_t FramePayloadOffset = OS.tell();
+ uint64_t FrameTableOffset = writeMemProfFrames(OS, MemProfFrameData);
+
+ uint64_t Header[] = {RecordTableOffset, FramePayloadOffset, FrameTableOffset};
+ OS.patch({{HeaderUpdatePos, Header, std::size(Header)}});
+
+ return Error::success();
+}
+
+// The MemProf profile data includes a simple schema
+// with the format described below followed by the hashtable:
+// uint64_t Version
+// uint64_t RecordTableOffset = RecordTableGenerator.Emit
+// uint64_t FramePayloadOffset = Stream offset before emitting the frame table
+// uint64_t FrameTableOffset = FrameTableGenerator.Emit
+// uint64_t Num schema entries
+// uint64_t Schema entry 0
+// uint64_t Schema entry 1
+// ....
+// uint64_t Schema entry N - 1
+// OnDiskChainedHashTable MemProfRecordData
+// OnDiskChainedHashTable MemProfFrameData
+static Error writeMemProf(
+ ProfOStream &OS,
+ llvm::MapVector<GlobalValue::GUID, memprof::IndexedMemProfRecord>
+ &MemProfRecordData,
+ llvm::MapVector<memprof::FrameId, memprof::Frame> &MemProfFrameData,
+ memprof::IndexedVersion MemProfVersionRequested) {
+
+ switch (MemProfVersionRequested) {
+ case memprof::Version0:
+ return writeMemProfV0(OS, MemProfRecordData, MemProfFrameData);
+ case memprof::Version1:
+ return writeMemProfV1(OS, MemProfRecordData, MemProfFrameData);
+ case memprof::Version2:
+ // TODO: Implement. Fall through to the error handling below for now.
+ break;
+ }
+
+ return make_error<InstrProfError>(
+ instrprof_error::unsupported_version,
+ formatv("MemProf version {} not supported; "
+ "requires version between {} and {}, inclusive",
+ MemProfVersionRequested, memprof::MinimumSupportedVersion,
+ memprof::MaximumSupportedVersion));
+}
+
Error InstrProfWriter::writeImpl(ProfOStream &OS) {
using namespace IndexedInstrProf;
using namespace support;
@@ -517,85 +655,13 @@ Error InstrProfWriter::writeImpl(ProfOStream &OS) {
// Write the hash table.
uint64_t HashTableStart = Generator.Emit(OS.OS, *InfoObj);
- // Write the MemProf profile data if we have it. This includes a simple schema
- // with the format described below followed by the hashtable:
- // uint64_t Version
- // uint64_t RecordTableOffset = RecordTableGenerator.Emit
- // uint64_t FramePayloadOffset = Stream offset before emitting the frame table
- // uint64_t FrameTableOffset = FrameTableGenerator.Emit
- // uint64_t Num schema entries
- // uint64_t Schema entry 0
- // uint64_t Schema entry 1
- // ....
- // uint64_t Schema entry N - 1
- // OnDiskChainedHashTable MemProfRecordData
- // OnDiskChainedHashTable MemProfFrameData
+ // Write the MemProf profile data if we have it.
uint64_t MemProfSectionStart = 0;
if (static_cast<bool>(ProfileKind & InstrProfKind::MemProf)) {
- if (MemProfVersionRequested < memprof::MinimumSupportedVersion ||
- MemProfVersionRequested > memprof::MaximumSupportedVersion) {
- return make_error<InstrProfError>(
- instrprof_error::unsupported_version,
- formatv("MemProf version {} not supported; "
- "requires version between {} and {}, inclusive",
- MemProfVersionRequested, memprof::MinimumSupportedVersion,
- memprof::MaximumSupportedVersion));
- }
-
MemProfSectionStart = OS.tell();
-
- if (MemProfVersionRequested >= memprof::Version1)
- OS.write(MemProfVersionRequested);
-
- OS.write(0ULL); // Reserve space for the memprof record table offset.
- OS.write(0ULL); // Reserve space for the memprof frame payload offset.
- OS.write(0ULL); // Reserve space for the memprof frame table offset.
-
- auto Schema = memprof::PortableMemInfoBlock::getSchema();
- OS.write(static_cast<uint64_t>(Schema.size()));
- for (const auto Id : Schema) {
- OS.write(static_cast<uint64_t>(Id));
- }
-
- auto RecordWriter =
- std::make_unique<memprof::RecordWriterTrait>(memprof::Version1);
- RecordWriter->Schema = &Schema;
- OnDiskChainedHashTableGenerator<memprof::RecordWriterTrait>
- RecordTableGenerator;
- for (auto &I : MemProfRecordData) {
- // Insert the key (func hash) and value (memprof record).
- RecordTableGenerator.insert(I.first, I.second, *RecordWriter.get());
- }
- // Release the memory of this MapVector as it is no longer needed.
- MemProfRecordData.clear();
-
- // The call to Emit invokes RecordWriterTrait::EmitData which destructs
- // the memprof record copies owned by the RecordTableGenerator. This works
- // because the RecordTableGenerator is not used after this point.
- uint64_t RecordTableOffset =
- RecordTableGenerator.Emit(OS.OS, *RecordWriter);
-
- uint64_t FramePayloadOffset = OS.tell();
-
- auto FrameWriter = std::make_unique<memprof::FrameWriterTrait>();
- OnDiskChainedHashTableGenerator<memprof::FrameWriterTrait>
- FrameTableGenerator;
- for (auto &I : MemProfFrameData) {
- // Insert the key (frame id) and value (frame contents).
- FrameTableGenerator.insert(I.first, I.second);
- }
- // Release the memory of this MapVector as it is no longer needed.
- MemProfFrameData.clear();
-
- uint64_t FrameTableOffset = FrameTableGenerator.Emit(OS.OS, *FrameWriter);
-
- uint64_t Header[] = {RecordTableOffset, FramePayloadOffset,
- FrameTableOffset};
- uint64_t HeaderUpdatePos = MemProfSectionStart;
- if (MemProfVersionRequested >= memprof::Version1)
- // The updates go just after the version field.
- HeaderUpdatePos += sizeof(uint64_t);
- OS.patch({{HeaderUpdatePos, Header, std::size(Header)}});
+ if (auto E = writeMemProf(OS, MemProfRecordData, MemProfFrameData,
+ MemProfVersionRequested))
+ return E;
}
// BinaryIdSection has two parts:
diff --git a/llvm/lib/Target/ARM/Thumb2InstrInfo.cpp b/llvm/lib/Target/ARM/Thumb2InstrInfo.cpp
index fc2834cb0b45..083f25f49dec 100644
--- a/llvm/lib/Target/ARM/Thumb2InstrInfo.cpp
+++ b/llvm/lib/Target/ARM/Thumb2InstrInfo.cpp
@@ -286,25 +286,6 @@ MachineInstr *Thumb2InstrInfo::commuteInstructionImpl(MachineInstr &MI,
return ARMBaseInstrInfo::commuteInstructionImpl(MI, NewMI, OpIdx1, OpIdx2);
}
-bool Thumb2InstrInfo::isSchedulingBoundary(const MachineInstr &MI,
- const MachineBasicBlock *MBB,
- const MachineFunction &MF) const {
- // BTI clearing instructions shall not take part in scheduling regions as
- // they must stay in their intended place. Although PAC isn't BTI clearing,
- // it can be transformed into PACBTI after the pre-RA Machine Scheduling
- // has taken place, so its movement must also be restricted.
- switch (MI.getOpcode()) {
- case ARM::t2BTI:
- case ARM::t2PAC:
- case ARM::t2PACBTI:
- case ARM::t2SG:
- return true;
- default:
- break;
- }
- return ARMBaseInstrInfo::isSchedulingBoundary(MI, MBB, MF);
-}
-
void llvm::emitT2RegPlusImmediate(MachineBasicBlock &MBB,
MachineBasicBlock::iterator &MBBI,
const DebugLoc &dl, Register DestReg,
diff --git a/llvm/lib/Target/ARM/Thumb2InstrInfo.h b/llvm/lib/Target/ARM/Thumb2InstrInfo.h
index 8915da8c5bf3..4bb412f09dcb 100644
--- a/llvm/lib/Target/ARM/Thumb2InstrInfo.h
+++ b/llvm/lib/Target/ARM/Thumb2InstrInfo.h
@@ -68,10 +68,6 @@ public:
unsigned OpIdx1,
unsigned OpIdx2) const override;
- bool isSchedulingBoundary(const MachineInstr &MI,
- const MachineBasicBlock *MBB,
- const MachineFunction &MF) const override;
-
private:
void expandLoadStackGuard(MachineBasicBlock::iterator MI) const override;
};
diff --git a/llvm/test/CodeGen/ARM/misched-branch-targets.mir b/llvm/test/CodeGen/ARM/misched-branch-targets.mir
deleted file mode 100644
index b071fbd4538a..000000000000
--- a/llvm/test/CodeGen/ARM/misched-branch-targets.mir
+++ /dev/null
@@ -1,166 +0,0 @@
-# RUN: llc -o - -run-pass=machine-scheduler -misched=shuffle %s | FileCheck %s
-# RUN: llc -o - -run-pass=postmisched %s | FileCheck %s
-
---- |
- target datalayout = "e-m:e-p:32:32-Fi8-i64:64-v128:64:128-a:0:32-n32-S64"
- target triple = "thumbv8.1m.main-arm-none-eabi"
-
- define i32 @foo_bti() #0 {
- entry:
- ret i32 0
- }
-
- define i32 @foo_pac() #0 {
- entry:
- ret i32 0
- }
-
- define i32 @foo_pacbti() #0 {
- entry:
- ret i32 0
- }
-
- define i32 @foo_setjmp() #0 {
- entry:
- ret i32 0
- if.then:
- ret i32 0
- }
-
- define i32 @foo_sg() #0 {
- entry:
- ret i32 0
- }
-
- declare i32 @setjmp(ptr noundef) #1
- declare void @longjmp(ptr noundef, i32 noundef) #2
-
- attributes #0 = { "frame-pointer"="all" "target-cpu"="cortex-m55" "target-features"="+armv8.1-m.main" }
- attributes #1 = { nounwind returns_twice "frame-pointer"="all" "target-cpu"="cortex-m55" "target-features"="+armv8.1-m.main" }
- attributes #2 = { noreturn nounwind "frame-pointer"="all" "target-cpu"="cortex-m55" "target-features"="+armv8.1-m.main" }
-
-...
----
-name: foo_bti
-tracksRegLiveness: true
-body: |
- bb.0.entry:
- liveins: $r0
-
- t2BTI
- renamable $r0, dead $cpsr = nsw tADDi8 killed renamable $r0, 1, 14 /* CC::al */, $noreg
- tBX_RET 14 /* CC::al */, $noreg, implicit killed $r0
-
-...
-
-# CHECK-LABEL: name: foo_bti
-# CHECK: body:
-# CHECK-NEXT: bb.0.entry:
-# CHECK-NEXT: liveins: $r0
-# CHECK-NEXT: {{^ +$}}
-# CHECK-NEXT: t2BTI
-
----
-name: foo_pac
-tracksRegLiveness: true
-body: |
- bb.0.entry:
- liveins: $r0, $lr, $r12
-
- frame-setup t2PAC implicit-def $r12, implicit $lr, implicit $sp
- renamable $r2 = nsw t2ADDri $r0, 3, 14 /* CC::al */, $noreg, $noreg
- $sp = frame-setup t2STMDB_UPD $sp, 14 /* CC::al */, $noreg, killed $r7, killed $lr
- $r7 = frame-setup tMOVr killed $sp, 14 /* CC::al */, $noreg
- early-clobber $sp = frame-setup t2STR_PRE killed $r12, $sp, -4, 14 /* CC::al */, $noreg
- $r12, $sp = frame-destroy t2LDR_POST $sp, 4, 14 /* CC::al */, $noreg
- $sp = frame-destroy t2LDMIA_UPD $sp, 14 /* CC::al */, $noreg, def $r7, def $lr
- t2AUT implicit $r12, implicit $lr, implicit $sp
- tBX_RET 14 /* CC::al */, $noreg, implicit $r0
-
-...
-
-# CHECK-LABEL: name: foo_pac
-# CHECK: body:
-# CHECK-NEXT: bb.0.entry:
-# CHECK-NEXT: liveins: $r0, $lr, $r12
-# CHECK-NEXT: {{^ +$}}
-# CHECK-NEXT: frame-setup t2PAC implicit-def $r12, implicit $lr, implicit $sp
-
----
-name: foo_pacbti
-tracksRegLiveness: true
-body: |
- bb.0.entry:
- liveins: $r0, $lr, $r12
-
- frame-setup t2PACBTI implicit-def $r12, implicit $lr, implicit $sp
- renamable $r2 = nsw t2ADDri $r0, 3, 14 /* CC::al */, $noreg, $noreg
- $sp = frame-setup t2STMDB_UPD $sp, 14 /* CC::al */, $noreg, killed $r7, killed $lr
- $r7 = frame-setup tMOVr killed $sp, 14 /* CC::al */, $noreg
- early-clobber $sp = frame-setup t2STR_PRE killed $r12, $sp, -4, 14 /* CC::al */, $noreg
- $r12, $sp = frame-destroy t2LDR_POST $sp, 4, 14 /* CC::al */, $noreg
- $sp = frame-destroy t2LDMIA_UPD $sp, 14 /* CC::al */, $noreg, def $r7, def $lr
- t2AUT implicit $r12, implicit $lr, implicit $sp
- tBX_RET 14 /* CC::al */, $noreg, implicit $r0
-
-...
-
-# CHECK-LABEL: name: foo_pacbti
-# CHECK: body:
-# CHECK-NEXT: bb.0.entry:
-# CHECK-NEXT: liveins: $r0, $lr, $r12
-# CHECK-NEXT: {{^ +$}}
-# CHECK-NEXT: frame-setup t2PACBTI implicit-def $r12, implicit $lr, implicit $sp
-
----
-name: foo_setjmp
-tracksRegLiveness: true
-body: |
- bb.0.entry:
- successors: %bb.1
- liveins: $lr
-
- frame-setup tPUSH 14 /* CC::al */, $noreg, $r7, killed $lr, implicit-def $sp, implicit $sp
- $r7 = frame-setup tMOVr $sp, 14 /* CC::al */, $noreg
- $sp = frame-setup tSUBspi $sp, 40, 14 /* CC::al */, $noreg
- renamable $r0 = tMOVr $sp, 14 /* CC::al */, $noreg
- tBL 14 /* CC::al */, $noreg, @setjmp, csr_aapcs, implicit-def dead $lr, implicit $sp, implicit killed $r0, implicit-def $sp, implicit-def $r0
- t2BTI
- renamable $r2 = nsw t2ADDri $r0, 3, 14 /* CC::al */, $noreg, $noreg
- tCMPi8 killed renamable $r0, 0, 14 /* CC::al */, $noreg, implicit-def $cpsr
- t2IT 0, 2, implicit-def $itstate
- renamable $r0 = tMOVi8 $noreg, 0, 0 /* CC::eq */, $cpsr, implicit $itstate
- $sp = frame-destroy tADDspi $sp, 40, 0 /* CC::eq */, $cpsr, implicit $itstate
- frame-destroy tPOP_RET 0 /* CC::eq */, killed $cpsr, def $r7, def $pc, implicit killed $r0, implicit $sp, implicit killed $itstate
-
- bb.1.if.then:
- renamable $r0 = tMOVr $sp, 14 /* CC::al */, $noreg
- renamable $r1, dead $cpsr = tMOVi8 1, 14 /* CC::al */, $noreg
- tBL 14 /* CC::al */, $noreg, @longjmp, csr_aapcs, implicit-def dead $lr, implicit $sp, implicit killed $r0, implicit killed $r1, implicit-def $sp
-
-...
-
-# CHECK-LABEL: name: foo_setjmp
-# CHECK: body:
-# CHECK: tBL 14 /* CC::al */, $noreg, @setjmp, csr_aapcs, implicit-def dead $lr, implicit $sp, implicit killed $r0, implicit-def $sp, implicit-def $r0
-# CHECK-NEXT: t2BTI
-
----
-name: foo_sg
-tracksRegLiveness: true
-body: |
- bb.0.entry:
- liveins: $r0
-
- t2SG 14 /* CC::al */, $noreg
- renamable $r0, dead $cpsr = nsw tADDi8 killed renamable $r0, 1, 14 /* CC::al */, $noreg
- tBX_RET 14 /* CC::al */, $noreg, implicit killed $r0
-
-...
-
-# CHECK-LABEL: name: foo_sg
-# CHECK: body:
-# CHECK-NEXT: bb.0.entry:
-# CHECK-NEXT: liveins: $r0
-# CHECK-NEXT: {{^ +$}}
-# CHECK-NEXT: t2SG
diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
index 723a262f24ac..d143954b78fc 100644
--- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
@@ -299,53 +299,8 @@ LogicalResult
separateFullTiles(MutableArrayRef<AffineForOp> nest,
SmallVectorImpl<AffineForOp> *fullTileNest = nullptr);
-/// Walk either an scf.for or an affine.for to find a band to coalesce.
-template <typename LoopOpTy>
-LogicalResult coalescePerfectlyNestedLoops(LoopOpTy op) {
- LogicalResult result(failure());
- SmallVector<LoopOpTy> loops;
- getPerfectlyNestedLoops(loops, op);
-
- // Look for a band of loops that can be coalesced, i.e. perfectly nested
- // loops with bounds defined above some loop.
- // 1. For each loop, find above which parent loop its operands are
- // defined.
- SmallVector<unsigned, 4> operandsDefinedAbove(loops.size());
- for (unsigned i = 0, e = loops.size(); i < e; ++i) {
- operandsDefinedAbove[i] = i;
- for (unsigned j = 0; j < i; ++j) {
- if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) {
- operandsDefinedAbove[i] = j;
- break;
- }
- }
- }
-
- // 2. Identify bands of loops such that the operands of all of them are
- // defined above the first loop in the band. Traverse the nest bottom-up
- // so that modifications don't invalidate the inner loops.
- for (unsigned end = loops.size(); end > 0; --end) {
- unsigned start = 0;
- for (; start < end - 1; ++start) {
- auto maxPos =
- *std::max_element(std::next(operandsDefinedAbove.begin(), start),
- std::next(operandsDefinedAbove.begin(), end));
- if (maxPos > start)
- continue;
- assert(maxPos == start &&
- "expected loop bounds to be known at the start of the band");
- auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
- if (succeeded(coalesceLoops(band)))
- result = success();
- break;
- }
- // If a band was found and transformed, keep looking at the loops above
- // the outermost transformed loop.
- if (start != end - 1)
- end = start + 1;
- }
- return result;
-}
+/// Walk an affine.for to find a band to coalesce.
+LogicalResult coalescePerfectlyNestedAffineLoops(AffineForOp op);
} // namespace affine
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index 883d11bcc4df..bc09cc7f7fa5 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -100,11 +100,16 @@ getSCFMinMaxExpr(Value value, SmallVectorImpl<Value> &dims,
/// `loops` contains a list of perfectly nested loops with bounds and steps
/// independent of any loop induction variable involved in the nest.
LogicalResult coalesceLoops(MutableArrayRef<scf::ForOp> loops);
+LogicalResult coalesceLoops(RewriterBase &rewriter,
+ MutableArrayRef<scf::ForOp>);
+
+/// Walk an affine.for to find a band to coalesce.
+LogicalResult coalescePerfectlyNestedSCFForLoops(scf::ForOp op);
/// Take the ParallelLoop and for each set of dimension indices, combine them
/// into a single dimension. combinedDimensions must contain each index into
/// loops exactly once.
-void collapseParallelLoops(scf::ParallelOp loops,
+void collapseParallelLoops(RewriterBase &rewriter, scf::ParallelOp loops,
ArrayRef<std::vector<unsigned>> combinedDimensions);
/// Unrolls this for operation by the specified unroll factor. Returns failure
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 15b1c3892948..2562301e499d 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -15,6 +15,7 @@
#include "llvm/Support/TypeName.h"
#include <optional>
+using llvm::SmallPtrSetImpl;
namespace mlir {
class PatternRewriter;
@@ -704,6 +705,8 @@ public:
return user != exceptedUser;
});
}
+ void replaceAllUsesExcept(Value from, Value to,
+ const SmallPtrSetImpl<Operation *> &preservedUsers);
/// Used to notify the listener that the IR failed to be rewritten because of
/// a match failure, and provide a callback to populate a diagnostic with the
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
index 1dc69ab493d4..05c77070a70c 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
@@ -39,9 +39,9 @@ struct LoopCoalescingPass
func::FuncOp func = getOperation();
func.walk<WalkOrder::PreOrder>([](Operation *op) {
if (auto scfForOp = dyn_cast<scf::ForOp>(op))
- (void)coalescePerfectlyNestedLoops(scfForOp);
+ (void)coalescePerfectlyNestedSCFForLoops(scfForOp);
else if (auto affineForOp = dyn_cast<AffineForOp>(op))
- (void)coalescePerfectlyNestedLoops(affineForOp);
+ (void)coalescePerfectlyNestedAffineLoops(affineForOp);
});
}
};
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index af59973d7a92..268050a30e00 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -2765,3 +2765,51 @@ mlir::affine::separateFullTiles(MutableArrayRef<AffineForOp> inputNest,
return success();
}
+
+LogicalResult affine::coalescePerfectlyNestedAffineLoops(AffineForOp op) {
+ LogicalResult result(failure());
+ SmallVector<AffineForOp> loops;
+ getPerfectlyNestedLoops(loops, op);
+ if (loops.size() <= 1)
+ return success();
+
+ // Look for a band of loops that can be coalesced, i.e. perfectly nested
+ // loops with bounds defined above some loop.
+ // 1. For each loop, find above which parent loop its operands are
+ // defined.
+ SmallVector<unsigned> operandsDefinedAbove(loops.size());
+ for (unsigned i = 0, e = loops.size(); i < e; ++i) {
+ operandsDefinedAbove[i] = i;
+ for (unsigned j = 0; j < i; ++j) {
+ if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) {
+ operandsDefinedAbove[i] = j;
+ break;
+ }
+ }
+ }
+
+ // 2. Identify bands of loops such that the operands of all of them are
+ // defined above the first loop in the band. Traverse the nest bottom-up
+ // so that modifications don't invalidate the inner loops.
+ for (unsigned end = loops.size(); end > 0; --end) {
+ unsigned start = 0;
+ for (; start < end - 1; ++start) {
+ auto maxPos =
+ *std::max_element(std::next(operandsDefinedAbove.begin(), start),
+ std::next(operandsDefinedAbove.begin(), end));
+ if (maxPos > start)
+ continue;
+ assert(maxPos == start &&
+ "expected loop bounds to be known at the start of the band");
+ auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
+ if (succeeded(coalesceLoops(band)))
+ result = success();
+ break;
+ }
+ // If a band was found and transformed, keep looking at the loops above
+ // the outermost transformed loop.
+ if (start != end - 1)
+ end = start + 1;
+ }
+ return result;
+}
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index c09184148208..7e4faf8b73af 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -332,9 +332,9 @@ transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter,
transform::TransformState &state) {
LogicalResult result(failure());
if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
- result = coalescePerfectlyNestedLoops(scfForOp);
+ result = coalescePerfectlyNestedSCFForLoops(scfForOp);
else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
- result = coalescePerfectlyNestedLoops(affineForOp);
+ result = coalescePerfectlyNestedAffineLoops(affineForOp);
results.push_back(op);
if (failed(result)) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
index a69df025bcba..6ba7020e86fa 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
@@ -28,6 +28,7 @@ namespace {
struct TestSCFParallelLoopCollapsing
: public impl::TestSCFParallelLoopCollapsingBase<
TestSCFParallelLoopCollapsing> {
+
void runOnOperation() override {
Operation *module = getOperation();
@@ -88,6 +89,7 @@ struct TestSCFParallelLoopCollapsing
// Only apply the transformation on parallel loops where the specified
// transformation is valid, but do NOT early abort in the case of invalid
// loops.
+ IRRewriter rewriter(&getContext());
module->walk([&](scf::ParallelOp op) {
if (flattenedCombinedLoops.size() != op.getNumLoops()) {
op.emitOpError("has ")
@@ -97,7 +99,7 @@ struct TestSCFParallelLoopCollapsing
<< flattenedCombinedLoops.size() << " iter args.";
return;
}
- collapseParallelLoops(op, combinedLoops);
+ collapseParallelLoops(rewriter, op, combinedLoops);
});
}
};
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 914aeb4fa79f..9279081cfd45 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
@@ -472,18 +473,23 @@ LogicalResult mlir::loopUnrollByFactor(
return success();
}
-/// Return the new lower bound, upper bound, and step in that order. Insert any
-/// additional bounds calculations before the given builder and any additional
-/// conversion back to the original loop induction value inside the given Block.
-static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
- OpBuilder &insideLoopBuilder, Location loc,
- Value lowerBound, Value upperBound, Value step,
- Value inductionVar) {
+/// Transform a loop with a strictly positive step
+/// for %i = %lb to %ub step %s
+/// into a 0-based loop with step 1
+/// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
+/// %i = %ii * %s + %lb
+/// Insert the induction variable remapping in the body of `inner`, which is
+/// expected to be either `loop` or another loop perfectly nested under `loop`.
+/// Insert the definition of new bounds immediate before `outer`, which is
+/// expected to be either `loop` or its parent in the loop nest.
+static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
+ Value lb, Value ub, Value step) {
+ // For non-index types, generate `arith` instructions
// Check if the loop is already known to have a constant zero lower bound or
// a constant one step.
bool isZeroBased = false;
- if (auto ubCst = getConstantIntValue(lowerBound))
- isZeroBased = ubCst.value() == 0;
+ if (auto lbCst = getConstantIntValue(lb))
+ isZeroBased = lbCst.value() == 0;
bool isStepOne = false;
if (auto stepCst = getConstantIntValue(step))
@@ -493,62 +499,90 @@ static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
// assuming the step is strictly positive. Update the bounds and the step
// of the loop to go from 0 to the number of iterations, if necessary.
if (isZeroBased && isStepOne)
- return {/*lowerBound=*/lowerBound, /*upperBound=*/upperBound,
- /*step=*/step};
+ return {lb, ub, step};
- Value diff = boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
+ Value diff = isZeroBased ? ub : rewriter.create<arith::SubIOp>(loc, ub, lb);
Value newUpperBound =
- boundsBuilder.create<arith::CeilDivSIOp>(loc, diff, step);
-
- Value newLowerBound =
- isZeroBased ? lowerBound
- : boundsBuilder.create<arith::ConstantOp>(
- loc, boundsBuilder.getZeroAttr(lowerBound.getType()));
- Value newStep =
- isStepOne ? step
- : boundsBuilder.create<arith::ConstantOp>(
- loc, boundsBuilder.getIntegerAttr(step.getType(), 1));
-
- // Insert code computing the value of the original loop induction variable
- // from the "normalized" one.
- Value scaled =
- isStepOne
- ? inductionVar
- : insideLoopBuilder.create<arith::MulIOp>(loc, inductionVar, step);
- Value shifted =
- isZeroBased
- ? scaled
- : insideLoopBuilder.create<arith::AddIOp>(loc, scaled, lowerBound);
-
- SmallPtrSet<Operation *, 2> preserve{scaled.getDefiningOp(),
- shifted.getDefiningOp()};
- inductionVar.replaceAllUsesExcept(shifted, preserve);
- return {/*lowerBound=*/newLowerBound, /*upperBound=*/newUpperBound,
- /*step=*/newStep};
+ isStepOne ? diff : rewriter.create<arith::CeilDivSIOp>(loc, diff, step);
+
+ Value newLowerBound = isZeroBased
+ ? lb
+ : rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(lb.getType()));
+ Value newStep = isStepOne
+ ? step
+ : rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(step.getType(), 1));
+
+ return {newLowerBound, newUpperBound, newStep};
}
-/// Transform a loop with a strictly positive step
-/// for %i = %lb to %ub step %s
-/// into a 0-based loop with step 1
-/// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
-/// %i = %ii * %s + %lb
-/// Insert the induction variable remapping in the body of `inner`, which is
-/// expected to be either `loop` or another loop perfectly nested under `loop`.
-/// Insert the definition of new bounds immediate before `outer`, which is
-/// expected to be either `loop` or its parent in the loop nest.
-static void normalizeLoop(scf::ForOp loop, scf::ForOp outer, scf::ForOp inner) {
- OpBuilder builder(outer);
- OpBuilder innerBuilder = OpBuilder::atBlockBegin(inner.getBody());
- auto loopPieces = normalizeLoop(builder, innerBuilder, loop.getLoc(),
- loop.getLowerBound(), loop.getUpperBound(),
- loop.getStep(), loop.getInductionVar());
-
- loop.setLowerBound(loopPieces.lowerBound);
- loop.setUpperBound(loopPieces.upperBound);
- loop.setStep(loopPieces.step);
+/// Get back the original induction variable values after loop normalization
+static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
+ Value normalizedIv, Value origLb,
+ Value origStep) {
+ Value denormalizedIv;
+ SmallPtrSet<Operation *, 2> preserve;
+ bool isStepOne = isConstantIntValue(origStep, 1);
+ bool isZeroBased = isConstantIntValue(origLb, 0);
+
+ Value scaled = normalizedIv;
+ if (!isStepOne) {
+ scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStep);
+ preserve.insert(scaled.getDefiningOp());
+ }
+ denormalizedIv = scaled;
+ if (!isZeroBased) {
+ denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLb);
+ preserve.insert(denormalizedIv.getDefiningOp());
+ }
+
+ rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
}
-LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
+/// Helper function to multiply a sequence of values.
+static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
+ ArrayRef<Value> values) {
+ assert(!values.empty() && "unexpected empty list");
+ Value productOf = values.front();
+ for (auto v : values.drop_front()) {
+ productOf = rewriter.create<arith::MulIOp>(loc, productOf, v);
+ }
+ return productOf;
+}
+
+/// For each original loop, the value of the
+/// induction variable can be obtained by dividing the induction variable of
+/// the linearized loop by the total number of iterations of the loops nested
+/// in it modulo the number of iterations in this loop (remove the values
+/// related to the outer loops):
+/// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
+/// Compute these iteratively from the innermost loop by creating a "running
+/// quotient" of division by the range.
+static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
+delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
+ Value linearizedIv, ArrayRef<Value> ubs) {
+ Value previous = linearizedIv;
+ SmallVector<Value> delinearizedIvs(ubs.size());
+ SmallPtrSet<Operation *, 2> preservedUsers;
+ for (unsigned i = 0, e = ubs.size(); i < e; ++i) {
+ unsigned idx = ubs.size() - i - 1;
+ if (i != 0) {
+ previous = rewriter.create<arith::DivSIOp>(loc, previous, ubs[idx + 1]);
+ preservedUsers.insert(previous.getDefiningOp());
+ }
+ Value iv = previous;
+ if (i != e - 1) {
+ iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]);
+ preservedUsers.insert(iv.getDefiningOp());
+ }
+ delinearizedIvs[idx] = iv;
+ }
+ return {delinearizedIvs, preservedUsers};
+}
+
+LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
+ MutableArrayRef<scf::ForOp> loops) {
if (loops.size() < 2)
return failure();
@@ -557,57 +591,148 @@ LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
// 1. Make sure all loops iterate from 0 to upperBound with step 1. This
// allows the following code to assume upperBound is the number of iterations.
- for (auto loop : loops)
- normalizeLoop(loop, outermost, innermost);
+ for (auto loop : loops) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(outermost);
+ Value lb = loop.getLowerBound();
+ Value ub = loop.getUpperBound();
+ Value step = loop.getStep();
+ auto newLoopParams =
+ emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
+
+ rewriter.modifyOpInPlace(loop, [&]() {
+ loop.setLowerBound(newLoopParams.lowerBound);
+ loop.setUpperBound(newLoopParams.upperBound);
+ loop.setStep(newLoopParams.step);
+ });
+
+ rewriter.setInsertionPointToStart(innermost.getBody());
+ denormalizeInductionVariable(rewriter, loop.getLoc(),
+ loop.getInductionVar(), lb, step);
+ }
// 2. Emit code computing the upper bound of the coalesced loop as product
// of the number of iterations of all loops.
- OpBuilder builder(outermost);
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(outermost);
Location loc = outermost.getLoc();
- Value upperBound = outermost.getUpperBound();
- for (auto loop : loops.drop_front())
- upperBound =
- builder.create<arith::MulIOp>(loc, upperBound, loop.getUpperBound());
+ SmallVector<Value> upperBounds = llvm::map_to_vector(
+ loops, [](auto loop) { return loop.getUpperBound(); });
+ Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, upperBounds);
outermost.setUpperBound(upperBound);
- builder.setInsertionPointToStart(outermost.getBody());
-
- // 3. Remap induction variables. For each original loop, the value of the
- // induction variable can be obtained by dividing the induction variable of
- // the linearized loop by the total number of iterations of the loops nested
- // in it modulo the number of iterations in this loop (remove the values
- // related to the outer loops):
- // iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
- // Compute these iteratively from the innermost loop by creating a "running
- // quotient" of division by the range.
- Value previous = outermost.getInductionVar();
+ rewriter.setInsertionPointToStart(innermost.getBody());
+ auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
+ rewriter, loc, outermost.getInductionVar(), upperBounds);
+ rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0],
+ preservedUsers);
+
+ for (int i = loops.size() - 1; i > 0; --i) {
+ auto outerLoop = loops[i - 1];
+ auto innerLoop = loops[i];
+
+ Operation *innerTerminator = innerLoop.getBody()->getTerminator();
+ auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
+ rewriter.eraseOp(innerTerminator);
+
+ SmallVector<Value> innerBlockArgs;
+ innerBlockArgs.push_back(delinearizeIvs[i]);
+ llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
+ rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(),
+ Block::iterator(innerLoop), innerBlockArgs);
+ rewriter.replaceOp(innerLoop, yieldedVals);
+ }
+ return success();
+}
+
+LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
+ if (loops.empty()) {
+ return failure();
+ }
+ IRRewriter rewriter(loops.front().getContext());
+ return coalesceLoops(rewriter, loops);
+}
+
+LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) {
+ LogicalResult result(failure());
+ SmallVector<scf::ForOp> loops;
+ getPerfectlyNestedLoops(loops, op);
+
+ // Look for a band of loops that can be coalesced, i.e. perfectly nested
+ // loops with bounds defined above some loop.
+
+ // 1. For each loop, find above which parent loop its bounds operands are
+ // defined.
+ SmallVector<unsigned> operandsDefinedAbove(loops.size());
for (unsigned i = 0, e = loops.size(); i < e; ++i) {
- unsigned idx = loops.size() - i - 1;
- if (i != 0)
- previous = builder.create<arith::DivSIOp>(loc, previous,
- loops[idx + 1].getUpperBound());
-
- Value iv = (i == e - 1) ? previous
- : builder.create<arith::RemSIOp>(
- loc, previous, loops[idx].getUpperBound());
- replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv,
- loops.back().getRegion());
+ operandsDefinedAbove[i] = i;
+ for (unsigned j = 0; j < i; ++j) {
+ SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
+ loops[i].getUpperBound(),
+ loops[i].getStep()};
+ if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) {
+ operandsDefinedAbove[i] = j;
+ break;
+ }
+ }
}
- // 4. Move the operations from the innermost just above the second-outermost
- // loop, delete the extra terminator and the second-outermost loop.
- scf::ForOp second = loops[1];
- innermost.getBody()->back().erase();
- outermost.getBody()->getOperations().splice(
- Block::iterator(second.getOperation()),
- innermost.getBody()->getOperations());
- second.erase();
- return success();
+ // 2. For each inner loop check that the iter_args for the immediately outer
+ // loop are the init for the immediately inner loop and that the yields of the
+ // return of the inner loop is the yield for the immediately outer loop. Keep
+ // track of where the chain starts from for each loop.
+ SmallVector<unsigned> iterArgChainStart(loops.size());
+ iterArgChainStart[0] = 0;
+ for (unsigned i = 1, e = loops.size(); i < e; ++i) {
+ // By default set the start of the chain to itself.
+ iterArgChainStart[i] = i;
+ auto outerloop = loops[i - 1];
+ auto innerLoop = loops[i];
+ if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
+ continue;
+ }
+ if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
+ continue;
+ }
+ auto outerloopTerminator = outerloop.getBody()->getTerminator();
+ if (!llvm::equal(outerloopTerminator->getOperands(),
+ innerLoop.getResults())) {
+ continue;
+ }
+ iterArgChainStart[i] = iterArgChainStart[i - 1];
+ }
+
+ // 3. Identify bands of loops such that the operands of all of them are
+ // defined above the first loop in the band. Traverse the nest bottom-up
+ // so that modifications don't invalidate the inner loops.
+ for (unsigned end = loops.size(); end > 0; --end) {
+ unsigned start = 0;
+ for (; start < end - 1; ++start) {
+ auto maxPos =
+ *std::max_element(std::next(operandsDefinedAbove.begin(), start),
+ std::next(operandsDefinedAbove.begin(), end));
+ if (maxPos > start)
+ continue;
+ if (iterArgChainStart[end - 1] > start)
+ continue;
+ auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
+ if (succeeded(coalesceLoops(band)))
+ result = success();
+ break;
+ }
+ // If a band was found and transformed, keep looking at the loops above
+ // the outermost transformed loop.
+ if (start != end - 1)
+ end = start + 1;
+ }
+ return result;
}
void mlir::collapseParallelLoops(
- scf::ParallelOp loops, ArrayRef<std::vector<unsigned>> combinedDimensions) {
- OpBuilder outsideBuilder(loops);
+ RewriterBase &rewriter, scf::ParallelOp loops,
+ ArrayRef<std::vector<unsigned>> combinedDimensions) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(loops);
Location loc = loops.getLoc();
// Presort combined dimensions.
@@ -619,25 +744,29 @@ void mlir::collapseParallelLoops(
SmallVector<Value, 3> normalizedLowerBounds, normalizedSteps,
normalizedUpperBounds;
for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
- OpBuilder insideLoopBuilder = OpBuilder::atBlockBegin(loops.getBody());
- auto resultBounds =
- normalizeLoop(outsideBuilder, insideLoopBuilder, loc,
- loops.getLowerBound()[i], loops.getUpperBound()[i],
- loops.getStep()[i], loops.getBody()->getArgument(i));
-
- normalizedLowerBounds.push_back(resultBounds.lowerBound);
- normalizedUpperBounds.push_back(resultBounds.upperBound);
- normalizedSteps.push_back(resultBounds.step);
+ OpBuilder::InsertionGuard g2(rewriter);
+ rewriter.setInsertionPoint(loops);
+ Value lb = loops.getLowerBound()[i];
+ Value ub = loops.getUpperBound()[i];
+ Value step = loops.getStep()[i];
+ auto newLoopParams = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
+ normalizedLowerBounds.push_back(newLoopParams.lowerBound);
+ normalizedUpperBounds.push_back(newLoopParams.upperBound);
+ normalizedSteps.push_back(newLoopParams.step);
+
+ rewriter.setInsertionPointToStart(loops.getBody());
+ denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
+ step);
}
// Combine iteration spaces.
SmallVector<Value, 3> lowerBounds, upperBounds, steps;
- auto cst0 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 0);
- auto cst1 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1);
+ auto cst0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto cst1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
for (auto &sortedDimension : sortedDimensions) {
- Value newUpperBound = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1);
+ Value newUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, 1);
for (auto idx : sortedDimension) {
- newUpperBound = outsideBuilder.create<arith::MulIOp>(
+ newUpperBound = rewriter.create<arith::MulIOp>(
loc, newUpperBound, normalizedUpperBounds[idx]);
}
lowerBounds.push_back(cst0);
@@ -651,7 +780,7 @@ void mlir::collapseParallelLoops(
// value. The remainders then determine based on that range, which iteration
// of the original induction value this represents. This is a normalized value
// that is un-normalized already by the previous logic.
- auto newPloop = outsideBuilder.create<scf::ParallelOp>(
+ auto newPloop = rewriter.create<scf::ParallelOp>(
loc, lowerBounds, upperBounds, steps,
[&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5944a0ea46a1..286f47ce6913 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/RegionKindInterface.h"
+#include "llvm/ADT/SmallPtrSet.h"
using namespace mlir;
@@ -250,6 +251,14 @@ void RewriterBase::finalizeOpModification(Operation *op) {
rewriteListener->notifyOperationModified(op);
}
+void RewriterBase::replaceAllUsesExcept(
+ Value from, Value to, const SmallPtrSetImpl<Operation *> &preservedUsers) {
+ return replaceUsesWithIf(from, to, [&](OpOperand &use) {
+ Operation *user = use.getOwner();
+ return !preservedUsers.contains(user);
+ });
+}
+
void RewriterBase::replaceUsesWithIf(Value from, Value to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced) {
diff --git a/mlir/test/Dialect/Affine/loop-coalescing.mlir b/mlir/test/Dialect/Affine/loop-coalescing.mlir
index 9c17fb24be69..ae0adf5a0a02 100644
--- a/mlir/test/Dialect/Affine/loop-coalescing.mlir
+++ b/mlir/test/Dialect/Affine/loop-coalescing.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -affine-loop-coalescing %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -affine-loop-coalescing --cse %s | FileCheck %s
// CHECK-LABEL: @one_3d_nest
func.func @one_3d_nest() {
@@ -239,19 +239,15 @@ func.func @coalesce_affine_for(%arg0: memref<?x?xf32>) {
}
return
}
-// CHECK: %[[T0:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
-// CHECK: %[[T1:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
-// CHECK: %[[T2:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
-// CHECK-DAG: %[[T3:.*]] = affine.apply #[[IDENTITY]]()[%[[T0]]]
-// CHECK-DAG: %[[T4:.*]] = affine.apply #[[IDENTITY]]()[%[[T1]]]
-// CHECK-DAG: %[[T5:.*]] = affine.apply #[[PRODUCT]](%[[T3]])[%[[T4]]]
-// CHECK-DAG: %[[T6:.*]] = affine.apply #[[IDENTITY]]()[%[[T2]]]
-// CHECK-DAG: %[[T7:.*]] = affine.apply #[[PRODUCT]](%[[T5]])[%[[T6]]]
-// CHECK: affine.for %[[IV:.*]] = 0 to %[[T7]]
-// CHECK-DAG: %[[K:.*]] = affine.apply #[[MOD]](%[[IV]])[%[[T6]]]
-// CHECK-DAG: %[[T9:.*]] = affine.apply #[[FLOOR]](%[[IV]])[%[[T6]]]
-// CHECK-DAG: %[[J:.*]] = affine.apply #[[MOD]](%[[T9]])[%[[T4]]]
-// CHECK-DAG: %[[I:.*]] = affine.apply #[[FLOOR]](%[[T9]])[%[[T4]]]
+// CHECK: %[[DIM:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
+// CHECK-DAG: %[[T0:.*]] = affine.apply #[[IDENTITY]]()[%[[DIM]]]
+// CHECK-DAG: %[[T1:.*]] = affine.apply #[[PRODUCT]](%[[T0]])[%[[T0]]]
+// CHECK-DAG: %[[T2:.*]] = affine.apply #[[PRODUCT]](%[[T1]])[%[[T0]]]
+// CHECK: affine.for %[[IV:.*]] = 0 to %[[T2]]
+// CHECK-DAG: %[[K:.*]] = affine.apply #[[MOD]](%[[IV]])[%[[T0]]]
+// CHECK-DAG: %[[T9:.*]] = affine.apply #[[FLOOR]](%[[IV]])[%[[T0]]]
+// CHECK-DAG: %[[J:.*]] = affine.apply #[[MOD]](%[[T9]])[%[[T0]]]
+// CHECK-DAG: %[[I:.*]] = affine.apply #[[FLOOR]](%[[T9]])[%[[T0]]]
// CHECK-NEXT: "test.foo"(%[[I]], %[[J]], %[[K]])
// CHECK-NEXT: }
// CHECK-NEXT: return
@@ -277,18 +273,16 @@ func.func @coalesce_affine_for(%arg0: memref<?x?xf32>) {
}
return
}
-// CHECK: %[[T0:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
-// CHECK: %[[T1:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
-// CHECK-DAG: %[[T2:.*]] = affine.apply #[[IDENTITY]]()[%[[T0]]]
-// CHECK-DAG: %[[T3:.*]] = affine.apply #[[IDENTITY]]()[%[[T1]]]
-// CHECK-DAG: %[[T4:.*]] = affine.apply #[[PRODUCT]](%[[T2]])[%[[T3]]]
-// CHECK-DAG: %[[T5:.*]] = affine.apply #[[SIXTY_FOUR]]()
-// CHECK-DAG: %[[T6:.*]] = affine.apply #[[PRODUCT]](%[[T4]])[%[[T5]]]
-// CHECK: affine.for %[[IV:.*]] = 0 to %[[T6]]
-// CHECK-DAG: %[[K:.*]] = affine.apply #[[MOD]](%[[IV]])[%[[T5]]]
-// CHECK-DAG: %[[T8:.*]] = affine.apply #[[DIV]](%[[IV]])[%[[T5]]]
-// CHECK-DAG: %[[J:.*]] = affine.apply #[[MOD]](%[[T8]])[%[[T3]]]
-// CHECK-DAG: %[[I:.*]] = affine.apply #[[DIV]](%[[T8]])[%[[T3]]]
+// CHECK: %[[DIM:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
+// CHECK-DAG: %[[T0:.*]] = affine.apply #[[IDENTITY]]()[%[[DIM]]]
+// CHECK-DAG: %[[T1:.*]] = affine.apply #[[PRODUCT]](%[[T0]])[%[[T0]]]
+// CHECK-DAG: %[[T2:.*]] = affine.apply #[[SIXTY_FOUR]]()
+// CHECK-DAG: %[[T3:.*]] = affine.apply #[[PRODUCT]](%[[T1]])[%[[T2]]]
+// CHECK: affine.for %[[IV:.*]] = 0 to %[[T3]]
+// CHECK-DAG: %[[K:.*]] = affine.apply #[[MOD]](%[[IV]])[%[[T2]]]
+// CHECK-DAG: %[[T5:.*]] = affine.apply #[[DIV]](%[[IV]])[%[[T2]]]
+// CHECK-DAG: %[[J:.*]] = affine.apply #[[MOD]](%[[T5]])[%[[T0]]]
+// CHECK-DAG: %[[I:.*]] = affine.apply #[[DIV]](%[[T5]])[%[[T0]]]
// CHECK-NEXT: "test.foo"(%[[I]], %[[J]], %[[K]])
// CHECK-NEXT: }
// CHECK-NEXT: return
@@ -316,19 +310,16 @@ func.func @coalesce_affine_for(%arg0: memref<?x?xf32>) {
}
return
}
-// CHECK: %[[T0:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
-// CHECK: %[[T1:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
-// CHECK: %[[T2:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
-// CHECK-DAG: %[[T3:.*]] = affine.min #[[MAP0]]()[%[[T0]]]
-// CHECK-DAG: %[[T4:.*]] = affine.apply #[[IDENTITY]]()[%[[T1]]]
-// CHECK-DAG: %[[T5:.*]] = affine.apply #[[PRODUCT]](%[[T3]])[%[[T4]]]
-// CHECK-DAG: %[[T6:.*]] = affine.apply #[[IDENTITY]]()[%[[T2]]]
-// CHECK-DAG: %[[T7:.*]] = affine.apply #[[PRODUCT]](%[[T5]])[%[[T6]]]
-// CHECK: affine.for %[[IV:.*]] = 0 to %[[T7]]
-// CHECK-DAG: %[[K:.*]] = affine.apply #[[MOD]](%[[IV]])[%[[T6]]]
-// CHECK-DAG: %[[T9:.*]] = affine.apply #[[DIV]](%[[IV]])[%[[T6]]]
-// CHECK-DAG: %[[J:.*]] = affine.apply #[[MOD]](%[[T9]])[%[[T4]]]
-// CHECK-DAG: %[[I:.*]] = affine.apply #[[DIV]](%[[T9]])[%[[T4]]]
+// CHECK: %[[DIM:.*]] = memref.dim %arg{{.*}}, %c{{.*}} : memref<?x?xf32>
+// CHECK-DAG: %[[T0:.*]] = affine.min #[[MAP0]]()[%[[DIM]]]
+// CHECK-DAG: %[[T1:.*]] = affine.apply #[[IDENTITY]]()[%[[DIM]]]
+// CHECK-DAG: %[[T2:.*]] = affine.apply #[[PRODUCT]](%[[T0]])[%[[T1]]]
+// CHECK-DAG: %[[T3:.*]] = affine.apply #[[PRODUCT]](%[[T2]])[%[[T1]]]
+// CHECK: affine.for %[[IV:.*]] = 0 to %[[T3]]
+// CHECK-DAG: %[[K:.*]] = affine.apply #[[MOD]](%[[IV]])[%[[T1]]]
+// CHECK-DAG: %[[T5:.*]] = affine.apply #[[DIV]](%[[IV]])[%[[T1]]]
+// CHECK-DAG: %[[J:.*]] = affine.apply #[[MOD]](%[[T5]])[%[[T1]]]
+// CHECK-DAG: %[[I:.*]] = affine.apply #[[DIV]](%[[T5]])[%[[T1]]]
// CHECK-NEXT: "test.foo"(%[[I]], %[[J]], %[[K]])
// CHECK-NEXT: }
// CHECK-NEXT: return
@@ -342,12 +333,14 @@ func.func @coalesce_affine_for(%arg0: memref<?x?xf32>) {
func.func @test_loops_do_not_get_coalesced() {
affine.for %i = 0 to 7 {
affine.for %j = #map0(%i) to min #map1(%i) {
+ "use"(%i, %j) : (index, index) -> ()
}
}
return
}
// CHECK: affine.for %[[IV0:.*]] = 0 to 7
// CHECK-NEXT: affine.for %[[IV1:.*]] = #[[MAP0]](%[[IV0]]) to min #[[MAP1]](%[[IV0]])
+// CHECK-NEXT: "use"(%[[IV0]], %[[IV1]])
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: return
diff --git a/mlir/test/Dialect/SCF/transform-op-coalesce.mlir b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir
index 2d59331b72cf..4dc3e4ea0ef4 100644
--- a/mlir/test/Dialect/SCF/transform-op-coalesce.mlir
+++ b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics | FileCheck %s
+// RUN: mlir-opt %s -transform-interpreter -split-input-file -verify-diagnostics -allow-unregistered-dialect --cse | FileCheck %s
func.func @coalesce_inner() {
%c0 = arith.constant 0 : index
@@ -14,7 +14,7 @@ func.func @coalesce_inner() {
scf.for %k = %i to %j step %c1 {
// Inner loop must have been removed.
scf.for %l = %i to %j step %c1 {
- arith.addi %i, %j : index
+ "use"(%i, %j) : (index, index) -> ()
}
} {coalesce}
}
@@ -33,13 +33,19 @@ module attributes {transform.with_named_sequence} {
// -----
+// CHECK-DAG: #[[MAP:.+]] = affine_map<() -> (64)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 mod s0)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 floordiv s0)>
func.func @coalesce_outer(%arg1: memref<64x64xf32, 1>, %arg2: memref<64x64xf32, 1>, %arg3: memref<64x64xf32, 1>) attributes {} {
+ // CHECK: %[[T0:.+]] = affine.apply #[[MAP]]()
+ // CHECK: %[[UB:.+]] = affine.apply #[[MAP1]](%[[T0]])[%[[T0]]]
// CHECK: affine.for %[[IV1:.+]] = 0 to %[[UB:.+]] {
// CHECK-NOT: affine.for %[[IV2:.+]]
affine.for %arg4 = 0 to 64 {
affine.for %arg5 = 0 to 64 {
- // CHECK: %[[IDX0:.+]] = affine.apply #[[MAP0:.+]](%[[IV1]])[%{{.+}}]
- // CHECK: %[[IDX1:.+]] = affine.apply #[[MAP1:.+]](%[[IV1]])[%{{.+}}]
+ // CHECK: %[[IDX0:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%{{.+}}]
+ // CHECK: %[[IDX1:.+]] = affine.apply #[[MAP3]](%[[IV1]])[%{{.+}}]
// CHECK-NEXT: %{{.+}} = affine.load %{{.+}}[%[[IDX1]], %[[IDX0]]] : memref<64x64xf32, 1>
%0 = affine.load %arg1[%arg4, %arg5] : memref<64x64xf32, 1>
%1 = affine.load %arg2[%arg4, %arg5] : memref<64x64xf32, 1>
@@ -96,3 +102,200 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+func.func @tensor_loops(%arg0 : tensor<?x?xf32>, %lb0 : index, %ub0 : index, %step0 : index,
+ %lb1 : index, %ub1 : index, %step1 : index, %lb2 : index, %ub2 : index, %step2 : index) -> tensor<?x?xf32> {
+ %0 = scf.for %i = %lb0 to %ub0 step %step0 iter_args(%arg1 = %arg0) -> tensor<?x?xf32> {
+ %1 = scf.for %j = %lb1 to %ub1 step %step1 iter_args(%arg2 = %arg1) -> tensor<?x?xf32> {
+ %2 = scf.for %k = %lb2 to %ub2 step %step2 iter_args(%arg3 = %arg2) -> tensor<?x?xf32> {
+ %3 = "use"(%arg3, %i, %j, %k) : (tensor<?x?xf32>, index, index, index) -> (tensor<?x?xf32>)
+ scf.yield %3 : tensor<?x?xf32>
+ }
+ scf.yield %2 : tensor<?x?xf32>
+ }
+ scf.yield %1 : tensor<?x?xf32>
+ } {coalesce}
+ return %0 : tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
+ %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
+ transform.yield
+ }
+}
+// CHECK: func.func @tensor_loops(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[LB0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[UB0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[STEP0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[LB1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[UB1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[STEP1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[LB2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[UB2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[STEP2:[a-zA-Z0-9_]+]]: index
+// CHECK: %[[NEWUB0_DIFF:.+]] = arith.subi %[[UB0]], %[[LB0]]
+// CHECK-DAG: %[[NEWUB0:.+]] = arith.ceildivsi %[[NEWUB0_DIFF]], %[[STEP0]]
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0
+// CHECK-DAG: %[[C1:.+]] = arith.constant 1
+// CHECK: %[[NEWUB1_DIFF:.+]] = arith.subi %[[UB1]], %[[LB1]]
+// CHECK-DAG: %[[NEWUB1:.+]] = arith.ceildivsi %[[NEWUB1_DIFF]], %[[STEP1]]
+// CHECK: %[[NEWUB2_DIFF:.+]] = arith.subi %[[UB2]], %[[LB2]]
+// CHECK-DAG: %[[NEWUB2:.+]] = arith.ceildivsi %[[NEWUB2_DIFF]], %[[STEP2]]
+// CHECK: %[[PROD1:.+]] = arith.muli %[[NEWUB0]], %[[NEWUB1]]
+// CHECK: %[[NEWUB:.+]] = arith.muli %[[PROD1]], %[[NEWUB2]]
+// CHECK: %[[RESULT:.+]] = scf.for %[[IV:[a-zA-Z0-9]+]] = %[[C0]] to %[[NEWUB]] step %[[C1]] iter_args(%[[ITER_ARG:.+]] = %[[ARG0]])
+// CHECK: %[[IV2:.+]] = arith.remsi %[[IV]], %[[NEWUB2]]
+// CHECK: %[[PREVIOUS:.+]] = arith.divsi %[[IV]], %[[NEWUB2]]
+// CHECK: %[[IV1:.+]] = arith.remsi %[[PREVIOUS]], %[[NEWUB1]]
+// CHECK: %[[IV0:.+]] = arith.divsi %[[PREVIOUS]], %[[NEWUB1]]
+// CHECK: %[[K_STEP:.+]] = arith.muli %[[IV2]], %[[STEP2]]
+// CHECK: %[[K:.+]] = arith.addi %[[K_STEP]], %[[LB2]]
+// CHECK: %[[J_STEP:.+]] = arith.muli %[[IV1]], %[[STEP1]]
+// CHECK: %[[J:.+]] = arith.addi %[[J_STEP]], %[[LB1]]
+// CHECK: %[[I_STEP:.+]] = arith.muli %[[IV0]], %[[STEP0]]
+// CHECK: %[[I:.+]] = arith.addi %[[I_STEP]], %[[LB0]]
+// CHECK: %[[USE:.+]] = "use"(%[[ITER_ARG]], %[[I]], %[[J]], %[[K]])
+// CHECK: scf.yield %[[USE]]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+// Coalesce only first two loops, but not the last since the iter_args dont line up
+func.func @tensor_loops_first_two(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %lb0 : index, %ub0 : index, %step0 : index,
+ %lb1 : index, %ub1 : index, %step1 : index, %lb2 : index, %ub2 : index, %step2 : index) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %0:2 = scf.for %i = %lb0 to %ub0 step %step0 iter_args(%arg2 = %arg0, %arg3 = %arg1) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %1:2 = scf.for %j = %lb1 to %ub1 step %step1 iter_args(%arg4 = %arg2, %arg5 = %arg3) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %2:2 = scf.for %k = %lb2 to %ub2 step %step2 iter_args(%arg6 = %arg5, %arg7 = %arg4) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %3:2 = "use"(%arg3, %i, %j, %k) : (tensor<?x?xf32>, index, index, index) -> (tensor<?x?xf32>, tensor<?x?xf32>)
+ scf.yield %3#0, %3#1 : tensor<?x?xf32>, tensor<?x?xf32>
+ }
+ scf.yield %2#0, %2#1 : tensor<?x?xf32>, tensor<?x?xf32>
+ }
+ scf.yield %1#0, %1#1 : tensor<?x?xf32>, tensor<?x?xf32>
+ } {coalesce}
+ return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
+ %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
+ transform.yield
+ }
+}
+// CHECK: func.func @tensor_loops_first_two(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[LB0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[UB0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[STEP0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[LB1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[UB1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[STEP1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[LB2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[UB2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[STEP2:[a-zA-Z0-9_]+]]: index
+// CHECK: scf.for
+// CHECK: arith.remsi
+// CHECK: arith.divsi
+// CHECK: scf.for %{{[a-zA-Z0-9]+}} = %[[LB2]] to %[[UB2]] step %[[STEP2]]
+// CHECK-NOT: scf.for
+// CHECK: transform.named_sequence
+
+// -----
+
+// Coalesce only first two loops, but not the last since the yields dont match up
+func.func @tensor_loops_first_two_2(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %lb0 : index, %ub0 : index, %step0 : index,
+ %lb1 : index, %ub1 : index, %step1 : index, %lb2 : index, %ub2 : index, %step2 : index) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %0:2 = scf.for %i = %lb0 to %ub0 step %step0 iter_args(%arg2 = %arg0, %arg3 = %arg1) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %1:2 = scf.for %j = %lb1 to %ub1 step %step1 iter_args(%arg4 = %arg2, %arg5 = %arg3) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %2:2 = scf.for %k = %lb2 to %ub2 step %step2 iter_args(%arg6 = %arg4, %arg7 = %arg5) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %3:2 = "use"(%arg3, %i, %j, %k) : (tensor<?x?xf32>, index, index, index) -> (tensor<?x?xf32>, tensor<?x?xf32>)
+ scf.yield %3#0, %3#1 : tensor<?x?xf32>, tensor<?x?xf32>
+ }
+ scf.yield %2#1, %2#0 : tensor<?x?xf32>, tensor<?x?xf32>
+ }
+ scf.yield %1#0, %1#1 : tensor<?x?xf32>, tensor<?x?xf32>
+ } {coalesce}
+ return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
+ %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
+ transform.yield
+ }
+}
+// CHECK: func.func @tensor_loops_first_two_2(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[LB0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[UB0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[STEP0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[LB1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[UB1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[STEP1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[LB2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[UB2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[STEP2:[a-zA-Z0-9_]+]]: index
+// CHECK: scf.for
+// CHECK: arith.remsi
+// CHECK: arith.divsi
+// CHECK: scf.for %{{[a-zA-Z0-9]+}} = %[[LB2]] to %[[UB2]] step %[[STEP2]]
+// CHECK-NOT: scf.for
+// CHECK: transform.named_sequence
+
+// -----
+
+// Coalesce only last two loops, but not the first since the yields dont match up
+func.func @tensor_loops_last_two(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %lb0 : index, %ub0 : index, %step0 : index,
+ %lb1 : index, %ub1 : index, %step1 : index, %lb2 : index, %ub2 : index, %step2 : index) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %0:2 = scf.for %i = %lb0 to %ub0 step %step0 iter_args(%arg2 = %arg0, %arg3 = %arg1) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %1:2 = scf.for %j = %lb1 to %ub1 step %step1 iter_args(%arg4 = %arg2, %arg5 = %arg3) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %2:2 = scf.for %k = %lb2 to %ub2 step %step2 iter_args(%arg6 = %arg4, %arg7 = %arg5) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
+ %3:2 = "use"(%arg3, %i, %j, %k) : (tensor<?x?xf32>, index, index, index) -> (tensor<?x?xf32>, tensor<?x?xf32>)
+ scf.yield %3#0, %3#1 : tensor<?x?xf32>, tensor<?x?xf32>
+ }
+ scf.yield %2#0, %2#1 : tensor<?x?xf32>, tensor<?x?xf32>
+ }
+ scf.yield %1#1, %1#0 : tensor<?x?xf32>, tensor<?x?xf32>
+ } {coalesce}
+ return %0#0, %0#1 : tensor<?x?xf32>, tensor<?x?xf32>
+}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for">
+ %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">)
+ transform.yield
+ }
+}
+// CHECK: func.func @tensor_loops_last_two(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+// CHECK-SAME: %[[LB0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[UB0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[STEP0:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[LB1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[UB1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[STEP1:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[LB2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[UB2:[a-zA-Z0-9_]+]]: index
+// CHECK-SAME: %[[STEP2:[a-zA-Z0-9_]+]]: index
+// CHECK: scf.for %{{[a-zA-Z0-9]+}} = %[[LB0]] to %[[UB0]] step %[[STEP0]]
+// CHECK: arith.subi
+// CHECK: arith.ceildivsi
+// CHECK: arith.subi
+// CHECK: arith.ceildivsi
+// CHECK: scf.for
+// CHECK: arith.remsi
+// CHECK: arith.divsi
+// CHECK-NOT: scf.for
+// CHECK: transform.named_sequence
+
diff --git a/mlir/test/Transforms/parallel-loop-collapsing.mlir b/mlir/test/Transforms/parallel-loop-collapsing.mlir
index 660d7edb2fbb..d1c23d584f92 100644
--- a/mlir/test/Transforms/parallel-loop-collapsing.mlir
+++ b/mlir/test/Transforms/parallel-loop-collapsing.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(test-scf-parallel-loop-collapsing{collapsed-indices-0=0,3 collapsed-indices-1=1,4 collapsed-indices-2=2}, canonicalize))' | FileCheck %s
-// CHECK-LABEL: func @parallel_many_dims() {
+// CHECK: func @parallel_many_dims() {
func.func @parallel_many_dims() {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
@@ -28,19 +28,19 @@ func.func @parallel_many_dims() {
return
}
-// CHECK-DAG: [[C12:%.*]] = arith.constant 12 : index
-// CHECK-DAG: [[C10:%.*]] = arith.constant 10 : index
-// CHECK-DAG: [[C9:%.*]] = arith.constant 9 : index
-// CHECK-DAG: [[C6:%.*]] = arith.constant 6 : index
-// CHECK-DAG: [[C4:%.*]] = arith.constant 4 : index
-// CHECK-DAG: [[C3:%.*]] = arith.constant 3 : index
-// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
-// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
-// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
-// CHECK: scf.parallel ([[NEW_I0:%.*]]) = ([[C0]]) to ([[C4]]) step ([[C1]]) {
-// CHECK: [[V0:%.*]] = arith.remsi [[NEW_I0]], [[C2]] : index
-// CHECK: [[I0:%.*]] = arith.divsi [[NEW_I0]], [[C2]] : index
-// CHECK: [[V2:%.*]] = arith.muli [[V0]], [[C10]] : index
-// CHECK: [[I3:%.*]] = arith.addi [[V2]], [[C9]] : index
-// CHECK: "magic.op"([[I0]], [[C3]], [[C6]], [[I3]], [[C12]]) : (index, index, index, index, index) -> index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+// CHECK-DAG: %[[C12:.*]] = arith.constant 12 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index
+// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
+// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: scf.parallel (%[[NEW_I0:.*]]) = (%[[C0]]) to (%[[C4]]) step (%[[C1]]) {
+// CHECK: %[[V0:.*]] = arith.remsi %[[NEW_I0]], %[[C2]] : index
+// CHECK: %[[I0:.*]] = arith.divsi %[[NEW_I0]], %[[C2]] : index
+// CHECK: %[[V2:.*]] = arith.muli %[[V0]], %[[C10]]
+// CHECK: %[[I3:.*]] = arith.addi %[[V2]], %[[C9]]
+// CHECK: "magic.op"(%[[I0]], %[[C3]], %[[C6]], %[[I3]], %[[C12]]) : (index, index, index, index, index) -> index
// CHECK: scf.reduce
diff --git a/mlir/test/Transforms/single-parallel-loop-collapsing.mlir b/mlir/test/Transforms/single-parallel-loop-collapsing.mlir
index 542786b5fa5e..4eed61a65aa4 100644
--- a/mlir/test/Transforms/single-parallel-loop-collapsing.mlir
+++ b/mlir/test/Transforms/single-parallel-loop-collapsing.mlir
@@ -13,22 +13,22 @@ func.func @collapse_to_single() {
return
}
-// CHECK-LABEL: func @collapse_to_single() {
-// CHECK-DAG: [[C18:%.*]] = arith.constant 18 : index
-// CHECK-DAG: [[C6:%.*]] = arith.constant 6 : index
-// CHECK-DAG: [[C3:%.*]] = arith.constant 3 : index
-// CHECK-DAG: [[C7:%.*]] = arith.constant 7 : index
-// CHECK-DAG: [[C4:%.*]] = arith.constant 4 : index
-// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
-// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
-// CHECK: scf.parallel ([[NEW_I:%.*]]) = ([[C0]]) to ([[C18]]) step ([[C1]]) {
-// CHECK: [[I0_COUNT:%.*]] = arith.remsi [[NEW_I]], [[C6]] : index
-// CHECK: [[I1_COUNT:%.*]] = arith.divsi [[NEW_I]], [[C6]] : index
-// CHECK: [[V0:%.*]] = arith.muli [[I0_COUNT]], [[C4]] : index
-// CHECK: [[I1:%.*]] = arith.addi [[V0]], [[C7]] : index
-// CHECK: [[V1:%.*]] = arith.muli [[I1_COUNT]], [[C3]] : index
-// CHECK: [[I0:%.*]] = arith.addi [[V1]], [[C3]] : index
-// CHECK: "magic.op"([[I0]], [[I1]]) : (index, index) -> index
+// CHECK: func @collapse_to_single() {
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+// CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
+// CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
+// CHECK: scf.parallel (%[[NEW_I:.*]]) = (%[[C0]]) to (%[[C18]]) step (%[[C1]]) {
+// CHECK: %[[I0_COUNT:.*]] = arith.remsi %[[NEW_I]], %[[C6]] : index
+// CHECK: %[[I1_COUNT:.*]] = arith.divsi %[[NEW_I]], %[[C6]] : index
+// CHECK: %[[V0:.*]] = arith.muli %[[I0_COUNT]], %[[C4]]
+// CHECK: %[[I1:.*]] = arith.addi %[[V0]], %[[C7]]
+// CHECK: %[[V1:.*]] = arith.muli %[[I1_COUNT]], %[[C3]]
+// CHECK: %[[I0:.*]] = arith.addi %[[V1]], %[[C3]]
+// CHECK: "magic.op"(%[[I0]], %[[I1]]) : (index, index) -> index
// CHECK: scf.reduce
// CHECK-NEXT: }
// CHECK-NEXT: return
diff --git a/mlir/test/mlir-tblgen/op-properties.td b/mlir/test/mlir-tblgen/op-properties.td
new file mode 100644
index 000000000000..a484f68fc4a1
--- /dev/null
+++ b/mlir/test/mlir-tblgen/op-properties.td
@@ -0,0 +1,21 @@
+// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/EnumAttr.td"
+include "mlir/IR/OpBase.td"
+
+def Test_Dialect : Dialect {
+ let name = "test";
+ let cppNamespace = "foobar";
+}
+class NS_Op<string mnemonic, list<Trait> traits = []> :
+ Op<Test_Dialect, mnemonic, traits>;
+
+def OpWithAttr : NS_Op<"op_with_attr">{
+ let arguments = (ins AnyAttr:$attr, OptionalAttr<AnyAttr>:$optional);
+}
+
+// CHECK: void OpWithAttr::setAttrAttr(::mlir::Attribute attr)
+// CHECK-NEXT: getProperties().attr = attr
+// CHECK: void OpWithAttr::setOptionalAttr(::mlir::Attribute attr)
+// CHECK-NEXT: getProperties().optional = attr
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 3a697520dfad..843760d57c99 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -1804,23 +1804,36 @@ void OpEmitter::genAttrGetters() {
}
void OpEmitter::genAttrSetters() {
+ bool useProperties = op.getDialect().usePropertiesForAttributes();
+
+ // Generate the code to set an attribute.
+ auto emitSetAttr = [&](Method *method, StringRef getterName,
+ StringRef attrName, StringRef attrVar) {
+ if (useProperties) {
+ method->body() << formatv(" getProperties().{0} = {1};", attrName,
+ attrVar);
+ } else {
+ method->body() << formatv(" (*this)->setAttr({0}AttrName(), {1});",
+ getterName, attrVar);
+ }
+ };
+
// Generate raw named setter type. This is a wrapper class that allows setting
// to the attributes via setters instead of having to use the string interface
// for better compile time verification.
auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName,
- Attribute attr) {
+ StringRef attrName, Attribute attr) {
auto *method =
opClass.addMethod("void", setterName + "Attr",
MethodParameter(attr.getStorageType(), "attr"));
if (method)
- method->body() << formatv(" (*this)->setAttr({0}AttrName(), attr);",
- getterName);
+ emitSetAttr(method, getterName, attrName, "attr");
};
// Generate a setter that accepts the underlying C++ type as opposed to the
// attribute type.
auto emitAttrWithReturnType = [&](StringRef setterName, StringRef getterName,
- Attribute attr) {
+ StringRef attrName, Attribute attr) {
Attribute baseAttr = attr.getBaseAttr();
if (!canUseUnwrappedRawValue(baseAttr))
return;
@@ -1849,9 +1862,8 @@ void OpEmitter::genAttrSetters() {
// If the value isn't optional, just set it directly.
if (!isOptional) {
- method->body() << formatv(
- " (*this)->setAttr({0}AttrName(), {1});", getterName,
- constBuildAttrFromParam(attr, fctx, "attrValue"));
+ emitSetAttr(method, getterName, attrName,
+ constBuildAttrFromParam(attr, fctx, "attrValue"));
return;
}
@@ -1862,13 +1874,25 @@ void OpEmitter::genAttrSetters() {
// optional but not in the same way as the others (i.e. it uses bool over
// std::optional<>).
StringRef paramStr = isUnitAttr ? "attrValue" : "*attrValue";
- const char *optionalCodeBody = R"(
+ if (!useProperties) {
+ const char *optionalCodeBody = R"(
if (attrValue)
return (*this)->setAttr({0}AttrName(), {1});
(*this)->removeAttr({0}AttrName());)";
- method->body() << formatv(
- optionalCodeBody, getterName,
- constBuildAttrFromParam(baseAttr, fctx, paramStr));
+ method->body() << formatv(
+ optionalCodeBody, getterName,
+ constBuildAttrFromParam(baseAttr, fctx, paramStr));
+ } else {
+ const char *optionalCodeBody = R"(
+ auto &odsProp = getProperties().{0};
+ if (attrValue)
+ odsProp = {1};
+ else
+ odsProp = nullptr;)";
+ method->body() << formatv(
+ optionalCodeBody, attrName,
+ constBuildAttrFromParam(baseAttr, fctx, paramStr));
+ }
};
for (const NamedAttribute &namedAttr : op.getAttributes()) {
@@ -1876,8 +1900,10 @@ void OpEmitter::genAttrSetters() {
continue;
std::string setterName = op.getSetterName(namedAttr.name);
std::string getterName = op.getGetterName(namedAttr.name);
- emitAttrWithStorageType(setterName, getterName, namedAttr.attr);
- emitAttrWithReturnType(setterName, getterName, namedAttr.attr);
+ emitAttrWithStorageType(setterName, getterName, namedAttr.name,
+ namedAttr.attr);
+ emitAttrWithReturnType(setterName, getterName, namedAttr.name,
+ namedAttr.attr);
}
}