diff options
Diffstat (limited to 'src/3rdparty/angle/src/compiler/translator/ArrayReturnValueToOutParameter.cpp')
-rw-r--r-- | src/3rdparty/angle/src/compiler/translator/ArrayReturnValueToOutParameter.cpp | 227 |
1 files changed, 117 insertions, 110 deletions
diff --git a/src/3rdparty/angle/src/compiler/translator/ArrayReturnValueToOutParameter.cpp b/src/3rdparty/angle/src/compiler/translator/ArrayReturnValueToOutParameter.cpp index 510ade84c1..17721fb0dc 100644 --- a/src/3rdparty/angle/src/compiler/translator/ArrayReturnValueToOutParameter.cpp +++ b/src/3rdparty/angle/src/compiler/translator/ArrayReturnValueToOutParameter.cpp @@ -3,17 +3,23 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // -// The ArrayReturnValueToOutParameter function changes return values of an array type to out parameters in -// function definitions, prototypes, and call sites. +// The ArrayReturnValueToOutParameter function changes return values of an array type to out +// parameters in function definitions, prototypes, and call sites. #include "compiler/translator/ArrayReturnValueToOutParameter.h" -#include "compiler/translator/IntermNode.h" +#include <map> + +#include "compiler/translator/IntermTraverse.h" +#include "compiler/translator/SymbolTable.h" + +namespace sh +{ namespace { -void CopyAggregateChildren(TIntermAggregate *from, TIntermAggregate *to) +void CopyAggregateChildren(TIntermAggregateBase *from, TIntermAggregateBase *to) { const TIntermSequence *fromSequence = from->getSequence(); for (size_t ii = 0; ii < fromSequence->size(); ++ii) @@ -22,156 +28,153 @@ void CopyAggregateChildren(TIntermAggregate *from, TIntermAggregate *to) } } -TIntermSymbol *CreateReturnValueSymbol(const TType &type) +TIntermSymbol *CreateReturnValueSymbol(const TSymbolUniqueId &id, const TType &type) { - TIntermSymbol *node = new TIntermSymbol(0, "angle_return", type); + TIntermSymbol *node = new TIntermSymbol(id, "angle_return", type); node->setInternal(true); + node->getTypePointer()->setQualifier(EvqOut); return node; } -TIntermSymbol *CreateReturnValueOutSymbol(const TType &type) +TIntermAggregate *CreateReplacementCall(TIntermAggregate *originalCall, + TIntermTyped *returnValueTarget) { - TType outType(type); - outType.setQualifier(EvqOut); - return CreateReturnValueSymbol(outType); -} - -TIntermAggregate *CreateReplacementCall(TIntermAggregate *originalCall, TIntermTyped *returnValueTarget) -{ - TIntermAggregate *replacementCall = new TIntermAggregate(EOpFunctionCall); - replacementCall->setType(TType(EbtVoid)); - replacementCall->setUserDefined(); - replacementCall->setNameObj(originalCall->getNameObj()); - replacementCall->setFunctionId(originalCall->getFunctionId()); - replacementCall->setLine(originalCall->getLine()); - TIntermSequence *replacementParameters = replacementCall->getSequence(); - TIntermSequence *originalParameters = originalCall->getSequence(); - for (auto ¶m : *originalParameters) + TIntermSequence *replacementArguments = new TIntermSequence(); + TIntermSequence *originalArguments = originalCall->getSequence(); + for (auto &arg : *originalArguments) { - replacementParameters->push_back(param); + replacementArguments->push_back(arg); } - replacementParameters->push_back(returnValueTarget); + replacementArguments->push_back(returnValueTarget); + TIntermAggregate *replacementCall = TIntermAggregate::CreateFunctionCall( + TType(EbtVoid), originalCall->getFunctionSymbolInfo()->getId(), + originalCall->getFunctionSymbolInfo()->getNameObj(), replacementArguments); + replacementCall->setLine(originalCall->getLine()); return replacementCall; } class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser { public: - static void apply(TIntermNode *root, unsigned int *temporaryIndex); + static void apply(TIntermNode *root, TSymbolTable *symbolTable); + private: - ArrayReturnValueToOutParameterTraverser(); + ArrayReturnValueToOutParameterTraverser(TSymbolTable *symbolTable); + bool visitFunctionPrototype(Visit visit, TIntermFunctionPrototype *node) override; + bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override; bool visitAggregate(Visit visit, TIntermAggregate *node) override; bool visitBranch(Visit visit, TIntermBranch *node) override; bool visitBinary(Visit visit, TIntermBinary *node) override; - bool mInFunctionWithArrayReturnValue; + // Set when traversal is inside a function with array return value. + TIntermFunctionDefinition *mFunctionWithArrayReturnValue; + + // Map from function symbol ids to array return value ids. + std::map<int, TSymbolUniqueId *> mReturnValueIds; }; -void ArrayReturnValueToOutParameterTraverser::apply(TIntermNode *root, unsigned int *temporaryIndex) +void ArrayReturnValueToOutParameterTraverser::apply(TIntermNode *root, TSymbolTable *symbolTable) { - ArrayReturnValueToOutParameterTraverser arrayReturnValueToOutParam; - arrayReturnValueToOutParam.useTemporaryIndex(temporaryIndex); + ArrayReturnValueToOutParameterTraverser arrayReturnValueToOutParam(symbolTable); root->traverse(&arrayReturnValueToOutParam); arrayReturnValueToOutParam.updateTree(); } -ArrayReturnValueToOutParameterTraverser::ArrayReturnValueToOutParameterTraverser() - : TIntermTraverser(true, false, true), - mInFunctionWithArrayReturnValue(false) +ArrayReturnValueToOutParameterTraverser::ArrayReturnValueToOutParameterTraverser( + TSymbolTable *symbolTable) + : TIntermTraverser(true, false, true, symbolTable), mFunctionWithArrayReturnValue(nullptr) { } -bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TIntermAggregate *node) +bool ArrayReturnValueToOutParameterTraverser::visitFunctionDefinition( + Visit visit, + TIntermFunctionDefinition *node) +{ + if (node->getFunctionPrototype()->isArray() && visit == PreVisit) + { + // Replacing the function header is done on visitFunctionPrototype(). + mFunctionWithArrayReturnValue = node; + } + if (visit == PostVisit) + { + mFunctionWithArrayReturnValue = nullptr; + } + return true; +} + +bool ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(Visit visit, + TIntermFunctionPrototype *node) { - if (visit == PreVisit) + if (visit == PreVisit && node->isArray()) { - if (node->isArray()) + // Replace the whole prototype node with another node that has the out parameter + // added. Also set the function to return void. + TIntermFunctionPrototype *replacement = + new TIntermFunctionPrototype(TType(EbtVoid), node->getFunctionSymbolInfo()->getId()); + CopyAggregateChildren(node, replacement); + const TSymbolUniqueId &functionId = node->getFunctionSymbolInfo()->getId(); + if (mReturnValueIds.find(functionId.get()) == mReturnValueIds.end()) { - if (node->getOp() == EOpFunction) - { - // Replace the parameters child node of the function definition with another node - // that has the out parameter added. - // Also set the function to return void. - - TIntermAggregate *params = node->getSequence()->front()->getAsAggregate(); - ASSERT(params != nullptr && params->getOp() == EOpParameters); - - TIntermAggregate *replacementParams = new TIntermAggregate; - replacementParams->setOp(EOpParameters); - CopyAggregateChildren(params, replacementParams); - replacementParams->getSequence()->push_back(CreateReturnValueOutSymbol(node->getType())); - replacementParams->setLine(params->getLine()); - - mReplacements.push_back(NodeUpdateEntry(node, params, replacementParams, false)); - - node->setType(TType(EbtVoid)); - - mInFunctionWithArrayReturnValue = true; - } - else if (node->getOp() == EOpPrototype) - { - // Replace the whole prototype node with another node that has the out parameter added. - TIntermAggregate *replacement = new TIntermAggregate; - replacement->setOp(EOpPrototype); - CopyAggregateChildren(node, replacement); - replacement->getSequence()->push_back(CreateReturnValueOutSymbol(node->getType())); - replacement->setUserDefined(); - replacement->setNameObj(node->getNameObj()); - replacement->setFunctionId(node->getFunctionId()); - replacement->setLine(node->getLine()); - replacement->setType(TType(EbtVoid)); - - mReplacements.push_back(NodeUpdateEntry(getParentNode(), node, replacement, false)); - } - else if (node->getOp() == EOpFunctionCall) - { - // Handle call sites where the returned array is not assigned. - // Examples where f() is a function returning an array: - // 1. f(); - // 2. another_array == f(); - // 3. another_function(f()); - // 4. return f(); - // Cases 2 to 4 are already converted to simpler cases by SeparateExpressionsReturningArrays, so we - // only need to worry about the case where a function call returning an array forms an expression by - // itself. - TIntermAggregate *parentAgg = getParentNode()->getAsAggregate(); - if (parentAgg != nullptr && parentAgg->getOp() == EOpSequence) - { - nextTemporaryIndex(); - TIntermSequence replacements; - replacements.push_back(createTempDeclaration(node->getType())); - TIntermSymbol *returnSymbol = createTempSymbol(node->getType()); - replacements.push_back(CreateReplacementCall(node, returnSymbol)); - mMultiReplacements.push_back(NodeReplaceWithMultipleEntry(parentAgg, node, replacements)); - } - return false; - } + mReturnValueIds[functionId.get()] = new TSymbolUniqueId(mSymbolTable); } + replacement->getSequence()->push_back( + CreateReturnValueSymbol(*mReturnValueIds[functionId.get()], node->getType())); + *replacement->getFunctionSymbolInfo() = *node->getFunctionSymbolInfo(); + replacement->setLine(node->getLine()); + + queueReplacement(replacement, OriginalNode::IS_DROPPED); } - else if (visit == PostVisit) + return false; +} + +bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TIntermAggregate *node) +{ + ASSERT(!node->isArray() || node->getOp() != EOpCallInternalRawFunction); + if (visit == PreVisit && node->isArray() && node->getOp() == EOpCallFunctionInAST) { - if (node->getOp() == EOpFunction) + // Handle call sites where the returned array is not assigned. + // Examples where f() is a function returning an array: + // 1. f(); + // 2. another_array == f(); + // 3. another_function(f()); + // 4. return f(); + // Cases 2 to 4 are already converted to simpler cases by + // SeparateExpressionsReturningArrays, so we only need to worry about the case where a + // function call returning an array forms an expression by itself. + TIntermBlock *parentBlock = getParentNode()->getAsBlock(); + if (parentBlock) { - mInFunctionWithArrayReturnValue = false; + nextTemporaryId(); + TIntermSequence replacements; + replacements.push_back(createTempDeclaration(node->getType())); + TIntermSymbol *returnSymbol = createTempSymbol(node->getType()); + replacements.push_back(CreateReplacementCall(node, returnSymbol)); + mMultiReplacements.push_back( + NodeReplaceWithMultipleEntry(parentBlock, node, replacements)); } + return false; } return true; } bool ArrayReturnValueToOutParameterTraverser::visitBranch(Visit visit, TIntermBranch *node) { - if (mInFunctionWithArrayReturnValue && node->getFlowOp() == EOpReturn) + if (mFunctionWithArrayReturnValue && node->getFlowOp() == EOpReturn) { // Instead of returning a value, assign to the out parameter and then return. TIntermSequence replacements; - TIntermBinary *replacementAssignment = new TIntermBinary(EOpAssign); TIntermTyped *expression = node->getExpression(); ASSERT(expression != nullptr); - replacementAssignment->setLeft(CreateReturnValueSymbol(expression->getType())); - replacementAssignment->setRight(node->getExpression()); - replacementAssignment->setType(expression->getType()); + const TSymbolUniqueId &functionId = + mFunctionWithArrayReturnValue->getFunctionSymbolInfo()->getId(); + ASSERT(mReturnValueIds.find(functionId.get()) != mReturnValueIds.end()); + const TSymbolUniqueId &returnValueId = *mReturnValueIds[functionId.get()]; + TIntermSymbol *returnValueSymbol = + CreateReturnValueSymbol(returnValueId, expression->getType()); + TIntermBinary *replacementAssignment = + new TIntermBinary(EOpAssign, returnValueSymbol, expression); replacementAssignment->setLine(expression->getLine()); replacements.push_back(replacementAssignment); @@ -179,7 +182,8 @@ bool ArrayReturnValueToOutParameterTraverser::visitBranch(Visit visit, TIntermBr replacementBranch->setLine(node->getLine()); replacements.push_back(replacementBranch); - mMultiReplacements.push_back(NodeReplaceWithMultipleEntry(getParentNode()->getAsAggregate(), node, replacements)); + mMultiReplacements.push_back( + NodeReplaceWithMultipleEntry(getParentNode()->getAsBlock(), node, replacements)); } return false; } @@ -189,18 +193,21 @@ bool ArrayReturnValueToOutParameterTraverser::visitBinary(Visit visit, TIntermBi if (node->getOp() == EOpAssign && node->getLeft()->isArray()) { TIntermAggregate *rightAgg = node->getRight()->getAsAggregate(); - if (rightAgg != nullptr && rightAgg->getOp() == EOpFunctionCall && rightAgg->isUserDefined()) + ASSERT(rightAgg == nullptr || rightAgg->getOp() != EOpCallInternalRawFunction); + if (rightAgg != nullptr && rightAgg->getOp() == EOpCallFunctionInAST) { TIntermAggregate *replacementCall = CreateReplacementCall(rightAgg, node->getLeft()); - mReplacements.push_back(NodeUpdateEntry(getParentNode(), node, replacementCall, false)); + queueReplacement(replacementCall, OriginalNode::IS_DROPPED); } } return false; } -} // namespace +} // namespace -void ArrayReturnValueToOutParameter(TIntermNode *root, unsigned int *temporaryIndex) +void ArrayReturnValueToOutParameter(TIntermNode *root, TSymbolTable *symbolTable) { - ArrayReturnValueToOutParameterTraverser::apply(root, temporaryIndex); + ArrayReturnValueToOutParameterTraverser::apply(root, symbolTable); } + +} // namespace sh |