diff options
author | long.chen <lipracer@gmail.com> | 2023-11-19 02:14:53 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-19 02:14:53 +0800 |
commit | dc4786b4877d67d73d3892c45baf6811af0e6f57 (patch) | |
tree | e1903a133c2404923015fc38f5eba4187e7303d7 | |
parent | c093383ffadff8dfadfd6bc0ab7107a0e194aa7e (diff) |
[mlir][affine] remove divide zero check when simplifer affineMap (#64622) (#68519)
When performing constant folding on the affineApplyOp, there is a
division of 0 in the affine map.
[related issue](https://github.com/llvm/llvm-project/issues/64622)
---------
Co-authored-by: Javier Setoain <jsetoain@users.noreply.github.com>
-rw-r--r-- | mlir/include/mlir/IR/AffineExprVisitor.h | 202 | ||||
-rw-r--r-- | mlir/include/mlir/IR/AffineMap.h | 9 | ||||
-rw-r--r-- | mlir/lib/Analysis/FlatLinearValueConstraints.cpp | 10 | ||||
-rw-r--r-- | mlir/lib/Dialect/Affine/IR/AffineOps.cpp | 14 | ||||
-rw-r--r-- | mlir/lib/Dialect/Affine/IR/CMakeLists.txt | 1 | ||||
-rw-r--r-- | mlir/lib/IR/AffineExpr.cpp | 64 | ||||
-rw-r--r-- | mlir/lib/IR/AffineMap.cpp | 46 | ||||
-rw-r--r-- | mlir/test/Dialect/Affine/constant-fold.mlir | 21 |
8 files changed, 255 insertions, 112 deletions
diff --git a/mlir/include/mlir/IR/AffineExprVisitor.h b/mlir/include/mlir/IR/AffineExprVisitor.h index 382db22dce46..2860e73c8f42 100644 --- a/mlir/include/mlir/IR/AffineExprVisitor.h +++ b/mlir/include/mlir/IR/AffineExprVisitor.h @@ -14,6 +14,7 @@ #define MLIR_IR_AFFINEEXPRVISITOR_H #include "mlir/IR/AffineExpr.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/ArrayRef.h" namespace mlir { @@ -65,8 +66,78 @@ namespace mlir { /// just as efficient as having your own switch instruction over the instruction /// opcode. +template <typename SubClass, typename RetTy> +class AffineExprVisitorBase { +public: + // Function to visit an AffineExpr. + RetTy visit(AffineExpr expr) { + static_assert(std::is_base_of<AffineExprVisitorBase, SubClass>::value, + "Must instantiate with a derived type of AffineExprVisitor"); + auto self = static_cast<SubClass *>(this); + switch (expr.getKind()) { + case AffineExprKind::Add: { + auto binOpExpr = cast<AffineBinaryOpExpr>(expr); + return self->visitAddExpr(binOpExpr); + } + case AffineExprKind::Mul: { + auto binOpExpr = cast<AffineBinaryOpExpr>(expr); + return self->visitMulExpr(binOpExpr); + } + case AffineExprKind::Mod: { + auto binOpExpr = cast<AffineBinaryOpExpr>(expr); + return self->visitModExpr(binOpExpr); + } + case AffineExprKind::FloorDiv: { + auto binOpExpr = cast<AffineBinaryOpExpr>(expr); + return self->visitFloorDivExpr(binOpExpr); + } + case AffineExprKind::CeilDiv: { + auto binOpExpr = cast<AffineBinaryOpExpr>(expr); + return self->visitCeilDivExpr(binOpExpr); + } + case AffineExprKind::Constant: + return self->visitConstantExpr(cast<AffineConstantExpr>(expr)); + case AffineExprKind::DimId: + return self->visitDimExpr(cast<AffineDimExpr>(expr)); + case AffineExprKind::SymbolId: + return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr)); + } + llvm_unreachable("Unknown AffineExpr"); + } + + //===--------------------------------------------------------------------===// + // Visitation functions... these functions provide default fallbacks in case + // the user does not specify what to do for a particular instruction type. + // The default behavior is to generalize the instruction type to its subtype + // and try visiting the subtype. All of this should be inlined perfectly, + // because there are no virtual functions to get in the way. + // + + // Default visit methods. Note that the default op-specific binary op visit + // methods call the general visitAffineBinaryOpExpr visit method. + RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); } + RetTy visitAddExpr(AffineBinaryOpExpr expr) { + return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); + } + RetTy visitMulExpr(AffineBinaryOpExpr expr) { + return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); + } + RetTy visitModExpr(AffineBinaryOpExpr expr) { + return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); + } + RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) { + return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); + } + RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) { + return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); + } + RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); } + RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); } + RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); } +}; + template <typename SubClass, typename RetTy = void> -class AffineExprVisitor { +class AffineExprVisitor : public AffineExprVisitorBase<SubClass, RetTy> { //===--------------------------------------------------------------------===// // Interface code - This is the public interface of the AffineExprVisitor // that you use to visit affine expressions... @@ -75,117 +146,112 @@ public: RetTy walkPostOrder(AffineExpr expr) { static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value, "Must instantiate with a derived type of AffineExprVisitor"); + auto self = static_cast<SubClass *>(this); switch (expr.getKind()) { case AffineExprKind::Add: { auto binOpExpr = cast<AffineBinaryOpExpr>(expr); walkOperandsPostOrder(binOpExpr); - return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr); + return self->visitAddExpr(binOpExpr); } case AffineExprKind::Mul: { auto binOpExpr = cast<AffineBinaryOpExpr>(expr); walkOperandsPostOrder(binOpExpr); - return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr); + return self->visitMulExpr(binOpExpr); } case AffineExprKind::Mod: { auto binOpExpr = cast<AffineBinaryOpExpr>(expr); walkOperandsPostOrder(binOpExpr); - return static_cast<SubClass *>(this)->visitModExpr(binOpExpr); + return self->visitModExpr(binOpExpr); } case AffineExprKind::FloorDiv: { auto binOpExpr = cast<AffineBinaryOpExpr>(expr); walkOperandsPostOrder(binOpExpr); - return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr); + return self->visitFloorDivExpr(binOpExpr); } case AffineExprKind::CeilDiv: { auto binOpExpr = cast<AffineBinaryOpExpr>(expr); walkOperandsPostOrder(binOpExpr); - return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr); + return self->visitCeilDivExpr(binOpExpr); } case AffineExprKind::Constant: - return static_cast<SubClass *>(this)->visitConstantExpr( - cast<AffineConstantExpr>(expr)); + return self->visitConstantExpr(cast<AffineConstantExpr>(expr)); case AffineExprKind::DimId: - return static_cast<SubClass *>(this)->visitDimExpr( - cast<AffineDimExpr>(expr)); + return self->visitDimExpr(cast<AffineDimExpr>(expr)); case AffineExprKind::SymbolId: - return static_cast<SubClass *>(this)->visitSymbolExpr( - cast<AffineSymbolExpr>(expr)); + return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr)); } + llvm_unreachable("Unknown AffineExpr"); } - // Function to visit an AffineExpr. - RetTy visit(AffineExpr expr) { +private: + // Walk the operands - each operand is itself walked in post order. + RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) { + walkPostOrder(expr.getLHS()); + walkPostOrder(expr.getRHS()); + } +}; + +template <typename SubClass> +class AffineExprVisitor<SubClass, LogicalResult> + : public AffineExprVisitorBase<SubClass, LogicalResult> { + //===--------------------------------------------------------------------===// + // Interface code - This is the public interface of the AffineExprVisitor + // that you use to visit affine expressions... +public: + // Function to walk an AffineExpr (in post order). + LogicalResult walkPostOrder(AffineExpr expr) { static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value, "Must instantiate with a derived type of AffineExprVisitor"); + auto self = static_cast<SubClass *>(this); switch (expr.getKind()) { case AffineExprKind::Add: { auto binOpExpr = cast<AffineBinaryOpExpr>(expr); - return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr); + if (failed(walkOperandsPostOrder(binOpExpr))) + return failure(); + return self->visitAddExpr(binOpExpr); } case AffineExprKind::Mul: { auto binOpExpr = cast<AffineBinaryOpExpr>(expr); - return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr); + if (failed(walkOperandsPostOrder(binOpExpr))) + return failure(); + return self->visitMulExpr(binOpExpr); } case AffineExprKind::Mod: { auto binOpExpr = cast<AffineBinaryOpExpr>(expr); - return static_cast<SubClass *>(this)->visitModExpr(binOpExpr); + if (failed(walkOperandsPostOrder(binOpExpr))) + return failure(); + return self->visitModExpr(binOpExpr); } case AffineExprKind::FloorDiv: { auto binOpExpr = cast<AffineBinaryOpExpr>(expr); - return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr); + if (failed(walkOperandsPostOrder(binOpExpr))) + return failure(); + return self->visitFloorDivExpr(binOpExpr); } case AffineExprKind::CeilDiv: { auto binOpExpr = cast<AffineBinaryOpExpr>(expr); - return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr); + if (failed(walkOperandsPostOrder(binOpExpr))) + return failure(); + return self->visitCeilDivExpr(binOpExpr); } case AffineExprKind::Constant: - return static_cast<SubClass *>(this)->visitConstantExpr( - cast<AffineConstantExpr>(expr)); + return self->visitConstantExpr(cast<AffineConstantExpr>(expr)); case AffineExprKind::DimId: - return static_cast<SubClass *>(this)->visitDimExpr( - cast<AffineDimExpr>(expr)); + return self->visitDimExpr(cast<AffineDimExpr>(expr)); case AffineExprKind::SymbolId: - return static_cast<SubClass *>(this)->visitSymbolExpr( - cast<AffineSymbolExpr>(expr)); + return self->visitSymbolExpr(cast<AffineSymbolExpr>(expr)); } llvm_unreachable("Unknown AffineExpr"); } - //===--------------------------------------------------------------------===// - // Visitation functions... these functions provide default fallbacks in case - // the user does not specify what to do for a particular instruction type. - // The default behavior is to generalize the instruction type to its subtype - // and try visiting the subtype. All of this should be inlined perfectly, - // because there are no virtual functions to get in the way. - // - - // Default visit methods. Note that the default op-specific binary op visit - // methods call the general visitAffineBinaryOpExpr visit method. - RetTy visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { return RetTy(); } - RetTy visitAddExpr(AffineBinaryOpExpr expr) { - return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); - } - RetTy visitMulExpr(AffineBinaryOpExpr expr) { - return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); - } - RetTy visitModExpr(AffineBinaryOpExpr expr) { - return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); - } - RetTy visitFloorDivExpr(AffineBinaryOpExpr expr) { - return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); - } - RetTy visitCeilDivExpr(AffineBinaryOpExpr expr) { - return static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr); - } - RetTy visitConstantExpr(AffineConstantExpr expr) { return RetTy(); } - RetTy visitDimExpr(AffineDimExpr expr) { return RetTy(); } - RetTy visitSymbolExpr(AffineSymbolExpr expr) { return RetTy(); } - private: // Walk the operands - each operand is itself walked in post order. - RetTy walkOperandsPostOrder(AffineBinaryOpExpr expr) { - walkPostOrder(expr.getLHS()); - walkPostOrder(expr.getRHS()); + LogicalResult walkOperandsPostOrder(AffineBinaryOpExpr expr) { + if (failed(walkPostOrder(expr.getLHS()))) + return failure(); + if (failed(walkPostOrder(expr.getRHS()))) + return failure(); + return success(); } }; @@ -246,7 +312,7 @@ private: // expressions are mapped to the same local identifier (same column position in // 'localVarCst'). class SimpleAffineExprFlattener - : public AffineExprVisitor<SimpleAffineExprFlattener> { + : public AffineExprVisitor<SimpleAffineExprFlattener, LogicalResult> { public: // Flattend expression layout: [dims, symbols, locals, constant] // Stack that holds the LHS and RHS operands while visiting a binary op expr. @@ -275,13 +341,13 @@ public: virtual ~SimpleAffineExprFlattener() = default; // Visitor method overrides. - void visitMulExpr(AffineBinaryOpExpr expr); - void visitAddExpr(AffineBinaryOpExpr expr); - void visitDimExpr(AffineDimExpr expr); - void visitSymbolExpr(AffineSymbolExpr expr); - void visitConstantExpr(AffineConstantExpr expr); - void visitCeilDivExpr(AffineBinaryOpExpr expr); - void visitFloorDivExpr(AffineBinaryOpExpr expr); + LogicalResult visitMulExpr(AffineBinaryOpExpr expr); + LogicalResult visitAddExpr(AffineBinaryOpExpr expr); + LogicalResult visitDimExpr(AffineDimExpr expr); + LogicalResult visitSymbolExpr(AffineSymbolExpr expr); + LogicalResult visitConstantExpr(AffineConstantExpr expr); + LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr); + LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr); // // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1 @@ -289,7 +355,7 @@ public: // A mod expression "expr mod c" is thus flattened by introducing a new local // variable q (= expr floordiv c), such that expr mod c is replaced with // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst. - void visitModExpr(AffineBinaryOpExpr expr); + LogicalResult visitModExpr(AffineBinaryOpExpr expr); protected: // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). @@ -328,7 +394,7 @@ private: // // A ceildiv is similarly flattened: // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c - void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil); + LogicalResult visitDivExpr(AffineBinaryOpExpr expr, bool isCeil); int findLocalId(AffineExpr localExpr); diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index 713aef767edf..0e4a8d363946 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -310,7 +310,8 @@ public: /// Folds the results of the application of an affine map on the provided /// operands to a constant if possible. LogicalResult constantFold(ArrayRef<Attribute> operandConstants, - SmallVectorImpl<Attribute> &results) const; + SmallVectorImpl<Attribute> &results, + bool *hasPoison = nullptr) const; /// Propagates the constant operands into this affine map. Operands are /// allowed to be null, at which point they are treated as non-constant. This @@ -318,9 +319,9 @@ public: /// which may be equal to the old map if no folding happened. If `results` is /// provided and if all expressions in the map were folded to constants, /// `results` will contain the values of these constants. - AffineMap - partialConstantFold(ArrayRef<Attribute> operandConstants, - SmallVectorImpl<int64_t> *results = nullptr) const; + AffineMap partialConstantFold(ArrayRef<Attribute> operandConstants, + SmallVectorImpl<int64_t> *results = nullptr, + bool *hasPoison = nullptr) const; /// Returns the AffineMap resulting from composing `this` with `map`. /// The resulting AffineMap has as many AffineDimExpr as `map` and as many diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp index ea123ea56025..69846a356e0c 100644 --- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp +++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp @@ -67,7 +67,9 @@ private: } // namespace // Flattens the expressions in map. Returns failure if 'expr' was unable to be -// flattened (i.e., semi-affine expressions not handled yet). +// flattened. For example two specific cases: +// 1. semi-affine expressions not handled yet. +// 2. has poison expression (i.e., division by zero). static LogicalResult getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims, unsigned numSymbols, @@ -85,8 +87,10 @@ getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims, for (auto expr : exprs) { if (!expr.isPureAffine()) return failure(); - - flattener.walkPostOrder(expr); + // has poison expression + auto flattenResult = flattener.walkPostOrder(expr); + if (failed(flattenResult)) + return failure(); } assert(flattener.operandExprStack.size() == exprs.size()); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 05496e70716a..d22a7539fb75 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/IntegerSet.h" @@ -226,6 +227,8 @@ void AffineDialect::initialize() { Operation *AffineDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { + if (auto poison = dyn_cast<ub::PoisonAttr>(value)) + return builder.create<ub::PoisonOp>(loc, type, poison); return arith::ConstantOp::materialize(builder, value, type, loc); } @@ -580,7 +583,12 @@ OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) { // Otherwise, default to folding the map. SmallVector<Attribute, 1> result; - if (failed(map.constantFold(adaptor.getMapOperands(), result))) + bool hasPoison = false; + auto foldResult = + map.constantFold(adaptor.getMapOperands(), result, &hasPoison); + if (hasPoison) + return ub::PoisonAttr::get(getContext()); + if (failed(foldResult)) return {}; return result[0]; } @@ -3379,7 +3387,9 @@ static LogicalResult canonicalizeMapExprAndTermOrder(AffineMap &map) { return failure(); SimpleAffineExprFlattener flattener(map.getNumDims(), map.getNumSymbols()); - flattener.walkPostOrder(resultExpr); + auto flattenResult = flattener.walkPostOrder(resultExpr); + if (failed(flattenResult)) + return failure(); // Fail if the flattened expression has local variables. if (flattener.operandExprStack.back().size() != diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt index 89ea3128b0e7..9e3c1161fd92 100644 --- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt @@ -19,5 +19,6 @@ add_mlir_dialect_library(MLIRAffineDialect MLIRMemRefDialect MLIRShapedOpInterfaces MLIRSideEffectInterfaces + MLIRUBDialect MLIRValueBoundsOpInterface ) diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index cdceaac11069..038ceea286a3 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -1216,7 +1216,7 @@ SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims, // In case of semi affine multiplication expressions, t = expr * symbolic_expr, // introduce a local variable p (= expr * symbolic_expr), and the affine // expression expr * symbolic_expr is added to `localExprs`. -void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) { +LogicalResult SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) { assert(operandExprStack.size() >= 2); SmallVector<int64_t, 8> rhs = operandExprStack.back(); operandExprStack.pop_back(); @@ -1232,7 +1232,7 @@ void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) { AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols, localExprs, context); addLocalVariableSemiAffine(a * b, lhs, lhs.size()); - return; + return success(); } // Get the RHS constant. @@ -1240,9 +1240,10 @@ void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) { for (unsigned i = 0, e = lhs.size(); i < e; i++) { lhs[i] *= rhsConst; } + return success(); } -void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) { +LogicalResult SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) { assert(operandExprStack.size() >= 2); const auto &rhs = operandExprStack.back(); auto &lhs = operandExprStack[operandExprStack.size() - 2]; @@ -1253,6 +1254,7 @@ void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) { } // Pop off the RHS. operandExprStack.pop_back(); + return success(); } // @@ -1265,7 +1267,7 @@ void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) { // In case of semi-affine modulo expressions, t = expr mod symbolic_expr, // introduce a local variable m (= expr mod symbolic_expr), and the affine // expression expr mod symbolic_expr is added to `localExprs`. -void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) { +LogicalResult SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) { assert(operandExprStack.size() >= 2); SmallVector<int64_t, 8> rhs = operandExprStack.back(); @@ -1283,13 +1285,12 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) { localExprs, context); AffineExpr modExpr = dividendExpr % divisorExpr; addLocalVariableSemiAffine(modExpr, lhs, lhs.size()); - return; + return success(); } int64_t rhsConst = rhs[getConstantIndex()]; - // TODO: handle modulo by zero case when this issue is fixed - // at the other places in the IR. - assert(rhsConst > 0 && "RHS constant has to be positive"); + if (rhsConst <= 0) + return failure(); // Check if the LHS expression is a multiple of modulo factor. unsigned i, e; @@ -1299,7 +1300,7 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) { // If yes, modulo expression here simplifies to zero. if (i == lhs.size()) { std::fill(lhs.begin(), lhs.end(), 0); - return; + return success(); } // Add a local variable for the quotient, i.e., expr % c is replaced by @@ -1331,33 +1332,41 @@ void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) { // Reuse the existing local id. lhs[getLocalVarStartIndex() + loc] = -rhsConst; } + return success(); } -void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) { - visitDivExpr(expr, /*isCeil=*/true); +LogicalResult +SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) { + return visitDivExpr(expr, /*isCeil=*/true); } -void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) { - visitDivExpr(expr, /*isCeil=*/false); +LogicalResult +SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) { + return visitDivExpr(expr, /*isCeil=*/false); } -void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) { +LogicalResult SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) { operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); auto &eq = operandExprStack.back(); assert(expr.getPosition() < numDims && "Inconsistent number of dims"); eq[getDimStartIndex() + expr.getPosition()] = 1; + return success(); } -void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) { +LogicalResult +SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) { operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); auto &eq = operandExprStack.back(); assert(expr.getPosition() < numSymbols && "inconsistent number of symbols"); eq[getSymbolStartIndex() + expr.getPosition()] = 1; + return success(); } -void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) { +LogicalResult +SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) { operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0)); auto &eq = operandExprStack.back(); eq[getConstantIndex()] = expr.getValue(); + return success(); } void SimpleAffineExprFlattener::addLocalVariableSemiAffine( @@ -1388,8 +1397,8 @@ void SimpleAffineExprFlattener::addLocalVariableSemiAffine( // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to // `localExprs`. -void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr, - bool isCeil) { +LogicalResult SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr, + bool isCeil) { assert(operandExprStack.size() >= 2); MLIRContext *context = expr.getContext(); @@ -1407,14 +1416,13 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr, localExprs, context); AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b); addLocalVariableSemiAffine(divExpr, lhs, lhs.size()); - return; + return success(); } // This is a pure affine expr; the RHS is a positive constant. int64_t rhsConst = rhs[getConstantIndex()]; - // TODO: handle division by zero at the same time the issue is - // fixed at other places. - assert(rhsConst > 0 && "RHS constant has to be positive"); + if (rhsConst <= 0) + return failure(); // Simplify the floordiv, ceildiv if possible by canceling out the greatest // common divisors of the numerator and denominator. @@ -1430,7 +1438,7 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr, // If the divisor becomes 1, the updated LHS is the result. (The // divisor can't be negative since rhsConst is positive). if (divisor == 1) - return; + return success(); // If the divisor cannot be simplified to one, we will have to retain // the ceil/floor expr (simplified up until here). Add an existential @@ -1460,6 +1468,7 @@ void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr, lhs[getLocalVarStartIndex() + numLocals - 1] = 1; else lhs[getLocalVarStartIndex() + loc] = 1; + return success(); } // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr). @@ -1500,7 +1509,9 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims, expr = simplifySemiAffine(expr, numDims, numSymbols); SimpleAffineExprFlattener flattener(numDims, numSymbols); - flattener.walkPostOrder(expr); + // has poison expression + if (failed(flattener.walkPostOrder(expr))) + return expr; ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back(); if (!expr.isPureAffine() && expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols, @@ -1573,7 +1584,10 @@ std::optional<int64_t> mlir::getBoundForAffineExpr( } // Flatten the expression. SimpleAffineExprFlattener flattener(numDims, numSymbols); - flattener.walkPostOrder(expr); + auto simpleResult = flattener.walkPostOrder(expr); + // has poison expression + if (failed(simpleResult)) + return std::nullopt; ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back(); // TODO: Handle local variables. We can get hold of flattener.localExprs and // get bound on the local expr recursively. diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index 93a8d048e0a6..e0293812277a 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -8,6 +8,7 @@ #include "mlir/IR/AffineMap.h" #include "AffineMapDetail.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" @@ -59,13 +60,34 @@ private: expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; }); case AffineExprKind::Mod: return constantFoldBinExpr( - expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); }); + expr, + [expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> { + if (rhs < 1) { + hasPoison_ = true; + return std::nullopt; + } + return mod(lhs, rhs); + }); case AffineExprKind::FloorDiv: return constantFoldBinExpr( - expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); }); + expr, + [expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> { + if (rhs == 0) { + hasPoison_ = true; + return std::nullopt; + } + return floorDiv(lhs, rhs); + }); case AffineExprKind::CeilDiv: return constantFoldBinExpr( - expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); }); + expr, + [expr, this](int64_t lhs, int64_t rhs) -> std::optional<int64_t> { + if (rhs == 0) { + hasPoison_ = true; + return std::nullopt; + } + return ceilDiv(lhs, rhs); + }); case AffineExprKind::Constant: return cast<AffineConstantExpr>(expr).getValue(); case AffineExprKind::DimId: @@ -387,12 +409,12 @@ std::optional<unsigned> AffineMap::getResultPosition(AffineExpr input) const { /// Folds the results of the application of an affine map on the provided /// operands to a constant if possible. Returns false if the folding happens, /// true otherwise. -LogicalResult -AffineMap::constantFold(ArrayRef<Attribute> operandConstants, - SmallVectorImpl<Attribute> &results) const { +LogicalResult AffineMap::constantFold(ArrayRef<Attribute> operandConstants, + SmallVectorImpl<Attribute> &results, + bool *hasPoison) const { // Attempt partial folding. SmallVector<int64_t, 2> integers; - partialConstantFold(operandConstants, &integers); + partialConstantFold(operandConstants, &integers, hasPoison); // If all expressions folded to a constant, populate results with attributes // containing those constants. @@ -406,9 +428,9 @@ AffineMap::constantFold(ArrayRef<Attribute> operandConstants, return success(); } -AffineMap -AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants, - SmallVectorImpl<int64_t> *results) const { +AffineMap AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants, + SmallVectorImpl<int64_t> *results, + bool *hasPoison) const { assert(getNumInputs() == operandConstants.size()); // Fold each of the result expressions. @@ -418,6 +440,10 @@ AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants, for (auto expr : getResults()) { auto folded = exprFolder.constantFold(expr); + if (exprFolder.hasPoison() && hasPoison) { + *hasPoison = true; + return {}; + } // If did not fold to a constant, keep the original expression, and clear // the integer results vector. if (folded) { diff --git a/mlir/test/Dialect/Affine/constant-fold.mlir b/mlir/test/Dialect/Affine/constant-fold.mlir index cdce39855acd..5236b44ddfed 100644 --- a/mlir/test/Dialect/Affine/constant-fold.mlir +++ b/mlir/test/Dialect/Affine/constant-fold.mlir @@ -60,3 +60,24 @@ func.func @affine_min(%variable: index) -> (index, index) { // CHECK: return %[[r]], %[[C44]] return %0, %1 : index, index } + +// ----- + +func.func @affine_apply_poison_division_zero() { + // This is just for mlir::context to load ub dailect + %ub = ub.poison : index + %c16 = arith.constant 16 : index + %0 = affine.apply affine_map<(d0)[s0] -> (d0 mod (s0 - s0))>(%c16)[%c16] + %1 = affine.apply affine_map<(d0)[s0] -> (d0 floordiv (s0 - s0))>(%c16)[%c16] + %2 = affine.apply affine_map<(d0)[s0] -> (d0 ceildiv (s0 - s0))>(%c16)[%c16] + %alloc = memref.alloc(%0, %1, %2) : memref<?x?x?xi1> + %3 = affine.load %alloc[%0, %1, %2] : memref<?x?x?xi1> + affine.store %3, %alloc[%0, %1, %2] : memref<?x?x?xi1> + return +} + +// CHECK-NOT: affine.apply +// CHECK: %[[poison:.*]] = ub.poison : index +// CHECK-NEXT: %[[alloc:.*]] = memref.alloc(%[[poison]], %[[poison]], %[[poison]]) +// CHECK-NEXT: %[[load:.*]] = affine.load %[[alloc]][%[[poison]], %[[poison]], %[[poison]]] : memref<?x?x?xi1> +// CHECK-NEXT: affine.store %[[load]], %alloc[%[[poison]], %[[poison]], %[[poison]]] : memref<?x?x?xi1> |