diff options
author | Ryan Holt <ryanholt@mathworks.com> | 2024-04-23 11:18:04 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-23 11:18:04 -0400 |
commit | 8317d366212763d907d6d61a6d07450168a33bfb (patch) | |
tree | 0b49cb0e88a04deb47ff802e2346d3e402177c1c | |
parent | c1086532d4d5c0a261457dfb00e79fcb764e3d78 (diff) |
[mlir][linalg] Add runtime verification for linalg ops (#89342)
This commit implements runtime verification for LinalgStructuredOps
using the existing `RuntimeVerifiableOpInterface`. The verification
checks that the runtime sizes of the operands match the runtime sizes
inferred by composing the loop ranges with the op's indexing maps.
9 files changed, 549 insertions, 33 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h b/mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h new file mode 100644 index 000000000000..6c3643f7835c --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h @@ -0,0 +1,21 @@ +//===- RuntimeOpVerification.h - Op Verification ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H +#define MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H + +namespace mlir { +class DialectRegistry; + +namespace linalg { +void registerRuntimeVerifiableOpInterfaceExternalModels( + DialectRegistry ®istry); +} // namespace linalg +} // namespace mlir + +#endif // MLIR_DIALECT_LINALG_RUNTIMEOPVERIFICATION_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index c4d788cf8ed3..d9db21073e15 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -45,6 +45,7 @@ #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/MPI/IR/MPI.h" @@ -161,6 +162,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { cf::registerBufferDeallocationOpInterfaceExternalModels(registry); gpu::registerBufferDeallocationOpInterfaceExternalModels(registry); linalg::registerAllDialectInterfaceImplementations(registry); + linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry); memref::registerAllocationOpInterfaceExternalModels(registry); memref::registerBufferViewFlowOpInterfaceExternalModels(registry); memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); diff --git a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td index d5f11d00cc3d..6fd0df59d9d2 100644 --- a/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td +++ b/mlir/include/mlir/Interfaces/RuntimeVerifiableOpInterface.td @@ -35,6 +35,12 @@ def RuntimeVerifiableOpInterface : OpInterface<"RuntimeVerifiableOpInterface"> { "::mlir::Location":$loc) >, ]; + + let extraClassDeclaration = [{ + /// Generate the error message that will be printed to the user when + /// verification fails. + static std::string generateErrorMessage(Operation *op, const std::string &msg); + }]; } #endif // MLIR_INTERFACES_RUNTIMEVERIFIABLEOPINTERFACE diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index ee6e391d0cc6..44d95bbc02d4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -27,6 +27,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms NamedOpConversions.cpp Padding.cpp Promotion.cpp + RuntimeOpVerification.cpp Specialize.cpp Split.cpp SplitReduction.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp new file mode 100644 index 000000000000..b30182dc8407 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp @@ -0,0 +1,135 @@ +//===- RuntimeOpVerification.cpp - Op Verification ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Index/IR/IndexAttrs.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" + +namespace mlir { +namespace linalg { +namespace { +/// Verify that the runtime sizes of the operands to linalg structured ops are +/// compatible with the runtime sizes inferred by composing the loop ranges with +/// the linalg op's indexing maps. This is similar to the verifier except that +/// here we insert IR to perform the verification at runtime. +template <typename T> +struct StructuredOpInterface + : public RuntimeVerifiableOpInterface::ExternalModel< + StructuredOpInterface<T>, T> { + void generateRuntimeVerification(Operation *op, OpBuilder &builder, + Location loc) const { + auto linalgOp = llvm::cast<LinalgOp>(op); + + SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc); + auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges); + + auto zero = builder.create<arith::ConstantIndexOp>(loc, 0); + auto one = builder.create<arith::ConstantIndexOp>(loc, 1); + + // Subtract one from the loop ends before composing with the indexing map + transform(ends, ends.begin(), [&](OpFoldResult end) { + auto endValue = getValueOrCreateConstantIndexOp(builder, loc, end); + return builder.createOrFold<index::SubOp>(loc, endValue, one); + }); + + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); + auto startIndices = affine::makeComposedFoldedMultiResultAffineApply( + builder, loc, indexingMap, starts); + auto endIndices = affine::makeComposedFoldedMultiResultAffineApply( + builder, loc, indexingMap, ends); + + for (auto dim : llvm::seq(linalgOp.getRank(&opOperand))) { + auto startIndex = + getValueOrCreateConstantIndexOp(builder, loc, startIndices[dim]); + auto endIndex = + getValueOrCreateConstantIndexOp(builder, loc, endIndices[dim]); + + // Generate: + // minIndex = min(startIndex, endIndex) + // assert(minIndex >= 0) + // To ensure we do not generate a negative index. We take the minimum of + // the start and end indices in order to handle reverse loops such as + // `affine_map<(i) -> (3 - i)>` + auto min = + builder.createOrFold<index::MinSOp>(loc, startIndex, endIndex); + auto cmpOp = builder.createOrFold<index::CmpOp>( + loc, index::IndexCmpPredicate::SGE, min, zero); + auto msg = RuntimeVerifiableOpInterface::generateErrorMessage( + linalgOp, "unexpected negative result on dimension #" + + std::to_string(dim) + " of input/output operand #" + + std::to_string(opOperand.getOperandNumber())); + builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg); + + // Generate: + // inferredDimSize = max(startIndex, endIndex) + 1 + // actualDimSize = dim(operand) + // assert(inferredDimSize <= actualDimSize) + // To ensure that we do not index past the bounds of the operands. + auto max = + builder.createOrFold<index::MaxSOp>(loc, startIndex, endIndex); + + auto inferredDimSize = + builder.createOrFold<index::AddOp>(loc, max, one); + + auto actualDimSize = + createOrFoldDimOp(builder, loc, opOperand.get(), dim); + + // Similar to the verifier, when the affine expression in the indexing + // map is complicated, we just check that the inferred dimension sizes + // are in the boundary of the operands' size. Being more precise than + // that is difficult. + auto predicate = isa<AffineDimExpr>(indexingMap.getResult(dim)) + ? index::IndexCmpPredicate::EQ + : index::IndexCmpPredicate::SLE; + + cmpOp = builder.createOrFold<index::CmpOp>( + loc, predicate, inferredDimSize, actualDimSize); + msg = RuntimeVerifiableOpInterface::generateErrorMessage( + linalgOp, "dimension #" + std::to_string(dim) + + " of input/output operand #" + + std::to_string(opOperand.getOperandNumber()) + + " is incompatible with inferred dimension size"); + builder.createOrFold<cf::AssertOp>(loc, cmpOp, msg); + } + } + } +}; + +template <typename... OpTs> +void attachInterface(MLIRContext *ctx) { + (OpTs::template attachInterface<StructuredOpInterface<OpTs>>(*ctx), ...); +} +} // namespace +} // namespace linalg +} // namespace mlir + +void mlir::linalg::registerRuntimeVerifiableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *) { + attachInterface< +#define GET_OP_LIST +#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" + >(ctx); + + // Load additional dialects of which ops may get created. + ctx->loadDialect<affine::AffineDialect, arith::ArithDialect, + cf::ControlFlowDialect, index::IndexDialect, + tensor::TensorDialect>(); + }); +} diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 05b813a3b1e9..450bfa0cec0c 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -20,25 +20,6 @@ using namespace mlir; -/// Generate an error message string for the given op and the specified error. -static std::string generateErrorMessage(Operation *op, const std::string &msg) { - std::string buffer; - llvm::raw_string_ostream stream(buffer); - OpPrintingFlags flags; - // We may generate a lot of error messages and so we need to ensure the - // printing is fast. - flags.elideLargeElementsAttrs(); - flags.printGenericOpForm(); - flags.skipRegions(); - flags.useLocalScope(); - stream << "ERROR: Runtime op verification failed\n"; - op->print(stream, flags); - stream << "\n^ " << msg; - stream << "\nLocation: "; - op->getLoc().print(stream); - return stream.str(); -} - namespace mlir { namespace memref { namespace { @@ -62,8 +43,10 @@ struct CastOpInterface builder.create<arith::ConstantIndexOp>(loc, resultType.getRank()); Value isSameRank = builder.create<arith::CmpIOp>( loc, arith::CmpIPredicate::eq, srcRank, resultRank); - builder.create<cf::AssertOp>(loc, isSameRank, - generateErrorMessage(op, "rank mismatch")); + builder.create<cf::AssertOp>( + loc, isSameRank, + RuntimeVerifiableOpInterface::generateErrorMessage(op, + "rank mismatch")); } // Get source offset and strides. We do not have an op to get offsets and @@ -101,8 +84,8 @@ struct CastOpInterface loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); builder.create<cf::AssertOp>( loc, isSameSz, - generateErrorMessage(op, "size mismatch of dim " + - std::to_string(it.index()))); + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "size mismatch of dim " + std::to_string(it.index()))); } // Get result offset and strides. @@ -119,8 +102,10 @@ struct CastOpInterface builder.create<arith::ConstantIndexOp>(loc, resultOffset); Value isSameOffset = builder.create<arith::CmpIOp>( loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal); - builder.create<cf::AssertOp>(loc, isSameOffset, - generateErrorMessage(op, "offset mismatch")); + builder.create<cf::AssertOp>( + loc, isSameOffset, + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "offset mismatch")); } // Check strides. @@ -137,8 +122,8 @@ struct CastOpInterface loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal); builder.create<cf::AssertOp>( loc, isSameStride, - generateErrorMessage(op, "stride mismatch of dim " + - std::to_string(it.index()))); + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "stride mismatch of dim " + std::to_string(it.index()))); } } }; @@ -178,7 +163,9 @@ struct LoadStoreOpInterface : andOp; } builder.create<cf::AssertOp>( - loc, assertCond, generateErrorMessage(op, "out-of-bounds access")); + loc, assertCond, + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "out-of-bounds access")); } }; @@ -248,7 +235,7 @@ struct ReinterpretCastOpInterface builder.create<cf::AssertOp>( loc, assertCond, - generateErrorMessage( + RuntimeVerifiableOpInterface::generateErrorMessage( op, "result of reinterpret_cast is out-of-bounds of the base memref")); } @@ -293,8 +280,8 @@ struct SubViewOpInterface builder.create<cf::AssertOp>( loc, assertCond, - generateErrorMessage(op, - "subview is out-of-bounds of the base memref")); + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "subview is out-of-bounds of the base memref")); } }; @@ -334,8 +321,9 @@ struct ExpandShapeOpInterface builder.create<arith::ConstantIndexOp>(loc, 0)); builder.create<cf::AssertOp>( loc, isModZero, - generateErrorMessage(op, "static result dims in reassoc group do not " - "divide src dim evenly")); + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "static result dims in reassoc group do not " + "divide src dim evenly")); } } }; diff --git a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp index 9205d8d8c34a..e823b5df179c 100644 --- a/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp +++ b/mlir/lib/Interfaces/RuntimeVerifiableOpInterface.cpp @@ -11,6 +11,28 @@ namespace mlir { class Location; class OpBuilder; + +/// Generate an error message string for the given op and the specified error. +std::string +RuntimeVerifiableOpInterface::generateErrorMessage(Operation *op, + const std::string &msg) { + std::string buffer; + llvm::raw_string_ostream stream(buffer); + OpPrintingFlags flags; + // We may generate a lot of error messages and so we need to ensure the + // printing is fast. + flags.elideLargeElementsAttrs(); + flags.printGenericOpForm(); + flags.skipRegions(); + flags.useLocalScope(); + stream << "ERROR: Runtime op verification failed\n"; + op->print(stream, flags); + stream << "\n^ " << msg; + stream << "\nLocation: "; + op->getLoc().print(stream); + return stream.str(); +} + } // namespace mlir /// Include the definitions of the interface. diff --git a/mlir/test/Dialect/Linalg/runtime-verification.mlir b/mlir/test/Dialect/Linalg/runtime-verification.mlir new file mode 100644 index 000000000000..a4f29d8457e5 --- /dev/null +++ b/mlir/test/Dialect/Linalg/runtime-verification.mlir @@ -0,0 +1,43 @@ +// RUN: mlir-opt %s -generate-runtime-verification | FileCheck %s + +// Most of the tests for linalg runtime-verification are implemented as integration tests. + +#identity = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: @static_dims +func.func @static_dims(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>) { + // CHECK: %[[TRUE:.*]] = index.bool.constant true + // CHECK: cf.assert %[[TRUE]] + %result = tensor.empty() : tensor<5xf32> + %0 = linalg.generic { + indexing_maps = [#identity, #identity, #identity], + iterator_types = ["parallel"] + } ins(%arg0, %arg1 : tensor<5xf32>, tensor<5xf32>) + outs(%result : tensor<5xf32>) { + ^bb0(%gen_arg1: f32, %gen_arg2: f32, %out: f32) : + %tmp1 = arith.addf %gen_arg1, %gen_arg2 : f32 + linalg.yield %tmp1 : f32 + } -> tensor<5xf32> + return %0 : tensor<5xf32> +} + +// ----- + +#map = affine_map<() -> ()> + +// CHECK-LABEL: @scalars +func.func @scalars(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) { + // No runtime checks are required if the operands are all scalars + // CHECK-NOT: cf.assert + %result = tensor.empty() : tensor<f32> + %0 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = [] + } ins(%arg0, %arg1 : tensor<f32>, tensor<f32>) + outs(%result : tensor<f32>) { + ^bb0(%gen_arg1: f32, %gen_arg2: f32, %out: f32) : + %tmp1 = arith.addf %gen_arg1, %gen_arg2 : f32 + linalg.yield %tmp1 : f32 + } -> tensor<f32> + return %0 : tensor<f32> +} diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir new file mode 100644 index 000000000000..b05ef9422e59 --- /dev/null +++ b/mlir/test/Integration/Dialect/Linalg/CPU/runtime-verification.mlir @@ -0,0 +1,298 @@ +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -convert-linalg-to-loops \ +// RUN: -expand-strided-metadata \ +// RUN: -lower-affine \ +// RUN: -convert-scf-to-cf \ +// RUN: -test-cf-assert \ +// RUN: -convert-index-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils \ +// RUN: -shared-libs=%mlir_c_runner_utils 2>&1 | \ +// RUN: FileCheck %s + +func.func @main() { + %c5x = arith.constant dense<0.0> : tensor<5xf32> + %c4x = arith.constant dense<0.0> : tensor<4xf32> + %d5x = tensor.cast %c5x : tensor<5xf32> to tensor<?xf32> + %d4x = tensor.cast %c4x : tensor<4xf32> to tensor<?xf32> + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @simple_add(%d5x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>) + + // CHECK: ERROR: Runtime op verification failed + // CHECK: linalg.generic + // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size + func.call @simple_add(%d5x, %d4x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>) + + // CHECK: ERROR: Runtime op verification failed + // CHECK: linalg.generic + // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size + func.call @simple_add(%d4x, %d5x) : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>) + + %c1x1 = arith.constant dense<0.0> : tensor<1x1xf32> + %c1x4 = arith.constant dense<0.0> : tensor<1x4xf32> + %c4x4 = arith.constant dense<0.0> : tensor<4x4xf32> + %c4x5 = arith.constant dense<0.0> : tensor<4x5xf32> + %c5x4 = arith.constant dense<0.0> : tensor<5x4xf32> + %d1x1 = tensor.cast %c1x1 : tensor<1x1xf32> to tensor<?x?xf32> + %d1x4 = tensor.cast %c1x4 : tensor<1x4xf32> to tensor<?x?xf32> + %d4x4 = tensor.cast %c4x4 : tensor<4x4xf32> to tensor<?x?xf32> + %d4x5 = tensor.cast %c4x5 : tensor<4x5xf32> to tensor<?x?xf32> + %d5x4 = tensor.cast %c5x4 : tensor<5x4xf32> to tensor<?x?xf32> + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @broadcast_add(%d1x1, %d1x1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @broadcast_add(%d1x1, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @broadcast_add(%d4x4, %d1x4) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + + // CHECK: ERROR: Runtime op verification failed + // CHECK: linalg.generic + // CHECK: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size + func.call @broadcast_add(%d1x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + + // CHECK: ERROR: Runtime op verification failed + // CHECK: linalg.generic + // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size + // CHECK: ERROR: Runtime op verification failed + // CHECK: linalg.generic + // CHECK: ^ dimension #1 of input/output operand #1 is incompatible with inferred dimension size + // CHECK: ERROR: Runtime op verification failed + // CHECK: linalg.generic + // CHECK: ^ dimension #1 of input/output operand #2 is incompatible with inferred dimension size + func.call @broadcast_add(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @matmul_generic(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + + // CHECK: ERROR: Runtime op verification failed + // CHECK: linalg.generic + // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size + func.call @matmul_generic(%d4x5, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @matmul_named(%d5x4, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + + // CHECK: ERROR: Runtime op verification failed + // CHECK: linalg.matmul + // CHECK: ^ dimension #0 of input/output operand #1 is incompatible with inferred dimension size + func.call @matmul_named(%d4x5, %d4x5) : (tensor<?x?xf32>, tensor<?x?xf32>) -> (tensor<?x?xf32>) + + %c64x57 = arith.constant dense<0.0> : tensor<16x29xf32> + %c3x4 = arith.constant dense<0.0> : tensor<3x4xf32> + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @conv(%c64x57, %c3x4) : (tensor<16x29xf32>, tensor<3x4xf32>) -> (tensor<5x7xf32>) + + // CHECK-NOT: ERROR: Runtime op verification failed + func.call @reverse_from_3(%d4x) : (tensor<?xf32>) -> (tensor<?xf32>) + + // CHECK: ERROR: Runtime op verification failed + // CHECK: linalg.generic + // CHECK: unexpected negative result on dimension #0 of input/output operand #0 + func.call @reverse_from_3(%d5x) : (tensor<?xf32>) -> (tensor<?xf32>) + + return +} + + +#identity1D = affine_map<(d0) -> (d0)> + +func.func @simple_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> (tensor<?xf32>) { + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg0, %c0 : tensor<?xf32> + %result = tensor.empty(%dim) : tensor<?xf32> + %0 = linalg.generic { + indexing_maps = [#identity1D, #identity1D, #identity1D], + iterator_types = ["parallel"] + } ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) + outs(%result : tensor<?xf32>) { + ^bb0(%gen_arg1: f32, %gen_arg2: f32, %out: f32) : + %tmp1 = arith.addf %gen_arg1, %gen_arg2 : f32 + linalg.yield %tmp1 : f32 + } -> tensor<?xf32> + return %0 : tensor<?xf32> +} + +#broadcastD0 = affine_map<(d0, d1) -> (0, d1)> +#broadcastD1 = affine_map<(d0, d1) -> (d0, 0)> +#identity2D = affine_map<(d0, d1) -> (d0, d1)> + +func.func @broadcast_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { + // Calculate maximum dimension 0 + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32> + %dim_0 = tensor.dim %arg1, %c0 : tensor<?x?xf32> + %0 = arith.maxui %dim, %dim_0 : index + + // Calculate maximum dimension 1 + %c1 = arith.constant 1 : index + %dim_1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> + %dim_2 = tensor.dim %arg1, %c1 : tensor<?x?xf32> + %1 = arith.maxui %dim_1, %dim_2 : index + + // Broadcast dimension 0 of %arg0 + %dim_3 = tensor.dim %arg0, %c0 : tensor<?x?xf32> + %2 = arith.cmpi eq, %dim_3, %c1 : index + %3 = scf.if %2 -> (tensor<?x?xf32>) { + %dim_7 = tensor.dim %arg0, %c1 : tensor<?x?xf32> + %12 = tensor.empty(%0, %dim_7) : tensor<?x?xf32> + %13 = linalg.generic { + indexing_maps = [#broadcastD0, #identity2D], + iterator_types = ["parallel", "parallel"] + } ins(%arg0 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<?x?xf32> + scf.yield %13 : tensor<?x?xf32> + } else { + scf.yield %arg0 : tensor<?x?xf32> + } + + // Broadcast dimension 1 of %arg0 + %dim_4 = tensor.dim %3, %c1 : tensor<?x?xf32> + %4 = arith.cmpi eq, %dim_4, %c1 : index + %5 = scf.if %4 -> (tensor<?x?xf32>) { + %dim_7 = tensor.dim %3, %c0 : tensor<?x?xf32> + %12 = tensor.empty(%dim_7, %1) : tensor<?x?xf32> + %13 = linalg.generic { + indexing_maps = [#broadcastD1, #identity2D], + iterator_types = ["parallel", "parallel"] + } ins(%3 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<?x?xf32> + scf.yield %13 : tensor<?x?xf32> + } else { + scf.yield %3 : tensor<?x?xf32> + } + + // Broadcast dimension 0 of %arg1 + %dim_5 = tensor.dim %arg1, %c0 : tensor<?x?xf32> + %6 = arith.cmpi eq, %dim_5, %c1 : index + %7 = scf.if %6 -> (tensor<?x?xf32>) { + %dim_7 = tensor.dim %arg1, %c1 : tensor<?x?xf32> + %12 = tensor.empty(%0, %dim_7) : tensor<?x?xf32> + %13 = linalg.generic { + indexing_maps = [#broadcastD0, #identity2D], + iterator_types = ["parallel", "parallel"] + } ins(%arg1 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<?x?xf32> + scf.yield %13 : tensor<?x?xf32> + } else { + scf.yield %arg1 : tensor<?x?xf32> + } + + // Broadcast dimension 1 of %arg1 + %dim_6 = tensor.dim %7, %c1 : tensor<?x?xf32> + %8 = arith.cmpi eq, %dim_6, %c1 : index + %9 = scf.if %8 -> (tensor<?x?xf32>) { + %dim_7 = tensor.dim %7, %c0 : tensor<?x?xf32> + %12 = tensor.empty(%dim_7, %1) : tensor<?x?xf32> + %13 = linalg.generic { + indexing_maps = [#broadcastD1, #identity2D], + iterator_types = ["parallel", "parallel"] + } ins(%7 : tensor<?x?xf32>) outs(%12 : tensor<?x?xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<?x?xf32> + scf.yield %13 : tensor<?x?xf32> + } else { + scf.yield %7 : tensor<?x?xf32> + } + + // Perform element-wise computation + %10 = tensor.empty(%0, %1) : tensor<?x?xf32> + %11 = linalg.generic { + indexing_maps = [#identity2D, #identity2D, #identity2D], + iterator_types = ["parallel", "parallel"] + } ins(%5, %9 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%10 : tensor<?x?xf32>) { + ^bb0(%in: f32, %in_7: f32, %out: f32): + %12 = arith.addf %in, %in_7 : f32 + linalg.yield %12 : f32 + } -> tensor<?x?xf32> + return %11 : tensor<?x?xf32> +} + +#matmul_accesses = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (k, n)>, + affine_map<(m, n, k) -> (m, n)> +] +#matmul_trait = { + iterator_types = ["parallel", "parallel", "reduction"], + indexing_maps = #matmul_accesses +} + +func.func @matmul_generic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { + %cf0 = arith.constant 0.0 : f32 + %ci0 = arith.constant 0 : index + %ci1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %ci0 : tensor<?x?xf32> + %d1 = tensor.dim %arg1, %ci1 : tensor<?x?xf32> + %splat = tensor.splat %cf0[%d0, %d1] : tensor<?x?xf32> + %0 = linalg.generic #matmul_trait ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%splat : tensor<?x?xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %1 = arith.mulf %in, %in_0 : f32 + %2 = arith.addf %out, %1 : f32 + linalg.yield %2 : f32 + } -> tensor<?x?xf32> + return %0 : tensor<?x?xf32> +} + +func.func @matmul_named(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { + %cf0 = arith.constant 0.0 : f32 + %ci0 = arith.constant 0 : index + %ci1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %ci0 : tensor<?x?xf32> + %d1 = tensor.dim %arg1, %ci1 : tensor<?x?xf32> + %splat = tensor.splat %cf0[%d0, %d1] : tensor<?x?xf32> + %0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%splat : tensor<?x?xf32>) -> tensor<?x?xf32> + return %0 : tensor<?x?xf32> +} + +#conv_trait = { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0 * 3 + d2, d1 * 4 + d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction", "reduction"] +} + +func.func @conv(%arg0: tensor<16x29xf32>, %arg1: tensor<3x4xf32>) -> (tensor<5x7xf32>) { + %c0 = arith.constant 0.0 : f32 + %splat = tensor.splat %c0 : tensor<5x7xf32> + %result = linalg.generic #conv_trait ins(%arg0, %arg1 : tensor<16x29xf32>, tensor<3x4xf32>) outs(%splat : tensor<5x7xf32>) { + ^bb0(%in: f32, %in_64: f32, %out: f32): + %5 = arith.mulf %in, %in_64 : f32 + %6 = arith.addf %out, %5 : f32 + linalg.yield %6 : f32 + } -> tensor<5x7xf32> + return %result : tensor<5x7xf32> +} + +#reverse_trait = { + indexing_maps = [ + affine_map<(i) -> (3 - i)>, + affine_map<(i) -> (i)> + ], + iterator_types = ["parallel"] +} + +func.func @reverse_from_3(%arg0: tensor<?xf32>) -> (tensor<?xf32>) { + %cf0 = arith.constant 0.0 : f32 + %ci0 = arith.constant 0 : index + %d0 = tensor.dim %arg0, %ci0 : tensor<?xf32> + %splat = tensor.splat %cf0[%d0] : tensor<?xf32> + %result = linalg.generic #reverse_trait ins(%arg0: tensor<?xf32>) outs(%splat: tensor<?xf32>) { + ^bb0(%a: f32, %b: f32): + linalg.yield %a : f32 + } -> tensor<?xf32> + return %result : tensor<?xf32> +} |