diff options
author | Boian Petkantchin <boian.petkantchin@amd.com> | 2024-01-05 07:14:07 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-05 07:14:07 -0800 |
commit | fc18b13492880ba8597333c6050a18ec6db0f831 (patch) | |
tree | fec61d39cab0fccd3cfd413b5e12b3abaa9b5c9a | |
parent | 10b03e66629aedad79a804e22d23b575077303b3 (diff) |
[mlir][mesh] In sharding attr use FlatSymbolRefAttr instead of SymbolRefAttr (#76886)
Analogous to func.call use FlatSymbolRefAttr to reference the
corresponding mesh.
-rw-r--r-- | mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td | 8 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td | 4 | ||||
-rw-r--r-- | mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h | 4 | ||||
-rw-r--r-- | mlir/lib/Dialect/Mesh/IR/MeshOps.cpp | 2 | ||||
-rw-r--r-- | mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp | 2 | ||||
-rw-r--r-- | mlir/test/Dialect/Mesh/invalid.mlir | 8 |
6 files changed, 18 insertions, 10 deletions
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td index 060d54b82efa..bda6467e9c5d 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td @@ -79,7 +79,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> { let mnemonic = "shard"; let parameters = (ins - AttrParameter<"::mlir::SymbolRefAttr", "cluster placed">:$cluster, + AttrParameter<"::mlir::FlatSymbolRefAttr", "cluster placed">:$cluster, ArrayRefParameter<"MeshAxesAttr">:$split_axes, OptionalArrayRefParameter<"MeshAxis">:$partial_axes, OptionalParameter<"::mlir::mesh::Partial">:$partial_type @@ -91,7 +91,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> { The MeshSharding attribute could be used in the encoding of a `RankedTensorType` or the mesh.shard op. it contains three sub-attributes: - 1. `cluster`: this attribute is a SymbolRefAttr that refers to the mesh + 1. `cluster`: this attribute is a FlatSymbolRefAttr that refers to the mesh cluster where the distributed tensor is placed. The symbol must resolve to a `mesh.cluster` operation. @@ -145,7 +145,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> { }]; let builders = [ - AttrBuilder<(ins "SymbolRefAttr":$cluster, + AttrBuilder<(ins "FlatSymbolRefAttr":$cluster, "ArrayRef<SmallVector<MeshAxis>>":$split_axes, "ArrayRef<MeshAxis>": $partial_axes, "mesh::Partial": $partial_type), [{ @@ -156,7 +156,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> { return $_get($_ctxt, cluster, splitAxesAttr, partial_axes, partial_type); }]>, - AttrBuilder<(ins "SymbolRefAttr":$cluster, + AttrBuilder<(ins "FlatSymbolRefAttr":$cluster, "ArrayRef<SmallVector<MeshAxis>>":$split_axes), [{ return MeshShardingAttr::get($_ctxt, cluster, split_axes, {}, Partial::Sum); }]> diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td index 1934bdfb4270..f459077ea120 100644 --- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td +++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td @@ -196,12 +196,12 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> { ``` }]; let arguments = (ins - Builtin_RankedTensor:$src, + AnyRankedTensor:$src, MeshSharding:$shard, UnitAttr:$annotate_for_users ); let results = (outs - Builtin_RankedTensor:$result + AnyRankedTensor:$result ); let assemblyFormat = [{ $src `to` $shard (`annotate_for_users` $annotate_for_users^)? attr-dict `:` diff --git a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h index 201c0151754e..a32274d857f1 100644 --- a/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h +++ b/mlir/include/mlir/Dialect/Mesh/Interfaces/ShardingInterface.h @@ -25,13 +25,13 @@ struct ShardingOption { // An array of int array. The sub-array at the i-th position signifies the // mesh axes the i-th loop will be sharded on. ShardingArray shardingArray = {}; - SymbolRefAttr cluster = nullptr; + FlatSymbolRefAttr cluster = nullptr; // `empty` being true indicates that no sharding information can be inferred // at present. Note that it is different from the case where an operation is // not sharded. bool empty = false; ShardingOption() = default; - ShardingOption(ShardingArray shardingArray, SymbolRefAttr cluster) + ShardingOption(ShardingArray shardingArray, FlatSymbolRefAttr cluster) : shardingArray(std::move(shardingArray)), cluster(cluster) {} }; diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp index c3d8f1d45610..6667d409df8b 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp @@ -266,7 +266,7 @@ void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, LogicalResult MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError, - SymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes, + FlatSymbolRefAttr, ArrayRef<MeshAxesAttr> splitAxes, ArrayRef<MeshAxis> partialAxes, Partial) { // TODO: At present cluster symbol ref is not verified. This is due to the // difficulty in fetching the corresponding symbol op based on an attribute. diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp index ee885ab16b7b..dca7e86e6f07 100644 --- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp +++ b/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp @@ -215,7 +215,7 @@ namespace { // Update the given `shardingOption` according to `meshAxes` and `loopIdx` static LogicalResult fillShardingOption(Operation *op, ShardingOption &shardingOption, - SymbolRefAttr cluster, + FlatSymbolRefAttr cluster, ArrayRef<MeshAxis> meshAxes, unsigned loopIdx) { if ((shardingOption.cluster && cluster && diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir index 3ee578a37235..3e1b04da0dfd 100644 --- a/mlir/test/Dialect/Mesh/invalid.mlir +++ b/mlir/test/Dialect/Mesh/invalid.mlir @@ -70,6 +70,14 @@ func.func @mesh_axis_negtive_in_partial( // ----- +func.func @sharding_attribute_invalid_nested_symbol(%arg0 : tensor<4x8xf32>) { + // expected-error@+2 {{custom op 'mesh.shard' invalid kind of attribute specified}} + // expected-error@+1 {{custom op 'mesh.shard' failed to parse MeshSharding parameter 'cluster' which is to be a `::mlir::FlatSymbolRefAttr`}} + %0 = mesh.shard %arg0 to <@a::@b, [[0]]> : tensor<4x8xf32> +} + +// ----- + mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4) func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) { |