diff options
author | Peiming Liu <36770114+PeimingLiu@users.noreply.github.com> | 2023-11-08 11:28:00 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-08 11:28:00 -0800 |
commit | c99951d4916e18c9191d6a25a4a4fb1b2243d4c4 (patch) | |
tree | b37d4297d37f541cd489cbe5519787981c0c9bb1 | |
parent | 64d413efdd76f2e6464ae6f578161811b9d12411 (diff) |
[mlir][sparse] end-to-end matmul between Dense and BSR tensors (#71448)
4 files changed, 471 insertions, 125 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp index 307a609fd1b7..83e86d137335 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp @@ -16,27 +16,316 @@ #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" using namespace mlir; using namespace mlir::sparse_tensor; +namespace { + +//===----------------------------------------------------------------------===// +// File Local Helper classes. +//===----------------------------------------------------------------------===// + +// CRTP to help implementing a rewriter that demaps all its inputs. +template <typename SubClass, typename SourceOp> +struct DemapInsRewriter : public OpRewritePattern<SourceOp> { + using OpRewritePattern<SourceOp>::OpRewritePattern; + using OpAdaptor = typename SourceOp::Adaptor; + + LogicalResult matchAndRewrite(SourceOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + // Demaps non-trivial inputs. + SmallVector<Value> deMappedIns(op->getOperands()); + for (Value &in : deMappedIns) + if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) + in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in); + + // CRTP call. + OpAdaptor adaptor(deMappedIns, op); + return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, + rewriter); + } +}; + +// Flattens an affine expression into a list of AffineDimExprs. +struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> { + explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){}; + void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); } + BitVector dims; +}; + +// Flattens an affine expression into a list of AffineDimExprs. +struct AffineExprAdmissibleVisitor + : public AffineExprVisitor<AffineExprAdmissibleVisitor> { + explicit AffineExprAdmissibleVisitor(bool isOutput) + : admissible(true), isOutput(isOutput){}; + + // We only allow AffineDimExpr on output. + void visitAddExpr(AffineBinaryOpExpr expr) { + if (isOutput) + admissible = false; + } + void visitMulExpr(AffineBinaryOpExpr expr) { + if (isOutput) + admissible = false; + } + + // We disallow mod, floor div and ceil div on inputs. + void visitModExpr(AffineBinaryOpExpr expr) { admissible = false; } + void visitFloorDivExpr(AffineBinaryOpExpr expr) { admissible = false; } + void visitCeilDivExpr(AffineBinaryOpExpr expr) { admissible = false; } + operator bool() { return admissible; } + +private: + bool admissible; + bool isOutput; +}; + +// The first BitVector stores levels where inadmissible exprs are used. +// The second BitVector stores the AffineDimExp that are used by the +// inadmissible expressions. +using InadmissInfo = std::pair<BitVector, BitVector>; + +} // namespace + //===----------------------------------------------------------------------===// // File Local Helper methods. //===----------------------------------------------------------------------===// -// Translates a "simple" map according to an identity lvl-map. -static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt, - AffineMap map) { - unsigned lvlRank = stt.getLvlRank(); - AffineMap lvl2dim = stt.getLvlToDim(); - assert(lvl2dim.getNumInputs() == lvlRank); - SmallVector<AffineExpr> exps; - for (unsigned i = 0, n = map.getNumResults(); i < n; i++) { - unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition(); - exps.push_back(lvl2dim.getResult(pos)); +// Collects the inadmissible affine expression imposed on levels. +static InadmissInfo collectInadmissInfo(AffineMap map, bool isOutput) { + auto ret = std::make_pair(BitVector(map.getNumResults()), + BitVector(map.getNumDims())); + AffineDimCollector collector(map.getNumDims()); + for (unsigned lvl = 0, e = map.getNumResults(); lvl < e; lvl++) { + AffineExprAdmissibleVisitor admissible(isOutput); + admissible.walkPostOrder(map.getResult(lvl)); + if (!admissible) { + // Record the inadmissible level. + ret.first.set(lvl); + // Record the AffineDimExpr that is used in the inadmissible expr. + collector.walkPostOrder(map.getResult(lvl)); + } + } + ret.second = collector.dims; + return ret; +} + +// Builds the AffineMap to replace the idx in idxMap to lvl such that all tht +// inadmissible affine expressions can be eliminated. +// For example, we can rewrite +// idxMap = (d0, d1) -> (d0 floordiv 2, d1 floordiv 3, d0 mod 2, d1 mod 3) +// to +// idxMap = (l0, l1, l2, l3) -> (l0, l1, l2, l3) +// by composing inverse(idxMap), that is +// inverse(idxMap) . idxMap = (l0, l1, l2, l3) -> (l0 * 2 + l2, l1 * 3 + l3) +// -> ((l0 * 2 + l2) floordiv 2, +// (l1 * 3 + l3) floordiv 3, +// (l0 * 2 + l2) mod 2, +// (l1 * 3 + l3) mod 3) = (l0, l1, l2, l3) +// +// This function builds the inverse(idxMap) that replace every dimensions used +// in `info` to levels, and updates the iterator type array `itTps` for the new +// index variable introduced. +// +// Note that the returned affine map does not retain the order of the input +// affine map. Instead, it always uses the first `info.inAdlvls.count()` for the +// replaced levels, and remaining ones for unused dimensions. +// For example, to handle +// idxMap = (d0, d1) -> (d0, d1 floordiv 4, d2 mod 4) +// which is a typical map for block_2to4. The function returns: +// inverse(idxMap) = (l0, l1, d0) -> (d0, l0 * 4 + l1) +// in which, (l0, l1) together replaces `d1`, yet they appear +// before `d0` in the resulting affine map. +// The index (loop) order can later be canonicalized by a topo sort. +static AffineMap +genReplaceDimToLvlMap(const InadmissInfo &info, AffineMap idxMap, + SmallVector<utils::IteratorType> &itTps) { + MLIRContext *ctx = idxMap.getContext(); + auto [inAdLvls, usedDims] = info; + // Note that idxMap does not equal to dim2Lvl map, it is computed by + // composing idx2Dim(dim2Lvl). They are only equal when idx2Dim is an + // ID map. + // TODO: we might fail here, in those case we should really return + // failure instead of assertion error. + auto lvl2Idx = inferLvlToDim(idxMap, ctx); + + assert(lvl2Idx.getNumResults() <= idxMap.getNumDims()); + if (lvl2Idx.getNumResults() != idxMap.getNumDims()) { + // This could happen when some dimensions are projected. + // E.g., idx2Lvl = (*i*, j, k) -> (j, k) + // ==> lvl2Idx = (j, k) -> (j, k) + // In this case, we append the unused dimesion at the end. + // ==> lvl2Idx = (j, k, *i*) -> (*i*, j, k) + SmallVector<AffineExpr> results; + AffineDimCollector usedInLvl(idxMap.getNumDims()); + for (auto e : idxMap.getResults()) + usedInLvl.walkPostOrder(e); + + unsigned curUsedDimID = 0; + unsigned curUnusedDimID = lvl2Idx.getNumDims(); + + BitVector unused = usedInLvl.dims.flip(); + for (unsigned i = 0; i < idxMap.getNumDims(); i++) { + if (unused.test(i)) + results.push_back(getAffineDimExpr(curUnusedDimID++, ctx)); + else + results.push_back(lvl2Idx.getResult(curUsedDimID++)); + } + lvl2Idx = + AffineMap::get(lvl2Idx.getNumDims() + unused.count(), 0, results, ctx); + } + assert(lvl2Idx.getNumResults() == idxMap.getNumDims()); + + // We do not need to replace the DimExpr that is not used in inadmissible + // level expressions. We use the first inAdLvl.count() dim to represent the + // replaced level, the remainings are reserved for unchanged ones. + // Note that results from the inverse map computed previously does not follow + // the convention we used, and we need to fix the mismatch below. + unsigned curRepID = 0; + unsigned curOriID = inAdLvls.count(); + SmallVector<AffineExpr> results; + SmallVector<AffineExpr> dimRep(idxMap.getNumResults(), AffineExpr()); + SmallVector<utils::IteratorType> transItTps; + + for (unsigned l : inAdLvls.set_bits()) { + // By our convention, the inadmissible level `l` always appears in the + // leading part (accumulated by curRepID) of the affine map's parameter + // list. Record the mapping so that we can replace all the uses of `l` to + // the correct position after the translation. + dimRep[l] = getAffineDimExpr(curRepID++, ctx); + // A new index variable is introduced for the inadmissible level, inherit + // the iterator type. E.g., if l0 = d0 floordiv 2, the + // iterator type of l0 equals to the iterator type of d0. + AffineExpr lvlExp = idxMap.getResult(l); + AffineDimCollector collector(idxMap.getNumDims()); + collector.walkPostOrder(lvlExp); + // We assumes a level can only be derived from one dimension. + assert(collector.dims.count() == 1); + transItTps.push_back(itTps[collector.dims.find_first()]); + } + + for (unsigned d = 0, e = idxMap.getNumDims(); d < e; d++) { + if (usedDims.test(d)) { + // The dimension is used in some of the inadmissible levels, and it need + // to be inversed. Get the inversion from the inverse map, and fix the + // mismatch captured by the above loop. + results.push_back(lvl2Idx.getResult(d).replaceDims(dimRep)); + } else { + // The dimension is not used in any of the inadmissible levels, and it + // does not need to be inversed. Fix the mismatch by mapping it to the + // trailing part of the affine map (accumulated by curOriID). + results.push_back(getAffineDimExpr(curOriID++, ctx)); + transItTps.push_back(itTps[d]); + } } - return AffineMap::get(lvlRank, 0, exps, builder.getContext()); + unsigned numDim = idxMap.getNumDims() - usedDims.count() + inAdLvls.count(); + // Update iterator type. + itTps.assign(transItTps.begin(), transItTps.end()); + return AffineMap::get(numDim, 0, results, ctx); +} + +// Translates the index map in the linalg::GenericOp from idx->dim map to +// idx->lvl map. Returns failure if the index map can not be translated to an +// admissible form. +// Returns the translated index map array and the iterator type array. +static std::optional<std::pair<ArrayAttr, ArrayAttr>> +translateMap(linalg::GenericOp op, PatternRewriter &rewriter) { + // idxMap is a idx2dim map before reinterpretation. + MLIRContext *ctx = op.getContext(); + SmallVector<AffineMap> idxMapArray = op.getIndexingMapsArray(); + SmallVector<utils::IteratorType> itTps = op.getIteratorTypesArray(); + for (unsigned i = 0, e = idxMapArray.size(); i < e; i++) { + Value tensor = op->getOpOperand(i).get(); + auto stt = tryGetSparseTensorType(tensor); + if (stt && !stt->isIdentity()) { + AffineMap dim2Lvl = stt->getDimToLvl(); + // By composing the idx2dim(dim2lvl), we got a idx2lvl Map + idxMapArray[i] = dim2Lvl.compose(idxMapArray[i]); + } + } + + // A naive way to handle common constant expressions that arise during dim2lvl + // translation. + auto populateCstMapping = [ctx](DenseMap<AffineExpr, AffineExpr> &cstMapping, + unsigned pos, int64_t lvlSz) { + if (!ShapedType::isDynamic(lvlSz)) { + auto c0 = getAffineConstantExpr(0, ctx); + auto lvlExp = getAffineDimExpr(pos, ctx); + auto szExp = getAffineConstantExpr(lvlSz, ctx); + + // lvl floordiv lvlSz = 0 + auto divExp = + getAffineBinaryOpExpr(AffineExprKind::FloorDiv, lvlExp, szExp); + cstMapping.try_emplace(divExp, c0); + + // lvl mod lvlSz = lvl + auto modExp = getAffineBinaryOpExpr(AffineExprKind::Mod, lvlExp, szExp); + cstMapping.try_emplace(modExp, lvlExp); + } + }; + + unsigned boundedNum = 0; + // A fixed-point algorithm. + bool changed = true; + while (changed) { + changed = false; + for (OpOperand &operand : op->getOpOperands()) { + auto stt = tryGetSparseTensorType(operand.get()); + // Skip on dense operands. + if (!stt || !stt->getEncoding()) + continue; + + unsigned tid = operand.getOperandNumber(); + bool isOutput = &operand == op.getDpsInitOperand(0); + AffineMap idxMap = idxMapArray[tid]; + InadmissInfo inAdInfo = collectInadmissInfo(idxMap, isOutput); + auto [inAdLvls, dimExprs] = inAdInfo; + for (unsigned d : dimExprs.set_bits()) { + // The first `boundedNum` used in the AffineMap is introduced to + // resolve previous inadmissible expressions. We can not replace them + // as it might bring back the inadmissible expressions. + if (d < boundedNum) + return std::nullopt; + } + + if (inAdLvls.count() != 0) { + // Naive constant progagation, should be sufficient to handle block + // sparsity in our cases. + SmallVector<int64_t> lvlShape = stt->getLvlShape(); + DenseMap<AffineExpr, AffineExpr> cstMapping; + unsigned position = 0; + for (unsigned lvl : inAdLvls.set_bits()) { + int64_t lvlSz = lvlShape[lvl]; + populateCstMapping(cstMapping, position, lvlSz); + position++; + } + + AffineMap lvl2Idx = genReplaceDimToLvlMap(inAdInfo, idxMap, itTps); + // Compose the lvl2Idx Map to all AffineIdxMap to eliminate + // inadmissible expressions. + for (unsigned tid = 0, e = idxMapArray.size(); tid < e; tid++) { + AffineMap transMap = idxMapArray[tid].compose(lvl2Idx); + idxMapArray[tid] = transMap.replace( + cstMapping, /*numResultDims=*/transMap.getNumDims(), + /*numResultSyms=*/0); + } + changed = true; + boundedNum += inAdLvls.count(); + } + } + }; + + SmallVector<Attribute> iterAttr = + llvm::map_to_vector(itTps, [ctx](auto itTp) -> Attribute { + return linalg::IteratorTypeAttr::get(ctx, itTp); + }); + + return std::make_pair(rewriter.getAffineMapArrayAttr(idxMapArray), + rewriter.getArrayAttr(iterAttr)); } // Generates a "de"mapping reinterpretation of the map. @@ -73,41 +362,6 @@ static bool hasNonIdentityOperandsOrResults(Operation *op) { llvm::any_of(op->getResults(), hasNonIdentityMap); } -// Generates a clone of the given linalg generic operation, but with -// remapped arguments, index maps, and iteration types. -// -// TODO: As decribed below, this is proof-of-concept code which makes a lot -// of simplifying assumptions for now. -// -static linalg::GenericOp genGenericLinalg(PatternRewriter &rewriter, - linalg::GenericOp linalgOp, - SparseTensorType stt, Value out) { - unsigned dimRank = stt.getDimRank(); - unsigned lvlRank = stt.getLvlRank(); - SmallVector<Value> inputOps = linalgOp.getInputs(); - SmallVector<Value> outputOps = {out}; - SmallVector<AffineMap> indexMaps; - SmallVector<utils::IteratorType> iterTypes; - // Translate the index maps, except output map, which is lvl-identity. - auto maps = linalgOp.getIndexingMapsArray(); - for (unsigned i = 0, n = maps.size() - 1; i < n; i++) - indexMaps.push_back(translateMap(rewriter, stt, maps[i])); - indexMaps.push_back( - AffineMap::getMultiDimIdentityMap(lvlRank, rewriter.getContext())); - // Add additional "parallel" iteration types at the top. - for (unsigned i = 0, diff = lvlRank = dimRank; i < diff; i++) - iterTypes.push_back(utils::IteratorType::parallel); - for (auto &i : linalgOp.getIteratorTypesArray()) - iterTypes.push_back(i); - // Generate the new linalg generic operation and clone body. - auto newOp = rewriter.create<linalg::GenericOp>( - linalgOp.getLoc(), out.getType(), inputOps, outputOps, indexMaps, - iterTypes); - rewriter.cloneRegionBefore(linalgOp.getRegion(), newOp.getRegion(), - newOp.getRegion().begin()); - return newOp; -} - namespace { //===----------------------------------------------------------------------===// @@ -115,53 +369,45 @@ namespace { //===----------------------------------------------------------------------===// /// Sparse rewriting rule for the generic `linalg` operation. -struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> { +struct GenericOpReinterpretMap + : public DemapInsRewriter<GenericOpReinterpretMap, linalg::GenericOp> { public: - GenericOpReinterpretMap(MLIRContext *context) - : OpRewritePattern<linalg::GenericOp>(context) {} - - LogicalResult matchAndRewrite(linalg::GenericOp linalgOp, - PatternRewriter &rewriter) const override { - // Only rewrite single output operations with pure tensor semantics. - if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics()) + using DemapInsRewriter::DemapInsRewriter; + LogicalResult rewriteOp(linalg::GenericOp linalgOp, OpAdaptor adaptor, + PatternRewriter &rewriter) const { + // Only rewrite single output operations with pure (sparse) tensor + // semantics. + if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics() || + !hasAnySparseOperandOrResult(linalgOp) || + !hasNonIdentityOperandsOrResults(linalgOp)) return failure(); - // Scan all operands, inspect sparse tensors. - // - // TODO: generalize this proof-of-concept algorithm, since the current - // implementation accepts only simple indexing maps, and one - // non-permutation sparse tensor, which must have an identity - // indexing map and be the output. - // - OpOperand *tx = nullptr; - for (OpOperand &t : linalgOp->getOpOperands()) { - // Ensure every index map is "simple". - const auto map = linalgOp.getMatchingIndexingMap(&t); - for (unsigned i = 0, n = map.getNumResults(); i < n; i++) - if (map.getResult(i).getKind() != AffineExprKind::DimId) - return failure(); - // Inspect sparse operands. - auto stt = tryGetSparseTensorType(t.get()); - if (stt && stt->hasEncoding()) { - if (stt->isPermutation()) - continue; - assert(stt->getDimRank() < stt->getLvlRank()); // only allowed non-perm - if (tx) - return failure(); // more than one non-perm - if (!map.isIdentity()) - return failure(); // no ID indexing map on the non-perm - tx = &t; - } - } - // Found a non-permutation, rewrite when this is the output. - if (tx && tx == linalgOp.getDpsInitOperand(0)) { - auto stt = getSparseTensorType(tx->get()); - auto demap = genDemap(rewriter, stt.getEncoding(), tx->get()); - auto newOp = genGenericLinalg(rewriter, linalgOp, stt, demap); - auto remap = genRemap(rewriter, stt.getEncoding(), newOp.getResult(0)); - rewriter.replaceOp(linalgOp, remap); - return success(); + + // Try translating the index map. + auto transMap = translateMap(linalgOp, rewriter); + if (!transMap) + return rewriter.notifyMatchFailure( + linalgOp, "the sparse kernel can not be sparsified."); + + // On success, replace update the linalg operands and maps in place. + Value res = linalgOp.getResult(0); + auto stt = tryGetSparseTensorType(res); + auto [idxMap, itTp] = *transMap; + + rewriter.startRootUpdate(linalgOp); + linalgOp.setIndexingMapsAttr(idxMap); + linalgOp.setIteratorTypesAttr(itTp); + // Use demapped arguments. + linalgOp.getInputsMutable().assign(adaptor.getInputs()); + linalgOp.getDpsInitsMutable().assign(adaptor.getOutputs()); + res.setType(adaptor.getOutputs()[0].getType()); + rewriter.finalizeRootUpdate(linalgOp); + + rewriter.setInsertionPointAfter(linalgOp); + if (stt && stt->hasEncoding()) { + Value t = genRemap(rewriter, stt->getEncoding(), res); + rewriter.replaceAllUsesExcept(res, t, t.getDefiningOp()); } - return failure(); + return success(); } }; @@ -169,32 +415,10 @@ public: // Reinterpret Map Rewriters for operations other than linalg.generics //===----------------------------------------------------------------------===// -// CRTP to help implementing a rewriter that demaps all its inputs. -template <typename SubClass, typename SourceOp> -struct DemapInsRewriter : public OpRewritePattern<SourceOp> { - using OpRewritePattern<SourceOp>::OpRewritePattern; - using OpAdaptor = typename SourceOp::Adaptor; - - LogicalResult matchAndRewrite(SourceOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - // Demaps non-trivial inputs. - SmallVector<Value> deMappedIns(op->getOperands()); - for (Value &in : deMappedIns) - if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity()) - in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in); - - // CRTP call. - OpAdaptor adaptor(deMappedIns); - return static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, - rewriter); - } -}; - -struct TensorAllocDemapper - : public OpRewritePattern<bufferization::AllocTensorOp> { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(bufferization::AllocTensorOp op, +template <typename AllocOp> +struct TensorAllocDemapper : public OpRewritePattern<AllocOp> { + using OpRewritePattern<AllocOp>::OpRewritePattern; + LogicalResult matchAndRewrite(AllocOp op, PatternRewriter &rewriter) const override { if (!hasNonIdentityOperandsOrResults(op)) return failure(); @@ -362,7 +586,8 @@ void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns, } if (scope == ReinterpretMapScope::kAll || scope == ReinterpretMapScope::kExceptGeneric) { - patterns.add<TensorAllocDemapper, TensorInsertDemapper, ForeachOpDemapper>( - patterns.getContext()); + patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>, + TensorAllocDemapper<tensor::EmptyOp>, TensorInsertDemapper, + ForeachOpDemapper>(patterns.getContext()); } } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index 4a293f6819d0..e20b98add19a 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -139,13 +139,11 @@ public: // of `bufferization.alloc_tensor` ops. { OpPassManager pm("builtin.module"); - pm.addPass( - createSparseReinterpretMapPass(ReinterpretMapScope::kGenericOnly)); + pm.addPass(createSparseReinterpretMapPass(ReinterpretMapScope::kAll)); pm.addPass(createSparsificationPass(sparsificationOptions)); pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass()); pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary, /*enableConvert=*/true)); - // Handle dim-to-lvl maps on operations other than linalg.generic. pm.addPass( createSparseReinterpretMapPass(ReinterpretMapScope::kExceptGeneric)); pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass()); diff --git a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir index 972364289ac2..c4931c62c626 100644 --- a/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir @@ -19,15 +19,15 @@ ) }> -// CHECK: #[[$map0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 4 + d3)> -// CHECK: #[[$map1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1 * 4 + d3, d0 * 2 + d2)> -// CHECK: #[[$map2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 4 + d3)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1 * 4 + d3, d0 * 2 + d2)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-LABEL: func @mul( // CHECK-SAME: %[[A0:.*0]]: tensor<32x32xf32>, // CHECK-SAME: %[[A1:.*1]]: tensor<32x32xf32>, // CHECK-SAME: %[[A2:.*2]]: tensor<32x32xf32, #sparse_tensor.encoding<{{{.*}}}>>) // CHECK: %[[T0:.*]] = sparse_tensor.reinterpret_map %[[A2]] -// CHECK: %[[T1:.*]] = linalg.generic {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} +// CHECK: %[[T1:.*]] = linalg.generic {doc = {{.*}} indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} // CHECK: %[[T2:.*]] = sparse_tensor.reinterpret_map %[[T1]] // CHECK: return %[[T2]] : tensor<32x32xf32, #sparse_tensor.encoding<{{{.*}}}>> func.func @mul(%arg0: tensor<32x32xf32>, diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir new file mode 100644 index 000000000000..7e9606b1caed --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_block_matmul.mlir @@ -0,0 +1,123 @@ +//-------------------------------------------------------------------------------------------------- +// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS. +// +// Set-up that's shared across all tests in this directory. In principle, this +// config could be moved to lit.local.cfg. However, there are downstream users that +// do not use these LIT config files. Hence why this is kept inline. +// +// DEFINE: %{sparse_compiler_opts} = enable-runtime-library=true +// DEFINE: %{sparse_compiler_opts_sve} = enable-arm-sve=true %{sparse_compiler_opts} +// DEFINE: %{compile} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts}" +// DEFINE: %{compile_sve} = mlir-opt %s --sparse-compiler="%{sparse_compiler_opts_sve}" +// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils +// DEFINE: %{run_opts} = -e entry -entry-point-result=void +// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs} +// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs} +// +// DEFINE: %{env} = +//-------------------------------------------------------------------------------------------------- + +// RUN: %{compile} | %{run} | FileCheck %s +// +// Do the same run, but now with direct IR generation. +// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false enable-index-reduction=true +// RUN: %{compile} | %{run} | FileCheck %s +// +// Do the same run, but now with direct IR generation and vectorization. +// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false enable-index-reduction=true vl=2 reassociate-fp-reductions=true enable-index-optimizations=true +// RUN: %{compile} | %{run} | FileCheck %s +// +// Do the same run, but now with direct IR generation and VLA vectorization. +// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %} + +#trait_mul = { + indexing_maps = [ + affine_map<(i,j,k) -> (i,k)>, // A (in) + affine_map<(i,j,k) -> (j,k)>, // B (in, transposed) + affine_map<(i,j,k) -> (i,j)> // X (out) + ], + iterator_types = ["parallel", "parallel", "reduction"], + doc = "X(i,j) *= A(i,j) * B(j,i)" +} + + +#BSR = #sparse_tensor.encoding<{ + map = ( i, j ) -> + ( i floordiv 2 : dense, + j floordiv 3 : compressed, + i mod 2 : dense, + j mod 3 : dense + ) +}> + +module { + +func.func @mul(%arg0: tensor<4x6xf64>, + %arg1: tensor<4x6xf64, #BSR>) -> tensor<4x4xf64> { + %out = tensor.empty() : tensor<4x4xf64> + %0 = linalg.generic #trait_mul + ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64, #BSR>) + outs(%out: tensor<4x4xf64>) { + ^bb(%x: f64, %y : f64, %z : f64): + %1 = arith.mulf %x, %y : f64 + %2 = arith.addf %1, %z : f64 + linalg.yield %2 : f64 + } -> tensor<4x4xf64> + return %0 : tensor<4x4xf64> +} + +func.func @mul_dense(%arg0: tensor<4x6xf64>, + %arg1: tensor<4x6xf64>) -> tensor<4x4xf64> { + %out = tensor.empty() : tensor<4x4xf64> + %0 = linalg.generic #trait_mul + ins(%arg0, %arg1: tensor<4x6xf64>, tensor<4x6xf64>) + outs(%out: tensor<4x4xf64>) { + ^bb(%x: f64, %y : f64, %z : f64): + %1 = arith.mulf %x, %y : f64 + %2 = arith.addf %1, %z : f64 + linalg.yield %2 : f64 + } -> tensor<4x4xf64> + return %0 : tensor<4x4xf64> +} + + + // + // Output utilities. + // + func.func @dumpf64(%arg0: tensor<4x4xf64>) { + %c0 = arith.constant 0 : index + %d0 = arith.constant -1.0 : f64 + %0 = vector.transfer_read %arg0[%c0, %c0], %d0: tensor<4x4xf64>, vector<4x4xf64> + vector.print %0 : vector<4x4xf64> + return + } + + // + // Main driver. + // + func.func @entry() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + + %td = arith.constant dense<[[ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0], + [ 6.0, 7.0, 8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0, 16.0, 17.0], + [18.0, 19.0, 20.0, 21.0, 22.0, 23.0]]> : tensor<4x6xf64> + + + %2 = sparse_tensor.convert %td : tensor<4x6xf64> to tensor<4x6xf64, #BSR> + + %d = call @mul_dense(%td, %td) + : (tensor<4x6xf64>, tensor<4x6xf64>) -> tensor<4x4xf64> + %s = call @mul(%td, %2) + : (tensor<4x6xf64>, tensor<4x6xf64, #BSR>) -> tensor<4x4xf64> + + // CHECK-COUNT-2: ( ( 55, 145, 235, 325 ), ( 145, 451, 757, 1063 ), ( 235, 757, 1279, 1801 ), ( 325, 1063, 1801, 2539 ) ) + call @dumpf64(%d) : (tensor<4x4xf64>) -> () + call @dumpf64(%s) : (tensor<4x4xf64>) -> () + + return + } +} |