diff mlir/unittests/TableGen/PassGenTest.cpp @ 236:c4bab56944e8 llvm-original

LLVM 16
author kono
date Wed, 09 Nov 2022 17:45:10 +0900
parents 5f17cb93ff66
children
line wrap: on
line diff
--- a/mlir/unittests/TableGen/PassGenTest.cpp	Wed Jul 21 10:27:27 2021 +0900
+++ b/mlir/unittests/TableGen/PassGenTest.cpp	Wed Nov 09 17:45:10 2022 +0900
@@ -7,31 +7,34 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Pass/Pass.h"
+#include "llvm/ADT/STLExtras.h"
 
 #include "gmock/gmock.h"
 
-std::unique_ptr<mlir::Pass> createTestPass(int v = 0);
+std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v = 0);
 
+#define GEN_PASS_DECL
 #define GEN_PASS_REGISTRATION
 #include "PassGenTest.h.inc"
 
-#define GEN_PASS_CLASSES
+#define GEN_PASS_DEF_TESTPASS
+#define GEN_PASS_DEF_TESTPASSWITHOPTIONS
+#define GEN_PASS_DEF_TESTPASSWITHCUSTOMCONSTRUCTOR
 #include "PassGenTest.h.inc"
 
-struct TestPass : public TestPassBase<TestPass> {
-  explicit TestPass(int v) : extraVal(v) {}
+struct TestPass : public impl::TestPassBase<TestPass> {
+  using TestPassBase::TestPassBase;
 
   void runOnOperation() override {}
 
   std::unique_ptr<mlir::Pass> clone() const {
     return TestPassBase<TestPass>::clone();
   }
-
-  int extraVal;
 };
 
-std::unique_ptr<mlir::Pass> createTestPass(int v) {
-  return std::make_unique<TestPass>(v);
+TEST(PassGenTest, defaultGeneratedConstructor) {
+  std::unique_ptr<mlir::Pass> pass = createTestPass();
+  EXPECT_TRUE(pass.get() != nullptr);
 }
 
 TEST(PassGenTest, PassClone) {
@@ -41,7 +44,75 @@
     return static_cast<const TestPass *>(pass.get());
   };
 
-  const auto origPass = createTestPass(10);
+  const auto origPass = createTestPass();
+  const auto clonePass = unwrap(origPass)->clone();
+
+  EXPECT_TRUE(clonePass.get() != nullptr);
+  EXPECT_TRUE(origPass.get() != clonePass.get());
+}
+
+struct TestPassWithOptions
+    : public impl::TestPassWithOptionsBase<TestPassWithOptions> {
+  using TestPassWithOptionsBase::TestPassWithOptionsBase;
+
+  void runOnOperation() override {}
+
+  std::unique_ptr<mlir::Pass> clone() const {
+    return TestPassWithOptionsBase<TestPassWithOptions>::clone();
+  }
+
+  int getTestOption() const { return testOption; }
+
+  llvm::ArrayRef<int64_t> getTestListOption() const { return testListOption; }
+};
+
+TEST(PassGenTest, PassOptions) {
+  mlir::MLIRContext context;
+
+  TestPassWithOptionsOptions options;
+  options.testOption = 57;
+
+  llvm::SmallVector<int64_t, 2> testListOption = {1, 2};
+  options.testListOption = testListOption;
+
+  const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
+    return static_cast<const TestPassWithOptions *>(pass.get());
+  };
+
+  const auto pass = createTestPassWithOptions(options);
+
+  EXPECT_EQ(unwrap(pass)->getTestOption(), 57);
+  EXPECT_EQ(unwrap(pass)->getTestListOption()[0], 1);
+  EXPECT_EQ(unwrap(pass)->getTestListOption()[1], 2);
+}
+
+struct TestPassWithCustomConstructor
+    : public impl::TestPassWithCustomConstructorBase<
+          TestPassWithCustomConstructor> {
+  explicit TestPassWithCustomConstructor(int v) : extraVal(v) {}
+
+  void runOnOperation() override {}
+
+  std::unique_ptr<mlir::Pass> clone() const {
+    return TestPassWithCustomConstructorBase<
+        TestPassWithCustomConstructor>::clone();
+  }
+
+  unsigned int extraVal = 23;
+};
+
+std::unique_ptr<mlir::Pass> createTestPassWithCustomConstructor(int v) {
+  return std::make_unique<TestPassWithCustomConstructor>(v);
+}
+
+TEST(PassGenTest, PassCloneWithCustomConstructor) {
+  mlir::MLIRContext context;
+
+  const auto unwrap = [](const std::unique_ptr<mlir::Pass> &pass) {
+    return static_cast<const TestPassWithCustomConstructor *>(pass.get());
+  };
+
+  const auto origPass = createTestPassWithCustomConstructor(10);
   const auto clonePass = unwrap(origPass)->clone();
 
   EXPECT_EQ(unwrap(origPass)->extraVal, unwrap(clonePass)->extraVal);