comparison mlir/unittests/Transforms/Canonicalizer.cpp @ 236:c4bab56944e8 llvm-original

LLVM 16
author kono
date Wed, 09 Nov 2022 17:45:10 +0900
parents
children
comparison
equal deleted inserted replaced
232:70dce7da266c 236:c4bab56944e8
1 //===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/Parser/Parser.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13 #include "mlir/Transforms/Passes.h"
14 #include "gtest/gtest.h"
15
16 using namespace mlir;
17
18 namespace {
19
20 struct DisabledPattern : public RewritePattern {
21 DisabledPattern(MLIRContext *context)
22 : RewritePattern("test.foo", /*benefit=*/0, context,
23 /*generatedNamed=*/{}) {
24 setDebugName("DisabledPattern");
25 }
26
27 LogicalResult matchAndRewrite(Operation *op,
28 PatternRewriter &rewriter) const override {
29 if (op->getNumResults() != 1)
30 return failure();
31 rewriter.eraseOp(op);
32 return success();
33 }
34 };
35
36 struct EnabledPattern : public RewritePattern {
37 EnabledPattern(MLIRContext *context)
38 : RewritePattern("test.foo", /*benefit=*/0, context,
39 /*generatedNamed=*/{}) {
40 setDebugName("EnabledPattern");
41 }
42
43 LogicalResult matchAndRewrite(Operation *op,
44 PatternRewriter &rewriter) const override {
45 if (op->getNumResults() == 1)
46 return failure();
47 rewriter.eraseOp(op);
48 return success();
49 }
50 };
51
52 struct TestDialect : public Dialect {
53 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect)
54
55 static StringRef getDialectNamespace() { return "test"; }
56
57 TestDialect(MLIRContext *context)
58 : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {
59 allowUnknownOperations();
60 }
61
62 void getCanonicalizationPatterns(RewritePatternSet &results) const override {
63 results.add<DisabledPattern, EnabledPattern>(results.getContext());
64 }
65 };
66
67 TEST(CanonicalizerTest, TestDisablePatterns) {
68 MLIRContext context;
69 context.getOrLoadDialect<TestDialect>();
70 PassManager mgr(&context);
71 mgr.addPass(
72 createCanonicalizerPass(GreedyRewriteConfig(), {"DisabledPattern"}));
73
74 const char *const code = R"mlir(
75 %0:2 = "test.foo"() {sym_name = "A"} : () -> (i32, i32)
76 %1 = "test.foo"() {sym_name = "B"} : () -> (f32)
77 )mlir";
78
79 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context);
80 ASSERT_TRUE(succeeded(mgr.run(*module)));
81
82 EXPECT_TRUE(module->lookupSymbol("B"));
83 EXPECT_FALSE(module->lookupSymbol("A"));
84 }
85
86 } // end anonymous namespace