summaryrefslogtreecommitdiffstats
path: root/src/3rdparty/angle/src/compiler/translator/OutputHLSL.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/3rdparty/angle/src/compiler/translator/OutputHLSL.cpp')
-rw-r--r--src/3rdparty/angle/src/compiler/translator/OutputHLSL.cpp991
1 files changed, 596 insertions, 395 deletions
diff --git a/src/3rdparty/angle/src/compiler/translator/OutputHLSL.cpp b/src/3rdparty/angle/src/compiler/translator/OutputHLSL.cpp
index 30bbbff0f5..94225b81c4 100644
--- a/src/3rdparty/angle/src/compiler/translator/OutputHLSL.cpp
+++ b/src/3rdparty/angle/src/compiler/translator/OutputHLSL.cpp
@@ -6,26 +6,29 @@
#include "compiler/translator/OutputHLSL.h"
+#include <algorithm>
+#include <cfloat>
+#include <stdio.h>
+
#include "common/angleutils.h"
#include "common/utilities.h"
-#include "common/blocklayout.h"
-#include "compiler/translator/compilerdebug.h"
-#include "compiler/translator/InfoSink.h"
+#include "compiler/translator/BuiltInFunctionEmulator.h"
+#include "compiler/translator/BuiltInFunctionEmulatorHLSL.h"
#include "compiler/translator/DetectDiscontinuity.h"
-#include "compiler/translator/SearchSymbol.h"
-#include "compiler/translator/UnfoldShortCircuit.h"
#include "compiler/translator/FlagStd140Structs.h"
+#include "compiler/translator/InfoSink.h"
#include "compiler/translator/NodeSearch.h"
+#include "compiler/translator/RemoveSwitchFallThrough.h"
#include "compiler/translator/RewriteElseBlocks.h"
-#include "compiler/translator/UtilsHLSL.h"
-#include "compiler/translator/util.h"
-#include "compiler/translator/UniformHLSL.h"
+#include "compiler/translator/SearchSymbol.h"
#include "compiler/translator/StructureHLSL.h"
#include "compiler/translator/TranslatorHLSL.h"
-
-#include <algorithm>
-#include <cfloat>
-#include <stdio.h>
+#include "compiler/translator/UnfoldShortCircuit.h"
+#include "compiler/translator/UniformHLSL.h"
+#include "compiler/translator/UtilsHLSL.h"
+#include "compiler/translator/blocklayout.h"
+#include "compiler/translator/compilerdebug.h"
+#include "compiler/translator/util.h"
namespace sh
{
@@ -94,12 +97,21 @@ bool OutputHLSL::TextureFunction::operator<(const TextureFunction &rhs) const
return false;
}
-OutputHLSL::OutputHLSL(TParseContext &context, TranslatorHLSL *parentTranslator)
+OutputHLSL::OutputHLSL(sh::GLenum shaderType, int shaderVersion,
+ const TExtensionBehavior &extensionBehavior,
+ const char *sourcePath, ShShaderOutput outputType,
+ int numRenderTargets, const std::vector<Uniform> &uniforms,
+ int compileOptions)
: TIntermTraverser(true, true, true),
- mContext(context),
- mOutputType(parentTranslator->getOutputType())
+ mShaderType(shaderType),
+ mShaderVersion(shaderVersion),
+ mExtensionBehavior(extensionBehavior),
+ mSourcePath(sourcePath),
+ mOutputType(outputType),
+ mNumRenderTargets(numRenderTargets),
+ mCompileOptions(compileOptions)
{
- mUnfoldShortCircuit = new UnfoldShortCircuit(context, this);
+ mUnfoldShortCircuit = new UnfoldShortCircuit(this);
mInsideFunction = false;
mUsesFragColor = false;
@@ -109,28 +121,12 @@ OutputHLSL::OutputHLSL(TParseContext &context, TranslatorHLSL *parentTranslator)
mUsesPointCoord = false;
mUsesFrontFacing = false;
mUsesPointSize = false;
+ mUsesInstanceID = false;
mUsesFragDepth = false;
mUsesXor = false;
- mUsesMod1 = false;
- mUsesMod2v = false;
- mUsesMod2f = false;
- mUsesMod3v = false;
- mUsesMod3f = false;
- mUsesMod4v = false;
- mUsesMod4f = false;
- mUsesFaceforward1 = false;
- mUsesFaceforward2 = false;
- mUsesFaceforward3 = false;
- mUsesFaceforward4 = false;
- mUsesAtan2_1 = false;
- mUsesAtan2_2 = false;
- mUsesAtan2_3 = false;
- mUsesAtan2_4 = false;
mUsesDiscardRewriting = false;
mUsesNestedBreak = false;
-
- const ShBuiltInResources &resources = parentTranslator->getResources();
- mNumRenderTargets = resources.EXT_draw_buffers ? resources.MaxDrawBuffers : 1;
+ mRequiresIEEEStrictCompiling = false;
mUniqueIndex = 0;
@@ -143,20 +139,14 @@ OutputHLSL::OutputHLSL(TParseContext &context, TranslatorHLSL *parentTranslator)
mExcessiveLoopIndex = NULL;
mStructureHLSL = new StructureHLSL;
- mUniformHLSL = new UniformHLSL(mStructureHLSL, parentTranslator);
+ mUniformHLSL = new UniformHLSL(mStructureHLSL, outputType, uniforms);
if (mOutputType == SH_HLSL9_OUTPUT)
{
- if (mContext.shaderType == GL_FRAGMENT_SHADER)
- {
- // Reserve registers for dx_DepthRange, dx_ViewCoords and dx_DepthFront
- mUniformHLSL->reserveUniformRegisters(3);
- }
- else
- {
- // Reserve registers for dx_DepthRange and dx_ViewAdjust
- mUniformHLSL->reserveUniformRegisters(2);
- }
+ // Fragment shaders need dx_DepthRange, dx_ViewCoords and dx_DepthFront.
+ // Vertex shaders need a slightly different set: dx_DepthRange, dx_ViewCoords and dx_ViewAdjust.
+ // In both cases total 3 uniform registers need to be reserved.
+ mUniformHLSL->reserveUniformRegisters(3);
}
// Reserve registers for the default uniform block and driver constants
@@ -168,27 +158,55 @@ OutputHLSL::~OutputHLSL()
SafeDelete(mUnfoldShortCircuit);
SafeDelete(mStructureHLSL);
SafeDelete(mUniformHLSL);
+ for (auto it = mStructEqualityFunctions.begin(); it != mStructEqualityFunctions.end(); ++it)
+ {
+ SafeDelete(*it);
+ }
+ for (auto it = mArrayEqualityFunctions.begin(); it != mArrayEqualityFunctions.end(); ++it)
+ {
+ SafeDelete(*it);
+ }
}
-void OutputHLSL::output()
+void OutputHLSL::output(TIntermNode *treeRoot, TInfoSinkBase &objSink)
{
- mContainsLoopDiscontinuity = mContext.shaderType == GL_FRAGMENT_SHADER && containsLoopDiscontinuity(mContext.treeRoot);
- mContainsAnyLoop = containsAnyLoop(mContext.treeRoot);
- const std::vector<TIntermTyped*> &flaggedStructs = FlagStd140ValueStructs(mContext.treeRoot);
+ mContainsLoopDiscontinuity = mShaderType == GL_FRAGMENT_SHADER && containsLoopDiscontinuity(treeRoot);
+ mContainsAnyLoop = containsAnyLoop(treeRoot);
+ const std::vector<TIntermTyped*> &flaggedStructs = FlagStd140ValueStructs(treeRoot);
makeFlaggedStructMaps(flaggedStructs);
// Work around D3D9 bug that would manifest in vertex shaders with selection blocks which
// use a vertex attribute as a condition, and some related computation in the else block.
- if (mOutputType == SH_HLSL9_OUTPUT && mContext.shaderType == GL_VERTEX_SHADER)
+ if (mOutputType == SH_HLSL9_OUTPUT && mShaderType == GL_VERTEX_SHADER)
+ {
+ RewriteElseBlocks(treeRoot);
+ }
+
+ BuiltInFunctionEmulator builtInFunctionEmulator;
+ InitBuiltInFunctionEmulatorForHLSL(&builtInFunctionEmulator);
+ builtInFunctionEmulator.MarkBuiltInFunctionsForEmulation(treeRoot);
+
+ // Output the body and footer first to determine what has to go in the header
+ mInfoSinkStack.push(&mBody);
+ treeRoot->traverse(this);
+ mInfoSinkStack.pop();
+
+ mInfoSinkStack.push(&mFooter);
+ if (!mDeferredGlobalInitializers.empty())
{
- RewriteElseBlocks(mContext.treeRoot);
+ writeDeferredGlobalInitializers(mFooter);
}
+ mInfoSinkStack.pop();
+
+ mInfoSinkStack.push(&mHeader);
+ header(&builtInFunctionEmulator);
+ mInfoSinkStack.pop();
- mContext.treeRoot->traverse(this); // Output the body first to determine what has to go in the header
- header();
+ objSink << mHeader.c_str();
+ objSink << mBody.c_str();
+ objSink << mFooter.c_str();
- mContext.infoSink().obj << mHeader.c_str();
- mContext.infoSink().obj << mBody.c_str();
+ builtInFunctionEmulator.Cleanup();
}
void OutputHLSL::makeFlaggedStructMaps(const std::vector<TIntermTyped *> &flaggedStructs)
@@ -197,10 +215,14 @@ void OutputHLSL::makeFlaggedStructMaps(const std::vector<TIntermTyped *> &flagge
{
TIntermTyped *flaggedNode = flaggedStructs[structIndex];
+ TInfoSinkBase structInfoSink;
+ mInfoSinkStack.push(&structInfoSink);
+
// This will mark the necessary block elements as referenced
flaggedNode->traverse(this);
- TString structName(mBody.c_str());
- mBody.erase();
+
+ TString structName(structInfoSink.c_str());
+ mInfoSinkStack.pop();
mFlaggedStructOriginalNames[flaggedNode] = structName;
@@ -213,11 +235,6 @@ void OutputHLSL::makeFlaggedStructMaps(const std::vector<TIntermTyped *> &flagge
}
}
-TInfoSinkBase &OutputHLSL::getBodyStream()
-{
- return mBody;
-}
-
const std::map<std::string, unsigned int> &OutputHLSL::getInterfaceBlockRegisterMap() const
{
return mUniformHLSL->getInterfaceBlockRegisterMap();
@@ -277,9 +294,9 @@ TString OutputHLSL::structInitializerString(int indent, const TStructure &struct
return init;
}
-void OutputHLSL::header()
+void OutputHLSL::header(const BuiltInFunctionEmulator *builtInFunctionEmulator)
{
- TInfoSinkBase &out = mHeader;
+ TInfoSinkBase &out = getInfoSink();
TString varyings;
TString attributes;
@@ -320,6 +337,23 @@ void OutputHLSL::header()
out << mUniformHLSL->uniformsHeader(mOutputType, mReferencedUniforms);
out << mUniformHLSL->interfaceBlocksHeader(mReferencedInterfaceBlocks);
+ if (!mEqualityFunctions.empty())
+ {
+ out << "\n// Equality functions\n\n";
+ for (auto it = mEqualityFunctions.cbegin(); it != mEqualityFunctions.cend(); ++it)
+ {
+ out << (*it)->functionDefinition << "\n";
+ }
+ }
+ if (!mArrayAssignmentFunctions.empty())
+ {
+ out << "\n// Assignment functions\n\n";
+ for (auto it = mArrayAssignmentFunctions.cbegin(); it != mArrayAssignmentFunctions.cend(); ++it)
+ {
+ out << it->functionDefinition << "\n";
+ }
+ }
+
if (mUsesDiscardRewriting)
{
out << "#define ANGLE_USES_DISCARD_REWRITING\n";
@@ -330,6 +364,11 @@ void OutputHLSL::header()
out << "#define ANGLE_USES_NESTED_BREAK\n";
}
+ if (mRequiresIEEEStrictCompiling)
+ {
+ out << "#define ANGLE_REQUIRES_IEEE_STRICT_COMPILING\n";
+ }
+
out << "#ifdef ANGLE_ENABLE_LOOP_FLATTEN\n"
"#define LOOP [loop]\n"
"#define FLATTEN [flatten]\n"
@@ -338,16 +377,16 @@ void OutputHLSL::header()
"#define FLATTEN\n"
"#endif\n";
- if (mContext.shaderType == GL_FRAGMENT_SHADER)
+ if (mShaderType == GL_FRAGMENT_SHADER)
{
- TExtensionBehavior::const_iterator iter = mContext.extensionBehavior().find("GL_EXT_draw_buffers");
- const bool usingMRTExtension = (iter != mContext.extensionBehavior().end() && (iter->second == EBhEnable || iter->second == EBhRequire));
+ TExtensionBehavior::const_iterator iter = mExtensionBehavior.find("GL_EXT_draw_buffers");
+ const bool usingMRTExtension = (iter != mExtensionBehavior.end() && (iter->second == EBhEnable || iter->second == EBhRequire));
out << "// Varyings\n";
out << varyings;
out << "\n";
- if (mContext.getShaderVersion() >= 300)
+ if (mShaderVersion >= 300)
{
for (ReferencedSymbols::const_iterator outputVariableIt = mReferencedOutputVariables.begin(); outputVariableIt != mReferencedOutputVariables.end(); outputVariableIt++)
{
@@ -493,6 +532,11 @@ void OutputHLSL::header()
out << "static float gl_PointSize = float(1);\n";
}
+ if (mUsesInstanceID)
+ {
+ out << "static int gl_InstanceID;";
+ }
+
out << "\n"
"// Varyings\n";
out << varyings;
@@ -511,14 +555,22 @@ void OutputHLSL::header()
if (mOutputType == SH_HLSL11_OUTPUT)
{
+ out << "cbuffer DriverConstants : register(b1)\n"
+ "{\n";
+
if (mUsesDepthRange)
{
- out << "cbuffer DriverConstants : register(b1)\n"
- "{\n"
- " float3 dx_DepthRange : packoffset(c0);\n"
- "};\n"
- "\n";
+ out << " float3 dx_DepthRange : packoffset(c0);\n";
}
+
+ // dx_ViewAdjust and dx_ViewCoords will only be used in Feature Level 9 shaders.
+ // However, we declare it for all shaders (including Feature Level 10+).
+ // The bytecode is the same whether we declare it or not, since D3DCompiler removes it if it's unused.
+ out << " float4 dx_ViewAdjust : packoffset(c1);\n";
+ out << " float2 dx_ViewCoords : packoffset(c2);\n";
+
+ out << "};\n"
+ "\n";
}
else
{
@@ -527,7 +579,8 @@ void OutputHLSL::header()
out << "uniform float3 dx_DepthRange : register(c0);\n";
}
- out << "uniform float4 dx_ViewAdjust : register(c1);\n"
+ out << "uniform float4 dx_ViewAdjust : register(c1);\n";
+ out << "uniform float2 dx_ViewCoords : register(c2);\n"
"\n";
}
@@ -980,7 +1033,15 @@ void OutputHLSL::header()
}
else if (IsShadowSampler(textureFunction->sampler))
{
- out << "x.SampleCmp(s, ";
+ switch(textureFunction->method)
+ {
+ case TextureFunction::IMPLICIT: out << "x.SampleCmp(s, "; break;
+ case TextureFunction::BIAS: out << "x.SampleCmp(s, "; break;
+ case TextureFunction::LOD: out << "x.SampleCmp(s, "; break;
+ case TextureFunction::LOD0: out << "x.SampleCmpLevelZero(s, "; break;
+ case TextureFunction::LOD0BIAS: out << "x.SampleCmpLevelZero(s, "; break;
+ default: UNREACHABLE();
+ }
}
else
{
@@ -1111,11 +1172,20 @@ void OutputHLSL::header()
else if (IsShadowSampler(textureFunction->sampler))
{
// Compare value
- switch(textureFunction->coords)
+ if (textureFunction->proj)
{
- case 3: out << "), t.z"; break;
- case 4: out << "), t.w"; break;
- default: UNREACHABLE();
+ // According to ESSL 3.00.4 sec 8.8 p95 on textureProj:
+ // The resulting third component of P' in the shadow forms is used as Dref
+ out << "), t.z" << proj;
+ }
+ else
+ {
+ switch(textureFunction->coords)
+ {
+ case 3: out << "), t.z"; break;
+ case 4: out << "), t.w"; break;
+ default: UNREACHABLE();
+ }
}
}
else
@@ -1131,11 +1201,20 @@ void OutputHLSL::header()
else if (IsShadowSampler(textureFunction->sampler))
{
// Compare value
- switch(textureFunction->coords)
+ if (textureFunction->proj)
{
- case 3: out << "), t.z"; break;
- case 4: out << "), t.w"; break;
- default: UNREACHABLE();
+ // According to ESSL 3.00.4 sec 8.8 p95 on textureProj:
+ // The resulting third component of P' in the shadow forms is used as Dref
+ out << "), t.z" << proj;
+ }
+ else
+ {
+ switch(textureFunction->coords)
+ {
+ case 3: out << "), t.z"; break;
+ case 4: out << "), t.w"; break;
+ default: UNREACHABLE();
+ }
}
}
else
@@ -1205,179 +1284,12 @@ void OutputHLSL::header()
"\n";
}
- if (mUsesMod1)
- {
- out << "float mod(float x, float y)\n"
- "{\n"
- " return x - y * floor(x / y);\n"
- "}\n"
- "\n";
- }
-
- if (mUsesMod2v)
- {
- out << "float2 mod(float2 x, float2 y)\n"
- "{\n"
- " return x - y * floor(x / y);\n"
- "}\n"
- "\n";
- }
-
- if (mUsesMod2f)
- {
- out << "float2 mod(float2 x, float y)\n"
- "{\n"
- " return x - y * floor(x / y);\n"
- "}\n"
- "\n";
- }
-
- if (mUsesMod3v)
- {
- out << "float3 mod(float3 x, float3 y)\n"
- "{\n"
- " return x - y * floor(x / y);\n"
- "}\n"
- "\n";
- }
-
- if (mUsesMod3f)
- {
- out << "float3 mod(float3 x, float y)\n"
- "{\n"
- " return x - y * floor(x / y);\n"
- "}\n"
- "\n";
- }
-
- if (mUsesMod4v)
- {
- out << "float4 mod(float4 x, float4 y)\n"
- "{\n"
- " return x - y * floor(x / y);\n"
- "}\n"
- "\n";
- }
-
- if (mUsesMod4f)
- {
- out << "float4 mod(float4 x, float y)\n"
- "{\n"
- " return x - y * floor(x / y);\n"
- "}\n"
- "\n";
- }
-
- if (mUsesFaceforward1)
- {
- out << "float faceforward(float N, float I, float Nref)\n"
- "{\n"
- " if(dot(Nref, I) >= 0)\n"
- " {\n"
- " return -N;\n"
- " }\n"
- " else\n"
- " {\n"
- " return N;\n"
- " }\n"
- "}\n"
- "\n";
- }
-
- if (mUsesFaceforward2)
- {
- out << "float2 faceforward(float2 N, float2 I, float2 Nref)\n"
- "{\n"
- " if(dot(Nref, I) >= 0)\n"
- " {\n"
- " return -N;\n"
- " }\n"
- " else\n"
- " {\n"
- " return N;\n"
- " }\n"
- "}\n"
- "\n";
- }
-
- if (mUsesFaceforward3)
- {
- out << "float3 faceforward(float3 N, float3 I, float3 Nref)\n"
- "{\n"
- " if(dot(Nref, I) >= 0)\n"
- " {\n"
- " return -N;\n"
- " }\n"
- " else\n"
- " {\n"
- " return N;\n"
- " }\n"
- "}\n"
- "\n";
- }
-
- if (mUsesFaceforward4)
- {
- out << "float4 faceforward(float4 N, float4 I, float4 Nref)\n"
- "{\n"
- " if(dot(Nref, I) >= 0)\n"
- " {\n"
- " return -N;\n"
- " }\n"
- " else\n"
- " {\n"
- " return N;\n"
- " }\n"
- "}\n"
- "\n";
- }
-
- if (mUsesAtan2_1)
- {
- out << "float atanyx(float y, float x)\n"
- "{\n"
- " if(x == 0 && y == 0) x = 1;\n" // Avoid producing a NaN
- " return atan2(y, x);\n"
- "}\n";
- }
-
- if (mUsesAtan2_2)
- {
- out << "float2 atanyx(float2 y, float2 x)\n"
- "{\n"
- " if(x[0] == 0 && y[0] == 0) x[0] = 1;\n"
- " if(x[1] == 0 && y[1] == 0) x[1] = 1;\n"
- " return float2(atan2(y[0], x[0]), atan2(y[1], x[1]));\n"
- "}\n";
- }
-
- if (mUsesAtan2_3)
- {
- out << "float3 atanyx(float3 y, float3 x)\n"
- "{\n"
- " if(x[0] == 0 && y[0] == 0) x[0] = 1;\n"
- " if(x[1] == 0 && y[1] == 0) x[1] = 1;\n"
- " if(x[2] == 0 && y[2] == 0) x[2] = 1;\n"
- " return float3(atan2(y[0], x[0]), atan2(y[1], x[1]), atan2(y[2], x[2]));\n"
- "}\n";
- }
-
- if (mUsesAtan2_4)
- {
- out << "float4 atanyx(float4 y, float4 x)\n"
- "{\n"
- " if(x[0] == 0 && y[0] == 0) x[0] = 1;\n"
- " if(x[1] == 0 && y[1] == 0) x[1] = 1;\n"
- " if(x[2] == 0 && y[2] == 0) x[2] = 1;\n"
- " if(x[3] == 0 && y[3] == 0) x[3] = 1;\n"
- " return float4(atan2(y[0], x[0]), atan2(y[1], x[1]), atan2(y[2], x[2]), atan2(y[3], x[3]));\n"
- "}\n";
- }
+ builtInFunctionEmulator->OutputEmulatedFunctions(out);
}
void OutputHLSL::visitSymbol(TIntermSymbol *node)
{
- TInfoSinkBase &out = mBody;
+ TInfoSinkBase &out = getInfoSink();
// Handle accessing std140 structs by value
if (mFlaggedStructMappedNames.count(node) > 0)
@@ -1458,6 +1370,11 @@ void OutputHLSL::visitSymbol(TIntermSymbol *node)
mUsesPointSize = true;
out << name;
}
+ else if (qualifier == EvqInstanceID)
+ {
+ mUsesInstanceID = true;
+ out << name;
+ }
else if (name == "gl_FragDepthEXT")
{
mUsesFragDepth = true;
@@ -1476,12 +1393,51 @@ void OutputHLSL::visitSymbol(TIntermSymbol *node)
void OutputHLSL::visitRaw(TIntermRaw *node)
{
- mBody << node->getRawText();
+ getInfoSink() << node->getRawText();
+}
+
+void OutputHLSL::outputEqual(Visit visit, const TType &type, TOperator op, TInfoSinkBase &out)
+{
+ if (type.isScalar() && !type.isArray())
+ {
+ if (op == EOpEqual)
+ {
+ outputTriplet(visit, "(", " == ", ")", out);
+ }
+ else
+ {
+ outputTriplet(visit, "(", " != ", ")", out);
+ }
+ }
+ else
+ {
+ if (visit == PreVisit && op == EOpNotEqual)
+ {
+ out << "!";
+ }
+
+ if (type.isArray())
+ {
+ const TString &functionName = addArrayEqualityFunction(type);
+ outputTriplet(visit, (functionName + "(").c_str(), ", ", ")", out);
+ }
+ else if (type.getBasicType() == EbtStruct)
+ {
+ const TStructure &structure = *type.getStruct();
+ const TString &functionName = addStructEqualityFunction(structure);
+ outputTriplet(visit, (functionName + "(").c_str(), ", ", ")", out);
+ }
+ else
+ {
+ ASSERT(type.isMatrix() || type.isVector());
+ outputTriplet(visit, "all(", " == ", ")", out);
+ }
+ }
}
bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
{
- TInfoSinkBase &out = mBody;
+ TInfoSinkBase &out = getInfoSink();
// Handle accessing std140 structs by value
if (mFlaggedStructMappedNames.count(node) > 0)
@@ -1492,7 +1448,17 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
switch (node->getOp())
{
- case EOpAssign: outputTriplet(visit, "(", " = ", ")"); break;
+ case EOpAssign:
+ if (node->getLeft()->isArray())
+ {
+ const TString &functionName = addArrayAssignmentFunction(node->getType());
+ outputTriplet(visit, (functionName + "(").c_str(), ", ", ")");
+ }
+ else
+ {
+ outputTriplet(visit, "(", " = ", ")");
+ }
+ break;
case EOpInitialize:
if (visit == PreVisit)
{
@@ -1502,22 +1468,21 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
// this to "float t = x, x = t;".
TIntermSymbol *symbolNode = node->getLeft()->getAsSymbolNode();
+ ASSERT(symbolNode);
TIntermTyped *expression = node->getRight();
- sh::SearchSymbol searchSymbol(symbolNode->getSymbol());
- expression->traverse(&searchSymbol);
- bool sameSymbol = searchSymbol.foundMatch();
-
- if (sameSymbol)
+ // TODO (jmadill): do a 'deep' scan to know if an expression is statically const
+ if (symbolNode->getQualifier() == EvqGlobal && expression->getQualifier() != EvqConst)
{
- // Type already printed
- out << "t" + str(mUniqueIndex) + " = ";
- expression->traverse(this);
- out << ", ";
- symbolNode->traverse(this);
- out << " = t" + str(mUniqueIndex);
-
- mUniqueIndex++;
+ // For variables which are not constant, defer their real initialization until
+ // after we initialize other globals: uniforms, attributes and varyings.
+ mDeferredGlobalInitializers.push_back(std::make_pair(symbolNode, expression));
+ const TString &initString = initializer(node->getType());
+ node->setRight(new TIntermRaw(node->getType(), initString));
+ }
+ else if (writeSameSymbolInitializer(out, symbolNode, expression))
+ {
+ // Skip initializing the rest of the expression
return false;
}
}
@@ -1554,16 +1519,22 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
}
else if (visit == InVisit)
{
- out << " = mul(";
+ out << " = transpose(mul(transpose(";
node->getLeft()->traverse(this);
- out << ", ";
+ out << "), transpose(";
}
else
{
- out << "))";
+ out << "))))";
}
break;
case EOpDivAssign: outputTriplet(visit, "(", " /= ", ")"); break;
+ case EOpIModAssign: outputTriplet(visit, "(", " %= ", ")"); break;
+ case EOpBitShiftLeftAssign: outputTriplet(visit, "(", " <<= ", ")"); break;
+ case EOpBitShiftRightAssign: outputTriplet(visit, "(", " >>= ", ")"); break;
+ case EOpBitwiseAndAssign: outputTriplet(visit, "(", " &= ", ")"); break;
+ case EOpBitwiseXorAssign: outputTriplet(visit, "(", " ^= ", ")"); break;
+ case EOpBitwiseOrAssign: outputTriplet(visit, "(", " |= ", ")"); break;
case EOpIndexDirect:
{
const TType& leftType = node->getLeft()->getType();
@@ -1651,65 +1622,15 @@ bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
case EOpSub: outputTriplet(visit, "(", " - ", ")"); break;
case EOpMul: outputTriplet(visit, "(", " * ", ")"); break;
case EOpDiv: outputTriplet(visit, "(", " / ", ")"); break;
+ case EOpIMod: outputTriplet(visit, "(", " % ", ")"); break;
+ case EOpBitShiftLeft: outputTriplet(visit, "(", " << ", ")"); break;
+ case EOpBitShiftRight: outputTriplet(visit, "(", " >> ", ")"); break;
+ case EOpBitwiseAnd: outputTriplet(visit, "(", " & ", ")"); break;
+ case EOpBitwiseXor: outputTriplet(visit, "(", " ^ ", ")"); break;
+ case EOpBitwiseOr: outputTriplet(visit, "(", " | ", ")"); break;
case EOpEqual:
case EOpNotEqual:
- if (node->getLeft()->isScalar())
- {
- if (node->getOp() == EOpEqual)
- {
- outputTriplet(visit, "(", " == ", ")");
- }
- else
- {
- outputTriplet(visit, "(", " != ", ")");
- }
- }
- else if (node->getLeft()->getBasicType() == EbtStruct)
- {
- if (node->getOp() == EOpEqual)
- {
- out << "(";
- }
- else
- {
- out << "!(";
- }
-
- const TStructure &structure = *node->getLeft()->getType().getStruct();
- const TFieldList &fields = structure.fields();
-
- for (size_t i = 0; i < fields.size(); i++)
- {
- const TField *field = fields[i];
-
- node->getLeft()->traverse(this);
- out << "." + DecorateField(field->name(), structure) + " == ";
- node->getRight()->traverse(this);
- out << "." + DecorateField(field->name(), structure);
-
- if (i < fields.size() - 1)
- {
- out << " && ";
- }
- }
-
- out << ")";
-
- return false;
- }
- else
- {
- ASSERT(node->getLeft()->isMatrix() || node->getLeft()->isVector());
-
- if (node->getOp() == EOpEqual)
- {
- outputTriplet(visit, "all(", " == ", ")");
- }
- else
- {
- outputTriplet(visit, "!all(", " == ", ")");
- }
- }
+ outputEqual(visit, node->getLeft()->getType(), node->getOp(), out);
break;
case EOpLessThan: outputTriplet(visit, "(", " < ", ")"); break;
case EOpGreaterThan: outputTriplet(visit, "(", " > ", ")"); break;
@@ -1760,6 +1681,7 @@ bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
case EOpPositive: outputTriplet(visit, "(+", "", ")"); break;
case EOpVectorLogicalNot: outputTriplet(visit, "(!", "", ")"); break;
case EOpLogicalNot: outputTriplet(visit, "(!", "", ")"); break;
+ case EOpBitwiseNot: outputTriplet(visit, "(~", "", ")"); break;
case EOpPostIncrement: outputTriplet(visit, "(", "", "++)"); break;
case EOpPostDecrement: outputTriplet(visit, "(", "", "--)"); break;
case EOpPreIncrement: outputTriplet(visit, "(++", "", ")"); break;
@@ -1772,6 +1694,21 @@ bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
case EOpAsin: outputTriplet(visit, "asin(", "", ")"); break;
case EOpAcos: outputTriplet(visit, "acos(", "", ")"); break;
case EOpAtan: outputTriplet(visit, "atan(", "", ")"); break;
+ case EOpSinh: outputTriplet(visit, "sinh(", "", ")"); break;
+ case EOpCosh: outputTriplet(visit, "cosh(", "", ")"); break;
+ case EOpTanh: outputTriplet(visit, "tanh(", "", ")"); break;
+ case EOpAsinh:
+ ASSERT(node->getUseEmulatedFunction());
+ writeEmulatedFunctionTriplet(visit, "asinh(");
+ break;
+ case EOpAcosh:
+ ASSERT(node->getUseEmulatedFunction());
+ writeEmulatedFunctionTriplet(visit, "acosh(");
+ break;
+ case EOpAtanh:
+ ASSERT(node->getUseEmulatedFunction());
+ writeEmulatedFunctionTriplet(visit, "atanh(");
+ break;
case EOpExp: outputTriplet(visit, "exp(", "", ")"); break;
case EOpLog: outputTriplet(visit, "log(", "", ")"); break;
case EOpExp2: outputTriplet(visit, "exp2(", "", ")"); break;
@@ -1781,8 +1718,47 @@ bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
case EOpAbs: outputTriplet(visit, "abs(", "", ")"); break;
case EOpSign: outputTriplet(visit, "sign(", "", ")"); break;
case EOpFloor: outputTriplet(visit, "floor(", "", ")"); break;
+ case EOpTrunc: outputTriplet(visit, "trunc(", "", ")"); break;
+ case EOpRound: outputTriplet(visit, "round(", "", ")"); break;
+ case EOpRoundEven:
+ ASSERT(node->getUseEmulatedFunction());
+ writeEmulatedFunctionTriplet(visit, "roundEven(");
+ break;
case EOpCeil: outputTriplet(visit, "ceil(", "", ")"); break;
case EOpFract: outputTriplet(visit, "frac(", "", ")"); break;
+ case EOpIsNan:
+ outputTriplet(visit, "isnan(", "", ")");
+ mRequiresIEEEStrictCompiling = true;
+ break;
+ case EOpIsInf: outputTriplet(visit, "isinf(", "", ")"); break;
+ case EOpFloatBitsToInt: outputTriplet(visit, "asint(", "", ")"); break;
+ case EOpFloatBitsToUint: outputTriplet(visit, "asuint(", "", ")"); break;
+ case EOpIntBitsToFloat: outputTriplet(visit, "asfloat(", "", ")"); break;
+ case EOpUintBitsToFloat: outputTriplet(visit, "asfloat(", "", ")"); break;
+ case EOpPackSnorm2x16:
+ ASSERT(node->getUseEmulatedFunction());
+ writeEmulatedFunctionTriplet(visit, "packSnorm2x16(");
+ break;
+ case EOpPackUnorm2x16:
+ ASSERT(node->getUseEmulatedFunction());
+ writeEmulatedFunctionTriplet(visit, "packUnorm2x16(");
+ break;
+ case EOpPackHalf2x16:
+ ASSERT(node->getUseEmulatedFunction());
+ writeEmulatedFunctionTriplet(visit, "packHalf2x16(");
+ break;
+ case EOpUnpackSnorm2x16:
+ ASSERT(node->getUseEmulatedFunction());
+ writeEmulatedFunctionTriplet(visit, "unpackSnorm2x16(");
+ break;
+ case EOpUnpackUnorm2x16:
+ ASSERT(node->getUseEmulatedFunction());
+ writeEmulatedFunctionTriplet(visit, "unpackUnorm2x16(");
+ break;
+ case EOpUnpackHalf2x16:
+ ASSERT(node->getUseEmulatedFunction());
+ writeEmulatedFunctionTriplet(visit, "unpackHalf2x16(");
+ break;
case EOpLength: outputTriplet(visit, "length(", "", ")"); break;
case EOpNormalize: outputTriplet(visit, "normalize(", "", ")"); break;
case EOpDFdx:
@@ -1815,6 +1791,13 @@ bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
outputTriplet(visit, "fwidth(", "", ")");
}
break;
+ case EOpTranspose: outputTriplet(visit, "transpose(", "", ")"); break;
+ case EOpDeterminant: outputTriplet(visit, "determinant(transpose(", "", "))"); break;
+ case EOpInverse:
+ ASSERT(node->getUseEmulatedFunction());
+ writeEmulatedFunctionTriplet(visit, "inverse(");
+ break;
+
case EOpAny: outputTriplet(visit, "any(", "", ")"); break;
case EOpAll: outputTriplet(visit, "all(", "", ")"); break;
default: UNREACHABLE();
@@ -1825,7 +1808,7 @@ bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
{
- TInfoSinkBase &out = mBody;
+ TInfoSinkBase &out = getInfoSink();
switch (node->getOp())
{
@@ -1843,7 +1826,11 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
traverseStatements(*sit);
- out << ";\n";
+ // Don't output ; after case labels, they're terminated by :
+ // This is needed especially since outputting a ; after a case statement would turn empty
+ // case statements into non-empty case statements, disallowing fall-through from them.
+ if ((*sit)->getAsCaseNode() == nullptr)
+ out << ";\n";
}
if (mInsideFunction)
@@ -1871,11 +1858,12 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
if (!variable->getAsSymbolNode() || variable->getAsSymbolNode()->getSymbol() != "") // Variable declaration
{
- for (TIntermSequence::iterator sit = sequence->begin(); sit != sequence->end(); sit++)
+ for (auto it = sequence->cbegin(); it != sequence->cend(); ++it)
{
- if (isSingleStatement(*sit))
+ const auto &seqElement = *it;
+ if (isSingleStatement(seqElement))
{
- mUnfoldShortCircuit->traverse(*sit);
+ mUnfoldShortCircuit->traverse(seqElement);
}
if (!mInsideFunction)
@@ -1885,7 +1873,7 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
out << TypeString(variable->getType()) + " ";
- TIntermSymbol *symbol = (*sit)->getAsSymbolNode();
+ TIntermSymbol *symbol = seqElement->getAsSymbolNode();
if (symbol)
{
@@ -1895,10 +1883,10 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
}
else
{
- (*sit)->traverse(this);
+ seqElement->traverse(this);
}
- if (*sit != sequence->back())
+ if (seqElement != sequence->back())
{
out << ";\n";
}
@@ -2149,7 +2137,7 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
bool bias = (arguments->size() > mandatoryArgumentCount); // Bias argument is optional
- if (lod0 || mContext.shaderType == GL_VERTEX_SHADER)
+ if (lod0 || mShaderType == GL_VERTEX_SHADER)
{
if (bias)
{
@@ -2217,7 +2205,7 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
{
const TString &structName = StructNameString(*node->getType().getStruct());
mStructureHLSL->addConstructor(node->getType(), structName, node->getSequence());
- outputTriplet(visit, structName + "_ctor(", ", ", ")");
+ outputTriplet(visit, (structName + "_ctor(").c_str(), ", ", ")");
}
break;
case EOpLessThan: outputTriplet(visit, "(", " < ", ")"); break;
@@ -2227,37 +2215,15 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
case EOpVectorEqual: outputTriplet(visit, "(", " == ", ")"); break;
case EOpVectorNotEqual: outputTriplet(visit, "(", " != ", ")"); break;
case EOpMod:
- {
- // We need to look at the number of components in both arguments
- const int modValue = (*node->getSequence())[0]->getAsTyped()->getNominalSize() * 10 +
- (*node->getSequence())[1]->getAsTyped()->getNominalSize();
- switch (modValue)
- {
- case 11: mUsesMod1 = true; break;
- case 22: mUsesMod2v = true; break;
- case 21: mUsesMod2f = true; break;
- case 33: mUsesMod3v = true; break;
- case 31: mUsesMod3f = true; break;
- case 44: mUsesMod4v = true; break;
- case 41: mUsesMod4f = true; break;
- default: UNREACHABLE();
- }
-
- outputTriplet(visit, "mod(", ", ", ")");
- }
+ ASSERT(node->getUseEmulatedFunction());
+ writeEmulatedFunctionTriplet(visit, "mod(");
break;
+ case EOpModf: outputTriplet(visit, "modf(", ", ", ")"); break;
case EOpPow: outputTriplet(visit, "pow(", ", ", ")"); break;
case EOpAtan:
ASSERT(node->getSequence()->size() == 2); // atan(x) is a unary operator
- switch ((*node->getSequence())[0]->getAsTyped()->getNominalSize())
- {
- case 1: mUsesAtan2_1 = true; break;
- case 2: mUsesAtan2_2 = true; break;
- case 3: mUsesAtan2_3 = true; break;
- case 4: mUsesAtan2_4 = true; break;
- default: UNREACHABLE();
- }
- outputTriplet(visit, "atanyx(", ", ", ")");
+ ASSERT(node->getUseEmulatedFunction());
+ writeEmulatedFunctionTriplet(visit, "atan(");
break;
case EOpMin: outputTriplet(visit, "min(", ", ", ")"); break;
case EOpMax: outputTriplet(visit, "max(", ", ", ")"); break;
@@ -2269,21 +2235,15 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
case EOpDot: outputTriplet(visit, "dot(", ", ", ")"); break;
case EOpCross: outputTriplet(visit, "cross(", ", ", ")"); break;
case EOpFaceForward:
- {
- switch ((*node->getSequence())[0]->getAsTyped()->getNominalSize()) // Number of components in the first argument
- {
- case 1: mUsesFaceforward1 = true; break;
- case 2: mUsesFaceforward2 = true; break;
- case 3: mUsesFaceforward3 = true; break;
- case 4: mUsesFaceforward4 = true; break;
- default: UNREACHABLE();
- }
-
- outputTriplet(visit, "faceforward(", ", ", ")");
- }
+ ASSERT(node->getUseEmulatedFunction());
+ writeEmulatedFunctionTriplet(visit, "faceforward(");
break;
case EOpReflect: outputTriplet(visit, "reflect(", ", ", ")"); break;
case EOpRefract: outputTriplet(visit, "refract(", ", ", ")"); break;
+ case EOpOuterProduct:
+ ASSERT(node->getUseEmulatedFunction());
+ writeEmulatedFunctionTriplet(visit, "outerProduct(");
+ break;
case EOpMul: outputTriplet(visit, "(", " * ", ")"); break;
default: UNREACHABLE();
}
@@ -2293,7 +2253,7 @@ bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
bool OutputHLSL::visitSelection(Visit visit, TIntermSelection *node)
{
- TInfoSinkBase &out = mBody;
+ TInfoSinkBase &out = getInfoSink();
if (node->usesTernaryOperator())
{
@@ -2307,7 +2267,7 @@ bool OutputHLSL::visitSelection(Visit visit, TIntermSelection *node)
// however flattening all the ifs in branch heavy shaders made D3D error too.
// As a temporary workaround we flatten the ifs only if there is at least a loop
// present somewhere in the shader.
- if (mContext.shaderType == GL_FRAGMENT_SHADER && mContainsAnyLoop)
+ if (mShaderType == GL_FRAGMENT_SHADER && mContainsAnyLoop)
{
out << "FLATTEN ";
}
@@ -2361,6 +2321,37 @@ bool OutputHLSL::visitSelection(Visit visit, TIntermSelection *node)
return false;
}
+bool OutputHLSL::visitSwitch(Visit visit, TIntermSwitch *node)
+{
+ if (node->getStatementList())
+ {
+ node->setStatementList(RemoveSwitchFallThrough::removeFallThrough(node->getStatementList()));
+ outputTriplet(visit, "switch (", ") ", "");
+ // The curly braces get written when visiting the statementList aggregate
+ }
+ else
+ {
+ // No statementList, so it won't output curly braces
+ outputTriplet(visit, "switch (", ") {", "}\n");
+ }
+ return true;
+}
+
+bool OutputHLSL::visitCase(Visit visit, TIntermCase *node)
+{
+ if (node->hasCondition())
+ {
+ outputTriplet(visit, "case (", "", "):\n");
+ return true;
+ }
+ else
+ {
+ TInfoSinkBase &out = getInfoSink();
+ out << "default:\n";
+ return false;
+ }
+}
+
void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
{
writeConstantUnion(node->getType(), node->getUnionArrayPointer());
@@ -2388,7 +2379,7 @@ bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
}
}
- TInfoSinkBase &out = mBody;
+ TInfoSinkBase &out = getInfoSink();
if (node->getType() == ELoopDoWhile)
{
@@ -2454,7 +2445,7 @@ bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node)
{
- TInfoSinkBase &out = mBody;
+ TInfoSinkBase &out = getInfoSink();
switch (node->getFlowOp())
{
@@ -2556,7 +2547,7 @@ bool OutputHLSL::isSingleStatement(TIntermNode *node)
bool OutputHLSL::handleExcessiveLoop(TIntermLoop *node)
{
const int MAX_LOOP_ITERATIONS = 254;
- TInfoSinkBase &out = mBody;
+ TInfoSinkBase &out = getInfoSink();
// Parse loops of the form:
// for(int index = initial; index [comparator] limit; index += increment)
@@ -2756,10 +2747,8 @@ bool OutputHLSL::handleExcessiveLoop(TIntermLoop *node)
return false; // Not handled as an excessive loop
}
-void OutputHLSL::outputTriplet(Visit visit, const TString &preString, const TString &inString, const TString &postString)
+void OutputHLSL::outputTriplet(Visit visit, const char *preString, const char *inString, const char *postString, TInfoSinkBase &out)
{
- TInfoSinkBase &out = mBody;
-
if (visit == PreVisit)
{
out << preString;
@@ -2774,19 +2763,26 @@ void OutputHLSL::outputTriplet(Visit visit, const TString &preString, const TStr
}
}
+void OutputHLSL::outputTriplet(Visit visit, const char *preString, const char *inString, const char *postString)
+{
+ outputTriplet(visit, preString, inString, postString, getInfoSink());
+}
+
void OutputHLSL::outputLineDirective(int line)
{
- if ((mContext.compileOptions & SH_LINE_DIRECTIVES) && (line > 0))
+ if ((mCompileOptions & SH_LINE_DIRECTIVES) && (line > 0))
{
- mBody << "\n";
- mBody << "#line " << line;
+ TInfoSinkBase &out = getInfoSink();
+
+ out << "\n";
+ out << "#line " << line;
- if (mContext.sourcePath)
+ if (mSourcePath)
{
- mBody << " \"" << mContext.sourcePath << "\"";
+ out << " \"" << mSourcePath << "\"";
}
- mBody << "\n";
+ out << "\n";
}
}
@@ -2832,15 +2828,15 @@ TString OutputHLSL::initializer(const TType &type)
return "{" + string + "}";
}
-void OutputHLSL::outputConstructor(Visit visit, const TType &type, const TString &name, const TIntermSequence *parameters)
+void OutputHLSL::outputConstructor(Visit visit, const TType &type, const char *name, const TIntermSequence *parameters)
{
- TInfoSinkBase &out = mBody;
+ TInfoSinkBase &out = getInfoSink();
if (visit == PreVisit)
{
mStructureHLSL->addConstructor(type, name, parameters);
- out << name + "(";
+ out << name << "(";
}
else if (visit == InVisit)
{
@@ -2854,7 +2850,7 @@ void OutputHLSL::outputConstructor(Visit visit, const TType &type, const TString
const ConstantUnion *OutputHLSL::writeConstantUnion(const TType &type, const ConstantUnion *constUnion)
{
- TInfoSinkBase &out = mBody;
+ TInfoSinkBase &out = getInfoSink();
const TStructure* structure = type.getStruct();
if (structure)
@@ -2912,4 +2908,209 @@ const ConstantUnion *OutputHLSL::writeConstantUnion(const TType &type, const Con
return constUnion;
}
+void OutputHLSL::writeEmulatedFunctionTriplet(Visit visit, const char *preStr)
+{
+ TString preString = BuiltInFunctionEmulator::GetEmulatedFunctionName(preStr);
+ outputTriplet(visit, preString.c_str(), ", ", ")");
+}
+
+bool OutputHLSL::writeSameSymbolInitializer(TInfoSinkBase &out, TIntermSymbol *symbolNode, TIntermTyped *expression)
+{
+ sh::SearchSymbol searchSymbol(symbolNode->getSymbol());
+ expression->traverse(&searchSymbol);
+
+ if (searchSymbol.foundMatch())
+ {
+ // Type already printed
+ out << "t" + str(mUniqueIndex) + " = ";
+ expression->traverse(this);
+ out << ", ";
+ symbolNode->traverse(this);
+ out << " = t" + str(mUniqueIndex);
+
+ mUniqueIndex++;
+ return true;
+ }
+
+ return false;
+}
+
+void OutputHLSL::writeDeferredGlobalInitializers(TInfoSinkBase &out)
+{
+ out << "#define ANGLE_USES_DEFERRED_INIT\n"
+ << "\n"
+ << "void initializeDeferredGlobals()\n"
+ << "{\n";
+
+ for (auto it = mDeferredGlobalInitializers.cbegin(); it != mDeferredGlobalInitializers.cend(); ++it)
+ {
+ const auto &deferredGlobal = *it;
+ TIntermSymbol *symbol = deferredGlobal.first;
+ TIntermTyped *expression = deferredGlobal.second;
+ ASSERT(symbol);
+ ASSERT(symbol->getQualifier() == EvqGlobal && expression->getQualifier() != EvqConst);
+
+ out << " " << Decorate(symbol->getSymbol()) << " = ";
+
+ if (!writeSameSymbolInitializer(out, symbol, expression))
+ {
+ ASSERT(mInfoSinkStack.top() == &out);
+ expression->traverse(this);
+ }
+
+ out << ";\n";
+ }
+
+ out << "}\n"
+ << "\n";
+}
+
+TString OutputHLSL::addStructEqualityFunction(const TStructure &structure)
+{
+ const TFieldList &fields = structure.fields();
+
+ for (auto it = mStructEqualityFunctions.cbegin(); it != mStructEqualityFunctions.cend(); ++it)
+ {
+ auto *eqFunction = *it;
+ if (eqFunction->structure == &structure)
+ {
+ return eqFunction->functionName;
+ }
+ }
+
+ const TString &structNameString = StructNameString(structure);
+
+ StructEqualityFunction *function = new StructEqualityFunction();
+ function->structure = &structure;
+ function->functionName = "angle_eq_" + structNameString;
+
+ TInfoSinkBase fnOut;
+
+ fnOut << "bool " << function->functionName << "(" << structNameString << " a, " << structNameString + " b)\n"
+ << "{\n"
+ " return ";
+
+ for (size_t i = 0; i < fields.size(); i++)
+ {
+ const TField *field = fields[i];
+ const TType *fieldType = field->type();
+
+ const TString &fieldNameA = "a." + Decorate(field->name());
+ const TString &fieldNameB = "b." + Decorate(field->name());
+
+ if (i > 0)
+ {
+ fnOut << " && ";
+ }
+
+ fnOut << "(";
+ outputEqual(PreVisit, *fieldType, EOpEqual, fnOut);
+ fnOut << fieldNameA;
+ outputEqual(InVisit, *fieldType, EOpEqual, fnOut);
+ fnOut << fieldNameB;
+ outputEqual(PostVisit, *fieldType, EOpEqual, fnOut);
+ fnOut << ")";
+ }
+
+ fnOut << ";\n" << "}\n";
+
+ function->functionDefinition = fnOut.c_str();
+
+ mStructEqualityFunctions.push_back(function);
+ mEqualityFunctions.push_back(function);
+
+ return function->functionName;
+}
+
+TString OutputHLSL::addArrayEqualityFunction(const TType& type)
+{
+ for (auto it = mArrayEqualityFunctions.cbegin(); it != mArrayEqualityFunctions.cend(); ++it)
+ {
+ const auto &eqFunction = *it;
+ if (eqFunction->type == type)
+ {
+ return eqFunction->functionName;
+ }
+ }
+
+ const TString &typeName = TypeString(type);
+
+ ArrayHelperFunction *function = new ArrayHelperFunction();
+ function->type = type;
+
+ TInfoSinkBase fnNameOut;
+ fnNameOut << "angle_eq_" << type.getArraySize() << "_" << typeName;
+ function->functionName = fnNameOut.c_str();
+
+ TType nonArrayType = type;
+ nonArrayType.clearArrayness();
+
+ TInfoSinkBase fnOut;
+
+ fnOut << "bool " << function->functionName << "("
+ << typeName << " a[" << type.getArraySize() << "], "
+ << typeName << " b[" << type.getArraySize() << "])\n"
+ << "{\n"
+ " for (int i = 0; i < " << type.getArraySize() << "; ++i)\n"
+ " {\n"
+ " if (";
+
+ outputEqual(PreVisit, nonArrayType, EOpNotEqual, fnOut);
+ fnOut << "a[i]";
+ outputEqual(InVisit, nonArrayType, EOpNotEqual, fnOut);
+ fnOut << "b[i]";
+ outputEqual(PostVisit, nonArrayType, EOpNotEqual, fnOut);
+
+ fnOut << ") { return false; }\n"
+ " }\n"
+ " return true;\n"
+ "}\n";
+
+ function->functionDefinition = fnOut.c_str();
+
+ mArrayEqualityFunctions.push_back(function);
+ mEqualityFunctions.push_back(function);
+
+ return function->functionName;
+}
+
+TString OutputHLSL::addArrayAssignmentFunction(const TType& type)
+{
+ for (auto it = mArrayAssignmentFunctions.cbegin(); it != mArrayAssignmentFunctions.cend(); ++it)
+ {
+ const auto &assignFunction = *it;
+ if (assignFunction.type == type)
+ {
+ return assignFunction.functionName;
+ }
+ }
+
+ const TString &typeName = TypeString(type);
+
+ ArrayHelperFunction function;
+ function.type = type;
+
+ TInfoSinkBase fnNameOut;
+ fnNameOut << "angle_assign_" << type.getArraySize() << "_" << typeName;
+ function.functionName = fnNameOut.c_str();
+
+ TInfoSinkBase fnOut;
+
+ fnOut << "void " << function.functionName << "(out "
+ << typeName << " a[" << type.getArraySize() << "], "
+ << typeName << " b[" << type.getArraySize() << "])\n"
+ << "{\n"
+ " for (int i = 0; i < " << type.getArraySize() << "; ++i)\n"
+ " {\n"
+ " a[i] = b[i];\n"
+ " }\n"
+ "}\n";
+
+ function.functionDefinition = fnOut.c_str();
+
+ mArrayAssignmentFunctions.push_back(function);
+
+ return function.functionName;
+}
+
}