diff options
author | Guray Ozen <guray.ozen@gmail.com> | 2024-01-09 16:44:25 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-09 16:44:25 +0100 |
commit | 2aec7083ada09c8b8a0aad79492cbedcf8f9fbb7 (patch) | |
tree | 23e78636a3ce8d53915134f1b2c590c3e1ab8109 | |
parent | ca06c330fd07f05e65a638892c32ca1474d47b5e (diff) |
[mlir][gpu] Use DenseI32Array for NVVM's maxntid and reqntid (NFC) (#77466)
-rw-r--r-- | mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp | 2 | ||||
-rw-r--r-- | mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 10 | ||||
-rw-r--r-- | mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp | 10 | ||||
-rw-r--r-- | mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir | 2 | ||||
-rw-r--r-- | mlir/test/Target/LLVMIR/nvvmir.mlir | 10 |
5 files changed, 13 insertions, 21 deletions
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp index eeb8fbbb180b..ae2bd8e5b540 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -100,7 +100,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, // If any of the dimensions are missing, fill them in with 1. attributes.emplace_back( kernelBlockSizeAttributeName.value(), - rewriter.getI32ArrayAttr( + rewriter.getDenseI32ArrayAttr( {dimX.value_or(1), dimY.value_or(1), dimZ.value_or(1)})); } } diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index a4de89d928e1..aa49c4dc31fb 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -1060,19 +1060,13 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, // If maxntid and reqntid exist, it must be an array with max 3 dim if (attrName == NVVMDialect::getMaxntidAttrName() || attrName == NVVMDialect::getReqntidAttrName()) { - auto values = llvm::dyn_cast<ArrayAttr>(attr.getValue()); + auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue()); if (!values || values.empty() || values.size() > 3) return op->emitError() << "'" << attrName << "' attribute must be integer array with maximum 3 index"; - for (auto val : llvm::cast<ArrayAttr>(attr.getValue())) { - if (!llvm::dyn_cast<IntegerAttr>(val)) - return op->emitError() - << "'" << attrName - << "' attribute must be integer array with maximum 3 index"; - } } - // If minctasm and maxnreg exist, it must be an array with max 3 dim + // If minctasm and maxnreg exist, it must be an integer attribute if (attrName == NVVMDialect::getMinctasmAttrName() || attrName == NVVMDialect::getMaxnregAttrName()) { if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index 0d6bca5e2203..45eb8402a734 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -163,20 +163,18 @@ public: ->addOperand(llvmMetadataNode); }; if (attribute.getName() == NVVM::NVVMDialect::getMaxntidAttrName()) { - if (!dyn_cast<ArrayAttr>(attribute.getValue())) + if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue())) return failure(); - SmallVector<int64_t> values = - extractFromIntegerArrayAttr<int64_t>(attribute.getValue()); + auto values = cast<DenseI32ArrayAttr>(attribute.getValue()); generateMetadata(values[0], NVVM::NVVMDialect::getMaxntidXName()); if (values.size() > 1) generateMetadata(values[1], NVVM::NVVMDialect::getMaxntidYName()); if (values.size() > 2) generateMetadata(values[2], NVVM::NVVMDialect::getMaxntidZName()); } else if (attribute.getName() == NVVM::NVVMDialect::getReqntidAttrName()) { - if (!dyn_cast<ArrayAttr>(attribute.getValue())) + if (!dyn_cast<DenseI32ArrayAttr>(attribute.getValue())) return failure(); - SmallVector<int64_t> values = - extractFromIntegerArrayAttr<int64_t>(attribute.getValue()); + auto values = cast<DenseI32ArrayAttr>(attribute.getValue()); generateMetadata(values[0], NVVM::NVVMDialect::getReqntidXName()); if (values.size() > 1) generateMetadata(values[1], NVVM::NVVMDialect::getReqntidYName()); diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index c7f1d4f124c1..66630d33d118 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -629,7 +629,7 @@ gpu.module @test_module_31 { gpu.module @gpumodule { // CHECK-LABEL: func @kernel_with_block_size() -// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = [128 : i32, 1 : i32, 1 : i32]} +// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = array<i32: 128, 1, 1>} gpu.func @kernel_with_block_size() kernel attributes {gpu.known_block_size = array<i32: 128, 1, 1>} { gpu.return } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 6076fce598fb..f83be9dbb2ff 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -398,7 +398,7 @@ llvm.func @kernel_func() attributes {nvvm.kernel} { // ----- -llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32]} { +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 1, 23, 32>} { llvm.return } @@ -410,7 +410,7 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32]} { // CHECK: {ptr @kernel_func, !"maxntidz", i32 32} // ----- -llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = [1,23,32]} { +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = array<i32: 1, 23, 32>} { llvm.return } @@ -442,7 +442,7 @@ llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxnreg = 16} { // CHECK: {ptr @kernel_func, !"maxnreg", i32 16} // ----- -llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [1,23,32], +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 1, 23, 32>, nvvm.minctasm = 16, nvvm.maxnreg = 32} { llvm.return } @@ -472,13 +472,13 @@ nvvm.maxnreg = "boo"} { } // ----- // expected-error @below {{'"nvvm.reqntid"' attribute must be integer array with maximum 3 index}} -llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = [3,4,5,6]} { +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.reqntid = array<i32: 3, 4, 5, 6>} { llvm.return } // ----- // expected-error @below {{'"nvvm.maxntid"' attribute must be integer array with maximum 3 index}} -llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = [3,4,5,6]} { +llvm.func @kernel_func() attributes {nvvm.kernel, nvvm.maxntid = array<i32: 3, 4, 5, 6>} { llvm.return } |