1//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Linalg operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Linalg/IR/Linalg.h"
14
15#include "mlir/AsmParser/AsmParser.h"
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Arith/Utils/Utils.h"
19#include "mlir/Dialect/Complex/IR/Complex.h"
20#include "mlir/Dialect/Math/IR/Math.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
23#include "mlir/Dialect/Tensor/IR/Tensor.h"
24#include "mlir/Dialect/Utils/IndexingUtils.h"
25#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
26#include "mlir/Dialect/Utils/StaticValueUtils.h"
27#include "mlir/IR/AffineMap.h"
28#include "mlir/IR/Attributes.h"
29#include "mlir/IR/Builders.h"
30#include "mlir/IR/BuiltinAttributes.h"
31#include "mlir/IR/BuiltinTypeInterfaces.h"
32#include "mlir/IR/OpImplementation.h"
33#include "mlir/IR/OperationSupport.h"
34#include "mlir/IR/PatternMatch.h"
35#include "mlir/Interfaces/InferTypeOpInterface.h"
36#include "mlir/Interfaces/SideEffectInterfaces.h"
37
38#include "llvm/ADT/DenseMap.h"
39#include "llvm/ADT/STLExtras.h"
40#include "llvm/ADT/SetOperations.h"
41#include "llvm/ADT/SmallVector.h"
42#include "llvm/ADT/StringSet.h"
43#include "llvm/ADT/TypeSwitch.h"
44#include "llvm/Support/FormatVariadic.h"
45#include "llvm/Support/InterleavedRange.h"
46#include "llvm/Support/LogicalResult.h"
47#include "llvm/Support/MathExtras.h"
48#include "llvm/Support/raw_ostream.h"
49#include <cassert>
50#include <optional>
51
52using namespace mlir;
53using namespace mlir::linalg;
54
55/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
56static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
57 int64_t dim) {
58 auto type = cast<ShapedType>(Val: v.getType());
59 if (!type.isDynamicDim(idx: dim))
60 return builder.getIndexAttr(value: type.getDimSize(idx: dim));
61
62 return getAsOpFoldResult(
63 val: TypeSwitch<Type, Value>(v.getType())
64 .Case<RankedTensorType>(caseFn: [&](RankedTensorType t) -> Value {
65 return builder.create<tensor::DimOp>(location: loc, args&: v, args&: dim);
66 })
67 .Case<MemRefType>(caseFn: [&](MemRefType t) -> Value {
68 return builder.create<memref::DimOp>(location: loc, args&: v, args&: dim);
69 }));
70}
71
72/// Returns a memref.subview or a tensor.extract_slice based on the type of the
73/// `source`.
74static Operation *getSlice(OpBuilder &b, Location loc, Value source,
75 ArrayRef<OpFoldResult> offsets,
76 ArrayRef<OpFoldResult> sizes,
77 ArrayRef<OpFoldResult> strides) {
78 return TypeSwitch<Type, Operation *>(source.getType())
79 .Case<RankedTensorType>(caseFn: [&](RankedTensorType t) -> Operation * {
80 return b.create<tensor::ExtractSliceOp>(location: loc, args&: source, args&: offsets, args&: sizes,
81 args&: strides);
82 })
83 .Case<MemRefType>(caseFn: [&](MemRefType type) -> Operation * {
84 return b.create<memref::SubViewOp>(location: loc, args&: source, args&: offsets, args&: sizes,
85 args&: strides);
86 })
87 .Default(defaultFn: [&](Type t) -> Operation * { return nullptr; });
88}
89
90//===----------------------------------------------------------------------===//
91// Helper functions
92//===----------------------------------------------------------------------===//
93
94Value linalg::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
95 int64_t dim) {
96 if (llvm::isa<UnrankedMemRefType, MemRefType>(Val: source.getType()))
97 return b.createOrFold<memref::DimOp>(location: loc, args&: source, args&: dim);
98 if (llvm::isa<UnrankedTensorType, RankedTensorType>(Val: source.getType()))
99 return b.createOrFold<tensor::DimOp>(location: loc, args&: source, args&: dim);
100 llvm_unreachable("Expected MemRefType or TensorType");
101}
102
103OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source,
104 int64_t dim) {
105 auto shapedType = llvm::cast<ShapedType>(Val: source.getType());
106 if (!shapedType.hasRank() || shapedType.isDynamicDim(idx: dim))
107 return createOrFoldDimOp(b, loc, source, dim);
108 return b.getIndexAttr(value: shapedType.getDimSize(idx: dim));
109}
110
111//===----------------------------------------------------------------------===//
112// Support for named Linalg ops defined in ods-gen.
113//===----------------------------------------------------------------------===//
114
115using RegionBuilderFn = llvm::function_ref<void(
116 ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>,
117 function_ref<InFlightDiagnostic()>)>;
118
119/// Fills the region of a structured operation using the provided
120/// `regionBuilder`. The method is used by both named structured ops created by
121/// ods-gen and by manually defined C++ ops. It is called by both builders and
122/// parsers and creates a block with arguments corresponding to the elemental
123/// types of `inputTypes` and `outputTypes`.
124static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
125 TypeRange inputTypes, TypeRange outputTypes,
126 ArrayRef<NamedAttribute> attrs,
127 function_ref<InFlightDiagnostic()> emitError,
128 RegionBuilderFn regionBuilder) {
129 SmallVector<Type, 8> argTypes;
130 SmallVector<Location, 8> argLocs;
131 for (auto containers : {inputTypes, outputTypes}) {
132 for (auto t : containers) {
133 argTypes.push_back(
134 Elt: isa<MemRefType, RankedTensorType>(Val: t) ? getElementTypeOrSelf(type: t) : t);
135
136 // TODO: Pass in a proper location here.
137 argLocs.push_back(Elt: opBuilder.getUnknownLoc());
138 }
139 }
140
141 // RAII.
142 OpBuilder::InsertionGuard guard(opBuilder);
143 Block *body =
144 opBuilder.createBlock(parent: &region, /*insertPt=*/{}, argTypes, locs: argLocs);
145
146 opBuilder.setInsertionPointToStart(body);
147 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
148 regionBuilder(b, *body, attrs, emitError);
149
150 // indexing_maps is an auto-generated method.
151
152 // iterator_types is an auto-generated method.
153}
154
155/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
156/// The result types are derived automatically if `resultTensorTypes` is none.
157/// The body of the operation is filled using `regionBuilder`. All ods-gen
158/// created structured operations use the method to implement their builders.
159static void buildStructuredOp(OpBuilder &b, OperationState &state,
160 std::optional<TypeRange> resultTensorTypes,
161 ValueRange inputs, ValueRange outputs,
162 ArrayRef<NamedAttribute> attributes,
163 RegionBuilderFn regionBuilder) {
164 // Derive the result types if needed.
165 SmallVector<Type> derivedResultTypes =
166 resultTensorTypes.value_or(u: TypeRange());
167 if (!resultTensorTypes)
168 copy_if(Range: outputs.getTypes(), Out: std::back_inserter(x&: derivedResultTypes),
169 P: llvm::IsaPred<RankedTensorType>);
170
171 state.addOperands(newOperands: inputs);
172 state.addOperands(newOperands: outputs);
173 state.addTypes(newTypes: derivedResultTypes);
174
175 state.addAttributes(newAttributes: attributes);
176 state.addAttribute(
177 name: "operandSegmentSizes",
178 attr: b.getDenseI32ArrayAttr(values: {static_cast<int32_t>(inputs.size()),
179 static_cast<int32_t>(outputs.size())}));
180
181 // Create and fill the region of the structured operation.
182 Region &region = *state.addRegion();
183 fillStructuredOpRegion(opBuilder&: b, region, inputTypes: TypeRange(inputs), outputTypes: TypeRange(outputs),
184 attrs: state.attributes.getAttrs(), /*emitError=*/{},
185 regionBuilder);
186}
187
188static void buildMatmulOp(OpBuilder &b, OperationState &state,
189 std::optional<TypeRange> resultTensorTypes,
190 ValueRange inputs, ValueRange outputs,
191 ArrayRef<NamedAttribute> attributes,
192 RegionBuilderFn regionBuilder,
193 ArrayRef<AffineMap> indexingMaps) {
194 // Initialize indexingMaps attribute, for MatmulOp.
195 SmallVector<Attribute, 3> indexingMapsAttrVal;
196 indexingMapsAttrVal = llvm::map_to_vector(
197 C: MatmulOp::getDefaultIndexingMaps(context: b.getContext()),
198 F: [](AffineMap map) -> Attribute { return AffineMapAttr::get(value: map); });
199 state.addAttribute(name: "indexing_maps", attr: b.getArrayAttr(value: indexingMapsAttrVal));
200 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
201 attributes, regionBuilder);
202}
203
204static void buildBatchMatmulOp(OpBuilder &b, OperationState &state,
205 std::optional<TypeRange> resultTensorTypes,
206 ValueRange inputs, ValueRange outputs,
207 ArrayRef<NamedAttribute> attributes,
208 RegionBuilderFn regionBuilder,
209 ArrayRef<AffineMap> indexingMaps) {
210 // Initialize indexingMaps attribute, for BatchMatmulOp.
211 SmallVector<Attribute, 4> indexingMapsAttrVal;
212 indexingMapsAttrVal =
213 llvm::map_to_vector(C&: indexingMaps, F: [](AffineMap map) -> Attribute {
214 return AffineMapAttr::get(value: map);
215 });
216 state.addAttribute(name: "indexing_maps", attr: b.getArrayAttr(value: indexingMapsAttrVal));
217 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
218 attributes, regionBuilder);
219}
220
221static void buildBatchReduceMatmulOp(OpBuilder &b, OperationState &state,
222 std::optional<TypeRange> resultTensorTypes,
223 ValueRange inputs, ValueRange outputs,
224 ArrayRef<NamedAttribute> attributes,
225 RegionBuilderFn regionBuilder,
226 ArrayRef<AffineMap> indexingMaps) {
227 // Initialize indexingMaps attribute, for BatchReduceMatmulOp.
228 SmallVector<Attribute, 4> indexingMapsAttrVal;
229 indexingMapsAttrVal =
230 llvm::map_to_vector(C&: indexingMaps, F: [](AffineMap map) -> Attribute {
231 return AffineMapAttr::get(value: map);
232 });
233 state.addAttribute(name: "indexing_maps", attr: b.getArrayAttr(value: indexingMapsAttrVal));
234 return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
235 attributes, regionBuilder);
236}
237
238/// Common parsing used for both named structured ops created by ods-gen and by
239/// manually defined C++ ops. Does not handle regions.
240static ParseResult
241parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
242 SmallVectorImpl<Type> &inputTypes,
243 SmallVectorImpl<Type> &outputTypes,
244 bool addOperandSegmentSizes = true) {
245 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
246 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
247 outputsOperands;
248
249 if (succeeded(Result: parser.parseOptionalLess())) {
250 if (parser.parseAttribute(result&: result.propertiesAttr) || parser.parseGreater())
251 return failure();
252 }
253 attrsLoc = parser.getCurrentLocation();
254 if (parser.parseOptionalAttrDict(result&: result.attributes))
255 return failure();
256
257 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "ins"))) {
258 if (parser.parseLParen())
259 return failure();
260
261 inputsOperandsLoc = parser.getCurrentLocation();
262 if (parser.parseOperandList(result&: inputsOperands) ||
263 parser.parseColonTypeList(result&: inputTypes) || parser.parseRParen())
264 return failure();
265 }
266
267 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "outs"))) {
268 outputsOperandsLoc = parser.getCurrentLocation();
269 if (parser.parseLParen() || parser.parseOperandList(result&: outputsOperands) ||
270 parser.parseColonTypeList(result&: outputTypes) || parser.parseRParen())
271 return failure();
272 }
273
274 if (parser.resolveOperands(operands&: inputsOperands, types&: inputTypes, loc: inputsOperandsLoc,
275 result&: result.operands) ||
276 parser.resolveOperands(operands&: outputsOperands, types&: outputTypes, loc: outputsOperandsLoc,
277 result&: result.operands))
278 return failure();
279
280 if (addOperandSegmentSizes) {
281 // This is a bit complex because we're trying to be backward compatible with
282 // operation syntax that mix the inherent attributes and the discardable
283 // ones in the same dictionary. If the properties are used, we append the
284 // operandSegmentSizes there directly. Otherwise we append it to the
285 // discardable attributes dictionary where it is handled by the generic
286 // Operation::create(...) method.
287 if (result.propertiesAttr) {
288 NamedAttrList attrs = llvm::cast<DictionaryAttr>(Val&: result.propertiesAttr);
289 attrs.append(name: "operandSegmentSizes",
290 attr: parser.getBuilder().getDenseI32ArrayAttr(
291 values: {static_cast<int32_t>(inputsOperands.size()),
292 static_cast<int32_t>(outputsOperands.size())}));
293 result.propertiesAttr = attrs.getDictionary(context: parser.getContext());
294 } else {
295 result.addAttribute(name: "operandSegmentSizes",
296 attr: parser.getBuilder().getDenseI32ArrayAttr(
297 values: {static_cast<int32_t>(inputsOperands.size()),
298 static_cast<int32_t>(outputsOperands.size())}));
299 }
300 }
301 if (!result.propertiesAttr) {
302 std::optional<RegisteredOperationName> info =
303 result.name.getRegisteredInfo();
304 if (info) {
305 if (failed(Result: info->verifyInherentAttrs(attributes&: result.attributes, emitError: [&]() {
306 return parser.emitError(loc: attrsLoc)
307 << "'" << result.name.getStringRef() << "' op ";
308 })))
309 return failure();
310 }
311 }
312 return success();
313}
314
315static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
316 ValueRange outputs) {
317 if (!inputs.empty())
318 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
319 if (!outputs.empty())
320 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
321}
322
323//===----------------------------------------------------------------------===//
324// Specific parsing and printing for named structured ops created by ods-gen.
325//===----------------------------------------------------------------------===//
326
327static ParseResult parseNamedStructuredOpRegion(
328 OpAsmParser &parser, Region &region, unsigned numRegionArgs,
329 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
330 RegionBuilderFn regionBuilder, SMLoc loc) {
331 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
332 return parser.emitError(
333 loc: parser.getCurrentLocation(),
334 message: llvm::formatv(Fmt: "[parseNamedStructuredOpRegion] ods-gen generated "
335 "region expects {0} args, got {1}",
336 Vals&: numRegionArgs, Vals: inputTypes.size() + outputTypes.size()));
337 }
338
339 OpBuilder opBuilder(parser.getContext());
340 ParseResult result = success();
341 fillStructuredOpRegion(
342 opBuilder, region, inputTypes, outputTypes, attrs,
343 emitError: [&]() {
344 result = failure();
345 return parser.emitError(loc);
346 },
347 regionBuilder);
348 return result;
349}
350
351static ParseResult
352parseNamedStructuredOpResults(OpAsmParser &parser,
353 SmallVectorImpl<Type> &resultTypes) {
354 if (parser.parseOptionalArrowTypeList(result&: resultTypes))
355 return failure();
356 return success();
357}
358
359static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
360 OperationState &result,
361 unsigned numRegionArgs,
362 RegionBuilderFn regionBuilder) {
363 // TODO: Enable when ods-gen supports captures.
364 SmallVector<Type, 1> inputTypes, outputTypes;
365 SMLoc loc = parser.getCurrentLocation();
366 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
367 return failure();
368
369 // Parse optional attributes.
370 if (parser.parseOptionalAttrDict(result&: result.attributes))
371 return failure();
372
373 // TODO: consider merging results parsing into region parsing.
374 // Need to wait for declarative assembly resolution to decide.
375 SmallVector<Type, 1> outputTensorsTypes;
376 if (parseNamedStructuredOpResults(parser, resultTypes&: outputTensorsTypes))
377 return failure();
378 result.addTypes(newTypes: outputTensorsTypes);
379
380 std::unique_ptr<Region> region = std::make_unique<Region>();
381 if (parseNamedStructuredOpRegion(parser, region&: *region, numRegionArgs, inputTypes,
382 outputTypes, attrs: result.attributes.getAttrs(),
383 regionBuilder, loc))
384 return failure();
385 result.addRegion(region: std::move(region));
386
387 return success();
388}
389
390static void printNamedStructuredOpResults(OpAsmPrinter &p,
391 TypeRange resultTypes) {
392 if (resultTypes.empty())
393 return;
394 p.printOptionalArrowTypeList(types&: resultTypes);
395}
396
397static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
398 ValueRange inputs, ValueRange outputs,
399 ArrayRef<StringRef> elidedAttrs = {}) {
400 p.printOptionalAttrDict(attrs: op->getAttrs(), elidedAttrs);
401
402 // Printing is shared with generic ops, except for the region and
403 // attributes.
404 printCommonStructuredOpParts(p, inputs, outputs);
405
406 // Results printing.
407 printNamedStructuredOpResults(p, resultTypes: op->getResultTypes());
408
409 // Region is elided.
410}
411
412//===----------------------------------------------------------------------===//
413// Region builder helper.
414// TODO: Move this to a utility library.
415// The public methods on this class are referenced directly from generated code.
416// Helper build the unary, binary, and type conversion functions defined by the
417// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
418// class.
419//
420// Implementations of the math functions must be polymorphic over numeric types,
421// internally performing necessary casts. If the function application makes no
422// sense, then the only recourse is to assert and return nullptr. This can be
423// extended later if it becomes possible to fail construction of the region. The
424// invariant should be enforced at a higher level.
425//
426// TODO: These helpers are currently type polymorphic over the class of integer
427// and floating point types, but they will not internally cast within bit
428// widths of a class (mixed precision such as i8->i32) or across classes
429// (i.e. mixed float and integer). Many such combinations are ambiguous or need
430// to be handled with care and work is being considered to extend the op
431// language to make such cases explicit. In the mean-time, violating this will
432// fail verification, which is deemed acceptable.
433//===----------------------------------------------------------------------===//
434
435namespace {
436
437class RegionBuilderHelper {
438public:
439 RegionBuilderHelper(OpBuilder &builder, Block &block)
440 : builder(builder), block(block) {}
441
442 // Build the unary functions defined by OpDSL.
443 Value buildUnaryFn(UnaryFn unaryFn, Value arg,
444 function_ref<InFlightDiagnostic()> emitError = {}) {
445 if (!isFloatingPoint(value: arg)) {
446 if (emitError) {
447 emitError() << "unsupported non numeric type";
448 return nullptr;
449 }
450 llvm_unreachable("unsupported non numeric type");
451 }
452 OpBuilder::InsertionGuard g(builder);
453 builder.setInsertionPointToEnd(&block);
454 switch (unaryFn) {
455 case UnaryFn::exp:
456 return builder.create<math::ExpOp>(location: arg.getLoc(), args&: arg);
457 case UnaryFn::log:
458 return builder.create<math::LogOp>(location: arg.getLoc(), args&: arg);
459 case UnaryFn::abs:
460 return builder.create<math::AbsFOp>(location: arg.getLoc(), args&: arg);
461 case UnaryFn::ceil:
462 return builder.create<math::CeilOp>(location: arg.getLoc(), args&: arg);
463 case UnaryFn::floor:
464 return builder.create<math::FloorOp>(location: arg.getLoc(), args&: arg);
465 case UnaryFn::negf:
466 return builder.create<arith::NegFOp>(location: arg.getLoc(), args&: arg);
467 case UnaryFn::reciprocal: {
468 Attribute oneAttr = builder.getOneAttr(type: arg.getType());
469 auto one = builder.create<arith::ConstantOp>(location: arg.getLoc(),
470 args: ::cast<TypedAttr>(Val&: oneAttr));
471 return builder.create<arith::DivFOp>(location: arg.getLoc(), args&: one, args&: arg);
472 }
473 case UnaryFn::round:
474 return builder.create<math::RoundOp>(location: arg.getLoc(), args&: arg);
475 case UnaryFn::sqrt:
476 return builder.create<math::SqrtOp>(location: arg.getLoc(), args&: arg);
477 case UnaryFn::rsqrt:
478 return builder.create<math::RsqrtOp>(location: arg.getLoc(), args&: arg);
479 case UnaryFn::square:
480 return builder.create<arith::MulFOp>(location: arg.getLoc(), args&: arg, args&: arg);
481 case UnaryFn::tanh:
482 return builder.create<math::TanhOp>(location: arg.getLoc(), args&: arg);
483 case UnaryFn::erf:
484 return builder.create<math::ErfOp>(location: arg.getLoc(), args&: arg);
485 }
486 if (emitError) {
487 emitError() << "unsupported unary function";
488 return nullptr;
489 }
490 llvm_unreachable("unsupported unary function");
491 }
492
493 // Build the binary functions defined by OpDSL.
494 // If emitError is provided, an error will be emitted if the operation is not
495 // supported and a nullptr will be returned, otherwise an assertion will be
496 // raised.
497 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1,
498 function_ref<InFlightDiagnostic()> emitError = {}) {
499 bool allComplex = isComplex(value: arg0) && isComplex(value: arg1);
500 bool allFloatingPoint = isFloatingPoint(value: arg0) && isFloatingPoint(value: arg1);
501 bool allInteger = isInteger(value: arg0) && isInteger(value: arg1);
502 bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
503 arg1.getType().getIntOrFloatBitWidth() == 1;
504 if (!allComplex && !allFloatingPoint && !allInteger) {
505 if (emitError) {
506 emitError()
507 << "Cannot build binary Linalg operation: expects allComplex, "
508 "allFloatingPoint, or allInteger, got "
509 << arg0.getType() << " and " << arg1.getType();
510 return nullptr;
511 }
512 llvm_unreachable("unsupported non numeric type");
513 }
514 OpBuilder::InsertionGuard g(builder);
515 builder.setInsertionPointToEnd(&block);
516 switch (binaryFn) {
517 case BinaryFn::add:
518 if (allComplex)
519 return builder.create<complex::AddOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
520 if (allFloatingPoint)
521 return builder.create<arith::AddFOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
522 if (allBool)
523 return builder.create<arith::OrIOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
524 return builder.create<arith::AddIOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
525 case BinaryFn::sub:
526 if (allComplex)
527 return builder.create<complex::SubOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
528 if (allFloatingPoint)
529 return builder.create<arith::SubFOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
530 if (allBool) {
531 if (emitError) {
532 emitError() << "unsupported operation: sub with bools";
533 return nullptr;
534 }
535 llvm_unreachable("unsupported operation: sub with bools");
536 }
537 return builder.create<arith::SubIOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
538 case BinaryFn::mul:
539 if (allComplex)
540 return builder.create<complex::MulOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
541 if (allFloatingPoint)
542 return builder.create<arith::MulFOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
543 if (allBool)
544 return builder.create<arith::AndIOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
545 return builder.create<arith::MulIOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
546 case BinaryFn::div:
547 if (allComplex)
548 return builder.create<complex::DivOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
549 if (allFloatingPoint)
550 return builder.create<arith::DivFOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
551 if (allBool) {
552 if (emitError) {
553 emitError() << "unsupported operation: div with bools";
554 return nullptr;
555 }
556 llvm_unreachable("unsupported operation: div with bools");
557 }
558 return builder.create<arith::DivSIOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
559 case BinaryFn::div_unsigned:
560 if (!allInteger || allBool) {
561 if (emitError) {
562 emitError() << "unsupported operation: unsigned div not on uint";
563 return nullptr;
564 }
565 llvm_unreachable("unsupported operation: unsigned div not on uint");
566 }
567 return builder.create<arith::DivUIOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
568 case BinaryFn::max_signed:
569 assert(!allComplex);
570 if (allFloatingPoint)
571 return builder.create<arith::MaximumFOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
572 return builder.create<arith::MaxSIOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
573 case BinaryFn::min_signed:
574 assert(!allComplex);
575 if (allFloatingPoint)
576 return builder.create<arith::MinimumFOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
577 return builder.create<arith::MinSIOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
578 case BinaryFn::max_unsigned:
579 assert(!allComplex);
580 if (allFloatingPoint)
581 return builder.create<arith::MaximumFOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
582 return builder.create<arith::MaxUIOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
583 case BinaryFn::min_unsigned:
584 assert(!allComplex);
585 if (allFloatingPoint)
586 return builder.create<arith::MinimumFOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
587 return builder.create<arith::MinUIOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
588 case BinaryFn::powf:
589 assert(allFloatingPoint);
590 return builder.create<math::PowFOp>(location: arg0.getLoc(), args&: arg0, args&: arg1);
591 }
592 if (emitError) {
593 emitError() << "unsupported binary function";
594 return nullptr;
595 }
596 llvm_unreachable("unsupported binary function");
597 }
598
599 // Build the ternary functions defined by OpDSL.
600 Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1, Value arg2,
601 function_ref<InFlightDiagnostic()> emitError = {}) {
602 bool headBool =
603 isInteger(value: arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
604 bool tailFloatingPoint =
605 isFloatingPoint(value: arg0) && isFloatingPoint(value: arg1) && isFloatingPoint(value: arg2);
606 bool tailInteger = isInteger(value: arg0) && isInteger(value: arg1) && isInteger(value: arg2);
607 OpBuilder::InsertionGuard g(builder);
608 builder.setInsertionPointToEnd(&block);
609 switch (ternaryFn) {
610 case TernaryFn::select:
611 if (!headBool && !(tailFloatingPoint || tailInteger))
612 llvm_unreachable("unsupported non numeric type");
613 return builder.create<arith::SelectOp>(location: arg0.getLoc(), args&: arg0, args&: arg1, args&: arg2);
614 }
615 if (emitError) {
616 emitError() << "unsupported ternary function";
617 return nullptr;
618 }
619 llvm_unreachable("unsupported ternary function");
620 }
621
622 // Build the type functions defined by OpDSL.
623 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand,
624 function_ref<InFlightDiagnostic()> emitError = {}) {
625 switch (typeFn) {
626 case TypeFn::cast_signed:
627 return cast(toType, operand, isUnsignedCast: false);
628 case TypeFn::cast_unsigned:
629 return cast(toType, operand, isUnsignedCast: true);
630 }
631 if (emitError) {
632 emitError() << "unsupported type conversion function";
633 return nullptr;
634 }
635 llvm_unreachable("unsupported type conversion function");
636 }
637
638 void yieldOutputs(ValueRange values) {
639 OpBuilder::InsertionGuard g(builder);
640 builder.setInsertionPointToEnd(&block);
641 Location loc = builder.getUnknownLoc();
642 builder.create<YieldOp>(location: loc, args&: values);
643 }
644
645 Value constant(const std::string &value) {
646 OpBuilder::InsertionGuard g(builder);
647 builder.setInsertionPointToEnd(&block);
648 Location loc = builder.getUnknownLoc();
649 Attribute valueAttr = parseAttribute(attrStr: value, context: builder.getContext());
650 return builder.create<arith::ConstantOp>(location: loc, args: ::cast<TypedAttr>(Val&: valueAttr));
651 }
652
653 Value index(int64_t dim) {
654 OpBuilder::InsertionGuard g(builder);
655 builder.setInsertionPointToEnd(&block);
656 return builder.create<IndexOp>(location: builder.getUnknownLoc(), args&: dim);
657 }
658
659 Type getIntegerType(unsigned width) {
660 return IntegerType::get(context: builder.getContext(), width);
661 }
662
663 Type getFloat32Type() { return Float32Type::get(context: builder.getContext()); }
664 Type getFloat64Type() { return Float64Type::get(context: builder.getContext()); }
665
666private:
667 // Generates operations to cast the given operand to a specified type.
668 // If the cast cannot be performed, a warning will be issued and the
669 // operand returned as-is (which will presumably yield a verification
670 // issue downstream).
671 Value cast(Type toType, Value operand, bool isUnsignedCast) {
672 OpBuilder::InsertionGuard g(builder);
673 builder.setInsertionPointToEnd(&block);
674 auto loc = operand.getLoc();
675 if (isa<UnknownLoc>(Val: loc)) {
676 if (operand.getDefiningOp())
677 loc = operand.getDefiningOp()->getLoc();
678 else if (operand.getParentBlock() &&
679 operand.getParentBlock()->getParentOp())
680 loc = operand.getParentBlock()->getParentOp()->getLoc();
681 }
682 return convertScalarToDtype(b&: builder, loc, operand, toType, isUnsignedCast);
683 }
684
685 bool isComplex(Value value) {
686 return llvm::isa<ComplexType>(Val: value.getType());
687 }
688 bool isFloatingPoint(Value value) {
689 return llvm::isa<FloatType>(Val: value.getType());
690 }
691 bool isInteger(Value value) {
692 return llvm::isa<IntegerType>(Val: value.getType());
693 }
694
695 OpBuilder &builder;
696 Block &block;
697};
698
699} // namespace
700
701//===----------------------------------------------------------------------===//
702// CopyOp
703//===----------------------------------------------------------------------===//
704
705namespace {
706
707struct EraseSelfCopy : OpRewritePattern<CopyOp> {
708 using OpRewritePattern<CopyOp>::OpRewritePattern;
709 LogicalResult matchAndRewrite(CopyOp copyOp,
710 PatternRewriter &rewriter) const override {
711 if (copyOp.getInputs() != copyOp.getOutputs())
712 return rewriter.notifyMatchFailure(arg&: copyOp, msg: "not a self copy");
713 if (copyOp.hasPureBufferSemantics())
714 rewriter.eraseOp(op: copyOp);
715 else
716 rewriter.replaceOp(op: copyOp, newValues: copyOp.getInputs());
717
718 return success();
719 }
720};
721
722} // namespace
723
724void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
725 MLIRContext *context) {
726 results.add<EraseSelfCopy>(arg&: context);
727}
728
729//===----------------------------------------------------------------------===//
730// FillOp
731//===----------------------------------------------------------------------===//
732
733namespace {
734
735/// Fold linalg.fill -> tensor.expand/collapse_shape chain.
736///
737/// For such op chains, we can create new linalg.fill ops with the result
738/// type of the tensor.expand/collapse_shape op.
739template <typename TensorReshapeOp>
740struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
741 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
742 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
743 PatternRewriter &rewriter) const override {
744 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
745 if (!oldFill)
746 return failure();
747
748 Location loc = oldFill.getLoc();
749 TensorReshapeOp newInit;
750 if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
751
752 newInit = rewriter.create<TensorReshapeOp>(
753 loc, reshapeOp.getResultType(), oldFill.output(),
754 reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
755 reshapeOp.getStaticOutputShape());
756 } else {
757 newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
758 oldFill.output(),
759 reshapeOp.getReassociation());
760 }
761 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
762 ValueRange{newInit});
763 return success();
764 }
765};
766
767/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
768/// filling value are the same.
769struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
770 using OpRewritePattern::OpRewritePattern;
771
772 LogicalResult matchAndRewrite(tensor::PadOp padOp,
773 PatternRewriter &rewriter) const override {
774 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
775 if (!fillOp)
776 return failure();
777
778 // We can only fold if the padding value is the same as the original
779 // filling value.
780 Value padValue = padOp.getConstantPaddingValue();
781 if (!padValue || fillOp.value() != padValue)
782 return failure();
783
784 ReifiedRankedShapedTypeDims reifiedShape;
785 if (failed(Result: reifyResultShapes(b&: rewriter, op: padOp, reifiedReturnShapes&: reifiedShape)))
786 return rewriter.notifyMatchFailure(
787 arg&: padOp, msg: "failed to reify tensor.pad op result shape");
788
789 auto emptyTensor = rewriter.create<tensor::EmptyOp>(
790 location: padOp.getLoc(), args&: reifiedShape.front(),
791 args: padOp.getResultType().getElementType());
792 Value replacement =
793 rewriter
794 .create<FillOp>(location: fillOp.getLoc(), args: ValueRange{padValue},
795 args: ValueRange{emptyTensor})
796 .getResult(i: 0);
797 if (replacement.getType() != padOp.getResultType()) {
798 replacement = rewriter.create<tensor::CastOp>(
799 location: fillOp.getLoc(), args: padOp.getResultType(), args&: replacement);
800 }
801 rewriter.replaceOp(op: padOp, newValues: replacement);
802 return success();
803 }
804};
805
806/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
807/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
808/// filling value are the same.
809struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
810 using OpRewritePattern::OpRewritePattern;
811
812 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
813 PatternRewriter &rewriter) const override {
814 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
815 if (!srcPadOp)
816 return failure();
817
818 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
819 return failure();
820
821 // Walk back the tensor.insert_slice chain and find the first destination
822 // value at the start of the chain.
823 Value firstDest = insertOp.getDest();
824 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
825 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
826 return failure();
827
828 // Make sure the range of values accessed are disjoint. Without this, we
829 // cannot fold tensor.pad away.
830 bool disjoint = false;
831 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
832 // If the dimension has dynamic offset/size, we cannot guarantee
833 // disjoint. So just skip it.
834 if (insertOp.isDynamicOffset(idx: i) || insertOp.isDynamicSize(idx: i) ||
835 insertOp.isDynamicStride(idx: i) || prevOp.isDynamicOffset(idx: i) ||
836 prevOp.isDynamicSize(idx: i) || prevOp.isDynamicStride(idx: i))
837 continue;
838
839 // Get the range start and end, inclusively for both.
840 int64_t prevStart = prevOp.getStaticOffset(idx: i);
841 int64_t prevEnd = prevStart + (prevOp.getStaticSize(idx: i) - 1) *
842 prevOp.getStaticStride(idx: i);
843 int64_t nextStart = insertOp.getStaticOffset(idx: i);
844 int64_t nextEnd = nextStart + (insertOp.getStaticSize(idx: i) - 1) *
845 insertOp.getStaticStride(idx: i);
846 if (prevEnd < nextStart || nextEnd < prevStart) {
847 disjoint = true;
848 break;
849 }
850 }
851
852 if (!disjoint)
853 break;
854 firstDest = prevOp.getDest();
855 }
856
857 // Check whether the first destination is a fill op. For overlapped cases,
858 // this also cannot be true.
859 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
860 if (!dstFillOp)
861 return failure();
862
863 // We can only fold if the padding value is the same as the original
864 // filling value.
865 Value padValue = srcPadOp.getConstantPaddingValue();
866 if (!padValue || dstFillOp.value() != padValue)
867 return failure();
868
869 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
870 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
871
872 Location loc = insertOp.getLoc();
873 MLIRContext *context = getContext();
874
875 AffineExpr sym0, sym1;
876 bindSymbols(ctx: context, exprs&: sym0, exprs&: sym1);
877 auto addMap = AffineMap::get(dimCount: 0, symbolCount: 2, results: {sym0 + sym1}, context);
878
879 // Calculate the new offsets for the insert. It should be the old offsets
880 // plus low padding sizes.
881 SmallVector<OpFoldResult, 4> newOffsets;
882 for (const auto &p : llvm::zip(t&: lowPads, u&: oldOffsets)) {
883 newOffsets.push_back(Elt: affine::makeComposedFoldedAffineApply(
884 b&: rewriter, loc, map: addMap, operands: {std::get<0>(t: p), std::get<1>(t: p)}));
885 }
886
887 RankedTensorType srcPadType = srcPadOp.getSourceType();
888 SmallVector<OpFoldResult, 4> newSizes;
889 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
890 if (srcPadType.isDynamicDim(idx: i)) {
891 newSizes.push_back(
892 Elt: rewriter.create<tensor::DimOp>(location: loc, args: srcPadOp.getSource(), args&: i)
893 .getResult());
894 } else {
895 newSizes.push_back(Elt: rewriter.getIndexAttr(value: srcPadType.getDimSize(idx: i)));
896 }
897 }
898
899 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
900 op: insertOp, args: srcPadOp.getSource(), args: insertOp.getDest(), args&: newOffsets,
901 args&: newSizes, args: insertOp.getMixedStrides());
902 return success();
903 }
904};
905
906/// Fold tensor.extract(linalg.fill(<input>)) into <input>
907struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
908public:
909 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
910
911 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
912 PatternRewriter &rewriter) const override {
913 // See if tensor input of tensor.extract op is the result of a linalg.fill
914 // op.
915 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
916 if (!fillOp)
917 return failure();
918
919 // Get scalar input operand of linalg.fill op.
920 Value extractedScalar = fillOp.getInputs()[0];
921
922 // Replace tensor.extract op with scalar value used to fill the tensor.
923 rewriter.replaceOp(op: extractOp, newValues: extractedScalar);
924 return success();
925 }
926};
927
928/// Folds pack(fill) into a single fill op if
929/// 1. The pack op does not have padding value, or
930/// 2. The filled value and padding value are the same.
931static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
932 linalg::PackOp packOp) {
933 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
934 if (!fillOp)
935 return failure();
936
937 if (auto paddingValue = packOp.getPaddingValue())
938 if (!isEqualConstantIntOrValue(ofr1: paddingValue, ofr2: fillOp.value()))
939 return failure();
940
941 Value packOpDest = packOp.getDest();
942 if (!packOpDest.hasOneUse())
943 return failure();
944
945 return rewriter.create<linalg::FillOp>(location: packOp.getLoc(), args: fillOp.getInputs(),
946 args: packOp.getDest());
947}
948
949/// Wrapper pattern that applies foldFillPackIntoFillOp method.
950struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
951public:
952 FoldFillWithPack(MLIRContext *context)
953 : OpRewritePattern<linalg::PackOp>(context) {}
954
955 LogicalResult matchAndRewrite(linalg::PackOp packOp,
956 PatternRewriter &rewriter) const override {
957 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
958 if (failed(Result: fillOp))
959 return failure();
960 rewriter.replaceOp(op: packOp, newValues: fillOp.value().result());
961 return success();
962 }
963};
964
965/// Fold fill with copy.
966struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
967 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
968
969 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
970 PatternRewriter &rewriter) const override {
971 if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
972 rewriter.replaceOpWithNewOp<FillOp>(op: copyOp, args: copyOp.getResultTypes(),
973 args: fillOp.getInputs(),
974 args: copyOp.getOutputs());
975 return success();
976 }
977 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
978 rewriter.replaceOpWithNewOp<linalg::CopyOp>(op: copyOp, args: copyOp.getInputs(),
979 args: fillOp.getOutputs());
980 return success();
981 }
982 return failure();
983 }
984};
985
986/// Fold fill with transpose.
987struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
988 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
989
990 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
991 PatternRewriter &rewriter) const override {
992 if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
993 rewriter.replaceOpWithNewOp<FillOp>(
994 op: transposeOp, args: transposeOp.getResultTypes(), args: fillOp.getInputs(),
995 args: transposeOp.getDpsInitOperand(i: 0)->get());
996 return success();
997 }
998 return failure();
999 }
1000};
1001
1002/// Fold a concat with all elements being fills of the same value
1003/// into a fill of the concat result shape.
1004struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
1005 using OpRewritePattern::OpRewritePattern;
1006
1007 LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
1008 PatternRewriter &rewriter) const override {
1009 auto concatOperands = concatOp.getInputs();
1010 if (concatOperands.empty()) {
1011 return failure();
1012 }
1013
1014 auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
1015 if (!firstFillOp) {
1016 return failure();
1017 }
1018 // Prefetch the fill value.
1019 OpFoldResult firstFillVal =
1020 getAsOpFoldResult(val: firstFillOp.getDpsInputOperand(i: 0)->get());
1021 // Collect all the outs values for the fill operations.
1022 SmallVector<Value> allOuts;
1023 allOuts.push_back(Elt: firstFillOp.getDpsInitOperand(i: 0)->get());
1024
1025 auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
1026 auto fillOp = v.getDefiningOp<linalg::FillOp>();
1027 if (!fillOp) {
1028 return false;
1029 }
1030
1031 OpFoldResult fillVal =
1032 getAsOpFoldResult(val: fillOp.getDpsInputOperand(i: 0)->get());
1033 if (fillVal != firstFillVal)
1034 return false;
1035
1036 allOuts.push_back(Elt: fillOp.getDpsInitOperand(i: 0)->get());
1037 return true;
1038 };
1039 if (!llvm::all_of(Range: concatOperands.drop_front(),
1040 P: isDefinedByCompatibleFillOp)) {
1041 return rewriter.notifyMatchFailure(
1042 arg&: concatOp, msg: "not all operands are defined by a compatible fill op");
1043 }
1044
1045 Value outsConcat = rewriter.create<tensor::ConcatOp>(
1046 location: concatOp.getLoc(), args: concatOp.getDim(), args&: allOuts);
1047 rewriter.replaceOpWithNewOp<linalg::FillOp>(
1048 op: concatOp, args: firstFillOp.getDpsInputOperand(i: 0)->get(), args&: outsConcat);
1049 return success();
1050 }
1051};
1052
1053} // namespace
1054
1055void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
1056 MLIRContext *context) {
1057 results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
1058 FoldFillWithPack, FoldFillWithPad,
1059 FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
1060 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
1061 FoldInsertPadIntoFill, FoldFillWithTranspose>(arg&: context);
1062}
1063
1064//===----------------------------------------------------------------------===//
1065// GenericOp
1066//===----------------------------------------------------------------------===//
1067
1068static void buildGenericRegion(
1069 OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
1070 ValueRange outputs,
1071 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
1072 SmallVector<Type, 4> blockArgTypes;
1073 SmallVector<Location, 4> blockArgLocs;
1074 for (ValueRange container : {inputs, outputs}) {
1075 for (Value v : container) {
1076 Type t = v.getType();
1077 blockArgTypes.push_back(
1078 Elt: isa<MemRefType, RankedTensorType>(Val: t) ? getElementTypeOrSelf(type: t) : t);
1079 blockArgLocs.push_back(Elt: v.getLoc());
1080 }
1081 }
1082
1083 OpBuilder::InsertionGuard guard(builder);
1084 Block *bodyBlock =
1085 builder.createBlock(parent: &region, insertPt: region.end(), argTypes: blockArgTypes, locs: blockArgLocs);
1086 bodyBuild(builder, loc, bodyBlock->getArguments());
1087}
1088
1089void GenericOp::getAsmBlockArgumentNames(Region &region,
1090 OpAsmSetValueNameFn setNameFn) {
1091 for (Value v : getRegionInputArgs())
1092 setNameFn(v, "in");
1093 for (Value v : getRegionOutputArgs())
1094 setNameFn(v, "out");
1095}
1096
1097void GenericOp::build(
1098 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1099 ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1100 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1101 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1102 ArrayRef<NamedAttribute> attributes) {
1103 build(odsBuilder&: builder, odsState&: result, result_tensors: resultTensorTypes, inputs, outputs, indexing_maps: indexingMaps,
1104 iterator_types: iteratorTypes, doc, library_call: libraryCall);
1105 result.addAttributes(newAttributes: attributes);
1106 if (bodyBuild)
1107 buildGenericRegion(builder, loc: result.location, region&: *result.regions.front(),
1108 inputs, outputs, bodyBuild);
1109}
1110
1111void GenericOp::build(
1112 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1113 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1114 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1115 StringRef libraryCall,
1116 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1117 ArrayRef<NamedAttribute> attributes) {
1118 build(builder, result, resultTensorTypes, inputs, outputs,
1119 indexingMaps: builder.getAffineMapArrayAttr(values: indexingMaps),
1120 iteratorTypes: builder.getArrayAttr(value: llvm::to_vector(Range: llvm::map_range(
1121 C&: iteratorTypes,
1122 F: [&](utils::IteratorType iter) -> mlir::Attribute {
1123 return IteratorTypeAttr::get(context: builder.getContext(), value: iter);
1124 }))),
1125 doc: doc.empty() ? StringAttr() : builder.getStringAttr(bytes: doc),
1126 libraryCall: libraryCall.empty() ? StringAttr() : builder.getStringAttr(bytes: libraryCall),
1127 bodyBuild, attributes);
1128}
1129
1130void GenericOp::build(
1131 OpBuilder &builder, OperationState &result, ValueRange inputs,
1132 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1133 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1134 StringRef libraryCall,
1135 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1136 ArrayRef<NamedAttribute> attributes) {
1137 build(builder, result, resultTensorTypes: TypeRange{}, inputs, outputs, indexingMaps,
1138 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1139}
1140
1141void GenericOp::build(
1142 OpBuilder &builder, OperationState &result, ValueRange inputs,
1143 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1144 ArrayRef<utils::IteratorType> iteratorTypes,
1145 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1146 ArrayRef<NamedAttribute> attributes) {
1147 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1148 /*doc=*/"",
1149 /*libraryCall=*/"", bodyBuild, attributes);
1150}
1151
1152void GenericOp::build(
1153 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1154 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1155 ArrayRef<utils::IteratorType> iteratorTypes,
1156 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1157 ArrayRef<NamedAttribute> attributes) {
1158 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1159 iteratorTypes,
1160 /*doc=*/"",
1161 /*libraryCall=*/"", bodyBuild, attributes);
1162}
1163
1164void GenericOp::print(OpAsmPrinter &p) {
1165 p << " ";
1166
1167 // Print extra attributes.
1168 auto genericAttrNames = linalgTraitAttrNames();
1169
1170 llvm::StringSet<> genericAttrNamesSet;
1171 genericAttrNamesSet.insert_range(R&: genericAttrNames);
1172 SmallVector<NamedAttribute, 8> genericAttrs;
1173 for (auto attr : (*this)->getAttrs()) {
1174 if (attr.getName() == getIteratorTypesAttrName()) {
1175 auto iteratorTypes =
1176 llvm::cast<ArrayAttr>(Val: attr.getValue())
1177 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1178 // Convert IteratorType enums into the string representation. This is
1179 // needed, because tests still use the old format when 'iterator_types'
1180 // attribute is represented as an array of strings.
1181 // TODO: Remove this conversion once tests are fixed.
1182 SmallVector<Attribute> iteratorTypeNames =
1183 llvm::to_vector(Range: llvm::map_range(
1184 C&: iteratorTypes, F: [&](utils::IteratorType t) -> Attribute {
1185 return StringAttr::get(context: getContext(), bytes: stringifyIteratorType(t));
1186 }));
1187
1188 genericAttrs.emplace_back(
1189 Args: getIteratorTypesAttrName(),
1190 Args: ArrayAttr::get(context: getContext(), value: iteratorTypeNames));
1191 } else if (genericAttrNamesSet.count(Key: attr.getName().strref()) > 0) {
1192 genericAttrs.push_back(Elt: attr);
1193 }
1194 }
1195 if (!genericAttrs.empty()) {
1196 auto genericDictAttr = DictionaryAttr::get(context: getContext(), value: genericAttrs);
1197 p << genericDictAttr;
1198 }
1199
1200 // Printing is shared with named ops, except for the region and attributes
1201 printCommonStructuredOpParts(p, inputs: getDpsInputs(), outputs: getDpsInits());
1202
1203 genericAttrNames.push_back(Elt: "operandSegmentSizes");
1204 genericAttrNamesSet.insert(key: genericAttrNames.back());
1205
1206 bool hasExtraAttrs = false;
1207 for (NamedAttribute n : (*this)->getAttrs()) {
1208 if ((hasExtraAttrs = !genericAttrNamesSet.contains(key: n.getName().strref())))
1209 break;
1210 }
1211 if (hasExtraAttrs) {
1212 p << " attrs = ";
1213 p.printOptionalAttrDict(attrs: (*this)->getAttrs(),
1214 /*elidedAttrs=*/genericAttrNames);
1215 }
1216
1217 // Print region.
1218 if (!getRegion().empty()) {
1219 p << ' ';
1220 p.printRegion(blocks&: getRegion());
1221 }
1222
1223 // Print results.
1224 printNamedStructuredOpResults(p, resultTypes: getResultTensors().getTypes());
1225}
1226
1227ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1228 DictionaryAttr dictAttr;
1229 // Parse the core linalg traits that must check into a dictAttr.
1230 // The name is unimportant as we will overwrite result.attributes.
1231 // The core linalg traits must contain the information necessary to pass the
1232 // verifier.
1233 llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1234 if (parser.parseAttribute(result&: dictAttr, attrName: "_", attrs&: result.attributes))
1235 return failure();
1236 result.attributes.assign(inStart: dictAttr.getValue().begin(),
1237 inEnd: dictAttr.getValue().end());
1238
1239 // Convert array of string into an array of IteratorType enums. This is
1240 // needed, because tests still use the old format when 'iterator_types'
1241 // attribute is represented as an array of strings.
1242 // TODO: Remove this conversion once tests are fixed.
1243 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1244 Val: result.attributes.get(name: getIteratorTypesAttrName(name: result.name)));
1245 if (!iteratorTypes) {
1246 return parser.emitError(loc: attributeLocation)
1247 << "expected " << getIteratorTypesAttrName(name: result.name)
1248 << " array attribute";
1249 }
1250
1251 SmallVector<Attribute> iteratorTypeAttrs;
1252
1253 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1254 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1255 if (!maybeIteratorType.has_value())
1256 return parser.emitError(loc: parser.getCurrentLocation())
1257 << "unexpected iterator_type (" << s << ")";
1258
1259 iteratorTypeAttrs.push_back(
1260 Elt: IteratorTypeAttr::get(context: parser.getContext(), value: maybeIteratorType.value()));
1261 }
1262 result.attributes.set(name: getIteratorTypesAttrName(name: result.name),
1263 value: parser.getBuilder().getArrayAttr(value: iteratorTypeAttrs));
1264
1265 // Parsing is shared with named ops, except for the region.
1266 SmallVector<Type, 1> inputTypes, outputTypes;
1267 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1268 return failure();
1269
1270 // Optional attributes may be added.
1271 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "attrs")))
1272 if (failed(Result: parser.parseEqual()) ||
1273 failed(Result: parser.parseOptionalAttrDict(result&: result.attributes)))
1274 return failure();
1275
1276 std::unique_ptr<Region> region = std::make_unique<Region>();
1277 if (parser.parseRegion(region&: *region, arguments: {}))
1278 return failure();
1279 result.addRegion(region: std::move(region));
1280
1281 // Generic ops may specify that a subset of its outputs are tensors. Such
1282 // outputs are specified in the result type.
1283 // TODO: may need to move output parsing before region parsing.
1284 // Need to wait for declarative assembly resolution to decide.
1285 SmallVector<Type, 1> outputTensorsTypes;
1286 if (parseNamedStructuredOpResults(parser, resultTypes&: outputTensorsTypes))
1287 return failure();
1288 result.addTypes(newTypes: outputTensorsTypes);
1289
1290 return success();
1291}
1292
1293static void getGenericEffectsImpl(
1294 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1295 &effects,
1296 LinalgOp linalgOp) {
1297 for (auto [index, operand] : llvm::enumerate(First: linalgOp.getDpsInputs())) {
1298 if (!llvm::isa<MemRefType>(Val: operand.getType()))
1299 continue;
1300 effects.emplace_back(
1301 Args: MemoryEffects::Read::get(), Args: &linalgOp->getOpOperand(idx: index), /*stage=*/Args: 0,
1302 /*effectOnFullRegion=*/Args: true, Args: SideEffects::DefaultResource::get());
1303 }
1304
1305 for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1306 if (!llvm::isa<MemRefType>(Val: operand.get().getType()))
1307 continue;
1308 if (linalgOp.payloadUsesValueFromOperand(opOperand: &operand)) {
1309 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: &operand, /*stage=*/Args: 0,
1310 /*effectOnFullRegion=*/Args: true,
1311 Args: SideEffects::DefaultResource::get());
1312 }
1313 effects.emplace_back(Args: MemoryEffects::Write::get(), Args: &operand, /*stage=*/Args: 0,
1314 /*effectOnFullRegion=*/Args: true,
1315 Args: SideEffects::DefaultResource::get());
1316 }
1317}
1318
1319void GenericOp::getEffects(
1320 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1321 &effects) {
1322 getGenericEffectsImpl(effects, linalgOp: cast<LinalgOp>(Val: getOperation()));
1323}
1324
1325static Speculation::Speculatability
1326getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
1327 // Operands with value semantics are speculatable, while operands with memory
1328 // semantics are not.
1329 if (!linalgOp.hasPureTensorSemantics())
1330 return Speculation::NotSpeculatable;
1331 // The body of the op can still have speculation in its region.
1332 return Speculation::RecursivelySpeculatable;
1333}
1334
1335Speculation::Speculatability GenericOp::getSpeculatability() {
1336 return getGenericSpeculatabilityImpl(linalgOp: cast<LinalgOp>(Val: getOperation()));
1337}
1338
1339LogicalResult GenericOp::verify() { return success(); }
1340
1341namespace {
1342
1343/// Remove linalg operations that are just copying the values from inputs to
1344/// results. In the memref case, the operation must be copying to and from the
1345/// same value. Requirements are:
1346/// 1) All iterator types are parallel
1347/// 2) The body contains just a yield operation with the yielded values being
1348/// the arguments corresponding to the operands.
1349template <typename OpTy>
1350struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1351 using OpRewritePattern<OpTy>::OpRewritePattern;
1352
1353 LogicalResult matchAndRewrite(OpTy linalgOp,
1354 PatternRewriter &rewriter) const override {
1355 // All indexing maps must be equal. It follows that they are permutations.
1356 if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1357 return failure();
1358
1359 // Check that the body of the linalg operation is just a linalg.yield
1360 // operation.
1361 Block &body = linalgOp->getRegion(0).front();
1362 if (!llvm::hasSingleElement(C&: body))
1363 return failure();
1364 auto yieldOp = dyn_cast<linalg::YieldOp>(Val: body.getTerminator());
1365 if (!yieldOp)
1366 return failure();
1367
1368 // In the buffer case, we need to check exact buffer equality.
1369 if (linalgOp.hasPureBufferSemantics()) {
1370 if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1 ||
1371 linalgOp.getDpsInputOperand(0)->get() !=
1372 linalgOp.getDpsInitOperand(0)->get()) {
1373 return rewriter.notifyMatchFailure(
1374 linalgOp, "expected single input and output to be the same value");
1375 }
1376
1377 auto yieldArg = dyn_cast<BlockArgument>(Val: yieldOp.getOperand(i: 0));
1378 if (!yieldArg || yieldArg.getOwner() != &body) {
1379 return rewriter.notifyMatchFailure(linalgOp,
1380 "cannot fold fill-like op");
1381 }
1382
1383 rewriter.eraseOp(op: linalgOp);
1384 return success();
1385 }
1386
1387 if (!linalgOp.hasPureTensorSemantics()) {
1388 return rewriter.notifyMatchFailure(
1389 linalgOp, "mixed semantics is not supported yet");
1390 }
1391
1392 // Get the argument number of the returned values. That is the operand
1393 // number to use for replacing uses of this operation.
1394 SmallVector<Value> returnedArgs;
1395 for (const auto &yieldVal : llvm::enumerate(First: yieldOp.getValues())) {
1396 auto yieldArg = llvm::dyn_cast<BlockArgument>(Val&: yieldVal.value());
1397 if (!yieldArg || yieldArg.getOwner() != &body)
1398 return failure();
1399 unsigned argumentNumber = yieldArg.getArgNumber();
1400 Value returnedArg = linalgOp->getOperand(argumentNumber);
1401 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1402 // The input can have a different type than the result, e.g. a dynamic
1403 // input dimension can be turned into a static output dimension.
1404 Type returnType = returnedArg.getType();
1405 if (returnType != resultType) {
1406 // Distinguish between sparse conversion or dense tensor casting.
1407 // TODO: unify the two ops?
1408 if (sparse_tensor::getSparseTensorEncoding(type: returnType) ||
1409 sparse_tensor::getSparseTensorEncoding(type: resultType))
1410 returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1411 linalgOp.getLoc(), resultType, returnedArg);
1412 else {
1413 if (!tensor::CastOp::areCastCompatible(inputs: returnedArg.getType(),
1414 outputs: resultType))
1415 return failure();
1416 returnedArg = rewriter.create<tensor::CastOp>(
1417 linalgOp.getLoc(), resultType, returnedArg);
1418 }
1419 }
1420 returnedArgs.push_back(Elt: returnedArg);
1421 }
1422
1423 if (returnedArgs.size() != linalgOp->getNumResults())
1424 return failure();
1425 rewriter.replaceOp(linalgOp, returnedArgs);
1426 return success();
1427 }
1428};
1429
1430} // namespace
1431
1432void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1433 MLIRContext *context) {
1434 results.add<EraseIdentityLinalgOp<GenericOp>>(arg&: context);
1435}
1436
1437LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1438 return memref::foldMemRefCast(op: *this);
1439}
1440
1441//===----------------------------------------------------------------------===//
1442// MapOp
1443//===----------------------------------------------------------------------===//
1444
1445static ParseResult parseDstStyleOp(
1446 OpAsmParser &parser, OperationState &result,
1447 function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1448 nullptr) {
1449 // Parse `ins` and `outs`.
1450 SmallVector<Type, 4> inputTypes, outputTypes;
1451 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1452 /*addOperandSegmentSizes=*/false))
1453 return failure();
1454
1455 // Add result types.
1456 for (Type outputType : outputTypes) {
1457 if (llvm::isa<RankedTensorType>(Val: outputType))
1458 result.addTypes(newTypes: outputType);
1459 }
1460
1461 // Parse required attributes.
1462 if (parseAttrsFn && failed(Result: parseAttrsFn(parser, result.attributes)))
1463 return failure();
1464
1465 // Parse optional attributes.
1466 if (parser.parseOptionalAttrDict(result&: result.attributes))
1467 return failure();
1468 return success();
1469}
1470
1471void MapOp::getAsmBlockArgumentNames(Region &region,
1472 OpAsmSetValueNameFn setNameFn) {
1473 for (Value v : getRegionInputArgs())
1474 setNameFn(v, "in");
1475}
1476
1477void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1478 if (!getResults().empty())
1479 setNameFn(getResults().front(), "mapped");
1480}
1481
1482void MapOp::build(
1483 OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1484 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1485 ArrayRef<NamedAttribute> attributes) {
1486 build(odsBuilder&: builder, odsState&: result, result: TypeRange{}, inputs, init);
1487 result.addAttributes(newAttributes: attributes);
1488
1489 // Add output types for `RankedTensorType` output arguments.
1490 Type initType = init.getType();
1491 if (llvm::isa<RankedTensorType>(Val: initType))
1492 result.addTypes(newTypes: initType);
1493
1494 if (bodyBuild)
1495 buildGenericRegion(builder, loc: result.location, region&: *result.regions.front(),
1496 inputs, /*outputs=*/{}, bodyBuild);
1497}
1498
1499static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1500 const OperationName &payloadOpName,
1501 const NamedAttrList &payloadOpAttrs,
1502 ArrayRef<Value> operands,
1503 bool initFirst = false) {
1504 OpBuilder b(parser.getContext());
1505 Region *body = result.addRegion();
1506 Block &block = body->emplaceBlock();
1507 b.setInsertionPointToStart(&block);
1508 for (auto &operand : operands) {
1509 block.addArgument(
1510 type: llvm::cast<ShapedType>(Val: operand.getType()).getElementType(),
1511 loc: b.getUnknownLoc());
1512 }
1513 SmallVector<Value> payloadOpOperands;
1514 // If initFirst flag is enabled, we consider init as the first position of
1515 // payload operands.
1516 if (initFirst) {
1517 payloadOpOperands.push_back(Elt: block.getArguments().back());
1518 for (const auto &arg : block.getArguments().drop_back())
1519 payloadOpOperands.push_back(Elt: arg);
1520 } else {
1521 payloadOpOperands = {block.getArguments().begin(),
1522 block.getArguments().end()};
1523 }
1524
1525 Operation *payloadOp = b.create(
1526 loc: result.location, opName: b.getStringAttr(bytes: payloadOpName.getStringRef()),
1527 operands: payloadOpOperands,
1528 types: TypeRange{llvm::cast<ShapedType>(Val: result.operands.back().getType())
1529 .getElementType()},
1530 attributes: payloadOpAttrs);
1531 b.create<YieldOp>(location: result.location, args: payloadOp->getResults());
1532}
1533
1534ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1535 std::optional<OperationName> payloadOpName;
1536 NamedAttrList payloadOpAttrs;
1537 if (succeeded(Result: parser.parseOptionalLBrace())) {
1538 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1539 if (failed(Result: operationName))
1540 return failure();
1541 if (parser.parseOptionalAttrDict(result&: payloadOpAttrs))
1542 return failure();
1543 payloadOpName = operationName.value();
1544 if (parser.parseRBrace())
1545 return failure();
1546 }
1547
1548 if (parseDstStyleOp(parser, result))
1549 return failure();
1550
1551 if (payloadOpName.has_value()) {
1552 if (!result.operands.empty())
1553 addBodyWithPayloadOp(parser, result, payloadOpName: payloadOpName.value(),
1554 payloadOpAttrs,
1555 operands: ArrayRef(result.operands).drop_back());
1556 else
1557 result.addRegion();
1558 } else {
1559 SmallVector<OpAsmParser::Argument> regionArgs;
1560 if (parser.parseArgumentList(result&: regionArgs, delimiter: OpAsmParser::Delimiter::Paren,
1561 /*allowType=*/true, /*allowAttrs=*/true)) {
1562 return failure();
1563 }
1564 Region *body = result.addRegion();
1565 if (parser.parseRegion(region&: *body, arguments: regionArgs))
1566 return failure();
1567 }
1568 return success();
1569}
1570
1571// Retrieve the operation from the body, if it is the only one (except
1572// yield) and if it gets the same amount of arguments as the body does.
1573// If initFirst flag is enabled, we check that init takes the first position in
1574// operands of payload.
1575static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1576 if (body->getOperations().size() != 2)
1577 return nullptr;
1578 Operation &payload = body->getOperations().front();
1579 assert(isa<YieldOp>(body->getOperations().back()));
1580
1581 if (payload.getNumOperands() == 0 ||
1582 payload.getNumOperands() != body->getNumArguments())
1583 return nullptr;
1584 if (initFirst) {
1585 // check init
1586 if (payload.getOperands().back() != body->getArgument(i: 0))
1587 return nullptr;
1588 // check rest
1589 for (const auto &[operand, bbArg] :
1590 llvm::zip(t: payload.getOperands(), u: body->getArguments().drop_front())) {
1591 if (bbArg != operand)
1592 return nullptr;
1593 }
1594 } else {
1595 for (const auto &[operand, bbArg] :
1596 llvm::zip(t: payload.getOperands(), u: body->getArguments())) {
1597 if (bbArg != operand)
1598 return nullptr;
1599 }
1600 }
1601 return &payload;
1602}
1603
1604void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1605 SmallVector<StringRef> elidedAttrs;
1606 std::string attrToElide;
1607 p << " { " << payloadOp->getName().getStringRef();
1608 for (const auto &attr : payloadOp->getAttrs()) {
1609 auto fastAttr =
1610 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(Val: attr.getValue());
1611 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1612 attrToElide = attr.getName().str();
1613 elidedAttrs.push_back(Elt: attrToElide);
1614 break;
1615 }
1616 }
1617 p.printOptionalAttrDict(attrs: payloadOp->getAttrs(), elidedAttrs);
1618 p << " }";
1619}
1620
1621void MapOp::print(OpAsmPrinter &p) {
1622 Block *mapper = getBody();
1623 Operation *payloadOp = findPayloadOp(body: mapper);
1624 if (payloadOp) {
1625 printShortForm(p, payloadOp);
1626 }
1627
1628 printCommonStructuredOpParts(p, inputs: getDpsInputs(), outputs: getDpsInits());
1629 p.printOptionalAttrDict(attrs: (*this)->getAttrs());
1630
1631 if (!payloadOp) {
1632 // Print region if the payload op was not detected.
1633 p.increaseIndent();
1634 p.printNewline();
1635 p << "(";
1636 llvm::interleaveComma(c: mapper->getArguments(), os&: p,
1637 each_fn: [&](auto arg) { p.printRegionArgument(arg); });
1638 p << ") ";
1639
1640 p.printRegion(blocks&: getMapper(), /*printEntryBlockArgs=*/false);
1641 p.decreaseIndent();
1642 }
1643}
1644
1645LogicalResult MapOp::verify() {
1646 auto *bodyBlock = getBody();
1647 auto blockArgs = bodyBlock->getArguments();
1648
1649 // Checks if the number of `inputs` match the arity of the `mapper` region.
1650 if (getInputs().size() != blockArgs.size())
1651 return emitOpError() << "expects number of operands to match the arity of "
1652 "mapper, but got: "
1653 << getInputs().size() << " and " << blockArgs.size();
1654
1655 // The parameters of mapper should all match the element type of inputs.
1656 for (const auto &[bbArgType, inputArg] :
1657 llvm::zip(t: bodyBlock->getArgumentTypes(), u: getInputs())) {
1658 auto inputElemType =
1659 llvm::cast<ShapedType>(Val: inputArg.getType()).getElementType();
1660 if (bbArgType != inputElemType) {
1661 return emitOpError() << "expected element type of input " << inputElemType
1662 << " to match bbArg type " << bbArgType;
1663 }
1664 }
1665
1666 // The shape of each input must match the shape of the output.
1667 auto outputShape = getInit().getType().getShape();
1668 for (Type inputArgType : TypeRange{getInputs()}) {
1669 auto inputElemShape = llvm::cast<ShapedType>(Val&: inputArgType).getShape();
1670 if (inputElemShape != outputShape) {
1671 return emitOpError() << "expected shape of input (" << inputElemShape
1672 << ") to match shape of output (" << outputShape
1673 << ")";
1674 }
1675 }
1676
1677 return success();
1678}
1679
1680SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1681 int64_t rank = getInit().getType().getRank();
1682 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1683}
1684
1685ArrayAttr MapOp::getIndexingMaps() {
1686 Builder builder(getContext());
1687 int64_t rank = getInit().getType().getRank();
1688 int64_t numIndexingMaps = getOperands().size();
1689 return builder.getAffineMapArrayAttr(values: SmallVector<AffineMap>(
1690 numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1691}
1692
1693void MapOp::getEffects(
1694 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1695 &effects) {
1696 getGenericEffectsImpl(effects, linalgOp: cast<LinalgOp>(Val: getOperation()));
1697}
1698
1699Speculation::Speculatability MapOp::getSpeculatability() {
1700 return getGenericSpeculatabilityImpl(linalgOp: cast<LinalgOp>(Val: getOperation()));
1701}
1702
1703//===----------------------------------------------------------------------===//
1704// ReduceOp
1705//===----------------------------------------------------------------------===//
1706
1707void ReduceOp::getAsmBlockArgumentNames(Region &region,
1708 OpAsmSetValueNameFn setNameFn) {
1709 for (Value v : getRegionInputArgs())
1710 setNameFn(v, "in");
1711 for (Value v : getRegionOutputArgs())
1712 setNameFn(v, "init");
1713}
1714
1715void ReduceOp::getAsmResultNames(
1716 function_ref<void(Value, StringRef)> setNameFn) {
1717 if (!getResults().empty())
1718 setNameFn(getResults().front(), "reduced");
1719}
1720
1721void ReduceOp::build(
1722 OpBuilder &builder, OperationState &result, ValueRange inputs,
1723 ValueRange inits, ArrayRef<int64_t> dimensions,
1724 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1725 ArrayRef<NamedAttribute> attributes) {
1726 build(odsBuilder&: builder, odsState&: result, resultType0: TypeRange{}, inputs, inits, dimensions);
1727 result.addAttributes(newAttributes: attributes);
1728
1729 // Add output types for `RankedTensorType` output arguments.
1730 for (Value init : inits) {
1731 Type initType = init.getType();
1732 if (llvm::isa<RankedTensorType>(Val: initType))
1733 result.addTypes(newTypes: initType);
1734 }
1735
1736 if (bodyBuild)
1737 buildGenericRegion(builder, loc: result.location, region&: *result.regions.front(),
1738 inputs, outputs: inits, bodyBuild);
1739}
1740
1741SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1742 int64_t inputRank =
1743 llvm::cast<ShapedType>(Val: getInputs()[0].getType()).getRank();
1744 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1745 utils::IteratorType::parallel);
1746 for (int64_t reductionDim : getDimensions())
1747 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1748 return iteratorTypes;
1749}
1750
1751ArrayAttr ReduceOp::getIndexingMaps() {
1752 int64_t inputRank =
1753 llvm::cast<ShapedType>(Val: getInputs()[0].getType()).getRank();
1754 SmallVector<AffineMap> affineMaps(
1755 getNumDpsInputs(),
1756 AffineMap::getMultiDimIdentityMap(numDims: inputRank, context: getContext()));
1757 AffineMap resultMap =
1758 AffineMap::getMultiDimIdentityMap(numDims: inputRank, context: getContext())
1759 .dropResults(positions: getDimensions());
1760 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1761 affineMaps.push_back(Elt: resultMap);
1762 return Builder(getContext()).getAffineMapArrayAttr(values: affineMaps);
1763}
1764
1765void ReduceOp::getEffects(
1766 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1767 &effects) {
1768 getGenericEffectsImpl(effects, linalgOp: cast<LinalgOp>(Val: getOperation()));
1769}
1770
1771Speculation::Speculatability ReduceOp::getSpeculatability() {
1772 return getGenericSpeculatabilityImpl(linalgOp: cast<LinalgOp>(Val: getOperation()));
1773}
1774
1775static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1776 NamedAttrList &attributes,
1777 StringRef attributeName) {
1778 if (parser.parseKeyword(keyword: attributeName) || parser.parseEqual())
1779 return failure();
1780
1781 attributes.set(name: attributeName, value: DenseI64ArrayAttr::parse(parser, type: Type{}));
1782 return success();
1783}
1784
1785ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1786 std::optional<OperationName> payloadOpName;
1787 NamedAttrList payloadOpAttrs;
1788 if (succeeded(Result: parser.parseOptionalLBrace())) {
1789 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1790 if (failed(Result: operationName))
1791 return failure();
1792 if (parser.parseOptionalAttrDict(result&: payloadOpAttrs))
1793 return failure();
1794 payloadOpName = operationName.value();
1795 if (parser.parseRBrace())
1796 return failure();
1797 }
1798
1799 if (parseDstStyleOp(
1800 parser, result, parseAttrsFn: [&](OpAsmParser &parser, NamedAttrList &attributes) {
1801 return parseDenseI64ArrayAttr(parser, attributes, attributeName: "dimensions");
1802 }))
1803 return failure();
1804
1805 if (payloadOpName.has_value()) {
1806 addBodyWithPayloadOp(parser, result, payloadOpName: payloadOpName.value(), payloadOpAttrs,
1807 operands: ArrayRef(result.operands), /*initFirst=*/true);
1808 } else {
1809 SmallVector<OpAsmParser::Argument> regionArgs;
1810 if (parser.parseArgumentList(result&: regionArgs, delimiter: OpAsmParser::Delimiter::Paren,
1811 /*allowType=*/true, /*allowAttrs=*/true)) {
1812 return failure();
1813 }
1814
1815 Region *body = result.addRegion();
1816 if (parser.parseRegion(region&: *body, arguments: regionArgs))
1817 return failure();
1818 }
1819
1820 return success();
1821}
1822
1823static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1824 ArrayRef<int64_t> attributeValue) {
1825 p << ' ' << attributeName << " = [" << attributeValue << "] ";
1826}
1827
1828void ReduceOp::print(OpAsmPrinter &p) {
1829 Block *mapper = getBody();
1830 Operation *payloadOp = findPayloadOp(body: mapper, /*initFirst=*/true);
1831 if (payloadOp) {
1832 printShortForm(p, payloadOp);
1833 }
1834
1835 printCommonStructuredOpParts(p, inputs: getDpsInputs(), outputs: getDpsInits());
1836 printDenseI64ArrayAttr(p, attributeName: getDimensionsAttrName(), attributeValue: getDimensions());
1837 p.printOptionalAttrDict(attrs: (*this)->getAttrs(), elidedAttrs: {getDimensionsAttrName()});
1838 if (!payloadOp) {
1839 // Print region if the payload op was not detected.
1840 p.increaseIndent();
1841 p.printNewline();
1842 p << "(";
1843 llvm::interleaveComma(c: mapper->getArguments(), os&: p,
1844 each_fn: [&](auto arg) { p.printRegionArgument(arg); });
1845 p << ") ";
1846
1847 p.printRegion(blocks&: getCombiner(), /*printEntryBlockArgs=*/false);
1848 p.decreaseIndent();
1849 }
1850}
1851
1852LogicalResult ReduceOp::verify() {
1853 ArrayRef<int64_t> dimensionsRef = getDimensions();
1854
1855 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1856 if (llvm::cast<ShapedType>(Val: getInputs()[i].getType()).getShape() !=
1857 llvm::cast<ShapedType>(Val: getInputs()[0].getType()).getShape()) {
1858 return emitOpError() << "expects all inputs to have the same shapes. "
1859 "Shape at input-index "
1860 << i
1861 << " is not equal to the shape at input-index 0.";
1862 }
1863 }
1864 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1865 if (llvm::cast<ShapedType>(Val: getInits()[i].getType()).getShape() !=
1866 llvm::cast<ShapedType>(Val: getInits()[0].getType()).getShape()) {
1867 return emitOpError() << "expects all outputs to have the same shapes. "
1868 "Shape at output-index "
1869 << i
1870 << " is not equal to the shape at output-index 0.";
1871 }
1872 }
1873 auto inputType = llvm::cast<ShapedType>(Val: getInputs()[0].getType());
1874 auto initType = llvm::cast<ShapedType>(Val: getInits()[0].getType());
1875
1876 DenseSet<int64_t> dimensionsToReduce;
1877 for (int64_t dimension : dimensionsRef) {
1878 if (dimension < 0 || dimension >= inputType.getRank()) {
1879 return emitOpError()
1880 << "dimensions for reduction should be in the range [0, "
1881 << inputType.getRank() - 1 << "].";
1882 }
1883 dimensionsToReduce.insert(V: dimension);
1884 }
1885
1886 auto inputDims = inputType.getShape();
1887 auto initDims = initType.getShape();
1888
1889 // Input dimensions that will be left after the reduction.
1890 SmallVector<int64_t> reducedInputDims;
1891 for (const auto &en : llvm::enumerate(First&: inputDims)) {
1892 if (!dimensionsToReduce.count(V: en.index()))
1893 reducedInputDims.push_back(Elt: en.value());
1894 }
1895
1896 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1897 return emitOpError() << "number of dimensions after reduction "
1898 << reducedInputDims.size()
1899 << " doesn't match the init rank "
1900 << initType.getRank();
1901 }
1902
1903 if (reducedInputDims != initDims)
1904 return emitOpError() << "init dimensions [" << initDims
1905 << "] doesn't match input dimensions after reduction ["
1906 << reducedInputDims << "]";
1907
1908 Block *block = getBody();
1909 if (block->getNumArguments() != this->getNumOperands())
1910 return emitOpError()
1911 << "mismatching number of operands and block arguments";
1912
1913 // Check that the first block arguments match the element type of the inputs.
1914 for (auto [input, bbArg] : llvm::zip(t: getInputs(), u: block->getArguments())) {
1915 Type inputElementType =
1916 llvm::cast<ShapedType>(Val: input.getType()).getElementType();
1917 if (inputElementType != bbArg.getType())
1918 return emitOpError()
1919 << "input element type " << inputElementType
1920 << " does not match corresponding block argument type "
1921 << bbArg.getType();
1922 }
1923
1924 // Check that the last block arguments match the element type of the outputs.
1925 for (auto [output, bbArg] : llvm::zip(
1926 t: getDpsInits(), u: block->getArguments().take_back(N: getNumDpsInits()))) {
1927 auto outputElementType =
1928 llvm::cast<ShapedType>(Val: output.getType()).getElementType();
1929 if (outputElementType != bbArg.getType())
1930 return emitOpError()
1931 << "output element type " << outputElementType
1932 << " does not match corresponding block argument type "
1933 << bbArg.getType();
1934 }
1935 return success();
1936}
1937
1938//===----------------------------------------------------------------------===//
1939// TransposeOp
1940//===----------------------------------------------------------------------===//
1941
1942static void buildIdentityRegion(OpBuilder &builder, Location loc,
1943 Region &region, ValueRange inputs,
1944 ValueRange outputs) {
1945 buildGenericRegion(builder, loc, region, inputs, outputs,
1946 bodyBuild: [](OpBuilder &b, Location loc, ValueRange args) {
1947 if (!args.empty())
1948 b.create<linalg::YieldOp>(location: loc, args: args[0]);
1949 });
1950}
1951
1952void TransposeOp::build(::mlir::OpBuilder &builder,
1953 ::mlir::OperationState &result, Value input, Value init,
1954 DenseI64ArrayAttr permutation,
1955 ArrayRef<NamedAttribute> attributes) {
1956 result.addOperands(newOperands: input);
1957 result.addOperands(newOperands: init);
1958 result.addAttribute(name: getPermutationAttrName(name: result.name), attr: permutation);
1959 result.addAttributes(newAttributes: attributes);
1960
1961 // Add output types for `RankedTensorType` output arguments.
1962 Type initType = init.getType();
1963 if (llvm::isa<RankedTensorType>(Val: initType))
1964 result.addTypes(newTypes: initType);
1965
1966 buildIdentityRegion(builder, loc: result.location, region&: *result.addRegion(), inputs: input,
1967 outputs: init);
1968}
1969
1970void TransposeOp::build(::mlir::OpBuilder &builder,
1971 ::mlir::OperationState &result, Value input, Value init,
1972 ArrayRef<int64_t> permutation,
1973 ArrayRef<NamedAttribute> attributes) {
1974 build(builder, result, input, init, permutation: builder.getDenseI64ArrayAttr(values: permutation),
1975 attributes);
1976}
1977
1978ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1979 if (failed(Result: parseDstStyleOp(
1980 parser, result, parseAttrsFn: [&](OpAsmParser &parser, NamedAttrList &attributes) {
1981 return parseDenseI64ArrayAttr(parser, attributes, attributeName: "permutation");
1982 })))
1983 return failure();
1984
1985 OpBuilder builder(parser.getContext());
1986 buildIdentityRegion(builder, loc: result.location, region&: *result.addRegion(),
1987 /*inputs=*/result.operands,
1988 /*outputs=*/{});
1989 return success();
1990}
1991
1992void TransposeOp::getAsmResultNames(
1993 function_ref<void(Value, StringRef)> setNameFn) {
1994 if (!getResults().empty())
1995 setNameFn(getResults().front(), "transposed");
1996}
1997
1998void TransposeOp::print(OpAsmPrinter &p) {
1999 printCommonStructuredOpParts(p, inputs: getDpsInputs(), outputs: getDpsInits());
2000 printDenseI64ArrayAttr(p, attributeName: getPermutationAttrName(), attributeValue: getPermutation());
2001 p.printOptionalAttrDict(attrs: (*this)->getAttrs(), elidedAttrs: {getPermutationAttrName()});
2002}
2003
2004LogicalResult TransposeOp::verify() {
2005 ArrayRef<int64_t> permutationRef = getPermutation();
2006
2007 if (!isPermutationVector(interchange: permutationRef))
2008 return emitOpError(message: "permutation is not valid");
2009
2010 auto inputType = getInput().getType();
2011 auto initType = getInit().getType();
2012
2013 int64_t rank = inputType.getRank();
2014
2015 if (rank != initType.getRank())
2016 return emitOpError() << "input rank " << rank
2017 << " does not match init rank " << initType.getRank();
2018
2019 if (rank != static_cast<int64_t>(permutationRef.size()))
2020 return emitOpError() << "size of permutation " << permutationRef.size()
2021 << " does not match the argument rank " << rank;
2022
2023 auto inputDims = inputType.getShape();
2024 auto initDims = initType.getShape();
2025
2026 for (int64_t i = 0; i < rank; ++i) {
2027 int64_t inputDim = inputDims[permutationRef[i]];
2028 int64_t initDim = initDims[i];
2029
2030 if (inputDim != initDim) {
2031 return emitOpError() << "dim(result, " << i << ") = " << initDim
2032 << " doesn't match dim(input, permutation[" << i
2033 << "]) = " << inputDim;
2034 }
2035 }
2036
2037 return success();
2038}
2039
2040SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
2041 int64_t rank = getInit().getType().getRank();
2042 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2043}
2044
2045ArrayAttr TransposeOp::getIndexingMaps() {
2046 Builder builder(getContext());
2047 int64_t rank = getInit().getType().getRank();
2048 return builder.getAffineMapArrayAttr(
2049 values: {inversePermutation(map: AffineMap::getPermutationMap(
2050 permutation: llvm::to_vector_of<unsigned>(Range: getPermutation()), context: getContext())),
2051 builder.getMultiDimIdentityMap(rank)});
2052}
2053
2054void TransposeOp::getEffects(
2055 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2056 &effects) {
2057 getGenericEffectsImpl(effects, linalgOp: cast<LinalgOp>(Val: getOperation()));
2058}
2059
2060Speculation::Speculatability TransposeOp::getSpeculatability() {
2061 return getGenericSpeculatabilityImpl(linalgOp: cast<LinalgOp>(Val: getOperation()));
2062}
2063
2064LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
2065 SmallVectorImpl<OpFoldResult> &result) {
2066 // Only the tensor type is supported.
2067 if (!isa<TensorType>(Val: getInput().getType()))
2068 return failure();
2069
2070 // Single dimension transpose.
2071 if (getPermutation().size() == 0) {
2072 result.push_back(Elt: getInput());
2073 return success();
2074 }
2075 // Identity permutation.
2076 if (isIdentityPermutation(permutation: getPermutation())) {
2077 result.push_back(Elt: getInput());
2078 return success();
2079 }
2080
2081 return failure();
2082}
2083
2084/// Fold transpose with transpose.
2085struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
2086 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2087
2088 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2089 PatternRewriter &rewriter) const override {
2090 auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2091 if (!defTransposeOp)
2092 return failure();
2093 ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2094 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2095 SmallVector<int64_t> foldedPerms;
2096 foldedPerms.reserve(N: perms.size());
2097 for (int64_t perm : perms)
2098 foldedPerms.push_back(Elt: defPerms[perm]);
2099
2100 rewriter.replaceOpWithNewOp<TransposeOp>(
2101 op: transposeOp, args: defTransposeOp.getInput(), args: transposeOp.getInit(),
2102 args&: foldedPerms);
2103 return success();
2104 }
2105};
2106
2107/// This pattern canonicalize transpose by swapping the order of
2108/// broadcast and transpose:
2109/// transpose(broadcast(input)) -> broadcast(transpose(input))
2110struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2111 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
2112
2113 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2114 PatternRewriter &rewriter) const override {
2115 Value input = transposeOp.getInput();
2116 BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2117 if (!input.hasOneUse() || !broadcastOp)
2118 return failure();
2119
2120 ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2121 ArrayRef<int64_t> perms = transposeOp.getPermutation();
2122
2123 // Get new perms and new dimensions.
2124 SmallVector<int64_t> resultPerms = dropDims(inputPerm: perms, dropPositions: dimensions);
2125 SmallVector<int64_t> invertPerm = invertPermutationVector(permutation: perms);
2126 SmallVector<int64_t> resultDimensions;
2127 unsigned dimensionSize = dimensions.size();
2128 for (unsigned i = 0; i < dimensionSize; ++i)
2129 resultDimensions.push_back(Elt: invertPerm[dimensions[i]]);
2130
2131 // Create transpose result.
2132 Value broadcastInput = broadcastOp.getInput();
2133 Location loc = transposeOp.getLoc();
2134 MLIRContext *ctx = transposeOp.getContext();
2135 SmallVector<OpFoldResult> dims;
2136 auto broadcastInputTy =
2137 mlir::cast<RankedTensorType>(Val: broadcastInput.getType());
2138 unsigned inputRank = broadcastInputTy.getRank();
2139 for (unsigned i = 0; i < inputRank; ++i) {
2140 if (broadcastInputTy.isDynamicDim(idx: i)) {
2141 dims.push_back(Elt: rewriter.create<tensor::DimOp>(location: loc, args&: broadcastInput, args&: i)
2142 ->getResult(idx: 0));
2143 } else {
2144 dims.push_back(Elt: IntegerAttr::get(type: IndexType::get(context: ctx),
2145 value: broadcastInputTy.getDimSize(idx: i)));
2146 }
2147 }
2148 SmallVector<OpFoldResult> transposeResultShapes =
2149 applyPermutation(input: dims, permutation: resultPerms);
2150 Value transposeInit = rewriter.create<tensor::EmptyOp>(
2151 location: transposeOp.getLoc(), args&: transposeResultShapes,
2152 args: broadcastInputTy.getElementType());
2153
2154 // Create broadcast(transpose(input)).
2155 Value transposeResult =
2156 rewriter
2157 .create<TransposeOp>(location: loc, args: broadcastOp.getInput(), args&: transposeInit,
2158 args&: resultPerms)
2159 ->getResult(idx: 0);
2160 rewriter.replaceOpWithNewOp<BroadcastOp>(
2161 op: transposeOp, args&: transposeResult, args: transposeOp.getInit(), args&: resultDimensions);
2162 return success();
2163 }
2164};
2165
2166void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2167 MLIRContext *context) {
2168 results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(arg&: context);
2169}
2170
2171//===----------------------------------------------------------------------===//
2172// BroadcastOp
2173//===----------------------------------------------------------------------===//
2174
2175void BroadcastOp::build(::mlir::OpBuilder &builder,
2176 ::mlir::OperationState &result, Value input, Value init,
2177 DenseI64ArrayAttr dimensions,
2178 ArrayRef<NamedAttribute> attributes) {
2179 result.addOperands(newOperands: input);
2180 result.addOperands(newOperands: init);
2181 result.addAttribute(name: getDimensionsAttrName(name: result.name), attr: dimensions);
2182 result.addAttributes(newAttributes: attributes);
2183
2184 // Add output types for `RankedTensorType` output arguments.
2185 Type initType = init.getType();
2186 if (llvm::isa<RankedTensorType>(Val: initType))
2187 result.addTypes(newTypes: initType);
2188
2189 buildIdentityRegion(builder, loc: result.location, region&: *result.addRegion(), inputs: input,
2190 outputs: init);
2191}
2192
2193void BroadcastOp::build(::mlir::OpBuilder &builder,
2194 ::mlir::OperationState &result, Value input, Value init,
2195 ArrayRef<int64_t> dimensions,
2196 ArrayRef<NamedAttribute> attributes) {
2197 build(builder, result, input, init, dimensions: builder.getDenseI64ArrayAttr(values: dimensions),
2198 attributes);
2199}
2200
2201ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2202 if (failed(Result: parseDstStyleOp(
2203 parser, result, parseAttrsFn: [&](OpAsmParser &parser, NamedAttrList &attributes) {
2204 return parseDenseI64ArrayAttr(parser, attributes, attributeName: "dimensions");
2205 })))
2206 return failure();
2207
2208 OpBuilder builder(parser.getContext());
2209 buildIdentityRegion(builder, loc: result.location, region&: *result.addRegion(),
2210 /*inputs=*/result.operands,
2211 /*outputs=*/{});
2212 return success();
2213}
2214
2215void BroadcastOp::getAsmResultNames(
2216 function_ref<void(Value, StringRef)> setNameFn) {
2217 if (!getResults().empty())
2218 setNameFn(getResults().front(), "broadcasted");
2219}
2220
2221void BroadcastOp::print(OpAsmPrinter &p) {
2222 printCommonStructuredOpParts(p, inputs: getDpsInputs(), outputs: getDpsInits());
2223 printDenseI64ArrayAttr(p, attributeName: getDimensionsAttrName(), attributeValue: getDimensions());
2224 p.printOptionalAttrDict(attrs: (*this)->getAttrs(), elidedAttrs: {getDimensionsAttrName()});
2225}
2226
2227LogicalResult BroadcastOp::verify() {
2228 ArrayRef<int64_t> dimensionsRef = getDimensions();
2229
2230 auto inputType = getInput().getType();
2231 auto initType = getInit().getType();
2232
2233 int64_t inputRank = inputType.getRank();
2234 int64_t initRank = initType.getRank();
2235
2236 auto inputShape = inputType.getShape();
2237 auto initShape = initType.getShape();
2238
2239 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2240 return emitOpError() << "input rank plus added dimensions does not "
2241 "match init rank. input rank: "
2242 << inputRank
2243 << ", dimensions size: " << dimensionsRef.size()
2244 << ", init rank: " << initRank;
2245
2246 for (const auto &[idx, dim] : llvm::enumerate(First&: dimensionsRef)) {
2247 if (dim < 0 || dim >= initRank)
2248 return emitOpError() << "dimension " << idx
2249 << " is out of range. expected range: [0, "
2250 << initRank - 1 << "], got: " << dim;
2251 }
2252
2253 // Mapping from input dims to init dims.
2254 SmallVector<int64_t> dimMap;
2255 for (auto dim : llvm::seq<int64_t>(Begin: 0, End: initRank)) {
2256 if (!llvm::is_contained(Range&: dimensionsRef, Element: dim))
2257 dimMap.push_back(Elt: dim);
2258 }
2259
2260 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(First&: dimMap)) {
2261 // This dimensions is mapped from the input. Init and input dims should
2262 // match.
2263 if (inputShape[inputDimIdx] != initShape[initDimIdx])
2264 return emitOpError() << "input dim " << inputDimIdx
2265 << " should match init dim " << initDimIdx
2266 << ". input: " << inputShape[inputDimIdx]
2267 << ", init: " << initShape[initDimIdx];
2268 }
2269
2270 return success();
2271}
2272
2273SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2274 int64_t rank = getInit().getType().getRank();
2275 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2276}
2277
2278ArrayAttr BroadcastOp::getIndexingMaps() {
2279 Builder builder(getContext());
2280 int64_t rank = getInit().getType().getRank();
2281 return builder.getAffineMapArrayAttr(
2282 values: {builder.getMultiDimIdentityMap(rank).dropResults(positions: getDimensions()),
2283 builder.getMultiDimIdentityMap(rank)});
2284}
2285
2286void BroadcastOp::getEffects(
2287 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2288 &effects) {
2289 getGenericEffectsImpl(effects, linalgOp: cast<LinalgOp>(Val: getOperation()));
2290}
2291
2292Speculation::Speculatability BroadcastOp::getSpeculatability() {
2293 return getGenericSpeculatabilityImpl(linalgOp: cast<LinalgOp>(Val: getOperation()));
2294}
2295
2296void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2297 MLIRContext *context) {
2298 results.add<EraseIdentityLinalgOp<BroadcastOp>>(arg&: context);
2299}
2300
2301//===----------------------------------------------------------------------===//
2302// YieldOp
2303//===----------------------------------------------------------------------===//
2304
2305void linalg::YieldOp::print(OpAsmPrinter &p) {
2306 if (getNumOperands() > 0)
2307 p << ' ' << getOperands();
2308 p.printOptionalAttrDict(attrs: (*this)->getAttrs());
2309 if (getNumOperands() > 0)
2310 p << " : " << getOperandTypes();
2311}
2312
2313ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2314 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
2315 SmallVector<Type, 2> types;
2316 SMLoc loc = parser.getCurrentLocation();
2317 return failure(IsFailure: parser.parseOperandList(result&: opInfo) ||
2318 parser.parseOptionalAttrDict(result&: result.attributes) ||
2319 (!opInfo.empty() && parser.parseColonTypeList(result&: types)) ||
2320 parser.resolveOperands(operands&: opInfo, types, loc, result&: result.operands));
2321}
2322
2323// Check the operand number and types must match the element types of the
2324// LinalgOp interface's shaped operands.
2325static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2326 if (op.getNumOperands() != linalgOp.getNumDpsInits())
2327 return op.emitOpError(message: "expected number of yield values (")
2328 << op.getNumOperands()
2329 << ") to match the number of inits / outs operands of the enclosing "
2330 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2331
2332 for (OpOperand &opOperand : op->getOpOperands()) {
2333 OpOperand *outputOperand =
2334 linalgOp.getDpsInitOperand(i: opOperand.getOperandNumber());
2335 Type elementType = outputOperand->get().getType();
2336 if (isa<MemRefType, RankedTensorType>(Val: elementType))
2337 elementType = getElementTypeOrSelf(type: outputOperand->get().getType());
2338 if (opOperand.get().getType() != elementType)
2339 return op.emitOpError(message: "type of yield operand ")
2340 << (opOperand.getOperandNumber() + 1) << " ("
2341 << opOperand.get().getType() << ") doesn't match "
2342 << "the element type of the enclosing linalg.generic op ("
2343 << elementType << ")";
2344 }
2345 return success();
2346}
2347
2348LogicalResult linalg::YieldOp::verify() {
2349 auto *parentOp = (*this)->getParentOp();
2350 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(index: 0).empty())
2351 return emitOpError(message: "expected single non-empty parent region");
2352
2353 if (auto linalgOp = dyn_cast<LinalgOp>(Val: parentOp))
2354 return verifyYield(op: *this, linalgOp);
2355
2356 return emitOpError(message: "expected parent op with LinalgOp interface");
2357}
2358
2359//===----------------------------------------------------------------------===//
2360// IndexOp
2361//===----------------------------------------------------------------------===//
2362
2363LogicalResult IndexOp::verify() {
2364 auto linalgOp = dyn_cast<LinalgOp>(Val: (*this)->getParentOp());
2365 if (!linalgOp)
2366 return emitOpError(message: "expected parent op with LinalgOp interface");
2367 if (linalgOp.getNumLoops() <= getDim())
2368 return emitOpError(message: "expected dim (")
2369 << getDim() << ") to be lower than the number of loops ("
2370 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2371 return success();
2372}
2373
2374OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2375 auto linalgOp = dyn_cast_or_null<LinalgOp>(Val: (*this)->getParentOp());
2376 // Bail out if `linalg.index` does not have a proper parent yet at this
2377 // point, e.g., when calling `createOrFold` during IR construction in
2378 // `genericOp::build`.
2379 if (!linalgOp)
2380 return OpFoldResult{};
2381
2382 // Index of unit dims is always 0.
2383 SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
2384 uint64_t dim = getDim();
2385 assert(dim < loopBounds.size() && "Dim is out of bounds");
2386 if (loopBounds[dim] == 1)
2387 return IntegerAttr::get(type: IndexType::get(context: getContext()), value: 0);
2388
2389 return OpFoldResult{};
2390}
2391
2392/////// Operations corresponding to library calls defined with Tablegen ////////
2393
2394#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2395
2396#define GET_OP_CLASSES
2397#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2398
2399#define GET_OP_CLASSES
2400#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2401#define GET_OP_CLASSES
2402#include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2403
2404AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2405 unsigned rank,
2406 MLIRContext *context) {
2407 if (maybeMap)
2408 return *maybeMap;
2409 if (rank == 0)
2410 return AffineMap::get(context);
2411 return AffineMap::getMultiDimIdentityMap(numDims: rank, context);
2412}
2413
2414SmallVector<AffineExpr, 4>
2415mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2416 MLIRContext *context) {
2417 SmallVector<AffineExpr, 4> res;
2418 res.reserve(N: num);
2419 for (unsigned i = 0; i < num; ++i)
2420 res.push_back(Elt: getAffineDimExpr(position: startIdx++, context));
2421 return res;
2422}
2423
2424SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
2425 ArrayRef<AffineExpr> b) {
2426 auto rangeA = llvm::make_range(x: a.begin(), y: a.end());
2427 auto rangeB = llvm::make_range(x: b.begin(), y: b.end());
2428 auto concatRanges = llvm::concat<const AffineExpr>(Ranges&: rangeA, Ranges&: rangeB);
2429 return llvm::to_vector<4>(Range&: concatRanges);
2430}
2431
2432static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2433 if (auto memref = llvm::dyn_cast<MemRefType>(Val&: t)) {
2434 ss << "view";
2435 for (auto size : memref.getShape())
2436 if (size < 0)
2437 ss << "sx";
2438 else
2439 ss << size << "x";
2440 if (failed(Result: appendMangledType(ss, t: memref.getElementType())))
2441 return failure();
2442 if (auto as = memref.getMemorySpace()) {
2443 if (auto attr = llvm::dyn_cast<IntegerAttr>(Val&: as))
2444 ss << "as" << attr.getInt();
2445 else
2446 return failure();
2447 }
2448 return success();
2449 }
2450 if (auto vec = llvm::dyn_cast<VectorType>(Val&: t)) {
2451 ss << "vector";
2452 llvm::interleave(
2453 c: vec.getShape(), each_fn: [&](int64_t i) { ss << i; }, between_fn: [&]() { ss << "x"; });
2454 if (failed(Result: appendMangledType(ss, t: vec.getElementType())))
2455 return failure();
2456 return success();
2457 }
2458 if (t.isSignlessIntOrIndexOrFloat()) {
2459 ss << t;
2460 return success();
2461 }
2462 return failure();
2463}
2464
2465std::string mlir::linalg::generateLibraryCallName(Operation *op) {
2466 assert(isa<LinalgOp>(op));
2467 std::string name(op->getName().getStringRef().str());
2468 std::string fun = "";
2469 for (NamedAttribute kv : op->getAttrs()) {
2470 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(Val: kv.getValue())) {
2471 fun = stringifyEnum(enumValue: ufa.getValue()).str() + "_";
2472 } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(Val: kv.getValue())) {
2473 fun = stringifyEnum(enumValue: bfa.getValue()).str() + "_";
2474 }
2475 }
2476 name.reserve(res_arg: 128);
2477 llvm::replace(Range&: name, OldValue: '.', NewValue: '_');
2478 llvm::raw_string_ostream ss(name);
2479 ss << "_" << fun;
2480 for (Type t : op->getOperandTypes()) {
2481 if (failed(Result: appendMangledType(ss, t)))
2482 return std::string();
2483 ss << "_";
2484 }
2485 name.pop_back();
2486 return name;
2487}
2488
2489//===----------------------------------------------------------------------===//
2490// Canonicalizers and Folders.
2491//===----------------------------------------------------------------------===//
2492
2493namespace {
2494struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2495 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2496
2497 LogicalResult matchAndRewrite(LinalgOp op,
2498 PatternRewriter &rewriter) const override {
2499 for (OpOperand &opOperand : op->getOpOperands()) {
2500 // Linalg "inputs" may be either tensor or memref type.
2501 // tensor<0xelt_type> is a convention that may not always mean
2502 // "0 iterations". Only erase in cases we see memref<...x0x...>.
2503 auto mt = llvm::dyn_cast<MemRefType>(Val: opOperand.get().getType());
2504 if (!mt)
2505 continue;
2506 if (llvm::is_contained(Range: op.getShape(opOperand: &opOperand), Element: 0)) {
2507 rewriter.eraseOp(op);
2508 return success();
2509 }
2510 }
2511 return failure();
2512 }
2513};
2514
2515/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2516/// result that is more static than the linalg op.
2517struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2518 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2519
2520 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2521 PatternRewriter &rewriter) const override {
2522 if (!tensor::canFoldIntoProducerOp(castOp))
2523 return failure();
2524
2525 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2526 if (!linalgOp)
2527 return failure();
2528
2529 // Cast can be in conditionally reachable region, if which case folding will
2530 // generate invalid code. Only conservatively fold ops in same block for
2531 // now.
2532 if (castOp->getBlock() != linalgOp->getBlock())
2533 return failure();
2534
2535 OpBuilder::InsertionGuard guard(rewriter);
2536 rewriter.setInsertionPoint(linalgOp);
2537
2538 Location loc = linalgOp.getLoc();
2539 OpResult resultValue = llvm::cast<OpResult>(Val: castOp.getSource());
2540 unsigned resultNumber = resultValue.getResultNumber();
2541 auto resultType =
2542 llvm::cast<RankedTensorType>(Val: castOp->getResult(idx: 0).getType());
2543 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2544 // going from a more dynamic shape to a less dynamic shape. If the producer
2545 // for this cast, i.e. producer of the out operand, is also an operation
2546 // that folds with tensor.cast consumer (like this pattern), the cast will
2547 // continue to propagate as far up the stack as it can go.
2548 OpOperand *outOperand = linalgOp.getDpsInitOperand(i: resultNumber);
2549 Value newOperand =
2550 rewriter.create<tensor::CastOp>(location: loc, args&: resultType, args: outOperand->get());
2551 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2552 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2553 linalgOp.getDpsInits().end());
2554 outputOperands[resultNumber] = newOperand;
2555 newOperands.append(in_start: outputOperands.begin(), in_end: outputOperands.end());
2556
2557 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2558 linalgOp->result_type_end());
2559 resultTypes[resultNumber] = resultType;
2560 Operation *newOp = clone(b&: rewriter, op: linalgOp, newResultTypes: resultTypes, newOperands);
2561
2562 // Create a tensor.cast operation back to the original type.
2563 Value castBack = rewriter.create<tensor::CastOp>(
2564 location: loc, args: resultValue.getType(), args: newOp->getResult(idx: resultNumber));
2565
2566 SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2567 results[resultNumber] = castBack;
2568 rewriter.replaceOp(op: linalgOp, newValues: results);
2569 rewriter.replaceOp(op: castOp, newValues: newOp->getResult(idx: resultNumber));
2570 return success();
2571 }
2572};
2573
2574/// For each of the operand in `operands` this function maps the static sizes of
2575/// dimensions to their affine dim expressions.
2576static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2577 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2578 for (OpOperand &opOperand : operands) {
2579 if (linalgOp.isScalar(opOperand: &opOperand))
2580 continue;
2581 Value src = opOperand.get();
2582 auto sourceType = llvm::cast<RankedTensorType>(Val: src.getType());
2583 auto sourceMap = linalgOp.getMatchingIndexingMap(opOperand: &opOperand);
2584
2585 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2586 // `tensor.cast` operation and source of the cast operation has a static
2587 // shape, then assign it to the `sourceShape`.
2588 auto *parentOp = src.getDefiningOp();
2589 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2590 if (parentOp) {
2591 if (auto castOp = dyn_cast<tensor::CastOp>(Val: parentOp)) {
2592 Value castSource = castOp.getSource();
2593 auto castSourceType =
2594 llvm::dyn_cast<RankedTensorType>(Val: castSource.getType());
2595 if (castSourceType && castSourceType.hasStaticShape())
2596 sourceShape = castSourceType.getShape();
2597 }
2598 }
2599
2600 // If the source shape's dimension has a static shape, map the affine dim
2601 // expression to the known static size.
2602 for (unsigned i = 0; i < sourceShape.size(); i++) {
2603 if (sourceType.isDynamicDim(idx: i))
2604 continue;
2605 if (auto affineDimExpr = dyn_cast<AffineDimExpr>(Val: sourceMap.getResult(idx: i)))
2606 affineExprToSize.try_emplace(Key: affineDimExpr, Args: sourceShape[i]);
2607 }
2608 }
2609}
2610
2611/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2612/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2613/// their result types is stored in `resultTypes`. If `opOperand` requires no
2614/// change then `changeNeeded` is false and same operand is added in the
2615/// `newOperands` list.
2616static void createNewOperandWithStaticSizes(
2617 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2618 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2619 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2620 bool &changeNeeded) {
2621 Value src = opOperand->get();
2622 newOperands.push_back(Elt: src);
2623 if (linalgOp.isScalar(opOperand))
2624 return;
2625 auto sourceType = llvm::cast<RankedTensorType>(Val: src.getType());
2626 Type resultType = sourceType;
2627 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2628 resultTypes.push_back(Elt: resultType);
2629 return;
2630 }
2631 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2632 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2633 SmallVector<int64_t> newShape;
2634 // If operand is updated with new shape, `newOperandNeeded` will be
2635 // true.
2636 bool newOperandNeeded = false;
2637 for (unsigned i = 0; i < sourceShape.size(); i++) {
2638 int64_t dimShape = sourceShape[i];
2639 AffineExpr dimExpr = sourceMap.getResult(idx: i);
2640 if (!affineExprToSize.contains(Val: dimExpr) || !sourceType.isDynamicDim(idx: i)) {
2641 newShape.push_back(Elt: dimShape);
2642 continue;
2643 }
2644 // Dimension has a dynamic shape and corresponding affine dim
2645 // expression is present in the map. So assign the size for the
2646 // given affine dim expression to the dimension.
2647 newShape.push_back(Elt: affineExprToSize[dimExpr]);
2648 newOperandNeeded = true;
2649 }
2650 resultType = RankedTensorType::get(shape: newShape, elementType: sourceType.getElementType(),
2651 encoding: sourceType.getEncoding());
2652 if (newOperandNeeded) {
2653 changeNeeded = true;
2654 // Get the new operand value given its size and element type by
2655 // casting it.
2656 Value newOperand = rewriter.create<tensor::CastOp>(location: loc, args&: resultType, args&: src);
2657 unsigned index = opOperand->getOperandNumber();
2658 newOperands[index] = newOperand;
2659 }
2660 if (linalgOp.isDpsInit(opOperand))
2661 resultTypes.push_back(Elt: resultType);
2662}
2663
2664/// Static shapes for the operands can be inferred if any one of the operands
2665/// have a static shape. This can be done by referring to the affine dim
2666/// expressions for the operand.
2667struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2668 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2669
2670 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2671 PatternRewriter &rewriter) const override {
2672 if (!linalgOp.hasPureTensorSemantics())
2673 return failure();
2674
2675 // Maps must be projected permutations.
2676 if (llvm::any_of(Range: linalgOp.getIndexingMapsArray(), P: [](AffineMap map) {
2677 return !map.isProjectedPermutation();
2678 }))
2679 return failure();
2680
2681 // Maps affine dim expressions to the static size of that dimension.
2682 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2683 Location loc = linalgOp.getLoc();
2684
2685 // For each of the affine dim expression, check if the size is known. If
2686 // known add that in the map.
2687 populateMap(linalgOp, operands: linalgOp->getOpOperands(), affineExprToSize);
2688
2689 SmallVector<Value> newOperands;
2690 SmallVector<Type> resultTypes;
2691
2692 // `changeNeeded` is `false` if the operands of `linalgOp` require no
2693 // change in their types.
2694 bool changeNeeded = false;
2695 newOperands.reserve(N: linalgOp->getNumOperands());
2696 resultTypes.reserve(N: linalgOp.getNumDpsInits());
2697
2698 // Iterate over all the operands and update the static sizes.
2699 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2700 createNewOperandWithStaticSizes(loc, rewriter, opOperand: &opOperand,
2701 affineExprToSize, linalgOp, newOperands,
2702 resultTypes, changeNeeded);
2703 }
2704
2705 // If the generic op has all the required static information, no
2706 // canonicalization needed.
2707 if (!changeNeeded)
2708 return failure();
2709
2710 // Clone op.
2711 Operation *newOp = clone(b&: rewriter, op: linalgOp, newResultTypes: resultTypes, newOperands);
2712 SmallVector<Value> replacements;
2713 replacements.reserve(N: newOp->getNumResults());
2714 for (auto it : llvm::zip(t: linalgOp->getResults(), u: newOp->getResults())) {
2715 Value newResult = std::get<1>(t&: it);
2716 Value oldResult = std::get<0>(t&: it);
2717 Type newType = newResult.getType();
2718 Type oldType = oldResult.getType();
2719 replacements.push_back(
2720 Elt: (newType != oldType)
2721 ? rewriter.create<tensor::CastOp>(location: loc, args&: oldType, args&: newResult)
2722 : newResult);
2723 }
2724 rewriter.replaceOp(op: linalgOp, newValues: replacements);
2725 return success();
2726 }
2727};
2728
2729} // namespace
2730
2731// All named ops canonicalizers and folders are auto-generated in the
2732// .cpp.inc.
2733
2734//===----------------------------------------------------------------------===//
2735// SoftmaxOp
2736//===----------------------------------------------------------------------===//
2737
2738LogicalResult SoftmaxOp::verify() {
2739 ShapedType inputType = getInputOperandType();
2740 ShapedType outputType = getOutputOperandType();
2741
2742 ArrayRef<int64_t> inputShape = inputType.getShape();
2743 ArrayRef<int64_t> outputShape = outputType.getShape();
2744 if (failed(Result: verifyCompatibleShape(shape1: inputShape, shape2: outputShape)))
2745 return emitOpError(message: "incompatible output shape");
2746
2747 int64_t inputRank = getInputOperandRank();
2748 int64_t dimension = getDimension();
2749 if ((dimension < 0) || (dimension >= inputRank))
2750 return emitOpError(message: "incorrect dimension specified");
2751
2752 return success();
2753}
2754
2755SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2756 int64_t operandRank = getInputOperandRank();
2757 SmallVector<Range> loopBounds(operandRank);
2758 Location loc = getLoc();
2759 Value zero = builder.create<arith::ConstantIndexOp>(location: loc, args: 0);
2760 Value one = builder.create<arith::ConstantIndexOp>(location: loc, args: 1);
2761 Value source = getInput();
2762 for (auto dim : llvm::seq<int64_t>(Begin: 0, End: operandRank)) {
2763 loopBounds[dim].offset = zero;
2764 loopBounds[dim].size = getDimValue(builder, loc, v: source, dim);
2765 loopBounds[dim].stride = one;
2766 }
2767 return loopBounds;
2768}
2769
2770SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2771 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2772 utils::IteratorType::parallel);
2773 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2774 return iteratorTypes;
2775}
2776
2777FailureOr<TilingResult>
2778SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2779 ArrayRef<OpFoldResult> offsets,
2780 ArrayRef<OpFoldResult> sizes) {
2781 int64_t rank = getInputOperandRank();
2782 auto oneAttr = builder.getI64IntegerAttr(value: 1);
2783 SmallVector<OpFoldResult> strides(rank, oneAttr);
2784 SmallVector<Value> tiledOperands;
2785 Operation *inputSlice =
2786 getSlice(b&: builder, loc: getLoc(), source: getInput(), offsets, sizes, strides);
2787 if (!inputSlice) {
2788 return emitOpError(message: "failed to compute input slice");
2789 }
2790 tiledOperands.emplace_back(Args: inputSlice->getResult(idx: 0));
2791 Operation *outputSlice =
2792 getSlice(b&: builder, loc: getLoc(), source: getOutput(), offsets, sizes, strides);
2793 if (!outputSlice) {
2794 return emitOpError(message: "failed to compute output slice");
2795 }
2796 tiledOperands.emplace_back(Args: outputSlice->getResult(idx: 0));
2797
2798 SmallVector<Type, 4> resultTypes;
2799 if (hasPureTensorSemantics())
2800 resultTypes.push_back(Elt: tiledOperands[1].getType());
2801 Operation *tiledOp =
2802 mlir::clone(b&: builder, op: getOperation(), newResultTypes: resultTypes, newOperands: tiledOperands);
2803
2804 return TilingResult{
2805 .tiledOps: {tiledOp},
2806 .tiledValues: SmallVector<Value>(tiledOp->getResults()),
2807 .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{inputSlice, outputSlice})};
2808}
2809
2810LogicalResult SoftmaxOp::getResultTilePosition(
2811 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2812 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2813 SmallVector<OpFoldResult> &resultSizes) {
2814 if (resultNumber == 0) {
2815 resultOffsets.assign(in_start: offsets.begin(), in_end: offsets.end());
2816 resultSizes.assign(in_start: sizes.begin(), in_end: sizes.end());
2817 return success();
2818 }
2819 return failure();
2820}
2821
2822// cast(dynamic) -> static.
2823LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2824 return memref::foldMemRefCast(op: *this);
2825}
2826
2827LogicalResult
2828SoftmaxOp::reifyResultShapes(OpBuilder &b,
2829 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2830 SmallVector<OpFoldResult> shapes;
2831 Location loc = getOperation()->getLoc();
2832 IRRewriter rewriter(b);
2833 auto inputShapedType = llvm::cast<ShapedType>(Val: getInputOperandType());
2834 auto outputShapedType = llvm::cast<ShapedType>(Val: getOutputOperandType());
2835 for (int64_t dim : llvm::seq<int64_t>(Begin: 0, End: getOutputOperandRank())) {
2836 if (!outputShapedType.isDynamicDim(idx: dim)) {
2837 // Static dim: Return IntegerAttr.
2838 shapes.push_back(Elt: b.getIndexAttr(value: inputShapedType.getDimSize(idx: dim)));
2839 } else {
2840 // Dynamic dim: Return Value.
2841 OpFoldResult ofr = createOrFoldDimOp(b, loc, source: getInput(), dim);
2842 shapes.push_back(Elt: getValueOrCreateConstantIndexOp(b, loc, ofr));
2843 }
2844 }
2845 reifiedReturnShapes.emplace_back(Args: std::move(shapes));
2846 return success();
2847}
2848
2849void SoftmaxOp::getEffects(
2850 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2851 &effects) {
2852 for (auto [index, operand] : llvm::enumerate(First: getDpsInputs())) {
2853 if (!llvm::isa<MemRefType>(Val: operand.getType()))
2854 continue;
2855 effects.emplace_back(Args: MemoryEffects::Read::get(),
2856 Args: &getOperation()->getOpOperand(idx: index), /*stage=*/Args: 0,
2857 /*effectOnFullRegion=*/Args: true,
2858 Args: SideEffects::DefaultResource::get());
2859 }
2860
2861 for (OpOperand &operand : getDpsInitsMutable()) {
2862 if (!llvm::isa<MemRefType>(Val: operand.get().getType()))
2863 continue;
2864 effects.emplace_back(Args: MemoryEffects::Read::get(), Args: &operand, /*stage=*/Args: 0,
2865 /*effectOnFullRegion=*/Args: true,
2866 Args: SideEffects::DefaultResource::get());
2867 effects.emplace_back(Args: MemoryEffects::Write::get(), Args: &operand, /*stage=*/Args: 0,
2868 /*effectOnFullRegion=*/Args: true,
2869 Args: SideEffects::DefaultResource::get());
2870 }
2871}
2872
2873// Helper functions for softmax decomposition.
2874// @{
2875
2876// Helper function to produce the iterator types (reduction or parallel) and
2877// affine maps for the iterators used in the decomposition of softmax.
2878// This method creates:
2879// If allParallel == true:
2880// - iterator type: {parallel, ..., parallel}
2881// - affine maps:
2882// -- identity with inputRank dimensions.
2883// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2884// where N == inputRank.
2885//
2886// If allParallel == false:
2887// - iterator type at dim(i) == parallel for i != \p dim and
2888// dim(dim) == reduction.
2889// - affine map:
2890// -- identity with inputRank dimensions.
2891// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2892// where N == inputRank.
2893static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2894computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
2895 int64_t dim, bool allParallel = false) {
2896 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2897 utils::IteratorType::parallel);
2898 if (!allParallel)
2899 iteratorTypes[dim] = utils::IteratorType::reduction;
2900 MLIRContext *ctxt = builder.getContext();
2901 auto identityMap = AffineMap::getMultiDimIdentityMap(numDims: inputRank, context: ctxt);
2902 SmallVector<AffineExpr, 2> affineExprs;
2903 for (int i = 0; i < inputRank; i++) {
2904 if (i != dim)
2905 affineExprs.push_back(Elt: mlir::getAffineDimExpr(position: i, context: ctxt));
2906 }
2907 auto reductionMap =
2908 AffineMap::get(dimCount: inputRank, /*symbols=*/symbolCount: 0, results: affineExprs, context: ctxt);
2909 SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2910 return std::make_tuple(args&: iteratorTypes, args&: indexingMaps);
2911}
2912
2913// Helper function to produce a linalg.generic that computes a reduction on
2914// dimension \p dim with the operation type \p T.
2915template <typename T>
2916static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2917 int64_t dim) {
2918 auto inputType = cast<ShapedType>(Val: input.getType());
2919 ArrayRef<int64_t> inputShape = inputType.getShape();
2920 int64_t inputRank = inputShape.size();
2921 auto [iteratorTypes, indexingMaps] =
2922 computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2923 assert(indexingMaps.size() == 2 &&
2924 "We should have two maps: 1 for the input, 1 for the output");
2925 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2926
2927 auto genericOp = builder.create<linalg::GenericOp>(
2928 loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2929 [&](OpBuilder &b, Location loc, ValueRange args) {
2930 Value result = b.create<T>(loc, args[0], args[1]);
2931 b.create<linalg::YieldOp>(location: loc, args&: result);
2932 });
2933 return genericOp.getResult(0);
2934}
2935
2936/// Produce a linalg generic that computes the second step of the softmax
2937/// decomposition: res = exp(input - max), where \p max is the max of \p input
2938/// on dimension \p dim.
2939static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2940 Value max, Value output, int64_t dim) {
2941 auto inputType = cast<ShapedType>(Val: input.getType());
2942 ArrayRef<int64_t> inputShape = inputType.getShape();
2943 int64_t inputRank = inputShape.size();
2944 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2945 builder, inputRank, dim, /*allParallel=*/true);
2946 assert(indexingMaps.size() == 2 && "We should have one map for each input");
2947 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2948 // Add the affine map for the output argument.
2949 indexingMaps.push_back(Elt: indexingMaps[0]);
2950 auto genericOp = builder.create<linalg::GenericOp>(
2951 location: loc, args: input.getType(), args: ValueRange{input, max}, args&: output, args&: indexingMaps,
2952 args&: iteratorTypes, args: [&](OpBuilder &b, Location loc, ValueRange args) {
2953 Value diff = b.create<arith::SubFOp>(location: loc, args: args[0], args: args[1]);
2954 Value result = b.create<math::ExpOp>(location: loc, args&: diff);
2955 b.create<linalg::YieldOp>(location: loc, args&: result);
2956 });
2957 return genericOp.getResult(i: 0);
2958}
2959
2960/// Produce a linalg generic that computes the final step of the softmax
2961/// decomposition.
2962/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2963/// yield n / d
2964/// }
2965static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2966 Value denominator, Value output, int64_t dim) {
2967 auto inputType = cast<ShapedType>(Val: numerator.getType());
2968 ArrayRef<int64_t> inputShape = inputType.getShape();
2969 int64_t inputRank = inputShape.size();
2970 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2971 builder, inputRank, dim, /*allParallel=*/true);
2972 assert(indexingMaps.size() == 2 &&
2973 "We should have one map for each input (2)");
2974 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2975 // Add the affine map for the output tensor.
2976 indexingMaps.push_back(Elt: indexingMaps[0]);
2977 auto genericOp = builder.create<linalg::GenericOp>(
2978 location: loc, args: numerator.getType(), args: ValueRange{numerator, denominator}, args&: output,
2979 args&: indexingMaps, args&: iteratorTypes,
2980 args: [&](OpBuilder &b, Location loc, ValueRange args) {
2981 Value result = b.create<arith::DivFOp>(location: loc, args: args[0], args: args[1]);
2982 b.create<linalg::YieldOp>(location: loc, args&: result);
2983 });
2984 return genericOp.getResult(i: 0);
2985}
2986// @} End helper functions for softmax decomposition.
2987
2988/// Given an N-dimensional tensor x, this method converts
2989/// softmax(x) to the following sequence of operations:
2990///
2991/// 1. Compute the max of x along dimension d. This results
2992/// in a N-1 dimensional tensor m.
2993/// m = max(x, dim = d)
2994///
2995/// 2. Subtract a broadcasted m from x and exponentiate. This results in
2996/// a N dimensional tensor z.
2997/// z = exp(x - m)
2998///
2999/// 3. Compute the sum of z along dimension d. This results in
3000/// a N-1 dimensional tensor l.
3001/// l = sum(z, dim = d)
3002///
3003/// 4. Divide z and l. This gives the N-dimensional softmax.
3004/// softmax = z / l
3005///
3006FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
3007 OpBuilder::InsertionGuard guard(b);
3008 b.setInsertionPoint(*this);
3009 Location loc = getLoc();
3010 Value input = getInput();
3011 ShapedType inputType = getInputOperandType();
3012 Type elementType = inputType.getElementType();
3013 int64_t reductionDim = getDimension();
3014 SmallVector<OpFoldResult> dims = tensor::getMixedSizes(builder&: b, loc, value: input);
3015 Value output = getOutput();
3016 dims.erase(CI: dims.begin() + reductionDim);
3017 // Step 1: Compute max along dim.
3018 Value outputReduce = b.create<tensor::EmptyOp>(location: loc, args&: dims, args&: elementType);
3019 Value neutralForMaxF = arith::getIdentityValue(op: arith::AtomicRMWKind::maxnumf,
3020 resultType: elementType, builder&: b, loc,
3021 /*useOnlyFiniteValue=*/true);
3022 Value neutralForMaxFInit =
3023 b.create<linalg::FillOp>(location: loc, args: Value{neutralForMaxF}, args&: outputReduce)
3024 .result();
3025 Value max =
3026 reduce<arith::MaxNumFOp>(builder&: b, loc, input, output: neutralForMaxFInit, dim: reductionDim);
3027
3028 // Step 2: Subtract max from input and exponentiate.
3029 Value numerator = buildSubAndExpOp(builder&: b, loc, input, max, output, dim: reductionDim);
3030
3031 // Step 3: Compute sum along dim.
3032 Value zero = arith::getIdentityValue(op: arith::AtomicRMWKind::addf, resultType: elementType,
3033 builder&: b, loc, /*useOnlyFiniteValue=*/true);
3034 Value zeroInit =
3035 b.create<linalg::FillOp>(location: loc, args: Value{zero}, args&: outputReduce).result();
3036 Value denominator =
3037 reduce<arith::AddFOp>(builder&: b, loc, input: numerator, output: zeroInit, dim: reductionDim);
3038
3039 // Step 4: Compute softmax.
3040 Value result =
3041 buildDivOp(builder&: b, loc, numerator, denominator, output, dim: reductionDim);
3042 return SmallVector<Value>{result};
3043}
3044
3045//===----------------------------------------------------------------------===//
3046// WinogradFilterTransformOp
3047//===----------------------------------------------------------------------===//
3048
3049LogicalResult WinogradFilterTransformOp::verify() {
3050 auto filterType = cast<ShapedType>(Val: getFilter().getType());
3051 ArrayRef<int64_t> filterShape = filterType.getShape();
3052 int64_t filterH = filterShape[getFilterHDim()];
3053 int64_t filterW = filterShape[getFilterWDim()];
3054 WinogradConv2DFmr fmr = getFmr();
3055 int64_t m, r;
3056 std::tie(args&: m, args&: r) = getFmrFromWinogradConv2DFmr(fmr);
3057
3058 if (filterH != r && filterH != 1)
3059 return emitOpError(message: "expect filter height either equals to r or 1");
3060 if (filterW != r && filterW != 1)
3061 return emitOpError(message: "expect filter width either equals to r or 1");
3062 if (filterH == 1 && filterW == 1)
3063 return emitOpError(message: "expect either filter height or width equals to r");
3064
3065 SmallVector<int64_t> expectedOutputShape;
3066 expectedOutputShape.push_back(Elt: filterH == r ? m + r - 1 : 1);
3067 expectedOutputShape.push_back(Elt: filterW == r ? m + r - 1 : 1);
3068 expectedOutputShape.push_back(Elt: filterShape[getFilterCDim()]);
3069 expectedOutputShape.push_back(Elt: filterShape[getFilterFDim()]);
3070
3071 auto outputType = cast<ShapedType>(Val: getOutput().getType());
3072 ArrayRef<int64_t> outputShape = outputType.getShape();
3073 if (failed(Result: verifyCompatibleShape(shape1: expectedOutputShape, shape2: outputShape))) {
3074 return emitOpError(message: "the output shape is not expected");
3075 }
3076 return success();
3077}
3078
3079SmallVector<Range>
3080WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
3081 Location loc = getLoc();
3082 IntegerAttr zeroAttr = builder.getIndexAttr(value: 0);
3083 IntegerAttr oneAttr = builder.getIndexAttr(value: 1);
3084 Value filter = getFilter();
3085 int64_t filterRank = getFilterOperandRank();
3086 SmallVector<Range> loopBounds(filterRank);
3087 for (unsigned dim = 0; dim < filterRank; ++dim) {
3088 loopBounds[dim].offset = zeroAttr;
3089 loopBounds[dim].size = getDimValue(builder, loc, v: filter, dim);
3090 loopBounds[dim].stride = oneAttr;
3091 }
3092 return loopBounds;
3093}
3094
3095SmallVector<utils::IteratorType>
3096WinogradFilterTransformOp::getLoopIteratorTypes() {
3097 int64_t filterRank = getFilterOperandRank();
3098 SmallVector<utils::IteratorType> iteratorTypes(filterRank,
3099 utils::IteratorType::parallel);
3100 return iteratorTypes;
3101}
3102
3103LogicalResult WinogradFilterTransformOp::getResultTilePosition(
3104 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3105 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3106 SmallVector<OpFoldResult> &resultSizes) {
3107 IntegerAttr zeroAttr = builder.getI64IntegerAttr(value: 0);
3108 ShapedType filterType = getFilterOperandType();
3109 ArrayRef<int64_t> filterShape = filterType.getShape();
3110 int64_t filterH = filterShape[getFilterHDim()];
3111 int64_t filterW = filterShape[getFilterWDim()];
3112 WinogradConv2DFmr fmr = getFmr();
3113 int64_t m, r;
3114 std::tie(args&: m, args&: r) = getFmrFromWinogradConv2DFmr(fmr);
3115 int64_t alpha = m + r - 1;
3116 int64_t alphaH = filterH != 1 ? alpha : 1;
3117 int64_t alphaW = filterW != 1 ? alpha : 1;
3118 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(value: alphaH);
3119 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(value: alphaW);
3120
3121 resultOffsets.append(
3122 IL: {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3123 resultSizes.append(
3124 IL: {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3125
3126 return success();
3127}
3128
3129/// Implement tiling for winograd_filter_transform
3130/// The input of winograd_filter_transform is (F, KH, KW, C).
3131/// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3132/// Users can specify the tile sizes of F and C.
3133/// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3134/// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3135FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
3136 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3137 ArrayRef<OpFoldResult> sizes) {
3138 IntegerAttr oneAttr = builder.getI64IntegerAttr(value: 1);
3139 IntegerAttr zeroAttr = builder.getI64IntegerAttr(value: 0);
3140 ShapedType filterType = getFilterOperandType();
3141 ArrayRef<int64_t> filterShape = filterType.getShape();
3142 int64_t filterH = filterShape[getFilterHDim()];
3143 int64_t filterW = filterShape[getFilterWDim()];
3144 IntegerAttr filterHAttr = builder.getI64IntegerAttr(value: filterH);
3145 IntegerAttr filterWAttr = builder.getI64IntegerAttr(value: filterW);
3146 SmallVector<Value> tiledOperands;
3147 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3148
3149 sliceOffsets.append(
3150 IL: {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3151 sliceSizes.append(IL: {sizes[getFilterFDim()], filterHAttr, filterWAttr,
3152 sizes[getFilterCDim()]});
3153 int64_t filterRank = getFilterOperandRank();
3154 SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3155 Location loc = getLoc();
3156 auto filterSlice = builder.create<tensor::ExtractSliceOp>(
3157 location: loc, args: getFilter(), args&: sliceOffsets, args&: sliceSizes, args&: filterStrides);
3158 tiledOperands.emplace_back(Args&: filterSlice);
3159
3160 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3161 if (failed(Result: getResultTilePosition(builder, resultNumber: 1, offsets, sizes, resultOffsets,
3162 resultSizes)))
3163 return failure();
3164
3165 int64_t outputRank = getOutputOperandRank();
3166 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3167 auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3168 location: loc, args: getOutput(), args&: resultOffsets, args&: resultSizes, args&: outputStrides);
3169 tiledOperands.emplace_back(Args&: outputSlice);
3170
3171 SmallVector<Type> resultTypes;
3172 resultTypes.push_back(Elt: tiledOperands[1].getType());
3173 Operation *tiledOp =
3174 mlir::clone(b&: builder, op: getOperation(), newResultTypes: resultTypes, newOperands: tiledOperands);
3175
3176 return TilingResult{
3177 .tiledOps: {tiledOp},
3178 .tiledValues: SmallVector<Value>(tiledOp->getResults()),
3179 .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{filterSlice, outputSlice})};
3180}
3181
3182//===----------------------------------------------------------------------===//
3183// WinogradInputTransformOp
3184//===----------------------------------------------------------------------===//
3185
3186LogicalResult WinogradInputTransformOp::verify() {
3187 auto inputType = cast<ShapedType>(Val: getInput().getType());
3188 ArrayRef<int64_t> inputShape = inputType.getShape();
3189 int64_t inputH = inputShape[getInputHDim()];
3190 int64_t inputW = inputShape[getInputWDim()];
3191 WinogradConv2DFmr fmr = getFmr();
3192 int64_t m, r;
3193 std::tie(args&: m, args&: r) = getFmrFromWinogradConv2DFmr(fmr);
3194 int64_t tileSize = m + r - 1;
3195
3196 auto outputType = cast<ShapedType>(Val: getOutput().getType());
3197 ArrayRef<int64_t> outputShape = outputType.getShape();
3198 bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3199 bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3200
3201 SmallVector<int64_t> expectedOutputShape(6, inputH);
3202 if (ShapedType::isDynamic(dValue: inputH)) {
3203 expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3204 expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3205 } else {
3206 expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3207 expectedOutputShape[getOutputTileHDim()] =
3208 leftTransform ? (inputH - (r - 1)) / m : inputH;
3209 }
3210 if (ShapedType::isDynamic(dValue: inputW)) {
3211 expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3212 expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3213 } else {
3214 expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3215 expectedOutputShape[getOutputTileWDim()] =
3216 rightTransform ? (inputW - (r - 1)) / m : inputW;
3217 }
3218 expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3219 expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3220
3221 if (failed(Result: verifyCompatibleShape(shape1: expectedOutputShape, shape2: outputShape))) {
3222 return emitOpError(message: "the output shape is not expected");
3223 }
3224 return success();
3225}
3226
3227SmallVector<Range>
3228WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3229 Location loc = getLoc();
3230 IntegerAttr zeroAttr = builder.getIndexAttr(value: 0);
3231 IntegerAttr oneAttr = builder.getIndexAttr(value: 1);
3232 Value output = getOutput();
3233 int64_t outputRank = getOutputOperandRank();
3234 SmallVector<Range> loopBounds(outputRank);
3235 for (unsigned dim = 0; dim < outputRank; ++dim) {
3236 loopBounds[dim].offset = zeroAttr;
3237 // alphaH, alphaW, tileH, tileW, N, C
3238 loopBounds[dim].size = getDimValue(builder, loc, v: output, dim);
3239 loopBounds[dim].stride = oneAttr;
3240 }
3241 return loopBounds;
3242}
3243
3244SmallVector<utils::IteratorType>
3245WinogradInputTransformOp::getLoopIteratorTypes() {
3246 int64_t outputRank = getOutputOperandRank();
3247 SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3248 utils::IteratorType::parallel);
3249 return iteratorTypes;
3250}
3251
3252LogicalResult WinogradInputTransformOp::getResultTilePosition(
3253 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3254 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3255 SmallVector<OpFoldResult> &resultSizes) {
3256 IntegerAttr zeroAttr = builder.getI64IntegerAttr(value: 0);
3257 ShapedType outputType = getOutputOperandType();
3258 ArrayRef<int64_t> outputShape = outputType.getShape();
3259 int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3260 int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3261
3262 WinogradConv2DFmr fmr = getFmr();
3263 int64_t m, r;
3264 std::tie(args&: m, args&: r) = getFmrFromWinogradConv2DFmr(fmr);
3265 int64_t alpha = m + r - 1;
3266 int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3267 int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3268
3269 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(value: alphaH);
3270 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(value: alphaW);
3271
3272 resultOffsets.append(IL: {zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3273 offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3274 offsets[getOutputCDim()]});
3275 resultSizes.append(IL: {alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3276 sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3277 sizes[getOutputCDim()]});
3278
3279 return success();
3280}
3281
3282/// Implement tiling for winograd_input_transform
3283/// The input of winograd_input_transform is (N, H, W, C).
3284/// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3285/// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3286/// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3287/// the values for the sizes of tileH, tileW, N, C for one tile.
3288FailureOr<TilingResult>
3289WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3290 ArrayRef<OpFoldResult> offsets,
3291 ArrayRef<OpFoldResult> sizes) {
3292 IntegerAttr oneAttr = builder.getI64IntegerAttr(value: 1);
3293 WinogradConv2DFmr fmr = getFmr();
3294 int64_t m, r;
3295 std::tie(args&: m, args&: r) = getFmrFromWinogradConv2DFmr(fmr);
3296
3297 ShapedType outputType = getOutputOperandType();
3298 ArrayRef<int64_t> outputShape = outputType.getShape();
3299 int64_t alphaH = outputShape[getOutputAlphaHDim()];
3300 int64_t alphaW = outputShape[getOutputAlphaWDim()];
3301
3302 Location loc = getLoc();
3303 MLIRContext *context = builder.getContext();
3304 auto identityAffineMap =
3305 AffineMap::get(dimCount: 1, symbolCount: 0, results: {builder.getAffineDimExpr(position: 0)}, context);
3306 auto offsetAffineMap =
3307 AffineMap::get(dimCount: 1, symbolCount: 0, results: {builder.getAffineDimExpr(position: 0) * m}, context);
3308 Value mappedOffsetH = affine::makeComposedAffineApply(
3309 b&: builder, loc, map: (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3310 operands: offsets[getOutputTileHDim()]);
3311 Value mappedOffsetW = affine::makeComposedAffineApply(
3312 b&: builder, loc, map: (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3313 operands: offsets[getOutputTileWDim()]);
3314 auto sizeAffineMap = AffineMap::get(
3315 dimCount: 1, symbolCount: 0, results: {builder.getAffineDimExpr(position: 0) * m + (r - 1)}, context);
3316 Value mappedSizeH = affine::makeComposedAffineApply(
3317 b&: builder, loc, map: sizeAffineMap, operands: sizes[getOutputTileHDim()]);
3318 Value mappedSizeW = affine::makeComposedAffineApply(
3319 b&: builder, loc, map: sizeAffineMap, operands: sizes[getOutputTileWDim()]);
3320
3321 SmallVector<Value> tiledOperands;
3322 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3323
3324 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3325 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3326 sliceOffsets.append(
3327 IL: {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3328 OpFoldResult sizeH =
3329 alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3330 OpFoldResult sizeW =
3331 alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3332 sliceSizes.append(
3333 IL: {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3334 int64_t inputRank = getInputOperandRank();
3335 SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3336 auto inputSlice = builder.create<tensor::ExtractSliceOp>(
3337 location: loc, args: getInput(), args&: sliceOffsets, args&: sliceSizes, args&: inputStrides);
3338 tiledOperands.emplace_back(Args&: inputSlice);
3339
3340 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3341 if (failed(Result: getResultTilePosition(builder, resultNumber: 1, offsets, sizes, resultOffsets,
3342 resultSizes)))
3343 return failure();
3344
3345 int64_t outputRank = getOutputOperandRank();
3346 SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3347 auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3348 location: loc, args: getOutput(), args&: resultOffsets, args&: resultSizes, args&: outputStrides);
3349 tiledOperands.emplace_back(Args&: outputSlice);
3350
3351 SmallVector<Type> resultTypes;
3352 resultTypes.push_back(Elt: tiledOperands[1].getType());
3353 Operation *tiledOp =
3354 mlir::clone(b&: builder, op: getOperation(), newResultTypes: resultTypes, newOperands: tiledOperands);
3355
3356 return TilingResult{
3357 .tiledOps: {tiledOp},
3358 .tiledValues: SmallVector<Value>(tiledOp->getResults()),
3359 .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{inputSlice, outputSlice})};
3360}
3361
3362//===----------------------------------------------------------------------===//
3363// WinogradOutputTransformOp
3364//===----------------------------------------------------------------------===//
3365
3366LogicalResult WinogradOutputTransformOp::verify() {
3367 auto valueType = cast<ShapedType>(Val: getValue().getType());
3368 ArrayRef<int64_t> valueShape = valueType.getShape();
3369 int64_t valueH = valueShape[getValueAlphaHDim()];
3370 int64_t valueW = valueShape[getValueAlphaWDim()];
3371 int64_t valueTileH = valueShape[getValueTileHDim()];
3372 int64_t valueTileW = valueShape[getValueTileWDim()];
3373 WinogradConv2DFmr fmr = getFmr();
3374 int64_t m, r;
3375 std::tie(args&: m, args&: r) = getFmrFromWinogradConv2DFmr(fmr);
3376 bool leftTransform = valueH != 1;
3377 bool rightTransform = valueW != 1;
3378
3379 int64_t outputRank = getOutputOperandRank();
3380 SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3381 if (ShapedType::isDynamic(dValue: valueH) || ShapedType::isDynamic(dValue: valueTileH)) {
3382 expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3383 } else {
3384 if (valueH != (leftTransform ? m + r - 1 : 1))
3385 return emitOpError(message: "expect input height equals to input tile size");
3386 expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3387 }
3388 if (ShapedType::isDynamic(dValue: valueW) || ShapedType::isDynamic(dValue: valueTileW)) {
3389 expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3390 } else {
3391 if (valueW != (rightTransform ? m + r - 1 : 1))
3392 return emitOpError(message: "expect input width equals to input tile size");
3393 expectedOutputShape[getOutputWDim()] =
3394 (rightTransform ? m : 1) * valueTileW;
3395 }
3396 expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3397 expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3398
3399 auto outputType = cast<ShapedType>(Val: getOutput().getType());
3400 ArrayRef<int64_t> outputShape = outputType.getShape();
3401 if (failed(Result: verifyCompatibleShape(shape1: expectedOutputShape, shape2: outputShape))) {
3402 return emitOpError(message: "the output shape is not expected");
3403 }
3404 return success();
3405}
3406
3407SmallVector<Range>
3408WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3409 Location loc = getLoc();
3410 IntegerAttr zeroAttr = builder.getIndexAttr(value: 0);
3411 IntegerAttr oneAttr = builder.getIndexAttr(value: 1);
3412 Value value = getValue();
3413 int64_t valueRank = getValueOperandRank();
3414 SmallVector<Range> loopBounds(valueRank);
3415 for (unsigned dim = 0; dim < valueRank; ++dim) {
3416 loopBounds[dim].offset = zeroAttr;
3417 // alphaH, alphaW, tileH, tileW, N, F
3418 loopBounds[dim].size = getDimValue(builder, loc, v: value, dim);
3419 loopBounds[dim].stride = oneAttr;
3420 }
3421 return loopBounds;
3422}
3423
3424SmallVector<utils::IteratorType>
3425WinogradOutputTransformOp::getLoopIteratorTypes() {
3426 int64_t valueRank = getValueOperandRank();
3427 SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3428 utils::IteratorType::parallel);
3429 return iteratorTypes;
3430}
3431
3432LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3433 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3434 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3435 SmallVector<OpFoldResult> &resultSizes) {
3436 WinogradConv2DFmr fmr = getFmr();
3437 int64_t m, r;
3438 std::tie(args&: m, args&: r) = getFmrFromWinogradConv2DFmr(fmr);
3439
3440 Location loc = getLoc();
3441 MLIRContext *context = builder.getContext();
3442 auto identityAffineMap =
3443 AffineMap::get(dimCount: 1, symbolCount: 0, results: {builder.getAffineDimExpr(position: 0)}, context);
3444 auto affineMap =
3445 AffineMap::get(dimCount: 1, symbolCount: 0, results: {builder.getAffineDimExpr(position: 0) * m}, context);
3446
3447 ShapedType valueType = getValueOperandType();
3448 ArrayRef<int64_t> valueShape = valueType.getShape();
3449 int64_t valueH = valueShape[0];
3450 int64_t valueW = valueShape[1];
3451 Value mappedOffsetH = affine::makeComposedAffineApply(
3452 b&: builder, loc, map: (valueH != 1 ? affineMap : identityAffineMap),
3453 operands: offsets[getValueTileHDim()]);
3454 Value mappedOffsetW = affine::makeComposedAffineApply(
3455 b&: builder, loc, map: (valueW != 1 ? affineMap : identityAffineMap),
3456 operands: offsets[getValueTileWDim()]);
3457 Value mappedSizeH = affine::makeComposedAffineApply(
3458 b&: builder, loc, map: affineMap, operands: sizes[getValueTileHDim()]);
3459 Value mappedSizeW = affine::makeComposedAffineApply(
3460 b&: builder, loc, map: affineMap, operands: sizes[getValueTileWDim()]);
3461
3462 IntegerAttr oneAttr = builder.getI64IntegerAttr(value: 1);
3463 OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3464 OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3465 OpFoldResult sizeH =
3466 valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3467 OpFoldResult sizeW =
3468 valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3469
3470 resultOffsets.append(
3471 IL: {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3472 resultSizes.append(
3473 IL: {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3474 return success();
3475}
3476
3477/// Implement tiling for winograd_output_transform
3478/// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3479/// F). The output of winograd_output_transform is (N, H, W, F) Users can
3480/// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3481/// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3482/// for the sizes of tileH, tileW, N, F for one tile.
3483FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3484 OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3485 ArrayRef<OpFoldResult> sizes) {
3486 IntegerAttr oneAttr = builder.getI64IntegerAttr(value: 1);
3487 IntegerAttr zeroAttr = builder.getI64IntegerAttr(value: 0);
3488 Location loc = getLoc();
3489 SmallVector<Value> tiledOperands;
3490 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3491
3492 ShapedType valueType = getValueOperandType();
3493 ArrayRef<int64_t> valueShape = valueType.getShape();
3494 int64_t alphaH = valueShape[getValueAlphaHDim()];
3495 int64_t alphaW = valueShape[getValueAlphaWDim()];
3496 IntegerAttr alphaHAttr = builder.getI64IntegerAttr(value: alphaH);
3497 IntegerAttr alphaWAttr = builder.getI64IntegerAttr(value: alphaW);
3498
3499 sliceOffsets.append(IL: {zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3500 offsets[getValueTileWDim()], offsets[getValueNDim()],
3501 offsets[getValueFDim()]});
3502 sliceSizes.append(IL: {alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3503 sizes[getValueTileWDim()], sizes[getValueNDim()],
3504 sizes[getValueFDim()]});
3505 int64_t valueRank = getValueOperandRank();
3506 SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3507 auto valueSlice = builder.create<tensor::ExtractSliceOp>(
3508 location: loc, args: getValue(), args&: sliceOffsets, args&: sliceSizes, args&: sliceStrides);
3509 tiledOperands.emplace_back(Args&: valueSlice);
3510
3511 SmallVector<OpFoldResult> resultOffsets, resultSizes;
3512 if (failed(Result: getResultTilePosition(builder, resultNumber: 1, offsets, sizes, resultOffsets,
3513 resultSizes)))
3514 return failure();
3515
3516 int64_t outputRank = getOutputOperandRank();
3517 SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3518 auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3519 location: loc, args: getOutput(), args&: resultOffsets, args&: resultSizes, args&: strides);
3520 tiledOperands.emplace_back(Args&: outputSlice);
3521
3522 SmallVector<Type> resultTypes;
3523 resultTypes.push_back(Elt: tiledOperands[1].getType());
3524 Operation *tiledOp =
3525 mlir::clone(b&: builder, op: getOperation(), newResultTypes: resultTypes, newOperands: tiledOperands);
3526
3527 return TilingResult{
3528 .tiledOps: {tiledOp},
3529 .tiledValues: SmallVector<Value>(tiledOp->getResults()),
3530 .generatedSlices: llvm::to_vector(Range: ArrayRef<Operation *>{valueSlice, outputSlice})};
3531}
3532
3533//===----------------------------------------------------------------------===//
3534// LinalgDialect
3535// TODO: Merge with the LinalgDialect block at the bottom
3536//===----------------------------------------------------------------------===//
3537
3538// Returns true if the result expression of `subMap` are a subset of `fullMap`.
3539static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3540 auto explicitRange = subMap.getResults();
3541 auto defaultRange = fullMap.getResults();
3542 DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3543 DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3544 llvm::set_union(S1&: explicitSet, S2: defaultSet);
3545 return explicitSet == defaultSet;
3546}
3547
3548/// Check if the user defined map is valid broadcast map. Here broadcast
3549/// indexing maps are defined in context of corresponding default indexing maps
3550/// for the given Op. This way the check becomes very simple i.e just check the
3551/// number of result dims.
3552/// Returns true if the explictMap is broadcasted with respect to the
3553/// defaultMap.
3554static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3555 return explictMap.getNumResults() < defaultMap.getNumResults();
3556}
3557
3558/// Verifies the broadcast and transpose semantic sepecified by the explicit
3559/// indexing map for the MatmulOp \p op for each operand specified by \p
3560/// opIndex.
3561static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3562 unsigned opIndex) {
3563 SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3564 SmallVector<AffineMap, 3> defaultIndexingMaps =
3565 matmulOp.getDefaultIndexingMaps(context: matmulOp->getContext());
3566
3567 auto opIndexingMap = opIndexingMaps[opIndex];
3568 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3569 // Check general validity of indexing map results.
3570 if (!areResultExprsSubsetOf(subMap: opIndexingMap, fullMap: defaultIndexingMap))
3571 return matmulOp->emitOpError()
3572 << "Unexpected dim expression in map result.";
3573
3574 if (isBroadcasted(explictMap: opIndexingMap, defaultMap: defaultIndexingMap)) {
3575 if (!matmulOp.isValidLhsRhsBroadcastMap(bcastMap: opIndexingMap)) {
3576 return matmulOp->emitOpError()
3577 << "Invalid broadcast requested, should be (d2).";
3578 }
3579 return success();
3580 }
3581 return success();
3582}
3583
3584// Check general validity of input indexing map of
3585// BatchMatmulOp/BatchReduceMatmulOp.
3586template <typename OpTy>
3587static LogicalResult verifyInputMaps(OpTy batchVariantMatmulOp,
3588 AffineMap opIndexingMap,
3589 AffineMap defaultIndexingMap, bool isLHS) {
3590 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3591 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3592 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3593 // Check the result dims are valid.
3594 if (!areResultExprsSubsetOf(subMap: opIndexingMap, fullMap: defaultIndexingMap))
3595 return batchVariantMatmulOp->emitOpError()
3596 << "Unexpected result dim expression (outside the set of default "
3597 "result dims).";
3598
3599 // Check for valid number of result dims of input maps.
3600 if (opIndexingMap.getNumResults() > 3)
3601 return batchVariantMatmulOp->emitOpError()
3602 << "no. of result dim expressions exceeds 3.";
3603
3604 auto hasValidBatchDim = [](AffineMap map) {
3605 AffineExpr batchDim = map.getResult(idx: 0);
3606 return batchDim.isFunctionOfDim(position: 0);
3607 };
3608
3609 // Check if the requested broadcast is valid.
3610 if (isBroadcasted(explictMap: opIndexingMap, defaultMap: defaultIndexingMap)) {
3611 if (!batchVariantMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3612 return batchVariantMatmulOp->emitOpError()
3613 << "Invalid broadcast requested.";
3614 } else if (!hasValidBatchDim(opIndexingMap)) {
3615 return batchVariantMatmulOp->emitOpError()
3616 << "Invalid batch dimension expression.";
3617 }
3618 return success();
3619}
3620
3621/// This function checks if the given AffineMap for the output of a
3622/// BatchMatmulOp/BatchReduceMatmulOp has exactly the desired number of result
3623/// dimensions and if the output map result dimensions are valid.
3624template <typename OpTy>
3625static LogicalResult verifyOutputMap(OpTy batchVariantMatmulOp,
3626 AffineMap opIndexingMap) {
3627 assert((isa<BatchMatmulOp>(batchVariantMatmulOp) ||
3628 isa<BatchReduceMatmulOp>(batchVariantMatmulOp)) &&
3629 "Expected BatchMatmulOp or BatchReduceMatmulOp");
3630 if (isa<BatchMatmulOp>(batchVariantMatmulOp) &&
3631 opIndexingMap.getNumResults() != 3) {
3632
3633 return batchVariantMatmulOp->emitOpError()
3634 << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3635 << ").";
3636 }
3637 if (isa<BatchReduceMatmulOp>(batchVariantMatmulOp) &&
3638 opIndexingMap.getNumResults() != 2) {
3639 return batchVariantMatmulOp->emitOpError()
3640 << "expects 2 dims, but got (" << opIndexingMap.getNumResults()
3641 << ").";
3642 }
3643
3644 auto areValidOutputResultDim = [&](AffineMap outputMap) {
3645 return isa<BatchMatmulOp>(batchVariantMatmulOp)
3646 ? outputMap.getResult(idx: 0).isFunctionOfDim(position: 0) &&
3647 outputMap.getResult(idx: 1).isFunctionOfDim(position: 1) &&
3648 outputMap.getResult(idx: 2).isFunctionOfDim(position: 2)
3649 : outputMap.getResult(idx: 0).isFunctionOfDim(position: 1) &&
3650 outputMap.getResult(idx: 1).isFunctionOfDim(position: 2);
3651 };
3652
3653 if (!areValidOutputResultDim(opIndexingMap)) {
3654 return batchVariantMatmulOp->emitOpError()
3655 << "Invalid output map result dimension.";
3656 }
3657
3658 return success();
3659}
3660
3661/// Verifies the broadcast and transpose semantic specified by the explicit
3662/// indexing map for the BatchMatmulOp/BatchReduceMatmulOp op for each operand
3663/// specified by opIndex.
3664template <typename OpTy>
3665static LogicalResult
3666verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
3667 unsigned opIndex) {
3668 SmallVector<AffineMap, 3> opIndexingMaps =
3669 batchVariantMatmulOp.getIndexingMapsArray();
3670 SmallVector<AffineMap, 3> defaultIndexingMaps =
3671 batchVariantMatmulOp.getDefaultIndexingMaps(
3672 batchVariantMatmulOp->getContext());
3673
3674 if (opIndexingMaps.size() != 3)
3675 return batchVariantMatmulOp->emitOpError()
3676 << "Indexing_map attribute must have 3 affine maps.";
3677
3678 auto opIndexingMap = opIndexingMaps[opIndex];
3679 auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3680
3681 if (opIndex == 2 &&
3682 failed(verifyOutputMap(batchVariantMatmulOp, opIndexingMap)))
3683 return failure();
3684
3685 if (opIndex != 2 &&
3686 failed(verifyInputMaps(batchVariantMatmulOp, opIndexingMap,
3687 defaultIndexingMap, opIndex == 0)))
3688 return failure();
3689
3690 return success();
3691}
3692
3693namespace mlir {
3694namespace linalg {
3695
3696std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
3697 if (m == 2 && r == 3)
3698 return WinogradConv2DFmr::F_2_3;
3699 if (m == 4 && r == 3)
3700 return WinogradConv2DFmr::F_4_3;
3701 if (m == 2 && r == 5)
3702 return WinogradConv2DFmr::F_2_5;
3703 return std::nullopt;
3704}
3705
3706std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
3707 switch (fmr) {
3708 case WinogradConv2DFmr::F_2_3:
3709 return {2, 3};
3710 case WinogradConv2DFmr::F_4_3:
3711 return {4, 3};
3712 case WinogradConv2DFmr::F_2_5:
3713 return {2, 5};
3714 }
3715}
3716
3717//===----------------------------------------------------------------------===//
3718// MatMulOp
3719//===----------------------------------------------------------------------===//
3720
3721/// Returns a list of AffineMap with the typical matmul indexing charactristic.
3722SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3723 AffineExpr d0, d1, d2;
3724 SmallVector<AffineMap> indexingMaps;
3725 bindDims(ctx: context, exprs&: d0, exprs&: d1, exprs&: d2);
3726 indexingMaps.push_back(Elt: AffineMap::get(dimCount: 3, symbolCount: 0, results: {d0, d2}, context));
3727 indexingMaps.push_back(Elt: AffineMap::get(dimCount: 3, symbolCount: 0, results: {d2, d1}, context));
3728 indexingMaps.push_back(Elt: AffineMap::get(dimCount: 3, symbolCount: 0, results: {d0, d1}, context));
3729 return indexingMaps;
3730}
3731
3732SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3733 return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3734 utils::IteratorType::parallel,
3735 utils::IteratorType::reduction};
3736}
3737
3738unsigned MatmulOp::getNumRegionArgs() { return 3; }
3739
3740std::string MatmulOp::getLibraryCallName() {
3741 return generateLibraryCallName(op: getOperation());
3742}
3743
3744bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3745
3746/// Check if the op has broadcast and/or transpose semantic. Returns true if
3747/// the user defined indexing maps are not equal to default map.
3748bool MatmulOp::hasUserDefinedMaps() {
3749 SmallVector<AffineMap, 3> defaultMaps =
3750 getDefaultIndexingMaps(context: this->getContext());
3751 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3752 return defaultMaps != explicitMaps;
3753}
3754
3755/// Implements the block region builder for the MatmulOp. This is called by
3756/// 'fillStructuredOpRegion'.
3757void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3758 ArrayRef<NamedAttribute> attrs,
3759 function_ref<InFlightDiagnostic()> emitError) {
3760 if (emitError && block.getNumArguments() != 3) {
3761 emitError() << "MatmulOp regionBuilder expects 3 args, got "
3762 << block.getNumArguments();
3763 return;
3764 }
3765 assert(block.getNumArguments() == 3 &&
3766 "MatmulOp regionBuilder expects 3 args");
3767 RegionBuilderHelper helper(b, block);
3768 SmallVector<Value> yields;
3769
3770 TypeFn castVal = TypeFn::cast_signed;
3771 const auto *castIter = llvm::find_if(Range&: attrs, P: [&](const NamedAttribute &attr) {
3772 return attr.getName() == "cast";
3773 });
3774 if (castIter != attrs.end()) {
3775 if (auto attr = llvm::dyn_cast<TypeFnAttr>(Val: castIter->getValue()))
3776 castVal = attr.getValue();
3777 }
3778
3779 Value value1 = helper.buildTypeFn(typeFn: castVal, toType: block.getArgument(i: 2).getType(),
3780 operand: block.getArgument(i: 0));
3781 Value value2 = helper.buildTypeFn(typeFn: castVal, toType: block.getArgument(i: 2).getType(),
3782 operand: block.getArgument(i: 1));
3783 Value value3 = helper.buildBinaryFn(binaryFn: BinaryFn::mul, arg0: value1, arg1: value2, emitError);
3784 if (!value3)
3785 return;
3786 Value value4 = helper.buildBinaryFn(binaryFn: BinaryFn::add, arg0: block.getArgument(i: 2),
3787 arg1: value3, emitError);
3788 if (!value4)
3789 return;
3790 yields.push_back(Elt: value4);
3791 helper.yieldOutputs(values: yields);
3792}
3793
3794/// Returns true if the given bcastMap map is a valid broadcast map. A valid
3795/// broadcast map must include K dimension.
3796/// TODO: Strict inclusion of K dimension in the broadcast map is not
3797/// necessary for both input matrices simultaneously. We can relax this
3798/// condition to have K dimension for one input matrix map and infer the K
3799/// dimension for other input matrix map from the one already having K
3800/// dimension.
3801bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3802 assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3803 AffineExpr expr = bcastMap.getResult(idx: 0);
3804 // Invalid map if the common dimension of matmul not found.
3805 return expr.isFunctionOfDim(position: bcastMap.getNumDims() - 1);
3806}
3807
3808FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3809 if (parser.parseOptionalKeyword(keyword: "indexing_maps"))
3810 return ArrayAttr{
3811 nullptr}; // Success in case indexing_maps was not provided.
3812
3813 ArrayAttr arrayAttr;
3814 if (parser.parseEqual() || parser.parseAttribute(result&: arrayAttr))
3815 return failure();
3816
3817 if (llvm::any_of(Range&: arrayAttr,
3818 P: [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3819 return parser.emitError(loc: parser.getCurrentLocation())
3820 << "element of indexing_maps array is not an affine_map";
3821
3822 return arrayAttr;
3823}
3824
3825ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3826 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3827 if (failed(Result: indexingMapsAttr))
3828 return failure();
3829
3830 if (*indexingMapsAttr == nullptr) {
3831 auto indexingMapAttrs = llvm::map_to_vector(
3832 C: MatmulOp::getDefaultIndexingMaps(context: parser.getContext()),
3833 F: [](AffineMap map) -> Attribute { return AffineMapAttr::get(value: map); });
3834 indexingMapsAttr = parser.getBuilder().getArrayAttr(value: indexingMapAttrs);
3835 }
3836
3837 result.addAttribute(name: "indexing_maps", attr: *indexingMapsAttr);
3838 return parseNamedStructuredOp(parser, result, numRegionArgs: MatmulOp::getNumRegionArgs(),
3839 regionBuilder: MatmulOp::getRegionBuilder());
3840}
3841
3842void MatmulOp::print(OpAsmPrinter &p) {
3843 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
3844 C: MatmulOp::getDefaultIndexingMaps(context: getContext()),
3845 F: [](AffineMap map) -> Attribute { return AffineMapAttr::get(value: map); });
3846 if (!llvm::equal(LRange: getIndexingMaps(), RRange&: indexingMaps))
3847 p << " indexing_maps = " << llvm::interleaved_array(R: getIndexingMaps());
3848
3849 std::array<StringRef, 3> elidedAttrs = {
3850 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3851 printNamedStructuredOp(p, op: getOperation(), inputs: getInputs(), outputs: getOutputs(),
3852 elidedAttrs);
3853}
3854
3855/// Verify the user defined indexing maps.
3856LogicalResult MatmulOp::verify() {
3857 // Verification of pure matmul is handled by verifyStructuredOpInterface().
3858 if (!hasUserDefinedMaps())
3859 return success();
3860
3861 for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3862 if (failed(Result: verifyExtendedMatmulSemantic(matmulOp: *this, opIndex)))
3863 return failure();
3864 }
3865 return success();
3866}
3867
3868LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3869 return memref::foldMemRefCast(op: *this);
3870}
3871
3872void MatmulOp::getEffects(
3873 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3874 &effects) {
3875 if (hasPureTensorSemantics())
3876 return;
3877 getGenericEffectsImpl(effects, linalgOp: cast<LinalgOp>(Val: getOperation()));
3878}
3879
3880Speculation::Speculatability MatmulOp::getSpeculatability() {
3881 return getGenericSpeculatabilityImpl(linalgOp: cast<LinalgOp>(Val: getOperation()));
3882}
3883
3884//===----------------------------------------------------------------------===//
3885// ContractOp
3886//===----------------------------------------------------------------------===//
3887
3888SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
3889 AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3890 // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
3891 // domains are all the same, and each implements a projected permutation.
3892 // Each iteration space dim must occur for at least one operand and either
3893 // takes part in a contraction/reduction or else has parallel iteration type.
3894 // We have that a dim is a contraction/reduction dim if and only if the dim
3895 // occurs for the output operand. We use this fact for fast inference:
3896 // NB: In case we allow dims to occur solely for one input, the above still
3897 // holds: per the einsum semantics, these are reduction dims as well.
3898 SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
3899 for (auto result : outAffineMap.getResults()) {
3900 auto dimExpr = dyn_cast<AffineDimExpr>(Val&: result);
3901 assert(dimExpr && "affine_map is a projected permutation");
3902 dimsInOutput[dimExpr.getPosition()] = true;
3903 }
3904
3905 SmallVector<utils::IteratorType> iteratorTypes;
3906 for (auto dimOccursInOutput : dimsInOutput)
3907 iteratorTypes.push_back(Elt: dimOccursInOutput ? utils::IteratorType::parallel
3908 : utils::IteratorType::reduction);
3909
3910 return iteratorTypes;
3911}
3912
3913unsigned ContractOp::getNumRegionArgs() { return 3; }
3914
3915/// Implement block region builder, which is called by 'fillStructuredOpRegion'.
3916void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3917 ArrayRef<NamedAttribute> attrs,
3918 function_ref<InFlightDiagnostic()> emitError) {
3919 if (emitError && block.getNumArguments() != 3) {
3920 emitError() << "ContractOp regionBuilder expects 3 args, got "
3921 << block.getNumArguments();
3922 return;
3923 }
3924 assert(block.getNumArguments() == 3 &&
3925 "ContractOp regionBuilder expects 3 args");
3926 RegionBuilderHelper helper(b, block);
3927
3928 TypeFn castSignedness = TypeFn::cast_signed;
3929 auto castIter = llvm::find_if(Range&: attrs, P: [&](const NamedAttribute &attr) {
3930 return attr.getName() == "cast";
3931 });
3932 if (castIter != attrs.end()) {
3933 if (auto attr = llvm::dyn_cast<TypeFnAttr>(Val: castIter->getValue()))
3934 castSignedness = attr.getValue();
3935 }
3936
3937 // TODO: Support fields with operators besides mult & add.
3938 Type outType = block.getArgument(i: 2).getType();
3939 Value lhsAtOutType =
3940 helper.buildTypeFn(typeFn: castSignedness, toType: outType, operand: block.getArgument(i: 0));
3941 Value rhsAtOutType =
3942 helper.buildTypeFn(typeFn: castSignedness, toType: outType, operand: block.getArgument(i: 1));
3943 Value productAtOutType = helper.buildBinaryFn(binaryFn: BinaryFn::mul, arg0: lhsAtOutType,
3944 arg1: rhsAtOutType, emitError);
3945 if (!productAtOutType)
3946 return;
3947 Value result = helper.buildBinaryFn(binaryFn: BinaryFn::add, arg0: block.getArgument(i: 2),
3948 arg1: productAtOutType, emitError);
3949 if (!result)
3950 return;
3951 helper.yieldOutputs(values: {result});
3952}
3953
3954ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
3955 FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3956 if (failed(Result: indexingMapsAttr) || *indexingMapsAttr == nullptr)
3957 return parser.emitError(loc: parser.getCurrentLocation(),
3958 message: "expected 'indexing_maps' attribute");
3959 result.addAttribute(name: "indexing_maps", attr: *indexingMapsAttr);
3960
3961 return parseNamedStructuredOp(parser, result, numRegionArgs: getNumRegionArgs(),
3962 regionBuilder);
3963}
3964
3965void ContractOp::print(OpAsmPrinter &p) {
3966 p << " indexing_maps = " << llvm::interleaved_array(R: getIndexingMaps());
3967 printNamedStructuredOp(
3968 p, op: getOperation(), inputs: getInputs(), outputs: getOutputs(),
3969 /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
3970}
3971
3972LogicalResult ContractOp::verify() {
3973 int iterationSpaceDims = -1;
3974 // Map iter space dims to #occurrences in inputs' and output's affine_maps:
3975 // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
3976 // access an input operand (so occurrence count can be at most 2) and
3977 // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
3978 SmallVector<size_t> inOccurrences;
3979 SmallVector<size_t> outOccurrences;
3980
3981 // A helper so that for each operand's affine_map and type we check that ...
3982 auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
3983 bool isInput) -> LogicalResult {
3984 // ... the affine_map is a projected permutation;
3985 if (!affineMap.isProjectedPermutation())
3986 return emitError(message: "provided affine_map is not a projected permutation");
3987
3988 // ... the rank of the affine_map's results and corresponding type match;
3989 if (auto shapedType = dyn_cast<ShapedType>(Val&: operandType)) {
3990 if (affineMap.getNumResults() != shapedType.getRank())
3991 return emitError(message: "ranks of shaped operand and results of corresponding "
3992 "affine_map differ");
3993 } else if (affineMap.getNumResults() != 0) {
3994 return emitError(message: "affine_map specifies shaped access while operand has "
3995 "non-shaped type");
3996 }
3997
3998 // ... the rank of the affine_map's domain is the same as those seen prior;
3999 if (iterationSpaceDims == -1) {
4000 iterationSpaceDims = affineMap.getNumDims();
4001 inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4002 outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
4003 } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
4004 return emitError(message: "iteration spaces of provided affine_maps differ");
4005 }
4006
4007 // ... update counts of dims used to access either an input or the output.
4008 for (AffineExpr affineExpr : affineMap.getResults()) {
4009 auto affineDimExpr = dyn_cast<AffineDimExpr>(Val&: affineExpr);
4010 if (!affineDimExpr)
4011 llvm_unreachable("affine_map is a projected permutation");
4012
4013 if (isInput)
4014 inOccurrences[affineDimExpr.getPosition()] += 1;
4015 else
4016 outOccurrences[affineDimExpr.getPosition()] += 1;
4017 }
4018
4019 return success();
4020 };
4021
4022 for (auto &&[affineMap, operandType, isInput] :
4023 llvm::zip(t: getIndexingMapsArray(), u: getOperandTypes(),
4024 args: SmallVector<bool>{true, true, false})) {
4025 if (failed(Result: checkAffineMapAndType(affineMap, operandType, isInput)))
4026 return failure(); // NB: checkAffineMapAndType will emit relevant error.
4027 }
4028
4029 bool hasContractingDim = false;
4030 for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
4031 size_t inOccCount = inOccurrences[dimIndex];
4032 size_t outOccCount = outOccurrences[dimIndex];
4033
4034 // We have a contracting dim if and only if ...
4035 hasContractingDim |= inOccCount == 2 && outOccCount == 0;
4036
4037 if (inOccCount == 0 && outOccCount == 0)
4038 return emitError() << "iteration space dim at index " << dimIndex
4039 << " not used to access any operand";
4040
4041 // NB: We disallow a dim which occurs for only one input operand and not
4042 // for the output. In terms of einsum semantics such dims have a
4043 // sensible meaning - namely an additional reduction per each such dim.
4044 // By contrast, the ContractionOpInterface does not know about this
4045 // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
4046 // while vector.contract's verifier accepts dims of this kind many of
4047 // its lowerings give up on encountering these dims.
4048 // TODO: Remove following once we have comprehensive support for input-only
4049 // reduction dims, at both the linalg- and vector-dialect levels.
4050 if (inOccCount == 1 && outOccCount != 1)
4051 return emitError()
4052 << "iteration space dim at index " << dimIndex
4053 << " is neither a contracting dim nor of parallel iteration type";
4054 }
4055
4056 if (!hasContractingDim)
4057 return emitError(message: "'indexing_maps' do not specify a contracting dimension");
4058
4059 return success();
4060}
4061
4062LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
4063 return memref::foldMemRefCast(op: *this);
4064}
4065
4066void ContractOp::getEffects(
4067 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4068 &effects) {
4069 if (hasPureTensorSemantics())
4070 return;
4071 getGenericEffectsImpl(effects, linalgOp: cast<LinalgOp>(Val: getOperation()));
4072}
4073
4074Speculation::Speculatability ContractOp::getSpeculatability() {
4075 return getGenericSpeculatabilityImpl(linalgOp: cast<LinalgOp>(Val: getOperation()));
4076}
4077
4078//===----------------------------------------------------------------------===//
4079// Implementation of BatchMatmulOp
4080//===----------------------------------------------------------------------===//
4081SmallVector<AffineMap>
4082BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
4083 AffineExpr d0, d1, d2, d3;
4084 SmallVector<AffineMap> indexingMaps;
4085 bindDims(ctx: context, exprs&: d0, exprs&: d1, exprs&: d2, exprs&: d3);
4086 indexingMaps.push_back(Elt: AffineMap::get(dimCount: 4, symbolCount: 0, results: {d0, d1, d3}, context));
4087 indexingMaps.push_back(Elt: AffineMap::get(dimCount: 4, symbolCount: 0, results: {d0, d3, d2}, context));
4088 indexingMaps.push_back(Elt: AffineMap::get(dimCount: 4, symbolCount: 0, results: {d0, d1, d2}, context));
4089 return indexingMaps;
4090}
4091
4092SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
4093 return SmallVector<utils::IteratorType>{
4094 utils::IteratorType::parallel, utils::IteratorType::parallel,
4095 utils::IteratorType::parallel, utils::IteratorType::reduction};
4096}
4097
4098unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
4099
4100std::string BatchMatmulOp::getLibraryCallName() {
4101 return generateLibraryCallName(op: getOperation());
4102}
4103
4104/// Check if the op has broadcast and/or transpose semantic. Returns true if
4105/// the user defined indexing maps are not equal to default map.
4106bool BatchMatmulOp::hasUserDefinedMaps() {
4107 SmallVector<AffineMap, 3> defaultMaps =
4108 getDefaultIndexingMaps(context: this->getContext());
4109 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
4110 return defaultMaps != explicitMaps;
4111}
4112
4113/// Returns true if the given bcastMap map is a valid broadcast map. A valid
4114/// broadcast map must include K dimension.
4115/// TODO: Strict inclusion of K dimension in the broadcast map is not
4116/// necessary for both input matrices simultaneously. We can relax this
4117/// condition to have K dimension for one input matrix map and infer the K
4118/// dimension for other input matrix map from the one already having K
4119/// dimension.
4120bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
4121 assert(bcastMap.getNumResults() < 3 &&
4122 "Expected less than 3 result dim expr.");
4123 bool isValid = false;
4124 enum Indices { batchPos, mPos, nPos, kPos };
4125 if (bcastMap.getNumResults() == 1) {
4126 AffineExpr expr = bcastMap.getResult(idx: 0);
4127 isValid = expr.isFunctionOfDim(position: kPos);
4128 } else if (bcastMap.getNumResults() == 2) {
4129 AffineExpr expr0 = bcastMap.getResult(idx: 0);
4130 AffineExpr expr1 = bcastMap.getResult(idx: 1);
4131 isValid =
4132 isLHS ? ((expr0.isFunctionOfDim(position: batchPos) ||
4133 expr0.isFunctionOfDim(position: mPos)) &&
4134 expr1.isFunctionOfDim(position: kPos))
4135 : ((expr0.isFunctionOfDim(position: batchPos) &&
4136 expr1.isFunctionOfDim(position: kPos)) ||
4137 (expr0.isFunctionOfDim(position: kPos) && expr1.isFunctionOfDim(position: nPos)));
4138 }
4139 return isValid;
4140}
4141
4142void BatchMatmulOp::regionBuilder(
4143 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4144 function_ref<InFlightDiagnostic()> emitError) {
4145 if (emitError && block.getNumArguments() != 3) {
4146 emitError() << "BatchMatmulOp regionBuilder expects 3 args, got "
4147 << block.getNumArguments();
4148 return;
4149 }
4150 assert(block.getNumArguments() == 3 &&
4151 "BatchMatmulOp regionBuilder expects 3 args");
4152 RegionBuilderHelper helper(b, block);
4153 SmallVector<Value> yields;
4154
4155 TypeFn castVal = TypeFn::cast_signed;
4156 auto castIter = llvm::find_if(Range&: attrs, P: [&](const NamedAttribute &attr) {
4157 return attr.getName() == "cast";
4158 });
4159 if (castIter != attrs.end()) {
4160 if (auto attr = llvm::dyn_cast<TypeFnAttr>(Val: castIter->getValue()))
4161 castVal = attr.getValue();
4162 }
4163
4164 auto toType = block.getArgument(i: 2).getType();
4165 Value castValA = helper.buildTypeFn(typeFn: castVal, toType, operand: block.getArgument(i: 0));
4166 Value castValB = helper.buildTypeFn(typeFn: castVal, toType, operand: block.getArgument(i: 1));
4167 Value mulVal = helper.buildBinaryFn(binaryFn: BinaryFn::mul, arg0: castValA, arg1: castValB);
4168 Value addVal =
4169 helper.buildBinaryFn(binaryFn: BinaryFn::add, arg0: block.getArgument(i: 2), arg1: mulVal);
4170 yields.push_back(Elt: addVal);
4171 helper.yieldOutputs(values: yields);
4172}
4173
4174ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
4175 SmallVector<Attribute, 3> indexingMapsAttr;
4176 Attribute mapAttr;
4177 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "indexing_maps"))) {
4178 if (parser.parseEqual())
4179 return failure();
4180
4181 if (parser.parseLSquare())
4182 return failure();
4183
4184 do {
4185 if (parser.parseAttribute(result&: mapAttr))
4186 return failure();
4187 if (!isa<AffineMapAttr>(Val: mapAttr)) {
4188 return parser.emitError(loc: parser.getCurrentLocation(),
4189 message: "expected affine map attribute");
4190 }
4191 indexingMapsAttr.push_back(Elt: mapAttr);
4192
4193 if (parser.parseOptionalComma())
4194 break;
4195 } while (true);
4196
4197 if (parser.parseRSquare())
4198 return failure();
4199 }
4200 // Initialize indexingMaps, if not supplied explicitly.
4201 if (indexingMapsAttr.empty()) {
4202 indexingMapsAttr = llvm::map_to_vector(
4203 C: BatchMatmulOp::getDefaultIndexingMaps(context: parser.getContext()),
4204 F: [](AffineMap map) -> Attribute { return AffineMapAttr::get(value: map); });
4205 }
4206 result.addAttribute(name: "indexing_maps",
4207 attr: parser.getBuilder().getArrayAttr(value: indexingMapsAttr));
4208
4209 return ::parseNamedStructuredOp(parser, result,
4210 numRegionArgs: BatchMatmulOp::getNumRegionArgs(),
4211 regionBuilder: BatchMatmulOp::getRegionBuilder());
4212}
4213
4214void BatchMatmulOp::print(OpAsmPrinter &p) {
4215 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4216 C: BatchMatmulOp::getDefaultIndexingMaps(context: getContext()),
4217 F: [](AffineMap map) -> Attribute { return AffineMapAttr::get(value: map); });
4218 if (!llvm::equal(LRange: getIndexingMaps(), RRange&: indexingMaps))
4219 p << " indexing_maps = " << llvm::interleaved_array(R: getIndexingMaps());
4220
4221 std::array<StringRef, 3> elidedAttrs = {
4222 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4223 ::printNamedStructuredOp(p, op: getOperation(), inputs: getInputs(), outputs: getOutputs(),
4224 elidedAttrs);
4225}
4226
4227/// Verify the user defined indexing maps.
4228LogicalResult BatchMatmulOp::verify() {
4229 // Verification of pure batch_matmul is handled by
4230 // verifyStructuredOpInterface().
4231 if (!hasUserDefinedMaps())
4232 return success();
4233
4234 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4235 if (failed(Result: verifyExtendedBatchVariantMatmulSemantic(batchVariantMatmulOp: *this, opIndex)))
4236 return failure();
4237 }
4238 return success();
4239}
4240
4241LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4242 SmallVectorImpl<OpFoldResult> &) {
4243 return memref::foldMemRefCast(op: *this);
4244}
4245
4246void BatchMatmulOp::getEffects(
4247 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4248 &effects) {
4249 if (hasPureTensorSemantics())
4250 return;
4251 getGenericEffectsImpl(effects, linalgOp: cast<LinalgOp>(Val: getOperation()));
4252}
4253
4254Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4255 return getGenericSpeculatabilityImpl(linalgOp: cast<LinalgOp>(Val: getOperation()));
4256}
4257
4258//===----------------------------------------------------------------------===//
4259// ElementwiseOp
4260//===----------------------------------------------------------------------===//
4261//
4262namespace {
4263struct ArityGroupAndKind {
4264 // The enum class {Unary, Binary, Ternary, ..}
4265 ElementwiseArityGroup arityGroup;
4266
4267 // The kind (e.g. `exp` or `add`) belonging to the arity group.
4268 union Kind {
4269 UnaryFn unaryFn;
4270 BinaryFn binaryFn;
4271 TernaryFn ternaryFn;
4272 } kind;
4273};
4274
4275unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4276 return static_cast<unsigned>(arityGroup);
4277}
4278} // namespace
4279
4280static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4281 constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4282 constexpr int lastBinary =
4283 static_cast<int>(ElementwiseCaseLimits::LastBinary);
4284 constexpr int lastTernary =
4285 static_cast<int>(ElementwiseCaseLimits::LastTernary);
4286
4287 int val = static_cast<int>(kind);
4288 ArityGroupAndKind result;
4289
4290 if (val < lastUnary) {
4291 result.arityGroup = ElementwiseArityGroup::Unary;
4292 result.kind.unaryFn = static_cast<UnaryFn>(val);
4293 return result;
4294 }
4295 if (val < lastBinary) {
4296 result.arityGroup = ElementwiseArityGroup::Binary;
4297 result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4298 return result;
4299 }
4300 if (val >= lastTernary) {
4301 llvm_unreachable("unhandled ElementwiseFn");
4302 }
4303 result.arityGroup = ElementwiseArityGroup::Ternary;
4304 result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4305 return result;
4306}
4307
4308SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4309 auto rank = getResultRank();
4310 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4311}
4312
4313SmallVector<AffineMap>
4314ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4315 MLIRContext *context) {
4316 auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4317 return SmallVector<AffineMap>(numMaps, map);
4318}
4319
4320ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4321 // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4322 Attribute attr;
4323 mlir::linalg::ElementwiseKind elemwiseKindVal;
4324 if (parser.parseKeyword(keyword: "kind") || parser.parseEqual())
4325 return failure();
4326
4327 if (succeeded(Result: parser.parseAttribute(result&: attr))) {
4328 auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(Val&: attr);
4329 if (!elemwiseKindAttr)
4330 return parser.emitError(loc: parser.getCurrentLocation(),
4331 message: "expected ElementwiseKind attribute");
4332 elemwiseKindVal = elemwiseKindAttr.getValue();
4333 } else {
4334 return parser.emitError(loc: parser.getCurrentLocation(),
4335 message: "expected operation 'kind' attribute");
4336 }
4337 result.addAttribute(
4338 name: "kind", attr: ElementwiseKindAttr::get(context: parser.getContext(), value: elemwiseKindVal));
4339
4340 // Parse optional `indexing_maps`
4341 SmallVector<Attribute, 3> indexingMapsAttr;
4342 Attribute mapAttr;
4343 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "indexing_maps"))) {
4344 if (parser.parseEqual())
4345 return failure();
4346 if (parser.parseLSquare())
4347 return failure();
4348 do {
4349 if (parser.parseAttribute(result&: mapAttr))
4350 return failure();
4351 if (!isa<AffineMapAttr>(Val: mapAttr))
4352 return parser.emitError(loc: parser.getCurrentLocation(),
4353 message: "expected affine map attribute");
4354 indexingMapsAttr.push_back(Elt: mapAttr);
4355 if (parser.parseOptionalComma())
4356 break;
4357 } while (true);
4358 if (parser.parseRSquare())
4359 return failure();
4360 }
4361 // At this stage of parsing the only way to infer number of region
4362 // args is through op kind, as input output tensors are not parsed yet.
4363 auto arityGroupAndKind = getArityGroupAndKind(kind: elemwiseKindVal);
4364 int numRegionArgs =
4365 getArityGroupAsUInt(arityGroup: arityGroupAndKind.arityGroup) + 1 /*output*/;
4366 if (parseNamedStructuredOp(parser, result, numRegionArgs,
4367 regionBuilder: ElementwiseOp::getRegionBuilder())) {
4368 return parser.emitError(loc: parser.getCurrentLocation(),
4369 message: "unable to parse elemwise op");
4370 }
4371
4372 // Initialize indexingMaps, if not supplied explicitly.
4373 if (indexingMapsAttr.empty()) {
4374 // We need to infer the numDims of the indexing maps from the output
4375 // type which is already parsed by now.
4376 auto resultType = result.operands[result.operands.size() - 1].getType();
4377 auto shapedType = llvm::dyn_cast<ShapedType>(Val&: resultType);
4378 if (!shapedType)
4379 return parser.emitError(loc: parser.getCurrentLocation(),
4380 message: "return type needs to be shaped type");
4381 auto numDims = shapedType.getRank();
4382 indexingMapsAttr = llvm::map_to_vector(
4383 C: ElementwiseOp::getDefaultIndexingMaps(numMaps: numRegionArgs, numDims,
4384 context: parser.getContext()),
4385 F: [](AffineMap map) -> Attribute { return AffineMapAttr::get(value: map); });
4386 }
4387
4388 result.addAttribute(name: "indexing_maps",
4389 attr: parser.getBuilder().getArrayAttr(value: indexingMapsAttr));
4390 return success();
4391}
4392
4393void ElementwiseOp::print(OpAsmPrinter &p) {
4394 p << " kind=";
4395 p.printAttribute(attr: getKindAttr());
4396 SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4397 "indexing_maps"};
4398 unsigned arity =
4399 getArityGroupAsUInt(arityGroup: getArityGroupAndKind(kind: getKind()).arityGroup);
4400 unsigned numDims = getResultRank();
4401
4402 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector<3>(
4403 C: ElementwiseOp::getDefaultIndexingMaps(numMaps: arity + 1 /*output*/, numDims,
4404 context: getContext()),
4405 F: [](AffineMap map) -> Attribute { return AffineMapAttr::get(value: map); });
4406
4407 if (!llvm::equal(LRange: getIndexingMaps(), RRange&: indexingMaps))
4408 p << " indexing_maps = " << llvm::interleaved_array(R: getIndexingMaps());
4409
4410 printNamedStructuredOp(p, op: getOperation(), inputs: getInputs(), outputs: getOutputs(),
4411 elidedAttrs);
4412}
4413
4414LogicalResult ElementwiseOp::verify() {
4415 // All necessary checks are done either by
4416 // - EnumAttr (e.g. unknown operation kind)
4417 // - verifyStructuredOpInterface (incorrect map, sizes).
4418 return success();
4419}
4420
4421/// Implements the block region builder for the ElementwiseOp. This is called by
4422/// 'fillStructuredOpRegion'.
4423void ElementwiseOp::regionBuilder(
4424 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
4425 function_ref<InFlightDiagnostic()> emitError) {
4426 ElementwiseKind elemwiseKind;
4427 for (auto attr : attrs) {
4428 if (attr.getName() == b.getStringAttr(bytes: "kind")) {
4429 auto kindAttr = dyn_cast<ElementwiseKindAttr>(Val: attr.getValue());
4430 assert(kindAttr && "op kind attribute incorrectly set");
4431 elemwiseKind = kindAttr.getValue();
4432 break;
4433 }
4434 }
4435
4436 ArityGroupAndKind groupAndKind = getArityGroupAndKind(kind: elemwiseKind);
4437 auto arityGroup = groupAndKind.arityGroup;
4438 auto kind = groupAndKind.kind;
4439 if (emitError && block.getNumArguments() !=
4440 getArityGroupAsUInt(arityGroup) + 1 /*output*/) {
4441 emitError() << "Elementwise regionBuilder expects "
4442 << (getArityGroupAsUInt(arityGroup) + 1) << " args, got "
4443 << block.getNumArguments();
4444 return;
4445 }
4446 assert(block.getNumArguments() ==
4447 getArityGroupAsUInt(arityGroup) + 1 /*output*/
4448 && "Elementwise regionBuilder number of block args mismatch");
4449
4450 RegionBuilderHelper helper(b, block);
4451 SmallVector<Value> yields;
4452 Value result;
4453
4454 if (arityGroup == ElementwiseArityGroup::Unary) {
4455 result = helper.buildUnaryFn(unaryFn: kind.unaryFn, arg: block.getArgument(i: 0));
4456
4457 } else if (arityGroup == ElementwiseArityGroup::Binary) {
4458 result = helper.buildBinaryFn(binaryFn: kind.binaryFn, arg0: block.getArgument(i: 0),
4459 arg1: block.getArgument(i: 1));
4460
4461 } else if (arityGroup == ElementwiseArityGroup::Ternary) {
4462 result = helper.buildTernaryFn(ternaryFn: kind.ternaryFn, arg0: block.getArgument(i: 0),
4463 arg1: block.getArgument(i: 1), arg2: block.getArgument(i: 2));
4464
4465 } else {
4466 assert(false && "found unhandled category in elemwise");
4467 }
4468
4469 yields.push_back(Elt: result);
4470 helper.yieldOutputs(values: yields);
4471}
4472
4473LogicalResult ElementwiseOp::fold(FoldAdaptor,
4474 SmallVectorImpl<OpFoldResult> &) {
4475 return memref::foldMemRefCast(op: *this);
4476}
4477
4478void ElementwiseOp::getEffects(
4479 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4480 &effects) {
4481 if (hasPureTensorSemantics())
4482 return;
4483 getGenericEffectsImpl(effects, linalgOp: cast<LinalgOp>(Val: getOperation()));
4484}
4485
4486Speculation::Speculatability ElementwiseOp::getSpeculatability() {
4487 return getGenericSpeculatabilityImpl(linalgOp: cast<LinalgOp>(Val: getOperation()));
4488}
4489
4490//===----------------------------------------------------------------------===//
4491// PackOp/UnPackOp Common
4492//===----------------------------------------------------------------------===//
4493// Given the (potentially) updated packed type, `newPackedTy`, generates an
4494// updated mixed-tile-sizes attribute. A tile size is updated only
4495// when:
4496// * a dim from newPackedTy is static, and
4497// * the corresponding size from mixedTiles is still dynamic.
4498// Otherwise, the original tile size is preserved.
4499// Note - packed-type-dim and mixed-tile-size should always match!
4500static SmallVector<OpFoldResult>
4501getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
4502 SmallVector<OpFoldResult> mixedTiles) {
4503 SmallVector<OpFoldResult> newMixedTileSizes;
4504 for (auto it : llvm::zip(t: cast<ShapedType>(Val&: newPackedTy)
4505 .getShape()
4506 .take_back(N: mixedTiles.size()),
4507 u&: mixedTiles)) {
4508 int64_t shape = std::get<0>(t&: it);
4509 if (shape == ShapedType::kDynamic) {
4510 newMixedTileSizes.push_back(Elt: std::get<1>(t&: it));
4511 continue;
4512 }
4513
4514 // If the current result dim is static, update the dynamic mixed-size
4515 // (provided the original value is dynamic).
4516 OpFoldResult tile = std::get<1>(t&: it);
4517 if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(Val&: tile)) {
4518 // Already a constant
4519 newMixedTileSizes.push_back(Elt: tile);
4520 } else {
4521 assert(getConstantIntValue(tile).value() == shape &&
4522 "tile size and dim size don't match!");
4523 newMixedTileSizes.push_back(
4524 Elt: (rewriter.getIntegerAttr(type: rewriter.getIndexType(), value: shape)));
4525 }
4526 }
4527
4528 return newMixedTileSizes;
4529}
4530
4531template <typename OpTy>
4532static LogicalResult
4533reifyResultShapesImpl(OpTy op, OpBuilder &builder,
4534 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4535 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4536 "applies to only pack or unpack operations");
4537 int64_t destRank = op.getDestRank();
4538 reifiedReturnShapes.resize(N: 1, NV: SmallVector<OpFoldResult>(destRank));
4539 reifiedReturnShapes[0] =
4540 tensor::getMixedSizes(builder, loc: op.getLoc(), value: op.getDest());
4541 return success();
4542}
4543
4544template <typename OpTy>
4545static DenseMap<int64_t, OpFoldResult> getDimAndTileMappingImpl(OpTy op) {
4546 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4547 "applies to only pack or unpack operations");
4548 DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
4549 ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
4550 SmallVector<OpFoldResult> tiles = op.getMixedTiles();
4551 assert(tiles.size() == dimsToTile.size() &&
4552 "tiles must match indices of dimension to block");
4553 // bind the dimension `i` with the tile factor.
4554 for (auto i : llvm::seq<int64_t>(Begin: 0, End: dimsToTile.size()))
4555 dimAndTileMapping[dimsToTile[i]] = tiles[i];
4556 return dimAndTileMapping;
4557}
4558
4559template <typename OpTy>
4560static SmallVector<OpFoldResult> getMixedTilesImpl(OpTy op) {
4561 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4562 "applies to only pack or unpack operations");
4563 Builder builder(op);
4564 SmallVector<OpFoldResult> mixedInnerTiles;
4565 unsigned dynamicValIndex = 0;
4566 for (int64_t staticTile : op.getStaticInnerTiles()) {
4567 if (ShapedType::isStatic(dValue: staticTile))
4568 mixedInnerTiles.push_back(Elt: builder.getI64IntegerAttr(value: staticTile));
4569 else
4570 mixedInnerTiles.push_back(Elt: op.getInnerTiles()[dynamicValIndex++]);
4571 }
4572 return mixedInnerTiles;
4573}
4574
4575template <typename OpTy>
4576static SmallVector<int64_t> getStaticTilesImpl(OpTy op) {
4577 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4578 "applies to only pack or unpack operations");
4579 SmallVector<Value> dynamicTiles;
4580 SmallVector<int64_t> staticTiles;
4581 dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
4582 return staticTiles;
4583}
4584
4585/// Returns true if `dimsPos` is invalid. It is invalid when:
4586/// a) It contains duplicate.
4587/// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
4588/// c) The number of elements in `dimsPos` is > than `rank`.
4589static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
4590 size_t rank) {
4591 size_t dimsPosSize = dimsPos.size();
4592 if (dimsPosSize > rank)
4593 return true;
4594 DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
4595 if (dimsPosSize != uniqued.size())
4596 return true;
4597 return llvm::any_of(Range&: dimsPos, P: [rank](int64_t dimPos) {
4598 return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
4599 });
4600}
4601
4602/// Returns true if the dimension of `sourceShape` is smaller than the dimension
4603/// of the `limitShape`.
4604static bool areAllInBound(ArrayRef<int64_t> sourceShape,
4605 ArrayRef<int64_t> limitShape) {
4606 assert(
4607 sourceShape.size() == limitShape.size() &&
4608 "expected source shape rank, and limit of the shape to have same rank");
4609 return llvm::all_of(
4610 Range: llvm::zip(t&: sourceShape, u&: limitShape), P: [](std::tuple<int64_t, int64_t> it) {
4611 int64_t sourceExtent = std::get<0>(t&: it);
4612 int64_t limit = std::get<1>(t&: it);
4613 return ShapedType::isDynamic(dValue: sourceExtent) ||
4614 ShapedType::isDynamic(dValue: limit) || sourceExtent <= limit;
4615 });
4616}
4617
4618template <typename OpTy>
4619static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
4620 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4621 "applies to only pack or unpack operations");
4622 Operation *op = packOrUnPack.getOperation();
4623
4624 // Return true if we have a zero-value tile.
4625 auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
4626 return llvm::any_of(Range&: tiles, P: isZeroInteger);
4627 };
4628
4629 // Verify tiles. Do not allow zero tiles.
4630 SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
4631 if (hasZeros(mixedTiles))
4632 return op->emitError(message: "invalid zero tile factor");
4633
4634 // Verify inner_dims_pos and outer_dims_perm.
4635 RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4636 ? packOrUnPack.getSourceType()
4637 : packOrUnPack.getDestType();
4638 size_t unpackedRank = unpackedType.getRank();
4639 ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
4640 ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
4641 if (isInvalidPackingPosSpecification(dimsPos: innerDimsPos, rank: unpackedRank))
4642 return op->emitError(message: "invalid inner_dims_pos vector");
4643 if (isInvalidPackingPosSpecification(dimsPos: outerDimPerm, rank: unpackedRank))
4644 return op->emitError(message: "invalid outer_dims_perm vector");
4645 if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
4646 return op->emitError(message: "outer_dims_perm must be a permutation or empty");
4647
4648 // Tiling factors must be less than or equal to the input rank for pack (or
4649 // output rank for unpack), and must match the number of `inner_dims_pos`.
4650 if (mixedTiles.size() > unpackedRank) {
4651 return op->emitError(message: "tiling factors must be less than or equal to the "
4652 "input rank for pack or output rank for unpack");
4653 }
4654 if (mixedTiles.size() != innerDimsPos.size()) {
4655 return op->emitError(
4656 message: "tiling factors must equal the number of dimensions to tile");
4657 }
4658
4659 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4660 ? packOrUnPack.getDestType()
4661 : packOrUnPack.getSourceType();
4662 size_t packedRank = packedType.getRank();
4663 // Require output rank to match input rank + number of blocking factors.
4664 size_t expectedPackedRank = unpackedRank + mixedTiles.size();
4665 if (expectedPackedRank != packedRank) {
4666 return op->emitError(
4667 message: "packed rank != (unpacked rank + num tiling factors), got ")
4668 << packedRank << " != " << expectedPackedRank;
4669 }
4670
4671 // Verify result shape is greater than the minimum expected
4672 // by the pack operation, and that the output shape
4673 // represents full tiles.
4674 RankedTensorType expectedPackedType = PackOp::inferPackedType(
4675 sourceType: unpackedType, innerTileSizes: packOrUnPack.getStaticTiles(), innerDimsPos, outerDimsPerm: outerDimPerm);
4676 if (!areAllInBound(sourceShape: expectedPackedType.getShape(), limitShape: packedType.getShape())) {
4677 return op->emitError(message: "the shape of output is not large enough to hold the "
4678 "packed data. Expected at least ")
4679 << expectedPackedType << ", got " << packedType;
4680 }
4681 if (!llvm::all_of(
4682 llvm::zip(t: packedType.getShape().take_back(N: mixedTiles.size()),
4683 u&: mixedTiles),
4684 [](std::tuple<int64_t, OpFoldResult> it) {
4685 int64_t shape = std::get<0>(t&: it);
4686 if (Attribute attr =
4687 llvm::dyn_cast_if_present<Attribute>(Val&: std::get<1>(t&: it))) {
4688 IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(Val&: attr);
4689 int64_t staticTileSize = intAttr.getValue().getSExtValue();
4690 return shape == staticTileSize;
4691 }
4692 return ShapedType::isDynamic(dValue: shape);
4693 })) {
4694 return op->emitError(message: "mismatch in inner tile sizes specified and shaped of "
4695 "tiled dimension in the packed type");
4696 }
4697 return success();
4698}
4699
4700namespace {
4701/// Subset of PackOp/UnPackOp fields used to compute the result of applying
4702/// various permutations to the op.
4703// TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
4704// these. These may or may not become true foldings / canonicalizations
4705// depending on how aggressive we want to be in automatically folding
4706// transposes.
4707struct PackOrUnPackTransposeResult {
4708 SmallVector<int64_t> innerDimsPos;
4709 SmallVector<OpFoldResult> innerTiles;
4710 SmallVector<int64_t> outerDimsPerm;
4711};
4712} // namespace
4713
4714template <typename OpTy>
4715static PackOrUnPackTransposeResult
4716commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp,
4717 ArrayRef<int64_t> innerPermutation,
4718 ArrayRef<int64_t> outerPermutation) {
4719 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4720 "applies to only pack or unpack operations");
4721 assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4722 "some permutation must be non-empty");
4723 PackOrUnPackTransposeResult metadata;
4724 metadata.innerDimsPos =
4725 SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
4726 metadata.innerTiles =
4727 SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
4728 int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4729 ? packOrUnPackOp.getSourceRank()
4730 : packOrUnPackOp.getDestRank();
4731 metadata.outerDimsPerm =
4732 packOrUnPackOp.getOuterDimsPerm().empty()
4733 ? llvm::to_vector(Range: llvm::seq<int64_t>(Begin: 0, End: numOuterDims))
4734 : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
4735 if (!innerPermutation.empty()) {
4736 assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4737 isPermutationVector(innerPermutation) &&
4738 "invalid inner permutation");
4739 applyPermutationToVector(inVec&: metadata.innerDimsPos, permutation: innerPermutation);
4740 applyPermutationToVector(inVec&: metadata.innerTiles, permutation: innerPermutation);
4741 }
4742 if (!outerPermutation.empty()) {
4743 assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4744 isPermutationVector(outerPermutation) &&
4745 "invalid outer permutation");
4746 applyPermutationToVector(inVec&: metadata.outerDimsPerm, permutation: outerPermutation);
4747 }
4748 return metadata;
4749}
4750
4751//===----------------------------------------------------------------------===//
4752// PackOp
4753//===----------------------------------------------------------------------===//
4754
4755void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
4756 setNameFn(getResult(), "pack");
4757}
4758
4759void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
4760 Value dest, ArrayRef<int64_t> innerDimsPos,
4761 ArrayRef<OpFoldResult> innerTiles,
4762 std::optional<Value> paddingValue,
4763 ArrayRef<int64_t> outerDimsPerm) {
4764 assert(innerDimsPos.size() == innerTiles.size() &&
4765 "number of tile sizes specified must match the specified number of "
4766 "original dimensions to be tiled");
4767 SmallVector<int64_t> staticTileSizes;
4768 SmallVector<Value> dynamicTileSizes;
4769 dispatchIndexOpFoldResults(ofrs: innerTiles, dynamicVec&: dynamicTileSizes, staticVec&: staticTileSizes);
4770 build(odsBuilder&: builder, odsState&: state, result: dest.getType(), source, dest,
4771 padding_value: paddingValue ? *paddingValue : nullptr,
4772 outer_dims_perm: outerDimsPerm.empty() ? nullptr
4773 : builder.getDenseI64ArrayAttr(values: outerDimsPerm),
4774 inner_dims_pos: builder.getDenseI64ArrayAttr(values: innerDimsPos), inner_tiles: dynamicTileSizes,
4775 static_inner_tiles: builder.getDenseI64ArrayAttr(values: staticTileSizes));
4776}
4777
4778LogicalResult
4779PackOp::reifyResultShapes(OpBuilder &builder,
4780 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4781 return reifyResultShapesImpl(op: *this, builder, reifiedReturnShapes);
4782}
4783
4784DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
4785 return getDimAndTileMappingImpl(op: *this);
4786}
4787
4788SmallVector<OpFoldResult> PackOp::getMixedTiles() {
4789 return getMixedTilesImpl(op: *this);
4790}
4791
4792SmallVector<int64_t> PackOp::getStaticTiles() {
4793 return getStaticTilesImpl(op: *this);
4794}
4795
4796ArrayRef<int64_t> PackOp::getAllOuterDims() {
4797 ShapedType inputType = getSourceType();
4798 int64_t inputRank = inputType.getRank();
4799 return getDestType().getShape().take_front(N: inputRank);
4800}
4801
4802SmallVector<int64_t> PackOp::getTiledOuterDims() {
4803 auto innerDimsPos = getInnerDimsPos();
4804 auto packedShape = getDestType().getShape();
4805 SmallVector<int64_t> res;
4806
4807 for (auto index : innerDimsPos)
4808 res.push_back(Elt: packedShape[index]);
4809
4810 return res;
4811}
4812
4813bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
4814 ArrayRef<int64_t> innerDimsPos,
4815 ArrayRef<int64_t> outputShape,
4816 ArrayRef<int64_t> outerDimsPerm,
4817 ArrayRef<OpFoldResult> innerTiles) {
4818 SmallVector<int64_t> outputTileSizes(
4819 outputShape.take_front(N: inputShape.size()));
4820 if (!outerDimsPerm.empty()) {
4821 assert(outerDimsPerm.size() == outputTileSizes.size() &&
4822 "expected output and outer_dims_perm to have same size");
4823 applyPermutationToVector(inVec&: outputTileSizes,
4824 permutation: invertPermutationVector(permutation: outerDimsPerm));
4825 }
4826 for (auto [pos, tileSize] : llvm::zip_equal(t&: innerDimsPos, u&: innerTiles)) {
4827 if (ShapedType::isDynamic(dValue: inputShape[pos]))
4828 continue;
4829 std::optional<int64_t> constantTile = getConstantIntValue(ofr: tileSize);
4830
4831 if (!constantTile) {
4832 if (ShapedType::isStatic(dValue: outputTileSizes[pos]) &&
4833 (inputShape[pos] % outputTileSizes[pos] != 0))
4834 return true;
4835 } else if (inputShape[pos] % (*constantTile) != 0) {
4836 return true;
4837 }
4838 }
4839 return false;
4840}
4841
4842LogicalResult PackOp::verify() {
4843 if (failed(Result: commonVerifierPackAndUnPackOp(packOrUnPack: *this)))
4844 return failure();
4845
4846 // Verify padding value, and bail out if the tile does not divide the
4847 // dimension fully. In the case of dynamic tile factors or dimensions, having
4848 // a partial tile is undefined behavior.
4849 auto paddingValue = getPaddingValue();
4850 if (paddingValue &&
4851 paddingValue.getType() != getSourceType().getElementType()) {
4852 return emitOpError(message: "expected padding_value has ")
4853 << getSourceType().getElementType()
4854 << " but got: " << paddingValue.getType();
4855 }
4856
4857 if (!paddingValue &&
4858 requirePaddingValue(inputShape: getSourceType().getShape(), innerDimsPos: getInnerDimsPos(),
4859 outputShape: getDestType().getShape(), outerDimsPerm: getOuterDimsPerm(),
4860 innerTiles: getMixedTiles())) {
4861 return emitOpError(
4862 message: "invalid tile factor or output size provided. Only full tiles are "
4863 "supported when padding_value is not set");
4864 }
4865 return success();
4866}
4867
4868/// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
4869/// Value's to kDynamic, even if they are arith.constant values.
4870static SmallVector<int64_t>
4871asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
4872 SmallVector<int64_t> result;
4873 for (auto o : ofrs) {
4874 // Have to do this first, as getConstantIntValue special-cases constants.
4875 if (llvm::dyn_cast_if_present<Value>(Val&: o))
4876 result.push_back(Elt: ShapedType::kDynamic);
4877 else
4878 result.push_back(Elt: getConstantIntValue(ofr: o).value_or(u: ShapedType::kDynamic));
4879 }
4880 return result;
4881}
4882
4883/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
4884/// the packed type. Having a shared helper helps implement these two methods in
4885/// a way that ensures that they agree on which dimensions are dynamic.
4886static SmallVector<int64_t> getPackOpResultTypeShape(
4887 ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
4888 ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
4889 SmallVector<int64_t> resultShape = llvm::to_vector(Range&: sourceShape);
4890 for (auto tiledDim : llvm::enumerate(First: llvm::to_vector(Range&: innerDimsPos))) {
4891 if (ShapedType::isDynamic(dValue: resultShape[tiledDim.value()]))
4892 continue;
4893 if (ShapedType::isDynamic(dValue: innerTileSizes[tiledDim.index()])) {
4894 resultShape[tiledDim.value()] = ShapedType::kDynamic;
4895 continue;
4896 }
4897 resultShape[tiledDim.value()] = llvm::divideCeilSigned(
4898 Numerator: resultShape[tiledDim.value()], Denominator: innerTileSizes[tiledDim.index()]);
4899 }
4900
4901 // Swap tile loops if outer_dims_perm is available.
4902 if (!outerDimsPerm.empty())
4903 applyPermutationToVector(inVec&: resultShape, permutation: outerDimsPerm);
4904
4905 // Append the inner tile dimensions.
4906 resultShape.append(in_start: innerTileSizes.begin(), in_end: innerTileSizes.end());
4907 return resultShape;
4908}
4909
4910SmallVector<OpFoldResult> PackOp::getResultShape(
4911 OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
4912 ArrayRef<OpFoldResult> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
4913 ArrayRef<int64_t> outerDimsPerm) {
4914 SmallVector<OpFoldResult> resultDims = llvm::to_vector(Range&: sourceDims);
4915
4916 AffineExpr s0, s1;
4917 bindSymbols(ctx: builder.getContext(), exprs&: s0, exprs&: s1);
4918 AffineExpr ceilDivExpr = s0.ceilDiv(other: s1);
4919 for (auto tiledDim : llvm::enumerate(First: llvm::to_vector(Range&: innerDimsPos))) {
4920 resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
4921 b&: builder, loc, expr: ceilDivExpr,
4922 operands: {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4923 }
4924 if (!outerDimsPerm.empty())
4925 applyPermutationToVector(inVec&: resultDims, permutation: outerDimsPerm);
4926 resultDims.append(in_start: innerTileSizes.begin(), in_end: innerTileSizes.end());
4927
4928 SmallVector<int64_t> resultTypeShape =
4929 getPackOpResultTypeShape(sourceShape: asShapeWithAnyValueAsDynamic(ofrs: sourceDims),
4930 innerTileSizes: asShapeWithAnyValueAsDynamic(ofrs: innerTileSizes),
4931 innerDimsPos, outerDimsPerm);
4932
4933 // Fix-up `resultDims` to ensure that they are Value's if and only if the
4934 // result type shape says it's a dynamic dim. This is needed as callers may
4935 // use dispatchIndexOpFoldResults on the result, and rely on exact number of
4936 // dynamic dims returned by that.
4937 for (unsigned i = 0; i < resultDims.size(); ++i) {
4938 if (ShapedType::isStatic(dValue: resultTypeShape[i]))
4939 continue;
4940 resultDims[i] =
4941 getValueOrCreateConstantIndexOp(b&: builder, loc, ofr: resultDims[i]);
4942 }
4943
4944 return resultDims;
4945}
4946
4947/// Get the expected packed type based on source type, tile factors, position of
4948/// the inner tiles and permutation of the outer tiled loop.
4949RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4950 ArrayRef<int64_t> innerTileSizes,
4951 ArrayRef<int64_t> innerDimsPos,
4952 ArrayRef<int64_t> outerDimsPerm) {
4953 SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
4954 sourceShape: sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4955 return RankedTensorType::get(shape: resultShape, elementType: sourceType.getElementType());
4956}
4957
4958Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
4959 ArrayRef<OpFoldResult> innerTileSizes,
4960 ArrayRef<int64_t> innerDimsPos,
4961 ArrayRef<int64_t> outerDimsPerm) {
4962 AffineExpr dim0, dim1;
4963 bindDims(ctx: b.getContext(), exprs&: dim0, exprs&: dim1);
4964 auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4965 return affine::makeComposedFoldedAffineApply(b, loc, expr: dim0.ceilDiv(other: dim1),
4966 operands: {v1, v2});
4967 };
4968
4969 SmallVector<OpFoldResult> mixedSizes;
4970 for (auto [index, value] : llvm::enumerate(
4971 First: llvm::cast<RankedTensorType>(Val: source.getType()).getShape())) {
4972 if (ShapedType::isDynamic(dValue: value))
4973 mixedSizes.push_back(
4974 Elt: b.create<tensor::DimOp>(location: loc, args&: source, args&: index).getResult());
4975 else
4976 mixedSizes.push_back(Elt: b.getIndexAttr(value));
4977 }
4978 for (auto it : llvm::zip(t&: innerDimsPos, u&: innerTileSizes)) {
4979 int64_t dimPos = std::get<0>(t&: it);
4980 OpFoldResult tileSize = std::get<1>(t&: it);
4981 mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4982 }
4983 if (!outerDimsPerm.empty())
4984 applyPermutationToVector<OpFoldResult>(inVec&: mixedSizes, permutation: outerDimsPerm);
4985
4986 mixedSizes.append(in_start: innerTileSizes.begin(), in_end: innerTileSizes.end());
4987 auto elemType = llvm::cast<ShapedType>(Val: source.getType()).getElementType();
4988 return b.create<tensor::EmptyOp>(location: loc, args&: mixedSizes, args&: elemType);
4989}
4990
4991PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
4992 ArrayRef<int64_t> innerPermutation,
4993 ArrayRef<int64_t> outerPermutation) {
4994 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4995 packOrUnPackOp: *this, innerPermutation, outerPermutation);
4996 Value transposedDest =
4997 createDestinationTensor(b, loc, source: getSource(), innerTileSizes: metadata.innerTiles,
4998 innerDimsPos: metadata.innerDimsPos, outerDimsPerm: metadata.outerDimsPerm);
4999 return b.create<PackOp>(location: loc, args: getSource(), args&: transposedDest,
5000 args&: metadata.innerDimsPos, args&: metadata.innerTiles,
5001 args: getPaddingValue(), args&: metadata.outerDimsPerm);
5002}
5003
5004/// Returns true if the tiles and the tiled dims are constant.
5005template <typename OpTy>
5006bool areTilesAndTiledDimsAllConstant(OpTy op) {
5007 static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
5008 "applies to only pack or unpack operations");
5009 ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
5010 ? op.getDestType()
5011 : op.getSourceType();
5012 SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
5013 for (auto [dimDest, tile] : llvm::zip(
5014 t: packedType.getShape().take_back(N: mixedTiles.size()), u&: mixedTiles)) {
5015 std::optional<int64_t> constTileSize = getConstantIntValue(ofr: tile);
5016 if (!constTileSize || ShapedType::isDynamic(dValue: dimDest))
5017 return false;
5018 }
5019 return true;
5020}
5021
5022Speculation::Speculatability PackOp::getSpeculatability() {
5023 if (getPaddingValue())
5024 return Speculation::Speculatable;
5025
5026 // The verifier rejects already operations if we can statically prove that the
5027 // sizes of the tiles do not divide perfectly the dimension; thus, check only
5028 // to have constant tiles and tiled inner dimensions.
5029 if (!areTilesAndTiledDimsAllConstant(op: *this))
5030 return Speculation::NotSpeculatable;
5031
5032 return Speculation::Speculatable;
5033}
5034
5035// Return true if `inner_dims_pos` and `outer_dims_perm` target the same
5036// dimensions for pack and unpack.
5037static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
5038 if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
5039 return false;
5040 if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
5041 return true;
5042 // Outer dims permutation is optional.
5043 // To compare unbalanced pack-unpack pair, treat no permutation as equal to
5044 // identity permutation.
5045 return isIdentityPermutation(permutation: packOp.getOuterDimsPerm()) &&
5046 isIdentityPermutation(permutation: unPackOp.getOuterDimsPerm());
5047}
5048
5049// Return true if pack and unpack have the same tiles.
5050// Same SSA values or same integer constants.
5051static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
5052 auto packTiles = packOp.getMixedTiles();
5053 auto unPackTiles = unPackOp.getMixedTiles();
5054 if (packTiles.size() != unPackTiles.size())
5055 return false;
5056 for (size_t i = 0, e = packTiles.size(); i < e; i++) {
5057 if (!isEqualConstantIntOrValue(ofr1: packTiles[i], ofr2: unPackTiles[i]))
5058 return false;
5059 }
5060 return true;
5061}
5062
5063/// Returns true if the pack op does not need a padding value.
5064static bool paddingIsNotNeeded(PackOp op) {
5065 auto srcType = op.getSourceType();
5066 if (llvm::any_of(Range: op.getInnerDimsPos(),
5067 P: [&](int64_t pos) { return srcType.isDynamicDim(idx: pos); }))
5068 return false;
5069 if (ShapedType::isDynamicShape(dSizes: op.getStaticInnerTiles()))
5070 return false;
5071 return !PackOp::requirePaddingValue(
5072 inputShape: srcType.getShape(), innerDimsPos: op.getInnerDimsPos(), outputShape: op.getDestType().getShape(),
5073 outerDimsPerm: op.getOuterDimsPerm(), innerTiles: op.getMixedTiles());
5074}
5075
5076/// Returns true if the `srcShape` or `destShape` is different from the one in
5077/// `packOp` and populates each with the inferred static shape.
5078static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
5079 SmallVectorImpl<int64_t> &destShape) {
5080 bool changeNeeded = false;
5081 srcShape.assign(in_start: packOp.getSourceType().getShape().begin(),
5082 in_end: packOp.getSourceType().getShape().end());
5083 destShape.assign(in_start: packOp.getDestType().getShape().begin(),
5084 in_end: packOp.getDestType().getShape().end());
5085 llvm::SmallSetVector<int64_t, 4> innerDims;
5086 innerDims.insert_range(R: packOp.getInnerDimsPos());
5087 SmallVector<int64_t> inverseOuterDimsPerm;
5088 if (!packOp.getOuterDimsPerm().empty())
5089 inverseOuterDimsPerm = invertPermutationVector(permutation: packOp.getOuterDimsPerm());
5090 int srcRank = packOp.getSourceRank();
5091 for (auto i : llvm::seq<int64_t>(Begin: 0, End: srcRank)) {
5092 if (innerDims.contains(key: i))
5093 continue;
5094 int64_t srcPos = i;
5095 int64_t destPos = i;
5096 if (!inverseOuterDimsPerm.empty())
5097 destPos = inverseOuterDimsPerm[srcPos];
5098 if (ShapedType::isDynamic(dValue: srcShape[srcPos]) ==
5099 ShapedType::isDynamic(dValue: destShape[destPos])) {
5100 continue;
5101 }
5102 int64_t size = srcShape[srcPos];
5103 if (ShapedType::isDynamic(dValue: size))
5104 size = destShape[destPos];
5105 srcShape[srcPos] = size;
5106 destShape[destPos] = size;
5107 changeNeeded = true;
5108 }
5109 return changeNeeded;
5110}
5111
5112LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
5113 // Fold an pack(unpack(x)) to x.
5114 if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5115 if (unPackOp.getSourceType() != packOp.getDestType())
5116 return failure();
5117 if (packOp.getPaddingValue() ||
5118 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
5119 !haveSameTiles(packOp, unPackOp))
5120 return failure();
5121 rewriter.replaceOp(op: packOp, newValues: unPackOp.getSource());
5122 return success();
5123 }
5124
5125 // Fold optional PaddingValue operand away if padding is not needed.
5126 if (packOp.getPaddingValue() && paddingIsNotNeeded(op: packOp)) {
5127 rewriter.startOpModification(op: packOp);
5128 packOp.getPaddingValueMutable().clear();
5129 rewriter.finalizeOpModification(op: packOp);
5130 return success();
5131 }
5132
5133 // Insert tensor.cast ops if static shape inference is available..
5134 SmallVector<int64_t> srcShape, destShape;
5135 if (inferStaticShape(packOp, srcShape, destShape)) {
5136 Location loc = packOp.getLoc();
5137 Value source = packOp.getSource();
5138 if (srcShape != packOp.getSourceType().getShape()) {
5139 auto newSrcType = packOp.getSourceType().clone(shape: srcShape);
5140 source =
5141 rewriter.create<tensor::CastOp>(location: loc, args&: newSrcType, args: packOp.getSource());
5142 }
5143 Value dest = packOp.getDest();
5144 RankedTensorType originalResultType = packOp.getDestType();
5145 bool needUpdateDestType = (destShape != originalResultType.getShape());
5146 if (needUpdateDestType) {
5147 auto newDestType = packOp.getDestType().clone(shape: destShape);
5148 dest =
5149 rewriter.create<tensor::CastOp>(location: loc, args&: newDestType, args: packOp.getDest());
5150 }
5151 rewriter.modifyOpInPlace(root: packOp, callable: [&] {
5152 packOp.getSourceMutable().assign(value: source);
5153 packOp.getDestMutable().assign(value: dest);
5154 packOp.getResult().setType(cast<RankedTensorType>(Val: dest.getType()));
5155 });
5156 // Insert a cast if needed
5157 if (needUpdateDestType) {
5158 rewriter.setInsertionPointAfter(packOp);
5159 auto castOp =
5160 rewriter.create<tensor::CastOp>(location: loc, args&: originalResultType, args&: packOp);
5161 rewriter.replaceAllUsesExcept(from: packOp, to: castOp, exceptedUser: castOp);
5162 }
5163 return success();
5164 }
5165
5166 return failure();
5167}
5168
5169template <typename PackOrUnpackOp>
5170static bool isLikePadUnPad(PackOrUnpackOp packOp,
5171 RankedTensorType packedTensorType) {
5172 static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
5173 std::is_same<PackOrUnpackOp, UnPackOp>::value,
5174 "Function meant for pack/unpack");
5175 // This is a pad if packing only adds ones and we don't transpose dimensions.
5176
5177 // Check that we are not transposing any dimensions.
5178 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
5179 int64_t numPackedDims = innerDimsPos.size();
5180 auto orderedDims = llvm::to_vector<4>(Range: llvm::seq<int64_t>(Begin: 0, End: numPackedDims));
5181 if (orderedDims != innerDimsPos) {
5182 // Dimensions don't happen in order.
5183 return false;
5184 }
5185
5186 ArrayRef<int64_t> packedShape = packedTensorType.getShape();
5187 int64_t packedRank = packedTensorType.getRank();
5188 // At this point we know that we are taking numPackedDims outer
5189 // dimensions and pushing them all the way as the inner most dimensions.
5190 // What's left on the outer most dimensions is, in this order:
5191 // - the factor of the packed dimensions, then
5192 // - the untouched dimensions
5193 // This shifting inward of dimensions is a no-op (as opposed to a transpose)
5194 // if all the dimensions that bubble outerward are ones.
5195 // Therefore check that all the dimensions but the numPackedDims inner most
5196 // ones are ones.
5197 return llvm::all_of(
5198 llvm::seq<int64_t>(Begin: 0, End: packedRank - numPackedDims),
5199 [&packedShape](int64_t i) { return packedShape[i] == 1; });
5200}
5201
5202bool PackOp::isLikePad() {
5203 auto packedTensorType =
5204 llvm::cast<RankedTensorType>(Val: (*this)->getResultTypes().front());
5205 return isLikePadUnPad(packOp: *this, packedTensorType);
5206}
5207
5208OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
5209 std::optional<Attribute> paddingValue;
5210 if (auto pad = adaptor.getPaddingValue())
5211 paddingValue = pad;
5212 if (OpFoldResult reshapedSource = reshapeConstantSource(
5213 source: llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getSource()),
5214 result: getDestType(), cst: paddingValue))
5215 return reshapedSource;
5216 return {};
5217}
5218
5219/// Folds a tensor.cast op into a consuming PackOp op if the
5220/// `tensor.cast` has source that is more static than the consuming op.
5221///
5222/// Example:
5223/// ```mlir
5224/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
5225/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
5226/// ```
5227///
5228/// folds into:
5229///
5230/// ```mlir
5231/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
5232/// ```
5233struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
5234 using OpRewritePattern<PackOp>::OpRewritePattern;
5235
5236 LogicalResult matchAndRewrite(PackOp op,
5237 PatternRewriter &rewriter) const override {
5238 if (!tensor::hasFoldableTensorCastOperand(op))
5239 return failure();
5240
5241 SmallVector<Type> newResultTypes(op->getResultTypes());
5242 SmallVector<Value> newOperands =
5243 tensor::getUpdatedOperandsAfterCastOpFolding(op, newResTy&: newResultTypes);
5244
5245 // Get the updated mixed-tile-sizes attribute.
5246 SmallVector<OpFoldResult> newMixedTileSizes =
5247 getNewMixedTileSizes(rewriter, newPackedTy: newResultTypes[0], mixedTiles: op.getMixedTiles());
5248
5249 // Clone op.
5250 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5251 // this point. However, in practice, we use them for things that we'd like
5252 // to preserve. Implement a better abstraction.
5253 PackOp newOp = rewriter.create<PackOp>(
5254 location: op.getLoc(), args&: newOperands[0], args&: newOperands[1], args: op.getInnerDimsPos(),
5255 args&: newMixedTileSizes, args: op.getPaddingValue(), args: op.getOuterDimsPerm());
5256 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5257
5258 // Replace op.
5259 Value oldResult = op.getResult();
5260 Value newResult = newOp.getResult();
5261 Value replacement = (newResult.getType() != oldResult.getType())
5262 ? rewriter.create<tensor::CastOp>(
5263 location: op->getLoc(), args: oldResult.getType(), args&: newResult)
5264 : newResult;
5265
5266 rewriter.replaceOp(op, newValues: {replacement});
5267
5268 return success();
5269 }
5270};
5271
5272//===----------------------------------------------------------------------===//
5273// UnPackOp
5274//===----------------------------------------------------------------------===//
5275
5276void UnPackOp::getAsmResultNames(
5277 function_ref<void(Value, StringRef)> setNameFn) {
5278 setNameFn(getResult(), "unpack");
5279}
5280
5281LogicalResult
5282UnPackOp::reifyResultShapes(OpBuilder &builder,
5283 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5284 return reifyResultShapesImpl(op: *this, builder, reifiedReturnShapes);
5285}
5286
5287DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
5288 return getDimAndTileMappingImpl(op: *this);
5289}
5290
5291SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
5292 return getMixedTilesImpl(op: *this);
5293}
5294
5295SmallVector<int64_t> UnPackOp::getStaticTiles() {
5296 return getStaticTilesImpl(op: *this);
5297}
5298
5299ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
5300 ShapedType destType = getDestType();
5301 int64_t destRank = destType.getRank();
5302 return getSourceType().getShape().take_front(N: destRank);
5303}
5304
5305SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
5306 auto innerDimsPos = getInnerDimsPos();
5307 auto packedShape = getSourceType().getShape();
5308 SmallVector<int64_t> res;
5309
5310 for (auto index : innerDimsPos)
5311 res.push_back(Elt: packedShape[index]);
5312
5313 return res;
5314}
5315
5316LogicalResult UnPackOp::verify() {
5317 return commonVerifierPackAndUnPackOp(packOrUnPack: *this);
5318}
5319
5320Speculation::Speculatability UnPackOp::getSpeculatability() {
5321 // See PackOp::getSpeculatability.
5322 if (!areTilesAndTiledDimsAllConstant(op: *this))
5323 return Speculation::NotSpeculatable;
5324
5325 return Speculation::Speculatable;
5326}
5327
5328void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
5329 Value dest, ArrayRef<int64_t> innerDimsPos,
5330 ArrayRef<OpFoldResult> innerTiles,
5331 ArrayRef<int64_t> outerDimsPerm) {
5332 assert(innerDimsPos.size() == innerTiles.size() &&
5333 "number of tile sizes specified must match the specified number of "
5334 "original dimensions to be tiled");
5335 SmallVector<int64_t> staticTileSizes;
5336 SmallVector<Value> dynamicTileSizes;
5337 dispatchIndexOpFoldResults(ofrs: innerTiles, dynamicVec&: dynamicTileSizes, staticVec&: staticTileSizes);
5338 build(odsBuilder&: builder, odsState&: state, result: dest.getType(), source, dest,
5339 outer_dims_perm: outerDimsPerm.empty() ? nullptr
5340 : builder.getDenseI64ArrayAttr(values: outerDimsPerm),
5341 inner_dims_pos: builder.getDenseI64ArrayAttr(values: innerDimsPos), inner_tiles: dynamicTileSizes,
5342 static_inner_tiles: builder.getDenseI64ArrayAttr(values: staticTileSizes));
5343}
5344
5345Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
5346 Value source,
5347 ArrayRef<OpFoldResult> innerTileSizes,
5348 ArrayRef<int64_t> innerDimsPos,
5349 ArrayRef<int64_t> outerDimsPerm) {
5350 AffineExpr sym0, sym1;
5351 bindSymbols(ctx: b.getContext(), exprs&: sym0, exprs&: sym1);
5352 auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5353 return affine::makeComposedFoldedAffineApply(b, loc, expr: sym0 * sym1, operands: {v1, v2});
5354 };
5355
5356 SmallVector<OpFoldResult> mixedSizes;
5357 auto srcType = llvm::cast<RankedTensorType>(Val: source.getType());
5358 for (auto i :
5359 llvm::seq<unsigned>(Begin: 0, End: srcType.getRank() - innerTileSizes.size())) {
5360 if (srcType.isDynamicDim(idx: i))
5361 mixedSizes.push_back(Elt: b.create<tensor::DimOp>(location: loc, args&: source, args&: i).getResult());
5362 else
5363 mixedSizes.push_back(Elt: b.getIndexAttr(value: srcType.getDimSize(idx: i)));
5364 }
5365 if (!outerDimsPerm.empty()) {
5366 applyPermutationToVector<OpFoldResult>(
5367 inVec&: mixedSizes, permutation: invertPermutationVector(permutation: outerDimsPerm));
5368 }
5369
5370 for (auto [dimPos, tileSize] : llvm::zip_equal(t&: innerDimsPos, u&: innerTileSizes))
5371 mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5372
5373 auto elemType = srcType.getElementType();
5374 return b.create<tensor::EmptyOp>(location: loc, args&: mixedSizes, args&: elemType);
5375}
5376
5377UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
5378 Value transposedSource,
5379 ArrayRef<int64_t> innerPermutation,
5380 ArrayRef<int64_t> outerPermutation) {
5381 PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5382 packOrUnPackOp: *this, innerPermutation, outerPermutation);
5383 return b.create<UnPackOp>(location: loc, args&: transposedSource, args: getDest(),
5384 args&: metadata.innerDimsPos, args&: metadata.innerTiles,
5385 args&: metadata.outerDimsPerm);
5386}
5387
5388/// Returns true if the `srcShape` or `destShape` is different from the one in
5389/// `op` and populates each with the inferred static shape.
5390static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
5391 SmallVectorImpl<int64_t> &destShape) {
5392 bool changeNeeded = false;
5393 srcShape.assign(in_start: op.getSourceType().getShape().begin(),
5394 in_end: op.getSourceType().getShape().end());
5395 destShape.assign(in_start: op.getDestType().getShape().begin(),
5396 in_end: op.getDestType().getShape().end());
5397 llvm::SmallSetVector<int64_t, 4> innerDims;
5398 innerDims.insert_range(R: op.getInnerDimsPos());
5399 SmallVector<int64_t> inverseOuterDimsPerm;
5400 if (!op.getOuterDimsPerm().empty())
5401 inverseOuterDimsPerm = invertPermutationVector(permutation: op.getOuterDimsPerm());
5402 int destRank = op.getDestRank();
5403 for (auto i : llvm::seq<int64_t>(Begin: 0, End: destRank)) {
5404 if (innerDims.contains(key: i))
5405 continue;
5406 int64_t srcPos = i;
5407 int64_t destPos = i;
5408 if (!inverseOuterDimsPerm.empty())
5409 srcPos = inverseOuterDimsPerm[destPos];
5410 if (ShapedType::isDynamic(dValue: srcShape[srcPos]) ==
5411 ShapedType::isDynamic(dValue: destShape[destPos])) {
5412 continue;
5413 }
5414 int64_t size = srcShape[srcPos];
5415 if (ShapedType::isDynamic(dValue: size))
5416 size = destShape[destPos];
5417 srcShape[srcPos] = size;
5418 destShape[destPos] = size;
5419 changeNeeded = true;
5420 }
5421 return changeNeeded;
5422}
5423
5424LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5425 PatternRewriter &rewriter) {
5426 /// unpack(pack(x)) -> x
5427 if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5428 if (packOp.getSourceType() != unPackOp.getDestType())
5429 return failure();
5430 if (packOp.getPaddingValue() ||
5431 !hasSameInnerOuterAttribute(packOp, unPackOp) ||
5432 !haveSameTiles(packOp, unPackOp))
5433 return failure();
5434 rewriter.replaceOp(op: unPackOp, newValues: packOp.getSource());
5435 return success();
5436 }
5437 /// unpack(destinationStyleOp(x)) -> unpack(x)
5438 if (auto dstStyleOp =
5439 unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5440 auto destValue = cast<OpResult>(Val: unPackOp.getDest());
5441 Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5442 rewriter.modifyOpInPlace(root: unPackOp,
5443 callable: [&]() { unPackOp.setDpsInitOperand(i: 0, value: newDest); });
5444 return success();
5445 }
5446 /// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
5447 if (unPackOp->hasOneUse()) {
5448 auto extractSliceUser =
5449 dyn_cast<tensor::ExtractSliceOp>(Val: *unPackOp->getUsers().begin());
5450 if (extractSliceUser &&
5451 areAllConstantIntValue(ofrs: extractSliceUser.getMixedOffsets(), value: 0) &&
5452 areAllConstantIntValue(ofrs: extractSliceUser.getMixedStrides(), value: 1) &&
5453 extractSliceUser.getSourceType().getRank() ==
5454 extractSliceUser.getResultType().getRank()) {
5455 OpBuilder::InsertionGuard g(rewriter);
5456 rewriter.setInsertionPoint(unPackOp);
5457 auto newDest = rewriter.create<tensor::ExtractSliceOp>(
5458 location: unPackOp->getLoc(), args: unPackOp.getDest(),
5459 args: extractSliceUser.getMixedOffsets(), args: extractSliceUser.getMixedSizes(),
5460 args: extractSliceUser.getMixedStrides());
5461 rewriter.modifyOpInPlace(root: unPackOp, callable: [&]() {
5462 unPackOp.setDpsInitOperand(i: 0, value: newDest);
5463 unPackOp.getResult().setType(newDest.getType());
5464 });
5465 rewriter.replaceOp(op: extractSliceUser, newOp: unPackOp);
5466 return success();
5467 }
5468 }
5469
5470 // Insert tensor.cast ops if static shape inference is available..
5471 SmallVector<int64_t> srcShape, destShape;
5472 if (inferStaticShape(op: unPackOp, srcShape, destShape)) {
5473 Location loc = unPackOp.getLoc();
5474 Value source = unPackOp.getSource();
5475 if (srcShape != unPackOp.getSourceType().getShape()) {
5476 auto newSrcType = unPackOp.getSourceType().clone(shape: srcShape);
5477 source = rewriter.create<tensor::CastOp>(location: loc, args&: newSrcType,
5478 args: unPackOp.getSource());
5479 }
5480 Value dest = unPackOp.getDest();
5481 if (destShape != unPackOp.getDestType().getShape()) {
5482 auto newDestType = unPackOp.getDestType().clone(shape: destShape);
5483 dest =
5484 rewriter.create<tensor::CastOp>(location: loc, args&: newDestType, args: unPackOp.getDest());
5485 }
5486 Value newOp = rewriter.create<UnPackOp>(
5487 location: loc, args&: source, args&: dest, args: unPackOp.getInnerDimsPos(), args: unPackOp.getMixedTiles(),
5488 args: unPackOp.getOuterDimsPerm());
5489 rewriter.replaceOpWithNewOp<tensor::CastOp>(
5490 op: unPackOp, args: unPackOp.getResult().getType(), args&: newOp);
5491 return success();
5492 }
5493
5494 return failure();
5495}
5496
5497bool UnPackOp::isLikeUnPad() {
5498 RankedTensorType packedTensorType = getSourceType();
5499 return isLikePadUnPad(packOp: *this, packedTensorType);
5500}
5501
5502OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
5503 if (OpFoldResult reshapedSource = reshapeConstantSource(
5504 source: llvm::dyn_cast_if_present<DenseElementsAttr>(Val: adaptor.getSource()),
5505 result: getResult().getType()))
5506 return reshapedSource;
5507 return {};
5508}
5509
5510/// Folds a tensor.cast op into a consuming UnPackOp op if the
5511/// `tensor.cast` has source that is more static than the consuming op.
5512///
5513/// Example:
5514/// ```mlir
5515/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
5516/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
5517/// ```
5518///
5519/// folds into:
5520///
5521/// ```mlir
5522/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
5523/// ```
5524struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
5525 using OpRewritePattern<UnPackOp>::OpRewritePattern;
5526
5527 LogicalResult matchAndRewrite(UnPackOp op,
5528 PatternRewriter &rewriter) const override {
5529 if (!tensor::hasFoldableTensorCastOperand(op))
5530 return failure();
5531
5532 SmallVector<Type> newResultTypes(op->getResultTypes());
5533 SmallVector<Value> newOperands =
5534 tensor::getUpdatedOperandsAfterCastOpFolding(op, newResTy&: newResultTypes);
5535 Value sourceTensor = newOperands[0];
5536
5537 // Get the updated mixed-tile-sizes attribute.
5538 SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
5539 rewriter, newPackedTy: sourceTensor.getType(), mixedTiles: op.getMixedTiles());
5540
5541 // Clone op.
5542 // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5543 // this point. However, in practice, we use them for things that we'd like
5544 // to preserve. Implement a better abstraction.
5545 UnPackOp newOp = rewriter.create<UnPackOp>(
5546 location: op.getLoc(), args&: sourceTensor, args&: newOperands[1], args: op.getInnerDimsPos(),
5547 args&: newMixedTileSizes, args: op.getOuterDimsPerm());
5548 newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5549
5550 // Replace op.
5551 Value oldResult = op.getResult();
5552 Value newResult = newOp.getResult();
5553 Value replacement = (newResult.getType() != oldResult.getType())
5554 ? rewriter.create<tensor::CastOp>(
5555 location: op->getLoc(), args: oldResult.getType(), args&: newResult)
5556 : newResult;
5557
5558 rewriter.replaceOp(op, newValues: {replacement});
5559
5560 return success();
5561 }
5562};
5563
5564//===----------------------------------------------------------------------===//
5565// BatchReduceMatmulOp
5566//===----------------------------------------------------------------------===//
5567SmallVector<utils::IteratorType> BatchReduceMatmulOp::getIteratorTypesArray() {
5568 return SmallVector<utils::IteratorType>{
5569 utils::IteratorType::reduction, utils::IteratorType::parallel,
5570 utils::IteratorType::parallel, utils::IteratorType::reduction};
5571}
5572
5573SmallVector<AffineMap>
5574BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
5575 AffineExpr d0, d1, d2, d3;
5576 SmallVector<AffineMap> indexingMaps;
5577 bindDims(ctx: context, exprs&: d0, exprs&: d1, exprs&: d2, exprs&: d3);
5578 indexingMaps.push_back(Elt: AffineMap::get(dimCount: 4, symbolCount: 0, results: {d0, d1, d3}, context));
5579 indexingMaps.push_back(Elt: AffineMap::get(dimCount: 4, symbolCount: 0, results: {d0, d3, d2}, context));
5580 indexingMaps.push_back(Elt: AffineMap::get(dimCount: 4, symbolCount: 0, results: {d1, d2}, context));
5581 return indexingMaps;
5582}
5583
5584unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
5585
5586std::string BatchReduceMatmulOp::getLibraryCallName() {
5587 return generateLibraryCallName(op: getOperation());
5588}
5589
5590/// Check if the op has broadcast and/or transpose semantic. Returns true if
5591/// the user defined indexing maps are not equal to default map.
5592bool BatchReduceMatmulOp::hasUserDefinedMaps() {
5593 SmallVector<AffineMap, 3> defaultMaps =
5594 getDefaultIndexingMaps(context: this->getContext());
5595 SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
5596 return defaultMaps != explicitMaps;
5597}
5598
5599/// Returns true if the given bcastMap map is a valid broadcast map. A valid
5600/// broadcast map must include K dimension.
5601/// TODO: Strict inclusion of K dimension in the broadcast map is not
5602/// necessary for both input matrices simultaneously. We can relax this
5603/// condition to have K dimension for one input matrix map and infer the K
5604/// dimension for other input matrix map from the one already having K
5605/// dimension.
5606bool BatchReduceMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap,
5607 bool isLHS) {
5608 assert(bcastMap.getNumResults() < 3 &&
5609 "Expected less than 3 result dim expr.");
5610 bool isValid = false;
5611 enum Indices { batchPos, mPos, nPos, kPos };
5612 if (bcastMap.getNumResults() == 1) {
5613 AffineExpr expr = bcastMap.getResult(idx: 0);
5614 isValid = expr.isFunctionOfDim(position: kPos);
5615 } else if (bcastMap.getNumResults() == 2) {
5616 AffineExpr expr0 = bcastMap.getResult(idx: 0);
5617 AffineExpr expr1 = bcastMap.getResult(idx: 1);
5618 isValid =
5619 isLHS ? ((expr0.isFunctionOfDim(position: batchPos) ||
5620 expr0.isFunctionOfDim(position: mPos)) &&
5621 expr1.isFunctionOfDim(position: kPos))
5622 : ((expr0.isFunctionOfDim(position: batchPos) &&
5623 expr1.isFunctionOfDim(position: kPos)) ||
5624 (expr0.isFunctionOfDim(position: kPos) && expr1.isFunctionOfDim(position: nPos)));
5625 }
5626 return isValid;
5627}
5628
5629void BatchReduceMatmulOp::regionBuilder(
5630 ImplicitLocOpBuilder &b, Block &block, ArrayRef<NamedAttribute> attrs,
5631 function_ref<InFlightDiagnostic()> emitError) {
5632 if (emitError && block.getNumArguments() != 3) {
5633 emitError() << "BatchReduceMatmulOp regionBuilder expects 3 args, got "
5634 << block.getNumArguments();
5635 return;
5636 }
5637 assert(block.getNumArguments() == 3 &&
5638 "BatchReduceMatmulOp regionBuilder expects 3 args");
5639 RegionBuilderHelper helper(b, block);
5640 SmallVector<Value> yields;
5641
5642 auto toType = block.getArgument(i: 2).getType();
5643 Value castValA =
5644 helper.buildTypeFn(typeFn: TypeFn::cast_signed, toType, operand: block.getArgument(i: 0));
5645 Value castValB =
5646 helper.buildTypeFn(typeFn: TypeFn::cast_signed, toType, operand: block.getArgument(i: 1));
5647 Value mulVal = helper.buildBinaryFn(binaryFn: BinaryFn::mul, arg0: castValA, arg1: castValB);
5648 Value addVal =
5649 helper.buildBinaryFn(binaryFn: BinaryFn::add, arg0: block.getArgument(i: 2), arg1: mulVal);
5650 yields.push_back(Elt: addVal);
5651 helper.yieldOutputs(values: yields);
5652}
5653
5654ParseResult BatchReduceMatmulOp::parse(OpAsmParser &parser,
5655 OperationState &result) {
5656 SmallVector<Attribute, 3> indexingMapsAttr;
5657 Attribute mapAttr;
5658 if (succeeded(Result: parser.parseOptionalKeyword(keyword: "indexing_maps"))) {
5659 if (parser.parseEqual())
5660 return failure();
5661 if (parser.parseLSquare())
5662 return failure();
5663
5664 do {
5665 if (parser.parseAttribute(result&: mapAttr))
5666 return failure();
5667 if (!isa<AffineMapAttr>(Val: mapAttr)) {
5668 return parser.emitError(loc: parser.getCurrentLocation(),
5669 message: "expected affine map attribute");
5670 }
5671 indexingMapsAttr.push_back(Elt: mapAttr);
5672
5673 if (parser.parseOptionalComma())
5674 break;
5675 } while (true);
5676
5677 if (parser.parseRSquare())
5678 return failure();
5679 }
5680 // Initialize indexingMaps, if not supplied explicitly.
5681 if (indexingMapsAttr.empty()) {
5682 indexingMapsAttr = llvm::map_to_vector(
5683 C: BatchReduceMatmulOp::getDefaultIndexingMaps(context: parser.getContext()),
5684 F: [](AffineMap map) -> Attribute { return AffineMapAttr::get(value: map); });
5685 }
5686 result.addAttribute(name: "indexing_maps",
5687 attr: parser.getBuilder().getArrayAttr(value: indexingMapsAttr));
5688 return ::parseNamedStructuredOp(parser, result,
5689 numRegionArgs: BatchReduceMatmulOp::getNumRegionArgs(),
5690 regionBuilder: BatchReduceMatmulOp::getRegionBuilder());
5691}
5692
5693void BatchReduceMatmulOp::print(OpAsmPrinter &p) {
5694 SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
5695 C: BatchReduceMatmulOp::getDefaultIndexingMaps(context: getContext()),
5696 F: [](AffineMap map) -> Attribute { return AffineMapAttr::get(value: map); });
5697
5698 if (!llvm::equal(LRange: getIndexingMaps(), RRange&: indexingMaps)) {
5699 p << " indexing_maps = [";
5700 llvm::interleaveComma(c: getIndexingMaps(), os&: p,
5701 each_fn: [&](Attribute attr) { p.printAttribute(attr); });
5702 p << "]";
5703 }
5704
5705 SmallVector<StringRef, 3> elidedAttrs = {
5706 "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
5707 ::printNamedStructuredOp(p, op: getOperation(), inputs: getInputs(), outputs: getOutputs(),
5708 elidedAttrs);
5709}
5710
5711/// Verify the user defined indexing maps.
5712LogicalResult BatchReduceMatmulOp::verify() {
5713 // Verification of pure batch_reduce_matmul is handled by
5714 // verifyStructuredOpInterface().
5715 if (!hasUserDefinedMaps())
5716 return success();
5717
5718 for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
5719 if (failed(Result: verifyExtendedBatchVariantMatmulSemantic(batchVariantMatmulOp: *this, opIndex)))
5720 return failure();
5721 }
5722 return success();
5723}
5724LogicalResult BatchReduceMatmulOp::fold(FoldAdaptor,
5725 SmallVectorImpl<OpFoldResult> &) {
5726 return memref::foldMemRefCast(op: *this);
5727}
5728void BatchReduceMatmulOp::getEffects(
5729 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
5730 &effects) {
5731 if (hasPureTensorSemantics())
5732 return;
5733 getGenericEffectsImpl(effects, linalgOp: cast<LinalgOp>(Val: getOperation()));
5734}
5735
5736Speculation::Speculatability BatchReduceMatmulOp::getSpeculatability() {
5737 return getGenericSpeculatabilityImpl(linalgOp: cast<LinalgOp>(Val: getOperation()));
5738}
5739
5740} // namespace linalg
5741} // namespace mlir
5742
5743//===----------------------------------------------------------------------===//
5744// LinalgDialect
5745//===----------------------------------------------------------------------===//
5746
5747void LinalgDialect::getCanonicalizationPatterns(
5748 RewritePatternSet &results) const {
5749 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
5750 FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(arg: getContext());
5751}
5752
5753Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
5754 Attribute value, Type type,
5755 Location loc) {
5756 return arith::ConstantOp::materialize(builder, value, type, loc);
5757}
5758

source code of mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp