236
|
1 //===- DialectConversion.cpp - Dialect conversion unit tests --------------===//
|
|
2 //
|
|
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
4 // See https://llvm.org/LICENSE.txt for license information.
|
|
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
6 //
|
|
7 //===----------------------------------------------------------------------===//
|
|
8
|
|
9 #include "mlir/IR/PatternMatch.h"
|
|
10 #include "mlir/Parser/Parser.h"
|
|
11 #include "mlir/Pass/PassManager.h"
|
|
12 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
13 #include "mlir/Transforms/Passes.h"
|
|
14 #include "gtest/gtest.h"
|
|
15
|
|
16 using namespace mlir;
|
|
17
|
|
18 namespace {
|
|
19
|
|
20 struct DisabledPattern : public RewritePattern {
|
|
21 DisabledPattern(MLIRContext *context)
|
|
22 : RewritePattern("test.foo", /*benefit=*/0, context,
|
|
23 /*generatedNamed=*/{}) {
|
|
24 setDebugName("DisabledPattern");
|
|
25 }
|
|
26
|
|
27 LogicalResult matchAndRewrite(Operation *op,
|
|
28 PatternRewriter &rewriter) const override {
|
|
29 if (op->getNumResults() != 1)
|
|
30 return failure();
|
|
31 rewriter.eraseOp(op);
|
|
32 return success();
|
|
33 }
|
|
34 };
|
|
35
|
|
36 struct EnabledPattern : public RewritePattern {
|
|
37 EnabledPattern(MLIRContext *context)
|
|
38 : RewritePattern("test.foo", /*benefit=*/0, context,
|
|
39 /*generatedNamed=*/{}) {
|
|
40 setDebugName("EnabledPattern");
|
|
41 }
|
|
42
|
|
43 LogicalResult matchAndRewrite(Operation *op,
|
|
44 PatternRewriter &rewriter) const override {
|
|
45 if (op->getNumResults() == 1)
|
|
46 return failure();
|
|
47 rewriter.eraseOp(op);
|
|
48 return success();
|
|
49 }
|
|
50 };
|
|
51
|
|
52 struct TestDialect : public Dialect {
|
|
53 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialect)
|
|
54
|
|
55 static StringRef getDialectNamespace() { return "test"; }
|
|
56
|
|
57 TestDialect(MLIRContext *context)
|
|
58 : Dialect(getDialectNamespace(), context, TypeID::get<TestDialect>()) {
|
|
59 allowUnknownOperations();
|
|
60 }
|
|
61
|
|
62 void getCanonicalizationPatterns(RewritePatternSet &results) const override {
|
|
63 results.add<DisabledPattern, EnabledPattern>(results.getContext());
|
|
64 }
|
|
65 };
|
|
66
|
|
67 TEST(CanonicalizerTest, TestDisablePatterns) {
|
|
68 MLIRContext context;
|
|
69 context.getOrLoadDialect<TestDialect>();
|
|
70 PassManager mgr(&context);
|
|
71 mgr.addPass(
|
|
72 createCanonicalizerPass(GreedyRewriteConfig(), {"DisabledPattern"}));
|
|
73
|
|
74 const char *const code = R"mlir(
|
|
75 %0:2 = "test.foo"() {sym_name = "A"} : () -> (i32, i32)
|
|
76 %1 = "test.foo"() {sym_name = "B"} : () -> (f32)
|
|
77 )mlir";
|
|
78
|
|
79 OwningOpRef<ModuleOp> module = parseSourceString<ModuleOp>(code, &context);
|
|
80 ASSERT_TRUE(succeeded(mgr.run(*module)));
|
|
81
|
|
82 EXPECT_TRUE(module->lookupSymbol("B"));
|
|
83 EXPECT_FALSE(module->lookupSymbol("A"));
|
|
84 }
|
|
85
|
|
86 } // end anonymous namespace
|