summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/IR/BuiltinDialect.td2
-rw-r--r--mlir/include/mlir/IR/BuiltinTypes.td25
-rw-r--r--mlir/include/mlir/IR/OpImplementation.h5
-rw-r--r--mlir/lib/AsmParser/DialectSymbolParser.cpp24
-rw-r--r--mlir/lib/AsmParser/Parser.h24
-rw-r--r--mlir/lib/AsmParser/TypeParser.cpp211
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp72
-rw-r--r--mlir/lib/IR/BuiltinTypes.cpp150
-rw-r--r--mlir/test/IR/invalid-builtin-types.mlir10
-rw-r--r--mlir/test/IR/invalid.mlir4
-rw-r--r--mlir/test/IR/qualified-builtin.mlir11
11 files changed, 249 insertions, 289 deletions
diff --git a/mlir/include/mlir/IR/BuiltinDialect.td b/mlir/include/mlir/IR/BuiltinDialect.td
index c131107634b4..a8627170288c 100644
--- a/mlir/include/mlir/IR/BuiltinDialect.td
+++ b/mlir/include/mlir/IR/BuiltinDialect.td
@@ -22,7 +22,7 @@ def Builtin_Dialect : Dialect {
let name = "builtin";
let cppNamespace = "::mlir";
let useDefaultAttributePrinterParser = 0;
- let useDefaultTypePrinterParser = 0;
+ let useDefaultTypePrinterParser = 1;
let extraClassDeclaration = [{
private:
// Register the builtin Attributes.
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index 4cade83dd3c3..f3a51d215504 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -25,7 +25,8 @@ include "mlir/IR/BuiltinTypeInterfaces.td"
// Base class for Builtin dialect types.
class Builtin_Type<string name, string typeMnemonic, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
- : TypeDef<Builtin_Dialect, name, traits, baseCppClass> {
+ : TypeDef<Builtin_Dialect, name, !listconcat(traits, [PrintTypeQualified]),
+ baseCppClass> {
let mnemonic = ?;
let typeName = "builtin." # typeMnemonic;
}
@@ -62,6 +63,9 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
+
+ let mnemonic = "complex";
+ let assemblyFormat = "`<` $elementType `>`";
}
//===----------------------------------------------------------------------===//
@@ -668,6 +672,9 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
+
+ let mnemonic = "memref";
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@@ -698,6 +705,8 @@ def Builtin_None : Builtin_Type<"None", "none"> {
let extraClassDeclaration = [{
static NoneType get(MLIRContext *context);
}];
+
+ let mnemonic = "none";
}
//===----------------------------------------------------------------------===//
@@ -849,6 +858,9 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
+
+ let mnemonic = "tensor";
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@@ -884,7 +896,7 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
tuple<i32, f32, tensor<i1>, i5>
```
}];
- let parameters = (ins "ArrayRef<Type>":$types);
+ let parameters = (ins OptionalArrayRefParameter<"Type">:$types);
let builders = [
TypeBuilder<(ins "TypeRange":$elementTypes), [{
return $_get($_ctxt, elementTypes);
@@ -916,6 +928,9 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
return getTypes()[index];
}
}];
+
+ let mnemonic = "tuple";
+ let assemblyFormat = "`<` (`>`) : ($types^ `>`)?";
}
//===----------------------------------------------------------------------===//
@@ -994,6 +1009,9 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
+
+ let mnemonic = "memref";
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
@@ -1043,6 +1061,9 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
}];
let skipDefaultBuilders = 1;
let genVerifyDecl = 1;
+
+ let mnemonic = "tensor";
+ let hasCustomAssemblyFormat = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 50e6cc59ca45..2a5587d43901 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -187,6 +187,11 @@ public:
/// provide a valid type for the attribute.
virtual void printAttributeWithoutType(Attribute attr);
+ /// Print the given attribute without its type if and only if the type is the
+ /// default type for the given attribute.
+ /// E.g. '1 : i64' is printed as just '1'.
+ virtual void printAttributeWithoutDefaultType(Attribute attr);
+
/// Print the alias for the given attribute, return failure if no alias could
/// be printed.
virtual LogicalResult printAlias(Attribute attr);
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 80cce7e6ae43..400d26398afc 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -322,6 +322,30 @@ Type Parser::parseExtendedType() {
});
}
+Type Parser::parseExtendedBuiltinType() {
+ // Initially set to just the mnemonic of the type.
+ llvm::StringRef symbolData = getToken().getSpelling();
+ const char *startOfTypePos = symbolData.data();
+ consumeToken();
+ // Extend 'symbolData' to include the body if it is not a singleton type.
+ // Note that all types in the builtin type always use the pretty dialect form
+ // aka 'dialect.mnemonic<body>'.
+ if (getToken().is(Token::less))
+ if (failed(parseDialectSymbolBody(symbolData)))
+ return nullptr;
+
+ const char *endOfTypePos = getToken().getLoc().getPointer();
+
+ // With the body of the type captured, hand it off to the dialect parser.
+ resetToken(startOfTypePos);
+ CustomDialectAsmParser customParser(symbolData, *this);
+ Type type = builtinDialect->parseType(customParser);
+
+ // Move the lexer past the type.
+ resetToken(endOfTypePos);
+ return type;
+}
+
//===----------------------------------------------------------------------===//
// mlir::parseAttribute/parseType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h
index b959e67b8e25..73080c88ff6b 100644
--- a/mlir/lib/AsmParser/Parser.h
+++ b/mlir/lib/AsmParser/Parser.h
@@ -11,6 +11,7 @@
#include "ParserState.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/OpImplementation.h"
#include <optional>
@@ -28,9 +29,14 @@ public:
using Delimiter = OpAsmParser::Delimiter;
Builder builder;
+ /// Cached instance of the builtin dialect for parsing builtins.
+ Dialect *builtinDialect;
Parser(ParserState &state)
- : builder(state.config.getContext()), state(state) {}
+ : builder(state.config.getContext()),
+ builtinDialect(
+ builder.getContext()->getLoadedDialect<BuiltinDialect>()),
+ state(state) {}
// Helper methods to get stuff from the parser-global state.
ParserState &getState() const { return state; }
@@ -192,27 +198,19 @@ public:
/// Parse an arbitrary type.
Type parseType();
- /// Parse a complex type.
- Type parseComplexType();
-
/// Parse an extended type.
Type parseExtendedType();
+ /// Parse an extended type from the builtin dialect where the '!builtin'
+ /// prefix is missing.
+ Type parseExtendedBuiltinType();
+
/// Parse a function type.
Type parseFunctionType();
- /// Parse a memref type.
- Type parseMemRefType();
-
/// Parse a non function type.
Type parseNonFunctionType();
- /// Parse a tensor type.
- Type parseTensorType();
-
- /// Parse a tuple type.
- Type parseTupleType();
-
/// Parse a vector type.
VectorType parseVectorType();
ParseResult parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions,
diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp
index 5da931b77b3b..95df69b899b8 100644
--- a/mlir/lib/AsmParser/TypeParser.cpp
+++ b/mlir/lib/AsmParser/TypeParser.cpp
@@ -11,12 +11,9 @@
//===----------------------------------------------------------------------===//
#include "Parser.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
-#include "mlir/IR/TensorEncoding.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
@@ -123,29 +120,6 @@ ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
return success();
}
-/// Parse a complex type.
-///
-/// complex-type ::= `complex` `<` type `>`
-///
-Type Parser::parseComplexType() {
- consumeToken(Token::kw_complex);
-
- // Parse the '<'.
- if (parseToken(Token::less, "expected '<' in complex type"))
- return nullptr;
-
- SMLoc elementTypeLoc = getToken().getLoc();
- auto elementType = parseType();
- if (!elementType ||
- parseToken(Token::greater, "expected '>' in complex type"))
- return nullptr;
- if (!isa<FloatType>(elementType) && !isa<IntegerType>(elementType))
- return emitError(elementTypeLoc, "invalid element type for complex"),
- nullptr;
-
- return ComplexType::get(elementType);
-}
-
/// Parse a function type.
///
/// function-type ::= type-list-parens `->` function-result-type
@@ -162,95 +136,6 @@ Type Parser::parseFunctionType() {
return builder.getFunctionType(arguments, results);
}
-/// Parse a memref type.
-///
-/// memref-type ::= ranked-memref-type | unranked-memref-type
-///
-/// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
-/// (`,` layout-specification)? (`,` memory-space)? `>`
-///
-/// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
-///
-/// stride-list ::= `[` (dimension (`,` dimension)*)? `]`
-/// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
-/// layout-specification ::= semi-affine-map | strided-layout | attribute
-/// memory-space ::= integer-literal | attribute
-///
-Type Parser::parseMemRefType() {
- SMLoc loc = getToken().getLoc();
- consumeToken(Token::kw_memref);
-
- if (parseToken(Token::less, "expected '<' in memref type"))
- return nullptr;
-
- bool isUnranked;
- SmallVector<int64_t, 4> dimensions;
-
- if (consumeIf(Token::star)) {
- // This is an unranked memref type.
- isUnranked = true;
- if (parseXInDimensionList())
- return nullptr;
-
- } else {
- isUnranked = false;
- if (parseDimensionListRanked(dimensions))
- return nullptr;
- }
-
- // Parse the element type.
- auto typeLoc = getToken().getLoc();
- auto elementType = parseType();
- if (!elementType)
- return nullptr;
-
- // Check that memref is formed from allowed types.
- if (!BaseMemRefType::isValidElementType(elementType))
- return emitError(typeLoc, "invalid memref element type"), nullptr;
-
- MemRefLayoutAttrInterface layout;
- Attribute memorySpace;
-
- auto parseElt = [&]() -> ParseResult {
- // Either it is MemRefLayoutAttrInterface or memory space attribute.
- Attribute attr = parseAttribute();
- if (!attr)
- return failure();
-
- if (isa<MemRefLayoutAttrInterface>(attr)) {
- layout = cast<MemRefLayoutAttrInterface>(attr);
- } else if (memorySpace) {
- return emitError("multiple memory spaces specified in memref type");
- } else {
- memorySpace = attr;
- return success();
- }
-
- if (isUnranked)
- return emitError("cannot have affine map for unranked memref type");
- if (memorySpace)
- return emitError("expected memory space to be last in memref type");
-
- return success();
- };
-
- // Parse a list of mappings and address space if present.
- if (!consumeIf(Token::greater)) {
- // Parse comma separated list of affine maps, followed by memory space.
- if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
- parseCommaSeparatedListUntil(Token::greater, parseElt,
- /*allowEmptyList=*/false)) {
- return nullptr;
- }
- }
-
- if (isUnranked)
- return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace);
-
- return getChecked<MemRefType>(loc, dimensions, elementType, layout,
- memorySpace);
-}
-
/// Parse any type except the function type.
///
/// non-function-type ::= integer-type
@@ -272,14 +157,12 @@ Type Parser::parseNonFunctionType() {
switch (getToken().getKind()) {
default:
return (emitWrongTokenError("expected non-function type"), nullptr);
- case Token::kw_memref:
- return parseMemRefType();
case Token::kw_tensor:
- return parseTensorType();
+ case Token::kw_memref:
case Token::kw_complex:
- return parseComplexType();
case Token::kw_tuple:
- return parseTupleType();
+ case Token::kw_none:
+ return parseExtendedBuiltinType();
case Token::kw_vector:
return parseVectorType();
// integer-type
@@ -344,11 +227,6 @@ Type Parser::parseNonFunctionType() {
consumeToken(Token::kw_index);
return builder.getIndexType();
- // none-type
- case Token::kw_none:
- consumeToken(Token::kw_none);
- return builder.getNoneType();
-
// extended type
case Token::exclamation_identifier:
return parseExtendedType();
@@ -361,89 +239,6 @@ Type Parser::parseNonFunctionType() {
}
}
-/// Parse a tensor type.
-///
-/// tensor-type ::= `tensor` `<` dimension-list type `>`
-/// dimension-list ::= dimension-list-ranked | `*x`
-///
-Type Parser::parseTensorType() {
- consumeToken(Token::kw_tensor);
-
- if (parseToken(Token::less, "expected '<' in tensor type"))
- return nullptr;
-
- bool isUnranked;
- SmallVector<int64_t, 4> dimensions;
-
- if (consumeIf(Token::star)) {
- // This is an unranked tensor type.
- isUnranked = true;
-
- if (parseXInDimensionList())
- return nullptr;
-
- } else {
- isUnranked = false;
- if (parseDimensionListRanked(dimensions))
- return nullptr;
- }
-
- // Parse the element type.
- auto elementTypeLoc = getToken().getLoc();
- auto elementType = parseType();
-
- // Parse an optional encoding attribute.
- Attribute encoding;
- if (consumeIf(Token::comma)) {
- auto parseResult = parseOptionalAttribute(encoding);
- if (parseResult.has_value()) {
- if (failed(parseResult.value()))
- return nullptr;
- if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
- if (failed(v.verifyEncoding(dimensions, elementType,
- [&] { return emitError(); })))
- return nullptr;
- }
- }
- }
-
- if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
- return nullptr;
- if (!TensorType::isValidElementType(elementType))
- return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
-
- if (isUnranked) {
- if (encoding)
- return emitError("cannot apply encoding to unranked tensor"), nullptr;
- return UnrankedTensorType::get(elementType);
- }
- return RankedTensorType::get(dimensions, elementType, encoding);
-}
-
-/// Parse a tuple type.
-///
-/// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
-///
-Type Parser::parseTupleType() {
- consumeToken(Token::kw_tuple);
-
- // Parse the '<'.
- if (parseToken(Token::less, "expected '<' in tuple type"))
- return nullptr;
-
- // Check for an empty tuple by directly parsing '>'.
- if (consumeIf(Token::greater))
- return TupleType::get(getContext());
-
- // Parse the element types and the '>'.
- SmallVector<Type, 4> types;
- if (parseTypeListNoParens(types) ||
- parseToken(Token::greater, "expected '>' in tuple type"))
- return nullptr;
-
- return TupleType::get(getContext(), types);
-}
-
/// Parse a vector type.
///
/// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 6b8b7473bf0f..0679d4135048 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -2132,6 +2132,13 @@ static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
/// Print the given dialect symbol to the stream.
static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
StringRef dialectName, StringRef symString) {
+ // Treat the builtin dialect special by eliding the '<symPrefix>builtin'
+ // prefix.
+ if (dialectName == "builtin") {
+ os << symString;
+ return;
+ }
+
os << symPrefix << dialectName;
// If this symbol name is simple enough, print it directly in pretty form,
@@ -2599,64 +2606,6 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
printType(vectorTy.getElementType());
os << '>';
})
- .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
- os << "tensor<";
- printDimensionList(tensorTy.getShape());
- if (!tensorTy.getShape().empty())
- os << 'x';
- printType(tensorTy.getElementType());
- // Only print the encoding attribute value if set.
- if (tensorTy.getEncoding()) {
- os << ", ";
- printAttribute(tensorTy.getEncoding());
- }
- os << '>';
- })
- .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
- os << "tensor<*x";
- printType(tensorTy.getElementType());
- os << '>';
- })
- .Case<MemRefType>([&](MemRefType memrefTy) {
- os << "memref<";
- printDimensionList(memrefTy.getShape());
- if (!memrefTy.getShape().empty())
- os << 'x';
- printType(memrefTy.getElementType());
- MemRefLayoutAttrInterface layout = memrefTy.getLayout();
- if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) {
- os << ", ";
- printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
- }
- // Only print the memory space if it is the non-default one.
- if (memrefTy.getMemorySpace()) {
- os << ", ";
- printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
- }
- os << '>';
- })
- .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
- os << "memref<*x";
- printType(memrefTy.getElementType());
- // Only print the memory space if it is the non-default one.
- if (memrefTy.getMemorySpace()) {
- os << ", ";
- printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
- }
- os << '>';
- })
- .Case<ComplexType>([&](ComplexType complexTy) {
- os << "complex<";
- printType(complexTy.getElementType());
- os << '>';
- })
- .Case<TupleType>([&](TupleType tupleTy) {
- os << "tuple<";
- interleaveComma(tupleTy.getTypes(),
- [&](Type type) { printType(type); });
- os << '>';
- })
- .Case<NoneType>([&](Type) { os << "none"; })
.Default([&](Type type) { return printDialectType(type); });
}
@@ -2799,6 +2748,13 @@ void AsmPrinter::printAttributeWithoutType(Attribute attr) {
impl->printAttribute(attr, Impl::AttrTypeElision::Must);
}
+void AsmPrinter::printAttributeWithoutDefaultType(Attribute attr) {
+ assert(
+ impl &&
+ "expected AsmPrinter::printAttributeWithoutDefaultType to be overriden");
+ impl->printAttribute(attr, Impl::AttrTypeElision::May);
+}
+
void AsmPrinter::printKeywordOrString(StringRef keyword) {
assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden");
::printKeywordOrString(keyword, impl->getStream());
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 9b8ee3d45280..e160c0ff4c33 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -10,10 +10,13 @@
#include "TypeDetail.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TensorEncoding.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/APFloat.h"
@@ -26,6 +29,52 @@ using namespace mlir;
using namespace mlir::detail;
//===----------------------------------------------------------------------===//
+// Custom printing and parsing
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseMemRefDimension(AsmParser &parser,
+ SmallVectorImpl<int64_t> &dimension,
+ bool &isUnranked) {
+ if (succeeded(parser.parseOptionalStar())) {
+ isUnranked = true;
+ return parser.parseXInDimensionList();
+ }
+
+ isUnranked = false;
+ return parser.parseDimensionList(dimension);
+}
+
+static ParseResult parseMemRefSpaceAndLayout(AsmParser &parser,
+ MemRefLayoutAttrInterface &layout,
+ Attribute &memorySpace,
+ bool isUnranked) {
+ while (succeeded(parser.parseOptionalComma())) {
+ SMLoc loc = parser.getCurrentLocation();
+ Attribute attr;
+ if (parser.parseAttribute(attr))
+ return failure();
+
+ if (auto memRefLayout = dyn_cast<MemRefLayoutAttrInterface>(attr)) {
+ layout = memRefLayout;
+ } else if (memorySpace) {
+ return parser.emitError(
+ loc, "multiple memory spaces specified in memref type");
+ } else {
+ memorySpace = attr;
+ continue;
+ }
+
+ if (isUnranked)
+ return parser.emitError(
+ loc, "cannot have affine map for unranked memref type");
+ if (memorySpace)
+ return parser.emitError(
+ loc, "expected memory space to be last in memref type");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
/// Tablegen Type Definitions
//===----------------------------------------------------------------------===//
@@ -340,6 +389,46 @@ RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
return checkTensorElementType(emitError, elementType);
}
+Type RankedTensorType::parse(AsmParser &parser) {
+ SmallVector<int64_t> dimension;
+ Type elementType;
+ bool isUnranked;
+ if (parser.parseLess() ||
+ parseMemRefDimension(parser, dimension, isUnranked) ||
+ parser.parseType(elementType))
+ return nullptr;
+
+ Attribute encoding;
+ if (succeeded(parser.parseOptionalComma())) {
+ SMLoc loc = parser.getCurrentLocation();
+ if (parser.parseAttribute(encoding))
+ return nullptr;
+
+ if (isUnranked) {
+ parser.emitError(loc, "cannot apply encoding to unranked tensor");
+ return nullptr;
+ }
+ }
+
+ if (failed(parser.parseGreater()))
+ return nullptr;
+
+ if (isUnranked)
+ return parser.getChecked<UnrankedTensorType>(elementType);
+ return parser.getChecked<RankedTensorType>(dimension, elementType, encoding);
+}
+
+void RankedTensorType::print(AsmPrinter &printer) const {
+ printer << '<';
+ printer.printDimensionList(getShape());
+ if (!getShape().empty())
+ printer << 'x';
+ printer << getElementType();
+ if (getEncoding())
+ printer << ", " << getEncoding();
+ printer << '>';
+}
+
//===----------------------------------------------------------------------===//
// UnrankedTensorType
//===----------------------------------------------------------------------===//
@@ -350,6 +439,14 @@ UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
return checkTensorElementType(emitError, elementType);
}
+Type UnrankedTensorType::parse(AsmParser &parser) {
+ return RankedTensorType::parse(parser);
+}
+
+void UnrankedTensorType::print(AsmPrinter &printer) const {
+ printer << "<*x" << getElementType() << ">";
+}
+
//===----------------------------------------------------------------------===//
// BaseMemRefType
//===----------------------------------------------------------------------===//
@@ -652,6 +749,44 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+Type MemRefType::parse(AsmParser &parser) {
+ SmallVector<int64_t> dimension;
+ Type elementType;
+ MemRefLayoutAttrInterface layout;
+ Attribute memorySpace;
+ bool isUnranked;
+ if (parser.parseLess() ||
+ parseMemRefDimension(parser, dimension, isUnranked) ||
+ parser.parseType(elementType) ||
+ parseMemRefSpaceAndLayout(parser, layout, memorySpace, isUnranked) ||
+ parser.parseGreater())
+ return nullptr;
+
+ if (isUnranked)
+ return parser.getChecked<UnrankedMemRefType>(elementType, memorySpace);
+ return parser.getChecked<MemRefType>(dimension, elementType, layout,
+ memorySpace);
+}
+
+void MemRefType::print(AsmPrinter &printer) const {
+ printer << '<';
+ printer.printDimensionList(getShape());
+ if (!getShape().empty())
+ printer << 'x';
+ printer << getElementType();
+ MemRefLayoutAttrInterface layout = getLayout();
+ if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) {
+ printer << ", ";
+ printer.printAttributeWithoutDefaultType(getLayout());
+ }
+ // Only print the memory space if it is the non-default one.
+ if (getMemorySpace()) {
+ printer << ", ";
+ printer.printAttributeWithoutDefaultType(getMemorySpace());
+ }
+ printer << '>';
+}
+
//===----------------------------------------------------------------------===//
// UnrankedMemRefType
//===----------------------------------------------------------------------===//
@@ -672,6 +807,21 @@ UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+Type UnrankedMemRefType::parse(AsmParser &parser) {
+ return MemRefType::parse(parser);
+}
+
+void UnrankedMemRefType::print(AsmPrinter &printer) const {
+ printer << "<*x";
+ printer << getElementType();
+ // Only print the memory space if it is the non-default one.
+ if (getMemorySpace()) {
+ printer << ", ";
+ printer.printAttributeWithoutDefaultType(getMemorySpace());
+ }
+ printer << '>';
+}
+
// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
// i.e. single term). Accumulate the AffineExpr into the existing one.
static void extractStridesFromTerm(AffineExpr e,
diff --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir
index 9884212e916c..04995bf7338a 100644
--- a/mlir/test/IR/invalid-builtin-types.mlir
+++ b/mlir/test/IR/invalid-builtin-types.mlir
@@ -27,7 +27,7 @@ func.func @illegalunrankedmemrefelementtype(memref<*xtensor<i8>>) -> () // expec
// -----
// Test no map in memref type.
-func.func @memrefs(memref<2x4xi8, >) // expected-error {{expected list element}}
+func.func @memrefs(memref<2x4xi8, >) // expected-error {{expected attribute}}
// -----
// Test non-existent map in memref type.
@@ -74,7 +74,7 @@ func.func private @memref_unfinished_strided() -> memref<?x?xf32, strided<>>
// -----
-// expected-error @below {{expected a 64-bit signed integer or '?'}}
+// expected-error @below {{unbalanced '[' character in pretty dialect name}}
func.func private @memref_unfinished_stride_list() -> memref<?x?xf32, strided<[>>
// -----
@@ -94,7 +94,7 @@ func.func private @memref_missing_offset_value() -> memref<?x?xf32, strided<[],
// -----
-// expected-error @below {{expected '>'}}
+// expected-error @below {{unbalanced '<' character in pretty dialect name}}
func.func private @memref_incorrect_strided_ending() -> memref<?x?xf32, strided<[], offset: 32)>
// -----
@@ -170,12 +170,12 @@ func.func @bad_complex(complex<memref<2x4xi8>>)
// -----
-// expected-error @+1 {{expected '<' in complex type}}
+// expected-error @+1 {{expected '<'}}
func.func @bad_complex(complex memref<2x4xi8>>)
// -----
-// expected-error @+1 {{expected '>' in complex type}}
+// expected-error @+1 {{unbalanced '<' character in pretty dialect name}}
func.func @bad_complex(complex<i32)
// -----
diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir
index 861f4ef6c020..1e01b477b1ad 100644
--- a/mlir/test/IR/invalid.mlir
+++ b/mlir/test/IR/invalid.mlir
@@ -419,12 +419,12 @@ func.func @invalid_unknown_type_dialect_name() -> !invalid.dialect<!x@#]!@#>
// -----
-// expected-error @+1 {{expected '<' in tuple type}}
+// expected-error @+1 {{expected '<'}}
func.func @invalid_tuple_missing_less(tuple i32>)
// -----
-// expected-error @+1 {{expected '>' in tuple type}}
+// expected-error @+1 {{unbalanced '<' character in pretty dialect name}}
func.func @invalid_tuple_missing_greater(tuple<i32)
// -----
diff --git a/mlir/test/IR/qualified-builtin.mlir b/mlir/test/IR/qualified-builtin.mlir
new file mode 100644
index 000000000000..a2f9e63ea66b
--- /dev/null
+++ b/mlir/test/IR/qualified-builtin.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// CHECK-LABEL: @test1
+// CHECK: -> tuple<>
+func.func private @test1() -> !builtin.tuple<>
+
+// CHECK-LABEL: @test2
+// CHECK: -> none
+func.func private @test2() -> !builtin.none
+
+