1//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Linalg operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Linalg/IR/Linalg.h"
14
15#include "mlir/AsmParser/AsmParser.h"
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Arith/Utils/Utils.h"
19#include "mlir/Dialect/Complex/IR/Complex.h"
20#include "mlir/Dialect/Math/IR/Math.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/SCF/IR/SCF.h"
23#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24#include "mlir/Dialect/Tensor/IR/Tensor.h"
25#include "mlir/Dialect/Tensor/Utils/Utils.h"
26#include "mlir/Dialect/Utils/IndexingUtils.h"
27#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
28#include "mlir/Dialect/Utils/StaticValueUtils.h"
29#include "mlir/IR/AffineExprVisitor.h"
30#include "mlir/IR/AffineMap.h"
31#include "mlir/IR/Attributes.h"
32#include "mlir/IR/Builders.h"
33#include "mlir/IR/BuiltinAttributes.h"
34#include "mlir/IR/BuiltinTypeInterfaces.h"
35#include "mlir/IR/Matchers.h"
36#include "mlir/IR/OpImplementation.h"
37#include "mlir/IR/OperationSupport.h"
38#include "mlir/IR/PatternMatch.h"
39#include "mlir/Interfaces/InferTypeOpInterface.h"
40#include "mlir/Interfaces/SideEffectInterfaces.h"
41
42#include "llvm/ADT/DenseMap.h"
43#include "llvm/ADT/STLExtras.h"
44#include "llvm/ADT/SetOperations.h"
45#include "llvm/ADT/SmallSet.h"
46#include "llvm/ADT/SmallVector.h"
47#include "llvm/ADT/StringSet.h"
48#include "llvm/ADT/TypeSwitch.h"
49#include "llvm/Support/FormatVariadic.h"
50#include "llvm/Support/InterleavedRange.h"
51#include "llvm/Support/LogicalResult.h"
52#include "llvm/Support/MathExtras.h"
53#include "llvm/Support/raw_ostream.h"
54#include <cassert>
55#include <optional>
56
57using namespace mlir;
58using namespace mlir::linalg;
59
60/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
61static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
62 int64_t dim) {
63 auto type = cast<ShapedType>(v.getType());
64 if (!type.isDynamicDim(dim))
65 return builder.getIndexAttr(value: type.getDimSize(dim));
66
67 return getAsOpFoldResult(
68 val: TypeSwitch<Type, Value>(v.getType())
69 .Case<RankedTensorType>(caseFn: [&](RankedTensorType t) -> Value {
70 return builder.create<tensor::DimOp>(loc, v, dim);
71 })
72 .Case<MemRefType>(caseFn: [&](MemRefType t) -> Value {
73 return builder.create<memref::DimOp>(loc, v, dim);
74 }));
75}
76
77/// Returns a memref.subview or a tensor.extract_slice based on the type of the
78/// `source`.
79static Operation *getSlice(OpBuilder &b, Location loc, Value source,
80 ArrayRef<OpFoldResult> offsets,
81 ArrayRef<OpFoldResult> sizes,
82 ArrayRef<OpFoldResult> strides) {
83 return TypeSwitch<Type, Operation *>(source.getType())
84 .Case<RankedTensorType>(caseFn: [&](RankedTensorType t) -> Operation * {
85 return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
86 strides);
87 })
88 .Case<MemRefType>(caseFn: [&](MemRefType type) -> Operation * {
89 return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
90 strides);
91 })
92 .Default(defaultFn: [&](Type t) -> Operation * { return nullptr; });
93}
94
95//===----------------------------------------------------------------------===//
96// Helper functions
97//===----------------------------------------------------------------------===//
98
99Value linalg::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
100 int64_t dim) {
101 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
102 return b.createOrFold<memref::DimOp>(loc, source, dim);
103 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
104 return b.createOrFold<tensor::DimOp>(loc, source, dim);
105 llvm_unreachable("Expected MemRefType or TensorType");
106}
107
108OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source,
109 int64_t dim) {
110 auto shapedType = llvm::cast<ShapedType>(source.getType());
111 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
112 return createOrFoldDimOp(b, loc, source, dim);
113 return b.getIndexAttr(value: shapedType.getDimSize(dim));
114}
115
116//===----------------------------------------------------------------------===//
117// Support for named Linalg ops defined in ods-gen.
118//===----------------------------------------------------------------------===//
119
120using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
121 ArrayRef<NamedAttribute>)>;
122
123/// Fills the region of a structured operation using the provided
124/// `regionBuilder`. The method is used by both named structured ops created by
125/// ods-gen and by manually defined C++ ops. It is called by both builders and
126/// parsers and creates a block with arguments corresponding to the elemental
127/// types of `inputTypes` and `outputTypes`.
128static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
129 TypeRange inputTypes, TypeRange outputTypes,
130 ArrayRef<NamedAttribute> attrs,
131 RegionBuilderFn regionBuilder) {
132 SmallVector<Type, 8> argTypes;
133 SmallVector<Location, 8> argLocs;
134 for (auto containers : {inputTypes, outputTypes}) {
135 for (auto t : containers) {
136 argTypes.push_back(
137 Elt: isa<MemRefType, RankedTensorType>(Val: t) ? getElementTypeOrSelf(type: t) : t);
138
139 // TODO: Pass in a proper location here.
140 argLocs.push_back(Elt: opBuilder.getUnknownLoc());
141 }
142 }
143
144 // RAII.
145 OpBuilder::InsertionGuard guard(opBuilder);
146 Block *body =
147 opBuilder.createBlock(parent: &region, /*insertPt=*/{}, argTypes, locs: argLocs);
148
149 opBuilder.setInsertionPointToStart(body);
150 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
151 regionBuilder(b, *body, attrs);
152
153 // indexing_maps is an auto-generated method.
154
155 // iterator_types is an auto-generated method.
156}
157
158/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
159/// The result types are derived automatically if `resultTensorTypes` is none.
160/// The body of the operation is filled using `regionBuilder`. All ods-gen
161/// created structured operations use the method to implement their builders.
162static void buildStructuredOp(OpBuilder &b, OperationState &state,
163 std::optional<TypeRange> resultTensorTypes,
164 ValueRange inputs, ValueRange outputs,
165 ArrayRef<NamedAttribute> attributes,
166 RegionBuilderFn regionBuilder) {
167 // Derive the result types if needed.
168 SmallVector<Type> derivedResultTypes =
169 resultTensorTypes.value_or(u: TypeRange());
170 if (!resultTensorTypes)
171 copy_if(Range: outputs.getTypes(), Out: std::back_inserter(x&: derivedResultTypes),
172 P: llvm::IsaPred<RankedTensorType>);
173
174 state.addOperands(newOperands: inputs);
175 state.addOperands(newOperands: outputs);
176 state.addTypes(newTypes: derivedResultTypes);
177
178 state.addAttributes(newAttributes: attributes);
179 state.addAttribute(
180 "operandSegmentSizes",
181 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
182 static_cast<int32_t>(outputs.size())}));
183
184 // Create and fill the region of the structured operation.
185 Region &region = *state.addRegion();
186 fillStructuredOpRegion(opBuilder&: b, region, inputTypes: TypeRange(inputs), outputTypes: TypeRange(outputs),
187 attrs: state.attributes.getAttrs(), regionBuilder);
188}
189
190static void buildMatmulOp(OpBuilder &b, OperationState &state,
191 std::optional<TypeRange> resultTensorTypes,
192 ValueRange inputs, ValueRange outputs,
193 ArrayRef<NamedAttribute> attributes,
194 RegionBuilderFn regionBuilder,
195 ArrayRef<AffineMap> indexingMaps) {
196 // Initialize indexingMaps attribute, for MatmulOp.
197 SmallVector<Attribute, 3> indexingMapsAttrVal;
198 indexingMapsAttrVal = llvm::map_to_vector(
199 MatmulOp::getDefaultIndexingMaps(b.getContext()),
200 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
201 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
202 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
203 attributes, regionBuilder);
204}
205
206static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
207 std::optional<TypeRange> resultTensorTypes,
208 ValueRange inputs, ValueRange outputs,
209 ArrayRef<NamedAttribute> attributes,
210 RegionBuilderFn regionBuilder,
211 ArrayRef<AffineMap> indexingMaps) {
212 // Initialize indexingMaps attribute, for BatchMatmulOp.
213 SmallVector<Attribute, 4> indexingMapsAttrVal;
214 indexingMapsAttrVal =
215 llvm::map_to_vector(C&: indexingMaps, F: [](AffineMap map) -> Attribute {
216 return AffineMapAttr::get(map);
217 });
218 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
219 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
220 attributes, regionBuilder);
221}
222
223static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state,
224 std::optional<TypeRange> resultTensorTypes,
225 ValueRange inputs, ValueRange outputs,
226 ArrayRef<NamedAttribute> attributes,
227 RegionBuilderFn regionBuilder,
228 ArrayRef<AffineMap> indexingMaps) {
229 // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
230 SmallVector<Attribute, 4> indexingMapsAttrVal;
231 indexingMapsAttrVal =
232 llvm::map_to_vector(C&: indexingMaps, F: [](AffineMap map) -> Attribute {
233 return AffineMapAttr::get(map);
234 });
235 state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
236 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
237 attributes, regionBuilder);
238}
239
240/// Common parsing used for both named structured ops created by ods-gen and by
241/// manually defined C++ ops. Does not handle regions.
242static ParseResult
243parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
244 SmallVectorImpl<Type> &inputTypes,
245 SmallVectorImpl<Type> &outputTypes,
246 bool addOperandSegmentSizes = true) {
247 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
248 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
249 outputsOperands;
250
251 if (succeeded(Result: parser.parseOptionalLess())) {
252 if (parser.parseAttribute(result&: result.propertiesAttr) || parser.parseGreater())
253 return failure();
254 }
255 attrsLoc = parser.getCurrentLocation();
256 if (parser.parseOptionalAttrDict(result&: result.attributes))
257 return failure();
258
259 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "ins"))) {
260 if (parser.parseLParen())
261 return failure();
262
263 inputsOperandsLoc = parser.getCurrentLocation();
264 if (parser.parseOperandList(result&: inputsOperands) ||
265 parser.parseColonTypeList(result&: inputTypes) || parser.parseRParen())
266 return failure();
267 }
268
269 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "outs"))) {
270 outputsOperandsLoc = parser.getCurrentLocation();
271 if (parser.parseLParen() || parser.parseOperandList(result&: outputsOperands) ||
272 parser.parseColonTypeList(result&: outputTypes) || parser.parseRParen())
273 return failure();
274 }
275
276 if (parser.resolveOperands(operands&: inputsOperands, types&: inputTypes, loc: inputsOperandsLoc,
277 result&: result.operands) ||
278 parser.resolveOperands(operands&: outputsOperands, types&: outputTypes, loc: outputsOperandsLoc,
279 result&: result.operands))
280 return failure();
281
282 if (addOperandSegmentSizes) {
283 // This is a bit complex because we're trying to be backward compatible with
284 // operation syntax that mix the inherent attributes and the discardable
285 // ones in the same dictionary. If the properties are used, we append the
286 // operandSegmentSizes there directly. Otherwise we append it to the
287 // discardable attributes dictionary where it is handled by the generic
288 // Operation::create(...) method.
289 if (result.propertiesAttr) {
290 NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
291 attrs.append("operandSegmentSizes",
292 parser.getBuilder().getDenseI32ArrayAttr(
293 {static_cast<int32_t>(inputsOperands.size()),
294 static_cast<int32_t>(outputsOperands.size())}));
295 result.propertiesAttr = attrs.getDictionary(parser.getContext());
296 } else {
297 result.addAttribute("operandSegmentSizes",
298 parser.getBuilder().getDenseI32ArrayAttr(
299 {static_cast<int32_t>(inputsOperands.size()),
300 static_cast<int32_t>(outputsOperands.size())}));
301 }
302 }
303 if (!result.propertiesAttr) {
304 std::optional<RegisteredOperationName> info =
305 result.name.getRegisteredInfo();
306 if (info) {
307 if (failed(Result: info->verifyInherentAttrs(attributes&: result.attributes, emitError: [&]() {
308 return parser.emitError(loc: attrsLoc)
309 << "'" << result.name.getStringRef() << "' op ";
310 })))
311 return failure();
312 }
313 }
314 return success();
315}
316
317static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
318 ValueRange outputs) {
319 if (!inputs.empty())
320 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
321 if (!outputs.empty())
322 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
323}
324
325//===----------------------------------------------------------------------===//
326// Specific parsing and printing for named structured ops created by ods-gen.
327//===----------------------------------------------------------------------===//
328
329static ParseResult parseNamedStructuredOpRegion(
330 OpAsmParser &parser, Region &region, unsigned numRegionArgs,
331 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
332 RegionBuilderFn regionBuilder) {
333 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
334 return parser.emitError(
335 loc: parser.getCurrentLocation(),
336 message: llvm::formatv(Fmt: "[parseNamedStructuredOpRegion] ods-gen generated "
337 "region expects {0} args, got {1}",
338 Vals&: numRegionArgs, Vals: inputTypes.size() + outputTypes.size()));
339 }
340
341 OpBuilder opBuilder(parser.getContext());
342 fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
343 regionBuilder);
344 return success();
345}
346
347static ParseResult
348parseNamedStructuredOpResults(OpAsmParser &parser,
349 SmallVectorImpl<Type> &resultTypes) {
350 if (parser.parseOptionalArrowTypeList(result&: resultTypes))
351 return failure();
352 return success();
353}
354
355static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
356 OperationState &result,
357 unsigned numRegionArgs,
358 RegionBuilderFn regionBuilder) {
359 // TODO: Enable when ods-gen supports captures.
360 SmallVector<Type, 1> inputTypes, outputTypes;
361 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
362 return failure();
363
364 // Parse optional attributes.
365 if (parser.parseOptionalAttrDict(result&: result.attributes))
366 return failure();
367
368 // TODO: consider merging results parsing into region parsing.
369 // Need to wait for declarative assembly resolution to decide.
370 SmallVector<Type, 1> outputTensorsTypes;
371 if (parseNamedStructuredOpResults(parser, resultTypes&: outputTensorsTypes))
372 return failure();
373 result.addTypes(newTypes: outputTensorsTypes);
374
375 std::unique_ptr<Region> region = std::make_unique<Region>();
376 if (parseNamedStructuredOpRegion(parser, region&: *region, numRegionArgs, inputTypes,
377 outputTypes, attrs: result.attributes.getAttrs(),
378 regionBuilder))
379 return failure();
380 result.addRegion(region: std::move(region));
381
382 return success();
383}
384
385static void printNamedStructuredOpResults(OpAsmPrinter &p,
386 TypeRange resultTypes) {
387 if (resultTypes.empty())
388 return;
389 p.printOptionalArrowTypeList(types&: resultTypes);
390}
391
392static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
393 ValueRange inputs, ValueRange outputs,
394 ArrayRef<StringRef> elidedAttrs = {}) {
395 p.printOptionalAttrDict(attrs: op->getAttrs(), elidedAttrs);
396
397 // Printing is shared with generic ops, except for the region and
398 // attributes.
399 printCommonStructuredOpParts(p, inputs, outputs);
400
401 // Results printing.
402 printNamedStructuredOpResults(p, resultTypes: op->getResultTypes());
403
404 // Region is elided.
405}
406
407//===----------------------------------------------------------------------===//
408// Region builder helper.
409// TODO: Move this to a utility library.
410// The public methods on this class are referenced directly from generated code.
411// Helper build the unary, binary, and type conversion functions defined by the
412// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
413// class.
414//
415// Implementations of the math functions must be polymorphic over numeric types,
416// internally performing necessary casts. If the function application makes no
417// sense, then the only recourse is to assert and return nullptr. This can be
418// extended later if it becomes possible to fail construction of the region. The
419// invariant should be enforced at a higher level.
420//
421// TODO: These helpers are currently type polymorphic over the class of integer
422// and floating point types, but they will not internally cast within bit
423// widths of a class (mixed precision such as i8->i32) or across classes
424// (i.e. mixed float and integer). Many such combinations are ambiguous or need
425// to be handled with care and work is being considered to extend the op
426// language to make such cases explicit. In the mean-time, violating this will
427// fail verification, which is deemed acceptable.
428//===----------------------------------------------------------------------===//
429
430namespace {
431
432class RegionBuilderHelper {
433public:
434 RegionBuilderHelper(OpBuilder &builder, Block &block)
435 : builder(builder), block(block) {}
436
437 // Build the unary functions defined by OpDSL.
438 Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
439 if (!isFloatingPoint(value: arg))
440 llvm_unreachable("unsupported non numeric type");
441 OpBuilder::InsertionGuard g(builder);
442 builder.setInsertionPointToEnd(&block);
443 switch (unaryFn) {
444 case UnaryFn::exp:
445 return builder.create<math::ExpOp>(arg.getLoc(), arg);
446 case UnaryFn::log:
447 return builder.create<math::LogOp>(arg.getLoc(), arg);
448 case UnaryFn::abs:
449 return builder.create<math::AbsFOp>(arg.getLoc(), arg);
450 case UnaryFn::ceil:
451 return builder.create<math::CeilOp>(arg.getLoc(), arg);
452 case UnaryFn::floor:
453 return builder.create<math::FloorOp>(arg.getLoc(), arg);
454 case UnaryFn::negf:
455 return builder.create<arith::NegFOp>(arg.getLoc(), arg);
456 case UnaryFn::reciprocal: {
457 Attribute oneAttr = builder.getOneAttr(arg.getType());
458 auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
459 ::cast<TypedAttr>(oneAttr));
460 return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
461 }
462 case UnaryFn::round:
463 return builder.create<math::RoundOp>(arg.getLoc(), arg);
464 case UnaryFn::sqrt:
465 return builder.create<math::SqrtOp>(arg.getLoc(), arg);
466 case UnaryFn::rsqrt:
467 return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
468 case UnaryFn::square:
469 return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
470 case UnaryFn::tanh:
471 return builder.create<math::TanhOp>(arg.getLoc(), arg);
472 case UnaryFn::erf:
473 return builder.create<math::ErfOp>(arg.getLoc(), arg);
474 }
475 llvm_unreachable("unsupported unary function");
476 }
477
478 // Build the binary functions defined by OpDSL.
479 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
480 bool allComplex = isComplex(value: arg0) && isComplex(value: arg1);
481 bool allFloatingPoint = isFloatingPoint(value: arg0) && isFloatingPoint(value: arg1);
482 bool allInteger = isInteger(value: arg0) && isInteger(value: arg1);
483 bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
484 arg1.getType().getIntOrFloatBitWidth() == 1;
485 if (!allComplex && !allFloatingPoint && !allInteger)
486 llvm_unreachable("unsupported non numeric type");
487 OpBuilder::InsertionGuard g(builder);
488 builder.setInsertionPointToEnd(&block);
489 switch (binaryFn) {
490 case BinaryFn::add:
491 if (allComplex)
492 return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
493 if (allFloatingPoint)
494 return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
495 if (allBool)
496 return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
497 return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
498 case BinaryFn::sub:
499 if (allComplex)
500 return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
501 if (allFloatingPoint)
502 return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
503 if (allBool)
504 llvm_unreachable("unsupported operation: sub with bools");
505 return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
506 case BinaryFn::mul:
507 if (allComplex)
508 return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
509 if (allFloatingPoint)
510 return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
511 if (allBool)
512 return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
513 return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
514 case BinaryFn::div:
515 if (allComplex)
516 return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
517 if (allFloatingPoint)
518 return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
519 if (allBool)
520 llvm_unreachable("unsupported operation: div with bools");
521 return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
522 case BinaryFn::div_unsigned:
523 if (!allInteger || allBool)
524 llvm_unreachable("unsupported operation: unsigned div not on uint");
525 return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
526 case BinaryFn::max_signed:
527 assert(!allComplex);
528 if (allFloatingPoint)
529 return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
530 return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
531 case BinaryFn::min_signed:
532 assert(!allComplex);
533 if (allFloatingPoint)
534 return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
535 return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
536 case BinaryFn::max_unsigned:
537 assert(!allComplex);
538 if (allFloatingPoint)
539 return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
540 return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
541 case BinaryFn::min_unsigned:
542 assert(!allComplex);
543 if (allFloatingPoint)
544 return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
545 return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
546 case BinaryFn::powf:
547 assert(allFloatingPoint);
548 return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
549 }
550 llvm_unreachable("unsupported binary function");
551 }
552
553 // Build the ternary functions defined by OpDSL.
554 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
555 Value arg2) {
556 bool headBool =
557 isInteger(value: arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
558 bool tailFloatingPoint =
559 isFloatingPoint(value: arg0) && isFloatingPoint(value: arg1) && isFloatingPoint(value: arg2);
560 bool tailInteger = isInteger(value: arg0) && isInteger(value: arg1) && isInteger(value: arg2);
561 OpBuilder::InsertionGuard g(builder);
562 builder.setInsertionPointToEnd(&block);
563 switch (ternaryFn) {
564 case TernaryFn::select:
565 if (!headBool && !(tailFloatingPoint || tailInteger))
566 llvm_unreachable("unsupported non numeric type");
567 return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
568 }
569 llvm_unreachable("unsupported ternary function");
570 }
571
572 // Build the type functions defined by OpDSL.
573 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
574 switch (typeFn) {
575 case TypeFn::cast_signed:
576 return cast(toType, operand, isUnsignedCast: false);
577 case TypeFn::cast_unsigned:
578 return cast(toType, operand, isUnsignedCast: true);
579 }
580 llvm_unreachable("unsupported type conversion function");
581 }
582
583 void yieldOutputs(ValueRange values) {
584 OpBuilder::InsertionGuard g(builder);
585 builder.setInsertionPointToEnd(&block);
586 Location loc = builder.getUnknownLoc();
587 builder.create<YieldOp>(loc, values);
588 }
589
590 Value constant(const std::string &value) {
591 OpBuilder::InsertionGuard g(builder);
592 builder.setInsertionPointToEnd(&block);
593 Location loc = builder.getUnknownLoc();
594 Attribute valueAttr = parseAttribute(attrStr: value, context: builder.getContext());
595 return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
596 }
597
598 Value index(int64_t dim) {
599 OpBuilder::InsertionGuard g(builder);
600 builder.setInsertionPointToEnd(&block);
601 return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
602 }
603
604 Type getIntegerType(unsigned width) {
605 return IntegerType::get(builder.getContext(), width);
606 }
607
608 Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
609 Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
610
611private:
612 // Generates operations to cast the given operand to a specified type.
613 // If the cast cannot be performed, a warning will be issued and the
614 // operand returned as-is (which will presumably yield a verification
615 // issue downstream).
616 Value cast(Type toType, Value operand, bool isUnsignedCast) {
617 OpBuilder::InsertionGuard g(builder);
618 builder.setInsertionPointToEnd(&block);
619 auto loc = operand.getLoc();
620 return convertScalarToDtype(b&: builder, loc, operand, toType, isUnsignedCast);
621 }
622
623 bool isComplex(Value value) {
624 return llvm::isa<ComplexType>(value.getType());
625 }
626 bool isFloatingPoint(Value value) {
627 return llvm::isa<FloatType>(Val: value.getType());
628 }
629 bool isInteger(Value value) {
630 return llvm::isa<IntegerType>(Val: value.getType());
631 }
632
633 OpBuilder &builder;
634 Block &block;
635};
636
637} // namespace
638
639//===----------------------------------------------------------------------===//
640// CopyOp
641//===----------------------------------------------------------------------===//
642
643namespace {
644
645struct EraseSelfCopy : OpRewritePattern<CopyOp> {
646 using OpRewritePattern<CopyOp>::OpRewritePattern;
647 LogicalResult matchAndRewrite(CopyOp copyOp,
648 PatternRewriter &rewriter) const override {
649 if (copyOp.getInputs() != copyOp.getOutputs())
650 return rewriter.notifyMatchFailure(copyOp, "not a self copy");
651 if (copyOp.hasPureBufferSemantics())
652 rewriter.eraseOp(op: copyOp);
653 else
654 rewriter.replaceOp(copyOp, copyOp.getInputs());
655
656 return success();
657 }
658};
659
660} // namespace
661
662void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
663 MLIRContext *context) {
664 results.add<EraseSelfCopy>(context);
665}
666
667//===----------------------------------------------------------------------===//
668// FillOp
669//===----------------------------------------------------------------------===//
670
671namespace {
672
673/// Fold linalg.fill -> tensor.expand/collapse_shape chain.
674///
675/// For such op chains, we can create new linalg.fill ops with the result
676/// type of the tensor.expand/collapse_shape op.
677template <typename TensorReshapeOp>
678struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
679 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
680 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
681 PatternRewriter &rewriter) const override {
682 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
683 if (!oldFill)
684 return failure();
685
686 Location loc = oldFill.getLoc();
687 TensorReshapeOp newInit;
688 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
689
690 newInit = rewriter.create<TensorReshapeOp>(
691 loc, reshapeOp.getResultType(), oldFill.output(),
692 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
693 reshapeOp.getStaticOutputShape());
694 } else {
695 newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
696 oldFill.output(),
697 reshapeOp.getReassociation());
698 }
699 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
700 ValueRange{newInit});
701 return success();
702 }
703};
704
705/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
706/// filling value are the same.
707struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
708 using OpRewritePattern::OpRewritePattern;
709
710 LogicalResult matchAndRewrite(tensor::PadOp padOp,
711 PatternRewriter &rewriter) const override {
712 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
713 if (!fillOp)
714 return failure();
715
716 // We can only fold if the padding value is the same as the original
717 // filling value.
718 Value padValue = padOp.getConstantPaddingValue();
719 if (!padValue || fillOp.value() != padValue)
720 return failure();
721
722 ReifiedRankedShapedTypeDims reifiedShape;
723 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
724 return rewriter.notifyMatchFailure(
725 padOp, "failed to reify tensor.pad op result shape");
726
727 auto emptyTensor = rewriter.create<tensor::EmptyOp>(
728 padOp.getLoc(), reifiedShape.front(),
729 padOp.getResultType().getElementType());
730 Value replacement =
731 rewriter
732 .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
733 ValueRange{emptyTensor})
734 .getResult(0);
735 if (replacement.getType() != padOp.getResultType()) {
736 replacement = rewriter.create<tensor::CastOp>(
737 fillOp.getLoc(), padOp.getResultType(), replacement);
738 }
739 rewriter.replaceOp(padOp, replacement);
740 return success();
741 }
742};
743
744/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
745/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
746/// filling value are the same.
747struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
748 using OpRewritePattern::OpRewritePattern;
749
750 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
751 PatternRewriter &rewriter) const override {
752 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
753 if (!srcPadOp)
754 return failure();
755
756 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
757 return failure();
758
759 // Walk back the tensor.insert_slice chain and find the first destination
760 // value at the start of the chain.
761 Value firstDest = insertOp.getDest();
762 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
763 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
764 return failure();
765
766 // Make sure the range of values accessed are disjoint. Without this, we
767 // cannot fold tensor.pad away.
768 bool disjoint = false;
769 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
770 // If the dimension has dynamic offset/size, we cannot guarantee
771 // disjoint. So just skip it.
772 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
773 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
774 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
775 continue;
776
777 // Get the range start and end, inclusively for both.
778 int64_t prevStart = prevOp.getStaticOffset(i);
779 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
780 prevOp.getStaticStride(i);
781 int64_t nextStart = insertOp.getStaticOffset(i);
782 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
783 insertOp.getStaticStride(i);
784 if (prevEnd < nextStart || nextEnd < prevStart) {
785 disjoint = true;
786 break;
787 }
788 }
789
790 if (!disjoint)
791 break;
792 firstDest = prevOp.getDest();
793 }
794
795 // Check whether the first destination is a fill op. For overlapped cases,
796 // this also cannot be true.
797 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
798 if (!dstFillOp)
799 return failure();
800
801 // We can only fold if the padding value is the same as the original
802 // filling value.
803 Value padValue = srcPadOp.getConstantPaddingValue();
804 if (!padValue || dstFillOp.value() != padValue)
805 return failure();
806
807 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
808 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
809
810 Location loc = insertOp.getLoc();
811 MLIRContext *context = getContext();
812
813 AffineExpr sym0, sym1;
814 bindSymbols(ctx: context, exprs&: sym0, exprs&: sym1);
815 auto addMap = AffineMap::get(dimCount: 0, symbolCount: 2, results: {sym0 + sym1}, context);
816
817 // Calculate the new offsets for the insert. It should be the old offsets
818 // plus low padding sizes.
819 SmallVector<OpFoldResult, 4> newOffsets;
820 for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
821 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
822 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
823 }
824
825 RankedTensorType srcPadType = srcPadOp.getSourceType();
826 SmallVector<OpFoldResult, 4> newSizes;
827 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
828 if (srcPadType.isDynamicDim(i)) {
829 newSizes.push_back(
830 rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
831 .getResult());
832 } else {
833 newSizes.push_back(Elt: rewriter.getIndexAttr(value: srcPadType.getDimSize(i)));
834 }
835 }
836
837 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
838 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
839 newSizes, insertOp.getMixedStrides());
840 return success();
841 }
842};
843
844/// Fold tensor.extract(linalg.fill(<input>)) into <input>
845struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
846public:
847 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
848
849 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
850 PatternRewriter &rewriter) const override {
851 // See if tensor input of tensor.extract op is the result of a linalg.fill
852 // op.
853 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
854 if (!fillOp)
855 return failure();
856
857 // Get scalar input operand of linalg.fill op.
858 Value extractedScalar = fillOp.getInputs()[0];
859
860 // Replace tensor.extract op with scalar value used to fill the tensor.
861 rewriter.replaceOp(extractOp, extractedScalar);
862 return success();
863 }
864};
865
866/// Folds pack(fill) into a single fill op if
867/// 1. The pack op does not have padding value, or
868/// 2. The filled value and padding value are the same.
869static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
870 linalg::PackOp packOp) {
871 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
872 if (!fillOp)
873 return failure();
874
875 if (auto paddingValue = packOp.getPaddingValue())
876 if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
877 return failure();
878
879 Value packOpDest = packOp.getDest();
880 if (!packOpDest.hasOneUse())
881 return failure();
882
883 return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
884 packOp.getDest());
885}
886
887/// Wrapper pattern that applies foldFillPackIntoFillOp method.
888struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
889public:
890 FoldFillWithPack(MLIRContext *context)
891 : OpRewritePattern<linalg::PackOp>(context) {}
892
893 LogicalResult matchAndRewrite(linalg::PackOp packOp,
894 PatternRewriter &rewriter) const override {
895 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
896 if (failed(fillOp))
897 return failure();
898 rewriter.replaceOp(packOp, fillOp.value().result());
899 return success();
900 }
901};
902
903/// Fold fill with copy.
904struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
905 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
906
907 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
908 PatternRewriter &rewriter) const override {
909 if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
910 rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
911 fillOp.getInputs(),
912 copyOp.getOutputs());
913 return success();
914 }
915 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
916 rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
917 fillOp.getOutputs());
918 return success();
919 }
920 return failure();
921 }
922};
923
924/// Fold fill with transpose.
925struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
926 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
927
928 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
929 PatternRewriter &rewriter) const override {
930 if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
931 rewriter.replaceOpWithNewOp<FillOp>(
932 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
933 transposeOp.getDpsInitOperand(0)->get());
934 return success();
935 }
936 return failure();
937 }
938};
939
940/// Fold a concat with all elements being fills of the same value
941/// into a fill of the concat result shape.
942struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
943 using OpRewritePattern::OpRewritePattern;
944
945 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
946 PatternRewriter &rewriter) const override {
947 auto concatOperands = concatOp.getInputs();
948 if (concatOperands.empty()) {
949 return failure();
950 }
951
952 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
953 if (!firstFillOp) {
954 return failure();
955 }
956 // Prefetch the fill value.
957 OpFoldResult firstFillVal =
958 getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
959 // Collect all the outs values for the fill operations.
960 SmallVector<Value> allOuts;
961 allOuts.push_back(Elt: firstFillOp.getDpsInitOperand(0)->get());
962
963 auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
964 auto fillOp = v.getDefiningOp<linalg::FillOp>();
965 if (!fillOp) {
966 return false;
967 }
968
969 OpFoldResult fillVal =
970 getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
971 if (fillVal != firstFillVal)
972 return false;
973
974 allOuts.push_back(Elt: fillOp.getDpsInitOperand(0)->get());
975 return true;
976 };
977 if (!llvm::all_of(concatOperands.drop_front(),
978 isDefinedByCompatibleFillOp)) {
979 return rewriter.notifyMatchFailure(
980 concatOp, "not all operands are defined by a compatible fill op");
981 }
982
983 Value outsConcat = rewriter.create<tensor::ConcatOp>(
984 concatOp.getLoc(), concatOp.getDim(), allOuts);
985 rewriter.replaceOpWithNewOp<linalg::FillOp>(
986 concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
987 return success();
988 }
989};
990
991} // namespace
992
993void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
994 MLIRContext *context) {
995 results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
996 FoldFillWithPack, FoldFillWithPad,
997 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
998 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
999 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
1000}
1001
1002//===----------------------------------------------------------------------===//
1003// GenericOp
1004//===----------------------------------------------------------------------===//
1005
1006static void buildGenericRegion(
1007 OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
1008 ValueRange outputs,
1009 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
1010 SmallVector<Type, 4> blockArgTypes;
1011 SmallVector<Location, 4> blockArgLocs;
1012 for (ValueRange container : {inputs, outputs}) {
1013 for (Value v : container) {
1014 Type t = v.getType();
1015 blockArgTypes.push_back(
1016 Elt: isa<MemRefType, RankedTensorType>(Val: t) ? getElementTypeOrSelf(type: t) : t);
1017 blockArgLocs.push_back(Elt: v.getLoc());
1018 }
1019 }
1020
1021 OpBuilder::InsertionGuard guard(builder);
1022 Block *bodyBlock =
1023 builder.createBlock(parent: &region, insertPt: region.end(), argTypes: blockArgTypes, locs: blockArgLocs);
1024 bodyBuild(builder, loc, bodyBlock->getArguments());
1025}
1026
1027void GenericOp::getAsmBlockArgumentNames(Region &region,
1028 OpAsmSetValueNameFn setNameFn) {
1029 for (Value v : getRegionInputArgs())
1030 setNameFn(v, "in");
1031 for (Value v : getRegionOutputArgs())
1032 setNameFn(v, "out");
1033}
1034
1035void GenericOp::build(
1036 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1037 ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1038 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1039 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1040 ArrayRef<NamedAttribute> attributes) {
1041 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1042 iteratorTypes, doc, libraryCall);
1043 result.addAttributes(attributes);
1044 if (bodyBuild)
1045 buildGenericRegion(builder, result.location, *result.regions.front(),
1046 inputs, outputs, bodyBuild);
1047}
1048
1049void GenericOp::build(
1050 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1051 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1052 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1053 StringRef libraryCall,
1054 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1055 ArrayRef<NamedAttribute> attributes) {
1056 build(builder, result, resultTensorTypes, inputs, outputs,
1057 builder.getAffineMapArrayAttr(indexingMaps),
1058 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1059 iteratorTypes,
1060 [&](utils::IteratorType iter) -> mlir::Attribute {
1061 return IteratorTypeAttr::get(builder.getContext(), iter);
1062 }))),
1063 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1064 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1065 bodyBuild, attributes);
1066}
1067
1068void GenericOp::build(
1069 OpBuilder &builder, OperationState &result, ValueRange inputs,
1070 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1071 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1072 StringRef libraryCall,
1073 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1074 ArrayRef<NamedAttribute> attributes) {
1075 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1076 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1077}
1078
1079void GenericOp::build(
1080 OpBuilder &builder, OperationState &result, ValueRange inputs,
1081 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1082 ArrayRef<utils::IteratorType> iteratorTypes,
1083 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1084 ArrayRef<NamedAttribute> attributes) {
1085 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1086 /*doc=*/"",
1087 /*libraryCall=*/"", bodyBuild, attributes);
1088}
1089
1090void GenericOp::build(
1091 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1092 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1093 ArrayRef<utils::IteratorType> iteratorTypes,
1094 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1095 ArrayRef<NamedAttribute> attributes) {
1096 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1097 iteratorTypes,
1098 /*doc=*/"",
1099 /*libraryCall=*/"", bodyBuild, attributes);
1100}
1101
1102void GenericOp::print(OpAsmPrinter &p) {
1103 p << " ";
1104
1105 // Print extra attributes.
1106 auto genericAttrNames = linalgTraitAttrNames();
1107
1108 llvm::StringSet<> genericAttrNamesSet;
1109 genericAttrNamesSet.insert_range(genericAttrNames);
1110 SmallVector<NamedAttribute, 8> genericAttrs;
1111 for (auto attr : (*this)->getAttrs()) {
1112 if (attr.getName() == getIteratorTypesAttrName()) {
1113 auto iteratorTypes =
1114 llvm::cast<ArrayAttr>(attr.getValue())
1115 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1116 // Convert IteratorType enums into the string representation. This is
1117 // needed, because tests still use the old format when 'iterator_types'
1118 // attribute is represented as an array of strings.
1119 // TODO: Remove this conversion once tests are fixed.
1120 SmallVector<Attribute> iteratorTypeNames =
1121 llvm::to_vector(llvm::map_range(
1122 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1123 return StringAttr::get(getContext(), stringifyIteratorType(t));
1124 }));
1125
1126 genericAttrs.emplace_back(
1127 getIteratorTypesAttrName(),
1128 ArrayAttr::get(getContext(), iteratorTypeNames));
1129 } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1130 genericAttrs.push_back(attr);
1131 }
1132 }
1133 if (!genericAttrs.empty()) {
1134 auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1135 p << genericDictAttr;
1136 }
1137
1138 // Printing is shared with named ops, except for the region and attributes
1139 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1140
1141 genericAttrNames.push_back("operandSegmentSizes");
1142 genericAttrNamesSet.insert(genericAttrNames.back());
1143
1144 bool hasExtraAttrs = false;
1145 for (NamedAttribute n : (*this)->getAttrs()) {
1146 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1147 break;
1148 }
1149 if (hasExtraAttrs) {
1150 p << " attrs = ";
1151 p.printOptionalAttrDict((*this)->getAttrs(),
1152 /*elidedAttrs=*/genericAttrNames);
1153 }
1154
1155 // Print region.
1156 if (!getRegion().empty()) {
1157 p << ' ';
1158 p.printRegion(getRegion());
1159 }
1160
1161 // Print results.
1162 printNamedStructuredOpResults(p, getResultTensors().getTypes());
1163}
1164
1165ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1166 DictionaryAttr dictAttr;
1167 // Parse the core linalg traits that must check into a dictAttr.
1168 // The name is unimportant as we will overwrite result.attributes.
1169 // The core linalg traits must contain the information necessary to pass the
1170 // verifier.
1171 llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1172 if (parser.parseAttribute(dictAttr, "_", result.attributes))
1173 return failure();
1174 result.attributes.assign(dictAttr.getValue().begin(),
1175 dictAttr.getValue().end());
1176
1177 // Convert array of string into an array of IteratorType enums. This is
1178 // needed, because tests still use the old format when 'iterator_types'
1179 // attribute is represented as an array of strings.
1180 // TODO: Remove this conversion once tests are fixed.
1181 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1182 result.attributes.get(getIteratorTypesAttrName(result.name)));
1183 if (!iteratorTypes) {
1184 return parser.emitError(attributeLocation)
1185 << "expected " << getIteratorTypesAttrName(result.name)
1186 << " array attribute";
1187 }
1188
1189 SmallVector<Attribute> iteratorTypeAttrs;
1190
1191 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1192 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1193 if (!maybeIteratorType.has_value())
1194 return parser.emitError(parser.getCurrentLocation())
1195 << "unexpected iterator_type (" << s << ")";
1196
1197 iteratorTypeAttrs.push_back(
1198 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1199 }
1200 result.attributes.set(getIteratorTypesAttrName(result.name),
1201 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1202
1203 // Parsing is shared with named ops, except for the region.
1204 SmallVector<Type, 1> inputTypes, outputTypes;
1205 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1206 return failure();
1207
1208 // Optional attributes may be added.
1209 if (succeeded(parser.parseOptionalKeyword("attrs")))
1210 if (failed(parser.parseEqual()) ||
1211 failed(parser.parseOptionalAttrDict(result.attributes)))
1212 return failure();
1213
1214 std::unique_ptr<Region> region = std::make_unique<Region>();
1215 if (parser.parseRegion(*region, {}))
1216 return failure();
1217 result.addRegion(std::move(region));
1218
1219 // Generic ops may specify that a subset of its outputs are tensors. Such
1220 // outputs are specified in the result type.
1221 // TODO: may need to move output parsing before region parsing.
1222 // Need to wait for declarative assembly resolution to decide.
1223 SmallVector<Type, 1> outputTensorsTypes;
1224 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1225 return failure();
1226 result.addTypes(outputTensorsTypes);
1227
1228 return success();
1229}
1230
1231static void getGenericEffectsImpl(
1232 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1233 &effects,
1234 LinalgOp linalgOp) {
1235 for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1236 if (!llvm::isa<MemRefType>(operand.getType()))
1237 continue;
1238 effects.emplace_back(
1239 MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1240 /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1241 }
1242
1243 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1244 if (!llvm::isa<MemRefType>(operand.get().getType()))
1245 continue;
1246 if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1247 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1248 /*effectOnFullRegion=*/true,
1249 SideEffects::DefaultResource::get());
1250 }
1251 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1252 /*effectOnFullRegion=*/true,
1253 SideEffects::DefaultResource::get());
1254 }
1255}
1256
1257void GenericOp::getEffects(
1258 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1259 &effects) {
1260 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1261}
1262
1263static Speculation::Speculatability
1264getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
1265 // Operands with value semantics are speculatable, while operands with memory
1266 // semantics are not.
1267 if (!linalgOp.hasPureTensorSemantics())
1268 return Speculation::NotSpeculatable;
1269 // The body of the op can still have speculation in its region.
1270 return Speculation::RecursivelySpeculatable;
1271}
1272
1273Speculation::Speculatability GenericOp::getSpeculatability() {
1274 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1275}
1276
1277LogicalResult GenericOp::verify() { return success(); }
1278
1279namespace {
1280
1281/// Remove linalg operations that are just copying the values from inputs to
1282/// results. In the memref case, the operation must be copying to and from the
1283/// same value. Requirements are:
1284/// 1) All iterator types are parallel
1285/// 2) The body contains just a yield operation with the yielded values being
1286/// the arguments corresponding to the operands.
1287template <typename OpTy>
1288struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1289 using OpRewritePattern<OpTy>::OpRewritePattern;
1290
1291 LogicalResult matchAndRewrite(OpTy linalgOp,
1292 PatternRewriter &rewriter) const override {
1293 // All indexing maps must be equal. It follows that they are permutations.
1294 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1295 return failure();
1296
1297 // Check that the body of the linalg operation is just a linalg.yield
1298 // operation.
1299 Block &body = linalgOp->getRegion(0).front();
1300 if (!llvm::hasSingleElement(C&: body))
1301 return failure();
1302 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1303 if (!yieldOp)
1304 return failure();
1305
1306 // In the buffer case, we need to check exact buffer equality.
1307 if (linalgOp.hasPureBufferSemantics()) {
1308 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1309 linalgOp.getDpsInputOperand(0)->get() !=
1310 linalgOp.getDpsInitOperand(0)->get()) {
1311 return rewriter.notifyMatchFailure(
1312 linalgOp, "expected single input and output to be the same value");
1313 }
1314
1315 auto yieldArg = dyn_cast<BlockArgument>(yieldOp.getOperand(0));
1316 if (!yieldArg || yieldArg.getOwner() != &body) {
1317 return rewriter.notifyMatchFailure(linalgOp,
1318 "cannot fold fill-like op");
1319 }
1320
1321 rewriter.eraseOp(op: linalgOp);
1322 return success();
1323 }
1324
1325 if (!linalgOp.hasPureTensorSemantics()) {
1326 return rewriter.notifyMatchFailure(
1327 linalgOp, "mixed semantics is not supported yet");
1328 }
1329
1330 // Get the argument number of the returned values. That is the operand
1331 // number to use for replacing uses of this operation.
1332 SmallVector<Value> returnedArgs;
1333 for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1334 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1335 if (!yieldArg || yieldArg.getOwner() != &body)
1336 return failure();
1337 unsigned argumentNumber = yieldArg.getArgNumber();
1338 Value returnedArg = linalgOp->getOperand(argumentNumber);
1339 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1340 // The input can have a different type than the result, e.g. a dynamic
1341 // input dimension can be turned into a static output dimension.
1342 Type returnType = returnedArg.getType();
1343 if (returnType != resultType) {
1344 // Distinguish between sparse conversion or dense tensor casting.
1345 // TODO: unify the two ops?
1346 if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1347 sparse_tensor::getSparseTensorEncoding(resultType))
1348 returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1349 linalgOp.getLoc(), resultType, returnedArg);
1350 else {
1351 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1352 resultType))
1353 return failure();
1354 returnedArg = rewriter.create<tensor::CastOp>(
1355 linalgOp.getLoc(), resultType, returnedArg);
1356 }
1357 }
1358 returnedArgs.push_back(returnedArg);
1359 }
1360
1361 if (returnedArgs.size() != linalgOp->getNumResults())
1362 return failure();
1363 rewriter.replaceOp(linalgOp, returnedArgs);
1364 return success();
1365 }
1366};
1367
1368} // namespace
1369
1370void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1371 MLIRContext *context) {
1372 results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1373}
1374
1375LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1376 return memref::foldMemRefCast(*this);
1377}
1378
1379//===----------------------------------------------------------------------===//
1380// MapOp
1381//===----------------------------------------------------------------------===//
1382
1383static ParseResult parseDstStyleOp(
1384 OpAsmParser &parser, OperationState &result,
1385 function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1386 nullptr) {
1387 // Parse `ins` and `outs`.
1388 SmallVector<Type, 4> inputTypes, outputTypes;
1389 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1390 /*addOperandSegmentSizes=*/false))
1391 return failure();
1392
1393 // Add result types.
1394 for (Type outputType : outputTypes) {
1395 if (llvm::isa<RankedTensorType>(Val: outputType))
1396 result.addTypes(newTypes: outputType);
1397 }
1398
1399 // Parse required attributes.
1400 if (parseAttrsFn && failed(Result: parseAttrsFn(parser, result.attributes)))
1401 return failure();
1402
1403 // Parse optional attributes.
1404 if (parser.parseOptionalAttrDict(result&: result.attributes))
1405 return failure();
1406 return success();
1407}
1408
1409void MapOp::getAsmBlockArgumentNames(Region &region,
1410 OpAsmSetValueNameFn setNameFn) {
1411 for (Value v : getRegionInputArgs())
1412 setNameFn(v, "in");
1413}
1414
1415void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1416 if (!getResults().empty())
1417 setNameFn(getResults().front(), "mapped");
1418}
1419
1420void MapOp::build(
1421 OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1422 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1423 ArrayRef<NamedAttribute> attributes) {
1424 build(builder, result, TypeRange{}, inputs, init);
1425 result.addAttributes(attributes);
1426
1427 // Add output types for `RankedTensorType` output arguments.
1428 Type initType = init.getType();
1429 if (llvm::isa<RankedTensorType>(initType))
1430 result.addTypes(initType);
1431
1432 if (bodyBuild)
1433 buildGenericRegion(builder, result.location, *result.regions.front(),
1434 inputs, /*outputs=*/{}, bodyBuild);
1435}
1436
1437static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1438 const OperationName &payloadOpName,
1439 const NamedAttrList &payloadOpAttrs,
1440 ArrayRef<Value> operands,
1441 bool initFirst = false) {
1442 OpBuilder b(parser.getContext());
1443 Region *body = result.addRegion();
1444 Block &block = body->emplaceBlock();
1445 b.setInsertionPointToStart(&block);
1446 for (auto &operand : operands) {
1447 block.addArgument(
1448 llvm::cast<ShapedType>(operand.getType()).getElementType(),
1449 b.getUnknownLoc());
1450 }
1451 SmallVector<Value> payloadOpOperands;
1452 // If initFirst flag is enabled, we consider init as the first position of
1453 // payload operands.
1454 if (initFirst) {
1455 payloadOpOperands.push_back(Elt: block.getArguments().back());
1456 for (const auto &arg : block.getArguments().drop_back())
1457 payloadOpOperands.push_back(Elt: arg);
1458 } else {
1459 payloadOpOperands = {block.getArguments().begin(),
1460 block.getArguments().end()};
1461 }
1462
1463 Operation *payloadOp = b.create(
1464 result.location, b.getStringAttr(payloadOpName.getStringRef()),
1465 payloadOpOperands,
1466 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1467 .getElementType()},
1468 payloadOpAttrs);
1469 b.create<YieldOp>(result.location, payloadOp->getResults());
1470}
1471
1472ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1473 std::optional<OperationName> payloadOpName;
1474 NamedAttrList payloadOpAttrs;
1475 if (succeeded(parser.parseOptionalLBrace())) {
1476 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1477 if (failed(operationName))
1478 return failure();
1479 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1480 return failure();
1481 payloadOpName = operationName.value();
1482 if (parser.parseRBrace())
1483 return failure();
1484 }
1485
1486 if (parseDstStyleOp(parser, result))
1487 return failure();
1488
1489 if (payloadOpName.has_value()) {
1490 if (!result.operands.empty())
1491 addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1492 payloadOpAttrs,
1493 ArrayRef(result.operands).drop_back());
1494 else
1495 result.addRegion();
1496 } else {
1497 SmallVector<OpAsmParser::Argument> regionArgs;
1498 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1499 /*allowType=*/true, /*allowAttrs=*/true)) {
1500 return failure();
1501 }
1502 Region *body = result.addRegion();
1503 if (parser.parseRegion(*body, regionArgs))
1504 return failure();
1505 }
1506 return success();
1507}
1508
1509// Retrieve the operation from the body, if it is the only one (except
1510// yield) and if it gets the same amount of arguments as the body does.
1511// If initFirst flag is enabled, we check that init takes the first position in
1512// operands of payload.
1513static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1514 if (body->getOperations().size() != 2)
1515 return nullptr;
1516 Operation &payload = body->getOperations().front();
1517 assert(isa<YieldOp>(body->getOperations().back()));
1518
1519 if (payload.getNumOperands() == 0 ||
1520 payload.getNumOperands() != body->getNumArguments())
1521 return nullptr;
1522 if (initFirst) {
1523 // check init
1524 if (payload.getOperands().back() != body->getArgument(i: 0))
1525 return nullptr;
1526 // check rest
1527 for (const auto &[operand, bbArg] :
1528 llvm::zip(t: payload.getOperands(), u: body->getArguments().drop_front())) {
1529 if (bbArg != operand)
1530 return nullptr;
1531 }
1532 } else {
1533 for (const auto &[operand, bbArg] :
1534 llvm::zip(t: payload.getOperands(), u: body->getArguments())) {
1535 if (bbArg != operand)
1536 return nullptr;
1537 }
1538 }
1539 return &payload;
1540}
1541
1542void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1543 SmallVector<StringRef> elidedAttrs;
1544 std::string attrToElide;
1545 p << " { " << payloadOp->getName().getStringRef();
1546 for (const auto &attr : payloadOp->getAttrs()) {
1547 auto fastAttr =
1548 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1549 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1550 attrToElide = attr.getName().str();
1551 elidedAttrs.push_back(Elt: attrToElide);
1552 break;
1553 }
1554 }
1555 p.printOptionalAttrDict(attrs: payloadOp->getAttrs(), elidedAttrs);
1556 p << " }";
1557}
1558
1559void MapOp::print(OpAsmPrinter &p) {
1560 Block *mapper = getBody();
1561 Operation *payloadOp = findPayloadOp(mapper);
1562 if (payloadOp) {
1563 printShortForm(p, payloadOp);
1564 }
1565
1566 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1567 p.printOptionalAttrDict((*this)->getAttrs());
1568
1569 if (!payloadOp) {
1570 // Print region if the payload op was not detected.
1571 p.increaseIndent();
1572 p.printNewline();
1573 p << "(";
1574 llvm::interleaveComma(mapper->getArguments(), p,
1575 [&](auto arg) { p.printRegionArgument(arg); });
1576 p << ") ";
1577
1578 p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1579 p.decreaseIndent();
1580 }
1581}
1582
1583LogicalResult MapOp::verify() {
1584 auto *bodyBlock = getBody();
1585 auto blockArgs = bodyBlock->getArguments();
1586
1587 // Checks if the number of `inputs` match the arity of the `mapper` region.
1588 if (getInputs().size() != blockArgs.size())
1589 return emitOpError() << "expects number of operands to match the arity of "
1590 "mapper, but got: "
1591 << getInputs().size() << " and " << blockArgs.size();
1592
1593 // The parameters of mapper should all match the element type of inputs.
1594 for (const auto &[bbArgType, inputArg] :
1595 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1596 auto inputElemType =
1597 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1598 if (bbArgType != inputElemType) {
1599 return emitOpError() << "expected element type of input " << inputElemType
1600 << " to match bbArg type " << bbArgType;
1601 }
1602 }
1603
1604 // The shape of each input must match the shape of the output.
1605 auto outputShape = getInit().getType().getShape();
1606 for (Type inputArgType : TypeRange{getInputs()}) {
1607 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1608 if (inputElemShape != outputShape) {
1609 return emitOpError() << "expected shape of input (" << inputElemShape
1610 << ") to match shape of output (" << outputShape
1611 << ")";
1612 }
1613 }
1614
1615 return success();
1616}
1617
1618SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1619 int64_t rank = getInit().getType().getRank();
1620 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1621}
1622
1623ArrayAttr MapOp::getIndexingMaps() {
1624 Builder builder(getContext());
1625 int64_t rank = getInit().getType().getRank();
1626 int64_t numIndexingMaps = getOperands().size();
1627 return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
1628 numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1629}
1630
1631void MapOp::getEffects(
1632 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1633 &effects) {
1634 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1635}
1636
1637Speculation::Speculatability MapOp::getSpeculatability() {
1638 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1639}
1640
1641//===----------------------------------------------------------------------===//
1642// ReduceOp
1643//===----------------------------------------------------------------------===//
1644
1645void ReduceOp::getAsmBlockArgumentNames(Region &region,
1646 OpAsmSetValueNameFn setNameFn) {
1647 for (Value v : getRegionInputArgs())
1648 setNameFn(v, "in");
1649 for (Value v : getRegionOutputArgs())
1650 setNameFn(v, "init");
1651}
1652
1653void ReduceOp::getAsmResultNames(
1654 function_ref<void(Value, StringRef)> setNameFn) {
1655 if (!getResults().empty())
1656 setNameFn(getResults().front(), "reduced");
1657}
1658
1659void ReduceOp::build(
1660 OpBuilder &builder, OperationState &result, ValueRange inputs,
1661 ValueRange inits, ArrayRef<int64_t> dimensions,
1662 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1663 ArrayRef<NamedAttribute> attributes) {
1664 build(builder, result, TypeRange{}, inputs, inits, dimensions);
1665 result.addAttributes(attributes);
1666
1667 // Add output types for `RankedTensorType` output arguments.
1668 for (Value init : inits) {
1669 Type initType = init.getType();
1670 if (llvm::isa<RankedTensorType>(initType))
1671 result.addTypes(initType);
1672 }
1673
1674 if (bodyBuild)
1675 buildGenericRegion(builder, result.location, *result.regions.front(),
1676 inputs, inits, bodyBuild);
1677}
1678
1679SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1680 int64_t inputRank =
1681 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1682 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1683 utils::IteratorType::parallel);
1684 for (int64_t reductionDim : getDimensions())
1685 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1686 return iteratorTypes;
1687}
1688
1689ArrayAttr ReduceOp::getIndexingMaps() {
1690 int64_t inputRank =
1691 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1692 SmallVector<AffineMap> affineMaps(
1693 getNumDpsInputs(),
1694 AffineMap::getMultiDimIdentityMap(inputRank, getContext()));
1695 AffineMap resultMap =
1696 AffineMap::getMultiDimIdentityMap(inputRank, getContext())
1697 .dropResults(getDimensions());
1698 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1699 affineMaps.push_back(resultMap);
1700 return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1701}
1702
1703void ReduceOp::getEffects(
1704 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1705 &effects) {
1706 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1707}
1708
1709Speculation::Speculatability ReduceOp::getSpeculatability() {
1710 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1711}
1712
1713static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1714 NamedAttrList &attributes,
1715 StringRef attributeName) {
1716 if (parser.parseKeyword(keyword: attributeName) || parser.parseEqual())
1717 return failure();
1718
1719 attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1720 return success();
1721}
1722
1723ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1724 std::optional<OperationName> payloadOpName;
1725 NamedAttrList payloadOpAttrs;
1726 if (succeeded(parser.parseOptionalLBrace())) {
1727 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1728 if (failed(operationName))
1729 return failure();
1730 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1731 return failure();
1732 payloadOpName = operationName.value();
1733 if (parser.parseRBrace())
1734 return failure();
1735 }
1736
1737 if (parseDstStyleOp(
1738 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1739 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1740 }))
1741 return failure();
1742
1743 if (payloadOpName.has_value()) {
1744 addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1745 ArrayRef(result.operands), /*initFirst=*/true);
1746 } else {
1747 SmallVector<OpAsmParser::Argument> regionArgs;
1748 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1749 /*allowType=*/true, /*allowAttrs=*/true)) {
1750 return failure();
1751 }
1752
1753 Region *body = result.addRegion();
1754 if (parser.parseRegion(*body, regionArgs))
1755 return failure();
1756 }
1757
1758 return success();
1759}
1760
1761static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1762 ArrayRef<int64_t> attributeValue) {
1763 p << ' ' << attributeName << " = [" << attributeValue << "] ";
1764}
1765
1766void ReduceOp::print(OpAsmPrinter &p) {
1767 Block *mapper = getBody();
1768 Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1769 if (payloadOp) {
1770 printShortForm(p, payloadOp);
1771 }
1772
1773 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1774 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1775 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1776 if (!payloadOp) {
1777 // Print region if the payload op was not detected.
1778 p.increaseIndent();
1779 p.printNewline();
1780 p << "(";
1781 llvm::interleaveComma(mapper->getArguments(), p,
1782 [&](auto arg) { p.printRegionArgument(arg); });
1783 p << ") ";
1784
1785 p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1786 p.decreaseIndent();
1787 }
1788}
1789
1790LogicalResult ReduceOp::verify() {
1791 ArrayRef<int64_t> dimensionsRef = getDimensions();
1792
1793 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1794 if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1795 llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1796 return emitOpError() << "expects all inputs to have the same shapes. "
1797 "Shape at input-index "
1798 << i
1799 << " is not equal to the shape at input-index 0.";
1800 }
1801 }
1802 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1803 if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1804 llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1805 return emitOpError() << "expects all outputs to have the same shapes. "
1806 "Shape at output-index "
1807 << i
1808 << " is not equal to the shape at output-index 0.";
1809 }
1810 }
1811 auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1812 auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1813
1814 DenseSet<int64_t> dimensionsToReduce;
1815 for (int64_t dimension : dimensionsRef) {
1816 if (dimension < 0 || dimension >= inputType.getRank()) {
1817 return emitOpError()
1818 << "dimensions for reduction should be in the range [0, "
1819 << inputType.getRank() - 1 << "].";
1820 }
1821 dimensionsToReduce.insert(dimension);
1822 }
1823
1824 auto inputDims = inputType.getShape();
1825 auto initDims = initType.getShape();
1826
1827 // Input dimensions that will be left after the reduction.
1828 SmallVector<int64_t> reducedInputDims;
1829 for (const auto &en : llvm::enumerate(inputDims)) {
1830 if (!dimensionsToReduce.count(en.index()))
1831 reducedInputDims.push_back(en.value());
1832 }
1833
1834 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1835 return emitOpError() << "number of dimensions after reduction "
1836 << reducedInputDims.size()
1837 << " doesn't match the init rank "
1838 << initType.getRank();
1839 }
1840
1841 if (reducedInputDims != initDims)
1842 return emitOpError() << "init dimensions [" << initDims
1843 << "] doesn't match input dimensions after reduction ["
1844 << reducedInputDims << "]";
1845
1846 Block *block = getBody();
1847 if (block->getNumArguments() != this->getNumOperands())
1848 return emitOpError()
1849 << "mismatching number of operands and block arguments";
1850
1851 // Check that the first block arguments match the element type of the inputs.
1852 for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1853 Type inputElementType =
1854 llvm::cast<ShapedType>(input.getType()).getElementType();
1855 if (inputElementType != bbArg.getType())
1856 return emitOpError()
1857 << "input element type " << inputElementType
1858 << " does not match corresponding block argument type "
1859 << bbArg.getType();
1860 }
1861
1862 // Check that the last block arguments match the element type of the outputs.
1863 for (auto [output, bbArg] : llvm::zip(
1864 getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1865 auto outputElementType =
1866 llvm::cast<ShapedType>(output.getType()).getElementType();
1867 if (outputElementType != bbArg.getType())
1868 return emitOpError()
1869 << "output element type " << outputElementType
1870 << " does not match corresponding block argument type "
1871 << bbArg.getType();
1872 }
1873 return success();
1874}
1875
1876//===----------------------------------------------------------------------===//
1877// TransposeOp
1878//===----------------------------------------------------------------------===//
1879
1880static void buildIdentityRegion(OpBuilder &builder, Location loc,
1881 Region &region, ValueRange inputs,
1882 ValueRange outputs) {
1883 buildGenericRegion(builder, loc, region, inputs, outputs,
1884 bodyBuild: [](OpBuilder &b, Location loc, ValueRange args) {
1885 if (!args.empty())
1886 b.create<linalg::YieldOp>(loc, args[0]);
1887 });
1888}
1889
1890void TransposeOp::build(::mlir::OpBuilder &builder,
1891 ::mlir::OperationState &result, Value input, Value init,
1892 DenseI64ArrayAttr permutation,
1893 ArrayRef<NamedAttribute> attributes) {
1894 result.addOperands(input);
1895 result.addOperands(init);
1896 result.addAttribute(getPermutationAttrName(result.name), permutation);
1897 result.addAttributes(attributes);
1898
1899 // Add output types for `RankedTensorType` output arguments.
1900 Type initType = init.getType();
1901 if (llvm::isa<RankedTensorType>(initType))
1902 result.addTypes(initType);
1903
1904 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1905 init);
1906}
1907
1908void TransposeOp::build(::mlir::OpBuilder &builder,
1909 ::mlir::OperationState &result, Value input, Value init,
1910 ArrayRef<int64_t> permutation,
1911 ArrayRef<NamedAttribute> attributes) {
1912 build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1913 attributes);
1914}
1915
1916ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1917 if (failed(parseDstStyleOp(
1918 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1919 return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1920 })))
1921 return failure();
1922
1923 OpBuilder builder(parser.getContext());
1924 buildIdentityRegion(builder, result.location, *result.addRegion(),
1925 /*inputs=*/result.operands,
1926 /*outputs=*/{});
1927 return success();
1928}
1929
1930void TransposeOp::getAsmResultNames(
1931 function_ref<void(Value, StringRef)> setNameFn) {
1932 if (!getResults().empty())
1933 setNameFn(getResults().front(), "transposed");
1934}
1935
1936void TransposeOp::print(OpAsmPrinter &p) {
1937 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1938 printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1939 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1940}
1941
1942LogicalResult TransposeOp::verify() {
1943 ArrayRef<int64_t> permutationRef = getPermutation();
1944
1945 if (!isPermutationVector(permutationRef))
1946 return emitOpError("permutation is not valid");
1947
1948 auto inputType = getInput().getType();
1949 auto initType = getInit().getType();
1950
1951 int64_t rank = inputType.getRank();
1952
1953 if (rank != initType.getRank())
1954 return emitOpError() << "input rank " << rank
1955 << " does not match init rank " << initType.getRank();
1956
1957 if (rank != static_cast<int64_t>(permutationRef.size()))
1958 return emitOpError() << "size of permutation " << permutationRef.size()
1959 << " does not match the argument rank " << rank;
1960
1961 auto inputDims = inputType.getShape();
1962 auto initDims = initType.getShape();
1963
1964 for (int64_t i = 0; i < rank; ++i) {
1965 int64_t inputDim = inputDims[permutationRef[i]];
1966 int64_t initDim = initDims[i];
1967
1968 if (inputDim != initDim) {
1969 return emitOpError() << "dim(result, " << i << ") = " << initDim
1970 << " doesn't match dim(input, permutation[" << i
1971 << "]) = " << inputDim;
1972 }
1973 }
1974
1975 return success();
1976}
1977
1978SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1979 int64_t rank = getInit().getType().getRank();
1980 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1981}
1982
1983ArrayAttr TransposeOp::getIndexingMaps() {
1984 Builder builder(getContext());
1985 int64_t rank = getInit().getType().getRank();
1986 return builder.getAffineMapArrayAttr(
1987 {inversePermutation(AffineMap::getPermutationMap(
1988 llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
1989 builder.getMultiDimIdentityMap(rank)});
1990}
1991
1992void TransposeOp::getEffects(
1993 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1994 &effects) {
1995 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1996}
1997
1998Speculation::Speculatability TransposeOp::getSpeculatability() {
1999 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2000}
2001
2002LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2003 SmallVectorImpl<OpFoldResult> &result) {
2004 // Only the tensor type is supported.
2005 if (!isa<TensorType>(getInput().getType()))
2006 return failure();
2007
2008 // Single dimension transpose.
2009 if (getPermutation().size() == 0) {
2010 result.push_back(getInput());
2011 return success();
2012 }
2013 // Identity permutation.
2014 if (isIdentityPermutation(getPermutation())) {
2015 result.push_back(getInput());
2016 return success();
2017 }
2018
2019 return failure();
2020}
2021
2022/// Fold transpose with transpose.
2023struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
2024 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2025
2026 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2027 PatternRewriter &rewriter) const override {
2028 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2029 if (!defTransposeOp)
2030 return failure();
2031 ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2032 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2033 SmallVector<int64_t> foldedPerms;
2034 foldedPerms.reserve(N: perms.size());
2035 for (int64_t perm : perms)
2036 foldedPerms.push_back(defPerms[perm]);
2037
2038 rewriter.replaceOpWithNewOp<TransposeOp>(
2039 transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2040 foldedPerms);
2041 return success();
2042 }
2043};
2044
2045/// This pattern canonicalize transpose by swapping the order of
2046/// broadcast and transpose:
2047/// transpose(broadcast(input)) -> broadcast(transpose(input))
2048struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2049 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2050
2051 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2052 PatternRewriter &rewriter) const override {
2053 Value input = transposeOp.getInput();
2054 BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2055 if (!input.hasOneUse() || !broadcastOp)
2056 return failure();
2057
2058 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2059 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2060
2061 // Get new perms and new dimensions.
2062 SmallVector<int64_t> resultPerms = dropDims(inputPerm: perms, dropPositions: dimensions);
2063 SmallVector<int64_t> invertPerm = invertPermutationVector(permutation: perms);
2064 SmallVector<int64_t> resultDimensions;
2065 unsigned dimensionSize = dimensions.size();
2066 for (unsigned i = 0; i < dimensionSize; ++i)
2067 resultDimensions.push_back(Elt: invertPerm[dimensions[i]]);
2068
2069 // Create transpose result.
2070 Value broadcastInput = broadcastOp.getInput();
2071 Location loc = transposeOp.getLoc();
2072 MLIRContext *ctx = transposeOp.getContext();
2073 SmallVector<OpFoldResult> dims;
2074 auto broadcastInputTy =
2075 mlir::cast<RankedTensorType>(broadcastInput.getType());
2076 unsigned inputRank = broadcastInputTy.getRank();
2077 for (unsigned i = 0; i < inputRank; ++i) {
2078 if (broadcastInputTy.isDynamicDim(i)) {
2079 dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
2080 ->getResult(0));
2081 } else {
2082 dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2083 broadcastInputTy.getDimSize(i)));
2084 }
2085 }
2086 SmallVector<OpFoldResult> transposeResultShapes =
2087 applyPermutation(input: dims, permutation: resultPerms);
2088 Value transposeInit = rewriter.create<tensor::EmptyOp>(
2089 transposeOp.getLoc(), transposeResultShapes,
2090 broadcastInputTy.getElementType());
2091
2092 // Create broadcast(transpose(input)).
2093 Value transposeResult =
2094 rewriter
2095 .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2096 resultPerms)
2097 ->getResult(0);
2098 rewriter.replaceOpWithNewOp<BroadcastOp>(
2099 transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2100 return success();
2101 }
2102};
2103
2104void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2105 MLIRContext *context) {
2106 results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
2107}
2108
2109//===----------------------------------------------------------------------===//
2110// BroadcastOp
2111//===----------------------------------------------------------------------===//
2112
2113void BroadcastOp::build(::mlir::OpBuilder &builder,
2114 ::mlir::OperationState &result, Value input, Value init,
2115 DenseI64ArrayAttr dimensions,
2116 ArrayRef<NamedAttribute> attributes) {
2117 result.addOperands(input);
2118 result.addOperands(init);
2119 result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2120 result.addAttributes(attributes);
2121
2122 // Add output types for `RankedTensorType` output arguments.
2123 Type initType = init.getType();
2124 if (llvm::isa<RankedTensorType>(initType))
2125 result.addTypes(initType);
2126
2127 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2128 init);
2129}
2130
2131void BroadcastOp::build(::mlir::OpBuilder &builder,
2132 ::mlir::OperationState &result, Value input, Value init,
2133 ArrayRef<int64_t> dimensions,
2134 ArrayRef<NamedAttribute> attributes) {
2135 build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2136 attributes);
2137}
2138
2139ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2140 if (failed(parseDstStyleOp(
2141 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2142 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2143 })))
2144 return failure();
2145
2146 OpBuilder builder(parser.getContext());
2147 buildIdentityRegion(builder, result.location, *result.addRegion(),
2148 /*inputs=*/result.operands,
2149 /*outputs=*/{});
2150 return success();
2151}
2152
2153void BroadcastOp::getAsmResultNames(
2154 function_ref<void(Value, StringRef)> setNameFn) {
2155 if (!getResults().empty())
2156 setNameFn(getResults().front(), "broadcasted");
2157}
2158
2159void BroadcastOp::print(OpAsmPrinter &p) {
2160 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2161 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2162 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2163}
2164
2165LogicalResult BroadcastOp::verify() {
2166 ArrayRef<int64_t> dimensionsRef = getDimensions();
2167
2168 auto inputType = getInput().getType();
2169 auto initType = getInit().getType();
2170
2171 int64_t inputRank = inputType.getRank();
2172 int64_t initRank = initType.getRank();
2173
2174 auto inputShape = inputType.getShape();
2175 auto initShape = initType.getShape();
2176
2177 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2178 return emitOpError() << "input rank plus added dimensions does not "
2179 "match init rank. input rank: "
2180 << inputRank
2181 << ", dimensions size: " << dimensionsRef.size()
2182 << ", init rank: " << initRank;
2183
2184 for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2185 if (dim < 0 || dim >= initRank)
2186 return emitOpError() << "dimension " << idx
2187 << " is out of range. expected range: [0, "
2188 << initRank - 1 << "], got: " << dim;
2189 }
2190
2191 // Mapping from input dims to init dims.
2192 SmallVector<int64_t> dimMap;
2193 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2194 if (!llvm::is_contained(dimensionsRef, dim))
2195 dimMap.push_back(dim);
2196 }
2197
2198 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2199 // This dimensions is mapped from the input. Init and input dims should
2200 // match.
2201 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2202 return emitOpError() << "input dim " << inputDimIdx
2203 << " should match init dim " << initDimIdx
2204 << ". input: " << inputShape[inputDimIdx]
2205 << ", init: " << initShape[initDimIdx];
2206 }
2207
2208 return success();
2209}
2210
2211SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2212 int64_t rank = getInit().getType().getRank();
2213 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2214}
2215
2216ArrayAttr BroadcastOp::getIndexingMaps() {
2217 Builder builder(getContext());
2218 int64_t rank = getInit().getType().getRank();
2219 return builder.getAffineMapArrayAttr(
2220 {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2221 builder.getMultiDimIdentityMap(rank)});
2222}
2223
2224void BroadcastOp::getEffects(
2225 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2226 &effects) {
2227 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2228}
2229
2230Speculation::Speculatability BroadcastOp::getSpeculatability() {
2231 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2232}
2233
2234void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2235 MLIRContext *context) {
2236 results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2237}
2238
2239//===----------------------------------------------------------------------===//
2240// YieldOp
2241//===----------------------------------------------------------------------===//
2242
2243void linalg::YieldOp::print(OpAsmPrinter &p) {
2244 if (getNumOperands() > 0)
2245 p << ' ' << getOperands();
2246 p.printOptionalAttrDict((*this)->getAttrs());
2247 if (getNumOperands() > 0)
2248 p << " : " << getOperandTypes();
2249}
2250
2251ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2252 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2253 SmallVector<Type, 2> types;
2254 SMLoc loc = parser.getCurrentLocation();
2255 return failure(parser.parseOperandList(opInfo) ||
2256 parser.parseOptionalAttrDict(result.attributes) ||
2257 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2258 parser.resolveOperands(opInfo, types, loc, result.operands));
2259}
2260
2261// Check the operand number and types must match the element types of the
2262// LinalgOp interface's shaped operands.
2263static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2264 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2265 return op.emitOpError("expected number of yield values (")
2266 << op.getNumOperands()
2267 << ") to match the number of inits / outs operands of the enclosing "
2268 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2269
2270 for (OpOperand &opOperand : op->getOpOperands()) {
2271 OpOperand *outputOperand =
2272 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2273 Type elementType = outputOperand->get().getType();
2274 if (isa<MemRefType, RankedTensorType>(elementType))
2275 elementType = getElementTypeOrSelf(outputOperand->get().getType());
2276 if (opOperand.get().getType() != elementType)
2277 return op.emitOpError("type of yield operand ")
2278 << (opOperand.getOperandNumber() + 1) << " ("
2279 << opOperand.get().getType() << ") doesn't match "
2280 << "the element type of the enclosing linalg.generic op ("
2281 << elementType << ")";
2282 }
2283 return success();
2284}
2285
2286LogicalResult linalg::YieldOp::verify() {
2287 auto *parentOp = (*this)->getParentOp();
2288 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2289 return emitOpError("expected single non-empty parent region");
2290
2291 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2292 return verifyYield(*this, linalgOp);
2293
2294 return emitOpError("expected parent op with LinalgOp interface");
2295}
2296
2297//===----------------------------------------------------------------------===//
2298// IndexOp
2299//===----------------------------------------------------------------------===//
2300
2301LogicalResult IndexOp::verify() {
2302 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2303 if (!linalgOp)
2304 return emitOpError("expected parent op with LinalgOp interface");
2305 if (linalgOp.getNumLoops() <= getDim())
2306 return emitOpError("expected dim (")
2307 << getDim() << ") to be lower than the number of loops ("
2308 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2309 return success();
2310}
2311
2312OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2313 auto linalgOp = dyn_cast_or_null<LinalgOp>((*this)->getParentOp());
2314 // Bail out if `linalg.index` does not have a proper parent yet at this
2315 // point, e.g., when calling `createOrFold` during IR construction in
2316 // `genericOp::build`.
2317 if (!linalgOp)
2318 return OpFoldResult{};
2319
2320 // Index of unit dims is always 0.
2321 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2322 uint64_t dim = getDim();
2323 assert(dim < loopBounds.size() && "Dim is out of bounds");
2324 if (loopBounds[dim] == 1)
2325 return IntegerAttr::get(IndexType::get(getContext()), 0);
2326
2327 return OpFoldResult{};
2328}
2329
2330/////// Operations corresponding to library calls defined with Tablegen ////////
2331
2332#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2333
2334#define GET_OP_CLASSES
2335#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2336
2337#define GET_OP_CLASSES
2338#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2339#define GET_OP_CLASSES
2340#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2341
2342AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2343 unsigned rank,
2344 MLIRContext *context) {
2345 if (maybeMap)
2346 return *maybeMap;
2347 if (rank == 0)
2348 return AffineMap::get(context);
2349 return AffineMap::getMultiDimIdentityMap(numDims: rank, context);
2350}
2351
2352SmallVector<AffineExpr, 4>
2353mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2354 MLIRContext *context) {
2355 SmallVector<AffineExpr, 4> res;
2356 res.reserve(N: num);
2357 for (unsigned i = 0; i < num; ++i)
2358 res.push_back(Elt: getAffineDimExpr(position: startIdx++, context));
2359 return res;
2360}
2361
2362SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
2363 ArrayRef<AffineExpr> b) {
2364 auto rangeA = llvm::make_range(x: a.begin(), y: a.end());
2365 auto rangeB = llvm::make_range(x: b.begin(), y: b.end());
2366 auto concatRanges = llvm::concat<const AffineExpr>(Ranges&: rangeA, Ranges&: rangeB);
2367 return llvm::to_vector<4>(Range&: concatRanges);
2368}
2369
2370static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2371 if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2372 ss << "view";
2373 for (auto size : memref.getShape())
2374 if (size < 0)
2375 ss << "sx";
2376 else
2377 ss << size << "x";
2378 if (failed(appendMangledType(ss, memref.getElementType())))
2379 return failure();
2380 if (auto as = memref.getMemorySpace()) {
2381 if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2382 ss << "as" << attr.getInt();
2383 else
2384 return failure();
2385 }
2386 return success();
2387 }
2388 if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2389 ss << "vector";
2390 llvm::interleave(
2391 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2392 if (failed(appendMangledType(ss, vec.getElementType())))
2393 return failure();
2394 return success();
2395 }
2396 if (t.isSignlessIntOrIndexOrFloat()) {
2397 ss << t;
2398 return success();
2399 }
2400 return failure();
2401}
2402
2403std::string mlir::linalg::generateLibraryCallName(Operation *op) {
2404 assert(isa<LinalgOp>(op));
2405 std::string name(op->getName().getStringRef().str());
2406 std::string fun = "";
2407 for (NamedAttribute kv : op->getAttrs()) {
2408 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2409 fun = stringifyEnum(ufa.getValue()).str() + "_";
2410 } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2411 fun = stringifyEnum(bfa.getValue()).str() + "_";
2412 }
2413 }
2414 name.reserve(res: 128);
2415 llvm::replace(Range&: name, OldValue: '.', NewValue: '_');
2416 llvm::raw_string_ostream ss(name);
2417 ss << "_" << fun;
2418 for (Type t : op->getOperandTypes()) {
2419 if (failed(Result: appendMangledType(ss, t)))
2420 return std::string();
2421 ss << "_";
2422 }
2423 name.pop_back();
2424 return name;
2425}
2426
2427//===----------------------------------------------------------------------===//
2428// Canonicalizers and Folders.
2429//===----------------------------------------------------------------------===//
2430
2431namespace {
2432struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2433 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2434
2435 LogicalResult matchAndRewrite(LinalgOp op,
2436 PatternRewriter &rewriter) const override {
2437 for (OpOperand &opOperand : op->getOpOperands()) {
2438 // Linalg "inputs" may be either tensor or memref type.
2439 // tensor<0xelt_type> is a convention that may not always mean
2440 // "0 iterations". Only erase in cases we see memref<...x0x...>.
2441 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2442 if (!mt)
2443 continue;
2444 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2445 rewriter.eraseOp(op);
2446 return success();
2447 }
2448 }
2449 return failure();
2450 }
2451};
2452
2453/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2454/// result that is more static than the linalg op.
2455struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2456 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2457
2458 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2459 PatternRewriter &rewriter) const override {
2460 if (!tensor::canFoldIntoProducerOp(castOp))
2461 return failure();
2462
2463 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2464 if (!linalgOp)
2465 return failure();
2466
2467 // Cast can be in conditionally reachable region, if which case folding will
2468 // generate invalid code. Only conservatively fold ops in same block for
2469 // now.
2470 if (castOp->getBlock() != linalgOp->getBlock())
2471 return failure();
2472
2473 OpBuilder::InsertionGuard guard(rewriter);
2474 rewriter.setInsertionPoint(linalgOp);
2475
2476 Location loc = linalgOp.getLoc();
2477 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2478 unsigned resultNumber = resultValue.getResultNumber();
2479 auto resultType =
2480 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2481 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2482 // going from a more dynamic shape to a less dynamic shape. If the producer
2483 // for this cast, i.e. producer of the out operand, is also an operation
2484 // that folds with tensor.cast consumer (like this pattern), the cast will
2485 // continue to propagate as far up the stack as it can go.
2486 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2487 Value newOperand =
2488 rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2489 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2490 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2491 linalgOp.getDpsInits().end());
2492 outputOperands[resultNumber] = newOperand;
2493 newOperands.append(in_start: outputOperands.begin(), in_end: outputOperands.end());
2494
2495 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2496 linalgOp->result_type_end());
2497 resultTypes[resultNumber] = resultType;
2498 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2499
2500 // Create a tensor.cast operation back to the original type.
2501 Value castBack = rewriter.create<tensor::CastOp>(
2502 loc, resultValue.getType(), newOp->getResult(resultNumber));
2503
2504 SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2505 results[resultNumber] = castBack;
2506 rewriter.replaceOp(linalgOp, results);
2507 rewriter.replaceOp(castOp, newOp->getResult(idx: resultNumber));
2508 return success();
2509 }
2510};
2511
2512/// For each of the operand in `operands` this function maps the static sizes of
2513/// dimensions to their affine dim expressions.
2514static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2515 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2516 for (OpOperand &opOperand : operands) {
2517 if (linalgOp.isScalar(&opOperand))
2518 continue;
2519 Value src = opOperand.get();
2520 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2521 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2522
2523 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2524 // `tensor.cast` operation and source of the cast operation has a static
2525 // shape, then assign it to the `sourceShape`.
2526 auto *parentOp = src.getDefiningOp();
2527 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2528 if (parentOp) {
2529 if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2530 Value castSource = castOp.getSource();
2531 auto castSourceType =
2532 llvm::dyn_cast<RankedTensorType>(castSource.getType());
2533 if (castSourceType && castSourceType.hasStaticShape())
2534 sourceShape = castSourceType.getShape();
2535 }
2536 }
2537
2538 // If the source shape's dimension has a static shape, map the affine dim
2539 // expression to the known static size.
2540 for (unsigned i = 0; i < sourceShape.size(); i++) {
2541 if (sourceType.isDynamicDim(i))
2542 continue;
2543 if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2544 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2545 }
2546 }
2547}
2548
2549/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2550/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2551/// their result types is stored in `resultTypes`. If `opOperand` requires no
2552/// change then `changeNeeded` is false and same operand is added in the
2553/// `newOperands` list.
2554static void createNewOperandWithStaticSizes(
2555 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2556 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2557 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2558 bool &changeNeeded) {
2559 Value src = opOperand->get();
2560 newOperands.push_back(Elt: src);
2561 if (linalgOp.isScalar(opOperand))
2562 return;
2563 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2564 Type resultType = sourceType;
2565 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2566 resultTypes.push_back(Elt: resultType);
2567 return;
2568 }
2569 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2570 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2571 SmallVector<int64_t> newShape;
2572 // If operand is updated with new shape, `newOperandNeeded` will be
2573 // true.
2574 bool newOperandNeeded = false;
2575 for (unsigned i = 0; i < sourceShape.size(); i++) {
2576 int64_t dimShape = sourceShape[i];
2577 AffineExpr dimExpr = sourceMap.getResult(idx: i);
2578 if (!affineExprToSize.contains(Val: dimExpr) || !sourceType.isDynamicDim(i)) {
2579 newShape.push_back(Elt: dimShape);
2580 continue;
2581 }
2582 // Dimension has a dynamic shape and corresponding affine dim
2583 // expression is present in the map. So assign the size for the
2584 // given affine dim expression to the dimension.
2585 newShape.push_back(Elt: affineExprToSize[dimExpr]);
2586 newOperandNeeded = true;
2587 }
2588 resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2589 sourceType.getEncoding());
2590 if (newOperandNeeded) {
2591 changeNeeded = true;
2592 // Get the new operand value given its size and element type by
2593 // casting it.
2594 Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2595 unsigned index = opOperand->getOperandNumber();
2596 newOperands[index] = newOperand;
2597 }
2598 if (linalgOp.isDpsInit(opOperand))
2599 resultTypes.push_back(Elt: resultType);
2600}
2601
2602/// Static shapes for the operands can be inferred if any one of the operands
2603/// have a static shape. This can be done by referring to the affine dim
2604/// expressions for the operand.
2605struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2606 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2607
2608 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2609 PatternRewriter &rewriter) const override {
2610 if (!linalgOp.hasPureTensorSemantics())
2611 return failure();
2612
2613 // Maps must be projected permutations.
2614 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2615 return !map.isProjectedPermutation();
2616 }))
2617 return failure();
2618
2619 // Maps affine dim expressions to the static size of that dimension.
2620 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2621 Location loc = linalgOp.getLoc();
2622
2623 // For each of the affine dim expression, check if the size is known. If
2624 // known add that in the map.
2625 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2626
2627 SmallVector<Value> newOperands;
2628 SmallVector<Type> resultTypes;
2629
2630 // `changeNeeded` is `false` if the operands of `linalgOp` require no
2631 // change in their types.
2632 bool changeNeeded = false;
2633 newOperands.reserve(N: linalgOp->getNumOperands());
2634 resultTypes.reserve(N: linalgOp.getNumDpsInits());
2635
2636 // Iterate over all the operands and update the static sizes.
2637 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2638 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2639 affineExprToSize, linalgOp, newOperands,
2640 resultTypes, changeNeeded);
2641 }
2642
2643 // If the generic op has all the required static information, no
2644 // canonicalization needed.
2645 if (!changeNeeded)
2646 return failure();
2647
2648 // Clone op.
2649 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2650 SmallVector<Value> replacements;
2651 replacements.reserve(N: newOp->getNumResults());
2652 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2653 Value newResult = std::get<1>(it);
2654 Value oldResult = std::get<0>(it);
2655 Type newType = newResult.getType();
2656 Type oldType = oldResult.getType();
2657 replacements.push_back(
2658 (newType != oldType)
2659 ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2660 : newResult);
2661 }
2662 rewriter.replaceOp(linalgOp, replacements);
2663 return success();
2664 }
2665};
2666
2667} // namespace
2668
2669// All named ops canonicalizers and folders are auto-generated in the
2670// .cpp.inc.
2671
2672//===----------------------------------------------------------------------===//
2673// SoftmaxOp
2674//===----------------------------------------------------------------------===//
2675
2676LogicalResult SoftmaxOp::verify() {
2677 ShapedType inputType = getInputOperandType();
2678 ShapedType outputType = getOutputOperandType();
2679
2680 ArrayRef<int64_t> inputShape = inputType.getShape();
2681 ArrayRef<int64_t> outputShape = outputType.getShape();
2682 if (failed(verifyCompatibleShape(inputShape, outputShape)))
2683 return emitOpError("incompatible output shape");
2684
2685 int64_t inputRank = getInputOperandRank();
2686 int64_t dimension = getDimension();
2687 if ((dimension < 0) || (dimension >= inputRank))
2688 return emitOpError("incorrect dimension specified");
2689
2690 return success();
2691}
2692
2693SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2694 int64_t operandRank = getInputOperandRank();
2695 SmallVector<Range> loopBounds(operandRank);
2696 Location loc = getLoc();
2697 Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2698 Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2699 Value source = getInput();
2700 for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2701 loopBounds[dim].offset = zero;
2702 loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2703 loopBounds[dim].stride = one;
2704 }
2705 return loopBounds;
2706}
2707
2708SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2709 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2710 utils::IteratorType::parallel);
2711 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2712 return iteratorTypes;
2713}
2714
2715FailureOr<TilingResult>
2716SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2717 ArrayRef<OpFoldResult> offsets,
2718 ArrayRef<OpFoldResult> sizes) {
2719 int64_t rank = getInputOperandRank();
2720 auto oneAttr = builder.getI64IntegerAttr(1);
2721 SmallVector<OpFoldResult> strides(rank, oneAttr);
2722 SmallVector<Value> tiledOperands;
2723 Operation *inputSlice =
2724 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2725 if (!inputSlice) {
2726 return emitOpError("failed to compute input slice");
2727 }
2728 tiledOperands.emplace_back(inputSlice->getResult(0));
2729 Operation *outputSlice =
2730 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2731 if (!outputSlice) {
2732 return emitOpError("failed to compute output slice");
2733 }
2734 tiledOperands.emplace_back(outputSlice->getResult(0));
2735
2736 SmallVector<Type, 4> resultTypes;
2737 if (hasPureTensorSemantics())
2738 resultTypes.push_back(tiledOperands[1].getType());
2739 Operation *tiledOp =
2740 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2741
2742 return TilingResult{
2743 {tiledOp},
2744 SmallVector<Value>(tiledOp->getResults()),
2745 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2746}
2747
2748LogicalResult SoftmaxOp::getResultTilePosition(
2749 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2750 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2751 SmallVector<OpFoldResult> &resultSizes) {
2752 if (resultNumber == 0) {
2753 resultOffsets.assign(offsets.begin(), offsets.end());
2754 resultSizes.assign(sizes.begin(), sizes.end());
2755 return success();
2756 }
2757 return failure();
2758}
2759
2760// cast(dynamic) -> static.
2761LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2762 return memref::foldMemRefCast(*this);
2763}
2764
2765LogicalResult
2766SoftmaxOp::reifyResultShapes(OpBuilder &b,
2767 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2768 SmallVector<OpFoldResult> shapes;
2769 Location loc = getOperation()->getLoc();
2770 IRRewriter rewriter(b);
2771 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2772 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2773 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2774 if (!outputShapedType.isDynamicDim(dim)) {
2775 // Static dim: Return IntegerAttr.
2776 shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2777 } else {
2778 // Dynamic dim: Return Value.
2779 OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2780 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2781 }
2782 }
2783 reifiedReturnShapes.emplace_back(std::move(shapes));
2784 return success();
2785}
2786
2787void SoftmaxOp::getEffects(
2788 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2789 &effects) {
2790 for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2791 if (!llvm::isa<MemRefType>(operand.getType()))
2792 continue;
2793 effects.emplace_back(MemoryEffects::Read::get(),
2794 &getOperation()->getOpOperand(index), /*stage=*/0,
2795 /*effectOnFullRegion=*/true,
2796 SideEffects::DefaultResource::get());
2797 }
2798
2799 for (OpOperand &operand : getDpsInitsMutable()) {
2800 if (!llvm::isa<MemRefType>(operand.get().getType()))
2801 continue;
2802 effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2803 /*effectOnFullRegion=*/true,
2804 SideEffects::DefaultResource::get());
2805 effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2806 /*effectOnFullRegion=*/true,
2807 SideEffects::DefaultResource::get());
2808 }
2809}
2810
2811// Helper functions for softmax decomposition.
2812// @{
2813
2814// Helper function to produce the iterator types (reduction or parallel) and
2815// affine maps for the iterators used in the decomposition of softmax.
2816// This method creates:
2817// If allParallel == true:
2818// - iterator type: {parallel, ..., parallel}
2819// - affine maps:
2820// -- identity with inputRank dimensions.
2821// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2822// where N == inputRank.
2823//
2824// If allParallel == false:
2825// - iterator type at dim(i) == parallel for i != \p dim and
2826// dim(dim) == reduction.
2827// - affine map:
2828// -- identity with inputRank dimensions.
2829// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2830// where N == inputRank.
2831static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2832computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
2833 int64_t dim, bool allParallel = false) {
2834 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2835 utils::IteratorType::parallel);
2836 if (!allParallel)
2837 iteratorTypes[dim] = utils::IteratorType::reduction;
2838 MLIRContext *ctxt = builder.getContext();
2839 auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2840 SmallVector<AffineExpr, 2> affineExprs;
2841 for (int i = 0; i < inputRank; i++) {
2842 if (i != dim)
2843 affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2844 }
2845 auto reductionMap =
2846 AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2847 SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2848 return std::make_tuple(iteratorTypes, indexingMaps);
2849}
2850
2851// Helper function to produce a linalg.generic that computes a reduction on
2852// dimension \p dim with the operation type \p T.
2853template <typename T>
2854static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2855 int64_t dim) {
2856 auto inputType = cast<ShapedType>(input.getType());
2857 ArrayRef<int64_t> inputShape = inputType.getShape();
2858 int64_t inputRank = inputShape.size();
2859 auto [iteratorTypes, indexingMaps] =
2860 computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2861 assert(indexingMaps.size() == 2 &&
2862 "We should have two maps: 1 for the input, 1 for the output");
2863 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2864
2865 auto genericOp = builder.create<linalg::GenericOp>(
2866 loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2867 [&](OpBuilder &b, Location loc, ValueRange args) {
2868 Value result = b.create<T>(loc, args[0], args[1]);
2869 b.create<linalg::YieldOp>(loc, result);
2870 });
2871 return genericOp.getResult(0);
2872}
2873
2874/// Produce a linalg generic that computes the second step of the softmax
2875/// decomposition: res = exp(input - max), where \p max is the max of \p input
2876/// on dimension \p dim.
2877static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2878 Value max, Value output, int64_t dim) {
2879 auto inputType = cast<ShapedType>(input.getType());
2880 ArrayRef<int64_t> inputShape = inputType.getShape();
2881 int64_t inputRank = inputShape.size();
2882 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2883 builder, inputRank, dim, /*allParallel=*/true);
2884 assert(indexingMaps.size() == 2 && "We should have one map for each input");
2885 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2886 // Add the affine map for the output argument.
2887 indexingMaps.push_back(indexingMaps[0]);
2888 auto genericOp = builder.create<linalg::GenericOp>(
2889 loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2890 iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2891 Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2892 Value result = b.create<math::ExpOp>(loc, diff);
2893 b.create<linalg::YieldOp>(loc, result);
2894 });
2895 return genericOp.getResult(0);
2896}
2897
2898/// Produce a linalg generic that computes the final step of the softmax
2899/// decomposition.
2900/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2901/// yield n / d
2902/// }
2903static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2904 Value denominator, Value output, int64_t dim) {
2905 auto inputType = cast<ShapedType>(numerator.getType());
2906 ArrayRef<int64_t> inputShape = inputType.getShape();
2907 int64_t inputRank = inputShape.size();
2908 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2909 builder, inputRank, dim, /*allParallel=*/true);
2910 assert(indexingMaps.size() == 2 &&
2911 "We should have one map for each input (2)");
2912 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2913 // Add the affine map for the output tensor.
2914 indexingMaps.push_back(indexingMaps[0]);
2915 auto genericOp = builder.create<linalg::GenericOp>(
2916 loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2917 indexingMaps, iteratorTypes,
2918 [&](OpBuilder &b, Location loc, ValueRange args) {
2919 Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2920 b.create<linalg::YieldOp>(loc, result);
2921 });
2922 return genericOp.getResult(0);
2923}
2924// @} End helper functions for softmax decomposition.
2925
2926/// Given an N-dimensional tensor x, this method converts
2927/// softmax(x) to the following sequence of operations:
2928///
2929/// 1. Compute the max of x along dimension d. This results
2930/// in a N-1 dimensional tensor m.
2931/// m = max(x, dim = d)
2932///
2933/// 2. Subtract a broadcasted m from x and exponentiate. This results in
2934/// a N dimensional tensor z.
2935/// z = exp(x - m)
2936///
2937/// 3. Compute the sum of z along dimension d. This results in
2938/// a N-1 dimensional tensor l.
2939/// l = sum(z, dim = d)
2940///
2941/// 4. Divide z and l. This gives the N-dimensional softmax.
2942/// softmax = z / l
2943///
2944FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2945 OpBuilder::InsertionGuard guard(b);
2946 b.setInsertionPoint(*this);
2947 Location loc = getLoc();
2948 Value input = getInput();
2949 ShapedType inputType = getInputOperandType();
2950 Type elementType = inputType.getElementType();
2951 int64_t reductionDim = getDimension();
2952 SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2953 Value output = getOutput();
2954 dims.erase(dims.begin() + reductionDim);
2955 // Step 1: Compute max along dim.
2956 Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2957 Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
2958 elementType, b, loc,
2959 /*useOnlyFiniteValue=*/true);
2960 Value neutralForMaxFInit =
2961 b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
2962 .result();
2963 Value max =
2964 reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2965
2966 // Step 2: Subtract max from input and exponentiate.
2967 Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
2968
2969 // Step 3: Compute sum along dim.
2970 Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
2971 b, loc, /*useOnlyFiniteValue=*/true);
2972 Value zeroInit =
2973 b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
2974 Value denominator =
2975 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2976
2977 // Step 4: Compute softmax.
2978 Value result =
2979 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2980 return SmallVector<Value>{result};
2981}
2982
2983//===----------------------------------------------------------------------===//
2984// WinogradFilterTransformOp
2985//===----------------------------------------------------------------------===//
2986
2987LogicalResult WinogradFilterTransformOp::verify() {
2988 auto filterType = cast<ShapedType>(getFilter().getType());
2989 ArrayRef<int64_t> filterShape = filterType.getShape();
2990 int64_t filterH = filterShape[getFilterHDim()];
2991 int64_t filterW = filterShape[getFilterWDim()];
2992 int64_t r = getR();
2993 int64_t m = getM();
2994
2995 if (filterH != r && filterH != 1)
2996 return emitOpError("expect filter height either equals to r or 1");
2997 if (filterW != r && filterW != 1)
2998 return emitOpError("expect filter width either equals to r or 1");
2999 if (filterH == 1 && filterW == 1)
3000 return emitOpError("expect either filter height or width equals to r");
3001
3002 SmallVector<int64_t> expectedOutputShape;
3003 expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
3004 expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
3005 expectedOutputShape.push_back(filterShape[getFilterCDim()]);
3006 expectedOutputShape.push_back(filterShape[getFilterFDim()]);
3007
3008 auto outputType = cast<ShapedType>(getOutput().getType());
3009 ArrayRef<int64_t> outputShape = outputType.getShape();
3010 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3011 return emitOpError("the output shape is not expected");
3012 }
3013 return success();
3014}
3015
3016SmallVector<Range>
3017WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3018 Location loc = getLoc();
3019 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3020 IntegerAttr oneAttr = builder.getIndexAttr(1);
3021 Value filter = getFilter();
3022 int64_t filterRank = getFilterOperandRank();
3023 SmallVector<Range> loopBounds(filterRank);
3024 for (unsigned dim = 0; dim < filterRank; ++dim) {
3025 loopBounds[dim].offset = zeroAttr;
3026 loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
3027 loopBounds[dim].stride = oneAttr;
3028 }
3029 return loopBounds;
3030}
3031
3032SmallVector<utils::IteratorType>
3033WinogradFilterTransformOp::getLoopIteratorTypes() {
3034 int64_t filterRank = getFilterOperandRank();
3035 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3036 utils::IteratorType::parallel);
3037 return iteratorTypes;
3038}
3039
3040LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3041 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3042 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3043 SmallVector<OpFoldResult> &resultSizes) {
3044 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3045 ShapedType filterType = getFilterOperandType();
3046 ArrayRef<int64_t> filterShape = filterType.getShape();
3047 int64_t filterH = filterShape[getFilterHDim()];
3048 int64_t filterW = filterShape[getFilterWDim()];
3049 int64_t m = getM();
3050 int64_t r = getR();
3051 int64_t alpha = m + r - 1;
3052 int64_t alphaH = filterH != 1 ? alpha : 1;
3053 int64_t alphaW = filterW != 1 ? alpha : 1;
3054 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3055 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3056
3057 resultOffsets.append(
3058 {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3059 resultSizes.append(
3060 {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3061
3062 return success();
3063}
3064
3065/// Implement tiling for winograd_filter_transform
3066/// The input of winograd_filter_transform is (F, KH, KW, C).
3067/// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3068/// Users can specify the tile sizes of F and C.
3069/// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3070/// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3071FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3072 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3073 ArrayRef<OpFoldResult> sizes) {
3074 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3075 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3076 ShapedType filterType = getFilterOperandType();
3077 ArrayRef<int64_t> filterShape = filterType.getShape();
3078 int64_t filterH = filterShape[getFilterHDim()];
3079 int64_t filterW = filterShape[getFilterWDim()];
3080 IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3081 IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3082 SmallVector<Value> tiledOperands;
3083 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3084
3085 sliceOffsets.append(
3086 {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3087 sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3088 sizes[getFilterCDim()]});
3089 int64_t filterRank = getFilterOperandRank();
3090 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3091 Location loc = getLoc();
3092 auto filterSlice = builder.create<tensor::ExtractSliceOp>(
3093 loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3094 tiledOperands.emplace_back(filterSlice);
3095
3096 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3097 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3098 resultSizes)))
3099 return failure();
3100
3101 int64_t outputRank = getOutputOperandRank();
3102 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3103 auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3104 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3105 tiledOperands.emplace_back(outputSlice);
3106
3107 SmallVector<Type> resultTypes;
3108 resultTypes.push_back(tiledOperands[1].getType());
3109 Operation *tiledOp =
3110 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3111
3112 return TilingResult{
3113 {tiledOp},
3114 SmallVector<Value>(tiledOp->getResults()),
3115 llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3116}
3117
3118//===----------------------------------------------------------------------===//
3119// WinogradInputTransformOp
3120//===----------------------------------------------------------------------===//
3121
3122LogicalResult WinogradInputTransformOp::verify() {
3123 auto inputType = cast<ShapedType>(getInput().getType());
3124 ArrayRef<int64_t> inputShape = inputType.getShape();
3125 int64_t inputH = inputShape[getInputHDim()];
3126 int64_t inputW = inputShape[getInputWDim()];
3127 int m = getM();
3128 int r = getR();
3129 int64_t tileSize = m + r - 1;
3130
3131 auto outputType = cast<ShapedType>(getOutput().getType());
3132 ArrayRef<int64_t> outputShape = outputType.getShape();
3133 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3134 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3135
3136 SmallVector<int64_t> expectedOutputShape(6, inputH);
3137 if (ShapedType::isDynamic(inputH)) {
3138 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3139 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3140 } else {
3141 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3142 expectedOutputShape[getOutputTileHDim()] =
3143 leftTransform ? (inputH - (r - 1)) / m : inputH;
3144 }
3145 if (ShapedType::isDynamic(inputW)) {
3146 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3147 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3148 } else {
3149 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3150 expectedOutputShape[getOutputTileWDim()] =
3151 rightTransform ? (inputW - (r - 1)) / m : inputW;
3152 }
3153 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3154 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3155
3156 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3157 return emitOpError("the output shape is not expected");
3158 }
3159 return success();
3160}
3161
3162SmallVector<Range>
3163WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3164 Location loc = getLoc();
3165 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3166 IntegerAttr oneAttr = builder.getIndexAttr(1);
3167 Value output = getOutput();
3168 int64_t outputRank = getOutputOperandRank();
3169 SmallVector<Range> loopBounds(outputRank);
3170 for (unsigned dim = 0; dim < outputRank; ++dim) {
3171 loopBounds[dim].offset = zeroAttr;
3172 // alphaH, alphaW, tileH, tileW, N, C
3173 loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3174 loopBounds[dim].stride = oneAttr;
3175 }
3176 return loopBounds;
3177}
3178
3179SmallVector<utils::IteratorType>
3180WinogradInputTransformOp::getLoopIteratorTypes() {
3181 int64_t outputRank = getOutputOperandRank();
3182 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3183 utils::IteratorType::parallel);
3184 return iteratorTypes;
3185}
3186
3187LogicalResult WinogradInputTransformOp::getResultTilePosition(
3188 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3189 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3190 SmallVector<OpFoldResult> &resultSizes) {
3191 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3192 ShapedType outputType = getOutputOperandType();
3193 ArrayRef<int64_t> outputShape = outputType.getShape();
3194 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3195 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3196
3197 int64_t m = getM();
3198 int64_t r = getR();
3199 int64_t alpha = m + r - 1;
3200 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3201 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3202
3203 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3204 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3205
3206 resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3207 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3208 offsets[getOutputCDim()]});
3209 resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3210 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3211 sizes[getOutputCDim()]});
3212
3213 return success();
3214}
3215
3216/// Implement tiling for winograd_input_transform
3217/// The input of winograd_input_transform is (N, H, W, C).
3218/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3219/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3220/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3221/// the values for the sizes of tileH, tileW, N, C for one tile.
3222FailureOr<TilingResult>
3223WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3224 ArrayRef<OpFoldResult> offsets,
3225 ArrayRef<OpFoldResult> sizes) {
3226 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3227 int64_t m = getM();
3228 int64_t r = getR();
3229
3230 ShapedType outputType = getOutputOperandType();
3231 ArrayRef<int64_t> outputShape = outputType.getShape();
3232 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3233 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3234
3235 Location loc = getLoc();
3236 MLIRContext *context = builder.getContext();
3237 auto identityAffineMap =
3238 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3239 auto offsetAffineMap =
3240 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3241 Value mappedOffsetH = affine::makeComposedAffineApply(
3242 builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3243 offsets[getOutputTileHDim()]);
3244 Value mappedOffsetW = affine::makeComposedAffineApply(
3245 builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3246 offsets[getOutputTileWDim()]);
3247 auto sizeAffineMap = AffineMap::get(
3248 1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3249 Value mappedSizeH = affine::makeComposedAffineApply(
3250 builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3251 Value mappedSizeW = affine::makeComposedAffineApply(
3252 builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3253
3254 SmallVector<Value> tiledOperands;
3255 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3256
3257 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3258 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3259 sliceOffsets.append(
3260 {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3261 OpFoldResult sizeH =
3262 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3263 OpFoldResult sizeW =
3264 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3265 sliceSizes.append(
3266 {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3267 int64_t inputRank = getInputOperandRank();
3268 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3269 auto inputSlice = builder.create<tensor::ExtractSliceOp>(
3270 loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3271 tiledOperands.emplace_back(inputSlice);
3272
3273 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3274 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3275 resultSizes)))
3276 return failure();
3277
3278 int64_t outputRank = getOutputOperandRank();
3279 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3280 auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3281 loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3282 tiledOperands.emplace_back(outputSlice);
3283
3284 SmallVector<Type> resultTypes;
3285 resultTypes.push_back(tiledOperands[1].getType());
3286 Operation *tiledOp =
3287 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3288
3289 return TilingResult{
3290 {tiledOp},
3291 SmallVector<Value>(tiledOp->getResults()),
3292 llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3293}
3294
3295//===----------------------------------------------------------------------===//
3296// WinogradOutputTransformOp
3297//===----------------------------------------------------------------------===//
3298
3299LogicalResult WinogradOutputTransformOp::verify() {
3300 auto valueType = cast<ShapedType>(getValue().getType());
3301 ArrayRef<int64_t> valueShape = valueType.getShape();
3302 int64_t valueH = valueShape[getValueAlphaHDim()];
3303 int64_t valueW = valueShape[getValueAlphaWDim()];
3304 int64_t valueTileH = valueShape[getValueTileHDim()];
3305 int64_t valueTileW = valueShape[getValueTileWDim()];
3306 int m = getM();
3307 int r = getR();
3308 bool leftTransform = valueH != 1;
3309 bool rightTransform = valueW != 1;
3310
3311 int64_t outputRank = getOutputOperandRank();
3312 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3313 if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3314 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3315 } else {
3316 if (valueH != (leftTransform ? m + r - 1 : 1))
3317 return emitOpError("expect input height equals to input tile size");
3318 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3319 }
3320 if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3321 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3322 } else {
3323 if (valueW != (rightTransform ? m + r - 1 : 1))
3324 return emitOpError("expect input width equals to input tile size");
3325 expectedOutputShape[getOutputWDim()] =
3326 (rightTransform ? m : 1) * valueTileW;
3327 }
3328 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3329 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3330
3331 auto outputType = cast<ShapedType>(getOutput().getType());
3332 ArrayRef<int64_t> outputShape = outputType.getShape();
3333 if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3334 return emitOpError("the output shape is not expected");
3335 }
3336 return success();
3337}
3338
3339SmallVector<Range>
3340WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3341 Location loc = getLoc();
3342 IntegerAttr zeroAttr = builder.getIndexAttr(0);
3343 IntegerAttr oneAttr = builder.getIndexAttr(1);
3344 Value value = getValue();
3345 int64_t valueRank = getValueOperandRank();
3346 SmallVector<Range> loopBounds(valueRank);
3347 for (unsigned dim = 0; dim < valueRank; ++dim) {
3348 loopBounds[dim].offset = zeroAttr;
3349 // alphaH, alphaW, tileH, tileW, N, F
3350 loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3351 loopBounds[dim].stride = oneAttr;
3352 }
3353 return loopBounds;
3354}
3355
3356SmallVector<utils::IteratorType>
3357WinogradOutputTransformOp::getLoopIteratorTypes() {
3358 int64_t valueRank = getValueOperandRank();
3359 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3360 utils::IteratorType::parallel);
3361 return iteratorTypes;
3362}
3363
3364LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3365 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3366 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3367 SmallVector<OpFoldResult> &resultSizes) {
3368 int64_t m = getM();
3369
3370 Location loc = getLoc();
3371 MLIRContext *context = builder.getContext();
3372 auto identityAffineMap =
3373 AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3374 auto affineMap =
3375 AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3376
3377 ShapedType valueType = getValueOperandType();
3378 ArrayRef<int64_t> valueShape = valueType.getShape();
3379 int64_t valueH = valueShape[0];
3380 int64_t valueW = valueShape[1];
3381 Value mappedOffsetH = affine::makeComposedAffineApply(
3382 builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3383 offsets[getValueTileHDim()]);
3384 Value mappedOffsetW = affine::makeComposedAffineApply(
3385 builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3386 offsets[getValueTileWDim()]);
3387 Value mappedSizeH = affine::makeComposedAffineApply(
3388 builder, loc, affineMap, sizes[getValueTileHDim()]);
3389 Value mappedSizeW = affine::makeComposedAffineApply(
3390 builder, loc, affineMap, sizes[getValueTileWDim()]);
3391
3392 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3393 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3394 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3395 OpFoldResult sizeH =
3396 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3397 OpFoldResult sizeW =
3398 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3399
3400 resultOffsets.append(
3401 {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3402 resultSizes.append(
3403 {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3404 return success();
3405}
3406
3407/// Implement tiling for winograd_output_transform
3408/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3409/// F). The output of winograd_output_transform is (N, H, W, F) Users can
3410/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3411/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3412/// for the sizes of tileH, tileW, N, F for one tile.
3413FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3414 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3415 ArrayRef<OpFoldResult> sizes) {
3416 IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3417 IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3418 Location loc = getLoc();
3419 SmallVector<Value> tiledOperands;
3420 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3421
3422 ShapedType valueType = getValueOperandType();
3423 ArrayRef<int64_t> valueShape = valueType.getShape();
3424 int64_t alphaH = valueShape[getValueAlphaHDim()];
3425 int64_t alphaW = valueShape[getValueAlphaWDim()];
3426 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3427 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3428
3429 sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3430 offsets[getValueTileWDim()], offsets[getValueNDim()],
3431 offsets[getValueFDim()]});
3432 sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3433 sizes[getValueTileWDim()], sizes[getValueNDim()],
3434 sizes[getValueFDim()]});
3435 int64_t valueRank = getValueOperandRank();
3436 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3437 auto valueSlice = builder.create<tensor::ExtractSliceOp>(
3438 loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3439 tiledOperands.emplace_back(valueSlice);
3440
3441 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3442 if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3443 resultSizes)))
3444 return failure();
3445
3446 int64_t outputRank = getOutputOperandRank();
3447 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3448 auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3449 loc, getOutput(), resultOffsets, resultSizes, strides);
3450 tiledOperands.emplace_back(outputSlice);
3451
3452 SmallVector<Type> resultTypes;
3453 resultTypes.push_back(tiledOperands[1].getType());
3454 Operation *tiledOp =
3455 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3456
3457 return TilingResult{
3458 {tiledOp},
3459 SmallVector<Value>(tiledOp->getResults()),
3460 llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3461}
3462
3463//===----------------------------------------------------------------------===//
3464// LinalgDialect
3465// TODO: Merge with the LinalgDialect block at the bottom
3466//===----------------------------------------------------------------------===//
3467
3468// Returns true if the result expression of `subMap` are a subset of `fullMap`.
3469static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3470 auto explicitRange = subMap.getResults();
3471 auto defaultRange = fullMap.getResults();
3472 DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3473 DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3474 llvm::set_union(S1&: explicitSet, S2: defaultSet);
3475 return explicitSet == defaultSet;
3476}
3477
3478/// Check if the user defined map is valid broadcast map. Here broadcast
3479/// indexing maps are defined in context of corresponding default indexing maps
3480/// for the given Op. This way the check becomes very simple i.e just check the
3481/// number of result dims.
3482/// Returns true if the explictMap is broadcasted with respect to the
3483/// defaultMap.
3484static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3485 return explictMap.getNumResults() < defaultMap.getNumResults();
3486}
3487
3488/// Verifies the broadcast and transpose semantic sepecified by the explicit
3489/// indexing map for the MatmulOp \p op for each operand specified by \p
3490/// opIndex.
3491static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3492 unsigned opIndex) {
3493 SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3494 SmallVector<AffineMap, 3> defaultIndexingMaps =
3495 matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3496
3497 auto opIndexingMap = opIndexingMaps[opIndex];
3498 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3499 // Check general validity of indexing map results.
3500 if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3501 return matmulOp->emitOpError()
3502 << "Unexpected dim expression in map result.";
3503
3504 if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3505 if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3506 return matmulOp->emitOpError()
3507 << "Invalid broadcast requested, should be (d2).";
3508 }
3509 return success();
3510 }
3511 return success();
3512}
3513
3514// Check general validity of input indexing map of
3515// BatchMatmulOp/BatchReduceMatmulOp.
3516template <typename OpTy>
3517static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
3518 AffineMap opIndexingMap,
3519 AffineMap defaultIndexingMap, bool isLHS) {
3520 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3521 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3522 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3523 // Check the result dims are valid.
3524 if (!areResultExprsSubsetOf(subMap: opIndexingMap, fullMap: defaultIndexingMap))
3525 return batchVariantMatmulOp->emitOpError()
3526 << "Unexpected result dim expression (outside the set of default "
3527 "result dims).";
3528
3529 // Check for valid number of result dims of input maps.
3530 if (opIndexingMap.getNumResults() > 3)
3531 return batchVariantMatmulOp->emitOpError()
3532 << "no. of result dim expressions exceeds 3.";
3533
3534 auto hasValidBatchDim = [](AffineMap map) {
3535 AffineExpr batchDim = map.getResult(idx: 0);
3536 return batchDim.isFunctionOfDim(position: 0);
3537 };
3538
3539 // Check if the requested broadcast is valid.
3540 if (isBroadcasted(explictMap: opIndexingMap, defaultMap: defaultIndexingMap)) {
3541 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3542 return batchVariantMatmulOp->emitOpError()
3543 << "Invalid broadcast requested.";
3544 } else if (!hasValidBatchDim(opIndexingMap)) {
3545 return batchVariantMatmulOp->emitOpError()
3546 << "Invalid batch dimension expression.";
3547 }
3548 return success();
3549}
3550
3551/// This function checks if the given AffineMap for the output of a
3552/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
3553/// dimensions and if the output map result dimensions are valid.
3554template <typename OpTy>
3555static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
3556 AffineMap opIndexingMap) {
3557 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3558 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3559 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3560 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3561 opIndexingMap.getNumResults() != 3) {
3562
3563 return batchVariantMatmulOp->emitOpError()
3564 << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3565 << ").";
3566 }
3567 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3568 opIndexingMap.getNumResults() != 2) {
3569 return batchVariantMatmulOp->emitOpError()
3570 << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
3571 << ").";
3572 }
3573
3574 auto areValidOutputResultDim = [&](AffineMap outputMap) {
3575 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3576 ? outputMap.getResult(0).isFunctionOfDim(0) &&
3577 outputMap.getResult(1).isFunctionOfDim(1) &&
3578 outputMap.getResult(2).isFunctionOfDim(2)
3579 : outputMap.getResult(0).isFunctionOfDim(1) &&
3580 outputMap.getResult(1).isFunctionOfDim(2);
3581 };
3582
3583 if (!areValidOutputResultDim(opIndexingMap)) {
3584 return batchVariantMatmulOp->emitOpError()
3585 << "Invalid output map result dimension.";
3586 }
3587
3588 return success();
3589}
3590
3591/// Verifies the broadcast and transpose semantic specified by the explicit
3592/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
3593/// specified by opIndex.
3594template <typename OpTy>
3595static LogicalResult
3596verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
3597 unsigned opIndex) {
3598 SmallVector<AffineMap, 3> opIndexingMaps =
3599 batchVariantMatmulOp.getIndexingMapsArray();
3600 SmallVector<AffineMap, 3> defaultIndexingMaps =
3601 batchVariantMatmulOp.getDefaultIndexingMaps(
3602 batchVariantMatmulOp->getContext());
3603
3604 if (opIndexingMaps.size() != 3)
3605 return batchVariantMatmulOp->emitOpError()
3606 << "Indexing_map attribute must have 3 affine maps.";
3607
3608 auto opIndexingMap = opIndexingMaps[opIndex];
3609 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3610
3611 if (opIndex == 2 &&
3612 failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
3613 return failure();
3614
3615 if (opIndex != 2 &&
3616 failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
3617 defaultIndexingMap, opIndex == 0)))
3618 return failure();
3619
3620 return success();
3621}
3622
3623namespace mlir {
3624namespace linalg {
3625
3626//===----------------------------------------------------------------------===//
3627// MatMulOp
3628//===----------------------------------------------------------------------===//
3629
3630/// Returns a list of AffineMap with the typical matmul indexing charactristic.
3631SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3632 AffineExpr d0, d1, d2;
3633 SmallVector<AffineMap> indexingMaps;
3634 bindDims(context, d0, d1, d2);
3635 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3636 indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3637 indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3638 return indexingMaps;
3639}
3640
3641SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3642 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3643 utils::IteratorType::parallel,
3644 utils::IteratorType::reduction};
3645}
3646
3647unsigned MatmulOp::getNumRegionArgs() { return 3; }
3648
3649std::string MatmulOp::getLibraryCallName() {
3650 return generateLibraryCallName(getOperation());
3651}
3652
3653bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3654
3655/// Check if the op has broadcast and/or transpose semantic. Returns true if
3656/// the user defined indexing maps are not equal to default map.
3657bool MatmulOp::hasUserDefinedMaps() {
3658 SmallVector<AffineMap, 3> defaultMaps =
3659 getDefaultIndexingMaps(this->getContext());
3660 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3661 return defaultMaps != explicitMaps;
3662}
3663
3664/// Implements the block region builder for the MatmulOp. This is called by
3665/// 'fillStructuredOpRegion'.
3666void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3667 ArrayRef<NamedAttribute> attrs) {
3668 assert(3 > 0 && block.getNumArguments() == 3 &&
3669 "MatmulOp regionBuilder expects 3 (>=0) args");
3670 RegionBuilderHelper helper(b, block);
3671 SmallVector<Value> yields;
3672
3673 TypeFn castVal = TypeFn::cast_signed;
3674 const auto *castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3675 return attr.getName() == "cast";
3676 });
3677 if (castIter != attrs.end()) {
3678 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3679 castVal = attr.getValue();
3680 }
3681
3682 Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3683 block.getArgument(0));
3684 Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3685 block.getArgument(1));
3686 Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3687 Value value4 =
3688 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
3689 yields.push_back(value4);
3690 helper.yieldOutputs(yields);
3691}
3692
3693/// Returns true if the given bcastMap map is a valid broadcast map. A valid
3694/// broadcast map must include K dimension.
3695/// TODO: Strict inclusion of K dimension in the broadcast map is not
3696/// necessary for both input matrices simultaneously. We can relax this
3697/// condition to have K dimension for one input matrix map and infer the K
3698/// dimension for other input matrix map from the one already having K
3699/// dimension.
3700bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3701 assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3702 AffineExpr expr = bcastMap.getResult(0);
3703 // Invalid map if the common dimension of matmul not found.
3704 return expr.isFunctionOfDim(bcastMap.getNumDims() - 1);
3705}
3706
3707FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3708 if (parser.parseOptionalKeyword(keyword: "indexing_maps"))
3709 return ArrayAttr{
3710 nullptr}; // Success in case indexing_maps was not provided.
3711
3712 ArrayAttr arrayAttr;
3713 if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3714 return failure();
3715
3716 if (llvm::any_of(arrayAttr,
3717 [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3718 return parser.emitError(loc: parser.getCurrentLocation())
3719 << "element of indexing_maps array is not an affine_map";
3720
3721 return arrayAttr;
3722}
3723
3724ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3725 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3726 if (failed(indexingMapsAttr))
3727 return failure();
3728
3729 if (*indexingMapsAttr == nullptr) {
3730 auto indexingMapAttrs = llvm::map_to_vector(
3731 MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3732 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3733 indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
3734 }
3735
3736 result.addAttribute("indexing_maps", *indexingMapsAttr);
3737 return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3738 MatmulOp::getRegionBuilder());
3739}
3740
3741void MatmulOp::print(OpAsmPrinter &p) {
3742 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3743 MatmulOp::getDefaultIndexingMaps(getContext()),
3744 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3745 if (!llvm::equal(getIndexingMaps(), indexingMaps))
3746 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3747
3748 std::array<StringRef, 3> elidedAttrs = {
3749 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3750 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3751 elidedAttrs);
3752}
3753
3754/// Verify the user defined indexing maps.
3755LogicalResult MatmulOp::verify() {
3756 // Verification of pure matmul is handled by verifyStructuredOpInterface().
3757 if (!hasUserDefinedMaps())
3758 return success();
3759
3760 for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3761 if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3762 return failure();
3763 }
3764 return success();
3765}
3766
3767LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3768 return memref::foldMemRefCast(*this);
3769}
3770
3771void MatmulOp::getEffects(
3772 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3773 &effects) {
3774 if (hasPureTensorSemantics())
3775 return;
3776 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3777}
3778
3779Speculation::Speculatability MatmulOp::getSpeculatability() {
3780 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3781}
3782
3783//===----------------------------------------------------------------------===//
3784// ContractOp
3785//===----------------------------------------------------------------------===//
3786
3787SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
3788 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3789 // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
3790 // domains are all the same, and each implements a projected permutation.
3791 // Each iteration space dim must occur for at least one operand and either
3792 // takes part in a contraction/reduction or else has parallel iteration type.
3793 // We have that a dim is a contraction/reduction dim if and only if the dim
3794 // occurs for the output operand. We use this fact for fast inference:
3795 // NB: In case we allow dims to occur solely for one input, the above still
3796 // holds: per the einsum semantics, these are reduction dims as well.
3797 SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
3798 for (auto result : outAffineMap.getResults()) {
3799 auto dimExpr = dyn_cast<AffineDimExpr>(result);
3800 assert(dimExpr && "affine_map is a projected permutation");
3801 dimsInOutput[dimExpr.getPosition()] = true;
3802 }
3803
3804 SmallVector<utils::IteratorType> iteratorTypes;
3805 for (auto dimOccursInOutput : dimsInOutput)
3806 iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
3807 : utils::IteratorType::reduction);
3808
3809 return iteratorTypes;
3810}
3811
3812unsigned ContractOp::getNumRegionArgs() { return 3; }
3813
3814/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
3815void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3816 ArrayRef<NamedAttribute> attrs) {
3817 assert(block.getNumArguments() == 3 &&
3818 "ContractOp regionBuilder expects 3 args");
3819 RegionBuilderHelper helper(b, block);
3820
3821 TypeFn castSignedness = TypeFn::cast_signed;
3822 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3823 return attr.getName() == "cast";
3824 });
3825 if (castIter != attrs.end()) {
3826 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3827 castSignedness = attr.getValue();
3828 }
3829
3830 // TODO: Support fields with operators besides mult & add.
3831 Type outType = block.getArgument(2).getType();
3832 Value lhsAtOutType =
3833 helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
3834 Value rhsAtOutType =
3835 helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
3836 Value productAtOutType =
3837 helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
3838 Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3839 productAtOutType);
3840 helper.yieldOutputs({result});
3841}
3842
3843ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
3844 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3845 if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
3846 return parser.emitError(parser.getCurrentLocation(),
3847 "expected 'indexing_maps' attribute");
3848 result.addAttribute("indexing_maps", *indexingMapsAttr);
3849
3850 return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
3851 regionBuilder);
3852}
3853
3854void ContractOp::print(OpAsmPrinter &p) {
3855 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
3856 printNamedStructuredOp(
3857 p, getOperation(), getInputs(), getOutputs(),
3858 /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
3859}
3860
3861LogicalResult ContractOp::verify() {
3862 int iterationSpaceDims = -1;
3863 // Map iter space dims to #occurrences in inputs' and output's affine_maps:
3864 // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
3865 // access an input operand (so occurrence count can be at most 2) and
3866 // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
3867 SmallVector<size_t> inOccurrences;
3868 SmallVector<size_t> outOccurrences;
3869
3870 // A helper so that for each operand's affine_map and type we check that ...
3871 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
3872 bool isInput) -> LogicalResult {
3873 // ... the affine_map is a projected permutation;
3874 if (!affineMap.isProjectedPermutation())
3875 return emitError("provided affine_map is not a projected permutation");
3876
3877 // ... the rank of the affine_map's results and corresponding type match;
3878 if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
3879 if (affineMap.getNumResults() != shapedType.getRank())
3880 return emitError("ranks of shaped operand and results of corresponding "
3881 "affine_map differ");
3882 } else if (affineMap.getNumResults() != 0) {
3883 return emitError("affine_map specifies shaped access while operand has "
3884 "non-shaped type");
3885 }
3886
3887 // ... the rank of the affine_map's domain is the same as those seen prior;
3888 if (iterationSpaceDims == -1) {
3889 iterationSpaceDims = affineMap.getNumDims();
3890 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3891 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3892 } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
3893 return emitError("iteration spaces of provided affine_maps differ");
3894 }
3895
3896 // ... update counts of dims used to access either an input or the output.
3897 for (AffineExpr affineExpr : affineMap.getResults()) {
3898 auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
3899 if (!affineDimExpr)
3900 llvm_unreachable("affine_map is a projected permutation");
3901
3902 if (isInput)
3903 inOccurrences[affineDimExpr.getPosition()] += 1;
3904 else
3905 outOccurrences[affineDimExpr.getPosition()] += 1;
3906 }
3907
3908 return success();
3909 };
3910
3911 for (auto &&[affineMap, operandType, isInput] :
3912 llvm::zip(getIndexingMapsArray(), getOperandTypes(),
3913 SmallVector<bool>{true, true, false})) {
3914 if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
3915 return failure(); // NB: checkAffineMapAndType will emit relevant error.
3916 }
3917
3918 bool hasContractingDim = false;
3919 for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
3920 size_t inOccCount = inOccurrences[dimIndex];
3921 size_t outOccCount = outOccurrences[dimIndex];
3922
3923 // We have a contracting dim if and only if ...
3924 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
3925
3926 if (inOccCount == 0 && outOccCount == 0)
3927 return emitError() << "iteration space dim at index " << dimIndex
3928 << " not used to access any operand";
3929
3930 // NB: We disallow a dim which occurs for only one input operand and not
3931 // for the output. In terms of einsum semantics such dims have a
3932 // sensible meaning - namely an additional reduction per each such dim.
3933 // By contrast, the ContractionOpInterface does not know about this
3934 // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
3935 // while vector.contract's verifier accepts dims of this kind many of
3936 // its lowerings give up on encountering these dims.
3937 // TODO: Remove following once we have comprehensive support for input-only
3938 // reduction dims, at both the linalg- and vector-dialect levels.
3939 if (inOccCount == 1 && outOccCount != 1)
3940 return emitError()
3941 << "iteration space dim at index " << dimIndex
3942 << " is neither a contracting dim nor of parallel iteration type";
3943 }
3944
3945 if (!hasContractingDim)
3946 return emitError("'indexing_maps' do not specify a contracting dimension");
3947
3948 return success();
3949}
3950
3951LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3952 return memref::foldMemRefCast(*this);
3953}
3954
3955void ContractOp::getEffects(
3956 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3957 &effects) {
3958 if (hasPureTensorSemantics())
3959 return;
3960 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3961}
3962
3963Speculation::Speculatability ContractOp::getSpeculatability() {
3964 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3965}
3966
3967//===----------------------------------------------------------------------===//
3968// Implementation of BatchMatmulOp
3969//===----------------------------------------------------------------------===//
3970SmallVector<AffineMap>
3971BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3972 AffineExpr d0, d1, d2, d3;
3973 SmallVector<AffineMap> indexingMaps;
3974 bindDims(context, d0, d1, d2, d3);
3975 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
3976 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
3977 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
3978 return indexingMaps;
3979}
3980
3981SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
3982 return SmallVector<utils::IteratorType>{
3983 utils::IteratorType::parallel, utils::IteratorType::parallel,
3984 utils::IteratorType::parallel, utils::IteratorType::reduction};
3985}
3986
3987unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
3988
3989std::string BatchMatmulOp::getLibraryCallName() {
3990 return generateLibraryCallName(getOperation());
3991}
3992
3993/// Check if the op has broadcast and/or transpose semantic. Returns true if
3994/// the user defined indexing maps are not equal to default map.
3995bool BatchMatmulOp::hasUserDefinedMaps() {
3996 SmallVector<AffineMap, 3> defaultMaps =
3997 getDefaultIndexingMaps(this->getContext());
3998 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3999 return defaultMaps != explicitMaps;
4000}
4001
4002/// Returns true if the given bcastMap map is a valid broadcast map. A valid
4003/// broadcast map must include K dimension.
4004/// TODO: Strict inclusion of K dimension in the broadcast map is not
4005/// necessary for both input matrices simultaneously. We can relax this
4006/// condition to have K dimension for one input matrix map and infer the K
4007/// dimension for other input matrix map from the one already having K
4008/// dimension.
4009bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
4010 assert(bcastMap.getNumResults() < 3 &&
4011 "Expected less than 3 result dim expr.");
4012 bool isValid = false;
4013 enum Indices { batchPos, mPos, nPos, kPos };
4014 if (bcastMap.getNumResults() == 1) {
4015 AffineExpr expr = bcastMap.getResult(0);
4016 isValid = expr.isFunctionOfDim(kPos);
4017 } else if (bcastMap.getNumResults() == 2) {
4018 AffineExpr expr0 = bcastMap.getResult(0);
4019 AffineExpr expr1 = bcastMap.getResult(1);
4020 isValid =
4021 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
4022 expr0.isFunctionOfDim(mPos)) &&
4023 expr1.isFunctionOfDim(kPos))
4024 : ((expr0.isFunctionOfDim(batchPos) &&
4025 expr1.isFunctionOfDim(kPos)) ||
4026 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
4027 }
4028 return isValid;
4029}
4030
4031void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
4032 ArrayRef<NamedAttribute> attrs) {
4033 assert(block.getNumArguments() == 3 &&
4034 "BatchMatmulOp regionBuilder expects 3 (>=0) args");
4035 RegionBuilderHelper helper(b, block);
4036 SmallVector<Value> yields;
4037
4038 TypeFn castVal = TypeFn::cast_signed;
4039 auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
4040 return attr.getName() == "cast";
4041 });
4042 if (castIter != attrs.end()) {
4043 if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
4044 castVal = attr.getValue();
4045 }
4046
4047 auto toType = block.getArgument(2).getType();
4048 Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
4049 Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
4050 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
4051 Value addVal =
4052 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
4053 yields.push_back(addVal);
4054 helper.yieldOutputs(yields);
4055}
4056
4057ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4058 SmallVector<Attribute, 3> indexingMapsAttr;
4059 Attribute mapAttr;
4060 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4061 if (parser.parseEqual())
4062 return failure();
4063
4064 if (parser.parseLSquare())
4065 return failure();
4066
4067 do {
4068 if (parser.parseAttribute(mapAttr))
4069 return failure();
4070 if (!isa<AffineMapAttr>(mapAttr)) {
4071 return parser.emitError(parser.getCurrentLocation(),
4072 "expected affine map attribute");
4073 }
4074 indexingMapsAttr.push_back(mapAttr);
4075
4076 if (parser.parseOptionalComma())
4077 break;
4078 } while (true);
4079
4080 if (parser.parseRSquare())
4081 return failure();
4082 }
4083 // Initialize indexingMaps, if not supplied explicitly.
4084 if (indexingMapsAttr.empty()) {
4085 indexingMapsAttr = llvm::map_to_vector(
4086 BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4087 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4088 }
4089 result.addAttribute("indexing_maps",
4090 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4091
4092 return ::parseNamedStructuredOp(parser, result,
4093 BatchMatmulOp::getNumRegionArgs(),
4094 BatchMatmulOp::getRegionBuilder());
4095}
4096
4097void BatchMatmulOp::print(OpAsmPrinter &p) {
4098 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4099 BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4100 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4101 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4102 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4103
4104 std::array<StringRef, 3> elidedAttrs = {
4105 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4106 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4107 elidedAttrs);
4108}
4109
4110/// Verify the user defined indexing maps.
4111LogicalResult BatchMatmulOp::verify() {
4112 // Verification of pure batch_matmul is handled by
4113 // verifyStructuredOpInterface().
4114 if (!hasUserDefinedMaps())
4115 return success();
4116
4117 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4118 if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
4119 return failure();
4120 }
4121 return success();
4122}
4123
4124LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4125 SmallVectorImpl<OpFoldResult> &) {
4126 return memref::foldMemRefCast(*this);
4127}
4128
4129void BatchMatmulOp::getEffects(
4130 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4131 &effects) {
4132 if (hasPureTensorSemantics())
4133 return;
4134 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4135}
4136
4137Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4138 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4139}
4140
4141//===----------------------------------------------------------------------===//
4142// ElementwiseOp
4143//===----------------------------------------------------------------------===//
4144//
4145namespace {
4146struct ArityGroupAndKind {
4147 // The enum class {Unary, Binary, Ternary, ..}
4148 ElementwiseArityGroup arityGroup;
4149
4150 // The kind (e.g. `exp` or `add`) belonging to the arity group.
4151 union Kind {
4152 UnaryFn unaryFn;
4153 BinaryFn binaryFn;
4154 TernaryFn ternaryFn;
4155 } kind;
4156};
4157
4158unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4159 return static_cast<unsigned>(arityGroup);
4160}
4161} // namespace
4162
4163static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4164 constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4165 constexpr int lastBinary =
4166 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4167 constexpr int lastTernary =
4168 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4169
4170 int val = static_cast<int>(kind);
4171 ArityGroupAndKind result;
4172
4173 if (val < lastUnary) {
4174 result.arityGroup = ElementwiseArityGroup::Unary;
4175 result.kind.unaryFn = static_cast<UnaryFn>(val);
4176 return result;
4177 }
4178 if (val < lastBinary) {
4179 result.arityGroup = ElementwiseArityGroup::Binary;
4180 result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4181 return result;
4182 }
4183 if (val >= lastTernary) {
4184 llvm_unreachable("unhandled ElementwiseFn");
4185 }
4186 result.arityGroup = ElementwiseArityGroup::Ternary;
4187 result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4188 return result;
4189}
4190
4191SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4192 auto rank = getResultRank();
4193 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4194}
4195
4196SmallVector<AffineMap>
4197ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4198 MLIRContext *context) {
4199 auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4200 return SmallVector<AffineMap>(numMaps, map);
4201}
4202
4203ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4204 // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4205 Attribute attr;
4206 mlir::linalg::ElementwiseKind elemwiseKindVal;
4207 if (parser.parseKeyword("kind") || parser.parseEqual())
4208 return failure();
4209
4210 if (succeeded(parser.parseAttribute(attr))) {
4211 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4212 if (!elemwiseKindAttr)
4213 return parser.emitError(parser.getCurrentLocation(),
4214 "expected ElementwiseKind attribute");
4215 elemwiseKindVal = elemwiseKindAttr.getValue();
4216 } else {
4217 return parser.emitError(parser.getCurrentLocation(),
4218 "expected operation 'kind' attribute");
4219 }
4220 result.addAttribute(
4221 "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
4222
4223 // Parse optional `indexing_maps`
4224 SmallVector<Attribute, 3> indexingMapsAttr;
4225 Attribute mapAttr;
4226 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4227 if (parser.parseEqual())
4228 return failure();
4229 if (parser.parseLSquare())
4230 return failure();
4231 do {
4232 if (parser.parseAttribute(mapAttr))
4233 return failure();
4234 if (!isa<AffineMapAttr>(mapAttr))
4235 return parser.emitError(parser.getCurrentLocation(),
4236 "expected affine map attribute");
4237 indexingMapsAttr.push_back(mapAttr);
4238 if (parser.parseOptionalComma())
4239 break;
4240 } while (true);
4241 if (parser.parseRSquare())
4242 return failure();
4243 }
4244 // At this stage of parsing the only way to infer number of region
4245 // args is through op kind, as input output tensors are not parsed yet.
4246 auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
4247 int numRegionArgs =
4248 getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
4249 if (parseNamedStructuredOp(parser, result, numRegionArgs,
4250 ElementwiseOp::getRegionBuilder())) {
4251 return parser.emitError(parser.getCurrentLocation(),
4252 "unable to parse elemwise op");
4253 }
4254
4255 // Initialize indexingMaps, if not supplied explicitly.
4256 if (indexingMapsAttr.empty()) {
4257 // We need to infer the numDims of the indexing maps from the output
4258 // type which is already parsed by now.
4259 auto resultType = result.operands[result.operands.size() - 1].getType();
4260 auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4261 if (!shapedType)
4262 return parser.emitError(parser.getCurrentLocation(),
4263 "return type needs to be shaped type");
4264 auto numDims = shapedType.getRank();
4265 indexingMapsAttr = llvm::map_to_vector(
4266 ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4267 parser.getContext()),
4268 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4269 }
4270
4271 result.addAttribute("indexing_maps",
4272 parser.getBuilder().getArrayAttr(indexingMapsAttr));
4273 return success();
4274}
4275
4276void ElementwiseOp::print(OpAsmPrinter &p) {
4277 p << " kind=";
4278 p.printAttribute(getKindAttr());
4279 SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4280 "indexing_maps"};
4281 unsigned arity =
4282 getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
4283 unsigned numDims = getResultRank();
4284
4285 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4286 ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
4287 getContext()),
4288 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4289
4290 if (!llvm::equal(getIndexingMaps(), indexingMaps))
4291 p << " indexing_maps = " << llvm::interleaved_array(getIndexingMaps());
4292
4293 printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4294 elidedAttrs);
4295}
4296
4297LogicalResult ElementwiseOp::verify() {
4298 // All necessary checks are done either by
4299 // - EnumAttr (e.g. unknown operation kind)
4300 // - verifyStructuredOpInterface (incorrect map, sizes).
4301 return success();
4302}
4303
4304/// Implements the block region builder for the ElementwiseOp. This is called by
4305/// 'fillStructuredOpRegion'.
4306void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
4307 ArrayRef<NamedAttribute> attrs) {
4308 ElementwiseKind elemwiseKind;
4309 for (auto attr : attrs) {
4310 if (attr.getName() == b.getStringAttr("kind")) {
4311 auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4312 assert(kindAttr && "op kind attribute incorrectly set");
4313 elemwiseKind = kindAttr.getValue();
4314 break;
4315 }
4316 }
4317
4318 ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
4319 auto arityGroup = groupAndKind.arityGroup;
4320 auto kind = groupAndKind.kind;
4321 assert(block.getNumArguments() ==
4322 getArityGroupAsUInt(arityGroup) + 1 /*output*/
4323 && "Elementwise regionBuilder number of block args mismatch");
4324
4325 RegionBuilderHelper helper(b, block);
4326 SmallVector<Value> yields;
4327 Value result;
4328
4329 if (arityGroup == ElementwiseArityGroup::Unary) {
4330 result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
4331
4332 } else if (arityGroup == ElementwiseArityGroup::Binary) {
4333 result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
4334 block.getArgument(1));
4335
4336 } else if (arityGroup == ElementwiseArityGroup::Ternary) {
4337 result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
4338 block.getArgument(1), block.getArgument(2));
4339
4340 } else {
4341 assert(false && "found unhandled category in elemwise");
4342 }
4343
4344 yields.push_back(result);
4345 helper.yieldOutputs(yields);
4346}
4347
4348LogicalResult ElementwiseOp::fold(FoldAdaptor,
4349 SmallVectorImpl<OpFoldResult> &) {
4350 return memref::foldMemRefCast(*this);
4351}
4352
4353void ElementwiseOp::getEffects(
4354 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4355 &effects) {
4356 if (hasPureTensorSemantics())
4357 return;
4358 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4359}
4360
4361Speculation::Speculatability ElementwiseOp::getSpeculatability() {
4362 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4363}
4364
4365//===----------------------------------------------------------------------===//
4366// PackOp/UnPackOp Common
4367//===----------------------------------------------------------------------===//
4368// Given the (potentially) updated packed type, `newPackedTy`, generates an
4369// updated mixed-tile-sizes attribute. A tile size is updated only
4370// when:
4371// * a dim from newPackedTy is static, and
4372// * the corresponding size from mixedTiles is still dynamic.
4373// Otherwise, the original tile size is preserved.
4374// Note - packed-type-dim and mixed-tile-size should always match!
4375static SmallVector<OpFoldResult>
4376getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
4377 SmallVector<OpFoldResult> mixedTiles) {
4378 SmallVector<OpFoldResult> newMixedTileSizes;
4379 for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4380 .getShape()
4381 .take_back(mixedTiles.size()),
4382 mixedTiles)) {
4383 int64_t shape = std::get<0>(it);
4384 if (shape == ShapedType::kDynamic) {
4385 newMixedTileSizes.push_back(std::get<1>(it));
4386 continue;
4387 }
4388
4389 // If the current result dim is static, update the dynamic mixed-size
4390 // (provided the original value is dynamic).
4391 OpFoldResult tile = std::get<1>(it);
4392 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
4393 // Already a constant
4394 newMixedTileSizes.push_back(tile);
4395 } else {
4396 assert(getConstantIntValue(tile).value() == shape &&
4397 "tile size and dim size don't match!");
4398 newMixedTileSizes.push_back(
4399 (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
4400 }
4401 }
4402
4403 return newMixedTileSizes;
4404}
4405
4406template <typename OpTy>
4407static LogicalResult
4408reifyResultShapesImpl(OpTy op, OpBuilder &builder,
4409 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4410 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4411 "applies to only pack or unpack operations");
4412 int64_t destRank = op.getDestRank();
4413 reifiedReturnShapes.resize(N: 1, NV: SmallVector<OpFoldResult>(destRank));
4414 reifiedReturnShapes[0] =
4415 tensor::getMixedSizes(builder, loc: op.getLoc(), value: op.getDest());
4416 return success();
4417}
4418
4419template <typename OpTy>
4420static DenseMap<int64_t, OpFoldResult> getDimAndTileMappingImpl(OpTy op) {
4421 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4422 "applies to only pack or unpack operations");
4423 DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
4424 ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
4425 SmallVector<OpFoldResult> tiles = op.getMixedTiles();
4426 assert(tiles.size() == dimsToTile.size() &&
4427 "tiles must match indices of dimension to block");
4428 // bind the dimension `i` with the tile factor.
4429 for (auto i : llvm::seq<int64_t>(Begin: 0, End: dimsToTile.size()))
4430 dimAndTileMapping[dimsToTile[i]] = tiles[i];
4431 return dimAndTileMapping;
4432}
4433
4434template <typename OpTy>
4435static SmallVector<OpFoldResult> getMixedTilesImpl(OpTy op) {
4436 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4437 "applies to only pack or unpack operations");
4438 Builder builder(op);
4439 SmallVector<OpFoldResult> mixedInnerTiles;
4440 unsigned dynamicValIndex = 0;
4441 for (int64_t staticTile : op.getStaticInnerTiles()) {
4442 if (!ShapedType::isDynamic(staticTile))
4443 mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
4444 else
4445 mixedInnerTiles.push_back(Elt: op.getInnerTiles()[dynamicValIndex++]);
4446 }
4447 return mixedInnerTiles;
4448}
4449
4450template <typename OpTy>
4451static SmallVector<int64_t> getStaticTilesImpl(OpTy op) {
4452 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4453 "applies to only pack or unpack operations");
4454 SmallVector<Value> dynamicTiles;
4455 SmallVector<int64_t> staticTiles;
4456 dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
4457 return staticTiles;
4458}
4459
4460/// Returns true if `dimsPos` is invalid. It is invalid when:
4461/// a) It contains duplicate.
4462/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
4463/// c) The number of elements in `dimsPos` is > than `rank`.
4464static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
4465 size_t rank) {
4466 size_t dimsPosSize = dimsPos.size();
4467 if (dimsPosSize > rank)
4468 return true;
4469 DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
4470 if (dimsPosSize != uniqued.size())
4471 return true;
4472 return llvm::any_of(Range&: dimsPos, P: [rank](int64_t dimPos) {
4473 return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
4474 });
4475}
4476
4477/// Returns true if the dimension of `sourceShape` is smaller than the dimension
4478/// of the `limitShape`.
4479static bool areAllInBound(ArrayRef<int64_t> sourceShape,
4480 ArrayRef<int64_t> limitShape) {
4481 assert(
4482 sourceShape.size() == limitShape.size() &&
4483 "expected source shape rank, and limit of the shape to have same rank");
4484 return llvm::all_of(
4485 Range: llvm::zip(t&: sourceShape, u&: limitShape), P: [](std::tuple<int64_t, int64_t> it) {
4486 int64_t sourceExtent = std::get<0>(t&: it);
4487 int64_t limit = std::get<1>(t&: it);
4488 return ShapedType::isDynamic(sourceExtent) ||
4489 ShapedType::isDynamic(limit) || sourceExtent <= limit;
4490 });
4491}
4492
4493template <typename OpTy>
4494static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
4495 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4496 "applies to only pack or unpack operations");
4497 Operation *op = packOrUnPack.getOperation();
4498
4499 // Return true if we have a zero-value tile.
4500 auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
4501 return llvm::any_of(Range&: tiles, P: isZeroInteger);
4502 };
4503
4504 // Verify tiles. Do not allow zero tiles.
4505 SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
4506 if (hasZeros(mixedTiles))
4507 return op->emitError(message: "invalid zero tile factor");
4508
4509 // Verify inner_dims_pos and outer_dims_perm.
4510 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4511 ? packOrUnPack.getSourceType()
4512 : packOrUnPack.getDestType();
4513 size_t unpackedRank = unpackedType.getRank();
4514 ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
4515 ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
4516 if (isInvalidPackingPosSpecification(dimsPos: innerDimsPos, rank: unpackedRank))
4517 return op->emitError(message: "invalid inner_dims_pos vector");
4518 if (isInvalidPackingPosSpecification(dimsPos: outerDimPerm, rank: unpackedRank))
4519 return op->emitError(message: "invalid outer_dims_perm vector");
4520 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
4521 return op->emitError(message: "outer_dims_perm must be a permutation or empty");
4522
4523 // Tiling factors must be less than or equal to the input rank for pack (or
4524 // output rank for unpack), and must match the number of `inner_dims_pos`.
4525 if (mixedTiles.size() > unpackedRank) {
4526 return op->emitError(message: "tiling factors must be less than or equal to the "
4527 "input rank for pack or output rank for unpack");
4528 }
4529 if (mixedTiles.size() != innerDimsPos.size()) {
4530 return op->emitError(
4531 message: "tiling factors must equal the number of dimensions to tile");
4532 }
4533
4534 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4535 ? packOrUnPack.getDestType()
4536 : packOrUnPack.getSourceType();
4537 size_t packedRank = packedType.getRank();
4538 // Require output rank to match input rank + number of blocking factors.
4539 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
4540 if (expectedPackedRank != packedRank) {
4541 return op->emitError(
4542 message: "packed rank != (unpacked rank + num tiling factors), got ")
4543 << packedRank << " != " << expectedPackedRank;
4544 }
4545
4546 // Verify result shape is greater than the minimum expected
4547 // by the pack operation, and that the output shape
4548 // represents full tiles.
4549 RankedTensorType expectedPackedType = PackOp::inferPackedType(
4550 unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
4551 if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4552 return op->emitError(message: "the shape of output is not large enough to hold the "
4553 "packed data. Expected at least ")
4554 << expectedPackedType << ", got " << packedType;
4555 }
4556 if (!llvm::all_of(
4557 llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4558 mixedTiles),
4559 [](std::tuple<int64_t, OpFoldResult> it) {
4560 int64_t shape = std::get<0>(t&: it);
4561 if (Attribute attr =
4562 llvm::dyn_cast_if_present<Attribute>(Val&: std::get<1>(t&: it))) {
4563 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4564 int64_t staticTileSize = intAttr.getValue().getSExtValue();
4565 return shape == staticTileSize;
4566 }
4567 return ShapedType::isDynamic(shape);
4568 })) {
4569 return op->emitError(message: "mismatch in inner tile sizes specified and shaped of "
4570 "tiled dimension in the packed type");
4571 }
4572 return success();
4573}
4574
4575namespace {
4576/// Subset of PackOp/UnPackOp fields used to compute the result of applying
4577/// various permutations to the op.
4578// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
4579// these. These may or may not become true foldings / canonicalizations
4580// depending on how aggressive we want to be in automatically folding
4581// transposes.
4582struct PackOrUnPackTransposeResult {
4583 SmallVector<int64_t> innerDimsPos;
4584 SmallVector<OpFoldResult> innerTiles;
4585 SmallVector<int64_t> outerDimsPerm;
4586};
4587} // namespace
4588
4589template <typename OpTy>
4590static PackOrUnPackTransposeResult
4591commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
4592 ArrayRef<int64_t> innerPermutation,
4593 ArrayRef<int64_t> outerPermutation) {
4594 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4595 "applies to only pack or unpack operations");
4596 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4597 "some permutation must be non-empty");
4598 PackOrUnPackTransposeResult metadata;
4599 metadata.innerDimsPos =
4600 SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
4601 metadata.innerTiles =
4602 SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
4603 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4604 ? packOrUnPackOp.getSourceRank()
4605 : packOrUnPackOp.getDestRank();
4606 metadata.outerDimsPerm =
4607 packOrUnPackOp.getOuterDimsPerm().empty()
4608 ? llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: numOuterDims))
4609 : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
4610 if (!innerPermutation.empty()) {
4611 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4612 isPermutationVector(innerPermutation) &&
4613 "invalid inner permutation");
4614 applyPermutationToVector(inVec&: metadata.innerDimsPos, permutation: innerPermutation);
4615 applyPermutationToVector(inVec&: metadata.innerTiles, permutation: innerPermutation);
4616 }
4617 if (!outerPermutation.empty()) {
4618 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4619 isPermutationVector(outerPermutation) &&
4620 "invalid outer permutation");
4621 applyPermutationToVector(inVec&: metadata.outerDimsPerm, permutation: outerPermutation);
4622 }
4623 return metadata;
4624}
4625
4626//===----------------------------------------------------------------------===//
4627// PackOp
4628//===----------------------------------------------------------------------===//
4629
4630void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
4631 setNameFn(getResult(), "pack");
4632}
4633
4634void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
4635 Value dest, ArrayRef<int64_t> innerDimsPos,
4636 ArrayRef<OpFoldResult> innerTiles,
4637 std::optional<Value> paddingValue,
4638 ArrayRef<int64_t> outerDimsPerm) {
4639 assert(innerDimsPos.size() == innerTiles.size() &&
4640 "number of tile sizes specified must match the specified number of "
4641 "original dimensions to be tiled");
4642 SmallVector<int64_t> staticTileSizes;
4643 SmallVector<Value> dynamicTileSizes;
4644 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4645 build(builder, state, dest.getType(), source, dest,
4646 paddingValue ? *paddingValue : nullptr,
4647 outerDimsPerm.empty() ? nullptr
4648 : builder.getDenseI64ArrayAttr(outerDimsPerm),
4649 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4650 builder.getDenseI64ArrayAttr(staticTileSizes));
4651}
4652
4653LogicalResult
4654PackOp::reifyResultShapes(OpBuilder &builder,
4655 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4656 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4657}
4658
4659DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
4660 return getDimAndTileMappingImpl(*this);
4661}
4662
4663SmallVector<OpFoldResult> PackOp::getMixedTiles() {
4664 return getMixedTilesImpl(*this);
4665}
4666
4667SmallVector<int64_t> PackOp::getStaticTiles() {
4668 return getStaticTilesImpl(*this);
4669}
4670
4671ArrayRef<int64_t> PackOp::getAllOuterDims() {
4672 ShapedType inputType = getSourceType();
4673 int64_t inputRank = inputType.getRank();
4674 return getDestType().getShape().take_front(inputRank);
4675}
4676
4677SmallVector<int64_t> PackOp::getTiledOuterDims() {
4678 auto innerDimsPos = getInnerDimsPos();
4679 auto packedShape = getDestType().getShape();
4680 SmallVector<int64_t> res;
4681
4682 for (auto index : innerDimsPos)
4683 res.push_back(packedShape[index]);
4684
4685 return res;
4686}
4687
4688bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
4689 ArrayRef<int64_t> innerDimsPos,
4690 ArrayRef<int64_t> outputShape,
4691 ArrayRef<int64_t> outerDimsPerm,
4692 ArrayRef<OpFoldResult> innerTiles) {
4693 SmallVector<int64_t> outputTileSizes(
4694 outputShape.take_front(inputShape.size()));
4695 if (!outerDimsPerm.empty()) {
4696 assert(outerDimsPerm.size() == outputTileSizes.size() &&
4697 "expected output and outer_dims_perm to have same size");
4698 applyPermutationToVector(outputTileSizes,
4699 invertPermutationVector(outerDimsPerm));
4700 }
4701 for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
4702 if (ShapedType::isDynamic(inputShape[pos]))
4703 continue;
4704 std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
4705
4706 if (!constantTile) {
4707 if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4708 (inputShape[pos] % outputTileSizes[pos] != 0))
4709 return true;
4710 } else if (inputShape[pos] % (*constantTile) != 0) {
4711 return true;
4712 }
4713 }
4714 return false;
4715}
4716
4717LogicalResult PackOp::verify() {
4718 if (failed(commonVerifierPackAndUnPackOp(*this)))
4719 return failure();
4720
4721 // Verify padding value, and bail out if the tile does not divide the
4722 // dimension fully. In the case of dynamic tile factors or dimensions, having
4723 // a partial tile is undefined behavior.
4724 auto paddingValue = getPaddingValue();
4725 if (paddingValue &&
4726 paddingValue.getType() != getSourceType().getElementType()) {
4727 return emitOpError("expected padding_value has ")
4728 << getSourceType().getElementType()
4729 << " but got: " << paddingValue.getType();
4730 }
4731
4732 if (!paddingValue &&
4733 requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
4734 getDestType().getShape(), getOuterDimsPerm(),
4735 getMixedTiles())) {
4736 return emitOpError(
4737 "invalid tile factor or output size provided. Only full tiles are "
4738 "supported when padding_value is not set");
4739 }
4740 return success();
4741}
4742
4743/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
4744/// Value's to kDynamic, even if they are arith.constant values.
4745static SmallVector<int64_t>
4746asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
4747 SmallVector<int64_t> result;
4748 for (auto o : ofrs) {
4749 // Have to do this first, as getConstantIntValue special-cases constants.
4750 if (llvm::dyn_cast_if_present<Value>(o))
4751 result.push_back(ShapedType::kDynamic);
4752 else
4753 result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
4754 }
4755 return result;
4756}
4757
4758/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
4759/// the packed type. Having a shared helper helps implement these two methods in
4760/// a way that ensures that they agree on which dimensions are dynamic.
4761static SmallVector<int64_t> getPackOpResultTypeShape(
4762 ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
4763 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
4764 SmallVector<int64_t> resultShape = llvm::to_vector(Range&: sourceShape);
4765 for (auto tiledDim : llvm::enumerate(First: llvm::to_vector(Range&: innerDimsPos))) {
4766 if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4767 continue;
4768 if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4769 resultShape[tiledDim.value()] = ShapedType::kDynamic;
4770 continue;
4771 }
4772 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
4773 Numerator: resultShape[tiledDim.value()], Denominator: innerTileSizes[tiledDim.index()]);
4774 }
4775
4776 // Swap tile loops if outer_dims_perm is available.
4777 if (!outerDimsPerm.empty())
4778 applyPermutationToVector(inVec&: resultShape, permutation: outerDimsPerm);
4779
4780 // Append the inner tile dimensions.
4781 resultShape.append(in_start: innerTileSizes.begin(), in_end: innerTileSizes.end());
4782 return resultShape;
4783}
4784
4785SmallVector<OpFoldResult> PackOp::getResultShape(
4786 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
4787 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
4788 ArrayRef<int64_t> outerDimsPerm) {
4789 SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
4790
4791 AffineExpr s0, s1;
4792 bindSymbols(builder.getContext(), s0, s1);
4793 AffineExpr ceilDivExpr = s0.ceilDiv(s1);
4794 for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4795 resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
4796 builder, loc, ceilDivExpr,
4797 {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4798 }
4799 if (!outerDimsPerm.empty())
4800 applyPermutationToVector(resultDims, outerDimsPerm);
4801 resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4802
4803 SmallVector<int64_t> resultTypeShape =
4804 getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
4805 asShapeWithAnyValueAsDynamic(innerTileSizes),
4806 innerDimsPos, outerDimsPerm);
4807
4808 // Fix-up `resultDims` to ensure that they are Value's if and only if the
4809 // result type shape says it's a dynamic dim. This is needed as callers may
4810 // use dispatchIndexOpFoldResults on the result, and rely on exact number of
4811 // dynamic dims returned by that.
4812 for (unsigned i = 0; i < resultDims.size(); ++i) {
4813 if (!ShapedType::isDynamic(resultTypeShape[i]))
4814 continue;
4815 resultDims[i] =
4816 getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
4817 }
4818
4819 return resultDims;
4820}
4821
4822/// Get the expected packed type based on source type, tile factors, position of
4823/// the inner tiles and permutation of the outer tiled loop.
4824RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4825 ArrayRef<int64_t> innerTileSizes,
4826 ArrayRef<int64_t> innerDimsPos,
4827 ArrayRef<int64_t> outerDimsPerm) {
4828 SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
4829 sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4830 return RankedTensorType::get(resultShape, sourceType.getElementType());
4831}
4832
4833Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
4834 ArrayRef<OpFoldResult> innerTileSizes,
4835 ArrayRef<int64_t> innerDimsPos,
4836 ArrayRef<int64_t> outerDimsPerm) {
4837 AffineExpr dim0, dim1;
4838 bindDims(b.getContext(), dim0, dim1);
4839 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4840 return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
4841 {v1, v2});
4842 };
4843
4844 SmallVector<OpFoldResult> mixedSizes;
4845 for (auto [index, value] : llvm::enumerate(
4846 llvm::cast<RankedTensorType>(source.getType()).getShape())) {
4847 if (ShapedType::isDynamic(value))
4848 mixedSizes.push_back(
4849 b.create<tensor::DimOp>(loc, source, index).getResult());
4850 else
4851 mixedSizes.push_back(b.getIndexAttr(value));
4852 }
4853 for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
4854 int64_t dimPos = std::get<0>(it);
4855 OpFoldResult tileSize = std::get<1>(it);
4856 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4857 }
4858 if (!outerDimsPerm.empty())
4859 applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
4860
4861 mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4862 auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4863 return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4864}
4865
4866PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
4867 ArrayRef<int64_t> innerPermutation,
4868 ArrayRef<int64_t> outerPermutation) {
4869 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4870 *this, innerPermutation, outerPermutation);
4871 Value transposedDest =
4872 createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4873 metadata.innerDimsPos, metadata.outerDimsPerm);
4874 return b.create<PackOp>(loc, getSource(), transposedDest,
4875 metadata.innerDimsPos, metadata.innerTiles,
4876 getPaddingValue(), metadata.outerDimsPerm);
4877}
4878
4879/// Returns true if the tiles and the tiled dims are constant.
4880template <typename OpTy>
4881bool areTilesAndTiledDimsAllConstant(OpTy op) {
4882 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4883 "applies to only pack or unpack operations");
4884 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4885 ? op.getDestType()
4886 : op.getSourceType();
4887 SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
4888 for (auto [dimDest, tile] : llvm::zip(
4889 packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4890 std::optional<int64_t> constTileSize = getConstantIntValue(tile);
4891 if (!constTileSize || ShapedType::isDynamic(dimDest))
4892 return false;
4893 }
4894 return true;
4895}
4896
4897Speculation::Speculatability PackOp::getSpeculatability() {
4898 if (getPaddingValue())
4899 return Speculation::Speculatable;
4900
4901 // The verifier rejects already operations if we can statically prove that the
4902 // sizes of the tiles do not divide perfectly the dimension; thus, check only
4903 // to have constant tiles and tiled inner dimensions.
4904 if (!areTilesAndTiledDimsAllConstant(*this))
4905 return Speculation::NotSpeculatable;
4906
4907 return Speculation::Speculatable;
4908}
4909
4910// Return true if `inner_dims_pos` and `outer_dims_perm` target the same
4911// dimensions for pack and unpack.
4912static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
4913 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4914 return false;
4915 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4916 return true;
4917 // Outer dims permutation is optional.
4918 // To compare unbalanced pack-unpack pair, treat no permutation as equal to
4919 // identity permutation.
4920 return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
4921 isIdentityPermutation(unPackOp.getOuterDimsPerm());
4922}
4923
4924// Return true if pack and unpack have the same tiles.
4925// Same SSA values or same integer constants.
4926static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
4927 auto packTiles = packOp.getMixedTiles();
4928 auto unPackTiles = unPackOp.getMixedTiles();
4929 if (packTiles.size() != unPackTiles.size())
4930 return false;
4931 for (size_t i = 0, e = packTiles.size(); i < e; i++) {
4932 if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
4933 return false;
4934 }
4935 return true;
4936}
4937
4938/// Returns true if the pack op does not need a padding value.
4939static bool paddingIsNotNeeded(PackOp op) {
4940 auto srcType = op.getSourceType();
4941 if (llvm::any_of(op.getInnerDimsPos(),
4942 [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4943 return false;
4944 if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4945 return false;
4946 return !PackOp::requirePaddingValue(
4947 srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4948 op.getOuterDimsPerm(), op.getMixedTiles());
4949}
4950
4951/// Returns true if the `srcShape` or `destShape` is different from the one in
4952/// `packOp` and populates each with the inferred static shape.
4953static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
4954 SmallVectorImpl<int64_t> &destShape) {
4955 bool changeNeeded = false;
4956 srcShape.assign(packOp.getSourceType().getShape().begin(),
4957 packOp.getSourceType().getShape().end());
4958 destShape.assign(packOp.getDestType().getShape().begin(),
4959 packOp.getDestType().getShape().end());
4960 llvm::SmallSetVector<int64_t, 4> innerDims;
4961 innerDims.insert_range(packOp.getInnerDimsPos());
4962 SmallVector<int64_t> inverseOuterDimsPerm;
4963 if (!packOp.getOuterDimsPerm().empty())
4964 inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
4965 int srcRank = packOp.getSourceRank();
4966 for (auto i : llvm::seq<int64_t>(0, srcRank)) {
4967 if (innerDims.contains(i))
4968 continue;
4969 int64_t srcPos = i;
4970 int64_t destPos = i;
4971 if (!inverseOuterDimsPerm.empty())
4972 destPos = inverseOuterDimsPerm[srcPos];
4973 if (ShapedType::isDynamic(srcShape[srcPos]) ==
4974 ShapedType::isDynamic(destShape[destPos])) {
4975 continue;
4976 }
4977 int64_t size = srcShape[srcPos];
4978 if (ShapedType::isDynamic(size))
4979 size = destShape[destPos];
4980 srcShape[srcPos] = size;
4981 destShape[destPos] = size;
4982 changeNeeded = true;
4983 }
4984 return changeNeeded;
4985}
4986
4987LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4988 // Fold an pack(unpack(x)) to x.
4989 if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4990 if (unPackOp.getSourceType() != packOp.getDestType())
4991 return failure();
4992 if (packOp.getPaddingValue() ||
4993 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4994 !haveSameTiles(packOp, unPackOp))
4995 return failure();
4996 rewriter.replaceOp(packOp, unPackOp.getSource());
4997 return success();
4998 }
4999
5000 // Fold optional PaddingValue operand away if padding is not needed.
5001 if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
5002 rewriter.startOpModification(packOp);
5003 packOp.getPaddingValueMutable().clear();
5004 rewriter.finalizeOpModification(packOp);
5005 return success();
5006 }
5007
5008 // Insert tensor.cast ops if static shape inference is available..
5009 SmallVector<int64_t> srcShape, destShape;
5010 if (inferStaticShape(packOp, srcShape, destShape)) {
5011 Location loc = packOp.getLoc();
5012 Value source = packOp.getSource();
5013 if (srcShape != packOp.getSourceType().getShape()) {
5014 auto newSrcType = packOp.getSourceType().clone(srcShape);
5015 source =
5016 rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
5017 }
5018 Value dest = packOp.getDest();
5019 RankedTensorType originalResultType = packOp.getDestType();
5020 bool needUpdateDestType = (destShape != originalResultType.getShape());
5021 if (needUpdateDestType) {
5022 auto newDestType = packOp.getDestType().clone(destShape);
5023 dest =
5024 rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
5025 }
5026 rewriter.modifyOpInPlace(packOp, [&] {
5027 packOp.getSourceMutable().assign(source);
5028 packOp.getDestMutable().assign(dest);
5029 packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
5030 });
5031 // Insert a cast if needed
5032 if (needUpdateDestType) {
5033 rewriter.setInsertionPointAfter(packOp);
5034 auto castOp =
5035 rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
5036 rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
5037 }
5038 return success();
5039 }
5040
5041 return failure();
5042}
5043
5044template <typename PackOrUnpackOp>
5045static bool isLikePadUnPad(PackOrUnpackOp packOp,
5046 RankedTensorType packedTensorType) {
5047 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5048 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5049 "Function meant for pack/unpack");
5050 // This is a pad if packing only adds ones and we don't transpose dimensions.
5051
5052 // Check that we are not transposing any dimensions.
5053 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
5054 int64_t numPackedDims = innerDimsPos.size();
5055 auto orderedDims = llvm::to_vector<4>(Range: llvm::seq<int64_t>(Begin: 0, End: numPackedDims));
5056 if (orderedDims != innerDimsPos) {
5057 // Dimensions don't happen in order.
5058 return false;
5059 }
5060
5061 ArrayRef<int64_t> packedShape = packedTensorType.getShape();
5062 int64_t packedRank = packedTensorType.getRank();
5063 // At this point we know that we are taking numPackedDims outer
5064 // dimensions and pushing them all the way as the inner most dimensions.
5065 // What's left on the outer most dimensions is, in this order:
5066 // - the factor of the packed dimensions, then
5067 // - the untouched dimensions
5068 // This shifting inward of dimensions is a no-op (as opposed to a transpose)
5069 // if all the dimensions that bubble outerward are ones.
5070 // Therefore check that all the dimensions but the numPackedDims inner most
5071 // ones are ones.
5072 return llvm::all_of(
5073 llvm::seq<int64_t>(Begin: 0, End: packedRank - numPackedDims),
5074 [&packedShape](int64_t i) { return packedShape[i] == 1; });
5075}
5076
5077bool PackOp::isLikePad() {
5078 auto packedTensorType =
5079 llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5080 return isLikePadUnPad(*this, packedTensorType);
5081}
5082
5083OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
5084 std::optional<Attribute> paddingValue;
5085 if (auto pad = adaptor.getPaddingValue())
5086 paddingValue = pad;
5087 if (OpFoldResult reshapedSource = reshapeConstantSource(
5088 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5089 getDestType(), paddingValue))
5090 return reshapedSource;
5091 return {};
5092}
5093
5094/// Folds a tensor.cast op into a consuming PackOp op if the
5095/// `tensor.cast` has source that is more static than the consuming op.
5096///
5097/// Example:
5098/// ```mlir
5099/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
5100/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
5101/// ```
5102///
5103/// folds into:
5104///
5105/// ```mlir
5106/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
5107/// ```
5108struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
5109 using OpRewritePattern<PackOp>::OpRewritePattern;
5110
5111 LogicalResult matchAndRewrite(PackOp op,
5112 PatternRewriter &rewriter) const override {
5113 if (!tensor::hasFoldableTensorCastOperand(op: op))
5114 return failure();
5115
5116 SmallVector<Type> newResultTypes(op->getResultTypes());
5117 SmallVector<Value> newOperands =
5118 tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
5119
5120 // Get the updated mixed-tile-sizes attribute.
5121 SmallVector<OpFoldResult> newMixedTileSizes =
5122 getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
5123
5124 // Clone op.
5125 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5126 // this point. However, in practice, we use them for things that we'd like
5127 // to preserve. Implement a better abstraction.
5128 PackOp newOp = rewriter.create<PackOp>(
5129 op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
5130 newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
5131 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5132
5133 // Replace op.
5134 Value oldResult = op.getResult();
5135 Value newResult = newOp.getResult();
5136 Value replacement = (newResult.getType() != oldResult.getType())
5137 ? rewriter.create<tensor::CastOp>(
5138 op->getLoc(), oldResult.getType(), newResult)
5139 : newResult;
5140
5141 rewriter.replaceOp(op, {replacement});
5142
5143 return success();
5144 }
5145};
5146
5147//===----------------------------------------------------------------------===//
5148// UnPackOp
5149//===----------------------------------------------------------------------===//
5150
5151void UnPackOp::getAsmResultNames(
5152 function_ref<void(Value, StringRef)> setNameFn) {
5153 setNameFn(getResult(), "unpack");
5154}
5155
5156LogicalResult
5157UnPackOp::reifyResultShapes(OpBuilder &builder,
5158 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5159 return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5160}
5161
5162DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
5163 return getDimAndTileMappingImpl(*this);
5164}
5165
5166SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
5167 return getMixedTilesImpl(*this);
5168}
5169
5170SmallVector<int64_t> UnPackOp::getStaticTiles() {
5171 return getStaticTilesImpl(*this);
5172}
5173
5174ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
5175 ShapedType destType = getDestType();
5176 int64_t destRank = destType.getRank();
5177 return getSourceType().getShape().take_front(destRank);
5178}
5179
5180SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
5181 auto innerDimsPos = getInnerDimsPos();
5182 auto packedShape = getSourceType().getShape();
5183 SmallVector<int64_t> res;
5184
5185 for (auto index : innerDimsPos)
5186 res.push_back(packedShape[index]);
5187
5188 return res;
5189}
5190
5191LogicalResult UnPackOp::verify() {
5192 return commonVerifierPackAndUnPackOp(*this);
5193}
5194
5195Speculation::Speculatability UnPackOp::getSpeculatability() {
5196 // See PackOp::getSpeculatability.
5197 if (!areTilesAndTiledDimsAllConstant(*this))
5198 return Speculation::NotSpeculatable;
5199
5200 return Speculation::Speculatable;
5201}
5202
5203void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
5204 Value dest, ArrayRef<int64_t> innerDimsPos,
5205 ArrayRef<OpFoldResult> innerTiles,
5206 ArrayRef<int64_t> outerDimsPerm) {
5207 assert(innerDimsPos.size() == innerTiles.size() &&
5208 "number of tile sizes specified must match the specified number of "
5209 "original dimensions to be tiled");
5210 SmallVector<int64_t> staticTileSizes;
5211 SmallVector<Value> dynamicTileSizes;
5212 dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5213 build(builder, state, dest.getType(), source, dest,
5214 outerDimsPerm.empty() ? nullptr
5215 : builder.getDenseI64ArrayAttr(outerDimsPerm),
5216 builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5217 builder.getDenseI64ArrayAttr(staticTileSizes));
5218}
5219
5220Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
5221 Value source,
5222 ArrayRef<OpFoldResult> innerTileSizes,
5223 ArrayRef<int64_t> innerDimsPos,
5224 ArrayRef<int64_t> outerDimsPerm) {
5225 AffineExpr sym0, sym1;
5226 bindSymbols(b.getContext(), sym0, sym1);
5227 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5228 return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
5229 };
5230
5231 SmallVector<OpFoldResult> mixedSizes;
5232 auto srcType = llvm::cast<RankedTensorType>(source.getType());
5233 for (auto i :
5234 llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5235 if (srcType.isDynamicDim(i))
5236 mixedSizes.push_back(b.create<tensor::DimOp>(loc, source, i).getResult());
5237 else
5238 mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
5239 }
5240 if (!outerDimsPerm.empty()) {
5241 applyPermutationToVector<OpFoldResult>(
5242 mixedSizes, invertPermutationVector(outerDimsPerm));
5243 }
5244
5245 for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
5246 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5247
5248 auto elemType = srcType.getElementType();
5249 return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
5250}
5251
5252UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
5253 Value transposedSource,
5254 ArrayRef<int64_t> innerPermutation,
5255 ArrayRef<int64_t> outerPermutation) {
5256 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5257 *this, innerPermutation, outerPermutation);
5258 return b.create<UnPackOp>(loc, transposedSource, getDest(),
5259 metadata.innerDimsPos, metadata.innerTiles,
5260 metadata.outerDimsPerm);
5261}
5262
5263/// Returns true if the `srcShape` or `destShape` is different from the one in
5264/// `op` and populates each with the inferred static shape.
5265static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
5266 SmallVectorImpl<int64_t> &destShape) {
5267 bool changeNeeded = false;
5268 srcShape.assign(op.getSourceType().getShape().begin(),
5269 op.getSourceType().getShape().end());
5270 destShape.assign(op.getDestType().getShape().begin(),
5271 op.getDestType().getShape().end());
5272 llvm::SmallSetVector<int64_t, 4> innerDims;
5273 innerDims.insert_range(op.getInnerDimsPos());
5274 SmallVector<int64_t> inverseOuterDimsPerm;
5275 if (!op.getOuterDimsPerm().empty())
5276 inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
5277 int destRank = op.getDestRank();
5278 for (auto i : llvm::seq<int64_t>(0, destRank)) {
5279 if (innerDims.contains(i))
5280 continue;
5281 int64_t srcPos = i;
5282 int64_t destPos = i;
5283 if (!inverseOuterDimsPerm.empty())
5284 srcPos = inverseOuterDimsPerm[destPos];
5285 if (ShapedType::isDynamic(srcShape[srcPos]) ==
5286 ShapedType::isDynamic(destShape[destPos])) {
5287 continue;
5288 }
5289 int64_t size = srcShape[srcPos];
5290 if (ShapedType::isDynamic(size))
5291 size = destShape[destPos];
5292 srcShape[srcPos] = size;
5293 destShape[destPos] = size;
5294 changeNeeded = true;
5295 }
5296 return changeNeeded;
5297}
5298
5299LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5300 PatternRewriter &rewriter) {
5301 /// unpack(pack(x)) -> x
5302 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5303 if (packOp.getSourceType() != unPackOp.getDestType())
5304 return failure();
5305 if (packOp.getPaddingValue() ||
5306 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
5307 !haveSameTiles(packOp, unPackOp))
5308 return failure();
5309 rewriter.replaceOp(unPackOp, packOp.getSource());
5310 return success();
5311 }
5312 /// unpack(destinationStyleOp(x)) -> unpack(x)
5313 if (auto dstStyleOp =
5314 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5315 auto destValue = cast<OpResult>(unPackOp.getDest());
5316 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5317 rewriter.modifyOpInPlace(unPackOp,
5318 [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5319 return success();
5320 }
5321 /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
5322 if (unPackOp->hasOneUse()) {
5323 auto extractSliceUser =
5324 dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5325 if (extractSliceUser &&
5326 areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
5327 areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
5328 extractSliceUser.getSourceType().getRank() ==
5329 extractSliceUser.getResultType().getRank()) {
5330 OpBuilder::InsertionGuard g(rewriter);
5331 rewriter.setInsertionPoint(unPackOp);
5332 auto newDest = rewriter.create<tensor::ExtractSliceOp>(
5333 unPackOp->getLoc(), unPackOp.getDest(),
5334 extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5335 extractSliceUser.getMixedStrides());
5336 rewriter.modifyOpInPlace(unPackOp, [&]() {
5337 unPackOp.setDpsInitOperand(0, newDest);
5338 unPackOp.getResult().setType(newDest.getType());
5339 });
5340 rewriter.replaceOp(extractSliceUser, unPackOp);
5341 return success();
5342 }
5343 }
5344
5345 // Insert tensor.cast ops if static shape inference is available..
5346 SmallVector<int64_t> srcShape, destShape;
5347 if (inferStaticShape(unPackOp, srcShape, destShape)) {
5348 Location loc = unPackOp.getLoc();
5349 Value source = unPackOp.getSource();
5350 if (srcShape != unPackOp.getSourceType().getShape()) {
5351 auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5352 source = rewriter.create<tensor::CastOp>(loc, newSrcType,
5353 unPackOp.getSource());
5354 }
5355 Value dest = unPackOp.getDest();
5356 if (destShape != unPackOp.getDestType().getShape()) {
5357 auto newDestType = unPackOp.getDestType().clone(destShape);
5358 dest =
5359 rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
5360 }
5361 Value newOp = rewriter.create<UnPackOp>(
5362 loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
5363 unPackOp.getOuterDimsPerm());
5364 rewriter.replaceOpWithNewOp<tensor::CastOp>(
5365 unPackOp, unPackOp.getResult().getType(), newOp);
5366 return success();
5367 }
5368
5369 return failure();
5370}
5371
5372bool UnPackOp::isLikeUnPad() {
5373 RankedTensorType packedTensorType = getSourceType();
5374 return isLikePadUnPad(*this, packedTensorType);
5375}
5376
5377OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
5378 if (OpFoldResult reshapedSource = reshapeConstantSource(
5379 llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5380 getResult().getType()))
5381 return reshapedSource;
5382 return {};
5383}
5384
5385/// Folds a tensor.cast op into a consuming UnPackOp op if the
5386/// `tensor.cast` has source that is more static than the consuming op.
5387///
5388/// Example:
5389/// ```mlir
5390/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
5391/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
5392/// ```
5393///
5394/// folds into:
5395///
5396/// ```mlir
5397/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
5398/// ```
5399struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
5400 using OpRewritePattern<UnPackOp>::OpRewritePattern;
5401
5402 LogicalResult matchAndRewrite(UnPackOp op,
5403 PatternRewriter &rewriter) const override {
5404 if (!tensor::hasFoldableTensorCastOperand(op: op))
5405 return failure();
5406
5407 SmallVector<Type> newResultTypes(op->getResultTypes());
5408 SmallVector<Value> newOperands =
5409 tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
5410 Value sourceTensor = newOperands[0];
5411
5412 // Get the updated mixed-tile-sizes attribute.
5413 SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
5414 rewriter, sourceTensor.getType(), op.getMixedTiles());
5415
5416 // Clone op.
5417 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5418 // this point. However, in practice, we use them for things that we'd like
5419 // to preserve. Implement a better abstraction.
5420 UnPackOp newOp = rewriter.create<UnPackOp>(
5421 op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
5422 newMixedTileSizes, op.getOuterDimsPerm());
5423 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5424
5425 // Replace op.
5426 Value oldResult = op.getResult();
5427 Value newResult = newOp.getResult();
5428 Value replacement = (newResult.getType() != oldResult.getType())
5429 ? rewriter.create<tensor::CastOp>(
5430 op->getLoc(), oldResult.getType(), newResult)
5431 : newResult;
5432
5433 rewriter.replaceOp(op, {replacement});
5434
5435 return success();
5436 }
5437};
5438
5439//===----------------------------------------------------------------------===//
5440// BatchReduceMatmulOp
5441//===----------------------------------------------------------------------===//
5442SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
5443 return SmallVector<utils::IteratorType>{
5444 utils::IteratorType::reduction, utils::IteratorType::parallel,
5445 utils::IteratorType::parallel, utils::IteratorType::reduction};
5446}
5447
5448SmallVector<AffineMap>
5449BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
5450 AffineExpr d0, d1, d2, d3;
5451 SmallVector<AffineMap> indexingMaps;
5452 bindDims(context, d0, d1, d2, d3);
5453 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
5454 indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
5455 indexingMaps.push_back(AffineMap::get(4, 0, {d1, d2}, context));
5456 return indexingMaps;
5457}
5458
5459unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
5460
5461std::string BatchReduceMatmulOp::getLibraryCallName() {
5462 return generateLibraryCallName(getOperation());
5463}
5464
5465/// Check if the op has broadcast and/or transpose semantic. Returns true if
5466/// the user defined indexing maps are not equal to default map.
5467bool BatchReduceMatmulOp::hasUserDefinedMaps() {
5468 SmallVector<AffineMap, 3> defaultMaps =
5469 getDefaultIndexingMaps(this->getContext());
5470 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
5471 return defaultMaps != explicitMaps;
5472}
5473
5474/// Returns true if the given bcastMap map is a valid broadcast map. A valid
5475/// broadcast map must include K dimension.
5476/// TODO: Strict inclusion of K dimension in the broadcast map is not
5477/// necessary for both input matrices simultaneously. We can relax this
5478/// condition to have K dimension for one input matrix map and infer the K
5479/// dimension for other input matrix map from the one already having K
5480/// dimension.
5481bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
5482 bool isLHS) {
5483 assert(bcastMap.getNumResults() < 3 &&
5484 "Expected less than 3 result dim expr.");
5485 bool isValid = false;
5486 enum Indices { batchPos, mPos, nPos, kPos };
5487 if (bcastMap.getNumResults() == 1) {
5488 AffineExpr expr = bcastMap.getResult(0);
5489 isValid = expr.isFunctionOfDim(kPos);
5490 } else if (bcastMap.getNumResults() == 2) {
5491 AffineExpr expr0 = bcastMap.getResult(0);
5492 AffineExpr expr1 = bcastMap.getResult(1);
5493 isValid =
5494 isLHS ? ((expr0.isFunctionOfDim(batchPos) ||
5495 expr0.isFunctionOfDim(mPos)) &&
5496 expr1.isFunctionOfDim(kPos))
5497 : ((expr0.isFunctionOfDim(batchPos) &&
5498 expr1.isFunctionOfDim(kPos)) ||
5499 (expr0.isFunctionOfDim(kPos) && expr1.isFunctionOfDim(nPos)));
5500 }
5501 return isValid;
5502}
5503
5504void BatchReduceMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
5505 ArrayRef<NamedAttribute> attrs) {
5506 assert(block.getNumArguments() == 3 &&
5507 "BatchReduceMatmulOp regionBuilder expects 3 (>=0) args");
5508 RegionBuilderHelper helper(b, block);
5509 SmallVector<Value> yields;
5510
5511 auto toType = block.getArgument(2).getType();
5512 Value castValA =
5513 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(0));
5514 Value castValB =
5515 helper.buildTypeFn(TypeFn::cast_signed, toType, block.getArgument(1));
5516 Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
5517 Value addVal =
5518 helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
5519 yields.push_back(addVal);
5520 helper.yieldOutputs(yields);
5521}
5522
5523ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
5524 OperationState &result) {
5525 SmallVector<Attribute, 3> indexingMapsAttr;
5526 Attribute mapAttr;
5527 if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
5528 if (parser.parseEqual())
5529 return failure();
5530 if (parser.parseLSquare())
5531 return failure();
5532
5533 do {
5534 if (parser.parseAttribute(mapAttr))
5535 return failure();
5536 if (!isa<AffineMapAttr>(mapAttr)) {
5537 return parser.emitError(parser.getCurrentLocation(),
5538 "expected affine map attribute");
5539 }
5540 indexingMapsAttr.push_back(mapAttr);
5541
5542 if (parser.parseOptionalComma())
5543 break;
5544 } while (true);
5545
5546 if (parser.parseRSquare())
5547 return failure();
5548 }
5549 // Initialize indexingMaps, if not supplied explicitly.
5550 if (indexingMapsAttr.empty()) {
5551 indexingMapsAttr = llvm::map_to_vector(
5552 BatchReduceMatmulOp::getDefaultIndexingMaps(parser.getContext()),
5553 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
5554 }
5555 result.addAttribute("indexing_maps",
5556 parser.getBuilder().getArrayAttr(indexingMapsAttr));
5557 return ::parseNamedStructuredOp(parser, result,
5558 BatchReduceMatmulOp::getNumRegionArgs(),
5559 BatchReduceMatmulOp::getRegionBuilder());
5560}
5561
5562void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
5563 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
5564 BatchReduceMatmulOp::getDefaultIndexingMaps(getContext()),
5565 [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
5566
5567 if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
5568 p << " indexing_maps = [";
5569 llvm::interleaveComma(getIndexingMaps(), p,
5570 [&](Attribute attr) { p.printAttribute(attr); });
5571 p << "]";
5572 }
5573
5574 SmallVector<StringRef, 3> elidedAttrs = {
5575 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
5576 ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
5577 elidedAttrs);
5578}
5579
5580/// Verify the user defined indexing maps.
5581LogicalResult BatchReduceMatmulOp::verify() {
5582 // Verification of pure batch_reduce_matmul is handled by
5583 // verifyStructuredOpInterface().
5584 if (!hasUserDefinedMaps())
5585 return success();
5586
5587 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
5588 if (failed(verifyExtendedBatchVariantMatmulSemantic(*this, opIndex)))
5589 return failure();
5590 }
5591 return success();
5592}
5593LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
5594 SmallVectorImpl<OpFoldResult> &) {
5595 return memref::foldMemRefCast(*this);
5596}
5597void BatchReduceMatmulOp::getEffects(
5598 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5599 &effects) {
5600 if (hasPureTensorSemantics())
5601 return;
5602 getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
5603}
5604
5605Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
5606 return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
5607}
5608
5609} // namespace linalg
5610} // namespace mlir
5611
5612//===----------------------------------------------------------------------===//
5613// LinalgDialect
5614//===----------------------------------------------------------------------===//
5615
5616void LinalgDialect::getCanonicalizationPatterns(
5617 RewritePatternSet &results) const {
5618 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
5619 FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
5620}
5621
5622Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
5623 Attribute value, Type type,
5624 Location loc) {
5625 return arith::ConstantOp::materialize(builder, value, type, loc);
5626}
5627

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp