1 | //===- LinalgOps.cpp - Implementation of the linalg operations ------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the Linalg operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Linalg/IR/Linalg.h" |
14 | |
15 | #include "mlir/AsmParser/AsmParser.h" |
16 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
17 | #include "mlir/Dialect/Arith/IR/Arith.h" |
18 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
19 | #include "mlir/Dialect/Complex/IR/Complex.h" |
20 | #include "mlir/Dialect/Math/IR/Math.h" |
21 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
22 | #include "mlir/Dialect/SCF/IR/SCF.h" |
23 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" |
24 | #include "mlir/Dialect/Tensor/IR/Tensor.h" |
25 | #include "mlir/Dialect/Utils/IndexingUtils.h" |
26 | #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" |
27 | #include "mlir/Dialect/Utils/StaticValueUtils.h" |
28 | #include "mlir/IR/AffineExprVisitor.h" |
29 | #include "mlir/IR/AffineMap.h" |
30 | #include "mlir/IR/BuiltinAttributes.h" |
31 | #include "mlir/IR/BuiltinTypeInterfaces.h" |
32 | #include "mlir/IR/Matchers.h" |
33 | #include "mlir/IR/OpImplementation.h" |
34 | #include "mlir/IR/OperationSupport.h" |
35 | #include "mlir/IR/PatternMatch.h" |
36 | #include "mlir/Interfaces/InferTypeOpInterface.h" |
37 | |
38 | #include "llvm/ADT/DenseMap.h" |
39 | #include "llvm/ADT/SmallSet.h" |
40 | #include "llvm/ADT/StringSet.h" |
41 | #include "llvm/ADT/TypeSwitch.h" |
42 | #include "llvm/Support/FormatVariadic.h" |
43 | #include "llvm/Support/MathExtras.h" |
44 | #include "llvm/Support/raw_ostream.h" |
45 | #include <optional> |
46 | |
47 | using namespace mlir; |
48 | using namespace mlir::linalg; |
49 | |
50 | /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`. |
51 | static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, |
52 | int64_t dim) { |
53 | auto type = cast<ShapedType>(v.getType()); |
54 | if (!type.isDynamicDim(dim)) |
55 | return builder.getIndexAttr(value: type.getDimSize(dim)); |
56 | |
57 | return getAsOpFoldResult( |
58 | val: TypeSwitch<Type, Value>(v.getType()) |
59 | .Case<RankedTensorType>(caseFn: [&](RankedTensorType t) -> Value { |
60 | return builder.create<tensor::DimOp>(loc, v, dim); |
61 | }) |
62 | .Case<MemRefType>(caseFn: [&](MemRefType t) -> Value { |
63 | return builder.create<memref::DimOp>(loc, v, dim); |
64 | })); |
65 | } |
66 | |
67 | /// Returns a memref.subview or a tensor.extract_slice based on the type of the |
68 | /// `source`. |
69 | static Value getSlice(OpBuilder &b, Location loc, Value source, |
70 | ArrayRef<OpFoldResult> offsets, |
71 | ArrayRef<OpFoldResult> sizes, |
72 | ArrayRef<OpFoldResult> strides) { |
73 | return TypeSwitch<Type, Value>(source.getType()) |
74 | .Case<RankedTensorType>(caseFn: [&](RankedTensorType t) -> Value { |
75 | return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes, |
76 | strides); |
77 | }) |
78 | .Case<MemRefType>(caseFn: [&](MemRefType type) -> Value { |
79 | return b.create<memref::SubViewOp>(loc, source, offsets, sizes, |
80 | strides); |
81 | }) |
82 | .Default(defaultFn: [&](Type t) { return nullptr; }); |
83 | } |
84 | |
85 | //===----------------------------------------------------------------------===// |
86 | // Helper functions |
87 | //===----------------------------------------------------------------------===// |
88 | |
89 | Value linalg::createOrFoldDimOp(OpBuilder &b, Location loc, Value source, |
90 | int64_t dim) { |
91 | if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType())) |
92 | return b.createOrFold<memref::DimOp>(loc, source, dim); |
93 | if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType())) |
94 | return b.createOrFold<tensor::DimOp>(loc, source, dim); |
95 | llvm_unreachable("Expected MemRefType or TensorType" ); |
96 | } |
97 | |
98 | OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source, |
99 | int64_t dim) { |
100 | auto shapedType = llvm::cast<ShapedType>(source.getType()); |
101 | if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) |
102 | return createOrFoldDimOp(b, loc, source, dim); |
103 | return b.getIndexAttr(value: shapedType.getDimSize(dim)); |
104 | } |
105 | |
106 | //===----------------------------------------------------------------------===// |
107 | // Support for named Linalg ops defined in ods-gen. |
108 | //===----------------------------------------------------------------------===// |
109 | |
110 | using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &, |
111 | ArrayRef<NamedAttribute>)>; |
112 | |
113 | /// Fills the region of a structured operation using the provided |
114 | /// `regionBuilder`. The method is used by both named structured ops created by |
115 | /// ods-gen and by manually defined C++ ops. It is called by both builders and |
116 | /// parsers and creates a block with arguments corresponding to the elemental |
117 | /// types of `inputTypes` and `outputTypes`. All output types are asserted to be |
118 | /// ShapedType. |
119 | static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion, |
120 | TypeRange inputTypes, TypeRange outputTypes, |
121 | ArrayRef<NamedAttribute> attrs, |
122 | RegionBuilderFn regionBuilder) { |
123 | assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>)); |
124 | |
125 | SmallVector<Type, 8> argTypes; |
126 | SmallVector<Location, 8> argLocs; |
127 | for (auto containers : {inputTypes, outputTypes}) { |
128 | for (auto t : containers) { |
129 | argTypes.push_back( |
130 | Elt: isa<MemRefType, RankedTensorType>(Val: t) ? getElementTypeOrSelf(type: t) : t); |
131 | |
132 | // TODO: Pass in a proper location here. |
133 | argLocs.push_back(Elt: opBuilder.getUnknownLoc()); |
134 | } |
135 | } |
136 | |
137 | // RAII. |
138 | OpBuilder::InsertionGuard guard(opBuilder); |
139 | Block *body = |
140 | opBuilder.createBlock(parent: ®ion, /*insertPt=*/{}, argTypes, locs: argLocs); |
141 | |
142 | opBuilder.setInsertionPointToStart(body); |
143 | ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder); |
144 | regionBuilder(b, *body, attrs); |
145 | |
146 | // indexing_maps is an auto-generated method. |
147 | |
148 | // iterator_types is an auto-generated method. |
149 | } |
150 | |
151 | /// Creates a structured operation given `inputs`, `outputs`, and `attributes`. |
152 | /// The result types are derived automatically if `resultTensorTypes` is none. |
153 | /// The body of the operation is filled using `regionBuilder`. All ods-gen |
154 | /// created structured operations use the method to implement their builders. |
155 | static void buildStructuredOp(OpBuilder &b, OperationState &state, |
156 | std::optional<TypeRange> resultTensorTypes, |
157 | ValueRange inputs, ValueRange outputs, |
158 | ArrayRef<NamedAttribute> attributes, |
159 | RegionBuilderFn regionBuilder) { |
160 | // Derive the result types if needed. |
161 | SmallVector<Type> derivedResultTypes = |
162 | resultTensorTypes.value_or(u: TypeRange()); |
163 | if (!resultTensorTypes) |
164 | copy_if(Range: outputs.getTypes(), Out: std::back_inserter(x&: derivedResultTypes), |
165 | P: llvm::IsaPred<RankedTensorType>); |
166 | |
167 | state.addOperands(newOperands: inputs); |
168 | state.addOperands(newOperands: outputs); |
169 | state.addTypes(newTypes: derivedResultTypes); |
170 | state.addAttributes(newAttributes: attributes); |
171 | state.addAttribute( |
172 | "operandSegmentSizes" , |
173 | b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()), |
174 | static_cast<int32_t>(outputs.size())})); |
175 | |
176 | // Create and fill the region of the structured operation. |
177 | Region ®ion = *state.addRegion(); |
178 | fillStructuredOpRegion(opBuilder&: b, region, inputTypes: TypeRange(inputs), outputTypes: TypeRange(outputs), |
179 | attrs: state.attributes.getAttrs(), regionBuilder); |
180 | } |
181 | |
182 | /// Common parsing used for both named structured ops created by ods-gen and by |
183 | /// manually defined C++ ops. Does not handle regions. |
184 | static ParseResult |
185 | parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, |
186 | SmallVectorImpl<Type> &inputTypes, |
187 | SmallVectorImpl<Type> &outputTypes, |
188 | bool addOperandSegmentSizes = true) { |
189 | SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc; |
190 | SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands, |
191 | outputsOperands; |
192 | |
193 | if (succeeded(result: parser.parseOptionalLess())) { |
194 | if (parser.parseAttribute(result&: result.propertiesAttr) || parser.parseGreater()) |
195 | return failure(); |
196 | } |
197 | attrsLoc = parser.getCurrentLocation(); |
198 | if (parser.parseOptionalAttrDict(result&: result.attributes)) |
199 | return failure(); |
200 | |
201 | if (succeeded(result: parser.parseOptionalKeyword(keyword: "ins" ))) { |
202 | if (parser.parseLParen()) |
203 | return failure(); |
204 | |
205 | inputsOperandsLoc = parser.getCurrentLocation(); |
206 | if (parser.parseOperandList(result&: inputsOperands) || |
207 | parser.parseColonTypeList(result&: inputTypes) || parser.parseRParen()) |
208 | return failure(); |
209 | } |
210 | |
211 | if (succeeded(result: parser.parseOptionalKeyword(keyword: "outs" ))) { |
212 | outputsOperandsLoc = parser.getCurrentLocation(); |
213 | if (parser.parseLParen() || parser.parseOperandList(result&: outputsOperands) || |
214 | parser.parseColonTypeList(result&: outputTypes) || parser.parseRParen()) |
215 | return failure(); |
216 | } |
217 | |
218 | if (parser.resolveOperands(operands&: inputsOperands, types&: inputTypes, loc: inputsOperandsLoc, |
219 | result&: result.operands) || |
220 | parser.resolveOperands(operands&: outputsOperands, types&: outputTypes, loc: outputsOperandsLoc, |
221 | result&: result.operands)) |
222 | return failure(); |
223 | |
224 | if (addOperandSegmentSizes) { |
225 | // This is a bit complex because we're trying to be backward compatible with |
226 | // operation syntax that mix the inherent attributes and the discardable |
227 | // ones in the same dictionary. If the properties are used, we append the |
228 | // operandSegmentSizes there directly. Otherwise we append it to the |
229 | // discardable attributes dictionary where it is handled by the generic |
230 | // Operation::create(...) method. |
231 | if (result.propertiesAttr) { |
232 | NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr); |
233 | attrs.append("operandSegmentSizes" , |
234 | parser.getBuilder().getDenseI32ArrayAttr( |
235 | {static_cast<int32_t>(inputsOperands.size()), |
236 | static_cast<int32_t>(outputsOperands.size())})); |
237 | result.propertiesAttr = attrs.getDictionary(parser.getContext()); |
238 | } else { |
239 | result.addAttribute("operandSegmentSizes" , |
240 | parser.getBuilder().getDenseI32ArrayAttr( |
241 | {static_cast<int32_t>(inputsOperands.size()), |
242 | static_cast<int32_t>(outputsOperands.size())})); |
243 | } |
244 | } |
245 | if (!result.propertiesAttr) { |
246 | std::optional<RegisteredOperationName> info = |
247 | result.name.getRegisteredInfo(); |
248 | if (info) { |
249 | if (failed(result: info->verifyInherentAttrs(attributes&: result.attributes, emitError: [&]() { |
250 | return parser.emitError(loc: attrsLoc) |
251 | << "'" << result.name.getStringRef() << "' op " ; |
252 | }))) |
253 | return failure(); |
254 | } |
255 | } |
256 | return success(); |
257 | } |
258 | |
259 | static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, |
260 | ValueRange outputs) { |
261 | if (!inputs.empty()) |
262 | p << " ins(" << inputs << " : " << inputs.getTypes() << ")" ; |
263 | if (!outputs.empty()) |
264 | p << " outs(" << outputs << " : " << outputs.getTypes() << ")" ; |
265 | } |
266 | |
267 | //===----------------------------------------------------------------------===// |
268 | // Specific parsing and printing for named structured ops created by ods-gen. |
269 | //===----------------------------------------------------------------------===// |
270 | |
271 | static ParseResult parseNamedStructuredOpRegion( |
272 | OpAsmParser &parser, Region ®ion, unsigned numRegionArgs, |
273 | TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs, |
274 | RegionBuilderFn regionBuilder) { |
275 | if (numRegionArgs != inputTypes.size() + outputTypes.size()) { |
276 | return parser.emitError( |
277 | loc: parser.getCurrentLocation(), |
278 | message: llvm::formatv(Fmt: "[parseNamedStructuredOpRegion] ods-gen generated " |
279 | "region expects {0} args, got {1}" , |
280 | Vals&: numRegionArgs, Vals: inputTypes.size() + outputTypes.size())); |
281 | } |
282 | |
283 | OpBuilder opBuilder(parser.getContext()); |
284 | fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs, |
285 | regionBuilder); |
286 | return success(); |
287 | } |
288 | |
289 | static ParseResult |
290 | parseNamedStructuredOpResults(OpAsmParser &parser, |
291 | SmallVectorImpl<Type> &resultTypes) { |
292 | if (parser.parseOptionalArrowTypeList(result&: resultTypes)) |
293 | return failure(); |
294 | return success(); |
295 | } |
296 | |
297 | static ParseResult parseNamedStructuredOp(OpAsmParser &parser, |
298 | OperationState &result, |
299 | unsigned numRegionArgs, |
300 | RegionBuilderFn regionBuilder) { |
301 | // TODO: Enable when ods-gen supports captures. |
302 | SmallVector<Type, 1> inputTypes, outputTypes; |
303 | if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) |
304 | return failure(); |
305 | |
306 | // TODO: consider merging results parsing into region parsing. |
307 | // Need to wait for declarative assembly resolution to decide. |
308 | SmallVector<Type, 1> outputTensorsTypes; |
309 | if (parseNamedStructuredOpResults(parser, resultTypes&: outputTensorsTypes)) |
310 | return failure(); |
311 | result.addTypes(newTypes: outputTensorsTypes); |
312 | |
313 | std::unique_ptr<Region> region = std::make_unique<Region>(); |
314 | if (parseNamedStructuredOpRegion(parser, region&: *region, numRegionArgs, inputTypes, |
315 | outputTypes, attrs: result.attributes.getAttrs(), |
316 | regionBuilder)) |
317 | return failure(); |
318 | result.addRegion(region: std::move(region)); |
319 | |
320 | return success(); |
321 | } |
322 | |
323 | static void printNamedStructuredOpResults(OpAsmPrinter &p, |
324 | TypeRange resultTypes) { |
325 | if (resultTypes.empty()) |
326 | return; |
327 | p.printOptionalArrowTypeList(types&: resultTypes); |
328 | } |
329 | |
330 | static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, |
331 | ValueRange inputs, ValueRange outputs) { |
332 | p.printOptionalAttrDict( |
333 | attrs: op->getAttrs(), |
334 | /*elidedAttrs=*/{"operandSegmentSizes" , |
335 | // See generated code in |
336 | // LinalgNamedStructuredOps.yamlgen.cpp.inc |
337 | "linalg.memoized_indexing_maps" }); |
338 | |
339 | // Printing is shared with generic ops, except for the region and |
340 | // attributes. |
341 | printCommonStructuredOpParts(p, inputs, outputs); |
342 | |
343 | // Results printing. |
344 | printNamedStructuredOpResults(p, resultTypes: op->getResultTypes()); |
345 | |
346 | // Region is elided. |
347 | } |
348 | |
349 | //===----------------------------------------------------------------------===// |
350 | // Region builder helper. |
351 | // TODO: Move this to a utility library. |
352 | // The public methods on this class are referenced directly from generated code. |
353 | // Helper build the unary, binary, and type conversion functions defined by the |
354 | // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this |
355 | // class. |
356 | // |
357 | // Implementations of the math functions must be polymorphic over numeric types, |
358 | // internally performing necessary casts. If the function application makes no |
359 | // sense, then the only recourse is to assert and return nullptr. This can be |
360 | // extended later if it becomes possible to fail construction of the region. The |
361 | // invariant should be enforced at a higher level. |
362 | // |
363 | // TODO: These helpers are currently type polymorphic over the class of integer |
364 | // and floating point types, but they will not internally cast within bit |
365 | // widths of a class (mixed precision such as i8->i32) or across classes |
366 | // (i.e. mixed float and integer). Many such combinations are ambiguous or need |
367 | // to be handled with care and work is being considered to extend the op |
368 | // language to make such cases explicit. In the mean-time, violating this will |
369 | // fail verification, which is deemed acceptable. |
370 | //===----------------------------------------------------------------------===// |
371 | |
372 | namespace { |
373 | |
374 | class RegionBuilderHelper { |
375 | public: |
376 | RegionBuilderHelper(OpBuilder &builder, Block &block) |
377 | : builder(builder), block(block) {} |
378 | |
379 | // Build the unary functions defined by OpDSL. |
380 | Value buildUnaryFn(UnaryFn unaryFn, Value arg) { |
381 | if (!isFloatingPoint(value: arg)) |
382 | llvm_unreachable("unsupported non numeric type" ); |
383 | OpBuilder::InsertionGuard g(builder); |
384 | builder.setInsertionPointToEnd(&block); |
385 | switch (unaryFn) { |
386 | case UnaryFn::exp: |
387 | return builder.create<math::ExpOp>(arg.getLoc(), arg); |
388 | case UnaryFn::log: |
389 | return builder.create<math::LogOp>(arg.getLoc(), arg); |
390 | case UnaryFn::abs: |
391 | return builder.create<math::AbsFOp>(arg.getLoc(), arg); |
392 | case UnaryFn::ceil: |
393 | return builder.create<math::CeilOp>(arg.getLoc(), arg); |
394 | case UnaryFn::floor: |
395 | return builder.create<math::FloorOp>(arg.getLoc(), arg); |
396 | case UnaryFn::negf: |
397 | return builder.create<arith::NegFOp>(arg.getLoc(), arg); |
398 | } |
399 | llvm_unreachable("unsupported unary function" ); |
400 | } |
401 | |
402 | // Build the binary functions defined by OpDSL. |
403 | Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) { |
404 | bool allComplex = isComplex(value: arg0) && isComplex(value: arg1); |
405 | bool allFloatingPoint = isFloatingPoint(value: arg0) && isFloatingPoint(value: arg1); |
406 | bool allInteger = isInteger(value: arg0) && isInteger(value: arg1); |
407 | bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 && |
408 | arg1.getType().getIntOrFloatBitWidth() == 1; |
409 | if (!allComplex && !allFloatingPoint && !allInteger) |
410 | llvm_unreachable("unsupported non numeric type" ); |
411 | OpBuilder::InsertionGuard g(builder); |
412 | builder.setInsertionPointToEnd(&block); |
413 | switch (binaryFn) { |
414 | case BinaryFn::add: |
415 | if (allComplex) |
416 | return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1); |
417 | if (allFloatingPoint) |
418 | return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1); |
419 | if (allBool) |
420 | return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1); |
421 | return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1); |
422 | case BinaryFn::sub: |
423 | if (allComplex) |
424 | return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1); |
425 | if (allFloatingPoint) |
426 | return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1); |
427 | if (allBool) |
428 | llvm_unreachable("unsupported operation: sub with bools" ); |
429 | return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1); |
430 | case BinaryFn::mul: |
431 | if (allComplex) |
432 | return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1); |
433 | if (allFloatingPoint) |
434 | return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1); |
435 | if (allBool) |
436 | return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1); |
437 | return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1); |
438 | case BinaryFn::div: |
439 | if (allComplex) |
440 | return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1); |
441 | if (allFloatingPoint) |
442 | return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1); |
443 | if (allBool) |
444 | llvm_unreachable("unsupported operation: div with bools" ); |
445 | return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1); |
446 | case BinaryFn::div_unsigned: |
447 | if (!allInteger || allBool) |
448 | llvm_unreachable("unsupported operation: unsigned div not on uint" ); |
449 | return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1); |
450 | case BinaryFn::max_signed: |
451 | assert(!allComplex); |
452 | if (allFloatingPoint) |
453 | return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1); |
454 | return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1); |
455 | case BinaryFn::min_signed: |
456 | assert(!allComplex); |
457 | if (allFloatingPoint) |
458 | return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1); |
459 | return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1); |
460 | case BinaryFn::max_unsigned: |
461 | assert(!allComplex); |
462 | if (allFloatingPoint) |
463 | return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1); |
464 | return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1); |
465 | case BinaryFn::min_unsigned: |
466 | assert(!allComplex); |
467 | if (allFloatingPoint) |
468 | return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1); |
469 | return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1); |
470 | } |
471 | llvm_unreachable("unsupported binary function" ); |
472 | } |
473 | |
474 | // Build the type functions defined by OpDSL. |
475 | Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) { |
476 | switch (typeFn) { |
477 | case TypeFn::cast_signed: |
478 | return cast(toType, operand, isUnsignedCast: false); |
479 | case TypeFn::cast_unsigned: |
480 | return cast(toType, operand, isUnsignedCast: true); |
481 | } |
482 | llvm_unreachable("unsupported type conversion function" ); |
483 | } |
484 | |
485 | void yieldOutputs(ValueRange values) { |
486 | OpBuilder::InsertionGuard g(builder); |
487 | builder.setInsertionPointToEnd(&block); |
488 | Location loc = builder.getUnknownLoc(); |
489 | builder.create<YieldOp>(loc, values); |
490 | } |
491 | |
492 | Value constant(const std::string &value) { |
493 | OpBuilder::InsertionGuard g(builder); |
494 | builder.setInsertionPointToEnd(&block); |
495 | Location loc = builder.getUnknownLoc(); |
496 | Attribute valueAttr = parseAttribute(attrStr: value, context: builder.getContext()); |
497 | return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr)); |
498 | } |
499 | |
500 | Value index(int64_t dim) { |
501 | OpBuilder::InsertionGuard g(builder); |
502 | builder.setInsertionPointToEnd(&block); |
503 | return builder.create<IndexOp>(builder.getUnknownLoc(), dim); |
504 | } |
505 | |
506 | Type getIntegerType(unsigned width) { |
507 | return IntegerType::get(builder.getContext(), width); |
508 | } |
509 | |
510 | Type getFloat32Type() { return Float32Type::get(builder.getContext()); } |
511 | Type getFloat64Type() { return Float64Type::get(builder.getContext()); } |
512 | |
513 | private: |
514 | // Generates operations to cast the given operand to a specified type. |
515 | // If the cast cannot be performed, a warning will be issued and the |
516 | // operand returned as-is (which will presumably yield a verification |
517 | // issue downstream). |
518 | Value cast(Type toType, Value operand, bool isUnsignedCast) { |
519 | OpBuilder::InsertionGuard g(builder); |
520 | builder.setInsertionPointToEnd(&block); |
521 | auto loc = operand.getLoc(); |
522 | return convertScalarToDtype(b&: builder, loc, operand, toType, isUnsignedCast); |
523 | } |
524 | |
525 | bool isComplex(Value value) { |
526 | return llvm::isa<ComplexType>(value.getType()); |
527 | } |
528 | bool isFloatingPoint(Value value) { |
529 | return llvm::isa<FloatType>(Val: value.getType()); |
530 | } |
531 | bool isInteger(Value value) { |
532 | return llvm::isa<IntegerType>(Val: value.getType()); |
533 | } |
534 | |
535 | OpBuilder &builder; |
536 | Block █ |
537 | }; |
538 | |
539 | } // namespace |
540 | |
541 | //===----------------------------------------------------------------------===// |
542 | // CopyOp |
543 | //===----------------------------------------------------------------------===// |
544 | |
545 | namespace { |
546 | |
547 | struct EraseSelfCopy : OpRewritePattern<CopyOp> { |
548 | using OpRewritePattern<CopyOp>::OpRewritePattern; |
549 | LogicalResult matchAndRewrite(CopyOp copyOp, |
550 | PatternRewriter &rewriter) const override { |
551 | if (copyOp.getInputs() != copyOp.getOutputs()) |
552 | return rewriter.notifyMatchFailure(copyOp, "not a self copy" ); |
553 | if (copyOp.hasPureBufferSemantics()) |
554 | rewriter.eraseOp(op: copyOp); |
555 | else |
556 | rewriter.replaceOp(copyOp, copyOp.getInputs()); |
557 | |
558 | return success(); |
559 | } |
560 | }; |
561 | |
562 | } // namespace |
563 | |
564 | void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results, |
565 | MLIRContext *context) { |
566 | results.add<EraseSelfCopy>(context); |
567 | } |
568 | |
569 | //===----------------------------------------------------------------------===// |
570 | // FillOp |
571 | //===----------------------------------------------------------------------===// |
572 | |
573 | namespace { |
574 | |
575 | /// Fold linalg.fill -> tensor.expand/collapse_shape chain. |
576 | /// |
577 | /// For such op chains, we can create new linalg.fill ops with the result |
578 | /// type of the tensor.expand/collapse_shape op. |
579 | template <typename TensorReshapeOp> |
580 | struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> { |
581 | using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; |
582 | LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, |
583 | PatternRewriter &rewriter) const override { |
584 | auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>(); |
585 | if (!oldFill) |
586 | return failure(); |
587 | |
588 | Location loc = oldFill.getLoc(); |
589 | auto newInit = rewriter.create<TensorReshapeOp>( |
590 | loc, reshapeOp.getResultType(), oldFill.output(), |
591 | reshapeOp.getReassociation()); |
592 | rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()}, |
593 | ValueRange{newInit}); |
594 | |
595 | return success(); |
596 | } |
597 | }; |
598 | |
599 | /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the |
600 | /// filling value are the same. |
601 | struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> { |
602 | using OpRewritePattern::OpRewritePattern; |
603 | |
604 | LogicalResult matchAndRewrite(tensor::PadOp padOp, |
605 | PatternRewriter &rewriter) const override { |
606 | auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>(); |
607 | if (!fillOp) |
608 | return failure(); |
609 | |
610 | // We can only fold if the padding value is the same as the original |
611 | // filling value. |
612 | Value padValue = padOp.getConstantPaddingValue(); |
613 | if (!padValue || fillOp.value() != padValue) |
614 | return failure(); |
615 | |
616 | ReifiedRankedShapedTypeDims reifiedShape; |
617 | if (failed(reifyResultShapes(rewriter, padOp, reifiedShape))) |
618 | return rewriter.notifyMatchFailure( |
619 | padOp, "failed to reify tensor.pad op result shape" ); |
620 | |
621 | auto emptyTensor = rewriter.create<tensor::EmptyOp>( |
622 | padOp.getLoc(), reifiedShape.front(), |
623 | padOp.getResultType().getElementType()); |
624 | Value replacement = |
625 | rewriter |
626 | .create<FillOp>(fillOp.getLoc(), ValueRange{padValue}, |
627 | ValueRange{emptyTensor}) |
628 | .getResult(0); |
629 | if (replacement.getType() != padOp.getResultType()) { |
630 | replacement = rewriter.create<tensor::CastOp>( |
631 | fillOp.getLoc(), padOp.getResultType(), replacement); |
632 | } |
633 | rewriter.replaceOp(padOp, replacement); |
634 | return success(); |
635 | } |
636 | }; |
637 | |
638 | /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into |
639 | /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the |
640 | /// filling value are the same. |
641 | struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> { |
642 | using OpRewritePattern::OpRewritePattern; |
643 | |
644 | LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, |
645 | PatternRewriter &rewriter) const override { |
646 | auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>(); |
647 | if (!srcPadOp) |
648 | return failure(); |
649 | |
650 | if (insertOp.getType().getRank() != insertOp.getSourceType().getRank()) |
651 | return failure(); |
652 | |
653 | // Walk back the tensor.insert_slice chain and find the first destination |
654 | // value at the start of the chain. |
655 | Value firstDest = insertOp.getDest(); |
656 | while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) { |
657 | if (prevOp.getType().getRank() != prevOp.getSourceType().getRank()) |
658 | return failure(); |
659 | |
660 | // Make sure the range of values accessed are disjoint. Without this, we |
661 | // cannot fold tensor.pad away. |
662 | bool disjoint = false; |
663 | for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) { |
664 | // If the dimension has dynamic offset/size, we cannot guarantee |
665 | // disjoint. So just skip it. |
666 | if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) || |
667 | insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) || |
668 | prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i)) |
669 | continue; |
670 | |
671 | // Get the range start and end, inclusively for both. |
672 | int64_t prevStart = prevOp.getStaticOffset(i); |
673 | int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) * |
674 | prevOp.getStaticStride(i); |
675 | int64_t nextStart = insertOp.getStaticOffset(i); |
676 | int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) * |
677 | insertOp.getStaticStride(i); |
678 | if (prevEnd < nextStart || nextEnd < prevStart) { |
679 | disjoint = true; |
680 | break; |
681 | } |
682 | } |
683 | |
684 | if (!disjoint) |
685 | break; |
686 | firstDest = prevOp.getDest(); |
687 | } |
688 | |
689 | // Check whether the first destination is a fill op. For overlapped cases, |
690 | // this also cannot be true. |
691 | auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>(); |
692 | if (!dstFillOp) |
693 | return failure(); |
694 | |
695 | // We can only fold if the padding value is the same as the original |
696 | // filling value. |
697 | Value padValue = srcPadOp.getConstantPaddingValue(); |
698 | if (!padValue || dstFillOp.value() != padValue) |
699 | return failure(); |
700 | |
701 | SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad(); |
702 | SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets(); |
703 | |
704 | Location loc = insertOp.getLoc(); |
705 | MLIRContext *context = getContext(); |
706 | |
707 | AffineExpr sym0, sym1; |
708 | bindSymbols(ctx: context, exprs&: sym0, exprs&: sym1); |
709 | auto addMap = AffineMap::get(dimCount: 0, symbolCount: 2, results: {sym0 + sym1}, context); |
710 | |
711 | // Calculate the new offsets for the insert. It should be the old offsets |
712 | // plus low padding sizes. |
713 | SmallVector<OpFoldResult, 4> newOffsets; |
714 | for (const auto &p : llvm::zip(lowPads, oldOffsets)) { |
715 | newOffsets.push_back(affine::makeComposedFoldedAffineApply( |
716 | rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)})); |
717 | } |
718 | |
719 | RankedTensorType srcPadType = srcPadOp.getSourceType(); |
720 | SmallVector<OpFoldResult, 4> newSizes; |
721 | for (int i = 0, e = srcPadType.getRank(); i < e; ++i) { |
722 | if (srcPadType.isDynamicDim(i)) { |
723 | newSizes.push_back( |
724 | rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i) |
725 | .getResult()); |
726 | } else { |
727 | newSizes.push_back(Elt: rewriter.getIndexAttr(value: srcPadType.getDimSize(i))); |
728 | } |
729 | } |
730 | |
731 | rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( |
732 | insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets, |
733 | newSizes, insertOp.getMixedStrides()); |
734 | return success(); |
735 | } |
736 | }; |
737 | |
738 | /// Fold tensor.extract(linalg.fill(<input>)) into <input> |
739 | struct : public OpRewritePattern<tensor::ExtractOp> { |
740 | public: |
741 | using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern; |
742 | |
743 | LogicalResult matchAndRewrite(tensor::ExtractOp , |
744 | PatternRewriter &rewriter) const override { |
745 | // See if tensor input of tensor.extract op is the result of a linalg.fill |
746 | // op. |
747 | auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>(); |
748 | if (!fillOp) |
749 | return failure(); |
750 | |
751 | // Get scalar input operand of linalg.fill op. |
752 | Value = fillOp.getInputs()[0]; |
753 | |
754 | // Replace tensor.extract op with scalar value used to fill the tensor. |
755 | rewriter.replaceOp(extractOp, extractedScalar); |
756 | return success(); |
757 | } |
758 | }; |
759 | |
760 | /// Folds pack(fill) into a single fill op if |
761 | /// 1. The pack op does not have padding value, or |
762 | /// 2. The filled value and padding value are the same. |
763 | static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter, |
764 | tensor::PackOp packOp) { |
765 | auto fillOp = packOp.getSource().getDefiningOp<FillOp>(); |
766 | if (!fillOp) |
767 | return failure(); |
768 | |
769 | if (auto paddingValue = packOp.getPaddingValue()) |
770 | if (!isEqualConstantIntOrValue(paddingValue, fillOp.value())) |
771 | return failure(); |
772 | |
773 | Value packOpDest = packOp.getDest(); |
774 | if (!packOpDest.hasOneUse()) |
775 | return failure(); |
776 | |
777 | return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(), |
778 | packOp.getDest()); |
779 | } |
780 | |
781 | /// Wrapper pattern that applies foldFillPackIntoFillOp method. |
782 | struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> { |
783 | public: |
784 | FoldFillWithPack(MLIRContext *context) |
785 | : OpRewritePattern<tensor::PackOp>(context) {} |
786 | |
787 | LogicalResult matchAndRewrite(tensor::PackOp packOp, |
788 | PatternRewriter &rewriter) const override { |
789 | auto fillOp = foldFillPackIntoFillOp(rewriter, packOp); |
790 | if (failed(fillOp)) |
791 | return failure(); |
792 | rewriter.replaceOp(packOp, fillOp.value().result()); |
793 | return success(); |
794 | } |
795 | }; |
796 | |
797 | /// Fold fill with copy. |
798 | struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> { |
799 | using OpRewritePattern<linalg::CopyOp>::OpRewritePattern; |
800 | |
801 | LogicalResult matchAndRewrite(linalg::CopyOp copyOp, |
802 | PatternRewriter &rewriter) const override { |
803 | if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) { |
804 | rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(), |
805 | fillOp.getInputs(), |
806 | copyOp.getOutputs()); |
807 | return success(); |
808 | } |
809 | if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) { |
810 | rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(), |
811 | fillOp.getOutputs()); |
812 | return success(); |
813 | } |
814 | return failure(); |
815 | } |
816 | }; |
817 | |
818 | /// Fold fill with transpose. |
819 | struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> { |
820 | using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern; |
821 | |
822 | LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, |
823 | PatternRewriter &rewriter) const override { |
824 | if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) { |
825 | rewriter.replaceOpWithNewOp<FillOp>( |
826 | transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(), |
827 | transposeOp.getDpsInitOperand(0)->get()); |
828 | return success(); |
829 | } |
830 | return failure(); |
831 | } |
832 | }; |
833 | |
834 | } // namespace |
835 | |
836 | void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, |
837 | MLIRContext *context) { |
838 | results |
839 | .add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack, |
840 | FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>, |
841 | FoldFillWithTensorReshape<tensor::ExpandShapeOp>, |
842 | FoldInsertPadIntoFill, FoldFillWithTranspose>(context); |
843 | } |
844 | |
845 | //===----------------------------------------------------------------------===// |
846 | // GenericOp |
847 | //===----------------------------------------------------------------------===// |
848 | |
849 | static void buildGenericRegion( |
850 | OpBuilder &builder, Location loc, Region ®ion, ValueRange inputs, |
851 | ValueRange outputs, |
852 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) { |
853 | SmallVector<Type, 4> blockArgTypes; |
854 | SmallVector<Location, 4> blockArgLocs; |
855 | for (ValueRange container : {inputs, outputs}) { |
856 | for (Value v : container) { |
857 | Type t = v.getType(); |
858 | blockArgTypes.push_back( |
859 | Elt: isa<MemRefType, RankedTensorType>(Val: t) ? getElementTypeOrSelf(type: t) : t); |
860 | blockArgLocs.push_back(Elt: v.getLoc()); |
861 | } |
862 | } |
863 | |
864 | OpBuilder::InsertionGuard guard(builder); |
865 | Block *bodyBlock = |
866 | builder.createBlock(parent: ®ion, insertPt: region.end(), argTypes: blockArgTypes, locs: blockArgLocs); |
867 | bodyBuild(builder, loc, bodyBlock->getArguments()); |
868 | } |
869 | |
870 | void GenericOp::getAsmBlockArgumentNames(Region ®ion, |
871 | OpAsmSetValueNameFn setNameFn) { |
872 | for (Value v : getRegionInputArgs()) |
873 | setNameFn(v, "in" ); |
874 | for (Value v : getRegionOutputArgs()) |
875 | setNameFn(v, "out" ); |
876 | } |
877 | |
878 | void GenericOp::build( |
879 | OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, |
880 | ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps, |
881 | ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall, |
882 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, |
883 | ArrayRef<NamedAttribute> attributes) { |
884 | build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, |
885 | iteratorTypes, doc, libraryCall); |
886 | result.addAttributes(attributes); |
887 | if (bodyBuild) |
888 | buildGenericRegion(builder, result.location, *result.regions.front(), |
889 | inputs, outputs, bodyBuild); |
890 | } |
891 | |
892 | void GenericOp::build( |
893 | OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, |
894 | ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps, |
895 | ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc, |
896 | StringRef libraryCall, |
897 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, |
898 | ArrayRef<NamedAttribute> attributes) { |
899 | build(builder, result, resultTensorTypes, inputs, outputs, |
900 | builder.getAffineMapArrayAttr(indexingMaps), |
901 | builder.getArrayAttr(llvm::to_vector(llvm::map_range( |
902 | iteratorTypes, |
903 | [&](utils::IteratorType iter) -> mlir::Attribute { |
904 | return IteratorTypeAttr::get(builder.getContext(), iter); |
905 | }))), |
906 | doc.empty() ? StringAttr() : builder.getStringAttr(doc), |
907 | libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall), |
908 | bodyBuild, attributes); |
909 | } |
910 | |
911 | void GenericOp::build( |
912 | OpBuilder &builder, OperationState &result, ValueRange inputs, |
913 | ValueRange outputs, ArrayRef<AffineMap> indexingMaps, |
914 | ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc, |
915 | StringRef libraryCall, |
916 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, |
917 | ArrayRef<NamedAttribute> attributes) { |
918 | build(builder, result, TypeRange{}, inputs, outputs, indexingMaps, |
919 | iteratorTypes, doc, libraryCall, bodyBuild, attributes); |
920 | } |
921 | |
922 | void GenericOp::build( |
923 | OpBuilder &builder, OperationState &result, ValueRange inputs, |
924 | ValueRange outputs, ArrayRef<AffineMap> indexingMaps, |
925 | ArrayRef<utils::IteratorType> iteratorTypes, |
926 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, |
927 | ArrayRef<NamedAttribute> attributes) { |
928 | build(builder, result, inputs, outputs, indexingMaps, iteratorTypes, |
929 | /*doc=*/"" , |
930 | /*libraryCall=*/"" , bodyBuild, attributes); |
931 | } |
932 | |
933 | void GenericOp::build( |
934 | OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, |
935 | ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps, |
936 | ArrayRef<utils::IteratorType> iteratorTypes, |
937 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, |
938 | ArrayRef<NamedAttribute> attributes) { |
939 | build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, |
940 | iteratorTypes, |
941 | /*doc=*/"" , |
942 | /*libraryCall=*/"" , bodyBuild, attributes); |
943 | } |
944 | |
945 | void GenericOp::print(OpAsmPrinter &p) { |
946 | p << " " ; |
947 | |
948 | // Print extra attributes. |
949 | auto genericAttrNames = linalgTraitAttrNames(); |
950 | |
951 | llvm::StringSet<> genericAttrNamesSet; |
952 | genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end()); |
953 | SmallVector<NamedAttribute, 8> genericAttrs; |
954 | for (auto attr : (*this)->getAttrs()) { |
955 | if (attr.getName() == getIteratorTypesAttrName()) { |
956 | auto iteratorTypes = |
957 | llvm::cast<ArrayAttr>(attr.getValue()) |
958 | .getAsValueRange<IteratorTypeAttr, utils::IteratorType>(); |
959 | // Convert IteratorType enums into the string representation. This is |
960 | // needed, because tests still use the old format when 'iterator_types' |
961 | // attribute is represented as an array of strings. |
962 | // TODO: Remove this conversion once tests are fixed. |
963 | SmallVector<Attribute> iteratorTypeNames = |
964 | llvm::to_vector(llvm::map_range( |
965 | iteratorTypes, [&](utils::IteratorType t) -> Attribute { |
966 | return StringAttr::get(getContext(), stringifyIteratorType(t)); |
967 | })); |
968 | |
969 | genericAttrs.emplace_back( |
970 | getIteratorTypesAttrName(), |
971 | ArrayAttr::get(getContext(), iteratorTypeNames)); |
972 | } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) { |
973 | genericAttrs.push_back(attr); |
974 | } |
975 | } |
976 | if (!genericAttrs.empty()) { |
977 | auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs); |
978 | p << genericDictAttr; |
979 | } |
980 | |
981 | // Printing is shared with named ops, except for the region and attributes |
982 | printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); |
983 | |
984 | genericAttrNames.push_back("operandSegmentSizes" ); |
985 | genericAttrNamesSet.insert(genericAttrNames.back()); |
986 | |
987 | bool hasExtraAttrs = false; |
988 | for (NamedAttribute n : (*this)->getAttrs()) { |
989 | if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref()))) |
990 | break; |
991 | } |
992 | if (hasExtraAttrs) { |
993 | p << " attrs = " ; |
994 | p.printOptionalAttrDict((*this)->getAttrs(), |
995 | /*elidedAttrs=*/genericAttrNames); |
996 | } |
997 | |
998 | // Print region. |
999 | if (!getRegion().empty()) { |
1000 | p << ' '; |
1001 | p.printRegion(getRegion()); |
1002 | } |
1003 | |
1004 | // Print results. |
1005 | printNamedStructuredOpResults(p, getResultTensors().getTypes()); |
1006 | } |
1007 | |
1008 | ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) { |
1009 | DictionaryAttr dictAttr; |
1010 | // Parse the core linalg traits that must check into a dictAttr. |
1011 | // The name is unimportant as we will overwrite result.attributes. |
1012 | // The core linalg traits must contain the information necessary to pass the |
1013 | // verifier. |
1014 | llvm::SMLoc attributeLocation = parser.getCurrentLocation(); |
1015 | if (parser.parseAttribute(dictAttr, "_" , result.attributes)) |
1016 | return failure(); |
1017 | result.attributes.assign(dictAttr.getValue().begin(), |
1018 | dictAttr.getValue().end()); |
1019 | |
1020 | // Convert array of string into an array of IteratorType enums. This is |
1021 | // needed, because tests still use the old format when 'iterator_types' |
1022 | // attribute is represented as an array of strings. |
1023 | // TODO: Remove this conversion once tests are fixed. |
1024 | auto iteratorTypes = dyn_cast_or_null<ArrayAttr>( |
1025 | result.attributes.get(getIteratorTypesAttrName(result.name))); |
1026 | if (!iteratorTypes) { |
1027 | return parser.emitError(attributeLocation) |
1028 | << "expected " << getIteratorTypesAttrName(result.name) |
1029 | << " array attribute" ; |
1030 | } |
1031 | |
1032 | SmallVector<Attribute> iteratorTypeAttrs; |
1033 | |
1034 | for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) { |
1035 | auto maybeIteratorType = utils::symbolizeIteratorType(s); |
1036 | if (!maybeIteratorType.has_value()) |
1037 | return parser.emitError(parser.getCurrentLocation()) |
1038 | << "unexpected iterator_type (" << s << ")" ; |
1039 | |
1040 | iteratorTypeAttrs.push_back( |
1041 | IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value())); |
1042 | } |
1043 | result.attributes.set(getIteratorTypesAttrName(result.name), |
1044 | parser.getBuilder().getArrayAttr(iteratorTypeAttrs)); |
1045 | |
1046 | // Parsing is shared with named ops, except for the region. |
1047 | SmallVector<Type, 1> inputTypes, outputTypes; |
1048 | if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes)) |
1049 | return failure(); |
1050 | |
1051 | // Optional attributes may be added. |
1052 | if (succeeded(parser.parseOptionalKeyword("attrs" ))) |
1053 | if (failed(parser.parseEqual()) || |
1054 | failed(parser.parseOptionalAttrDict(result.attributes))) |
1055 | return failure(); |
1056 | |
1057 | std::unique_ptr<Region> region = std::make_unique<Region>(); |
1058 | if (parser.parseRegion(*region, {})) |
1059 | return failure(); |
1060 | result.addRegion(std::move(region)); |
1061 | |
1062 | // Generic ops may specify that a subset of its outputs are tensors. Such |
1063 | // outputs are specified in the result type. |
1064 | // TODO: may need to move output parsing before region parsing. |
1065 | // Need to wait for declarative assembly resolution to decide. |
1066 | SmallVector<Type, 1> outputTensorsTypes; |
1067 | if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) |
1068 | return failure(); |
1069 | result.addTypes(outputTensorsTypes); |
1070 | |
1071 | return success(); |
1072 | } |
1073 | |
1074 | static void getGenericEffectsImpl( |
1075 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
1076 | &effects, |
1077 | ValueRange results, const ValueRange inputOperands, |
1078 | ValueRange outputOperands) { |
1079 | for (auto operand : inputOperands) { |
1080 | if (!llvm::isa<MemRefType>(Val: operand.getType())) |
1081 | continue; |
1082 | effects.emplace_back(Args: MemoryEffects::Read::get(), Args&: operand, |
1083 | Args: SideEffects::DefaultResource::get()); |
1084 | } |
1085 | for (auto operand : outputOperands) { |
1086 | if (!llvm::isa<MemRefType>(Val: operand.getType())) |
1087 | continue; |
1088 | effects.emplace_back(Args: MemoryEffects::Read::get(), Args&: operand, |
1089 | Args: SideEffects::DefaultResource::get()); |
1090 | effects.emplace_back(Args: MemoryEffects::Write::get(), Args&: operand, |
1091 | Args: SideEffects::DefaultResource::get()); |
1092 | } |
1093 | } |
1094 | |
1095 | void GenericOp::getEffects( |
1096 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
1097 | &effects) { |
1098 | getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(), |
1099 | getDpsInits()); |
1100 | } |
1101 | |
1102 | LogicalResult GenericOp::verify() { return success(); } |
1103 | |
1104 | namespace { |
1105 | |
1106 | /// Remove any linalg operation (on tensors) that are just copying |
1107 | /// the values from inputs to the results. Requirements are |
1108 | /// 1) All iterator types are parallel |
1109 | /// 2) The body contains just a yield operation with the yielded values being |
1110 | /// the arguments corresponding to the operands. |
1111 | template <typename OpTy> |
1112 | struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> { |
1113 | using OpRewritePattern<OpTy>::OpRewritePattern; |
1114 | |
1115 | LogicalResult matchAndRewrite(OpTy linalgOp, |
1116 | PatternRewriter &rewriter) const override { |
1117 | // Check all indexing maps are identity. |
1118 | if (llvm::any_of(linalgOp.getIndexingMapsArray(), |
1119 | [](AffineMap map) { return !map.isIdentity(); })) |
1120 | return failure(); |
1121 | |
1122 | // Check that the body of the linalg operation is just a linalg.yield |
1123 | // operation. |
1124 | Block &body = linalgOp->getRegion(0).front(); |
1125 | if (!llvm::hasSingleElement(C&: body)) |
1126 | return failure(); |
1127 | auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator()); |
1128 | if (!yieldOp) |
1129 | return failure(); |
1130 | |
1131 | // In the buffer case, we need to check exact buffer equality. |
1132 | if (linalgOp.hasPureBufferSemantics()) { |
1133 | if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 && |
1134 | linalgOp.getDpsInputOperand(0)->get() == |
1135 | linalgOp.getDpsInitOperand(0)->get()) { |
1136 | rewriter.eraseOp(op: linalgOp); |
1137 | return success(); |
1138 | } |
1139 | return failure(); |
1140 | } |
1141 | |
1142 | // Mixed semantics is not supported yet. |
1143 | if (!linalgOp.hasPureTensorSemantics()) |
1144 | return failure(); |
1145 | |
1146 | // Get the argument number of the returned values. That is the operand |
1147 | // number to use for replacing uses of this operation. |
1148 | SmallVector<Value> returnedArgs; |
1149 | for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) { |
1150 | auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value()); |
1151 | if (!yieldArg || yieldArg.getOwner() != &body) |
1152 | return failure(); |
1153 | unsigned argumentNumber = yieldArg.getArgNumber(); |
1154 | Value returnedArg = linalgOp->getOperand(argumentNumber); |
1155 | Type resultType = linalgOp->getResult(yieldVal.index()).getType(); |
1156 | // The input can have a different type than the result, e.g. a dynamic |
1157 | // input dimension can be turned into a static output dimension. |
1158 | Type returnType = returnedArg.getType(); |
1159 | if (returnType != resultType) { |
1160 | // Distinguish between sparse conversion or dense tensor casting. |
1161 | // TODO: unify the two ops? |
1162 | if (sparse_tensor::getSparseTensorEncoding(returnType) || |
1163 | sparse_tensor::getSparseTensorEncoding(resultType)) |
1164 | returnedArg = rewriter.create<sparse_tensor::ConvertOp>( |
1165 | linalgOp.getLoc(), resultType, returnedArg); |
1166 | else { |
1167 | if (!tensor::CastOp::areCastCompatible(returnedArg.getType(), |
1168 | resultType)) |
1169 | return failure(); |
1170 | returnedArg = rewriter.create<tensor::CastOp>( |
1171 | linalgOp.getLoc(), resultType, returnedArg); |
1172 | } |
1173 | } |
1174 | returnedArgs.push_back(returnedArg); |
1175 | } |
1176 | |
1177 | if (returnedArgs.size() != linalgOp->getNumResults()) |
1178 | return failure(); |
1179 | rewriter.replaceOp(linalgOp, returnedArgs); |
1180 | return success(); |
1181 | } |
1182 | }; |
1183 | |
1184 | } // namespace |
1185 | |
1186 | void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, |
1187 | MLIRContext *context) { |
1188 | results.add<EraseIdentityLinalgOp<GenericOp>>(context); |
1189 | } |
1190 | |
1191 | LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) { |
1192 | return memref::foldMemRefCast(*this); |
1193 | } |
1194 | |
1195 | //===----------------------------------------------------------------------===// |
1196 | // MapOp |
1197 | //===----------------------------------------------------------------------===// |
1198 | |
1199 | static ParseResult parseDstStyleOp( |
1200 | OpAsmParser &parser, OperationState &result, |
1201 | function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn = |
1202 | nullptr) { |
1203 | // Parse `ins` and `outs`. |
1204 | SmallVector<Type, 4> inputTypes, outputTypes; |
1205 | if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes, |
1206 | /*addOperandSegmentSizes=*/false)) |
1207 | return failure(); |
1208 | |
1209 | // Add result types. |
1210 | for (Type outputType : outputTypes) { |
1211 | if (llvm::isa<RankedTensorType>(Val: outputType)) |
1212 | result.addTypes(newTypes: outputType); |
1213 | } |
1214 | |
1215 | // Parse required attributes. |
1216 | if (parseAttrsFn && failed(result: parseAttrsFn(parser, result.attributes))) |
1217 | return failure(); |
1218 | |
1219 | // Parse optional attributes. |
1220 | if (parser.parseOptionalAttrDict(result&: result.attributes)) |
1221 | return failure(); |
1222 | return success(); |
1223 | } |
1224 | |
1225 | void MapOp::getAsmBlockArgumentNames(Region ®ion, |
1226 | OpAsmSetValueNameFn setNameFn) { |
1227 | for (Value v : getRegionInputArgs()) |
1228 | setNameFn(v, "in" ); |
1229 | } |
1230 | |
1231 | void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { |
1232 | if (!getResults().empty()) |
1233 | setNameFn(getResults().front(), "mapped" ); |
1234 | } |
1235 | |
1236 | void MapOp::build( |
1237 | OpBuilder &builder, OperationState &result, ValueRange inputs, Value init, |
1238 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, |
1239 | ArrayRef<NamedAttribute> attributes) { |
1240 | build(builder, result, TypeRange{}, inputs, init); |
1241 | result.addAttributes(attributes); |
1242 | |
1243 | // Add output types for `RankedTensorType` output arguments. |
1244 | Type initType = init.getType(); |
1245 | if (llvm::isa<RankedTensorType>(initType)) |
1246 | result.addTypes(initType); |
1247 | |
1248 | if (bodyBuild) |
1249 | buildGenericRegion(builder, result.location, *result.regions.front(), |
1250 | inputs, /*outputs=*/{}, bodyBuild); |
1251 | } |
1252 | |
1253 | static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, |
1254 | const OperationName &payloadOpName, |
1255 | const NamedAttrList &payloadOpAttrs, |
1256 | ArrayRef<Value> operands, |
1257 | bool initFirst = false) { |
1258 | OpBuilder b(parser.getContext()); |
1259 | Region *body = result.addRegion(); |
1260 | Block &block = body->emplaceBlock(); |
1261 | b.setInsertionPointToStart(&block); |
1262 | SmallVector<Value> bbArgs; |
1263 | for (auto &operand : operands) { |
1264 | block.addArgument( |
1265 | llvm::cast<ShapedType>(operand.getType()).getElementType(), |
1266 | b.getUnknownLoc()); |
1267 | } |
1268 | SmallVector<Value> payloadOpOperands; |
1269 | // If initFirst flag is enabled, we consider init as the first position of |
1270 | // payload operands. |
1271 | if (initFirst) { |
1272 | payloadOpOperands.push_back(Elt: block.getArguments().back()); |
1273 | for (const auto &arg : block.getArguments().drop_back()) |
1274 | payloadOpOperands.push_back(Elt: arg); |
1275 | } else { |
1276 | payloadOpOperands = {block.getArguments().begin(), |
1277 | block.getArguments().end()}; |
1278 | } |
1279 | |
1280 | Operation *payloadOp = b.create( |
1281 | result.location, b.getStringAttr(payloadOpName.getStringRef()), |
1282 | payloadOpOperands, |
1283 | TypeRange{llvm::cast<ShapedType>(result.operands.back().getType()) |
1284 | .getElementType()}, |
1285 | payloadOpAttrs); |
1286 | b.create<YieldOp>(result.location, payloadOp->getResults()); |
1287 | } |
1288 | |
1289 | ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { |
1290 | std::optional<OperationName> payloadOpName; |
1291 | NamedAttrList payloadOpAttrs; |
1292 | if (succeeded(parser.parseOptionalLBrace())) { |
1293 | FailureOr<OperationName> operationName = parser.parseCustomOperationName(); |
1294 | if (failed(operationName)) |
1295 | return failure(); |
1296 | if (parser.parseOptionalAttrDict(payloadOpAttrs)) |
1297 | return failure(); |
1298 | payloadOpName = operationName.value(); |
1299 | if (parser.parseRBrace()) |
1300 | return failure(); |
1301 | } |
1302 | |
1303 | if (parseDstStyleOp(parser, result)) |
1304 | return failure(); |
1305 | |
1306 | if (payloadOpName.has_value()) { |
1307 | addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs, |
1308 | ArrayRef(result.operands).drop_back()); |
1309 | } else { |
1310 | SmallVector<OpAsmParser::Argument> regionArgs; |
1311 | if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, |
1312 | /*allowType=*/true, /*allowAttrs=*/true)) { |
1313 | return failure(); |
1314 | } |
1315 | Region *body = result.addRegion(); |
1316 | if (parser.parseRegion(*body, regionArgs)) |
1317 | return failure(); |
1318 | } |
1319 | return success(); |
1320 | } |
1321 | |
1322 | // Retrieve the operation from the body, if it is the only one (except |
1323 | // yield) and if it gets the same amount of arguments as the body does. |
1324 | // If initFirst flag is enabled, we check that init takes the first position in |
1325 | // operands of payload. |
1326 | static Operation *findPayloadOp(Block *body, bool initFirst = false) { |
1327 | if (body->getOperations().size() != 2) |
1328 | return nullptr; |
1329 | Operation &payload = body->getOperations().front(); |
1330 | assert(isa<YieldOp>(body->getOperations().back())); |
1331 | |
1332 | if (payload.getNumOperands() == 0 || |
1333 | payload.getNumOperands() != body->getNumArguments()) |
1334 | return nullptr; |
1335 | if (initFirst) { |
1336 | // check init |
1337 | if (payload.getOperands().back() != body->getArgument(i: 0)) |
1338 | return nullptr; |
1339 | // check rest |
1340 | for (const auto &[operand, bbArg] : |
1341 | llvm::zip(t: payload.getOperands(), u: body->getArguments().drop_front())) { |
1342 | if (bbArg != operand) |
1343 | return nullptr; |
1344 | } |
1345 | } else { |
1346 | for (const auto &[operand, bbArg] : |
1347 | llvm::zip(t: payload.getOperands(), u: body->getArguments())) { |
1348 | if (bbArg != operand) |
1349 | return nullptr; |
1350 | } |
1351 | } |
1352 | return &payload; |
1353 | } |
1354 | |
1355 | void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { |
1356 | SmallVector<StringRef> elidedAttrs; |
1357 | std::string attrToElide; |
1358 | p << " { " << payloadOp->getName().getStringRef(); |
1359 | for (const auto &attr : payloadOp->getAttrs()) { |
1360 | auto fastAttr = |
1361 | llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue()); |
1362 | if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) { |
1363 | attrToElide = attr.getName().str(); |
1364 | elidedAttrs.push_back(Elt: attrToElide); |
1365 | break; |
1366 | } |
1367 | } |
1368 | p.printOptionalAttrDict(attrs: payloadOp->getAttrs(), elidedAttrs); |
1369 | p << " }" ; |
1370 | } |
1371 | |
1372 | void MapOp::print(OpAsmPrinter &p) { |
1373 | Block *mapper = getBody(); |
1374 | Operation *payloadOp = findPayloadOp(mapper); |
1375 | if (payloadOp) { |
1376 | printShortForm(p, payloadOp); |
1377 | } |
1378 | |
1379 | printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); |
1380 | p.printOptionalAttrDict((*this)->getAttrs()); |
1381 | |
1382 | if (!payloadOp) { |
1383 | // Print region if the payload op was not detected. |
1384 | p.increaseIndent(); |
1385 | p.printNewline(); |
1386 | p << "(" ; |
1387 | llvm::interleaveComma(mapper->getArguments(), p, |
1388 | [&](auto arg) { p.printRegionArgument(arg); }); |
1389 | p << ") " ; |
1390 | |
1391 | p.printRegion(getMapper(), /*printEntryBlockArgs=*/false); |
1392 | p.decreaseIndent(); |
1393 | } |
1394 | } |
1395 | |
1396 | LogicalResult MapOp::verify() { |
1397 | auto *bodyBlock = getBody(); |
1398 | auto blockArgs = bodyBlock->getArguments(); |
1399 | |
1400 | // Checks if the number of `inputs` match the arity of the `mapper` region. |
1401 | if (getInputs().size() != blockArgs.size()) |
1402 | return emitOpError() << "expects number of operands to match the arity of " |
1403 | "mapper, but got: " |
1404 | << getInputs().size() << " and " << blockArgs.size(); |
1405 | |
1406 | // The parameters of mapper should all match the element type of inputs. |
1407 | for (const auto &[bbArgType, inputArg] : |
1408 | llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) { |
1409 | auto inputElemType = |
1410 | llvm::cast<ShapedType>(inputArg.getType()).getElementType(); |
1411 | if (bbArgType != inputElemType) { |
1412 | return emitOpError() << "expected element type of input " << inputElemType |
1413 | << " to match bbArg type " << bbArgType; |
1414 | } |
1415 | } |
1416 | |
1417 | // The shape of each input must match the shape of the output. |
1418 | auto outputShape = getInit().getType().getShape(); |
1419 | for (Type inputArgType : TypeRange{getInputs()}) { |
1420 | auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape(); |
1421 | if (inputElemShape != outputShape) { |
1422 | return emitOpError() << "expected shape of input (" << inputElemShape |
1423 | << ") to match shape of output (" << outputShape |
1424 | << ")" ; |
1425 | } |
1426 | } |
1427 | |
1428 | return success(); |
1429 | } |
1430 | |
1431 | SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() { |
1432 | int64_t rank = getInit().getType().getRank(); |
1433 | return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel); |
1434 | } |
1435 | |
1436 | ArrayAttr MapOp::getIndexingMaps() { |
1437 | Builder builder(getContext()); |
1438 | int64_t rank = getInit().getType().getRank(); |
1439 | int64_t numIndexingMaps = getOperands().size(); |
1440 | return builder.getAffineMapArrayAttr(SmallVector<AffineMap>( |
1441 | numIndexingMaps, builder.getMultiDimIdentityMap(rank))); |
1442 | } |
1443 | |
1444 | void MapOp::getEffects( |
1445 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
1446 | &effects) { |
1447 | getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(), |
1448 | getDpsInits()); |
1449 | } |
1450 | |
1451 | //===----------------------------------------------------------------------===// |
1452 | // ReduceOp |
1453 | //===----------------------------------------------------------------------===// |
1454 | |
1455 | void ReduceOp::getAsmBlockArgumentNames(Region ®ion, |
1456 | OpAsmSetValueNameFn setNameFn) { |
1457 | for (Value v : getRegionInputArgs()) |
1458 | setNameFn(v, "in" ); |
1459 | for (Value v : getRegionOutputArgs()) |
1460 | setNameFn(v, "init" ); |
1461 | } |
1462 | |
1463 | void ReduceOp::getAsmResultNames( |
1464 | function_ref<void(Value, StringRef)> setNameFn) { |
1465 | if (!getResults().empty()) |
1466 | setNameFn(getResults().front(), "reduced" ); |
1467 | } |
1468 | |
1469 | void ReduceOp::build( |
1470 | OpBuilder &builder, OperationState &result, ValueRange inputs, |
1471 | ValueRange inits, ArrayRef<int64_t> dimensions, |
1472 | function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild, |
1473 | ArrayRef<NamedAttribute> attributes) { |
1474 | build(builder, result, TypeRange{}, inputs, inits, dimensions); |
1475 | result.addAttributes(attributes); |
1476 | |
1477 | // Add output types for `RankedTensorType` output arguments. |
1478 | for (Value init : inits) { |
1479 | Type initType = init.getType(); |
1480 | if (llvm::isa<RankedTensorType>(initType)) |
1481 | result.addTypes(initType); |
1482 | } |
1483 | |
1484 | if (bodyBuild) |
1485 | buildGenericRegion(builder, result.location, *result.regions.front(), |
1486 | inputs, inits, bodyBuild); |
1487 | } |
1488 | |
1489 | SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() { |
1490 | int64_t inputRank = |
1491 | llvm::cast<ShapedType>(getInputs()[0].getType()).getRank(); |
1492 | SmallVector<utils::IteratorType> iteratorTypes(inputRank, |
1493 | utils::IteratorType::parallel); |
1494 | for (int64_t reductionDim : getDimensions()) |
1495 | iteratorTypes[reductionDim] = utils::IteratorType::reduction; |
1496 | return iteratorTypes; |
1497 | } |
1498 | |
1499 | ArrayAttr ReduceOp::getIndexingMaps() { |
1500 | int64_t inputRank = |
1501 | llvm::cast<ShapedType>(getInputs()[0].getType()).getRank(); |
1502 | SmallVector<AffineMap> affineMaps( |
1503 | getNumDpsInputs(), |
1504 | AffineMap::getMultiDimIdentityMap(inputRank, getContext())); |
1505 | AffineMap resultMap = |
1506 | AffineMap::getMultiDimIdentityMap(inputRank, getContext()) |
1507 | .dropResults(getDimensions()); |
1508 | for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i) |
1509 | affineMaps.push_back(resultMap); |
1510 | return Builder(getContext()).getAffineMapArrayAttr(affineMaps); |
1511 | } |
1512 | |
1513 | void ReduceOp::getEffects( |
1514 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
1515 | &effects) { |
1516 | getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(), |
1517 | getDpsInits()); |
1518 | } |
1519 | |
1520 | static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, |
1521 | NamedAttrList &attributes, |
1522 | StringRef attributeName) { |
1523 | if (parser.parseKeyword(keyword: attributeName) || parser.parseEqual()) |
1524 | return failure(); |
1525 | |
1526 | attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{})); |
1527 | return success(); |
1528 | } |
1529 | |
1530 | ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) { |
1531 | std::optional<OperationName> payloadOpName; |
1532 | NamedAttrList payloadOpAttrs; |
1533 | if (succeeded(parser.parseOptionalLBrace())) { |
1534 | FailureOr<OperationName> operationName = parser.parseCustomOperationName(); |
1535 | if (failed(operationName)) |
1536 | return failure(); |
1537 | if (parser.parseOptionalAttrDict(payloadOpAttrs)) |
1538 | return failure(); |
1539 | payloadOpName = operationName.value(); |
1540 | if (parser.parseRBrace()) |
1541 | return failure(); |
1542 | } |
1543 | |
1544 | if (parseDstStyleOp( |
1545 | parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { |
1546 | return parseDenseI64ArrayAttr(parser, attributes, "dimensions" ); |
1547 | })) |
1548 | return failure(); |
1549 | |
1550 | if (payloadOpName.has_value()) { |
1551 | addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs, |
1552 | ArrayRef(result.operands), /*initFirst=*/true); |
1553 | } else { |
1554 | SmallVector<OpAsmParser::Argument> regionArgs; |
1555 | if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren, |
1556 | /*allowType=*/true, /*allowAttrs=*/true)) { |
1557 | return failure(); |
1558 | } |
1559 | |
1560 | Region *body = result.addRegion(); |
1561 | if (parser.parseRegion(*body, regionArgs)) |
1562 | return failure(); |
1563 | } |
1564 | |
1565 | return success(); |
1566 | } |
1567 | |
1568 | static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, |
1569 | ArrayRef<int64_t> attributeValue) { |
1570 | p << ' ' << attributeName << " = [" << attributeValue << "] " ; |
1571 | } |
1572 | |
1573 | void ReduceOp::print(OpAsmPrinter &p) { |
1574 | Block *mapper = getBody(); |
1575 | Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true); |
1576 | if (payloadOp) { |
1577 | printShortForm(p, payloadOp); |
1578 | } |
1579 | |
1580 | printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); |
1581 | printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); |
1582 | p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); |
1583 | if (!payloadOp) { |
1584 | // Print region if the payload op was not detected. |
1585 | p.increaseIndent(); |
1586 | p.printNewline(); |
1587 | p << "(" ; |
1588 | llvm::interleaveComma(mapper->getArguments(), p, |
1589 | [&](auto arg) { p.printRegionArgument(arg); }); |
1590 | p << ") " ; |
1591 | |
1592 | p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false); |
1593 | p.decreaseIndent(); |
1594 | } |
1595 | } |
1596 | |
1597 | LogicalResult ReduceOp::verify() { |
1598 | ArrayRef<int64_t> dimensionsRef = getDimensions(); |
1599 | |
1600 | for (int64_t i = 1; i < getNumDpsInputs(); ++i) { |
1601 | if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() != |
1602 | llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) { |
1603 | return emitOpError() << "expects all inputs to have the same shapes. " |
1604 | "Shape at input-index " |
1605 | << i |
1606 | << " is not equal to the shape at input-index 0." ; |
1607 | } |
1608 | } |
1609 | for (int64_t i = 1; i < getNumDpsInits(); ++i) { |
1610 | if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() != |
1611 | llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) { |
1612 | return emitOpError() << "expects all outputs to have the same shapes. " |
1613 | "Shape at output-index " |
1614 | << i |
1615 | << " is not equal to the shape at output-index 0." ; |
1616 | } |
1617 | } |
1618 | auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType()); |
1619 | auto initType = llvm::cast<ShapedType>(getInits()[0].getType()); |
1620 | |
1621 | DenseSet<int64_t> dimensionsToReduce; |
1622 | for (int64_t dimension : dimensionsRef) { |
1623 | if (dimension < 0 || dimension >= inputType.getRank()) { |
1624 | return emitOpError() |
1625 | << "dimensions for reduction should be in the range [0, " |
1626 | << inputType.getRank() - 1 << "]." ; |
1627 | } |
1628 | dimensionsToReduce.insert(dimension); |
1629 | } |
1630 | |
1631 | auto inputDims = inputType.getShape(); |
1632 | auto initDims = initType.getShape(); |
1633 | |
1634 | // Input dimensions that will be left after the reduction. |
1635 | SmallVector<int64_t> reducedInputDims; |
1636 | for (const auto &en : llvm::enumerate(inputDims)) { |
1637 | if (!dimensionsToReduce.count(en.index())) |
1638 | reducedInputDims.push_back(en.value()); |
1639 | } |
1640 | |
1641 | if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) { |
1642 | return emitOpError() << "number of dimensions after reduction " |
1643 | << reducedInputDims.size() |
1644 | << " doesn't match the init rank " |
1645 | << initType.getRank(); |
1646 | } |
1647 | |
1648 | if (reducedInputDims != initDims) |
1649 | return emitOpError() << "init dimensions [" << initDims |
1650 | << "] doesn't match input dimensions after reduction [" |
1651 | << reducedInputDims << "]" ; |
1652 | |
1653 | Block *block = getBody(); |
1654 | if (block->getNumArguments() != this->getNumOperands()) |
1655 | return emitOpError() |
1656 | << "mismatching number of operands and block arguments" ; |
1657 | |
1658 | // Check that the first block arguments match the element type of the inputs. |
1659 | for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) { |
1660 | Type inputElementType = |
1661 | llvm::cast<ShapedType>(input.getType()).getElementType(); |
1662 | if (inputElementType != bbArg.getType()) |
1663 | return emitOpError() |
1664 | << "input element type " << inputElementType |
1665 | << " does not match corresponding block argument type " |
1666 | << bbArg.getType(); |
1667 | } |
1668 | |
1669 | // Check that the last block arguments match the element type of the outputs. |
1670 | for (auto [output, bbArg] : llvm::zip( |
1671 | getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) { |
1672 | auto outputElementType = |
1673 | llvm::cast<ShapedType>(output.getType()).getElementType(); |
1674 | if (outputElementType != bbArg.getType()) |
1675 | return emitOpError() |
1676 | << "output element type " << outputElementType |
1677 | << " does not match corresponding block argument type " |
1678 | << bbArg.getType(); |
1679 | } |
1680 | return success(); |
1681 | } |
1682 | |
1683 | //===----------------------------------------------------------------------===// |
1684 | // TransposeOp |
1685 | //===----------------------------------------------------------------------===// |
1686 | |
1687 | static void buildIdentityRegion(OpBuilder &builder, Location loc, |
1688 | Region ®ion, ValueRange inputs, |
1689 | ValueRange outputs) { |
1690 | buildGenericRegion(builder, loc, region, inputs, outputs, |
1691 | bodyBuild: [](OpBuilder &b, Location loc, ValueRange args) { |
1692 | b.create<linalg::YieldOp>(loc, args[0]); |
1693 | }); |
1694 | } |
1695 | |
1696 | void TransposeOp::build(::mlir::OpBuilder &builder, |
1697 | ::mlir::OperationState &result, Value input, Value init, |
1698 | DenseI64ArrayAttr permutation, |
1699 | ArrayRef<NamedAttribute> attributes) { |
1700 | result.addOperands(input); |
1701 | result.addOperands(init); |
1702 | result.addAttribute(getPermutationAttrName(result.name), permutation); |
1703 | result.addAttributes(attributes); |
1704 | |
1705 | // Add output types for `RankedTensorType` output arguments. |
1706 | Type initType = init.getType(); |
1707 | if (llvm::isa<RankedTensorType>(initType)) |
1708 | result.addTypes(initType); |
1709 | |
1710 | buildIdentityRegion(builder, result.location, *result.addRegion(), input, |
1711 | init); |
1712 | } |
1713 | |
1714 | void TransposeOp::build(::mlir::OpBuilder &builder, |
1715 | ::mlir::OperationState &result, Value input, Value init, |
1716 | ArrayRef<int64_t> permutation, |
1717 | ArrayRef<NamedAttribute> attributes) { |
1718 | build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation), |
1719 | attributes); |
1720 | } |
1721 | |
1722 | ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) { |
1723 | if (failed(parseDstStyleOp( |
1724 | parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { |
1725 | return parseDenseI64ArrayAttr(parser, attributes, "permutation" ); |
1726 | }))) |
1727 | return failure(); |
1728 | |
1729 | OpBuilder builder(parser.getContext()); |
1730 | buildIdentityRegion(builder, result.location, *result.addRegion(), |
1731 | /*inputs=*/result.operands, |
1732 | /*outputs=*/{}); |
1733 | return success(); |
1734 | } |
1735 | |
1736 | void TransposeOp::getAsmResultNames( |
1737 | function_ref<void(Value, StringRef)> setNameFn) { |
1738 | if (!getResults().empty()) |
1739 | setNameFn(getResults().front(), "transposed" ); |
1740 | } |
1741 | |
1742 | void TransposeOp::print(OpAsmPrinter &p) { |
1743 | printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); |
1744 | printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation()); |
1745 | p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()}); |
1746 | } |
1747 | |
1748 | LogicalResult TransposeOp::verify() { |
1749 | ArrayRef<int64_t> permutationRef = getPermutation(); |
1750 | |
1751 | if (!isPermutationVector(permutationRef)) |
1752 | return emitOpError("permutation is not valid" ); |
1753 | |
1754 | auto inputType = getInput().getType(); |
1755 | auto initType = getInit().getType(); |
1756 | |
1757 | int64_t rank = inputType.getRank(); |
1758 | |
1759 | if (rank != initType.getRank()) |
1760 | return emitOpError() << "input rank " << rank |
1761 | << " does not match init rank " << initType.getRank(); |
1762 | |
1763 | if (rank != static_cast<int64_t>(permutationRef.size())) |
1764 | return emitOpError() << "size of permutation " << permutationRef.size() |
1765 | << " does not match the argument rank " << rank; |
1766 | |
1767 | auto inputDims = inputType.getShape(); |
1768 | auto initDims = initType.getShape(); |
1769 | |
1770 | for (int64_t i = 0; i < rank; ++i) { |
1771 | int64_t inputDim = inputDims[permutationRef[i]]; |
1772 | int64_t initDim = initDims[i]; |
1773 | |
1774 | if (inputDim != initDim) { |
1775 | return emitOpError() << "dim(result, " << i << ") = " << initDim |
1776 | << " doesn't match dim(input, permutation[" << i |
1777 | << "]) = " << inputDim; |
1778 | } |
1779 | } |
1780 | |
1781 | return success(); |
1782 | } |
1783 | |
1784 | SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() { |
1785 | int64_t rank = getInit().getType().getRank(); |
1786 | return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel); |
1787 | } |
1788 | |
1789 | ArrayAttr TransposeOp::getIndexingMaps() { |
1790 | Builder builder(getContext()); |
1791 | int64_t rank = getInit().getType().getRank(); |
1792 | return builder.getAffineMapArrayAttr( |
1793 | {inversePermutation(AffineMap::getPermutationMap( |
1794 | llvm::to_vector_of<unsigned>(getPermutation()), getContext())), |
1795 | builder.getMultiDimIdentityMap(rank)}); |
1796 | } |
1797 | |
1798 | void TransposeOp::getEffects( |
1799 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
1800 | &effects) { |
1801 | getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(), |
1802 | getDpsInits()); |
1803 | } |
1804 | |
1805 | LogicalResult TransposeOp::fold(FoldAdaptor adaptor, |
1806 | SmallVectorImpl<OpFoldResult> &result) { |
1807 | // Single dimension transpose. |
1808 | if (getPermutation().size() == 0) { |
1809 | result.push_back(getInput()); |
1810 | return success(); |
1811 | } |
1812 | // Identity permutation. |
1813 | if (isIdentityPermutation(getPermutation())) { |
1814 | result.push_back(getInput()); |
1815 | return success(); |
1816 | } |
1817 | |
1818 | return failure(); |
1819 | } |
1820 | |
1821 | //===----------------------------------------------------------------------===// |
1822 | // BroadcastOp |
1823 | //===----------------------------------------------------------------------===// |
1824 | |
1825 | void BroadcastOp::build(::mlir::OpBuilder &builder, |
1826 | ::mlir::OperationState &result, Value input, Value init, |
1827 | DenseI64ArrayAttr dimensions, |
1828 | ArrayRef<NamedAttribute> attributes) { |
1829 | result.addOperands(input); |
1830 | result.addOperands(init); |
1831 | result.addAttribute(getDimensionsAttrName(result.name), dimensions); |
1832 | result.addAttributes(attributes); |
1833 | |
1834 | // Add output types for `RankedTensorType` output arguments. |
1835 | Type initType = init.getType(); |
1836 | if (llvm::isa<RankedTensorType>(initType)) |
1837 | result.addTypes(initType); |
1838 | |
1839 | buildIdentityRegion(builder, result.location, *result.addRegion(), input, |
1840 | init); |
1841 | } |
1842 | |
1843 | void BroadcastOp::build(::mlir::OpBuilder &builder, |
1844 | ::mlir::OperationState &result, Value input, Value init, |
1845 | ArrayRef<int64_t> dimensions, |
1846 | ArrayRef<NamedAttribute> attributes) { |
1847 | build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions), |
1848 | attributes); |
1849 | } |
1850 | |
1851 | ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) { |
1852 | if (failed(parseDstStyleOp( |
1853 | parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) { |
1854 | return parseDenseI64ArrayAttr(parser, attributes, "dimensions" ); |
1855 | }))) |
1856 | return failure(); |
1857 | |
1858 | OpBuilder builder(parser.getContext()); |
1859 | buildIdentityRegion(builder, result.location, *result.addRegion(), |
1860 | /*inputs=*/result.operands, |
1861 | /*outputs=*/{}); |
1862 | return success(); |
1863 | } |
1864 | |
1865 | void BroadcastOp::getAsmResultNames( |
1866 | function_ref<void(Value, StringRef)> setNameFn) { |
1867 | if (!getResults().empty()) |
1868 | setNameFn(getResults().front(), "broadcasted" ); |
1869 | } |
1870 | |
1871 | void BroadcastOp::print(OpAsmPrinter &p) { |
1872 | printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); |
1873 | printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); |
1874 | p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); |
1875 | } |
1876 | |
1877 | LogicalResult BroadcastOp::verify() { |
1878 | ArrayRef<int64_t> dimensionsRef = getDimensions(); |
1879 | |
1880 | auto inputType = getInput().getType(); |
1881 | auto initType = getInit().getType(); |
1882 | |
1883 | int64_t inputRank = inputType.getRank(); |
1884 | int64_t initRank = initType.getRank(); |
1885 | |
1886 | auto inputShape = inputType.getShape(); |
1887 | auto initShape = initType.getShape(); |
1888 | |
1889 | if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank) |
1890 | return emitOpError() << "input rank plus added dimensions does not " |
1891 | "match init rank. input rank: " |
1892 | << inputRank |
1893 | << ", dimensions size: " << dimensionsRef.size() |
1894 | << ", init rank: " << initRank; |
1895 | |
1896 | for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) { |
1897 | if (dim < 0 || dim >= initRank) |
1898 | return emitOpError() << "dimension " << idx |
1899 | << " is out of range. expected range: [0, " |
1900 | << initRank - 1 << "], got: " << dim; |
1901 | } |
1902 | |
1903 | // Mapping from input dims to init dims. |
1904 | SmallVector<int64_t> dimMap; |
1905 | for (auto dim : llvm::seq<int64_t>(0, initRank)) { |
1906 | if (!llvm::is_contained(dimensionsRef, dim)) |
1907 | dimMap.push_back(dim); |
1908 | } |
1909 | |
1910 | for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) { |
1911 | // This dimensions is mapped from the input. Init and input dims should |
1912 | // match. |
1913 | if (inputShape[inputDimIdx] != initShape[initDimIdx]) |
1914 | return emitOpError() << "input dim " << inputDimIdx |
1915 | << " should match init dim " << initDimIdx |
1916 | << ". input: " << inputShape[inputDimIdx] |
1917 | << ", init: " << initShape[initDimIdx]; |
1918 | } |
1919 | |
1920 | return success(); |
1921 | } |
1922 | |
1923 | SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() { |
1924 | int64_t rank = getInit().getType().getRank(); |
1925 | return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel); |
1926 | } |
1927 | |
1928 | ArrayAttr BroadcastOp::getIndexingMaps() { |
1929 | Builder builder(getContext()); |
1930 | int64_t rank = getInit().getType().getRank(); |
1931 | return builder.getAffineMapArrayAttr( |
1932 | {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()), |
1933 | builder.getMultiDimIdentityMap(rank)}); |
1934 | } |
1935 | |
1936 | void BroadcastOp::getEffects( |
1937 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
1938 | &effects) { |
1939 | getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(), |
1940 | getDpsInits()); |
1941 | } |
1942 | |
1943 | void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, |
1944 | MLIRContext *context) { |
1945 | results.add<EraseIdentityLinalgOp<BroadcastOp>>(context); |
1946 | } |
1947 | |
1948 | //===----------------------------------------------------------------------===// |
1949 | // YieldOp |
1950 | //===----------------------------------------------------------------------===// |
1951 | |
1952 | void linalg::YieldOp::print(OpAsmPrinter &p) { |
1953 | if (getNumOperands() > 0) |
1954 | p << ' ' << getOperands(); |
1955 | p.printOptionalAttrDict((*this)->getAttrs()); |
1956 | if (getNumOperands() > 0) |
1957 | p << " : " << getOperandTypes(); |
1958 | } |
1959 | |
1960 | ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) { |
1961 | SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo; |
1962 | SmallVector<Type, 2> types; |
1963 | SMLoc loc = parser.getCurrentLocation(); |
1964 | return failure(parser.parseOperandList(opInfo) || |
1965 | parser.parseOptionalAttrDict(result.attributes) || |
1966 | (!opInfo.empty() && parser.parseColonTypeList(types)) || |
1967 | parser.resolveOperands(opInfo, types, loc, result.operands)); |
1968 | } |
1969 | |
1970 | // Check the operand number and types must match the element types of the |
1971 | // LinalgOp interface's shaped operands. |
1972 | static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) { |
1973 | if (op.getNumOperands() != linalgOp.getNumDpsInits()) |
1974 | return op.emitOpError("expected number of yield values (" ) |
1975 | << op.getNumOperands() |
1976 | << ") to match the number of inits / outs operands of the enclosing " |
1977 | << "LinalgOp (" << linalgOp.getNumDpsInits() << ")" ; |
1978 | |
1979 | for (OpOperand &opOperand : op->getOpOperands()) { |
1980 | OpOperand *outputOperand = |
1981 | linalgOp.getDpsInitOperand(opOperand.getOperandNumber()); |
1982 | Type elementType = outputOperand->get().getType(); |
1983 | if (isa<MemRefType, RankedTensorType>(elementType)) |
1984 | elementType = getElementTypeOrSelf(outputOperand->get().getType()); |
1985 | if (opOperand.get().getType() != elementType) |
1986 | return op.emitOpError("type of yield operand " ) |
1987 | << (opOperand.getOperandNumber() + 1) << " (" |
1988 | << opOperand.get().getType() << ") doesn't match " |
1989 | << "the element type of the enclosing linalg.generic op (" |
1990 | << elementType << ")" ; |
1991 | } |
1992 | return success(); |
1993 | } |
1994 | |
1995 | LogicalResult linalg::YieldOp::verify() { |
1996 | auto *parentOp = (*this)->getParentOp(); |
1997 | if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) |
1998 | return emitOpError("expected single non-empty parent region" ); |
1999 | |
2000 | if (auto linalgOp = dyn_cast<LinalgOp>(parentOp)) |
2001 | return verifyYield(*this, linalgOp); |
2002 | |
2003 | return emitOpError("expected parent op with LinalgOp interface" ); |
2004 | } |
2005 | |
2006 | //===----------------------------------------------------------------------===// |
2007 | // IndexOp |
2008 | //===----------------------------------------------------------------------===// |
2009 | |
2010 | LogicalResult IndexOp::verify() { |
2011 | auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp()); |
2012 | if (!linalgOp) |
2013 | return emitOpError("expected parent op with LinalgOp interface" ); |
2014 | if (linalgOp.getNumLoops() <= getDim()) |
2015 | return emitOpError("expected dim (" ) |
2016 | << getDim() << ") to be lower than the number of loops (" |
2017 | << linalgOp.getNumLoops() << ") of the enclosing LinalgOp" ; |
2018 | return success(); |
2019 | } |
2020 | |
2021 | /////// Operations corresponding to library calls defined with Tablegen //////// |
2022 | |
2023 | #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc" |
2024 | |
2025 | #define GET_OP_CLASSES |
2026 | #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" |
2027 | |
2028 | #define GET_OP_CLASSES |
2029 | #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" |
2030 | |
2031 | AffineMap mlir::linalg::(std::optional<AffineMap> maybeMap, |
2032 | unsigned rank, |
2033 | MLIRContext *context) { |
2034 | if (maybeMap) |
2035 | return *maybeMap; |
2036 | if (rank == 0) |
2037 | return AffineMap::get(context); |
2038 | return AffineMap::getMultiDimIdentityMap(numDims: rank, context); |
2039 | } |
2040 | |
2041 | SmallVector<AffineExpr, 4> |
2042 | mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx, |
2043 | MLIRContext *context) { |
2044 | SmallVector<AffineExpr, 4> res; |
2045 | res.reserve(N: num); |
2046 | for (unsigned i = 0; i < num; ++i) |
2047 | res.push_back(Elt: getAffineDimExpr(position: startIdx++, context)); |
2048 | return res; |
2049 | } |
2050 | |
2051 | SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a, |
2052 | ArrayRef<AffineExpr> b) { |
2053 | auto rangeA = llvm::make_range(x: a.begin(), y: a.end()); |
2054 | auto rangeB = llvm::make_range(x: b.begin(), y: b.end()); |
2055 | auto concatRanges = llvm::concat<const AffineExpr>(Ranges&: rangeA, Ranges&: rangeB); |
2056 | return llvm::to_vector<4>(Range&: concatRanges); |
2057 | } |
2058 | |
2059 | static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) { |
2060 | if (auto memref = llvm::dyn_cast<MemRefType>(t)) { |
2061 | ss << "view" ; |
2062 | for (auto size : memref.getShape()) |
2063 | if (size < 0) |
2064 | ss << "sx" ; |
2065 | else |
2066 | ss << size << "x" ; |
2067 | if (failed(appendMangledType(ss, memref.getElementType()))) |
2068 | return failure(); |
2069 | if (auto as = memref.getMemorySpace()) { |
2070 | if (auto attr = llvm::dyn_cast<IntegerAttr>(as)) |
2071 | ss << "as" << attr.getInt(); |
2072 | else |
2073 | return failure(); |
2074 | } |
2075 | return success(); |
2076 | } |
2077 | if (auto vec = llvm::dyn_cast<VectorType>(t)) { |
2078 | ss << "vector" ; |
2079 | llvm::interleave( |
2080 | vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x" ; }); |
2081 | if (failed(appendMangledType(ss, vec.getElementType()))) |
2082 | return failure(); |
2083 | return success(); |
2084 | } |
2085 | if (t.isSignlessIntOrIndexOrFloat()) { |
2086 | ss << t; |
2087 | return success(); |
2088 | } |
2089 | return failure(); |
2090 | } |
2091 | |
2092 | std::string mlir::linalg::generateLibraryCallName(Operation *op) { |
2093 | assert(isa<LinalgOp>(op)); |
2094 | std::string name(op->getName().getStringRef().str()); |
2095 | std::string fun = "" ; |
2096 | for (NamedAttribute kv : op->getAttrs()) { |
2097 | if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) { |
2098 | fun = stringifyEnum(ufa.getValue()).str() + "_" ; |
2099 | } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) { |
2100 | fun = stringifyEnum(bfa.getValue()).str() + "_" ; |
2101 | } |
2102 | } |
2103 | name.reserve(res: 128); |
2104 | std::replace(first: name.begin(), last: name.end(), old_value: '.', new_value: '_'); |
2105 | llvm::raw_string_ostream ss(name); |
2106 | ss << "_" << fun; |
2107 | for (Type t : op->getOperandTypes()) { |
2108 | if (failed(result: appendMangledType(ss, t))) |
2109 | return std::string(); |
2110 | ss << "_" ; |
2111 | } |
2112 | std::string res = ss.str(); |
2113 | res.pop_back(); |
2114 | return res; |
2115 | } |
2116 | |
2117 | //===----------------------------------------------------------------------===// |
2118 | // Canonicalizers and Folders. |
2119 | //===----------------------------------------------------------------------===// |
2120 | |
2121 | namespace { |
2122 | struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> { |
2123 | using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern; |
2124 | |
2125 | LogicalResult matchAndRewrite(LinalgOp op, |
2126 | PatternRewriter &rewriter) const override { |
2127 | for (OpOperand &opOperand : op->getOpOperands()) { |
2128 | // Linalg "inputs" may be either tensor or memref type. |
2129 | // tensor<0xelt_type> is a convention that may not always mean |
2130 | // "0 iterations". Only erase in cases we see memref<...x0x...>. |
2131 | auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType()); |
2132 | if (!mt) |
2133 | continue; |
2134 | if (llvm::is_contained(op.getShape(&opOperand), 0)) { |
2135 | rewriter.eraseOp(op); |
2136 | return success(); |
2137 | } |
2138 | } |
2139 | return failure(); |
2140 | } |
2141 | }; |
2142 | |
2143 | /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has |
2144 | /// result that is more static than the linalg op. |
2145 | struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> { |
2146 | using OpRewritePattern<tensor::CastOp>::OpRewritePattern; |
2147 | |
2148 | LogicalResult matchAndRewrite(tensor::CastOp castOp, |
2149 | PatternRewriter &rewriter) const override { |
2150 | if (!tensor::canFoldIntoProducerOp(castOp)) |
2151 | return failure(); |
2152 | |
2153 | auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>(); |
2154 | if (!linalgOp) |
2155 | return failure(); |
2156 | |
2157 | // Cast can be in conditionally reachable region, if which case folding will |
2158 | // generate invalid code. Only conservatively fold ops in same block for |
2159 | // now. |
2160 | if (castOp->getBlock() != linalgOp->getBlock()) |
2161 | return failure(); |
2162 | |
2163 | OpBuilder::InsertionGuard guard(rewriter); |
2164 | rewriter.setInsertionPoint(linalgOp); |
2165 | |
2166 | Location loc = linalgOp.getLoc(); |
2167 | OpResult resultValue = llvm::cast<OpResult>(castOp.getSource()); |
2168 | unsigned resultNumber = resultValue.getResultNumber(); |
2169 | auto resultType = |
2170 | llvm::cast<RankedTensorType>(castOp->getResult(0).getType()); |
2171 | // Replace the `outs` for the result with a `tensor.cast`. This cast is now |
2172 | // going from a more dynamic shape to a less dynamic shape. If the producer |
2173 | // for this cast, i.e. producer of the out operand, is also an operation |
2174 | // that folds with tensor.cast consumer (like this pattern), the cast will |
2175 | // continue to propagate as far up the stack as it can go. |
2176 | OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber); |
2177 | Value newOperand = |
2178 | rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get()); |
2179 | SmallVector<Value> newOperands = linalgOp.getDpsInputs(); |
2180 | SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(), |
2181 | linalgOp.getDpsInits().end()); |
2182 | outputOperands[resultNumber] = newOperand; |
2183 | newOperands.append(in_start: outputOperands.begin(), in_end: outputOperands.end()); |
2184 | |
2185 | SmallVector<Type> resultTypes(linalgOp->result_type_begin(), |
2186 | linalgOp->result_type_end()); |
2187 | resultTypes[resultNumber] = resultType; |
2188 | Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands); |
2189 | |
2190 | // Create a tensor.cast operation back to the original type. |
2191 | Value castBack = rewriter.create<tensor::CastOp>( |
2192 | loc, resultValue.getType(), newOp->getResult(resultNumber)); |
2193 | |
2194 | SmallVector<Value> results(newOp->result_begin(), newOp->result_end()); |
2195 | results[resultNumber] = castBack; |
2196 | rewriter.replaceOp(linalgOp, results); |
2197 | rewriter.replaceOp(castOp, newOp->getResult(idx: resultNumber)); |
2198 | return success(); |
2199 | } |
2200 | }; |
2201 | |
2202 | /// For each of the operand in `operands` this function maps the static sizes of |
2203 | /// dimensions to their affine dim expressions. |
2204 | static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands, |
2205 | llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) { |
2206 | for (OpOperand &opOperand : operands) { |
2207 | if (linalgOp.isScalar(&opOperand)) |
2208 | continue; |
2209 | Value src = opOperand.get(); |
2210 | auto sourceType = llvm::cast<RankedTensorType>(src.getType()); |
2211 | auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand); |
2212 | |
2213 | // Get the `sourceShape` of the `sourceType`. If the operand is a result of |
2214 | // `tensor.cast` operation and source of the cast operation has a static |
2215 | // shape, then assign it to the `sourceShape`. |
2216 | auto *parentOp = src.getDefiningOp(); |
2217 | ArrayRef<int64_t> sourceShape = sourceType.getShape(); |
2218 | if (parentOp) { |
2219 | if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) { |
2220 | Value castSource = castOp.getSource(); |
2221 | auto castSourceType = |
2222 | llvm::dyn_cast<RankedTensorType>(castSource.getType()); |
2223 | if (castSourceType && castSourceType.hasStaticShape()) |
2224 | sourceShape = castSourceType.getShape(); |
2225 | } |
2226 | } |
2227 | |
2228 | // If the source shape's dimension has a static shape, map the affine dim |
2229 | // expression to the known static size. |
2230 | for (unsigned i = 0; i < sourceShape.size(); i++) { |
2231 | if (sourceType.isDynamicDim(i)) |
2232 | continue; |
2233 | if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i))) |
2234 | affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]); |
2235 | } |
2236 | } |
2237 | } |
2238 | |
2239 | /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes |
2240 | /// mapped in `affineExprToSize`. New operands are created in `newOperands` and |
2241 | /// their result types is stored in `resultTypes`. If `opOperand` requires no |
2242 | /// change then `changeNeeded` is false and same operand is added in the |
2243 | /// `newOperands` list. |
2244 | static void createNewOperandWithStaticSizes( |
2245 | Location loc, PatternRewriter &rewriter, OpOperand *opOperand, |
2246 | llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp, |
2247 | SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes, |
2248 | bool &changeNeeded) { |
2249 | Value src = opOperand->get(); |
2250 | newOperands.push_back(Elt: src); |
2251 | if (linalgOp.isScalar(opOperand)) |
2252 | return; |
2253 | auto sourceType = llvm::cast<RankedTensorType>(src.getType()); |
2254 | Type resultType = sourceType; |
2255 | if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) { |
2256 | resultTypes.push_back(Elt: resultType); |
2257 | return; |
2258 | } |
2259 | ArrayRef<int64_t> sourceShape = sourceType.getShape(); |
2260 | AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand); |
2261 | SmallVector<int64_t> newShape; |
2262 | // If operand is updated with new shape, `newOperandNeeded` will be |
2263 | // true. |
2264 | bool newOperandNeeded = false; |
2265 | for (unsigned i = 0; i < sourceShape.size(); i++) { |
2266 | int64_t dimShape = sourceShape[i]; |
2267 | AffineExpr dimExpr = sourceMap.getResult(idx: i); |
2268 | if (!affineExprToSize.contains(Val: dimExpr) || !sourceType.isDynamicDim(i)) { |
2269 | newShape.push_back(Elt: dimShape); |
2270 | continue; |
2271 | } |
2272 | // Dimension has a dynamic shape and corresponding affine dim |
2273 | // expression is present in the map. So assign the size for the |
2274 | // given affine dim expression to the dimension. |
2275 | newShape.push_back(Elt: affineExprToSize[dimExpr]); |
2276 | newOperandNeeded = true; |
2277 | } |
2278 | resultType = RankedTensorType::get(newShape, sourceType.getElementType()); |
2279 | if (newOperandNeeded) { |
2280 | changeNeeded = true; |
2281 | // Get the new operand value given its size and element type by |
2282 | // casting it. |
2283 | Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src); |
2284 | unsigned index = opOperand->getOperandNumber(); |
2285 | newOperands[index] = newOperand; |
2286 | } |
2287 | if (linalgOp.isDpsInit(opOperand)) |
2288 | resultTypes.push_back(Elt: resultType); |
2289 | } |
2290 | |
2291 | /// Static shapes for the operands can be inferred if any one of the operands |
2292 | /// have a static shape. This can be done by referring to the affine dim |
2293 | /// expressions for the operand. |
2294 | struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> { |
2295 | using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern; |
2296 | |
2297 | LogicalResult matchAndRewrite(LinalgOp linalgOp, |
2298 | PatternRewriter &rewriter) const override { |
2299 | if (!linalgOp.hasPureTensorSemantics()) |
2300 | return failure(); |
2301 | |
2302 | // Maps must be projected permutations. |
2303 | if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) { |
2304 | return !map.isProjectedPermutation(); |
2305 | })) |
2306 | return failure(); |
2307 | |
2308 | // Maps affine dim expressions to the static size of that dimension. |
2309 | llvm::DenseMap<AffineExpr, int64_t> affineExprToSize; |
2310 | Location loc = linalgOp.getLoc(); |
2311 | |
2312 | // For each of the affine dim expression, check if the size is known. If |
2313 | // known add that in the map. |
2314 | populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize); |
2315 | |
2316 | SmallVector<Value> newOperands; |
2317 | SmallVector<Type> resultTypes; |
2318 | |
2319 | // `changeNeeded` is `false` if the operands of `linalgOp` require no |
2320 | // change in their types. |
2321 | bool changeNeeded = false; |
2322 | newOperands.reserve(N: linalgOp->getNumOperands()); |
2323 | resultTypes.reserve(N: linalgOp.getNumDpsInits()); |
2324 | |
2325 | // Iterate over all the operands and update the static sizes. |
2326 | for (OpOperand &opOperand : linalgOp->getOpOperands()) { |
2327 | createNewOperandWithStaticSizes(loc, rewriter, &opOperand, |
2328 | affineExprToSize, linalgOp, newOperands, |
2329 | resultTypes, changeNeeded); |
2330 | } |
2331 | |
2332 | // If the generic op has all the required static information, no |
2333 | // canonicalization needed. |
2334 | if (!changeNeeded) |
2335 | return failure(); |
2336 | |
2337 | // Clone op. |
2338 | Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands); |
2339 | SmallVector<Value> replacements; |
2340 | replacements.reserve(N: newOp->getNumResults()); |
2341 | for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) { |
2342 | Value newResult = std::get<1>(it); |
2343 | Value oldResult = std::get<0>(it); |
2344 | Type newType = newResult.getType(); |
2345 | Type oldType = oldResult.getType(); |
2346 | replacements.push_back( |
2347 | (newType != oldType) |
2348 | ? rewriter.create<tensor::CastOp>(loc, oldType, newResult) |
2349 | : newResult); |
2350 | } |
2351 | rewriter.replaceOp(linalgOp, replacements); |
2352 | return success(); |
2353 | } |
2354 | }; |
2355 | |
2356 | } // namespace |
2357 | |
2358 | // All named ops canonicalizers and folders are auto-generated in the |
2359 | // .cpp.inc. |
2360 | |
2361 | //===----------------------------------------------------------------------===// |
2362 | // SoftmaxOp |
2363 | //===----------------------------------------------------------------------===// |
2364 | |
2365 | LogicalResult SoftmaxOp::verify() { |
2366 | ShapedType inputType = getInputOperandType(); |
2367 | ShapedType outputType = getOutputOperandType(); |
2368 | |
2369 | ArrayRef<int64_t> inputShape = inputType.getShape(); |
2370 | ArrayRef<int64_t> outputShape = outputType.getShape(); |
2371 | if (failed(verifyCompatibleShape(inputShape, outputShape))) |
2372 | return emitOpError("incompatible output shape" ); |
2373 | |
2374 | int64_t inputRank = getInputOperandRank(); |
2375 | int64_t dimension = getDimension(); |
2376 | if ((dimension < 0) || (dimension >= inputRank)) |
2377 | return emitOpError("incorrect dimension specified" ); |
2378 | |
2379 | return success(); |
2380 | } |
2381 | |
2382 | SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) { |
2383 | int64_t operandRank = getInputOperandRank(); |
2384 | SmallVector<Range> loopBounds(operandRank); |
2385 | Location loc = getLoc(); |
2386 | Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); |
2387 | Value one = builder.create<arith::ConstantIndexOp>(loc, 1); |
2388 | Value source = getInput(); |
2389 | for (auto dim : llvm::seq<int64_t>(0, operandRank)) { |
2390 | loopBounds[dim].offset = zero; |
2391 | loopBounds[dim].size = getDimValue(builder, loc, source, dim); |
2392 | loopBounds[dim].stride = one; |
2393 | } |
2394 | return loopBounds; |
2395 | } |
2396 | |
2397 | SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() { |
2398 | SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(), |
2399 | utils::IteratorType::parallel); |
2400 | iteratorTypes[getDimension()] = utils::IteratorType::reduction; |
2401 | return iteratorTypes; |
2402 | } |
2403 | |
2404 | FailureOr<TilingResult> |
2405 | SoftmaxOp::getTiledImplementation(OpBuilder &builder, |
2406 | ArrayRef<OpFoldResult> offsets, |
2407 | ArrayRef<OpFoldResult> sizes) { |
2408 | int64_t rank = getInputOperandRank(); |
2409 | auto oneAttr = builder.getI64IntegerAttr(1); |
2410 | SmallVector<OpFoldResult> strides(rank, oneAttr); |
2411 | SmallVector<Value> tiledOperands; |
2412 | tiledOperands.emplace_back( |
2413 | getSlice(builder, getLoc(), getInput(), offsets, sizes, strides)); |
2414 | tiledOperands.emplace_back( |
2415 | getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides)); |
2416 | |
2417 | SmallVector<Type, 4> resultTypes; |
2418 | if (hasPureTensorSemantics()) |
2419 | resultTypes.push_back(tiledOperands[1].getType()); |
2420 | Operation *tiledOp = |
2421 | mlir::clone(builder, getOperation(), resultTypes, tiledOperands); |
2422 | |
2423 | return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())}; |
2424 | } |
2425 | |
2426 | LogicalResult SoftmaxOp::getResultTilePosition( |
2427 | OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets, |
2428 | ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets, |
2429 | SmallVector<OpFoldResult> &resultSizes) { |
2430 | if (resultNumber == 0) { |
2431 | resultOffsets.assign(offsets.begin(), offsets.end()); |
2432 | resultSizes.assign(sizes.begin(), sizes.end()); |
2433 | return success(); |
2434 | } |
2435 | return failure(); |
2436 | } |
2437 | |
2438 | // cast(dynamic) -> static. |
2439 | LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) { |
2440 | return memref::foldMemRefCast(*this); |
2441 | } |
2442 | |
2443 | LogicalResult |
2444 | SoftmaxOp::reifyResultShapes(OpBuilder &b, |
2445 | ReifiedRankedShapedTypeDims &reifiedReturnShapes) { |
2446 | SmallVector<OpFoldResult> shapes; |
2447 | Location loc = getOperation()->getLoc(); |
2448 | IRRewriter rewriter(b); |
2449 | auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType()); |
2450 | auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType()); |
2451 | for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) { |
2452 | if (!outputShapedType.isDynamicDim(dim)) { |
2453 | // Static dim: Return IntegerAttr. |
2454 | shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim))); |
2455 | } else { |
2456 | // Dynamic dim: Return Value. |
2457 | OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim); |
2458 | shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr)); |
2459 | } |
2460 | } |
2461 | reifiedReturnShapes.emplace_back(std::move(shapes)); |
2462 | return success(); |
2463 | } |
2464 | |
2465 | void SoftmaxOp::getEffects( |
2466 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
2467 | &effects) { |
2468 | getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(), |
2469 | getDpsInits()); |
2470 | } |
2471 | |
2472 | // Helper functions for softmax decomposition. |
2473 | // @{ |
2474 | |
2475 | // Helper function to produce the iterator types (reduction or parallel) and |
2476 | // affine maps for the iterators used in the decomposition of softmax. |
2477 | // This method creates: |
2478 | // If allParallel == true: |
2479 | // - iterator type: {parallel, ..., parallel} |
2480 | // - affine maps: |
2481 | // -- identity with inputRank dimensions. |
2482 | // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN), |
2483 | // where N == inputRank. |
2484 | // |
2485 | // If allParallel == false: |
2486 | // - iterator type at dim(i) == parallel for i != \p dim and |
2487 | // dim(dim) == reduction. |
2488 | // - affine map: |
2489 | // -- identity with inputRank dimensions. |
2490 | // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN), |
2491 | // where N == inputRank. |
2492 | static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>> |
2493 | computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, |
2494 | int64_t dim, bool allParallel = false) { |
2495 | SmallVector<utils::IteratorType> iteratorTypes(inputRank, |
2496 | utils::IteratorType::parallel); |
2497 | if (!allParallel) |
2498 | iteratorTypes[dim] = utils::IteratorType::reduction; |
2499 | MLIRContext *ctxt = builder.getContext(); |
2500 | auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt); |
2501 | SmallVector<AffineExpr, 2> affineExprs; |
2502 | for (int i = 0; i < inputRank; i++) { |
2503 | if (i != dim) |
2504 | affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt)); |
2505 | } |
2506 | auto reductionMap = |
2507 | AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt); |
2508 | SmallVector<AffineMap> indexingMaps{identityMap, reductionMap}; |
2509 | return std::make_tuple(iteratorTypes, indexingMaps); |
2510 | } |
2511 | |
2512 | // Helper function to produce a linalg.generic that computes a reduction on |
2513 | // dimension \p dim with the operation type \p T. |
2514 | template <typename T> |
2515 | static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, |
2516 | int64_t dim) { |
2517 | auto inputType = cast<ShapedType>(input.getType()); |
2518 | ArrayRef<int64_t> inputShape = inputType.getShape(); |
2519 | int64_t inputRank = inputShape.size(); |
2520 | auto [iteratorTypes, indexingMaps] = |
2521 | computeIteratorTypesAndIndexingMaps(builder, inputRank, dim); |
2522 | assert(indexingMaps.size() == 2 && |
2523 | "We should have two maps: 1 for the input, 1 for the output" ); |
2524 | assert(indexingMaps[0].isIdentity() && "input map should be identity" ); |
2525 | |
2526 | auto genericOp = builder.create<linalg::GenericOp>( |
2527 | loc, output.getType(), input, output, indexingMaps, iteratorTypes, |
2528 | [&](OpBuilder &b, Location loc, ValueRange args) { |
2529 | Value result = b.create<T>(loc, args[0], args[1]); |
2530 | b.create<linalg::YieldOp>(loc, result); |
2531 | }); |
2532 | return genericOp.getResult(0); |
2533 | } |
2534 | |
2535 | /// Produce a linalg generic that computes the second step of the softmax |
2536 | /// decomposition: res = exp(input - max), where \p max is the max of \p input |
2537 | /// on dimension \p dim. |
2538 | static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, |
2539 | Value max, Value output, int64_t dim) { |
2540 | auto inputType = cast<ShapedType>(input.getType()); |
2541 | ArrayRef<int64_t> inputShape = inputType.getShape(); |
2542 | int64_t inputRank = inputShape.size(); |
2543 | auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps( |
2544 | builder, inputRank, dim, /*allParallel=*/true); |
2545 | assert(indexingMaps.size() == 2 && "We should have one map for each input" ); |
2546 | assert(indexingMaps[0].isIdentity() && "input map should be identity" ); |
2547 | // Add the affine map for the output argument. |
2548 | indexingMaps.push_back(indexingMaps[0]); |
2549 | auto genericOp = builder.create<linalg::GenericOp>( |
2550 | loc, input.getType(), ValueRange{input, max}, output, indexingMaps, |
2551 | iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { |
2552 | Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]); |
2553 | Value result = b.create<math::ExpOp>(loc, diff); |
2554 | b.create<linalg::YieldOp>(loc, result); |
2555 | }); |
2556 | return genericOp.getResult(0); |
2557 | } |
2558 | |
2559 | /// Produce a linalg generic that computes the final step of the softmax |
2560 | /// decomposition. |
2561 | /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) { |
2562 | /// yield n / d |
2563 | /// } |
2564 | static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, |
2565 | Value denominator, Value output, int64_t dim) { |
2566 | auto inputType = cast<ShapedType>(numerator.getType()); |
2567 | ArrayRef<int64_t> inputShape = inputType.getShape(); |
2568 | int64_t inputRank = inputShape.size(); |
2569 | auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps( |
2570 | builder, inputRank, dim, /*allParallel=*/true); |
2571 | assert(indexingMaps.size() == 2 && |
2572 | "We should have one map for each input (2)" ); |
2573 | assert(indexingMaps[0].isIdentity() && "Numerator map should be identity" ); |
2574 | // Add the affine map for the output tensor. |
2575 | indexingMaps.push_back(indexingMaps[0]); |
2576 | auto genericOp = builder.create<linalg::GenericOp>( |
2577 | loc, numerator.getType(), ValueRange{numerator, denominator}, output, |
2578 | indexingMaps, iteratorTypes, |
2579 | [&](OpBuilder &b, Location loc, ValueRange args) { |
2580 | Value result = b.create<arith::DivFOp>(loc, args[0], args[1]); |
2581 | b.create<linalg::YieldOp>(loc, result); |
2582 | }); |
2583 | return genericOp.getResult(0); |
2584 | } |
2585 | // @} End helper functions for softmax decomposition. |
2586 | |
2587 | /// Given an N-dimensional tensor x, this method converts |
2588 | /// softmax(x) to the following sequence of operations: |
2589 | /// |
2590 | /// 1. Compute the max of x along dimension d. This results |
2591 | /// in a N-1 dimensional tensor m. |
2592 | /// m = max(x, dim = d) |
2593 | /// |
2594 | /// 2. Subtract a broadcasted m from x and exponentiate. This results in |
2595 | /// a N dimensional tensor z. |
2596 | /// z = exp(x - m) |
2597 | /// |
2598 | /// 3. Compute the sum of z along dimension d. This results in |
2599 | /// a N-1 dimensional tensor l. |
2600 | /// l = sum(z, dim = d) |
2601 | /// |
2602 | /// 4. Divide z and l. This gives the N-dimensional softmax. |
2603 | /// softmax = z / l |
2604 | /// |
2605 | FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) { |
2606 | OpBuilder::InsertionGuard guard(b); |
2607 | b.setInsertionPoint(*this); |
2608 | Location loc = getLoc(); |
2609 | Value input = getInput(); |
2610 | ShapedType inputType = getInputOperandType(); |
2611 | Type elementType = inputType.getElementType(); |
2612 | int64_t reductionDim = getDimension(); |
2613 | SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input); |
2614 | Value output = getOutput(); |
2615 | dims.erase(dims.begin() + reductionDim); |
2616 | // Step 1: Compute max along dim. |
2617 | Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType); |
2618 | Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf, |
2619 | elementType, b, loc, |
2620 | /*useOnlyFiniteValue=*/true); |
2621 | Value neutralForMaxFInit = |
2622 | b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce) |
2623 | .result(); |
2624 | Value max = reduce<arith::MaximumFOp>(b, loc, input, neutralForMaxFInit, |
2625 | reductionDim); |
2626 | |
2627 | // Step 2: Subtract max from input and exponentiate. |
2628 | Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim); |
2629 | |
2630 | // Step 3: Compute sum along dim. |
2631 | Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, |
2632 | b, loc, /*useOnlyFiniteValue=*/true); |
2633 | Value zeroInit = |
2634 | b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result(); |
2635 | Value denominator = |
2636 | reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim); |
2637 | |
2638 | // Step 4: Compute softmax. |
2639 | Value result = |
2640 | buildDivOp(b, loc, numerator, denominator, output, reductionDim); |
2641 | return SmallVector<Value>{result}; |
2642 | } |
2643 | |
2644 | //===----------------------------------------------------------------------===// |
2645 | // LinalgDialect |
2646 | //===----------------------------------------------------------------------===// |
2647 | |
2648 | void LinalgDialect::getCanonicalizationPatterns( |
2649 | RewritePatternSet &results) const { |
2650 | results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, |
2651 | InferStaticShapeOfOperands>(getContext()); |
2652 | } |
2653 | |
2654 | Operation *LinalgDialect::materializeConstant(OpBuilder &builder, |
2655 | Attribute value, Type type, |
2656 | Location loc) { |
2657 | return arith::ConstantOp::materialize(builder, value, type, loc); |
2658 | } |
2659 | |