summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorPeiming Liu <peiming@google.com>2024-05-01 15:37:38 -0700
committerGitHub <noreply@github.com>2024-05-01 15:37:38 -0700
commit78885395c802b5c1eea3decece3695a7240902b1 (patch)
tree80c75c8429cbf29f966511c6a85b186695575e27
parent0af415d436e8d352397d3c5b279fca5d9b4e29f5 (diff)
[mlir][sparse] support tensor.pad on CSR tensors (#90687)
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp141
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h6
-rw-r--r--mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir79
-rw-r--r--mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir30
4 files changed, 198 insertions, 58 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index caf55072ce32..dbec46d2616d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -95,8 +95,9 @@ public:
ValueRange getLvlBuffers() const override { return {}; }
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
+ assert(!inPadZone && "Not implemented");
Value p = parentPos.front();
Value posLo = MULI(p, lvlSize);
return {posLo, lvlSize};
@@ -115,7 +116,8 @@ public:
ValueRange getLvlBuffers() const override { return {}; }
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
+ assert(!inPadZone && "Not implemented");
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
// No need to linearize the position for non-annotated tensors.
return {C_IDX(0), lvlSize};
@@ -129,18 +131,42 @@ public:
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 &&
"compressed level must be the first non-unique level.");
- Value p = parentPos.front();
- SmallVector<Value> memCrd(batchPrefix);
- memCrd.push_back(p);
- Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
- memCrd.back() = ADDI(p, C_IDX(1));
- Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
- return {pLo, pHi};
+ auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
+ Value p = parentPos.front();
+ SmallVector<Value> memCrd(batchPrefix);
+ memCrd.push_back(p);
+ Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
+ memCrd.back() = ADDI(p, C_IDX(1));
+ Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
+ return {pLo, pHi};
+ };
+
+ if (inPadZone == nullptr)
+ return loadRange();
+
+ SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
+ scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
+ // True branch, returns a "fake" empty range [0, 0) if parent
+ // iterator is in pad zone.
+ b.setInsertionPointToStart(posRangeIf.thenBlock());
+
+ SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
+ b.create<scf::YieldOp>(l, emptyRange);
+
+ // False branch, returns the actual range.
+ b.setInsertionPointToStart(posRangeIf.elseBlock());
+ auto [pLo, pHi] = loadRange();
+ SmallVector<Value, 2> loadedRange{pLo, pHi};
+ b.create<scf::YieldOp>(l, loadedRange);
+
+ b.setInsertionPointAfter(posRangeIf);
+ ValueRange posRange = posRangeIf.getResults();
+ return {posRange.front(), posRange.back()};
}
};
@@ -151,9 +177,10 @@ public:
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 &&
"loose-compressed level must be the first non-unique level.");
+ assert(!inPadZone && "Not implemented");
SmallVector<Value> memCrd(batchPrefix);
Value p = parentPos.front();
p = MULI(p, C_IDX(2));
@@ -172,8 +199,9 @@ public:
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 || parentPos.size() == 2);
+ assert(!inPadZone && "Not implemented");
Value p = parentPos.front();
Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
@@ -191,9 +219,10 @@ public:
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 && isUnique() &&
"n:m level can not be non-unique.");
+ assert(!inPadZone && "Not implemented");
// Each n:m blk has exactly n specified elements.
auto n = getN(lt);
Value posLo = MULI(parentPos.front(), C_IDX(n));
@@ -325,23 +354,7 @@ public:
};
void genInitImpl(OpBuilder &b, Location l,
- const SparseIterator *parent) override {
-
- if (isBatchIterator() && batchCrds.size() <= stl.lvl)
- batchCrds.resize(stl.lvl + 1, nullptr);
-
- Value c0 = C_IDX(0);
- ValueRange pPos = c0;
- // If the parent iterator is a batch iterator, we also start from 0 (but
- // on a different batch).
- if (parent && !parent->isBatchIterator())
- pPos = parent->getCurPosition();
-
- ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
- std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
- // Seek to the lowest position.
- seek(posLo);
- }
+ const SparseIterator *parent) override;
ValuePair genForCond(OpBuilder &b, Location l) override {
if (randomAccessible())
@@ -465,8 +478,9 @@ public:
// A util base-iterator that delegates all methods to the wrapped iterator.
class SimpleWrapIterator : public SparseIterator {
public:
- SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
- : SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
+ SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
+ unsigned extraCursorVal = 0)
+ : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
return wrap->getCursorValTypes(b);
@@ -474,6 +488,7 @@ public:
bool isBatchIterator() const override { return wrap->isBatchIterator(); }
bool randomAccessible() const override { return wrap->randomAccessible(); };
bool iteratableByFor() const override { return wrap->iteratableByFor(); };
+
SmallVector<Value> serialize() const override { return wrap->serialize(); };
void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
@@ -586,10 +601,9 @@ class PadIterator : public SimpleWrapIterator {
public:
PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
Value padHigh)
- : SimpleWrapIterator(std::move(wrap), IterKind::kPad), padLow(padLow),
- padHigh(padHigh) {
- assert(!randomAccessible() && "Not implemented.");
- }
+ : SimpleWrapIterator(std::move(wrap), IterKind::kPad,
+ wrap->randomAccessible() ? 1 : 0),
+ padLow(padLow), padHigh(padHigh) {}
// For LLVM-style RTTI.
static bool classof(const SparseIterator *from) {
@@ -600,6 +614,26 @@ public:
return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
}
+ // Returns a pair of values for *upper*, *lower* bound respectively.
+ ValuePair genForCond(OpBuilder &b, Location l) override {
+ if (randomAccessible())
+ return {getCrd(), upperBound(b, l)};
+ return wrap->genForCond(b, l);
+ }
+
+ // For padded dense iterator, we append a `inPadZone: bool` in addition to
+ // values used by the wrapped iterator.
+ ValueRange getCurPosition() const override { return getCursor(); }
+
+ SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+ SmallVector<Type> ret = wrap->getCursorValTypes(b);
+ // Need an extra boolean value `inPadZone` for padded dense iterator.
+ if (randomAccessible())
+ ret.push_back(b.getI1Type());
+
+ return ret;
+ }
+
// The upper bound after padding becomes `size + padLow + padHigh`.
Value upperBound(OpBuilder &b, Location l) const override {
return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
@@ -613,6 +647,14 @@ public:
void locateImpl(OpBuilder &b, Location l, Value crd) override {
assert(randomAccessible());
+ wrap->locate(b, l, SUBI(crd, padLow));
+
+ // inPadZone = crd < padLow || crd >= size + padLow.
+ Value inPadLow = CMPI(ult, crd, padLow);
+ Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
+ getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
+
+ updateCrd(crd);
}
Value padLow, padHigh;
@@ -1227,6 +1269,33 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
}
+void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
+ const SparseIterator *parent) {
+
+ if (isBatchIterator() && batchCrds.size() <= stl.lvl)
+ batchCrds.resize(stl.lvl + 1, nullptr);
+
+ Value c0 = C_IDX(0);
+ ValueRange pPos = c0;
+ Value inPadZone = nullptr;
+ // If the parent iterator is a batch iterator, we also start from 0 (but
+ // on a different batch).
+ if (parent && !parent->isBatchIterator()) {
+ pPos = parent->getCurPosition();
+ if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
+ // A padded dense iterator create "sparse" padded zone, which need to be
+ // handled specially.
+ inPadZone = pPos.back();
+ pPos = pPos.drop_back();
+ }
+ }
+
+ ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
+ std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
+ // Seek to the lowest position.
+ seek(posLo);
+}
+
void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
const SparseIterator *) {
Value c0 = C_IDX(0);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 2e7eeb2a05f9..120a806536f1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -46,9 +46,9 @@ public:
///
/// For a sparse level, [posLo, loopHi) specifies the range of index pointer
/// to load coordinate from the coordinate buffer.
- virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
- ValueRange batchPrefix,
- ValueRange parentPos) const = 0;
+ virtual std::pair<Value, Value>
+ peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+ ValueRange parentPos, Value inPadZone = nullptr) const = 0;
Level getLevel() const { return lvl; }
LevelType getLT() const { return lt; }
diff --git a/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
new file mode 100644
index 000000000000..4f509bf747ab
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/fuse_sparse_pad_with_consumer.mlir
@@ -0,0 +1,79 @@
+// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -canonicalize | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed)
+}>
+
+#elemwise = {
+ indexing_maps = [
+ affine_map<(i,j) -> (i,j)>, // A
+ affine_map<(i,j) -> (i,j)>, // B
+ affine_map<(i,j) -> (i,j)> // X (out)
+ ],
+ iterator_types = ["parallel", "parallel"],
+ doc = "X(i,j) = A(i,j) OP B(i,j)"
+}
+
+
+// CHECK-LABEL: func.func @padded_mul(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4xf32, #sparse>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf32>) -> tensor<8x8xf32> {
+// CHECK-DAG: %[[VAL_2:.*]] = arith.constant -1 : index
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 6 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<8x8xf32>
+// CHECK: %[[VAL_10:.*]] = linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_9]] : tensor<8x8xf32>) -> tensor<8x8xf32>
+// CHECK: %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
+// CHECK: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf32, #sparse> to memref<?xf32>
+// CHECK: %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_10]] : memref<8x8xf32>
+// CHECK: linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_14]] : memref<8x8xf32>)
+// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_5]] {
+// CHECK: %[[VAL_16:.*]] = arith.subi %[[VAL_15]], %[[VAL_7]] : index
+// CHECK: %[[VAL_17:.*]] = arith.cmpi ult, %[[VAL_15]], %[[VAL_7]] : index
+// CHECK: %[[VAL_18:.*]] = arith.cmpi uge, %[[VAL_15]], %[[VAL_3]] : index
+// CHECK: %[[VAL_19:.*]] = arith.ori %[[VAL_17]], %[[VAL_18]] : i1
+// CHECK: %[[VAL_20:.*]]:2 = scf.if %[[VAL_19]] -> (index, index) {
+// CHECK: scf.yield %[[VAL_6]], %[[VAL_6]] : index, index
+// CHECK: } else {
+// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_16]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_15]], %[[VAL_2]] : index
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref<?xindex>
+// CHECK: scf.yield %[[VAL_21]], %[[VAL_23]] : index, index
+// CHECK: }
+// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_20]]#0 to %[[VAL_20]]#1 step %[[VAL_5]] {
+// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
+// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_7]] : index
+// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32>
+// CHECK: %[[VAL_29:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : tensor<8x8xf32>
+// CHECK: %[[VAL_30:.*]] = arith.mulf %[[VAL_28]], %[[VAL_29]] : f32
+// CHECK: memref.store %[[VAL_30]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: } {"Emitted from" = "linalg.generic"}
+// CHECK: %[[VAL_31:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<8x8xf32>
+// CHECK: return %[[VAL_31]] : tensor<8x8xf32>
+// CHECK: }
+func.func @padded_mul(%arg0: tensor<4x4xf32, #CSR>, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> {
+ %cst_0 = arith.constant 0.00000e+00 : f32
+ %buf = tensor.empty() : tensor<8x8xf32>
+ %s = linalg.fill ins(%cst_0 : f32) outs(%buf : tensor<8x8xf32>) -> tensor<8x8xf32>
+
+ %padded = tensor.pad %arg0 low[2, 2] high[2, 2] {
+ ^bb0(%arg75: index, %arg76: index):
+ tensor.yield %cst_0 : f32
+ } : tensor<4x4xf32, #CSR> to tensor<8x8xf32, #CSR>
+
+ %0 = linalg.generic #elemwise
+ ins(%padded, %arg1: tensor<8x8xf32, #CSR>, tensor<8x8xf32>)
+ outs(%s: tensor<8x8xf32>) {
+ ^bb(%a: f32, %b: f32, %x: f32):
+ %0 = arith.mulf %a, %b : f32
+ linalg.yield %0 : f32
+ } -> tensor<8x8xf32>
+
+ return %0 : tensor<8x8xf32>
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
index 92fbbf545582..50dd989416e2 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
@@ -30,16 +30,8 @@
// Do the same run, but now with direct IR generation and VLA vectorization.
// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
-#CCCC = #sparse_tensor.encoding<{
- map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : compressed)
-}>
-
-#CDCD = #sparse_tensor.encoding<{
- map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : dense)
-}>
-
-#DCCD = #sparse_tensor.encoding<{
- map = (d0, d1, d2, d3) -> (d0 : dense, d1 : compressed, d2 : compressed, d3 : dense)
+#CDCC_NHWC = #sparse_tensor.encoding<{
+ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : compressed)
}>
// Creates and returns 4-D buffer of size (%s1, %s2, %s3, %s4) filled with the value %f
@@ -66,7 +58,7 @@ func.func @conv_2d_nhwc_hwcf(%arg0: tensor<3x8x8x3xf32>, %arg1: tensor<5x5x3x1xf
return %ret : tensor<3x8x8x1xf32>
}
-func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
+func.func @conv_2d_nhwc_hwcf_CDCC_NHWC(%arg0: tensor<3x8x8x3xf32, #CDCC_NHWC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
%cst_0 = arith.constant 0.00000e+00 : f32
%buf = tensor.empty() : tensor<3x8x8x1xf32>
%s = linalg.fill ins(%cst_0 : f32) outs(%buf : tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
@@ -74,11 +66,11 @@ func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tens
%padded = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
^bb0(%arg75: index, %arg76: index, %arg77: index, %arg78: index):
tensor.yield %cst_0 : f32
- } : tensor<3x8x8x3xf32, #CCCC> to tensor<3x12x12x3xf32, #CCCC>
+ } : tensor<3x8x8x3xf32, #CDCC_NHWC> to tensor<3x12x12x3xf32, #CDCC_NHWC>
%ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
- ins (%padded, %arg1: tensor<3x12x12x3xf32, #CCCC>, tensor<5x5x3x1xf32>)
+ ins (%padded, %arg1: tensor<3x12x12x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>)
outs (%s: tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
return %ret : tensor<3x8x8x1xf32>
}
@@ -105,8 +97,8 @@ func.func @main() {
%dense_ret = call @conv_2d_nhwc_hwcf(%static_input, %static_filter, %static_output) : (tensor<3x8x8x3xf32>, tensor<5x5x3x1xf32>, tensor<3x8x8x1xf32>) -> (tensor<3x8x8x1xf32>)
- %in2D_nhwc_CCCC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CCCC>
- %CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %static_filter) : (tensor<3x8x8x3xf32, #CCCC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
+ %in2D_nhwc_CDCC_NHWC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CDCC_NHWC>
+ %CDCC_NHWC_ret = call @conv_2d_nhwc_hwcf_CDCC_NHWC(%in2D_nhwc_CDCC_NHWC, %static_filter) : (tensor<3x8x8x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
// CHECK: ( ( ( ( 108 ), ( 160 ), ( 196 ), ( 196 ), ( 196 ), ( 196 ), ( 144 ), ( 108 ) ),
@@ -161,17 +153,17 @@ func.func @main() {
// CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
// CHECK-SAME: ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
// CHECK-SAME: ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ) )
- %CCCC_v = vector.transfer_read %CCCC_ret[%c0, %c0, %c0, %c0], %zero
+ %CDCC_NHWC_v = vector.transfer_read %CDCC_NHWC_ret[%c0, %c0, %c0, %c0], %zero
: tensor<3x8x8x1xf32>, vector<3x8x8x1xf32>
- vector.print %CCCC_v : vector<3x8x8x1xf32>
+ vector.print %CDCC_NHWC_v : vector<3x8x8x1xf32>
bufferization.dealloc_tensor %static_filter : tensor<5x5x3x1xf32>
bufferization.dealloc_tensor %static_input : tensor<3x8x8x3xf32>
bufferization.dealloc_tensor %static_output : tensor<3x8x8x1xf32>
- bufferization.dealloc_tensor %CCCC_ret : tensor<3x8x8x1xf32>
+ bufferization.dealloc_tensor %CDCC_NHWC_ret : tensor<3x8x8x1xf32>
- bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor<3x8x8x3xf32, #CCCC>
+ bufferization.dealloc_tensor %in2D_nhwc_CDCC_NHWC : tensor<3x8x8x3xf32, #CDCC_NHWC>
return
}