diff options
author | SahilPatidar <89641424+SahilPatidar@users.noreply.github.com> | 2024-02-18 11:51:35 +0530 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-17 22:21:35 -0800 |
commit | ead0a9777f8ccb5c26d50d96bade6cd5b47f496b (patch) | |
tree | 1908df5aaa2eebbd2fd6d58c5b77ae21aaf8a96c | |
parent | 339baae3e223693a98f0f25e06147e4e6dde1254 (diff) |
[mlir][spirv] Replace hardcoded strings with op methods (#81443)
Progress towards #77627
---------
Co-authored-by: SahilPatidar <patidarsahil@2001gmail.com>
Co-authored-by: Lei Zhang <antiagainst@gmail.com>
4 files changed, 23 insertions, 13 deletions
diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp index a678124bf483..5b2903824c9e 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -489,7 +489,8 @@ Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) { auto attrValue = words[wordIndex++]; auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>( static_cast<spirv::MemoryAccess>(attrValue)); - attributes.push_back(opBuilder.getNamedAttr("memory_access", attr)); + attributes.push_back( + opBuilder.getNamedAttr(attributeName<MemoryAccess>(), attr)); isAlignedAttr = (attrValue == 2); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 02d03b3a0fae..83ef01b4e3a4 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -216,10 +216,11 @@ spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) { return emitError(unknownLoc, "OpMemoryModel must have two operands"); (*module)->setAttr( - "addressing_model", + module->getAddressingModelAttrName(), opBuilder.getAttr<spirv::AddressingModelAttr>( static_cast<spirv::AddressingModel>(operands.front()))); - (*module)->setAttr("memory_model", + + (*module)->setAttr(module->getMemoryModelAttrName(), opBuilder.getAttr<spirv::MemoryModelAttr>( static_cast<spirv::MemoryModel>(operands.back()))); diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp index e68ed5efaca7..c283e64fa185 100644 --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -709,33 +709,37 @@ Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) { operands.push_back(id); } - if (auto attr = op->getAttr("memory_access")) { + StringAttr memoryAccess = op.getMemoryAccessAttrName(); + if (auto attr = op->getAttr(memoryAccess)) { operands.push_back( static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue())); } - elidedAttrs.push_back("memory_access"); + elidedAttrs.push_back(memoryAccess.strref()); - if (auto attr = op->getAttr("alignment")) { + StringAttr alignment = op.getAlignmentAttrName(); + if (auto attr = op->getAttr(alignment)) { operands.push_back(static_cast<uint32_t>( cast<IntegerAttr>(attr).getValue().getZExtValue())); } - elidedAttrs.push_back("alignment"); + elidedAttrs.push_back(alignment.strref()); - if (auto attr = op->getAttr("source_memory_access")) { + StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName(); + if (auto attr = op->getAttr(sourceMemoryAccess)) { operands.push_back( static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue())); } - elidedAttrs.push_back("source_memory_access"); + elidedAttrs.push_back(sourceMemoryAccess.strref()); - if (auto attr = op->getAttr("source_alignment")) { + StringAttr sourceAlignment = op.getSourceAlignmentAttrName(); + if (auto attr = op->getAttr(sourceAlignment)) { operands.push_back(static_cast<uint32_t>( cast<IntegerAttr>(attr).getValue().getZExtValue())); } - elidedAttrs.push_back("source_alignment"); + elidedAttrs.push_back(sourceAlignment.strref()); if (failed(emitDebugLine(functionBody, op.getLoc()))) return failure(); encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 40337e007bbf..4a4e878d8af9 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -197,10 +197,14 @@ void Serializer::processExtension() { } void Serializer::processMemoryModel() { + StringAttr memoryModelName = module.getMemoryModelAttrName(); auto mm = static_cast<uint32_t>( - module->getAttrOfType<spirv::MemoryModelAttr>("memory_model").getValue()); + module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName) + .getValue()); + + StringAttr addressingModelName = module.getAddressingModelAttrName(); auto am = static_cast<uint32_t>( - module->getAttrOfType<spirv::AddressingModelAttr>("addressing_model") + module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName) .getValue()); encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm}); |