summaryrefslogtreecommitdiffstats
path: root/mlir/test/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/lib')
-rw-r--r--mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp26
-rw-r--r--mlir/test/lib/Dialect/Test/TestDialect.cpp37
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td16
3 files changed, 58 insertions, 21 deletions
diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
index 6730f9b292ad..b098a5a23fd3 100644
--- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
@@ -109,7 +109,7 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
FailureOr<OpFoldResult> reified = failure();
if (constant) {
auto reifiedConst = ValueBoundsConstraintSet::computeConstantBound(
- boundType, value, dim, /*stopCondition=*/nullptr);
+ boundType, {value, dim}, /*stopCondition=*/nullptr);
if (succeeded(reifiedConst))
reified = FailureOr<OpFoldResult>(rewriter.getIndexAttr(*reifiedConst));
} else if (scalable) {
@@ -128,22 +128,12 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,
rewriter, loc, reifiedScalable->map, vscaleOperand);
}
} else {
- if (dim) {
- if (useArithOps) {
- reified = arith::reifyShapedValueDimBound(
- rewriter, op->getLoc(), boundType, value, *dim, stopCondition);
- } else {
- reified = reifyShapedValueDimBound(rewriter, op->getLoc(), boundType,
- value, *dim, stopCondition);
- }
+ if (useArithOps) {
+ reified = arith::reifyValueBound(rewriter, op->getLoc(), boundType,
+ op.getVariable(), stopCondition);
} else {
- if (useArithOps) {
- reified = arith::reifyIndexValueBound(
- rewriter, op->getLoc(), boundType, value, stopCondition);
- } else {
- reified = reifyIndexValueBound(rewriter, op->getLoc(), boundType,
- value, stopCondition);
- }
+ reified = reifyValueBound(rewriter, op->getLoc(), boundType,
+ op.getVariable(), stopCondition);
}
}
if (failed(reified)) {
@@ -188,9 +178,7 @@ static LogicalResult testEquality(func::FuncOp funcOp) {
}
auto compare = [&](ValueBoundsConstraintSet::ComparisonOperator cmp) {
- return ValueBoundsConstraintSet::compare(
- /*lhs=*/op.getLhs(), /*lhsDim=*/std::nullopt, cmp,
- /*rhs=*/op.getRhs(), /*rhsDim=*/std::nullopt);
+ return ValueBoundsConstraintSet::compare(op.getLhs(), cmp, op.getRhs());
};
if (compare(cmpType)) {
op->emitRemark("true");
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 25c5190ca0ef..a23ed89c4b04 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -549,6 +549,12 @@ LogicalResult ReifyBoundOp::verify() {
return success();
}
+::mlir::ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() {
+ if (getDim().has_value())
+ return ValueBoundsConstraintSet::Variable(getVar(), *getDim());
+ return ValueBoundsConstraintSet::Variable(getVar());
+}
+
::mlir::ValueBoundsConstraintSet::ComparisonOperator
CompareOp::getComparisonOperator() {
if (getCmp() == "EQ")
@@ -564,6 +570,37 @@ CompareOp::getComparisonOperator() {
llvm_unreachable("invalid comparison operator");
}
+::mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() {
+ if (!getLhsMap())
+ return ValueBoundsConstraintSet::Variable(getVarOperands()[0]);
+ SmallVector<Value> mapOperands(
+ getVarOperands().slice(0, getLhsMap()->getNumInputs()));
+ return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands);
+}
+
+::mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() {
+ int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+ if (!getRhsMap())
+ return ValueBoundsConstraintSet::Variable(
+ getVarOperands()[rhsOperandsBegin]);
+ SmallVector<Value> mapOperands(
+ getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs()));
+ return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands);
+}
+
+LogicalResult CompareOp::verify() {
+ if (getCompose() && (getLhsMap() || getRhsMap()))
+ return emitOpError(
+ "'compose' not supported when 'lhs_map' or 'rhs_map' is present");
+ int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1;
+ expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1;
+ if (getVarOperands().size() != size_t(expectedNumOperands))
+ return emitOpError("expected ")
+ << expectedNumOperands << " operands, but got "
+ << getVarOperands().size();
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Test removing op with inner ops.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index ebf158b8bb82..b641b3da719c 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2207,6 +2207,7 @@ def ReifyBoundOp : TEST_Op<"reify_bound", [Pure]> {
let extraClassDeclaration = [{
::mlir::presburger::BoundType getBoundType();
+ ::mlir::ValueBoundsConstraintSet::Variable getVariable();
}];
let hasVerifier = 1;
@@ -2217,18 +2218,29 @@ def CompareOp : TEST_Op<"compare"> {
Compare `lhs` and `rhs`. A remark is emitted which indicates whether the
specified comparison operator was proven to hold. The remark also indicates
whether the opposite comparison operator was proven to hold.
+
+ `var_operands` must have exactly two operands: one for the LHS operand and
+ one for the RHS operand. If `lhs_map` is specified, as many operands as
+ `lhs_map` has inputs are expected instead of the first operand. If `rhs_map`
+ is specified, as many operands as `rhs_map` has inputs are expected instead
+ of the second operand.
}];
- let arguments = (ins Index:$lhs,
- Index:$rhs,
+ let arguments = (ins Variadic<Index>:$var_operands,
DefaultValuedAttr<StrAttr, "\"EQ\"">:$cmp,
+ OptionalAttr<AffineMapAttr>:$lhs_map,
+ OptionalAttr<AffineMapAttr>:$rhs_map,
UnitAttr:$compose);
let results = (outs);
let extraClassDeclaration = [{
::mlir::ValueBoundsConstraintSet::ComparisonOperator
getComparisonOperator();
+ ::mlir::ValueBoundsConstraintSet::Variable getLhs();
+ ::mlir::ValueBoundsConstraintSet::Variable getRhs();
}];
+
+ let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//