Mercurial > hg > CbC > CbC_llvm
diff mlir/lib/Transforms/OpStats.cpp @ 236:c4bab56944e8 llvm-original
LLVM 16
author | kono |
---|---|
date | Wed, 09 Nov 2022 17:45:10 +0900 |
parents | 79ff65ed7e25 |
children |
line wrap: on
line diff
--- a/mlir/lib/Transforms/OpStats.cpp Wed Jul 21 10:27:27 2021 +0900 +++ b/mlir/lib/Transforms/OpStats.cpp Wed Nov 09 17:45:10 2022 +0900 @@ -6,20 +6,29 @@ // //===----------------------------------------------------------------------===// -#include "PassDetail.h" +#include "mlir/Transforms/Passes.h" + #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" -#include "mlir/Transforms/Passes.h" #include "llvm/ADT/DenseMap.h" #include "llvm/Support/Format.h" #include "llvm/Support/raw_ostream.h" +namespace mlir { +#define GEN_PASS_DEF_PRINTOPSTATS +#include "mlir/Transforms/Passes.h.inc" +} // namespace mlir + using namespace mlir; namespace { -struct PrintOpStatsPass : public PrintOpStatsBase<PrintOpStatsPass> { - explicit PrintOpStatsPass(raw_ostream &os = llvm::errs()) : os(os) {} +struct PrintOpStatsPass : public impl::PrintOpStatsBase<PrintOpStatsPass> { + explicit PrintOpStatsPass(raw_ostream &os) : os(os) {} + + explicit PrintOpStatsPass(raw_ostream &os, bool printAsJSON) : os(os) { + this->printAsJSON = printAsJSON; + } // Prints the resultant operation statistics post iterating over the module. void runOnOperation() override; @@ -27,6 +36,9 @@ // Print summary of op stats. void printSummary(); + // Print symmary of op stats in JSON. + void printSummaryInJSON(); + private: llvm::StringMap<int64_t> opCount; raw_ostream &os; @@ -37,8 +49,12 @@ opCount.clear(); // Compute the operation statistics for the currently visited operation. - getOperation()->walk([&](Operation *op) { ++opCount[op->getName().getStringRef()]; }); - printSummary(); + getOperation()->walk( + [&](Operation *op) { ++opCount[op->getName().getStringRef()]; }); + if (printAsJSON) { + printSummaryInJSON(); + } else + printSummary(); } void PrintOpStatsPass::printSummary() { @@ -55,16 +71,15 @@ }; // Compute the largest dialect and operation name. - StringRef dialectName, opName; size_t maxLenOpName = 0, maxLenDialect = 0; for (const auto &key : sorted) { - std::tie(dialectName, opName) = splitOperationName(key); + auto [dialectName, opName] = splitOperationName(key); maxLenDialect = std::max(maxLenDialect, dialectName.size()); maxLenOpName = std::max(maxLenOpName, opName.size()); } for (const auto &key : sorted) { - std::tie(dialectName, opName) = splitOperationName(key); + auto [dialectName, opName] = splitOperationName(key); // Left-align the names (aligning on the dialect) and right-align the count // below. The alignment is for readability and does not affect CSV/FileCheck @@ -80,6 +95,28 @@ } } -std::unique_ptr<Pass> mlir::createPrintOpStatsPass() { - return std::make_unique<PrintOpStatsPass>(); +void PrintOpStatsPass::printSummaryInJSON() { + SmallVector<StringRef, 64> sorted(opCount.keys()); + llvm::sort(sorted); + + os << "{\n"; + + for (unsigned i = 0, e = sorted.size(); i != e; ++i) { + const auto &key = sorted[i]; + os << " \"" << key << "\" : " << opCount[key]; + if (i != e - 1) + os << ",\n"; + else + os << "\n"; + } + os << "}\n"; } + +std::unique_ptr<Pass> mlir::createPrintOpStatsPass(raw_ostream &os) { + return std::make_unique<PrintOpStatsPass>(os); +} + +std::unique_ptr<Pass> mlir::createPrintOpStatsPass(raw_ostream &os, + bool printAsJSON) { + return std::make_unique<PrintOpStatsPass>(os, printAsJSON); +}