view pyrect/translator/llvm_translator.py @ 108:2632b963e441

modify llvm*.
author Ryoma SHINYA <shinya@firefly.cr.ie.u-ryukyu.ac.jp>
date Thu, 30 Dec 2010 17:18:40 +0900
parents 701beabd7d97
children
line wrap: on
line source

#!/usr/bin/env python

from llvm.core import *
from llvm.passes import *
from llvm.ee import *
from pyrect.regexp import Regexp
from translator import Translator

class LLVMTranslator(Translator):
    """LLVMTranslator
    This Class can translate from DFA or NFA into LLVM-IR.
    and also can JIT-Compile/evaluate it's self using llvm-py.
    >>> string = '(A|B)*C'
    >>> reg = Regexp(string)
    >>> lt = LLVMTranslator(reg)
    >>> lt.debug = True
    >>> lt.translate()
    >>> lt.execute()
    """

    # define llvm core types, and const
    int_t = Type.int(32)
    char_t = Type.int(8)
    charptr_t = Type.pointer(char_t)
    charptrptr_t = Type.pointer(charptr_t)
    const_zero = Constant.int(int_t, 0)
    const_one = Constant.int(int_t, 1)
    llvm.GuaranteedTailCallOpt = True

    def __init__(self, regexp):
        Translator.__init__(self, regexp)
        self.optimize = False
        self.debug = False
        self.string = "ABC"
        self.fa = regexp.dfa
        self.llvm_module = Module.new("DFA")
        self.compiled = False

    def emit_driver(self):
        main = self.llvm_module.add_function(
            Type.function(self.int_t, (self.int_t,)), "unitmain")
        main.args[0].name = "index"
        main_entry = main.append_basic_block("entry")

        emit = Builder.new(main_entry)
        start = self.llvm_module.get_function_named(self.fa.start)
        ret = emit.call(start, (main.args[0],))
        emit.ret(ret)
        self.main = main

    def jitcompile(self):
        self.matchp_str = self.new_str_const(self.string)
        self.debug_str = self.new_str_const("state: %s, arg: %c(int %d)\n")

        def optional_func_decl(fun):
            fun.calling_convertion = CC_X86_FASTCALL
            fun.args[0].name = "index"

        def func_decl(state):
            optional_func_decl(state)

        state_ref = dict()

        # Create function - accept and reject (final state).
        accept_state = self.llvm_module.add_function(
            Type.function(self.int_t, (self.int_t,)), "accept")
        optional_func_decl(accept_state)
        reject_state = self.llvm_module.add_function(
            Type.function(self.int_t, (self.int_t,)), "reject")
        optional_func_decl(reject_state)

        state_ref["accept"] = accept_state
        state_ref["reject"] = reject_state

        # add state to module, (as function or label).
        for state in self.fa.transition.iterkeys():
            fun = self.llvm_module.add_function(
                Type.function(self.int_t, (self.int_t,)), state)
            optional_func_decl(fun)
            state_ref[state] = fun

        # emit instructions
        emit = Builder.new(accept_state.append_basic_block("entry"))
        if self.debug: self.emit_call_printf(emit, "%s does match regexp\n", self.gep_first(emit, self.matchp_str))
        emit.ret(self.const_one)

        emit = Builder.new(reject_state.append_basic_block("entry"))
        if self.debug: self.emit_call_printf(emit, "%s does not match regexp\n", self.gep_first(emit, self.matchp_str))
        emit.ret(self.const_zero)

        for state, transition in self.fa.transition.iteritems():
            cases = dict()
            if state in self.fa.accepts:
                transition['\\0'] = ["accept"]
            for case, next_states in transition.iteritems():
                cases[self.char_const(case)] = state_ref[next_states[0]]
            state_fun = state_ref[state]
            emit = Builder.new(state_fun.append_basic_block("entry"))
            ptr = emit.gep(self.matchp_str, (self.const_zero, state_fun.args[0]))
            next_index = emit.add(state_fun.args[0], self.const_one)
            char = emit.load(ptr)

            if (self.debug): self.emit_call_printf(emit, self.debug_str, self.gep_first(emit, self.new_str_const(fun.name)), char, char)

            label = 0
            default_bb = state_fun.append_basic_block("default") #create default bb
            builder = Builder.new(default_bb)              # default is reject.
            ret = builder.call(reject_state, (next_index,))
            builder.ret(ret)

            si = emit.switch(char, default_bb, len(cases)) # create switch instruction with deafult case.
            for case, nextFun in cases.iteritems():
                bb = state_fun.append_basic_block("case%d" % label)   #create default bb
                builder = Builder.new(bb)
                ret = builder.call(nextFun, (next_index,))
                builder.ret(ret)
                si.add_case(case, bb)
                label += 1

        self.mp = ModuleProvider.new(self.llvm_module)
        if (self.optimize): self.do_optimize()
        self.ee = ExecutionEngine.new(self.mp)
        self.emit_driver()
        self.compiled = True

    def emit_from_callgraph(self):
        if not self.compiled:
            self.jitcompile()
        self.emit(str(self.llvm_module))
    def get_execution_engine(self):
        if not self.compiled:
            self.jitcompile()
        return self.ee

    def do_optimize(self):
        #optimization passes
        pm = PassManager.new()
        pm.add(TargetData.new(''))
        pm.add(PASS_FUNCTION_INLINING)
        pm.run(self.llvm_module)
        fp = FunctionPassManager.new(self.mp)
        fp.add(TargetData.new(''))
        fp.add(PASS_BLOCK_PLACEMENT)
        fp.add(PASS_INSTRUCTION_COMBINING)
        fp.add(PASS_TAIL_CALL_ELIMINATION)
        fp.add(PASS_AGGRESSIVE_DCE)
        fp.add(PASS_DEAD_INST_ELIMINATION)
        fp.add(PASS_DEAD_CODE_ELIMINATION)
        for fun in self.llvm_module.functions:
            fp.run(fun)

    def print_module(self):
        if not self.compiled:
            self.jitcompile()
        print self.llvm_module

    def execute(self):
        if not self.compiled:
            self.jitcompile()
        self.ee.run_function(self.main,
                             (GenericValue.int(self.int_t, 0),))
        return

    def new_str_const(self, val):
        '''create string(array of int) as a global value '''
        str = self.llvm_module.add_global_variable(Type.array(self.char_t, len(val) + 1), "")
        str.initializer = Constant.stringz(val)
        return str

    def gep_first(self, emit, val):
        '''get pointer of array'''
        return emit.gep(val, (self.const_zero, self.const_zero))

    def char_const(self, val):
        '''create constant int value'''
        if isinstance(val, str):
            if val == '\\0':
                return Constant.int(self.char_t, 0)
            else:
                return Constant.int(self.char_t, ord(val))
        else:
            exit('char_const: invalid argument.', val)

    def emit_call_printf(self, emit, string, *args):
        '''emit libc printf function call instruction'''
        try:
            printf = self.llvm_module.get_function_named("printf")
        except llvm.LLVMException:
            printf = self.llvm_module.add_function(
                Type.function(Type.void(),
                              (Type.pointer(self.char_t, 0),), 1), "printf")
        if isinstance(string, str):
            string = self.new_str_const(string)
        emit.call(printf,
                  [self.gep_first(emit, string)]+list(args))

def test():
    import doctest
    doctest.testmod()

if __name__ == "__main__": test()