summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMarkus Böck <markus.boeck02@gmail.com>2024-02-02 11:57:16 +0100
committerMarkus Böck <markus.boeck02@gmail.com>2024-02-02 13:06:29 +0100
commit94c47995dff5a3a57b7066676dfe40b1caf713dc (patch)
treea3da33cd4c15627723b3e463e42372b60f421afd
parent1bbb797e9c7f37aa814b9bbaba2961f730a26891 (diff)
[mlir] Add `Print(Attr|Type)Qualified` trait
This PR adds a new trait to attributes and types that force the use of the qualified syntax for attributes and types. More concretely, any attribute or type with the trait must be parsed and printed with the `dialect.mnemonic` prefix. The motivation for this PR is the dependent PR where it is used to retain backwards-compatibility of syntax, but downstream projects may also use the trait if the subjectively prefer the verbose syntax.
-rw-r--r--mlir/include/mlir/IR/AttrTypeBase.td8
-rw-r--r--mlir/include/mlir/IR/Attributes.h8
-rw-r--r--mlir/include/mlir/IR/OpImplementation.h44
-rw-r--r--mlir/include/mlir/IR/Types.h5
-rw-r--r--mlir/test/IR/always-qualified-trait.mlir4
-rw-r--r--mlir/test/lib/Dialect/Test/TestAttrDefs.td9
-rw-r--r--mlir/test/lib/Dialect/Test/TestOps.td6
-rw-r--r--mlir/test/lib/Dialect/Test/TestTypeDefs.td9
8 files changed, 80 insertions, 13 deletions
diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td
index 91c9283de8bd..c371ce9e515d 100644
--- a/mlir/include/mlir/IR/AttrTypeBase.td
+++ b/mlir/include/mlir/IR/AttrTypeBase.td
@@ -38,6 +38,10 @@ class ParamNativeAttrTrait<string prop, string params>
class GenInternalAttrTrait<string prop> : GenInternalTrait<prop, "Attribute">;
class PredAttrTrait<string descr, Pred pred> : PredTrait<descr, pred>;
+// Trait used to tell the printer and parser to always print and parse
+// instances of the attribute as if it occurs within a `qualified` directive.
+def PrintAttrQualified : NativeAttrTrait<"PrintQualified">;
+
//===----------------------------------------------------------------------===//
// TypeTrait definitions
//===----------------------------------------------------------------------===//
@@ -56,6 +60,10 @@ class ParamNativeTypeTrait<string prop, string params>
class GenInternalTypeTrait<string prop> : GenInternalTrait<prop, "Type">;
class PredTypeTrait<string descr, Pred pred> : PredTrait<descr, pred>;
+// Trait used to tell the printer and parser to always print and parse
+// instances of the type as if it occurs within a `qualified` directive.
+def PrintTypeQualified : NativeTypeTrait<"PrintQualified">;
+
//===----------------------------------------------------------------------===//
// Builders
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index cc0cee6a3118..2122e5f8e135 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -317,12 +317,18 @@ protected:
// Core AttributeTrait
//===----------------------------------------------------------------------===//
+namespace AttributeTrait {
+
/// This trait is used to determine if an attribute is mutable or not. It is
/// attached on an attribute if the corresponding ImplType defines a `mutate`
/// function with proper signature.
-namespace AttributeTrait {
template <typename ConcreteType>
using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
+
+/// Trait used to tell the printer and parser to always print and parse
+/// instances of the attribute as if it occurs within a `qualified` directive.
+template <typename ConcreteAttr>
+struct PrintQualified : TraitBase<ConcreteAttr, PrintQualified> {};
} // namespace AttributeTrait
} // namespace mlir.
diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h
index 5333d7446df5..402399cf2966 100644
--- a/mlir/include/mlir/IR/OpImplementation.h
+++ b/mlir/include/mlir/IR/OpImplementation.h
@@ -132,12 +132,20 @@ public:
using detect_has_print_method =
llvm::is_detected<has_print_method, AttrOrType>;
+ /// Constexpr bool that is true if `AttrOrType` should be printed with the
+ /// dialect prefix stripped.
+ template <typename AttrOrType>
+ constexpr static bool shouldPrintStripped =
+ detect_has_print_method<AttrOrType>::value &&
+ (!std::is_base_of_v<AttributeTrait::PrintQualified<AttrOrType>,
+ AttrOrType> &&
+ !std::is_base_of_v<TypeTrait::PrintQualified<AttrOrType>, AttrOrType>);
+
/// Print the provided attribute in the context of an operation custom
/// printer/parser: this will invoke directly the print method on the
/// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
template <typename AttrOrType,
- std::enable_if_t<detect_has_print_method<AttrOrType>::value>
- *sfinae = nullptr>
+ std::enable_if_t<shouldPrintStripped<AttrOrType>> *sfinae = nullptr>
void printStrippedAttrOrType(AttrOrType attrOrType) {
if (succeeded(printAlias(attrOrType)))
return;
@@ -158,8 +166,7 @@ public:
/// method on the attribute class and skip the `#dialect.mnemonic` prefix in
/// most cases.
template <typename AttrOrType,
- std::enable_if_t<detect_has_print_method<AttrOrType>::value>
- *sfinae = nullptr>
+ std::enable_if_t<shouldPrintStripped<AttrOrType>> *sfinae = nullptr>
void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) {
llvm::interleaveComma(
attrOrTypes, getStream(),
@@ -170,8 +177,7 @@ public:
/// custom printer in the case where the attribute does not define a print
/// method.
template <typename AttrOrType,
- std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
- *sfinae = nullptr>
+ std::enable_if_t<!shouldPrintStripped<AttrOrType>> *sfinae = nullptr>
void printStrippedAttrOrType(AttrOrType attrOrType) {
*this << attrOrType;
}
@@ -980,12 +986,19 @@ public:
template <typename AttrType>
using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
+ /// Constexpr bool that is true if `AttrType` can be parsed with the dialect
+ /// prefix stripped.
+ template <typename AttrType>
+ constexpr static bool shouldParseAttrStripped =
+ detect_has_parse_method<AttrType>::value &&
+ !std::is_base_of_v<AttributeTrait::PrintQualified<AttrType>, AttrType>;
+
/// Parse a custom attribute of a given type unless the next token is `#`, in
/// which case the generic parser is invoked. The parsed attribute is
/// populated in `result` and also added to the specified attribute list with
/// the specified name.
template <typename AttrType>
- std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
+ std::enable_if_t<shouldParseAttrStripped<AttrType>, ParseResult>
parseCustomAttributeWithFallback(AttrType &result, Type type,
StringRef attrName, NamedAttrList &attrs) {
SMLoc loc = getCurrentLocation();
@@ -1012,7 +1025,7 @@ public:
/// SFINAE parsing method for Attribute that don't implement a parse method.
template <typename AttrType>
- std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
+ std::enable_if_t<!shouldParseAttrStripped<AttrType>, ParseResult>
parseCustomAttributeWithFallback(AttrType &result, Type type,
StringRef attrName, NamedAttrList &attrs) {
return parseAttribute(result, type, attrName, attrs);
@@ -1022,7 +1035,7 @@ public:
/// which case the generic parser is invoked. The parsed attribute is
/// populated in `result`.
template <typename AttrType>
- std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
+ std::enable_if_t<shouldParseAttrStripped<AttrType>, ParseResult>
parseCustomAttributeWithFallback(AttrType &result, Type type = {}) {
SMLoc loc = getCurrentLocation();
@@ -1044,7 +1057,7 @@ public:
/// SFINAE parsing method for Attribute that don't implement a parse method.
template <typename AttrType>
- std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
+ std::enable_if_t<!shouldParseAttrStripped<AttrType>, ParseResult>
parseCustomAttributeWithFallback(AttrType &result, Type type = {}) {
return parseAttribute(result, type);
}
@@ -1213,11 +1226,18 @@ public:
using detect_type_has_parse_method =
llvm::is_detected<type_has_parse_method, TypeT>;
+ /// Constexpr bool that is true if `TypeT` can be parsed with the dialect
+ /// prefix stripped.
+ template <typename TypeT>
+ constexpr static bool shouldParseTypeStripped =
+ detect_type_has_parse_method<TypeT>::value &&
+ !std::is_base_of_v<TypeTrait::PrintQualified<TypeT>, TypeT>;
+
/// Parse a custom Type of a given type unless the next token is `#`, in
/// which case the generic parser is invoked. The parsed Type is
/// populated in `result`.
template <typename TypeT>
- std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
+ std::enable_if_t<shouldParseTypeStripped<TypeT>, ParseResult>
parseCustomTypeWithFallback(TypeT &result) {
SMLoc loc = getCurrentLocation();
@@ -1238,7 +1258,7 @@ public:
/// SFINAE parsing method for Type that don't implement a parse method.
template <typename TypeT>
- std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
+ std::enable_if_t<!shouldParseTypeStripped<TypeT>, ParseResult>
parseCustomTypeWithFallback(TypeT &result) {
return parseType(result);
}
diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 46bb733101c1..5c647504f7a7 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -304,6 +304,11 @@ protected:
namespace TypeTrait {
template <typename ConcreteType>
using IsMutable = detail::StorageUserTrait::IsMutable<ConcreteType>;
+
+/// Trait used to tell the printer and parser to always print and parse
+/// instances of the type as if it occurs within a `qualified` directive.
+template <typename ConcreteType>
+struct PrintQualified : TraitBase<ConcreteType, PrintQualified> {};
} // namespace TypeTrait
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/IR/always-qualified-trait.mlir b/mlir/test/IR/always-qualified-trait.mlir
new file mode 100644
index 000000000000..c029cb20df34
--- /dev/null
+++ b/mlir/test/IR/always-qualified-trait.mlir
@@ -0,0 +1,4 @@
+// RUN: mlir-opt %s | FileCheck %s
+
+// CHECK: test.would_print_unqualified #test.always_qualified<5> -> !test.always_qualified<7>
+%0 = test.would_print_unqualified #test.always_qualified<5> -> !test.always_qualified<7>
diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
index 40f035a3e3a4..5e7f8e290dbd 100644
--- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td
@@ -340,4 +340,13 @@ def TestConditionalAliasAttr : Test_Attr<"TestConditionalAlias"> {
}];
}
+def TestAlwaysQualifiedAttr : Test_Attr<"TestAlwaysQualified",
+ [PrintAttrQualified]> {
+ let mnemonic = "always_qualified";
+ let parameters = (ins "int":$value);
+ let assemblyFormat = [{
+ `<` $value `>`
+ }];
+}
+
#endif // TEST_ATTRDEFS
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 91ce0af9cd7f..e62322e8ad0d 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3056,4 +3056,10 @@ def TestOpOptionallyImplementingInterface
let arguments = (ins BoolAttr:$implementsInterface);
}
+def TestOpWouldPrintUnqualified : TEST_Op<"would_print_unqualified"> {
+ let arguments = (ins TestAlwaysQualifiedAttr:$attr);
+ let results = (outs TestAlwaysQualifiedType:$result);
+ let assemblyFormat = "$attr `->` type($result) attr-dict";
+}
+
#endif // TEST_OPS
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index 1957845c842f..681ab3de440a 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -391,4 +391,13 @@ def TestRecursiveAlias
}];
}
+def TestAlwaysQualifiedType : Test_Type<"TestAlwaysQualified",
+ [PrintTypeQualified]> {
+ let mnemonic = "always_qualified";
+ let parameters = (ins "int":$value);
+ let assemblyFormat = [{
+ `<` $value `>`
+ }];
+}
+
#endif // TEST_TYPEDEFS