summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorPeiming Liu <36770114+PeimingLiu@users.noreply.github.com>2023-11-08 11:28:00 -0800
committerGitHub <noreply@github.com>2023-11-08 11:28:00 -0800
commitc99951d4916e18c9191d6a25a4a4fb1b2243d4c4 (patch)
treeb37d4297d37f541cd489cbe5519787981c0c9bb1
parent64d413efdd76f2e6464ae6f578161811b9d12411 (diff)
[mlir][sparse] end-to-end matmul between Dense and BSR tensors (#71448)
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp461
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp4
-rw-r--r--mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir8
-rw-r--r--mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir123
4 files changed, 471 insertions, 125 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index 307a609fd1b7..83e86d137335 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -16,27 +16,316 @@
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
using namespace mlir;
using namespace mlir::sparse_tensor;
+namespace {
+
+//===----------------------------------------------------------------------===//
+// File Local Helper classes.
+//===----------------------------------------------------------------------===//
+
+// CRTP to help implementing a rewriter that demaps all its inputs.
+template <typename SubClass, typename SourceOp>
+struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
+ using OpRewritePattern<SourceOp>::OpRewritePattern;
+ using OpAdaptor = typename SourceOp::Adaptor;
+
+ LogicalResult matchAndRewrite(SourceOp op,
+ PatternRewriter &rewriter) const override {
+ Location loc = op.getLoc();
+ // Demaps non-trivial inputs.
+ SmallVector<Value> deMappedIns(op->getOperands());
+ for (Value &in : deMappedIns)
+ if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
+ in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
+
+ // CRTP call.
+ OpAdaptor adaptor(deMappedIns, op);
+ return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
+ rewriter);
+ }
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
+ explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
+ void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
+ BitVector dims;
+};
+
+// Flattens an affine expression into a list of AffineDimExprs.
+struct AffineExprAdmissibleVisitor
+ : public AffineExprVisitor<AffineExprAdmissibleVisitor> {
+ explicit AffineExprAdmissibleVisitor(bool isOutput)
+ : admissible(true), isOutput(isOutput){};
+
+ // We only allow AffineDimExpr on output.
+ void visitAddExpr(AffineBinaryOpExpr expr) {
+ if (isOutput)
+ admissible = false;
+ }
+ void visitMulExpr(AffineBinaryOpExpr expr) {
+ if (isOutput)
+ admissible = false;
+ }
+
+ // We disallow mod, floor div and ceil div on inputs.
+ void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; }
+ void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
+ void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; }
+ operator bool() { return admissible; }
+
+private:
+ bool admissible;
+ bool isOutput;
+};
+
+// The first BitVector stores levels where inadmissible exprs are used.
+// The second BitVector stores the AffineDimExp that are used by the
+// inadmissible expressions.
+using InadmissInfo = std::pair<BitVector, BitVector>;
+
+} // namespace
+
//===----------------------------------------------------------------------===//
// File Local Helper methods.
//===----------------------------------------------------------------------===//
-// Translates a "simple" map according to an identity lvl-map.
-static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
- AffineMap map) {
- unsigned lvlRank = stt.getLvlRank();
- AffineMap lvl2dim = stt.getLvlToDim();
- assert(lvl2dim.getNumInputs() == lvlRank);
- SmallVector<AffineExpr> exps;
- for (unsigned i = 0, n = map.getNumResults(); i < n; i++) {
- unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
- exps.push_back(lvl2dim.getResult(pos));
+// Collects the inadmissible affine expression imposed on levels.
+static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) {
+ auto ret = std::make_pair(BitVector(map.getNumResults()),
+ BitVector(map.getNumDims()));
+ AffineDimCollector collector(map.getNumDims());
+ for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) {
+ AffineExprAdmissibleVisitor admissible(isOutput);
+ admissible.walkPostOrder(map.getResult(lvl));
+ if (!admissible) {
+ // Record the inadmissible level.
+ ret.first.set(lvl);
+ // Record the AffineDimExpr that is used in the inadmissible expr.
+ collector.walkPostOrder(map.getResult(lvl));
+ }
+ }
+ ret.second = collector.dims;
+ return ret;
+}
+
+// Builds the AffineMap to replace the idx in idxMap to lvl such that all tht
+// inadmissible affine expressions can be eliminated.
+// For example, we can rewrite
+// idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3)
+// to
+// idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3)
+// by composing inverse(idxMap), that is
+// inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3)
+// -> ((l0 * 2 + l2) floordiv 2,
+// (l1 * 3 + l3) floordiv 3,
+// (l0 * 2 + l2) mod 2,
+// (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3)
+//
+// This function builds the inverse(idxMap) that replace every dimensions used
+// in `info` to levels, and updates the iterator type array `itTps` for the new
+// index variable introduced.
+//
+// Note that the returned affine map does not retain the order of the input
+// affine map. Instead, it always uses the first `info.inAdlvls.count()` for the
+// replaced levels, and remaining ones for unused dimensions.
+// For example, to handle
+// idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4)
+// which is a typical map for block_2to4. The function returns:
+// inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1)
+// in which, (l0, l1) together replaces `d1`, yet they appear
+// before `d0` in the resulting affine map.
+// The index (loop) order can later be canonicalized by a topo sort.
+static AffineMap
+genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap,
+ SmallVector<utils::IteratorType> &itTps) {
+ MLIRContext *ctx = idxMap.getContext();
+ auto [inAdLvls, usedDims] = info;
+ // Note that idxMap does not equal to dim2Lvl map, it is computed by
+ // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an
+ // ID map.
+ // TODO: we might fail here, in those case we should really return
+ // failure instead of assertion error.
+ auto lvl2Idx = inferLvlToDim(idxMap, ctx);
+
+ assert(lvl2Idx.getNumResults() <= idxMap.getNumDims());
+ if (lvl2Idx.getNumResults() != idxMap.getNumDims()) {
+ // This could happen when some dimensions are projected.
+ // E.g., idx2Lvl = (*i*, j, k) -> (j, k)
+ // ==> lvl2Idx = (j, k) -> (j, k)
+ // In this case, we append the unused dimesion at the end.
+ // ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k)
+ SmallVector<AffineExpr> results;
+ AffineDimCollector usedInLvl(idxMap.getNumDims());
+ for (auto e : idxMap.getResults())
+ usedInLvl.walkPostOrder(e);
+
+ unsigned curUsedDimID = 0;
+ unsigned curUnusedDimID = lvl2Idx.getNumDims();
+
+ BitVector unused = usedInLvl.dims.flip();
+ for (unsigned i = 0; i < idxMap.getNumDims(); i++) {
+ if (unused.test(i))
+ results.push_back(getAffineDimExpr(curUnusedDimID++, ctx));
+ else
+ results.push_back(lvl2Idx.getResult(curUsedDimID++));
+ }
+ lvl2Idx =
+ AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx);
+ }
+ assert(lvl2Idx.getNumResults() == idxMap.getNumDims());
+
+ // We do not need to replace the DimExpr that is not used in inadmissible
+ // level expressions. We use the first inAdLvl.count() dim to represent the
+ // replaced level, the remainings are reserved for unchanged ones.
+ // Note that results from the inverse map computed previously does not follow
+ // the convention we used, and we need to fix the mismatch below.
+ unsigned curRepID = 0;
+ unsigned curOriID = inAdLvls.count();
+ SmallVector<AffineExpr> results;
+ SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr());
+ SmallVector<utils::IteratorType> transItTps;
+
+ for (unsigned l : inAdLvls.set_bits()) {
+ // By our convention, the inadmissible level `l` always appears in the
+ // leading part (accumulated by curRepID) of the affine map's parameter
+ // list. Record the mapping so that we can replace all the uses of `l` to
+ // the correct position after the translation.
+ dimRep[l] = getAffineDimExpr(curRepID++, ctx);
+ // A new index variable is introduced for the inadmissible level, inherit
+ // the iterator type. E.g., if l0 = d0 floordiv 2, the
+ // iterator type of l0 equals to the iterator type of d0.
+ AffineExpr lvlExp = idxMap.getResult(l);
+ AffineDimCollector collector(idxMap.getNumDims());
+ collector.walkPostOrder(lvlExp);
+ // We assumes a level can only be derived from one dimension.
+ assert(collector.dims.count() == 1);
+ transItTps.push_back(itTps[collector.dims.find_first()]);
+ }
+
+ for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) {
+ if (usedDims.test(d)) {
+ // The dimension is used in some of the inadmissible levels, and it need
+ // to be inversed. Get the inversion from the inverse map, and fix the
+ // mismatch captured by the above loop.
+ results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep));
+ } else {
+ // The dimension is not used in any of the inadmissible levels, and it
+ // does not need to be inversed. Fix the mismatch by mapping it to the
+ // trailing part of the affine map (accumulated by curOriID).
+ results.push_back(getAffineDimExpr(curOriID++, ctx));
+ transItTps.push_back(itTps[d]);
+ }
}
- return AffineMap::get(lvlRank, 0, exps, builder.getContext());
+ unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count();
+ // Update iterator type.
+ itTps.assign(transItTps.begin(), transItTps.end());
+ return AffineMap::get(numDim, 0, results, ctx);
+}
+
+// Translates the index map in the linalg::GenericOp from idx->dim map to
+// idx->lvl map. Returns failure if the index map can not be translated to an
+// admissible form.
+// Returns the translated index map array and the iterator type array.
+static std::optional<std::pair<ArrayAttr, ArrayAttr>>
+translateMap(linalg::GenericOp op, PatternRewriter &rewriter) {
+ // idxMap is a idx2dim map before reinterpretation.
+ MLIRContext *ctx = op.getContext();
+ SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray();
+ SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray();
+ for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) {
+ Value tensor = op->getOpOperand(i).get();
+ auto stt = tryGetSparseTensorType(tensor);
+ if (stt && !stt->isIdentity()) {
+ AffineMap dim2Lvl = stt->getDimToLvl();
+ // By composing the idx2dim(dim2lvl), we got a idx2lvl Map
+ idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]);
+ }
+ }
+
+ // A naive way to handle common constant expressions that arise during dim2lvl
+ // translation.
+ auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping,
+ unsigned pos, int64_t lvlSz) {
+ if (!ShapedType::isDynamic(lvlSz)) {
+ auto c0 = getAffineConstantExpr(0, ctx);
+ auto lvlExp = getAffineDimExpr(pos, ctx);
+ auto szExp = getAffineConstantExpr(lvlSz, ctx);
+
+ // lvl floordiv lvlSz = 0
+ auto divExp =
+ getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp);
+ cstMapping.try_emplace(divExp, c0);
+
+ // lvl mod lvlSz = lvl
+ auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp);
+ cstMapping.try_emplace(modExp, lvlExp);
+ }
+ };
+
+ unsigned boundedNum = 0;
+ // A fixed-point algorithm.
+ bool changed = true;
+ while (changed) {
+ changed = false;
+ for (OpOperand &operand : op->getOpOperands()) {
+ auto stt = tryGetSparseTensorType(operand.get());
+ // Skip on dense operands.
+ if (!stt || !stt->getEncoding())
+ continue;
+
+ unsigned tid = operand.getOperandNumber();
+ bool isOutput = &operand == op.getDpsInitOperand(0);
+ AffineMap idxMap = idxMapArray[tid];
+ InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput);
+ auto [inAdLvls, dimExprs] = inAdInfo;
+ for (unsigned d : dimExprs.set_bits()) {
+ // The first `boundedNum` used in the AffineMap is introduced to
+ // resolve previous inadmissible expressions. We can not replace them
+ // as it might bring back the inadmissible expressions.
+ if (d < boundedNum)
+ return std::nullopt;
+ }
+
+ if (inAdLvls.count() != 0) {
+ // Naive constant progagation, should be sufficient to handle block
+ // sparsity in our cases.
+ SmallVector<int64_t> lvlShape = stt->getLvlShape();
+ DenseMap<AffineExpr, AffineExpr> cstMapping;
+ unsigned position = 0;
+ for (unsigned lvl : inAdLvls.set_bits()) {
+ int64_t lvlSz = lvlShape[lvl];
+ populateCstMapping(cstMapping, position, lvlSz);
+ position++;
+ }
+
+ AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps);
+ // Compose the lvl2Idx Map to all AffineIdxMap to eliminate
+ // inadmissible expressions.
+ for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) {
+ AffineMap transMap = idxMapArray[tid].compose(lvl2Idx);
+ idxMapArray[tid] = transMap.replace(
+ cstMapping, /*numResultDims=*/transMap.getNumDims(),
+ /*numResultSyms=*/0);
+ }
+ changed = true;
+ boundedNum += inAdLvls.count();
+ }
+ }
+ };
+
+ SmallVector<Attribute> iterAttr =
+ llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute {
+ return linalg::IteratorTypeAttr::get(ctx, itTp);
+ });
+
+ return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray),
+ rewriter.getArrayAttr(iterAttr));
}
// Generates a "de"mapping reinterpretation of the map.
@@ -73,41 +362,6 @@ static bool hasNonIdentityOperandsOrResults(Operation *op) {
llvm::any_of(op->getResults(), hasNonIdentityMap);
}
-// Generates a clone of the given linalg generic operation, but with
-// remapped arguments, index maps, and iteration types.
-//
-// TODO: As decribed below, this is proof-of-concept code which makes a lot
-// of simplifying assumptions for now.
-//
-static linalg::GenericOp genGenericLinalg(PatternRewriter &rewriter,
- linalg::GenericOp linalgOp,
- SparseTensorType stt, Value out) {
- unsigned dimRank = stt.getDimRank();
- unsigned lvlRank = stt.getLvlRank();
- SmallVector<Value> inputOps = linalgOp.getInputs();
- SmallVector<Value> outputOps = {out};
- SmallVector<AffineMap> indexMaps;
- SmallVector<utils::IteratorType> iterTypes;
- // Translate the index maps, except output map, which is lvl-identity.
- auto maps = linalgOp.getIndexingMapsArray();
- for (unsigned i = 0, n = maps.size() - 1; i < n; i++)
- indexMaps.push_back(translateMap(rewriter, stt, maps[i]));
- indexMaps.push_back(
- AffineMap::getMultiDimIdentityMap(lvlRank, rewriter.getContext()));
- // Add additional "parallel" iteration types at the top.
- for (unsigned i = 0, diff = lvlRank = dimRank; i < diff; i++)
- iterTypes.push_back(utils::IteratorType::parallel);
- for (auto &i : linalgOp.getIteratorTypesArray())
- iterTypes.push_back(i);
- // Generate the new linalg generic operation and clone body.
- auto newOp = rewriter.create<linalg::GenericOp>(
- linalgOp.getLoc(), out.getType(), inputOps, outputOps, indexMaps,
- iterTypes);
- rewriter.cloneRegionBefore(linalgOp.getRegion(), newOp.getRegion(),
- newOp.getRegion().begin());
- return newOp;
-}
-
namespace {
//===----------------------------------------------------------------------===//
@@ -115,53 +369,45 @@ namespace {
//===----------------------------------------------------------------------===//
/// Sparse rewriting rule for the generic `linalg` operation.
-struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
+struct GenericOpReinterpretMap
+ : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> {
public:
- GenericOpReinterpretMap(MLIRContext *context)
- : OpRewritePattern<linalg::GenericOp>(context) {}
-
- LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
- PatternRewriter &rewriter) const override {
- // Only rewrite single output operations with pure tensor semantics.
- if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics())
+ using DemapInsRewriter::DemapInsRewriter;
+ LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor,
+ PatternRewriter &rewriter) const {
+ // Only rewrite single output operations with pure (sparse) tensor
+ // semantics.
+ if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() ||
+ !hasAnySparseOperandOrResult(linalgOp) ||
+ !hasNonIdentityOperandsOrResults(linalgOp))
return failure();
- // Scan all operands, inspect sparse tensors.
- //
- // TODO: generalize this proof-of-concept algorithm, since the current
- // implementation accepts only simple indexing maps, and one
- // non-permutation sparse tensor, which must have an identity
- // indexing map and be the output.
- //
- OpOperand *tx = nullptr;
- for (OpOperand &t : linalgOp->getOpOperands()) {
- // Ensure every index map is "simple".
- const auto map = linalgOp.getMatchingIndexingMap(&t);
- for (unsigned i = 0, n = map.getNumResults(); i < n; i++)
- if (map.getResult(i).getKind() != AffineExprKind::DimId)
- return failure();
- // Inspect sparse operands.
- auto stt = tryGetSparseTensorType(t.get());
- if (stt && stt->hasEncoding()) {
- if (stt->isPermutation())
- continue;
- assert(stt->getDimRank() < stt->getLvlRank()); // only allowed non-perm
- if (tx)
- return failure(); // more than one non-perm
- if (!map.isIdentity())
- return failure(); // no ID indexing map on the non-perm
- tx = &t;
- }
- }
- // Found a non-permutation, rewrite when this is the output.
- if (tx && tx == linalgOp.getDpsInitOperand(0)) {
- auto stt = getSparseTensorType(tx->get());
- auto demap = genDemap(rewriter, stt.getEncoding(), tx->get());
- auto newOp = genGenericLinalg(rewriter, linalgOp, stt, demap);
- auto remap = genRemap(rewriter, stt.getEncoding(), newOp.getResult(0));
- rewriter.replaceOp(linalgOp, remap);
- return success();
+
+ // Try translating the index map.
+ auto transMap = translateMap(linalgOp, rewriter);
+ if (!transMap)
+ return rewriter.notifyMatchFailure(
+ linalgOp, "the sparse kernel can not be sparsified.");
+
+ // On success, replace update the linalg operands and maps in place.
+ Value res = linalgOp.getResult(0);
+ auto stt = tryGetSparseTensorType(res);
+ auto [idxMap, itTp] = *transMap;
+
+ rewriter.startRootUpdate(linalgOp);
+ linalgOp.setIndexingMapsAttr(idxMap);
+ linalgOp.setIteratorTypesAttr(itTp);
+ // Use demapped arguments.
+ linalgOp.getInputsMutable().assign(adaptor.getInputs());
+ linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs());
+ res.setType(adaptor.getOutputs()[0].getType());
+ rewriter.finalizeRootUpdate(linalgOp);
+
+ rewriter.setInsertionPointAfter(linalgOp);
+ if (stt && stt->hasEncoding()) {
+ Value t = genRemap(rewriter, stt->getEncoding(), res);
+ rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp());
}
- return failure();
+ return success();
}
};
@@ -169,32 +415,10 @@ public:
// Reinterpret Map Rewriters for operations other than linalg.generics
//===----------------------------------------------------------------------===//
-// CRTP to help implementing a rewriter that demaps all its inputs.
-template <typename SubClass, typename SourceOp>
-struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
- using OpRewritePattern<SourceOp>::OpRewritePattern;
- using OpAdaptor = typename SourceOp::Adaptor;
-
- LogicalResult matchAndRewrite(SourceOp op,
- PatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
- // Demaps non-trivial inputs.
- SmallVector<Value> deMappedIns(op->getOperands());
- for (Value &in : deMappedIns)
- if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
- in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
-
- // CRTP call.
- OpAdaptor adaptor(deMappedIns);
- return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor,
- rewriter);
- }
-};
-
-struct TensorAllocDemapper
- : public OpRewritePattern<bufferization::AllocTensorOp> {
- using OpRewritePattern::OpRewritePattern;
- LogicalResult matchAndRewrite(bufferization::AllocTensorOp op,
+template <typename AllocOp>
+struct TensorAllocDemapper : public OpRewritePattern<AllocOp> {
+ using OpRewritePattern<AllocOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(AllocOp op,
PatternRewriter &rewriter) const override {
if (!hasNonIdentityOperandsOrResults(op))
return failure();
@@ -362,7 +586,8 @@ void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
}
if (scope == ReinterpretMapScope::kAll ||
scope == ReinterpretMapScope::kExceptGeneric) {
- patterns.add<TensorAllocDemapper, TensorInsertDemapper, ForeachOpDemapper>(
- patterns.getContext());
+ patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
+ TensorAllocDemapper<tensor::EmptyOp>, TensorInsertDemapper,
+ ForeachOpDemapper>(patterns.getContext());
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 4a293f6819d0..e20b98add19a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -139,13 +139,11 @@ public:
// of `bufferization.alloc_tensor` ops.
{
OpPassManager pm("builtin.module");
- pm.addPass(
- createSparseReinterpretMapPass(ReinterpretMapScope::kGenericOnly));
+ pm.addPass(createSparseReinterpretMapPass(ReinterpretMapScope::kAll));
pm.addPass(createSparsificationPass(sparsificationOptions));
pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
/*enableConvert=*/true));
- // Handle dim-to-lvl maps on operations other than linalg.generic.
pm.addPass(
createSparseReinterpretMapPass(ReinterpretMapScope::kExceptGeneric));
pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass());
diff --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
index 972364289ac2..c4931c62c626 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir
@@ -19,15 +19,15 @@
)
}>
-// CHECK: #[[$map0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 4 + d3)>
-// CHECK: #[[$map1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1 * 4 + d3, d0 * 2 + d2)>
-// CHECK: #[[$map2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 4 + d3)>
+// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1 * 4 + d3, d0 * 2 + d2)>
+// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func @mul(
// CHECK-SAME: %[[A0:.*0]]: tensor<32x32xf32>,
// CHECK-SAME: %[[A1:.*1]]: tensor<32x32xf32>,
// CHECK-SAME: %[[A2:.*2]]: tensor<32x32xf32, #sparse_tensor.encoding<{{{.*}}}>>)
// CHECK: %[[T0:.*]] = sparse_tensor.reinterpret_map %[[A2]]
-// CHECK: %[[T1:.*]] = linalg.generic {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
+// CHECK: %[[T1:.*]] = linalg.generic {doc = {{.*}} indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
// CHECK: %[[T2:.*]] = sparse_tensor.reinterpret_map %[[T1]]
// CHECK: return %[[T2]] : tensor<32x32xf32, #sparse_tensor.encoding<{{{.*}}}>>
func.func @mul(%arg0: tensor<32x32xf32>,
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
new file mode 100644
index 000000000000..7e9606b1caed
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir
@@ -0,0 +1,123 @@
+//--------------------------------------------------------------------------------------------------
+// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
+//
+// Set-up that's shared across all tests in this directory. In principle, this
+// config could be moved to lit.local.cfg. However, there are downstream users that
+// do not use these LIT config files. Hence why this is kept inline.
+//
+// DEFINE: %{sparse_compiler_opts} = enable-runtime-library=true
+// DEFINE: %{sparse_compiler_opts_sve} = enable-arm-sve=true %{sparse_compiler_opts}
+// DEFINE: %{compile} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts}"
+// DEFINE: %{compile_sve} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts_sve}"
+// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
+// DEFINE: %{run_opts} = -e entry -entry-point-result=void
+// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
+// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
+//
+// DEFINE: %{env} =
+//--------------------------------------------------------------------------------------------------
+
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation.
+// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false enable-index-reduction=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation and vectorization.
+// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false enable-index-reduction=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true
+// RUN: %{compile} | %{run} | FileCheck %s
+//
+// Do the same run, but now with direct IR generation and VLA vectorization.
+// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
+
+#trait_mul = {
+ indexing_maps = [
+ affine_map<(i,j,k) -> (i,k)>, // A (in)
+ affine_map<(i,j,k) -> (j,k)>, // B (in, transposed)
+ affine_map<(i,j,k) -> (i,j)> // X (out)
+ ],
+ iterator_types = ["parallel", "parallel", "reduction"],
+ doc = "X(i,j) *= A(i,j) * B(j,i)"
+}
+
+
+#BSR = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i floordiv 2 : dense,
+ j floordiv 3 : compressed,
+ i mod 2 : dense,
+ j mod 3 : dense
+ )
+}>
+
+module {
+
+func.func @mul(%arg0: tensor<4x6xf64>,
+ %arg1: tensor<4x6xf64, #BSR>) -> tensor<4x4xf64> {
+ %out = tensor.empty() : tensor<4x4xf64>
+ %0 = linalg.generic #trait_mul
+ ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64, #BSR>)
+ outs(%out: tensor<4x4xf64>) {
+ ^bb(%x: f64, %y : f64, %z : f64):
+ %1 = arith.mulf %x, %y : f64
+ %2 = arith.addf %1, %z : f64
+ linalg.yield %2 : f64
+ } -> tensor<4x4xf64>
+ return %0 : tensor<4x4xf64>
+}
+
+func.func @mul_dense(%arg0: tensor<4x6xf64>,
+ %arg1: tensor<4x6xf64>) -> tensor<4x4xf64> {
+ %out = tensor.empty() : tensor<4x4xf64>
+ %0 = linalg.generic #trait_mul
+ ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64>)
+ outs(%out: tensor<4x4xf64>) {
+ ^bb(%x: f64, %y : f64, %z : f64):
+ %1 = arith.mulf %x, %y : f64
+ %2 = arith.addf %1, %z : f64
+ linalg.yield %2 : f64
+ } -> tensor<4x4xf64>
+ return %0 : tensor<4x4xf64>
+}
+
+
+ //
+ // Output utilities.
+ //
+ func.func @dumpf64(%arg0: tensor<4x4xf64>) {
+ %c0 = arith.constant 0 : index
+ %d0 = arith.constant -1.0 : f64
+ %0 = vector.transfer_read %arg0[%c0, %c0], %d0: tensor<4x4xf64>, vector<4x4xf64>
+ vector.print %0 : vector<4x4xf64>
+ return
+ }
+
+ //
+ // Main driver.
+ //
+ func.func @entry() {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+
+
+ %td = arith.constant dense<[[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
+ [ 6.0, 7.0, 8.0, 9.0, 10.0, 11.0],
+ [12.0, 13.0, 14.0, 15.0, 16.0, 17.0],
+ [18.0, 19.0, 20.0, 21.0, 22.0, 23.0]]> : tensor<4x6xf64>
+
+
+ %2 = sparse_tensor.convert %td : tensor<4x6xf64> to tensor<4x6xf64, #BSR>
+
+ %d = call @mul_dense(%td, %td)
+ : (tensor<4x6xf64>, tensor<4x6xf64>) -> tensor<4x4xf64>
+ %s = call @mul(%td, %2)
+ : (tensor<4x6xf64>, tensor<4x6xf64, #BSR>) -> tensor<4x4xf64>
+
+ // CHECK-COUNT-2: ( ( 55, 145, 235, 325 ), ( 145, 451, 757, 1063 ), ( 235, 757, 1279, 1801 ), ( 325, 1063, 1801, 2539 ) )
+ call @dumpf64(%d) : (tensor<4x4xf64>) -> ()
+ call @dumpf64(%s) : (tensor<4x4xf64>) -> ()
+
+ return
+ }
+}