summaryrefslogtreecommitdiffstats
path: root/src/3rdparty/angle/src/compiler/translator/RunAtTheEndOfShader.cpp
blob: 3c4209c539d0c6d9bbf43b847a29f054637df936 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
//
// Copyright (c) 2017 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// RunAtTheEndOfShader.cpp: Add code to be run at the end of the shader. In case main() contains a
// return statement, this is done by replacing the main() function with another function that calls
// the old main, like this:
//
// void main() { body }
// =>
// void main0() { body }
// void main()
// {
//     main0();
//     codeToRun
// }
//
// This way the code will get run even if the return statement inside main is executed.
//

#include "compiler/translator/RunAtTheEndOfShader.h"

#include "compiler/translator/FindMain.h"
#include "compiler/translator/IntermNode.h"
#include "compiler/translator/IntermNode_util.h"
#include "compiler/translator/IntermTraverse.h"
#include "compiler/translator/SymbolTable.h"

namespace sh
{

namespace
{

class ContainsReturnTraverser : public TIntermTraverser
{
  public:
    ContainsReturnTraverser() : TIntermTraverser(true, false, false), mContainsReturn(false) {}

    bool visitBranch(Visit visit, TIntermBranch *node) override
    {
        if (node->getFlowOp() == EOpReturn)
        {
            mContainsReturn = true;
        }
        return false;
    }

    bool containsReturn() { return mContainsReturn; }

  private:
    bool mContainsReturn;
};

bool ContainsReturn(TIntermNode *node)
{
    ContainsReturnTraverser traverser;
    node->traverse(&traverser);
    return traverser.containsReturn();
}

void WrapMainAndAppend(TIntermBlock *root,
                       TIntermFunctionDefinition *main,
                       TIntermNode *codeToRun,
                       TSymbolTable *symbolTable)
{
    // Replace main() with main0() with the same body.
    TSymbolUniqueId oldMainId(symbolTable);
    std::stringstream oldMainName;
    oldMainName << "main" << oldMainId.get();
    TIntermFunctionDefinition *oldMain = CreateInternalFunctionDefinitionNode(
        TType(EbtVoid), oldMainName.str().c_str(), main->getBody(), oldMainId);

    bool replaced = root->replaceChildNode(main, oldMain);
    ASSERT(replaced);

    // void main()
    TIntermFunctionPrototype *newMainProto = new TIntermFunctionPrototype(
        TType(EbtVoid), main->getFunctionPrototype()->getFunctionSymbolInfo()->getId());
    newMainProto->getFunctionSymbolInfo()->setName("main");

    // {
    //     main0();
    //     codeToRun
    // }
    TIntermBlock *newMainBody     = new TIntermBlock();
    TIntermAggregate *oldMainCall = CreateInternalFunctionCallNode(
        TType(EbtVoid), oldMainName.str().c_str(), oldMainId, new TIntermSequence());
    newMainBody->appendStatement(oldMainCall);
    newMainBody->appendStatement(codeToRun);

    // Add the new main() to the root node.
    TIntermFunctionDefinition *newMain = new TIntermFunctionDefinition(newMainProto, newMainBody);
    root->appendStatement(newMain);
}

}  // anonymous namespace

void RunAtTheEndOfShader(TIntermBlock *root, TIntermNode *codeToRun, TSymbolTable *symbolTable)
{
    TIntermFunctionDefinition *main = FindMain(root);
    if (!ContainsReturn(main))
    {
        main->getBody()->appendStatement(codeToRun);
        return;
    }

    WrapMainAndAppend(root, main, codeToRun, symbolTable);
}

}  // namespace sh