summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2023-12-01 17:35:55 -0800
committerGitHub <noreply@github.com>2023-12-02 02:35:55 +0100
commitccfc2d687c106ee8430fccd09e165e0aaea39081 (patch)
treedb3d784512624369ef236f057e6614dd9d68d555
parentfc74db466b0d2b87d2013d5e24be137f0d8b6f0a (diff)
[mlir][transform] Remove `cachedNames` expensive check (#73961)
This check was trying to find cases of invalid API usage: incorrect/missing handle side effects and/or incorrect rewriter usage. This check is not implemented correctly and can report false positives in case of pointer reuse (different op created at same location). It is unclear if such a check can be implemented given that we have both tracking listener-based handle updates and handle consumption. Fixes #72931.
-rw-r--r--mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h12
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp124
2 files changed, 3 insertions, 133 deletions
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 5fcde11d52f0..2fdc15db9ad8 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -789,18 +789,6 @@ private:
/// Each region must be an ancestor of the following regions in this list.
/// These are also the keys for "mappings".
SmallVector<Region *> regionStack;
-
- /// This cache stores operation names for operations that are tracked in the
- /// transform dialect state. It is used to detect missing memory side effects
- /// and op tracking.
- ///
- /// All tracked ops are added to this cache before a transform op is applied.
- /// After the application of the transform op, the names of all tracked ops
- /// are compared with the names in the cache. If there is a mismatch (or a
- /// crash), op tracking is missing somewhere. This is typically a missing
- /// "consumesHandle" side effect or a pattern that removes an op without
- /// notifying a TrackingListener.
- DenseMap<Operation *, OperationName> cachedNames;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
};
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index d0cd879d560c..de5b7a81286b 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -386,23 +386,6 @@ transform::TransformState::replacePayloadOp(Operation *op,
dropMappingEntry(mappings.reverse, op, handle);
}
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- if (options.getExpensiveChecksEnabled()) {
- auto it = cachedNames.find(op);
- assert(it != cachedNames.end() && "entry not found");
- assert(it->second == op->getName() && "operation name mismatch");
- cachedNames.erase(it);
- if (replacement) {
- auto insertion =
- cachedNames.insert({replacement, replacement->getName()});
- if (!insertion.second) {
- assert(insertion.first->second == replacement->getName() &&
- "operation is already cached with a different name");
- }
- }
- }
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
-
// Replace the pointed-to object of all handles with the replacement object.
// In case a payload op was erased (replacement object is nullptr), a nullptr
// is stored in the mapping. These nullptrs are removed after each transform.
@@ -494,10 +477,10 @@ void transform::TransformState::recordOpHandleInvalidationOne(
unsigned operandNo = consumingHandle.getOperandNumber();
for (Operation *ancestor : potentialAncestors) {
// clang-format off
- DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
+ DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
{ (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); });
- DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
- { (DBGS() << "----of payload with name: "
+ DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
+ { (DBGS() << "----of payload with name: "
<< payloadOp->getName().getIdentifier() << "\n"); });
DEBUG_WITH_TYPE(DEBUG_TYPE_FULL,
{ (DBGS() << "----of payload: " << *payloadOp << "\n"); });
@@ -872,29 +855,6 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n");
}
}
-
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- // Cache Operation* -> OperationName mappings. These will be checked after
- // the transform has been applied to detect incorrect memory side effects
- // and missing op tracking.
- for (std::unique_ptr<Mappings> &mapping :
- llvm::make_second_range(mappings)) {
- for (Operation *op : llvm::make_first_range(mapping->reverse)) {
- auto insertion = cachedNames.insert({op, op->getName()});
- if (!insertion.second) {
- if (insertion.first->second != op->getName()) {
- // Operation is already in the cache, but with a different name.
- DiagnosedDefiniteFailure diag =
- emitDefiniteFailure(transform->getLoc())
- << "expensive checks failure: operation mismatch, expected "
- << insertion.first->second;
- diag.attachNote(op->getLoc()) << "payload op: " << op->getName();
- return diag;
- }
- }
- }
- }
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
// Find which operands are consumed.
@@ -908,22 +868,11 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
// IR after that.
SmallVector<Value> origOpFlatResults;
SmallVector<Operation *> origAssociatedOps;
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- DenseSet<Operation *> consumedPayloadOps;
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
for (OpOperand *opOperand : consumedOperands) {
Value operand = opOperand->get();
if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) {
for (Operation *payloadOp : getPayloadOps(operand)) {
llvm::append_range(origOpFlatResults, payloadOp->getResults());
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- if (options.getExpensiveChecksEnabled()) {
- // Store all consumed payload ops (and their nested ops) in a set for
- // extra error checking.
- payloadOp->walk(
- [&](Operation *op) { consumedPayloadOps.insert(op); });
- }
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
continue;
}
@@ -1004,63 +953,6 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
}
}
-#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- if (options.getExpensiveChecksEnabled()) {
- // Remove erased ops from the transform state.
- for (Operation *op : consumedPayloadOps) {
- // This payload op was consumed but it may still be mapped to one or
- // multiple handles. Forget all handles that are mapped to the op, so that
- // there are no dangling pointers in the transform dialect state. This is
- // necessary so that the `cachedNames`-based checks work correctly.
- //
- // Note: Dangling pointers to erased payload ops are allowed if the
- // corresponding handles are not used anymore. There is another
- // "expensive-check" that looks for future uses of dangling payload op
- // pointers (through arbitrary handles). Removing handles to erased ops
- // does not interfere with the other expensive checks: handle invalidation
- // happens earlier and keeps track of invalidated handles with
- // pre-generated error messages, so we do not need the association to
- // still be there when the invalidated handle is accessed.
- SmallVector<Value> handles;
- (void)getHandlesForPayloadOp(op, handles, /*includeOutOfScope=*/true);
- for (Value handle : handles)
- forgetMapping(handle, /*origOpFlatResults=*/ValueRange(),
- /*allowOutOfScope=*/true);
- cachedNames.erase(op);
- }
-
- // Check cached operation names.
- for (std::unique_ptr<Mappings> &mapping :
- llvm::make_second_range(mappings)) {
- for (Operation *op : llvm::make_first_range(mapping->reverse)) {
- // Make sure that the name of the op has not changed. If it has changed,
- // the op was removed and a new op was allocated at the same memory
- // location. This means that we are missing op tracking somewhere.
- auto cacheIt = cachedNames.find(op);
- if (cacheIt == cachedNames.end()) {
- DiagnosedDefiniteFailure diag =
- emitDefiniteFailure(transform->getLoc())
- << "expensive checks failure: operation not found in cache";
- diag.attachNote(op->getLoc()) << "payload op";
- return diag;
- }
- // If the `getName` call (or the above `attachNote`) is crashing, we
- // have a dangling pointer. This usually means that an op was erased but
- // the transform dialect was not made aware of that; e.g., missing
- // "consumesHandle" or rewriter usage.
- if (cacheIt->second != op->getName()) {
- DiagnosedDefiniteFailure diag =
- emitDefiniteFailure(transform->getLoc())
- << "expensive checks failure: operation mismatch, expected "
- << cacheIt->second;
- diag.attachNote(op->getLoc()) << "payload op: " << op->getName();
- return diag;
- }
- }
- }
- }
-#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
-
if (failed(updateStateFromResults(results, transform->getResults())))
return DiagnosedSilenceableFailure::definiteFailure();
@@ -1150,16 +1042,6 @@ transform::TransformState::RegionScope::~RegionScope() {
state.mappings.erase(region);
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
- // If the last handle to a payload op has gone out of scope, we no longer
- // need to store the cached name. Pointers may get reused, leading to
- // incorrect associations in the cache.
- for (Operation *op : referencedOps) {
- SmallVector<Value> handles;
- if (succeeded(state.getHandlesForPayloadOp(op, handles)))
- continue;
- state.cachedNames.erase(op);
- }
-
state.regionStack.pop_back();
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}