summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGuray Ozen <guray.ozen@gmail.com>2024-01-09 16:44:25 +0100
committerGitHub <noreply@github.com>2024-01-09 16:44:25 +0100
commit2aec7083ada09c8b8a0aad79492cbedcf8f9fbb7 (patch)
tree23e78636a3ce8d53915134f1b2c590c3e1ab8109
parentca06c330fd07f05e65a638892c32ca1474d47b5e (diff)
[mlir][gpu] Use DenseI32Array for NVVM's maxntid and reqntid (NFC) (#77466)
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp2
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp10
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp10
-rw-r--r--mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir2
-rw-r--r--mlir/test/Target/LLVMIR/nvvmir.mlir10
5 files changed, 13 insertions, 21 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index eeb8fbbb180b..ae2bd8e5b540 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -100,7 +100,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// If any of the dimensions are missing, fill them in with 1.
attributes.emplace_back(
kernelBlockSizeAttributeName.value(),
- rewriter.getI32ArrayAttr(
+ rewriter.getDenseI32ArrayAttr(
{dimX.value_or(1), dimY.value_or(1), dimZ.value_or(1)}));
}
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index a4de89d928e1..aa49c4dc31fb 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1060,19 +1060,13 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
// If maxntid and reqntid exist, it must be an array with max 3 dim
if (attrName == NVVMDialect::getMaxntidAttrName() ||
attrName == NVVMDialect::getReqntidAttrName()) {
- auto values = llvm::dyn_cast<ArrayAttr>(attr.getValue());
+ auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue());
if (!values || values.empty() || values.size() > 3)
return op->emitError()
<< "'" << attrName
<< "' attribute must be integer array with maximum 3 index";
- for (auto val : llvm::cast<ArrayAttr>(attr.getValue())) {
- if (!llvm::dyn_cast<IntegerAttr>(val))
- return op->emitError()
- << "'" << attrName
- << "' attribute must be integer array with maximum 3 index";
- }
}
- // If minctasm and maxnreg exist, it must be an array with max 3 dim
+ // If minctasm and maxnreg exist, it must be an integer attribute
if (attrName == NVVMDialect::getMinctasmAttrName() ||
attrName == NVVMDialect::getMaxnregAttrName()) {
if (!llvm::dyn_cast<IntegerAttr>(attr.getValue()))
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 0d6bca5e2203..45eb8402a734 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -163,20 +163,18 @@ public:
->addOperand(llvmMetadataNode);
};
if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) {
- if (!dyn_cast<ArrayAttr>(attribute.getValue()))
+ if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
return failure();
- SmallVector<int64_t> values =
- extractFromIntegerArrayAttr<int64_t>(attribute.getValue());
+ auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName());
if (values.size() > 1)
generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName());
if (values.size() > 2)
generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName());
} else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) {
- if (!dyn_cast<ArrayAttr>(attribute.getValue()))
+ if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue()))
return failure();
- SmallVector<int64_t> values =
- extractFromIntegerArrayAttr<int64_t>(attribute.getValue());
+ auto values = cast<DenseI32ArrayAttr>(attribute.getValue());
generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName());
if (values.size() > 1)
generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName());
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index c7f1d4f124c1..66630d33d118 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -629,7 +629,7 @@ gpu.module @test_module_31 {
gpu.module @gpumodule {
// CHECK-LABEL: func @kernel_with_block_size()
-// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = [128 : i32, 1 : i32, 1 : i32]}
+// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = array<i32: 128, 1, 1>}
gpu.func @kernel_with_block_size() kernel attributes {gpu.known_block_size = array<i32: 128, 1, 1>} {
gpu.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 6076fce598fb..f83be9dbb2ff 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -398,7 +398,7 @@ llvm.func @kernel_func() attributes {nvvm.kernel} {
// -----
-llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32]} {
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 1, 23, 32>} {
llvm.return
}
@@ -410,7 +410,7 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32]} {
// CHECK: {ptr @kernel_func, !"maxntidz", i32 32}
// -----
-llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = [1,23,32]} {
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = array<i32: 1, 23, 32>} {
llvm.return
}
@@ -442,7 +442,7 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxnreg = 16} {
// CHECK: {ptr @kernel_func, !"maxnreg", i32 16}
// -----
-llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32],
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 1, 23, 32>,
nvvm.minctasm = 16, nvvm.maxnreg = 32} {
llvm.return
}
@@ -472,13 +472,13 @@ nvvm.maxnreg = "boo"} {
}
// -----
// expected-error @below {{'"nvvm.reqntid"' attribute must be integer array with maximum 3 index}}
-llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = [3,4,5,6]} {
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = array<i32: 3, 4, 5, 6>} {
llvm.return
}
// -----
// expected-error @below {{'"nvvm.maxntid"' attribute must be integer array with maximum 3 index}}
-llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [3,4,5,6]} {
+llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 3, 4, 5, 6>} {
llvm.return
}