summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorFlorian Mayer <fmayer@google.com>2024-03-08 15:46:12 -0800
committerFlorian Mayer <fmayer@google.com>2024-03-08 15:46:12 -0800
commita8be7221b79d4fc4f239c107be4155755c050514 (patch)
tree65d1bfae5f15d8fe660287beed69a1a901764379
parentb408241d0ad9ce009b49018fe1e9838887abf3c1 (diff)
parentda4957be2365831c94eab0b52612367c29f1d299 (diff)
Created using spr 1.3.4 [skip ci]
-rw-r--r--clang/docs/LanguageExtensions.rst1
-rw-r--r--clang/docs/ReleaseNotes.rst5
-rw-r--r--clang/include/clang/Basic/Builtins.td2
-rw-r--r--clang/lib/AST/ExprConstant.cpp1
-rw-r--r--clang/test/Sema/constant-builtins-2.c7
-rw-r--r--compiler-rt/lib/fuzzer/FuzzerUtilWindows.cpp2
-rw-r--r--llvm/include/llvm/DebugInfo/DWARF/DWARFDebugFrame.h6
-rw-r--r--llvm/include/llvm/Transforms/Utils/MemoryTaggingSupport.h1
-rw-r--r--llvm/lib/DebugInfo/DWARF/DWARFDebugFrame.cpp21
-rw-r--r--llvm/lib/Target/AArch64/AArch64StackTagging.cpp6
-rw-r--r--llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp28
-rw-r--r--llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp16
-rw-r--r--llvm/test/CodeGen/AArch64/sign-return-address-cfi-negate-ra-state.ll19
-rw-r--r--llvm/test/CodeGen/AArch64/stack-tagging-initializer-merge.ll4
-rw-r--r--llvm/test/CodeGen/AArch64/stack-tagging-stack-coloring.ll16
-rw-r--r--llvm/test/CodeGen/AArch64/stack-tagging-untag-placement.ll2
-rw-r--r--llvm/test/CodeGen/PowerPC/pr74951.ll54
-rw-r--r--llvm/test/CodeGen/RISCV/forced-atomics.ll2
-rw-r--r--llvm/test/CodeGen/RISCV/fpclamptosat.ll60
-rw-r--r--llvm/test/DebugInfo/dwarfdump-debug-frame-simple.test6
-rw-r--r--llvm/test/Instrumentation/AddressSanitizer/aarch64be.ll4
-rw-r--r--llvm/test/Instrumentation/AddressSanitizer/program-addrspace.ll2
-rw-r--r--llvm/test/Instrumentation/InstrProfiling/before-value-profile-lowering.ll12
-rw-r--r--llvm/test/Instrumentation/InstrProfiling/timestamp-coverage.ll8
-rw-r--r--llvm/test/Instrumentation/InstrProfiling/timestamp.ll8
-rw-r--r--llvm/test/Object/Inputs/small.ll10
-rw-r--r--llvm/test/Object/Inputs/trivial.ll10
-rw-r--r--llvm/test/Object/X86/irsymtab-bad-alias.ll4
-rw-r--r--llvm/test/Object/X86/nm-ir.ll10
-rw-r--r--llvm/test/Object/dllimport-globalref.ll2
-rw-r--r--llvm/test/Object/dllimport.ll2
-rw-r--r--llvm/test/Object/mangle-ir.ll4
-rw-r--r--llvm/test/Object/objc-swift-mixed-imageinfo-macho.ll8
-rw-r--r--llvm/test/tools/llvm-readobj/ELF/unwind.test34
-rw-r--r--llvm/tools/llvm-readobj/DwarfCFIEHPrinter.h7
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h26
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h20
-rw-r--r--mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td6
-rw-r--r--mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td4
-rw-r--r--mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h18
-rw-r--r--mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h6
-rw-r--r--mlir/include/mlir/IR/Dialect.h8
-rw-r--r--mlir/include/mlir/InitAllDialects.h10
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h33
-rw-r--r--mlir/lib/Dialect/Linalg/IR/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp7
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp24
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt5
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp353
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp8
-rw-r--r--mlir/lib/Dialect/Mesh/IR/MeshOps.cpp7
-rw-r--r--mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp89
-rw-r--r--mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp13
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp248
-rw-r--r--mlir/test/Dialect/Linalg/mesh-spmdization.mlir165
-rw-r--r--mlir/test/Transforms/test-legalizer.mlir71
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp28
-rw-r--r--openmp/runtime/src/kmp_collapse.cpp311
-rw-r--r--openmp/runtime/src/kmp_collapse.h11
-rw-r--r--openmp/runtime/test/worksharing/for/omp_for_collapse_LowerTriangularLess.c124
-rw-r--r--openmp/runtime/test/worksharing/for/omp_for_collapse_LowerTriangularLessEqual.c124
-rw-r--r--openmp/runtime/test/worksharing/for/omp_for_collapse_UpperTriangular.c124
-rw-r--r--utils/bazel/llvm-project-overlay/mlir/BUILD.bazel4
63 files changed, 2007 insertions, 225 deletions
diff --git a/clang/docs/LanguageExtensions.rst b/clang/docs/LanguageExtensions.rst
index 2b54dffd058a..06af93fd3c15 100644
--- a/clang/docs/LanguageExtensions.rst
+++ b/clang/docs/LanguageExtensions.rst
@@ -5378,6 +5378,7 @@ The following builtin intrinsics can be used in constant expressions:
* ``__builtin_popcount``
* ``__builtin_popcountl``
* ``__builtin_popcountll``
+* ``__builtin_popcountg``
* ``__builtin_rotateleft8``
* ``__builtin_rotateleft16``
* ``__builtin_rotateleft32``
diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst
index 42c4a7c4d4bd..fa23c215790f 100644
--- a/clang/docs/ReleaseNotes.rst
+++ b/clang/docs/ReleaseNotes.rst
@@ -157,6 +157,11 @@ Non-comprehensive list of changes in this release
- ``__builtin_addc``, ``__builtin_subc``, and the other sizes of those
builtins are now constexpr and may be used in constant expressions.
+- Added ``__builtin_popcountg`` as a type-generic alternative to
+ ``__builtin_popcount{,l,ll}`` with support for any unsigned integer type. Like
+ the previous builtins, this new builtin is constexpr and may be used in
+ constant expressions.
+
New Compiler Flags
------------------
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index a81131d82c4c..9c703377ca8d 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -706,7 +706,7 @@ def Popcount : Builtin, BitInt_Long_LongLongTemplate {
def Popcountg : Builtin {
let Spellings = ["__builtin_popcountg"];
- let Attributes = [NoThrow, Const, CustomTypeChecking];
+ let Attributes = [NoThrow, Const, Constexpr, CustomTypeChecking];
let Prototype = "int(...)";
}
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index d8ca35740fbc..4a7c7755e1d6 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -12483,6 +12483,7 @@ bool IntExprEvaluator::VisitBuiltinCallExpr(const CallExpr *E,
case Builtin::BI__builtin_popcount:
case Builtin::BI__builtin_popcountl:
case Builtin::BI__builtin_popcountll:
+ case Builtin::BI__builtin_popcountg:
case Builtin::BI__popcnt16: // Microsoft variants of popcount
case Builtin::BI__popcnt:
case Builtin::BI__popcnt64: {
diff --git a/clang/test/Sema/constant-builtins-2.c b/clang/test/Sema/constant-builtins-2.c
index 2bdd7b06daab..0935abe4c65f 100644
--- a/clang/test/Sema/constant-builtins-2.c
+++ b/clang/test/Sema/constant-builtins-2.c
@@ -237,6 +237,13 @@ char popcount7[__builtin_popcountl(~0L) == BITSIZE(long) ? 1 : -1];
char popcount8[__builtin_popcountll(0LL) == 0 ? 1 : -1];
char popcount9[__builtin_popcountll(0xF0F0LL) == 8 ? 1 : -1];
char popcount10[__builtin_popcountll(~0LL) == BITSIZE(long long) ? 1 : -1];
+char popcount11[__builtin_popcountg(0U) == 0 ? 1 : -1];
+char popcount12[__builtin_popcountg(0xF0F0U) == 8 ? 1 : -1];
+char popcount13[__builtin_popcountg(~0U) == BITSIZE(int) ? 1 : -1];
+char popcount14[__builtin_popcountg(~0UL) == BITSIZE(long) ? 1 : -1];
+char popcount15[__builtin_popcountg(~0ULL) == BITSIZE(long long) ? 1 : -1];
+char popcount16[__builtin_popcountg(~(unsigned __int128)0) == BITSIZE(__int128) ? 1 : -1];
+char popcount17[__builtin_popcountg(~(unsigned _BitInt(128))0) == BITSIZE(_BitInt(128)) ? 1 : -1];
char parity1[__builtin_parity(0) == 0 ? 1 : -1];
char parity2[__builtin_parity(0xb821) == 0 ? 1 : -1];
diff --git a/compiler-rt/lib/fuzzer/FuzzerUtilWindows.cpp b/compiler-rt/lib/fuzzer/FuzzerUtilWindows.cpp
index 0dbcec8b5f22..db80eb383885 100644
--- a/compiler-rt/lib/fuzzer/FuzzerUtilWindows.cpp
+++ b/compiler-rt/lib/fuzzer/FuzzerUtilWindows.cpp
@@ -243,7 +243,7 @@ void SetThreadName(std::thread &thread, const std::string &name) {
HMODULE kbase = GetModuleHandleA("KernelBase.dll");
proc ThreadNameProc =
reinterpret_cast<proc>(GetProcAddress(kbase, "SetThreadDescription"));
- if (proc) {
+ if (ThreadNameProc) {
std::wstring buf;
auto sz = MultiByteToWideChar(CP_UTF8, 0, name.data(), -1, nullptr, 0);
if (sz > 0) {
diff --git a/llvm/include/llvm/DebugInfo/DWARF/DWARFDebugFrame.h b/llvm/include/llvm/DebugInfo/DWARF/DWARFDebugFrame.h
index bc35f2ab988e..c7c558850a28 100644
--- a/llvm/include/llvm/DebugInfo/DWARF/DWARFDebugFrame.h
+++ b/llvm/include/llvm/DebugInfo/DWARF/DWARFDebugFrame.h
@@ -454,8 +454,8 @@ public:
/// where a problem occurred in case an error is returned.
Error parse(DWARFDataExtractor Data, uint64_t *Offset, uint64_t EndOffset);
- void dump(raw_ostream &OS, DIDumpOptions DumpOpts,
- unsigned IndentLevel = 1) const;
+ void dump(raw_ostream &OS, DIDumpOptions DumpOpts, unsigned IndentLevel,
+ std::optional<uint64_t> InitialLocation) const;
void addInstruction(const Instruction &I) { Instructions.push_back(I); }
@@ -524,7 +524,7 @@ private:
/// Print \p Opcode's operand number \p OperandIdx which has value \p Operand.
void printOperand(raw_ostream &OS, DIDumpOptions DumpOpts,
const Instruction &Instr, unsigned OperandIdx,
- uint64_t Operand) const;
+ uint64_t Operand, std::optional<uint64_t> &Address) const;
};
/// An entry in either debug_frame or eh_frame. This entry can be a CIE or an
diff --git a/llvm/include/llvm/Transforms/Utils/MemoryTaggingSupport.h b/llvm/include/llvm/Transforms/Utils/MemoryTaggingSupport.h
index eb00e6c4e856..df61f60de4f2 100644
--- a/llvm/include/llvm/Transforms/Utils/MemoryTaggingSupport.h
+++ b/llvm/include/llvm/Transforms/Utils/MemoryTaggingSupport.h
@@ -78,6 +78,7 @@ private:
uint64_t getAllocaSizeInBytes(const AllocaInst &AI);
void alignAndPadAlloca(memtag::AllocaInfo &Info, llvm::Align Align);
+bool isLifetimeIntrinsic(Value *V);
} // namespace memtag
} // namespace llvm
diff --git a/llvm/lib/DebugInfo/DWARF/DWARFDebugFrame.cpp b/llvm/lib/DebugInfo/DWARF/DWARFDebugFrame.cpp
index aae1668c1639..aff26824dda1 100644
--- a/llvm/lib/DebugInfo/DWARF/DWARFDebugFrame.cpp
+++ b/llvm/lib/DebugInfo/DWARF/DWARFDebugFrame.cpp
@@ -630,6 +630,8 @@ Error UnwindTable::parseRows(const CFIProgram &CFIP, UnwindRow &Row,
if (LRLoc->getLocation() == UnwindLocation::Constant) {
// Toggle the constant value from 0 to 1 or 1 to 0.
LRLoc->setConstant(LRLoc->getConstant() ^ 1);
+ Row.getRegisterLocations().setRegisterLocation(
+ AArch64DWARFPAuthRaState, *LRLoc);
} else {
return createStringError(
errc::invalid_argument,
@@ -858,7 +860,8 @@ CFIProgram::getOperandTypes() {
/// Print \p Opcode's operand number \p OperandIdx which has value \p Operand.
void CFIProgram::printOperand(raw_ostream &OS, DIDumpOptions DumpOpts,
const Instruction &Instr, unsigned OperandIdx,
- uint64_t Operand) const {
+ uint64_t Operand,
+ std::optional<uint64_t> &Address) const {
assert(OperandIdx < MaxOperands);
uint8_t Opcode = Instr.Opcode;
OperandType Type = getOperandTypes()[Opcode][OperandIdx];
@@ -877,6 +880,7 @@ void CFIProgram::printOperand(raw_ostream &OS, DIDumpOptions DumpOpts,
break;
case OT_Address:
OS << format(" %" PRIx64, Operand);
+ Address = Operand;
break;
case OT_Offset:
// The offsets are all encoded in a unsigned form, but in practice
@@ -888,7 +892,11 @@ void CFIProgram::printOperand(raw_ostream &OS, DIDumpOptions DumpOpts,
if (CodeAlignmentFactor)
OS << format(" %" PRId64, Operand * CodeAlignmentFactor);
else
- OS << format(" %" PRId64 "*code_alignment_factor" , Operand);
+ OS << format(" %" PRId64 "*code_alignment_factor", Operand);
+ if (Address && CodeAlignmentFactor) {
+ *Address += Operand * CodeAlignmentFactor;
+ OS << format(" to 0x%" PRIx64, *Address);
+ }
break;
case OT_SignedFactDataOffset:
if (DataAlignmentFactor)
@@ -918,13 +926,14 @@ void CFIProgram::printOperand(raw_ostream &OS, DIDumpOptions DumpOpts,
}
void CFIProgram::dump(raw_ostream &OS, DIDumpOptions DumpOpts,
- unsigned IndentLevel) const {
+ unsigned IndentLevel,
+ std::optional<uint64_t> Address) const {
for (const auto &Instr : Instructions) {
uint8_t Opcode = Instr.Opcode;
OS.indent(2 * IndentLevel);
OS << callFrameString(Opcode) << ":";
for (unsigned i = 0; i < Instr.Ops.size(); ++i)
- printOperand(OS, DumpOpts, Instr, i, Instr.Ops[i]);
+ printOperand(OS, DumpOpts, Instr, i, Instr.Ops[i], Address);
OS << '\n';
}
}
@@ -975,7 +984,7 @@ void CIE::dump(raw_ostream &OS, DIDumpOptions DumpOpts) const {
OS << "\n";
}
OS << "\n";
- CFIs.dump(OS, DumpOpts);
+ CFIs.dump(OS, DumpOpts, /*IndentLevel=*/1, /*InitialLocation=*/{});
OS << "\n";
if (Expected<UnwindTable> RowsOrErr = UnwindTable::create(this))
@@ -1003,7 +1012,7 @@ void FDE::dump(raw_ostream &OS, DIDumpOptions DumpOpts) const {
OS << " Format: " << FormatString(IsDWARF64) << "\n";
if (LSDAAddress)
OS << format(" LSDA Address: %016" PRIx64 "\n", *LSDAAddress);
- CFIs.dump(OS, DumpOpts);
+ CFIs.dump(OS, DumpOpts, /*IndentLevel=*/1, InitialLocation);
OS << "\n";
if (Expected<UnwindTable> RowsOrErr = UnwindTable::create(this))
diff --git a/llvm/lib/Target/AArch64/AArch64StackTagging.cpp b/llvm/lib/Target/AArch64/AArch64StackTagging.cpp
index ef7c517732ef..f2812d2b49bc 100644
--- a/llvm/lib/Target/AArch64/AArch64StackTagging.cpp
+++ b/llvm/lib/Target/AArch64/AArch64StackTagging.cpp
@@ -533,7 +533,9 @@ bool AArch64StackTagging::runOnFunction(Function &Fn) {
if (Info.AI->hasName())
TagPCall->setName(Info.AI->getName() + ".tag");
// Does not replace metadata, so we don't have to handle DPValues.
- Info.AI->replaceNonMetadataUsesWith(TagPCall);
+ Info.AI->replaceUsesWithIf(TagPCall, [&](const Use &U) {
+ return !memtag::isLifetimeIntrinsic(U.getUser());
+ });
TagPCall->setOperand(0, Info.AI);
// Calls to functions that may return twice (e.g. setjmp) confuse the
@@ -550,7 +552,7 @@ bool AArch64StackTagging::runOnFunction(Function &Fn) {
uint64_t Size =
cast<ConstantInt>(Start->getArgOperand(0))->getZExtValue();
Size = alignTo(Size, kTagGranuleSize);
- tagAlloca(AI, Start->getNextNode(), Start->getArgOperand(1), Size);
+ tagAlloca(AI, Start->getNextNode(), TagPCall, Size);
auto TagEnd = [&](Instruction *Node) { untagAlloca(AI, Node, Size); };
if (!DT || !PDT ||
diff --git a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp
index 289183ecf0f2..88553d49b1b5 100644
--- a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp
@@ -187,15 +187,15 @@ static cl::opt<bool>
cl::desc("Use selective instrumentation"),
cl::Hidden, cl::init(false));
-static cl::opt<int> HotPercentileCutoff(
+static cl::opt<int> ClHotPercentileCutoff(
"hwasan-percentile-cutoff-hot", cl::init(0),
cl::desc("Alternative hot percentile cuttoff."
"By default `-profile-summary-cutoff-hot` is used."));
static cl::opt<float>
- RandomSkipRate("hwasan-random-skip-rate", cl::init(0),
- cl::desc("Probability value in the range [0.0, 1.0] "
- "to skip instrumentation of a function."));
+ ClRandomSkipRate("hwasan-random-skip-rate", cl::init(0),
+ cl::desc("Probability value in the range [0.0, 1.0] "
+ "to skip instrumentation of a function."));
STATISTIC(NumTotalFuncs, "Number of total funcs");
STATISTIC(NumInstrumentedFuncs, "Number of instrumented funcs");
@@ -301,7 +301,7 @@ public:
? ClEnableKhwasan
: CompileKernel;
this->Rng =
- RandomSkipRate.getNumOccurrences() ? M.createRNG("hwasan") : nullptr;
+ ClRandomSkipRate.getNumOccurrences() ? M.createRNG("hwasan") : nullptr;
initializeModule();
}
@@ -1391,11 +1391,6 @@ bool HWAddressSanitizer::instrumentLandingPads(
return true;
}
-static bool isLifetimeIntrinsic(Value *V) {
- auto *II = dyn_cast<IntrinsicInst>(V);
- return II && II->isLifetimeStartOrEnd();
-}
-
static DbgAssignIntrinsic *DynCastToDbgAssign(DbgVariableIntrinsic *DVI) {
return dyn_cast<DbgAssignIntrinsic>(DVI);
}
@@ -1455,7 +1450,8 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo,
AI->replaceUsesWithIf(Replacement, [AICast, AILong](const Use &U) {
auto *User = U.getUser();
- return User != AILong && User != AICast && !isLifetimeIntrinsic(User);
+ return User != AILong && User != AICast &&
+ !memtag::isLifetimeIntrinsic(User);
});
// Helper utility for adding DW_OP_LLVM_tag_offset to debug-info records,
@@ -1537,8 +1533,8 @@ void HWAddressSanitizer::sanitizeFunction(Function &F,
NumTotalFuncs++;
if (CSelectiveInstrumentation) {
- if (RandomSkipRate.getNumOccurrences()) {
- std::bernoulli_distribution D(RandomSkipRate);
+ if (ClRandomSkipRate.getNumOccurrences()) {
+ std::bernoulli_distribution D(ClRandomSkipRate);
if (D(*Rng))
return;
} else {
@@ -1547,10 +1543,10 @@ void HWAddressSanitizer::sanitizeFunction(Function &F,
MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());
if (PSI && PSI->hasProfileSummary()) {
auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F);
- if ((HotPercentileCutoff.getNumOccurrences() &&
- HotPercentileCutoff >= 0)
+ if ((ClHotPercentileCutoff.getNumOccurrences() &&
+ ClHotPercentileCutoff >= 0)
? PSI->isFunctionHotInCallGraphNthPercentile(
- HotPercentileCutoff, &F, BFI)
+ ClHotPercentileCutoff, &F, BFI)
: PSI->isFunctionHotInCallGraph(&F, BFI))
return;
} else {
diff --git a/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp b/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp
index 2ffe89a24584..f4b9b155827a 100644
--- a/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp
+++ b/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp
@@ -12,6 +12,7 @@
#include "llvm/Transforms/Utils/MemoryTaggingSupport.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Analysis/CFG.h"
#include "llvm/Analysis/PostDominators.h"
#include "llvm/Analysis/StackSafetyAnalysis.h"
@@ -69,14 +70,12 @@ bool forAllReachableExits(const DominatorTree &DT, const PostDominatorTree &PDT,
++NumCoveredExits;
}
}
- // If there's a mix of covered and non-covered exits, just put the untag
- // on exits, so we avoid the redundancy of untagging twice.
if (NumCoveredExits == ReachableRetVec.size()) {
- for (auto *End : Ends)
- Callback(End);
+ for_each(Ends, Callback);
} else {
- for (auto *RI : ReachableRetVec)
- Callback(RI);
+ // If there's a mix of covered and non-covered exits, just put the untag
+ // on exits, so we avoid the redundancy of untagging twice.
+ for_each(ReachableRetVec, Callback);
// We may have inserted untag outside of the lifetime interval.
// Signal the caller to remove the lifetime end call for this alloca.
return false;
@@ -237,5 +236,10 @@ void alignAndPadAlloca(memtag::AllocaInfo &Info, llvm::Align Alignment) {
Info.AI = NewAI;
}
+bool isLifetimeIntrinsic(Value *V) {
+ auto *II = dyn_cast<IntrinsicInst>(V);
+ return II && II->isLifetimeStartOrEnd();
+}
+
} // namespace memtag
} // namespace llvm
diff --git a/llvm/test/CodeGen/AArch64/sign-return-address-cfi-negate-ra-state.ll b/llvm/test/CodeGen/AArch64/sign-return-address-cfi-negate-ra-state.ll
index da2c2985acf9..9464e3447993 100644
--- a/llvm/test/CodeGen/AArch64/sign-return-address-cfi-negate-ra-state.ll
+++ b/llvm/test/CodeGen/AArch64/sign-return-address-cfi-negate-ra-state.ll
@@ -213,6 +213,10 @@ attributes #0 = { "sign-return-address"="all" }
; CHECK-DUMP-NOT: DW_CFA_remember_state
; CHECK-DUMP-NOT: DW_CFA_restore_state
+; CHECK-DUMP: CFA=WSP{{$}}
+; CHECK-DUMP: reg34=1
+; CHECK-DUMP-NOT: reg34=0
+
; baz_async
; CHECK-DUMP-LABEL: FDE
; CHECK-DUMP: Format: DWARF32
@@ -222,9 +226,24 @@ attributes #0 = { "sign-return-address"="all" }
; CHECK-DUMP: DW_CFA_restore_state:
; CHECK-DUMP: DW_CFA_AARCH64_negate_ra_state:
+; CHECK-DUMP: CFA=WSP{{$}}
+;; First DW_CFA_AARCH64_negate_ra_state:
+; CHECK-DUMP: reg34=1
+;; Second DW_CFA_AARCH64_negate_ra_state:
+; CHECK-DUMP: reg34=0
+;; DW_CFA_restore_state:
+; CHECK-DUMP: reg34=1
+;; Third DW_CFA_AARCH64_negate_ra_state:
+; CHECK-DUMP: reg34=0
+; CHECK-DUMP-NOT: reg34=
+
; baz_sync
; CHECK-DUMP-LABEL: FDE
; CHECK-DUMP: DW_CFA_AARCH64_negate_ra_state:
; CHECK-DUMP-NOT: DW_CFA_AARCH64_negate_ra_state
; CHECK-DUMP-NOT: DW_CFA_remember_state
; CHECK-DUMP-NOT: DW_CFA_restore_state
+
+; CHECK-DUMP: CFA=WSP{{$}}
+; CHECK-DUMP: reg34=1
+; CHECK-DUMP-NOT: reg34=0
diff --git a/llvm/test/CodeGen/AArch64/stack-tagging-initializer-merge.ll b/llvm/test/CodeGen/AArch64/stack-tagging-initializer-merge.ll
index d8969fc9bebd..22d177ca3267 100644
--- a/llvm/test/CodeGen/AArch64/stack-tagging-initializer-merge.ll
+++ b/llvm/test/CodeGen/AArch64/stack-tagging-initializer-merge.ll
@@ -20,10 +20,10 @@ entry:
; CHECK-LABEL: define void @OneVarNoInit(
; CHECK-DAG: [[X:%.*]] = alloca { i32, [12 x i8] }, align 16
; CHECK-DAG: [[TX:%.*]] = call ptr @llvm.aarch64.tagp.{{.*}}(ptr [[X]], {{.*}}, i64 0)
-; CHECK-DAG: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[TX]])
+; CHECK-DAG: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[X]])
; CHECK-DAG: call void @llvm.aarch64.settag(ptr [[TX]], i64 16)
; CHECK-DAG: call void @use(ptr nonnull [[TX]])
-; CHECK-DAG: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[TX]])
+; CHECK-DAG: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[X]])
define void @OneVarInitConst() sanitize_memtag {
entry:
diff --git a/llvm/test/CodeGen/AArch64/stack-tagging-stack-coloring.ll b/llvm/test/CodeGen/AArch64/stack-tagging-stack-coloring.ll
index 6eb72013fb0e..81349620fb77 100644
--- a/llvm/test/CodeGen/AArch64/stack-tagging-stack-coloring.ll
+++ b/llvm/test/CodeGen/AArch64/stack-tagging-stack-coloring.ll
@@ -1,20 +1,20 @@
; Test that storage for allocas with disjoint lifetimes is reused with stack
; tagging.
-; RUN: opt -S -aarch64-stack-tagging %s -o - | \
-; RUN: llc -no-stack-coloring=false -o - | \
+; RUN: opt -S -aarch64-stack-tagging -stack-tagging-use-stack-safety=0 %s -o - | \
+; RUN: llc --mattr=+mte -no-stack-coloring=false -o - | \
; RUN: FileCheck %s --check-prefix=COLOR
-; RUN: opt -S -aarch64-stack-tagging %s -o - | \
-; RUN: llc -no-stack-coloring=true -o - | \
+; RUN: opt -S -aarch64-stack-tagging %s -stack-tagging-use-stack-safety=0 -o - | \
+; RUN: llc --mattr=+mte -no-stack-coloring=true -o - | \
; RUN: FileCheck %s --check-prefix=NOCOLOR
target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
-target triple = "aarch64-unknown-linux-android29"
+target triple = "aarch64"
-; COLOR: sub sp, sp, #192
-; NOCOLOR: sub sp, sp, #320
+; COLOR: sub sp, sp, #208
+; NOCOLOR: sub sp, sp, #336
-define i32 @myCall_w2(i32 %in) sanitize_hwaddress {
+define i32 @myCall_w2(i32 %in) sanitize_memtag {
entry:
%a = alloca [17 x ptr], align 8
%a2 = alloca [16 x ptr], align 8
diff --git a/llvm/test/CodeGen/AArch64/stack-tagging-untag-placement.ll b/llvm/test/CodeGen/AArch64/stack-tagging-untag-placement.ll
index 06f8cd5241eb..aa9cccc58712 100644
--- a/llvm/test/CodeGen/AArch64/stack-tagging-untag-placement.ll
+++ b/llvm/test/CodeGen/AArch64/stack-tagging-untag-placement.ll
@@ -27,7 +27,7 @@ S1:
; CHECK: call void @llvm.aarch64.settag(ptr %w, i64 48)
; CHECK-NOT: settag{{.*}}%v
call void @llvm.lifetime.end.p0(i64 48, ptr nonnull %w) #1
-; CHECK: call void @llvm.lifetime.end.p0(i64 48, ptr nonnull %w.tag)
+; CHECK: call void @llvm.lifetime.end.p0(i64 48, ptr nonnull %w)
%b1 = icmp eq i32 %t1, 0
br i1 %b1, label %S2, label %S3
; CHECK-NOT: settag
diff --git a/llvm/test/CodeGen/PowerPC/pr74951.ll b/llvm/test/CodeGen/PowerPC/pr74951.ll
new file mode 100644
index 000000000000..a0d19fc09cc2
--- /dev/null
+++ b/llvm/test/CodeGen/PowerPC/pr74951.ll
@@ -0,0 +1,54 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc < %s -verify-machineinstrs -ppc-asm-full-reg-names -mtriple=powerpc64-ibm-aix-xcoff | FileCheck %s
+
+%struct.anon = type { i32 }
+
+@b = local_unnamed_addr global %struct.anon { i32 -1 }, align 4
+@g = local_unnamed_addr global [1 x i1] zeroinitializer, align 1
+
+define noundef signext i32 @main() {
+; CHECK-LABEL: main:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: ld r3, L..C0(r2) # @b
+; CHECK-NEXT: lwz r3, 0(r3)
+; CHECK-NEXT: extsw r4, r3
+; CHECK-NEXT: neg r4, r4
+; CHECK-NEXT: andi. r5, r3, 65535
+; CHECK-NEXT: rldicl r4, r4, 1, 63
+; CHECK-NEXT: bne cr0, L..BB0_4
+; CHECK-NEXT: # %bb.1: # %lor.rhs.i.i
+; CHECK-NEXT: xori r5, r4, 1
+; CHECK-NEXT: cmpw r3, r5
+; CHECK-NEXT: crnot 4*cr5+lt, eq
+; CHECK-NEXT: li r3, 1
+; CHECK-NEXT: bc 12, 4*cr5+lt, L..BB0_3
+; CHECK-NEXT: # %bb.2: # %lor.rhs.i.i
+; CHECK-NEXT: li r3, 0
+; CHECK-NEXT: L..BB0_3: # %lor.rhs.i.i
+; CHECK-NEXT: ld r5, L..C1(r2) # @g
+; CHECK-NEXT: stb r3, 0(r5)
+; CHECK-NEXT: L..BB0_4: # %g.exit
+; CHECK-NEXT: ld r5, L..C1(r2) # @g
+; CHECK-NEXT: li r3, 0
+; CHECK-NEXT: stb r4, 0(r5)
+; CHECK-NEXT: blr
+entry:
+ %0 = load i32, ptr @b, align 4
+ %conv4.i = sext i32 %0 to i64
+ %cmp.i = icmp slt i32 %0, 1
+ %conv.i = zext i1 %cmp.i to i32
+ %cmp1.i = icmp ne i32 %0, %conv.i
+ %conv3.i = trunc i32 %0 to i16
+ %tobool.not.i.i = icmp eq i16 %conv3.i, 0
+ br i1 %tobool.not.i.i, label %lor.rhs.i.i, label %g.exit
+
+lor.rhs.i.i: ; preds = %entry
+ store i1 %cmp1.i, ptr @g, align 1
+ br label %g.exit
+
+g.exit: ; preds = %lor.end.i.i
+ %4 = trunc i64 %conv4.i to i32
+ %cmp.i9.i = icmp sgt i32 %4, 0
+ store i1 %cmp.i9.i, ptr @g, align 1
+ ret i32 0
+}
diff --git a/llvm/test/CodeGen/RISCV/forced-atomics.ll b/llvm/test/CodeGen/RISCV/forced-atomics.ll
index 2b198afb47a9..659e0748dd53 100644
--- a/llvm/test/CodeGen/RISCV/forced-atomics.ll
+++ b/llvm/test/CodeGen/RISCV/forced-atomics.ll
@@ -3567,8 +3567,8 @@ define i64 @rmw64_umax_seq_cst(ptr %p) nounwind {
; RV32-NEXT: # in Loop: Header=BB51_2 Depth=1
; RV32-NEXT: neg a3, a0
; RV32-NEXT: and a3, a3, a1
-; RV32-NEXT: sw a1, 4(sp)
; RV32-NEXT: sw a4, 0(sp)
+; RV32-NEXT: sw a1, 4(sp)
; RV32-NEXT: mv a1, sp
; RV32-NEXT: li a4, 5
; RV32-NEXT: li a5, 5
diff --git a/llvm/test/CodeGen/RISCV/fpclamptosat.ll b/llvm/test/CodeGen/RISCV/fpclamptosat.ll
index 6bfacc3e9814..630d16e7c888 100644
--- a/llvm/test/CodeGen/RISCV/fpclamptosat.ll
+++ b/llvm/test/CodeGen/RISCV/fpclamptosat.ll
@@ -1324,8 +1324,8 @@ define i64 @ustest_f64i64(double %x) {
; RV32IF-NEXT: # %bb.4: # %entry
; RV32IF-NEXT: li a0, 1
; RV32IF-NEXT: .LBB20_5: # %entry
-; RV32IF-NEXT: lw a3, 8(sp)
-; RV32IF-NEXT: lw a4, 12(sp)
+; RV32IF-NEXT: lw a4, 8(sp)
+; RV32IF-NEXT: lw a3, 12(sp)
; RV32IF-NEXT: and a5, a2, a1
; RV32IF-NEXT: beqz a5, .LBB20_7
; RV32IF-NEXT: # %bb.6: # %entry
@@ -1334,17 +1334,18 @@ define i64 @ustest_f64i64(double %x) {
; RV32IF-NEXT: .LBB20_7:
; RV32IF-NEXT: snez a1, a0
; RV32IF-NEXT: .LBB20_8: # %entry
-; RV32IF-NEXT: and a4, a2, a4
+; RV32IF-NEXT: and a3, a2, a3
; RV32IF-NEXT: or a0, a0, a5
-; RV32IF-NEXT: and a2, a2, a3
+; RV32IF-NEXT: and a2, a2, a4
; RV32IF-NEXT: bnez a0, .LBB20_10
; RV32IF-NEXT: # %bb.9:
-; RV32IF-NEXT: or a0, a2, a4
-; RV32IF-NEXT: snez a1, a0
+; RV32IF-NEXT: snez a0, a3
+; RV32IF-NEXT: snez a1, a2
+; RV32IF-NEXT: or a1, a1, a0
; RV32IF-NEXT: .LBB20_10: # %entry
; RV32IF-NEXT: neg a1, a1
; RV32IF-NEXT: and a0, a1, a2
-; RV32IF-NEXT: and a1, a1, a4
+; RV32IF-NEXT: and a1, a1, a3
; RV32IF-NEXT: lw ra, 28(sp) # 4-byte Folded Reload
; RV32IF-NEXT: addi sp, sp, 32
; RV32IF-NEXT: ret
@@ -1403,8 +1404,8 @@ define i64 @ustest_f64i64(double %x) {
; RV32IFD-NEXT: # %bb.4: # %entry
; RV32IFD-NEXT: li a0, 1
; RV32IFD-NEXT: .LBB20_5: # %entry
-; RV32IFD-NEXT: lw a3, 8(sp)
-; RV32IFD-NEXT: lw a4, 12(sp)
+; RV32IFD-NEXT: lw a4, 8(sp)
+; RV32IFD-NEXT: lw a3, 12(sp)
; RV32IFD-NEXT: and a5, a2, a1
; RV32IFD-NEXT: beqz a5, .LBB20_7
; RV32IFD-NEXT: # %bb.6: # %entry
@@ -1413,17 +1414,18 @@ define i64 @ustest_f64i64(double %x) {
; RV32IFD-NEXT: .LBB20_7:
; RV32IFD-NEXT: snez a1, a0
; RV32IFD-NEXT: .LBB20_8: # %entry
-; RV32IFD-NEXT: and a4, a2, a4
+; RV32IFD-NEXT: and a3, a2, a3
; RV32IFD-NEXT: or a0, a0, a5
-; RV32IFD-NEXT: and a2, a2, a3
+; RV32IFD-NEXT: and a2, a2, a4
; RV32IFD-NEXT: bnez a0, .LBB20_10
; RV32IFD-NEXT: # %bb.9:
-; RV32IFD-NEXT: or a0, a2, a4
-; RV32IFD-NEXT: snez a1, a0
+; RV32IFD-NEXT: snez a0, a3
+; RV32IFD-NEXT: snez a1, a2
+; RV32IFD-NEXT: or a1, a1, a0
; RV32IFD-NEXT: .LBB20_10: # %entry
; RV32IFD-NEXT: neg a1, a1
; RV32IFD-NEXT: and a0, a1, a2
-; RV32IFD-NEXT: and a1, a1, a4
+; RV32IFD-NEXT: and a1, a1, a3
; RV32IFD-NEXT: lw ra, 28(sp) # 4-byte Folded Reload
; RV32IFD-NEXT: addi sp, sp, 32
; RV32IFD-NEXT: ret
@@ -1594,8 +1596,8 @@ define i64 @ustest_f32i64(float %x) {
; RV32-NEXT: # %bb.4: # %entry
; RV32-NEXT: li a0, 1
; RV32-NEXT: .LBB23_5: # %entry
-; RV32-NEXT: lw a3, 8(sp)
-; RV32-NEXT: lw a4, 12(sp)
+; RV32-NEXT: lw a4, 8(sp)
+; RV32-NEXT: lw a3, 12(sp)
; RV32-NEXT: and a5, a2, a1
; RV32-NEXT: beqz a5, .LBB23_7
; RV32-NEXT: # %bb.6: # %entry
@@ -1604,17 +1606,18 @@ define i64 @ustest_f32i64(float %x) {
; RV32-NEXT: .LBB23_7:
; RV32-NEXT: snez a1, a0
; RV32-NEXT: .LBB23_8: # %entry
-; RV32-NEXT: and a4, a2, a4
+; RV32-NEXT: and a3, a2, a3
; RV32-NEXT: or a0, a0, a5
-; RV32-NEXT: and a2, a2, a3
+; RV32-NEXT: and a2, a2, a4
; RV32-NEXT: bnez a0, .LBB23_10
; RV32-NEXT: # %bb.9:
-; RV32-NEXT: or a0, a2, a4
-; RV32-NEXT: snez a1, a0
+; RV32-NEXT: snez a0, a3
+; RV32-NEXT: snez a1, a2
+; RV32-NEXT: or a1, a1, a0
; RV32-NEXT: .LBB23_10: # %entry
; RV32-NEXT: neg a1, a1
; RV32-NEXT: and a0, a1, a2
-; RV32-NEXT: and a1, a1, a4
+; RV32-NEXT: and a1, a1, a3
; RV32-NEXT: lw ra, 28(sp) # 4-byte Folded Reload
; RV32-NEXT: addi sp, sp, 32
; RV32-NEXT: ret
@@ -1847,8 +1850,8 @@ define i64 @ustest_f16i64(half %x) {
; RV32-NEXT: # %bb.4: # %entry
; RV32-NEXT: li a0, 1
; RV32-NEXT: .LBB26_5: # %entry
-; RV32-NEXT: lw a3, 8(sp)
-; RV32-NEXT: lw a4, 12(sp)
+; RV32-NEXT: lw a4, 8(sp)
+; RV32-NEXT: lw a3, 12(sp)
; RV32-NEXT: and a5, a2, a1
; RV32-NEXT: beqz a5, .LBB26_7
; RV32-NEXT: # %bb.6: # %entry
@@ -1857,17 +1860,18 @@ define i64 @ustest_f16i64(half %x) {
; RV32-NEXT: .LBB26_7:
; RV32-NEXT: snez a1, a0
; RV32-NEXT: .LBB26_8: # %entry
-; RV32-NEXT: and a4, a2, a4
+; RV32-NEXT: and a3, a2, a3
; RV32-NEXT: or a0, a0, a5
-; RV32-NEXT: and a2, a2, a3
+; RV32-NEXT: and a2, a2, a4
; RV32-NEXT: bnez a0, .LBB26_10
; RV32-NEXT: # %bb.9:
-; RV32-NEXT: or a0, a2, a4
-; RV32-NEXT: snez a1, a0
+; RV32-NEXT: snez a0, a3
+; RV32-NEXT: snez a1, a2
+; RV32-NEXT: or a1, a1, a0
; RV32-NEXT: .LBB26_10: # %entry
; RV32-NEXT: neg a1, a1
; RV32-NEXT: and a0, a1, a2
-; RV32-NEXT: and a1, a1, a4
+; RV32-NEXT: and a1, a1, a3
; RV32-NEXT: lw ra, 28(sp) # 4-byte Folded Reload
; RV32-NEXT: addi sp, sp, 32
; RV32-NEXT: ret
diff --git a/llvm/test/DebugInfo/dwarfdump-debug-frame-simple.test b/llvm/test/DebugInfo/dwarfdump-debug-frame-simple.test
index 6c049af43efe..2cd281c8d0af 100644
--- a/llvm/test/DebugInfo/dwarfdump-debug-frame-simple.test
+++ b/llvm/test/DebugInfo/dwarfdump-debug-frame-simple.test
@@ -12,15 +12,15 @@
; FRAMES-NEXT: DW_CFA_nop:
; FRAMES: 00000014 00000010 00000000 FDE cie=00000000 pc=00000000...00000022
-; FRAMES: DW_CFA_advance_loc: 3
+; FRAMES: DW_CFA_advance_loc: 3 to 0x3
; FRAMES-NEXT: DW_CFA_def_cfa_offset: +12
; FRAMES-NEXT: DW_CFA_nop:
; FRAMES: 00000028 00000014 00000000 FDE cie=00000000 pc=00000030...00000080
-; FRAMES: DW_CFA_advance_loc: 1
+; FRAMES: DW_CFA_advance_loc: 1 to 0x31
; FRAMES-NEXT: DW_CFA_def_cfa_offset: +8
; FRAMES-NEXT: DW_CFA_offset: {{reg5|EBP}} -8
-; FRAMES-NEXT: DW_CFA_advance_loc: 2
+; FRAMES-NEXT: DW_CFA_advance_loc: 2 to 0x33
; FRAMES-NEXT: DW_CFA_def_cfa_register: {{reg5|EBP}}
; FRAMES-NOT: CIE
diff --git a/llvm/test/Instrumentation/AddressSanitizer/aarch64be.ll b/llvm/test/Instrumentation/AddressSanitizer/aarch64be.ll
index eb522a0f3f31..aeb1b0e8ebe7 100644
--- a/llvm/test/Instrumentation/AddressSanitizer/aarch64be.ll
+++ b/llvm/test/Instrumentation/AddressSanitizer/aarch64be.ll
@@ -2,9 +2,9 @@
; RUN: opt < %s -passes=asan -S -mtriple=aarch64_be-linux-gnu | FileCheck --check-prefix=CHECK-AARCH64BE %s
; REQUIRES: aarch64-registered-target
-define i32 @read_4_bytes(i32* %a) sanitize_address {
+define i32 @read_4_bytes(ptr %a) sanitize_address {
entry:
- %tmp1 = load i32, i32* %a, align 4
+ %tmp1 = load i32, ptr %a, align 4
ret i32 %tmp1
}
diff --git a/llvm/test/Instrumentation/AddressSanitizer/program-addrspace.ll b/llvm/test/Instrumentation/AddressSanitizer/program-addrspace.ll
index adfe21135e7a..1d5bfb09ead9 100644
--- a/llvm/test/Instrumentation/AddressSanitizer/program-addrspace.ll
+++ b/llvm/test/Instrumentation/AddressSanitizer/program-addrspace.ll
@@ -16,7 +16,7 @@ target datalayout = "P1"
define i1 @b(i64 %c) addrspace(1) {
%cast = inttoptr i64 %c to ptr addrspace(42)
- %cmp = icmp ugt ptr addrspace(42) %cast, getelementptr inbounds ([1 x i32], ptr addrspace(42) @a, i64 0, i64 0)
+ %cmp = icmp ugt ptr addrspace(42) %cast, @a
ret i1 %cmp
}
diff --git a/llvm/test/Instrumentation/InstrProfiling/before-value-profile-lowering.ll b/llvm/test/Instrumentation/InstrProfiling/before-value-profile-lowering.ll
index 5dfec433f4ec..870e74ccfdac 100644
--- a/llvm/test/Instrumentation/InstrProfiling/before-value-profile-lowering.ll
+++ b/llvm/test/Instrumentation/InstrProfiling/before-value-profile-lowering.ll
@@ -7,17 +7,17 @@
target triple = "x86_64-unknown-linux-gnu"
-declare void @llvm.instrprof.increment.step(i8*, i64, i32, i32, i64)
+declare void @llvm.instrprof.increment.step(ptr, i64, i32, i32, i64)
-declare void @llvm.instrprof.value.profile(i8*, i64, i64, i32, i32)
+declare void @llvm.instrprof.value.profile(ptr, i64, i64, i32, i32)
; CHECK: @__profd_foo = private global
@__profn_foo = private constant [3 x i8] c"foo"
-define i32 @foo(i32 ()* ) {
- %2 = ptrtoint i32 ()* %0 to i64
- call void @llvm.instrprof.value.profile(i8* getelementptr inbounds ([3 x i8], [3 x i8]* @__profn_foo, i32 0, i32 0), i64 0, i64 %2, i32 0, i32 0)
- call void @llvm.instrprof.increment.step(i8* getelementptr inbounds ([3 x i8], [3 x i8]* @__profn_foo, i32 0, i32 0), i64 0, i32 1, i32 0, i64 0)
+define i32 @foo(ptr ) {
+ %2 = ptrtoint ptr %0 to i64
+ call void @llvm.instrprof.value.profile(ptr @__profn_foo, i64 0, i64 %2, i32 0, i32 0)
+ call void @llvm.instrprof.increment.step(ptr @__profn_foo, i64 0, i32 1, i32 0, i64 0)
%3 = tail call i32 %0()
ret i32 %3
}
diff --git a/llvm/test/Instrumentation/InstrProfiling/timestamp-coverage.ll b/llvm/test/Instrumentation/InstrProfiling/timestamp-coverage.ll
index ab9b664a2cff..d40cc2ac02c1 100644
--- a/llvm/test/Instrumentation/InstrProfiling/timestamp-coverage.ll
+++ b/llvm/test/Instrumentation/InstrProfiling/timestamp-coverage.ll
@@ -6,11 +6,11 @@ target triple = "aarch64-unknown-linux-gnu"
; CHECK: @__profc_foo = private global [9 x i8] c"\FF\FF\FF\FF\FF\FF\FF\FF\FF", section "__llvm_prf_cnts", comdat, align 8
define void @_Z3foov() {
- call void @llvm.instrprof.timestamp(i8* getelementptr inbounds ([3 x i8], [3 x i8]* @__profn_foo, i32 0, i32 0), i64 12345678, i32 9, i32 0)
+ call void @llvm.instrprof.timestamp(ptr @__profn_foo, i64 12345678, i32 9, i32 0)
; CHECK: call void @__llvm_profile_set_timestamp(ptr @__profc_foo)
- call void @llvm.instrprof.cover(i8* getelementptr inbounds ([3 x i8], [3 x i8]* @__profn_foo, i32 0, i32 0), i64 12345678, i32 9, i32 8)
+ call void @llvm.instrprof.cover(ptr @__profn_foo, i64 12345678, i32 9, i32 8)
ret void
}
-declare void @llvm.instrprof.timestamp(i8*, i64, i32, i32)
-declare void @llvm.instrprof.cover(i8*, i64, i32, i32)
+declare void @llvm.instrprof.timestamp(ptr, i64, i32, i32)
+declare void @llvm.instrprof.cover(ptr, i64, i32, i32)
diff --git a/llvm/test/Instrumentation/InstrProfiling/timestamp.ll b/llvm/test/Instrumentation/InstrProfiling/timestamp.ll
index aa2393695d6b..c08ba4485fc5 100644
--- a/llvm/test/Instrumentation/InstrProfiling/timestamp.ll
+++ b/llvm/test/Instrumentation/InstrProfiling/timestamp.ll
@@ -6,11 +6,11 @@ target triple = "aarch64-unknown-linux-gnu"
; CHECK: @__profc_foo = private global [2 x i64] zeroinitializer, section "__llvm_prf_cnts", comdat, align 8
define void @_Z3foov() {
- call void @llvm.instrprof.timestamp(i8* getelementptr inbounds ([3 x i8], [3 x i8]* @__profn_foo, i32 0, i32 0), i64 12345678, i32 2, i32 0)
+ call void @llvm.instrprof.timestamp(ptr @__profn_foo, i64 12345678, i32 2, i32 0)
; CHECK: call void @__llvm_profile_set_timestamp(ptr @__profc_foo)
- call void @llvm.instrprof.increment(i8* getelementptr inbounds ([3 x i8], [3 x i8]* @__profn_foo, i32 0, i32 0), i64 12345678, i32 2, i32 1)
+ call void @llvm.instrprof.increment(ptr @__profn_foo, i64 12345678, i32 2, i32 1)
ret void
}
-declare void @llvm.instrprof.timestamp(i8*, i64, i32, i32)
-declare void @llvm.instrprof.increment(i8*, i64, i32, i32)
+declare void @llvm.instrprof.timestamp(ptr, i64, i32, i32)
+declare void @llvm.instrprof.increment(ptr, i64, i32, i32)
diff --git a/llvm/test/Object/Inputs/small.ll b/llvm/test/Object/Inputs/small.ll
index ef68a8c324a3..677f20ade4c5 100644
--- a/llvm/test/Object/Inputs/small.ll
+++ b/llvm/test/Object/Inputs/small.ll
@@ -4,15 +4,15 @@ target triple = "i386-pc-windows"
define i32 @main() nounwind {
entry:
- %call = tail call i32 @puts(i8* getelementptr inbounds ([13 x i8], [13 x i8]* @.str, i32 0, i32 0)) nounwind
- tail call void bitcast (void (...)* @SomeOtherFunction to void ()*)() nounwind
+ %call = tail call i32 @puts(ptr @.str) nounwind
+ tail call void @SomeOtherFunction() nounwind
ret i32 0
}
-declare i32 @puts(i8* nocapture) nounwind
+declare i32 @puts(ptr nocapture) nounwind
declare void @SomeOtherFunction(...)
@var = global i32 0
-@llvm.used = appending global [1 x i8*] [i8* bitcast (i32* @var to i8*)], section "llvm.metadata"
-@llvm.global_ctors = appending global [1 x { i32, void ()*, i8* }] [{ i32, void ()*, i8* } { i32 65535, void ()* null, i8* null }]
+@llvm.used = appending global [1 x ptr] [ptr @var], section "llvm.metadata"
+@llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 65535, ptr null, ptr null }]
diff --git a/llvm/test/Object/Inputs/trivial.ll b/llvm/test/Object/Inputs/trivial.ll
index 82eabc6389fb..1a6a76298b23 100644
--- a/llvm/test/Object/Inputs/trivial.ll
+++ b/llvm/test/Object/Inputs/trivial.ll
@@ -5,15 +5,15 @@
define i32 @main() nounwind {
entry:
- %call = tail call i32 @puts(i8* getelementptr inbounds ([13 x i8], [13 x i8]* @.str, i32 0, i32 0)) nounwind
- tail call void bitcast (void (...)* @SomeOtherFunction to void ()*)() nounwind
+ %call = tail call i32 @puts(ptr @.str) nounwind
+ tail call void @SomeOtherFunction() nounwind
ret i32 0
}
-declare i32 @puts(i8* nocapture) nounwind
+declare i32 @puts(ptr nocapture) nounwind
declare void @SomeOtherFunction(...)
@var = global i32 0
-@llvm.used = appending global [1 x i8*] [i8* bitcast (i32* @var to i8*)], section "llvm.metadata"
-@llvm.global_ctors = appending global [1 x { i32, void ()*, i8* }] [{ i32, void ()*, i8* } { i32 65535, void ()* null, i8* null }]
+@llvm.used = appending global [1 x ptr] [ptr @var], section "llvm.metadata"
+@llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 65535, ptr null, ptr null }]
diff --git a/llvm/test/Object/X86/irsymtab-bad-alias.ll b/llvm/test/Object/X86/irsymtab-bad-alias.ll
index c54436d59219..7f204d1dd157 100644
--- a/llvm/test/Object/X86/irsymtab-bad-alias.ll
+++ b/llvm/test/Object/X86/irsymtab-bad-alias.ll
@@ -11,5 +11,5 @@ target triple = "x86_64-unknown-linux-gnu"
@g1 = global i32 1
@g2 = global i32 2
-@a = alias i32, inttoptr(i32 sub (i32 ptrtoint (i32* @g1 to i32),
- i32 ptrtoint (i32* @g2 to i32)) to i32*)
+@a = alias i32, inttoptr(i32 sub (i32 ptrtoint (ptr @g1 to i32),
+ i32 ptrtoint (ptr @g2 to i32)) to ptr)
diff --git a/llvm/test/Object/X86/nm-ir.ll b/llvm/test/Object/X86/nm-ir.ll
index e57c6d9a11c6..0324efb2948d 100644
--- a/llvm/test/Object/X86/nm-ir.ll
+++ b/llvm/test/Object/X86/nm-ir.ll
@@ -29,15 +29,15 @@ module asm ".long undef_asm_sym"
@g3 = common global i32 0
@g4 = private global i32 42
-@a1 = alias i32, i32* @g1
-@a2 = internal alias i32, i32* @g1
+@a1 = alias i32, ptr @g1
+@a2 = internal alias i32, ptr @g1
-define void ()* @f1() {
+define ptr @f1() {
call void @f5()
- ret void ()* null
+ ret ptr null
}
-@ifunc_f1 = ifunc void (), void ()* ()* @f1
+@ifunc_f1 = ifunc void (), ptr @f1
define internal void @f2() {
ret void
diff --git a/llvm/test/Object/dllimport-globalref.ll b/llvm/test/Object/dllimport-globalref.ll
index dd518bc2266c..0a95be20a9d1 100644
--- a/llvm/test/Object/dllimport-globalref.ll
+++ b/llvm/test/Object/dllimport-globalref.ll
@@ -11,4 +11,4 @@ target triple = "x86_64-pc-windows-msvc"
; CHECK: U f
declare dllimport void @f()
-@fp = constant void ()* @f
+@fp = constant ptr @f
diff --git a/llvm/test/Object/dllimport.ll b/llvm/test/Object/dllimport.ll
index afdb4562cc9f..52f583fa2487 100644
--- a/llvm/test/Object/dllimport.ll
+++ b/llvm/test/Object/dllimport.ll
@@ -12,6 +12,6 @@ declare dllimport void @f()
define void @g() {
call void @f()
- store i32 42, i32* @v
+ store i32 42, ptr @v
ret void
}
diff --git a/llvm/test/Object/mangle-ir.ll b/llvm/test/Object/mangle-ir.ll
index bd7c3d93b7c9..76442f070385 100644
--- a/llvm/test/Object/mangle-ir.ll
+++ b/llvm/test/Object/mangle-ir.ll
@@ -7,8 +7,8 @@ target datalayout = "m:o"
; CHECK-NOT: memcpy
define void @f() {
- tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* null, i8* null, i64 0, i1 false)
+ tail call void @llvm.memcpy.p0.p0.i64(ptr null, ptr null, i64 0, i1 false)
ret void
}
-declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture, i8* nocapture readonly, i64, i1)
+declare void @llvm.memcpy.p0.p0.i64(ptr nocapture, ptr nocapture readonly, i64, i1)
diff --git a/llvm/test/Object/objc-swift-mixed-imageinfo-macho.ll b/llvm/test/Object/objc-swift-mixed-imageinfo-macho.ll
index d2518f46cc27..c506c9687ec2 100644
--- a/llvm/test/Object/objc-swift-mixed-imageinfo-macho.ll
+++ b/llvm/test/Object/objc-swift-mixed-imageinfo-macho.ll
@@ -5,11 +5,11 @@
target triple = "x86_64-apple-macosx10.15.0"
-@llvm.used = appending global [1 x i8*] [i8* bitcast (i16* @__swift_reflection_version to i8*)], section "llvm.metadata", align 8
+@llvm.used = appending global [1 x ptr] [ptr @__swift_reflection_version], section "llvm.metadata", align 8
@__swift_reflection_version = linkonce_odr hidden constant i16 3
-define i32 @main(i32 %0, i8** %1) #0 {
- %3 = bitcast i8** %1 to i8*
+define i32 @main(i32 %0, ptr %1) #0 {
+ %3 = bitcast ptr %1 to ptr
ret i32 0
}
@@ -25,7 +25,7 @@ attributes #0 = { "frame-pointer"="all" "target-cpu"="penryn" "target-features"=
!1 = !{!"-lswiftSwiftOnoneSupport"}
!2 = !{!"-lswiftCore"}
!3 = !{!"-lobjc"}
-!4 = !{[1 x i8*]* @llvm.used, null, null, i1 false, i1 true}
+!4 = !{ptr @llvm.used, null, null, i1 false, i1 true}
!5 = !{i32 2, !"SDK Version", [2 x i32] [i32 10, i32 15]}
!6 = !{i32 1, !"Objective-C Version", i32 2}
!7 = !{i32 1, !"Objective-C Image Info Version", i32 0}
diff --git a/llvm/test/tools/llvm-readobj/ELF/unwind.test b/llvm/test/tools/llvm-readobj/ELF/unwind.test
index 2deb1a587d24..2e51ec2a61a6 100644
--- a/llvm/test/tools/llvm-readobj/ELF/unwind.test
+++ b/llvm/test/tools/llvm-readobj/ELF/unwind.test
@@ -96,9 +96,9 @@
# CHECK: Program:
# CHECK-NEXT: DW_CFA_def_cfa_offset: +16
-# CHECK-NEXT: DW_CFA_advance_loc: 6
+# CHECK-NEXT: DW_CFA_advance_loc: 6 to 0x4004a6
# CHECK-NEXT: DW_CFA_def_cfa_offset: +24
-# CHECK-NEXT: DW_CFA_advance_loc: 10
+# CHECK-NEXT: DW_CFA_advance_loc: 10 to 0x4004b0
# CHECK-NEXT: DW_CFA_def_cfa_expression: DW_OP_breg7 +8, DW_OP_breg16 +0, DW_OP_lit15, DW_OP_and, DW_OP_lit11, DW_OP_ge, DW_OP_lit3, DW_OP_shl, DW_OP_plus
# CHECK-NEXT: DW_CFA_nop:
# CHECK-NEXT: DW_CFA_nop:
@@ -110,12 +110,12 @@
# CHECK-NEXT: address_range: 0x10 (end : 0x4005c6)
# CHECK: Program:
-# CHECK-NEXT: DW_CFA_advance_loc: 1
+# CHECK-NEXT: DW_CFA_advance_loc: 1 to 0x4005b7
# CHECK-NEXT: DW_CFA_def_cfa_offset: +16
# CHECK-NEXT: DW_CFA_offset: reg6 -16
-# CHECK-NEXT: DW_CFA_advance_loc: 3
+# CHECK-NEXT: DW_CFA_advance_loc: 3 to 0x4005ba
# CHECK-NEXT: DW_CFA_def_cfa_register: reg6
-# CHECK-NEXT: DW_CFA_advance_loc: 11
+# CHECK-NEXT: DW_CFA_advance_loc: 11 to 0x4005c5
# CHECK-NEXT: DW_CFA_def_cfa: reg7 +8
# CHECK-NEXT: DW_CFA_nop:
# CHECK-NEXT: DW_CFA_nop:
@@ -126,15 +126,15 @@
# CHECK-NEXT: address_range: 0xc7f (end : 0x40124f)
# CHECK: Program:
-# CHECK-NEXT: DW_CFA_advance_loc: 5
+# CHECK-NEXT: DW_CFA_advance_loc: 5 to 0x4005d5
# CHECK-NEXT: DW_CFA_def_cfa: reg10 +0
-# CHECK-NEXT: DW_CFA_advance_loc: 9
+# CHECK-NEXT: DW_CFA_advance_loc: 9 to 0x4005de
# CHECK-NEXT: DW_CFA_expression: reg6 DW_OP_breg6 +0
-# CHECK-NEXT: DW_CFA_advance_loc: 5
+# CHECK-NEXT: DW_CFA_advance_loc: 5 to 0x4005e3
# CHECK-NEXT: DW_CFA_def_cfa_expression: DW_OP_breg6 -8, DW_OP_deref
-# CHECK-NEXT: DW_CFA_advance_loc2: 3174
+# CHECK-NEXT: DW_CFA_advance_loc2: 3174 to 0x401249
# CHECK-NEXT: DW_CFA_def_cfa: reg10 +0
-# CHECK-NEXT: DW_CFA_advance_loc: 5
+# CHECK-NEXT: DW_CFA_advance_loc: 5 to 0x40124e
# CHECK-NEXT: DW_CFA_def_cfa: reg7 +8
# CHECK-NEXT: DW_CFA_nop:
# CHECK-NEXT: DW_CFA_nop:
@@ -146,21 +146,21 @@
# CHECK-NEXT: address_range: 0x66 (end : 0x4012b6)
# CHECK: Program:
-# CHECK-NEXT: DW_CFA_advance_loc: 1
+# CHECK-NEXT: DW_CFA_advance_loc: 1 to 0x401251
# CHECK-NEXT: DW_CFA_def_cfa_offset: +16
# CHECK-NEXT: DW_CFA_offset: reg6 -16
-# CHECK-NEXT: DW_CFA_advance_loc: 3
+# CHECK-NEXT: DW_CFA_advance_loc: 3 to 0x401254
# CHECK-NEXT: DW_CFA_def_cfa_register: reg6
-# CHECK-NEXT: DW_CFA_advance_loc: 2
+# CHECK-NEXT: DW_CFA_advance_loc: 2 to 0x401256
# CHECK-NEXT: DW_CFA_offset: reg15 -24
-# CHECK-NEXT: DW_CFA_advance_loc: 5
+# CHECK-NEXT: DW_CFA_advance_loc: 5 to 0x40125b
# CHECK-NEXT: DW_CFA_offset: reg14 -32
-# CHECK-NEXT: DW_CFA_advance_loc: 7
+# CHECK-NEXT: DW_CFA_advance_loc: 7 to 0x401262
# CHECK-NEXT: DW_CFA_offset: reg13 -40
# CHECK-NEXT: DW_CFA_offset: reg12 -48
-# CHECK-NEXT: DW_CFA_advance_loc: 8
+# CHECK-NEXT: DW_CFA_advance_loc: 8 to 0x40126a
# CHECK-NEXT: DW_CFA_offset: reg3 -56
-# CHECK-NEXT: DW_CFA_advance_loc1: 75
+# CHECK-NEXT: DW_CFA_advance_loc1: 75 to 0x4012b5
# CHECK-NEXT: DW_CFA_def_cfa: reg7 +8
# CHECK-NEXT: DW_CFA_nop:
# CHECK-NEXT: DW_CFA_nop:
diff --git a/llvm/tools/llvm-readobj/DwarfCFIEHPrinter.h b/llvm/tools/llvm-readobj/DwarfCFIEHPrinter.h
index 687d97abd023..2e89463e68d5 100644
--- a/llvm/tools/llvm-readobj/DwarfCFIEHPrinter.h
+++ b/llvm/tools/llvm-readobj/DwarfCFIEHPrinter.h
@@ -196,6 +196,7 @@ void PrinterContext<ELFT>::printEHFrame(const Elf_Shdr *EHFrameShdr) const {
reportError(std::move(E), ObjF.getFileName());
for (const dwarf::FrameEntry &Entry : EHFrame) {
+ std::optional<uint64_t> InitialLocation;
if (const dwarf::CIE *CIE = dyn_cast<dwarf::CIE>(&Entry)) {
W.startLine() << format("[0x%" PRIx64 "] CIE length=%" PRIu64 "\n",
Address + CIE->getOffset(), CIE->getLength());
@@ -214,8 +215,9 @@ void PrinterContext<ELFT>::printEHFrame(const Elf_Shdr *EHFrameShdr) const {
Address + FDE->getLinkedCIE()->getOffset());
W.indent();
+ InitialLocation = FDE->getInitialLocation();
W.startLine() << format("initial_location: 0x%" PRIx64 "\n",
- FDE->getInitialLocation());
+ *InitialLocation);
W.startLine() << format(
"address_range: 0x%" PRIx64 " (end : 0x%" PRIx64 ")\n",
FDE->getAddressRange(),
@@ -227,7 +229,8 @@ void PrinterContext<ELFT>::printEHFrame(const Elf_Shdr *EHFrameShdr) const {
W.indent();
auto DumpOpts = DIDumpOptions();
DumpOpts.IsEH = true;
- Entry.cfis().dump(W.getOStream(), DumpOpts, W.getIndentLevel());
+ Entry.cfis().dump(W.getOStream(), DumpOpts, W.getIndentLevel(),
+ InitialLocation);
W.unindent();
W.unindent();
W.getOStream() << "\n";
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h b/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h
new file mode 100644
index 000000000000..a69751e072b7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h
@@ -0,0 +1,26 @@
+//===- AllInterfaces.h - ----------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a common entry point for registering all external
+// interface implementations to the linalg dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
+#define MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+void registerAllDialectInterfaceImplementations(DialectRegistry &registry);
+} // namespace linalg
+
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h
new file mode 100644
index 000000000000..c57501ea86b7
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h
@@ -0,0 +1,20 @@
+//===- MeshShardingInterfaceImpl.h ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
+#define MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
+
+namespace mlir {
+class DialectRegistry;
+
+namespace linalg {
+void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry);
+} // namespace linalg
+} // namespace mlir
+
+#endif // MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index fc2acc70381e..9d9b5892e1a5 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -46,6 +46,12 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind",
I32EnumAttrCase<"Sum", 1, "sum">,
I32EnumAttrCase<"Max", 2, "max">,
I32EnumAttrCase<"Min", 3, "min">,
+ I32EnumAttrCase<"Product", 4, "product">,
+ // Arithmetic mean.
+ I32EnumAttrCase<"Average", 5, "average">,
+ I32EnumAttrCase<"BitwiseAnd", 6, "bitwise_and">,
+ I32EnumAttrCase<"BitwiseOr", 7, "bitwise_or">,
+ I32EnumAttrCase<"BitwiseXor", 8, "bitwise_xor">,
I32EnumAttrCase<"Generic", 100, "generic">
]> {
let genSpecializedAttr = 0;
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index b9cd15e20626..8e1e47546358 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -353,6 +353,10 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [
attr-dict `:` type($input) `->` type($result)
}];
let hasCanonicalizer = 1;
+ let builders = [
+ OpBuilder<(ins "Value":$input, "StringRef":$mesh,
+ "ArrayRef<MeshAxis>":$meshAxes, "ReductionKind":$reduction)>
+ ];
}
def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
index ffc9b6fb18be..ab4df2ab028d 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h
@@ -22,6 +22,24 @@ class SymbolTableCollection;
namespace mesh {
+// Retrieve the mesh axes corresponding to each operation loop iterator based
+// on the provided shardings for the op's operands and results.
+// Assumes that the indexingMaps are projected permutations.
+ShardingArray getMeshAxisAssignmentForLoopIterators(
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<AffineMap> indexingMaps);
+
+bool isAtLeastOneReductionIteratorSharded(
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
+
+// Get the set of mesh axes that correspond to reduction loop iterators.
+SmallVector<MeshAxis> getReductionMeshAxes(
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators);
+
// Inserts a clone of the operation that has all ranked tensor
// arguments/results sharded.
void spmdizeTriviallyShardableOperation(
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
index aeab28961a4e..be82e2af399d 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h
@@ -13,6 +13,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ArrayRef.h"
namespace mlir {
class RewritePatternSet;
@@ -37,6 +38,11 @@ TypedValue<IndexType>
createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
ImplicitLocOpBuilder &builder);
+// Get process linear index along the given mesh axes.
+TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
+ ArrayRef<MeshAxis> meshAxes,
+ ImplicitLocOpBuilder &builder);
+
} // namespace mesh
} // namespace mlir
diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 50f6f6de5c28..6c8a170a03c7 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -216,6 +216,14 @@ public:
{TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()});
}
+ // Declare the same interface for multiple types.
+ // Example:
+ // declarePromisedInterfaces<FunctionOpInterface, MyFuncType1, MyFuncType2>()
+ template <typename InterfaceT, typename... ConcreteT>
+ void declarePromisedInterfaces() {
+ (declarePromisedInterface<ConcreteT, InterfaceT>(), ...);
+ }
+
/// Checks if the given interface, which is attempting to be used, is a
/// promised interface of this dialect that has yet to be implemented. If so,
/// emits a fatal error. `interfaceName` is an optional string that contains a
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 838bd03622a6..21775e11e071 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -43,10 +43,7 @@
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
-#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
-#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/MPI/IR/MPI.h"
@@ -157,10 +154,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
cf::registerBufferizableOpInterfaceExternalModels(registry);
cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
- linalg::registerBufferizableOpInterfaceExternalModels(registry);
- linalg::registerSubsetOpInterfaceExternalModels(registry);
- linalg::registerTilingInterfaceExternalModels(registry);
- linalg::registerValueBoundsOpInterfaceExternalModels(registry);
+ linalg::registerAllDialectInterfaceImplementations(registry);
memref::registerAllocationOpInterfaceExternalModels(registry);
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
memref::registerValueBoundsOpInterfaceExternalModels(registry);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 01fde101ef3c..83198c9b0db5 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1085,6 +1085,39 @@ struct ConversionConfig {
/// IR during an analysis conversion and only pre-existing operations are
/// added to the set.
DenseSet<Operation *> *legalizableOps = nullptr;
+
+ /// An optional listener that is notified about all IR modifications in case
+ /// dialect conversion succeeds. If the dialect conversion fails and no IR
+ /// modifications are visible (i.e., they were all rolled back), no
+ /// notifications are sent.
+ ///
+ /// Note: Notifications are sent in a delayed fashion, when the dialect
+ /// conversion is guaranteed to succeed. At that point, some IR modifications
+ /// may already have been materialized. Consequently, operations/blocks that
+ /// are passed to listener callbacks should not be accessed. (Ops/blocks are
+ /// guaranteed to be valid pointers and accessing op names is allowed. But
+ /// there are no guarantees about the state of ops/blocks at the time that a
+ /// callback is triggered.)
+ ///
+ /// Example: Consider a dialect conversion a new op ("test.foo") is created
+ /// and inserted, and later moved to another block. (Moving ops also triggers
+ /// "notifyOperationInserted".)
+ ///
+ /// (1) notifyOperationInserted: "test.foo" (into block "b1")
+ /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2")
+ ///
+ /// When querying "op->getBlock()" during the first "notifyOperationInserted",
+ /// "b2" would be returned because "moving an op" is a kind of rewrite that is
+ /// immediately performed by the dialect conversion (and rolled back upon
+ /// failure).
+ //
+ // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted"
+ // callback, the previous region/block is provided to the callback, but not
+ // the iterator pointing to the exact location within the region/block. That
+ // is because these notifications are sent with a delay (after the IR has
+ // already been modified) and iterators into past IR state cannot be
+ // represented at the moment.
+ RewriterBase::Listener *listener = nullptr;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
index f0ac1899bb02..c187563b8f0c 100644
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLinalgDialect
MLIRInferTypeOpInterface
MLIRIR
MLIRParser
+ MLIRShardingInterface
MLIRSideEffectInterfaces
MLIRSparseTensorDialect
MLIRSCFDialect
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
index 5069d43e7db9..027058d4de63 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
@@ -118,6 +119,12 @@ void mlir::linalg::LinalgDialect::initialize() {
>(namedStructuredOpRegionBuilders);
addInterfaces<LinalgInlinerInterface>();
+
+ declarePromisedInterface<GenericOp, mesh::ShardingInterface>();
+ declarePromisedInterfaces<mesh::ShardingInterface,
+#define GET_OP_LIST
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+ >();
}
LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
new file mode 100644
index 000000000000..281d9f220448
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp
@@ -0,0 +1,24 @@
+//===- AllInterfaces.cpp - ------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
+
+#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
+
+void mlir::linalg::registerAllDialectInterfaceImplementations(
+ DialectRegistry &registry) {
+ registerBufferizableOpInterfaceExternalModels(registry);
+ registerMeshShardingInterfaceExternalModels(registry);
+ registerSubsetOpInterfaceExternalModels(registry);
+ registerTilingInterfaceExternalModels(registry);
+ registerValueBoundsOpInterfaceExternalModels(registry);
+}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 4f47e3b87184..513c54de5d7b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRLinalgTransforms
+ AllInterfaces.cpp
BubbleUpExtractSlice.cpp
BufferizableOpInterfaceImpl.cpp
Bufferize.cpp
@@ -21,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
InlineScalarOperands.cpp
Interchange.cpp
Loops.cpp
+ MeshShardingInterfaceImpl.cpp
NamedOpConversions.cpp
Padding.cpp
Promotion.cpp
@@ -61,12 +63,15 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRIR
MLIRMemRefDialect
MLIRMemRefTransforms
+ MLIRMeshDialect
+ MLIRMeshTransforms
MLIRLinalgDialect
MLIRLinalgUtils
MLIRSCFDialect
MLIRSCFTransforms
MLIRSCFUtils
MLIRPass
+ MLIRShardingInterface
MLIRSubsetOpInterface
MLIRSparseTensorDialect
MLIRTensorDialect
diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
new file mode 100644
index 000000000000..146e88076566
--- /dev/null
+++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp
@@ -0,0 +1,353 @@
+//===- MeshShardingInterfaceImpl.cpp --------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h"
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
+#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Mesh/Transforms/Transforms.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
+#include <iterator>
+#include <optional>
+#include <utility>
+
+namespace mlir::linalg {
+
+using MeshAxis = mesh::MeshAxis;
+using ReductionKind = mesh::ReductionKind;
+using MeshShardingAttr = mesh::MeshShardingAttr;
+using ShardingArray = mesh::ShardingArray;
+using MeshOp = mesh::MeshOp;
+
+// Returns the corresponding mesh reduction kind for the given arith op.
+static ReductionKind getReductionKind(Operation *op) {
+ return llvm::TypeSwitch<Operation *, ReductionKind>(op)
+ // Floating-point operations.
+ .Case([](arith::AddFOp op) { return ReductionKind::Sum; })
+ .Case([](arith::MulFOp op) { return ReductionKind::Product; })
+ // TODO: handle maxnumf and minnumf.
+ .Case([](arith::MaximumFOp op) { return ReductionKind::Max; })
+ .Case([](arith::MinimumFOp op) { return ReductionKind::Min; })
+ // Integer operations.
+ .Case([](arith::AddIOp op) { return ReductionKind::Sum; })
+ .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; })
+ .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; })
+ .Case([](arith::AndIOp op) { return ReductionKind::Sum; })
+ // TODO: handle signless, signed and unsigned types properly.
+ // It is assumed that the element type of the collective operands and
+ // result drive the meaning of the reduction kind, whether it is signed
+ // or unsigned.
+ // The reduction op inside the linalg op may have different result type
+ // from the element type of the linalg op's result.
+ // Also signed and unsigned Arith dialect ops may accept signed, unsigned
+ // or signless operands.
+ // Maybe expand the reduction kinds.
+ .Case([](arith::MaxUIOp op) { return ReductionKind::Max; })
+ .Case([](arith::MinUIOp op) { return ReductionKind::Min; })
+ .Case([](arith::MaxSIOp op) { return ReductionKind::Max; })
+ .Case([](arith::MinSIOp op) { return ReductionKind::Min; })
+ .Case([](arith::MulIOp op) { return ReductionKind::Product; })
+ .Default([](Operation *op) { return ReductionKind::Generic; });
+}
+
+static std::optional<Operation *> getCombinerOp(LinalgOp op) {
+ SmallVector<Operation *> combinerOps;
+ Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps);
+ if (!reducedValue || combinerOps.size() != 1) {
+ return std::nullopt;
+ }
+
+ return combinerOps[0];
+}
+
+static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
+ std::optional<Operation *> reductionOp = getCombinerOp(op);
+ if (!reductionOp) {
+ return ReductionKind::Generic;
+ }
+ [[maybe_unused]] Type resultElementType =
+ llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
+ // TODO: handle case when result type of the reduction op does not match the
+ // element type of the result tensor.
+ // Would it makes sense at all?
+ assert(resultElementType == reductionOp.value()->getResult(0).getType());
+ return getReductionKind(reductionOp.value());
+}
+
+static MeshOp getMesh(Operation *op,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ SymbolTableCollection &symbolTable) {
+ for (MeshShardingAttr sharding : operandShardings) {
+ if (sharding) {
+ return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+ }
+ }
+
+ for (MeshShardingAttr sharding : resultShardings) {
+ if (sharding) {
+ return mesh::getMesh(op, sharding.getMesh(), symbolTable);
+ }
+ }
+
+ assert(false);
+ return nullptr;
+}
+
+// Choose the operand based on the current process index along the reduction
+// mesh axes.
+// We need to use the initial value only once to avoid including it in the
+// reduction multiple times.
+// In each process group only the leading process with linear index 0 would use
+// the original operand.
+// The other processes would use the reduction operation neutral tensor.
+static Value createDestinationPassingStyleInitOperand(
+ LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
+ MeshOp meshOp, ImplicitLocOpBuilder &builder) {
+ Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
+ meshOp.getSymName(), reductionMeshAxes, builder);
+ Value zero = builder.create<arith::ConstantIndexOp>(0);
+ Value isLeadProcess = builder.create<arith::CmpIOp>(
+ builder.getI1Type(), arith::CmpIPredicate::eq,
+ processLinearIndexInReductionGroup, zero);
+ scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(),
+ isLeadProcess, true, true);
+ // Then block.
+ {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
+ builder.create<scf::YieldOp>(spmdizedOperand);
+ }
+
+ // Else block.
+ {
+ OpBuilder::InsertionGuard insertionGuard(builder);
+ builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
+ SmallVector<OpFoldResult> shape =
+ tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
+ PartialReductionOpInterface partialReductionIface =
+ llvm::cast<PartialReductionOpInterface>(op.getOperation());
+ FailureOr<Operation *> reductionNeutralTensorOp =
+ partialReductionIface.generateInitialTensorForPartialReduction(
+ builder, builder.getLoc(), shape, {});
+ assert(succeeded(reductionNeutralTensorOp));
+ builder.create<scf::YieldOp>(
+ reductionNeutralTensorOp.value()->getResult(0));
+ }
+ return ifOp.getResult(0);
+}
+
+// Create the DPS init operands for the spmdized Linalg op.
+// Return all the new spmdized operands.
+static SmallVector<Value> createDestinationPassingStyleInitOperands(
+ LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap,
+ ImplicitLocOpBuilder &builder) {
+ // TODO: add support for multiple destination passing style initial value
+ // operands.
+ // PartialReductionOpInterface::generateInitialTensorForPartialReduction
+ // needs to also support multiple DPS initial operands.
+ SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
+ auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
+ Value spmdizedInitOperand =
+ spmdizationMap.lookup(op->getOperands()[operandIdx]);
+ newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
+ op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
+ return newOperands;
+}
+
+static void createAllReduceForResultWithoutPartialSharding(
+ Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes,
+ MeshShardingAttr resultSharding, ReductionKind reductionKind,
+ IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) {
+ SmallVector<MeshAxis> allReduceMeshAxes;
+ llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes),
+ [&resultSharding](MeshAxis axis) {
+ return !llvm::is_contained(resultSharding.getPartialAxes(),
+ axis);
+ });
+ if (allReduceMeshAxes.empty()) {
+ return;
+ }
+
+ Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult);
+ Value reducedValue = builder.create<mesh::AllReduceOp>(
+ spmdizedLinalgOpResult, resultSharding.getMesh().getValue(),
+ allReduceMeshAxes, reductionKind);
+ spmdizationMap.map(unshardedLinalgOpResult, reducedValue);
+}
+
+static void createAllReduceForResultsWithoutPartialShardings(
+ LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes,
+ ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap,
+ ImplicitLocOpBuilder &builder) {
+ ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp);
+ for (auto [unshardedLinalgOpResult, resultSharding] :
+ llvm::zip_equal(unshardedOp->getResults(), resultShardings)) {
+ createAllReduceForResultWithoutPartialSharding(
+ unshardedLinalgOpResult, opReductionMeshAxes, resultSharding,
+ reductionKind, spmdizationMap, builder);
+ }
+}
+
+static void spmdizeLinalgOpWithShardedReduction(
+ LinalgOp op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators,
+ IRMapping &spmdizationMap, SymbolTableCollection &symbolTable,
+ ImplicitLocOpBuilder &builder) {
+ MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable);
+ SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes(
+ loopIteratorTypes, meshAxisAssignmentForLoopIterators);
+ SmallVector<Value> spmdizedLinalgOpOperands =
+ createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands,
+ reductionMeshAxes,
+ spmdizationMap, builder);
+ // We must not change the operand mappings of the original spmdizationMap as
+ // they are the mappings for the whole spmdization blob and may be used by
+ // others.
+ IRMapping internalSpmdizationMap;
+ for (auto [unshardedOperand, spmdizedOperand] :
+ llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) {
+ internalSpmdizationMap.map(unshardedOperand, spmdizedOperand);
+ }
+ spmdizeTriviallyShardableOperation(
+ *op, spmdizedLinalgOpOperands, operandShardings, resultShardings,
+ internalSpmdizationMap, symbolTable, builder);
+ for (Value result : op->getResults()) {
+ spmdizationMap.map(result, internalSpmdizationMap.lookup(result));
+ }
+
+ // Handle partial shardings.
+ createAllReduceForResultsWithoutPartialShardings(
+ op, reductionMeshAxes, resultShardings, spmdizationMap, builder);
+}
+
+namespace {
+
+// ShardingInterface for ops that implement LinalgStructuredInterface.
+// The supported ops are only those where the indexing maps are projected
+// permutations.
+template <typename Op>
+struct StructuredOpShardingInterface
+ : public mesh::ShardingInterface::ExternalModel<
+ StructuredOpShardingInterface<Op>, Op> {
+ SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const {
+ return llvm::cast<LinalgOp>(op).getIteratorTypesArray();
+ }
+
+ SmallVector<AffineMap> getIndexingMaps(Operation *op) const {
+ LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
+ SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray();
+
+ // Results must have the same indexing as destination passing style initial
+ // operands.
+ for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) {
+ res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]);
+ }
+
+ return res;
+ }
+
+ LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands,
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ IRMapping &spmdizationMap,
+ SymbolTableCollection &symbolTable,
+ OpBuilder &builder) const {
+ LinalgOp linalgOp = llvm::cast<LinalgOp>(op);
+
+ SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
+ bool allIndexingMapsAreProjectedPermutation =
+ llvm::all_of(indexingMaps, [](AffineMap map) {
+ return map.isProjectedPermutation();
+ });
+ if (!allIndexingMapsAreProjectedPermutation) {
+ // TODO: handle non-projected permutations.
+ return op->emitOpError()
+ << "supports indexing maps that are only projected permutation.";
+ }
+
+ SmallVector<utils::IteratorType> loopIteratorTypes =
+ linalgOp.getIteratorTypesArray();
+ ShardingArray meshAxisAssignmentForLoopIterators =
+ getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings,
+ loopIteratorTypes, indexingMaps);
+ if (mesh::isAtLeastOneReductionIteratorSharded(
+ loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+ ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder);
+ spmdizeLinalgOpWithShardedReduction(
+ linalgOp, spmdizedOperands, operandShardings, resultShardings,
+ loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap,
+ symbolTable, implicitLocBuilder);
+ } else {
+ spmdizeTriviallyShardableOperation(*op, spmdizedOperands,
+ operandShardings, resultShardings,
+ spmdizationMap, symbolTable, builder);
+ }
+
+ return success();
+ }
+};
+
+} // namespace
+
+template <typename OpType>
+static void registerOne(MLIRContext *ctx) {
+ OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx);
+}
+
+/// Variadic helper function.
+template <typename... OpTypes>
+static void registerAll(MLIRContext *ctx) {
+ (registerOne<OpTypes>(ctx), ...);
+}
+
+void registerMeshShardingInterfaceExternalModels(DialectRegistry &registry) {
+ registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) {
+ DialectRegistry registry;
+ registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect,
+ tensor::TensorDialect>();
+ ctx->appendDialectRegistry(registry);
+ for (StringRef name : registry.getDialectNames())
+ ctx->getOrLoadDialect(name);
+
+ registerOne<linalg::GenericOp>(ctx);
+ registerAll<
+#define GET_OP_LIST
+#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
+ >(ctx);
+ });
+}
+
+} // namespace mlir::linalg
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 8b3119f02e8f..bd870d4f982e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -275,14 +275,6 @@ struct LinalgOpPartialReductionInterface
ArrayRef<int64_t> oldShape =
linalgOp.getShape(linalgOp.getDpsInitOperand(0));
- // Extend tile size vector to the rank of the output tensor.
- SmallVector<Value> tileSizeVector =
- getValueOrCreateConstantIndexOp(b, loc, sizes);
- if (tileSizeVector.size() < oldShape.size()) {
- auto zero = b.create<arith::ConstantIndexOp>(loc, 0);
- tileSizeVector.append(oldShape.size() - tileSizeVector.size(), zero);
- }
-
// Calculate the new shape, we insert the new dimensions based on the index
// of the reduction dimensions.
SmallVector<int64_t> newOutputShape;
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index 50163880e85f..03f11ad1f949 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -647,6 +647,13 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context);
}
+void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ Value input, StringRef mesh,
+ ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) {
+ build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input,
+ reduction);
+}
+
void AllReduceOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "all_reduce");
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index fe3d7c44413f..9acee5aa8d86 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -539,8 +539,9 @@ static bool areValuesCompatibleWithFullReplicationShardings(
if (std::size(values) != std::size(shardings)) {
return false;
}
- return llvm::all_of(llvm::zip(std::forward<ValueRange>(values),
- std::forward<MeshShardingAttrRage>(shardings)),
+ return llvm::all_of(llvm::zip_equal(
+ std::forward<ValueRange>(values),
+ std::forward<MeshShardingAttrRage>(shardings)),
[](auto valueAndSharding) {
return isValueCompatibleWithFullReplicationSharding(
std::get<0>(valueAndSharding),
@@ -563,6 +564,88 @@ void mesh::spmdizeFullyReplicatedOperation(
builder.clone(op, spmdizationMap);
}
+static void updateMeshAxisAssignmentForLoopIterators(
+ ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr,
+ SmallVector<std::optional<SmallVector<MeshAxis>>>
+ &meshAxesAssignmentForLoopIterators) {
+ AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr);
+ unsigned loopIteratorIdx = affineDimExpr.getPosition();
+ if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) {
+ assert(llvm::equal(meshAxesAssignmentForTensorAxis,
+ *meshAxesAssignmentForLoopIterators[loopIteratorIdx]));
+ } else {
+ meshAxesAssignmentForLoopIterators[loopIteratorIdx] =
+ llvm::to_vector(meshAxesAssignmentForTensorAxis);
+ }
+}
+
+ShardingArray mesh::getMeshAxisAssignmentForLoopIterators(
+ ArrayRef<MeshShardingAttr> operandShardings,
+ ArrayRef<MeshShardingAttr> resultShardings,
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<AffineMap> indexingMaps) {
+ SmallVector<std::optional<SmallVector<MeshAxis>>>
+ meshAxisAssignmentForLoopIterators(loopIteratorTypes.size());
+ SmallVector<MeshShardingAttr> operatorAndResultShardings;
+ operatorAndResultShardings.reserve(operandShardings.size() +
+ resultShardings.size());
+ llvm::append_range(operatorAndResultShardings, operandShardings);
+ for (auto [sharding, affineMap] :
+ llvm::zip_equal(operatorAndResultShardings, indexingMaps)) {
+ if (!sharding) {
+ continue;
+ }
+ for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] :
+ llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) {
+ updateMeshAxisAssignmentForLoopIterators(
+ meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr,
+ meshAxisAssignmentForLoopIterators);
+ }
+ // Missing trailing split axes means replication on those tensor dimensions.
+ for (unsigned i = sharding.getSplitAxes().size();
+ i < affineMap.getNumResults(); ++i) {
+ updateMeshAxisAssignmentForLoopIterators(
+ {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators);
+ }
+ }
+
+ ShardingArray res;
+ llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res),
+ [](std::optional<SmallVector<MeshAxis>> &axes) {
+ if (!axes) {
+ return SmallVector<MeshAxis>();
+ };
+ return std::move(*axes);
+ });
+ return res;
+}
+
+bool mesh::isAtLeastOneReductionIteratorSharded(
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
+ for (auto [loopIteratorType, meshAxisAssignment] :
+ llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+ if (loopIteratorType == utils::IteratorType::reduction &&
+ !meshAxisAssignment.empty()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+SmallVector<MeshAxis> mesh::getReductionMeshAxes(
+ ArrayRef<utils::IteratorType> loopIteratorTypes,
+ ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) {
+ SmallVector<MeshAxis> meshAxes;
+ for (auto [loopIteratorType, meshAxisAssignment] :
+ llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) {
+ if (loopIteratorType == utils::IteratorType::reduction) {
+ llvm::append_range(meshAxes, meshAxisAssignment);
+ }
+ }
+ return meshAxes;
+}
+
void mesh::spmdizeTriviallyShardableOperation(
Operation &op, ArrayRef<Value> spmdizedOperands,
ArrayRef<MeshShardingAttr> operandShardings,
@@ -572,7 +655,7 @@ void mesh::spmdizeTriviallyShardableOperation(
Operation *newOp = builder.clone(op, spmdizationMap);
// Set the result types to the sharded counterparts.
for (auto [oldResult, newResult, sharding] :
- llvm::zip(op.getResults(), newOp->getResults(), resultShardings)) {
+ llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) {
newResult.setType(shardType(newResult.getType(),
getMesh(&op, sharding.getMesh(), symbolTable),
sharding));
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
index d59b9119dea5..cb13ee404751 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp
@@ -208,4 +208,17 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes,
.cast<TypedValue<IndexType>>();
}
+TypedValue<IndexType> createProcessLinearIndex(StringRef mesh,
+ ArrayRef<MeshAxis> meshAxes,
+ ImplicitLocOpBuilder &builder) {
+ ResultRange processInGroupMultiIndex =
+ builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults();
+ Operation::result_range processGroupShape =
+ builder.create<MeshShapeOp>(mesh, meshAxes).getResult();
+ OpFoldResult processInGroupLinearIndex = affine::linearizeIndex(
+ llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex),
+ llvm::to_vector_of<OpFoldResult>(processGroupShape), builder);
+ return cast<TypedValue<IndexType>>(processInGroupLinearIndex.get<Value>());
+}
+
} // namespace mlir::mesh
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d7dc902a9a5e..c1a261eab848 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -204,14 +204,22 @@ public:
/// Roll back the rewrite. Operations may be erased during rollback.
virtual void rollback() = 0;
- /// Commit the rewrite. Operations/blocks may be unlinked during the commit
- /// phase, but they must not be erased yet. This is because internal dialect
- /// conversion state (such as `mapping`) may still be using them. Operations/
- /// blocks must be erased during cleanup.
- virtual void commit() {}
+ /// Commit the rewrite. At this point, it is certain that the dialect
+ /// conversion will succeed. All IR modifications, except for operation/block
+ /// erasure, must be performed through the given rewriter.
+ ///
+ /// Instead of erasing operations/blocks, they should merely be unlinked
+ /// commit phase and finally be erased during the cleanup phase. This is
+ /// because internal dialect conversion state (such as `mapping`) may still
+ /// be using them.
+ ///
+ /// Any IR modification that was already performed before the commit phase
+ /// (e.g., insertion of an op) must be communicated to the listener that may
+ /// be attached to the given rewriter.
+ virtual void commit(RewriterBase &rewriter) {}
/// Cleanup operations/blocks. Cleanup is called after commit.
- virtual void cleanup() {}
+ virtual void cleanup(RewriterBase &rewriter) {}
Kind getKind() const { return kind; }
@@ -221,12 +229,6 @@ protected:
IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
: kind(kind), rewriterImpl(rewriterImpl) {}
- /// Erase the given op (unless it was already erased).
- void eraseOp(Operation *op);
-
- /// Erase the given block (unless it was already erased).
- void eraseBlock(Block *block);
-
const ConversionConfig &getConfig() const;
const Kind kind;
@@ -265,6 +267,12 @@ public:
return rewrite->getKind() == Kind::CreateBlock;
}
+ void commit(RewriterBase &rewriter) override {
+ // The block was already created and inserted. Just inform the listener.
+ if (auto *listener = rewriter.getListener())
+ listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
+ }
+
void rollback() override {
// Unlink all of the operations within this block, they will be deleted
// separately.
@@ -311,10 +319,19 @@ public:
block = nullptr;
}
- void cleanup() override {
+ void commit(RewriterBase &rewriter) override {
// Erase the block.
assert(block && "expected block");
assert(block->empty() && "expected empty block");
+
+ // Notify the listener that the block is about to be erased.
+ if (auto *listener =
+ dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
+ listener->notifyBlockErased(block);
+ }
+
+ void cleanup(RewriterBase &rewriter) override {
+ // Erase the block.
block->dropAllDefinedValueUses();
delete block;
block = nullptr;
@@ -341,6 +358,13 @@ public:
firstInlinedInst(sourceBlock->empty() ? nullptr
: &sourceBlock->front()),
lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
+ // If a listener is attached to the dialect conversion, ops must be moved
+ // one-by-one. When they are moved in bulk, notifications cannot be sent
+ // because the ops that used to be in the source block at the time of the
+ // inlining (before the "commit" phase) are unknown at the time when
+ // notifications are sent (which is during the "commit" phase).
+ assert(!getConfig().listener &&
+ "InlineBlockRewrite not supported if listener is attached");
}
static bool classof(const IRRewrite *rewrite) {
@@ -382,6 +406,16 @@ public:
return rewrite->getKind() == Kind::MoveBlock;
}
+ void commit(RewriterBase &rewriter) override {
+ // The block was already moved. Just inform the listener.
+ if (auto *listener = rewriter.getListener()) {
+ // Note: `previousIt` cannot be passed because this is a delayed
+ // notification and iterators into past IR state cannot be represented.
+ listener->notifyBlockInserted(block, /*previous=*/region,
+ /*previousIt=*/{});
+ }
+ }
+
void rollback() override {
// Move the block back to its original position.
Region::iterator before =
@@ -437,7 +471,7 @@ public:
LogicalResult
materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
- void commit() override;
+ void commit(RewriterBase &rewriter) override;
void rollback() override;
@@ -466,7 +500,7 @@ public:
return rewrite->getKind() == Kind::ReplaceBlockArg;
}
- void commit() override;
+ void commit(RewriterBase &rewriter) override;
void rollback() override;
@@ -506,6 +540,17 @@ public:
return rewrite->getKind() == Kind::MoveOperation;
}
+ void commit(RewriterBase &rewriter) override {
+ // The operation was already moved. Just inform the listener.
+ if (auto *listener = rewriter.getListener()) {
+ // Note: `previousIt` cannot be passed because this is a delayed
+ // notification and iterators into past IR state cannot be represented.
+ listener->notifyOperationInserted(
+ op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
+ /*insertPt=*/{}));
+ }
+ }
+
void rollback() override {
// Move the operation back to its original position.
Block::iterator before =
@@ -549,7 +594,12 @@ public:
"rewrite was neither committed nor rolled back");
}
- void commit() override {
+ void commit(RewriterBase &rewriter) override {
+ // Notify the listener that the operation was modified in-place.
+ if (auto *listener =
+ dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
+ listener->notifyOperationModified(op);
+
if (propertiesStorage) {
OpaqueProperties propCopy(propertiesStorage);
// Note: The operation may have been erased in the mean time, so
@@ -600,11 +650,11 @@ public:
return rewrite->getKind() == Kind::ReplaceOperation;
}
- void commit() override;
+ void commit(RewriterBase &rewriter) override;
void rollback() override;
- void cleanup() override;
+ void cleanup(RewriterBase &rewriter) override;
const TypeConverter *getConverter() const { return converter; }
@@ -629,6 +679,12 @@ public:
return rewrite->getKind() == Kind::CreateOperation;
}
+ void commit(RewriterBase &rewriter) override {
+ // The operation was already created and inserted. Just inform the listener.
+ if (auto *listener = rewriter.getListener())
+ listener->notifyOperationInserted(op, /*previous=*/{});
+ }
+
void rollback() override;
};
@@ -666,7 +722,7 @@ public:
void rollback() override;
- void cleanup() override;
+ void cleanup(RewriterBase &rewriter) override;
/// Return the type converter of this materialization (which may be null).
const TypeConverter *getConverter() const {
@@ -735,7 +791,7 @@ namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
const ConversionConfig &config)
- : eraseRewriter(ctx), config(config) {}
+ : context(ctx), config(config) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -900,6 +956,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
}
void notifyOperationErased(Operation *op) override { erased.insert(op); }
+
void notifyBlockErased(Block *block) override { erased.insert(block); }
/// Pointers to all erased operations and blocks.
@@ -910,8 +967,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// State
//===--------------------------------------------------------------------===//
- /// This rewriter must be used for erasing ops/blocks.
- SingleEraseRewriter eraseRewriter;
+ /// MLIR context.
+ MLIRContext *context;
// Mapping between replaced values that differ in type. This happens when
// replacing a value with one of a different type.
@@ -955,19 +1012,19 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
} // namespace detail
} // namespace mlir
-void IRRewrite::eraseOp(Operation *op) {
- rewriterImpl.eraseRewriter.eraseOp(op);
-}
-
-void IRRewrite::eraseBlock(Block *block) {
- rewriterImpl.eraseRewriter.eraseBlock(block);
-}
-
const ConversionConfig &IRRewrite::getConfig() const {
return rewriterImpl.config;
}
-void BlockTypeConversionRewrite::commit() {
+void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
+ // Inform the listener about all IR modifications that have already taken
+ // place: References to the original block have been replaced with the new
+ // block.
+ if (auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>(
+ rewriter.getListener()))
+ for (Operation *op : block->getUsers())
+ listener->notifyOperationModified(op);
+
// Process the remapping for each of the original arguments.
for (auto [origArg, info] :
llvm::zip_equal(origBlock->getArguments(), argInfo)) {
@@ -975,7 +1032,7 @@ void BlockTypeConversionRewrite::commit() {
if (!info) {
if (Value newArg =
rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
- origArg.replaceAllUsesWith(newArg);
+ rewriter.replaceAllUsesWith(origArg, newArg);
continue;
}
@@ -985,8 +1042,8 @@ void BlockTypeConversionRewrite::commit() {
// If the argument is still used, replace it with the generated cast.
if (!origArg.use_empty()) {
- origArg.replaceAllUsesWith(
- rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType()));
+ rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
+ castValue, origArg.getType()));
}
}
}
@@ -1042,13 +1099,13 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
return success();
}
-void ReplaceBlockArgRewrite::commit() {
+void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
if (!repl)
return;
if (isa<BlockArgument>(repl)) {
- arg.replaceAllUsesWith(repl);
+ rewriter.replaceAllUsesWith(arg, repl);
return;
}
@@ -1057,7 +1114,7 @@ void ReplaceBlockArgRewrite::commit() {
// replacement value.
Operation *replOp = cast<OpResult>(repl).getOwner();
Block *replBlock = replOp->getBlock();
- arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
+ rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
Operation *user = operand.getOwner();
return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
});
@@ -1065,14 +1122,40 @@ void ReplaceBlockArgRewrite::commit() {
void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
-void ReplaceOperationRewrite::commit() {
- for (OpResult result : op->getResults())
- if (Value newValue =
- rewriterImpl.mapping.lookupOrNull(result, result.getType()))
- result.replaceAllUsesWith(newValue);
+void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
+ auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>(
+ rewriter.getListener());
+
+ // Compute replacement values.
+ SmallVector<Value> replacements =
+ llvm::map_to_vector(op->getResults(), [&](OpResult result) {
+ return rewriterImpl.mapping.lookupOrNull(result, result.getType());
+ });
+
+ // Notify the listener that the operation is about to be replaced.
+ if (listener)
+ listener->notifyOperationReplaced(op, replacements);
+
+ // Replace all uses with the new values.
+ for (auto [result, newValue] :
+ llvm::zip_equal(op->getResults(), replacements))
+ if (newValue)
+ rewriter.replaceAllUsesWith(result, newValue);
+
+ // The original op will be erased, so remove it from the set of unlegalized
+ // ops.
if (getConfig().unlegalizedOps)
getConfig().unlegalizedOps->erase(op);
+
+ // Notify the listener that the operation (and its nested operations) was
+ // erased.
+ if (listener) {
+ op->walk<WalkOrder::PostOrder>(
+ [&](Operation *op) { listener->notifyOperationErased(op); });
+ }
+
// Do not erase the operation yet. It may still be referenced in `mapping`.
+ // Just unlink it for now and erase it during cleanup.
op->getBlock()->getOperations().remove(op);
}
@@ -1081,7 +1164,9 @@ void ReplaceOperationRewrite::rollback() {
rewriterImpl.mapping.erase(result);
}
-void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
+void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
+ rewriter.eraseOp(op);
+}
void CreateOperationRewrite::rollback() {
for (Region &region : op->getRegions()) {
@@ -1100,14 +1185,20 @@ void UnresolvedMaterializationRewrite::rollback() {
op->erase();
}
-void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
+void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
+ rewriter.eraseOp(op);
+}
void ConversionPatternRewriterImpl::applyRewrites() {
// Commit all rewrites.
+ IRRewriter rewriter(context, config.listener);
for (auto &rewrite : rewrites)
- rewrite->commit();
+ rewrite->commit(rewriter);
+
+ // Clean up all rewrites.
+ SingleEraseRewriter eraseRewriter(context);
for (auto &rewrite : rewrites)
- rewrite->cleanup();
+ rewrite->cleanup(eraseRewriter);
}
//===----------------------------------------------------------------------===//
@@ -1281,7 +1372,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
ConversionPatternRewriter &rewriter, Block *block,
const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion) {
- MLIRContext *ctx = rewriter.getContext();
+ OpBuilder::InsertionGuard g(rewriter);
// If no arguments are being changed or added, there is nothing to do.
unsigned origArgCount = block->getNumArguments();
@@ -1289,14 +1380,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
if (llvm::equal(block->getArgumentTypes(), convertedTypes))
return block;
- // Split the block at the beginning to get a new block to use for the updated
- // signature.
- Block *newBlock = rewriter.splitBlock(block, block->begin());
- block->replaceAllUsesWith(newBlock);
-
- // Map all new arguments to the location of the argument they originate from.
+ // Compute the locations of all block arguments in the new block.
SmallVector<Location> newLocs(convertedTypes.size(),
- Builder(ctx).getUnknownLoc());
+ rewriter.getUnknownLoc());
for (unsigned i = 0; i < origArgCount; ++i) {
auto inputMap = signatureConversion.getInputMapping(i);
if (!inputMap || inputMap->replacementValue)
@@ -1306,9 +1392,29 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
newLocs[inputMap->inputNo + j] = origLoc;
}
- SmallVector<Value, 4> newArgRange(
- newBlock->addArguments(convertedTypes, newLocs));
- ArrayRef<Value> newArgs(newArgRange);
+ // Insert a new block with the converted block argument types and move all ops
+ // from the old block to the new block.
+ Block *newBlock =
+ rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
+ convertedTypes, newLocs);
+
+ // If a listener is attached to the dialect conversion, ops cannot be moved
+ // to the destination block in bulk ("fast path"). This is because at the time
+ // the notifications are sent, it is unknown which ops were moved. Instead,
+ // ops should be moved one-by-one ("slow path"), so that a separate
+ // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
+ // a bit more efficient, so we try to do that when possible.
+ bool fastPath = !config.listener;
+ if (fastPath) {
+ appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
+ newBlock->getOperations().splice(newBlock->end(), block->getOperations());
+ } else {
+ while (!block->empty())
+ rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
+ }
+
+ // Replace all uses of the old block with the new block.
+ block->replaceAllUsesWith(newBlock);
// Remap each of the original arguments as determined by the signature
// conversion.
@@ -1333,7 +1439,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
}
// Otherwise, this is a 1->1+ mapping.
- auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size);
+ auto replArgs =
+ newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
Value newArg;
// If this is a 1->1 mapping and the types of new and replacement arguments
@@ -1642,10 +1749,31 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
"expected 'source' to have no predecessors");
#endif // NDEBUG
- impl->notifyBlockBeingInlined(dest, source, before);
+ // If a listener is attached to the dialect conversion, ops cannot be moved
+ // to the destination block in bulk ("fast path"). This is because at the time
+ // the notifications are sent, it is unknown which ops were moved. Instead,
+ // ops should be moved one-by-one ("slow path"), so that a separate
+ // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
+ // a bit more efficient, so we try to do that when possible.
+ bool fastPath = !impl->config.listener;
+
+ if (fastPath)
+ impl->notifyBlockBeingInlined(dest, source, before);
+
+ // Replace all uses of block arguments.
for (auto it : llvm::zip(source->getArguments(), argValues))
replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
- dest->getOperations().splice(before, source->getOperations());
+
+ if (fastPath) {
+ // Move all ops at once.
+ dest->getOperations().splice(before, source->getOperations());
+ } else {
+ // Move op by op.
+ while (!source->empty())
+ moveOpBefore(&source->front(), dest, before);
+ }
+
+ // Erase the source block.
eraseBlock(source);
}
diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
new file mode 100644
index 000000000000..6d21def8de27
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir
@@ -0,0 +1,165 @@
+// RUN: mlir-opt \
+// RUN: --mesh-spmdization \
+// RUN: --test-constant-fold \
+// RUN: --split-input-file \
+// RUN: %s | FileCheck %s
+
+// CHECK: #[[$MAP_IDENTITY_1D:.*]] = affine_map<(d0) -> (d0)>
+#map_identity_1d = affine_map<(d0) -> (d0)>
+
+mesh.mesh @mesh_1d(shape = 2)
+
+// CHECK-LABEL: func @elementwise_static_1d_mesh_static_1d_tensor
+func.func @elementwise_static_1d_mesh_static_1d_tensor(
+ // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1xi8>,
+ %in1: tensor<2xi8>,
+ // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xi8>,
+ %in2: tensor<2xi8>,
+ // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<1xi8>
+ %dps_out: tensor<2xi8>
+// CHECK-SAME: -> tensor<1xi8> {
+) -> tensor<2xi8> {
+ %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ // CHECK: %[[RES:.*]] = linalg.generic {
+ // CHECK-SAME: indexing_maps = [#[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]]],
+ // CHECK-SAME: iterator_types = ["parallel"]}
+ // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1xi8>, tensor<1xi8>)
+ // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1xi8>) {
+ %res = linalg.generic {
+ indexing_maps = [#map_identity_1d, #map_identity_1d, #map_identity_1d],
+ iterator_types = ["parallel"]
+ } ins(%in1_shared2, %in2_shared2 : tensor<2xi8>, tensor<2xi8>)
+ outs(%dps_out_shared2 : tensor<2xi8>) {
+ ^bb0(%in1_scalar: i8, %in2_scalar: i8, %out: i8):
+ %res_scalar = arith.muli %in1_scalar, %in2_scalar : i8
+ linalg.yield %res_scalar : i8
+ } -> tensor<2xi8>
+ %res_shared1 = mesh.shard %res to <@mesh_1d, [[0]]> : tensor<2xi8>
+ %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8>
+ // CHECK: return %[[RES]] : tensor<1xi8>
+ return %res_shared2 : tensor<2xi8>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 4)
+
+// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding
+func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding(
+ // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1x3xi8>,
+ %in1: tensor<4x3xi8>,
+// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x8xi8>,
+ %in2: tensor<3x8xi8>,
+// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<1x8xi8>
+ %dps_out: tensor<4x8xi8>
+// CHECK-SAME: -> tensor<1x8xi8> {
+) -> tensor<4x8xi8> {
+ %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[0]]> : tensor<4x3xi8>
+ %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x3xi8>
+ %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[]]> : tensor<3x8xi8>
+ %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<3x8xi8>
+ %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[0]]> : tensor<4x8xi8>
+ %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x8xi8>
+ // CHECK: %[[RES:.*]] = linalg.matmul
+ // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1x3xi8>, tensor<3x8xi8>)
+ // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1x8xi8>)
+ // CHECK-SAME: -> tensor<1x8xi8>
+ %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x3xi8>, tensor<3x8xi8>)
+ outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
+ %res_shared1 = mesh.shard %res to <@mesh_1d, [[0]]> : tensor<4x8xi8>
+ %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x8xi8>
+ // CHECK: return %[[RES]] : tensor<1x8xi8>
+ return %res_shared2 : tensor<4x8xi8>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 3)
+
+// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding
+func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding(
+ // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>,
+ %in1: tensor<4x6xi8>,
+// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>,
+ %in2: tensor<6x8xi8>,
+// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8>
+ %dps_out: tensor<4x8xi8>
+// CHECK-SAME: -> tensor<4x8xi8> {
+) -> tensor<4x8xi8> {
+ %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[], [0]]> : tensor<4x6xi8>
+ %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x6xi8>
+ %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<6x8xi8>
+ %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<6x8xi8>
+ %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[]]> : tensor<4x8xi8>
+ %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8>
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
+ // CHECK-DAG: %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
+ // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
+ // CHECK: %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index
+ // CHECK: %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) {
+ // CHECK: scf.yield %[[DPS_OUT]] : tensor<4x8xi8>
+ // CHECK: } else {
+ // CHECK-DAG: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<4x8xi8>
+ // CHECK: %[[NEUTRAL_ELEMENT_FILLED_TENSOR:.*]] = linalg.fill ins(%[[C0_I8]] : i8)
+ // CHECK-SAME: outs(%[[EMPTY_TENSOR]] : tensor<4x8xi8>) -> tensor<4x8xi8>
+ // CHECK: scf.yield %[[NEUTRAL_ELEMENT_FILLED_TENSOR]] : tensor<4x8xi8>
+ // CHECK: }
+ // CHECK: %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>)
+ // CHECK-SAME: outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8>
+ // CHECK: %[[ALL_REDUCED:.*]] = mesh.all_reduce %[[SHARDED_MATMUL]] on @mesh_1d mesh_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8>
+ %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>)
+ outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
+ %res_shared1 = mesh.shard %res to <@mesh_1d, [[]]> : tensor<4x8xi8>
+ %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8>
+ // CHECK: return %[[ALL_REDUCED]] : tensor<4x8xi8>
+ return %res_shared2 : tensor<4x8xi8>
+}
+
+// -----
+
+mesh.mesh @mesh_1d(shape = 3)
+
+// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partial_result
+func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partial_result(
+ // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>,
+ %in1: tensor<4x6xi8>,
+// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>,
+ %in2: tensor<6x8xi8>,
+// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8>
+ %dps_out: tensor<4x8xi8>
+// CHECK-SAME: -> tensor<4x8xi8> {
+) -> tensor<4x8xi8> {
+ %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[], [0]]> : tensor<4x6xi8>
+ %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x6xi8>
+ %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<6x8xi8>
+ %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<6x8xi8>
+ %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[]]> : tensor<4x8xi8>
+ %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8>
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8
+ // CHECK-DAG: %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index
+ // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index
+ // CHECK: %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index
+ // CHECK: %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) {
+ // CHECK: scf.yield %[[DPS_OUT]] : tensor<4x8xi8>
+ // CHECK: } else {
+ // CHECK-DAG: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<4x8xi8>
+ // CHECK: %[[NEUTRAL_ELEMENT_FILLED_TENSOR:.*]] = linalg.fill ins(%[[C0_I8]] : i8)
+ // CHECK-SAME: outs(%[[EMPTY_TENSOR]] : tensor<4x8xi8>) -> tensor<4x8xi8>
+ // CHECK: scf.yield %[[NEUTRAL_ELEMENT_FILLED_TENSOR]] : tensor<4x8xi8>
+ // CHECK: }
+ // CHECK: %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>)
+ // CHECK-SAME: outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8>
+ %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>)
+ outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8>
+ %res_shared1 = mesh.shard %res to <@mesh_1d, [[]], partial = sum[0]> : tensor<4x8xi8>
+ %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[]], partial = sum[0]> annotate_for_users: tensor<4x8xi8>
+ // CHECK: return %[[SHARDED_MATMUL]] : tensor<4x8xi8>
+ return %res_shared2 : tensor<4x8xi8>
+}
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index ccdc9fe78ea0..d552f0346644 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -1,5 +1,10 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns -verify-diagnostics %s | FileCheck %s
+// CHECK: notifyOperationInserted: test.legal_op_a, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_a
+// CHECK-NEXT: notifyOperationModified: func.return
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_a
+
// CHECK-LABEL: verifyDirectPattern
func.func @verifyDirectPattern() -> i32 {
// CHECK-NEXT: "test.legal_op_a"() <{status = "Success"}
@@ -8,6 +13,16 @@ func.func @verifyDirectPattern() -> i32 {
return %result : i32
}
+// -----
+
+// CHECK: notifyOperationInserted: test.illegal_op_e, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_c
+// CHECK-NEXT: notifyOperationModified: func.return
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_c
+// CHECK-NEXT: notifyOperationInserted: test.legal_op_a, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_e
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_e
+
// CHECK-LABEL: verifyLargerBenefit
func.func @verifyLargerBenefit() -> i32 {
// CHECK-NEXT: "test.legal_op_a"() <{status = "Success"}
@@ -16,16 +31,24 @@ func.func @verifyLargerBenefit() -> i32 {
return %result : i32
}
+// -----
+
+// CHECK: notifyOperationModified: func.func
+// Note: No block insertion because this function is external and no block
+// signature conversion is performed.
+
// CHECK-LABEL: func private @remap_input_1_to_0()
func.func private @remap_input_1_to_0(i16)
+// -----
+
// CHECK-LABEL: func @remap_input_1_to_1(%arg0: f64)
func.func @remap_input_1_to_1(%arg0: i64) {
// CHECK-NEXT: "test.valid"{{.*}} : (f64)
"test.invalid"(%arg0) : (i64) -> ()
}
-// CHECK-LABEL: func @remap_call_1_to_1(%arg0: f64)
+// CHECK: func @remap_call_1_to_1(%arg0: f64)
func.func @remap_call_1_to_1(%arg0: i64) {
// CHECK-NEXT: call @remap_input_1_to_1(%arg0) : (f64) -> ()
call @remap_input_1_to_1(%arg0) : (i64) -> ()
@@ -33,12 +56,36 @@ func.func @remap_call_1_to_1(%arg0: i64) {
return
}
+// -----
+
+// Block signature conversion: new block is inserted.
+// CHECK: notifyBlockInserted into func.func: was unlinked
+
+// Contents of the old block are moved to the new block.
+// CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown
+
+// The new block arguments are used in "test.return".
+// CHECK-NEXT: notifyOperationModified: test.return
+
+// The old block is erased.
+// CHECK-NEXT: notifyBlockErased
+
+// The function op gets a new type attribute.
+// CHECK-NEXT: notifyOperationModified: func.func
+
+// "test.return" is replaced.
+// CHECK-NEXT: notifyOperationInserted: test.return, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.return
+// CHECK-NEXT: notifyOperationErased: test.return
+
// CHECK-LABEL: func @remap_input_1_to_N({{.*}}f16, {{.*}}f16)
func.func @remap_input_1_to_N(%arg0: f32) -> f32 {
// CHECK-NEXT: "test.return"{{.*}} : (f16, f16) -> ()
"test.return"(%arg0) : (f32) -> ()
}
+// -----
+
// CHECK-LABEL: func @remap_input_1_to_N_remaining_use(%arg0: f16, %arg1: f16)
func.func @remap_input_1_to_N_remaining_use(%arg0: f32) {
// CHECK-NEXT: [[CAST:%.*]] = "test.cast"(%arg0, %arg1) : (f16, f16) -> f32
@@ -54,6 +101,8 @@ func.func @remap_materialize_1_to_1(%arg0: i42) {
"test.return"(%arg0) : (i42) -> ()
}
+// -----
+
// CHECK-LABEL: func @remap_input_to_self
func.func @remap_input_to_self(%arg0: index) {
// CHECK-NOT: test.cast
@@ -68,6 +117,8 @@ func.func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) {
"test.invalid"(%arg0, %arg1) : (i64, i64) -> ()
}
+// -----
+
// CHECK-LABEL: func @no_remap_nested
func.func @no_remap_nested() {
// CHECK-NEXT: "foo.region"
@@ -82,6 +133,8 @@ func.func @no_remap_nested() {
return
}
+// -----
+
// CHECK-LABEL: func @remap_moved_region_args
func.func @remap_moved_region_args() {
// CHECK-NEXT: return
@@ -96,6 +149,8 @@ func.func @remap_moved_region_args() {
return
}
+// -----
+
// CHECK-LABEL: func @remap_cloned_region_args
func.func @remap_cloned_region_args() {
// CHECK-NEXT: return
@@ -122,6 +177,8 @@ func.func @remap_drop_region() {
return
}
+// -----
+
// CHECK-LABEL: func @dropped_input_in_use
func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
// CHECK-NEXT: "test.cast"{{.*}} : () -> i16
@@ -130,6 +187,8 @@ func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
"work"(%arg) : (i16) -> ()
}
+// -----
+
// CHECK-LABEL: func @up_to_date_replacement
func.func @up_to_date_replacement(%arg: i8) -> i8 {
// CHECK-NEXT: return
@@ -139,6 +198,8 @@ func.func @up_to_date_replacement(%arg: i8) -> i8 {
return %repl_2 : i8
}
+// -----
+
// CHECK-LABEL: func @remove_foldable_op
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: i32)
func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
@@ -150,6 +211,8 @@ func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
return %0 : i32
}
+// -----
+
// CHECK-LABEL: @create_block
func.func @create_block() {
// Check that we created a block with arguments.
@@ -161,6 +224,12 @@ func.func @create_block() {
return
}
+// -----
+
+// CHECK: notifyOperationModified: test.recursive_rewrite
+// CHECK-NEXT: notifyOperationModified: test.recursive_rewrite
+// CHECK-NEXT: notifyOperationModified: test.recursive_rewrite
+
// CHECK-LABEL: @bounded_recursion
func.func @bounded_recursion() {
// CHECK: test.recursive_rewrite 0
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 27eae2ffd694..2da184bc3d85 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -327,8 +327,12 @@ struct TestPatternDriver
struct DumpNotifications : public RewriterBase::Listener {
void notifyBlockInserted(Block *block, Region *previous,
Region::iterator previousIt) override {
- llvm::outs() << "notifyBlockInserted into "
- << block->getParentOp()->getName() << ": ";
+ llvm::outs() << "notifyBlockInserted";
+ if (block->getParentOp()) {
+ llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
+ } else {
+ llvm::outs() << " into unknown op: ";
+ }
if (previous == nullptr) {
llvm::outs() << "was unlinked\n";
} else {
@@ -341,7 +345,9 @@ struct DumpNotifications : public RewriterBase::Listener {
if (!previous.isSet()) {
llvm::outs() << ", was unlinked\n";
} else {
- if (previous.getPoint() == previous.getBlock()->end()) {
+ if (!previous.getPoint().getNodePtr()) {
+ llvm::outs() << ", was linked, exact position unknown\n";
+ } else if (previous.getPoint() == previous.getBlock()->end()) {
llvm::outs() << ", was last in block\n";
} else {
llvm::outs() << ", previous = " << previous.getPoint()->getName()
@@ -349,9 +355,18 @@ struct DumpNotifications : public RewriterBase::Listener {
}
}
}
+ void notifyBlockErased(Block *block) override {
+ llvm::outs() << "notifyBlockErased\n";
+ }
void notifyOperationErased(Operation *op) override {
llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
}
+ void notifyOperationModified(Operation *op) override {
+ llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
+ }
+ void notifyOperationReplaced(Operation *op, ValueRange values) override {
+ llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
+ }
};
struct TestStrictPatternDriver
@@ -1153,6 +1168,8 @@ struct TestLegalizePatternDriver
if (mode == ConversionMode::Partial) {
DenseSet<Operation *> unlegalizedOps;
ConversionConfig config;
+ DumpNotifications dumpNotifications;
+ config.listener = &dumpNotifications;
config.unlegalizedOps = &unlegalizedOps;
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns), config))) {
@@ -1171,8 +1188,11 @@ struct TestLegalizePatternDriver
return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
});
+ ConversionConfig config;
+ DumpNotifications dumpNotifications;
+ config.listener = &dumpNotifications;
if (failed(applyFullConversion(getOperation(), target,
- std::move(patterns)))) {
+ std::move(patterns), config))) {
getOperation()->emitRemark() << "applyFullConversion failed";
}
return;
diff --git a/openmp/runtime/src/kmp_collapse.cpp b/openmp/runtime/src/kmp_collapse.cpp
index 2c410ca9b603..569d2c150831 100644
--- a/openmp/runtime/src/kmp_collapse.cpp
+++ b/openmp/runtime/src/kmp_collapse.cpp
@@ -1272,6 +1272,304 @@ void kmp_calc_original_ivs_for_end(
}
}
+/**************************************************************************
+ * Identify nested loop structure - loops come in the canonical form
+ * Lower triangle matrix: i = 0; i <= N; i++ {0,0}:{N,0}
+ * j = 0; j <= 0/-1+1*i; j++ {0,0}:{0/-1,1}
+ * Upper Triangle matrix
+ * i = 0; i <= N; i++ {0,0}:{N,0}
+ * j = 0+1*i; j <= N; j++ {0,1}:{N,0}
+ * ************************************************************************/
+nested_loop_type_t
+kmp_identify_nested_loop_structure(/*in*/ bounds_info_t *original_bounds_nest,
+ /*in*/ kmp_index_t n) {
+ // only 2-level nested loops are supported
+ if (n != 2) {
+ return nested_loop_type_unkown;
+ }
+ // loops must be canonical
+ KMP_ASSERT(
+ (original_bounds_nest[0].comparison == comparison_t::comp_less_or_eq) &&
+ (original_bounds_nest[1].comparison == comparison_t::comp_less_or_eq));
+ // check outer loop bounds: for triangular need to be {0,0}:{N,0}
+ kmp_uint64 outer_lb0_u64 = kmp_fix_iv(original_bounds_nest[0].loop_iv_type,
+ original_bounds_nest[0].lb0_u64);
+ kmp_uint64 outer_ub0_u64 = kmp_fix_iv(original_bounds_nest[0].loop_iv_type,
+ original_bounds_nest[0].ub0_u64);
+ kmp_uint64 outer_lb1_u64 = kmp_fix_iv(original_bounds_nest[0].loop_iv_type,
+ original_bounds_nest[0].lb1_u64);
+ kmp_uint64 outer_ub1_u64 = kmp_fix_iv(original_bounds_nest[0].loop_iv_type,
+ original_bounds_nest[0].ub1_u64);
+ if (outer_lb0_u64 != 0 || outer_lb1_u64 != 0 || outer_ub1_u64 != 0) {
+ return nested_loop_type_unkown;
+ }
+ // check inner bounds to determine triangle type
+ kmp_uint64 inner_lb0_u64 = kmp_fix_iv(original_bounds_nest[1].loop_iv_type,
+ original_bounds_nest[1].lb0_u64);
+ kmp_uint64 inner_ub0_u64 = kmp_fix_iv(original_bounds_nest[1].loop_iv_type,
+ original_bounds_nest[1].ub0_u64);
+ kmp_uint64 inner_lb1_u64 = kmp_fix_iv(original_bounds_nest[1].loop_iv_type,
+ original_bounds_nest[1].lb1_u64);
+ kmp_uint64 inner_ub1_u64 = kmp_fix_iv(original_bounds_nest[1].loop_iv_type,
+ original_bounds_nest[1].ub1_u64);
+ // lower triangle loop inner bounds need to be {0,0}:{0/-1,1}
+ if (inner_lb0_u64 == 0 && inner_lb1_u64 == 0 &&
+ (inner_ub0_u64 == 0 || inner_ub0_u64 == -1) && inner_ub1_u64 == 1) {
+ return nested_loop_type_lower_triangular_matrix;
+ }
+ // upper triangle loop inner bounds need to be {0,1}:{N,0}
+ if (inner_lb0_u64 == 0 && inner_lb1_u64 == 1 &&
+ inner_ub0_u64 == outer_ub0_u64 && inner_ub1_u64 == 0) {
+ return nested_loop_type_upper_triangular_matrix;
+ }
+ return nested_loop_type_unkown;
+}
+
+/**************************************************************************
+ * SQRT Approximation: https://math.mit.edu/~stevenj/18.335/newton-sqrt.pdf
+ * Start point is x so the result is always > sqrt(x)
+ * The method has uniform convergence, PRECISION is set to 0.1
+ * ************************************************************************/
+#define level_of_precision 0.1
+double sqrt_newton_approx(/*in*/ kmp_uint64 x) {
+ double sqrt_old = 0.;
+ double sqrt_new = (double)x;
+ do {
+ sqrt_old = sqrt_new;
+ sqrt_new = (sqrt_old + x / sqrt_old) / 2;
+ } while ((sqrt_old - sqrt_new) > level_of_precision);
+ return sqrt_new;
+}
+
+/**************************************************************************
+ * Handle lower triangle matrix in the canonical form
+ * i = 0; i <= N; i++ {0,0}:{N,0}
+ * j = 0; j <= 0/-1 + 1*i; j++ {0,0}:{0/-1,1}
+ * ************************************************************************/
+void kmp_handle_lower_triangle_matrix(
+ /*in*/ kmp_uint32 nth,
+ /*in*/ kmp_uint32 tid,
+ /*in */ kmp_index_t n,
+ /*in/out*/ bounds_info_t *original_bounds_nest,
+ /*out*/ bounds_info_t *chunk_bounds_nest) {
+
+ // transfer loop types from the original loop to the chunks
+ for (kmp_index_t i = 0; i < n; ++i) {
+ chunk_bounds_nest[i] = original_bounds_nest[i];
+ }
+ // cleanup iv variables
+ kmp_uint64 outer_ub0 = kmp_fix_iv(original_bounds_nest[0].loop_iv_type,
+ original_bounds_nest[0].ub0_u64);
+ kmp_uint64 outer_lb0 = kmp_fix_iv(original_bounds_nest[0].loop_iv_type,
+ original_bounds_nest[0].lb0_u64);
+ kmp_uint64 inner_ub0 = kmp_fix_iv(original_bounds_nest[1].loop_iv_type,
+ original_bounds_nest[1].ub0_u64);
+ // calculate the chunk's lower and upper bounds
+ // the total number of iterations in the loop is the sum of the arithmetic
+ // progression from the outer lower to outer upper bound (inclusive since the
+ // loop is canonical) note that less_than inner loops (inner_ub0 = -1)
+ // effectively make the progression 1-based making N = (outer_ub0 - inner_lb0
+ // + 1) -> N - 1
+ kmp_uint64 outer_iters = (outer_ub0 - outer_lb0 + 1) + inner_ub0;
+ kmp_uint64 iter_total = outer_iters * (outer_iters + 1) / 2;
+ // the current thread's number of iterations:
+ // each thread gets an equal number of iterations: total number of iterations
+ // divided by the number of threads plus, if there's a remainder,
+ // the first threads with the number up to the remainder get an additional
+ // iteration each to cover it
+ kmp_uint64 iter_current =
+ iter_total / nth + ((tid < (iter_total % nth)) ? 1 : 0);
+ // cumulative number of iterations executed by all the previous threads:
+ // threads with the tid below the remainder will have (iter_total/nth+1)
+ // elements, and so will all threads before them so the cumulative number of
+ // iterations executed by the all previous will be the current thread's number
+ // of iterations multiplied by the number of previous threads which is equal
+ // to the current thread's tid; threads with the number equal or above the
+ // remainder will have (iter_total/nth) elements so the cumulative number of
+ // iterations previously executed is its number of iterations multipled by the
+ // number of previous threads which is again equal to the current thread's tid
+ // PLUS all the remainder iterations that will have been executed by the
+ // previous threads
+ kmp_uint64 iter_before_current =
+ tid * iter_current + ((tid < iter_total % nth) ? 0 : (iter_total % nth));
+ // cumulative number of iterations executed with the current thread is
+ // the cumulative number executed before it plus its own
+ kmp_uint64 iter_with_current = iter_before_current + iter_current;
+ // calculate the outer loop lower bound (lbo) which is the max outer iv value
+ // that gives the number of iterations that is equal or just below the total
+ // number of iterations executed by the previous threads, for less_than
+ // (1-based) inner loops (inner_ub0 == -1) it will be i.e.
+ // lbo*(lbo-1)/2<=iter_before_current => lbo^2-lbo-2*iter_before_current<=0
+ // for less_than_equal (0-based) inner loops (inner_ub == 0) it will be:
+ // i.e. lbo*(lbo+1)/2<=iter_before_current =>
+ // lbo^2+lbo-2*iter_before_current<=0 both cases can be handled similarily
+ // using a parameter to control the equation sign
+ kmp_int64 inner_adjustment = 1 + 2 * inner_ub0;
+ kmp_uint64 lower_bound_outer =
+ (kmp_uint64)(sqrt_newton_approx(inner_adjustment * inner_adjustment +
+ 8 * iter_before_current) +
+ inner_adjustment) /
+ 2 -
+ inner_adjustment;
+ // calculate the inner loop lower bound which is the remaining number of
+ // iterations required to hit the total number of iterations executed by the
+ // previous threads giving the starting point of this thread
+ kmp_uint64 lower_bound_inner =
+ iter_before_current -
+ ((lower_bound_outer + inner_adjustment) * lower_bound_outer) / 2;
+ // calculate the outer loop upper bound using the same approach as for the
+ // inner bound except using the total number of iterations executed with the
+ // current thread
+ kmp_uint64 upper_bound_outer =
+ (kmp_uint64)(sqrt_newton_approx(inner_adjustment * inner_adjustment +
+ 8 * iter_with_current) +
+ inner_adjustment) /
+ 2 -
+ inner_adjustment;
+ // calculate the inner loop upper bound which is the remaining number of
+ // iterations required to hit the total number of iterations executed after
+ // the current thread giving the starting point of the next thread
+ kmp_uint64 upper_bound_inner =
+ iter_with_current -
+ ((upper_bound_outer + inner_adjustment) * upper_bound_outer) / 2;
+ // adjust the upper bounds down by 1 element to point at the last iteration of
+ // the current thread the first iteration of the next thread
+ if (upper_bound_inner == 0) {
+ // {n,0} => {n-1,n-1}
+ upper_bound_outer -= 1;
+ upper_bound_inner = upper_bound_outer;
+ } else {
+ // {n,m} => {n,m-1} (m!=0)
+ upper_bound_inner -= 1;
+ }
+
+ // assign the values, zeroing out lb1 and ub1 values since the iteration space
+ // is now one-dimensional
+ chunk_bounds_nest[0].lb0_u64 = lower_bound_outer;
+ chunk_bounds_nest[1].lb0_u64 = lower_bound_inner;
+ chunk_bounds_nest[0].ub0_u64 = upper_bound_outer;
+ chunk_bounds_nest[1].ub0_u64 = upper_bound_inner;
+ chunk_bounds_nest[0].lb1_u64 = 0;
+ chunk_bounds_nest[0].ub1_u64 = 0;
+ chunk_bounds_nest[1].lb1_u64 = 0;
+ chunk_bounds_nest[1].ub1_u64 = 0;
+
+#if 0
+ printf("tid/nth = %d/%d : From [%llu, %llu] To [%llu, %llu] : Chunks %llu/%llu\n",
+ tid, nth, chunk_bounds_nest[0].lb0_u64, chunk_bounds_nest[1].lb0_u64,
+ chunk_bounds_nest[0].ub0_u64, chunk_bounds_nest[1].ub0_u64, iter_current, iter_total);
+#endif
+}
+
+/**************************************************************************
+ * Handle upper triangle matrix in the canonical form
+ * i = 0; i <= N; i++ {0,0}:{N,0}
+ * j = 0+1*i; j <= N; j++ {0,1}:{N,0}
+ * ************************************************************************/
+void kmp_handle_upper_triangle_matrix(
+ /*in*/ kmp_uint32 nth,
+ /*in*/ kmp_uint32 tid,
+ /*in */ kmp_index_t n,
+ /*in/out*/ bounds_info_t *original_bounds_nest,
+ /*out*/ bounds_info_t *chunk_bounds_nest) {
+
+ // transfer loop types from the original loop to the chunks
+ for (kmp_index_t i = 0; i < n; ++i) {
+ chunk_bounds_nest[i] = original_bounds_nest[i];
+ }
+ // cleanup iv variables
+ kmp_uint64 outer_ub0 = kmp_fix_iv(original_bounds_nest[0].loop_iv_type,
+ original_bounds_nest[0].ub0_u64);
+ kmp_uint64 outer_lb0 = kmp_fix_iv(original_bounds_nest[0].loop_iv_type,
+ original_bounds_nest[0].lb0_u64);
+ kmp_uint64 inner_ub0 = kmp_fix_iv(original_bounds_nest[1].loop_iv_type,
+ original_bounds_nest[1].ub0_u64);
+ // calculate the chunk's lower and upper bounds
+ // the total number of iterations in the loop is the sum of the arithmetic
+ // progression from the outer lower to outer upper bound (inclusive since the
+ // loop is canonical) note that less_than inner loops (inner_ub0 = -1)
+ // effectively make the progression 1-based making N = (outer_ub0 - inner_lb0
+ // + 1) -> N - 1
+ kmp_uint64 outer_iters = (outer_ub0 - outer_lb0 + 1);
+ kmp_uint64 iter_total = outer_iters * (outer_iters + 1) / 2;
+ // the current thread's number of iterations:
+ // each thread gets an equal number of iterations: total number of iterations
+ // divided by the number of threads plus, if there's a remainder,
+ // the first threads with the number up to the remainder get an additional
+ // iteration each to cover it
+ kmp_uint64 iter_current =
+ iter_total / nth + ((tid < (iter_total % nth)) ? 1 : 0);
+ // cumulative number of iterations executed by all the previous threads:
+ // threads with the tid below the remainder will have (iter_total/nth+1)
+ // elements, and so will all threads before them so the cumulative number of
+ // iterations executed by the all previous will be the current thread's number
+ // of iterations multiplied by the number of previous threads which is equal
+ // to the current thread's tid; threads with the number equal or above the
+ // remainder will have (iter_total/nth) elements so the cumulative number of
+ // iterations previously executed is its number of iterations multipled by the
+ // number of previous threads which is again equal to the current thread's tid
+ // PLUS all the remainder iterations that will have been executed by the
+ // previous threads
+ kmp_uint64 iter_before_current =
+ tid * iter_current + ((tid < iter_total % nth) ? 0 : (iter_total % nth));
+ // cumulative number of iterations executed with the current thread is
+ // the cumulative number executed before it plus its own
+ kmp_uint64 iter_with_current = iter_before_current + iter_current;
+ // calculate the outer loop lower bound (lbo) which is the max outer iv value
+ // that gives the number of iterations that is equal or just below the total
+ // number of iterations executed by the previous threads, for less_than
+ // (1-based) inner loops (inner_ub0 == -1) it will be i.e.
+ // lbo*(lbo-1)/2<=iter_before_current => lbo^2-lbo-2*iter_before_current<=0
+ // for less_than_equal (0-based) inner loops (inner_ub == 0) it will be:
+ // i.e. lbo*(lbo+1)/2<=iter_before_current =>
+ // lbo^2+lbo-2*iter_before_current<=0 both cases can be handled similarily
+ // using a parameter to control the equatio sign
+ kmp_uint64 lower_bound_outer =
+ (kmp_uint64)(sqrt_newton_approx(1 + 8 * iter_before_current) + 1) / 2 - 1;
+ ;
+ // calculate the inner loop lower bound which is the remaining number of
+ // iterations required to hit the total number of iterations executed by the
+ // previous threads giving the starting point of this thread
+ kmp_uint64 lower_bound_inner =
+ iter_before_current - ((lower_bound_outer + 1) * lower_bound_outer) / 2;
+ // calculate the outer loop upper bound using the same approach as for the
+ // inner bound except using the total number of iterations executed with the
+ // current thread
+ kmp_uint64 upper_bound_outer =
+ (kmp_uint64)(sqrt_newton_approx(1 + 8 * iter_with_current) + 1) / 2 - 1;
+ // calculate the inner loop upper bound which is the remaining number of
+ // iterations required to hit the total number of iterations executed after
+ // the current thread giving the starting point of the next thread
+ kmp_uint64 upper_bound_inner =
+ iter_with_current - ((upper_bound_outer + 1) * upper_bound_outer) / 2;
+ // adjust the upper bounds down by 1 element to point at the last iteration of
+ // the current thread the first iteration of the next thread
+ if (upper_bound_inner == 0) {
+ // {n,0} => {n-1,n-1}
+ upper_bound_outer -= 1;
+ upper_bound_inner = upper_bound_outer;
+ } else {
+ // {n,m} => {n,m-1} (m!=0)
+ upper_bound_inner -= 1;
+ }
+
+ // assign the values, zeroing out lb1 and ub1 values since the iteration space
+ // is now one-dimensional
+ chunk_bounds_nest[0].lb0_u64 = (outer_iters - 1) - upper_bound_outer;
+ chunk_bounds_nest[1].lb0_u64 = (outer_iters - 1) - upper_bound_inner;
+ chunk_bounds_nest[0].ub0_u64 = (outer_iters - 1) - lower_bound_outer;
+ chunk_bounds_nest[1].ub0_u64 = (outer_iters - 1) - lower_bound_inner;
+ chunk_bounds_nest[0].lb1_u64 = 0;
+ chunk_bounds_nest[0].ub1_u64 = 0;
+ chunk_bounds_nest[1].lb1_u64 = 0;
+ chunk_bounds_nest[1].ub1_u64 = 0;
+
+#if 0
+ printf("tid/nth = %d/%d : From [%llu, %llu] To [%llu, %llu] : Chunks %llu/%llu\n",
+ tid, nth, chunk_bounds_nest[0].lb0_u64, chunk_bounds_nest[1].lb0_u64,
+ chunk_bounds_nest[0].ub0_u64, chunk_bounds_nest[1].ub0_u64, iter_current, iter_total);
+#endif
+}
//----------Init API for non-rectangular loops--------------------------------
// Init API for collapsed loops (static, no chunks defined).
@@ -1334,6 +1632,19 @@ __kmpc_for_collapsed_init(ident_t *loc, kmp_int32 gtid,
KMP_DEBUG_ASSERT(tid < nth);
+ // Handle special cases
+ nested_loop_type_t loop_type =
+ kmp_identify_nested_loop_structure(original_bounds_nest, n);
+ if (loop_type == nested_loop_type_lower_triangular_matrix) {
+ kmp_handle_lower_triangle_matrix(nth, tid, n, original_bounds_nest,
+ chunk_bounds_nest);
+ return TRUE;
+ } else if (loop_type == nested_loop_type_upper_triangular_matrix) {
+ kmp_handle_upper_triangle_matrix(nth, tid, n, original_bounds_nest,
+ chunk_bounds_nest);
+ return TRUE;
+ }
+
CollapseAllocator<kmp_uint64> original_ivs_start(n);
if (!kmp_calc_original_ivs_for_start(original_bounds_nest, n,
diff --git a/openmp/runtime/src/kmp_collapse.h b/openmp/runtime/src/kmp_collapse.h
index e4870185645d..1044478554a0 100644
--- a/openmp/runtime/src/kmp_collapse.h
+++ b/openmp/runtime/src/kmp_collapse.h
@@ -45,6 +45,13 @@ enum loop_type_t : kmp_int32 {
loop_type_int64 = 7
};
+// Defining loop types to handle special cases
+enum nested_loop_type_t : kmp_int32 {
+ nested_loop_type_unkown = 0,
+ nested_loop_type_lower_triangular_matrix = 1,
+ nested_loop_type_upper_triangular_matrix = 2
+};
+
/*!
@ingroup WORK_SHARING
* Describes the structure for rectangular nested loops.
@@ -124,14 +131,14 @@ struct bounds_info_t {
// It's represented in kmp_uint64, but each dimention is calculated in
// that loop IV type. Also dimentions have to be converted to those types
// when used in generated code.
-typedef kmp_uint64* kmp_point_t;
+typedef kmp_uint64 *kmp_point_t;
// Array: Number of loop iterations on each nesting level to achieve some point,
// in expanded space or in original space.
// OMPTODO: move from using iterations to using offsets (iterations multiplied
// by steps). For those we need to be careful with the types, as step can be
// negative, but it'll remove multiplications and divisions in several places.
-typedef kmp_loop_nest_iv_t* kmp_iterations_t;
+typedef kmp_loop_nest_iv_t *kmp_iterations_t;
// Internal struct with additional info:
template <typename T> struct bounds_info_internalXX_template {
diff --git a/openmp/runtime/test/worksharing/for/omp_for_collapse_LowerTriangularLess.c b/openmp/runtime/test/worksharing/for/omp_for_collapse_LowerTriangularLess.c
new file mode 100644
index 000000000000..9d742066cf1f
--- /dev/null
+++ b/openmp/runtime/test/worksharing/for/omp_for_collapse_LowerTriangularLess.c
@@ -0,0 +1,124 @@
+// RUN: %libomp-compile-and-run
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include "omp.h"
+
+#ifndef MAX_BOUND
+#define MAX_BOUND 64
+#endif
+#ifndef _MSC_VER
+#define NO_EFFICIENCY_CHECK
+#endif
+
+/* To ensure Correctness, only valid iterations are executed and are executed
+ only once. Stores the number of times an iteration is executed. */
+unsigned *execution_count = NULL;
+/* Stores the number of iterations executed by each thread. */
+unsigned *iterations_per_thread = NULL;
+
+unsigned *Alloc(unsigned bound1, unsigned bound2) {
+ return (unsigned *)(malloc(bound1 * bound2 * sizeof(unsigned)));
+}
+
+void ZeroOut(unsigned *p, unsigned bound1, unsigned bound2) {
+ memset(p, 0, bound1 * bound2 * sizeof(unsigned));
+}
+
+void Free(unsigned *p) { free((void *)p); }
+
+unsigned *Index(unsigned *p, unsigned i, unsigned j, unsigned bound2) {
+ return &p[i * bound2 + j];
+}
+
+int test(unsigned upper_bound) {
+
+ unsigned total_iterations = upper_bound * (upper_bound - 1) / 2;
+ unsigned num_threads = omp_get_max_threads();
+ unsigned lower_per_chunk = total_iterations / num_threads;
+ unsigned upper_per_chunk =
+ lower_per_chunk + ((total_iterations % num_threads) ? 1 : 0);
+ int i, j;
+
+ omp_set_num_threads(num_threads);
+
+ ZeroOut(execution_count, upper_bound, upper_bound);
+ ZeroOut(iterations_per_thread, num_threads, 1);
+
+#ifdef VERBOSE
+ fprintf(stderr,
+ "INFO: Using %6d threads for %6d outer iterations with %6d [%6d:%6d] "
+ "chunks "
+ "loop type lower triangle <,< - ",
+ num_threads, upper_bound, total_iterations, lower_per_chunk,
+ upper_per_chunk);
+#endif
+
+#pragma omp parallel shared(iterations_per_thread, execution_count)
+ { /* begin of parallel */
+ /* Lower triangular execution_count matrix */
+#pragma omp for schedule(static) collapse(2)
+ for (i = 0; i < upper_bound; i++) {
+ for (j = 0; j < i; j++) {
+ (*Index(iterations_per_thread, omp_get_thread_num(), 0, 1))++;
+ (*Index(execution_count, i, j, upper_bound))++;
+ }
+ } /* end of for*/
+ } /* end of parallel */
+
+ /* check the execution_count array */
+ for (i = 0; i < upper_bound; i++) {
+ for (j = 0; j < i; j++) {
+ unsigned value = *Index(execution_count, i, j, upper_bound);
+ /* iteration with j<=i are valid, but should have been executed only once
+ */
+ if (value != 1) {
+ fprintf(stderr, "ERROR: valid iteration [%i,%i] executed %i times.\n",
+ i, j, value);
+ return 0;
+ }
+ }
+ for (j = i; j < upper_bound; j++) {
+ unsigned value = *Index(execution_count, i, j, upper_bound);
+ /* iteration with j>=i are invalid and should not have been executed
+ */
+ if (value > 0) {
+ fprintf(stderr, "ERROR: invalid iteration [%i,%i] executed %i times.\n",
+ i, j, value);
+ return 0;
+ }
+ }
+ }
+
+#ifndef NO_EFFICIENCY_CHECK
+ /* Ensure the number of iterations executed by each thread is within bounds */
+ for (i = 0; i < num_threads; i++) {
+ unsigned value = *Index(iterations_per_thread, i, 0, 1);
+ if (value < lower_per_chunk || value > upper_per_chunk) {
+ fprintf(stderr,
+ "ERROR: Inefficient Collapse thread %d of %d assigned %i "
+ "iterations; must be between %d and %d\n",
+ i, num_threads, value, lower_per_chunk, upper_per_chunk);
+ return 0;
+ }
+ }
+#endif
+#ifdef VERBOSE
+ fprintf(stderr, "PASSED\r\n");
+#endif
+ return 1;
+}
+
+int main() {
+
+ execution_count = Alloc(MAX_BOUND, MAX_BOUND);
+ iterations_per_thread = Alloc(omp_get_max_threads(), 1);
+
+ for (unsigned j = 0; j < MAX_BOUND; j++) {
+ if (!test(j))
+ return 1;
+ }
+ Free(execution_count);
+ Free(iterations_per_thread);
+ return 0;
+}
diff --git a/openmp/runtime/test/worksharing/for/omp_for_collapse_LowerTriangularLessEqual.c b/openmp/runtime/test/worksharing/for/omp_for_collapse_LowerTriangularLessEqual.c
new file mode 100644
index 000000000000..154ee0f69daa
--- /dev/null
+++ b/openmp/runtime/test/worksharing/for/omp_for_collapse_LowerTriangularLessEqual.c
@@ -0,0 +1,124 @@
+// RUN: %libomp-compile-and-run
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include "omp.h"
+
+#ifndef MAX_BOUND
+#define MAX_BOUND 64
+#endif
+#ifndef _MSC_VER
+#define NO_EFFICIENCY_CHECK
+#endif
+
+/* To ensure Correctness, only valid iterations are executed and are executed
+ only once. Stores the number of times an iteration is executed. */
+unsigned *execution_count = NULL;
+/* Stores the number of iterations executed by each thread. */
+unsigned *iterations_per_thread = NULL;
+
+unsigned *Alloc(unsigned bound1, unsigned bound2) {
+ return (unsigned *)(malloc(bound1 * bound2 * sizeof(unsigned)));
+}
+
+void ZeroOut(unsigned *p, unsigned bound1, unsigned bound2) {
+ memset(p, 0, bound1 * bound2 * sizeof(unsigned));
+}
+
+void Free(unsigned *p) { free((void *)p); }
+
+unsigned *Index(unsigned *p, unsigned i, unsigned j, unsigned bound2) {
+ return &p[i * bound2 + j];
+}
+
+int test(int upper_bound) {
+
+ unsigned total_iterations = upper_bound * (upper_bound + 1) / 2;
+ unsigned num_threads = omp_get_max_threads();
+ unsigned lower_per_chunk = total_iterations / num_threads;
+ unsigned upper_per_chunk =
+ lower_per_chunk + ((total_iterations % num_threads) ? 1 : 0);
+ int i, j;
+
+ omp_set_num_threads(num_threads);
+
+ ZeroOut(execution_count, upper_bound, upper_bound);
+ ZeroOut(iterations_per_thread, num_threads, 1);
+
+#ifdef VERBOSE
+ fprintf(stderr,
+ "INFO: Using %6d threads for %6d outer iterations with %6d [%6d:%6d] "
+ "chunks "
+ "loop type lower triangle <,<= - ",
+ num_threads, upper_bound, total_iterations, lower_per_chunk,
+ upper_per_chunk);
+#endif
+
+#pragma omp parallel shared(iterations_per_thread, execution_count)
+ { /* begin of parallel */
+ /* Lower triangular execution_count matrix */
+#pragma omp for schedule(static) collapse(2)
+ for (i = 0; i < upper_bound; i++) {
+ for (j = 0; j <= i; j++) {
+ (*Index(iterations_per_thread, omp_get_thread_num(), 0, 1))++;
+ (*Index(execution_count, i, j, upper_bound))++;
+ }
+ } /* end of for*/
+ } /* end of parallel */
+
+ /* check the execution_count array */
+ for (i = 0; i < upper_bound; i++) {
+ for (j = 0; j <= i; j++) {
+ unsigned value = *Index(execution_count, i, j, upper_bound);
+ /* iteration with j<=i are valid, but should have been executed only once
+ */
+ if (value != 1) {
+ fprintf(stderr, "ERROR: valid iteration [%i,%i] executed %i times.\n",
+ i, j, value);
+ return 0;
+ }
+ }
+ for (j = i + 1; j < upper_bound; j++) {
+ unsigned value = *Index(execution_count, i, j, upper_bound);
+ /* iteration with j>=i are invalid and should not have been executed
+ */
+ if (value > 0) {
+ fprintf(stderr, "ERROR: invalid iteration [%i,%i] executed %i times.\n",
+ i, j, value);
+ return 0;
+ }
+ }
+ }
+
+#ifndef NO_EFFICIENCY_CHECK
+ /* Ensure the number of iterations executed by each thread is within bounds */
+ for (i = 0; i < num_threads; i++) {
+ unsigned value = *Index(iterations_per_thread, i, 0, 1);
+ if (value < lower_per_chunk || value > upper_per_chunk) {
+ fprintf(stderr,
+ "ERROR: Inefficient Collapse thread %d of %d assigned %i "
+ "iterations; must be between %d and %d\n",
+ i, num_threads, value, lower_per_chunk, upper_per_chunk);
+ return 0;
+ }
+ }
+#endif
+#ifdef VERBOSE
+ fprintf(stderr, "PASSED\r\n");
+#endif
+ return 1;
+}
+
+int main() {
+
+ execution_count = Alloc(MAX_BOUND, MAX_BOUND);
+ iterations_per_thread = Alloc(omp_get_max_threads(), 1);
+
+ for (unsigned j = 0; j < MAX_BOUND; j++) {
+ if (!test(j))
+ return 1;
+ }
+ Free(execution_count);
+ Free(iterations_per_thread);
+ return 0;
+}
diff --git a/openmp/runtime/test/worksharing/for/omp_for_collapse_UpperTriangular.c b/openmp/runtime/test/worksharing/for/omp_for_collapse_UpperTriangular.c
new file mode 100644
index 000000000000..452410025be0
--- /dev/null
+++ b/openmp/runtime/test/worksharing/for/omp_for_collapse_UpperTriangular.c
@@ -0,0 +1,124 @@
+// RUN: %libomp-compile-and-run
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include "omp.h"
+
+#ifndef MAX_BOUND
+#define MAX_BOUND 64
+#endif
+#ifndef _MSC_VER
+#define NO_EFFICIENCY_CHECK
+#endif
+
+/* To ensure Correctness, only valid iterations are executed and are executed
+ only once. Stores the number of times an iteration is executed. */
+unsigned *execution_count = NULL;
+/* Stores the number of iterations executed by each thread. */
+unsigned *iterations_per_thread = NULL;
+
+unsigned *Alloc(unsigned bound1, unsigned bound2) {
+ return (unsigned *)(malloc(bound1 * bound2 * sizeof(unsigned)));
+}
+
+void ZeroOut(unsigned *p, unsigned bound1, unsigned bound2) {
+ memset(p, 0, bound1 * bound2 * sizeof(unsigned));
+}
+
+void Free(unsigned *p) { free((void *)p); }
+
+unsigned *Index(unsigned *p, unsigned i, unsigned j, unsigned bound2) {
+ return &p[i * bound2 + j];
+}
+
+int test(unsigned upper_bound) {
+
+ unsigned total_iterations = upper_bound * (upper_bound + 1) / 2;
+ unsigned num_threads = omp_get_max_threads();
+ unsigned lower_per_chunk = total_iterations / num_threads;
+ unsigned upper_per_chunk =
+ lower_per_chunk + ((total_iterations % num_threads) ? 1 : 0);
+ int i, j;
+
+ omp_set_num_threads(num_threads);
+
+ ZeroOut(execution_count, upper_bound, upper_bound);
+ ZeroOut(iterations_per_thread, num_threads, 1);
+
+#ifdef VERBOSE
+ fprintf(stderr,
+ "INFO: Using %6d threads for %6d outer iterations with %6d [%6d:%6d] "
+ "chunks "
+ "loop type upper triangle <,< - ",
+ num_threads, upper_bound, total_iterations, lower_per_chunk,
+ upper_per_chunk);
+#endif
+
+#pragma omp parallel shared(iterations_per_thread, execution_count)
+ { /* begin of parallel */
+ /* Lower triangular execution_count matrix */
+#pragma omp for schedule(static) collapse(2)
+ for (i = 0; i < upper_bound; i++) {
+ for (j = i; j < upper_bound; j++) {
+ (*Index(iterations_per_thread, omp_get_thread_num(), 0, 1))++;
+ (*Index(execution_count, i, j, upper_bound))++;
+ }
+ } /* end of for*/
+ } /* end of parallel */
+
+ /* check the execution_count array */
+ for (i = 0; i < upper_bound; i++) {
+ for (j = i; j < upper_bound; j++) {
+ unsigned value = *Index(execution_count, i, j, upper_bound);
+ /* iteration with j<=i are valid, but should have been executed only once
+ */
+ if (value != 1) {
+ fprintf(stderr, "ERROR: valid iteration [%i,%i] executed %i times.\n",
+ i, j, value);
+ return 0;
+ }
+ }
+ for (j = 0; j < i; j++) {
+ unsigned value = *Index(execution_count, i, j, upper_bound);
+ /* iteration with j>=i are invalid and should not have been executed
+ */
+ if (value > 0) {
+ fprintf(stderr, "ERROR: invalid iteration [%i,%i] executed %i times.\n",
+ i, j, value);
+ return 0;
+ }
+ }
+ }
+
+#ifndef NO_EFFICIENCY_CHECK
+ /* Ensure the number of iterations executed by each thread is within bounds */
+ for (i = 0; i < num_threads; i++) {
+ unsigned value = *Index(iterations_per_thread, i, 0, 1);
+ if (value < lower_per_chunk || value > upper_per_chunk) {
+ fprintf(stderr,
+ "ERROR: Inefficient Collapse thread %d of %d assigned %i "
+ "iterations; must be between %d and %d\n",
+ i, num_threads, value, lower_per_chunk, upper_per_chunk);
+ return 0;
+ }
+ }
+#endif
+#ifdef VERBOSE
+ fprintf(stderr, "PASSED\r\n");
+#endif
+ return 1;
+}
+
+int main() {
+
+ execution_count = Alloc(MAX_BOUND, MAX_BOUND);
+ iterations_per_thread = Alloc(omp_get_max_threads(), 1);
+
+ for (unsigned j = 0; j < MAX_BOUND; j++) {
+ if (!test(j))
+ return 1;
+ }
+ Free(execution_count);
+ Free(iterations_per_thread);
+ return 0;
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index 7a6bc2dc3202..2cfe61844703 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -10841,6 +10841,7 @@ cc_library(
":MemRefDialect",
":Parser",
":SCFDialect",
+ ":MeshShardingInterface",
":SideEffectInterfaces",
":SparseTensorDialect",
":Support",
@@ -10994,10 +10995,13 @@ cc_library(
":MathDialect",
":MemRefDialect",
":MemRefTransforms",
+ ":MeshDialect",
+ ":MeshTransforms",
":Pass",
":SCFDialect",
":SCFTransforms",
":SCFUtils",
+ ":MeshShardingInterface",
":SparseTensorDialect",
":SubsetOpInterface",
":Support",