summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBoian Petkantchin <boian.petkantchin@amd.com>2024-01-05 07:14:07 -0800
committerGitHub <noreply@github.com>2024-01-05 07:14:07 -0800
commitfc18b13492880ba8597333c6050a18ec6db0f831 (patch)
treefec61d39cab0fccd3cfd413b5e12b3abaa9b5c9a
parent10b03e66629aedad79a804e22d23b575077303b3 (diff)
[mlir][mesh] In sharding attr use FlatSymbolRefAttr instead of SymbolRefAttr (#76886)
Analogous to func.call use FlatSymbolRefAttr to reference the corresponding mesh.
-rw-r--r--mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td8
-rw-r--r--mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td4
-rw-r--r--mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h4
-rw-r--r--mlir/lib/Dialect/Mesh/IR/MeshOps.cpp2
-rw-r--r--mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp2
-rw-r--r--mlir/test/Dialect/Mesh/invalid.mlir8
6 files changed, 18 insertions, 10 deletions
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 060d54b82efa..bda6467e9c5d 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -79,7 +79,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
let mnemonic = "shard";
let parameters = (ins
- AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster,
+ AttrParameter<"::mlir::FlatSymbolRefAttr", "cluster placed">:$cluster,
ArrayRefParameter<"MeshAxesAttr">:$split_axes,
OptionalArrayRefParameter<"MeshAxis">:$partial_axes,
OptionalParameter<"::mlir::mesh::Partial">:$partial_type
@@ -91,7 +91,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
The MeshSharding attribute could be used in the encoding of a
`RankedTensorType` or the mesh.shard op. it contains three sub-attributes:
- 1. `cluster`: this attribute is a SymbolRefAttr that refers to the mesh
+ 1. `cluster`: this attribute is a FlatSymbolRefAttr that refers to the mesh
cluster where the distributed tensor is placed. The symbol must resolve to a
`mesh.cluster` operation.
@@ -145,7 +145,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
}];
let builders = [
- AttrBuilder<(ins "SymbolRefAttr":$cluster,
+ AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes,
"ArrayRef<MeshAxis>": $partial_axes,
"mesh::Partial": $partial_type), [{
@@ -156,7 +156,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
return $_get($_ctxt, cluster, splitAxesAttr, partial_axes,
partial_type);
}]>,
- AttrBuilder<(ins "SymbolRefAttr":$cluster,
+ AttrBuilder<(ins "FlatSymbolRefAttr":$cluster,
"ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{
return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum);
}]>
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 1934bdfb4270..f459077ea120 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -196,12 +196,12 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
```
}];
let arguments = (ins
- Builtin_RankedTensor:$src,
+ AnyRankedTensor:$src,
MeshSharding:$shard,
UnitAttr:$annotate_for_users
);
let results = (outs
- Builtin_RankedTensor:$result
+ AnyRankedTensor:$result
);
let assemblyFormat = [{
$src `to` $shard (`annotate_for_users` $annotate_for_users^)? attr-dict `:`
diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
index 201c0151754e..a32274d857f1 100644
--- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
+++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h
@@ -25,13 +25,13 @@ struct ShardingOption {
// An array of int array. The sub-array at the i-th position signifies the
// mesh axes the i-th loop will be sharded on.
ShardingArray shardingArray = {};
- SymbolRefAttr cluster = nullptr;
+ FlatSymbolRefAttr cluster = nullptr;
// `empty` being true indicates that no sharding information can be inferred
// at present. Note that it is different from the case where an operation is
// not sharded.
bool empty = false;
ShardingOption() = default;
- ShardingOption(ShardingArray shardingArray, SymbolRefAttr cluster)
+ ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr cluster)
: shardingArray(std::move(shardingArray)), cluster(cluster) {}
};
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index c3d8f1d45610..6667d409df8b 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -266,7 +266,7 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
LogicalResult
MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
- SymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
+ FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes,
ArrayRef<MeshAxis> partialAxes, Partial) {
// TODO: At present cluster symbol ref is not verified. This is due to the
// difficulty in fetching the corresponding symbol op based on an attribute.
diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
index ee885ab16b7b..dca7e86e6f07 100644
--- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
+++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp
@@ -215,7 +215,7 @@ namespace {
// Update the given `shardingOption` according to `meshAxes` and `loopIdx`
static LogicalResult fillShardingOption(Operation *op,
ShardingOption &shardingOption,
- SymbolRefAttr cluster,
+ FlatSymbolRefAttr cluster,
ArrayRef<MeshAxis> meshAxes,
unsigned loopIdx) {
if ((shardingOption.cluster && cluster &&
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 3ee578a37235..3e1b04da0dfd 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -70,6 +70,14 @@ func.func @mesh_axis_negtive_in_partial(
// -----
+func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) {
+ // expected-error@+2 {{custom op 'mesh.shard' invalid kind of attribute specified}}
+ // expected-error@+1 {{custom op 'mesh.shard' failed to parse MeshSharding parameter 'cluster' which is to be a `::mlir::FlatSymbolRefAttr`}}
+ %0 = mesh.shard %arg0 to <@a::@b, [[0]]> : tensor<4x8xf32>
+}
+
+// -----
+
mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {