summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorValentin Clement (バレンタイン クレメン) <clementval@gmail.com>2024-05-03 11:25:34 -0700
committerGitHub <noreply@github.com>2024-05-03 11:25:34 -0700
commitf8a9973f8c1ef60281ace6f3cfeb24d9dcd5b3c3 (patch)
tree6562f1477733f075acb8cab9c96af12f4dbef63d
parenta4d10266d20bfe5930dfed77e17832af341ed66e (diff)
[flang][cuda] Add verifier for cuda_alloc/cuda_free (#90983)
Adding a verifier to check the associated cuda attribute.
-rw-r--r--flang/include/flang/Optimizer/Dialect/FIROps.td4
-rw-r--r--flang/lib/Optimizer/Dialect/FIROps.cpp13
-rw-r--r--flang/test/Fir/cuf-invalid.fir18
3 files changed, 35 insertions, 0 deletions
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index dc38e56d93c6..64c5e360b28f 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -3364,6 +3364,8 @@ def fir_CUDAAllocOp : fir_Op<"cuda_alloc", [AttrSizedOperandSegments,
CArg<"mlir::ValueRange", "{}">:$typeparams,
CArg<"mlir::ValueRange", "{}">:$shape,
CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>];
+
+ let hasVerifier = 1;
}
def fir_CUDAFreeOp : fir_Op<"cuda_free", [MemoryEffects<[MemFree]>]> {
@@ -3381,6 +3383,8 @@ def fir_CUDAFreeOp : fir_Op<"cuda_free", [MemoryEffects<[MemFree]>]> {
);
let assemblyFormat = "$devptr `:` qualified(type($devptr)) attr-dict";
+
+ let hasVerifier = 1;
}
#endif
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 5e6c18af2dd0..edf7f7f4b1a9 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -4048,6 +4048,19 @@ void fir::CUDAAllocOp::build(
result.addAttributes(attributes);
}
+template <typename Op>
+static mlir::LogicalResult checkCudaAttr(Op op) {
+ if (op.getCudaAttr() == fir::CUDADataAttribute::Device ||
+ op.getCudaAttr() == fir::CUDADataAttribute::Managed ||
+ op.getCudaAttr() == fir::CUDADataAttribute::Unified)
+ return mlir::success();
+ return op.emitOpError("expect device, managed or unified cuda attribute");
+}
+
+mlir::LogicalResult fir::CUDAAllocOp::verify() { return checkCudaAttr(*this); }
+
+mlir::LogicalResult fir::CUDAFreeOp::verify() { return checkCudaAttr(*this); }
+
//===----------------------------------------------------------------------===//
// FIROpsDialect
//===----------------------------------------------------------------------===//
diff --git a/flang/test/Fir/cuf-invalid.fir b/flang/test/Fir/cuf-invalid.fir
index 6c533a32ccf9..5a12e3c1a4bf 100644
--- a/flang/test/Fir/cuf-invalid.fir
+++ b/flang/test/Fir/cuf-invalid.fir
@@ -85,3 +85,21 @@ func.func @_QPsub1() {
%13 = fir.cuda_deallocate %11 : !fir.ref<!fir.box<none>> errmsg(%16 : !fir.box<none>) {cuda_attr = #fir.cuda<device>} -> i32
return
}
+
+// -----
+
+func.func @_QPsub1() {
+ // expected-error@+1{{'fir.cuda_alloc' op expect device, managed or unified cuda attribute}}
+ %0 = fir.cuda_alloc f32 {bindc_name = "r", cuda_attr = #fir.cuda<pinned>, uniq_name = "_QFsub1Er"} -> !fir.ref<f32>
+ fir.cuda_free %0 : !fir.ref<f32> {cuda_attr = #fir.cuda<constant>}
+ return
+}
+
+// -----
+
+func.func @_QPsub1() {
+ %0 = fir.cuda_alloc f32 {bindc_name = "r", cuda_attr = #fir.cuda<device>, uniq_name = "_QFsub1Er"} -> !fir.ref<f32>
+ // expected-error@+1{{'fir.cuda_free' op expect device, managed or unified cuda attribute}}
+ fir.cuda_free %0 : !fir.ref<f32> {cuda_attr = #fir.cuda<constant>}
+ return
+}