diff mlir/lib/Transforms/PipelineDataTransfer.cpp @ 150:1d019706d866

LLVM10
author anatofuz
date Thu, 13 Feb 2020 15:10:13 +0900
parents
children 0572611fdcc8
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp	Thu Feb 13 15:10:13 2020 +0900
@@ -0,0 +1,378 @@
+//===- PipelineDataTransfer.cpp --- Pass for pipelining data movement ---*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to pipeline data transfers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Passes.h"
+
+#include "mlir/Analysis/AffineAnalysis.h"
+#include "mlir/Analysis/LoopAnalysis.h"
+#include "mlir/Analysis/Utils.h"
+#include "mlir/Dialect/AffineOps/AffineOps.h"
+#include "mlir/Dialect/StandardOps/Ops.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/LoopUtils.h"
+#include "mlir/Transforms/Utils.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#define DEBUG_TYPE "affine-pipeline-data-transfer"
+
+using namespace mlir;
+
+namespace {
+
+struct PipelineDataTransfer : public FunctionPass<PipelineDataTransfer> {
+  void runOnFunction() override;
+  void runOnAffineForOp(AffineForOp forOp);
+
+  std::vector<AffineForOp> forOps;
+};
+
+} // end anonymous namespace
+
+/// Creates a pass to pipeline explicit movement of data across levels of the
+/// memory hierarchy.
+std::unique_ptr<OpPassBase<FuncOp>> mlir::createPipelineDataTransferPass() {
+  return std::make_unique<PipelineDataTransfer>();
+}
+
+// Returns the position of the tag memref operand given a DMA operation.
+// Temporary utility: will be replaced when DmaStart/DmaFinish abstract op's are
+// added.  TODO(b/117228571)
+static unsigned getTagMemRefPos(Operation &dmaInst) {
+  assert(isa<AffineDmaStartOp>(dmaInst) || isa<AffineDmaWaitOp>(dmaInst));
+  if (auto dmaStartOp = dyn_cast<AffineDmaStartOp>(dmaInst)) {
+    return dmaStartOp.getTagMemRefOperandIndex();
+  }
+  // First operand for a dma finish operation.
+  return 0;
+}
+
+/// Doubles the buffer of the supplied memref on the specified 'affine.for'
+/// operation by adding a leading dimension of size two to the memref.
+/// Replaces all uses of the old memref by the new one while indexing the newly
+/// added dimension by the loop IV of the specified 'affine.for' operation
+/// modulo 2. Returns false if such a replacement cannot be performed.
+static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) {
+  auto *forBody = forOp.getBody();
+  OpBuilder bInner(forBody, forBody->begin());
+
+  // Doubles the shape with a leading dimension extent of 2.
+  auto doubleShape = [&](MemRefType oldMemRefType) -> MemRefType {
+    // Add the leading dimension in the shape for the double buffer.
+    ArrayRef<int64_t> oldShape = oldMemRefType.getShape();
+    SmallVector<int64_t, 4> newShape(1 + oldMemRefType.getRank());
+    newShape[0] = 2;
+    std::copy(oldShape.begin(), oldShape.end(), newShape.begin() + 1);
+    return MemRefType::Builder(oldMemRefType)
+        .setShape(newShape)
+        .setAffineMaps({});
+  };
+
+  auto oldMemRefType = oldMemRef.getType().cast<MemRefType>();
+  auto newMemRefType = doubleShape(oldMemRefType);
+
+  // The double buffer is allocated right before 'forInst'.
+  auto *forInst = forOp.getOperation();
+  OpBuilder bOuter(forInst);
+  // Put together alloc operands for any dynamic dimensions of the memref.
+  SmallVector<Value, 4> allocOperands;
+  unsigned dynamicDimCount = 0;
+  for (auto dimSize : oldMemRefType.getShape()) {
+    if (dimSize == -1)
+      allocOperands.push_back(bOuter.create<DimOp>(forInst->getLoc(), oldMemRef,
+                                                   dynamicDimCount++));
+  }
+
+  // Create and place the alloc right before the 'affine.for' operation.
+  Value newMemRef =
+      bOuter.create<AllocOp>(forInst->getLoc(), newMemRefType, allocOperands);
+
+  // Create 'iv mod 2' value to index the leading dimension.
+  auto d0 = bInner.getAffineDimExpr(0);
+  int64_t step = forOp.getStep();
+  auto modTwoMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
+                                  {d0.floorDiv(step) % 2});
+  auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap,
+                                                 forOp.getInductionVar());
+
+  // replaceAllMemRefUsesWith will succeed unless the forOp body has
+  // non-dereferencing uses of the memref (dealloc's are fine though).
+  if (failed(replaceAllMemRefUsesWith(
+          oldMemRef, newMemRef,
+          /*extraIndices=*/{ivModTwoOp},
+          /*indexRemap=*/AffineMap(),
+          /*extraOperands=*/{},
+          /*symbolOperands=*/{},
+          /*domInstFilter=*/&*forOp.getBody()->begin()))) {
+    LLVM_DEBUG(
+        forOp.emitError("memref replacement for double buffering failed"));
+    ivModTwoOp.erase();
+    return false;
+  }
+  // Insert the dealloc op right after the for loop.
+  bOuter.setInsertionPointAfter(forInst);
+  bOuter.create<DeallocOp>(forInst->getLoc(), newMemRef);
+
+  return true;
+}
+
+/// Returns success if the IR is in a valid state.
+void PipelineDataTransfer::runOnFunction() {
+  // Do a post order walk so that inner loop DMAs are processed first. This is
+  // necessary since 'affine.for' operations nested within would otherwise
+  // become invalid (erased) when the outer loop is pipelined (the pipelined one
+  // gets deleted and replaced by a prologue, a new steady-state loop and an
+  // epilogue).
+  forOps.clear();
+  getFunction().walk([&](AffineForOp forOp) { forOps.push_back(forOp); });
+  for (auto forOp : forOps)
+    runOnAffineForOp(forOp);
+}
+
+// Check if tags of the dma start op and dma wait op match.
+static bool checkTagMatch(AffineDmaStartOp startOp, AffineDmaWaitOp waitOp) {
+  if (startOp.getTagMemRef() != waitOp.getTagMemRef())
+    return false;
+  auto startIndices = startOp.getTagIndices();
+  auto waitIndices = waitOp.getTagIndices();
+  // Both of these have the same number of indices since they correspond to the
+  // same tag memref.
+  for (auto it = startIndices.begin(), wIt = waitIndices.begin(),
+            e = startIndices.end();
+       it != e; ++it, ++wIt) {
+    // Keep it simple for now, just checking if indices match.
+    // TODO(mlir-team): this would in general need to check if there is no
+    // intervening write writing to the same tag location, i.e., memory last
+    // write/data flow analysis. This is however sufficient/powerful enough for
+    // now since the DMA generation pass or the input for it will always have
+    // start/wait with matching tags (same SSA operand indices).
+    if (*it != *wIt)
+      return false;
+  }
+  return true;
+}
+
+// Identify matching DMA start/finish operations to overlap computation with.
+static void findMatchingStartFinishInsts(
+    AffineForOp forOp,
+    SmallVectorImpl<std::pair<Operation *, Operation *>> &startWaitPairs) {
+
+  // Collect outgoing DMA operations - needed to check for dependences below.
+  SmallVector<AffineDmaStartOp, 4> outgoingDmaOps;
+  for (auto &op : *forOp.getBody()) {
+    auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
+    if (dmaStartOp && dmaStartOp.isSrcMemorySpaceFaster())
+      outgoingDmaOps.push_back(dmaStartOp);
+  }
+
+  SmallVector<Operation *, 4> dmaStartInsts, dmaFinishInsts;
+  for (auto &op : *forOp.getBody()) {
+    // Collect DMA finish operations.
+    if (isa<AffineDmaWaitOp>(op)) {
+      dmaFinishInsts.push_back(&op);
+      continue;
+    }
+    auto dmaStartOp = dyn_cast<AffineDmaStartOp>(op);
+    if (!dmaStartOp)
+      continue;
+
+    // Only DMAs incoming into higher memory spaces are pipelined for now.
+    // TODO(bondhugula): handle outgoing DMA pipelining.
+    if (!dmaStartOp.isDestMemorySpaceFaster())
+      continue;
+
+    // Check for dependence with outgoing DMAs. Doing this conservatively.
+    // TODO(andydavis,bondhugula): use the dependence analysis to check for
+    // dependences between an incoming and outgoing DMA in the same iteration.
+    auto it = outgoingDmaOps.begin();
+    for (; it != outgoingDmaOps.end(); ++it) {
+      if (it->getDstMemRef() == dmaStartOp.getSrcMemRef())
+        break;
+    }
+    if (it != outgoingDmaOps.end())
+      continue;
+
+    // We only double buffer if the buffer is not live out of loop.
+    auto memref = dmaStartOp.getOperand(dmaStartOp.getFasterMemPos());
+    bool escapingUses = false;
+    for (auto *user : memref.getUsers()) {
+      // We can double buffer regardless of dealloc's outside the loop.
+      if (isa<DeallocOp>(user))
+        continue;
+      if (!forOp.getBody()->findAncestorOpInBlock(*user)) {
+        LLVM_DEBUG(llvm::dbgs()
+                       << "can't pipeline: buffer is live out of loop\n";);
+        escapingUses = true;
+        break;
+      }
+    }
+    if (!escapingUses)
+      dmaStartInsts.push_back(&op);
+  }
+
+  // For each start operation, we look for a matching finish operation.
+  for (auto *dmaStartInst : dmaStartInsts) {
+    for (auto *dmaFinishInst : dmaFinishInsts) {
+      if (checkTagMatch(cast<AffineDmaStartOp>(dmaStartInst),
+                        cast<AffineDmaWaitOp>(dmaFinishInst))) {
+        startWaitPairs.push_back({dmaStartInst, dmaFinishInst});
+        break;
+      }
+    }
+  }
+}
+
+/// Overlap DMA transfers with computation in this loop. If successful,
+/// 'forOp' is deleted, and a prologue, a new pipelined loop, and epilogue are
+/// inserted right before where it was.
+void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
+  auto mayBeConstTripCount = getConstantTripCount(forOp);
+  if (!mayBeConstTripCount.hasValue()) {
+    LLVM_DEBUG(
+        forOp.emitRemark("won't pipeline due to unknown trip count loop"));
+    return;
+  }
+
+  SmallVector<std::pair<Operation *, Operation *>, 4> startWaitPairs;
+  findMatchingStartFinishInsts(forOp, startWaitPairs);
+
+  if (startWaitPairs.empty()) {
+    LLVM_DEBUG(forOp.emitRemark("No dma start/finish pairs\n"));
+    return;
+  }
+
+  // Double the buffers for the higher memory space memref's.
+  // Identify memref's to replace by scanning through all DMA start
+  // operations. A DMA start operation has two memref's - the one from the
+  // higher level of memory hierarchy is the one to double buffer.
+  // TODO(bondhugula): check whether double-buffering is even necessary.
+  // TODO(bondhugula): make this work with different layouts: assuming here that
+  // the dimension we are adding here for the double buffering is the outermost
+  // dimension.
+  for (auto &pair : startWaitPairs) {
+    auto *dmaStartInst = pair.first;
+    Value oldMemRef = dmaStartInst->getOperand(
+        cast<AffineDmaStartOp>(dmaStartInst).getFasterMemPos());
+    if (!doubleBuffer(oldMemRef, forOp)) {
+      // Normally, double buffering should not fail because we already checked
+      // that there are no uses outside.
+      LLVM_DEBUG(llvm::dbgs()
+                     << "double buffering failed for" << dmaStartInst << "\n";);
+      // IR still valid and semantically correct.
+      return;
+    }
+    // If the old memref has no more uses, remove its 'dead' alloc if it was
+    // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
+    // operation could have been used on it if it was dynamically shaped in
+    // order to create the double buffer above.)
+    // '-canonicalize' does this in a more general way, but we'll anyway do the
+    // simple/common case so that the output / test cases looks clear.
+    if (auto *allocInst = oldMemRef.getDefiningOp()) {
+      if (oldMemRef.use_empty()) {
+        allocInst->erase();
+      } else if (oldMemRef.hasOneUse()) {
+        if (auto dealloc = dyn_cast<DeallocOp>(*oldMemRef.user_begin())) {
+          dealloc.erase();
+          allocInst->erase();
+        }
+      }
+    }
+  }
+
+  // Double the buffers for tag memrefs.
+  for (auto &pair : startWaitPairs) {
+    auto *dmaFinishInst = pair.second;
+    Value oldTagMemRef =
+        dmaFinishInst->getOperand(getTagMemRefPos(*dmaFinishInst));
+    if (!doubleBuffer(oldTagMemRef, forOp)) {
+      LLVM_DEBUG(llvm::dbgs() << "tag double buffering failed\n";);
+      return;
+    }
+    // If the old tag has no uses or a single dealloc use, remove it.
+    // (canonicalization handles more complex cases).
+    if (auto *tagAllocInst = oldTagMemRef.getDefiningOp()) {
+      if (oldTagMemRef.use_empty()) {
+        tagAllocInst->erase();
+      } else if (oldTagMemRef.hasOneUse()) {
+        if (auto dealloc = dyn_cast<DeallocOp>(*oldTagMemRef.user_begin())) {
+          dealloc.erase();
+          tagAllocInst->erase();
+        }
+      }
+    }
+  }
+
+  // Double buffering would have invalidated all the old DMA start/wait insts.
+  startWaitPairs.clear();
+  findMatchingStartFinishInsts(forOp, startWaitPairs);
+
+  // Store shift for operation for later lookup for AffineApplyOp's.
+  DenseMap<Operation *, unsigned> instShiftMap;
+  for (auto &pair : startWaitPairs) {
+    auto *dmaStartInst = pair.first;
+    assert(isa<AffineDmaStartOp>(dmaStartInst));
+    instShiftMap[dmaStartInst] = 0;
+    // Set shifts for DMA start op's affine operand computation slices to 0.
+    SmallVector<AffineApplyOp, 4> sliceOps;
+    mlir::createAffineComputationSlice(dmaStartInst, &sliceOps);
+    if (!sliceOps.empty()) {
+      for (auto sliceOp : sliceOps) {
+        instShiftMap[sliceOp.getOperation()] = 0;
+      }
+    } else {
+      // If a slice wasn't created, the reachable affine.apply op's from its
+      // operands are the ones that go with it.
+      SmallVector<Operation *, 4> affineApplyInsts;
+      SmallVector<Value, 4> operands(dmaStartInst->getOperands());
+      getReachableAffineApplyOps(operands, affineApplyInsts);
+      for (auto *op : affineApplyInsts) {
+        instShiftMap[op] = 0;
+      }
+    }
+  }
+  // Everything else (including compute ops and dma finish) are shifted by one.
+  for (auto &op : *forOp.getBody()) {
+    if (instShiftMap.find(&op) == instShiftMap.end()) {
+      instShiftMap[&op] = 1;
+    }
+  }
+
+  // Get shifts stored in map.
+  std::vector<uint64_t> shifts(forOp.getBody()->getOperations().size());
+  unsigned s = 0;
+  for (auto &op : *forOp.getBody()) {
+    assert(instShiftMap.find(&op) != instShiftMap.end());
+    shifts[s++] = instShiftMap[&op];
+
+    // Tagging operations with shifts for debugging purposes.
+    LLVM_DEBUG({
+      OpBuilder b(&op);
+      op.setAttr("shift", b.getI64IntegerAttr(shifts[s - 1]));
+    });
+  }
+
+  if (!isInstwiseShiftValid(forOp, shifts)) {
+    // Violates dependences.
+    LLVM_DEBUG(llvm::dbgs() << "Shifts invalid - unexpected\n";);
+    return;
+  }
+
+  if (failed(instBodySkew(forOp, shifts))) {
+    LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";);
+    return;
+  }
+}
+
+static PassRegistration<PipelineDataTransfer> pass(
+    "affine-pipeline-data-transfer",
+    "Pipeline non-blocking data transfers between explicitly managed levels of "
+    "the memory hierarchy");