Mercurial > hg > CbC > CbC_llvm
view mlir/unittests/Transforms/Canonicalizer.cpp @ 266:00f31e85ec16 default tip
Added tag current for changeset 31d058e83c98
author | Shinji KONO <kono@ie.u-ryukyu.ac.jp> |
---|---|
date | Sat, 14 Oct 2023 10:13:55 +0900 |
parents | c4bab56944e8 |
children |
line wrap: on
line source
//===- DialectConversion.cpp - Dialect conversion unit tests --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/IR/PatternMatch.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "gtest/gtest.h" using namespace mlir; namespace { struct DisabledPattern : public RewritePattern { DisabledPattern(MLIRContext *context) : RewritePattern("test.foo", /*benefit=*/0, context, /*generatedNamed=*/{}) { setDebugName("DisabledPattern"); } LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (op->getNumResults() != 1) return failure(); rewriter.eraseOp(op); return success(); } }; struct EnabledPattern : public RewritePattern { EnabledPattern(MLIRContext *context) : RewritePattern("test.foo", /*benefit=*/0, context, /*generatedNamed=*/{}) { setDebugName("EnabledPattern"); } LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { if (op->getNumResults() == 1) return failure(); rewriter.eraseOp(op); return success(); } }; struct TestDialect : public Dialect { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect) static StringRef getDialectNamespace() { return "test"; } TestDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) { allowUnknownOperations(); } void getCanonicalizationPatterns(RewritePatternSet &results) const override { results.add<DisabledPattern, EnabledPattern>(results.getContext()); } }; TEST(CanonicalizerTest, TestDisablePatterns) { MLIRContext context; context.getOrLoadDialect<TestDialect>(); PassManager mgr(&context); mgr.addPass( createCanonicalizerPass(GreedyRewriteConfig(), {"DisabledPattern"})); const char *const code = R"mlir( %0:2 = "test.foo"() {sym_name = "A"} : () -> (i32, i32) %1 = "test.foo"() {sym_name = "B"} : () -> (f32) )mlir"; OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context); ASSERT_TRUE(succeeded(mgr.run(*module))); EXPECT_TRUE(module->lookupSymbol("B")); EXPECT_FALSE(module->lookupSymbol("A")); } } // end anonymous namespace