diff src/llvm_translator.py @ 18:ec36e784df2e

add LLVMTranslator(Translator) (and remove reg2llvm.py), add --LLVM option to converter.py.
author Ryoma SHINYA <shinya@firefly.cr.ie.u-ryukyu.ac.jp>
date Mon, 05 Jul 2010 08:36:11 +0900
parents
children a24acddedf89
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/llvm_translator.py	Mon Jul 05 08:36:11 2010 +0900
@@ -0,0 +1,225 @@
+#!/usr/bin/env python
+
+from llvm.core import *
+from llvm.passes import *
+from llvm.ee import *
+from translator import Translator
+from dfareg import Regexp, CallGraph
+
+class LLVMTranslator(Translator):
+    """LLVMTranslator
+    This Class can translate from DFA or NFA into LLVM-IR.
+    and also can JIT-Compile/evaluate it's self using llvm-py.
+    >>> string = '(A|B)*C'
+    >>> reg = Regexp(string)
+    >>> dfacg = CallGraph(reg.dfa)
+    >>> lt = LLVMTranslator(string, dfacg)
+    >>> lt.translate()
+    >>> isinstance(lt.execute(), llvm.ee.GenericValue)
+    True
+    """
+    # define llvm core types, and const
+    int_t = Type.int()
+    char_t = Type.int(8)
+    charptr_t = Type.pointer(char_t)
+    charptrptr_t = Type.pointer(charptr_t)
+    const_zero = Constant.int(int_t, 0)
+    const_one = Constant.int(int_t, 1)
+    llvm.GuaranteedTailCallOpt = True
+
+    def __init__(self, regexp, cg): #(self, modName, regexp, string, self.impl_label, optimized, debug):
+        Translator.__init__(self, regexp, cg)
+        self.mod = Module.new(self.cg.type)
+        self.optimize = False
+        self.debug = False
+        self.impl_label = False
+        self.matchp_str = self.new_str_const("ABC")
+        self.debug_str = self.new_str_const("state: %s, arg: %c(int %d)\n")
+        if self.cg.type == "DFA":
+            self.name_hash = self.create_name_hash()
+
+    def modify_state_name(self, state_name):
+        if self.cg.type == "DFA":
+            return self.name_hash[state_name]
+        else:
+            return state_name
+
+    def emit_from_callgraph(self):
+        def optional_func_decl(fun):
+            fun.calling_convertion = CC_X86_FASTCALL
+            fun.args[0].name = "index"
+
+        def func_decl(state):
+            optional_func_decl(state)
+
+        state_ref = dict()
+        main = self.mod.add_function(
+            Type.function(self.int_t, (self.int_t,)), "main")
+        optional_func_decl(main)
+        main_entry = main.append_basic_block("entry")
+
+        if self.impl_label:
+            accept_state = main.append_basic_block("accpet")
+            reject_state = main.append_basic_block("reject")
+            index_ptr = Builder.new(main_entry).malloc(self.int_t)
+            Builder.new(accept_state).free(index_ptr)
+            Builder.new(reject_state).free(index_ptr)
+        else:
+            # Create function - accept and reject (final state).
+            accept_state = self.mod.add_function(
+                Type.function(self.int_t, (self.int_t,)), "accept")
+            optional_func_decl(accept_state)
+            reject_state = self.mod.add_function(
+                Type.function(self.int_t, (self.int_t,)), "reject")
+            optional_func_decl(reject_state)
+
+
+        state_ref["accept"] = accept_state
+        state_ref["reject"] = reject_state
+
+        # add state to module, (as function or label).
+        if (self.impl_label):
+            for state in self.cg.map.iterkeys():
+                label = main.append_basic_block(state)
+                state_ref[state] = label
+        else:
+            for state in self.cg.map.iterkeys():
+                fun = self.mod.add_function(
+                    Type.function(self.int_t, (self.int_t,)), state)
+                optional_func_decl(fun)
+                state_ref[state] = fun
+
+        # emit instructions
+        if (self.impl_label): emit = Builder.new(accept_state)
+        else:            emit = Builder.new(accept_state.append_basic_block("entry"))
+
+        self.emit_call_printf(emit, "%s does match regexp\n", self.gep_first(emit, self.matchp_str))
+        emit.ret(self.const_one)
+
+        if (self.impl_label): emit = Builder.new(reject_state)
+        else:            emit = Builder.new(reject_state.append_basic_block("entry"))
+        self.emit_call_printf(emit, "%s does not match regexp\n", self.gep_first(emit, self.matchp_str))
+        emit.ret(self.const_zero)
+
+        if (self.impl_label):
+            # emit transition instruction with jump instruction
+            emit = Builder.new(main_entry)
+            emit.store(main.args[0], index_ptr)
+            emit.branch(state_ref[self.cg.start])
+
+            for state, transition in self.cg.map.iteritems():
+                emit = Builder.new(state_ref[state])
+                index = emit.load(index_ptr)
+                ptr = emit.gep(self.matchp_str, (self.const_zero, index))
+                emit.store(emit.add(self.const_one, index), index_ptr)
+                char = emit.load(ptr)
+                si = emit.switch(char, state_ref['reject'], len(transition))
+                local_label = 0
+                for case, next_states in transition.iteritems():
+                    bb = main.append_basic_block("%s_case%d" % (state, local_label))   #create default bb
+                    emit = Builder.new(bb)
+                    emit.branch(state_ref[next_states[0]])
+                    si.add_case(self.char_const(case), bb)
+                    local_label += 1
+        else:
+            for state, transition in self.cg.map.iteritems():
+                cases = dict()
+                for case, next_states in transition.iteritems():
+                    cases[self.char_const(case)] = state_ref[next_states[0]]
+                state_fun = state_ref[state]
+                emit = Builder.new(state_fun.append_basic_block("entry"))
+                ptr = emit.gep(self.matchp_str, (self.const_zero, state_fun.args[0]))
+                next_index = emit.add(state_fun.args[0], self.const_one)
+                char = emit.load(ptr)
+
+                if (self.debug): self.emit_call_printf(emit, self.debug_str, self.gep_first(emit, self.new_str_const(fun.name)), char, char)
+
+                label = 0
+                default_bb = state_fun.append_basic_block("default") #create default bb
+                builder = Builder.new(default_bb)              # default is reject.
+                ret = builder.call(reject_state, (next_index,))
+                builder.ret(ret)
+
+                si = emit.switch(char, default_bb, len(cases)) # create switch instruction with deafult case.
+                for case, nextFun in cases.iteritems():
+                    bb = state_fun.append_basic_block("case%d" % label)   #create default bb
+                    builder = Builder.new(bb)
+                    ret = builder.call(nextFun, (next_index,))
+                    builder.ret(ret)
+                    si.add_case(case, bb)
+                    label += 1
+            emit = Builder.new(main_entry)
+            ret = emit.call(state_ref[self.cg.start], (main.args[0],))
+            emit.ret(ret)
+
+        self.mp = ModuleProvider.new(self.mod)
+        if (self.optimize): self.do_optimize()
+        self.ee = ExecutionEngine.new(self.mp)
+        self.main = main
+        self.emit(str(self.mod))
+
+    def get_execution_engine(self):
+        return self.ee
+
+    def do_optimize(self):
+        #optimization passes
+        pm = PassManager.new()
+        pm.add(TargetData.new(''))
+        pm.add(PASS_FUNCTION_INLINING)
+        pm.run(self.mod)
+        fp = FunctionPassManager.new(self.mp)
+        fp.add(TargetData.new(''))
+        fp.add(PASS_BLOCK_PLACEMENT)
+        fp.add(PASS_INSTRUCTION_COMBINING)
+        fp.add(PASS_TAIL_CALL_ELIMINATION)
+        fp.add(PASS_AGGRESSIVE_DCE)
+        fp.add(PASS_DEAD_INST_ELIMINATION)
+        fp.add(PASS_DEAD_CODE_ELIMINATION)
+        for fun in self.mod.functions:
+            fp.run(fun)
+
+    def print_module(self):
+        print self.mod
+
+    def execute(self):
+        return self.ee.run_function(self.main,
+                                    (GenericValue.int(self.int_t, 0),))
+
+    def new_str_const(self, val):
+        '''create string(array of int) as a global value '''
+        str = self.mod.add_global_variable(Type.array(self.char_t, len(val) + 1), "")
+        str.initializer = Constant.stringz(val)
+        return str
+
+    def gep_first(self, emit, val):
+        '''get pointer of array'''
+        return emit.gep(val, (self.const_zero, self.const_zero))
+
+    def char_const(self, val):
+        '''create constant int value'''
+        if isinstance(val, str):
+            if val == '\\0':
+                return Constant.int(self.char_t, 0)
+            else:
+                return Constant.int(self.char_t, ord(val))
+        else:
+            exit('char_const: invalid argument.', val)
+
+    def emit_call_printf(self, emit, string, *args):
+        '''emit libc printf function call instruction'''
+        try:
+            printf = self.mod.get_function_named("printf")
+        except llvm.LLVMException:
+            printf = self.mod.add_function(
+                Type.function(Type.void(),
+                              (Type.pointer(self.char_t, 0),), 1), "printf")
+        if isinstance(string, str):
+            string = self.new_str_const(string)
+        emit.call(printf,
+                  [self.gep_first(emit, string)]+list(args))
+
+def test():
+    import doctest
+    doctest.testmod()
+
+if __name__ == "__main__": test()