150
|
1 //===- toyc.cpp - The Toy Compiler ----------------------------------------===//
|
|
2 //
|
|
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
4 // See https://llvm.org/LICENSE.txt for license information.
|
|
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
6 //
|
|
7 //===----------------------------------------------------------------------===//
|
|
8 //
|
|
9 // This file implements the entry point for the Toy compiler.
|
|
10 //
|
|
11 //===----------------------------------------------------------------------===//
|
|
12
|
|
13 #include "toy/Dialect.h"
|
|
14 #include "toy/MLIRGen.h"
|
|
15 #include "toy/Parser.h"
|
|
16 #include "toy/Passes.h"
|
|
17
|
|
18 #include "mlir/Analysis/Verifier.h"
|
|
19 #include "mlir/ExecutionEngine/ExecutionEngine.h"
|
|
20 #include "mlir/ExecutionEngine/OptUtils.h"
|
|
21 #include "mlir/IR/MLIRContext.h"
|
|
22 #include "mlir/IR/Module.h"
|
|
23 #include "mlir/InitAllDialects.h"
|
|
24 #include "mlir/Parser.h"
|
|
25 #include "mlir/Pass/Pass.h"
|
|
26 #include "mlir/Pass/PassManager.h"
|
|
27 #include "mlir/Target/LLVMIR.h"
|
|
28 #include "mlir/Transforms/Passes.h"
|
|
29
|
|
30 #include "llvm/ADT/StringRef.h"
|
|
31 #include "llvm/IR/Module.h"
|
|
32 #include "llvm/Support/CommandLine.h"
|
|
33 #include "llvm/Support/ErrorOr.h"
|
|
34 #include "llvm/Support/MemoryBuffer.h"
|
|
35 #include "llvm/Support/SourceMgr.h"
|
|
36 #include "llvm/Support/TargetSelect.h"
|
|
37 #include "llvm/Support/raw_ostream.h"
|
|
38
|
|
39 using namespace toy;
|
|
40 namespace cl = llvm::cl;
|
|
41
|
|
42 static cl::opt<std::string> inputFilename(cl::Positional,
|
|
43 cl::desc("<input toy file>"),
|
|
44 cl::init("-"),
|
|
45 cl::value_desc("filename"));
|
|
46
|
|
47 namespace {
|
|
48 enum InputType { Toy, MLIR };
|
|
49 }
|
|
50 static cl::opt<enum InputType> inputType(
|
|
51 "x", cl::init(Toy), cl::desc("Decided the kind of output desired"),
|
|
52 cl::values(clEnumValN(Toy, "toy", "load the input file as a Toy source.")),
|
|
53 cl::values(clEnumValN(MLIR, "mlir",
|
|
54 "load the input file as an MLIR file")));
|
|
55
|
|
56 namespace {
|
|
57 enum Action {
|
|
58 None,
|
|
59 DumpAST,
|
|
60 DumpMLIR,
|
|
61 DumpMLIRAffine,
|
|
62 DumpMLIRLLVM,
|
|
63 DumpLLVMIR,
|
|
64 RunJIT
|
|
65 };
|
|
66 }
|
|
67 static cl::opt<enum Action> emitAction(
|
|
68 "emit", cl::desc("Select the kind of output desired"),
|
|
69 cl::values(clEnumValN(DumpAST, "ast", "output the AST dump")),
|
|
70 cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump")),
|
|
71 cl::values(clEnumValN(DumpMLIRAffine, "mlir-affine",
|
|
72 "output the MLIR dump after affine lowering")),
|
|
73 cl::values(clEnumValN(DumpMLIRLLVM, "mlir-llvm",
|
|
74 "output the MLIR dump after llvm lowering")),
|
|
75 cl::values(clEnumValN(DumpLLVMIR, "llvm", "output the LLVM IR dump")),
|
|
76 cl::values(
|
|
77 clEnumValN(RunJIT, "jit",
|
|
78 "JIT the code and run it by invoking the main function")));
|
|
79
|
|
80 static cl::opt<bool> enableOpt("opt", cl::desc("Enable optimizations"));
|
|
81
|
|
82 /// Returns a Toy AST resulting from parsing the file or a nullptr on error.
|
|
83 std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
|
|
84 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
|
85 llvm::MemoryBuffer::getFileOrSTDIN(filename);
|
|
86 if (std::error_code ec = fileOrErr.getError()) {
|
|
87 llvm::errs() << "Could not open input file: " << ec.message() << "\n";
|
|
88 return nullptr;
|
|
89 }
|
|
90 auto buffer = fileOrErr.get()->getBuffer();
|
|
91 LexerBuffer lexer(buffer.begin(), buffer.end(), std::string(filename));
|
|
92 Parser parser(lexer);
|
|
93 return parser.parseModule();
|
|
94 }
|
|
95
|
|
96 int loadMLIR(mlir::MLIRContext &context, mlir::OwningModuleRef &module) {
|
|
97 // Handle '.toy' input to the compiler.
|
|
98 if (inputType != InputType::MLIR &&
|
|
99 !llvm::StringRef(inputFilename).endswith(".mlir")) {
|
|
100 auto moduleAST = parseInputFile(inputFilename);
|
|
101 if (!moduleAST)
|
|
102 return 6;
|
|
103 module = mlirGen(context, *moduleAST);
|
|
104 return !module ? 1 : 0;
|
|
105 }
|
|
106
|
|
107 // Otherwise, the input is '.mlir'.
|
|
108 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
|
|
109 llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
|
|
110 if (std::error_code EC = fileOrErr.getError()) {
|
|
111 llvm::errs() << "Could not open input file: " << EC.message() << "\n";
|
|
112 return -1;
|
|
113 }
|
|
114
|
|
115 // Parse the input mlir.
|
|
116 llvm::SourceMgr sourceMgr;
|
|
117 sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
|
|
118 module = mlir::parseSourceFile(sourceMgr, &context);
|
|
119 if (!module) {
|
|
120 llvm::errs() << "Error can't load file " << inputFilename << "\n";
|
|
121 return 3;
|
|
122 }
|
|
123 return 0;
|
|
124 }
|
|
125
|
|
126 int loadAndProcessMLIR(mlir::MLIRContext &context,
|
|
127 mlir::OwningModuleRef &module) {
|
|
128 if (int error = loadMLIR(context, module))
|
|
129 return error;
|
|
130
|
|
131 mlir::PassManager pm(&context);
|
|
132 // Apply any generic pass manager command line options and run the pipeline.
|
|
133 applyPassManagerCLOptions(pm);
|
|
134
|
|
135 // Check to see what granularity of MLIR we are compiling to.
|
|
136 bool isLoweringToAffine = emitAction >= Action::DumpMLIRAffine;
|
|
137 bool isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM;
|
|
138
|
|
139 if (enableOpt || isLoweringToAffine) {
|
|
140 // Inline all functions into main and then delete them.
|
|
141 pm.addPass(mlir::createInlinerPass());
|
|
142 pm.addPass(mlir::createSymbolDCEPass());
|
|
143
|
|
144 // Now that there is only one function, we can infer the shapes of each of
|
|
145 // the operations.
|
|
146 mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
|
|
147 optPM.addPass(mlir::createCanonicalizerPass());
|
|
148 optPM.addPass(mlir::toy::createShapeInferencePass());
|
|
149 optPM.addPass(mlir::createCanonicalizerPass());
|
|
150 optPM.addPass(mlir::createCSEPass());
|
|
151 }
|
|
152
|
|
153 if (isLoweringToAffine) {
|
|
154 // Partially lower the toy dialect with a few cleanups afterwards.
|
|
155 pm.addPass(mlir::toy::createLowerToAffinePass());
|
|
156
|
|
157 mlir::OpPassManager &optPM = pm.nest<mlir::FuncOp>();
|
|
158 optPM.addPass(mlir::createCanonicalizerPass());
|
|
159 optPM.addPass(mlir::createCSEPass());
|
|
160
|
|
161 // Add optimizations if enabled.
|
|
162 if (enableOpt) {
|
|
163 optPM.addPass(mlir::createLoopFusionPass());
|
|
164 optPM.addPass(mlir::createMemRefDataFlowOptPass());
|
|
165 }
|
|
166 }
|
|
167
|
|
168 if (isLoweringToLLVM) {
|
|
169 // Finish lowering the toy IR to the LLVM dialect.
|
|
170 pm.addPass(mlir::toy::createLowerToLLVMPass());
|
|
171 }
|
|
172
|
|
173 if (mlir::failed(pm.run(*module)))
|
|
174 return 4;
|
|
175 return 0;
|
|
176 }
|
|
177
|
|
178 int dumpAST() {
|
|
179 if (inputType == InputType::MLIR) {
|
|
180 llvm::errs() << "Can't dump a Toy AST when the input is MLIR\n";
|
|
181 return 5;
|
|
182 }
|
|
183
|
|
184 auto moduleAST = parseInputFile(inputFilename);
|
|
185 if (!moduleAST)
|
|
186 return 1;
|
|
187
|
|
188 dump(*moduleAST);
|
|
189 return 0;
|
|
190 }
|
|
191
|
|
192 int dumpLLVMIR(mlir::ModuleOp module) {
|
|
193 auto llvmModule = mlir::translateModuleToLLVMIR(module);
|
|
194 if (!llvmModule) {
|
|
195 llvm::errs() << "Failed to emit LLVM IR\n";
|
|
196 return -1;
|
|
197 }
|
|
198
|
|
199 // Initialize LLVM targets.
|
|
200 llvm::InitializeNativeTarget();
|
|
201 llvm::InitializeNativeTargetAsmPrinter();
|
|
202 mlir::ExecutionEngine::setupTargetTriple(llvmModule.get());
|
|
203
|
|
204 /// Optionally run an optimization pipeline over the llvm module.
|
|
205 auto optPipeline = mlir::makeOptimizingTransformer(
|
|
206 /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0,
|
|
207 /*targetMachine=*/nullptr);
|
|
208 if (auto err = optPipeline(llvmModule.get())) {
|
|
209 llvm::errs() << "Failed to optimize LLVM IR " << err << "\n";
|
|
210 return -1;
|
|
211 }
|
|
212 llvm::errs() << *llvmModule << "\n";
|
|
213 return 0;
|
|
214 }
|
|
215
|
|
216 int runJit(mlir::ModuleOp module) {
|
|
217 // Initialize LLVM targets.
|
|
218 llvm::InitializeNativeTarget();
|
|
219 llvm::InitializeNativeTargetAsmPrinter();
|
|
220
|
|
221 // An optimization pipeline to use within the execution engine.
|
|
222 auto optPipeline = mlir::makeOptimizingTransformer(
|
|
223 /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0,
|
|
224 /*targetMachine=*/nullptr);
|
|
225
|
|
226 // Create an MLIR execution engine. The execution engine eagerly JIT-compiles
|
|
227 // the module.
|
|
228 auto maybeEngine = mlir::ExecutionEngine::create(module, optPipeline);
|
|
229 assert(maybeEngine && "failed to construct an execution engine");
|
|
230 auto &engine = maybeEngine.get();
|
|
231
|
|
232 // Invoke the JIT-compiled function.
|
|
233 auto invocationResult = engine->invoke("main");
|
|
234 if (invocationResult) {
|
|
235 llvm::errs() << "JIT invocation failed\n";
|
|
236 return -1;
|
|
237 }
|
|
238
|
|
239 return 0;
|
|
240 }
|
|
241
|
|
242 int main(int argc, char **argv) {
|
|
243 mlir::registerAllDialects();
|
|
244 mlir::registerPassManagerCLOptions();
|
|
245 cl::ParseCommandLineOptions(argc, argv, "toy compiler\n");
|
|
246
|
|
247 if (emitAction == Action::DumpAST)
|
|
248 return dumpAST();
|
|
249
|
|
250 // If we aren't dumping the AST, then we are compiling with/to MLIR.
|
|
251
|
|
252 // Register our Dialect with MLIR.
|
|
253 mlir::registerDialect<mlir::toy::ToyDialect>();
|
|
254
|
|
255 mlir::MLIRContext context;
|
|
256 mlir::OwningModuleRef module;
|
|
257 if (int error = loadAndProcessMLIR(context, module))
|
|
258 return error;
|
|
259
|
|
260 // If we aren't exporting to non-mlir, then we are done.
|
|
261 bool isOutputingMLIR = emitAction <= Action::DumpMLIRLLVM;
|
|
262 if (isOutputingMLIR) {
|
|
263 module->dump();
|
|
264 return 0;
|
|
265 }
|
|
266
|
|
267 // Check to see if we are compiling to LLVM IR.
|
|
268 if (emitAction == Action::DumpLLVMIR)
|
|
269 return dumpLLVMIR(*module);
|
|
270
|
|
271 // Otherwise, we must be running the jit.
|
|
272 if (emitAction == Action::RunJIT)
|
|
273 return runJit(*module);
|
|
274
|
|
275 llvm::errs() << "No action specified (parsing only?), use -emit=<action>\n";
|
|
276 return -1;
|
|
277 }
|