diff options
author | Matthias Springer <me@m-sp.org> | 2023-12-01 17:35:55 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-02 02:35:55 +0100 |
commit | ccfc2d687c106ee8430fccd09e165e0aaea39081 (patch) | |
tree | db3d784512624369ef236f057e6614dd9d68d555 | |
parent | fc74db466b0d2b87d2013d5e24be137f0d8b6f0a (diff) |
[mlir][transform] Remove `cachedNames` expensive check (#73961)
This check was trying to find cases of invalid API usage:
incorrect/missing handle side effects and/or incorrect rewriter usage.
This check is not implemented correctly and can report false positives
in case of pointer reuse (different op created at same location). It is
unclear if such a check can be implemented given that we have both
tracking listener-based handle updates and handle consumption.
Fixes #72931.
-rw-r--r-- | mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h | 12 | ||||
-rw-r--r-- | mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp | 124 |
2 files changed, 3 insertions, 133 deletions
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index 5fcde11d52f0..2fdc15db9ad8 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -789,18 +789,6 @@ private: /// Each region must be an ancestor of the following regions in this list. /// These are also the keys for "mappings". SmallVector<Region *> regionStack; - - /// This cache stores operation names for operations that are tracked in the - /// transform dialect state. It is used to detect missing memory side effects - /// and op tracking. - /// - /// All tracked ops are added to this cache before a transform op is applied. - /// After the application of the transform op, the names of all tracked ops - /// are compared with the names in the cache. If there is a mismatch (or a - /// crash), op tracking is missing somewhere. This is typically a missing - /// "consumesHandle" side effect or a pattern that removes an op without - /// notifying a TrackingListener. - DenseMap<Operation *, OperationName> cachedNames; #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS }; diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index d0cd879d560c..de5b7a81286b 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -386,23 +386,6 @@ transform::TransformState::replacePayloadOp(Operation *op, dropMappingEntry(mappings.reverse, op, handle); } -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - if (options.getExpensiveChecksEnabled()) { - auto it = cachedNames.find(op); - assert(it != cachedNames.end() && "entry not found"); - assert(it->second == op->getName() && "operation name mismatch"); - cachedNames.erase(it); - if (replacement) { - auto insertion = - cachedNames.insert({replacement, replacement->getName()}); - if (!insertion.second) { - assert(insertion.first->second == replacement->getName() && - "operation is already cached with a different name"); - } - } - } -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - // Replace the pointed-to object of all handles with the replacement object. // In case a payload op was erased (replacement object is nullptr), a nullptr // is stored in the mapping. These nullptrs are removed after each transform. @@ -494,10 +477,10 @@ void transform::TransformState::recordOpHandleInvalidationOne( unsigned operandNo = consumingHandle.getOperandNumber(); for (Operation *ancestor : potentialAncestors) { // clang-format off - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, + DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); }); - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, - { (DBGS() << "----of payload with name: " + DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, + { (DBGS() << "----of payload with name: " << payloadOp->getName().getIdentifier() << "\n"); }); DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { (DBGS() << "----of payload: " << *payloadOp << "\n"); }); @@ -872,29 +855,6 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n"); } } - -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - // Cache Operation* -> OperationName mappings. These will be checked after - // the transform has been applied to detect incorrect memory side effects - // and missing op tracking. - for (std::unique_ptr<Mappings> &mapping : - llvm::make_second_range(mappings)) { - for (Operation *op : llvm::make_first_range(mapping->reverse)) { - auto insertion = cachedNames.insert({op, op->getName()}); - if (!insertion.second) { - if (insertion.first->second != op->getName()) { - // Operation is already in the cache, but with a different name. - DiagnosedDefiniteFailure diag = - emitDefiniteFailure(transform->getLoc()) - << "expensive checks failure: operation mismatch, expected " - << insertion.first->second; - diag.attachNote(op->getLoc()) << "payload op: " << op->getName(); - return diag; - } - } - } - } -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS } // Find which operands are consumed. @@ -908,22 +868,11 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { // IR after that. SmallVector<Value> origOpFlatResults; SmallVector<Operation *> origAssociatedOps; -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - DenseSet<Operation *> consumedPayloadOps; -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS for (OpOperand *opOperand : consumedOperands) { Value operand = opOperand->get(); if (llvm::isa<TransformHandleTypeInterface>(operand.getType())) { for (Operation *payloadOp : getPayloadOps(operand)) { llvm::append_range(origOpFlatResults, payloadOp->getResults()); -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - if (options.getExpensiveChecksEnabled()) { - // Store all consumed payload ops (and their nested ops) in a set for - // extra error checking. - payloadOp->walk( - [&](Operation *op) { consumedPayloadOps.insert(op); }); - } -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS } continue; } @@ -1004,63 +953,6 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { } } -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - if (options.getExpensiveChecksEnabled()) { - // Remove erased ops from the transform state. - for (Operation *op : consumedPayloadOps) { - // This payload op was consumed but it may still be mapped to one or - // multiple handles. Forget all handles that are mapped to the op, so that - // there are no dangling pointers in the transform dialect state. This is - // necessary so that the `cachedNames`-based checks work correctly. - // - // Note: Dangling pointers to erased payload ops are allowed if the - // corresponding handles are not used anymore. There is another - // "expensive-check" that looks for future uses of dangling payload op - // pointers (through arbitrary handles). Removing handles to erased ops - // does not interfere with the other expensive checks: handle invalidation - // happens earlier and keeps track of invalidated handles with - // pre-generated error messages, so we do not need the association to - // still be there when the invalidated handle is accessed. - SmallVector<Value> handles; - (void)getHandlesForPayloadOp(op, handles, /*includeOutOfScope=*/true); - for (Value handle : handles) - forgetMapping(handle, /*origOpFlatResults=*/ValueRange(), - /*allowOutOfScope=*/true); - cachedNames.erase(op); - } - - // Check cached operation names. - for (std::unique_ptr<Mappings> &mapping : - llvm::make_second_range(mappings)) { - for (Operation *op : llvm::make_first_range(mapping->reverse)) { - // Make sure that the name of the op has not changed. If it has changed, - // the op was removed and a new op was allocated at the same memory - // location. This means that we are missing op tracking somewhere. - auto cacheIt = cachedNames.find(op); - if (cacheIt == cachedNames.end()) { - DiagnosedDefiniteFailure diag = - emitDefiniteFailure(transform->getLoc()) - << "expensive checks failure: operation not found in cache"; - diag.attachNote(op->getLoc()) << "payload op"; - return diag; - } - // If the `getName` call (or the above `attachNote`) is crashing, we - // have a dangling pointer. This usually means that an op was erased but - // the transform dialect was not made aware of that; e.g., missing - // "consumesHandle" or rewriter usage. - if (cacheIt->second != op->getName()) { - DiagnosedDefiniteFailure diag = - emitDefiniteFailure(transform->getLoc()) - << "expensive checks failure: operation mismatch, expected " - << cacheIt->second; - diag.attachNote(op->getLoc()) << "payload op: " << op->getName(); - return diag; - } - } - } - } -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - if (failed(updateStateFromResults(results, transform->getResults()))) return DiagnosedSilenceableFailure::definiteFailure(); @@ -1150,16 +1042,6 @@ transform::TransformState::RegionScope::~RegionScope() { state.mappings.erase(region); #if LLVM_ENABLE_ABI_BREAKING_CHECKS - // If the last handle to a payload op has gone out of scope, we no longer - // need to store the cached name. Pointers may get reused, leading to - // incorrect associations in the cache. - for (Operation *op : referencedOps) { - SmallVector<Value> handles; - if (succeeded(state.getHandlesForPayloadOp(op, handles))) - continue; - state.cachedNames.erase(op); - } - state.regionStack.pop_back(); #endif // LLVM_ENABLE_ABI_BREAKING_CHECKS } |