diff options
author | Florian Mayer <fmayer@google.com> | 2024-03-08 15:46:12 -0800 |
---|---|---|
committer | Florian Mayer <fmayer@google.com> | 2024-03-08 15:46:12 -0800 |
commit | a8be7221b79d4fc4f239c107be4155755c050514 (patch) | |
tree | 65d1bfae5f15d8fe660287beed69a1a901764379 | |
parent | b408241d0ad9ce009b49018fe1e9838887abf3c1 (diff) | |
parent | da4957be2365831c94eab0b52612367c29f1d299 (diff) |
[𝘀𝗽𝗿] changes introduced through rebaseupstream/users/fmayer/spr/main.nfc-hwasan-also-be-more-consistent-when-getting-pointer-types
Created using spr 1.3.4
[skip ci]
63 files changed, 2007 insertions, 225 deletions
diff --git a/clang/docs/LanguageExtensions.rst b/clang/docs/LanguageExtensions.rst index 2b54dffd058a..06af93fd3c15 100644 --- a/clang/docs/LanguageExtensions.rst +++ b/clang/docs/LanguageExtensions.rst @@ -5378,6 +5378,7 @@ The following builtin intrinsics can be used in constant expressions: * ``__builtin_popcount`` * ``__builtin_popcountl`` * ``__builtin_popcountll`` +* ``__builtin_popcountg`` * ``__builtin_rotateleft8`` * ``__builtin_rotateleft16`` * ``__builtin_rotateleft32`` diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst index 42c4a7c4d4bd..fa23c215790f 100644 --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -157,6 +157,11 @@ Non-comprehensive list of changes in this release - ``__builtin_addc``, ``__builtin_subc``, and the other sizes of those builtins are now constexpr and may be used in constant expressions. +- Added ``__builtin_popcountg`` as a type-generic alternative to + ``__builtin_popcount{,l,ll}`` with support for any unsigned integer type. Like + the previous builtins, this new builtin is constexpr and may be used in + constant expressions. + New Compiler Flags ------------------ diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index a81131d82c4c..9c703377ca8d 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -706,7 +706,7 @@ def Popcount : Builtin, BitInt_Long_LongLongTemplate { def Popcountg : Builtin { let Spellings = ["__builtin_popcountg"]; - let Attributes = [NoThrow, Const, CustomTypeChecking]; + let Attributes = [NoThrow, Const, Constexpr, CustomTypeChecking]; let Prototype = "int(...)"; } diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp index d8ca35740fbc..4a7c7755e1d6 100644 --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -12483,6 +12483,7 @@ bool IntExprEvaluator::VisitBuiltinCallExpr(const CallExpr *E, case Builtin::BI__builtin_popcount: case Builtin::BI__builtin_popcountl: case Builtin::BI__builtin_popcountll: + case Builtin::BI__builtin_popcountg: case Builtin::BI__popcnt16: // Microsoft variants of popcount case Builtin::BI__popcnt: case Builtin::BI__popcnt64: { diff --git a/clang/test/Sema/constant-builtins-2.c b/clang/test/Sema/constant-builtins-2.c index 2bdd7b06daab..0935abe4c65f 100644 --- a/clang/test/Sema/constant-builtins-2.c +++ b/clang/test/Sema/constant-builtins-2.c @@ -237,6 +237,13 @@ char popcount7[__builtin_popcountl(~0L) == BITSIZE(long) ? 1 : -1]; char popcount8[__builtin_popcountll(0LL) == 0 ? 1 : -1]; char popcount9[__builtin_popcountll(0xF0F0LL) == 8 ? 1 : -1]; char popcount10[__builtin_popcountll(~0LL) == BITSIZE(long long) ? 1 : -1]; +char popcount11[__builtin_popcountg(0U) == 0 ? 1 : -1]; +char popcount12[__builtin_popcountg(0xF0F0U) == 8 ? 1 : -1]; +char popcount13[__builtin_popcountg(~0U) == BITSIZE(int) ? 1 : -1]; +char popcount14[__builtin_popcountg(~0UL) == BITSIZE(long) ? 1 : -1]; +char popcount15[__builtin_popcountg(~0ULL) == BITSIZE(long long) ? 1 : -1]; +char popcount16[__builtin_popcountg(~(unsigned __int128)0) == BITSIZE(__int128) ? 1 : -1]; +char popcount17[__builtin_popcountg(~(unsigned _BitInt(128))0) == BITSIZE(_BitInt(128)) ? 1 : -1]; char parity1[__builtin_parity(0) == 0 ? 1 : -1]; char parity2[__builtin_parity(0xb821) == 0 ? 1 : -1]; diff --git a/compiler-rt/lib/fuzzer/FuzzerUtilWindows.cpp b/compiler-rt/lib/fuzzer/FuzzerUtilWindows.cpp index 0dbcec8b5f22..db80eb383885 100644 --- a/compiler-rt/lib/fuzzer/FuzzerUtilWindows.cpp +++ b/compiler-rt/lib/fuzzer/FuzzerUtilWindows.cpp @@ -243,7 +243,7 @@ void SetThreadName(std::thread &thread, const std::string &name) { HMODULE kbase = GetModuleHandleA("KernelBase.dll"); proc ThreadNameProc = reinterpret_cast<proc>(GetProcAddress(kbase, "SetThreadDescription")); - if (proc) { + if (ThreadNameProc) { std::wstring buf; auto sz = MultiByteToWideChar(CP_UTF8, 0, name.data(), -1, nullptr, 0); if (sz > 0) { diff --git a/llvm/include/llvm/DebugInfo/DWARF/DWARFDebugFrame.h b/llvm/include/llvm/DebugInfo/DWARF/DWARFDebugFrame.h index bc35f2ab988e..c7c558850a28 100644 --- a/llvm/include/llvm/DebugInfo/DWARF/DWARFDebugFrame.h +++ b/llvm/include/llvm/DebugInfo/DWARF/DWARFDebugFrame.h @@ -454,8 +454,8 @@ public: /// where a problem occurred in case an error is returned. Error parse(DWARFDataExtractor Data, uint64_t *Offset, uint64_t EndOffset); - void dump(raw_ostream &OS, DIDumpOptions DumpOpts, - unsigned IndentLevel = 1) const; + void dump(raw_ostream &OS, DIDumpOptions DumpOpts, unsigned IndentLevel, + std::optional<uint64_t> InitialLocation) const; void addInstruction(const Instruction &I) { Instructions.push_back(I); } @@ -524,7 +524,7 @@ private: /// Print \p Opcode's operand number \p OperandIdx which has value \p Operand. void printOperand(raw_ostream &OS, DIDumpOptions DumpOpts, const Instruction &Instr, unsigned OperandIdx, - uint64_t Operand) const; + uint64_t Operand, std::optional<uint64_t> &Address) const; }; /// An entry in either debug_frame or eh_frame. This entry can be a CIE or an diff --git a/llvm/include/llvm/Transforms/Utils/MemoryTaggingSupport.h b/llvm/include/llvm/Transforms/Utils/MemoryTaggingSupport.h index eb00e6c4e856..df61f60de4f2 100644 --- a/llvm/include/llvm/Transforms/Utils/MemoryTaggingSupport.h +++ b/llvm/include/llvm/Transforms/Utils/MemoryTaggingSupport.h @@ -78,6 +78,7 @@ private: uint64_t getAllocaSizeInBytes(const AllocaInst &AI); void alignAndPadAlloca(memtag::AllocaInfo &Info, llvm::Align Align); +bool isLifetimeIntrinsic(Value *V); } // namespace memtag } // namespace llvm diff --git a/llvm/lib/DebugInfo/DWARF/DWARFDebugFrame.cpp b/llvm/lib/DebugInfo/DWARF/DWARFDebugFrame.cpp index aae1668c1639..aff26824dda1 100644 --- a/llvm/lib/DebugInfo/DWARF/DWARFDebugFrame.cpp +++ b/llvm/lib/DebugInfo/DWARF/DWARFDebugFrame.cpp @@ -630,6 +630,8 @@ Error UnwindTable::parseRows(const CFIProgram &CFIP, UnwindRow &Row, if (LRLoc->getLocation() == UnwindLocation::Constant) { // Toggle the constant value from 0 to 1 or 1 to 0. LRLoc->setConstant(LRLoc->getConstant() ^ 1); + Row.getRegisterLocations().setRegisterLocation( + AArch64DWARFPAuthRaState, *LRLoc); } else { return createStringError( errc::invalid_argument, @@ -858,7 +860,8 @@ CFIProgram::getOperandTypes() { /// Print \p Opcode's operand number \p OperandIdx which has value \p Operand. void CFIProgram::printOperand(raw_ostream &OS, DIDumpOptions DumpOpts, const Instruction &Instr, unsigned OperandIdx, - uint64_t Operand) const { + uint64_t Operand, + std::optional<uint64_t> &Address) const { assert(OperandIdx < MaxOperands); uint8_t Opcode = Instr.Opcode; OperandType Type = getOperandTypes()[Opcode][OperandIdx]; @@ -877,6 +880,7 @@ void CFIProgram::printOperand(raw_ostream &OS, DIDumpOptions DumpOpts, break; case OT_Address: OS << format(" %" PRIx64, Operand); + Address = Operand; break; case OT_Offset: // The offsets are all encoded in a unsigned form, but in practice @@ -888,7 +892,11 @@ void CFIProgram::printOperand(raw_ostream &OS, DIDumpOptions DumpOpts, if (CodeAlignmentFactor) OS << format(" %" PRId64, Operand * CodeAlignmentFactor); else - OS << format(" %" PRId64 "*code_alignment_factor" , Operand); + OS << format(" %" PRId64 "*code_alignment_factor", Operand); + if (Address && CodeAlignmentFactor) { + *Address += Operand * CodeAlignmentFactor; + OS << format(" to 0x%" PRIx64, *Address); + } break; case OT_SignedFactDataOffset: if (DataAlignmentFactor) @@ -918,13 +926,14 @@ void CFIProgram::printOperand(raw_ostream &OS, DIDumpOptions DumpOpts, } void CFIProgram::dump(raw_ostream &OS, DIDumpOptions DumpOpts, - unsigned IndentLevel) const { + unsigned IndentLevel, + std::optional<uint64_t> Address) const { for (const auto &Instr : Instructions) { uint8_t Opcode = Instr.Opcode; OS.indent(2 * IndentLevel); OS << callFrameString(Opcode) << ":"; for (unsigned i = 0; i < Instr.Ops.size(); ++i) - printOperand(OS, DumpOpts, Instr, i, Instr.Ops[i]); + printOperand(OS, DumpOpts, Instr, i, Instr.Ops[i], Address); OS << '\n'; } } @@ -975,7 +984,7 @@ void CIE::dump(raw_ostream &OS, DIDumpOptions DumpOpts) const { OS << "\n"; } OS << "\n"; - CFIs.dump(OS, DumpOpts); + CFIs.dump(OS, DumpOpts, /*IndentLevel=*/1, /*InitialLocation=*/{}); OS << "\n"; if (Expected<UnwindTable> RowsOrErr = UnwindTable::create(this)) @@ -1003,7 +1012,7 @@ void FDE::dump(raw_ostream &OS, DIDumpOptions DumpOpts) const { OS << " Format: " << FormatString(IsDWARF64) << "\n"; if (LSDAAddress) OS << format(" LSDA Address: %016" PRIx64 "\n", *LSDAAddress); - CFIs.dump(OS, DumpOpts); + CFIs.dump(OS, DumpOpts, /*IndentLevel=*/1, InitialLocation); OS << "\n"; if (Expected<UnwindTable> RowsOrErr = UnwindTable::create(this)) diff --git a/llvm/lib/Target/AArch64/AArch64StackTagging.cpp b/llvm/lib/Target/AArch64/AArch64StackTagging.cpp index ef7c517732ef..f2812d2b49bc 100644 --- a/llvm/lib/Target/AArch64/AArch64StackTagging.cpp +++ b/llvm/lib/Target/AArch64/AArch64StackTagging.cpp @@ -533,7 +533,9 @@ bool AArch64StackTagging::runOnFunction(Function &Fn) { if (Info.AI->hasName()) TagPCall->setName(Info.AI->getName() + ".tag"); // Does not replace metadata, so we don't have to handle DPValues. - Info.AI->replaceNonMetadataUsesWith(TagPCall); + Info.AI->replaceUsesWithIf(TagPCall, [&](const Use &U) { + return !memtag::isLifetimeIntrinsic(U.getUser()); + }); TagPCall->setOperand(0, Info.AI); // Calls to functions that may return twice (e.g. setjmp) confuse the @@ -550,7 +552,7 @@ bool AArch64StackTagging::runOnFunction(Function &Fn) { uint64_t Size = cast<ConstantInt>(Start->getArgOperand(0))->getZExtValue(); Size = alignTo(Size, kTagGranuleSize); - tagAlloca(AI, Start->getNextNode(), Start->getArgOperand(1), Size); + tagAlloca(AI, Start->getNextNode(), TagPCall, Size); auto TagEnd = [&](Instruction *Node) { untagAlloca(AI, Node, Size); }; if (!DT || !PDT || diff --git a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp index 289183ecf0f2..88553d49b1b5 100644 --- a/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -187,15 +187,15 @@ static cl::opt<bool> cl::desc("Use selective instrumentation"), cl::Hidden, cl::init(false)); -static cl::opt<int> HotPercentileCutoff( +static cl::opt<int> ClHotPercentileCutoff( "hwasan-percentile-cutoff-hot", cl::init(0), cl::desc("Alternative hot percentile cuttoff." "By default `-profile-summary-cutoff-hot` is used.")); static cl::opt<float> - RandomSkipRate("hwasan-random-skip-rate", cl::init(0), - cl::desc("Probability value in the range [0.0, 1.0] " - "to skip instrumentation of a function.")); + ClRandomSkipRate("hwasan-random-skip-rate", cl::init(0), + cl::desc("Probability value in the range [0.0, 1.0] " + "to skip instrumentation of a function.")); STATISTIC(NumTotalFuncs, "Number of total funcs"); STATISTIC(NumInstrumentedFuncs, "Number of instrumented funcs"); @@ -301,7 +301,7 @@ public: ? ClEnableKhwasan : CompileKernel; this->Rng = - RandomSkipRate.getNumOccurrences() ? M.createRNG("hwasan") : nullptr; + ClRandomSkipRate.getNumOccurrences() ? M.createRNG("hwasan") : nullptr; initializeModule(); } @@ -1391,11 +1391,6 @@ bool HWAddressSanitizer::instrumentLandingPads( return true; } -static bool isLifetimeIntrinsic(Value *V) { - auto *II = dyn_cast<IntrinsicInst>(V); - return II && II->isLifetimeStartOrEnd(); -} - static DbgAssignIntrinsic *DynCastToDbgAssign(DbgVariableIntrinsic *DVI) { return dyn_cast<DbgAssignIntrinsic>(DVI); } @@ -1455,7 +1450,8 @@ bool HWAddressSanitizer::instrumentStack(memtag::StackInfo &SInfo, AI->replaceUsesWithIf(Replacement, [AICast, AILong](const Use &U) { auto *User = U.getUser(); - return User != AILong && User != AICast && !isLifetimeIntrinsic(User); + return User != AILong && User != AICast && + !memtag::isLifetimeIntrinsic(User); }); // Helper utility for adding DW_OP_LLVM_tag_offset to debug-info records, @@ -1537,8 +1533,8 @@ void HWAddressSanitizer::sanitizeFunction(Function &F, NumTotalFuncs++; if (CSelectiveInstrumentation) { - if (RandomSkipRate.getNumOccurrences()) { - std::bernoulli_distribution D(RandomSkipRate); + if (ClRandomSkipRate.getNumOccurrences()) { + std::bernoulli_distribution D(ClRandomSkipRate); if (D(*Rng)) return; } else { @@ -1547,10 +1543,10 @@ void HWAddressSanitizer::sanitizeFunction(Function &F, MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); if (PSI && PSI->hasProfileSummary()) { auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); - if ((HotPercentileCutoff.getNumOccurrences() && - HotPercentileCutoff >= 0) + if ((ClHotPercentileCutoff.getNumOccurrences() && + ClHotPercentileCutoff >= 0) ? PSI->isFunctionHotInCallGraphNthPercentile( - HotPercentileCutoff, &F, BFI) + ClHotPercentileCutoff, &F, BFI) : PSI->isFunctionHotInCallGraph(&F, BFI)) return; } else { diff --git a/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp b/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp index 2ffe89a24584..f4b9b155827a 100644 --- a/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp +++ b/llvm/lib/Transforms/Utils/MemoryTaggingSupport.cpp @@ -12,6 +12,7 @@ #include "llvm/Transforms/Utils/MemoryTaggingSupport.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/StackSafetyAnalysis.h" @@ -69,14 +70,12 @@ bool forAllReachableExits(const DominatorTree &DT, const PostDominatorTree &PDT, ++NumCoveredExits; } } - // If there's a mix of covered and non-covered exits, just put the untag - // on exits, so we avoid the redundancy of untagging twice. if (NumCoveredExits == ReachableRetVec.size()) { - for (auto *End : Ends) - Callback(End); + for_each(Ends, Callback); } else { - for (auto *RI : ReachableRetVec) - Callback(RI); + // If there's a mix of covered and non-covered exits, just put the untag + // on exits, so we avoid the redundancy of untagging twice. + for_each(ReachableRetVec, Callback); // We may have inserted untag outside of the lifetime interval. // Signal the caller to remove the lifetime end call for this alloca. return false; @@ -237,5 +236,10 @@ void alignAndPadAlloca(memtag::AllocaInfo &Info, llvm::Align Alignment) { Info.AI = NewAI; } +bool isLifetimeIntrinsic(Value *V) { + auto *II = dyn_cast<IntrinsicInst>(V); + return II && II->isLifetimeStartOrEnd(); +} + } // namespace memtag } // namespace llvm diff --git a/llvm/test/CodeGen/AArch64/sign-return-address-cfi-negate-ra-state.ll b/llvm/test/CodeGen/AArch64/sign-return-address-cfi-negate-ra-state.ll index da2c2985acf9..9464e3447993 100644 --- a/llvm/test/CodeGen/AArch64/sign-return-address-cfi-negate-ra-state.ll +++ b/llvm/test/CodeGen/AArch64/sign-return-address-cfi-negate-ra-state.ll @@ -213,6 +213,10 @@ attributes #0 = { "sign-return-address"="all" } ; CHECK-DUMP-NOT: DW_CFA_remember_state ; CHECK-DUMP-NOT: DW_CFA_restore_state +; CHECK-DUMP: CFA=WSP{{$}} +; CHECK-DUMP: reg34=1 +; CHECK-DUMP-NOT: reg34=0 + ; baz_async ; CHECK-DUMP-LABEL: FDE ; CHECK-DUMP: Format: DWARF32 @@ -222,9 +226,24 @@ attributes #0 = { "sign-return-address"="all" } ; CHECK-DUMP: DW_CFA_restore_state: ; CHECK-DUMP: DW_CFA_AARCH64_negate_ra_state: +; CHECK-DUMP: CFA=WSP{{$}} +;; First DW_CFA_AARCH64_negate_ra_state: +; CHECK-DUMP: reg34=1 +;; Second DW_CFA_AARCH64_negate_ra_state: +; CHECK-DUMP: reg34=0 +;; DW_CFA_restore_state: +; CHECK-DUMP: reg34=1 +;; Third DW_CFA_AARCH64_negate_ra_state: +; CHECK-DUMP: reg34=0 +; CHECK-DUMP-NOT: reg34= + ; baz_sync ; CHECK-DUMP-LABEL: FDE ; CHECK-DUMP: DW_CFA_AARCH64_negate_ra_state: ; CHECK-DUMP-NOT: DW_CFA_AARCH64_negate_ra_state ; CHECK-DUMP-NOT: DW_CFA_remember_state ; CHECK-DUMP-NOT: DW_CFA_restore_state + +; CHECK-DUMP: CFA=WSP{{$}} +; CHECK-DUMP: reg34=1 +; CHECK-DUMP-NOT: reg34=0 diff --git a/llvm/test/CodeGen/AArch64/stack-tagging-initializer-merge.ll b/llvm/test/CodeGen/AArch64/stack-tagging-initializer-merge.ll index d8969fc9bebd..22d177ca3267 100644 --- a/llvm/test/CodeGen/AArch64/stack-tagging-initializer-merge.ll +++ b/llvm/test/CodeGen/AArch64/stack-tagging-initializer-merge.ll @@ -20,10 +20,10 @@ entry: ; CHECK-LABEL: define void @OneVarNoInit( ; CHECK-DAG: [[X:%.*]] = alloca { i32, [12 x i8] }, align 16 ; CHECK-DAG: [[TX:%.*]] = call ptr @llvm.aarch64.tagp.{{.*}}(ptr [[X]], {{.*}}, i64 0) -; CHECK-DAG: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[TX]]) +; CHECK-DAG: call void @llvm.lifetime.start.p0(i64 4, ptr nonnull [[X]]) ; CHECK-DAG: call void @llvm.aarch64.settag(ptr [[TX]], i64 16) ; CHECK-DAG: call void @use(ptr nonnull [[TX]]) -; CHECK-DAG: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[TX]]) +; CHECK-DAG: call void @llvm.lifetime.end.p0(i64 4, ptr nonnull [[X]]) define void @OneVarInitConst() sanitize_memtag { entry: diff --git a/llvm/test/CodeGen/AArch64/stack-tagging-stack-coloring.ll b/llvm/test/CodeGen/AArch64/stack-tagging-stack-coloring.ll index 6eb72013fb0e..81349620fb77 100644 --- a/llvm/test/CodeGen/AArch64/stack-tagging-stack-coloring.ll +++ b/llvm/test/CodeGen/AArch64/stack-tagging-stack-coloring.ll @@ -1,20 +1,20 @@ ; Test that storage for allocas with disjoint lifetimes is reused with stack ; tagging. -; RUN: opt -S -aarch64-stack-tagging %s -o - | \ -; RUN: llc -no-stack-coloring=false -o - | \ +; RUN: opt -S -aarch64-stack-tagging -stack-tagging-use-stack-safety=0 %s -o - | \ +; RUN: llc --mattr=+mte -no-stack-coloring=false -o - | \ ; RUN: FileCheck %s --check-prefix=COLOR -; RUN: opt -S -aarch64-stack-tagging %s -o - | \ -; RUN: llc -no-stack-coloring=true -o - | \ +; RUN: opt -S -aarch64-stack-tagging %s -stack-tagging-use-stack-safety=0 -o - | \ +; RUN: llc --mattr=+mte -no-stack-coloring=true -o - | \ ; RUN: FileCheck %s --check-prefix=NOCOLOR target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" -target triple = "aarch64-unknown-linux-android29" +target triple = "aarch64" -; COLOR: sub sp, sp, #192 -; NOCOLOR: sub sp, sp, #320 +; COLOR: sub sp, sp, #208 +; NOCOLOR: sub sp, sp, #336 -define i32 @myCall_w2(i32 %in) sanitize_hwaddress { +define i32 @myCall_w2(i32 %in) sanitize_memtag { entry: %a = alloca [17 x ptr], align 8 %a2 = alloca [16 x ptr], align 8 diff --git a/llvm/test/CodeGen/AArch64/stack-tagging-untag-placement.ll b/llvm/test/CodeGen/AArch64/stack-tagging-untag-placement.ll index 06f8cd5241eb..aa9cccc58712 100644 --- a/llvm/test/CodeGen/AArch64/stack-tagging-untag-placement.ll +++ b/llvm/test/CodeGen/AArch64/stack-tagging-untag-placement.ll @@ -27,7 +27,7 @@ S1: ; CHECK: call void @llvm.aarch64.settag(ptr %w, i64 48) ; CHECK-NOT: settag{{.*}}%v call void @llvm.lifetime.end.p0(i64 48, ptr nonnull %w) #1 -; CHECK: call void @llvm.lifetime.end.p0(i64 48, ptr nonnull %w.tag) +; CHECK: call void @llvm.lifetime.end.p0(i64 48, ptr nonnull %w) %b1 = icmp eq i32 %t1, 0 br i1 %b1, label %S2, label %S3 ; CHECK-NOT: settag diff --git a/llvm/test/CodeGen/PowerPC/pr74951.ll b/llvm/test/CodeGen/PowerPC/pr74951.ll new file mode 100644 index 000000000000..a0d19fc09cc2 --- /dev/null +++ b/llvm/test/CodeGen/PowerPC/pr74951.ll @@ -0,0 +1,54 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4 +; RUN: llc < %s -verify-machineinstrs -ppc-asm-full-reg-names -mtriple=powerpc64-ibm-aix-xcoff | FileCheck %s + +%struct.anon = type { i32 } + +@b = local_unnamed_addr global %struct.anon { i32 -1 }, align 4 +@g = local_unnamed_addr global [1 x i1] zeroinitializer, align 1 + +define noundef signext i32 @main() { +; CHECK-LABEL: main: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: ld r3, L..C0(r2) # @b +; CHECK-NEXT: lwz r3, 0(r3) +; CHECK-NEXT: extsw r4, r3 +; CHECK-NEXT: neg r4, r4 +; CHECK-NEXT: andi. r5, r3, 65535 +; CHECK-NEXT: rldicl r4, r4, 1, 63 +; CHECK-NEXT: bne cr0, L..BB0_4 +; CHECK-NEXT: # %bb.1: # %lor.rhs.i.i +; CHECK-NEXT: xori r5, r4, 1 +; CHECK-NEXT: cmpw r3, r5 +; CHECK-NEXT: crnot 4*cr5+lt, eq +; CHECK-NEXT: li r3, 1 +; CHECK-NEXT: bc 12, 4*cr5+lt, L..BB0_3 +; CHECK-NEXT: # %bb.2: # %lor.rhs.i.i +; CHECK-NEXT: li r3, 0 +; CHECK-NEXT: L..BB0_3: # %lor.rhs.i.i +; CHECK-NEXT: ld r5, L..C1(r2) # @g +; CHECK-NEXT: stb r3, 0(r5) +; CHECK-NEXT: L..BB0_4: # %g.exit +; CHECK-NEXT: ld r5, L..C1(r2) # @g +; CHECK-NEXT: li r3, 0 +; CHECK-NEXT: stb r4, 0(r5) +; CHECK-NEXT: blr +entry: + %0 = load i32, ptr @b, align 4 + %conv4.i = sext i32 %0 to i64 + %cmp.i = icmp slt i32 %0, 1 + %conv.i = zext i1 %cmp.i to i32 + %cmp1.i = icmp ne i32 %0, %conv.i + %conv3.i = trunc i32 %0 to i16 + %tobool.not.i.i = icmp eq i16 %conv3.i, 0 + br i1 %tobool.not.i.i, label %lor.rhs.i.i, label %g.exit + +lor.rhs.i.i: ; preds = %entry + store i1 %cmp1.i, ptr @g, align 1 + br label %g.exit + +g.exit: ; preds = %lor.end.i.i + %4 = trunc i64 %conv4.i to i32 + %cmp.i9.i = icmp sgt i32 %4, 0 + store i1 %cmp.i9.i, ptr @g, align 1 + ret i32 0 +} diff --git a/llvm/test/CodeGen/RISCV/forced-atomics.ll b/llvm/test/CodeGen/RISCV/forced-atomics.ll index 2b198afb47a9..659e0748dd53 100644 --- a/llvm/test/CodeGen/RISCV/forced-atomics.ll +++ b/llvm/test/CodeGen/RISCV/forced-atomics.ll @@ -3567,8 +3567,8 @@ define i64 @rmw64_umax_seq_cst(ptr %p) nounwind { ; RV32-NEXT: # in Loop: Header=BB51_2 Depth=1 ; RV32-NEXT: neg a3, a0 ; RV32-NEXT: and a3, a3, a1 -; RV32-NEXT: sw a1, 4(sp) ; RV32-NEXT: sw a4, 0(sp) +; RV32-NEXT: sw a1, 4(sp) ; RV32-NEXT: mv a1, sp ; RV32-NEXT: li a4, 5 ; RV32-NEXT: li a5, 5 diff --git a/llvm/test/CodeGen/RISCV/fpclamptosat.ll b/llvm/test/CodeGen/RISCV/fpclamptosat.ll index 6bfacc3e9814..630d16e7c888 100644 --- a/llvm/test/CodeGen/RISCV/fpclamptosat.ll +++ b/llvm/test/CodeGen/RISCV/fpclamptosat.ll @@ -1324,8 +1324,8 @@ define i64 @ustest_f64i64(double %x) { ; RV32IF-NEXT: # %bb.4: # %entry ; RV32IF-NEXT: li a0, 1 ; RV32IF-NEXT: .LBB20_5: # %entry -; RV32IF-NEXT: lw a3, 8(sp) -; RV32IF-NEXT: lw a4, 12(sp) +; RV32IF-NEXT: lw a4, 8(sp) +; RV32IF-NEXT: lw a3, 12(sp) ; RV32IF-NEXT: and a5, a2, a1 ; RV32IF-NEXT: beqz a5, .LBB20_7 ; RV32IF-NEXT: # %bb.6: # %entry @@ -1334,17 +1334,18 @@ define i64 @ustest_f64i64(double %x) { ; RV32IF-NEXT: .LBB20_7: ; RV32IF-NEXT: snez a1, a0 ; RV32IF-NEXT: .LBB20_8: # %entry -; RV32IF-NEXT: and a4, a2, a4 +; RV32IF-NEXT: and a3, a2, a3 ; RV32IF-NEXT: or a0, a0, a5 -; RV32IF-NEXT: and a2, a2, a3 +; RV32IF-NEXT: and a2, a2, a4 ; RV32IF-NEXT: bnez a0, .LBB20_10 ; RV32IF-NEXT: # %bb.9: -; RV32IF-NEXT: or a0, a2, a4 -; RV32IF-NEXT: snez a1, a0 +; RV32IF-NEXT: snez a0, a3 +; RV32IF-NEXT: snez a1, a2 +; RV32IF-NEXT: or a1, a1, a0 ; RV32IF-NEXT: .LBB20_10: # %entry ; RV32IF-NEXT: neg a1, a1 ; RV32IF-NEXT: and a0, a1, a2 -; RV32IF-NEXT: and a1, a1, a4 +; RV32IF-NEXT: and a1, a1, a3 ; RV32IF-NEXT: lw ra, 28(sp) # 4-byte Folded Reload ; RV32IF-NEXT: addi sp, sp, 32 ; RV32IF-NEXT: ret @@ -1403,8 +1404,8 @@ define i64 @ustest_f64i64(double %x) { ; RV32IFD-NEXT: # %bb.4: # %entry ; RV32IFD-NEXT: li a0, 1 ; RV32IFD-NEXT: .LBB20_5: # %entry -; RV32IFD-NEXT: lw a3, 8(sp) -; RV32IFD-NEXT: lw a4, 12(sp) +; RV32IFD-NEXT: lw a4, 8(sp) +; RV32IFD-NEXT: lw a3, 12(sp) ; RV32IFD-NEXT: and a5, a2, a1 ; RV32IFD-NEXT: beqz a5, .LBB20_7 ; RV32IFD-NEXT: # %bb.6: # %entry @@ -1413,17 +1414,18 @@ define i64 @ustest_f64i64(double %x) { ; RV32IFD-NEXT: .LBB20_7: ; RV32IFD-NEXT: snez a1, a0 ; RV32IFD-NEXT: .LBB20_8: # %entry -; RV32IFD-NEXT: and a4, a2, a4 +; RV32IFD-NEXT: and a3, a2, a3 ; RV32IFD-NEXT: or a0, a0, a5 -; RV32IFD-NEXT: and a2, a2, a3 +; RV32IFD-NEXT: and a2, a2, a4 ; RV32IFD-NEXT: bnez a0, .LBB20_10 ; RV32IFD-NEXT: # %bb.9: -; RV32IFD-NEXT: or a0, a2, a4 -; RV32IFD-NEXT: snez a1, a0 +; RV32IFD-NEXT: snez a0, a3 +; RV32IFD-NEXT: snez a1, a2 +; RV32IFD-NEXT: or a1, a1, a0 ; RV32IFD-NEXT: .LBB20_10: # %entry ; RV32IFD-NEXT: neg a1, a1 ; RV32IFD-NEXT: and a0, a1, a2 -; RV32IFD-NEXT: and a1, a1, a4 +; RV32IFD-NEXT: and a1, a1, a3 ; RV32IFD-NEXT: lw ra, 28(sp) # 4-byte Folded Reload ; RV32IFD-NEXT: addi sp, sp, 32 ; RV32IFD-NEXT: ret @@ -1594,8 +1596,8 @@ define i64 @ustest_f32i64(float %x) { ; RV32-NEXT: # %bb.4: # %entry ; RV32-NEXT: li a0, 1 ; RV32-NEXT: .LBB23_5: # %entry -; RV32-NEXT: lw a3, 8(sp) -; RV32-NEXT: lw a4, 12(sp) +; RV32-NEXT: lw a4, 8(sp) +; RV32-NEXT: lw a3, 12(sp) ; RV32-NEXT: and a5, a2, a1 ; RV32-NEXT: beqz a5, .LBB23_7 ; RV32-NEXT: # %bb.6: # %entry @@ -1604,17 +1606,18 @@ define i64 @ustest_f32i64(float %x) { ; RV32-NEXT: .LBB23_7: ; RV32-NEXT: snez a1, a0 ; RV32-NEXT: .LBB23_8: # %entry -; RV32-NEXT: and a4, a2, a4 +; RV32-NEXT: and a3, a2, a3 ; RV32-NEXT: or a0, a0, a5 -; RV32-NEXT: and a2, a2, a3 +; RV32-NEXT: and a2, a2, a4 ; RV32-NEXT: bnez a0, .LBB23_10 ; RV32-NEXT: # %bb.9: -; RV32-NEXT: or a0, a2, a4 -; RV32-NEXT: snez a1, a0 +; RV32-NEXT: snez a0, a3 +; RV32-NEXT: snez a1, a2 +; RV32-NEXT: or a1, a1, a0 ; RV32-NEXT: .LBB23_10: # %entry ; RV32-NEXT: neg a1, a1 ; RV32-NEXT: and a0, a1, a2 -; RV32-NEXT: and a1, a1, a4 +; RV32-NEXT: and a1, a1, a3 ; RV32-NEXT: lw ra, 28(sp) # 4-byte Folded Reload ; RV32-NEXT: addi sp, sp, 32 ; RV32-NEXT: ret @@ -1847,8 +1850,8 @@ define i64 @ustest_f16i64(half %x) { ; RV32-NEXT: # %bb.4: # %entry ; RV32-NEXT: li a0, 1 ; RV32-NEXT: .LBB26_5: # %entry -; RV32-NEXT: lw a3, 8(sp) -; RV32-NEXT: lw a4, 12(sp) +; RV32-NEXT: lw a4, 8(sp) +; RV32-NEXT: lw a3, 12(sp) ; RV32-NEXT: and a5, a2, a1 ; RV32-NEXT: beqz a5, .LBB26_7 ; RV32-NEXT: # %bb.6: # %entry @@ -1857,17 +1860,18 @@ define i64 @ustest_f16i64(half %x) { ; RV32-NEXT: .LBB26_7: ; RV32-NEXT: snez a1, a0 ; RV32-NEXT: .LBB26_8: # %entry -; RV32-NEXT: and a4, a2, a4 +; RV32-NEXT: and a3, a2, a3 ; RV32-NEXT: or a0, a0, a5 -; RV32-NEXT: and a2, a2, a3 +; RV32-NEXT: and a2, a2, a4 ; RV32-NEXT: bnez a0, .LBB26_10 ; RV32-NEXT: # %bb.9: -; RV32-NEXT: or a0, a2, a4 -; RV32-NEXT: snez a1, a0 +; RV32-NEXT: snez a0, a3 +; RV32-NEXT: snez a1, a2 +; RV32-NEXT: or a1, a1, a0 ; RV32-NEXT: .LBB26_10: # %entry ; RV32-NEXT: neg a1, a1 ; RV32-NEXT: and a0, a1, a2 -; RV32-NEXT: and a1, a1, a4 +; RV32-NEXT: and a1, a1, a3 ; RV32-NEXT: lw ra, 28(sp) # 4-byte Folded Reload ; RV32-NEXT: addi sp, sp, 32 ; RV32-NEXT: ret diff --git a/llvm/test/DebugInfo/dwarfdump-debug-frame-simple.test b/llvm/test/DebugInfo/dwarfdump-debug-frame-simple.test index 6c049af43efe..2cd281c8d0af 100644 --- a/llvm/test/DebugInfo/dwarfdump-debug-frame-simple.test +++ b/llvm/test/DebugInfo/dwarfdump-debug-frame-simple.test @@ -12,15 +12,15 @@ ; FRAMES-NEXT: DW_CFA_nop: ; FRAMES: 00000014 00000010 00000000 FDE cie=00000000 pc=00000000...00000022 -; FRAMES: DW_CFA_advance_loc: 3 +; FRAMES: DW_CFA_advance_loc: 3 to 0x3 ; FRAMES-NEXT: DW_CFA_def_cfa_offset: +12 ; FRAMES-NEXT: DW_CFA_nop: ; FRAMES: 00000028 00000014 00000000 FDE cie=00000000 pc=00000030...00000080 -; FRAMES: DW_CFA_advance_loc: 1 +; FRAMES: DW_CFA_advance_loc: 1 to 0x31 ; FRAMES-NEXT: DW_CFA_def_cfa_offset: +8 ; FRAMES-NEXT: DW_CFA_offset: {{reg5|EBP}} -8 -; FRAMES-NEXT: DW_CFA_advance_loc: 2 +; FRAMES-NEXT: DW_CFA_advance_loc: 2 to 0x33 ; FRAMES-NEXT: DW_CFA_def_cfa_register: {{reg5|EBP}} ; FRAMES-NOT: CIE diff --git a/llvm/test/Instrumentation/AddressSanitizer/aarch64be.ll b/llvm/test/Instrumentation/AddressSanitizer/aarch64be.ll index eb522a0f3f31..aeb1b0e8ebe7 100644 --- a/llvm/test/Instrumentation/AddressSanitizer/aarch64be.ll +++ b/llvm/test/Instrumentation/AddressSanitizer/aarch64be.ll @@ -2,9 +2,9 @@ ; RUN: opt < %s -passes=asan -S -mtriple=aarch64_be-linux-gnu | FileCheck --check-prefix=CHECK-AARCH64BE %s ; REQUIRES: aarch64-registered-target -define i32 @read_4_bytes(i32* %a) sanitize_address { +define i32 @read_4_bytes(ptr %a) sanitize_address { entry: - %tmp1 = load i32, i32* %a, align 4 + %tmp1 = load i32, ptr %a, align 4 ret i32 %tmp1 } diff --git a/llvm/test/Instrumentation/AddressSanitizer/program-addrspace.ll b/llvm/test/Instrumentation/AddressSanitizer/program-addrspace.ll index adfe21135e7a..1d5bfb09ead9 100644 --- a/llvm/test/Instrumentation/AddressSanitizer/program-addrspace.ll +++ b/llvm/test/Instrumentation/AddressSanitizer/program-addrspace.ll @@ -16,7 +16,7 @@ target datalayout = "P1" define i1 @b(i64 %c) addrspace(1) { %cast = inttoptr i64 %c to ptr addrspace(42) - %cmp = icmp ugt ptr addrspace(42) %cast, getelementptr inbounds ([1 x i32], ptr addrspace(42) @a, i64 0, i64 0) + %cmp = icmp ugt ptr addrspace(42) %cast, @a ret i1 %cmp } diff --git a/llvm/test/Instrumentation/InstrProfiling/before-value-profile-lowering.ll b/llvm/test/Instrumentation/InstrProfiling/before-value-profile-lowering.ll index 5dfec433f4ec..870e74ccfdac 100644 --- a/llvm/test/Instrumentation/InstrProfiling/before-value-profile-lowering.ll +++ b/llvm/test/Instrumentation/InstrProfiling/before-value-profile-lowering.ll @@ -7,17 +7,17 @@ target triple = "x86_64-unknown-linux-gnu" -declare void @llvm.instrprof.increment.step(i8*, i64, i32, i32, i64) +declare void @llvm.instrprof.increment.step(ptr, i64, i32, i32, i64) -declare void @llvm.instrprof.value.profile(i8*, i64, i64, i32, i32) +declare void @llvm.instrprof.value.profile(ptr, i64, i64, i32, i32) ; CHECK: @__profd_foo = private global @__profn_foo = private constant [3 x i8] c"foo" -define i32 @foo(i32 ()* ) { - %2 = ptrtoint i32 ()* %0 to i64 - call void @llvm.instrprof.value.profile(i8* getelementptr inbounds ([3 x i8], [3 x i8]* @__profn_foo, i32 0, i32 0), i64 0, i64 %2, i32 0, i32 0) - call void @llvm.instrprof.increment.step(i8* getelementptr inbounds ([3 x i8], [3 x i8]* @__profn_foo, i32 0, i32 0), i64 0, i32 1, i32 0, i64 0) +define i32 @foo(ptr ) { + %2 = ptrtoint ptr %0 to i64 + call void @llvm.instrprof.value.profile(ptr @__profn_foo, i64 0, i64 %2, i32 0, i32 0) + call void @llvm.instrprof.increment.step(ptr @__profn_foo, i64 0, i32 1, i32 0, i64 0) %3 = tail call i32 %0() ret i32 %3 } diff --git a/llvm/test/Instrumentation/InstrProfiling/timestamp-coverage.ll b/llvm/test/Instrumentation/InstrProfiling/timestamp-coverage.ll index ab9b664a2cff..d40cc2ac02c1 100644 --- a/llvm/test/Instrumentation/InstrProfiling/timestamp-coverage.ll +++ b/llvm/test/Instrumentation/InstrProfiling/timestamp-coverage.ll @@ -6,11 +6,11 @@ target triple = "aarch64-unknown-linux-gnu" ; CHECK: @__profc_foo = private global [9 x i8] c"\FF\FF\FF\FF\FF\FF\FF\FF\FF", section "__llvm_prf_cnts", comdat, align 8 define void @_Z3foov() { - call void @llvm.instrprof.timestamp(i8* getelementptr inbounds ([3 x i8], [3 x i8]* @__profn_foo, i32 0, i32 0), i64 12345678, i32 9, i32 0) + call void @llvm.instrprof.timestamp(ptr @__profn_foo, i64 12345678, i32 9, i32 0) ; CHECK: call void @__llvm_profile_set_timestamp(ptr @__profc_foo) - call void @llvm.instrprof.cover(i8* getelementptr inbounds ([3 x i8], [3 x i8]* @__profn_foo, i32 0, i32 0), i64 12345678, i32 9, i32 8) + call void @llvm.instrprof.cover(ptr @__profn_foo, i64 12345678, i32 9, i32 8) ret void } -declare void @llvm.instrprof.timestamp(i8*, i64, i32, i32) -declare void @llvm.instrprof.cover(i8*, i64, i32, i32) +declare void @llvm.instrprof.timestamp(ptr, i64, i32, i32) +declare void @llvm.instrprof.cover(ptr, i64, i32, i32) diff --git a/llvm/test/Instrumentation/InstrProfiling/timestamp.ll b/llvm/test/Instrumentation/InstrProfiling/timestamp.ll index aa2393695d6b..c08ba4485fc5 100644 --- a/llvm/test/Instrumentation/InstrProfiling/timestamp.ll +++ b/llvm/test/Instrumentation/InstrProfiling/timestamp.ll @@ -6,11 +6,11 @@ target triple = "aarch64-unknown-linux-gnu" ; CHECK: @__profc_foo = private global [2 x i64] zeroinitializer, section "__llvm_prf_cnts", comdat, align 8 define void @_Z3foov() { - call void @llvm.instrprof.timestamp(i8* getelementptr inbounds ([3 x i8], [3 x i8]* @__profn_foo, i32 0, i32 0), i64 12345678, i32 2, i32 0) + call void @llvm.instrprof.timestamp(ptr @__profn_foo, i64 12345678, i32 2, i32 0) ; CHECK: call void @__llvm_profile_set_timestamp(ptr @__profc_foo) - call void @llvm.instrprof.increment(i8* getelementptr inbounds ([3 x i8], [3 x i8]* @__profn_foo, i32 0, i32 0), i64 12345678, i32 2, i32 1) + call void @llvm.instrprof.increment(ptr @__profn_foo, i64 12345678, i32 2, i32 1) ret void } -declare void @llvm.instrprof.timestamp(i8*, i64, i32, i32) -declare void @llvm.instrprof.increment(i8*, i64, i32, i32) +declare void @llvm.instrprof.timestamp(ptr, i64, i32, i32) +declare void @llvm.instrprof.increment(ptr, i64, i32, i32) diff --git a/llvm/test/Object/Inputs/small.ll b/llvm/test/Object/Inputs/small.ll index ef68a8c324a3..677f20ade4c5 100644 --- a/llvm/test/Object/Inputs/small.ll +++ b/llvm/test/Object/Inputs/small.ll @@ -4,15 +4,15 @@ target triple = "i386-pc-windows" define i32 @main() nounwind { entry: - %call = tail call i32 @puts(i8* getelementptr inbounds ([13 x i8], [13 x i8]* @.str, i32 0, i32 0)) nounwind - tail call void bitcast (void (...)* @SomeOtherFunction to void ()*)() nounwind + %call = tail call i32 @puts(ptr @.str) nounwind + tail call void @SomeOtherFunction() nounwind ret i32 0 } -declare i32 @puts(i8* nocapture) nounwind +declare i32 @puts(ptr nocapture) nounwind declare void @SomeOtherFunction(...) @var = global i32 0 -@llvm.used = appending global [1 x i8*] [i8* bitcast (i32* @var to i8*)], section "llvm.metadata" -@llvm.global_ctors = appending global [1 x { i32, void ()*, i8* }] [{ i32, void ()*, i8* } { i32 65535, void ()* null, i8* null }] +@llvm.used = appending global [1 x ptr] [ptr @var], section "llvm.metadata" +@llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 65535, ptr null, ptr null }] diff --git a/llvm/test/Object/Inputs/trivial.ll b/llvm/test/Object/Inputs/trivial.ll index 82eabc6389fb..1a6a76298b23 100644 --- a/llvm/test/Object/Inputs/trivial.ll +++ b/llvm/test/Object/Inputs/trivial.ll @@ -5,15 +5,15 @@ define i32 @main() nounwind { entry: - %call = tail call i32 @puts(i8* getelementptr inbounds ([13 x i8], [13 x i8]* @.str, i32 0, i32 0)) nounwind - tail call void bitcast (void (...)* @SomeOtherFunction to void ()*)() nounwind + %call = tail call i32 @puts(ptr @.str) nounwind + tail call void @SomeOtherFunction() nounwind ret i32 0 } -declare i32 @puts(i8* nocapture) nounwind +declare i32 @puts(ptr nocapture) nounwind declare void @SomeOtherFunction(...) @var = global i32 0 -@llvm.used = appending global [1 x i8*] [i8* bitcast (i32* @var to i8*)], section "llvm.metadata" -@llvm.global_ctors = appending global [1 x { i32, void ()*, i8* }] [{ i32, void ()*, i8* } { i32 65535, void ()* null, i8* null }] +@llvm.used = appending global [1 x ptr] [ptr @var], section "llvm.metadata" +@llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 65535, ptr null, ptr null }] diff --git a/llvm/test/Object/X86/irsymtab-bad-alias.ll b/llvm/test/Object/X86/irsymtab-bad-alias.ll index c54436d59219..7f204d1dd157 100644 --- a/llvm/test/Object/X86/irsymtab-bad-alias.ll +++ b/llvm/test/Object/X86/irsymtab-bad-alias.ll @@ -11,5 +11,5 @@ target triple = "x86_64-unknown-linux-gnu" @g1 = global i32 1 @g2 = global i32 2 -@a = alias i32, inttoptr(i32 sub (i32 ptrtoint (i32* @g1 to i32), - i32 ptrtoint (i32* @g2 to i32)) to i32*) +@a = alias i32, inttoptr(i32 sub (i32 ptrtoint (ptr @g1 to i32), + i32 ptrtoint (ptr @g2 to i32)) to ptr) diff --git a/llvm/test/Object/X86/nm-ir.ll b/llvm/test/Object/X86/nm-ir.ll index e57c6d9a11c6..0324efb2948d 100644 --- a/llvm/test/Object/X86/nm-ir.ll +++ b/llvm/test/Object/X86/nm-ir.ll @@ -29,15 +29,15 @@ module asm ".long undef_asm_sym" @g3 = common global i32 0 @g4 = private global i32 42 -@a1 = alias i32, i32* @g1 -@a2 = internal alias i32, i32* @g1 +@a1 = alias i32, ptr @g1 +@a2 = internal alias i32, ptr @g1 -define void ()* @f1() { +define ptr @f1() { call void @f5() - ret void ()* null + ret ptr null } -@ifunc_f1 = ifunc void (), void ()* ()* @f1 +@ifunc_f1 = ifunc void (), ptr @f1 define internal void @f2() { ret void diff --git a/llvm/test/Object/dllimport-globalref.ll b/llvm/test/Object/dllimport-globalref.ll index dd518bc2266c..0a95be20a9d1 100644 --- a/llvm/test/Object/dllimport-globalref.ll +++ b/llvm/test/Object/dllimport-globalref.ll @@ -11,4 +11,4 @@ target triple = "x86_64-pc-windows-msvc" ; CHECK: U f declare dllimport void @f() -@fp = constant void ()* @f +@fp = constant ptr @f diff --git a/llvm/test/Object/dllimport.ll b/llvm/test/Object/dllimport.ll index afdb4562cc9f..52f583fa2487 100644 --- a/llvm/test/Object/dllimport.ll +++ b/llvm/test/Object/dllimport.ll @@ -12,6 +12,6 @@ declare dllimport void @f() define void @g() { call void @f() - store i32 42, i32* @v + store i32 42, ptr @v ret void } diff --git a/llvm/test/Object/mangle-ir.ll b/llvm/test/Object/mangle-ir.ll index bd7c3d93b7c9..76442f070385 100644 --- a/llvm/test/Object/mangle-ir.ll +++ b/llvm/test/Object/mangle-ir.ll @@ -7,8 +7,8 @@ target datalayout = "m:o" ; CHECK-NOT: memcpy define void @f() { - tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* null, i8* null, i64 0, i1 false) + tail call void @llvm.memcpy.p0.p0.i64(ptr null, ptr null, i64 0, i1 false) ret void } -declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture, i8* nocapture readonly, i64, i1) +declare void @llvm.memcpy.p0.p0.i64(ptr nocapture, ptr nocapture readonly, i64, i1) diff --git a/llvm/test/Object/objc-swift-mixed-imageinfo-macho.ll b/llvm/test/Object/objc-swift-mixed-imageinfo-macho.ll index d2518f46cc27..c506c9687ec2 100644 --- a/llvm/test/Object/objc-swift-mixed-imageinfo-macho.ll +++ b/llvm/test/Object/objc-swift-mixed-imageinfo-macho.ll @@ -5,11 +5,11 @@ target triple = "x86_64-apple-macosx10.15.0" -@llvm.used = appending global [1 x i8*] [i8* bitcast (i16* @__swift_reflection_version to i8*)], section "llvm.metadata", align 8 +@llvm.used = appending global [1 x ptr] [ptr @__swift_reflection_version], section "llvm.metadata", align 8 @__swift_reflection_version = linkonce_odr hidden constant i16 3 -define i32 @main(i32 %0, i8** %1) #0 { - %3 = bitcast i8** %1 to i8* +define i32 @main(i32 %0, ptr %1) #0 { + %3 = bitcast ptr %1 to ptr ret i32 0 } @@ -25,7 +25,7 @@ attributes #0 = { "frame-pointer"="all" "target-cpu"="penryn" "target-features"= !1 = !{!"-lswiftSwiftOnoneSupport"} !2 = !{!"-lswiftCore"} !3 = !{!"-lobjc"} -!4 = !{[1 x i8*]* @llvm.used, null, null, i1 false, i1 true} +!4 = !{ptr @llvm.used, null, null, i1 false, i1 true} !5 = !{i32 2, !"SDK Version", [2 x i32] [i32 10, i32 15]} !6 = !{i32 1, !"Objective-C Version", i32 2} !7 = !{i32 1, !"Objective-C Image Info Version", i32 0} diff --git a/llvm/test/tools/llvm-readobj/ELF/unwind.test b/llvm/test/tools/llvm-readobj/ELF/unwind.test index 2deb1a587d24..2e51ec2a61a6 100644 --- a/llvm/test/tools/llvm-readobj/ELF/unwind.test +++ b/llvm/test/tools/llvm-readobj/ELF/unwind.test @@ -96,9 +96,9 @@ # CHECK: Program: # CHECK-NEXT: DW_CFA_def_cfa_offset: +16 -# CHECK-NEXT: DW_CFA_advance_loc: 6 +# CHECK-NEXT: DW_CFA_advance_loc: 6 to 0x4004a6 # CHECK-NEXT: DW_CFA_def_cfa_offset: +24 -# CHECK-NEXT: DW_CFA_advance_loc: 10 +# CHECK-NEXT: DW_CFA_advance_loc: 10 to 0x4004b0 # CHECK-NEXT: DW_CFA_def_cfa_expression: DW_OP_breg7 +8, DW_OP_breg16 +0, DW_OP_lit15, DW_OP_and, DW_OP_lit11, DW_OP_ge, DW_OP_lit3, DW_OP_shl, DW_OP_plus # CHECK-NEXT: DW_CFA_nop: # CHECK-NEXT: DW_CFA_nop: @@ -110,12 +110,12 @@ # CHECK-NEXT: address_range: 0x10 (end : 0x4005c6) # CHECK: Program: -# CHECK-NEXT: DW_CFA_advance_loc: 1 +# CHECK-NEXT: DW_CFA_advance_loc: 1 to 0x4005b7 # CHECK-NEXT: DW_CFA_def_cfa_offset: +16 # CHECK-NEXT: DW_CFA_offset: reg6 -16 -# CHECK-NEXT: DW_CFA_advance_loc: 3 +# CHECK-NEXT: DW_CFA_advance_loc: 3 to 0x4005ba # CHECK-NEXT: DW_CFA_def_cfa_register: reg6 -# CHECK-NEXT: DW_CFA_advance_loc: 11 +# CHECK-NEXT: DW_CFA_advance_loc: 11 to 0x4005c5 # CHECK-NEXT: DW_CFA_def_cfa: reg7 +8 # CHECK-NEXT: DW_CFA_nop: # CHECK-NEXT: DW_CFA_nop: @@ -126,15 +126,15 @@ # CHECK-NEXT: address_range: 0xc7f (end : 0x40124f) # CHECK: Program: -# CHECK-NEXT: DW_CFA_advance_loc: 5 +# CHECK-NEXT: DW_CFA_advance_loc: 5 to 0x4005d5 # CHECK-NEXT: DW_CFA_def_cfa: reg10 +0 -# CHECK-NEXT: DW_CFA_advance_loc: 9 +# CHECK-NEXT: DW_CFA_advance_loc: 9 to 0x4005de # CHECK-NEXT: DW_CFA_expression: reg6 DW_OP_breg6 +0 -# CHECK-NEXT: DW_CFA_advance_loc: 5 +# CHECK-NEXT: DW_CFA_advance_loc: 5 to 0x4005e3 # CHECK-NEXT: DW_CFA_def_cfa_expression: DW_OP_breg6 -8, DW_OP_deref -# CHECK-NEXT: DW_CFA_advance_loc2: 3174 +# CHECK-NEXT: DW_CFA_advance_loc2: 3174 to 0x401249 # CHECK-NEXT: DW_CFA_def_cfa: reg10 +0 -# CHECK-NEXT: DW_CFA_advance_loc: 5 +# CHECK-NEXT: DW_CFA_advance_loc: 5 to 0x40124e # CHECK-NEXT: DW_CFA_def_cfa: reg7 +8 # CHECK-NEXT: DW_CFA_nop: # CHECK-NEXT: DW_CFA_nop: @@ -146,21 +146,21 @@ # CHECK-NEXT: address_range: 0x66 (end : 0x4012b6) # CHECK: Program: -# CHECK-NEXT: DW_CFA_advance_loc: 1 +# CHECK-NEXT: DW_CFA_advance_loc: 1 to 0x401251 # CHECK-NEXT: DW_CFA_def_cfa_offset: +16 # CHECK-NEXT: DW_CFA_offset: reg6 -16 -# CHECK-NEXT: DW_CFA_advance_loc: 3 +# CHECK-NEXT: DW_CFA_advance_loc: 3 to 0x401254 # CHECK-NEXT: DW_CFA_def_cfa_register: reg6 -# CHECK-NEXT: DW_CFA_advance_loc: 2 +# CHECK-NEXT: DW_CFA_advance_loc: 2 to 0x401256 # CHECK-NEXT: DW_CFA_offset: reg15 -24 -# CHECK-NEXT: DW_CFA_advance_loc: 5 +# CHECK-NEXT: DW_CFA_advance_loc: 5 to 0x40125b # CHECK-NEXT: DW_CFA_offset: reg14 -32 -# CHECK-NEXT: DW_CFA_advance_loc: 7 +# CHECK-NEXT: DW_CFA_advance_loc: 7 to 0x401262 # CHECK-NEXT: DW_CFA_offset: reg13 -40 # CHECK-NEXT: DW_CFA_offset: reg12 -48 -# CHECK-NEXT: DW_CFA_advance_loc: 8 +# CHECK-NEXT: DW_CFA_advance_loc: 8 to 0x40126a # CHECK-NEXT: DW_CFA_offset: reg3 -56 -# CHECK-NEXT: DW_CFA_advance_loc1: 75 +# CHECK-NEXT: DW_CFA_advance_loc1: 75 to 0x4012b5 # CHECK-NEXT: DW_CFA_def_cfa: reg7 +8 # CHECK-NEXT: DW_CFA_nop: # CHECK-NEXT: DW_CFA_nop: diff --git a/llvm/tools/llvm-readobj/DwarfCFIEHPrinter.h b/llvm/tools/llvm-readobj/DwarfCFIEHPrinter.h index 687d97abd023..2e89463e68d5 100644 --- a/llvm/tools/llvm-readobj/DwarfCFIEHPrinter.h +++ b/llvm/tools/llvm-readobj/DwarfCFIEHPrinter.h @@ -196,6 +196,7 @@ void PrinterContext<ELFT>::printEHFrame(const Elf_Shdr *EHFrameShdr) const { reportError(std::move(E), ObjF.getFileName()); for (const dwarf::FrameEntry &Entry : EHFrame) { + std::optional<uint64_t> InitialLocation; if (const dwarf::CIE *CIE = dyn_cast<dwarf::CIE>(&Entry)) { W.startLine() << format("[0x%" PRIx64 "] CIE length=%" PRIu64 "\n", Address + CIE->getOffset(), CIE->getLength()); @@ -214,8 +215,9 @@ void PrinterContext<ELFT>::printEHFrame(const Elf_Shdr *EHFrameShdr) const { Address + FDE->getLinkedCIE()->getOffset()); W.indent(); + InitialLocation = FDE->getInitialLocation(); W.startLine() << format("initial_location: 0x%" PRIx64 "\n", - FDE->getInitialLocation()); + *InitialLocation); W.startLine() << format( "address_range: 0x%" PRIx64 " (end : 0x%" PRIx64 ")\n", FDE->getAddressRange(), @@ -227,7 +229,8 @@ void PrinterContext<ELFT>::printEHFrame(const Elf_Shdr *EHFrameShdr) const { W.indent(); auto DumpOpts = DIDumpOptions(); DumpOpts.IsEH = true; - Entry.cfis().dump(W.getOStream(), DumpOpts, W.getIndentLevel()); + Entry.cfis().dump(W.getOStream(), DumpOpts, W.getIndentLevel(), + InitialLocation); W.unindent(); W.unindent(); W.getOStream() << "\n"; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h b/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h new file mode 100644 index 000000000000..a69751e072b7 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/AllInterfaces.h @@ -0,0 +1,26 @@ +//===- AllInterfaces.h - ----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a common entry point for registering all external +// interface implementations to the linalg dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H +#define MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H + +namespace mlir { +class DialectRegistry; + +namespace linalg { +void registerAllDialectInterfaceImplementations(DialectRegistry ®istry); +} // namespace linalg + +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_TRANSFORMS_ALLINTERFACES_H diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h new file mode 100644 index 000000000000..c57501ea86b7 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- MeshShardingInterfaceImpl.h ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H +#define MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace linalg { +void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry); +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_MESHSHARDINGINTERFACEIMPL_H diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td index fc2acc70381e..9d9b5892e1a5 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td @@ -46,6 +46,12 @@ def Mesh_ReductionKind : I32EnumAttr<"ReductionKind", I32EnumAttrCase<"Sum", 1, "sum">, I32EnumAttrCase<"Max", 2, "max">, I32EnumAttrCase<"Min", 3, "min">, + I32EnumAttrCase<"Product", 4, "product">, + // Arithmetic mean. + I32EnumAttrCase<"Average", 5, "average">, + I32EnumAttrCase<"BitwiseAnd", 6, "bitwise_and">, + I32EnumAttrCase<"BitwiseOr", 7, "bitwise_or">, + I32EnumAttrCase<"BitwiseXor", 8, "bitwise_xor">, I32EnumAttrCase<"Generic", 100, "generic"> ]> { let genSpecializedAttr = 0; diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index b9cd15e20626..8e1e47546358 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -353,6 +353,10 @@ def Mesh_AllReduceOp : Mesh_CollectiveCommunicationOpBase<"all_reduce", [ attr-dict `:` type($input) `->` type($result) }]; let hasCanonicalizer = 1; + let builders = [ + OpBuilder<(ins "Value":$input, "StringRef":$mesh, + "ArrayRef<MeshAxis>":$meshAxes, "ReductionKind":$reduction)> + ]; } def Mesh_AllSliceOp : Mesh_CollectiveCommunicationOpBase<"all_slice", [ diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h index ffc9b6fb18be..ab4df2ab028d 100644 --- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h @@ -22,6 +22,24 @@ class SymbolTableCollection; namespace mesh { +// Retrieve the mesh axes corresponding to each operation loop iterator based +// on the provided shardings for the op's operands and results. +// Assumes that the indexingMaps are projected permutations. +ShardingArray getMeshAxisAssignmentForLoopIterators( + ArrayRef<MeshShardingAttr> operandShardings, + ArrayRef<MeshShardingAttr> resultShardings, + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<AffineMap> indexingMaps); + +bool isAtLeastOneReductionIteratorSharded( + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators); + +// Get the set of mesh axes that correspond to reduction loop iterators. +SmallVector<MeshAxis> getReductionMeshAxes( + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators); + // Inserts a clone of the operation that has all ranked tensor // arguments/results sharded. void spmdizeTriviallyShardableOperation( diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h index aeab28961a4e..be82e2af399d 100644 --- a/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Transforms.h @@ -13,6 +13,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/ArrayRef.h" namespace mlir { class RewritePatternSet; @@ -37,6 +38,11 @@ TypedValue<IndexType> createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes, ImplicitLocOpBuilder &builder); +// Get process linear index along the given mesh axes. +TypedValue<IndexType> createProcessLinearIndex(StringRef mesh, + ArrayRef<MeshAxis> meshAxes, + ImplicitLocOpBuilder &builder); + } // namespace mesh } // namespace mlir diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h index 50f6f6de5c28..6c8a170a03c7 100644 --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -216,6 +216,14 @@ public: {TypeID::get<ConcreteT>(), InterfaceT::getInterfaceID()}); } + // Declare the same interface for multiple types. + // Example: + // declarePromisedInterfaces<FunctionOpInterface, MyFuncType1, MyFuncType2>() + template <typename InterfaceT, typename... ConcreteT> + void declarePromisedInterfaces() { + (declarePromisedInterface<ConcreteT, InterfaceT>(), ...); + } + /// Checks if the given interface, which is attempting to be used, is a /// promised interface of this dialect that has yet to be implemented. If so, /// emits a fatal error. `interfaceName` is an optional string that contains a diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 838bd03622a6..21775e11e071 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -43,10 +43,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h" -#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h" -#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/MPI/IR/MPI.h" @@ -157,10 +154,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { cf::registerBufferizableOpInterfaceExternalModels(registry); cf::registerBufferDeallocationOpInterfaceExternalModels(registry); gpu::registerBufferDeallocationOpInterfaceExternalModels(registry); - linalg::registerBufferizableOpInterfaceExternalModels(registry); - linalg::registerSubsetOpInterfaceExternalModels(registry); - linalg::registerTilingInterfaceExternalModels(registry); - linalg::registerValueBoundsOpInterfaceExternalModels(registry); + linalg::registerAllDialectInterfaceImplementations(registry); memref::registerAllocationOpInterfaceExternalModels(registry); memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); memref::registerValueBoundsOpInterfaceExternalModels(registry); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 01fde101ef3c..83198c9b0db5 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -1085,6 +1085,39 @@ struct ConversionConfig { /// IR during an analysis conversion and only pre-existing operations are /// added to the set. DenseSet<Operation *> *legalizableOps = nullptr; + + /// An optional listener that is notified about all IR modifications in case + /// dialect conversion succeeds. If the dialect conversion fails and no IR + /// modifications are visible (i.e., they were all rolled back), no + /// notifications are sent. + /// + /// Note: Notifications are sent in a delayed fashion, when the dialect + /// conversion is guaranteed to succeed. At that point, some IR modifications + /// may already have been materialized. Consequently, operations/blocks that + /// are passed to listener callbacks should not be accessed. (Ops/blocks are + /// guaranteed to be valid pointers and accessing op names is allowed. But + /// there are no guarantees about the state of ops/blocks at the time that a + /// callback is triggered.) + /// + /// Example: Consider a dialect conversion a new op ("test.foo") is created + /// and inserted, and later moved to another block. (Moving ops also triggers + /// "notifyOperationInserted".) + /// + /// (1) notifyOperationInserted: "test.foo" (into block "b1") + /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2") + /// + /// When querying "op->getBlock()" during the first "notifyOperationInserted", + /// "b2" would be returned because "moving an op" is a kind of rewrite that is + /// immediately performed by the dialect conversion (and rolled back upon + /// failure). + // + // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted" + // callback, the previous region/block is provided to the callback, but not + // the iterator pointing to the exact location within the region/block. That + // is because these notifications are sent with a delay (after the IR has + // already been modified) and iterators into past IR state cannot be + // represented at the moment. + RewriterBase::Listener *listener = nullptr; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt index f0ac1899bb02..c187563b8f0c 100644 --- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -25,6 +25,7 @@ add_mlir_dialect_library(MLIRLinalgDialect MLIRInferTypeOpInterface MLIRIR MLIRParser + MLIRShardingInterface MLIRSideEffectInterfaces MLIRSparseTensorDialect MLIRSCFDialect diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp index 5069d43e7db9..027058d4de63 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -118,6 +119,12 @@ void mlir::linalg::LinalgDialect::initialize() { >(namedStructuredOpRegionBuilders); addInterfaces<LinalgInlinerInterface>(); + + declarePromisedInterface<GenericOp, mesh::ShardingInterface>(); + declarePromisedInterfaces<mesh::ShardingInterface, +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >(); } LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op, diff --git a/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp new file mode 100644 index 000000000000..281d9f220448 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp @@ -0,0 +1,24 @@ +//===- AllInterfaces.cpp - ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h" + +#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" + +void mlir::linalg::registerAllDialectInterfaceImplementations( + DialectRegistry ®istry) { + registerBufferizableOpInterfaceExternalModels(registry); + registerMeshShardingInterfaceExternalModels(registry); + registerSubsetOpInterfaceExternalModels(registry); + registerTilingInterfaceExternalModels(registry); + registerValueBoundsOpInterfaceExternalModels(registry); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 4f47e3b87184..513c54de5d7b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRLinalgTransforms + AllInterfaces.cpp BubbleUpExtractSlice.cpp BufferizableOpInterfaceImpl.cpp Bufferize.cpp @@ -21,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms InlineScalarOperands.cpp Interchange.cpp Loops.cpp + MeshShardingInterfaceImpl.cpp NamedOpConversions.cpp Padding.cpp Promotion.cpp @@ -61,12 +63,15 @@ add_mlir_dialect_library(MLIRLinalgTransforms MLIRIR MLIRMemRefDialect MLIRMemRefTransforms + MLIRMeshDialect + MLIRMeshTransforms MLIRLinalgDialect MLIRLinalgUtils MLIRSCFDialect MLIRSCFTransforms MLIRSCFUtils MLIRPass + MLIRShardingInterface MLIRSubsetOpInterface MLIRSparseTensorDialect MLIRTensorDialect diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp new file mode 100644 index 000000000000..146e88076566 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp @@ -0,0 +1,353 @@ +//===- MeshShardingInterfaceImpl.cpp --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h" + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Mesh/Transforms/Transforms.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" +#include <iterator> +#include <optional> +#include <utility> + +namespace mlir::linalg { + +using MeshAxis = mesh::MeshAxis; +using ReductionKind = mesh::ReductionKind; +using MeshShardingAttr = mesh::MeshShardingAttr; +using ShardingArray = mesh::ShardingArray; +using MeshOp = mesh::MeshOp; + +// Returns the corresponding mesh reduction kind for the given arith op. +static ReductionKind getReductionKind(Operation *op) { + return llvm::TypeSwitch<Operation *, ReductionKind>(op) + // Floating-point operations. + .Case([](arith::AddFOp op) { return ReductionKind::Sum; }) + .Case([](arith::MulFOp op) { return ReductionKind::Product; }) + // TODO: handle maxnumf and minnumf. + .Case([](arith::MaximumFOp op) { return ReductionKind::Max; }) + .Case([](arith::MinimumFOp op) { return ReductionKind::Min; }) + // Integer operations. + .Case([](arith::AddIOp op) { return ReductionKind::Sum; }) + .Case([](arith::OrIOp op) { return ReductionKind::BitwiseOr; }) + .Case([](arith::XOrIOp op) { return ReductionKind::BitwiseXor; }) + .Case([](arith::AndIOp op) { return ReductionKind::Sum; }) + // TODO: handle signless, signed and unsigned types properly. + // It is assumed that the element type of the collective operands and + // result drive the meaning of the reduction kind, whether it is signed + // or unsigned. + // The reduction op inside the linalg op may have different result type + // from the element type of the linalg op's result. + // Also signed and unsigned Arith dialect ops may accept signed, unsigned + // or signless operands. + // Maybe expand the reduction kinds. + .Case([](arith::MaxUIOp op) { return ReductionKind::Max; }) + .Case([](arith::MinUIOp op) { return ReductionKind::Min; }) + .Case([](arith::MaxSIOp op) { return ReductionKind::Max; }) + .Case([](arith::MinSIOp op) { return ReductionKind::Min; }) + .Case([](arith::MulIOp op) { return ReductionKind::Product; }) + .Default([](Operation *op) { return ReductionKind::Generic; }); +} + +static std::optional<Operation *> getCombinerOp(LinalgOp op) { + SmallVector<Operation *> combinerOps; + Value reducedValue = matchReduction(op.getRegionOutputArgs(), 0, combinerOps); + if (!reducedValue || combinerOps.size() != 1) { + return std::nullopt; + } + + return combinerOps[0]; +} + +static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) { + std::optional<Operation *> reductionOp = getCombinerOp(op); + if (!reductionOp) { + return ReductionKind::Generic; + } + [[maybe_unused]] Type resultElementType = + llvm::cast<RankedTensorType>(op->getResult(0).getType()).getElementType(); + // TODO: handle case when result type of the reduction op does not match the + // element type of the result tensor. + // Would it makes sense at all? + assert(resultElementType == reductionOp.value()->getResult(0).getType()); + return getReductionKind(reductionOp.value()); +} + +static MeshOp getMesh(Operation *op, + ArrayRef<MeshShardingAttr> operandShardings, + ArrayRef<MeshShardingAttr> resultShardings, + SymbolTableCollection &symbolTable) { + for (MeshShardingAttr sharding : operandShardings) { + if (sharding) { + return mesh::getMesh(op, sharding.getMesh(), symbolTable); + } + } + + for (MeshShardingAttr sharding : resultShardings) { + if (sharding) { + return mesh::getMesh(op, sharding.getMesh(), symbolTable); + } + } + + assert(false); + return nullptr; +} + +// Choose the operand based on the current process index along the reduction +// mesh axes. +// We need to use the initial value only once to avoid including it in the +// reduction multiple times. +// In each process group only the leading process with linear index 0 would use +// the original operand. +// The other processes would use the reduction operation neutral tensor. +static Value createDestinationPassingStyleInitOperand( + LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes, + MeshOp meshOp, ImplicitLocOpBuilder &builder) { + Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex( + meshOp.getSymName(), reductionMeshAxes, builder); + Value zero = builder.create<arith::ConstantIndexOp>(0); + Value isLeadProcess = builder.create<arith::CmpIOp>( + builder.getI1Type(), arith::CmpIPredicate::eq, + processLinearIndexInReductionGroup, zero); + scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(), + isLeadProcess, true, true); + // Then block. + { + OpBuilder::InsertionGuard insertionGuard(builder); + builder.setInsertionPointToEnd(&ifOp.getThenRegion().front()); + builder.create<scf::YieldOp>(spmdizedOperand); + } + + // Else block. + { + OpBuilder::InsertionGuard insertionGuard(builder); + builder.setInsertionPointToEnd(&ifOp.getElseRegion().front()); + SmallVector<OpFoldResult> shape = + tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand); + PartialReductionOpInterface partialReductionIface = + llvm::cast<PartialReductionOpInterface>(op.getOperation()); + FailureOr<Operation *> reductionNeutralTensorOp = + partialReductionIface.generateInitialTensorForPartialReduction( + builder, builder.getLoc(), shape, {}); + assert(succeeded(reductionNeutralTensorOp)); + builder.create<scf::YieldOp>( + reductionNeutralTensorOp.value()->getResult(0)); + } + return ifOp.getResult(0); +} + +// Create the DPS init operands for the spmdized Linalg op. +// Return all the new spmdized operands. +static SmallVector<Value> createDestinationPassingStyleInitOperands( + LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands, + ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap, + ImplicitLocOpBuilder &builder) { + // TODO: add support for multiple destination passing style initial value + // operands. + // PartialReductionOpInterface::generateInitialTensorForPartialReduction + // needs to also support multiple DPS initial operands. + SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands); + auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber(); + Value spmdizedInitOperand = + spmdizationMap.lookup(op->getOperands()[operandIdx]); + newOperands[operandIdx] = createDestinationPassingStyleInitOperand( + op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder); + return newOperands; +} + +static void createAllReduceForResultWithoutPartialSharding( + Value unshardedLinalgOpResult, ArrayRef<MeshAxis> opReductionMeshAxes, + MeshShardingAttr resultSharding, ReductionKind reductionKind, + IRMapping &spmdizationMap, ImplicitLocOpBuilder &builder) { + SmallVector<MeshAxis> allReduceMeshAxes; + llvm::copy_if(opReductionMeshAxes, std::back_inserter(allReduceMeshAxes), + [&resultSharding](MeshAxis axis) { + return !llvm::is_contained(resultSharding.getPartialAxes(), + axis); + }); + if (allReduceMeshAxes.empty()) { + return; + } + + Value spmdizedLinalgOpResult = spmdizationMap.lookup(unshardedLinalgOpResult); + Value reducedValue = builder.create<mesh::AllReduceOp>( + spmdizedLinalgOpResult, resultSharding.getMesh().getValue(), + allReduceMeshAxes, reductionKind); + spmdizationMap.map(unshardedLinalgOpResult, reducedValue); +} + +static void createAllReduceForResultsWithoutPartialShardings( + LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes, + ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap, + ImplicitLocOpBuilder &builder) { + ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp); + for (auto [unshardedLinalgOpResult, resultSharding] : + llvm::zip_equal(unshardedOp->getResults(), resultShardings)) { + createAllReduceForResultWithoutPartialSharding( + unshardedLinalgOpResult, opReductionMeshAxes, resultSharding, + reductionKind, spmdizationMap, builder); + } +} + +static void spmdizeLinalgOpWithShardedReduction( + LinalgOp op, ArrayRef<Value> spmdizedOperands, + ArrayRef<MeshShardingAttr> operandShardings, + ArrayRef<MeshShardingAttr> resultShardings, + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators, + IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, + ImplicitLocOpBuilder &builder) { + MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable); + SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes( + loopIteratorTypes, meshAxisAssignmentForLoopIterators); + SmallVector<Value> spmdizedLinalgOpOperands = + createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands, + reductionMeshAxes, + spmdizationMap, builder); + // We must not change the operand mappings of the original spmdizationMap as + // they are the mappings for the whole spmdization blob and may be used by + // others. + IRMapping internalSpmdizationMap; + for (auto [unshardedOperand, spmdizedOperand] : + llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) { + internalSpmdizationMap.map(unshardedOperand, spmdizedOperand); + } + spmdizeTriviallyShardableOperation( + *op, spmdizedLinalgOpOperands, operandShardings, resultShardings, + internalSpmdizationMap, symbolTable, builder); + for (Value result : op->getResults()) { + spmdizationMap.map(result, internalSpmdizationMap.lookup(result)); + } + + // Handle partial shardings. + createAllReduceForResultsWithoutPartialShardings( + op, reductionMeshAxes, resultShardings, spmdizationMap, builder); +} + +namespace { + +// ShardingInterface for ops that implement LinalgStructuredInterface. +// The supported ops are only those where the indexing maps are projected +// permutations. +template <typename Op> +struct StructuredOpShardingInterface + : public mesh::ShardingInterface::ExternalModel< + StructuredOpShardingInterface<Op>, Op> { + SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { + return llvm::cast<LinalgOp>(op).getIteratorTypesArray(); + } + + SmallVector<AffineMap> getIndexingMaps(Operation *op) const { + LinalgOp linalgOp = llvm::cast<LinalgOp>(op); + SmallVector<AffineMap> res = linalgOp.getIndexingMapsArray(); + + // Results must have the same indexing as destination passing style initial + // operands. + for (int64_t i = 0; i < linalgOp.getNumDpsInits(); ++i) { + res.push_back(res[linalgOp.getDpsInitOperand(i)->getOperandNumber()]); + } + + return res; + } + + LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, + ArrayRef<MeshShardingAttr> operandShardings, + ArrayRef<MeshShardingAttr> resultShardings, + IRMapping &spmdizationMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder) const { + LinalgOp linalgOp = llvm::cast<LinalgOp>(op); + + SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); + bool allIndexingMapsAreProjectedPermutation = + llvm::all_of(indexingMaps, [](AffineMap map) { + return map.isProjectedPermutation(); + }); + if (!allIndexingMapsAreProjectedPermutation) { + // TODO: handle non-projected permutations. + return op->emitOpError() + << "supports indexing maps that are only projected permutation."; + } + + SmallVector<utils::IteratorType> loopIteratorTypes = + linalgOp.getIteratorTypesArray(); + ShardingArray meshAxisAssignmentForLoopIterators = + getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings, + loopIteratorTypes, indexingMaps); + if (mesh::isAtLeastOneReductionIteratorSharded( + loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { + ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder); + spmdizeLinalgOpWithShardedReduction( + linalgOp, spmdizedOperands, operandShardings, resultShardings, + loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap, + symbolTable, implicitLocBuilder); + } else { + spmdizeTriviallyShardableOperation(*op, spmdizedOperands, + operandShardings, resultShardings, + spmdizationMap, symbolTable, builder); + } + + return success(); + } +}; + +} // namespace + +template <typename OpType> +static void registerOne(MLIRContext *ctx) { + OpType::template attachInterface<StructuredOpShardingInterface<OpType>>(*ctx); +} + +/// Variadic helper function. +template <typename... OpTypes> +static void registerAll(MLIRContext *ctx) { + (registerOne<OpTypes>(ctx), ...); +} + +void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) { + DialectRegistry registry; + registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect, + tensor::TensorDialect>(); + ctx->appendDialectRegistry(registry); + for (StringRef name : registry.getDialectNames()) + ctx->getOrLoadDialect(name); + + registerOne<linalg::GenericOp>(ctx); + registerAll< +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >(ctx); + }); +} + +} // namespace mlir::linalg diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 8b3119f02e8f..bd870d4f982e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -275,14 +275,6 @@ struct LinalgOpPartialReductionInterface ArrayRef<int64_t> oldShape = linalgOp.getShape(linalgOp.getDpsInitOperand(0)); - // Extend tile size vector to the rank of the output tensor. - SmallVector<Value> tileSizeVector = - getValueOrCreateConstantIndexOp(b, loc, sizes); - if (tileSizeVector.size() < oldShape.size()) { - auto zero = b.create<arith::ConstantIndexOp>(loc, 0); - tileSizeVector.append(oldShape.size() - tileSizeVector.size(), zero); - } - // Calculate the new shape, we insert the new dimensions based on the index // of the reduction dimensions. SmallVector<int64_t> newOutputShape; diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index 50163880e85f..03f11ad1f949 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -647,6 +647,13 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context); } +void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState, + Value input, StringRef mesh, + ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) { + build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input, + reduction); +} + void AllReduceOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { setNameFn(getResult(), "all_reduce"); diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp index fe3d7c44413f..9acee5aa8d86 100644 --- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp +++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp @@ -539,8 +539,9 @@ static bool areValuesCompatibleWithFullReplicationShardings( if (std::size(values) != std::size(shardings)) { return false; } - return llvm::all_of(llvm::zip(std::forward<ValueRange>(values), - std::forward<MeshShardingAttrRage>(shardings)), + return llvm::all_of(llvm::zip_equal( + std::forward<ValueRange>(values), + std::forward<MeshShardingAttrRage>(shardings)), [](auto valueAndSharding) { return isValueCompatibleWithFullReplicationSharding( std::get<0>(valueAndSharding), @@ -563,6 +564,88 @@ void mesh::spmdizeFullyReplicatedOperation( builder.clone(op, spmdizationMap); } +static void updateMeshAxisAssignmentForLoopIterators( + ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, + SmallVector<std::optional<SmallVector<MeshAxis>>> + &meshAxesAssignmentForLoopIterators) { + AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr); + unsigned loopIteratorIdx = affineDimExpr.getPosition(); + if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) { + assert(llvm::equal(meshAxesAssignmentForTensorAxis, + *meshAxesAssignmentForLoopIterators[loopIteratorIdx])); + } else { + meshAxesAssignmentForLoopIterators[loopIteratorIdx] = + llvm::to_vector(meshAxesAssignmentForTensorAxis); + } +} + +ShardingArray mesh::getMeshAxisAssignmentForLoopIterators( + ArrayRef<MeshShardingAttr> operandShardings, + ArrayRef<MeshShardingAttr> resultShardings, + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<AffineMap> indexingMaps) { + SmallVector<std::optional<SmallVector<MeshAxis>>> + meshAxisAssignmentForLoopIterators(loopIteratorTypes.size()); + SmallVector<MeshShardingAttr> operatorAndResultShardings; + operatorAndResultShardings.reserve(operandShardings.size() + + resultShardings.size()); + llvm::append_range(operatorAndResultShardings, operandShardings); + for (auto [sharding, affineMap] : + llvm::zip_equal(operatorAndResultShardings, indexingMaps)) { + if (!sharding) { + continue; + } + for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] : + llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) { + updateMeshAxisAssignmentForLoopIterators( + meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr, + meshAxisAssignmentForLoopIterators); + } + // Missing trailing split axes means replication on those tensor dimensions. + for (unsigned i = sharding.getSplitAxes().size(); + i < affineMap.getNumResults(); ++i) { + updateMeshAxisAssignmentForLoopIterators( + {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators); + } + } + + ShardingArray res; + llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res), + [](std::optional<SmallVector<MeshAxis>> &axes) { + if (!axes) { + return SmallVector<MeshAxis>(); + }; + return std::move(*axes); + }); + return res; +} + +bool mesh::isAtLeastOneReductionIteratorSharded( + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { + for (auto [loopIteratorType, meshAxisAssignment] : + llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { + if (loopIteratorType == utils::IteratorType::reduction && + !meshAxisAssignment.empty()) { + return true; + } + } + return false; +} + +SmallVector<MeshAxis> mesh::getReductionMeshAxes( + ArrayRef<utils::IteratorType> loopIteratorTypes, + ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { + SmallVector<MeshAxis> meshAxes; + for (auto [loopIteratorType, meshAxisAssignment] : + llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { + if (loopIteratorType == utils::IteratorType::reduction) { + llvm::append_range(meshAxes, meshAxisAssignment); + } + } + return meshAxes; +} + void mesh::spmdizeTriviallyShardableOperation( Operation &op, ArrayRef<Value> spmdizedOperands, ArrayRef<MeshShardingAttr> operandShardings, @@ -572,7 +655,7 @@ void mesh::spmdizeTriviallyShardableOperation( Operation *newOp = builder.clone(op, spmdizationMap); // Set the result types to the sharded counterparts. for (auto [oldResult, newResult, sharding] : - llvm::zip(op.getResults(), newOp->getResults(), resultShardings)) { + llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) { newResult.setType(shardType(newResult.getType(), getMesh(&op, sharding.getMesh(), symbolTable), sharding)); diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp index d59b9119dea5..cb13ee404751 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp @@ -208,4 +208,17 @@ createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes, .cast<TypedValue<IndexType>>(); } +TypedValue<IndexType> createProcessLinearIndex(StringRef mesh, + ArrayRef<MeshAxis> meshAxes, + ImplicitLocOpBuilder &builder) { + ResultRange processInGroupMultiIndex = + builder.create<ProcessMultiIndexOp>(mesh, meshAxes).getResults(); + Operation::result_range processGroupShape = + builder.create<MeshShapeOp>(mesh, meshAxes).getResult(); + OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( + llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex), + llvm::to_vector_of<OpFoldResult>(processGroupShape), builder); + return cast<TypedValue<IndexType>>(processInGroupLinearIndex.get<Value>()); +} + } // namespace mlir::mesh diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index d7dc902a9a5e..c1a261eab848 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -204,14 +204,22 @@ public: /// Roll back the rewrite. Operations may be erased during rollback. virtual void rollback() = 0; - /// Commit the rewrite. Operations/blocks may be unlinked during the commit - /// phase, but they must not be erased yet. This is because internal dialect - /// conversion state (such as `mapping`) may still be using them. Operations/ - /// blocks must be erased during cleanup. - virtual void commit() {} + /// Commit the rewrite. At this point, it is certain that the dialect + /// conversion will succeed. All IR modifications, except for operation/block + /// erasure, must be performed through the given rewriter. + /// + /// Instead of erasing operations/blocks, they should merely be unlinked + /// commit phase and finally be erased during the cleanup phase. This is + /// because internal dialect conversion state (such as `mapping`) may still + /// be using them. + /// + /// Any IR modification that was already performed before the commit phase + /// (e.g., insertion of an op) must be communicated to the listener that may + /// be attached to the given rewriter. + virtual void commit(RewriterBase &rewriter) {} /// Cleanup operations/blocks. Cleanup is called after commit. - virtual void cleanup() {} + virtual void cleanup(RewriterBase &rewriter) {} Kind getKind() const { return kind; } @@ -221,12 +229,6 @@ protected: IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl) : kind(kind), rewriterImpl(rewriterImpl) {} - /// Erase the given op (unless it was already erased). - void eraseOp(Operation *op); - - /// Erase the given block (unless it was already erased). - void eraseBlock(Block *block); - const ConversionConfig &getConfig() const; const Kind kind; @@ -265,6 +267,12 @@ public: return rewrite->getKind() == Kind::CreateBlock; } + void commit(RewriterBase &rewriter) override { + // The block was already created and inserted. Just inform the listener. + if (auto *listener = rewriter.getListener()) + listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{}); + } + void rollback() override { // Unlink all of the operations within this block, they will be deleted // separately. @@ -311,10 +319,19 @@ public: block = nullptr; } - void cleanup() override { + void commit(RewriterBase &rewriter) override { // Erase the block. assert(block && "expected block"); assert(block->empty() && "expected empty block"); + + // Notify the listener that the block is about to be erased. + if (auto *listener = + dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener())) + listener->notifyBlockErased(block); + } + + void cleanup(RewriterBase &rewriter) override { + // Erase the block. block->dropAllDefinedValueUses(); delete block; block = nullptr; @@ -341,6 +358,13 @@ public: firstInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->front()), lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) { + // If a listener is attached to the dialect conversion, ops must be moved + // one-by-one. When they are moved in bulk, notifications cannot be sent + // because the ops that used to be in the source block at the time of the + // inlining (before the "commit" phase) are unknown at the time when + // notifications are sent (which is during the "commit" phase). + assert(!getConfig().listener && + "InlineBlockRewrite not supported if listener is attached"); } static bool classof(const IRRewrite *rewrite) { @@ -382,6 +406,16 @@ public: return rewrite->getKind() == Kind::MoveBlock; } + void commit(RewriterBase &rewriter) override { + // The block was already moved. Just inform the listener. + if (auto *listener = rewriter.getListener()) { + // Note: `previousIt` cannot be passed because this is a delayed + // notification and iterators into past IR state cannot be represented. + listener->notifyBlockInserted(block, /*previous=*/region, + /*previousIt=*/{}); + } + } + void rollback() override { // Move the block back to its original position. Region::iterator before = @@ -437,7 +471,7 @@ public: LogicalResult materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser); - void commit() override; + void commit(RewriterBase &rewriter) override; void rollback() override; @@ -466,7 +500,7 @@ public: return rewrite->getKind() == Kind::ReplaceBlockArg; } - void commit() override; + void commit(RewriterBase &rewriter) override; void rollback() override; @@ -506,6 +540,17 @@ public: return rewrite->getKind() == Kind::MoveOperation; } + void commit(RewriterBase &rewriter) override { + // The operation was already moved. Just inform the listener. + if (auto *listener = rewriter.getListener()) { + // Note: `previousIt` cannot be passed because this is a delayed + // notification and iterators into past IR state cannot be represented. + listener->notifyOperationInserted( + op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block, + /*insertPt=*/{})); + } + } + void rollback() override { // Move the operation back to its original position. Block::iterator before = @@ -549,7 +594,12 @@ public: "rewrite was neither committed nor rolled back"); } - void commit() override { + void commit(RewriterBase &rewriter) override { + // Notify the listener that the operation was modified in-place. + if (auto *listener = + dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener())) + listener->notifyOperationModified(op); + if (propertiesStorage) { OpaqueProperties propCopy(propertiesStorage); // Note: The operation may have been erased in the mean time, so @@ -600,11 +650,11 @@ public: return rewrite->getKind() == Kind::ReplaceOperation; } - void commit() override; + void commit(RewriterBase &rewriter) override; void rollback() override; - void cleanup() override; + void cleanup(RewriterBase &rewriter) override; const TypeConverter *getConverter() const { return converter; } @@ -629,6 +679,12 @@ public: return rewrite->getKind() == Kind::CreateOperation; } + void commit(RewriterBase &rewriter) override { + // The operation was already created and inserted. Just inform the listener. + if (auto *listener = rewriter.getListener()) + listener->notifyOperationInserted(op, /*previous=*/{}); + } + void rollback() override; }; @@ -666,7 +722,7 @@ public: void rollback() override; - void cleanup() override; + void cleanup(RewriterBase &rewriter) override; /// Return the type converter of this materialization (which may be null). const TypeConverter *getConverter() const { @@ -735,7 +791,7 @@ namespace detail { struct ConversionPatternRewriterImpl : public RewriterBase::Listener { explicit ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config) - : eraseRewriter(ctx), config(config) {} + : context(ctx), config(config) {} //===--------------------------------------------------------------------===// // State Management @@ -900,6 +956,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { } void notifyOperationErased(Operation *op) override { erased.insert(op); } + void notifyBlockErased(Block *block) override { erased.insert(block); } /// Pointers to all erased operations and blocks. @@ -910,8 +967,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { // State //===--------------------------------------------------------------------===// - /// This rewriter must be used for erasing ops/blocks. - SingleEraseRewriter eraseRewriter; + /// MLIR context. + MLIRContext *context; // Mapping between replaced values that differ in type. This happens when // replacing a value with one of a different type. @@ -955,19 +1012,19 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { } // namespace detail } // namespace mlir -void IRRewrite::eraseOp(Operation *op) { - rewriterImpl.eraseRewriter.eraseOp(op); -} - -void IRRewrite::eraseBlock(Block *block) { - rewriterImpl.eraseRewriter.eraseBlock(block); -} - const ConversionConfig &IRRewrite::getConfig() const { return rewriterImpl.config; } -void BlockTypeConversionRewrite::commit() { +void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) { + // Inform the listener about all IR modifications that have already taken + // place: References to the original block have been replaced with the new + // block. + if (auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>( + rewriter.getListener())) + for (Operation *op : block->getUsers()) + listener->notifyOperationModified(op); + // Process the remapping for each of the original arguments. for (auto [origArg, info] : llvm::zip_equal(origBlock->getArguments(), argInfo)) { @@ -975,7 +1032,7 @@ void BlockTypeConversionRewrite::commit() { if (!info) { if (Value newArg = rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType())) - origArg.replaceAllUsesWith(newArg); + rewriter.replaceAllUsesWith(origArg, newArg); continue; } @@ -985,8 +1042,8 @@ void BlockTypeConversionRewrite::commit() { // If the argument is still used, replace it with the generated cast. if (!origArg.use_empty()) { - origArg.replaceAllUsesWith( - rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType())); + rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault( + castValue, origArg.getType())); } } } @@ -1042,13 +1099,13 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions( return success(); } -void ReplaceBlockArgRewrite::commit() { +void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType()); if (!repl) return; if (isa<BlockArgument>(repl)) { - arg.replaceAllUsesWith(repl); + rewriter.replaceAllUsesWith(arg, repl); return; } @@ -1057,7 +1114,7 @@ void ReplaceBlockArgRewrite::commit() { // replacement value. Operation *replOp = cast<OpResult>(repl).getOwner(); Block *replBlock = replOp->getBlock(); - arg.replaceUsesWithIf(repl, [&](OpOperand &operand) { + rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) { Operation *user = operand.getOwner(); return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); }); @@ -1065,14 +1122,40 @@ void ReplaceBlockArgRewrite::commit() { void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); } -void ReplaceOperationRewrite::commit() { - for (OpResult result : op->getResults()) - if (Value newValue = - rewriterImpl.mapping.lookupOrNull(result, result.getType())) - result.replaceAllUsesWith(newValue); +void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { + auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>( + rewriter.getListener()); + + // Compute replacement values. + SmallVector<Value> replacements = + llvm::map_to_vector(op->getResults(), [&](OpResult result) { + return rewriterImpl.mapping.lookupOrNull(result, result.getType()); + }); + + // Notify the listener that the operation is about to be replaced. + if (listener) + listener->notifyOperationReplaced(op, replacements); + + // Replace all uses with the new values. + for (auto [result, newValue] : + llvm::zip_equal(op->getResults(), replacements)) + if (newValue) + rewriter.replaceAllUsesWith(result, newValue); + + // The original op will be erased, so remove it from the set of unlegalized + // ops. if (getConfig().unlegalizedOps) getConfig().unlegalizedOps->erase(op); + + // Notify the listener that the operation (and its nested operations) was + // erased. + if (listener) { + op->walk<WalkOrder::PostOrder>( + [&](Operation *op) { listener->notifyOperationErased(op); }); + } + // Do not erase the operation yet. It may still be referenced in `mapping`. + // Just unlink it for now and erase it during cleanup. op->getBlock()->getOperations().remove(op); } @@ -1081,7 +1164,9 @@ void ReplaceOperationRewrite::rollback() { rewriterImpl.mapping.erase(result); } -void ReplaceOperationRewrite::cleanup() { eraseOp(op); } +void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) { + rewriter.eraseOp(op); +} void CreateOperationRewrite::rollback() { for (Region ®ion : op->getRegions()) { @@ -1100,14 +1185,20 @@ void UnresolvedMaterializationRewrite::rollback() { op->erase(); } -void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); } +void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) { + rewriter.eraseOp(op); +} void ConversionPatternRewriterImpl::applyRewrites() { // Commit all rewrites. + IRRewriter rewriter(context, config.listener); for (auto &rewrite : rewrites) - rewrite->commit(); + rewrite->commit(rewriter); + + // Clean up all rewrites. + SingleEraseRewriter eraseRewriter(context); for (auto &rewrite : rewrites) - rewrite->cleanup(); + rewrite->cleanup(eraseRewriter); } //===----------------------------------------------------------------------===// @@ -1281,7 +1372,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( ConversionPatternRewriter &rewriter, Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion) { - MLIRContext *ctx = rewriter.getContext(); + OpBuilder::InsertionGuard g(rewriter); // If no arguments are being changed or added, there is nothing to do. unsigned origArgCount = block->getNumArguments(); @@ -1289,14 +1380,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( if (llvm::equal(block->getArgumentTypes(), convertedTypes)) return block; - // Split the block at the beginning to get a new block to use for the updated - // signature. - Block *newBlock = rewriter.splitBlock(block, block->begin()); - block->replaceAllUsesWith(newBlock); - - // Map all new arguments to the location of the argument they originate from. + // Compute the locations of all block arguments in the new block. SmallVector<Location> newLocs(convertedTypes.size(), - Builder(ctx).getUnknownLoc()); + rewriter.getUnknownLoc()); for (unsigned i = 0; i < origArgCount; ++i) { auto inputMap = signatureConversion.getInputMapping(i); if (!inputMap || inputMap->replacementValue) @@ -1306,9 +1392,29 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( newLocs[inputMap->inputNo + j] = origLoc; } - SmallVector<Value, 4> newArgRange( - newBlock->addArguments(convertedTypes, newLocs)); - ArrayRef<Value> newArgs(newArgRange); + // Insert a new block with the converted block argument types and move all ops + // from the old block to the new block. + Block *newBlock = + rewriter.createBlock(block->getParent(), std::next(block->getIterator()), + convertedTypes, newLocs); + + // If a listener is attached to the dialect conversion, ops cannot be moved + // to the destination block in bulk ("fast path"). This is because at the time + // the notifications are sent, it is unknown which ops were moved. Instead, + // ops should be moved one-by-one ("slow path"), so that a separate + // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is + // a bit more efficient, so we try to do that when possible. + bool fastPath = !config.listener; + if (fastPath) { + appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end()); + newBlock->getOperations().splice(newBlock->end(), block->getOperations()); + } else { + while (!block->empty()) + rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end()); + } + + // Replace all uses of the old block with the new block. + block->replaceAllUsesWith(newBlock); // Remap each of the original arguments as determined by the signature // conversion. @@ -1333,7 +1439,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( } // Otherwise, this is a 1->1+ mapping. - auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); + auto replArgs = + newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); Value newArg; // If this is a 1->1 mapping and the types of new and replacement arguments @@ -1642,10 +1749,31 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, "expected 'source' to have no predecessors"); #endif // NDEBUG - impl->notifyBlockBeingInlined(dest, source, before); + // If a listener is attached to the dialect conversion, ops cannot be moved + // to the destination block in bulk ("fast path"). This is because at the time + // the notifications are sent, it is unknown which ops were moved. Instead, + // ops should be moved one-by-one ("slow path"), so that a separate + // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is + // a bit more efficient, so we try to do that when possible. + bool fastPath = !impl->config.listener; + + if (fastPath) + impl->notifyBlockBeingInlined(dest, source, before); + + // Replace all uses of block arguments. for (auto it : llvm::zip(source->getArguments(), argValues)) replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it)); - dest->getOperations().splice(before, source->getOperations()); + + if (fastPath) { + // Move all ops at once. + dest->getOperations().splice(before, source->getOperations()); + } else { + // Move op by op. + while (!source->empty()) + moveOpBefore(&source->front(), dest, before); + } + + // Erase the source block. eraseBlock(source); } diff --git a/mlir/test/Dialect/Linalg/mesh-spmdization.mlir b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir new file mode 100644 index 000000000000..6d21def8de27 --- /dev/null +++ b/mlir/test/Dialect/Linalg/mesh-spmdization.mlir @@ -0,0 +1,165 @@ +// RUN: mlir-opt \ +// RUN: --mesh-spmdization \ +// RUN: --test-constant-fold \ +// RUN: --split-input-file \ +// RUN: %s | FileCheck %s + +// CHECK: #[[$MAP_IDENTITY_1D:.*]] = affine_map<(d0) -> (d0)> +#map_identity_1d = affine_map<(d0) -> (d0)> + +mesh.mesh @mesh_1d(shape = 2) + +// CHECK-LABEL: func @elementwise_static_1d_mesh_static_1d_tensor +func.func @elementwise_static_1d_mesh_static_1d_tensor( + // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1xi8>, + %in1: tensor<2xi8>, + // CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<1xi8>, + %in2: tensor<2xi8>, + // CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<1xi8> + %dps_out: tensor<2xi8> +// CHECK-SAME: -> tensor<1xi8> { +) -> tensor<2xi8> { + %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[0]]> : tensor<2xi8> + %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8> + %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<2xi8> + %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8> + %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[0]]> : tensor<2xi8> + %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8> + // CHECK: %[[RES:.*]] = linalg.generic { + // CHECK-SAME: indexing_maps = [#[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]], #[[$MAP_IDENTITY_1D]]], + // CHECK-SAME: iterator_types = ["parallel"]} + // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1xi8>, tensor<1xi8>) + // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1xi8>) { + %res = linalg.generic { + indexing_maps = [#map_identity_1d, #map_identity_1d, #map_identity_1d], + iterator_types = ["parallel"] + } ins(%in1_shared2, %in2_shared2 : tensor<2xi8>, tensor<2xi8>) + outs(%dps_out_shared2 : tensor<2xi8>) { + ^bb0(%in1_scalar: i8, %in2_scalar: i8, %out: i8): + %res_scalar = arith.muli %in1_scalar, %in2_scalar : i8 + linalg.yield %res_scalar : i8 + } -> tensor<2xi8> + %res_shared1 = mesh.shard %res to <@mesh_1d, [[0]]> : tensor<2xi8> + %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<2xi8> + // CHECK: return %[[RES]] : tensor<1xi8> + return %res_shared2 : tensor<2xi8> +} + +// ----- + +mesh.mesh @mesh_1d(shape = 4) + +// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding +func.func @matmul_1d_mesh_static_tensors_parallel_iterator_sharding( + // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<1x3xi8>, + %in1: tensor<4x3xi8>, +// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<3x8xi8>, + %in2: tensor<3x8xi8>, +// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<1x8xi8> + %dps_out: tensor<4x8xi8> +// CHECK-SAME: -> tensor<1x8xi8> { +) -> tensor<4x8xi8> { + %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[0]]> : tensor<4x3xi8> + %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x3xi8> + %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[]]> : tensor<3x8xi8> + %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<3x8xi8> + %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[0]]> : tensor<4x8xi8> + %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x8xi8> + // CHECK: %[[RES:.*]] = linalg.matmul + // CHECK-SAME: ins(%[[IN1]], %[[IN2]] : tensor<1x3xi8>, tensor<3x8xi8>) + // CHECK-SAME: outs(%[[DPS_OUT]] : tensor<1x8xi8>) + // CHECK-SAME: -> tensor<1x8xi8> + %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x3xi8>, tensor<3x8xi8>) + outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8> + %res_shared1 = mesh.shard %res to <@mesh_1d, [[0]]> : tensor<4x8xi8> + %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<4x8xi8> + // CHECK: return %[[RES]] : tensor<1x8xi8> + return %res_shared2 : tensor<4x8xi8> +} + +// ----- + +mesh.mesh @mesh_1d(shape = 3) + +// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding +func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding( + // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>, + %in1: tensor<4x6xi8>, +// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>, + %in2: tensor<6x8xi8>, +// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8> + %dps_out: tensor<4x8xi8> +// CHECK-SAME: -> tensor<4x8xi8> { +) -> tensor<4x8xi8> { + %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[], [0]]> : tensor<4x6xi8> + %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x6xi8> + %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<6x8xi8> + %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<6x8xi8> + %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[]]> : tensor<4x8xi8> + %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8> + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8 + // CHECK-DAG: %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index + // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index + // CHECK: %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index + // CHECK: %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) { + // CHECK: scf.yield %[[DPS_OUT]] : tensor<4x8xi8> + // CHECK: } else { + // CHECK-DAG: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<4x8xi8> + // CHECK: %[[NEUTRAL_ELEMENT_FILLED_TENSOR:.*]] = linalg.fill ins(%[[C0_I8]] : i8) + // CHECK-SAME: outs(%[[EMPTY_TENSOR]] : tensor<4x8xi8>) -> tensor<4x8xi8> + // CHECK: scf.yield %[[NEUTRAL_ELEMENT_FILLED_TENSOR]] : tensor<4x8xi8> + // CHECK: } + // CHECK: %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>) + // CHECK-SAME: outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8> + // CHECK: %[[ALL_REDUCED:.*]] = mesh.all_reduce %[[SHARDED_MATMUL]] on @mesh_1d mesh_axes = [0] : tensor<4x8xi8> -> tensor<4x8xi8> + %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>) + outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8> + %res_shared1 = mesh.shard %res to <@mesh_1d, [[]]> : tensor<4x8xi8> + %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8> + // CHECK: return %[[ALL_REDUCED]] : tensor<4x8xi8> + return %res_shared2 : tensor<4x8xi8> +} + +// ----- + +mesh.mesh @mesh_1d(shape = 3) + +// CHECK-LABEL: func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partial_result +func.func @matmul_1d_mesh_static_tensors_reduction_iterator_sharding_with_partial_result( + // CHECK-SAME: %[[IN1:[A-Za-z0-9_]+]]: tensor<4x2xi8>, + %in1: tensor<4x6xi8>, +// CHECK-SAME: %[[IN2:[A-Za-z0-9_]+]]: tensor<2x8xi8>, + %in2: tensor<6x8xi8>, +// CHECK-SAME: %[[DPS_OUT:[A-Za-z0-9_]+]]: tensor<4x8xi8> + %dps_out: tensor<4x8xi8> +// CHECK-SAME: -> tensor<4x8xi8> { +) -> tensor<4x8xi8> { + %in1_shared1 = mesh.shard %in1 to <@mesh_1d, [[], [0]]> : tensor<4x6xi8> + %in1_shared2 = mesh.shard %in1_shared1 to <@mesh_1d, [[], [0]]> annotate_for_users: tensor<4x6xi8> + %in2_shared1 = mesh.shard %in2 to <@mesh_1d, [[0]]> : tensor<6x8xi8> + %in2_shared2 = mesh.shard %in2_shared1 to <@mesh_1d, [[0]]> annotate_for_users: tensor<6x8xi8> + %dps_out_shared1 = mesh.shard %dps_out to <@mesh_1d, [[]]> : tensor<4x8xi8> + %dps_out_shared2 = mesh.shard %dps_out_shared1 to <@mesh_1d, [[]]> annotate_for_users: tensor<4x8xi8> + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[C0_I8:.*]] = arith.constant 0 : i8 + // CHECK-DAG: %[[PROCESS_IDX:.*]] = mesh.process_multi_index on @mesh_1d axes = [0] : index + // CHECK-DAG: %[[MESH_SIZE:.*]] = mesh.mesh_shape @mesh_1d axes = [0] : index + // CHECK: %[[DPS_INIT_OPERAND_CONDITION:.*]] = arith.cmpi eq, %[[PROCESS_IDX]], %[[C0]] : index + // CHECK: %[[DPS_INIT_OPERAND:.*]] = scf.if %[[DPS_INIT_OPERAND_CONDITION]] -> (tensor<4x8xi8>) { + // CHECK: scf.yield %[[DPS_OUT]] : tensor<4x8xi8> + // CHECK: } else { + // CHECK-DAG: %[[EMPTY_TENSOR:.*]] = tensor.empty() : tensor<4x8xi8> + // CHECK: %[[NEUTRAL_ELEMENT_FILLED_TENSOR:.*]] = linalg.fill ins(%[[C0_I8]] : i8) + // CHECK-SAME: outs(%[[EMPTY_TENSOR]] : tensor<4x8xi8>) -> tensor<4x8xi8> + // CHECK: scf.yield %[[NEUTRAL_ELEMENT_FILLED_TENSOR]] : tensor<4x8xi8> + // CHECK: } + // CHECK: %[[SHARDED_MATMUL:.*]] = linalg.matmul ins(%[[IN1]], %[[IN2]] : tensor<4x2xi8>, tensor<2x8xi8>) + // CHECK-SAME: outs(%[[DPS_INIT_OPERAND]] : tensor<4x8xi8>) -> tensor<4x8xi8> + %res = linalg.matmul ins(%in1_shared2, %in2_shared2 : tensor<4x6xi8>, tensor<6x8xi8>) + outs(%dps_out_shared2 : tensor<4x8xi8>) -> tensor<4x8xi8> + %res_shared1 = mesh.shard %res to <@mesh_1d, [[]], partial = sum[0]> : tensor<4x8xi8> + %res_shared2 = mesh.shard %res_shared1 to <@mesh_1d, [[]], partial = sum[0]> annotate_for_users: tensor<4x8xi8> + // CHECK: return %[[SHARDED_MATMUL]] : tensor<4x8xi8> + return %res_shared2 : tensor<4x8xi8> +} diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index ccdc9fe78ea0..d552f0346644 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -1,5 +1,10 @@ // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns -verify-diagnostics %s | FileCheck %s +// CHECK: notifyOperationInserted: test.legal_op_a, was unlinked +// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_a +// CHECK-NEXT: notifyOperationModified: func.return +// CHECK-NEXT: notifyOperationErased: test.illegal_op_a + // CHECK-LABEL: verifyDirectPattern func.func @verifyDirectPattern() -> i32 { // CHECK-NEXT: "test.legal_op_a"() <{status = "Success"} @@ -8,6 +13,16 @@ func.func @verifyDirectPattern() -> i32 { return %result : i32 } +// ----- + +// CHECK: notifyOperationInserted: test.illegal_op_e, was unlinked +// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_c +// CHECK-NEXT: notifyOperationModified: func.return +// CHECK-NEXT: notifyOperationErased: test.illegal_op_c +// CHECK-NEXT: notifyOperationInserted: test.legal_op_a, was unlinked +// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_e +// CHECK-NEXT: notifyOperationErased: test.illegal_op_e + // CHECK-LABEL: verifyLargerBenefit func.func @verifyLargerBenefit() -> i32 { // CHECK-NEXT: "test.legal_op_a"() <{status = "Success"} @@ -16,16 +31,24 @@ func.func @verifyLargerBenefit() -> i32 { return %result : i32 } +// ----- + +// CHECK: notifyOperationModified: func.func +// Note: No block insertion because this function is external and no block +// signature conversion is performed. + // CHECK-LABEL: func private @remap_input_1_to_0() func.func private @remap_input_1_to_0(i16) +// ----- + // CHECK-LABEL: func @remap_input_1_to_1(%arg0: f64) func.func @remap_input_1_to_1(%arg0: i64) { // CHECK-NEXT: "test.valid"{{.*}} : (f64) "test.invalid"(%arg0) : (i64) -> () } -// CHECK-LABEL: func @remap_call_1_to_1(%arg0: f64) +// CHECK: func @remap_call_1_to_1(%arg0: f64) func.func @remap_call_1_to_1(%arg0: i64) { // CHECK-NEXT: call @remap_input_1_to_1(%arg0) : (f64) -> () call @remap_input_1_to_1(%arg0) : (i64) -> () @@ -33,12 +56,36 @@ func.func @remap_call_1_to_1(%arg0: i64) { return } +// ----- + +// Block signature conversion: new block is inserted. +// CHECK: notifyBlockInserted into func.func: was unlinked + +// Contents of the old block are moved to the new block. +// CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown + +// The new block arguments are used in "test.return". +// CHECK-NEXT: notifyOperationModified: test.return + +// The old block is erased. +// CHECK-NEXT: notifyBlockErased + +// The function op gets a new type attribute. +// CHECK-NEXT: notifyOperationModified: func.func + +// "test.return" is replaced. +// CHECK-NEXT: notifyOperationInserted: test.return, was unlinked +// CHECK-NEXT: notifyOperationReplaced: test.return +// CHECK-NEXT: notifyOperationErased: test.return + // CHECK-LABEL: func @remap_input_1_to_N({{.*}}f16, {{.*}}f16) func.func @remap_input_1_to_N(%arg0: f32) -> f32 { // CHECK-NEXT: "test.return"{{.*}} : (f16, f16) -> () "test.return"(%arg0) : (f32) -> () } +// ----- + // CHECK-LABEL: func @remap_input_1_to_N_remaining_use(%arg0: f16, %arg1: f16) func.func @remap_input_1_to_N_remaining_use(%arg0: f32) { // CHECK-NEXT: [[CAST:%.*]] = "test.cast"(%arg0, %arg1) : (f16, f16) -> f32 @@ -54,6 +101,8 @@ func.func @remap_materialize_1_to_1(%arg0: i42) { "test.return"(%arg0) : (i42) -> () } +// ----- + // CHECK-LABEL: func @remap_input_to_self func.func @remap_input_to_self(%arg0: index) { // CHECK-NOT: test.cast @@ -68,6 +117,8 @@ func.func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) { "test.invalid"(%arg0, %arg1) : (i64, i64) -> () } +// ----- + // CHECK-LABEL: func @no_remap_nested func.func @no_remap_nested() { // CHECK-NEXT: "foo.region" @@ -82,6 +133,8 @@ func.func @no_remap_nested() { return } +// ----- + // CHECK-LABEL: func @remap_moved_region_args func.func @remap_moved_region_args() { // CHECK-NEXT: return @@ -96,6 +149,8 @@ func.func @remap_moved_region_args() { return } +// ----- + // CHECK-LABEL: func @remap_cloned_region_args func.func @remap_cloned_region_args() { // CHECK-NEXT: return @@ -122,6 +177,8 @@ func.func @remap_drop_region() { return } +// ----- + // CHECK-LABEL: func @dropped_input_in_use func.func @dropped_input_in_use(%arg: i16, %arg2: i64) { // CHECK-NEXT: "test.cast"{{.*}} : () -> i16 @@ -130,6 +187,8 @@ func.func @dropped_input_in_use(%arg: i16, %arg2: i64) { "work"(%arg) : (i16) -> () } +// ----- + // CHECK-LABEL: func @up_to_date_replacement func.func @up_to_date_replacement(%arg: i8) -> i8 { // CHECK-NEXT: return @@ -139,6 +198,8 @@ func.func @up_to_date_replacement(%arg: i8) -> i8 { return %repl_2 : i8 } +// ----- + // CHECK-LABEL: func @remove_foldable_op // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: i32) func.func @remove_foldable_op(%arg0 : i32) -> (i32) { @@ -150,6 +211,8 @@ func.func @remove_foldable_op(%arg0 : i32) -> (i32) { return %0 : i32 } +// ----- + // CHECK-LABEL: @create_block func.func @create_block() { // Check that we created a block with arguments. @@ -161,6 +224,12 @@ func.func @create_block() { return } +// ----- + +// CHECK: notifyOperationModified: test.recursive_rewrite +// CHECK-NEXT: notifyOperationModified: test.recursive_rewrite +// CHECK-NEXT: notifyOperationModified: test.recursive_rewrite + // CHECK-LABEL: @bounded_recursion func.func @bounded_recursion() { // CHECK: test.recursive_rewrite 0 diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 27eae2ffd694..2da184bc3d85 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -327,8 +327,12 @@ struct TestPatternDriver struct DumpNotifications : public RewriterBase::Listener { void notifyBlockInserted(Block *block, Region *previous, Region::iterator previousIt) override { - llvm::outs() << "notifyBlockInserted into " - << block->getParentOp()->getName() << ": "; + llvm::outs() << "notifyBlockInserted"; + if (block->getParentOp()) { + llvm::outs() << " into " << block->getParentOp()->getName() << ": "; + } else { + llvm::outs() << " into unknown op: "; + } if (previous == nullptr) { llvm::outs() << "was unlinked\n"; } else { @@ -341,7 +345,9 @@ struct DumpNotifications : public RewriterBase::Listener { if (!previous.isSet()) { llvm::outs() << ", was unlinked\n"; } else { - if (previous.getPoint() == previous.getBlock()->end()) { + if (!previous.getPoint().getNodePtr()) { + llvm::outs() << ", was linked, exact position unknown\n"; + } else if (previous.getPoint() == previous.getBlock()->end()) { llvm::outs() << ", was last in block\n"; } else { llvm::outs() << ", previous = " << previous.getPoint()->getName() @@ -349,9 +355,18 @@ struct DumpNotifications : public RewriterBase::Listener { } } } + void notifyBlockErased(Block *block) override { + llvm::outs() << "notifyBlockErased\n"; + } void notifyOperationErased(Operation *op) override { llvm::outs() << "notifyOperationErased: " << op->getName() << "\n"; } + void notifyOperationModified(Operation *op) override { + llvm::outs() << "notifyOperationModified: " << op->getName() << "\n"; + } + void notifyOperationReplaced(Operation *op, ValueRange values) override { + llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n"; + } }; struct TestStrictPatternDriver @@ -1153,6 +1168,8 @@ struct TestLegalizePatternDriver if (mode == ConversionMode::Partial) { DenseSet<Operation *> unlegalizedOps; ConversionConfig config; + DumpNotifications dumpNotifications; + config.listener = &dumpNotifications; config.unlegalizedOps = &unlegalizedOps; if (failed(applyPartialConversion(getOperation(), target, std::move(patterns), config))) { @@ -1171,8 +1188,11 @@ struct TestLegalizePatternDriver return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal"); }); + ConversionConfig config; + DumpNotifications dumpNotifications; + config.listener = &dumpNotifications; if (failed(applyFullConversion(getOperation(), target, - std::move(patterns)))) { + std::move(patterns), config))) { getOperation()->emitRemark() << "applyFullConversion failed"; } return; diff --git a/openmp/runtime/src/kmp_collapse.cpp b/openmp/runtime/src/kmp_collapse.cpp index 2c410ca9b603..569d2c150831 100644 --- a/openmp/runtime/src/kmp_collapse.cpp +++ b/openmp/runtime/src/kmp_collapse.cpp @@ -1272,6 +1272,304 @@ void kmp_calc_original_ivs_for_end( } } +/************************************************************************** + * Identify nested loop structure - loops come in the canonical form + * Lower triangle matrix: i = 0; i <= N; i++ {0,0}:{N,0} + * j = 0; j <= 0/-1+1*i; j++ {0,0}:{0/-1,1} + * Upper Triangle matrix + * i = 0; i <= N; i++ {0,0}:{N,0} + * j = 0+1*i; j <= N; j++ {0,1}:{N,0} + * ************************************************************************/ +nested_loop_type_t +kmp_identify_nested_loop_structure(/*in*/ bounds_info_t *original_bounds_nest, + /*in*/ kmp_index_t n) { + // only 2-level nested loops are supported + if (n != 2) { + return nested_loop_type_unkown; + } + // loops must be canonical + KMP_ASSERT( + (original_bounds_nest[0].comparison == comparison_t::comp_less_or_eq) && + (original_bounds_nest[1].comparison == comparison_t::comp_less_or_eq)); + // check outer loop bounds: for triangular need to be {0,0}:{N,0} + kmp_uint64 outer_lb0_u64 = kmp_fix_iv(original_bounds_nest[0].loop_iv_type, + original_bounds_nest[0].lb0_u64); + kmp_uint64 outer_ub0_u64 = kmp_fix_iv(original_bounds_nest[0].loop_iv_type, + original_bounds_nest[0].ub0_u64); + kmp_uint64 outer_lb1_u64 = kmp_fix_iv(original_bounds_nest[0].loop_iv_type, + original_bounds_nest[0].lb1_u64); + kmp_uint64 outer_ub1_u64 = kmp_fix_iv(original_bounds_nest[0].loop_iv_type, + original_bounds_nest[0].ub1_u64); + if (outer_lb0_u64 != 0 || outer_lb1_u64 != 0 || outer_ub1_u64 != 0) { + return nested_loop_type_unkown; + } + // check inner bounds to determine triangle type + kmp_uint64 inner_lb0_u64 = kmp_fix_iv(original_bounds_nest[1].loop_iv_type, + original_bounds_nest[1].lb0_u64); + kmp_uint64 inner_ub0_u64 = kmp_fix_iv(original_bounds_nest[1].loop_iv_type, + original_bounds_nest[1].ub0_u64); + kmp_uint64 inner_lb1_u64 = kmp_fix_iv(original_bounds_nest[1].loop_iv_type, + original_bounds_nest[1].lb1_u64); + kmp_uint64 inner_ub1_u64 = kmp_fix_iv(original_bounds_nest[1].loop_iv_type, + original_bounds_nest[1].ub1_u64); + // lower triangle loop inner bounds need to be {0,0}:{0/-1,1} + if (inner_lb0_u64 == 0 && inner_lb1_u64 == 0 && + (inner_ub0_u64 == 0 || inner_ub0_u64 == -1) && inner_ub1_u64 == 1) { + return nested_loop_type_lower_triangular_matrix; + } + // upper triangle loop inner bounds need to be {0,1}:{N,0} + if (inner_lb0_u64 == 0 && inner_lb1_u64 == 1 && + inner_ub0_u64 == outer_ub0_u64 && inner_ub1_u64 == 0) { + return nested_loop_type_upper_triangular_matrix; + } + return nested_loop_type_unkown; +} + +/************************************************************************** + * SQRT Approximation: https://math.mit.edu/~stevenj/18.335/newton-sqrt.pdf + * Start point is x so the result is always > sqrt(x) + * The method has uniform convergence, PRECISION is set to 0.1 + * ************************************************************************/ +#define level_of_precision 0.1 +double sqrt_newton_approx(/*in*/ kmp_uint64 x) { + double sqrt_old = 0.; + double sqrt_new = (double)x; + do { + sqrt_old = sqrt_new; + sqrt_new = (sqrt_old + x / sqrt_old) / 2; + } while ((sqrt_old - sqrt_new) > level_of_precision); + return sqrt_new; +} + +/************************************************************************** + * Handle lower triangle matrix in the canonical form + * i = 0; i <= N; i++ {0,0}:{N,0} + * j = 0; j <= 0/-1 + 1*i; j++ {0,0}:{0/-1,1} + * ************************************************************************/ +void kmp_handle_lower_triangle_matrix( + /*in*/ kmp_uint32 nth, + /*in*/ kmp_uint32 tid, + /*in */ kmp_index_t n, + /*in/out*/ bounds_info_t *original_bounds_nest, + /*out*/ bounds_info_t *chunk_bounds_nest) { + + // transfer loop types from the original loop to the chunks + for (kmp_index_t i = 0; i < n; ++i) { + chunk_bounds_nest[i] = original_bounds_nest[i]; + } + // cleanup iv variables + kmp_uint64 outer_ub0 = kmp_fix_iv(original_bounds_nest[0].loop_iv_type, + original_bounds_nest[0].ub0_u64); + kmp_uint64 outer_lb0 = kmp_fix_iv(original_bounds_nest[0].loop_iv_type, + original_bounds_nest[0].lb0_u64); + kmp_uint64 inner_ub0 = kmp_fix_iv(original_bounds_nest[1].loop_iv_type, + original_bounds_nest[1].ub0_u64); + // calculate the chunk's lower and upper bounds + // the total number of iterations in the loop is the sum of the arithmetic + // progression from the outer lower to outer upper bound (inclusive since the + // loop is canonical) note that less_than inner loops (inner_ub0 = -1) + // effectively make the progression 1-based making N = (outer_ub0 - inner_lb0 + // + 1) -> N - 1 + kmp_uint64 outer_iters = (outer_ub0 - outer_lb0 + 1) + inner_ub0; + kmp_uint64 iter_total = outer_iters * (outer_iters + 1) / 2; + // the current thread's number of iterations: + // each thread gets an equal number of iterations: total number of iterations + // divided by the number of threads plus, if there's a remainder, + // the first threads with the number up to the remainder get an additional + // iteration each to cover it + kmp_uint64 iter_current = + iter_total / nth + ((tid < (iter_total % nth)) ? 1 : 0); + // cumulative number of iterations executed by all the previous threads: + // threads with the tid below the remainder will have (iter_total/nth+1) + // elements, and so will all threads before them so the cumulative number of + // iterations executed by the all previous will be the current thread's number + // of iterations multiplied by the number of previous threads which is equal + // to the current thread's tid; threads with the number equal or above the + // remainder will have (iter_total/nth) elements so the cumulative number of + // iterations previously executed is its number of iterations multipled by the + // number of previous threads which is again equal to the current thread's tid + // PLUS all the remainder iterations that will have been executed by the + // previous threads + kmp_uint64 iter_before_current = + tid * iter_current + ((tid < iter_total % nth) ? 0 : (iter_total % nth)); + // cumulative number of iterations executed with the current thread is + // the cumulative number executed before it plus its own + kmp_uint64 iter_with_current = iter_before_current + iter_current; + // calculate the outer loop lower bound (lbo) which is the max outer iv value + // that gives the number of iterations that is equal or just below the total + // number of iterations executed by the previous threads, for less_than + // (1-based) inner loops (inner_ub0 == -1) it will be i.e. + // lbo*(lbo-1)/2<=iter_before_current => lbo^2-lbo-2*iter_before_current<=0 + // for less_than_equal (0-based) inner loops (inner_ub == 0) it will be: + // i.e. lbo*(lbo+1)/2<=iter_before_current => + // lbo^2+lbo-2*iter_before_current<=0 both cases can be handled similarily + // using a parameter to control the equation sign + kmp_int64 inner_adjustment = 1 + 2 * inner_ub0; + kmp_uint64 lower_bound_outer = + (kmp_uint64)(sqrt_newton_approx(inner_adjustment * inner_adjustment + + 8 * iter_before_current) + + inner_adjustment) / + 2 - + inner_adjustment; + // calculate the inner loop lower bound which is the remaining number of + // iterations required to hit the total number of iterations executed by the + // previous threads giving the starting point of this thread + kmp_uint64 lower_bound_inner = + iter_before_current - + ((lower_bound_outer + inner_adjustment) * lower_bound_outer) / 2; + // calculate the outer loop upper bound using the same approach as for the + // inner bound except using the total number of iterations executed with the + // current thread + kmp_uint64 upper_bound_outer = + (kmp_uint64)(sqrt_newton_approx(inner_adjustment * inner_adjustment + + 8 * iter_with_current) + + inner_adjustment) / + 2 - + inner_adjustment; + // calculate the inner loop upper bound which is the remaining number of + // iterations required to hit the total number of iterations executed after + // the current thread giving the starting point of the next thread + kmp_uint64 upper_bound_inner = + iter_with_current - + ((upper_bound_outer + inner_adjustment) * upper_bound_outer) / 2; + // adjust the upper bounds down by 1 element to point at the last iteration of + // the current thread the first iteration of the next thread + if (upper_bound_inner == 0) { + // {n,0} => {n-1,n-1} + upper_bound_outer -= 1; + upper_bound_inner = upper_bound_outer; + } else { + // {n,m} => {n,m-1} (m!=0) + upper_bound_inner -= 1; + } + + // assign the values, zeroing out lb1 and ub1 values since the iteration space + // is now one-dimensional + chunk_bounds_nest[0].lb0_u64 = lower_bound_outer; + chunk_bounds_nest[1].lb0_u64 = lower_bound_inner; + chunk_bounds_nest[0].ub0_u64 = upper_bound_outer; + chunk_bounds_nest[1].ub0_u64 = upper_bound_inner; + chunk_bounds_nest[0].lb1_u64 = 0; + chunk_bounds_nest[0].ub1_u64 = 0; + chunk_bounds_nest[1].lb1_u64 = 0; + chunk_bounds_nest[1].ub1_u64 = 0; + +#if 0 + printf("tid/nth = %d/%d : From [%llu, %llu] To [%llu, %llu] : Chunks %llu/%llu\n", + tid, nth, chunk_bounds_nest[0].lb0_u64, chunk_bounds_nest[1].lb0_u64, + chunk_bounds_nest[0].ub0_u64, chunk_bounds_nest[1].ub0_u64, iter_current, iter_total); +#endif +} + +/************************************************************************** + * Handle upper triangle matrix in the canonical form + * i = 0; i <= N; i++ {0,0}:{N,0} + * j = 0+1*i; j <= N; j++ {0,1}:{N,0} + * ************************************************************************/ +void kmp_handle_upper_triangle_matrix( + /*in*/ kmp_uint32 nth, + /*in*/ kmp_uint32 tid, + /*in */ kmp_index_t n, + /*in/out*/ bounds_info_t *original_bounds_nest, + /*out*/ bounds_info_t *chunk_bounds_nest) { + + // transfer loop types from the original loop to the chunks + for (kmp_index_t i = 0; i < n; ++i) { + chunk_bounds_nest[i] = original_bounds_nest[i]; + } + // cleanup iv variables + kmp_uint64 outer_ub0 = kmp_fix_iv(original_bounds_nest[0].loop_iv_type, + original_bounds_nest[0].ub0_u64); + kmp_uint64 outer_lb0 = kmp_fix_iv(original_bounds_nest[0].loop_iv_type, + original_bounds_nest[0].lb0_u64); + kmp_uint64 inner_ub0 = kmp_fix_iv(original_bounds_nest[1].loop_iv_type, + original_bounds_nest[1].ub0_u64); + // calculate the chunk's lower and upper bounds + // the total number of iterations in the loop is the sum of the arithmetic + // progression from the outer lower to outer upper bound (inclusive since the + // loop is canonical) note that less_than inner loops (inner_ub0 = -1) + // effectively make the progression 1-based making N = (outer_ub0 - inner_lb0 + // + 1) -> N - 1 + kmp_uint64 outer_iters = (outer_ub0 - outer_lb0 + 1); + kmp_uint64 iter_total = outer_iters * (outer_iters + 1) / 2; + // the current thread's number of iterations: + // each thread gets an equal number of iterations: total number of iterations + // divided by the number of threads plus, if there's a remainder, + // the first threads with the number up to the remainder get an additional + // iteration each to cover it + kmp_uint64 iter_current = + iter_total / nth + ((tid < (iter_total % nth)) ? 1 : 0); + // cumulative number of iterations executed by all the previous threads: + // threads with the tid below the remainder will have (iter_total/nth+1) + // elements, and so will all threads before them so the cumulative number of + // iterations executed by the all previous will be the current thread's number + // of iterations multiplied by the number of previous threads which is equal + // to the current thread's tid; threads with the number equal or above the + // remainder will have (iter_total/nth) elements so the cumulative number of + // iterations previously executed is its number of iterations multipled by the + // number of previous threads which is again equal to the current thread's tid + // PLUS all the remainder iterations that will have been executed by the + // previous threads + kmp_uint64 iter_before_current = + tid * iter_current + ((tid < iter_total % nth) ? 0 : (iter_total % nth)); + // cumulative number of iterations executed with the current thread is + // the cumulative number executed before it plus its own + kmp_uint64 iter_with_current = iter_before_current + iter_current; + // calculate the outer loop lower bound (lbo) which is the max outer iv value + // that gives the number of iterations that is equal or just below the total + // number of iterations executed by the previous threads, for less_than + // (1-based) inner loops (inner_ub0 == -1) it will be i.e. + // lbo*(lbo-1)/2<=iter_before_current => lbo^2-lbo-2*iter_before_current<=0 + // for less_than_equal (0-based) inner loops (inner_ub == 0) it will be: + // i.e. lbo*(lbo+1)/2<=iter_before_current => + // lbo^2+lbo-2*iter_before_current<=0 both cases can be handled similarily + // using a parameter to control the equatio sign + kmp_uint64 lower_bound_outer = + (kmp_uint64)(sqrt_newton_approx(1 + 8 * iter_before_current) + 1) / 2 - 1; + ; + // calculate the inner loop lower bound which is the remaining number of + // iterations required to hit the total number of iterations executed by the + // previous threads giving the starting point of this thread + kmp_uint64 lower_bound_inner = + iter_before_current - ((lower_bound_outer + 1) * lower_bound_outer) / 2; + // calculate the outer loop upper bound using the same approach as for the + // inner bound except using the total number of iterations executed with the + // current thread + kmp_uint64 upper_bound_outer = + (kmp_uint64)(sqrt_newton_approx(1 + 8 * iter_with_current) + 1) / 2 - 1; + // calculate the inner loop upper bound which is the remaining number of + // iterations required to hit the total number of iterations executed after + // the current thread giving the starting point of the next thread + kmp_uint64 upper_bound_inner = + iter_with_current - ((upper_bound_outer + 1) * upper_bound_outer) / 2; + // adjust the upper bounds down by 1 element to point at the last iteration of + // the current thread the first iteration of the next thread + if (upper_bound_inner == 0) { + // {n,0} => {n-1,n-1} + upper_bound_outer -= 1; + upper_bound_inner = upper_bound_outer; + } else { + // {n,m} => {n,m-1} (m!=0) + upper_bound_inner -= 1; + } + + // assign the values, zeroing out lb1 and ub1 values since the iteration space + // is now one-dimensional + chunk_bounds_nest[0].lb0_u64 = (outer_iters - 1) - upper_bound_outer; + chunk_bounds_nest[1].lb0_u64 = (outer_iters - 1) - upper_bound_inner; + chunk_bounds_nest[0].ub0_u64 = (outer_iters - 1) - lower_bound_outer; + chunk_bounds_nest[1].ub0_u64 = (outer_iters - 1) - lower_bound_inner; + chunk_bounds_nest[0].lb1_u64 = 0; + chunk_bounds_nest[0].ub1_u64 = 0; + chunk_bounds_nest[1].lb1_u64 = 0; + chunk_bounds_nest[1].ub1_u64 = 0; + +#if 0 + printf("tid/nth = %d/%d : From [%llu, %llu] To [%llu, %llu] : Chunks %llu/%llu\n", + tid, nth, chunk_bounds_nest[0].lb0_u64, chunk_bounds_nest[1].lb0_u64, + chunk_bounds_nest[0].ub0_u64, chunk_bounds_nest[1].ub0_u64, iter_current, iter_total); +#endif +} //----------Init API for non-rectangular loops-------------------------------- // Init API for collapsed loops (static, no chunks defined). @@ -1334,6 +1632,19 @@ __kmpc_for_collapsed_init(ident_t *loc, kmp_int32 gtid, KMP_DEBUG_ASSERT(tid < nth); + // Handle special cases + nested_loop_type_t loop_type = + kmp_identify_nested_loop_structure(original_bounds_nest, n); + if (loop_type == nested_loop_type_lower_triangular_matrix) { + kmp_handle_lower_triangle_matrix(nth, tid, n, original_bounds_nest, + chunk_bounds_nest); + return TRUE; + } else if (loop_type == nested_loop_type_upper_triangular_matrix) { + kmp_handle_upper_triangle_matrix(nth, tid, n, original_bounds_nest, + chunk_bounds_nest); + return TRUE; + } + CollapseAllocator<kmp_uint64> original_ivs_start(n); if (!kmp_calc_original_ivs_for_start(original_bounds_nest, n, diff --git a/openmp/runtime/src/kmp_collapse.h b/openmp/runtime/src/kmp_collapse.h index e4870185645d..1044478554a0 100644 --- a/openmp/runtime/src/kmp_collapse.h +++ b/openmp/runtime/src/kmp_collapse.h @@ -45,6 +45,13 @@ enum loop_type_t : kmp_int32 { loop_type_int64 = 7 }; +// Defining loop types to handle special cases +enum nested_loop_type_t : kmp_int32 { + nested_loop_type_unkown = 0, + nested_loop_type_lower_triangular_matrix = 1, + nested_loop_type_upper_triangular_matrix = 2 +}; + /*! @ingroup WORK_SHARING * Describes the structure for rectangular nested loops. @@ -124,14 +131,14 @@ struct bounds_info_t { // It's represented in kmp_uint64, but each dimention is calculated in // that loop IV type. Also dimentions have to be converted to those types // when used in generated code. -typedef kmp_uint64* kmp_point_t; +typedef kmp_uint64 *kmp_point_t; // Array: Number of loop iterations on each nesting level to achieve some point, // in expanded space or in original space. // OMPTODO: move from using iterations to using offsets (iterations multiplied // by steps). For those we need to be careful with the types, as step can be // negative, but it'll remove multiplications and divisions in several places. -typedef kmp_loop_nest_iv_t* kmp_iterations_t; +typedef kmp_loop_nest_iv_t *kmp_iterations_t; // Internal struct with additional info: template <typename T> struct bounds_info_internalXX_template { diff --git a/openmp/runtime/test/worksharing/for/omp_for_collapse_LowerTriangularLess.c b/openmp/runtime/test/worksharing/for/omp_for_collapse_LowerTriangularLess.c new file mode 100644 index 000000000000..9d742066cf1f --- /dev/null +++ b/openmp/runtime/test/worksharing/for/omp_for_collapse_LowerTriangularLess.c @@ -0,0 +1,124 @@ +// RUN: %libomp-compile-and-run +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include "omp.h" + +#ifndef MAX_BOUND +#define MAX_BOUND 64 +#endif +#ifndef _MSC_VER +#define NO_EFFICIENCY_CHECK +#endif + +/* To ensure Correctness, only valid iterations are executed and are executed + only once. Stores the number of times an iteration is executed. */ +unsigned *execution_count = NULL; +/* Stores the number of iterations executed by each thread. */ +unsigned *iterations_per_thread = NULL; + +unsigned *Alloc(unsigned bound1, unsigned bound2) { + return (unsigned *)(malloc(bound1 * bound2 * sizeof(unsigned))); +} + +void ZeroOut(unsigned *p, unsigned bound1, unsigned bound2) { + memset(p, 0, bound1 * bound2 * sizeof(unsigned)); +} + +void Free(unsigned *p) { free((void *)p); } + +unsigned *Index(unsigned *p, unsigned i, unsigned j, unsigned bound2) { + return &p[i * bound2 + j]; +} + +int test(unsigned upper_bound) { + + unsigned total_iterations = upper_bound * (upper_bound - 1) / 2; + unsigned num_threads = omp_get_max_threads(); + unsigned lower_per_chunk = total_iterations / num_threads; + unsigned upper_per_chunk = + lower_per_chunk + ((total_iterations % num_threads) ? 1 : 0); + int i, j; + + omp_set_num_threads(num_threads); + + ZeroOut(execution_count, upper_bound, upper_bound); + ZeroOut(iterations_per_thread, num_threads, 1); + +#ifdef VERBOSE + fprintf(stderr, + "INFO: Using %6d threads for %6d outer iterations with %6d [%6d:%6d] " + "chunks " + "loop type lower triangle <,< - ", + num_threads, upper_bound, total_iterations, lower_per_chunk, + upper_per_chunk); +#endif + +#pragma omp parallel shared(iterations_per_thread, execution_count) + { /* begin of parallel */ + /* Lower triangular execution_count matrix */ +#pragma omp for schedule(static) collapse(2) + for (i = 0; i < upper_bound; i++) { + for (j = 0; j < i; j++) { + (*Index(iterations_per_thread, omp_get_thread_num(), 0, 1))++; + (*Index(execution_count, i, j, upper_bound))++; + } + } /* end of for*/ + } /* end of parallel */ + + /* check the execution_count array */ + for (i = 0; i < upper_bound; i++) { + for (j = 0; j < i; j++) { + unsigned value = *Index(execution_count, i, j, upper_bound); + /* iteration with j<=i are valid, but should have been executed only once + */ + if (value != 1) { + fprintf(stderr, "ERROR: valid iteration [%i,%i] executed %i times.\n", + i, j, value); + return 0; + } + } + for (j = i; j < upper_bound; j++) { + unsigned value = *Index(execution_count, i, j, upper_bound); + /* iteration with j>=i are invalid and should not have been executed + */ + if (value > 0) { + fprintf(stderr, "ERROR: invalid iteration [%i,%i] executed %i times.\n", + i, j, value); + return 0; + } + } + } + +#ifndef NO_EFFICIENCY_CHECK + /* Ensure the number of iterations executed by each thread is within bounds */ + for (i = 0; i < num_threads; i++) { + unsigned value = *Index(iterations_per_thread, i, 0, 1); + if (value < lower_per_chunk || value > upper_per_chunk) { + fprintf(stderr, + "ERROR: Inefficient Collapse thread %d of %d assigned %i " + "iterations; must be between %d and %d\n", + i, num_threads, value, lower_per_chunk, upper_per_chunk); + return 0; + } + } +#endif +#ifdef VERBOSE + fprintf(stderr, "PASSED\r\n"); +#endif + return 1; +} + +int main() { + + execution_count = Alloc(MAX_BOUND, MAX_BOUND); + iterations_per_thread = Alloc(omp_get_max_threads(), 1); + + for (unsigned j = 0; j < MAX_BOUND; j++) { + if (!test(j)) + return 1; + } + Free(execution_count); + Free(iterations_per_thread); + return 0; +} diff --git a/openmp/runtime/test/worksharing/for/omp_for_collapse_LowerTriangularLessEqual.c b/openmp/runtime/test/worksharing/for/omp_for_collapse_LowerTriangularLessEqual.c new file mode 100644 index 000000000000..154ee0f69daa --- /dev/null +++ b/openmp/runtime/test/worksharing/for/omp_for_collapse_LowerTriangularLessEqual.c @@ -0,0 +1,124 @@ +// RUN: %libomp-compile-and-run +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include "omp.h" + +#ifndef MAX_BOUND +#define MAX_BOUND 64 +#endif +#ifndef _MSC_VER +#define NO_EFFICIENCY_CHECK +#endif + +/* To ensure Correctness, only valid iterations are executed and are executed + only once. Stores the number of times an iteration is executed. */ +unsigned *execution_count = NULL; +/* Stores the number of iterations executed by each thread. */ +unsigned *iterations_per_thread = NULL; + +unsigned *Alloc(unsigned bound1, unsigned bound2) { + return (unsigned *)(malloc(bound1 * bound2 * sizeof(unsigned))); +} + +void ZeroOut(unsigned *p, unsigned bound1, unsigned bound2) { + memset(p, 0, bound1 * bound2 * sizeof(unsigned)); +} + +void Free(unsigned *p) { free((void *)p); } + +unsigned *Index(unsigned *p, unsigned i, unsigned j, unsigned bound2) { + return &p[i * bound2 + j]; +} + +int test(int upper_bound) { + + unsigned total_iterations = upper_bound * (upper_bound + 1) / 2; + unsigned num_threads = omp_get_max_threads(); + unsigned lower_per_chunk = total_iterations / num_threads; + unsigned upper_per_chunk = + lower_per_chunk + ((total_iterations % num_threads) ? 1 : 0); + int i, j; + + omp_set_num_threads(num_threads); + + ZeroOut(execution_count, upper_bound, upper_bound); + ZeroOut(iterations_per_thread, num_threads, 1); + +#ifdef VERBOSE + fprintf(stderr, + "INFO: Using %6d threads for %6d outer iterations with %6d [%6d:%6d] " + "chunks " + "loop type lower triangle <,<= - ", + num_threads, upper_bound, total_iterations, lower_per_chunk, + upper_per_chunk); +#endif + +#pragma omp parallel shared(iterations_per_thread, execution_count) + { /* begin of parallel */ + /* Lower triangular execution_count matrix */ +#pragma omp for schedule(static) collapse(2) + for (i = 0; i < upper_bound; i++) { + for (j = 0; j <= i; j++) { + (*Index(iterations_per_thread, omp_get_thread_num(), 0, 1))++; + (*Index(execution_count, i, j, upper_bound))++; + } + } /* end of for*/ + } /* end of parallel */ + + /* check the execution_count array */ + for (i = 0; i < upper_bound; i++) { + for (j = 0; j <= i; j++) { + unsigned value = *Index(execution_count, i, j, upper_bound); + /* iteration with j<=i are valid, but should have been executed only once + */ + if (value != 1) { + fprintf(stderr, "ERROR: valid iteration [%i,%i] executed %i times.\n", + i, j, value); + return 0; + } + } + for (j = i + 1; j < upper_bound; j++) { + unsigned value = *Index(execution_count, i, j, upper_bound); + /* iteration with j>=i are invalid and should not have been executed + */ + if (value > 0) { + fprintf(stderr, "ERROR: invalid iteration [%i,%i] executed %i times.\n", + i, j, value); + return 0; + } + } + } + +#ifndef NO_EFFICIENCY_CHECK + /* Ensure the number of iterations executed by each thread is within bounds */ + for (i = 0; i < num_threads; i++) { + unsigned value = *Index(iterations_per_thread, i, 0, 1); + if (value < lower_per_chunk || value > upper_per_chunk) { + fprintf(stderr, + "ERROR: Inefficient Collapse thread %d of %d assigned %i " + "iterations; must be between %d and %d\n", + i, num_threads, value, lower_per_chunk, upper_per_chunk); + return 0; + } + } +#endif +#ifdef VERBOSE + fprintf(stderr, "PASSED\r\n"); +#endif + return 1; +} + +int main() { + + execution_count = Alloc(MAX_BOUND, MAX_BOUND); + iterations_per_thread = Alloc(omp_get_max_threads(), 1); + + for (unsigned j = 0; j < MAX_BOUND; j++) { + if (!test(j)) + return 1; + } + Free(execution_count); + Free(iterations_per_thread); + return 0; +} diff --git a/openmp/runtime/test/worksharing/for/omp_for_collapse_UpperTriangular.c b/openmp/runtime/test/worksharing/for/omp_for_collapse_UpperTriangular.c new file mode 100644 index 000000000000..452410025be0 --- /dev/null +++ b/openmp/runtime/test/worksharing/for/omp_for_collapse_UpperTriangular.c @@ -0,0 +1,124 @@ +// RUN: %libomp-compile-and-run +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include "omp.h" + +#ifndef MAX_BOUND +#define MAX_BOUND 64 +#endif +#ifndef _MSC_VER +#define NO_EFFICIENCY_CHECK +#endif + +/* To ensure Correctness, only valid iterations are executed and are executed + only once. Stores the number of times an iteration is executed. */ +unsigned *execution_count = NULL; +/* Stores the number of iterations executed by each thread. */ +unsigned *iterations_per_thread = NULL; + +unsigned *Alloc(unsigned bound1, unsigned bound2) { + return (unsigned *)(malloc(bound1 * bound2 * sizeof(unsigned))); +} + +void ZeroOut(unsigned *p, unsigned bound1, unsigned bound2) { + memset(p, 0, bound1 * bound2 * sizeof(unsigned)); +} + +void Free(unsigned *p) { free((void *)p); } + +unsigned *Index(unsigned *p, unsigned i, unsigned j, unsigned bound2) { + return &p[i * bound2 + j]; +} + +int test(unsigned upper_bound) { + + unsigned total_iterations = upper_bound * (upper_bound + 1) / 2; + unsigned num_threads = omp_get_max_threads(); + unsigned lower_per_chunk = total_iterations / num_threads; + unsigned upper_per_chunk = + lower_per_chunk + ((total_iterations % num_threads) ? 1 : 0); + int i, j; + + omp_set_num_threads(num_threads); + + ZeroOut(execution_count, upper_bound, upper_bound); + ZeroOut(iterations_per_thread, num_threads, 1); + +#ifdef VERBOSE + fprintf(stderr, + "INFO: Using %6d threads for %6d outer iterations with %6d [%6d:%6d] " + "chunks " + "loop type upper triangle <,< - ", + num_threads, upper_bound, total_iterations, lower_per_chunk, + upper_per_chunk); +#endif + +#pragma omp parallel shared(iterations_per_thread, execution_count) + { /* begin of parallel */ + /* Lower triangular execution_count matrix */ +#pragma omp for schedule(static) collapse(2) + for (i = 0; i < upper_bound; i++) { + for (j = i; j < upper_bound; j++) { + (*Index(iterations_per_thread, omp_get_thread_num(), 0, 1))++; + (*Index(execution_count, i, j, upper_bound))++; + } + } /* end of for*/ + } /* end of parallel */ + + /* check the execution_count array */ + for (i = 0; i < upper_bound; i++) { + for (j = i; j < upper_bound; j++) { + unsigned value = *Index(execution_count, i, j, upper_bound); + /* iteration with j<=i are valid, but should have been executed only once + */ + if (value != 1) { + fprintf(stderr, "ERROR: valid iteration [%i,%i] executed %i times.\n", + i, j, value); + return 0; + } + } + for (j = 0; j < i; j++) { + unsigned value = *Index(execution_count, i, j, upper_bound); + /* iteration with j>=i are invalid and should not have been executed + */ + if (value > 0) { + fprintf(stderr, "ERROR: invalid iteration [%i,%i] executed %i times.\n", + i, j, value); + return 0; + } + } + } + +#ifndef NO_EFFICIENCY_CHECK + /* Ensure the number of iterations executed by each thread is within bounds */ + for (i = 0; i < num_threads; i++) { + unsigned value = *Index(iterations_per_thread, i, 0, 1); + if (value < lower_per_chunk || value > upper_per_chunk) { + fprintf(stderr, + "ERROR: Inefficient Collapse thread %d of %d assigned %i " + "iterations; must be between %d and %d\n", + i, num_threads, value, lower_per_chunk, upper_per_chunk); + return 0; + } + } +#endif +#ifdef VERBOSE + fprintf(stderr, "PASSED\r\n"); +#endif + return 1; +} + +int main() { + + execution_count = Alloc(MAX_BOUND, MAX_BOUND); + iterations_per_thread = Alloc(omp_get_max_threads(), 1); + + for (unsigned j = 0; j < MAX_BOUND; j++) { + if (!test(j)) + return 1; + } + Free(execution_count); + Free(iterations_per_thread); + return 0; +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 7a6bc2dc3202..2cfe61844703 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -10841,6 +10841,7 @@ cc_library( ":MemRefDialect", ":Parser", ":SCFDialect", + ":MeshShardingInterface", ":SideEffectInterfaces", ":SparseTensorDialect", ":Support", @@ -10994,10 +10995,13 @@ cc_library( ":MathDialect", ":MemRefDialect", ":MemRefTransforms", + ":MeshDialect", + ":MeshTransforms", ":Pass", ":SCFDialect", ":SCFTransforms", ":SCFUtils", + ":MeshShardingInterface", ":SparseTensorDialect", ":SubsetOpInterface", ":Support", |