目录

编译程序高级教程课程实验

听了几节讲什么语法文法的课, 听不懂. 实验和这个一点关系也没有, 建议不讲这些奇奇怪怪的, 讲点 LLVM 多好.

第一个实验是基于 llvm 做一个 c 的解释器, 给了部分实现好的, 需要补充. 一共 25 个测试程序.

根据 Readme 写了个简单的 Makefile 脚本, 方便编译执行测试什么的.

makefile

USER_ID=$(shell id -u)
CONTAINER_NAME=lczxxx123/llvm_10_hw:0.2
HOST_CODE_DIR=./ast-interpreter/
HOST_BUILD_DIR=./build
HOST_TESTCASE_DIR=./testcases
CONTAINER_CODE_DIR=/tmp/code
CONTAINER_BUILD_DIR=/tmp/build
CONTAINER_TESTCASE_DIR=/tmp/testcases
LLVM_DIR=/usr/local/llvm10ra
BUILD_OUTPUT=ast-interpreter
BUILD_TYPE=Debug
FLAGS=-DDEBUG

CPP_FILES=$(wildcard $(HOST_CODE_DIR)/*.cpp)
H_FILES=$(wildcard $(HOST_CODE_DIR)/*.h)
ALL_FILES=$(CPP_FILES) $(H_FILES)
TESTCASE_FILES=$(wildcard $(HOST_TESTCASE_DIR)/*.c)

$(HOST_BUILD_DIR)/$(BUILD_OUTPUT): $(ALL_FILES)
	mkdir -p $(HOST_BUILD_DIR)
	docker run --rm \
		--user=$(USER_ID) \
		-v $(HOST_CODE_DIR):$(CONTAINER_CODE_DIR) \
		-v $(HOST_BUILD_DIR):$(CONTAINER_BUILD_DIR) \
		$(CONTAINER_NAME) /bin/bash -c "\
		cmake -DLLVM_DIR=$(LLVM_DIR) \
		-DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \
		-DCMAKE_CXX_FLAGS_DEBUG=$(FLAGS) \
		$(CONTAINER_CODE_DIR) \
		-B $(CONTAINER_BUILD_DIR) && \
		make -C $(CONTAINER_BUILD_DIR)"

.PHONY: all build docker run clean test
all: build

docker:
  docker pull $(CONTAINER_NAME)

build: $(HOST_BUILD_DIR)/$(BUILD_OUTPUT)

run: build
	docker run --rm -it \
		--user=$(USER_ID) \
		-v $(HOST_BUILD_DIR):$(CONTAINER_BUILD_DIR) \
		$(CONTAINER_NAME) \
		$(CONTAINER_BUILD_DIR)/$(BUILD_OUTPUT) "`cat $(HOST_TESTCASE_DIR)/test$(test).c`"

ast:
	docker run --rm \
		--user=$(USER_ID) \
		-v $(HOST_TESTCASE_DIR):$(CONTAINER_TESTCASE_DIR) \
		$(CONTAINER_NAME)	\
		$(LLVM_DIR)/bin/clang -fsyntax-only -Xclang -ast-dump $(CONTAINER_TESTCASE_DIR)/test$(test).c

test:
	mkdir -p $(HOST_BUILD_DIR)
	docker run --rm \
		--user=$(USER_ID) \
		-v $(HOST_CODE_DIR):$(CONTAINER_CODE_DIR) \
		-v $(HOST_BUILD_DIR):$(CONTAINER_BUILD_DIR) \
		$(CONTAINER_NAME) /bin/bash -c "\
		cmake -DLLVM_DIR=$(LLVM_DIR) \
		$(CONTAINER_CODE_DIR) \
		-B $(CONTAINER_BUILD_DIR) && \
		make -C $(CONTAINER_BUILD_DIR)"

	for test in $(TESTCASE_FILES) ; do \
		gcc ./builtin.c $$test -o /tmp/a.out ; \
		/tmp/a.out > /tmp/ans ; \
		docker run --rm -it \
		--user=$(USER_ID) \
		-v $(HOST_BUILD_DIR):$(CONTAINER_BUILD_DIR) \
		$(CONTAINER_NAME) \
		$(CONTAINER_BUILD_DIR)/$(BUILD_OUTPUT) "`cat $$test`" > /tmp/out ; \
		if diff /tmp/ans /tmp/out ; then \
			echo -e "\033[32mTest $$test passed\033[0m" ; \
		else \
			echo -e "\033[31mTest $$test failed\033[0m" ; \
		fi ; \
	done

clean:
	rm -rf $(HOST_BUILD_DIR)

由于时间紧迫, 就不深入研究 LLVM 的代码了, 直接对着名字猜 + 询问 GPT, 理解一下这个解释器是怎么写的, 需要添加什么功能就好.

ASTInterpreter.cppmain 函数如下:

cpp

int main(int argc, char **argv) {
  if (argc > 1) {
    clang::tooling::runToolOnCode(
        std::unique_ptr<clang::FrontendAction>(new InterpreterClassAction),
        argv[1]);
  }
}

调用了 llvm 的 runToolOnCode, 传入了 FrontendAction 类实例和输入的程序. 其中 FrontendAction 是继承实现的 InterpreterClassAction:

cpp

class InterpreterClassAction : public ASTFrontendAction {
public:
  virtual std::unique_ptr<clang::ASTConsumer>
  CreateASTConsumer(clang::CompilerInstance &Compiler, llvm::StringRef InFile) {
    return std::unique_ptr<clang::ASTConsumer>(
        new InterpreterConsumer(Compiler.getASTContext()));
  }
};

这个类继承自 ASTFrontendAction, 重载了 CreateASTConsumer 函数, 实例化了一个 InterpreterConsumer 类. 能够猜测到是用来对 AST 进行操作的类. 其中的参数就是代码的 AST. InterpreterConsumer 如下:

cpp

class InterpreterConsumer : public ASTConsumer {
public:
  explicit InterpreterConsumer(const ASTContext &context)
      : mEnv(), mVisitor(context, &mEnv) {}
  virtual ~InterpreterConsumer() {}

  virtual void HandleTranslationUnit(clang::ASTContext &Context) {
    // ...
  }

private:
  Environment mEnv;
  InterpreterVisitor mVisitor;
};

继承自 ASTConsumer, 构造函数里初始化了成员 Environment mEnvInterpreterVisitor mVisitor. EnvironmentEnvironment.h 中有实现, 是整个解释器的核心功能, 包括了堆栈环境, 以及访问到语句时候的操作 (需要自己实现). InterpreterVisitor 是遍历 AST 处理每条语句的, 继承自 EvaluatedExprVisitor, 实现如下:

cpp

class InterpreterVisitor : public EvaluatedExprVisitor<InterpreterVisitor> {
public:
  explicit InterpreterVisitor(const ASTContext &context, Environment *env)
      : EvaluatedExprVisitor(context), mEnv(env) {}
  virtual ~InterpreterVisitor() {}

  virtual void VisitBinaryOperator(BinaryOperator *bop) {
    VisitStmt(bop);
    mEnv->binop(bop);
  }
  virtual void VisitDeclRefExpr(DeclRefExpr *expr) {
    VisitStmt(expr);
    mEnv->declref(expr);
  }
  virtual void VisitCastExpr(CastExpr *expr) {
    VisitStmt(expr);
    mEnv->cast(expr);
  }
  virtual void VisitCallExpr(CallExpr *call) {
    VisitStmt(call);
    mEnv->call(call);
  }
  virtual void VisitDeclStmt(DeclStmt *declstmt) {
    mEnv->decl(declstmt);
  }

private:
  Environment *mEnv;
};

主要就是重载一些 Visit 函数, 然后让 mEnv 来做功能. 这里除了 VisitDeclStmt, 都有先调用了 VisitStmt, 这是遍历当前子树用的, 因为 AST 的处理需要儿子先处理完. 在作业需要实现的语法里, DeclStmt 没有子树, 就不用这个了. 当然加上也没问题.

回到 InterpreterConsumer, 其重载的 HandleTranslationUnit 函数是处理 AST 的入口. 这个函数本来是用于将 AST 翻译成目标代码的入口, 这里拿过来用做 AST 解释器的入口.

c

  virtual void HandleTranslationUnit(clang::ASTContext &Context) {
    TranslationUnitDecl *decl = Context.getTranslationUnitDecl();
    mEnv.init(decl);

    FunctionDecl *entry = mEnv.getEntry();
    mVisitor.VisitStmt(entry->getBody());
  }

TranslationUnitDecl 是 AST 的根节点. mEnvinit 函数拿到这个根节点去初始化. 一会说. 接着找到入口函数, 用 InterpreterVisitorVisitStmt 函数去遍历入口函数的 AST, 也即从入口函数开始解析.

cpp

class Environment {
  std::vector<StackFrame> mStack;

  FunctionDecl *mFree; /// Declartions to the built-in functions
  FunctionDecl *mMalloc;
  FunctionDecl *mInput;
  FunctionDecl *mOutput;

  FunctionDecl *mEntry;

public:
  /// Get the declartions to the built-in functions
  Environment()
      : mStack(), mFree(NULL), mMalloc(NULL), mInput(NULL), mOutput(NULL),
        mEntry(NULL) {}

  /// Initialize the Environment
  void init(TranslationUnitDecl *unit) {
    for (TranslationUnitDecl::decl_iterator i = unit->decls_begin(),
                                            e = unit->decls_end();
         i != e; ++i) {
      if (FunctionDecl *fdecl = dyn_cast<FunctionDecl>(*i)) {
        if (fdecl->getName().equals("FREE"))
          mFree = fdecl;
        else if (fdecl->getName().equals("MALLOC"))
          mMalloc = fdecl;
        else if (fdecl->getName().equals("GET"))
          mInput = fdecl;
        else if (fdecl->getName().equals("PRINT"))
          mOutput = fdecl;
        else if (fdecl->getName().equals("main"))
          mEntry = fdecl;
      }
    }
    mStack.push_back(StackFrame());
  }

  FunctionDecl *getEntry() { return mEntry; }
  // ...
};

定义了 “栈”, 以及用 FunctionDecl * 来确定的四个内置函数以及一个入口函数. init 函数在 AST 中遍历 FunctionDecl, 找到定义的四个内置函数以及 main 函数, 分别将地址赋值给成员. 最后在 “栈” 上新建一个栈帧, 供 main 函数使用.

这个 “栈” 和普通意义的栈稍微有些不一样.

cpp

class StackFrame {
  /// StackFrame maps Variable Declaration to Value
  /// Which are either integer or addresses (also represented using an Integer
  /// value)
  std::map<Decl *, int> mVars;
  std::map<Stmt *, int> mExprs;
  /// The current stmt
  Stmt *mPC;

public:
  StackFrame() : mVars(), mExprs(), mPC() {}

  void bindDecl(Decl *decl, int val) { mVars[decl] = val; }
  int getDeclVal(Decl *decl) {
    assert(mVars.find(decl) != mVars.end());
    return mVars.find(decl)->second;
  }
  void bindStmt(Stmt *stmt, int val) { mExprs[stmt] = val; }
  int getStmtVal(Stmt *stmt) {
    assert(mExprs.find(stmt) != mExprs.end());
    return mExprs[stmt];
  }
  void setPC(Stmt *stmt) { mPC = stmt; }
  Stmt *getPC() { return mPC; }
};

它仅仅用于保存函数内的变量和表达式. 变量由 VarDecl 唯一确定, 表达式的值用 Stmt 唯一确定. 除此之外还有一个 mPC, 表示当前函数遍历到的语句. 接下来看 Environment 实现的用于解释程序的函数:

cpp

  void binop(BinaryOperator *bop) {
    // ...
  }
  void decl(DeclStmt *declstmt) {
    // ...
  }
  void declref(DeclRefExpr *declref) {
    // ...
  }
  void cast(CastExpr *castexpr) {
    // ...
  }
  /// !TODO Support Function Call
  void call(CallExpr *callexpr) {
    mStack.back().setPC(callexpr);
    int val = 0;
    FunctionDecl *callee = callexpr->getDirectCallee();
    if (callee == mInput) {
      llvm::errs() << "Please Input an Integer Value : ";
      scanf("%d", &val);

      mStack.back().bindStmt(callexpr, val);
    } else if (callee == mOutput) {
      Expr *decl = callexpr->getArg(0);
      val = mStack.back().getStmtVal(decl);
      llvm::errs() << val;
    } else {
      /// You could add your code here for Function call Return
    }
  }

这些就是需要完善的函数, 或者可以添加新的函数, 以处理更多类型的语句. 这里主要看 call, 已经实现好了 GETPRINT, 是通过 callexpr->getDirectCallee() 获得被调用的函数, 然后判断是不是之前保存的 FunctionDecl 指针.

虽然题目给了所有的语法, 但是我还是习惯根据测试点来写, 因为只有这样才知道我要写什么qaq.

加入一些 log 信息, 运行第一个测试点, 可以发现在 a=100 时 assert 了, 原因是右表达式找不到引用. 打印一下 AST 发现右操作数是 IntegerLiteral, 程序现在还没有对应的 visitor 来处理这个.

-BinaryOperator 0x5604311852d0 <line:8:4, col:6> 'int' '='
 |-DeclRefExpr 0x560431185290 <col:4> 'int' lvalue Var 0x560431185210 'a' 'int'
 `-IntegerLiteral 0x5604311852b0 <col:6> 'int' 100

所以模仿 InterpreterVisitor 里的其他 Visit 函数, 重载 VisitIntegerLiteral, 实现并调用 mEnv->integer, 在栈帧中绑定一下表达式和它的值即可.

cpp

class InterpreterVisitor : public EvaluatedExprVisitor<InterpreterVisitor> {
  // ...
  virtual void VisitIntegerLiteral(IntegerLiteral *literal) {
    VisitStmt(literal);
    mEnv->integer(literal);
  }
}

class Environment {
  void integer(IntegerLiteral *literal) {
    mStack.back().setPC(literal);
    mStack.back().bindStmt(literal, literal->getValue().getSExtValue());
  }
}

这个测试点设计到全局变量, 函数调用, if else 语句以及一些二元运算.

全局变量

全局变量直接在 Environment 加一个就行, init 里处理一下. VarDeclgetInit 可以获取初始化值, 由于题目赋值和初始化只有 int 类型, 所以这里直接用 IntegerLiteral 去 dyn cast 了.

cpp

  void initGlobal(VarDecl *vardecl) {
    int val = 0;
    if (vardecl->hasInit()) {
      if (IntegerLiteral *literal =
              dyn_cast<IntegerLiteral>(vardecl->getInit()))
        val = literal->getValue().getSExtValue();
      mGlobalVars[vardecl] = val;
    }
#ifdef DEBUG
    llvm::errs() << "Declaring Global Variable: " << vardecl->getName() << " = "
                 << val << "\n";
#endif
  }

找变量引用也需要修改:

cpp

  int getDeclVal(Decl *decl) {
    auto it = mGlobalVars.find(decl);
    if (it != mGlobalVars.end())
      return it->second;
    return mStack.back().getDeclVal(decl);
  }

同理还有 bindDeclVal:

cpp

  void bindDecl(Decl *decl, int val) {
    if (mGlobalVars.find(decl) != mGlobalVars.end())
      mGlobalVars.find(decl)->second = val;
    else
      mStack.back().bindDecl(decl, val);
  }

函数调用

首先需要有一个存函数绑定的地方, 可以为 Environment 新增一个成员变量 std::map<StringRef, FunctionDecl *> mFunctions. 在初始化的时候把除了 main 之外的函数定义语句丢进去就行:

cpp

  void initFunction(FunctionDecl *fdecl) {
#ifdef DEBUG
    llvm::errs() << "Declaring Function: " << fdecl->getName() << "\n";
#endif
    if (fdecl->getName().equals("FREE"))
      mFree = fdecl;
    else if (fdecl->getName().equals("MALLOC"))
      mMalloc = fdecl;
    else if (fdecl->getName().equals("GET"))
      mInput = fdecl;
    else if (fdecl->getName().equals("PRINT"))
      mOutput = fdecl;
    else if (fdecl->getName().equals("main"))
      mEntry = fdecl;
    else
      mFunctions[fdecl->getName()] = fdecl;
  }

这里还改写了一下 init:

cpp

  /// Initialize the Environment
  void init(TranslationUnitDecl *unit) {
#ifdef DEBUG
    llvm::errs() << "Initializing Environment\n";
#endif
    for (TranslationUnitDecl::decl_iterator i = unit->decls_begin(),
                                            e = unit->decls_end();
         i != e; ++i) {
      if (FunctionDecl *fdecl = dyn_cast<FunctionDecl>(*i))
        initFunction(fdecl);
      else if (VarDecl *vardecl = dyn_cast<VarDecl>(*i))
        initGlobal(vardecl);
    }
    mStack.push_back(StackFrame());
  }

打印 AST 可以看到, f(b) 调用的时候发生了这些事情:

CallExpr 0x5614aea6b728 <col:10, col:13> 'int'
|-ImplicitCastExpr 0x5614aea6b710 <col:10> 'int (*)(int)' <FunctionToPointerDecay>
| `-DeclRefExpr 0x5614aea6b6d0 <col:10> 'int (int)' Function 0x5614aea6b300 'f' 'int (int)'
`-ImplicitCastExpr 0x5614aea6b750 <col:12> 'int' <LValueToRValue>
  `-DeclRefExpr 0x5614aea6b6f0 <col:12> 'int' lvalue Var 0x5614aea6b198 'b' 'int'

它会将 f 引用隐式转换为函数指针, 参数也会进行左值到右值的转换. 好像很多地方都有 ImplicitCast, 处理整型的已经写好了, 就是将引用的值取出来, 然后绑定到这条语句 (表达式):

cpp

  void cast(CastExpr *castexpr) {
    mStack.back().setPC(castexpr);
    if (castexpr->getType()->isIntegerType()) {
      Expr *expr = castexpr->getSubExpr();
      int val = mStack.back().getStmtVal(expr);
      mStack.back().bindStmt(castexpr, val);
    }
  }

函数变量引用转换为函数指针的倒是没写, 一开始我也没注意到 qaq, 之后处理函数调用的时候也没用上这个转换.

调用函数需要新建栈帧, 然后执行函数, 最后返回. 一开始还想着 ret addr 要怎么搞, 因为遍历的时候不能知道下一条语句. 思考了一会发现其实压根不用. 比如在 main 函数里调用 f 函数, 实际上是在 VisitStmt main 函数体的过程中, 又调用 VisitStmt f 函数体, 那等这个调用结束了, 自然就回到 main 函数的遍历过程中了.

Environmentcall 能够处理 “内置函数”, 但是不能 “调用” 其他函数. 虽然保存了 FunctionDecl, 但是没有 Visitor. 这里我的方法是让它返回 FunctionDecl *, 在 Visitor 里处理:

cpp

  FunctionDecl *call(CallExpr *callexpr) {
    mStack.back().setPC(callexpr);
    int val = 0;
    FunctionDecl *callee = callexpr->getDirectCallee();
    if (callee == mInput) {
      llvm::errs() << "Please Input an Integer Value : ";
      scanf("%d", &val);

      mStack.back().bindStmt(callexpr, val);
      return nullptr;
    } else if (callee == mOutput) {
      Expr *decl = callexpr->getArg(0);
      val = mStack.back().getStmtVal(decl);
      llvm::errs() << val;
      return nullptr;
    } else if (callee == mMalloc) {
      return nullptr;
    } else if (callee == mFree) {
      return nullptr;
    } else {
#ifdef DEBUG
      llvm::errs() << "Call Function: " << callee->getName() << "\n";
#endif
      auto it = mFunctions.find(callee->getName());
      assert(it != mFunctions.end());
      return it->second;
    }
  }

这样 VisitCallExpr 里就能拿到函数定义了. 包括参数和函数体. 在像 visit main 函数一样遍历 callee ast 之前, 需要一个栈帧, 储存参数等. 遍历完这个函数之后还需要清理栈帧, 并且把返回值绑定 caller 栈帧的 callexpr (stmt) 上. 于是写了个 pushFramepopFrame, 来建立和清理栈帧.

cpp

  virtual void VisitCallExpr(CallExpr *call) {
#ifdef DEBUG
    llvm::errs() << "VisitCallExpr: \n";
    call->dump();
#endif
    VisitStmt(call);
    if (FunctionDecl *fdecl = mEnv->call(call)) {
      mEnv->pushFrame(call, fdecl);
      VisitStmt(fdecl->getBody());
      mEnv->popFrame(call);
    }
  }

pushFrame 主要是建立新的栈帧, 将传入的参数值绑定到局部变量上. 根据打印出的 AST 可以看到, 参数名称是在 FunctionDecl 里写的, 所以这里还需要传入 *callee.

cpp

  void pushFrame(CallExpr *callexpr, FunctionDecl *callee) {
    StackFrame frame;
    assert(callee->param_size() == callexpr->getNumArgs());
    for (unsigned i = 0; i < callexpr->getNumArgs(); i++) {
      Expr *arg = callexpr->getArg(i);
      ParmVarDecl *param = callee->getParamDecl(i);
      int val = mStack.back().getStmtVal(arg);
      frame.bindDecl(param, val);
    }
    mStack.push_back(frame);
  }

接下来考虑函数的返回. 需要重载 VisitReturnStmt (可以写一个 return 语句打印出来就知道名称了, 或者直接翻源码 API). 然后获取返回值 (或者没有返回值). 在正常的机器中, 返回值都是用寄存器存起来的, 这个解释器没有寄存器这种东西, 我的想法是在栈帧上存一个返回值变量, return 语句的时候存进去.

现在还差一个问题, return 如何 “跳出” 当前函数 AST 遍历呢? 我的方法是在栈帧上再存一个 isReturn, 并且在 visit 每个节点的入口判断一下, return 了就不继续遍历和处理了.

cpp

// Environment
  void ret(ReturnStmt *retstmt) {
    if (Expr *expr = retstmt->getRetValue()) {
      int val = mStack.back().getStmtVal(expr);
      mStack.back().setRetVal(val);
    }
    mStack.back().setReturn();
  }

// InterpreterVisitor 
  virtual void VisitReturnStmt(ReturnStmt *retstmt) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitReturnStmt: \n";
    retstmt->dump();
#endif
    VisitStmt(retstmt);
    mEnv->ret(retstmt);
  }

popFrame 中获取栈帧的返回值, 绑定到 caller 的 CallExpr 上, 弹出栈帧:

cpp

  void popFrame(CallExpr *callexpr) {
    assert(isReturn());
    int retval = mStack.back().getRetVal();
    mStack.pop_back();
    mStack.back().bindStmt(callexpr, retval);
  }

二元运算

稍微修改一下使得表达式有值, 其他没啥好讲的, 类比赋值操作就好.

cpp

  void binop(BinaryOperator *bop) {
    Expr *left = bop->getLHS();
    Expr *right = bop->getRHS();

    int val = 0;
    if (bop->isAssignmentOp())
      val = assignOp(left, right);
    else if (bop->isAdditiveOp())
      val = addOp(left, right, bop->getOpcode() == BO_Add ? 1 : -1);
    else if (bop->getOpcode() == BO_Mul)
      val = mulOp(left, right);
    else if (bop->isComparisonOp())
      val = cmpOp(left, right, bop->getOpcode());
    else
      assert(false && "Unsupported Binary Operator");
    mStack.back().bindStmt(bop, val);
  }

  int assignOp(Expr *left, Expr *right) {
    int val = mStack.back().getStmtVal(right);
    mStack.back().bindStmt(left, val);
    if (DeclRefExpr *declexpr = dyn_cast<DeclRefExpr>(left)) {
      Decl *decl = declexpr->getFoundDecl();
      mStack.back().bindDecl(decl, val);
    }
    return val;
  }

  int addOp(Expr *left, Expr *right, int sign) {
    int lval = mStack.back().getStmtVal(left);
    int rval = mStack.back().getStmtVal(right);
    return lval + sign * rval;
  }

  int mulOp(Expr *left, Expr *right) {
    int lval = mStack.back().getStmtVal(left);
    int rval = mStack.back().getStmtVal(right);
    return lval * rval;
  }

  int cmpOp(Expr *left, Expr *right, BinaryOperator::Opcode op) {
    int lval = mStack.back().getStmtVal(left);
    int rval = mStack.back().getStmtVal(right);
    switch (op) {
    case BO_LT:
      return lval < rval;
    case BO_GT:
      return lval > rval;
    case BO_LE:
      return lval <= rval;
    case BO_GE:
      return lval >= rval;
    case BO_EQ:
      return lval == rval;
    case BO_NE:
      return lval != rval;
    default:
      assert(false && "Never reach here");
    }
  }

IfStmt

先访问条件语句, 根据条件语句来决进入哪个分支. 翻源码发现还有一个 Visit 函数是遍历整棵子树, 包括子树根节点. 这样就不用担心遍历到不走的分支了.

cpp

// Visitor
  virtual void VisitIfStmt(IfStmt *ifstmt) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitIfStmt: \n";
    ifstmt->dump();
#endif
    Expr *cond = ifstmt->getCond();
    Visit(cond);
    if (mEnv->getStmtVal(cond))
      Visit(ifstmt->getThen());
    else if (Stmt *els = ifstmt->getElse())
      Visit(els);
  }
// Environment
  int getStmtVal(Stmt *stmt) { return mStack.back().getStmtVal(stmt); }

然后一连串过了几个点, 直到 04 有负号, 属于 UnaryOperator:

BinaryOperator 0x5630de6a24b0 'int' lvalue '='
|-DeclRefExpr 0x5630de6a2458 'int' lvalue Var 0x5630de6a2320 'a' 'int'
`-UnaryOperator 0x5630de6a2498 'int' prefix '-'
  `-IntegerLiteral 0x5630de6a2478 'int' 10

写这个也很简单啦:

cpp

// InterpreterVisitor
  virtual void VisitUnaryOperator(UnaryOperator *uop) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitUnaryOperator: \n";
    uop->dump();
#endif
    VisitStmt(uop);
    mEnv->unaryop(uop);
  }

// Environment
  void unaryop(UnaryOperator *unaryop) {
    Expr *expr = unaryop->getSubExpr();
    int val = mStack.back().getStmtVal(expr);
    if (unaryop->getOpcode() == UO_Minus)
      val = -val;
    else
      assert(false && "Unsupported Unary Operator");
    mStack.back().bindStmt(unaryop, val);
  }

while 语句, 注意每次判断条件要重新访问.

cpp

  virtual void VisitWhileStmt(WhileStmt *whilestmt) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitWhileStmt: \n";
    whilestmt->dump();
#endif
    Expr *cond = whilestmt->getCond();
    while (Visit(cond), mEnv->getStmtVal(cond)) {
      Visit(whilestmt->getBody());
    }
  }

for 语句, copilot 帮我写完了.

cpp

  virtual void VisitForStmt(ForStmt *forstmt) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitForStmt: \n";
    forstmt->dump();
#endif
    if (Stmt *init = forstmt->getInit())
      Visit(init);
    Expr *cond = forstmt->getCond();
    while (!cond || (Visit(cond), mEnv->getStmtVal(cond))) {
      Visit(forstmt->getBody());
      if (Stmt *inc = forstmt->getInc())
        Visit(inc);
    }
  }

一连串 过了 运行每报错几个点, 来到 test12, 数组.

写一个简单的数组先看看 AST, 定义长这样:

-DeclStmt 0x5560aa2dd378 <line:7:3, col:11>
 `-VarDecl 0x5560aa2dd310 <col:3, col:10> col:7 used a 'int [3]'

赋值时长这样:

-BinaryOperator 0x5560aa2dd450 <line:8:3, col:10> 'int' '='
 |-ArraySubscriptExpr 0x5560aa2dd410 <col:3, col:6> 'int' lvalue
 | |-ImplicitCastExpr 0x5560aa2dd3f8 <col:3> 'int *' <ArrayToPointerDecay>
 | | `-DeclRefExpr 0x5560aa2dd390 <col:3> 'int [3]' lvalue Var 0x5560aa2dd310 'a' 'int [3]'
 | `-IntegerLiteral 0x5560aa2dd3b0 <col:5> 'int' 0
 `-IntegerLiteral 0x5560aa2dd430 <col:10> 'int' 10

使用时长这样:

CallExpr 0x5560aa2dd550 <line:15:3, col:13> 'void'
  |-ImplicitCastExpr 0x5560aa2dd538 <col:3> 'void (*)(int)' <FunctionToPointerDecay>
  | `-DeclRefExpr 0x5560aa2dd470 <col:3> 'void (int)' Function 0x5560aa2dd108 'PRINT' 'void (int)'
  `-ImplicitCastExpr 0x5560aa2dd578 <col:9, col:12> 'int' <LValueToRValue>
    `-ArraySubscriptExpr 0x5560aa2dd4e8 <col:9, col:12> 'int' lvalue
      |-ImplicitCastExpr 0x5560aa2dd4d0 <col:9> 'int *' <ArrayToPointerDecay>
      | `-DeclRefExpr 0x5560aa2dd490 <col:9> 'int [3]' lvalue Var 0x5560aa2dd310 'a' 'int [3]'
      `-IntegerLiteral 0x5560aa2dd4b0 <col:11> 'int' 0

可以看到, 对数组变量的引用需要先转换成指针, 然后通过 ArraySubscriptExpr 操作去取数组某下标. 参考对变量的赋值, 已经写好的操作是找到 DeclRefExpr 对应的 VarDecl, 并给他绑定一个值. 但是对于数组就不能这么玩. 因为数组有多个位置的值, 而 VarDecl 只有一个. 并且栈上的 mVars 是一个 VarDecl 对应一个 int 值, 不像正常的栈一样, 可以一个变量占不止一个空间. 所以得想个办法存数组.

最开始的想法是栈上再开一个类似 map<VarDelc, std::vector<int>> 这样的东西, 但是这样存在一个明显的问题, 赋值语句左边的 ArraySubscriptExpr 绑定没法精确到数组中的值. (其实也可以, 比如直接把地址绑定上去, 但是这样个人感觉有点奇奇怪怪的, 通篇只有这里用了解释器运行时的地址).

最后决定把数组放在堆上! 当然这个堆也是自己实现的一个堆, 不用系统的堆 (否则和上面方法本质上没有区别). 这个堆之后也会用到.

实现了一个简单的堆以及管理:

cpp

/// Heap maps address to a value
class Heap {
  std::vector<int> mHeap;
  std::set<int> mBin;

public:
  Heap() : mHeap(), mBin() {}

  int malloc(int size) {
    size++;
    if (!mBin.empty())
      for (int addr : mBin) {
        int chunkSize = mHeap[addr];
        if (chunkSize == size || chunkSize == size + 1)
          return toUserAddr(addr);
        if (mHeap[addr] < size)
          continue;
        int next = addr + size;
        mHeap[next] = mHeap[addr] - size;
        mHeap[addr] = size;
        return toUserAddr(addr);
      }
    int addr = mHeap.size();
    mHeap.resize(addr + size);
    mHeap[addr] = size;
    return toUserAddr(addr);
  }

  void free(int addr) {
    addr = toHeapAddr(addr);
    mBin.insert(addr);
    int next = addr + mHeap[addr];
    if (mBin.find(next) != mBin.end()) {
      mHeap[addr] += mHeap[next];
      mBin.erase(next);
    }
  }

  void set(int addr, int val) {
    assert(addr >= 0 && addr < mHeap.size());
    mHeap[addr] = val;
  }

  int get(int addr) {
    assert(addr >= 0 && addr < mHeap.size());
    return mHeap[addr];
  }

#ifdef DEBUG
  void dump() {
    for (int i = 0; i < mHeap.size(); i++)
      llvm::errs() << i << " : " << mHeap[i] << "\n";
  }
#endif

private:
  int toUserAddr(int addr) { return addr + 1; }
  int toHeapAddr(int addr) { return addr - 1; }
};

堆块大小按 int 来计数, 块头只有一个 size 字段, 只有一个 bin 放所有 freed 堆块, 也不需要 next 字段, 直接遍历 bin 就好. bin 用 set 容器, 直接防止 double free (不是). 还有一点点简单的合并. 如果当前释放的堆块的下一个也是被释放的, 那么合并他们.

现在定义数组的时候可以获取 size, 把 Heap::malloc 返回的 “堆地址” 绑定上. DeclRefExprImplicitCastExpr 的处理和 int 的一样, 直接把堆地址绑定上去就行.

数组在赋值和使用的时候有一点点区别, 使用的时候 ArraySubscriptExpr 的父节点是一个 LValueToRvalue 的 cast. 根据编译原理询问 GPT 得知, 左值是变量的引用, 左值到右值是解引用的一个过程, 右值才是可以使用的值, 并且无法修改. (没学过, 不知道理解对不对)

处理 ArraySubscriptExpr 就可以把变量的堆地址加上下标, 得到元素的堆地址, 绑定到这句表达式上. 对于赋值语句来说, 判断一下如果左边表达式是 ArraySubscriptExpr, 就用 Heap::set 赋值; 其他情况如果要使用这个变量的值, 则一定会有 LValueToRValue 的 cast, 那么在 cast 处理的时候判断如果被转换的子表达式是 ArraySubscriptExpr, 那么把 Heap::get 的值绑定到 cast 表达式上.

(感觉写的很简单, 但是摸索着写代码倒是写了蛮久, 还修了一点点 bug)

cpp

  void declref(DeclRefExpr *declref) {
    mStack.back().setPC(declref);
    QualType type = declref->getType();
    if (type->isIntegerType() || type->isArrayType()) {
      Decl *decl = declref->getFoundDecl();

      // int val = mStack.back().getDeclVal(decl);
      int val = getDeclVal(decl);
      mStack.back().bindStmt(declref, val);
    } else if (type->isFunctionType()) {
    } else
      assert(false && "Unsupported DeclRefExpr Type");
  }

    void cast(CastExpr *castexpr) {
    mStack.back().setPC(castexpr);
    QualType type = castexpr->getType();
#ifdef DEBUG
    llvm::errs() << "Casting Expression: ";
    castexpr->dump();
    llvm::errs() << " to Type: ";
    type.dump();
#endif
    if (type->isFunctionPointerType()) {
    } else if (type->isIntegerType() || type->isPointerType()) {
      Expr *expr = castexpr->getSubExpr();

      int val = mStack.back().getStmtVal(expr);
      if (ArraySubscriptExpr *arraysub = dyn_cast<ArraySubscriptExpr>(expr))
        val = mHeap.get(val);
      mStack.back().bindStmt(castexpr, val);
    } else
      assert(false && "Unsupported CastExpr Type");
  }

cpp

  virtual void VisitArraySubscriptExpr(ArraySubscriptExpr *subscript) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitArraySubscriptExpr: \n";
    subscript->dump();
#endif
    VisitStmt(subscript);
    mEnv->arraySubscript(subscript);
  }

一连过了几个点, 到了 test17, 出现了指针以及 sizeof

定义指针和定义整数一样, 但是初始值最好不要为 0, 在我的实现里堆地址 0 是合法指针.

cpp

  void decl(DeclStmt *declstmt) {
    for (DeclStmt::decl_iterator it = declstmt->decl_begin(),
                                 ie = declstmt->decl_end();
         it != ie; ++it) {
      Decl *decl = *it;
      if (VarDecl *vardecl = dyn_cast<VarDecl>(decl)) {
        QualType type = vardecl->getType();

        if (type->isIntegerType() || type->isPointerType()) {
          int val = type->isIntegerType() ? 0 : -1; // init pointer to NULL (-1)
          if (Expr *expr = vardecl->getInit())
            val = mStack.back().getStmtVal(expr);
          mStack.back().bindDecl(vardecl, val);
#ifdef DEBUG
          llvm::errs() << "Declaring Variable: " << vardecl->getName() << " = "
                       << val << "\n";
#endif
        } 
      }
    }
    // ...
  }

declef 也要加:

cpp

  void declref(DeclRefExpr *declref) {
    mStack.back().setPC(declref);
    QualType type = declref->getType();
    if (type->isIntegerType() || type->isArrayType() || type->isPointerType()) {
      Decl *decl = declref->getFoundDecl();

      // int val = mStack.back().getDeclVal(decl);
      int val = getDeclVal(decl);
      mStack.back().bindStmt(declref, val);
    } else if (type->isFunctionType()) {
    } else
      assert(false && "Unsupported DeclRefExpr Type");
  }

对指针指向的值赋值和使用指针指向的值, 也就是解引用 (*) 操作分别作为左值在赋值语句中使用, 和作为右值使用, 处理和数组的类似. 赋值语句判断一下左儿子是不是解引用操作, 那么就用 Heap::set 赋值; 使用这个变量的值也一定会有 LValueToRValue 的 cast, 判断子表达式是 UnaryOperator 且操作符是 UO_Deref, 把 Heap::get 的值绑定到 cast 表达式上.

cpp

  void unaryop(UnaryOperator *uop) {
    Expr *expr = uop->getSubExpr();
    int val = mStack.back().getStmtVal(expr);
    if (uop->getOpcode() == UO_Minus)
      val = -val;
    else if (uop->getOpcode() == UO_Deref) {
    } else
      assert(false && "Unsupported Unary Operator");
    mStack.back().bindStmt(uop, val);
  }

  int assignOp(Expr *left, Expr *right) {
#ifdef DEBUG
    llvm::errs() << "Assignment Operation\n";
#endif
    int val = mStack.back().getStmtVal(right);
    // ...
    if (UnaryOperator *uop = dyn_cast<UnaryOperator>(left)) {
      int addr = mStack.back().getStmtVal(uop);
      mHeap.set(addr, val);
    } else
      assert(false && "Unsupported Assignment Operation");
    mStack.back().bindStmt(left, val);
    return val;
  }

  void cast(CastExpr *castexpr) {
    mStack.back().setPC(castexpr);
    QualType type = castexpr->getType();
    if (type->isFunctionPointerType()) {
    } else if (type->isIntegerType() || type->isPointerType()) {
      Expr *expr = castexpr->getSubExpr();

      int val = mStack.back().getStmtVal(expr);
      if (ArraySubscriptExpr *arraysub = dyn_cast<ArraySubscriptExpr>(expr))
        val = mHeap.get(val);
      if (UnaryOperator *uop = dyn_cast<UnaryOperator>(expr))
        if (uop->getOpcode() == UO_Deref)
          val = mHeap.get(val);
      mStack.back().bindStmt(castexpr, val);
    } else
      assert(false && "Unsupported CastExpr Type");
  }

sizeof 的 AST 为 UnaryExprOrTypeTraitExpr:

UnaryExprOrTypeTraitExpr 0x563ea4c44460 <col:21, col:31> 'unsigned long' sizeof 'int'

由于这个解释器只涉及到整数和 “指针”, 指针又只会在自己实现的 “堆” 上, 用 int 做地址, 所以直接给他绑个 1 上去就完事了 ^_^.

cpp

// InterpretVisitor
  virtual void VisitUnaryExprOrTypeTraitExpr(UnaryExprOrTypeTraitExpr *uett) {
#ifdef DEBUG
    llvm::errs() << "VisitUnaryExprOrTypeTraitExpr: \n";
    uett->dump();
#endif
    if (mEnv->isReturn())
      return;
    VisitStmt(uett);
    mEnv->unaryExprOrTypeTrait(uett);
  }

// Environment
  void unaryExprOrTypeTrait(UnaryExprOrTypeTraitExpr *uett) {
    QualType type = uett->getTypeOfArgument();
    if (uett->getKind() == UETT_SizeOf) {
      // HACK: only int or pointer, the pointer is int in this interpreter
      if (type->isIntegerType() || type->isPointerType()) {
        mStack.back().bindStmt(uett, 1);
      }
    }
  }

写完测了一下输出全对了! 本着怀疑的态度查看了之后每一个测试点, 有二级指针以及 char 类型. 指针写的好二级指针直接过了, char 在 clang 里也是整型, 所以 char 完全无所谓.

cpp

//==--- tools/clang-check/ClangInterpreter.cpp - Clang Interpreter tool
//--------------===//
//===----------------------------------------------------------------------===//

#include "clang/AST/ASTConsumer.h"
#include "clang/AST/EvaluatedExprVisitor.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Frontend/FrontendAction.h"
#include "clang/Tooling/Tooling.h"

using namespace clang;

#include "Environment.h"

class InterpreterVisitor : public EvaluatedExprVisitor<InterpreterVisitor> {
public:
  explicit InterpreterVisitor(const ASTContext &context, Environment *env)
      : EvaluatedExprVisitor(context), mEnv(env) {}
  virtual ~InterpreterVisitor() {}

#ifdef DEBUG
  virtual void VisitStmt(Stmt *stmt) {
    Expr *expr = dyn_cast<Expr>(stmt);
    if (!((expr && (isa<DeclRefExpr>(expr) || isa<CastExpr>(expr) ||
                    isa<CallExpr>(expr) || isa<BinaryOperator>(expr) ||
                    isa<UnaryOperator>(expr) || isa<IntegerLiteral>(expr) ||
                    isa<ArraySubscriptExpr>(expr) || isa<ParenExpr>(expr) ||
                    isa<UnaryExprOrTypeTraitExpr>(expr))) ||
          (isa<DeclStmt>(stmt) || isa<CompoundStmt>(stmt) ||
           isa<ReturnStmt>(stmt) || isa<IfStmt>(stmt) ||
           isa<WhileStmt>(stmt)))) {
      // red color
      llvm::errs() << "\033[0;31m" << "Unsupport Stmt: \n";
      stmt->dump();
      llvm::errs() << "\033[0m";
    }
    EvaluatedExprVisitor::VisitStmt(stmt);
  }
#endif

  virtual void VisitBinaryOperator(BinaryOperator *bop) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitBinaryOperator: \n";
    bop->dump();
#endif
    VisitStmt(bop);
    mEnv->binop(bop);
  }

  virtual void VisitDeclRefExpr(DeclRefExpr *expr) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitDeclRefExpr: \n";
    expr->dump();
#endif
    VisitStmt(expr);
    mEnv->declref(expr);
  }

  virtual void VisitCastExpr(CastExpr *expr) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitCastExpr: \n";
    expr->dump();
#endif
    VisitStmt(expr);
    mEnv->cast(expr);
  }

  virtual void VisitCallExpr(CallExpr *call) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitCallExpr: \n";
    call->dump();
#endif
    VisitStmt(call);
    // HACK: ImplicitCastExpr is not used in the interpreter
    // but it just works fine in this case
    if (FunctionDecl *fdecl = mEnv->call(call)) {
      mEnv->pushFrame(call, fdecl);
      VisitStmt(fdecl->getBody());
      mEnv->popFrame(call);
    }
  }

  virtual void VisitDeclStmt(DeclStmt *declstmt) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitDeclStmt: \n";
    declstmt->dump();
#endif
    VisitStmt(declstmt);
    mEnv->decl(declstmt);
  }

  virtual void VisitIntegerLiteral(IntegerLiteral *literal) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitIntegerLiteral: \n";
    literal->dump();
#endif
    VisitStmt(literal);
    mEnv->integer(literal);
  }

  virtual void VisitReturnStmt(ReturnStmt *retstmt) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitReturnStmt: \n";
    retstmt->dump();
#endif
    VisitStmt(retstmt);
    mEnv->ret(retstmt);
  }

  virtual void VisitIfStmt(IfStmt *ifstmt) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitIfStmt: \n";
    ifstmt->dump();
#endif
    Expr *cond = ifstmt->getCond();
    Visit(cond);
    if (mEnv->getStmtVal(cond))
      Visit(ifstmt->getThen());
    else if (Stmt *els = ifstmt->getElse())
      Visit(els);
  }

  virtual void VisitUnaryOperator(UnaryOperator *uop) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitUnaryOperator: \n";
    uop->dump();
#endif
    VisitStmt(uop);
    mEnv->unaryop(uop);
  }

  virtual void VisitWhileStmt(WhileStmt *whilestmt) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitWhileStmt: \n";
    whilestmt->dump();
#endif
    Expr *cond = whilestmt->getCond();
    while (Visit(cond), mEnv->getStmtVal(cond)) {
      Visit(whilestmt->getBody());
    }
  }

  virtual void VisitForStmt(ForStmt *forstmt) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitForStmt: \n";
    forstmt->dump();
#endif
    if (Stmt *init = forstmt->getInit())
      Visit(init);
    Expr *cond = forstmt->getCond();
    while (!cond || (Visit(cond), mEnv->getStmtVal(cond))) {
      Visit(forstmt->getBody());
      if (Stmt *inc = forstmt->getInc())
        Visit(inc);
    }
  }

  virtual void VisitArraySubscriptExpr(ArraySubscriptExpr *subscript) {
    if (mEnv->isReturn())
      return;
#ifdef DEBUG
    llvm::errs() << "VisitArraySubscriptExpr: \n";
    subscript->dump();
#endif
    VisitStmt(subscript);
    mEnv->arraySubscript(subscript);
  }

  virtual void VisitParenExpr(ParenExpr *paren) {
#ifdef DEBUG
    llvm::errs() << "VisitParenExpr: \n";
    paren->dump();
#endif
    if (mEnv->isReturn())
      return;
    VisitStmt(paren);
    mEnv->paren(paren);
  }

  virtual void VisitUnaryExprOrTypeTraitExpr(UnaryExprOrTypeTraitExpr *uett) {
#ifdef DEBUG
    llvm::errs() << "VisitUnaryExprOrTypeTraitExpr: \n";
    uett->dump();
#endif
    if (mEnv->isReturn())
      return;
    VisitStmt(uett);
    mEnv->unaryExprOrTypeTrait(uett);
  }

private:
  Environment *mEnv;
};

class InterpreterConsumer : public ASTConsumer {
public:
  explicit InterpreterConsumer(const ASTContext &context)
      : mEnv(), mVisitor(context, &mEnv) {}
  virtual ~InterpreterConsumer() {}

  virtual void HandleTranslationUnit(clang::ASTContext &Context) {
#ifdef DEBUG
    llvm::errs() << "HandleTranslationUnit\n";
#endif
    TranslationUnitDecl *decl = Context.getTranslationUnitDecl();
    mEnv.init(decl);

    FunctionDecl *entry = mEnv.getEntry();
    mVisitor.VisitStmt(entry->getBody());
  }

private:
  Environment mEnv;
  InterpreterVisitor mVisitor;
};

class InterpreterClassAction : public ASTFrontendAction {
public:
  virtual std::unique_ptr<clang::ASTConsumer>
  CreateASTConsumer(clang::CompilerInstance &Compiler, llvm::StringRef InFile) {
    return std::unique_ptr<clang::ASTConsumer>(
        new InterpreterConsumer(Compiler.getASTContext()));
  }
};

int main(int argc, char **argv) {
  if (argc > 1) {
#ifdef DEBUG
    llvm::errs() << "Running interpreter on \n==========\n"
                 << argv[1] << "\n==========\n";
#endif
    clang::tooling::runToolOnCode(
        std::unique_ptr<clang::FrontendAction>(new InterpreterClassAction),
        argv[1]);
  }

cpp

//==--- tools/clang-check/ClangInterpreter.cpp - Clang Interpreter tool
//--------------===//
//===----------------------------------------------------------------------===//
#include <stdio.h>

#include "clang/AST/ASTConsumer.h"
#include "clang/AST/Decl.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Frontend/CompilerInstance.h"
#include "clang/Frontend/FrontendAction.h"
#include "clang/Tooling/Tooling.h"

using namespace clang;

class StackFrame {
  /// StackFrame maps Variable Declaration to Value
  /// Which are either integer or addresses (also represented using an Integer
  /// value)
  std::map<Decl *, int> mVars;
  std::map<Stmt *, int> mExprs;
  /// The current stmt
  Stmt *mPC;
  /// The return value of the function
  int mRetVal;
  /// Whether the function has returned
  bool mIsReturn;

public:
  StackFrame() : mVars(), mExprs(), mPC(), mRetVal(), mIsReturn(false) {}

  void bindDecl(Decl *decl, int val) { mVars[decl] = val; }
  int getDeclVal(Decl *decl) {
#ifdef DEBUG
    if (mVars.find(decl) == mVars.end()) {
      // red color
      llvm::errs() << "\033[0;31m";
      llvm::errs() << "Variable not found: ";
      decl->dump();
      llvm::errs() << "\033[0m";
      assert(0);
    }
#else
    assert(mVars.find(decl) != mVars.end());
#endif
    return mVars.find(decl)->second;
  }
  void bindStmt(Stmt *stmt, int val) { mExprs[stmt] = val; }
  int getStmtVal(Stmt *stmt) {
#ifdef DEBUG
    if (mExprs.find(stmt) == mExprs.end()) {
      // red color
      llvm::errs() << "\033[0;31m";
      llvm::errs() << "Statement not found: ";
      stmt->dump();
      llvm::errs() << "\033[0m";
      assert(0);
    }
#else
    assert(mExprs.find(stmt) != mExprs.end());
#endif
    return mExprs[stmt];
  }
  void setPC(Stmt *stmt) { mPC = stmt; }
  Stmt *getPC() { return mPC; }
  void setRetVal(int val) { mRetVal = val; }
  int getRetVal() { return mRetVal; }
  void setReturn() { mIsReturn = true; }
  bool isReturn() { return mIsReturn; }
};

/// Heap maps address to a value
class Heap {
  std::vector<int> mHeap;
  std::set<int> mBin;

public:
  Heap() : mHeap(), mBin() {}

  int malloc(int size) {
    size++;
    if (!mBin.empty())
      for (int addr : mBin) {
        int chunkSize = mHeap[addr];
        if (chunkSize == size || chunkSize == size + 1)
          return toUserAddr(addr);
        if (mHeap[addr] < size)
          continue;
        int next = addr + size;
        mHeap[next] = mHeap[addr] - size;
        mHeap[addr] = size;
        return toUserAddr(addr);
      }
    int addr = mHeap.size();
    mHeap.resize(addr + size);
    mHeap[addr] = size;
    return toUserAddr(addr);
  }

  void free(int addr) {
    addr = toHeapAddr(addr);
    mBin.insert(addr);
    int next = addr + mHeap[addr];
    if (mBin.find(next) != mBin.end()) {
      mHeap[addr] += mHeap[next];
      mBin.erase(next);
    }
  }

  void set(int addr, int val) {
    assert(addr >= 0 && addr < mHeap.size());
    mHeap[addr] = val;
  }

  int get(int addr) {
    assert(addr >= 0 && addr < mHeap.size());
    return mHeap[addr];
  }

#ifdef DEBUG
  void dump() {
    for (int i = 0; i < mHeap.size(); i++)
      llvm::errs() << i << " : " << mHeap[i] << "\n";
  }
#endif

private:
  int toUserAddr(int addr) { return addr + 1; }
  int toHeapAddr(int addr) { return addr - 1; }
};

class Environment {
  std::vector<StackFrame> mStack;

  FunctionDecl *mFree; /// Declartions to the built-in functions
  FunctionDecl *mMalloc;
  FunctionDecl *mInput;
  FunctionDecl *mOutput;

  FunctionDecl *mEntry;

  std::map<Decl *, int> mGlobalVars;
  std::map<StringRef, FunctionDecl *> mFunctions;
  Heap mHeap;

public:
  /// Get the declartions to the built-in functions
  Environment()
      : mStack(), mFree(NULL), mMalloc(NULL), mInput(NULL), mOutput(NULL),
        mEntry(NULL), mHeap() {}

  /// Initialize the Environment
  void init(TranslationUnitDecl *unit) {
#ifdef DEBUG
    llvm::errs() << "Initializing Environment\n";
#endif
    for (TranslationUnitDecl::decl_iterator i = unit->decls_begin(),
                                            e = unit->decls_end();
         i != e; ++i) {
      if (FunctionDecl *fdecl = dyn_cast<FunctionDecl>(*i))
        initFunction(fdecl);
      else if (VarDecl *vardecl = dyn_cast<VarDecl>(*i))
        initGlobal(vardecl);
    }
    mStack.push_back(StackFrame());
  }

  FunctionDecl *getEntry() { return mEntry; }

  /// !TODO Support comparison operation
  void binop(BinaryOperator *bop) {
    Expr *left = bop->getLHS();
    Expr *right = bop->getRHS();

    int val = 0;
    if (bop->isAssignmentOp())
      val = assignOp(left, right);
    else if (bop->isAdditiveOp())
      val = addOp(left, right, bop->getOpcode() == BO_Add ? 1 : -1);
    else if (bop->getOpcode() == BO_Mul)
      val = mulOp(left, right);
    else if (bop->isComparisonOp())
      val = cmpOp(left, right, bop->getOpcode());
    else
      assert(false && "Unsupported Binary Operator");
    mStack.back().bindStmt(bop, val);
  }

  void decl(DeclStmt *declstmt) {
    for (DeclStmt::decl_iterator it = declstmt->decl_begin(),
                                 ie = declstmt->decl_end();
         it != ie; ++it) {
      Decl *decl = *it;
      if (VarDecl *vardecl = dyn_cast<VarDecl>(decl)) {
        QualType type = vardecl->getType();

        if (type->isIntegerType() || type->isPointerType()) {
          int val = type->isIntegerType() ? 0 : -1; // init pointer to NULL (-1)
          if (Expr *expr = vardecl->getInit())
            val = mStack.back().getStmtVal(expr);
          mStack.back().bindDecl(vardecl, val);
#ifdef DEBUG
          llvm::errs() << "Declaring Variable: " << vardecl->getName() << " = "
                       << val << "\n";
#endif
        } else if (type->isArrayType()) {
          int size = 0;
          if (const ConstantArrayType *arraytype =
                  dyn_cast<ConstantArrayType>(type.getTypePtr()))
            size = arraytype->getSize().getSExtValue();
#ifdef DEBUG
          llvm::errs() << "Declaring Array: " << vardecl->getName()
                       << " Size: " << size << "\n";
#endif
          mStack.back().bindDecl(vardecl, mHeap.malloc(size));
        } else
          assert(false && "Unsupported Decl Type");
      }
    }
  }

  void declref(DeclRefExpr *declref) {
    mStack.back().setPC(declref);
    QualType type = declref->getType();
    if (type->isIntegerType() || type->isArrayType() || type->isPointerType()) {
      Decl *decl = declref->getFoundDecl();

      // int val = mStack.back().getDeclVal(decl);
      int val = getDeclVal(decl);
      mStack.back().bindStmt(declref, val);
    } else if (type->isFunctionType()) {
    } else
      assert(false && "Unsupported DeclRefExpr Type");
  }

  void cast(CastExpr *castexpr) {
    mStack.back().setPC(castexpr);
    QualType type = castexpr->getType();
    if (type->isFunctionPointerType()) {
    } else if (type->isIntegerType() || type->isPointerType()) {
      Expr *expr = castexpr->getSubExpr();

      int val = mStack.back().getStmtVal(expr);
      if (ArraySubscriptExpr *arraysub = dyn_cast<ArraySubscriptExpr>(expr))
        val = mHeap.get(val);
      if (UnaryOperator *uop = dyn_cast<UnaryOperator>(expr))
        if (uop->getOpcode() == UO_Deref)
          val = mHeap.get(val);
      mStack.back().bindStmt(castexpr, val);
    } else
      assert(false && "Unsupported CastExpr Type");
  }

  FunctionDecl *call(CallExpr *callexpr) {
    mStack.back().setPC(callexpr);
    int val = 0;
    FunctionDecl *callee = callexpr->getDirectCallee();
    if (callee == mInput) {
      llvm::errs() << "Please Input an Integer Value : ";
      scanf("%d", &val);

      mStack.back().bindStmt(callexpr, val);
      return nullptr;
    } else if (callee == mOutput) {
      Expr *decl = callexpr->getArg(0);
      val = mStack.back().getStmtVal(decl);
      llvm::errs() << val;
      return nullptr;
    } else if (callee == mMalloc) {
      Expr *decl = callexpr->getArg(0);
      int size = mStack.back().getStmtVal(decl);
      int addr = mHeap.malloc(size);
      mStack.back().bindStmt(callexpr, addr);
      return nullptr;
    } else if (callee == mFree) {
      Expr *decl = callexpr->getArg(0);
      int addr = mStack.back().getStmtVal(decl);
      mHeap.free(addr);
      return nullptr;
    } else {
#ifdef DEBUG
      llvm::errs() << "Call Function: " << callee->getName() << "\n";
#endif
      auto it = mFunctions.find(callee->getName());
      assert(it != mFunctions.end());
      return it->second;
    }
  }

  void integer(IntegerLiteral *literal) {
    mStack.back().setPC(literal);
    mStack.back().bindStmt(literal, literal->getValue().getSExtValue());
  }

  void pushFrame(CallExpr *callexpr, FunctionDecl *callee) {
    StackFrame frame;
    assert(callee->param_size() == callexpr->getNumArgs());
    for (unsigned i = 0; i < callexpr->getNumArgs(); i++) {
      Expr *arg = callexpr->getArg(i);
      ParmVarDecl *param = callee->getParamDecl(i);
      int val = mStack.back().getStmtVal(arg);
      frame.bindDecl(param, val);
    }
    mStack.push_back(frame);
  }

  void popFrame(CallExpr *callexpr) {
    assert(isReturn());
    int retval = mStack.back().getRetVal();
    mStack.pop_back();
    mStack.back().bindStmt(callexpr, retval);
  }

  void ret(ReturnStmt *retstmt) {
    if (Expr *expr = retstmt->getRetValue()) {
      int val = mStack.back().getStmtVal(expr);
      mStack.back().setRetVal(val);
    }
    mStack.back().setReturn();
  }

  bool isReturn() { return mStack.back().isReturn(); }

  int getStmtVal(Stmt *stmt) { return mStack.back().getStmtVal(stmt); }

  void unaryop(UnaryOperator *uop) {
    Expr *expr = uop->getSubExpr();
    int val = mStack.back().getStmtVal(expr);
    if (uop->getOpcode() == UO_Minus)
      val = -val;
    else if (uop->getOpcode() == UO_Deref) {
    } else
      assert(false && "Unsupported Unary Operator");
    mStack.back().bindStmt(uop, val);
  }

  // HACK: only support array variable as the base
  void arraySubscript(ArraySubscriptExpr *arraysub) {
    Expr *base = arraysub->getBase();
    Expr *idx = arraysub->getIdx();
    int baseAddr = mStack.back().getStmtVal(base);
    int idxVal = mStack.back().getStmtVal(idx);
    if (ImplicitCastExpr *cast = dyn_cast<ImplicitCastExpr>(base))
      if (cast->getCastKind() == CK_ArrayToPointerDecay) {
        if (DeclRefExpr *declexpr = dyn_cast<DeclRefExpr>(cast->getSubExpr())) {
          if (VarDecl *vardecl = dyn_cast<VarDecl>(declexpr->getDecl())) {
            if (const ConstantArrayType *arraytype =
                    dyn_cast<ConstantArrayType>(
                        vardecl->getType().getTypePtr())) {
              int size = arraytype->getSize().getSExtValue();
#ifdef DEBUG
              llvm::errs() << "size = " << size << "\n";
              llvm::errs() << "idxVal = " << idxVal << "\n";
#endif
              assert(idxVal >= 0 && idxVal < size);
            }
          }
        }
      }
#ifdef DEBUG
    llvm::errs() << "Array Subscript: " << baseAddr << "[" << idxVal << "]\n";
    llvm::errs() << "Array Subscript: " << arraysub << " " << baseAddr + idxVal
                 << "\n";
#endif
    mStack.back().bindStmt(arraysub, baseAddr + idxVal);
  }

  void paren(ParenExpr *paren) {
    Expr *expr = paren->getSubExpr();
    int val = mStack.back().getStmtVal(expr);
    mStack.back().bindStmt(paren, val);
  }

  void unaryExprOrTypeTrait(UnaryExprOrTypeTraitExpr *uett) {
    QualType type = uett->getTypeOfArgument();
    if (uett->getKind() == UETT_SizeOf) {
      // HACK: only int or pointer, the pointer is int in this interpreter
      if (type->isIntegerType() || type->isPointerType()) {
        mStack.back().bindStmt(uett, 1);
      }
    }
  }

private:
  void initFunction(FunctionDecl *fdecl) {
#ifdef DEBUG
    llvm::errs() << "Declaring Function: " << fdecl->getName() << "\n";
#endif
    if (fdecl->getName().equals("FREE"))
      mFree = fdecl;
    else if (fdecl->getName().equals("MALLOC"))
      mMalloc = fdecl;
    else if (fdecl->getName().equals("GET"))
      mInput = fdecl;
    else if (fdecl->getName().equals("PRINT"))
      mOutput = fdecl;
    else if (fdecl->getName().equals("main"))
      mEntry = fdecl;
    else
      mFunctions[fdecl->getName()] = fdecl;
  }

  void initGlobal(VarDecl *vardecl) {
    int val = 0;
    if (vardecl->hasInit()) {
      if (IntegerLiteral *literal =
              dyn_cast<IntegerLiteral>(vardecl->getInit()))
        val = literal->getValue().getSExtValue();
      mGlobalVars[vardecl] = val;
    }
#ifdef DEBUG
    llvm::errs() << "Declaring Global Variable: " << vardecl->getName() << " = "
                 << val << "\n";
#endif
  }

  int getDeclVal(Decl *decl) {
    if (mGlobalVars.find(decl) != mGlobalVars.end())
      return mGlobalVars.find(decl)->second;
    return mStack.back().getDeclVal(decl);
  }

  void bindDecl(Decl *decl, int val) {
    if (mGlobalVars.find(decl) != mGlobalVars.end())
      mGlobalVars.find(decl)->second = val;
    else
      mStack.back().bindDecl(decl, val);
  }

  int assignOp(Expr *left, Expr *right) {
#ifdef DEBUG
    llvm::errs() << "Assignment Operation\n";
#endif
    int val = mStack.back().getStmtVal(right);
    if (DeclRefExpr *declexpr = dyn_cast<DeclRefExpr>(left)) {
      Decl *decl = declexpr->getFoundDecl();
      bindDecl(decl, val);
#ifdef DEBUG
      llvm::errs() << "Assigning Variable: "
                   << declexpr->getFoundDecl()->getName() << " = " << val
                   << "\n";
#endif
    } else if (ArraySubscriptExpr *arraysub =
                   dyn_cast<ArraySubscriptExpr>(left)) {
      int addr = mStack.back().getStmtVal(arraysub);
#ifdef DEBUG
      mHeap.dump();
      llvm::errs() << "Assigning Array: [" << addr << "] = " << val << "\n";
#endif
      mHeap.set(addr, val);
    } else if (UnaryOperator *uop = dyn_cast<UnaryOperator>(left)) {
      int addr = mStack.back().getStmtVal(uop);
      mHeap.set(addr, val);
    } else
      assert(false && "Unsupported Assignment Operation");
    mStack.back().bindStmt(left, val);
    return val;
  }

  int addOp(Expr *left, Expr *right, int sign) {
    int lval = mStack.back().getStmtVal(left);
    int rval = mStack.back().getStmtVal(right);
    return lval + sign * rval;
  }

  int mulOp(Expr *left, Expr *right) {
    int lval = mStack.back().getStmtVal(left);
    int rval = mStack.back().getStmtVal(right);
    return lval * rval;
  }

  int cmpOp(Expr *left, Expr *right, BinaryOperator::Opcode op) {
    int lval = mStack.back().getStmtVal(left);
    int rval = mStack.back().getStmtVal(right);
    switch (op) {
    case BO_LT:
      return lval < rval;
    case BO_GT:
      return lval > rval;
    case BO_LE:
      return lval <= rval;
    case BO_GE:
      return lval >= rval;
    case BO_EQ:
      return lval == rval;
    case BO_NE:
      return lval != rval;
    default:
      assert(false && "Never reach here");
      return 0;
    }
  }
};

写一个 pass, 根据 bc 分析 call 指令可能 call 的函数. 上课没听不太懂怎么搞.

Makefile:

makefile

USER_ID=$(shell id -u)
CONTAINER_NAME=lczxxx123/llvm_10_hw:0.2
HOST_CODE_DIR=./assign2/
HOST_BUILD_DIR=./build
HOST_TESTCASE_DIR=./assign2/assign2-tests/
CONTAINER_CODE_DIR=/tmp/code
CONTAINER_BUILD_DIR=/tmp/build
CONTAINER_TESTCASE_DIR=/tmp/testcases
LLVM_DIR=/usr/local/llvm10ra
BUILD_OUTPUT=llvmassignment
FLAGS=-DDEBUG

CPP_FILES=$(wildcard $(HOST_CODE_DIR)/*.cpp)
H_FILES=$(wildcard $(HOST_CODE_DIR)/*.h)
ALL_FILES=$(CPP_FILES) $(H_FILES)
TESTCASE_FILES=$(wildcard $(HOST_TESTCASE_DIR)/*.c)

$(HOST_BUILD_DIR)/$(BUILD_OUTPUT): $(ALL_FILES)
	mkdir -p $(HOST_BUILD_DIR)
	docker run --rm \
		--user=$(USER_ID) \
		-v $(HOST_CODE_DIR):$(CONTAINER_CODE_DIR) \
		-v $(HOST_BUILD_DIR):$(CONTAINER_BUILD_DIR) \
		$(CONTAINER_NAME) /bin/bash -c "\
		cmake -DLLVM_DIR=$(LLVM_DIR) \
		-DCMAKE_BUILD_TYPE=Debug \
		-DCMAKE_CXX_FLAGS_DEBUG=$(FLAGS) \
		-DCMAKE_CXX_FLAGS="-std=c++14" \
		$(CONTAINER_CODE_DIR) \
		-B $(CONTAINER_BUILD_DIR) && \
		make -C $(CONTAINER_BUILD_DIR)"

.PHONY: all build docker run bc ll clean
all: build

docker:
  docker pull $(CONTAINER_NAME)

build: $(HOST_BUILD_DIR)/$(BUILD_OUTPUT)

run: build
	docker run --rm -it \
		--user=$(USER_ID) \
		-v $(HOST_BUILD_DIR):$(CONTAINER_BUILD_DIR) \
		-v $(HOST_TESTCASE_DIR):$(CONTAINER_TESTCASE_DIR) \
		$(CONTAINER_NAME) \
		$(CONTAINER_BUILD_DIR)/$(BUILD_OUTPUT) $(CONTAINER_TESTCASE_DIR)/test$(test).c.bc


bc:
	docker run --rm \
		--user=$(USER_ID) \
		-v $(HOST_TESTCASE_DIR):$(CONTAINER_TESTCASE_DIR) \
		$(CONTAINER_NAME)	\
		sh -c 'for test in `find $(CONTAINER_TESTCASE_DIR) -name "*.c"`; do \
			echo $$test ; \
			$(LLVM_DIR)/bin/clang -emit-llvm -O0 -g -c $$test -o $$test.bc; \
		done'


ll:
	docker run --rm \
		--user=$(USER_ID) \
		-v $(HOST_TESTCASE_DIR):$(CONTAINER_TESTCASE_DIR) \
		$(CONTAINER_NAME)	\
		$(LLVM_DIR)/bin/clang -emit-llvm -O0 -S $(CONTAINER_TESTCASE_DIR)/test$(test).c -o -


clean:
	rm -rf $(HOST_BUILD_DIR)

IR 是 SSA 的, SSA 听了一耳朵.

主要就是继承了一个 ModulePass, 重写 runOnModule.

为了方便写了一个输入函数:

c++

 * @brief print the call instruction
 * lineno: func1, func2, ...
 *
 * @param callInst call instruction
 * @param funcs possible functions
 */
void printCall(CallInst *callInst, std::vector<Function *> funcs) {
  int lineno = callInst->getDebugLoc().getLine();
  errs() << lineno << ": ";
  Function *last = funcs.back();
  funcs.pop_back();
  for (Function *F : funcs)
    errs() << F->getName() << ", ";
  errs() << last->getName() << '\n';
}

由于要输出行号, 需要用 Debug 模式.

test00 就是一个直接调用, 遍历所有的指令, 遇到 call 指令, 如果是直接调用, 那么直接打印就好.

c++

  bool runOnModule(Module &M) override {
#ifdef DEBUG
    errs() << "------------------------------\n";
    errs().write_escaped(M.getName()) << '\n';
    M.dump();
    errs() << "------------------------------\n";
#endif

    for (Function &F : M)
      for (BasicBlock &BB : F)
        for (Instruction &I : BB)
          if (!isa<DbgInfoIntrinsic>(&I))
            if (CallInst *CI = dyn_cast<CallInst>(&I))
              visitCallInst(CI);

    return false;
  }

  void visitCallInst(CallInst *callInst) {
    int line = callInst->getDebugLoc().getLine();
    Function *Callee = callInst->getCalledFunction();
    if (Callee) {
#ifdef DEBUG
      errs() << "Direct call at line " << line << ":\n"
             << Callee->getName() << '\n';
#endif
      printCall(callInst, {Callee});
    } else {
    }
  }

需要注意指令不能是 DbgInfoIntrinsic, 否则有个 llvm.dbg.value 也会进入.

test01 则设计到了间接调用, 打印一下:

c++

  void visitCallInst(CallInst *callInst) {
    int line = callInst->getDebugLoc().getLine();
    Function *Callee = callInst->getCalledFunction();
    if (Callee) {
#ifdef DEBUG
      errs() << "Direct call at line " << line << ":\n"
             << Callee->getName() << '\n';
#endif
      printCall(callInst, {Callee});
    } else {
#ifdef DEBUG
      errs() << "Indirect call at line " << line << ":\n";
      callInst->dump();
      callInst->getCalledOperand()->dump();
#endif
    }
  }
Indirect call at line 22:
  %call = call i32 %t_fptr.0(i32 1, i32 2), !dbg !32
  %t_fptr.0 = phi i32 (i32, i32)* [ @plus, %if.then ], [ null, %entry ], !dbg !15

可以看到 %t_fptr.0 是一个 phi 节点. phi 节点包含了它可能的取值. 这里需要处理 phi 节点.

由于节点可能还有嵌套什么的, 这里考虑写一个递归函数 getPossibleCallees, 处理所有可能的节点:

c++

#ifdef DEBUG
  void getPossibleCallees(Value *V, std::vector<Function *> &funcs,
                          int depth = 0) {
    for (int i = 0; i < depth; i++)
      errs() << "  ";
    errs() << "------------- getPossibleCallees\n";
    for (int i = 0; i < depth; i++)
      errs() << "  ";
    V->dump();
#else
  void getPossibleCallees(Value *V, std::vector<Function *> &funcs) {
#endif
    if (PHINode *PN = dyn_cast<PHINode>(V)) {
      for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
#ifdef DEBUG
        getPossibleCallees(PN->getIncomingValue(i), funcs, depth + 1);
#else
        getPossibleCallees(PN->getIncomingValue(i), funcs);
#endif
      }
    }
  }
Indirect call at line 22:
  %call = call i32 %t_fptr.0(i32 1, i32 2), !dbg !32
  %t_fptr.0 = phi i32 (i32, i32)* [ @plus, %if.then ], [ null, %entry ], !dbg !15
------------- getPossibleCallees
  %t_fptr.0 = phi i32 (i32, i32)* [ @plus, %if.then ], [ null, %entry ], !dbg !15
  ------------- getPossibleCallees
  ; Function Attrs: noinline nounwind uwtable
define dso_local i32 @plus(i32 %a, i32 %b) #0 !dbg !9 {
entry:
  call void @llvm.dbg.value(metadata i32 %a, metadata !14, metadata !DIExpression()), !dbg !15
  call void @llvm.dbg.value(metadata i32 %b, metadata !16, metadata !DIExpression()), !dbg !15
  %add = add nsw i32 %a, %b, !dbg !17
  ret i32 %add, !dbg !18
}

  ------------- getPossibleCallees
  i32 (i32, i32)* null

可以看到 phi 节点的两个值分别是 @plus 函数和 null. null 可以不管他. 测试了一下当直接调用时, callInst->getCalledOperand(), 发现是 Function 类型, 那么就可以直接整合到一起了.

c++

#ifdef DEBUG
  void getPossibleCallees(Value *V, std::vector<Function *> &funcs,
                          int depth = 0) {
    for (int i = 0; i < depth; i++)
      errs() << "  ";
    errs() << "------------- getPossibleCallees\n";
    for (int i = 0; i < depth; i++)
      errs() << "  ";
    V->dump();
#else
  void getPossibleCallees(Value *V, std::vector<Function *> &funcs) {
#endif
    if (Function *F = dyn_cast<Function>(V))
      funcs.push_back(F);
    else if (PHINode *PN = dyn_cast<PHINode>(V)) {
      for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
#ifdef DEBUG
        getPossibleCallees(PN->getIncomingValue(i), funcs, depth + 1);
#else
        getPossibleCallees(PN->getIncomingValue(i), funcs);
#endif
      }
    }
  }

  void visitCallInst(CallInst *callInst) {
    int line = callInst->getDebugLoc().getLine();
#ifdef DEBUG
    errs() << "Call at line " << line << ":\n";
    callInst->getCalledOperand()->dump();
#endif
    std::vector<Function *> possibleFuncs;
    getPossibleCallees(callInst->getCalledOperand(), possibleFuncs);
    printCall(callInst, possibleFuncs);
  }

这里涉及到函数指针作为参数传入, 这样就无法得知具体调用了什么函数了. 要靠形参追踪到实参, 然后找实参的可能调用, 直接递归即可.

c++

  void getPossibleCallees(Value *V, std::vector<Function *> &funcs) {
    // ...
    } else if (Argument *arg = dyn_cast<Argument>(value)) {
      int argNo = arg->getArgNo();
      Function *func = arg->getParent();
      for (User *user : func->users()) {
        if (CallInst *callInst = dyn_cast<CallInst>(user)) {
          Value *argValue = callInst->getArgOperand(argNo);
#ifdef DEBUG
          getPossibleCallees(argValue, funcs, depth + 1);
#else
          getPossibleCallees(argValue, funcs);
#endif
        } else {
#ifdef DEBUG
          for (int i = 0; i < depth; i++)
            errs() << "  ";
          errs() << "Usupported User: \n";
          user->dump();
#endif
          assert(0);
        }
      }
    //...
  }

(看 test05 的时候发现可能有几个 phi 导致同一个函数放到 vector 里好几次, 改成 set 了. 代码不放了)

直接来到 test11. 这里涉及到函数返回值是一个函数指针, 然后调用这个函数指针.

这里的 call 的 operand value 是一个 instruction, 而这个 instruction 是 call instruction. (那会不会是其他 instruction 啊qaq)

看了一下测试点只有 call instruction.

那么按这个函数的返回值去递归往上追踪.

c++

      } else if (CallInst *inst = dyn_cast<CallInst>(value)) {
      errs() << "Call Instruction: ";
      std::set<Value *> possibleValues;
      getPossibleReturnValues(inst->getCalledOperand(), possibleValues);
      for (Value *value : possibleValues) {
        getPossibleCallees(value, funcs);
      }

c++

  void getPossibleReturnValues(Value *value, std::set<Value *> &values) {
    if (Function *func = dyn_cast<Function>(value)) {
      for (BasicBlock &BB : *func) {
        for (Instruction &I : BB) {
          if (ReturnInst *retInst = dyn_cast<ReturnInst>(&I)) {
            values.insert(retInst->getReturnValue());
          }
        }
      }
    } else if (PHINode *phi = dyn_cast<PHINode>(value))
      for (unsigned i = 0; i < phi->getNumIncomingValues(); i++)
        getPossibleReturnValues(phi->getIncomingValue(i), values);
    else {
      assert(0);
    }
  }
Call at line 31:
  %call7 = call i32 %call(i32 %op1, i32 %op2), !dbg !41
------------- getPossibleCallees
  %call = call i32 (i32, i32)* %goo_ptr.1(i32 %op1, i32 %op2, i32 (i32, i32)* @plus, i32 (i32, i32)* @minus), !dbg !40
  ------------- getPossibleCallees
  i32 (i32, i32)* %a_fptr
  Usupported User:
  %goo_ptr.1 = phi i32 (i32, i32)* (i32, i32, i32 (i32, i32)*, i32 (i32, i32)*)* [ @foo, %if.then ], [ @clever, %if.end ], !dbg !39
llvmassignment: /tmp/code/LLVMAssignment.cpp:140: void FuncPtrPass::getPossibleCallees(llvm::Value*, std::set<llvm::Function*>&, int): Assertion `0' failed.

不是很懂为什么, 但是反正就是一个函数的 User 是一个 phi. 所以还得讨论一下.

c++

    } else if (Argument *arg = dyn_cast<Argument>(value)) {
      int argNo = arg->getArgNo();
      Function *func = arg->getParent();
      std::set<CallInst *> callInsts;
      getPossibleUsers(func, callInsts);
      for (CallInst *callInst : callInsts)
#ifdef DEBUG
        getPossibleCallees(callInst->getArgOperand(argNo), funcs, depth + 1);
#else
        getPossibleCallees(callInst->getArgOperand(argNo), funcs);
#endif

c++

  void getPossibleUsers(Value *value, std::set<CallInst *> &users) {
    if (CallInst *callInst = dyn_cast<CallInst>(value))
      users.insert(callInst);
    else if (Function *func = dyn_cast<Function>(value))
      for (User *user : func->users())
        getPossibleUsers(user, users);
    else if (PHINode *phi = dyn_cast<PHINode>(value))
      for (User *user : phi->users())
        getPossibleUsers(user, users);
    else {
#ifdef DEBUG
      errs() << "Usupported Value Type: " << value->getValueID() << '\n';
      value->dump();
#endif
      assert(0);
    }
  }

这里就找 call phi 的 users, 找到 callinst 返回去, getArgOperand 不管他实际上调用的是哪个函数, 总之有可能是目标函数, 那么取参数去找就可以.

一行有多个函数调用, 没意思

这里也是函数指针, 但是和 test11 不同的是, 函数指针的 User 并不是调用这个函数, 而是作为参数传入了. 所以需要找到调用的函数, 进入这个函数去寻找哪里调用了目标函数. 有点绕, 自己也没太搞懂说不清楚, 直接放代码吧:

c++

    } else if (Argument *arg = dyn_cast<Argument>(value)) {
      int argNo = arg->getArgNo();
      Function *parentFunc = arg->getParent();
      std::set<CallInst *> callInsts;
      getPossibleCallUsers(parentFunc, callInsts);
      for (CallInst *callInst : callInsts) {
        std::set<Function *> possibleCallees;
        // get all callees of the call instruction,
        // even if the called operand is a phi node.
        getPossibleFunctions(callInst->getCalledOperand(), possibleCallees);
        if (possibleCallees.count(parentFunc))
#ifdef DEBUG
          getPossibleFunctions(callInst->getArgOperand(argNo), funcs,
                               depth + 1);
#else
          getPossibleFunctions(callInst->getArgOperand(argNo), funcs);
#endif
        // TODO: phi?
        else if (Function *func =
                     dyn_cast<Function>(callInst->getCalledOperand())) {
          int argNo2 = -1;
          for (int i = 0; i < func->arg_size(); i++) {
            std::set<Function *> possibleFunctionArgs;
            // get all functions that the argument can be,
            // even if the argument is a phi node.
            getPossibleFunctions(callInst->getArgOperand(i),
                                 possibleFunctionArgs,
#ifdef DEBUG
                                 0,
#endif
                                 false);
            if (possibleFunctionArgs.count(parentFunc)) {
              argNo2 = i;
              break;
            }
          }
          assert(argNo2 != -1);

          for (BasicBlock &BB : *func)
            for (Instruction &I : BB)
              if (!isa<DbgInfoIntrinsic>(&I))
                if (CallInst *callInst = dyn_cast<CallInst>(&I)) {
                  if (Argument *arg2 =
                          dyn_cast<Argument>(callInst->getCalledOperand())) {
                    if (func->getArg(argNo2) == arg2) {
#ifdef DEBUG
                      getPossibleFunctions(callInst->getArgOperand(argNo),
                                           funcs, depth + 1);
#else
                      getPossibleFunctions(callInst->getArgOperand(argNo),
                                           funcs);
#endif
                    }

                  } else // TODO: phi?
                    assert(0);
                }
        } else
          assert(0);
      }
    }

剩下还有两个地方可能是 phi 节点没有考虑, 一个是 call 的是 phi 而不是某个特定函数, 另一个是函数里面调用的目标函数, 可能不直接是参数, 而是 phi 节点. 反正测试点没有, 先摆了.

甚至还可能有调用函数和传参的函数是同一个哈哈哈哈分析不了一点.

c++

//===- Hello.cpp - Example code from "Writing an LLVM Pass" ---------------===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// This file implements two versions of the LLVM "Hello World" pass described
// in docs/WritingAnLLVMPass.html
//
//===----------------------------------------------------------------------===//

#include "llvm/IR/DerivedTypes.h"
#include <llvm/IR/LLVMContext.h>
#include <llvm/IR/LegacyPassManager.h>
#include <llvm/IRReader/IRReader.h>
#include <llvm/Support/CommandLine.h>
#include <llvm/Support/SourceMgr.h>
#include <llvm/Support/ToolOutputFile.h>

#include <llvm/Transforms/Scalar.h>
#include <llvm/Transforms/Utils.h>

#include <llvm/IR/Function.h>
#include <llvm/IR/Instructions.h>
#include <llvm/IR/IntrinsicInst.h>
#include <llvm/Pass.h>
#include <llvm/Support/raw_ostream.h>

#include <llvm/Bitcode/BitcodeReader.h>
#include <llvm/Bitcode/BitcodeWriter.h>

using namespace llvm;

/**
 * @brief print the call instruction
 * lineno: func1, func2, ...
 *
 * @param callInst call instruction
 * @param funcs possible functions
 */
void printCall(int lineno, std::set<Function *> &funcs) {
  assert(!funcs.empty());
  errs() << lineno << " : ";
  for (auto it = funcs.begin(); it != funcs.end(); it++) {
    errs() << (*it)->getName();
    if (std::next(it) != funcs.end())
      errs() << ", ";
  }
  errs() << '\n';
}

static ManagedStatic<LLVMContext> GlobalContext;
static LLVMContext &getGlobalContext() { return *GlobalContext; }
/* In LLVM 5.0, when  -O0 passed to clang , the functions generated with clang
 * will have optnone attribute which would lead to some transform passes
 * disabled, like mem2reg.
 */
struct EnableFunctionOptPass : public FunctionPass {
  static char ID;
  EnableFunctionOptPass() : FunctionPass(ID) {}
  bool runOnFunction(Function &F) override {
    if (F.hasFnAttribute(Attribute::OptimizeNone)) {
      F.removeFnAttr(Attribute::OptimizeNone);
    }
    return true;
  }
};

char EnableFunctionOptPass::ID = 0;

///!TODO TO BE COMPLETED BY YOU FOR ASSIGNMENT 2
/// Updated 11/10/2017 by fargo: make all functions
/// processed by mem2reg before this pass.
struct FuncPtrPass : public ModulePass {
  static char ID; // Pass identification, replacement for typeid
  // add funcs to store all the functions in the module
  FuncPtrPass() : ModulePass(ID) {}

  void getPossibleCallUsers(Value *value, std::set<CallInst *> &users) {
    if (CallInst *callInst = dyn_cast<CallInst>(value))
      users.insert(callInst);
    else if (Function *func = dyn_cast<Function>(value))
      for (User *user : func->users())
        getPossibleCallUsers(user, users);
    else if (PHINode *phi = dyn_cast<PHINode>(value))
      for (User *user : phi->users())
        getPossibleCallUsers(user, users);
    else {
#ifdef DEBUG
      errs() << "Usupported Value Type: " << value->getValueID() << '\n';
      value->dump();
#endif
      assert(0);
    }
  }

  void getPossibleReturnValues(Value *value, std::set<Value *> &values) {
    if (Function *func = dyn_cast<Function>(value)) {
      for (BasicBlock &BB : *func)
        for (Instruction &I : BB)
          if (ReturnInst *retInst = dyn_cast<ReturnInst>(&I))
            values.insert(retInst->getReturnValue());
    } else if (PHINode *phi = dyn_cast<PHINode>(value))
      for (unsigned i = 0; i < phi->getNumIncomingValues(); i++)
        getPossibleReturnValues(phi->getIncomingValue(i), values);
    else
      assert(0);
  }

#ifdef DEBUG
  void getPossibleFunctions(Value *value, std::set<Function *> &funcs,
                            int depth = 0, bool assert_on_error = true) {
    for (int i = 0; i < depth; i++)
      errs() << "  ";
    errs() << "------------- getPossibleCallees\n";
    for (int i = 0; i < depth; i++)
      errs() << "  ";
    value->dump();
#else
  void getPossibleFunctions(Value *value, std::set<Function *> &funcs,
                            bool assert_on_error = true) {
#endif
    if (Function *func = dyn_cast<Function>(value))
      funcs.insert(func);
    else if (PHINode *phi = dyn_cast<PHINode>(value)) {
      for (unsigned i = 0; i < phi->getNumIncomingValues(); i++)
#ifdef DEBUG
        getPossibleFunctions(phi->getIncomingValue(i), funcs, depth + 1);
#else
        getPossibleFunctions(phi->getIncomingValue(i), funcs);
#endif
    } else if (ConstantPointerNull *ptr =
                   dyn_cast<ConstantPointerNull>(value)) {
    } else if (Argument *arg = dyn_cast<Argument>(value)) {
      int argNo = arg->getArgNo();
      Function *parentFunc = arg->getParent();
      std::set<CallInst *> callInsts;
      getPossibleCallUsers(parentFunc, callInsts);
      for (CallInst *callInst : callInsts) {
        std::set<Function *> possibleCallees;
        // get all callees of the call instruction,
        // even if the called operand is a phi node.
        getPossibleFunctions(callInst->getCalledOperand(), possibleCallees);
        if (possibleCallees.count(parentFunc))
#ifdef DEBUG
          getPossibleFunctions(callInst->getArgOperand(argNo), funcs,
                               depth + 1);
#else
          getPossibleFunctions(callInst->getArgOperand(argNo), funcs);
#endif
        // TODO: phi?
        else if (Function *func =
                     dyn_cast<Function>(callInst->getCalledOperand())) {
          int argNo2 = -1;
          for (int i = 0; i < func->arg_size(); i++) {
            std::set<Function *> possibleFunctionArgs;
            // get all functions that the argument can be,
            // even if the argument is a phi node.
            getPossibleFunctions(callInst->getArgOperand(i),
                                 possibleFunctionArgs,
#ifdef DEBUG
                                 0,
#endif
                                 false);
            if (possibleFunctionArgs.count(parentFunc)) {
              argNo2 = i;
              break;
            }
          }
          assert(argNo2 != -1);

          for (BasicBlock &BB : *func)
            for (Instruction &I : BB)
              if (!isa<DbgInfoIntrinsic>(&I))
                if (CallInst *callInst = dyn_cast<CallInst>(&I)) {
                  if (Argument *arg2 =
                          dyn_cast<Argument>(callInst->getCalledOperand())) {
                    if (func->getArg(argNo2) == arg2) {
#ifdef DEBUG
                      getPossibleFunctions(callInst->getArgOperand(argNo),
                                           funcs, depth + 1);
#else
                      getPossibleFunctions(callInst->getArgOperand(argNo),
                                           funcs);
#endif
                    }

                  } else // TODO: phi?
                    assert(0);
                }
        } else
          assert(0);
      }
    } else if (CallInst *inst = dyn_cast<CallInst>(value)) {
      std::set<Value *> possibleValues;
      getPossibleReturnValues(inst->getCalledOperand(), possibleValues);
      for (Value *value : possibleValues)
#ifdef DEBUG
        getPossibleFunctions(value, funcs, depth + 1);
#else
        getPossibleFunctions(value, funcs);
#endif
    } else {
#ifdef DEBUG
      for (int i = 0; i < depth; i++)
        errs() << "  ";
      errs() << "Usupported Value Type: " << value->getValueID() << '\n';
#endif
      assert(0);
    }
  }

  void visitCallInst(CallInst *callInst,
                     std::map<int, std::set<Function *>> &ans) {
    int line = callInst->getDebugLoc().getLine();
#ifdef DEBUG
    errs() << "Call at line " << line << ":\n";
    callInst->dump();
#endif
    std::set<Function *> possibleFuncs;
    getPossibleFunctions(callInst->getCalledOperand(), possibleFuncs);
    if (ans.count(line) == 0)
      ans[line] = possibleFuncs;
    else
      for (Function *func : possibleFuncs)
        ans[line].insert(func);
  }

  bool runOnModule(Module &M) override {
#ifdef DEBUG
    errs() << "------------------------------\n";
    errs().write_escaped(M.getName()) << '\n';
    M.dump();
    errs() << "------------------------------\n";
#endif

    std::map<int, std::set<Function *>> ans;

    for (Function &F : M)
      for (BasicBlock &BB : F)
        for (Instruction &I : BB)
          if (!isa<DbgInfoIntrinsic>(&I))
            if (CallInst *CI = dyn_cast<CallInst>(&I))
              visitCallInst(CI, ans);

    for (auto it = ans.begin(); it != ans.end(); it++)
      printCall(it->first, it->second);

    return false;
  }
};

char FuncPtrPass::ID = 0;
static RegisterPass<FuncPtrPass> X("funcptrpass",
                                   "Print function call instruction");

static cl::opt<std::string>
    InputFilename(cl::Positional, cl::desc("<filename>.bc"), cl::init(""));

int main(int argc, char **argv) {
  LLVMContext &Context = getGlobalContext();
  SMDiagnostic Err;
  // Parse the command line to read the Inputfilename
  cl::ParseCommandLineOptions(
      argc, argv, "FuncPtrPass \n My first LLVM too which does not do much.\n");

  // Load the input module
  std::unique_ptr<Module> M = parseIRFile(InputFilename, Err, Context);
  if (!M) {
    Err.print(argv[0], errs());
    return 1;
  }

  llvm::legacy::PassManager Passes;

  /// Remove functions' optnone attribute in LLVM5.0
  Passes.add(new EnableFunctionOptPass());
  /// Transform it to SSA
  Passes.add(llvm::createPromoteMemoryToRegisterPass());

  /// Your pass to print Function and Call Instructions
  Passes.add(new FuncPtrPass());
  Passes.run(*M.get());
}

Points-to Analysis

打开文件一脸蒙, 什么? 我要学数据流分析? 奈何课是一定没听, 硬是对着 ppt 自学了一点, 学到上下文不敏感的指针分析就结束了, 作业够用了

(懒得写了)

给了数据流分析的框架, 已经实现好了前向后向传播, 只需要自己按照需求实现 Gen / Kill, 以及基本块的合并操作 (May 还是 Must) 即可.

这个作业需要做的和上一个一样, 分析每个函数调用可能是什么, 只不过这次要用更合理的方法, 还要处理更多的东西. 所以需要分析指针变量的可能 (may) 指向集 (points-to set), 然后到调用的语句看一下这个指针的指向集有哪些, 就知道可能调用的是什么函数了.

要求是流敏感, 上下文不敏感, 域不敏感, 还需要过程间分析. 流敏感就是每个程序点都要存一份所有指针的指向集. 数据流分析的框架可以处理每个函数的过程内分析, 需要想个办法拓展加入过程间分析. 暂时想的是先对每个函数进行过程内分析, 然后遍历每个 call, 把 call 的输入点的指向集处理一下给函数的参数, 作为初始值. 同样 call 的这个函数的结果处理一下, 传播给 call 之后的程序点 (或许可以作为当前基本块的输出? 这样只需要修改一下数据流框架加上判断就行, 不需要在具体的指令上搞). 这样处理完了再跑一遍过程间分析, 重复过程直到不动点 (大概这样, 不知道可不可行).

流敏感分析不需要计算传递闭包, 所以不需要建图, 只需要保存指向集即可. 抄一下 liveness 的 info, 修改成这样即可:

cpp

// Points-to information for each program point
struct PointsToInfo {
  std::map<Value *, std::set<Value *>> PointsToSets;

  PointsToInfo();
  PointsToInfo(const PointsToInfo &info);
  bool operator==(const PointsToInfo &info) const;
};

map 的 key 就是变量, value 里的 set 就是他的指向集.

Points-to 分析是前向传播, Gen 和 Kill 按照 PPT 上的 Andersen-style 来做. LLVM 的 IR 会把几种赋值操作拆成一系列的 store 和 load, 没有 PPT 写的那么直观, 这里就是需要解决的一个难点.

首先来看看各种操作的 IR 是什么:

c

int main() {
  int a;
  int *pa, *pa2, *pa3, **ppa;
  a = 10;
  pa = &a;
  ppa = &pa;
  pa2 = *ppa;
  pa3 = pa;
  *ppa = pa3;
}
define dso_local i32 @main() #0 {
entry:
  %a = alloca i32, align 4
  %pa = alloca i32*, align 8
  %pa2 = alloca i32*, align 8
  %pa3 = alloca i32*, align 8
  %ppa = alloca i32**, align 8
  store i32 10, i32* %a, align 4
  store i32* %a, i32** %pa, align 8
  store i32** %pa, i32*** %ppa, align 8
  %0 = load i32**, i32*** %ppa, align 8
  %1 = load i32*, i32** %0, align 8
  store i32* %1, i32** %pa2, align 8
  %2 = load i32*, i32** %pa, align 8
  store i32* %2, i32** %pa3, align 8
  %3 = load i32*, i32** %pa3, align 8
  %4 = load i32**, i32*** %ppa, align 8
  store i32* %3, i32** %4, align 8
  ret i32 0
}

可以看到, 局部变量都是 alloca 指令开辟栈空间, a = 10 是 store 一个 10%a 空间的内存里.

  • base: pa = &a: store %a%pa 空间的内存里;
  • assign: pa3 = pa: load %pa 空间的内存里的值到 %2, 然后 store %2%pa3 空间的内存里;
  • store: *ppa = pa3: load %pa3 空间的内存里的值到 %3, load %ppa 空间的内存里的值到 %4, 然后 store %3%4 空间的内存里.
  • load: pa2 = *ppa: load %ppa 空间的内存里的值到 %0, load %0 空间的内存里的值到 %1, 然后 store %1%pa2 空间的内存里.

可以看到, 栈上的指针变量应该都是 alloac 指令, 并且类型是指针 (无论一级还是多级).

最简单的 base 操作, 就是把一个 alloca 指令对应的变量放入空间, 无论是否是指针变量 (&a&pa), 所以可以判断一下, 更新指向集为 ${\texttt{a}}$.

assign 操作中 (以 pa3 = pa 为例), store 的是一个 load 指令对应的临时变量 %2, 判断一下这个条件, 然后把 load 的操作数 %pa 对应的指向集合拿出来, 更新成 store 的操作数 (%pa3) 的指向集合. 为了处理方便, 可以遍历到 load 指令的时候, 判断操作数是否为变量 (alloca 指令), 是则为临时变量 %2 创建指向集, 和 %pa 的一样.

load 操作中 (以 pa2 = *ppa 为例), 可以发现有两个 load 指令, 其中第一个已经处理过了, 下面考虑第二个. 第二个与第一个的不同之处在于操作数是 load 指令对应的临时变量 (%0), 而不是 alloca 指令对应的变量. 那么可以判断这个条件, 如果操作数是 load 指令, 那么遍历他 (%0) 的指向集中的所有元素 $x$ (这里只有一个, 就是 pa), 把 x (只有一个 pa) 的指向集合并 (这里例子是 ${\texttt{a}}$) 更新为当前变量 %pa2 的指向集.

store 操作也是一样, 断操作数是否为 load 指令, 然后按照 Gen / Kill 规则更新即可. 由于规则比较复杂, 这里就不展开了.

这一部分代码如下:

cpp

std::set<Value *> PointsToVisitor::getPointsToSet(Value *val,
                                                  const PointsToInfo *dfval) {
  if (isa<LoadInst>(val)) {
    assert(dfval->PointsToSets.count(val) > 0);
    return dfval->PointsToSets.at(val);
  }
  // TODO: may handle other instructions? there are some instructions which is
  // not alloca or global variable or function. For example, argument, which
  // maybe a pointer or a valuable. Is the pointer should be handled as the same
  // as LoadInst? Should I use pointer type to distinguish them rather than
  // LoadInst?
  return {val};
}

void PointsToVisitor::compDFVal(Instruction *inst, PointsToInfo *dfval) {
  if (isa<DbgInfoIntrinsic>(inst))
    return;
#ifdef DEBUG
  llvm::errs() << "\033[1;34m";
  inst->dump();
  errs() << "\033[0m";
#endif
  if (StoreInst *store_inst = dyn_cast<StoreInst>(inst)) {
    Value *ptr = store_inst->getPointerOperand();
    Value *val = store_inst->getValueOperand();
    if (isa<ConstantData>(val))
      return;
    if (isa<PointerType>(val->getType())) {
      std::set<Value *> val_points_to = getPointsToSet(val, dfval);
      if (isa<AllocaInst>(ptr) || isa<GlobalVariable>(ptr))
        // base or assign or load: x = y, based on y, maybe &z, z, *z, loaded
        // temp value in PointsToSets
        dfval->PointsToSets[ptr] = val_points_to;
      else if (isa<LoadInst>(ptr)) {
        // store: *x = y;
        assert(dfval->PointsToSets.count(ptr) > 0 &&
               dfval->PointsToSets[ptr].size() > 0); // S(x) = {}, assert
        // S(x) = {v}, S' = S \ {v->S(v)} \cup {v->S(y)}
        if (dfval->PointsToSets[ptr].size() == 1)
          dfval->PointsToSets[*dfval->PointsToSets[ptr].begin()] =
              val_points_to;
        else // S(x) = {v1, v2, ...}, S' = S \cup {v1->S(y), v2->S(y), ...}
          for (auto val : dfval->PointsToSets[ptr])
            dfval->PointsToSets[val] = val_points_to;
      } else
        assert(0);
    }
#ifdef DEBUG
    llvm::errs() << "\033[1;34m";
    llvm::errs() << *dfval;
    llvm::errs() << "\033[0m";
#endif
  } else if (LoadInst *load_inst = dyn_cast<LoadInst>(inst)) {
    if (!load_inst->getType()->isPointerTy())
      return;
    Value *ptr = load_inst->getPointerOperand();
    assert(ptr->getType()->isPointerTy());
    assert(dfval->PointsToSets.count(ptr) > 0);
    if (isa<LoadInst>(ptr)) {
      dfval->PointsToSets[inst] = {};
      for (auto val : dfval->PointsToSets[ptr])
        dfval->PointsToSets[inst].insert(dfval->PointsToSets[val].begin(),
                                         dfval->PointsToSets[val].end());

    } else if (isa<AllocaInst>(ptr) || isa<GlobalVariable>(ptr))
      dfval->PointsToSets[inst] = dfval->PointsToSets[ptr];
    else
      assert(0);
#ifdef DEBUG
    llvm::errs() << "\033[1;34m";
    llvm::errs() << *dfval;
    errs() << "\033[0m";
#endif
  }
}

然后是 may merge, 直接合并即可:

cpp

void PointsToVisitor::merge(PointsToInfo *dest, const PointsToInfo &src) {
#ifdef DEBUG
  llvm::errs() << "\033[1;36m";
  llvm::errs() << "merge: ";
  llvm::errs() << src << "\n";
  llvm::errs() << "into: ";
  llvm::errs() << *dest << "\n";
#endif
  for (auto inst : src.PointsToSets)
    dest->PointsToSets[inst.first].insert(inst.second.begin(),
                                          inst.second.end());
#ifdef DEBUG
  llvm::errs() << "after merge: ";
  llvm::errs() << *dest << "\n";
  llvm::errs() << "\033[0m";
#endif
}

抄一下上一个实验的输出, 就能够通过 test01 了 (是的, test00 需要过程间分析, 安排也太不合理了)

需要一半的过程间分析, 主要是把函数指针传参给调用的函数. 这里对每个函数都建一个 initval, 就可以把参数放进去了. 遍历每个 call 指令, 获得可能调用的函数 (也是需要输出的结果) 和指针参数, 把指针参数放入对应函数的 initval, 开始下一轮处理.

cpp

  int cnt = 0;
  do {
    last_result = result;
    for (auto &F : M) {
      if (F.isDeclaration())
        continue;
#ifdef DEBUG
      llvm::errs()
          << "==========================================================\n";
      F.dump();
      llvm::errs() << "\033[1;35m";
      llvm::errs() << "initval: " << func_initval[&F] << "\n";
      llvm::errs() << "\033[0m";
#endif

      compForwardDataflow(&F, &visitor, &result, func_initval[&F]);
      // printDataflowResult<PointsToInfo>(llvm::errs(), result);
    }

    for (auto &F : M) {
      if (F.isDeclaration())
        continue;
      for (auto &BB : F)
        for (auto &inst : BB) {
          if (isa<DbgInfoIntrinsic>(&inst))
            continue;
          if (CallInst *call_inst = dyn_cast<CallInst>(&inst)) {
#ifdef DEBUG
            llvm::errs() << "\033[1;35m";
            llvm::errs() << "call inst: ";
            call_inst->dump();
            llvm::errs() << "\033[0m";
#endif
            auto points_to_info_bb = result.at(&BB).second;
            std::map<unsigned, Value *> args;
            for (unsigned arg_idx = 0; arg_idx < call_inst->arg_size();
                 arg_idx++) {
              Value *arg = call_inst->getArgOperand(arg_idx);
              if (arg->getType()->isPointerTy())
                args[arg_idx] = arg;
            }
#ifdef DEBUG
            PointsToInfo tmp;
            llvm::errs() << "\033[1;35m";
            for (auto arg : args) {
              llvm::errs() << "arg " << arg.first << ": ";
              printValName(arg.second, llvm::errs());
              auto points_to = result.at(&BB).second.PointsToSets;
              auto it = points_to.find(arg.second);
              assert(it != points_to.end());
              tmp.PointsToSets.insert(*it);
              llvm::errs() << "\n";
            }
            llvm::errs() << "PointsToSets: " << tmp << "\n";
            llvm::errs() << "\033[0m";
#endif

            std::set<Function *> funcs =
                getCalledFunctions(call_inst, &points_to_info_bb);
            for (Function *func : funcs) {
              PointsToInfo &initval = func_initval[func];
              for (auto arg_pair : args) {
                unsigned arg_idx = arg_pair.first;
                Value *arg = arg_pair.second;
                std::set<Value *> arg_points_to =
                    points_to_info_bb.PointsToSets.at(arg);
                Value *fun_arg = dynamic_cast<Value *>(func->getArg(arg_idx));
                assert(fun_arg != nullptr && fun_arg->getType()->isPointerTy());
                initval.PointsToSets[fun_arg].insert(arg_points_to.begin(),
                                                     arg_points_to.end());
              }
            }
          }
        }
    }
    cnt++;
  } while (result != last_result);

需要注意的是, 由于 result 每轮都是一样的, 而 compForwardDataflow 里基本块的输入输出数据使用 insert 插入进去的, 需要修改成下面这样, 才能够正确覆盖 map 里的值:

cpp

  for (auto &bb : *fn) {
    auto it = result->find(&bb);
    if (it == result->end())
      result->insert({&bb, {initval, initval}});
    else
      it->second = {initval, initval};
    // result->insert({&bb, {initval, initval}});
    worklist.insert(&bb);
  }

跳过一堆先来到处理返回指针的这种过程间分析. 遍历到 call 语句可能调用的函数的时候, 需要得到他们的返回值的指向集, 合并以后放到当前函数的 initval 里.

cpp

    for (auto &F : M) {
      // ...
      for (auto &BB : F)
        for (auto &inst : BB) {
          // ...
          if (CallInst *call_inst = dyn_cast<CallInst>(&inst)) {
            // ...
            std::set<Function *> funcs =
                getCalledFunctions(call_inst, &points_to_info_bb_out);
            for (Function *func : funcs) {
              // ...
              // handle function return value if it is a pointer
              if (func->getReturnType()->isPointerTy()) {
#ifdef DEBUG
                llvm::errs() << "\033[1;37m";
                llvm::errs() << "func: " << func->getName() << "\n";
                llvm::errs() << "\033[0m";
#endif
                for (ReturnInst *return_inst : getReturnInsts(*func)) {
                  llvm::BasicBlock *return_inst_bb = return_inst->getParent();
                  auto points_to_info_return_bb_out =
                      result.at(return_inst_bb).second;
                  std::set<Value *> return_points_to =
                      points_to_info_return_bb_out.PointsToSets.at(
                          return_inst->getReturnValue());
#ifdef DEBUG
                  llvm::errs() << "\033[1;37m";
                  llvm::errs() << "return inst: ";
                  return_inst->dump();
                  llvm::errs() << "\n";
                  llvm::errs() << "return points to: ";
                  for (auto val : return_points_to) {
                    printValName(val, llvm::errs());
                    llvm::errs() << " ";
                  }
                  llvm::errs() << "\n";
                  llvm::errs() << "\033[0m";
#endif
                  func_initval[&F].PointsToSets[call_inst].insert(
                      return_points_to.begin(), return_points_to.end());
                }
              }
            }
          }
        }
    }

在遍历指令的时候对 call 语句初始化

cpp

  } else if (isa<CallInst>(inst)) {
    // init as empty set, leaving it to be filled on interprocedural analysis
    if (inst->getType()->isPointerTy() && dfval->PointsToSets.count(inst) == 0)
      dfval->PointsToSets[inst] = {};
  }

我们处理的 call 指令的类型是指针而不是变量, 所以在 getPointsTo 函数里和 load 指令以及 Argument (指针) 一样处理, 取出指向集.

cpp

std::set<Value *> getPointsToSet(Value *val, const PointsToInfo *dfval) {
  if (isa<LoadInst>(val) || isa<Argument>(val) || isa<CallInst>(val)) {
    return dfval->PointsToSets.at(val);
  }
  return {val};
}

test16 里的 malloc 好像不需要管他, 一个 bitcast 直接转移到了临时变量上了.

(还要看参数是不是修改了, 传入地址然后修改的这种)

往后看了几个没有结构体的过程间, 看到这个 *a = *b 需要过程间分析才能确定的, 把 compDFVal 里对 load 指令的处理改一下, 去掉 assert, 不满足条件直接跳过, 等过程间分析填了值再来即可.

cpp

        // store: *x = y;
        // assert(dfval->PointsToSets.count(ptr) > 0 &&
        //        dfval->PointsToSets[ptr].size() > 0); // S(x) = {}, assert
        if (dfval->PointsToSets.count(ptr) == 0 ||
            dfval->PointsToSets[ptr].size() == 0)
          return;

test28 里传的是二级指针, 这就会导致入口没有变量. 所以需要递归去解析. 只放关键部分了:

cpp

            PointsToInfo args_related_points_to_info;
            for (auto arg_pair : args) {
              Value *arg = arg_pair.second;
              PointsToInfo arg_related_points_to_info =
                  points_to_info_bb_out.getAllRelatedPointsTo(arg);

cpp

PointsToInfo PointsToInfo::getAllRelatedPointsTo(Value *val) {
  PointsToInfo related_points_to_info;
  if (PointsToSets.count(val)) {
    related_points_to_info.PointsToSets[val] = PointsToSets.at(val);
    for (auto related_val : PointsToSets.at(val))
      related_points_to_info.merge(getAllRelatedPointsTo(related_val));
  }
  return related_points_to_info;
}

涉及到结构体. 翻看了一下所有的测试, 结构体里都只有一个指针, 所以只需要做域不敏感的分析即可.

写一个这样的程序测试一下:

c

#include <stdlib.h>
struct s {
  int *ptr;
};
void stru() {
  struct s sptr;
  int a;
  sptr.ptr = &a;
}
void alloc() {
  int a;
  int **ptr = (int **)malloc(8);
  *ptr = &a;
}

可以发现 getelementptr 指令的操作和 *ptr 的 load 指令类似, 都是将某一个位置取出来. 因此可以做相似的处理. 首先在 getelementptr 指令处理时, 将当前指令的指向集设置为结构体基地址 (base_ptr) 在 load 指令和 store 指令中, 与对操作数是 load 的一样处理即可.

cpp

  if (StoreInst *store_inst = dyn_cast<StoreInst>(inst)) {
    // ...
    if (isa<PointerType>(val->getType())) {
      // ...
      else if (isa<LoadInst>(ptr) || isa<GetElementPtrInst>(ptr)) {
        // store: *x = y;
        // assert(dfval->PointsToSets.count(ptr) > 0 &&
        //        dfval->PointsToSets[ptr].size() > 0); // S(x) = {}, assert
        if (dfval->PointsToSets.count(ptr) == 0 ||
            dfval->PointsToSets[ptr].size() == 0)
          return;
        // S(x) = {v}, S' = S \ {v->S(v)} \cup {v->S(y)}
        if (dfval->PointsToSets[ptr].size() == 1)
          dfval->PointsToSets[*dfval->PointsToSets[ptr].begin()] =
              val_points_to;
        else // S(x) = {v1, v2, ...}, S' = S \cup {v1->S(y), v2->S(y), ...}
          for (auto val : dfval->PointsToSets[ptr])
            dfval->PointsToSets[val] = val_points_to;
      }
    // ...
    }
  } else if (LoadInst *load_inst = dyn_cast<LoadInst>(inst)) {
    // ...
    if (isa<LoadInst>(ptr) || isa<GetElementPtrInst>(ptr)) {
      dfval->PointsToSets[inst] = {};
      for (auto val : dfval->PointsToSets[ptr])
        dfval->PointsToSets[inst].insert(dfval->PointsToSets[val].begin(),
                                         dfval->PointsToSets[val].end());
    } 
    // ...
  } else if (GetElementPtrInst *gep_inst = dyn_cast<GetElementPtrInst>(inst)) {
    Value *base_ptr = gep_inst->getPointerOperand();
    dfval->PointsToSets[inst] = {base_ptr};
  }

写不出来了, 摆烂了, 过了20多个点, 及格了