150
|
1 //===- Traits.cpp - Common op traits shared by dialects -------------------===//
|
|
2 //
|
|
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
4 // See https://llvm.org/LICENSE.txt for license information.
|
|
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
6 //
|
|
7 //===----------------------------------------------------------------------===//
|
|
8
|
|
9 #include "mlir/Dialect/Traits.h"
|
|
10 #include "mlir/IR/StandardTypes.h"
|
|
11 #include "mlir/IR/TypeUtilities.h"
|
|
12 #include "llvm/Support/FormatVariadic.h"
|
|
13
|
|
14 using namespace mlir;
|
|
15
|
|
16 bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
|
|
17 ArrayRef<int64_t> shape2,
|
|
18 SmallVectorImpl<int64_t> &resultShape) {
|
|
19 // To compute the result broadcasted shape, we compare operand shapes
|
|
20 // element-wise: starting with the trailing dimensions, and working the
|
|
21 // way backward. Two dimensions are compatible when
|
|
22 // 1. they are equal, or
|
|
23 // 2. one of them is 1
|
|
24 // The result shape has the maximum among the two inputs at every
|
|
25 // dimension index.
|
|
26
|
|
27 resultShape.clear();
|
|
28 if (shape1.size() > shape2.size()) {
|
|
29 std::copy(shape1.begin(), shape1.end(), std::back_inserter(resultShape));
|
|
30 } else {
|
|
31 std::copy(shape2.begin(), shape2.end(), std::back_inserter(resultShape));
|
|
32 }
|
|
33
|
|
34 auto i1 = shape1.rbegin(), e1 = shape1.rend();
|
|
35 auto i2 = shape2.rbegin(), e2 = shape2.rend();
|
|
36 auto iR = resultShape.rbegin();
|
|
37
|
|
38 // Check each dimension is consistent.
|
|
39 for (; i1 != e1 && i2 != e2; ++i1, ++i2, ++iR) {
|
|
40 if (*i1 == -1 || *i2 == -1) {
|
|
41 // One or both dimensions is unknown. Follow TensorFlow behavior:
|
|
42 // - If either dimension is greater than 1, we assume that the program is
|
|
43 // correct, and the other dimension will be broadcast to match it.
|
|
44 // - If either dimension is 1, the other dimension is the output.
|
|
45 if (*i1 > 1) {
|
|
46 *iR = *i1;
|
|
47 } else if (*i2 > 1) {
|
|
48 *iR = *i2;
|
|
49 } else if (*i1 == 1) {
|
|
50 *iR = *i2;
|
|
51 } else if (*i2 == 1) {
|
|
52 *iR = *i1;
|
|
53 } else {
|
|
54 *iR = -1;
|
|
55 }
|
|
56 } else {
|
|
57 if (*i1 == *i2 || *i2 == 1) {
|
|
58 *iR = *i1;
|
|
59 } else if (*i1 == 1) {
|
|
60 *iR = *i2;
|
|
61 } else {
|
|
62 // This dimension of the two operand types is incompatible.
|
|
63 resultShape.clear();
|
|
64 return false;
|
|
65 }
|
|
66 }
|
|
67 }
|
|
68
|
|
69 return true;
|
|
70 }
|
|
71
|
|
72 /// Returns the shape of the given type. Scalars will be considered as having a
|
|
73 /// shape with zero dimensions.
|
|
74 static ArrayRef<int64_t> getShape(Type type) {
|
|
75 if (auto sType = type.dyn_cast<ShapedType>())
|
|
76 return sType.getShape();
|
|
77 return {};
|
|
78 }
|
|
79
|
|
80 /// Returns the result broadcast composition type from the two given types by
|
|
81 /// following NumPy broadcast semantics. Returned type may have dynamic shape if
|
|
82 /// either of the input types has dynamic shape. Returns null type if the two
|
|
83 /// given types are not broadcast-compatible.
|
|
84 ///
|
|
85 /// elementType, if specified, will be used as the element type of the
|
|
86 /// broadcasted result type. Otherwise it is required that the element type of
|
|
87 /// type1 and type2 is the same and this element type will be used as the
|
|
88 /// resultant element type.
|
|
89 Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
|
|
90 Type elementType) {
|
|
91 // If the elementType is not specified, then the use the common element type
|
|
92 // of the inputs or fail if there is no common element type.
|
|
93 if (!elementType) {
|
|
94 elementType = getElementTypeOrSelf(type1);
|
|
95 if (elementType != getElementTypeOrSelf(type2))
|
|
96 return {};
|
|
97 }
|
|
98
|
|
99 // If one of the types is unranked tensor, then the other type shouldn't be
|
|
100 // vector and the result should have unranked tensor type.
|
|
101 if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
|
|
102 if (type1.isa<VectorType>() || type2.isa<VectorType>())
|
|
103 return {};
|
|
104 return UnrankedTensorType::get(elementType);
|
|
105 }
|
|
106
|
|
107 // Returns the type kind if the given type is a vector or ranked tensor type.
|
|
108 // Returns llvm::None otherwise.
|
|
109 auto getCompositeTypeKind = [](Type type) -> Optional<StandardTypes::Kind> {
|
|
110 if (type.isa<VectorType>() || type.isa<RankedTensorType>())
|
|
111 return static_cast<StandardTypes::Kind>(type.getKind());
|
|
112 return llvm::None;
|
|
113 };
|
|
114
|
|
115 // Make sure the composite type, if has, is consistent.
|
|
116 auto compositeKind1 = getCompositeTypeKind(type1);
|
|
117 auto compositeKind2 = getCompositeTypeKind(type2);
|
|
118 Optional<StandardTypes::Kind> resultCompositeKind;
|
|
119
|
|
120 if (compositeKind1 && compositeKind2) {
|
|
121 // Disallow mixing vector and tensor.
|
|
122 if (compositeKind1 != compositeKind2)
|
|
123 return {};
|
|
124 resultCompositeKind = compositeKind1;
|
|
125 } else if (compositeKind1) {
|
|
126 resultCompositeKind = compositeKind1;
|
|
127 } else if (compositeKind2) {
|
|
128 resultCompositeKind = compositeKind2;
|
|
129 }
|
|
130
|
|
131 // Get the shape of each type.
|
|
132 SmallVector<int64_t, 4> resultShape;
|
|
133 if (!getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
|
|
134 return {};
|
|
135
|
|
136 // Compose the final broadcasted type
|
|
137 if (resultCompositeKind == StandardTypes::Vector)
|
|
138 return VectorType::get(resultShape, elementType);
|
|
139 if (resultCompositeKind == StandardTypes::RankedTensor)
|
|
140 return RankedTensorType::get(resultShape, elementType);
|
|
141 return elementType;
|
|
142 }
|
|
143
|
|
144 /// Returns a tuple corresponding to whether range has tensor or vector type.
|
|
145 template <typename iterator_range>
|
|
146 static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
|
|
147 return std::make_tuple(
|
|
148 llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); }),
|
|
149 llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }));
|
|
150 }
|
|
151
|
|
152 static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
|
|
153 ArrayRef<int64_t> shape2) {
|
|
154 auto isCompatible = [](int64_t dim1, int64_t dim2) {
|
|
155 return dim1 == dim2 || dim1 == -1 || dim2 == -1;
|
|
156 };
|
|
157 if (shape1.size() != shape2.size())
|
|
158 return false;
|
|
159 for (auto p : llvm::zip(shape1, shape2))
|
|
160 if (!isCompatible(std::get<0>(p), std::get<1>(p)))
|
|
161 return false;
|
|
162 return true;
|
|
163 }
|
|
164
|
|
165 static std::string getShapeString(ArrayRef<int64_t> shape) {
|
|
166 // TODO: should replace with printing shape more uniformly across here and
|
|
167 // when in type.
|
|
168 return std::string(
|
|
169 formatv("'{0:$[x]}'", llvm::make_range(shape.begin(), shape.end())));
|
|
170 }
|
|
171
|
|
172 LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
|
|
173 // Ensure broadcasting only tensor or only vector types.
|
|
174 auto operandsHasTensorVectorType =
|
|
175 hasTensorOrVectorType(op->getOperandTypes());
|
|
176 auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes());
|
|
177 if ((std::get<0>(operandsHasTensorVectorType) ||
|
|
178 std::get<0>(resultsHasTensorVectorType)) &&
|
|
179 (std::get<1>(operandsHasTensorVectorType) ||
|
|
180 std::get<1>(resultsHasTensorVectorType)))
|
|
181 return op->emitError("cannot broadcast vector with tensor");
|
|
182
|
|
183 auto rankedOperands = make_filter_range(
|
|
184 op->getOperandTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
|
|
185
|
|
186 // If all operands are unranked, then all result shapes are possible.
|
|
187 if (rankedOperands.empty())
|
|
188 return success();
|
|
189
|
|
190 // Compute broadcasted shape of operands (which requires that operands are
|
|
191 // broadcast compatible). The results need to be broadcast compatible with
|
|
192 // this result shape.
|
|
193 SmallVector<int64_t, 4> resultShape;
|
|
194 (void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {},
|
|
195 resultShape);
|
|
196 for (auto other : make_early_inc_range(rankedOperands)) {
|
|
197 SmallVector<int64_t, 4> temp = resultShape;
|
|
198 if (!util::getBroadcastedShape(temp, getShape(other), resultShape))
|
|
199 return op->emitOpError("operands don't have broadcast-compatible shapes");
|
|
200 }
|
|
201
|
|
202 auto rankedResults = make_filter_range(
|
|
203 op->getResultTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
|
|
204
|
173
|
205 // If all of the results are unranked then no further verification.
|
150
|
206 if (rankedResults.empty())
|
|
207 return success();
|
|
208
|
|
209 for (auto type : rankedResults) {
|
|
210 ArrayRef<int64_t> actualSuffix =
|
|
211 getShape(type).take_back(resultShape.size());
|
|
212 if (!areCompatibleShapes(actualSuffix, resultShape))
|
|
213 return op->emitOpError()
|
|
214 << "result type " << getShapeString(getShape(type))
|
|
215 << " not broadcast compatible with broadcasted operands's shapes "
|
|
216 << getShapeString(resultShape);
|
|
217 }
|
|
218 return success();
|
|
219 }
|