view pyrect/translator/llvm_grep_translator.py @ 108:2632b963e441

modify llvm*.
author Ryoma SHINYA <shinya@firefly.cr.ie.u-ryukyu.ac.jp>
date Thu, 30 Dec 2010 17:18:40 +0900
parents 701beabd7d97
children
line wrap: on
line source

#!/usr/bin/env python

import os
from llvm.core import *
from llvm.passes import *
from llvm.ee import *
from llvm_translator import LLVMTranslator
from pyrect.regexp import Regexp

class LLVMGREPTranslator(LLVMTranslator):
    """LLVMGREPTranslator
    This class can translate from DFA into grep LLVM-module.
    which can translate LLVM-IR, and also can execute it's self.
    >>> string = 'def'
    >>> reg = Regexp(string)
    >>> lt = LLVMGREPTranslator(reg)
    >>> lt.translate()
    >>> ret = lt.execute()
    >>> isinstance(ret, llvm.ee.GenericValue)
    True
    """

    BASE_DIR = os.path.dirname(os.path.abspath(__file__))

    def __init__(self, regexp):
        LLVMTranslator.__init__(self, regexp)
        llfile = file(self.BASE_DIR + "/template/grep.ll")
        self.llvm_module = Module.from_assembly(llfile)
        self.compiled = False
        self.string = regexp.regexp
        self.args = []

    def state_name(self, state_name):
        return str(state_name)

    def emit_driver(self):
        self.regexp_str = self.new_str_const(self.string)
        dfa = self.llvm_module.get_or_insert_function(
            Type.function(self.int_t, (self.charptr_t,)), "DFA")
        dfa_entry = dfa.append_basic_block("entry")
        emit = Builder.new(dfa_entry)
        ret = emit.call(self.llvm_module.get_function_named(self.fa.start)
                        ,(dfa.args[0],))
        emit.ret(ret)

        main = self.llvm_module.add_function(
            Type.function(Type.void(), (self.int_t,)), "pre_main")
        main_entry = main.append_basic_block("entry")
        emit = Builder.new(main_entry)

        index = len(self.args)

        if index == 1:
            grep = self.llvm_module.get_function_named("llgrep")
        else:
            grep = self.llvm_module.get_function_named("llgrep_with_name")

        for i in range(index):
            emit.call(grep, (self.gep_first(emit, self.regexp_str),
                             self.gep_first(emit, self.new_str_const(self.args[i]))))
        emit.ret_void()

        self.main = main

    def jitcompile(self):
        self.debug_str = self.new_str_const("state: %s, arg: %c(int %d)\n")
        def optional_func_decl(fun):
            #fun.calling_convertion = CC_X86_FASTCALL
            fun.args[0].name = "index"

        def func_decl(state):
            optional_func_decl(state)

        state_ref = dict()

        # Create function - accept and reject (final state).
        accept_state = self.llvm_module.add_function(
            Type.function(self.int_t, (self.charptr_t,)), "accept")
        optional_func_decl(accept_state)
        reject_state = self.llvm_module.add_function(
            Type.function(self.int_t, (self.charptr_t,)), "reject")
        optional_func_decl(reject_state)

        state_ref["accept"] = accept_state
        state_ref["reject"] = reject_state

        # add state to module, (as function or label).
        for state in self.fa.transition.iterkeys():
            fun = self.llvm_module.add_function(
                Type.function(self.int_t, (self.charptr_t,)), self.state_name(state))
            optional_func_decl(fun)
            state_ref[state] = fun

        # emit instructions
        emit = Builder.new(accept_state.append_basic_block("entry"))
        emit.ret(self.const_one)

        emit = Builder.new(reject_state.append_basic_block("entry"))
        emit.ret(self.const_zero)

        for state, transition in self.fa.transition.iteritems():
            cases = dict()
            state_fun = state_ref[state]
            emit = Builder.new(state_fun.append_basic_block("entry"))

            if state in self.fa.accepts:
                ret = emit.call(accept_state, (state_fun.args[0],))
                emit.ret(ret)
                continue

            for case, next_state in transition.iteritems():
                cases[self.char_const(case)] = state_ref[next_state]

            char = emit.load(state_fun.args[0])
            next_ptr = emit.gep(state_fun.args[0], (self.const_one,))
            if (self.debug): self.emit_call_printf(emit, self.debug_str, self.gep_first(emit, self.new_str_const(fun.name)), char, char)

            label = 0
            default_bb = state_fun.append_basic_block("default") #create default bb
            builder = Builder.new(default_bb)              # default is reject.
            ret = builder.call(reject_state, (next_ptr,))
            builder.ret(ret)

            si = emit.switch(char, default_bb, len(cases)) # create switch instruction with deafult case.
            for case, nextFun in cases.iteritems():
                bb = state_fun.append_basic_block("case%d" % label)   #create default bb
                builder = Builder.new(bb)
                ret = builder.call(nextFun, (next_ptr,))
                builder.ret(ret)
                si.add_case(case, bb)
                label += 1

        self.emit_driver()
        self.mp = ModuleProvider.new(self.llvm_module)
        if (self.optimize): self.do_optimize()
        self.ee = ExecutionEngine.new(self.mp)
        self.compiled = True

    def execute(self):
        if not self.compiled:
            self.jitcompile()
        return self.ee.run_function(self.main,
                                    (GenericValue.int(self.int_t, 0),))

def test():
    import doctest
    doctest.testmod()

if __name__ == "__main__": test()