summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSahilPatidar <89641424+SahilPatidar@users.noreply.github.com>2024-02-18 11:51:35 +0530
committerGitHub <noreply@github.com>2024-02-17 22:21:35 -0800
commitead0a9777f8ccb5c26d50d96bade6cd5b47f496b (patch)
tree1908df5aaa2eebbd2fd6d58c5b77ae21aaf8a96c
parent339baae3e223693a98f0f25e06147e4e6dde1254 (diff)
[mlir][spirv] Replace hardcoded strings with op methods (#81443)
Progress towards #77627 --------- Co-authored-by: SahilPatidar <patidarsahil@2001gmail.com> Co-authored-by: Lei Zhang <antiagainst@gmail.com>
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp3
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp5
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp20
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.cpp8
4 files changed, 23 insertions, 13 deletions
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
index a678124bf483..5b2903824c9e 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
@@ -489,7 +489,8 @@ Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
auto attrValue = words[wordIndex++];
auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
static_cast<spirv::MemoryAccess>(attrValue));
- attributes.push_back(opBuilder.getNamedAttr("memory_access", attr));
+ attributes.push_back(
+ opBuilder.getNamedAttr(attributeName<MemoryAccess>(), attr));
isAlignedAttr = (attrValue == 2);
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 02d03b3a0fae..83ef01b4e3a4 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -216,10 +216,11 @@ spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
return emitError(unknownLoc, "OpMemoryModel must have two operands");
(*module)->setAttr(
- "addressing_model",
+ module->getAddressingModelAttrName(),
opBuilder.getAttr<spirv::AddressingModelAttr>(
static_cast<spirv::AddressingModel>(operands.front())));
- (*module)->setAttr("memory_model",
+
+ (*module)->setAttr(module->getMemoryModelAttrName(),
opBuilder.getAttr<spirv::MemoryModelAttr>(
static_cast<spirv::MemoryModel>(operands.back())));
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index e68ed5efaca7..c283e64fa185 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -709,33 +709,37 @@ Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
operands.push_back(id);
}
- if (auto attr = op->getAttr("memory_access")) {
+ StringAttr memoryAccess = op.getMemoryAccessAttrName();
+ if (auto attr = op->getAttr(memoryAccess)) {
operands.push_back(
static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
}
- elidedAttrs.push_back("memory_access");
+ elidedAttrs.push_back(memoryAccess.strref());
- if (auto attr = op->getAttr("alignment")) {
+ StringAttr alignment = op.getAlignmentAttrName();
+ if (auto attr = op->getAttr(alignment)) {
operands.push_back(static_cast<uint32_t>(
cast<IntegerAttr>(attr).getValue().getZExtValue()));
}
- elidedAttrs.push_back("alignment");
+ elidedAttrs.push_back(alignment.strref());
- if (auto attr = op->getAttr("source_memory_access")) {
+ StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
+ if (auto attr = op->getAttr(sourceMemoryAccess)) {
operands.push_back(
static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
}
- elidedAttrs.push_back("source_memory_access");
+ elidedAttrs.push_back(sourceMemoryAccess.strref());
- if (auto attr = op->getAttr("source_alignment")) {
+ StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
+ if (auto attr = op->getAttr(sourceAlignment)) {
operands.push_back(static_cast<uint32_t>(
cast<IntegerAttr>(attr).getValue().getZExtValue()));
}
- elidedAttrs.push_back("source_alignment");
+ elidedAttrs.push_back(sourceAlignment.strref());
if (failed(emitDebugLine(functionBody, op.getLoc())))
return failure();
encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 40337e007bbf..4a4e878d8af9 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -197,10 +197,14 @@ void Serializer::processExtension() {
}
void Serializer::processMemoryModel() {
+ StringAttr memoryModelName = module.getMemoryModelAttrName();
auto mm = static_cast<uint32_t>(
- module->getAttrOfType<spirv::MemoryModelAttr>("memory_model").getValue());
+ module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
+ .getValue());
+
+ StringAttr addressingModelName = module.getAddressingModelAttrName();
auto am = static_cast<uint32_t>(
- module->getAttrOfType<spirv::AddressingModelAttr>("addressing_model")
+ module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
.getValue());
encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});