summaryrefslogtreecommitdiffstats
path: root/llvm/include/llvm/Analysis/IndirectCallVisitor.h
blob: c8429e52bee966d1d137e23ebaf75b7fdc2c6b6d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
//===-- IndirectCallVisitor.h - indirect call visitor ---------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements defines a visitor class and a helper function that find
// all indirect call-sites in a function.

#ifndef LLVM_ANALYSIS_INDIRECTCALLVISITOR_H
#define LLVM_ANALYSIS_INDIRECTCALLVISITOR_H

#include "llvm/ADT/SetVector.h"
#include "llvm/IR/InstVisitor.h"
#include <vector>

namespace llvm {
// Visitor class that finds indirect calls or instructions that gives vtable
// value, depending on Type.
struct PGOIndirectCallVisitor : public InstVisitor<PGOIndirectCallVisitor> {
  enum class InstructionType {
    kIndirectCall = 0,
    kVTableVal = 1,
  };
  std::vector<CallBase *> IndirectCalls;
  std::vector<Instruction *> ProfiledAddresses;
  PGOIndirectCallVisitor(InstructionType Type) : Type(Type) {}

  void visitCallBase(CallBase &Call) {
    if (Call.isIndirectCall()) {
      IndirectCalls.push_back(&Call);

      if (Type != InstructionType::kVTableVal)
        return;

      LoadInst *LI = dyn_cast<LoadInst>(Call.getCalledOperand());
      // The code pattern to look for
      //
      // %vtable = load ptr, ptr %b
      // %vfn = getelementptr inbounds ptr, ptr %vtable, i64 1
      // %2 = load ptr, ptr %vfn
      // %call = tail call i32 %2(ptr %b)
      //
      // %vtable is the vtable address value to profile, and
      // %2 is the indirect call target address to profile.
      if (LI != nullptr) {
        Value *Ptr = LI->getPointerOperand();
        Value *VTablePtr = Ptr->stripInBoundsConstantOffsets();
        // This is a heuristic to find address feeding instructions.
        // FIXME: Add support in the frontend so LLVM type intrinsics are
        // emitted without LTO. This way, added intrinsics could filter
        // non-vtable instructions and reduce instrumentation overhead.
        // Since a non-vtable profiled address is not within the address
        // range of vtable objects, it's stored as zero in indexed profiles.
        // A pass that looks up symbol with an zero hash will (almost) always
        // find nullptr and skip the actual transformation (e.g., comparison
        // of symbols). So the performance overhead from non-vtable profiled
        // address is negligible if exists at all. Comparing loaded address
        // with symbol address guarantees correctness.
        if (VTablePtr != nullptr && isa<Instruction>(VTablePtr)) {
          ProfiledAddresses.push_back(cast<Instruction>(VTablePtr));
        }
      }
    }
  }

private:
  InstructionType Type;
};

inline std::vector<CallBase *> findIndirectCalls(Function &F) {
  PGOIndirectCallVisitor ICV(
      PGOIndirectCallVisitor::InstructionType::kIndirectCall);
  ICV.visit(F);
  return ICV.IndirectCalls;
}

inline std::vector<Instruction *> findVTableAddrs(Function &F) {
  PGOIndirectCallVisitor ICV(
      PGOIndirectCallVisitor::InstructionType::kVTableVal);
  ICV.visit(F);
  return ICV.ProfiledAddresses;
}

} // namespace llvm

#endif