Mercurial > hg > CbC > CbC_llvm
comparison mlir/unittests/Transforms/Canonicalizer.cpp @ 236:c4bab56944e8 llvm-original
LLVM 16
author | kono |
---|---|
date | Wed, 09 Nov 2022 17:45:10 +0900 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
232:70dce7da266c | 236:c4bab56944e8 |
---|---|
1 //===- DialectConversion.cpp - Dialect conversion unit tests --------------===// | |
2 // | |
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |
4 // See https://llvm.org/LICENSE.txt for license information. | |
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |
6 // | |
7 //===----------------------------------------------------------------------===// | |
8 | |
9 #include "mlir/IR/PatternMatch.h" | |
10 #include "mlir/Parser/Parser.h" | |
11 #include "mlir/Pass/PassManager.h" | |
12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | |
13 #include "mlir/Transforms/Passes.h" | |
14 #include "gtest/gtest.h" | |
15 | |
16 using namespace mlir; | |
17 | |
18 namespace { | |
19 | |
20 struct DisabledPattern : public RewritePattern { | |
21 DisabledPattern(MLIRContext *context) | |
22 : RewritePattern("test.foo", /*benefit=*/0, context, | |
23 /*generatedNamed=*/{}) { | |
24 setDebugName("DisabledPattern"); | |
25 } | |
26 | |
27 LogicalResult matchAndRewrite(Operation *op, | |
28 PatternRewriter &rewriter) const override { | |
29 if (op->getNumResults() != 1) | |
30 return failure(); | |
31 rewriter.eraseOp(op); | |
32 return success(); | |
33 } | |
34 }; | |
35 | |
36 struct EnabledPattern : public RewritePattern { | |
37 EnabledPattern(MLIRContext *context) | |
38 : RewritePattern("test.foo", /*benefit=*/0, context, | |
39 /*generatedNamed=*/{}) { | |
40 setDebugName("EnabledPattern"); | |
41 } | |
42 | |
43 LogicalResult matchAndRewrite(Operation *op, | |
44 PatternRewriter &rewriter) const override { | |
45 if (op->getNumResults() == 1) | |
46 return failure(); | |
47 rewriter.eraseOp(op); | |
48 return success(); | |
49 } | |
50 }; | |
51 | |
52 struct TestDialect : public Dialect { | |
53 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect) | |
54 | |
55 static StringRef getDialectNamespace() { return "test"; } | |
56 | |
57 TestDialect(MLIRContext *context) | |
58 : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) { | |
59 allowUnknownOperations(); | |
60 } | |
61 | |
62 void getCanonicalizationPatterns(RewritePatternSet &results) const override { | |
63 results.add<DisabledPattern, EnabledPattern>(results.getContext()); | |
64 } | |
65 }; | |
66 | |
67 TEST(CanonicalizerTest, TestDisablePatterns) { | |
68 MLIRContext context; | |
69 context.getOrLoadDialect<TestDialect>(); | |
70 PassManager mgr(&context); | |
71 mgr.addPass( | |
72 createCanonicalizerPass(GreedyRewriteConfig(), {"DisabledPattern"})); | |
73 | |
74 const char *const code = R"mlir( | |
75 %0:2 = "test.foo"() {sym_name = "A"} : () -> (i32, i32) | |
76 %1 = "test.foo"() {sym_name = "B"} : () -> (f32) | |
77 )mlir"; | |
78 | |
79 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context); | |
80 ASSERT_TRUE(succeeded(mgr.run(*module))); | |
81 | |
82 EXPECT_TRUE(module->lookupSymbol("B")); | |
83 EXPECT_FALSE(module->lookupSymbol("A")); | |
84 } | |
85 | |
86 } // end anonymous namespace |